Skip to content

Commit

Permalink
build: vendorize refiners
Browse files Browse the repository at this point in the history
so we can still work in conda envs
  • Loading branch information
brycedrennan committed Jan 3, 2024
1 parent f84406f commit 158077f
Show file tree
Hide file tree
Showing 85 changed files with 9,543 additions and 33 deletions.
11 changes: 11 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,17 @@ vendorize_normal_map:
make af


vendorize_refiners:
export [email protected]:finegrain-ai/refiners.git PKG=refiners COMMIT=20c229903f53d05dc1c44659ec97603660ef964c && \
make download_repo REPO=$$REPO PKG=$$PKG COMMIT=$$COMMIT && \
mkdir -p ./imaginairy/vendored/$$PKG && \
rm -rf ./imaginairy/vendored/$$PKG/* && \
cp -R ./downloads/refiners/src/refiners/* ./imaginairy/vendored/$$PKG/ && \
cp ./downloads/refiners/LICENSE ./imaginairy/vendored/$$PKG/ && \
rm -rf ./imaginairy/vendored/$$PKG/training_utils && \
echo "vendored from $$REPO @ $$COMMIT" | tee ./imaginairy/vendored/$$PKG/readme.txt
make af


vendorize: ## vendorize a github repo. `make vendorize [email protected]:openai/CLIP.git PKG=clip`
mkdir -p ./downloads
Expand Down
9 changes: 4 additions & 5 deletions imaginairy/api/generate_compvis.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,17 @@
logger = logging.getLogger(__name__)


def _generate_single_image_compvis(
def _generate_single_image(
prompt: "ImaginePrompt",
debug_img_callback=None,
progress_img_callback=None,
progress_img_interval_steps=3,
progress_img_interval_min_s=0.1,
half_mode=None,
add_caption=False,
# controlnet, finetune, naive, auto
inpaint_method="finetune",
return_latent=False,
dtype=None,
):
import torch.nn
from PIL import Image, ImageOps
Expand Down Expand Up @@ -96,7 +96,7 @@ def _generate_single_image_compvis(
weights_location=prompt.model_weights,
config_path=prompt.model_architecture,
control_weights_locations=control_modes,
half_mode=half_mode,
half_mode=dtype == torch.float16,
for_inpainting=for_inpainting and inpaint_method == "finetune",
)
is_controlnet_model = hasattr(model, "control_key")
Expand Down Expand Up @@ -502,7 +502,6 @@ def _generate_composition_image(
):
from PIL import Image

from imaginairy.api.generate_refiners import generate_single_image
from imaginairy.utils import default, get_default_dtype

cutoff = normalize_image_size(cutoff)
Expand Down Expand Up @@ -532,7 +531,7 @@ def _generate_composition_image(
},
)

result = generate_single_image(composition_prompt, dtype=dtype)
result = _generate_single_image(composition_prompt, dtype=dtype)
img = result.images["generated"]
while img.width < target_width:
from imaginairy.enhancers.upscale_realesrgan import upscale_image
Expand Down
9 changes: 7 additions & 2 deletions imaginairy/api/generate_refiners.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def generate_single_image(
):
import torch.nn
from PIL import Image, ImageOps
from refiners.foundationals.latent_diffusion.schedulers import DDIM, DPMSolver
from tqdm import tqdm

from imaginairy.api.generate import (
Expand Down Expand Up @@ -61,6 +60,10 @@ def generate_single_image(
prepare_image_for_outpaint,
)
from imaginairy.utils.safety import create_safety_score
from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers import (
DDIM,
DPMSolver,
)

if dtype is None:
dtype = torch.float16
Expand Down Expand Up @@ -513,7 +516,9 @@ def prep_control_input(
if not control_config:
msg = f"Unknown control mode: {control_input.mode}"
raise ValueError(msg)
from refiners.foundationals.latent_diffusion import SD1ControlnetAdapter
from imaginairy.vendored.refiners.foundationals.latent_diffusion import (
SD1ControlnetAdapter,
)

controlnet = SD1ControlnetAdapter( # type: ignore
name=control_input.mode,
Expand Down
50 changes: 30 additions & 20 deletions imaginairy/modules/refiners_sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,42 +5,52 @@
from functools import lru_cache
from typing import Any, List, Literal

import refiners.fluxion.layers as fl
import torch
from refiners.fluxion.layers.attentions import ScaledDotProductAttention
from refiners.fluxion.layers.chain import ChainError
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
from refiners.foundationals.latent_diffusion.model import (
from torch import Tensor, device as Device, dtype as DType, nn
from torch.nn import functional as F

import imaginairy.vendored.refiners.fluxion.layers as fl
from imaginairy.schema import WeightedPrompt
from imaginairy.utils.feather_tile import rebuild_image, tile_image
from imaginairy.vendored.refiners.fluxion.layers.attentions import (
ScaledDotProductAttention,
)
from imaginairy.vendored.refiners.fluxion.layers.chain import ChainError
from imaginairy.vendored.refiners.foundationals.clip.text_encoder import (
CLIPTextEncoderL,
)
from imaginairy.vendored.refiners.foundationals.latent_diffusion.model import (
TLatentDiffusionModel,
)
from refiners.foundationals.latent_diffusion.schedulers.ddim import DDIM
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
from refiners.foundationals.latent_diffusion.self_attention_guidance import (
from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.ddim import (
DDIM,
)
from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.scheduler import (
Scheduler,
)
from imaginairy.vendored.refiners.foundationals.latent_diffusion.self_attention_guidance import (
SelfAttentionMap,
)
from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import (
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import (
Controlnet,
SD1ControlnetAdapter,
)
from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import (
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_1.model import (
SD1Autoencoder,
SD1UNet,
StableDiffusion_1 as RefinerStableDiffusion_1,
StableDiffusion_1_Inpainting as RefinerStableDiffusion_1_Inpainting,
)
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import (
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import (
SDXLAutoencoder,
StableDiffusion_XL as RefinerStableDiffusion_XL,
)
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import (
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import (
DoubleTextEncoder,
)
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet
from torch import Tensor, device as Device, dtype as DType, nn
from torch.nn import functional as F

from imaginairy.schema import WeightedPrompt
from imaginairy.utils.feather_tile import rebuild_image, tile_image
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import (
SDXLUNet,
)
from imaginairy.weight_management.conversion import cast_weights

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -375,8 +385,8 @@ def prompts_to_embeddings(self, prompts: List[WeightedPrompt]) -> Tensor:
import torch

total_weight = sum(wp.weight for wp in prompts)
if str(self.clip_text_encoder.device) == "cpu":
self.clip_text_encoder = self.clip_text_encoder.to(dtype=torch.float32)
if str(self.clip_text_encoder.device) == "cpu": # type: ignore
self.clip_text_encoder = self.clip_text_encoder.to(dtype=torch.float32) # type: ignore
conditioning = sum(
self.clip_text_encoder(wp.text) * (wp.weight / total_weight)
for wp in prompts
Expand Down
16 changes: 12 additions & 4 deletions imaginairy/utils/model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@
try_to_load_from_cache,
)
from omegaconf import OmegaConf
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
from refiners.foundationals.latent_diffusion import DoubleTextEncoder, SD1UNet, SDXLUNet
from refiners.foundationals.latent_diffusion.model import LatentDiffusionModel
from safetensors.torch import load_file

from imaginairy import config as iconfig
Expand All @@ -29,6 +26,17 @@
from imaginairy.utils.model_cache import memory_managed_model
from imaginairy.utils.named_resolutions import normalize_image_size
from imaginairy.utils.paths import PKG_ROOT
from imaginairy.vendored.refiners.foundationals.clip.text_encoder import (
CLIPTextEncoderL,
)
from imaginairy.vendored.refiners.foundationals.latent_diffusion import (
DoubleTextEncoder,
SD1UNet,
SDXLUNet,
)
from imaginairy.vendored.refiners.foundationals.latent_diffusion.model import (
LatentDiffusionModel,
)
from imaginairy.weight_management import translators

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -823,7 +831,7 @@ def open_weights(filepath, device=None):
device = get_device()

if "safetensor" in filepath.lower():
from refiners.fluxion.utils import safe_open
from imaginairy.vendored.refiners.fluxion.utils import safe_open

with safe_open(path=filepath, framework="pytorch", device=device) as tensors:
state_dict = {
Expand Down
21 changes: 21 additions & 0 deletions imaginairy/vendored/refiners/LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
MIT License

Copyright (c) 2023 Lagon Technologies

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
File renamed without changes.
3 changes: 3 additions & 0 deletions imaginairy/vendored/refiners/fluxion/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from refiners.fluxion.utils import load_from_safetensors, manual_seed, norm, pad, save_to_safetensors

__all__ = ["norm", "manual_seed", "save_to_safetensors", "load_from_safetensors", "pad"]
3 changes: 3 additions & 0 deletions imaginairy/vendored/refiners/fluxion/adapters/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from refiners.fluxion.adapters.adapter import Adapter

__all__ = ["Adapter"]
101 changes: 101 additions & 0 deletions imaginairy/vendored/refiners/fluxion/adapters/adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import contextlib
from typing import Any, Generic, Iterator, TypeVar

import refiners.fluxion.layers as fl

T = TypeVar("T", bound=fl.Module)
TAdapter = TypeVar("TAdapter", bound="Adapter[Any]") # Self (see PEP 673)


class Adapter(Generic[T]):
# we store _target into a one element list to avoid pytorch thinking it is a submodule
_target: "list[T]"

def __init_subclass__(cls, **kwargs: Any) -> None:
super().__init_subclass__(**kwargs)
assert issubclass(cls, fl.Chain), f"Adapter {cls.__name__} must be a Chain"

@property
def target(self) -> T:
return self._target[0]

@contextlib.contextmanager
def setup_adapter(self, target: T) -> Iterator[None]:
assert isinstance(self, fl.Chain)
assert (not hasattr(self, "_modules")) or (
len(self) == 0
), "Call the Chain constructor in the setup_adapter context."
self._target = [target]

if not isinstance(self.target, fl.ContextModule):
yield
return

_old_can_refresh_parent = target._can_refresh_parent
target._can_refresh_parent = False
yield
target._can_refresh_parent = _old_can_refresh_parent

def inject(self: TAdapter, parent: fl.Chain | None = None) -> TAdapter:
assert isinstance(self, fl.Chain)

if (parent is None) and isinstance(self.target, fl.ContextModule):
parent = self.target.parent
if parent is not None:
assert isinstance(parent, fl.Chain), f"{self.target} has invalid parent {parent}"

target_parent = self.find_parent(self.target)

if parent is None:
if isinstance(self.target, fl.ContextModule):
self.target._set_parent(target_parent) # type: ignore[reportPrivateUsage]
return self

# In general, `true_parent` is `parent`. We do this to support multiple adaptation,
# i.e. initializing two adapters before injecting them.
true_parent = parent.ensure_find_parent(self.target)
true_parent.replace(
old_module=self.target,
new_module=self,
old_module_parent=target_parent,
)
return self

def eject(self) -> None:
assert isinstance(self, fl.Chain)

# In general, the "actual target" is the target.
# Here we deal with the edge case where the target
# is part of the replacement block and has been adapted by
# another adapter after this one. For instance, this is the
# case when stacking Controlnets.
actual_target = lookup_top_adapter(self, self.target)

if (parent := self.parent) is None:
if isinstance(actual_target, fl.ContextModule):
actual_target._set_parent(None) # type: ignore[reportPrivateUsage]
else:
parent.replace(old_module=self, new_module=actual_target)

def _pre_structural_copy(self) -> None:
if isinstance(self.target, fl.Chain):
raise RuntimeError("Chain adapters typically cannot be copied, eject them first.")

def _post_structural_copy(self: TAdapter, source: TAdapter) -> None:
self._target = [source.target]


def lookup_top_adapter(top: fl.Chain, target: fl.Module) -> fl.Module:
"""Lookup and return last adapter in parents tree (or target if none)."""

target_parent = top.find_parent(target)
if (target_parent is None) or (target_parent == top):
return target

r, p = target, target_parent
while p != top:
if isinstance(p, Adapter):
r = p
assert p.parent, f"parent tree of {top} is broken"
p = p.parent
return r
Loading

0 comments on commit 158077f

Please sign in to comment.