Skip to content

Commit

Permalink
🔥 Remove unnecessary invited_user_id logic (#72)
Browse files Browse the repository at this point in the history
  • Loading branch information
estebanx64 authored May 27, 2024
1 parent df0cbec commit b70eea2
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 37 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""Remove invited_user_id column and relationships
Revision ID: c5cc3b0f01d6
Revises: a9b76125b71a
Create Date: 2024-05-24 15:47:43.647737
"""
from alembic import op
import sqlalchemy as sa
import sqlmodel.sql.sqltypes


# revision identifiers, used by Alembic.
revision = 'c5cc3b0f01d6'
down_revision = 'a9b76125b71a'
branch_labels = None
depends_on = None


def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.alter_column('invitation', 'email',
existing_type=sa.VARCHAR(),
nullable=False)
op.drop_constraint('invitation_invited_user_id_fkey', 'invitation', type_='foreignkey')
op.drop_column('invitation', 'invited_user_id')
# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('invitation', sa.Column('invited_user_id', sa.INTEGER(), autoincrement=False, nullable=True))
op.create_foreign_key('invitation_invited_user_id_fkey', 'invitation', 'user', ['invited_user_id'], ['id'])
op.alter_column('invitation', 'email',
existing_type=sa.VARCHAR(),
nullable=True)
# ### end Alembic commands ###
28 changes: 4 additions & 24 deletions backend/app/api/routes/invitations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import HTMLResponse
from sqlalchemy.sql import and_, or_
from sqlmodel import col, func, select

from app.api.deps import CurrentUser, SessionDep, get_first_superuser
Expand Down Expand Up @@ -44,12 +43,12 @@ def read_invitations_me(
count_statement = (
select(func.count())
.select_from(Invitation)
.where(col(Invitation.invited_user_id) == current_user.id)
.where(col(Invitation.email) == current_user.email)
)
count = session.exec(count_statement).one()
statement = (
select(Invitation)
.where(col(Invitation.invited_user_id) == current_user.id)
.where(col(Invitation.email) == current_user.email)
.offset(skip)
.limit(limit)
)
Expand Down Expand Up @@ -175,12 +174,6 @@ def create_invitation(
)

if not user_to_invite:
if not invitation.email:
raise HTTPException(
status_code=400,
detail="The invitation must have an email to be sent to a user that does not exist in our platform",
)

session.add(invitation)
session.commit()
session.refresh(invitation)
Expand All @@ -205,10 +198,6 @@ def create_invitation(
detail="The user is already in the team",
)

# make sure if the user was found fill the email or FK in the invitation row
invitation.invited_user_id = user_to_invite.id
invitation.email = user_to_invite.email

session.add(invitation)
session.commit()
session.refresh(invitation)
Expand Down Expand Up @@ -238,13 +227,7 @@ def accept_invitation(
raise HTTPException(status_code=400, detail="Invalid invitation token")

invitation_query = select(Invitation).where(
and_(
col(Invitation.id) == invitation_id,
or_(
col(Invitation.email) == current_user.email,
col(Invitation.invited_user_id) == current_user.id,
),
)
col(Invitation.id) == invitation_id, col(Invitation.email) == current_user.email
)
invitation = session.exec(invitation_query).first()

Expand All @@ -257,9 +240,6 @@ def accept_invitation(
if current_user.id in {link.user_id for link in invitation.team.user_links}:
raise HTTPException(status_code=400, detail="User already in team")

if invitation.invited_user_id is None:
invitation.invited_user_id = current_user.id

invitation.status = InvitationStatus.accepted

add_user_to_team(
Expand Down Expand Up @@ -313,7 +293,7 @@ def invitation_html_content(invitation_id: int, session: SessionDep) -> Any:
if not invitation:
raise HTTPException(status_code=404, detail="Invitation not found")
token = generate_invitation_token(invitation_id=invitation_id)
email_to = invitation.email or invitation.receiver.email
email_to = invitation.email
email_from = invitation.sender.email
email_data = generate_invitation_token_email(
team_name=invitation.team.name,
Expand Down
11 changes: 0 additions & 11 deletions backend/app/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,6 @@ class User(UserBase, table=True):
hashed_password: str

team_links: list[UserTeamLink] = Relationship(back_populates="user")
invitations: list["Invitation"] = Relationship(
back_populates="receiver",
sa_relationship_kwargs={"foreign_keys": "[Invitation.invited_user_id]"},
)
invitations_sent: list["Invitation"] = Relationship(
back_populates="sender",
sa_relationship_kwargs={"foreign_keys": "[Invitation.invited_by_id]"},
Expand Down Expand Up @@ -189,10 +185,8 @@ class InvitationPublic(InvitationBase):
id: int
team_id: int
invited_by_id: int
invited_user_id: int | None = None
status: InvitationStatus
created_at: datetime
receiver: UserPublic | None
sender: UserPublic
team: TeamPublic

Expand All @@ -206,15 +200,10 @@ class Invitation(InvitationBase, table=True):
id: int | None = Field(default=None, primary_key=True)
team_id: int = Field(foreign_key="team.id")
invited_by_id: int = Field(foreign_key="user.id")
invited_user_id: int | None = Field(default=None, foreign_key="user.id")
status: InvitationStatus = InvitationStatus.pending
created_at: datetime = Field(default_factory=get_datetime_utc)
expires_at: datetime

receiver: User = Relationship(
back_populates="invitations",
sa_relationship_kwargs={"foreign_keys": "[Invitation.invited_user_id]"},
)
sender: User = Relationship(
back_populates="invitations_sent",
sa_relationship_kwargs={"foreign_keys": "[Invitation.invited_by_id]"},
Expand Down
2 changes: 0 additions & 2 deletions backend/app/tests/api/routes/test_invitations.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,6 @@ def test_read_invitations_me(client: TestClient, db: Session) -> None:
assert invitations[1]["email"] == invited_user.email
assert invitations[0]["role"] == "member"
assert invitations[1]["role"] == "member"
assert invitations[0]["receiver"]["id"] == invited_user.id
assert invitations[1]["receiver"]["id"] == invited_user.id


def test_read_invitations_me_empty(client: TestClient, db: Session) -> None:
Expand Down

0 comments on commit b70eea2

Please sign in to comment.