Skip to content

Commit

Permalink
Merge pull request #40 from bigcode-project/mqa-checkpoint-utils
Browse files Browse the repository at this point in the history
support mqa in checkpoint-merging tools
  • Loading branch information
RaymondLi0 authored May 8, 2023
2 parents 22b8611 + 57f21b7 commit 1a7d54b
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 17 deletions.
2 changes: 1 addition & 1 deletion megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,7 @@ def _add_network_size_args(parser):
'attention. This is set to '
' args.hidden_size // args.num_attention_heads '
'if not provided.')
group.add_argument('--attention-head-type', type=str, default='multihead',
group.add_argument('--attention-head-type', type=str, default=None,
choices=['multihead', 'multiquery'],
help='Type of attention heads. `multihead` is the standard multi-head attention.'
'`multiquery` shares the values and keys across attention heads')
Expand Down
20 changes: 12 additions & 8 deletions megatron/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def ensure_directory_exists(filename):


def get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer, release=False,
pipeline_parallel=None, tensor_rank=None, pipeline_rank=None):
pipeline_parallel=None, tensor_rank=None, pipeline_rank=None, only_model=False):
"""Determine the directory name for this rank's checkpoint."""
if release:
directory = 'release'
Expand All @@ -119,7 +119,7 @@ def get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer,

if use_distributed_optimizer:
model_name = os.path.join(common_path, "model_rng.pt")
optim_name = os.path.join(
optim_name = None if only_model else os.path.join(
common_path + "_%03d" % mpu.get_data_parallel_rank(),
"optim.pt")
else:
Expand All @@ -139,14 +139,14 @@ def find_checkpoint_rank_0(checkpoints_path, iteration, use_distributed_optimize
# Look for checkpoint with no pipelining
filenames = get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer, release,
pipeline_parallel=False,
tensor_rank=0, pipeline_rank=0)
tensor_rank=0, pipeline_rank=0, only_model=True)
if os.path.isfile(filenames[0]):
return filenames

# Look for checkpoint with pipelining
filenames = get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer, release,
pipeline_parallel=True,
tensor_rank=0, pipeline_rank=0)
tensor_rank=0, pipeline_rank=0, only_model=True)
if os.path.isfile(filenames[0]):
return filenames

Expand Down Expand Up @@ -379,10 +379,11 @@ def fix_query_key_value_ordering(model, checkpoint_version):
print_rank_0(" succesfully fixed query-key-values ordering for"
" checkpoint version {}".format(checkpoint_version))

def _load_base_checkpoint(load_dir, use_distributed_optimizer, rank0=False, iteration=None, release=None):
def _load_base_checkpoint(load_dir, use_distributed_optimizer, rank0=False, iteration=None, release=None, no_load_optim=False):
""" Load the base state_dict from the given directory
If rank0 is true, just loads rank 0 checkpoint, ignoring arguments.
If rank0 is true or no_load_optim is true, we do not care about the optimizer, only the model checkpoint.
"""

# Read the tracker file and set the iteration.
Expand All @@ -408,7 +409,7 @@ def _load_base_checkpoint(load_dir, use_distributed_optimizer, rank0=False, iter
release)
else:
checkpoint_names = get_checkpoint_names(load_dir, iteration, use_distributed_optimizer,
release)
release, only_model=no_load_optim)
if release:
print_rank_0(f' loading release checkpoint from {load_dir}')
else:
Expand All @@ -419,7 +420,9 @@ def _load_base_checkpoint(load_dir, use_distributed_optimizer, rank0=False, iter
# Load the checkpoint.
try:
model_state_dict = torch.load(model_checkpoint_name, map_location='cpu')
if use_distributed_optimizer:
if rank0 or no_load_optim:
optim_state_dict = None
elif use_distributed_optimizer:
optim_state_dict = torch.load(optim_checkpoint_name, map_location='cpu')
else:
optim_state_dict = model_state_dict
Expand Down Expand Up @@ -572,7 +575,8 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
use_distributed_optimizer=args.use_distributed_optimizer,
rank0=False,
iteration=iteration,
release=release)
release=release,
no_load_optim=args.no_load_optim)

