Replies: 23 comments 57 replies
-
Forge SetupCommandline args Profile: stable-diffusion-webui-forge/modules/txt2img.py Lines 104 to 122 in 29be1da def txt2img_function(id_task: str, request: gr.Request, *args):
from torch.profiler import profile, record_function, ProfilerActivity
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
record_shapes=True,
profile_memory=True, # Track memory allocation
with_stack=True) as prof:
with record_function("model_inference"):
p = txt2img_create_processing(id_task, request, *args)
with closing(p):
processed = modules.scripts.scripts_txt2img.run(p, *p.script_args)
if processed is None:
processed = processing.process_images(p)
# Print profiling results
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))
# Export to Chrome trace format
prof.export_chrome_trace("trace.json") Commandline report
NoteVery interestingly although Forge's total duration is longer than A1111, the 1-step duration for steps after step 1 are significantly shorter than A1111. From about 500ms -> about 250ms. |
Beta Was this translation helpful? Give feedback.
-
ComfyUI SetupNo extra commandline args Profile https://github.com/comfyanonymous/ComfyUI/blob/c61eadf69a3ba4033dcf22e2e190fd54f779fc5b/nodes.py#L1321-L1344 class KSampler:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"model": ("MODEL",),
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
"sampler_name": (comfy.samplers.KSampler.SAMPLERS, ),
"scheduler": (comfy.samplers.KSampler.SCHEDULERS, ),
"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"latent_image": ("LATENT", ),
"denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
}
}
RETURN_TYPES = ("LATENT",)
FUNCTION = "sample"
CATEGORY = "sampling"
def sample(self, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0):
from torch.profiler import profile, record_function, ProfilerActivity
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
record_shapes=True,
profile_memory=True, # Track memory allocation
with_stack=True) as prof:
with record_function("model_inference"):
result = common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise)
# Print profiling results
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))
# Export to Chrome trace format
prof.export_chrome_trace("trace.json")
return result Commandline report
|
Beta Was this translation helpful? Give feedback.
-
All experiments above are cold start 1st run after UI launch. Subsequent runs will have 1st step denoise duration more aligned with following steps. Some intersting observations:
|
Beta Was this translation helpful? Give feedback.
-
The main diff in denosing step performance diff seem to come from cross attn function duration. Perfetto queries INCLUDE PERFETTO MODULE slices.slices;
SELECT
sum(dur) / count(1) as dur,
count(1) as count
FROM _slice_with_thread_and_process_info
WHERE name = 'modules/sd_hijack_optimizations.py(480): xformers_attention_forward'
Forge INCLUDE PERFETTO MODULE slices.slices;
SELECT
sum(dur) / count(1) as dur,
count(1) as count
FROM _slice_with_thread_and_process_info
WHERE name = 'ldm_patched/ldm/modules/attention.py(388): forward' |
Beta Was this translation helpful? Give feedback.
-
Problem 1A1111 is using A1111 impl: def xformers_attention_forward(self, x, context=None, mask=None, **kwargs):
h = self.heads
q_in = self.to_q(x)
context = default(context, x)
context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
k_in = self.to_k(context_k)
v_in = self.to_v(context_v)
q, k, v = (rearrange(t, 'b n (h d) -> b n h d', h=h) for t in (q_in, k_in, v_in))
del q_in, k_in, v_in
dtype = q.dtype
if shared.opts.upcast_attn:
q, k, v = q.float(), k.float(), v.float()
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=get_xformers_flash_attention_op(q, k, v))
out = out.to(dtype)
out = rearrange(out, 'b n h d -> b n (h d)', h=h)
return self.to_out(out) Introduced in AUTOMATIC1111/stable-diffusion-webui#1851 Forge impl: def attention_xformers(q, k, v, heads, mask=None):
b, _, dim_head = q.shape
dim_head //= heads
if BROKEN_XFORMERS:
if b * heads > 65535:
return attention_pytorch(q, k, v, heads, mask)
q, k, v = map(
lambda t: t.unsqueeze(3)
.reshape(b, -1, heads, dim_head)
.permute(0, 2, 1, 3)
.reshape(b * heads, -1, dim_head)
.contiguous(),
(q, k, v),
)
if mask is not None:
pad = 8 - q.shape[1] % 8
mask_out = torch.empty([q.shape[0], q.shape[1], q.shape[1] + pad], dtype=q.dtype, device=q.device)
mask_out[:, :, :mask.shape[-1]] = mask
mask = mask_out[:, :, :mask.shape[-1]]
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
out = (
out.unsqueeze(0)
.reshape(b, heads, -1, dim_head)
.permute(0, 2, 1, 3)
.reshape(b, -1, heads * dim_head)
)
return out Forge explicitly calls reshape and permute, which is explicit and not involving of parsing a string. |
Beta Was this translation helpful? Give feedback.
-
top - a1111, bottom - forge Profile settings
A1111: Model: SD 1.5 (CLIP SKIP 1, VAE None) Tool: NVIDIA Nsight Systems 2024.1.1 A1111 commandline: Collect: CUDA trace Environment Variables:
Check the |
Beta Was this translation helpful? Give feedback.
-
Problem 2A1111's linear call in attn block is taking way longer than forge's. A1111 240ms vs Forge 71ms. From A1111 trace, you can see that there are a bunch of calls to A1111ForgeThese tables are inspecting a particular
You can see that the costs in CLIP token process are comparable, but A1111 has extra overhead in all subsequent runs.
Forge by default does not do any casting during inference. All dtype casting happens before start of inference. |
Beta Was this translation helpful? Give feedback.
-
Problem 3Between calls to 2 ForgeA1111Each of those extra args checks takes ~1ms overhead. Considering there are 70 calls to def checkpoint(func, inputs, params, flag):
"""
Evaluate a function without caching intermediate activations, allowing for
reduced memory at the expense of extra compute in the backward pass.
:param func: the function to evaluate.
:param inputs: the argument sequence to pass to `func`.
:param params: a sequence of parameters `func` depends on but does not
explicitly take as arguments.
:param flag: if False, disable gradient checkpointing.
"""
if flag:
args = tuple(inputs) + tuple(params)
return CheckpointFunction.apply(func, len(inputs), *args)
else:
return func(*inputs)
class CheckpointFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, run_function, length, *args):
ctx.run_function = run_function
ctx.input_tensors = list(args[:length])
ctx.input_params = list(args[length:])
ctx.gpu_autocast_kwargs = {
"enabled": torch.is_autocast_enabled(),
"dtype": torch.get_autocast_gpu_dtype(),
"cache_enabled": torch.is_autocast_cache_enabled(),
}
with torch.no_grad():
output_tensors = ctx.run_function(*ctx.input_tensors)
return output_tensors
@staticmethod
def backward(ctx, *output_grads):
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
with torch.enable_grad(), torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
# Fixes a bug where the first op in run_function modifies the
# Tensor storage in place, which is not allowed for detach()'d
# Tensors.
shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
output_tensors = ctx.run_function(*shallow_copies)
input_grads = torch.autograd.grad(
output_tensors,
ctx.input_tensors + ctx.input_params,
output_grads,
allow_unused=True,
)
del ctx.input_tensors
del ctx.input_params
del output_tensors
return (None, None) + input_grads |
Beta Was this translation helpful? Give feedback.
-
With Problem1 and Problem 3 fixed. We are going from 580ms per iteration to 420ms per iteration already! 🎉 |
Beta Was this translation helpful? Give feedback.
-
Problem 4In each denosing step, A1111 seem to call ForgeA1111 |
Beta Was this translation helpful? Give feedback.
-
With Problem 1, 3, 4 fixed, we are going from 580ms to 370ms per itereation. |
Beta Was this translation helpful? Give feedback.
-
Problem 5Obviously test for nan is not free in A1111 The solution is simple. Add Going from 345ms/it to 325ms/it. 20ms/it cut! |
Beta Was this translation helpful? Give feedback.
-
Nice work! What's the final goal of this investigation? Are you planning to upstream all these fixes back to the A1111 repo? |
Beta Was this translation helpful? Give feedback.
-
Your rapid progress in this endeavor is startling. The emergence of Forge should have spurred Automatic1111 to conduct a similar investigation, finding the same flaws. Do they really need a flashing neon sign above the issues? This makes it look like either you dance circles around them from a coding perspective, or their priorities are completely out of whack. |
Beta Was this translation helpful? Give feedback.
-
If it's possible, do you think you could look into layout optimization? Conv operations on Nvidia cards with tensor cores have faster kernels available with the channels_last memory format. However, just setting the model and inputs to channels last as a whole in eager mode often doesn't work well with anything that isn't a pure CNN since other ops are better off in contiguous format and will be casted back, and with SD this often means you break even. If you profile a compiled model, though, you will very plainly see that it does choose to use channels_last where appropriate. You could probably realize most of the benefit by having channels last only cast conv layers and nearby bias/norm layers to channels last and casting the intermediate tensors to channels last before those layers and cast it back to continuous afterwards. Great work on these optimizations so far! I look forward to having both good optimizations and proper maintenance in one frontend. |
Beta Was this translation helpful? Give feedback.
-
Two blocking calls need to be eliminated to allow torch to properly fill its dispatch queue:
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
).to(device=timesteps.device) with: freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=timesteps.device) / half
) This creates the tensor on device and avoids an unnecessary memory copy that will happily block dispatch.
If these changes are made, torch will be able to line up several sampling steps worth of instructions ahead of time, which will eliminate most overhead. |
Beta Was this translation helpful? Give feedback.
-
I have all PRs bundled here for anyone who wants to try it out: AUTOMATIC1111/stable-diffusion-webui#15821 It is unclear when AUTO will come back and get them merged though. |
Beta Was this translation helpful? Give feedback.
-
I have to ask, is it possible to also mimic Forge's better memory handling? Aside from the speedups, the smarter moving of checkpoints around in memory was super useful for me on a RTX 2080, and still useful sometimes on a RTX 3090. I sometimes see my VRAM spike to max when doing HR Fix on large images with WebUI, but Forge handls pretty much any size with ease. (even with NeverOOM off in Forge, and Tiled VAE on in WebUI) |
Beta Was this translation helpful? Give feedback.
-
I've noticed that if you have some extensions installed the whole UI becomes sluggish, slow, and laggy. Even the main queue dies (generate button is not responding), in this case helps agent scheduler. |
Beta Was this translation helpful? Give feedback.
-
Requested: #801 (reply in thread) System:
Sampling setup:
(none uses xformers) Above 10 steps the profiling process began to slow down, which makes it not suitable for long runs. (fixed trace logs in reply below) |
Beta Was this translation helpful? Give feedback.
-
If we're talking about weird crap with A1111 and the command flags, using --use-ipex on my nvidia 3060TI (8GB of VRAM) gives me a slight speed increase. I tested every single CLI argument that I thought would give me a speed increase a month ago. Here's the weird results. Tests were run with an SD 1.5 based model producing 1080x1920 images using KHRFix. XX = "don't use, no speed increase", while the numbers did give a speed increase. The full test (which I haven't run yet) is 255 possible combinations of the numbered CLI arguments.
|
Beta Was this translation helpful? Give feedback.
-
This is a gold thread thank you for work |
Beta Was this translation helpful? Give feedback.
-
Hey all, this may be late to the party, I am a maintainer of Eclipse Trace Compass, a free open source trace viewer designed to scale. I would like to suggest trying out tc for this, it can injest the json traces and correlate them well. I will make a custom stable-diffusion video soon to illustrate how to. Until then, check this out. https://www.youtube.com/watch?v=YCdzmcpOrK4 as a way to handle the trace event traces. |
Beta Was this translation helpful? Give feedback.
-
This thread will be used for performance profiling analysis for Forge/A1111/ComfyUI. Hopefully people can submit traces, screenshots to help us better understand why A1111 is slow.
I will start with a simple profiling task:
Hardware
GPU: RTX 4090 (24GB VRAM)
CPU: Ryzen 5 3600
Memory: 32 GB
A1111 Setup
Commandline args:
--opt-split-attention --xformers
Profile https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/1c0a0c4c26f78c32095ebc7f8af82f5c04fca8c0/modules/txt2img.py#L102-L120
Commandline report
Beta Was this translation helpful? Give feedback.
All reactions