From c7e3f0e4e2d33dd332aee5bb90a3b521a7a0db45 Mon Sep 17 00:00:00 2001 From: Bryce Date: Sat, 21 Sep 2024 17:18:24 -0700 Subject: [PATCH] feature: flux model works on my nvidia 4090 and probably not many other places. rush job. will not be providing support but will accept pull requests --- imaginairy/api/generate.py | 25 +++- imaginairy/api/generate_flux.py | 244 ++++++++++++++++++++++++++++++++ imaginairy/config.py | 16 +++ imaginairy/schema.py | 21 ++- 4 files changed, 300 insertions(+), 6 deletions(-) create mode 100644 imaginairy/api/generate_flux.py diff --git a/imaginairy/api/generate.py b/imaginairy/api/generate.py index 4a2e1f5f..83509a23 100755 --- a/imaginairy/api/generate.py +++ b/imaginairy/api/generate.py @@ -58,7 +58,6 @@ def imagine_image_files( """ from PIL import ImageDraw - from imaginairy.api.video_sample import generate_video from imaginairy.utils import get_next_filenumber, prompt_normalized generated_imgs_path = os.path.join(outdir, "generated") @@ -105,6 +104,8 @@ def _record_step(img, description, image_count, step_count, prompt): continue result_filenames.append(primary_filename) if primary_filename and videogen: + from imaginairy.api.video_sample import generate_video + try: generate_video( input_path=primary_filename, @@ -229,7 +230,6 @@ def imagine( """ import torch.nn - from imaginairy.api.generate_refiners import generate_single_image from imaginairy.schema import ImaginePrompt from imaginairy.utils import ( check_torch_version, @@ -262,10 +262,27 @@ def imagine( concrete_prompt = prompt.make_concrete_copy() prog_text = f"{i + 1}/{num_prompts}" logger.info(f"🖼 {prog_text} {concrete_prompt.prompt_description()}") + # Determine which generate function to use based on the model + if ( + concrete_prompt.model_architecture + and concrete_prompt.model_architecture.name.lower() == "flux" + ): + from imaginairy.api.generate_flux import ( + generate_single_image as generate_single_flux_image, + ) + + generate_func = generate_single_flux_image + else: + from imaginairy.api.generate_refiners import ( + generate_single_image as generate_single_image_refiners, + ) + + generate_func = generate_single_image_refiners + for attempt in range(unsafe_retry_count + 1): if attempt > 0 and isinstance(concrete_prompt.seed, int): concrete_prompt.seed += 100_000_000 + attempt - result = generate_single_image( + result = generate_func( concrete_prompt, debug_img_callback=debug_img_callback, progress_img_callback=progress_img_callback, @@ -275,7 +292,7 @@ def imagine( dtype=torch.float16 if half_mode else torch.float32, output_perf=True, ) - if not result.safety_score.is_filtered: + if not result.safety_score or not result.safety_score.is_filtered: break if attempt < unsafe_retry_count: logger.info(" Image was unsafe, retrying with new seed...") diff --git a/imaginairy/api/generate_flux.py b/imaginairy/api/generate_flux.py new file mode 100644 index 00000000..d481a19e --- /dev/null +++ b/imaginairy/api/generate_flux.py @@ -0,0 +1,244 @@ +import logging +import os +from functools import lru_cache + +from imaginairy.schema import ImaginePrompt, ImagineResult +from imaginairy.utils.log_utils import ImageLoggingContext + +logger = logging.getLogger(__name__) + + +@lru_cache(maxsize=1) +def load_flux_models(): + import torch + from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler + from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel + from optimum.quanto import freeze, qfloat8, quantize + from transformers import ( + CLIPTextModel, + CLIPTokenizer, + T5EncoderModel, + T5TokenizerFast, + ) + + from imaginairy.utils.downloads import get_cache_dir + + dtype = torch.bfloat16 + bfl_repo = "black-forest-labs/FLUX.1-schnell" + revision = "refs/pr/1" + quant_type = "qfloat8" # Define the quantization type + + # Define paths for saved quantized models + quantized_dir = os.path.join(get_cache_dir(), "quantized_flux_models") + os.makedirs(quantized_dir, exist_ok=True) + transformer_path = os.path.join( + quantized_dir, f"quantized_transformer_{quant_type}.pt" + ) + text_encoder_2_path = os.path.join( + quantized_dir, f"quantized_text_encoder_2_{quant_type}.pt" + ) + + # Load and set up models + scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( + bfl_repo, subfolder="scheduler", revision=revision + ) + text_encoder = CLIPTextModel.from_pretrained( + "openai/clip-vit-large-patch14", torch_dtype=dtype + ) + tokenizer = CLIPTokenizer.from_pretrained( + "openai/clip-vit-large-patch14", torch_dtype=dtype + ) + tokenizer_2 = T5TokenizerFast.from_pretrained( + bfl_repo, subfolder="tokenizer_2", torch_dtype=dtype, revision=revision + ) + vae = AutoencoderKL.from_pretrained( + bfl_repo, subfolder="vae", torch_dtype=dtype, revision=revision + ) + + # Load or create quantized models + if os.path.exists(transformer_path): + transformer = torch.load(transformer_path) + else: + transformer = FluxTransformer2DModel.from_pretrained( + bfl_repo, subfolder="transformer", torch_dtype=dtype, revision=revision + ) + quantize(transformer, weights=qfloat8) + freeze(transformer) + torch.save(transformer, transformer_path) + + if os.path.exists(text_encoder_2_path): + text_encoder_2 = torch.load(text_encoder_2_path) + else: + text_encoder_2 = T5EncoderModel.from_pretrained( + bfl_repo, subfolder="text_encoder_2", torch_dtype=dtype, revision=revision + ) + quantize(text_encoder_2, weights=qfloat8) + freeze(text_encoder_2) + torch.save(text_encoder_2, text_encoder_2_path) + + return ( + scheduler, + text_encoder, + tokenizer, + text_encoder_2, + tokenizer_2, + vae, + transformer, + ) + + +def generate_single_image( + prompt: ImaginePrompt, + debug_img_callback=None, + progress_img_callback=None, + progress_img_interval_steps=3, + progress_img_interval_min_s=0.1, + add_caption=False, + return_latent=False, + dtype=None, + logging_context: ImageLoggingContext | None = None, + output_perf=False, + image_name="", +): + import torch + from diffusers.pipelines.flux.pipeline_flux import FluxPipeline + + from imaginairy.api.generate import IMAGINAIRY_SAFETY_MODE + from imaginairy.enhancers.upscale import upscale_image + from imaginairy.utils import clear_gpu_cache, seed_everything + from imaginairy.utils.log_utils import ImageLoggingContext + from imaginairy.utils.safety import create_safety_score + + # Initialize logging context + if not logging_context: + + def latent_logger(latents): + progress_latents.append(latents) + + lc = ImageLoggingContext( + prompt=prompt, + debug_img_callback=debug_img_callback, + progress_img_callback=progress_img_callback, + progress_img_interval_steps=progress_img_interval_steps, + progress_img_interval_min_s=progress_img_interval_min_s, + progress_latent_callback=latent_logger + if prompt.collect_progress_latents + else None, + ) + else: + lc = logging_context + + with lc: + # Seed for reproducibility + seed_everything(prompt.seed) + clear_gpu_cache() + + # Load models + with lc.timing("model-load"): + ( + scheduler, + text_encoder, + tokenizer, + text_encoder_2, + tokenizer_2, + vae, + transformer, + ) = load_flux_models() + + # Set up pipeline + pipe = FluxPipeline( + scheduler=scheduler, + text_encoder=text_encoder, + tokenizer=tokenizer, + text_encoder_2=text_encoder_2, + tokenizer_2=tokenizer_2, + vae=vae, + transformer=transformer, + ) + pipe.enable_model_cpu_offload() + + # Generate image + generator = torch.Generator().manual_seed(prompt.seed) + with lc.timing("image-generation"): + output = pipe( + prompt=prompt.prompt_text, + width=prompt.width, + height=prompt.height, + num_inference_steps=prompt.steps, + guidance_scale=prompt.prompt_strength, + generator=generator, + ) + image = output.images[0] + + # Perform safety check + with lc.timing("safety-filter"): + safety_score = create_safety_score( + image, + safety_mode=IMAGINAIRY_SAFETY_MODE, + ) + is_filtered = safety_score.is_filtered + + # Initialize result images + result_images = {} + progress_latents = [] + + # If the image is unsafe, we can discard it or handle it accordingly + if is_filtered: + image = None # Discard the unsafe image + else: + result_images["generated"] = image + + # Optionally upscale the image + if prompt.upscale: + with lc.timing("upscaling"): + upscaled_img = upscale_image(image) + result_images["upscaled"] = upscaled_img + final_image = upscaled_img + else: + final_image = image + + if add_caption: + with lc.timing("caption-img"): + from imaginairy.enhancers.describe_image_blip import ( + generate_caption, + ) + + caption = generate_caption(final_image) + logger.info(f"Generated caption: {caption}") + + if prompt.fix_faces: + with lc.timing("face-enhancement"): + from imaginairy.enhancers.face_restoration_codeformer import ( + enhance_faces, + ) + + logger.info("Fixing 😊 's in 🖼 using CodeFormer...") + final_image = enhance_faces( + final_image, fidelity=prompt.fix_faces_fidelity + ) + result_images["face_enhanced"] = final_image + + # Create ImagineResult + result = ImagineResult( + img=final_image, + prompt=prompt, + is_nsfw=safety_score.is_nsfw if safety_score else False, + safety_score=safety_score, + result_images=result_images, + performance_stats=lc.get_performance_stats(), + progress_latents=progress_latents, + ) + + _image_name = f"{image_name} " if image_name else "" + logger.info(f"Generated {_image_name}image in {result.total_time():.1f}s") + + if result.performance_stats: + log = logger.info if output_perf else logger.debug + log(f" Timings: {result.timings_str()}") + if torch.cuda.is_available(): + log(f" Peak VRAM: {result.gpu_str('memory_peak')}") + log(f" Peak VRAM Delta: {result.gpu_str('memory_peak_delta')}") + log(f" Ending VRAM: {result.gpu_str('memory_end')}") + + clear_gpu_cache() + return result diff --git a/imaginairy/config.py b/imaginairy/config.py index 9cce76ac..ada5ab25 100644 --- a/imaginairy/config.py +++ b/imaginairy/config.py @@ -101,6 +101,12 @@ def primary_alias(self): defaults={"size": "1024x576"}, config_path="configs/svd_xt_image_decoder.yaml", ), + ModelArchitecture( + name="Flux", + aliases=["flux"], + output_modality="image", + defaults={"size": "1024", "steps": 3}, + ), ] MODEL_ARCHITECTURE_LOOKUP = {} @@ -242,6 +248,16 @@ def __post_init__(self): weights_location="https://huggingface.co/imaginairy/stable-video-diffusion/resolve/f9dce2757a0713da6262f35438050357c2be7ee6/svd_xt_image_decoder.fp16.safetensors", defaults={"frames": 25, "steps": 30}, ), + ModelWeightsConfig( + name="FLUX.1-schnell", + aliases=MODEL_ARCHITECTURE_LOOKUP["flux"].aliases, + architecture=MODEL_ARCHITECTURE_LOOKUP["flux"], + weights_location="", + defaults={ + "steps": 5, + "negative_prompt": "", + }, + ), ] MODEL_WEIGHT_CONFIG_LOOKUP = {} diff --git a/imaginairy/schema.py b/imaginairy/schema.py index 73fc3584..92c6f4eb 100644 --- a/imaginairy/schema.py +++ b/imaginairy/schema.py @@ -656,10 +656,27 @@ def validate_solver_type(cls, v, info: core_schema.FieldValidationInfo): @field_validator("steps", mode="before") def validate_steps(cls, v, info: core_schema.FieldValidationInfo): - steps_lookup = {"ddim": 50, "dpmpp": 20} + model_weights = info.data.get("model_weights") + # Try to get steps from model weights defaults + if ( + v is None + and model_weights + and isinstance(model_weights, config.ModelWeightsConfig) + ): + v = model_weights.defaults.get("steps") + + # If not found in model weights, try model architecture defaults + if v is None and model_weights and model_weights.architecture: + v = model_weights.architecture.defaults.get("steps") + + # If still not found, use solver-specific defaults if v is None: - v = steps_lookup[info.data["solver_type"]] + solver_type = info.data.get("solver_type", "ddim").lower() + steps_lookup = {"ddim": 50, "dpmpp": 20} + v = steps_lookup.get( + solver_type, 50 + ) # Default to 50 if solver not recognized try: return int(v)