Skip to content

Commit

Permalink
fix for num worker 0 causing issues in losses after 1 epoch (NVIDIA#5379
Browse files Browse the repository at this point in the history
) (NVIDIA#5384)

Co-authored-by: Adi Renduchintala <[email protected]>
Signed-off-by: Hainan Xu <[email protected]>
  • Loading branch information
2 people authored and Hainan Xu committed Nov 29, 2022
1 parent d7ed3eb commit 9c113ec
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,7 @@ def collate_fn(self, batch, tp_workers=0):
def pad_batch_and_build_loss_mask(self, input_ids, batch_max, answer_starts):
""" Pad input_ids in batch to max batch length while building loss mask """
batch_loss_masks = []
padded_input_ids = []
for ids, answer_start_idx in zip(input_ids, answer_starts):
if answer_start_idx is not None:
# Loss mask where answer tokens are 1.0 and all other tokens are 0.0
Expand All @@ -375,17 +376,19 @@ def pad_batch_and_build_loss_mask(self, input_ids, batch_max, answer_starts):
# Pad to max length
input_length = len(ids)
padding_length = batch_max - input_length
ids.extend([self.pad_token_id] * padding_length)
pad_extend = [self.pad_token_id] * padding_length
ids = ids + pad_extend
padded_input_ids.append(ids)

# Account for padding in loss mask
loss_mask.extend([0.0] * padding_length)
batch_loss_masks.append(torch.tensor(loss_mask, dtype=torch.float))

# Make into torch tensors
input_ids = torch.tensor(input_ids, dtype=torch.long)
padded_input_ids = torch.tensor(padded_input_ids, dtype=torch.long)
batch_loss_masks = torch.stack(batch_loss_masks)

return input_ids, batch_loss_masks
return padded_input_ids, batch_loss_masks

def inference_collate_fn(self, batch):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import enum
import json

import torch
Expand Down

0 comments on commit 9c113ec

Please sign in to comment.