From d529747a3b5a0d0fe466e0dcb70e8dc57901b635 Mon Sep 17 00:00:00 2001
From: Derval Guillaume <gderval@uliege.be>
Date: Fri, 2 Dec 2022 21:41:59 +0100
Subject: [PATCH] Node/hyperedge generators

---
 src/gboml/ast/__init__.py              |  2 +-
 src/gboml/ast/hyperedges.py            |  9 ++++++++
 src/gboml/ast/nodes.py                 | 13 ++++++++++++
 src/gboml/gboml.lark                   |  4 ++--
 src/gboml/parsing.py                   | 29 ++++++++++++++++++++------
 tests/instances/ok/complex_parsing.txt |  4 ++++
 6 files changed, 52 insertions(+), 9 deletions(-)

diff --git a/src/gboml/ast/__init__.py b/src/gboml/ast/__init__.py
index 7a1e59d..550739b 100644
--- a/src/gboml/ast/__init__.py
+++ b/src/gboml/ast/__init__.py
@@ -6,7 +6,7 @@ __all__ = [
     "StdConstraint", "SOSConstraint", "Objective", "VariableDefinition", "Node",
     "HyperEdge", "NodeDefinition", "NodeImport", "HyperEdgeDefinition", "HyperEdgeImport",
     "ExpressionOp", "GBOMLGraph", "ImplicitLoop", "RValue", "RValueWithGen", "GeneratedRValue",
-    "Range", "MultiLoop", "DictEntry", "Dictionary"
+    "Range", "MultiLoop", "DictEntry", "Dictionary", "NodeGenerator", "HyperEdgeGenerator"
 ]
 
 from gboml.ast.arrays import *
diff --git a/src/gboml/ast/hyperedges.py b/src/gboml/ast/hyperedges.py
index c918750..2f89464 100644
--- a/src/gboml/ast/hyperedges.py
+++ b/src/gboml/ast/hyperedges.py
@@ -1,5 +1,6 @@
 from dataclasses import dataclass
 
+from gboml.ast import Loop
 from gboml.ast.base import GBOMLObject
 from gboml.ast.constraints import Constraint
 from gboml.ast.path import VarOrParam
@@ -18,6 +19,14 @@ class HyperEdgeDefinition(HyperEdge):
     constraints: list[Constraint]
 
 
+@dataclass
+class HyperEdgeGenerator(HyperEdge):
+    name: VarOrParam
+    loop: Loop
+    parameters: list[Definition]
+    constraints: list[Constraint]
+
+
 @dataclass
 class HyperEdgeImport(HyperEdge):
     name: str
diff --git a/src/gboml/ast/nodes.py b/src/gboml/ast/nodes.py
index d93ae5f..3b6a7d7 100644
--- a/src/gboml/ast/nodes.py
+++ b/src/gboml/ast/nodes.py
@@ -1,5 +1,6 @@
 from dataclasses import dataclass
 
+from gboml.ast import Loop
 from gboml.ast.base import GBOMLObject
 from gboml.ast.constraints import Constraint
 from gboml.ast.path import VarOrParam
@@ -24,6 +25,18 @@ class NodeDefinition(Node):
     objectives: list[Objective]
 
 
+@dataclass
+class NodeGenerator(Node):
+    name: VarOrParam
+    loop: Loop
+    parameters: list[Definition]
+    nodes: list[Node]
+    hyperedges: list[HyperEdge]
+    variables: list[VariableDefinition]
+    constraints: list[Constraint]
+    objectives: list[Objective]
+
+
 @dataclass
 class NodeImport(Node):
     name: str
diff --git a/src/gboml/gboml.lark b/src/gboml/gboml.lark
index f8ff758..a5864f7 100644
--- a/src/gboml/gboml.lark
+++ b/src/gboml/gboml.lark
@@ -25,7 +25,7 @@ _program: node | hyperedge
 
 // NODES
 ?node: node_definition | node_import
-node_definition: "#NODE" ID \
+node_definition: "#NODE" var_or_param [loop] \
                   parameters_block \
                   program_block \
                   variables_block \
