Skip to content

Commit

Permalink
Merge pull request #14 from jpal91/save2
Browse files Browse the repository at this point in the history
Merge Save2 to Main
  • Loading branch information
jpal91 authored Oct 13, 2023
2 parents 60b64d4 + 3f3fd12 commit a502210
Show file tree
Hide file tree
Showing 9 changed files with 396 additions and 300 deletions.
7 changes: 3 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,9 @@ chat-manager?

**BONUS**: If [xontrib-abbrevs](https://github.com/xonsh/xontrib-abbrevs) is loaded, use `cm` to expand to `chat-manager`

#### Version 0.1.6 Notes
- Added the ability to edit ChatGPT system messages
- Saved convos now include system messages so conversation can continue with same instructions when loaded
- See [Edit System Messages](docs/edit_sys_messages.md) for more details
#### See Also
- [Tips and Tricks](/docs/tips_and_tricks.md)
- [How To Edit System Messages](/docs/edit_sys_messages.md)

## Future Plans
- **Streaming Responses**
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@

[project]
name = "xontrib-chatgpt"
version = "0.1.6"
version = "0.2.0"
license = {file = "LICENSE"}
description = "Gives the ability to use ChatGPT directly from the command line"
classifiers = [
Expand Down
25 changes: 24 additions & 1 deletion tests/xontrib_chatgpt/conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import shutil
import pytest
from xontrib_chatgpt.events import chat_events

Expand All @@ -6,4 +7,26 @@ def cm_events(xession):
events = xession.builtins.events
for c in chat_events:
events.doc(*c)
yield events
yield events

@pytest.fixture(scope="module")
def temp_home(tmpdir_factory):
home = tmpdir_factory.mktemp("home")
home.mkdir("expected")
home.mkdir("saved")
data_dir = home.mkdir("data_dir")
data_dir.mkdir("chatgpt")
fixtures = [
"color_convo.txt",
"no_color_convo.txt",
"no_color_convo2.txt",
"convo.json",
"convo2.json",
"long_convo.txt",
]
for f in fixtures:
shutil.copy(f"tests/fixtures/{f}", f"{home}/expected/{f}")
shutil.copy(
f"tests/fixtures/no_color_convo.txt", f"{data_dir}/chatgpt/no_color_convo.txt"
)
yield home
81 changes: 34 additions & 47 deletions tests/xontrib_chatgpt/test_chatgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,18 +130,28 @@ def raise_it(*_, **__):
with pytest.raises(SystemExit):
chat.chat("test")

def test_chat_convo(xession, chat):
assert chat.chat_convo == chat.base
chat.messages =[
{"role": "user", "content": "test"},
{"role": "assistant", "content": "test"},
]
chat.chat_idx -= 2
assert chat.chat_convo == chat.base + chat.messages

def test_chat_response(xession, monkeypatch_openai, chat):
xession.env["OPENAI_API_KEY"] = "test"
assert chat.chat_idx == 0
chat.chat("test") == "test"
assert chat.messages == [
{"role": "user", "content": "test"},
{"role": "assistant", "content": "test"},
]
assert chat._tokens == [1, 1]
assert chat.tokens == 55
assert chat.chat_idx == -2


@pytest.mark.skip()
def test_trim(xession, chat):
chat._tokens = [1000, 1000, 900]
chat.messages = ["test", "test", "test"]
Expand All @@ -153,18 +163,25 @@ def test_trim(xession, chat):
assert len(chat._tokens) == 3
assert len(chat.messages) == 3

def test_trim_convo(xession, chat):
toks = chat._tokens = [1000, 1000, 900]
idx = chat.chat_idx = -3
chat.trim_convo()
assert chat._tokens == toks
assert chat.chat_idx == idx
chat._tokens.append(1000)
chat.chat_idx -= 1
chat.trim_convo()
assert chat._tokens == toks
assert chat.chat_idx == idx


def test_set_base_msgs(xession, chat):
assert chat._base_tokens == 53
chat.base = [{'role': 'system', 'content': 'test'}]
assert chat._base_tokens == 8


def test__format_markdown(xession, chat):
md = chat._format_markdown(MARKDOWN_BLOCK)
assert "\x1b" in md
assert "```" not in md


def test__get_json_convo(xession, chat):
chat.messages.append({"role": "user", "content": "test"})
res = chat._get_json_convo(n=1)
Expand Down Expand Up @@ -381,54 +398,24 @@ def test_loads_from_convo_raises_file_not_found(xession, temp_home):
with pytest.raises(FileNotFoundError):
ChatGPT.fromconvo("invalid.txt")


# parse_convo


def test_parses_json(xession, temp_home):
json_path = temp_home / "expected" / "convo.json"
with open(json_path) as f:
exp_json = f.read()

assert parse_convo(exp_json) == json.loads(exp_json)


def test_parses_text(xession, temp_home):
text_path = temp_home / "expected" / "long_convo.txt"
with open(text_path) as f:
exp_text = f.read()

msgs, base = parse_convo(exp_text)
assert len(msgs) == 6
assert len(base) == 1

for r in msgs:
assert r["role"] in ["user", "assistant"]
assert r["content"] != ""

assert base[0] == {'role': 'system', 'content': 'This is a test.\n'}


# get_token_list


def test_get_token_list(xession, temp_home):
json_path = temp_home / "expected" / "convo2.json"
with open(json_path) as f:
exp_json = json.load(f)
res = get_token_list(exp_json)
assert len(res) == 7
assert sum(res) == 835
def test_loads_and_trims(xession, temp_home, monkeypatch):
chat_file = temp_home / 'expected' / 'convo2.json'
def trim_convo(self):
while self.tokens > 650:
self.chat_idx += 1
monkeypatch.setattr('xontrib_chatgpt.chatgpt.ChatGPT.trim_convo', trim_convo)
new_cls = ChatGPT.fromconvo(chat_file)
assert new_cls.chat_idx == -1


@pytest.fixture
def inc_test(xession):
xession.ctx["test"] = 0

def inc_test(**kw):
def _inc_test(**_):
xession.ctx["test"] += 1

return inc_test
return _inc_test


def test_on_chat_create_handler(xession, cm_events, inc_test):
Expand Down
38 changes: 0 additions & 38 deletions tests/xontrib_chatgpt/test_chatmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,28 +32,6 @@ def cm():
return ChatManager()


@pytest.fixture
def sys_msgs():
l = """[
{'role': 'system', 'content': 'Hello'},
{'role': 'system', 'content': 'Hi there!'},
]
"""

d = '{"content": "Hello"}'

y = dedent(
"""
- role: system
content: Hello
- role: system
content: Hi there!
"""
)

return l, d, y


def test_update_inst_dict(xession, cm):
insts = [
("test", ChatGPT("test")),
Expand Down Expand Up @@ -248,22 +226,6 @@ def test_cli(xession, cm, action, args, expected, monkeypatch):
assert getattr(cm, f"_{action}") == expected


def test_convert_to_sys(xession, sys_msgs):
l, d, y = sys_msgs
res = convert_to_sys(l)
assert res == [
{"role": "system", "content": "Hello"},
{"role": "system", "content": "Hi there!"},
]
res = convert_to_sys(d)
assert res == [{"role": "system", "content": "Hello"}]
res = convert_to_sys(y)
assert res == [
{"role": "system", "content": "Hello"},
{"role": "system", "content": "Hi there!"},
]


def test_edit(xession, cm, cm_events):
cm_events.on_chat_create(lambda *args, **kw: cm.on_chat_create_handler(*args, **kw))
cm.add("test")
Expand Down
89 changes: 89 additions & 0 deletions tests/xontrib_chatgpt/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import json
import pytest
from textwrap import dedent

from xontrib_chatgpt.utils import (
parse_convo,
get_token_list,
format_markdown,
convert_to_sys,
)

@pytest.fixture
def sys_msgs():
l = """[
{'role': 'system', 'content': 'Hello'},
{'role': 'system', 'content': 'Hi there!'},
]
"""

d = '{"content": "Hello"}'

y = dedent(
"""
- role: system
content: Hello
- role: system
content: Hi there!
"""
)

return l, d, y

MARKDOWN_BLOCK = """\
Hello!
```python
print('Hello World!')
```
"""

def test_format_markdown(xession):
md = format_markdown(MARKDOWN_BLOCK)
assert "\x1b" in md
assert "```" not in md

def test_parses_json(xession, temp_home):
json_path = temp_home / "expected" / "convo.json"
with open(json_path) as f:
exp_json = f.read()
msg, base = parse_convo(exp_json)
assert base + msg == json.loads(exp_json)


def test_parses_text(xession, temp_home):
text_path = temp_home / "expected" / "long_convo.txt"
with open(text_path) as f:
exp_text = f.read()

msgs, base = parse_convo(exp_text)
assert len(msgs) == 6
assert len(base) == 1

for r in msgs:
assert r["role"] in ["user", "assistant"]
assert r["content"] != ""

assert base[0] == {'role': 'system', 'content': 'This is a test.\n'}

def test_get_token_list(xession, temp_home):
json_path = temp_home / "expected" / "convo2.json"
with open(json_path) as f:
exp_json = json.load(f)
res = get_token_list(exp_json)
assert len(res) == 7
assert sum(res) == 835

def test_convert_to_sys(xession, sys_msgs):
l, d, y = sys_msgs
res = convert_to_sys(l)
assert res == [
{"role": "system", "content": "Hello"},
{"role": "system", "content": "Hi there!"},
]
res = convert_to_sys(d)
assert res == [{"role": "system", "content": "Hello"}]
res = convert_to_sys(y)
assert res == [
{"role": "system", "content": "Hello"},
{"role": "system", "content": "Hi there!"},
]
Loading

0 comments on commit a502210

Please sign in to comment.