From a16598ad07aa38f8285e18501c817a84a7b8a176 Mon Sep 17 00:00:00 2001 From: Yi Dong Date: Tue, 18 Jan 2022 19:28:41 -0800 Subject: [PATCH 01/22] init checking of p-tune method Signed-off-by: Yi Dong --- .../ptune_text_classification_config.yaml | 118 +++ .../ptune_text_classification.py | 154 ++++ .../ptune_text_classification_dataset.py | 60 ++ .../language_modeling/megatron/gpt_model.py | 2 + .../ptune_text_classification_model.py | 423 +++++++++ .../modules/common/megatron/language_model.py | 5 +- .../nlp/modules/common/prompt_encoder.py | 59 ++ tutorials/nlp/PTune_sentiment_analysis.ipynb | 816 ++++++++++++++++++ 8 files changed, 1635 insertions(+), 2 deletions(-) create mode 100644 examples/nlp/text_classification/conf/ptune_text_classification_config.yaml create mode 100644 examples/nlp/text_classification/ptune_text_classification.py create mode 100644 nemo/collections/nlp/data/text_classification/ptune_text_classification_dataset.py create mode 100644 nemo/collections/nlp/models/text_classification/ptune_text_classification_model.py create mode 100644 nemo/collections/nlp/modules/common/prompt_encoder.py create mode 100644 tutorials/nlp/PTune_sentiment_analysis.ipynb diff --git a/examples/nlp/text_classification/conf/ptune_text_classification_config.yaml b/examples/nlp/text_classification/conf/ptune_text_classification_config.yaml new file mode 100644 index 000000000000..88f0d326c135 --- /dev/null +++ b/examples/nlp/text_classification/conf/ptune_text_classification_config.yaml @@ -0,0 +1,118 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Config file for text classification with pre-trained BERT models + +trainer: + gpus: 1 # number of GPUs (0 for CPU), or list of the GPUs to use e.g. [0, 1] + num_nodes: 1 + max_epochs: 100 + max_steps: null # precedence over max_epochs + accumulate_grad_batches: 1 # accumulates grads every k batches + gradient_clip_val: 0.0 + precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP. + accelerator: ddp + log_every_n_steps: 1 # Interval of logging. + val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations + resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. + num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it + + checkpoint_callback: False # Provided by exp_manager + logger: False # Provided by exp_manager + +model: + tensor_model_parallel_size: 2 # tensor model parallel size used in the LM model + seed: 1234 + nemo_path: ptune_text_classification_model.nemo # filename to save the model and associated artifacts to .nemo file + use_lm_finetune: False # whether fine tune the language model + pseudo_token: '[PROMPT]' # pseudo prompt tokens + + tokenizer: + library: 'megatron' + type: 'GPT2BPETokenizer' + model: null + vocab_file: null + merge_file: null + + language_model: + nemo_file: null + + prompt_encoder: + template: [3, 3, 0] + dropout: 0.1 + + dataset: + classes: ??? # The class labels, e.g. ['positive', 'neutral', 'negative'] + do_lower_case: false # true for uncased models, false for cased models, will be set automatically if pre-trained tokenizer model is used + max_seq_length: 256 # the maximum length BERT supports is 512 + class_balancing: null # null or 'weighted_loss'. 'weighted_loss' enables the weighted class balancing of the loss, may be used for handling unbalanced classes + use_cache: false # uses a cache to store the processed dataset, you may use it for large datasets for speed up + + train_ds: + file_path: null + batch_size: 64 + shuffle: true + num_samples: -1 # number of samples to be considered, -1 means all the dataset + num_workers: 3 + drop_last: false + pin_memory: false + + validation_ds: + file_path: null + batch_size: 64 + shuffle: false + num_samples: -1 # number of samples to be considered, -1 means all the dataset + num_workers: 3 + drop_last: false + pin_memory: false + + test_ds: + file_path: null + batch_size: 64 + shuffle: false + num_samples: -1 # number of samples to be considered, -1 means all the dataset + num_workers: 3 + drop_last: false + pin_memory: false + + optim: + name: adam + lr: 2e-5 + # optimizer arguments + betas: [0.9, 0.999] + weight_decay: 0.01 + + # scheduler setup + sched: + name: WarmupAnnealing + # Scheduler params + warmup_steps: null + warmup_ratio: 0.1 + last_epoch: -1 + # pytorch lightning args + monitor: val_loss + reduce_on_plateau: false + + # List of some sample queries for inference after training is done + infer_samples: [ + 'by the end of no such thing the audience , like beatrice , has a watchful affection for the monster .', + 'director rob marshall went out gunning to make a great one .', + 'uneasy mishmash of styles and genres .', + ] + +exp_manager: + exp_dir: null # exp_dir for your experiment, if None, defaults to "./nemo_experiments" + name: "PTuneTextClassification" # The name of your model + create_tensorboard_logger: True # Whether you want exp_manger to create a tb logger + create_checkpoint_callback: True # Whether you want exp_manager to create a modelcheckpoint callback diff --git a/examples/nlp/text_classification/ptune_text_classification.py b/examples/nlp/text_classification/ptune_text_classification.py new file mode 100644 index 000000000000..6ca660f8a47c --- /dev/null +++ b/examples/nlp/text_classification/ptune_text_classification.py @@ -0,0 +1,154 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script contains an example on how to train, evaluate and perform inference with the TextClassificationModel. +TextClassificationModel in NeMo supports text classification problems such as sentiment analysis or +domain/intent detection for dialogue systems, as long as the data follows the format specified below. + +***Data format*** +TextClassificationModel requires the data to be stored in TAB separated files (.tsv) with two columns of sentence and +label. Each line of the data file contains text sequences, where words are separated with spaces and label separated +with [TAB], i.e.: + +[WORD][SPACE][WORD][SPACE][WORD][TAB][LABEL] + +For example: + +hide new secretions from the parental units[TAB]0 +that loves its characters and communicates something rather beautiful about human nature[TAB]1 +... + +If your dataset is stored in another format, you need to convert it to this format to use the TextClassificationModel. + + +***Setting the configs*** +The model and the PT trainer are defined in a config file which declares multiple important sections. +The most important ones are: + model: All arguments that are related to the Model - language model, tokenizer, head classifier, optimizer, + schedulers, and datasets/data loaders. + trainer: Any argument to be passed to PyTorch Lightning including number of epochs, number of GPUs, + precision level, etc. + +This script uses the `/examples/nlp/text_classification/conf/text_classification_config.yaml` default config file +by default. You may update the config file from the file directly or by using the command line arguments. +Other option is to set another config file via command line arguments by `--config-name=CONFIG_FILE_PATH'. + +You first need to set the num_classes in the config file which specifies the number of classes in the dataset. +Notice that some config lines, including `model.dataset.classes_num`, have `???` as their value, this means that values +for these fields are required to be specified by the user. We need to specify and set the `model.train_ds.file_name`, +`model.validation_ds.file_name`, and `model.test_ds.file_name` in the config file to the paths of the train, validation, + and test files if they exist. We may do it by updating the config file or by setting them from the command line. + + +***How to run the script?*** +For example the following would train a model for 50 epochs in 2 GPUs on a classification task with 2 classes: + +# python text_classification_with_bert.py + model.dataset.num_classes=2 + model.train_ds=PATH_TO_TRAIN_FILE + model.validation_ds=PATH_TO_VAL_FILE + trainer.max_epochs=50 + trainer.gpus=2 + +This script would also reload the last checkpoint after the training is done and does evaluation on the dev set, +then performs inference on some sample queries. + +By default, this script uses examples/nlp/text_classification/conf/text_classifciation_config.py config file, and +you may update all the params in the config file from the command line. You may also use another config file like this: + +# python text_classification_with_bert.py --config-name==PATH_TO_CONFIG_FILE + model.dataset.num_classes=2 + model.train_ds=PATH_TO_TRAIN_FILE + model.validation_ds=PATH_TO_VAL_FILE + trainer.max_epochs=50 + trainer.gpus=2 + +***Load a saved model*** +This script would save the model after training into '.nemo' checkpoint file specified by nemo_path of the model config. +You may restore the saved model like this: + model = TextClassificationModel.restore_from(restore_path=NEMO_FILE_PATH) + +***Evaluation a saved model on another dataset*** +# If you wanted to evaluate the saved model on another dataset, you may restore the model and create a new data loader: + eval_model = TextClassificationModel.restore_from(restore_path=checkpoint_path) + +# Then, you may create a dataloader config for evaluation: + eval_config = OmegaConf.create( + {'file_path': cfg.model.test_ds.file_path, 'batch_size': 64, 'shuffle': False, 'num_workers': 3} + ) + eval_model.setup_test_data(test_data_config=eval_config) + +# You need to create a new trainer: + eval_trainer = pl.Trainer(gpus=1) + eval_model.set_trainer(eval_trainer) + eval_trainer.test(model=eval_model, verbose=False) +""" +import pytorch_lightning as pl +from omegaconf import DictConfig, OmegaConf + +from nemo.collections.nlp.models.text_classification.ptune_text_classification_model import PTuneTextClassificationModel +from nemo.collections.nlp.parts.nlp_overrides import NLPDDPPlugin +from nemo.core.config import hydra_runner +from nemo.utils import logging +from nemo.utils.exp_manager import exp_manager + + +@hydra_runner(config_path="conf", config_name="ptune_text_classification_config") +def main(cfg: DictConfig) -> None: + logging.info(f'\nConfig Params:\n{OmegaConf.to_yaml(cfg)}') + trainer = pl.Trainer(plugins=[NLPDDPPlugin()], **cfg.trainer) + exp_manager(trainer, cfg.get("exp_manager", None)) + + if not cfg.model.train_ds.file_path: + raise ValueError("'train_ds.file_path' need to be set for the training!") + + model = PTuneTextClassificationModel(cfg.model, trainer=trainer) + logging.info("===========================================================================================") + logging.info('Starting training...') + trainer.fit(model) + logging.info('Training finished!') + logging.info("===========================================================================================") + + if cfg.model.nemo_path: + # '.nemo' file contains the last checkpoint and the params to initialize the model + model.save_to(cfg.model.nemo_path) + logging.info(f'Model is saved into `.nemo` file: {cfg.model.nemo_path}') + + # We evaluate the trained model on the test set if test_ds is set in the config file + if cfg.model.test_ds.file_path: + logging.info("===========================================================================================") + logging.info("Starting the testing of the trained model on test set...") + trainer.test(model=model, ckpt_path=None, verbose=False) + logging.info("Testing finished!") + logging.info("===========================================================================================") + + # perform inference on a list of queries. + if "infer_samples" in cfg.model and cfg.model.infer_samples: + logging.info("===========================================================================================") + logging.info("Starting the inference on some sample queries...") + + # max_seq_length=512 is the maximum length BERT supports. + results = model.classifytext(queries=cfg.model.infer_samples, batch_size=16, max_seq_length=512) + logging.info('The prediction results of some sample queries with the trained model:') + for query, result in zip(cfg.model.infer_samples, results): + logging.info(f'Query : {query}') + logging.info(f'Predicted label: {result}') + + logging.info("Inference finished!") + logging.info("===========================================================================================") + + +if __name__ == '__main__': + main() diff --git a/nemo/collections/nlp/data/text_classification/ptune_text_classification_dataset.py b/nemo/collections/nlp/data/text_classification/ptune_text_classification_dataset.py new file mode 100644 index 000000000000..9ebb16472273 --- /dev/null +++ b/nemo/collections/nlp/data/text_classification/ptune_text_classification_dataset.py @@ -0,0 +1,60 @@ +# Copyright 2018 The Google AI Language Team Authors and +# The HuggingFace Inc. team. +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import List +from nemo.core.classes import Dataset + +__all__ = ['BankPTextClassificationDataset', 'token_wrapper'] + +import json + + +def load_file(filename): + data = [] + with open(filename, "r") as f: + for line in f.readlines(): + data.append(json.loads(line)) + return data + + +def token_wrapper(token: str) -> str: + return 'Ġ' + token + + +class BankPTextClassificationDataset(Dataset): + def __init__(self, input_file: str, sentiments: List[str]): + super().__init__() + if input_file and not os.path.exists(input_file): + raise FileNotFoundError( + f'Data file `{input_file}` not found! Each line of the data file should contain json object' + f'where `sentence` key maps to sentence and `sentiment` key maps to sentiment' + ) + data = load_file(input_file) + self.x_hs, self.x_ts = [], [] + self.data = data + + for d in data: + if d['sentiment'] not in sentiments: + continue + self.x_ts.append(d['sentiment']) + self.x_hs.append(d['sentence']) + + def __len__(self): + return len(self.data) + + def __getitem__(self, i): + return self.data[i]['sentence'], self.data[i]['sentiment'] diff --git a/nemo/collections/nlp/models/language_modeling/megatron/gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron/gpt_model.py index fc0310d06cd3..0a2f753ec5d8 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron/gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron/gpt_model.py @@ -154,6 +154,7 @@ def forward( layer_past=None, get_key_value=False, forward_method_parallel_output=None, + encoder_input=None, ): lm_output = self.language_model( @@ -163,6 +164,7 @@ def forward( prompt_tags=prompt_tags, layer_past=layer_past, get_key_value=get_key_value, + encoder_input=encoder_input ) if self.post_process: diff --git a/nemo/collections/nlp/models/text_classification/ptune_text_classification_model.py b/nemo/collections/nlp/models/text_classification/ptune_text_classification_model.py new file mode 100644 index 000000000000..abaad707ff4c --- /dev/null +++ b/nemo/collections/nlp/models/text_classification/ptune_text_classification_model.py @@ -0,0 +1,423 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +from typing import Dict, List, Optional + +import torch +from omegaconf import DictConfig +from pytorch_lightning import Trainer + +from nemo.collections.common.losses import CrossEntropyLoss +from nemo.collections.nlp.data.text_classification.ptune_text_classification_dataset import BankPTextClassificationDataset, token_wrapper +from nemo.collections.nlp.metrics.classification_report import ClassificationReport +from nemo.collections.nlp.models.nlp_model import NLPModel +from nemo.collections.nlp.modules.common import SequenceClassifier +from nemo.collections.nlp.modules.common.lm_utils import get_lm_model +from nemo.collections.nlp.parts.utils_funcs import tensor2list +from nemo.core.classes.common import typecheck +from nemo.core.classes.exportable import Exportable +from nemo.core.neural_types import NeuralType +from nemo.utils import logging +from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel +from nemo.collections.nlp.modules.common.prompt_encoder import PromptEncoder +from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer +from nemo.collections.nlp.modules.common.megatron.megatron_init import ( + initialize_model_parallel_for_nemo, +) +from torch.nn.utils.rnn import pad_sequence + +__all__ = ['PTuneTextClassificationModel'] + + +class PTuneTextClassificationModel(NLPModel, Exportable): + + # @property + # def input_types(self) -> Optional[Dict[str, NeuralType]]: + # return self.bert_model.input_types + + # @property + # def output_types(self) -> Optional[Dict[str, NeuralType]]: + # return self.classifier.output_types + + def __init__(self, cfg: DictConfig, trainer: Trainer = None): + """Initializes the BERTTextClassifier model.""" + super().__init__(cfg=cfg, trainer=trainer) + + initialize_model_parallel_for_nemo( + world_size=trainer.world_size, + global_rank=trainer.global_rank, + local_rank=trainer.local_rank, + tensor_model_parallel_size=cfg.get('tensor_model_parallel_size', 1), + seed=cfg.get('seed', 1234), + ) + + # shared params for dataset and data loaders + self.dataset_cfg = cfg.dataset + # tokenizer needs to get initialized before the super.__init__() + # as dataloaders and datasets need it to process the data + self.tokenizer = get_nmt_tokenizer( + library=cfg.tokenizer.library, + model_name=cfg.tokenizer.type, + tokenizer_model=self.register_artifact("tokenizer.model", cfg.tokenizer.model), + vocab_file=self.register_artifact("tokenizer.vocab_file", cfg.tokenizer.vocab_file), + merges_file=self.register_artifact("tokenizer.merges_file", cfg.tokenizer.merge_file), + ) + + self.class_weights = None + + self.model = MegatronGPTModel.restore_from(self.register_artifact('language_model.nemo_file', cfg.language_model.get('nemo_file', None)), + trainer=trainer).half() + + for param in self.model.parameters(): + param.requires_grad = cfg.use_lm_finetune + + hidden_size = self.model.cfg.hidden_size + + + # self.create_loss_module() + + # register the file containing the labels into the artifacts to get stored in the '.nemo' file later + self.classes = cfg.dataset.classes + + # setup to track metrics + self.classification_report = ClassificationReport( + num_classes=len(self.classes), mode='micro', dist_sync_on_step=True + ) + + self.embeddings = self.model.model.language_model.embedding.word_embeddings + + # set allowed vocab set + self.vocab = self.tokenizer.tokenizer.get_vocab() + + self.allowed_vocab_ids = set(self.vocab[token_wrapper(k)] for k in cfg.dataset.classes) + + self.template = cfg.prompt_encoder.template + + self.prompt_encoder = PromptEncoder( + template=cfg.prompt_encoder.template, + hidden_size=hidden_size, + lstm_dropout=cfg.prompt_encoder.dropout + ) + + # load prompt encoder + self.hidden_size = hidden_size + self.tokenizer.add_special_tokens({'additional_special_tokens': [cfg.pseudo_token]}) + + # if 'megatron' in self.args.model_name: + # self.pseudo_token_id = self.tokenizer.tokenizer.convert_tokens_to_ids( + # self.args.pseudo_token) + # self.pad_token_id = self.tokenizer.eod + # else: + self.pseudo_token_id = self.tokenizer.tokenizer.get_vocab()[cfg.pseudo_token] + self.pad_token_id = self.tokenizer.tokenizer.pad_token_id if self.tokenizer.tokenizer.pad_token_id is not None else self.tokenizer.tokenizer.unk_token_id + self.spell_length = sum(self.template) + + + def embed_input(self, queries): + bz = queries.shape[0] + queries_for_embedding = queries.clone() + + queries_for_embedding[(queries == self.pseudo_token_id)] = self.pad_token_id + raw_embeds = self.embeddings(queries_for_embedding) + + blocked_indices = (queries == self.pseudo_token_id).nonzero().reshape((bz, self.spell_length, 2))[:, :, 1] # bz + replace_embeds = self.prompt_encoder() + for bidx in range(bz): + for i in range(self.prompt_encoder.spell_length): + raw_embeds[bidx, blocked_indices[bidx, i], :] = replace_embeds[i, :] + return raw_embeds + + def get_query(self, x_h, prompt_tokens, x_t=None): + return [prompt_tokens * self.template[0] + + self.tokenizer.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenizer.tokenize(' ' + x_h)) # head entity + + prompt_tokens * self.template[1] + + (self.tokenizer.tokenizer.convert_tokens_to_ids( + self.tokenizer.tokenize(' ' + x_t)) if x_t is not None else []) + ] + + def forward(self, x_hs, x_ts, return_candidates=False): + bz = len(x_hs) + + # construct query ids + prompt_tokens = [self.pseudo_token_id] + x_ts = [token_wrapper(x_t) for x_t in x_ts] + queries = [torch.LongTensor(self.get_query(x_hs[i], prompt_tokens)).squeeze(0) for i in range(bz)] + queries = pad_sequence(queries, True, padding_value=self.pad_token_id).long().to(self.device) + + # construct label ids + label_ids = torch.LongTensor(self.tokenizer.tokenizer.convert_tokens_to_ids(x_ts)).reshape( + (bz, -1)).to(self.device) + attention_mask = queries != self.pad_token_id + # get embedded input + inputs_embeds = self.embed_input(queries) + + def megatron_out(): + bz, seq_len, _ = inputs_embeds.shape + labels = torch.empty_like(queries).fill_(-100).long() # bz * seq_len + label_mask = (attention_mask.long().sum(dim=1) - 1).unsqueeze(1) + labels = labels.scatter_(1, label_mask, label_ids) + + causal_mask = torch.tril( + torch.ones((bz, seq_len, seq_len), + device=self.device)).view(bz, 1, + seq_len, seq_len) + r = causal_mask.permute((1, 2, 0, 3)) * attention_mask.int() + new_atten = r.permute((2, 0, 1, 3)) + new_atten = new_atten < 0.5 + + position_ids = torch.arange(seq_len, dtype=torch.long, device=self.device) + position_ids = position_ids.unsqueeze(0).expand_as(inputs_embeds[:, :, 0]) + position_embeddings = self.model.model.language_model.embedding.position_embeddings(position_ids) + encoder_input = inputs_embeds + position_embeddings + + output = self.model.model(None, None, encoder_input=encoder_input.half(), + attention_mask=new_atten, + labels=labels) + loss, logits = output + floss = (loss[(labels != -100)]).mean() + + pred_ids = torch.argsort(logits, dim=2, descending=True) + hit1 = 0 + top10 = [] + for i in range(bz): + top10.append([]) + pred_seq = pred_ids[i, label_mask[i, 0]].tolist() + for pred in pred_seq: + if pred in self.allowed_vocab_ids: + top10[-1].append(pred) + if len(top10[-1]) >= 10: + break + pred = top10[-1][0] + if pred == label_ids[i, 0]: + hit1 += 1 + if return_candidates: + return floss, hit1, top10 + return floss, hit1 + return megatron_out() + + def create_loss_module(self): + # create the loss module if it is not yet created by the training data loader + if not hasattr(self, 'loss'): + if hasattr(self, 'class_weights') and self.class_weights: + # You may need to increase the number of epochs for convergence when using weighted_loss + self.loss = CrossEntropyLoss(weight=self.class_weights) + else: + self.loss = CrossEntropyLoss() + + def training_step(self, batch, batch_idx): + """ + Lightning calls this inside the training loop with the data from the training dataloader + passed in as `batch`. + """ + # forward pass + xs, ts = batch + train_loss, hit1 = self.forward(xs, ts) + + lr = self._optimizer.param_groups[0]['lr'] + self.log('train_loss', train_loss) + self.log('lr', lr, prog_bar=True) + + return { + 'loss': train_loss, + 'lr': lr, + } + + def validation_step(self, batch, batch_idx): + """ + Lightning calls this inside the validation loop with the data from the validation dataloader + passed in as `batch`. + """ + input_ids, input_type_ids, input_mask, labels = batch + logits = self.forward(input_ids=input_ids, token_type_ids=input_type_ids, attention_mask=input_mask) + + val_loss = self.loss(logits=logits, labels=labels) + + preds = torch.argmax(logits, axis=-1) + + tp, fn, fp, _ = self.classification_report(preds, labels) + + return {'val_loss': val_loss, 'tp': tp, 'fn': fn, 'fp': fp} + + def validation_epoch_end(self, outputs): + """ + Called at the end of validation to aggregate outputs. + :param outputs: list of individual outputs of each validation step. + """ + if not outputs: + return {} + if self.trainer.testing: + prefix = 'test' + else: + prefix = 'val' + + avg_loss = torch.stack([x[f'val_loss'] for x in outputs]).mean() + + # calculate metrics and classification report + precision, recall, f1, report = self.classification_report.compute() + + logging.info(f'{prefix}_report: {report}') + + self.log(f'{prefix}_loss', avg_loss, prog_bar=True) + self.log(f'{prefix}_precision', precision) + self.log(f'{prefix}_f1', f1) + self.log(f'{prefix}_recall', recall) + + self.classification_report.reset() + + def test_step(self, batch, batch_idx): + """ + Lightning calls this inside the test loop with the data from the test dataloader + passed in as `batch`. + """ + return self.validation_step(batch, batch_idx) + + def test_epoch_end(self, outputs): + """ + Called at the end of test to aggregate outputs. + :param outputs: list of individual outputs of each test step. + """ + return self.validation_epoch_end(outputs) + + def setup_training_data(self, train_data_config: Optional[DictConfig]): + if not train_data_config or not train_data_config.file_path: + logging.info( + f"Dataloader config or file_path for the train is missing, so no data loader for test is created!" + ) + self._test_dl = None + return + self._train_dl = self._setup_dataloader_from_config(cfg=train_data_config) + + # calculate the class weights to be used in the loss function + if self.cfg.dataset.class_balancing == 'weighted_loss': + self.class_weights = calc_class_weights(train_data_config.file_path, self.cfg.dataset.num_classes) + else: + self.class_weights = None + # we need to create/update the loss module by using the weights calculated from the training data + self.create_loss_module() + + def setup_validation_data(self, val_data_config: Optional[DictConfig]): + if not val_data_config or not val_data_config.file_path: + logging.info( + f"Dataloader config or file_path for the validation is missing, so no data loader for test is created!" + ) + self._test_dl = None + return + self._validation_dl = self._setup_dataloader_from_config(cfg=val_data_config) + + def setup_test_data(self, test_data_config: Optional[DictConfig]): + if not test_data_config or not test_data_config.file_path: + logging.info( + f"Dataloader config or file_path for the test is missing, so no data loader for test is created!" + ) + self._test_dl = None + return + self._test_dl = self._setup_dataloader_from_config(cfg=test_data_config) + + def _setup_dataloader_from_config(self, cfg: Dict) -> 'torch.utils.data.DataLoader': + input_file = cfg.file_path + if not os.path.exists(input_file): + raise FileNotFoundError( + f'{input_file} not found! The data should be be stored in TAB-separated files \n\ + "validation_ds.file_path" and "train_ds.file_path" for train and evaluation respectively. \n\ + Each line of the files contains text sequences, where words are separated with spaces. \n\ + The label of the example is separated with TAB at the end of each line. \n\ + Each line of the files should follow the format: \n\ + [WORD][SPACE][WORD][SPACE][WORD][...][TAB][LABEL]' + ) + + dataset = BankPTextClassificationDataset( + input_file, + self._cfg.dataset.classes + ) + + return torch.utils.data.DataLoader( + dataset=dataset, + batch_size=cfg.batch_size, + shuffle=cfg.shuffle, + num_workers=cfg.get("num_workers", 0), + pin_memory=cfg.get("pin_memory", False), + drop_last=cfg.get("drop_last", False), + collate_fn=dataset.collate_fn, + ) + + @torch.no_grad() + def classifytext(self, queries: List[str], batch_size: int = 1, max_seq_length: int = -1) -> List[int]: + """ + Get prediction for the queries + Args: + queries: text sequences + batch_size: batch size to use during inference + max_seq_length: sequences longer than max_seq_length will get truncated. default -1 disables truncation. + Returns: + all_preds: model predictions + """ + # store predictions for all queries in a single list + all_preds = [] + mode = self.training + device = next(self.parameters()).device + try: + # Switch model to evaluation mode + self.eval() + logging_level = logging.get_verbosity() + logging.set_verbosity(logging.WARNING) + dataloader_cfg = {"batch_size": batch_size, "num_workers": 3, "pin_memory": False} + infer_datalayer = self._setup_infer_dataloader(dataloader_cfg, queries, max_seq_length) + + for i, batch in enumerate(infer_datalayer): + input_ids, input_type_ids, input_mask, subtokens_mask = batch + + logits = self.forward( + input_ids=input_ids.to(device), + token_type_ids=input_type_ids.to(device), + attention_mask=input_mask.to(device), + ) + + preds = tensor2list(torch.argmax(logits, axis=-1)) + all_preds.extend(preds) + finally: + # set mode back to its original value + self.train(mode=mode) + logging.set_verbosity(logging_level) + return all_preds + + def _setup_infer_dataloader( + self, cfg: Dict, queries: List[str], max_seq_length: int = -1 + ) -> 'torch.utils.data.DataLoader': + """ + Setup function for a infer data loader. + + Args: + cfg: config dictionary containing data loader params like batch_size, num_workers and pin_memory + queries: text + max_seq_length: maximum length of queries, default is -1 for no limit + Returns: + A pytorch DataLoader. + """ + pass + # dataset = BankPTextClassificationDataset() + # return torch.utils.data.DataLoader( + # dataset=dataset, + # batch_size=cfg["batch_size"], + # shuffle=False, + # num_workers=cfg.get("num_workers", 0), + # pin_memory=cfg.get("pin_memory", False), + # drop_last=False, + # collate_fn=dataset.collate_fn, + # ) + + @classmethod + def list_available_models(cls) -> Optional[Dict[str, str]]: + pass diff --git a/nemo/collections/nlp/modules/common/megatron/language_model.py b/nemo/collections/nlp/modules/common/megatron/language_model.py index 14647b42c209..b2f1bd5c55e9 100644 --- a/nemo/collections/nlp/modules/common/megatron/language_model.py +++ b/nemo/collections/nlp/modules/common/megatron/language_model.py @@ -679,9 +679,10 @@ def forward( pooling_sequence_index=0, enc_hidden_states=None, output_enc_hidden_only=False, + encoder_input=None, ): # Embeddings. - if self.pre_process: + if self.pre_process and encoder_input is None: embedding_output = self.embedding(enc_input_ids, enc_position_ids, tokentype_ids=tokentype_ids) # Soft prompts @@ -694,7 +695,7 @@ def forward( else: encoder_input = embedding_output else: - encoder_input = None + encoder_input = encoder_input # encoder. if enc_hidden_states is None: diff --git a/nemo/collections/nlp/modules/common/prompt_encoder.py b/nemo/collections/nlp/modules/common/prompt_encoder.py new file mode 100644 index 000000000000..b354b123955a --- /dev/null +++ b/nemo/collections/nlp/modules/common/prompt_encoder.py @@ -0,0 +1,59 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +from nemo.core.classes import Exportable, NeuralModule +import torch +from torch import nn + +__all__ = ['SequenceClassifier'] + + +class PromptEncoder(NeuralModule, Exportable): + + def __init__(self, + template: List[int], + hidden_size: int, + lstm_dropout: float): + super().__init__() + self.spell_length = sum(template) + self.hidden_size = hidden_size + # ent embedding + self.cloze_length = template + self.cloze_mask = [ + [1] * self.cloze_length[0] # first cloze + + [1] * self.cloze_length[1] # second cloze + + [1] * self.cloze_length[2] # third cloze + ] + self.cloze_mask = torch.LongTensor(self.cloze_mask).bool() + self.register_buffer('seq_indices', torch.LongTensor(list(range(len(self.cloze_mask[0]))))) + + # embedding + self.embedding = torch.nn.Embedding(len(self.cloze_mask[0]), self.hidden_size) + # LSTM + self.lstm_head = torch.nn.LSTM(input_size=self.hidden_size, + hidden_size=self.hidden_size // 2, + num_layers=2, + dropout=lstm_dropout, + bidirectional=True, + batch_first=True) + self.mlp_head = nn.Sequential(nn.Linear(self.hidden_size, self.hidden_size), + nn.ReLU(), + nn.Linear(self.hidden_size, self.hidden_size)) + + def forward(self): + input_embeds = self.embedding(self.seq_indices).unsqueeze(0) + output_embeds = self.mlp_head(self.lstm_head(input_embeds)[0]).squeeze() + return output_embeds diff --git a/tutorials/nlp/PTune_sentiment_analysis.ipynb b/tutorials/nlp/PTune_sentiment_analysis.ipynb new file mode 100644 index 000000000000..f2239ae47531 --- /dev/null +++ b/tutorials/nlp/PTune_sentiment_analysis.ipynb @@ -0,0 +1,816 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 2, + "id": "b7a434f4", + "metadata": {}, + "outputs": [], + "source": [ + "BRANCH='main'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "developmental-gibraltar", + "metadata": {}, + "outputs": [], + "source": [ + "\"\"\"\n", + "You can run either this notebook locally (if you have all the dependencies and a GPU) or on Google Colab.\n", + "\n", + "Instructions for setting up Colab are as follows:\n", + "1. Open a new Python 3 notebook.\n", + "2. Import this notebook from GitHub (File -> Upload Notebook -> \"GITHUB\" tab -> copy/paste GitHub URL)\n", + "3. Connect to an instance with a GPU (Runtime -> Change runtime type -> select \"GPU\" for hardware accelerator)\n", + "4. Run this cell to set up dependencies.\n", + "\"\"\"\n", + "# If you're using Google Colab and not running locally, run this cell\n", + "\n", + "# install NeMo\n", + "!python -m pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[nlp]" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "challenging-pioneer", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "################################################################################\n", + "### WARNING, path does not exist: KALDI_ROOT=/mnt/matylda5/iveselyk/Tools/kaldi-trunk\n", + "### (please add 'export KALDI_ROOT=' in your $HOME/.profile)\n", + "### (or run as: KALDI_ROOT= python .py)\n", + "################################################################################\n", + "\n", + "[NeMo W 2022-01-18 18:59:06 experimental:27] Module is experimental, not ready for production and is not fully supported. Use at your own risk.\n", + "[NeMo W 2022-01-18 18:59:06 experimental:27] Module is experimental, not ready for production and is not fully supported. Use at your own risk.\n", + "[NeMo W 2022-01-18 18:59:06 experimental:27] Module is experimental, not ready for production and is not fully supported. Use at your own risk.\n", + "[NeMo W 2022-01-18 18:59:06 experimental:27] Module is experimental, not ready for production and is not fully supported. Use at your own risk.\n" + ] + } + ], + "source": [ + "from nemo.collections import nlp as nemo_nlp\n", + "from nemo.utils.exp_manager import exp_manager\n", + "\n", + "import os\n", + "import wget \n", + "import torch\n", + "import pytorch_lightning as pl\n", + "from omegaconf import OmegaConf" + ] + }, + { + "cell_type": "markdown", + "id": "employed-ethiopia", + "metadata": {}, + "source": [ + "In this tutorial, we are going to describe how to finetune BioMegatron - a [BERT](https://arxiv.org/abs/1810.04805)-like [Megatron-LM](https://arxiv.org/pdf/1909.08053.pdf) model pre-trained on large biomedical text corpus ([PubMed](https://pubmed.ncbi.nlm.nih.gov/) abstracts and full-text commercial use collection) - on the [NCBI Disease Dataset](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC3951655/) for Named Entity Recognition.\n", + "\n", + "The model size of Megatron-LM can be larger than BERT, up to multi-billion parameters, compared to 345 million parameters of BERT-large.\n", + "There are some alternatives of BioMegatron, most notably [BioBERT](https://arxiv.org/abs/1901.08746). Compared to BioBERT BioMegatron is larger by model size and pre-trained on larger text corpus.\n", + "\n", + "A more general tutorial of using BERT-based models, including Megatron-LM, for downstream natural language processing tasks can be found [here](https://github.com/NVIDIA/NeMo/blob/stable/tutorials/nlp/01_Pretrained_Language_Models_for_Downstream_Tasks.ipynb).\n", + "\n", + "# Task Description\n", + "**Named entity recognition (NER)**, also referred to as entity chunking, identification or extraction, is the task of detecting and classifying key information (entities) in text.\n", + "\n", + "For instance, **given sentences from medical abstracts, what diseases are mentioned?**
\n", + "In this case, our data input is sentences from the abstracts, and our labels are the precise locations of the named disease entities. Take a look at the information provided for the dataset.\n", + "\n", + "For more details and general examples on Named Entity Recognition, please refer to the [Token Classification and Named Entity Recognition tutorial notebook](https://github.com/NVIDIA/NeMo/blob/stable/tutorials/nlp/Token_Classification_Named_Entity_Recognition.ipynb).\n", + "\n", + "# Dataset\n", + "\n", + "The [NCBI-disease corpus](https://www.ncbi.nlm.nih.gov/CBBresearch/Dogan/DISEASE/) is a set of 793 PubMed abstracts, annotated by 14 annotators. The annotations take the form of HTML-style tags inserted into the abstract text using the clearly defined rules. The annotations identify named diseases, and can be used to fine-tune a language model to identify disease mentions in future abstracts, *whether those diseases were part of the original training set or not*.\n", + "\n", + "Here's an example of what an annotated abstract from the corpus looks like:\n", + "\n", + "```html\n", + "10021369\tIdentification of APC2, a homologue of the adenomatous polyposis coli tumour suppressor .\tThe adenomatous polyposis coli ( APC ) tumour-suppressor protein controls the Wnt signalling pathway by forming a complex with glycogen synthase kinase 3beta ( GSK-3beta ) , axin / conductin and betacatenin . Complex formation induces the rapid degradation of betacatenin . In colon carcinoma cells , loss of APC leads to the accumulation of betacatenin in the nucleus , where it binds to and activates the Tcf-4 transcription factor ( reviewed in [ 1 ] [ 2 ] ) . Here , we report the identification and genomic structure of APC homologues . Mammalian APC2 , which closely resembles APC in overall domain structure , was functionally analyzed and shown to contain two SAMP domains , both of which are required for binding to conductin . Like APC , APC2 regulates the formation of active betacatenin-Tcf complexes , as demonstrated using transient transcriptional activation assays in APC - / - colon carcinoma cells . Human APC2 maps to chromosome 19p13 . 3 . APC and APC2 may therefore have comparable functions in development and cancer .\n", + "```\n", + "\n", + "In this example, we see the following tags within the abstract:\n", + "```html\n", + "adenomatous polyposis coli tumour\n", + "adenomatous polyposis coli ( APC ) tumour\n", + "colon carcinoma\n", + "colon carcinoma\n", + "cancer\n", + "```\n", + "\n", + "For our purposes, we will consider any identified category (such as \"Modifier\", \"Specific Disease\", and a few others) to generally be a \"disease\".\n", + "\n", + "Let's download the dataset." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "federal-beads", + "metadata": {}, + "outputs": [], + "source": [ + "DATA_DIR = \"DATA_DIR\"\n", + "os.makedirs(DATA_DIR, exist_ok=True)\n", + "os.makedirs(os.path.join(DATA_DIR, 'SA'), exist_ok=True)" + ] + }, + { + "cell_type": "markdown", + "id": "1c1e1b08", + "metadata": {}, + "source": [ + "## Downloading Financial Phrase Bank Dataset\n", + "\n", + "The datase is collected by Malo et al. 2014, and can be downloaded from this [link](https://www.researchgate.net/profile/Pekka_Malo/publication/251231364_FinancialPhraseBank-v10/data/0c96051eee4fb1d56e000000/FinancialPhraseBank-v10.zip). The zip file for the Financial Phrase Bank Dataset has been provided for ease of download and use." + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "8ad03fc0", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--2022-01-18 19:17:05-- https://www.researchgate.net/profile/Pekka_Malo/publication/251231364_FinancialPhraseBank-v10/data/0c96051eee4fb1d56e000000/FinancialPhraseBank-v10.zip\n", + "Resolving www.researchgate.net (www.researchgate.net)... 104.17.32.105, 104.17.33.105, 2606:4700::6811:2069, ...\n", + "Connecting to www.researchgate.net (www.researchgate.net)|104.17.32.105|:443... connected.\n", + "HTTP request sent, awaiting response... 301 Moved Permanently\n", + "Location: https://www.researchgate.net/profile/Pekka-Malo/publication/251231364_FinancialPhraseBank-v10/data/0c96051eee4fb1d56e000000/FinancialPhraseBank-v10.zip [following]\n", + "--2022-01-18 19:17:05-- https://www.researchgate.net/profile/Pekka-Malo/publication/251231364_FinancialPhraseBank-v10/data/0c96051eee4fb1d56e000000/FinancialPhraseBank-v10.zip\n", + "Reusing existing connection to www.researchgate.net:443.\n", + "HTTP request sent, awaiting response... 200 OK\n", + "Length: 681890 (666K) [application/zip]\n", + "Saving to: ‘FinancialPhraseBank-v10.zip’\n", + "\n", + "FinancialPhraseBank 100%[===================>] 665.91K --.-KB/s in 0.04s \n", + "\n", + "2022-01-18 19:17:05 (17.9 MB/s) - ‘FinancialPhraseBank-v10.zip’ saved [681890/681890]\n", + "\n", + "Archive: DATA_DIR/FinancialPhraseBank-v10.zip\n", + " creating: DATA_DIR/FinancialPhraseBank-v1.0/\n", + " inflating: DATA_DIR/FinancialPhraseBank-v1.0/License.txt \n", + " creating: DATA_DIR/__MACOSX/\n", + " creating: DATA_DIR/__MACOSX/FinancialPhraseBank-v1.0/\n", + " inflating: DATA_DIR/__MACOSX/FinancialPhraseBank-v1.0/._License.txt \n", + " inflating: DATA_DIR/FinancialPhraseBank-v1.0/README.txt \n", + " inflating: DATA_DIR/__MACOSX/FinancialPhraseBank-v1.0/._README.txt \n", + " inflating: DATA_DIR/FinancialPhraseBank-v1.0/Sentences_50Agree.txt \n", + " inflating: DATA_DIR/FinancialPhraseBank-v1.0/Sentences_66Agree.txt \n", + " inflating: DATA_DIR/FinancialPhraseBank-v1.0/Sentences_75Agree.txt \n", + " inflating: DATA_DIR/FinancialPhraseBank-v1.0/Sentences_AllAgree.txt \n" + ] + } + ], + "source": [ + "!wget https://www.researchgate.net/profile/Pekka_Malo/publication/251231364_FinancialPhraseBank-v10/data/0c96051eee4fb1d56e000000/FinancialPhraseBank-v10.zip\n", + "!mv FinancialPhraseBank-v10.zip {DATA_DIR}\n", + "!unzip {DATA_DIR}/FinancialPhraseBank-v10.zip -d {DATA_DIR}" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "radical-castle", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "According to Gran , the company has no plans to move all production to Russia , although that is where the company is growing .@neutral\n" + ] + } + ], + "source": [ + "# If you want to see more examples, you can explore the text of the corpus using the file browser to the left, or open files directly, for example typing a command like the following in a code-cell:\n", + "\n", + "! head -1 $DATA_DIR/FinancialPhraseBank-v1.0/Sentences_50Agree.txt" + ] + }, + { + "cell_type": "markdown", + "id": "specified-maine", + "metadata": {}, + "source": [ + "We have two datasets derived from this corpus: a text classification dataset and a named entity recognition (NER) dataset. The text classification dataset labels the abstracts among three broad disease groupings. We'll use this simple split to demonstrate the NLP text classification task. The NER dataset labels individual words as diseases. This dataset will be used for the NLP NER task. " + ] + }, + { + "cell_type": "markdown", + "id": "affected-numbers", + "metadata": {}, + "source": [ + "## Pre-process dataset\n", + "A pre-processed NCBI-disease dataset for NER can be found [here](https://github.com/spyysalo/ncbi-disease/tree/master/conll) or [here](https://github.com/dmis-lab/biobert#datasets).
\n", + "We download the files under {DATA_DIR/NER} directory." + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "198287d4", + "metadata": {}, + "outputs": [], + "source": [ + "import json\n", + "\n", + "files = ['Sentences_50Agree.txt', 'Sentences_66Agree.txt', 'Sentences_75Agree.txt', 'Sentences_AllAgree.txt']\n", + "base_dir = DATA_DIR + '/FinancialPhraseBank-v1.0/'\n", + "files = [base_dir + f for f in files]\n", + "\n", + "alllines = []\n", + "for fn in files:\n", + " with open(fn, 'r', encoding=\"ISO-8859-1\") as f:\n", + " alllines.extend(f.readlines())\n", + "\n", + "fold = 10\n", + "fold_size = len(alllines) // fold\n", + "\n", + "chunk_start = list(range(0, 14780, 1478))\n", + "\n", + "chunks = []\n", + "\n", + "for start_id in chunk_start:\n", + " chunks.append(alllines[start_id:start_id+fold_size])\n", + "\n", + "special = '<|endoftext|>'\n", + "\n", + "def gen_file(data, fold_id, split_type):\n", + " filename = \"{}/{}_{}.txt\".format(base_dir, split_type, fold_id)\n", + " with open(filename, 'w') as f:\n", + " obj = {}\n", + " for line in data:\n", + " splits = line.split('@')\n", + " part1 = splits[0].strip()\n", + " part2 = splits[1].strip()\n", + " obj['sentence'] = part1 +'. Sentiment '\n", + " obj['sentiment'] = part2\n", + " f.write(json.dumps(obj)+'\\n')\n", + "\n", + "\n", + "def gen_fold(fold_number):\n", + " lists = list(range(fold))\n", + " test_id = (fold_number + fold) % fold\n", + " val_id = (fold_number + fold - 1) % fold\n", + " test_set = chunks[test_id]\n", + " val_set = chunks[val_id]\n", + " lists.remove(test_id)\n", + " lists.remove(val_id)\n", + " train_set = []\n", + " for idd in lists:\n", + " train_set += chunks[idd]\n", + " gen_file(train_set, fold_number, 'train')\n", + " gen_file(val_set, fold_number, 'validation')\n", + " gen_file(test_set, fold_number, 'test')\n", + "\n", + "for i in range(fold):\n", + " gen_fold(i)" + ] + }, + { + "cell_type": "markdown", + "id": "graphic-debate", + "metadata": {}, + "source": [ + "The NER task requires two files: the text sentences, and the labels. Run the next two cells to see a sample of the two files." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "sound-surgeon", + "metadata": {}, + "outputs": [], + "source": [ + "!head $NER_DATA_DIR/text_train.txt" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "spectacular-strain", + "metadata": {}, + "outputs": [], + "source": [ + "!head $NER_DATA_DIR/labels_train.txt" + ] + }, + { + "cell_type": "markdown", + "id": "3813cc36", + "metadata": {}, + "source": [ + "## Convert the Megatron-LM Weights to Nemo file\n", + "\n", + "If you prefer to use the Huggingface BERT models, please skip this section and refer to `Setting up a NeMo Experiment` setction to load a model from `nemo_nlp.modules.get_pretrained_lm_models_list()`\n", + "\n", + "NeMo Megatron BERT can [load from a pretrained model](https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/core/core.html?highlight=nemo%20file#restore) using `.nemo` file. We can convert the Megatron-LM checkpoint to the `.nemo` file. Let's first download the pretrained model weights and vocabulary file." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "82b8e08e", + "metadata": {}, + "outputs": [], + "source": [ + "from nemo.collections.nlp.modules.common.megatron.megatron_utils import MEGATRON_CONFIG_MAP\n", + "import pathlib\n", + "# specify BERT-like model, you want to use\n", + "PRETRAINED_BERT_MODEL = \"biomegatron-bert-345m-cased\"\n", + "\n", + "checkpoint_url = MEGATRON_CONFIG_MAP[PRETRAINED_BERT_MODEL]['checkpoint']\n", + "vocab_url = MEGATRON_CONFIG_MAP[PRETRAINED_BERT_MODEL]['vocab']\n", + "checkpoint_filename = pathlib.Path(checkpoint_url).name\n", + "vocab_filename = pathlib.Path(vocab_url).name\n", + "if not pathlib.Path(checkpoint_filename).exists():\n", + " print('downloading from checkpoint url', checkpoint_url)\n", + " !wget $checkpoint_url\n", + "if not pathlib.Path(vocab_filename).exists():\n", + " print('downloading from vocab url', vocab_url)\n", + " !wget $vocab_url" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4b00ee86", + "metadata": {}, + "outputs": [], + "source": [ + "WORK_DIR = \"WORK_DIR\"\n", + "os.makedirs(WORK_DIR, exist_ok=True)\n", + "\n", + "# Prepare the model parameters \n", + "# download the model's configuration file \n", + "config_dir = WORK_DIR + '/configs/'\n", + "MODEL_CONFIG = \"megatron_bert_config.yaml\"\n", + "os.makedirs(config_dir, exist_ok=True)\n", + "if not os.path.exists(config_dir + MODEL_CONFIG):\n", + " print('Downloading config file...')\n", + " wget.download(f'https://raw.githubusercontent.com/NVIDIA/NeMo/{BRANCH}/examples/nlp/language_modeling/conf/' + MODEL_CONFIG, config_dir)\n", + "else:\n", + " print ('config file is already exists')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0ae5a1a9", + "metadata": {}, + "outputs": [], + "source": [ + "# this line will print the entire config of the model\n", + "config_path = f'{WORK_DIR}/configs/{MODEL_CONFIG}'\n", + "print(config_path)\n", + "config = OmegaConf.load(config_path)\n", + "config.model.num_layers = 24\n", + "config.model.hidden_size = 1024\n", + "config.model.ffn_hidden_size = 4096\n", + "config.model.num_attention_heads = 16\n", + "config.model.tokenizer.vocab_file = vocab_filename\n", + "config.model.tokenizer.type = 'BertWordPieceCase'\n", + "config.model.tensor_model_parallel_size = 1\n", + "config.model.data.data_prefix = ''\n", + "config.model.max_position_embeddings = 512\n", + "config.model.data.seq_length = 512\n", + "config.cfg = {}\n", + "config.cfg.cfg = config.model\n", + "with open('hparams.yaml', 'w') as f:\n", + " f.write(OmegaConf.to_yaml(config.cfg))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9e1beda4", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "PWD = os.getcwd()\n", + "wget.download(f'https://raw.githubusercontent.com/NVIDIA/NeMo/{BRANCH}/examples/nlp/language_modeling/megatron_lm_ckpt_to_nemo.py')\n", + "!python -m torch.distributed.run --nproc_per_node=1 megatron_lm_ckpt_to_nemo.py --checkpoint_folder=$PWD --checkpoint_name=$checkpoint_filename --hparams_file=$PWD/hparams.yaml --nemo_file_path=$PWD/biomegatron.nemo --model_type=bert --tensor_model_parallel_size=1" + ] + }, + { + "cell_type": "markdown", + "id": "84b455a6", + "metadata": {}, + "source": [ + "# Model configuration\n", + "\n", + "Our Named Entity Recognition model is comprised of the pretrained [BERT](https://arxiv.org/pdf/1810.04805.pdf) model followed by a Token Classification layer.\n", + "\n", + "The model is defined in a config file which declares multiple important sections. They are:\n", + "- **model**: All arguments that are related to the Model - language model, token classifier, optimizer and schedulers, datasets and any other related information\n", + "\n", + "- **trainer**: Any argument to be passed to PyTorch Lightning" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "speaking-grant", + "metadata": {}, + "outputs": [], + "source": [ + "MODEL_CONFIG = \"token_classification_config.yaml\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "demanding-ballet", + "metadata": {}, + "outputs": [], + "source": [ + "# download the model's configuration file \n", + "config_dir = WORK_DIR + '/configs/'\n", + "os.makedirs(config_dir, exist_ok=True)\n", + "if not os.path.exists(config_dir + MODEL_CONFIG):\n", + " print('Downloading config file...')\n", + " wget.download(f'https://raw.githubusercontent.com/NVIDIA/NeMo/{BRANCH}/examples/nlp/token_classification/conf/' + MODEL_CONFIG, config_dir)\n", + "else:\n", + " print ('config file is already exists')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "criminal-outdoors", + "metadata": {}, + "outputs": [], + "source": [ + "# this line will print the entire config of the model\n", + "config_path = f'{WORK_DIR}/configs/{MODEL_CONFIG}'\n", + "print(config_path)\n", + "config = OmegaConf.load(config_path)\n", + "# Note: these are small batch-sizes - increase as appropriate to available GPU capacity\n", + "config.model.train_ds.batch_size=8\n", + "config.model.validation_ds.batch_size=8" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "informed-purse", + "metadata": {}, + "outputs": [], + "source": [ + "# in this tutorial train and dev datasets are located in the same folder, so it is enought to add the path of the data directory to the config\n", + "config.model.dataset.data_dir = os.path.join(DATA_DIR, 'NER')\n", + "\n", + "# if you want to decrease the size of your datasets, uncomment the lines below:\n", + "# NUM_SAMPLES = 1000\n", + "# config.model.train_ds.num_samples = NUM_SAMPLES\n", + "# config.model.validation_ds.num_samples = NUM_SAMPLES" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "divine-belly", + "metadata": {}, + "outputs": [], + "source": [ + "print(OmegaConf.to_yaml(config))" + ] + }, + { + "cell_type": "markdown", + "id": "dedicated-effort", + "metadata": {}, + "source": [ + "# Model Training\n", + "## Setting up Data within the config\n", + "\n", + "Among other things, the config file contains dictionaries called dataset, train_ds and validation_ds. These are configurations used to setup the Dataset and DataLoaders of the corresponding config.\n" + ] + }, + { + "cell_type": "markdown", + "id": "15e2c67a", + "metadata": {}, + "source": [ + "\n", + "We assume that both training and evaluation files are located in the same directory, and use the default names mentioned during the data download step. \n", + "So, to start model training, we simply need to specify `model.dataset.data_dir`, like we are going to do below.\n" + ] + }, + { + "cell_type": "markdown", + "id": "89dd468d", + "metadata": {}, + "source": [ + "\n", + "Also notice that some config lines, including `model.dataset.data_dir`, have `???` in place of paths, this means that values for these fields are required to be specified by the user.\n", + "\n", + "Let's now add the data directory path to the config." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a312ed76", + "metadata": {}, + "outputs": [], + "source": [ + "# in this tutorial train and dev datasets are located in the same folder, so it is enought to add the path of the data directory to the config\n", + "config.model.dataset.data_dir = os.path.join(DATA_DIR, 'NER')\n", + "\n", + "# if you want to decrease the size of your datasets, uncomment the lines below:\n", + "# NUM_SAMPLES = 1000\n", + "# config.model.train_ds.num_samples = NUM_SAMPLES\n", + "# config.model.validation_ds.num_samples = NUM_SAMPLES" + ] + }, + { + "cell_type": "markdown", + "id": "changed-mauritius", + "metadata": {}, + "source": [ + "## Building the PyTorch Lightning Trainer\n", + "\n", + "NeMo models are primarily PyTorch Lightning modules - and therefore are entirely compatible with the PyTorch Lightning ecosystem.\n", + "\n", + "Let's first instantiate a Trainer object" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "computational-battlefield", + "metadata": {}, + "outputs": [], + "source": [ + "print(\"Trainer config - \\n\")\n", + "print(OmegaConf.to_yaml(config.trainer))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "unique-genre", + "metadata": {}, + "outputs": [], + "source": [ + "# lets modify some trainer configs\n", + "# checks if we have GPU available and uses it\n", + "cuda = 1 if torch.cuda.is_available() else 0\n", + "config.trainer.gpus = cuda\n", + "\n", + "# for PyTorch Native AMP set precision=16\n", + "config.trainer.precision = 16 if torch.cuda.is_available() else 32\n", + "\n", + "# remove distributed training flags\n", + "config.trainer.accelerator = None\n", + "\n", + "trainer = pl.Trainer(**config.trainer)" + ] + }, + { + "cell_type": "markdown", + "id": "overall-literature", + "metadata": {}, + "source": [ + "## Setting up a NeMo Experiment\n", + "\n", + "NeMo has an experiment manager that handles logging and checkpointing for us, so let's use it:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "mathematical-portable", + "metadata": {}, + "outputs": [], + "source": [ + "exp_dir = exp_manager(trainer, config.get(\"exp_manager\", None))\n", + "os.makedirs(WORK_DIR, exist_ok=True)\n", + "\n", + "# the exp_dir provides a path to the current experiment for easy access\n", + "exp_dir = str(exp_dir)\n", + "exp_dir" + ] + }, + { + "cell_type": "markdown", + "id": "f62ea6cd", + "metadata": {}, + "source": [ + "To load the pretrained BERT LM model, we can either load it from the converted `.nemo` file as shown above or load it from a list of included model names. \n", + "\n", + "We can get the list of names by following command \n", + "```python\n", + "# complete list of supported BERT-like models\n", + "print(nemo_nlp.modules.get_pretrained_lm_models_list())\n", + "```\n", + "We can change the `model.language_mode` config to use it\n", + "```python\n", + "# add the specified above model parameters to the config\n", + "config.model.language_model.pretrained_model_name = MODEL_NAME\n", + "```\n", + "\n", + "In this notebook, we will use the converted `.nemo` file as our LM model, which is BioMegatron, [Megatron-LM BERT](https://arxiv.org/abs/1909.08053) pre-trained on [PubMed](https://pubmed.ncbi.nlm.nih.gov/) biomedical text corpus." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "compact-horse", + "metadata": {}, + "outputs": [], + "source": [ + "# add the specified above model parameters to the config\n", + "# config.model.language_model.pretrained_model_name = PRETRAINED_BERT_MODEL\n", + "config.model.language_model.nemo_file = 'biomegatron.nemo'\n", + "config.model.language_model.pretrained_model_name = 'megatron-bert-cased'\n", + "config.model.tokenizer.vocab_file='vocab.txt'\n", + "config.model.tokenizer.tokenizer_model = 'BertWordPieceCase'\n", + " \n", + "\n" + ] + }, + { + "cell_type": "markdown", + "id": "seeing-geometry", + "metadata": {}, + "source": [ + "Now, we are ready to initialize our model. During the model initialization call, the dataset and data loaders we'll be prepared for training and evaluation.\n", + "Also, the pretrained BERT model will be downloaded, note it can take up to a few minutes depending on the size of the chosen BERT model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "indoor-france", + "metadata": {}, + "outputs": [], + "source": [ + "model_ner = nemo_nlp.models.TokenClassificationModel(cfg=config.model, trainer=trainer)" + ] + }, + { + "cell_type": "markdown", + "id": "genuine-pipeline", + "metadata": {}, + "source": [ + "## Monitoring training progress\n", + "Optionally, you can create a Tensorboard visualization to monitor training progress.\n", + "If you're not using Colab, refer to [https://www.tensorflow.org/tensorboard/tensorboard_in_notebooks](https://www.tensorflow.org/tensorboard/tensorboard_in_notebooks) if you're facing issues with running the cell below." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "changed-expense", + "metadata": {}, + "outputs": [], + "source": [ + "try:\n", + " from google import colab\n", + " COLAB_ENV = True\n", + "except (ImportError, ModuleNotFoundError):\n", + " COLAB_ENV = False\n", + "\n", + "# Load the TensorBoard notebook extension\n", + "if COLAB_ENV:\n", + " %load_ext tensorboard\n", + " %tensorboard --logdir {exp_dir}\n", + "else:\n", + " print(\"To use tensorboard, please use this notebook in a Google Colab environment.\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "applied-quality", + "metadata": {}, + "outputs": [], + "source": [ + "# start model training\n", + "trainer.fit(model_ner)" + ] + }, + { + "cell_type": "markdown", + "id": "cooperative-michael", + "metadata": {}, + "source": [ + "# Inference\n", + "\n", + "To see how the model performs, we can run generate prediction similar to the way we did it earlier" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "classical-scientist", + "metadata": {}, + "outputs": [], + "source": [ + "# let's first create a subset of our dev data\n", + "! head -n 100 $NER_DATA_DIR/text_dev.txt > $NER_DATA_DIR/sample_text_dev.txt\n", + "! head -n 100 $NER_DATA_DIR/labels_dev.txt > $NER_DATA_DIR/sample_labels_dev.txt" + ] + }, + { + "cell_type": "markdown", + "id": "adult-ranking", + "metadata": {}, + "source": [ + "Now, let's generate predictions for the provided text file.\n", + "If labels file is also specified, the model will evaluate the predictions and plot confusion matrix. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "twenty-abortion", + "metadata": {}, + "outputs": [], + "source": [ + "model_ner.half().evaluate_from_file(\n", + " text_file=os.path.join(NER_DATA_DIR, 'sample_text_dev.txt'),\n", + " labels_file=os.path.join(NER_DATA_DIR, 'sample_labels_dev.txt'),\n", + " output_dir=exp_dir,\n", + " add_confusion_matrix=False,\n", + " normalize_confusion_matrix=True,\n", + " batch_size=1\n", + ")\n", + "# Please check matplotlib version if encountering any error plotting confusion matrix:\n", + "# https://stackoverflow.com/questions/63212347/importerror-cannot-import-name-png-from-matplotlib" + ] + }, + { + "cell_type": "markdown", + "id": "connected-typing", + "metadata": {}, + "source": [ + "## Training Script\n", + "\n", + "If you have NeMo installed locally, you can also train the model with `nlp/token_classification/token_classification_train.py.`\n", + "\n", + "To run training script, use:\n", + "\n", + "`python token_classification_train.py model.dataset.data_dir=PATH_TO_DATA_DIR exp_manager.exp_dir=EXP_DIR model.language_model.pretrained_model_name=megatron-bert-cased model.tokenizer.vocab_file=VOCAB_FILE model.tokenizer.tokenizer_model=BertWordPieceCase model.language_model.nemo_file=NEMO_FILE`\n" + ] + }, + { + "cell_type": "markdown", + "id": "legitimate-electric", + "metadata": {}, + "source": [ + "The training could take several minutes and the result should look something like\n", + "```\n", + "[NeMo I 2020-05-22 17:13:48 token_classification_callback:82] Accuracy: 0.9882348032875798\n", + "[NeMo I 2020-05-22 17:13:48 token_classification_callback:86] F1 weighted: 98.82\n", + "[NeMo I 2020-05-22 17:13:48 token_classification_callback:86] F1 macro: 93.74\n", + "[NeMo I 2020-05-22 17:13:48 token_classification_callback:86] F1 micro: 98.82\n", + "[NeMo I 2020-05-22 17:13:49 token_classification_callback:89] precision recall f1-score support\n", + " \n", + " O (label id: 0) 0.9938 0.9957 0.9947 22092\n", + " B (label id: 1) 0.8843 0.9034 0.8938 787\n", + " I (label id: 2) 0.9505 0.8982 0.9236 1090\n", + " \n", + " accuracy 0.9882 23969\n", + " macro avg 0.9429 0.9324 0.9374 23969\n", + " weighted avg 0.9882 0.9882 0.9882 23969\n", + "```" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From 06d13440374a6458abd069f4a2f80b4a3196bdfb Mon Sep 17 00:00:00 2001 From: Yi Dong Date: Wed, 19 Jan 2022 09:36:15 -0800 Subject: [PATCH 02/22] training is working Signed-off-by: Yi Dong --- .../ptune_text_classification_config.yaml | 4 -- .../ptune_text_classification_dataset.py | 5 +- .../language_modeling/megatron/gpt_model.py | 7 ++- .../ptune_text_classification_model.py | 61 +++++++++---------- 4 files changed, 39 insertions(+), 38 deletions(-) diff --git a/examples/nlp/text_classification/conf/ptune_text_classification_config.yaml b/examples/nlp/text_classification/conf/ptune_text_classification_config.yaml index 88f0d326c135..2e81a1045f33 100644 --- a/examples/nlp/text_classification/conf/ptune_text_classification_config.yaml +++ b/examples/nlp/text_classification/conf/ptune_text_classification_config.yaml @@ -54,10 +54,6 @@ model: dataset: classes: ??? # The class labels, e.g. ['positive', 'neutral', 'negative'] - do_lower_case: false # true for uncased models, false for cased models, will be set automatically if pre-trained tokenizer model is used - max_seq_length: 256 # the maximum length BERT supports is 512 - class_balancing: null # null or 'weighted_loss'. 'weighted_loss' enables the weighted class balancing of the loss, may be used for handling unbalanced classes - use_cache: false # uses a cache to store the processed dataset, you may use it for large datasets for speed up train_ds: file_path: null diff --git a/nemo/collections/nlp/data/text_classification/ptune_text_classification_dataset.py b/nemo/collections/nlp/data/text_classification/ptune_text_classification_dataset.py index 9ebb16472273..b9035555226b 100644 --- a/nemo/collections/nlp/data/text_classification/ptune_text_classification_dataset.py +++ b/nemo/collections/nlp/data/text_classification/ptune_text_classification_dataset.py @@ -36,14 +36,15 @@ def token_wrapper(token: str) -> str: class BankPTextClassificationDataset(Dataset): - def __init__(self, input_file: str, sentiments: List[str]): + def __init__(self, input_file: str, sentiments: List[str], data: List[str]=None): super().__init__() if input_file and not os.path.exists(input_file): raise FileNotFoundError( f'Data file `{input_file}` not found! Each line of the data file should contain json object' f'where `sentence` key maps to sentence and `sentiment` key maps to sentiment' ) - data = load_file(input_file) + if data is None: + data = load_file(input_file) self.x_hs, self.x_ts = [], [] self.data = data diff --git a/nemo/collections/nlp/models/language_modeling/megatron/gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron/gpt_model.py index 0a2f753ec5d8..043ef32e6b17 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron/gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron/gpt_model.py @@ -31,6 +31,7 @@ def post_language_model_processing( parallel_output, forward_method_parallel_output, fp16_lm_cross_entropy, + return_logits=False ): if get_key_value: lm_output, presents = lm_output @@ -52,7 +53,10 @@ def post_language_model_processing( else: loss = tensor_parallel.vocab_parallel_cross_entropy(output.float(), labels) - return loss + if return_logits: + return loss, output + else: + return loss class GPTModel(MegatronModule): @@ -176,6 +180,7 @@ def forward( self.parallel_output, forward_method_parallel_output, self.fp16_lm_cross_entropy, + return_logits=encoder_input is not None ) else: return lm_output diff --git a/nemo/collections/nlp/models/text_classification/ptune_text_classification_model.py b/nemo/collections/nlp/models/text_classification/ptune_text_classification_model.py index abaad707ff4c..a5c3f99b0d7c 100644 --- a/nemo/collections/nlp/models/text_classification/ptune_text_classification_model.py +++ b/nemo/collections/nlp/models/text_classification/ptune_text_classification_model.py @@ -38,6 +38,7 @@ initialize_model_parallel_for_nemo, ) from torch.nn.utils.rnn import pad_sequence +from nemo.utils import logging __all__ = ['PTuneTextClassificationModel'] @@ -87,16 +88,10 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): hidden_size = self.model.cfg.hidden_size - # self.create_loss_module() # register the file containing the labels into the artifacts to get stored in the '.nemo' file later self.classes = cfg.dataset.classes - # setup to track metrics - self.classification_report = ClassificationReport( - num_classes=len(self.classes), mode='micro', dist_sync_on_step=True - ) - self.embeddings = self.model.model.language_model.embedding.word_embeddings # set allowed vocab set @@ -104,6 +99,19 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): self.allowed_vocab_ids = set(self.vocab[token_wrapper(k)] for k in cfg.dataset.classes) + # map from id to label + self.allowed_vocab = {} + label_ids = {} + for i, k in enumerate(cfg.dataset.classes): + self.allowed_vocab[self.vocab[token_wrapper(k)]] = i + label_ids[k] = i + + # setup to track metrics + self.classification_report = ClassificationReport( + num_classes=len(self.classes), label_ids=label_ids, mode='micro', dist_sync_on_step=True + ) + + self.template = cfg.prompt_encoder.template self.prompt_encoder = PromptEncoder( @@ -124,7 +132,6 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): self.pseudo_token_id = self.tokenizer.tokenizer.get_vocab()[cfg.pseudo_token] self.pad_token_id = self.tokenizer.tokenizer.pad_token_id if self.tokenizer.tokenizer.pad_token_id is not None else self.tokenizer.tokenizer.unk_token_id self.spell_length = sum(self.template) - def embed_input(self, queries): bz = queries.shape[0] @@ -141,8 +148,14 @@ def embed_input(self, queries): return raw_embeds def get_query(self, x_h, prompt_tokens, x_t=None): + max_seq_len = self.model._cfg.encoder_seq_length + input_token_ids = self.tokenizer.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenizer.tokenize(' ' + x_h)) + cut = 0 + if len(input_token_ids) + sum(self.template) > max_seq_len: + logging.warning("Input sequence is longer than the LM model max seq, will cut it off to fit") + cut = len(input_token_ids) + sum(self.template) - max_seq_len return [prompt_tokens * self.template[0] - + self.tokenizer.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenizer.tokenize(' ' + x_h)) # head entity + + input_token_ids[cut:] # head entity + prompt_tokens * self.template[1] + (self.tokenizer.tokenizer.convert_tokens_to_ids( self.tokenizer.tokenize(' ' + x_t)) if x_t is not None else []) @@ -161,7 +174,7 @@ def forward(self, x_hs, x_ts, return_candidates=False): label_ids = torch.LongTensor(self.tokenizer.tokenizer.convert_tokens_to_ids(x_ts)).reshape( (bz, -1)).to(self.device) attention_mask = queries != self.pad_token_id - # get embedded input + # get embedded input inputs_embeds = self.embed_input(queries) def megatron_out(): @@ -192,6 +205,8 @@ def megatron_out(): pred_ids = torch.argsort(logits, dim=2, descending=True) hit1 = 0 top10 = [] + returned_pred = [] + returned_label = [] for i in range(bz): top10.append([]) pred_seq = pred_ids[i, label_mask[i, 0]].tolist() @@ -201,22 +216,15 @@ def megatron_out(): if len(top10[-1]) >= 10: break pred = top10[-1][0] + returned_pred.append(self.allowed_vocab[pred]) + returned_label.append(self.allowed_vocab[label_ids[i, 0].item()]) if pred == label_ids[i, 0]: hit1 += 1 if return_candidates: return floss, hit1, top10 - return floss, hit1 + return floss, hit1, torch.tensor(returned_pred).to(self.device), torch.tensor(returned_label).to(self.device) return megatron_out() - def create_loss_module(self): - # create the loss module if it is not yet created by the training data loader - if not hasattr(self, 'loss'): - if hasattr(self, 'class_weights') and self.class_weights: - # You may need to increase the number of epochs for convergence when using weighted_loss - self.loss = CrossEntropyLoss(weight=self.class_weights) - else: - self.loss = CrossEntropyLoss() - def training_step(self, batch, batch_idx): """ Lightning calls this inside the training loop with the data from the training dataloader @@ -224,7 +232,7 @@ def training_step(self, batch, batch_idx): """ # forward pass xs, ts = batch - train_loss, hit1 = self.forward(xs, ts) + train_loss, hit1, pred_ids, label_ids = self.forward(xs, ts) lr = self._optimizer.param_groups[0]['lr'] self.log('train_loss', train_loss) @@ -240,12 +248,8 @@ def validation_step(self, batch, batch_idx): Lightning calls this inside the validation loop with the data from the validation dataloader passed in as `batch`. """ - input_ids, input_type_ids, input_mask, labels = batch - logits = self.forward(input_ids=input_ids, token_type_ids=input_type_ids, attention_mask=input_mask) - - val_loss = self.loss(logits=logits, labels=labels) - - preds = torch.argmax(logits, axis=-1) + xs, ts = batch + val_loss, hit1 , preds, labels = self.forward(xs, ts) tp, fn, fp, _ = self.classification_report(preds, labels) @@ -300,11 +304,6 @@ def setup_training_data(self, train_data_config: Optional[DictConfig]): return self._train_dl = self._setup_dataloader_from_config(cfg=train_data_config) - # calculate the class weights to be used in the loss function - if self.cfg.dataset.class_balancing == 'weighted_loss': - self.class_weights = calc_class_weights(train_data_config.file_path, self.cfg.dataset.num_classes) - else: - self.class_weights = None # we need to create/update the loss module by using the weights calculated from the training data self.create_loss_module() From 1ee3972e4a15b104d49a8fe8ee092e13a82b9c7f Mon Sep 17 00:00:00 2001 From: Yi Dong Date: Wed, 19 Jan 2022 11:37:36 -0800 Subject: [PATCH 03/22] refactor to seperate prediction and loss computation Signed-off-by: Yi Dong --- .../ptune_text_classification_model.py | 158 +++++++++++------- 1 file changed, 95 insertions(+), 63 deletions(-) diff --git a/nemo/collections/nlp/models/text_classification/ptune_text_classification_model.py b/nemo/collections/nlp/models/text_classification/ptune_text_classification_model.py index a5c3f99b0d7c..9d03a399f551 100644 --- a/nemo/collections/nlp/models/text_classification/ptune_text_classification_model.py +++ b/nemo/collections/nlp/models/text_classification/ptune_text_classification_model.py @@ -43,6 +43,8 @@ __all__ = ['PTuneTextClassificationModel'] +SMALL_LOGITS = -100 + class PTuneTextClassificationModel(NLPModel, Exportable): # @property @@ -161,69 +163,102 @@ def get_query(self, x_h, prompt_tokens, x_t=None): self.tokenizer.tokenize(' ' + x_t)) if x_t is not None else []) ] - def forward(self, x_hs, x_ts, return_candidates=False): - bz = len(x_hs) - + def get_ground_truth_labels(self, batch_size, label_ids): + returned_label = [] + for i in range(batch_size): + returned_label.append(self.allowed_vocab[label_ids[i, 0].item()]) + return torch.tensor(returned_label).to(self.device) + + def get_prediction(self, batch_size, label_position, logits): + pred_ids = torch.argsort(logits, dim=2, descending=True) + top10 = [] + returned_pred = [] + for i in range(batch_size): + top10.append([]) + pred_seq = pred_ids[i, label_position[i, 0]].tolist() + for pred in pred_seq: + if pred in self.allowed_vocab_ids: + top10[-1].append(pred) + if len(top10[-1]) >= 10: + break + pred = top10[-1][0] + returned_pred.append(self.allowed_vocab[pred]) + return top10, torch.tensor(returned_pred).to(self.device) + + def get_encoder_input(self, sentences): + batch_size = len(sentences) # construct query ids prompt_tokens = [self.pseudo_token_id] - x_ts = [token_wrapper(x_t) for x_t in x_ts] - queries = [torch.LongTensor(self.get_query(x_hs[i], prompt_tokens)).squeeze(0) for i in range(bz)] + + queries = [torch.LongTensor(self.get_query(sentences[i], prompt_tokens)).squeeze(0) for i in range(batch_size)] queries = pad_sequence(queries, True, padding_value=self.pad_token_id).long().to(self.device) - # construct label ids - label_ids = torch.LongTensor(self.tokenizer.tokenizer.convert_tokens_to_ids(x_ts)).reshape( - (bz, -1)).to(self.device) + # attention_mask indicates the boundary of attention attention_mask = queries != self.pad_token_id # get embedded input inputs_embeds = self.embed_input(queries) - def megatron_out(): - bz, seq_len, _ = inputs_embeds.shape - labels = torch.empty_like(queries).fill_(-100).long() # bz * seq_len - label_mask = (attention_mask.long().sum(dim=1) - 1).unsqueeze(1) - labels = labels.scatter_(1, label_mask, label_ids) - - causal_mask = torch.tril( - torch.ones((bz, seq_len, seq_len), - device=self.device)).view(bz, 1, - seq_len, seq_len) - r = causal_mask.permute((1, 2, 0, 3)) * attention_mask.int() - new_atten = r.permute((2, 0, 1, 3)) - new_atten = new_atten < 0.5 - - position_ids = torch.arange(seq_len, dtype=torch.long, device=self.device) - position_ids = position_ids.unsqueeze(0).expand_as(inputs_embeds[:, :, 0]) - position_embeddings = self.model.model.language_model.embedding.position_embeddings(position_ids) - encoder_input = inputs_embeds + position_embeddings - - output = self.model.model(None, None, encoder_input=encoder_input.half(), - attention_mask=new_atten, - labels=labels) - loss, logits = output - floss = (loss[(labels != -100)]).mean() - - pred_ids = torch.argsort(logits, dim=2, descending=True) - hit1 = 0 - top10 = [] - returned_pred = [] - returned_label = [] - for i in range(bz): - top10.append([]) - pred_seq = pred_ids[i, label_mask[i, 0]].tolist() - for pred in pred_seq: - if pred in self.allowed_vocab_ids: - top10[-1].append(pred) - if len(top10[-1]) >= 10: - break - pred = top10[-1][0] - returned_pred.append(self.allowed_vocab[pred]) - returned_label.append(self.allowed_vocab[label_ids[i, 0].item()]) - if pred == label_ids[i, 0]: - hit1 += 1 - if return_candidates: - return floss, hit1, top10 - return floss, hit1, torch.tensor(returned_pred).to(self.device), torch.tensor(returned_label).to(self.device) - return megatron_out() + bz, seq_len, _ = inputs_embeds.shape + + # get the GPT causal mask + causal_mask = torch.tril( + torch.ones((bz, seq_len, seq_len), + device=self.device)).view(bz, 1, + seq_len, seq_len) + # combine the attention_mask and causal_mask + r = causal_mask.permute((1, 2, 0, 3)) * attention_mask.int() + new_atten = r.permute((2, 0, 1, 3)) + # convert it to the boolean + new_atten = new_atten < 0.5 + + # calculate the position embedding based on the seq_len + position_ids = torch.arange(seq_len, dtype=torch.long, device=self.device) + position_ids = position_ids.unsqueeze(0).expand_as(inputs_embeds[:, :, 0]) + position_embeddings = self.model.model.language_model.embedding.position_embeddings(position_ids) + + # get the final input for encoder + encoder_input = inputs_embeds + position_embeddings + + # calculate the position of the output token + label_position = (attention_mask.long().sum(dim=1) - 1).unsqueeze(1) + return encoder_input, new_atten, label_position + + def get_label_input(self, labels, label_position, seq_len): + batch_size, _ = label_position.shape + x_ts = [token_wrapper(x_t) for x_t in labels] + + # construct label ids + label_ids = torch.LongTensor(self.tokenizer.tokenizer.convert_tokens_to_ids(x_ts)).reshape( + (batch_size, -1)).to(self.device) + labels = torch.zeros(batch_size, seq_len).to(self.device).fill_(SMALL_LOGITS).long() # bz * seq_len + labels = labels.scatter_(1, label_position, label_ids) + return labels, label_ids + + def forward_eval(self, sentences): + encoder_input, new_atten, label_position = self.get_encoder_input(sentences) + batch_size, _, seq_len, _ = new_atten.shape + + output = self.model.model(None, None, encoder_input=encoder_input, + attention_mask=new_atten) + logits = output + + _, returned_pred = self.get_prediction(batch_size, label_position, logits) + return returned_pred + + def forward(self, sentences, labels): + encoder_input, new_atten, label_position = self.get_encoder_input(sentences) + batch_size, _, seq_len, _ = new_atten.shape + labels_input, label_ids = self.get_label_input(labels, label_position, seq_len) + + output = self.model.model(None, None, encoder_input=encoder_input, + attention_mask=new_atten, + labels=labels_input) + loss, logits = output + floss = (loss[(labels != SMALL_LOGITS)]).mean() + + _, returned_pred = self.get_prediction(batch_size, label_position, logits) + returned_label = self.get_ground_truth_labels(batch_size, label_ids) + return floss, returned_pred, returned_label def training_step(self, batch, batch_idx): """ @@ -231,8 +266,8 @@ def training_step(self, batch, batch_idx): passed in as `batch`. """ # forward pass - xs, ts = batch - train_loss, hit1, pred_ids, label_ids = self.forward(xs, ts) + sentences, labels = batch + train_loss, _, _ = self.forward(sentences, labels) lr = self._optimizer.param_groups[0]['lr'] self.log('train_loss', train_loss) @@ -248,10 +283,10 @@ def validation_step(self, batch, batch_idx): Lightning calls this inside the validation loop with the data from the validation dataloader passed in as `batch`. """ - xs, ts = batch - val_loss, hit1 , preds, labels = self.forward(xs, ts) + sentences, labels = batch + val_loss, preds, gt_labels = self.forward(sentences, labels) - tp, fn, fp, _ = self.classification_report(preds, labels) + tp, fn, fp, _ = self.classification_report(preds, gt_labels) return {'val_loss': val_loss, 'tp': tp, 'fn': fn, 'fp': fp} @@ -304,9 +339,6 @@ def setup_training_data(self, train_data_config: Optional[DictConfig]): return self._train_dl = self._setup_dataloader_from_config(cfg=train_data_config) - # we need to create/update the loss module by using the weights calculated from the training data - self.create_loss_module() - def setup_validation_data(self, val_data_config: Optional[DictConfig]): if not val_data_config or not val_data_config.file_path: logging.info( From 1c14df0b214da2483790462132d6446d71524375 Mon Sep 17 00:00:00 2001 From: Yi Dong Date: Wed, 19 Jan 2022 19:15:51 -0800 Subject: [PATCH 04/22] updated the notebook Signed-off-by: Yi Dong --- .../megatron_lm_ckpt_to_nemo.py | 3 +- .../ptune_text_classification_dataset.py | 10 +- .../models/text_classification/__init__.py | 2 + .../ptune_text_classification_model.py | 42 +- tutorials/nlp/PTune_sentiment_analysis.ipynb | 626 +++++++++++++++--- 5 files changed, 545 insertions(+), 138 deletions(-) diff --git a/examples/nlp/language_modeling/megatron_lm_ckpt_to_nemo.py b/examples/nlp/language_modeling/megatron_lm_ckpt_to_nemo.py index bb4d68216afc..85a30d221bc9 100644 --- a/examples/nlp/language_modeling/megatron_lm_ckpt_to_nemo.py +++ b/examples/nlp/language_modeling/megatron_lm_ckpt_to_nemo.py @@ -230,6 +230,7 @@ def convert(rank, world_size, args): ## this dictionary is used to rename the model parameters name_translate = {} name_translate['transformer'] = 'encoder' + name_translate['.attention.'] = '.self_attention.' model = load_from_checkpoint( MegatronGPTModel, checkpoint_path, @@ -242,7 +243,7 @@ def convert(rank, world_size, args): ## this dictionary is used to rename the model parameters name_translate = {} name_translate['transformer'] = 'encoder' - name_translate['attention.'] = 'self_attention.' + name_translate['.attention.'] = '.self_attention.' model = load_from_checkpoint( MegatronBertModel, checkpoint_path, diff --git a/nemo/collections/nlp/data/text_classification/ptune_text_classification_dataset.py b/nemo/collections/nlp/data/text_classification/ptune_text_classification_dataset.py index b9035555226b..85ccc66e705e 100644 --- a/nemo/collections/nlp/data/text_classification/ptune_text_classification_dataset.py +++ b/nemo/collections/nlp/data/text_classification/ptune_text_classification_dataset.py @@ -44,11 +44,15 @@ def __init__(self, input_file: str, sentiments: List[str], data: List[str]=None) f'where `sentence` key maps to sentence and `sentiment` key maps to sentiment' ) if data is None: - data = load_file(input_file) + json_data = load_file(input_file) + else: + json_data = [] + for line in data: + json_data.append({'sentence': line+' Sentiment ', 'sentiment': ''}) self.x_hs, self.x_ts = [], [] - self.data = data + self.data = json_data - for d in data: + for d in json_data: if d['sentiment'] not in sentiments: continue self.x_ts.append(d['sentiment']) diff --git a/nemo/collections/nlp/models/text_classification/__init__.py b/nemo/collections/nlp/models/text_classification/__init__.py index 10ef0f00b883..6d5dc10fc600 100644 --- a/nemo/collections/nlp/models/text_classification/__init__.py +++ b/nemo/collections/nlp/models/text_classification/__init__.py @@ -13,3 +13,5 @@ # limitations under the License. from nemo.collections.nlp.models.text_classification.text_classification_model import TextClassificationModel +from nemo.collections.nlp.models.text_classification.ptune_text_classification_model import PTuneTextClassificationModel + diff --git a/nemo/collections/nlp/models/text_classification/ptune_text_classification_model.py b/nemo/collections/nlp/models/text_classification/ptune_text_classification_model.py index 9d03a399f551..7ee313a25fe2 100644 --- a/nemo/collections/nlp/models/text_classification/ptune_text_classification_model.py +++ b/nemo/collections/nlp/models/text_classification/ptune_text_classification_model.py @@ -104,9 +104,11 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): # map from id to label self.allowed_vocab = {} label_ids = {} + self.id_to_label = {} for i, k in enumerate(cfg.dataset.classes): self.allowed_vocab[self.vocab[token_wrapper(k)]] = i label_ids[k] = i + self.id_to_label[i] = k # setup to track metrics self.classification_report = ClassificationReport( @@ -398,26 +400,17 @@ def classifytext(self, queries: List[str], batch_size: int = 1, max_seq_length: # store predictions for all queries in a single list all_preds = [] mode = self.training - device = next(self.parameters()).device try: # Switch model to evaluation mode self.eval() logging_level = logging.get_verbosity() logging.set_verbosity(logging.WARNING) dataloader_cfg = {"batch_size": batch_size, "num_workers": 3, "pin_memory": False} - infer_datalayer = self._setup_infer_dataloader(dataloader_cfg, queries, max_seq_length) - + infer_datalayer = self._setup_infer_dataloader(dataloader_cfg, queries) for i, batch in enumerate(infer_datalayer): - input_ids, input_type_ids, input_mask, subtokens_mask = batch - - logits = self.forward( - input_ids=input_ids.to(device), - token_type_ids=input_type_ids.to(device), - attention_mask=input_mask.to(device), - ) - - preds = tensor2list(torch.argmax(logits, axis=-1)) - all_preds.extend(preds) + sentences, _ = batch + preds = self.forward_eval(sentences) + all_preds.extend([self.id_to_label[i.item()] for i in preds]) finally: # set mode back to its original value self.train(mode=mode) @@ -425,7 +418,7 @@ def classifytext(self, queries: List[str], batch_size: int = 1, max_seq_length: return all_preds def _setup_infer_dataloader( - self, cfg: Dict, queries: List[str], max_seq_length: int = -1 + self, cfg: Dict, queries: List[str] ) -> 'torch.utils.data.DataLoader': """ Setup function for a infer data loader. @@ -437,17 +430,16 @@ def _setup_infer_dataloader( Returns: A pytorch DataLoader. """ - pass - # dataset = BankPTextClassificationDataset() - # return torch.utils.data.DataLoader( - # dataset=dataset, - # batch_size=cfg["batch_size"], - # shuffle=False, - # num_workers=cfg.get("num_workers", 0), - # pin_memory=cfg.get("pin_memory", False), - # drop_last=False, - # collate_fn=dataset.collate_fn, - # ) + dataset = BankPTextClassificationDataset(None, None, queries) + return torch.utils.data.DataLoader( + dataset=dataset, + batch_size=cfg["batch_size"], + shuffle=False, + num_workers=cfg.get("num_workers", 0), + pin_memory=cfg.get("pin_memory", False), + drop_last=False, + collate_fn=dataset.collate_fn, + ) @classmethod def list_available_models(cls) -> Optional[Dict[str, str]]: diff --git a/tutorials/nlp/PTune_sentiment_analysis.ipynb b/tutorials/nlp/PTune_sentiment_analysis.ipynb index f2239ae47531..4883f808ac95 100644 --- a/tutorials/nlp/PTune_sentiment_analysis.ipynb +++ b/tutorials/nlp/PTune_sentiment_analysis.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 2, + "execution_count": 16, "id": "b7a434f4", "metadata": {}, "outputs": [], @@ -34,27 +34,10 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 44, "id": "challenging-pioneer", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "################################################################################\n", - "### WARNING, path does not exist: KALDI_ROOT=/mnt/matylda5/iveselyk/Tools/kaldi-trunk\n", - "### (please add 'export KALDI_ROOT=' in your $HOME/.profile)\n", - "### (or run as: KALDI_ROOT= python .py)\n", - "################################################################################\n", - "\n", - "[NeMo W 2022-01-18 18:59:06 experimental:27] Module is experimental, not ready for production and is not fully supported. Use at your own risk.\n", - "[NeMo W 2022-01-18 18:59:06 experimental:27] Module is experimental, not ready for production and is not fully supported. Use at your own risk.\n", - "[NeMo W 2022-01-18 18:59:06 experimental:27] Module is experimental, not ready for production and is not fully supported. Use at your own risk.\n", - "[NeMo W 2022-01-18 18:59:06 experimental:27] Module is experimental, not ready for production and is not fully supported. Use at your own risk.\n" - ] - } - ], + "outputs": [], "source": [ "from nemo.collections import nlp as nemo_nlp\n", "from nemo.utils.exp_manager import exp_manager\n", @@ -112,7 +95,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 2, "id": "federal-beads", "metadata": {}, "outputs": [], @@ -134,7 +117,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 4, "id": "8ad03fc0", "metadata": {}, "outputs": [ @@ -142,45 +125,34 @@ "name": "stdout", "output_type": "stream", "text": [ - "--2022-01-18 19:17:05-- https://www.researchgate.net/profile/Pekka_Malo/publication/251231364_FinancialPhraseBank-v10/data/0c96051eee4fb1d56e000000/FinancialPhraseBank-v10.zip\n", + "--2022-01-20 01:48:29-- https://www.researchgate.net/profile/Pekka_Malo/publication/251231364_FinancialPhraseBank-v10/data/0c96051eee4fb1d56e000000/FinancialPhraseBank-v10.zip\n", "Resolving www.researchgate.net (www.researchgate.net)... 104.17.32.105, 104.17.33.105, 2606:4700::6811:2069, ...\n", "Connecting to www.researchgate.net (www.researchgate.net)|104.17.32.105|:443... connected.\n", "HTTP request sent, awaiting response... 301 Moved Permanently\n", "Location: https://www.researchgate.net/profile/Pekka-Malo/publication/251231364_FinancialPhraseBank-v10/data/0c96051eee4fb1d56e000000/FinancialPhraseBank-v10.zip [following]\n", - "--2022-01-18 19:17:05-- https://www.researchgate.net/profile/Pekka-Malo/publication/251231364_FinancialPhraseBank-v10/data/0c96051eee4fb1d56e000000/FinancialPhraseBank-v10.zip\n", + "--2022-01-20 01:48:29-- https://www.researchgate.net/profile/Pekka-Malo/publication/251231364_FinancialPhraseBank-v10/data/0c96051eee4fb1d56e000000/FinancialPhraseBank-v10.zip\n", "Reusing existing connection to www.researchgate.net:443.\n", "HTTP request sent, awaiting response... 200 OK\n", "Length: 681890 (666K) [application/zip]\n", "Saving to: ‘FinancialPhraseBank-v10.zip’\n", "\n", - "FinancialPhraseBank 100%[===================>] 665.91K --.-KB/s in 0.04s \n", + "FinancialPhraseBank 100%[===================>] 665.91K --.-KB/s in 0.02s \n", "\n", - "2022-01-18 19:17:05 (17.9 MB/s) - ‘FinancialPhraseBank-v10.zip’ saved [681890/681890]\n", + "2022-01-20 01:48:30 (28.1 MB/s) - ‘FinancialPhraseBank-v10.zip’ saved [681890/681890]\n", "\n", - "Archive: DATA_DIR/FinancialPhraseBank-v10.zip\n", - " creating: DATA_DIR/FinancialPhraseBank-v1.0/\n", - " inflating: DATA_DIR/FinancialPhraseBank-v1.0/License.txt \n", - " creating: DATA_DIR/__MACOSX/\n", - " creating: DATA_DIR/__MACOSX/FinancialPhraseBank-v1.0/\n", - " inflating: DATA_DIR/__MACOSX/FinancialPhraseBank-v1.0/._License.txt \n", - " inflating: DATA_DIR/FinancialPhraseBank-v1.0/README.txt \n", - " inflating: DATA_DIR/__MACOSX/FinancialPhraseBank-v1.0/._README.txt \n", - " inflating: DATA_DIR/FinancialPhraseBank-v1.0/Sentences_50Agree.txt \n", - " inflating: DATA_DIR/FinancialPhraseBank-v1.0/Sentences_66Agree.txt \n", - " inflating: DATA_DIR/FinancialPhraseBank-v1.0/Sentences_75Agree.txt \n", - " inflating: DATA_DIR/FinancialPhraseBank-v1.0/Sentences_AllAgree.txt \n" + "Archive: DATA_DIR/FinancialPhraseBank-v10.zip\n" ] } ], "source": [ "!wget https://www.researchgate.net/profile/Pekka_Malo/publication/251231364_FinancialPhraseBank-v10/data/0c96051eee4fb1d56e000000/FinancialPhraseBank-v10.zip\n", "!mv FinancialPhraseBank-v10.zip {DATA_DIR}\n", - "!unzip {DATA_DIR}/FinancialPhraseBank-v10.zip -d {DATA_DIR}" + "!unzip -f {DATA_DIR}/FinancialPhraseBank-v10.zip -d {DATA_DIR}" ] }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 5, "id": "radical-castle", "metadata": {}, "outputs": [ @@ -218,13 +190,15 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 6, "id": "198287d4", "metadata": {}, "outputs": [], "source": [ "import json\n", + "import random\n", "\n", + "random.seed(1234)\n", "files = ['Sentences_50Agree.txt', 'Sentences_66Agree.txt', 'Sentences_75Agree.txt', 'Sentences_AllAgree.txt']\n", "base_dir = DATA_DIR + '/FinancialPhraseBank-v1.0/'\n", "files = [base_dir + f for f in files]\n", @@ -234,6 +208,7 @@ " with open(fn, 'r', encoding=\"ISO-8859-1\") as f:\n", " alllines.extend(f.readlines())\n", "\n", + "random.shuffle(alllines)\n", "fold = 10\n", "fold_size = len(alllines) // fold\n", "\n", @@ -254,7 +229,7 @@ " splits = line.split('@')\n", " part1 = splits[0].strip()\n", " part2 = splits[1].strip()\n", - " obj['sentence'] = part1 +'. Sentiment '\n", + " obj['sentence'] = part1 +' Sentiment '\n", " obj['sentiment'] = part2\n", " f.write(json.dumps(obj)+'\\n')\n", "\n", @@ -288,12 +263,21 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 9, "id": "sound-surgeon", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{\"sentence\": \"The contract includes heating plant equipment and associated installation work . Sentiment \", \"sentiment\": \"neutral\"}\n", + "{\"sentence\": \"The utility will also provide services related to electricity management , such as hedging trades and risk management and reporting . Sentiment \", \"sentiment\": \"neutral\"}\n" + ] + } + ], "source": [ - "!head $NER_DATA_DIR/text_train.txt" + "!head -n 2 $DATA_DIR/FinancialPhraseBank-v1.0/train_0.txt" ] }, { @@ -302,9 +286,7 @@ "id": "spectacular-strain", "metadata": {}, "outputs": [], - "source": [ - "!head $NER_DATA_DIR/labels_train.txt" - ] + "source": [] }, { "cell_type": "markdown", @@ -313,41 +295,47 @@ "source": [ "## Convert the Megatron-LM Weights to Nemo file\n", "\n", - "If you prefer to use the Huggingface BERT models, please skip this section and refer to `Setting up a NeMo Experiment` setction to load a model from `nemo_nlp.modules.get_pretrained_lm_models_list()`\n", + "P-Tuning method works the best with large GPT lanague models. From our experiences, models of size 5B or above give good performance. If you already have a large GPT model ready, skip this section. \n", "\n", - "NeMo Megatron BERT can [load from a pretrained model](https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/core/core.html?highlight=nemo%20file#restore) using `.nemo` file. We can convert the Megatron-LM checkpoint to the `.nemo` file. Let's first download the pretrained model weights and vocabulary file." + "In this example, we will use the pretrained 344M NeMo Megatron GPT model from [Megatron-LM project](https://github.com/NVIDIA/Megatron-LM). To load it in NeMo Megatron, We first need to convert the Megatron-LM checkpoint to the `.nemo` file. Let's download the pretrained model weights and vocabulary file.\n", + "\n" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 19, "id": "82b8e08e", "metadata": {}, "outputs": [], "source": [ - "from nemo.collections.nlp.modules.common.megatron.megatron_utils import MEGATRON_CONFIG_MAP\n", "import pathlib\n", - "# specify BERT-like model, you want to use\n", - "PRETRAINED_BERT_MODEL = \"biomegatron-bert-345m-cased\"\n", + "gpt_file = 'megatron_lm_345m_v0.0.zip'\n", + "vocab_file = 'gpt2-vocab.json'\n", + "merge_file = 'gpt2-merge.txt'\n", + "checkpoint_filename = 'model_optim_rng.pt'\n", "\n", - "checkpoint_url = MEGATRON_CONFIG_MAP[PRETRAINED_BERT_MODEL]['checkpoint']\n", - "vocab_url = MEGATRON_CONFIG_MAP[PRETRAINED_BERT_MODEL]['vocab']\n", - "checkpoint_filename = pathlib.Path(checkpoint_url).name\n", - "vocab_filename = pathlib.Path(vocab_url).name\n", - "if not pathlib.Path(checkpoint_filename).exists():\n", - " print('downloading from checkpoint url', checkpoint_url)\n", - " !wget $checkpoint_url\n", - "if not pathlib.Path(vocab_filename).exists():\n", - " print('downloading from vocab url', vocab_url)\n", - " !wget $vocab_url" + "if not pathlib.Path(gpt_file).exists():\n", + " !wget --content-disposition https://api.ngc.nvidia.com/v2/models/nvidia/megatron_lm_345m/versions/v0.0/zip -O $gpt_file\n", + " !unzip -f $gpt_file\n", + " !wget https://s3.amazonaws.com/models.huggingface.co/bert/$vocab_file -O $vocab_file \n", + " !wget https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt -O $merge_file\n", + "\n" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 17, "id": "4b00ee86", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Downloading config file...\n" + ] + } + ], "source": [ "WORK_DIR = \"WORK_DIR\"\n", "os.makedirs(WORK_DIR, exist_ok=True)\n", @@ -355,7 +343,7 @@ "# Prepare the model parameters \n", "# download the model's configuration file \n", "config_dir = WORK_DIR + '/configs/'\n", - "MODEL_CONFIG = \"megatron_bert_config.yaml\"\n", + "MODEL_CONFIG = \"megatron_gpt_config.yaml\"\n", "os.makedirs(config_dir, exist_ok=True)\n", "if not os.path.exists(config_dir + MODEL_CONFIG):\n", " print('Downloading config file...')\n", @@ -366,10 +354,18 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 18, "id": "0ae5a1a9", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "WORK_DIR/configs/megatron_gpt_config.yaml\n" + ] + } + ], "source": [ "# this line will print the entire config of the model\n", "config_path = f'{WORK_DIR}/configs/{MODEL_CONFIG}'\n", @@ -379,12 +375,12 @@ "config.model.hidden_size = 1024\n", "config.model.ffn_hidden_size = 4096\n", "config.model.num_attention_heads = 16\n", - "config.model.tokenizer.vocab_file = vocab_filename\n", - "config.model.tokenizer.type = 'BertWordPieceCase'\n", + "config.model.tokenizer.vocab_file = vocab_file\n", + "config.model.tokenizer.merge_file = merge_file\n", "config.model.tensor_model_parallel_size = 1\n", "config.model.data.data_prefix = ''\n", - "config.model.max_position_embeddings = 512\n", - "config.model.data.seq_length = 512\n", + "config.model.max_position_embeddings = 1024\n", + "config.model.data.seq_length = 1024\n", "config.cfg = {}\n", "config.cfg.cfg = config.model\n", "with open('hparams.yaml', 'w') as f:\n", @@ -393,15 +389,47 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 21, "id": "9e1beda4", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "################################################################################\n", + "### WARNING, path does not exist: KALDI_ROOT=/mnt/matylda5/iveselyk/Tools/kaldi-trunk\n", + "### (please add 'export KALDI_ROOT=' in your $HOME/.profile)\n", + "### (or run as: KALDI_ROOT= python .py)\n", + "################################################################################\n", + "\n", + "[NeMo W 2022-01-20 02:37:43 experimental:27] Module is experimental, not ready for production and is not fully supported. Use at your own risk.\n", + "[NeMo W 2022-01-20 02:37:43 experimental:27] Module is experimental, not ready for production and is not fully supported. Use at your own risk.\n", + "[NeMo W 2022-01-20 02:37:43 experimental:27] Module is experimental, not ready for production and is not fully supported. Use at your own risk.\n", + "[NeMo W 2022-01-20 02:37:43 experimental:27] Module is experimental, not ready for production and is not fully supported. Use at your own risk.\n", + "I0120 02:37:44.114219 139878815041344 distributed_c10d.py:218] Added key: store_based_barrier_key:1 to store for rank: 0\n", + "I0120 02:37:44.114770 139878815041344 distributed_c10d.py:252] Rank 0: Completed store-based barrier for key:store_based_barrier_key:1 with 1 nodes.\n", + "GPU available: True, used: True\n", + "TPU available: False, using: 0 TPU cores\n", + "IPU available: False, using: 0 IPUs\n", + "converted 354.87M parameters\n", + "[NeMo I 2022-01-20 02:37:44 tokenizer_utils:190] Getting Megatron tokenizer for pretrained model name: megatron-gpt-345m and custom vocab file: /NeMo/tutorials/nlp/gpt2-vocab.json\n", + "[NeMo I 2022-01-20 02:37:44 tokenizer_utils:123] Getting HuggingFace AutoTokenizer with pretrained_model_name: gpt2, vocab_file: /NeMo/tutorials/nlp/gpt2-vocab.json, special_tokens_dict: {}, and use_fast: False\n", + "Using sep_token, but it is not set yet.\n", + "Using cls_token, but it is not set yet.\n", + "Using pad_token, but it is not set yet.\n", + "Using mask_token, but it is not set yet.\n", + "[NeMo I 2022-01-20 02:37:48 megatron_gpt_model:754] Padded vocab_size: 50304, original vocab_size: 50257, dummy tokens: 47.\n", + "[NeMo I 2022-01-20 02:46:42 megatron_lm_ckpt_to_nemo:265] NeMo model saved to: /NeMo/tutorials/nlp/gpt_344m.nemo\n", + "\u001b[0m" + ] + } + ], "source": [ "import os\n", "PWD = os.getcwd()\n", "wget.download(f'https://raw.githubusercontent.com/NVIDIA/NeMo/{BRANCH}/examples/nlp/language_modeling/megatron_lm_ckpt_to_nemo.py')\n", - "!python -m torch.distributed.run --nproc_per_node=1 megatron_lm_ckpt_to_nemo.py --checkpoint_folder=$PWD --checkpoint_name=$checkpoint_filename --hparams_file=$PWD/hparams.yaml --nemo_file_path=$PWD/biomegatron.nemo --model_type=bert --tensor_model_parallel_size=1" + "!python -m torch.distributed.run --nproc_per_node=1 megatron_lm_ckpt_to_nemo.py --checkpoint_folder=$PWD/release/mp_rank_00/ --checkpoint_name=$checkpoint_filename --hparams_file=$PWD/hparams.yaml --nemo_file_path=$PWD/gpt_344m.nemo --model_type=gpt --tensor_model_parallel_size=1" ] }, { @@ -421,20 +449,28 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 24, "id": "speaking-grant", "metadata": {}, "outputs": [], "source": [ - "MODEL_CONFIG = \"token_classification_config.yaml\"" + "MODEL_CONFIG = \"ptune_text_classification_config.yaml\"" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 26, "id": "demanding-ballet", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "config file is already exists\n" + ] + } + ], "source": [ "# download the model's configuration file \n", "config_dir = WORK_DIR + '/configs/'\n", @@ -448,10 +484,18 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 27, "id": "criminal-outdoors", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "WORK_DIR/configs/ptune_text_classification_config.yaml\n" + ] + } + ], "source": [ "# this line will print the entire config of the model\n", "config_path = f'{WORK_DIR}/configs/{MODEL_CONFIG}'\n", @@ -464,13 +508,17 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 36, "id": "informed-purse", "metadata": {}, "outputs": [], "source": [ "# in this tutorial train and dev datasets are located in the same folder, so it is enought to add the path of the data directory to the config\n", - "config.model.dataset.data_dir = os.path.join(DATA_DIR, 'NER')\n", + "#config.model.dataset.classes = ['positive', 'neutral', 'negative']\n", + "config.model.train_ds.file_path = DATA_DIR+'/FinancialPhraseBank-v1.0/train_0.txt'\n", + "config.model.validation_ds.file_path = DATA_DIR+'/FinancialPhraseBank-v1.0/validation_0.txt'\n", + "config.model.test_ds.file_path = DATA_DIR+'/FinancialPhraseBank-v1.0/test_0.txt'\n", + "\n", "\n", "# if you want to decrease the size of your datasets, uncomment the lines below:\n", "# NUM_SAMPLES = 1000\n", @@ -480,10 +528,103 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 31, "id": "divine-belly", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "trainer:\n", + " gpus: 1\n", + " num_nodes: 1\n", + " max_epochs: 100\n", + " max_steps: null\n", + " accumulate_grad_batches: 1\n", + " gradient_clip_val: 0.0\n", + " precision: 32\n", + " accelerator: ddp\n", + " log_every_n_steps: 1\n", + " val_check_interval: 1.0\n", + " resume_from_checkpoint: null\n", + " num_sanity_val_steps: 0\n", + " checkpoint_callback: false\n", + " logger: false\n", + "model:\n", + " tensor_model_parallel_size: 1\n", + " seed: 1234\n", + " nemo_path: ptune_text_classification_model.nemo\n", + " use_lm_finetune: false\n", + " pseudo_token: '[PROMPT]'\n", + " tokenizer:\n", + " library: megatron\n", + " type: GPT2BPETokenizer\n", + " model: null\n", + " vocab_file: null\n", + " merge_file: null\n", + " language_model:\n", + " nemo_file: null\n", + " prompt_encoder:\n", + " template:\n", + " - 3\n", + " - 3\n", + " - 0\n", + " dropout: 0.1\n", + " dataset:\n", + " classes: ???\n", + " train_ds:\n", + " file_path: DATA_DIR/FinancialPhraseBank-v1.0/train_0.txt\n", + " batch_size: 8\n", + " shuffle: true\n", + " num_samples: -1\n", + " num_workers: 3\n", + " drop_last: false\n", + " pin_memory: false\n", + " validation_ds:\n", + " file_path: DATA_DIR/FinancialPhraseBank-v1.0/validation_0.txt\n", + " batch_size: 8\n", + " shuffle: false\n", + " num_samples: -1\n", + " num_workers: 3\n", + " drop_last: false\n", + " pin_memory: false\n", + " test_ds:\n", + " file_path: DATA_DIR/FinancialPhraseBank-v1.0/test_0.txt\n", + " batch_size: 64\n", + " shuffle: false\n", + " num_samples: -1\n", + " num_workers: 3\n", + " drop_last: false\n", + " pin_memory: false\n", + " optim:\n", + " name: adam\n", + " lr: 2.0e-05\n", + " betas:\n", + " - 0.9\n", + " - 0.999\n", + " weight_decay: 0.01\n", + " sched:\n", + " name: WarmupAnnealing\n", + " warmup_steps: null\n", + " warmup_ratio: 0.1\n", + " last_epoch: -1\n", + " monitor: val_loss\n", + " reduce_on_plateau: false\n", + " infer_samples:\n", + " - by the end of no such thing the audience , like beatrice , has a watchful affection\n", + " for the monster .\n", + " - director rob marshall went out gunning to make a great one .\n", + " - uneasy mishmash of styles and genres .\n", + "exp_manager:\n", + " exp_dir: null\n", + " name: PTuneTextClassification\n", + " create_tensorboard_logger: true\n", + " create_checkpoint_callback: true\n", + "\n" + ] + } + ], "source": [ "print(OmegaConf.to_yaml(config))" ] @@ -522,13 +663,13 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 32, "id": "a312ed76", "metadata": {}, "outputs": [], "source": [ "# in this tutorial train and dev datasets are located in the same folder, so it is enought to add the path of the data directory to the config\n", - "config.model.dataset.data_dir = os.path.join(DATA_DIR, 'NER')\n", + "config.model.dataset.data_dir = os.path.join(DATA_DIR, 'SA')\n", "\n", "# if you want to decrease the size of your datasets, uncomment the lines below:\n", "# NUM_SAMPLES = 1000\n", @@ -550,10 +691,34 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 33, "id": "computational-battlefield", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Trainer config - \n", + "\n", + "gpus: 1\n", + "num_nodes: 1\n", + "max_epochs: 100\n", + "max_steps: null\n", + "accumulate_grad_batches: 1\n", + "gradient_clip_val: 0.0\n", + "precision: 32\n", + "accelerator: ddp\n", + "log_every_n_steps: 1\n", + "val_check_interval: 1.0\n", + "resume_from_checkpoint: null\n", + "num_sanity_val_steps: 0\n", + "checkpoint_callback: false\n", + "logger: false\n", + "\n" + ] + } + ], "source": [ "print(\"Trainer config - \\n\")\n", "print(OmegaConf.to_yaml(config.trainer))" @@ -561,10 +726,27 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 34, "id": "unique-genre", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Using 16bit native Automatic Mixed Precision (AMP)\n", + "[NeMo W 2022-01-20 02:54:59 nemo_logging:349] /opt/conda/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py:48: LightningDeprecationWarning: Setting `max_steps = None` is deprecated in v1.5 and will no longer be supported in v1.7. Use `max_steps = -1` instead.\n", + " rank_zero_deprecation(\n", + " \n", + "[NeMo W 2022-01-20 02:54:59 nemo_logging:349] /opt/conda/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/callback_connector.py:147: LightningDeprecationWarning: Setting `Trainer(checkpoint_callback=False)` is deprecated in v1.5 and will be removed in v1.7. Please consider using `Trainer(enable_checkpointing=False)`.\n", + " rank_zero_deprecation(\n", + " \n", + "GPU available: True, used: True\n", + "TPU available: False, using: 0 TPU cores\n", + "IPU available: False, using: 0 IPUs\n" + ] + } + ], "source": [ "# lets modify some trainer configs\n", "# checks if we have GPU available and uses it\n", @@ -592,10 +774,39 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 35, "id": "mathematical-portable", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[NeMo I 2022-01-20 02:55:04 exp_manager:283] Experiments will be logged at /NeMo/tutorials/nlp/nemo_experiments/PTuneTextClassification/2022-01-20_02-55-04\n", + "[NeMo I 2022-01-20 02:55:04 exp_manager:648] TensorboardLogger has been set up\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[NeMo W 2022-01-20 02:55:04 exp_manager:889] The checkpoint callback was told to monitor a validation value and trainer's max_steps was set to -1. Please ensure that max_steps will run for at least 1 epochs to ensure that checkpointing will not error out.\n", + "[NeMo W 2022-01-20 02:55:04 nemo_logging:349] /opt/conda/lib/python3.8/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:243: LightningDeprecationWarning: `ModelCheckpoint(every_n_val_epochs)` is deprecated in v1.4 and will be removed in v1.6. Please use `every_n_epochs` instead.\n", + " rank_zero_deprecation(\n", + " \n" + ] + }, + { + "data": { + "text/plain": [ + "'/NeMo/tutorials/nlp/nemo_experiments/PTuneTextClassification/2022-01-20_02-55-04'" + ] + }, + "execution_count": 35, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "exp_dir = exp_manager(trainer, config.get(\"exp_manager\", None))\n", "os.makedirs(WORK_DIR, exist_ok=True)\n", @@ -628,19 +839,18 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 39, "id": "compact-horse", "metadata": {}, "outputs": [], "source": [ "# add the specified above model parameters to the config\n", "# config.model.language_model.pretrained_model_name = PRETRAINED_BERT_MODEL\n", - "config.model.language_model.nemo_file = 'biomegatron.nemo'\n", - "config.model.language_model.pretrained_model_name = 'megatron-bert-cased'\n", - "config.model.tokenizer.vocab_file='vocab.txt'\n", - "config.model.tokenizer.tokenizer_model = 'BertWordPieceCase'\n", - " \n", - "\n" + "config.model.language_model.nemo_file = 'gpt_344m.nemo'\n", + "config.model.tensor_model_parallel_size = 1\n", + "config.model.dataset.classes = ['positive', 'neutral', 'negative']\n", + "config.model.tokenizer.vocab_file = vocab_file\n", + "config.model.tokenizer.merge_file = merge_file" ] }, { @@ -654,12 +864,86 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 49, "id": "indoor-france", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[NeMo I 2022-01-20 03:05:26 tokenizer_utils:190] Getting Megatron tokenizer for pretrained model name: megatron-gpt-345m and custom vocab file: /NeMo/tutorials/nlp/gpt2-vocab.json\n", + "[NeMo I 2022-01-20 03:05:26 tokenizer_utils:123] Getting HuggingFace AutoTokenizer with pretrained_model_name: gpt2, vocab_file: /NeMo/tutorials/nlp/gpt2-vocab.json, special_tokens_dict: {}, and use_fast: False\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Using sep_token, but it is not set yet.\n", + "Using cls_token, but it is not set yet.\n", + "Using pad_token, but it is not set yet.\n", + "Using mask_token, but it is not set yet.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[NeMo I 2022-01-20 03:05:41 tokenizer_utils:190] Getting Megatron tokenizer for pretrained model name: megatron-gpt-345m and custom vocab file: /tmp/tmpb7y7lez5/bc1a5de6bb3a4c3fa09426d2951b450c_gpt2-vocab.json\n", + "[NeMo I 2022-01-20 03:05:41 tokenizer_utils:123] Getting HuggingFace AutoTokenizer with pretrained_model_name: gpt2, vocab_file: /tmp/tmpb7y7lez5/bc1a5de6bb3a4c3fa09426d2951b450c_gpt2-vocab.json, special_tokens_dict: {}, and use_fast: False\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Using sep_token, but it is not set yet.\n", + "Using cls_token, but it is not set yet.\n", + "Using pad_token, but it is not set yet.\n", + "Using mask_token, but it is not set yet.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[NeMo I 2022-01-20 03:05:47 megatron_gpt_model:754] Padded vocab_size: 50304, original vocab_size: 50257, dummy tokens: 47.\n", + "[NeMo I 2022-01-20 03:05:48 tokenizer_utils:190] Getting Megatron tokenizer for pretrained model name: megatron-gpt-345m and custom vocab file: /tmp/tmpb7y7lez5/bc1a5de6bb3a4c3fa09426d2951b450c_gpt2-vocab.json\n", + "[NeMo I 2022-01-20 03:05:48 tokenizer_utils:123] Getting HuggingFace AutoTokenizer with pretrained_model_name: gpt2, vocab_file: /tmp/tmpb7y7lez5/bc1a5de6bb3a4c3fa09426d2951b450c_gpt2-vocab.json, special_tokens_dict: {}, and use_fast: False\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Using sep_token, but it is not set yet.\n", + "Using cls_token, but it is not set yet.\n", + "Using pad_token, but it is not set yet.\n", + "Using mask_token, but it is not set yet.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[NeMo I 2022-01-20 03:05:52 megatron_gpt_model:754] Padded vocab_size: 50304, original vocab_size: 50257, dummy tokens: 47.\n", + "[NeMo I 2022-01-20 03:05:53 save_restore_connector:149] Model MegatronGPTModel was successfully restored from /NeMo/tutorials/nlp/gpt_344m.nemo.\n", + "[NeMo I 2022-01-20 03:05:53 auto_tokenizer:171] 1 special tokens added, resize your model accordingly.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Using pad_token, but it is not set yet.\n", + "Using mask_token, but it is not set yet.\n" + ] + } + ], "source": [ - "model_ner = nemo_nlp.models.TokenClassificationModel(cfg=config.model, trainer=trainer)" + "from nemo.collections.nlp.models.text_classification.ptune_text_classification_model import PTuneTextClassificationModel\n", + "model_ptune = PTuneTextClassificationModel(cfg=config.model, trainer=trainer)" ] }, { @@ -674,10 +958,18 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 50, "id": "changed-expense", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "To use tensorboard, please use this notebook in a Google Colab environment.\n" + ] + } + ], "source": [ "try:\n", " from google import colab\n", @@ -695,13 +987,129 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 51, "id": "applied-quality", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[NeMo W 2022-01-20 03:06:08 nemo_logging:349] /opt/conda/lib/python3.8/site-packages/pytorch_lightning/trainer/configuration_validator.py:287: LightningDeprecationWarning: Base `Callback.on_train_batch_start` hook signature has changed in v1.5. The `dataloader_idx` argument will be removed in v1.7.\n", + " rank_zero_deprecation(\n", + " \n", + "[NeMo W 2022-01-20 03:06:08 nemo_logging:349] /opt/conda/lib/python3.8/site-packages/pytorch_lightning/trainer/configuration_validator.py:287: LightningDeprecationWarning: Base `Callback.on_train_batch_end` hook signature has changed in v1.5. The `dataloader_idx` argument will be removed in v1.7.\n", + " rank_zero_deprecation(\n", + " \n", + "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4]\n", + "[NeMo W 2022-01-20 03:06:09 modelPT:475] The lightning trainer received accelerator: . We recommend to use 'ddp' instead.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[NeMo I 2022-01-20 03:06:09 modelPT:566] Optimizer config = Adam (\n", + " Parameter Group 0\n", + " amsgrad: False\n", + " betas: [0.9, 0.999]\n", + " eps: 1e-08\n", + " lr: 2e-05\n", + " weight_decay: 0.01\n", + " )\n", + "[NeMo I 2022-01-20 03:06:09 lr_scheduler:833] Scheduler \"\" \n", + " will be used during training (effective maximum steps = 147800) - \n", + " Parameters : \n", + " (warmup_steps: null\n", + " warmup_ratio: 0.1\n", + " last_epoch: -1\n", + " max_steps: 147800\n", + " )\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + " | Name | Type | Params\n", + "-----------------------------------------------------------------\n", + "0 | model | MegatronGPTModel | 354 M \n", + "1 | embeddings | VocabParallelEmbedding | 51.5 M\n", + "2 | classification_report | ClassificationReport | 0 \n", + "3 | prompt_encoder | PromptEncoder | 14.7 M\n", + "-----------------------------------------------------------------\n", + "14.7 M Trainable params\n", + "354 M Non-trainable params\n", + "369 M Total params\n", + "739.152 Total estimated model params size (MB)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "c6c41b857467495987f968a4300465b0", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Training: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "ename": "AssertionError", + "evalue": "intra_layer_model parallel group is not initialized", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m/tmp/ipykernel_1537560/3447343327.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# start model training\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel_ptune\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;32m/opt/conda/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, model, train_dataloaders, val_dataloaders, datamodule, train_dataloader, ckpt_path)\u001b[0m\n\u001b[1;32m 735\u001b[0m )\n\u001b[1;32m 736\u001b[0m \u001b[0mtrain_dataloaders\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrain_dataloader\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 737\u001b[0;31m self._call_and_handle_interrupt(\n\u001b[0m\u001b[1;32m 738\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_fit_impl\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain_dataloaders\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mval_dataloaders\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdatamodule\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mckpt_path\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 739\u001b[0m )\n", + "\u001b[0;32m/opt/conda/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py\u001b[0m in \u001b[0;36m_call_and_handle_interrupt\u001b[0;34m(self, trainer_fn, *args, **kwargs)\u001b[0m\n\u001b[1;32m 680\u001b[0m \"\"\"\n\u001b[1;32m 681\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 682\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mtrainer_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 683\u001b[0m \u001b[0;31m# TODO: treat KeyboardInterrupt as BaseException (delete the code below) in v1.7\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 684\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mKeyboardInterrupt\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mexception\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/opt/conda/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py\u001b[0m in \u001b[0;36m_fit_impl\u001b[0;34m(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)\u001b[0m\n\u001b[1;32m 770\u001b[0m \u001b[0;31m# TODO: ckpt_path only in v1.7\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 771\u001b[0m \u001b[0mckpt_path\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mckpt_path\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mresume_from_checkpoint\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 772\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_run\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mckpt_path\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mckpt_path\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 773\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 774\u001b[0m \u001b[0;32massert\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstate\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstopped\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/opt/conda/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py\u001b[0m in \u001b[0;36m_run\u001b[0;34m(self, model, ckpt_path)\u001b[0m\n\u001b[1;32m 1193\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1194\u001b[0m \u001b[0;31m# dispatch `start_training` or `start_evaluating` or `start_predicting`\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1195\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_dispatch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1196\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1197\u001b[0m \u001b[0;31m# plugin will finalized fitting (e.g. ddp_spawn will load trained model)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/opt/conda/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py\u001b[0m in \u001b[0;36m_dispatch\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1273\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtraining_type_plugin\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstart_predicting\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1274\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1275\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtraining_type_plugin\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstart_training\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1276\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1277\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mrun_stage\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/opt/conda/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py\u001b[0m in \u001b[0;36mstart_training\u001b[0;34m(self, trainer)\u001b[0m\n\u001b[1;32m 200\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mstart_training\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrainer\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m\"pl.Trainer\"\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 201\u001b[0m \u001b[0;31m# double dispatch to initiate the training loop\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 202\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_results\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun_stage\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 203\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 204\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mstart_evaluating\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrainer\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m\"pl.Trainer\"\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/opt/conda/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py\u001b[0m in \u001b[0;36mrun_stage\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1283\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpredicting\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1284\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_run_predict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1285\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_run_train\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1286\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1287\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_pre_training_routine\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/opt/conda/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py\u001b[0m in \u001b[0;36m_run_train\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1313\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit_loop\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrainer\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1314\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_detect_anomaly\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_detect_anomaly\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1315\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit_loop\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1316\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1317\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_run_evaluate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0m_EVALUATE_OUTPUT\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/opt/conda/lib/python3.8/site-packages/pytorch_lightning/loops/base.py\u001b[0m in \u001b[0;36mrun\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 143\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 144\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mon_advance_start\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 145\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madvance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 146\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mon_advance_end\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 147\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrestarting\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/opt/conda/lib/python3.8/site-packages/pytorch_lightning/loops/fit_loop.py\u001b[0m in \u001b[0;36madvance\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 232\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 233\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprofiler\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprofile\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"run_training_epoch\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 234\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mepoch_loop\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata_fetcher\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 235\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 236\u001b[0m \u001b[0;31m# the global step is manually decreased here due to backwards compatibility with existing loggers\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/opt/conda/lib/python3.8/site-packages/pytorch_lightning/loops/base.py\u001b[0m in \u001b[0;36mrun\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 143\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 144\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mon_advance_start\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 145\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madvance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 146\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mon_advance_end\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 147\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrestarting\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/opt/conda/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py\u001b[0m in \u001b[0;36madvance\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 191\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 192\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprofiler\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprofile\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"run_training_batch\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 193\u001b[0;31m \u001b[0mbatch_output\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbatch_loop\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_idx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 194\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 195\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbatch_progress\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mincrement_processed\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/opt/conda/lib/python3.8/site-packages/pytorch_lightning/loops/base.py\u001b[0m in \u001b[0;36mrun\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 143\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 144\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mon_advance_start\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 145\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madvance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 146\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mon_advance_end\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 147\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrestarting\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/opt/conda/lib/python3.8/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py\u001b[0m in \u001b[0;36madvance\u001b[0;34m(self, batch, batch_idx)\u001b[0m\n\u001b[1;32m 86\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlightning_module\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautomatic_optimization\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 87\u001b[0m \u001b[0moptimizers\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_get_active_optimizers\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptimizers\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptimizer_frequencies\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_idx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 88\u001b[0;31m \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptimizer_loop\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msplit_batch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptimizers\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_idx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 89\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 90\u001b[0m \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmanual_loop\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msplit_batch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_idx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/opt/conda/lib/python3.8/site-packages/pytorch_lightning/loops/base.py\u001b[0m in \u001b[0;36mrun\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 143\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 144\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mon_advance_start\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 145\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madvance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 146\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mon_advance_end\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 147\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrestarting\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/opt/conda/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py\u001b[0m in \u001b[0;36madvance\u001b[0;34m(self, batch, *args, **kwargs)\u001b[0m\n\u001b[1;32m 213\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 214\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0madvance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mAny\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mAny\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mAny\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m# type: ignore[override]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 215\u001b[0;31m result = self._run_optimization(\n\u001b[0m\u001b[1;32m 216\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 217\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_batch_idx\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/opt/conda/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py\u001b[0m in \u001b[0;36m_run_optimization\u001b[0;34m(self, split_batch, batch_idx, optimizer, opt_idx)\u001b[0m\n\u001b[1;32m 264\u001b[0m \u001b[0;31m# gradient update with accumulated gradients\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 265\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 266\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_optimizer_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moptimizer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mopt_idx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_idx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mclosure\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 267\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 268\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mclosure\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconsume_result\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/opt/conda/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py\u001b[0m in \u001b[0;36m_optimizer_step\u001b[0;34m(self, optimizer, opt_idx, batch_idx, train_step_and_backward_closure)\u001b[0m\n\u001b[1;32m 376\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 377\u001b[0m \u001b[0;31m# model hook\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 378\u001b[0;31m lightning_module.optimizer_step(\n\u001b[0m\u001b[1;32m 379\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcurrent_epoch\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 380\u001b[0m \u001b[0mbatch_idx\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/opt/conda/lib/python3.8/site-packages/pytorch_lightning/core/lightning.py\u001b[0m in \u001b[0;36moptimizer_step\u001b[0;34m(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, on_tpu, using_native_amp, using_lbfgs)\u001b[0m\n\u001b[1;32m 1650\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1651\u001b[0m \"\"\"\n\u001b[0;32m-> 1652\u001b[0;31m \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mclosure\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0moptimizer_closure\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1653\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1654\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0moptimizer_zero_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mepoch\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_idx\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mOptimizer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptimizer_idx\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/opt/conda/lib/python3.8/site-packages/pytorch_lightning/core/optimizer.py\u001b[0m in \u001b[0;36mstep\u001b[0;34m(self, closure, **kwargs)\u001b[0m\n\u001b[1;32m 162\u001b[0m \u001b[0;32massert\u001b[0m \u001b[0mtrainer\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 163\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprofiler\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprofile\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mprofiler_action\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 164\u001b[0;31m \u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maccelerator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptimizer_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_optimizer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_optimizer_idx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mclosure\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;32m/opt/conda/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py\u001b[0m in \u001b[0;36moptimizer_step\u001b[0;34m(self, optimizer, opt_idx, closure, model, **kwargs)\u001b[0m\n\u001b[1;32m 334\u001b[0m \"\"\"\n\u001b[1;32m 335\u001b[0m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlightning_module\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 336\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprecision_plugin\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptimizer_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mopt_idx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mclosure\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 337\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 338\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0moptimizer_zero_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcurrent_epoch\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_idx\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mOptimizer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mopt_idx\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/opt/conda/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/native_amp.py\u001b[0m in \u001b[0;36moptimizer_step\u001b[0;34m(self, model, optimizer, optimizer_idx, closure, **kwargs)\u001b[0m\n\u001b[1;32m 83\u001b[0m \u001b[0;34mf\"Native AMP and the LBFGS optimizer are not compatible (optimizer {optimizer_idx}).\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 84\u001b[0m )\n\u001b[0;32m---> 85\u001b[0;31m \u001b[0mclosure_result\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mclosure\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 86\u001b[0m \u001b[0;31m# `unscale` after the closure is executed but before the `on_before_optimizer_step` hook.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 87\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mscaler\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0munscale_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moptimizer\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/opt/conda/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 158\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 159\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__call__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mAny\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mAny\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mOptional\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 160\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_result\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclosure\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 161\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_result\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mloss\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 162\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/opt/conda/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py\u001b[0m in \u001b[0;36mclosure\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 140\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mclosure\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mAny\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mAny\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mClosureResult\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 141\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_profiler\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprofile\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"training_step_and_backward\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 142\u001b[0;31m \u001b[0mstep_output\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_step_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 143\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 144\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mstep_output\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclosure_loss\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/opt/conda/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py\u001b[0m in \u001b[0;36m_training_step\u001b[0;34m(self, split_batch, batch_idx, opt_idx)\u001b[0m\n\u001b[1;32m 433\u001b[0m \u001b[0mlightning_module\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_current_fx_name\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m\"training_step\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 434\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprofiler\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprofile\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"training_step\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 435\u001b[0;31m \u001b[0mtraining_step_output\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maccelerator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtraining_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstep_kwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 436\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtraining_type_plugin\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpost_training_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 437\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/opt/conda/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py\u001b[0m in \u001b[0;36mtraining_step\u001b[0;34m(self, step_kwargs)\u001b[0m\n\u001b[1;32m 214\u001b[0m \"\"\"\n\u001b[1;32m 215\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprecision_plugin\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain_step_context\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 216\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtraining_type_plugin\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtraining_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mstep_kwargs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 217\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 218\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mpost_training_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/opt/conda/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py\u001b[0m in \u001b[0;36mtraining_step\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 211\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 212\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mtraining_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 213\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtraining_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 214\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 215\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mpost_training_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/NeMo/nemo/utils/model_utils.py\u001b[0m in \u001b[0;36mwrap_training_step\u001b[0;34m(wrapped, instance, args, kwargs)\u001b[0m\n\u001b[1;32m 353\u001b[0m \u001b[0;34m@\u001b[0m\u001b[0mwrapt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdecorator\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 354\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mwrap_training_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mwrapped\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minstance\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'pl.LightningModule'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 355\u001b[0;31m \u001b[0moutput_dict\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mwrapped\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 356\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 357\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutput_dict\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdict\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0moutput_dict\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0;34m'log'\u001b[0m \u001b[0;32min\u001b[0m \u001b[0moutput_dict\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/NeMo/nemo/collections/nlp/models/text_classification/ptune_text_classification_model.py\u001b[0m in \u001b[0;36mtraining_step\u001b[0;34m(self, batch, batch_idx)\u001b[0m\n\u001b[1;32m 270\u001b[0m \u001b[0;31m# forward pass\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 271\u001b[0m \u001b[0msentences\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabels\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 272\u001b[0;31m \u001b[0mtrain_loss\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msentences\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabels\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 273\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 274\u001b[0m \u001b[0mlr\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_optimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparam_groups\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'lr'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/NeMo/nemo/collections/nlp/models/text_classification/ptune_text_classification_model.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, sentences, labels)\u001b[0m\n\u001b[1;32m 253\u001b[0m \u001b[0mlabels_input\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabel_ids\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_label_input\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlabels\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabel_position\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mseq_len\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 254\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 255\u001b[0;31m output = self.model.model(None, None, encoder_input=encoder_input,\n\u001b[0m\u001b[1;32m 256\u001b[0m \u001b[0mattention_mask\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mnew_atten\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 257\u001b[0m labels=labels_input)\n", + "\u001b[0;32m/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1108\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m 1109\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1110\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1111\u001b[0m \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1112\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/NeMo/nemo/collections/nlp/models/language_modeling/megatron/gpt_model.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input_ids, position_ids, attention_mask, labels, prompt_tags, tokentype_ids, layer_past, get_key_value, forward_method_parallel_output, encoder_input)\u001b[0m\n\u001b[1;32m 173\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 174\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpost_process\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 175\u001b[0;31m return post_language_model_processing(\n\u001b[0m\u001b[1;32m 176\u001b[0m \u001b[0mlm_output\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 177\u001b[0m \u001b[0mlabels\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/NeMo/nemo/collections/nlp/models/language_modeling/megatron/gpt_model.py\u001b[0m in \u001b[0;36mpost_language_model_processing\u001b[0;34m(lm_output, labels, logit_weights, get_key_value, parallel_output, forward_method_parallel_output, fp16_lm_cross_entropy, return_logits)\u001b[0m\n\u001b[1;32m 52\u001b[0m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtensor_parallel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvocab_parallel_cross_entropy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabels\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 53\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 54\u001b[0;31m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtensor_parallel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvocab_parallel_cross_entropy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutput\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfloat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabels\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 55\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 56\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mreturn_logits\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/opt/conda/lib/python3.8/site-packages/apex/transformer/tensor_parallel/cross_entropy.py\u001b[0m in \u001b[0;36mvocab_parallel_cross_entropy\u001b[0;34m(vocab_parallel_logits, target)\u001b[0m\n\u001b[1;32m 101\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mvocab_parallel_cross_entropy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvocab_parallel_logits\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtarget\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 102\u001b[0m \u001b[0;34m\"\"\"Helper function for the cross entropy.\"\"\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 103\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0m_VocabParallelCrossEntropy\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mapply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvocab_parallel_logits\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtarget\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;32m/opt/conda/lib/python3.8/site-packages/apex/transformer/tensor_parallel/cross_entropy.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(ctx, vocab_parallel_logits, target)\u001b[0m\n\u001b[1;32m 28\u001b[0m \u001b[0mlogits_max\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvocab_parallel_logits\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 29\u001b[0m torch.distributed.all_reduce(\n\u001b[0;32m---> 30\u001b[0;31m \u001b[0mlogits_max\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mop\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdistributed\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mReduceOp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mMAX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgroup\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mget_tensor_model_parallel_group\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 31\u001b[0m )\n\u001b[1;32m 32\u001b[0m \u001b[0;31m# Subtract the maximum value.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/opt/conda/lib/python3.8/site-packages/apex/transformer/parallel_state.py\u001b[0m in \u001b[0;36mget_tensor_model_parallel_group\u001b[0;34m()\u001b[0m\n\u001b[1;32m 172\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mget_tensor_model_parallel_group\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 173\u001b[0m \u001b[0;34m\"\"\"Get the tensor model parallel group the caller rank belongs to.\"\"\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 174\u001b[0;31m \u001b[0;32massert\u001b[0m \u001b[0m_TENSOR_MODEL_PARALLEL_GROUP\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"intra_layer_model parallel group is not initialized\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 175\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0m_TENSOR_MODEL_PARALLEL_GROUP\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 176\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mAssertionError\u001b[0m: intra_layer_model parallel group is not initialized" + ] + } + ], "source": [ "# start model training\n", - "trainer.fit(model_ner)" + "trainer.fit(model_ptune)" ] }, { From 7cb36408823f10e449e778b9748a667d6231bc9f Mon Sep 17 00:00:00 2001 From: Yi Dong Date: Wed, 19 Jan 2022 19:43:03 -0800 Subject: [PATCH 05/22] match the original hyper parameters Signed-off-by: Yi Dong --- .../conf/ptune_text_classification_config.yaml | 12 +++++------- nemo/core/optim/lr_scheduler.py | 2 +- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/examples/nlp/text_classification/conf/ptune_text_classification_config.yaml b/examples/nlp/text_classification/conf/ptune_text_classification_config.yaml index 2e81a1045f33..ff9559939004 100644 --- a/examples/nlp/text_classification/conf/ptune_text_classification_config.yaml +++ b/examples/nlp/text_classification/conf/ptune_text_classification_config.yaml @@ -50,7 +50,7 @@ model: prompt_encoder: template: [3, 3, 0] - dropout: 0.1 + dropout: 0.0 dataset: classes: ??? # The class labels, e.g. ['positive', 'neutral', 'negative'] @@ -84,18 +84,16 @@ model: optim: name: adam - lr: 2e-5 + lr: 1e-5 # optimizer arguments betas: [0.9, 0.999] - weight_decay: 0.01 + weight_decay: 0.0005 # scheduler setup sched: - name: WarmupAnnealing + name: ExponentialLR # Scheduler params - warmup_steps: null - warmup_ratio: 0.1 - last_epoch: -1 + gamma: 0.98 # pytorch lightning args monitor: val_loss reduce_on_plateau: false diff --git a/nemo/core/optim/lr_scheduler.py b/nemo/core/optim/lr_scheduler.py index 7f04ccbe8a4b..06a925f72fb8 100644 --- a/nemo/core/optim/lr_scheduler.py +++ b/nemo/core/optim/lr_scheduler.py @@ -821,7 +821,7 @@ def prepare_lr_scheduler( return None # Inject max_steps (effective or provided) into the scheduler config - if add_max_args_flag: + if add_max_args_flag and scheduler_config.get('name', '') != "ExponentialLR": scheduler_args['max_steps'] = max_steps # Get the scheduler class from the config From c748a81957f832e61d5b0ed10c9f9e74572f4e9d Mon Sep 17 00:00:00 2001 From: Yi Dong Date: Thu, 20 Jan 2022 15:48:59 -0800 Subject: [PATCH 06/22] fixed the loss bug Signed-off-by: Yi Dong --- .../ptune_text_classification_model.py | 2 +- tutorials/nlp/PTune_sentiment_analysis.ipynb | 52 +++++++++++++------ 2 files changed, 36 insertions(+), 18 deletions(-) diff --git a/nemo/collections/nlp/models/text_classification/ptune_text_classification_model.py b/nemo/collections/nlp/models/text_classification/ptune_text_classification_model.py index 7ee313a25fe2..ed50ab9a1351 100644 --- a/nemo/collections/nlp/models/text_classification/ptune_text_classification_model.py +++ b/nemo/collections/nlp/models/text_classification/ptune_text_classification_model.py @@ -256,7 +256,7 @@ def forward(self, sentences, labels): attention_mask=new_atten, labels=labels_input) loss, logits = output - floss = (loss[(labels != SMALL_LOGITS)]).mean() + floss = (loss[(labels_input != SMALL_LOGITS)]).mean() _, returned_pred = self.get_prediction(batch_size, label_position, logits) returned_label = self.get_ground_truth_labels(batch_size, label_ids) diff --git a/tutorials/nlp/PTune_sentiment_analysis.ipynb b/tutorials/nlp/PTune_sentiment_analysis.ipynb index 4883f808ac95..fe89e581f26d 100644 --- a/tutorials/nlp/PTune_sentiment_analysis.ipynb +++ b/tutorials/nlp/PTune_sentiment_analysis.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 16, + "execution_count": 4, "id": "b7a434f4", "metadata": {}, "outputs": [], @@ -34,10 +34,27 @@ }, { "cell_type": "code", - "execution_count": 44, + "execution_count": 1, "id": "challenging-pioneer", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "################################################################################\n", + "### WARNING, path does not exist: KALDI_ROOT=/mnt/matylda5/iveselyk/Tools/kaldi-trunk\n", + "### (please add 'export KALDI_ROOT=' in your $HOME/.profile)\n", + "### (or run as: KALDI_ROOT= python .py)\n", + "################################################################################\n", + "\n", + "[NeMo W 2022-01-20 20:59:05 experimental:27] Module is experimental, not ready for production and is not fully supported. Use at your own risk.\n", + "[NeMo W 2022-01-20 20:59:05 experimental:27] Module is experimental, not ready for production and is not fully supported. Use at your own risk.\n", + "[NeMo W 2022-01-20 20:59:05 experimental:27] Module is experimental, not ready for production and is not fully supported. Use at your own risk.\n", + "[NeMo W 2022-01-20 20:59:05 experimental:27] Module is experimental, not ready for production and is not fully supported. Use at your own risk.\n" + ] + } + ], "source": [ "from nemo.collections import nlp as nemo_nlp\n", "from nemo.utils.exp_manager import exp_manager\n", @@ -303,7 +320,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 2, "id": "82b8e08e", "metadata": {}, "outputs": [], @@ -324,7 +341,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 5, "id": "4b00ee86", "metadata": {}, "outputs": [ @@ -354,7 +371,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 6, "id": "0ae5a1a9", "metadata": {}, "outputs": [ @@ -381,6 +398,7 @@ "config.model.data.data_prefix = ''\n", "config.model.max_position_embeddings = 1024\n", "config.model.data.seq_length = 1024\n", + "config.model.encoder_seq_length = 1024\n", "config.cfg = {}\n", "config.cfg.cfg = config.model\n", "with open('hparams.yaml', 'w') as f:\n", @@ -389,7 +407,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 7, "id": "9e1beda4", "metadata": {}, "outputs": [ @@ -403,24 +421,24 @@ "### (or run as: KALDI_ROOT= python .py)\n", "################################################################################\n", "\n", - "[NeMo W 2022-01-20 02:37:43 experimental:27] Module is experimental, not ready for production and is not fully supported. Use at your own risk.\n", - "[NeMo W 2022-01-20 02:37:43 experimental:27] Module is experimental, not ready for production and is not fully supported. Use at your own risk.\n", - "[NeMo W 2022-01-20 02:37:43 experimental:27] Module is experimental, not ready for production and is not fully supported. Use at your own risk.\n", - "[NeMo W 2022-01-20 02:37:43 experimental:27] Module is experimental, not ready for production and is not fully supported. Use at your own risk.\n", - "I0120 02:37:44.114219 139878815041344 distributed_c10d.py:218] Added key: store_based_barrier_key:1 to store for rank: 0\n", - "I0120 02:37:44.114770 139878815041344 distributed_c10d.py:252] Rank 0: Completed store-based barrier for key:store_based_barrier_key:1 with 1 nodes.\n", + "[NeMo W 2022-01-20 21:01:09 experimental:27] Module is experimental, not ready for production and is not fully supported. Use at your own risk.\n", + "[NeMo W 2022-01-20 21:01:09 experimental:27] Module is experimental, not ready for production and is not fully supported. Use at your own risk.\n", + "[NeMo W 2022-01-20 21:01:09 experimental:27] Module is experimental, not ready for production and is not fully supported. Use at your own risk.\n", + "[NeMo W 2022-01-20 21:01:09 experimental:27] Module is experimental, not ready for production and is not fully supported. Use at your own risk.\n", + "I0120 21:01:09.749301 140536743184192 distributed_c10d.py:218] Added key: store_based_barrier_key:1 to store for rank: 0\n", + "I0120 21:01:09.749543 140536743184192 distributed_c10d.py:252] Rank 0: Completed store-based barrier for key:store_based_barrier_key:1 with 1 nodes.\n", "GPU available: True, used: True\n", "TPU available: False, using: 0 TPU cores\n", "IPU available: False, using: 0 IPUs\n", "converted 354.87M parameters\n", - "[NeMo I 2022-01-20 02:37:44 tokenizer_utils:190] Getting Megatron tokenizer for pretrained model name: megatron-gpt-345m and custom vocab file: /NeMo/tutorials/nlp/gpt2-vocab.json\n", - "[NeMo I 2022-01-20 02:37:44 tokenizer_utils:123] Getting HuggingFace AutoTokenizer with pretrained_model_name: gpt2, vocab_file: /NeMo/tutorials/nlp/gpt2-vocab.json, special_tokens_dict: {}, and use_fast: False\n", + "[NeMo I 2022-01-20 21:01:10 tokenizer_utils:190] Getting Megatron tokenizer for pretrained model name: megatron-gpt-345m and custom vocab file: /NeMo/tutorials/nlp/gpt2-vocab.json\n", + "[NeMo I 2022-01-20 21:01:10 tokenizer_utils:123] Getting HuggingFace AutoTokenizer with pretrained_model_name: gpt2, vocab_file: /NeMo/tutorials/nlp/gpt2-vocab.json, special_tokens_dict: {}, and use_fast: False\n", "Using sep_token, but it is not set yet.\n", "Using cls_token, but it is not set yet.\n", "Using pad_token, but it is not set yet.\n", "Using mask_token, but it is not set yet.\n", - "[NeMo I 2022-01-20 02:37:48 megatron_gpt_model:754] Padded vocab_size: 50304, original vocab_size: 50257, dummy tokens: 47.\n", - "[NeMo I 2022-01-20 02:46:42 megatron_lm_ckpt_to_nemo:265] NeMo model saved to: /NeMo/tutorials/nlp/gpt_344m.nemo\n", + "[NeMo I 2022-01-20 21:01:13 megatron_gpt_model:754] Padded vocab_size: 50304, original vocab_size: 50257, dummy tokens: 47.\n", + "[NeMo I 2022-01-20 21:10:10 megatron_lm_ckpt_to_nemo:265] NeMo model saved to: /NeMo/tutorials/nlp/gpt_344m.nemo\n", "\u001b[0m" ] } From 80f48c282d445ec111757d1bcae79e6008ca3eca Mon Sep 17 00:00:00 2001 From: Yi Dong Date: Thu, 20 Jan 2022 16:18:20 -0800 Subject: [PATCH 07/22] better scheduler Signed-off-by: Yi Dong --- .../conf/ptune_text_classification_config.yaml | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/nlp/text_classification/conf/ptune_text_classification_config.yaml b/examples/nlp/text_classification/conf/ptune_text_classification_config.yaml index ff9559939004..9f7a871c2d9a 100644 --- a/examples/nlp/text_classification/conf/ptune_text_classification_config.yaml +++ b/examples/nlp/text_classification/conf/ptune_text_classification_config.yaml @@ -91,9 +91,11 @@ model: # scheduler setup sched: - name: ExponentialLR + name: WarmupAnnealing # Scheduler params - gamma: 0.98 + warmup_steps: null + warmup_ratio: 0.1 + last_epoch: -1 # pytorch lightning args monitor: val_loss reduce_on_plateau: false From 8a85024c095523b841bd128af9ef985d647e3251 Mon Sep 17 00:00:00 2001 From: Yi Dong Date: Fri, 21 Jan 2022 05:43:22 -0800 Subject: [PATCH 08/22] notebook runs Signed-off-by: Yi Dong --- tutorials/nlp/PTune_Sentiment_Analysis.ipynb | 1293 ++++++++++++++++++ tutorials/nlp/PTune_sentiment_analysis.ipynb | 1242 ----------------- 2 files changed, 1293 insertions(+), 1242 deletions(-) create mode 100644 tutorials/nlp/PTune_Sentiment_Analysis.ipynb delete mode 100644 tutorials/nlp/PTune_sentiment_analysis.ipynb diff --git a/tutorials/nlp/PTune_Sentiment_Analysis.ipynb b/tutorials/nlp/PTune_Sentiment_Analysis.ipynb new file mode 100644 index 000000000000..ba7dc3aff894 --- /dev/null +++ b/tutorials/nlp/PTune_Sentiment_Analysis.ipynb @@ -0,0 +1,1293 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "b7a434f4", + "metadata": {}, + "outputs": [], + "source": [ + "BRANCH='main'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "developmental-gibraltar", + "metadata": {}, + "outputs": [], + "source": [ + "\"\"\"\n", + "You can run either this notebook locally (if you have all the dependencies and a GPU) or on Google Colab.\n", + "\n", + "Instructions for setting up Colab are as follows:\n", + "1. Open a new Python 3 notebook.\n", + "2. Import this notebook from GitHub (File -> Upload Notebook -> \"GITHUB\" tab -> copy/paste GitHub URL)\n", + "3. Connect to an instance with a GPU (Runtime -> Change runtime type -> select \"GPU\" for hardware accelerator)\n", + "4. Run this cell to set up dependencies.\n", + "\"\"\"\n", + "# If you're using Google Colab and not running locally, run this cell\n", + "\n", + "# install NeMo\n", + "!python -m pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[nlp]" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "challenging-pioneer", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "################################################################################\n", + "### WARNING, path does not exist: KALDI_ROOT=/mnt/matylda5/iveselyk/Tools/kaldi-trunk\n", + "### (please add 'export KALDI_ROOT=' in your $HOME/.profile)\n", + "### (or run as: KALDI_ROOT= python .py)\n", + "################################################################################\n", + "\n", + "[NeMo W 2022-01-21 13:35:51 experimental:27] Module is experimental, not ready for production and is not fully supported. Use at your own risk.\n", + "[NeMo W 2022-01-21 13:35:51 experimental:27] Module is experimental, not ready for production and is not fully supported. Use at your own risk.\n", + "[NeMo W 2022-01-21 13:35:51 experimental:27] Module is experimental, not ready for production and is not fully supported. Use at your own risk.\n", + "[NeMo W 2022-01-21 13:35:51 experimental:27] Module is experimental, not ready for production and is not fully supported. Use at your own risk.\n" + ] + } + ], + "source": [ + "from nemo.collections import nlp as nemo_nlp\n", + "from nemo.utils.exp_manager import exp_manager\n", + "\n", + "import os\n", + "import wget \n", + "import torch\n", + "import pytorch_lightning as pl\n", + "from omegaconf import OmegaConf" + ] + }, + { + "cell_type": "markdown", + "id": "employed-ethiopia", + "metadata": {}, + "source": [ + "In this tutorial, we are going to describe how to finetune BioMegatron - a [BERT](https://arxiv.org/abs/1810.04805)-like [Megatron-LM](https://arxiv.org/pdf/1909.08053.pdf) model pre-trained on large biomedical text corpus ([PubMed](https://pubmed.ncbi.nlm.nih.gov/) abstracts and full-text commercial use collection) - on the [NCBI Disease Dataset](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC3951655/) for Named Entity Recognition.\n", + "\n", + "The model size of Megatron-LM can be larger than BERT, up to multi-billion parameters, compared to 345 million parameters of BERT-large.\n", + "There are some alternatives of BioMegatron, most notably [BioBERT](https://arxiv.org/abs/1901.08746). Compared to BioBERT BioMegatron is larger by model size and pre-trained on larger text corpus.\n", + "\n", + "A more general tutorial of using BERT-based models, including Megatron-LM, for downstream natural language processing tasks can be found [here](https://github.com/NVIDIA/NeMo/blob/stable/tutorials/nlp/01_Pretrained_Language_Models_for_Downstream_Tasks.ipynb).\n", + "\n", + "# Task Description\n", + "**Named entity recognition (NER)**, also referred to as entity chunking, identification or extraction, is the task of detecting and classifying key information (entities) in text.\n", + "\n", + "For instance, **given sentences from medical abstracts, what diseases are mentioned?**
\n", + "In this case, our data input is sentences from the abstracts, and our labels are the precise locations of the named disease entities. Take a look at the information provided for the dataset.\n", + "\n", + "For more details and general examples on Named Entity Recognition, please refer to the [Token Classification and Named Entity Recognition tutorial notebook](https://github.com/NVIDIA/NeMo/blob/stable/tutorials/nlp/Token_Classification_Named_Entity_Recognition.ipynb).\n", + "\n", + "# Dataset\n", + "\n", + "The [NCBI-disease corpus](https://www.ncbi.nlm.nih.gov/CBBresearch/Dogan/DISEASE/) is a set of 793 PubMed abstracts, annotated by 14 annotators. The annotations take the form of HTML-style tags inserted into the abstract text using the clearly defined rules. The annotations identify named diseases, and can be used to fine-tune a language model to identify disease mentions in future abstracts, *whether those diseases were part of the original training set or not*.\n", + "\n", + "Here's an example of what an annotated abstract from the corpus looks like:\n", + "\n", + "```html\n", + "10021369\tIdentification of APC2, a homologue of the adenomatous polyposis coli tumour suppressor .\tThe adenomatous polyposis coli ( APC ) tumour-suppressor protein controls the Wnt signalling pathway by forming a complex with glycogen synthase kinase 3beta ( GSK-3beta ) , axin / conductin and betacatenin . Complex formation induces the rapid degradation of betacatenin . In colon carcinoma cells , loss of APC leads to the accumulation of betacatenin in the nucleus , where it binds to and activates the Tcf-4 transcription factor ( reviewed in [ 1 ] [ 2 ] ) . Here , we report the identification and genomic structure of APC homologues . Mammalian APC2 , which closely resembles APC in overall domain structure , was functionally analyzed and shown to contain two SAMP domains , both of which are required for binding to conductin . Like APC , APC2 regulates the formation of active betacatenin-Tcf complexes , as demonstrated using transient transcriptional activation assays in APC - / - colon carcinoma cells . Human APC2 maps to chromosome 19p13 . 3 . APC and APC2 may therefore have comparable functions in development and cancer .\n", + "```\n", + "\n", + "In this example, we see the following tags within the abstract:\n", + "```html\n", + "adenomatous polyposis coli tumour\n", + "adenomatous polyposis coli ( APC ) tumour\n", + "colon carcinoma\n", + "colon carcinoma\n", + "cancer\n", + "```\n", + "\n", + "For our purposes, we will consider any identified category (such as \"Modifier\", \"Specific Disease\", and a few others) to generally be a \"disease\".\n", + "\n", + "Let's download the dataset." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "federal-beads", + "metadata": {}, + "outputs": [], + "source": [ + "DATA_DIR = \"DATA_DIR\"\n", + "os.makedirs(DATA_DIR, exist_ok=True)\n", + "os.makedirs(os.path.join(DATA_DIR, 'SA'), exist_ok=True)" + ] + }, + { + "cell_type": "markdown", + "id": "1c1e1b08", + "metadata": {}, + "source": [ + "## Downloading Financial Phrase Bank Dataset\n", + "\n", + "The datase is collected by Malo et al. 2014, and can be downloaded from this [link](https://www.researchgate.net/profile/Pekka_Malo/publication/251231364_FinancialPhraseBank-v10/data/0c96051eee4fb1d56e000000/FinancialPhraseBank-v10.zip). The zip file for the Financial Phrase Bank Dataset has been provided for ease of download and use." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "8ad03fc0", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--2022-01-20 01:48:29-- https://www.researchgate.net/profile/Pekka_Malo/publication/251231364_FinancialPhraseBank-v10/data/0c96051eee4fb1d56e000000/FinancialPhraseBank-v10.zip\n", + "Resolving www.researchgate.net (www.researchgate.net)... 104.17.32.105, 104.17.33.105, 2606:4700::6811:2069, ...\n", + "Connecting to www.researchgate.net (www.researchgate.net)|104.17.32.105|:443... connected.\n", + "HTTP request sent, awaiting response... 301 Moved Permanently\n", + "Location: https://www.researchgate.net/profile/Pekka-Malo/publication/251231364_FinancialPhraseBank-v10/data/0c96051eee4fb1d56e000000/FinancialPhraseBank-v10.zip [following]\n", + "--2022-01-20 01:48:29-- https://www.researchgate.net/profile/Pekka-Malo/publication/251231364_FinancialPhraseBank-v10/data/0c96051eee4fb1d56e000000/FinancialPhraseBank-v10.zip\n", + "Reusing existing connection to www.researchgate.net:443.\n", + "HTTP request sent, awaiting response... 200 OK\n", + "Length: 681890 (666K) [application/zip]\n", + "Saving to: ‘FinancialPhraseBank-v10.zip’\n", + "\n", + "FinancialPhraseBank 100%[===================>] 665.91K --.-KB/s in 0.02s \n", + "\n", + "2022-01-20 01:48:30 (28.1 MB/s) - ‘FinancialPhraseBank-v10.zip’ saved [681890/681890]\n", + "\n", + "Archive: DATA_DIR/FinancialPhraseBank-v10.zip\n" + ] + } + ], + "source": [ + "!wget https://www.researchgate.net/profile/Pekka_Malo/publication/251231364_FinancialPhraseBank-v10/data/0c96051eee4fb1d56e000000/FinancialPhraseBank-v10.zip\n", + "!mv FinancialPhraseBank-v10.zip {DATA_DIR}\n", + "!unzip -f {DATA_DIR}/FinancialPhraseBank-v10.zip -d {DATA_DIR}" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "radical-castle", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "According to Gran , the company has no plans to move all production to Russia , although that is where the company is growing .@neutral\n" + ] + } + ], + "source": [ + "# If you want to see more examples, you can explore the text of the corpus using the file browser to the left, or open files directly, for example typing a command like the following in a code-cell:\n", + "\n", + "! head -1 $DATA_DIR/FinancialPhraseBank-v1.0/Sentences_50Agree.txt" + ] + }, + { + "cell_type": "markdown", + "id": "specified-maine", + "metadata": {}, + "source": [ + "We have two datasets derived from this corpus: a text classification dataset and a named entity recognition (NER) dataset. The text classification dataset labels the abstracts among three broad disease groupings. We'll use this simple split to demonstrate the NLP text classification task. The NER dataset labels individual words as diseases. This dataset will be used for the NLP NER task. " + ] + }, + { + "cell_type": "markdown", + "id": "affected-numbers", + "metadata": {}, + "source": [ + "## Pre-process dataset\n", + "A pre-processed NCBI-disease dataset for NER can be found [here](https://github.com/spyysalo/ncbi-disease/tree/master/conll) or [here](https://github.com/dmis-lab/biobert#datasets).
\n", + "We download the files under {DATA_DIR/NER} directory." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "198287d4", + "metadata": {}, + "outputs": [], + "source": [ + "import json\n", + "import random\n", + "\n", + "random.seed(1234)\n", + "files = ['Sentences_50Agree.txt', 'Sentences_66Agree.txt', 'Sentences_75Agree.txt', 'Sentences_AllAgree.txt']\n", + "base_dir = DATA_DIR + '/FinancialPhraseBank-v1.0/'\n", + "files = [base_dir + f for f in files]\n", + "\n", + "alllines = []\n", + "for fn in files:\n", + " with open(fn, 'r', encoding=\"ISO-8859-1\") as f:\n", + " alllines.extend(f.readlines())\n", + "\n", + "random.shuffle(alllines)\n", + "fold = 10\n", + "fold_size = len(alllines) // fold\n", + "\n", + "chunk_start = list(range(0, 14780, 1478))\n", + "\n", + "chunks = []\n", + "\n", + "for start_id in chunk_start:\n", + " chunks.append(alllines[start_id:start_id+fold_size])\n", + "\n", + "special = '<|endoftext|>'\n", + "\n", + "def gen_file(data, fold_id, split_type):\n", + " filename = \"{}/{}_{}.txt\".format(base_dir, split_type, fold_id)\n", + " with open(filename, 'w') as f:\n", + " obj = {}\n", + " for line in data:\n", + " splits = line.split('@')\n", + " part1 = splits[0].strip()\n", + " part2 = splits[1].strip()\n", + " obj['sentence'] = part1 +' Sentiment '\n", + " obj['sentiment'] = part2\n", + " f.write(json.dumps(obj)+'\\n')\n", + "\n", + "\n", + "def gen_fold(fold_number):\n", + " lists = list(range(fold))\n", + " test_id = (fold_number + fold) % fold\n", + " val_id = (fold_number + fold - 1) % fold\n", + " test_set = chunks[test_id]\n", + " val_set = chunks[val_id]\n", + " lists.remove(test_id)\n", + " lists.remove(val_id)\n", + " train_set = []\n", + " for idd in lists:\n", + " train_set += chunks[idd]\n", + " gen_file(train_set, fold_number, 'train')\n", + " gen_file(val_set, fold_number, 'validation')\n", + " gen_file(test_set, fold_number, 'test')\n", + "\n", + "for i in range(fold):\n", + " gen_fold(i)" + ] + }, + { + "cell_type": "markdown", + "id": "graphic-debate", + "metadata": {}, + "source": [ + "The NER task requires two files: the text sentences, and the labels. Run the next two cells to see a sample of the two files." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "sound-surgeon", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{\"sentence\": \"The contract includes heating plant equipment and associated installation work . Sentiment \", \"sentiment\": \"neutral\"}\n", + "{\"sentence\": \"The utility will also provide services related to electricity management , such as hedging trades and risk management and reporting . Sentiment \", \"sentiment\": \"neutral\"}\n" + ] + } + ], + "source": [ + "!head -n 2 $DATA_DIR/FinancialPhraseBank-v1.0/train_0.txt" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "spectacular-strain", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "id": "3813cc36", + "metadata": {}, + "source": [ + "## Convert the Megatron-LM Weights to Nemo file\n", + "\n", + "P-Tuning method works the best with large GPT lanague models. From our experiences, models of size 5B or above give good performance. If you already have a large GPT model ready, skip this section. \n", + "\n", + "In this example, we will use the pretrained 344M NeMo Megatron GPT model from [Megatron-LM project](https://github.com/NVIDIA/Megatron-LM). To load it in NeMo Megatron, We first need to convert the Megatron-LM checkpoint to the `.nemo` file. Let's download the pretrained model weights and vocabulary file.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "82b8e08e", + "metadata": {}, + "outputs": [], + "source": [ + "import pathlib\n", + "gpt_file = 'megatron_lm_345m_v0.0.zip'\n", + "vocab_file = 'gpt2-vocab.json'\n", + "merge_file = 'gpt2-merge.txt'\n", + "checkpoint_filename = 'model_optim_rng.pt'\n", + "\n", + "if not pathlib.Path(gpt_file).exists():\n", + " !wget --content-disposition https://api.ngc.nvidia.com/v2/models/nvidia/megatron_lm_345m/versions/v0.0/zip -O $gpt_file\n", + " !unzip -f $gpt_file\n", + " !wget https://s3.amazonaws.com/models.huggingface.co/bert/$vocab_file -O $vocab_file \n", + " !wget https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt -O $merge_file\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "4b00ee86", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "config file is already exists\n" + ] + } + ], + "source": [ + "WORK_DIR = \"WORK_DIR\"\n", + "os.makedirs(WORK_DIR, exist_ok=True)\n", + "\n", + "# Prepare the model parameters \n", + "# download the model's configuration file \n", + "config_dir = WORK_DIR + '/configs/'\n", + "MODEL_CONFIG = \"megatron_gpt_config.yaml\"\n", + "os.makedirs(config_dir, exist_ok=True)\n", + "if not os.path.exists(config_dir + MODEL_CONFIG):\n", + " print('Downloading config file...')\n", + " wget.download(f'https://raw.githubusercontent.com/NVIDIA/NeMo/{BRANCH}/examples/nlp/language_modeling/conf/' + MODEL_CONFIG, config_dir)\n", + "else:\n", + " print ('config file is already exists')" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "0ae5a1a9", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "WORK_DIR/configs/megatron_gpt_config.yaml\n" + ] + } + ], + "source": [ + "# this line will print the entire config of the model\n", + "config_path = f'{WORK_DIR}/configs/{MODEL_CONFIG}'\n", + "print(config_path)\n", + "config = OmegaConf.load(config_path)\n", + "config.model.num_layers = 24\n", + "config.model.hidden_size = 1024\n", + "config.model.ffn_hidden_size = 4096\n", + "config.model.num_attention_heads = 16\n", + "config.model.tokenizer.vocab_file = vocab_file\n", + "config.model.tokenizer.merge_file = merge_file\n", + "config.model.tensor_model_parallel_size = 1\n", + "config.model.data.data_prefix = ''\n", + "config.model.max_position_embeddings = 1024\n", + "config.model.data.seq_length = 1024\n", + "config.model.encoder_seq_length = 1024\n", + "config.cfg = {}\n", + "config.cfg.cfg = config.model\n", + "with open('hparams.yaml', 'w') as f:\n", + " f.write(OmegaConf.to_yaml(config.cfg))" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "9e1beda4", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "################################################################################\n", + "### WARNING, path does not exist: KALDI_ROOT=/mnt/matylda5/iveselyk/Tools/kaldi-trunk\n", + "### (please add 'export KALDI_ROOT=' in your $HOME/.profile)\n", + "### (or run as: KALDI_ROOT= python .py)\n", + "################################################################################\n", + "\n", + "[NeMo W 2022-01-20 21:01:09 experimental:27] Module is experimental, not ready for production and is not fully supported. Use at your own risk.\n", + "[NeMo W 2022-01-20 21:01:09 experimental:27] Module is experimental, not ready for production and is not fully supported. Use at your own risk.\n", + "[NeMo W 2022-01-20 21:01:09 experimental:27] Module is experimental, not ready for production and is not fully supported. Use at your own risk.\n", + "[NeMo W 2022-01-20 21:01:09 experimental:27] Module is experimental, not ready for production and is not fully supported. Use at your own risk.\n", + "I0120 21:01:09.749301 140536743184192 distributed_c10d.py:218] Added key: store_based_barrier_key:1 to store for rank: 0\n", + "I0120 21:01:09.749543 140536743184192 distributed_c10d.py:252] Rank 0: Completed store-based barrier for key:store_based_barrier_key:1 with 1 nodes.\n", + "GPU available: True, used: True\n", + "TPU available: False, using: 0 TPU cores\n", + "IPU available: False, using: 0 IPUs\n", + "converted 354.87M parameters\n", + "[NeMo I 2022-01-20 21:01:10 tokenizer_utils:190] Getting Megatron tokenizer for pretrained model name: megatron-gpt-345m and custom vocab file: /NeMo/tutorials/nlp/gpt2-vocab.json\n", + "[NeMo I 2022-01-20 21:01:10 tokenizer_utils:123] Getting HuggingFace AutoTokenizer with pretrained_model_name: gpt2, vocab_file: /NeMo/tutorials/nlp/gpt2-vocab.json, special_tokens_dict: {}, and use_fast: False\n", + "Using sep_token, but it is not set yet.\n", + "Using cls_token, but it is not set yet.\n", + "Using pad_token, but it is not set yet.\n", + "Using mask_token, but it is not set yet.\n", + "[NeMo I 2022-01-20 21:01:13 megatron_gpt_model:754] Padded vocab_size: 50304, original vocab_size: 50257, dummy tokens: 47.\n", + "[NeMo I 2022-01-20 21:10:10 megatron_lm_ckpt_to_nemo:265] NeMo model saved to: /NeMo/tutorials/nlp/gpt_344m.nemo\n", + "\u001b[0m" + ] + } + ], + "source": [ + "import os\n", + "PWD = os.getcwd()\n", + "wget.download(f'https://raw.githubusercontent.com/NVIDIA/NeMo/{BRANCH}/examples/nlp/language_modeling/megatron_lm_ckpt_to_nemo.py')\n", + "!python -m torch.distributed.run --nproc_per_node=1 megatron_lm_ckpt_to_nemo.py --checkpoint_folder=$PWD/release/mp_rank_00/ --checkpoint_name=$checkpoint_filename --hparams_file=$PWD/hparams.yaml --nemo_file_path=$PWD/gpt_344m.nemo --model_type=gpt --tensor_model_parallel_size=1" + ] + }, + { + "cell_type": "markdown", + "id": "84b455a6", + "metadata": {}, + "source": [ + "# Model configuration\n", + "\n", + "Our Named Entity Recognition model is comprised of the pretrained [BERT](https://arxiv.org/pdf/1810.04805.pdf) model followed by a Token Classification layer.\n", + "\n", + "The model is defined in a config file which declares multiple important sections. They are:\n", + "- **model**: All arguments that are related to the Model - language model, token classifier, optimizer and schedulers, datasets and any other related information\n", + "\n", + "- **trainer**: Any argument to be passed to PyTorch Lightning" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "speaking-grant", + "metadata": {}, + "outputs": [], + "source": [ + "MODEL_CONFIG = \"ptune_text_classification_config.yaml\"" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "demanding-ballet", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "config file is already exists\n" + ] + } + ], + "source": [ + "# download the model's configuration file \n", + "config_dir = WORK_DIR + '/configs/'\n", + "os.makedirs(config_dir, exist_ok=True)\n", + "if not os.path.exists(config_dir + MODEL_CONFIG):\n", + " print('Downloading config file...')\n", + " wget.download(f'https://raw.githubusercontent.com/NVIDIA/NeMo/{BRANCH}/examples/nlp/token_classification/conf/' + MODEL_CONFIG, config_dir)\n", + "else:\n", + " print ('config file is already exists')" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "criminal-outdoors", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "WORK_DIR/configs/ptune_text_classification_config.yaml\n" + ] + } + ], + "source": [ + "# this line will print the entire config of the model\n", + "config_path = f'{WORK_DIR}/configs/{MODEL_CONFIG}'\n", + "print(config_path)\n", + "config = OmegaConf.load(config_path)\n", + "# Note: these are small batch-sizes - increase as appropriate to available GPU capacity\n", + "config.model.train_ds.batch_size=8\n", + "config.model.validation_ds.batch_size=8" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "informed-purse", + "metadata": {}, + "outputs": [], + "source": [ + "# in this tutorial train and dev datasets are located in the same folder, so it is enought to add the path of the data directory to the config\n", + "#config.model.dataset.classes = ['positive', 'neutral', 'negative']\n", + "config.model.train_ds.file_path = DATA_DIR+'/FinancialPhraseBank-v1.0/train_0.txt'\n", + "config.model.validation_ds.file_path = DATA_DIR+'/FinancialPhraseBank-v1.0/validation_0.txt'\n", + "config.model.test_ds.file_path = DATA_DIR+'/FinancialPhraseBank-v1.0/test_0.txt'\n", + "\n", + "\n", + "# if you want to decrease the size of your datasets, uncomment the lines below:\n", + "# NUM_SAMPLES = 1000\n", + "# config.model.train_ds.num_samples = NUM_SAMPLES\n", + "# config.model.validation_ds.num_samples = NUM_SAMPLES" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "divine-belly", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "trainer:\n", + " gpus: 1\n", + " num_nodes: 1\n", + " max_epochs: 100\n", + " max_steps: null\n", + " accumulate_grad_batches: 1\n", + " gradient_clip_val: 0.0\n", + " precision: 32\n", + " accelerator: ddp\n", + " log_every_n_steps: 1\n", + " val_check_interval: 1.0\n", + " resume_from_checkpoint: null\n", + " num_sanity_val_steps: 0\n", + " checkpoint_callback: false\n", + " logger: false\n", + "model:\n", + " tensor_model_parallel_size: 2\n", + " seed: 1234\n", + " nemo_path: ptune_text_classification_model.nemo\n", + " use_lm_finetune: false\n", + " pseudo_token: '[PROMPT]'\n", + " tokenizer:\n", + " library: megatron\n", + " type: GPT2BPETokenizer\n", + " model: null\n", + " vocab_file: null\n", + " merge_file: null\n", + " language_model:\n", + " nemo_file: null\n", + " prompt_encoder:\n", + " template:\n", + " - 3\n", + " - 3\n", + " - 0\n", + " dropout: 0.0\n", + " dataset:\n", + " classes: ???\n", + " train_ds:\n", + " file_path: DATA_DIR/FinancialPhraseBank-v1.0/train_0.txt\n", + " batch_size: 8\n", + " shuffle: true\n", + " num_samples: -1\n", + " num_workers: 3\n", + " drop_last: false\n", + " pin_memory: false\n", + " validation_ds:\n", + " file_path: DATA_DIR/FinancialPhraseBank-v1.0/validation_0.txt\n", + " batch_size: 8\n", + " shuffle: false\n", + " num_samples: -1\n", + " num_workers: 3\n", + " drop_last: false\n", + " pin_memory: false\n", + " test_ds:\n", + " file_path: DATA_DIR/FinancialPhraseBank-v1.0/test_0.txt\n", + " batch_size: 64\n", + " shuffle: false\n", + " num_samples: -1\n", + " num_workers: 3\n", + " drop_last: false\n", + " pin_memory: false\n", + " optim:\n", + " name: adam\n", + " lr: 1.0e-05\n", + " betas:\n", + " - 0.9\n", + " - 0.999\n", + " weight_decay: 0.0005\n", + " sched:\n", + " name: WarmupAnnealing\n", + " warmup_steps: null\n", + " warmup_ratio: 0.1\n", + " last_epoch: -1\n", + " monitor: val_loss\n", + " reduce_on_plateau: false\n", + " infer_samples:\n", + " - by the end of no such thing the audience , like beatrice , has a watchful affection\n", + " for the monster .\n", + " - director rob marshall went out gunning to make a great one .\n", + " - uneasy mishmash of styles and genres .\n", + "exp_manager:\n", + " exp_dir: null\n", + " name: PTuneTextClassification\n", + " create_tensorboard_logger: true\n", + " create_checkpoint_callback: true\n", + "\n" + ] + } + ], + "source": [ + "print(OmegaConf.to_yaml(config))" + ] + }, + { + "cell_type": "markdown", + "id": "dedicated-effort", + "metadata": {}, + "source": [ + "# Model Training\n", + "## Setting up Data within the config\n", + "\n", + "Among other things, the config file contains dictionaries called dataset, train_ds and validation_ds. These are configurations used to setup the Dataset and DataLoaders of the corresponding config.\n" + ] + }, + { + "cell_type": "markdown", + "id": "15e2c67a", + "metadata": {}, + "source": [ + "\n", + "We assume that both training and evaluation files are located in the same directory, and use the default names mentioned during the data download step. \n", + "So, to start model training, we simply need to specify `model.dataset.data_dir`, like we are going to do below.\n" + ] + }, + { + "cell_type": "markdown", + "id": "89dd468d", + "metadata": {}, + "source": [ + "\n", + "Also notice that some config lines, including `model.dataset.data_dir`, have `???` in place of paths, this means that values for these fields are required to be specified by the user.\n", + "\n", + "Let's now add the data directory path to the config." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "a312ed76", + "metadata": {}, + "outputs": [], + "source": [ + "# in this tutorial train and dev datasets are located in the same folder, so it is enought to add the path of the data directory to the config\n", + "config.model.dataset.data_dir = os.path.join(DATA_DIR, 'SA')\n", + "\n", + "# if you want to decrease the size of your datasets, uncomment the lines below:\n", + "# NUM_SAMPLES = 1000\n", + "# config.model.train_ds.num_samples = NUM_SAMPLES\n", + "# config.model.validation_ds.num_samples = NUM_SAMPLES" + ] + }, + { + "cell_type": "markdown", + "id": "changed-mauritius", + "metadata": {}, + "source": [ + "## Building the PyTorch Lightning Trainer\n", + "\n", + "NeMo models are primarily PyTorch Lightning modules - and therefore are entirely compatible with the PyTorch Lightning ecosystem.\n", + "\n", + "Let's first instantiate a Trainer object" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "computational-battlefield", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Trainer config - \n", + "\n", + "gpus: 1\n", + "num_nodes: 1\n", + "max_epochs: 100\n", + "max_steps: null\n", + "accumulate_grad_batches: 1\n", + "gradient_clip_val: 0.0\n", + "precision: 32\n", + "accelerator: ddp\n", + "log_every_n_steps: 1\n", + "val_check_interval: 1.0\n", + "resume_from_checkpoint: null\n", + "num_sanity_val_steps: 0\n", + "checkpoint_callback: false\n", + "logger: false\n", + "\n" + ] + } + ], + "source": [ + "print(\"Trainer config - \\n\")\n", + "print(OmegaConf.to_yaml(config.trainer))" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "unique-genre", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[NeMo W 2022-01-21 13:37:28 nemo_logging:349] /opt/conda/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/ddp.py:107: LightningDeprecationWarning: Argument `num_nodes` in `DDPPlugin` is deprecated in v1.4, and will be removed in v1.6. Notice that it will be overriden by the trainer setting.\n", + " rank_zero_deprecation(\n", + " \n", + "[NeMo W 2022-01-21 13:37:28 nemo_logging:349] /opt/conda/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/ddp.py:113: LightningDeprecationWarning: Argument `sync_batchnorm` in `DDPPlugin` is deprecated in v1.4, and will be removed in v1.6. Notice that it will be overriden by the trainer setting.\n", + " rank_zero_deprecation(\n", + " \n", + "[NeMo W 2022-01-21 13:37:28 nemo_logging:349] /opt/conda/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/accelerator_connector.py:324: LightningDeprecationWarning: Passing `strategy` to the `plugins` flag in Trainer has been deprecated in v1.5 and will be removed in v1.7. Use `Trainer(strategy=)` instead.\n", + " rank_zero_deprecation(\n", + " \n", + "Using 16bit native Automatic Mixed Precision (AMP)\n", + "[NeMo W 2022-01-21 13:37:28 nemo_logging:349] /opt/conda/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py:48: LightningDeprecationWarning: Setting `max_steps = None` is deprecated in v1.5 and will no longer be supported in v1.7. Use `max_steps = -1` instead.\n", + " rank_zero_deprecation(\n", + " \n", + "[NeMo W 2022-01-21 13:37:28 nemo_logging:349] /opt/conda/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/callback_connector.py:147: LightningDeprecationWarning: Setting `Trainer(checkpoint_callback=False)` is deprecated in v1.5 and will be removed in v1.7. Please consider using `Trainer(enable_checkpointing=False)`.\n", + " rank_zero_deprecation(\n", + " \n", + "GPU available: True, used: True\n", + "TPU available: False, using: 0 TPU cores\n", + "IPU available: False, using: 0 IPUs\n" + ] + } + ], + "source": [ + "from nemo.collections.nlp.parts.nlp_overrides import NLPDDPPlugin\n", + "\n", + "\n", + "# lets modify some trainer configs\n", + "# checks if we have GPU available and uses it\n", + "cuda = 1 if torch.cuda.is_available() else 0\n", + "config.trainer.gpus = cuda\n", + "\n", + "# for PyTorch Native AMP set precision=16\n", + "config.trainer.precision = 16 if torch.cuda.is_available() else 32\n", + "\n", + "# remove distributed training flags\n", + "config.trainer.accelerator = None\n", + "\n", + "trainer = pl.Trainer(plugins=[NLPDDPPlugin()], **config.trainer)" + ] + }, + { + "cell_type": "markdown", + "id": "overall-literature", + "metadata": {}, + "source": [ + "## Setting up a NeMo Experiment\n", + "\n", + "NeMo has an experiment manager that handles logging and checkpointing for us, so let's use it:" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "mathematical-portable", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[NeMo I 2022-01-21 13:37:41 exp_manager:283] Experiments will be logged at /NeMo/tutorials/nlp/nemo_experiments/PTuneTextClassification/2022-01-21_13-37-41\n", + "[NeMo I 2022-01-21 13:37:41 exp_manager:648] TensorboardLogger has been set up\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[NeMo W 2022-01-21 13:37:41 exp_manager:889] The checkpoint callback was told to monitor a validation value and trainer's max_steps was set to -1. Please ensure that max_steps will run for at least 1 epochs to ensure that checkpointing will not error out.\n", + "[NeMo W 2022-01-21 13:37:41 nemo_logging:349] /opt/conda/lib/python3.8/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:243: LightningDeprecationWarning: `ModelCheckpoint(every_n_val_epochs)` is deprecated in v1.4 and will be removed in v1.6. Please use `every_n_epochs` instead.\n", + " rank_zero_deprecation(\n", + " \n" + ] + }, + { + "data": { + "text/plain": [ + "'/NeMo/tutorials/nlp/nemo_experiments/PTuneTextClassification/2022-01-21_13-37-41'" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "exp_dir = exp_manager(trainer, config.get(\"exp_manager\", None))\n", + "os.makedirs(WORK_DIR, exist_ok=True)\n", + "\n", + "# the exp_dir provides a path to the current experiment for easy access\n", + "exp_dir = str(exp_dir)\n", + "exp_dir" + ] + }, + { + "cell_type": "markdown", + "id": "f62ea6cd", + "metadata": {}, + "source": [ + "To load the pretrained BERT LM model, we can either load it from the converted `.nemo` file as shown above or load it from a list of included model names. \n", + "\n", + "We can get the list of names by following command \n", + "```python\n", + "# complete list of supported BERT-like models\n", + "print(nemo_nlp.modules.get_pretrained_lm_models_list())\n", + "```\n", + "We can change the `model.language_mode` config to use it\n", + "```python\n", + "# add the specified above model parameters to the config\n", + "config.model.language_model.pretrained_model_name = MODEL_NAME\n", + "```\n", + "\n", + "In this notebook, we will use the converted `.nemo` file as our LM model, which is BioMegatron, [Megatron-LM BERT](https://arxiv.org/abs/1909.08053) pre-trained on [PubMed](https://pubmed.ncbi.nlm.nih.gov/) biomedical text corpus." + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "compact-horse", + "metadata": {}, + "outputs": [], + "source": [ + "# add the specified above model parameters to the config\n", + "# config.model.language_model.pretrained_model_name = PRETRAINED_BERT_MODEL\n", + "config.model.language_model.nemo_file = 'gpt_344m.nemo'\n", + "config.model.tensor_model_parallel_size = 1\n", + "config.model.dataset.classes = ['positive', 'neutral', 'negative']\n", + "config.model.tokenizer.vocab_file = vocab_file\n", + "config.model.tokenizer.merge_file = merge_file" + ] + }, + { + "cell_type": "markdown", + "id": "seeing-geometry", + "metadata": {}, + "source": [ + "Now, we are ready to initialize our model. During the model initialization call, the dataset and data loaders we'll be prepared for training and evaluation.\n", + "Also, the pretrained BERT model will be downloaded, note it can take up to a few minutes depending on the size of the chosen BERT model." + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "indoor-france", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[NeMo I 2022-01-21 13:38:19 tokenizer_utils:190] Getting Megatron tokenizer for pretrained model name: megatron-gpt-345m and custom vocab file: /NeMo/tutorials/nlp/gpt2-vocab.json\n", + "[NeMo I 2022-01-21 13:38:19 tokenizer_utils:123] Getting HuggingFace AutoTokenizer with pretrained_model_name: gpt2, vocab_file: /NeMo/tutorials/nlp/gpt2-vocab.json, special_tokens_dict: {}, and use_fast: False\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Using sep_token, but it is not set yet.\n", + "Using cls_token, but it is not set yet.\n", + "Using pad_token, but it is not set yet.\n", + "Using mask_token, but it is not set yet.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[NeMo I 2022-01-21 13:38:34 tokenizer_utils:190] Getting Megatron tokenizer for pretrained model name: megatron-gpt-345m and custom vocab file: /tmp/tmp1vxu9jzs/3f23abcf03b94354899f3c5b5beab943_gpt2-vocab.json\n", + "[NeMo I 2022-01-21 13:38:34 tokenizer_utils:123] Getting HuggingFace AutoTokenizer with pretrained_model_name: gpt2, vocab_file: /tmp/tmp1vxu9jzs/3f23abcf03b94354899f3c5b5beab943_gpt2-vocab.json, special_tokens_dict: {}, and use_fast: False\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Using sep_token, but it is not set yet.\n", + "Using cls_token, but it is not set yet.\n", + "Using pad_token, but it is not set yet.\n", + "Using mask_token, but it is not set yet.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[NeMo I 2022-01-21 13:38:38 megatron_gpt_model:754] Padded vocab_size: 50304, original vocab_size: 50257, dummy tokens: 47.\n", + "[NeMo I 2022-01-21 13:38:39 tokenizer_utils:190] Getting Megatron tokenizer for pretrained model name: megatron-gpt-345m and custom vocab file: /tmp/tmp1vxu9jzs/3f23abcf03b94354899f3c5b5beab943_gpt2-vocab.json\n", + "[NeMo I 2022-01-21 13:38:39 tokenizer_utils:123] Getting HuggingFace AutoTokenizer with pretrained_model_name: gpt2, vocab_file: /tmp/tmp1vxu9jzs/3f23abcf03b94354899f3c5b5beab943_gpt2-vocab.json, special_tokens_dict: {}, and use_fast: False\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Using sep_token, but it is not set yet.\n", + "Using cls_token, but it is not set yet.\n", + "Using pad_token, but it is not set yet.\n", + "Using mask_token, but it is not set yet.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[NeMo I 2022-01-21 13:38:43 megatron_gpt_model:754] Padded vocab_size: 50304, original vocab_size: 50257, dummy tokens: 47.\n", + "[NeMo I 2022-01-21 13:38:43 save_restore_connector:149] Model MegatronGPTModel was successfully restored from /NeMo/tutorials/nlp/gpt_344m.nemo.\n", + "[NeMo I 2022-01-21 13:38:44 auto_tokenizer:171] 1 special tokens added, resize your model accordingly.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Using pad_token, but it is not set yet.\n", + "Using mask_token, but it is not set yet.\n" + ] + } + ], + "source": [ + "from nemo.collections.nlp.models.text_classification.ptune_text_classification_model import PTuneTextClassificationModel\n", + "model_ptune = PTuneTextClassificationModel(cfg=config.model, trainer=trainer)" + ] + }, + { + "cell_type": "markdown", + "id": "genuine-pipeline", + "metadata": {}, + "source": [ + "## Monitoring training progress\n", + "Optionally, you can create a Tensorboard visualization to monitor training progress.\n", + "If you're not using Colab, refer to [https://www.tensorflow.org/tensorboard/tensorboard_in_notebooks](https://www.tensorflow.org/tensorboard/tensorboard_in_notebooks) if you're facing issues with running the cell below." + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "changed-expense", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "To use tensorboard, please use this notebook in a Google Colab environment.\n" + ] + } + ], + "source": [ + "try:\n", + " from google import colab\n", + " COLAB_ENV = True\n", + "except (ImportError, ModuleNotFoundError):\n", + " COLAB_ENV = False\n", + "\n", + "# Load the TensorBoard notebook extension\n", + "if COLAB_ENV:\n", + " %load_ext tensorboard\n", + " %tensorboard --logdir {exp_dir}\n", + "else:\n", + " print(\"To use tensorboard, please use this notebook in a Google Colab environment.\")" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "applied-quality", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[NeMo W 2022-01-21 13:38:50 nemo_logging:349] /opt/conda/lib/python3.8/site-packages/pytorch_lightning/trainer/configuration_validator.py:287: LightningDeprecationWarning: Base `Callback.on_train_batch_start` hook signature has changed in v1.5. The `dataloader_idx` argument will be removed in v1.7.\n", + " rank_zero_deprecation(\n", + " \n", + "[NeMo W 2022-01-21 13:38:50 nemo_logging:349] /opt/conda/lib/python3.8/site-packages/pytorch_lightning/trainer/configuration_validator.py:287: LightningDeprecationWarning: Base `Callback.on_train_batch_end` hook signature has changed in v1.5. The `dataloader_idx` argument will be removed in v1.7.\n", + " rank_zero_deprecation(\n", + " \n", + "initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/1\n", + "I0121 13:38:50.209178 140425413850944 distributed_c10d.py:218] Added key: store_based_barrier_key:1 to store for rank: 0\n", + "I0121 13:38:50.209801 140425413850944 distributed_c10d.py:252] Rank 0: Completed store-based barrier for key:store_based_barrier_key:1 with 1 nodes.\n", + "----------------------------------------------------------------------------------------------------\n", + "distributed_backend=nccl\n", + "All distributed processes registered. Starting with 1 processes\n", + "----------------------------------------------------------------------------------------------------\n", + "\n", + "I0121 13:38:50.211127 140425413850944 distributed_c10d.py:218] Added key: store_based_barrier_key:2 to store for rank: 0\n", + "I0121 13:38:50.211528 140425413850944 distributed_c10d.py:252] Rank 0: Completed store-based barrier for key:store_based_barrier_key:2 with 1 nodes.\n", + "I0121 13:38:50.212052 140425413850944 distributed_c10d.py:218] Added key: store_based_barrier_key:3 to store for rank: 0\n", + "I0121 13:38:50.212460 140425413850944 distributed_c10d.py:252] Rank 0: Completed store-based barrier for key:store_based_barrier_key:3 with 1 nodes.\n", + "I0121 13:38:50.212984 140425413850944 distributed_c10d.py:218] Added key: store_based_barrier_key:4 to store for rank: 0\n", + "I0121 13:38:50.213450 140425413850944 distributed_c10d.py:252] Rank 0: Completed store-based barrier for key:store_based_barrier_key:4 with 1 nodes.\n", + "I0121 13:38:50.213927 140425413850944 distributed_c10d.py:218] Added key: store_based_barrier_key:5 to store for rank: 0\n", + "I0121 13:38:50.214323 140425413850944 distributed_c10d.py:252] Rank 0: Completed store-based barrier for key:store_based_barrier_key:5 with 1 nodes.\n", + "I0121 13:38:50.214805 140425413850944 distributed_c10d.py:218] Added key: store_based_barrier_key:6 to store for rank: 0\n", + "I0121 13:38:50.215201 140425413850944 distributed_c10d.py:252] Rank 0: Completed store-based barrier for key:store_based_barrier_key:6 with 1 nodes.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "> initializing tensor model parallel with size 1\n", + "> initializing pipeline model parallel with size 1\n", + "> initializing data parallel with size 1\n", + "[NeMo I 2022-01-21 13:38:50 nlp_overrides:137] mp_rank: 0\n", + "[NeMo I 2022-01-21 13:38:50 nlp_overrides:138] dp_rank: 0\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4]\n", + "[NeMo W 2022-01-21 13:38:50 modelPT:475] The lightning trainer received accelerator: . We recommend to use 'ddp' instead.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[NeMo I 2022-01-21 13:38:50 modelPT:566] Optimizer config = Adam (\n", + " Parameter Group 0\n", + " amsgrad: False\n", + " betas: [0.9, 0.999]\n", + " eps: 1e-08\n", + " lr: 1e-05\n", + " weight_decay: 0.0005\n", + " )\n", + "[NeMo I 2022-01-21 13:38:50 lr_scheduler:833] Scheduler \"\" \n", + " will be used during training (effective maximum steps = 147800) - \n", + " Parameters : \n", + " (warmup_steps: null\n", + " warmup_ratio: 0.1\n", + " last_epoch: -1\n", + " max_steps: 147800\n", + " )\n", + "[NeMo I 2022-01-21 13:38:51 nlp_overrides:92] Configuring DDP for model parallelism.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + " | Name | Type | Params\n", + "-----------------------------------------------------------------\n", + "0 | model | MegatronGPTModel | 354 M \n", + "1 | embeddings | VocabParallelEmbedding | 51.5 M\n", + "2 | classification_report | ClassificationReport | 0 \n", + "3 | prompt_encoder | PromptEncoder | 14.7 M\n", + "-----------------------------------------------------------------\n", + "14.7 M Trainable params\n", + "354 M Non-trainable params\n", + "369 M Total params\n", + "739.152 Total estimated model params size (MB)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "46ac1cdd81ad40c39fb6f076790aee94", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Training: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[NeMo W 2022-01-21 13:38:52 nemo_logging:349] /opt/conda/lib/python3.8/site-packages/torch/optim/lr_scheduler.py:129: UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`. In PyTorch 1.1.0 and later, you should call them in the opposite order: `optimizer.step()` before `lr_scheduler.step()`. Failure to do this will result in PyTorch skipping the first value of the learning rate schedule. See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate\n", + " warnings.warn(\"Detected call of `lr_scheduler.step()` before `optimizer.step()`. \"\n", + " \n", + "I0121 13:38:52.474958 140425413850944 distributed.py:902] Reducer buckets have been rebuilt in this iteration.\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "357dcda709c249f7b5d4ce45e39691dc", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validating: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[NeMo I 2022-01-21 13:41:51 ptune_text_classification_model:312] val_report: \n", + " label precision recall f1 support \n", + " positive (label_id: 0) 49.78 27.68 35.58 401\n", + " neutral (label_id: 1) 67.28 94.38 78.56 889\n", + " negative (label_id: 2) 75.00 3.19 6.12 188\n", + " -------------------\n", + " micro avg 64.68 64.68 64.68 1478\n", + " macro avg 64.02 41.75 40.09 1478\n", + " weighted avg 63.51 64.68 57.68 1478\n", + " \n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 0, global step 1477: val_loss reached 0.93801 (best 0.93801), saving model to \"/NeMo/tutorials/nlp/nemo_experiments/PTuneTextClassification/2022-01-21_13-37-41/checkpoints/PTuneTextClassification--val_loss=0.9380-epoch=0.ckpt\" as top 3\n", + "[NeMo W 2022-01-21 13:42:29 nemo_logging:349] /opt/conda/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:685: UserWarning: Detected KeyboardInterrupt, attempting graceful shutdown...\n", + " rank_zero_warn(\"Detected KeyboardInterrupt, attempting graceful shutdown...\")\n", + " \n" + ] + } + ], + "source": [ + "# start model training\n", + "trainer.fit(model_ptune)" + ] + }, + { + "cell_type": "markdown", + "id": "cooperative-michael", + "metadata": {}, + "source": [ + "# Inference\n", + "\n", + "To see how the model performs, we can run generate prediction similar to the way we did it earlier" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "classical-scientist", + "metadata": {}, + "outputs": [], + "source": [ + "# let's first create a subset of our dev data\n", + "! head -n 100 $NER_DATA_DIR/text_dev.txt > $NER_DATA_DIR/sample_text_dev.txt\n", + "! head -n 100 $NER_DATA_DIR/labels_dev.txt > $NER_DATA_DIR/sample_labels_dev.txt" + ] + }, + { + "cell_type": "markdown", + "id": "adult-ranking", + "metadata": {}, + "source": [ + "Now, let's generate predictions for the provided text file.\n", + "If labels file is also specified, the model will evaluate the predictions and plot confusion matrix. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "twenty-abortion", + "metadata": {}, + "outputs": [], + "source": [ + "model_ner.half().evaluate_from_file(\n", + " text_file=os.path.join(NER_DATA_DIR, 'sample_text_dev.txt'),\n", + " labels_file=os.path.join(NER_DATA_DIR, 'sample_labels_dev.txt'),\n", + " output_dir=exp_dir,\n", + " add_confusion_matrix=False,\n", + " normalize_confusion_matrix=True,\n", + " batch_size=1\n", + ")\n", + "# Please check matplotlib version if encountering any error plotting confusion matrix:\n", + "# https://stackoverflow.com/questions/63212347/importerror-cannot-import-name-png-from-matplotlib" + ] + }, + { + "cell_type": "markdown", + "id": "connected-typing", + "metadata": {}, + "source": [ + "## Training Script\n", + "\n", + "If you have NeMo installed locally, you can also train the model with `nlp/token_classification/token_classification_train.py.`\n", + "\n", + "To run training script, use:\n", + "\n", + "`python token_classification_train.py model.dataset.data_dir=PATH_TO_DATA_DIR exp_manager.exp_dir=EXP_DIR model.language_model.pretrained_model_name=megatron-bert-cased model.tokenizer.vocab_file=VOCAB_FILE model.tokenizer.tokenizer_model=BertWordPieceCase model.language_model.nemo_file=NEMO_FILE`\n" + ] + }, + { + "cell_type": "markdown", + "id": "legitimate-electric", + "metadata": {}, + "source": [ + "The training could take several minutes and the result should look something like\n", + "```\n", + "[NeMo I 2020-05-22 17:13:48 token_classification_callback:82] Accuracy: 0.9882348032875798\n", + "[NeMo I 2020-05-22 17:13:48 token_classification_callback:86] F1 weighted: 98.82\n", + "[NeMo I 2020-05-22 17:13:48 token_classification_callback:86] F1 macro: 93.74\n", + "[NeMo I 2020-05-22 17:13:48 token_classification_callback:86] F1 micro: 98.82\n", + "[NeMo I 2020-05-22 17:13:49 token_classification_callback:89] precision recall f1-score support\n", + " \n", + " O (label id: 0) 0.9938 0.9957 0.9947 22092\n", + " B (label id: 1) 0.8843 0.9034 0.8938 787\n", + " I (label id: 2) 0.9505 0.8982 0.9236 1090\n", + " \n", + " accuracy 0.9882 23969\n", + " macro avg 0.9429 0.9324 0.9374 23969\n", + " weighted avg 0.9882 0.9882 0.9882 23969\n", + "```" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tutorials/nlp/PTune_sentiment_analysis.ipynb b/tutorials/nlp/PTune_sentiment_analysis.ipynb deleted file mode 100644 index fe89e581f26d..000000000000 --- a/tutorials/nlp/PTune_sentiment_analysis.ipynb +++ /dev/null @@ -1,1242 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 4, - "id": "b7a434f4", - "metadata": {}, - "outputs": [], - "source": [ - "BRANCH='main'" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "developmental-gibraltar", - "metadata": {}, - "outputs": [], - "source": [ - "\"\"\"\n", - "You can run either this notebook locally (if you have all the dependencies and a GPU) or on Google Colab.\n", - "\n", - "Instructions for setting up Colab are as follows:\n", - "1. Open a new Python 3 notebook.\n", - "2. Import this notebook from GitHub (File -> Upload Notebook -> \"GITHUB\" tab -> copy/paste GitHub URL)\n", - "3. Connect to an instance with a GPU (Runtime -> Change runtime type -> select \"GPU\" for hardware accelerator)\n", - "4. Run this cell to set up dependencies.\n", - "\"\"\"\n", - "# If you're using Google Colab and not running locally, run this cell\n", - "\n", - "# install NeMo\n", - "!python -m pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[nlp]" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "challenging-pioneer", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "################################################################################\n", - "### WARNING, path does not exist: KALDI_ROOT=/mnt/matylda5/iveselyk/Tools/kaldi-trunk\n", - "### (please add 'export KALDI_ROOT=' in your $HOME/.profile)\n", - "### (or run as: KALDI_ROOT= python .py)\n", - "################################################################################\n", - "\n", - "[NeMo W 2022-01-20 20:59:05 experimental:27] Module is experimental, not ready for production and is not fully supported. Use at your own risk.\n", - "[NeMo W 2022-01-20 20:59:05 experimental:27] Module is experimental, not ready for production and is not fully supported. Use at your own risk.\n", - "[NeMo W 2022-01-20 20:59:05 experimental:27] Module is experimental, not ready for production and is not fully supported. Use at your own risk.\n", - "[NeMo W 2022-01-20 20:59:05 experimental:27] Module is experimental, not ready for production and is not fully supported. Use at your own risk.\n" - ] - } - ], - "source": [ - "from nemo.collections import nlp as nemo_nlp\n", - "from nemo.utils.exp_manager import exp_manager\n", - "\n", - "import os\n", - "import wget \n", - "import torch\n", - "import pytorch_lightning as pl\n", - "from omegaconf import OmegaConf" - ] - }, - { - "cell_type": "markdown", - "id": "employed-ethiopia", - "metadata": {}, - "source": [ - "In this tutorial, we are going to describe how to finetune BioMegatron - a [BERT](https://arxiv.org/abs/1810.04805)-like [Megatron-LM](https://arxiv.org/pdf/1909.08053.pdf) model pre-trained on large biomedical text corpus ([PubMed](https://pubmed.ncbi.nlm.nih.gov/) abstracts and full-text commercial use collection) - on the [NCBI Disease Dataset](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC3951655/) for Named Entity Recognition.\n", - "\n", - "The model size of Megatron-LM can be larger than BERT, up to multi-billion parameters, compared to 345 million parameters of BERT-large.\n", - "There are some alternatives of BioMegatron, most notably [BioBERT](https://arxiv.org/abs/1901.08746). Compared to BioBERT BioMegatron is larger by model size and pre-trained on larger text corpus.\n", - "\n", - "A more general tutorial of using BERT-based models, including Megatron-LM, for downstream natural language processing tasks can be found [here](https://github.com/NVIDIA/NeMo/blob/stable/tutorials/nlp/01_Pretrained_Language_Models_for_Downstream_Tasks.ipynb).\n", - "\n", - "# Task Description\n", - "**Named entity recognition (NER)**, also referred to as entity chunking, identification or extraction, is the task of detecting and classifying key information (entities) in text.\n", - "\n", - "For instance, **given sentences from medical abstracts, what diseases are mentioned?**
\n", - "In this case, our data input is sentences from the abstracts, and our labels are the precise locations of the named disease entities. Take a look at the information provided for the dataset.\n", - "\n", - "For more details and general examples on Named Entity Recognition, please refer to the [Token Classification and Named Entity Recognition tutorial notebook](https://github.com/NVIDIA/NeMo/blob/stable/tutorials/nlp/Token_Classification_Named_Entity_Recognition.ipynb).\n", - "\n", - "# Dataset\n", - "\n", - "The [NCBI-disease corpus](https://www.ncbi.nlm.nih.gov/CBBresearch/Dogan/DISEASE/) is a set of 793 PubMed abstracts, annotated by 14 annotators. The annotations take the form of HTML-style tags inserted into the abstract text using the clearly defined rules. The annotations identify named diseases, and can be used to fine-tune a language model to identify disease mentions in future abstracts, *whether those diseases were part of the original training set or not*.\n", - "\n", - "Here's an example of what an annotated abstract from the corpus looks like:\n", - "\n", - "```html\n", - "10021369\tIdentification of APC2, a homologue of the adenomatous polyposis coli tumour suppressor .\tThe adenomatous polyposis coli ( APC ) tumour-suppressor protein controls the Wnt signalling pathway by forming a complex with glycogen synthase kinase 3beta ( GSK-3beta ) , axin / conductin and betacatenin . Complex formation induces the rapid degradation of betacatenin . In colon carcinoma cells , loss of APC leads to the accumulation of betacatenin in the nucleus , where it binds to and activates the Tcf-4 transcription factor ( reviewed in [ 1 ] [ 2 ] ) . Here , we report the identification and genomic structure of APC homologues . Mammalian APC2 , which closely resembles APC in overall domain structure , was functionally analyzed and shown to contain two SAMP domains , both of which are required for binding to conductin . Like APC , APC2 regulates the formation of active betacatenin-Tcf complexes , as demonstrated using transient transcriptional activation assays in APC - / - colon carcinoma cells . Human APC2 maps to chromosome 19p13 . 3 . APC and APC2 may therefore have comparable functions in development and cancer .\n", - "```\n", - "\n", - "In this example, we see the following tags within the abstract:\n", - "```html\n", - "adenomatous polyposis coli tumour\n", - "adenomatous polyposis coli ( APC ) tumour\n", - "colon carcinoma\n", - "colon carcinoma\n", - "cancer\n", - "```\n", - "\n", - "For our purposes, we will consider any identified category (such as \"Modifier\", \"Specific Disease\", and a few others) to generally be a \"disease\".\n", - "\n", - "Let's download the dataset." - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "federal-beads", - "metadata": {}, - "outputs": [], - "source": [ - "DATA_DIR = \"DATA_DIR\"\n", - "os.makedirs(DATA_DIR, exist_ok=True)\n", - "os.makedirs(os.path.join(DATA_DIR, 'SA'), exist_ok=True)" - ] - }, - { - "cell_type": "markdown", - "id": "1c1e1b08", - "metadata": {}, - "source": [ - "## Downloading Financial Phrase Bank Dataset\n", - "\n", - "The datase is collected by Malo et al. 2014, and can be downloaded from this [link](https://www.researchgate.net/profile/Pekka_Malo/publication/251231364_FinancialPhraseBank-v10/data/0c96051eee4fb1d56e000000/FinancialPhraseBank-v10.zip). The zip file for the Financial Phrase Bank Dataset has been provided for ease of download and use." - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "8ad03fc0", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "--2022-01-20 01:48:29-- https://www.researchgate.net/profile/Pekka_Malo/publication/251231364_FinancialPhraseBank-v10/data/0c96051eee4fb1d56e000000/FinancialPhraseBank-v10.zip\n", - "Resolving www.researchgate.net (www.researchgate.net)... 104.17.32.105, 104.17.33.105, 2606:4700::6811:2069, ...\n", - "Connecting to www.researchgate.net (www.researchgate.net)|104.17.32.105|:443... connected.\n", - "HTTP request sent, awaiting response... 301 Moved Permanently\n", - "Location: https://www.researchgate.net/profile/Pekka-Malo/publication/251231364_FinancialPhraseBank-v10/data/0c96051eee4fb1d56e000000/FinancialPhraseBank-v10.zip [following]\n", - "--2022-01-20 01:48:29-- https://www.researchgate.net/profile/Pekka-Malo/publication/251231364_FinancialPhraseBank-v10/data/0c96051eee4fb1d56e000000/FinancialPhraseBank-v10.zip\n", - "Reusing existing connection to www.researchgate.net:443.\n", - "HTTP request sent, awaiting response... 200 OK\n", - "Length: 681890 (666K) [application/zip]\n", - "Saving to: ‘FinancialPhraseBank-v10.zip’\n", - "\n", - "FinancialPhraseBank 100%[===================>] 665.91K --.-KB/s in 0.02s \n", - "\n", - "2022-01-20 01:48:30 (28.1 MB/s) - ‘FinancialPhraseBank-v10.zip’ saved [681890/681890]\n", - "\n", - "Archive: DATA_DIR/FinancialPhraseBank-v10.zip\n" - ] - } - ], - "source": [ - "!wget https://www.researchgate.net/profile/Pekka_Malo/publication/251231364_FinancialPhraseBank-v10/data/0c96051eee4fb1d56e000000/FinancialPhraseBank-v10.zip\n", - "!mv FinancialPhraseBank-v10.zip {DATA_DIR}\n", - "!unzip -f {DATA_DIR}/FinancialPhraseBank-v10.zip -d {DATA_DIR}" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "radical-castle", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "According to Gran , the company has no plans to move all production to Russia , although that is where the company is growing .@neutral\n" - ] - } - ], - "source": [ - "# If you want to see more examples, you can explore the text of the corpus using the file browser to the left, or open files directly, for example typing a command like the following in a code-cell:\n", - "\n", - "! head -1 $DATA_DIR/FinancialPhraseBank-v1.0/Sentences_50Agree.txt" - ] - }, - { - "cell_type": "markdown", - "id": "specified-maine", - "metadata": {}, - "source": [ - "We have two datasets derived from this corpus: a text classification dataset and a named entity recognition (NER) dataset. The text classification dataset labels the abstracts among three broad disease groupings. We'll use this simple split to demonstrate the NLP text classification task. The NER dataset labels individual words as diseases. This dataset will be used for the NLP NER task. " - ] - }, - { - "cell_type": "markdown", - "id": "affected-numbers", - "metadata": {}, - "source": [ - "## Pre-process dataset\n", - "A pre-processed NCBI-disease dataset for NER can be found [here](https://github.com/spyysalo/ncbi-disease/tree/master/conll) or [here](https://github.com/dmis-lab/biobert#datasets).
\n", - "We download the files under {DATA_DIR/NER} directory." - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "198287d4", - "metadata": {}, - "outputs": [], - "source": [ - "import json\n", - "import random\n", - "\n", - "random.seed(1234)\n", - "files = ['Sentences_50Agree.txt', 'Sentences_66Agree.txt', 'Sentences_75Agree.txt', 'Sentences_AllAgree.txt']\n", - "base_dir = DATA_DIR + '/FinancialPhraseBank-v1.0/'\n", - "files = [base_dir + f for f in files]\n", - "\n", - "alllines = []\n", - "for fn in files:\n", - " with open(fn, 'r', encoding=\"ISO-8859-1\") as f:\n", - " alllines.extend(f.readlines())\n", - "\n", - "random.shuffle(alllines)\n", - "fold = 10\n", - "fold_size = len(alllines) // fold\n", - "\n", - "chunk_start = list(range(0, 14780, 1478))\n", - "\n", - "chunks = []\n", - "\n", - "for start_id in chunk_start:\n", - " chunks.append(alllines[start_id:start_id+fold_size])\n", - "\n", - "special = '<|endoftext|>'\n", - "\n", - "def gen_file(data, fold_id, split_type):\n", - " filename = \"{}/{}_{}.txt\".format(base_dir, split_type, fold_id)\n", - " with open(filename, 'w') as f:\n", - " obj = {}\n", - " for line in data:\n", - " splits = line.split('@')\n", - " part1 = splits[0].strip()\n", - " part2 = splits[1].strip()\n", - " obj['sentence'] = part1 +' Sentiment '\n", - " obj['sentiment'] = part2\n", - " f.write(json.dumps(obj)+'\\n')\n", - "\n", - "\n", - "def gen_fold(fold_number):\n", - " lists = list(range(fold))\n", - " test_id = (fold_number + fold) % fold\n", - " val_id = (fold_number + fold - 1) % fold\n", - " test_set = chunks[test_id]\n", - " val_set = chunks[val_id]\n", - " lists.remove(test_id)\n", - " lists.remove(val_id)\n", - " train_set = []\n", - " for idd in lists:\n", - " train_set += chunks[idd]\n", - " gen_file(train_set, fold_number, 'train')\n", - " gen_file(val_set, fold_number, 'validation')\n", - " gen_file(test_set, fold_number, 'test')\n", - "\n", - "for i in range(fold):\n", - " gen_fold(i)" - ] - }, - { - "cell_type": "markdown", - "id": "graphic-debate", - "metadata": {}, - "source": [ - "The NER task requires two files: the text sentences, and the labels. Run the next two cells to see a sample of the two files." - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "sound-surgeon", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{\"sentence\": \"The contract includes heating plant equipment and associated installation work . Sentiment \", \"sentiment\": \"neutral\"}\n", - "{\"sentence\": \"The utility will also provide services related to electricity management , such as hedging trades and risk management and reporting . Sentiment \", \"sentiment\": \"neutral\"}\n" - ] - } - ], - "source": [ - "!head -n 2 $DATA_DIR/FinancialPhraseBank-v1.0/train_0.txt" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "spectacular-strain", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "markdown", - "id": "3813cc36", - "metadata": {}, - "source": [ - "## Convert the Megatron-LM Weights to Nemo file\n", - "\n", - "P-Tuning method works the best with large GPT lanague models. From our experiences, models of size 5B or above give good performance. If you already have a large GPT model ready, skip this section. \n", - "\n", - "In this example, we will use the pretrained 344M NeMo Megatron GPT model from [Megatron-LM project](https://github.com/NVIDIA/Megatron-LM). To load it in NeMo Megatron, We first need to convert the Megatron-LM checkpoint to the `.nemo` file. Let's download the pretrained model weights and vocabulary file.\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "82b8e08e", - "metadata": {}, - "outputs": [], - "source": [ - "import pathlib\n", - "gpt_file = 'megatron_lm_345m_v0.0.zip'\n", - "vocab_file = 'gpt2-vocab.json'\n", - "merge_file = 'gpt2-merge.txt'\n", - "checkpoint_filename = 'model_optim_rng.pt'\n", - "\n", - "if not pathlib.Path(gpt_file).exists():\n", - " !wget --content-disposition https://api.ngc.nvidia.com/v2/models/nvidia/megatron_lm_345m/versions/v0.0/zip -O $gpt_file\n", - " !unzip -f $gpt_file\n", - " !wget https://s3.amazonaws.com/models.huggingface.co/bert/$vocab_file -O $vocab_file \n", - " !wget https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt -O $merge_file\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "4b00ee86", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Downloading config file...\n" - ] - } - ], - "source": [ - "WORK_DIR = \"WORK_DIR\"\n", - "os.makedirs(WORK_DIR, exist_ok=True)\n", - "\n", - "# Prepare the model parameters \n", - "# download the model's configuration file \n", - "config_dir = WORK_DIR + '/configs/'\n", - "MODEL_CONFIG = \"megatron_gpt_config.yaml\"\n", - "os.makedirs(config_dir, exist_ok=True)\n", - "if not os.path.exists(config_dir + MODEL_CONFIG):\n", - " print('Downloading config file...')\n", - " wget.download(f'https://raw.githubusercontent.com/NVIDIA/NeMo/{BRANCH}/examples/nlp/language_modeling/conf/' + MODEL_CONFIG, config_dir)\n", - "else:\n", - " print ('config file is already exists')" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "0ae5a1a9", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "WORK_DIR/configs/megatron_gpt_config.yaml\n" - ] - } - ], - "source": [ - "# this line will print the entire config of the model\n", - "config_path = f'{WORK_DIR}/configs/{MODEL_CONFIG}'\n", - "print(config_path)\n", - "config = OmegaConf.load(config_path)\n", - "config.model.num_layers = 24\n", - "config.model.hidden_size = 1024\n", - "config.model.ffn_hidden_size = 4096\n", - "config.model.num_attention_heads = 16\n", - "config.model.tokenizer.vocab_file = vocab_file\n", - "config.model.tokenizer.merge_file = merge_file\n", - "config.model.tensor_model_parallel_size = 1\n", - "config.model.data.data_prefix = ''\n", - "config.model.max_position_embeddings = 1024\n", - "config.model.data.seq_length = 1024\n", - "config.model.encoder_seq_length = 1024\n", - "config.cfg = {}\n", - "config.cfg.cfg = config.model\n", - "with open('hparams.yaml', 'w') as f:\n", - " f.write(OmegaConf.to_yaml(config.cfg))" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "9e1beda4", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "################################################################################\n", - "### WARNING, path does not exist: KALDI_ROOT=/mnt/matylda5/iveselyk/Tools/kaldi-trunk\n", - "### (please add 'export KALDI_ROOT=' in your $HOME/.profile)\n", - "### (or run as: KALDI_ROOT= python .py)\n", - "################################################################################\n", - "\n", - "[NeMo W 2022-01-20 21:01:09 experimental:27] Module is experimental, not ready for production and is not fully supported. Use at your own risk.\n", - "[NeMo W 2022-01-20 21:01:09 experimental:27] Module is experimental, not ready for production and is not fully supported. Use at your own risk.\n", - "[NeMo W 2022-01-20 21:01:09 experimental:27] Module is experimental, not ready for production and is not fully supported. Use at your own risk.\n", - "[NeMo W 2022-01-20 21:01:09 experimental:27] Module is experimental, not ready for production and is not fully supported. Use at your own risk.\n", - "I0120 21:01:09.749301 140536743184192 distributed_c10d.py:218] Added key: store_based_barrier_key:1 to store for rank: 0\n", - "I0120 21:01:09.749543 140536743184192 distributed_c10d.py:252] Rank 0: Completed store-based barrier for key:store_based_barrier_key:1 with 1 nodes.\n", - "GPU available: True, used: True\n", - "TPU available: False, using: 0 TPU cores\n", - "IPU available: False, using: 0 IPUs\n", - "converted 354.87M parameters\n", - "[NeMo I 2022-01-20 21:01:10 tokenizer_utils:190] Getting Megatron tokenizer for pretrained model name: megatron-gpt-345m and custom vocab file: /NeMo/tutorials/nlp/gpt2-vocab.json\n", - "[NeMo I 2022-01-20 21:01:10 tokenizer_utils:123] Getting HuggingFace AutoTokenizer with pretrained_model_name: gpt2, vocab_file: /NeMo/tutorials/nlp/gpt2-vocab.json, special_tokens_dict: {}, and use_fast: False\n", - "Using sep_token, but it is not set yet.\n", - "Using cls_token, but it is not set yet.\n", - "Using pad_token, but it is not set yet.\n", - "Using mask_token, but it is not set yet.\n", - "[NeMo I 2022-01-20 21:01:13 megatron_gpt_model:754] Padded vocab_size: 50304, original vocab_size: 50257, dummy tokens: 47.\n", - "[NeMo I 2022-01-20 21:10:10 megatron_lm_ckpt_to_nemo:265] NeMo model saved to: /NeMo/tutorials/nlp/gpt_344m.nemo\n", - "\u001b[0m" - ] - } - ], - "source": [ - "import os\n", - "PWD = os.getcwd()\n", - "wget.download(f'https://raw.githubusercontent.com/NVIDIA/NeMo/{BRANCH}/examples/nlp/language_modeling/megatron_lm_ckpt_to_nemo.py')\n", - "!python -m torch.distributed.run --nproc_per_node=1 megatron_lm_ckpt_to_nemo.py --checkpoint_folder=$PWD/release/mp_rank_00/ --checkpoint_name=$checkpoint_filename --hparams_file=$PWD/hparams.yaml --nemo_file_path=$PWD/gpt_344m.nemo --model_type=gpt --tensor_model_parallel_size=1" - ] - }, - { - "cell_type": "markdown", - "id": "84b455a6", - "metadata": {}, - "source": [ - "# Model configuration\n", - "\n", - "Our Named Entity Recognition model is comprised of the pretrained [BERT](https://arxiv.org/pdf/1810.04805.pdf) model followed by a Token Classification layer.\n", - "\n", - "The model is defined in a config file which declares multiple important sections. They are:\n", - "- **model**: All arguments that are related to the Model - language model, token classifier, optimizer and schedulers, datasets and any other related information\n", - "\n", - "- **trainer**: Any argument to be passed to PyTorch Lightning" - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "id": "speaking-grant", - "metadata": {}, - "outputs": [], - "source": [ - "MODEL_CONFIG = \"ptune_text_classification_config.yaml\"" - ] - }, - { - "cell_type": "code", - "execution_count": 26, - "id": "demanding-ballet", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "config file is already exists\n" - ] - } - ], - "source": [ - "# download the model's configuration file \n", - "config_dir = WORK_DIR + '/configs/'\n", - "os.makedirs(config_dir, exist_ok=True)\n", - "if not os.path.exists(config_dir + MODEL_CONFIG):\n", - " print('Downloading config file...')\n", - " wget.download(f'https://raw.githubusercontent.com/NVIDIA/NeMo/{BRANCH}/examples/nlp/token_classification/conf/' + MODEL_CONFIG, config_dir)\n", - "else:\n", - " print ('config file is already exists')" - ] - }, - { - "cell_type": "code", - "execution_count": 27, - "id": "criminal-outdoors", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "WORK_DIR/configs/ptune_text_classification_config.yaml\n" - ] - } - ], - "source": [ - "# this line will print the entire config of the model\n", - "config_path = f'{WORK_DIR}/configs/{MODEL_CONFIG}'\n", - "print(config_path)\n", - "config = OmegaConf.load(config_path)\n", - "# Note: these are small batch-sizes - increase as appropriate to available GPU capacity\n", - "config.model.train_ds.batch_size=8\n", - "config.model.validation_ds.batch_size=8" - ] - }, - { - "cell_type": "code", - "execution_count": 36, - "id": "informed-purse", - "metadata": {}, - "outputs": [], - "source": [ - "# in this tutorial train and dev datasets are located in the same folder, so it is enought to add the path of the data directory to the config\n", - "#config.model.dataset.classes = ['positive', 'neutral', 'negative']\n", - "config.model.train_ds.file_path = DATA_DIR+'/FinancialPhraseBank-v1.0/train_0.txt'\n", - "config.model.validation_ds.file_path = DATA_DIR+'/FinancialPhraseBank-v1.0/validation_0.txt'\n", - "config.model.test_ds.file_path = DATA_DIR+'/FinancialPhraseBank-v1.0/test_0.txt'\n", - "\n", - "\n", - "# if you want to decrease the size of your datasets, uncomment the lines below:\n", - "# NUM_SAMPLES = 1000\n", - "# config.model.train_ds.num_samples = NUM_SAMPLES\n", - "# config.model.validation_ds.num_samples = NUM_SAMPLES" - ] - }, - { - "cell_type": "code", - "execution_count": 31, - "id": "divine-belly", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "trainer:\n", - " gpus: 1\n", - " num_nodes: 1\n", - " max_epochs: 100\n", - " max_steps: null\n", - " accumulate_grad_batches: 1\n", - " gradient_clip_val: 0.0\n", - " precision: 32\n", - " accelerator: ddp\n", - " log_every_n_steps: 1\n", - " val_check_interval: 1.0\n", - " resume_from_checkpoint: null\n", - " num_sanity_val_steps: 0\n", - " checkpoint_callback: false\n", - " logger: false\n", - "model:\n", - " tensor_model_parallel_size: 1\n", - " seed: 1234\n", - " nemo_path: ptune_text_classification_model.nemo\n", - " use_lm_finetune: false\n", - " pseudo_token: '[PROMPT]'\n", - " tokenizer:\n", - " library: megatron\n", - " type: GPT2BPETokenizer\n", - " model: null\n", - " vocab_file: null\n", - " merge_file: null\n", - " language_model:\n", - " nemo_file: null\n", - " prompt_encoder:\n", - " template:\n", - " - 3\n", - " - 3\n", - " - 0\n", - " dropout: 0.1\n", - " dataset:\n", - " classes: ???\n", - " train_ds:\n", - " file_path: DATA_DIR/FinancialPhraseBank-v1.0/train_0.txt\n", - " batch_size: 8\n", - " shuffle: true\n", - " num_samples: -1\n", - " num_workers: 3\n", - " drop_last: false\n", - " pin_memory: false\n", - " validation_ds:\n", - " file_path: DATA_DIR/FinancialPhraseBank-v1.0/validation_0.txt\n", - " batch_size: 8\n", - " shuffle: false\n", - " num_samples: -1\n", - " num_workers: 3\n", - " drop_last: false\n", - " pin_memory: false\n", - " test_ds:\n", - " file_path: DATA_DIR/FinancialPhraseBank-v1.0/test_0.txt\n", - " batch_size: 64\n", - " shuffle: false\n", - " num_samples: -1\n", - " num_workers: 3\n", - " drop_last: false\n", - " pin_memory: false\n", - " optim:\n", - " name: adam\n", - " lr: 2.0e-05\n", - " betas:\n", - " - 0.9\n", - " - 0.999\n", - " weight_decay: 0.01\n", - " sched:\n", - " name: WarmupAnnealing\n", - " warmup_steps: null\n", - " warmup_ratio: 0.1\n", - " last_epoch: -1\n", - " monitor: val_loss\n", - " reduce_on_plateau: false\n", - " infer_samples:\n", - " - by the end of no such thing the audience , like beatrice , has a watchful affection\n", - " for the monster .\n", - " - director rob marshall went out gunning to make a great one .\n", - " - uneasy mishmash of styles and genres .\n", - "exp_manager:\n", - " exp_dir: null\n", - " name: PTuneTextClassification\n", - " create_tensorboard_logger: true\n", - " create_checkpoint_callback: true\n", - "\n" - ] - } - ], - "source": [ - "print(OmegaConf.to_yaml(config))" - ] - }, - { - "cell_type": "markdown", - "id": "dedicated-effort", - "metadata": {}, - "source": [ - "# Model Training\n", - "## Setting up Data within the config\n", - "\n", - "Among other things, the config file contains dictionaries called dataset, train_ds and validation_ds. These are configurations used to setup the Dataset and DataLoaders of the corresponding config.\n" - ] - }, - { - "cell_type": "markdown", - "id": "15e2c67a", - "metadata": {}, - "source": [ - "\n", - "We assume that both training and evaluation files are located in the same directory, and use the default names mentioned during the data download step. \n", - "So, to start model training, we simply need to specify `model.dataset.data_dir`, like we are going to do below.\n" - ] - }, - { - "cell_type": "markdown", - "id": "89dd468d", - "metadata": {}, - "source": [ - "\n", - "Also notice that some config lines, including `model.dataset.data_dir`, have `???` in place of paths, this means that values for these fields are required to be specified by the user.\n", - "\n", - "Let's now add the data directory path to the config." - ] - }, - { - "cell_type": "code", - "execution_count": 32, - "id": "a312ed76", - "metadata": {}, - "outputs": [], - "source": [ - "# in this tutorial train and dev datasets are located in the same folder, so it is enought to add the path of the data directory to the config\n", - "config.model.dataset.data_dir = os.path.join(DATA_DIR, 'SA')\n", - "\n", - "# if you want to decrease the size of your datasets, uncomment the lines below:\n", - "# NUM_SAMPLES = 1000\n", - "# config.model.train_ds.num_samples = NUM_SAMPLES\n", - "# config.model.validation_ds.num_samples = NUM_SAMPLES" - ] - }, - { - "cell_type": "markdown", - "id": "changed-mauritius", - "metadata": {}, - "source": [ - "## Building the PyTorch Lightning Trainer\n", - "\n", - "NeMo models are primarily PyTorch Lightning modules - and therefore are entirely compatible with the PyTorch Lightning ecosystem.\n", - "\n", - "Let's first instantiate a Trainer object" - ] - }, - { - "cell_type": "code", - "execution_count": 33, - "id": "computational-battlefield", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Trainer config - \n", - "\n", - "gpus: 1\n", - "num_nodes: 1\n", - "max_epochs: 100\n", - "max_steps: null\n", - "accumulate_grad_batches: 1\n", - "gradient_clip_val: 0.0\n", - "precision: 32\n", - "accelerator: ddp\n", - "log_every_n_steps: 1\n", - "val_check_interval: 1.0\n", - "resume_from_checkpoint: null\n", - "num_sanity_val_steps: 0\n", - "checkpoint_callback: false\n", - "logger: false\n", - "\n" - ] - } - ], - "source": [ - "print(\"Trainer config - \\n\")\n", - "print(OmegaConf.to_yaml(config.trainer))" - ] - }, - { - "cell_type": "code", - "execution_count": 34, - "id": "unique-genre", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Using 16bit native Automatic Mixed Precision (AMP)\n", - "[NeMo W 2022-01-20 02:54:59 nemo_logging:349] /opt/conda/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py:48: LightningDeprecationWarning: Setting `max_steps = None` is deprecated in v1.5 and will no longer be supported in v1.7. Use `max_steps = -1` instead.\n", - " rank_zero_deprecation(\n", - " \n", - "[NeMo W 2022-01-20 02:54:59 nemo_logging:349] /opt/conda/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/callback_connector.py:147: LightningDeprecationWarning: Setting `Trainer(checkpoint_callback=False)` is deprecated in v1.5 and will be removed in v1.7. Please consider using `Trainer(enable_checkpointing=False)`.\n", - " rank_zero_deprecation(\n", - " \n", - "GPU available: True, used: True\n", - "TPU available: False, using: 0 TPU cores\n", - "IPU available: False, using: 0 IPUs\n" - ] - } - ], - "source": [ - "# lets modify some trainer configs\n", - "# checks if we have GPU available and uses it\n", - "cuda = 1 if torch.cuda.is_available() else 0\n", - "config.trainer.gpus = cuda\n", - "\n", - "# for PyTorch Native AMP set precision=16\n", - "config.trainer.precision = 16 if torch.cuda.is_available() else 32\n", - "\n", - "# remove distributed training flags\n", - "config.trainer.accelerator = None\n", - "\n", - "trainer = pl.Trainer(**config.trainer)" - ] - }, - { - "cell_type": "markdown", - "id": "overall-literature", - "metadata": {}, - "source": [ - "## Setting up a NeMo Experiment\n", - "\n", - "NeMo has an experiment manager that handles logging and checkpointing for us, so let's use it:" - ] - }, - { - "cell_type": "code", - "execution_count": 35, - "id": "mathematical-portable", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[NeMo I 2022-01-20 02:55:04 exp_manager:283] Experiments will be logged at /NeMo/tutorials/nlp/nemo_experiments/PTuneTextClassification/2022-01-20_02-55-04\n", - "[NeMo I 2022-01-20 02:55:04 exp_manager:648] TensorboardLogger has been set up\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "[NeMo W 2022-01-20 02:55:04 exp_manager:889] The checkpoint callback was told to monitor a validation value and trainer's max_steps was set to -1. Please ensure that max_steps will run for at least 1 epochs to ensure that checkpointing will not error out.\n", - "[NeMo W 2022-01-20 02:55:04 nemo_logging:349] /opt/conda/lib/python3.8/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:243: LightningDeprecationWarning: `ModelCheckpoint(every_n_val_epochs)` is deprecated in v1.4 and will be removed in v1.6. Please use `every_n_epochs` instead.\n", - " rank_zero_deprecation(\n", - " \n" - ] - }, - { - "data": { - "text/plain": [ - "'/NeMo/tutorials/nlp/nemo_experiments/PTuneTextClassification/2022-01-20_02-55-04'" - ] - }, - "execution_count": 35, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "exp_dir = exp_manager(trainer, config.get(\"exp_manager\", None))\n", - "os.makedirs(WORK_DIR, exist_ok=True)\n", - "\n", - "# the exp_dir provides a path to the current experiment for easy access\n", - "exp_dir = str(exp_dir)\n", - "exp_dir" - ] - }, - { - "cell_type": "markdown", - "id": "f62ea6cd", - "metadata": {}, - "source": [ - "To load the pretrained BERT LM model, we can either load it from the converted `.nemo` file as shown above or load it from a list of included model names. \n", - "\n", - "We can get the list of names by following command \n", - "```python\n", - "# complete list of supported BERT-like models\n", - "print(nemo_nlp.modules.get_pretrained_lm_models_list())\n", - "```\n", - "We can change the `model.language_mode` config to use it\n", - "```python\n", - "# add the specified above model parameters to the config\n", - "config.model.language_model.pretrained_model_name = MODEL_NAME\n", - "```\n", - "\n", - "In this notebook, we will use the converted `.nemo` file as our LM model, which is BioMegatron, [Megatron-LM BERT](https://arxiv.org/abs/1909.08053) pre-trained on [PubMed](https://pubmed.ncbi.nlm.nih.gov/) biomedical text corpus." - ] - }, - { - "cell_type": "code", - "execution_count": 39, - "id": "compact-horse", - "metadata": {}, - "outputs": [], - "source": [ - "# add the specified above model parameters to the config\n", - "# config.model.language_model.pretrained_model_name = PRETRAINED_BERT_MODEL\n", - "config.model.language_model.nemo_file = 'gpt_344m.nemo'\n", - "config.model.tensor_model_parallel_size = 1\n", - "config.model.dataset.classes = ['positive', 'neutral', 'negative']\n", - "config.model.tokenizer.vocab_file = vocab_file\n", - "config.model.tokenizer.merge_file = merge_file" - ] - }, - { - "cell_type": "markdown", - "id": "seeing-geometry", - "metadata": {}, - "source": [ - "Now, we are ready to initialize our model. During the model initialization call, the dataset and data loaders we'll be prepared for training and evaluation.\n", - "Also, the pretrained BERT model will be downloaded, note it can take up to a few minutes depending on the size of the chosen BERT model." - ] - }, - { - "cell_type": "code", - "execution_count": 49, - "id": "indoor-france", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[NeMo I 2022-01-20 03:05:26 tokenizer_utils:190] Getting Megatron tokenizer for pretrained model name: megatron-gpt-345m and custom vocab file: /NeMo/tutorials/nlp/gpt2-vocab.json\n", - "[NeMo I 2022-01-20 03:05:26 tokenizer_utils:123] Getting HuggingFace AutoTokenizer with pretrained_model_name: gpt2, vocab_file: /NeMo/tutorials/nlp/gpt2-vocab.json, special_tokens_dict: {}, and use_fast: False\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Using sep_token, but it is not set yet.\n", - "Using cls_token, but it is not set yet.\n", - "Using pad_token, but it is not set yet.\n", - "Using mask_token, but it is not set yet.\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[NeMo I 2022-01-20 03:05:41 tokenizer_utils:190] Getting Megatron tokenizer for pretrained model name: megatron-gpt-345m and custom vocab file: /tmp/tmpb7y7lez5/bc1a5de6bb3a4c3fa09426d2951b450c_gpt2-vocab.json\n", - "[NeMo I 2022-01-20 03:05:41 tokenizer_utils:123] Getting HuggingFace AutoTokenizer with pretrained_model_name: gpt2, vocab_file: /tmp/tmpb7y7lez5/bc1a5de6bb3a4c3fa09426d2951b450c_gpt2-vocab.json, special_tokens_dict: {}, and use_fast: False\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Using sep_token, but it is not set yet.\n", - "Using cls_token, but it is not set yet.\n", - "Using pad_token, but it is not set yet.\n", - "Using mask_token, but it is not set yet.\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[NeMo I 2022-01-20 03:05:47 megatron_gpt_model:754] Padded vocab_size: 50304, original vocab_size: 50257, dummy tokens: 47.\n", - "[NeMo I 2022-01-20 03:05:48 tokenizer_utils:190] Getting Megatron tokenizer for pretrained model name: megatron-gpt-345m and custom vocab file: /tmp/tmpb7y7lez5/bc1a5de6bb3a4c3fa09426d2951b450c_gpt2-vocab.json\n", - "[NeMo I 2022-01-20 03:05:48 tokenizer_utils:123] Getting HuggingFace AutoTokenizer with pretrained_model_name: gpt2, vocab_file: /tmp/tmpb7y7lez5/bc1a5de6bb3a4c3fa09426d2951b450c_gpt2-vocab.json, special_tokens_dict: {}, and use_fast: False\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Using sep_token, but it is not set yet.\n", - "Using cls_token, but it is not set yet.\n", - "Using pad_token, but it is not set yet.\n", - "Using mask_token, but it is not set yet.\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[NeMo I 2022-01-20 03:05:52 megatron_gpt_model:754] Padded vocab_size: 50304, original vocab_size: 50257, dummy tokens: 47.\n", - "[NeMo I 2022-01-20 03:05:53 save_restore_connector:149] Model MegatronGPTModel was successfully restored from /NeMo/tutorials/nlp/gpt_344m.nemo.\n", - "[NeMo I 2022-01-20 03:05:53 auto_tokenizer:171] 1 special tokens added, resize your model accordingly.\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Using pad_token, but it is not set yet.\n", - "Using mask_token, but it is not set yet.\n" - ] - } - ], - "source": [ - "from nemo.collections.nlp.models.text_classification.ptune_text_classification_model import PTuneTextClassificationModel\n", - "model_ptune = PTuneTextClassificationModel(cfg=config.model, trainer=trainer)" - ] - }, - { - "cell_type": "markdown", - "id": "genuine-pipeline", - "metadata": {}, - "source": [ - "## Monitoring training progress\n", - "Optionally, you can create a Tensorboard visualization to monitor training progress.\n", - "If you're not using Colab, refer to [https://www.tensorflow.org/tensorboard/tensorboard_in_notebooks](https://www.tensorflow.org/tensorboard/tensorboard_in_notebooks) if you're facing issues with running the cell below." - ] - }, - { - "cell_type": "code", - "execution_count": 50, - "id": "changed-expense", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "To use tensorboard, please use this notebook in a Google Colab environment.\n" - ] - } - ], - "source": [ - "try:\n", - " from google import colab\n", - " COLAB_ENV = True\n", - "except (ImportError, ModuleNotFoundError):\n", - " COLAB_ENV = False\n", - "\n", - "# Load the TensorBoard notebook extension\n", - "if COLAB_ENV:\n", - " %load_ext tensorboard\n", - " %tensorboard --logdir {exp_dir}\n", - "else:\n", - " print(\"To use tensorboard, please use this notebook in a Google Colab environment.\")" - ] - }, - { - "cell_type": "code", - "execution_count": 51, - "id": "applied-quality", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "[NeMo W 2022-01-20 03:06:08 nemo_logging:349] /opt/conda/lib/python3.8/site-packages/pytorch_lightning/trainer/configuration_validator.py:287: LightningDeprecationWarning: Base `Callback.on_train_batch_start` hook signature has changed in v1.5. The `dataloader_idx` argument will be removed in v1.7.\n", - " rank_zero_deprecation(\n", - " \n", - "[NeMo W 2022-01-20 03:06:08 nemo_logging:349] /opt/conda/lib/python3.8/site-packages/pytorch_lightning/trainer/configuration_validator.py:287: LightningDeprecationWarning: Base `Callback.on_train_batch_end` hook signature has changed in v1.5. The `dataloader_idx` argument will be removed in v1.7.\n", - " rank_zero_deprecation(\n", - " \n", - "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4]\n", - "[NeMo W 2022-01-20 03:06:09 modelPT:475] The lightning trainer received accelerator: . We recommend to use 'ddp' instead.\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[NeMo I 2022-01-20 03:06:09 modelPT:566] Optimizer config = Adam (\n", - " Parameter Group 0\n", - " amsgrad: False\n", - " betas: [0.9, 0.999]\n", - " eps: 1e-08\n", - " lr: 2e-05\n", - " weight_decay: 0.01\n", - " )\n", - "[NeMo I 2022-01-20 03:06:09 lr_scheduler:833] Scheduler \"\" \n", - " will be used during training (effective maximum steps = 147800) - \n", - " Parameters : \n", - " (warmup_steps: null\n", - " warmup_ratio: 0.1\n", - " last_epoch: -1\n", - " max_steps: 147800\n", - " )\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\n", - " | Name | Type | Params\n", - "-----------------------------------------------------------------\n", - "0 | model | MegatronGPTModel | 354 M \n", - "1 | embeddings | VocabParallelEmbedding | 51.5 M\n", - "2 | classification_report | ClassificationReport | 0 \n", - "3 | prompt_encoder | PromptEncoder | 14.7 M\n", - "-----------------------------------------------------------------\n", - "14.7 M Trainable params\n", - "354 M Non-trainable params\n", - "369 M Total params\n", - "739.152 Total estimated model params size (MB)\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "c6c41b857467495987f968a4300465b0", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Training: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "ename": "AssertionError", - "evalue": "intra_layer_model parallel group is not initialized", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m/tmp/ipykernel_1537560/3447343327.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# start model training\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel_ptune\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;32m/opt/conda/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, model, train_dataloaders, val_dataloaders, datamodule, train_dataloader, ckpt_path)\u001b[0m\n\u001b[1;32m 735\u001b[0m )\n\u001b[1;32m 736\u001b[0m \u001b[0mtrain_dataloaders\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrain_dataloader\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 737\u001b[0;31m self._call_and_handle_interrupt(\n\u001b[0m\u001b[1;32m 738\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_fit_impl\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain_dataloaders\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mval_dataloaders\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdatamodule\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mckpt_path\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 739\u001b[0m )\n", - "\u001b[0;32m/opt/conda/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py\u001b[0m in \u001b[0;36m_call_and_handle_interrupt\u001b[0;34m(self, trainer_fn, *args, **kwargs)\u001b[0m\n\u001b[1;32m 680\u001b[0m \"\"\"\n\u001b[1;32m 681\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 682\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mtrainer_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 683\u001b[0m \u001b[0;31m# TODO: treat KeyboardInterrupt as BaseException (delete the code below) in v1.7\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 684\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mKeyboardInterrupt\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mexception\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/opt/conda/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py\u001b[0m in \u001b[0;36m_fit_impl\u001b[0;34m(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)\u001b[0m\n\u001b[1;32m 770\u001b[0m \u001b[0;31m# TODO: ckpt_path only in v1.7\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 771\u001b[0m \u001b[0mckpt_path\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mckpt_path\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mresume_from_checkpoint\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 772\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_run\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mckpt_path\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mckpt_path\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 773\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 774\u001b[0m \u001b[0;32massert\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstate\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstopped\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/opt/conda/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py\u001b[0m in \u001b[0;36m_run\u001b[0;34m(self, model, ckpt_path)\u001b[0m\n\u001b[1;32m 1193\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1194\u001b[0m \u001b[0;31m# dispatch `start_training` or `start_evaluating` or `start_predicting`\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1195\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_dispatch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1196\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1197\u001b[0m \u001b[0;31m# plugin will finalized fitting (e.g. ddp_spawn will load trained model)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/opt/conda/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py\u001b[0m in \u001b[0;36m_dispatch\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1273\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtraining_type_plugin\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstart_predicting\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1274\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1275\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtraining_type_plugin\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstart_training\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1276\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1277\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mrun_stage\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/opt/conda/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py\u001b[0m in \u001b[0;36mstart_training\u001b[0;34m(self, trainer)\u001b[0m\n\u001b[1;32m 200\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mstart_training\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrainer\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m\"pl.Trainer\"\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 201\u001b[0m \u001b[0;31m# double dispatch to initiate the training loop\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 202\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_results\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun_stage\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 203\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 204\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mstart_evaluating\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrainer\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m\"pl.Trainer\"\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/opt/conda/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py\u001b[0m in \u001b[0;36mrun_stage\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1283\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpredicting\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1284\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_run_predict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1285\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_run_train\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1286\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1287\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_pre_training_routine\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/opt/conda/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py\u001b[0m in \u001b[0;36m_run_train\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1313\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit_loop\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrainer\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1314\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_detect_anomaly\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_detect_anomaly\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1315\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit_loop\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1316\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1317\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_run_evaluate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0m_EVALUATE_OUTPUT\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/opt/conda/lib/python3.8/site-packages/pytorch_lightning/loops/base.py\u001b[0m in \u001b[0;36mrun\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 143\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 144\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mon_advance_start\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 145\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madvance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 146\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mon_advance_end\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 147\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrestarting\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/opt/conda/lib/python3.8/site-packages/pytorch_lightning/loops/fit_loop.py\u001b[0m in \u001b[0;36madvance\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 232\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 233\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprofiler\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprofile\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"run_training_epoch\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 234\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mepoch_loop\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata_fetcher\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 235\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 236\u001b[0m \u001b[0;31m# the global step is manually decreased here due to backwards compatibility with existing loggers\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/opt/conda/lib/python3.8/site-packages/pytorch_lightning/loops/base.py\u001b[0m in \u001b[0;36mrun\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 143\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 144\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mon_advance_start\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 145\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madvance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 146\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mon_advance_end\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 147\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrestarting\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/opt/conda/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py\u001b[0m in \u001b[0;36madvance\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 191\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 192\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprofiler\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprofile\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"run_training_batch\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 193\u001b[0;31m \u001b[0mbatch_output\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbatch_loop\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_idx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 194\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 195\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbatch_progress\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mincrement_processed\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/opt/conda/lib/python3.8/site-packages/pytorch_lightning/loops/base.py\u001b[0m in \u001b[0;36mrun\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 143\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 144\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mon_advance_start\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 145\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madvance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 146\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mon_advance_end\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 147\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrestarting\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/opt/conda/lib/python3.8/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py\u001b[0m in \u001b[0;36madvance\u001b[0;34m(self, batch, batch_idx)\u001b[0m\n\u001b[1;32m 86\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlightning_module\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautomatic_optimization\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 87\u001b[0m \u001b[0moptimizers\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_get_active_optimizers\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptimizers\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptimizer_frequencies\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_idx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 88\u001b[0;31m \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptimizer_loop\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msplit_batch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptimizers\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_idx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 89\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 90\u001b[0m \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmanual_loop\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msplit_batch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_idx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/opt/conda/lib/python3.8/site-packages/pytorch_lightning/loops/base.py\u001b[0m in \u001b[0;36mrun\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 143\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 144\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mon_advance_start\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 145\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madvance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 146\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mon_advance_end\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 147\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrestarting\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/opt/conda/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py\u001b[0m in \u001b[0;36madvance\u001b[0;34m(self, batch, *args, **kwargs)\u001b[0m\n\u001b[1;32m 213\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 214\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0madvance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mAny\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mAny\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mAny\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m# type: ignore[override]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 215\u001b[0;31m result = self._run_optimization(\n\u001b[0m\u001b[1;32m 216\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 217\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_batch_idx\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/opt/conda/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py\u001b[0m in \u001b[0;36m_run_optimization\u001b[0;34m(self, split_batch, batch_idx, optimizer, opt_idx)\u001b[0m\n\u001b[1;32m 264\u001b[0m \u001b[0;31m# gradient update with accumulated gradients\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 265\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 266\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_optimizer_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moptimizer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mopt_idx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_idx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mclosure\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 267\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 268\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mclosure\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconsume_result\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/opt/conda/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py\u001b[0m in \u001b[0;36m_optimizer_step\u001b[0;34m(self, optimizer, opt_idx, batch_idx, train_step_and_backward_closure)\u001b[0m\n\u001b[1;32m 376\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 377\u001b[0m \u001b[0;31m# model hook\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 378\u001b[0;31m lightning_module.optimizer_step(\n\u001b[0m\u001b[1;32m 379\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcurrent_epoch\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 380\u001b[0m \u001b[0mbatch_idx\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/opt/conda/lib/python3.8/site-packages/pytorch_lightning/core/lightning.py\u001b[0m in \u001b[0;36moptimizer_step\u001b[0;34m(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, on_tpu, using_native_amp, using_lbfgs)\u001b[0m\n\u001b[1;32m 1650\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1651\u001b[0m \"\"\"\n\u001b[0;32m-> 1652\u001b[0;31m \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mclosure\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0moptimizer_closure\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1653\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1654\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0moptimizer_zero_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mepoch\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_idx\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mOptimizer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptimizer_idx\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/opt/conda/lib/python3.8/site-packages/pytorch_lightning/core/optimizer.py\u001b[0m in \u001b[0;36mstep\u001b[0;34m(self, closure, **kwargs)\u001b[0m\n\u001b[1;32m 162\u001b[0m \u001b[0;32massert\u001b[0m \u001b[0mtrainer\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 163\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprofiler\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprofile\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mprofiler_action\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 164\u001b[0;31m \u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maccelerator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptimizer_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_optimizer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_optimizer_idx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mclosure\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;32m/opt/conda/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py\u001b[0m in \u001b[0;36moptimizer_step\u001b[0;34m(self, optimizer, opt_idx, closure, model, **kwargs)\u001b[0m\n\u001b[1;32m 334\u001b[0m \"\"\"\n\u001b[1;32m 335\u001b[0m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlightning_module\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 336\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprecision_plugin\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptimizer_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mopt_idx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mclosure\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 337\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 338\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0moptimizer_zero_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcurrent_epoch\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_idx\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mOptimizer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mopt_idx\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/opt/conda/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/native_amp.py\u001b[0m in \u001b[0;36moptimizer_step\u001b[0;34m(self, model, optimizer, optimizer_idx, closure, **kwargs)\u001b[0m\n\u001b[1;32m 83\u001b[0m \u001b[0;34mf\"Native AMP and the LBFGS optimizer are not compatible (optimizer {optimizer_idx}).\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 84\u001b[0m )\n\u001b[0;32m---> 85\u001b[0;31m \u001b[0mclosure_result\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mclosure\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 86\u001b[0m \u001b[0;31m# `unscale` after the closure is executed but before the `on_before_optimizer_step` hook.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 87\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mscaler\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0munscale_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moptimizer\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/opt/conda/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 158\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 159\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__call__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mAny\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mAny\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mOptional\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 160\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_result\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclosure\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 161\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_result\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mloss\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 162\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/opt/conda/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py\u001b[0m in \u001b[0;36mclosure\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 140\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mclosure\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mAny\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mAny\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mClosureResult\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 141\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_profiler\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprofile\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"training_step_and_backward\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 142\u001b[0;31m \u001b[0mstep_output\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_step_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 143\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 144\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mstep_output\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclosure_loss\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/opt/conda/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py\u001b[0m in \u001b[0;36m_training_step\u001b[0;34m(self, split_batch, batch_idx, opt_idx)\u001b[0m\n\u001b[1;32m 433\u001b[0m \u001b[0mlightning_module\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_current_fx_name\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m\"training_step\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 434\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprofiler\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprofile\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"training_step\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 435\u001b[0;31m \u001b[0mtraining_step_output\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maccelerator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtraining_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstep_kwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 436\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtraining_type_plugin\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpost_training_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 437\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/opt/conda/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py\u001b[0m in \u001b[0;36mtraining_step\u001b[0;34m(self, step_kwargs)\u001b[0m\n\u001b[1;32m 214\u001b[0m \"\"\"\n\u001b[1;32m 215\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprecision_plugin\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain_step_context\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 216\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtraining_type_plugin\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtraining_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mstep_kwargs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 217\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 218\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mpost_training_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/opt/conda/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py\u001b[0m in \u001b[0;36mtraining_step\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 211\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 212\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mtraining_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 213\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtraining_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 214\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 215\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mpost_training_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/NeMo/nemo/utils/model_utils.py\u001b[0m in \u001b[0;36mwrap_training_step\u001b[0;34m(wrapped, instance, args, kwargs)\u001b[0m\n\u001b[1;32m 353\u001b[0m \u001b[0;34m@\u001b[0m\u001b[0mwrapt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdecorator\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 354\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mwrap_training_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mwrapped\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minstance\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'pl.LightningModule'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 355\u001b[0;31m \u001b[0moutput_dict\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mwrapped\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 356\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 357\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutput_dict\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdict\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0moutput_dict\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0;34m'log'\u001b[0m \u001b[0;32min\u001b[0m \u001b[0moutput_dict\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/NeMo/nemo/collections/nlp/models/text_classification/ptune_text_classification_model.py\u001b[0m in \u001b[0;36mtraining_step\u001b[0;34m(self, batch, batch_idx)\u001b[0m\n\u001b[1;32m 270\u001b[0m \u001b[0;31m# forward pass\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 271\u001b[0m \u001b[0msentences\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabels\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 272\u001b[0;31m \u001b[0mtrain_loss\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msentences\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabels\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 273\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 274\u001b[0m \u001b[0mlr\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_optimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparam_groups\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'lr'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/NeMo/nemo/collections/nlp/models/text_classification/ptune_text_classification_model.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, sentences, labels)\u001b[0m\n\u001b[1;32m 253\u001b[0m \u001b[0mlabels_input\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabel_ids\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_label_input\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlabels\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabel_position\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mseq_len\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 254\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 255\u001b[0;31m output = self.model.model(None, None, encoder_input=encoder_input,\n\u001b[0m\u001b[1;32m 256\u001b[0m \u001b[0mattention_mask\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mnew_atten\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 257\u001b[0m labels=labels_input)\n", - "\u001b[0;32m/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1108\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m 1109\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1110\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1111\u001b[0m \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1112\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/NeMo/nemo/collections/nlp/models/language_modeling/megatron/gpt_model.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input_ids, position_ids, attention_mask, labels, prompt_tags, tokentype_ids, layer_past, get_key_value, forward_method_parallel_output, encoder_input)\u001b[0m\n\u001b[1;32m 173\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 174\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpost_process\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 175\u001b[0;31m return post_language_model_processing(\n\u001b[0m\u001b[1;32m 176\u001b[0m \u001b[0mlm_output\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 177\u001b[0m \u001b[0mlabels\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/NeMo/nemo/collections/nlp/models/language_modeling/megatron/gpt_model.py\u001b[0m in \u001b[0;36mpost_language_model_processing\u001b[0;34m(lm_output, labels, logit_weights, get_key_value, parallel_output, forward_method_parallel_output, fp16_lm_cross_entropy, return_logits)\u001b[0m\n\u001b[1;32m 52\u001b[0m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtensor_parallel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvocab_parallel_cross_entropy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabels\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 53\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 54\u001b[0;31m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtensor_parallel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvocab_parallel_cross_entropy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutput\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfloat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabels\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 55\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 56\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mreturn_logits\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/opt/conda/lib/python3.8/site-packages/apex/transformer/tensor_parallel/cross_entropy.py\u001b[0m in \u001b[0;36mvocab_parallel_cross_entropy\u001b[0;34m(vocab_parallel_logits, target)\u001b[0m\n\u001b[1;32m 101\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mvocab_parallel_cross_entropy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvocab_parallel_logits\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtarget\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 102\u001b[0m \u001b[0;34m\"\"\"Helper function for the cross entropy.\"\"\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 103\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0m_VocabParallelCrossEntropy\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mapply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvocab_parallel_logits\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtarget\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;32m/opt/conda/lib/python3.8/site-packages/apex/transformer/tensor_parallel/cross_entropy.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(ctx, vocab_parallel_logits, target)\u001b[0m\n\u001b[1;32m 28\u001b[0m \u001b[0mlogits_max\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvocab_parallel_logits\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 29\u001b[0m torch.distributed.all_reduce(\n\u001b[0;32m---> 30\u001b[0;31m \u001b[0mlogits_max\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mop\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdistributed\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mReduceOp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mMAX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgroup\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mget_tensor_model_parallel_group\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 31\u001b[0m )\n\u001b[1;32m 32\u001b[0m \u001b[0;31m# Subtract the maximum value.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/opt/conda/lib/python3.8/site-packages/apex/transformer/parallel_state.py\u001b[0m in \u001b[0;36mget_tensor_model_parallel_group\u001b[0;34m()\u001b[0m\n\u001b[1;32m 172\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mget_tensor_model_parallel_group\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 173\u001b[0m \u001b[0;34m\"\"\"Get the tensor model parallel group the caller rank belongs to.\"\"\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 174\u001b[0;31m \u001b[0;32massert\u001b[0m \u001b[0m_TENSOR_MODEL_PARALLEL_GROUP\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"intra_layer_model parallel group is not initialized\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 175\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0m_TENSOR_MODEL_PARALLEL_GROUP\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 176\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mAssertionError\u001b[0m: intra_layer_model parallel group is not initialized" - ] - } - ], - "source": [ - "# start model training\n", - "trainer.fit(model_ptune)" - ] - }, - { - "cell_type": "markdown", - "id": "cooperative-michael", - "metadata": {}, - "source": [ - "# Inference\n", - "\n", - "To see how the model performs, we can run generate prediction similar to the way we did it earlier" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "classical-scientist", - "metadata": {}, - "outputs": [], - "source": [ - "# let's first create a subset of our dev data\n", - "! head -n 100 $NER_DATA_DIR/text_dev.txt > $NER_DATA_DIR/sample_text_dev.txt\n", - "! head -n 100 $NER_DATA_DIR/labels_dev.txt > $NER_DATA_DIR/sample_labels_dev.txt" - ] - }, - { - "cell_type": "markdown", - "id": "adult-ranking", - "metadata": {}, - "source": [ - "Now, let's generate predictions for the provided text file.\n", - "If labels file is also specified, the model will evaluate the predictions and plot confusion matrix. " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "twenty-abortion", - "metadata": {}, - "outputs": [], - "source": [ - "model_ner.half().evaluate_from_file(\n", - " text_file=os.path.join(NER_DATA_DIR, 'sample_text_dev.txt'),\n", - " labels_file=os.path.join(NER_DATA_DIR, 'sample_labels_dev.txt'),\n", - " output_dir=exp_dir,\n", - " add_confusion_matrix=False,\n", - " normalize_confusion_matrix=True,\n", - " batch_size=1\n", - ")\n", - "# Please check matplotlib version if encountering any error plotting confusion matrix:\n", - "# https://stackoverflow.com/questions/63212347/importerror-cannot-import-name-png-from-matplotlib" - ] - }, - { - "cell_type": "markdown", - "id": "connected-typing", - "metadata": {}, - "source": [ - "## Training Script\n", - "\n", - "If you have NeMo installed locally, you can also train the model with `nlp/token_classification/token_classification_train.py.`\n", - "\n", - "To run training script, use:\n", - "\n", - "`python token_classification_train.py model.dataset.data_dir=PATH_TO_DATA_DIR exp_manager.exp_dir=EXP_DIR model.language_model.pretrained_model_name=megatron-bert-cased model.tokenizer.vocab_file=VOCAB_FILE model.tokenizer.tokenizer_model=BertWordPieceCase model.language_model.nemo_file=NEMO_FILE`\n" - ] - }, - { - "cell_type": "markdown", - "id": "legitimate-electric", - "metadata": {}, - "source": [ - "The training could take several minutes and the result should look something like\n", - "```\n", - "[NeMo I 2020-05-22 17:13:48 token_classification_callback:82] Accuracy: 0.9882348032875798\n", - "[NeMo I 2020-05-22 17:13:48 token_classification_callback:86] F1 weighted: 98.82\n", - "[NeMo I 2020-05-22 17:13:48 token_classification_callback:86] F1 macro: 93.74\n", - "[NeMo I 2020-05-22 17:13:48 token_classification_callback:86] F1 micro: 98.82\n", - "[NeMo I 2020-05-22 17:13:49 token_classification_callback:89] precision recall f1-score support\n", - " \n", - " O (label id: 0) 0.9938 0.9957 0.9947 22092\n", - " B (label id: 1) 0.8843 0.9034 0.8938 787\n", - " I (label id: 2) 0.9505 0.8982 0.9236 1090\n", - " \n", - " accuracy 0.9882 23969\n", - " macro avg 0.9429 0.9324 0.9374 23969\n", - " weighted avg 0.9882 0.9882 0.9882 23969\n", - "```" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.12" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} From 716880818c864133d9f9dc9cff1769388ef250a1 Mon Sep 17 00:00:00 2001 From: Yi Dong Date: Fri, 21 Jan 2022 06:32:13 -0800 Subject: [PATCH 09/22] added neural types Signed-off-by: Yi Dong --- .../ptune_text_classification_config.yaml | 1 + .../ptune_text_classification.py | 4 +- .../ptune_text_classification_dataset.py | 16 ++- .../language_modeling/megatron/gpt_model.py | 6 +- .../models/text_classification/__init__.py | 5 +- .../ptune_text_classification_model.py | 112 ++++++++++-------- .../modules/common/megatron/transformer.py | 1 - .../nlp/modules/common/prompt_encoder.py | 36 +++--- 8 files changed, 103 insertions(+), 78 deletions(-) diff --git a/examples/nlp/text_classification/conf/ptune_text_classification_config.yaml b/examples/nlp/text_classification/conf/ptune_text_classification_config.yaml index 9f7a871c2d9a..0817c9638e7d 100644 --- a/examples/nlp/text_classification/conf/ptune_text_classification_config.yaml +++ b/examples/nlp/text_classification/conf/ptune_text_classification_config.yaml @@ -51,6 +51,7 @@ model: prompt_encoder: template: [3, 3, 0] dropout: 0.0 + num_layers: 2 dataset: classes: ??? # The class labels, e.g. ['positive', 'neutral', 'negative'] diff --git a/examples/nlp/text_classification/ptune_text_classification.py b/examples/nlp/text_classification/ptune_text_classification.py index 6ca660f8a47c..afd54fbb3326 100644 --- a/examples/nlp/text_classification/ptune_text_classification.py +++ b/examples/nlp/text_classification/ptune_text_classification.py @@ -98,7 +98,9 @@ import pytorch_lightning as pl from omegaconf import DictConfig, OmegaConf -from nemo.collections.nlp.models.text_classification.ptune_text_classification_model import PTuneTextClassificationModel +from nemo.collections.nlp.models.text_classification.ptune_text_classification_model import ( + PTuneTextClassificationModel, +) from nemo.collections.nlp.parts.nlp_overrides import NLPDDPPlugin from nemo.core.config import hydra_runner from nemo.utils import logging diff --git a/nemo/collections/nlp/data/text_classification/ptune_text_classification_dataset.py b/nemo/collections/nlp/data/text_classification/ptune_text_classification_dataset.py index 85ccc66e705e..c0595d10d57d 100644 --- a/nemo/collections/nlp/data/text_classification/ptune_text_classification_dataset.py +++ b/nemo/collections/nlp/data/text_classification/ptune_text_classification_dataset.py @@ -14,14 +14,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json import os -from typing import List +from typing import Dict, List, Optional + from nemo.core.classes import Dataset +from nemo.core.classes.common import typecheck +from nemo.core.neural_types import NeuralType, StringLabel, StringType __all__ = ['BankPTextClassificationDataset', 'token_wrapper'] -import json - def load_file(filename): data = [] @@ -36,7 +38,11 @@ def token_wrapper(token: str) -> str: class BankPTextClassificationDataset(Dataset): - def __init__(self, input_file: str, sentiments: List[str], data: List[str]=None): + @property + def output_types(self) -> Optional[Dict[str, NeuralType]]: + return {"sentences": [NeuralType(('T'), StringType())], "labels": [NeuralType(('T'), StringLabel())]} + + def __init__(self, input_file: str, sentiments: List[str], data: List[str] = None): super().__init__() if input_file and not os.path.exists(input_file): raise FileNotFoundError( @@ -48,7 +54,7 @@ def __init__(self, input_file: str, sentiments: List[str], data: List[str]=None) else: json_data = [] for line in data: - json_data.append({'sentence': line+' Sentiment ', 'sentiment': ''}) + json_data.append({'sentence': line + ' Sentiment ', 'sentiment': ''}) self.x_hs, self.x_ts = [], [] self.data = json_data diff --git a/nemo/collections/nlp/models/language_modeling/megatron/gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron/gpt_model.py index 043ef32e6b17..43959ff2f8c9 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron/gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron/gpt_model.py @@ -31,7 +31,7 @@ def post_language_model_processing( parallel_output, forward_method_parallel_output, fp16_lm_cross_entropy, - return_logits=False + return_logits=False, ): if get_key_value: lm_output, presents = lm_output @@ -168,7 +168,7 @@ def forward( prompt_tags=prompt_tags, layer_past=layer_past, get_key_value=get_key_value, - encoder_input=encoder_input + encoder_input=encoder_input, ) if self.post_process: @@ -180,7 +180,7 @@ def forward( self.parallel_output, forward_method_parallel_output, self.fp16_lm_cross_entropy, - return_logits=encoder_input is not None + return_logits=encoder_input is not None, ) else: return lm_output diff --git a/nemo/collections/nlp/models/text_classification/__init__.py b/nemo/collections/nlp/models/text_classification/__init__.py index 6d5dc10fc600..76ac32823cf4 100644 --- a/nemo/collections/nlp/models/text_classification/__init__.py +++ b/nemo/collections/nlp/models/text_classification/__init__.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from nemo.collections.nlp.models.text_classification.ptune_text_classification_model import ( + PTuneTextClassificationModel, +) from nemo.collections.nlp.models.text_classification.text_classification_model import TextClassificationModel -from nemo.collections.nlp.models.text_classification.ptune_text_classification_model import PTuneTextClassificationModel - diff --git a/nemo/collections/nlp/models/text_classification/ptune_text_classification_model.py b/nemo/collections/nlp/models/text_classification/ptune_text_classification_model.py index ed50ab9a1351..7965ec723dc1 100644 --- a/nemo/collections/nlp/models/text_classification/ptune_text_classification_model.py +++ b/nemo/collections/nlp/models/text_classification/ptune_text_classification_model.py @@ -19,25 +19,25 @@ import torch from omegaconf import DictConfig from pytorch_lightning import Trainer +from torch.nn.utils.rnn import pad_sequence from nemo.collections.common.losses import CrossEntropyLoss -from nemo.collections.nlp.data.text_classification.ptune_text_classification_dataset import BankPTextClassificationDataset, token_wrapper +from nemo.collections.nlp.data.text_classification.ptune_text_classification_dataset import ( + BankPTextClassificationDataset, + token_wrapper, +) from nemo.collections.nlp.metrics.classification_report import ClassificationReport +from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel from nemo.collections.nlp.models.nlp_model import NLPModel from nemo.collections.nlp.modules.common import SequenceClassifier from nemo.collections.nlp.modules.common.lm_utils import get_lm_model +from nemo.collections.nlp.modules.common.megatron.megatron_init import initialize_model_parallel_for_nemo +from nemo.collections.nlp.modules.common.prompt_encoder import PromptEncoder +from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer from nemo.collections.nlp.parts.utils_funcs import tensor2list from nemo.core.classes.common import typecheck from nemo.core.classes.exportable import Exportable -from nemo.core.neural_types import NeuralType -from nemo.utils import logging -from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel -from nemo.collections.nlp.modules.common.prompt_encoder import PromptEncoder -from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer -from nemo.collections.nlp.modules.common.megatron.megatron_init import ( - initialize_model_parallel_for_nemo, -) -from torch.nn.utils.rnn import pad_sequence +from nemo.core.neural_types import LossType, NeuralType, PredictionsType, StringLabel, StringType from nemo.utils import logging __all__ = ['PTuneTextClassificationModel'] @@ -45,15 +45,19 @@ SMALL_LOGITS = -100 -class PTuneTextClassificationModel(NLPModel, Exportable): - # @property - # def input_types(self) -> Optional[Dict[str, NeuralType]]: - # return self.bert_model.input_types +class PTuneTextClassificationModel(NLPModel, Exportable): + @property + def input_types(self) -> Optional[Dict[str, NeuralType]]: + return {"sentences": [NeuralType(('T'), StringType())], "labels": [NeuralType(('T'), StringLabel())]} - # @property - # def output_types(self) -> Optional[Dict[str, NeuralType]]: - # return self.classifier.output_types + @property + def output_types(self) -> Optional[Dict[str, NeuralType]]: + return { + "floss": NeuralType((), LossType()), + "returned_pred": NeuralType(('B'), PredictionsType()), + "returned_label": NeuralType(('B'), PredictionsType()), + } def __init__(self, cfg: DictConfig, trainer: Trainer = None): """Initializes the BERTTextClassifier model.""" @@ -81,16 +85,16 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): self.class_weights = None - self.model = MegatronGPTModel.restore_from(self.register_artifact('language_model.nemo_file', cfg.language_model.get('nemo_file', None)), - trainer=trainer).half() + self.model = MegatronGPTModel.restore_from( + self.register_artifact('language_model.nemo_file', cfg.language_model.get('nemo_file', None)), + trainer=trainer, + ).half() for param in self.model.parameters(): param.requires_grad = cfg.use_lm_finetune hidden_size = self.model.cfg.hidden_size - - # register the file containing the labels into the artifacts to get stored in the '.nemo' file later self.classes = cfg.dataset.classes @@ -115,13 +119,13 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): num_classes=len(self.classes), label_ids=label_ids, mode='micro', dist_sync_on_step=True ) - self.template = cfg.prompt_encoder.template self.prompt_encoder = PromptEncoder( template=cfg.prompt_encoder.template, hidden_size=hidden_size, - lstm_dropout=cfg.prompt_encoder.dropout + lstm_dropout=cfg.prompt_encoder.dropout, + num_layers=cfg.prompt_encoder.num_layers, ) # load prompt encoder @@ -134,7 +138,11 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): # self.pad_token_id = self.tokenizer.eod # else: self.pseudo_token_id = self.tokenizer.tokenizer.get_vocab()[cfg.pseudo_token] - self.pad_token_id = self.tokenizer.tokenizer.pad_token_id if self.tokenizer.tokenizer.pad_token_id is not None else self.tokenizer.tokenizer.unk_token_id + self.pad_token_id = ( + self.tokenizer.tokenizer.pad_token_id + if self.tokenizer.tokenizer.pad_token_id is not None + else self.tokenizer.tokenizer.unk_token_id + ) self.spell_length = sum(self.template) def embed_input(self, queries): @@ -144,7 +152,9 @@ def embed_input(self, queries): queries_for_embedding[(queries == self.pseudo_token_id)] = self.pad_token_id raw_embeds = self.embeddings(queries_for_embedding) - blocked_indices = (queries == self.pseudo_token_id).nonzero().reshape((bz, self.spell_length, 2))[:, :, 1] # bz + blocked_indices = ( + (queries == self.pseudo_token_id).nonzero().reshape((bz, self.spell_length, 2))[:, :, 1] + ) # bz replace_embeds = self.prompt_encoder() for bidx in range(bz): for i in range(self.prompt_encoder.spell_length): @@ -158,12 +168,16 @@ def get_query(self, x_h, prompt_tokens, x_t=None): if len(input_token_ids) + sum(self.template) > max_seq_len: logging.warning("Input sequence is longer than the LM model max seq, will cut it off to fit") cut = len(input_token_ids) + sum(self.template) - max_seq_len - return [prompt_tokens * self.template[0] - + input_token_ids[cut:] # head entity - + prompt_tokens * self.template[1] - + (self.tokenizer.tokenizer.convert_tokens_to_ids( - self.tokenizer.tokenize(' ' + x_t)) if x_t is not None else []) - ] + return [ + prompt_tokens * self.template[0] + + input_token_ids[cut:] # head entity + + prompt_tokens * self.template[1] + + ( + self.tokenizer.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(' ' + x_t)) + if x_t is not None + else [] + ) + ] def get_ground_truth_labels(self, batch_size, label_ids): returned_label = [] @@ -203,10 +217,7 @@ def get_encoder_input(self, sentences): bz, seq_len, _ = inputs_embeds.shape # get the GPT causal mask - causal_mask = torch.tril( - torch.ones((bz, seq_len, seq_len), - device=self.device)).view(bz, 1, - seq_len, seq_len) + causal_mask = torch.tril(torch.ones((bz, seq_len, seq_len), device=self.device)).view(bz, 1, seq_len, seq_len) # combine the attention_mask and causal_mask r = causal_mask.permute((1, 2, 0, 3)) * attention_mask.int() new_atten = r.permute((2, 0, 1, 3)) @@ -230,31 +241,35 @@ def get_label_input(self, labels, label_position, seq_len): x_ts = [token_wrapper(x_t) for x_t in labels] # construct label ids - label_ids = torch.LongTensor(self.tokenizer.tokenizer.convert_tokens_to_ids(x_ts)).reshape( - (batch_size, -1)).to(self.device) + label_ids = ( + torch.LongTensor(self.tokenizer.tokenizer.convert_tokens_to_ids(x_ts)) + .reshape((batch_size, -1)) + .to(self.device) + ) labels = torch.zeros(batch_size, seq_len).to(self.device).fill_(SMALL_LOGITS).long() # bz * seq_len labels = labels.scatter_(1, label_position, label_ids) return labels, label_ids + @typecheck() def forward_eval(self, sentences): encoder_input, new_atten, label_position = self.get_encoder_input(sentences) batch_size, _, seq_len, _ = new_atten.shape - output = self.model.model(None, None, encoder_input=encoder_input, - attention_mask=new_atten) + output = self.model.model(None, None, encoder_input=encoder_input, attention_mask=new_atten) logits = output _, returned_pred = self.get_prediction(batch_size, label_position, logits) return returned_pred + @typecheck() def forward(self, sentences, labels): encoder_input, new_atten, label_position = self.get_encoder_input(sentences) batch_size, _, seq_len, _ = new_atten.shape labels_input, label_ids = self.get_label_input(labels, label_position, seq_len) - output = self.model.model(None, None, encoder_input=encoder_input, - attention_mask=new_atten, - labels=labels_input) + output = self.model.model( + None, None, encoder_input=encoder_input, attention_mask=new_atten, labels=labels_input + ) loss, logits = output floss = (loss[(labels_input != SMALL_LOGITS)]).mean() @@ -269,7 +284,7 @@ def training_step(self, batch, batch_idx): """ # forward pass sentences, labels = batch - train_loss, _, _ = self.forward(sentences, labels) + train_loss, _, _ = self.forward(sentences=sentences, labels=labels) lr = self._optimizer.param_groups[0]['lr'] self.log('train_loss', train_loss) @@ -286,7 +301,7 @@ def validation_step(self, batch, batch_idx): passed in as `batch`. """ sentences, labels = batch - val_loss, preds, gt_labels = self.forward(sentences, labels) + val_loss, preds, gt_labels = self.forward(sentences=sentences, labels=sentences) tp, fn, fp, _ = self.classification_report(preds, gt_labels) @@ -371,10 +386,7 @@ def _setup_dataloader_from_config(self, cfg: Dict) -> 'torch.utils.data.DataLoad [WORD][SPACE][WORD][SPACE][WORD][...][TAB][LABEL]' ) - dataset = BankPTextClassificationDataset( - input_file, - self._cfg.dataset.classes - ) + dataset = BankPTextClassificationDataset(input_file, self._cfg.dataset.classes) return torch.utils.data.DataLoader( dataset=dataset, @@ -417,9 +429,7 @@ def classifytext(self, queries: List[str], batch_size: int = 1, max_seq_length: logging.set_verbosity(logging_level) return all_preds - def _setup_infer_dataloader( - self, cfg: Dict, queries: List[str] - ) -> 'torch.utils.data.DataLoader': + def _setup_infer_dataloader(self, cfg: Dict, queries: List[str]) -> 'torch.utils.data.DataLoader': """ Setup function for a infer data loader. diff --git a/nemo/collections/nlp/modules/common/megatron/transformer.py b/nemo/collections/nlp/modules/common/megatron/transformer.py index 6da558afa781..5023d7d4d57f 100644 --- a/nemo/collections/nlp/modules/common/megatron/transformer.py +++ b/nemo/collections/nlp/modules/common/megatron/transformer.py @@ -29,7 +29,6 @@ bias_dropout_add_fused_train, ) from nemo.collections.nlp.modules.common.megatron.fused_bias_gelu import fused_bias_gelu - from nemo.collections.nlp.modules.common.megatron.fused_layer_norm import get_layer_norm from nemo.collections.nlp.modules.common.megatron.module import MegatronModule from nemo.collections.nlp.modules.common.megatron.utils import attention_mask_func, erf_gelu diff --git a/nemo/collections/nlp/modules/common/prompt_encoder.py b/nemo/collections/nlp/modules/common/prompt_encoder.py index b354b123955a..3156c89794ce 100644 --- a/nemo/collections/nlp/modules/common/prompt_encoder.py +++ b/nemo/collections/nlp/modules/common/prompt_encoder.py @@ -12,21 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List +from typing import Dict, List, Optional -from nemo.core.classes import Exportable, NeuralModule import torch from torch import nn +from nemo.core.classes import Exportable, NeuralModule +from nemo.core.classes.common import typecheck +from nemo.core.neural_types import ChannelType, NeuralType + __all__ = ['SequenceClassifier'] class PromptEncoder(NeuralModule, Exportable): + @property + def output_types(self) -> Optional[Dict[str, NeuralType]]: + return {"output_embeds": NeuralType(('T', 'C'), ChannelType())} - def __init__(self, - template: List[int], - hidden_size: int, - lstm_dropout: float): + def __init__(self, template: List[int], hidden_size: int, lstm_dropout: float, num_layers: int): super().__init__() self.spell_length = sum(template) self.hidden_size = hidden_size @@ -43,16 +46,19 @@ def __init__(self, # embedding self.embedding = torch.nn.Embedding(len(self.cloze_mask[0]), self.hidden_size) # LSTM - self.lstm_head = torch.nn.LSTM(input_size=self.hidden_size, - hidden_size=self.hidden_size // 2, - num_layers=2, - dropout=lstm_dropout, - bidirectional=True, - batch_first=True) - self.mlp_head = nn.Sequential(nn.Linear(self.hidden_size, self.hidden_size), - nn.ReLU(), - nn.Linear(self.hidden_size, self.hidden_size)) + self.lstm_head = torch.nn.LSTM( + input_size=self.hidden_size, + hidden_size=self.hidden_size // 2, + num_layers=num_layers, + dropout=lstm_dropout, + bidirectional=True, + batch_first=True, + ) + self.mlp_head = nn.Sequential( + nn.Linear(self.hidden_size, self.hidden_size), nn.ReLU(), nn.Linear(self.hidden_size, self.hidden_size) + ) + @typecheck() def forward(self): input_embeds = self.embedding(self.seq_indices).unsqueeze(0) output_embeds = self.mlp_head(self.lstm_head(input_embeds)[0]).squeeze() From e59c2e4854b8a5fa0b8f93064b1be6820afd3df6 Mon Sep 17 00:00:00 2001 From: Yi Dong Date: Fri, 21 Jan 2022 06:59:29 -0800 Subject: [PATCH 10/22] updated the doc Signed-off-by: Yi Dong --- .../ptune_text_classification.py | 34 ++++++------- .../ptune_text_classification_dataset.py | 17 +++---- .../ptune_text_classification_model.py | 31 +++++------- .../nlp/modules/common/prompt_encoder.py | 16 +++++- tutorials/nlp/PTune_Sentiment_Analysis.ipynb | 49 ++++++------------- 5 files changed, 67 insertions(+), 80 deletions(-) diff --git a/examples/nlp/text_classification/ptune_text_classification.py b/examples/nlp/text_classification/ptune_text_classification.py index afd54fbb3326..b21a721a8a2e 100644 --- a/examples/nlp/text_classification/ptune_text_classification.py +++ b/examples/nlp/text_classification/ptune_text_classification.py @@ -13,24 +13,22 @@ # limitations under the License. """ -This script contains an example on how to train, evaluate and perform inference with the TextClassificationModel. -TextClassificationModel in NeMo supports text classification problems such as sentiment analysis or +This script contains an example on how to train, evaluate and perform inference with the PTuneTextClassificationModel. +PTuneTextClassificationModel in NeMo supports text classification problems such as sentiment analysis or domain/intent detection for dialogue systems, as long as the data follows the format specified below. ***Data format*** -TextClassificationModel requires the data to be stored in TAB separated files (.tsv) with two columns of sentence and -label. Each line of the data file contains text sequences, where words are separated with spaces and label separated -with [TAB], i.e.: - -[WORD][SPACE][WORD][SPACE][WORD][TAB][LABEL] +PTuneTextClassificationModel requires the data to be stored in loose json format with two keys of sentence and +label in each line, i.e. +{"sentence": "sentence string", "label": "label string"} For example: -hide new secretions from the parental units[TAB]0 -that loves its characters and communicates something rather beautiful about human nature[TAB]1 +{"sentence": "hide new secretions from the parental units", "label": "0"} +{"sentence": "that loves its characters and communicates something rather beautiful about human nature", "label":"1"} ... -If your dataset is stored in another format, you need to convert it to this format to use the TextClassificationModel. +If your dataset is stored in another format, you need to convert it to this format to use the PTuneTextClassificationModel. ***Setting the configs*** @@ -41,12 +39,12 @@ trainer: Any argument to be passed to PyTorch Lightning including number of epochs, number of GPUs, precision level, etc. -This script uses the `/examples/nlp/text_classification/conf/text_classification_config.yaml` default config file +This script uses the `/examples/nlp/text_classification/conf/ptune_text_classification_config.yaml` default config file by default. You may update the config file from the file directly or by using the command line arguments. Other option is to set another config file via command line arguments by `--config-name=CONFIG_FILE_PATH'. -You first need to set the num_classes in the config file which specifies the number of classes in the dataset. -Notice that some config lines, including `model.dataset.classes_num`, have `???` as their value, this means that values +You first need to set the classes in the config file which specifies the class types in the dataset. +Notice that some config lines, including `model.dataset.classes`, have `???` as their value, this means that values for these fields are required to be specified by the user. We need to specify and set the `model.train_ds.file_name`, `model.validation_ds.file_name`, and `model.test_ds.file_name` in the config file to the paths of the train, validation, and test files if they exist. We may do it by updating the config file or by setting them from the command line. @@ -55,8 +53,8 @@ ***How to run the script?*** For example the following would train a model for 50 epochs in 2 GPUs on a classification task with 2 classes: -# python text_classification_with_bert.py - model.dataset.num_classes=2 +# python ptune_text_classification.py + model.dataset.classes=[Label1, Label2] model.train_ds=PATH_TO_TRAIN_FILE model.validation_ds=PATH_TO_VAL_FILE trainer.max_epochs=50 @@ -65,10 +63,10 @@ This script would also reload the last checkpoint after the training is done and does evaluation on the dev set, then performs inference on some sample queries. -By default, this script uses examples/nlp/text_classification/conf/text_classifciation_config.py config file, and +By default, this script uses examples/nlp/text_classification/conf/ptune_text_classifciation_config.py config file, and you may update all the params in the config file from the command line. You may also use another config file like this: -# python text_classification_with_bert.py --config-name==PATH_TO_CONFIG_FILE +# python ptune_text_classification.py --config-name==PATH_TO_CONFIG_FILE model.dataset.num_classes=2 model.train_ds=PATH_TO_TRAIN_FILE model.validation_ds=PATH_TO_VAL_FILE @@ -78,7 +76,7 @@ ***Load a saved model*** This script would save the model after training into '.nemo' checkpoint file specified by nemo_path of the model config. You may restore the saved model like this: - model = TextClassificationModel.restore_from(restore_path=NEMO_FILE_PATH) + model = PTuneTextClassificationModel.restore_from(restore_path=NEMO_FILE_PATH) ***Evaluation a saved model on another dataset*** # If you wanted to evaluate the saved model on another dataset, you may restore the model and create a new data loader: diff --git a/nemo/collections/nlp/data/text_classification/ptune_text_classification_dataset.py b/nemo/collections/nlp/data/text_classification/ptune_text_classification_dataset.py index c0595d10d57d..79f6b3a39a87 100644 --- a/nemo/collections/nlp/data/text_classification/ptune_text_classification_dataset.py +++ b/nemo/collections/nlp/data/text_classification/ptune_text_classification_dataset.py @@ -19,10 +19,9 @@ from typing import Dict, List, Optional from nemo.core.classes import Dataset -from nemo.core.classes.common import typecheck from nemo.core.neural_types import NeuralType, StringLabel, StringType -__all__ = ['BankPTextClassificationDataset', 'token_wrapper'] +__all__ = ['PTuneTextClassificationDataset', 'token_wrapper'] def load_file(filename): @@ -37,35 +36,35 @@ def token_wrapper(token: str) -> str: return 'Ġ' + token -class BankPTextClassificationDataset(Dataset): +class PTuneTextClassificationDataset(Dataset): @property def output_types(self) -> Optional[Dict[str, NeuralType]]: return {"sentences": [NeuralType(('T'), StringType())], "labels": [NeuralType(('T'), StringLabel())]} - def __init__(self, input_file: str, sentiments: List[str], data: List[str] = None): + def __init__(self, input_file: str, sentiments: List[str], data: List[str] = None, prompt: str = 'Sentiment'): super().__init__() if input_file and not os.path.exists(input_file): raise FileNotFoundError( f'Data file `{input_file}` not found! Each line of the data file should contain json object' - f'where `sentence` key maps to sentence and `sentiment` key maps to sentiment' + f'where `sentence` key maps to sentence and `label` key maps to label' ) if data is None: json_data = load_file(input_file) else: json_data = [] for line in data: - json_data.append({'sentence': line + ' Sentiment ', 'sentiment': ''}) + json_data.append({'sentence': line + f' {prompt} ', 'label': ''}) self.x_hs, self.x_ts = [], [] self.data = json_data for d in json_data: - if d['sentiment'] not in sentiments: + if d['label'] not in sentiments: continue - self.x_ts.append(d['sentiment']) + self.x_ts.append(d['label']) self.x_hs.append(d['sentence']) def __len__(self): return len(self.data) def __getitem__(self, i): - return self.data[i]['sentence'], self.data[i]['sentiment'] + return self.data[i]['sentence'], self.data[i]['label'] diff --git a/nemo/collections/nlp/models/text_classification/ptune_text_classification_model.py b/nemo/collections/nlp/models/text_classification/ptune_text_classification_model.py index 7965ec723dc1..8f88e7647a5d 100644 --- a/nemo/collections/nlp/models/text_classification/ptune_text_classification_model.py +++ b/nemo/collections/nlp/models/text_classification/ptune_text_classification_model.py @@ -21,20 +21,16 @@ from pytorch_lightning import Trainer from torch.nn.utils.rnn import pad_sequence -from nemo.collections.common.losses import CrossEntropyLoss from nemo.collections.nlp.data.text_classification.ptune_text_classification_dataset import ( - BankPTextClassificationDataset, + PTuneTextClassificationDataset, token_wrapper, ) from nemo.collections.nlp.metrics.classification_report import ClassificationReport from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel from nemo.collections.nlp.models.nlp_model import NLPModel -from nemo.collections.nlp.modules.common import SequenceClassifier -from nemo.collections.nlp.modules.common.lm_utils import get_lm_model from nemo.collections.nlp.modules.common.megatron.megatron_init import initialize_model_parallel_for_nemo from nemo.collections.nlp.modules.common.prompt_encoder import PromptEncoder from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer -from nemo.collections.nlp.parts.utils_funcs import tensor2list from nemo.core.classes.common import typecheck from nemo.core.classes.exportable import Exportable from nemo.core.neural_types import LossType, NeuralType, PredictionsType, StringLabel, StringType @@ -60,7 +56,7 @@ def output_types(self) -> Optional[Dict[str, NeuralType]]: } def __init__(self, cfg: DictConfig, trainer: Trainer = None): - """Initializes the BERTTextClassifier model.""" + """Initializes the PTune TextClassifier model.""" super().__init__(cfg=cfg, trainer=trainer) initialize_model_parallel_for_nemo( @@ -103,6 +99,10 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): # set allowed vocab set self.vocab = self.tokenizer.tokenizer.get_vocab() + #make sure classes are part of the vocab + for k in cfg.dataset.classes: + if token_wrapper(k) not in self.vocab: + logging.error(f'class {k} is not part of the vocabulary. Please add it to your vocab') self.allowed_vocab_ids = set(self.vocab[token_wrapper(k)] for k in cfg.dataset.classes) # map from id to label @@ -132,11 +132,6 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): self.hidden_size = hidden_size self.tokenizer.add_special_tokens({'additional_special_tokens': [cfg.pseudo_token]}) - # if 'megatron' in self.args.model_name: - # self.pseudo_token_id = self.tokenizer.tokenizer.convert_tokens_to_ids( - # self.args.pseudo_token) - # self.pad_token_id = self.tokenizer.eod - # else: self.pseudo_token_id = self.tokenizer.tokenizer.get_vocab()[cfg.pseudo_token] self.pad_token_id = ( self.tokenizer.tokenizer.pad_token_id @@ -386,7 +381,7 @@ def _setup_dataloader_from_config(self, cfg: Dict) -> 'torch.utils.data.DataLoad [WORD][SPACE][WORD][SPACE][WORD][...][TAB][LABEL]' ) - dataset = BankPTextClassificationDataset(input_file, self._cfg.dataset.classes) + dataset = PTuneTextClassificationDataset(input_file, self._cfg.dataset.classes) return torch.utils.data.DataLoader( dataset=dataset, @@ -399,13 +394,13 @@ def _setup_dataloader_from_config(self, cfg: Dict) -> 'torch.utils.data.DataLoad ) @torch.no_grad() - def classifytext(self, queries: List[str], batch_size: int = 1, max_seq_length: int = -1) -> List[int]: + def classifytext(self, queries: List[str], batch_size: int = 1, prompt: str = 'Sentiment') -> List[int]: """ Get prediction for the queries Args: queries: text sequences batch_size: batch size to use during inference - max_seq_length: sequences longer than max_seq_length will get truncated. default -1 disables truncation. + prompt: the prompt string appended at the end of your input sentence Returns: all_preds: model predictions """ @@ -418,7 +413,7 @@ def classifytext(self, queries: List[str], batch_size: int = 1, max_seq_length: logging_level = logging.get_verbosity() logging.set_verbosity(logging.WARNING) dataloader_cfg = {"batch_size": batch_size, "num_workers": 3, "pin_memory": False} - infer_datalayer = self._setup_infer_dataloader(dataloader_cfg, queries) + infer_datalayer = self._setup_infer_dataloader(dataloader_cfg, queries, prompt) for i, batch in enumerate(infer_datalayer): sentences, _ = batch preds = self.forward_eval(sentences) @@ -429,18 +424,18 @@ def classifytext(self, queries: List[str], batch_size: int = 1, max_seq_length: logging.set_verbosity(logging_level) return all_preds - def _setup_infer_dataloader(self, cfg: Dict, queries: List[str]) -> 'torch.utils.data.DataLoader': + def _setup_infer_dataloader(self, cfg: Dict, queries: List[str], prompt: str) -> 'torch.utils.data.DataLoader': """ Setup function for a infer data loader. Args: cfg: config dictionary containing data loader params like batch_size, num_workers and pin_memory queries: text - max_seq_length: maximum length of queries, default is -1 for no limit + prompt: the prompt string appended at the end of your input sentence Returns: A pytorch DataLoader. """ - dataset = BankPTextClassificationDataset(None, None, queries) + dataset = PTuneTextClassificationDataset(None, None, queries, prompt) return torch.utils.data.DataLoader( dataset=dataset, batch_size=cfg["batch_size"], diff --git a/nemo/collections/nlp/modules/common/prompt_encoder.py b/nemo/collections/nlp/modules/common/prompt_encoder.py index 3156c89794ce..56ffbabf5629 100644 --- a/nemo/collections/nlp/modules/common/prompt_encoder.py +++ b/nemo/collections/nlp/modules/common/prompt_encoder.py @@ -21,15 +21,27 @@ from nemo.core.classes.common import typecheck from nemo.core.neural_types import ChannelType, NeuralType -__all__ = ['SequenceClassifier'] +__all__ = ['PromptEncoder'] class PromptEncoder(NeuralModule, Exportable): + """ + The Prompt Encoder network that is used to generate the virtual token embeddings + """ + @property def output_types(self) -> Optional[Dict[str, NeuralType]]: return {"output_embeds": NeuralType(('T', 'C'), ChannelType())} def __init__(self, template: List[int], hidden_size: int, lstm_dropout: float, num_layers: int): + """ + Initializes the PromptEncoder module. + Args: + template: the template sizes of the vitural tokens for different clozes + hidden_size: hidden dimension + lstm_dropout: the dropout used for the LSTM + num_layers: number of layers used in the LSTM + """ super().__init__() self.spell_length = sum(template) self.hidden_size = hidden_size @@ -59,7 +71,7 @@ def __init__(self, template: List[int], hidden_size: int, lstm_dropout: float, n ) @typecheck() - def forward(self): + def forward(self) -> torch.Tensor: input_embeds = self.embedding(self.seq_indices).unsqueeze(0) output_embeds = self.mlp_head(self.lstm_head(input_embeds)[0]).squeeze() return output_embeds diff --git a/tutorials/nlp/PTune_Sentiment_Analysis.ipynb b/tutorials/nlp/PTune_Sentiment_Analysis.ipynb index ba7dc3aff894..4234a9bcb75a 100644 --- a/tutorials/nlp/PTune_Sentiment_Analysis.ipynb +++ b/tutorials/nlp/PTune_Sentiment_Analysis.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": 20, "id": "b7a434f4", "metadata": {}, "outputs": [], @@ -34,27 +34,10 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 21, "id": "challenging-pioneer", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "################################################################################\n", - "### WARNING, path does not exist: KALDI_ROOT=/mnt/matylda5/iveselyk/Tools/kaldi-trunk\n", - "### (please add 'export KALDI_ROOT=' in your $HOME/.profile)\n", - "### (or run as: KALDI_ROOT= python .py)\n", - "################################################################################\n", - "\n", - "[NeMo W 2022-01-21 13:35:51 experimental:27] Module is experimental, not ready for production and is not fully supported. Use at your own risk.\n", - "[NeMo W 2022-01-21 13:35:51 experimental:27] Module is experimental, not ready for production and is not fully supported. Use at your own risk.\n", - "[NeMo W 2022-01-21 13:35:51 experimental:27] Module is experimental, not ready for production and is not fully supported. Use at your own risk.\n", - "[NeMo W 2022-01-21 13:35:51 experimental:27] Module is experimental, not ready for production and is not fully supported. Use at your own risk.\n" - ] - } - ], + "outputs": [], "source": [ "from nemo.collections import nlp as nemo_nlp\n", "from nemo.utils.exp_manager import exp_manager\n", @@ -112,7 +95,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 22, "id": "federal-beads", "metadata": {}, "outputs": [], @@ -134,7 +117,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 23, "id": "8ad03fc0", "metadata": {}, "outputs": [ @@ -142,20 +125,20 @@ "name": "stdout", "output_type": "stream", "text": [ - "--2022-01-20 01:48:29-- https://www.researchgate.net/profile/Pekka_Malo/publication/251231364_FinancialPhraseBank-v10/data/0c96051eee4fb1d56e000000/FinancialPhraseBank-v10.zip\n", - "Resolving www.researchgate.net (www.researchgate.net)... 104.17.32.105, 104.17.33.105, 2606:4700::6811:2069, ...\n", + "--2022-01-21 14:56:30-- https://www.researchgate.net/profile/Pekka_Malo/publication/251231364_FinancialPhraseBank-v10/data/0c96051eee4fb1d56e000000/FinancialPhraseBank-v10.zip\n", + "Resolving www.researchgate.net (www.researchgate.net)... 104.17.32.105, 104.17.33.105, 2606:4700::6811:2169, ...\n", "Connecting to www.researchgate.net (www.researchgate.net)|104.17.32.105|:443... connected.\n", "HTTP request sent, awaiting response... 301 Moved Permanently\n", "Location: https://www.researchgate.net/profile/Pekka-Malo/publication/251231364_FinancialPhraseBank-v10/data/0c96051eee4fb1d56e000000/FinancialPhraseBank-v10.zip [following]\n", - "--2022-01-20 01:48:29-- https://www.researchgate.net/profile/Pekka-Malo/publication/251231364_FinancialPhraseBank-v10/data/0c96051eee4fb1d56e000000/FinancialPhraseBank-v10.zip\n", + "--2022-01-21 14:56:30-- https://www.researchgate.net/profile/Pekka-Malo/publication/251231364_FinancialPhraseBank-v10/data/0c96051eee4fb1d56e000000/FinancialPhraseBank-v10.zip\n", "Reusing existing connection to www.researchgate.net:443.\n", "HTTP request sent, awaiting response... 200 OK\n", "Length: 681890 (666K) [application/zip]\n", "Saving to: ‘FinancialPhraseBank-v10.zip’\n", "\n", - "FinancialPhraseBank 100%[===================>] 665.91K --.-KB/s in 0.02s \n", + "FinancialPhraseBank 100%[===================>] 665.91K 1.74MB/s in 0.4s \n", "\n", - "2022-01-20 01:48:30 (28.1 MB/s) - ‘FinancialPhraseBank-v10.zip’ saved [681890/681890]\n", + "2022-01-21 14:56:31 (1.74 MB/s) - ‘FinancialPhraseBank-v10.zip’ saved [681890/681890]\n", "\n", "Archive: DATA_DIR/FinancialPhraseBank-v10.zip\n" ] @@ -169,7 +152,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 24, "id": "radical-castle", "metadata": {}, "outputs": [ @@ -207,7 +190,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 25, "id": "198287d4", "metadata": {}, "outputs": [], @@ -247,7 +230,7 @@ " part1 = splits[0].strip()\n", " part2 = splits[1].strip()\n", " obj['sentence'] = part1 +' Sentiment '\n", - " obj['sentiment'] = part2\n", + " obj['label'] = part2\n", " f.write(json.dumps(obj)+'\\n')\n", "\n", "\n", @@ -280,7 +263,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 26, "id": "sound-surgeon", "metadata": {}, "outputs": [ @@ -288,8 +271,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "{\"sentence\": \"The contract includes heating plant equipment and associated installation work . Sentiment \", \"sentiment\": \"neutral\"}\n", - "{\"sentence\": \"The utility will also provide services related to electricity management , such as hedging trades and risk management and reporting . Sentiment \", \"sentiment\": \"neutral\"}\n" + "{\"sentence\": \"The contract includes heating plant equipment and associated installation work . Sentiment \", \"label\": \"neutral\"}\n", + "{\"sentence\": \"The utility will also provide services related to electricity management , such as hedging trades and risk management and reporting . Sentiment \", \"label\": \"neutral\"}\n" ] } ], From d234cf6f34ae20e130b624f3712482e7864cbad8 Mon Sep 17 00:00:00 2001 From: Yi Dong Date: Fri, 21 Jan 2022 09:20:04 -0800 Subject: [PATCH 11/22] fixed the notebook Signed-off-by: Yi Dong --- .../ptune_text_classification_config.yaml | 6 +- .../ptune_text_classification.py | 7 +- .../ptune_text_classification_dataset.py | 19 +- .../ptune_text_classification_model.py | 11 +- tutorials/nlp/PTune_Sentiment_Analysis.ipynb | 776 +++--------------- 5 files changed, 114 insertions(+), 705 deletions(-) diff --git a/examples/nlp/text_classification/conf/ptune_text_classification_config.yaml b/examples/nlp/text_classification/conf/ptune_text_classification_config.yaml index 0817c9638e7d..ed7e376dfb46 100644 --- a/examples/nlp/text_classification/conf/ptune_text_classification_config.yaml +++ b/examples/nlp/text_classification/conf/ptune_text_classification_config.yaml @@ -103,9 +103,9 @@ model: # List of some sample queries for inference after training is done infer_samples: [ - 'by the end of no such thing the audience , like beatrice , has a watchful affection for the monster .', - 'director rob marshall went out gunning to make a great one .', - 'uneasy mishmash of styles and genres .', + 'For example , net sales increased by 5.9 % from the first quarter , and EBITDA increased from a negative EUR 0.2 mn in the first quarter of 2009 .', + '8 May 2009 - Finnish liquid handling products and diagnostic test systems maker Biohit Oyj ( HEL : BIOBV ) said today ( 8 May 2009 ) its net loss narrowed to EUR0 .1 m ( USD0 .14 m ) for the first quarter of 2009 from EUR0 .4 m for the same period of 2008 .', + 'CHS Expo Freight is a major Finnish fair , exhibition and culture logistics company that provides logistics services to various events by land , air and sea .', ] exp_manager: diff --git a/examples/nlp/text_classification/ptune_text_classification.py b/examples/nlp/text_classification/ptune_text_classification.py index b21a721a8a2e..c91554f5539d 100644 --- a/examples/nlp/text_classification/ptune_text_classification.py +++ b/examples/nlp/text_classification/ptune_text_classification.py @@ -24,8 +24,8 @@ For example: -{"sentence": "hide new secretions from the parental units", "label": "0"} -{"sentence": "that loves its characters and communicates something rather beautiful about human nature", "label":"1"} +{"sentence": "The output of the contracts totals 72 MWe. ", "label": "neutral"} +{"sentence": "Pretax profit totaled EUR 9.0 mn , down from EUR 36.3 mn in 2007 .", "label": "negative"} ... If your dataset is stored in another format, you need to convert it to this format to use the PTuneTextClassificationModel. @@ -130,6 +130,7 @@ def main(cfg: DictConfig) -> None: if cfg.model.test_ds.file_path: logging.info("===========================================================================================") logging.info("Starting the testing of the trained model on test set...") + trainer = pl.Trainer(**cfg.trainer) trainer.test(model=model, ckpt_path=None, verbose=False) logging.info("Testing finished!") logging.info("===========================================================================================") @@ -140,7 +141,7 @@ def main(cfg: DictConfig) -> None: logging.info("Starting the inference on some sample queries...") # max_seq_length=512 is the maximum length BERT supports. - results = model.classifytext(queries=cfg.model.infer_samples, batch_size=16, max_seq_length=512) + results = model.cuda().classifytext(queries=cfg.model.infer_samples, batch_size=1, prompt='Sentiment') logging.info('The prediction results of some sample queries with the trained model:') for query, result in zip(cfg.model.infer_samples, results): logging.info(f'Query : {query}') diff --git a/nemo/collections/nlp/data/text_classification/ptune_text_classification_dataset.py b/nemo/collections/nlp/data/text_classification/ptune_text_classification_dataset.py index 79f6b3a39a87..8db901ec51fe 100644 --- a/nemo/collections/nlp/data/text_classification/ptune_text_classification_dataset.py +++ b/nemo/collections/nlp/data/text_classification/ptune_text_classification_dataset.py @@ -41,27 +41,28 @@ class PTuneTextClassificationDataset(Dataset): def output_types(self) -> Optional[Dict[str, NeuralType]]: return {"sentences": [NeuralType(('T'), StringType())], "labels": [NeuralType(('T'), StringLabel())]} - def __init__(self, input_file: str, sentiments: List[str], data: List[str] = None, prompt: str = 'Sentiment'): + def __init__(self, input_file: str, queries: List[str] = None, prompt: str = 'Sentiment'): + """ + A dataset class that feed data for P-tuning model + Args: + input_file: loose json data file. The format is {"sentence":"input sentence", "label":"class label"} + queries: list of query input sentences + prompt: the prompt string appended at the end of your input sentence + """ super().__init__() if input_file and not os.path.exists(input_file): raise FileNotFoundError( f'Data file `{input_file}` not found! Each line of the data file should contain json object' f'where `sentence` key maps to sentence and `label` key maps to label' ) - if data is None: + if queries is None: json_data = load_file(input_file) else: json_data = [] - for line in data: + for line in queries: json_data.append({'sentence': line + f' {prompt} ', 'label': ''}) - self.x_hs, self.x_ts = [], [] self.data = json_data - for d in json_data: - if d['label'] not in sentiments: - continue - self.x_ts.append(d['label']) - self.x_hs.append(d['sentence']) def __len__(self): return len(self.data) diff --git a/nemo/collections/nlp/models/text_classification/ptune_text_classification_model.py b/nemo/collections/nlp/models/text_classification/ptune_text_classification_model.py index 8f88e7647a5d..538ae5c22a93 100644 --- a/nemo/collections/nlp/models/text_classification/ptune_text_classification_model.py +++ b/nemo/collections/nlp/models/text_classification/ptune_text_classification_model.py @@ -245,15 +245,14 @@ def get_label_input(self, labels, label_position, seq_len): labels = labels.scatter_(1, label_position, label_ids) return labels, label_ids - @typecheck() def forward_eval(self, sentences): encoder_input, new_atten, label_position = self.get_encoder_input(sentences) batch_size, _, seq_len, _ = new_atten.shape - output = self.model.model(None, None, encoder_input=encoder_input, attention_mask=new_atten) + output = self.model.model(None, None, encoder_input=encoder_input.to(self.device), attention_mask=new_atten.to(self.device)) logits = output - _, returned_pred = self.get_prediction(batch_size, label_position, logits) + _, returned_pred = self.get_prediction(batch_size, label_position.to(self.device), logits) return returned_pred @typecheck() @@ -296,7 +295,7 @@ def validation_step(self, batch, batch_idx): passed in as `batch`. """ sentences, labels = batch - val_loss, preds, gt_labels = self.forward(sentences=sentences, labels=sentences) + val_loss, preds, gt_labels = self.forward(sentences=sentences, labels=labels) tp, fn, fp, _ = self.classification_report(preds, gt_labels) @@ -381,7 +380,7 @@ def _setup_dataloader_from_config(self, cfg: Dict) -> 'torch.utils.data.DataLoad [WORD][SPACE][WORD][SPACE][WORD][...][TAB][LABEL]' ) - dataset = PTuneTextClassificationDataset(input_file, self._cfg.dataset.classes) + dataset = PTuneTextClassificationDataset(input_file) return torch.utils.data.DataLoader( dataset=dataset, @@ -435,7 +434,7 @@ def _setup_infer_dataloader(self, cfg: Dict, queries: List[str], prompt: str) -> Returns: A pytorch DataLoader. """ - dataset = PTuneTextClassificationDataset(None, None, queries, prompt) + dataset = PTuneTextClassificationDataset(None, queries, prompt) return torch.utils.data.DataLoader( dataset=dataset, batch_size=cfg["batch_size"], diff --git a/tutorials/nlp/PTune_Sentiment_Analysis.ipynb b/tutorials/nlp/PTune_Sentiment_Analysis.ipynb index 4234a9bcb75a..8676989700b9 100644 --- a/tutorials/nlp/PTune_Sentiment_Analysis.ipynb +++ b/tutorials/nlp/PTune_Sentiment_Analysis.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 20, + "execution_count": null, "id": "b7a434f4", "metadata": {}, "outputs": [], @@ -34,7 +34,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": null, "id": "challenging-pioneer", "metadata": {}, "outputs": [], @@ -54,48 +54,34 @@ "id": "employed-ethiopia", "metadata": {}, "source": [ - "In this tutorial, we are going to describe how to finetune BioMegatron - a [BERT](https://arxiv.org/abs/1810.04805)-like [Megatron-LM](https://arxiv.org/pdf/1909.08053.pdf) model pre-trained on large biomedical text corpus ([PubMed](https://pubmed.ncbi.nlm.nih.gov/) abstracts and full-text commercial use collection) - on the [NCBI Disease Dataset](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC3951655/) for Named Entity Recognition.\n", + "In this tutorial, we are going to describe how to use [P-Tuning method](https://arxiv.org/pdf/2103.10385.pdf) to find good prompts for large GPT models, so it can solve downstream NLP tasks with good performance. P-Tuning leverages few continuous free parameters to serve as prompts fed as the input to the pre-trained language models. Freezing the large language model weights, P-Tuning model can be trained efficiently while delivering stats of art performance. \n", "\n", - "The model size of Megatron-LM can be larger than BERT, up to multi-billion parameters, compared to 345 million parameters of BERT-large.\n", - "There are some alternatives of BioMegatron, most notably [BioBERT](https://arxiv.org/abs/1901.08746). Compared to BioBERT BioMegatron is larger by model size and pre-trained on larger text corpus.\n", - "\n", - "A more general tutorial of using BERT-based models, including Megatron-LM, for downstream natural language processing tasks can be found [here](https://github.com/NVIDIA/NeMo/blob/stable/tutorials/nlp/01_Pretrained_Language_Models_for_Downstream_Tasks.ipynb).\n", + "Large Language Model can be trained with [Megatron-LM project](https://github.com/NVIDIA/Megatron-LM), up to multi-billion parameters. In this notebook, we will use the pre-trained 344M GPT model released from NGC.\n", "\n", "# Task Description\n", - "**Named entity recognition (NER)**, also referred to as entity chunking, identification or extraction, is the task of detecting and classifying key information (entities) in text.\n", - "\n", - "For instance, **given sentences from medical abstracts, what diseases are mentioned?**
\n", - "In this case, our data input is sentences from the abstracts, and our labels are the precise locations of the named disease entities. Take a look at the information provided for the dataset.\n", + "In this notebook, we are going to use P-Tuning method for **Sentiment Analysis** task, also known as opinion mining or emotion AI. It is a sub-field of NLP that tries to identify and extract opinions within a given text across blogs, reviews, social media, forums, news etc.\n", "\n", - "For more details and general examples on Named Entity Recognition, please refer to the [Token Classification and Named Entity Recognition tutorial notebook](https://github.com/NVIDIA/NeMo/blob/stable/tutorials/nlp/Token_Classification_Named_Entity_Recognition.ipynb).\n", + "For instance, **given sentences from news title, is it a good or bad news?**
\n", "\n", "# Dataset\n", "\n", - "The [NCBI-disease corpus](https://www.ncbi.nlm.nih.gov/CBBresearch/Dogan/DISEASE/) is a set of 793 PubMed abstracts, annotated by 14 annotators. The annotations take the form of HTML-style tags inserted into the abstract text using the clearly defined rules. The annotations identify named diseases, and can be used to fine-tune a language model to identify disease mentions in future abstracts, *whether those diseases were part of the original training set or not*.\n", + "The [Financial PhraseBank dataset](https://huggingface.co/datasets/financial_phrasebank) contains the sentiments for financial news headlines from the perspective of a retail investor. Further details about the dataset can be found in: Malo, P., Sinha, A., Takala, P., Korhonen, P. and Wallenius, J. (2014): “Good debt or bad debt: Detecting semantic orientations in economic texts.” Journal of the American Society for Information Science and Technology.\n", "\n", "Here's an example of what an annotated abstract from the corpus looks like:\n", "\n", - "```html\n", - "10021369\tIdentification of APC2, a homologue of the adenomatous polyposis coli tumour suppressor .\tThe adenomatous polyposis coli ( APC ) tumour-suppressor protein controls the Wnt signalling pathway by forming a complex with glycogen synthase kinase 3beta ( GSK-3beta ) , axin / conductin and betacatenin . Complex formation induces the rapid degradation of betacatenin . In colon carcinoma cells , loss of APC leads to the accumulation of betacatenin in the nucleus , where it binds to and activates the Tcf-4 transcription factor ( reviewed in [ 1 ] [ 2 ] ) . Here , we report the identification and genomic structure of APC homologues . Mammalian APC2 , which closely resembles APC in overall domain structure , was functionally analyzed and shown to contain two SAMP domains , both of which are required for binding to conductin . Like APC , APC2 regulates the formation of active betacatenin-Tcf complexes , as demonstrated using transient transcriptional activation assays in APC - / - colon carcinoma cells . Human APC2 maps to chromosome 19p13 . 3 . APC and APC2 may therefore have comparable functions in development and cancer .\n", "```\n", - "\n", - "In this example, we see the following tags within the abstract:\n", - "```html\n", - "adenomatous polyposis coli tumour\n", - "adenomatous polyposis coli ( APC ) tumour\n", - "colon carcinoma\n", - "colon carcinoma\n", - "cancer\n", + "HELSINKI Thomson Financial - Shares in Cargotec fell sharply in early afternoon trade after the cargo handling group posted a surprise drop in April-June profits , which overshadowed the large number of new orders received during the three months .@negative\n", + "LONDON MarketWatch -- Share prices ended lower in London Monday as a rebound in bank stocks failed to offset broader weakness for the FTSE 100 .@negative\n", + "Operating profit fell to EUR 35.4 mn from EUR 68.8 mn in 2007 , including vessel sales gain of EUR 12.3 mn .@negative\n", + "Sales in Finland decreased by 10.5 % in January , while sales outside Finland dropped by 17 % .@negative\n", "```\n", "\n", - "For our purposes, we will consider any identified category (such as \"Modifier\", \"Specific Disease\", and a few others) to generally be a \"disease\".\n", - "\n", "Let's download the dataset." ] }, { "cell_type": "code", - "execution_count": 22, + "execution_count": null, "id": "federal-beads", "metadata": {}, "outputs": [], @@ -117,33 +103,10 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": null, "id": "8ad03fc0", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "--2022-01-21 14:56:30-- https://www.researchgate.net/profile/Pekka_Malo/publication/251231364_FinancialPhraseBank-v10/data/0c96051eee4fb1d56e000000/FinancialPhraseBank-v10.zip\n", - "Resolving www.researchgate.net (www.researchgate.net)... 104.17.32.105, 104.17.33.105, 2606:4700::6811:2169, ...\n", - "Connecting to www.researchgate.net (www.researchgate.net)|104.17.32.105|:443... connected.\n", - "HTTP request sent, awaiting response... 301 Moved Permanently\n", - "Location: https://www.researchgate.net/profile/Pekka-Malo/publication/251231364_FinancialPhraseBank-v10/data/0c96051eee4fb1d56e000000/FinancialPhraseBank-v10.zip [following]\n", - "--2022-01-21 14:56:30-- https://www.researchgate.net/profile/Pekka-Malo/publication/251231364_FinancialPhraseBank-v10/data/0c96051eee4fb1d56e000000/FinancialPhraseBank-v10.zip\n", - "Reusing existing connection to www.researchgate.net:443.\n", - "HTTP request sent, awaiting response... 200 OK\n", - "Length: 681890 (666K) [application/zip]\n", - "Saving to: ‘FinancialPhraseBank-v10.zip’\n", - "\n", - "FinancialPhraseBank 100%[===================>] 665.91K 1.74MB/s in 0.4s \n", - "\n", - "2022-01-21 14:56:31 (1.74 MB/s) - ‘FinancialPhraseBank-v10.zip’ saved [681890/681890]\n", - "\n", - "Archive: DATA_DIR/FinancialPhraseBank-v10.zip\n" - ] - } - ], + "outputs": [], "source": [ "!wget https://www.researchgate.net/profile/Pekka_Malo/publication/251231364_FinancialPhraseBank-v10/data/0c96051eee4fb1d56e000000/FinancialPhraseBank-v10.zip\n", "!mv FinancialPhraseBank-v10.zip {DATA_DIR}\n", @@ -152,45 +115,29 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": null, "id": "radical-castle", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "According to Gran , the company has no plans to move all production to Russia , although that is where the company is growing .@neutral\n" - ] - } - ], + "outputs": [], "source": [ "# If you want to see more examples, you can explore the text of the corpus using the file browser to the left, or open files directly, for example typing a command like the following in a code-cell:\n", "\n", "! head -1 $DATA_DIR/FinancialPhraseBank-v1.0/Sentences_50Agree.txt" ] }, - { - "cell_type": "markdown", - "id": "specified-maine", - "metadata": {}, - "source": [ - "We have two datasets derived from this corpus: a text classification dataset and a named entity recognition (NER) dataset. The text classification dataset labels the abstracts among three broad disease groupings. We'll use this simple split to demonstrate the NLP text classification task. The NER dataset labels individual words as diseases. This dataset will be used for the NLP NER task. " - ] - }, { "cell_type": "markdown", "id": "affected-numbers", "metadata": {}, "source": [ "## Pre-process dataset\n", - "A pre-processed NCBI-disease dataset for NER can be found [here](https://github.com/spyysalo/ncbi-disease/tree/master/conll) or [here](https://github.com/dmis-lab/biobert#datasets).
\n", - "We download the files under {DATA_DIR/NER} directory." + "\n", + "In this pre-process step, we are going to convert the downloaded dataset into the format that can be used for P-Tuning dataloader. The data is split into 10 folds so we can do 10-fold cross validation. " ] }, { "cell_type": "code", - "execution_count": 25, + "execution_count": null, "id": "198287d4", "metadata": {}, "outputs": [], @@ -258,24 +205,16 @@ "id": "graphic-debate", "metadata": {}, "source": [ - "The NER task requires two files: the text sentences, and the labels. Run the next two cells to see a sample of the two files." + "The data is converted to the loss json file. Each line has two keys \"sentence\" and \"label\". Note we append \"Sentiment\" at the end of the input sentence to cue the model for sentiment analysis. \n", + "Here are the first two lines of converted data:" ] }, { "cell_type": "code", - "execution_count": 26, + "execution_count": null, "id": "sound-surgeon", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{\"sentence\": \"The contract includes heating plant equipment and associated installation work . Sentiment \", \"label\": \"neutral\"}\n", - "{\"sentence\": \"The utility will also provide services related to electricity management , such as hedging trades and risk management and reporting . Sentiment \", \"label\": \"neutral\"}\n" - ] - } - ], + "outputs": [], "source": [ "!head -n 2 $DATA_DIR/FinancialPhraseBank-v1.0/train_0.txt" ] @@ -303,7 +242,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": null, "id": "82b8e08e", "metadata": {}, "outputs": [], @@ -324,18 +263,10 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "id": "4b00ee86", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "config file is already exists\n" - ] - } - ], + "outputs": [], "source": [ "WORK_DIR = \"WORK_DIR\"\n", "os.makedirs(WORK_DIR, exist_ok=True)\n", @@ -354,18 +285,10 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "id": "0ae5a1a9", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "WORK_DIR/configs/megatron_gpt_config.yaml\n" - ] - } - ], + "outputs": [], "source": [ "# this line will print the entire config of the model\n", "config_path = f'{WORK_DIR}/configs/{MODEL_CONFIG}'\n", @@ -390,42 +313,10 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "id": "9e1beda4", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "################################################################################\n", - "### WARNING, path does not exist: KALDI_ROOT=/mnt/matylda5/iveselyk/Tools/kaldi-trunk\n", - "### (please add 'export KALDI_ROOT=' in your $HOME/.profile)\n", - "### (or run as: KALDI_ROOT= python .py)\n", - "################################################################################\n", - "\n", - "[NeMo W 2022-01-20 21:01:09 experimental:27] Module is experimental, not ready for production and is not fully supported. Use at your own risk.\n", - "[NeMo W 2022-01-20 21:01:09 experimental:27] Module is experimental, not ready for production and is not fully supported. Use at your own risk.\n", - "[NeMo W 2022-01-20 21:01:09 experimental:27] Module is experimental, not ready for production and is not fully supported. Use at your own risk.\n", - "[NeMo W 2022-01-20 21:01:09 experimental:27] Module is experimental, not ready for production and is not fully supported. Use at your own risk.\n", - "I0120 21:01:09.749301 140536743184192 distributed_c10d.py:218] Added key: store_based_barrier_key:1 to store for rank: 0\n", - "I0120 21:01:09.749543 140536743184192 distributed_c10d.py:252] Rank 0: Completed store-based barrier for key:store_based_barrier_key:1 with 1 nodes.\n", - "GPU available: True, used: True\n", - "TPU available: False, using: 0 TPU cores\n", - "IPU available: False, using: 0 IPUs\n", - "converted 354.87M parameters\n", - "[NeMo I 2022-01-20 21:01:10 tokenizer_utils:190] Getting Megatron tokenizer for pretrained model name: megatron-gpt-345m and custom vocab file: /NeMo/tutorials/nlp/gpt2-vocab.json\n", - "[NeMo I 2022-01-20 21:01:10 tokenizer_utils:123] Getting HuggingFace AutoTokenizer with pretrained_model_name: gpt2, vocab_file: /NeMo/tutorials/nlp/gpt2-vocab.json, special_tokens_dict: {}, and use_fast: False\n", - "Using sep_token, but it is not set yet.\n", - "Using cls_token, but it is not set yet.\n", - "Using pad_token, but it is not set yet.\n", - "Using mask_token, but it is not set yet.\n", - "[NeMo I 2022-01-20 21:01:13 megatron_gpt_model:754] Padded vocab_size: 50304, original vocab_size: 50257, dummy tokens: 47.\n", - "[NeMo I 2022-01-20 21:10:10 megatron_lm_ckpt_to_nemo:265] NeMo model saved to: /NeMo/tutorials/nlp/gpt_344m.nemo\n", - "\u001b[0m" - ] - } - ], + "outputs": [], "source": [ "import os\n", "PWD = os.getcwd()\n", @@ -440,7 +331,7 @@ "source": [ "# Model configuration\n", "\n", - "Our Named Entity Recognition model is comprised of the pretrained [BERT](https://arxiv.org/pdf/1810.04805.pdf) model followed by a Token Classification layer.\n", + "Our P-Tuning text classification model is comprised of the pretrained GPT LM model followed by a prompt encoder layer.\n", "\n", "The model is defined in a config file which declares multiple important sections. They are:\n", "- **model**: All arguments that are related to the Model - language model, token classifier, optimizer and schedulers, datasets and any other related information\n", @@ -450,7 +341,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "id": "speaking-grant", "metadata": {}, "outputs": [], @@ -460,18 +351,10 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "id": "demanding-ballet", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "config file is already exists\n" - ] - } - ], + "outputs": [], "source": [ "# download the model's configuration file \n", "config_dir = WORK_DIR + '/configs/'\n", @@ -485,18 +368,10 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "id": "criminal-outdoors", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "WORK_DIR/configs/ptune_text_classification_config.yaml\n" - ] - } - ], + "outputs": [], "source": [ "# this line will print the entire config of the model\n", "config_path = f'{WORK_DIR}/configs/{MODEL_CONFIG}'\n", @@ -507,14 +382,25 @@ "config.model.validation_ds.batch_size=8" ] }, + { + "cell_type": "markdown", + "id": "dedicated-effort", + "metadata": {}, + "source": [ + "# Model Training\n", + "## Setting up Data within the config\n", + "\n", + "Among other things, the config file contains dictionaries called train_ds, validation_ds and test_ds. These are configurations used to setup the Dataset and DataLoaders of the corresponding config.\n" + ] + }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "id": "informed-purse", "metadata": {}, "outputs": [], "source": [ - "# in this tutorial train and dev datasets are located in the same folder, so it is enought to add the path of the data directory to the config\n", + "# in this tutorial train and dev datasets are located in the same folder, so it is enough to add the path of the data directory to the config\n", "#config.model.dataset.classes = ['positive', 'neutral', 'negative']\n", "config.model.train_ds.file_path = DATA_DIR+'/FinancialPhraseBank-v1.0/train_0.txt'\n", "config.model.validation_ds.file_path = DATA_DIR+'/FinancialPhraseBank-v1.0/validation_0.txt'\n", @@ -529,153 +415,12 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "id": "divine-belly", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "trainer:\n", - " gpus: 1\n", - " num_nodes: 1\n", - " max_epochs: 100\n", - " max_steps: null\n", - " accumulate_grad_batches: 1\n", - " gradient_clip_val: 0.0\n", - " precision: 32\n", - " accelerator: ddp\n", - " log_every_n_steps: 1\n", - " val_check_interval: 1.0\n", - " resume_from_checkpoint: null\n", - " num_sanity_val_steps: 0\n", - " checkpoint_callback: false\n", - " logger: false\n", - "model:\n", - " tensor_model_parallel_size: 2\n", - " seed: 1234\n", - " nemo_path: ptune_text_classification_model.nemo\n", - " use_lm_finetune: false\n", - " pseudo_token: '[PROMPT]'\n", - " tokenizer:\n", - " library: megatron\n", - " type: GPT2BPETokenizer\n", - " model: null\n", - " vocab_file: null\n", - " merge_file: null\n", - " language_model:\n", - " nemo_file: null\n", - " prompt_encoder:\n", - " template:\n", - " - 3\n", - " - 3\n", - " - 0\n", - " dropout: 0.0\n", - " dataset:\n", - " classes: ???\n", - " train_ds:\n", - " file_path: DATA_DIR/FinancialPhraseBank-v1.0/train_0.txt\n", - " batch_size: 8\n", - " shuffle: true\n", - " num_samples: -1\n", - " num_workers: 3\n", - " drop_last: false\n", - " pin_memory: false\n", - " validation_ds:\n", - " file_path: DATA_DIR/FinancialPhraseBank-v1.0/validation_0.txt\n", - " batch_size: 8\n", - " shuffle: false\n", - " num_samples: -1\n", - " num_workers: 3\n", - " drop_last: false\n", - " pin_memory: false\n", - " test_ds:\n", - " file_path: DATA_DIR/FinancialPhraseBank-v1.0/test_0.txt\n", - " batch_size: 64\n", - " shuffle: false\n", - " num_samples: -1\n", - " num_workers: 3\n", - " drop_last: false\n", - " pin_memory: false\n", - " optim:\n", - " name: adam\n", - " lr: 1.0e-05\n", - " betas:\n", - " - 0.9\n", - " - 0.999\n", - " weight_decay: 0.0005\n", - " sched:\n", - " name: WarmupAnnealing\n", - " warmup_steps: null\n", - " warmup_ratio: 0.1\n", - " last_epoch: -1\n", - " monitor: val_loss\n", - " reduce_on_plateau: false\n", - " infer_samples:\n", - " - by the end of no such thing the audience , like beatrice , has a watchful affection\n", - " for the monster .\n", - " - director rob marshall went out gunning to make a great one .\n", - " - uneasy mishmash of styles and genres .\n", - "exp_manager:\n", - " exp_dir: null\n", - " name: PTuneTextClassification\n", - " create_tensorboard_logger: true\n", - " create_checkpoint_callback: true\n", - "\n" - ] - } - ], - "source": [ - "print(OmegaConf.to_yaml(config))" - ] - }, - { - "cell_type": "markdown", - "id": "dedicated-effort", - "metadata": {}, - "source": [ - "# Model Training\n", - "## Setting up Data within the config\n", - "\n", - "Among other things, the config file contains dictionaries called dataset, train_ds and validation_ds. These are configurations used to setup the Dataset and DataLoaders of the corresponding config.\n" - ] - }, - { - "cell_type": "markdown", - "id": "15e2c67a", - "metadata": {}, - "source": [ - "\n", - "We assume that both training and evaluation files are located in the same directory, and use the default names mentioned during the data download step. \n", - "So, to start model training, we simply need to specify `model.dataset.data_dir`, like we are going to do below.\n" - ] - }, - { - "cell_type": "markdown", - "id": "89dd468d", - "metadata": {}, - "source": [ - "\n", - "Also notice that some config lines, including `model.dataset.data_dir`, have `???` in place of paths, this means that values for these fields are required to be specified by the user.\n", - "\n", - "Let's now add the data directory path to the config." - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "a312ed76", - "metadata": {}, "outputs": [], "source": [ - "# in this tutorial train and dev datasets are located in the same folder, so it is enought to add the path of the data directory to the config\n", - "config.model.dataset.data_dir = os.path.join(DATA_DIR, 'SA')\n", - "\n", - "# if you want to decrease the size of your datasets, uncomment the lines below:\n", - "# NUM_SAMPLES = 1000\n", - "# config.model.train_ds.num_samples = NUM_SAMPLES\n", - "# config.model.validation_ds.num_samples = NUM_SAMPLES" + "print(OmegaConf.to_yaml(config))" ] }, { @@ -692,34 +437,10 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, "id": "computational-battlefield", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Trainer config - \n", - "\n", - "gpus: 1\n", - "num_nodes: 1\n", - "max_epochs: 100\n", - "max_steps: null\n", - "accumulate_grad_batches: 1\n", - "gradient_clip_val: 0.0\n", - "precision: 32\n", - "accelerator: ddp\n", - "log_every_n_steps: 1\n", - "val_check_interval: 1.0\n", - "resume_from_checkpoint: null\n", - "num_sanity_val_steps: 0\n", - "checkpoint_callback: false\n", - "logger: false\n", - "\n" - ] - } - ], + "outputs": [], "source": [ "print(\"Trainer config - \\n\")\n", "print(OmegaConf.to_yaml(config.trainer))" @@ -727,36 +448,10 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": null, "id": "unique-genre", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "[NeMo W 2022-01-21 13:37:28 nemo_logging:349] /opt/conda/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/ddp.py:107: LightningDeprecationWarning: Argument `num_nodes` in `DDPPlugin` is deprecated in v1.4, and will be removed in v1.6. Notice that it will be overriden by the trainer setting.\n", - " rank_zero_deprecation(\n", - " \n", - "[NeMo W 2022-01-21 13:37:28 nemo_logging:349] /opt/conda/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/ddp.py:113: LightningDeprecationWarning: Argument `sync_batchnorm` in `DDPPlugin` is deprecated in v1.4, and will be removed in v1.6. Notice that it will be overriden by the trainer setting.\n", - " rank_zero_deprecation(\n", - " \n", - "[NeMo W 2022-01-21 13:37:28 nemo_logging:349] /opt/conda/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/accelerator_connector.py:324: LightningDeprecationWarning: Passing `strategy` to the `plugins` flag in Trainer has been deprecated in v1.5 and will be removed in v1.7. Use `Trainer(strategy=)` instead.\n", - " rank_zero_deprecation(\n", - " \n", - "Using 16bit native Automatic Mixed Precision (AMP)\n", - "[NeMo W 2022-01-21 13:37:28 nemo_logging:349] /opt/conda/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py:48: LightningDeprecationWarning: Setting `max_steps = None` is deprecated in v1.5 and will no longer be supported in v1.7. Use `max_steps = -1` instead.\n", - " rank_zero_deprecation(\n", - " \n", - "[NeMo W 2022-01-21 13:37:28 nemo_logging:349] /opt/conda/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/callback_connector.py:147: LightningDeprecationWarning: Setting `Trainer(checkpoint_callback=False)` is deprecated in v1.5 and will be removed in v1.7. Please consider using `Trainer(enable_checkpointing=False)`.\n", - " rank_zero_deprecation(\n", - " \n", - "GPU available: True, used: True\n", - "TPU available: False, using: 0 TPU cores\n", - "IPU available: False, using: 0 IPUs\n" - ] - } - ], + "outputs": [], "source": [ "from nemo.collections.nlp.parts.nlp_overrides import NLPDDPPlugin\n", "\n", @@ -787,39 +482,10 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": null, "id": "mathematical-portable", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[NeMo I 2022-01-21 13:37:41 exp_manager:283] Experiments will be logged at /NeMo/tutorials/nlp/nemo_experiments/PTuneTextClassification/2022-01-21_13-37-41\n", - "[NeMo I 2022-01-21 13:37:41 exp_manager:648] TensorboardLogger has been set up\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "[NeMo W 2022-01-21 13:37:41 exp_manager:889] The checkpoint callback was told to monitor a validation value and trainer's max_steps was set to -1. Please ensure that max_steps will run for at least 1 epochs to ensure that checkpointing will not error out.\n", - "[NeMo W 2022-01-21 13:37:41 nemo_logging:349] /opt/conda/lib/python3.8/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:243: LightningDeprecationWarning: `ModelCheckpoint(every_n_val_epochs)` is deprecated in v1.4 and will be removed in v1.6. Please use `every_n_epochs` instead.\n", - " rank_zero_deprecation(\n", - " \n" - ] - }, - { - "data": { - "text/plain": [ - "'/NeMo/tutorials/nlp/nemo_experiments/PTuneTextClassification/2022-01-21_13-37-41'" - ] - }, - "execution_count": 13, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "exp_dir = exp_manager(trainer, config.get(\"exp_manager\", None))\n", "os.makedirs(WORK_DIR, exist_ok=True)\n", @@ -834,25 +500,12 @@ "id": "f62ea6cd", "metadata": {}, "source": [ - "To load the pretrained BERT LM model, we can either load it from the converted `.nemo` file as shown above or load it from a list of included model names. \n", - "\n", - "We can get the list of names by following command \n", - "```python\n", - "# complete list of supported BERT-like models\n", - "print(nemo_nlp.modules.get_pretrained_lm_models_list())\n", - "```\n", - "We can change the `model.language_mode` config to use it\n", - "```python\n", - "# add the specified above model parameters to the config\n", - "config.model.language_model.pretrained_model_name = MODEL_NAME\n", - "```\n", - "\n", - "In this notebook, we will use the converted `.nemo` file as our LM model, which is BioMegatron, [Megatron-LM BERT](https://arxiv.org/abs/1909.08053) pre-trained on [PubMed](https://pubmed.ncbi.nlm.nih.gov/) biomedical text corpus." + "We will use the converted `.nemo` file as our LM model." ] }, { "cell_type": "code", - "execution_count": 16, + "execution_count": null, "id": "compact-horse", "metadata": {}, "outputs": [], @@ -871,89 +524,15 @@ "id": "seeing-geometry", "metadata": {}, "source": [ - "Now, we are ready to initialize our model. During the model initialization call, the dataset and data loaders we'll be prepared for training and evaluation.\n", - "Also, the pretrained BERT model will be downloaded, note it can take up to a few minutes depending on the size of the chosen BERT model." + "Now, we are ready to initialize our model. During the model initialization call, the dataset and data loaders we'll be prepared for training and evaluation." ] }, { "cell_type": "code", - "execution_count": 17, + "execution_count": null, "id": "indoor-france", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[NeMo I 2022-01-21 13:38:19 tokenizer_utils:190] Getting Megatron tokenizer for pretrained model name: megatron-gpt-345m and custom vocab file: /NeMo/tutorials/nlp/gpt2-vocab.json\n", - "[NeMo I 2022-01-21 13:38:19 tokenizer_utils:123] Getting HuggingFace AutoTokenizer with pretrained_model_name: gpt2, vocab_file: /NeMo/tutorials/nlp/gpt2-vocab.json, special_tokens_dict: {}, and use_fast: False\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Using sep_token, but it is not set yet.\n", - "Using cls_token, but it is not set yet.\n", - "Using pad_token, but it is not set yet.\n", - "Using mask_token, but it is not set yet.\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[NeMo I 2022-01-21 13:38:34 tokenizer_utils:190] Getting Megatron tokenizer for pretrained model name: megatron-gpt-345m and custom vocab file: /tmp/tmp1vxu9jzs/3f23abcf03b94354899f3c5b5beab943_gpt2-vocab.json\n", - "[NeMo I 2022-01-21 13:38:34 tokenizer_utils:123] Getting HuggingFace AutoTokenizer with pretrained_model_name: gpt2, vocab_file: /tmp/tmp1vxu9jzs/3f23abcf03b94354899f3c5b5beab943_gpt2-vocab.json, special_tokens_dict: {}, and use_fast: False\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Using sep_token, but it is not set yet.\n", - "Using cls_token, but it is not set yet.\n", - "Using pad_token, but it is not set yet.\n", - "Using mask_token, but it is not set yet.\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[NeMo I 2022-01-21 13:38:38 megatron_gpt_model:754] Padded vocab_size: 50304, original vocab_size: 50257, dummy tokens: 47.\n", - "[NeMo I 2022-01-21 13:38:39 tokenizer_utils:190] Getting Megatron tokenizer for pretrained model name: megatron-gpt-345m and custom vocab file: /tmp/tmp1vxu9jzs/3f23abcf03b94354899f3c5b5beab943_gpt2-vocab.json\n", - "[NeMo I 2022-01-21 13:38:39 tokenizer_utils:123] Getting HuggingFace AutoTokenizer with pretrained_model_name: gpt2, vocab_file: /tmp/tmp1vxu9jzs/3f23abcf03b94354899f3c5b5beab943_gpt2-vocab.json, special_tokens_dict: {}, and use_fast: False\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Using sep_token, but it is not set yet.\n", - "Using cls_token, but it is not set yet.\n", - "Using pad_token, but it is not set yet.\n", - "Using mask_token, but it is not set yet.\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[NeMo I 2022-01-21 13:38:43 megatron_gpt_model:754] Padded vocab_size: 50304, original vocab_size: 50257, dummy tokens: 47.\n", - "[NeMo I 2022-01-21 13:38:43 save_restore_connector:149] Model MegatronGPTModel was successfully restored from /NeMo/tutorials/nlp/gpt_344m.nemo.\n", - "[NeMo I 2022-01-21 13:38:44 auto_tokenizer:171] 1 special tokens added, resize your model accordingly.\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Using pad_token, but it is not set yet.\n", - "Using mask_token, but it is not set yet.\n" - ] - } - ], + "outputs": [], "source": [ "from nemo.collections.nlp.models.text_classification.ptune_text_classification_model import PTuneTextClassificationModel\n", "model_ptune = PTuneTextClassificationModel(cfg=config.model, trainer=trainer)" @@ -971,18 +550,10 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": null, "id": "changed-expense", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "To use tensorboard, please use this notebook in a Google Colab environment.\n" - ] - } - ], + "outputs": [], "source": [ "try:\n", " from google import colab\n", @@ -1000,165 +571,10 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": null, "id": "applied-quality", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "[NeMo W 2022-01-21 13:38:50 nemo_logging:349] /opt/conda/lib/python3.8/site-packages/pytorch_lightning/trainer/configuration_validator.py:287: LightningDeprecationWarning: Base `Callback.on_train_batch_start` hook signature has changed in v1.5. The `dataloader_idx` argument will be removed in v1.7.\n", - " rank_zero_deprecation(\n", - " \n", - "[NeMo W 2022-01-21 13:38:50 nemo_logging:349] /opt/conda/lib/python3.8/site-packages/pytorch_lightning/trainer/configuration_validator.py:287: LightningDeprecationWarning: Base `Callback.on_train_batch_end` hook signature has changed in v1.5. The `dataloader_idx` argument will be removed in v1.7.\n", - " rank_zero_deprecation(\n", - " \n", - "initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/1\n", - "I0121 13:38:50.209178 140425413850944 distributed_c10d.py:218] Added key: store_based_barrier_key:1 to store for rank: 0\n", - "I0121 13:38:50.209801 140425413850944 distributed_c10d.py:252] Rank 0: Completed store-based barrier for key:store_based_barrier_key:1 with 1 nodes.\n", - "----------------------------------------------------------------------------------------------------\n", - "distributed_backend=nccl\n", - "All distributed processes registered. Starting with 1 processes\n", - "----------------------------------------------------------------------------------------------------\n", - "\n", - "I0121 13:38:50.211127 140425413850944 distributed_c10d.py:218] Added key: store_based_barrier_key:2 to store for rank: 0\n", - "I0121 13:38:50.211528 140425413850944 distributed_c10d.py:252] Rank 0: Completed store-based barrier for key:store_based_barrier_key:2 with 1 nodes.\n", - "I0121 13:38:50.212052 140425413850944 distributed_c10d.py:218] Added key: store_based_barrier_key:3 to store for rank: 0\n", - "I0121 13:38:50.212460 140425413850944 distributed_c10d.py:252] Rank 0: Completed store-based barrier for key:store_based_barrier_key:3 with 1 nodes.\n", - "I0121 13:38:50.212984 140425413850944 distributed_c10d.py:218] Added key: store_based_barrier_key:4 to store for rank: 0\n", - "I0121 13:38:50.213450 140425413850944 distributed_c10d.py:252] Rank 0: Completed store-based barrier for key:store_based_barrier_key:4 with 1 nodes.\n", - "I0121 13:38:50.213927 140425413850944 distributed_c10d.py:218] Added key: store_based_barrier_key:5 to store for rank: 0\n", - "I0121 13:38:50.214323 140425413850944 distributed_c10d.py:252] Rank 0: Completed store-based barrier for key:store_based_barrier_key:5 with 1 nodes.\n", - "I0121 13:38:50.214805 140425413850944 distributed_c10d.py:218] Added key: store_based_barrier_key:6 to store for rank: 0\n", - "I0121 13:38:50.215201 140425413850944 distributed_c10d.py:252] Rank 0: Completed store-based barrier for key:store_based_barrier_key:6 with 1 nodes.\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "> initializing tensor model parallel with size 1\n", - "> initializing pipeline model parallel with size 1\n", - "> initializing data parallel with size 1\n", - "[NeMo I 2022-01-21 13:38:50 nlp_overrides:137] mp_rank: 0\n", - "[NeMo I 2022-01-21 13:38:50 nlp_overrides:138] dp_rank: 0\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4]\n", - "[NeMo W 2022-01-21 13:38:50 modelPT:475] The lightning trainer received accelerator: . We recommend to use 'ddp' instead.\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[NeMo I 2022-01-21 13:38:50 modelPT:566] Optimizer config = Adam (\n", - " Parameter Group 0\n", - " amsgrad: False\n", - " betas: [0.9, 0.999]\n", - " eps: 1e-08\n", - " lr: 1e-05\n", - " weight_decay: 0.0005\n", - " )\n", - "[NeMo I 2022-01-21 13:38:50 lr_scheduler:833] Scheduler \"\" \n", - " will be used during training (effective maximum steps = 147800) - \n", - " Parameters : \n", - " (warmup_steps: null\n", - " warmup_ratio: 0.1\n", - " last_epoch: -1\n", - " max_steps: 147800\n", - " )\n", - "[NeMo I 2022-01-21 13:38:51 nlp_overrides:92] Configuring DDP for model parallelism.\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\n", - " | Name | Type | Params\n", - "-----------------------------------------------------------------\n", - "0 | model | MegatronGPTModel | 354 M \n", - "1 | embeddings | VocabParallelEmbedding | 51.5 M\n", - "2 | classification_report | ClassificationReport | 0 \n", - "3 | prompt_encoder | PromptEncoder | 14.7 M\n", - "-----------------------------------------------------------------\n", - "14.7 M Trainable params\n", - "354 M Non-trainable params\n", - "369 M Total params\n", - "739.152 Total estimated model params size (MB)\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "46ac1cdd81ad40c39fb6f076790aee94", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Training: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "[NeMo W 2022-01-21 13:38:52 nemo_logging:349] /opt/conda/lib/python3.8/site-packages/torch/optim/lr_scheduler.py:129: UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`. In PyTorch 1.1.0 and later, you should call them in the opposite order: `optimizer.step()` before `lr_scheduler.step()`. Failure to do this will result in PyTorch skipping the first value of the learning rate schedule. See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate\n", - " warnings.warn(\"Detected call of `lr_scheduler.step()` before `optimizer.step()`. \"\n", - " \n", - "I0121 13:38:52.474958 140425413850944 distributed.py:902] Reducer buckets have been rebuilt in this iteration.\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "357dcda709c249f7b5d4ce45e39691dc", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validating: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[NeMo I 2022-01-21 13:41:51 ptune_text_classification_model:312] val_report: \n", - " label precision recall f1 support \n", - " positive (label_id: 0) 49.78 27.68 35.58 401\n", - " neutral (label_id: 1) 67.28 94.38 78.56 889\n", - " negative (label_id: 2) 75.00 3.19 6.12 188\n", - " -------------------\n", - " micro avg 64.68 64.68 64.68 1478\n", - " macro avg 64.02 41.75 40.09 1478\n", - " weighted avg 63.51 64.68 57.68 1478\n", - " \n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Epoch 0, global step 1477: val_loss reached 0.93801 (best 0.93801), saving model to \"/NeMo/tutorials/nlp/nemo_experiments/PTuneTextClassification/2022-01-21_13-37-41/checkpoints/PTuneTextClassification--val_loss=0.9380-epoch=0.ckpt\" as top 3\n", - "[NeMo W 2022-01-21 13:42:29 nemo_logging:349] /opt/conda/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:685: UserWarning: Detected KeyboardInterrupt, attempting graceful shutdown...\n", - " rank_zero_warn(\"Detected KeyboardInterrupt, attempting graceful shutdown...\")\n", - " \n" - ] - } - ], + "outputs": [], "source": [ "# start model training\n", "trainer.fit(model_ptune)" @@ -1171,7 +587,7 @@ "source": [ "# Inference\n", "\n", - "To see how the model performs, we can run generate prediction similar to the way we did it earlier" + "To see how the model performs, we can run model in the inference mode" ] }, { @@ -1182,36 +598,17 @@ "outputs": [], "source": [ "# let's first create a subset of our dev data\n", - "! head -n 100 $NER_DATA_DIR/text_dev.txt > $NER_DATA_DIR/sample_text_dev.txt\n", - "! head -n 100 $NER_DATA_DIR/labels_dev.txt > $NER_DATA_DIR/sample_labels_dev.txt" - ] - }, - { - "cell_type": "markdown", - "id": "adult-ranking", - "metadata": {}, - "source": [ - "Now, let's generate predictions for the provided text file.\n", - "If labels file is also specified, the model will evaluate the predictions and plot confusion matrix. " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "twenty-abortion", - "metadata": {}, - "outputs": [], - "source": [ - "model_ner.half().evaluate_from_file(\n", - " text_file=os.path.join(NER_DATA_DIR, 'sample_text_dev.txt'),\n", - " labels_file=os.path.join(NER_DATA_DIR, 'sample_labels_dev.txt'),\n", - " output_dir=exp_dir,\n", - " add_confusion_matrix=False,\n", - " normalize_confusion_matrix=True,\n", - " batch_size=1\n", - ")\n", - "# Please check matplotlib version if encountering any error plotting confusion matrix:\n", - "# https://stackoverflow.com/questions/63212347/importerror-cannot-import-name-png-from-matplotlib" + "query_examples = [\n", + "\"For example , net sales increased by 5.9 % from the first quarter , and EBITDA increased from a negative EUR 0.2 mn in the first quarter of 2000 .\",\n", + "\"EPS for the quarter was EUR0 .00 , as compared with EUR0 .01 in the third quarter of 2008 , representing a Group net sales for the third quarter were EUR15 .3 m , up by 2.8 % as compared with EUR14 .9 m in the third quarter of 2008 .\",\n", + "\"The NTSB said investigators are set to conduct sight distance tests on July 18 , using trains similar to those involved in the accident .\",\n", + "\"Pretax profit totaled EUR 9.0 mn , down from EUR 36.3 mn in 2007 .\",\n", + "\"However , the proportion of the paid standing orders grew in 2009 .\"]\n", + "results = model_ptune.classifytext(queries=query_examples, batch_size=1, prompt='Sentiment')\n", + "print('The prediction results of some sample queries with the trained model:')\n", + "for query, result in zip(query_examples, results):\n", + " print(f'Query : {query}')\n", + " print(f'Predicted label: {result}')" ] }, { @@ -1221,11 +618,22 @@ "source": [ "## Training Script\n", "\n", - "If you have NeMo installed locally, you can also train the model with `nlp/token_classification/token_classification_train.py.`\n", + "If you have NeMo installed locally, you can also train the model with `examples/nlp/text_classification/ptune_text_classification.py`.\n", "\n", "To run training script, use:\n", - "\n", - "`python token_classification_train.py model.dataset.data_dir=PATH_TO_DATA_DIR exp_manager.exp_dir=EXP_DIR model.language_model.pretrained_model_name=megatron-bert-cased model.tokenizer.vocab_file=VOCAB_FILE model.tokenizer.tokenizer_model=BertWordPieceCase model.language_model.nemo_file=NEMO_FILE`\n" + "```\n", + "python examples/nlp/text_classification/ptune_text_classification.py \\\n", + " trainer.gpus=1 \\\n", + " model.tokenizer.vocab_file=VOCAB_FILE \\\n", + " model.tensor_model_parallel_size=1 \\\n", + " model.tokenizer.merge_file=MERGE_FILE \\\n", + " model.language_model.nemo_file=gpt_344m.nemo \\\n", + " model.dataset.classes=[positive, neutral, negative] \\\n", + " model.train_ds.file_path=TRAIN_FILE \\\n", + " model.train_ds.batch_size=8 \\\n", + " model.validation_ds.file_path=VAL_FILE \\\n", + " model.test_ds.file_path=TEST_FILE \\\n", + "```" ] }, { From 51cbea24f50d40d026a99acbab38c4c33d244383 Mon Sep 17 00:00:00 2001 From: Yi Dong Date: Fri, 21 Jan 2022 11:12:00 -0800 Subject: [PATCH 12/22] updated expected result Signed-off-by: Yi Dong --- tutorials/nlp/PTune_Sentiment_Analysis.ipynb | 29 ++++++++++---------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/tutorials/nlp/PTune_Sentiment_Analysis.ipynb b/tutorials/nlp/PTune_Sentiment_Analysis.ipynb index 8676989700b9..237ea501052c 100644 --- a/tutorials/nlp/PTune_Sentiment_Analysis.ipynb +++ b/tutorials/nlp/PTune_Sentiment_Analysis.ipynb @@ -641,23 +641,24 @@ "id": "legitimate-electric", "metadata": {}, "source": [ - "The training could take several minutes and the result should look something like\n", + "The training could take several hours and the result should look something like\n", "```\n", - "[NeMo I 2020-05-22 17:13:48 token_classification_callback:82] Accuracy: 0.9882348032875798\n", - "[NeMo I 2020-05-22 17:13:48 token_classification_callback:86] F1 weighted: 98.82\n", - "[NeMo I 2020-05-22 17:13:48 token_classification_callback:86] F1 macro: 93.74\n", - "[NeMo I 2020-05-22 17:13:48 token_classification_callback:86] F1 micro: 98.82\n", - "[NeMo I 2020-05-22 17:13:49 token_classification_callback:89] precision recall f1-score support\n", - " \n", - " O (label id: 0) 0.9938 0.9957 0.9947 22092\n", - " B (label id: 1) 0.8843 0.9034 0.8938 787\n", - " I (label id: 2) 0.9505 0.8982 0.9236 1090\n", - " \n", - " accuracy 0.9882 23969\n", - " macro avg 0.9429 0.9324 0.9374 23969\n", - " weighted avg 0.9882 0.9882 0.9882 23969\n", + " label precision recall f1 support\n", + " positive (label_id: 0) 87.75 89.28 88.50 401\n", + " neutral (label_id: 1) 94.26 94.26 94.26 889\n", + " negative (label_id: 2) 95.03 91.49 93.22 188\n", + " -------------------\n", + " micro avg 92.56 92.56 92.56 1478\n", + " macro avg 92.35 91.68 92.00 1478\n", + " weighted avg 92.59 92.56 92.57 1478\n", "```" ] + }, + { + "cell_type": "markdown", + "id": "0ddb3960", + "metadata": {}, + "source": [] } ], "metadata": { From 12e0a252076ea22e6dc84c6e02e0f07e4637c6d9 Mon Sep 17 00:00:00 2001 From: Yi Dong Date: Fri, 21 Jan 2022 13:26:14 -0800 Subject: [PATCH 13/22] added accuracy Signed-off-by: Yi Dong --- .../ptune_text_classification_model.py | 12 ++++++++++-- tutorials/nlp/PTune_Sentiment_Analysis.ipynb | 3 ++- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/nemo/collections/nlp/models/text_classification/ptune_text_classification_model.py b/nemo/collections/nlp/models/text_classification/ptune_text_classification_model.py index 538ae5c22a93..234ea1e9ef56 100644 --- a/nemo/collections/nlp/models/text_classification/ptune_text_classification_model.py +++ b/nemo/collections/nlp/models/text_classification/ptune_text_classification_model.py @@ -297,9 +297,14 @@ def validation_step(self, batch, batch_idx): sentences, labels = batch val_loss, preds, gt_labels = self.forward(sentences=sentences, labels=labels) + hit = 0 + for pred, gt_label in zip(preds, gt_labels): + if pred == gt_label: + hit += 1 + tp, fn, fp, _ = self.classification_report(preds, gt_labels) - return {'val_loss': val_loss, 'tp': tp, 'fn': fn, 'fp': fp} + return {'val_loss': val_loss, 'tp': tp, 'fn': fn, 'fp': fp, 'hit': hit} def validation_epoch_end(self, outputs): """ @@ -315,11 +320,14 @@ def validation_epoch_end(self, outputs): avg_loss = torch.stack([x[f'val_loss'] for x in outputs]).mean() + total_hit = sum([x[f'hit'] for x in outputs]) # calculate metrics and classification report precision, recall, f1, report = self.classification_report.compute() + total_data = torch.sum(self.classification_report.num_examples_per_class) + accuracy = total_hit / total_data.item() logging.info(f'{prefix}_report: {report}') - + logging.info(f'{total_hit} correct out of {total_data}, accuracy: {accuracy*100:.2f}') self.log(f'{prefix}_loss', avg_loss, prog_bar=True) self.log(f'{prefix}_precision', precision) self.log(f'{prefix}_f1', f1) diff --git a/tutorials/nlp/PTune_Sentiment_Analysis.ipynb b/tutorials/nlp/PTune_Sentiment_Analysis.ipynb index 237ea501052c..431a817182e1 100644 --- a/tutorials/nlp/PTune_Sentiment_Analysis.ipynb +++ b/tutorials/nlp/PTune_Sentiment_Analysis.ipynb @@ -460,6 +460,7 @@ "# checks if we have GPU available and uses it\n", "cuda = 1 if torch.cuda.is_available() else 0\n", "config.trainer.gpus = cuda\n", + "config.trainer.max_epochs = 6\n", "\n", "# for PyTorch Native AMP set precision=16\n", "config.trainer.precision = 16 if torch.cuda.is_available() else 32\n", @@ -604,7 +605,7 @@ "\"The NTSB said investigators are set to conduct sight distance tests on July 18 , using trains similar to those involved in the accident .\",\n", "\"Pretax profit totaled EUR 9.0 mn , down from EUR 36.3 mn in 2007 .\",\n", "\"However , the proportion of the paid standing orders grew in 2009 .\"]\n", - "results = model_ptune.classifytext(queries=query_examples, batch_size=1, prompt='Sentiment')\n", + "results = model_ptune.cuda().classifytext(queries=query_examples, batch_size=1, prompt='Sentiment')\n", "print('The prediction results of some sample queries with the trained model:')\n", "for query, result in zip(query_examples, results):\n", " print(f'Query : {query}')\n", From 0f8444ff3bfbc92f79d37e946f683f2b82631e3c Mon Sep 17 00:00:00 2001 From: Yi Dong Date: Fri, 21 Jan 2022 13:26:59 -0800 Subject: [PATCH 14/22] style fix Signed-off-by: Yi Dong --- .../ptune_text_classification_dataset.py | 1 - .../text_classification/ptune_text_classification_model.py | 6 ++++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/nemo/collections/nlp/data/text_classification/ptune_text_classification_dataset.py b/nemo/collections/nlp/data/text_classification/ptune_text_classification_dataset.py index 8db901ec51fe..9005e4e97920 100644 --- a/nemo/collections/nlp/data/text_classification/ptune_text_classification_dataset.py +++ b/nemo/collections/nlp/data/text_classification/ptune_text_classification_dataset.py @@ -63,7 +63,6 @@ def __init__(self, input_file: str, queries: List[str] = None, prompt: str = 'Se json_data.append({'sentence': line + f' {prompt} ', 'label': ''}) self.data = json_data - def __len__(self): return len(self.data) diff --git a/nemo/collections/nlp/models/text_classification/ptune_text_classification_model.py b/nemo/collections/nlp/models/text_classification/ptune_text_classification_model.py index 234ea1e9ef56..bf525b038a8c 100644 --- a/nemo/collections/nlp/models/text_classification/ptune_text_classification_model.py +++ b/nemo/collections/nlp/models/text_classification/ptune_text_classification_model.py @@ -99,7 +99,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): # set allowed vocab set self.vocab = self.tokenizer.tokenizer.get_vocab() - #make sure classes are part of the vocab + # make sure classes are part of the vocab for k in cfg.dataset.classes: if token_wrapper(k) not in self.vocab: logging.error(f'class {k} is not part of the vocabulary. Please add it to your vocab') @@ -249,7 +249,9 @@ def forward_eval(self, sentences): encoder_input, new_atten, label_position = self.get_encoder_input(sentences) batch_size, _, seq_len, _ = new_atten.shape - output = self.model.model(None, None, encoder_input=encoder_input.to(self.device), attention_mask=new_atten.to(self.device)) + output = self.model.model( + None, None, encoder_input=encoder_input.to(self.device), attention_mask=new_atten.to(self.device) + ) logits = output _, returned_pred = self.get_prediction(batch_size, label_position.to(self.device), logits) From b2ebf812e6a65e89a8b3e3b5dcfee22d9401eb08 Mon Sep 17 00:00:00 2001 From: Yi Dong Date: Mon, 24 Jan 2022 06:41:22 -0800 Subject: [PATCH 15/22] fix reassgin Signed-off-by: Yi Dong --- nemo/collections/nlp/modules/common/megatron/language_model.py | 2 +- tutorials/nlp/PTune_Sentiment_Analysis.ipynb | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/nemo/collections/nlp/modules/common/megatron/language_model.py b/nemo/collections/nlp/modules/common/megatron/language_model.py index b2f1bd5c55e9..34b4505fd4f3 100644 --- a/nemo/collections/nlp/modules/common/megatron/language_model.py +++ b/nemo/collections/nlp/modules/common/megatron/language_model.py @@ -695,7 +695,7 @@ def forward( else: encoder_input = embedding_output else: - encoder_input = encoder_input + pass # encoder. if enc_hidden_states is None: diff --git a/tutorials/nlp/PTune_Sentiment_Analysis.ipynb b/tutorials/nlp/PTune_Sentiment_Analysis.ipynb index 431a817182e1..6f3a44f080ad 100644 --- a/tutorials/nlp/PTune_Sentiment_Analysis.ipynb +++ b/tutorials/nlp/PTune_Sentiment_Analysis.ipynb @@ -629,7 +629,7 @@ " model.tensor_model_parallel_size=1 \\\n", " model.tokenizer.merge_file=MERGE_FILE \\\n", " model.language_model.nemo_file=gpt_344m.nemo \\\n", - " model.dataset.classes=[positive, neutral, negative] \\\n", + " model.dataset.classes=[positive,neutral,negative] \\\n", " model.train_ds.file_path=TRAIN_FILE \\\n", " model.train_ds.batch_size=8 \\\n", " model.validation_ds.file_path=VAL_FILE \\\n", From 8991080edde25dbb9979ebb4069cb6827d5a0e83 Mon Sep 17 00:00:00 2001 From: Yi Dong Date: Mon, 24 Jan 2022 06:46:34 -0800 Subject: [PATCH 16/22] log accuracy Signed-off-by: Yi Dong --- .../text_classification/ptune_text_classification_model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/nemo/collections/nlp/models/text_classification/ptune_text_classification_model.py b/nemo/collections/nlp/models/text_classification/ptune_text_classification_model.py index bf525b038a8c..138d41c74163 100644 --- a/nemo/collections/nlp/models/text_classification/ptune_text_classification_model.py +++ b/nemo/collections/nlp/models/text_classification/ptune_text_classification_model.py @@ -331,6 +331,7 @@ def validation_epoch_end(self, outputs): logging.info(f'{prefix}_report: {report}') logging.info(f'{total_hit} correct out of {total_data}, accuracy: {accuracy*100:.2f}') self.log(f'{prefix}_loss', avg_loss, prog_bar=True) + self.log(f'{prefix}_accuracy', accuracy) self.log(f'{prefix}_precision', precision) self.log(f'{prefix}_f1', f1) self.log(f'{prefix}_recall', recall) From 1d76088a4ad997d37ca0db82d5d9376cbbe5deb2 Mon Sep 17 00:00:00 2001 From: Yi Dong Date: Mon, 24 Jan 2022 17:31:12 -0800 Subject: [PATCH 17/22] load the best checkpoint Signed-off-by: Yi Dong --- .../ptune_text_classification.py | 47 ++++++++++++++++--- nemo/collections/nlp/parts/nlp_overrides.py | 2 +- 2 files changed, 41 insertions(+), 8 deletions(-) diff --git a/examples/nlp/text_classification/ptune_text_classification.py b/examples/nlp/text_classification/ptune_text_classification.py index c91554f5539d..ca6bbe3fc0d2 100644 --- a/examples/nlp/text_classification/ptune_text_classification.py +++ b/examples/nlp/text_classification/ptune_text_classification.py @@ -93,7 +93,11 @@ eval_model.set_trainer(eval_trainer) eval_trainer.test(model=eval_model, verbose=False) """ +import os +import pathlib + import pytorch_lightning as pl +import torch from omegaconf import DictConfig, OmegaConf from nemo.collections.nlp.models.text_classification.ptune_text_classification_model import ( @@ -121,27 +125,56 @@ def main(cfg: DictConfig) -> None: logging.info('Training finished!') logging.info("===========================================================================================") - if cfg.model.nemo_path: - # '.nemo' file contains the last checkpoint and the params to initialize the model - model.save_to(cfg.model.nemo_path) - logging.info(f'Model is saved into `.nemo` file: {cfg.model.nemo_path}') - # We evaluate the trained model on the test set if test_ds is set in the config file if cfg.model.test_ds.file_path: logging.info("===========================================================================================") logging.info("Starting the testing of the trained model on test set...") - trainer = pl.Trainer(**cfg.trainer) trainer.test(model=model, ckpt_path=None, verbose=False) logging.info("Testing finished!") logging.info("===========================================================================================") + # extract the path of the best checkpoint from the training, you may update it to any checkpoint + checkpoint_path = trainer.checkpoint_callback.best_model_path + tensor_parallel_size = cfg.model.tensor_model_parallel_size + pathobj = pathlib.Path(checkpoint_path) + checkpoint_folder = str(pathobj.parent) + checkpoint_name = str(pathobj.name) + + rank = trainer.accelerator.training_type_plugin.local_rank + if tensor_parallel_size > 1: + # inject model parallel rank + checkpoint_path = os.path.join(checkpoint_folder, f'mp_rank_{rank:02d}', checkpoint_name) + else: + checkpoint_path = os.path.join(checkpoint_folder, checkpoint_name) + + # Load the checkpoint + best_eval_model = PTuneTextClassificationModel.load_from_checkpoint( + checkpoint_path=checkpoint_path, strict=False, trainer=trainer + ) + logging.info(f'best checkpoint path: {checkpoint_path}') + logging.info("Running Test with best EVAL checkpoint!") + # setup the test dataset + best_eval_model.setup_test_data(test_data_config=cfg.model.test_ds) + if torch.distributed.is_initialized(): + torch.distributed.barrier() + trainer.test(model=best_eval_model, ckpt_path=None, verbose=False) + logging.info("Beset EVAL Testing finished!") + logging.info("===========================================================================================") + + if cfg.model.nemo_path: + # '.nemo' file contains the last checkpoint and the params to initialize the model + best_eval_model.save_to(cfg.model.nemo_path) + logging.info(f'Model is saved into `.nemo` file: {cfg.model.nemo_path}') + # perform inference on a list of queries. if "infer_samples" in cfg.model and cfg.model.infer_samples: logging.info("===========================================================================================") logging.info("Starting the inference on some sample queries...") # max_seq_length=512 is the maximum length BERT supports. - results = model.cuda().classifytext(queries=cfg.model.infer_samples, batch_size=1, prompt='Sentiment') + results = best_eval_model.cuda().classifytext( + queries=cfg.model.infer_samples, batch_size=1, prompt='Sentiment' + ) logging.info('The prediction results of some sample queries with the trained model:') for query, result in zip(cfg.model.infer_samples, results): logging.info(f'Query : {query}') diff --git a/nemo/collections/nlp/parts/nlp_overrides.py b/nemo/collections/nlp/parts/nlp_overrides.py index 0d61f9ed62b7..59c595decfa0 100644 --- a/nemo/collections/nlp/parts/nlp_overrides.py +++ b/nemo/collections/nlp/parts/nlp_overrides.py @@ -129,7 +129,7 @@ def init_model_parallel(self, global_rank: int, world_size: int) -> None: # we initialize megatron-lm model parallel and data parallel groups # after initializing DDP with PTL. if app_state.model_parallel_size is not None: - if torch.distributed.is_initialized(): + if torch.distributed.is_initialized() and app_state.data_parallel_group is None: parallel_state.initialize_model_parallel(app_state.model_parallel_size) app_state.model_parallel_group = parallel_state.get_tensor_model_parallel_group() app_state.data_parallel_group = parallel_state.get_data_parallel_group() From d4e2cddfee5b442125bbf7eb8998cc501259ede3 Mon Sep 17 00:00:00 2001 From: Yi Dong Date: Tue, 25 Jan 2022 16:07:04 -0800 Subject: [PATCH 18/22] address PR comments Signed-off-by: Yi Dong --- .../ptune_text_classification_config.yaml | 6 +++--- .../ptune_text_classification_dataset.py | 2 +- .../ptune_text_classification_model.py | 19 ++++++++++++++----- tutorials/nlp/PTune_Sentiment_Analysis.ipynb | 2 +- 4 files changed, 19 insertions(+), 10 deletions(-) diff --git a/examples/nlp/text_classification/conf/ptune_text_classification_config.yaml b/examples/nlp/text_classification/conf/ptune_text_classification_config.yaml index ed7e376dfb46..13ce98096526 100644 --- a/examples/nlp/text_classification/conf/ptune_text_classification_config.yaml +++ b/examples/nlp/text_classification/conf/ptune_text_classification_config.yaml @@ -1,4 +1,4 @@ -# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -32,9 +32,9 @@ trainer: logger: False # Provided by exp_manager model: - tensor_model_parallel_size: 2 # tensor model parallel size used in the LM model + tensor_model_parallel_size: 1 # tensor model parallel size used in the LM model seed: 1234 - nemo_path: ptune_text_classification_model.nemo # filename to save the model and associated artifacts to .nemo file + nemo_path: null # filename to save the model and associated artifacts to .nemo file use_lm_finetune: False # whether fine tune the language model pseudo_token: '[PROMPT]' # pseudo prompt tokens diff --git a/nemo/collections/nlp/data/text_classification/ptune_text_classification_dataset.py b/nemo/collections/nlp/data/text_classification/ptune_text_classification_dataset.py index 9005e4e97920..827b36ea4d36 100644 --- a/nemo/collections/nlp/data/text_classification/ptune_text_classification_dataset.py +++ b/nemo/collections/nlp/data/text_classification/ptune_text_classification_dataset.py @@ -1,4 +1,4 @@ -# Copyright 2018 The Google AI Language Team Authors and +# Copyright 2022 The Google AI Language Team Authors and # The HuggingFace Inc. team. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # diff --git a/nemo/collections/nlp/models/text_classification/ptune_text_classification_model.py b/nemo/collections/nlp/models/text_classification/ptune_text_classification_model.py index 138d41c74163..bb72f546aa9c 100644 --- a/nemo/collections/nlp/models/text_classification/ptune_text_classification_model.py +++ b/nemo/collections/nlp/models/text_classification/ptune_text_classification_model.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -84,7 +84,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): self.model = MegatronGPTModel.restore_from( self.register_artifact('language_model.nemo_file', cfg.language_model.get('nemo_file', None)), trainer=trainer, - ).half() + ) for param in self.model.parameters(): param.requires_grad = cfg.use_lm_finetune @@ -262,10 +262,19 @@ def forward(self, sentences, labels): encoder_input, new_atten, label_position = self.get_encoder_input(sentences) batch_size, _, seq_len, _ = new_atten.shape labels_input, label_ids = self.get_label_input(labels, label_position, seq_len) + # workaround to do auto-cast + # get the LM dtype + dtype = self.model.model.language_model.encoder.layers[0].dtype - output = self.model.model( - None, None, encoder_input=encoder_input, attention_mask=new_atten, labels=labels_input - ) + if dtype == torch.float32: + output = self.model.model( + None, None, encoder_input=encoder_input, attention_mask=new_atten, labels=labels_input + ) + else: + with torch.autocast(device_type="cuda", dtype=dtype): + output = self.model.model( + None, None, encoder_input=encoder_input, attention_mask=new_atten, labels=labels_input + ) loss, logits = output floss = (loss[(labels_input != SMALL_LOGITS)]).mean() diff --git a/tutorials/nlp/PTune_Sentiment_Analysis.ipynb b/tutorials/nlp/PTune_Sentiment_Analysis.ipynb index 6f3a44f080ad..f55dded9d7f0 100644 --- a/tutorials/nlp/PTune_Sentiment_Analysis.ipynb +++ b/tutorials/nlp/PTune_Sentiment_Analysis.ipynb @@ -56,7 +56,7 @@ "source": [ "In this tutorial, we are going to describe how to use [P-Tuning method](https://arxiv.org/pdf/2103.10385.pdf) to find good prompts for large GPT models, so it can solve downstream NLP tasks with good performance. P-Tuning leverages few continuous free parameters to serve as prompts fed as the input to the pre-trained language models. Freezing the large language model weights, P-Tuning model can be trained efficiently while delivering stats of art performance. \n", "\n", - "Large Language Model can be trained with [Megatron-LM project](https://github.com/NVIDIA/Megatron-LM), up to multi-billion parameters. In this notebook, we will use the pre-trained 344M GPT model released from NGC.\n", + "Large Language Model can be trained with [NeMo Megatron](https://github.com/NVIDIA/NeMo/tree/main/examples/nlp/language_modeling), up to multi-billion parameters. In this notebook, we will use the pre-trained 344M GPT model released from NGC.\n", "\n", "# Task Description\n", "In this notebook, we are going to use P-Tuning method for **Sentiment Analysis** task, also known as opinion mining or emotion AI. It is a sub-field of NLP that tries to identify and extract opinions within a given text across blogs, reviews, social media, forums, news etc.\n", From b3db9078831352a24220ae9f88a8bbd62c9eab01 Mon Sep 17 00:00:00 2001 From: Yi Dong Date: Wed, 26 Jan 2022 05:55:52 -0800 Subject: [PATCH 19/22] added ci test Signed-off-by: Yi Dong --- Jenkinsfile | 27 +++++++++++++++++++ .../ptune_text_classification_model.py | 16 ++++++++--- 2 files changed, 40 insertions(+), 3 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index 752c134b0415..f94b8939ef4f 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -1863,6 +1863,33 @@ pipeline { sh "rm -rf examples/nlp/language_modeling/bert_pretrain_results" } } + stage('L2: Megatron P-Tuning GPT LM') { + when { + anyOf { + branch 'main' + changeRequest target: 'main' + } + } + failFast true + steps { + sh "python examples/nlp/text_classification/ptune_text_classification.py \ + trainer.gpus=2 \ + trainer.max_epochs=1 \ + +trainer.limit_val_batches=10 \ + +trainer.limit_train_batches=10 \ + +trainer.limit_test_batches=10 \ + exp_manager.exp_dir=examples/nlp/language_modeling/ptune_results \ + model.tokenizer.vocab_file=/home/TestData/nlp/ptune/gpt2-vocab.json \ + model.tensor_model_parallel_size=2 \ + model.tokenizer.merge_file=/home/TestData/nlp/ptune/gpt2-merges.txt \ + model.language_model.nemo_file=/home/TestData/nlp/ptune/small_gpt.nemo \ + model.dataset.classes=[positive,neutral,negative] \ + model.train_ds.file_path=/home/TestData/nlp/ptune/data/train_0.txt \ + model.validation_ds.file_path=/home/TestData/nlp/ptune/data/validation_0.txt \ + model.test_ds.file_path=/home/TestData/nlp/ptune/data/test_0.txt " + sh "rm -rf examples/nlp/language_modeling/ptune_results" + } + } stage('L2: Megatron GPT Pretraining and Resume Training') { when { anyOf { diff --git a/nemo/collections/nlp/models/text_classification/ptune_text_classification_model.py b/nemo/collections/nlp/models/text_classification/ptune_text_classification_model.py index bb72f546aa9c..61b94fde59d8 100644 --- a/nemo/collections/nlp/models/text_classification/ptune_text_classification_model.py +++ b/nemo/collections/nlp/models/text_classification/ptune_text_classification_model.py @@ -249,9 +249,19 @@ def forward_eval(self, sentences): encoder_input, new_atten, label_position = self.get_encoder_input(sentences) batch_size, _, seq_len, _ = new_atten.shape - output = self.model.model( - None, None, encoder_input=encoder_input.to(self.device), attention_mask=new_atten.to(self.device) - ) + # workaround to do auto-cast + # get the LM dtype + dtype = self.model.model.language_model.encoder.layers[0].dtype + + if dtype == torch.float32: + output = self.model.model( + None, None, encoder_input=encoder_input.to(self.device), attention_mask=new_atten.to(self.device) + ) + else: + with torch.autocast(device_type="cuda", dtype=dtype): + output = self.model.model( + None, None, encoder_input=encoder_input.to(self.device), attention_mask=new_atten.to(self.device) + ) logits = output _, returned_pred = self.get_prediction(batch_size, label_position.to(self.device), logits) From f70542cc704828cc9019edb35debe9b869a85908 Mon Sep 17 00:00:00 2001 From: Yi Dong Date: Wed, 26 Jan 2022 12:27:10 -0800 Subject: [PATCH 20/22] fixed max_step calculation error due to wrong number of workers Signed-off-by: Yi Dong --- nemo/core/classes/modelPT.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/nemo/core/classes/modelPT.py b/nemo/core/classes/modelPT.py index 5b9840609541..049867f04ea5 100644 --- a/nemo/core/classes/modelPT.py +++ b/nemo/core/classes/modelPT.py @@ -458,6 +458,8 @@ def setup_optimization(self, optim_config: Optional[Union[DictConfig, Dict]] = N logging.warning(f"Trainer wasn't specified in model constructor. Make sure that you really wanted it.") if 'sched' in optim_config and self._trainer is not None: + from nemo.collections.nlp.parts.nlp_overrides import NLPDDPPlugin + if not isinstance(self._trainer.accumulate_grad_batches, int): raise ValueError("We do not currently support gradient acculumation that is not an integer.") if self._trainer.max_steps is None or self.trainer.max_steps < 0: @@ -471,6 +473,9 @@ def setup_optimization(self, optim_config: Optional[Union[DictConfig, Dict]] = N optim_config['sched']['t_num_workers'] = self._trainer.num_processes * self._trainer.num_nodes elif self._trainer.accelerator == "ddp": optim_config['sched']['t_num_workers'] = self._trainer.num_gpus * self._trainer.num_nodes + elif isinstance(self._trainer.accelerator.training_type_plugin, NLPDDPPlugin): + app = AppState() + optim_config['sched']['t_num_workers'] = app.data_parallel_size else: logging.warning( f"The lightning trainer received accelerator: {self._trainer.accelerator}. We " From 2856e849129b59932facbd424461dbe459232f6b Mon Sep 17 00:00:00 2001 From: Yi Dong Date: Wed, 26 Jan 2022 14:11:43 -0800 Subject: [PATCH 21/22] add import guard for nlp plugin Signed-off-by: Yi Dong --- nemo/collections/nlp/parts/nlp_overrides.py | 18 +++++++++--------- nemo/core/classes/modelPT.py | 10 ++++++++-- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/nemo/collections/nlp/parts/nlp_overrides.py b/nemo/collections/nlp/parts/nlp_overrides.py index 59c595decfa0..24aac079e3f3 100644 --- a/nemo/collections/nlp/parts/nlp_overrides.py +++ b/nemo/collections/nlp/parts/nlp_overrides.py @@ -12,6 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +try: + from apex.transformer import parallel_state + + HAVE_APEX = True + +except (ImportError, ModuleNotFoundError): + + HAVE_APEX = False + import os import shutil import tempfile @@ -35,15 +44,6 @@ from nemo.core.optim import MasterOptimizerWrapper from nemo.utils import AppState, logging -try: - from apex.transformer import parallel_state - - HAVE_APEX = True - -except (ImportError, ModuleNotFoundError): - - HAVE_APEX = False - class NLPDDPPlugin(DDPPlugin): """ DDP plugin for Pytorch Lightning. Needed to customize DDP for model parallel models. diff --git a/nemo/core/classes/modelPT.py b/nemo/core/classes/modelPT.py index 049867f04ea5..8b735567072e 100644 --- a/nemo/core/classes/modelPT.py +++ b/nemo/core/classes/modelPT.py @@ -35,6 +35,13 @@ from nemo.utils.app_state import AppState from nemo.utils.get_rank import is_global_rank_zero +try: + from nemo.collections.nlp.parts.nlp_overrides import NLPDDPPlugin + + HAVE_NLPPLUGIN = True +except (ImportError, ModuleNotFoundError): + HAVE_NLPPLUGIN = False + __all__ = ['ModelPT'] @@ -458,7 +465,6 @@ def setup_optimization(self, optim_config: Optional[Union[DictConfig, Dict]] = N logging.warning(f"Trainer wasn't specified in model constructor. Make sure that you really wanted it.") if 'sched' in optim_config and self._trainer is not None: - from nemo.collections.nlp.parts.nlp_overrides import NLPDDPPlugin if not isinstance(self._trainer.accumulate_grad_batches, int): raise ValueError("We do not currently support gradient acculumation that is not an integer.") @@ -473,7 +479,7 @@ def setup_optimization(self, optim_config: Optional[Union[DictConfig, Dict]] = N optim_config['sched']['t_num_workers'] = self._trainer.num_processes * self._trainer.num_nodes elif self._trainer.accelerator == "ddp": optim_config['sched']['t_num_workers'] = self._trainer.num_gpus * self._trainer.num_nodes - elif isinstance(self._trainer.accelerator.training_type_plugin, NLPDDPPlugin): + elif HAVE_NLPPLUGIN and isinstance(self._trainer.accelerator.training_type_plugin, NLPDDPPlugin): app = AppState() optim_config['sched']['t_num_workers'] = app.data_parallel_size else: From f54da3c791be874e24ace89b0588468470a5b4d6 Mon Sep 17 00:00:00 2001 From: Yi Dong Date: Thu, 27 Jan 2022 07:03:38 -0800 Subject: [PATCH 22/22] fixed the metric report issue when using tensor parallel Signed-off-by: Yi Dong --- .../ptune_text_classification_model.py | 22 +++++++++++++------ 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/nemo/collections/nlp/models/text_classification/ptune_text_classification_model.py b/nemo/collections/nlp/models/text_classification/ptune_text_classification_model.py index 61b94fde59d8..900c48020b2a 100644 --- a/nemo/collections/nlp/models/text_classification/ptune_text_classification_model.py +++ b/nemo/collections/nlp/models/text_classification/ptune_text_classification_model.py @@ -35,6 +35,7 @@ from nemo.core.classes.exportable import Exportable from nemo.core.neural_types import LossType, NeuralType, PredictionsType, StringLabel, StringType from nemo.utils import logging +from nemo.utils.app_state import AppState __all__ = ['PTuneTextClassificationModel'] @@ -107,18 +108,13 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): # map from id to label self.allowed_vocab = {} - label_ids = {} + self.label_ids = {} self.id_to_label = {} for i, k in enumerate(cfg.dataset.classes): self.allowed_vocab[self.vocab[token_wrapper(k)]] = i - label_ids[k] = i + self.label_ids[k] = i self.id_to_label[i] = k - # setup to track metrics - self.classification_report = ClassificationReport( - num_classes=len(self.classes), label_ids=label_ids, mode='micro', dist_sync_on_step=True - ) - self.template = cfg.prompt_encoder.template self.prompt_encoder = PromptEncoder( @@ -140,6 +136,18 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): ) self.spell_length = sum(self.template) + def setup(self, stage): + # setup to track metrics, need to put here + # as data_parallel_group is initialized when calling `fit, or test function` + app = AppState() + self.classification_report = ClassificationReport( + num_classes=len(self.classes), + label_ids=self.label_ids, + mode='micro', + dist_sync_on_step=True, + process_group=app.data_parallel_group, + ) + def embed_input(self, queries): bz = queries.shape[0] queries_for_embedding = queries.clone()