-
Notifications
You must be signed in to change notification settings - Fork 2.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* LM adapted T5 dataset Signed-off-by: MaximumEntropy <[email protected]> * Style fixes Signed-off-by: MaximumEntropy <[email protected]> * File renaming Signed-off-by: MaximumEntropy <[email protected]> * change assert to raising valueerror Signed-off-by: MaximumEntropy <[email protected]> * Style fixes Signed-off-by: MaximumEntropy <[email protected]> * Printing changes Signed-off-by: MaximumEntropy <[email protected]> * Style fixes Signed-off-by: MaximumEntropy <[email protected]> * Comment out ICT dataset Signed-off-by: MaximumEntropy <[email protected]> Co-authored-by: Micha Livne <[email protected]> Co-authored-by: Eric Harper <[email protected]>
- Loading branch information
1 parent
a8f29af
commit 49183b7
Showing
8 changed files
with
184 additions
and
64 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
75 changes: 75 additions & 0 deletions
75
nemo/collections/nlp/data/language_modeling/megatron/base_dataset_utils.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import math | ||
|
||
|
||
def get_datasets_weights_and_num_samples(data_prefix, train_valid_test_num_samples): | ||
|
||
# The data prefix should be in the format of: | ||
# weight-1, data-prefix-1, weight-2, data-prefix-2, .. | ||
assert len(data_prefix) % 2 == 0 | ||
num_datasets = len(data_prefix) // 2 | ||
weights = [0] * num_datasets | ||
prefixes = [0] * num_datasets | ||
for i in range(num_datasets): | ||
weights[i] = float(data_prefix[2 * i]) | ||
prefixes[i] = (data_prefix[2 * i + 1]).strip() | ||
# Normalize weights | ||
weight_sum = 0.0 | ||
for weight in weights: | ||
weight_sum += weight | ||
assert weight_sum > 0.0 | ||
weights = [weight / weight_sum for weight in weights] | ||
|
||
# Add 0.5% (the 1.005 factor) so in case the bleding dataset does | ||
# not uniformly distribute the number of samples, we still have | ||
# samples left to feed to the network. | ||
# TODO: check data leakage between train/val/test? | ||
datasets_train_valid_test_num_samples = [] | ||
for weight in weights: | ||
datasets_train_valid_test_num_samples.append( | ||
[int(math.ceil(val * weight * 1.005)) for val in train_valid_test_num_samples] | ||
) | ||
|
||
return prefixes, weights, datasets_train_valid_test_num_samples | ||
|
||
|
||
def get_train_valid_test_split_(splits_string, size): | ||
""" Get dataset splits from comma or '/' separated string list.""" | ||
|
||
splits = [] | ||
if splits_string.find(',') != -1: | ||
splits = [float(s) for s in splits_string.split(',')] | ||
elif splits_string.find('/') != -1: | ||
splits = [float(s) for s in splits_string.split('/')] | ||
else: | ||
splits = [float(splits_string)] | ||
if len(splits) != 3: | ||
raise ValueError(f"Invalid splits string: {splits_string}. Expected 3 comma separated values.") | ||
while len(splits) < 3: | ||
splits.append(0.0) | ||
splits = splits[:3] | ||
splits_sum = sum(splits) | ||
assert splits_sum > 0.0 | ||
splits = [split / splits_sum for split in splits] | ||
splits_index = [0] | ||
for index, split in enumerate(splits): | ||
splits_index.append(splits_index[index] + int(round(split * float(size)))) | ||
diff = splits_index[-1] - size | ||
for index in range(1, len(splits_index)): | ||
splits_index[index] -= diff | ||
assert len(splits_index) == 4 | ||
assert splits_index[-1] == size | ||
return splits_index |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
73 changes: 73 additions & 0 deletions
73
nemo/collections/nlp/data/language_modeling/megatron/lm_adapted_t5_dataset.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import numpy as np | ||
|
||
from nemo.collections.nlp.data.language_modeling.megatron.gpt_dataset import GPTDataset | ||
|
||
|
||
class T5LMAdaptedDataset(GPTDataset): | ||
""" | ||
Dataset for unlearning span corruption (https://arxiv.org/abs/2104.08691) in T5 models. | ||
Corresponds to the prefix-LM objective in the T5 paper (Table 3 in https://arxiv.org/abs/1910.10683). | ||
""" | ||
|
||
def __init__( | ||
self, cfg, trainer, tokenizer, name, data_prefix, documents, indexed_dataset, num_samples, seed, **kwargs | ||
): | ||
self.seq_length_encoder = cfg.data.seq_length | ||
self.seq_length_decoder = cfg.data.seq_length_dec | ||
self.tokenizer = tokenizer | ||
super().__init__( | ||
cfg, | ||
trainer, | ||
name, | ||
data_prefix, | ||
documents, | ||
indexed_dataset, | ||
num_samples, | ||
self.seq_length_encoder + self.seq_length_decoder, | ||
seed, | ||
) | ||
|
||
def __getitem__(self, idx): | ||
text = super().__getitem__(idx) | ||
text = text['text'] | ||
|
||
# Split text sequence into encoder and decoder inputs | ||
tokens_enc = text[: self.seq_length_encoder] | ||
|
||
# NOTE: Add bos only and not eos because the model will always generate till max seq length. | ||
tokens_dec = np.concatenate(([self.tokenizer.bos_id], text[self.seq_length_encoder :])) | ||
|
||
# Shift sequences for teacher forcing | ||
tokens_dec_in = tokens_dec[:-1] | ||
labels = tokens_dec[1:] | ||
|
||
# Create attention masks | ||
enc_mask = (tokens_enc != self.tokenizer.pad_id).astype(np.int64) | ||
dec_mask = (tokens_dec_in != self.tokenizer.pad_id).astype(np.int64) | ||
|
||
loss_mask = dec_mask | ||
|
||
train_sample = { | ||
'text_enc': tokens_enc, | ||
'text_dec': tokens_dec_in, | ||
'labels': labels, | ||
'loss_mask': loss_mask, | ||
'truncated': False, | ||
'enc_mask': enc_mask, | ||
'dec_mask': dec_mask, | ||
} | ||
return train_sample |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters