Skip to content

Commit

Permalink
update async method & add conftest
Browse files Browse the repository at this point in the history
  • Loading branch information
RexWzh committed Sep 29, 2024
1 parent 7ebd718 commit 797f7b8
Show file tree
Hide file tree
Showing 8 changed files with 41 additions and 20 deletions.
7 changes: 7 additions & 0 deletions chattool/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,13 @@ def save_envs(env_file:str):
elif platform.startswith("darwin"):
platform = "macos"

# is jupyter notebook
try:
get_ipython
is_jupyter = True
except:
is_jupyter = False

def default_prompt(msg:str):
"""Default prompt message for the API call
Expand Down
21 changes: 15 additions & 6 deletions chattool/chattype.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def getresponse( self
max_tries = max(max_tries, max_requests)
if options.get('stream'):
options['stream'] = False
warnings.warn("Use `async_stream_responses()` instead.")
warnings.warn("Use `stream_responses` instead.")
options = self._init_options(**options)
# make requests
api_key, chat_log, chat_url = self.api_key, self.chat_log, self.chat_url
Expand All @@ -259,29 +259,38 @@ async def async_stream_responses( self
Returns:
str: response text
Examples:
>>> chat = Chat("Hello")
>>> # in Jupyter notebook
>>> async for resp in chat.async_stream_responses():
>>> print(resp)
"""
async for resp in _async_stream_responses(
self.api_key, self.chat_url, self.chat_log, self.model, timeout=timeout, **options):
yield resp.delta_content if textonly else resp

def stream_responses(self, timeout:int=0, textonly:bool=False, **options):
def stream_responses(self, timeout:int=0, textonly:bool=True, **options):
"""Post request synchronously and stream the responses
Args:
timeout (int, optional): timeout for the API call. Defaults to 0(no timeout).
textonly (bool, optional): whether to only return the text. Defaults to False.
textonly (bool, optional): whether to only return the text. Defaults to True.
options (dict, optional): other options like `temperature`, `top_p`, etc.
Returns:
str: response text
Examples:
>>> chat = Chat("Hello")
>>> for resp in chat.stream_responses():
>>> print(resp)
"""

assert not chattool.is_jupyter, "use `await chat.async_stream_responses()` in Jupyter notebook"
async_gen = self.async_stream_responses(timeout=timeout, textonly=textonly, **options)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)

try:
# Synchronously get responses from the async generator
while True:
try:
# Run the async generator to get each response
Expand Down
8 changes: 8 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import pytest
from chattool import *

TEST_PATH = 'tests/testfiles/'

@pytest.fixture(scope="session")
def testpath():
return TEST_PATH
11 changes: 6 additions & 5 deletions tests/test_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
chatlogs = [
[{"role": "user", "content": f"Print hello using {lang}"}] for lang in langs
]
testpath = 'tests/testfiles/'

def test_simple():
# set api_key in the environment variable
Expand All @@ -32,6 +31,8 @@ async def show_resp(chat):
async for resp in chat.async_stream_responses():
print(resp.delta_content, end='')
asyncio.run(show_resp(chat))
for resp in chat.stream_responses():
print(resp, end='')

def test_async_typewriter():
def typewriter_effect(text, delay):
Expand All @@ -51,23 +52,23 @@ async def show_resp(chat):
chat = Chat("Print hello using Python")
asyncio.run(show_resp(chat))

def test_async_process():
def test_async_process(testpath):
chkpoint = testpath + "test_async.jsonl"
t = time.time()
async_chat_completion(chatlogs[:1], chkpoint, clearfile=True, nproc=3)
async_chat_completion(chatlogs, chkpoint, nproc=3)
print(f"Time elapsed: {time.time() - t:.2f}s")

# broken test
def test_failed_async():
def test_failed_async(testpath):
api_key = chattool.api_key
chattool.api_key = "sk-invalid"
chkpoint = testpath + "test_async_fail.jsonl"
words = ["hello", "Can you help me?", "Do not translate this word", "I need help with my homework"]
resp = async_chat_completion(words, chkpoint, clearfile=True, nproc=3)
chattool.api_key = api_key

def test_async_process_withfunc():
def test_async_process_withfunc(testpath):
chkpoint = testpath + "test_async_withfunc.jsonl"
words = ["hello", "Can you help me?", "Do not translate this word", "I need help with my homework"]
def msg2log(msg):
Expand All @@ -77,7 +78,7 @@ def msg2log(msg):
return chat.chat_log
async_chat_completion(words, chkpoint, clearfile=True, nproc=3, msg2log=msg2log)

def test_normal_process():
def test_normal_process(testpath):
chkpoint = testpath + "test_nomal.jsonl"
def data2chat(data):
chat = Chat(data)
Expand Down
3 changes: 1 addition & 2 deletions tests/test_chattool.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from chattool import cli
from chattool import Chat, Resp, findcost
import pytest
testpath = 'tests/testfiles/'


def test_command_line_interface():
Expand All @@ -21,7 +20,7 @@ def test_command_line_interface():
assert '--help Show this message and exit.' in help_result.output

# test for the chat class
def test_chat():
def test_chat(testpath):
# initialize
chat = Chat()
assert chat.chat_log == []
Expand Down
5 changes: 2 additions & 3 deletions tests/test_checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import os, responses
from chattool import Chat, load_chats, process_chats, api_key
testpath = 'tests/testfiles/'

def test_with_checkpoint():
def test_with_checkpoint(testpath):
# save chats without chatid
chat = Chat()
checkpath = testpath + "tmp.jsonl"
Expand Down Expand Up @@ -38,7 +37,7 @@ def test_with_checkpoint():
]
assert chats == [Chat(log) if log is not None else None for log in chat_logs]

def test_process_chats():
def test_process_chats(testpath):
def msg2chat(msg):
chat = Chat()
chat.system("You are a helpful translator for numbers.")
Expand Down
3 changes: 1 addition & 2 deletions tests/test_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import chattool
from chattool import Chat, save_envs, load_envs
testpath = 'tests/testfiles/'

def test_model_api_key():
api_key, model = chattool.api_key, chattool.model
Expand Down Expand Up @@ -43,7 +42,7 @@ def test_apibase():

chattool.api_base, chattool.base_url = api_base, base_url

def test_env_file():
def test_env_file(testpath):
save_envs(testpath + "chattool.env")
with open(testpath + "test.env", "w") as f:
f.write("OPENAI_API_KEY=sk-132\n")
Expand Down
3 changes: 1 addition & 2 deletions tests/test_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
)
import pytest, chattool, os
api_key, base_url, api_base = chattool.api_key, chattool.base_url, chattool.api_base
testpath = 'tests/testfiles/'

def test_valid_models():
if chattool.api_base:
Expand Down Expand Up @@ -40,7 +39,7 @@ def test_normalize_url():
assert normalize_url("api.openai.com") == "https://api.openai.com"
assert normalize_url("example.com/foo/bar") == "https://example.com/foo/bar"

def test_broken_requests():
def test_broken_requests(testpath):
"""Test the broken requests"""
with open(testpath + "test.txt", "w") as f:
f.write("hello world")
Expand Down

0 comments on commit 797f7b8

Please sign in to comment.