Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core] Freenoise memory improvements #9262

Merged
merged 24 commits into from
Sep 6, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
d0a81ae
update
a-r-r-o-w Aug 14, 2024
d55903d
implement prompt interpolation
a-r-r-o-w Aug 15, 2024
a86eabe
make style
a-r-r-o-w Aug 15, 2024
94438e1
resnet memory optimizations
a-r-r-o-w Aug 18, 2024
74e3ab0
more memory optimizations; todo: refactor
a-r-r-o-w Aug 18, 2024
ec91064
update
a-r-r-o-w Aug 18, 2024
6568681
update animatediff controlnet with latest changes
a-r-r-o-w Aug 18, 2024
76f931d
Merge branch 'main' into animatediff/freenoise-improvements
a-r-r-o-w Aug 19, 2024
761c44d
refactor chunked inference changes
a-r-r-o-w Aug 21, 2024
6830fb0
remove print statements
a-r-r-o-w Aug 21, 2024
49e40ef
Merge branch 'main' into animatediff/freenoise-improvements
a-r-r-o-w Aug 23, 2024
9e215c0
update
a-r-r-o-w Aug 24, 2024
2cef5c7
Merge branch 'main' into animatediff/freenoise-memory-improvements
a-r-r-o-w Sep 5, 2024
fb96059
chunk -> split
a-r-r-o-w Sep 5, 2024
dc2c12b
remove changes from incorrect conflict resolution
a-r-r-o-w Sep 5, 2024
12f0ae1
remove changes from incorrect conflict resolution
a-r-r-o-w Sep 5, 2024
661a0b3
add explanation of SplitInferenceModule
a-r-r-o-w Sep 5, 2024
c55a50a
update docs
a-r-r-o-w Sep 5, 2024
8797cc3
Merge branch 'main' into animatediff/freenoise-memory-improvements
a-r-r-o-w Sep 5, 2024
32961be
Revert "update docs"
a-r-r-o-w Sep 5, 2024
256ee34
update docstring for freenoise split inference
a-r-r-o-w Sep 5, 2024
c7bf8dd
apply suggestions from review
a-r-r-o-w Sep 5, 2024
9e556be
add tests
a-r-r-o-w Sep 5, 2024
098bfd1
apply suggestions from review
a-r-r-o-w Sep 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions src/diffusers/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1104,8 +1104,15 @@ def forward(
accumulated_values[:, frame_start:frame_end] += hidden_states_chunk * weights
num_times_accumulated[:, frame_start:frame_end] += weights

hidden_states = torch.where(
num_times_accumulated > 0, accumulated_values / num_times_accumulated, accumulated_values
hidden_states = torch.cat(
[
torch.where(num_times_split > 0, accumulated_split / num_times_split, accumulated_split)
for accumulated_split, num_times_split in zip(
accumulated_values.split(self.context_length, dim=1),
num_times_accumulated.split(self.context_length, dim=1),
)
],
dim=1,
Comment on lines +1118 to +1126
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, this seems to be a form of chunking?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep. At some point lowering the peaks on memory traces, torch.where became the bottleneck. This was actually first noticed by @DN6 so credits to him

Copy link
Member

@sayakpaul sayakpaul Sep 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, do we know the situations where torch.where() leads to spikes? Seems a little weird to me honestly because native conditionals like torch.where() are supposed to be more efficient.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the spike that we see is due to tensors being copied. The intermediate dimensions for attention get large when generating many frames (let's say, 200+) here. We could do something different here too - I just did what seemed like the easiest thing to do (as these changes were made when I was trying out different things in quick succession to golf the memory spikes)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah cool. Let's perhaps make a note of this to reivsit later? At least this way, we are aware?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, made a note. LMK if any further changes needed on this :)

).to(dtype)

# 3. Feed-forward
Expand Down
101 changes: 40 additions & 61 deletions src/diffusers/models/unets/unet_motion_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,20 +187,20 @@ def forward(
hidden_states = self.norm(hidden_states)
hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel)

hidden_states = self.proj_in(hidden_states)
hidden_states = self.proj_in(input=hidden_states)

# 2. Blocks
for block in self.transformer_blocks:
hidden_states = block(
hidden_states,
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
timestep=timestep,
cross_attention_kwargs=cross_attention_kwargs,
class_labels=class_labels,
)

# 3. Output
hidden_states = self.proj_out(hidden_states)
hidden_states = self.proj_out(input=hidden_states)
hidden_states = (
hidden_states[None, None, :]
.reshape(batch_size, height, width, num_frames, channel)
Expand Down Expand Up @@ -344,15 +344,15 @@ def custom_forward(*inputs):
)

else:
hidden_states = resnet(hidden_states, temb)
hidden_states = resnet(input_tensor=hidden_states, temb=temb)

hidden_states = motion_module(hidden_states, num_frames=num_frames)

output_states = output_states + (hidden_states,)

if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
hidden_states = downsampler(hidden_states=hidden_states)

output_states = output_states + (hidden_states,)

Expand Down Expand Up @@ -531,25 +531,18 @@ def custom_forward(*inputs):
temb,
**ckpt_kwargs,
)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = resnet(input_tensor=hidden_states, temb=temb)

hidden_states = attn(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]

hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
hidden_states = motion_module(
hidden_states,
num_frames=num_frames,
Expand All @@ -563,7 +556,7 @@ def custom_forward(*inputs):

if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
hidden_states = downsampler(hidden_states=hidden_states)

output_states = output_states + (hidden_states,)

Expand Down Expand Up @@ -757,33 +750,26 @@ def custom_forward(*inputs):
temb,
**ckpt_kwargs,
)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = resnet(input_tensor=hidden_states, temb=temb)

hidden_states = attn(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]

hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
hidden_states = motion_module(
hidden_states,
num_frames=num_frames,
)

if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size)
hidden_states = upsampler(hidden_states=hidden_states, output_size=upsample_size)

