Skip to content

Commit

Permalink
Fix gpt trainer test (#6915)
Browse files Browse the repository at this point in the history
* Add trainer.test() for GPT

Signed-off-by: hsiehjackson <[email protected]>

* Remove unused part

Signed-off-by: hsiehjackson <[email protected]>

* Add trainer.test() for GPT

Signed-off-by: hsiehjackson <[email protected]>

* Remove unused part

Signed-off-by: hsiehjackson <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix training part

Signed-off-by: hsiehjackson <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix config

Signed-off-by: hsiehjackson <[email protected]>

* Fix references and add CI

Signed-off-by: hsiehjackson <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix config error

Signed-off-by: hsiehjackson <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix dataset

Signed-off-by: hsiehjackson <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add metadata

Signed-off-by: hsiehjackson <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix config

Signed-off-by: hsiehjackson <[email protected]>

* Fix empty batch

Signed-off-by: hsiehjackson <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix config

Signed-off-by: hsiehjackson <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix config

Signed-off-by: hsiehjackson <[email protected]>

* Fix max seq length

Signed-off-by: hsiehjackson <[email protected]>

* Fix dataset

Signed-off-by: hsiehjackson <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix dataset

Signed-off-by: hsiehjackson <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add token f1

Signed-off-by: hsiehjackson <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add FA in sft

Signed-off-by: hsiehjackson <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add inference config

Signed-off-by: hsiehjackson <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix bug

Signed-off-by: hsiehjackson <[email protected]>

* Fix pad

Signed-off-by: hsiehjackson <[email protected]>

* Fix num batch

Signed-off-by: hsiehjackson <[email protected]>

* Add query_key

Signed-off-by: hsiehjackson <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Remove pdb

Signed-off-by: hsiehjackson <[email protected]>

* Fix write json

Signed-off-by: hsiehjackson <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix dataset bug and refactor

Signed-off-by: hsiehjackson <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add logging for prediction

Signed-off-by: hsiehjackson <[email protected]>

* Fix retrain

Signed-off-by: hsiehjackson <[email protected]>

* Add query_key in config

Signed-off-by: hsiehjackson <[email protected]>

* Fix bug

Signed-off-by: hsiehjackson <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix config

Signed-off-by: hsiehjackson <[email protected]>

* Fix bug

Signed-off-by: hsiehjackson <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add inference config

Signed-off-by: hsiehjackson <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix bug

Signed-off-by: hsiehjackson <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix mask

Signed-off-by: hsiehjackson <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix mask

Signed-off-by: hsiehjackson <[email protected]>

* Split PR

Signed-off-by: Cheng-Ping Hsieh <[email protected]>

* Undo commit

Signed-off-by: Cheng-Ping Hsieh <[email protected]>

* Add query_key to doc_string

Signed-off-by: Cheng-Ping Hsieh <[email protected]>

* Adjust yzhang123 comments

Signed-off-by: Cheng-Ping Hsieh <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix error and follow comments

Signed-off-by: Cheng-Ping Hsieh <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Remove query key

Signed-off-by: Cheng-Ping Hsieh <[email protected]>

* Remove logic and query

Signed-off-by: Cheng-Ping Hsieh <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Remove query from model

Signed-off-by: Cheng-Ping Hsieh <[email protected]>

* Remove query_key

Signed-off-by: Cheng-Ping Hsieh <[email protected]>

* Fix error

Signed-off-by: Cheng-Ping Hsieh <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix pdb

Signed-off-by: Cheng-Ping Hsieh <[email protected]>

* Add default tokens_to_generate

Signed-off-by: Cheng-Ping Hsieh <[email protected]>

* Revert prompt truncate re-prompt

Signed-off-by: Cheng-Ping Hsieh <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Revert

Signed-off-by: Cheng-Ping Hsieh <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* skip generation with metric loss

Signed-off-by: Cheng-Ping Hsieh <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Revert

Signed-off-by: Cheng-Ping Hsieh <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix bug

Signed-off-by: Cheng-Ping Hsieh <[email protected]>

* support GPTSFTChatDataset

Signed-off-by: Cheng-Ping Hsieh <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add comment

Signed-off-by: Cheng-Ping Hsieh <[email protected]>

---------

Signed-off-by: hsiehjackson <[email protected]>
Signed-off-by: Cheng-Ping Hsieh <[email protected]>
Signed-off-by: Cheng-Ping Hsieh <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
hsiehjackson and pre-commit-ci[bot] authored Aug 10, 2023
1 parent 06f882f commit e0f8b9b
Show file tree
Hide file tree
Showing 12 changed files with 362 additions and 257 deletions.
4 changes: 4 additions & 0 deletions Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -3815,6 +3815,8 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"'''
model.data.test_ds.global_batch_size=1 \
model.data.test_ds.micro_batch_size=1 \
model.data.test_ds.tokens_to_generate=10 \
model.data.test_ds.write_predictions_to_file=True \
model.data.test_ds.output_file_path_prefix='/home/TestData/nlp/lora_tuning_tp2/out' \
inference.greedy=True \
inference.repetition_penalty=1.0 \
inference.outfile_path='/home/TestData/nlp/lora_tuning_tp2/out.jsonl'"
Expand Down Expand Up @@ -3874,6 +3876,8 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"'''
model.data.test_ds.micro_batch_size=1 \
model.data.test_ds.tokens_to_generate=30 \
model.data.test_ds.max_seq_length=6000 \
model.data.test_ds.write_predictions_to_file=True \
model.data.test_ds.output_file_path_prefix='examples/nlp/language_modeling/out' \
inference.greedy=True \
inference.repetition_penalty=1.0 \
inference.outfile_path='examples/nlp/language_modeling/out.jsonl' && \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ model:
# of each chunk at the specified granularity
# 'block' checkpoints the specified number of layers per pipeline stage at the specified granularity
activations_checkpoint_num_layers: null # not used with 'selective'
activations_checkpoint_layers_per_pipeline: null
answer_only_loss: False # not used right now
gradient_as_bucket_view: False

Expand Down Expand Up @@ -113,7 +114,7 @@ model:
truncation_field: ${data.train_ds.truncation_field} # Options: ['context', 'answer']
index_mapping_dir: null # Path to a directory to write index mapping files.
prompt_template: ${data.train_ds.prompt_template}
tokens_to_generate: ???
tokens_to_generate: 32 # decide how many tokens we want to generate to evaluate performance with string metrics

metric:
name: "loss" # Name of the evaluation metric to use. Options: ['exact_string_match', 'loss']
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ model:
# of each chunk at the specified granularity
# 'block' checkpoints the specified number of layers per pipeline stage at the specified granularity
activations_checkpoint_num_layers: null # not used with 'selective'
activations_checkpoint_layers_per_pipeline: null
answer_only_loss: True
gradient_as_bucket_view: False

Expand Down Expand Up @@ -160,7 +161,7 @@ model:
truncation_field: "context" # Options: ['context', 'answer']
index_mapping_dir: null # Path to a directory to write index mapping files.
prompt_template: ${model.data.train_ds.prompt_template} # fstring to use for assistant prompt. Example: "Q: {input}\nA: {output}"

tokens_to_generate: 32 # decide how many tokens we want to generate to evaluate performance with string metrics
metric:
name: "loss" # Name of the evaluation metric to use. Options: ['exact_string_match', 'loss']
average: null # Average the metric over the dataset. Options: ['macro', 'micro']. Works only for 'F1', 'accuracy' etc. Refer to torchmetrics for metrics where this is supported.
Expand Down Expand Up @@ -188,7 +189,7 @@ model:
truncation_field: "context" # Options: ['context', 'answer']
index_mapping_dir: null # Path to a directory to write index mapping files.
prompt_template: ${model.data.train_ds.prompt_template}

tokens_to_generate: 32 # decide how many tokens we want to generate to evaluate performance with string metrics
metric:
name: "loss" # Name of the evaluation metric to use. Options: ['exact_string_match', 'loss']
average: null # Average the metric over the dataset. Options: ['macro', 'micro']. Works only for 'F1', 'accuracy' etc. Refer to torchmetrics for metrics where this is supported.
Expand Down
32 changes: 23 additions & 9 deletions examples/nlp/language_modeling/tuning/conf/megatron_gpt_sft.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ model:
# of each chunk at the specified granularity
# 'block' checkpoints the specified number of layers per pipeline stage at the specified granularity
activations_checkpoint_num_layers: null # not used with 'selective'
activations_checkpoint_layers_per_pipeline: null
# This feature is valid only when used with pipeline-model-parallelism. More details in megatron_gpt_config.yaml.
answer_only_loss: False # not used right now
gradient_as_bucket_view: False
seq_len_interpolation_factor: null # if not None, seq_len_interpolation_factor will match the base model's value
Expand Down Expand Up @@ -109,15 +111,15 @@ model:
names: null # Names of the corresponding datasets used to log metrics.
global_batch_size: ${model.global_batch_size}
micro_batch_size: ${model.micro_batch_size}
shuffle: True
shuffle: False
num_workers: 4
memmap_workers: ${model.data.train_ds.memmap_workers}
pin_memory: True
max_seq_length: 2048
min_seq_length: 1
drop_last: True
context_key: 'input'
label_key: 'output'
drop_last: False
context_key: ${model.data.train_ds.context_key}
label_key: ${model.data.train_ds.label_key}
add_eos: ${model.data.train_ds.add_eos}
add_sep: ${model.data.train_ds.add_sep}
add_bos: ${model.data.train_ds.add_bos}
Expand All @@ -127,10 +129,11 @@ model:
truncation_field: "context" # Options: ['context', 'answer']
index_mapping_dir: null # Path to a directory to write index mapping files.
prompt_template: ${model.data.train_ds.prompt_template} # fstring to use for assistant prompt. Example: "Q: {input}\nA: {output}"
tokens_to_generate: 32 # decide how many tokens we want to generate to evaluate performance with string metrics
hf_dataset: False # Whether to load the json file with the HuggingFace dataset. otherwise, will load the jsonl file with the JSONLMemMapDataset.

metric:
name: "loss" # Name of the evaluation metric to use. Options: ['exact_string_match', 'loss']
name: "loss" # Name of the evaluation metric to use. Options: ['exact_string_match', 'loss', 'rouge', 'token_f1']
average: null # Average the metric over the dataset. Options: ['macro', 'micro']. Works only for 'F1', 'accuracy' etc. Refer to torchmetrics for metrics where this is supported.
num_classes: null

Expand All @@ -139,15 +142,15 @@ model:
names: null # Names of the corresponding datasets used to log metrics.
global_batch_size: ${model.global_batch_size}
micro_batch_size: ${model.micro_batch_size}
shuffle: True
shuffle: False
num_workers: 4
memmap_workers: ${model.data.train_ds.memmap_workers}
pin_memory: True
max_seq_length: 2048
min_seq_length: 1
drop_last: True
context_key: 'input'
label_key: 'output'
drop_last: False
context_key: ${model.data.train_ds.context_key}
label_key: ${model.data.train_ds.label_key}
add_eos: ${model.data.train_ds.add_eos}
add_sep: ${model.data.train_ds.add_sep}
add_bos: ${model.data.train_ds.add_bos}
Expand All @@ -171,3 +174,14 @@ model:
betas:
- 0.9
- 0.98

inference:
greedy: True # Whether or not to use sampling ; use greedy decoding otherwise
top_k: 0 # The number of highest probability vocabulary tokens to keep for top-k-filtering.
top_p: 0.9 # If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation.
temperature: 1.0 # sampling temperature
all_probs: False # whether return the log prob for all the tokens in vocab
repetition_penalty: 1.2 # The parameter for repetition penalty. 1.0 means no penalty.
min_tokens_to_generate: 0 # The minimum length of the sequence to be generated.
compute_logprob: False # a flag used to compute logprob of all the input text, a very special case of running inference, default False
compute_attention_mask: True
35 changes: 3 additions & 32 deletions examples/nlp/language_modeling/tuning/megatron_gpt_peft_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def main(cfg) -> None:
peft_model_cfg.data.test_ds = cfg.model.data.test_ds
peft_model_cfg.activations_checkpoint_granularity = None
peft_model_cfg.activations_checkpoint_method = None
peft_model_cfg.activations_checkpoint_layers_per_pipeline = None
if peft_model_cfg.get("use_flash_attention", False):
peft_model_cfg.use_flash_attention = cfg.model.use_flash_attention
if cfg.model.get("seq_len_interpolation_factor", None) is not None:
Expand Down Expand Up @@ -171,40 +172,10 @@ def main(cfg) -> None:
)

model.freeze()
_test_ds = model._build_dataset(peft_model_cfg.data.test_ds, is_train=False)
request_dl = DataLoader(
dataset=_test_ds[0],
batch_size=peft_model_cfg.data.test_ds.global_batch_size,
collate_fn=_test_ds[0].collate_fn,
)
config = OmegaConf.to_container(cfg.inference, resolve=True)
model.set_inference_config(config)
response = trainer.predict(model, request_dl)

if model.global_rank == 0:
print("***************************")
if cfg.inference.outfile_path is not None:
with open(cfg.inference.outfile_path, "w", encoding="utf-8") as f:
for batch in response:
batch_sentences = [s for s in batch['sentences']]
batch_tokens = [s for s in batch['tokens']]
if cfg.inference.compute_logprob:
batch_logprob = [s.tolist() for s in batch['logprob']]
for s, t, l in zip(batch_sentences, batch_tokens, batch_logprob):
if cfg.inference.get("verbose", False):
d = {
'sentence': s,
'tokens_with_logprobs': ', '.join([f"{_t} {_l:.4f}" for _t, _l in zip(t, l)]),
}
f.write(json.dumps(d, sort_keys=True, indent=2) + '\n')
else:
for s in batch_sentences:
d = {'sentence': s}
f.write(json.dumps(d) + '\n')
print("predictions saved to {}".format(cfg.inference.outfile_path))
else:
print(response)
print("***************************")

