Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor context attribute access in data modules #1319

Merged
merged 5 commits into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions backend/chainlit/data/dynamodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@ def _update_item(self, key: Dict[str, Any], updates: Dict[str, Any]):
ExpressionAttributeValues=self._serialize_item(expression_attribute_values),
)

@property
def context(self):
return context

async def get_user(self, identifier: str) -> Optional["PersistedUser"]:
_logger.info("DynamoDB: get_user identifier=%s", identifier)

Expand Down Expand Up @@ -241,7 +245,7 @@ async def create_element(self, element: "Element"):
if not element.mime:
element.mime = "application/octet-stream"

context_user = context.session.user
context_user = self.context.session.user
user_folder = getattr(context_user, "id", "unknown")
file_object_key = f"{user_folder}/{element.thread_id}/{element.id}"

Expand Down Expand Up @@ -293,7 +297,7 @@ async def get_element(

@queue_until_user_message()
async def delete_element(self, element_id: str, thread_id: Optional[str] = None):
thread_id = context.session.thread_id
thread_id = self.context.session.thread_id
_logger.info(
"DynamoDB: delete_element thread=%s element=%s", thread_id, element_id
)
Expand Down Expand Up @@ -349,7 +353,7 @@ async def update_step(self, step_dict: "StepDict"):

@queue_until_user_message()
async def delete_step(self, step_id: str):
thread_id = context.session.thread_id
thread_id = self.context.session.thread_id
_logger.info("DynamoDB: delete_feedback thread=%s step=%s", thread_id, step_id)

self.client.delete_item(
Expand Down
87 changes: 67 additions & 20 deletions backend/chainlit/data/sql_alchemy.py
dokterbob marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import aiofiles
import aiohttp
from chainlit.context import context

from chainlit.data.base import BaseDataLayer, BaseStorageClient
from chainlit.data.utils import queue_until_user_message
from chainlit.element import ElementDict
Expand Down Expand Up @@ -84,6 +84,9 @@ async def execute_sql(
if result.returns_rows:
json_result = [dict(row._mapping) for row in result.fetchall()]
clean_json_result = self.clean_result(json_result)
assert isinstance(clean_json_result, list) or isinstance(
clean_json_result, int
)
return clean_json_result
else:
return result.rowcount
Expand Down Expand Up @@ -118,7 +121,47 @@ async def get_user(self, identifier: str) -> Optional[PersistedUser]:
result = await self.execute_sql(query=query, parameters=parameters)
if result and isinstance(result, list):
user_data = result[0]
return PersistedUser(**user_data)

# SQLite returns JSON as string, we most convert it. (#1137)
metadata = user_data.get("metadata", {})
if isinstance(metadata, str):
metadata = json.loads(metadata)

assert isinstance(metadata, dict)
assert isinstance(user_data["id"], str)
assert isinstance(user_data["identifier"], str)
assert isinstance(user_data["createdAt"], str)

return PersistedUser(
id=user_data["id"],
identifier=user_data["identifier"],
createdAt=user_data["createdAt"],
metadata=metadata,
)
return None

async def _get_user_identifer_by_id(self, user_id: str) -> str:
if self.show_logger:
logger.info(f"SQLAlchemy: _get_user_identifer_by_id, user_id={user_id}")
query = "SELECT identifier FROM users WHERE id = :user_id"
parameters = {"user_id": user_id}
result = await self.execute_sql(query=query, parameters=parameters)

assert result
assert isinstance(result, list)

return result[0]["identifier"]

async def _get_user_id_by_thread(self, thread_id: str) -> Optional[str]:
if self.show_logger:
logger.info(f"SQLAlchemy: _get_user_id_by_thread, thread_id={thread_id}")
query = "SELECT userId FROM threads WHERE id = :thread_id"
parameters = {"thread_id": thread_id}
result = await self.execute_sql(query=query, parameters=parameters)
if result:
assert isinstance(result, list)
return result[0]["userId"]

return None

async def create_user(self, user: User) -> Optional[PersistedUser]:
Expand Down Expand Up @@ -179,10 +222,11 @@ async def update_thread(
):
if self.show_logger:
logger.info(f"SQLAlchemy: update_thread, thread_id={thread_id}")
if context.session.user is not None:
user_identifier = context.session.user.identifier
else:
raise ValueError("User not found in session context")

DanielAvdar marked this conversation as resolved.
Show resolved Hide resolved
user_identifier = None
if user_id:
user_identifier = await self._get_user_identifer_by_id(user_id)

data = {
"id": thread_id,
"createdAt": (
Expand Down Expand Up @@ -294,8 +338,7 @@ async def list_threads(
async def create_step(self, step_dict: "StepDict"):
if self.show_logger:
logger.info(f"SQLAlchemy: create_step, step_id={step_dict.get('id')}")
if not getattr(context.session.user, "id", None):
raise ValueError("No authenticated user in context")

step_dict["showInput"] = (
str(step_dict.get("showInput", "")).lower()
if "showInput" in step_dict
Expand Down Expand Up @@ -373,12 +416,18 @@ async def delete_feedback(self, feedback_id: str) -> bool:
return True

###### Elements ######
async def get_element(self, thread_id: str, element_id: str) -> Optional["ElementDict"]:
async def get_element(
self, thread_id: str, element_id: str
) -> Optional["ElementDict"]:
if self.show_logger:
logger.info(f"SQLAlchemy: get_element, thread_id={thread_id}, element_id={element_id}")
logger.info(
f"SQLAlchemy: get_element, thread_id={thread_id}, element_id={element_id}"
)
query = """SELECT * FROM elements WHERE "threadId" = :thread_id AND "id" = :element_id"""
parameters = {"thread_id": thread_id, "element_id": element_id}
element: Union[List[Dict[str, Any]], int, None] = await self.execute_sql(query=query, parameters=parameters)
element: Union[List[Dict[str, Any]], int, None] = await self.execute_sql(
query=query, parameters=parameters
)
if isinstance(element, list) and element:
element_dict: Dict[str, Any] = element[0]
return ElementDict(
Expand All @@ -396,7 +445,7 @@ async def get_element(self, thread_id: str, element_id: str) -> Optional["Elemen
autoPlay=element_dict.get("autoPlay"),
playerConfig=element_dict.get("playerConfig"),
forId=element_dict.get("forId"),
mime=element_dict.get("mime")
mime=element_dict.get("mime"),
)
else:
return None
Expand All @@ -405,8 +454,7 @@ async def get_element(self, thread_id: str, element_id: str) -> Optional["Elemen
async def create_element(self, element: "Element"):
if self.show_logger:
logger.info(f"SQLAlchemy: create_element, element_id = {element.id}")
if not getattr(context.session.user, "id", None):
raise ValueError("No authenticated user in context")

if not self.storage_provider:
logger.warn(
"SQLAlchemy: create_element error. No blob_storage_client is configured!"
Expand Down Expand Up @@ -434,10 +482,8 @@ async def create_element(self, element: "Element"):
if content is None:
raise ValueError("Content is None, cannot upload file")

context_user = context.session.user

user_folder = getattr(context_user, "id", "unknown")
file_object_key = f"{user_folder}/{element.id}" + (
user_id: str = await self._get_user_id_by_thread(element.thread_id) or "unknown"
file_object_key = f"{user_id}/{element.id}" + (
f"/{element.name}" if element.name else ""
)

Expand Down Expand Up @@ -607,8 +653,9 @@ async def get_all_user_threads(
tags=step_feedback.get("step_tags"),
input=(
step_feedback.get("step_input", "")
if step_feedback.get("step_showinput") not in [None, "false"]
else None
if step_feedback.get("step_showinput")
not in [None, "false"]
else ""
),
output=step_feedback.get("step_output", ""),
createdAt=step_feedback.get("step_createdat"),
Expand Down
Loading
Loading