Skip to content

Commit

Permalink
fix: pydantic models for http server working now. Fixes #380
Browse files Browse the repository at this point in the history
  • Loading branch information
brycedrennan committed Sep 29, 2023
1 parent ba51364 commit 68fb7d0
Show file tree
Hide file tree
Showing 10 changed files with 66 additions and 42 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,4 @@ tests/vastai_cli.py
**/.eggs
/img_size_memory_usage.csv
/tests/test_cluster_output/
/.env
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,10 @@ A: The AI models are cached in `~/.cache/` (or `HUGGINGFACE_HUB_CACHE`). To dele

## ChangeLog

**13.2.1**
- fix: pydantic models for http server working now. Fixes #380
- fix: install triton so annoying message is gone

**13.2.0**
- fix: allow tile_mode to be set to True or False for backward compatibility
- fix: various pydantic issues have been resolved
Expand Down
4 changes: 2 additions & 2 deletions imaginairy/http/stablestudio/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,5 +119,5 @@ class StableStudioBatchResponse(BaseModel):
images: List[StableStudioImage]


StableStudioInput.update_forward_refs()
StableStudioImage.update_forward_refs()
StableStudioInput.model_rebuild()
StableStudioImage.model_rebuild()
4 changes: 2 additions & 2 deletions imaginairy/http/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


def generate_image(prompt):
"""ImaginPrompt to generated image."""
"""ImaginePrompt to generated image."""
result = next(imagine([prompt]))
img = result.images["generated"]
img_io = BytesIO()
Expand All @@ -27,7 +27,7 @@ def __get_validators__(cls):
yield cls.validate

@classmethod
def validate(cls, v):
def validate(cls, v, info):
if isinstance(v, bytes):
return v
if isinstance(v, str):
Expand Down
22 changes: 11 additions & 11 deletions imaginairy/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,12 @@ def __getattr__(self, key):
self._load_img()
return getattr(self._img, key)

def __setstate__(self, state):
self.__dict__.update(state)

def __getstate__(self):
return self.__dict__

def _load_img(self):
if self._img is None:
from PIL import Image, ImageOps
Expand Down Expand Up @@ -192,16 +198,10 @@ def __repr__(self):
shows filepath or url if available.
"""
return f"<LazyLoadingImage filepath={self._lazy_filepath} url={self._lazy_url}>"


#
# LazyLoadingImage = Annotated[
# _LazyLoadingImage,
# AfterValidator(_LazyLoadingImage.validate),
# PlainSerializer(lambda i: str(i), return_type=str),
# WithJsonSchema({"type": "string"}, mode="serialization"),
# ]
try:
return f"<LazyLoadingImage filepath={self._lazy_filepath} url={self._lazy_url}>"
except Exception as e: # noqa
return f"<LazyLoadingImage RENDER EXCEPTION*{e}*>"


class ControlNetInput(BaseModel):
Expand Down Expand Up @@ -245,7 +245,7 @@ def __repr__(self):
return f"{self.weight}*({self.text})"


class ImaginePrompt(BaseModel):
class ImaginePrompt(BaseModel, protected_namespaces=()):
prompt: Optional[List[WeightedPrompt]] = Field(default=None, validate_default=True)
negative_prompt: Optional[List[WeightedPrompt]] = Field(
default=None, validate_default=True
Expand Down
2 changes: 2 additions & 0 deletions requirements-dev.in
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,5 @@ pytest-randomly
pytest-sugar
responses
wheel

-c tests/constraints.txt
54 changes: 28 additions & 26 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ anyio==3.7.1
# via
# fastapi
# starlette
astroid==2.15.6
astroid==2.15.8
# via pylint
async-timeout==4.0.3
# via aiohttp
Expand All @@ -42,13 +42,13 @@ click-help-colors==0.9.2
# via imaginAIry (setup.py)
click-shell==2.1
# via imaginAIry (setup.py)
contourpy==1.1.0
contourpy==1.1.1
# via matplotlib
coverage==7.3.1
# via -r requirements-dev.in
cycler==0.11.0
cycler==0.12.0
# via matplotlib
diffusers==0.20.2
diffusers==0.21.3
# via imaginAIry (setup.py)
dill==0.3.7
# via pylint
Expand All @@ -62,9 +62,9 @@ facexlib==0.3.0
# via imaginAIry (setup.py)
fairscale==0.4.13
# via imaginAIry (setup.py)
fastapi==0.103.1
fastapi==0.103.2
# via imaginAIry (setup.py)
filelock==3.12.3
filelock==3.12.4
# via
# diffusers
# huggingface-hub
Expand All @@ -77,7 +77,7 @@ frozenlist==1.4.0
# via
# aiohttp
# aiosignal
fsspec[http]==2023.9.0
fsspec[http]==2023.9.2
# via
# huggingface-hub
# pytorch-lightning
Expand All @@ -87,7 +87,7 @@ ftfy==6.1.1
# open-clip-torch
h11==0.14.0
# via uvicorn
huggingface-hub==0.17.1
huggingface-hub==0.17.3
# via
# diffusers
# open-clip-torch
Expand All @@ -98,7 +98,7 @@ idna==3.4
# anyio
# requests
# yarl
imageio==2.31.3
imageio==2.31.4
# via imaginAIry (setup.py)
importlib-metadata==6.8.0
# via diffusers
Expand All @@ -120,10 +120,12 @@ lightning-utilities==0.9.0
# via
# pytorch-lightning
# torchmetrics
llvmlite==0.40.1
llvmlite==0.41.0
# via numba
matplotlib==3.7.3
# via filterpy
# via
# -c tests/constraints.txt
# filterpy
mccabe==0.7.0
# via
# pylama
Expand All @@ -136,10 +138,11 @@ mypy-extensions==1.0.0
# via
# black
# typing-inspect
numba==0.57.1
numba==0.58.0
# via facexlib
numpy==1.24.4
# via
# -c tests/constraints.txt
# contourpy
# diffusers
# facexlib
Expand All @@ -159,7 +162,7 @@ omegaconf==2.3.0
# via imaginAIry (setup.py)
open-clip-torch==2.20.0
# via imaginAIry (setup.py)
opencv-python==4.8.0.76
opencv-python==4.8.1.78
# via
# facexlib
# imaginAIry (setup.py)
Expand All @@ -178,7 +181,7 @@ pathspec==0.11.2
# via
# black
# pycln
pillow==10.0.0
pillow==10.0.1
# via
# diffusers
# facexlib
Expand All @@ -202,19 +205,19 @@ pycln==2.2.2
# via -r requirements-dev.in
pycodestyle==2.11.0
# via pylama
pydantic==2.3.0
pydantic==2.4.2
# via
# fastapi
# imaginAIry (setup.py)
pydantic-core==2.6.3
pydantic-core==2.10.1
# via pydantic
pydocstyle==6.3.0
# via pylama
pyflakes==3.1.0
# via pylama
pylama==8.4.1
# via -r requirements-dev.in
pylint==2.17.5
pylint==2.17.6
# via -r requirements-dev.in
pyparsing==3.1.1
# via matplotlib
Expand Down Expand Up @@ -257,7 +260,7 @@ requests==2.31.0
# transformers
responses==0.23.3
# via -r requirements-dev.in
ruff==0.0.288
ruff==0.0.291
# via -r requirements-dev.in
safetensors==0.3.3
# via
Expand Down Expand Up @@ -312,7 +315,7 @@ torch==1.13.1
# torchvision
torchdiffeq==0.2.3
# via imaginAIry (setup.py)
torchmetrics==1.1.2
torchmetrics==1.2.0
# via
# imaginAIry (setup.py)
# pytorch-lightning
Expand All @@ -330,18 +333,17 @@ tqdm==4.66.1
# open-clip-torch
# pytorch-lightning
# transformers
transformers==4.33.1
transformers==4.33.3
# via imaginAIry (setup.py)
typer==0.9.0
# via pycln
types-pyyaml==6.0.12.11
types-pyyaml==6.0.12.12
# via responses
typing-extensions==4.7.1
typing-extensions==4.8.0
# via
# astroid
# black
# fastapi
# filelock
# huggingface-hub
# libcst
# lightning-utilities
Expand All @@ -355,19 +357,19 @@ typing-extensions==4.7.1
# uvicorn
typing-inspect==0.9.0
# via libcst
urllib3==2.0.4
urllib3==2.0.5
# via
# requests
# responses
uvicorn==0.23.2
# via imaginAIry (setup.py)
wcwidth==0.2.6
wcwidth==0.2.7
# via ftfy
wheel==0.41.2
# via -r requirements-dev.in
wrapt==1.15.0
# via astroid
yarl==1.9.2
# via aiohttp
zipp==3.16.2
zipp==3.17.0
# via importlib-metadata
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,10 @@ def get_git_revision_hash() -> str:
"scipy<1.11",
"timm>=0.4.12,!=0.9.0,!=0.9.1", # for vendored blip
"torchdiffeq>=0.2.0",
"transformers>=4.19.2",
"torchmetrics>=0.6.0",
"torchvision>=0.13.1",
"transformers>=4.19.2",
"triton>=2.0.0; sys_platform!='darwin' and platform_machine!='aarch64'",
"kornia>=0.6",
"uvicorn>=0.16.0",
"xformers>=0.0.16; sys_platform!='darwin' and platform_machine!='aarch64'",
Expand Down
3 changes: 3 additions & 0 deletions tests/constraints.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# held back for python 3.8 compatability
matplotlib<3.8.0
numpy<1.25.0
11 changes: 11 additions & 0 deletions tests/test_schema/test_lazy_load_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,3 +101,14 @@ def test_image_deserialization(red_path, red_url):
for row in rows:
obj = TestModel.model_validate(row)
assert obj.header_img.size == (512, 512)


def test_image_state(red_path):
"""I dont remember what this fixes. Maybe the ability of pydantic to copy an object?."""
img = LazyLoadingImage(filepath=red_path)

# bypass init
img2 = LazyLoadingImage.__new__(LazyLoadingImage)
img2.__setstate__(img.__getstate__())

assert repr(img) == repr(img2)

0 comments on commit 68fb7d0

Please sign in to comment.