diff --git a/Jenkinsfile b/Jenkinsfile index 355c7d0643c0..4ec372ba796c 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -3815,6 +3815,8 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"''' model.data.test_ds.global_batch_size=1 \ model.data.test_ds.micro_batch_size=1 \ model.data.test_ds.tokens_to_generate=10 \ + model.data.test_ds.write_predictions_to_file=True \ + model.data.test_ds.output_file_path_prefix='/home/TestData/nlp/lora_tuning_tp2/out' \ inference.greedy=True \ inference.repetition_penalty=1.0 \ inference.outfile_path='/home/TestData/nlp/lora_tuning_tp2/out.jsonl'" @@ -3874,6 +3876,8 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"''' model.data.test_ds.micro_batch_size=1 \ model.data.test_ds.tokens_to_generate=30 \ model.data.test_ds.max_seq_length=6000 \ + model.data.test_ds.write_predictions_to_file=True \ + model.data.test_ds.output_file_path_prefix='examples/nlp/language_modeling/out' \ inference.greedy=True \ inference.repetition_penalty=1.0 \ inference.outfile_path='examples/nlp/language_modeling/out.jsonl' && \ diff --git a/examples/nlp/language_modeling/tuning/conf/megatron_gpt_peft_eval_config.yaml b/examples/nlp/language_modeling/tuning/conf/megatron_gpt_peft_eval_config.yaml index 14b931fc9929..fecf6b346142 100755 --- a/examples/nlp/language_modeling/tuning/conf/megatron_gpt_peft_eval_config.yaml +++ b/examples/nlp/language_modeling/tuning/conf/megatron_gpt_peft_eval_config.yaml @@ -60,6 +60,7 @@ model: # of each chunk at the specified granularity # 'block' checkpoints the specified number of layers per pipeline stage at the specified granularity activations_checkpoint_num_layers: null # not used with 'selective' + activations_checkpoint_layers_per_pipeline: null answer_only_loss: False # not used right now gradient_as_bucket_view: False @@ -113,7 +114,7 @@ model: truncation_field: ${data.train_ds.truncation_field} # Options: ['context', 'answer'] index_mapping_dir: null # Path to a directory to write index mapping files. prompt_template: ${data.train_ds.prompt_template} - tokens_to_generate: ??? + tokens_to_generate: 32 # decide how many tokens we want to generate to evaluate performance with string metrics metric: name: "loss" # Name of the evaluation metric to use. Options: ['exact_string_match', 'loss'] diff --git a/examples/nlp/language_modeling/tuning/conf/megatron_gpt_peft_tuning_config.yaml b/examples/nlp/language_modeling/tuning/conf/megatron_gpt_peft_tuning_config.yaml index 7972a7dbfc11..1e06e72295bb 100755 --- a/examples/nlp/language_modeling/tuning/conf/megatron_gpt_peft_tuning_config.yaml +++ b/examples/nlp/language_modeling/tuning/conf/megatron_gpt_peft_tuning_config.yaml @@ -68,6 +68,7 @@ model: # of each chunk at the specified granularity # 'block' checkpoints the specified number of layers per pipeline stage at the specified granularity activations_checkpoint_num_layers: null # not used with 'selective' + activations_checkpoint_layers_per_pipeline: null answer_only_loss: True gradient_as_bucket_view: False @@ -160,7 +161,7 @@ model: truncation_field: "context" # Options: ['context', 'answer'] index_mapping_dir: null # Path to a directory to write index mapping files. prompt_template: ${model.data.train_ds.prompt_template} # fstring to use for assistant prompt. Example: "Q: {input}\nA: {output}" - + tokens_to_generate: 32 # decide how many tokens we want to generate to evaluate performance with string metrics metric: name: "loss" # Name of the evaluation metric to use. Options: ['exact_string_match', 'loss'] average: null # Average the metric over the dataset. Options: ['macro', 'micro']. Works only for 'F1', 'accuracy' etc. Refer to torchmetrics for metrics where this is supported. @@ -188,7 +189,7 @@ model: truncation_field: "context" # Options: ['context', 'answer'] index_mapping_dir: null # Path to a directory to write index mapping files. prompt_template: ${model.data.train_ds.prompt_template} - + tokens_to_generate: 32 # decide how many tokens we want to generate to evaluate performance with string metrics metric: name: "loss" # Name of the evaluation metric to use. Options: ['exact_string_match', 'loss'] average: null # Average the metric over the dataset. Options: ['macro', 'micro']. Works only for 'F1', 'accuracy' etc. Refer to torchmetrics for metrics where this is supported. diff --git a/examples/nlp/language_modeling/tuning/conf/megatron_gpt_sft.yaml b/examples/nlp/language_modeling/tuning/conf/megatron_gpt_sft.yaml index 775584e4f25b..8d3b77600b2f 100644 --- a/examples/nlp/language_modeling/tuning/conf/megatron_gpt_sft.yaml +++ b/examples/nlp/language_modeling/tuning/conf/megatron_gpt_sft.yaml @@ -58,6 +58,8 @@ model: # of each chunk at the specified granularity # 'block' checkpoints the specified number of layers per pipeline stage at the specified granularity activations_checkpoint_num_layers: null # not used with 'selective' + activations_checkpoint_layers_per_pipeline: null + # This feature is valid only when used with pipeline-model-parallelism. More details in megatron_gpt_config.yaml. answer_only_loss: False # not used right now gradient_as_bucket_view: False seq_len_interpolation_factor: null # if not None, seq_len_interpolation_factor will match the base model's value @@ -109,15 +111,15 @@ model: names: null # Names of the corresponding datasets used to log metrics. global_batch_size: ${model.global_batch_size} micro_batch_size: ${model.micro_batch_size} - shuffle: True + shuffle: False num_workers: 4 memmap_workers: ${model.data.train_ds.memmap_workers} pin_memory: True max_seq_length: 2048 min_seq_length: 1 - drop_last: True - context_key: 'input' - label_key: 'output' + drop_last: False + context_key: ${model.data.train_ds.context_key} + label_key: ${model.data.train_ds.label_key} add_eos: ${model.data.train_ds.add_eos} add_sep: ${model.data.train_ds.add_sep} add_bos: ${model.data.train_ds.add_bos} @@ -127,10 +129,11 @@ model: truncation_field: "context" # Options: ['context', 'answer'] index_mapping_dir: null # Path to a directory to write index mapping files. prompt_template: ${model.data.train_ds.prompt_template} # fstring to use for assistant prompt. Example: "Q: {input}\nA: {output}" + tokens_to_generate: 32 # decide how many tokens we want to generate to evaluate performance with string metrics hf_dataset: False # Whether to load the json file with the HuggingFace dataset. otherwise, will load the jsonl file with the JSONLMemMapDataset. metric: - name: "loss" # Name of the evaluation metric to use. Options: ['exact_string_match', 'loss'] + name: "loss" # Name of the evaluation metric to use. Options: ['exact_string_match', 'loss', 'rouge', 'token_f1'] average: null # Average the metric over the dataset. Options: ['macro', 'micro']. Works only for 'F1', 'accuracy' etc. Refer to torchmetrics for metrics where this is supported. num_classes: null @@ -139,15 +142,15 @@ model: names: null # Names of the corresponding datasets used to log metrics. global_batch_size: ${model.global_batch_size} micro_batch_size: ${model.micro_batch_size} - shuffle: True + shuffle: False num_workers: 4 memmap_workers: ${model.data.train_ds.memmap_workers} pin_memory: True max_seq_length: 2048 min_seq_length: 1 - drop_last: True - context_key: 'input' - label_key: 'output' + drop_last: False + context_key: ${model.data.train_ds.context_key} + label_key: ${model.data.train_ds.label_key} add_eos: ${model.data.train_ds.add_eos} add_sep: ${model.data.train_ds.add_sep} add_bos: ${model.data.train_ds.add_bos} @@ -171,3 +174,14 @@ model: betas: - 0.9 - 0.98 + +inference: + greedy: True # Whether or not to use sampling ; use greedy decoding otherwise + top_k: 0 # The number of highest probability vocabulary tokens to keep for top-k-filtering. + top_p: 0.9 # If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation. + temperature: 1.0 # sampling temperature + all_probs: False # whether return the log prob for all the tokens in vocab + repetition_penalty: 1.2 # The parameter for repetition penalty. 1.0 means no penalty. + min_tokens_to_generate: 0 # The minimum length of the sequence to be generated. + compute_logprob: False # a flag used to compute logprob of all the input text, a very special case of running inference, default False + compute_attention_mask: True \ No newline at end of file diff --git a/examples/nlp/language_modeling/tuning/megatron_gpt_peft_eval.py b/examples/nlp/language_modeling/tuning/megatron_gpt_peft_eval.py index e226b17426b5..850ce8693c93 100644 --- a/examples/nlp/language_modeling/tuning/megatron_gpt_peft_eval.py +++ b/examples/nlp/language_modeling/tuning/megatron_gpt_peft_eval.py @@ -131,6 +131,7 @@ def main(cfg) -> None: peft_model_cfg.data.test_ds = cfg.model.data.test_ds peft_model_cfg.activations_checkpoint_granularity = None peft_model_cfg.activations_checkpoint_method = None + peft_model_cfg.activations_checkpoint_layers_per_pipeline = None if peft_model_cfg.get("use_flash_attention", False): peft_model_cfg.use_flash_attention = cfg.model.use_flash_attention if cfg.model.get("seq_len_interpolation_factor", None) is not None: @@ -171,40 +172,10 @@ def main(cfg) -> None: ) model.freeze() - _test_ds = model._build_dataset(peft_model_cfg.data.test_ds, is_train=False) - request_dl = DataLoader( - dataset=_test_ds[0], - batch_size=peft_model_cfg.data.test_ds.global_batch_size, - collate_fn=_test_ds[0].collate_fn, - ) config = OmegaConf.to_container(cfg.inference, resolve=True) model.set_inference_config(config) - response = trainer.predict(model, request_dl) - - if model.global_rank == 0: - print("***************************") - if cfg.inference.outfile_path is not None: - with open(cfg.inference.outfile_path, "w", encoding="utf-8") as f: - for batch in response: - batch_sentences = [s for s in batch['sentences']] - batch_tokens = [s for s in batch['tokens']] - if cfg.inference.compute_logprob: - batch_logprob = [s.tolist() for s in batch['logprob']] - for s, t, l in zip(batch_sentences, batch_tokens, batch_logprob): - if cfg.inference.get("verbose", False): - d = { - 'sentence': s, - 'tokens_with_logprobs': ', '.join([f"{_t} {_l:.4f}" for _t, _l in zip(t, l)]), - } - f.write(json.dumps(d, sort_keys=True, indent=2) + '\n') - else: - for s in batch_sentences: - d = {'sentence': s} - f.write(json.dumps(d) + '\n') - print("predictions saved to {}".format(cfg.inference.outfile_path)) - else: - print(response) - print("***************************") + + trainer.test(model) if __name__ == "__main__": diff --git a/examples/nlp/language_modeling/tuning/megatron_gpt_peft_tuning.py b/examples/nlp/language_modeling/tuning/megatron_gpt_peft_tuning.py index 572bc498d5c2..ada3e3dded89 100644 --- a/examples/nlp/language_modeling/tuning/megatron_gpt_peft_tuning.py +++ b/examples/nlp/language_modeling/tuning/megatron_gpt_peft_tuning.py @@ -85,6 +85,9 @@ def _modify_config(gpt_cfg, cfg, add_cfg_to_tree=False): gpt_cfg.activations_checkpoint_granularity = cfg.model.get("activations_checkpoint_granularity", None) gpt_cfg.activations_checkpoint_num_layers = cfg.model.get("activations_checkpoint_num_layers", None) gpt_cfg.activations_checkpoint_method = cfg.model.get("activations_checkpoint_method", None) + gpt_cfg.activations_checkpoint_layers_per_pipeline = cfg.model.get( + "activations_checkpoint_layers_per_pipeline", None + ) gpt_cfg.data = cfg.model.data gpt_cfg.optim = cfg.model.optim gpt_cfg.precision = cfg.trainer.precision diff --git a/examples/nlp/language_modeling/tuning/megatron_gpt_sft.py b/examples/nlp/language_modeling/tuning/megatron_gpt_sft.py index c5b8a3affed5..c17a62f9edd8 100644 --- a/examples/nlp/language_modeling/tuning/megatron_gpt_sft.py +++ b/examples/nlp/language_modeling/tuning/megatron_gpt_sft.py @@ -50,6 +50,9 @@ def _modify_config(gpt_cfg, cfg, add_cfg_to_tree=False): gpt_cfg.activations_checkpoint_granularity = cfg.model.get("activations_checkpoint_granularity", None) gpt_cfg.activations_checkpoint_num_layers = cfg.model.get("activations_checkpoint_num_layers", None) gpt_cfg.activations_checkpoint_method = cfg.model.get("activations_checkpoint_method", None) + gpt_cfg.activations_checkpoint_layers_per_pipeline = cfg.model.get( + "activations_checkpoint_layers_per_pipeline", None + ) gpt_cfg.data = cfg.model.data gpt_cfg.optim = cfg.model.optim gpt_cfg.precision = cfg.trainer.precision @@ -61,6 +64,8 @@ def _modify_config(gpt_cfg, cfg, add_cfg_to_tree=False): gpt_cfg.hidden_dropout = cfg.model.get('hidden_dropout', 0.0) gpt_cfg.attention_dropout = cfg.model.get('attention_dropout', 0.0) gpt_cfg.ffn_dropout = cfg.model.ffn_dropout + gpt_cfg.use_flash_attention = cfg.model.get('use_flash_attention', False) + sft_cls = MegatronGPTSFTModel gpt_cfg.target = f"{sft_cls.__module__}.{sft_cls.__name__}" @@ -200,6 +205,10 @@ def main(cfg) -> None: validate_checkpoint_loading_args(cfg.model.pretrained_checkpoint) model = load_from_checkpoint_dir(MegatronGPTSFTModel, cfg, trainer, modify_confg_fn=_modify_config) + if 'inference' in cfg: + config = OmegaConf.to_container(cfg.inference, resolve=True) + model.set_inference_config(config) + trainer.fit(model) diff --git a/nemo/collections/common/metrics/classification_accuracy.py b/nemo/collections/common/metrics/classification_accuracy.py index 09eb14b06406..46eca7474a5e 100644 --- a/nemo/collections/common/metrics/classification_accuracy.py +++ b/nemo/collections/common/metrics/classification_accuracy.py @@ -13,7 +13,10 @@ # limitations under the License. import logging -from typing import List +import re +import string +from collections import Counter +from typing import List, Union import torch from torchmetrics import Metric @@ -207,3 +210,53 @@ def update(self, pred: str, target: str): def compute(self): return self.correct.float() / self.total + + +class TokenF1Score(Metric): + """Taken from the official evaluation script for v1.1 of the SQuAD dataset""" + + def __init__(self, dist_sync_on_step=False, *args, **kwargs): + super().__init__(dist_sync_on_step=dist_sync_on_step) + + self.add_state("correct", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") + + def update(self, pred: str, target: Union[str, List[str]]): + if isinstance(target, str): + self.correct += self.f1_score(pred, target) + elif isinstance(target, list): + self.correct += max([self.f1_score(pred, tgt) for tgt in target]) + self.total += 1 + + def compute(self): + return self.correct.float() / self.total + + def f1_score(self, prediction, ground_truth): + prediction_tokens = self.normalize(prediction).split() + ground_truth_tokens = self.normalize(ground_truth).split() + common = Counter(prediction_tokens) & Counter(ground_truth_tokens) + num_same = sum(common.values()) + if num_same == 0: + return 0.0 + precision = 1.0 * num_same / len(prediction_tokens) + recall = 1.0 * num_same / len(ground_truth_tokens) + f1 = (2 * precision * recall) / (precision + recall) + return f1 + + def normalize(self, s): + """Lower text and remove punctuation, articles and extra whitespace.""" + + def remove_articles(text): + return re.sub(r"\b(a|an|the)\b", " ", text) + + def white_space_fix(text): + return " ".join(text.split()) + + def remove_punc(text): + exclude = set(string.punctuation) + return "".join(ch for ch in text if ch not in exclude) + + def lower(text): + return text.lower() + + return white_space_fix(remove_articles(remove_punc(lower(s)))) diff --git a/nemo/collections/common/metrics/metric_string_to_torchmetric.py b/nemo/collections/common/metrics/metric_string_to_torchmetric.py index 2d1e094a0d8b..b38047b576cc 100644 --- a/nemo/collections/common/metrics/metric_string_to_torchmetric.py +++ b/nemo/collections/common/metrics/metric_string_to_torchmetric.py @@ -15,7 +15,7 @@ from torchmetrics import Accuracy, AveragePrecision, F1Score, MatthewsCorrCoef, PearsonCorrCoef, SpearmanCorrCoef from torchmetrics.text.rouge import ROUGEScore -from nemo.collections.common.metrics.classification_accuracy import ExactStringMatchMetric +from nemo.collections.common.metrics.classification_accuracy import ExactStringMatchMetric, TokenF1Score __all__ = ['MetricStringToTorchMetric'] @@ -25,6 +25,7 @@ 'accuracy': Accuracy, 'average_precision': AveragePrecision, 'f1': F1Score, + 'token_f1': TokenF1Score, 'pearson_corr_coef': PearsonCorrCoef, 'spearman_corr_coef': SpearmanCorrCoef, 'matthews_corr_coef': MatthewsCorrCoef, diff --git a/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_chat_dataset.py b/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_chat_dataset.py index 733ff0f829cd..e05a0d7a5425 100644 --- a/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_chat_dataset.py +++ b/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_chat_dataset.py @@ -207,7 +207,6 @@ def preprocess(source: dict, tokenizer: TokenizerSpec, extra_id_2_token_id: int, # not going to train on the header target[:header_len] = IGNORE_INDEX input_ids = torch.LongTensor(input_ids) - _mask_targets( target, tokenized_lens, @@ -222,7 +221,11 @@ def preprocess(source: dict, tokenizer: TokenizerSpec, extra_id_2_token_id: int, ) mask = (target != IGNORE_INDEX).bool() assert mask.sum().item() != 0, "mask is empty" - return dict(input_ids=input_ids, mask=mask) + # Choose the last conversation as answer other history are context + last_ignore_index_pos = torch.nonzero(target == IGNORE_INDEX)[-1].item() + 1 + context_ids = input_ids[:last_ignore_index_pos] + answer_ids = input_ids[last_ignore_index_pos:] + return dict(input_ids=input_ids, mask=mask, context_ids=context_ids, answer_ids=answer_ids) def _check_token_in_vocab(tokenizer, token): @@ -262,19 +265,28 @@ def _process_example(self, example): """ result = preprocess(example, self.tokenizer, self.extra_id_2_token_id, self.new_line_token_id) + # store metadata in dataset, in case user may have keys required in the prediction json files + metadata = {k: v for k, v in example.items() if k not in ['conversations']} + result['metadata'] = metadata + return result def collate_fn(self, batch): input_ids = [item['input_ids'][:-1].tolist() for item in batch] labels = [item['input_ids'][1:].tolist() for item in batch] + contexts = [item['context_ids'].tolist() for item in batch] + answers = [item['answer_ids'].tolist() for item in batch] loss_mask = [item['mask'][1:].tolist() for item in batch] + metadata = [item['metadata'] for item in batch] - max_length = max([len(x) for x in input_ids]) + max_length = max(max([len(x) for x in input_ids]), max([len(x) for x in contexts]) + self.tokens_to_generate) if max_length > self.max_seq_length: # truncate the sequences if it is longer than max_seq_length input_ids = [x[: self.max_seq_length] for x in input_ids] labels = [x[: self.max_seq_length] for x in labels] loss_mask = [x[: self.max_seq_length] for x in loss_mask] + contexts = [x[: self.max_seq_length] for x in contexts] + # increase max length to nearest multiple of 4 or 8 if self.pad_to_max_length: max_length = self.max_seq_length @@ -291,6 +303,9 @@ def collate_fn(self, batch): ) labels = torch.LongTensor(self._collate_item(labels, max_length=max_length, pad_id=self.tokenizer.eos_id)) loss_mask = torch.LongTensor(self._collate_item(loss_mask, max_length=max_length, pad_id=0)) + context_lengths = torch.LongTensor([len(x) for x in contexts]) + contexts = torch.LongTensor(self._collate_item(contexts, max_length=max_length, pad_id=self.tokenizer.eos_id)) + answers = torch.LongTensor(self._collate_item(answers, max_length=max_length, pad_id=self.tokenizer.eos_id)) processed_batch = { 'tokens': input_ids, @@ -298,6 +313,10 @@ def collate_fn(self, batch): 'attention_mask': attention_mask, 'loss_mask': loss_mask, 'position_ids': position_ids, + 'contexts': contexts, + 'context_lengths': context_lengths, + 'answers': answers, + 'metadata': metadata, } return processed_batch diff --git a/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_dataset.py b/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_dataset.py index cb62b3ac9f7a..39c335ac19b8 100644 --- a/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_dataset.py +++ b/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_dataset.py @@ -72,7 +72,7 @@ def __init__( truncation_field: Field to use for truncation. (Options: "answer", "context"). Field to be used for truncation if the combined length exceeds the max sequence length. pad_to_max_length: Whether to pad the input to the max sequence length. If False, will pad to the max length of the current batch. index_mapping_dir: Directory to save the index mapping to. If None, will write to the same folder as the dataset. - prompt_template: Prompt template to inject via an fstring. Formatted like Q: {input}\n\nA: {output} + prompt_template: Prompt template to inject via an fstring. Formatted like Q: {context_key}\n\nA: {label_key} hf_dataset: Whether to load the json file with the HuggingFace dataset. otherwise, will load the jsonl file with the JSONLMemMapDataset. """ self.tokenizer = tokenizer @@ -150,6 +150,9 @@ def __getitem__(self, idx): idx = idx.item() assert idx < len(self.indexed_dataset) + # idx may < 0 because we pad_samples_to_global_batch_size, e.g. id = -1 + if idx < 0: + idx = len(self) + idx try: example = self.indexed_dataset[idx] except Exception as e: @@ -185,9 +188,9 @@ def _process_example(self, example): f'{{{self.label_key}}}', output ) - if self.separate_prompt_and_response_with_newline and self.prompt_template is None: + elif self.separate_prompt_and_response_with_newline: text = context + '\n' + output - elif not self.separate_prompt_and_response_with_newline and self.prompt_template is None: + else: text = context + ' ' + output if self.virtual_tokens: @@ -206,7 +209,8 @@ def _process_example(self, example): total_ids += 1 if self.add_sep: total_ids += 1 - if self.add_eos: + # Only training need to consider eos token + if self.add_eos and self.tokens_to_generate == 0: total_ids += 1 # If the total number of token is greater than the max, we will try to truncate the answer @@ -217,34 +221,41 @@ def _process_example(self, example): elif self.truncation_field == "context": context_ids = context_ids[: -min(truncation_length, len(context_ids))] - if len(context_ids) > self.max_seq_length: - context_ids = context_ids[: self.max_seq_length] - - assert len(context_ids) <= self.max_seq_length input_ids = context_ids - answer_start_idx = len(input_ids) + + # Adds bos token in the start + if self.add_bos: + context_ids = [self.tokenizer.bos_id] + context_ids + input_ids = [self.tokenizer.bos_id] + input_ids + answer_start_idx += 1 + # Adds sep token between text/prompt and answer if self.add_sep: + context_ids = context_ids + [self.sep_id] input_ids = input_ids + [self.sep_id] answer_start_idx += 1 input_ids = input_ids + answer_ids - if self.add_bos: - input_ids = [self.tokenizer.bos_id] + input_ids - answer_start_idx += 1 - if self.add_eos: + # Only training need to consider eos token + if self.add_eos and self.tokens_to_generate == 0: input_ids = input_ids + [self.tokenizer.eos_id] - if len(input_ids) < self.min_seq_length or len(input_ids) > self.max_seq_length: + if len(input_ids) > self.max_seq_length: + logging.warning(f'Input ids length {len(input_ids)} exceed max sequence length {self.max_seq_length}') input_ids = input_ids[: self.max_seq_length] + # store metadata in dataset, in case user may have keys required in the prediction json files + metadata = {k: v for k, v in example.items() if k not in [self.context_key, self.label_key]} + processed_example = { 'input_ids': input_ids, 'answer_start_idx': answer_start_idx, 'context_ids': context_ids, 'context_length': len(context_ids), + 'answer_ids': answer_ids, + 'metadata': metadata, } return processed_example @@ -292,9 +303,11 @@ def collate_fn(self, batch): labels = [item['input_ids'][1:] for item in batch] contexts = [item['context_ids'] for item in batch] context_lengths = torch.LongTensor([item['context_length'] for item in batch]) + answers = [item['answer_ids'] for item in batch] loss_mask = [self._build_loss_mask(item)[1:] for item in batch] + metadata = [item['metadata'] for item in batch] - max_length = max([len(x) for x in input_ids]) + self.tokens_to_generate + max_length = max(max([len(x) for x in input_ids]), max([len(x) for x in contexts]) + self.tokens_to_generate) # increase max length to nearest multiple of 4 or 8 if self.pad_to_max_length: max_length = self.max_seq_length @@ -312,6 +325,7 @@ def collate_fn(self, batch): labels = torch.LongTensor(self._collate_item(labels, max_length=max_length, pad_id=self.tokenizer.eos_id)) loss_mask = torch.LongTensor(self._collate_item(loss_mask, max_length=max_length, pad_id=0)) contexts = torch.LongTensor(self._collate_item(contexts, max_length=max_length, pad_id=self.tokenizer.eos_id)) + answers = torch.LongTensor(self._collate_item(answers, max_length=max_length, pad_id=self.tokenizer.eos_id)) processed_batch = { 'tokens': input_ids, @@ -321,6 +335,8 @@ def collate_fn(self, batch): 'position_ids': position_ids, 'contexts': contexts, 'context_lengths': context_lengths, + 'answers': answers, + 'metadata': metadata, } return processed_batch diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py index 2dbe6353cd7d..5b065b834a3a 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import itertools import json from functools import partial from typing import Any, Optional @@ -45,6 +45,7 @@ try: from apex.transformer.pipeline_parallel.utils import ( _reconfigure_microbatch_calculator, + get_current_global_batch_size, get_micro_batch_size, get_num_microbatches, ) @@ -79,20 +80,27 @@ def __init__(self, cfg: DictConfig, trainer: Trainer): ) super().__init__(cfg, trainer=trainer) self.sep_id = cfg.get('sep_id', 49704) - self.val_metric, self.val_metric_name = self.setup_metric(self.cfg.data.validation_ds) - self.val_metric = torch.nn.ModuleList(self.val_metric) if self.val_metric is not None else None + if hasattr(self.cfg.data, "validation_ds"): + self.val_metric, self.val_metric_name = self.setup_metric(self.cfg.data.validation_ds) + self.val_metric = torch.nn.ModuleList(self.val_metric) if self.val_metric is not None else None + # Used other keys from metadata to calulate metrics + if hasattr(self.cfg.data.validation_ds, "metric"): + self.val_metric_label_key = self.cfg.data.validation_ds.metric.get('label_key', 'labels') + if hasattr(self.cfg.data, "test_ds"): self.test_metric, self.test_metric_name = self.setup_metric(self.cfg.data.test_ds) self.test_metric = torch.nn.ModuleList(self.test_metric) if self.test_metric is not None else None + # Used other keys from metadata to calulate metrics + if hasattr(self.cfg.data.test_ds, "metric"): + self.test_metric_label_key = self.cfg.data.test_ds.metric.get('label_key', 'labels') if self.cfg.get('megatron_amp_O2', False): base_module = self.model.module else: base_module = self.model - self.original_checkpointing_granularity = base_module.language_model.encoder.activations_checkpoint_granularity - self.original_checkpointing_num_layers = base_module.language_model.encoder.activations_checkpoint_num_layers - self.original_checkpointing_method = base_module.language_model.encoder.activations_checkpoint_method + self._reset_activation_checkpointing_args() + self._reset_sequence_parallelism_args() self.virtual_tokens = 0 def setup_metric(self, data_cfg): @@ -244,6 +252,16 @@ def _build_dataset(self, data_cfg, is_train=True): else: num_train_samples_per_dataset = [[None]] * len(data_cfg.file_names) + # Check dataset max_seq_legnth and max_position_embeddings size + if ( + self.cfg.get('position_embedding_type', None) in [None, 'learned_absolute'] + and data_cfg.max_seq_length > self.cfg.max_position_embeddings + ): + logging.warning( + f"Set dataset max_seq_length to max_position_embeddings {self.cfg.max_position_embeddings} if using learned_absolute position embedding" + ) + data_cfg.max_seq_length = self.cfg.max_position_embeddings + for file_path, num_samples in zip(data_cfg.file_names, num_train_samples_per_dataset): if self.cfg.data.get("chat", False): dataset_cls = GPTSFTChatDataset @@ -267,7 +285,7 @@ def _build_dataset(self, data_cfg, is_train=True): ), answer_only_loss=self.cfg.get('answer_only_loss', True), truncation_field=data_cfg.get('truncation_field', 'context'), - pad_to_max_length=False, + pad_to_max_length=data_cfg.get('pad_to_max_length', False), index_mapping_dir=data_cfg.get('index_mapping_dir', None), prompt_template=data_cfg.get('prompt_template', None), virtual_tokens=self.virtual_tokens, @@ -307,6 +325,8 @@ def _determine_log_key(self, data_config, dataloader_idx, metric_name, mode): def fwd_bwd_step(self, dataloader_iter, batch_idx, forward_only): batch = next(dataloader_iter) + # Pass only torch.Tensor to prevent errors when process get_iterator_k_split() + batch = {k: v for k, v in batch.items() if isinstance(v, torch.Tensor)} _, seq_length = batch['tokens'].shape tensor_shape = [seq_length, get_micro_batch_size(), self.cfg.hidden_size] data_iter = get_iterator_k_split(batch, get_num_microbatches()) @@ -380,92 +400,50 @@ def test_step(self, dataloader_iter, batch_idx, dataloader_idx=0): return self.inference_step(dataloader_iter, batch_idx, 'test', dataloader_idx) def inference_step(self, dataloader_iter, batch_idx, mode, dataloader_idx=0): - # Call parent validation step to get the loss. - loss = super().validation_step(dataloader_iter, batch_idx) - # loss can be None as super().validation_step returns None when dataloader_iter is exhausted - # which can lead to error, adding check to prevent it - if loss is not None: # Ensure its not None - outputs = { - 'loss': loss, - 'preds': None, - 'labels': None, - 'inputs': None, - } - if mode == 'validation': - if type(self.trainer.val_dataloaders) == list and len(self.trainer.val_dataloaders) > 1: - # super().validation_step appends just loss to self.validation_step_outputs, replace the last appended loss with the outputs dict - self.validation_step_outputs[dataloader_idx][-1] = outputs - else: - # super().validation_step appends just loss to self.validation_step_outputs, replace the last appended loss with the outputs dict - self.validation_step_outputs[-1] = outputs + batch = next(dataloader_iter) + data_cfg = self.cfg.data.validation_ds if mode == 'validation' else self.cfg.data.test_ds + self._reconfigure_and_process_inference_batch(batch, data_cfg) + # Meta data from dataset + metadata = batch.get('metadata', [{}] * len(batch['tokens'])) + loss = super().validation_step(itertools.chain([batch]), batch_idx) + + # We need _inference_config to get generation params + # add_BOS and tokens_to_generate are set in dataset + if self.get_inference_config() is None: + self.set_inference_config(inference_config={}) + self._inference_config['add_BOS'] = data_cfg.add_bos + self._inference_config['tokens_to_generate'] = data_cfg.get('tokens_to_generate') + + output = self.predict_step(batch, batch_idx, dataloader_idx) + + inputs_text = [self.tokenizer.ids_to_text(c.tolist()) for c in batch['contexts']] + labels_text = [self.tokenizer.ids_to_text(a.tolist()) for a in batch['answers']] + preds_text = [ + self.tokenizer.ids_to_text(t[l.item() :][: data_cfg.get('tokens_to_generate')]) + for t, l in zip(output['token_ids'], batch['context_lengths']) + ] + + outputs = { + 'loss': loss, + 'preds': preds_text, # [str] + 'labels': labels_text, # [str] + 'inputs': inputs_text, # [str] + 'metadata': metadata, # [dict] + } + + if mode == 'validation': + if type(self.trainer.val_dataloaders) == list and len(self.trainer.val_dataloaders) > 1: + # super().validation_step appends just loss to self.validation_step_outputs, replace the last appended loss with the outputs dict + self.validation_step_outputs[dataloader_idx][-1] = outputs else: - if type(self.trainer.test_dataloaders) == list and len(self.trainer.test_dataloaders) > 1: - self.test_step_outputs[dataloader_idx].append(outputs) - else: - self.test_step_outputs.append(outputs) - return outputs - # TODO (sandeepsub): Figure out the subsequent decode bits. - # length_params: LengthParam = { - # "min_length": 0, - # "max_length": batch['tokens'].size(1) - batch['context_lengths'].max(), - # } - # sampling_params: SamplingParam = { - # "use_greedy": True, - # "temperature": 1.0, - # "top_k": 1, - # "top_p": 0.94, - # "repetition_penalty": 1.2, - # "add_BOS": False, - # "all_probs": False, - # "compute_logprob": False, - # "end_strings": ["<|endoftext|>"], - # } - # result = megatron_gpt_generate( - # model=self, - # inputs=( - # batch['tokens'].cuda(), - # (batch['context_lengths'] - 1).cuda(), - # ), # NOTE: We do -1 here to remove the space between context and response. - # tokenizer=self.tokenizer, - # sampling_params=sampling_params, - # length_params=length_params, - # check_sequence_parallel_and_checkpointing=False, # We need to skip these checks since we'll manually enbale and disable checkpointing between training and validation. - # ) - - # preds_text = [] - # labels_text = [] - # input_text = [] - # for idx, item in enumerate(result['token_ids']): - # pred = self.tokenizer.ids_to_text(item[batch['context_lengths'][idx] - 1 :]) - # input = self.tokenizer.ids_to_text(item[: batch['context_lengths'][idx] - 1]) - # label = self.tokenizer.ids_to_text(batch['tokens'][idx][batch['context_lengths'][idx] :].tolist()) - # preds_text.append(pred.strip()) - # labels_text.append(label.strip()) - # input_text.append(input.strip()) - - # metric = self.val_metric[dataloader_idx] if mode == 'validation' else self.test_metric[dataloader_idx] - # assert len(preds_text) == len(labels_text) == len(input_text) - # for _, (pred, label) in enumerate(zip(preds_text, labels_text)): - # # To compute metrics like pearson or spearman correlation, we need to cast the predicted string and labels to floats. - # pred, label = self.cast_for_metric( - # pred=pred.strip(), - # label=label.strip(), - # metric_name=self.val_metric_name if mode == 'validation' else self.test_metric_name, - # class_labels=self.cfg.data.validation_ds.metric.get('class_labels', None) - # if mode == 'validation' - # else self.cfg.data.test_ds.metric.get('class_labels', None), - # labels_are_strings=self.cfg.data.validation_ds.metric.get('labels_are_strings', False) - # if mode == 'validation' - # else self.cfg.data.test_ds.metric.get('labels_are_strings', False), - # ) - # _ = metric(pred, label) - - # return { - # 'loss': loss, - # 'preds': preds_text, - # 'labels': labels_text, - # 'inputs': input_text, - # } + # super().validation_step appends just loss to self.validation_step_outputs, replace the last appended loss with the outputs dict + self.validation_step_outputs[-1] = outputs + else: + if type(self.trainer.test_dataloaders) == list and len(self.trainer.test_dataloaders) > 1: + self.test_step_outputs[dataloader_idx][-1] = outputs + else: + self.test_step_outputs[-1] = outputs + return outputs def inference_epoch_end(self, outputs, mode, data_cfg): # Parent class will handle logging of the loss. @@ -477,7 +455,6 @@ def inference_epoch_end(self, outputs, mode, data_cfg): averaged_loss = [] averaged_metric = [] - metric_name = self.val_metric_name if mode == 'validation' else self.test_metric_name # Log metrics for each provided validation/test dataset. for dataloader_idx, output in enumerate(outputs): # Expand on_validation_epoch_end from parent class MegatronGPTModel as on_validation_epoch_end doesnt take outputs arg @@ -505,74 +482,90 @@ def inference_epoch_end(self, outputs, mode, data_cfg): self.log(loss_log_key, loss, batch_size=1) averaged_loss.append(loss) - # Skip the rest of this loop if the user wants to monitor the loss only. - if self.val_metric is None: - continue - # Determine the key used to log the eval metric based on the user provided name of the dataset or the dataloader index. - metric_log_key = self._determine_log_key(data_cfg, dataloader_idx, metric_name, mode) - metric_object = ( - self.val_metric[dataloader_idx] if mode == 'validation' else self.test_metric[dataloader_idx] + # Gather the outputs object from all data parallel ranks since we are using the DistributedSampler which splits data across DDP ranks. + gathered_outputs = [None for _ in range(parallel_state.get_data_parallel_world_size())] + torch.distributed.all_gather_object( + gathered_outputs, + [ + {'preds': x['preds'], 'labels': x['labels'], 'inputs': x['inputs'], 'metadata': x['metadata']} + for x in output + ], + group=parallel_state.get_data_parallel_group(), ) - metric = metric_object.compute() - # Handle logging of GLUE/XNLI separately here. XNLI has a separate metric per language. - if isinstance(metric, dict): - if metric_name == 'rouge': - metric = metric['rougeL_fmeasure'] + + # Remove duplicate examples due to distributed sampler. + inp_label_set = set() + deduplicated_outputs = { + 'preds': [], + 'labels': [], + 'inputs': [], + 'metadata': [], + } + total_size = 0 + for rank in range(0, parallel_state.get_data_parallel_world_size()): + for batch in gathered_outputs[rank]: + for pred, label, input, metadata in zip( + batch['preds'], batch['labels'], batch['inputs'], batch['metadata'] + ): + key = input + label + total_size += 1 + if key not in inp_label_set: + inp_label_set.add(key) + deduplicated_outputs['preds'].append(pred) + deduplicated_outputs['labels'].append(label) + deduplicated_outputs['inputs'].append(input) + deduplicated_outputs['metadata'].append(metadata) + + # Compute metric score + metric_name = self.val_metric_name if mode == 'validation' else self.test_metric_name + metric_label_key = self.val_metric_label_key if mode == 'validation' else self.test_metric_label_key + if metric_name != 'loss': + metric_log_key = self._determine_log_key(data_cfg, dataloader_idx, metric_name, mode) + metric_fn = ( + self.val_metric[dataloader_idx] if mode == 'validation' else self.test_metric[dataloader_idx] + ) + if metric_label_key in deduplicated_outputs['metadata'][0]: + labels = [m[metric_label_key] for m in deduplicated_outputs['metadata']] else: - metric = metric['acc'] - torch.distributed.all_reduce( - metric, op=torch.distributed.ReduceOp.SUM, group=parallel_state.get_data_parallel_group() - ) - metric = metric / parallel_state.get_data_parallel_world_size() - self.log(metric_log_key, metric) - logging.info(f"{mode} {metric_name}: {metric}") + labels = deduplicated_outputs['labels'] + + for pred, label in zip(deduplicated_outputs['preds'], labels): + _ = metric_fn(pred, label) - metric_object.reset() + metric_result = metric_fn.compute() + + if metric_name == 'rouge': + for k, v in metric_result.items(): + if 'fmeasure' in k: + self.log(metric_log_key + f'_{k}', v.item(), sync_dist=True) + logging.info(f"{mode} {metric_name} {k}: {v.item()}") + metric_result = metric_result['rouge1_fmeasure'] + else: + self.log(metric_log_key, metric_result.item(), sync_dist=True) + logging.info(f"{mode} {metric_name}: {metric_result.item()}") - averaged_metric.append(metric) + metric_fn.reset() + averaged_metric.append(metric_result) - # Write predictions, labels, and inputs to a file for each validation/test dataset. - if data_cfg.get("write_predictions_to_file", False): + # Write predictions to file + if self.global_rank == 0 and data_cfg.get("write_predictions_to_file", False): + logging.info( + f"Total deduplicated inference data size: {total_size} to {len(deduplicated_outputs['inputs'])}" + ) # Check if the user provided a prefix path to the file(s) they want to write. if not hasattr(data_cfg, "output_file_path_prefix") or data_cfg.output_file_path_prefix is None: raise ValueError( f"Cannot write predictions to file when output_file_path_prefix is not set or present in the yaml config file." ) - - # Gather the outputs object from all data parallel ranks since we are using the DistributedSampler which splits data across DDP ranks. - gathered_outputs = [None for _ in range(parallel_state.get_data_parallel_world_size())] - torch.distributed.all_gather_object( - gathered_outputs, - [{'preds': x['preds'], 'labels': x['labels'], 'inputs': x['inputs'],} for x in output], - group=parallel_state.get_data_parallel_group(), - ) - - # Figure out what the suffix of the file should be. filename_log_key = self._determine_log_key(data_cfg, dataloader_idx, None, mode) + self.write_predictions_to_file( + deduplicated_outputs, f"{data_cfg.output_file_path_prefix}_{filename_log_key}" + ) - # Keep a set of ground truths and inputs to write deduplicated predictions. Distributed Sampler may duplicate examples. - gt_inp_set = set() - deduplicated_outputs = { - 'preds': [], - 'labels': [], - 'inputs': [], - } - - # PTL models have a self.global_rank attribute and we want to write to disk only on global rank 0. - if self.global_rank == 0: - for rank in range(0, parallel_state.get_data_parallel_world_size()): - for batch in gathered_outputs[rank]: - for pred, label, input in zip(batch['preds'], batch['labels'], batch['inputs']): - gt_inp_set.add(input + label) - deduplicated_outputs['preds'].append(pred) - deduplicated_outputs['labels'].append(label) - deduplicated_outputs['inputs'].append(input) - self.write_predictions_to_file( - deduplicated_outputs, f"{data_cfg.output_file_path_prefix}_{filename_log_key}" - ) - torch.distributed.barrier() + torch.distributed.barrier(group=parallel_state.get_data_parallel_group()) outputs[dataloader_idx].clear() # free memory + # Logging of the averaged metrics: averaged_loss = sum(averaged_loss) / len(averaged_loss) averaged_metric = sum(averaged_metric) / len(averaged_metric) if len(averaged_metric) > 1 else None @@ -589,7 +582,7 @@ def inference_epoch_end(self, outputs, mode, data_cfg): if averaged_metric is not None: self.log(f"validation_{self.val_metric_name}", averaged_metric) elif mode == 'test': - self.log("test_loss", averaged_loss) + self.log("test_loss", averaged_loss, batch_size=1) if averaged_metric is not None: self.log(f"test_{self.test_metric_name}", averaged_metric) @@ -619,11 +612,12 @@ def inference_epoch_end(self, outputs, mode, data_cfg): def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int] = None) -> Any: inference_config = self.get_inference_config() - if inference_config is None: - return None # need to overwrite some configuration, make it immutable inference_config = inference_config.copy() - compute_logprob = inference_config['compute_logprob'] + global_batch_size_per_gpu = batch['tokens'].size(0) + num_micro_batches_before_decode = get_num_microbatches() + + compute_logprob = inference_config.get('compute_logprob', False) if compute_logprob: inference_config['inputs'] = batch inference_config['tokens_to_generate'] = 1 @@ -631,8 +625,7 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int] inference_config["add_BOS"] = False inference_config['greedy'] = True response = generate(self, **inference_config) - compute_prob_response = get_computeprob_response(self.tokenizer, response, batch) - return compute_prob_response + response = get_computeprob_response(self.tokenizer, response, batch) else: # for megatron_gpt_eval.py if isinstance(batch, list): @@ -640,13 +633,33 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int] else: # peft_eval.py inference_config['inputs'] = (batch['contexts'].cuda(), batch['context_lengths'].cuda()) - return generate(self, **inference_config) + response = generate(self, **inference_config) + + app_state = AppState() + _reconfigure_microbatch_calculator( + rank=app_state.global_rank, + rampup_batch_size=None, + global_batch_size=global_batch_size_per_gpu * parallel_state.get_data_parallel_world_size(), + micro_batch_size=global_batch_size_per_gpu // num_micro_batches_before_decode, + data_parallel_size=parallel_state.get_data_parallel_world_size(), + ) + + return response def write_predictions_to_file(self, outputs, output_file_path_prefix): - with open(output_file_path_prefix + "_inputs_preds_labels.jsonl", "w") as f_json: - assert len(outputs['inputs']) == len(outputs['preds']) == len(outputs['labels']) - for i, p, l in zip(outputs['inputs'], outputs['preds'], outputs['labels']): - f_json.write(json.dumps({'input': i, 'pred': p, 'label': l}) + '\n') + output_file_path = output_file_path_prefix + "_inputs_preds_labels.jsonl" + with open(output_file_path, "w") as f_json: + assert ( + len(outputs['inputs']) == len(outputs['preds']) == len(outputs['labels']) == len(outputs['metadata']) + ) + for i, p, l, m in zip(outputs['inputs'], outputs['preds'], outputs['labels'], outputs['metadata']): + json_string = {'input': i, 'pred': p, 'label': l} + for k, v in m.items(): + if k not in json_string: + json_string[k] = v + f_json.write(json.dumps(json_string) + '\n') + + logging.info(f'Predictions saved to {output_file_path}') def cast_for_metric(self, pred, label, metric_name, class_labels=None, labels_are_strings=False): if metric_name == 'exact_string_match' or 'rouge' in metric_name: @@ -704,19 +717,36 @@ def cast_for_metric(self, pred, label, metric_name, class_labels=None, labels_ar return pred, label # Override the parent batch reconfiguring logic. - def _reconfigure_and_process_inference_batch(self, batch): - global_batch_per_gpu = batch['tokens'].size(0) - # This should happen only on the last batch of the validation/test dataset with drop_last=False. - if global_batch_per_gpu != self.cfg.data.validation_ds.global_batch_size: - app_state = AppState() - _reconfigure_microbatch_calculator( - rank=app_state.global_rank, - rampup_batch_size=None, - global_batch_size=global_batch_per_gpu * parallel_state.get_data_parallel_world_size(), - micro_batch_size=global_batch_per_gpu, - data_parallel_size=parallel_state.get_data_parallel_world_size(), - ) - return batch + def _reconfigure_and_process_inference_batch(self, batch, data_cfg): + global_batch_size_per_gpu = batch['tokens'].size(0) + # This should happen only on the last batch of the dataset. + if ( + global_batch_size_per_gpu + != get_current_global_batch_size() // parallel_state.get_data_parallel_world_size() + ): + # NOTE: This is reconfiguring to make sure there is no grad-acc for validation batches. + if ( + global_batch_size_per_gpu + != data_cfg.global_batch_size // parallel_state.get_data_parallel_world_size() + ): + app_state = AppState() + _reconfigure_microbatch_calculator( + rank=app_state.global_rank, + rampup_batch_size=None, + global_batch_size=global_batch_size_per_gpu * parallel_state.get_data_parallel_world_size(), + micro_batch_size=global_batch_size_per_gpu, + data_parallel_size=parallel_state.get_data_parallel_world_size(), + ) + # NOTE: need to explicitly handle resetting for multi-validation + else: + app_state = AppState() + _reconfigure_microbatch_calculator( + rank=app_state.global_rank, + rampup_batch_size=None, + global_batch_size=data_cfg.global_batch_size, + micro_batch_size=data_cfg.micro_batch_size, + data_parallel_size=parallel_state.get_data_parallel_world_size(), + ) def build_train_valid_test_datasets(self, stage): if stage != 'test': @@ -754,8 +784,8 @@ def build_data_loader(self, dataset, data_cfg, consumed_samples=0): global_batch_size=data_cfg.global_batch_size, data_parallel_rank=parallel_state.get_data_parallel_rank(), data_parallel_size=parallel_state.get_data_parallel_world_size(), - drop_last=True, - pad_samples_to_global_batch_size=False, + drop_last=data_cfg.drop_last, + pad_samples_to_global_batch_size=not data_cfg.drop_last, ) return torch.utils.data.DataLoader( dataset, @@ -779,27 +809,9 @@ def setup_eval_dataloader(self, datasets, data_cfg): dataloaders.append(eval_dl) return dataloaders - def _reset_activation_checkpointing_args(self): - if self.cfg.get('megatron_amp_O2', False): - base_module = self.model.module - else: - base_module = self.model - - base_module.language_model.encoder.activations_checkpoint_granularity = None - base_module.language_model.encoder.activations_checkpoint_method = None - base_module.language_model.encoder.activations_checkpoint_num_layers = None - - def _restore_activation_checkpointing_args(self): - if self.cfg.get('megatron_amp_O2', False): - base_module = self.model.module - else: - base_module = self.model - base_module.language_model.encoder.activations_checkpoint_granularity = self.original_checkpointing_granularity - base_module.language_model.encoder.activations_checkpoint_method = self.original_checkpointing_method - base_module.language_model.encoder.activations_checkpoint_num_layers = self.original_checkpointing_num_layers - def on_validation_epoch_start(self): self._reset_activation_checkpointing_args() + self._reset_sequence_parallelism_args() app_state = AppState() _reconfigure_microbatch_calculator( rank=app_state.global_rank, @@ -811,8 +823,9 @@ def on_validation_epoch_start(self): return super().on_validation_epoch_start() def on_test_epoch_start(self): - app_state = AppState() self._reset_activation_checkpointing_args() + self._reset_sequence_parallelism_args() + app_state = AppState() _reconfigure_microbatch_calculator( rank=app_state.global_rank, rampup_batch_size=None,