diff --git a/fiftyone/utils/eval/detection.py b/fiftyone/utils/eval/detection.py index 32aef9af2f..270a7d3b25 100644 --- a/fiftyone/utils/eval/detection.py +++ b/fiftyone/utils/eval/detection.py @@ -64,7 +64,13 @@ def spin(): def _evaluate_detections_bulk( - _samples, gt_field, pred_field, eval_method, eval_key, progress=True + _samples, + gt_field, + pred_field, + eval_method, + eval_key, + progress=True, + save=True, ): matches = [] id_field = "id" @@ -75,6 +81,7 @@ def _evaluate_detections_bulk( decorated_func = spinner_decorator(enabled=True)(_samples.values) + # might want to fetch batches of data instead of the whole collection ids, ground_truths, predictions = decorated_func( [id_field, gt_field, pred_field] ) @@ -97,9 +104,10 @@ def _evaluate_detections_bulk( matches.extend(doc_matches) tp, fp, fn = _tally_matches(doc_matches) - sample_updates["sample_tp"][id] = tp - sample_updates["sample_fp"][id] = fp - sample_updates["sample_fn"][id] = fn + if save: + sample_updates["sample_tp"][id] = tp + sample_updates["sample_fp"][id] = fp + sample_updates["sample_fn"][id] = fn pb.update() docs = (ids, ground_truths, predictions) @@ -284,6 +292,7 @@ def evaluate_detections( eval_method, eval_key, progress=progress, + save=save, ) end_time = time.time() logger.info(