-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
DeepSpeed Zero 2 Fails to Load All Checkpoint Parameters #15694
Comments
Any updates on this, or quick workarounds? |
You can try using this code to convert the DeepSpeed checkpoint to a Lightning checkpoint while patching in all parameters that aren't loaded from the DeepSpeed checkpoint. I can't guarantee that the parameters which aren't processed by the import os
import torch
from pytorch_lightning.utilities.deepspeed import (
convert_zero_checkpoint_to_fp32_state_dict,
get_model_state_file,
get_optim_files,
ds_checkpoint_dir
)
DS_PARAM_REGEX = r'_forward_module\.(.+)'
def convert_deepspeed_checkpoint(deepspeed_ckpt_path: str, pl_ckpt_path: str = None):
'''
Creates a PyTorch Lightning checkpoint from the DeepSpeed checkpoint directory, while patching
in parameters which are improperly loaded by the DeepSpeed conversion utility.
deepspeed_ckpt_path: Path to the DeepSpeed checkpoint folder.
pl_ckpt_path: Path to the reconstructed PyTorch Lightning checkpoint. If not specified, will be
placed in the same directory as the DeepSpeed checkpoint directory with the same name but
a .pt extension.
Returns: path to the converted checkpoint.
'''
if not (deepspeed_ckpt_path.endswith('.ckpt') and os.path.isdir(deepspeed_ckpt_path)):
raise ValueError(
'args.ckpt_dir should point to the checkpoint directory'
' output by DeepSpeed (e.g. "last.ckpt" or "epoch=4-step=39150.ckpt").'
)
# Convert state dict to PyTorch format
if not pl_ckpt_path:
pl_ckpt_path = f'{deepspeed_ckpt_path[:-4]}pt' # .ckpt --> .pt
if not os.path.exists(pl_ckpt_path):
convert_zero_checkpoint_to_fp32_state_dict(deepspeed_ckpt_path, pl_ckpt_path)
# Patch in missing parameters that failed to be converted by DeepSpeed utility
pl_ckpt = _merge_deepspeed_weights(deepspeed_ckpt_path, pl_ckpt_path)
torch.save(pl_ckpt, pl_ckpt_path)
return pl_ckpt_path
def _merge_deepspeed_weights(deepspeed_ckpt_path: str, fp32_ckpt_path: str):
'''
Merges tensors with keys in the DeepSpeed checkpoint but not in the fp32_checkpoint
into the fp32 state dict.
deepspeed_ckpt_path: Path to the DeepSpeed checkpoint folder.
fp32_ckpt_path: Path to the reconstructed
'''
# This first part is based on pytorch_lightning.utilities.deepspeed.convert_zero_checkpoint_to_fp32_state_dict
checkpoint_dir = ds_checkpoint_dir(deepspeed_ckpt_path)
optim_files = get_optim_files(checkpoint_dir)
optim_state = torch.load(optim_files[0], map_location='cpu')
zero_stage = optim_state["optimizer_state_dict"]["zero_stage"]
deepspeed_model_file = get_model_state_file(checkpoint_dir, zero_stage)
# Start adding all parameters from DeepSpeed ckpt to generated PyTorch Lightning ckpt
ds_ckpt = torch.load(deepspeed_model_file, map_location='cpu')
ds_sd = ds_ckpt['module']
fp32_ckpt = torch.load(fp32_ckpt_path, map_location='cpu')
fp32_sd = fp32_ckpt['state_dict']
for k, v in ds_sd.items():
try:
match = re.match(DS_PARAM_REGEX, k)
param_name = match.group(1)
except:
print(f'Failed to extract parameter from DeepSpeed key {k}')
continue
v = v.to(torch.float32)
if param_name not in fp32_sd:
print(f'Adding parameter {param_name} from DeepSpeed state_dict to fp32_sd')
fp32_sd[param_name] = v
else:
assert torch.allclose(v, fp32_sd[param_name], atol=1e-2)
return fp32_ckpt |
thank you @kelvins64 , I will try this out. |
I can confirm this does work (though it's missing an |
Confirm this works. Great work. Thanks |
Thanks @yakazimir for pointing me to this issue. I looked into it and found that the problem lies in DeepSpeed. When saving a checkpoint, DeepSpeed is not able to identify the shared parameters and when converting/loading the checkpoint, it doesn't reconstruct them properly, leading to the error for missing keys. I boiled this down to a reproducible script with DeepSpeed and submitted a ticket and a PR with the fix. If my PR gets merged, the workaround posted here won't be necessary anymore. For reference, my investigation was done with DeepSpeed master (0.9.5dev) and Lightning master (2.1.0dev) starting from this script based on the original submission but with minor modifications to fit the newer API: import os
import torch
import shutil
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import lightning.pytorch as pl
from lightning.pytorch import LightningModule, Trainer
from lightning.pytorch.utilities.deepspeed import convert_zero_checkpoint_to_fp32_state_dict
from transformers import MBartForConditionalGeneration, MBart50Tokenizer
class TextDataset(Dataset):
def __init__(self, model_name, length):
self.model_name = model_name
self.length = length
self.tokenizer = MBart50Tokenizer.from_pretrained(model_name)
self.data = self.tokenizer(
[f'Hello world {i}!' for i in range(length)],
padding='longest',
truncation=True,
return_tensors='pt'
)
def __getitem__(self, index):
return {
'input_ids': self.data['input_ids'][index],
'attention_mask': self.data['attention_mask'][index],
'labels': self.data['input_ids'][index] # Have the target text be the input text
}
def __len__(self):
return self.length
class BoringModel(LightningModule):
def __init__(self, model_name):
super().__init__()
self.model = MBartForConditionalGeneration.from_pretrained(model_name)
def forward(self, batch):
return self.model(**batch)[0] # Return loss
def training_step(self, batch, batch_idx):
loss = self(batch)
self.log("train_loss", loss)
return {"loss": loss}
def validation_step(self, batch, batch_idx):
loss = self(batch)
self.log("valid_loss", loss)
def configure_optimizers(self):
return torch.optim.SGD(self.parameters(), lr=0.1)
def run():
if os.path.exists("lightning_logs"):
shutil.rmtree("lightning_logs")
model_name = 'facebook/mbart-large-50'
pl.seed_everything(42)
train_data = DataLoader(TextDataset(model_name, 64), batch_size=2)
val_data = DataLoader(TextDataset(model_name, 64), batch_size=2)
model = BoringModel(model_name)
trainer = Trainer(
accelerator="cuda",
devices=2,
strategy="deepspeed_stage_2",
limit_train_batches=1,
limit_val_batches=1,
num_sanity_val_steps=0,
max_epochs=1,
deterministic=True
)
trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)
from pprint import pprint
# Convert checkpoint
if trainer.is_global_zero:
pprint(trainer.strategy.config)
zero_ckpt_dir = os.path.join(os.getcwd(), 'lightning_logs/version_0/checkpoints/epoch=0-step=1.ckpt')
ckpt_path = zero_ckpt_dir[:-4] + 'pth'
convert_zero_checkpoint_to_fp32_state_dict(zero_ckpt_dir, ckpt_path)
# Attempt to load checkpoint
model.load_from_checkpoint(ckpt_path, model_name=model_name, strict=True)
if __name__ == "__main__":
run() |
Fix was merged in deepspeed: deepspeedai/DeepSpeed#3825 |
I have to change |
Bug description
Using DeepSpeed Zero 2 with certain models fails to properly save and reload the model checkpoint after conversion to the Lightning format.
In the provided example, several parameters do not appear in the
param_shapes
value of the Zero checkpoint (which the generated reconstruction script uses to build the state dict), despite appearing in themodule
value of the Zero checkpoint.How to reproduce the bug
Error messages and logs
Running the above code, we encounter the error message
Environment
More info
cc @awaelchli
The text was updated successfully, but these errors were encountered: