Skip to content
Snippets Groups Projects
Commit 39800b0f authored by Derval Guillaume's avatar Derval Guillaume
Browse files

Rationalize some parts of the parsing and AST

parent df47c3d8
No related branches found
No related tags found
No related merge requests found
...@@ -12,3 +12,7 @@ class Meta: ...@@ -12,3 +12,7 @@ class Meta:
@dataclass @dataclass
class GBOMLObject: class GBOMLObject:
meta: Optional[Meta] = field(default=None, kw_only=True, repr=False) meta: Optional[Meta] = field(default=None, kw_only=True, repr=False)
@dataclass
class NamedGBOMLObject(GBOMLObject):
name: str
\ No newline at end of file
...@@ -26,7 +26,7 @@ class StdConstraint(Constraint): ...@@ -26,7 +26,7 @@ class StdConstraint(Constraint):
op: Operator op: Operator
rhs: Expression rhs: Expression
loop: Optional[Loop] = None loop: Optional[Loop] = None
tags: list[str] = field(default_factory=list) tags: set[str] = field(default_factory=set)
@dataclass @dataclass
...@@ -34,7 +34,7 @@ class SOSConstraint(Constraint): ...@@ -34,7 +34,7 @@ class SOSConstraint(Constraint):
type: SOSType type: SOSType
content: Array content: Array
loop: Optional[Loop] = None loop: Optional[Loop] = None
tags: list[str] = field(default_factory=list) tags: set[str] = field(default_factory=set)
@dataclass @dataclass
......
...@@ -13,6 +13,3 @@ Expression = int | float | ExpressionObj ...@@ -13,6 +13,3 @@ Expression = int | float | ExpressionObj
@dataclass @dataclass
class BoolExpression(GBOMLObject): class BoolExpression(GBOMLObject):
pass pass
...@@ -3,14 +3,14 @@ from typing import Optional ...@@ -3,14 +3,14 @@ from typing import Optional
from gboml.ast.importable import Extends from gboml.ast.importable import Extends
from gboml.ast.loops import Loop from gboml.ast.loops import Loop
from gboml.ast.base import GBOMLObject from gboml.ast.base import NamedGBOMLObject
from gboml.ast.constraints import Constraint, CtrActivation from gboml.ast.constraints import Constraint, CtrActivation
from gboml.ast.path import VarOrParam from gboml.ast.path import VarOrParam
from gboml.ast.variables import Definition from gboml.ast.variables import Definition
@dataclass @dataclass
class HyperEdge(GBOMLObject): class HyperEdge(NamedGBOMLObject):
pass pass
...@@ -21,15 +21,16 @@ class HyperEdgeDefinition(HyperEdge): ...@@ -21,15 +21,16 @@ class HyperEdgeDefinition(HyperEdge):
parameters: list[Definition] = field(default_factory=list) parameters: list[Definition] = field(default_factory=list)
constraints: list[Constraint] = field(default_factory=list) constraints: list[Constraint] = field(default_factory=list)
activations: list[CtrActivation] = field(default_factory=list) activations: list[CtrActivation] = field(default_factory=list)
tags: list[str] = field(default_factory=list) tags: set[str] = field(default_factory=set)
@dataclass @dataclass
class HyperEdgeGenerator(HyperEdge): class HyperEdgeGenerator(HyperEdge):
name: VarOrParam name: str
indices: list["RValue"]
loop: Loop loop: Loop
import_from: Optional[Extends] = None import_from: Optional[Extends] = None
parameters: list[Definition] = field(default_factory=list) parameters: list[Definition] = field(default_factory=list)
constraints: list[Constraint] = field(default_factory=list) constraints: list[Constraint] = field(default_factory=list)
activations: list[CtrActivation] = field(default_factory=list) activations: list[CtrActivation] = field(default_factory=list)
tags: list[str] = field(default_factory=list) tags: set[str] = field(default_factory=set)
...@@ -3,17 +3,16 @@ from typing import Optional ...@@ -3,17 +3,16 @@ from typing import Optional
from gboml.ast.loops import Loop from gboml.ast.loops import Loop
from gboml.ast.activation import Activation from gboml.ast.activation import Activation
from gboml.ast.base import GBOMLObject from gboml.ast.base import NamedGBOMLObject
from gboml.ast.constraints import Constraint from gboml.ast.constraints import Constraint
from gboml.ast.importable import Extends from gboml.ast.importable import Extends
from gboml.ast.path import VarOrParam
from gboml.ast.hyperedges import HyperEdge from gboml.ast.hyperedges import HyperEdge
from gboml.ast.objectives import Objective from gboml.ast.objectives import Objective
from gboml.ast.variables import Definition, VariableDefinition, ScopeChange from gboml.ast.variables import Definition, VariableDefinition, ScopeChange
@dataclass @dataclass
class Node(GBOMLObject): class Node(NamedGBOMLObject):
pass pass
...@@ -28,12 +27,13 @@ class NodeDefinition(Node): ...@@ -28,12 +27,13 @@ class NodeDefinition(Node):
constraints: list[Constraint] = field(default_factory=list) constraints: list[Constraint] = field(default_factory=list)
objectives: list[Objective] = field(default_factory=list) objectives: list[Objective] = field(default_factory=list)
activations: list[Activation] = field(default_factory=list) activations: list[Activation] = field(default_factory=list)
tags: list[str] = field(default_factory=list) tags: set[str] = field(default_factory=set)
@dataclass @dataclass
class NodeGenerator(Node): class NodeGenerator(Node):
name: VarOrParam name: str
indices: list["RValue"]
loop: Loop loop: Loop
import_from: Optional[Extends] = None import_from: Optional[Extends] = None
parameters: list[Definition] = field(default_factory=list) parameters: list[Definition] = field(default_factory=list)
...@@ -43,4 +43,4 @@ class NodeGenerator(Node): ...@@ -43,4 +43,4 @@ class NodeGenerator(Node):
constraints: list[Constraint] = field(default_factory=list) constraints: list[Constraint] = field(default_factory=list)
objectives: list[Objective] = field(default_factory=list) objectives: list[Objective] = field(default_factory=list)
activations: list[Activation] = field(default_factory=list) activations: list[Activation] = field(default_factory=list)
tags: list[str] = field(default_factory=list) tags: set[str] = field(default_factory=set)
...@@ -19,7 +19,7 @@ class Objective(GBOMLObject): ...@@ -19,7 +19,7 @@ class Objective(GBOMLObject):
name: Optional[str] name: Optional[str]
expression: Expression expression: Expression
loop: Optional[Loop] = None loop: Optional[Loop] = None
tags: list[str] = field(default_factory=list) tags: set[str] = field(default_factory=set)
@dataclass @dataclass
class ObjActivation(Activation): class ObjActivation(Activation):
......
...@@ -2,7 +2,7 @@ from dataclasses import dataclass, field ...@@ -2,7 +2,7 @@ from dataclasses import dataclass, field
from enum import Enum from enum import Enum
from typing import Optional from typing import Optional
from gboml.ast.base import GBOMLObject from gboml.ast.base import GBOMLObject, NamedGBOMLObject
from gboml.ast.path import VarOrParam from gboml.ast.path import VarOrParam
from gboml.ast.rvalue import RValue from gboml.ast.rvalue import RValue
...@@ -24,42 +24,39 @@ class DefinitionType(Enum): ...@@ -24,42 +24,39 @@ class DefinitionType(Enum):
@dataclass @dataclass
class Definition(GBOMLObject): class Definition(NamedGBOMLObject):
pass name: str
@dataclass @dataclass
class ConstantDefinition(Definition): class ConstantDefinition(Definition):
name: str
value: RValue value: RValue
tags: list[str] = field(default_factory=list) tags: set[str] = field(default_factory=set)
@dataclass @dataclass
class ExpressionDefinition(Definition): class ExpressionDefinition(Definition):
name: str
value: RValue value: RValue
tags: list[str] = field(default_factory=list) tags: set[str] = field(default_factory=set)
@dataclass @dataclass
class FunctionDefinition(Definition): class FunctionDefinition(Definition):
name: str
args: list[str] args: list[str]
value: RValue value: RValue
tags: list[str] = field(default_factory=list) tags: set[str] = field(default_factory=set)
@dataclass @dataclass
class VariableDefinition(GBOMLObject): class VariableDefinition(NamedGBOMLObject):
name: str
indices: list[str]
scope: VarScope scope: VarScope
type: VarType type: VarType
name: VarOrParam
import_from: Optional[VarOrParam] = None import_from: Optional[VarOrParam] = None
tags: list[str] = field(default_factory=list) tags: set[str] = field(default_factory=set)
@dataclass @dataclass
class ScopeChange(GBOMLObject): class ScopeChange(GBOMLObject):
var: str name: str
scope: VarScope scope: VarScope
\ No newline at end of file
...@@ -37,7 +37,7 @@ extends: "extends" var_or_param ["from" STRING] ...@@ -37,7 +37,7 @@ extends: "extends" var_or_param ["from" STRING]
// NODES // NODES
?node: node_definition | node_import ?node: node_definition | node_import
node_definition: _block_shortcut{_node_header, _node_content} node_definition: _block_shortcut{_node_header, _node_content}
_node_header: "#NODE" var_or_param [extends] [loop] tags _node_header: "#NODE" ID olist{index} [extends] [loop] tags
_node_content: parameters_block program_block variables_block constraints_block objectives_block _node_content: parameters_block program_block variables_block constraints_block objectives_block
parameters_block: (_block_repeat_or_pass{_opt_param_header,definition})? parameters_block: (_block_repeat_or_pass{_opt_param_header,definition})?
...@@ -57,14 +57,15 @@ variable_scope_change: ID SCOPE ";" ...@@ -57,14 +57,15 @@ variable_scope_change: ID SCOPE ";"
?hyperedge: hyperedge_definition | hyperedge_import ?hyperedge: hyperedge_definition | hyperedge_import
hyperedge_definition: _block_shortcut{_hyperedge_header, _hyperedge_content} hyperedge_definition: _block_shortcut{_hyperedge_header, _hyperedge_content}
_hyperedge_header: "#HYPEREDGE" var_or_param [extends] [loop] tags _hyperedge_header: "#HYPEREDGE" ID olist{index} [extends] [loop] tags
_hyperedge_content: parameters_block constraints_block _hyperedge_content: parameters_block constraints_block
hyperedge_import: "#HYPEREDGE" ID "=" "import" var_or_param "from" STRING hyperedge_redefs hyperedge_import: "#HYPEREDGE" ID "=" "import" var_or_param "from" STRING hyperedge_redefs
hyperedge_redefs: "with" definition* | ";" hyperedge_redefs: "with" definition* | ";"
// VARIABLES // VARIABLES
variable_definition: SCOPE [VTYPE] ":" separated_list{var_or_param,","} [_LARROW separated_list{var_or_param, ","}] tags ";" variable_definition: SCOPE [VTYPE] ":" separated_list{variable_name,","} [_LARROW separated_list{var_or_param, ","}] tags ";"
variable_name: ID olist{index}
SCOPE: "internal" | "external" SCOPE: "internal" | "external"
VTYPE: "binary" | "continuous" | "integer" VTYPE: "binary" | "continuous" | "integer"
_LARROW.1: "<-" _LARROW.1: "<-"
......
...@@ -54,7 +54,10 @@ def _lark_to_gboml(tree: Tree, filename: Optional[str] = None) -> GBOMLGraph: ...@@ -54,7 +54,10 @@ def _lark_to_gboml(tree: Tree, filename: Optional[str] = None) -> GBOMLGraph:
as_list = { as_list = {
"objectives_block", "constraints_block", "objectives_block", "constraints_block",
"parameters_block", "global_block", "olist", "mlist", "node_redefs", "parameters_block", "global_block", "olist", "mlist", "node_redefs",
"hyperedge_redefs", "separated_list", "separated_maybe_empty_list", "hyperedge_redefs", "separated_list", "separated_maybe_empty_list"
}
as_sets = {
"tags" "tags"
} }
...@@ -93,12 +96,15 @@ def _lark_to_gboml(tree: Tree, filename: Optional[str] = None) -> GBOMLGraph: ...@@ -93,12 +96,15 @@ def _lark_to_gboml(tree: Tree, filename: Optional[str] = None) -> GBOMLGraph:
"ctr_deactivate": lambda *x, meta: CtrActivation(ActivationType.deactivate, *x, meta=meta), "ctr_deactivate": lambda *x, meta: CtrActivation(ActivationType.deactivate, *x, meta=meta),
"obj_activate": lambda *x, meta: ObjActivation(ActivationType.activate, *x, meta=meta), "obj_activate": lambda *x, meta: ObjActivation(ActivationType.activate, *x, meta=meta),
"obj_deactivate": lambda *x, meta: ObjActivation(ActivationType.deactivate, *x, meta=meta), "obj_deactivate": lambda *x, meta: ObjActivation(ActivationType.deactivate, *x, meta=meta),
"extends": Extends "extends": Extends,
"variable_name": lambda *x, meta: x
} }
def __default__(self, data, children, _): def __default__(self, data, children, _):
if data in self.as_list: if data in self.as_list:
return list(children) return list(children)
if data in self.as_sets:
return set(children)
if data in self.to_obj: if data in self.to_obj:
return self.to_obj[data](*children, meta=gen_meta(data)) return self.to_obj[data](*children, meta=gen_meta(data))
raise RuntimeError(f"Unknown rule {data}") raise RuntimeError(f"Unknown rule {data}")
...@@ -125,8 +131,8 @@ def _lark_to_gboml(tree: Tree, filename: Optional[str] = None) -> GBOMLGraph: ...@@ -125,8 +131,8 @@ def _lark_to_gboml(tree: Tree, filename: Optional[str] = None) -> GBOMLGraph:
def program_block(self, meta: Meta, *childrens: list[Node | HyperEdge]) -> NodesAndHyperEdges: def program_block(self, meta: Meta, *childrens: list[Node | HyperEdge]) -> NodesAndHyperEdges:
return self.NodesAndHyperEdges([x for x in childrens if isinstance(x, Node)], [x for x in childrens if isinstance(x, HyperEdge)]) return self.NodesAndHyperEdges([x for x in childrens if isinstance(x, Node)], [x for x in childrens if isinstance(x, HyperEdge)])
def hyperedge_definition(self, meta: Meta, name: VarOrParam, extends: Optional[Extends], def hyperedge_definition(self, meta: Meta, name: str, indices: list[RValue], extends: Optional[Extends],
loop: Optional[Loop], tags: list[str], param_block: list[Definition] = None, loop: Optional[Loop], tags: set[str], param_block: list[Definition] = None,
constraint_block: list[Constraint | CtrActivation] = None): constraint_block: list[Constraint | CtrActivation] = None):
constraint_block = constraint_block or [] constraint_block = constraint_block or []
activations = [x for x in constraint_block if isinstance(x, CtrActivation)] activations = [x for x in constraint_block if isinstance(x, CtrActivation)]
...@@ -134,16 +140,16 @@ def _lark_to_gboml(tree: Tree, filename: Optional[str] = None) -> GBOMLGraph: ...@@ -134,16 +140,16 @@ def _lark_to_gboml(tree: Tree, filename: Optional[str] = None) -> GBOMLGraph:
param_block = param_block or [] param_block = param_block or []
if loop is None: if loop is None:
if len(name.path) != 1 or len(name.path[0].indices) != 0: return HyperEdgeDefinition(name, extends, param_block, constraint_block,
raise Exception(f"Invalid name for node: {name}")
return HyperEdgeDefinition(name.path[0].name, extends, param_block, constraint_block,
activations, tags, meta=meta) activations, tags, meta=meta)
else: else:
return HyperEdgeGenerator(name, loop, extends, param_block, constraint_block, if len(indices) == 0:
raise Exception(f"Invalid name for node: {name}")
return HyperEdgeGenerator(name, indices, loop, extends, param_block, constraint_block,
activations, tags, meta=meta) activations, tags, meta=meta)
def node_definition(self, meta: Meta, name: VarOrParam, extends: Optional[Extends], def node_definition(self, meta: Meta, name: str, indices: list[RValue], extends: Optional[Extends],
loop: Optional[Loop], tags: list[str], loop: Optional[Loop], tags: set[str],
param_block: list[Definition] = None, subprogram_block: NodesAndHyperEdges = None, param_block: list[Definition] = None, subprogram_block: NodesAndHyperEdges = None,
variable_block: list[VariableDefinition] = None, variable_block: list[VariableDefinition] = None,
constraint_block: list[Constraint | CtrActivation] = None, constraint_block: list[Constraint | CtrActivation] = None,
...@@ -159,14 +165,14 @@ def _lark_to_gboml(tree: Tree, filename: Optional[str] = None) -> GBOMLGraph: ...@@ -159,14 +165,14 @@ def _lark_to_gboml(tree: Tree, filename: Optional[str] = None) -> GBOMLGraph:
objectives_block = [x for x in objectives_block if isinstance(x, Objective)] objectives_block = [x for x in objectives_block if isinstance(x, Objective)]
if loop is None: if loop is None:
if len(name.path) != 1 or len(name.path[0].indices) != 0: return NodeDefinition(name, extends, param_block,
raise Exception(f"Invalid name for node: {name}")
return NodeDefinition(name.path[0].name, extends, param_block,
subprogram_block.nodes, subprogram_block.hyperedges, subprogram_block.nodes, subprogram_block.hyperedges,
variable_block, constraint_block, variable_block, constraint_block,
objectives_block, activations, tags, meta=meta) objectives_block, activations, tags, meta=meta)
else: else:
return NodeGenerator(name, loop, extends, param_block, if len(indices) == 0:
raise Exception(f"Invalid name for node: {name}")
return NodeGenerator(name, indices, loop, extends, param_block,
subprogram_block.nodes, subprogram_block.hyperedges, subprogram_block.nodes, subprogram_block.hyperedges,
variable_block, constraint_block, variable_block, constraint_block,
objectives_block, activations, tags, meta=meta) objectives_block, activations, tags, meta=meta)
...@@ -184,13 +190,13 @@ def _lark_to_gboml(tree: Tree, filename: Optional[str] = None) -> GBOMLGraph: ...@@ -184,13 +190,13 @@ def _lark_to_gboml(tree: Tree, filename: Optional[str] = None) -> GBOMLGraph:
def start(self, meta: Meta, time_horizon: Optional[int], global_defs: list[Definition], nodes_hyperedges: NodesAndHyperEdges): def start(self, meta: Meta, time_horizon: Optional[int], global_defs: list[Definition], nodes_hyperedges: NodesAndHyperEdges):
return GBOMLGraph(time_horizon, global_defs, nodes_hyperedges.nodes, nodes_hyperedges.hyperedges, meta=meta) return GBOMLGraph(time_horizon, global_defs, nodes_hyperedges.nodes, nodes_hyperedges.hyperedges, meta=meta)
def variable_definition(self, meta: Meta, scope: VarScope, type: Optional[VarType], names: list[VarOrParam], def variable_definition(self, meta: Meta, scope: VarScope, type: Optional[VarType], names: list[(str, list[str])],
imports_from: Optional[list[VarOrParam]], tags: list[str]): imports_from: Optional[list[VarOrParam]], tags: set[str]):
if imports_from is not None and len(imports_from) != len(names): if imports_from is not None and len(imports_from) != len(names):
raise Exception("Invalid variable import, numbers of variables on the left and on the right-side of " raise Exception("Invalid variable import, numbers of variables on the left and on the right-side of "
"`<-` don't match") "`<-` don't match")
for name, import_from in zip(names, imports_from or repeat(None, len(names))): for name, import_from in zip(names, imports_from or repeat(None, len(names))):
yield VariableDefinition(scope, type or VarType.continuous, name, import_from, tags, meta=meta) yield VariableDefinition(name[0], name[1], scope, type or VarType.continuous, import_from, tags, meta=meta)
def variables_block(self, _: Meta, *defs: Tuple[Iterable[VariableDefinition]]): def variables_block(self, _: Meta, *defs: Tuple[Iterable[VariableDefinition]]):
return [vd for iterable in defs for vd in iterable] return [vd for iterable in defs for vd in iterable]
...@@ -205,7 +211,7 @@ def _lark_to_gboml(tree: Tree, filename: Optional[str] = None) -> GBOMLGraph: ...@@ -205,7 +211,7 @@ def _lark_to_gboml(tree: Tree, filename: Optional[str] = None) -> GBOMLGraph:
return Array(entries, meta=meta) return Array(entries, meta=meta)
raise Exception("An array cannot contain dictionary entries (and conversely)") raise Exception("An array cannot contain dictionary entries (and conversely)")
def definition(self, meta: Meta, name: str, args: Optional[list[str]], typ: DefinitionType, val: RValue, tags: list[str]): def definition(self, meta: Meta, name: str, args: Optional[list[str]], typ: DefinitionType, val: RValue, tags: set[str]):
if args is not None: if args is not None:
if typ != DefinitionType.expression: if typ != DefinitionType.expression:
raise Exception("Functions can only be defined as expressions (use `<-` instead of `=`)") raise Exception("Functions can only be defined as expressions (use `<-` instead of `=`)")
...@@ -215,4 +221,5 @@ def _lark_to_gboml(tree: Tree, filename: Optional[str] = None) -> GBOMLGraph: ...@@ -215,4 +221,5 @@ def _lark_to_gboml(tree: Tree, filename: Optional[str] = None) -> GBOMLGraph:
else: else:
return ConstantDefinition(name, val, tags, meta=meta) return ConstantDefinition(name, val, tags, meta=meta)
return GBOMLLarkTransformer().transform(tree) return GBOMLLarkTransformer().transform(tree)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment