Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixing the run and model scripts for running the BingBertSquad #58

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 56 additions & 24 deletions BingBertSquad/convert_bert_ckpt_to_deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


def set_data(param, array):
try:
assert param.shape == array.shape
Expand All @@ -22,6 +23,7 @@ def set_data(param, array):
raise
param.data = torch.from_numpy(array)


def load_tf_weights_in_bert_kernel(model, ckpt_path, voc_size_diff):
""" Load tf checkpoints in DeepSpeed model.
"""
Expand Down Expand Up @@ -52,10 +54,10 @@ def load_tf_weights_in_bert_kernel(model, ckpt_path, voc_size_diff):
name = name_str.split("/")
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
# which are not required for using pretrained model
if any(
n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
for n in name
):
if any(n in [
"adam_v", "adam_m", "AdamWeightDecayOptimizer",
"AdamWeightDecayOptimizer_1", "global_step"
] for n in name):
logger.info("Skipping {}".format("/".join(name)))
continue
pointer = model
Expand All @@ -76,11 +78,14 @@ def load_tf_weights_in_bert_kernel(model, ckpt_path, voc_size_diff):
elif scope_names[0] == "squad":
pointer = getattr(pointer, "classifier")
# Special in deepspeed.
elif name_str.find("bert/pooler/dense") >= 0 and scope_names[0] == "dense":
elif name_str.find(
"bert/pooler/dense") >= 0 and scope_names[0] == "dense":
pointer = getattr(pointer, "dense_act")
elif name_str.find("bert/embeddings/LayerNorm/gamma") >= 0 and scope_names[0] == "gamma":
elif name_str.find("bert/embeddings/LayerNorm/gamma"
) >= 0 and scope_names[0] == "gamma":
pointer = getattr(pointer, "weight")
elif name_str.find("bert/embeddings/LayerNorm/beta") >= 0 and scope_names[0] == "beta":
elif name_str.find("bert/embeddings/LayerNorm/beta"
) >= 0 and scope_names[0] == "beta":
pointer = getattr(pointer, "bias")
else:
try:
Expand Down Expand Up @@ -121,16 +126,26 @@ def load_tf_weights_in_bert_kernel(model, ckpt_path, voc_size_diff):
pointer = getattr(pointer, "inter_w")
elif name_str.find("intermediate/dense/bias") > 0:
pointer = getattr(pointer, "inter_b")
elif name_str.find("output/dense/kernel") > 0 and name_str.find("attention") < 0:
elif name_str.find(
"output/dense/kernel") > 0 and name_str.find(
"attention") < 0:
pointer = getattr(pointer, "output_w")
elif name_str.find("output/dense/bias") > 0 and name_str.find("attention") < 0:
elif name_str.find(
"output/dense/bias") > 0 and name_str.find(
"attention") < 0:
pointer = getattr(pointer, "output_b")
elif name_str.find("output/LayerNorm/gamma") > 0 and name_str.find("attention") < 0:
elif name_str.find(
"output/LayerNorm/gamma") > 0 and name_str.find(
"attention") < 0:
pointer = getattr(pointer, "norm_w")
elif name_str.find("output/LayerNorm/beta") > 0 and name_str.find("attention") < 0:
elif name_str.find(
"output/LayerNorm/beta") > 0 and name_str.find(
"attention") < 0:
pointer = getattr(pointer, "norm_b")
else:
raise ValueError(f"unexpect scope name {name_str} in transformer layer.")
raise ValueError(
f"unexpect scope name {name_str} in transformer layer."
)
break

if skipping:
Expand Down Expand Up @@ -161,7 +176,8 @@ def load_tf_weights_in_bert_kernel(model, ckpt_path, voc_size_diff):
continue

# DeepSpeed BERT model has voc_size 8 aligned.
if voc_size_diff > 0 and name_str.find("embeddings/word_embeddings") >= 0:
if voc_size_diff > 0 and name_str.find(
"embeddings/word_embeddings") >= 0:
z = np.zeros((voc_size_diff, array.shape[1]), dtype=array.dtype)
array = np.concatenate((array, z), axis=0)

