Skip to content

Commit

Permalink
Fix query, form, header model extra not honored (#201)
Browse files Browse the repository at this point in the history
* Fix query, form, header model extra not honored

* update

---------

Co-authored-by: luolingchun <[email protected]>
  • Loading branch information
luolingchun and luolingchun authored Dec 1, 2024
1 parent 012d1f4 commit 79ed3fe
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 0 deletions.
20 changes: 20 additions & 0 deletions flask_openapi3/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,10 @@ def _get_value(model: Type[BaseModel], args: MultiDict, model_field_key: str, mo
def _validate_header(header: Type[BaseModel], func_kwargs: dict):
request_headers = dict(request.headers)
header_dict = {}
model_properties = header.model_json_schema().get("properties", {})
for model_field_key, model_field_value in header.model_fields.items():
key_title = model_field_key.replace("_", "-").title()
model_field_schema = model_properties.get(model_field_value.alias or model_field_key)
if model_field_value.alias and header.model_config.get("populate_by_name"):
key = model_field_value.alias
key_alias_title = model_field_value.alias.replace("_", "-").title()
Expand All @@ -57,6 +59,12 @@ def _validate_header(header: Type[BaseModel], func_kwargs: dict):
value = request_headers[key_title]
if value is not None:
header_dict[key] = value
if model_field_schema.get("type") == "null":
header_dict[key] = value # type:ignore
# extra keys
for key, value in request_headers.items():
if key not in header_dict.keys():
header_dict[key] = value
func_kwargs["header"] = header.model_validate(obj=header_dict)


Expand All @@ -81,6 +89,12 @@ def _validate_query(query: Type[BaseModel], func_kwargs: dict):
key, value = _get_value(query, request_args, model_field_key, model_field_value)
if value is not None and value != []:
query_dict[key] = value
if model_field_schema.get("type") == "null":
query_dict[key] = value
# extra keys
for key, value in request_args.items():
if key not in query_dict.keys():
query_dict[key] = value
func_kwargs["query"] = query.model_validate(obj=query_dict)


Expand Down Expand Up @@ -114,6 +128,12 @@ def _validate_form(form: Type[BaseModel], func_kwargs: dict):
value = _value
if value is not None and value != []:
form_dict[key] = value
if model_field_schema.get("type") == "null":
form_dict[key] = value
# extra keys
for key, value in {**dict(request_form), **dict(request_files)}.items():
if key not in form_dict.keys():
form_dict[key] = value
func_kwargs["form"] = form.model_validate(obj=form_dict)


Expand Down
78 changes: 78 additions & 0 deletions tests/test_model_extra.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# -*- coding: utf-8 -*-
# @Author : llc
# @Time : 2024/11/20 14:45
from typing import Optional

import pytest
from pydantic import BaseModel, Field, ConfigDict

from flask_openapi3 import OpenAPI

app = OpenAPI(__name__)
app.config["TESTING"] = True


class BookQuery(BaseModel):
age: Optional[int] = Field(None, description="Age")

model_config = ConfigDict(extra="allow")


class BookForm(BaseModel):
string: str

model_config = ConfigDict(extra="forbid")


class BookHeader(BaseModel):
api_key: str = Field(..., description="API Key")

model_config = ConfigDict(extra="forbid")


@pytest.fixture
def client():
client = app.test_client()

return client


@app.get("/book")
def get_books(query: BookQuery):
"""get books
to get all books
"""
assert query.age == 3
assert query.author == "joy"
return {"code": 0, "message": "ok"}


@app.post("/form")
def api_form(form: BookForm):
print(form)
return {"code": 0, "message": "ok"}


def test_query(client):
resp = client.get("/book?age=3&author=joy")
assert resp.status_code == 200


@app.get("/header")
def get_book(header: BookHeader):
return header.model_dump(by_alias=True)


def test_form(client):
data = {
"string": "a",
"string_list": ["a", "b", "c"]
}
r = client.post("/form", data=data, content_type="multipart/form-data")
assert r.status_code == 422


def test_header(client):
headers = {"Hello1": "111", "hello2": "222", "api_key": "333", "api_type": "A", "x-hello": "444"}
resp = client.get("/header", headers=headers)
assert resp.status_code == 422

0 comments on commit 79ed3fe

Please sign in to comment.