From 06257cf669ad7350dbf41cfe4f7a9f2f0b233611 Mon Sep 17 00:00:00 2001 From: Huy Vu <86480512+huvunvidia@users.noreply.github.com> Date: Mon, 21 Oct 2024 14:15:46 -0400 Subject: [PATCH] Update T5 tokenizer (adding additional tokens to tokenizer config) (#10972) * initial commit * restore t5_pretraining * Apply isort and black reformatting Signed-off-by: huvunvidia --------- Signed-off-by: huvunvidia Co-authored-by: Huy Vu2 Co-authored-by: huvunvidia --- .../common/tokenizers/huggingface/auto_tokenizer.py | 11 ++++++++++- nemo/collections/llm/t5/data/fine_tuning.py | 2 -- nemo/collections/llm/t5/data/pre_training.py | 4 ---- .../collections/nlp/modules/common/tokenizer_utils.py | 9 +++++++-- tests/collections/llm/megatron_t5_finetuning.py | 3 +++ tests/collections/llm/megatron_t5_pretraining.py | 3 +++ 6 files changed, 23 insertions(+), 9 deletions(-) diff --git a/nemo/collections/common/tokenizers/huggingface/auto_tokenizer.py b/nemo/collections/common/tokenizers/huggingface/auto_tokenizer.py index 76dca1268c3b..439322b8e810 100644 --- a/nemo/collections/common/tokenizers/huggingface/auto_tokenizer.py +++ b/nemo/collections/common/tokenizers/huggingface/auto_tokenizer.py @@ -13,7 +13,7 @@ # limitations under the License. from collections import OrderedDict -from typing import Optional +from typing import List, Optional from transformers import AutoTokenizer as AUTOTOKENIZER @@ -43,6 +43,7 @@ def __init__( sep_token: Optional[str] = None, cls_token: Optional[str] = None, unk_token: Optional[str] = None, + additional_special_tokens: Optional[List] = [], use_fast: Optional[bool] = False, trust_remote_code: Optional[bool] = False, ): @@ -60,6 +61,7 @@ def __init__( sep_token: token used for separating sequences cls_token: class token. Usually equal to bos_token unk_token: token to use for unknown tokens + additional_special_tokens: list of other tokens beside standard special tokens (bos, eos, pad, etc.). For example, sentinel tokens for T5 (, , etc.) use_fast: whether to use fast HuggingFace tokenizer """ try: @@ -124,10 +126,17 @@ def __init__( elif self.tokenizer.cls_token is None and self.tokenizer.bos_token: special_tokens_dict["cls_token"] = self.tokenizer.bos_token + # add additional special tokens (not standard special tokens such as bos, eod, sep) + if additional_special_tokens is not None: + special_tokens_dict["additional_special_tokens"] = additional_special_tokens + new_tokens_in_vocab = [] for token in [mask_token, bos_token, eos_token, pad_token, sep_token, cls_token, unk_token]: if token is not None and token not in self.tokenizer.get_vocab(): new_tokens_in_vocab.append(token) + for token in additional_special_tokens: + if token is not None and token not in self.tokenizer.get_vocab(): + new_tokens_in_vocab.append(token) if len(new_tokens_in_vocab) > 0: """ diff --git a/nemo/collections/llm/t5/data/fine_tuning.py b/nemo/collections/llm/t5/data/fine_tuning.py index b1315f7a708a..9326dabe7b84 100644 --- a/nemo/collections/llm/t5/data/fine_tuning.py +++ b/nemo/collections/llm/t5/data/fine_tuning.py @@ -61,8 +61,6 @@ def __init__( from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer self.tokenizer = tokenizer or get_nmt_tokenizer("megatron", "BertWordPieceCase") - additional_tokens = {'additional_special_tokens': [f'' for i in range(100)]} - self.tokenizer.add_special_tokens(additional_tokens) self.memmap_workers = memmap_workers self.num_workers = num_workers diff --git a/nemo/collections/llm/t5/data/pre_training.py b/nemo/collections/llm/t5/data/pre_training.py index 2c73e0b78b11..e6f619972284 100644 --- a/nemo/collections/llm/t5/data/pre_training.py +++ b/nemo/collections/llm/t5/data/pre_training.py @@ -130,10 +130,6 @@ def __init__( # add additional tokens for T5 tokenizer from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer - self.tokenizer = tokenizer or get_nmt_tokenizer("megatron", "BertWordPieceCase") - additional_tokens = {'additional_special_tokens': [f'' for i in range(100)]} - self.tokenizer.add_special_tokens(additional_tokens) - self.data_sampler = MegatronDataSampler( seq_len=self.seq_length, micro_batch_size=micro_batch_size, diff --git a/nemo/collections/nlp/modules/common/tokenizer_utils.py b/nemo/collections/nlp/modules/common/tokenizer_utils.py index 4e6f9e15b839..dfc55a6c9065 100644 --- a/nemo/collections/nlp/modules/common/tokenizer_utils.py +++ b/nemo/collections/nlp/modules/common/tokenizer_utils.py @@ -69,7 +69,8 @@ def get_tokenizer( To see the list of all HuggingFace pretrained models, use: nemo_nlp.modules.common.get_huggingface_pretrained_lm_models_list() tokenizer_model: tokenizer model file of sentencepiece - special_tokens: dict of special tokens + special_tokens: dict of special tokens. + For additional special tokens besides standard special tokens (bos, eos, pad, etc.), such as sentinel tokens for T5 (, , etc.), use key 'additional_special_tokens' vocab_file: path to vocab file use_fast: (only for HuggingFace AutoTokenizer) set to True to use fast HuggingFace tokenizer bpe_dropout: (experimental) BPE dropout tries to corrupt the standard segmentation @@ -224,7 +225,11 @@ def get_nmt_tokenizer( f'Getting Megatron tokenizer for pretrained model name: {model_name}, custom vocab file: {vocab_file}, and merges file: {merges_file}' ) return get_tokenizer( - tokenizer_name=model_name, vocab_file=vocab_file, merges_file=merges_file, chat_template=chat_template + tokenizer_name=model_name, + vocab_file=vocab_file, + merges_file=merges_file, + special_tokens=special_tokens_dict, + chat_template=chat_template, ) elif library == 'tabular': from nemo.collections.common.tokenizers.tabular_tokenizer import TabularTokenizer diff --git a/tests/collections/llm/megatron_t5_finetuning.py b/tests/collections/llm/megatron_t5_finetuning.py index a204e6797926..f54e858cfb43 100644 --- a/tests/collections/llm/megatron_t5_finetuning.py +++ b/tests/collections/llm/megatron_t5_finetuning.py @@ -35,9 +35,12 @@ def get_args(): args = get_args() + special_tokens = {} + special_tokens['additional_special_tokens'] = [f'' for i in range(100)] tokenizer = get_nmt_tokenizer( "megatron", "BertWordPieceCase", + special_tokens=special_tokens, ) data = SquadDataModule( diff --git a/tests/collections/llm/megatron_t5_pretraining.py b/tests/collections/llm/megatron_t5_pretraining.py index 5d8f55a7f26f..a5460be3d154 100644 --- a/tests/collections/llm/megatron_t5_pretraining.py +++ b/tests/collections/llm/megatron_t5_pretraining.py @@ -50,10 +50,13 @@ def get_args(): args = get_args() + special_tokens = {} + special_tokens['additional_special_tokens'] = [f'' for i in range(100)] tokenizer = get_nmt_tokenizer( "megatron", "BertWordPieceCase", vocab_file=args.vocab_path, + special_tokens=special_tokens, ) data = PreTrainingDataModule( paths=args.data_path,