Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core] Freenoise memory improvements #9262

Merged
merged 24 commits into from
Sep 6, 2024
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
d0a81ae
update
a-r-r-o-w Aug 14, 2024
d55903d
implement prompt interpolation
a-r-r-o-w Aug 15, 2024
a86eabe
make style
a-r-r-o-w Aug 15, 2024
94438e1
resnet memory optimizations
a-r-r-o-w Aug 18, 2024
74e3ab0
more memory optimizations; todo: refactor
a-r-r-o-w Aug 18, 2024
ec91064
update
a-r-r-o-w Aug 18, 2024
6568681
update animatediff controlnet with latest changes
a-r-r-o-w Aug 18, 2024
76f931d
Merge branch 'main' into animatediff/freenoise-improvements
a-r-r-o-w Aug 19, 2024
761c44d
refactor chunked inference changes
a-r-r-o-w Aug 21, 2024
6830fb0
remove print statements
a-r-r-o-w Aug 21, 2024
49e40ef
Merge branch 'main' into animatediff/freenoise-improvements
a-r-r-o-w Aug 23, 2024
9e215c0
update
a-r-r-o-w Aug 24, 2024
2cef5c7
Merge branch 'main' into animatediff/freenoise-memory-improvements
a-r-r-o-w Sep 5, 2024
fb96059
chunk -> split
a-r-r-o-w Sep 5, 2024
dc2c12b
remove changes from incorrect conflict resolution
a-r-r-o-w Sep 5, 2024
12f0ae1
remove changes from incorrect conflict resolution
a-r-r-o-w Sep 5, 2024
661a0b3
add explanation of SplitInferenceModule
a-r-r-o-w Sep 5, 2024
c55a50a
update docs
a-r-r-o-w Sep 5, 2024
8797cc3
Merge branch 'main' into animatediff/freenoise-memory-improvements
a-r-r-o-w Sep 5, 2024
32961be
Revert "update docs"
a-r-r-o-w Sep 5, 2024
256ee34
update docstring for freenoise split inference
a-r-r-o-w Sep 5, 2024
c7bf8dd
apply suggestions from review
a-r-r-o-w Sep 5, 2024
9e556be
add tests
a-r-r-o-w Sep 5, 2024
098bfd1
apply suggestions from review
a-r-r-o-w Sep 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions src/diffusers/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1104,8 +1104,26 @@ def forward(
accumulated_values[:, frame_start:frame_end] += hidden_states_chunk * weights
num_times_accumulated[:, frame_start:frame_end] += weights

hidden_states = torch.where(
num_times_accumulated > 0, accumulated_values / num_times_accumulated, accumulated_values
# TODO(aryan): Maybe this could be done in a better way.
#
# Previously, this was:
# hidden_states = torch.where(
# num_times_accumulated > 0, accumulated_values / num_times_accumulated, accumulated_values
# )
#
# The reasoning for the change here is `torch.where` became a bottleneck at some point when golfing memory
# spikes. It is particularly noticeable when the number of frames is high. My understanding is that this comes
# from tensors being copied - which is why we resort to spliting and concatenating here. I've not particularly
# looked into this deeply because other memory optimizations led to more pronounced reductions.
hidden_states = torch.cat(
[
torch.where(num_times_split > 0, accumulated_split / num_times_split, accumulated_split)
for accumulated_split, num_times_split in zip(
accumulated_values.split(self.context_length, dim=1),
num_times_accumulated.split(self.context_length, dim=1),
)
],
dim=1,
Comment on lines +1118 to +1126
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, this seems to be a form of chunking?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep. At some point lowering the peaks on memory traces, torch.where became the bottleneck. This was actually first noticed by @DN6 so credits to him

Copy link
Member

@sayakpaul sayakpaul Sep 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, do we know the situations where torch.where() leads to spikes? Seems a little weird to me honestly because native conditionals like torch.where() are supposed to be more efficient.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the spike that we see is due to tensors being copied. The intermediate dimensions for attention get large when generating many frames (let's say, 200+) here. We could do something different here too - I just did what seemed like the easiest thing to do (as these changes were made when I was trying out different things in quick succession to golf the memory spikes)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah cool. Let's perhaps make a note of this to reivsit later? At least this way, we are aware?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, made a note. LMK if any further changes needed on this :)

).to(dtype)

# 3. Feed-forward
Expand Down
101 changes: 40 additions & 61 deletions src/diffusers/models/unets/unet_motion_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,20 +187,20 @@ def forward(
hidden_states = self.norm(hidden_states)
hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel)

hidden_states = self.proj_in(hidden_states)
hidden_states = self.proj_in(input=hidden_states)

# 2. Blocks
for block in self.transformer_blocks:
hidden_states = block(
hidden_states,
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
timestep=timestep,
cross_attention_kwargs=cross_attention_kwargs,
class_labels=class_labels,
)

# 3. Output
hidden_states = self.proj_out(hidden_states)
hidden_states = self.proj_out(input=hidden_states)
hidden_states = (
hidden_states[None, None, :]
.reshape(batch_size, height, width, num_frames, channel)
Expand Down Expand Up @@ -344,15 +344,15 @@ def custom_forward(*inputs):
)

else:
hidden_states = resnet(hidden_states, temb)
hidden_states = resnet(input_tensor=hidden_states, temb=temb)

hidden_states = motion_module(hidden_states, num_frames=num_frames)

output_states = output_states + (hidden_states,)

if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
hidden_states = downsampler(hidden_states=hidden_states)

output_states = output_states + (hidden_states,)

Expand Down Expand Up @@ -531,25 +531,18 @@ def custom_forward(*inputs):
temb,
**ckpt_kwargs,
)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = resnet(input_tensor=hidden_states, temb=temb)

hidden_states = attn(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]

hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
hidden_states = motion_module(
hidden_states,
num_frames=num_frames,
Expand All @@ -563,7 +556,7 @@ def custom_forward(*inputs):

if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
hidden_states = downsampler(hidden_states=hidden_states)

output_states = output_states + (hidden_states,)

Expand Down Expand Up @@ -757,33 +750,26 @@ def custom_forward(*inputs):
temb,
**ckpt_kwargs,
)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = resnet(input_tensor=hidden_states, temb=temb)

hidden_states = attn(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]

hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
hidden_states = motion_module(
hidden_states,
num_frames=num_frames,
)

if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size)
hidden_states = upsampler(hidden_states=hidden_states, output_size=upsample_size)

