From 7312d68d4e2e91c2e8cb2f4de84d24202a5488ff Mon Sep 17 00:00:00 2001
From: Derval Guillaume <gderval@uliege.be>
Date: Wed, 21 Dec 2022 16:22:40 +0100
Subject: [PATCH] Fix a bug in tree modifier + typing

---
 src/gboml/ast/base.py              | 6 +++++-
 src/gboml/redundant_definitions.py | 2 +-
 src/gboml/tools/tree_modifier.py   | 8 ++++++--
 3 files changed, 12 insertions(+), 4 deletions(-)

diff --git a/src/gboml/ast/base.py b/src/gboml/ast/base.py
index dad9d10..6e5c4f2 100644
--- a/src/gboml/ast/base.py
+++ b/src/gboml/ast/base.py
@@ -1,3 +1,4 @@
+import typing
 from dataclasses import dataclass, field
 from typing import Optional
 
@@ -15,4 +16,7 @@ class GBOMLObject:
 
 @dataclass
 class NamedGBOMLObject(GBOMLObject):
-    name: str
\ No newline at end of file
+    name: str
+
+
+AnyGBOMLObject = typing.TypeVar('AnyGBOMLObject', bound=GBOMLObject)
\ No newline at end of file
diff --git a/src/gboml/redundant_definitions.py b/src/gboml/redundant_definitions.py
index af765c6..1d47b46 100644
--- a/src/gboml/redundant_definitions.py
+++ b/src/gboml/redundant_definitions.py
@@ -46,7 +46,7 @@ from gboml.ast import *
 from gboml.tools.tree_modifier import modify
 
 
-def remove_redundant_definitions(elem: GBOMLObject) -> GBOMLObject:
+def remove_redundant_definitions(elem: AnyGBOMLObject) -> AnyGBOMLObject:
     return modify(elem, {Node: _modify_node, HyperEdge: _modify_hyperedge})
 
 
diff --git a/src/gboml/tools/tree_modifier.py b/src/gboml/tools/tree_modifier.py
index efee4f2..e8e9521 100644
--- a/src/gboml/tools/tree_modifier.py
+++ b/src/gboml/tools/tree_modifier.py
@@ -124,7 +124,8 @@ def _modify_gbomlobject(obj, by):
     # if we now what to do with the current object, let's do it
     for x in family_list[obj.__class__]:
         if x in by:
-            return by[x](obj)
+            obj = by[x](obj)
+            break
 
     # and now the hard part
     interesting_fields = set()
@@ -141,7 +142,10 @@ def _modify_gbomlobject(obj, by):
     return obj
 
 
-def modify(element: typing.Any, by: dict[typing.Type[GBOMLObject], typing.Callable[[GBOMLObject], GBOMLObject]]):
+T = typing.TypeVar('T')
+
+
+def modify(element: T, by: dict[typing.Type[AnyGBOMLObject], typing.Callable[[AnyGBOMLObject], AnyGBOMLObject]]) -> T:
     """
         Recursively modifies a GBOMLGraph tree (or any part of it) according to rules set in the dict `by`.
         `by` entries should be in the form `(cls: fun)`, where cls is a class derivating from GBOMLObject and
-- 
GitLab