From 7312d68d4e2e91c2e8cb2f4de84d24202a5488ff Mon Sep 17 00:00:00 2001 From: Derval Guillaume <gderval@uliege.be> Date: Wed, 21 Dec 2022 16:22:40 +0100 Subject: [PATCH] Fix a bug in tree modifier + typing --- src/gboml/ast/base.py | 6 +++++- src/gboml/redundant_definitions.py | 2 +- src/gboml/tools/tree_modifier.py | 8 ++++++-- 3 files changed, 12 insertions(+), 4 deletions(-) diff --git a/src/gboml/ast/base.py b/src/gboml/ast/base.py index dad9d10..6e5c4f2 100644 --- a/src/gboml/ast/base.py +++ b/src/gboml/ast/base.py @@ -1,3 +1,4 @@ +import typing from dataclasses import dataclass, field from typing import Optional @@ -15,4 +16,7 @@ class GBOMLObject: @dataclass class NamedGBOMLObject(GBOMLObject): - name: str \ No newline at end of file + name: str + + +AnyGBOMLObject = typing.TypeVar('AnyGBOMLObject', bound=GBOMLObject) \ No newline at end of file diff --git a/src/gboml/redundant_definitions.py b/src/gboml/redundant_definitions.py index af765c6..1d47b46 100644 --- a/src/gboml/redundant_definitions.py +++ b/src/gboml/redundant_definitions.py @@ -46,7 +46,7 @@ from gboml.ast import * from gboml.tools.tree_modifier import modify -def remove_redundant_definitions(elem: GBOMLObject) -> GBOMLObject: +def remove_redundant_definitions(elem: AnyGBOMLObject) -> AnyGBOMLObject: return modify(elem, {Node: _modify_node, HyperEdge: _modify_hyperedge}) diff --git a/src/gboml/tools/tree_modifier.py b/src/gboml/tools/tree_modifier.py index efee4f2..e8e9521 100644 --- a/src/gboml/tools/tree_modifier.py +++ b/src/gboml/tools/tree_modifier.py @@ -124,7 +124,8 @@ def _modify_gbomlobject(obj, by): # if we now what to do with the current object, let's do it for x in family_list[obj.__class__]: if x in by: - return by[x](obj) + obj = by[x](obj) + break # and now the hard part interesting_fields = set() @@ -141,7 +142,10 @@ def _modify_gbomlobject(obj, by): return obj -def modify(element: typing.Any, by: dict[typing.Type[GBOMLObject], typing.Callable[[GBOMLObject], GBOMLObject]]): +T = typing.TypeVar('T') + + +def modify(element: T, by: dict[typing.Type[AnyGBOMLObject], typing.Callable[[AnyGBOMLObject], AnyGBOMLObject]]) -> T: """ Recursively modifies a GBOMLGraph tree (or any part of it) according to rules set in the dict `by`. `by` entries should be in the form `(cls: fun)`, where cls is a class derivating from GBOMLObject and -- GitLab