Skip to content

Commit

Permalink
[core] Freenoise memory improvements (#9262)
Browse files Browse the repository at this point in the history
* update

* implement prompt interpolation

* make style

* resnet memory optimizations

* more memory optimizations; todo: refactor

* update

* update animatediff controlnet with latest changes

* refactor chunked inference changes

* remove print statements

* update

* chunk -> split

* remove changes from incorrect conflict resolution

* remove changes from incorrect conflict resolution

* add explanation of SplitInferenceModule

* update docs

* Revert "update docs"

This reverts commit c55a50a.

* update docstring for freenoise split inference

* apply suggestions from review

* add tests

* apply suggestions from review
  • Loading branch information
a-r-r-o-w authored Sep 6, 2024
1 parent 5249a26 commit 6dfa499
Show file tree
Hide file tree
Showing 5 changed files with 294 additions and 64 deletions.
22 changes: 20 additions & 2 deletions src/diffusers/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1104,8 +1104,26 @@ def forward(
accumulated_values[:, frame_start:frame_end] += hidden_states_chunk * weights
num_times_accumulated[:, frame_start:frame_end] += weights

hidden_states = torch.where(
num_times_accumulated > 0, accumulated_values / num_times_accumulated, accumulated_values
# TODO(aryan): Maybe this could be done in a better way.
#
# Previously, this was:
# hidden_states = torch.where(
# num_times_accumulated > 0, accumulated_values / num_times_accumulated, accumulated_values
# )
#
# The reasoning for the change here is `torch.where` became a bottleneck at some point when golfing memory
# spikes. It is particularly noticeable when the number of frames is high. My understanding is that this comes
# from tensors being copied - which is why we resort to spliting and concatenating here. I've not particularly
# looked into this deeply because other memory optimizations led to more pronounced reductions.
hidden_states = torch.cat(
[
torch.where(num_times_split > 0, accumulated_split / num_times_split, accumulated_split)
for accumulated_split, num_times_split in zip(
accumulated_values.split(self.context_length, dim=1),
num_times_accumulated.split(self.context_length, dim=1),
)
],
dim=1,
).to(dtype)

# 3. Feed-forward
Expand Down
101 changes: 40 additions & 61 deletions src/diffusers/models/unets/unet_motion_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,20 +187,20 @@ def forward(
hidden_states = self.norm(hidden_states)
hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel)

hidden_states = self.proj_in(hidden_states)
hidden_states = self.proj_in(input=hidden_states)

# 2. Blocks
for block in self.transformer_blocks:
hidden_states = block(
hidden_states,
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
timestep=timestep,
cross_attention_kwargs=cross_attention_kwargs,
class_labels=class_labels,
)

# 3. Output
hidden_states = self.proj_out(hidden_states)
hidden_states = self.proj_out(input=hidden_states)
hidden_states = (
hidden_states[None, None, :]
.reshape(batch_size, height, width, num_frames, channel)
Expand Down Expand Up @@ -344,15 +344,15 @@ def custom_forward(*inputs):
)

else:
hidden_states = resnet(hidden_states, temb)
hidden_states = resnet(input_tensor=hidden_states, temb=temb)

hidden_states = motion_module(hidden_states, num_frames=num_frames)

output_states = output_states + (hidden_states,)

if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
hidden_states = downsampler(hidden_states=hidden_states)

output_states = output_states + (hidden_states,)

Expand Down Expand Up @@ -531,25 +531,18 @@ def custom_forward(*inputs):
temb,
**ckpt_kwargs,
)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = resnet(input_tensor=hidden_states, temb=temb)

hidden_states = attn(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]

hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
hidden_states = motion_module(
hidden_states,
num_frames=num_frames,
Expand All @@ -563,7 +556,7 @@ def custom_forward(*inputs):

if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
hidden_states = downsampler(hidden_states=hidden_states)

output_states = output_states + (hidden_states,)

Expand Down Expand Up @@ -757,33 +750,26 @@ def custom_forward(*inputs):
temb,
**ckpt_kwargs,
)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = resnet(input_tensor=hidden_states, temb=temb)

hidden_states = attn(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]

hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
hidden_states = motion_module(
hidden_states,
num_frames=num_frames,
)

if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size)
hidden_states = upsampler(hidden_states=hidden_states, output_size=upsample_size)

return hidden_states

Expand Down Expand Up @@ -929,13 +915,13 @@ def custom_forward(*inputs):
create_custom_forward(resnet), hidden_states, temb
)
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = resnet(input_tensor=hidden_states, temb=temb)

hidden_states = motion_module(hidden_states, num_frames=num_frames)

if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size)
hidden_states = upsampler(hidden_states=hidden_states, output_size=upsample_size)

return hidden_states

Expand Down Expand Up @@ -1080,10 +1066,19 @@ def forward(
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")

hidden_states = self.resnets[0](hidden_states, temb)
hidden_states = self.resnets[0](input_tensor=hidden_states, temb=temb)

blocks = zip(self.attentions, self.resnets[1:], self.motion_modules)
for attn, resnet, motion_module in blocks:
hidden_states = attn(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]

if self.training and self.gradient_checkpointing:

def create_custom_forward(module, return_dict=None):
Expand All @@ -1096,14 +1091,6 @@ def custom_forward(*inputs):
return custom_forward

ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(motion_module),
hidden_states,
Expand All @@ -1117,19 +1104,11 @@ def custom_forward(*inputs):
**ckpt_kwargs,
)
else:
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
hidden_states = motion_module(
hidden_states,
num_frames=num_frames,
)
hidden_states = resnet(hidden_states, temb)
hidden_states = resnet(input_tensor=hidden_states, temb=temb)

return hidden_states

Expand Down
Loading

0 comments on commit 6dfa499

Please sign in to comment.