From 88bc15d69ba0a70303c54448b038eb132f72c2b1 Mon Sep 17 00:00:00 2001 From: ccurme Date: Mon, 16 Sep 2024 11:15:23 -0400 Subject: [PATCH] standard-tests[patch]: add async test for structured output (#26527) --- .../integration_tests/chat_models.py | 40 ++++++++++++++++--- 1 file changed, 34 insertions(+), 6 deletions(-) diff --git a/libs/standard-tests/langchain_standard_tests/integration_tests/chat_models.py b/libs/standard-tests/langchain_standard_tests/integration_tests/chat_models.py index 805cf57549aae..850f56c33d882 100644 --- a/libs/standard-tests/langchain_standard_tests/integration_tests/chat_models.py +++ b/libs/standard-tests/langchain_standard_tests/integration_tests/chat_models.py @@ -44,6 +44,13 @@ def magic_function_no_args() -> int: return 5 +class Joke(BaseModel): + """Joke to tell user.""" + + setup: str = Field(description="question to set up a joke") + punchline: str = Field(description="answer to resolve the joke") + + def _validate_tool_call_message(message: BaseMessage) -> None: assert isinstance(message, AIMessage) assert len(message.tool_calls) == 1 @@ -240,12 +247,6 @@ def test_structured_output(self, model: BaseChatModel) -> None: if not self.has_tool_calling: pytest.skip("Test requires tool calling.") - class Joke(BaseModel): - """Joke to tell user.""" - - setup: str = Field(description="question to set up a joke") - punchline: str = Field(description="answer to resolve the joke") - # Pydantic class # Type ignoring since the interface only officially supports pydantic 1 # or pydantic.v1.BaseModel but not pydantic.BaseModel from pydantic 2. @@ -268,6 +269,33 @@ class Joke(BaseModel): assert isinstance(chunk, dict) # for mypy assert set(chunk.keys()) == {"setup", "punchline"} + async def test_structured_output_async(self, model: BaseChatModel) -> None: + """Test to verify structured output with a Pydantic model.""" + if not self.has_tool_calling: + pytest.skip("Test requires tool calling.") + + # Pydantic class + # Type ignoring since the interface only officially supports pydantic 1 + # or pydantic.v1.BaseModel but not pydantic.BaseModel from pydantic 2. + # We'll need to do a pass updating the type signatures. + chat = model.with_structured_output(Joke) # type: ignore[arg-type] + result = await chat.ainvoke("Tell me a joke about cats.") + assert isinstance(result, Joke) + + async for chunk in chat.astream("Tell me a joke about cats."): + assert isinstance(chunk, Joke) + + # Schema + chat = model.with_structured_output(Joke.model_json_schema()) + result = await chat.ainvoke("Tell me a joke about cats.") + assert isinstance(result, dict) + assert set(result.keys()) == {"setup", "punchline"} + + async for chunk in chat.astream("Tell me a joke about cats."): + assert isinstance(chunk, dict) + assert isinstance(chunk, dict) # for mypy + assert set(chunk.keys()) == {"setup", "punchline"} + @pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Test requires pydantic 2.") def test_structured_output_pydantic_2_v1(self, model: BaseChatModel) -> None: """Test to verify compatibility with pydantic.v1.BaseModel.