From cfce64371ea623cfb7c737877a59c6a143c6e782 Mon Sep 17 00:00:00 2001 From: cpacker Date: Thu, 21 Mar 2024 21:12:41 -0700 Subject: [PATCH 01/11] assert that timezone is included for fresh agent, cast timezone onto existing agent (TODO fix DB problem then add back assert) --- memgpt/agent.py | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/memgpt/agent.py b/memgpt/agent.py index ebe06791f1..dfcd513648 100644 --- a/memgpt/agent.py +++ b/memgpt/agent.py @@ -29,6 +29,7 @@ validate_function_response, verify_first_message_correctness, create_uuid_from_string, + is_utc_datetime, ) from memgpt.constants import ( FIRST_MESSAGE_ATTEMPTS, @@ -140,7 +141,7 @@ def initialize_message_sequence( recall_memory: Optional[RecallMemory] = None, memory_edit_timestamp: Optional[str] = None, include_initial_boot_message: bool = True, -): +) -> List[dict]: if memory_edit_timestamp is None: memory_edit_timestamp = get_local_time() @@ -291,6 +292,13 @@ def __init__( assert all([isinstance(msg, Message) for msg in raw_messages]), (raw_messages, self.agent_state.state["messages"]) self._messages.extend([cast(Message, msg) for msg in raw_messages if msg is not None]) + for m in self._messages: + # assert is_utc_datetime(m.created_at), f"created_at on message for agent {self.agent_state.name} isn't UTC:\n{vars(m)}" + # TODO eventually do casting via an edit_message function + if not is_utc_datetime(m.created_at): + printd(f"Warning - created_at on message for agent {self.agent_state.name} isn't UTC (text='{m.text}')") + m.created_at.replace(tzinfo=datetime.timezone.utc) + else: # print(f"Agent.__init__ :: creating, state={agent_state.state['messages']}") init_messages = initialize_message_sequence( @@ -309,6 +317,13 @@ def __init__( self.messages_total = 0 self._append_to_messages(added_messages=[cast(Message, msg) for msg in init_messages_objs if msg is not None]) + for m in self._messages: + assert is_utc_datetime(m.created_at), f"created_at on message for agent {self.agent_state.name} isn't UTC:\n{vars(m)}" + # TODO eventually do casting via an edit_message function + if not is_utc_datetime(m.created_at): + printd(f"Warning - created_at on message for agent {self.agent_state.name} isn't UTC (text='{m.text}')") + m.created_at.replace(tzinfo=datetime.timezone.utc) + # Keep track of the total number of messages throughout all time self.messages_total = messages_total if messages_total is not None else (len(self._messages) - 1) # (-system) # self.messages_total_init = self.messages_total @@ -445,6 +460,8 @@ def _handle_ai_response( # role: assistant (requesting tool call, set tool call ID) messages.append( + # NOTE: we're recreating the message here + # TODO should probably just overwrite the fields? Message.dict_to_message( agent_id=self.agent_state.id, user_id=self.agent_state.user_id, @@ -710,7 +727,7 @@ def validate_json(user_message_text: str, raise_on_error: bool) -> str: # (if yes) Step 3: call the function # (if yes) Step 4: send the info on the function call and function response to LLM response_message = response.choices[0].message - response_message.copy() + response_message.model_copy() # TODO why are we copying here? all_response_messages, heartbeat_request, function_failed = self._handle_ai_response(response_message) # Add the extra metadata to the assistant response From 5715acf622385898a376fb81640d629da7548a2f Mon Sep 17 00:00:00 2001 From: cpacker Date: Thu, 21 Mar 2024 21:12:59 -0700 Subject: [PATCH 02/11] remove db default on created_at, should be handle higher up --- memgpt/agent_store/db.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/memgpt/agent_store/db.py b/memgpt/agent_store/db.py index aeb9a40f75..47821c7fea 100644 --- a/memgpt/agent_store/db.py +++ b/memgpt/agent_store/db.py @@ -151,7 +151,7 @@ class PassageModel(Base): metadata_ = Column(MutableJson) # Add a datetime column, with default value as the current time - created_at = Column(DateTime(timezone=True), server_default=func.now()) + created_at = Column(DateTime(timezone=True)) def __repr__(self): return f" Date: Thu, 21 Mar 2024 21:14:51 -0700 Subject: [PATCH 05/11] add timezone UTC checker --- memgpt/utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/memgpt/utils.py b/memgpt/utils.py index d4c9aa00a4..0b61439075 100644 --- a/memgpt/utils.py +++ b/memgpt/utils.py @@ -1,4 +1,4 @@ -from datetime import datetime, timezone +from datetime import datetime, timezone, timedelta import copy import re import json @@ -469,6 +469,10 @@ ] +def is_utc_datetime(dt: datetime) -> bool: + return dt.tzinfo is not None and dt.tzinfo.utcoffset(dt) == timedelta(0) + + def get_tool_call_id() -> str: return str(uuid.uuid4())[:TOOL_CALL_ID_MAX_LEN] From 8812498d5a86427ca499ed4b6d6d6dd33027f1e6 Mon Sep 17 00:00:00 2001 From: cpacker Date: Thu, 21 Mar 2024 21:15:19 -0700 Subject: [PATCH 06/11] validate timestamp provided by REST client --- memgpt/server/rest_api/agents/message.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/memgpt/server/rest_api/agents/message.py b/memgpt/server/rest_api/agents/message.py index 4455d9e6c8..863ae206eb 100644 --- a/memgpt/server/rest_api/agents/message.py +++ b/memgpt/server/rest_api/agents/message.py @@ -1,20 +1,21 @@ import asyncio import json import uuid -from datetime import datetime +from datetime import datetime, timezone from asyncio import AbstractEventLoop from enum import Enum from functools import partial -from typing import List, Optional +from typing import List, Optional, Any from fastapi import APIRouter, Body, HTTPException, Query, Depends -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, validator from starlette.responses import StreamingResponse from memgpt.constants import JSON_ENSURE_ASCII from memgpt.server.rest_api.auth_token import get_current_user from memgpt.server.rest_api.interface import QueuingInterface from memgpt.server.server import SyncServer +from memgpt.data_types import Message router = APIRouter() @@ -33,6 +34,14 @@ class UserMessageRequest(BaseModel): description="Timestamp to tag the message with (in ISO format). If null, timestamp will be created server-side on receipt of message.", ) + @validator("timestamp") + def validate_timestamp(cls, value: Any) -> Any: + if value.tzinfo is None or value.tzinfo.utcoffset(value) is None: + raise ValueError("Timestamp must include timezone information.") + if value.tzinfo.utcoffset(value) != datetime.fromtimestamp(timezone.utc).utcoffset(): + raise ValueError("Timestamp must be in UTC.") + return value + class UserMessageResponse(BaseModel): messages: List[dict] = Field(..., description="List of messages generated by the agent in response to the received message.") @@ -90,6 +99,12 @@ def get_agent_messages_cursor( [_, messages] = server.get_agent_recall_cursor( user_id=user_id, agent_id=agent_id, before=request.before, limit=request.limit, reverse=True ) + print("====> messages-cursor DEBUG") + for i, msg in enumerate(messages): + print(f"message {i+1}/{len(messages)}") + # print(f"UTC created-at: {msg.created_at.strftime('%Y-%m-%dT%H:%M:%S.%f')[:-3] + 'Z'}") + print(f"ISO format string: {msg['created_at']}") + print(msg) return GetAgentMessagesResponse(messages=messages) @router.post("/agents/{agent_id}/messages", tags=["agents"], response_model=UserMessageResponse) From ae3097052dfd9b36074f0a95663658d824ab665b Mon Sep 17 00:00:00 2001 From: cpacker Date: Thu, 21 Mar 2024 21:16:12 -0700 Subject: [PATCH 07/11] add more debug printing to queueing interface, and read from the back of the buffer in the assistant_message instead of relying on msg_obj to get pulled correctly (it was stale, likely a concurrency bug) --- memgpt/server/rest_api/interface.py | 26 +++++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/memgpt/server/rest_api/interface.py b/memgpt/server/rest_api/interface.py index eb8f635a98..440decdf20 100644 --- a/memgpt/server/rest_api/interface.py +++ b/memgpt/server/rest_api/interface.py @@ -7,6 +7,7 @@ from memgpt.interface import AgentInterface from memgpt.data_types import Message +from memgpt.utils import is_utc_datetime class QueuingInterface(AgentInterface): @@ -57,34 +58,54 @@ def error(self, error: str): def user_message(self, msg: str, msg_obj: Optional[Message] = None): """Handle reception of a user message""" assert msg_obj is not None, "QueuingInterface requires msg_obj references for metadata" + if self.debug: + print(msg) + print(vars(msg_obj)) + print(msg_obj.created_at.isoformat()) def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None) -> None: """Handle the agent's internal monologue""" assert msg_obj is not None, "QueuingInterface requires msg_obj references for metadata" if self.debug: print(msg) + print(vars(msg_obj)) + print(msg_obj.created_at.isoformat()) new_message = {"internal_monologue": msg} # add extra metadata if msg_obj is not None: new_message["id"] = str(msg_obj.id) + assert is_utc_datetime(msg_obj.created_at), msg_obj.created_at new_message["date"] = msg_obj.created_at.isoformat() self.buffer.put(new_message) def assistant_message(self, msg: str, msg_obj: Optional[Message] = None) -> None: """Handle the agent sending a message""" - assert msg_obj is not None, "QueuingInterface requires msg_obj references for metadata" + # assert msg_obj is not None, "QueuingInterface requires msg_obj references for metadata" + if self.debug: print(msg) + if msg_obj is not None: + print(vars(msg_obj)) + print(msg_obj.created_at.isoformat()) new_message = {"assistant_message": msg} # add extra metadata if msg_obj is not None: new_message["id"] = str(msg_obj.id) + assert is_utc_datetime(msg_obj.created_at), msg_obj.created_at new_message["date"] = msg_obj.created_at.isoformat() + else: + # FIXME this is a total hack + assert self.buffer.qsize() > 1, "Tried to reach back to grab function call data, but couldn't find a buffer message." + # TODO also should not be accessing protected member here + + new_message["id"] = self.buffer.queue[-1]["id"] + # assert is_utc_datetime(msg_obj.created_at), msg_obj.created_at + new_message["date"] = self.buffer.queue[-1]["date"] self.buffer.put(new_message) @@ -95,6 +116,8 @@ def function_message(self, msg: str, msg_obj: Optional[Message] = None, include_ if self.debug: print(msg) + print(vars(msg_obj)) + print(msg_obj.created_at.isoformat()) if msg.startswith("Running "): msg = msg.replace("Running ", "") @@ -121,6 +144,7 @@ def function_message(self, msg: str, msg_obj: Optional[Message] = None, include_ # add extra metadata if msg_obj is not None: new_message["id"] = str(msg_obj.id) + assert is_utc_datetime(msg_obj.created_at), msg_obj.created_at new_message["date"] = msg_obj.created_at.isoformat() self.buffer.put(new_message) From ee1bf6189d8eb17c434c9fdeb63e390737758239 Mon Sep 17 00:00:00 2001 From: cpacker Date: Thu, 21 Mar 2024 21:16:39 -0700 Subject: [PATCH 08/11] drop msg_obj passing in send_message base function since it is bugged --- memgpt/functions/function_sets/base.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/memgpt/functions/function_sets/base.py b/memgpt/functions/function_sets/base.py index 99e3ac426b..f63d39ebab 100644 --- a/memgpt/functions/function_sets/base.py +++ b/memgpt/functions/function_sets/base.py @@ -22,7 +22,10 @@ def send_message(self: Agent, message: str) -> Optional[str]: Optional[str]: None is always returned as this function does not produce a response. """ # FIXME passing of msg_obj here is a hack, unclear if guaranteed to be the correct reference - self.interface.assistant_message(message, msg_obj=self._messages[-1]) + # print("SEND_MESSAGE:::") + # print(self.interface.__class__.__name__) + # print("messages:::", [vars(m) for m in self._messages]) + self.interface.assistant_message(message) # , msg_obj=self._messages[-1]) return None From 69243ae4ad486a4b8a228a8096a8d356700184bd Mon Sep 17 00:00:00 2001 From: cpacker Date: Thu, 21 Mar 2024 21:59:13 -0700 Subject: [PATCH 09/11] correct misuse of replace --- memgpt/agent.py | 4 ++-- memgpt/data_types.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/memgpt/agent.py b/memgpt/agent.py index dfcd513648..706e9489b0 100644 --- a/memgpt/agent.py +++ b/memgpt/agent.py @@ -297,7 +297,7 @@ def __init__( # TODO eventually do casting via an edit_message function if not is_utc_datetime(m.created_at): printd(f"Warning - created_at on message for agent {self.agent_state.name} isn't UTC (text='{m.text}')") - m.created_at.replace(tzinfo=datetime.timezone.utc) + m.created_at = m.created_at.replace(tzinfo=datetime.timezone.utc) else: # print(f"Agent.__init__ :: creating, state={agent_state.state['messages']}") @@ -322,7 +322,7 @@ def __init__( # TODO eventually do casting via an edit_message function if not is_utc_datetime(m.created_at): printd(f"Warning - created_at on message for agent {self.agent_state.name} isn't UTC (text='{m.text}')") - m.created_at.replace(tzinfo=datetime.timezone.utc) + m.created_at = m.created_at.replace(tzinfo=datetime.timezone.utc) # Keep track of the total number of messages throughout all time self.messages_total = messages_total if messages_total is not None else (len(self._messages) - 1) # (-system) diff --git a/memgpt/data_types.py b/memgpt/data_types.py index 4196a2baed..2e09588dcc 100644 --- a/memgpt/data_types.py +++ b/memgpt/data_types.py @@ -135,7 +135,7 @@ def to_json(self): # turn datetime to ISO format # also if the created_at is missing a timezone, add UTC if not is_utc_datetime(self.created_at): - self.created_at.replace(tzinfo=timezone.utc) + self.created_at = self.created_at.replace(tzinfo=timezone.utc) json_message["created_at"] = self.created_at.isoformat() return json_message From e0dfd5b79ec454abd6e9714367e24bb615fd4999 Mon Sep 17 00:00:00 2001 From: cpacker Date: Tue, 26 Mar 2024 19:42:16 -0700 Subject: [PATCH 10/11] clean --- configs/server_config.yaml | 1 + memgpt/functions/function_sets/base.py | 3 --- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/configs/server_config.yaml b/configs/server_config.yaml index 67a63e91a2..f9156faad2 100644 --- a/configs/server_config.yaml +++ b/configs/server_config.yaml @@ -36,3 +36,4 @@ memgpt_version = 0.3.7 [client] anon_clientid = 00000000-0000-0000-0000-000000000000 + diff --git a/memgpt/functions/function_sets/base.py b/memgpt/functions/function_sets/base.py index f63d39ebab..769a05d2e1 100644 --- a/memgpt/functions/function_sets/base.py +++ b/memgpt/functions/function_sets/base.py @@ -22,9 +22,6 @@ def send_message(self: Agent, message: str) -> Optional[str]: Optional[str]: None is always returned as this function does not produce a response. """ # FIXME passing of msg_obj here is a hack, unclear if guaranteed to be the correct reference - # print("SEND_MESSAGE:::") - # print(self.interface.__class__.__name__) - # print("messages:::", [vars(m) for m in self._messages]) self.interface.assistant_message(message) # , msg_obj=self._messages[-1]) return None From 673c8a9f922c4fcccb39dd58af8e5e30940dd013 Mon Sep 17 00:00:00 2001 From: cpacker Date: Tue, 26 Mar 2024 19:43:21 -0700 Subject: [PATCH 11/11] remove prints --- memgpt/server/rest_api/agents/message.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/memgpt/server/rest_api/agents/message.py b/memgpt/server/rest_api/agents/message.py index 863ae206eb..3c1d0b67a2 100644 --- a/memgpt/server/rest_api/agents/message.py +++ b/memgpt/server/rest_api/agents/message.py @@ -99,12 +99,12 @@ def get_agent_messages_cursor( [_, messages] = server.get_agent_recall_cursor( user_id=user_id, agent_id=agent_id, before=request.before, limit=request.limit, reverse=True ) - print("====> messages-cursor DEBUG") - for i, msg in enumerate(messages): - print(f"message {i+1}/{len(messages)}") - # print(f"UTC created-at: {msg.created_at.strftime('%Y-%m-%dT%H:%M:%S.%f')[:-3] + 'Z'}") - print(f"ISO format string: {msg['created_at']}") - print(msg) + # print("====> messages-cursor DEBUG") + # for i, msg in enumerate(messages): + # print(f"message {i+1}/{len(messages)}") + # print(f"UTC created-at: {msg.created_at.strftime('%Y-%m-%dT%H:%M:%S.%f')[:-3] + 'Z'}") + # print(f"ISO format string: {msg['created_at']}") + # print(msg) return GetAgentMessagesResponse(messages=messages) @router.post("/agents/{agent_id}/messages", tags=["agents"], response_model=UserMessageResponse)