Skip to content

Commit

Permalink
Fix dtype casting in spearman and pearson (#2379)
Browse files Browse the repository at this point in the history
* fix dtype in pearson and spearman
* small refactor
* update classification
* update to new pytest format
* Apply suggestions from code review

---------

Co-authored-by: Jirka Borovec <[email protected]>

(cherry picked from commit f4ef8a8)
  • Loading branch information
SkafteNicki authored and Borda committed Mar 18, 2024
1 parent 37f0219 commit 9563563
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 5 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed negative variance estimates in certain image metrics ([#2378](https://github.com/Lightning-AI/torchmetrics/pull/2378))


- Fixed dtype being changed by deepspeed for certain regression metrics ([#2379](https://github.com/Lightning-AI/torchmetrics/pull/2379))


## [1.3.1] - 2024-02-12

### Fixed
Expand Down
4 changes: 2 additions & 2 deletions src/torchmetrics/functional/regression/pearson.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ def _pearson_corrcoef_update(
mx_new = (num_prior * mean_x + preds.sum(0)) / (num_prior + num_obs)
my_new = (num_prior * mean_y + target.sum(0)) / (num_prior + num_obs)
else:
mx_new = preds.mean(0)
my_new = target.mean(0)
mx_new = preds.mean(0).to(mean_x.dtype)
my_new = target.mean(0).to(mean_y.dtype)

num_prior += num_obs

Expand Down
10 changes: 9 additions & 1 deletion src/torchmetrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def __init__(
torch._C._log_api_usage_once(f"torchmetrics.metric.{self.__class__.__name__}")

self._device = torch.device("cpu")
self._dtype = torch.get_default_dtype()

self.compute_on_cpu = kwargs.pop("compute_on_cpu", False)
if not isinstance(self.compute_on_cpu, bool):
Expand Down Expand Up @@ -729,6 +730,11 @@ def device(self) -> "torch.device":
"""Return the device of the metric."""
return self._device

@property
def dtype(self) -> "torch.dtype":
"""Return the default dtype of the metric."""
return self._dtype

def type(self, dst_type: Union[str, torch.dtype]) -> "Metric": # noqa: A003
"""Override default and prevent dtype casting.
Expand Down Expand Up @@ -813,7 +819,9 @@ def _apply(self, fn: Callable, exclude_state: Sequence[str] = "") -> Module:

# make sure to update the device attribute
# if the dummy tensor moves device by fn function we should also update the attribute
self._device = fn(torch.zeros(1, device=self.device)).device
_dummy_tensor = fn(torch.zeros(1, device=self.device))
self._device = _dummy_tensor.device
self._dtype = _dummy_tensor.dtype

# Additional apply to forward cache and computed attributes (may be nested)
if this._computed is not None:
Expand Down
4 changes: 2 additions & 2 deletions src/torchmetrics/regression/spearman.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ def __init__(
def update(self, preds: Tensor, target: Tensor) -> None:
"""Update state with predictions and targets."""
preds, target = _spearman_corrcoef_update(preds, target, num_outputs=self.num_outputs)
self.preds.append(preds)
self.target.append(target)
self.preds.append(preds.to(self.dtype))
self.target.append(target.to(self.dtype))

def compute(self) -> Tensor:
"""Compute Spearman's correlation coefficient."""
Expand Down

0 comments on commit 9563563

Please sign in to comment.