diff --git a/tests/test_adaption_prompt.py b/tests/test_adaption_prompt.py index 418bb41187b..f18a40c79f1 100644 --- a/tests/test_adaption_prompt.py +++ b/tests/test_adaption_prompt.py @@ -137,7 +137,7 @@ def test_save_pretrained_regression(self) -> None: assert state_dict.keys() == state_dict_from_pretrained.keys() # Check that the number of saved parameters is 4 -- 2 layers of (tokens and gate). - assert len(list(state_dict.keys())) == 4 + assert len(state_dict) == 4 # check if tensors equal for key in state_dict.keys(): @@ -180,7 +180,7 @@ def test_save_pretrained(self) -> None: assert state_dict.keys() == state_dict_from_pretrained.keys() # Check that the number of saved parameters is 4 -- 2 layers of (tokens and gate). - assert len(list(state_dict.keys())) == 4 + assert len(state_dict) == 4 # check if tensors equal for key in state_dict.keys(): @@ -228,7 +228,7 @@ def test_save_pretrained_selected_adapters(self) -> None: assert state_dict.keys() == state_dict_from_pretrained.keys() # Check that the number of saved parameters is 4 -- 2 layers of (tokens and gate). - assert len(list(state_dict.keys())) == 4 + assert len(state_dict) == 4 # check if tensors equal for key in state_dict.keys(): diff --git a/tests/test_multitask_prompt_tuning.py b/tests/test_multitask_prompt_tuning.py index 2b95cee15fb..e6ead98f68a 100644 --- a/tests/test_multitask_prompt_tuning.py +++ b/tests/test_multitask_prompt_tuning.py @@ -134,7 +134,7 @@ def test_save_pretrained(self) -> None: assert state_dict.keys() == state_dict_from_pretrained.keys() # Check that the number of saved parameters is 4 -- 2 layers of (tokens and gate). - assert len(list(state_dict.keys())) == 3 + assert len(state_dict) == 3 # check if tensors equal for key in state_dict.keys(): @@ -177,7 +177,7 @@ def test_save_pretrained_regression(self) -> None: assert state_dict.keys() == state_dict_from_pretrained.keys() # Check that the number of saved parameters is 4 -- 2 layers of (tokens and gate). - assert len(list(state_dict.keys())) == 3 + assert len(state_dict) == 3 # check if tensors equal for key in state_dict.keys():