Skip to content

Commit

Permalink
Fix broken clone method for classification metrics (#1250)
Browse files Browse the repository at this point in the history
* fix clone method
* chlog

Co-authored-by: Jirka <[email protected]>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Oct 6, 2022
1 parent 5b3eb5c commit 2856e0b
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 5 deletions.
8 changes: 4 additions & 4 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,22 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed

-
-


### Deprecated

-
-


### Removed

-
-


### Fixed

-
- Fixed broken clone method for classification metrics ([#1250](https://github.com/Lightning-AI/metrics/pull/1250))


## [0.10.0] - 2022-10-04
Expand Down
5 changes: 4 additions & 1 deletion src/torchmetrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from abc import ABC, abstractmethod
from contextlib import contextmanager
from copy import deepcopy
from typing import Any, Callable, Dict, Generator, List, Optional, Sequence, Union
from typing import Any, Callable, Dict, Generator, List, Optional, Sequence, Tuple, Union

import torch
from torch import Tensor
Expand Down Expand Up @@ -848,6 +848,9 @@ def __pos__(self) -> "Metric":
def __getitem__(self, idx: int) -> "Metric":
return CompositionalMetric(lambda x: x[idx], self, None)

def __getnewargs__(self) -> Tuple:
return (Metric.__str__(self),)


def _neg(x: Tensor) -> Tensor:
return -torch.abs(x)
Expand Down
5 changes: 5 additions & 0 deletions tests/unittests/helpers/testers.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,11 @@ def _class_test(
if check_scriptable:
torch.jit.script(metric)

# check that metric can be cloned
clone = metric.clone()
assert clone is not metric, "Clone is not a different object than the metric"
assert type(clone) == type(metric), "Type of clone did not match metric type"

# move to device
metric = metric.to(device)
preds = apply_to_collection(preds, Tensor, lambda x: x.to(device))
Expand Down

0 comments on commit 2856e0b

Please sign in to comment.