Skip to content

Commit

Permalink
refacotring ChatGPT conversations
Browse files Browse the repository at this point in the history
  • Loading branch information
jpal91 committed Oct 12, 2023
1 parent 60b64d4 commit f4b13de
Show file tree
Hide file tree
Showing 5 changed files with 210 additions and 39 deletions.
2 changes: 2 additions & 0 deletions xontrib/chatgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ def _unload_xontrib_(xsh: XonshSession, **_):
del xsh.aliases["chatgpt"]
del xsh.aliases["chatgpt?"]
del xsh.aliases["chat-manager"]
del xsh.ctx['ChatGPT']
del xsh.ctx['chat_manager']

rm_events(xsh)
rm_completers()
Expand Down
83 changes: 45 additions & 38 deletions xontrib_chatgpt/chatgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from xonsh.ansi_colors import ansi_partial_color_format
from openai.error import OpenAIError

from xontrib_chatgpt.conversation import Conversation
from xontrib_chatgpt.args import _gpt_parse
from xontrib_chatgpt.lazyobjs import (
_openai,
Expand Down Expand Up @@ -92,7 +93,7 @@
"""


class ChatGPT(Block):
class ChatGPT(Conversation, Block):
"""Allows for communication with ChatGPT from the xonsh shell"""

__xonsh_block__ = str
Expand All @@ -107,19 +108,19 @@ def __init__(self, alias: str = "", managed: bool = False) -> None:
Alias to use for the instance. Defaults to ''.
Will automatically register a xonsh alias under XSH.aliases[alias] if provided.
"""

super().__init__(hash(self))
self.alias = alias
self._base: list[dict[str, str]] = [
{"role": "system", "content": "You are a helpful assistant."},
{
"role": "system",
"content": "If your responses include code, make sure to wrap it in a markdown code block with the appropriate language.\nExample:\n```python\nprint('Hello World!')\n```",
},
]
self.messages: list[dict[str, str]] = []
self._base_tokens: int = 53
self._tokens: list = []
self._max_tokens = 3000
# self._base: list[dict[str, str]] = [
# {"role": "system", "content": "You are a helpful assistant."},
# {
# "role": "system",
# "content": "If your responses include code, make sure to wrap it in a markdown code block with the appropriate language.\nExample:\n```python\nprint('Hello World!')\n```",
# },
# ]
# self.messages: list[dict[str, str]] = []
# self._base_tokens: int = 53
# self._tokens: list = []
# self._max_tokens = 3000
self._managed = managed

if self.alias:
Expand Down Expand Up @@ -172,19 +173,19 @@ def __str__(self):
def __repr__(self):
return f"ChatGPT(alias={self.alias or None})"

@property
def tokens(self) -> int:
"""Current convo tokens"""
return self._base_tokens + sum(self._tokens)
# @property
# def tokens(self) -> int:
# """Current convo tokens"""
# return self._base_tokens + sum(self._tokens)

@property
def base(self) -> list[dict[str, str]]:
return self._base
# @property
# def base(self) -> list[dict[str, str]]:
# return self._base

@base.setter
def base(self, msgs: list[dict[str, str]]) -> None:
self._base_tokens = sum(get_token_list(msgs))
self._base = msgs
# @base.setter
# def base(self, msgs: list[dict[str, str]]) -> None:
# self._base_tokens = sum(get_token_list(msgs))
# self._base = msgs

def stats(self) -> None:
"""Prints conversation stats to shell"""
Expand All @@ -203,7 +204,7 @@ def _stats(self) -> str:
stats = [
("Alias:", f"{self.alias or None}", "{BOLD_GREEN}", "🤖"),
("Tokens:", self.tokens, "{BOLD_BLUE}", "🪙"),
("Trim After:", f"{self._max_tokens} Tokens", "{BOLD_BLUE}", "🔪"),
("Trim After:", f"{self.max_tokens} Tokens", "{BOLD_BLUE}", "🔪"),
("Messages:", len(self.messages), "{BOLD_BLUE}", "📨"),
]
return stats
Expand Down Expand Up @@ -241,15 +242,15 @@ def chat(self, text: str) -> str:
raise NoApiKeyError()
openai.api_key = api_key

self.messages.append({"role": "user", "content": text})
self.append({"role": "user", "content": text})

try:
response = openai.ChatCompletion.create(
model=model,
messages=self.base + self.messages,
)
except OpenAIError as e:
self.messages.pop()
self.pop()
sys.exit(
ansi_partial_color_format(
"{}OpenAI Error{}: {}".format("{BOLD_RED}", "{RESET}", e)
Expand All @@ -262,20 +263,26 @@ def chat(self, text: str) -> str:
response["usage"]["completion_tokens"],
)

self.messages.append(res_text)
self.append(res_text)
self._tokens.extend([user_toks, gpt_toks])

if self.tokens >= self._max_tokens:
self._trim()
if self.tokens >= self.max_tokens:
self.trim()

if self._managed:
XSH.builtins.events.on_chat_complete.fire(
inst=self,
convo=self.messages[-2:]
)

return res_text["content"]

def _trim(self) -> None:
"""Trims conversation to make sure it doesn't exceed the max tokens"""
tokens = self.tokens
while tokens > self._max_tokens:
self.messages.pop(0)
tokens -= self._tokens.pop(0)
# def _trim(self) -> None:
# """Trims conversation to make sure it doesn't exceed the max tokens"""
# tokens = self.tokens
# while tokens > self._max_tokens:
# self.messages.pop(0)
# tokens -= self._tokens.pop(0)

def _print_res(self, res: str) -> None:
"""Called after receiving response from ChatGPT, prints the response to the shell"""
Expand All @@ -294,7 +301,7 @@ def _format_markdown(self, text: str) -> str:
def _get_json_convo(self, n: int) -> list[dict[str, str]]:
"""Returns the current conversation as a JSON string, up to n last items"""
n = -n if n != 0 else 0
messages = self.base + self.messages if n == 0 else self.messages[n:]
messages = self.full_convo if n == 0 else self.messages[n:]

return json.dumps(messages, indent=4)

Expand All @@ -303,7 +310,7 @@ def _get_printed_convo(self, n: int, color: bool = True) -> list[tuple[str, str]
user = XSH.env.get("USER", "user")
convo = []
n = -n if n != 0 else 0
messages = self.base + self.messages if n == 0 else self.messages[n:]
messages = self.full_convo if n == 0 else self.messages[n:]

for msg in messages:
if msg["role"] == "user":
Expand Down
19 changes: 18 additions & 1 deletion xontrib_chatgpt/chatmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from xonsh.lazyasd import LazyObject

from xontrib_chatgpt.chatgpt import ChatGPT
from xontrib_chatgpt.conversation import ConversationCache
from xontrib_chatgpt.lazyobjs import _FIND_NAME_REGEX, _YAML
from xontrib_chatgpt.args import _cm_parse
from xontrib_chatgpt.exceptions import (
Expand Down Expand Up @@ -102,6 +103,7 @@ class ChatManager:
def __init__(self):
self._instances: dict[int, dict[str, Union[str, ChatGPT]]] = defaultdict(dict)
self._current: Optional[int] = None
self.cache = ConversationCache()
self._update_inst_dict()

def __call__(self, args: list[str], stdin: TextIO = None):
Expand Down Expand Up @@ -241,7 +243,7 @@ def save(
chat = self.get_chat_by_name(chat_name)

try:
res = chat["inst"].save_convo(mode=mode)
res = chat["inst"].save_convo(mode=mode, override=override)
except (NoConversationsError, InvalidConversationsTypeError) as e:
print(e)
sys.exit(1)
Expand Down Expand Up @@ -352,6 +354,12 @@ def get_chat_by_name(self, chat_name: str) -> Optional[dict]:
sys.exit(1)

return chat

def get_name_by_chat(self, inst: ChatGPT) -> Optional[str]:
"""Returns the name of a chat instance"""
for chat in self._instances.values():
if chat["inst"] == inst:
return chat["name"]

def _find_saved(self) -> list[Optional[str]]:
"""Returns a list of saved chat files in the default directory"""
Expand Down Expand Up @@ -443,6 +451,8 @@ def on_chat_destroy_handler(self, inst: ChatGPT) -> None:
inst_hash = hash(inst)

if inst_hash in self._instances:
name = self._instances[inst_hash]["name"]
self.cache.rm_cache(name)
del self._instances[inst_hash]

if inst_hash == self._current:
Expand All @@ -451,6 +461,13 @@ def on_chat_destroy_handler(self, inst: ChatGPT) -> None:
def on_chat_used_handler(self, inst: ChatGPT) -> None:
"""Handler for on_chat_used. Updates the current chat instance."""
self._current = hash(inst)

def on_chat_complete_handler(self, inst: ChatGPT, convo: list[dict[str, str]]) -> None:
"""Handler for on_chat_complete. Updates the internal dict with the new conversation."""
if hash(inst) != self._current:
name = self.get_name_by_chat(inst)
self.cache.current = name
self.cache.extend(convo)

def tutorial(self) -> str:
"""Returns a usage string for the xontrib."""
Expand Down
135 changes: 135 additions & 0 deletions xontrib_chatgpt/conversation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
"""Conversation cache for ChatGPT instances"""
import json
from pathlib import Path
from typing import Optional

from xonsh.built_ins import XSH
from xonsh.lazyasd import LazyObject

from xontrib_chatgpt.lazyobjs import _tiktoken

tiktoken = LazyObject("_tiktoken", globals(), _tiktoken)

class ConversationCache:

def __init__(self, hash: int) -> None:
self.cache: list[dict[str, str]] = []
self.path = Path(XSH.env.get("XONSH_DATA_DIR") / 'chatgpt' / '.cache' / f'{hash}.json')

def __del__(self) -> None:
self.path.unlink(missing_ok=True)

def dump(self) -> None:
with open(self.path, 'w') as f:
json.dump(self.cache, f)
self.cache = []

def load(self) -> list[dict[str, str]]:
if not self.path.exists():
return []

with open(self.path, 'r') as f:
self.cache = json.load(f)
return self.cache


class Conversation(ConversationCache):

def __init__(self, hash: int) -> None:
super().__init__(hash)
self._base: list[dict[str, str]] = [
{"role": "system", "content": "You are a helpful assistant."},
{
"role": "system",
"content": "If your responses include code, make sure to wrap it in a markdown code block with the appropriate language.\nExample:\n```python\nprint('Hello World!')\n```",
},
]
self.messages: list[dict[str, str]] = []
self.cache = self.base
self.base_tokens: int = 53
self._tokens: list = []
self.max_tokens: int = 3000

@property
def tokens(self) -> int:
"""Current convo tokens"""
return self.base_tokens + sum(self._tokens)

@property
def base(self) -> list[dict[str, str]]:
return self._base

@base.setter
def base(self, msgs: list[dict[str, str]]) -> None:
self.base_tokens = sum(get_token_list(msgs))
self._base = msgs

@property
def full_convo(self) -> list[dict[str, str]]:
return self.cache

def append(self, msg: list[dict[str, str]]) -> None:
self.messages.append(msg)
self.cache.append(msg)

def pop(self) -> None:
self.messages.pop()
self.cache.pop()

# def load_convo(self) -> None:
# convo = self.load()
# tokens = get_token_list(convo)

# while convo[0]['role'] == 'system':
# convo.pop(0)
# tokens.pop(0)

# cur_toks, idx = 0, -1

# while True:
# if cur_toks + tokens[idx] >= self.max_tokens:
# break
# cur_toks += tokens[idx]
# idx -= 1

# self.messages = convo[idx]
# self._tokens = tokens[idx]

# def dump_convo(self) -> None:
# self.dump()
# self.messages = []
# self._tokens = []

def trim(self) -> None:
"""Trims conversation to make sure it doesn't exceed the max tokens"""
tokens = self.tokens
while tokens > self.max_tokens:
self.messages.pop(0)
tokens -= self._tokens.pop(0)



def get_token_list(messages: list[dict[str, str]]) -> list[int]:
"""Gets the chat tokens for the loaded conversation
Parameters
----------
messages : list[dict[str, str]]
Messages from the conversation
Returns
-------
list[int]
List of tokens for each message
"""
tokens_per_message = 3
tokens = [3]

for message in messages:
num_tokens = 0
num_tokens += tokens_per_message
for v in message.values():
num_tokens += len(tiktoken.encode(v))
tokens.append(num_tokens)

return tokens
10 changes: 10 additions & 0 deletions xontrib_chatgpt/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,15 @@
of the chat instance to update the current chat for the manager.
""",
),
(
"on_chat_complete",
"""
on_chat_complete(inst: ChatGPT, convo: list[dict[str, str]]) -> None
Fires when chat is successfully completed. Passes the instance of the
chat instance and the conversation to be added to the cache.
"""
)
]


Expand All @@ -42,6 +51,7 @@ def add_events(xsh: XonshSession, cm: ChatManager):
events.on_chat_create(lambda *_, **kw: cm.on_chat_create_handler(**kw))
events.on_chat_destroy(lambda *_, **kw: cm.on_chat_destroy_handler(**kw))
events.on_chat_used(lambda *_, **kw: cm.on_chat_used_handler(**kw))
events.on_chat_complete(lambda *_, **kw: cm.on_chat_complete_handler(**kw))


def rm_events(xsh: XonshSession):
Expand Down

0 comments on commit f4b13de

Please sign in to comment.