Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

⬆️♻️ WIP: Fixes openapi generators for web-api, i.e. api/specs/web-server #6771

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion api/specs/web-server/_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def get_compatible_outputs_given_target_input(

@router.get(
"/catalog/services/{service_key}/{service_version}/resources",
response_model=ServiceResourcesGet,
response_model=Envelope[ServiceResourcesGet],
)
def get_service_resources(
_params: Annotated[ServicePathParams, Depends()],
Expand Down
88 changes: 59 additions & 29 deletions api/specs/web-server/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,64 +5,93 @@
import sys
from collections.abc import Callable
from pathlib import Path
from typing import Any, ClassVar, NamedTuple
from typing import Annotated, Any, NamedTuple, Optional, Union, get_args, get_origin

import yaml
from common_library.json_serialization import json_dumps
from common_library.pydantic_fields_extension import get_type
from fastapi import FastAPI, Query
from models_library.basic_types import LogLevel
from pydantic import BaseModel, Field, create_model
from pydantic import BaseModel, ConfigDict, Field, Json, create_model
from pydantic.fields import FieldInfo
from servicelib.fastapi.openapi import override_fastapi_openapi_method

CURRENT_DIR = Path(sys.argv[0] if __name__ == "__main__" else __file__).resolve().parent


def _create_json_type(**schema_extras):
class _Json(str):
# FIXME: upgrade this to pydnatic v2 protocols
class _JsonStr(str):
__slots__ = ()

@classmethod
def __modify_schema__(cls, field_schema: dict[str, Any]) -> None:
# openapi.json schema is corrected here
field_schema.update(
type="string",
# format="json-string" NOTE: we need to get rid of openapi-core in web-server before using this!
)
if schema_extras:
field_schema.update(schema_extras)
def __get_pydantic_json_schema__(cls, schema: dict[str, Any]) -> dict[str, Any]:
# Update the schema with custom type and format
schema.update(type="string", format="json-string", **schema_extras)
return schema

return _Json
return _JsonStr


def _replace_basemodel_in_annotation(annotation, new_type):
origin = get_origin(annotation)

# Handle Annotated
if origin is Annotated:
args = get_args(annotation)
base_type = args[0]
metadata = args[1:]
if isinstance(base_type, type) and issubclass(base_type, BaseModel):
# Replace the BaseModel subclass
base_type = new_type

return Annotated[(base_type, *metadata)]

# Handle Optionals, Unions, or other generic types
if origin in (Optional, Union, list, dict, tuple): # Extendable for other generics
new_args = tuple(
_replace_basemodel_in_annotation(arg, new_type)
for arg in get_args(annotation)
)
return origin[new_args]

# Replace BaseModel subclass directly
if isinstance(annotation, type) and issubclass(annotation, BaseModel):
return new_type

# Return as-is if no changes
return annotation


def as_query(model_class: type[BaseModel]) -> type[BaseModel]:
fields = {}
for field_name, field_info in model_class.model_fields.items():

field_type = get_type(field_info)
default_value = field_info.default

kwargs = {
field_default = field_info.default
assert not field_info.default_factory # nosec
query_kwargs = {
"alias": field_info.alias,
"title": field_info.title,
"description": field_info.description,
"metadata": field_info.metadata,
"json_schema_extra": field_info.json_schema_extra,
"json_schema_extra": field_info.json_schema_extra or {},
}

if issubclass(field_type, BaseModel):
# Complex fields
assert "json_schema_extra" in kwargs # nosec
assert kwargs["json_schema_extra"] # nosec
field_type = _create_json_type(
description=kwargs["description"],
example=kwargs.get("json_schema_extra", {}).get("example_json"),
)
json_field_type = Json
# _create_json_type(
# description=query_kwargs["description"],
# example=query_kwargs.get("json_schema_extra", {}).get("example_json"),
# )

default_value = json_dumps(default_value) if default_value else None
annotation = _replace_basemodel_in_annotation(
field_info.annotation, new_type=json_field_type
)

fields[field_name] = (field_type, Query(default=default_value, **kwargs))
if annotation != field_info.annotation:
# Complex fields are transformed to Json
field_default = json_dumps(field_default) if field_default else None

fields[field_name] = (annotation, Query(default=field_default, **query_kwargs))

new_model_name = f"{model_class.__name__}Query"
return create_model(new_model_name, **fields)
Expand All @@ -78,14 +107,15 @@ class Log(BaseModel):
None, description="name of the logger receiving this message"
)

class Config:
schema_extra: ClassVar[dict[str, Any]] = {
model_config = ConfigDict(
json_schema_extra={
"example": {
"message": "Hi there, Mr user",
"level": "INFO",
"logger": "user-logger",
}
}
)


class ErrorItem(BaseModel):
Expand Down
3 changes: 2 additions & 1 deletion api/specs/web-server/_publications.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
status_code=status.HTTP_204_NO_CONTENT,
)
def service_submission(
_file: Annotated[bytes, File(description="metadata.json submission file")]
file: Annotated[bytes, File(description="metadata.json submission file")]
):
"""
Submits files with new service candidate
"""
assert file # nosec
2 changes: 1 addition & 1 deletion api/specs/web-server/openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
"_activity",
"_announcements",
"_catalog",
"_catalog_tags", # after _catalog
"_catalog_tags", # MUST BE after _catalog
"_cluster",
"_computations",
"_exporter",
Expand Down
4 changes: 3 additions & 1 deletion api/specs/web-server/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Extra reqs, besides webserver's

fastapi==0.96.0
fastapi
jsonref
pydantic
pydantic-extra-types
python-multipart
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, TypeAlias
from typing import Annotated, Any, TypeAlias

from pydantic import (
AnyHttpUrl,
Expand Down Expand Up @@ -92,9 +92,10 @@ class ClusterDetails(BaseModel):


class ClusterGet(Cluster):
access_rights: dict[GroupID, ClusterAccessRights] = Field(
alias="accessRights", default_factory=dict
)
access_rights: Annotated[
dict[GroupID, ClusterAccessRights],
Field(alias="accessRights", default_factory=dict),
] = {}

model_config = ConfigDict(extra="allow", populate_by_name=True)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, TypeAlias

from pydantic import ConfigDict, Field
from pydantic import ConfigDict, Field, RootModel
from pydantic.main import BaseModel

from ..api_schemas_catalog import services as api_schemas_catalog_services
Expand Down Expand Up @@ -228,8 +228,8 @@ class ServiceGet(api_schemas_catalog_services.ServiceGet):
)


class ServiceResourcesGet(api_schemas_catalog_services.ServiceResourcesGet):
model_config = OutputSchema.model_config
class ServiceResourcesGet(RootModel[api_schemas_catalog_services.ServiceResourcesGet]):
...


class CatalogServiceGet(api_schemas_catalog_services.ServiceGetV2):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class ClusterPathParams(BaseModel):
)


class ClusterGet(directorv2_clusters.ClusterGet):
class ClusterGet(directorv2_clusters.ClusterCreate):
model_config = OutputSchema.model_config


Expand Down
10 changes: 7 additions & 3 deletions packages/models-library/src/models_library/rest_filters.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import Generic, TypeVar
from typing import Annotated, Generic, TypeVar

from pydantic import BaseModel, Field, Json
from pydantic import BaseModel, BeforeValidator, Field

from .utils.common_validators import parse_json_pre_validator


class Filters(BaseModel):
Expand All @@ -15,7 +17,9 @@ class Filters(BaseModel):


class FiltersQueryParameters(BaseModel, Generic[FilterT]):
filters: Json[FilterT] | None = Field( # pylint: disable=unsubscriptable-object
filters: Annotated[
FilterT | None, BeforeValidator(parse_json_pre_validator)
] = Field( # pylint: disable=unsubscriptable-object
default=None,
description="Custom filter query parameter encoded as JSON",
)
Loading
Loading