trainer.test(model)


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ def _modify_config(gpt_cfg, cfg, add_cfg_to_tree=False):
gpt_cfg.activations_checkpoint_granularity = cfg.model.get("activations_checkpoint_granularity", None)
gpt_cfg.activations_checkpoint_num_layers = cfg.model.get("activations_checkpoint_num_layers", None)
gpt_cfg.activations_checkpoint_method = cfg.model.get("activations_checkpoint_method", None)
gpt_cfg.activations_checkpoint_layers_per_pipeline = cfg.model.get(
"activations_checkpoint_layers_per_pipeline", None
)
gpt_cfg.data = cfg.model.data
gpt_cfg.optim = cfg.model.optim
gpt_cfg.precision = cfg.trainer.precision
Expand Down
9 changes: 9 additions & 0 deletions examples/nlp/language_modeling/tuning/megatron_gpt_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ def _modify_config(gpt_cfg, cfg, add_cfg_to_tree=False):
gpt_cfg.activations_checkpoint_granularity = cfg.model.get("activations_checkpoint_granularity", None)
gpt_cfg.activations_checkpoint_num_layers = cfg.model.get("activations_checkpoint_num_layers", None)
gpt_cfg.activations_checkpoint_method = cfg.model.get("activations_checkpoint_method", None)
gpt_cfg.activations_checkpoint_layers_per_pipeline = cfg.model.get(
"activations_checkpoint_layers_per_pipeline", None
)
gpt_cfg.data = cfg.model.data
gpt_cfg.optim = cfg.model.optim
gpt_cfg.precision = cfg.trainer.precision
Expand All @@ -61,6 +64,8 @@ def _modify_config(gpt_cfg, cfg, add_cfg_to_tree=False):
gpt_cfg.hidden_dropout = cfg.model.get('hidden_dropout', 0.0)
gpt_cfg.attention_dropout = cfg.model.get('attention_dropout', 0.0)
gpt_cfg.ffn_dropout = cfg.model.ffn_dropout
gpt_cfg.use_flash_attention = cfg.model.get('use_flash_attention', False)

