Skip to content

Commit

Permalink
Allow frozen dataclasses in apply_to_collection (#98)
Browse files Browse the repository at this point in the history
Co-authored-by: Jirka <[email protected]>
  • Loading branch information
janEbert and Borda authored Feb 7, 2023
1 parent d4ee2ca commit 9eab563
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 3 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

-
- Allow frozen dataclasses in `apply_to_collection` ([#98](https://github.com/Lightning-AI/utilities/pull/98))


### Changed
Expand Down
24 changes: 22 additions & 2 deletions src/lightning_utilities/core/apply_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def apply_to_collection(
*args: Any,
wrong_dtype: Optional[Union[type, Tuple[type, ...]]] = None,
include_none: bool = True,
allow_frozen: bool = False,
**kwargs: Any,
) -> Any:
"""Recursively applies a function to all elements of a certain dtype.
Expand All @@ -37,6 +38,7 @@ def apply_to_collection(
wrong_dtype: the given function won't be applied if this type is specified and the given collections
is of the ``wrong_dtype`` even if it is of type ``dtype``
include_none: Whether to include an element if the output of ``function`` is ``None``.
allow_frozen: Whether not to error upon encountering a frozen dataclass instance.
**kwargs: keyword arguments (will be forwarded to calls of ``function``)
Returns:
Expand All @@ -53,7 +55,14 @@ def apply_to_collection(
out = []
for k, v in data.items():
v = apply_to_collection(
v, dtype, function, *args, wrong_dtype=wrong_dtype, include_none=include_none, **kwargs
v,
dtype,
function,
*args,
wrong_dtype=wrong_dtype,
include_none=include_none,
allow_frozen=allow_frozen,
**kwargs,
)
if include_none or v is not None:
out.append((k, v))
Expand All @@ -67,7 +76,14 @@ def apply_to_collection(
out = []
for d in data:
v = apply_to_collection(
d, dtype, function, *args, wrong_dtype=wrong_dtype, include_none=include_none, **kwargs
d,
dtype,
function,
*args,
wrong_dtype=wrong_dtype,
include_none=include_none,
allow_frozen=allow_frozen,
**kwargs,
)
if include_none or v is not None:
out.append(v)
Expand Down Expand Up @@ -95,13 +111,17 @@ def apply_to_collection(
*args,
wrong_dtype=wrong_dtype,
include_none=include_none,
allow_frozen=allow_frozen,
**kwargs,
)
if not field_init or (not include_none and v is None): # retain old value
v = getattr(data, field_name)
try:
setattr(result, field_name, v)
except dataclasses.FrozenInstanceError as e:
if allow_frozen:
# Quit early if we encounter a frozen data class; return `result` as is.
break
raise ValueError(
"A frozen dataclass was passed to `apply_to_collection` but this is not allowed."
) from e
Expand Down
10 changes: 10 additions & 0 deletions tests/unittests/core/test_apply_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,3 +324,13 @@ class Foo:
foo = Foo(0)
with pytest.raises(ValueError, match="frozen dataclass was passed"):
apply_to_collection(foo, int, lambda x: x + 1)


def test_apply_to_collection_allow_frozen_dataclass():
@dataclasses.dataclass(frozen=True)
class Foo:
input: int

foo = Foo(0)
result = apply_to_collection(foo, int, lambda x: x + 1, allow_frozen=True)
assert foo == result

0 comments on commit 9eab563

Please sign in to comment.