Expand All @@ -170,6 +186,7 @@ def load_tf_weights_in_bert_kernel(model, ckpt_path, voc_size_diff):

return model


def load_hf_weights_in_bert_kernel(model, ckpt_path, voc_size_diff):
""" Load huggingface checkpoints and convert to a deepspeed model.
"""
Expand All @@ -181,7 +198,8 @@ def load_hf_weights_in_bert_kernel(model, ckpt_path, voc_size_diff):
qkv = {}
for name_str in ckpt.keys():
array = ckpt[name_str].numpy()
logger.info("Loading Huggingface weight {} with shape {}".format(name_str, array.shape))
logger.info("Loading Huggingface weight {} with shape {}".format(
name_str, array.shape))
name = name_str.split(".")
pointer = model
key = None
Expand Down Expand Up @@ -235,16 +253,22 @@ def load_hf_weights_in_bert_kernel(model, ckpt_path, voc_size_diff):
pointer = getattr(pointer, "inter_w")
elif name_str.find("intermediate.dense.bias") > 0:
pointer = getattr(pointer, "inter_b")
elif name_str.find("output.dense.weight") > 0 and name_str.find("attention") < 0:
elif name_str.find("output.dense.weight"
) > 0 and name_str.find("attention") < 0:
pointer = getattr(pointer, "output_w")
elif name_str.find("output.dense.bias") > 0 and name_str.find("attention") < 0:
elif name_str.find("output.dense.bias") > 0 and name_str.find(
"attention") < 0:
pointer = getattr(pointer, "output_b")
elif name_str.find("output.LayerNorm.weight") > 0 and name_str.find("attention") < 0:
elif name_str.find("output.LayerNorm.weight"
) > 0 and name_str.find("attention") < 0:
pointer = getattr(pointer, "norm_w")
elif name_str.find("output.LayerNorm.bias") > 0 and name_str.find("attention") < 0:
elif name_str.find("output.LayerNorm.bias"
) > 0 and name_str.find("attention") < 0:
pointer = getattr(pointer, "norm_b")
else:
raise ValueError(f"unexpect scope name {name_str} in transformer layer.")
raise ValueError(
f"unexpect scope name {name_str} in transformer layer."
)
break

if skipping:
Expand All @@ -270,7 +294,8 @@ def load_hf_weights_in_bert_kernel(model, ckpt_path, voc_size_diff):
continue

# DeepSpeed BERT model has voc_size 8 aligned.
if voc_size_diff > 0 and name_str.find("embeddings.word_embeddings") >= 0:
if voc_size_diff > 0 and name_str.find(
"embeddings.word_embeddings") >= 0:
z = np.zeros((voc_size_diff, array.shape[1]), dtype=array.dtype)
array = np.concatenate((array, z), axis=0)

Expand All @@ -279,6 +304,7 @@ def load_hf_weights_in_bert_kernel(model, ckpt_path, voc_size_diff):

return model


def load_hf_weights_in_bert_torch(model, ckpt_path, voc_size_diff):
""" Load huggingface checkpoints and convert to a deepspeed model.
"""
Expand All @@ -290,7 +316,8 @@ def load_hf_weights_in_bert_torch(model, ckpt_path, voc_size_diff):
qkv = {}
for name_str in ckpt.keys():
array = ckpt[name_str].numpy()
logger.info("Loading Huggingface weight {} with shape {}".format(name_str, array.shape))
logger.info("Loading Huggingface weight {} with shape {}".format(
name_str, array.shape))
name = name_str.split(".")
pointer = model
key = None
Expand All @@ -314,7 +341,8 @@ def load_hf_weights_in_bert_torch(model, ckpt_path, voc_size_diff):
continue