return hidden_states

Expand Down Expand Up @@ -929,13 +915,13 @@ def custom_forward(*inputs):
create_custom_forward(resnet), hidden_states, temb
)
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = resnet(input_tensor=hidden_states, temb=temb)

hidden_states = motion_module(hidden_states, num_frames=num_frames)

if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size)
hidden_states = upsampler(hidden_states=hidden_states, output_size=upsample_size)

return hidden_states

Expand Down Expand Up @@ -1080,10 +1066,19 @@ def forward(
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")

hidden_states = self.resnets[0](hidden_states, temb)
hidden_states = self.resnets[0](input_tensor=hidden_states, temb=temb)

blocks = zip(self.attentions, self.resnets[1:], self.motion_modules)
for attn, resnet, motion_module in blocks:
hidden_states = attn(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
sayakpaul marked this conversation as resolved.
Show resolved Hide resolved

if self.training and self.gradient_checkpointing:

def create_custom_forward(module, return_dict=None):
Expand All @@ -1096,14 +1091,6 @@ def custom_forward(*inputs):
return custom_forward

ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(motion_module),
hidden_states,
Expand All @@ -1117,19 +1104,11 @@ def custom_forward(*inputs):
**ckpt_kwargs,
)
else:
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
hidden_states = motion_module(
hidden_states,
num_frames=num_frames,
)
hidden_states = resnet(hidden_states, temb)
hidden_states = resnet(input_tensor=hidden_states, temb=temb)

return hidden_states

Expand Down
Loading
Loading