Skip to content

Commit

Permalink
refactor score_less_than_thresh in _binclf_one_curve_python
Browse files Browse the repository at this point in the history
  • Loading branch information
jpcbertoldo committed Feb 9, 2024
1 parent 7737aed commit ffbd354
Showing 1 changed file with 5 additions and 7 deletions.
12 changes: 5 additions & 7 deletions src/anomalib/metrics/per_image/binclf_curve_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import itertools
import logging
from dataclasses import dataclass
from functools import partial
from typing import ClassVar

import numpy as np
Expand Down Expand Up @@ -140,24 +141,21 @@ def _binclf_one_curve_python(scores: ndarray, gts: ndarray, threshs: ndarray) ->
num_neg = current_count_fp = scores_negatives.size
fps = np.empty((num_th,), dtype=np.int64)

def score_less_than_thresh(thresh): # noqa: ANN001, ANN202
def func(score) -> bool: # noqa: ANN001
return score < thresh

return func
def score_less_than_thresh(score: float, thresh: float) -> bool:
return score < thresh

# it will progressively drop the scores that are below the current thresh
for thresh_idx, thresh in enumerate(threshs):
# UPDATE POSITIVES
# < becasue it is the same as ~(>=)
num_drop = sum(1 for _ in itertools.takewhile(score_less_than_thresh(thresh), scores_positives))
num_drop = sum(1 for _ in itertools.takewhile(partial(score_less_than_thresh, thresh=thresh), scores_positives))
scores_positives = scores_positives[num_drop:]
current_count_tp -= num_drop
tps[thresh_idx] = current_count_tp

# UPDATE NEGATIVES
# same with the negatives
num_drop = sum(1 for _ in itertools.takewhile(score_less_than_thresh(thresh), scores_negatives))
num_drop = sum(1 for _ in itertools.takewhile(partial(score_less_than_thresh, thresh=thresh), scores_negatives))
scores_negatives = scores_negatives[num_drop:]
current_count_fp -= num_drop
fps[thresh_idx] = current_count_fp
Expand Down

0 comments on commit ffbd354

Please sign in to comment.