Skip to content

Commit

Permalink
Fixed adding float values into DynamoDB (#26562)
Browse files Browse the repository at this point in the history
Thank you for contributing to LangChain!

- [x] **PR title**: Add float Message into Dynamo DB
  -  community
  - Example: "community: Chat Message History 


- [x] **PR message**: 
- **Description:** pushing float values into dynamo db creates error ,
solved that by converting to str type
    - **Issue:** Float values are not getting pushed
    - **Twitter handle:** VpkPrasanna
    
    
Have added an utility function for str conversion , let me know where to
place it happy to do an commit.
    
    This PR is from an discussion of #26543
    
    @hwchase17 @baskaryan @efriis

---------

Co-authored-by: Chester Curme <[email protected]>
  • Loading branch information
VpkPrasanna and ccurme authored Dec 18, 2024
1 parent 50ea1c3 commit 684b146
Showing 1 changed file with 19 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import logging
from decimal import Decimal
from typing import TYPE_CHECKING, Dict, List, Optional

from langchain_core.chat_history import BaseChatMessageHistory
Expand All @@ -17,6 +18,16 @@
logger = logging.getLogger(__name__)


def convert_messages(item: List) -> List:
if isinstance(item, list):
return [convert_messages(i) for i in item]
elif isinstance(item, dict):
return {k: convert_messages(v) for k, v in item.items()}
elif isinstance(item, float):
return Decimal(str(item))
return item


class DynamoDBChatMessageHistory(BaseChatMessageHistory):
"""Chat message history that stores history in AWS DynamoDB.
Expand Down Expand Up @@ -47,6 +58,8 @@ class DynamoDBChatMessageHistory(BaseChatMessageHistory):
limit. If not None then only the latest `history_size` messages are stored.
history_messages_key: Key for the chat history where the messages
are stored and updated
coerce_float_to_decimal: If True, all float values in the messages will be
converted to Decimal.
"""

def __init__(
Expand All @@ -62,6 +75,8 @@ def __init__(
ttl_key_name: str = "expireAt",
history_size: Optional[int] = None,
history_messages_key: Optional[str] = "History",
*,
coerce_float_to_decimal: bool = False,
):
if boto3_session:
client = boto3_session.resource("dynamodb", endpoint_url=endpoint_url)
Expand All @@ -83,6 +98,7 @@ def __init__(
self.ttl_key_name = ttl_key_name
self.history_size = history_size
self.history_messages_key = history_messages_key
self.coerce_float_to_decimal = coerce_float_to_decimal

if kms_key_id:
try:
Expand Down Expand Up @@ -159,6 +175,9 @@ def add_message(self, message: BaseMessage) -> None:
_message = message_to_dict(message)
messages.append(_message)

if self.coerce_float_to_decimal:
messages = convert_messages(messages)

if self.history_size:
messages = messages[-self.history_size :]

Expand Down

0 comments on commit 684b146

Please sign in to comment.