diff --git a/src/gboml/ast/base.py b/src/gboml/ast/base.py index dad9d10f43915b1fbb481b343fcf782dc2868626..6e5c4f2a075eb0bf94689c4f59a6619576a3e667 100644 --- a/src/gboml/ast/base.py +++ b/src/gboml/ast/base.py @@ -1,3 +1,4 @@ +import typing from dataclasses import dataclass, field from typing import Optional @@ -15,4 +16,7 @@ class GBOMLObject: @dataclass class NamedGBOMLObject(GBOMLObject): - name: str \ No newline at end of file + name: str + + +AnyGBOMLObject = typing.TypeVar('AnyGBOMLObject', bound=GBOMLObject) \ No newline at end of file diff --git a/src/gboml/redundant_definitions.py b/src/gboml/redundant_definitions.py index af765c677838b93cf2cf32f9fda0154bd6526007..1d47b46fb3af973a7e75b965f2934312aa06ca72 100644 --- a/src/gboml/redundant_definitions.py +++ b/src/gboml/redundant_definitions.py @@ -46,7 +46,7 @@ from gboml.ast import * from gboml.tools.tree_modifier import modify -def remove_redundant_definitions(elem: GBOMLObject) -> GBOMLObject: +def remove_redundant_definitions(elem: AnyGBOMLObject) -> AnyGBOMLObject: return modify(elem, {Node: _modify_node, HyperEdge: _modify_hyperedge}) diff --git a/src/gboml/tools/tree_modifier.py b/src/gboml/tools/tree_modifier.py index efee4f2aa0c7788e4ef1c6bc1be0290fafadbfcb..e8e952189e1734aade51b949b9dbb8e3dda0ddc5 100644 --- a/src/gboml/tools/tree_modifier.py +++ b/src/gboml/tools/tree_modifier.py @@ -124,7 +124,8 @@ def _modify_gbomlobject(obj, by): # if we now what to do with the current object, let's do it for x in family_list[obj.__class__]: if x in by: - return by[x](obj) + obj = by[x](obj) + break # and now the hard part interesting_fields = set() @@ -141,7 +142,10 @@ def _modify_gbomlobject(obj, by): return obj -def modify(element: typing.Any, by: dict[typing.Type[GBOMLObject], typing.Callable[[GBOMLObject], GBOMLObject]]): +T = typing.TypeVar('T') + + +def modify(element: T, by: dict[typing.Type[AnyGBOMLObject], typing.Callable[[AnyGBOMLObject], AnyGBOMLObject]]) -> T: """ Recursively modifies a GBOMLGraph tree (or any part of it) according to rules set in the dict `by`. `by` entries should be in the form `(cls: fun)`, where cls is a class derivating from GBOMLObject and