sft_cls = MegatronGPTSFTModel
gpt_cfg.target = f"{sft_cls.__module__}.{sft_cls.__name__}"

Expand Down Expand Up @@ -200,6 +205,10 @@ def main(cfg) -> None:
validate_checkpoint_loading_args(cfg.model.pretrained_checkpoint)
model = load_from_checkpoint_dir(MegatronGPTSFTModel, cfg, trainer, modify_confg_fn=_modify_config)

if 'inference' in cfg:
config = OmegaConf.to_container(cfg.inference, resolve=True)
model.set_inference_config(config)

trainer.fit(model)


Expand Down
55 changes: 54 additions & 1 deletion nemo/collections/common/metrics/classification_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@
# limitations under the License.

import logging
from typing import List
import re
import string
from collections import Counter
from typing import List, Union

import torch
from torchmetrics import Metric
Expand Down Expand Up @@ -207,3 +210,53 @@ def update(self, pred: str, target: str):

def compute(self):
return self.correct.float() / self.total


class TokenF1Score(Metric):
"""Taken from the official evaluation script for v1.1 of the SQuAD dataset"""

def __init__(self, dist_sync_on_step=False, *args, **kwargs):
super().__init__(dist_sync_on_step=dist_sync_on_step)

self.add_state("correct", default=torch.tensor(0.0), dist_reduce_fx="sum")
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")

def update(self, pred: str, target: Union[str, List[str]]):
if isinstance(target, str):
self.correct += self.f1_score(pred, target)
elif isinstance(target, list):
self.correct += max([self.f1_score(pred, tgt) for tgt in target])
self.total += 1

def compute(self):
return self.correct.float() / self.total

def f1_score(self, prediction, ground_truth):
prediction_tokens = self.normalize(prediction).split()
ground_truth_tokens = self.normalize(ground_truth).split()
common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
num_same = sum(common.values())
if num_same == 0:
return 0.0
precision = 1.0 * num_same / len(prediction_tokens)
recall = 1.0 * num_same / len(ground_truth_tokens)
f1 = (2 * precision * recall) / (precision + recall)
return f1

def normalize(self, s):
"""Lower text and remove punctuation, articles and extra whitespace."""

