From 49561bb4e80c0559709ff352b5fc299af986461a Mon Sep 17 00:00:00 2001 From: Ajay Patel Date: Sun, 24 Dec 2023 18:44:41 -0500 Subject: [PATCH] Add pre-compute test --- src/tests/trainers/test_trainers.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/tests/trainers/test_trainers.py b/src/tests/trainers/test_trainers.py index c36f4fcd..085768ed 100644 --- a/src/tests/trainers/test_trainers.py +++ b/src/tests/trainers/test_trainers.py @@ -1820,6 +1820,7 @@ def test_seq2seq(self, create_datadreamer, mocker): validation_rejected=val_dataset.output["rejected"], epochs=1, batch_size=8, + precompute_ref_log_probs=False, ) assert data_collator_spy.call_count == 3 spy_return_value = { @@ -1920,6 +1921,7 @@ def test_causal(self, create_datadreamer, mocker): validation_rejected=val_dataset.output["rejected"], epochs=1, batch_size=8, + precompute_ref_log_probs=False, ) assert data_collator_spy.call_count == 3 spy_return_value = { @@ -2057,6 +2059,7 @@ def test_peft(self, create_datadreamer, mocker): validation_rejected=val_dataset.output["rejected"], epochs=1, batch_size=8, + precompute_ref_log_probs=True, ) trainer_path = cast(str, trainer._output_folder_path) with open(os.path.join(trainer_path, "fingerprint.json"), "r") as f: