forked from tortoise/tortoise-orm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
81e8ffb
commit eb13f78
Showing
7 changed files
with
185 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
# mypy: no-disallow-untyped-decorators | ||
# pylint: disable=E0611,E0401 | ||
import os | ||
|
||
import pytest | ||
from asgi_lifespan import LifespanManager | ||
from httpx import AsyncClient | ||
from main import LOG_FILE, app | ||
from models import Users | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def anyio_backend(): | ||
return "asyncio" | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
async def client(): | ||
if LOG_FILE.exists(): | ||
LOG_FILE.unlink() | ||
async with LifespanManager(app): | ||
async with AsyncClient(app=app, base_url="http://test") as c: | ||
yield c | ||
assert not LOG_FILE.exists() | ||
|
||
|
||
@pytest.mark.anyio | ||
async def test_create_user(client: AsyncClient): # nosec | ||
response = await client.post("/users", json={"username": "admin"}) | ||
assert response.status_code == 200, response.text | ||
data = response.json() | ||
assert data["username"] == "admin" | ||
assert "id" in data | ||
user_id = data["id"] | ||
|
||
user_obj = await Users.get(id=user_id) | ||
assert user_obj.id == user_id | ||
|
||
|
||
@pytest.mark.anyio | ||
async def test_lifespan(client: AsyncClient): # nosec | ||
if os.getenv("USE_LIFESPAN"): | ||
assert LOG_FILE.exists() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
# pylint: disable=E0611,E0401 | ||
import os | ||
from contextlib import asynccontextmanager | ||
from pathlib import Path | ||
from typing import List | ||
|
||
from fastapi import FastAPI | ||
from models import User_Pydantic, UserIn_Pydantic, Users | ||
from pydantic import BaseModel | ||
from starlette.exceptions import HTTPException | ||
|
||
from tortoise.contrib.fastapi import register_tortoise | ||
|
||
LOG_FILE = Path(__file__).parent / "foo.log" | ||
|
||
|
||
@asynccontextmanager | ||
async def lifespan(app: FastAPI): | ||
print("app startup") | ||
if not LOG_FILE.exists(): | ||
LOG_FILE.touch() | ||
yield | ||
print("app teardown") | ||
if LOG_FILE.exists(): | ||
LOG_FILE.unlink() | ||
|
||
|
||
if os.getenv("USE_LIFESPAN"): | ||
app = FastAPI(title="Tortoise ORM FastAPI test", lifespan=lifespan) | ||
else: | ||
app = FastAPI(title="Tortoise ORM FastAPI test") | ||
|
||
|
||
class Status(BaseModel): | ||
message: str | ||
|
||
|
||
@app.get("/users", response_model=List[User_Pydantic]) | ||
async def get_users(): | ||
return await User_Pydantic.from_queryset(Users.all()) | ||
|
||
|
||
@app.post("/users", response_model=User_Pydantic) | ||
async def create_user(user: UserIn_Pydantic): | ||
user_obj = await Users.create(**user.model_dump(exclude_unset=True)) | ||
return await User_Pydantic.from_tortoise_orm(user_obj) | ||
|
||
|
||
@app.get("/user/{user_id}", response_model=User_Pydantic) | ||
async def get_user(user_id: int): | ||
return await User_Pydantic.from_queryset_single(Users.get(id=user_id)) | ||
|
||
|
||
@app.put("/user/{user_id}", response_model=User_Pydantic) | ||
async def update_user(user_id: int, user: UserIn_Pydantic): | ||
await Users.filter(id=user_id).update(**user.model_dump(exclude_unset=True)) | ||
return await User_Pydantic.from_queryset_single(Users.get(id=user_id)) | ||
|
||
|
||
@app.delete("/user/{user_id}", response_model=Status) | ||
async def delete_user(user_id: int): | ||
deleted_count = await Users.filter(id=user_id).delete() | ||
if not deleted_count: | ||
raise HTTPException(status_code=404, detail=f"User {user_id} not found") | ||
return Status(message=f"Deleted user {user_id}") | ||
|
||
|
||
register_tortoise( | ||
app, | ||
db_url="sqlite://:memory:", | ||
modules={"models": ["models"]}, | ||
generate_schemas=True, | ||
add_exception_handlers=True, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
from tortoise import fields, models | ||
from tortoise.contrib.pydantic import pydantic_model_creator | ||
|
||
|
||
class Users(models.Model): | ||
""" | ||
The User model | ||
""" | ||
|
||
id = fields.IntField(pk=True) | ||
#: This is a username | ||
username = fields.CharField(max_length=20, unique=True) | ||
name = fields.CharField(max_length=50, null=True) | ||
family_name = fields.CharField(max_length=50, null=True) | ||
category = fields.CharField(max_length=30, default="misc") | ||
password_hash = fields.CharField(max_length=128, null=True) | ||
created_at = fields.DatetimeField(auto_now_add=True) | ||
modified_at = fields.DatetimeField(auto_now=True) | ||
|
||
def full_name(self) -> str: | ||
""" | ||
Returns the best name | ||
""" | ||
if self.name or self.family_name: | ||
return f"{self.name or ''} {self.family_name or ''}".strip() | ||
return self.username | ||
|
||
class PydanticMeta: | ||
computed = ["full_name"] | ||
exclude = ["password_hash"] | ||
|
||
|
||
User_Pydantic = pydantic_model_creator(Users, name="User") | ||
UserIn_Pydantic = pydantic_model_creator(Users, name="UserIn", exclude_readonly=True) |