From 4bc6e959d0be798b00ea8ccac96e9c8e2f9616a0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 10 Nov 2021 22:20:59 +0100 Subject: [PATCH] Fix support for dataclasses with ClassVar/InitVar in `apply_to_collection` (#9702) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli Co-authored-by: Carlos Mocholi --- CHANGELOG.md | 2 +- .../connectors/logger_connector/result.py | 26 ++- pytorch_lightning/utilities/apply_func.py | 25 ++- tests/core/test_results.py | 2 +- tests/models/test_tpu.py | 2 +- tests/utilities/test_apply_func.py | 158 +++++++++++++----- 6 files changed, 159 insertions(+), 56 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c93d1618eb088..8b18ac8db873a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,7 +10,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed - Fixed `CombinedLoader` and `max_size_cycle` didn't receive a `DistributedSampler` ([#10374](https://github.com/PyTorchLightning/pytorch-lightning/issues/10374)) - +- Fixed an issue where class or init-only variables of dataclasses were passed to the dataclass constructor in `utilities.apply_to_collection` ([#9702](https://github.com/PyTorchLightning/pytorch-lightning/issues/9702)) ## [1.5.1] - 2021-11-09 diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index f798cf3ee2b82..53034ac77db3f 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -51,8 +51,8 @@ class _Sync: fn: Optional[Callable] = None _should: bool = False rank_zero_only: bool = False - op: Optional[str] = None - group: Optional[Any] = None + _op: Optional[str] = None + _group: Optional[Any] = None def __post_init__(self) -> None: self._generate_sync_fn() @@ -67,6 +67,26 @@ def should(self, should: bool) -> None: # `self._fn` needs to be re-generated. self._generate_sync_fn() + @property + def op(self) -> Optional[str]: + return self._op + + @op.setter + def op(self, op: Optional[str]) -> None: + self._op = op + # `self._fn` needs to be re-generated. + self._generate_sync_fn() + + @property + def group(self) -> Optional[Any]: + return self._group + + @group.setter + def group(self, group: Optional[Any]) -> None: + self._group = group + # `self._fn` needs to be re-generated. + self._generate_sync_fn() + def _generate_sync_fn(self) -> None: """Used to compute the syncing function and cache it.""" fn = self.no_op if self.fn is None or not self.should or self.rank_zero_only else self.fn @@ -426,7 +446,7 @@ def log( dataloader_idx=dataloader_idx, metric_attribute=metric_attribute, ) - meta.sync = _Sync(_should=sync_dist, fn=sync_dist_fn, group=sync_dist_group, rank_zero_only=rank_zero_only) + meta.sync = _Sync(_should=sync_dist, fn=sync_dist_fn, _group=sync_dist_group, rank_zero_only=rank_zero_only) # register logged value if it doesn't exist if key not in self: diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index 1e981a0f543e7..5a76f402bcc02 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -16,7 +16,7 @@ from abc import ABC from collections import defaultdict, OrderedDict from collections.abc import Mapping, Sequence -from copy import copy +from copy import copy, deepcopy from functools import partial from typing import Any, Callable, List, Optional, Tuple, Union @@ -119,11 +119,21 @@ def apply_to_collection( return elem_type(*out) if is_namedtuple else elem_type(out) if _is_dataclass_instance(data): - out_dict = {} + # make a deepcopy of the data, + # but do not deepcopy mapped fields since the computation would + # be wasted on values that likely get immediately overwritten + fields = {} + memo = {} for field in dataclasses.fields(data): - if field.init: + field_value = getattr(data, field.name) + fields[field.name] = (field_value, field.init) + memo[id(field_value)] = field_value + result = deepcopy(data, memo=memo) + # apply function to each field + for field_name, (field_value, field_init) in fields.items(): + if field_init: v = apply_to_collection( - getattr(data, field.name), + field_value, dtype, function, *args, @@ -131,9 +141,10 @@ def apply_to_collection( include_none=include_none, **kwargs, ) - if include_none or v is not None: - out_dict[field.name] = v - return elem_type(**out_dict) + if not field_init or (not include_none and v is None): # retain old value + v = getattr(data, field_name) + setattr(result, field_name, v) + return result # data is neither of dtype, nor a collection return data diff --git a/tests/core/test_results.py b/tests/core/test_results.py index 0e62441b1d40e..a39ce51788ff9 100644 --- a/tests/core/test_results.py +++ b/tests/core/test_results.py @@ -33,7 +33,7 @@ def _setup_ddp(rank, worldsize): def _ddp_test_fn(rank, worldsize): _setup_ddp(rank, worldsize) tensor = torch.tensor([1.0]) - sync = _Sync(sync_ddp_if_available, _should=True, op="SUM") + sync = _Sync(sync_ddp_if_available, _should=True, _op="SUM") actual = sync(tensor) assert actual.item() == dist.get_world_size(), "Result-Log does not work properly with DDP and Tensors" diff --git a/tests/models/test_tpu.py b/tests/models/test_tpu.py index d8ceb4106fd07..31ebd3968ff3e 100644 --- a/tests/models/test_tpu.py +++ b/tests/models/test_tpu.py @@ -407,7 +407,7 @@ def test_tpu_sync_dist(): """Test tpu spawn sync dist operation.""" def test_sync_dist(_): - sync = _Sync(TPUSpawnPlugin().reduce, should=True, op=torch.distributed.ReduceOp.SUM) + sync = _Sync(TPUSpawnPlugin().reduce, should=True, _op=torch.distributed.ReduceOp.SUM) value = torch.tensor([1.0]) value = (sync(value),) assert value.item() == 8 diff --git a/tests/utilities/test_apply_func.py b/tests/utilities/test_apply_func.py index da309f7d22b50..9b0fcbd643744 100644 --- a/tests/utilities/test_apply_func.py +++ b/tests/utilities/test_apply_func.py @@ -14,7 +14,8 @@ import dataclasses import numbers from collections import defaultdict, namedtuple, OrderedDict -from typing import List +from dataclasses import InitVar +from typing import Any, ClassVar, List, Optional import numpy as np import pytest @@ -31,6 +32,12 @@ class Feature: input_ids: torch.Tensor segment_ids: np.ndarray + def __eq__(self, o: object) -> bool: + if not isinstance(o, Feature): + return NotImplemented + else: + return torch.equal(self.input_ids, o.input_ids) and np.equal(self.segment_ids, o.segment_ids).all() + @dataclasses.dataclass class ModelExample: example_ids: List[str] @@ -41,6 +48,71 @@ class ModelExample: def __post_init__(self): self.some_constant = 7 + def __eq__(self, o: object) -> bool: + if not isinstance(o, ModelExample): + return NotImplemented + else: + return ( + self.example_ids == o.example_ids + and self.feature == o.feature + and torch.equal(self.label, o.label) + and self.some_constant == o.some_constant + ) + + @dataclasses.dataclass + class WithClassVar: + class_var: ClassVar[int] = 0 + dummy: Any + + def __eq__(self, o: object) -> bool: + if not isinstance(o, WithClassVar): + return NotImplemented + elif isinstance(self.dummy, torch.Tensor): + return torch.equal(self.dummy, o.dummy) + else: + return self.dummy == o.dummy + + @dataclasses.dataclass + class WithInitVar: + dummy: Any + override: InitVar[Optional[Any]] = None + + def __post_init__(self, override: Optional[Any]): + if override is not None: + self.dummy = override + + def __eq__(self, o: object) -> bool: + if not isinstance(o, WithInitVar): + return NotImplemented + elif isinstance(self.dummy, torch.Tensor): + return torch.equal(self.dummy, o.dummy) + else: + return self.dummy == o.dummy + + @dataclasses.dataclass + class WithClassAndInitVar: + class_var: ClassVar[torch.Tensor] = torch.tensor(0) + dummy: Any + override: InitVar[Optional[Any]] = torch.tensor(1) + + def __post_init__(self, override: Optional[Any]): + if override is not None: + self.dummy = override + + def __eq__(self, o: object) -> bool: + if not isinstance(o, WithClassAndInitVar): + return NotImplemented + elif isinstance(self.dummy, torch.Tensor): + return torch.equal(self.dummy, o.dummy) + else: + return self.dummy == o.dummy + + model_example = ModelExample( + example_ids=["i-1", "i-2", "i-3"], + feature=Feature(input_ids=torch.tensor([1.0, 2.0, 3.0]), segment_ids=np.array([4.0, 5.0, 6.0])), + label=torch.tensor([7.0, 8.0, 9.0]), + ) + to_reduce = { "a": torch.tensor([1.0]), # Tensor "b": [torch.tensor([2.0])], # list @@ -50,13 +122,18 @@ def __post_init__(self): "f": "this_is_a_dummy_str", # string "g": 12.0, # number "h": Feature(input_ids=torch.tensor([1.0, 2.0, 3.0]), segment_ids=np.array([4.0, 5.0, 6.0])), # dataclass - "i": ModelExample( - example_ids=["i-1", "i-2", "i-3"], - feature=Feature(input_ids=torch.tensor([1.0, 2.0, 3.0]), segment_ids=np.array([4.0, 5.0, 6.0])), - label=torch.tensor([7.0, 8.0, 9.0]), - ), # nested dataclass + "i": model_example, # nested dataclass + "j": WithClassVar(torch.arange(3)), # dataclass with class variable + "k": WithInitVar("this_gets_overridden", torch.tensor([2.0])), # dataclass with init-only variable + "l": WithClassAndInitVar(model_example, None), # nested dataclass with class and init-only variables } + model_example_result = ModelExample( + example_ids=["i-1", "i-2", "i-3"], + feature=Feature(input_ids=torch.tensor([2.0, 4.0, 6.0]), segment_ids=np.array([8.0, 10.0, 12.0])), + label=torch.tensor([14.0, 16.0, 18.0]), + ) + expected_result = { "a": torch.tensor([2.0]), "b": [torch.tensor([4.0])], @@ -66,32 +143,31 @@ def __post_init__(self): "f": "this_is_a_dummy_str", "g": 24.0, "h": Feature(input_ids=torch.tensor([2.0, 4.0, 6.0]), segment_ids=np.array([8.0, 10.0, 12.0])), - "i": ModelExample( - example_ids=["i-1", "i-2", "i-3"], - feature=Feature(input_ids=torch.tensor([2.0, 4.0, 6.0]), segment_ids=np.array([8.0, 10.0, 12.0])), - label=torch.tensor([14.0, 16.0, 18.0]), - ), + "i": model_example_result, + "j": WithClassVar(torch.arange(0, 6, 2)), + "k": WithInitVar(torch.tensor([4.0])), + "l": WithClassAndInitVar(model_example_result, None), } reduced = apply_to_collection(to_reduce, (torch.Tensor, numbers.Number, np.ndarray), lambda x: x * 2) - assert isinstance(reduced, dict), " Type Consistency of dict not preserved" + assert isinstance(reduced, dict), "Type Consistency of dict not preserved" assert all(x in reduced for x in to_reduce), "Not all entries of the dict were preserved" assert all( isinstance(reduced[k], type(expected_result[k])) for k in to_reduce ), "At least one type was not correctly preserved" assert isinstance(reduced["a"], torch.Tensor), "Reduction Result of a Tensor should be a Tensor" - assert torch.allclose(expected_result["a"], reduced["a"]), "Reduction of a tensor does not yield the expected value" + assert torch.equal(expected_result["a"], reduced["a"]), "Reduction of a tensor does not yield the expected value" assert isinstance(reduced["b"], list), "Reduction Result of a list should be a list" assert all( - torch.allclose(x, y) for x, y in zip(reduced["b"], expected_result["b"]) + torch.equal(x, y) for x, y in zip(reduced["b"], expected_result["b"]) ), "At least one value of list reduction did not come out as expected" assert isinstance(reduced["c"], tuple), "Reduction Result of a tuple should be a tuple" assert all( - torch.allclose(x, y) for x, y in zip(reduced["c"], expected_result["c"]) + torch.equal(x, y) for x, y in zip(reduced["c"], expected_result["c"]) ), "At least one value of tuple reduction did not come out as expected" assert isinstance(reduced["d"], ntc), "Type Consistency for named tuple not given" @@ -109,34 +185,30 @@ def __post_init__(self): assert isinstance(reduced["g"], numbers.Number), "Reduction of a number should result in a number" assert reduced["g"] == expected_result["g"], "Reduction of a number did not yield the desired result" - assert dataclasses.is_dataclass(reduced["h"]) and not isinstance( - reduced["h"], type - ), "Reduction of a dataclass should result in a dataclass" - assert torch.allclose( - reduced["h"].input_ids, expected_result["h"].input_ids - ), "Reduction of a dataclass did not yield the desired result" - assert np.allclose( - reduced["h"].segment_ids, expected_result["h"].segment_ids - ), "Reduction of a dataclass did not yield the desired result" - - assert dataclasses.is_dataclass(reduced["i"]) and not isinstance( - reduced["i"], type - ), "Reduction of a dataclass should result in a dataclass" - assert dataclasses.is_dataclass(reduced["i"].feature) and not isinstance( - reduced["i"].feature, type - ), "Reduction of a nested dataclass should result in a nested dataclass" - assert ( - reduced["i"].example_ids == expected_result["i"].example_ids - ), "Reduction of a nested dataclass did not yield the desired result" - assert torch.allclose( - reduced["i"].label, expected_result["i"].label - ), "Reduction of a nested dataclass did not yield the desired result" - assert torch.allclose( - reduced["i"].feature.input_ids, expected_result["i"].feature.input_ids - ), "Reduction of a nested dataclass did not yield the desired result" - assert np.allclose( - reduced["i"].feature.segment_ids, expected_result["i"].feature.segment_ids - ), "Reduction of a nested dataclass did not yield the desired result" + def _assert_dataclass_reduction(actual, expected, dataclass_type: str = ""): + assert dataclasses.is_dataclass(actual) and not isinstance( + actual, type + ), f"Reduction of a {dataclass_type} dataclass should result in a dataclass" + for field in dataclasses.fields(actual): + if dataclasses.is_dataclass(field.type): + _assert_dataclass_reduction(getattr(actual, field.name), getattr(expected, field.name), "nested") + assert actual == expected, f"Reduction of a {dataclass_type} dataclass did not yield the desired result" + + _assert_dataclass_reduction(reduced["h"], expected_result["h"]) + + _assert_dataclass_reduction(reduced["i"], expected_result["i"]) + + dataclass_type = "ClassVar-containing" + _assert_dataclass_reduction(reduced["j"], expected_result["j"], dataclass_type) + assert WithClassVar.class_var == 0, f"Reduction of a {dataclass_type} dataclass should not change the class var" + + _assert_dataclass_reduction(reduced["k"], expected_result["k"], "InitVar-containing") + + dataclass_type = "Class-and-InitVar-containing" + _assert_dataclass_reduction(reduced["l"], expected_result["l"], dataclass_type) + assert torch.equal( + WithClassAndInitVar.class_var, torch.tensor(0) + ), f"Reduction of a {dataclass_type} dataclass should not change the class var" # mapping support reduced = apply_to_collection({"a": 1, "b": 2}, int, lambda x: str(x))