return hidden_states

Expand Down Expand Up @@ -929,13 +915,13 @@ def custom_forward(*inputs):
create_custom_forward(resnet), hidden_states, temb
)
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = resnet(input_tensor=hidden_states, temb=temb)

hidden_states = motion_module(hidden_states, num_frames=num_frames)

if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size)
hidden_states = upsampler(hidden_states=hidden_states, output_size=upsample_size)

return hidden_states

Expand Down Expand Up @@ -1080,10 +1066,19 @@ def forward(
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")

hidden_states = self.resnets[0](hidden_states, temb)
hidden_states = self.resnets[0](input_tensor=hidden_states, temb=temb)

blocks = zip(self.attentions, self.resnets[1:], self.motion_modules)
for attn, resnet, motion_module in blocks:
hidden_states = attn(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
sayakpaul marked this conversation as resolved.
Show resolved Hide resolved

if self.training and self.gradient_checkpointing:

def create_custom_forward(module, return_dict=None):
Expand All @@ -1096,14 +1091,6 @@ def custom_forward(*inputs):
return custom_forward

ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(motion_module),
hidden_states,
Expand All @@ -1117,19 +1104,11 @@ def custom_forward(*inputs):
**ckpt_kwargs,
)
else:
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
hidden_states = motion_module(
hidden_states,
num_frames=num_frames,
)
hidden_states = resnet(hidden_states, temb)
hidden_states = resnet(input_tensor=hidden_states, temb=temb)

return hidden_states

Expand Down
107 changes: 106 additions & 1 deletion src/diffusers/pipelines/free_noise_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable, Dict, Optional, Union
from typing import Callable, Dict, List, Optional, Tuple, Union

import torch
import torch.nn as nn

from ..models.attention import BasicTransformerBlock, FreeNoiseTransformerBlock
from ..models.resnet import Downsample2D, ResnetBlock2D, Upsample2D
from ..models.transformers.transformer_2d import Transformer2DModel
from ..models.unets.unet_motion_model import (
AnimateDiffTransformer3D,
CrossAttnDownBlockMotion,
DownBlockMotion,
UpBlockMotion,
Expand All @@ -30,6 +34,53 @@
logger = logging.get_logger(__name__) # pylint: disable=invalid-name


class SplitInferenceModule(nn.Module):
def __init__(
self,
module: nn.Module,
split_size: int = 1,
split_dim: int = 0,
input_kwargs_to_split: List[str] = ["hidden_states"],
) -> None:
super().__init__()

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would also add a docstring here to explain the init arguments. Maybe the workflow example in forward can be moved up here too?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good!

self.module = module
self.split_size = split_size
self.split_dim = split_dim
self.input_kwargs_to_split = set(input_kwargs_to_split)

def forward(self, *args, **kwargs) -> Union[torch.Tensor, Tuple[torch.Tensor]]:
r"""Forward method of `SplitInferenceModule`.

All inputs that should be split should be passed as keyword arguments. Only those keywords arguments will be
split that are specified in `inputs_to_split` when initializing the module.
a-r-r-o-w marked this conversation as resolved.
Show resolved Hide resolved
"""
split_inputs = {}

for key in list(kwargs.keys()):
if key not in self.input_kwargs_to_split or not torch.is_tensor(kwargs[key]):
continue
split_inputs[key] = torch.split(kwargs[key], self.split_size, self.split_dim)
kwargs.pop(key)

results = []
for split_input in zip(*split_inputs.values()):
inputs = dict(zip(split_inputs.keys(), split_input))
inputs.update(kwargs)
Comment on lines +127 to +129
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clean 😎

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not too proud of this one 😬 but hey, I'm not afraid of not understanding what sorcery is going on here two weeks later

image


intermediate_tensor_or_tensor_tuple = self.module(*args, **inputs)
results.append(intermediate_tensor_or_tensor_tuple)

if isinstance(results[0], torch.Tensor):
return torch.cat(results, dim=self.split_dim)
elif isinstance(results[0], tuple):
return tuple([torch.cat(x, dim=self.split_dim) for x in zip(*results)])
else:
raise ValueError(
"In order to use the SplitInferenceModule, it is necessary for the underlying `module` to either return a torch.Tensor or a tuple of torch.Tensor's."
)


class AnimateDiffFreeNoiseMixin:
r"""Mixin class for [FreeNoise](https://arxiv.org/abs/2310.15169)."""

Expand Down Expand Up @@ -70,6 +121,9 @@ def _enable_free_noise_in_block(self, block: Union[CrossAttnDownBlockMotion, Dow
motion_module.transformer_blocks[i].load_state_dict(
basic_transfomer_block.state_dict(), strict=True
)
motion_module.transformer_blocks[i].set_chunk_feed_forward(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we always need chunked feed forward set when enabling free noise? Might be overkill no?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is only there to carry forward the chunk FF behaviour if and only if it was already enabled in the BasicTransformerBlock. Basically, if it was not enable in BTB, motion_module.transformer_blocks[i]._chunk_size would be None leading to default behaviour of no chunking. If it was enabled in BTB, it would by default carry forward to FreeNoiseTransformerBlock

basic_transfomer_block._chunk_size, basic_transfomer_block._chunk_dim
)

def _disable_free_noise_in_block(self, block: Union[CrossAttnDownBlockMotion, DownBlockMotion, UpBlockMotion]):
r"""Helper function to disable FreeNoise in transformer blocks."""
Expand Down Expand Up @@ -98,6 +152,9 @@ def _disable_free_noise_in_block(self, block: Union[CrossAttnDownBlockMotion, Do
motion_module.transformer_blocks[i].load_state_dict(
free_noise_transfomer_block.state_dict(), strict=True
)
motion_module.transformer_blocks[i].set_chunk_feed_forward(
free_noise_transfomer_block._chunk_size, free_noise_transfomer_block._chunk_dim
)

def _check_inputs_free_noise(
self,
Expand Down Expand Up @@ -410,6 +467,54 @@ def disable_free_noise(self) -> None:
for block in blocks:
self._disable_free_noise_in_block(block)

def _enable_split_inference_motion_modules_(
self, motion_modules: List[AnimateDiffTransformer3D], spatial_split_size: int
) -> None:
for motion_module in motion_modules:
motion_module.proj_in = SplitInferenceModule(motion_module.proj_in, spatial_split_size, 0, ["input"])

for i in range(len(motion_module.transformer_blocks)):
motion_module.transformer_blocks[i] = SplitInferenceModule(
motion_module.transformer_blocks[i],
spatial_split_size,
0,
["hidden_states", "encoder_hidden_states"],
)

motion_module.proj_out = SplitInferenceModule(motion_module.proj_out, spatial_split_size, 0, ["input"])

def _enable_split_inference_attentions_(
self, attentions: List[Transformer2DModel], temporal_split_size: int
) -> None:
for i in range(len(attentions)):
attentions[i] = SplitInferenceModule(
attentions[i], temporal_split_size, 0, ["hidden_states", "encoder_hidden_states"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should ["hidden_states", "encoder_hidden_states"] not be configurable or not really?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I don't understand the comment very well.

SplitInferenceModule sets input_kwargs_to_split to ["hidden_states"] by default if no parameter is passed. I want both hidden_states and encoder_hidden_states to be split based on split_size here

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry. Okay. I am assuming that is a reasonable default to choose? I was wondering if it could make sense to let the users choose the inputs they wanna split?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would say let's keep it in mind to allow users to have more control on this, but for now let's keep the scope of changes minimal. I would like to experiment on FreeNoise for CogVideoX as discussed internally, and so would like to get this in soon :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay then we could add a comment before the blocks that could be configured and revisit those if needed?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, updated. WDYT?

)

def _enable_split_inference_resnets_(self, resnets: List[ResnetBlock2D], temporal_split_size: int) -> None:
for i in range(len(resnets)):
resnets[i] = SplitInferenceModule(resnets[i], temporal_split_size, 0, ["input_tensor", "temb"])

def _enable_split_inference_samplers_(
self, samplers: Union[List[Downsample2D], List[Upsample2D]], temporal_split_size: int
) -> None:
for i in range(len(samplers)):
samplers[i] = SplitInferenceModule(samplers[i], temporal_split_size, 0, ["hidden_states"])

def enable_free_noise_split_inference(self, spatial_split_size: int = 256, temporal_split_size: int = 16) -> None:
blocks = [*self.unet.down_blocks, self.unet.mid_block, *self.unet.up_blocks]
for block in blocks:
if getattr(block, "motion_modules", None) is not None:
self._enable_split_inference_motion_modules_(block.motion_modules, spatial_split_size)
if getattr(block, "attentions", None) is not None:
self._enable_split_inference_attentions_(block.attentions, temporal_split_size)
if getattr(block, "resnets", None) is not None:
self._enable_split_inference_resnets_(block.resnets, temporal_split_size)
if getattr(block, "downsamplers", None) is not None:
self._enable_split_inference_samplers_(block.downsamplers, temporal_split_size)
if getattr(block, "upsamplers", None) is not None:
self._enable_split_inference_samplers_(block.upsamplers, temporal_split_size)

Comment on lines +583 to +593
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same. Should attentions, resnets, etc. ne not configurable?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I understand this comment either. Basically, we're going to be splitting across the batch dimension for the layers based on chosen spatial_split_size and temporal_split_size values

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly.

Basically, we're going to be splitting across the batch dimension for the layers based on chosen spatial_split_size and temporal_split_size values

Could it make sense to let the users to choose the kind of layers that wanna apply splitting? Perhaps we default to all (attentions, motion_modules, resnets, downsamplers, upsamplers,) or not really?

@property
def free_noise_enabled(self):
return hasattr(self, "_free_noise_context_length") and self._free_noise_context_length is not None
Loading