Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LM adapted T5 dataset #3654

Merged
merged 23 commits into from
Feb 17, 2022
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
e1bb282
LM adapted T5 dataset
MaximumEntropy Feb 11, 2022
54047bd
Style fixes
MaximumEntropy Feb 11, 2022
b255e58
Merge branch 'main' into t5_lm_adaptation
MaximumEntropy Feb 11, 2022
99fabbb
Merge branch 'main' into t5_lm_adaptation
michalivne Feb 14, 2022
d19d09b
File renaming
MaximumEntropy Feb 14, 2022
8b75d01
change assert to raising valueerror
MaximumEntropy Feb 14, 2022
377342c
Style fixes
MaximumEntropy Feb 14, 2022
a826ac7
Merge branch 'main' into t5_lm_adaptation
ericharper Feb 14, 2022
3ab1a71
Printing changes
MaximumEntropy Feb 14, 2022
803b3f7
Merge branch 't5_lm_adaptation' of github.com:NVIDIA/NeMo into t5_lm_…
MaximumEntropy Feb 14, 2022
0f1579d
Merge branch 'main' into t5_lm_adaptation
MaximumEntropy Feb 14, 2022
007b9e2
Style fixes
MaximumEntropy Feb 14, 2022
9d4c3aa
Merge branch 'main' into t5_lm_adaptation
MaximumEntropy Feb 14, 2022
fd8c672
Merge main and fix conflicts
MaximumEntropy Feb 15, 2022
69acc37
Merge branch 'main' into t5_lm_adaptation
MaximumEntropy Feb 15, 2022
7d1626f
Merge branch 'main' into t5_lm_adaptation
MaximumEntropy Feb 16, 2022
5da9724
Merge branch 'main' into t5_lm_adaptation
ericharper Feb 16, 2022
acf4909
Update conflict merges
MaximumEntropy Feb 16, 2022
77ae69f
Merge branch 'main' into t5_lm_adaptation
MaximumEntropy Feb 16, 2022
8339503
Comment out ICT dataset
MaximumEntropy Feb 16, 2022
c40f91d
Merge branch 't5_lm_adaptation' of github.com:NVIDIA/NeMo into t5_lm_…
MaximumEntropy Feb 16, 2022
ed3be60
Merge branch 'main' into t5_lm_adaptation
MaximumEntropy Feb 16, 2022
8d026b1
Merge branch 'main' into t5_lm_adaptation
MaximumEntropy Feb 16, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ model:
dataloader_type: single # cyclic
masked_lm_prob: 0.15
short_seq_prob: 0.1
dataset_type: 't5'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe add possible types here in a comment like # t5, ...

max_ngram_size: 10
mean_ngram_size: null
geometric_dist: True
Expand Down
4 changes: 2 additions & 2 deletions examples/nlp/language_modeling/megatron_t5_pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from omegaconf.omegaconf import OmegaConf, open_dict
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelSummary
from pytorch_lightning.callbacks.timer import Timer
from pytorch_lightning.plugins.environments.torchelastic_environment import TorchElasticEnvironment
from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin
Expand Down Expand Up @@ -43,7 +44,7 @@ def main(cfg) -> None:
if cfg.get('cluster_type', None) == 'BCP':
plugins.append(TorchElasticEnvironment())

trainer = Trainer(plugins=plugins, **cfg.trainer)
trainer = Trainer(plugins=plugins, **cfg.trainer, callbacks=[ModelSummary(max_depth=3)])
michalivne marked this conversation as resolved.
Show resolved Hide resolved

exp_manager(trainer, cfg.exp_manager)

Expand All @@ -62,7 +63,6 @@ def main(cfg) -> None:
cfg.model.precision = cfg.trainer.precision

model = MegatronT5Model(cfg.model, trainer)

trainer.fit(model)


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math


def get_datasets_weights_and_num_samples(data_prefix, train_valid_test_num_samples):

# The data prefix should be in the format of:
# weight-1, data-prefix-1, weight-2, data-prefix-2, ..
assert len(data_prefix) % 2 == 0
num_datasets = len(data_prefix) // 2
weights = [0] * num_datasets
prefixes = [0] * num_datasets
for i in range(num_datasets):
weights[i] = float(data_prefix[2 * i])
prefixes[i] = (data_prefix[2 * i + 1]).strip()
# Normalize weights
weight_sum = 0.0
for weight in weights:
weight_sum += weight
assert weight_sum > 0.0
weights = [weight / weight_sum for weight in weights]

# Add 0.5% (the 1.005 factor) so in case the bleding dataset does
# not uniformly distribute the number of samples, we still have
# samples left to feed to the network.
# TODO: check data leakage between train/val/test?
datasets_train_valid_test_num_samples = []
for weight in weights:
datasets_train_valid_test_num_samples.append(
[int(math.ceil(val * weight * 1.005)) for val in train_valid_test_num_samples]
)

return prefixes, weights, datasets_train_valid_test_num_samples


def get_train_valid_test_split_(splits_string, size):
""" Get dataset splits from comma or '/' separated string list."""

splits = []
if splits_string.find(',') != -1:
splits = [float(s) for s in splits_string.split(',')]
elif splits_string.find('/') != -1:
splits = [float(s) for s in splits_string.split('/')]
else:
splits = [float(splits_string)]
if len(splits) != 3:
raise ValueError(f"Invalid splits string: {splits_string}. Expected 3 comma separated values.")
while len(splits) < 3:
splits.append(0.0)
splits = splits[:3]
splits_sum = sum(splits)
assert splits_sum > 0.0
splits = [split / splits_sum for split in splits]
splits_index = [0]
for index, split in enumerate(splits):
splits_index.append(splits_index[index] + int(round(split * float(size))))
diff = splits_index[-1] - size
for index in range(1, len(splits_index)):
splits_index[index] -= diff
assert len(splits_index) == 4
assert splits_index[-1] == size
return splits_index
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,13 @@
import numpy as np
import torch

from nemo.collections.nlp.data.language_modeling.megatron.base_dataset_utils import (
get_datasets_weights_and_num_samples,
get_train_valid_test_split_,
)
from nemo.collections.nlp.data.language_modeling.megatron.blendable_dataset import BlendableDataset
from nemo.collections.nlp.data.language_modeling.megatron.indexed_dataset import make_dataset as make_indexed_dataset
from nemo.collections.nlp.data.language_modeling.megatron.lm_adapted_t5_dataset import T5LMAdaptedDataset
from nemo.utils import logging
from nemo.utils.get_rank import is_global_rank_zero

Expand All @@ -57,38 +62,9 @@
DSET_TYPE_BERT = 'standard_bert'
DSET_TYPE_ICT = 'ict'
DSET_TYPE_T5 = 't5'
DSET_TYPE_T5_LM = 't5_prefix_lm'

DSET_TYPES = [DSET_TYPE_BERT, DSET_TYPE_ICT, DSET_TYPE_T5]


def get_datasets_weights_and_num_samples(data_prefix, train_valid_test_num_samples):

# The data prefix should be in the format of:
# weight-1, data-prefix-1, weight-2, data-prefix-2, ..
assert len(data_prefix) % 2 == 0
num_datasets = len(data_prefix) // 2
weights = [0] * num_datasets
prefixes = [0] * num_datasets
for i in range(num_datasets):
weights[i] = float(data_prefix[2 * i])
prefixes[i] = (data_prefix[2 * i + 1]).strip()
# Normalize weights
weight_sum = 0.0
for weight in weights:
weight_sum += weight
assert weight_sum > 0.0
weights = [weight / weight_sum for weight in weights]

