From 1a4c361a5a4591047a81da122e4533f0f100ab5d Mon Sep 17 00:00:00 2001
From: Josh Hayes <35790761+hayescode@users.noreply.github.com>
Date: Wed, 20 Mar 2024 22:58:58 -0500
Subject: [PATCH 1/5] Create sql_alchemy.py

---
 backend/chainlit/sql_alchemy.py | 449 ++++++++++++++++++++++++++++++++
 1 file changed, 449 insertions(+)
 create mode 100644 backend/chainlit/sql_alchemy.py

diff --git a/backend/chainlit/sql_alchemy.py b/backend/chainlit/sql_alchemy.py
new file mode 100644
index 0000000000..fac8e1e965
--- /dev/null
+++ b/backend/chainlit/sql_alchemy.py
@@ -0,0 +1,449 @@
+import uuid
+import ssl
+from datetime import datetime, timezone
+import json
+from typing import Any, Dict, List, Optional, Union, TYPE_CHECKING
+import aiofiles
+import asyncio
+from dataclasses import asdict
+from sqlalchemy import text
+from sqlalchemy.exc import SQLAlchemyError
+from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
+from sqlalchemy.orm import sessionmaker
+from azure.storage.filedatalake import FileSystemClient, ContentSettings  # type: ignore
+from chainlit.context import context
+from chainlit.logger import logger
+from chainlit.data import BaseDataLayer, queue_until_user_message
+from chainlit.user import User, PersistedUser, UserDict
+from chainlit.types import Feedback, FeedbackDict, Pagination, ThreadDict, ThreadFilter
+from literalai import PageInfo, PaginatedResponse
+from chainlit.step import StepDict
+
+if TYPE_CHECKING:
+    from chainlit.element import Element
+    from chainlit.step import StepDict
+
+class SQLAlchemyDataLayer(BaseDataLayer):
+    def __init__(self, conninfo, ssl_require=False):
+        self._conninfo = conninfo
+        ssl_args = {}
+        if ssl_require:
+            # Create an SSL context to require an SSL connection
+            ssl_context = ssl.create_default_context()
+            ssl_context.check_hostname = False
+            ssl_context.verify_mode = ssl.CERT_NONE
+            ssl_args['ssl'] = ssl_context
+        self.engine = create_async_engine(self._conninfo, connect_args=ssl_args)
+        self.async_session = sessionmaker(self.engine, expire_on_commit=False, class_=AsyncSession)
+        self.thread_update_lock = asyncio.Lock()
+        self.step_update_lock = asyncio.Lock()
+
+    async def add_blob_storage_client(self, blob_storage_client, access_token: Optional[str]) -> None:
+        if isinstance(blob_storage_client, FileSystemClient):
+            self.blob_storage_client = blob_storage_client
+            self.blob_access_token = access_token
+            self.blob_storage_provider = 'Azure'
+            logger.info("Azure Data Lake Storage client initialized")
+        # Add other checks here for AWS/Google/etc.
+        else:
+            raise ValueError("The provided blob_storage is not recognized")
+
+    ###### SQL Helpers ######
+    async def execute_sql(self, query: str, parameters: dict) -> Union[List[Dict[str, Any]], int, None]:
+        parameterized_query = text(query)
+        async with self.async_session() as session:
+            try:
+                await session.begin()
+                result = await session.execute(parameterized_query, parameters)
+                await session.commit()
+                if result.returns_rows:
+                    json_result = [dict(row._mapping) for row in result.fetchall()]
+                    clean_json_result = self.clean_result(json_result)
+                    return clean_json_result
+                else:
+                    return result.rowcount
+            except SQLAlchemyError as e:
+                await session.rollback()
+                logger.warn(f"An error occurred: {e}")
+                return None
+            except Exception as e:
+                await session.rollback()
+                logger.warn(f"An unexpected error occurred: {e}")
+                return None
+
+    async def get_current_timestamp(self) -> str:
+        return datetime.now(timezone.utc).astimezone().isoformat()
+    
+    def clean_result(self, obj):
+        """Recursively change UUI -> STR and serialize dictionaries"""
+        if isinstance(obj, dict):
+            for k, v in obj.items():
+                obj[k] = self.clean_result(v)
+        elif isinstance(obj, list):
+            return [self.clean_result(item) for item in obj]
+        elif isinstance(obj, uuid.UUID):
+            return str(obj)
+        elif isinstance(obj, dict):
+            return json.dumps(obj)
+        return obj
+              
+    ###### User ######
+    async def get_user(self, identifier: str) -> Optional[PersistedUser]:
+        logger.info(f"Postgres: get_user, identifier={identifier}")
+        query = "SELECT * FROM users WHERE identifier = :identifier"
+        parameters = {"identifier": identifier}
+        result = await self.execute_sql(query=query, parameters=parameters)
+        if result and isinstance(result, list):
+            user_data = result[0]
+            return PersistedUser(**user_data)
+        return None
+        
+    async def create_user(self, user: User) -> Optional[PersistedUser]:
+        logger.info(f"Postgres: create_user, user_identifier={user.identifier}")
+        existing_user: Optional['PersistedUser'] = await self.get_user(user.identifier)
+        user_dict: Dict[str, Any]  = {
+            "identifier": str(user.identifier),
+            "metadata": json.dumps(user.metadata) or {}
+            }
+        if not existing_user: # create the user
+            logger.info("Postgres: create_user, creating the user")
+            user_dict['id'] = str(uuid.uuid4())
+            user_dict['createdAt'] = await self.get_current_timestamp()
+            query = "INSERT INTO users (id, identifier, createdAt, metadata) VALUES (:id, :identifier, :createdAt, :metadata)"
+            await self.execute_sql(query=query, parameters=user_dict)
+        else: # update the user
+            query = """UPDATE users SET "metadata" = :metadata WHERE "identifier" = :identifier"""
+            await self.execute_sql(query=query, parameters=user_dict) # We want to update the metadata
+        return await self.get_user(user.identifier)
+        
+    ###### Threads ######
+    async def get_thread_author(self, thread_id: str) -> str:
+        logger.info(f"Postgres: get_thread_author, thread_id={thread_id}")
+        query = """SELECT u.* FROM threads t JOIN users u ON t."user_id" = u."id" WHERE t."id" = :id"""
+        parameters = {"id": thread_id}
+        result = await self.execute_sql(query=query, parameters=parameters)
+        if result and isinstance(result, list) and result[0]:
+            author_identifier = result[0].get('identifier')
+            if author_identifier is not None:
+                return author_identifier
+        raise ValueError(f"Author not found for thread_id {thread_id}")
+    
+    async def get_thread(self, thread_id: str) -> Optional[ThreadDict]:
+        logger.info(f"Postgres: get_thread, thread_id={thread_id}")
+        user_identifier = await self.get_thread_author(thread_id=thread_id)
+        if user_identifier is None:
+            raise ValueError("User identifier not found for the given thread_id")
+        user_threads: Optional[List[ThreadDict]] = await self.get_all_user_threads(user_identifier=user_identifier)
+        if not user_threads:
+            return None
+        for thread in user_threads:
+            if thread['id'] == thread_id:
+                return thread
+        return None
+
+    async def update_thread(self, thread_id: str, name: Optional[str] = None, user_id: Optional[str] = None, metadata: Optional[Dict] = None, tags: Optional[List[str]] = None):
+        logger.info(f"Postgres: update_thread, thread_id={thread_id}")
+        async with self.thread_update_lock:  # Acquire the lock before updating the thread
+            data = {
+                "id": thread_id,
+                "createdAt": await self.get_current_timestamp() if metadata is None else None,
+                "name": name if name is not None else (metadata.get('name') if metadata and 'name' in metadata else None),
+                "user_id": user_id,
+                "tags": tags,
+                "metadata": json.dumps(metadata) if metadata else None,
+            }
+            parameters = {key: value for key, value in data.items() if value is not None} # Remove keys with None values
+            columns = ', '.join(f'"{key}"' for key in parameters.keys())
+            values = ', '.join(f':{key}' for key in parameters.keys())
+            updates = ', '.join(f'"{key}" = EXCLUDED."{key}"' for key in parameters.keys() if key != 'id')
+            query = f"""
+                INSERT INTO threads ({columns})
+                VALUES ({values})
+                ON CONFLICT ("id") DO UPDATE
+                SET {updates};
+            """
+            await self.execute_sql(query=query, parameters=parameters)
+
+    async def delete_thread(self, thread_id: str):
+        logger.info(f"Postgres: delete_thread, thread_id={thread_id}")
+        query = """DELETE FROM threads WHERE "id" = :id"""
+        parameters = {"id": thread_id}
+        await self.execute_sql(query=query, parameters=parameters)
+    
+    async def list_threads(self, pagination: Pagination, filters: ThreadFilter) -> PaginatedResponse[ThreadDict]:
+        logger.info(f"Postgres: list_threads, pagination={pagination}, filters={filters}")
+        if not filters.userIdentifier:
+            raise ValueError("userIdentifier is required")
+        all_user_threads: List[ThreadDict] = await self.get_all_user_threads(user_identifier=filters.userIdentifier) or []
+
+        search_keyword = filters.search.lower() if filters.search else None
+        feedback_value = int(filters.feedback) if filters.feedback else None
+
+        filtered_threads = []
+        for thread in all_user_threads:
+            if search_keyword or feedback_value:
+                keyword_match = any(search_keyword in step['output'].lower() for step in thread['steps'] if 'output' in step) if search_keyword else True
+                if feedback_value is not None:
+                    for step in thread['steps']:
+                        feedback = step.get('feedback')
+                        if feedback and feedback.get('value') == feedback_value:
+                            feedback_match = True
+                            break
+                    else:
+                        feedback_match = False
+                if keyword_match and feedback_match:
+                    filtered_threads.append(thread)
+            else:
+                filtered_threads.append(thread)
+
+        # Apply pagination
+        start = int(pagination.cursor) if pagination.cursor else 0
+        end = start + pagination.first
+        paginated_threads = filtered_threads[start:end] or []
+
+        has_next_page = len(filtered_threads) > end
+        end_cursor = paginated_threads[-1]['id'] if paginated_threads else None
+
+        return PaginatedResponse(
+            pageInfo=PageInfo(hasNextPage=has_next_page, endCursor=end_cursor),
+            data=paginated_threads
+        )
+    
+    ###### Steps ######
+    @queue_until_user_message()
+    async def create_step(self, step_dict: 'StepDict'):
+        logger.info(f"Postgres: create_step, step_id={step_dict.get('id')}")
+        async with self.thread_update_lock:  # Wait for update_thread
+            pass
+        async with self.step_update_lock:  # Acquire the lock before updating the step
+            step_dict['showInput'] = str(step_dict.get('showInput', '')).lower() if 'showInput' in step_dict else None
+            parameters = {key: value for key, value in step_dict.items() if value is not None}  # Remove keys with None values
+
+            columns = ', '.join(f'"{key}"' for key in parameters.keys())
+            values = ', '.join(f':{key}' for key in parameters.keys())
+            updates = ', '.join(f'"{key}" = :{key}' for key in parameters.keys() if key != 'id')
+            query = f"""
+                INSERT INTO steps ({columns})
+                VALUES ({values})
+                ON CONFLICT (id) DO UPDATE
+                SET {updates};
+            """
+            await self.execute_sql(query=query, parameters=parameters)
+    
+    @queue_until_user_message()
+    async def update_step(self, step_dict: 'StepDict'):
+        logger.info(f"Postgres: update_step, step_id={step_dict.get('id')}")
+        await self.create_step(step_dict)
+
+    @queue_until_user_message()
+    async def delete_step(self, step_id: str):
+        logger.info(f"Postgres: delete_step, step_id={step_id}")
+        query = """DELETE FROM steps WHERE "id" = :id"""
+        parameters = {"id": step_id}
+        await self.execute_sql(query=query, parameters=parameters)
+    
+    ###### Feedback ######
+    async def upsert_feedback(self, feedback: Feedback) -> str:
+        logger.info(f"Postgres: upsert_feedback, feedback_id={feedback.id}")
+        feedback.id = feedback.id or str(uuid.uuid4())
+        feedback_dict = asdict(feedback)
+        parameters = {key: value for key, value in feedback_dict.items() if value is not None}  # Remove keys with None values
+
+        columns = ', '.join(f'"{key}"' for key in parameters.keys())
+        values = ', '.join(f':{key}' for key in parameters.keys())
+        updates = ', '.join(f'"{key}" = :{key}' for key in parameters.keys() if key != 'id')
+        query = f"""
+            INSERT INTO feedbacks ({columns})
+            VALUES ({values})
+            ON CONFLICT (id) DO UPDATE
+            SET {updates};
+        """
+        await self.execute_sql(query=query, parameters=parameters)
+        return feedback.id
+
+    ###### Elements ######
+    @queue_until_user_message()
+    async def create_element(self, element: 'Element'):
+        logger.info(f"Postgres: create_element, element_id = {element.id}")
+        async with self.thread_update_lock:
+            pass
+        async with self.step_update_lock:
+            pass
+        if not self.blob_storage_client:
+            raise ValueError("No blob_storage_client is configured")
+        if not element.for_id:
+            return
+        element_dict = element.to_dict()
+        content: Optional[Union[bytes, str]] = None
+
+        if not element.url:
+            if element.path:
+                async with aiofiles.open(element.path, "rb") as f:
+                    content = await f.read()
+            elif element.content:
+                content = element.content
+            else:
+                raise ValueError("Either path or content must be provided")
+
+            context_user = context.session.user
+            if not context_user or not getattr(context_user, 'id', None):
+                raise ValueError("No valid user in context")
+
+            user_folder = getattr(context_user, 'id', 'unknown')
+            object_key = f"{user_folder}/{element.id}" + (f"/{element.name}" if element.name else "")
+
+            if self.blob_storage_provider == 'Azure':
+                file_client = self.blob_storage_client.get_file_client(object_key)
+                content_type = ContentSettings(content_type=element.mime)
+                file_client.upload_data(content, overwrite=True, content_settings=content_type)
+                element.url = file_client.url + (self.blob_access_token or '')
+
+        element_dict['url'] = element.url
+        element_dict['objectKey'] = object_key if 'object_key' in locals() else None
+        element_dict_cleaned = {k: v for k, v in element_dict.items() if v is not None}
+
+        columns = ', '.join(f'"{column}"' for column in element_dict_cleaned.keys())
+        placeholders = ', '.join(f':{column}' for column in element_dict_cleaned.keys())
+        query = f"INSERT INTO elements ({columns}) VALUES ({placeholders})"
+        await self.execute_sql(query=query, parameters=element_dict_cleaned)
+
+    @queue_until_user_message()
+    async def delete_element(self, element_id: str):
+        logger.info(f"Postgres: delete_element, element_id={element_id}")
+        query = """DELETE FROM elements WHERE "id" = :id"""
+        parameters = {"id": element_id}
+        await self.execute_sql(query=query, parameters=parameters)
+
+    async def delete_user_session(self, id: str) -> bool:
+        return False # Not sure why documentation wants this
+
+    #### NEW OPTIMIZATION ####
+    async def get_all_user_threads(self, user_identifier: str) -> Optional[List[ThreadDict]]:
+        """Fetch all user threads  for fast retrieval"""
+        logger.info(f"Postgres: get_all_user_threads")
+        parameters = {"identifier": user_identifier}
+        sql_query = """
+        SELECT
+            t."id" AS thread_id,
+            t."createdAt" AS thread_createdat,
+            t."name" AS thread_name,
+            t."tags" AS thread_tags,
+            t."metadata" AS thread_metadata,
+            u."id" AS user_id,
+            u."identifier" AS user_identifier,
+            u."metadata" AS user_metadata,
+            s."id" AS step_id,
+            s."name" AS step_name,
+            s."type" AS step_type,
+            s."threadId" AS step_threadid,
+            s."parentId" AS step_parentid,
+            s."disableFeedback" AS step_disablefeedback,
+            s."streaming" AS step_streaming,
+            s."waitForAnswer" AS step_waitforanswer,
+            s."isError" AS step_iserror,
+            s."metadata" AS step_metadata,
+            s."input" AS step_input,
+            s."output" AS step_output,
+            s."createdAt" AS step_createdat,
+            s."start" AS step_start,
+            s."end" AS step_end,
+            s."generation" AS step_generation,
+            s."showInput" AS step_showinput,
+            s."language" AS step_language,
+            s."indent" AS step_indent,
+            f."value" AS feedback_value,
+            f."strategy" AS feedback_strategy,
+            f."comment" AS feedback_comment,
+            e."id" AS element_id,
+            e."threadId" as element_threadid,
+            e."type" AS element_type,
+            e."url" AS element_url,
+            e."chainlitKey" AS element_chainlitkey,
+            e."objectKey" as element_objectkey,
+            e."name" AS element_name,
+            e."display" AS element_display,
+            e."size" AS element_size,
+            e."language" AS element_language,
+            e."page" AS element_page,
+            e."forId" AS element_forid,
+            e."mime" AS element_mime
+        FROM
+            threads t
+            LEFT JOIN users u ON t."user_id" = u."id"
+            LEFT JOIN steps s ON t."id" = s."threadId"
+            LEFT JOIN feedbacks f ON s."id" = f."forId"
+            LEFT JOIN elements e ON t."id" = e."threadId"
+        WHERE u."identifier" = :identifier
+        ORDER BY t."createdAt" DESC, s."start" ASC 
+        """
+        results = await self.execute_sql(query=sql_query, parameters=parameters)
+        threads: List[ThreadDict] = []
+        if not isinstance(results, list):
+            raise ValueError("Expected a list of results")
+        for row in results:
+            thread_id = row['thread_id']
+            thread = next((t for t in threads if t['id'] == thread_id), None)
+            if not thread:
+                thread = ThreadDict(
+                    id=thread_id,
+                    createdAt=row['thread_createdat'],
+                    name=row['thread_name'],
+                    user= UserDict(
+                        id=row['user_id'],
+                        identifier=row['user_identifier'],
+                        metadata=row['user_metadata']
+                    ) if row['user_id'] else None,
+                    tags=row['thread_tags'],
+                    metadata=row['thread_metadata'],
+                    steps=[],
+                    elements=[]
+                )
+                threads.append(thread)
+            if row['step_id']:
+                step = StepDict(
+                    id=row['step_id'],
+                    name=row['step_name'],
+                    type=row['step_type'],
+                    threadId=row['step_threadid'],
+                    parentId=row['step_parentid'],
+                    disableFeedback=row['step_disablefeedback'],
+                    streaming=row['step_streaming'],
+                    waitForAnswer=row['step_waitforanswer'],
+                    isError=row['step_iserror'],
+                    metadata=row['step_metadata'],
+                    input=row['step_input'] if row['step_showinput'] else None,
+                    output=row['step_output'],
+                    createdAt=row['step_createdat'],
+                    start=row['step_start'],
+                    end=row['step_end'],
+                    generation=row['step_generation'],
+                    showInput=row['step_showinput'],
+                    language=row['step_language'],
+                    indent=row['step_indent'],
+                    feedback= FeedbackDict(
+                        value=row['feedback_value'],
+                        strategy=row['feedback_strategy'],
+                        comment=row['feedback_comment']
+                     ) if row['feedback_value'] is not None else None
+                )
+                thread['steps'].append(step)
+            if row['element_id']:
+                element: Dict[str, Any] = {
+                    "id":row['element_id'],
+                    "threadId":row['element_threadid'],
+                    "type":row['element_type'],
+                    "chainlitKey":row['element_chainlitkey'],
+                    "url":row['element_url'],
+                    "objectKey":row['element_objectkey'],
+                    "name":row['element_name'],
+                    "display":row['element_display'],
+                    "size":row['element_size'],
+                    "language":row['element_language'],
+                    "page":row['element_page'],
+                    "forId":row['element_forid'],
+                    "mime":row['element_mime']
+                }
+                if thread['elements'] is None:
+                    thread['elements'] = []
+                thread['elements'].append(element)   # type: ignore
+        return threads

From bec6e5dbc5acd4afb9b869ae60d52ea993b81a62 Mon Sep 17 00:00:00 2001
From: Josh Hayes <35790761+hayescode@users.noreply.github.com>
Date: Wed, 20 Mar 2024 22:59:52 -0500
Subject: [PATCH 2/5] Update pyproject.toml

---
 backend/pyproject.toml | 2 ++
 1 file changed, 2 insertions(+)

diff --git a/backend/pyproject.toml b/backend/pyproject.toml
index a0a6ba0c36..b1c84d3122 100644
--- a/backend/pyproject.toml
+++ b/backend/pyproject.toml
@@ -46,6 +46,8 @@ lazify = "^0.4.0"
 packaging = "^23.1"
 python-multipart = "^0.0.9"
 pyjwt = "^2.8.0"
+asyncpg = "^0.29.0"
+SQLAlchemy = "^2.0.28"
 
 [tool.poetry.group.tests]
 optional = true

From 354ddc2b0ea622c55b1d22db496323ada6ea1564 Mon Sep 17 00:00:00 2001
From: Josh Hayes <35790761+hayescode@users.noreply.github.com>
Date: Wed, 20 Mar 2024 23:01:23 -0500
Subject: [PATCH 3/5] Create sql_alchemy.py

---
 cypress/e2e/custom_data_layer/sql_alchemy.py | 28 ++++++++++++++++++++
 1 file changed, 28 insertions(+)
 create mode 100644 cypress/e2e/custom_data_layer/sql_alchemy.py

diff --git a/cypress/e2e/custom_data_layer/sql_alchemy.py b/cypress/e2e/custom_data_layer/sql_alchemy.py
new file mode 100644
index 0000000000..8a059d8aa0
--- /dev/null
+++ b/cypress/e2e/custom_data_layer/sql_alchemy.py
@@ -0,0 +1,28 @@
+from typing import List, Optional
+
+import chainlit.data as cl_data
+from chainlit.data.sql_alchemy import SQLAlchemyDataLayer
+from literalai.helper import utc_now
+
+import chainlit as cl
+
+cl_data._data_layer = SQLAlchemyDataLayer(conninfo="<your conninfo>")
+
+
+@cl.on_chat_start
+async def main():
+    await cl.Message("Hello, send me a message!", disable_feedback=True).send()
+
+
+@cl.on_message
+async def handle_message():
+    await cl.sleep(2)
+    await cl.Message("Ok!").send()
+
+
+@cl.password_auth_callback
+def auth_callback(username: str, password: str) -> Optional[cl.User]:
+    if (username, password) == ("admin", "admin"):
+        return cl.User(identifier="admin")
+    else:
+        return None

From 610a0f4e19ff88bc99e6d37c447fbd8404be701f Mon Sep 17 00:00:00 2001
From: Josh Hayes <35790761+hayescode@users.noreply.github.com>
Date: Mon, 25 Mar 2024 09:55:47 -0500
Subject: [PATCH 4/5] Create sql_alchemy.py

---
 backend/chainlit/data/sql_alchemy.py | 478 +++++++++++++++++++++++++++
 1 file changed, 478 insertions(+)
 create mode 100644 backend/chainlit/data/sql_alchemy.py

