Skip to content

Commit

Permalink
LM adapted T5 dataset (#3654)
Browse files Browse the repository at this point in the history
* LM adapted T5 dataset

Signed-off-by: MaximumEntropy <[email protected]>

* Style fixes

Signed-off-by: MaximumEntropy <[email protected]>

* File renaming

Signed-off-by: MaximumEntropy <[email protected]>

* change assert to raising valueerror

Signed-off-by: MaximumEntropy <[email protected]>

* Style fixes

Signed-off-by: MaximumEntropy <[email protected]>

* Printing changes

Signed-off-by: MaximumEntropy <[email protected]>

* Style fixes

Signed-off-by: MaximumEntropy <[email protected]>

* Comment out ICT dataset

Signed-off-by: MaximumEntropy <[email protected]>

Co-authored-by: Micha Livne <[email protected]>
Co-authored-by: Eric Harper <[email protected]>
  • Loading branch information
3 people authored Feb 17, 2022
1 parent a8f29af commit 49183b7
Show file tree
Hide file tree
Showing 8 changed files with 184 additions and 64 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ model:
num_workers: 0
dataloader_type: single # cyclic
masked_lm_prob: 0.15
dataset_type: 't5'
short_seq_prob: 0.0
max_ngram_size: 10
mean_ngram_size: null
Expand Down
4 changes: 2 additions & 2 deletions examples/nlp/language_modeling/megatron_t5_pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from omegaconf.omegaconf import OmegaConf, open_dict
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelSummary
from pytorch_lightning.callbacks.timer import Timer
from pytorch_lightning.plugins.environments.torchelastic_environment import TorchElasticEnvironment
from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin
Expand Down Expand Up @@ -59,7 +60,7 @@ def main(cfg) -> None:
if cfg.get('cluster_type', None) == 'BCP':
plugins.append(TorchElasticEnvironment())

trainer = Trainer(plugins=plugins, **cfg.trainer)
trainer = Trainer(plugins=plugins, **cfg.trainer, callbacks=[ModelSummary(max_depth=3)])

exp_manager(trainer, cfg.exp_manager)

Expand All @@ -78,7 +79,6 @@ def main(cfg) -> None:
cfg.model.precision = cfg.trainer.precision

model = MegatronT5Model(cfg.model, trainer)

trainer.fit(model)


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math


def get_datasets_weights_and_num_samples(data_prefix, train_valid_test_num_samples):

# The data prefix should be in the format of:
# weight-1, data-prefix-1, weight-2, data-prefix-2, ..
assert len(data_prefix) % 2 == 0
num_datasets = len(data_prefix) // 2
weights = [0] * num_datasets
prefixes = [0] * num_datasets
for i in range(num_datasets):
weights[i] = float(data_prefix[2 * i])
prefixes[i] = (data_prefix[2 * i + 1]).strip()
# Normalize weights
weight_sum = 0.0
for weight in weights:
weight_sum += weight
assert weight_sum > 0.0
weights = [weight / weight_sum for weight in weights]

# Add 0.5% (the 1.005 factor) so in case the bleding dataset does
# not uniformly distribute the number of samples, we still have
# samples left to feed to the network.
# TODO: check data leakage between train/val/test?
datasets_train_valid_test_num_samples = []
for weight in weights:
datasets_train_valid_test_num_samples.append(
[int(math.ceil(val * weight * 1.005)) for val in train_valid_test_num_samples]
)

return prefixes, weights, datasets_train_valid_test_num_samples


def get_train_valid_test_split_(splits_string, size):
""" Get dataset splits from comma or '/' separated string list."""

splits = []
if splits_string.find(',') != -1:
splits = [float(s) for s in splits_string.split(',')]
elif splits_string.find('/') != -1:
splits = [float(s) for s in splits_string.split('/')]
else:
splits = [float(splits_string)]
if len(splits) != 3:
raise ValueError(f"Invalid splits string: {splits_string}. Expected 3 comma separated values.")
while len(splits) < 3:
splits.append(0.0)
splits = splits[:3]
splits_sum = sum(splits)
assert splits_sum > 0.0
splits = [split / splits_sum for split in splits]
splits_index = [0]
for index, split in enumerate(splits):
splits_index.append(splits_index[index] + int(round(split * float(size))))
diff = splits_index[-1] - size
for index in range(1, len(splits_index)):
splits_index[index] -= diff
assert len(splits_index) == 4
assert splits_index[-1] == size
return splits_index
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,13 @@
import numpy as np
import torch

from nemo.collections.nlp.data.language_modeling.megatron.base_dataset_utils import (
get_datasets_weights_and_num_samples,
get_train_valid_test_split_,
)
from nemo.collections.nlp.data.language_modeling.megatron.blendable_dataset import BlendableDataset
from nemo.collections.nlp.data.language_modeling.megatron.indexed_dataset import make_dataset as make_indexed_dataset
from nemo.collections.nlp.data.language_modeling.megatron.lm_adapted_t5_dataset import T5LMAdaptedDataset
from nemo.utils import logging
from nemo.utils.get_rank import is_global_rank_zero

Expand All @@ -57,38 +62,9 @@
DSET_TYPE_BERT = 'standard_bert'
DSET_TYPE_ICT = 'ict'
DSET_TYPE_T5 = 't5'
DSET_TYPE_T5_LM = 't5_prefix_lm'

DSET_TYPES = [DSET_TYPE_BERT, DSET_TYPE_ICT, DSET_TYPE_T5]


def get_datasets_weights_and_num_samples(data_prefix, train_valid_test_num_samples):

# The data prefix should be in the format of:
# weight-1, data-prefix-1, weight-2, data-prefix-2, ..
assert len(data_prefix) % 2 == 0
num_datasets = len(data_prefix) // 2
weights = [0] * num_datasets
prefixes = [0] * num_datasets
for i in range(num_datasets):
weights[i] = float(data_prefix[2 * i])
prefixes[i] = (data_prefix[2 * i + 1]).strip()
# Normalize weights
weight_sum = 0.0
for weight in weights:
weight_sum += weight
assert weight_sum > 0.0
weights = [weight / weight_sum for weight in weights]

# Add 0.5% (the 1.005 factor) so in case the bleding dataset does
# not uniformly distribute the number of samples, we still have
# samples left to feed to the network.
datasets_train_valid_test_num_samples = []
for weight in weights:
datasets_train_valid_test_num_samples.append(
[int(math.ceil(val * weight * 1.005)) for val in train_valid_test_num_samples]
)

return prefixes, weights, datasets_train_valid_test_num_samples
DSET_TYPES = [DSET_TYPE_BERT, DSET_TYPE_ICT, DSET_TYPE_T5, DSET_TYPE_T5_LM]


def compile_helper():
Expand Down Expand Up @@ -634,6 +610,8 @@ def build_dataset(index, name):
)