# Add 0.5% (the 1.005 factor) so in case the bleding dataset does
# not uniformly distribute the number of samples, we still have
# samples left to feed to the network.
datasets_train_valid_test_num_samples = []
for weight in weights:
datasets_train_valid_test_num_samples.append(
[int(math.ceil(val * weight * 1.005)) for val in train_valid_test_num_samples]
)

return prefixes, weights, datasets_train_valid_test_num_samples
DSET_TYPES = [DSET_TYPE_BERT, DSET_TYPE_ICT, DSET_TYPE_T5, DSET_TYPE_T5_LM]


def compile_helper():
Expand Down Expand Up @@ -634,6 +610,7 @@ def build_dataset(index, name):
)

if dataset_type == DSET_TYPE_ICT:
raise NotImplementedError("ICT dataset is not implemented yet.")
dataset = ICTDataset(
block_dataset=indexed_dataset,
title_dataset=title_dataset,
MaximumEntropy marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -644,6 +621,7 @@ def build_dataset(index, name):
)
elif dataset_type == DSET_TYPE_T5:
assert tokenizer is not None, "Tokenizer is required for T5 dataset"
logging.info("Instatiating T5 Dataset ...")
dataset = T5Dataset(
cfg=cfg,
trainer=trainer,
Expand All @@ -661,6 +639,7 @@ def build_dataset(index, name):
**kwargs,
)
elif dataset_type == DSET_TYPE_BERT:
logging.info("Instatiating BERT Dataset ...")
dataset = BertDataset(
indexed_dataset=indexed_dataset,
masked_lm_prob=masked_lm_prob,
Expand All @@ -669,6 +648,18 @@ def build_dataset(index, name):
tokenizer=tokenizer,
**kwargs,
)
elif dataset_type == DSET_TYPE_T5_LM:
documents = np.arange(start=splits[index], stop=splits[index + 1], step=1, dtype=np.int32)
logging.info("Instatiating T5 Prefix-LM Dataset ...")
dataset = T5LMAdaptedDataset(
cfg=cfg,
trainer=trainer,
tokenizer=tokenizer,
documents=documents,
indexed_dataset=indexed_dataset,
num_samples=int(train_valid_test_num_samples[index]),
**kwargs,
)
else:
raise NotImplementedError("Dataset type not fully implemented.")

Expand Down Expand Up @@ -702,33 +693,6 @@ def get_indexed_dataset_(data_prefix, data_impl, skip_warmup):
return indexed_dataset


def get_train_valid_test_split_(splits_string, size):
""" Get dataset splits from comma or '/' separated string list."""

splits = []
if splits_string.find(',') != -1:
splits = [float(s) for s in splits_string.split(',')]
elif splits_string.find('/') != -1:
splits = [float(s) for s in splits_string.split('/')]
else:
splits = [float(splits_string)]
while len(splits) < 3:
splits.append(0.0)
splits = splits[:3]
splits_sum = sum(splits)
assert splits_sum > 0.0
splits = [split / splits_sum for split in splits]
splits_index = [0]
for index, split in enumerate(splits):
splits_index.append(splits_index[index] + int(round(split * float(size))))
diff = splits_index[-1] - size
for index in range(1, len(splits_index)):
splits_index[index] -= diff
assert len(splits_index) == 4
assert splits_index[-1] == size
return splits_index


