diff --git a/airflow/providers/openai/hooks/openai.py b/airflow/providers/openai/hooks/openai.py index 31a4c16b9f29c..5e73ac12fb2db 100644 --- a/airflow/providers/openai/hooks/openai.py +++ b/airflow/providers/openai/hooks/openai.py @@ -18,13 +18,22 @@ from __future__ import annotations from functools import cached_property -from typing import TYPE_CHECKING, Any, Literal +from typing import TYPE_CHECKING, Any, BinaryIO, Literal from openai import OpenAI if TYPE_CHECKING: - from openai.types.beta import Assistant, AssistantDeleted, Thread, ThreadDeleted + from openai.types import FileDeleted, FileObject + from openai.types.beta import ( + Assistant, + AssistantDeleted, + Thread, + ThreadDeleted, + VectorStore, + VectorStoreDeleted, + ) from openai.types.beta.threads import Message, Run + from openai.types.beta.vector_stores import VectorStoreFile, VectorStoreFileBatch, VectorStoreFileDeleted from openai.types.chat import ( ChatCompletionAssistantMessageParam, ChatCompletionFunctionMessageParam, @@ -111,7 +120,8 @@ def create_chat_completion( return response.choices def create_assistant(self, model: str = "gpt-3.5-turbo", **kwargs: Any) -> Assistant: - """Create an OpenAI assistant using the given model. + """ + Create an OpenAI assistant using the given model. :param model: The OpenAI model for the assistant to use. """ @@ -133,18 +143,17 @@ def get_assistants(self, **kwargs: Any) -> list[Assistant]: return assistants.data def get_assistant_by_name(self, assistant_name: str) -> Assistant | None: - """Get an OpenAI Assistant object for a given name. + """ + Get an OpenAI Assistant object for a given name. :param assistant_name: The name of the assistant to retrieve """ - response = self.get_assistants() - for assistant in response: - if assistant.name == assistant_name: - return assistant - return None + assistants = self.get_assistants() + return next((assistant for assistant in assistants if assistant.name == assistant_name), None) def modify_assistant(self, assistant_id: str, **kwargs: Any) -> Assistant: - """Modify an existing Assistant object. + """ + Modify an existing Assistant object. :param assistant_id: The ID of the assistant to be modified. """ @@ -152,7 +161,8 @@ def modify_assistant(self, assistant_id: str, **kwargs: Any) -> Assistant: return assistant def delete_assistant(self, assistant_id: str) -> AssistantDeleted: - """Delete an OpenAI Assistant for a given ID. + """ + Delete an OpenAI Assistant for a given ID. :param assistant_id: The ID of the assistant to delete. """ @@ -165,16 +175,18 @@ def create_thread(self, **kwargs: Any) -> Thread: return thread def modify_thread(self, thread_id: str, metadata: dict[str, Any]) -> Thread: - """Modify an existing Thread object. + """ + Modify an existing Thread object. - :param thread_id: The ID of the thread to modify. + :param thread_id: The ID of the thread to modify. Only the metadata can be modified. :param metadata: Set of 16 key-value pairs that can be attached to an object. """ thread = self.conn.beta.threads.update(thread_id=thread_id, metadata=metadata) return thread def delete_thread(self, thread_id: str) -> ThreadDeleted: - """Delete an OpenAI thread for a given thread_id. + """ + Delete an OpenAI thread for a given thread_id. :param thread_id: The ID of the thread to delete. """ @@ -184,7 +196,8 @@ def delete_thread(self, thread_id: str) -> ThreadDeleted: def create_message( self, thread_id: str, role: Literal["user", "assistant"], content: str, **kwargs: Any ) -> Message: - """Create a message for a given Thread. + """ + Create a message for a given Thread. :param thread_id: The ID of the thread to create a message for. :param role: The role of the entity that is creating the message. Allowed values include: 'user', 'assistant'. @@ -196,7 +209,8 @@ def create_message( return thread_message def get_messages(self, thread_id: str, **kwargs: Any) -> list[Message]: - """Return a list of messages for a given Thread. + """ + Return a list of messages for a given Thread. :param thread_id: The ID of the thread the messages belong to. """ @@ -204,7 +218,8 @@ def get_messages(self, thread_id: str, **kwargs: Any) -> list[Message]: return messages.data def modify_message(self, thread_id: str, message_id, **kwargs: Any) -> Message: - """Modify an existing message for a given Thread. + """ + Modify an existing message for a given Thread. :param thread_id: The ID of the thread to which this message belongs. :param message_id: The ID of the message to modify. @@ -215,7 +230,8 @@ def modify_message(self, thread_id: str, message_id, **kwargs: Any) -> Message: return thread_message def create_run(self, thread_id: str, assistant_id: str, **kwargs: Any) -> Run: - """Create a run for a given thread and assistant. + """ + Create a run for a given thread and assistant. :param thread_id: The ID of the thread to run. :param assistant_id: The ID of the assistant to use to execute this run. @@ -223,8 +239,22 @@ def create_run(self, thread_id: str, assistant_id: str, **kwargs: Any) -> Run: run = self.conn.beta.threads.runs.create(thread_id=thread_id, assistant_id=assistant_id, **kwargs) return run + def create_run_and_poll(self, thread_id: str, assistant_id: str, **kwargs: Any) -> Run: + """ + Create a run for a given thread and assistant and then polls until completion. + + :param thread_id: The ID of the thread to run. + :param assistant_id: The ID of the assistant to use to execute this run. + :return: An OpenAI Run object + """ + run = self.conn.beta.threads.runs.create_and_poll( + thread_id=thread_id, assistant_id=assistant_id, **kwargs + ) + return run + def get_run(self, thread_id: str, run_id: str) -> Run: - """Retrieve a run for a given thread and run. + """ + Retrieve a run for a given thread and run. :param thread_id: The ID of the thread that was run. :param run_id: The ID of the run to retrieve. @@ -257,7 +287,8 @@ def create_embeddings( model: str = "text-embedding-ada-002", **kwargs: Any, ) -> list[float]: - """Generate embeddings for the given text using the given model. + """ + Generate embeddings for the given text using the given model. :param text: The text to generate embeddings for. :param model: The model to use for generating embeddings. @@ -265,3 +296,128 @@ def create_embeddings( response = self.conn.embeddings.create(model=model, input=text, **kwargs) embeddings: list[float] = response.data[0].embedding return embeddings + + def upload_file(self, file: str, purpose: Literal["fine-tune", "assistants"]) -> FileObject: + """ + Upload a file that can be used across various endpoints. The size of all the files uploaded by one organization can be up to 100 GB. + + :param file: The File object (not file name) to be uploaded. + :param purpose: The intended purpose of the uploaded file. Use "fine-tune" for + Fine-tuning and "assistants" for Assistants and Messages. + """ + file_object = self.conn.files.create(file=open(file, "rb"), purpose=purpose) + return file_object + + def get_file(self, file_id: str) -> FileObject: + """ + Return information about a specific file. + + :param file_id: The ID of the file to use for this request. + """ + file = self.conn.files.retrieve(file_id=file_id) + return file + + def get_files(self) -> list[FileObject]: + """Return a list of files that belong to the user's organization.""" + files = self.conn.files.list() + return files.data + + def get_file_by_name(self, file_name: str) -> FileObject | None: + """ + Get an OpenAI Assistant object for a given name. + + :param file_name: The name of the file object to retrieve + """ + files = self.get_files() + return next((file for file in files if file.filename == file_name), None) + + def delete_file(self, file_id) -> FileDeleted: + """ + Delete a file. + + :param file_id: The ID of the file to be deleted. + """ + response = self.conn.files.delete(file_id=file_id) + return response + + def create_vector_store(self, **kwargs) -> VectorStore: + """Create a vector store.""" + vector_store = self.conn.beta.vector_stores.create(**kwargs) + return vector_store + + def get_vectors_stores(self, **kwargs) -> list[VectorStore]: + """Return a list of vector stores.""" + vector_stores = self.conn.beta.vector_stores.list(**kwargs) + return vector_stores.data + + def get_vector_store(self, vector_store_id: str) -> VectorStore: + """ + Retrieve a vector store. + + :param vector_store_id: The ID of the vector store to retrieve. + """ + vector_store = self.conn.beta.vector_stores.retrieve(vector_store_id=vector_store_id) + return vector_store + + def get_vector_store_by_name(self, vector_store_name: str) -> VectorStore | None: + """ + Get an OpenAI Vector Store object for a given name. + + :param vector_store_name: The name of the vector store to retrieve. + """ + vector_stores = self.get_vectors_stores() + return next( + (vector_store for vector_store in vector_stores if vector_store.name == vector_store_name), None + ) + + def modify_vector_store(self, vector_store_id: str, **kwargs) -> VectorStore: + """ + Modify a vector store. + + :param vector_store_id: The ID of the vector store to modify. + """ + vector_store = self.conn.beta.vector_stores.update(vector_store_id=vector_store_id, **kwargs) + return vector_store + + def delete_vector_store(self, vector_store_id: str) -> VectorStoreDeleted: + """ + Delete a vector store. + + :param vector_store_id: The ID of the vector store to delete. + """ + response = self.conn.beta.vector_stores.delete(vector_store_id=vector_store_id) + return response + + def upload_files_to_vector_store( + self, vector_store_id: str, files: list[BinaryIO] + ) -> VectorStoreFileBatch: + """ + Upload files to a vector store and poll until completion. + + :param vector_store_id: The ID of the vector store the files are to be uploaded + to. + :param files: A list of binary files to upload. + """ + file_batch = self.conn.beta.vector_stores.file_batches.upload_and_poll( + vector_store_id=vector_store_id, files=files + ) + return file_batch + + def get_vector_store_files(self, vector_store_id: str) -> list[VectorStoreFile]: + """ + Return a list of vector store files. + + :param vector_store_id: + """ + vector_store_files = self.conn.beta.vector_stores.files.list(vector_store_id=vector_store_id) + return vector_store_files.data + + def delete_vector_store_file(self, vector_store_id: str, file_id: str) -> VectorStoreFileDeleted: + """ + Delete a vector store file. This will remove the file from the vector store but the file itself will not be deleted. To delete the file, use delete_file. + + :param vector_store_id: The ID of the vector store that the file belongs to. + :param file_id: The ID of the file to delete. + """ + response = self.conn.beta.vector_stores.files.delete(vector_store_id=vector_store_id, file_id=file_id) + return response