Skip to content

Commit

Permalink
More fixes to SentencePiece for T5 (#3515)
Browse files Browse the repository at this point in the history
* More fixes to spm for T5

Signed-off-by: MaximumEntropy <[email protected]>

* Style fixes

Signed-off-by: MaximumEntropy <[email protected]>
  • Loading branch information
MaximumEntropy authored Jan 28, 2022
1 parent 67aae90 commit 33290ac
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,10 @@ def unk_id(self):
@property
def additional_special_tokens_ids(self):
"""Returns a list of the additional special tokens (excluding bos, eos, pad, unk). Used to return sentinel tokens for e.g. T5."""
return list(self.special_token_to_id.values())
special_tokens = set(
[self.bos_token, self.eos_token, self.pad_token, self.mask_token, self.cls_token, self.sep_token]
)
return [v for k, v in self.special_token_to_id.items() if k not in special_tokens]

@property
def vocab(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def __init__(
self.vocab_id_to_token_dict = {idx: token for idx, token in enumerate(self.vocab_id_list)}

self.sentinel_tokens = tokenizer.additional_special_tokens_ids
assert len(self.sentinel_tokens) > 0, "Provide the argument --vocab-extra-ids 100 to the script"
assert len(self.sentinel_tokens) > 0

def __len__(self):
return self.samples_mapping.shape[0]
Expand Down
22 changes: 17 additions & 5 deletions nemo/collections/nlp/models/language_modeling/megatron_t5_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,25 +480,37 @@ def _add_special_tokens_to_tokenizer(self):

# bos, eos, pad and unk may be present in the provided spm .model file, if they are, use it.
if not hasattr(self.tokenizer, 'pad_token'):
if hasattr(self.tokenizer.tokenizer, 'pad_id'):
if hasattr(self.tokenizer.tokenizer, 'pad_id') and self.tokenizer.tokenizer.pad_id() > 0:
self.tokenizer.pad_token = self.tokenizer.tokenizer.id_to_piece(self.tokenizer.tokenizer.pad_id())
else:
self.tokenizer.add_special_tokens({'pad_token': '<pad>'})
else:
self.tokenizer.add_special_tokens({'pad_token': '<pad>'})

if not hasattr(self.tokenizer, 'bos_token'):
if hasattr(self.tokenizer.tokenizer, 'bos_id'):
if hasattr(self.tokenizer.tokenizer, 'bos_id') and self.tokenizer.tokenizer.bos_id() > 0:
self.tokenizer.bos_token = self.tokenizer.tokenizer.id_to_piece(self.tokenizer.tokenizer.bos_id())
else:
self.tokenizer.add_special_tokens({'bos_token': '<bos>'})
else:
self.tokenizer.add_special_tokens({'bos_token': '<s>'})

if not hasattr(self.tokenizer, 'eos_token'):
if hasattr(self.tokenizer.tokenizer, 'eos_id'):
if hasattr(self.tokenizer.tokenizer, 'eos_id') and self.tokenizer.tokenizer.eos_id() > 0:
self.tokenizer.eos_token = self.tokenizer.tokenizer.id_to_piece(self.tokenizer.tokenizer.eos_id())
else:
self.tokenizer.add_special_tokens({'eos_token': '<eos>'})
else:
self.tokenizer.add_special_tokens({'eos_token': '</s>'})

additional_tokens = [f'<extra_id_{i}>' for i in range(self.num_sentinel_tokens)]
self.tokenizer.add_special_tokens(additional_tokens)
# Special check to see if <extra_id_{}> is already present in the tokenizer. If it is, only modify the additional_special_tokens function.
for i in range(self.num_sentinel_tokens):
if f'▁<extra_id_{i}>' in self.tokenizer.vocab:
self.tokenizer.special_token_to_id[f'<extra_id_{i}>'] = self.tokenizer.text_to_ids(
f'<extra_id_{i}>'
)[0]
else:
self.tokenizer.add_special_tokens([f'<extra_id_{i}>'])

def list_available_models():
pass
3 changes: 2 additions & 1 deletion nemo/collections/nlp/modules/common/lm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,10 @@ def get_lm_model(
)

if nemo_file is not None:
from nemo.collections.nlp.models.language_modeling.megatron_bert_model import MegatronBertModel
import torch

from nemo.collections.nlp.models.language_modeling.megatron_bert_model import MegatronBertModel

class Identity(torch.nn.Module):
def __init__(self):
super(Identity, self).__init__()
Expand Down

0 comments on commit 33290ac

Please sign in to comment.