From 72ccc75c4b7a760a5f975aa93d069ea0991af72e Mon Sep 17 00:00:00 2001 From: Krista Pratico Date: Wed, 1 May 2024 13:39:44 -0700 Subject: [PATCH 1/7] add aoai assistants streaming tests --- .../azure-openai/tests/test_assistants.py | 228 +++++++++++++++++ .../tests/test_assistants_async.py | 230 ++++++++++++++++++ 2 files changed, 458 insertions(+) diff --git a/sdk/openai/azure-openai/tests/test_assistants.py b/sdk/openai/azure-openai/tests/test_assistants.py index 09149351add2..618b87ce7e92 100644 --- a/sdk/openai/azure-openai/tests/test_assistants.py +++ b/sdk/openai/azure-openai/tests/test_assistants.py @@ -10,9 +10,177 @@ import uuid from devtools_testutils import AzureRecordedTestCase from conftest import ASST_AZURE, ASST_AZUREAD, PREVIEW, GPT_4_OPENAI, configure +from openai import AssistantEventHandler +from openai.types.beta.threads import ( + Text, + Message, + ImageFile, + TextDelta, + MessageDelta, +) +from openai.types.beta.threads.runs import RunStep, ToolCall, RunStepDelta, ToolCallDelta TIMEOUT = 300 + +class EventHandler(AssistantEventHandler): + def on_text_delta(self, delta: TextDelta, snapshot: Text) -> None: + if delta.value: + assert delta.value is not None + if delta.annotations: + for annotation in delta.annotations: + if annotation.type == "file_citation": + assert annotation.index is not None + assert annotation.file_citation.file_id + assert annotation.file_citation.quote + elif annotation.type == "file_path": + assert annotation.index is not None + assert annotation.file_path.file_id + + def on_run_step_done(self, run_step: RunStep) -> None: + details = run_step.step_details + if details.type == "tool_calls": + for tool in details.tool_calls: + if tool.type == "code_interpreter": + assert tool.id + assert tool.code_interpreter.outputs + assert tool.code_interpreter.input is not None + elif tool.type == "function": + assert tool.id + assert tool.function.arguments is not None + assert tool.function.name is not None + + def on_run_step_delta(self, delta: RunStepDelta, snapshot: RunStep) -> None: + details = delta.step_details + if details is not None: + if details.type == "tool_calls": + for tool in details.tool_calls or []: + if tool.type == "code_interpreter" and tool.code_interpreter and tool.code_interpreter.input: + assert tool.index is not None + assert tool.code_interpreter.input is not None + elif details.type == "message_creation": + assert details.message_creation.message_id + + def on_run_step_created(self, run_step: RunStep): + assert run_step.object == "thread.run.step" + assert run_step.id + assert run_step.type + assert run_step.created_at + assert run_step.assistant_id + assert run_step.thread_id + assert run_step.run_id + assert run_step.status + assert run_step.step_details + + def on_message_created(self, message: Message): + assert message.object == "thread.message" + assert message.id + assert message.created_at + assert message.file_ids is not None + assert message.status + assert message.thread_id + + def on_message_delta(self, delta: MessageDelta, snapshot: Message): + if delta.content: + for content in delta.content: + if content.type == "text": + assert content.index is not None + if content.text: + if content.text.value: + assert content.text.value is not None + if content.text.annotations: + for annot in content.text.annotations: + if annot.type == "file_citation": + assert annot.end_index is not None + assert annot.file_citation.file_id + assert annot.file_citation.quote + assert annot.start_index is not None + elif annot.type == "file_path": + assert annot.end_index is not None + assert annot.file_path.file_id + assert annot.start_index is not None + elif content.type == "image_file": + assert content.index is not None + assert content.image_file.file_id + + + def on_message_done(self, message: Message): + for msg in message.content: + if msg.type == "image_file": + assert msg.image_file.file_id + if msg.type == "text": + assert msg.text.value + if msg.text.annotations: + for annot in msg.text.annotations: + if annot.type == "file_citation": + assert annot.end_index is not None + assert annot.file_citation.file_id + assert annot.file_citation.quote + assert annot.start_index is not None + assert annot.text is not None + elif annot.type == "file_path": + assert annot.end_index is not None + assert annot.file_path.file_id + assert annot.start_index is not None + assert annot.text is not None + + def on_text_created(self, text: Text): + assert text.value is not None + + def on_text_done(self, text: Text): + assert text.value is not None + for annot in text.annotations: + if annot.type == "file_citation": + assert annot.end_index is not None + assert annot.file_citation.file_id + assert annot.file_citation.quote + assert annot.start_index is not None + assert annot.text is not None + elif annot.type == "file_path": + assert annot.end_index is not None + assert annot.file_path.file_id + assert annot.start_index is not None + assert annot.text is not None + + def on_image_file_done(self, image_file: ImageFile): + assert image_file.file_id + + def on_tool_call_created(self, tool_call: ToolCall): + assert tool_call.id + + def on_tool_call_delta(self, delta: ToolCallDelta, snapshot: ToolCall): + if delta.type == "code_interpreter": + assert delta.index is not None + if delta.code_interpreter: + if delta.code_interpreter.input: + assert delta.code_interpreter.input is not None + if delta.code_interpreter.outputs: + for output in delta.code_interpreter.outputs: + if output.type == "image": + assert output.image.file_id + elif output.type == "logs": + assert output.logs + if delta.type == "function": + assert delta.id + if delta.function: + assert delta.function.arguments is not None + assert delta.function.name is not None + + def on_tool_call_done(self, tool_call: ToolCall): + if tool_call.type == "code_interpreter": + assert tool_call.id + assert tool_call.code_interpreter.input is not None + for output in tool_call.code_interpreter.outputs: + if output.type == "image": + assert output.image.file_id + elif output.type == "logs": + assert output.logs + if tool_call.type == "function": + assert tool_call.id + assert tool_call.function.arguments is not None + assert tool_call.function.name is not None + + class TestAssistants(AzureRecordedTestCase): @configure @@ -406,3 +574,63 @@ def test_assistants_runs_functions(self, client, api_type, api_version, **kwargs ) assert delete_thread.id assert delete_thread.deleted is True + + @configure + @pytest.mark.parametrize("api_type, api_version", [(ASST_AZURE, PREVIEW), (GPT_4_OPENAI, "v1")]) + def test_assistants_streaming(self, client, api_type, api_version, **kwargs): + assistant = client.beta.assistants.create( + name="Math Tutor", + instructions="You are a personal math tutor. Write and run code to answer math questions.", + tools=[{"type": "code_interpreter"}], + **kwargs, + ) + try: + thread = client.beta.threads.create() + client.beta.threads.messages.create( + thread_id=thread.id, + role="user", + content="I need to solve the equation `3x + 11 = 14`. Can you help me?", + ) + stream = client.beta.threads.runs.create( + thread_id=thread.id, + assistant_id=assistant.id, + instructions="Please address the user as Jane Doe. The user has a premium account.", + stream=True, + ) + + for event in stream: + assert event + finally: + client.beta.assistants.delete(assistant.id) + + @configure + @pytest.mark.parametrize("api_type, api_version", [(ASST_AZURE, PREVIEW), (GPT_4_OPENAI, "v1")]) + def test_assistants_stream_event_handler(self, client, api_type, api_version, **kwargs): + assistant = client.beta.assistants.create( + name="Math Tutor", + instructions="You are a personal math tutor. Write and run code to answer math questions.", + tools=[{"type": "code_interpreter"}], + **kwargs + ) + + try: + question = "I need to solve the equation `3x + 11 = 14`. Can you help me and then generate an image with the answer?" + + thread = client.beta.threads.create( + messages=[ + { + "role": "user", + "content": question, + }, + ] + ) + + with client.beta.threads.runs.stream( + thread_id=thread.id, + assistant_id=assistant.id, + instructions="Please address the user as Jane Doe. The user has a premium account.", + event_handler=EventHandler(), + ) as stream: + stream.until_done() + finally: + client.beta.assistants.delete(assistant.id) diff --git a/sdk/openai/azure-openai/tests/test_assistants_async.py b/sdk/openai/azure-openai/tests/test_assistants_async.py index f4c437eea83d..8fbbdf4c8685 100644 --- a/sdk/openai/azure-openai/tests/test_assistants_async.py +++ b/sdk/openai/azure-openai/tests/test_assistants_async.py @@ -10,9 +10,177 @@ import uuid from devtools_testutils import AzureRecordedTestCase from conftest import ASST_AZURE, ASST_AZUREAD, PREVIEW, GPT_4_OPENAI, configure_async +from openai import AsyncAssistantEventHandler +from openai.types.beta.threads import ( + Text, + Message, + ImageFile, + TextDelta, + MessageDelta, +) +from openai.types.beta.threads.runs import RunStep, ToolCall, RunStepDelta, ToolCallDelta + TIMEOUT = 300 + +class AsyncEventHandler(AsyncAssistantEventHandler): + async def on_text_delta(self, delta: TextDelta, snapshot: Text) -> None: + if delta.value: + assert delta.value is not None + if delta.annotations: + for annotation in delta.annotations: + if annotation.type == "file_citation": + assert annotation.index is not None + assert annotation.file_citation.file_id + assert annotation.file_citation.quote + elif annotation.type == "file_path": + assert annotation.index is not None + assert annotation.file_path.file_id + + async def on_run_step_done(self, run_step: RunStep) -> None: + details = run_step.step_details + if details.type == "tool_calls": + for tool in details.tool_calls: + if tool.type == "code_interpreter": + assert tool.id + assert tool.code_interpreter.outputs + assert tool.code_interpreter.input is not None + elif tool.type == "function": + assert tool.id + assert tool.function.arguments is not None + assert tool.function.name is not None + + async def on_run_step_delta(self, delta: RunStepDelta, snapshot: RunStep) -> None: + details = delta.step_details + if details is not None: + if details.type == "tool_calls": + for tool in details.tool_calls or []: + if tool.type == "code_interpreter" and tool.code_interpreter and tool.code_interpreter.input: + assert tool.index is not None + assert tool.code_interpreter.input is not None + elif details.type == "message_creation": + assert details.message_creation.message_id + + async def on_run_step_created(self, run_step: RunStep): + assert run_step.object == "thread.run.step" + assert run_step.id + assert run_step.type + assert run_step.created_at + assert run_step.assistant_id + assert run_step.thread_id + assert run_step.run_id + assert run_step.status + assert run_step.step_details + + async def on_message_created(self, message: Message): + assert message.object == "thread.message" + assert message.id + assert message.created_at + assert message.file_ids is not None + assert message.status + assert message.thread_id + + async def on_message_delta(self, delta: MessageDelta, snapshot: Message): + if delta.content: + for content in delta.content: + if content.type == "text": + assert content.index is not None + if content.text: + if content.text.value: + assert content.text.value is not None + if content.text.annotations: + for annot in content.text.annotations: + if annot.type == "file_citation": + assert annot.end_index is not None + assert annot.file_citation.file_id + assert annot.file_citation.quote + assert annot.start_index is not None + elif annot.type == "file_path": + assert annot.end_index is not None + assert annot.file_path.file_id + assert annot.start_index is not None + elif content.type == "image_file": + assert content.index is not None + assert content.image_file.file_id + + async def on_message_done(self, message: Message): + for msg in message.content: + if msg.type == "image_file": + assert msg.image_file.file_id + if msg.type == "text": + assert msg.text.value + if msg.text.annotations: + for annot in msg.text.annotations: + if annot.type == "file_citation": + assert annot.end_index is not None + assert annot.file_citation.file_id + assert annot.file_citation.quote + assert annot.start_index is not None + assert annot.text is not None + elif annot.type == "file_path": + assert annot.end_index is not None + assert annot.file_path.file_id + assert annot.start_index is not None + assert annot.text is not None + + async def on_text_created(self, text: Text): + assert text.value is not None + + async def on_text_done(self, text: Text): + assert text.value is not None + for annot in text.annotations: + if annot.type == "file_citation": + assert annot.end_index is not None + assert annot.file_citation.file_id + assert annot.file_citation.quote + assert annot.start_index is not None + assert annot.text is not None + elif annot.type == "file_path": + assert annot.end_index is not None + assert annot.file_path.file_id + assert annot.start_index is not None + assert annot.text is not None + + async def on_image_file_done(self, image_file: ImageFile): + assert image_file.file_id + + async def on_tool_call_created(self, tool_call: ToolCall): + assert tool_call.id + + async def on_tool_call_delta(self, delta: ToolCallDelta, snapshot: ToolCall): + if delta.type == "code_interpreter": + assert delta.index is not None + if delta.code_interpreter: + if delta.code_interpreter.input: + assert delta.code_interpreter.input is not None + if delta.code_interpreter.outputs: + for output in delta.code_interpreter.outputs: + if output.type == "image": + assert output.image.file_id + elif output.type == "logs": + assert output.logs + if delta.type == "function": + assert delta.id + if delta.function: + assert delta.function.arguments is not None + assert delta.function.name is not None + + async def on_tool_call_done(self, tool_call: ToolCall): + if tool_call.type == "code_interpreter": + assert tool_call.id + assert tool_call.code_interpreter.input is not None + for output in tool_call.code_interpreter.outputs: + if output.type == "image": + assert output.image.file_id + elif output.type == "logs": + assert output.logs + if tool_call.type == "function": + assert tool_call.id + assert tool_call.function.arguments is not None + assert tool_call.function.name is not None + + class TestAssistantsAsync(AzureRecordedTestCase): @configure_async @@ -414,3 +582,65 @@ async def test_assistants_runs_functions(self, client_async, api_type, api_versi ) assert delete_thread.id assert delete_thread.deleted is True + + @configure_async + @pytest.mark.asyncio + @pytest.mark.parametrize("api_type, api_version", [(ASST_AZURE, PREVIEW), (GPT_4_OPENAI, "v1")]) + async def test_assistants_streaming(self, client_async, api_type, api_version, **kwargs): + assistant = await client_async.beta.assistants.create( + name="Math Tutor", + instructions="You are a personal math tutor. Write and run code to answer math questions.", + tools=[{"type": "code_interpreter"}], + **kwargs, + ) + try: + thread = await client_async.beta.threads.create() + await client_async.beta.threads.messages.create( + thread_id=thread.id, + role="user", + content="I need to solve the equation `3x + 11 = 14`. Can you help me?", + ) + stream = await client_async.beta.threads.runs.create( + thread_id=thread.id, + assistant_id=assistant.id, + instructions="Please address the user as Jane Doe. The user has a premium account.", + stream=True, + ) + + async for event in stream: + assert event + finally: + await client_async.beta.assistants.delete(assistant.id) + + @configure_async + @pytest.mark.asyncio + @pytest.mark.parametrize("api_type, api_version", [(ASST_AZURE, PREVIEW), (GPT_4_OPENAI, "v1")]) + async def test_assistants_stream_event_handler(self, client_async, api_type, api_version, **kwargs): + assistant = await client_async.beta.assistants.create( + name="Math Tutor", + instructions="You are a personal math tutor. Write and run code to answer math questions.", + tools=[{"type": "code_interpreter"}], + **kwargs + ) + + try: + question = "I need to solve the equation `3x + 11 = 14`. Can you help me and then generate an image with the answer?" + + thread = await client_async.beta.threads.create( + messages=[ + { + "role": "user", + "content": question, + }, + ] + ) + + async with client_async.beta.threads.runs.stream( + thread_id=thread.id, + assistant_id=assistant.id, + instructions="Please address the user as Jane Doe. The user has a premium account.", + event_handler=AsyncEventHandler(), + ) as stream: + await stream.until_done() + finally: + await client_async.beta.assistants.delete(assistant.id) From 4db16058222b3854e0257ac21cad7ac5e4bf4fa2 Mon Sep 17 00:00:00 2001 From: Krista Pratico Date: Tue, 21 May 2024 17:55:03 -0700 Subject: [PATCH 2/7] update to 2024-05-01-preview and add assistants v2 tests --- sdk/openai/azure-openai/tests/conftest.py | 2 +- .../azure-openai/tests/test_assistants.py | 199 +++++++++++++++--- .../tests/test_assistants_async.py | 198 ++++++++++++++--- 3 files changed, 347 insertions(+), 52 deletions(-) diff --git a/sdk/openai/azure-openai/tests/conftest.py b/sdk/openai/azure-openai/tests/conftest.py index 89cf69ae8ee4..4081e42e057e 100644 --- a/sdk/openai/azure-openai/tests/conftest.py +++ b/sdk/openai/azure-openai/tests/conftest.py @@ -20,7 +20,7 @@ # for pytest.parametrize GA = "2024-02-01" -PREVIEW = "2024-03-01-preview" +PREVIEW = "2024-05-01-preview" LATEST = PREVIEW AZURE = "azure" diff --git a/sdk/openai/azure-openai/tests/test_assistants.py b/sdk/openai/azure-openai/tests/test_assistants.py index 618b87ce7e92..912f2bb1c7e0 100644 --- a/sdk/openai/azure-openai/tests/test_assistants.py +++ b/sdk/openai/azure-openai/tests/test_assistants.py @@ -76,7 +76,7 @@ def on_message_created(self, message: Message): assert message.object == "thread.message" assert message.id assert message.created_at - assert message.file_ids is not None + assert message.attachments is not None assert message.status assert message.thread_id @@ -260,7 +260,6 @@ def test_assistants_threads_crud(self, client, api_type, api_version, **kwargs): assert delete_thread.id == thread.id assert delete_thread.deleted is True - @pytest.mark.skip(reason="AOAI doesn't support assistants v2 yet") @configure @pytest.mark.parametrize("api_type, api_version", [(ASST_AZURE, PREVIEW), (GPT_4_OPENAI, "v1")]) def test_assistants_messages_crud(self, client, api_type, api_version, **kwargs): @@ -330,6 +329,149 @@ def test_assistants_messages_crud(self, client, api_type, api_version, **kwargs) ) assert delete_thread.id == thread.id assert delete_thread.deleted is True + delete_file = client.files.delete(file.id) + assert delete_file.deleted is True + + @configure + @pytest.mark.parametrize("api_type, api_version", [(ASST_AZURE, PREVIEW), (GPT_4_OPENAI, "v1")]) + def test_assistants_vector_stores_crud(self, client, api_type, api_version, **kwargs): + file_name = f"test{uuid.uuid4()}.txt" + with open(file_name, "w") as f: + f.write("test") + + path = pathlib.Path(file_name) + + file = client.files.create( + file=open(path, "rb"), + purpose="assistants" + ) + + try: + vector_store = client.beta.vector_stores.create( + name="Support FAQ" + ) + assert vector_store.name == "Support FAQ" + assert vector_store.id + assert vector_store.object == "vector_store" + assert vector_store.created_at + assert vector_store.file_counts.total == 0 + + vectors = client.beta.vector_stores.list() + for vector in vectors: + assert vector.id + assert vector_store.object == "vector_store" + assert vector_store.created_at + + vector_store = client.beta.vector_stores.update( + vector_store_id=vector_store.id, + name="Support FAQ and more", + metadata={"Q": "A"} + ) + retrieved_vector = client.beta.vector_stores.retrieve( + vector_store_id=vector_store.id + ) + assert retrieved_vector.id == vector_store.id + assert retrieved_vector.name == "Support FAQ and more" + assert retrieved_vector.metadata == {"Q": "A"} + + vector_store_file = client.beta.vector_stores.files.create( + vector_store_id=vector_store.id, + file_id=file.id + ) + assert vector_store_file.id + assert vector_store_file.object == "vector_store.file" + assert vector_store_file.created_at + assert vector_store_file.vector_store_id == vector_store.id + + vector_store_files = client.beta.vector_stores.files.list( + vector_store_id=vector_store.id + ) + for vector_file in vector_store_files: + assert vector_file.id + assert vector_file.object == "vector_store.file" + assert vector_store_file.created_at + assert vector_store_file.vector_store_id == vector_store.id + + vector_store_file_2 = client.beta.vector_stores.files.retrieve( + vector_store_id=vector_store.id, + file_id=file.id + ) + assert vector_store_file_2.id == vector_store_file.id + assert vector_store_file.vector_store_id == vector_store.id + + finally: + os.remove(path) + deleted_vector_store_file = client.beta.vector_stores.files.delete( + vector_store_id=vector_store.id, + file_id=file.id + ) + assert deleted_vector_store_file.deleted is True + deleted_vector_store = client.beta.vector_stores.delete( + vector_store_id=vector_store.id + ) + assert deleted_vector_store.deleted is True + + @configure + @pytest.mark.parametrize("api_type, api_version", [(ASST_AZURE, PREVIEW), (GPT_4_OPENAI, "v1")]) + def test_assistants_vector_stores_batch_crud(self, client, api_type, api_version, **kwargs): + file_name = f"test{uuid.uuid4()}.txt" + file_name_2 = f"test{uuid.uuid4()}.txt" + with open(file_name, "w") as f: + f.write("test") + + path = pathlib.Path(file_name) + + file = client.files.create( + file=open(path, "rb"), + purpose="assistants" + ) + with open(file_name_2, "w") as f: + f.write("test") + path_2 = pathlib.Path(file_name_2) + + file_2 = client.files.create( + file=open(path_2, "rb"), + purpose="assistants" + ) + try: + vector_store = client.beta.vector_stores.create( + name="Support FAQ" + ) + vector_store_file_batch = client.beta.vector_stores.file_batches.create( + vector_store_id=vector_store.id, + file_ids=[file.id, file_2.id] + ) + assert vector_store_file_batch.id + assert vector_store_file_batch.object == "vector_store.file_batch" + assert vector_store_file_batch.created_at + assert vector_store_file_batch.status + + vectors = client.beta.vector_stores.file_batches.list_files( + vector_store_id=vector_store.id, + batch_id=vector_store_file_batch.id + ) + for vector in vectors: + assert vector.id + assert vector.object == "vector_store.file" + assert vector.created_at + + retrieved_vector_store_file_batch = client.beta.vector_stores.file_batches.retrieve( + vector_store_id=vector_store.id, + batch_id=vector_store_file_batch.id + ) + assert retrieved_vector_store_file_batch.id == vector_store_file_batch.id + + finally: + os.remove(path) + os.remove(path_2) + delete_file = client.files.delete(file.id) + assert delete_file.deleted is True + delete_file = client.files.delete(file_2.id) + assert delete_file.deleted is True + deleted_vector_store = client.beta.vector_stores.delete( + vector_store_id=vector_store.id + ) + assert deleted_vector_store.deleted is True @configure @pytest.mark.parametrize("api_type, api_version", [(ASST_AZURE, PREVIEW), (GPT_4_OPENAI, "v1")]) @@ -354,7 +496,7 @@ def test_assistants_runs_code(self, client, api_type, api_version, **kwargs): thread_id=thread.id, assistant_id=assistant.id, instructions="Please address the user as Jane Doe.", - # additional_instructions="After solving each equation, say 'Isn't math fun?'", # not supported by AOAI yet + additional_instructions="After solving each equation, say 'Isn't math fun?'", ) start_time = time.time() @@ -396,10 +538,9 @@ def test_assistants_runs_code(self, client, api_type, api_version, **kwargs): assert delete_thread.id == thread.id assert delete_thread.deleted is True - @pytest.mark.skip("AOAI does not support retrieval tools yet") @configure @pytest.mark.parametrize("api_type, api_version", [(ASST_AZURE, PREVIEW), (GPT_4_OPENAI, "v1")]) - def test_assistants_runs_retrieval(self, client, api_type, api_version, **kwargs): + def test_assistants_runs_file_search(self, client, api_type, api_version, **kwargs): file_name = f"test{uuid.uuid4()}.txt" with open(file_name, "w") as f: f.write("Contoso company policy requires that all employees take at least 10 vacation days a year.") @@ -412,15 +553,24 @@ def test_assistants_runs_retrieval(self, client, api_type, api_version, **kwargs ) try: + vector_store = client.beta.vector_stores.create( + name="Support FAQ", + file_ids=[file.id] + ) + assistant = client.beta.assistants.create( name="python test", instructions="You help answer questions about Contoso company policy.", - tools=[{"type": "retrieval"}], - file_ids=[file.id], + tools=[{"type": "file_search"}], + tool_resources={ + "file_search": { + "vector_store_ids": [vector_store.id] + } + }, **kwargs ) - run = client.beta.threads.create_and_run( + run = client.beta.threads.create_and_run_poll( assistant_id=assistant.id, thread={ "messages": [ @@ -429,24 +579,12 @@ def test_assistants_runs_retrieval(self, client, api_type, api_version, **kwargs } ) - start_time = time.time() - - while True: - if time.time() - start_time > TIMEOUT: - raise TimeoutError("Run timed out") + if run.status == "completed": + messages = client.beta.threads.messages.list(thread_id=run.thread_id) - run = client.beta.threads.runs.retrieve(thread_id=run.thread_id, run_id=run.id) - - if run.status == "completed": - messages = client.beta.threads.messages.list(thread_id=run.thread_id) - - for message in messages: - assert message.content[0].type == "text" - assert message.content[0].text.value - - break - - time.sleep(5) + for message in messages: + assert message.content[0].type == "text" + assert message.content[0].text.value finally: os.remove(path) @@ -461,6 +599,15 @@ def test_assistants_runs_retrieval(self, client, api_type, api_version, **kwargs ) assert delete_thread.id assert delete_thread.deleted is True + deleted_vector_store_file = client.beta.vector_stores.files.delete( + vector_store_id=vector_store.id, + file_id=file.id + ) + assert deleted_vector_store_file.deleted is True + deleted_vector_store = client.beta.vector_stores.delete( + vector_store_id=vector_store.id + ) + assert deleted_vector_store.deleted is True @configure @pytest.mark.parametrize("api_type, api_version", [(ASST_AZURE, PREVIEW), (GPT_4_OPENAI, "v1")]) @@ -552,7 +699,7 @@ def test_assistants_runs_functions(self, client, api_type, api_version, **kwargs retrieved_step = client.beta.threads.runs.steps.retrieve( thread_id=run.thread_id, run_id=r.id, - step_id=run_steps.data[0].id + step_id=run_steps[0].id ) assert retrieved_step.id assert retrieved_step.created_at diff --git a/sdk/openai/azure-openai/tests/test_assistants_async.py b/sdk/openai/azure-openai/tests/test_assistants_async.py index 8fbbdf4c8685..d4057c54a846 100644 --- a/sdk/openai/azure-openai/tests/test_assistants_async.py +++ b/sdk/openai/azure-openai/tests/test_assistants_async.py @@ -77,7 +77,7 @@ async def on_message_created(self, message: Message): assert message.object == "thread.message" assert message.id assert message.created_at - assert message.file_ids is not None + assert message.attachments is not None assert message.status assert message.thread_id @@ -262,7 +262,6 @@ async def test_assistants_threads_crud(self, client_async, api_type, api_version assert delete_thread.id == thread.id assert delete_thread.deleted is True - @pytest.mark.skip(reason="AOAI doesn't support assistants v2 yet") @configure_async @pytest.mark.asyncio @pytest.mark.parametrize("api_type, api_version", [(ASST_AZURE, PREVIEW), (GPT_4_OPENAI, "v1")]) @@ -333,6 +332,151 @@ async def test_assistants_messages_crud(self, client_async, api_type, api_versio ) assert delete_thread.id == thread.id assert delete_thread.deleted is True + delete_file = await client_async.files.delete(file.id) + assert delete_file.deleted is True + + @configure_async + @pytest.mark.asyncio + @pytest.mark.parametrize("api_type, api_version", [(ASST_AZURE, PREVIEW), (GPT_4_OPENAI, "v1")]) + async def test_assistants_vector_stores_crud(self, client_async, api_type, api_version, **kwargs): + file_name = f"test{uuid.uuid4()}.txt" + with open(file_name, "w") as f: + f.write("test") + + path = pathlib.Path(file_name) + + file = await client_async.files.create( + file=open(path, "rb"), + purpose="assistants" + ) + + try: + vector_store = await client_async.beta.vector_stores.create( + name="Support FAQ" + ) + assert vector_store.name == "Support FAQ" + assert vector_store.id + assert vector_store.object == "vector_store" + assert vector_store.created_at + assert vector_store.file_counts.total == 0 + + vectors = client_async.beta.vector_stores.list() + async for vector in vectors: + assert vector.id + assert vector_store.object == "vector_store" + assert vector_store.created_at + + vector_store = await client_async.beta.vector_stores.update( + vector_store_id=vector_store.id, + name="Support FAQ and more", + metadata={"Q": "A"} + ) + retrieved_vector = await client_async.beta.vector_stores.retrieve( + vector_store_id=vector_store.id + ) + assert retrieved_vector.id == vector_store.id + assert retrieved_vector.name == "Support FAQ and more" + assert retrieved_vector.metadata == {"Q": "A"} + + vector_store_file = await client_async.beta.vector_stores.files.create( + vector_store_id=vector_store.id, + file_id=file.id + ) + assert vector_store_file.id + assert vector_store_file.object == "vector_store.file" + assert vector_store_file.created_at + assert vector_store_file.vector_store_id == vector_store.id + + vector_store_files = client_async.beta.vector_stores.files.list( + vector_store_id=vector_store.id + ) + async for vector_file in vector_store_files: + assert vector_file.id + assert vector_file.object == "vector_store.file" + assert vector_store_file.created_at + assert vector_store_file.vector_store_id == vector_store.id + + vector_store_file_2 = await client_async.beta.vector_stores.files.retrieve( + vector_store_id=vector_store.id, + file_id=file.id + ) + assert vector_store_file_2.id == vector_store_file.id + assert vector_store_file.vector_store_id == vector_store.id + + finally: + os.remove(path) + deleted_vector_store_file = await client_async.beta.vector_stores.files.delete( + vector_store_id=vector_store.id, + file_id=file.id + ) + assert deleted_vector_store_file.deleted is True + deleted_vector_store = await client_async.beta.vector_stores.delete( + vector_store_id=vector_store.id + ) + assert deleted_vector_store.deleted is True + + @configure_async + @pytest.mark.asyncio + @pytest.mark.parametrize("api_type, api_version", [(ASST_AZURE, PREVIEW), (GPT_4_OPENAI, "v1")]) + async def test_assistants_vector_stores_batch_crud(self, client_async, api_type, api_version, **kwargs): + file_name = f"test{uuid.uuid4()}.txt" + file_name_2 = f"test{uuid.uuid4()}.txt" + with open(file_name, "w") as f: + f.write("test") + + path = pathlib.Path(file_name) + + file = await client_async.files.create( + file=open(path, "rb"), + purpose="assistants" + ) + with open(file_name_2, "w") as f: + f.write("test") + path_2 = pathlib.Path(file_name_2) + + file_2 = await client_async.files.create( + file=open(path_2, "rb"), + purpose="assistants" + ) + try: + vector_store = await client_async.beta.vector_stores.create( + name="Support FAQ" + ) + vector_store_file_batch = await client_async.beta.vector_stores.file_batches.create( + vector_store_id=vector_store.id, + file_ids=[file.id, file_2.id] + ) + assert vector_store_file_batch.id + assert vector_store_file_batch.object == "vector_store.file_batch" + assert vector_store_file_batch.created_at + assert vector_store_file_batch.status + + vectors = await client_async.beta.vector_stores.file_batches.list_files( + vector_store_id=vector_store.id, + batch_id=vector_store_file_batch.id + ) + for vector in vectors: + assert vector.id + assert vector.object == "vector_store.file" + assert vector.created_at + + retrieved_vector_store_file_batch = await client_async.beta.vector_stores.file_batches.retrieve( + vector_store_id=vector_store.id, + batch_id=vector_store_file_batch.id + ) + assert retrieved_vector_store_file_batch.id == vector_store_file_batch.id + + finally: + os.remove(path) + os.remove(path_2) + delete_file = await client_async.files.delete(file.id) + assert delete_file.deleted is True + delete_file = await client_async.files.delete(file_2.id) + assert delete_file.deleted is True + deleted_vector_store = await client_async.beta.vector_stores.delete( + vector_store_id=vector_store.id + ) + assert deleted_vector_store.deleted is True @configure_async @pytest.mark.asyncio @@ -358,7 +502,7 @@ async def test_assistants_runs_code(self, client_async, api_type, api_version, * thread_id=thread.id, assistant_id=assistant.id, instructions="Please address the user as Jane Doe.", - # additional_instructions="After solving each equation, say 'Isn't math fun?'", # not supported by AOAI yet + additional_instructions="After solving each equation, say 'Isn't math fun?'", ) start_time = time.time() @@ -400,11 +544,10 @@ async def test_assistants_runs_code(self, client_async, api_type, api_version, * assert delete_thread.id == thread.id assert delete_thread.deleted is True - @pytest.mark.skip("AOAI does not support retrieval tools yet") @configure_async @pytest.mark.asyncio @pytest.mark.parametrize("api_type, api_version", [(ASST_AZURE, PREVIEW), (GPT_4_OPENAI, "v1")]) - async def test_assistants_runs_retrieval(self, client_async, api_type, api_version, **kwargs): + async def test_assistants_runs_file_search(self, client_async, api_type, api_version, **kwargs): file_name = f"test{uuid.uuid4()}.txt" with open(file_name, "w") as f: f.write("Contoso company policy requires that all employees take at least 10 vacation days a year.") @@ -416,15 +559,23 @@ async def test_assistants_runs_retrieval(self, client_async, api_type, api_versi purpose="assistants" ) try: + vector_store = await client_async.beta.vector_stores.create( + name="Support FAQ", + file_ids=[file.id] + ) assistant = await client_async.beta.assistants.create( name="python test", instructions="You help answer questions about Contoso company policy.", - tools=[{"type": "retrieval"}], - file_ids=[file.id], + tools=[{"type": "file_search"}], + tool_resources={ + "file_search": { + "vector_store_ids": [vector_store.id] + } + }, **kwargs ) - run = await client_async.beta.threads.create_and_run( + run = await client_async.beta.threads.create_and_run_poll( assistant_id=assistant.id, thread={ "messages": [ @@ -433,24 +584,12 @@ async def test_assistants_runs_retrieval(self, client_async, api_type, api_versi } ) - start_time = time.time() - - while True: - if time.time() - start_time > TIMEOUT: - raise TimeoutError("Run timed out") + if run.status == "completed": + messages = client_async.beta.threads.messages.list(thread_id=run.thread_id) - run = await client_async.beta.threads.runs.retrieve(thread_id=run.thread_id, run_id=run.id) - - if run.status == "completed": - messages = client_async.beta.threads.messages.list(thread_id=run.thread_id) - - async for message in messages: - assert message.content[0].type == "text" - assert message.content[0].text.value - - break - - time.sleep(5) + async for message in messages: + assert message.content[0].type == "text" + assert message.content[0].text.value finally: os.remove(path) @@ -465,6 +604,15 @@ async def test_assistants_runs_retrieval(self, client_async, api_type, api_versi ) assert delete_thread.id assert delete_thread.deleted is True + deleted_vector_store_file = await client_async.beta.vector_stores.files.delete( + vector_store_id=vector_store.id, + file_id=file.id + ) + assert deleted_vector_store_file.deleted is True + deleted_vector_store = await client_async.beta.vector_stores.delete( + vector_store_id=vector_store.id + ) + assert deleted_vector_store.deleted is True @configure_async @pytest.mark.asyncio From 9683ea5f24844385982974164ced6bf1bd9df0eb Mon Sep 17 00:00:00 2001 From: Krista Pratico Date: Wed, 29 May 2024 15:35:04 -0700 Subject: [PATCH 3/7] fix some tests --- sdk/openai/azure-openai/tests/test_assistants.py | 2 +- sdk/openai/azure-openai/tests/test_chat_completions.py | 6 ++---- .../azure-openai/tests/test_chat_completions_async.py | 6 ++---- 3 files changed, 5 insertions(+), 9 deletions(-) diff --git a/sdk/openai/azure-openai/tests/test_assistants.py b/sdk/openai/azure-openai/tests/test_assistants.py index 912f2bb1c7e0..bb44dfbeb803 100644 --- a/sdk/openai/azure-openai/tests/test_assistants.py +++ b/sdk/openai/azure-openai/tests/test_assistants.py @@ -699,7 +699,7 @@ def test_assistants_runs_functions(self, client, api_type, api_version, **kwargs retrieved_step = client.beta.threads.runs.steps.retrieve( thread_id=run.thread_id, run_id=r.id, - step_id=run_steps[0].id + step_id=run_steps.data[0].id ) assert retrieved_step.id assert retrieved_step.created_at diff --git a/sdk/openai/azure-openai/tests/test_chat_completions.py b/sdk/openai/azure-openai/tests/test_chat_completions.py index 0b5cdda390d0..6f97dd014d7e 100644 --- a/sdk/openai/azure-openai/tests/test_chat_completions.py +++ b/sdk/openai/azure-openai/tests/test_chat_completions.py @@ -789,9 +789,8 @@ def test_chat_completion_seed(self, client, api_type, api_version, **kwargs): ] completion = client.chat.completions.create(messages=messages, seed=42, **kwargs) - assert completion.system_fingerprint - completion = client.chat.completions.create(messages=messages, seed=42, **kwargs) - assert completion.system_fingerprint + if api_type != GPT_4_OPENAI: # bug in openai where system_fingerprint is not always returned + assert completion.system_fingerprint @configure @pytest.mark.parametrize("api_type, api_version", [(GPT_4_AZURE, GA), (GPT_4_AZURE, GA), (GPT_4_OPENAI, "v1")]) @@ -804,7 +803,6 @@ def test_chat_completion_json_response(self, client, api_type, api_version, **kw completion = client.chat.completions.create(messages=messages, response_format={ "type": "json_object" }, **kwargs) assert completion.id assert completion.object == "chat.completion" - assert completion.system_fingerprint assert completion.model assert completion.created assert completion.usage.completion_tokens is not None diff --git a/sdk/openai/azure-openai/tests/test_chat_completions_async.py b/sdk/openai/azure-openai/tests/test_chat_completions_async.py index 2c16440b54d8..27e13869199a 100644 --- a/sdk/openai/azure-openai/tests/test_chat_completions_async.py +++ b/sdk/openai/azure-openai/tests/test_chat_completions_async.py @@ -807,9 +807,8 @@ async def test_chat_completion_seed(self, client_async, api_type, api_version, * ] completion = await client_async.chat.completions.create(messages=messages, seed=42, **kwargs) - assert completion.system_fingerprint - completion = await client_async.chat.completions.create(messages=messages, seed=42, **kwargs) - assert completion.system_fingerprint + if api_type != GPT_4_OPENAI: # bug in openai where system_fingerprint is not always returned + assert completion.system_fingerprint @configure_async @pytest.mark.asyncio @@ -823,7 +822,6 @@ async def test_chat_completion_json_response(self, client_async, api_type, api_v completion = await client_async.chat.completions.create(messages=messages, response_format={ "type": "json_object" }, **kwargs) assert completion.id assert completion.object == "chat.completion" - assert completion.system_fingerprint assert completion.model assert completion.created assert completion.usage.completion_tokens is not None From 62494b93c430e1a73fb08958caef8c0a6344a797 Mon Sep 17 00:00:00 2001 From: Krista Pratico Date: Tue, 4 Jun 2024 14:19:40 -0700 Subject: [PATCH 4/7] fix --- sdk/openai/azure-openai/tests/test_chat_completions.py | 2 +- sdk/openai/azure-openai/tests/test_chat_completions_async.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sdk/openai/azure-openai/tests/test_chat_completions.py b/sdk/openai/azure-openai/tests/test_chat_completions.py index 6f97dd014d7e..72628d964deb 100644 --- a/sdk/openai/azure-openai/tests/test_chat_completions.py +++ b/sdk/openai/azure-openai/tests/test_chat_completions.py @@ -793,7 +793,7 @@ def test_chat_completion_seed(self, client, api_type, api_version, **kwargs): assert completion.system_fingerprint @configure - @pytest.mark.parametrize("api_type, api_version", [(GPT_4_AZURE, GA), (GPT_4_AZURE, GA), (GPT_4_OPENAI, "v1")]) + @pytest.mark.parametrize("api_type, api_version", [(GPT_4_AZURE, GA), (GPT_4_AZURE, PREVIEW), (GPT_4_OPENAI, "v1")]) def test_chat_completion_json_response(self, client, api_type, api_version, **kwargs): messages = [ {"role": "system", "content": "You are a helpful assistant."}, diff --git a/sdk/openai/azure-openai/tests/test_chat_completions_async.py b/sdk/openai/azure-openai/tests/test_chat_completions_async.py index 27e13869199a..a654de4512bb 100644 --- a/sdk/openai/azure-openai/tests/test_chat_completions_async.py +++ b/sdk/openai/azure-openai/tests/test_chat_completions_async.py @@ -812,7 +812,7 @@ async def test_chat_completion_seed(self, client_async, api_type, api_version, * @configure_async @pytest.mark.asyncio - @pytest.mark.parametrize("api_type, api_version", [(GPT_4_AZURE, GA), (GPT_4_AZURE, GA), (GPT_4_OPENAI, "v1")]) + @pytest.mark.parametrize("api_type, api_version", [(GPT_4_AZURE, GA), (GPT_4_AZURE, PREVIEW), (GPT_4_OPENAI, "v1")]) async def test_chat_completion_json_response(self, client_async, api_type, api_version, **kwargs): messages = [ {"role": "system", "content": "You are a helpful assistant."}, From ee7c4fb60b30dbad18a8846de81d3f0944cd379d Mon Sep 17 00:00:00 2001 From: Krista Pratico Date: Thu, 6 Jun 2024 15:09:35 -0700 Subject: [PATCH 5/7] try to fix flaky assistants tests --- .../azure-openai/tests/test_assistants.py | 49 ++++++++----------- .../tests/test_assistants_async.py | 47 ++++++++---------- 2 files changed, 40 insertions(+), 56 deletions(-) diff --git a/sdk/openai/azure-openai/tests/test_assistants.py b/sdk/openai/azure-openai/tests/test_assistants.py index 584cc2bb174e..2c1fd57d1572 100644 --- a/sdk/openai/azure-openai/tests/test_assistants.py +++ b/sdk/openai/azure-openai/tests/test_assistants.py @@ -8,6 +8,7 @@ import pytest import pathlib import uuid +import openai from devtools_testutils import AzureRecordedTestCase from conftest import ASST_AZURE, PREVIEW, GPT_4_OPENAI, configure from openai import AssistantEventHandler @@ -643,7 +644,7 @@ def test_assistants_runs_functions(self, client, api_type, api_version, **kwargs **kwargs, ) - run = client.beta.threads.create_and_run( + run = client.beta.threads.create_and_run_poll( assistant_id=assistant.id, thread={ "messages": [ @@ -651,36 +652,26 @@ def test_assistants_runs_functions(self, client, api_type, api_version, **kwargs ] } ) - start_time = time.time() - - while True: - if time.time() - start_time > TIMEOUT: - raise TimeoutError("Run timed out") - - run = client.beta.threads.runs.retrieve(thread_id=run.thread_id, run_id=run.id) - - if run.status == "requires_action": - run = client.beta.threads.runs.submit_tool_outputs( - thread_id=run.thread_id, - run_id=run.id, - tool_outputs=[ - { - "tool_call_id": run.required_action.submit_tool_outputs.tool_calls[0].id, - "output": "{\"temperature\": \"22\", \"unit\": \"celsius\", \"description\": \"Sunny\"}" - } - ] - ) - - if run.status == "completed": - messages = client.beta.threads.messages.list(thread_id=run.thread_id) - - for message in messages: - assert message.content[0].type == "text" - assert message.content[0].text.value + if run.status == "failed": + raise openai.OpenAIError(run.last_error.message) + if run.status == "requires_action": + run = client.beta.threads.runs.submit_tool_outputs_and_poll( + thread_id=run.thread_id, + run_id=run.id, + tool_outputs=[ + { + "tool_call_id": run.required_action.submit_tool_outputs.tool_calls[0].id, + "output": "{\"temperature\": \"22\", \"unit\": \"celsius\", \"description\": \"Sunny\"}" + } + ] + ) - break + if run.status == "completed": + messages = client.beta.threads.messages.list(thread_id=run.thread_id) - time.sleep(5) + for message in messages: + assert message.content[0].type == "text" + assert message.content[0].text.value runs = client.beta.threads.runs.list(thread_id=run.thread_id) for r in runs: diff --git a/sdk/openai/azure-openai/tests/test_assistants_async.py b/sdk/openai/azure-openai/tests/test_assistants_async.py index 5ce6e8399285..6bee74e10f68 100644 --- a/sdk/openai/azure-openai/tests/test_assistants_async.py +++ b/sdk/openai/azure-openai/tests/test_assistants_async.py @@ -8,6 +8,7 @@ import pytest import pathlib import uuid +import openai from devtools_testutils import AzureRecordedTestCase from conftest import ASST_AZURE, PREVIEW, GPT_4_OPENAI, configure_async from openai import AsyncAssistantEventHandler @@ -648,7 +649,7 @@ async def test_assistants_runs_functions(self, client_async, api_type, api_versi **kwargs, ) - run = await client_async.beta.threads.create_and_run( + run = await client_async.beta.threads.create_and_run_poll( assistant_id=assistant.id, thread={ "messages": [ @@ -656,36 +657,28 @@ async def test_assistants_runs_functions(self, client_async, api_type, api_versi ] } ) - start_time = time.time() - - while True: - if time.time() - start_time > TIMEOUT: - raise TimeoutError("Run timed out") - run = await client_async.beta.threads.runs.retrieve(thread_id=run.thread_id, run_id=run.id) - - if run.status == "requires_action": - run = await client_async.beta.threads.runs.submit_tool_outputs( - thread_id=run.thread_id, - run_id=run.id, - tool_outputs=[ - { - "tool_call_id": run.required_action.submit_tool_outputs.tool_calls[0].id, - "output": "{\"temperature\": \"22\", \"unit\": \"celsius\", \"description\": \"Sunny\"}" - } - ] - ) - - if run.status == "completed": - messages = client_async.beta.threads.messages.list(thread_id=run.thread_id) + if run.status == "failed": + raise openai.OpenAIError(run.last_error.message) + if run.status == "requires_action": + run = await client_async.beta.threads.runs.submit_tool_outputs_and_poll( + thread_id=run.thread_id, + run_id=run.id, + tool_outputs=[ + { + "tool_call_id": run.required_action.submit_tool_outputs.tool_calls[0].id, + "output": "{\"temperature\": \"22\", \"unit\": \"celsius\", \"description\": \"Sunny\"}" + } + ] + ) - async for message in messages: - assert message.content[0].type == "text" - assert message.content[0].text.value + if run.status == "completed": + messages = client_async.beta.threads.messages.list(thread_id=run.thread_id) - break + async for message in messages: + assert message.content[0].type == "text" + assert message.content[0].text.value - time.sleep(5) runs = client_async.beta.threads.runs.list(thread_id=run.thread_id) async for r in runs: From 7a79dcd42bc7385d2dea3f15201682bc894f3b93 Mon Sep 17 00:00:00 2001 From: Krista Pratico Date: Fri, 7 Jun 2024 09:51:26 -0700 Subject: [PATCH 6/7] more updates to fix flakiness --- .../azure-openai/tests/test_assistants.py | 51 ++++++------------- 1 file changed, 15 insertions(+), 36 deletions(-) diff --git a/sdk/openai/azure-openai/tests/test_assistants.py b/sdk/openai/azure-openai/tests/test_assistants.py index 2c1fd57d1572..e1834d4fe2ec 100644 --- a/sdk/openai/azure-openai/tests/test_assistants.py +++ b/sdk/openai/azure-openai/tests/test_assistants.py @@ -4,7 +4,6 @@ # ------------------------------------ import os -import time import pytest import pathlib import uuid @@ -21,8 +20,6 @@ ) from openai.types.beta.threads.runs import RunStep, ToolCall, RunStepDelta, ToolCallDelta -TIMEOUT = 300 - class EventHandler(AssistantEventHandler): def on_text_delta(self, delta: TextDelta, snapshot: Text) -> None: @@ -493,32 +490,21 @@ def test_assistants_runs_code(self, client, api_type, api_version, **kwargs): content="I need to solve the equation `3x + 11 = 14`. Can you help me?", ) - run = client.beta.threads.runs.create( + run = client.beta.threads.runs.create_and_poll( thread_id=thread.id, assistant_id=assistant.id, instructions="Please address the user as Jane Doe.", additional_instructions="After solving each equation, say 'Isn't math fun?'", ) + if run.status == "failed": + raise openai.OpenAIError(run.last_error.message) + if run.status == "completed": + messages = client.beta.threads.messages.list(thread_id=thread.id) - start_time = time.time() - - while True: - if time.time() - start_time > TIMEOUT: - raise TimeoutError("Run timed out") - - run = client.beta.threads.runs.retrieve(thread_id=thread.id, run_id=run.id) - - if run.status == "completed": - messages = client.beta.threads.messages.list(thread_id=thread.id) - - for message in messages: - assert message.content[0].type == "text" - assert message.content[0].text.value + for message in messages: + assert message.content[0].type == "text" + assert message.content[0].text.value - break - else: - time.sleep(5) - run = client.beta.threads.runs.update( thread_id=thread.id, run_id=run.id, @@ -548,17 +534,14 @@ def test_assistants_runs_file_search(self, client, api_type, api_version, **kwar path = pathlib.Path(file_name) - file = client.files.create( - file=open(path, "rb"), - purpose="assistants" - ) - try: vector_store = client.beta.vector_stores.create( - name="Support FAQ", - file_ids=[file.id] + name="Support FAQ" + ) + client.beta.vector_stores.files.upload_and_poll( + vector_store_id=vector_store.id, + file_id=path ) - assistant = client.beta.assistants.create( name="python test", instructions="You help answer questions about Contoso company policy.", @@ -579,7 +562,8 @@ def test_assistants_runs_file_search(self, client, api_type, api_version, **kwar ] } ) - + if run.status == "failed": + raise openai.OpenAIError(run.last_error.message) if run.status == "completed": messages = client.beta.threads.messages.list(thread_id=run.thread_id) @@ -600,11 +584,6 @@ def test_assistants_runs_file_search(self, client, api_type, api_version, **kwar ) assert delete_thread.id assert delete_thread.deleted is True - deleted_vector_store_file = client.beta.vector_stores.files.delete( - vector_store_id=vector_store.id, - file_id=file.id - ) - assert deleted_vector_store_file.deleted is True deleted_vector_store = client.beta.vector_stores.delete( vector_store_id=vector_store.id ) From ee5553a3af40d4e29125059166444d7dbc8f212e Mon Sep 17 00:00:00 2001 From: Krista Pratico Date: Fri, 7 Jun 2024 12:09:55 -0700 Subject: [PATCH 7/7] fix more flakiness and only run openai on weekly --- sdk/openai/azure-openai/tests/conftest.py | 11 ++++ .../azure-openai/tests/test_assistants.py | 27 +++++--- .../tests/test_assistants_async.py | 62 +++++++------------ 3 files changed, 52 insertions(+), 48 deletions(-) diff --git a/sdk/openai/azure-openai/tests/conftest.py b/sdk/openai/azure-openai/tests/conftest.py index 0efd79af540b..43d853fead3e 100644 --- a/sdk/openai/azure-openai/tests/conftest.py +++ b/sdk/openai/azure-openai/tests/conftest.py @@ -16,6 +16,7 @@ DefaultAzureCredential as AsyncDefaultAzureCredential, get_bearer_token_provider as get_bearer_token_provider_async, ) +from ci_tools.variables import in_ci # for pytest.parametrize @@ -65,8 +66,15 @@ ENV_OPENAI_TTS_MODEL = "tts-1" +def skip_openai_test(api_type) -> bool: + return in_ci() and "openai" in api_type and "tests-weekly" not in os.getenv("SYSTEM_DEFINITIONNAME", "") + + @pytest.fixture def client(api_type, api_version): + if skip_openai_test(api_type): + pytest.skip("Skipping openai tests - they only run on tests-weekly.") + if api_type == "azure": client = openai.AzureOpenAI( azure_endpoint=os.getenv(ENV_AZURE_OPENAI_ENDPOINT), @@ -100,6 +108,9 @@ def client(api_type, api_version): @pytest.fixture def client_async(api_type, api_version): + if skip_openai_test(api_type): + pytest.skip("Skipping openai tests - they only run on tests-weekly.") + if api_type == "azure": client = openai.AsyncAzureOpenAI( azure_endpoint=os.getenv(ENV_AZURE_OPENAI_ENDPOINT), diff --git a/sdk/openai/azure-openai/tests/test_assistants.py b/sdk/openai/azure-openai/tests/test_assistants.py index e1834d4fe2ec..e4e19dd4e7d4 100644 --- a/sdk/openai/azure-openai/tests/test_assistants.py +++ b/sdk/openai/azure-openai/tests/test_assistants.py @@ -18,6 +18,7 @@ TextDelta, MessageDelta, ) +from openai.types.beta.threads import Run from openai.types.beta.threads.runs import RunStep, ToolCall, RunStepDelta, ToolCallDelta @@ -181,6 +182,14 @@ def on_tool_call_done(self, tool_call: ToolCall): class TestAssistants(AzureRecordedTestCase): + def handle_run_failure(self, run: Run): + if run.status == "failed": + if "Rate limit" in run.last_error.message: + pytest.skip("Skipping - Rate limit reached.") + raise openai.OpenAIError(run.last_error.message) + if run.status not in ["completed", "requires_action"]: + raise openai.OpenAIError(f"Run in unexpected status: {run.status}") + @configure @pytest.mark.parametrize( "api_type, api_version", @@ -496,8 +505,7 @@ def test_assistants_runs_code(self, client, api_type, api_version, **kwargs): instructions="Please address the user as Jane Doe.", additional_instructions="After solving each equation, say 'Isn't math fun?'", ) - if run.status == "failed": - raise openai.OpenAIError(run.last_error.message) + self.handle_run_failure(run) if run.status == "completed": messages = client.beta.threads.messages.list(thread_id=thread.id) @@ -540,7 +548,7 @@ def test_assistants_runs_file_search(self, client, api_type, api_version, **kwar ) client.beta.vector_stores.files.upload_and_poll( vector_store_id=vector_store.id, - file_id=path + file=path ) assistant = client.beta.assistants.create( name="python test", @@ -562,8 +570,7 @@ def test_assistants_runs_file_search(self, client, api_type, api_version, **kwar ] } ) - if run.status == "failed": - raise openai.OpenAIError(run.last_error.message) + self.handle_run_failure(run) if run.status == "completed": messages = client.beta.threads.messages.list(thread_id=run.thread_id) @@ -631,8 +638,7 @@ def test_assistants_runs_functions(self, client, api_type, api_version, **kwargs ] } ) - if run.status == "failed": - raise openai.OpenAIError(run.last_error.message) + self.handle_run_failure(run) if run.status == "requires_action": run = client.beta.threads.runs.submit_tool_outputs_and_poll( thread_id=run.thread_id, @@ -644,7 +650,7 @@ def test_assistants_runs_functions(self, client, api_type, api_version, **kwargs } ] ) - + self.handle_run_failure(run) if run.status == "completed": messages = client.beta.threads.messages.list(thread_id=run.thread_id) @@ -666,10 +672,13 @@ def test_assistants_runs_functions(self, client, api_type, api_version, **kwargs thread_id=run.thread_id, run_id=r.id ) + for step in run_steps: + assert step.id + retrieved_step = client.beta.threads.runs.steps.retrieve( thread_id=run.thread_id, run_id=r.id, - step_id=run_steps.data[0].id + step_id=step.id ) assert retrieved_step.id assert retrieved_step.created_at diff --git a/sdk/openai/azure-openai/tests/test_assistants_async.py b/sdk/openai/azure-openai/tests/test_assistants_async.py index 6bee74e10f68..3ceabd3dd9b3 100644 --- a/sdk/openai/azure-openai/tests/test_assistants_async.py +++ b/sdk/openai/azure-openai/tests/test_assistants_async.py @@ -4,7 +4,6 @@ # ------------------------------------ import os -import time import pytest import pathlib import uuid @@ -19,10 +18,9 @@ TextDelta, MessageDelta, ) +from openai.types.beta.threads import Run from openai.types.beta.threads.runs import RunStep, ToolCall, RunStepDelta, ToolCallDelta -TIMEOUT = 300 - class AsyncEventHandler(AsyncAssistantEventHandler): async def on_text_delta(self, delta: TextDelta, snapshot: Text) -> None: @@ -183,6 +181,14 @@ async def on_tool_call_done(self, tool_call: ToolCall): class TestAssistantsAsync(AzureRecordedTestCase): + def handle_run_failure(self, run: Run): + if run.status == "failed": + if "Rate limit" in run.last_error.message: + pytest.skip("Skipping - Rate limit reached.") + raise openai.OpenAIError(run.last_error.message) + if run.status not in ["completed", "requires_action"]: + raise openai.OpenAIError(f"Run in unexpected status: {run.status}") + @configure_async @pytest.mark.asyncio @pytest.mark.parametrize( @@ -240,7 +246,6 @@ async def test_assistants_threads_crud(self, client_async, api_type, api_version ], metadata={"key": "value"}, ) - retrieved_thread = await client_async.beta.threads.retrieve( thread_id=thread.id, ) @@ -482,7 +487,6 @@ async def test_assistants_vector_stores_batch_crud(self, client_async, api_type, @pytest.mark.asyncio @pytest.mark.parametrize("api_type, api_version", [(ASST_AZURE, PREVIEW), (GPT_4_OPENAI, "v1")]) async def test_assistants_runs_code(self, client_async, api_type, api_version, **kwargs): - try: assistant = await client_async.beta.assistants.create( name="python test", @@ -498,31 +502,19 @@ async def test_assistants_runs_code(self, client_async, api_type, api_version, * content="I need to solve the equation `3x + 11 = 14`. Can you help me?", ) - run = await client_async.beta.threads.runs.create( + run = await client_async.beta.threads.runs.create_and_poll( thread_id=thread.id, assistant_id=assistant.id, instructions="Please address the user as Jane Doe.", additional_instructions="After solving each equation, say 'Isn't math fun?'", ) + self.handle_run_failure(run) + if run.status == "completed": + messages = client_async.beta.threads.messages.list(thread_id=thread.id) - start_time = time.time() - - while True: - if time.time() - start_time > TIMEOUT: - raise TimeoutError("Run timed out") - - run = await client_async.beta.threads.runs.retrieve(thread_id=thread.id, run_id=run.id) - - if run.status == "completed": - messages = client_async.beta.threads.messages.list(thread_id=thread.id) - - async for message in messages: - assert message.content[0].type == "text" - assert message.content[0].text.value - - break - else: - time.sleep(5) + async for message in messages: + assert message.content[0].type == "text" + assert message.content[0].text.value run = await client_async.beta.threads.runs.update( thread_id=thread.id, @@ -554,14 +546,13 @@ async def test_assistants_runs_file_search(self, client_async, api_type, api_ver path = pathlib.Path(file_name) - file = await client_async.files.create( - file=open(path, "rb"), - purpose="assistants" - ) try: vector_store = await client_async.beta.vector_stores.create( name="Support FAQ", - file_ids=[file.id] + ) + await client_async.beta.vector_stores.files.upload_and_poll( + vector_store_id=vector_store.id, + file=path ) assistant = await client_async.beta.assistants.create( name="python test", @@ -583,7 +574,7 @@ async def test_assistants_runs_file_search(self, client_async, api_type, api_ver ] } ) - + self.handle_run_failure(run) if run.status == "completed": messages = client_async.beta.threads.messages.list(thread_id=run.thread_id) @@ -604,11 +595,6 @@ async def test_assistants_runs_file_search(self, client_async, api_type, api_ver ) assert delete_thread.id assert delete_thread.deleted is True - deleted_vector_store_file = await client_async.beta.vector_stores.files.delete( - vector_store_id=vector_store.id, - file_id=file.id - ) - assert deleted_vector_store_file.deleted is True deleted_vector_store = await client_async.beta.vector_stores.delete( vector_store_id=vector_store.id ) @@ -657,9 +643,7 @@ async def test_assistants_runs_functions(self, client_async, api_type, api_versi ] } ) - - if run.status == "failed": - raise openai.OpenAIError(run.last_error.message) + self.handle_run_failure(run) if run.status == "requires_action": run = await client_async.beta.threads.runs.submit_tool_outputs_and_poll( thread_id=run.thread_id, @@ -671,7 +655,7 @@ async def test_assistants_runs_functions(self, client_async, api_type, api_versi } ] ) - + self.handle_run_failure(run) if run.status == "completed": messages = client_async.beta.threads.messages.list(thread_id=run.thread_id)