Skip to content

Commit

Permalink
add sharded checkpoint loading for AutoTP path to reduce the peak mem…
Browse files Browse the repository at this point in the history
…ory in initialization stage

Signed-off-by: Wang, Yi A <[email protected]>
  • Loading branch information
sywangyi committed Mar 27, 2023
1 parent b3ec1c9 commit b48734a
Show file tree
Hide file tree
Showing 2 changed files with 166 additions and 25 deletions.
49 changes: 38 additions & 11 deletions deepspeed/inference/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,6 @@ def __init__(self, model, config):
assert pkg_version.parse(torch.__version__) >= pkg_version.parse("1.10"), \
"If you want to use cuda graph, please upgrade torch to at least v1.10"

if config.checkpoint and not config.replace_with_kernel_inject:
self._load_checkpoint(config.checkpoint)

# convert model to intended dtype
if config.dtype:
self._convert_to_dtype(config)
Expand All @@ -115,10 +112,6 @@ def __init__(self, model, config):
if moe and dist.get_world_size() > 1:
self._create_ep_parallel_group(config.moe.moe_experts)

# retain this from the old conditional argument being passed to apply_injection_policy()
if not config.replace_with_kernel_inject:
config.checkpoint = None

# We only support three modes: 1) user specified policy for tensor-parallelism, 2) kernel injection (replace_with_kernel_inject), and 3) automatic tensor parallelism.
if self.injection_dict:
# 1. User specified Tensor Parallelism
Expand Down Expand Up @@ -298,6 +291,12 @@ def load_model_with_checkpoint(self, r_module):
def load(module, state_dict, prefix):
args = (state_dict, prefix, {}, True, [], [], error_msgs)
if hasattr(module, 'weight'):
if module.weight.data.is_meta:
# meta tensor cannot be casted or copied to, so we need to replace it with a normal tensor here
module.weight = torch.nn.parameter.Parameter(
data=torch.empty_like(module.weight.data,
device="cpu"),
requires_grad=module.weight.data.requires_grad)
if 'query_key_value' in prefix:
module.weight = self.mp_replace.qkv_copy(
module.weight.data,
Expand All @@ -306,13 +305,31 @@ def load(module, state_dict, prefix):
module.weight = self.mp_replace.copy(module.weight.data,
state_dict[prefix + 'weight'])
else:
if module.norm.weight.data.is_meta:
# meta tensor cannot be casted or copied to, so we need to replace it with a normal tensor here
module.norm.weight = torch.nn.parameter.Parameter(
data=torch.empty_like(module.norm.weight.data,
device="cpu"),
requires_grad=module.norm.weight.data.requires_grad)
module.norm.weight = self.mp_replace.copy(module.norm.weight.data,
state_dict[prefix + 'weight'])
if prefix + 'bias' in self.key_list:
if hasattr(module, 'norm'):
if module.norm.bias.data.is_meta:
# meta tensor cannot be casted or copied to, so we need to replace it with a normal tensor here
module.norm.bias = torch.nn.parameter.Parameter(
data=torch.empty_like(module.norm.bias.data,
device="cpu"),
requires_grad=module.norm.bias.data.requires_grad)
module.norm.bias = self.mp_replace.copy(module.norm.bias,
state_dict[prefix + 'bias'])
else:
if module.bias.data.is_meta:
# meta tensor cannot be casted or copied to, so we need to replace it with a normal tensor here
module.bias = torch.nn.parameter.Parameter(
data=torch.empty_like(module.bias.data,
device="cpu"),
requires_grad=module.bias.data.requires_grad)
data = state_dict[prefix + 'bias']
data = data.to(get_accelerator().current_device_name())
module.bias = self.mp_replace.copy(module.bias, data)
Expand Down Expand Up @@ -346,6 +363,14 @@ def load_module_recursive(module, prefix='', level=0):

load_module_recursive(r_module)

embedding_weight = None

for n, p in r_module.named_parameters():
if "word_embeddings." in n or "embed_tokens." in n or "wte." in n:
embedding_weight = p
if embedding_weight is not None and r_module.lm_head.weight.is_meta:
r_module.lm_head.weight = embedding_weight

def _apply_injection_policy(self, config, client_module=None):
# client_module is only passed when using the injection_dict method.
checkpoint_dir = config.checkpoint
Expand Down Expand Up @@ -407,16 +432,18 @@ def _load_checkpoint(self, load_dir, load_module_strict=True, tag=None):
sd_loader = SDLoaderFactory.get_sd_loader_json(load_dir,
self.checkpoint_engine)

if type(sd_loader) is list:
self.sd = torch.load(sd_loader[0], map_location='cpu')
checkpoint = sd_loader['checkpoints']

if type(checkpoint) is list:
self.sd = torch.load(checkpoint[0], map_location='cpu')
self.key_list = list(self.sd.keys())

self.load_model_with_checkpoint(self.module)

for i in range(1, len(sd_loader)):
for i in range(1, len(checkpoint)):
if not dist.is_initialized() or dist.get_rank() == 0:
print(f"loading checkpoint ({i})")
self.sd = torch.load(sd_loader[i],
self.sd = torch.load(checkpoint[i],
map_location=get_accelerator().device_name())
self.key_list = list(self.sd.keys())
self.load_model_with_checkpoint(self.module)
Expand Down
142 changes: 128 additions & 14 deletions deepspeed/module_inject/replace_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

from .utils import policy_to_ds_container

import gc


class ReplaceWithTensorSlicing:
def __init__(self, mp_group=None, mp_size=1, out_dim=1, in_dim=0):
Expand Down Expand Up @@ -401,11 +403,13 @@ def replace_with_policy(child,

return _container.module

def replace_wo_policy(module, all_reduce_linears):
def replace_wo_policy(module, all_reduce_linears, prefix="", state_dict=None):
mp_size = config.tensor_parallel.tp_size
mp_group = config.tensor_parallel.tp_group

def _replace(child, name, conv_linear_layer):
if getattr(child, "replaced", False) == True:
return
mp_replace = ReplaceWithTensorSlicing(mp_group=mp_group)
weight_shape = child.weight.shape
if name in all_reduce_linears:
Expand All @@ -424,6 +428,7 @@ def _replace(child, name, conv_linear_layer):
dtype=child.weight.dtype)
if child.bias is not None:
new_bias.data.copy_(child.bias.data)
setattr(child, "replaced", True)
return LinearAllreduce(data, child.bias if child.bias is None else \
torch.nn.parameter.Parameter(new_bias.to(get_accelerator().current_device_name())), mp_group)
else:
Expand All @@ -444,11 +449,14 @@ def _replace(child, name, conv_linear_layer):
bias_data = None if child.bias is None else mp_replace.copy(
new_bias,
child.bias.data).to(get_accelerator().current_device_name())
setattr(child, "replaced", True)
return LinearLayer(weight=data.to(
get_accelerator().current_device_name()),
bias=bias_data)

def _slice_embedding(child, name, conv_linear_layer):
if getattr(child, "replaced", False) == True:
return
mp_replace = ReplaceWithTensorSlicing(mp_group=mp_group)
new_weight = torch.empty((child.weight.shape[0],
child.weight.shape[1] // mp_size),
Expand All @@ -460,9 +468,12 @@ def _slice_embedding(child, name, conv_linear_layer):
new_embedding = nn.Embedding(child.weight.shape[0],
child.weight.shape[1] // mp_size)
new_embedding.weight.data.copy_(data)
setattr(child, "replaced", True)
return new_embedding

def update_mp_params(child):
if getattr(child, "replaced", False) == True:
return
if hasattr(child, 'n_heads'):
assert child.n_heads%mp_size == 0, "n_heads ({}) must be divisible by mp_size ({})".format(child.n_heads, mp_size)
child.n_heads = child.n_heads // mp_size
Expand All @@ -487,6 +498,7 @@ def update_mp_params(child):
if hasattr(child, 'hidden_size'):
assert child.hidden_size%mp_size == 0, "hidden_size ({}) must be divisible by mp_size ({})".format(child.hidden_size, mp_size)
child.hidden_size = child.hidden_size // mp_size
setattr(child, "replaced", True)

conv_linear_layer = False
if linear_layer_setting is not None:
Expand All @@ -506,6 +518,14 @@ def update_mp_params(child):

def _replace_module(r_module, prev_name=''):
for name, child in r_module.named_children():
if child.__class__ in [nn.Linear,
nn.Embedding,
nn.LayerNorm] and state_dict != None:
full_prefix = prefix + '.' + prev_name + '.' + name + '.' if prev_name != "" else prefix + '.' + name + '.'
if prefix_check(full_prefix, state_dict):
load(child, state_dict, full_prefix, mp_group)
else:
continue
if child.__class__ in linear_policies:
setattr(
r_module,
Expand All @@ -520,7 +540,7 @@ def _replace_module(r_module, prev_name=''):

return _replace_module(module)

def replace_fn(child, _policy, layer_id=0):
def replace_fn(child, _policy, layer_id=0, prefix="", state_dict=None):
training = False # todo: refactor this part to go in the config
if training:
# copy relevant state from child -> new module
Expand All @@ -535,19 +555,37 @@ def replace_fn(child, _policy, layer_id=0):
inference=True,
layer_id=layer_id)
else:
new_module = replace_wo_policy(child, _policy)
new_module = replace_wo_policy(child,
_policy,
prefix=prefix,
state_dict=state_dict)

return new_module

replaced_module = replace_module(model=model,
orig_class=orig_layer_impl,
replace_fn=replace_fn,
_replace_policy=config.injection_policy_tuple)
if checkpoint_dict != None and not config.replace_with_kernel_inject:
# AutoTP shard loading
checkpoint = checkpoint_dict["checkpoints"]
pbar = tqdm.tqdm(total=len(checkpoint),
desc=f"Loading {len(checkpoint)} checkpoint shards")
for i in range(len(checkpoint)):
replaced_module = replace_module(
model=model,
orig_class=orig_layer_impl,
replace_fn=replace_fn,
_replace_policy=config.injection_policy_tuple,
checkpoint=checkpoint[i])
pbar.update(1)
gc.collect()
else:
replaced_module = replace_module(model=model,
orig_class=orig_layer_impl,
replace_fn=replace_fn,
_replace_policy=config.injection_policy_tuple)

quantizer = GroupQuantizer(q_int8=quantize)
world_size = dist.get_world_size() if dist.is_initialized() else 1
rank = dist.get_rank() if dist.is_initialized() else 0
if checkpoint_dict is not None:
if checkpoint_dict is not None and config.replace_with_kernel_inject:
assert container_g.ckpt_load_enabled, \
f"Meta Tensor checkpoint loading not supported in {container_g.__class__.__name__} container"
start_time = time.time()
Expand Down Expand Up @@ -577,7 +615,6 @@ def replace_fn(child, _policy, layer_id=0):
container=container_g)
pbar.update(1)
else:
import gc
num_checkpoints = len(ckpt_list) // ckpt_mp_size
tp_split_size = (world_size / ckpt_mp_size)
sd_offset = int(rank / tp_split_size)
Expand Down Expand Up @@ -778,7 +815,7 @@ def replace_fn(child, _replace_policy, layer_id):
_replace_policy=None)


def replace_module(model, orig_class, replace_fn, _replace_policy):
def replace_module(model, orig_class, replace_fn, _replace_policy, checkpoint=None):
""" Scan the model for instances of ``orig_clas:`` to replace using ``replace_fn``.
Arguments:
model (torch.nn.Module): the model to augment
Expand All @@ -788,6 +825,9 @@ def replace_module(model, orig_class, replace_fn, _replace_policy):
Returns:
A modified ``model``.
"""
sd = None
if checkpoint != None:
sd = torch.load(checkpoint, map_location='cpu')
policy = {}
if orig_class is not None:
policy.update({orig_class: (replace_fn, _replace_policy)})
Expand All @@ -804,35 +844,109 @@ def replace_module(model, orig_class, replace_fn, _replace_policy):
"No default policy found! Please specify your policy injection_policy (like {BertLayer:HFBEertLayerPolicy})." +\
"You can find some samples here: https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/module_inject/replace_policy.py"

replaced_module, _ = _replace_module(model, policy)
replaced_module, _ = _replace_module(model, policy, state_dict=sd)

for n, p in replaced_module.named_parameters():
if "word_embeddings." in n or "embed_tokens." in n or "wte." in n:
embedding_weight = p
if embedding_weight is not None and replaced_module.lm_head.weight.is_meta:
replaced_module.lm_head.weight = embedding_weight
return replaced_module


from ..pipe import PipelineModule

import re


def prefix_check(name, checkpoint_dict):
# if keys start with 'model.', don't skip level 0 prefix
for key in checkpoint_dict.keys():
if re.match(name, key):
return True
return False


def _replace_module(model, policies, layer_id=0):
def _replace_module(model, policies, prefix='', layer_id=0, level_id=0, state_dict=None):
""" Traverse model's children recursively and apply any transformations in ``policies``.
Arguments:
model (torch.nn.Module): model to augment
policies (dict): Mapping of source class to replacement function.
Returns:
Modified ``model``.
"""
load_layers = [nn.Linear, nn.Embedding, nn.LayerNorm]
for name, child in model.named_children():
if child.__class__ in policies:
replaced_module = policies[child.__class__][0](child,
policies[child.__class__][-1],
layer_id)
layer_id,
prefix=prefix + name,
state_dict=state_dict)
setattr(model, name, replaced_module)
if isinstance(model, PipelineModule):
assert hasattr(model, 'forward_funcs'),\
"we require pipe-module to have the list of fwd_functions"
model.forward_funcs[model.fwd_map[name]] = replaced_module
layer_id += 1
else:
_, layer_id = _replace_module(child, policies, layer_id=layer_id)
if child.__class__ in load_layers and state_dict != None:
if prefix_check(prefix + name + '.', state_dict):
load(
child,
state_dict,
prefix + name + '.',
)
else:
continue
_, layer_id = _replace_module(child, policies, prefix if level_id == 0 else prefix + name + '.',layer_id=layer_id, level_id=level_id+1, state_dict=state_dict)

# Add the reset_cache func to the model, so that it can be called in the beginning of text-generation.
model.reset_cache = transformer_inference.DeepSpeedTransformerInference.reset_cache
return model, layer_id


def load(module, state_dict, prefix, mp_group=None):
mp_replace = ReplaceWithTensorSlicing(mp_group=mp_group)
if hasattr(module, 'weight'):
if module.weight.data.is_meta:
# meta tensor cannot be casted or copied to, so we need to replace it with a normal tensor here
module.weight = torch.nn.parameter.Parameter(
data=torch.empty_like(module.weight.data,
device="cpu"),
requires_grad=module.weight.data.requires_grad)
if 'query_key_value' in prefix:
module.weight = mp_replace.qkv_copy(module.weight.data,
state_dict[prefix + 'weight'])
else:
module.weight = mp_replace.copy(module.weight.data,
state_dict[prefix + 'weight'])
else:
if hasattr(module, 'norm') and hasattr(module.norm, 'weight'):
if module.norm.weight.data.is_meta:
# meta tensor cannot be casted or copied to, so we need to replace it with a normal tensor here
module.norm.weight = torch.nn.parameter.Parameter(
data=torch.empty_like(module.norm.weight.data,
device="cpu"),
requires_grad=module.norm.weight.data.requires_grad)
module.norm.weight = mp_replace.copy(module.norm.weight.data,
state_dict[prefix + 'weight'])

if hasattr(module, 'bias'):
if module.bias.data.is_meta:
# meta tensor cannot be casted or copied to, so we need to replace it with a normal tensor here
module.bias = torch.nn.parameter.Parameter(
data=torch.empty_like(module.bias.data,
device="cpu"),
requires_grad=module.bias.data.requires_grad)
module.bias = mp_replace.copy(module.bias, state_dict[prefix + 'bias'])
else:
if hasattr(module, 'norm') and hasattr(module.norm, 'bias'):
if module.norm.bias.data.is_meta:
# meta tensor cannot be casted or copied to, so we need to replace it with a normal tensor here
module.norm.bias = torch.nn.parameter.Parameter(
data=torch.empty_like(module.norm.bias.data,
device="cpu"),
requires_grad=module.norm.bias.data.requires_grad)
module.norm.bias = mp_replace.copy(module.norm.bias,
state_dict[prefix + 'bias'])

0 comments on commit b48734a

Please sign in to comment.