Skip to content

Commit

Permalink
fix: fix default value access for prompt runners
Browse files Browse the repository at this point in the history
call field.get_default which will call .default_factory if passed in
instead of just assuming .default is there
  • Loading branch information
jonasi committed Mar 29, 2024
1 parent ad3a5e1 commit 2a7cbc5
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 6 deletions.
7 changes: 4 additions & 3 deletions langdspy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,12 @@ def __init__(self, n_jobs=1, **kwargs):
self.kwargs = {**kwargs, 'trained_state': self.trained_state}
for field_name, field in self.__fields__.items():
if issubclass(field.type_, PromptRunner):
self.prompt_runners.append((field_name, field.default))
field_value = field.get_default()
self.prompt_runners.append((field_name, field_value))

field.default.set_model_kwargs(self.kwargs)
field_value.set_model_kwargs(self.kwargs)
# Necessary since pydantic creates a new version of the object
setattr(self, field_name, field.default)
setattr(self, field_name, field_value)


def save(self, filepath):
Expand Down
6 changes: 3 additions & 3 deletions tests/test_prompt_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def test_trained_state_in_inputs():

print(result)
print(f"Called with {mock_invoke.call_count} {mock_invoke.call_args_list} {mock_invoke.call_args}")
call_args = {**input_dict, 'print_prompt': "TEST", 'trained_state': model.trained_state, 'use_training': True, 'llm_type': "test"}
call_args = {**input_dict, 'print_prompt': False, 'trained_state': model.trained_state, 'use_training': True, 'llm_type': "test"}
print(f"Expecting call {call_args}")
mock_invoke.assert_called_with(**call_args)

Expand All @@ -96,6 +96,6 @@ def test_use_training():

print(result)
print(f"Called with {mock_invoke.call_count} {mock_invoke.call_args_list} {mock_invoke.call_args}")
call_args = {**input_dict, 'print_prompt': "TEST", 'trained_state': model.trained_state, 'use_training': False, 'llm_type': "test"}
call_args = {**input_dict, 'print_prompt': False, 'trained_state': model.trained_state, 'use_training': False, 'llm_type': "test"}
print(f"Expecting call {call_args}")
mock_invoke.assert_called_with(**call_args)
mock_invoke.assert_called_with(**call_args)

0 comments on commit 2a7cbc5

Please sign in to comment.