From 6feab42642ab05fa8d18f45aab42067e270b02f4 Mon Sep 17 00:00:00 2001 From: Cheng hou <59219579+hhou435@users.noreply.github.com> Date: Wed, 24 Jan 2024 19:34:47 +0800 Subject: [PATCH] fix bug (#123) * Update word_embedding.py * Update transformer_encoder.py --- tencentpretrain/embeddings/word_embedding.py | 2 +- tencentpretrain/encoders/transformer_encoder.py | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/tencentpretrain/embeddings/word_embedding.py b/tencentpretrain/embeddings/word_embedding.py index c7428dda..38e4e7f4 100644 --- a/tencentpretrain/embeddings/word_embedding.py +++ b/tencentpretrain/embeddings/word_embedding.py @@ -9,7 +9,7 @@ class WordEmbedding(nn.Module): def __init__(self, args, vocab_size): super(WordEmbedding, self).__init__() - if args.tensor_model_parallel_size > 1: + if hasattr(args, "tensor_model_parallel_size") and args.tensor_model_parallel_size > 1: self.embedding = mpu.VocabParallelEmbedding(vocab_size, args.emb_size) else: self.embedding = nn.Embedding(vocab_size, args.emb_size) diff --git a/tencentpretrain/encoders/transformer_encoder.py b/tencentpretrain/encoders/transformer_encoder.py index e84b1ed4..0552977e 100644 --- a/tencentpretrain/encoders/transformer_encoder.py +++ b/tencentpretrain/encoders/transformer_encoder.py @@ -20,7 +20,10 @@ def __init__(self, args): self.relative_position_embedding = args.relative_position_embedding self.rotary_position_embedding = args.rotary_position_embedding self.has_residual_attention = args.has_residual_attention - self.tensor_model_parallel_size = args.tensor_model_parallel_size + if hasattr(args, "tensor_model_parallel_size"): + self.tensor_model_parallel_size = args.tensor_model_parallel_size + else: + self.tensor_model_parallel_size = 1 if self.relative_position_embedding: args.relative_pos_emb = RelativePositionEmbedding(bidirectional=True, heads_num=args.heads_num, @@ -28,7 +31,7 @@ def __init__(self, args): elif self.rotary_position_embedding: args.freqs_cis = precompute_freqs_cis(args.hidden_size // args.heads_num, args.max_seq_length * 2) - if "deepspeed_checkpoint_activations" in args: + if hasattr(args, "deepspeed_checkpoint_activations"): self.deepspeed_checkpoint_activations = args.deepspeed_checkpoint_activations self.deepspeed_checkpoint_layers_num = args.deepspeed_checkpoint_layers_num else: