Skip to content

Commit

Permalink
feature: flux model
Browse files Browse the repository at this point in the history
works on my nvidia 4090 and probably not many other places. rush job. will not be providing support but will accept pull requests
  • Loading branch information
brycedrennan committed Sep 22, 2024
1 parent e83b17b commit c7e3f0e
Show file tree
Hide file tree
Showing 4 changed files with 300 additions and 6 deletions.
25 changes: 21 additions & 4 deletions imaginairy/api/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ def imagine_image_files(
"""
from PIL import ImageDraw

from imaginairy.api.video_sample import generate_video
from imaginairy.utils import get_next_filenumber, prompt_normalized

generated_imgs_path = os.path.join(outdir, "generated")
Expand Down Expand Up @@ -105,6 +104,8 @@ def _record_step(img, description, image_count, step_count, prompt):
continue
result_filenames.append(primary_filename)
if primary_filename and videogen:
from imaginairy.api.video_sample import generate_video

try:
generate_video(
input_path=primary_filename,
Expand Down Expand Up @@ -229,7 +230,6 @@ def imagine(
"""
import torch.nn

from imaginairy.api.generate_refiners import generate_single_image
from imaginairy.schema import ImaginePrompt
from imaginairy.utils import (
check_torch_version,
Expand Down Expand Up @@ -262,10 +262,27 @@ def imagine(
concrete_prompt = prompt.make_concrete_copy()
prog_text = f"{i + 1}/{num_prompts}"
logger.info(f"🖼 {prog_text} {concrete_prompt.prompt_description()}")
# Determine which generate function to use based on the model
if (
concrete_prompt.model_architecture
and concrete_prompt.model_architecture.name.lower() == "flux"
):
from imaginairy.api.generate_flux import (
generate_single_image as generate_single_flux_image,
)

generate_func = generate_single_flux_image
else:
from imaginairy.api.generate_refiners import (
generate_single_image as generate_single_image_refiners,
)

generate_func = generate_single_image_refiners

for attempt in range(unsafe_retry_count + 1):
if attempt > 0 and isinstance(concrete_prompt.seed, int):
concrete_prompt.seed += 100_000_000 + attempt
result = generate_single_image(
result = generate_func(
concrete_prompt,
debug_img_callback=debug_img_callback,
progress_img_callback=progress_img_callback,
Expand All @@ -275,7 +292,7 @@ def imagine(
dtype=torch.float16 if half_mode else torch.float32,
output_perf=True,
)
if not result.safety_score.is_filtered:
if not result.safety_score or not result.safety_score.is_filtered:
break
if attempt < unsafe_retry_count:
logger.info(" Image was unsafe, retrying with new seed...")
Expand Down
244 changes: 244 additions & 0 deletions imaginairy/api/generate_flux.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
import logging
import os
from functools import lru_cache

from imaginairy.schema import ImaginePrompt, ImagineResult
from imaginairy.utils.log_utils import ImageLoggingContext

logger = logging.getLogger(__name__)


@lru_cache(maxsize=1)
def load_flux_models():
import torch
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
from optimum.quanto import freeze, qfloat8, quantize
from transformers import (
CLIPTextModel,
CLIPTokenizer,
T5EncoderModel,
T5TokenizerFast,
)

from imaginairy.utils.downloads import get_cache_dir

dtype = torch.bfloat16
bfl_repo = "black-forest-labs/FLUX.1-schnell"
revision = "refs/pr/1"
quant_type = "qfloat8" # Define the quantization type

# Define paths for saved quantized models
quantized_dir = os.path.join(get_cache_dir(), "quantized_flux_models")
os.makedirs(quantized_dir, exist_ok=True)
transformer_path = os.path.join(
quantized_dir, f"quantized_transformer_{quant_type}.pt"
)
text_encoder_2_path = os.path.join(
quantized_dir, f"quantized_text_encoder_2_{quant_type}.pt"
)

# Load and set up models
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
bfl_repo, subfolder="scheduler", revision=revision
)
text_encoder = CLIPTextModel.from_pretrained(
"openai/clip-vit-large-patch14", torch_dtype=dtype
)
tokenizer = CLIPTokenizer.from_pretrained(
"openai/clip-vit-large-patch14", torch_dtype=dtype
)
tokenizer_2 = T5TokenizerFast.from_pretrained(
bfl_repo, subfolder="tokenizer_2", torch_dtype=dtype, revision=revision
)
vae = AutoencoderKL.from_pretrained(
bfl_repo, subfolder="vae", torch_dtype=dtype, revision=revision
)

# Load or create quantized models
if os.path.exists(transformer_path):
transformer = torch.load(transformer_path)
else:
transformer = FluxTransformer2DModel.from_pretrained(
bfl_repo, subfolder="transformer", torch_dtype=dtype, revision=revision
)
quantize(transformer, weights=qfloat8)
freeze(transformer)
torch.save(transformer, transformer_path)

if os.path.exists(text_encoder_2_path):
text_encoder_2 = torch.load(text_encoder_2_path)
else:
text_encoder_2 = T5EncoderModel.from_pretrained(
bfl_repo, subfolder="text_encoder_2", torch_dtype=dtype, revision=revision
)
quantize(text_encoder_2, weights=qfloat8)
freeze(text_encoder_2)
torch.save(text_encoder_2, text_encoder_2_path)

return (
scheduler,
text_encoder,
tokenizer,
text_encoder_2,
tokenizer_2,
vae,
transformer,
)


def generate_single_image(
prompt: ImaginePrompt,
debug_img_callback=None,
progress_img_callback=None,
progress_img_interval_steps=3,
progress_img_interval_min_s=0.1,
add_caption=False,
return_latent=False,
dtype=None,
logging_context: ImageLoggingContext | None = None,
output_perf=False,
image_name="",
):
import torch
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline

from imaginairy.api.generate import IMAGINAIRY_SAFETY_MODE
from imaginairy.enhancers.upscale import upscale_image
from imaginairy.utils import clear_gpu_cache, seed_everything
from imaginairy.utils.log_utils import ImageLoggingContext
from imaginairy.utils.safety import create_safety_score

# Initialize logging context
if not logging_context:

def latent_logger(latents):
progress_latents.append(latents)

lc = ImageLoggingContext(
prompt=prompt,
debug_img_callback=debug_img_callback,
progress_img_callback=progress_img_callback,
progress_img_interval_steps=progress_img_interval_steps,
progress_img_interval_min_s=progress_img_interval_min_s,
progress_latent_callback=latent_logger
if prompt.collect_progress_latents
else None,
)
else:
lc = logging_context

with lc:
# Seed for reproducibility
seed_everything(prompt.seed)
clear_gpu_cache()

# Load models
with lc.timing("model-load"):
(
scheduler,
text_encoder,
tokenizer,
text_encoder_2,
tokenizer_2,
vae,
transformer,
) = load_flux_models()

# Set up pipeline
pipe = FluxPipeline(
scheduler=scheduler,
text_encoder=text_encoder,
tokenizer=tokenizer,
text_encoder_2=text_encoder_2,
tokenizer_2=tokenizer_2,
vae=vae,
transformer=transformer,
)
pipe.enable_model_cpu_offload()

# Generate image
generator = torch.Generator().manual_seed(prompt.seed)
with lc.timing("image-generation"):
output = pipe(
prompt=prompt.prompt_text,
width=prompt.width,
height=prompt.height,
num_inference_steps=prompt.steps,
guidance_scale=prompt.prompt_strength,
generator=generator,
)
image = output.images[0]

# Perform safety check
with lc.timing("safety-filter"):
safety_score = create_safety_score(
image,
safety_mode=IMAGINAIRY_SAFETY_MODE,
)
is_filtered = safety_score.is_filtered

# Initialize result images
result_images = {}
progress_latents = []

# If the image is unsafe, we can discard it or handle it accordingly
if is_filtered:
image = None # Discard the unsafe image
else:
result_images["generated"] = image

# Optionally upscale the image
if prompt.upscale:
with lc.timing("upscaling"):
upscaled_img = upscale_image(image)
result_images["upscaled"] = upscaled_img
final_image = upscaled_img
else:
final_image = image

if add_caption:
with lc.timing("caption-img"):
from imaginairy.enhancers.describe_image_blip import (
generate_caption,
)

caption = generate_caption(final_image)
logger.info(f"Generated caption: {caption}")

if prompt.fix_faces:
with lc.timing("face-enhancement"):
from imaginairy.enhancers.face_restoration_codeformer import (
enhance_faces,
)

logger.info("Fixing 😊 's in 🖼 using CodeFormer...")
final_image = enhance_faces(
final_image, fidelity=prompt.fix_faces_fidelity
)
result_images["face_enhanced"] = final_image

# Create ImagineResult
result = ImagineResult(
img=final_image,
prompt=prompt,
is_nsfw=safety_score.is_nsfw if safety_score else False,
safety_score=safety_score,
result_images=result_images,
performance_stats=lc.get_performance_stats(),
progress_latents=progress_latents,
)

_image_name = f"{image_name} " if image_name else ""
logger.info(f"Generated {_image_name}image in {result.total_time():.1f}s")

if result.performance_stats:
log = logger.info if output_perf else logger.debug
log(f" Timings: {result.timings_str()}")
if torch.cuda.is_available():
log(f" Peak VRAM: {result.gpu_str('memory_peak')}")
log(f" Peak VRAM Delta: {result.gpu_str('memory_peak_delta')}")
log(f" Ending VRAM: {result.gpu_str('memory_end')}")

clear_gpu_cache()
return result
16 changes: 16 additions & 0 deletions imaginairy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,12 @@ def primary_alias(self):
defaults={"size": "1024x576"},
config_path="configs/svd_xt_image_decoder.yaml",
),
ModelArchitecture(
name="Flux",
aliases=["flux"],
output_modality="image",
defaults={"size": "1024", "steps": 3},
),
]

MODEL_ARCHITECTURE_LOOKUP = {}
Expand Down Expand Up @@ -242,6 +248,16 @@ def __post_init__(self):
weights_location="https://huggingface.co/imaginairy/stable-video-diffusion/resolve/f9dce2757a0713da6262f35438050357c2be7ee6/svd_xt_image_decoder.fp16.safetensors",
defaults={"frames": 25, "steps": 30},
),
ModelWeightsConfig(
name="FLUX.1-schnell",
aliases=MODEL_ARCHITECTURE_LOOKUP["flux"].aliases,
architecture=MODEL_ARCHITECTURE_LOOKUP["flux"],
weights_location="",
defaults={
"steps": 5,
"negative_prompt": "",
},
),
]

MODEL_WEIGHT_CONFIG_LOOKUP = {}
Expand Down
21 changes: 19 additions & 2 deletions imaginairy/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,10 +656,27 @@ def validate_solver_type(cls, v, info: core_schema.FieldValidationInfo):

@field_validator("steps", mode="before")
def validate_steps(cls, v, info: core_schema.FieldValidationInfo):
steps_lookup = {"ddim": 50, "dpmpp": 20}
model_weights = info.data.get("model_weights")

# Try to get steps from model weights defaults
if (
v is None
and model_weights
and isinstance(model_weights, config.ModelWeightsConfig)
):
v = model_weights.defaults.get("steps")

# If not found in model weights, try model architecture defaults
if v is None and model_weights and model_weights.architecture:
v = model_weights.architecture.defaults.get("steps")

# If still not found, use solver-specific defaults
if v is None:
v = steps_lookup[info.data["solver_type"]]
solver_type = info.data.get("solver_type", "ddim").lower()
steps_lookup = {"ddim": 50, "dpmpp": 20}
v = steps_lookup.get(
solver_type, 50
) # Default to 50 if solver not recognized

try:
return int(v)
Expand Down

0 comments on commit c7e3f0e

Please sign in to comment.