diff --git a/examples/nlp/language_modeling/conf/megatron_t5_config.yaml b/examples/nlp/language_modeling/conf/megatron_t5_config.yaml index 6800b2864b16..44b0e6d67ff7 100644 --- a/examples/nlp/language_modeling/conf/megatron_t5_config.yaml +++ b/examples/nlp/language_modeling/conf/megatron_t5_config.yaml @@ -108,6 +108,7 @@ model: num_workers: 0 dataloader_type: single # cyclic masked_lm_prob: 0.15 + dataset_type: 't5' short_seq_prob: 0.0 max_ngram_size: 10 mean_ngram_size: null diff --git a/examples/nlp/language_modeling/megatron_t5_pretraining.py b/examples/nlp/language_modeling/megatron_t5_pretraining.py index eeabd0905cbb..a641f75f6501 100644 --- a/examples/nlp/language_modeling/megatron_t5_pretraining.py +++ b/examples/nlp/language_modeling/megatron_t5_pretraining.py @@ -15,6 +15,7 @@ from omegaconf.omegaconf import OmegaConf, open_dict from pytorch_lightning import Trainer +from pytorch_lightning.callbacks import ModelSummary from pytorch_lightning.callbacks.timer import Timer from pytorch_lightning.plugins.environments.torchelastic_environment import TorchElasticEnvironment from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin @@ -59,7 +60,7 @@ def main(cfg) -> None: if cfg.get('cluster_type', None) == 'BCP': plugins.append(TorchElasticEnvironment()) - trainer = Trainer(plugins=plugins, **cfg.trainer) + trainer = Trainer(plugins=plugins, **cfg.trainer, callbacks=[ModelSummary(max_depth=3)]) exp_manager(trainer, cfg.exp_manager) @@ -78,7 +79,6 @@ def main(cfg) -> None: cfg.model.precision = cfg.trainer.precision model = MegatronT5Model(cfg.model, trainer) - trainer.fit(model) diff --git a/nemo/collections/nlp/data/language_modeling/megatron/base_dataset_utils.py b/nemo/collections/nlp/data/language_modeling/megatron/base_dataset_utils.py new file mode 100644 index 000000000000..ef7cd8ae5660 --- /dev/null +++ b/nemo/collections/nlp/data/language_modeling/megatron/base_dataset_utils.py @@ -0,0 +1,75 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + + +def get_datasets_weights_and_num_samples(data_prefix, train_valid_test_num_samples): + + # The data prefix should be in the format of: + # weight-1, data-prefix-1, weight-2, data-prefix-2, .. + assert len(data_prefix) % 2 == 0 + num_datasets = len(data_prefix) // 2 + weights = [0] * num_datasets + prefixes = [0] * num_datasets + for i in range(num_datasets): + weights[i] = float(data_prefix[2 * i]) + prefixes[i] = (data_prefix[2 * i + 1]).strip() + # Normalize weights + weight_sum = 0.0 + for weight in weights: + weight_sum += weight + assert weight_sum > 0.0 + weights = [weight / weight_sum for weight in weights] + + # Add 0.5% (the 1.005 factor) so in case the bleding dataset does + # not uniformly distribute the number of samples, we still have + # samples left to feed to the network. + # TODO: check data leakage between train/val/test? + datasets_train_valid_test_num_samples = [] + for weight in weights: + datasets_train_valid_test_num_samples.append( + [int(math.ceil(val * weight * 1.005)) for val in train_valid_test_num_samples] + ) + + return prefixes, weights, datasets_train_valid_test_num_samples + + +def get_train_valid_test_split_(splits_string, size): + """ Get dataset splits from comma or '/' separated string list.""" + + splits = [] + if splits_string.find(',') != -1: + splits = [float(s) for s in splits_string.split(',')] + elif splits_string.find('/') != -1: + splits = [float(s) for s in splits_string.split('/')] + else: + splits = [float(splits_string)] + if len(splits) != 3: + raise ValueError(f"Invalid splits string: {splits_string}. Expected 3 comma separated values.") + while len(splits) < 3: + splits.append(0.0) + splits = splits[:3] + splits_sum = sum(splits) + assert splits_sum > 0.0 + splits = [split / splits_sum for split in splits] + splits_index = [0] + for index, split in enumerate(splits): + splits_index.append(splits_index[index] + int(round(split * float(size)))) + diff = splits_index[-1] - size + for index in range(1, len(splits_index)): + splits_index[index] -= diff + assert len(splits_index) == 4 + assert splits_index[-1] == size + return splits_index diff --git a/nemo/collections/nlp/data/language_modeling/megatron/dataset_utils.py b/nemo/collections/nlp/data/language_modeling/megatron/dataset_utils.py index eec074ef7319..45a7d63bb46e 100644 --- a/nemo/collections/nlp/data/language_modeling/megatron/dataset_utils.py +++ b/nemo/collections/nlp/data/language_modeling/megatron/dataset_utils.py @@ -39,8 +39,13 @@ import numpy as np import torch +from nemo.collections.nlp.data.language_modeling.megatron.base_dataset_utils import ( + get_datasets_weights_and_num_samples, + get_train_valid_test_split_, +) from nemo.collections.nlp.data.language_modeling.megatron.blendable_dataset import BlendableDataset from nemo.collections.nlp.data.language_modeling.megatron.indexed_dataset import make_dataset as make_indexed_dataset +from nemo.collections.nlp.data.language_modeling.megatron.lm_adapted_t5_dataset import T5LMAdaptedDataset from nemo.utils import logging from nemo.utils.get_rank import is_global_rank_zero @@ -57,38 +62,9 @@ DSET_TYPE_BERT = 'standard_bert' DSET_TYPE_ICT = 'ict' DSET_TYPE_T5 = 't5' +DSET_TYPE_T5_LM = 't5_prefix_lm' -DSET_TYPES = [DSET_TYPE_BERT, DSET_TYPE_ICT, DSET_TYPE_T5] - - -def get_datasets_weights_and_num_samples(data_prefix, train_valid_test_num_samples): - - # The data prefix should be in the format of: - # weight-1, data-prefix-1, weight-2, data-prefix-2, .. - assert len(data_prefix) % 2 == 0 - num_datasets = len(data_prefix) // 2 - weights = [0] * num_datasets - prefixes = [0] * num_datasets - for i in range(num_datasets): - weights[i] = float(data_prefix[2 * i]) - prefixes[i] = (data_prefix[2 * i + 1]).strip() - # Normalize weights - weight_sum = 0.0 - for weight in weights: - weight_sum += weight - assert weight_sum > 0.0 - weights = [weight / weight_sum for weight in weights] - - # Add 0.5% (the 1.005 factor) so in case the bleding dataset does - # not uniformly distribute the number of samples, we still have - # samples left to feed to the network. - datasets_train_valid_test_num_samples = [] - for weight in weights: - datasets_train_valid_test_num_samples.append( - [int(math.ceil(val * weight * 1.005)) for val in train_valid_test_num_samples] - ) - - return prefixes, weights, datasets_train_valid_test_num_samples +DSET_TYPES = [DSET_TYPE_BERT, DSET_TYPE_ICT, DSET_TYPE_T5, DSET_TYPE_T5_LM] def compile_helper(): @@ -634,6 +610,8 @@ def build_dataset(index, name): ) if dataset_type == DSET_TYPE_ICT: + raise NotImplementedError("ICT dataset is not implemented yet.") + ''' dataset = ICTDataset( block_dataset=indexed_dataset, title_dataset=title_dataset, @@ -642,8 +620,10 @@ def build_dataset(index, name): binary_head=binary_head, **kwargs, ) + ''' elif dataset_type == DSET_TYPE_T5: assert tokenizer is not None, "Tokenizer is required for T5 dataset" + logging.info("Instatiating T5 Dataset ...") dataset = T5Dataset( cfg=cfg, trainer=trainer, @@ -661,6 +641,7 @@ def build_dataset(index, name): **kwargs, ) elif dataset_type == DSET_TYPE_BERT: + logging.info("Instatiating BERT Dataset ...") dataset = BertDataset( indexed_dataset=indexed_dataset, masked_lm_prob=masked_lm_prob, @@ -669,6 +650,18 @@ def build_dataset(index, name): tokenizer=tokenizer, **kwargs, ) + elif dataset_type == DSET_TYPE_T5_LM: + documents = np.arange(start=splits[index], stop=splits[index + 1], step=1, dtype=np.int32) + logging.info("Instatiating T5 Prefix-LM Dataset ...") + dataset = T5LMAdaptedDataset( + cfg=cfg, + trainer=trainer, + tokenizer=tokenizer, + documents=documents, + indexed_dataset=indexed_dataset, + num_samples=int(train_valid_test_num_samples[index]), + **kwargs, + ) else: raise NotImplementedError("Dataset type not fully implemented.") @@ -702,33 +695,6 @@ def get_indexed_dataset_(data_prefix, data_impl, skip_warmup): return indexed_dataset -def get_train_valid_test_split_(splits_string, size): - """ Get dataset splits from comma or '/' separated string list.""" - - splits = [] - if splits_string.find(',') != -1: - splits = [float(s) for s in splits_string.split(',')] - elif splits_string.find('/') != -1: - splits = [float(s) for s in splits_string.split('/')] - else: - splits = [float(splits_string)] - while len(splits) < 3: - splits.append(0.0) - splits = splits[:3] - splits_sum = sum(splits) - assert splits_sum > 0.0 - splits = [split / splits_sum for split in splits] - splits_index = [0] - for index, split in enumerate(splits): - splits_index.append(splits_index[index] + int(round(split * float(size)))) - diff = splits_index[-1] - size - for index in range(1, len(splits_index)): - splits_index[index] -= diff - assert len(splits_index) == 4 - assert splits_index[-1] == size - return splits_index - - def get_samples_mapping( indexed_dataset, data_prefix, num_epochs, max_num_samples, max_seq_length, short_seq_prob, seed, name, binary_head ): diff --git a/nemo/collections/nlp/data/language_modeling/megatron/gpt_dataset.py b/nemo/collections/nlp/data/language_modeling/megatron/gpt_dataset.py index c2c8e0ebaaa6..1e207e6aa550 100644 --- a/nemo/collections/nlp/data/language_modeling/megatron/gpt_dataset.py +++ b/nemo/collections/nlp/data/language_modeling/megatron/gpt_dataset.py @@ -20,11 +20,11 @@ import numpy as np import torch -from nemo.collections.nlp.data.language_modeling.megatron.blendable_dataset import BlendableDataset -from nemo.collections.nlp.data.language_modeling.megatron.dataset_utils import ( +from nemo.collections.nlp.data.language_modeling.megatron.base_dataset_utils import ( get_datasets_weights_and_num_samples, get_train_valid_test_split_, ) +from nemo.collections.nlp.data.language_modeling.megatron.blendable_dataset import BlendableDataset from nemo.collections.nlp.data.language_modeling.megatron.indexed_dataset import make_dataset as make_indexed_dataset from nemo.collections.nlp.data.language_modeling.megatron.megatron_dataset import MegatronDataset from nemo.utils import logging diff --git a/nemo/collections/nlp/data/language_modeling/megatron/lm_adapted_t5_dataset.py b/nemo/collections/nlp/data/language_modeling/megatron/lm_adapted_t5_dataset.py new file mode 100644 index 000000000000..3f50391a22aa --- /dev/null +++ b/nemo/collections/nlp/data/language_modeling/megatron/lm_adapted_t5_dataset.py @@ -0,0 +1,73 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + +from nemo.collections.nlp.data.language_modeling.megatron.gpt_dataset import GPTDataset + + +class T5LMAdaptedDataset(GPTDataset): + """ + Dataset for unlearning span corruption (https://arxiv.org/abs/2104.08691) in T5 models. + Corresponds to the prefix-LM objective in the T5 paper (Table 3 in https://arxiv.org/abs/1910.10683). + """ + + def __init__( + self, cfg, trainer, tokenizer, name, data_prefix, documents, indexed_dataset, num_samples, seed, **kwargs + ): + self.seq_length_encoder = cfg.data.seq_length + self.seq_length_decoder = cfg.data.seq_length_dec + self.tokenizer = tokenizer + super().__init__( + cfg, + trainer, + name, + data_prefix, + documents, + indexed_dataset, + num_samples, + self.seq_length_encoder + self.seq_length_decoder, + seed, + ) + + def __getitem__(self, idx): + text = super().__getitem__(idx) + text = text['text'] + + # Split text sequence into encoder and decoder inputs + tokens_enc = text[: self.seq_length_encoder] + + # NOTE: Add bos only and not eos because the model will always generate till max seq length. + tokens_dec = np.concatenate(([self.tokenizer.bos_id], text[self.seq_length_encoder :])) + + # Shift sequences for teacher forcing + tokens_dec_in = tokens_dec[:-1] + labels = tokens_dec[1:] + + # Create attention masks + enc_mask = (tokens_enc != self.tokenizer.pad_id).astype(np.int64) + dec_mask = (tokens_dec_in != self.tokenizer.pad_id).astype(np.int64) + + loss_mask = dec_mask + + train_sample = { + 'text_enc': tokens_enc, + 'text_dec': tokens_dec_in, + 'labels': labels, + 'loss_mask': loss_mask, + 'truncated': False, + 'enc_mask': enc_mask, + 'dec_mask': dec_mask, + } + return train_sample diff --git a/nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py b/nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py index c3c4ffa2dba7..a3d5b8ff4974 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py @@ -266,7 +266,6 @@ def process_batch(self, batch): keys = ['text_enc', 'text_dec', 'labels', 'loss_mask', 'enc_mask', 'dec_mask'] datatype = torch.int64 - data = batch data_b = tensor_parallel.broadcast_data(keys, data, datatype) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_t5_model.py b/nemo/collections/nlp/models/language_modeling/megatron_t5_model.py index ac625b12b55c..6672d9edd62f 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_t5_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_t5_model.py @@ -104,6 +104,12 @@ def build_train_valid_test_datasets(self): eval_iters * global_batch_size, test_iters * global_batch_size, ] + # Make sure the user specifies dataset type as either 't5' or 't5_prefix_lm' only. + if self._cfg.data.get('dataset_type', None) is not None: + if self._cfg.data.get('dataset_type') not in ['t5', 't5_prefix_lm']: + raise ValueError( + f"dataset_type must be either 't5' or 't5_prefix_lm'. found {self._cfg.data.get('dataset_type')}" + ) self._train_ds, self._validation_ds, self._test_ds = build_train_valid_test_datasets( cfg=self._cfg, trainer=self.trainer, @@ -118,7 +124,7 @@ def build_train_valid_test_datasets(self): short_seq_prob=self._cfg.data.short_seq_prob, seed=self._cfg.seed, skip_warmup=self._cfg.data.skip_warmup, - dataset_type='t5', + dataset_type=self._cfg.data.get('dataset_type', 't5'), max_ngram_size=self._cfg.get('max_ngram_size', 10), mean_ngram_size=self._cfg.get('mean_ngram_size', None), geometric_dist=self._cfg.get('geometric_dist', True),