diff --git a/src/motile_toolbox/visualization/napari_utils.py b/src/motile_toolbox/visualization/napari_utils.py index 74859b2..da0c4bd 100644 --- a/src/motile_toolbox/visualization/napari_utils.py +++ b/src/motile_toolbox/visualization/napari_utils.py @@ -1,6 +1,8 @@ import networkx as nx import numpy as np +from motile_toolbox.candidate_graph import NodeAttr + def assign_tracklet_ids(graph: nx.DiGraph) -> nx.DiGraph: """Add a tracklet_id attribute to a graph by removing division edges, @@ -36,7 +38,7 @@ def assign_tracklet_ids(graph: nx.DiGraph) -> nx.DiGraph: def to_napari_tracks_layer( - graph, frame_key="t", location_keys=("y", "x"), properties=() + graph, frame_key=NodeAttr.TIME.value, location_key=NodeAttr.POS.value, properties=() ): """Function to take a networkx graph and return the data needed to add to a napari tracks layer. @@ -44,9 +46,9 @@ def to_napari_tracks_layer( Args: graph (nx.DiGraph): _description_ frame_key (str, optional): Key in graph attributes containing time frame. - Defaults to "t". - location_keys (tuple, optional): Keys in graph node attributes containing - location. Should be in order: (Z), Y, X. Defaults to ("y", "x"). + Defaults to NodeAttr.TIME.value. + location_key (str, optional): Key in graph node attributes containing + location. Defaults to NodeAttr.POS.value. properties (tuple, optional): Keys in graph node attributes to add to the visualization layer. Defaults to (). NOTE: not working now :( @@ -65,13 +67,16 @@ def to_napari_tracks_layer( case of track splitting, or more than one (the track has multiple parents, but only one child) in the case of track merging. """ - napari_data = np.zeros((graph.number_of_nodes(), len(location_keys) + 2)) + for _, loc in graph.nodes(data=location_key): + ndim = len(loc) + break + napari_data = np.zeros((graph.number_of_nodes(), ndim + 2)) napari_properties = {prop: np.zeros(graph.number_of_nodes()) for prop in properties} napari_edges = {} graph, intertrack_edges = assign_tracklet_ids(graph) for index, node in enumerate(graph.nodes(data=True)): node_id, data = node - location = [data[loc_key] for loc_key in location_keys] + location = data[location_key] napari_data[index] = [data["tracklet_id"], data[frame_key], *location] for prop in properties: if prop in data: