From 9c9eaf16ac9287124f70edc2edb1827b43440c2f Mon Sep 17 00:00:00 2001 From: shibing624 Date: Thu, 11 Jan 2024 16:10:02 +0800 Subject: [PATCH] update group text function. --- pretraining.py | 27 +++++++++++++++++++++------ 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/pretraining.py b/pretraining.py index 9aab9ea..9f3fdf2 100644 --- a/pretraining.py +++ b/pretraining.py @@ -425,9 +425,11 @@ def tokenize_function(examples): return tokenized_inputs + def tokenize_wo_pad_function(examples): + return tokenizer(examples["text"]) + # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size. - def tokenize_and_group_text_function(examples): - examples = tokenizer(examples["text"]) + def group_text_function(examples): # Concatenate all texts. concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} total_length = len(concatenated_examples[list(examples.keys())[0]]) @@ -535,8 +537,16 @@ def tokenize_and_group_text_function(examples): with training_args.main_process_first(desc="Dataset tokenization and grouping"): if not data_args.streaming: if training_args.group_by_length: - lm_datasets = raw_datasets.map( - tokenize_and_group_text_function, + tokenized_datasets = raw_datasets.map( + tokenize_wo_pad_function, + batched=True, + num_proc=data_args.preprocessing_num_workers, + remove_columns=column_names, + load_from_cache_file=not data_args.overwrite_cache, + desc="Running tokenizer on dataset", + ) + lm_datasets = tokenized_datasets.map( + group_text_function, batched=True, num_proc=data_args.preprocessing_num_workers, load_from_cache_file=not data_args.overwrite_cache, @@ -553,8 +563,13 @@ def tokenize_and_group_text_function(examples): ) else: if training_args.group_by_length: - lm_datasets = raw_datasets.map( - tokenize_and_group_text_function, + tokenized_datasets = raw_datasets.map( + tokenize_wo_pad_function, + batched=True, + remove_columns=column_names, + ) + lm_datasets = tokenized_datasets.map( + group_text_function, batched=True, ) else: