Skip to content

Commit

Permalink
Fix support for dataclasses with ClassVar/InitVar in `apply_to_collec…
Browse files Browse the repository at this point in the history
…tion` (#9702)

Co-authored-by: Adrian Wälchli <[email protected]>
Co-authored-by: Carlos Mocholi <[email protected]>
  • Loading branch information
2 people authored and lexierule committed Nov 16, 2021
1 parent ab44b81 commit 4bc6e95
Show file tree
Hide file tree
Showing 6 changed files with 159 additions and 56 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Fixed

- Fixed `CombinedLoader` and `max_size_cycle` didn't receive a `DistributedSampler` ([#10374](https://github.com/PyTorchLightning/pytorch-lightning/issues/10374))

- Fixed an issue where class or init-only variables of dataclasses were passed to the dataclass constructor in `utilities.apply_to_collection` ([#9702](https://github.com/PyTorchLightning/pytorch-lightning/issues/9702))


## [1.5.1] - 2021-11-09
Expand Down
26 changes: 23 additions & 3 deletions pytorch_lightning/trainer/connectors/logger_connector/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ class _Sync:
fn: Optional[Callable] = None
_should: bool = False
rank_zero_only: bool = False
op: Optional[str] = None
group: Optional[Any] = None
_op: Optional[str] = None
_group: Optional[Any] = None

def __post_init__(self) -> None:
self._generate_sync_fn()
Expand All @@ -67,6 +67,26 @@ def should(self, should: bool) -> None:
# `self._fn` needs to be re-generated.
self._generate_sync_fn()

@property
def op(self) -> Optional[str]:
return self._op

@op.setter
def op(self, op: Optional[str]) -> None:
self._op = op
# `self._fn` needs to be re-generated.
self._generate_sync_fn()

@property
def group(self) -> Optional[Any]:
return self._group

@group.setter
def group(self, group: Optional[Any]) -> None:
self._group = group
# `self._fn` needs to be re-generated.
self._generate_sync_fn()

def _generate_sync_fn(self) -> None:
"""Used to compute the syncing function and cache it."""
fn = self.no_op if self.fn is None or not self.should or self.rank_zero_only else self.fn
Expand Down Expand Up @@ -426,7 +446,7 @@ def log(
dataloader_idx=dataloader_idx,
metric_attribute=metric_attribute,
)
meta.sync = _Sync(_should=sync_dist, fn=sync_dist_fn, group=sync_dist_group, rank_zero_only=rank_zero_only)
meta.sync = _Sync(_should=sync_dist, fn=sync_dist_fn, _group=sync_dist_group, rank_zero_only=rank_zero_only)

# register logged value if it doesn't exist
if key not in self:
Expand Down
25 changes: 18 additions & 7 deletions pytorch_lightning/utilities/apply_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from abc import ABC
from collections import defaultdict, OrderedDict
from collections.abc import Mapping, Sequence
from copy import copy
from copy import copy, deepcopy
from functools import partial
from typing import Any, Callable, List, Optional, Tuple, Union

Expand Down Expand Up @@ -119,21 +119,32 @@ def apply_to_collection(
return elem_type(*out) if is_namedtuple else elem_type(out)

if _is_dataclass_instance(data):
out_dict = {}
# make a deepcopy of the data,
# but do not deepcopy mapped fields since the computation would
# be wasted on values that likely get immediately overwritten
fields = {}
memo = {}
for field in dataclasses.fields(data):
if field.init:
field_value = getattr(data, field.name)
fields[field.name] = (field_value, field.init)
memo[id(field_value)] = field_value
result = deepcopy(data, memo=memo)
# apply function to each field
for field_name, (field_value, field_init) in fields.items():
if field_init:
v = apply_to_collection(
getattr(data, field.name),
field_value,
dtype,
function,
*args,
wrong_dtype=wrong_dtype,
include_none=include_none,
**kwargs,
)
if include_none or v is not None:
out_dict[field.name] = v
return elem_type(**out_dict)
if not field_init or (not include_none and v is None): # retain old value
v = getattr(data, field_name)
setattr(result, field_name, v)
return result

# data is neither of dtype, nor a collection
return data
Expand Down
2 changes: 1 addition & 1 deletion tests/core/test_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def _setup_ddp(rank, worldsize):
def _ddp_test_fn(rank, worldsize):
_setup_ddp(rank, worldsize)
tensor = torch.tensor([1.0])
sync = _Sync(sync_ddp_if_available, _should=True, op="SUM")
sync = _Sync(sync_ddp_if_available, _should=True, _op="SUM")
actual = sync(tensor)
assert actual.item() == dist.get_world_size(), "Result-Log does not work properly with DDP and Tensors"

Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ def test_tpu_sync_dist():
"""Test tpu spawn sync dist operation."""

def test_sync_dist(_):
sync = _Sync(TPUSpawnPlugin().reduce, should=True, op=torch.distributed.ReduceOp.SUM)
sync = _Sync(TPUSpawnPlugin().reduce, should=True, _op=torch.distributed.ReduceOp.SUM)
value = torch.tensor([1.0])
value = (sync(value),)
assert value.item() == 8
Expand Down
158 changes: 115 additions & 43 deletions tests/utilities/test_apply_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
import dataclasses
import numbers
from collections import defaultdict, namedtuple, OrderedDict
from typing import List
from dataclasses import InitVar
from typing import Any, ClassVar, List, Optional

import numpy as np
import pytest
Expand All @@ -31,6 +32,12 @@ class Feature:
input_ids: torch.Tensor
segment_ids: np.ndarray

def __eq__(self, o: object) -> bool:
if not isinstance(o, Feature):
return NotImplemented
else:
return torch.equal(self.input_ids, o.input_ids) and np.equal(self.segment_ids, o.segment_ids).all()

@dataclasses.dataclass
class ModelExample:
example_ids: List[str]
Expand All @@ -41,6 +48,71 @@ class ModelExample:
def __post_init__(self):
self.some_constant = 7

def __eq__(self, o: object) -> bool:
if not isinstance(o, ModelExample):
return NotImplemented
else:
return (
self.example_ids == o.example_ids
and self.feature == o.feature
and torch.equal(self.label, o.label)
and self.some_constant == o.some_constant
)

@dataclasses.dataclass
class WithClassVar:
class_var: ClassVar[int] = 0
dummy: Any

def __eq__(self, o: object) -> bool:
if not isinstance(o, WithClassVar):
return NotImplemented
elif isinstance(self.dummy, torch.Tensor):
return torch.equal(self.dummy, o.dummy)
else:
return self.dummy == o.dummy

@dataclasses.dataclass
class WithInitVar:
dummy: Any
override: InitVar[Optional[Any]] = None

def __post_init__(self, override: Optional[Any]):
if override is not None:
self.dummy = override

def __eq__(self, o: object) -> bool:
if not isinstance(o, WithInitVar):
return NotImplemented
elif isinstance(self.dummy, torch.Tensor):
return torch.equal(self.dummy, o.dummy)
else:
return self.dummy == o.dummy

@dataclasses.dataclass
class WithClassAndInitVar:
class_var: ClassVar[torch.Tensor] = torch.tensor(0)
dummy: Any
override: InitVar[Optional[Any]] = torch.tensor(1)

def __post_init__(self, override: Optional[Any]):
if override is not None:
self.dummy = override

def __eq__(self, o: object) -> bool:
if not isinstance(o, WithClassAndInitVar):
return NotImplemented
elif isinstance(self.dummy, torch.Tensor):
return torch.equal(self.dummy, o.dummy)
else:
return self.dummy == o.dummy

model_example = ModelExample(
example_ids=["i-1", "i-2", "i-3"],
feature=Feature(input_ids=torch.tensor([1.0, 2.0, 3.0]), segment_ids=np.array([4.0, 5.0, 6.0])),
label=torch.tensor([7.0, 8.0, 9.0]),
)

to_reduce = {
"a": torch.tensor([1.0]), # Tensor
"b": [torch.tensor([2.0])], # list
Expand All @@ -50,13 +122,18 @@ def __post_init__(self):
"f": "this_is_a_dummy_str", # string
"g": 12.0, # number
"h": Feature(input_ids=torch.tensor([1.0, 2.0, 3.0]), segment_ids=np.array([4.0, 5.0, 6.0])), # dataclass
"i": ModelExample(
example_ids=["i-1", "i-2", "i-3"],
feature=Feature(input_ids=torch.tensor([1.0, 2.0, 3.0]), segment_ids=np.array([4.0, 5.0, 6.0])),
label=torch.tensor([7.0, 8.0, 9.0]),
), # nested dataclass
"i": model_example, # nested dataclass
"j": WithClassVar(torch.arange(3)), # dataclass with class variable
"k": WithInitVar("this_gets_overridden", torch.tensor([2.0])), # dataclass with init-only variable
"l": WithClassAndInitVar(model_example, None), # nested dataclass with class and init-only variables
}

model_example_result = ModelExample(
example_ids=["i-1", "i-2", "i-3"],
feature=Feature(input_ids=torch.tensor([2.0, 4.0, 6.0]), segment_ids=np.array([8.0, 10.0, 12.0])),
label=torch.tensor([14.0, 16.0, 18.0]),
)

expected_result = {
"a": torch.tensor([2.0]),
"b": [torch.tensor([4.0])],
Expand All @@ -66,32 +143,31 @@ def __post_init__(self):
"f": "this_is_a_dummy_str",
"g": 24.0,
"h": Feature(input_ids=torch.tensor([2.0, 4.0, 6.0]), segment_ids=np.array([8.0, 10.0, 12.0])),
"i": ModelExample(
example_ids=["i-1", "i-2", "i-3"],
feature=Feature(input_ids=torch.tensor([2.0, 4.0, 6.0]), segment_ids=np.array([8.0, 10.0, 12.0])),
label=torch.tensor([14.0, 16.0, 18.0]),
),
"i": model_example_result,
"j": WithClassVar(torch.arange(0, 6, 2)),
"k": WithInitVar(torch.tensor([4.0])),
"l": WithClassAndInitVar(model_example_result, None),
}

reduced = apply_to_collection(to_reduce, (torch.Tensor, numbers.Number, np.ndarray), lambda x: x * 2)

assert isinstance(reduced, dict), " Type Consistency of dict not preserved"
assert isinstance(reduced, dict), "Type Consistency of dict not preserved"
assert all(x in reduced for x in to_reduce), "Not all entries of the dict were preserved"
assert all(
isinstance(reduced[k], type(expected_result[k])) for k in to_reduce
), "At least one type was not correctly preserved"

assert isinstance(reduced["a"], torch.Tensor), "Reduction Result of a Tensor should be a Tensor"
assert torch.allclose(expected_result["a"], reduced["a"]), "Reduction of a tensor does not yield the expected value"
assert torch.equal(expected_result["a"], reduced["a"]), "Reduction of a tensor does not yield the expected value"

assert isinstance(reduced["b"], list), "Reduction Result of a list should be a list"
assert all(
torch.allclose(x, y) for x, y in zip(reduced["b"], expected_result["b"])
torch.equal(x, y) for x, y in zip(reduced["b"], expected_result["b"])
), "At least one value of list reduction did not come out as expected"

assert isinstance(reduced["c"], tuple), "Reduction Result of a tuple should be a tuple"
assert all(
torch.allclose(x, y) for x, y in zip(reduced["c"], expected_result["c"])
torch.equal(x, y) for x, y in zip(reduced["c"], expected_result["c"])
), "At least one value of tuple reduction did not come out as expected"

assert isinstance(reduced["d"], ntc), "Type Consistency for named tuple not given"
Expand All @@ -109,34 +185,30 @@ def __post_init__(self):
assert isinstance(reduced["g"], numbers.Number), "Reduction of a number should result in a number"
assert reduced["g"] == expected_result["g"], "Reduction of a number did not yield the desired result"

assert dataclasses.is_dataclass(reduced["h"]) and not isinstance(
reduced["h"], type
), "Reduction of a dataclass should result in a dataclass"
assert torch.allclose(
reduced["h"].input_ids, expected_result["h"].input_ids
), "Reduction of a dataclass did not yield the desired result"
assert np.allclose(
reduced["h"].segment_ids, expected_result["h"].segment_ids
), "Reduction of a dataclass did not yield the desired result"

assert dataclasses.is_dataclass(reduced["i"]) and not isinstance(
reduced["i"], type
), "Reduction of a dataclass should result in a dataclass"
assert dataclasses.is_dataclass(reduced["i"].feature) and not isinstance(
reduced["i"].feature, type
), "Reduction of a nested dataclass should result in a nested dataclass"
assert (
reduced["i"].example_ids == expected_result["i"].example_ids
), "Reduction of a nested dataclass did not yield the desired result"
assert torch.allclose(
reduced["i"].label, expected_result["i"].label
), "Reduction of a nested dataclass did not yield the desired result"
assert torch.allclose(
reduced["i"].feature.input_ids, expected_result["i"].feature.input_ids
), "Reduction of a nested dataclass did not yield the desired result"
assert np.allclose(
reduced["i"].feature.segment_ids, expected_result["i"].feature.segment_ids
), "Reduction of a nested dataclass did not yield the desired result"
def _assert_dataclass_reduction(actual, expected, dataclass_type: str = ""):
assert dataclasses.is_dataclass(actual) and not isinstance(
actual, type
), f"Reduction of a {dataclass_type} dataclass should result in a dataclass"
for field in dataclasses.fields(actual):
if dataclasses.is_dataclass(field.type):
_assert_dataclass_reduction(getattr(actual, field.name), getattr(expected, field.name), "nested")
assert actual == expected, f"Reduction of a {dataclass_type} dataclass did not yield the desired result"

_assert_dataclass_reduction(reduced["h"], expected_result["h"])

_assert_dataclass_reduction(reduced["i"], expected_result["i"])

dataclass_type = "ClassVar-containing"
_assert_dataclass_reduction(reduced["j"], expected_result["j"], dataclass_type)
assert WithClassVar.class_var == 0, f"Reduction of a {dataclass_type} dataclass should not change the class var"

_assert_dataclass_reduction(reduced["k"], expected_result["k"], "InitVar-containing")

dataclass_type = "Class-and-InitVar-containing"
_assert_dataclass_reduction(reduced["l"], expected_result["l"], dataclass_type)
assert torch.equal(
WithClassAndInitVar.class_var, torch.tensor(0)
), f"Reduction of a {dataclass_type} dataclass should not change the class var"

# mapping support
reduced = apply_to_collection({"a": 1, "b": 2}, int, lambda x: str(x))
Expand Down

0 comments on commit 4bc6e95

Please sign in to comment.