Skip to content

Commit

Permalink
Initial functional user auth
Browse files Browse the repository at this point in the history
  • Loading branch information
Vasilije1990 committed Jul 22, 2024
1 parent 77e8c1b commit e785b30
Show file tree
Hide file tree
Showing 10 changed files with 354 additions and 99 deletions.
13 changes: 9 additions & 4 deletions cognee/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,13 @@
traces_sample_rate = 1.0,
profiles_sample_rate = 1.0,
)

app = FastAPI(debug = os.getenv("ENV") != "prod")
from contextlib import asynccontextmanager
@asynccontextmanager
async def lifespan(app: FastAPI):
# Not needed if you setup a migration system like Alembic
await create_db_and_tables()
yield
app = FastAPI(debug = os.getenv("ENV") != "prod", lifespan=lifespan)

origins = [
"http://frontend:3000",
Expand Down Expand Up @@ -338,8 +343,8 @@ def start_api_server(host: str = "0.0.0.0", port: int = 8000):
relational_config.create_engine()

from cognee.modules.data.deletion import prune_system, prune_data
asyncio.run(prune_data())
asyncio.run(prune_system(metadata = True))
# asyncio.run(prune_data())
# asyncio.run(prune_system(metadata = True))

uvicorn.run(app, host = host, port = port)
except Exception as e:
Expand Down
9 changes: 5 additions & 4 deletions cognee/infrastructure/databases/relational/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,22 @@

class RelationalConfig(BaseSettings):
db_path: str = os.path.join(get_absolute_path(".cognee_system"), "databases")
db_name: str = "cognee.db"
db_name: str = "cognee_db"
db_host: str = "localhost"
db_port: str = "5432"
db_user: str = "cognee"
db_password: str = "cognee"
db_provider: str = "duckdb"
database_engine: object = create_relational_engine(db_path, db_name, db_provider)
db_provider: str = "postgresql+asyncpg"
# database_engine: object = create_relational_engine(db_path, db_name, db_provider)
db_file_path: str = os.path.join(db_path, db_name)


model_config = SettingsConfigDict(env_file = ".env", extra = "allow")

def create_engine(self):
self.db_file_path = os.path.join(self.db_path, self.db_name)
self.database_engine = create_relational_engine(self.db_path, self.db_name)
self.database_engine = create_relational_engine(self.db_path, self.db_name, self.db_provider, self.db_host, self.db_port, self.db_user, self.db_password)
return self.database_engine

def to_dict(self) -> dict:
return {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,39 @@

from cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter import SQLAlchemyAdapter
from cognee.infrastructure.files.storage import LocalStorage
from cognee.infrastructure.databases.relational import DuckDBAdapter, get_relationaldb_config
from cognee.infrastructure.databases.relational import DuckDBAdapter


class DBProvider(Enum):
DUCKDB = "duckdb"
POSTGRES = "postgres"
POSTGRES = "postgresql+asyncpg"



def create_relational_engine(db_path: str, db_name: str, db_provider:str):
def create_relational_engine(db_path: str, db_name: str, db_provider:str, db_host:str, db_port:str, db_user:str, db_password:str):
LocalStorage.ensure_directory_exists(db_path)

llm_config = get_relationaldb_config()

provider = DBProvider(llm_config.llm_provider)

provider = DBProvider(db_provider)

if provider == DBProvider.DUCKDB:

return DuckDBAdapter(
# return DuckDBAdapter(
# db_name = db_name,
# db_path = db_path,
# )
return SQLAlchemyAdapter(
db_name = db_name,
db_path = db_path,
db_type = db_provider,
db_host=db_host,
db_port=db_port,
db_user=db_user,
db_password=db_password
)
elif provider == DBProvider.POSTGRES:
return SQLAlchemyAdapter(
db_name = db_name,
db_path = db_path,
db_type = db_provider,
db_host= db_host,
db_port= db_port,
db_user= db_user,
db_password= db_password
)
Original file line number Diff line number Diff line change
@@ -1,12 +1,35 @@
import os
from sqlalchemy import create_engine, MetaData, Table, Column, String, Boolean, TIMESTAMP, text
from sqlalchemy.orm import sessionmaker
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine

class SQLAlchemyAdapter():
def __init__(self, db_type: str, db_path: str, db_name: str):
def __init__(self, db_type: str, db_path: str, db_name: str, db_user:str, db_password:str, db_host:str, db_port:str):
self.db_location = os.path.abspath(os.path.join(db_path, db_name))
self.engine = create_engine(f"{db_type}:///{self.db_location}")
self.Session = sessionmaker(bind=self.engine)
# self.engine = create_engine(f"{db_type}:///{self.db_location}")
if db_type == "duckdb":
self.engine = create_engine(f"duckdb:///{self.db_location}")
self.Session = sessionmaker(bind=self.engine)

else:
print("Name: ", db_name)
print("User: ", db_user)
print("Password: ", db_password)
print("Host: ", db_host)
print("Port: ", db_port)
self.engine = create_async_engine(f"postgresql+asyncpg://{db_user}:{db_password}@{db_host}:{db_port}/{db_name}")
self.Session = sessionmaker(bind=self.engine, class_=AsyncSession, expire_on_commit=False)


async def get_async_session(self):
async_session_maker = self.Session
async with async_session_maker() as session:
yield session

def get_session(self):
session_maker = self.Session
with session_maker() as session:
yield session

def get_datasets(self):
with self.engine.connect() as connection:
Expand Down
Empty file.
Original file line number Diff line number Diff line change
@@ -1,15 +1,11 @@
from typing import AsyncGenerator
from typing import AsyncGenerator, Generator

from fastapi import Depends
from fastapi_users.db import SQLAlchemyBaseUserTableUUID, SQLAlchemyUserDatabase
from sqlalchemy import create_engine
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.orm import DeclarativeBase

from sqlalchemy.orm import DeclarativeBase, sessionmaker, Session
from cognee.infrastructure.databases.relational import get_relationaldb_config
from cognee.infrastructure.databases.relational.create_relational_engine import create_relational_engine

DATABASE_URL = "sqlite+aiosqlite:///./test.db"


class Base(DeclarativeBase):
pass
Expand All @@ -19,23 +15,21 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
pass


relational_config = get_relationaldb_config()

llm_config = get_relationaldb_config()


engine = create_relational_engine(llm_config.db_path, llm_config.db_name, llm_config.db_provider)
async_session_maker = async_sessionmaker(engine, expire_on_commit=False)

engine = relational_config.create_engine()
async_session_maker = async_sessionmaker(engine.engine, expire_on_commit=False)

async def create_db_and_tables():
async with engine.begin() as conn:
async with engine.engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)


async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
async with async_session_maker() as session:
yield session

# yield async_session_maker

async def get_user_db(session: AsyncSession = Depends(get_async_session)):
yield SQLAlchemyUserDatabase(session, User)
yield SQLAlchemyUserDatabase(session, User)
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
)
from fastapi_users.db import SQLAlchemyUserDatabase

from app.db import User, get_user_db

from cognee.infrastructure.databases.relational.user_authentication.authentication_db import User, get_user_db

SECRET = "SECRET"

Expand Down
16 changes: 16 additions & 0 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,24 @@ services:
- 3001:3000
networks:
- cognee-network
postgres:
image: postgres:latest
container_name: postgres
environment:
POSTGRES_USER: cognee
POSTGRES_PASSWORD: cognee
POSTGRES_DB: cognee_db
volumes:
- postgres_data:/var/lib/postgresql/data
ports:
- 5432:5432
networks:
- cognee-network

networks:
cognee-network:
name: cognee-network

volumes:
postgres_data:

Loading

0 comments on commit e785b30

Please sign in to comment.