Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix checkpoint conversion when model layers share weights #3825

Merged
22 changes: 14 additions & 8 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3236,24 +3236,30 @@ def _get_shared_params(self):
e.g. in `zero_to_fp32`. Each dict entry is a pair of param names, where the key is the name
of the variable that isn't stored and the value is the actual param holding data.
"""
shared_ds_ids = {}
shared_index = {}
shared_params_by_full_name = {}

is_zero3_model = (self.zero_optimization_partition_weights()
and any(hasattr(param, "ds_id") for param in self.module.parameters()))

def get_layer_state_dict(module, prefix=""):
# handle params
for name, param in module.named_parameters(recurse=False):
if param is None or not hasattr(param, "ds_id"):
if param is None or (is_zero3_model and not hasattr(param, "ds_id")):
continue
key = prefix + name
# can't rely on param.data_ptr() as it will be reused as weights gets
# gathered and reduced, but param.ds_id is unique across all zero weights

# When weights are manged by stage 3, we can't rely on param.data_ptr() as it will be reused
# as weights get gathered and reduced, but param.ds_id is unique across all zero weights
# (and shared params will have the same param.ds_id)
if param.ds_id in shared_ds_ids:
param_id = param.ds_id if is_zero3_model else param.data_ptr()

if param_id in shared_index:
# shared weights
#print(f"`{key}` is shared with `{shared_ds_ids[param.ds_id]}`")
shared_params_by_full_name[key] = shared_ds_ids[param.ds_id]
#print(f"`{key}` is shared with `{shared_index[param_id]}`")
shared_params_by_full_name[key] = shared_index[param_id]
else:
shared_ds_ids[param.ds_id] = key
shared_index[param_id] = key

for name, child in module.named_children():
if child is not None:
Expand Down
49 changes: 49 additions & 0 deletions tests/unit/checkpoint/test_shared_weights.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import torch
import torch.nn as nn

import deepspeed
from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
from unit.common import DistributedTest


class ModelWithSharedWeights(nn.Module):

def __init__(self):
super().__init__()
self.layer0 = nn.Linear(100, 100)
self.layer1 = nn.Linear(200, 200)
self.layer2 = nn.Linear(300, 300)
# tie layer 1 and layer 2
self.layer1.weight = self.layer2.weight


class TestCheckpointSharedWeights(DistributedTest):
world_size = 2

def test_checkpoint_shared_weights(self, tmp_path):
config = {
"train_micro_batch_size_per_gpu": 2,
"zero_allow_untested_optimizer": True,
"zero_optimization": {
"stage": 2
},
}
model = ModelWithSharedWeights()
optimizer = torch.optim.Adam(model.parameters())

deepspeed_engine, _, _, _ = deepspeed.initialize(
config=config,
model=model,
optimizer=optimizer,
)
filename = tmp_path / "checkpoint.pt"
deepspeed_engine.save_checkpoint(filename, tag="checkpoint")

model = ModelWithSharedWeights()
state_dict = get_fp32_state_dict_from_zero_checkpoint(filename, tag="checkpoint")
model.load_state_dict(state_dict, strict=True)