Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add sharded checkpoint loading for AutoTP path to reduce the peak mem… #3102

Merged
merged 10 commits into from
May 4, 2023
46 changes: 35 additions & 11 deletions deepspeed/inference/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,6 @@ def __init__(self, model, config):
assert pkg_version.parse(torch.__version__) >= pkg_version.parse("1.10"), \
"If you want to use cuda graph, please upgrade torch to at least v1.10"

if config.checkpoint and not config.replace_with_kernel_inject:
self._load_checkpoint(config.checkpoint)

# convert model to intended dtype
if config.dtype:
self._convert_to_dtype(config)
Expand All @@ -173,10 +170,6 @@ def __init__(self, model, config):
if moe and dist.get_world_size() > 1:
self._create_ep_parallel_group(config.moe.moe_experts)

# retain this from the old conditional argument being passed to apply_injection_policy()
if not config.replace_with_kernel_inject:
config.checkpoint = None

# We only support three modes: 1) user specified policy for tensor-parallelism, 2) kernel injection (replace_with_kernel_inject), and 3) automatic tensor parallelism.
if self.injection_dict:
# 1. User specified Tensor Parallelism
Expand Down Expand Up @@ -343,18 +336,38 @@ def load_model_with_checkpoint(self, r_module):
def load(module, state_dict, prefix):
args = (state_dict, prefix, {}, True, [], [], error_msgs)
if hasattr(module, 'weight'):
if module.weight.data.is_meta:
# meta tensor cannot be casted or copied to, so we need to replace it with a normal tensor here
module.weight = torch.nn.parameter.Parameter(data=torch.empty_like(module.weight.data,
device="cpu"),
requires_grad=module.weight.data.requires_grad)
if 'query_key_value' in prefix:
module.weight = self.mp_replace.strided_copy(module.weight.data,
state_dict[prefix + 'weight'],
num_splits=3)
else:
module.weight = self.mp_replace.copy(module.weight.data, state_dict[prefix + 'weight'])
else:
if module.norm.weight.data.is_meta:
# meta tensor cannot be casted or copied to, so we need to replace it with a normal tensor here
module.norm.weight = torch.nn.parameter.Parameter(
data=torch.empty_like(module.norm.weight.data, device="cpu"),
requires_grad=module.norm.weight.data.requires_grad)
module.norm.weight = self.mp_replace.copy(module.norm.weight.data, state_dict[prefix + 'weight'])
if prefix + 'bias' in self.key_list:
if hasattr(module, 'norm'):
if module.norm.bias.data.is_meta:
# meta tensor cannot be casted or copied to, so we need to replace it with a normal tensor here
module.norm.bias = torch.nn.parameter.Parameter(
data=torch.empty_like(module.norm.bias.data, device="cpu"),
requires_grad=module.norm.bias.data.requires_grad)
module.norm.bias = self.mp_replace.copy(module.norm.bias, state_dict[prefix + 'bias'])
else:
if module.bias.data.is_meta:
# meta tensor cannot be casted or copied to, so we need to replace it with a normal tensor here
module.bias = torch.nn.parameter.Parameter(data=torch.empty_like(module.bias.data,
device="cpu"),
requires_grad=module.bias.data.requires_grad)
data = state_dict[prefix + 'bias']
data = data.to(get_accelerator().current_device_name())
module.bias = self.mp_replace.copy(module.bias, data)
Expand Down Expand Up @@ -383,6 +396,15 @@ def load_module_recursive(module, prefix='', level=0):

load_module_recursive(r_module)

embedding_weight = None

for n, p in r_module.named_parameters():
if "word_embeddings." in n or "embed_tokens." in n or "wte." in n:
embedding_weight = p
if embedding_weight is not None and hasattr(r_module, "lm_head") and hasattr(
r_module.lm_head, "weight") and r_module.lm_head.weight.is_meta:
r_module.lm_head.weight = embedding_weight

def _apply_injection_policy(self, config, client_module=None):
# client_module is only passed when using the injection_dict method.
checkpoint_dir = config.checkpoint
Expand Down Expand Up @@ -434,16 +456,18 @@ def _load_checkpoint(self, load_dir, load_module_strict=True, tag=None):
else:
sd_loader = SDLoaderFactory.get_sd_loader_json(load_dir, self.checkpoint_engine)

if type(sd_loader) is list:
self.sd = torch.load(sd_loader[0], map_location='cpu')
checkpoint = sd_loader['checkpoints']

if type(checkpoint) is list:
self.sd = torch.load(checkpoint[0], map_location='cpu')
self.key_list = list(self.sd.keys())

self.load_model_with_checkpoint(self.module)

for i in range(1, len(sd_loader)):
for i in range(1, len(checkpoint)):
if not dist.is_initialized() or dist.get_rank() == 0:
print(f"loading checkpoint ({i})")
self.sd = torch.load(sd_loader[i], map_location=get_accelerator().device_name())
self.sd = torch.load(checkpoint[i], map_location=get_accelerator().device_name())
self.key_list = list(self.sd.keys())
self.load_model_with_checkpoint(self.module)
else:
Expand Down
Loading