Skip to content

Commit

Permalink
feat: add rotary_dim argument to rope APIs for partial apply rope (#…
Browse files Browse the repository at this point in the history
…599)

This PR implements the final piece of #530 , so that we can partially
apply rotary embedding to first head dimensions instead of entire head
dimensions.

We also add a simple benchmark for RoPE, below is the result on H100:
```python
batch_size:   1, append_len:     1, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: False, latency: 23us, throughput:   0.876GB/s
batch_size:   1, append_len:     1, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: True, latency: 26us, throughput:   0.801GB/s
batch_size:   1, append_len:   128, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: False, latency: 27us, throughput:  95.735GB/s
batch_size:   1, append_len:   128, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: True, latency: 27us, throughput:  95.639GB/s
batch_size:   1, append_len:  1024, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: False, latency: 31us, throughput: 672.889GB/s
batch_size:   1, append_len:  1024, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: True, latency: 32us, throughput: 662.972GB/s
---
batch_size:  19, append_len:     1, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: False, latency: 27us, throughput:  14.559GB/s
batch_size:  19, append_len:     1, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: True, latency: 27us, throughput:  14.435GB/s
batch_size:  19, append_len:   128, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: False, latency: 37us, throughput: 1339.450GB/s
batch_size:  19, append_len:   128, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: True, latency: 37us, throughput: 1340.399GB/s
batch_size:  19, append_len:  1024, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: False, latency: 148us, throughput: 2696.563GB/s
batch_size:  19, append_len:  1024, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: True, latency: 148us, throughput: 2689.104GB/s
---
batch_size:  99, append_len:     1, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: False, latency: 27us, throughput:  74.186GB/s
batch_size:  99, append_len:     1, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: True, latency: 27us, throughput:  74.452GB/s
batch_size:  99, append_len:   128, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: False, latency: 110us, throughput: 2350.830GB/s
batch_size:  99, append_len:   128, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: True, latency: 110us, throughput: 2359.814GB/s
batch_size:  99, append_len:  1024, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: False, latency: 717us, throughput: 2895.389GB/s
batch_size:  99, append_len:  1024, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: True, latency: 718us, throughput: 2891.385GB/s
---
batch_size: 128, append_len:     1, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: False, latency: 27us, throughput:  95.449GB/s
batch_size: 128, append_len:     1, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: True, latency: 27us, throughput:  95.646GB/s
batch_size: 128, append_len:   128, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: False, latency: 130us, throughput: 2576.101GB/s
batch_size: 128, append_len:   128, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: True, latency: 130us, throughput: 2582.447GB/s
batch_size: 128, append_len:  1024, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: False, latency: 924us, throughput: 2906.154GB/s
batch_size: 128, append_len:  1024, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: True, latency: 925us, throughput: 2903.484GB/s
```
  • Loading branch information
yzh119 authored Nov 10, 2024
1 parent 2043ca2 commit eb9bc71
Show file tree
Hide file tree
Showing 8 changed files with 493 additions and 197 deletions.
93 changes: 93 additions & 0 deletions benchmarks/bench_rope.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import argparse
from typing import cast

import torch
from triton.testing import do_bench

import flashinfer


def generate_cos_sin_f32_cache(max_seq_len, head_dim, theta=1e4):
position = torch.arange(max_seq_len).float().unsqueeze(1)
freqs = 1.0 / (theta ** (torch.arange(0, head_dim, 2).float() / head_dim))
freqs = torch.cat([freqs, freqs], dim=-1).contiguous()
args = position * freqs
sin_cache = torch.sin(args)
cos_cache = torch.cos(args)
return cos_cache, sin_cache


@torch.inference_mode()
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--batch-sizes", nargs="+", type=int, default=[1, 19, 99, 128])
parser.add_argument("--append-len", nargs="+", type=int, default=[1, 128, 1024])
parser.add_argument("--num-qo-heads", type=int, default=32)
parser.add_argument("--num-kv-heads", type=int, default=8)
parser.add_argument("--head-dim", type=int, default=128)
args = parser.parse_args()

eps = 1e-6
dtype = torch.float16
num_qo_heads = args.num_qo_heads
num_kv_heads = args.num_kv_heads
head_dim = args.head_dim

# Loop over each combination of batch_size, hidden_size, and dtype
for batch_size in args.batch_sizes:
for append_len in args.append_len:
for use_cos_sin_cache in [False, True]:
# Define tensors with the correct dtype

q = torch.randn(
(batch_size * append_len, num_qo_heads, args.head_dim),
dtype=dtype,
device="cuda",
)
k = torch.randn(
(batch_size * append_len, num_kv_heads, args.head_dim),
dtype=dtype,
device="cuda",
)
pos_ids = torch.repeat_interleave(
torch.arange(append_len, dtype=torch.int32, device=q.device),
batch_size,
)
cos_cache, sin_cache = generate_cos_sin_f32_cache(4096, head_dim)
cos_cache = cos_cache.to(q.device)
sin_cache = sin_cache.to(q.device)

@torch.cuda.nvtx.range(
f"apply_rope batch_size={batch_size}, append_len={append_len}, num_qo_heads={num_qo_heads}, num_kv_heads={num_kv_heads}, head_dim={head_dim}"
)
def fn() -> None:
if use_cos_sin_cache:
flashinfer.apply_rope_with_cos_sin_cache(
q, k, cos_cache, sin_cache, pos_ids
)
else:
flashinfer.apply_rope_pos_ids(q, k, pos_ids)

# Run benchmarking
latency_ms = cast(float, do_bench(fn))
throughput = (
q.numel() * q.element_size() * 2 + k.numel() * k.element_size() * 2
) / (latency_ms * 1e-3)
print(
f"batch_size: {batch_size:3},",
f"append_len: {append_len:5},",
f"num_qo_heads: {num_qo_heads:5},",
f"num_kv_heads: {num_kv_heads:5},",
f"head_dim: {head_dim:5},",
f"use_cos_sin_cache: {use_cos_sin_cache},",
f"latency: {latency_ms*1e3:2.0f}us,",
f"throughput: {throughput*1e-9:7.3f}GB/s",
)

print("---")

torch.cuda.profiler.stop()


if __name__ == "__main__":
main()
3 changes: 0 additions & 3 deletions include/flashinfer/cutlass_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@
#ifndef FLASHINFER_CUTLASS_UTILS_CUH_
#define FLASHINFER_CUTLASS_UTILS_CUH_

#include <cuda_runtime.h>
#include <cutlass/cutlass.h>

#include "cute/tensor.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/epilogue/collective/collective_builder.hpp"
Expand Down
Loading

0 comments on commit eb9bc71

Please sign in to comment.