-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
210 additions
and
39 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
"""Conversation cache for ChatGPT instances""" | ||
import json | ||
from pathlib import Path | ||
from typing import Optional | ||
|
||
from xonsh.built_ins import XSH | ||
from xonsh.lazyasd import LazyObject | ||
|
||
from xontrib_chatgpt.lazyobjs import _tiktoken | ||
|
||
tiktoken = LazyObject("_tiktoken", globals(), _tiktoken) | ||
|
||
class ConversationCache: | ||
|
||
def __init__(self, hash: int) -> None: | ||
self.cache: list[dict[str, str]] = [] | ||
self.path = Path(XSH.env.get("XONSH_DATA_DIR") / 'chatgpt' / '.cache' / f'{hash}.json') | ||
|
||
def __del__(self) -> None: | ||
self.path.unlink(missing_ok=True) | ||
|
||
def dump(self) -> None: | ||
with open(self.path, 'w') as f: | ||
json.dump(self.cache, f) | ||
self.cache = [] | ||
|
||
def load(self) -> list[dict[str, str]]: | ||
if not self.path.exists(): | ||
return [] | ||
|
||
with open(self.path, 'r') as f: | ||
self.cache = json.load(f) | ||
return self.cache | ||
|
||
|
||
class Conversation(ConversationCache): | ||
|
||
def __init__(self, hash: int) -> None: | ||
super().__init__(hash) | ||
self._base: list[dict[str, str]] = [ | ||
{"role": "system", "content": "You are a helpful assistant."}, | ||
{ | ||
"role": "system", | ||
"content": "If your responses include code, make sure to wrap it in a markdown code block with the appropriate language.\nExample:\n```python\nprint('Hello World!')\n```", | ||
}, | ||
] | ||
self.messages: list[dict[str, str]] = [] | ||
self.cache = self.base | ||
self.base_tokens: int = 53 | ||
self._tokens: list = [] | ||
self.max_tokens: int = 3000 | ||
|
||
@property | ||
def tokens(self) -> int: | ||
"""Current convo tokens""" | ||
return self.base_tokens + sum(self._tokens) | ||
|
||
@property | ||
def base(self) -> list[dict[str, str]]: | ||
return self._base | ||
|
||
@base.setter | ||
def base(self, msgs: list[dict[str, str]]) -> None: | ||
self.base_tokens = sum(get_token_list(msgs)) | ||
self._base = msgs | ||
|
||
@property | ||
def full_convo(self) -> list[dict[str, str]]: | ||
return self.cache | ||
|
||
def append(self, msg: list[dict[str, str]]) -> None: | ||
self.messages.append(msg) | ||
self.cache.append(msg) | ||
|
||
def pop(self) -> None: | ||
self.messages.pop() | ||
self.cache.pop() | ||
|
||
# def load_convo(self) -> None: | ||
# convo = self.load() | ||
# tokens = get_token_list(convo) | ||
|
||
# while convo[0]['role'] == 'system': | ||
# convo.pop(0) | ||
# tokens.pop(0) | ||
|
||
# cur_toks, idx = 0, -1 | ||
|
||
# while True: | ||
# if cur_toks + tokens[idx] >= self.max_tokens: | ||
# break | ||
# cur_toks += tokens[idx] | ||
# idx -= 1 | ||
|
||
# self.messages = convo[idx] | ||
# self._tokens = tokens[idx] | ||
|
||
# def dump_convo(self) -> None: | ||
# self.dump() | ||
# self.messages = [] | ||
# self._tokens = [] | ||
|
||
def trim(self) -> None: | ||
"""Trims conversation to make sure it doesn't exceed the max tokens""" | ||
tokens = self.tokens | ||
while tokens > self.max_tokens: | ||
self.messages.pop(0) | ||
tokens -= self._tokens.pop(0) | ||
|
||
|
||
|
||
def get_token_list(messages: list[dict[str, str]]) -> list[int]: | ||
"""Gets the chat tokens for the loaded conversation | ||
Parameters | ||
---------- | ||
messages : list[dict[str, str]] | ||
Messages from the conversation | ||
Returns | ||
------- | ||
list[int] | ||
List of tokens for each message | ||
""" | ||
tokens_per_message = 3 | ||
tokens = [3] | ||
|
||
for message in messages: | ||
num_tokens = 0 | ||
num_tokens += tokens_per_message | ||
for v in message.values(): | ||
num_tokens += len(tiktoken.encode(v)) | ||
tokens.append(num_tokens) | ||
|
||
return tokens |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters