diff --git a/CHANGELOG.md b/CHANGELOG.md index c1a8b5e05b9..f909e2b8306 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,6 +30,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed negative variance estimates in certain image metrics ([#2378](https://github.com/Lightning-AI/torchmetrics/pull/2378)) +- Fixed dtype being changed by deepspeed for certain regression metrics ([#2379](https://github.com/Lightning-AI/torchmetrics/pull/2379)) + + ## [1.3.1] - 2024-02-12 ### Fixed diff --git a/src/torchmetrics/functional/regression/pearson.py b/src/torchmetrics/functional/regression/pearson.py index b76c3a12992..c98bc65a85c 100644 --- a/src/torchmetrics/functional/regression/pearson.py +++ b/src/torchmetrics/functional/regression/pearson.py @@ -59,8 +59,8 @@ def _pearson_corrcoef_update( mx_new = (num_prior * mean_x + preds.sum(0)) / (num_prior + num_obs) my_new = (num_prior * mean_y + target.sum(0)) / (num_prior + num_obs) else: - mx_new = preds.mean(0) - my_new = target.mean(0) + mx_new = preds.mean(0).to(mean_x.dtype) + my_new = target.mean(0).to(mean_y.dtype) num_prior += num_obs diff --git a/src/torchmetrics/metric.py b/src/torchmetrics/metric.py index 29fdf81b46e..6a921ab0b96 100644 --- a/src/torchmetrics/metric.py +++ b/src/torchmetrics/metric.py @@ -108,6 +108,7 @@ def __init__( torch._C._log_api_usage_once(f"torchmetrics.metric.{self.__class__.__name__}") self._device = torch.device("cpu") + self._dtype = torch.get_default_dtype() self.compute_on_cpu = kwargs.pop("compute_on_cpu", False) if not isinstance(self.compute_on_cpu, bool): @@ -729,6 +730,11 @@ def device(self) -> "torch.device": """Return the device of the metric.""" return self._device + @property + def dtype(self) -> "torch.dtype": + """Return the default dtype of the metric.""" + return self._dtype + def type(self, dst_type: Union[str, torch.dtype]) -> "Metric": # noqa: A003 """Override default and prevent dtype casting. @@ -813,7 +819,9 @@ def _apply(self, fn: Callable, exclude_state: Sequence[str] = "") -> Module: # make sure to update the device attribute # if the dummy tensor moves device by fn function we should also update the attribute - self._device = fn(torch.zeros(1, device=self.device)).device + _dummy_tensor = fn(torch.zeros(1, device=self.device)) + self._device = _dummy_tensor.device + self._dtype = _dummy_tensor.dtype # Additional apply to forward cache and computed attributes (may be nested) if this._computed is not None: diff --git a/src/torchmetrics/regression/spearman.py b/src/torchmetrics/regression/spearman.py index 5aed8e96a01..8d6c85f298a 100644 --- a/src/torchmetrics/regression/spearman.py +++ b/src/torchmetrics/regression/spearman.py @@ -96,8 +96,8 @@ def __init__( def update(self, preds: Tensor, target: Tensor) -> None: """Update state with predictions and targets.""" preds, target = _spearman_corrcoef_update(preds, target, num_outputs=self.num_outputs) - self.preds.append(preds) - self.target.append(target) + self.preds.append(preds.to(self.dtype)) + self.target.append(target.to(self.dtype)) def compute(self) -> Tensor: """Compute Spearman's correlation coefficient."""