diff --git a/backend/chainlit/data/sql_alchemy.py b/backend/chainlit/data/sql_alchemy.py
new file mode 100644
index 0000000000..d7f1c037eb
--- /dev/null
+++ b/backend/chainlit/data/sql_alchemy.py
@@ -0,0 +1,478 @@
+import uuid
+import ssl
+from datetime import datetime, timezone
+import json
+from typing import Any, Dict, List, Optional, Union, TYPE_CHECKING
+import aiofiles
+import asyncio
+from dataclasses import asdict
+from sqlalchemy import text
+from sqlalchemy.exc import SQLAlchemyError
+from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
+from sqlalchemy.orm import sessionmaker
+from azure.storage.filedatalake import FileSystemClient, ContentSettings  # type: ignore
+from chainlit.context import context
+from chainlit.logger import logger
+from chainlit.data import BaseDataLayer, queue_until_user_message
+from chainlit.user import User, PersistedUser, UserDict
+from chainlit.types import Feedback, FeedbackDict, Pagination, ThreadDict, ThreadFilter, PageInfo, PaginatedResponse
+from chainlit.step import StepDict
+from chainlit.element import ElementDict
+
+if TYPE_CHECKING:
+    from chainlit.element import Element, ElementDict
+    from chainlit.step import StepDict
+
+class SQLAlchemyDataLayer(BaseDataLayer):
+    def __init__(self, conninfo, ssl_require=False, user_thread_limit=100):
+        self._conninfo = conninfo
+        self.user_thread_limit = user_thread_limit
+        ssl_args = {}
+        if ssl_require:
+            # Create an SSL context to require an SSL connection
+            ssl_context = ssl.create_default_context()
+            ssl_context.check_hostname = False
+            ssl_context.verify_mode = ssl.CERT_NONE
+            ssl_args['ssl'] = ssl_context
+        self.engine = create_async_engine(self._conninfo, connect_args=ssl_args)
+        self.async_session = sessionmaker(self.engine, expire_on_commit=False, class_=AsyncSession)
+        self.thread_update_lock = asyncio.Lock()
+        self.step_update_lock = asyncio.Lock()
+
+    async def add_blob_storage_client(self, blob_storage_client, access_token: Optional[str]) -> None:
+        if isinstance(blob_storage_client, FileSystemClient):
+            self.blob_storage_client = blob_storage_client
+            self.blob_access_token = access_token
+            self.blob_storage_provider = 'Azure'
+            logger.info("Azure Data Lake Storage client initialized")
+        # Add other checks here for AWS/Google/etc.
+        else:
+            raise ValueError("The provided blob_storage is not recognized")
+
+    ###### SQL Helpers ######
+    async def execute_sql(self, query: str, parameters: dict) -> Union[List[Dict[str, Any]], int, None]:
+        parameterized_query = text(query)
+        async with self.async_session() as session:
+            try:
+                await session.begin()
+                result = await session.execute(parameterized_query, parameters)
+                await session.commit()
+                if result.returns_rows:
+                    json_result = [dict(row._mapping) for row in result.fetchall()]
+                    clean_json_result = self.clean_result(json_result)
+                    return clean_json_result
+                else:
+                    return result.rowcount
+            except SQLAlchemyError as e:
+                await session.rollback()
+                logger.warn(f"An error occurred: {e}")
+                return None
+            except Exception as e:
+                await session.rollback()
+                logger.warn(f"An unexpected error occurred: {e}")
+                return None
+
+    async def get_current_timestamp(self) -> str:
+        return datetime.now(timezone.utc).astimezone().isoformat()
+    
+    def clean_result(self, obj):
+        """Recursively change UUI -> STR and serialize dictionaries"""
+        if isinstance(obj, dict):
+            for k, v in obj.items():
+                obj[k] = self.clean_result(v)
+        elif isinstance(obj, list):
+            return [self.clean_result(item) for item in obj]
+        elif isinstance(obj, uuid.UUID):
+            return str(obj)
+        elif isinstance(obj, dict):
+            return json.dumps(obj)
+        return obj
+              
+    ###### User ######
+    async def get_user(self, identifier: str) -> Optional[PersistedUser]:
+        logger.info(f"SQLAlchemy: get_user, identifier={identifier}")
+        query = "SELECT * FROM users WHERE identifier = :identifier"
+        parameters = {"identifier": identifier}
+        result = await self.execute_sql(query=query, parameters=parameters)
+        if result and isinstance(result, list):
+            user_data = result[0]
+            return PersistedUser(**user_data)
+        return None
+        
+    async def create_user(self, user: User) -> Optional[PersistedUser]:
+        logger.info(f"SQLAlchemy: create_user, user_identifier={user.identifier}")
+        existing_user: Optional['PersistedUser'] = await self.get_user(user.identifier)
+        user_dict: Dict[str, Any]  = {
+            "identifier": str(user.identifier),
+            "metadata": json.dumps(user.metadata) or {}
+            }
+        if not existing_user: # create the user
+            logger.info("SQLAlchemy: create_user, creating the user")
+            user_dict['id'] = str(uuid.uuid4())
+            user_dict['createdAt'] = await self.get_current_timestamp()
+            query = """INSERT INTO users ("id", "identifier", "createdAt", "metadata") VALUES (:id, :identifier, :createdAt, :metadata)"""
+            await self.execute_sql(query=query, parameters=user_dict)
+        else: # update the user
+            query = """UPDATE users SET "metadata" = :metadata WHERE "identifier" = :identifier"""
+            await self.execute_sql(query=query, parameters=user_dict) # We want to update the metadata
+        return await self.get_user(user.identifier)
+        
+    ###### Threads ######
+    async def get_thread_author(self, thread_id: str) -> str:
+        logger.info(f"SQLAlchemy: get_thread_author, thread_id={thread_id}")
+        query = """SELECT u.* FROM threads t JOIN users u ON t."user_id" = u."id" WHERE t."id" = :id"""
+        parameters = {"id": thread_id}
+        result = await self.execute_sql(query=query, parameters=parameters)
+        if result and isinstance(result, list) and result[0]:
+            author_identifier = result[0].get('identifier')
+            if author_identifier is not None:
+                return author_identifier
+        raise ValueError(f"Author not found for thread_id {thread_id}")
+    
+    async def get_thread(self, thread_id: str) -> Optional[ThreadDict]:
+        logger.info(f"SQLAlchemy: get_thread, thread_id={thread_id}")
+        user_identifier = await self.get_thread_author(thread_id=thread_id)
+        if user_identifier is None:
+            raise ValueError("User identifier not found for the given thread_id")
+        user_threads: Optional[List[ThreadDict]] = await self.get_all_user_threads(user_identifier=user_identifier)
+        if not user_threads:
+            return None
+        for thread in user_threads:
+            if thread['id'] == thread_id:
+                return thread
+        return None
+
+    async def update_thread(self, thread_id: str, name: Optional[str] = None, user_id: Optional[str] = None, metadata: Optional[Dict] = None, tags: Optional[List[str]] = None):
+        logger.info(f"SQLAlchemy: update_thread, thread_id={thread_id}")
+        async with self.thread_update_lock:  # Acquire the lock before updating the thread
+            data = {
+                "id": thread_id,
+                "createdAt": await self.get_current_timestamp() if metadata is None else None,
+                "name": name if name is not None else (metadata.get('name') if metadata and 'name' in metadata else None),
+                "user_id": user_id,
+                "tags": tags,
+                "metadata": json.dumps(metadata) if metadata else None,
+            }
+            parameters = {key: value for key, value in data.items() if value is not None} # Remove keys with None values
+            columns = ', '.join(f'"{key}"' for key in parameters.keys())
+            values = ', '.join(f':{key}' for key in parameters.keys())
+            updates = ', '.join(f'"{key}" = EXCLUDED."{key}"' for key in parameters.keys() if key != 'id')
+            query = f"""
+                INSERT INTO threads ({columns})
+                VALUES ({values})
+                ON CONFLICT ("id") DO UPDATE
+                SET {updates};
+            """
+            await self.execute_sql(query=query, parameters=parameters)
+
+    async def delete_thread(self, thread_id: str):
+        logger.info(f"SQLAlchemy: delete_thread, thread_id={thread_id}")
+        query = """DELETE FROM threads WHERE "id" = :id"""
+        parameters = {"id": thread_id}
+        await self.execute_sql(query=query, parameters=parameters)
+    
+    async def list_threads(self, pagination: Pagination, filters: ThreadFilter) -> PaginatedResponse[ThreadDict]:
+        logger.info(f"SQLAlchemy: list_threads, pagination={pagination}, filters={filters}")
+        if not filters.userIdentifier:
+            raise ValueError("userIdentifier is required")
+        all_user_threads: List[ThreadDict] = await self.get_all_user_threads(user_identifier=filters.userIdentifier) or []
+
+        search_keyword = filters.search.lower() if filters.search else None
+        feedback_value = int(filters.feedback) if filters.feedback else None
+
+        filtered_threads = []
+        for thread in all_user_threads:
+            keyword_match = True
+            feedback_match = True  # Initialize feedback_match to True
+            if search_keyword or feedback_value is not None:
+                if search_keyword:
+                    keyword_match = any(search_keyword in step['output'].lower() for step in thread['steps'] if 'output' in step)
+                if feedback_value is not None:
+                    feedback_match = False  # Assume no match until found
+                    for step in thread['steps']:
+                        feedback = step.get('feedback')
+                        if feedback and feedback.get('value') == feedback_value:
+                            feedback_match = True
+                            break
+            if keyword_match and feedback_match:
+                filtered_threads.append(thread)
+
+        start = 0 # Find the start index using pagination.cursor
+        if pagination.cursor:
+            for i, thread in enumerate(filtered_threads):
+                if thread['id'] == pagination.cursor:
+                    start = i + 1
+                    break
+        end = start + pagination.first
+        paginated_threads = filtered_threads[start:end] or []
+
+        has_next_page = len(filtered_threads) > end
+        end_cursor = paginated_threads[-1]['id'] if paginated_threads else None
+
+        return PaginatedResponse(
+            pageInfo=PageInfo(hasNextPage=has_next_page, endCursor=end_cursor),
+            data=paginated_threads
+        )
+    
+    ###### Steps ######
+    @queue_until_user_message()
+    async def create_step(self, step_dict: 'StepDict'):
+        logger.info(f"SQLAlchemy: create_step, step_id={step_dict.get('id')}")
+        logger.info(f"SQLAlchemy: name={step_dict.get('name')}, input={step_dict.get('input')}, output={step_dict.get('name')}")
+        async with self.thread_update_lock:  # Wait for update_thread
+            pass
+        async with self.step_update_lock:  # Acquire the lock before updating the step
+            step_dict['showInput'] = str(step_dict.get('showInput', '')).lower() if 'showInput' in step_dict else None
+            # parameters = {key: value for key, value in step_dict.items() if value is not None}  # Remove keys with None values
+            parameters = {key: value for key, value in step_dict.items() if value is not None and not (isinstance(value, dict) and not value)}
+
+
+            columns = ', '.join(f'"{key}"' for key in parameters.keys())
+            values = ', '.join(f':{key}' for key in parameters.keys())
+            updates = ', '.join(f'"{key}" = :{key}' for key in parameters.keys() if key != 'id')
+            query = f"""
+                INSERT INTO steps ({columns})
+                VALUES ({values})
+                ON CONFLICT (id) DO UPDATE
+                SET {updates};
+            """
+            await self.execute_sql(query=query, parameters=parameters)
+    
+    @queue_until_user_message()
+    async def update_step(self, step_dict: 'StepDict'):
+        logger.info(f"SQLAlchemy: update_step, step_id={step_dict.get('id')}")
+        await self.create_step(step_dict)
+
+    @queue_until_user_message()
+    async def delete_step(self, step_id: str):
+        logger.info(f"SQLAlchemy: delete_step, step_id={step_id}")
+        query = """DELETE FROM steps WHERE "id" = :id"""
+        parameters = {"id": step_id}
+        await self.execute_sql(query=query, parameters=parameters)
+    
+    ###### Feedback ######
+    async def upsert_feedback(self, feedback: Feedback) -> str:
+        logger.info(f"SQLAlchemy: upsert_feedback, feedback_id={feedback.id}")
+        feedback.id = feedback.id or str(uuid.uuid4())
+        feedback_dict = asdict(feedback)
+        parameters = {key: value for key, value in feedback_dict.items() if value is not None}  # Remove keys with None values
+
+        columns = ', '.join(f'"{key}"' for key in parameters.keys())
+        values = ', '.join(f':{key}' for key in parameters.keys())
+        updates = ', '.join(f'"{key}" = :{key}' for key in parameters.keys() if key != 'id')
+        query = f"""
+            INSERT INTO feedbacks ({columns})
+            VALUES ({values})
+            ON CONFLICT (id) DO UPDATE
+            SET {updates};
+        """
+        await self.execute_sql(query=query, parameters=parameters)
+        return feedback.id
+
+    ###### Elements ######
+    @queue_until_user_message()
+    async def create_element(self, element: 'Element'):
+        logger.info(f"SQLAlchemy: create_element, element_id = {element.id}")
+        async with self.thread_update_lock:
+            pass
+        async with self.step_update_lock:
+            pass
+        if not self.blob_storage_client:
+            raise ValueError("No blob_storage_client is configured")
+        if not element.for_id:
+            return
+        element_dict = element.to_dict()
+        content: Optional[Union[bytes, str]] = None
+
+        if not element.url:
+            if element.path:
+                async with aiofiles.open(element.path, "rb") as f:
+                    content = await f.read()
+            elif element.content:
+                content = element.content
+            else:
+                raise ValueError("Either path or content must be provided")
+
+            context_user = context.session.user
+            if not context_user or not getattr(context_user, 'id', None):
+                raise ValueError("No valid user in context")
+
+            user_folder = getattr(context_user, 'id', 'unknown')
+            object_key = f"{user_folder}/{element.id}" + (f"/{element.name}" if element.name else "")
+
+            if self.blob_storage_provider == 'Azure':
+                file_client = self.blob_storage_client.get_file_client(object_key)
+                content_type = ContentSettings(content_type=element.mime)
+                file_client.upload_data(content, overwrite=True, content_settings=content_type)
+                element.url = file_client.url + (self.blob_access_token or '')
+
+        element_dict['url'] = element.url
+        element_dict['objectKey'] = object_key if 'object_key' in locals() else None
+        element_dict_cleaned = {k: v for k, v in element_dict.items() if v is not None}
+
+        columns = ', '.join(f'"{column}"' for column in element_dict_cleaned.keys())
+        placeholders = ', '.join(f':{column}' for column in element_dict_cleaned.keys())
+        query = f"INSERT INTO elements ({columns}) VALUES ({placeholders})"
+        await self.execute_sql(query=query, parameters=element_dict_cleaned)
+
+    @queue_until_user_message()
+    async def delete_element(self, element_id: str):
+        logger.info(f"SQLAlchemy: delete_element, element_id={element_id}")
+        query = """DELETE FROM elements WHERE "id" = :id"""
+        parameters = {"id": element_id}
+        await self.execute_sql(query=query, parameters=parameters)
+
+    async def delete_user_session(self, id: str) -> bool:
+        return False # Not sure why documentation wants this
+
+    async def get_all_user_threads(self, user_identifier: str) -> Optional[List[ThreadDict]]:
+        """Fetch all user threads for fast retrieval, up to self.user_thread_limit"""
+        logger.info(f"SQLAlchemy: get_all_user_threads")
+        user_threads_query = """
+            SELECT
+                t."id" AS thread_id,
+                t."createdAt" AS thread_createdat,
+                t."name" AS thread_name,
+                t."tags" AS thread_tags,
+                t."metadata" AS thread_metadata,
+                u."id" AS user_id,
+                u."identifier" AS user_identifier,
+                u."metadata" AS user_metadata
+            FROM threads t JOIN users u ON t."user_id" = u."id"
+            WHERE u."identifier" = :identifier
+            ORDER BY t."createdAt" DESC
+            LIMIT :limit
+        """
+        user_threads = await self.execute_sql(query=user_threads_query, parameters={"identifier": user_identifier, "limit": self.user_thread_limit})
+        if not isinstance(user_threads, list):
+            return None
+        thread_ids = "('" + "','".join(map(str, [thread['thread_id'] for thread in user_threads])) + "')"
+        if not thread_ids:
+            return []
+        
+        steps_feedbacks_query = f"""
+            SELECT
+                s."id" AS step_id,
+                s."name" AS step_name,
+                s."type" AS step_type,
+                s."threadId" AS step_threadid,
+                s."parentId" AS step_parentid,
+                s."disableFeedback" AS step_disablefeedback,
+                s."streaming" AS step_streaming,
+                s."waitForAnswer" AS step_waitforanswer,
+                s."isError" AS step_iserror,
+                s."metadata" AS step_metadata,
+                s."input" AS step_input,
+                s."output" AS step_output,
+                s."createdAt" AS step_createdat,
+                s."start" AS step_start,
+                s."end" AS step_end,
+                s."generation" AS step_generation,
+                s."showInput" AS step_showinput,
+                s."language" AS step_language,
+                s."indent" AS step_indent,
+                f."value" AS feedback_value,
+                f."strategy" AS feedback_strategy,
+                f."comment" AS feedback_comment
+            FROM steps s LEFT JOIN feedbacks f ON s."id" = f."forId"
+            WHERE s."threadId" IN {thread_ids}
+            ORDER BY s."createdAt" ASC
+        """
+        steps_feedbacks = await self.execute_sql(query=steps_feedbacks_query, parameters={})
+        
+        elements_query = f"""
+            SELECT
+                e."id" AS element_id,
+                e."threadId" as element_threadid,
+                e."type" AS element_type,
+                e."url" AS element_url,
+                e."chainlitKey" AS element_chainlitkey,
+                e."objectKey" as element_objectkey,
+                e."name" AS element_name,
+                e."display" AS element_display,
+                e."size" AS element_size,
+                e."language" AS element_language,
+                e."page" AS element_page,
+                e."forId" AS element_forid,
+                e."mime" AS element_mime
+            FROM elements e
+            WHERE e."threadId" IN {thread_ids}
+        """
+        elements = await self.execute_sql(query=elements_query, parameters={})
+        
+        # Initialize a dictionary to hold ThreadDict objects keyed by thread_id
+        thread_dicts = {}
+        # Process threads_users to create initial ThreadDict objects
+        for thread in user_threads:
+            thread_id = thread['thread_id']
+            thread_dicts[thread_id] = ThreadDict(
+                id=thread_id,
+                createdAt=thread['thread_createdat'],
+                name=thread['thread_name'],
+                user=UserDict(
+                    id=thread['user_id'],
+                    identifier=thread['user_identifier'],
+                    metadata=thread['user_metadata']
+                ),
+                tags=thread['thread_tags'],
+                metadata=thread['thread_metadata'],
+                steps=[],
+                elements=[]
+            )
+        # Process steps_feedbacks to populate the steps in the corresponding ThreadDict
+        if isinstance(steps_feedbacks, list):
+            for step_feedback in steps_feedbacks:
+                thread_id = step_feedback['step_threadid']
+                feedback = None
+                if step_feedback['feedback_value'] is not None:
+                    feedback = FeedbackDict(
+                        value=step_feedback['feedback_value'],
+                        strategy=step_feedback['feedback_strategy'],
+                        comment=step_feedback.get('feedback_comment')
+                    )
+                step_dict = StepDict(
+                    id=step_feedback['step_id'],
+                    name=step_feedback['step_name'],
+                    type=step_feedback['step_type'],
+                    threadId=thread_id,
+                    parentId=step_feedback.get('step_parentid'),
+                    disableFeedback=step_feedback.get('step_disableFeedback', False),
+                    streaming=step_feedback.get('step_streaming', False),
+                    waitForAnswer=step_feedback.get('step_waitForAnswer'),
+                    isError=step_feedback.get('step_isError'),
+                    metadata=step_feedback.get('step_metadata', {}),
+                    input=step_feedback.get('step_input', '') if step_feedback['step_showinput'] else None,
+                    output=step_feedback.get('step_output', ''),
+                    createdAt=step_feedback.get('step_createdAt'),
+                    start=step_feedback.get('step_start'),
+                    end=step_feedback.get('step_end'),
+                    generation=step_feedback.get('step_generation'),
+                    showInput=step_feedback.get('step_showInput'),
+                    language=step_feedback.get('step_language'),
+                    indent=step_feedback.get('step_indent'),
+                    feedback=feedback
+                )
+                # Append the step to the steps list of the corresponding ThreadDict
+                thread_dicts[thread_id]['steps'].append(step_dict)
+
+        if isinstance(elements, list):
+            for element in elements:
+                thread_id = element['element_threadid']
+                element_dict = ElementDict(
+                    id=element['element_id'],
+                    threadId=thread_id,
+                    type=element['element_type'],
+                    chainlitKey=element.get('element_chainlitKey'),
+                    url=element.get('element_url'),
+                    objectKey=element.get('element_objectKey'),
+                    name=element['element_name'],
+                    display=element['element_display'],
+                    size=element.get('element_size'),
+                    language=element.get('element_language'),
+                    page=element.get('element_page'),
+                    forId=element.get('element_forId'),
+                    mime=element.get('element_mime'),
+                )
+                thread_dicts[thread_id]['elements'].append(element_dict)   # type: ignore
+
+        return list(thread_dicts.values()) 