def remove_articles(text):
return re.sub(r"\b(a|an|the)\b", " ", text)

def white_space_fix(text):
return " ".join(text.split())

def remove_punc(text):
exclude = set(string.punctuation)
return "".join(ch for ch in text if ch not in exclude)

def lower(text):
return text.lower()

return white_space_fix(remove_articles(remove_punc(lower(s))))
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from torchmetrics import Accuracy, AveragePrecision, F1Score, MatthewsCorrCoef, PearsonCorrCoef, SpearmanCorrCoef
from torchmetrics.text.rouge import ROUGEScore

from nemo.collections.common.metrics.classification_accuracy import ExactStringMatchMetric
from nemo.collections.common.metrics.classification_accuracy import ExactStringMatchMetric, TokenF1Score

__all__ = ['MetricStringToTorchMetric']

Expand All @@ -25,6 +25,7 @@
'accuracy': Accuracy,
'average_precision': AveragePrecision,
'f1': F1Score,
'token_f1': TokenF1Score,
'pearson_corr_coef': PearsonCorrCoef,
'spearman_corr_coef': SpearmanCorrCoef,
'matthews_corr_coef': MatthewsCorrCoef,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,6 @@ def preprocess(source: dict, tokenizer: TokenizerSpec, extra_id_2_token_id: int,
# not going to train on the header
target[:header_len] = IGNORE_INDEX
input_ids = torch.LongTensor(input_ids)

_mask_targets(
target,
tokenized_lens,
Expand All @@ -222,7 +221,11 @@ def preprocess(source: dict, tokenizer: TokenizerSpec, extra_id_2_token_id: int,
)
mask = (target != IGNORE_INDEX).bool()
assert mask.sum().item() != 0, "mask is empty"
return dict(input_ids=input_ids, mask=mask)
# Choose the last conversation as answer other history are context
last_ignore_index_pos = torch.nonzero(target == IGNORE_INDEX)[-1].item() + 1
context_ids = input_ids[:last_ignore_index_pos]
answer_ids = input_ids[last_ignore_index_pos:]
return dict(input_ids=input_ids, mask=mask, context_ids=context_ids, answer_ids=answer_ids)


def _check_token_in_vocab(tokenizer, token):
Expand Down Expand Up @@ -262,19 +265,28 @@ def _process_example(self, example):
"""
result = preprocess(example, self.tokenizer, self.extra_id_2_token_id, self.new_line_token_id)

# store metadata in dataset, in case user may have keys required in the prediction json files
metadata = {k: v for k, v in example.items() if k not in ['conversations']}
result['metadata'] = metadata

return result

def collate_fn(self, batch):
input_ids = [item['input_ids'][:-1].tolist() for item in batch]
labels = [item['input_ids'][1:].tolist() for item in batch]
contexts = [item['context_ids'].tolist() for item in batch]
answers = [item['answer_ids'].tolist() for item in batch]
loss_mask = [item['mask'][1:].tolist() for item in batch]
metadata = [item['metadata'] for item in batch]

max_length = max([len(x) for x in input_ids])
max_length = max(max([len(x) for x in input_ids]), max([len(x) for x in contexts]) + self.tokens_to_generate)
if max_length > self.max_seq_length:
# truncate the sequences if it is longer than max_seq_length
input_ids = [x[: self.max_seq_length] for x in input_ids]
labels = [x[: self.max_seq_length] for x in labels]
loss_mask = [x[: self.max_seq_length] for x in loss_mask]
contexts = [x[: self.max_seq_length] for x in contexts]

# increase max length to nearest multiple of 4 or 8
if self.pad_to_max_length:
max_length = self.max_seq_length
Expand All @@ -291,13 +303,20 @@ def collate_fn(self, batch):
)
labels = torch.LongTensor(self._collate_item(labels, max_length=max_length, pad_id=self.tokenizer.eos_id))
loss_mask = torch.LongTensor(self._collate_item(loss_mask, max_length=max_length, pad_id=0))
context_lengths = torch.LongTensor([len(x) for x in contexts])
contexts = torch.LongTensor(self._collate_item(contexts, max_length=max_length, pad_id=self.tokenizer.eos_id))
answers = torch.LongTensor(self._collate_item(answers, max_length=max_length, pad_id=self.tokenizer.eos_id))

processed_batch = {
'tokens': input_ids,
'labels': labels,
'attention_mask': attention_mask,
'loss_mask': loss_mask,
'position_ids': position_ids,
'contexts': contexts,
'context_lengths': context_lengths,
'answers': answers,
'metadata': metadata,
}

return processed_batch
Loading

0 comments on commit e0f8b9b

Please sign in to comment.