Skip to content

Commit

Permalink
Support save/load ckpt for XLA FSDP (#32311)
Browse files Browse the repository at this point in the history
* Support save/load ckpt for XLA FSDP

* Fix bug for save

* Fix style

* reserve sharded ckpt and better file naming

* minor fix

Co-authored-by: Zach Mueller <[email protected]>

* add is_fsdp_xla_v1_enabled

---------

Co-authored-by: Zach Mueller <[email protected]>
  • Loading branch information
yitongh and muellerzr authored Aug 19, 2024
1 parent f1b720e commit 8a4857c
Showing 1 changed file with 66 additions and 4 deletions.
70 changes: 66 additions & 4 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,6 +702,7 @@ def __init__(
# Tensor axis is just a placeholder where it will not be used in FSDPv2.
num_devices = xr.global_runtime_device_count()
xs.set_global_mesh(xs.Mesh(np.array(range(num_devices)), (num_devices, 1), axis_names=("fsdp", "tensor")))
self.is_fsdp_xla_v1_enabled = self.is_fsdp_xla_enabled and not self.is_fsdp_xla_v2_enabled

def _activate_neftune(self, model):
r"""
Expand Down Expand Up @@ -3002,7 +3003,20 @@ def _save_rng_state(self, output_dir):
def _save_optimizer_and_scheduler(self, output_dir):
if is_torch_xla_available():
xm.rendezvous("saving_optimizer_states")
xm.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
if self.is_fsdp_xla_v1_enabled:
optm = {
"optimizer": self.optimizer.state_dict(),
"shard_metadata": self.model.get_shard_metadata(),
}
xm.save(
optm,
os.path.join(
output_dir, f"rank{self.args.process_index}-of-{self.args.world_size}-{OPTIMIZER_NAME}"
),
master_only=False,
)
else:
xm.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
with warnings.catch_warnings(record=True) as caught_warnings:
xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
reissue_pt_warnings(caught_warnings)
Expand Down Expand Up @@ -3080,11 +3094,26 @@ def _load_optimizer_and_scheduler(self, checkpoint):
)
)
)
checkpoint_file_exists = (
glob.glob(os.path.join(checkpoint, f"rank*-of-{self.args.world_size}-{OPTIMIZER_NAME}"))
if self.is_fsdp_xla_v1_enabled
else checkpoint_file_exists
)
if checkpoint_file_exists and os.path.isfile(os.path.join(checkpoint, SCHEDULER_NAME)):
# Load in optimizer and scheduler states
if is_torch_xla_available():
# On TPU we have to take some extra precautions to properly load the states on the right device.
optimizer_state = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location="cpu")
if self.is_fsdp_xla_v1_enabled:
optimizer_state = torch.load(
os.path.join(
checkpoint, f"rank{self.args.process_index}-of-{self.args.world_size}-{OPTIMIZER_NAME}"
),
map_location="cpu",
)
# We only need `optimizer` when resuming from checkpoint
optimizer_state = optimizer_state["optimizer"]
else:
optimizer_state = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location="cpu")
with warnings.catch_warnings(record=True) as caught_warnings:
lr_scheduler_state = torch.load(os.path.join(checkpoint, SCHEDULER_NAME), map_location="cpu")
reissue_pt_warnings(caught_warnings)
Expand Down Expand Up @@ -3499,15 +3528,48 @@ def _save_tpu(self, output_dir: Optional[str] = None):
model = self.model
xm.mark_step()

if xm.is_master_ordinal():
if xm.is_master_ordinal(local=False):
os.makedirs(output_dir, exist_ok=True)
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))

# Save a trained model and configuration using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`
supported_classes = (PushToHubMixin,)
xm.rendezvous("saving_checkpoint")
if not isinstance(model, supported_classes):
if self.is_fsdp_xla_v1_enabled:
ckpt = {
"model": model.state_dict(),
"shard_metadata": model.get_shard_metadata(),
}
ckpt_path = os.path.join(
output_dir, f"rank{self.args.process_index}-of-{self.args.world_size}-{WEIGHTS_NAME}"
)
# All ranks save sharded checkpoint
xm.save(ckpt, ckpt_path, master_only=False)
# Make sure all ranks have saved checkpoints
xm.rendezvous("save_full_checkpoints")
# Master save full checkpoint
if self.args.should_save:
from torch_xla.distributed.fsdp import consolidate_sharded_model_checkpoints

full_state_dict, _ = consolidate_sharded_model_checkpoints(
ckpt_prefix=os.path.join(output_dir, ""),
ckpt_suffix=f"rank*-of-*-{WEIGHTS_NAME}",
save_model=False,
)
model = model.module.module
unwrapped_model = self.accelerator.unwrap_model(model)
if isinstance(unwrapped_model, supported_classes):
unwrapped_model.save_pretrained(
output_dir,
state_dict=full_state_dict,
save_function=xm.save,
safe_serialization=self.args.save_safetensors,
)
else:
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
xm.save(full_state_dict, os.path.join(output_dir, WEIGHTS_NAME))
elif not isinstance(model, supported_classes):
if isinstance(self.accelerator.unwrap_model(model), supported_classes):
self.accelerator.unwrap_model(model).save_pretrained(
output_dir,
Expand Down

0 comments on commit 8a4857c

Please sign in to comment.