diff --git a/airflow/providers/openai/hooks/openai.py b/airflow/providers/openai/hooks/openai.py index 31a4c16b9f29c..e66283afd6108 100644 --- a/airflow/providers/openai/hooks/openai.py +++ b/airflow/providers/openai/hooks/openai.py @@ -18,13 +18,22 @@ from __future__ import annotations from functools import cached_property -from typing import TYPE_CHECKING, Any, Literal +from typing import TYPE_CHECKING, Any, BinaryIO, Literal from openai import OpenAI if TYPE_CHECKING: - from openai.types.beta import Assistant, AssistantDeleted, Thread, ThreadDeleted + from openai.types import FileDeleted, FileObject + from openai.types.beta import ( + Assistant, + AssistantDeleted, + Thread, + ThreadDeleted, + VectorStore, + VectorStoreDeleted, + ) from openai.types.beta.threads import Message, Run + from openai.types.beta.vector_stores import VectorStoreFile, VectorStoreFileBatch, VectorStoreFileDeleted from openai.types.chat import ( ChatCompletionAssistantMessageParam, ChatCompletionFunctionMessageParam, @@ -111,7 +120,8 @@ def create_chat_completion( return response.choices def create_assistant(self, model: str = "gpt-3.5-turbo", **kwargs: Any) -> Assistant: - """Create an OpenAI assistant using the given model. + """ + Create an OpenAI assistant using the given model. :param model: The OpenAI model for the assistant to use. """ @@ -132,19 +142,9 @@ def get_assistants(self, **kwargs: Any) -> list[Assistant]: assistants = self.conn.beta.assistants.list(**kwargs) return assistants.data - def get_assistant_by_name(self, assistant_name: str) -> Assistant | None: - """Get an OpenAI Assistant object for a given name. - - :param assistant_name: The name of the assistant to retrieve - """ - response = self.get_assistants() - for assistant in response: - if assistant.name == assistant_name: - return assistant - return None - def modify_assistant(self, assistant_id: str, **kwargs: Any) -> Assistant: - """Modify an existing Assistant object. + """ + Modify an existing Assistant object. :param assistant_id: The ID of the assistant to be modified. """ @@ -152,7 +152,8 @@ def modify_assistant(self, assistant_id: str, **kwargs: Any) -> Assistant: return assistant def delete_assistant(self, assistant_id: str) -> AssistantDeleted: - """Delete an OpenAI Assistant for a given ID. + """ + Delete an OpenAI Assistant for a given ID. :param assistant_id: The ID of the assistant to delete. """ @@ -165,16 +166,18 @@ def create_thread(self, **kwargs: Any) -> Thread: return thread def modify_thread(self, thread_id: str, metadata: dict[str, Any]) -> Thread: - """Modify an existing Thread object. + """ + Modify an existing Thread object. - :param thread_id: The ID of the thread to modify. + :param thread_id: The ID of the thread to modify. Only the metadata can be modified. :param metadata: Set of 16 key-value pairs that can be attached to an object. """ thread = self.conn.beta.threads.update(thread_id=thread_id, metadata=metadata) return thread def delete_thread(self, thread_id: str) -> ThreadDeleted: - """Delete an OpenAI thread for a given thread_id. + """ + Delete an OpenAI thread for a given thread_id. :param thread_id: The ID of the thread to delete. """ @@ -184,7 +187,8 @@ def delete_thread(self, thread_id: str) -> ThreadDeleted: def create_message( self, thread_id: str, role: Literal["user", "assistant"], content: str, **kwargs: Any ) -> Message: - """Create a message for a given Thread. + """ + Create a message for a given Thread. :param thread_id: The ID of the thread to create a message for. :param role: The role of the entity that is creating the message. Allowed values include: 'user', 'assistant'. @@ -196,7 +200,8 @@ def create_message( return thread_message def get_messages(self, thread_id: str, **kwargs: Any) -> list[Message]: - """Return a list of messages for a given Thread. + """ + Return a list of messages for a given Thread. :param thread_id: The ID of the thread the messages belong to. """ @@ -204,7 +209,8 @@ def get_messages(self, thread_id: str, **kwargs: Any) -> list[Message]: return messages.data def modify_message(self, thread_id: str, message_id, **kwargs: Any) -> Message: - """Modify an existing message for a given Thread. + """ + Modify an existing message for a given Thread. :param thread_id: The ID of the thread to which this message belongs. :param message_id: The ID of the message to modify. @@ -215,7 +221,8 @@ def modify_message(self, thread_id: str, message_id, **kwargs: Any) -> Message: return thread_message def create_run(self, thread_id: str, assistant_id: str, **kwargs: Any) -> Run: - """Create a run for a given thread and assistant. + """ + Create a run for a given thread and assistant. :param thread_id: The ID of the thread to run. :param assistant_id: The ID of the assistant to use to execute this run. @@ -223,8 +230,22 @@ def create_run(self, thread_id: str, assistant_id: str, **kwargs: Any) -> Run: run = self.conn.beta.threads.runs.create(thread_id=thread_id, assistant_id=assistant_id, **kwargs) return run + def create_run_and_poll(self, thread_id: str, assistant_id: str, **kwargs: Any) -> Run: + """ + Create a run for a given thread and assistant and then polls until completion. + + :param thread_id: The ID of the thread to run. + :param assistant_id: The ID of the assistant to use to execute this run. + :return: An OpenAI Run object + """ + run = self.conn.beta.threads.runs.create_and_poll( + thread_id=thread_id, assistant_id=assistant_id, **kwargs + ) + return run + def get_run(self, thread_id: str, run_id: str) -> Run: - """Retrieve a run for a given thread and run. + """ + Retrieve a run for a given thread and run. :param thread_id: The ID of the thread that was run. :param run_id: The ID of the run to retrieve. @@ -257,7 +278,8 @@ def create_embeddings( model: str = "text-embedding-ada-002", **kwargs: Any, ) -> list[float]: - """Generate embeddings for the given text using the given model. + """ + Generate embeddings for the given text using the given model. :param text: The text to generate embeddings for. :param model: The model to use for generating embeddings. @@ -265,3 +287,109 @@ def create_embeddings( response = self.conn.embeddings.create(model=model, input=text, **kwargs) embeddings: list[float] = response.data[0].embedding return embeddings + + def upload_file(self, file: str, purpose: Literal["fine-tune", "assistants"]) -> FileObject: + """ + Upload a file that can be used across various endpoints. The size of all the files uploaded by one organization can be up to 100 GB. + + :param file: The File object (not file name) to be uploaded. + :param purpose: The intended purpose of the uploaded file. Use "fine-tune" for + Fine-tuning and "assistants" for Assistants and Messages. + """ + with open(file, "rb") as file_stream: + file_object = self.conn.files.create(file=file_stream, purpose=purpose) + return file_object + + def get_file(self, file_id: str) -> FileObject: + """ + Return information about a specific file. + + :param file_id: The ID of the file to use for this request. + """ + file = self.conn.files.retrieve(file_id=file_id) + return file + + def get_files(self) -> list[FileObject]: + """Return a list of files that belong to the user's organization.""" + files = self.conn.files.list() + return files.data + + def delete_file(self, file_id: str) -> FileDeleted: + """ + Delete a file. + + :param file_id: The ID of the file to be deleted. + """ + response = self.conn.files.delete(file_id=file_id) + return response + + def create_vector_store(self, **kwargs: Any) -> VectorStore: + """Create a vector store.""" + vector_store = self.conn.beta.vector_stores.create(**kwargs) + return vector_store + + def get_vector_stores(self, **kwargs: Any) -> list[VectorStore]: + """Return a list of vector stores.""" + vector_stores = self.conn.beta.vector_stores.list(**kwargs) + return vector_stores.data + + def get_vector_store(self, vector_store_id: str) -> VectorStore: + """ + Retrieve a vector store. + + :param vector_store_id: The ID of the vector store to retrieve. + """ + vector_store = self.conn.beta.vector_stores.retrieve(vector_store_id=vector_store_id) + return vector_store + + def modify_vector_store(self, vector_store_id: str, **kwargs: Any) -> VectorStore: + """ + Modify a vector store. + + :param vector_store_id: The ID of the vector store to modify. + """ + vector_store = self.conn.beta.vector_stores.update(vector_store_id=vector_store_id, **kwargs) + return vector_store + + def delete_vector_store(self, vector_store_id: str) -> VectorStoreDeleted: + """ + Delete a vector store. + + :param vector_store_id: The ID of the vector store to delete. + """ + response = self.conn.beta.vector_stores.delete(vector_store_id=vector_store_id) + return response + + def upload_files_to_vector_store( + self, vector_store_id: str, files: list[BinaryIO] + ) -> VectorStoreFileBatch: + """ + Upload files to a vector store and poll until completion. + + :param vector_store_id: The ID of the vector store the files are to be uploaded + to. + :param files: A list of binary files to upload. + """ + file_batch = self.conn.beta.vector_stores.file_batches.upload_and_poll( + vector_store_id=vector_store_id, files=files + ) + return file_batch + + def get_vector_store_files(self, vector_store_id: str) -> list[VectorStoreFile]: + """ + Return a list of vector store files. + + :param vector_store_id: + """ + vector_store_files = self.conn.beta.vector_stores.files.list(vector_store_id=vector_store_id) + return vector_store_files.data + + def delete_vector_store_file(self, vector_store_id: str, file_id: str) -> VectorStoreFileDeleted: + """ + Delete a vector store file. This will remove the file from the vector store but the file itself will not be deleted. To delete the file, use delete_file. + + :param vector_store_id: The ID of the vector store that the file belongs to. + :param file_id: The ID of the file to delete. + """ + response = self.conn.beta.vector_stores.files.delete(vector_store_id=vector_store_id, file_id=file_id) + return response diff --git a/airflow/providers/openai/provider.yaml b/airflow/providers/openai/provider.yaml index 9d12cddc502bb..fce0b5650cdcb 100644 --- a/airflow/providers/openai/provider.yaml +++ b/airflow/providers/openai/provider.yaml @@ -41,7 +41,7 @@ integrations: dependencies: - apache-airflow>=2.7.0 - - openai[datalib]>=1.16 + - openai[datalib]>=1.23 hooks: - integration-name: OpenAI diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index d91b48858029a..c17fe2d2e291a 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -805,7 +805,7 @@ "openai": { "deps": [ "apache-airflow>=2.7.0", - "openai[datalib]>=1.16" + "openai[datalib]>=1.23" ], "devel-deps": [], "cross-providers-deps": [], diff --git a/tests/providers/openai/hooks/test_openai.py b/tests/providers/openai/hooks/test_openai.py index aa7a479bbbe0b..374e0f61c99e7 100644 --- a/tests/providers/openai/hooks/test_openai.py +++ b/tests/providers/openai/hooks/test_openai.py @@ -23,10 +23,20 @@ openai = pytest.importorskip("openai") +from unittest.mock import mock_open + from openai.pagination import SyncCursorPage -from openai.types import CreateEmbeddingResponse, Embedding -from openai.types.beta import Assistant, AssistantDeleted, Thread, ThreadDeleted +from openai.types import CreateEmbeddingResponse, Embedding, FileDeleted, FileObject +from openai.types.beta import ( + Assistant, + AssistantDeleted, + Thread, + ThreadDeleted, + VectorStore, + VectorStoreDeleted, +) from openai.types.beta.threads import Message, Run +from openai.types.beta.vector_stores import VectorStoreFile, VectorStoreFileBatch, VectorStoreFileDeleted from openai.types.chat import ChatCompletion from airflow.models import Connection @@ -39,7 +49,12 @@ MESSAGE_ID = "test_message_abc123" RUN_ID = "test_run_abc123" MODEL = "gpt-4" +FILE_ID = "test_file_abc123" +FILE_NAME = "test_file.pdf" METADATA = {"modified": "true", "user": "abc123"} +VECTOR_STORE_ID = "test_vs_abc123" +VECTOR_STORE_NAME = "Test Vector Store" +VECTOR_FILE_STORE_BATCH_ID = "test_vfsb_abc123" @pytest.fixture @@ -161,6 +176,90 @@ def mock_run_list(mock_run): return SyncCursorPage[Run](data=[mock_run]) +@pytest.fixture +def mock_file(): + return FileObject( + id=FILE_ID, + object="file", + bytes=120000, + created_at=1677610602, + filename=FILE_NAME, + purpose="assistants", + status="processed", + ) + + +@pytest.fixture +def mock_file_list(mock_file): + return SyncCursorPage[FileObject](data=[mock_file]) + + +@pytest.fixture +def mock_vector_store(): + return VectorStore( + id=VECTOR_STORE_ID, + object="vector_store", + created_at=1698107661, + usage_bytes=123456, + last_active_at=1698107661, + name=VECTOR_STORE_NAME, + bytes=123456, + status="completed", + file_counts={"in_progress": 0, "completed": 100, "cancelled": 0, "failed": 0, "total": 100}, + metadata={}, + last_used_at=1698107661, + ) + + +@pytest.fixture +def mock_vector_store_list(mock_vector_store): + return SyncCursorPage[VectorStore](data=[mock_vector_store]) + + +@pytest.fixture +def mock_vector_file_store_batch(): + return VectorStoreFileBatch( + id=VECTOR_FILE_STORE_BATCH_ID, + object="vector_store.files_batch", + created_at=1699061776, + vector_store_id=VECTOR_STORE_ID, + status="completed", + file_counts={ + "in_progress": 0, + "completed": 3, + "failed": 0, + "cancelled": 0, + "total": 0, + }, + ) + + +@pytest.fixture +def mock_vector_file_store_list(): + return SyncCursorPage[VectorStoreFile]( + data=[ + VectorStoreFile( + id="test-file-abc123", + object="vector_store.file", + created_at=1699061776, + usage_bytes=1234, + vector_store_id=VECTOR_STORE_ID, + status="completed", + last_error=None, + ), + VectorStoreFile( + id="test-file-abc456", + object="vector_store.file", + created_at=1699061776, + usage_bytes=1234, + vector_store_id=VECTOR_STORE_ID, + status="completed", + last_error=None, + ), + ] + ) + + def test_create_chat_completion(mock_openai_hook, mock_completion): messages = [ {"role": "system", "content": "You are a helpful assistant."}, @@ -197,12 +296,6 @@ def test_get_assistants(mock_openai_hook, mock_assistant_list): assert isinstance(assistants, list) -def test_get_assistant_by_name(mock_openai_hook, mock_assistant_list): - mock_openai_hook.conn.beta.assistants.list.return_value = mock_assistant_list - assistant = mock_openai_hook.get_assistant_by_name(assistant_name=ASSISTANT_NAME) - assert assistant.name == ASSISTANT_NAME - - def test_modify_assistant(mock_openai_hook, mock_assistant): new_assistant_name = "New Test Assistant" mock_assistant.name = new_assistant_name @@ -269,6 +362,14 @@ def test_create_run(mock_openai_hook, mock_run): assert run.id == RUN_ID +def test_create_run_and_poll(mock_openai_hook, mock_run): + thread_id = THREAD_ID + assistant_id = ASSISTANT_ID + mock_openai_hook.conn.beta.threads.runs.create_and_poll.return_value = mock_run + run = mock_openai_hook.create_run_and_poll(thread_id=thread_id, assistant_id=assistant_id) + assert run.id == RUN_ID + + def test_get_runs(mock_openai_hook, mock_run_list): mock_openai_hook.conn.beta.threads.runs.list.return_value = mock_run_list runs = mock_openai_hook.get_runs(thread_id=THREAD_ID) @@ -296,6 +397,103 @@ def test_create_embeddings(mock_openai_hook, mock_embeddings_response): assert embeddings == [0.1, 0.2, 0.3] +@patch("builtins.open", new_callable=mock_open, read_data="test-data") +def test_upload_file(mock_file_open, mock_openai_hook, mock_file): + mock_file.name = FILE_NAME + mock_file.purpose = "assistants" + mock_openai_hook.conn.files.create.return_value = mock_file + file = mock_openai_hook.upload_file(file=mock_file_open(), purpose="assistants") + assert file.name == FILE_NAME + assert file.purpose == "assistants" + + +def test_get_file(mock_openai_hook, mock_file): + mock_openai_hook.conn.files.retrieve.return_value = mock_file + file = mock_openai_hook.get_file(file_id=FILE_ID) + assert file.id == FILE_ID + assert file.filename == FILE_NAME + + +def test_get_files(mock_openai_hook, mock_file_list): + mock_openai_hook.conn.files.list.return_value = mock_file_list + files = mock_openai_hook.get_files() + assert isinstance(files, list) + + +def test_delete_file(mock_openai_hook): + delete_response = FileDeleted(id=FILE_ID, object="file", deleted=True) + mock_openai_hook.conn.files.delete.return_value = delete_response + file_deleted = mock_openai_hook.delete_file(file_id=FILE_ID) + assert file_deleted.deleted + + +def test_create_vector_store(mock_openai_hook, mock_vector_store): + mock_openai_hook.conn.beta.vector_stores.create.return_value = mock_vector_store + vector_store = mock_openai_hook.create_vector_store(name=VECTOR_STORE_NAME) + assert vector_store.id == VECTOR_STORE_ID + assert vector_store.name == VECTOR_STORE_NAME + + +def test_get_vector_store(mock_openai_hook, mock_vector_store): + mock_openai_hook.conn.beta.vector_stores.retrieve.return_value = mock_vector_store + vector_store = mock_openai_hook.get_vector_store(vector_store_id=VECTOR_STORE_ID) + assert vector_store.id == VECTOR_STORE_ID + assert vector_store.name == VECTOR_STORE_NAME + + +def test_get_vector_stores(mock_openai_hook, mock_vector_store_list): + mock_openai_hook.conn.beta.vector_stores.list.return_value = mock_vector_store_list + vector_stores = mock_openai_hook.get_vector_stores() + assert isinstance(vector_stores, list) + + +def test_modify_vector_store(mock_openai_hook, mock_vector_store): + new_vector_store_name = "New Vector Store" + mock_vector_store.name = new_vector_store_name + mock_openai_hook.conn.beta.vector_stores.update.return_value = mock_vector_store + vector_store = mock_openai_hook.modify_vector_store( + vector_store_id=VECTOR_STORE_ID, name=new_vector_store_name + ) + assert vector_store.name == new_vector_store_name + + +def test_delete_vector_store(mock_openai_hook): + delete_response = VectorStoreDeleted(id=VECTOR_STORE_ID, object="vector_store.deleted", deleted=True) + mock_openai_hook.conn.beta.vector_stores.delete.return_value = delete_response + vector_store_deleted = mock_openai_hook.delete_vector_store(vector_store_id=VECTOR_STORE_ID) + assert vector_store_deleted.deleted + + +def test_upload_files_to_vector_store(mock_openai_hook, mock_vector_file_store_batch): + files = ["file1.txt", "file2.txt", "file3.txt"] + mock_openai_hook.conn.beta.vector_stores.file_batches.upload_and_poll.return_value = ( + mock_vector_file_store_batch + ) + vector_file_store_batch = mock_openai_hook.upload_files_to_vector_store( + vector_store_id=VECTOR_STORE_ID, files=files + ) + assert vector_file_store_batch.id == VECTOR_FILE_STORE_BATCH_ID + assert vector_file_store_batch.file_counts.completed == len(files) + + +def test_get_vector_store_files(mock_openai_hook, mock_vector_file_store_list): + mock_openai_hook.conn.beta.vector_stores.files.list.return_value = mock_vector_file_store_list + vector_file_store_list = mock_openai_hook.get_vector_store_files(vector_store_id=VECTOR_STORE_ID) + assert isinstance(vector_file_store_list, list) + + +def test_delete_vector_store_file(mock_openai_hook): + delete_response = VectorStoreFileDeleted( + id="test_file_abc123", object="vector_store.file.deleted", deleted=True + ) + mock_openai_hook.conn.beta.vector_stores.files.delete.return_value = delete_response + vector_store_file_deleted = mock_openai_hook.delete_vector_store_file( + vector_store_id=VECTOR_STORE_ID, file_id=FILE_ID + ) + assert vector_store_file_deleted.id == FILE_ID + assert vector_store_file_deleted.deleted + + def test_openai_hook_test_connection(mock_openai_hook): result, message = mock_openai_hook.test_connection() assert result is True