Skip to content
Snippets Groups Projects
Commit 34a63f23 authored by Derval Guillaume's avatar Derval Guillaume
Browse files

New semantic for import resolution

parent 312aa405
No related branches found
No related tags found
No related merge requests found
......@@ -17,7 +17,7 @@ class HyperEdge(NamedGBOMLObject):
@dataclass
class HyperEdgeDefinition(HyperEdge):
name: str
import_from: Optional[Extends] = None
import_from: Optional[Extends | HyperEdge] = None
parameters: list[Definition] = field(default_factory=list)
constraints: list[Constraint] = field(default_factory=list)
activations: list[CtrActivation] = field(default_factory=list)
......@@ -29,7 +29,7 @@ class HyperEdgeGenerator(HyperEdge):
name: str
indices: list[str]
loop: Loop
import_from: Optional[Extends] = None
import_from: Optional[Extends | HyperEdge] = None
parameters: list[Definition] = field(default_factory=list)
constraints: list[Constraint] = field(default_factory=list)
activations: list[CtrActivation] = field(default_factory=list)
......
......@@ -19,7 +19,7 @@ class Node(NamedGBOMLObject):
@dataclass
class NodeDefinition(Node):
name: str
import_from: Optional[Extends] = None
import_from: Optional[Extends | Node] = None
parameters: list[Definition] = field(default_factory=list)
nodes: list[Node] = field(default_factory=list)
hyperedges: list[HyperEdge] = field(default_factory=list)
......@@ -35,7 +35,7 @@ class NodeGenerator(Node):
name: str
indices: list[str]
loop: Loop
import_from: Optional[Extends] = None
import_from: Optional[Extends | Node] = None
parameters: list[Definition] = field(default_factory=list)
nodes: list[Node] = field(default_factory=list)
hyperedges: list[HyperEdge] = field(default_factory=list)
......
"""
This step aims at resolving imports and extension of other GBOML models.
At the end of this step, no "extends" or "import" keyword may remain in the resulting graph
At the end of this step, no "Extends" or "import" cls may remain in the resulting graph
"""
import dataclasses
from pathlib import Path
from typing import Optional
from gboml.ast import *
from gboml.parsing import parse_file
from gboml.redundant_definitions import remove_redundant_definitions
from gboml.tools.tree_modifier import modify
# Singleton used in _load_file to detect cyclic imports
WORKING = object
file_cache = {}
inheritable_ast = NodeDefinition | NodeGenerator | HyperEdgeDefinition | HyperEdgeGenerator
def load_file(fpath: Path):
def _load_file(fpath: Path, file_cache: dict[Path, GBOMLGraph]):
""" Loads a file and resolves its imports. file_cache is used as a cache for already-seen files. """
fpath = fpath.absolute()
if fpath not in file_cache:
file_cache[fpath] = resolve_imports(parse_file(fpath), fpath.parent)
file_cache[fpath] = resolve_imports(parse_file(fpath), fpath.parent, file_cache)
elif file_cache[fpath] is WORKING:
raise RuntimeError("Recursive import")
raise RuntimeError("Cyclic import")
return file_cache[fpath]
def _merge(newAst: Node | HyperEdge,
extends: Node | HyperEdge,
additionnal_parameters: list[Definition]) -> Node | HyperEdge:
""" Merges a node/hyperedge and its parent, forming a full node/hyperedge without extension. """
if isinstance(newAst, Node):
merge_fields = {"nodes", "hyperedges", "variables", "constraints", "objectives", "activations"}
else:
merge_fields = {"constraints", "activations"}
def _update_import_from(child: inheritable_ast, parent: inheritable_ast, parent_indices: list[Definition]) -> inheritable_ast:
""" Replaces the "Extend" element of child with its true parent, and add the needed parent_indices to its parameters """
return dataclasses.replace(
child,
import_from=parent,
parameters=parent_indices + parent.parameters
)
def _check_indices(child: inheritable_ast,
parent: inheritable_ast):
""" Checks that no indices are overriden """
child_indices = set()
if isinstance(child, NodeGenerator) or isinstance(child, HyperEdgeGenerator):
child_indices = set(child.indices)
child_parameters = {definition.name for definition in child.parameters}
if not child_indices.isdisjoint(child_parameters):
raise RuntimeError(f"The following indices are redefined: " + str(child_indices.intersection(child_parameters)))
while parent is not None:
if isinstance(parent, NodeGenerator) or isinstance(parent, HyperEdgeGenerator):
parent_indices = set(parent.indices)
if not parent_indices.isdisjoint(child_indices):
raise RuntimeError(f"{child.name} cannot share indices {parent_indices.intersection(child_indices)} with its parent {parent.name}. Change the name of the indice(s).")
if not parent_indices.isdisjoint(child_parameters):
raise RuntimeError(f"{child.name} cannot override indices {parent_indices.intersection(child_parameters)} of its parent {parent.name}.")
parent_parameters = {definition.name for definition in parent.parameters}
if not parent_parameters.isdisjoint(child_indices):
raise RuntimeError(f"{child.name}'s indices {parent_parameters.intersection(child_indices)} override parameters of its parent {parent.name}.")
parent = parent.import_from
return remove_redundant_definitions(dataclasses.replace(
newAst,
import_from=None,
tags=newAst.tags | extends.tags,
parameters=additionnal_parameters + extends.parameters + newAst.parameters,
**{f: getattr(extends, f) + getattr(newAst, f) for f in merge_fields}
))
def _find_elem_with_name(l, name):
......@@ -51,14 +68,29 @@ def _find_elem_with_name(l, name):
raise RuntimeError(f"Multiple nodes/hyperedges have the same name '{name}'")
return valid_nodes[0]
def resolve_imports(tree: GBOMLObject, current_dir: Path, file_cache: Optional[dict[Path, GBOMLGraph]] = None) -> GBOMLObject:
"""
Resolves imports, transforming all `Extends` entries to Nodes/HyperEdges.
def resolve_imports(tree: GBOMLObject, current_dir: Path) -> GBOMLObject:
def update(ast: NodeDefinition | NodeGenerator | HyperEdgeDefinition | HyperEdgeGenerator) \
-> NodeDefinition | NodeGenerator | HyperEdgeDefinition | HyperEdgeGenerator:
Args:
tree:
current_dir:
file_cache: dict to be used as a cache. Should initially be empty, and should be reused between calls to
resolve_imports.
Returns:
A modified tree where Node/HyperEdges with an import_from value that is of type Extends have been
replaced by a Node/HyperEdge
"""
if file_cache is None:
file_cache = {}
def update(ast: inheritable_ast) -> inheritable_ast:
if ast.import_from is None:
return ast
imported_file = load_file(current_dir / ast.import_from.filename)
imported_file = _load_file(current_dir / ast.import_from.filename, file_cache)
# for now, we only resolve "directly-named" nodes in other files.
# in the future we may resolve nodes referenced inside arrays or parameters, but for now we don't.
......@@ -74,22 +106,23 @@ def resolve_imports(tree: GBOMLObject, current_dir: Path) -> GBOMLObject:
raise RuntimeError("Invalid number of indices.")
# last element of the path
cur_ast = _find_elem_with_name(cur_ast.nodes if isinstance(ast, Node) else cur_ast.hyperedges,
ast.import_from.name.path[-1].name)
parent = _find_elem_with_name(cur_ast.nodes if isinstance(ast, Node) else cur_ast.hyperedges,
ast.import_from.name.path[-1].name)
# pay attention to indices
additional_parameters = []
parent_indices = []
if ast.import_from.name.path[-1].indices:
last_indices = ast.import_from.name.path[-1].indices
if not isinstance(cur_ast, NodeGenerator | HyperEdgeGenerator):
if not isinstance(parent, NodeGenerator | HyperEdgeGenerator):
raise RuntimeError("This element is not a Node/Hyperedge generator.")
if len(last_indices) != len(cur_ast.indices):
if len(last_indices) != len(parent.indices):
raise RuntimeError("Invalid number of indices.")
for a, b in zip(cur_ast.indices, last_indices):
additional_parameters.append(ExpressionDefinition(a, ExpressionUseGenScope(b)))
for a, b in zip(parent.indices, last_indices):
parent_indices.append(ExpressionDefinition(a, ExpressionUseGenScope(b)))
_check_indices(ast, parent)
# merge node/hyperedge
return _merge(ast, cur_ast, additional_parameters)
return _update_import_from(ast, parent, parent_indices)
return modify(tree, {Node: update, HyperEdge: update})
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment