Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: DynamoDB Typeerror with AWS Bedrock #1163

Closed
wants to merge 9 commits into from
30 changes: 30 additions & 0 deletions backend/chainlit/data/dynamodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
import os
import random
from decimal import Decimal
from dataclasses import asdict
from datetime import datetime
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
Expand Down Expand Up @@ -71,6 +72,32 @@ def _deserialize_item(self, item: Dict[str, Any]) -> Dict[str, Any]:
for key, value in item.items()
}

def _convert_floats_to_decimal(self, obj):
munday-tech marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(obj, dict):
for key, value in obj.items():
if isinstance(value, float):
obj[key] = Decimal(str(value))
elif isinstance(value, dict):
self._convert_floats_to_decimal(value)
elif isinstance(value, list):
obj[key] = [self._convert_floats_to_decimal(i) for i in value]
elif isinstance(obj, list):
return [self._convert_floats_to_decimal(i) for i in obj]
return obj
munday-tech marked this conversation as resolved.
Show resolved Hide resolved

def _convert_decimal_to_floats(self, obj):
if isinstance(obj, dict):
for key, value in obj.items():
if isinstance(value, Decimal):
obj[key] = float(value)
elif isinstance(value, dict):
self._convert_decimal_to_floats(value)
elif isinstance(value, list):
obj[key] = [self._convert_decimal_to_floats(i) for i in value]
elif isinstance(obj, list):
return [self._convert_decimal_to_floats(i) for i in obj]
return obj
munday-tech marked this conversation as resolved.
Show resolved Hide resolved

def _update_item(self, key: Dict[str, Any], updates: Dict[str, Any]):
update_expr: List[str] = []
expression_attribute_names = {}
Expand All @@ -83,6 +110,8 @@ def _update_item(self, key: Dict[str, Any], updates: Dict[str, Any]):
k, v = f"#{index}", f":{index}"
update_expr.append(f"{k} = {v}")
expression_attribute_names[k] = attr
if isinstance(value, (dict, list)):
value = self._convert_floats_to_decimal(value)
expression_attribute_values[v] = value

self.client.update_item(
Expand Down Expand Up @@ -510,6 +539,7 @@ async def get_thread(self, thread_id: str) -> "Optional[ThreadDict]":
steps = []
elements = []

thread_items = self._convert_decimal_to_floats(thread_items)
munday-tech marked this conversation as resolved.
Show resolved Hide resolved
for item in thread_items:
if item["SK"] == "THREAD":
thread_dict = item
Expand Down