Skip to content

Commit

Permalink
support GPTSFTChatDataset
Browse files Browse the repository at this point in the history
Signed-off-by: Cheng-Ping Hsieh <[email protected]>
  • Loading branch information
hsiehjackson committed Aug 10, 2023
1 parent 29f775c commit cf2f2b9
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,6 @@ def preprocess(source: dict, tokenizer: TokenizerSpec, extra_id_2_token_id: int,
# not going to train on the header
target[:header_len] = IGNORE_INDEX
input_ids = torch.LongTensor(input_ids)

_mask_targets(
target,
tokenized_lens,
Expand All @@ -222,7 +221,11 @@ def preprocess(source: dict, tokenizer: TokenizerSpec, extra_id_2_token_id: int,
)
mask = (target != IGNORE_INDEX).bool()
assert mask.sum().item() != 0, "mask is empty"
return dict(input_ids=input_ids, mask=mask)
# Choose the last conversation as answer other history are context
last_ignore_index_pos = torch.nonzero(target == IGNORE_INDEX)[-1].item() + 1
context_ids = input_ids[:last_ignore_index_pos]
answer_ids = input_ids[last_ignore_index_pos:]
return dict(input_ids=input_ids, mask=mask, context_ids=context_ids, answer_ids=answer_ids)


def _check_token_in_vocab(tokenizer, token):
Expand Down Expand Up @@ -261,20 +264,29 @@ def _process_example(self, example):
BOS, EOS, and SEP, are added if specified.
"""
result = preprocess(example, self.tokenizer, self.extra_id_2_token_id, self.new_line_token_id)


# store metadata in dataset, in case user may have keys required in the prediction json files
metadata = {k: v for k, v in example.items() if k not in ['conversations']}
result['metadata'] = metadata

return result

def collate_fn(self, batch):
input_ids = [item['input_ids'][:-1].tolist() for item in batch]
labels = [item['input_ids'][1:].tolist() for item in batch]
contexts = [item['context_ids'].tolist() for item in batch]
answers = [item['answer_ids'].tolist() for item in batch]
loss_mask = [item['mask'][1:].tolist() for item in batch]

max_length = max([len(x) for x in input_ids])
metadata = [item['metadata'] for item in batch]

max_length = max(max([len(x) for x in input_ids]), max([len(x) for x in contexts]) + self.tokens_to_generate)
if max_length > self.max_seq_length:
# truncate the sequences if it is longer than max_seq_length
input_ids = [x[: self.max_seq_length] for x in input_ids]
labels = [x[: self.max_seq_length] for x in labels]
loss_mask = [x[: self.max_seq_length] for x in loss_mask]
contexts = [x[: self.max_seq_length] for x in contexts]

# increase max length to nearest multiple of 4 or 8
if self.pad_to_max_length:
max_length = self.max_seq_length
Expand All @@ -291,13 +303,20 @@ def collate_fn(self, batch):
)
labels = torch.LongTensor(self._collate_item(labels, max_length=max_length, pad_id=self.tokenizer.eos_id))
loss_mask = torch.LongTensor(self._collate_item(loss_mask, max_length=max_length, pad_id=0))

context_lengths = torch.LongTensor([len(x) for x in contexts])
contexts = torch.LongTensor(self._collate_item(contexts, max_length=max_length, pad_id=self.tokenizer.eos_id))
answers = torch.LongTensor(self._collate_item(answers, max_length=max_length, pad_id=self.tokenizer.eos_id))

processed_batch = {
'tokens': input_ids,
'labels': labels,
'attention_mask': attention_mask,
'loss_mask': loss_mask,
'position_ids': position_ids,
'contexts': contexts,
'context_lengths': context_lengths,
'answers': answers,
'metadata': metadata,
}

return processed_batch
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ def inference_step(self, dataloader_iter, batch_idx, mode, dataloader_idx=0):
data_cfg = self.cfg.data.validation_ds if mode == 'validation' else self.cfg.data.test_ds
self._reconfigure_and_process_inference_batch(batch, data_cfg)
# Meta data from dataset
metadata = batch.pop('metadata')
metadata = batch.get('metadata', [{}] * len(batch['tokens']))
loss = super().validation_step(itertools.chain([batch]), batch_idx)

# We need _inference_config to get generation params
Expand Down

0 comments on commit cf2f2b9

Please sign in to comment.