Skip to content

Commit

Permalink
fix for num worker 0 causing issues in losses after 1 epoch
Browse files Browse the repository at this point in the history
Signed-off-by: arendu <[email protected]>
  • Loading branch information
arendu committed Nov 9, 2022
1 parent 6d9a8d2 commit 211e68e
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import json
import os
import pickle
Expand Down Expand Up @@ -319,7 +320,11 @@ def __getitem__(self, idx):

def collate_fn(self, batch, tp_workers=0):
""" Prepares input_ids, labels, loss mask, attention_mask, and position ids for global batch """
taskname_ids, input_ids, answer_starts = zip(*batch)
orig_taskname_ids, orig_input_ids, orig_answer_starts = zip(*batch)
taskname_ids = copy.deepcopy(orig_taskname_ids)
input_ids = copy.deepcopy(orig_input_ids)
answer_starts = copy.deepcopy(orig_answer_starts)


# Pad taskname_ids to be the same length for the prompt encoder
if self.virtual_prompt_source == VirtualPromptSource.PROMPT_ENCODER:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import enum
import copy
import json

import torch
Expand Down Expand Up @@ -195,7 +195,12 @@ def _insert_text_in_template(self, input_example, prompt_template_fields, doc, a
def collate_fn(self, batch):
""" Prepares enc_input, dec_input, labels, loss_mask, enc_mask, dec_mask, position_ids, taskname_ids for global batch """

taskname_ids, enc_input, dec_input, dec_labels = zip(*batch)
orig_taskname_ids, orig_enc_input, orig_dec_input, orig_dec_labels = zip(*batch)
taskname_ids = copy.deepcopy(orig_taskname_ids)
enc_input = copy.deepcopy(orig_enc_input)
dec_input = copy.deepcopy(orig_dec_input)
dec_labels = copy.deepcopy(orig_dec_labels)


taskname_ids = self.pad_taskname_ids(taskname_ids)

Expand Down

0 comments on commit 211e68e

Please sign in to comment.