From 211e68ed2aefa44bdd32425d38dd81b2d3c841e8 Mon Sep 17 00:00:00 2001 From: arendu Date: Wed, 9 Nov 2022 13:53:52 -0800 Subject: [PATCH] fix for num worker 0 causing issues in losses after 1 epoch Signed-off-by: arendu --- .../megatron/gpt_prompt_learning_dataset.py | 7 ++++++- .../megatron/t5_prompt_learning_dataset.py | 9 +++++++-- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/nemo/collections/nlp/data/language_modeling/megatron/gpt_prompt_learning_dataset.py b/nemo/collections/nlp/data/language_modeling/megatron/gpt_prompt_learning_dataset.py index 3d0c29673c83..2c23b2468585 100755 --- a/nemo/collections/nlp/data/language_modeling/megatron/gpt_prompt_learning_dataset.py +++ b/nemo/collections/nlp/data/language_modeling/megatron/gpt_prompt_learning_dataset.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy import json import os import pickle @@ -319,7 +320,11 @@ def __getitem__(self, idx): def collate_fn(self, batch, tp_workers=0): """ Prepares input_ids, labels, loss mask, attention_mask, and position ids for global batch """ - taskname_ids, input_ids, answer_starts = zip(*batch) + orig_taskname_ids, orig_input_ids, orig_answer_starts = zip(*batch) + taskname_ids = copy.deepcopy(orig_taskname_ids) + input_ids = copy.deepcopy(orig_input_ids) + answer_starts = copy.deepcopy(orig_answer_starts) + # Pad taskname_ids to be the same length for the prompt encoder if self.virtual_prompt_source == VirtualPromptSource.PROMPT_ENCODER: diff --git a/nemo/collections/nlp/data/language_modeling/megatron/t5_prompt_learning_dataset.py b/nemo/collections/nlp/data/language_modeling/megatron/t5_prompt_learning_dataset.py index 0f39cd8e05c9..0d1521decda7 100644 --- a/nemo/collections/nlp/data/language_modeling/megatron/t5_prompt_learning_dataset.py +++ b/nemo/collections/nlp/data/language_modeling/megatron/t5_prompt_learning_dataset.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import enum +import copy import json import torch @@ -195,7 +195,12 @@ def _insert_text_in_template(self, input_example, prompt_template_fields, doc, a def collate_fn(self, batch): """ Prepares enc_input, dec_input, labels, loss_mask, enc_mask, dec_mask, position_ids, taskname_ids for global batch """ - taskname_ids, enc_input, dec_input, dec_labels = zip(*batch) + orig_taskname_ids, orig_enc_input, orig_dec_input, orig_dec_labels = zip(*batch) + taskname_ids = copy.deepcopy(orig_taskname_ids) + enc_input = copy.deepcopy(orig_enc_input) + dec_input = copy.deepcopy(orig_dec_input) + dec_labels = copy.deepcopy(orig_dec_labels) + taskname_ids = self.pad_taskname_ids(taskname_ids)