Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow threshold to be outside (0,1) domain #351

Merged
merged 5 commits into from
Jul 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,29 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

**Note: we move fast, but still we preserve 0.1 version (one feature release) back compatibility.**

## [0.x.x] - ????-??-??

## [unreleased] - YYYY-MM-??

### Added

- Added support in `nDCG` metric for target with values larger than 1 ([#343](https://github.com/PyTorchLightning/metrics/issues/343))


### Changed


### Deprecated


### Removed

- Removed restriction that `threshold` has to be in (0,1) range to support logit input ([#351](https://github.com/PyTorchLightning/metrics/pull/351))


### Fixed



## [0.4.1] - 2021-07-05

### Changed
Expand Down
67 changes: 36 additions & 31 deletions tests/classification/test_stat_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,18 @@
from tests.classification.inputs import _input_multilabel_logits as _input_mlb_logits
from tests.classification.inputs import _input_multilabel_prob as _input_mlb_prob
from tests.helpers import seed_all
from tests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester
from tests.helpers.testers import NUM_CLASSES, MetricTester
from torchmetrics import StatScores
from torchmetrics.functional import stat_scores
from torchmetrics.utilities.checks import _input_format_classification

seed_all(42)


def _sk_stat_scores(preds, target, reduce, num_classes, multiclass, ignore_index, top_k, mdmc_reduce=None):
def _sk_stat_scores(preds, target, reduce, num_classes, multiclass, ignore_index, top_k, threshold, mdmc_reduce=None):
# todo: `mdmc_reduce` is unused
preds, target, _ = _input_format_classification(
preds, target, threshold=THRESHOLD, num_classes=num_classes, multiclass=multiclass, top_k=top_k
preds, target, threshold=threshold, num_classes=num_classes, multiclass=multiclass, top_k=top_k
)
sk_preds, sk_target = preds.numpy(), target.numpy()

Expand Down Expand Up @@ -75,23 +75,25 @@ def _sk_stat_scores(preds, target, reduce, num_classes, multiclass, ignore_index
return sk_stats


def _sk_stat_scores_mdim_mcls(preds, target, reduce, mdmc_reduce, num_classes, multiclass, ignore_index, top_k):
def _sk_stat_scores_mdim_mcls(
preds, target, reduce, mdmc_reduce, num_classes, multiclass, ignore_index, top_k, threshold
):
preds, target, _ = _input_format_classification(
preds, target, threshold=THRESHOLD, num_classes=num_classes, multiclass=multiclass, top_k=top_k
preds, target, threshold=threshold, num_classes=num_classes, multiclass=multiclass, top_k=top_k
)

if mdmc_reduce == "global":
preds = torch.transpose(preds, 1, 2).reshape(-1, preds.shape[1])
target = torch.transpose(target, 1, 2).reshape(-1, target.shape[1])

return _sk_stat_scores(preds, target, reduce, None, False, ignore_index, top_k)
return _sk_stat_scores(preds, target, reduce, None, False, ignore_index, top_k, threshold)
if mdmc_reduce == "samplewise":
scores = []

for i in range(preds.shape[0]):
pred_i = preds[i, ...].T
target_i = target[i, ...].T
scores_i = _sk_stat_scores(pred_i, target_i, reduce, None, False, ignore_index, top_k)
scores_i = _sk_stat_scores(pred_i, target_i, reduce, None, False, ignore_index, top_k, threshold)

scores.append(np.expand_dims(scores_i, 0))

Expand Down Expand Up @@ -128,34 +130,32 @@ def test_wrong_params(reduce, mdmc_reduce, num_classes, inputs, ignore_index):
sts(inputs.preds[0], inputs.target[0])


def test_wrong_threshold():
with pytest.raises(ValueError):
StatScores(threshold=1.5)


@pytest.mark.parametrize("ignore_index", [None, 0])
@pytest.mark.parametrize("reduce", ["micro", "macro", "samples"])
@pytest.mark.parametrize(
"preds, target, sk_fn, mdmc_reduce, num_classes, multiclass, top_k",
"preds, target, sk_fn, mdmc_reduce, num_classes, multiclass, top_k, threshold",
[
(_input_binary_logits.preds, _input_binary_logits.target, _sk_stat_scores, None, 1, None, None),
(_input_binary_prob.preds, _input_binary_prob.target, _sk_stat_scores, None, 1, None, None),
(_input_binary.preds, _input_binary.target, _sk_stat_scores, None, 1, False, None),
(_input_mlb_logits.preds, _input_mlb_logits.target, _sk_stat_scores, None, NUM_CLASSES, None, None),
(_input_mlb_prob.preds, _input_mlb_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, None),
(_input_mlb_prob.preds, _input_mlb_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, 2),
(_input_mcls.preds, _input_mcls.target, _sk_stat_scores, None, NUM_CLASSES, False, None),
(_input_mcls_prob.preds, _input_mcls_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, None),
(_input_mcls_logits.preds, _input_mcls_logits.target, _sk_stat_scores, None, NUM_CLASSES, None, None),
(_input_mcls_prob.preds, _input_mcls_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, 2),
(_input_multiclass.preds, _input_multiclass.target, _sk_stat_scores, None, NUM_CLASSES, None, None),
(_input_mdmc.preds, _input_mdmc.target, _sk_stat_scores_mdim_mcls, "samplewise", NUM_CLASSES, None, None),
(_input_binary_logits.preds, _input_binary_logits.target, _sk_stat_scores, None, 1, None, None, 0.0),
(_input_binary_prob.preds, _input_binary_prob.target, _sk_stat_scores, None, 1, None, None, 0.5),
(_input_binary.preds, _input_binary.target, _sk_stat_scores, None, 1, False, None, 0.5),
(_input_mlb_logits.preds, _input_mlb_logits.target, _sk_stat_scores, None, NUM_CLASSES, None, None, 0.0),
(_input_mlb_prob.preds, _input_mlb_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, None, 0.5),
(_input_mlb_prob.preds, _input_mlb_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, 2, 0.5),
(_input_mcls.preds, _input_mcls.target, _sk_stat_scores, None, NUM_CLASSES, False, None, 0.5),
(_input_mcls_prob.preds, _input_mcls_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, None, 0.5),
(_input_mcls_logits.preds, _input_mcls_logits.target, _sk_stat_scores, None, NUM_CLASSES, None, None, 0.0),
(_input_mcls_prob.preds, _input_mcls_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, 2, 0.0),
(_input_multiclass.preds, _input_multiclass.target, _sk_stat_scores, None, NUM_CLASSES, None, None, 0.0),
(_input_mdmc.preds, _input_mdmc.target, _sk_stat_scores_mdim_mcls, "samplewise", NUM_CLASSES, None, None, 0.0),
(
_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_stat_scores_mdim_mcls, "samplewise", NUM_CLASSES, None,
None
None, 0.0
),
(_input_mdmc.preds, _input_mdmc.target, _sk_stat_scores_mdim_mcls, "global", NUM_CLASSES, None, None, 0.0),
(
_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_stat_scores_mdim_mcls, "global", NUM_CLASSES, None,
None, 0.0
),
(_input_mdmc.preds, _input_mdmc.target, _sk_stat_scores_mdim_mcls, "global", NUM_CLASSES, None, None),
(_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_stat_scores_mdim_mcls, "global", NUM_CLASSES, None, None),
],
)
class TestStatScores(MetricTester):
Expand All @@ -175,6 +175,7 @@ def test_stat_scores_class(
multiclass: Optional[bool],
ignore_index: Optional[int],
top_k: Optional[int],
threshold: Optional[float],
):
if ignore_index is not None and preds.ndim == 2:
pytest.skip("Skipping ignore_index test with binary inputs.")
Expand All @@ -192,13 +193,14 @@ def test_stat_scores_class(
multiclass=multiclass,
ignore_index=ignore_index,
top_k=top_k,
threshold=threshold
),
dist_sync_on_step=dist_sync_on_step,
metric_args={
"num_classes": num_classes,
"reduce": reduce,
"mdmc_reduce": mdmc_reduce,
"threshold": THRESHOLD,
"threshold": threshold,
"multiclass": multiclass,
"ignore_index": ignore_index,
"top_k": top_k,
Expand All @@ -218,6 +220,7 @@ def test_stat_scores_fn(
multiclass: Optional[bool],
ignore_index: Optional[int],
top_k: Optional[int],
threshold: Optional[float],
):
if ignore_index is not None and preds.ndim == 2:
pytest.skip("Skipping ignore_index test with binary inputs.")
Expand All @@ -234,12 +237,13 @@ def test_stat_scores_fn(
multiclass=multiclass,
ignore_index=ignore_index,
top_k=top_k,
threshold=threshold
),
metric_args={
"num_classes": num_classes,
"reduce": reduce,
"mdmc_reduce": mdmc_reduce,
"threshold": THRESHOLD,
"threshold": threshold,
"multiclass": multiclass,
"ignore_index": ignore_index,
"top_k": top_k,
Expand All @@ -257,6 +261,7 @@ def test_stat_scores_differentiability(
multiclass: Optional[bool],
ignore_index: Optional[int],
top_k: Optional[int],
threshold: Optional[float],
):
if ignore_index is not None and preds.ndim == 2:
pytest.skip("Skipping ignore_index test with binary inputs.")
Expand All @@ -270,7 +275,7 @@ def test_stat_scores_differentiability(
"num_classes": num_classes,
"reduce": reduce,
"mdmc_reduce": mdmc_reduce,
"threshold": THRESHOLD,
"threshold": threshold,
"multiclass": multiclass,
"ignore_index": ignore_index,
"top_k": top_k,
Expand Down
3 changes: 0 additions & 3 deletions torchmetrics/classification/stat_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,6 @@ def __init__(
self.ignore_index = ignore_index
self.top_k = top_k

if not 0 < threshold < 1:
raise ValueError(f"The `threshold` should be a float in the (0,1) interval, got {threshold}")

if reduce not in ["micro", "macro", "samples"]:
raise ValueError(f"The `reduce` {reduce} is not valid.")

Expand Down