Skip to content

Commit

Permalink
Merge branch 'master' into fix_threshold
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda authored Jul 7, 2021
2 parents f7e7ba3 + b20cbda commit fc0e98f
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 8 deletions.
11 changes: 10 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,29 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

**Note: we move fast, but still we preserve 0.1 version (one feature release) back compatibility.**

## [0.x.x] - ????-??-??

## [unreleased] - YYYY-MM-??

### Added

- Added support in `nDCG` metric for target with values larger than 1 ([#343](https://github.com/PyTorchLightning/metrics/issues/343))


### Changed


### Deprecated


### Removed

- Removed restriction that `threshold` has to be in (0,1) range to support logit input ([#351](https://github.com/PyTorchLightning/metrics/pull/351))


### Fixed



## [0.4.1] - 2021-07-05

### Changed
Expand Down
21 changes: 21 additions & 0 deletions tests/retrieval/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from tests.retrieval.inputs import _input_retrieval_scores_mismatching_sizes as _irs_mis_sz
from tests.retrieval.inputs import _input_retrieval_scores_mismatching_sizes_func as _irs_mis_sz_fn
from tests.retrieval.inputs import _input_retrieval_scores_no_target as _irs_no_tgt
from tests.retrieval.inputs import _input_retrieval_scores_non_binary_target as _irs_non_binary
from tests.retrieval.inputs import _input_retrieval_scores_wrong_targets as _irs_bad_tgt

seed_all(42)
Expand Down Expand Up @@ -223,6 +224,16 @@ def _concat_tests(*tests: Tuple[Dict]) -> Dict:
]
)

_default_metric_class_input_arguments_with_non_binary_target = dict(
argnames="indexes,preds,target",
argvalues=[
(_irs.indexes, _irs.preds, _irs.target),
(_irs_extra.indexes, _irs_extra.preds, _irs_extra.target),
(_irs_no_tgt.indexes, _irs_no_tgt.preds, _irs_no_tgt.target),
(_irs_non_binary.indexes, _irs_non_binary.preds, _irs_non_binary.target),
]
)

_default_metric_functional_input_arguments = dict(
argnames="preds,target",
argvalues=[
Expand All @@ -232,6 +243,16 @@ def _concat_tests(*tests: Tuple[Dict]) -> Dict:
]
)

_default_metric_functional_input_arguments_with_non_binary_target = dict(
argnames="preds,target",
argvalues=[
(_irs.preds, _irs.target),
(_irs_extra.preds, _irs_extra.target),
(_irs_no_tgt.preds, _irs_no_tgt.target),
(_irs_non_binary.preds, _irs_non_binary.target),
]
)


def _errors_test_class_metric(
indexes: Tensor,
Expand Down
6 changes: 6 additions & 0 deletions tests/retrieval/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@
target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)),
)

_input_retrieval_scores_non_binary_target = Input(
indexes=torch.randint(high=10, size=(NUM_BATCHES, BATCH_SIZE)),
preds=torch.rand(NUM_BATCHES, BATCH_SIZE),
target=torch.randint(high=4, size=(NUM_BATCHES, BATCH_SIZE)),
)

# with errors
_input_retrieval_scores_no_target = Input(
indexes=torch.randint(high=10, size=(NUM_BATCHES, BATCH_SIZE)),
Expand Down
12 changes: 6 additions & 6 deletions tests/retrieval/test_ndcg.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
from tests.retrieval.helpers import (
RetrievalMetricTester,
_concat_tests,
_default_metric_class_input_arguments,
_default_metric_functional_input_arguments,
_default_metric_class_input_arguments_with_non_binary_target,
_default_metric_functional_input_arguments_with_non_binary_target,
_errors_test_class_metric_parameters_default,
_errors_test_class_metric_parameters_k,
_errors_test_class_metric_parameters_no_pos_target,
Expand Down Expand Up @@ -56,7 +56,7 @@ class TestNDCG(RetrievalMetricTester):
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
@pytest.mark.parametrize("empty_target_action", ['skip', 'neg', 'pos'])
@pytest.mark.parametrize("k", [None, 1, 4, 10])
@pytest.mark.parametrize(**_default_metric_class_input_arguments)
@pytest.mark.parametrize(**_default_metric_class_input_arguments_with_non_binary_target)
def test_class_metric(
self,
ddp: bool,
Expand All @@ -80,7 +80,7 @@ def test_class_metric(
metric_args=metric_args,
)

@pytest.mark.parametrize(**_default_metric_functional_input_arguments)
@pytest.mark.parametrize(**_default_metric_functional_input_arguments_with_non_binary_target)
@pytest.mark.parametrize("k", [None, 1, 4, 10])
def test_functional_metric(self, preds: Tensor, target: Tensor, k: int):
self.run_functional_metric_test(
Expand All @@ -92,7 +92,7 @@ def test_functional_metric(self, preds: Tensor, target: Tensor, k: int):
k=k,
)

@pytest.mark.parametrize(**_default_metric_class_input_arguments)
@pytest.mark.parametrize(**_default_metric_class_input_arguments_with_non_binary_target)
def test_precision_cpu(self, indexes: Tensor, preds: Tensor, target: Tensor):
self.run_precision_test_cpu(
indexes=indexes,
Expand All @@ -102,7 +102,7 @@ def test_precision_cpu(self, indexes: Tensor, preds: Tensor, target: Tensor):
metric_functional=retrieval_normalized_dcg,
)

@pytest.mark.parametrize(**_default_metric_class_input_arguments)
@pytest.mark.parametrize(**_default_metric_class_input_arguments_with_non_binary_target)
def test_precision_gpu(self, indexes: Tensor, preds: Tensor, target: Tensor):
self.run_precision_test_gpu(
indexes=indexes,
Expand Down
5 changes: 4 additions & 1 deletion torchmetrics/retrieval/retrieval_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def __init__(
process_group=process_group,
dist_sync_fn=dist_sync_fn
)
self.allow_non_binary_target = False

empty_target_action_options = ('error', 'skip', 'neg', 'pos')
if empty_target_action not in empty_target_action_options:
Expand All @@ -98,7 +99,9 @@ def update(self, preds: Tensor, target: Tensor, indexes: Tensor) -> None: # typ
if indexes is None:
raise ValueError("Argument `indexes` cannot be None")

indexes, preds, target = _check_retrieval_inputs(indexes, preds, target)
indexes, preds, target = _check_retrieval_inputs(
indexes, preds, target, allow_non_binary_target=self.allow_non_binary_target
)

self.indexes.append(indexes)
self.preds.append(preds)
Expand Down
1 change: 1 addition & 0 deletions torchmetrics/retrieval/retrieval_ndcg.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def __init__(
if (k is not None) and not (isinstance(k, int) and k > 0):
raise ValueError("`k` has to be a positive integer or None")
self.k = k
self.allow_non_binary_target = True

def _metric(self, preds: Tensor, target: Tensor) -> Tensor:
return retrieval_normalized_dcg(preds, target, k=self.k)

0 comments on commit fc0e98f

Please sign in to comment.