-
Notifications
You must be signed in to change notification settings - Fork 2.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into angularLoss1.0
- Loading branch information
Showing
21 changed files
with
1,299 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
80 changes: 80 additions & 0 deletions
80
examples/nlp/language_modeling/conf/transformer_lm_config.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
# Config file for training left-to-right Transformer language model | ||
name: &name TransformerLM | ||
|
||
trainer: | ||
gpus: 1 # the number of gpus, 0 for CPU | ||
num_nodes: 1 | ||
max_epochs: 2 | ||
max_steps: 400 # precedence over max_epochs | ||
accumulate_grad_batches: 1 # accumulates grads every k batches | ||
amp_level: O2 # O1/O2 for mixed precision | ||
precision: 16 # Should be set to 16 for O1 and O2, default is 16 as PT ignores it when am_level is O0 | ||
distributed_backend: ddp | ||
checkpoint_callback: False # Provided by exp_manager | ||
logger: False # Provided by exp_manager | ||
row_log_interval: 1 # Interval of logging. | ||
val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations | ||
resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. | ||
|
||
model: | ||
|
||
language_model: | ||
tokenizer: word | ||
vocab_file: ??? | ||
hidden_size: 512 | ||
num_layers: 6 | ||
num_attn_heads: 8 | ||
inner_size: 2048 | ||
max_seq_length: 256 | ||
embedding_dropout: 0 | ||
ffn_dropout: 0 | ||
attn_score_dropout: 0 | ||
attn_layer_dropout: 0 | ||
|
||
dataset: | ||
max_seq_length: 256 | ||
num_workers: 2 # number of workers for data loaders | ||
drop_last: false # drops the last last batch if it is smaller than the batch size | ||
pin_memory: false # enables pin_memory feature of the data loaders | ||
|
||
train_ds: | ||
file_name: ??? # path to file with training data | ||
batch_size: 32 | ||
shuffle: true | ||
num_samples: -1 # number of samples to be considered, -1 means all the dataset | ||
|
||
validation_ds: | ||
file_name: ??? # path to file with validation data | ||
batch_size: 32 | ||
shuffle: false | ||
num_samples: -1 # number of samples to be considered, -1 means all the dataset | ||
predict_last_k: 64 | ||
|
||
optim: | ||
name: adam | ||
lr: 1e-4 | ||
betas: [0.9, 0.999] | ||
weight_decay: 0 | ||
|
||
sched: | ||
name: WarmupAnnealing | ||
warmup_steps: null | ||
warmup_ratio: 0.05 | ||
last_epoch: -1 | ||
|
||
# pytorch lightning args | ||
monitor: val_loss | ||
reduce_on_plateau: false | ||
|
||
exp_manager: | ||
exp_dir: null # where to store logs and checkpoints | ||
name: *name # name of experiment | ||
create_tensorboard_logger: True | ||
create_checkpoint_callback: True | ||
|
||
hydra: | ||
run: | ||
dir: . | ||
job_logging: | ||
root: | ||
handlers: null |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
import pytorch_lightning as pl | ||
from omegaconf import DictConfig | ||
|
||
from nemo.collections.nlp.models.language_modeling import TransformerLMModel | ||
from nemo.core.config import hydra_runner | ||
from nemo.utils import logging | ||
from nemo.utils.exp_manager import exp_manager | ||
|
||
|
||
@hydra_runner(config_path="conf", config_name="transformer_lm_config") | ||
def main(cfg: DictConfig) -> None: | ||
logging.info(f'Config: {cfg.pretty()}') | ||
trainer = pl.Trainer(**cfg.trainer) | ||
exp_manager(trainer, cfg.get("exp_manager", None)) | ||
transformer_lm = TransformerLMModel(cfg.model, trainer=trainer) | ||
trainer.fit(transformer_lm) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec | ||
|
||
__all__ = ['CharTokenizer'] | ||
|
||
|
||
class CharTokenizer(TokenizerSpec): | ||
def __init__( | ||
self, | ||
vocab_file: str, | ||
bos_token: str = "<BOS>", | ||
eos_token: str = "<EOS>", | ||
pad_token: str = "<PAD>", | ||
unk_token: str = "<UNK>", | ||
): | ||
""" | ||
Args: | ||
vocab_file: path to file with vocabulary which consists | ||
of characters separated by \n | ||
bos_token: the beginning of sequence token | ||
eos_token: the end of sequence token | ||
pad_token: token to use for padding | ||
unk_token: token to use for unknown tokens | ||
""" | ||
|
||
vocab_list = open(vocab_file, "r").readlines() | ||
self.vocab = {vocab_list[i].strip(): i for i in range(len(vocab_list))} | ||
|
||
special_tokens_dict = { | ||
"bos_token": bos_token, | ||
"eos_token": eos_token, | ||
"pad_token": pad_token, | ||
"unk_token": unk_token, | ||
} | ||
|
||
self.add_special_tokens(special_tokens_dict) | ||
self.inv_vocab = {v: k for k, v in self.vocab.items()} | ||
self.vocab_size = len(self.vocab) | ||
self.special_tokens = self.tokens_to_ids(special_tokens_dict.values()) | ||
|
||
def add_special_tokens(self, special_tokens_dict: dict) -> int: | ||
""" | ||
Adds a dictionary of special tokens (eos, pad, cls...). | ||
If special tokens are NOT in the vocabulary, they are added | ||
to it (indexed starting from the last index of the current vocabulary). | ||
Args: | ||
special_tokens_dict: dict of special tokens | ||
""" | ||
for token in special_tokens_dict: | ||
token_str = special_tokens_dict[token] | ||
if token_str not in self.vocab: | ||
self.vocab[token_str] = len(self.vocab) | ||
setattr(self, token, token_str) | ||
|
||
def text_to_tokens(self, text): | ||
token_candidates = [char for char in text] | ||
tokens = [] | ||
for token in token_candidates: | ||
if token in self.vocab: | ||
tokens.append(token) | ||
else: | ||
tokens.append(self.unk_token) | ||
return tokens | ||
|
||
def tokens_to_text(self, tokens): | ||
return self.ids_to_text(self.tokens_to_ids(tokens)) | ||
|
||
def text_to_ids(self, text): | ||
return [self.vocab[token] for token in self.text_to_tokens(text)] | ||
|
||
def ids_to_text(self, ids): | ||
ids_ = [id_ for id_ in ids if id_ not in self.special_tokens] | ||
return "".join(self.ids_to_tokens(ids_)) | ||
|
||
def tokens_to_ids(self, tokens): | ||
return [self.vocab[token] for token in tokens] | ||
|
||
def ids_to_tokens(self, ids): | ||
return [self.inv_vocab[id] for id in ids] | ||
|
||
@property | ||
def pad_id(self): | ||
return self.vocab[self.pad_token] | ||
|
||
@property | ||
def bos_id(self): | ||
return self.vocab[self.bos_token] | ||
|
||
@property | ||
def eos_id(self): | ||
return self.vocab[self.eos_token] | ||
|
||
@property | ||
def unk_id(self): | ||
return self.vocab[self.unk_token] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from nemo.collections.common.tokenizers.char_tokenizer import CharTokenizer | ||
|
||
__all__ = ['WordTokenizer'] | ||
|
||
|
||
class WordTokenizer(CharTokenizer): | ||
def __init__( | ||
self, | ||
vocab_file: str, | ||
bos_token: str = "<BOS>", | ||
eos_token: str = "<EOS>", | ||
pad_token: str = "<PAD>", | ||
unk_token: str = "<UNK>", | ||
): | ||
""" | ||
Args: | ||
vocab_file: path to file with vocabulary which consists | ||
of characters separated by \n | ||
bos_token: the beginning of sequence token | ||
eos_token: the end of sequence token | ||
pad_token: token to use for padding | ||
unk_token: token to use for unknown tokens | ||
""" | ||
|
||
super().__init__(vocab_file, bos_token, eos_token, pad_token, unk_token) | ||
|
||
def text_to_tokens(self, text): | ||
token_candidates = text.strip().split() | ||
tokens = [] | ||
for token in token_candidates: | ||
if token in self.vocab: | ||
tokens.append(token) | ||
else: | ||
tokens.append(self.unk_token) | ||
return tokens | ||
|
||
def ids_to_text(self, ids): | ||
ids_ = [id_ for id_ in ids if id_ not in self.special_tokens] | ||
return " ".join(self.ids_to_tokens(ids_)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.