You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hey, I have observed in my timing tests that version 2.6.3 is faster than some later commits (including 2.7.0.post2) for below input sizes. For example, for small batch sizes (==2) and relatively small sequences, 2.6.3 is even 2x faster for me in the forward pass.
My setup: 4070 Laptop (CUDA 12) and A100 (CUDA 11), Torch 2.4. Both flash-attn versions were installed via pip install directly from PyPI. Below are results measured with a custom Python script with proper CUDA synchronization.
importtimefromabcimportABC, abstractmethodfrommathimportsqrtfromtypingimportOptional, TupleimportpytestfromloguruimportloggerimporttorchfromtorchimportTensorfromtorch.nnimportModuletry:
fromflash_attnimportflash_attn_funcexceptImportErrorase:
logger.error(f"ImportError: {e}")
flash_attn_func=None# Tensors dimensions (batch_size, seq_len (same for q and k), num_heads, embed_dim)dimensions= [
(512, 50, 16, 256),
(512, 150, 16, 256),
(512, 300, 16, 256),
(2, 50, 16, 256),
(2, 150, 16, 256),
(2, 300, 16, 256),
]
dtypes= [
torch.float16,
#torch.bfloat16
]
classAttentionBackend(Module, ABC):
def__init__(self, embed_dim: int, num_heads: int, dropout: float=.1):
""" :param embed_dim: the size of each embedding vector :param num_heads: number of heads :param dropout: attention dropout """assertnotembed_dim%num_heads, 'embed_dim must be divisible by num_heads'super().__init__()
self.embed_dim=embed_dimself.num_heads=num_headsself.dropout=dropoutself._scale=1/sqrt(embed_dim/num_heads)
defunflatten(self, q: Tensor, k: Tensor, v: Tensor) ->Tuple[Tensor, Tensor, Tensor]:
q=q.unflatten(2, (self.num_heads, -1))
k=k.unflatten(2, (self.num_heads, -1))
v=v.unflatten(2, (self.num_heads, -1))
returnq, k, v@abstractmethoddefforward(self, q: Tensor, k: Tensor, v: Tensor) ->Tuple[
Tensor, Optional[Tensor]]:
raiseNotImplementedError("Forward method not implemented in subclass.")
classDummyFlashAttentionBackend(AttentionBackend):
defforward(self, q: Tensor, k: Tensor, v: Tensor) ->Tuple[
Tensor, Optional[Tensor]]:
q, k, v=self.unflatten(q, k, v)
o=flash_attn_func(q, k, v, softmax_scale=self._scale, dropout_p=self.dropout)
o=o.flatten(2)
returno, Nonedefmeasure_forward_time(q,k,v, backend, num_runs=10):
times= []
for_inrange(2): # Warm-up runsoutput, _=backend(q, k, v)
torch.cuda.synchronize()
for_inrange(num_runs): # Timed runstorch.cuda.synchronize()
start_time=time.time()
output, _=backend(q, k, v)
torch.cuda.synchronize()
end_time=time.time()
times.append(end_time-start_time)
secs_to_microseconds=1000000return (sum(times) /num_runs) *secs_to_microsecondsdefmeasure_backward_time(q,k,v, backend, num_runs=10):
times= []
for_inrange(2): # Warm-up runsoutput, _=backend(q, k, v)
loss=torch.sum(output) # Dummy loss for backward passloss.backward(retain_graph=True)
torch.cuda.synchronize()
for_inrange(num_runs): # Timed runstorch.cuda.synchronize()
output, _=backend(q, k, v)
loss=torch.sum(output) # Dummy loss for backward passtorch.cuda.synchronize()
start_time=time.time()
loss.backward(retain_graph=True)
torch.cuda.synchronize()
end_time=time.time()
times.append(end_time-start_time)
secs_to_microseconds=1000000return (sum(times) /num_runs) *secs_to_microsecondsdefcreate_random_tensors_with_embeddings(batch_size, seq_len, embed_dim, num_heads, device, dtype):
q=torch.randn((batch_size, seq_len, embed_dim), device=device, dtype=dtype, requires_grad=True)
k=torch.randn((batch_size, seq_len, embed_dim), device=device, dtype=dtype, requires_grad=True)
v=torch.randn((batch_size, seq_len, embed_dim), device=device, dtype=dtype, requires_grad=True)
returnq, k, v@pytest.mark.parametrize("batch_size, seq_len, num_heads, embed_dim", dimensions)@pytest.mark.parametrize("dtype", dtypes)deftest_timing_forward_pass(batch_size, seq_len, num_heads, embed_dim, dtype):
# Compare flash attention timingsifnottorch.cuda.is_available():
pytest.skip("CUDA is not available")
torch.manual_seed(42)
device=torch.device('cuda')
dropout=0# disable dropoutq, k, v=create_random_tensors_with_embeddings(
batch_size, seq_len, embed_dim, num_heads, device, dtype
)
dummy_kernel_backend=DummyFlashAttentionBackend(embed_dim, num_heads, dropout=dropout).to(device)
avg_dummy_time=measure_forward_time(q, k, v, dummy_kernel_backend)
deldummy_kernel_backend# Log the resultslogger.info(f"Configuration: batch_size={batch_size}, seq_len={seq_len}, "f"num_heads={num_heads}, embed_dim={embed_dim}, dtype={dtype}")
logger.info(f"[Forward] Average time for Dummy backend: {avg_dummy_time:.2f} microsecs")
# Pass the test since it's for timing comparison, no correctness checkingassertTrue# As noted, the output comparison is not relevant here@pytest.mark.parametrize("batch_size, seq_len, num_heads, embed_dim", dimensions)@pytest.mark.parametrize("dtype", dtypes)deftest_timing_backward_pass(batch_size, seq_len, num_heads, embed_dim, dtype):
# Compare flash attention timingsifnottorch.cuda.is_available():
pytest.skip("CUDA is not available")
torch.manual_seed(42)
device=torch.device('cuda')
dropout=0# disable dropoutq, k, v=create_random_tensors_with_embeddings(
batch_size, seq_len, embed_dim, num_heads, device, dtype
)
dummy_kernel_backend=DummyFlashAttentionBackend(embed_dim, num_heads, dropout=dropout).to(device)
avg_dummy_time=measure_backward_time(q, k, v, dummy_kernel_backend)
deldummy_kernel_backend# Log the resultslogger.info(f"Configuration: batch_size={batch_size}, seq_len={seq_len}, "f"num_heads={num_heads}, embed_dim={embed_dim}, dtype={dtype}")
logger.info(f"[Backward] Average time for Dummy backend: {avg_dummy_time:.2f} nanosecs")
# Pass the test since it's for timing comparison, no correctness checkingassertTrue# As noted, the output comparison is not relevant here
Could you please help me understand what might be the source of these timing differences? When going through the source code, it seems to me that the kernel code is the same, the CUTLASS submodule repo pointer is the same, and the only changes are in the API in C++/Python, which relate to head, head_size_og, and padding. Also, my embedding sizes and head numbers are divisible by 8.
The text was updated successfully, but these errors were encountered:
I'm guessing it's because we moved some of the checks and padding (i.e. checking if headdim not a multiple 8) from C++ to Python for compatibility with torch compile. This might add a bit more Python overhead so it's noticable for small batch and short sequences (since the kernel will be very fast there).
You can try torch compiling it to reduce the overhead in this case.
What would be helpful is to get the profiler result (e.g. pytorch profiler or nsight systems) to see the kernel time. e.g. if the kernel time stays the same then we can say it's because of Python overhead. If the kernel time is very different then we'll need to investigate.
I’ve been profiling with nsys on the A100 and can conclude that it’s likely Python overhead, as the kernel times appear identical for both versions 2.6.3 and 2.7.0post2. I’m checking forward/backward passes for the same dimensions as mentioned earlier. Unfortunately, it seems that Python overhead becomes quite significant, especially when targeting smaller Q/K lengths and/or batch sizes.
You can try torch compiling it to reduce the overhead in this case.
Yeah, we should introduce it as a baseline I guess. Will test it soon. ATM, this thread can be closed :) Thanks!
Hey, I have observed in my timing tests that version 2.6.3 is faster than some later commits (including 2.7.0.post2) for below input sizes. For example, for small batch sizes (==2) and relatively small sequences, 2.6.3 is even 2x faster for me in the forward pass.
My setup: 4070 Laptop (CUDA 12) and A100 (CUDA 11), Torch 2.4. Both flash-attn versions were installed via pip install directly from PyPI. Below are results measured with a custom Python script with proper CUDA synchronization.
Minimal instructions to replicate:
# set-up environment for flash attention, install torch pip install loguru pip install pytest pip install flash-attn==2.6.3 --no-build-isolation pytest -s test_min_example.py pip uninstall flash-attn==2.6.3 pip install flash-attn==2.7.0.post2 --no-build-isolation pytest -s test_min_example.py
test_min_example.py
Could you please help me understand what might be the source of these timing differences? When going through the source code, it seems to me that the kernel code is the same, the CUTLASS submodule repo pointer is the same, and the only changes are in the API in C++/Python, which relate to head, head_size_og, and padding. Also, my embedding sizes and head numbers are divisible by 8.
The text was updated successfully, but these errors were encountered: