Skip to content

Commit

Permalink
Add pre-compute test
Browse files Browse the repository at this point in the history
  • Loading branch information
AjayP13 committed Dec 24, 2023
1 parent f2f1701 commit 49561bb
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/tests/trainers/test_trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1820,6 +1820,7 @@ def test_seq2seq(self, create_datadreamer, mocker):
validation_rejected=val_dataset.output["rejected"],
epochs=1,
batch_size=8,
precompute_ref_log_probs=False,
)
assert data_collator_spy.call_count == 3
spy_return_value = {
Expand Down Expand Up @@ -1920,6 +1921,7 @@ def test_causal(self, create_datadreamer, mocker):
validation_rejected=val_dataset.output["rejected"],
epochs=1,
batch_size=8,
precompute_ref_log_probs=False,
)
assert data_collator_spy.call_count == 3
spy_return_value = {
Expand Down Expand Up @@ -2057,6 +2059,7 @@ def test_peft(self, create_datadreamer, mocker):
validation_rejected=val_dataset.output["rejected"],
epochs=1,
batch_size=8,
precompute_ref_log_probs=True,
)
trainer_path = cast(str, trainer._output_folder_path)
with open(os.path.join(trainer_path, "fingerprint.json"), "r") as f:
Expand Down

0 comments on commit 49561bb

Please sign in to comment.