if model_state_dict is None:
return 0
Expand Down
27 changes: 23 additions & 4 deletions tools/checkpoint_loader_megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ def _load_checkpoint(queue, args):
'--no-initialization',
'--load', args.load_dir
]
if args.use_distributed_optimizer:
sys.argv.append("--use-distributed-optimizer")


margs = parse_args()
margs = load_args_from_checkpoint(margs)
Expand Down Expand Up @@ -78,6 +81,7 @@ def check_for_arg(arg_name):
check_for_arg('iteration')
check_for_arg('bert_binary_head')
check_for_arg('params_dtype')
check_for_arg('attention_head_type')

# Determine how to make our models
if args.model_type == 'GPT':
Expand Down Expand Up @@ -147,6 +151,7 @@ def get_models(count, dtype, pre_process, post_process):
# metadata
md = types.SimpleNamespace()
md.model_type = args.model_type
md.attention_head_type = margs.attention_head_type
md.num_layers = margs.num_layers
md.hidden_size = margs.hidden_size
md.seq_length = margs.seq_length
Expand Down Expand Up @@ -202,26 +207,40 @@ def queue_put(name, msg):
message["post layernorm weight"] = layer.post_attention_layernorm.weight.data
message["post layernorm bias"] = layer.post_attention_layernorm.bias.data
message["mlp l1 bias"] = layer.mlp.dense_4h_to_h.bias.data
if margs.attention_head_type == "multiquery":
# MQA: kv is shared across tp-ranks
message["kv weight"] = layer.self_attention.key_value.weight.data
message["kv bias"] = layer.self_attention.key_value.bias.data

# Grab all parallel tensors for this layer
qkv_weight = []
qkv_bias = []
q_weight = []
q_bias = []
dense_weight = []
mlp_l0_weight = []
mlp_l0_bias = []
mlp_l1_weight = []
for tp_rank, model in enumerate(models):
layer = model.language_model.encoder.layers[layer_num]
qkv_weight.append(layer.self_attention.query_key_value.weight.data)
qkv_bias.append(layer.self_attention.query_key_value.bias.data)
if margs.attention_head_type == "multihead":
qkv_weight.append(layer.self_attention.query_key_value.weight.data)
qkv_bias.append(layer.self_attention.query_key_value.bias.data)
elif margs.attention_head_type == "multiquery":
q_weight.append(layer.self_attention.query.weight.data)
q_bias.append(layer.self_attention.query.bias.data)
dense_weight.append(layer.self_attention.dense.weight.data)
mlp_l0_weight.append(layer.mlp.dense_h_to_4h.weight.data)
mlp_l0_bias.append(layer.mlp.dense_h_to_4h.bias.data)
mlp_l1_weight.append(layer.mlp.dense_4h_to_h.weight.data)

# concat them
message["qkv weight"] = torch.cat(qkv_weight, dim=0)
message["qkv bias"] = torch.cat(qkv_bias, dim=0)
if margs.attention_head_type == "multihead":
message["qkv weight"] = torch.cat(qkv_weight, dim=0)
message["qkv bias"] = torch.cat(qkv_bias, dim=0)
elif margs.attention_head_type == "multiquery":
message["q weight"] = torch.cat(q_weight, dim=0)
message["q bias"] = torch.cat(q_bias, dim=0)
message["dense weight"] = torch.cat(dense_weight, dim=1)
message["mlp l0 weight"] = torch.cat(mlp_l0_weight, dim=0)
message["mlp l0 bias"] = torch.cat(mlp_l0_bias, dim=0)
Expand Down
23 changes: 19 additions & 4 deletions tools/checkpoint_saver_megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def check_message(msg):
'--seq-length', str(md.seq_length),
'--num-attention-heads', str(md.num_attention_heads),
'--max-position-embeddings', str(md.max_position_embeddings),
'--attention-head-type', str(md.attention_head_type),
'--tokenizer-type', str(md.tokenizer_type),
'--tensor-model-parallel-size', str(args.target_tensor_parallel_size),
'--pipeline-model-parallel-size', str(args.target_pipeline_parallel_size),
Expand Down Expand Up @@ -225,10 +226,17 @@ def get_models(count, dtype, pre_process, post_process):
post_layernorm_weight = msg.pop("post layernorm weight")
post_layernorm_bias = msg.pop("post layernorm bias")
mlp_l1_bias = msg.pop("mlp l1 bias")
if margs.attention_head_type == "multiquery":
kv_weight = msg.pop("kv weight")
kv_bias = msg.pop("kv bias")

