Skip to content

Commit

Permalink
Save checkpoint to temporary folder first
Browse files Browse the repository at this point in the history
Since partial/missing files due to failures throw error during load
  • Loading branch information
SilverSoldier committed Jan 23, 2025
1 parent 8e4cedd commit 81175b6
Showing 1 changed file with 31 additions and 20 deletions.
51 changes: 31 additions & 20 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import tempfile
import time
import warnings
import errno
from collections.abc import Mapping
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Type, Union
Expand Down Expand Up @@ -3210,31 +3211,41 @@ def _save_checkpoint(self, model, trial):
self.store_flos()

run_dir = self._get_output_dir(trial=trial)
output_dir = os.path.join(run_dir, checkpoint_folder)
self.save_model(output_dir, _internal_call=True)
checkpoint_dir = os.path.join(run_dir, checkpoint_folder)
with tempfile.TemporaryDirectory(prefix=f"tmp-{PREFIX_CHECKPOINT_DIR}-", dir=run_dir) as output_dir:
self.save_model(output_dir, _internal_call=True)

if not self.args.save_only_model:
# Save optimizer and scheduler
self._save_optimizer_and_scheduler(output_dir)
# Save RNG state
self._save_rng_state(output_dir)
if not self.args.save_only_model:
# Save optimizer and scheduler
self._save_optimizer_and_scheduler(output_dir)
# Save RNG state
self._save_rng_state(output_dir)

# Save the Trainer state
if self.args.should_save:
# Update `ExportableState` callbacks and `TrainerControl` state to where we are currently
for cb in [
cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
]:
cb_name = cb.__class__.__name__
cb_state = cb.state()
if isinstance(self.state.stateful_callbacks[cb_name], list):
self.state.stateful_callbacks[cb_name].append(cb_state)
# Save the Trainer state
if self.args.should_save:
# Update `ExportableState` callbacks and `TrainerControl` state to where we are currently
for cb in [
cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
]:
cb_name = cb.__class__.__name__
cb_state = cb.state()
if isinstance(self.state.stateful_callbacks[cb_name], list):
self.state.stateful_callbacks[cb_name].append(cb_state)
else:
self.state.stateful_callbacks[cb_name] = cb_state
self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))

try:
os.renames(output_dir, checkpoint_dir)
except OSError as e:
if e.errno in [errno.ENOTEMPTY, errno.EEXIST]: # Directory/File already exists
shutil.rmtree(checkpoint_dir)
os.renames(output_dir, checkpoint_dir)
else:
self.state.stateful_callbacks[cb_name] = cb_state
self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
raise

if self.args.push_to_hub:
self._push_from_checkpoint(output_dir)
self._push_from_checkpoint(checkpoint_dir)

# Maybe delete some older checkpoints.
if self.args.should_save:
Expand Down

0 comments on commit 81175b6

Please sign in to comment.