Skip to content

Commit

Permalink
Remove unnecessary hasattr checks for scaled_dot_product_attention. W…
Browse files Browse the repository at this point in the history
…e pin the torch version, so there should be no concern that this function does not exist.
  • Loading branch information
RyanJDick authored and hipsterusername committed Oct 10, 2024
1 parent ea54a26 commit ac08c31
Showing 1 changed file with 5 additions and 10 deletions.
15 changes: 5 additions & 10 deletions invokeai/backend/stable_diffusion/diffusers_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,29 +198,24 @@ def _adjust_memory_efficient_attention(self, latents: torch.Tensor):
self.disable_attention_slicing()
return
elif config.attention_type == "torch-sdp":
if hasattr(torch.nn.functional, "scaled_dot_product_attention"):
# diffusers enables sdp automatically
return
else:
raise Exception("torch-sdp attention slicing not available")
# torch-sdp is the default in diffusers.
return

# See https://github.com/invoke-ai/InvokeAI/issues/7049 for context.
# Bumping torch from 2.2.2 to 2.4.1 caused the sliced attention implementation to produce incorrect results.
# For now, if a user is on an MPS device and has not explicitly set the attention_type, then we select the
# non-sliced torch-sdp implementation. This keeps things working on MPS at the cost of increased peak memory
# utilization.
if torch.backends.mps.is_available():
assert hasattr(torch.nn.functional, "scaled_dot_product_attention")
return

# the remainder if this code is called when attention_type=='auto'
# The remainder if this code is called when attention_type=='auto'.
if self.unet.device.type == "cuda":
if is_xformers_available() and prefer_xformers:
self.enable_xformers_memory_efficient_attention()
return
elif hasattr(torch.nn.functional, "scaled_dot_product_attention"):
# diffusers enables sdp automatically
return
# torch-sdp is the default in diffusers.
return

if self.unet.device.type == "cpu" or self.unet.device.type == "mps":
mem_free = psutil.virtual_memory().free
Expand Down

0 comments on commit ac08c31

Please sign in to comment.