From 4454a1f3bdde1a609b877fbebdb52f9a047d08d5 Mon Sep 17 00:00:00 2001 From: gitttt-1234 Date: Fri, 7 Feb 2025 10:54:59 -0800 Subject: [PATCH] Filter instances while generating indices --- sleap_nn/data/custom_datasets.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/sleap_nn/data/custom_datasets.py b/sleap_nn/data/custom_datasets.py index 45516af9..c797c05c 100644 --- a/sleap_nn/data/custom_datasets.py +++ b/sleap_nn/data/custom_datasets.py @@ -347,11 +347,6 @@ def _fill_cache(self): lf = self.labels[lf_idx] video_idx = self._get_video_idx(lf) - # Filter to user instances - if self.data_config.user_instances_only: - if lf.user_instances is not None and len(lf.user_instances) > 0: - lf.instances = lf.user_instances - if lf_idx == self.cache_lf[0]: img = self.cache_lf[1] else: @@ -362,8 +357,7 @@ def _fill_cache(self): instances = [] for inst in lf: - if not inst.is_empty: - instances.append(inst.numpy()) + instances.append(inst.numpy()) instances = np.stack(instances, axis=0) # Add singleton time dimension for single frames. @@ -443,8 +437,13 @@ def _get_instance_idx_list(self) -> List[Tuple[int]]: """Return list of tuples with indices of labelled frames and instances.""" instance_idx_list = [] for lf_idx, lf in enumerate(self.labels): - for inst_idx, _ in enumerate(lf.instances): - instance_idx_list.append((lf_idx, inst_idx)) + # Filter to user instances + if self.data_config.user_instances_only: + if lf.user_instances is not None and len(lf.user_instances) > 0: + lf.instances = lf.user_instances + for inst_idx, inst in enumerate(lf.instances): + if not inst.is_empty: # filter all NaN instances. + instance_idx_list.append((lf_idx, inst_idx)) return instance_idx_list def __len__(self) -> int: