Skip to content

Commit

Permalink
feat: JIT compilation (#507)
Browse files Browse the repository at this point in the history
This PR implements the JIT compilation (#170 ) of flashinfer, after this
PR, flashinfer will compile kernels just-in-time for different input
data types and shapes, and cached the kernels at the disk, instead of
pre-compile a set of kernels in the wheel.

# Motivation
The pip wheel size is exploding as we add support to more data types,
more head dimensions, more attention variants and more kernel
implementation. Pre-compile everything is not sustainable, and impedes
development speed.

This PR refactors the codebase to use torch's [JIT Compiling
Extensions](https://pytorch.org/tutorials/advanced/cpp_extension.html#jit-compiling-extensions)
feature instead of pre-compile kernels in the wheel.

## Attention Variants
We learned from [FlexAttention](https://pytorch.org/blog/flexattention/)
and describes every attention variant as a template class, each instance
of the struct can carry some closure variable defined in local memory or
shared memory, below are two examples (logits soft cap and alibi
attention, the programming interface is tentative and will be updated as
we improve the programmability of the JIT template):

```cuda
template <typename ParamsT>
struct LogitsSoftCap {
  using DTypeQ = typename ParamsT::DTypeQ;
  using DTypeKV = typename ParamsT::DTypeKV;
  using DTypeO = typename ParamsT::DTypeO;

  uint32_t qo_len, kv_len;
  uint32_t window_left;

  __device__ __host__ LogitsSoftCap(const ParamsT& params, uint32_t batch_idx, uint8_t* smem_ptr) {
    qo_len = params.get_qo_len(batch_idx);
    kv_len = params.get_kv_len(batch_idx);
    window_left = kv_len;
  }

  template <typename T>
  __device__ __forceinline__ T QueryTransform(const ParamsT& params, T q) {
    return float(q) * params.sm_scale * math::ptx_rcp(params.logits_soft_cap);
  }

  template <typename T>
  __device__ __forceinline__ T LogitsTransform(const ParamsT& params, T logits, uint32_t batch_idx,
                                               uint32_t qo_idx, uint32_t kv_idx,
                                               uint32_t qo_head_idx, uint32_t kv_head_idx) {
    return params.logits_soft_cap * math::log2e * float(math::tanh(logits));
  }

  __device__ __forceinline__ bool LogitsMask(const ParamsT& params, uint32_t batch_idx,
                                             uint32_t qo_idx, uint32_t kv_idx, uint32_t qo_head_idx,
                                             uint32_t kv_head_idx) {
    return true;
  }
};

template <typename ParamsT>
struct ALIBIAttention {
  using DTypeQ = typename ParamsT::DTypeQ;
  using DTypeKV = typename ParamsT::DTypeKV;
  using DTypeO = typename ParamsT::DTypeO;
  using IdType = typename ParamsT::IdType;

  uint32_t qo_len, kv_len;
  uint32_t window_left;

  __device__ __host__ ALIBIAttention(const ParamsT& params, uint32_t batch_idx, uint8_t* smem_ptr) {
    qo_len = params.get_qo_len(batch_idx);
    kv_len = params.get_kv_len(batch_idx);
    window_left = kv_len;
  }

  template <typename T>
  __device__ __forceinline__ T QueryTransform(const ParamsT& params, T q) {
    return float(q) * params.sm_scale * math::log2e;
  }

  template <typename T>
  __device__ __forceinline__ T LogitsTransform(const ParamsT& params, T logits, uint32_t batch_idx,
                                               uint32_t qo_idx, uint32_t kv_idx,
                                               uint32_t qo_head_idx, uint32_t kv_head_idx) {
    return logits + params.alibi_slopes[qo_head_idx] * float(int(kv_idx) - int(qo_idx));
  }

  __device__ __forceinline__ bool LogitsMask(const ParamsT& params, uint32_t batch_idx,
                                             uint32_t qo_idx, uint32_t kv_idx, uint32_t qo_head_idx,
                                             uint32_t kv_head_idx) {
    return true;
  }
};
```
User can customize their own `ParamsT` class and variants class to
define their own attention variants, we hope such refactor will make the
codebase more concise and extensive.

# Roadmap

After this PR, we will add support for:
1. PyPI wheels #153 
2. fp8 tensor cores attention: #502
3. different head dimensions: #142 #454 #455
4. flashattention3 #369 
5. multi-head latency attention #237 
6. Generate ParamsT and Attention variants description from python dsl

The development of this features have been blocked by the limitation of
wheel size (binary size >= 2GB will trigger some linking issues), I hope
this PR will make development easier in the future.
  • Loading branch information
yzh119 authored Oct 7, 2024
1 parent 2043692 commit 3613a5b
Show file tree
Hide file tree
Showing 137 changed files with 6,986 additions and 6,122 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ src/dispatch.inc
src/generated/
python/csrc/generated/
python/flashinfer/_build_meta.py
python/flashinfer/jit/aot_config.py
flashinfer-aot/csrc_aot/generated/

# Generated documentation files
docs/generated
Expand Down
2 changes: 1 addition & 1 deletion 3rdparty/cutlass
Submodule cutlass updated 360 files
242 changes: 114 additions & 128 deletions CMakeLists.txt

Large diffs are not rendered by default.

1 change: 0 additions & 1 deletion cmake/config.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ set(FLASHINFER_FASTDEQUANT_TEST ON)
set(FLASHINFER_DISTRIBUTED ON)
# The following configurations can impact the binary
# size of the generated library
set(FLASHINFER_GEN_LOGITS_POST_HOOKS 0)
set(FLASHINFER_GEN_HEAD_DIMS 64 128 256)
set(FLASHINFER_GEN_KV_LAYOUTS 0 1)
set(FLASHINFER_GEN_POS_ENCODING_MODES 0 1 2)
Expand Down
1 change: 1 addition & 0 deletions flashinfer-aot/3rdparty
12 changes: 12 additions & 0 deletions flashinfer-aot/MANIFEST.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# sdist & wheel
include version.txt
recursive-include include *
recursive-include csrc *
recursive-include 3rdparty/cutlass *

# wheel-only
exclude flashinfer/_build_meta.py

# Unneeded files
prune */__pycache__
global-exclude *.so
1 change: 1 addition & 0 deletions flashinfer-aot/csrc
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,25 @@

#include <flashinfer/activation.cuh>

#include "flashinfer_ops.h"
#include "pytorch_extension_utils.h"

using namespace flashinfer;

__device__ __forceinline__ float silu(const float& val) {
return val / (1.0f + __expf(-val));
}

__device__ __forceinline__ float gelu(const float& val) {
constexpr float kAlpha = M_SQRT1_2;
return val * 0.5f * (1.0f + ::erf(val * kAlpha));
}

__device__ __forceinline__ float gelu_tanh(const float& val) {
const float cdf =
0.5f * (1.0f + math::tanh((0.7978845608028654f * (val + 0.044715f * val * val * val))));
return val * cdf;
}

void silu_and_mul(torch::Tensor& out, torch::Tensor& input) {
int d = input.size(-1) / 2;
int64_t num_tokens = input.numel() / input.size(-1);
Expand All @@ -33,7 +47,7 @@ void silu_and_mul(torch::Tensor& out, torch::Tensor& input) {
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] {
uint32_t vec_size = 16 / sizeof(c_type);
dim3 block(std::min(d / vec_size, 1024U));
flashinfer::activation::act_and_mul_kernel<c_type, flashinfer::activation::silu_kernel>
flashinfer::activation::act_and_mul_kernel<c_type, silu>
<<<grid, block, 0, stream>>>(static_cast<c_type*>(out.data_ptr()),
static_cast<c_type*>(input.data_ptr()), d);

Expand All @@ -51,7 +65,7 @@ void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input) {
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] {
uint32_t vec_size = 16 / sizeof(c_type);
dim3 block(std::min(d / vec_size, 1024U));
flashinfer::activation::act_and_mul_kernel<c_type, flashinfer::activation::gelu_tanh_kernel>
flashinfer::activation::act_and_mul_kernel<c_type, gelu_tanh>
<<<grid, block, 0, stream>>>(static_cast<c_type*>(out.data_ptr()),
static_cast<c_type*>(input.data_ptr()), d);

Expand All @@ -69,7 +83,7 @@ void gelu_and_mul(torch::Tensor& out, torch::Tensor& input) {
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] {
uint32_t vec_size = 16 / sizeof(c_type);
dim3 block(std::min(d / vec_size, 1024U));
flashinfer::activation::act_and_mul_kernel<c_type, flashinfer::activation::gelu_kernel>
flashinfer::activation::act_and_mul_kernel<c_type, gelu>
<<<grid, block, 0, stream>>>(static_cast<c_type*>(out.data_ptr()),
static_cast<c_type*>(input.data_ptr()), d);

Expand Down
205 changes: 205 additions & 0 deletions flashinfer-aot/csrc_aot/batch_decode.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
/*
* Copyright (c) 2023 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <torch/extension.h>

#include <flashinfer/attention/decode_params.cuh>
#include <flashinfer/attention/scheduler.cuh>
#include <flashinfer/attention/variants.cuh>
#include <optional>

#include "pytorch_extension_utils.h"

namespace flashinfer {

template <uint32_t HEAD_DIM, PosEncodingMode POS_ENCODING_MODE, typename AttentionVariant>
cudaError_t BatchDecodeWithPagedKVCacheDispatched(typename AttentionVariant::ParamsT params,
typename AttentionVariant::DTypeO* tmp_v,
float* tmp_s, cudaStream_t stream);

} // namespace flashinfer

std::vector<int64_t> BatchDecodeWithPagedKVCachePlan(
bool use_logits_soft_cap, unsigned int head_dim, torch::Tensor empty_q_data,
torch::Tensor empty_kv_data, torch::Tensor float_workspace_buffer,
torch::Tensor int_workspace_buffer, torch::Tensor page_locked_int_workspace_buffer,
torch::Tensor indptr, unsigned int batch_size, unsigned int num_qo_heads,
unsigned int num_kv_heads, unsigned int page_size, bool enable_cuda_graph) {
size_t float_workspace_size_in_bytes =
float_workspace_buffer.size(0) * float_workspace_buffer.element_size();
size_t int_workspace_size_in_bytes =
int_workspace_buffer.size(0) * int_workspace_buffer.element_size();
auto device = float_workspace_buffer.device();
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
indptr = indptr.to(torch::kCPU);

DecodePlanInfo plan_info;

using IdType = int32_t;
// check indptr has idtype int32
TORCH_CHECK(indptr.scalar_type() == torch::kInt32, "indptr must be int32");
constexpr auto POS_ENCODING_MODE = PosEncodingMode::kNone;

auto q_scalar_type = empty_q_data.scalar_type();
auto kv_scalar_type = empty_kv_data.scalar_type();

DISPATCH_PYTORCH_QKV_DTYPE_TO_CTYPE(q_scalar_type, kv_scalar_type, q_type, kv_type, [&] {
using DTypeQ = q_type;
using DTypeKV = kv_type;
using DTypeO = DTypeQ;
return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] {
return DISPATCH_LOGITS_SOFT_CAP(use_logits_soft_cap, USE_LOGITS_SOFT_CAP, [&] {
using ParamsT = BatchDecodeParams<DTypeQ, DTypeKV, DTypeO, IdType>;
using AttentionVariant =
ComposedAttention<ParamsT, get_variant_code(/*use_custom_mask=*/false,
/*use_sliding_window=*/true,
USE_LOGITS_SOFT_CAP, /*use_alibi=*/false)>;

cudaError_t status = DecodePlan<HEAD_DIM, POS_ENCODING_MODE, AttentionVariant>(
static_cast<void*>(float_workspace_buffer.data_ptr()), float_workspace_size_in_bytes,
static_cast<void*>(int_workspace_buffer.data_ptr()),
static_cast<void*>(page_locked_int_workspace_buffer.data_ptr()),
int_workspace_size_in_bytes, plan_info, static_cast<IdType*>(indptr.data_ptr()),
batch_size, num_qo_heads, num_kv_heads, page_size, enable_cuda_graph,
/*stream=*/torch_current_stream);

TORCH_CHECK(status == cudaSuccess, "BatchDecodeWithPagedKVCache failed with error ",
cudaGetErrorString(status));
return true;
});
});
});

