Skip to content

Commit

Permalink
tests for pydantic generic model support
Browse files Browse the repository at this point in the history
  • Loading branch information
jfschneider committed Dec 3, 2024
1 parent fb3cc7f commit 9575fa7
Show file tree
Hide file tree
Showing 3 changed files with 155 additions and 3 deletions.
6 changes: 5 additions & 1 deletion openapi_pydantic/v3/v3_0/util.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import re
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -242,6 +243,9 @@ def _traverse(obj: Any) -> None:


def _construct_ref_obj(pydantic_schema: PydanticSchema[PydanticType]) -> Reference:
ref_obj = Reference(**{"$ref": ref_prefix + pydantic_schema.schema_class.__name__})
ref_name = re.sub(
r"[^a-zA-Z0-9.\-_]", "_", pydantic_schema.schema_class.__name__
).replace(".", "__")
ref_obj = Reference(**{"$ref": ref_prefix + ref_name})
logger.debug(f"ref_obj={ref_obj}")
return ref_obj
76 changes: 75 additions & 1 deletion tests/util/test_util.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
from typing import Callable
from typing import Callable, Generic, TypeVar

import pytest
from pydantic import BaseModel, Field

from openapi_pydantic import (
Expand Down Expand Up @@ -68,6 +69,43 @@ def test_construct_open_api_with_schema_class_3() -> None:
assert "resp_bar" in schema_without_alias.properties


@pytest.mark.skipif(PYDANTIC_V2, reason="generic type for Pydantic V1")
def test_construct_open_api_with_schema_class_4_generic_response_v1() -> None:
DataT = TypeVar("DataT")
from pydantic.v1.generics import GenericModel

class GenericResponse(GenericModel, Generic[DataT]):
msg: str = Field(description="message of the generic response")
data: DataT = Field(description="data value of the generic response")

open_api_4 = construct_base_open_api_4_generic_response(
GenericResponse[PongResponse]
)

result = construct_open_api_with_schema_class(open_api_4)
assert result.components is not None
assert result.components.schemas is not None
assert "GenericResponse_PongResponse_" in result.components.schemas


@pytest.mark.skipif(not PYDANTIC_V2, reason="generic type for Pydantic V2")
def test_construct_open_api_with_schema_class_4_generic_response_v2() -> None:
DataT = TypeVar("DataT")

class GenericResponse(BaseModel, Generic[DataT]):
msg: str = Field(description="message of the generic response")
data: DataT = Field(description="data value of the generic response")

open_api_4 = construct_base_open_api_4_generic_response(
GenericResponse[PongResponse]
)

result = construct_open_api_with_schema_class(open_api_4)
assert result.components is not None
assert result.components.schemas is not None
assert "GenericResponse_PongResponse_" in result.components.schemas


def construct_base_open_api_1() -> OpenAPI:
model_validate: Callable[[dict], OpenAPI] = getattr(
OpenAPI, "model_validate" if PYDANTIC_V2 else "parse_obj"
Expand Down Expand Up @@ -176,6 +214,42 @@ def construct_base_open_api_3() -> OpenAPI:
)


def construct_base_open_api_4_generic_response(response_schema: type) -> OpenAPI:
return OpenAPI(
info=Info(
title="My own API",
version="v0.0.1",
),
paths={
"/ping": PathItem(
post=Operation(
requestBody=RequestBody(
content={
"application/json": MediaType(
media_type_schema=PydanticSchema(
schema_class=PingRequest
)
)
}
),
responses={
"200": Response(
description="pong",
content={
"application/json": MediaType(
media_type_schema=PydanticSchema(
schema_class=response_schema
)
)
},
)
},
)
)
},
)


class PingRequest(BaseModel):
"""Ping Request"""

Expand Down
76 changes: 75 additions & 1 deletion tests/v3_0/test_util.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
from typing import Callable, Literal
from typing import Callable, Generic, Literal, TypeVar

import pytest
from pydantic import BaseModel, Field

from openapi_pydantic.compat import PYDANTIC_V2
Expand Down Expand Up @@ -74,6 +75,43 @@ def test_construct_open_api_with_schema_class_3() -> None:
assert "resp_bar" in schema_without_alias.properties


@pytest.mark.skipif(PYDANTIC_V2, reason="generic type for Pydantic V1")
def test_construct_open_api_with_schema_class_4_generic_response_v1() -> None:
DataT = TypeVar("DataT")
from pydantic.v1.generics import GenericModel

class GenericResponse(GenericModel, Generic[DataT]):
msg: str = Field(description="message of the generic response")
data: DataT = Field(description="data value of the generic response")

open_api_4 = construct_base_open_api_4_generic_response(
GenericResponse[PongResponse]
)

result = construct_open_api_with_schema_class(open_api_4)
assert result.components is not None
assert result.components.schemas is not None
assert "GenericResponse_PongResponse_" in result.components.schemas


@pytest.mark.skipif(not PYDANTIC_V2, reason="generic type for Pydantic V2")
def test_construct_open_api_with_schema_class_4_generic_response() -> None:
DataT = TypeVar("DataT")

class GenericResponse(BaseModel, Generic[DataT]):
msg: str = Field(description="message of the generic response")
data: DataT = Field(description="data value of the generic response")

open_api_4 = construct_base_open_api_4_generic_response(
GenericResponse[PongResponse]
)

result = construct_open_api_with_schema_class(open_api_4)
assert result.components is not None
assert result.components.schemas is not None
assert "GenericResponse_PongResponse_" in result.components.schemas


def construct_base_open_api_1() -> OpenAPI:
model_validate: Callable[[dict], OpenAPI] = getattr(
OpenAPI, "model_validate" if PYDANTIC_V2 else "parse_obj"
Expand Down Expand Up @@ -215,6 +253,42 @@ def construct_base_open_api_3_plus() -> OpenAPI:
)


def construct_base_open_api_4_generic_response(response_schema: type) -> OpenAPI:
return OpenAPI(
info=Info(
title="My own API",
version="v0.0.1",
),
paths={
"/ping": PathItem(
post=Operation(
requestBody=RequestBody(
content={
"application/json": MediaType(
media_type_schema=PydanticSchema(
schema_class=PingRequest
)
)
}
),
responses={
"200": Response(
description="pong",
content={
"application/json": MediaType(
media_type_schema=PydanticSchema(
schema_class=response_schema
)
)
},
)
},
)
)
},
)


class PingRequest(BaseModel):
"""Ping Request"""

Expand Down

0 comments on commit 9575fa7

Please sign in to comment.