diff --git a/cli/conf/pretrain/default.yaml b/cli/conf/pretrain/default.yaml index 8bb5e0d..ce28bfc 100644 --- a/cli/conf/pretrain/default.yaml +++ b/cli/conf/pretrain/default.yaml @@ -39,6 +39,13 @@ trainer: mode: max save_top_k: -1 every_n_epochs: ${floordiv:${trainer.max_epochs},10} + - _target_: uni2ts.callbacks.HuggingFaceCheckpoint.HuggingFaceCheckpoint + dirpath: ${hydra:runtime.output_dir}/HF_checkpoints + filename: last + monitor: epoch + mode: max + save_top_k: 1 + every_n_epochs: 1 # epoch-based training provides averaged metrics # cannot use max_steps with epoch-based training - resume from checkpoint on wrong epoch max_epochs: 1_000 diff --git a/src/uni2ts/callbacks/HuggingFaceCheckpoint.py b/src/uni2ts/callbacks/HuggingFaceCheckpoint.py index ba0bdf4..105771c 100644 --- a/src/uni2ts/callbacks/HuggingFaceCheckpoint.py +++ b/src/uni2ts/callbacks/HuggingFaceCheckpoint.py @@ -36,7 +36,6 @@ log = logging.getLogger(__name__) warning_cache = WarningCache() - _PATH = Union[str, Path] @@ -73,8 +72,13 @@ def _save_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None: pl_module = trainer.model pretrain_module = pl_module.module - if hasattr(pretrain_module, "module"): + try: moirai_module = pretrain_module.module + except AttributeError: + moirai_module = pretrain_module + warnings.warn( + "Warning: no module attribute found in the model. Saving the model directly." + ) # filepath in pytorch lightning usually ends with .ckpt # To get the directory to save the model, remove the .ckpt @@ -82,7 +86,11 @@ def _save_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None: save_dir = filepath.split(".ckpt")[0] else: save_dir = filepath - moirai_module.save_pretrained(save_dir) + + try: + moirai_module.save_pretrained(save_dir) + except Exception as e: + warnings.warn(f"An error occurred during model saving: {e}") self._last_global_step_saved = trainer.global_step self._last_checkpoint_saved = save_dir