From 9dcf20cc2289c759a9a4405aae8e04eb1bd436db Mon Sep 17 00:00:00 2001 From: Paulo Nascimento <37284051+paulonasc@users.noreply.github.com> Date: Mon, 19 Feb 2024 19:11:58 -0800 Subject: [PATCH 1/3] feat: openai dalle image generation langchain tool --- .../openai_dalle_image_generation/__init__.py | 7 +++++ .../openai_dalle_image_generation/tool.py | 29 +++++++++++++++++++ .../openai_dalle_image_generation/__init__.py | 0 .../test_image_generation.py | 13 +++++++++ 4 files changed, 49 insertions(+) create mode 100644 libs/community/langchain_community/tools/openai_dalle_image_generation/__init__.py create mode 100644 libs/community/langchain_community/tools/openai_dalle_image_generation/tool.py create mode 100644 libs/community/tests/unit_tests/tools/openai_dalle_image_generation/__init__.py create mode 100644 libs/community/tests/unit_tests/tools/openai_dalle_image_generation/test_image_generation.py diff --git a/libs/community/langchain_community/tools/openai_dalle_image_generation/__init__.py b/libs/community/langchain_community/tools/openai_dalle_image_generation/__init__.py new file mode 100644 index 0000000000000..dbdf41b11b253 --- /dev/null +++ b/libs/community/langchain_community/tools/openai_dalle_image_generation/__init__.py @@ -0,0 +1,7 @@ +"""Tool to generate an image using DALLE OpenAI V1 SDK.""" + +from langchain_community.tools.openai_dalle_image_generation.tool import ( + OpenAIDALLEImageGenerationTool, +) + +__all__ = ["OpenAIDALLEImageGenerationTool"] diff --git a/libs/community/langchain_community/tools/openai_dalle_image_generation/tool.py b/libs/community/langchain_community/tools/openai_dalle_image_generation/tool.py new file mode 100644 index 0000000000000..69b173c24345b --- /dev/null +++ b/libs/community/langchain_community/tools/openai_dalle_image_generation/tool.py @@ -0,0 +1,29 @@ +"""Tool for the OpenAI DALLE V1 Image Generation SDK.""" + +from typing import Optional + +from langchain_core.callbacks import CallbackManagerForToolRun +from langchain_core.tools import BaseTool + +from langchain_community.utilities.dalle_image_generator import DallEAPIWrapper + + +class OpenAIDALLEImageGenerationTool(BaseTool): + """Tool that generates an image using OpenAI DALLE.""" + + name: str = "OpenAI DALLE" + description: str = ( + "A wrapper around OpenAI DALLE Image Generation. " + "Useful for when you need to generate an image of" + "people, places, paintings, animals, or other subjects. " + "Input should be a text prompt to generate an image." + ) + api_wrapper: DallEAPIWrapper + + def _run( + self, + query: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: + """Use the OpenAI DALLE Image Generation tool.""" + return self.api_wrapper.run(query) diff --git a/libs/community/tests/unit_tests/tools/openai_dalle_image_generation/__init__.py b/libs/community/tests/unit_tests/tools/openai_dalle_image_generation/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/libs/community/tests/unit_tests/tools/openai_dalle_image_generation/test_image_generation.py b/libs/community/tests/unit_tests/tools/openai_dalle_image_generation/test_image_generation.py new file mode 100644 index 0000000000000..4f287c25d41be --- /dev/null +++ b/libs/community/tests/unit_tests/tools/openai_dalle_image_generation/test_image_generation.py @@ -0,0 +1,13 @@ +from unittest.mock import MagicMock + +from langchain_community.tools.openai_dalle_image_generation import OpenAIDALLEImageGenerationTool + + +def test_generate_image() -> None: + """Test OpenAI DALLE Image Generation.""" + mock_api_resource = MagicMock() + # bypass pydantic validation as openai is not a package dependency + tool = OpenAIDALLEImageGenerationTool.construct(api_wrapper=mock_api_resource) + tool_input = {"query": "parrot on a branch"} + result = tool.run(tool_input) + assert result.startswith("https://") From 8e6ec25ac59a9b83fb03a62e2b8d8f3369e19de8 Mon Sep 17 00:00:00 2001 From: Paulo Nascimento <37284051+paulonasc@users.noreply.github.com> Date: Mon, 19 Feb 2024 19:17:48 -0800 Subject: [PATCH 2/3] formatting --- .../openai_dalle_image_generation/test_image_generation.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/libs/community/tests/unit_tests/tools/openai_dalle_image_generation/test_image_generation.py b/libs/community/tests/unit_tests/tools/openai_dalle_image_generation/test_image_generation.py index 4f287c25d41be..7358fc32926e4 100644 --- a/libs/community/tests/unit_tests/tools/openai_dalle_image_generation/test_image_generation.py +++ b/libs/community/tests/unit_tests/tools/openai_dalle_image_generation/test_image_generation.py @@ -1,6 +1,8 @@ from unittest.mock import MagicMock -from langchain_community.tools.openai_dalle_image_generation import OpenAIDALLEImageGenerationTool +from langchain_community.tools.openai_dalle_image_generation import ( + OpenAIDALLEImageGenerationTool, +) def test_generate_image() -> None: From 3dee236da3b455c7ed1df1ededf68d8a5c67443c Mon Sep 17 00:00:00 2001 From: Bagatur Date: Thu, 28 Mar 2024 17:20:54 -0700 Subject: [PATCH 3/3] fmt --- .../tools/openai_dalle_image_generation/tool.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/community/langchain_community/tools/openai_dalle_image_generation/tool.py b/libs/community/langchain_community/tools/openai_dalle_image_generation/tool.py index 69b173c24345b..36374e887f74b 100644 --- a/libs/community/langchain_community/tools/openai_dalle_image_generation/tool.py +++ b/libs/community/langchain_community/tools/openai_dalle_image_generation/tool.py @@ -11,7 +11,7 @@ class OpenAIDALLEImageGenerationTool(BaseTool): """Tool that generates an image using OpenAI DALLE.""" - name: str = "OpenAI DALLE" + name: str = "openai_dalle" description: str = ( "A wrapper around OpenAI DALLE Image Generation. " "Useful for when you need to generate an image of"