-
Notifications
You must be signed in to change notification settings - Fork 89
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
866270b
commit 77e8c1b
Showing
5 changed files
with
164 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
35 changes: 29 additions & 6 deletions
35
cognee/infrastructure/databases/relational/create_relational_engine.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,33 @@ | ||
from enum import Enum | ||
|
||
from cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter import SQLAlchemyAdapter | ||
from cognee.infrastructure.files.storage import LocalStorage | ||
from cognee.infrastructure.databases.relational import DuckDBAdapter | ||
from cognee.infrastructure.databases.relational import DuckDBAdapter, get_relationaldb_config | ||
|
||
|
||
class DBProvider(Enum): | ||
DUCKDB = "duckdb" | ||
POSTGRES = "postgres" | ||
|
||
|
||
|
||
def create_relational_engine(db_path: str, db_name: str): | ||
def create_relational_engine(db_path: str, db_name: str, db_provider:str): | ||
LocalStorage.ensure_directory_exists(db_path) | ||
|
||
return DuckDBAdapter( | ||
db_name = db_name, | ||
db_path = db_path, | ||
) | ||
llm_config = get_relationaldb_config() | ||
|
||
provider = DBProvider(llm_config.llm_provider) | ||
|
||
|
||
if provider == DBProvider.DUCKDB: | ||
|
||
return DuckDBAdapter( | ||
db_name = db_name, | ||
db_path = db_path, | ||
) | ||
elif provider == DBProvider.POSTGRES: | ||
return SQLAlchemyAdapter( | ||
db_name = db_name, | ||
db_path = db_path, | ||
db_type = db_provider, | ||
) |
125 changes: 125 additions & 0 deletions
125
cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
import os | ||
from sqlalchemy import create_engine, MetaData, Table, Column, String, Boolean, TIMESTAMP, text | ||
from sqlalchemy.orm import sessionmaker | ||
|
||
class SQLAlchemyAdapter(): | ||
def __init__(self, db_type: str, db_path: str, db_name: str): | ||
self.db_location = os.path.abspath(os.path.join(db_path, db_name)) | ||
self.engine = create_engine(f"{db_type}:///{self.db_location}") | ||
self.Session = sessionmaker(bind=self.engine) | ||
|
||
def get_datasets(self): | ||
with self.engine.connect() as connection: | ||
result = connection.execute(text("SELECT DISTINCT schema_name FROM information_schema.tables;")) | ||
tables = [row['schema_name'] for row in result] | ||
return list( | ||
filter( | ||
lambda schema_name: not schema_name.endswith("staging") and schema_name != "cognee", | ||
tables | ||
) | ||
) | ||
|
||
def get_files_metadata(self, dataset_name: str): | ||
with self.engine.connect() as connection: | ||
result = connection.execute(text(f"SELECT id, name, file_path, extension, mime_type FROM {dataset_name}.file_metadata;")) | ||
return [dict(row) for row in result] | ||
|
||
def create_table(self, schema_name: str, table_name: str, table_config: list[dict]): | ||
fields_query_parts = [f"{item['name']} {item['type']}" for item in table_config] | ||
with self.engine.connect() as connection: | ||
connection.execute(text(f"CREATE SCHEMA IF NOT EXISTS {schema_name};")) | ||
connection.execute(text(f"CREATE TABLE IF NOT EXISTS {schema_name}.{table_name} ({', '.join(fields_query_parts)});")) | ||
|
||
def delete_table(self, table_name: str): | ||
with self.engine.connect() as connection: | ||
connection.execute(text(f"DROP TABLE IF EXISTS {table_name};")) | ||
|
||
def insert_data(self, schema_name: str, table_name: str, data: list[dict]): | ||
columns = ", ".join(data[0].keys()) | ||
values = ", ".join([f"({', '.join([f':{key}' for key in row.keys()])})" for row in data]) | ||
insert_query = text(f"INSERT INTO {schema_name}.{table_name} ({columns}) VALUES {values};") | ||
with self.engine.connect() as connection: | ||
connection.execute(insert_query, data) | ||
|
||
def get_data(self, table_name: str, filters: dict = None): | ||
with self.engine.connect() as connection: | ||
query = f"SELECT * FROM {table_name}" | ||
if filters: | ||
filter_conditions = " AND ".join([ | ||
f"{key} IN ({', '.join([f':{key}{i}' for i in range(len(value))])})" if isinstance(value, list) | ||
else f"{key} = :{key}" for key, value in filters.items() | ||
]) | ||
query += f" WHERE {filter_conditions};" | ||
query = text(query) | ||
results = connection.execute(query, filters) | ||
else: | ||
query += ";" | ||
query = text(query) | ||
results = connection.execute(query) | ||
return {result["data_id"]: result["status"] for result in results} | ||
|
||
def execute_query(self, query): | ||
with self.engine.connect() as connection: | ||
result = connection.execute(text(query)) | ||
return [dict(row) for row in result] | ||
|
||
def load_cognify_data(self, data): | ||
metadata = MetaData() | ||
cognify_table = Table( | ||
'cognify', metadata, | ||
Column('document_id', String), | ||
Column('layer_id', String), | ||
Column('created_at', TIMESTAMP, server_default=text('CURRENT_TIMESTAMP')), | ||
Column('updated_at', TIMESTAMP, nullable=True, default=None), | ||
Column('processed', Boolean, default=False), | ||
Column('document_id_target', String, nullable=True) | ||
) | ||
metadata.create_all(self.engine) | ||
insert_query = cognify_table.insert().values(document_id=text(':document_id'), layer_id=text(':layer_id')) | ||
with self.engine.connect() as connection: | ||
connection.execute(insert_query, data) | ||
|
||
def fetch_cognify_data(self, excluded_document_id: str): | ||
with self.engine.connect() as connection: | ||
connection.execute(text(""" | ||
CREATE TABLE IF NOT EXISTS cognify ( | ||
document_id STRING, | ||
layer_id STRING, | ||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, | ||
updated_at TIMESTAMP DEFAULT NULL, | ||
processed BOOLEAN DEFAULT FALSE, | ||
document_id_target STRING NULL | ||
); | ||
""")) | ||
query = text(f""" | ||
SELECT document_id, layer_id, created_at, updated_at, processed | ||
FROM cognify | ||
WHERE document_id != :excluded_document_id AND processed = FALSE; | ||
""") | ||
records = connection.execute(query, {'excluded_document_id': excluded_document_id}).fetchall() | ||
if records: | ||
document_ids = tuple(record['document_id'] for record in records) | ||
update_query = text(f"UPDATE cognify SET processed = TRUE WHERE document_id IN :document_ids;") | ||
connection.execute(update_query, {'document_ids': document_ids}) | ||
return [dict(record) for record in records] | ||
|
||
def delete_cognify_data(self): | ||
with self.engine.connect() as connection: | ||
connection.execute(text(""" | ||
CREATE TABLE IF NOT EXISTS cognify ( | ||
document_id STRING, | ||
layer_id STRING, | ||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, | ||
updated_at TIMESTAMP DEFAULT NULL, | ||
processed BOOLEAN DEFAULT FALSE, | ||
document_id_target STRING NULL | ||
); | ||
""")) | ||
connection.execute(text("DELETE FROM cognify;")) | ||
connection.execute(text("DROP TABLE cognify;")) | ||
|
||
def delete_database(self): | ||
from cognee.infrastructure.files.storage import LocalStorage | ||
LocalStorage.remove(self.db_location) | ||
if LocalStorage.file_exists(self.db_location + ".wal"): | ||
LocalStorage.remove(self.db_location + ".wal") |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters