Skip to content

Commit

Permalink
Fixed some bugs regarding activation checkpoints and updated the BPE …
Browse files Browse the repository at this point in the history
…vocabulary loader (#125)

* Add token counter, update BPE vocab init

* Add special token security check

* update no_decay list

* [Fix] Fixed the impact of passing parameters on activation checkpointing.

* update

* update

* update

---------

Co-authored-by: kaeli <[email protected]>
  • Loading branch information
wmpscc and kaeli authored Apr 26, 2024
1 parent dc155e4 commit 27547a4
Show file tree
Hide file tree
Showing 9 changed files with 93 additions and 56 deletions.
6 changes: 3 additions & 3 deletions tencentpretrain/encoders/transformer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,14 +130,14 @@ def custom_forward(*inputs):
mpu.reset_checkpointed_activations_memory_buffer()
l = 0
while l < self.layers_num:
inputs = checkpointing.checkpoint(custom(l, l + self.deepspeed_checkpoint_layers_num), inputs)
inputs = checkpointing.checkpoint(custom(l, l + self.deepspeed_checkpoint_layers_num), *inputs)
l += self.deepspeed_checkpoint_layers_num
else:
for i in range(self.layers_num):
if self.parameter_sharing:
inputs = self.transformer(inputs)
inputs = self.transformer(*inputs)
else:
inputs = self.transformer[i](inputs)
inputs = self.transformer[i](*inputs)

hidden = inputs[0]

Expand Down
3 changes: 2 additions & 1 deletion tencentpretrain/layers/multi_headed_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,8 @@ def unshape(x):
scores += prev_attn
prev_attn_out = scores

probs = nn.Softmax(dim=-1)(scores)
# probs = nn.Softmax(dim=-1)(scores)
probs = nn.functional.softmax(scores, dim=-1, dtype=torch.float32).to(query.dtype)
probs = self.dropout(probs)
output = unshape(torch.matmul(probs, value))
output = self.final_linear(output)
Expand Down
19 changes: 10 additions & 9 deletions tencentpretrain/layers/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from tencentpretrain.layers.multi_headed_attn import MultiHeadedAttention, ParallelMultiHeadedAttention
from tencentpretrain.layers import *


class TransformerLayer(nn.Module):
"""
Transformer layer mainly consists of two parts:
Expand Down Expand Up @@ -40,7 +41,7 @@ def __init__(self, args, layer_number=None):

self.self_attn = MultiHeadedAttention(
args.hidden_size, args.heads_num, attention_head_size, local_kv_heads_num, args.dropout, has_bias=has_bias,
with_scale = with_scale, lora_params=lora_params, layer_number=layer_number
with_scale=with_scale, lora_params=lora_params, layer_number=layer_number
)
self.dropout_1 = nn.Dropout(args.dropout)

Expand All @@ -53,7 +54,7 @@ def __init__(self, args, layer_number=None):
self.layer_norm_1 = str2layernorm[args.layernorm](args.hidden_size, eps=args.layernorm_eps)
self.layer_norm_2 = str2layernorm[args.layernorm](args.hidden_size, eps=args.layernorm_eps)

def forward(self, inputs):
def forward(self, *inputs):

"""
Args:
Expand All @@ -63,7 +64,7 @@ def forward(self, inputs):
Returns:
output: [batch_size x seq_length x hidden_size]
"""
if len(inputs)==2:
if len(inputs) == 2:
hidden, mask = inputs
prev_attn = None
else:
Expand Down Expand Up @@ -136,7 +137,7 @@ def __init__(self, args, layer_number=None):

self.self_attn = ParallelMultiHeadedAttention(
args.hidden_size, args.heads_num, attention_head_size, local_kv_heads_num, args.dropout, has_bias=has_bias,
with_scale = with_scale, lora_params=lora_params, layer_number=layer_number
with_scale=with_scale, lora_params=lora_params, layer_number=layer_number
)
self.dropout_1 = nn.Dropout(args.dropout)

Expand All @@ -150,7 +151,7 @@ def __init__(self, args, layer_number=None):
self.layer_norm_1 = str2layernorm[args.layernorm](args.hidden_size, eps=args.layernorm_eps)
self.layer_norm_2 = str2layernorm[args.layernorm](args.hidden_size, eps=args.layernorm_eps)

def forward(self, inputs):
def forward(self, *inputs):

"""
Args:
Expand All @@ -161,7 +162,7 @@ def forward(self, inputs):
output: [batch_size x seq_length x hidden_size]
"""

if len(inputs)==2:
if len(inputs) == 2:
hidden, mask = inputs
prev_attn = None
else:
Expand Down Expand Up @@ -220,7 +221,7 @@ def generate_mask(self, seq_length, batch_size, device):
mask = mask.repeat(batch_size, 1, 1, 1)
return mask

def forward(self, inputs):
def forward(self, *inputs):

"""
Args:
Expand All @@ -231,15 +232,15 @@ def forward(self, inputs):
output: [batch_size x seq_length x hidden_size]
"""

if len(inputs)==2:
if len(inputs) == 2:
hidden, seg = inputs
prev_attn = None
else:
hidden, seg, prev_attn = inputs
batch_size, seq_length, _ = hidden.size()
mask = self.generate_mask(seq_length, batch_size, hidden.device)
layer_inputs = hidden, mask, prev_attn
outputs = self.layer(layer_inputs)
outputs = self.layer(*layer_inputs)

if self.has_residual_attention:
hidden, mask, prev_attn_out = outputs
Expand Down
21 changes: 10 additions & 11 deletions tencentpretrain/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def init_optimizer(args, model_for_training):
if 'lora' not in n:
p.requires_grad = False
else:
no_decay = ["bias", "gamma", "beta"]
no_decay = ["bias", "gamma", "beta", "layer_norm.weight", "layer_norm_1.weight", "layer_norm_2.weight"]
optimizer_grouped_parameters = [
{"params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], "weight_decay": 0.01},
{"params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], "weight_decay": 0.0}
Expand Down Expand Up @@ -695,23 +695,22 @@ def worker(local_rank, gpu_ranks, args):
if args.pipeline_model_parallel_size > 1:
from deepspeed.pipe import PipelineModule, TiedLayerSpec, LayerSpec
def get_model(model, args):
layers = [LayerSpec(EmbeddingPipe, args,model=model),
*[LayerSpec(ParallelTransformerLayerPipe, args,model=model, layer_idx=idx) for idx in
layers = [LayerSpec(EmbeddingPipe, args, model=model),
*[LayerSpec(ParallelTransformerLayerPipe, args, model=model, layer_idx=idx) for idx in
range(args.layers_num)],
LayerSpec(TargetPipe, args=args,model=model)
]
LayerSpec(TargetPipe, args=args, model=model)]
return layers
layers = get_model(model_for_training,args)
layers = get_model(model_for_training, args)
from deepspeed.runtime.pipe.topology import PipeModelDataParallelTopology
topo = PipeModelDataParallelTopology(num_pp=mpu.get_pipeline_model_parallel_world_size(),
num_mp=mpu.get_tensor_model_parallel_world_size(),
num_dp=mpu.get_data_parallel_world_size())

model_for_training=PipelineModule(layers=layers,
num_stages=args.pipeline_model_parallel_size,
activation_checkpoint_interval=args.deepspeed_checkpoint_layers_num,
loss_fn=CrossEntropy,
checkpointable_layers=['ParallelTransformerLayerPipe'], topology=topo)
model_for_training = PipelineModule(layers=layers,
num_stages=args.pipeline_model_parallel_size,
activation_checkpoint_interval=args.deepspeed_checkpoint_layers_num,
loss_fn=CrossEntropy,
checkpointable_layers=['ParallelTransformerLayerPipe'], topology=topo)

# Build optimizer.
custom_optimizer, custom_scheduler, optimizer_grouped_parameters = init_optimizer(args, model_for_training)
Expand Down
3 changes: 2 additions & 1 deletion tencentpretrain/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
"t5": T5Dataloader, "gsg": GsgDataloader, "bart": BartDataloader,
"cls": ClsDataloader, "prefixlm": PrefixlmDataloader, "cls_mlm": ClsMlmDataloader,
"vit": VitDataloader, "vilt": ViltDataloader, "clip": ClipDataloader, "s2t": S2tDataloader,
"beit":BeitDataloader, "dalle": DalleDataloader, "llm_sft": LlmSftDataloader}
"beit":BeitDataloader, "dalle": DalleDataloader, "llm_sft": LlmSftDataloader,
"llm_pretrain": LlmPretrainDataloader}

str2act = {"gelu": gelu, "gelu_fast": gelu_fast, "relu": relu, "silu": silu, "linear": linear}

Expand Down
18 changes: 8 additions & 10 deletions tencentpretrain/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,11 @@
with open("models/special_tokens_map.json", mode="r", encoding="utf-8") as f:
special_tokens_map = json.load(f)

UNK_TOKEN = special_tokens_map["unk_token"]
CLS_TOKEN = special_tokens_map["cls_token"]
SEP_TOKEN = special_tokens_map["sep_token"]
MASK_TOKEN = special_tokens_map["mask_token"]
PAD_TOKEN = special_tokens_map["pad_token"]
try:
# e.g. <extra_id_0>, <extra_id_1>, ... , should have consecutive IDs.
SENTINEL_TOKEN = special_tokens_map["sentinel_token"]
except KeyError:
pass
UNK_TOKEN = special_tokens_map.get("unk_token")
CLS_TOKEN = special_tokens_map.get("cls_token")
SEP_TOKEN = special_tokens_map.get("sep_token")
MASK_TOKEN = special_tokens_map.get("mask_token")
PAD_TOKEN = special_tokens_map.get("pad_token")

# e.g. <extra_id_0>, <extra_id_1>, ... , should have consecutive IDs.
SENTINEL_TOKEN = special_tokens_map.get("sentinel_token")
5 changes: 4 additions & 1 deletion tencentpretrain/utils/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,7 +754,6 @@ class S2tDataloader(AudioDataloader):
def __iter__(self):
import torchaudio
import torchaudio.compliance.kaldi as ta_kaldi

padding_vector = torch.FloatTensor(self.audio_feature_size * [self.padding_value] if self.audio_feature_size > 1 else self.padding_value).unsqueeze(0).cuda(self.local_rank)
while True:
while self._empty():
Expand Down Expand Up @@ -949,3 +948,7 @@ def __iter__(self):
yield torch.LongTensor(src), \
torch.LongTensor(tgt), \
torch.LongTensor(seg)


class LlmPretrainDataloader(LmDataloader):
pass
54 changes: 40 additions & 14 deletions tencentpretrain/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def merge_dataset(dataset_path, workers_num):
for i in range(workers_num):
tmp_dataset_reader = open("dataset-tmp-" + str(i) + ".pt", "rb")
while True:
tmp_data = tmp_dataset_reader.read(2**20)
tmp_data = tmp_dataset_reader.read(2 ** 20)
if tmp_data:
dataset_writer.write(tmp_data)
else:
Expand Down Expand Up @@ -69,13 +69,21 @@ def build_and_save(self, workers_num):
if workers_num == 1:
self.worker(0, 0, lines_num)
else:
async_results = []
pool = Pool(workers_num)
for i in range(workers_num):
start = i * lines_num // workers_num
end = (i + 1) * lines_num // workers_num
pool.apply_async(func=self.worker, args=[i, start, end])
# pool.apply_async(func=self.worker, args=[i, start, end])
async_results.append(pool.apply_async(func=self.worker, args=[i, start, end]))
pool.close()
pool.join()
async_results = [res.get() for res in async_results]
if async_results[0] is not None:
samples_num = sum([res[0] for res in async_results])
tokens_num = sum([res[1] for res in async_results])
print("Number of samples:", samples_num)
print("Total number of tokens:", tokens_num)

# Merge datasets.
merge_dataset(self.dataset_path, workers_num)
Expand Down Expand Up @@ -211,7 +219,8 @@ def create_ins_from_doc(self, all_documents, document_index):
pad_num = self.seq_length - len(src)

if not self.dynamic_masking:
src, tgt_mlm = mask_seq(src, self.tokenizer, self.whole_word_masking, self.span_masking, self.span_geo_prob, self.span_max_length)
src, tgt_mlm = mask_seq(src, self.tokenizer, self.whole_word_masking, self.span_masking,
self.span_geo_prob, self.span_max_length)
src = (src, pad_num)
instance = (src, tgt_mlm, is_random_next, seg_pos)
else:
Expand Down Expand Up @@ -245,7 +254,8 @@ def worker(self, proc_id, start, end):
line = f.readline()
pos += 1

document = [self.vocab.get(CLS_TOKEN)] + self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(line)) + [self.vocab.get(SEP_TOKEN)]
document = [self.vocab.get(CLS_TOKEN)] + self.tokenizer.convert_tokens_to_ids(
self.tokenizer.tokenize(line)) + [self.vocab.get(SEP_TOKEN)]

if self.full_sentences:
if len(document) > 0:
Expand Down Expand Up @@ -293,7 +303,8 @@ def build_instances(self, all_documents):
seg_pos = [len(src)]

if not self.dynamic_masking:
src, tgt = mask_seq(src, self.tokenizer, self.whole_word_masking, self.span_masking, self.span_geo_prob, self.span_max_length)
src, tgt = mask_seq(src, self.tokenizer, self.whole_word_masking, self.span_masking, self.span_geo_prob,
self.span_max_length)
instance = ((src, 0), tgt, seg_pos)
else:
instance = ((src, 0), seg_pos)
Expand All @@ -308,9 +319,10 @@ def build_instances(self, all_documents):
seg_pos = [len(src)]

pad_num = self.seq_length - len(src)

if not self.dynamic_masking:
src, tgt = mask_seq(src, self.tokenizer, self.whole_word_masking, self.span_masking, self.span_geo_prob, self.span_max_length)
src, tgt = mask_seq(src, self.tokenizer, self.whole_word_masking, self.span_masking, self.span_geo_prob,
self.span_max_length)
instance = ((src, pad_num), tgt, seg_pos)
else:
instance = ((src, pad_num), seg_pos)
Expand Down Expand Up @@ -417,7 +429,8 @@ def create_ins_from_doc(self, document):
pad_num = self.seq_length - len(src)

if not self.dynamic_masking:
src, tgt_mlm = mask_seq(src, self.tokenizer, self.whole_word_masking, self.span_masking, self.span_geo_prob, self.span_max_length)
src, tgt_mlm = mask_seq(src, self.tokenizer, self.whole_word_masking, self.span_masking,
self.span_geo_prob, self.span_max_length)
src = (src, pad_num)
instance = (src, tgt_mlm, is_wrong_order, seg_pos)
else:
Expand Down Expand Up @@ -464,7 +477,7 @@ def worker(self, proc_id, start, end):
seg_pos = [self.seq_length]
src = (src, 0)
pickle.dump((src, seg_pos), dataset_writer)
buffer = buffer[instances_num * (self.seq_length + 1): ]
buffer = buffer[instances_num * (self.seq_length + 1):]

else:
instances_num = len(document) // (self.seq_length + 1)
Expand All @@ -486,13 +499,17 @@ def worker(self, proc_id, start, end):

dataset_writer.close()


class LlmPretrainDataset(Dataset):
def __init__(self, args, vocab, tokenizer):
super(LlmPretrainDataset, self).__init__(args, vocab, tokenizer)
self.full_sentences = args.full_sentences

def worker(self, proc_id, start, end):
print("Worker %d is building dataset ... " % proc_id)
samples_num = 0
tokens_num = 0

set_seed(self.seed)
dataset_writer = open("dataset-tmp-" + str(proc_id) + ".pt", "wb")
pos = 0
Expand All @@ -517,7 +534,7 @@ def worker(self, proc_id, start, end):
seg_pos = [self.seq_length]
src = (src, 0)
pickle.dump((src, seg_pos), dataset_writer)
buffer = buffer[instances_num * (self.seq_length + 1): ]
buffer = buffer[instances_num * (self.seq_length + 1):]

else:
instances_num = len(document) // (self.seq_length + 1)
Expand All @@ -533,7 +550,8 @@ def worker(self, proc_id, start, end):
pad_num = self.seq_length + 1 - len(src)
src = (src, pad_num)
pickle.dump((src, seg_pos), dataset_writer)

tokens_num += len(src)
samples_num += 1
if pos >= end:
break

Expand Down Expand Up @@ -675,7 +693,8 @@ def create_ins_from_doc(self, all_documents, document_index):

while i < len(document):
segment = document[i]
if i in mask_seq_list and len(tgt) + len(segment) < target_tgt_seq_length and len(src) + 1 < target_seq_length:
if i in mask_seq_list and len(tgt) + len(segment) < target_tgt_seq_length and len(
src) + 1 < target_seq_length:
tgt = tgt + segment
src = src + [self.vocab.get(MASK_TOKEN)]
elif i not in mask_seq_list and len(src) + len(segment) < target_seq_length:
Expand Down Expand Up @@ -884,7 +903,8 @@ def worker(self, proc_id, start, end):
if len(line) == 2:
label = int(line[0])
text = line[1]
src = [self.vocab.get(CLS_TOKEN)] + self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(text)) + [self.vocab.get(SEP_TOKEN)]
src = [self.vocab.get(CLS_TOKEN)] + self.tokenizer.convert_tokens_to_ids(
self.tokenizer.tokenize(text)) + [self.vocab.get(SEP_TOKEN)]
tgt_cls = label
seg_pos = [len(src)]
elif len(line) == 3: # For sentence pair input.
Expand Down Expand Up @@ -920,7 +940,8 @@ def worker(self, proc_id, start, end):

if not self.dynamic_masking:
src_single, pad_num = src
src_single, tgt_mlm = mask_seq(src_single, self.tokenizer, self.whole_word_masking, self.span_masking, self.span_geo_prob, self.span_max_length)
src_single, tgt_mlm = mask_seq(src_single, self.tokenizer, self.whole_word_masking,
self.span_masking, self.span_geo_prob, self.span_max_length)
src = (src_single, pad_num)
instance = (src, tgt_mlm, tgt_cls, seg_pos)
else:
Expand Down Expand Up @@ -1046,6 +1067,8 @@ class DalleDataset(FileWithTextDataset):
class LlmSftDataset(Dataset):
def worker(self, proc_id, start, end):
print("Worker %d is building dataset ... " % proc_id)
samples_num = 0
tokens_num = 0
set_seed(self.seed)
dataset_writer = open("dataset-tmp-" + str(proc_id) + ".pt", "wb")
pos = 0
Expand Down Expand Up @@ -1079,7 +1102,10 @@ def worker(self, proc_id, start, end):
pad_num = self.seq_length - len(src)

pickle.dump(((src, pad_num), seg_pos), dataset_writer)
tokens_num += len(src)
samples_num += 1
if pos >= end:
break

dataset_writer.close()
return samples_num, tokens_num
Loading

0 comments on commit 27547a4

Please sign in to comment.