Skip to content

Commit

Permalink
Add stable diffusion integration (georgia-tech-db#1240)
Browse files Browse the repository at this point in the history
Reopen the georgia-tech-db#1111.

---------

Co-authored-by: sudoboi <[email protected]>
Co-authored-by: Abhijith S Raj <[email protected]>
  • Loading branch information
3 people authored and a0x8o committed Oct 30, 2023
1 parent cfe71f0 commit 2b924b7
Show file tree
Hide file tree
Showing 13 changed files with 719 additions and 0 deletions.
5 changes: 5 additions & 0 deletions docs/_toc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -238,11 +238,16 @@ parts:
title: OpenAI
- file: source/reference/ai/yolo
title: YOLO
<<<<<<< HEAD
<<<<<<< HEAD
- file: source/reference/ai/custom
title: Custom Model
<<<<<<< HEAD
=======
=======
- file: source/reference/ai/stablediffusion
title: Stable Diffusion
>>>>>>> bf022329 (Add stable diffusion integration (#1240))

- file: source/reference/ai/custom-ai-function
title: Bring Your Own AI Function
Expand Down
27 changes: 27 additions & 0 deletions docs/source/reference/ai/stablediffusion.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
.. _stablediffusion:

Stable Diffusion Models
======================================

This section provides an overview of how you can generate images from prompts in EvaDB using a Stable Diffusion model.


Introduction
------------

Stable Diffusion models leverage a controlled random walk process to generate intricate patterns and images from textual prompts,
bridging the gap between text and visual representation. EvaDB uses the stable diffusion implementation from `Replicate <https://replicate.com>`_.

Stable Diffusion UDF
--------------------

In order to create an image generation function in EvaDB, use the following SQL command:

.. code-block:: sql
CREATE FUNCTION IF NOT EXISTS StableDiffusion
IMPL 'evadb/functions/stable_diffusion.py';
EvaDB automatically uses the latest `stable diffusion release <https://replicate.com/stability-ai/stable-diffusion/versions>`_ available on Replicate.

To see a demo of how the function can be used, please check the `demo notebook <https://colab.research.google.com/github/georgia-tech-db/eva/blob/master/tutorials/18-stable-diffusion.ipynb>`_ on stable diffusion.
1 change: 1 addition & 0 deletions evadb/evadb.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,4 @@ third_party:
OPENAI_KEY: ""
PINECONE_API_KEY: ""
PINECONE_ENV: ""
REPLICATE_API_TOKEN: ""
88 changes: 88 additions & 0 deletions evadb/functions/dalle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# coding=utf-8
# Copyright 2018-2023 EvaDB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from io import BytesIO

import numpy as np
import pandas as pd
import requests
from PIL import Image

from evadb.catalog.catalog_type import NdArrayType
from evadb.configuration.configuration_manager import ConfigurationManager
from evadb.functions.abstract.abstract_function import AbstractFunction
from evadb.functions.decorators.decorators import forward
from evadb.functions.decorators.io_descriptors.data_types import PandasDataframe
from evadb.utils.generic_utils import try_to_import_openai


class DallEFunction(AbstractFunction):
@property
def name(self) -> str:
return "DallE"

def setup(self) -> None:
pass

@forward(
input_signatures=[
PandasDataframe(
columns=["prompt"],
column_types=[
NdArrayType.STR,
],
column_shapes=[(None,)],
)
],
output_signatures=[
PandasDataframe(
columns=["response"],
column_types=[NdArrayType.FLOAT32],
column_shapes=[(None, None, 3)],
)
],
)
def forward(self, text_df):
try_to_import_openai()
import openai

# Register API key, try configuration manager first
openai.api_key = ConfigurationManager().get_value("third_party", "OPENAI_KEY")
# If not found, try OS Environment Variable
if len(openai.api_key) == 0:
openai.api_key = os.environ.get("OPENAI_KEY", "")
assert (
len(openai.api_key) != 0
), "Please set your OpenAI API key in evadb.yml file (third_party, open_api_key) or environment variable (OPENAI_KEY)"

def generate_image(text_df: PandasDataframe):
results = []
queries = text_df[text_df.columns[0]]
for query in queries:
response = openai.Image.create(prompt=query, n=1, size="1024x1024")

# Download the image from the link
image_response = requests.get(response["data"][0]["url"])
image = Image.open(BytesIO(image_response.content))

# Convert the image to an array format suitable for the DataFrame
frame = np.array(image)
results.append(frame)

return results

df = pd.DataFrame({"response": generate_image(text_df=text_df)})
return df
14 changes: 14 additions & 0 deletions evadb/functions/function_bootstrap_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,18 @@
MODEL 'yolov8n.pt';
"""

stablediffusion_function_query = """CREATE FUNCTION IF NOT EXISTS StableDiffusion
IMPL '{}/functions/stable_diffusion.py';
""".format(
EvaDB_INSTALLATION_DIR
)

dalle_function_query = """CREATE FUNCTION IF NOT EXISTS DallE
IMPL '{}/functions/dalle.py';
""".format(
EvaDB_INSTALLATION_DIR
)


def init_builtin_functions(db: EvaDBDatabase, mode: str = "debug") -> None:
"""Load the built-in functions into the system during system bootstrapping.
Expand Down Expand Up @@ -274,6 +286,8 @@ def init_builtin_functions(db: EvaDBDatabase, mode: str = "debug") -> None:
# Mvit_function_query,
Sift_function_query,
Yolo_function_query,
stablediffusion_function_query,
dalle_function_query,
]

# if mode is 'debug', add debug functions
Expand Down
102 changes: 102 additions & 0 deletions evadb/functions/stable_diffusion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# coding=utf-8
# Copyright 2018-2023 EvaDB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from io import BytesIO

import numpy as np
import pandas as pd
import requests
from PIL import Image

from evadb.catalog.catalog_type import NdArrayType
from evadb.configuration.configuration_manager import ConfigurationManager
from evadb.functions.abstract.abstract_function import AbstractFunction
from evadb.functions.decorators.decorators import forward
from evadb.functions.decorators.io_descriptors.data_types import PandasDataframe
from evadb.utils.generic_utils import try_to_import_replicate


class StableDiffusion(AbstractFunction):
@property
def name(self) -> str:
return "StableDiffusion"

def setup(
self,
) -> None:
pass

@forward(
input_signatures=[
PandasDataframe(
columns=["prompt"],
column_types=[
NdArrayType.STR,
],
column_shapes=[(None,)],
)
],
output_signatures=[
PandasDataframe(
columns=["response"],
column_types=[
# FileFormatType.IMAGE,
NdArrayType.FLOAT32
],
column_shapes=[(None, None, 3)],
)
],
)
def forward(self, text_df):
try_to_import_replicate()
import replicate

# Register API key, try configuration manager first
replicate_api_key = ConfigurationManager().get_value(
"third_party", "REPLICATE_API_TOKEN"
)
# If not found, try OS Environment Variable
if len(replicate_api_key) == 0:
replicate_api_key = os.environ.get("REPLICATE_API_TOKEN", "")
assert (
len(replicate_api_key) != 0
), "Please set your Replicate API key in evadb.yml file (third_party, replicate_api_token) or environment variable (REPLICATE_API_TOKEN)"
os.environ["REPLICATE_API_TOKEN"] = replicate_api_key

model_id = (
replicate.models.get("stability-ai/stable-diffusion").versions.list()[0].id
)

def generate_image(text_df: PandasDataframe):
results = []
queries = text_df[text_df.columns[0]]
for query in queries:
output = replicate.run(
"stability-ai/stable-diffusion:" + model_id, input={"prompt": query}
)

# Download the image from the link
response = requests.get(output[0])
image = Image.open(BytesIO(response.content))

# Convert the image to an array format suitable for the DataFrame
frame = np.array(image)
results.append(frame)

return results

df = pd.DataFrame({"response": generate_image(text_df=text_df)})
return df
18 changes: 18 additions & 0 deletions evadb/utils/generic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,3 +654,21 @@ def string_comparison_case_insensitive(string_1, string_2) -> bool:
return False

return string_1.lower() == string_2.lower()


def try_to_import_replicate():
try:
import replicate # noqa: F401
except ImportError:
raise ValueError(
"""Could not import replicate python package.
Please install it with `pip install replicate`."""
)


def is_replicate_available():
try:
try_to_import_replicate()
return True
except ValueError:
return False
4 changes: 4 additions & 0 deletions script/test/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ long_integration_test() {
}

notebook_test() {
<<<<<<< HEAD
<<<<<<< HEAD
PYTHONPATH=./ python -m pytest --durations=5 --nbmake --overwrite "./tutorials" --capture=sys --tb=short -v --log-level=WARNING --nbmake-timeout=3000 --ignore="tutorials/08-chatgpt.ipynb" --ignore="tutorials/14-food-review-tone-analysis-and-response.ipynb" --ignore="tutorials/15-AI-powered-join.ipynb" --ignore="tutorials/16-homesale-forecasting.ipynb" --ignore="tutorials/17-home-rental-prediction.ipynb"
=======
Expand All @@ -107,6 +108,9 @@ notebook_test() {
PYTHONPATH=./ python -m pytest --durations=5 --nbmake --overwrite "./tutorials" --capture=sys --tb=short -v --log-level=WARNING --nbmake-timeout=3000 --ignore="tutorials/08-chatgpt.ipynb" --ignore="tutorials/14-food-review-tone-analysis-and-response.ipynb" --ignore="tutorials/15-AI-powered-join.ipynb" --ignore="tutorials/16-homesale-forecasting.ipynb" --ignore="tutorials/17-home-rental-prediction.ipynb"
>>>>>>> 40a10ce1 (Bump v0.3.4+ dev)
>>>>>>> eva-master
=======
PYTHONPATH=./ python -m pytest --durations=5 --nbmake --overwrite "./tutorials" --capture=sys --tb=short -v --log-level=WARNING --nbmake-timeout=3000 --ignore="tutorials/08-chatgpt.ipynb" --ignore="tutorials/14-food-review-tone-analysis-and-response.ipynb" --ignore="tutorials/15-AI-powered-join.ipynb" --ignore="tutorials/16-homesale-forecasting.ipynb" --ignore="tutorials/17-home-rental-prediction.ipynb" --ignore="tutorials/18-stable-diffusion.ipynb"
>>>>>>> bf022329 (Add stable diffusion integration (#1240))
code=$?
print_error_code $code "NOTEBOOK TEST"
}
Expand Down
8 changes: 8 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,10 @@ def read(path, encoding="utf-8"):
"neuralforecast" # MODEL TRAIN AND FINE TUNING
]

imagegen_libs = [
"replicate"
]

### NEEDED FOR DEVELOPER TESTING ONLY

dev_libs = [
Expand Down Expand Up @@ -195,6 +199,7 @@ def read(path, encoding="utf-8"):
"sklearn": sklearn_libs,
"forecasting": forecasting_libs,
# everything except ray, qdrant, ludwig and postgres. The first three fail on pyhton 3.11.
<<<<<<< HEAD
"dev": dev_libs + vision_libs + document_libs + function_libs + notebook_libs + forecasting_libs + sklearn_libs,
<<<<<<< HEAD
=======
Expand All @@ -210,6 +215,9 @@ def read(path, encoding="utf-8"):
"dev": dev_libs + vision_libs + document_libs + function_libs + notebook_libs + forecasting_libs + sklearn_libs,
>>>>>>> 40a10ce1 (Bump v0.3.4+ dev)
>>>>>>> eva-master
=======
"dev": dev_libs + vision_libs + document_libs + function_libs + notebook_libs + forecasting_libs + sklearn_libs + imagegen_libs,
>>>>>>> bf022329 (Add stable diffusion integration (#1240))
}

setup(
Expand Down
Loading

0 comments on commit 2b924b7

Please sign in to comment.