diff --git a/nemo/collections/nlp/data/language_modeling/megatron/gpt_prompt_learning_dataset.py b/nemo/collections/nlp/data/language_modeling/megatron/gpt_prompt_learning_dataset.py index 3d0c29673c83..2c23b2468585 100755 --- a/nemo/collections/nlp/data/language_modeling/megatron/gpt_prompt_learning_dataset.py +++ b/nemo/collections/nlp/data/language_modeling/megatron/gpt_prompt_learning_dataset.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy import json import os import pickle @@ -319,7 +320,11 @@ def __getitem__(self, idx): def collate_fn(self, batch, tp_workers=0): """ Prepares input_ids, labels, loss mask, attention_mask, and position ids for global batch """ - taskname_ids, input_ids, answer_starts = zip(*batch) + orig_taskname_ids, orig_input_ids, orig_answer_starts = zip(*batch) + taskname_ids = copy.deepcopy(orig_taskname_ids) + input_ids = copy.deepcopy(orig_input_ids) + answer_starts = copy.deepcopy(orig_answer_starts) + # Pad taskname_ids to be the same length for the prompt encoder if self.virtual_prompt_source == VirtualPromptSource.PROMPT_ENCODER: diff --git a/nemo/collections/nlp/data/language_modeling/megatron/t5_prompt_learning_dataset.py b/nemo/collections/nlp/data/language_modeling/megatron/t5_prompt_learning_dataset.py index 0f39cd8e05c9..0d1521decda7 100644 --- a/nemo/collections/nlp/data/language_modeling/megatron/t5_prompt_learning_dataset.py +++ b/nemo/collections/nlp/data/language_modeling/megatron/t5_prompt_learning_dataset.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import enum +import copy import json import torch @@ -195,7 +195,12 @@ def _insert_text_in_template(self, input_example, prompt_template_fields, doc, a def collate_fn(self, batch): """ Prepares enc_input, dec_input, labels, loss_mask, enc_mask, dec_mask, position_ids, taskname_ids for global batch """ - taskname_ids, enc_input, dec_input, dec_labels = zip(*batch) + orig_taskname_ids, orig_enc_input, orig_dec_input, orig_dec_labels = zip(*batch) + taskname_ids = copy.deepcopy(orig_taskname_ids) + enc_input = copy.deepcopy(orig_enc_input) + dec_input = copy.deepcopy(orig_dec_input) + dec_labels = copy.deepcopy(orig_dec_labels) + taskname_ids = self.pad_taskname_ids(taskname_ids)