Skip to content

Commit

Permalink
Module Group Offloading (#10503)
Browse files Browse the repository at this point in the history
* update

* fix

* non_blocking; handle parameters and buffers

* update

* Group offloading with cuda stream prefetching (#10516)

* cuda stream prefetch

* remove breakpoints

* update

* copy model hook implementation from pab

* update; ~very workaround based implementation but it seems to work as expected; needs cleanup and rewrite

* more workarounds to make it actually work

* cleanup

* rewrite

* update

* make sure to sync current stream before overwriting with pinned params

not doing so will lead to erroneous computations on the GPU and cause bad results

* better check

* update

* remove hook implementation to not deal with merge conflict

* re-add hook changes

* why use more memory when less memory do trick

* why still use slightly more memory when less memory do trick

* optimise

* add model tests

* add pipeline tests

* update docs

* add layernorm and groupnorm

* address review comments

* improve tests; add docs

* improve docs

* Apply suggestions from code review

Co-authored-by: Steven Liu <[email protected]>

* apply suggestions from code review

* update tests

* apply suggestions from review

* enable_group_offloading -> enable_group_offload for naming consistency

* raise errors if multiple offloading strategies used; add relevant tests

* handle .to() when group offload applied

* refactor some repeated code

* remove unintentional change from merge conflict

* handle .cuda()

---------

Co-authored-by: Steven Liu <[email protected]>
  • Loading branch information
a-r-r-o-w and stevhliu authored Feb 14, 2025
1 parent ab42820 commit 9a147b8
Show file tree
Hide file tree
Showing 44 changed files with 1,239 additions and 4 deletions.
4 changes: 4 additions & 0 deletions docs/source/en/api/utilities.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,7 @@ Utility and helper functions for working with 🤗 Diffusers.
## apply_layerwise_casting

[[autodoc]] hooks.layerwise_casting.apply_layerwise_casting

## apply_group_offloading

[[autodoc]] hooks.group_offloading.apply_group_offloading
40 changes: 40 additions & 0 deletions docs/source/en/optimization/memory.md
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,46 @@ In order to properly offload models after they're called, it is required to run

</Tip>

## Group offloading

Group offloading is the middle ground between sequential and model offloading. It works by offloading groups of internal layers (either `torch.nn.ModuleList` or `torch.nn.Sequential`), which uses less memory than model-level offloading. It is also faster than sequential-level offloading because the number of device synchronizations is reduced.

To enable group offloading, call the [`~ModelMixin.enable_group_offload`] method on the model if it is a Diffusers model implementation. For any other model implementation, use [`~hooks.group_offloading.apply_group_offloading`]:

```python
import torch
from diffusers import CogVideoXPipeline
from diffusers.hooks import apply_group_offloading
from diffusers.utils import export_to_video

# Load the pipeline
onload_device = torch.device("cuda")
offload_device = torch.device("cpu")
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)

# We can utilize the enable_group_offload method for Diffusers model implementations
pipe.transformer.enable_group_offload(onload_device=onload_device, offload_device=offload_device, offload_type="leaf_level", use_stream=True)

# For any other model implementations, the apply_group_offloading function can be used
apply_group_offloading(pipe.text_encoder, onload_device=onload_device, offload_type="block_level", num_blocks_per_group=2)
apply_group_offloading(pipe.vae, onload_device=onload_device, offload_type="leaf_level")

prompt = (
"A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
"The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
"pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
"casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
"The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
"atmosphere of this unique musical performance."
)
video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
# This utilized about 14.79 GB. It can be further reduced by using tiling and using leaf_level offloading throughout the pipeline.
print(f"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB")
export_to_video(video, "output.mp4", fps=8)
```

Group offloading (for CUDA devices with support for asynchronous data transfer streams) overlaps data transfer and computation to reduce the overall execution time compared to sequential offloading. This is enabled using layer prefetching with CUDA streams. The next layer to be executed is loaded onto the accelerator device while the current layer is being executed - this increases the memory requirements slightly. Group offloading also supports leaf-level offloading (equivalent to sequential CPU offloading) but can be made much faster when using streams.

## FP8 layerwise weight-casting

PyTorch supports `torch.float8_e4m3fn` and `torch.float8_e5m2` as weight storage dtypes, but they can't be used for computation in many different tensor operations due to unimplemented kernel support. However, you can use these dtypes to store model weights in fp8 precision and upcast them on-the-fly when the layers are used in the forward pass. This is known as layerwise weight-casting.
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/hooks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@


if is_torch_available():
from .group_offloading import apply_group_offloading
from .hooks import HookRegistry, ModelHook
from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook
from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
678 changes: 678 additions & 0 deletions src/diffusers/hooks/group_offloading.py

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions src/diffusers/models/autoencoders/autoencoder_oobleck.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,7 @@ class AutoencoderOobleck(ModelMixin, ConfigMixin):
"""

_supports_gradient_checkpointing = False
_supports_group_offloading = False

@register_to_config
def __init__(
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/models/autoencoders/consistency_decoder_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
```
"""

_supports_group_offloading = False

@register_to_config
def __init__(
self,
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/models/autoencoders/vq_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class VQModel(ModelMixin, ConfigMixin):
"""

_skip_layerwise_casting_patterns = ["quantize"]
_supports_group_offloading = False

@register_to_config
def __init__(
Expand Down
92 changes: 91 additions & 1 deletion src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from typing_extensions import Self

from .. import __version__
from ..hooks import apply_layerwise_casting
from ..hooks import apply_group_offloading, apply_layerwise_casting
from ..quantizers import DiffusersAutoQuantizer, DiffusersQuantizer
from ..quantizers.quantization_config import QuantizationMethod
from ..utils import (
Expand Down Expand Up @@ -87,7 +87,17 @@


def get_parameter_device(parameter: torch.nn.Module) -> torch.device:
from ..hooks.group_offloading import _get_group_onload_device

try:
# Try to get the onload device from the group offloading hook
return _get_group_onload_device(parameter)
except ValueError:
pass

try:
# If the onload device is not available due to no group offloading hooks, try to get the device
# from the first parameter or buffer
parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers())
return next(parameters_and_buffers).device
except StopIteration:
Expand Down Expand Up @@ -166,6 +176,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
_no_split_modules = None
_keep_in_fp32_modules = None
_skip_layerwise_casting_patterns = None
_supports_group_offloading = True

def __init__(self):
super().__init__()
Expand Down Expand Up @@ -437,6 +448,55 @@ def enable_layerwise_casting(
self, storage_dtype, compute_dtype, skip_modules_pattern, skip_modules_classes, non_blocking
)

def enable_group_offload(
self,
onload_device: torch.device,
offload_device: torch.device = torch.device("cpu"),
offload_type: str = "block_level",
num_blocks_per_group: Optional[int] = None,
non_blocking: bool = False,
use_stream: bool = False,
) -> None:
r"""
Activates group offloading for the current model.
See [`~hooks.group_offloading.apply_group_offloading`] for more information.
Example:
```python
>>> from diffusers import CogVideoXTransformer3DModel
>>> transformer = CogVideoXTransformer3DModel.from_pretrained(
... "THUDM/CogVideoX-5b", subfolder="transformer", torch_dtype=torch.bfloat16
... )
>>> transformer.enable_group_offload(
... onload_device=torch.device("cuda"),
... offload_device=torch.device("cpu"),
... offload_type="leaf_level",
... use_stream=True,
... )
```
"""
if getattr(self, "enable_tiling", None) is not None and getattr(self, "use_tiling", False) and use_stream:
msg = (
"Applying group offloading on autoencoders, with CUDA streams, may not work as expected if the first "
"forward pass is executed with tiling enabled. Please make sure to either:\n"
"1. Run a forward pass with small input shapes.\n"
"2. Or, run a forward pass with tiling disabled (can still use small dummy inputs)."
)
logger.warning(msg)
if not self._supports_group_offloading:
raise ValueError(
f"{self.__class__.__name__} does not support group offloading. Please make sure to set the boolean attribute "
f"`_supports_group_offloading` to `True` in the class definition. If you believe this is a mistake, please "
f"open an issue at https://github.com/huggingface/diffusers/issues."
)
apply_group_offloading(
self, onload_device, offload_device, offload_type, num_blocks_per_group, non_blocking, use_stream
)

def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
Expand Down Expand Up @@ -1170,6 +1230,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
# Adapted from `transformers`.
@wraps(torch.nn.Module.cuda)
def cuda(self, *args, **kwargs):
from ..hooks.group_offloading import _is_group_offload_enabled

# Checks if the model has been loaded in 4-bit or 8-bit with BNB
if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
if getattr(self, "is_loaded_in_8bit", False):
Expand All @@ -1182,13 +1244,34 @@ def cuda(self, *args, **kwargs):
"Calling `cuda()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. "
f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2."
)

# Checks if group offloading is enabled
if _is_group_offload_enabled(self):
logger.warning(
f"The module '{self.__class__.__name__}' is group offloaded and moving it using `.cuda()` is not supported."
)
return self

return super().cuda(*args, **kwargs)

# Adapted from `transformers`.
@wraps(torch.nn.Module.to)
def to(self, *args, **kwargs):
from ..hooks.group_offloading import _is_group_offload_enabled

device_arg_or_kwarg_present = any(isinstance(arg, torch.device) for arg in args) or "device" in kwargs
dtype_present_in_args = "dtype" in kwargs

# Try converting arguments to torch.device in case they are passed as strings
for arg in args:
if not isinstance(arg, str):
continue
try:
torch.device(arg)
device_arg_or_kwarg_present = True
except RuntimeError:
pass

if not dtype_present_in_args:
for arg in args:
if isinstance(arg, torch.dtype):
Expand All @@ -1213,6 +1296,13 @@ def to(self, *args, **kwargs):
"Calling `to()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. "
f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2."
)

if _is_group_offload_enabled(self) and device_arg_or_kwarg_present:
logger.warning(
f"The module '{self.__class__.__name__}' is group offloaded and moving it using `.to()` is not supported."
)
return self

return super().to(*args, **kwargs)

# Taken from `transformers`.
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/models/transformers/dit_transformer_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ class DiTTransformer2DModel(ModelMixin, ConfigMixin):

_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
_supports_gradient_checkpointing = True
_supports_group_offloading = False

@register_to_config
def __init__(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
"""

_skip_layerwise_casting_patterns = ["pos_embed", "norm", "pooler"]
_supports_group_offloading = False

@register_to_config
def __init__(
Expand Down
53 changes: 50 additions & 3 deletions src/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,7 @@ def to(self, *args, **kwargs):
)

device = device or device_arg
device_type = torch.device(device).type if device is not None else None
pipeline_has_bnb = any(any((_check_bnb_status(module))) for _, module in self.components.items())

# throw warning if pipeline is in "offloaded"-mode but user tries to manually set to GPU.
Expand Down Expand Up @@ -424,7 +425,7 @@ def module_is_offloaded(module):
"It seems like you have activated a device mapping strategy on the pipeline which doesn't allow explicit device placement using `to()`. You can call `reset_device_map()` to remove the existing device map from the pipeline."
)

if device and torch.device(device).type == "cuda":
if device_type == "cuda":
if pipeline_is_sequentially_offloaded and not pipeline_has_bnb:
raise ValueError(
"It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading."
Expand All @@ -437,7 +438,7 @@ def module_is_offloaded(module):

# Display a warning in this case (the operation succeeds but the benefits are lost)
pipeline_is_offloaded = any(module_is_offloaded(module) for _, module in self.components.items())
if pipeline_is_offloaded and device and torch.device(device).type == "cuda":
if pipeline_is_offloaded and device_type == "cuda":
logger.warning(
f"It seems like you have activated model offloading by calling `enable_model_cpu_offload`, but are now manually moving the pipeline to GPU. It is strongly recommended against doing so as memory gains from offloading are likely to be lost. Offloading automatically takes care of moving the individual components {', '.join(self.components.keys())} to GPU when needed. To make sure offloading works as expected, you should consider moving the pipeline back to CPU: `pipeline.to('cpu')` or removing the move altogether if you use offloading."
)
Expand All @@ -449,6 +450,7 @@ def module_is_offloaded(module):
is_offloaded = pipeline_is_offloaded or pipeline_is_sequentially_offloaded
for module in modules:
_, is_loaded_in_4bit_bnb, is_loaded_in_8bit_bnb = _check_bnb_status(module)
is_group_offloaded = self._maybe_raise_error_if_group_offload_active(module=module)

if (is_loaded_in_4bit_bnb or is_loaded_in_8bit_bnb) and dtype is not None:
logger.warning(
Expand All @@ -460,11 +462,21 @@ def module_is_offloaded(module):
f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` 8bit and moving it to {device} via `.to()` is not supported. Module is still on {module.device}."
)

# Note: we also handle this at the ModelMixin level. The reason for doing it here too is that modeling
# components can be from outside diffusers too, but still have group offloading enabled.
if (
self._maybe_raise_error_if_group_offload_active(raise_error=False, module=module)
and device is not None
):
logger.warning(
f"The module '{module.__class__.__name__}' is group offloaded and moving it to {device} via `.to()` is not supported."
)

# This can happen for `transformer` models. CPU placement was added in
# https://github.com/huggingface/transformers/pull/33122. So, we guard this accordingly.
if is_loaded_in_4bit_bnb and device is not None and is_transformers_version(">", "4.44.0"):
module.to(device=device)
elif not is_loaded_in_4bit_bnb and not is_loaded_in_8bit_bnb:
elif not is_loaded_in_4bit_bnb and not is_loaded_in_8bit_bnb and not is_group_offloaded:
module.to(device, dtype)

if (
Expand Down Expand Up @@ -1023,6 +1035,19 @@ def _execution_device(self):
[`~DiffusionPipeline.enable_sequential_cpu_offload`] the execution device can only be inferred from
Accelerate's module hooks.
"""
from ..hooks.group_offloading import _get_group_onload_device

# When apply group offloading at the leaf_level, we're in the same situation as accelerate's sequential
# offloading. We need to return the onload device of the group offloading hooks so that the intermediates
# required for computation (latents, prompt embeddings, etc.) can be created on the correct device.
for name, model in self.components.items():
if not isinstance(model, torch.nn.Module):
continue
try:
return _get_group_onload_device(model)
except ValueError:
pass

for name, model in self.components.items():
if not isinstance(model, torch.nn.Module) or name in self._exclude_from_cpu_offload:
continue
Expand Down Expand Up @@ -1061,6 +1086,8 @@ def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[t
The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will
default to "cuda".
"""
self._maybe_raise_error_if_group_offload_active(raise_error=True)

is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
if is_pipeline_device_mapped:
raise ValueError(
Expand Down Expand Up @@ -1172,6 +1199,8 @@ def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Un
The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will
default to "cuda".
"""
self._maybe_raise_error_if_group_offload_active(raise_error=True)

if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"):
from accelerate import cpu_offload
else:
Expand Down Expand Up @@ -1896,6 +1925,24 @@ def from_pipe(cls, pipeline, **kwargs):

return new_pipeline

def _maybe_raise_error_if_group_offload_active(
self, raise_error: bool = False, module: Optional[torch.nn.Module] = None
) -> bool:
from ..hooks.group_offloading import _is_group_offload_enabled

components = self.components.values() if module is None else [module]
components = [component for component in components if isinstance(component, torch.nn.Module)]
for component in components:
if _is_group_offload_enabled(component):
if raise_error:
raise ValueError(
"You are trying to apply model/sequential CPU offloading to a pipeline that contains components "
"with group offloading enabled. This is not supported. Please disable group offloading for "
"components of the pipeline to use other offloading methods."
)
return True
return False


class StableDiffusionMixin:
r"""
Expand Down
Loading

0 comments on commit 9a147b8

Please sign in to comment.