Skip to content

Commit

Permalink
Check foreign keys belong to tournament
Browse files Browse the repository at this point in the history
  • Loading branch information
evroon committed Feb 23, 2024
1 parent b2e5a0e commit 10bc7ca
Show file tree
Hide file tree
Showing 9 changed files with 166 additions and 41 deletions.
8 changes: 6 additions & 2 deletions backend/bracket/routes/matches.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from bracket.sql.courts import get_all_courts_in_tournament
from bracket.sql.matches import sql_create_match, sql_delete_match, sql_update_match
from bracket.sql.tournaments import sql_get_tournament
from bracket.sql.validation import check_inputs_belong_to_tournament
from bracket.utils.id_types import MatchId, TournamentId
from bracket.utils.types import assert_some

Expand Down Expand Up @@ -77,6 +78,8 @@ async def create_match(
match_body: MatchCreateBodyFrontend,
_: UserPublic = Depends(user_authenticated_for_tournament),
) -> SingleMatchResponse:
await check_inputs_belong_to_tournament(match_body, tournament_id)

tournament = await sql_get_tournament(tournament_id)
body_with_durations = MatchCreateBody(
**match_body.model_dump(),
Expand Down Expand Up @@ -105,6 +108,7 @@ async def reschedule_match(
body: MatchRescheduleBody,
_: UserPublic = Depends(user_authenticated_for_tournament),
) -> SuccessResponse:
await check_inputs_belong_to_tournament(body, tournament_id)
await handle_match_reschedule(tournament_id, body, match_id)
return SuccessResponse()

Expand Down Expand Up @@ -175,9 +179,9 @@ async def update_match_by_id(
_: UserPublic = Depends(user_authenticated_for_tournament),
match: Match = Depends(match_dependency),
) -> SuccessResponse:
assert match.id
await check_inputs_belong_to_tournament(match_body, tournament_id)
tournament = await sql_get_tournament(tournament_id)

await sql_update_match(match.id, match_body, tournament)
await sql_update_match(assert_some(match.id), match_body, tournament)
await recalculate_ranking_for_tournament_id(tournament_id)
return SuccessResponse()
3 changes: 3 additions & 0 deletions backend/bracket/routes/rounds.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from bracket.sql.rounds import get_next_round_name, set_round_active_or_draft, sql_create_round
from bracket.sql.stage_items import get_stage_item
from bracket.sql.stages import get_full_tournament_details
from bracket.sql.validation import check_inputs_belong_to_tournament
from bracket.utils.id_types import RoundId, TournamentId

router = APIRouter()
Expand Down Expand Up @@ -55,6 +56,8 @@ async def create_round(
round_body: RoundCreateBody,
user: UserPublic = Depends(user_authenticated_for_tournament),
) -> SuccessResponse:
await check_inputs_belong_to_tournament(round_body, tournament_id)

stages = await get_full_tournament_details(tournament_id)
existing_rounds = [
round_
Expand Down
3 changes: 3 additions & 0 deletions backend/bracket/routes/stage_items.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
sql_create_stage_item,
)
from bracket.sql.stages import get_full_tournament_details
from bracket.sql.validation import check_inputs_belong_to_tournament
from bracket.utils.id_types import StageItemId, TournamentId

router = APIRouter()
Expand Down Expand Up @@ -62,6 +63,8 @@ async def create_stage_item(
detail="Team count doesn't match number of inputs",
)

await check_inputs_belong_to_tournament(stage_body, tournament_id)

stages = await get_full_tournament_details(tournament_id)
existing_stage_items = [stage_item for stage in stages for stage_item in stage.stage_items]
check_requirement(existing_stage_items, user, "max_stage_items")
Expand Down
5 changes: 5 additions & 0 deletions backend/bracket/routes/teams.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
get_teams_with_members,
sql_delete_team,
)
from bracket.sql.validation import check_inputs_belong_to_tournament
from bracket.utils.db import fetch_one_parsed
from bracket.utils.id_types import PlayerId, TeamId, TournamentId
from bracket.utils.pagination import PaginationTeams
Expand Down Expand Up @@ -78,6 +79,8 @@ async def update_team_by_id(
_: UserPublic = Depends(user_authenticated_for_tournament),
team: Team = Depends(team_dependency),
) -> SingleTeamResponse:
await check_inputs_belong_to_tournament(team_body, tournament_id)

await database.execute(
query=teams.update().where(
(teams.c.id == team.id) & (teams.c.tournament_id == tournament_id)
Expand Down Expand Up @@ -133,6 +136,8 @@ async def create_team(
tournament_id: TournamentId,
user: UserPublic = Depends(user_authenticated_for_tournament),
) -> SingleTeamResponse:
await check_inputs_belong_to_tournament(team_to_insert, tournament_id)

existing_teams = await get_teams_with_members(tournament_id)
check_requirement(existing_teams, user, "max_teams")

Expand Down
2 changes: 1 addition & 1 deletion backend/bracket/routes/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ async def get_me(


@router.put("/users/{user_id}", response_model=UserPublicResponse)
async def put_user(
async def update_user_details(
user_id: UserId,
user_to_update: UserToUpdate,
user_public: UserPublic = Depends(user_authenticated),
Expand Down
13 changes: 13 additions & 0 deletions backend/bracket/sql/players.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,19 @@ async def get_all_players_in_tournament(
return [Player.model_validate(x) for x in result]


async def get_player_by_id(player_id: PlayerId, tournament_id: TournamentId) -> Player | None:
query = """
SELECT *
FROM players
WHERE id = :player_id
AND tournament_id = :tournament_id
"""
result = await database.fetch_one(
query=query, values={"player_id": player_id, "tournament_id": tournament_id}
)
return Player.model_validate(result) if result is not None else None


async def get_player_count(
tournament_id: TournamentId,
*,
Expand Down
10 changes: 0 additions & 10 deletions backend/bracket/sql/stage_items.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,3 @@ async def get_stage_item(
return None

return stages[0].stage_items[0]


async def get_stage_items(
tournament_id: TournamentId, stage_item_ids: set[StageItemId]
) -> list[StageItemWithRounds]:
stages = await get_full_tournament_details(tournament_id, stage_item_ids=stage_item_ids)
if len(stages) < 1:
return []

return stages[0].stage_items
156 changes: 131 additions & 25 deletions backend/bracket/sql/validation.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,141 @@
from collections.abc import Awaitable, Callable
from typing import Any, cast

from fastapi import HTTPException
from pydantic import BaseModel
from starlette import status

from bracket.models.db.stage_item import StageItemCreateBody
from bracket.models.db.stage_item_inputs import (
StageItemInputCreateBodyFinal,
StageItemInputCreateBodyTentative,
from bracket.models.db.util import StageWithStageItems
from bracket.sql.courts import get_all_courts_in_tournament
from bracket.sql.players import get_all_players_in_tournament, get_player_by_id
from bracket.sql.stages import get_full_tournament_details
from bracket.sql.teams import get_team_by_id
from bracket.utils.id_types import (
CourtId,
MatchId,
PlayerId,
RoundId,
StageId,
StageItemId,
StageItemInputId,
TeamId,
TournamentId,
)
from bracket.sql.stage_items import get_stage_items
from bracket.sql.teams import get_teams_by_id
from bracket.utils.id_types import TournamentId

CheckCallableT = Callable[[Any, list[StageWithStageItems], TournamentId], Awaitable[bool]]


async def check_stage_belongs_to_tournament(
stage_id: StageId, stages: list[StageWithStageItems], _: TournamentId
) -> bool:
return any(stage.id == stage_id for stage in stages)


async def check_team_belongs_to_tournament(
team_id: TeamId, _: list[StageWithStageItems], tournament_id: TournamentId
) -> bool:
return await get_team_by_id(team_id, tournament_id) is not None


async def check_stage_item_belongs_to_tournament(
stage_item_id: StageItemId, stages: list[StageWithStageItems], _: TournamentId
) -> bool:
return any(
stage_item.id == stage_item_id for stage in stages for stage_item in stage.stage_items
)


async def check_stage_item_input_belongs_to_tournament(
stage_item_input_id: StageItemInputId, stages: list[StageWithStageItems], _: TournamentId
) -> bool:
return any(
stage_item_input.id == stage_item_input_id
for stage in stages
for stage_item in stage.stage_items
for stage_item_input in stage_item.inputs
)


async def check_round_belongs_to_tournament(
round_id: RoundId, stages: list[StageWithStageItems], _: TournamentId
) -> bool:
return any(
round_.id == round_id
for stage in stages
for stage_item in stage.stage_items
for round_ in stage_item.rounds
)

async def todo_check_inputs_belong_to_tournament(
stage_body: StageItemCreateBody, tournament_id: TournamentId

async def check_match_belongs_to_tournament(
match_id: MatchId, stages: list[StageWithStageItems], _: TournamentId
) -> bool:
return any(
match.id == match_id
for stage in stages
for stage_item in stage.stage_items
for round_ in stage_item.rounds
for match in round_.matches
)


async def check_player_belongs_to_tournament(
player_id: PlayerId, _: list[StageWithStageItems], tournament_id: TournamentId
) -> bool:
return await get_player_by_id(player_id, tournament_id) is not None


async def check_players_belong_to_tournament(
player_ids: set[PlayerId], tournament_id: TournamentId
) -> bool:
return player_ids.issubset(
player.id for player in await get_all_players_in_tournament(tournament_id)
)


async def check_court_belongs_to_tournament(
court_id: CourtId, _: list[StageWithStageItems], tournament_id: TournamentId
) -> bool:
return any(court_id == court.id for court in await get_all_courts_in_tournament(tournament_id))


async def check_inputs_belong_to_tournament(
some_body: BaseModel, tournament_id: TournamentId
) -> None:
teams = {
input_.team_id
for input_ in stage_body.inputs
if isinstance(input_, StageItemInputCreateBodyFinal)
}
teams_fetched = await get_teams_by_id(teams, tournament_id)
stages = await get_full_tournament_details(tournament_id)

stage_items = {
input_.winner_from_stage_item_id
for input_ in stage_body.inputs
if isinstance(input_, StageItemInputCreateBodyTentative)
check_lookup: dict[type[Any], CheckCallableT] = {
StageId: check_stage_belongs_to_tournament,
TeamId: check_team_belongs_to_tournament,
StageItemId: check_stage_item_belongs_to_tournament,
StageItemInputId: check_stage_item_input_belongs_to_tournament,
RoundId: check_round_belongs_to_tournament,
PlayerId: check_player_belongs_to_tournament,
MatchId: check_match_belongs_to_tournament,
CourtId: check_court_belongs_to_tournament,
}
stage_items_fetched = await get_stage_items(tournament_id, stage_items)

if len(teams) != len(teams_fetched) or len(stage_items) != len(stage_items_fetched):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Could not find all team ids or stages",
)
for field_key, field_info in some_body.model_fields.items():
field_value = getattr(some_body, field_key)

if isinstance(field_value, BaseModel):
await check_inputs_belong_to_tournament(field_value, tournament_id)
elif isinstance(field_value, set):
if field_info.annotation == set[PlayerId]:
await check_players_belong_to_tournament(field_value, tournament_id)
else:
raise Exception(f"Unknown set type: {field_info.annotation}")
else:
check_callable = check_lookup.get(cast(Any, field_info.annotation))
if check_callable is not None and not await check_callable(
field_value, stages, tournament_id
):
field_name = (
field_info.annotation.__name__
if field_info.annotation is not None
else "Unknown type"
)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Could not find {field_name.replace('Id', '')} with ID {field_value}",
)
7 changes: 4 additions & 3 deletions backend/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,12 @@ ignore_missing_imports = true

[tool.pylint.'MESSAGES CONTROL']
disable = [
'broad-except',
'broad-exception-raised',
'consider-iterating-dictionary',
'dangerous-default-value',
'duplicate-code',
'fixme',
'import-outside-toplevel',
'invalid-name',
'logging-fstring-interpolation',
Expand All @@ -66,9 +70,6 @@ disable = [
'unspecified-encoding',
'unused-argument', # Gives false positives.
'wrong-import-position',
'fixme',
'broad-except',
'consider-iterating-dictionary',
]

[tool.bandit]
Expand Down

0 comments on commit 10bc7ca

Please sign in to comment.