Skip to content

Commit

Permalink
Manage quantization of models within the loader
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonrising committed Aug 12, 2024
1 parent 9bf8ac7 commit c8f84d2
Show file tree
Hide file tree
Showing 8 changed files with 244 additions and 284 deletions.
1 change: 1 addition & 0 deletions invokeai/app/invocations/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ class FieldDescriptions:
negative_cond = "Negative conditioning tensor"
noise = "Noise tensor"
clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count"
t5Encoder = "T5 tokenizer and text encoder"
unet = "UNet (scheduler, LoRAs)"
transformer = "Transformer"
vae = "VAE"
Expand Down
181 changes: 54 additions & 127 deletions invokeai/app/invocations/flux_text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.model import TransformerField, CLIPField, T5EncoderField, VAEField
from optimum.quanto import qfloat8

Check failure on line 10 in invokeai/app/invocations/flux_text_to_image.py

View workflow job for this annotation

GitHub Actions / python-checks

Ruff (F401)

invokeai/app/invocations/flux_text_to_image.py:10:28: F401 `optimum.quanto.qfloat8` imported but unused
from PIL import Image
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
from transformers.models.auto import AutoModelForTextEncoding

from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import InputField, WithBoard, WithMetadata, UIType, Input
from invokeai.app.invocations.fields import InputField, WithBoard, WithMetadata, Input, FieldDescriptions
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.quantization.fast_quantized_diffusion_model import FastQuantizedDiffusersModel
Expand Down Expand Up @@ -42,14 +42,24 @@ class QuantizedModelForTextEncoding(FastQuantizedTransformersModel):
class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Text-to-image generation using a FLUX model."""

flux_model: ModelIdentifierField = InputField(
description="The Flux model",
input=Input.Any,
ui_type=UIType.FluxMainModel
transformer: TransformerField = InputField(
description=FieldDescriptions.unet,
input=Input.Connection,
title="Transformer",
)
model: TFluxModelKeys = InputField(description="The FLUX model to use for text-to-image generation.")
use_8bit: bool = InputField(
default=False, description="Whether to quantize the transformer model to 8-bit precision."
clip: CLIPField = InputField(
title="CLIP",
description=FieldDescriptions.clip,
input=Input.Connection,
)
t5Encoder: T5EncoderField = InputField(
title="T5EncoderField",
description=FieldDescriptions.t5Encoder,
input=Input.Connection,
)
vae: VAEField = InputField(
description=FieldDescriptions.vae,
input=Input.Connection,
)
positive_prompt: str = InputField(description="Positive prompt for text-to-image generation.")
width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.")
Expand All @@ -63,45 +73,40 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):

@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
model_path = context.models.download_and_cache_model(FLUX_MODELS[self.model])

t5_embeddings, clip_embeddings = self._encode_prompt(context, model_path)
latents = self._run_diffusion(context, model_path, clip_embeddings, t5_embeddings)
image = self._run_vae_decoding(context, model_path, latents)
t5_embeddings, clip_embeddings = self._encode_prompt(context)
latents = self._run_diffusion(context, clip_embeddings, t5_embeddings)
image = self._run_vae_decoding(context, latents)
image_dto = context.images.save(image=image)
return ImageOutput.build(image_dto)

def _encode_prompt(self, context: InvocationContext, flux_model_dir: Path) -> tuple[torch.Tensor, torch.Tensor]:
# Determine the T5 max sequence length based on the model.
if self.model == "flux-schnell":
max_seq_len = 256
# elif self.model == "flux-dev":
# max_seq_len = 512
else:
raise ValueError(f"Unknown model: {self.model}")
def _encode_prompt(self, context: InvocationContext) -> tuple[torch.Tensor, torch.Tensor]:
# TODO: Determine the T5 max sequence length based on the model.
# if self.model == "flux-schnell":
max_seq_len = 256
# # elif self.model == "flux-dev":
# # max_seq_len = 512
# else:
# raise ValueError(f"Unknown model: {self.model}")

# Load the CLIP tokenizer.
clip_tokenizer_path = flux_model_dir / "tokenizer"
clip_tokenizer = CLIPTokenizer.from_pretrained(clip_tokenizer_path, local_files_only=True)
assert isinstance(clip_tokenizer, CLIPTokenizer)
# Load CLIP.
clip_tokenizer_info = context.models.load(self.clip.tokenizer)
clip_text_encoder_info = context.models.load(self.clip.text_encoder)

# Load the T5 tokenizer.
t5_tokenizer_path = flux_model_dir / "tokenizer_2"
t5_tokenizer = T5TokenizerFast.from_pretrained(t5_tokenizer_path, local_files_only=True)
assert isinstance(t5_tokenizer, T5TokenizerFast)
# Load T5.
t5_tokenizer_info = context.models.load(self.t5Encoder.tokenizer)
t5_text_encoder_info = context.models.load(self.t5Encoder.text_encoder)

clip_text_encoder_path = flux_model_dir / "text_encoder"
t5_text_encoder_path = flux_model_dir / "text_encoder_2"
with (
context.models.load_local_model(
model_path=clip_text_encoder_path, loader=self._load_flux_text_encoder
) as clip_text_encoder,
context.models.load_local_model(
model_path=t5_text_encoder_path, loader=self._load_flux_text_encoder_2
) as t5_text_encoder,
clip_text_encoder_info as clip_text_encoder,
t5_text_encoder_info as t5_text_encoder,
clip_tokenizer_info as clip_tokenizer,
t5_tokenizer_info as t5_tokenizer,
):
assert isinstance(clip_text_encoder, CLIPTextModel)
assert isinstance(t5_text_encoder, T5EncoderModel)
assert isinstance(clip_tokenizer, CLIPTokenizer)
assert isinstance(t5_tokenizer, T5TokenizerFast)

pipeline = FluxPipeline(
scheduler=None,
vae=None,
Expand All @@ -114,7 +119,7 @@ def _encode_prompt(self, context: InvocationContext, flux_model_dir: Path) -> tu

# prompt_embeds: T5 embeddings
# pooled_prompt_embeds: CLIP embeddings
prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(
prompt_embeds, pooled_prompt_embeds, _ = pipeline.encode_prompt(
prompt=self.positive_prompt,
prompt_2=self.positive_prompt,
device=TorchDevice.choose_torch_device(),
Expand All @@ -128,22 +133,23 @@ def _encode_prompt(self, context: InvocationContext, flux_model_dir: Path) -> tu
def _run_diffusion(
self,
context: InvocationContext,
flux_model_dir: Path,
clip_embeddings: torch.Tensor,
t5_embeddings: torch.Tensor,
):
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(flux_model_dir / "scheduler", local_files_only=True)
scheduler_info = context.models.load(self.transformer.scheduler)
transformer_info = context.models.load(self.transformer.transformer)

# HACK(ryand): Manually empty the cache. Currently we don't check the size of the model before loading it from
# disk. Since the transformer model is large (24GB), there's a good chance that it will OOM on 32GB RAM systems
# if the cache is not empty.
context.models._services.model_manager.load.ram_cache.make_room(24 * 2**30)
# context.models._services.model_manager.load.ram_cache.make_room(24 * 2**30)

transformer_path = flux_model_dir / "transformer"
with context.models.load_local_model(
model_path=transformer_path, loader=self._load_flux_transformer
) as transformer:
with (
transformer_info as transformer,
scheduler_info as scheduler
):
assert isinstance(transformer, FluxTransformer2DModel)
assert isinstance(scheduler, FlowMatchEulerDiscreteScheduler)

flux_pipeline_with_transformer = FluxPipeline(
scheduler=scheduler,
Expand Down Expand Up @@ -176,11 +182,10 @@ def _run_diffusion(
def _run_vae_decoding(
self,
context: InvocationContext,
flux_model_dir: Path,
latents: torch.Tensor,
) -> Image.Image:
vae_path = flux_model_dir / "vae"
with context.models.load_local_model(model_path=vae_path, loader=self._load_flux_vae) as vae:
vae_info = context.models.load(self.vae.vae)
with vae_info as vae:
assert isinstance(vae, AutoencoderKL)

flux_pipeline_with_vae = FluxPipeline(
Expand All @@ -205,81 +210,3 @@ def _run_vae_decoding(

assert isinstance(image, Image.Image)
return image

@staticmethod
def _load_flux_text_encoder(path: Path) -> CLIPTextModel:
model = CLIPTextModel.from_pretrained(path, local_files_only=True)
assert isinstance(model, CLIPTextModel)
return model

def _load_flux_text_encoder_2(self, path: Path) -> T5EncoderModel:
if self.use_8bit:
model_8bit_path = path / "quantized"
if model_8bit_path.exists():
# The quantized model exists, load it.
# TODO(ryand): The requantize(...) operation in from_pretrained(...) is very slow. This seems like
# something that we should be able to make much faster.
q_model = QuantizedModelForTextEncoding.from_pretrained(model_8bit_path)

# Access the underlying wrapped model.
# We access the wrapped model, even though it is private, because it simplifies the type checking by
# always returning a T5EncoderModel from this function.
model = q_model._wrapped
else:
# The quantized model does not exist yet, quantize and save it.
# TODO(ryand): dtype?
model = T5EncoderModel.from_pretrained(path, local_files_only=True)
assert isinstance(model, T5EncoderModel)

q_model = QuantizedModelForTextEncoding.quantize(model, weights=qfloat8)

model_8bit_path.mkdir(parents=True, exist_ok=True)
q_model.save_pretrained(model_8bit_path)

# (See earlier comment about accessing the wrapped model.)
model = q_model._wrapped
else:
model = T5EncoderModel.from_pretrained(path, local_files_only=True)

assert isinstance(model, T5EncoderModel)
return model

def _load_flux_transformer(self, path: Path) -> FluxTransformer2DModel:
if self.use_8bit:
model_8bit_path = path / "quantized"
if model_8bit_path.exists():
# The quantized model exists, load it.
# TODO(ryand): The requantize(...) operation in from_pretrained(...) is very slow. This seems like
# something that we should be able to make much faster.
q_model = QuantizedFluxTransformer2DModel.from_pretrained(model_8bit_path)

# Access the underlying wrapped model.
# We access the wrapped model, even though it is private, because it simplifies the type checking by
# always returning a FluxTransformer2DModel from this function.
model = q_model._wrapped
else:
# The quantized model does not exist yet, quantize and save it.
# TODO(ryand): Loading in float16 and then quantizing seems to result in NaNs. In order to run this on
# GPUs that don't support bfloat16, we would need to host the quantized model instead of generating it
# here.
model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True, torch_dtype=torch.bfloat16)
assert isinstance(model, FluxTransformer2DModel)

q_model = QuantizedFluxTransformer2DModel.quantize(model, weights=qfloat8)

model_8bit_path.mkdir(parents=True, exist_ok=True)
q_model.save_pretrained(model_8bit_path)

# (See earlier comment about accessing the wrapped model.)
model = q_model._wrapped
else:
model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True, torch_dtype=torch.bfloat16)

assert isinstance(model, FluxTransformer2DModel)
return model

@staticmethod
def _load_flux_vae(path: Path) -> AutoencoderKL:
model = AutoencoderKL.from_pretrained(path, local_files_only=True)
assert isinstance(model, AutoencoderKL)
return model
10 changes: 7 additions & 3 deletions invokeai/app/invocations/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ class TransformerField(BaseModel):
transformer: ModelIdentifierField = Field(description="Info to load Transformer submodel")
scheduler: ModelIdentifierField = Field(description="Info to load scheduler submodel")

class T5EncoderField(BaseModel):
tokenizer: ModelIdentifierField = Field(description="Info to load tokenizer submodel")
text_encoder: ModelIdentifierField = Field(description="Info to load text_encoder submodel")


class VAEField(BaseModel):
vae: ModelIdentifierField = Field(description="Info to load vae submodel")
Expand Down Expand Up @@ -133,8 +137,8 @@ class FluxModelLoaderOutput(BaseInvocationOutput):
"""Flux base model loader output"""

transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer")
clip: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP 1")
clip2: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP 2")
clip: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP")
t5Encoder: T5EncoderField = OutputField(description=FieldDescriptions.t5Encoder, title="T5 Encoder")
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")


Expand Down Expand Up @@ -166,7 +170,7 @@ def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:
return FluxModelLoaderOutput(
transformer=TransformerField(transformer=transformer, scheduler=scheduler),
clip=CLIPField(tokenizer=tokenizer, text_encoder=text_encoder, loras=[], skipped_layers=0),
clip2=CLIPField(tokenizer=tokenizer2, text_encoder=text_encoder2, loras=[], skipped_layers=0),
t5Encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=text_encoder2),
vae=VAEField(vae=vae),
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,12 @@ def get_hf_load_class(self, model_path: Path, submodel_type: Optional[SubModelTy

# TO DO: Add exception handling
def _hf_definition_to_type(self, module: str, class_name: str) -> ModelMixin: # fix with correct type
if module in ["diffusers", "transformers"]:
if module in [
"diffusers",
"transformers",
"invokeai.backend.quantization.fast_quantized_transformers_model",
"invokeai.backend.quantization.fast_quantized_diffusion_model",
]:
res_type = sys.modules[module]
else:
res_type = sys.modules["diffusers"].pipelines
Expand Down
9 changes: 8 additions & 1 deletion invokeai/backend/model_manager/load/model_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.schedulers.scheduling_utils import SchedulerMixin
from transformers import CLIPTokenizer
from transformers import CLIPTokenizer, T5TokenizerFast

from invokeai.backend.image_util.grounding_dino.grounding_dino_pipeline import GroundingDinoPipeline
from invokeai.backend.image_util.segment_anything.segment_anything_pipeline import SegmentAnythingPipeline
Expand Down Expand Up @@ -48,6 +48,13 @@ def calc_model_size_by_data(logger: logging.Logger, model: AnyModel) -> int:
),
):
return model.calc_size()
elif isinstance(
model,
(
T5TokenizerFast,
),
):
return len(model)
else:
# TODO(ryand): Promote this from a log to an exception once we are confident that we are handling all of the
# supported model types.
Expand Down
14 changes: 8 additions & 6 deletions invokeai/backend/quantization/fast_quantized_diffusion_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,17 @@
)
from optimum.quanto.models import QuantizedDiffusersModel
from optimum.quanto.models.shared_dict import ShardedStateDict
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel

from invokeai.backend.requantize import requantize


class FastQuantizedDiffusersModel(QuantizedDiffusersModel):

Check failure on line 20 in invokeai/backend/quantization/fast_quantized_diffusion_model.py

View workflow job for this annotation

GitHub Actions / python-checks

Ruff (I001)

invokeai/backend/quantization/fast_quantized_diffusion_model.py:1:1: I001 Import block is un-sorted or un-formatted
@classmethod
def from_pretrained(cls, model_name_or_path: Union[str, os.PathLike]):
def from_pretrained(cls, model_name_or_path: Union[str, os.PathLike], base_class = FluxTransformer2DModel, **kwargs):
"""We override the `from_pretrained()` method in order to use our custom `requantize()` implementation."""
if cls.base_class is None:
base_class = base_class or cls.base_class
if base_class is None:
raise ValueError("The `base_class` attribute needs to be configured.")

if not is_accelerate_available():
Expand All @@ -43,16 +45,16 @@ def from_pretrained(cls, model_name_or_path: Union[str, os.PathLike]):

with open(model_config_path, "r", encoding="utf-8") as f:
original_model_cls_name = json.load(f)["_class_name"]
configured_cls_name = cls.base_class.__name__
configured_cls_name = base_class.__name__
if configured_cls_name != original_model_cls_name:
raise ValueError(
f"Configured base class ({configured_cls_name}) differs from what was derived from the provided configuration ({original_model_cls_name})."
)

# Create an empty model
config = cls.base_class.load_config(model_name_or_path)
config = base_class.load_config(model_name_or_path)
with init_empty_weights():
model = cls.base_class.from_config(config)
model = base_class.from_config(config)

# Look for the index of a sharded checkpoint
checkpoint_file = os.path.join(model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
Expand All @@ -72,6 +74,6 @@ def from_pretrained(cls, model_name_or_path: Union[str, os.PathLike]):
# Requantize and load quantized weights from state_dict
requantize(model, state_dict=state_dict, quantization_map=qmap)
model.eval()
return cls(model)
return cls(model)._wrapped
else:
raise NotImplementedError("Reloading quantized models directly from the hub is not supported yet.")
Loading

0 comments on commit c8f84d2

Please sign in to comment.