diff --git a/src/tests/trainers/test_trainers.py b/src/tests/trainers/test_trainers.py index c36f4fcd..085768ed 100644 --- a/src/tests/trainers/test_trainers.py +++ b/src/tests/trainers/test_trainers.py @@ -1820,6 +1820,7 @@ def test_seq2seq(self, create_datadreamer, mocker): validation_rejected=val_dataset.output["rejected"], epochs=1, batch_size=8, + precompute_ref_log_probs=False, ) assert data_collator_spy.call_count == 3 spy_return_value = { @@ -1920,6 +1921,7 @@ def test_causal(self, create_datadreamer, mocker): validation_rejected=val_dataset.output["rejected"], epochs=1, batch_size=8, + precompute_ref_log_probs=False, ) assert data_collator_spy.call_count == 3 spy_return_value = { @@ -2057,6 +2059,7 @@ def test_peft(self, create_datadreamer, mocker): validation_rejected=val_dataset.output["rejected"], epochs=1, batch_size=8, + precompute_ref_log_probs=True, ) trainer_path = cast(str, trainer._output_folder_path) with open(os.path.join(trainer_path, "fingerprint.json"), "r") as f: