diff --git a/configs/server_config.yaml b/configs/server_config.yaml index 67a63e91a2..f9156faad2 100644 --- a/configs/server_config.yaml +++ b/configs/server_config.yaml @@ -36,3 +36,4 @@ memgpt_version = 0.3.7 [client] anon_clientid = 00000000-0000-0000-0000-000000000000 + diff --git a/memgpt/agent.py b/memgpt/agent.py index ebe06791f1..706e9489b0 100644 --- a/memgpt/agent.py +++ b/memgpt/agent.py @@ -29,6 +29,7 @@ validate_function_response, verify_first_message_correctness, create_uuid_from_string, + is_utc_datetime, ) from memgpt.constants import ( FIRST_MESSAGE_ATTEMPTS, @@ -140,7 +141,7 @@ def initialize_message_sequence( recall_memory: Optional[RecallMemory] = None, memory_edit_timestamp: Optional[str] = None, include_initial_boot_message: bool = True, -): +) -> List[dict]: if memory_edit_timestamp is None: memory_edit_timestamp = get_local_time() @@ -291,6 +292,13 @@ def __init__( assert all([isinstance(msg, Message) for msg in raw_messages]), (raw_messages, self.agent_state.state["messages"]) self._messages.extend([cast(Message, msg) for msg in raw_messages if msg is not None]) + for m in self._messages: + # assert is_utc_datetime(m.created_at), f"created_at on message for agent {self.agent_state.name} isn't UTC:\n{vars(m)}" + # TODO eventually do casting via an edit_message function + if not is_utc_datetime(m.created_at): + printd(f"Warning - created_at on message for agent {self.agent_state.name} isn't UTC (text='{m.text}')") + m.created_at = m.created_at.replace(tzinfo=datetime.timezone.utc) + else: # print(f"Agent.__init__ :: creating, state={agent_state.state['messages']}") init_messages = initialize_message_sequence( @@ -309,6 +317,13 @@ def __init__( self.messages_total = 0 self._append_to_messages(added_messages=[cast(Message, msg) for msg in init_messages_objs if msg is not None]) + for m in self._messages: + assert is_utc_datetime(m.created_at), f"created_at on message for agent {self.agent_state.name} isn't UTC:\n{vars(m)}" + # TODO eventually do casting via an edit_message function + if not is_utc_datetime(m.created_at): + printd(f"Warning - created_at on message for agent {self.agent_state.name} isn't UTC (text='{m.text}')") + m.created_at = m.created_at.replace(tzinfo=datetime.timezone.utc) + # Keep track of the total number of messages throughout all time self.messages_total = messages_total if messages_total is not None else (len(self._messages) - 1) # (-system) # self.messages_total_init = self.messages_total @@ -445,6 +460,8 @@ def _handle_ai_response( # role: assistant (requesting tool call, set tool call ID) messages.append( + # NOTE: we're recreating the message here + # TODO should probably just overwrite the fields? Message.dict_to_message( agent_id=self.agent_state.id, user_id=self.agent_state.user_id, @@ -710,7 +727,7 @@ def validate_json(user_message_text: str, raise_on_error: bool) -> str: # (if yes) Step 3: call the function # (if yes) Step 4: send the info on the function call and function response to LLM response_message = response.choices[0].message - response_message.copy() + response_message.model_copy() # TODO why are we copying here? all_response_messages, heartbeat_request, function_failed = self._handle_ai_response(response_message) # Add the extra metadata to the assistant response diff --git a/memgpt/agent_store/db.py b/memgpt/agent_store/db.py index cafdd9166c..d36a584af4 100644 --- a/memgpt/agent_store/db.py +++ b/memgpt/agent_store/db.py @@ -151,7 +151,7 @@ class PassageModel(Base): metadata_ = Column(MutableJson) # Add a datetime column, with default value as the current time - created_at = Column(DateTime(timezone=True), server_default=func.now()) + created_at = Column(DateTime(timezone=True)) def __repr__(self): return f" None: """Handle the agent's internal monologue""" assert msg_obj is not None, "QueuingInterface requires msg_obj references for metadata" if self.debug: print(msg) + print(vars(msg_obj)) + print(msg_obj.created_at.isoformat()) new_message = {"internal_monologue": msg} # add extra metadata if msg_obj is not None: new_message["id"] = str(msg_obj.id) + assert is_utc_datetime(msg_obj.created_at), msg_obj.created_at new_message["date"] = msg_obj.created_at.isoformat() self.buffer.put(new_message) def assistant_message(self, msg: str, msg_obj: Optional[Message] = None) -> None: """Handle the agent sending a message""" - assert msg_obj is not None, "QueuingInterface requires msg_obj references for metadata" + # assert msg_obj is not None, "QueuingInterface requires msg_obj references for metadata" + if self.debug: print(msg) + if msg_obj is not None: + print(vars(msg_obj)) + print(msg_obj.created_at.isoformat()) new_message = {"assistant_message": msg} # add extra metadata if msg_obj is not None: new_message["id"] = str(msg_obj.id) + assert is_utc_datetime(msg_obj.created_at), msg_obj.created_at new_message["date"] = msg_obj.created_at.isoformat() + else: + # FIXME this is a total hack + assert self.buffer.qsize() > 1, "Tried to reach back to grab function call data, but couldn't find a buffer message." + # TODO also should not be accessing protected member here + + new_message["id"] = self.buffer.queue[-1]["id"] + # assert is_utc_datetime(msg_obj.created_at), msg_obj.created_at + new_message["date"] = self.buffer.queue[-1]["date"] self.buffer.put(new_message) @@ -95,6 +116,8 @@ def function_message(self, msg: str, msg_obj: Optional[Message] = None, include_ if self.debug: print(msg) + print(vars(msg_obj)) + print(msg_obj.created_at.isoformat()) if msg.startswith("Running "): msg = msg.replace("Running ", "") @@ -121,6 +144,7 @@ def function_message(self, msg: str, msg_obj: Optional[Message] = None, include_ # add extra metadata if msg_obj is not None: new_message["id"] = str(msg_obj.id) + assert is_utc_datetime(msg_obj.created_at), msg_obj.created_at new_message["date"] = msg_obj.created_at.isoformat() self.buffer.put(new_message) diff --git a/memgpt/server/server.py b/memgpt/server/server.py index 023a9bd45b..183d578e90 100644 --- a/memgpt/server/server.py +++ b/memgpt/server/server.py @@ -1064,7 +1064,7 @@ def get_agent_recall_cursor( order_by: Optional[str] = "created_at", order: Optional[str] = "asc", reverse: Optional[bool] = False, - ): + ) -> Tuple[uuid.UUID, List[dict]]: if self.ms.get_user(user_id=user_id) is None: raise ValueError(f"User user_id={user_id} does not exist") if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None: diff --git a/memgpt/utils.py b/memgpt/utils.py index d4c9aa00a4..0b61439075 100644 --- a/memgpt/utils.py +++ b/memgpt/utils.py @@ -1,4 +1,4 @@ -from datetime import datetime, timezone +from datetime import datetime, timezone, timedelta import copy import re import json @@ -469,6 +469,10 @@ ] +def is_utc_datetime(dt: datetime) -> bool: + return dt.tzinfo is not None and dt.tzinfo.utcoffset(dt) == timedelta(0) + + def get_tool_call_id() -> str: return str(uuid.uuid4())[:TOOL_CALL_ID_MAX_LEN]