Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: patch out-of-sync / missing tzinfo timestamps coming back from API server #1182

Merged
merged 13 commits into from
Mar 27, 2024
Merged
1 change: 1 addition & 0 deletions configs/server_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,4 @@ memgpt_version = 0.3.7

[client]
anon_clientid = 00000000-0000-0000-0000-000000000000

21 changes: 19 additions & 2 deletions memgpt/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
validate_function_response,
verify_first_message_correctness,
create_uuid_from_string,
is_utc_datetime,
)
from memgpt.constants import (
FIRST_MESSAGE_ATTEMPTS,
Expand Down Expand Up @@ -113,8 +114,8 @@
recall_memory: Optional[RecallMemory] = None,
include_char_count: bool = True,
):
full_system_message = "\n".join(

Check failure on line 117 in memgpt/agent.py

View workflow job for this annotation

GitHub Actions / Pyright types check (3.11)

No overloads for "join" match the provided arguments (reportCallIssue)
[

Check failure on line 118 in memgpt/agent.py

View workflow job for this annotation

GitHub Actions / Pyright types check (3.11)

Argument of type "list[str | Unknown | None]" cannot be assigned to parameter "iterable" of type "Iterable[str]" in function "join"   Type "Unknown | None" cannot be assigned to type "str"     "None" is incompatible with "str"   Type "Unknown | None" cannot be assigned to type "str"     "None" is incompatible with "str" (reportArgumentType)
system,
"\n",
f"### Memory [last modified: {memory_edit_timestamp.strip()}]",
Expand All @@ -140,7 +141,7 @@
recall_memory: Optional[RecallMemory] = None,
memory_edit_timestamp: Optional[str] = None,
include_initial_boot_message: bool = True,
):
) -> List[dict]:
if memory_edit_timestamp is None:
memory_edit_timestamp = get_local_time()

Expand Down Expand Up @@ -291,6 +292,13 @@
assert all([isinstance(msg, Message) for msg in raw_messages]), (raw_messages, self.agent_state.state["messages"])
self._messages.extend([cast(Message, msg) for msg in raw_messages if msg is not None])

for m in self._messages:
# assert is_utc_datetime(m.created_at), f"created_at on message for agent {self.agent_state.name} isn't UTC:\n{vars(m)}"
# TODO eventually do casting via an edit_message function
if not is_utc_datetime(m.created_at):
printd(f"Warning - created_at on message for agent {self.agent_state.name} isn't UTC (text='{m.text}')")
m.created_at = m.created_at.replace(tzinfo=datetime.timezone.utc)

