Skip to content

Commit

Permalink
fix #5254
Browse files Browse the repository at this point in the history
  • Loading branch information
brimoor committed Dec 13, 2024
1 parent 64cf79b commit 686be45
Showing 1 changed file with 31 additions and 27 deletions.
58 changes: 31 additions & 27 deletions fiftyone/operators/builtins/panels/model_evaluation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,17 @@
| `voxel51.com <https://voxel51.com/>`_
|
"""

from collections import defaultdict, Counter
import os
import traceback
import fiftyone.operators.types as types

from collections import defaultdict, Counter
import numpy as np

from fiftyone import ViewField as F
from fiftyone.operators.categories import Categories
from fiftyone.operators.panel import Panel, PanelConfig
from fiftyone.core.plots.plotly import _to_log_colorscale
import fiftyone.operators.types as types


STORE_NAME = "model_evaluation_panel_builtin"
Expand Down Expand Up @@ -104,29 +105,32 @@ def get_avg_confidence(self, per_class_metrics):
total += metrics["confidence"]
return total / count if count > 0 else None

def get_tp_fp_fn(self, ctx):
view_state = ctx.panel.get_state("view") or {}
key = view_state.get("key")
dataset = ctx.dataset
tp_key = f"{key}_tp"
fp_key = f"{key}_fp"
fn_key = f"{key}_fn"
tp_total = (
sum(ctx.dataset.values(tp_key))
if dataset.has_field(tp_key)
else None
)
fp_total = (
sum(ctx.dataset.values(fp_key))
if dataset.has_field(fp_key)
else None
)
fn_total = (
sum(ctx.dataset.values(fn_key))
if dataset.has_field(fn_key)
else None
)
return tp_total, fp_total, fn_total
def get_tp_fp_fn(self, info, results):
# Binary classification
if (
info.config.type == "classification"
and info.config.method == "binary"
):
neg_label, pos_label = results.classes
tp_count = np.count_nonzero(
(results.ytrue == pos_label) & (results.ypred == pos_label)
)
fp_count = np.count_nonzero(
(results.ytrue != pos_label) & (results.ypred == pos_label)
)
fn_count = np.count_nonzero(
(results.ytrue == pos_label) & (results.ypred != pos_label)
)
return tp_count, fp_count, fn_count

# Object detection
if info.config.type == "detection":
tp_count = np.count_nonzero(results.ytrue == results.ypred)
fp_count = np.count_nonzero(results.ytrue == results.missing)
fn_count = np.count_nonzero(results.ypred == results.missing)
return tp_count, fp_count, fn_count

return None, None, None

def get_map(self, results):
try:
Expand Down Expand Up @@ -298,7 +302,7 @@ def load_evaluation(self, ctx):
per_class_metrics
)
metrics["tp"], metrics["fp"], metrics["fn"] = self.get_tp_fp_fn(
ctx
info, results
)
metrics["mAP"] = self.get_map(results)
evaluation_data = {
Expand Down

0 comments on commit 686be45

Please sign in to comment.