# Split up the parallel tensors
qkv_weight = torch.chunk(msg.pop("qkv weight"), args.target_tensor_parallel_size, dim=0)
qkv_bias = torch.chunk(msg.pop("qkv bias"), args.target_tensor_parallel_size, dim=0)
if margs.attention_head_type == "multihead":
qkv_weight = torch.chunk(msg.pop("qkv weight"), args.target_tensor_parallel_size, dim=0)
qkv_bias = torch.chunk(msg.pop("qkv bias"), args.target_tensor_parallel_size, dim=0)
elif margs.attention_head_type == "multiquery":
q_weight = torch.chunk(msg.pop("q weight"), args.target_tensor_parallel_size, dim=0)
q_bias = torch.chunk(msg.pop("q bias"), args.target_tensor_parallel_size, dim=0)
dense_weight = torch.chunk(msg.pop("dense weight"), args.target_tensor_parallel_size, dim=1)
mlp_l0_weight = torch.chunk(msg.pop("mlp l0 weight"), args.target_tensor_parallel_size, dim=0)
mlp_l0_bias = torch.chunk(msg.pop("mlp l0 bias"), args.target_tensor_parallel_size, dim=0)
Expand All @@ -239,8 +247,15 @@ def get_models(count, dtype, pre_process, post_process):
l = models[tp_rank].language_model.encoder.layers[layer]
l.input_layernorm.weight.data.copy_(input_layernorm_weight)
l.input_layernorm.bias.data.copy_(input_layernorm_bias)
l.self_attention.query_key_value.weight.data.copy_(qkv_weight[tp_rank])
l.self_attention.query_key_value.bias.data.copy_(qkv_bias[tp_rank])
if margs.attention_head_type == "multihead":
l.self_attention.query_key_value.weight.data.copy_(qkv_weight[tp_rank])
l.self_attention.query_key_value.bias.data.copy_(qkv_bias[tp_rank])
elif margs.attention_head_type == "multiquery":
# MQA: key-value are shared across tp-ranks
l.self_attention.key_value.weight.data.copy_(kv_weight)
l.self_attention.key_value.bias.data.copy_(kv_bias)
l.self_attention.query.weight.data.copy_(q_weight[tp_rank])
l.self_attention.query.bias.data.copy_(q_bias[tp_rank])
l.self_attention.dense.weight.data.copy_(dense_weight[tp_rank])
l.self_attention.dense.bias.data.copy_(dense_bias)
l.post_attention_layernorm.weight.data.copy_(post_layernorm_weight)
Expand Down
3 changes: 3 additions & 0 deletions tools/checkpoint_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,9 @@ def main():
parser.add_argument('--no-checking', action='store_false',
help='Do not perform checking on the name and ordering of weights',
dest='checking')

parser.add_argument('--use-distributed-optimizer', action='store_true',
help='Loaded checkpoint uses distributed optimizer.')

known_args, _ = parser.parse_known_args()
loader = load_plugin('loader', known_args.loader)
Expand Down

0 comments on commit 1a7d54b

Please sign in to comment.