Skip to content

Commit

Permalink
Allow target nDCG metric to be integer larger than 1 (#349)
Browse files Browse the repository at this point in the history
* fix

* changelog
  • Loading branch information
SkafteNicki authored Jul 7, 2021
1 parent 3f02ba2 commit b20cbda
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 7 deletions.
13 changes: 13 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,19 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

**Note: we move fast, but still we preserve 0.1 version (one feature release) back compatibility.**

## [0.x.x] - ????-??-??

### Added

- Added support in `nDCG` metric for target with values larger than 1 ([#343](https://github.com/PyTorchLightning/metrics/issues/343))

### Changed

### Deprecated

### Removed

### Fixed

## [0.4.1] - 2021-07-05

Expand Down
21 changes: 21 additions & 0 deletions tests/retrieval/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from tests.retrieval.inputs import _input_retrieval_scores_mismatching_sizes as _irs_mis_sz
from tests.retrieval.inputs import _input_retrieval_scores_mismatching_sizes_func as _irs_mis_sz_fn
from tests.retrieval.inputs import _input_retrieval_scores_no_target as _irs_no_tgt
from tests.retrieval.inputs import _input_retrieval_scores_non_binary_target as _irs_non_binary
from tests.retrieval.inputs import _input_retrieval_scores_wrong_targets as _irs_bad_tgt

seed_all(42)
Expand Down Expand Up @@ -223,6 +224,16 @@ def _concat_tests(*tests: Tuple[Dict]) -> Dict:
]
)

_default_metric_class_input_arguments_with_non_binary_target = dict(
argnames="indexes,preds,target",
argvalues=[
(_irs.indexes, _irs.preds, _irs.target),
(_irs_extra.indexes, _irs_extra.preds, _irs_extra.target),
(_irs_no_tgt.indexes, _irs_no_tgt.preds, _irs_no_tgt.target),
(_irs_non_binary.indexes, _irs_non_binary.preds, _irs_non_binary.target),
]
)

_default_metric_functional_input_arguments = dict(
argnames="preds,target",
argvalues=[
Expand All @@ -232,6 +243,16 @@ def _concat_tests(*tests: Tuple[Dict]) -> Dict:
]
)

_default_metric_functional_input_arguments_with_non_binary_target = dict(
argnames="preds,target",
argvalues=[
(_irs.preds, _irs.target),
(_irs_extra.preds, _irs_extra.target),
(_irs_no_tgt.preds, _irs_no_tgt.target),
(_irs_non_binary.preds, _irs_non_binary.target),
]
)


def _errors_test_class_metric(
indexes: Tensor,
Expand Down
6 changes: 6 additions & 0 deletions tests/retrieval/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@
target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)),
)

_input_retrieval_scores_non_binary_target = Input(
indexes=torch.randint(high=10, size=(NUM_BATCHES, BATCH_SIZE)),
preds=torch.rand(NUM_BATCHES, BATCH_SIZE),
target=torch.randint(high=4, size=(NUM_BATCHES, BATCH_SIZE)),
)

# with errors
_input_retrieval_scores_no_target = Input(
indexes=torch.randint(high=10, size=(NUM_BATCHES, BATCH_SIZE)),
Expand Down
12 changes: 6 additions & 6 deletions tests/retrieval/test_ndcg.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
from tests.retrieval.helpers import (
RetrievalMetricTester,
_concat_tests,
_default_metric_class_input_arguments,
_default_metric_functional_input_arguments,
_default_metric_class_input_arguments_with_non_binary_target,
_default_metric_functional_input_arguments_with_non_binary_target,
_errors_test_class_metric_parameters_default,
_errors_test_class_metric_parameters_k,
_errors_test_class_metric_parameters_no_pos_target,
Expand Down Expand Up @@ -56,7 +56,7 @@ class TestNDCG(RetrievalMetricTester):
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
@pytest.mark.parametrize("empty_target_action", ['skip', 'neg', 'pos'])
@pytest.mark.parametrize("k", [None, 1, 4, 10])
@pytest.mark.parametrize(**_default_metric_class_input_arguments)
@pytest.mark.parametrize(**_default_metric_class_input_arguments_with_non_binary_target)
def test_class_metric(
self,
ddp: bool,
Expand All @@ -80,7 +80,7 @@ def test_class_metric(
metric_args=metric_args,
)

@pytest.mark.parametrize(**_default_metric_functional_input_arguments)
@pytest.mark.parametrize(**_default_metric_functional_input_arguments_with_non_binary_target)
@pytest.mark.parametrize("k", [None, 1, 4, 10])
def test_functional_metric(self, preds: Tensor, target: Tensor, k: int):
self.run_functional_metric_test(
Expand All @@ -92,7 +92,7 @@ def test_functional_metric(self, preds: Tensor, target: Tensor, k: int):
k=k,
)

@pytest.mark.parametrize(**_default_metric_class_input_arguments)
@pytest.mark.parametrize(**_default_metric_class_input_arguments_with_non_binary_target)
def test_precision_cpu(self, indexes: Tensor, preds: Tensor, target: Tensor):
self.run_precision_test_cpu(
indexes=indexes,
Expand All @@ -102,7 +102,7 @@ def test_precision_cpu(self, indexes: Tensor, preds: Tensor, target: Tensor):
metric_functional=retrieval_normalized_dcg,
)

@pytest.mark.parametrize(**_default_metric_class_input_arguments)
@pytest.mark.parametrize(**_default_metric_class_input_arguments_with_non_binary_target)
def test_precision_gpu(self, indexes: Tensor, preds: Tensor, target: Tensor):
self.run_precision_test_gpu(
indexes=indexes,
Expand Down
5 changes: 4 additions & 1 deletion torchmetrics/retrieval/retrieval_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def __init__(
process_group=process_group,
dist_sync_fn=dist_sync_fn
)
self.allow_non_binary_target = False

empty_target_action_options = ('error', 'skip', 'neg', 'pos')
if empty_target_action not in empty_target_action_options:
Expand All @@ -98,7 +99,9 @@ def update(self, preds: Tensor, target: Tensor, indexes: Tensor) -> None: # typ
if indexes is None:
raise ValueError("Argument `indexes` cannot be None")

indexes, preds, target = _check_retrieval_inputs(indexes, preds, target)
indexes, preds, target = _check_retrieval_inputs(
indexes, preds, target, allow_non_binary_target=self.allow_non_binary_target
)

self.indexes.append(indexes)
self.preds.append(preds)
Expand Down
1 change: 1 addition & 0 deletions torchmetrics/retrieval/retrieval_ndcg.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def __init__(
if (k is not None) and not (isinstance(k, int) and k > 0):
raise ValueError("`k` has to be a positive integer or None")
self.k = k
self.allow_non_binary_target = True

def _metric(self, preds: Tensor, target: Tensor) -> Tensor:
return retrieval_normalized_dcg(preds, target, k=self.k)

0 comments on commit b20cbda

Please sign in to comment.