diff --git a/src/gboml/ast/__init__.py b/src/gboml/ast/__init__.py index ef2a017ba0348fb63b4e3ddce94bcab672f266d6..391efbe5008caf59b7dc24f69331167a83379460 100644 --- a/src/gboml/ast/__init__.py +++ b/src/gboml/ast/__init__.py @@ -7,7 +7,7 @@ __all__ = [ "HyperEdge", "NodeDefinition", "NodeImport", "HyperEdgeDefinition", "HyperEdgeImport", "ExpressionOp", "GBOMLGraph", "ImplicitLoop", "RValue", "RValueWithGen", "GeneratedRValue", "Range", "MultiLoop", "DictEntry", "Dictionary", "NodeGenerator", "HyperEdgeGenerator", - "DefinitionType" + "DefinitionType", "FunctionDefinition", "ConstantDefinition", "ExpressionDefinition" ] from gboml.ast.arrays import * diff --git a/src/gboml/ast/variables.py b/src/gboml/ast/variables.py index 87a7927673d00683f9c0244f9129ab46b979f7c1..57562838df277bdf789949f1e5e624f479ef01cd 100644 --- a/src/gboml/ast/variables.py +++ b/src/gboml/ast/variables.py @@ -22,10 +22,30 @@ class DefinitionType(Enum): constant = "=" expression = "<-" + @dataclass class Definition(GBOMLObject): + pass + + +@dataclass +class ConstantDefinition(Definition): + name: str + value: RValue + tags: list[str] = field(default_factory=list) + + +@dataclass +class ExpressionDefinition(Definition): + name: str + value: RValue + tags: list[str] = field(default_factory=list) + + +@dataclass +class FunctionDefinition(Definition): name: str - type: DefinitionType + args: list[str] value: RValue tags: list[str] = field(default_factory=list) diff --git a/src/gboml/gboml.lark b/src/gboml/gboml.lark index 886c3cbc12cce52b34a8546612be3769acd68897..d37f0e6a95e461ac178ccb12ee13c0b0913ca52a 100644 --- a/src/gboml/gboml.lark +++ b/src/gboml/gboml.lark @@ -92,7 +92,7 @@ bool_expression_comparison: expression (COMPARISON_OPERATOR | CTR_OPERATOR) expr COMPARISON_OPERATOR: "<" | ">" | "!=" // DEFINITIONS -definition: ID DEF_TYPE rvalue tags ";" +definition: ID ["(" separated_list{ID, ","} ")"] DEF_TYPE rvalue tags ";" DEF_TYPE: "=" | "<-" // ARRAYS diff --git a/src/gboml/parsing.py b/src/gboml/parsing.py index b896a8c51a1eadbd56eb1854d7477591e1883398..5709e5d7f48e9016215a8b67e19342f901eb3169 100644 --- a/src/gboml/parsing.py +++ b/src/gboml/parsing.py @@ -64,7 +64,6 @@ def _lark_to_gboml(tree: Tree, filename: Optional[str] = None) -> GBOMLGraph: # to_obj = { "hyperedge_import": HyperEdgeImport, - "definition": Definition, "var_or_param_leaf": VarOrParamLeaf, "var_or_param": VarOrParam, "constraint_std": StdConstraint, @@ -195,4 +194,14 @@ def _lark_to_gboml(tree: Tree, filename: Optional[str] = None) -> GBOMLGraph: return Array(entries, meta=meta) raise Exception("An array cannot contain dictionary entries (and conversely)") + def definition(self, meta: Meta, name: str, args: Optional[list[str]], typ: DefinitionType, val: RValue, tags: list[str]): + if args is not None: + if typ != DefinitionType.expression: + raise Exception("Functions can only be defined as expressions (use `<-` instead of `=`)") + return FunctionDefinition(name, args, val, tags, meta=meta) + elif typ == DefinitionType.expression: + return ExpressionDefinition(name, val, tags, meta=meta) + else: + return ConstantDefinition(name, val, tags, meta=meta) + return GBOMLLarkTransformer().transform(tree)