You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I noticed that vLLM has implemented Cascade Attention in this PR: [V1] Implement Cascade Attention #11635 . I conducted some benchmarks with the Qwen2-0.5B model on an A100, aiming to determine whether it would be beneficial when most requests in the batch share a long common prefix.
I'm testing with Qwen2-0.5B-Instruct, using an input sequence length of 1024, batch sizes of 8 and 7, and an output length of 1, vLLM main branch with commit id 1f1542afa915e0975d2b63559424403e5e8aae2c.
However, it turns out that Cascade Attention did not show much improvement, and the latency exhibited a large standard deviation compared to the vLLM v0 implementation.
The baseline: (v0 implementation)
Cascade Attention: (VLLM_USE_V1=1)
And here is the benchmark script:
from vllm import LLM, SamplingParams
import random
import time
import statistics
import matplotlib.pyplot as plt
import argparse
import re
def generate_prompt_batch(length, common_prefix_length, batch_size, vocab_size):
# Generate a common prefix
common_prefix = [random.randint(0, vocab_size - 1) for _ in range(common_prefix_length)]
# Generate a batch of random tokenized prompts
batch = []
for _ in range(batch_size):
random_tokens = [random.randint(0, vocab_size - 1) for _ in range(length - common_prefix_length)]
prompt = common_prefix + random_tokens
batch.append(prompt)
return batch
def benchmark_generate_batch(batch_size, length, common_prefix_length, run_times, llm, tokenizer, sampling_params):
latencies = []
for _ in range(run_times):
test_batch = generate_prompt_batch(length, common_prefix_length, batch_size, tokenizer.vocab_size)
start = time.perf_counter()
responses = llm.generate(prompt_token_ids=test_batch, sampling_params=sampling_params, use_tqdm=False)
end = time.perf_counter()
latencies.append((end - start) * 1000) # Convert to milliseconds
avg_latency = statistics.mean(latencies)
std_latency = statistics.stdev(latencies)
print(f"Batch size: {batch_size}, Prompt length: {length}, Common prefix length: {common_prefix_length}, Number of runs: {run_times}")
print(f"Average latency: {avg_latency:.2f} milliseconds, std: {std_latency:.2f} milliseconds")
if std_latency > 5:
print(latencies)
return avg_latency, std_latency
def main():
parser = argparse.ArgumentParser(description="Benchmark LLM generation latencies.")
parser.add_argument('--model_name', type=str, required=True, help='Name of the model to use.')
parser.add_argument('--batch_size_ls', type=int, nargs='+', required=True, help='List of batch sizes to test.')
parser.add_argument('--use_v1', type=str, required=True, help='whether to use vllm v1 or not')
args = parser.parse_args()
print(args)
llm = LLM(model=args.model_name, max_num_batched_tokens=8*1024+512, enable_chunked_prefill=True, max_model_len=1025)
tokenizer = llm.get_tokenizer() # Access the tokenizer from the LLM object
sampling_params = SamplingParams(temperature=0.0, max_tokens=1)
common_prefix_ls = [0, 64, 128, 256, 512, 1024]
length = 1024
run_times = 20
# warmup
benchmark_generate_batch(1, length, 0, run_times, llm, tokenizer, sampling_params)
avg_latencies = {size: [] for size in args.batch_size_ls}
for common_prefix_length in common_prefix_ls:
for batch_size in args.batch_size_ls:
avg_latency, _= benchmark_generate_batch(batch_size, length, common_prefix_length, run_times, llm, tokenizer, sampling_params)
avg_latencies[batch_size].append(avg_latency)
# Draw bar chart with lines connecting the bars
x = range(len(common_prefix_ls))
width = 0.2
plt.figure(figsize=(10, 6))
bars = {}
colors = ['blue', 'green', 'red', 'pink', 'purple', 'orange', 'brown', 'gray', 'cyan', 'magenta']
for i, batch_size in enumerate(args.batch_size_ls):
bars[batch_size] = plt.bar(
[p + width * i for p in x], avg_latencies[batch_size], width=width,
label=f'Batch size {batch_size}', color=colors[i % len(colors)], align='center'
)
plt.plot(
[p + width * i for p in x], avg_latencies[batch_size], color=colors[i % len(colors)],
marker='o', linestyle='-', linewidth=2, markersize=5
)
# Add text labels above the bars
for batch_size in args.batch_size_ls:
for bar in bars[batch_size]:
yval = bar.get_height()
plt.text(bar.get_x() + bar.get_width() / 2.0, yval, f'{yval:.2f}', va='bottom') # va: vertical alignment
plt.xlabel('Common Prefix Length')
plt.ylabel('Average Latency (ms)')
plt.title('Average Latency vs Common Prefix Length')
plt.xticks([p + width / 2 for p in x], common_prefix_ls)
plt.legend()
# Save the plot to a file with postfix model name and batch sizes
batch_sizes_str = ','.join(map(str, args.batch_size_ls))
safe_model_name = re.sub(r'[^\w\-_\. ]', '_', args.model_name)
safe_batch_sizes_str = re.sub(r'[^\w\-_\. ]', '_', batch_sizes_str)
plt.savefig(f'latency_plot_{safe_model_name}_{safe_batch_sizes_str}_v1_{args.use_v1}.png')
plt.show()
if __name__ == "__main__":
main()
Misc discussion on performance
No response
Your current environment (if you think it is necessary)
PyTorch version: 2.5.1+cu124
Is debug build: False
CUDA used to build PyTorch: 12.4
ROCM used to build PyTorch: N/A
OS: Ubuntu 22.04.4 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.35
Python version: 3.12.8 (main, Dec 4 2024, 08:54:12) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-5.15.0-1078-azure-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA A100 80GB PCIe
Nvidia driver version: 560.35.05
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 48 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 96
On-line CPU(s) list: 0-95
Vendor ID: AuthenticAMD
Model name: AMD EPYC 7V13 64-Core Processor
CPU family: 25
Model: 1
Thread(s) per core: 1
Core(s) per socket: 48
Socket(s): 2
Stepping: 1
BogoMIPS: 4890.87
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl tsc_reliable nonstop_tsc cpuid extd_apicid aperfmperf tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw topoext perfctr_core invpcid_single vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves clzero xsaveerptr rdpru arat umip vaes vpclmulqdq rdpid fsrm
Hypervisor vendor: Microsoft
Virtualization type: full
L1d cache: 3 MiB (96 instances)
L1i cache: 3 MiB (96 instances)
L2 cache: 48 MiB (96 instances)
L3 cache: 384 MiB (12 instances)
NUMA node(s): 4
NUMA node0 CPU(s): 0-23
NUMA node1 CPU(s): 24-47
NUMA node2 CPU(s): 48-71
NUMA node3 CPU(s): 72-95
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Mitigation; safe RET, no microcode
Vulnerability Spec store bypass: Vulnerable
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines; STIBP disabled; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Versions of relevant libraries:
[pip3] flashinfer==0.1.6+cu121torch2.4
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.26.4
[pip3] nvidia-cublas-cu12==12.4.5.8
[pip3] nvidia-cuda-cupti-cu12==12.4.127
[pip3] nvidia-cuda-nvrtc-cu12==12.4.127
[pip3] nvidia-cuda-runtime-cu12==12.4.127
[pip3] nvidia-cudnn-cu12==9.1.0.70
[pip3] nvidia-cufft-cu12==11.2.1.3
[pip3] nvidia-curand-cu12==10.3.5.147
[pip3] nvidia-cusolver-cu12==11.6.1.9
[pip3] nvidia-cusparse-cu12==12.3.1.170
[pip3] nvidia-ml-py==12.560.30
[pip3] nvidia-nccl-cu12==2.21.5
[pip3] nvidia-nvjitlink-cu12==12.4.127
[pip3] nvidia-nvtx-cu12==12.4.127
[pip3] pyzmq==26.2.0
[pip3] sentence-transformers==3.2.1
[pip3] torch==2.5.1
[pip3] torchvision==0.20.1
[pip3] transformers==4.47.0
[pip3] transformers-stream-generator==0.0.5
[pip3] triton==3.1.0
[pip3] tritonclient==2.51.0
[conda] Could not collect
ROCM Version: Could not collect
Neuron SDK Version: N/A
vLLM Version: 0.6.6.post2.dev304+g1f1542af
vLLM Build Flags:
CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled
GPU Topology:
GPU0 NIC0 CPU Affinity NUMA Affinity GPU NUMA ID
GPU0 X NODE 0-23 0 N/A
NIC0 NODE X
Legend:
X = Self
SYS = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
PHB = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
PXB = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
PIX = Connection traversing at most a single PCIe bridge
NV# = Connection traversing a bonded set of # NVLinks
NIC Legend:
NIC0: mlx5_0
NVIDIA_VISIBLE_DEVICES=all
NVIDIA_REQUIRE_CUDA=cuda>=12.4 brand=tesla,driver>=470,driver<471 brand=unknown,driver>=470,driver<471 brand=nvidia,driver>=470,driver<471 brand=nvidiartx,driver>=470,driver<471 brand=geforce,driver>=470,driver<471 brand=geforcertx,driver>=470,driver<471 brand=quadro,driver>=470,driver<471 brand=quadrortx,driver>=470,driver<471 brand=titan,driver>=470,driver<471 brand=titanrtx,driver>=470,driver<471 brand=tesla,driver>=525,driver<526 brand=unknown,driver>=525,driver<526 brand=nvidia,driver>=525,driver<526 brand=nvidiartx,driver>=525,driver<526 brand=geforce,driver>=525,driver<526 brand=geforcertx,driver>=525,driver<526 brand=quadro,driver>=525,driver<526 brand=quadrortx,driver>=525,driver<526 brand=titan,driver>=525,driver<526 brand=titanrtx,driver>=525,driver<526 brand=tesla,driver>=535,driver<536 brand=unknown,driver>=535,driver<536 brand=nvidia,driver>=535,driver<536 brand=nvidiartx,driver>=535,driver<536 brand=geforce,driver>=535,driver<536 brand=geforcertx,driver>=535,driver<536 brand=quadro,driver>=535,driver<536 brand=quadrortx,driver>=535,driver<536 brand=titan,driver>=535,driver<536 brand=titanrtx,driver>=535,driver<536
NVIDIA_DRIVER_CAPABILITIES=compute,utility
CUDA_VERSION=12.4.1
LD_LIBRARY_PATH=/usr/local/lib/python3.12/dist-packages/cv2/../../lib64:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
VLLM_USE_V1=0
NCCL_CUMEM_ENABLE=0
TORCHINDUCTOR_COMPILE_THREADS=1
CUDA_MODULE_LOADING=LAZY
Before submitting a new issue...
Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
The text was updated successfully, but these errors were encountered:
Proposal to improve performance
No response
Report of performance regression
I noticed that vLLM has implemented Cascade Attention in this PR: [V1] Implement Cascade Attention #11635
. I conducted some benchmarks with the Qwen2-0.5B model on an A100, aiming to determine whether it would be beneficial when most requests in the batch share a long common prefix.
I'm testing with
Qwen2-0.5B-Instruct
, using an input sequence length of1024
, batch sizes of8
and7
, and an output length of1
, vLLM main branch with commit id1f1542afa915e0975d2b63559424403e5e8aae2c
.However, it turns out that Cascade Attention did not show much improvement, and the latency exhibited a large standard deviation compared to the vLLM v0 implementation.
The baseline: (v0 implementation)
Cascade Attention: (VLLM_USE_V1=1)
And here is the benchmark script:
Misc discussion on performance
No response
Your current environment (if you think it is necessary)
Before submitting a new issue...
The text was updated successfully, but these errors were encountered: