From 2a7cbc5384d931fc0d68e0b09fc245a9ea1a9f0f Mon Sep 17 00:00:00 2001 From: Isao Jonas Date: Fri, 29 Mar 2024 07:45:47 -0500 Subject: [PATCH] fix: fix default value access for prompt runners call field.get_default which will call .default_factory if passed in instead of just assuming .default is there --- langdspy/model.py | 7 ++++--- tests/test_prompt_runner.py | 6 +++--- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/langdspy/model.py b/langdspy/model.py index 855264f..1d928b5 100644 --- a/langdspy/model.py +++ b/langdspy/model.py @@ -52,11 +52,12 @@ def __init__(self, n_jobs=1, **kwargs): self.kwargs = {**kwargs, 'trained_state': self.trained_state} for field_name, field in self.__fields__.items(): if issubclass(field.type_, PromptRunner): - self.prompt_runners.append((field_name, field.default)) + field_value = field.get_default() + self.prompt_runners.append((field_name, field_value)) - field.default.set_model_kwargs(self.kwargs) + field_value.set_model_kwargs(self.kwargs) # Necessary since pydantic creates a new version of the object - setattr(self, field_name, field.default) + setattr(self, field_name, field_value) def save(self, filepath): diff --git a/tests/test_prompt_runner.py b/tests/test_prompt_runner.py index d0c3f3e..9c87242 100644 --- a/tests/test_prompt_runner.py +++ b/tests/test_prompt_runner.py @@ -80,7 +80,7 @@ def test_trained_state_in_inputs(): print(result) print(f"Called with {mock_invoke.call_count} {mock_invoke.call_args_list} {mock_invoke.call_args}") - call_args = {**input_dict, 'print_prompt': "TEST", 'trained_state': model.trained_state, 'use_training': True, 'llm_type': "test"} + call_args = {**input_dict, 'print_prompt': False, 'trained_state': model.trained_state, 'use_training': True, 'llm_type': "test"} print(f"Expecting call {call_args}") mock_invoke.assert_called_with(**call_args) @@ -96,6 +96,6 @@ def test_use_training(): print(result) print(f"Called with {mock_invoke.call_count} {mock_invoke.call_args_list} {mock_invoke.call_args}") - call_args = {**input_dict, 'print_prompt': "TEST", 'trained_state': model.trained_state, 'use_training': False, 'llm_type': "test"} + call_args = {**input_dict, 'print_prompt': False, 'trained_state': model.trained_state, 'use_training': False, 'llm_type': "test"} print(f"Expecting call {call_args}") - mock_invoke.assert_called_with(**call_args) \ No newline at end of file + mock_invoke.assert_called_with(**call_args)