diff --git a/CHANGELOG.md b/CHANGELOG.md index 3f2aa69ba29..f970dbc9df2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,7 +20,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed -- +- Made `__getattr__` and `__setattr__` of `ClasswiseWrapper` more general ([#2424](https://github.com/Lightning-AI/torchmetrics/pull/2424)) ### Deprecated diff --git a/src/torchmetrics/wrappers/classwise.py b/src/torchmetrics/wrappers/classwise.py index 698d0f51848..0920118c919 100644 --- a/src/torchmetrics/wrappers/classwise.py +++ b/src/torchmetrics/wrappers/classwise.py @@ -216,19 +216,19 @@ def plot( def __getattr__(self, name: str) -> Union[Tensor, "Module"]: """Get attribute from classwise wrapper.""" - # return state from self.metric - if name in ["tp", "fp", "fn", "tn"]: - return getattr(self.metric, name) + if name == "metric" or (name in self.__dict__ and name not in self.metric.__dict__): + # we need this to prevent from infinite getattribute loop. + return super().__getattr__(name) - return super().__getattr__(name) + return getattr(self.metric, name) def __setattr__(self, name: str, value: Any) -> None: """Set attribute to classwise wrapper.""" - super().__setattr__(name, value) - if name == "metric": - self._defaults = self.metric._defaults - self._persistent = self.metric._persistent - self._reductions = self.metric._reductions - if hasattr(self, "metric") and name in ["tp", "fp", "fn", "tn", "_update_count", "_computed"]: - # update ``_update_count`` and ``_computed`` of internal metric to prevent warning. + if hasattr(self, "metric") and name in self.metric._defaults: setattr(self.metric, name, value) + else: + super().__setattr__(name, value) + if name == "metric": + self._defaults = self.metric._defaults + self._persistent = self.metric._persistent + self._reductions = self.metric._reductions