From 1a4c361a5a4591047a81da122e4533f0f100ab5d Mon Sep 17 00:00:00 2001 From: Josh Hayes <35790761+hayescode@users.noreply.github.com> Date: Wed, 20 Mar 2024 22:58:58 -0500 Subject: [PATCH 1/5] Create sql_alchemy.py --- backend/chainlit/sql_alchemy.py | 449 ++++++++++++++++++++++++++++++++ 1 file changed, 449 insertions(+) create mode 100644 backend/chainlit/sql_alchemy.py diff --git a/backend/chainlit/sql_alchemy.py b/backend/chainlit/sql_alchemy.py new file mode 100644 index 0000000000..fac8e1e965 --- /dev/null +++ b/backend/chainlit/sql_alchemy.py @@ -0,0 +1,449 @@ +import uuid +import ssl +from datetime import datetime, timezone +import json +from typing import Any, Dict, List, Optional, Union, TYPE_CHECKING +import aiofiles +import asyncio +from dataclasses import asdict +from sqlalchemy import text +from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession +from sqlalchemy.orm import sessionmaker +from azure.storage.filedatalake import FileSystemClient, ContentSettings # type: ignore +from chainlit.context import context +from chainlit.logger import logger +from chainlit.data import BaseDataLayer, queue_until_user_message +from chainlit.user import User, PersistedUser, UserDict +from chainlit.types import Feedback, FeedbackDict, Pagination, ThreadDict, ThreadFilter +from literalai import PageInfo, PaginatedResponse +from chainlit.step import StepDict + +if TYPE_CHECKING: + from chainlit.element import Element + from chainlit.step import StepDict + +class SQLAlchemyDataLayer(BaseDataLayer): + def __init__(self, conninfo, ssl_require=False): + self._conninfo = conninfo + ssl_args = {} + if ssl_require: + # Create an SSL context to require an SSL connection + ssl_context = ssl.create_default_context() + ssl_context.check_hostname = False + ssl_context.verify_mode = ssl.CERT_NONE + ssl_args['ssl'] = ssl_context + self.engine = create_async_engine(self._conninfo, connect_args=ssl_args) + self.async_session = sessionmaker(self.engine, expire_on_commit=False, class_=AsyncSession) + self.thread_update_lock = asyncio.Lock() + self.step_update_lock = asyncio.Lock() + + async def add_blob_storage_client(self, blob_storage_client, access_token: Optional[str]) -> None: + if isinstance(blob_storage_client, FileSystemClient): + self.blob_storage_client = blob_storage_client + self.blob_access_token = access_token + self.blob_storage_provider = 'Azure' + logger.info("Azure Data Lake Storage client initialized") + # Add other checks here for AWS/Google/etc. + else: + raise ValueError("The provided blob_storage is not recognized") + + ###### SQL Helpers ###### + async def execute_sql(self, query: str, parameters: dict) -> Union[List[Dict[str, Any]], int, None]: + parameterized_query = text(query) + async with self.async_session() as session: + try: + await session.begin() + result = await session.execute(parameterized_query, parameters) + await session.commit() + if result.returns_rows: + json_result = [dict(row._mapping) for row in result.fetchall()] + clean_json_result = self.clean_result(json_result) + return clean_json_result + else: + return result.rowcount + except SQLAlchemyError as e: + await session.rollback() + logger.warn(f"An error occurred: {e}") + return None + except Exception as e: + await session.rollback() + logger.warn(f"An unexpected error occurred: {e}") + return None + + async def get_current_timestamp(self) -> str: + return datetime.now(timezone.utc).astimezone().isoformat() + + def clean_result(self, obj): + """Recursively change UUI -> STR and serialize dictionaries""" + if isinstance(obj, dict): + for k, v in obj.items(): + obj[k] = self.clean_result(v) + elif isinstance(obj, list): + return [self.clean_result(item) for item in obj] + elif isinstance(obj, uuid.UUID): + return str(obj) + elif isinstance(obj, dict): + return json.dumps(obj) + return obj + + ###### User ###### + async def get_user(self, identifier: str) -> Optional[PersistedUser]: + logger.info(f"Postgres: get_user, identifier={identifier}") + query = "SELECT * FROM users WHERE identifier = :identifier" + parameters = {"identifier": identifier} + result = await self.execute_sql(query=query, parameters=parameters) + if result and isinstance(result, list): + user_data = result[0] + return PersistedUser(**user_data) + return None + + async def create_user(self, user: User) -> Optional[PersistedUser]: + logger.info(f"Postgres: create_user, user_identifier={user.identifier}") + existing_user: Optional['PersistedUser'] = await self.get_user(user.identifier) + user_dict: Dict[str, Any] = { + "identifier": str(user.identifier), + "metadata": json.dumps(user.metadata) or {} + } + if not existing_user: # create the user + logger.info("Postgres: create_user, creating the user") + user_dict['id'] = str(uuid.uuid4()) + user_dict['createdAt'] = await self.get_current_timestamp() + query = "INSERT INTO users (id, identifier, createdAt, metadata) VALUES (:id, :identifier, :createdAt, :metadata)" + await self.execute_sql(query=query, parameters=user_dict) + else: # update the user + query = """UPDATE users SET "metadata" = :metadata WHERE "identifier" = :identifier""" + await self.execute_sql(query=query, parameters=user_dict) # We want to update the metadata + return await self.get_user(user.identifier) + + ###### Threads ###### + async def get_thread_author(self, thread_id: str) -> str: + logger.info(f"Postgres: get_thread_author, thread_id={thread_id}") + query = """SELECT u.* FROM threads t JOIN users u ON t."user_id" = u."id" WHERE t."id" = :id""" + parameters = {"id": thread_id} + result = await self.execute_sql(query=query, parameters=parameters) + if result and isinstance(result, list) and result[0]: + author_identifier = result[0].get('identifier') + if author_identifier is not None: + return author_identifier + raise ValueError(f"Author not found for thread_id {thread_id}") + + async def get_thread(self, thread_id: str) -> Optional[ThreadDict]: + logger.info(f"Postgres: get_thread, thread_id={thread_id}") + user_identifier = await self.get_thread_author(thread_id=thread_id) + if user_identifier is None: + raise ValueError("User identifier not found for the given thread_id") + user_threads: Optional[List[ThreadDict]] = await self.get_all_user_threads(user_identifier=user_identifier) + if not user_threads: + return None + for thread in user_threads: + if thread['id'] == thread_id: + return thread + return None + + async def update_thread(self, thread_id: str, name: Optional[str] = None, user_id: Optional[str] = None, metadata: Optional[Dict] = None, tags: Optional[List[str]] = None): + logger.info(f"Postgres: update_thread, thread_id={thread_id}") + async with self.thread_update_lock: # Acquire the lock before updating the thread + data = { + "id": thread_id, + "createdAt": await self.get_current_timestamp() if metadata is None else None, + "name": name if name is not None else (metadata.get('name') if metadata and 'name' in metadata else None), + "user_id": user_id, + "tags": tags, + "metadata": json.dumps(metadata) if metadata else None, + } + parameters = {key: value for key, value in data.items() if value is not None} # Remove keys with None values + columns = ', '.join(f'"{key}"' for key in parameters.keys()) + values = ', '.join(f':{key}' for key in parameters.keys()) + updates = ', '.join(f'"{key}" = EXCLUDED."{key}"' for key in parameters.keys() if key != 'id') + query = f""" + INSERT INTO threads ({columns}) + VALUES ({values}) + ON CONFLICT ("id") DO UPDATE + SET {updates}; + """ + await self.execute_sql(query=query, parameters=parameters) + + async def delete_thread(self, thread_id: str): + logger.info(f"Postgres: delete_thread, thread_id={thread_id}") + query = """DELETE FROM threads WHERE "id" = :id""" + parameters = {"id": thread_id} + await self.execute_sql(query=query, parameters=parameters) + + async def list_threads(self, pagination: Pagination, filters: ThreadFilter) -> PaginatedResponse[ThreadDict]: + logger.info(f"Postgres: list_threads, pagination={pagination}, filters={filters}") + if not filters.userIdentifier: + raise ValueError("userIdentifier is required") + all_user_threads: List[ThreadDict] = await self.get_all_user_threads(user_identifier=filters.userIdentifier) or [] + + search_keyword = filters.search.lower() if filters.search else None + feedback_value = int(filters.feedback) if filters.feedback else None + + filtered_threads = [] + for thread in all_user_threads: + if search_keyword or feedback_value: + keyword_match = any(search_keyword in step['output'].lower() for step in thread['steps'] if 'output' in step) if search_keyword else True + if feedback_value is not None: + for step in thread['steps']: + feedback = step.get('feedback') + if feedback and feedback.get('value') == feedback_value: + feedback_match = True + break + else: + feedback_match = False + if keyword_match and feedback_match: + filtered_threads.append(thread) + else: + filtered_threads.append(thread) + + # Apply pagination + start = int(pagination.cursor) if pagination.cursor else 0 + end = start + pagination.first + paginated_threads = filtered_threads[start:end] or [] + + has_next_page = len(filtered_threads) > end + end_cursor = paginated_threads[-1]['id'] if paginated_threads else None + + return PaginatedResponse( + pageInfo=PageInfo(hasNextPage=has_next_page, endCursor=end_cursor), + data=paginated_threads + ) + + ###### Steps ###### + @queue_until_user_message() + async def create_step(self, step_dict: 'StepDict'): + logger.info(f"Postgres: create_step, step_id={step_dict.get('id')}") + async with self.thread_update_lock: # Wait for update_thread + pass + async with self.step_update_lock: # Acquire the lock before updating the step + step_dict['showInput'] = str(step_dict.get('showInput', '')).lower() if 'showInput' in step_dict else None + parameters = {key: value for key, value in step_dict.items() if value is not None} # Remove keys with None values + + columns = ', '.join(f'"{key}"' for key in parameters.keys()) + values = ', '.join(f':{key}' for key in parameters.keys()) + updates = ', '.join(f'"{key}" = :{key}' for key in parameters.keys() if key != 'id') + query = f""" + INSERT INTO steps ({columns}) + VALUES ({values}) + ON CONFLICT (id) DO UPDATE + SET {updates}; + """ + await self.execute_sql(query=query, parameters=parameters) + + @queue_until_user_message() + async def update_step(self, step_dict: 'StepDict'): + logger.info(f"Postgres: update_step, step_id={step_dict.get('id')}") + await self.create_step(step_dict) + + @queue_until_user_message() + async def delete_step(self, step_id: str): + logger.info(f"Postgres: delete_step, step_id={step_id}") + query = """DELETE FROM steps WHERE "id" = :id""" + parameters = {"id": step_id} + await self.execute_sql(query=query, parameters=parameters) + + ###### Feedback ###### + async def upsert_feedback(self, feedback: Feedback) -> str: + logger.info(f"Postgres: upsert_feedback, feedback_id={feedback.id}") + feedback.id = feedback.id or str(uuid.uuid4()) + feedback_dict = asdict(feedback) + parameters = {key: value for key, value in feedback_dict.items() if value is not None} # Remove keys with None values + + columns = ', '.join(f'"{key}"' for key in parameters.keys()) + values = ', '.join(f':{key}' for key in parameters.keys()) + updates = ', '.join(f'"{key}" = :{key}' for key in parameters.keys() if key != 'id') + query = f""" + INSERT INTO feedbacks ({columns}) + VALUES ({values}) + ON CONFLICT (id) DO UPDATE + SET {updates}; + """ + await self.execute_sql(query=query, parameters=parameters) + return feedback.id + + ###### Elements ###### + @queue_until_user_message() + async def create_element(self, element: 'Element'): + logger.info(f"Postgres: create_element, element_id = {element.id}") + async with self.thread_update_lock: + pass + async with self.step_update_lock: + pass + if not self.blob_storage_client: + raise ValueError("No blob_storage_client is configured") + if not element.for_id: + return + element_dict = element.to_dict() + content: Optional[Union[bytes, str]] = None + + if not element.url: + if element.path: + async with aiofiles.open(element.path, "rb") as f: + content = await f.read() + elif element.content: + content = element.content + else: + raise ValueError("Either path or content must be provided") + + context_user = context.session.user + if not context_user or not getattr(context_user, 'id', None): + raise ValueError("No valid user in context") + + user_folder = getattr(context_user, 'id', 'unknown') + object_key = f"{user_folder}/{element.id}" + (f"/{element.name}" if element.name else "") + + if self.blob_storage_provider == 'Azure': + file_client = self.blob_storage_client.get_file_client(object_key) + content_type = ContentSettings(content_type=element.mime) + file_client.upload_data(content, overwrite=True, content_settings=content_type) + element.url = file_client.url + (self.blob_access_token or '') + + element_dict['url'] = element.url + element_dict['objectKey'] = object_key if 'object_key' in locals() else None + element_dict_cleaned = {k: v for k, v in element_dict.items() if v is not None} + + columns = ', '.join(f'"{column}"' for column in element_dict_cleaned.keys()) + placeholders = ', '.join(f':{column}' for column in element_dict_cleaned.keys()) + query = f"INSERT INTO elements ({columns}) VALUES ({placeholders})" + await self.execute_sql(query=query, parameters=element_dict_cleaned) + + @queue_until_user_message() + async def delete_element(self, element_id: str): + logger.info(f"Postgres: delete_element, element_id={element_id}") + query = """DELETE FROM elements WHERE "id" = :id""" + parameters = {"id": element_id} + await self.execute_sql(query=query, parameters=parameters) + + async def delete_user_session(self, id: str) -> bool: + return False # Not sure why documentation wants this + + #### NEW OPTIMIZATION #### + async def get_all_user_threads(self, user_identifier: str) -> Optional[List[ThreadDict]]: + """Fetch all user threads for fast retrieval""" + logger.info(f"Postgres: get_all_user_threads") + parameters = {"identifier": user_identifier} + sql_query = """ + SELECT + t."id" AS thread_id, + t."createdAt" AS thread_createdat, + t."name" AS thread_name, + t."tags" AS thread_tags, + t."metadata" AS thread_metadata, + u."id" AS user_id, + u."identifier" AS user_identifier, + u."metadata" AS user_metadata, + s."id" AS step_id, + s."name" AS step_name, + s."type" AS step_type, + s."threadId" AS step_threadid, + s."parentId" AS step_parentid, + s."disableFeedback" AS step_disablefeedback, + s."streaming" AS step_streaming, + s."waitForAnswer" AS step_waitforanswer, + s."isError" AS step_iserror, + s."metadata" AS step_metadata, + s."input" AS step_input, + s."output" AS step_output, + s."createdAt" AS step_createdat, + s."start" AS step_start, + s."end" AS step_end, + s."generation" AS step_generation, + s."showInput" AS step_showinput, + s."language" AS step_language, + s."indent" AS step_indent, + f."value" AS feedback_value, + f."strategy" AS feedback_strategy, + f."comment" AS feedback_comment, + e."id" AS element_id, + e."threadId" as element_threadid, + e."type" AS element_type, + e."url" AS element_url, + e."chainlitKey" AS element_chainlitkey, + e."objectKey" as element_objectkey, + e."name" AS element_name, + e."display" AS element_display, + e."size" AS element_size, + e."language" AS element_language, + e."page" AS element_page, + e."forId" AS element_forid, + e."mime" AS element_mime + FROM + threads t + LEFT JOIN users u ON t."user_id" = u."id" + LEFT JOIN steps s ON t."id" = s."threadId" + LEFT JOIN feedbacks f ON s."id" = f."forId" + LEFT JOIN elements e ON t."id" = e."threadId" + WHERE u."identifier" = :identifier + ORDER BY t."createdAt" DESC, s."start" ASC + """ + results = await self.execute_sql(query=sql_query, parameters=parameters) + threads: List[ThreadDict] = [] + if not isinstance(results, list): + raise ValueError("Expected a list of results") + for row in results: + thread_id = row['thread_id'] + thread = next((t for t in threads if t['id'] == thread_id), None) + if not thread: + thread = ThreadDict( + id=thread_id, + createdAt=row['thread_createdat'], + name=row['thread_name'], + user= UserDict( + id=row['user_id'], + identifier=row['user_identifier'], + metadata=row['user_metadata'] + ) if row['user_id'] else None, + tags=row['thread_tags'], + metadata=row['thread_metadata'], + steps=[], + elements=[] + ) + threads.append(thread) + if row['step_id']: + step = StepDict( + id=row['step_id'], + name=row['step_name'], + type=row['step_type'], + threadId=row['step_threadid'], + parentId=row['step_parentid'], + disableFeedback=row['step_disablefeedback'], + streaming=row['step_streaming'], + waitForAnswer=row['step_waitforanswer'], + isError=row['step_iserror'], + metadata=row['step_metadata'], + input=row['step_input'] if row['step_showinput'] else None, + output=row['step_output'], + createdAt=row['step_createdat'], + start=row['step_start'], + end=row['step_end'], + generation=row['step_generation'], + showInput=row['step_showinput'], + language=row['step_language'], + indent=row['step_indent'], + feedback= FeedbackDict( + value=row['feedback_value'], + strategy=row['feedback_strategy'], + comment=row['feedback_comment'] + ) if row['feedback_value'] is not None else None + ) + thread['steps'].append(step) + if row['element_id']: + element: Dict[str, Any] = { + "id":row['element_id'], + "threadId":row['element_threadid'], + "type":row['element_type'], + "chainlitKey":row['element_chainlitkey'], + "url":row['element_url'], + "objectKey":row['element_objectkey'], + "name":row['element_name'], + "display":row['element_display'], + "size":row['element_size'], + "language":row['element_language'], + "page":row['element_page'], + "forId":row['element_forid'], + "mime":row['element_mime'] + } + if thread['elements'] is None: + thread['elements'] = [] + thread['elements'].append(element) # type: ignore + return threads From bec6e5dbc5acd4afb9b869ae60d52ea993b81a62 Mon Sep 17 00:00:00 2001 From: Josh Hayes <35790761+hayescode@users.noreply.github.com> Date: Wed, 20 Mar 2024 22:59:52 -0500 Subject: [PATCH 2/5] Update pyproject.toml --- backend/pyproject.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/backend/pyproject.toml b/backend/pyproject.toml index a0a6ba0c36..b1c84d3122 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -46,6 +46,8 @@ lazify = "^0.4.0" packaging = "^23.1" python-multipart = "^0.0.9" pyjwt = "^2.8.0" +asyncpg = "^0.29.0" +SQLAlchemy = "^2.0.28" [tool.poetry.group.tests] optional = true From 354ddc2b0ea622c55b1d22db496323ada6ea1564 Mon Sep 17 00:00:00 2001 From: Josh Hayes <35790761+hayescode@users.noreply.github.com> Date: Wed, 20 Mar 2024 23:01:23 -0500 Subject: [PATCH 3/5] Create sql_alchemy.py --- cypress/e2e/custom_data_layer/sql_alchemy.py | 28 ++++++++++++++++++++ 1 file changed, 28 insertions(+) create mode 100644 cypress/e2e/custom_data_layer/sql_alchemy.py diff --git a/cypress/e2e/custom_data_layer/sql_alchemy.py b/cypress/e2e/custom_data_layer/sql_alchemy.py new file mode 100644 index 0000000000..8a059d8aa0 --- /dev/null +++ b/cypress/e2e/custom_data_layer/sql_alchemy.py @@ -0,0 +1,28 @@ +from typing import List, Optional + +import chainlit.data as cl_data +from chainlit.data.sql_alchemy import SQLAlchemyDataLayer +from literalai.helper import utc_now + +import chainlit as cl + +cl_data._data_layer = SQLAlchemyDataLayer(conninfo="<your conninfo>") + + +@cl.on_chat_start +async def main(): + await cl.Message("Hello, send me a message!", disable_feedback=True).send() + + +@cl.on_message +async def handle_message(): + await cl.sleep(2) + await cl.Message("Ok!").send() + + +@cl.password_auth_callback +def auth_callback(username: str, password: str) -> Optional[cl.User]: + if (username, password) == ("admin", "admin"): + return cl.User(identifier="admin") + else: + return None From 610a0f4e19ff88bc99e6d37c447fbd8404be701f Mon Sep 17 00:00:00 2001 From: Josh Hayes <35790761+hayescode@users.noreply.github.com> Date: Mon, 25 Mar 2024 09:55:47 -0500 Subject: [PATCH 4/5] Create sql_alchemy.py --- backend/chainlit/data/sql_alchemy.py | 478 +++++++++++++++++++++++++++ 1 file changed, 478 insertions(+) create mode 100644 backend/chainlit/data/sql_alchemy.py diff --git a/backend/chainlit/data/sql_alchemy.py b/backend/chainlit/data/sql_alchemy.py new file mode 100644 index 0000000000..d7f1c037eb --- /dev/null +++ b/backend/chainlit/data/sql_alchemy.py @@ -0,0 +1,478 @@ +import uuid +import ssl +from datetime import datetime, timezone +import json +from typing import Any, Dict, List, Optional, Union, TYPE_CHECKING +import aiofiles +import asyncio +from dataclasses import asdict +from sqlalchemy import text +from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession +from sqlalchemy.orm import sessionmaker +from azure.storage.filedatalake import FileSystemClient, ContentSettings # type: ignore +from chainlit.context import context +from chainlit.logger import logger +from chainlit.data import BaseDataLayer, queue_until_user_message +from chainlit.user import User, PersistedUser, UserDict +from chainlit.types import Feedback, FeedbackDict, Pagination, ThreadDict, ThreadFilter, PageInfo, PaginatedResponse +from chainlit.step import StepDict +from chainlit.element import ElementDict + +if TYPE_CHECKING: + from chainlit.element import Element, ElementDict + from chainlit.step import StepDict + +class SQLAlchemyDataLayer(BaseDataLayer): + def __init__(self, conninfo, ssl_require=False, user_thread_limit=100): + self._conninfo = conninfo + self.user_thread_limit = user_thread_limit + ssl_args = {} + if ssl_require: + # Create an SSL context to require an SSL connection + ssl_context = ssl.create_default_context() + ssl_context.check_hostname = False + ssl_context.verify_mode = ssl.CERT_NONE + ssl_args['ssl'] = ssl_context + self.engine = create_async_engine(self._conninfo, connect_args=ssl_args) + self.async_session = sessionmaker(self.engine, expire_on_commit=False, class_=AsyncSession) + self.thread_update_lock = asyncio.Lock() + self.step_update_lock = asyncio.Lock() + + async def add_blob_storage_client(self, blob_storage_client, access_token: Optional[str]) -> None: + if isinstance(blob_storage_client, FileSystemClient): + self.blob_storage_client = blob_storage_client + self.blob_access_token = access_token + self.blob_storage_provider = 'Azure' + logger.info("Azure Data Lake Storage client initialized") + # Add other checks here for AWS/Google/etc. + else: + raise ValueError("The provided blob_storage is not recognized") + + ###### SQL Helpers ###### + async def execute_sql(self, query: str, parameters: dict) -> Union[List[Dict[str, Any]], int, None]: + parameterized_query = text(query) + async with self.async_session() as session: + try: + await session.begin() + result = await session.execute(parameterized_query, parameters) + await session.commit() + if result.returns_rows: + json_result = [dict(row._mapping) for row in result.fetchall()] + clean_json_result = self.clean_result(json_result) + return clean_json_result + else: + return result.rowcount + except SQLAlchemyError as e: + await session.rollback() + logger.warn(f"An error occurred: {e}") + return None + except Exception as e: + await session.rollback() + logger.warn(f"An unexpected error occurred: {e}") + return None + + async def get_current_timestamp(self) -> str: + return datetime.now(timezone.utc).astimezone().isoformat() + + def clean_result(self, obj): + """Recursively change UUI -> STR and serialize dictionaries""" + if isinstance(obj, dict): + for k, v in obj.items(): + obj[k] = self.clean_result(v) + elif isinstance(obj, list): + return [self.clean_result(item) for item in obj] + elif isinstance(obj, uuid.UUID): + return str(obj) + elif isinstance(obj, dict): + return json.dumps(obj) + return obj + + ###### User ###### + async def get_user(self, identifier: str) -> Optional[PersistedUser]: + logger.info(f"SQLAlchemy: get_user, identifier={identifier}") + query = "SELECT * FROM users WHERE identifier = :identifier" + parameters = {"identifier": identifier} + result = await self.execute_sql(query=query, parameters=parameters) + if result and isinstance(result, list): + user_data = result[0] + return PersistedUser(**user_data) + return None + + async def create_user(self, user: User) -> Optional[PersistedUser]: + logger.info(f"SQLAlchemy: create_user, user_identifier={user.identifier}") + existing_user: Optional['PersistedUser'] = await self.get_user(user.identifier) + user_dict: Dict[str, Any] = { + "identifier": str(user.identifier), + "metadata": json.dumps(user.metadata) or {} + } + if not existing_user: # create the user + logger.info("SQLAlchemy: create_user, creating the user") + user_dict['id'] = str(uuid.uuid4()) + user_dict['createdAt'] = await self.get_current_timestamp() + query = """INSERT INTO users ("id", "identifier", "createdAt", "metadata") VALUES (:id, :identifier, :createdAt, :metadata)""" + await self.execute_sql(query=query, parameters=user_dict) + else: # update the user + query = """UPDATE users SET "metadata" = :metadata WHERE "identifier" = :identifier""" + await self.execute_sql(query=query, parameters=user_dict) # We want to update the metadata + return await self.get_user(user.identifier) + + ###### Threads ###### + async def get_thread_author(self, thread_id: str) -> str: + logger.info(f"SQLAlchemy: get_thread_author, thread_id={thread_id}") + query = """SELECT u.* FROM threads t JOIN users u ON t."user_id" = u."id" WHERE t."id" = :id""" + parameters = {"id": thread_id} + result = await self.execute_sql(query=query, parameters=parameters) + if result and isinstance(result, list) and result[0]: + author_identifier = result[0].get('identifier') + if author_identifier is not None: + return author_identifier + raise ValueError(f"Author not found for thread_id {thread_id}") + + async def get_thread(self, thread_id: str) -> Optional[ThreadDict]: + logger.info(f"SQLAlchemy: get_thread, thread_id={thread_id}") + user_identifier = await self.get_thread_author(thread_id=thread_id) + if user_identifier is None: + raise ValueError("User identifier not found for the given thread_id") + user_threads: Optional[List[ThreadDict]] = await self.get_all_user_threads(user_identifier=user_identifier) + if not user_threads: + return None + for thread in user_threads: + if thread['id'] == thread_id: + return thread + return None + + async def update_thread(self, thread_id: str, name: Optional[str] = None, user_id: Optional[str] = None, metadata: Optional[Dict] = None, tags: Optional[List[str]] = None): + logger.info(f"SQLAlchemy: update_thread, thread_id={thread_id}") + async with self.thread_update_lock: # Acquire the lock before updating the thread + data = { + "id": thread_id, + "createdAt": await self.get_current_timestamp() if metadata is None else None, + "name": name if name is not None else (metadata.get('name') if metadata and 'name' in metadata else None), + "user_id": user_id, + "tags": tags, + "metadata": json.dumps(metadata) if metadata else None, + } + parameters = {key: value for key, value in data.items() if value is not None} # Remove keys with None values + columns = ', '.join(f'"{key}"' for key in parameters.keys()) + values = ', '.join(f':{key}' for key in parameters.keys()) + updates = ', '.join(f'"{key}" = EXCLUDED."{key}"' for key in parameters.keys() if key != 'id') + query = f""" + INSERT INTO threads ({columns}) + VALUES ({values}) + ON CONFLICT ("id") DO UPDATE + SET {updates}; + """ + await self.execute_sql(query=query, parameters=parameters) + + async def delete_thread(self, thread_id: str): + logger.info(f"SQLAlchemy: delete_thread, thread_id={thread_id}") + query = """DELETE FROM threads WHERE "id" = :id""" + parameters = {"id": thread_id} + await self.execute_sql(query=query, parameters=parameters) + + async def list_threads(self, pagination: Pagination, filters: ThreadFilter) -> PaginatedResponse[ThreadDict]: + logger.info(f"SQLAlchemy: list_threads, pagination={pagination}, filters={filters}") + if not filters.userIdentifier: + raise ValueError("userIdentifier is required") + all_user_threads: List[ThreadDict] = await self.get_all_user_threads(user_identifier=filters.userIdentifier) or [] + + search_keyword = filters.search.lower() if filters.search else None + feedback_value = int(filters.feedback) if filters.feedback else None + + filtered_threads = [] + for thread in all_user_threads: + keyword_match = True + feedback_match = True # Initialize feedback_match to True + if search_keyword or feedback_value is not None: + if search_keyword: + keyword_match = any(search_keyword in step['output'].lower() for step in thread['steps'] if 'output' in step) + if feedback_value is not None: + feedback_match = False # Assume no match until found + for step in thread['steps']: + feedback = step.get('feedback') + if feedback and feedback.get('value') == feedback_value: + feedback_match = True + break + if keyword_match and feedback_match: + filtered_threads.append(thread) + + start = 0 # Find the start index using pagination.cursor + if pagination.cursor: + for i, thread in enumerate(filtered_threads): + if thread['id'] == pagination.cursor: + start = i + 1 + break + end = start + pagination.first + paginated_threads = filtered_threads[start:end] or [] + + has_next_page = len(filtered_threads) > end + end_cursor = paginated_threads[-1]['id'] if paginated_threads else None + + return PaginatedResponse( + pageInfo=PageInfo(hasNextPage=has_next_page, endCursor=end_cursor), + data=paginated_threads + ) + + ###### Steps ###### + @queue_until_user_message() + async def create_step(self, step_dict: 'StepDict'): + logger.info(f"SQLAlchemy: create_step, step_id={step_dict.get('id')}") + logger.info(f"SQLAlchemy: name={step_dict.get('name')}, input={step_dict.get('input')}, output={step_dict.get('name')}") + async with self.thread_update_lock: # Wait for update_thread + pass + async with self.step_update_lock: # Acquire the lock before updating the step + step_dict['showInput'] = str(step_dict.get('showInput', '')).lower() if 'showInput' in step_dict else None + # parameters = {key: value for key, value in step_dict.items() if value is not None} # Remove keys with None values + parameters = {key: value for key, value in step_dict.items() if value is not None and not (isinstance(value, dict) and not value)} + + + columns = ', '.join(f'"{key}"' for key in parameters.keys()) + values = ', '.join(f':{key}' for key in parameters.keys()) + updates = ', '.join(f'"{key}" = :{key}' for key in parameters.keys() if key != 'id') + query = f""" + INSERT INTO steps ({columns}) + VALUES ({values}) + ON CONFLICT (id) DO UPDATE + SET {updates}; + """ + await self.execute_sql(query=query, parameters=parameters) + + @queue_until_user_message() + async def update_step(self, step_dict: 'StepDict'): + logger.info(f"SQLAlchemy: update_step, step_id={step_dict.get('id')}") + await self.create_step(step_dict) + + @queue_until_user_message() + async def delete_step(self, step_id: str): + logger.info(f"SQLAlchemy: delete_step, step_id={step_id}") + query = """DELETE FROM steps WHERE "id" = :id""" + parameters = {"id": step_id} + await self.execute_sql(query=query, parameters=parameters) + + ###### Feedback ###### + async def upsert_feedback(self, feedback: Feedback) -> str: + logger.info(f"SQLAlchemy: upsert_feedback, feedback_id={feedback.id}") + feedback.id = feedback.id or str(uuid.uuid4()) + feedback_dict = asdict(feedback) + parameters = {key: value for key, value in feedback_dict.items() if value is not None} # Remove keys with None values + + columns = ', '.join(f'"{key}"' for key in parameters.keys()) + values = ', '.join(f':{key}' for key in parameters.keys()) + updates = ', '.join(f'"{key}" = :{key}' for key in parameters.keys() if key != 'id') + query = f""" + INSERT INTO feedbacks ({columns}) + VALUES ({values}) + ON CONFLICT (id) DO UPDATE + SET {updates}; + """ + await self.execute_sql(query=query, parameters=parameters) + return feedback.id + + ###### Elements ###### + @queue_until_user_message() + async def create_element(self, element: 'Element'): + logger.info(f"SQLAlchemy: create_element, element_id = {element.id}") + async with self.thread_update_lock: + pass + async with self.step_update_lock: + pass + if not self.blob_storage_client: + raise ValueError("No blob_storage_client is configured") + if not element.for_id: + return + element_dict = element.to_dict() + content: Optional[Union[bytes, str]] = None + + if not element.url: + if element.path: + async with aiofiles.open(element.path, "rb") as f: + content = await f.read() + elif element.content: + content = element.content + else: + raise ValueError("Either path or content must be provided") + + context_user = context.session.user + if not context_user or not getattr(context_user, 'id', None): + raise ValueError("No valid user in context") + + user_folder = getattr(context_user, 'id', 'unknown') + object_key = f"{user_folder}/{element.id}" + (f"/{element.name}" if element.name else "") + + if self.blob_storage_provider == 'Azure': + file_client = self.blob_storage_client.get_file_client(object_key) + content_type = ContentSettings(content_type=element.mime) + file_client.upload_data(content, overwrite=True, content_settings=content_type) + element.url = file_client.url + (self.blob_access_token or '') + + element_dict['url'] = element.url + element_dict['objectKey'] = object_key if 'object_key' in locals() else None + element_dict_cleaned = {k: v for k, v in element_dict.items() if v is not None} + + columns = ', '.join(f'"{column}"' for column in element_dict_cleaned.keys()) + placeholders = ', '.join(f':{column}' for column in element_dict_cleaned.keys()) + query = f"INSERT INTO elements ({columns}) VALUES ({placeholders})" + await self.execute_sql(query=query, parameters=element_dict_cleaned) + + @queue_until_user_message() + async def delete_element(self, element_id: str): + logger.info(f"SQLAlchemy: delete_element, element_id={element_id}") + query = """DELETE FROM elements WHERE "id" = :id""" + parameters = {"id": element_id} + await self.execute_sql(query=query, parameters=parameters) + + async def delete_user_session(self, id: str) -> bool: + return False # Not sure why documentation wants this + + async def get_all_user_threads(self, user_identifier: str) -> Optional[List[ThreadDict]]: + """Fetch all user threads for fast retrieval, up to self.user_thread_limit""" + logger.info(f"SQLAlchemy: get_all_user_threads") + user_threads_query = """ + SELECT + t."id" AS thread_id, + t."createdAt" AS thread_createdat, + t."name" AS thread_name, + t."tags" AS thread_tags, + t."metadata" AS thread_metadata, + u."id" AS user_id, + u."identifier" AS user_identifier, + u."metadata" AS user_metadata + FROM threads t JOIN users u ON t."user_id" = u."id" + WHERE u."identifier" = :identifier + ORDER BY t."createdAt" DESC + LIMIT :limit + """ + user_threads = await self.execute_sql(query=user_threads_query, parameters={"identifier": user_identifier, "limit": self.user_thread_limit}) + if not isinstance(user_threads, list): + return None + thread_ids = "('" + "','".join(map(str, [thread['thread_id'] for thread in user_threads])) + "')" + if not thread_ids: + return [] + + steps_feedbacks_query = f""" + SELECT + s."id" AS step_id, + s."name" AS step_name, + s."type" AS step_type, + s."threadId" AS step_threadid, + s."parentId" AS step_parentid, + s."disableFeedback" AS step_disablefeedback, + s."streaming" AS step_streaming, + s."waitForAnswer" AS step_waitforanswer, + s."isError" AS step_iserror, + s."metadata" AS step_metadata, + s."input" AS step_input, + s."output" AS step_output, + s."createdAt" AS step_createdat, + s."start" AS step_start, + s."end" AS step_end, + s."generation" AS step_generation, + s."showInput" AS step_showinput, + s."language" AS step_language, + s."indent" AS step_indent, + f."value" AS feedback_value, + f."strategy" AS feedback_strategy, + f."comment" AS feedback_comment + FROM steps s LEFT JOIN feedbacks f ON s."id" = f."forId" + WHERE s."threadId" IN {thread_ids} + ORDER BY s."createdAt" ASC + """ + steps_feedbacks = await self.execute_sql(query=steps_feedbacks_query, parameters={}) + + elements_query = f""" + SELECT + e."id" AS element_id, + e."threadId" as element_threadid, + e."type" AS element_type, + e."url" AS element_url, + e."chainlitKey" AS element_chainlitkey, + e."objectKey" as element_objectkey, + e."name" AS element_name, + e."display" AS element_display, + e."size" AS element_size, + e."language" AS element_language, + e."page" AS element_page, + e."forId" AS element_forid, + e."mime" AS element_mime + FROM elements e + WHERE e."threadId" IN {thread_ids} + """ + elements = await self.execute_sql(query=elements_query, parameters={}) + + # Initialize a dictionary to hold ThreadDict objects keyed by thread_id + thread_dicts = {} + # Process threads_users to create initial ThreadDict objects + for thread in user_threads: + thread_id = thread['thread_id'] + thread_dicts[thread_id] = ThreadDict( + id=thread_id, + createdAt=thread['thread_createdat'], + name=thread['thread_name'], + user=UserDict( + id=thread['user_id'], + identifier=thread['user_identifier'], + metadata=thread['user_metadata'] + ), + tags=thread['thread_tags'], + metadata=thread['thread_metadata'], + steps=[], + elements=[] + ) + # Process steps_feedbacks to populate the steps in the corresponding ThreadDict + if isinstance(steps_feedbacks, list): + for step_feedback in steps_feedbacks: + thread_id = step_feedback['step_threadid'] + feedback = None + if step_feedback['feedback_value'] is not None: + feedback = FeedbackDict( + value=step_feedback['feedback_value'], + strategy=step_feedback['feedback_strategy'], + comment=step_feedback.get('feedback_comment') + ) + step_dict = StepDict( + id=step_feedback['step_id'], + name=step_feedback['step_name'], + type=step_feedback['step_type'], + threadId=thread_id, + parentId=step_feedback.get('step_parentid'), + disableFeedback=step_feedback.get('step_disableFeedback', False), + streaming=step_feedback.get('step_streaming', False), + waitForAnswer=step_feedback.get('step_waitForAnswer'), + isError=step_feedback.get('step_isError'), + metadata=step_feedback.get('step_metadata', {}), + input=step_feedback.get('step_input', '') if step_feedback['step_showinput'] else None, + output=step_feedback.get('step_output', ''), + createdAt=step_feedback.get('step_createdAt'), + start=step_feedback.get('step_start'), + end=step_feedback.get('step_end'), + generation=step_feedback.get('step_generation'), + showInput=step_feedback.get('step_showInput'), + language=step_feedback.get('step_language'), + indent=step_feedback.get('step_indent'), + feedback=feedback + ) + # Append the step to the steps list of the corresponding ThreadDict + thread_dicts[thread_id]['steps'].append(step_dict) + + if isinstance(elements, list): + for element in elements: + thread_id = element['element_threadid'] + element_dict = ElementDict( + id=element['element_id'], + threadId=thread_id, + type=element['element_type'], + chainlitKey=element.get('element_chainlitKey'), + url=element.get('element_url'), + objectKey=element.get('element_objectKey'), + name=element['element_name'], + display=element['element_display'], + size=element.get('element_size'), + language=element.get('element_language'), + page=element.get('element_page'), + forId=element.get('element_forId'), + mime=element.get('element_mime'), + ) + thread_dicts[thread_id]['elements'].append(element_dict) # type: ignore + + return list(thread_dicts.values()) From e11d00dbeb7143e4a535409e8957a7a38dfa481b Mon Sep 17 00:00:00 2001 From: Josh Hayes <35790761+hayescode@users.noreply.github.com> Date: Mon, 25 Mar 2024 10:03:29 -0500 Subject: [PATCH 5/5] Delete backend/chainlit/sql_alchemy.py --- backend/chainlit/sql_alchemy.py | 449 -------------------------------- 1 file changed, 449 deletions(-) delete mode 100644 backend/chainlit/sql_alchemy.py diff --git a/backend/chainlit/sql_alchemy.py b/backend/chainlit/sql_alchemy.py deleted file mode 100644 index fac8e1e965..0000000000 --- a/backend/chainlit/sql_alchemy.py +++ /dev/null @@ -1,449 +0,0 @@ -import uuid -import ssl -from datetime import datetime, timezone -import json -from typing import Any, Dict, List, Optional, Union, TYPE_CHECKING -import aiofiles -import asyncio -from dataclasses import asdict -from sqlalchemy import text -from sqlalchemy.exc import SQLAlchemyError -from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession -from sqlalchemy.orm import sessionmaker -from azure.storage.filedatalake import FileSystemClient, ContentSettings # type: ignore -from chainlit.context import context -from chainlit.logger import logger -from chainlit.data import BaseDataLayer, queue_until_user_message -from chainlit.user import User, PersistedUser, UserDict -from chainlit.types import Feedback, FeedbackDict, Pagination, ThreadDict, ThreadFilter -from literalai import PageInfo, PaginatedResponse -from chainlit.step import StepDict - -if TYPE_CHECKING: - from chainlit.element import Element - from chainlit.step import StepDict - -class SQLAlchemyDataLayer(BaseDataLayer): - def __init__(self, conninfo, ssl_require=False): - self._conninfo = conninfo - ssl_args = {} - if ssl_require: - # Create an SSL context to require an SSL connection - ssl_context = ssl.create_default_context() - ssl_context.check_hostname = False - ssl_context.verify_mode = ssl.CERT_NONE - ssl_args['ssl'] = ssl_context - self.engine = create_async_engine(self._conninfo, connect_args=ssl_args) - self.async_session = sessionmaker(self.engine, expire_on_commit=False, class_=AsyncSession) - self.thread_update_lock = asyncio.Lock() - self.step_update_lock = asyncio.Lock() - - async def add_blob_storage_client(self, blob_storage_client, access_token: Optional[str]) -> None: - if isinstance(blob_storage_client, FileSystemClient): - self.blob_storage_client = blob_storage_client - self.blob_access_token = access_token - self.blob_storage_provider = 'Azure' - logger.info("Azure Data Lake Storage client initialized") - # Add other checks here for AWS/Google/etc. - else: - raise ValueError("The provided blob_storage is not recognized") - - ###### SQL Helpers ###### - async def execute_sql(self, query: str, parameters: dict) -> Union[List[Dict[str, Any]], int, None]: - parameterized_query = text(query) - async with self.async_session() as session: - try: - await session.begin() - result = await session.execute(parameterized_query, parameters) - await session.commit() - if result.returns_rows: - json_result = [dict(row._mapping) for row in result.fetchall()] - clean_json_result = self.clean_result(json_result) - return clean_json_result - else: - return result.rowcount - except SQLAlchemyError as e: - await session.rollback() - logger.warn(f"An error occurred: {e}") - return None - except Exception as e: - await session.rollback() - logger.warn(f"An unexpected error occurred: {e}") - return None - - async def get_current_timestamp(self) -> str: - return datetime.now(timezone.utc).astimezone().isoformat() - - def clean_result(self, obj): - """Recursively change UUI -> STR and serialize dictionaries""" - if isinstance(obj, dict): - for k, v in obj.items(): - obj[k] = self.clean_result(v) - elif isinstance(obj, list): - return [self.clean_result(item) for item in obj] - elif isinstance(obj, uuid.UUID): - return str(obj) - elif isinstance(obj, dict): - return json.dumps(obj) - return obj - - ###### User ###### - async def get_user(self, identifier: str) -> Optional[PersistedUser]: - logger.info(f"Postgres: get_user, identifier={identifier}") - query = "SELECT * FROM users WHERE identifier = :identifier" - parameters = {"identifier": identifier} - result = await self.execute_sql(query=query, parameters=parameters) - if result and isinstance(result, list): - user_data = result[0] - return PersistedUser(**user_data) - return None - - async def create_user(self, user: User) -> Optional[PersistedUser]: - logger.info(f"Postgres: create_user, user_identifier={user.identifier}") - existing_user: Optional['PersistedUser'] = await self.get_user(user.identifier) - user_dict: Dict[str, Any] = { - "identifier": str(user.identifier), - "metadata": json.dumps(user.metadata) or {} - } - if not existing_user: # create the user - logger.info("Postgres: create_user, creating the user") - user_dict['id'] = str(uuid.uuid4()) - user_dict['createdAt'] = await self.get_current_timestamp() - query = "INSERT INTO users (id, identifier, createdAt, metadata) VALUES (:id, :identifier, :createdAt, :metadata)" - await self.execute_sql(query=query, parameters=user_dict) - else: # update the user - query = """UPDATE users SET "metadata" = :metadata WHERE "identifier" = :identifier""" - await self.execute_sql(query=query, parameters=user_dict) # We want to update the metadata - return await self.get_user(user.identifier) - - ###### Threads ###### - async def get_thread_author(self, thread_id: str) -> str: - logger.info(f"Postgres: get_thread_author, thread_id={thread_id}") - query = """SELECT u.* FROM threads t JOIN users u ON t."user_id" = u."id" WHERE t."id" = :id""" - parameters = {"id": thread_id} - result = await self.execute_sql(query=query, parameters=parameters) - if result and isinstance(result, list) and result[0]: - author_identifier = result[0].get('identifier') - if author_identifier is not None: - return author_identifier - raise ValueError(f"Author not found for thread_id {thread_id}") - - async def get_thread(self, thread_id: str) -> Optional[ThreadDict]: - logger.info(f"Postgres: get_thread, thread_id={thread_id}") - user_identifier = await self.get_thread_author(thread_id=thread_id) - if user_identifier is None: - raise ValueError("User identifier not found for the given thread_id") - user_threads: Optional[List[ThreadDict]] = await self.get_all_user_threads(user_identifier=user_identifier) - if not user_threads: - return None - for thread in user_threads: - if thread['id'] == thread_id: - return thread - return None - - async def update_thread(self, thread_id: str, name: Optional[str] = None, user_id: Optional[str] = None, metadata: Optional[Dict] = None, tags: Optional[List[str]] = None): - logger.info(f"Postgres: update_thread, thread_id={thread_id}") - async with self.thread_update_lock: # Acquire the lock before updating the thread - data = { - "id": thread_id, - "createdAt": await self.get_current_timestamp() if metadata is None else None, - "name": name if name is not None else (metadata.get('name') if metadata and 'name' in metadata else None), - "user_id": user_id, - "tags": tags, - "metadata": json.dumps(metadata) if metadata else None, - } - parameters = {key: value for key, value in data.items() if value is not None} # Remove keys with None values - columns = ', '.join(f'"{key}"' for key in parameters.keys()) - values = ', '.join(f':{key}' for key in parameters.keys()) - updates = ', '.join(f'"{key}" = EXCLUDED."{key}"' for key in parameters.keys() if key != 'id') - query = f""" - INSERT INTO threads ({columns}) - VALUES ({values}) - ON CONFLICT ("id") DO UPDATE - SET {updates}; - """ - await self.execute_sql(query=query, parameters=parameters) - - async def delete_thread(self, thread_id: str): - logger.info(f"Postgres: delete_thread, thread_id={thread_id}") - query = """DELETE FROM threads WHERE "id" = :id""" - parameters = {"id": thread_id} - await self.execute_sql(query=query, parameters=parameters) - - async def list_threads(self, pagination: Pagination, filters: ThreadFilter) -> PaginatedResponse[ThreadDict]: - logger.info(f"Postgres: list_threads, pagination={pagination}, filters={filters}") - if not filters.userIdentifier: - raise ValueError("userIdentifier is required") - all_user_threads: List[ThreadDict] = await self.get_all_user_threads(user_identifier=filters.userIdentifier) or [] - - search_keyword = filters.search.lower() if filters.search else None - feedback_value = int(filters.feedback) if filters.feedback else None - - filtered_threads = [] - for thread in all_user_threads: - if search_keyword or feedback_value: - keyword_match = any(search_keyword in step['output'].lower() for step in thread['steps'] if 'output' in step) if search_keyword else True - if feedback_value is not None: - for step in thread['steps']: - feedback = step.get('feedback') - if feedback and feedback.get('value') == feedback_value: - feedback_match = True - break - else: - feedback_match = False - if keyword_match and feedback_match: - filtered_threads.append(thread) - else: - filtered_threads.append(thread) - - # Apply pagination - start = int(pagination.cursor) if pagination.cursor else 0 - end = start + pagination.first - paginated_threads = filtered_threads[start:end] or [] - - has_next_page = len(filtered_threads) > end - end_cursor = paginated_threads[-1]['id'] if paginated_threads else None - - return PaginatedResponse( - pageInfo=PageInfo(hasNextPage=has_next_page, endCursor=end_cursor), - data=paginated_threads - ) - - ###### Steps ###### - @queue_until_user_message() - async def create_step(self, step_dict: 'StepDict'): - logger.info(f"Postgres: create_step, step_id={step_dict.get('id')}") - async with self.thread_update_lock: # Wait for update_thread - pass - async with self.step_update_lock: # Acquire the lock before updating the step - step_dict['showInput'] = str(step_dict.get('showInput', '')).lower() if 'showInput' in step_dict else None - parameters = {key: value for key, value in step_dict.items() if value is not None} # Remove keys with None values - - columns = ', '.join(f'"{key}"' for key in parameters.keys()) - values = ', '.join(f':{key}' for key in parameters.keys()) - updates = ', '.join(f'"{key}" = :{key}' for key in parameters.keys() if key != 'id') - query = f""" - INSERT INTO steps ({columns}) - VALUES ({values}) - ON CONFLICT (id) DO UPDATE - SET {updates}; - """ - await self.execute_sql(query=query, parameters=parameters) - - @queue_until_user_message() - async def update_step(self, step_dict: 'StepDict'): - logger.info(f"Postgres: update_step, step_id={step_dict.get('id')}") - await self.create_step(step_dict) - - @queue_until_user_message() - async def delete_step(self, step_id: str): - logger.info(f"Postgres: delete_step, step_id={step_id}") - query = """DELETE FROM steps WHERE "id" = :id""" - parameters = {"id": step_id} - await self.execute_sql(query=query, parameters=parameters) - - ###### Feedback ###### - async def upsert_feedback(self, feedback: Feedback) -> str: - logger.info(f"Postgres: upsert_feedback, feedback_id={feedback.id}") - feedback.id = feedback.id or str(uuid.uuid4()) - feedback_dict = asdict(feedback) - parameters = {key: value for key, value in feedback_dict.items() if value is not None} # Remove keys with None values - - columns = ', '.join(f'"{key}"' for key in parameters.keys()) - values = ', '.join(f':{key}' for key in parameters.keys()) - updates = ', '.join(f'"{key}" = :{key}' for key in parameters.keys() if key != 'id') - query = f""" - INSERT INTO feedbacks ({columns}) - VALUES ({values}) - ON CONFLICT (id) DO UPDATE - SET {updates}; - """ - await self.execute_sql(query=query, parameters=parameters) - return feedback.id - - ###### Elements ###### - @queue_until_user_message() - async def create_element(self, element: 'Element'): - logger.info(f"Postgres: create_element, element_id = {element.id}") - async with self.thread_update_lock: - pass - async with self.step_update_lock: - pass - if not self.blob_storage_client: - raise ValueError("No blob_storage_client is configured") - if not element.for_id: - return - element_dict = element.to_dict() - content: Optional[Union[bytes, str]] = None - - if not element.url: - if element.path: - async with aiofiles.open(element.path, "rb") as f: - content = await f.read() - elif element.content: - content = element.content - else: - raise ValueError("Either path or content must be provided") - - context_user = context.session.user - if not context_user or not getattr(context_user, 'id', None): - raise ValueError("No valid user in context") - - user_folder = getattr(context_user, 'id', 'unknown') - object_key = f"{user_folder}/{element.id}" + (f"/{element.name}" if element.name else "") - - if self.blob_storage_provider == 'Azure': - file_client = self.blob_storage_client.get_file_client(object_key) - content_type = ContentSettings(content_type=element.mime) - file_client.upload_data(content, overwrite=True, content_settings=content_type) - element.url = file_client.url + (self.blob_access_token or '') - - element_dict['url'] = element.url - element_dict['objectKey'] = object_key if 'object_key' in locals() else None - element_dict_cleaned = {k: v for k, v in element_dict.items() if v is not None} - - columns = ', '.join(f'"{column}"' for column in element_dict_cleaned.keys()) - placeholders = ', '.join(f':{column}' for column in element_dict_cleaned.keys()) - query = f"INSERT INTO elements ({columns}) VALUES ({placeholders})" - await self.execute_sql(query=query, parameters=element_dict_cleaned) - - @queue_until_user_message() - async def delete_element(self, element_id: str): - logger.info(f"Postgres: delete_element, element_id={element_id}") - query = """DELETE FROM elements WHERE "id" = :id""" - parameters = {"id": element_id} - await self.execute_sql(query=query, parameters=parameters) - - async def delete_user_session(self, id: str) -> bool: - return False # Not sure why documentation wants this - - #### NEW OPTIMIZATION #### - async def get_all_user_threads(self, user_identifier: str) -> Optional[List[ThreadDict]]: - """Fetch all user threads for fast retrieval""" - logger.info(f"Postgres: get_all_user_threads") - parameters = {"identifier": user_identifier} - sql_query = """ - SELECT - t."id" AS thread_id, - t."createdAt" AS thread_createdat, - t."name" AS thread_name, - t."tags" AS thread_tags, - t."metadata" AS thread_metadata, - u."id" AS user_id, - u."identifier" AS user_identifier, - u."metadata" AS user_metadata, - s."id" AS step_id, - s."name" AS step_name, - s."type" AS step_type, - s."threadId" AS step_threadid, - s."parentId" AS step_parentid, - s."disableFeedback" AS step_disablefeedback, - s."streaming" AS step_streaming, - s."waitForAnswer" AS step_waitforanswer, - s."isError" AS step_iserror, - s."metadata" AS step_metadata, - s."input" AS step_input, - s."output" AS step_output, - s."createdAt" AS step_createdat, - s."start" AS step_start, - s."end" AS step_end, - s."generation" AS step_generation, - s."showInput" AS step_showinput, - s."language" AS step_language, - s."indent" AS step_indent, - f."value" AS feedback_value, - f."strategy" AS feedback_strategy, - f."comment" AS feedback_comment, - e."id" AS element_id, - e."threadId" as element_threadid, - e."type" AS element_type, - e."url" AS element_url, - e."chainlitKey" AS element_chainlitkey, - e."objectKey" as element_objectkey, - e."name" AS element_name, - e."display" AS element_display, - e."size" AS element_size, - e."language" AS element_language, - e."page" AS element_page, - e."forId" AS element_forid, - e."mime" AS element_mime - FROM - threads t - LEFT JOIN users u ON t."user_id" = u."id" - LEFT JOIN steps s ON t."id" = s."threadId" - LEFT JOIN feedbacks f ON s."id" = f."forId" - LEFT JOIN elements e ON t."id" = e."threadId" - WHERE u."identifier" = :identifier - ORDER BY t."createdAt" DESC, s."start" ASC - """ - results = await self.execute_sql(query=sql_query, parameters=parameters) - threads: List[ThreadDict] = [] - if not isinstance(results, list): - raise ValueError("Expected a list of results") - for row in results: - thread_id = row['thread_id'] - thread = next((t for t in threads if t['id'] == thread_id), None) - if not thread: - thread = ThreadDict( - id=thread_id, - createdAt=row['thread_createdat'], - name=row['thread_name'], - user= UserDict( - id=row['user_id'], - identifier=row['user_identifier'], - metadata=row['user_metadata'] - ) if row['user_id'] else None, - tags=row['thread_tags'], - metadata=row['thread_metadata'], - steps=[], - elements=[] - ) - threads.append(thread) - if row['step_id']: - step = StepDict( - id=row['step_id'], - name=row['step_name'], - type=row['step_type'], - threadId=row['step_threadid'], - parentId=row['step_parentid'], - disableFeedback=row['step_disablefeedback'], - streaming=row['step_streaming'], - waitForAnswer=row['step_waitforanswer'], - isError=row['step_iserror'], - metadata=row['step_metadata'], - input=row['step_input'] if row['step_showinput'] else None, - output=row['step_output'], - createdAt=row['step_createdat'], - start=row['step_start'], - end=row['step_end'], - generation=row['step_generation'], - showInput=row['step_showinput'], - language=row['step_language'], - indent=row['step_indent'], - feedback= FeedbackDict( - value=row['feedback_value'], - strategy=row['feedback_strategy'], - comment=row['feedback_comment'] - ) if row['feedback_value'] is not None else None - ) - thread['steps'].append(step) - if row['element_id']: - element: Dict[str, Any] = { - "id":row['element_id'], - "threadId":row['element_threadid'], - "type":row['element_type'], - "chainlitKey":row['element_chainlitkey'], - "url":row['element_url'], - "objectKey":row['element_objectkey'], - "name":row['element_name'], - "display":row['element_display'], - "size":row['element_size'], - "language":row['element_language'], - "page":row['element_page'], - "forId":row['element_forid'], - "mime":row['element_mime'] - } - if thread['elements'] is None: - thread['elements'] = [] - thread['elements'].append(element) # type: ignore - return threads