Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactored link processing code and added add_links function #149

Merged
merged 1 commit into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ultrack/core/interactive.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def add_new_node(
node = Node.from_mask(
time=time,
mask=mask,
bbox=bbox,
bbox=np.asarray(bbox),
)
if node.area == 0:
raise ValueError("Node area is zero. Something went wrong.")
Expand Down
1 change: 1 addition & 0 deletions ultrack/core/linking/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from ultrack.core.linking.processing import add_links
178 changes: 143 additions & 35 deletions ultrack/core/linking/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,61 @@ def _compute_features(
]


def color_filtering_mask(
time: int,
current_nodes: List[Node],
next_nodes: List[Node],
images: Sequence[ArrayLike],
neighbors: ArrayLike,
z_score_threshold: float,
) -> ArrayLike:
"""
Filtering by color z-score.

Parameters
----------
time : int
Current time.
current_nodes : List[Node]
List of source nodes.
next_nodes : List[Node]
List of target nodes.
images : Sequence[ArrayLike]
Sequence of images to extract color features for filtering.
neighbors : ArrayLike
Neighbors indices (current/source) for each target (next) node.
z_score_threshold : float
Z-score threshold for color filtering.

Returns
-------
ArrayLike
Boolean mask of neighboring nodes within color z-score threshold.

"""
LOG.info(f"computing filtering by color z-score from t={time}")
(current_features,) = _compute_features(
time, current_nodes, images, [Node.intensity_mean]
)
# inserting dummy value for missing neighbors
current_features = np.append(
current_features,
np.zeros((1, current_features.shape[1])),
axis=0,
)
next_features, next_features_std = _compute_features(
time + 1, next_nodes, images, [Node.intensity_mean, Node.intensity_std]
)
LOG.info(
f"Features Std. Dev. range {next_features_std.min()} {next_features_std.max()}"
)
next_features_std[next_features_std <= 1e-6] = 1.0
difference = next_features[:, None, ...] - current_features[neighbors]
difference /= next_features_std[:, None, ...]
filtered_by_color = np.abs(difference).max(axis=-1) <= z_score_threshold
return filtered_by_color


@curry
def _process(
time: int,
Expand Down Expand Up @@ -91,71 +146,86 @@ def _process(
next_nodes = [row[0] for row in query]
next_shift = np.asarray([row[1:] for row in query])

current_pos = np.asarray([n.centroid for n in current_nodes])
next_pos = np.asarray([n.centroid for n in next_nodes], dtype=np.float32)
compute_spatial_neighbors(
time,
config,
current_nodes,
next_nodes,
next_shift,
scale=scale,
table_name=LinkDB.__tablename__,
db_path=db_path,
images=images,
write_lock=write_lock,
)

n_dim = next_pos.shape[1]
next_shift = next_shift[:, -n_dim:] # matching positions dimensions
next_pos += next_shift

def compute_spatial_neighbors(
time: int,
config: LinkingConfig,
source_nodes: List[Node],
target_nodes: List[Node],
target_shift: ArrayLike,
scale: Optional[Sequence[float]],
table_name: str,
db_path: str,
images: Sequence[ArrayLike],
write_lock: Optional[fasteners.InterProcessLock] = None,
) -> pd.DataFrame:

source_pos = np.asarray([n.centroid for n in source_nodes])
target_pos = np.asarray([n.centroid for n in target_nodes], dtype=np.float32)

n_dim = target_pos.shape[1]
target_shift = target_shift[:, -n_dim:] # matching positions dimensions
target_pos += target_shift

if scale is not None:
min_n_dim = min(n_dim, len(scale))
scale = scale[-min_n_dim:]
current_pos = current_pos[..., -min_n_dim:] * scale
next_pos = next_pos[..., -min_n_dim:] * scale
source_pos = source_pos[..., -min_n_dim:] * scale
target_pos = target_pos[..., -min_n_dim:] * scale

# finds neighbors nodes within the radius
# and connect the pairs with highest edge weight
current_kdtree = KDTree(current_pos)
current_kdtree = KDTree(source_pos)

distances, neighbors = current_kdtree.query(
next_pos,
target_pos,
# twice as expected because we select the nearest with highest edge weight
k=2 * config.max_neighbors,
distance_upper_bound=config.max_distance,
)

if len(images) > 0:
LOG.info(f"computing filtering by color z-score from t={time}")
(current_features,) = _compute_features(
time, current_nodes, images, [Node.intensity_mean]
)
# inserting dummy value for missing neighbors
current_features = np.append(
current_features,
np.zeros((1, current_features.shape[1])),
axis=0,
)
next_features, next_features_std = _compute_features(
time + 1, next_nodes, images, [Node.intensity_mean, Node.intensity_std]
filtered_by_color = color_filtering_mask(
time,
source_nodes,
target_nodes,
images,
neighbors,
config.z_score_threshold,
)
LOG.info(
f"Features Std. Dev. range {next_features_std.min()} {next_features_std.max()}"
)
next_features_std[next_features_std <= 1e-6] = 1.0
difference = next_features[:, None, ...] - current_features[neighbors]
difference /= next_features_std[:, None, ...]
filtered_by_color = np.abs(difference).max(axis=-1) <= config.z_score_threshold
else:
filtered_by_color = np.ones_like(neighbors, dtype=bool)

int_next_shift = np.round(next_shift).astype(int)
int_next_shift = np.round(target_shift).astype(int)
# NOTE: moving bbox with shift, MUST be after `feature computation`
for node, shift in zip(next_nodes, int_next_shift):
for node, shift in zip(target_nodes, int_next_shift):
node.bbox[:n_dim] += shift
node.bbox[-n_dim:] += shift

distance_w = config.distance_weight
links = []

for i, node in enumerate(next_nodes):
for i, node in enumerate(target_nodes):
valid = (~np.isinf(distances[i])) & filtered_by_color[i]
valid_neighbors = neighbors[i, valid]
neigh_distances = distances[i, valid]

neighborhood = []
for neigh_idx, neigh_dist in zip(valid_neighbors, neigh_distances):
neigh = current_nodes[neigh_idx]
neigh = source_nodes[neigh_idx]
edge_weight = node.IoU(neigh) - distance_w * neigh_dist
# using dist as a tie-breaker
neighborhood.append(
Expand All @@ -176,13 +246,14 @@ def _process(

with write_lock if write_lock is not None else nullcontext():
LOG.info(f"Pushing links from time {time} to {db_path}")
connect_args = {"timeout": 45} if write_lock is not None else {}
engine = sqla.create_engine(
db_path, hide_parameters=True, connect_args=connect_args
)
with engine.begin() as conn:
df.to_sql(
name=LinkDB.__tablename__, con=conn, if_exists="append", index=False
)
df.to_sql(name=table_name, con=conn, if_exists="append", index=False)

return df


def link(
Expand Down Expand Up @@ -230,3 +301,40 @@ def link(
multiprocessing_apply(
process, time_points, config.linking_config.n_workers, desc="Linking nodes."
)


def add_links(
config: MainConfig,
sources: ArrayLike,
targets: ArrayLike,
weights: ArrayLike,
) -> None:
"""
Adds user-defined links to the database.

Parameters
----------
config : MainConfig
Configuration parameters.
sources : ArrayLike
Sources (t) node id.
targets : ArrayLike
Targets (t + 1) node id.
weights : ArrayLike
Link weights, the higher the weight the more likely the link.
"""
df = pd.DataFrame(
{
"source_id": np.asarray(sources, dtype=int),
"target_id": np.asarray(targets, dtype=int),
"weight": weights,
}
)

engine = sqla.create_engine(
config.data_config.database_path,
hide_parameters=True,
)

with engine.begin() as conn:
df.to_sql(name=LinkDB.__tablename__, con=conn, if_exists="append", index=False)
28 changes: 19 additions & 9 deletions ultrack/core/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
tracks_layer_to_trackmate,
tracks_to_zarr,
)
from ultrack.core.linking.processing import link
from ultrack.core.linking.processing import add_links, link
from ultrack.core.main import track
from ultrack.core.segmentation.processing import segment
from ultrack.core.solve.processing import solve
Expand Down Expand Up @@ -71,32 +71,31 @@ def __init__(self, config: MainConfig) -> None:
@rename_argument("edges", "contours")
def segment(self, foreground: ArrayLike, contours: ArrayLike, **kwargs) -> None:
segment(foreground=foreground, contours=contours, config=self.config, **kwargs)
self.status = TrackerStatus.SEGMENTED
self.status &= ~TrackerStatus.NOT_COMPUTED
self.status |= TrackerStatus.SEGMENTED

@functools.wraps(add_flow)
def add_flow(self, vector_field: ArrayLike) -> None:
if TrackerStatus.SEGMENTED not in self.status:
raise ValueError("You must call `segment` before calling `add_flow`.")
self._assert_segmented("add_flow")
add_flow(config=self.config, vector_field=vector_field)

@functools.wraps(link)
def link(self, *args, **kwargs) -> None:
if TrackerStatus.SEGMENTED not in self.status:
raise ValueError("You must call `segment` before calling `link`.")
self._assert_segmented("link")
link(config=self.config, *args, **kwargs)
self.status = TrackerStatus.LINKED
self.status |= TrackerStatus.LINKED

@functools.wraps(solve)
def solve(self, *args, **kwargs) -> None:
if TrackerStatus.LINKED not in self.status:
raise ValueError("You must call `segment` & `link` before calling `solve`.")
solve(config=self.config, *args, **kwargs)
self.status = TrackerStatus.SOLVED
self.status |= TrackerStatus.SOLVED

@functools.wraps(track)
def track(self, *args, **kwargs) -> None:
track(config=self.config, *args, **kwargs)
self.status = TrackerStatus.SOLVED
self.status |= TrackerStatus.SOLVED

def _assert_solved(self) -> None:
"""Raise an error if the tracking is not solved."""
Expand All @@ -106,6 +105,11 @@ def _assert_solved(self) -> None:
"called `segment` &a `link` & `solve` or `track`."
)

def _assert_segmented(self, method_name: str) -> None:
"""Raise an error if segmentation is not done."""
if TrackerStatus.SEGMENTED not in self.status:
raise ValueError(f"You must call `segment` before calling `{method_name}`.")

@functools.wraps(tracks_layer_to_networkx)
def to_networkx(
self, *, tracks_df: Optional[pd.DataFrame] = None, **kwargs
Expand Down Expand Up @@ -155,6 +159,12 @@ def export_by_extension(self, filename: str, overwrite: bool = False) -> None:
self._assert_solved()
export_tracks_by_extension(self.config, filename, overwrite=overwrite)

@functools.wraps(add_links)
def add_links(self, **kwargs) -> None:
self._assert_segmented("add_links")
add_links(config=self.config, **kwargs)
self.status |= TrackerStatus.LINKED

@functools.wraps(add_nodes_prob)
def add_nodes_prob(
self,
Expand Down
Loading