Skip to content

Commit

Permalink
Allow not enforcing function usage when a single function is passed t…
Browse files Browse the repository at this point in the history
…o openai function executable (langchain-ai#14308)

- **Description:** allows not enforcing function usage when a single
function is passed to an openAI function executable (or corresponding
legacy chain). This is a desired feature in the case where the model
does not have enough information to call a function, and needs to get
back to the user.
  - **Issue:** N/A
  - **Dependencies:** N/A
  - **Tag maintainer:** N/A
  • Loading branch information
karimassi authored and aymeric-roucher committed Dec 11, 2023
1 parent 2237217 commit 4c39a9e
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions libs/langchain/langchain/chains/openai_functions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ def create_openai_fn_runnable(
llm: Runnable,
prompt: BasePromptTemplate,
*,
enforce_single_function_usage: bool = True,
output_parser: Optional[Union[BaseOutputParser, BaseGenerationOutputParser]] = None,
**kwargs: Any,
) -> Runnable:
Expand All @@ -222,6 +223,9 @@ def create_openai_fn_runnable(
pydantic.BaseModels for arguments.
llm: Language model to use, assumed to support the OpenAI function-calling API.
prompt: BasePromptTemplate to pass to the model.
enforce_single_function_usage: only used if a single function is passed in. If
True, then the model will be forced to use the given function. If False,
then the model will be given the option to use the given function or not.
output_parser: BaseLLMOutputParser to use for parsing model outputs. By default
will be inferred from the function types. If pydantic.BaseModels are passed
in, then the OutputParser will try to parse outputs using those. Otherwise
Expand Down Expand Up @@ -276,7 +280,7 @@ class RecordDog(BaseModel):
raise ValueError("Need to pass in at least one function. Received zero.")
openai_functions = [convert_to_openai_function(f) for f in functions]
llm_kwargs: Dict[str, Any] = {"functions": openai_functions, **kwargs}
if len(openai_functions) == 1:
if len(openai_functions) == 1 and enforce_single_function_usage:
llm_kwargs["function_call"] = {"name": openai_functions[0]["name"]}
output_parser = output_parser or get_openai_output_parser(functions)
return prompt | llm.bind(**llm_kwargs) | output_parser
Expand Down Expand Up @@ -373,6 +377,7 @@ def create_openai_fn_chain(
llm: BaseLanguageModel,
prompt: BasePromptTemplate,
*,
enforce_single_function_usage: bool = True,
output_key: str = "function",
output_parser: Optional[BaseLLMOutputParser] = None,
**kwargs: Any,
Expand All @@ -392,6 +397,9 @@ def create_openai_fn_chain(
pydantic.BaseModels for arguments.
llm: Language model to use, assumed to support the OpenAI function-calling API.
prompt: BasePromptTemplate to pass to the model.
enforce_single_function_usage: only used if a single function is passed in. If
True, then the model will be forced to use the given function. If False,
then the model will be given the option to use the given function or not.
output_key: The key to use when returning the output in LLMChain.__call__.
output_parser: BaseLLMOutputParser to use for parsing model outputs. By default
will be inferred from the function types. If pydantic.BaseModels are passed
Expand Down Expand Up @@ -451,7 +459,7 @@ class RecordDog(BaseModel):
llm_kwargs: Dict[str, Any] = {
"functions": openai_functions,
}
if len(openai_functions) == 1:
if len(openai_functions) == 1 and enforce_single_function_usage:
llm_kwargs["function_call"] = {"name": openai_functions[0]["name"]}
llm_chain = LLMChain(
llm=llm,
Expand Down

0 comments on commit 4c39a9e

Please sign in to comment.