From 00e73e6089f2d54a38944ec47303aa00f9d211d7 Mon Sep 17 00:00:00 2001 From: Josix Date: Thu, 22 Aug 2024 18:34:35 +0800 Subject: [PATCH] feat(providers/openai): support batch api in hook/operator/trigger (#41554) * feat(providers/openai) * support batch api in hook/operator/trigger * add wait_for_completion to OpenAITriggerBatchOperator --------- Co-authored-by: YungHsiu Chen Co-authored-by: Wei Lee --- airflow/providers/openai/exceptions.py | 28 +++ airflow/providers/openai/hooks/openai.py | 103 ++++++++++- airflow/providers/openai/operators/openai.py | 94 +++++++++- airflow/providers/openai/provider.yaml | 5 + airflow/providers/openai/triggers/__init__.py | 16 ++ airflow/providers/openai/triggers/openai.py | 112 ++++++++++++ .../operators/openai.rst | 25 +++ tests/providers/openai/hooks/test_openai.py | 72 +++++++- .../providers/openai/operators/test_openai.py | 73 +++++++- tests/providers/openai/test_exceptions.py | 39 ++++ tests/providers/openai/triggers/__init__.py | 16 ++ .../providers/openai/triggers/test_openai.py | 166 ++++++++++++++++++ .../openai/example_trigger_batch_operator.py | 117 ++++++++++++ 13 files changed, 858 insertions(+), 8 deletions(-) create mode 100644 airflow/providers/openai/exceptions.py create mode 100644 airflow/providers/openai/triggers/__init__.py create mode 100644 airflow/providers/openai/triggers/openai.py create mode 100644 tests/providers/openai/test_exceptions.py create mode 100644 tests/providers/openai/triggers/__init__.py create mode 100644 tests/providers/openai/triggers/test_openai.py create mode 100644 tests/system/providers/openai/example_trigger_batch_operator.py diff --git a/airflow/providers/openai/exceptions.py b/airflow/providers/openai/exceptions.py new file mode 100644 index 0000000000000..eafba088c4b10 --- /dev/null +++ b/airflow/providers/openai/exceptions.py @@ -0,0 +1,28 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from airflow.exceptions import AirflowException + + +class OpenAIBatchJobException(AirflowException): + """Raise when OpenAI Batch Job fails to start AFTER processing the request.""" + + +class OpenAIBatchTimeout(AirflowException): + """Raise when OpenAI Batch Job times out.""" diff --git a/airflow/providers/openai/hooks/openai.py b/airflow/providers/openai/hooks/openai.py index e66283afd6108..cc8375f9ba040 100644 --- a/airflow/providers/openai/hooks/openai.py +++ b/airflow/providers/openai/hooks/openai.py @@ -17,6 +17,8 @@ from __future__ import annotations +import time +from enum import Enum from functools import cached_property from typing import TYPE_CHECKING, Any, BinaryIO, Literal @@ -24,6 +26,7 @@ if TYPE_CHECKING: from openai.types import FileDeleted, FileObject + from openai.types.batch import Batch from openai.types.beta import ( Assistant, AssistantDeleted, @@ -42,8 +45,29 @@ ChatCompletionToolMessageParam, ChatCompletionUserMessageParam, ) - from airflow.hooks.base import BaseHook +from airflow.providers.openai.exceptions import OpenAIBatchJobException, OpenAIBatchTimeout + + +class BatchStatus(str, Enum): + """Enum for the status of a batch.""" + + VALIDATING = "validating" + FAILED = "failed" + IN_PROGRESS = "in_progress" + FINALIZING = "finalizing" + COMPLETED = "completed" + EXPIRED = "expired" + CANCELLING = "cancelling" + CANCELLED = "cancelled" + + def __str__(self) -> str: + return str(self.value) + + @classmethod + def is_in_progress(cls, status: str) -> bool: + """Check if the batch status is in progress.""" + return status in (cls.VALIDATING, cls.IN_PROGRESS, cls.FINALIZING) class OpenAIHook(BaseHook): @@ -288,13 +312,13 @@ def create_embeddings( embeddings: list[float] = response.data[0].embedding return embeddings - def upload_file(self, file: str, purpose: Literal["fine-tune", "assistants"]) -> FileObject: + def upload_file(self, file: str, purpose: Literal["fine-tune", "assistants", "batch"]) -> FileObject: """ Upload a file that can be used across various endpoints. The size of all the files uploaded by one organization can be up to 100 GB. :param file: The File object (not file name) to be uploaded. :param purpose: The intended purpose of the uploaded file. Use "fine-tune" for - Fine-tuning and "assistants" for Assistants and Messages. + Fine-tuning, "assistants" for Assistants and Messages, and "batch" for Batch API. """ with open(file, "rb") as file_stream: file_object = self.conn.files.create(file=file_stream, purpose=purpose) @@ -393,3 +417,76 @@ def delete_vector_store_file(self, vector_store_id: str, file_id: str) -> Vector """ response = self.conn.beta.vector_stores.files.delete(vector_store_id=vector_store_id, file_id=file_id) return response + + def create_batch( + self, + file_id: str, + endpoint: Literal["/v1/chat/completions", "/v1/embeddings", "/v1/completions"], + metadata: dict[str, str] | None = None, + completion_window: Literal["24h"] = "24h", + ) -> Batch: + """ + Create a batch for a given model and files. + + :param file_id: The ID of the file to be used for this batch. + :param endpoint: The endpoint to use for this batch. Allowed values include: + '/v1/chat/completions', '/v1/embeddings', '/v1/completions'. + :param metadata: A set of key-value pairs that can be attached to an object. + :param completion_window: The time window for the batch to complete. Default is 24 hours. + """ + batch = self.conn.batches.create( + input_file_id=file_id, endpoint=endpoint, metadata=metadata, completion_window=completion_window + ) + return batch + + def get_batch(self, batch_id: str) -> Batch: + """ + Get the status of a batch. + + :param batch_id: The ID of the batch to get the status of. + """ + batch = self.conn.batches.retrieve(batch_id=batch_id) + return batch + + def wait_for_batch(self, batch_id: str, wait_seconds: float = 3, timeout: float = 3600) -> None: + """ + Poll a batch to check if it finishes. + + :param batch_id: Id of the Batch to wait for. + :param wait_seconds: Optional. Number of seconds between checks. + :param timeout: Optional. How many seconds wait for batch to be ready. + Used only if not ran in deferred operator. + """ + start = time.monotonic() + while True: + if start + timeout < time.monotonic(): + self.cancel_batch(batch_id=batch_id) + raise OpenAIBatchTimeout(f"Timeout: OpenAI Batch {batch_id} is not ready after {timeout}s") + batch = self.get_batch(batch_id=batch_id) + + if BatchStatus.is_in_progress(batch.status): + time.sleep(wait_seconds) + continue + if batch.status == BatchStatus.COMPLETED: + return + if batch.status == BatchStatus.FAILED: + raise OpenAIBatchJobException(f"Batch failed - \n{batch_id}") + elif batch.status in (BatchStatus.CANCELLED, BatchStatus.CANCELLING): + raise OpenAIBatchJobException(f"Batch failed - batch was cancelled:\n{batch_id}") + elif batch.status == BatchStatus.EXPIRED: + raise OpenAIBatchJobException( + f"Batch failed - batch couldn't be completed within the hour time window :\n{batch_id}" + ) + + raise OpenAIBatchJobException( + f"Batch failed - encountered unexpected status `{batch.status}` for batch_id `{batch_id}`" + ) + + def cancel_batch(self, batch_id: str) -> Batch: + """ + Cancel a batch. + + :param batch_id: The ID of the batch to delete. + """ + batch = self.conn.batches.cancel(batch_id=batch_id) + return batch diff --git a/airflow/providers/openai/operators/openai.py b/airflow/providers/openai/operators/openai.py index 1697e88b98371..7ce834865409b 100644 --- a/airflow/providers/openai/operators/openai.py +++ b/airflow/providers/openai/operators/openai.py @@ -17,11 +17,15 @@ from __future__ import annotations +import time from functools import cached_property -from typing import TYPE_CHECKING, Any, Sequence +from typing import TYPE_CHECKING, Any, Literal, Sequence +from airflow.configuration import conf from airflow.models import BaseOperator +from airflow.providers.openai.exceptions import OpenAIBatchJobException from airflow.providers.openai.hooks.openai import OpenAIHook +from airflow.providers.openai.triggers.openai import OpenAIBatchTrigger if TYPE_CHECKING: from airflow.utils.context import Context @@ -74,3 +78,91 @@ def execute(self, context: Context) -> list[float]: embeddings = self.hook.create_embeddings(self.input_text, model=self.model, **self.embedding_kwargs) self.log.info("Generated embeddings for %d items", len(embeddings)) return embeddings + + +class OpenAITriggerBatchOperator(BaseOperator): + """ + Operator that triggers an OpenAI Batch API endpoint and waits for the batch to complete. + + :param file_id: Required. The ID of the batch file to trigger. + :param endpoint: Required. The OpenAI Batch API endpoint to trigger. + :param conn_id: Optional. The OpenAI connection ID to use. Defaults to 'openai_default'. + :param deferrable: Optional. Run operator in the deferrable mode. + :param wait_seconds: Optional. Number of seconds between checks. Only used when ``deferrable`` is False. + Defaults to 3 seconds. + :param timeout: Optional. The amount of time, in seconds, to wait for the request to complete. + Only used when ``deferrable`` is False. Defaults to 24 hour, which is the SLA for OpenAI Batch API. + :param wait_for_completion: Optional. Whether to wait for the batch to complete. If set to False, the operator + will return immediately after triggering the batch. Defaults to True. + + .. seealso:: + For more information on how to use this operator, please take a look at the guide: + :ref:`howto/operator:OpenAITriggerBatchOperator` + """ + + template_fields: Sequence[str] = ("file_id",) + + def __init__( + self, + file_id: str, + endpoint: Literal["/v1/chat/completions", "/v1/embeddings", "/v1/completions"], + conn_id: str = OpenAIHook.default_conn_name, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + wait_seconds: float = 3, + timeout: float = 24 * 60 * 60, + wait_for_completion: bool = True, + **kwargs: Any, + ): + super().__init__(**kwargs) + self.conn_id = conn_id + self.file_id = file_id + self.endpoint = endpoint + self.deferrable = deferrable + self.wait_seconds = wait_seconds + self.timeout = timeout + self.wait_for_completion = wait_for_completion + self.batch_id: str | None = None + + @cached_property + def hook(self) -> OpenAIHook: + """Return an instance of the OpenAIHook.""" + return OpenAIHook(conn_id=self.conn_id) + + def execute(self, context: Context) -> str: + batch = self.hook.create_batch(file_id=self.file_id, endpoint=self.endpoint) + self.batch_id = batch.id + if self.wait_for_completion: + if self.deferrable: + self.defer( + timeout=self.execution_timeout, + trigger=OpenAIBatchTrigger( + conn_id=self.conn_id, + batch_id=self.batch_id, + poll_interval=60, + end_time=time.time() + self.timeout, + ), + method_name="execute_complete", + ) + else: + self.log.info("Waiting for batch %s to complete", self.batch_id) + self.hook.wait_for_batch(self.batch_id, wait_seconds=self.wait_seconds, timeout=self.timeout) + return self.batch_id + + def execute_complete(self, context: Context, event: Any = None) -> str: + """ + Invoke this callback when the trigger fires; return immediately. + + Relies on trigger to throw an exception, otherwise it assumes execution was + successful. + """ + if event["status"] == "error": + raise OpenAIBatchJobException(event["message"]) + + self.log.info("%s completed successfully.", self.task_id) + return event["batch_id"] + + def on_kill(self) -> None: + """Cancel the batch if task is cancelled.""" + if self.batch_id: + self.log.info("on_kill: cancel the OpenAI Batch %s", self.batch_id) + self.hook.cancel_batch(self.batch_id) diff --git a/airflow/providers/openai/provider.yaml b/airflow/providers/openai/provider.yaml index a78338f7ce614..c08a6d00e14b0 100644 --- a/airflow/providers/openai/provider.yaml +++ b/airflow/providers/openai/provider.yaml @@ -57,6 +57,11 @@ operators: python-modules: - airflow.providers.openai.operators.openai +triggers: + - integration-name: OpenAI + python-modules: + - airflow.providers.openai.triggers.openai + connection-types: - hook-class-name: airflow.providers.openai.hooks.openai.OpenAIHook connection-type: openai diff --git a/airflow/providers/openai/triggers/__init__.py b/airflow/providers/openai/triggers/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/airflow/providers/openai/triggers/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/airflow/providers/openai/triggers/openai.py b/airflow/providers/openai/triggers/openai.py new file mode 100644 index 0000000000000..481753f7bcc35 --- /dev/null +++ b/airflow/providers/openai/triggers/openai.py @@ -0,0 +1,112 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import asyncio +import time +from typing import Any, AsyncIterator + +from airflow.providers.openai.hooks.openai import BatchStatus, OpenAIHook +from airflow.triggers.base import BaseTrigger, TriggerEvent + + +class OpenAIBatchTrigger(BaseTrigger): + """Triggers OpenAI Batch API.""" + + def __init__( + self, + conn_id: str, + batch_id: str, + poll_interval: float, + end_time: float, + ) -> None: + super().__init__() + self.conn_id = conn_id + self.poll_interval = poll_interval + self.batch_id = batch_id + self.end_time = end_time + + def serialize(self) -> tuple[str, dict[str, Any]]: + """Serialize OpenAIBatchTrigger arguments and class path.""" + return ( + "airflow.providers.openai.triggers.openai.OpenAIBatchTrigger", + { + "conn_id": self.conn_id, + "batch_id": self.batch_id, + "poll_interval": self.poll_interval, + "end_time": self.end_time, + }, + ) + + async def run(self) -> AsyncIterator[TriggerEvent]: + """Make connection to OpenAI Client, and poll the status of batch.""" + hook = OpenAIHook(conn_id=self.conn_id) + try: + while (batch := hook.get_batch(self.batch_id)) and BatchStatus.is_in_progress(batch.status): + if self.end_time < time.time(): + yield TriggerEvent( + { + "status": "error", + "message": f"Batch {self.batch_id} has not reached a terminal status after " + f"{time.time() - self.end_time} seconds.", + "batch_id": self.batch_id, + } + ) + return + await asyncio.sleep(self.poll_interval) + if batch.status == BatchStatus.COMPLETED: + yield TriggerEvent( + { + "status": "success", + "message": f"Batch {self.batch_id} has completed successfully.", + "batch_id": self.batch_id, + } + ) + elif batch.status in {BatchStatus.CANCELLED, BatchStatus.CANCELLING}: + yield TriggerEvent( + { + "status": "cancelled", + "message": f"Batch {self.batch_id} has been cancelled.", + "batch_id": self.batch_id, + } + ) + elif batch.status == BatchStatus.FAILED: + yield TriggerEvent( + { + "status": "error", + "message": f"Batch failed:\n{self.batch_id}", + "batch_id": self.batch_id, + } + ) + elif batch.status == BatchStatus.EXPIRED: + yield TriggerEvent( + { + "status": "error", + "message": f"Batch couldn't be completed within the hour time window :\n{self.batch_id}", + "batch_id": self.batch_id, + } + ) + + yield TriggerEvent( + { + "status": "error", + "message": f"Batch {self.batch_id} has failed.", + "batch_id": self.batch_id, + } + ) + except Exception as e: + yield TriggerEvent({"status": "error", "message": str(e), "batch_id": self.batch_id}) diff --git a/docs/apache-airflow-providers-openai/operators/openai.rst b/docs/apache-airflow-providers-openai/operators/openai.rst index e3f4f4d403e6a..fef8521188df6 100644 --- a/docs/apache-airflow-providers-openai/operators/openai.rst +++ b/docs/apache-airflow-providers-openai/operators/openai.rst @@ -36,3 +36,28 @@ An example using the operator is in way: :language: python :start-after: [START howto_operator_openai_embedding] :end-before: [END howto_operator_openai_embedding] + +.. _howto/operator:OpenAITriggerBatchOperator: + +OpenAITriggerBatchOperator +=========================== + +Use the :class:`~airflow.providers.open_ai.operators.open_ai.OpenAITriggerBatchOperator` to +interact with Open APIs to trigger a batch job. This operator is used to trigger a batch job and wait for the job to complete. + + +Using the Operator +^^^^^^^^^^^^^^^^^^ + +The OpenAITriggerBatchOperator requires the prepared batch file as an input to trigger the batch job. Provide the ``file_id`` and the ``endpoint`` to trigger the batch job. +Use the ``conn_id`` parameter to specify the OpenAI connection to use to + + +The OpenAITriggerBatchOperator + +An example using the operator is in way: + +.. exampleinclude:: /../../tests/system/providers/openai/example_trigger_batch_operator.py + :language: python + :start-after: [START howto_operator_openai_trigger_operator] + :end-before: [END howto_operator_openai_trigger_operator] diff --git a/tests/providers/openai/hooks/test_openai.py b/tests/providers/openai/hooks/test_openai.py index a4e4cdbbbf290..6c84845c7555d 100644 --- a/tests/providers/openai/hooks/test_openai.py +++ b/tests/providers/openai/hooks/test_openai.py @@ -26,7 +26,7 @@ from unittest.mock import mock_open from openai.pagination import SyncCursorPage -from openai.types import CreateEmbeddingResponse, Embedding, FileDeleted, FileObject +from openai.types import Batch, CreateEmbeddingResponse, Embedding, FileDeleted, FileObject from openai.types.beta import ( Assistant, AssistantDeleted, @@ -40,6 +40,7 @@ from openai.types.chat import ChatCompletion from airflow.models import Connection +from airflow.providers.openai.exceptions import OpenAIBatchJobException, OpenAIBatchTimeout from airflow.providers.openai.hooks.openai import OpenAIHook ASSISTANT_ID = "test_assistant_abc123" @@ -55,6 +56,19 @@ VECTOR_STORE_ID = "test_vs_abc123" VECTOR_STORE_NAME = "Test Vector Store" VECTOR_FILE_STORE_BATCH_ID = "test_vfsb_abc123" +BATCH_ID = "test_batch_abc123" + + +def create_batch(status) -> Batch: + return Batch( + id=BATCH_ID, + object="batch", + completion_window="24h", + created_at=1699061776, + endpoint="/v1/chat/completions", + input_file_id=FILE_ID, + status=status, + ) @pytest.fixture @@ -261,6 +275,24 @@ def mock_vector_file_store_list(): ) +@pytest.fixture( + params=[ + "completed", + "expired", + "cancelling", + "cancelled", + "failed", + ] +) +def mock_terminated_batch(request): + return create_batch(request.param) + + +@pytest.fixture(params=["validating", "in_progress", "finalizing"]) +def mock_wip_batch(request): + return create_batch(request.param) + + def test_create_chat_completion(mock_openai_hook, mock_completion): messages = [ {"role": "system", "content": "You are a helpful assistant."}, @@ -495,6 +527,44 @@ def test_delete_vector_store_file(mock_openai_hook): assert vector_store_file_deleted.deleted +def test_create_batch(mock_openai_hook, mock_terminated_batch): + mock_openai_hook.conn.batches.create.return_value = mock_terminated_batch + batch = mock_openai_hook.create_batch(endpoint="/v1/chat/completions", file_id=FILE_ID) + assert batch.id == mock_terminated_batch.id + + +def test_get_batch(mock_openai_hook, mock_terminated_batch): + mock_openai_hook.conn.batches.retrieve.return_value = mock_terminated_batch + batch = mock_openai_hook.get_batch(batch_id=BATCH_ID) + assert batch.id == mock_terminated_batch.id + + +def test_cancel_batch(mock_openai_hook, mock_terminated_batch): + mock_openai_hook.conn.batches.cancel.return_value = mock_terminated_batch + batch = mock_openai_hook.cancel_batch(batch_id=BATCH_ID) + assert batch.id == mock_terminated_batch.id + + +def test_wait_for_finished_batch(mock_openai_hook, mock_terminated_batch): + mock_openai_hook.conn.batches.retrieve.return_value = mock_terminated_batch + if mock_terminated_batch.status == "completed": + try: + mock_openai_hook.wait_for_batch(batch_id=BATCH_ID) + except Exception as e: + pytest.fail(f"Should not have raised exception: {e}") + else: + with pytest.raises(OpenAIBatchJobException, match="Batch failed"): + mock_openai_hook.wait_for_batch(batch_id=BATCH_ID, wait_seconds=0.01, timeout=0.1) + + +def test_wait_for_in_progress_batch_timeout(mock_openai_hook, mock_wip_batch): + mock_openai_hook.conn.batches.retrieve.return_value = mock_wip_batch + with pytest.raises(OpenAIBatchTimeout, match="Timeout"): + mock_openai_hook.wait_for_batch(batch_id=BATCH_ID, wait_seconds=0.2, timeout=0.01) + assert mock_openai_hook.conn.batches.retrieve.call_count >= 1 + assert mock_openai_hook.conn.batches.cancel.call_count == 1 + + def test_openai_hook_test_connection(mock_openai_hook): result, message = mock_openai_hook.test_connection() assert result is True diff --git a/tests/providers/openai/operators/test_openai.py b/tests/providers/openai/operators/test_openai.py index f45cbb5da9733..f10b8a9b266da 100644 --- a/tests/providers/openai/operators/test_openai.py +++ b/tests/providers/openai/operators/test_openai.py @@ -19,16 +19,38 @@ from unittest.mock import Mock import pytest +from openai.types.batch import Batch openai = pytest.importorskip("openai") -from airflow.providers.openai.operators.openai import OpenAIEmbeddingOperator +from airflow.exceptions import TaskDeferred +from airflow.providers.openai.operators.openai import OpenAIEmbeddingOperator, OpenAITriggerBatchOperator +from airflow.providers.openai.triggers.openai import OpenAIBatchTrigger from airflow.utils.context import Context +TASK_ID = "TaskId" +CONN_ID = "test_conn_id" +BATCH_ID = "batch_id" +FILE_ID = "file_id" +BATCH_ENDPOINT = "/v1/chat/completions" + + +@pytest.fixture +def mock_batch(): + return Batch( + id=BATCH_ID, + object="batch", + completion_window="24h", + created_at=1699061776, + endpoint=BATCH_ENDPOINT, + input_file_id=FILE_ID, + status="in_progress", + ) + def test_execute_with_input_text(): operator = OpenAIEmbeddingOperator( - task_id="TaskId", conn_id="test_conn_id", model="test_model", input_text="Test input text" + task_id=TASK_ID, conn_id=CONN_ID, model="test_model", input_text="Test input text" ) mock_hook_instance = Mock() mock_hook_instance.create_embeddings.return_value = [1.0, 2.0, 3.0] @@ -43,8 +65,53 @@ def test_execute_with_input_text(): @pytest.mark.parametrize("invalid_input", ["", None, 123]) def test_execute_with_invalid_input(invalid_input): operator = OpenAIEmbeddingOperator( - task_id="TaskId", conn_id="test_conn_id", model="test_model", input_text=invalid_input + task_id=TASK_ID, conn_id=CONN_ID, model="test_model", input_text=invalid_input ) context = Context() with pytest.raises(ValueError): operator.execute(context) + + +@pytest.mark.parametrize("wait_for_completion", [True, False]) +def test_openai_trigger_batch_operator_not_deferred(mock_batch, wait_for_completion): + operator = OpenAITriggerBatchOperator( + task_id=TASK_ID, + conn_id=CONN_ID, + file_id=FILE_ID, + endpoint=BATCH_ENDPOINT, + wait_for_completion=wait_for_completion, + deferrable=False, + ) + mock_hook_instance = Mock() + mock_hook_instance.get_batch.return_value = mock_batch + mock_hook_instance.create_batch.return_value = mock_batch + operator.hook = mock_hook_instance + + context = Context() + batch_id = operator.execute(context) + assert batch_id == BATCH_ID + + +@pytest.mark.parametrize("wait_for_completion", [True, False]) +def test_openai_trigger_batch_operator_with_deferred(mock_batch, wait_for_completion): + operator = OpenAITriggerBatchOperator( + task_id=TASK_ID, + conn_id=CONN_ID, + file_id=FILE_ID, + endpoint=BATCH_ENDPOINT, + deferrable=True, + wait_for_completion=wait_for_completion, + ) + mock_hook_instance = Mock() + mock_hook_instance.get_batch.return_value = mock_batch + mock_hook_instance.create_batch.return_value = mock_batch + operator.hook = mock_hook_instance + + context = Context() + if wait_for_completion: + with pytest.raises(TaskDeferred) as exc: + operator.execute(context) + assert isinstance(exc.value.trigger, OpenAIBatchTrigger) + else: + batch_id = operator.execute(context) + assert batch_id == BATCH_ID diff --git a/tests/providers/openai/test_exceptions.py b/tests/providers/openai/test_exceptions.py new file mode 100644 index 0000000000000..38f6d38a50992 --- /dev/null +++ b/tests/providers/openai/test_exceptions.py @@ -0,0 +1,39 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from unittest.mock import Mock + +import pytest + +from airflow.providers.openai.exceptions import OpenAIBatchJobException, OpenAIBatchTimeout + + +@pytest.mark.parametrize( + "exception_class", + [ + OpenAIBatchTimeout, + OpenAIBatchJobException, + ], +) +def test_wait_for_batch_raise_exception(exception_class): + mock_hook_instance = Mock() + mock_hook_instance.wait_for_batch.side_effect = exception_class + hook = mock_hook_instance + with pytest.raises(exception_class): + hook.wait_for_batch(batch_id="batch_id") diff --git a/tests/providers/openai/triggers/__init__.py b/tests/providers/openai/triggers/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/tests/providers/openai/triggers/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/providers/openai/triggers/test_openai.py b/tests/providers/openai/triggers/test_openai.py new file mode 100644 index 0000000000000..7d19d00494932 --- /dev/null +++ b/tests/providers/openai/triggers/test_openai.py @@ -0,0 +1,166 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import asyncio +import time +from typing import Literal +from unittest import mock + +import pytest +from openai.types import Batch + +from airflow.providers.openai.hooks.openai import BatchStatus +from airflow.providers.openai.triggers.openai import OpenAIBatchTrigger +from airflow.triggers.base import TriggerEvent + + +class TestOpenAIBatchTrigger: + BATCH_ID = "batch_id" + CONN_ID = "openai_default" + END_TIME = time.time() + 24 * 60 * 60 + POLL_INTERVAL = 3.0 + + def mock_get_batch( + self, + status: Literal[ + "validating", + "failed", + "in_progress", + "finalizing", + "completed", + "expired", + "cancelling", + "cancelled", + ], + ) -> Batch: + return Batch( + id=self.BATCH_ID, + object="batch", + completion_window="24h", + created_at=1699061776, + endpoint="/v1/chat/completions", + input_file_id="file-id", + status=status, + ) + + def test_serialization(self): + """Assert TestOpenAIBatchTrigger correctly serializes its arguments and class path.""" + trigger = OpenAIBatchTrigger( + conn_id=self.CONN_ID, + batch_id=self.BATCH_ID, + poll_interval=self.POLL_INTERVAL, + end_time=self.END_TIME, + ) + class_path, kwargs = trigger.serialize() + assert class_path == "airflow.providers.openai.triggers.openai.OpenAIBatchTrigger" + assert kwargs == { + "conn_id": self.CONN_ID, + "batch_id": self.BATCH_ID, + "poll_interval": self.POLL_INTERVAL, + "end_time": self.END_TIME, + } + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "mock_batch_status, mock_status, mock_message", + [ + (str(BatchStatus.COMPLETED), "success", "Batch batch_id has completed successfully."), + (str(BatchStatus.CANCELLING), "cancelled", "Batch batch_id has been cancelled."), + (str(BatchStatus.CANCELLED), "cancelled", "Batch batch_id has been cancelled."), + (str(BatchStatus.FAILED), "error", "Batch failed:\nbatch_id"), + ( + str(BatchStatus.EXPIRED), + "error", + "Batch couldn't be completed within the hour time window :\nbatch_id", + ), + ], + ) + @mock.patch("airflow.providers.openai.hooks.openai.OpenAIHook.get_batch") + async def test_openai_batch_for_terminal_status( + self, mock_batch, mock_batch_status, mock_status, mock_message + ): + """Assert that run trigger messages in case of job finished""" + mock_batch.return_value = self.mock_get_batch(mock_batch_status) + trigger = OpenAIBatchTrigger( + conn_id=self.CONN_ID, + batch_id=self.BATCH_ID, + poll_interval=self.POLL_INTERVAL, + end_time=self.END_TIME, + ) + expected_result = { + "status": mock_status, + "message": mock_message, + "batch_id": self.BATCH_ID, + } + task = asyncio.create_task(trigger.run().__anext__()) + await asyncio.sleep(0.5) + assert TriggerEvent(expected_result) == task.result() + asyncio.get_event_loop().stop() + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "mock_batch_status", + [ + (str(BatchStatus.FINALIZING)), + (str(BatchStatus.IN_PROGRESS)), + (str(BatchStatus.VALIDATING)), + ], + ) + @mock.patch("airflow.providers.openai.hooks.openai.OpenAIHook.get_batch") + @mock.patch("time.time") + async def test_openai_batch_for_timeout(self, mock_check_time, mock_batch, mock_batch_status): + """Assert that run trigger messages in case of batch is still running after timeout""" + MOCK_TIME = 1724068066.6468632 + mock_batch.return_value = self.mock_get_batch(mock_batch_status) + mock_check_time.return_value = MOCK_TIME + 1 + trigger = OpenAIBatchTrigger( + conn_id=self.CONN_ID, + batch_id=self.BATCH_ID, + poll_interval=self.POLL_INTERVAL, + end_time=MOCK_TIME, + ) + expected_result = { + "status": "error", + "message": f"Batch {self.BATCH_ID} has not reached a terminal status after {mock_check_time.return_value - MOCK_TIME} seconds.", + "batch_id": self.BATCH_ID, + } + task = asyncio.create_task(trigger.run().__anext__()) + await asyncio.sleep(0.1) + assert TriggerEvent(expected_result) == task.result() + asyncio.get_event_loop().stop() + + @pytest.mark.asyncio + @mock.patch("airflow.providers.openai.hooks.openai.OpenAIHook.get_batch") + async def test_openai_batch_for_unexpected_error(self, mock_batch): + """Assert that run trigger messages in case of unexpected error""" + mock_batch.return_value = 1.0 # FORCE FAILURE TO TEST EXCEPTION + trigger = OpenAIBatchTrigger( + conn_id=self.CONN_ID, + batch_id=self.BATCH_ID, + poll_interval=self.POLL_INTERVAL, + end_time=self.END_TIME, + ) + expected_result = { + "status": "error", + "message": "'float' object has no attribute 'status'", + "batch_id": self.BATCH_ID, + } + task = asyncio.create_task(trigger.run().__anext__()) + await asyncio.sleep(0.1) + assert TriggerEvent(expected_result) == task.result() + asyncio.get_event_loop().stop() diff --git a/tests/system/providers/openai/example_trigger_batch_operator.py b/tests/system/providers/openai/example_trigger_batch_operator.py new file mode 100644 index 0000000000000..6f01f648ccc7b --- /dev/null +++ b/tests/system/providers/openai/example_trigger_batch_operator.py @@ -0,0 +1,117 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import Any, Literal + +from airflow.decorators import dag, task + +OPENAI_CONN_ID = "openai_default" + +POKEMONS = [ + "pikachu", + "charmander", + "bulbasaur", +] + + +@dag( + schedule=None, + catchup=False, +) +def openai_batch_chat_completions(): + @task + def generate_messages(pokemon, **context) -> list[dict[str, Any]]: + return [{"role": "user", "content": f"Describe the info about {pokemon}?"}] + + @task + def batch_upload(messages_batch, **context) -> str: + import tempfile + import uuid + + from pydantic import BaseModel, Field + + from airflow.providers.openai.hooks.openai import OpenAIHook + + class RequestBody(BaseModel): + model: str + messages: list[dict[str, Any]] + max_tokens: int = Field(default=1000) + + class BatchModel(BaseModel): + custom_id: str + method: Literal["POST"] + url: Literal["/v1/chat/completions"] + body: RequestBody + + model = "gpt-4o-mini" + max_tokens = 1000 + hook = OpenAIHook(conn_id=OPENAI_CONN_ID) + with tempfile.NamedTemporaryFile(mode="w", delete=False) as file: + for messages in messages_batch: + file.write( + BatchModel( + custom_id=str(uuid.uuid4()), + method="POST", + url="/v1/chat/completions", + body=RequestBody( + model=model, + max_tokens=max_tokens, + messages=messages, + ), + ).model_dump_json() + + "\n" + ) + batch_file = hook.upload_file(file.name, purpose="batch") + return batch_file.id + + @task + def cleanup_batch_output_file(batch_id, **context): + from airflow.providers.openai.hooks.openai import OpenAIHook + + hook = OpenAIHook(conn_id=OPENAI_CONN_ID) + batch = hook.get_batch(batch_id) + if batch.output_file_id: + hook.delete_file(batch.output_file_id) + + messages = generate_messages.expand(pokemon=POKEMONS) + batch_file_id = batch_upload(messages_batch=messages) + + # [START howto_operator_openai_trigger_operator] + from airflow.providers.openai.operators.openai import OpenAITriggerBatchOperator + + batch_id = OpenAITriggerBatchOperator( + task_id="batch_operator_deferred", + conn_id=OPENAI_CONN_ID, + file_id=batch_file_id, + endpoint="/v1/chat/completions", + deferrable=True, + ) + # [END howto_operator_openai_trigger_operator] + cleanup_batch_output = cleanup_batch_output_file( + batch_id="{{ ti.xcom_pull(task_ids='batch_operator_deferred', key='return_value') }}" + ) + batch_id >> cleanup_batch_output + + +openai_batch_chat_completions() + + +from tests.system.utils import get_test_run # noqa: E402 + +# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) +test_run = get_test_run(dag)