Skip to content

Commit

Permalink
be more defensive about model type check
Browse files Browse the repository at this point in the history
  • Loading branch information
jonasi committed Mar 20, 2024
1 parent 4ddbf81 commit d16f790
Showing 1 changed file with 13 additions and 1 deletion.
14 changes: 13 additions & 1 deletion langdspy/prompt_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,18 @@ def _determine_llm_type(self, llm):
else:
return 'openai' # Default to OpenAI if model type cannot be determined

def _determine_llm_model(self, llm):
if isinstance(llm, ChatOpenAI): # Assuming OpenAILLM is the class for OpenAI models
return llm.model_name
elif isinstance(llm, ChatAnthropic): # Assuming AnthropicLLM is the class for Anthropic models
return llm.model
elif hasattr(llm, 'model_name'):
return llm.model_name
elif hasattr(llm, 'model'):
return llm.model
else:
return '???'

def get_prompt_history(self):
return self.prompt_history.history

Expand Down Expand Up @@ -193,7 +205,7 @@ def _invoke_with_retries(self, chain, input, max_tries=1, config: Optional[Runna
parsed_output[attr_name] = transformed_val

end_time = time.time()
self.prompt_history.add_entry(config["llm"].model_name, formatted_prompt, res, parsed_output, validation_err, start_time, end_time)
self.prompt_history.add_entry(self._determine_llm_type(config['llm']) + " " + self._determine_llm_model(config['llm']), formatted_prompt, res, parsed_output, validation_err, start_time, end_time)

res = {attr_name: parsed_output.get(attr_name, None) for attr_name in self.template.output_variables.keys()}

Expand Down

0 comments on commit d16f790

Please sign in to comment.