Skip to content

Commit

Permalink
Merge branch 'master' into typing/reg
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda authored Jun 30, 2021
2 parents 62b7841 + 8412701 commit 42d0da3
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 15 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Changed

- Extend typing ([#330](https://github.com/PyTorchLightning/metrics/pull/330),
[#332](https://github.com/PyTorchLightning/metrics/pull/332),
[#333](https://github.com/PyTorchLightning/metrics/pull/333))


Expand Down
4 changes: 0 additions & 4 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,3 @@ ignore_missing_imports = True
# todo: add proper typing to this module...
[mypy-torchmetrics.classification.*]
ignore_errors = True

# todo: add proper typing to this module...
[mypy-torchmetrics.retrieval.*]
ignore_errors = True
4 changes: 3 additions & 1 deletion torchmetrics/functional/retrieval/ndcg.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional

import torch
from torch import Tensor, tensor

Expand All @@ -22,7 +24,7 @@ def _dcg(target: Tensor) -> Tensor:
return (target / denom).sum()


def retrieval_normalized_dcg(preds: Tensor, target: Tensor, k: int = None) -> Tensor:
def retrieval_normalized_dcg(preds: Tensor, target: Tensor, k: Optional[int] = None) -> Tensor:
"""
Computes Normalized Discounted Cumulative Gain (for information retrieval), as explained
`here <https://en.wikipedia.org/wiki/Discounted_cumulative_gain>`__.
Expand Down
4 changes: 3 additions & 1 deletion torchmetrics/functional/retrieval/precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional

import torch
from torch import Tensor, tensor

from torchmetrics.utilities.checks import _check_retrieval_functional_inputs


def retrieval_precision(preds: Tensor, target: Tensor, k: int = None) -> Tensor:
def retrieval_precision(preds: Tensor, target: Tensor, k: Optional[int] = None) -> Tensor:
"""
Computes the precision metric (for information retrieval),
as explained `here <https://en.wikipedia.org/wiki/Precision_and_recall#Precision>`__.
Expand Down
4 changes: 3 additions & 1 deletion torchmetrics/functional/retrieval/recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional

import torch
from torch import Tensor, tensor

from torchmetrics.utilities.checks import _check_retrieval_functional_inputs


def retrieval_recall(preds: Tensor, target: Tensor, k: int = None) -> Tensor:
def retrieval_recall(preds: Tensor, target: Tensor, k: Optional[int] = None) -> Tensor:
"""
Computes the recall metric (for information retrieval),
as explained `here <https://en.wikipedia.org/wiki/Precision_and_recall#Recall>`__.
Expand Down
16 changes: 8 additions & 8 deletions torchmetrics/retrieval/retrieval_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,18 +82,18 @@ def __init__(

empty_target_action_options = ('error', 'skip', 'neg', 'pos')
if empty_target_action not in empty_target_action_options:
raise ValueError(f"`empty_target_action` received a wrong value `{empty_target_action}`.")
raise ValueError(f"Argument `empty_target_action` received a wrong value `{empty_target_action}`.")

self.empty_target_action = empty_target_action

self.add_state("indexes", default=[], dist_reduce_fx=None)
self.add_state("preds", default=[], dist_reduce_fx=None)
self.add_state("target", default=[], dist_reduce_fx=None)

def update(self, preds: Tensor, target: Tensor, indexes: Tensor = None) -> None:
def update(self, preds: Tensor, target: Tensor, indexes: Tensor) -> None: # type: ignore
""" Check shape, check and convert dtypes, flatten and add to accumulators. """
if indexes is None:
raise ValueError("`indexes` cannot be None")
raise ValueError("Argument `indexes` cannot be None")

indexes, preds, target = _check_retrieval_inputs(indexes, preds, target)

Expand All @@ -103,10 +103,10 @@ def update(self, preds: Tensor, target: Tensor, indexes: Tensor = None) -> None:

def compute(self) -> Tensor:
"""
First concat state `indexes`, `preds` and `target` since they were stored as lists. After that,
compute list of groups that will help in keeping together predictions about the same query.
Finally, for each group compute the `_metric` if the number of positive targets is at least
1, otherwise behave as specified by `self.empty_target_action`.
First concat state ``indexes``, ``preds`` and ``target`` since they were stored as lists.
After that, compute list of groups that will help in keeping together predictions about the same query.
Finally, for each group compute the ``_metric`` if the number of positive targets is at least
1, otherwise behave as specified by ``self.empty_target_action``.
"""
indexes = torch.cat(self.indexes, dim=0)
preds = torch.cat(self.preds, dim=0)
Expand All @@ -127,7 +127,7 @@ def compute(self) -> Tensor:
elif self.empty_target_action == 'neg':
res.append(tensor(0.0))
else:
# ensure list containt only float tensors
# ensure list contains only float tensors
res.append(self._metric(mini_preds, mini_target))

return torch.stack([x.to(preds) for x in res]).mean() if res else tensor(0.0).to(preds)
Expand Down

0 comments on commit 42d0da3

Please sign in to comment.