return plan_info.ToVector();
}

std::vector<torch::Tensor> BatchDecodeWithPagedKVCacheRun(
torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer,
std::vector<int64_t> plan_info_vec, torch::Tensor q,
std::optional<torch::Tensor> paged_kv_cache, std::optional<torch::Tensor> paged_k_cache,
std::optional<torch::Tensor> paged_v_cache, torch::Tensor paged_kv_indptr,
torch::Tensor paged_kv_indices, torch::Tensor paged_kv_last_page_len,
std::optional<torch::Tensor> alibi_slopes, unsigned int kv_layout_code, int window_left,
float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, bool return_lse) {
DecodePlanInfo plan_info;
plan_info.FromVector(plan_info_vec);
QKVLayout kv_layout = static_cast<QKVLayout>(kv_layout_code);
bool paged_kv_defined = paged_kv_cache.has_value();
auto device = q.device();
int64_t batch_size = q.size(0);
int64_t num_qo_heads = q.size(1);
int64_t num_kv_heads, page_size;
if (paged_kv_defined) {
if (kv_layout == QKVLayout::kHND) {
num_kv_heads = paged_kv_cache->size(2);
page_size = paged_kv_cache->size(3);
} else {
page_size = paged_kv_cache->size(2);
num_kv_heads = paged_kv_cache->size(3);
}
} else {
if (kv_layout == QKVLayout::kHND) {
num_kv_heads = paged_k_cache->size(1);
page_size = paged_k_cache->size(2);
} else {
page_size = paged_k_cache->size(1);
num_kv_heads = paged_k_cache->size(2);
}
}
uint32_t head_dim = q.size(2);

cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
torch::Tensor o = torch::empty_like(q);
torch::Tensor lse;
if (return_lse) {
lse = torch::empty({batch_size, num_qo_heads}, q.options().dtype((torch::kFloat32)));
}

TORCH_CHECK(logits_soft_cap >= 0.f, "logits_soft_cap must be non-negative");

void* float_buffer = static_cast<void*>(float_workspace_buffer.data_ptr());
void* int_buffer = static_cast<void*>(int_workspace_buffer.data_ptr());

using IdType = int32_t;
constexpr auto POS_ENCODING_MODE = PosEncodingMode::kNone;

// get q_scalar_type and kv_scalar_type
auto q_scalar_type = q.scalar_type();
auto kv_scalar_type =
paged_kv_cache.has_value() ? paged_kv_cache->scalar_type() : paged_k_cache->scalar_type();

DISPATCH_PYTORCH_QKV_DTYPE_TO_CTYPE(q_scalar_type, kv_scalar_type, q_type, kv_type, [&] {
using DTypeQ = q_type;
using DTypeKV = kv_type;
using DTypeO = DTypeQ;
return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] {
return DISPATCH_LOGITS_SOFT_CAP(logits_soft_cap > 0, USE_LOGITS_SOFT_CAP, [&] {
using ParamsT = BatchDecodeParams<DTypeQ, DTypeKV, DTypeO, IdType>;
using AttentionVariant =
ComposedAttention<ParamsT, get_variant_code(/*use_custom_mask=*/false,
/*use_sliding_window=*/true,
USE_LOGITS_SOFT_CAP, /*use_alibi=*/false)>;

paged_kv_t<DTypeKV, IdType> paged_kv(
num_kv_heads, page_size, HEAD_DIM, batch_size, kv_layout,
static_cast<DTypeKV*>(paged_kv_cache.has_value() ? paged_kv_cache->data_ptr()
: nullptr),
static_cast<DTypeKV*>(paged_k_cache.has_value() ? paged_k_cache->data_ptr() : nullptr),
static_cast<DTypeKV*>(paged_v_cache.has_value() ? paged_v_cache->data_ptr() : nullptr),
static_cast<IdType*>(paged_kv_indices.data_ptr()),
static_cast<IdType*>(paged_kv_indptr.data_ptr()),
static_cast<IdType*>(paged_kv_last_page_len.data_ptr()));
ParamsT params(static_cast<DTypeQ*>(q.data_ptr()),
/*q_offset=*/nullptr, paged_kv, static_cast<DTypeO*>(o.data_ptr()),
/*lse=*/(return_lse ? static_cast<float*>(lse.data_ptr()) : nullptr),
/*alibi_slopes=*/nullptr, num_qo_heads, window_left, logits_soft_cap,
sm_scale, rope_scale, rope_theta);

DTypeO* tmp_v = nullptr;
float* tmp_s = nullptr;
params.request_indices =
GetPtrFromBaseOffset<IdType>(int_buffer, plan_info.request_indices_offset);
params.kv_tile_indices =
GetPtrFromBaseOffset<IdType>(int_buffer, plan_info.kv_tile_indices_offset);
params.o_indptr = GetPtrFromBaseOffset<IdType>(int_buffer, plan_info.o_indptr_offset);
params.kv_chunk_size_ptr =
GetPtrFromBaseOffset<IdType>(int_buffer, plan_info.kv_chunk_size_ptr_offset);
if (plan_info.split_kv) {
tmp_v = GetPtrFromBaseOffset<DTypeO>(float_buffer, plan_info.v_offset);
tmp_s = GetPtrFromBaseOffset<float>(float_buffer, plan_info.s_offset);
if (plan_info.enable_cuda_graph) {
params.block_valid_mask =
GetPtrFromBaseOffset<bool>(int_buffer, plan_info.block_valid_mask_offset);
}
}
params.padded_batch_size = plan_info.padded_batch_size;

cudaError_t status =
flashinfer::BatchDecodeWithPagedKVCacheDispatched<HEAD_DIM, POS_ENCODING_MODE,
AttentionVariant>(
params, tmp_v, tmp_s, /*stream=*/torch_current_stream);
TORCH_CHECK(status == cudaSuccess, "BatchDecodeWithPagedKVCache failed with error ",
cudaGetErrorString(status));
return true;
});
});
});

if (return_lse) {
return {o, lse};
} else {
return {o};
}
}
Loading

0 comments on commit 3613a5b

Please sign in to comment.