-
Notifications
You must be signed in to change notification settings - Fork 89
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
added boilerplate for user management
- Loading branch information
1 parent
86c7aa2
commit 866270b
Showing
4 changed files
with
177 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
36 changes: 36 additions & 0 deletions
36
cognee/infrastructure/databases/relational/user_authentication/authentication_db.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
from typing import AsyncGenerator | ||
|
||
from fastapi import Depends | ||
from fastapi_users.db import SQLAlchemyBaseUserTableUUID, SQLAlchemyUserDatabase | ||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine | ||
from sqlalchemy.orm import DeclarativeBase | ||
|
||
DATABASE_URL = "sqlite+aiosqlite:///./test.db" | ||
|
||
|
||
class Base(DeclarativeBase): | ||
pass | ||
|
||
|
||
class User(SQLAlchemyBaseUserTableUUID, Base): | ||
pass | ||
|
||
|
||
|
||
|
||
engine = create_async_engine(DATABASE_URL) | ||
async_session_maker = async_sessionmaker(engine, expire_on_commit=False) | ||
|
||
|
||
async def create_db_and_tables(): | ||
async with engine.begin() as conn: | ||
await conn.run_sync(Base.metadata.create_all) | ||
|
||
|
||
async def get_async_session() -> AsyncGenerator[AsyncSession, None]: | ||
async with async_session_maker() as session: | ||
yield session | ||
|
||
|
||
async def get_user_db(session: AsyncSession = Depends(get_async_session)): | ||
yield SQLAlchemyUserDatabase(session, User) |
15 changes: 15 additions & 0 deletions
15
cognee/infrastructure/databases/relational/user_authentication/schemas.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
import uuid | ||
|
||
from fastapi_users import schemas | ||
|
||
|
||
class UserRead(schemas.BaseUser[uuid.UUID]): | ||
pass | ||
|
||
|
||
class UserCreate(schemas.BaseUserCreate): | ||
pass | ||
|
||
|
||
class UserUpdate(schemas.BaseUserUpdate): | ||
pass |
55 changes: 55 additions & 0 deletions
55
cognee/infrastructure/databases/relational/user_authentication/users.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
import uuid | ||
from typing import Optional | ||
|
||
from fastapi import Depends, Request | ||
from fastapi_users import BaseUserManager, FastAPIUsers, UUIDIDMixin, models | ||
from fastapi_users.authentication import ( | ||
AuthenticationBackend, | ||
BearerTransport, | ||
JWTStrategy, | ||
) | ||
from fastapi_users.db import SQLAlchemyUserDatabase | ||
|
||
from app.db import User, get_user_db | ||
|
||
SECRET = "SECRET" | ||
|
||
|
||
class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]): | ||
reset_password_token_secret = SECRET | ||
verification_token_secret = SECRET | ||
|
||
async def on_after_register(self, user: User, request: Optional[Request] = None): | ||
print(f"User {user.id} has registered.") | ||
|
||
async def on_after_forgot_password( | ||
self, user: User, token: str, request: Optional[Request] = None | ||
): | ||
print(f"User {user.id} has forgot their password. Reset token: {token}") | ||
|
||
async def on_after_request_verify( | ||
self, user: User, token: str, request: Optional[Request] = None | ||
): | ||
print(f"Verification requested for user {user.id}. Verification token: {token}") | ||
|
||
|
||
async def get_user_manager(user_db: SQLAlchemyUserDatabase = Depends(get_user_db)): | ||
yield UserManager(user_db) | ||
|
||
|
||
bearer_transport = BearerTransport(tokenUrl="auth/jwt/login") | ||
|
||
|
||
def get_jwt_strategy() -> JWTStrategy[models.UP, models.ID]: | ||
return JWTStrategy(secret=SECRET, lifetime_seconds=3600) | ||
|
||
|
||
auth_backend = AuthenticationBackend( | ||
name="jwt", | ||
transport=bearer_transport, | ||
get_strategy=get_jwt_strategy, | ||
) | ||
|
||
fastapi_users = FastAPIUsers[User, uuid.UUID](get_user_manager, [auth_backend]) | ||
|
||
current_active_user = fastapi_users.current_user(active=True) |