Skip to content

Commit

Permalink
Fix compatibility with future pytorch (#1011)
Browse files Browse the repository at this point in the history
  • Loading branch information
SkafteNicki authored May 5, 2022
1 parent 1789b17 commit 63094c2
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 2 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed multi device aggregation in `PearsonCorrCoef` ([#998](https://github.com/PyTorchLightning/metrics/pull/998))


-
- Fixed compatibility with future Pytorch 1.12 in `pairwise_cosine_similarity` ([#1011](https://github.com/PyTorchLightning/metrics/pull/1011))


## [0.8.1] - 2022-04-27
Expand Down
12 changes: 11 additions & 1 deletion torchmetrics/functional/pairwise/cosine.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,16 @@
from torchmetrics.functional.pairwise.helpers import _check_input, _reduce_distance_matrix


def _safe_matmul(x: Tensor, y: Tensor) -> Tensor:
"""Safe calculation of matrix multiplication.
If input is float16, will cast to float32 for computation and back again.
"""
if x.dtype == torch.float16 or y.dtype == torch.float16:
return (x.float() @ y.T.float()).half()
return x @ y.T


def _pairwise_cosine_similarity_update(
x: Tensor, y: Optional[Tensor] = None, zero_diagonal: Optional[bool] = None
) -> Tensor:
Expand All @@ -37,7 +47,7 @@ def _pairwise_cosine_similarity_update(
norm = torch.norm(y, p=2, dim=1)
y /= norm.unsqueeze(1)

distance = x @ y.T
distance = _safe_matmul(x, y)
if zero_diagonal:
distance.fill_diagonal_(0)
return distance
Expand Down

0 comments on commit 63094c2

Please sign in to comment.