def get_samples_mapping(
indexed_dataset, data_prefix, num_epochs, max_num_samples, max_seq_length, short_seq_prob, seed, name, binary_head
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@
import numpy as np
import torch

from nemo.collections.nlp.data.language_modeling.megatron.blendable_dataset import BlendableDataset
from nemo.collections.nlp.data.language_modeling.megatron.dataset_utils import (
from nemo.collections.nlp.data.language_modeling.megatron.base_dataset_utils import (
get_datasets_weights_and_num_samples,
get_train_valid_test_split_,
)
from nemo.collections.nlp.data.language_modeling.megatron.blendable_dataset import BlendableDataset
from nemo.collections.nlp.data.language_modeling.megatron.indexed_dataset import make_dataset as make_indexed_dataset
from nemo.collections.nlp.data.language_modeling.megatron.megatron_dataset import MegatronDataset
from nemo.utils import logging
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np

from nemo.collections.nlp.data.language_modeling.megatron.gpt_dataset import GPTDataset


class T5LMAdaptedDataset(GPTDataset):
"""
Dataset for unlearning span corruption (https://arxiv.org/abs/2104.08691) in T5 models.
michalivne marked this conversation as resolved.
Show resolved Hide resolved
Corresponds to the prefix-LM objective in the T5 paper (Table 3 in https://arxiv.org/abs/1910.10683).
"""

def __init__(
self, cfg, trainer, tokenizer, name, data_prefix, documents, indexed_dataset, num_samples, seed, **kwargs
):
self.seq_length_encoder = cfg.data.seq_length
self.seq_length_decoder = cfg.data.seq_length_dec
michalivne marked this conversation as resolved.
Show resolved Hide resolved
self.tokenizer = tokenizer
super().__init__(
cfg,
trainer,
name,
data_prefix,
documents,
indexed_dataset,
num_samples,
self.seq_length_encoder + self.seq_length_decoder,
seed,
)

def __getitem__(self, idx):
text = super().__getitem__(idx)
text = text['text']

# Split text sequence into encoder and decoder inputs
tokens_enc = text[: self.seq_length_encoder]

# NOTE: Add bos only and not eos because the model will always generate till max seq length.
michalivne marked this conversation as resolved.
Show resolved Hide resolved
tokens_dec = np.concatenate(([self.tokenizer.bos_id], text[self.seq_length_encoder :]))

# Shift sequences for teacher forcing
tokens_dec_in = tokens_dec[:-1]
labels = tokens_dec[1:]

# Create attention masks
enc_mask = (tokens_enc != self.tokenizer.pad_id).astype(np.int64)
dec_mask = (tokens_dec_in != self.tokenizer.pad_id).astype(np.int64)

loss_mask = dec_mask

train_sample = {
'text_enc': tokens_enc,
'text_dec': tokens_dec_in,
'labels': labels,
'loss_mask': loss_mask,
'truncated': False,
'enc_mask': enc_mask,
'dec_mask': dec_mask,
}
return train_sample
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,6 @@ def process_batch(self, batch):

keys = ['text_enc', 'text_dec', 'labels', 'loss_mask', 'enc_mask', 'dec_mask']
datatype = torch.int64

data = batch
data_b = tensor_parallel.broadcast_data(keys, data, datatype)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,12 @@ def build_train_valid_test_datasets(self):
eval_iters * global_batch_size,
test_iters * global_batch_size,
]
# Make sure the user specifies dataset type as either 't5' or 't5_prefix_lm' only.
if self._cfg.data.get('dataset_type', None) is not None:
if self._cfg.data.get('dataset_type') not in ['t5', 't5_prefix_lm']:
raise ValueError(
f"dataset_type must be either 't5' or 't5_prefix_lm'. found {self._cfg.data.get('dataset_type')}"
)
self._train_ds, self._validation_ds, self._test_ds = build_train_valid_test_datasets(
cfg=self._cfg,
trainer=self.trainer,
Expand All @@ -118,7 +124,7 @@ def build_train_valid_test_datasets(self):
short_seq_prob=self._cfg.data.short_seq_prob,
seed=self._cfg.seed,
skip_warmup=self._cfg.data.skip_warmup,
dataset_type='t5',
dataset_type=self._cfg.data.get('dataset_type', 't5'),
max_ngram_size=self._cfg.get('max_ngram_size', 10),
mean_ngram_size=self._cfg.get('mean_ngram_size', None),
geometric_dist=self._cfg.get('geometric_dist', True),
Expand Down