Skip to content

Commit

Permalink
refactor: refactored msg_to_toml and toml_to_message into Message met…
Browse files Browse the repository at this point in the history
…hods
  • Loading branch information
ErikBjare committed Nov 6, 2023
1 parent c237dde commit 11b711f
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 58 deletions.
96 changes: 48 additions & 48 deletions gptme/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import sys
import textwrap
from datetime import datetime
from typing import Literal
from typing import Literal, Self

import tomlkit
from rich import print
Expand Down Expand Up @@ -40,6 +40,10 @@ def __init__(
# Wether this message should be printed on execution (will still print on resume, unlike hide)
self.quiet = quiet

def __repr__(self):
content = textwrap.shorten(self.content, 20, placeholder="...")
return f"<Message role={self.role} content={content}>"

def to_dict(self, keys=None):
"""Return a dict representation of the message, serializable to JSON."""
d = {
Expand All @@ -54,9 +58,48 @@ def to_dict(self, keys=None):
def format(self, oneline: bool = False, highlight: bool = False) -> str:
return format_msgs([self], oneline=oneline, highlight=highlight)[0]

def __repr__(self):
content = textwrap.shorten(self.content, 20, placeholder="...")
return f"<Message role={self.role} content={content}>"
def to_toml(self) -> str:
"""Converts a message to a TOML string, for easy editing by hand in editor to then be parsed back."""
flags = []
if self.pinned:
flags.append("pinned")
if self.hide:
flags.append("hide")
if self.quiet:
flags.append("quiet")
flags_toml = "\n".join(f"{flag} = true" for flag in flags)

# doublequotes need to be escaped
content = self.content.replace('"', '\\"')
return f'''[message]
role = "{self.role}"
content = """
{content}
"""
timestamp = "{self.timestamp.isoformat()}"
{flags_toml}
'''

@classmethod
def from_toml(cls, toml: str) -> Self:
"""
Converts a TOML string to a message.
The string can be a single [[message]].
"""

t = tomlkit.parse(toml)
assert "message" in t and isinstance(t["message"], dict)
msg: dict = t["message"] # type: ignore

return cls(
msg["role"],
msg["content"],
pinned=msg.get("pinned", False),
hide=msg.get("hide", False),
quiet=msg.get("quiet", False),
timestamp=datetime.fromisoformat(msg["timestamp"]),
)

def get_codeblocks(self, content=False) -> list[str]:
"""
Expand Down Expand Up @@ -154,58 +197,15 @@ def print_msg(
)


def msg_to_toml(msg: Message) -> str:
"""Converts a message to a TOML string, for easy editing by hand in editor to then be parsed back."""
# TODO: escape msg.content
flags = []
if msg.pinned:
flags.append("pinned")
if msg.hide:
flags.append("hide")
if msg.quiet:
flags.append("quiet")

# doublequotes need to be escaped
content = msg.content.replace('"', '\\"')
return f'''[message]
role = "{msg.role}"
content = """
{content}
"""
timestamp = "{msg.timestamp.isoformat()}"
'''


def msgs_to_toml(msgs: list[Message]) -> str:
"""Converts a list of messages to a TOML string, for easy editing by hand in editor to then be parsed back."""
t = ""
for msg in msgs:
t += msg_to_toml(msg).replace("[message]", "[[messages]]") + "\n\n"
t += msg.to_toml().replace("[message]", "[[messages]]") + "\n\n"

return t


def toml_to_msg(toml: str) -> Message:
"""
Converts a TOML string to a message.
The string can be a single [[message]].
"""

t = tomlkit.parse(toml)
assert "message" in t and isinstance(t["message"], dict)
msg: dict = t["message"] # type: ignore

return Message(
msg["role"],
msg["content"],
pinned=msg.get("pinned", False),
hide=msg.get("hide", False),
quiet=msg.get("quiet", False),
timestamp=datetime.fromisoformat(msg["timestamp"]),
)


def toml_to_msgs(toml: str) -> list[Message]:
"""
Converts a TOML string to a list of messages.
Expand Down
21 changes: 11 additions & 10 deletions tests/test_message.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,25 @@
from gptme.message import (
Message,
msg_to_toml,
msgs_to_toml,
toml_to_msg,
toml_to_msgs,
)
from gptme.message import Message, msgs_to_toml, toml_to_msgs


def test_toml():
# single message, check escaping
msg = Message(
"system",
'''Hello world!
"""Difficult to handle string"""
''',
)
t = msg_to_toml(msg)
t = msg.to_toml()
print(t)
m = toml_to_msg(t)
m = Message.from_toml(t)
print(m)
assert msg.content == m.content
assert msg.role == m.role
assert msg.timestamp.date() == m.timestamp.date()
assert msg.timestamp.timetuple() == m.timestamp.timetuple()

msg2 = Message("user", "Hello computer!")
# multiple messages
msg2 = Message("user", "Hello computer!", pinned=True, hide=True, quiet=True)
ts = msgs_to_toml([msg, msg2])
print(ts)
ms = toml_to_msgs(ts)
Expand All @@ -34,6 +30,11 @@ def test_toml():
assert ms[0].content == msg.content
assert ms[1].content == msg2.content

# check flags
assert ms[1].pinned == msg2.pinned
assert ms[1].hide == msg2.hide
assert ms[1].quiet == msg2.quiet


def test_get_codeblocks():
# single codeblock only
Expand Down

0 comments on commit 11b711f

Please sign in to comment.