Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More fixes to SentencePiece for T5 #3515

Merged
merged 5 commits into from
Jan 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,10 @@ def unk_id(self):
@property
def additional_special_tokens_ids(self):
"""Returns a list of the additional special tokens (excluding bos, eos, pad, unk). Used to return sentinel tokens for e.g. T5."""
return list(self.special_token_to_id.values())
special_tokens = set(
[self.bos_token, self.eos_token, self.pad_token, self.mask_token, self.cls_token, self.sep_token]
)
return [v for k, v in self.special_token_to_id.items() if k not in special_tokens]

@property
def vocab(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def __init__(
self.vocab_id_to_token_dict = {idx: token for idx, token in enumerate(self.vocab_id_list)}

self.sentinel_tokens = tokenizer.additional_special_tokens_ids
assert len(self.sentinel_tokens) > 0, "Provide the argument --vocab-extra-ids 100 to the script"
assert len(self.sentinel_tokens) > 0

def __len__(self):
return self.samples_mapping.shape[0]
Expand Down
22 changes: 17 additions & 5 deletions nemo/collections/nlp/models/language_modeling/megatron_t5_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,25 +480,37 @@ def _add_special_tokens_to_tokenizer(self):

# bos, eos, pad and unk may be present in the provided spm .model file, if they are, use it.
if not hasattr(self.tokenizer, 'pad_token'):
if hasattr(self.tokenizer.tokenizer, 'pad_id'):
if hasattr(self.tokenizer.tokenizer, 'pad_id') and self.tokenizer.tokenizer.pad_id() > 0:
self.tokenizer.pad_token = self.tokenizer.tokenizer.id_to_piece(self.tokenizer.tokenizer.pad_id())
else:
self.tokenizer.add_special_tokens({'pad_token': '<pad>'})
else:
self.tokenizer.add_special_tokens({'pad_token': '<pad>'})

if not hasattr(self.tokenizer, 'bos_token'):
if hasattr(self.tokenizer.tokenizer, 'bos_id'):
if hasattr(self.tokenizer.tokenizer, 'bos_id') and self.tokenizer.tokenizer.bos_id() > 0:
self.tokenizer.bos_token = self.tokenizer.tokenizer.id_to_piece(self.tokenizer.tokenizer.bos_id())
else:
self.tokenizer.add_special_tokens({'bos_token': '<bos>'})
else:
self.tokenizer.add_special_tokens({'bos_token': '<s>'})

if not hasattr(self.tokenizer, 'eos_token'):
if hasattr(self.tokenizer.tokenizer, 'eos_id'):
if hasattr(self.tokenizer.tokenizer, 'eos_id') and self.tokenizer.tokenizer.eos_id() > 0:
self.tokenizer.eos_token = self.tokenizer.tokenizer.id_to_piece(self.tokenizer.tokenizer.eos_id())
else:
self.tokenizer.add_special_tokens({'eos_token': '<eos>'})
else:
self.tokenizer.add_special_tokens({'eos_token': '</s>'})

additional_tokens = [f'<extra_id_{i}>' for i in range(self.num_sentinel_tokens)]
self.tokenizer.add_special_tokens(additional_tokens)
# Special check to see if <extra_id_{}> is already present in the tokenizer. If it is, only modify the additional_special_tokens function.
for i in range(self.num_sentinel_tokens):
if f'▁<extra_id_{i}>' in self.tokenizer.vocab:
self.tokenizer.special_token_to_id[f'<extra_id_{i}>'] = self.tokenizer.text_to_ids(
f'<extra_id_{i}>'
)[0]
else:
self.tokenizer.add_special_tokens([f'<extra_id_{i}>'])

def list_available_models():
pass
3 changes: 2 additions & 1 deletion nemo/collections/nlp/modules/common/lm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,10 @@ def get_lm_model(
)

if nemo_file is not None:
from nemo.collections.nlp.models.language_modeling.megatron_bert_model import MegatronBertModel
import torch

from nemo.collections.nlp.models.language_modeling.megatron_bert_model import MegatronBertModel

class Identity(torch.nn.Module):
def __init__(self):
super(Identity, self).__init__()
Expand Down