Skip to content

Commit

Permalink
🐛 Fix nested schema names in JSON schema generation
Browse files Browse the repository at this point in the history
  • Loading branch information
perdy authored and migduroli committed Sep 27, 2023
1 parent d17555d commit 78c1c37
Show file tree
Hide file tree
Showing 5 changed files with 482 additions and 107 deletions.
5 changes: 4 additions & 1 deletion flama/schemas/_libs/typesystem/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,10 @@ def _get_field_type(
)

if isinstance(field, typesystem.Object):
return {k: self._get_field_type(v) for k, v in field.properties.items()}
object_fields = {k: self._get_field_type(v) for k, v in field.properties.items()}
if isinstance(field.additional_properties, (typesystem.Field, typesystem.Reference)):
object_fields[""] = self._get_field_type(field.additional_properties)
return object_fields

try:
return MAPPING_TYPES[field.__class__]
Expand Down
39 changes: 36 additions & 3 deletions flama/schemas/data_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,21 @@

t.TypeGuard = TypeGuard # type: ignore

if sys.version_info < (3, 11): # PORT: Remove when stop supporting 3.10 # pragma: no cover

class StrEnum(str, enum.Enum):
def _generate_next_value_(name, start, count, last_values):
return name.lower()

enum.StrEnum = StrEnum # type: ignore

__all__ = ["Field", "Schema", "Parameter", "Parameters"]


UNKNOWN = t.TypeVar("UNKNOWN")


class ParameterLocation(enum.Enum):
class ParameterLocation(enum.StrEnum): # type: ignore # PORT: Remove this comment when stop supporting 3.10
query = enum.auto()
path = enum.auto()
body = enum.auto()
Expand Down Expand Up @@ -126,9 +134,34 @@ def is_schema(cls, obj: t.Any) -> bool:
def name(self) -> str:
return schemas.adapter.name(self.schema)

def _fix_ref(self, value: str, refs: t.Dict[str, str]) -> str:
try:
prefix, name = value.rsplit("/", 1)
return f"{prefix}/{refs[name]}"
except KeyError:
return value

def _replace_json_schema_refs(self, schema: types.JSONField, refs: t.Dict[str, str]) -> types.JSONField:
if isinstance(schema, dict):
return {
k: self._fix_ref(t.cast(str, v), refs) if k == "$ref" else self._replace_json_schema_refs(v, refs)
for k, v in schema.items()
}

if isinstance(schema, (list, tuple, set)):
return [self._replace_json_schema_refs(x, refs) for x in schema]

return schema

@property
def json_schema(self) -> types.JSONSchema:
return schemas.adapter.to_json_schema(self.schema)
return t.cast(
types.JSONSchema,
self._replace_json_schema_refs(
schemas.adapter.to_json_schema(self.schema),
{Schema(x).name.rsplit(".", 1)[1]: Schema(x).name for x in self.nested_schemas()},
),
)

@property
def unique_schema(self) -> t.Any:
Expand All @@ -143,7 +176,7 @@ def nested_schemas(self, schema: t.Any = UNKNOWN) -> t.List[t.Any]:
return self.nested_schemas(self)

if schemas.adapter.is_schema(schema):
return [schema]
return [schemas.adapter.unique_schema(schema)]

if isinstance(schema, (list, tuple, set)):
return [x for field in schema for x in self.nested_schemas(field)]
Expand Down
105 changes: 105 additions & 0 deletions tests/schemas/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import typing as t
from collections import namedtuple

import marshmallow
import pydantic
import pytest
import typesystem
import typesystem.fields


@pytest.fixture(scope="function")
def foo_schema(app):
if app.schema.schema_library.lib == pydantic:
schema = pydantic.create_model("Foo", name=(str, ...))
name = "pydantic.main.Foo"
elif app.schema.schema_library.lib == typesystem:
schema = typesystem.Schema(title="Foo", fields={"name": typesystem.fields.String()})
name = "typesystem.schemas.Foo"
elif app.schema.schema_library.lib == marshmallow:
schema = type("Foo", (marshmallow.Schema,), {"name": marshmallow.fields.String()})
name = "abc.Foo"
else:
raise ValueError(f"Wrong schema lib: {app.schema.schema_library.lib}")
return namedtuple("FooSchema", ("schema", "name"))(schema=schema, name=name)


@pytest.fixture(scope="function")
def bar_schema(app, foo_schema):
child_schema = foo_schema.schema
if app.schema.schema_library.lib == pydantic:
schema = pydantic.create_model("Bar", foo=(child_schema, ...))
name = "pydantic.main.Bar"
elif app.schema.schema_library.lib == typesystem:
schema = typesystem.Schema(
title="Bar",
fields={"foo": typesystem.Reference(to="Foo", definitions=typesystem.Definitions({"Foo": child_schema}))},
)
name = "typesystem.schemas.Bar"
elif app.schema.schema_library.lib == marshmallow:
schema = type("Bar", (marshmallow.Schema,), {"foo": marshmallow.fields.Nested(child_schema())})
name = "abc.Bar"
else:
raise ValueError(f"Wrong schema lib: {app.schema.schema_library.lib}")
return namedtuple("BarSchema", ("schema", "name"))(schema=schema, name=name)


@pytest.fixture(scope="function")
def bar_list_schema(app, foo_schema):
child_schema = foo_schema.schema
if app.schema.schema_library.lib == pydantic:
schema = pydantic.create_model("BarList", foo=(t.List[child_schema], ...))
name = "pydantic.main.BarList"
elif app.schema.schema_library.lib == typesystem:
schema = typesystem.Schema(
title="BarList",
fields={
"foo": typesystem.Array(
typesystem.Reference(to="Foo", definitions=typesystem.Definitions({"Foo": child_schema}))
)
},
)
name = "typesystem.schemas.BarList"
elif app.schema.schema_library.lib == marshmallow:
schema = type(
"BarList",
(marshmallow.Schema,),
{"foo": marshmallow.fields.List(marshmallow.fields.Nested(child_schema()))},
)
name = "abc.BarList"
else:
raise ValueError(f"Wrong schema lib: {app.schema.schema_library.lib}")
return namedtuple("BarListSchema", ("schema", "name"))(schema=schema, name=name)


@pytest.fixture(scope="function")
def bar_dict_schema(app, foo_schema):
child_schema = foo_schema.schema
if app.schema.schema_library.lib == pydantic:
schema = pydantic.create_model("BarDict", foo=(t.Dict[str, child_schema], ...))
name = "pydantic.main.BarDict"
elif app.schema.schema_library.lib == typesystem:
schema = typesystem.Schema(
title="BarDict",
fields={
"foo": typesystem.Object(
properties=typesystem.Reference(to="Foo", definitions=typesystem.Definitions({"Foo": child_schema}))
)
},
)
name = "typesystem.schemas.BarDict"
elif app.schema.schema_library.lib == marshmallow:
schema = type(
"BarDict",
(marshmallow.Schema,),
{"foo": marshmallow.fields.Dict(values=marshmallow.fields.Nested(child_schema()))},
)
name = "abc.BarDict"
else:
raise ValueError(f"Wrong schema lib: {app.schema.schema_library.lib}")
return namedtuple("BarDictSchema", ("schema", "name"))(schema=schema, name=name)


@pytest.fixture(scope="function")
def schemas(foo_schema, bar_schema, bar_list_schema, bar_dict_schema):
return {"Foo": foo_schema, "Bar": bar_schema, "BarList": bar_list_schema, "BarDict": bar_dict_schema}
Loading

0 comments on commit 78c1c37

Please sign in to comment.