Skip to content

Commit

Permalink
revise HF callback and add to default config
Browse files Browse the repository at this point in the history
  • Loading branch information
liu-jc committed Aug 22, 2024
1 parent 1bf15e9 commit d38b388
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 3 deletions.
7 changes: 7 additions & 0 deletions cli/conf/pretrain/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,13 @@ trainer:
mode: max
save_top_k: -1
every_n_epochs: ${floordiv:${trainer.max_epochs},10}
- _target_: uni2ts.callbacks.HuggingFaceCheckpoint.HuggingFaceCheckpoint
dirpath: ${hydra:runtime.output_dir}/HF_checkpoints
filename: last
monitor: epoch
mode: max
save_top_k: 1
every_n_epochs: 1
# epoch-based training provides averaged metrics
# cannot use max_steps with epoch-based training - resume from checkpoint on wrong epoch
max_epochs: 1_000
Expand Down
14 changes: 11 additions & 3 deletions src/uni2ts/callbacks/HuggingFaceCheckpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
log = logging.getLogger(__name__)
warning_cache = WarningCache()


_PATH = Union[str, Path]


Expand Down Expand Up @@ -73,16 +72,25 @@ def _save_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:
pl_module = trainer.model
pretrain_module = pl_module.module

if hasattr(pretrain_module, "module"):
try:
moirai_module = pretrain_module.module
except AttributeError:
moirai_module = pretrain_module
warnings.warn(
"Warning: no module attribute found in the model. Saving the model directly."
)

# filepath in pytorch lightning usually ends with .ckpt
# To get the directory to save the model, remove the .ckpt
if filepath.endswith(".ckpt"):
save_dir = filepath.split(".ckpt")[0]
else:
save_dir = filepath
moirai_module.save_pretrained(save_dir)

try:
moirai_module.save_pretrained(save_dir)
except Exception as e:
warnings.warn(f"An error occurred during model saving: {e}")

self._last_global_step_saved = trainer.global_step
self._last_checkpoint_saved = save_dir
Expand Down

0 comments on commit d38b388

Please sign in to comment.