@@ -45,7 +45,7 @@ variable_scope_change: ID SCOPE ";"
 // HYPEREDGES
 ?hyperedge: hyperedge_definition | hyperedge_import
 
-hyperedge_definition: "#HYPEREDGE" ID \
+hyperedge_definition: "#HYPEREDGE" var_or_param [loop] \
                       parameters_block \
                       constraints_block
 
diff --git a/src/gboml/parsing.py b/src/gboml/parsing.py
index 19e99d7..83fd7cf 100644
--- a/src/gboml/parsing.py
+++ b/src/gboml/parsing.py
@@ -61,7 +61,6 @@ def _lark_to_gboml(tree: Tree, filename: Optional[str] = None) -> GBOMLGraph:
         # obj(*children, meta=meta)
         #
         to_obj = {
-            "hyperedge_definition": HyperEdgeDefinition,
             "hyperedge_import": HyperEdgeImport,
             "definition": Definition,
             "var_or_param_leaf": VarOrParamLeaf,
@@ -119,13 +118,31 @@ def _lark_to_gboml(tree: Tree, filename: Optional[str] = None) -> GBOMLGraph:
         def program_block(self, meta: Meta, *childrens: list[Node | HyperEdge]) -> NodesAndHyperEdges:
             return self.NodesAndHyperEdges([x for x in childrens if isinstance(x, Node)], [x for x in childrens if isinstance(x, HyperEdge)])
 
-        def node_definition(self, meta: Meta, name: str, param_block: list[Definition], subprogram_block: NodesAndHyperEdges,
+        def hyperedge_definition(self, meta: Meta, name: VarOrParam, loop: Optional[Loop], param_block: list[Definition], constraint_block: list[Constraint]):
+            if loop is None:
+                if len(name.path) != 1 or len(name.path[0].indices) != 0:
+                    raise Exception(f"Invalid name for node: {name}")
+                return HyperEdgeDefinition(name.path[0].name, param_block, constraint_block, meta=meta)
+            else:
+                return HyperEdgeGenerator(name, loop, param_block, constraint_block, meta=meta)
+
+        def node_definition(self, meta: Meta, name: VarOrParam, loop: Optional[Loop], param_block: list[Definition], subprogram_block: NodesAndHyperEdges,
                             variable_block: list[VariableDefinition], constraint_block: list[Constraint],
                             objectives_block: list[Objective]):
-            return NodeDefinition(name, param_block,
-                                  subprogram_block.nodes, subprogram_block.hyperedges,
-                                  variable_block, constraint_block,
-                                  objectives_block, meta=meta)
+            if loop is None:
+                if len(name.path) != 1 or len(name.path[0].indices) != 0:
+                    raise Exception(f"Invalid name for node: {name}")
+                return NodeDefinition(name.path[0].name,
+                                      param_block,
+                                      subprogram_block.nodes, subprogram_block.hyperedges,
+                                      variable_block, constraint_block,
+                                      objectives_block, meta=meta)
+            else:
+                return NodeGenerator(name, loop,
+                                     param_block,
+                                     subprogram_block.nodes, subprogram_block.hyperedges,
+                                     variable_block, constraint_block,
+                                     objectives_block, meta=meta)
 
         def node_import(self, meta: Meta, name: str, imported_name: VarOrParam, imported_from: str, redef: list[ScopeChange | Definition]):
             return NodeImport(name, imported_name, imported_from,
diff --git a/tests/instances/ok/complex_parsing.txt b/tests/instances/ok/complex_parsing.txt
index 4549ea1..fed325b 100644
--- a/tests/instances/ok/complex_parsing.txt
+++ b/tests/instances/ok/complex_parsing.txt
@@ -13,6 +13,10 @@
     a = 2;
     b external;
     c internal;
+#NODE nodeL[i] for i in [0:10]
+    #VARIABLES
+#HYPEREDGE E[i] for i in [0:10]
+    #CONSTRAINTS
 #NODE node1
     #PARAMETERS
         a = 2;
-- 
GitLab