From 63094c22ec2df1ed3c4ae8082ef228c390a129dc Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Thu, 5 May 2022 13:19:35 +0200 Subject: [PATCH] Fix compatibility with future pytorch (#1011) --- CHANGELOG.md | 2 +- torchmetrics/functional/pairwise/cosine.py | 12 +++++++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index fff4578b101..eec4a7b0577 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -63,7 +63,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed multi device aggregation in `PearsonCorrCoef` ([#998](https://github.com/PyTorchLightning/metrics/pull/998)) -- +- Fixed compatibility with future Pytorch 1.12 in `pairwise_cosine_similarity` ([#1011](https://github.com/PyTorchLightning/metrics/pull/1011)) ## [0.8.1] - 2022-04-27 diff --git a/torchmetrics/functional/pairwise/cosine.py b/torchmetrics/functional/pairwise/cosine.py index 74df87352bb..29103a52dd9 100644 --- a/torchmetrics/functional/pairwise/cosine.py +++ b/torchmetrics/functional/pairwise/cosine.py @@ -20,6 +20,16 @@ from torchmetrics.functional.pairwise.helpers import _check_input, _reduce_distance_matrix +def _safe_matmul(x: Tensor, y: Tensor) -> Tensor: + """Safe calculation of matrix multiplication. + + If input is float16, will cast to float32 for computation and back again. + """ + if x.dtype == torch.float16 or y.dtype == torch.float16: + return (x.float() @ y.T.float()).half() + return x @ y.T + + def _pairwise_cosine_similarity_update( x: Tensor, y: Optional[Tensor] = None, zero_diagonal: Optional[bool] = None ) -> Tensor: @@ -37,7 +47,7 @@ def _pairwise_cosine_similarity_update( norm = torch.norm(y, p=2, dim=1) y /= norm.unsqueeze(1) - distance = x @ y.T + distance = _safe_matmul(x, y) if zero_diagonal: distance.fill_diagonal_(0) return distance