diff --git a/CHANGELOG.md b/CHANGELOG.md index 8f6ced3dd..f3e1139bb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,7 @@ dj.FreeTable(dj.conn(), "common_session.session_group").drop() - Add docstrings to all public methods #1076 - Update DataJoint to 0.14.2 #1081 - Allow restriction based on parent keys in `Merge.fetch_nwb()` #1086 +- Import `datajoint.dependencies.unite_master_parts` -> `topo_sort` #1116 ### Pipelines diff --git a/src/spyglass/utils/dj_graph.py b/src/spyglass/utils/dj_graph.py index 6269ecaba..26a944e20 100644 --- a/src/spyglass/utils/dj_graph.py +++ b/src/spyglass/utils/dj_graph.py @@ -13,7 +13,6 @@ from datajoint import FreeTable, Table from datajoint.condition import make_condition -from datajoint.dependencies import unite_master_parts from datajoint.hash import key_hash from datajoint.user_tables import TableMeta from datajoint.utils import get_master, to_camel_case @@ -35,6 +34,11 @@ unique_dicts, ) +try: # Datajoint 0.14.2+ uses topo_sort instead of unite_master_parts + from datajoint.dependencies import topo_sort as dj_topo_sort +except ImportError: + from datajoint.dependencies import unite_master_parts as dj_topo_sort + class Direction(Enum): """Cascade direction enum. Calling Up returns True. Inverting flips.""" @@ -474,7 +478,7 @@ def _topo_sort( if not self._is_out(node, warn=False) ] graph = self.graph.subgraph(nodes) if subgraph else self.graph - ordered = unite_master_parts(list(topological_sort(graph))) + ordered = dj_topo_sort(list(topological_sort(graph))) if reverse: ordered.reverse() return [n for n in ordered if n in nodes] diff --git a/tests/conftest.py b/tests/conftest.py index adb3ccc08..954dad204 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -394,6 +394,7 @@ def frequent_imports(): from spyglass.lfp.analysis.v1 import LFPBandSelection from spyglass.mua.v1.mua import MuaEventsV1 from spyglass.ripple.v1.ripple import RippleTimesV1 + from spyglass.spikesorting.analysis.v1.unit_annotation import UnitAnnotation from spyglass.spikesorting.v0.figurl_views import SpikeSortingRecordingView return ( @@ -403,6 +404,7 @@ def frequent_imports(): RippleTimesV1, SortedSpikesIndicatorSelection, SpikeSortingRecordingView, + UnitAnnotation, UnitMarksIndicatorSelection, )