Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve image generation endpoints #1176

Merged
merged 2 commits into from
Apr 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion agixt/providers/ezlocalai.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
import random
import re
import numpy as np
import requests
import os
import uuid

try:
import openai
Expand Down Expand Up @@ -47,7 +50,15 @@ def __init__(

@staticmethod
def services():
return ["llm", "tts", "transcription", "translation", "vision", "embeddings"]
return [
"llm",
"tts",
"image",
"embeddings",
"transcription",
"translation",
"vision",
]

def rotate_uri(self):
self.FAILURES.append(self.API_URI)
Expand Down Expand Up @@ -166,6 +177,25 @@ async def text_to_speech(self, text: str):
)
return tts_response.content

async def generate_image(self, prompt: str) -> str:
filename = f"{uuid.uuid4()}.png"
image_path = f"./WORKSPACE/{filename}"
openai.base_url = self.API_URI if self.API_URI else "https://api.openai.com/v1/"
openai.api_key = self.OPENAI_API_KEY
response = openai.images.generate(
prompt=prompt,
model="stabilityai/sdxl-turbo",
n=1,
size="512x512",
response_format="url",
)
logging.info(f"Image Generated for prompt:{prompt}")
url = response.data[0].url
with open(image_path, "wb") as f:
f.write(requests.get(url).content)
agixt_uri = os.environ.get("AGIXT_URI", "http://localhost:7437")
return f"{agixt_uri}/outputs/{filename}"

def embeddings(self, input) -> np.ndarray:
openai.base_url = self.API_URI
openai.api_key = self.EZLOCALAI_API_KEY
Expand Down
4 changes: 1 addition & 3 deletions agixt/providers/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ async def inference(self, prompt, tokens: int = 0, images: list = []):
async def generate_image(
self,
prompt: str,
filename: str = "",
negative_prompt: str = "out of frame,lowres,text,error,cropped,worst quality,low quality,jpeg artifacts,ugly,duplicate,morbid,mutilated,out of frame,extra fingers,mutated hands,poorly drawn hands,poorly drawn face,mutation,deformed,blurry,dehydrated,bad anatomy,bad proportions,extra limbs,cloned face,disfigured,gross proportions,malformed limbs,missing arms,missing legs,extra arms,extra legs,fused fingers,too many fingers,long neck,username,watermark,signature",
batch_size: int = 1,
cfg_scale: int = 7,
Expand All @@ -126,8 +125,7 @@ async def generate_image(
tiling: bool = False,
width: int = 768,
) -> str:
if filename == "":
filename = f"{uuid.uuid4()}.png"
filename = f"{uuid.uuid4()}.png"
image_path = f"./WORKSPACE/{filename}"
headers = {}
if (
Expand Down
25 changes: 15 additions & 10 deletions agixt/providers/openai.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import time
import logging
import random
import base64
import requests
import uuid
import os
import numpy as np
from chromadb.utils.embedding_functions import OpenAIEmbeddingFunction

Expand Down Expand Up @@ -175,21 +177,24 @@ async def text_to_speech(self, text: str):
)
return tts_response.content

async def generate_image(self, prompt: str, filename: str = "image.png") -> str:
async def generate_image(self, prompt: str) -> str:
filename = f"{uuid.uuid4()}.png"
image_path = f"./WORKSPACE/{filename}"
openai.base_url = self.API_URI if self.API_URI else "https://api.openai.com/v1/"
openai.api_key = self.OPENAI_API_KEY
response = openai.Image.create(
response = openai.images.generate(
prompt=prompt,
model="dall-e-3",
n=1,
size="256x256",
response_format="b64_json",
size="1024x1024",
response_format="url",
)
logging.info(f"Image Generated for prompt:{prompt}")
image_data = base64.b64decode(response["data"][0]["b64_json"])
with open(image_path, mode="wb") as png:
png.write(image_data)
encoded_image_data = base64.b64encode(image_data).decode("utf-8")
return f"data:image/png;base64,{encoded_image_data}"
url = response.data[0].url
with open(image_path, "wb") as f:
f.write(requests.get(url).content)
agixt_uri = os.environ.get("AGIXT_URI", "http://localhost:7437")
return f"{agixt_uri}/outputs/{filename}"

def embeddings(self, input) -> np.ndarray:
openai.base_url = self.API_URI
Expand Down
1 change: 1 addition & 0 deletions docker-compose-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ services:
- UVICORN_WORKERS=${UVICORN_WORKERS:-10}
- USING_JWT=${USING_JWT:-false}
- AGIXT_API_KEY=${AGIXT_API_KEY}
- AGIXT_URI=${AGIXT_URI-http://agixt:7437}
- DISABLED_EXTENSIONS=${DISABLED_EXTENSIONS:-}
- DISABLED_PROVIDERS=${DISABLED_PROVIDERS:-}
- WORKING_DIRECTORY=${WORKING_DIRECTORY:-/agixt/WORKSPACE}
Expand Down
3 changes: 2 additions & 1 deletion docker-compose-local-nvidia-sd.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ services:
environment:
- UVICORN_WORKERS=${UVICORN_WORKERS:-10}
- AGIXT_API_KEY=${AGIXT_API_KEY}
- AGIXT_URI=${AGIXT_URI-http://agixt:7437}
- WORKING_DIRECTORY=${WORKING_DIRECTORY:-/agixt/WORKSPACE}
- TOKENIZERS_PARALLELISM=False
- TZ=${TZ-America/New_York}
Expand All @@ -25,7 +26,7 @@ services:
depends_on:
- agixt
environment:
- AGIXT_URI=http://agixt:7437
- AGIXT_URI=${AGIXT_URI-http://agixt:7437}
- AGIXT_API_KEY=${AGIXT_API_KEY}
volumes:
- ./agixt/WORKSPACE:/app/WORKSPACE
Expand Down
3 changes: 2 additions & 1 deletion docker-compose-local-nvidia.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ services:
environment:
- UVICORN_WORKERS=${UVICORN_WORKERS:-10}
- AGIXT_API_KEY=${AGIXT_API_KEY}
- AGIXT_URI=${AGIXT_URI-http://agixt:7437}
- WORKING_DIRECTORY=${WORKING_DIRECTORY:-/agixt/WORKSPACE}
- TOKENIZERS_PARALLELISM=False
- TZ=${TZ-America/New_York}
Expand All @@ -25,7 +26,7 @@ services:
depends_on:
- agixt
environment:
- AGIXT_URI=http://agixt:7437
- AGIXT_URI=${AGIXT_URI-http://agixt:7437}
- AGIXT_API_KEY=${AGIXT_API_KEY}
volumes:
- ./agixt/WORKSPACE:/app/WORKSPACE
Expand Down
3 changes: 2 additions & 1 deletion docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ services:
environment:
- UVICORN_WORKERS=${UVICORN_WORKERS:-10}
- AGIXT_API_KEY=${AGIXT_API_KEY}
- AGIXT_URI=${AGIXT_URI-http://agixt:7437}
- WORKING_DIRECTORY=${WORKING_DIRECTORY:-/agixt/WORKSPACE}
- TOKENIZERS_PARALLELISM=False
- TZ=${TZ-America/New_York}
Expand All @@ -25,7 +26,7 @@ services:
depends_on:
- agixt
environment:
- AGIXT_URI=http://agixt:7437
- AGIXT_URI=${AGIXT_URI-http://agixt:7437}
- AGIXT_API_KEY=${AGIXT_API_KEY}
volumes:
- ./agixt/WORKSPACE:/app/WORKSPACE
Expand Down
Loading