Skip to content

Commit

Permalink
add: dalle-mini files for inference server
Browse files Browse the repository at this point in the history
  • Loading branch information
biswaroop1547 committed Jul 12, 2023
1 parent 95b9f1f commit 8bba72f
Show file tree
Hide file tree
Showing 6 changed files with 240 additions and 0 deletions.
Empty file added dfs-dalle/README.md
Empty file.
44 changes: 44 additions & 0 deletions dfs-dalle/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import logging

import uvicorn
from dotenv import load_dotenv
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from models import DalleBasedModel
from routes import router as api_router

load_dotenv()

logging.basicConfig(
format="%(asctime)s %(levelname)-8s %(message)s",
level=logging.INFO,
datefmt="%Y-%m-%d %H:%M:%S",
)


def create_start_app_handler(app: FastAPI):
def start_app() -> None:
DalleBasedModel.get_model()

return start_app


def get_application() -> FastAPI:
application = FastAPI(title="prem-chat", debug=True, version="0.0.1")
application.include_router(api_router, prefix="/v1")
application.add_event_handler("startup", create_start_app_handler(application))
application.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
return application


app = get_application()


if __name__ == "__main__":
uvicorn.run("main:app", host="0.0.0.0", port=8000)
122 changes: 122 additions & 0 deletions dfs-dalle/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import base64
import io
import os
import random
from functools import partial

import jax
import jax.numpy as jnp
import numpy as np
from dalle_mini import DalleBart, DalleBartProcessor
from flax.jax_utils import replicate
from flax.training.common_utils import shard_prng_key
from PIL import Image
from vqgan_jax.modeling_flax_vqgan import VQModel


class DalleBasedModel(object):
model = None
model_params = None
decoder = None
decoder_params = None
processor = None

generate_fn = None
decode_fn = None

rand_key = None

@classmethod
def generate(
cls,
prompt: str,
n: int,
size: str,
response_format: str,
negative_prompt: str = None,
top_k: float = None,
top_p: float = None,
temperature: float = None,
cond_scale: float = 5.0,
):
seed = random.randint(0, 2**32 - 1)
cls.rand_key = jax.random.PRNGKey(seed)
tokenized_prompts = cls.processor([prompt])
tokenized_prompt = replicate(tokenized_prompts)

data = []
for _ in range(n):
# get a new key
key, subkey = jax.random.split(cls.rand_key)
# generate images
encoded_images = cls.generate_fn(
tokenized_prompt,
shard_prng_key(subkey),
cls.model_params,
top_k,
top_p,
temperature,
cond_scale,
)
# remove BOS
encoded_images = encoded_images.sequences[..., 1:]
# decode images
decoded_images = cls.decode_fn(encoded_images, cls.decoder_params)
decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))
# no loop over decoded_images since inference on single prompt -> single image
img = Image.fromarray(np.asarray(decoded_images[0] * 255, dtype=np.uint8))
buffered = io.BytesIO()
img.save(buffered, format="PNG")
data.append(
{response_format: base64.b64encode(buffered.getvalue()).decode("utf-8")}
)

return data

@classmethod
def get_model(cls):
jax.local_device_count()

@partial(jax.pmap, axis_name="batch", static_broadcasted_argnums=(3, 4, 5, 6))
def p_generate(
tokenized_prompt, key, params, top_k, top_p, temperature, condition_scale
):
return cls.model.generate(
**tokenized_prompt,
prng_key=key,
params=params,
top_k=top_k,
top_p=top_p,
temperature=temperature,
condition_scale=condition_scale,
)

@partial(jax.pmap, axis_name="batch")
def p_decode(indices, params):
return cls.decoder.decode_code(indices, params=params)

cls.generate_fn = p_generate
cls.decode_fn = p_decode

if cls.model is None:
cls.model, params = DalleBart.from_pretrained(
os.getenv("DALLE_MODEL_ID", "dalle-mini/dalle-mini"),
revision=None,
dtype=jnp.float16,
_do_init=False,
)
cls.decoder, vqgan_params = VQModel.from_pretrained(
os.getenv("VQGAN_MODEL_ID", "dalle-mini/vqgan_imagenet_f16_16384"),
revision=os.getenv(
"VQGAN_REVISION_ID", "e93a26e7707683d349bf5d5c41c5b0ef69b677a9"
),
_do_init=False,
)
cls.processor = DalleBartProcessor.from_pretrained(
os.getenv("DALLE_MODEL_ID", "dalle-mini/dalle-mini"), revision=None
)

cls.model_params = replicate(params)
cls.decoder_params = replicate(vqgan_params)

return cls.model
58 changes: 58 additions & 0 deletions dfs-dalle/routes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from datetime import datetime as dt
from typing import List, Union

from fastapi import APIRouter
from models import DalleBasedModel as model
from pydantic import BaseModel


class ImageGenerationInput(BaseModel):
prompt: str
n: int = 1
size: str = ""
cond_scale: float = 5.0
temperature: float = None
top_p: float = None
top_k: float = None
response_format: str = "b64_json"
user: str = ""


class ImageObjectUrl(BaseModel):
url: str


class ImageObjectBase64(BaseModel):
b64_json: str


class ImageGenerationResponse(BaseModel):
created: int = int(dt.now().timestamp())
data: Union[List[ImageObjectUrl], List[ImageObjectBase64]]


class HealthResponse(BaseModel):
status: bool


router = APIRouter()


@router.get("/", response_model=HealthResponse)
async def health():
return HealthResponse(status=True)


@router.post("/images/generations")
async def images_generations(body: ImageGenerationInput):
images = model.generate(
prompt=body.prompt,
n=body.n,
size=body.size,
temperature=body.temperature,
top_p=body.top_p,
top_k=body.top_k,
cond_scale=body.cond_scale,
response_format=body.response_format,
)
return ImageGenerationResponse(created=int(dt.now().timestamp()), data=images)
Empty file added dfs-dalle/tests/__init__.py
Empty file.
16 changes: 16 additions & 0 deletions dfs-dalle/tests/test_views.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from fastapi.testclient import TestClient
from main import get_application


def test_generate_image() -> None:
app = get_application()
with TestClient(app) as client:
response = client.post(
"/v1/images/generations",
json={
"prompt": "Hello World",
"n": 1,
},
)
assert response.status_code == 200
print(response.json())

0 comments on commit 8bba72f

Please sign in to comment.