From b2a039f0830fe06b212e1cb75dc3bccacab43877 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Sat, 25 Nov 2023 17:16:20 +0100 Subject: [PATCH] Fix support for half precision in Perplexity metric (#2235) (cherry picked from commit c35a2fbb81120c1143f67e768fafcef3e44dbb09) --- CHANGELOG.md | 5 +- .../functional/text/perplexity.py | 11 +-- tests/unittests/bases/test_collections.py | 76 +++++++++---------- tests/unittests/text/test_perplexity.py | 23 +++++- 4 files changed, 67 insertions(+), 48 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 21eb389299b..6fbf0eb7921 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -35,7 +35,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed numerical stability issue in `UniversalImageQualityIndex` metric ([#2222](https://github.com/Lightning-AI/torchmetrics/pull/2222)) -- Fix device and dtype for `LearnedPerceptualImagePatchSimilarity` functional metric ([#2234](https://github.com/Lightning-AI/torchmetrics/pull/2234)) +- Fixed support for half precision in Perplexity metric ([#2235](https://github.com/Lightning-AI/torchmetrics/pull/2235)) + + +- Fixed device and dtype for `LearnedPerceptualImagePatchSimilarity` functional metric ([#2234](https://github.com/Lightning-AI/torchmetrics/pull/2234)) ## [1.2.0] - 2023-09-22 diff --git a/src/torchmetrics/functional/text/perplexity.py b/src/torchmetrics/functional/text/perplexity.py index 127d3c74a67..cb0bafd5082 100644 --- a/src/torchmetrics/functional/text/perplexity.py +++ b/src/torchmetrics/functional/text/perplexity.py @@ -16,9 +16,6 @@ import torch from torch import Tensor -from torch.nn import functional as F # noqa: N812 - -_TORCH_FLOAT_OR_DOUBLE = (torch.float32, torch.float64) def _check_shape_and_type_consistency(preds: Tensor, target: Tensor) -> None: @@ -59,10 +56,8 @@ def _check_shape_and_type_consistency(preds: Tensor, target: Tensor) -> None: "Input tensors `preds` and `target` are expected to have equaling first two dimensions," f" [batch_size, seq_len], but got {preds.shape[:2]} and {target.shape}." ) - if preds.dtype not in _TORCH_FLOAT_OR_DOUBLE: - raise TypeError( - f"Input tensor `preds` is expected to be of a type one of {_TORCH_FLOAT_OR_DOUBLE} but got {preds.dtype}." - ) + if not preds.is_floating_point(): + raise TypeError(f"Input tensor `preds` is expected to be of floating point type but got {preds.dtype}.") if target.dtype != torch.int64: raise TypeError(f"Input tensor `target` is expected to be of a type {torch.int64} but got {target.dtype}.") @@ -87,7 +82,7 @@ def _perplexity_update(preds: Tensor, target: Tensor, ignore_index: Optional[int """ _check_shape_and_type_consistency(preds, target) - probs = F.softmax(preds.reshape(-1, preds.shape[-1]), dim=1) + probs = torch.nn.functional.softmax(preds.reshape(-1, preds.shape[-1]), dim=1) target = target.reshape(-1) if ignore_index is not None: diff --git a/tests/unittests/bases/test_collections.py b/tests/unittests/bases/test_collections.py index 6f3e64d1b6c..ce4bb1ba8c8 100644 --- a/tests/unittests/bases/test_collections.py +++ b/tests/unittests/bases/test_collections.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import pickle -import time from copy import deepcopy from typing import Any @@ -480,43 +479,44 @@ def _compare(m1, m2): _compare(metric_cg, metric_no_cg) -@pytest.mark.parametrize( - "metrics", - [ - {"acc0": MulticlassAccuracy(3), "acc1": MulticlassAccuracy(3)}, - [MulticlassPrecision(3), MulticlassRecall(3)], - [MulticlassConfusionMatrix(3), MulticlassCohenKappa(3), MulticlassRecall(3), MulticlassPrecision(3)], - { - "acc": MulticlassAccuracy(3), - "acc2": MulticlassAccuracy(3), - "acc3": MulticlassAccuracy(num_classes=3, average="macro"), - "f1": MulticlassF1Score(3), - "recall": MulticlassRecall(3), - "confmat": MulticlassConfusionMatrix(3), - }, - ], -) -@pytest.mark.parametrize("steps", [1000]) -def test_check_compute_groups_is_faster(metrics, steps): - """Check that compute groups are formed after initialization.""" - m = MetricCollection(deepcopy(metrics), compute_groups=True) - # Construct without for comparison - m2 = MetricCollection(deepcopy(metrics), compute_groups=False) - - preds = torch.randn(10, 3).softmax(dim=-1) - target = torch.randint(3, (10,)) - - start = time.time() - for _ in range(steps): - m.update(preds, target) - time_cg = time.time() - start - - start = time.time() - for _ in range(steps): - m2.update(preds, target) - time_no_cg = time.time() - start - - assert time_cg < time_no_cg, "using compute groups were not faster" +# TODO: test is flaky +# @pytest.mark.parametrize( +# "metrics", +# [ +# {"acc0": MulticlassAccuracy(3), "acc1": MulticlassAccuracy(3)}, +# [MulticlassPrecision(3), MulticlassRecall(3)], +# [MulticlassConfusionMatrix(3), MulticlassCohenKappa(3), MulticlassRecall(3), MulticlassPrecision(3)], +# { +# "acc": MulticlassAccuracy(3), +# "acc2": MulticlassAccuracy(3), +# "acc3": MulticlassAccuracy(num_classes=3, average="macro"), +# "f1": MulticlassF1Score(3), +# "recall": MulticlassRecall(3), +# "confmat": MulticlassConfusionMatrix(3), +# }, +# ], +# ) +# @pytest.mark.parametrize("steps", [1000]) +# def test_check_compute_groups_is_faster(metrics, steps): +# """Check that compute groups are formed after initialization.""" +# m = MetricCollection(deepcopy(metrics), compute_groups=True) +# # Construct without for comparison +# m2 = MetricCollection(deepcopy(metrics), compute_groups=False) + +# preds = torch.randn(10, 3).softmax(dim=-1) +# target = torch.randint(3, (10,)) + +# start = time.time() +# for _ in range(steps): +# m.update(preds, target) +# time_cg = time.time() - start + +# start = time.time() +# for _ in range(steps): +# m2.update(preds, target) +# time_no_cg = time.time() - start + +# assert time_cg < time_no_cg, "using compute groups were not faster" def test_compute_group_define_by_user(): diff --git a/tests/unittests/text/test_perplexity.py b/tests/unittests/text/test_perplexity.py index b79f33391df..658c6eee878 100644 --- a/tests/unittests/text/test_perplexity.py +++ b/tests/unittests/text/test_perplexity.py @@ -71,7 +71,7 @@ def test_perplexity_fn(self, preds, target, ignore_index): metric_args={"ignore_index": ignore_index}, ) - def test_accuracy_differentiability(self, preds, target, ignore_index): + def test_perplexity_differentiability(self, preds, target, ignore_index): """Test the differentiability of the metric, according to its `is_differentiable` attribute.""" self.run_differentiability_test( preds=preds, @@ -80,3 +80,24 @@ def test_accuracy_differentiability(self, preds, target, ignore_index): metric_functional=perplexity, metric_args={"ignore_index": ignore_index}, ) + + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_perplexity_dtypes_cpu(self, preds, target, ignore_index, dtype): + """Test dtype support of the metric on CPU.""" + if dtype == torch.half: + with pytest.raises(RuntimeError, match="\"softmax_lastdim_kernel_impl\" not implemented for 'Half'"): + self.run_precision_test_cpu( + preds, target, Perplexity, perplexity, metric_args={"ignore_index": ignore_index}, dtype=dtype + ) + else: + self.run_precision_test_cpu( + preds, target, Perplexity, perplexity, metric_args={"ignore_index": ignore_index}, dtype=dtype + ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_perplexity_dtypes_gpu(self, preds, target, ignore_index, dtype): + """Test dtype support of the metric on GPU.""" + self.run_precision_test_gpu( + preds, target, Perplexity, perplexity, metric_args={"ignore_index": ignore_index}, dtype=dtype + )