Skip to content

Commit

Permalink
Allows not enforcing function usage when a single function is passed …
Browse files Browse the repository at this point in the history
…to openai function chain/executable
  • Loading branch information
karimassi committed Dec 5, 2023
1 parent f758c8a commit 79f1c59
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions libs/langchain/langchain/chains/openai_functions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ def create_openai_fn_runnable(
llm: Runnable,
prompt: BasePromptTemplate,
*,
enforce_single_function_usage: bool = True,
output_parser: Optional[Union[BaseOutputParser, BaseGenerationOutputParser]] = None,
**kwargs: Any,
) -> Runnable:
Expand All @@ -222,6 +223,9 @@ def create_openai_fn_runnable(
pydantic.BaseModels for arguments.
llm: Language model to use, assumed to support the OpenAI function-calling API.
prompt: BasePromptTemplate to pass to the model.
enforce_single_function_usage: only used if a single function is passed in. If
True, then the model will be forced to use the given function. If False,
then the model will be given the option to use the given function or not.
output_parser: BaseLLMOutputParser to use for parsing model outputs. By default
will be inferred from the function types. If pydantic.BaseModels are passed
in, then the OutputParser will try to parse outputs using those. Otherwise
Expand Down Expand Up @@ -276,7 +280,7 @@ class RecordDog(BaseModel):
raise ValueError("Need to pass in at least one function. Received zero.")
openai_functions = [convert_to_openai_function(f) for f in functions]
llm_kwargs: Dict[str, Any] = {"functions": openai_functions, **kwargs}
if len(openai_functions) == 1:
if len(openai_functions) == 1 and enforce_single_function_usage:
llm_kwargs["function_call"] = {"name": openai_functions[0]["name"]}
output_parser = output_parser or get_openai_output_parser(functions)
return prompt | llm.bind(**llm_kwargs) | output_parser
Expand Down Expand Up @@ -373,6 +377,7 @@ def create_openai_fn_chain(
llm: BaseLanguageModel,
prompt: BasePromptTemplate,
*,
enforce_single_function_usage: bool = True,
output_key: str = "function",
output_parser: Optional[BaseLLMOutputParser] = None,
**kwargs: Any,
Expand All @@ -392,6 +397,9 @@ def create_openai_fn_chain(
pydantic.BaseModels for arguments.
llm: Language model to use, assumed to support the OpenAI function-calling API.
prompt: BasePromptTemplate to pass to the model.
enforce_single_function_usage: only used if a single function is passed in. If
True, then the model will be forced to use the given function. If False,
then the model will be given the option to use the given function or not.
output_key: The key to use when returning the output in LLMChain.__call__.
output_parser: BaseLLMOutputParser to use for parsing model outputs. By default
will be inferred from the function types. If pydantic.BaseModels are passed
Expand Down Expand Up @@ -451,7 +459,7 @@ class RecordDog(BaseModel):
llm_kwargs: Dict[str, Any] = {
"functions": openai_functions,
}
if len(openai_functions) == 1:
if len(openai_functions) == 1 and enforce_single_function_usage:
llm_kwargs["function_call"] = {"name": openai_functions[0]["name"]}
llm_chain = LLMChain(
llm=llm,
Expand Down

0 comments on commit 79f1c59

Please sign in to comment.