# DeepSpeed BERT model has voc_size 8 aligned.
if voc_size_diff > 0 and name_str.find("embeddings.word_embeddings") >= 0:
if voc_size_diff > 0 and name_str.find(
"embeddings.word_embeddings") >= 0:
z = np.zeros((voc_size_diff, array.shape[1]), dtype=array.dtype)
array = np.concatenate((array, z), axis=0)

Expand All @@ -323,7 +351,9 @@ def load_hf_weights_in_bert_torch(model, ckpt_path, voc_size_diff):

return model

def convert_ckpt_to_deepspeed(model, ckpt_type, ckpt_path, vocab_diff, kernel_enabled):

def convert_ckpt_to_deepspeed(model, ckpt_type, ckpt_path, vocab_diff,
kernel_enabled):

# Load weights from checkpoint
if ckpt_type == "HF":
Expand All @@ -335,6 +365,8 @@ def convert_ckpt_to_deepspeed(model, ckpt_type, ckpt_path, vocab_diff, kernel_en
if kernel_enabled:
load_tf_weights_in_bert_kernel(model, ckpt_path, vocab_diff)
else:
raise ValueError("--deepspeed_transformer_kernel is required for loading TF checkpoint.")
raise ValueError(
"--deepspeed_transformer_kernel is required for loading TF checkpoint."
)
else:
raise ValueError(f"Invalid ckpt_type.")
43 changes: 28 additions & 15 deletions BingBertSquad/nvidia_run_squad_deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,11 +795,15 @@ def main():
else:
# Models from Tensorflow and Huggingface are post-LN.
if args.preln:
raise ValueError("Should NOT use --preln if the loading checkpoint doesn't use pre-layer-norm.")
raise ValueError(
"Should NOT use --preln if the loading checkpoint doesn't use pre-layer-norm."
)

# Use the original bert config if want to load from non-DeepSpeed checkpoint.
if args.origin_bert_config_file is None:
raise ValueError("--origin_bert_config_file is required for loading non-DeepSpeed checkpoint.")
raise ValueError(
"--origin_bert_config_file is required for loading non-DeepSpeed checkpoint."
)

bert_config = BertConfig.from_json_file(args.origin_bert_config_file)

Expand All @@ -812,6 +816,7 @@ def main():
vocab_diff = 8 - (bert_config.vocab_size % 8)
bert_config.vocab_size += vocab_diff

torch.distributed.init_process_group(backend='nccl')
if args.preln:
model = BertForQuestionAnsweringPreLN(bert_config, args)
else:
Expand All @@ -822,20 +827,22 @@ def main():
logger.info(f"Loading Pretrained Bert Encoder from: {args.model_file}")

if args.ckpt_type == "DS":
checkpoint_state_dict = torch.load(args.model_file,
map_location=torch.device("cpu"))
checkpoint_state_dict = torch.load(
args.model_file, map_location=torch.device("cpu"))
if 'module' in checkpoint_state_dict:
logger.info('Loading DeepSpeed v2.0 style checkpoint')
model.load_state_dict(checkpoint_state_dict['module'],
strict=False)
elif 'model_state_dict' in checkpoint_state_dict:
model.load_state_dict(checkpoint_state_dict['model_state_dict'],
strict=False)
model.load_state_dict(
checkpoint_state_dict['model_state_dict'], strict=False)
else:
raise ValueError("Unable to find model state in checkpoint")
else:
from convert_bert_ckpt_to_deepspeed import convert_ckpt_to_deepspeed
convert_ckpt_to_deepspeed(model, args.ckpt_type, args.model_file, vocab_diff, args.deepspeed_transformer_kernel)
convert_ckpt_to_deepspeed(model, args.ckpt_type, args.model_file,
vocab_diff,
args.deepspeed_transformer_kernel)

logger.info(f"Pretrained Bert Encoder Loaded from: {args.model_file}")

Expand All @@ -852,7 +859,7 @@ def main():
[p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
'weight_decay':
0.01
},{
}, {
'params':
[p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
'weight_decay':
Expand All @@ -864,7 +871,7 @@ def main():
model=model,
model_parameters=optimizer_grouped_parameters,
dist_init_required=True)

if args.local_rank == -1 or args.no_cuda:
device = torch.device("cuda" if torch.cuda.is_available()
and not args.no_cuda else "cpu")
Expand Down Expand Up @@ -911,8 +918,6 @@ def main():
else:
args.summary_writer = None



logger.info("propagate deepspeed-config settings to client settings")
args.train_batch_size = model.train_micro_batch_size_per_gpu()
args.gradient_accumulation_steps = model.gradient_accumulation_steps()
Expand Down Expand Up @@ -1056,12 +1061,20 @@ def main():
f'Warning: Early epoch termination due to max steps limit, epoch step ={epoch_step}, global step = {global_step}, epoch = {num_epoch}'
)
break
one_step_time = time.time() -start_time
one_step_time = time.time() - start_time
all_step_time += one_step_time
if (step + 1)%(ave_rounds) == 0 and torch.distributed.get_rank() == 0:
print('At Step {}, Averaged Throughput for {} rounds is: {} Samples/s'.format(step, ave_rounds, bs_size * ave_rounds * torch.distributed.get_world_size() / all_step_time ), flush=True )
if (step + 1) % (
ave_rounds) == 0 and torch.distributed.get_rank() == 0:
print(
'At Step {}, Averaged Throughput for {} rounds is: {} Samples/s'
.format(
step, ave_rounds,
bs_size * ave_rounds *
torch.distributed.get_world_size() /
all_step_time),
flush=True)
all_step_time = 0.0

# Save a trained model
# model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
#output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
Expand Down
2 changes: 1 addition & 1 deletion BingBertSquad/run_squad_deepspeed.sh
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ else
GRAD_ACCUM_STEPS=$((PER_GPU_BATCH_SIZE/MAX_GPU_BATCH_SIZE))
fi
JOB_NAME="deepspeed_${NGPU}GPUs_${EFFECTIVE_BATCH_SIZE}batch_size"
config_json=onebit_deepspeed_bsz24_config.json
config_json=deepspeed_bsz24_config.json
run_cmd="deepspeed --num_nodes ${NUM_NODES} --num_gpus ${NGPU_PER_NODE} \
--master_port=${MASTER_PORT} \
--hostfile ${HOSTFILE} \
Expand Down
2 changes: 2 additions & 0 deletions BingBertSquad/turing/nvidia_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,8 @@ def __init__(self, config, args):
hidden_dropout_ratio=config.hidden_dropout_prob,
num_hidden_layers=config.num_hidden_layers,
initializer_range=config.initializer_range,
local_rank=args.local_rank
if hasattr(args, 'local_rank') else -1,
seed=args.seed,
fp16=ds_config.fp16_enabled,
pre_layer_norm=False)
Expand Down
2 changes: 2 additions & 0 deletions BingBertSquad/turing/nvidia_modelingpreln.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,6 +537,8 @@ def __init__(self, config, args):
hidden_dropout_ratio=config.hidden_dropout_prob,
num_hidden_layers=config.num_hidden_layers,
initializer_range=config.initializer_range,
local_rank=args.local_rank
if hasattr(args, 'local_rank') else -1,
seed=args.seed,
fp16=ds_config.fp16_enabled,
pre_layer_norm=True)
Expand Down
7 changes: 5 additions & 2 deletions BingBertSquad/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,13 +213,16 @@ def get_argument_parser():
'--ckpt_type',
type=str,
default="DS",
help="Checkpoint's type, DS - DeepSpeed, TF - Tensorflow, HF - Huggingface.")
help=
"Checkpoint's type, DS - DeepSpeed, TF - Tensorflow, HF - Huggingface."
)

parser.add_argument(
"--origin_bert_config_file",
type=str,
default=None,
help="The config json file corresponding to the non-DeepSpeed pre-trained BERT model."
help=
"The config json file corresponding to the non-DeepSpeed pre-trained BERT model."
)

return parser
Expand Down