From 34a63f23825e84da7def7389a0850d9b8c3369f1 Mon Sep 17 00:00:00 2001 From: Derval Guillaume <gderval@uliege.be> Date: Fri, 28 Apr 2023 17:22:20 +0200 Subject: [PATCH] New semantic for import resolution --- src/gboml/ast/hyperedges.py | 4 +- src/gboml/ast/nodes.py | 4 +- src/gboml/resolve_imports.py | 101 +++++++++++++++++++++++------------ 3 files changed, 71 insertions(+), 38 deletions(-) diff --git a/src/gboml/ast/hyperedges.py b/src/gboml/ast/hyperedges.py index bf61a40..efafe02 100644 --- a/src/gboml/ast/hyperedges.py +++ b/src/gboml/ast/hyperedges.py @@ -17,7 +17,7 @@ class HyperEdge(NamedGBOMLObject): @dataclass class HyperEdgeDefinition(HyperEdge): name: str - import_from: Optional[Extends] = None + import_from: Optional[Extends | HyperEdge] = None parameters: list[Definition] = field(default_factory=list) constraints: list[Constraint] = field(default_factory=list) activations: list[CtrActivation] = field(default_factory=list) @@ -29,7 +29,7 @@ class HyperEdgeGenerator(HyperEdge): name: str indices: list[str] loop: Loop - import_from: Optional[Extends] = None + import_from: Optional[Extends | HyperEdge] = None parameters: list[Definition] = field(default_factory=list) constraints: list[Constraint] = field(default_factory=list) activations: list[CtrActivation] = field(default_factory=list) diff --git a/src/gboml/ast/nodes.py b/src/gboml/ast/nodes.py index 8c7167d..94296c6 100644 --- a/src/gboml/ast/nodes.py +++ b/src/gboml/ast/nodes.py @@ -19,7 +19,7 @@ class Node(NamedGBOMLObject): @dataclass class NodeDefinition(Node): name: str - import_from: Optional[Extends] = None + import_from: Optional[Extends | Node] = None parameters: list[Definition] = field(default_factory=list) nodes: list[Node] = field(default_factory=list) hyperedges: list[HyperEdge] = field(default_factory=list) @@ -35,7 +35,7 @@ class NodeGenerator(Node): name: str indices: list[str] loop: Loop - import_from: Optional[Extends] = None + import_from: Optional[Extends | Node] = None parameters: list[Definition] = field(default_factory=list) nodes: list[Node] = field(default_factory=list) hyperedges: list[HyperEdge] = field(default_factory=list) diff --git a/src/gboml/resolve_imports.py b/src/gboml/resolve_imports.py index 4bb7ee9..c0dcb48 100644 --- a/src/gboml/resolve_imports.py +++ b/src/gboml/resolve_imports.py @@ -1,45 +1,62 @@ """ This step aims at resolving imports and extension of other GBOML models. -At the end of this step, no "extends" or "import" keyword may remain in the resulting graph +At the end of this step, no "Extends" or "import" cls may remain in the resulting graph """ import dataclasses from pathlib import Path +from typing import Optional from gboml.ast import * from gboml.parsing import parse_file from gboml.redundant_definitions import remove_redundant_definitions from gboml.tools.tree_modifier import modify +# Singleton used in _load_file to detect cyclic imports WORKING = object -file_cache = {} +inheritable_ast = NodeDefinition | NodeGenerator | HyperEdgeDefinition | HyperEdgeGenerator -def load_file(fpath: Path): +def _load_file(fpath: Path, file_cache: dict[Path, GBOMLGraph]): + """ Loads a file and resolves its imports. file_cache is used as a cache for already-seen files. """ fpath = fpath.absolute() if fpath not in file_cache: - file_cache[fpath] = resolve_imports(parse_file(fpath), fpath.parent) + file_cache[fpath] = resolve_imports(parse_file(fpath), fpath.parent, file_cache) elif file_cache[fpath] is WORKING: - raise RuntimeError("Recursive import") - + raise RuntimeError("Cyclic import") return file_cache[fpath] -def _merge(newAst: Node | HyperEdge, - extends: Node | HyperEdge, - additionnal_parameters: list[Definition]) -> Node | HyperEdge: - """ Merges a node/hyperedge and its parent, forming a full node/hyperedge without extension. """ - if isinstance(newAst, Node): - merge_fields = {"nodes", "hyperedges", "variables", "constraints", "objectives", "activations"} - else: - merge_fields = {"constraints", "activations"} +def _update_import_from(child: inheritable_ast, parent: inheritable_ast, parent_indices: list[Definition]) -> inheritable_ast: + """ Replaces the "Extend" element of child with its true parent, and add the needed parent_indices to its parameters """ + return dataclasses.replace( + child, + import_from=parent, + parameters=parent_indices + parent.parameters + ) + +def _check_indices(child: inheritable_ast, + parent: inheritable_ast): + """ Checks that no indices are overriden """ + child_indices = set() + if isinstance(child, NodeGenerator) or isinstance(child, HyperEdgeGenerator): + child_indices = set(child.indices) + child_parameters = {definition.name for definition in child.parameters} + + if not child_indices.isdisjoint(child_parameters): + raise RuntimeError(f"The following indices are redefined: " + str(child_indices.intersection(child_parameters))) + + while parent is not None: + if isinstance(parent, NodeGenerator) or isinstance(parent, HyperEdgeGenerator): + parent_indices = set(parent.indices) + if not parent_indices.isdisjoint(child_indices): + raise RuntimeError(f"{child.name} cannot share indices {parent_indices.intersection(child_indices)} with its parent {parent.name}. Change the name of the indice(s).") + if not parent_indices.isdisjoint(child_parameters): + raise RuntimeError(f"{child.name} cannot override indices {parent_indices.intersection(child_parameters)} of its parent {parent.name}.") + parent_parameters = {definition.name for definition in parent.parameters} + if not parent_parameters.isdisjoint(child_indices): + raise RuntimeError(f"{child.name}'s indices {parent_parameters.intersection(child_indices)} override parameters of its parent {parent.name}.") + parent = parent.import_from - return remove_redundant_definitions(dataclasses.replace( - newAst, - import_from=None, - tags=newAst.tags | extends.tags, - parameters=additionnal_parameters + extends.parameters + newAst.parameters, - **{f: getattr(extends, f) + getattr(newAst, f) for f in merge_fields} - )) def _find_elem_with_name(l, name): @@ -51,14 +68,29 @@ def _find_elem_with_name(l, name): raise RuntimeError(f"Multiple nodes/hyperedges have the same name '{name}'") return valid_nodes[0] +def resolve_imports(tree: GBOMLObject, current_dir: Path, file_cache: Optional[dict[Path, GBOMLGraph]] = None) -> GBOMLObject: + """ + Resolves imports, transforming all `Extends` entries to Nodes/HyperEdges. -def resolve_imports(tree: GBOMLObject, current_dir: Path) -> GBOMLObject: - def update(ast: NodeDefinition | NodeGenerator | HyperEdgeDefinition | HyperEdgeGenerator) \ - -> NodeDefinition | NodeGenerator | HyperEdgeDefinition | HyperEdgeGenerator: + Args: + tree: + current_dir: + file_cache: dict to be used as a cache. Should initially be empty, and should be reused between calls to + resolve_imports. + + Returns: + A modified tree where Node/HyperEdges with an import_from value that is of type Extends have been + replaced by a Node/HyperEdge + """ + + if file_cache is None: + file_cache = {} + + def update(ast: inheritable_ast) -> inheritable_ast: if ast.import_from is None: return ast - imported_file = load_file(current_dir / ast.import_from.filename) + imported_file = _load_file(current_dir / ast.import_from.filename, file_cache) # for now, we only resolve "directly-named" nodes in other files. # in the future we may resolve nodes referenced inside arrays or parameters, but for now we don't. @@ -74,22 +106,23 @@ def resolve_imports(tree: GBOMLObject, current_dir: Path) -> GBOMLObject: raise RuntimeError("Invalid number of indices.") # last element of the path - cur_ast = _find_elem_with_name(cur_ast.nodes if isinstance(ast, Node) else cur_ast.hyperedges, - ast.import_from.name.path[-1].name) + parent = _find_elem_with_name(cur_ast.nodes if isinstance(ast, Node) else cur_ast.hyperedges, + ast.import_from.name.path[-1].name) # pay attention to indices - additional_parameters = [] + parent_indices = [] if ast.import_from.name.path[-1].indices: last_indices = ast.import_from.name.path[-1].indices - if not isinstance(cur_ast, NodeGenerator | HyperEdgeGenerator): + if not isinstance(parent, NodeGenerator | HyperEdgeGenerator): raise RuntimeError("This element is not a Node/Hyperedge generator.") - if len(last_indices) != len(cur_ast.indices): + if len(last_indices) != len(parent.indices): raise RuntimeError("Invalid number of indices.") - for a, b in zip(cur_ast.indices, last_indices): - additional_parameters.append(ExpressionDefinition(a, ExpressionUseGenScope(b))) + for a, b in zip(parent.indices, last_indices): + parent_indices.append(ExpressionDefinition(a, ExpressionUseGenScope(b))) + + _check_indices(ast, parent) - # merge node/hyperedge - return _merge(ast, cur_ast, additional_parameters) + return _update_import_from(ast, parent, parent_indices) return modify(tree, {Node: update, HyperEdge: update}) -- GitLab