Skip to content

Commit

Permalink
Use pytorch attention by default on nvidia when xformers isn't present.
Browse files Browse the repository at this point in the history
Add a new argument --use-quad-cross-attention
  • Loading branch information
comfyanonymous committed Jun 26, 2023
1 parent 9b93b92 commit 8248bab
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 3 deletions.
3 changes: 2 additions & 1 deletion comfy/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ class LatentPreviewMethod(enum.Enum):
parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction)

attn_group = parser.add_mutually_exclusive_group()
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization instead of the sub-quadratic one. Ignored when xformers is used.")
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.")
attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.")

parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
Expand Down
20 changes: 18 additions & 2 deletions comfy/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,23 @@ def get_total_memory(dev=None, torch_total_too=False):
except:
XFORMERS_IS_AVAILABLE = False

def is_nvidia():
global cpu_state
if cpu_state == CPUState.GPU:
if torch.version.cuda:
return True

ENABLE_PYTORCH_ATTENTION = args.use_pytorch_cross_attention

if ENABLE_PYTORCH_ATTENTION == False and XFORMERS_IS_AVAILABLE == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
try:
if is_nvidia():
torch_version = torch.version.__version__
if int(torch_version[0]) >= 2:
ENABLE_PYTORCH_ATTENTION = True
except:
pass

if ENABLE_PYTORCH_ATTENTION:
torch.backends.cuda.enable_math_sdp(True)
torch.backends.cuda.enable_flash_sdp(True)
Expand Down Expand Up @@ -347,7 +363,7 @@ def pytorch_attention_flash_attention():
global ENABLE_PYTORCH_ATTENTION
if ENABLE_PYTORCH_ATTENTION:
#TODO: more reliable way of checking for flash attention?
if torch.version.cuda: #pytorch flash attention only works on Nvidia
if is_nvidia(): #pytorch flash attention only works on Nvidia
return True
return False

Expand Down Expand Up @@ -438,7 +454,7 @@ def soft_empty_cache():
elif xpu_available:
torch.xpu.empty_cache()
elif torch.cuda.is_available():
if torch.version.cuda: #This seems to make things worse on ROCm so I only do it for cuda
if is_nvidia(): #This seems to make things worse on ROCm so I only do it for cuda
torch.cuda.empty_cache()
torch.cuda.ipc_collect()

Expand Down

0 comments on commit 8248bab

Please sign in to comment.