if dataset_type == DSET_TYPE_ICT:
raise NotImplementedError("ICT dataset is not implemented yet.")
'''
dataset = ICTDataset(
block_dataset=indexed_dataset,
title_dataset=title_dataset,
Expand All @@ -642,8 +620,10 @@ def build_dataset(index, name):
binary_head=binary_head,
**kwargs,
)
'''
elif dataset_type == DSET_TYPE_T5:
assert tokenizer is not None, "Tokenizer is required for T5 dataset"
logging.info("Instatiating T5 Dataset ...")
dataset = T5Dataset(
cfg=cfg,
trainer=trainer,
Expand All @@ -661,6 +641,7 @@ def build_dataset(index, name):
**kwargs,
)
elif dataset_type == DSET_TYPE_BERT:
logging.info("Instatiating BERT Dataset ...")
dataset = BertDataset(
indexed_dataset=indexed_dataset,
masked_lm_prob=masked_lm_prob,
Expand All @@ -669,6 +650,18 @@ def build_dataset(index, name):
tokenizer=tokenizer,
**kwargs,
)
elif dataset_type == DSET_TYPE_T5_LM:
documents = np.arange(start=splits[index], stop=splits[index + 1], step=1, dtype=np.int32)
logging.info("Instatiating T5 Prefix-LM Dataset ...")
dataset = T5LMAdaptedDataset(
cfg=cfg,
trainer=trainer,
tokenizer=tokenizer,
documents=documents,
indexed_dataset=indexed_dataset,
num_samples=int(train_valid_test_num_samples[index]),
**kwargs,
)
else:
raise NotImplementedError("Dataset type not fully implemented.")

Expand Down Expand Up @@ -702,33 +695,6 @@ def get_indexed_dataset_(data_prefix, data_impl, skip_warmup):
return indexed_dataset


def get_train_valid_test_split_(splits_string, size):
""" Get dataset splits from comma or '/' separated string list."""

splits = []
if splits_string.find(',') != -1:
splits = [float(s) for s in splits_string.split(',')]
elif splits_string.find('/') != -1:
splits = [float(s) for s in splits_string.split('/')]
else:
splits = [float(splits_string)]
while len(splits) < 3:
splits.append(0.0)
splits = splits[:3]
splits_sum = sum(splits)
assert splits_sum > 0.0
splits = [split / splits_sum for split in splits]
splits_index = [0]
for index, split in enumerate(splits):
splits_index.append(splits_index[index] + int(round(split * float(size))))
diff = splits_index[-1] - size
for index in range(1, len(splits_index)):
splits_index[index] -= diff
assert len(splits_index) == 4
assert splits_index[-1] == size
return splits_index


