diff --git a/chattool/__init__.py b/chattool/__init__.py index 5cd9732..a654eff 100644 --- a/chattool/__init__.py +++ b/chattool/__init__.py @@ -87,6 +87,13 @@ def save_envs(env_file:str): elif platform.startswith("darwin"): platform = "macos" +# is jupyter notebook +try: + get_ipython + is_jupyter = True +except: + is_jupyter = False + def default_prompt(msg:str): """Default prompt message for the API call diff --git a/chattool/chattype.py b/chattool/chattype.py index e9e4c25..897d184 100644 --- a/chattool/chattype.py +++ b/chattool/chattype.py @@ -236,7 +236,7 @@ def getresponse( self max_tries = max(max_tries, max_requests) if options.get('stream'): options['stream'] = False - warnings.warn("Use `async_stream_responses()` instead.") + warnings.warn("Use `stream_responses` instead.") options = self._init_options(**options) # make requests api_key, chat_log, chat_url = self.api_key, self.chat_log, self.chat_url @@ -259,29 +259,38 @@ async def async_stream_responses( self Returns: str: response text + + Examples: + >>> chat = Chat("Hello") + >>> # in Jupyter notebook + >>> async for resp in chat.async_stream_responses(): + >>> print(resp) """ async for resp in _async_stream_responses( self.api_key, self.chat_url, self.chat_log, self.model, timeout=timeout, **options): yield resp.delta_content if textonly else resp - def stream_responses(self, timeout:int=0, textonly:bool=False, **options): + def stream_responses(self, timeout:int=0, textonly:bool=True, **options): """Post request synchronously and stream the responses Args: timeout (int, optional): timeout for the API call. Defaults to 0(no timeout). - textonly (bool, optional): whether to only return the text. Defaults to False. + textonly (bool, optional): whether to only return the text. Defaults to True. options (dict, optional): other options like `temperature`, `top_p`, etc. Returns: str: response text + + Examples: + >>> chat = Chat("Hello") + >>> for resp in chat.stream_responses(): + >>> print(resp) """ - + assert not chattool.is_jupyter, "use `await chat.async_stream_responses()` in Jupyter notebook" async_gen = self.async_stream_responses(timeout=timeout, textonly=textonly, **options) loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) - try: - # Synchronously get responses from the async generator while True: try: # Run the async generator to get each response diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..511346d --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,8 @@ +import pytest +from chattool import * + +TEST_PATH = 'tests/testfiles/' + +@pytest.fixture(scope="session") +def testpath(): + return TEST_PATH \ No newline at end of file diff --git a/tests/test_async.py b/tests/test_async.py index 210cd0a..5d80a2f 100644 --- a/tests/test_async.py +++ b/tests/test_async.py @@ -8,7 +8,6 @@ chatlogs = [ [{"role": "user", "content": f"Print hello using {lang}"}] for lang in langs ] -testpath = 'tests/testfiles/' def test_simple(): # set api_key in the environment variable @@ -32,6 +31,8 @@ async def show_resp(chat): async for resp in chat.async_stream_responses(): print(resp.delta_content, end='') asyncio.run(show_resp(chat)) + for resp in chat.stream_responses(): + print(resp, end='') def test_async_typewriter(): def typewriter_effect(text, delay): @@ -51,7 +52,7 @@ async def show_resp(chat): chat = Chat("Print hello using Python") asyncio.run(show_resp(chat)) -def test_async_process(): +def test_async_process(testpath): chkpoint = testpath + "test_async.jsonl" t = time.time() async_chat_completion(chatlogs[:1], chkpoint, clearfile=True, nproc=3) @@ -59,7 +60,7 @@ def test_async_process(): print(f"Time elapsed: {time.time() - t:.2f}s") # broken test -def test_failed_async(): +def test_failed_async(testpath): api_key = chattool.api_key chattool.api_key = "sk-invalid" chkpoint = testpath + "test_async_fail.jsonl" @@ -67,7 +68,7 @@ def test_failed_async(): resp = async_chat_completion(words, chkpoint, clearfile=True, nproc=3) chattool.api_key = api_key -def test_async_process_withfunc(): +def test_async_process_withfunc(testpath): chkpoint = testpath + "test_async_withfunc.jsonl" words = ["hello", "Can you help me?", "Do not translate this word", "I need help with my homework"] def msg2log(msg): @@ -77,7 +78,7 @@ def msg2log(msg): return chat.chat_log async_chat_completion(words, chkpoint, clearfile=True, nproc=3, msg2log=msg2log) -def test_normal_process(): +def test_normal_process(testpath): chkpoint = testpath + "test_nomal.jsonl" def data2chat(data): chat = Chat(data) diff --git a/tests/test_chattool.py b/tests/test_chattool.py index 5520492..ae1739c 100644 --- a/tests/test_chattool.py +++ b/tests/test_chattool.py @@ -7,7 +7,6 @@ from chattool import cli from chattool import Chat, Resp, findcost import pytest -testpath = 'tests/testfiles/' def test_command_line_interface(): @@ -21,7 +20,7 @@ def test_command_line_interface(): assert '--help Show this message and exit.' in help_result.output # test for the chat class -def test_chat(): +def test_chat(testpath): # initialize chat = Chat() assert chat.chat_log == [] diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py index cd81821..ff737b0 100644 --- a/tests/test_checkpoint.py +++ b/tests/test_checkpoint.py @@ -1,8 +1,7 @@ import os, responses from chattool import Chat, load_chats, process_chats, api_key -testpath = 'tests/testfiles/' -def test_with_checkpoint(): +def test_with_checkpoint(testpath): # save chats without chatid chat = Chat() checkpath = testpath + "tmp.jsonl" @@ -38,7 +37,7 @@ def test_with_checkpoint(): ] assert chats == [Chat(log) if log is not None else None for log in chat_logs] -def test_process_chats(): +def test_process_chats(testpath): def msg2chat(msg): chat = Chat() chat.system("You are a helpful translator for numbers.") diff --git a/tests/test_envs.py b/tests/test_envs.py index 8a5a290..30a4dc2 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -2,7 +2,6 @@ import chattool from chattool import Chat, save_envs, load_envs -testpath = 'tests/testfiles/' def test_model_api_key(): api_key, model = chattool.api_key, chattool.model @@ -43,7 +42,7 @@ def test_apibase(): chattool.api_base, chattool.base_url = api_base, base_url -def test_env_file(): +def test_env_file(testpath): save_envs(testpath + "chattool.env") with open(testpath + "test.env", "w") as f: f.write("OPENAI_API_KEY=sk-132\n") diff --git a/tests/test_request.py b/tests/test_request.py index 7a14bbb..54a4a89 100644 --- a/tests/test_request.py +++ b/tests/test_request.py @@ -7,7 +7,6 @@ ) import pytest, chattool, os api_key, base_url, api_base = chattool.api_key, chattool.base_url, chattool.api_base -testpath = 'tests/testfiles/' def test_valid_models(): if chattool.api_base: @@ -40,7 +39,7 @@ def test_normalize_url(): assert normalize_url("api.openai.com") == "https://api.openai.com" assert normalize_url("example.com/foo/bar") == "https://example.com/foo/bar" -def test_broken_requests(): +def test_broken_requests(testpath): """Test the broken requests""" with open(testpath + "test.txt", "w") as f: f.write("hello world")