Skip to content

Commit

Permalink
🏷️ Add mypy to the GitHub Action for tests and fixed types in the who…
Browse files Browse the repository at this point in the history
…le project (fastapi#655)

Co-authored-by: Sebastián Ramírez <[email protected]>
  • Loading branch information
estebanx64 and tiangolo authored Mar 10, 2024
1 parent 86c0ed7 commit 16f2564
Show file tree
Hide file tree
Showing 19 changed files with 106 additions and 79 deletions.
2 changes: 1 addition & 1 deletion backend/app/api/deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
)


def get_db() -> Generator:
def get_db() -> Generator[Session, None, None]:
with Session(engine) as session:
yield session

Expand Down
21 changes: 11 additions & 10 deletions backend/app/api/routes/users.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any

from fastapi import APIRouter, Depends, HTTPException
from sqlmodel import delete, func, select
from sqlmodel import col, delete, func, select

from app import crud
from app.api.deps import (
Expand Down Expand Up @@ -189,16 +189,17 @@ def delete_user(
user = session.get(User, user_id)
if not user:
raise HTTPException(status_code=404, detail="User not found")

if (user == current_user and not current_user.is_superuser) or (
user != current_user and current_user.is_superuser
):
statement = delete(Item).where(Item.owner_id == user_id)
session.exec(statement)
session.delete(user)
session.commit()
return Message(message="User deleted successfully")
elif user != current_user and not current_user.is_superuser:
raise HTTPException(
status_code=403, detail="The user doesn't have enough privileges"
)
elif user == current_user and current_user.is_superuser:
raise HTTPException(
status_code=403, detail="Super users are not allowed to delete themselves"
)

statement = delete(Item).where(col(Item.owner_id) == user_id)
session.exec(statement) # type: ignore
session.delete(user)
session.commit()
return Message(message="User deleted successfully")
6 changes: 3 additions & 3 deletions backend/app/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def assemble_db_connection(cls, v: str | None, info: ValidationInfo) -> Any:
@field_validator("EMAILS_FROM_NAME")
def get_project_name(cls, v: str | None, info: ValidationInfo) -> str:
if not v:
return info.data["PROJECT_NAME"]
return str(info.data["PROJECT_NAME"])
return v

EMAIL_RESET_TOKEN_EXPIRE_HOURS: int = 48
Expand All @@ -89,7 +89,7 @@ def get_emails_enabled(cls, v: bool, info: ValidationInfo) -> bool:
FIRST_SUPERUSER: str
FIRST_SUPERUSER_PASSWORD: str
USERS_OPEN_REGISTRATION: bool = False
model_config = SettingsConfigDict(case_sensitive=True)
model_config = SettingsConfigDict(env_file=".env")


settings = Settings()
settings = Settings() # type: ignore
9 changes: 2 additions & 7 deletions backend/app/core/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,8 @@
ALGORITHM = "HS256"


def create_access_token(subject: str | Any, expires_delta: timedelta = None) -> str:
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(
minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES
)
def create_access_token(subject: str | Any, expires_delta: timedelta) -> str:
expire = datetime.utcnow() + expires_delta
to_encode = {"exp": expire, "sub": str(subject)}
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
Expand Down
2 changes: 1 addition & 1 deletion backend/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from app.core.config import settings


def custom_generate_unique_id(route: APIRoute):
def custom_generate_unique_id(route: APIRoute) -> str:
return f"{route.tags[0]}-{route.name}"


Expand Down
4 changes: 2 additions & 2 deletions backend/app/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class UserCreateOpen(SQLModel):
# Properties to receive via API on update, all are optional
# TODO replace email str with EmailStr when sqlmodel supports it
class UserUpdate(UserBase):
email: str | None = None
email: str | None = None # type: ignore
password: str | None = None


Expand Down Expand Up @@ -70,7 +70,7 @@ class ItemCreate(ItemBase):

# Properties to receive on item update
class ItemUpdate(ItemBase):
title: str | None = None
title: str | None = None # type: ignore


# Database model, database table inferred from class name
Expand Down
22 changes: 11 additions & 11 deletions backend/app/tests/api/api_v1/test_items.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


def test_create_item(
client: TestClient, superuser_token_headers: dict, db: Session
client: TestClient, superuser_token_headers: dict[str, str], db: Session
) -> None:
data = {"title": "Foo", "description": "Fighters"}
response = client.post(
Expand All @@ -23,7 +23,7 @@ def test_create_item(


def test_read_item(
client: TestClient, superuser_token_headers: dict, db: Session
client: TestClient, superuser_token_headers: dict[str, str], db: Session
) -> None:
item = create_random_item(db)
response = client.get(
Expand All @@ -39,7 +39,7 @@ def test_read_item(


def test_read_item_not_found(
client: TestClient, superuser_token_headers: dict, db: Session
client: TestClient, superuser_token_headers: dict[str, str], db: Session
) -> None:
response = client.get(
f"{settings.API_V1_STR}/items/999",
Expand All @@ -51,7 +51,7 @@ def test_read_item_not_found(


def test_read_item_not_enough_permissions(
client: TestClient, normal_user_token_headers: dict, db: Session
client: TestClient, normal_user_token_headers: dict[str, str], db: Session
) -> None:
item = create_random_item(db)
response = client.get(
Expand All @@ -64,7 +64,7 @@ def test_read_item_not_enough_permissions(


def test_read_items(
client: TestClient, superuser_token_headers: dict, db: Session
client: TestClient, superuser_token_headers: dict[str, str], db: Session
) -> None:
create_random_item(db)
create_random_item(db)
Expand All @@ -78,7 +78,7 @@ def test_read_items(


def test_update_item(
client: TestClient, superuser_token_headers: dict, db: Session
client: TestClient, superuser_token_headers: dict[str, str], db: Session
) -> None:
item = create_random_item(db)
data = {"title": "Updated title", "description": "Updated description"}
Expand All @@ -96,7 +96,7 @@ def test_update_item(


def test_update_item_not_found(
client: TestClient, superuser_token_headers: dict, db: Session
client: TestClient, superuser_token_headers: dict[str, str], db: Session
) -> None:
data = {"title": "Updated title", "description": "Updated description"}
response = client.put(
Expand All @@ -110,7 +110,7 @@ def test_update_item_not_found(


def test_update_item_not_enough_permissions(
client: TestClient, normal_user_token_headers: dict, db: Session
client: TestClient, normal_user_token_headers: dict[str, str], db: Session
) -> None:
item = create_random_item(db)
data = {"title": "Updated title", "description": "Updated description"}
Expand All @@ -125,7 +125,7 @@ def test_update_item_not_enough_permissions(


def test_delete_item(
client: TestClient, superuser_token_headers: dict, db: Session
client: TestClient, superuser_token_headers: dict[str, str], db: Session
) -> None:
item = create_random_item(db)
response = client.delete(
Expand All @@ -138,7 +138,7 @@ def test_delete_item(


def test_delete_item_not_found(
client: TestClient, superuser_token_headers: dict, db: Session
client: TestClient, superuser_token_headers: dict[str, str], db: Session
) -> None:
response = client.delete(
f"{settings.API_V1_STR}/items/999",
Expand All @@ -150,7 +150,7 @@ def test_delete_item_not_found(


def test_delete_item_not_enough_permissions(
client: TestClient, normal_user_token_headers: dict, db: Session
client: TestClient, normal_user_token_headers: dict[str, str], db: Session
) -> None:
item = create_random_item(db)
response = client.delete(
Expand Down
5 changes: 3 additions & 2 deletions backend/app/tests/api/api_v1/test_login.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from app.utils import generate_password_reset_token
from fastapi.testclient import TestClient
from pytest_mock import MockerFixture

from app.core.config import settings
from app.utils import generate_password_reset_token


def test_get_access_token(client: TestClient) -> None:
Expand Down Expand Up @@ -38,7 +39,7 @@ def test_use_access_token(


def test_recovery_password(
client: TestClient, normal_user_token_headers: dict[str, str], mocker
client: TestClient, normal_user_token_headers: dict[str, str], mocker: MockerFixture
) -> None:
mocker.patch("app.utils.send_email", return_value=None)
mocker.patch("app.core.config.settings.EMAILS_ENABLED", True)
Expand Down
29 changes: 19 additions & 10 deletions backend/app/tests/api/api_v1/test_users.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from fastapi.testclient import TestClient
from pytest_mock import MockerFixture
from sqlmodel import Session

from app import crud
Expand Down Expand Up @@ -30,7 +31,10 @@ def test_get_users_normal_user_me(


def test_create_user_new_email(
client: TestClient, superuser_token_headers: dict, db: Session, mocker
client: TestClient,
superuser_token_headers: dict[str, str],
db: Session,
mocker: MockerFixture,
) -> None:
mocker.patch("app.utils.send_email")
mocker.patch("app.core.config.settings.EMAILS_ENABLED", True)
Expand All @@ -50,7 +54,7 @@ def test_create_user_new_email(


def test_get_existing_user(
client: TestClient, superuser_token_headers: dict, db: Session
client: TestClient, superuser_token_headers: dict[str, str], db: Session
) -> None:
username = random_email()
password = random_lower_string()
Expand Down Expand Up @@ -107,7 +111,7 @@ def test_get_existing_user_permissions_error(


def test_create_user_existing_username(
client: TestClient, superuser_token_headers: dict, db: Session
client: TestClient, superuser_token_headers: dict[str, str], db: Session
) -> None:
username = random_email()
# username = email
Expand Down Expand Up @@ -140,7 +144,7 @@ def test_create_user_by_normal_user(


def test_retrieve_users(
client: TestClient, superuser_token_headers: dict, db: Session
client: TestClient, superuser_token_headers: dict[str, str], db: Session
) -> None:
username = random_email()
password = random_lower_string()
Expand Down Expand Up @@ -179,7 +183,7 @@ def test_update_user_me(


def test_update_password_me(
client: TestClient, superuser_token_headers: dict, db: Session
client: TestClient, superuser_token_headers: dict[str, str], db: Session
) -> None:
new_password = random_lower_string()
data = {
Expand Down Expand Up @@ -209,7 +213,7 @@ def test_update_password_me(


def test_update_password_me_incorrect_password(
client: TestClient, superuser_token_headers: dict, db: Session
client: TestClient, superuser_token_headers: dict[str, str], db: Session
) -> None:
new_password = random_lower_string()
data = {"current_password": new_password, "new_password": new_password}
Expand All @@ -224,7 +228,7 @@ def test_update_password_me_incorrect_password(


def test_update_password_me_same_password_error(
client: TestClient, superuser_token_headers: dict, db: Session
client: TestClient, superuser_token_headers: dict[str, str], db: Session
) -> None:
data = {
"current_password": settings.FIRST_SUPERUSER_PASSWORD,
Expand All @@ -242,7 +246,7 @@ def test_update_password_me_same_password_error(
)


def test_create_user_open(client: TestClient, mocker) -> None:
def test_create_user_open(client: TestClient, mocker: MockerFixture) -> None:
mocker.patch("app.core.config.settings.USERS_OPEN_REGISTRATION", True)
username = random_email()
password = random_lower_string()
Expand All @@ -258,7 +262,9 @@ def test_create_user_open(client: TestClient, mocker) -> None:
assert created_user["full_name"] == full_name


def test_create_user_open_forbidden_error(client: TestClient, mocker) -> None:
def test_create_user_open_forbidden_error(
client: TestClient, mocker: MockerFixture
) -> None:
mocker.patch("app.core.config.settings.USERS_OPEN_REGISTRATION", False)
username = random_email()
password = random_lower_string()
Expand All @@ -272,7 +278,9 @@ def test_create_user_open_forbidden_error(client: TestClient, mocker) -> None:
assert r.json()["detail"] == "Open user registration is forbidden on this server"


def test_create_user_open_already_exists_error(client: TestClient, mocker) -> None:
def test_create_user_open_already_exists_error(
client: TestClient, mocker: MockerFixture
) -> None:
mocker.patch("app.core.config.settings.USERS_OPEN_REGISTRATION", True)
password = random_lower_string()
full_name = random_lower_string()
Expand Down Expand Up @@ -382,6 +390,7 @@ def test_delete_user_current_super_user_error(
client: TestClient, superuser_token_headers: dict[str, str], db: Session
) -> None:
super_user = crud.get_user_by_email(session=db, email=settings.FIRST_SUPERUSER)
assert super_user
user_id = super_user.id

r = client.delete(
Expand Down
4 changes: 2 additions & 2 deletions backend/app/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@


@pytest.fixture(scope="session", autouse=True)
def db() -> Generator:
def db() -> Generator[Session, None, None]:
with Session(engine) as session:
init_db(session)
yield session
Expand All @@ -25,7 +25,7 @@ def db() -> Generator:


@pytest.fixture(scope="module")
def client() -> Generator:
def client() -> Generator[TestClient, None, None]:
with TestClient(app) as c:
yield c

Expand Down
21 changes: 13 additions & 8 deletions backend/app/tests/scripts/test_backend_pre_start.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,33 @@
from unittest.mock import MagicMock

from pytest_mock import MockerFixture
from sqlmodel import select

from app.backend_pre_start import init, logger


def test_init_successful_connection(mocker):
def test_init_successful_connection(mocker: MockerFixture) -> None:
engine_mock = MagicMock()

session_mock = MagicMock()
exec_mock = MagicMock(return_value=True)
session_mock.configure_mock(**{'exec.return_value': exec_mock})
mocker.patch('sqlmodel.Session', return_value=session_mock)
session_mock.configure_mock(**{"exec.return_value": exec_mock})
mocker.patch("sqlmodel.Session", return_value=session_mock)

mocker.patch.object(logger, 'info')
mocker.patch.object(logger, 'error')
mocker.patch.object(logger, 'warn')
mocker.patch.object(logger, "info")
mocker.patch.object(logger, "error")
mocker.patch.object(logger, "warn")

try:
init(engine_mock)
connection_successful = True
except Exception:
connection_successful = False

assert connection_successful, "The database connection should be successful and not raise an exception."
assert (
connection_successful
), "The database connection should be successful and not raise an exception."

assert session_mock.exec.called_once_with(select(1)), "The session should execute a select statement once."
assert session_mock.exec.called_once_with(
select(1)
), "The session should execute a select statement once."
Loading

0 comments on commit 16f2564

Please sign in to comment.