diff --git a/CHANGELOG.md b/CHANGELOG.md index c66330725d6a1..667c1ae8c8729 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -141,6 +141,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `PL_RECONCILE_PROCESS` environment variable to enable process reconciliation regardless of cluster environment settings ([#9389](https://github.com/PyTorchLightning/pytorch-lightning/pull/9389)) +- Added `multifile` option to `LightningCLI` to enable/disable config save to preserve multiple files structure ([#9073](https://github.com/PyTorchLightning/pytorch-lightning/pull/9073)) + + - Added `RichModelSummary` callback ([#9546](https://github.com/PyTorchLightning/pytorch-lightning/pull/9546)) diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index d97ef9ccddebb..6007dad4478f4 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -338,6 +338,13 @@ def _convert_argv_issue_85(classes: Tuple[Type, ...], nested_key: str, argv: Lis class SaveConfigCallback(Callback): """Saves a LightningCLI config to the log_dir when training starts. + Args: + parser: The parser object used to parse the configuration. + config: The parsed configuration that will be saved. + config_filename: Filename for the config file. + overwrite: Whether to overwrite an existing config file. + multifile: When input is multiple config files, saved config preserves this structure. + Raises: RuntimeError: If the config file already exists in the directory to avoid overwriting a previous run """ @@ -348,11 +355,13 @@ def __init__( config: Union[Namespace, Dict[str, Any]], config_filename: str, overwrite: bool = False, + multifile: bool = False, ) -> None: self.parser = parser self.config = config self.config_filename = config_filename self.overwrite = overwrite + self.multifile = multifile def setup(self, trainer: Trainer, pl_module: LightningModule, stage: Optional[str] = None) -> None: # save the config in `setup` because (1) we want it to save regardless of the trainer function run @@ -372,7 +381,9 @@ def setup(self, trainer: Trainer, pl_module: LightningModule, stage: Optional[st # the `log_dir` needs to be created as we rely on the logger to do it usually # but it hasn't logged anything at this point get_filesystem(log_dir).makedirs(log_dir, exist_ok=True) - self.parser.save(self.config, config_path, skip_none=False, overwrite=self.overwrite) + self.parser.save( + self.config, config_path, skip_none=False, overwrite=self.overwrite, multifile=self.multifile + ) def __reduce__(self) -> Tuple[Type["SaveConfigCallback"], Tuple, Dict]: # `ArgumentParser` is un-pickleable. Drop it @@ -389,6 +400,7 @@ def __init__( save_config_callback: Optional[Type[SaveConfigCallback]] = SaveConfigCallback, save_config_filename: str = "config.yaml", save_config_overwrite: bool = False, + save_config_multifile: bool = False, trainer_class: Union[Type[Trainer], Callable[..., Trainer]] = Trainer, trainer_defaults: Optional[Dict[str, Any]] = None, seed_everything_default: Optional[int] = None, @@ -420,6 +432,7 @@ def __init__( save_config_callback: A callback class to save the training config. save_config_filename: Filename for the config file. save_config_overwrite: Whether to overwrite an existing config file. + save_config_multifile: When input is multiple config files, saved config preserves this structure. trainer_class: An optional subclass of the :class:`~pytorch_lightning.trainer.trainer.Trainer` class or a callable which returns a :class:`~pytorch_lightning.trainer.trainer.Trainer` instance when called. trainer_defaults: Set to override Trainer defaults or add persistent callbacks. @@ -443,6 +456,7 @@ def __init__( self.save_config_callback = save_config_callback self.save_config_filename = save_config_filename self.save_config_overwrite = save_config_overwrite + self.save_config_multifile = save_config_multifile self.trainer_class = trainer_class self.trainer_defaults = trainer_defaults or {} self.seed_everything_default = seed_everything_default @@ -615,7 +629,11 @@ def _instantiate_trainer(self, config: Dict[str, Any], callbacks: List[Callback] config["callbacks"].append(self.trainer_defaults["callbacks"]) if self.save_config_callback and not config["fast_dev_run"]: config_callback = self.save_config_callback( - self.parser, self.config, self.save_config_filename, overwrite=self.save_config_overwrite + self.parser, + self.config, + self.save_config_filename, + overwrite=self.save_config_overwrite, + multifile=self.save_config_multifile, ) config["callbacks"].append(config_callback) return self.trainer_class(**config)