Skip to content

Commit

Permalink
store matches instead
Browse files Browse the repository at this point in the history
  • Loading branch information
brimoor committed Jan 1, 2025
1 parent 663adf8 commit 33f06a8
Show file tree
Hide file tree
Showing 2 changed files with 161 additions and 166 deletions.
148 changes: 50 additions & 98 deletions fiftyone/utils/eval/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from copy import deepcopy
import logging
import inspect
import itertools
import warnings

import numpy as np
Expand Down Expand Up @@ -349,9 +350,7 @@ def evaluate_samples(

nc = len(values)
confusion_matrix = np.zeros((nc, nc), dtype=int)
weights_dict = {}
ytrue_ids_dict = {}
ypred_ids_dict = {}
matches = []

bandwidth = self.config.bandwidth
average = self.config.average
Expand Down Expand Up @@ -394,18 +393,16 @@ def evaluate_samples(
)
sample_conf_mat += image_conf_mat

for index in zip(*np.nonzero(image_conf_mat)):
if index not in weights_dict:
weights_dict[index] = []
weights_dict[index].append(int(image_conf_mat[index]))

if index not in ytrue_ids_dict:
ytrue_ids_dict[index] = []
ytrue_ids_dict[index].append(gt_seg.id)

if index not in ypred_ids_dict:
ypred_ids_dict[index] = []
ypred_ids_dict[index].append(pred_seg.id)
for i, j in zip(*np.nonzero(image_conf_mat)):
matches.append(
(
classes[i],
classes[j],
int(image_conf_mat[i, j]),
gt_seg.id,
pred_seg.id,
)
)

if processing_frames and save:
facc, fpre, frec = _compute_accuracy_precision_recall(
Expand Down Expand Up @@ -440,9 +437,7 @@ def evaluate_samples(
eval_key,
confusion_matrix,
classes,
weights_dict=weights_dict,
ytrue_ids_dict=ytrue_ids_dict,
ypred_ids_dict=ypred_ids_dict,
matches=matches,
missing=missing,
backend=self,
)
Expand All @@ -457,11 +452,9 @@ class SegmentationResults(BaseEvaluationResults):
eval_key: the evaluation key
pixel_confusion_matrix: a pixel value confusion matrix
classes: a list of class labels corresponding to the confusion matrix
weights_dict (None): a dict mapping ``(i, j)`` tuples to pixel counts
ytrue_ids_dict (None): a dict mapping ``(i, j)`` tuples to lists of
ground truth IDs
ypred_ids_dict (None): a dict mapping ``(i, j)`` tuples to lists of
predicted label IDs
matches (None): a list of
``(gt_label, pred_label, pixel_count, gt_id, pred_id)``
matches
missing (None): a missing (background) class
backend (None): a :class:`SegmentationEvaluation` backend
"""
Expand All @@ -473,27 +466,20 @@ def __init__(
eval_key,
pixel_confusion_matrix,
classes,
weights_dict=None,
ytrue_ids_dict=None,
ypred_ids_dict=None,
matches=None,
missing=None,
backend=None,
):
pixel_confusion_matrix = np.asarray(pixel_confusion_matrix)

(
ytrue,
ypred,
weights,
ytrue_ids,
ypred_ids,
) = self._parse_confusion_matrix(
pixel_confusion_matrix,
classes,
weights_dict=weights_dict,
ytrue_ids_dict=ytrue_ids_dict,
ypred_ids_dict=ypred_ids_dict,
)
if matches is not None:
ytrue, ypred, weights, ytrue_ids, ypred_ids = zip(*matches)
else:
ytrue, ypred, weights = self._parse_confusion_matrix(
pixel_confusion_matrix, classes
)
ytrue_ids = None
ypred_ids = None

super().__init__(
samples,
Expand All @@ -510,20 +496,6 @@ def __init__(
)

self.pixel_confusion_matrix = pixel_confusion_matrix
self.weights_dict = weights_dict
self.ytrue_ids_dict = ytrue_ids_dict
self.ypred_ids_dict = ypred_ids_dict

def attributes(self):
return [
"cls",
"pixel_confusion_matrix",
"classes",
"weights_dict",
"ytrue_ids_dict",
"ypred_ids_dict",
"missing",
]

def dice_score(self):
"""Computes the Dice score across all samples in the evaluation.
Expand All @@ -535,70 +507,50 @@ def dice_score(self):

@classmethod
def _from_dict(cls, d, samples, config, eval_key, **kwargs):
ytrue = d.get("ytrue", None)
ypred = d.get("ypred", None)
weights = d.get("weights", None)
ytrue_ids = d.get("ytrue_ids", None)
ypred_ids = d.get("ypred_ids", None)

if ytrue is not None and ypred is not None and weights is not None:
if ytrue_ids is None:
ytrue_ids = itertools.repeat(None)

if ypred_ids is None:
ypred_ids = itertools.repeat(None)

matches = list(zip(ytrue, ypred, weights, ytrue_ids, ypred_ids))
else:
# Legacy format segmentations
matches = None

return cls(
samples,
config,
eval_key,
d["pixel_confusion_matrix"],
d["classes"],
weights_dict=_parse_index_dict(d.get("weights_dict", None)),
ytrue_ids_dict=_parse_index_dict(d.get("ytrue_ids_dict", None)),
ypred_ids_dict=_parse_index_dict(d.get("ypred_ids_dict", None)),
matches=matches,
missing=d.get("missing", None),
**kwargs,
)

@staticmethod
def _parse_confusion_matrix(
confusion_matrix,
classes,
weights_dict=None,
ytrue_ids_dict=None,
ypred_ids_dict=None,
):
have_ids = ytrue_ids_dict is not None and ypred_ids_dict is not None

def _parse_confusion_matrix(confusion_matrix, classes):
ytrue = []
ypred = []
weights = []
if have_ids:
ytrue_ids = []
ypred_ids = []
else:
ytrue_ids = None
ypred_ids = None

nrows, ncols = confusion_matrix.shape
for i in range(nrows):
for j in range(ncols):
cij = confusion_matrix[i, j]
if cij > 0:
if have_ids:
index = (i, j)
classi = classes[i]
classj = classes[j]
for weight, ytrue_id, ypred_id in zip(
weights_dict[index],
ytrue_ids_dict[index],
ypred_ids_dict[index],
):
ytrue.append(classi)
ypred.append(classj)
weights.append(weight)
ytrue_ids.append(ytrue_id)
ypred_ids.append(ypred_id)
else:
ytrue.append(classes[i])
ypred.append(classes[j])
weights.append(cij)

return ytrue, ypred, weights, ytrue_ids, ypred_ids


def _parse_index_dict(d):
import ast

return {ast.literal_eval(k): v for k, v in d.items()}
ytrue.append(classes[i])
ypred.append(classes[j])
weights.append(cij)

return ytrue, ypred, weights


def _parse_config(pred_field, gt_field, method, **kwargs):
Expand Down
Loading

0 comments on commit 33f06a8

Please sign in to comment.