Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LightningCLI save config multifile option required to support save to fsspec filesystem #9073

Merged
merged 7 commits into from
Sep 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `PL_RECONCILE_PROCESS` environment variable to enable process reconciliation regardless of cluster environment settings ([#9389](https://github.com/PyTorchLightning/pytorch-lightning/pull/9389))


- Added `multifile` option to `LightningCLI` to enable/disable config save to preserve multiple files structure ([#9073](https://github.com/PyTorchLightning/pytorch-lightning/pull/9073))


- Added `RichModelSummary` callback ([#9546](https://github.com/PyTorchLightning/pytorch-lightning/pull/9546))


Expand Down
22 changes: 20 additions & 2 deletions pytorch_lightning/utilities/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,13 @@ def _convert_argv_issue_85(classes: Tuple[Type, ...], nested_key: str, argv: Lis
class SaveConfigCallback(Callback):
"""Saves a LightningCLI config to the log_dir when training starts.

Args:
parser: The parser object used to parse the configuration.
config: The parsed configuration that will be saved.
config_filename: Filename for the config file.
overwrite: Whether to overwrite an existing config file.
multifile: When input is multiple config files, saved config preserves this structure.

Raises:
RuntimeError: If the config file already exists in the directory to avoid overwriting a previous run
"""
Expand All @@ -348,11 +355,13 @@ def __init__(
config: Union[Namespace, Dict[str, Any]],
config_filename: str,
overwrite: bool = False,
multifile: bool = False,
carmocca marked this conversation as resolved.
Show resolved Hide resolved
) -> None:
self.parser = parser
self.config = config
self.config_filename = config_filename
self.overwrite = overwrite
self.multifile = multifile

def setup(self, trainer: Trainer, pl_module: LightningModule, stage: Optional[str] = None) -> None:
# save the config in `setup` because (1) we want it to save regardless of the trainer function run
Expand All @@ -372,7 +381,9 @@ def setup(self, trainer: Trainer, pl_module: LightningModule, stage: Optional[st
# the `log_dir` needs to be created as we rely on the logger to do it usually
# but it hasn't logged anything at this point
get_filesystem(log_dir).makedirs(log_dir, exist_ok=True)
self.parser.save(self.config, config_path, skip_none=False, overwrite=self.overwrite)
self.parser.save(
self.config, config_path, skip_none=False, overwrite=self.overwrite, multifile=self.multifile
)

def __reduce__(self) -> Tuple[Type["SaveConfigCallback"], Tuple, Dict]:
# `ArgumentParser` is un-pickleable. Drop it
Expand All @@ -389,6 +400,7 @@ def __init__(
save_config_callback: Optional[Type[SaveConfigCallback]] = SaveConfigCallback,
save_config_filename: str = "config.yaml",
save_config_overwrite: bool = False,
save_config_multifile: bool = False,
trainer_class: Union[Type[Trainer], Callable[..., Trainer]] = Trainer,
trainer_defaults: Optional[Dict[str, Any]] = None,
seed_everything_default: Optional[int] = None,
Expand Down Expand Up @@ -420,6 +432,7 @@ def __init__(
save_config_callback: A callback class to save the training config.
save_config_filename: Filename for the config file.
save_config_overwrite: Whether to overwrite an existing config file.
save_config_multifile: When input is multiple config files, saved config preserves this structure.
trainer_class: An optional subclass of the :class:`~pytorch_lightning.trainer.trainer.Trainer` class or a
callable which returns a :class:`~pytorch_lightning.trainer.trainer.Trainer` instance when called.
trainer_defaults: Set to override Trainer defaults or add persistent callbacks.
Expand All @@ -443,6 +456,7 @@ def __init__(
self.save_config_callback = save_config_callback
self.save_config_filename = save_config_filename
self.save_config_overwrite = save_config_overwrite
self.save_config_multifile = save_config_multifile
self.trainer_class = trainer_class
self.trainer_defaults = trainer_defaults or {}
self.seed_everything_default = seed_everything_default
Expand Down Expand Up @@ -615,7 +629,11 @@ def _instantiate_trainer(self, config: Dict[str, Any], callbacks: List[Callback]
config["callbacks"].append(self.trainer_defaults["callbacks"])
if self.save_config_callback and not config["fast_dev_run"]:
config_callback = self.save_config_callback(
self.parser, self.config, self.save_config_filename, overwrite=self.save_config_overwrite
self.parser,
self.config,
self.save_config_filename,
overwrite=self.save_config_overwrite,
multifile=self.save_config_multifile,
)
config["callbacks"].append(config_callback)
return self.trainer_class(**config)
Expand Down