diff --git a/alchemiscale/utils.py b/alchemiscale/utils.py index df7368e8..c50dad84 100644 --- a/alchemiscale/utils.py +++ b/alchemiscale/utils.py @@ -84,24 +84,33 @@ def gufe_to_digraph(gufe_obj): other GufeTokenizables. """ graph = nx.DiGraph() + shallow_dicts = {} def add_edges(o): + # if we've made a shallow dict before, we've already added this one + # and all its dependencies; return `None` to avoid going down the tree + # again + sd = shallow_dicts.get(o.key) + if sd is not None: + return None + + # if not, then we make the shallow dict only once, add it to our index, + # add edges to dependencies, and return it so we continue down the tree + sd = o.to_shallow_dict() + + shallow_dicts[o.key] = sd + # add the object node in case there aren't any connections graph.add_node(o) - connections = gufe_objects_from_shallow_dict(o.to_shallow_dict()) + connections = gufe_objects_from_shallow_dict(sd) for c in connections: graph.add_edge(o, c) - add_edges(gufe_obj) - - def modifier(o): - add_edges(o) - return o.to_shallow_dict() + return sd - _ = modify_dependencies( - gufe_obj.to_shallow_dict(), modifier, is_gufe_obj, mode="encode" - ) + sd = add_edges(gufe_obj) + _ = modify_dependencies(sd, add_edges, is_gufe_obj, mode="encode") return graph