diff --git a/src/openai/lib/_parsing/_completions.py b/src/openai/lib/_parsing/_completions.py index f9d1d6b351..2ef1bf3553 100644 --- a/src/openai/lib/_parsing/_completions.py +++ b/src/openai/lib/_parsing/_completions.py @@ -9,9 +9,9 @@ from .._tools import PydanticFunctionTool from ..._types import NOT_GIVEN, NotGiven from ..._utils import is_dict, is_given -from ..._compat import model_parse_json +from ..._compat import PYDANTIC_V2, model_parse_json from ..._models import construct_type_unchecked -from .._pydantic import to_strict_json_schema +from .._pydantic import is_basemodel_type, to_strict_json_schema, is_dataclass_like_type from ...types.chat import ( ParsedChoice, ChatCompletion, @@ -216,14 +216,16 @@ def is_parseable_tool(input_tool: ChatCompletionToolParam) -> bool: return cast(FunctionDefinition, input_fn).get("strict") or False -def is_basemodel_type(typ: type) -> TypeGuard[type[pydantic.BaseModel]]: - return issubclass(typ, pydantic.BaseModel) - - def _parse_content(response_format: type[ResponseFormatT], content: str) -> ResponseFormatT: if is_basemodel_type(response_format): return cast(ResponseFormatT, model_parse_json(response_format, content)) + if is_dataclass_like_type(response_format): + if not PYDANTIC_V2: + raise TypeError(f"Non BaseModel types are only supported with Pydantic v2 - {response_format}") + + return pydantic.TypeAdapter(response_format).validate_json(content) + raise TypeError(f"Unable to automatically parse response format type {response_format}") @@ -241,14 +243,22 @@ def type_to_response_format_param( # can only be a `type` response_format = cast(type, response_format) - if not is_basemodel_type(response_format): + json_schema_type: type[pydantic.BaseModel] | pydantic.TypeAdapter[Any] | None = None + + if is_basemodel_type(response_format): + name = response_format.__name__ + json_schema_type = response_format + elif is_dataclass_like_type(response_format): + name = response_format.__name__ + json_schema_type = pydantic.TypeAdapter(response_format) + else: raise TypeError(f"Unsupported response_format type - {response_format}") return { "type": "json_schema", "json_schema": { - "schema": to_strict_json_schema(response_format), - "name": response_format.__name__, + "schema": to_strict_json_schema(json_schema_type), + "name": name, "strict": True, }, } diff --git a/src/openai/lib/_pydantic.py b/src/openai/lib/_pydantic.py index f989ce3ed0..22c7a1f3cd 100644 --- a/src/openai/lib/_pydantic.py +++ b/src/openai/lib/_pydantic.py @@ -1,17 +1,26 @@ from __future__ import annotations -from typing import Any +import inspect +from typing import Any, TypeVar from typing_extensions import TypeGuard import pydantic from .._types import NOT_GIVEN from .._utils import is_dict as _is_dict, is_list -from .._compat import model_json_schema +from .._compat import PYDANTIC_V2, model_json_schema +_T = TypeVar("_T") + + +def to_strict_json_schema(model: type[pydantic.BaseModel] | pydantic.TypeAdapter[Any]) -> dict[str, Any]: + if inspect.isclass(model) and is_basemodel_type(model): + schema = model_json_schema(model) + elif PYDANTIC_V2 and isinstance(model, pydantic.TypeAdapter): + schema = model.json_schema() + else: + raise TypeError(f"Non BaseModel types are only supported with Pydantic v2 - {model}") -def to_strict_json_schema(model: type[pydantic.BaseModel]) -> dict[str, Any]: - schema = model_json_schema(model) return _ensure_strict_json_schema(schema, path=(), root=schema) @@ -117,6 +126,15 @@ def resolve_ref(*, root: dict[str, object], ref: str) -> object: return resolved +def is_basemodel_type(typ: type) -> TypeGuard[type[pydantic.BaseModel]]: + return issubclass(typ, pydantic.BaseModel) + + +def is_dataclass_like_type(typ: type) -> bool: + """Returns True if the given type likely used `@pydantic.dataclass`""" + return hasattr(typ, "__pydantic_config__") + + def is_dict(obj: object) -> TypeGuard[dict[str, object]]: # just pretend that we know there are only `str` keys # as that check is not worth the performance cost diff --git a/tests/lib/chat/test_completions.py b/tests/lib/chat/test_completions.py index aea449b097..d67d5129cd 100644 --- a/tests/lib/chat/test_completions.py +++ b/tests/lib/chat/test_completions.py @@ -3,7 +3,7 @@ import os import json from enum import Enum -from typing import Any, Callable, Optional +from typing import Any, List, Callable, Optional from typing_extensions import Literal, TypeVar import httpx @@ -317,6 +317,63 @@ class Location(BaseModel): ) +@pytest.mark.respx(base_url=base_url) +@pytest.mark.skipif(not PYDANTIC_V2, reason="dataclasses only supported in v2") +def test_parse_pydantic_dataclass(client: OpenAI, respx_mock: MockRouter, monkeypatch: pytest.MonkeyPatch) -> None: + from pydantic.dataclasses import dataclass + + @dataclass + class CalendarEvent: + name: str + date: str + participants: List[str] + + completion = _make_snapshot_request( + lambda c: c.beta.chat.completions.parse( + model="gpt-4o-2024-08-06", + messages=[ + {"role": "system", "content": "Extract the event information."}, + {"role": "user", "content": "Alice and Bob are going to a science fair on Friday."}, + ], + response_format=CalendarEvent, + ), + content_snapshot=snapshot( + '{"id": "chatcmpl-9wdGqXkJJARAz7rOrLH5u5FBwLjF3", "object": "chat.completion", "created": 1723761008, "model": "gpt-4o-2024-08-06", "choices": [{"index": 0, "message": {"role": "assistant", "content": "{\\"name\\":\\"Science Fair\\",\\"date\\":\\"Friday\\",\\"participants\\":[\\"Alice\\",\\"Bob\\"]}", "refusal": null}, "logprobs": null, "finish_reason": "stop"}], "usage": {"prompt_tokens": 32, "completion_tokens": 17, "total_tokens": 49}, "system_fingerprint": "fp_2a322c9ffc"}' + ), + mock_client=client, + respx_mock=respx_mock, + ) + + assert print_obj(completion, monkeypatch) == snapshot( + """\ +ParsedChatCompletion[CalendarEvent]( + choices=[ + ParsedChoice[CalendarEvent]( + finish_reason='stop', + index=0, + logprobs=None, + message=ParsedChatCompletionMessage[CalendarEvent]( + content='{"name":"Science Fair","date":"Friday","participants":["Alice","Bob"]}', + function_call=None, + parsed=CalendarEvent(name='Science Fair', date='Friday', participants=['Alice', 'Bob']), + refusal=None, + role='assistant', + tool_calls=[] + ) + ) + ], + created=1723761008, + id='chatcmpl-9wdGqXkJJARAz7rOrLH5u5FBwLjF3', + model='gpt-4o-2024-08-06', + object='chat.completion', + service_tier=None, + system_fingerprint='fp_2a322c9ffc', + usage=CompletionUsage(completion_tokens=17, prompt_tokens=32, total_tokens=49) +) +""" + ) + + @pytest.mark.respx(base_url=base_url) def test_pydantic_tool_model_all_types(client: OpenAI, respx_mock: MockRouter, monkeypatch: pytest.MonkeyPatch) -> None: completion = _make_snapshot_request(