diff --git a/xontrib/chatgpt.py b/xontrib/chatgpt.py index 40a0aca..7b13fe4 100644 --- a/xontrib/chatgpt.py +++ b/xontrib/chatgpt.py @@ -31,6 +31,8 @@ def _unload_xontrib_(xsh: XonshSession, **_): del xsh.aliases["chatgpt"] del xsh.aliases["chatgpt?"] del xsh.aliases["chat-manager"] + del xsh.ctx['ChatGPT'] + del xsh.ctx['chat_manager'] rm_events(xsh) rm_completers() diff --git a/xontrib_chatgpt/chatgpt.py b/xontrib_chatgpt/chatgpt.py index f67267f..6cf2f0f 100644 --- a/xontrib_chatgpt/chatgpt.py +++ b/xontrib_chatgpt/chatgpt.py @@ -13,6 +13,7 @@ from xonsh.ansi_colors import ansi_partial_color_format from openai.error import OpenAIError +from xontrib_chatgpt.conversation import Conversation from xontrib_chatgpt.args import _gpt_parse from xontrib_chatgpt.lazyobjs import ( _openai, @@ -92,7 +93,7 @@ """ -class ChatGPT(Block): +class ChatGPT(Conversation, Block): """Allows for communication with ChatGPT from the xonsh shell""" __xonsh_block__ = str @@ -107,19 +108,19 @@ def __init__(self, alias: str = "", managed: bool = False) -> None: Alias to use for the instance. Defaults to ''. Will automatically register a xonsh alias under XSH.aliases[alias] if provided. """ - + super().__init__(hash(self)) self.alias = alias - self._base: list[dict[str, str]] = [ - {"role": "system", "content": "You are a helpful assistant."}, - { - "role": "system", - "content": "If your responses include code, make sure to wrap it in a markdown code block with the appropriate language.\nExample:\n```python\nprint('Hello World!')\n```", - }, - ] - self.messages: list[dict[str, str]] = [] - self._base_tokens: int = 53 - self._tokens: list = [] - self._max_tokens = 3000 + # self._base: list[dict[str, str]] = [ + # {"role": "system", "content": "You are a helpful assistant."}, + # { + # "role": "system", + # "content": "If your responses include code, make sure to wrap it in a markdown code block with the appropriate language.\nExample:\n```python\nprint('Hello World!')\n```", + # }, + # ] + # self.messages: list[dict[str, str]] = [] + # self._base_tokens: int = 53 + # self._tokens: list = [] + # self._max_tokens = 3000 self._managed = managed if self.alias: @@ -172,19 +173,19 @@ def __str__(self): def __repr__(self): return f"ChatGPT(alias={self.alias or None})" - @property - def tokens(self) -> int: - """Current convo tokens""" - return self._base_tokens + sum(self._tokens) + # @property + # def tokens(self) -> int: + # """Current convo tokens""" + # return self._base_tokens + sum(self._tokens) - @property - def base(self) -> list[dict[str, str]]: - return self._base + # @property + # def base(self) -> list[dict[str, str]]: + # return self._base - @base.setter - def base(self, msgs: list[dict[str, str]]) -> None: - self._base_tokens = sum(get_token_list(msgs)) - self._base = msgs + # @base.setter + # def base(self, msgs: list[dict[str, str]]) -> None: + # self._base_tokens = sum(get_token_list(msgs)) + # self._base = msgs def stats(self) -> None: """Prints conversation stats to shell""" @@ -203,7 +204,7 @@ def _stats(self) -> str: stats = [ ("Alias:", f"{self.alias or None}", "{BOLD_GREEN}", "🤖"), ("Tokens:", self.tokens, "{BOLD_BLUE}", "🪙"), - ("Trim After:", f"{self._max_tokens} Tokens", "{BOLD_BLUE}", "🔪"), + ("Trim After:", f"{self.max_tokens} Tokens", "{BOLD_BLUE}", "🔪"), ("Messages:", len(self.messages), "{BOLD_BLUE}", "📨"), ] return stats @@ -241,7 +242,7 @@ def chat(self, text: str) -> str: raise NoApiKeyError() openai.api_key = api_key - self.messages.append({"role": "user", "content": text}) + self.append({"role": "user", "content": text}) try: response = openai.ChatCompletion.create( @@ -249,7 +250,7 @@ def chat(self, text: str) -> str: messages=self.base + self.messages, ) except OpenAIError as e: - self.messages.pop() + self.pop() sys.exit( ansi_partial_color_format( "{}OpenAI Error{}: {}".format("{BOLD_RED}", "{RESET}", e) @@ -262,20 +263,26 @@ def chat(self, text: str) -> str: response["usage"]["completion_tokens"], ) - self.messages.append(res_text) + self.append(res_text) self._tokens.extend([user_toks, gpt_toks]) - if self.tokens >= self._max_tokens: - self._trim() + if self.tokens >= self.max_tokens: + self.trim() + + if self._managed: + XSH.builtins.events.on_chat_complete.fire( + inst=self, + convo=self.messages[-2:] + ) return res_text["content"] - def _trim(self) -> None: - """Trims conversation to make sure it doesn't exceed the max tokens""" - tokens = self.tokens - while tokens > self._max_tokens: - self.messages.pop(0) - tokens -= self._tokens.pop(0) + # def _trim(self) -> None: + # """Trims conversation to make sure it doesn't exceed the max tokens""" + # tokens = self.tokens + # while tokens > self._max_tokens: + # self.messages.pop(0) + # tokens -= self._tokens.pop(0) def _print_res(self, res: str) -> None: """Called after receiving response from ChatGPT, prints the response to the shell""" @@ -294,7 +301,7 @@ def _format_markdown(self, text: str) -> str: def _get_json_convo(self, n: int) -> list[dict[str, str]]: """Returns the current conversation as a JSON string, up to n last items""" n = -n if n != 0 else 0 - messages = self.base + self.messages if n == 0 else self.messages[n:] + messages = self.full_convo if n == 0 else self.messages[n:] return json.dumps(messages, indent=4) @@ -303,7 +310,7 @@ def _get_printed_convo(self, n: int, color: bool = True) -> list[tuple[str, str] user = XSH.env.get("USER", "user") convo = [] n = -n if n != 0 else 0 - messages = self.base + self.messages if n == 0 else self.messages[n:] + messages = self.full_convo if n == 0 else self.messages[n:] for msg in messages: if msg["role"] == "user": diff --git a/xontrib_chatgpt/chatmanager.py b/xontrib_chatgpt/chatmanager.py index 18f63c1..3a878f8 100644 --- a/xontrib_chatgpt/chatmanager.py +++ b/xontrib_chatgpt/chatmanager.py @@ -12,6 +12,7 @@ from xonsh.lazyasd import LazyObject from xontrib_chatgpt.chatgpt import ChatGPT +from xontrib_chatgpt.conversation import ConversationCache from xontrib_chatgpt.lazyobjs import _FIND_NAME_REGEX, _YAML from xontrib_chatgpt.args import _cm_parse from xontrib_chatgpt.exceptions import ( @@ -102,6 +103,7 @@ class ChatManager: def __init__(self): self._instances: dict[int, dict[str, Union[str, ChatGPT]]] = defaultdict(dict) self._current: Optional[int] = None + self.cache = ConversationCache() self._update_inst_dict() def __call__(self, args: list[str], stdin: TextIO = None): @@ -241,7 +243,7 @@ def save( chat = self.get_chat_by_name(chat_name) try: - res = chat["inst"].save_convo(mode=mode) + res = chat["inst"].save_convo(mode=mode, override=override) except (NoConversationsError, InvalidConversationsTypeError) as e: print(e) sys.exit(1) @@ -352,6 +354,12 @@ def get_chat_by_name(self, chat_name: str) -> Optional[dict]: sys.exit(1) return chat + + def get_name_by_chat(self, inst: ChatGPT) -> Optional[str]: + """Returns the name of a chat instance""" + for chat in self._instances.values(): + if chat["inst"] == inst: + return chat["name"] def _find_saved(self) -> list[Optional[str]]: """Returns a list of saved chat files in the default directory""" @@ -443,6 +451,8 @@ def on_chat_destroy_handler(self, inst: ChatGPT) -> None: inst_hash = hash(inst) if inst_hash in self._instances: + name = self._instances[inst_hash]["name"] + self.cache.rm_cache(name) del self._instances[inst_hash] if inst_hash == self._current: @@ -451,6 +461,13 @@ def on_chat_destroy_handler(self, inst: ChatGPT) -> None: def on_chat_used_handler(self, inst: ChatGPT) -> None: """Handler for on_chat_used. Updates the current chat instance.""" self._current = hash(inst) + + def on_chat_complete_handler(self, inst: ChatGPT, convo: list[dict[str, str]]) -> None: + """Handler for on_chat_complete. Updates the internal dict with the new conversation.""" + if hash(inst) != self._current: + name = self.get_name_by_chat(inst) + self.cache.current = name + self.cache.extend(convo) def tutorial(self) -> str: """Returns a usage string for the xontrib.""" diff --git a/xontrib_chatgpt/conversation.py b/xontrib_chatgpt/conversation.py new file mode 100644 index 0000000..27f13df --- /dev/null +++ b/xontrib_chatgpt/conversation.py @@ -0,0 +1,135 @@ +"""Conversation cache for ChatGPT instances""" +import json +from pathlib import Path +from typing import Optional + +from xonsh.built_ins import XSH +from xonsh.lazyasd import LazyObject + +from xontrib_chatgpt.lazyobjs import _tiktoken + +tiktoken = LazyObject("_tiktoken", globals(), _tiktoken) + +class ConversationCache: + + def __init__(self, hash: int) -> None: + self.cache: list[dict[str, str]] = [] + self.path = Path(XSH.env.get("XONSH_DATA_DIR") / 'chatgpt' / '.cache' / f'{hash}.json') + + def __del__(self) -> None: + self.path.unlink(missing_ok=True) + + def dump(self) -> None: + with open(self.path, 'w') as f: + json.dump(self.cache, f) + self.cache = [] + + def load(self) -> list[dict[str, str]]: + if not self.path.exists(): + return [] + + with open(self.path, 'r') as f: + self.cache = json.load(f) + return self.cache + + +class Conversation(ConversationCache): + + def __init__(self, hash: int) -> None: + super().__init__(hash) + self._base: list[dict[str, str]] = [ + {"role": "system", "content": "You are a helpful assistant."}, + { + "role": "system", + "content": "If your responses include code, make sure to wrap it in a markdown code block with the appropriate language.\nExample:\n```python\nprint('Hello World!')\n```", + }, + ] + self.messages: list[dict[str, str]] = [] + self.cache = self.base + self.base_tokens: int = 53 + self._tokens: list = [] + self.max_tokens: int = 3000 + + @property + def tokens(self) -> int: + """Current convo tokens""" + return self.base_tokens + sum(self._tokens) + + @property + def base(self) -> list[dict[str, str]]: + return self._base + + @base.setter + def base(self, msgs: list[dict[str, str]]) -> None: + self.base_tokens = sum(get_token_list(msgs)) + self._base = msgs + + @property + def full_convo(self) -> list[dict[str, str]]: + return self.cache + + def append(self, msg: list[dict[str, str]]) -> None: + self.messages.append(msg) + self.cache.append(msg) + + def pop(self) -> None: + self.messages.pop() + self.cache.pop() + + # def load_convo(self) -> None: + # convo = self.load() + # tokens = get_token_list(convo) + + # while convo[0]['role'] == 'system': + # convo.pop(0) + # tokens.pop(0) + + # cur_toks, idx = 0, -1 + + # while True: + # if cur_toks + tokens[idx] >= self.max_tokens: + # break + # cur_toks += tokens[idx] + # idx -= 1 + + # self.messages = convo[idx] + # self._tokens = tokens[idx] + + # def dump_convo(self) -> None: + # self.dump() + # self.messages = [] + # self._tokens = [] + + def trim(self) -> None: + """Trims conversation to make sure it doesn't exceed the max tokens""" + tokens = self.tokens + while tokens > self.max_tokens: + self.messages.pop(0) + tokens -= self._tokens.pop(0) + + + +def get_token_list(messages: list[dict[str, str]]) -> list[int]: + """Gets the chat tokens for the loaded conversation + + Parameters + ---------- + messages : list[dict[str, str]] + Messages from the conversation + + Returns + ------- + list[int] + List of tokens for each message + """ + tokens_per_message = 3 + tokens = [3] + + for message in messages: + num_tokens = 0 + num_tokens += tokens_per_message + for v in message.values(): + num_tokens += len(tiktoken.encode(v)) + tokens.append(num_tokens) + + return tokens \ No newline at end of file diff --git a/xontrib_chatgpt/events.py b/xontrib_chatgpt/events.py index 1c6d571..5277d82 100644 --- a/xontrib_chatgpt/events.py +++ b/xontrib_chatgpt/events.py @@ -30,6 +30,15 @@ of the chat instance to update the current chat for the manager. """, ), + ( + "on_chat_complete", + """ + on_chat_complete(inst: ChatGPT, convo: list[dict[str, str]]) -> None + + Fires when chat is successfully completed. Passes the instance of the + chat instance and the conversation to be added to the cache. + """ + ) ] @@ -42,6 +51,7 @@ def add_events(xsh: XonshSession, cm: ChatManager): events.on_chat_create(lambda *_, **kw: cm.on_chat_create_handler(**kw)) events.on_chat_destroy(lambda *_, **kw: cm.on_chat_destroy_handler(**kw)) events.on_chat_used(lambda *_, **kw: cm.on_chat_used_handler(**kw)) + events.on_chat_complete(lambda *_, **kw: cm.on_chat_complete_handler(**kw)) def rm_events(xsh: XonshSession):