Skip to content

Commit

Permalink
standard-tests[patch]: add async test for structured output (#26527)
Browse files Browse the repository at this point in the history
  • Loading branch information
ccurme authored Sep 16, 2024
1 parent 1ab181f commit 88bc15d
Showing 1 changed file with 34 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,13 @@ def magic_function_no_args() -> int:
return 5


class Joke(BaseModel):
"""Joke to tell user."""

setup: str = Field(description="question to set up a joke")
punchline: str = Field(description="answer to resolve the joke")


def _validate_tool_call_message(message: BaseMessage) -> None:
assert isinstance(message, AIMessage)
assert len(message.tool_calls) == 1
Expand Down Expand Up @@ -240,12 +247,6 @@ def test_structured_output(self, model: BaseChatModel) -> None:
if not self.has_tool_calling:
pytest.skip("Test requires tool calling.")

class Joke(BaseModel):
"""Joke to tell user."""

setup: str = Field(description="question to set up a joke")
punchline: str = Field(description="answer to resolve the joke")

# Pydantic class
# Type ignoring since the interface only officially supports pydantic 1
# or pydantic.v1.BaseModel but not pydantic.BaseModel from pydantic 2.
Expand All @@ -268,6 +269,33 @@ class Joke(BaseModel):
assert isinstance(chunk, dict) # for mypy
assert set(chunk.keys()) == {"setup", "punchline"}

async def test_structured_output_async(self, model: BaseChatModel) -> None:
"""Test to verify structured output with a Pydantic model."""
if not self.has_tool_calling:
pytest.skip("Test requires tool calling.")

# Pydantic class
# Type ignoring since the interface only officially supports pydantic 1
# or pydantic.v1.BaseModel but not pydantic.BaseModel from pydantic 2.
# We'll need to do a pass updating the type signatures.
chat = model.with_structured_output(Joke) # type: ignore[arg-type]
result = await chat.ainvoke("Tell me a joke about cats.")
assert isinstance(result, Joke)

async for chunk in chat.astream("Tell me a joke about cats."):
assert isinstance(chunk, Joke)

# Schema
chat = model.with_structured_output(Joke.model_json_schema())
result = await chat.ainvoke("Tell me a joke about cats.")
assert isinstance(result, dict)
assert set(result.keys()) == {"setup", "punchline"}

async for chunk in chat.astream("Tell me a joke about cats."):
assert isinstance(chunk, dict)
assert isinstance(chunk, dict) # for mypy
assert set(chunk.keys()) == {"setup", "punchline"}

@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Test requires pydantic 2.")
def test_structured_output_pydantic_2_v1(self, model: BaseChatModel) -> None:
"""Test to verify compatibility with pydantic.v1.BaseModel.
Expand Down

0 comments on commit 88bc15d

Please sign in to comment.