-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[core] Freenoise memory improvements #9262
Changes from 16 commits
d0a81ae
d55903d
a86eabe
94438e1
74e3ab0
ec91064
6568681
76f931d
761c44d
6830fb0
49e40ef
9e215c0
2cef5c7
fb96059
dc2c12b
12f0ae1
661a0b3
c55a50a
8797cc3
32961be
256ee34
c7bf8dd
9e556be
098bfd1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,12 +12,16 @@ | |
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from typing import Callable, Dict, Optional, Union | ||
from typing import Callable, Dict, List, Optional, Tuple, Union | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
from ..models.attention import BasicTransformerBlock, FreeNoiseTransformerBlock | ||
from ..models.resnet import Downsample2D, ResnetBlock2D, Upsample2D | ||
from ..models.transformers.transformer_2d import Transformer2DModel | ||
from ..models.unets.unet_motion_model import ( | ||
AnimateDiffTransformer3D, | ||
CrossAttnDownBlockMotion, | ||
DownBlockMotion, | ||
UpBlockMotion, | ||
|
@@ -30,6 +34,53 @@ | |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name | ||
|
||
|
||
class SplitInferenceModule(nn.Module): | ||
def __init__( | ||
self, | ||
module: nn.Module, | ||
split_size: int = 1, | ||
split_dim: int = 0, | ||
input_kwargs_to_split: List[str] = ["hidden_states"], | ||
) -> None: | ||
super().__init__() | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would also add a docstring here to explain the init arguments. Maybe the workflow example in forward can be moved up here too? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds good! |
||
self.module = module | ||
self.split_size = split_size | ||
self.split_dim = split_dim | ||
self.input_kwargs_to_split = set(input_kwargs_to_split) | ||
|
||
def forward(self, *args, **kwargs) -> Union[torch.Tensor, Tuple[torch.Tensor]]: | ||
r"""Forward method of `SplitInferenceModule`. | ||
|
||
All inputs that should be split should be passed as keyword arguments. Only those keywords arguments will be | ||
split that are specified in `inputs_to_split` when initializing the module. | ||
a-r-r-o-w marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
split_inputs = {} | ||
|
||
for key in list(kwargs.keys()): | ||
if key not in self.input_kwargs_to_split or not torch.is_tensor(kwargs[key]): | ||
continue | ||
split_inputs[key] = torch.split(kwargs[key], self.split_size, self.split_dim) | ||
kwargs.pop(key) | ||
|
||
results = [] | ||
for split_input in zip(*split_inputs.values()): | ||
inputs = dict(zip(split_inputs.keys(), split_input)) | ||
inputs.update(kwargs) | ||
Comment on lines
+127
to
+129
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Clean 😎 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
||
intermediate_tensor_or_tensor_tuple = self.module(*args, **inputs) | ||
results.append(intermediate_tensor_or_tensor_tuple) | ||
|
||
if isinstance(results[0], torch.Tensor): | ||
return torch.cat(results, dim=self.split_dim) | ||
elif isinstance(results[0], tuple): | ||
return tuple([torch.cat(x, dim=self.split_dim) for x in zip(*results)]) | ||
else: | ||
raise ValueError( | ||
"In order to use the SplitInferenceModule, it is necessary for the underlying `module` to either return a torch.Tensor or a tuple of torch.Tensor's." | ||
) | ||
|
||
|
||
class AnimateDiffFreeNoiseMixin: | ||
r"""Mixin class for [FreeNoise](https://arxiv.org/abs/2310.15169).""" | ||
|
||
|
@@ -70,6 +121,9 @@ def _enable_free_noise_in_block(self, block: Union[CrossAttnDownBlockMotion, Dow | |
motion_module.transformer_blocks[i].load_state_dict( | ||
basic_transfomer_block.state_dict(), strict=True | ||
) | ||
motion_module.transformer_blocks[i].set_chunk_feed_forward( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we always need chunked feed forward set when enabling free noise? Might be overkill no? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is only there to carry forward the chunk FF behaviour if and only if it was already enabled in the BasicTransformerBlock. Basically, if it was not enable in BTB, |
||
basic_transfomer_block._chunk_size, basic_transfomer_block._chunk_dim | ||
) | ||
|
||
def _disable_free_noise_in_block(self, block: Union[CrossAttnDownBlockMotion, DownBlockMotion, UpBlockMotion]): | ||
r"""Helper function to disable FreeNoise in transformer blocks.""" | ||
|
@@ -98,6 +152,9 @@ def _disable_free_noise_in_block(self, block: Union[CrossAttnDownBlockMotion, Do | |
motion_module.transformer_blocks[i].load_state_dict( | ||
free_noise_transfomer_block.state_dict(), strict=True | ||
) | ||
motion_module.transformer_blocks[i].set_chunk_feed_forward( | ||
free_noise_transfomer_block._chunk_size, free_noise_transfomer_block._chunk_dim | ||
) | ||
|
||
def _check_inputs_free_noise( | ||
self, | ||
|
@@ -410,6 +467,54 @@ def disable_free_noise(self) -> None: | |
for block in blocks: | ||
self._disable_free_noise_in_block(block) | ||
|
||
def _enable_split_inference_motion_modules_( | ||
self, motion_modules: List[AnimateDiffTransformer3D], spatial_split_size: int | ||
) -> None: | ||
for motion_module in motion_modules: | ||
motion_module.proj_in = SplitInferenceModule(motion_module.proj_in, spatial_split_size, 0, ["input"]) | ||
|
||
for i in range(len(motion_module.transformer_blocks)): | ||
motion_module.transformer_blocks[i] = SplitInferenceModule( | ||
motion_module.transformer_blocks[i], | ||
spatial_split_size, | ||
0, | ||
["hidden_states", "encoder_hidden_states"], | ||
) | ||
|
||
motion_module.proj_out = SplitInferenceModule(motion_module.proj_out, spatial_split_size, 0, ["input"]) | ||
|
||
def _enable_split_inference_attentions_( | ||
self, attentions: List[Transformer2DModel], temporal_split_size: int | ||
) -> None: | ||
for i in range(len(attentions)): | ||
attentions[i] = SplitInferenceModule( | ||
attentions[i], temporal_split_size, 0, ["hidden_states", "encoder_hidden_states"] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry, I don't understand the comment very well.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry. Okay. I am assuming that is a reasonable default to choose? I was wondering if it could make sense to let the users choose the inputs they wanna split? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would say let's keep it in mind to allow users to have more control on this, but for now let's keep the scope of changes minimal. I would like to experiment on FreeNoise for CogVideoX as discussed internally, and so would like to get this in soon :) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Okay then we could add a comment before the blocks that could be configured and revisit those if needed? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks, updated. WDYT? |
||
) | ||
|
||
def _enable_split_inference_resnets_(self, resnets: List[ResnetBlock2D], temporal_split_size: int) -> None: | ||
for i in range(len(resnets)): | ||
resnets[i] = SplitInferenceModule(resnets[i], temporal_split_size, 0, ["input_tensor", "temb"]) | ||
|
||
def _enable_split_inference_samplers_( | ||
self, samplers: Union[List[Downsample2D], List[Upsample2D]], temporal_split_size: int | ||
) -> None: | ||
for i in range(len(samplers)): | ||
samplers[i] = SplitInferenceModule(samplers[i], temporal_split_size, 0, ["hidden_states"]) | ||
|
||
def enable_free_noise_split_inference(self, spatial_split_size: int = 256, temporal_split_size: int = 16) -> None: | ||
blocks = [*self.unet.down_blocks, self.unet.mid_block, *self.unet.up_blocks] | ||
for block in blocks: | ||
if getattr(block, "motion_modules", None) is not None: | ||
self._enable_split_inference_motion_modules_(block.motion_modules, spatial_split_size) | ||
if getattr(block, "attentions", None) is not None: | ||
self._enable_split_inference_attentions_(block.attentions, temporal_split_size) | ||
if getattr(block, "resnets", None) is not None: | ||
self._enable_split_inference_resnets_(block.resnets, temporal_split_size) | ||
if getattr(block, "downsamplers", None) is not None: | ||
self._enable_split_inference_samplers_(block.downsamplers, temporal_split_size) | ||
if getattr(block, "upsamplers", None) is not None: | ||
self._enable_split_inference_samplers_(block.upsamplers, temporal_split_size) | ||
|
||
Comment on lines
+583
to
+593
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same. Should There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure I understand this comment either. Basically, we're going to be splitting across the batch dimension for the layers based on chosen There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similarly.
Could it make sense to let the users to choose the kind of layers that wanna apply splitting? Perhaps we default to all ( |
||
@property | ||
def free_noise_enabled(self): | ||
return hasattr(self, "_free_noise_context_length") and self._free_noise_context_length is not None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, this seems to be a form of chunking?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep. At some point lowering the peaks on memory traces, torch.where became the bottleneck. This was actually first noticed by @DN6 so credits to him
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, do we know the situations where
torch.where()
leads to spikes? Seems a little weird to me honestly because native conditionals liketorch.where()
are supposed to be more efficient.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the spike that we see is due to tensors being copied. The intermediate dimensions for attention get large when generating many frames (let's say, 200+) here. We could do something different here too - I just did what seemed like the easiest thing to do (as these changes were made when I was trying out different things in quick succession to golf the memory spikes)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah cool. Let's perhaps make a note of this to reivsit later? At least this way, we are aware?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alright, made a note. LMK if any further changes needed on this :)