Skip to content

Commit

Permalink
Get rid of context from SQL Alchemy data layer (#1319), fix SQLite su…
Browse files Browse the repository at this point in the history
…pport (#1137).

* Add SQLite DB tests and fixtures
* Get rid of context in SQL Alchemy data layer.

---------

Signed-off-by: DanielAvdar <[email protected]>
Co-authored-by: Mathijs de Bruin <[email protected]>
  • Loading branch information
DanielAvdar and dokterbob authored Sep 20, 2024
1 parent 2bdd541 commit 1964409
Show file tree
Hide file tree
Showing 6 changed files with 1,333 additions and 1,088 deletions.
10 changes: 7 additions & 3 deletions backend/chainlit/data/dynamodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@ def _update_item(self, key: Dict[str, Any], updates: Dict[str, Any]):
ExpressionAttributeValues=self._serialize_item(expression_attribute_values),
)

@property
def context(self):
return context

async def get_user(self, identifier: str) -> Optional["PersistedUser"]:
_logger.info("DynamoDB: get_user identifier=%s", identifier)

Expand Down Expand Up @@ -241,7 +245,7 @@ async def create_element(self, element: "Element"):
if not element.mime:
element.mime = "application/octet-stream"

context_user = context.session.user
context_user = self.context.session.user
user_folder = getattr(context_user, "id", "unknown")
file_object_key = f"{user_folder}/{element.thread_id}/{element.id}"

Expand Down Expand Up @@ -293,7 +297,7 @@ async def get_element(

@queue_until_user_message()
async def delete_element(self, element_id: str, thread_id: Optional[str] = None):
thread_id = context.session.thread_id
thread_id = self.context.session.thread_id
_logger.info(
"DynamoDB: delete_element thread=%s element=%s", thread_id, element_id
)
Expand Down Expand Up @@ -349,7 +353,7 @@ async def update_step(self, step_dict: "StepDict"):

@queue_until_user_message()
async def delete_step(self, step_id: str):
thread_id = context.session.thread_id
thread_id = self.context.session.thread_id
_logger.info("DynamoDB: delete_feedback thread=%s step=%s", thread_id, step_id)

self.client.delete_item(
Expand Down
87 changes: 67 additions & 20 deletions backend/chainlit/data/sql_alchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import aiofiles
import aiohttp
from chainlit.context import context

from chainlit.data.base import BaseDataLayer, BaseStorageClient
from chainlit.data.utils import queue_until_user_message
from chainlit.element import ElementDict
Expand Down Expand Up @@ -84,6 +84,9 @@ async def execute_sql(
if result.returns_rows:
json_result = [dict(row._mapping) for row in result.fetchall()]
clean_json_result = self.clean_result(json_result)
assert isinstance(clean_json_result, list) or isinstance(
clean_json_result, int
)
return clean_json_result
else:
return result.rowcount
Expand Down Expand Up @@ -118,7 +121,47 @@ async def get_user(self, identifier: str) -> Optional[PersistedUser]:
result = await self.execute_sql(query=query, parameters=parameters)
if result and isinstance(result, list):
user_data = result[0]
return PersistedUser(**user_data)

# SQLite returns JSON as string, we most convert it. (#1137)
metadata = user_data.get("metadata", {})
if isinstance(metadata, str):
metadata = json.loads(metadata)

assert isinstance(metadata, dict)
assert isinstance(user_data["id"], str)
assert isinstance(user_data["identifier"], str)
assert isinstance(user_data["createdAt"], str)

return PersistedUser(
id=user_data["id"],
identifier=user_data["identifier"],
createdAt=user_data["createdAt"],
metadata=metadata,
)
return None

async def _get_user_identifer_by_id(self, user_id: str) -> str:
if self.show_logger:
logger.info(f"SQLAlchemy: _get_user_identifer_by_id, user_id={user_id}")
query = "SELECT identifier FROM users WHERE id = :user_id"
parameters = {"user_id": user_id}
result = await self.execute_sql(query=query, parameters=parameters)

assert result
assert isinstance(result, list)

return result[0]["identifier"]

async def _get_user_id_by_thread(self, thread_id: str) -> Optional[str]:
if self.show_logger:
logger.info(f"SQLAlchemy: _get_user_id_by_thread, thread_id={thread_id}")
query = "SELECT userId FROM threads WHERE id = :thread_id"
parameters = {"thread_id": thread_id}
result = await self.execute_sql(query=query, parameters=parameters)
if result:
assert isinstance(result, list)
return result[0]["userId"]

return None

async def create_user(self, user: User) -> Optional[PersistedUser]:
Expand Down Expand Up @@ -179,10 +222,11 @@ async def update_thread(
):
if self.show_logger:
logger.info(f"SQLAlchemy: update_thread, thread_id={thread_id}")
if context.session.user is not None:
user_identifier = context.session.user.identifier
else:
raise ValueError("User not found in session context")

user_identifier = None
if user_id:
user_identifier = await self._get_user_identifer_by_id(user_id)

data = {
"id": thread_id,
"createdAt": (
Expand Down Expand Up @@ -294,8 +338,7 @@ async def list_threads(
async def create_step(self, step_dict: "StepDict"):
if self.show_logger:
logger.info(f"SQLAlchemy: create_step, step_id={step_dict.get('id')}")
if not getattr(context.session.user, "id", None):
raise ValueError("No authenticated user in context")

step_dict["showInput"] = (
str(step_dict.get("showInput", "")).lower()
if "showInput" in step_dict
Expand Down Expand Up @@ -373,12 +416,18 @@ async def delete_feedback(self, feedback_id: str) -> bool:
return True

###### Elements ######
async def get_element(self, thread_id: str, element_id: str) -> Optional["ElementDict"]:
async def get_element(
self, thread_id: str, element_id: str
) -> Optional["ElementDict"]:
if self.show_logger:
logger.info(f"SQLAlchemy: get_element, thread_id={thread_id}, element_id={element_id}")
logger.info(
f"SQLAlchemy: get_element, thread_id={thread_id}, element_id={element_id}"
)
query = """SELECT * FROM elements WHERE "threadId" = :thread_id AND "id" = :element_id"""
parameters = {"thread_id": thread_id, "element_id": element_id}
element: Union[List[Dict[str, Any]], int, None] = await self.execute_sql(query=query, parameters=parameters)
element: Union[List[Dict[str, Any]], int, None] = await self.execute_sql(
query=query, parameters=parameters
)
if isinstance(element, list) and element:
element_dict: Dict[str, Any] = element[0]
return ElementDict(
Expand All @@ -396,7 +445,7 @@ async def get_element(self, thread_id: str, element_id: str) -> Optional["Elemen
autoPlay=element_dict.get("autoPlay"),
playerConfig=element_dict.get("playerConfig"),
forId=element_dict.get("forId"),
mime=element_dict.get("mime")
mime=element_dict.get("mime"),
)
else:
return None
Expand All @@ -405,8 +454,7 @@ async def get_element(self, thread_id: str, element_id: str) -> Optional["Elemen
async def create_element(self, element: "Element"):
if self.show_logger:
logger.info(f"SQLAlchemy: create_element, element_id = {element.id}")
if not getattr(context.session.user, "id", None):
raise ValueError("No authenticated user in context")

if not self.storage_provider:
logger.warn(
"SQLAlchemy: create_element error. No blob_storage_client is configured!"
Expand Down Expand Up @@ -434,10 +482,8 @@ async def create_element(self, element: "Element"):
if content is None:
raise ValueError("Content is None, cannot upload file")

context_user = context.session.user

user_folder = getattr(context_user, "id", "unknown")
file_object_key = f"{user_folder}/{element.id}" + (
user_id: str = await self._get_user_id_by_thread(element.thread_id) or "unknown"
file_object_key = f"{user_id}/{element.id}" + (
f"/{element.name}" if element.name else ""
)

Expand Down Expand Up @@ -607,8 +653,9 @@ async def get_all_user_threads(
tags=step_feedback.get("step_tags"),
input=(
step_feedback.get("step_input", "")
if step_feedback.get("step_showinput") not in [None, "false"]
else None
if step_feedback.get("step_showinput")
not in [None, "false"]
else ""
),
output=step_feedback.get("step_output", ""),
createdAt=step_feedback.get("step_createdat"),
Expand Down
Loading

0 comments on commit 1964409

Please sign in to comment.