Skip to content

Commit

Permalink
Ignore parameters causing ValueError when dumping to YAML (#19804)
Browse files Browse the repository at this point in the history
  • Loading branch information
Callidior authored Jun 6, 2024
1 parent 4f96c83 commit 5fa32d9
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 2 deletions.
2 changes: 2 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed an issue with the LightningCLI not being able to set the `ModelCheckpoint(save_last=...)` argument ([#19808](https://github.com/Lightning-AI/pytorch-lightning/pull/19808))

- Fixed an issue causing ValueError for certain object such as TorchMetrics when dumping hyperparameters to YAML ([#19804](https://github.com/Lightning-AI/pytorch-lightning/pull/19804))


## [2.2.2] - 2024-04-11

Expand Down
2 changes: 1 addition & 1 deletion src/lightning/pytorch/core/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ def save_hparams_to_yaml(config_yaml: _PATH, hparams: Union[dict, Namespace], us
try:
v = v.name if isinstance(v, Enum) else v
yaml.dump(v)
except TypeError:
except (TypeError, ValueError):
warn(f"Skipping '{k}' parameter because it is not possible to safely dump to YAML.")
hparams[k] = type(v).__name__
else:
Expand Down
10 changes: 9 additions & 1 deletion tests/tests_pytorch/models/test_hparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,7 +552,7 @@ def test_hparams_pickle_warning(tmp_path):
trainer.fit(model)


def test_hparams_save_yaml(tmp_path):
def test_save_hparams_to_yaml(tmp_path):
class Options(str, Enum):
option1name = "option1val"
option2name = "option2val"
Expand Down Expand Up @@ -590,6 +590,14 @@ def _compare_params(loaded_params, default_params: dict):
_compare_params(load_hparams_from_yaml(path_yaml), hparams)


def test_save_hparams_to_yaml_warning(tmp_path):
"""Test that we warn about unserializable parameters that need to be dropped."""
path_yaml = tmp_path / "hparams.yaml"
hparams = {"torch_type": torch.float32}
with pytest.warns(UserWarning, match="Skipping 'torch_type' parameter"):
save_hparams_to_yaml(path_yaml, hparams)


class NoArgsSubClassBoringModel(CustomBoringModel):
def __init__(self):
super().__init__()
Expand Down

0 comments on commit 5fa32d9

Please sign in to comment.