def get_samples_mapping(
indexed_dataset, data_prefix, num_epochs, max_num_samples, max_seq_length, short_seq_prob, seed, name, binary_head
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@
import numpy as np
import torch

from nemo.collections.nlp.data.language_modeling.megatron.blendable_dataset import BlendableDataset
from nemo.collections.nlp.data.language_modeling.megatron.dataset_utils import (
from nemo.collections.nlp.data.language_modeling.megatron.base_dataset_utils import (
get_datasets_weights_and_num_samples,
get_train_valid_test_split_,
)
from nemo.collections.nlp.data.language_modeling.megatron.blendable_dataset import BlendableDataset
from nemo.collections.nlp.data.language_modeling.megatron.indexed_dataset import make_dataset as make_indexed_dataset
from nemo.collections.nlp.data.language_modeling.megatron.megatron_dataset import MegatronDataset
from nemo.utils import logging
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np

from nemo.collections.nlp.data.language_modeling.megatron.gpt_dataset import GPTDataset


class T5LMAdaptedDataset(GPTDataset):
"""
Dataset for unlearning span corruption (https://arxiv.org/abs/2104.08691) in T5 models.
Corresponds to the prefix-LM objective in the T5 paper (Table 3 in https://arxiv.org/abs/1910.10683).
"""

def __init__(
self, cfg, trainer, tokenizer, name, data_prefix, documents, indexed_dataset, num_samples, seed, **kwargs
):
self.seq_length_encoder = cfg.data.seq_length
self.seq_length_decoder = cfg.data.seq_length_dec
self.tokenizer = tokenizer
super().__init__(
cfg,
trainer,
name,
data_prefix,
documents,
indexed_dataset,
num_samples,
self.seq_length_encoder + self.seq_length_decoder,
seed,
)

def __getitem__(self, idx):
text = super().__getitem__(idx)
text = text['text']

# Split text sequence into encoder and decoder inputs
tokens_enc = text[: self.seq_length_encoder]

# NOTE: Add bos only and not eos because the model will always generate till max seq length.
tokens_dec = np.concatenate(([self.tokenizer.bos_id], text[self.seq_length_encoder :]))

# Shift sequences for teacher forcing
tokens_dec_in = tokens_dec[:-1]
labels = tokens_dec[1:]

# Create attention masks
enc_mask = (tokens_enc != self.tokenizer.pad_id).astype(np.int64)
dec_mask = (tokens_dec_in != self.tokenizer.pad_id).astype(np.int64)

loss_mask = dec_mask

train_sample = {
'text_enc': tokens_enc,
'text_dec': tokens_dec_in,
'labels': labels,
'loss_mask': loss_mask,
'truncated': False,
'enc_mask': enc_mask,
'dec_mask': dec_mask,
}
return train_sample
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,6 @@ def process_batch(self, batch):

keys = ['text_enc', 'text_dec', 'labels', 'loss_mask', 'enc_mask', 'dec_mask']
datatype = torch.int64

data = batch
data_b = tensor_parallel.broadcast_data(keys, data, datatype)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,12 @@ def build_train_valid_test_datasets(self):
eval_iters * global_batch_size,
test_iters * global_batch_size,
]
# Make sure the user specifies dataset type as either 't5' or 't5_prefix_lm' only.
if self._cfg.data.get('dataset_type', None) is not None:
if self._cfg.data.get('dataset_type') not in ['t5', 't5_prefix_lm']:
raise ValueError(
f"dataset_type must be either 't5' or 't5_prefix_lm'. found {self._cfg.data.get('dataset_type')}"
)
self._train_ds, self._validation_ds, self._test_ds = build_train_valid_test_datasets(
cfg=self._cfg,
trainer=self.trainer,
Expand All @@ -118,7 +124,7 @@ def build_train_valid_test_datasets(self):
short_seq_prob=self._cfg.data.short_seq_prob,
seed=self._cfg.seed,
skip_warmup=self._cfg.data.skip_warmup,
dataset_type='t5',
dataset_type=self._cfg.data.get('dataset_type', 't5'),
max_ngram_size=self._cfg.get('max_ngram_size', 10),
mean_ngram_size=self._cfg.get('mean_ngram_size', None),
geometric_dist=self._cfg.get('geometric_dist', True),
Expand Down

0 comments on commit 49183b7

Please sign in to comment.