else:
# print(f"Agent.__init__ :: creating, state={agent_state.state['messages']}")
init_messages = initialize_message_sequence(
Expand All @@ -309,6 +317,13 @@
self.messages_total = 0
self._append_to_messages(added_messages=[cast(Message, msg) for msg in init_messages_objs if msg is not None])

for m in self._messages:
assert is_utc_datetime(m.created_at), f"created_at on message for agent {self.agent_state.name} isn't UTC:\n{vars(m)}"
# TODO eventually do casting via an edit_message function
if not is_utc_datetime(m.created_at):
printd(f"Warning - created_at on message for agent {self.agent_state.name} isn't UTC (text='{m.text}')")
m.created_at = m.created_at.replace(tzinfo=datetime.timezone.utc)

# Keep track of the total number of messages throughout all time
self.messages_total = messages_total if messages_total is not None else (len(self._messages) - 1) # (-system)
# self.messages_total_init = self.messages_total
Expand Down Expand Up @@ -445,6 +460,8 @@

# role: assistant (requesting tool call, set tool call ID)
messages.append(
# NOTE: we're recreating the message here
# TODO should probably just overwrite the fields?
Message.dict_to_message(
agent_id=self.agent_state.id,
user_id=self.agent_state.user_id,
Expand Down Expand Up @@ -710,7 +727,7 @@
# (if yes) Step 3: call the function
# (if yes) Step 4: send the info on the function call and function response to LLM
response_message = response.choices[0].message
response_message.copy()
response_message.model_copy() # TODO why are we copying here?
all_response_messages, heartbeat_request, function_failed = self._handle_ai_response(response_message)

# Add the extra metadata to the assistant response
Expand Down
4 changes: 2 additions & 2 deletions memgpt/agent_store/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ class PassageModel(Base):
metadata_ = Column(MutableJson)

# Add a datetime column, with default value as the current time
created_at = Column(DateTime(timezone=True), server_default=func.now())
created_at = Column(DateTime(timezone=True))
sarahwooders marked this conversation as resolved.
Show resolved Hide resolved

def __repr__(self):
return f"<Passage(passage_id='{self.id}', text='{self.text}', embedding='{self.embedding})>"
Expand Down Expand Up @@ -217,7 +217,7 @@ class MessageModel(Base):
embedding_model = Column(String)

# Add a datetime column, with default value as the current time
created_at = Column(DateTime(timezone=True), server_default=func.now())
created_at = Column(DateTime(timezone=True))
sarahwooders marked this conversation as resolved.
Show resolved Hide resolved

def __repr__(self):
return f"<Message(message_id='{self.id}', text='{self.text}', embedding='{self.embedding})>"
Expand Down
2 changes: 1 addition & 1 deletion memgpt/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def get_agent_response_to_state(self, response: Union[GetAgentResponse, CreateAg
embedding_config=embedding_config,
state=response.agent_state.state,
# load datetime from timestampe
created_at=datetime.datetime.fromtimestamp(response.agent_state.created_at),
created_at=datetime.datetime.fromtimestamp(response.agent_state.created_at, tz=datetime.timezone.utc),
)
return agent_state

Expand Down
17 changes: 9 additions & 8 deletions memgpt/data_types.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
""" This module contains the data types used by MemGPT. Each data type must include a function to create a DB model. """

import uuid
from datetime import datetime
from datetime import datetime, timezone
from typing import Optional, List, Dict, TypeVar
import numpy as np
from pydantic import BaseModel, Field, Json

from memgpt.constants import (
DEFAULT_HUMAN,
Expand All @@ -14,14 +15,9 @@
MAX_EMBEDDING_DIM,
TOOL_CALL_ID_MAX_LEN,
)
from memgpt.utils import get_local_time, format_datetime, get_utc_time, create_uuid_from_string
from memgpt.utils import get_utc_time, create_uuid_from_string
from memgpt.models import chat_completion_response
from memgpt.utils import get_human_text, get_persona_text, printd

from pydantic import BaseModel, Field, Json
from memgpt.utils import get_human_text, get_persona_text, printd

from pydantic import BaseModel, Field, Json
from memgpt.utils import get_human_text, get_persona_text, printd, is_utc_datetime


class Record:
Expand Down Expand Up @@ -136,6 +132,11 @@ def to_json(self):
json_message = vars(self)
if json_message["tool_calls"] is not None:
json_message["tool_calls"] = [vars(tc) for tc in json_message["tool_calls"]]
# turn datetime to ISO format
# also if the created_at is missing a timezone, add UTC
if not is_utc_datetime(self.created_at):
self.created_at = self.created_at.replace(tzinfo=timezone.utc)
json_message["created_at"] = self.created_at.isoformat()
return json_message

@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion memgpt/functions/function_sets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def send_message(self: Agent, message: str) -> Optional[str]:
Optional[str]: None is always returned as this function does not produce a response.
"""
# FIXME passing of msg_obj here is a hack, unclear if guaranteed to be the correct reference
self.interface.assistant_message(message, msg_obj=self._messages[-1])
self.interface.assistant_message(message) # , msg_obj=self._messages[-1])
return None


Expand Down
4 changes: 2 additions & 2 deletions memgpt/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import Optional, List

from memgpt.constants import DEFAULT_HUMAN, DEFAULT_MEMGPT_MODEL, DEFAULT_PERSONA, DEFAULT_PRESET, LLM_MAX_TOKENS
from memgpt.utils import get_local_time, enforce_types
from memgpt.utils import enforce_types
from memgpt.data_types import AgentState, Source, User, LLMConfig, EmbeddingConfig, Token, Preset
from memgpt.config import MemGPTConfig
from memgpt.functions.functions import load_all_function_sets
Expand Down Expand Up @@ -549,7 +549,7 @@ def list_tools(self, user_id: uuid.UUID) -> List[ToolModel]:
)
for k, v in available_functions.items()
]
print(results)
# print(results)
return results
# results = session.query(PresetModel).filter(PresetModel.user_id == user_id).all()
# return [r.to_record() for r in results]
Expand Down
2 changes: 1 addition & 1 deletion memgpt/persistence_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
BaseRecallMemory,
EmbeddingArchivalMemory,
)
from memgpt.utils import get_local_time, printd
from memgpt.utils import printd
from memgpt.data_types import Message, ToolCall, AgentState

from datetime import datetime
Expand Down
21 changes: 18 additions & 3 deletions memgpt/server/rest_api/agents/message.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
import asyncio
import json
import uuid
from datetime import datetime
from datetime import datetime, timezone
from asyncio import AbstractEventLoop
from enum import Enum
from functools import partial
from typing import List, Optional
from typing import List, Optional, Any

from fastapi import APIRouter, Body, HTTPException, Query, Depends
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, validator
from starlette.responses import StreamingResponse

from memgpt.constants import JSON_ENSURE_ASCII
from memgpt.server.rest_api.auth_token import get_current_user
from memgpt.server.rest_api.interface import QueuingInterface
from memgpt.server.server import SyncServer
from memgpt.data_types import Message

router = APIRouter()

Expand All @@ -33,6 +34,14 @@ class UserMessageRequest(BaseModel):
description="Timestamp to tag the message with (in ISO format). If null, timestamp will be created server-side on receipt of message.",
)

@validator("timestamp")
def validate_timestamp(cls, value: Any) -> Any:
if value.tzinfo is None or value.tzinfo.utcoffset(value) is None:
raise ValueError("Timestamp must include timezone information.")
if value.tzinfo.utcoffset(value) != datetime.fromtimestamp(timezone.utc).utcoffset():
raise ValueError("Timestamp must be in UTC.")
return value


class UserMessageResponse(BaseModel):
messages: List[dict] = Field(..., description="List of messages generated by the agent in response to the received message.")
Expand Down Expand Up @@ -90,6 +99,12 @@ def get_agent_messages_cursor(
[_, messages] = server.get_agent_recall_cursor(
user_id=user_id, agent_id=agent_id, before=request.before, limit=request.limit, reverse=True
)
# print("====> messages-cursor DEBUG")
# for i, msg in enumerate(messages):
# print(f"message {i+1}/{len(messages)}")
# print(f"UTC created-at: {msg.created_at.strftime('%Y-%m-%dT%H:%M:%S.%f')[:-3] + 'Z'}")
# print(f"ISO format string: {msg['created_at']}")
# print(msg)
return GetAgentMessagesResponse(messages=messages)

@router.post("/agents/{agent_id}/messages", tags=["agents"], response_model=UserMessageResponse)
Expand Down
26 changes: 25 additions & 1 deletion memgpt/server/rest_api/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from memgpt.interface import AgentInterface
from memgpt.data_types import Message
from memgpt.utils import is_utc_datetime


class QueuingInterface(AgentInterface):
Expand Down Expand Up @@ -57,34 +58,54 @@ def error(self, error: str):
def user_message(self, msg: str, msg_obj: Optional[Message] = None):
"""Handle reception of a user message"""
assert msg_obj is not None, "QueuingInterface requires msg_obj references for metadata"
if self.debug:
print(msg)
print(vars(msg_obj))
print(msg_obj.created_at.isoformat())

def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None) -> None:
"""Handle the agent's internal monologue"""
assert msg_obj is not None, "QueuingInterface requires msg_obj references for metadata"
if self.debug:
print(msg)
print(vars(msg_obj))
print(msg_obj.created_at.isoformat())

new_message = {"internal_monologue": msg}

# add extra metadata
if msg_obj is not None:
new_message["id"] = str(msg_obj.id)
assert is_utc_datetime(msg_obj.created_at), msg_obj.created_at
new_message["date"] = msg_obj.created_at.isoformat()

self.buffer.put(new_message)

def assistant_message(self, msg: str, msg_obj: Optional[Message] = None) -> None:
"""Handle the agent sending a message"""
assert msg_obj is not None, "QueuingInterface requires msg_obj references for metadata"
# assert msg_obj is not None, "QueuingInterface requires msg_obj references for metadata"

if self.debug:
print(msg)
if msg_obj is not None:
print(vars(msg_obj))
print(msg_obj.created_at.isoformat())

new_message = {"assistant_message": msg}

# add extra metadata
if msg_obj is not None:
new_message["id"] = str(msg_obj.id)
assert is_utc_datetime(msg_obj.created_at), msg_obj.created_at
new_message["date"] = msg_obj.created_at.isoformat()
else:
# FIXME this is a total hack
assert self.buffer.qsize() > 1, "Tried to reach back to grab function call data, but couldn't find a buffer message."
# TODO also should not be accessing protected member here

new_message["id"] = self.buffer.queue[-1]["id"]
# assert is_utc_datetime(msg_obj.created_at), msg_obj.created_at
new_message["date"] = self.buffer.queue[-1]["date"]

self.buffer.put(new_message)

Expand All @@ -95,6 +116,8 @@ def function_message(self, msg: str, msg_obj: Optional[Message] = None, include_

if self.debug:
print(msg)
print(vars(msg_obj))
print(msg_obj.created_at.isoformat())

if msg.startswith("Running "):
msg = msg.replace("Running ", "")
Expand All @@ -121,6 +144,7 @@ def function_message(self, msg: str, msg_obj: Optional[Message] = None, include_
# add extra metadata
if msg_obj is not None:
new_message["id"] = str(msg_obj.id)
assert is_utc_datetime(msg_obj.created_at), msg_obj.created_at
new_message["date"] = msg_obj.created_at.isoformat()

self.buffer.put(new_message)
2 changes: 1 addition & 1 deletion memgpt/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1064,7 +1064,7 @@ def get_agent_recall_cursor(
order_by: Optional[str] = "created_at",
order: Optional[str] = "asc",
reverse: Optional[bool] = False,
):
) -> Tuple[uuid.UUID, List[dict]]:
if self.ms.get_user(user_id=user_id) is None:
raise ValueError(f"User user_id={user_id} does not exist")
if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None:
Expand Down
6 changes: 5 additions & 1 deletion memgpt/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from datetime import datetime, timezone
from datetime import datetime, timezone, timedelta
import copy
import re
import json
Expand Down Expand Up @@ -469,6 +469,10 @@
]


def is_utc_datetime(dt: datetime) -> bool:
return dt.tzinfo is not None and dt.tzinfo.utcoffset(dt) == timedelta(0)


def get_tool_call_id() -> str:
return str(uuid.uuid4())[:TOOL_CALL_ID_MAX_LEN]

Expand Down
Loading