From e11d00dbeb7143e4a535409e8957a7a38dfa481b Mon Sep 17 00:00:00 2001
From: Josh Hayes <35790761+hayescode@users.noreply.github.com>
Date: Mon, 25 Mar 2024 10:03:29 -0500
Subject: [PATCH 5/5] Delete backend/chainlit/sql_alchemy.py

---
 backend/chainlit/sql_alchemy.py | 449 --------------------------------
 1 file changed, 449 deletions(-)
 delete mode 100644 backend/chainlit/sql_alchemy.py

diff --git a/backend/chainlit/sql_alchemy.py b/backend/chainlit/sql_alchemy.py
deleted file mode 100644
index fac8e1e965..0000000000
--- a/backend/chainlit/sql_alchemy.py
+++ /dev/null
@@ -1,449 +0,0 @@
-import uuid
-import ssl
-from datetime import datetime, timezone
-import json
-from typing import Any, Dict, List, Optional, Union, TYPE_CHECKING
-import aiofiles
-import asyncio
-from dataclasses import asdict
-from sqlalchemy import text
-from sqlalchemy.exc import SQLAlchemyError
-from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
-from sqlalchemy.orm import sessionmaker
-from azure.storage.filedatalake import FileSystemClient, ContentSettings  # type: ignore
-from chainlit.context import context
-from chainlit.logger import logger
-from chainlit.data import BaseDataLayer, queue_until_user_message
-from chainlit.user import User, PersistedUser, UserDict
-from chainlit.types import Feedback, FeedbackDict, Pagination, ThreadDict, ThreadFilter
-from literalai import PageInfo, PaginatedResponse
-from chainlit.step import StepDict
-
-if TYPE_CHECKING:
-    from chainlit.element import Element
-    from chainlit.step import StepDict
-
-class SQLAlchemyDataLayer(BaseDataLayer):
-    def __init__(self, conninfo, ssl_require=False):
-        self._conninfo = conninfo
-        ssl_args = {}
-        if ssl_require:
-            # Create an SSL context to require an SSL connection
-            ssl_context = ssl.create_default_context()
-            ssl_context.check_hostname = False
-            ssl_context.verify_mode = ssl.CERT_NONE
-            ssl_args['ssl'] = ssl_context
-        self.engine = create_async_engine(self._conninfo, connect_args=ssl_args)
-        self.async_session = sessionmaker(self.engine, expire_on_commit=False, class_=AsyncSession)
-        self.thread_update_lock = asyncio.Lock()
-        self.step_update_lock = asyncio.Lock()
-
-    async def add_blob_storage_client(self, blob_storage_client, access_token: Optional[str]) -> None:
-        if isinstance(blob_storage_client, FileSystemClient):
-            self.blob_storage_client = blob_storage_client
-            self.blob_access_token = access_token
-            self.blob_storage_provider = 'Azure'
-            logger.info("Azure Data Lake Storage client initialized")
-        # Add other checks here for AWS/Google/etc.
-        else:
-            raise ValueError("The provided blob_storage is not recognized")
-
-    ###### SQL Helpers ######
-    async def execute_sql(self, query: str, parameters: dict) -> Union[List[Dict[str, Any]], int, None]:
-        parameterized_query = text(query)
-        async with self.async_session() as session:
-            try:
-                await session.begin()
-                result = await session.execute(parameterized_query, parameters)
-                await session.commit()
-                if result.returns_rows:
-                    json_result = [dict(row._mapping) for row in result.fetchall()]
-                    clean_json_result = self.clean_result(json_result)
-                    return clean_json_result
-                else:
-                    return result.rowcount
-            except SQLAlchemyError as e:
-                await session.rollback()
-                logger.warn(f"An error occurred: {e}")
-                return None
-            except Exception as e:
-                await session.rollback()
-                logger.warn(f"An unexpected error occurred: {e}")
-                return None
-
-    async def get_current_timestamp(self) -> str:
-        return datetime.now(timezone.utc).astimezone().isoformat()
-    
-    def clean_result(self, obj):
-        """Recursively change UUI -> STR and serialize dictionaries"""
-        if isinstance(obj, dict):
-            for k, v in obj.items():
-                obj[k] = self.clean_result(v)
-        elif isinstance(obj, list):
-            return [self.clean_result(item) for item in obj]
-        elif isinstance(obj, uuid.UUID):
-            return str(obj)
-        elif isinstance(obj, dict):
-            return json.dumps(obj)
-        return obj
-              
-    ###### User ######
-    async def get_user(self, identifier: str) -> Optional[PersistedUser]:
-        logger.info(f"Postgres: get_user, identifier={identifier}")
-        query = "SELECT * FROM users WHERE identifier = :identifier"
-        parameters = {"identifier": identifier}
-        result = await self.execute_sql(query=query, parameters=parameters)
-        if result and isinstance(result, list):
-            user_data = result[0]
-            return PersistedUser(**user_data)
-        return None
-        
-    async def create_user(self, user: User) -> Optional[PersistedUser]:
-        logger.info(f"Postgres: create_user, user_identifier={user.identifier}")
-        existing_user: Optional['PersistedUser'] = await self.get_user(user.identifier)
-        user_dict: Dict[str, Any]  = {
-            "identifier": str(user.identifier),
-            "metadata": json.dumps(user.metadata) or {}
-            }
-        if not existing_user: # create the user
-            logger.info("Postgres: create_user, creating the user")
-            user_dict['id'] = str(uuid.uuid4())
-            user_dict['createdAt'] = await self.get_current_timestamp()
-            query = "INSERT INTO users (id, identifier, createdAt, metadata) VALUES (:id, :identifier, :createdAt, :metadata)"
-            await self.execute_sql(query=query, parameters=user_dict)
-        else: # update the user
-            query = """UPDATE users SET "metadata" = :metadata WHERE "identifier" = :identifier"""
-            await self.execute_sql(query=query, parameters=user_dict) # We want to update the metadata
-        return await self.get_user(user.identifier)
-        
-    ###### Threads ######
-    async def get_thread_author(self, thread_id: str) -> str:
-        logger.info(f"Postgres: get_thread_author, thread_id={thread_id}")
-        query = """SELECT u.* FROM threads t JOIN users u ON t."user_id" = u."id" WHERE t."id" = :id"""
-        parameters = {"id": thread_id}
-        result = await self.execute_sql(query=query, parameters=parameters)
-        if result and isinstance(result, list) and result[0]:
-            author_identifier = result[0].get('identifier')
-            if author_identifier is not None:
-                return author_identifier
-        raise ValueError(f"Author not found for thread_id {thread_id}")
-    
-    async def get_thread(self, thread_id: str) -> Optional[ThreadDict]:
-        logger.info(f"Postgres: get_thread, thread_id={thread_id}")
-        user_identifier = await self.get_thread_author(thread_id=thread_id)
-        if user_identifier is None:
-            raise ValueError("User identifier not found for the given thread_id")
-        user_threads: Optional[List[ThreadDict]] = await self.get_all_user_threads(user_identifier=user_identifier)
-        if not user_threads:
-            return None
-        for thread in user_threads:
-            if thread['id'] == thread_id:
-                return thread
-        return None
-
-    async def update_thread(self, thread_id: str, name: Optional[str] = None, user_id: Optional[str] = None, metadata: Optional[Dict] = None, tags: Optional[List[str]] = None):
-        logger.info(f"Postgres: update_thread, thread_id={thread_id}")
-        async with self.thread_update_lock:  # Acquire the lock before updating the thread
-            data = {
-                "id": thread_id,
-                "createdAt": await self.get_current_timestamp() if metadata is None else None,
-                "name": name if name is not None else (metadata.get('name') if metadata and 'name' in metadata else None),
-                "user_id": user_id,
-                "tags": tags,
-                "metadata": json.dumps(metadata) if metadata else None,
-            }
-            parameters = {key: value for key, value in data.items() if value is not None} # Remove keys with None values
-            columns = ', '.join(f'"{key}"' for key in parameters.keys())
-            values = ', '.join(f':{key}' for key in parameters.keys())
-            updates = ', '.join(f'"{key}" = EXCLUDED."{key}"' for key in parameters.keys() if key != 'id')
-            query = f"""
-                INSERT INTO threads ({columns})
-                VALUES ({values})
-                ON CONFLICT ("id") DO UPDATE
-                SET {updates};
-            """
-            await self.execute_sql(query=query, parameters=parameters)
-
-    async def delete_thread(self, thread_id: str):
-        logger.info(f"Postgres: delete_thread, thread_id={thread_id}")
-        query = """DELETE FROM threads WHERE "id" = :id"""
-        parameters = {"id": thread_id}
-        await self.execute_sql(query=query, parameters=parameters)
-    
-    async def list_threads(self, pagination: Pagination, filters: ThreadFilter) -> PaginatedResponse[ThreadDict]:
-        logger.info(f"Postgres: list_threads, pagination={pagination}, filters={filters}")
-        if not filters.userIdentifier:
-            raise ValueError("userIdentifier is required")
-        all_user_threads: List[ThreadDict] = await self.get_all_user_threads(user_identifier=filters.userIdentifier) or []
-
-        search_keyword = filters.search.lower() if filters.search else None
-        feedback_value = int(filters.feedback) if filters.feedback else None
-
-        filtered_threads = []
-        for thread in all_user_threads:
-            if search_keyword or feedback_value:
-                keyword_match = any(search_keyword in step['output'].lower() for step in thread['steps'] if 'output' in step) if search_keyword else True
-                if feedback_value is not None:
-                    for step in thread['steps']:
-                        feedback = step.get('feedback')
-                        if feedback and feedback.get('value') == feedback_value:
-                            feedback_match = True
-                            break
-                    else:
-                        feedback_match = False
-                if keyword_match and feedback_match:
-                    filtered_threads.append(thread)
-            else:
-                filtered_threads.append(thread)
-
-        # Apply pagination
-        start = int(pagination.cursor) if pagination.cursor else 0
-        end = start + pagination.first
-        paginated_threads = filtered_threads[start:end] or []
-
-        has_next_page = len(filtered_threads) > end
-        end_cursor = paginated_threads[-1]['id'] if paginated_threads else None
-
-        return PaginatedResponse(
-            pageInfo=PageInfo(hasNextPage=has_next_page, endCursor=end_cursor),
-            data=paginated_threads
-        )
-    
-    ###### Steps ######
-    @queue_until_user_message()
-    async def create_step(self, step_dict: 'StepDict'):
-        logger.info(f"Postgres: create_step, step_id={step_dict.get('id')}")
-        async with self.thread_update_lock:  # Wait for update_thread
-            pass
-        async with self.step_update_lock:  # Acquire the lock before updating the step
-            step_dict['showInput'] = str(step_dict.get('showInput', '')).lower() if 'showInput' in step_dict else None
-            parameters = {key: value for key, value in step_dict.items() if value is not None}  # Remove keys with None values
-
-            columns = ', '.join(f'"{key}"' for key in parameters.keys())
-            values = ', '.join(f':{key}' for key in parameters.keys())
-            updates = ', '.join(f'"{key}" = :{key}' for key in parameters.keys() if key != 'id')
-            query = f"""
-                INSERT INTO steps ({columns})
-                VALUES ({values})
-                ON CONFLICT (id) DO UPDATE
-                SET {updates};
-            """
-            await self.execute_sql(query=query, parameters=parameters)
-    
-    @queue_until_user_message()
-    async def update_step(self, step_dict: 'StepDict'):
-        logger.info(f"Postgres: update_step, step_id={step_dict.get('id')}")
-        await self.create_step(step_dict)
-
-    @queue_until_user_message()
-    async def delete_step(self, step_id: str):
-        logger.info(f"Postgres: delete_step, step_id={step_id}")
-        query = """DELETE FROM steps WHERE "id" = :id"""
-        parameters = {"id": step_id}
-        await self.execute_sql(query=query, parameters=parameters)
-    
-    ###### Feedback ######
-    async def upsert_feedback(self, feedback: Feedback) -> str:
-        logger.info(f"Postgres: upsert_feedback, feedback_id={feedback.id}")
-        feedback.id = feedback.id or str(uuid.uuid4())
-        feedback_dict = asdict(feedback)
-        parameters = {key: value for key, value in feedback_dict.items() if value is not None}  # Remove keys with None values
-
-        columns = ', '.join(f'"{key}"' for key in parameters.keys())
-        values = ', '.join(f':{key}' for key in parameters.keys())
-        updates = ', '.join(f'"{key}" = :{key}' for key in parameters.keys() if key != 'id')
-        query = f"""
-            INSERT INTO feedbacks ({columns})
-            VALUES ({values})
-            ON CONFLICT (id) DO UPDATE
-            SET {updates};
-        """
-        await self.execute_sql(query=query, parameters=parameters)
-        return feedback.id
-
-    ###### Elements ######
-    @queue_until_user_message()
-    async def create_element(self, element: 'Element'):
-        logger.info(f"Postgres: create_element, element_id = {element.id}")
-        async with self.thread_update_lock:
-            pass
-        async with self.step_update_lock:
-            pass
-        if not self.blob_storage_client:
-            raise ValueError("No blob_storage_client is configured")
-        if not element.for_id:
-            return
-        element_dict = element.to_dict()
-        content: Optional[Union[bytes, str]] = None
-
-        if not element.url:
-            if element.path:
-                async with aiofiles.open(element.path, "rb") as f:
-                    content = await f.read()
-            elif element.content:
-                content = element.content
-            else:
-                raise ValueError("Either path or content must be provided")
-
-            context_user = context.session.user
-            if not context_user or not getattr(context_user, 'id', None):
-                raise ValueError("No valid user in context")
-
-            user_folder = getattr(context_user, 'id', 'unknown')
-            object_key = f"{user_folder}/{element.id}" + (f"/{element.name}" if element.name else "")
-
-            if self.blob_storage_provider == 'Azure':
-                file_client = self.blob_storage_client.get_file_client(object_key)
-                content_type = ContentSettings(content_type=element.mime)
-                file_client.upload_data(content, overwrite=True, content_settings=content_type)
-                element.url = file_client.url + (self.blob_access_token or '')
-
-        element_dict['url'] = element.url
-        element_dict['objectKey'] = object_key if 'object_key' in locals() else None
-        element_dict_cleaned = {k: v for k, v in element_dict.items() if v is not None}
-
-        columns = ', '.join(f'"{column}"' for column in element_dict_cleaned.keys())
-        placeholders = ', '.join(f':{column}' for column in element_dict_cleaned.keys())
-        query = f"INSERT INTO elements ({columns}) VALUES ({placeholders})"
-        await self.execute_sql(query=query, parameters=element_dict_cleaned)
-
-    @queue_until_user_message()
-    async def delete_element(self, element_id: str):
-        logger.info(f"Postgres: delete_element, element_id={element_id}")
-        query = """DELETE FROM elements WHERE "id" = :id"""
-        parameters = {"id": element_id}
-        await self.execute_sql(query=query, parameters=parameters)
-
-    async def delete_user_session(self, id: str) -> bool:
-        return False # Not sure why documentation wants this
-
-    #### NEW OPTIMIZATION ####
-    async def get_all_user_threads(self, user_identifier: str) -> Optional[List[ThreadDict]]:
-        """Fetch all user threads  for fast retrieval"""
-        logger.info(f"Postgres: get_all_user_threads")
-        parameters = {"identifier": user_identifier}
-        sql_query = """
-        SELECT
-            t."id" AS thread_id,
-            t."createdAt" AS thread_createdat,
-            t."name" AS thread_name,
-            t."tags" AS thread_tags,
-            t."metadata" AS thread_metadata,
-            u."id" AS user_id,
-            u."identifier" AS user_identifier,
-            u."metadata" AS user_metadata,
-            s."id" AS step_id,
-            s."name" AS step_name,
-            s."type" AS step_type,
-            s."threadId" AS step_threadid,
-            s."parentId" AS step_parentid,
-            s."disableFeedback" AS step_disablefeedback,
-            s."streaming" AS step_streaming,
-            s."waitForAnswer" AS step_waitforanswer,
-            s."isError" AS step_iserror,
-            s."metadata" AS step_metadata,
-            s."input" AS step_input,
-            s."output" AS step_output,
-            s."createdAt" AS step_createdat,
-            s."start" AS step_start,
-            s."end" AS step_end,
-            s."generation" AS step_generation,
-            s."showInput" AS step_showinput,
-            s."language" AS step_language,
-            s."indent" AS step_indent,
-            f."value" AS feedback_value,
-            f."strategy" AS feedback_strategy,
-            f."comment" AS feedback_comment,
-            e."id" AS element_id,
-            e."threadId" as element_threadid,
-            e."type" AS element_type,
-            e."url" AS element_url,
-            e."chainlitKey" AS element_chainlitkey,
-            e."objectKey" as element_objectkey,
-            e."name" AS element_name,
-            e."display" AS element_display,
-            e."size" AS element_size,
-            e."language" AS element_language,
-            e."page" AS element_page,
-            e."forId" AS element_forid,
-            e."mime" AS element_mime
-        FROM
-            threads t
-            LEFT JOIN users u ON t."user_id" = u."id"
-            LEFT JOIN steps s ON t."id" = s."threadId"
-            LEFT JOIN feedbacks f ON s."id" = f."forId"
-            LEFT JOIN elements e ON t."id" = e."threadId"
-        WHERE u."identifier" = :identifier
-        ORDER BY t."createdAt" DESC, s."start" ASC 
-        """
-        results = await self.execute_sql(query=sql_query, parameters=parameters)
-        threads: List[ThreadDict] = []
-        if not isinstance(results, list):
-            raise ValueError("Expected a list of results")
-        for row in results:
-            thread_id = row['thread_id']
-            thread = next((t for t in threads if t['id'] == thread_id), None)
-            if not thread:
-                thread = ThreadDict(
-                    id=thread_id,
-                    createdAt=row['thread_createdat'],
-                    name=row['thread_name'],
-                    user= UserDict(
-                        id=row['user_id'],
-                        identifier=row['user_identifier'],
-                        metadata=row['user_metadata']
-                    ) if row['user_id'] else None,
-                    tags=row['thread_tags'],
-                    metadata=row['thread_metadata'],
-                    steps=[],
-                    elements=[]
-                )
-                threads.append(thread)
-            if row['step_id']:
-                step = StepDict(
-                    id=row['step_id'],
-                    name=row['step_name'],
-                    type=row['step_type'],
-                    threadId=row['step_threadid'],
-                    parentId=row['step_parentid'],
-                    disableFeedback=row['step_disablefeedback'],
-                    streaming=row['step_streaming'],
-                    waitForAnswer=row['step_waitforanswer'],
-                    isError=row['step_iserror'],
-                    metadata=row['step_metadata'],
-                    input=row['step_input'] if row['step_showinput'] else None,
-                    output=row['step_output'],
-                    createdAt=row['step_createdat'],
-                    start=row['step_start'],
-                    end=row['step_end'],
-                    generation=row['step_generation'],
-                    showInput=row['step_showinput'],
-                    language=row['step_language'],
-                    indent=row['step_indent'],
-                    feedback= FeedbackDict(
-                        value=row['feedback_value'],
-                        strategy=row['feedback_strategy'],
-                        comment=row['feedback_comment']
-                     ) if row['feedback_value'] is not None else None
-                )
-                thread['steps'].append(step)
-            if row['element_id']:
-                element: Dict[str, Any] = {
-                    "id":row['element_id'],
-                    "threadId":row['element_threadid'],
-                    "type":row['element_type'],
-                    "chainlitKey":row['element_chainlitkey'],
-                    "url":row['element_url'],
-                    "objectKey":row['element_objectkey'],
-                    "name":row['element_name'],
-                    "display":row['element_display'],
-                    "size":row['element_size'],
-                    "language":row['element_language'],
-                    "page":row['element_page'],
-                    "forId":row['element_forid'],
-                    "mime":row['element_mime']
-                }
-                if thread['elements'] is None:
-                    thread['elements'] = []
-                thread['elements'].append(element)   # type: ignore
-        return threads