diff --git a/CHANGELOG.md b/CHANGELOG.md index 700e18924b963..6453989f72136 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -226,6 +226,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `XLACheckpointIO` plugin ([#9972](https://github.com/PyTorchLightning/pytorch-lightning/pull/9972)) +- Added `use_omegaconf` argument to `save_hparams_to_yaml` plugin ([#9170](https://github.com/PyTorchLightning/pytorch-lightning/pull/9170)) + + - Added `ckpt_path` argument for `trainer.fit()` ([#10061](https://github.com/PyTorchLightning/pytorch-lightning/pull/10061)) @@ -242,6 +245,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). Old [neptune-client](https://github.com/neptune-ai/neptune-client) API is supported by `NeptuneClient` from [neptune-contrib](https://github.com/neptune-ai/neptune-contrib) repo. +- Parsing of `enums` type hyperparameters to be saved in the `haprams.yaml` file by tensorboard and csv loggers has been fixed and made in line with how omegaconf parses it. ([#9170](https://github.com/PyTorchLightning/pytorch-lightning/pull/9170)) + + - Parsing of the `gpus` Trainer argument has changed: `gpus="n"` (str) no longer selects the GPU index n and instead selects the first n devices. ([#8770](https://github.com/PyTorchLightning/pytorch-lightning/pull/8770)) diff --git a/pytorch_lightning/core/saving.py b/pytorch_lightning/core/saving.py index 782fbe99c0425..2f9463e42fec8 100644 --- a/pytorch_lightning/core/saving.py +++ b/pytorch_lightning/core/saving.py @@ -19,6 +19,7 @@ import os from argparse import Namespace from copy import deepcopy +from enum import Enum from typing import Any, Callable, Dict, IO, MutableMapping, Optional, Union from warnings import warn @@ -318,8 +319,8 @@ def load_hparams_from_yaml(config_yaml: str, use_omegaconf: bool = True) -> Dict Args: config_yaml: Path to config yaml file - use_omegaconf: If both `OMEGACONF_AVAILABLE` and `use_omegaconf` are True, - the hparams will be converted to `DictConfig` if possible + use_omegaconf: If omegaconf is available and ``use_omegaconf=True``, + the hparams will be converted to ``DictConfig`` if possible. >>> hparams = Namespace(batch_size=32, learning_rate=0.001, data_root='./any/path/here') >>> path_yaml = './testing-hparams.yaml' @@ -346,11 +347,14 @@ def load_hparams_from_yaml(config_yaml: str, use_omegaconf: bool = True) -> Dict return hparams -def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None: +def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace], use_omegaconf: bool = True) -> None: """ Args: config_yaml: path to new YAML file hparams: parameters to be saved + use_omegaconf: If omegaconf is available and ``use_omegaconf=True``, + the hparams will be converted to ``DictConfig`` if possible. + """ fs = get_filesystem(config_yaml) if not fs.isdir(os.path.dirname(config_yaml)): @@ -363,7 +367,7 @@ def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None: hparams = dict(hparams) # saving with OmegaConf objects - if _OMEGACONF_AVAILABLE: + if _OMEGACONF_AVAILABLE and use_omegaconf: # deepcopy: hparams from user shouldn't be resolved hparams = deepcopy(hparams) hparams = apply_to_collection(hparams, DictConfig, OmegaConf.to_container, resolve=True) @@ -381,6 +385,7 @@ def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None: # drop paramaters which contain some strange datatypes as fsspec for k, v in hparams.items(): try: + v = v.name if isinstance(v, Enum) else v yaml.dump(v) except TypeError: warn(f"Skipping '{k}' parameter because it is not possible to safely dump to YAML.") diff --git a/tests/models/test_hparams.py b/tests/models/test_hparams.py index a95bdf9ba76d7..dbd51d33bf0ed 100644 --- a/tests/models/test_hparams.py +++ b/tests/models/test_hparams.py @@ -17,6 +17,7 @@ import pickle from argparse import Namespace from dataclasses import dataclass +from enum import Enum from unittest import mock import cloudpickle @@ -24,6 +25,7 @@ import torch from fsspec.implementations.local import LocalFileSystem from omegaconf import Container, OmegaConf +from omegaconf.dictconfig import DictConfig from torch.utils.data import DataLoader from pytorch_lightning import LightningModule, Trainer @@ -477,22 +479,40 @@ def test_hparams_pickle_warning(tmpdir): def test_hparams_save_yaml(tmpdir): + class Options(str, Enum): + option1name = "option1val" + option2name = "option2val" + option3name = "option3val" + hparams = dict( - batch_size=32, learning_rate=0.001, data_root="./any/path/here", nasted=dict(any_num=123, anystr="abcd") + batch_size=32, + learning_rate=0.001, + data_root="./any/path/here", + nested=dict(any_num=123, anystr="abcd"), + switch=Options.option3name, ) path_yaml = os.path.join(tmpdir, "testing-hparams.yaml") + def _compare_params(loaded_params, default_params: dict): + assert isinstance(loaded_params, (dict, DictConfig)) + assert loaded_params.keys() == default_params.keys() + for k, v in default_params.items(): + if isinstance(v, Enum): + assert v.name == loaded_params[k] + else: + assert v == loaded_params[k] + save_hparams_to_yaml(path_yaml, hparams) - assert load_hparams_from_yaml(path_yaml, use_omegaconf=False) == hparams + _compare_params(load_hparams_from_yaml(path_yaml, use_omegaconf=False), hparams) save_hparams_to_yaml(path_yaml, Namespace(**hparams)) - assert load_hparams_from_yaml(path_yaml, use_omegaconf=False) == hparams + _compare_params(load_hparams_from_yaml(path_yaml, use_omegaconf=False), hparams) save_hparams_to_yaml(path_yaml, AttributeDict(hparams)) - assert load_hparams_from_yaml(path_yaml, use_omegaconf=False) == hparams + _compare_params(load_hparams_from_yaml(path_yaml, use_omegaconf=False), hparams) save_hparams_to_yaml(path_yaml, OmegaConf.create(hparams)) - assert load_hparams_from_yaml(path_yaml) == hparams + _compare_params(load_hparams_from_yaml(path_yaml), hparams) class NoArgsSubClassBoringModel(CustomBoringModel):