Skip to content

Commit

Permalink
perf: accelerate JIT compilation speed (#618)
Browse files Browse the repository at this point in the history
Current JIT compilation is slow because we rely on a huge header
`<torch/extension.h>` which is too heavy for our use case.

This PR refactors the codebase to only include necessary headers for
pybind, and moves most of torch runtime API calls from C++ to python.

The compilation time was reduced from 48 seconds to 18 seconds for
lightweight operators such as norm.
  • Loading branch information
yzh119 authored Nov 20, 2024
1 parent dd3c836 commit eaf73fd
Show file tree
Hide file tree
Showing 76 changed files with 2,420 additions and 2,101 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ src/generated/
python/csrc/generated/
python/flashinfer/_build_meta.py
python/flashinfer/jit/aot_config.py
python/csrc_aot/generated/
python/csrc-aot/generated/

# Package files
python/flashinfer/data/
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ repos:
- id: clang-format
types_or: [c++, c, cuda]
exclude: |
(?x)^(3rdparty/.* src/generated/.* python/flashinfer/jit/aot_config.py python/csrc_aot/generated/.*)$
(?x)^(3rdparty/.* src/generated/.* python/flashinfer/jit/aot_config.py)$
- repo: https://github.com/cheshirekow/cmake-format-precommit
rev: v0.6.13
Expand Down
5 changes: 3 additions & 2 deletions include/flashinfer/allocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@

#include <memory>
#include <sstream>
#include <stdexcept>

#include "exception.h"

namespace flashinfer {

Expand All @@ -44,7 +45,7 @@ struct AlignedAllocator {
std::ostringstream oss;
oss << "Failed to allocate memory for " << name << " with size " << size << " and alignment "
<< alignment << " in AlignedAllocator";
throw std::runtime_error(oss.str());
FLASHINFER_ERROR(oss.str());
}
return nullptr;
}
Expand Down
2 changes: 1 addition & 1 deletion include/flashinfer/attention/decode.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -687,7 +687,7 @@ cudaError_t SingleDecodeWithKVCacheDispatched(typename AttentionVariant::ParamsT
if (nblks.x == 0 || nblks.y == 0) {
std::ostringstream err_msg;
err_msg << "Invalid kernel configuration: nblks=(" << nblks.x << "," << nblks.y << ")";
throw std::runtime_error(err_msg.str());
FLASHINFER_ERROR(err_msg.str());
}
dim3 nthrs = dim3(bdx, bdy, bdz);
float* tmp_lse = (float*)(tmp + num_chunks * num_qo_heads * HEAD_DIM);
Expand Down
8 changes: 4 additions & 4 deletions include/flashinfer/attention/prefill.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1375,7 +1375,7 @@ cudaError_t SinglePrefillWithKVCacheDispatched(typename AttentionVariant::Params
err_msg << "When mask_mode is set to MaskMode::kCausal, kv_len must be greater than or equal "
"to qo_len, got kv_len"
<< kv_len << " and qo_len " << qo_len;
throw std::invalid_argument(err_msg.str());
FLASHINFER_ERROR(err_msg.str());
}

const uint32_t group_size = num_qo_heads / num_kv_heads;
Expand Down Expand Up @@ -1442,7 +1442,7 @@ cudaError_t SinglePrefillWithKVCacheDispatched(typename AttentionVariant::Params
<< " NUM_WARPS_Q=" << NUM_WARPS_Q << " NUM_WARPS_KV=" << NUM_WARPS_KV
<< " please create an issue (https://github.com/flashinfer-ai/flashinfer/issues)"
" and report the issue to the developers.";
throw std::invalid_argument(err_msg.str());
FLASHINFER_ERROR(err_msg.str());
} else {
constexpr uint32_t num_threads = (NUM_WARPS_Q * NUM_WARPS_KV) * WARP_SIZE;
constexpr uint32_t num_rows_per_cta = NUM_FRAGS_Q * NUM_WARPS_Q * 16;
Expand Down Expand Up @@ -2165,7 +2165,7 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched(typename AttentionVariant::P
<< " NUM_WARPS_Q=" << NUM_WARPS_Q << " NUM_WARPS_KV=" << NUM_WARPS_KV
<< " please create an issue (https://github.com/flashinfer-ai/flashinfer/issues)"
" and report the issue to the developers.";
throw std::invalid_argument(err_msg.str());
FLASHINFER_ERROR(err_msg.str());
} else {
// TODO(Zihao): fix the following computation
uint32_t smem_size = (NUM_FRAGS_Q * NUM_WARPS_Q * sizeof(DTypeQ) +
Expand Down Expand Up @@ -2267,7 +2267,7 @@ cudaError_t BatchPrefillWithPagedKVCacheDispatched(typename AttentionVariant::Pa
<< " NUM_WARPS_Q=" << NUM_WARPS_Q << " NUM_WARPS_KV=" << NUM_WARPS_KV
<< " please create an issue (https://github.com/flashinfer-ai/flashinfer/issues)"
" and report the issue to the developers.";
throw std::invalid_argument(err_msg.str());
FLASHINFER_ERROR(err_msg.str());
} else {
// TODO(Zihao): fix the following computation
uint32_t smem_size = (NUM_FRAGS_Q * NUM_WARPS_Q * sizeof(DTypeQ) +
Expand Down
10 changes: 5 additions & 5 deletions include/flashinfer/attention/scheduler.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ struct DecodePlanInfo {
if (vec.size() != 10) {
std::ostringstream err_msg;
err_msg << "DecodePlanInfo::FromVector: vec.size() should be 10, but got " << vec.size();
throw std::invalid_argument(err_msg.str());
FLASHINFER_ERROR(err_msg.str());
}
padded_batch_size = vec[0];
v_offset = vec[1];
Expand Down Expand Up @@ -440,14 +440,14 @@ inline auto PrefillSplitQOKVIndptr(IdType* qo_indptr_h, IdType* kv_indptr_h, uin
std::ostringstream err_msg;
err_msg << "qo_indptr[" << i + 1 << "]" << qo_indptr_h[i + 1] << " - qo_indptr[" << i << "]"
<< qo_indptr_h[i] << " should be non-negative";
throw std::invalid_argument(err_msg.str());
FLASHINFER_ERROR(err_msg.str());
}
kv_len_arr[i] = int64_t(kv_indptr_h[i + 1] - kv_indptr_h[i]);
if (kv_len_arr[i] < 0) {
std::ostringstream err_msg;
err_msg << "kv_indptr[" << i + 1 << "]" << kv_indptr_h[i + 1] << " - kv_indptr[" << i << "]"
<< kv_indptr_h[i] << " should be non-negative";
throw std::invalid_argument(err_msg.str());
FLASHINFER_ERROR(err_msg.str());
}
sum_packed_qo_len += packed_qo_len_arr[i];
}
Expand Down Expand Up @@ -570,7 +570,7 @@ struct PrefillPlanInfo {
if (vec.size() != 14) {
std::ostringstream err_msg;
err_msg << "PrefillPlanInfo::FromVector: vec.size() should be 14, but got " << vec.size();
throw std::invalid_argument(err_msg.str());
FLASHINFER_ERROR(err_msg.str());
}
padded_batch_size = vec[0];
total_num_rows = vec[1];
Expand Down Expand Up @@ -601,7 +601,7 @@ inline cudaError_t PrefillPlan(void* float_buffer, size_t float_workspace_size_i
std::ostringstream err_msg;
err_msg << "num_qo_heads " << num_qo_heads << " should be divisible by num_kv_heads "
<< num_kv_heads;
throw std::invalid_argument(err_msg.str());
FLASHINFER_ERROR(err_msg.str());
}

// step 0: get the number of SMs
Expand Down
48 changes: 48 additions & 0 deletions include/flashinfer/exception.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Copyright (c) 2024 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef FLASHINFER_EXCEPTION_H_
#define FLASHINFER_EXCEPTION_H_

#include <exception>
#include <sstream>

namespace flashinfer {

class Error : public std::exception {
private:
std::string message_;

public:
Error(const std::string& func, const std::string& file, int line, const std::string& message) {
std::ostringstream oss;
oss << "Error in function '" << func << "' "
<< "at " << file << ":" << line << ": " << message;
message_ = oss.str();
}

virtual const char* what() const noexcept override { return message_.c_str(); }
};

#define FLASHINFER_ERROR(message) throw Error(__FUNCTION__, __FILE__, __LINE__, message)

#define FLASHINFER_CHECK(condition, message) \
if (!(condition)) { \
FLASHINFER_ERROR(message); \
}

} // namespace flashinfer

#endif // FLASHINFER_EXCEPTION_H_
20 changes: 11 additions & 9 deletions include/flashinfer/gemm/bmm_fp8.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,17 @@
#include <cublasLt.h>
#include <cuda_fp8.h>

#include <stdexcept>
#include <iostream>
#include <memory>
#include <type_traits>

#define FLASHINFER_CUBLAS_CHECK(EXPR) \
{ \
cublasStatus_t e = (EXPR); \
if (e != CUBLAS_STATUS_SUCCESS) { \
throw std::runtime_error("CUBLAS Error: " + std::string(cublasGetStatusString(e))); \
} \
#include "../exception.h"

#define FLASHINFER_CUBLAS_CHECK(EXPR) \
{ \
cublasStatus_t e = (EXPR); \
FLASHINFER_CHECK(e == CUBLAS_STATUS_SUCCESS, \
"CUBLAS Error: " + std::string(cublasGetStatusString(e))); \
}

#ifndef NDEBUG
Expand Down Expand Up @@ -131,7 +133,7 @@ cudaDataType_t get_cuda_data_type() {
} else if constexpr (std::is_same_v<T, half>) {
return CUDA_R_16F;
} else {
throw std::runtime_error("Unsupported type");
FLASHINFER_ERROR("Unsupported type");
}
}

Expand All @@ -155,7 +157,7 @@ cublasStatus_t bmm_fp8_internal_cublaslt(void* workspace, size_t workspace_size_
cudaDataType_t b_type = get_cuda_data_type<BT>();
cudaDataType_t d_type = get_cuda_data_type<DT>();
if (std::is_same_v<AT, __nv_fp8_e5m2> && std::is_same_v<BT, __nv_fp8_e5m2>) {
throw std::runtime_error("Unsupported combination: both A and B are e5m2");
FLASHINFER_ERROR("Unsupported combination: both A and B are e5m2");
}

auto a_desp = CuBlasLtMatrixLayout(a_type, m, k, k, true);
Expand Down
4 changes: 2 additions & 2 deletions include/flashinfer/gemm/group_gemm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,13 @@ cudaError_t CutlassSegmentGEMMRun(void* workspace_buffer, size_t workspace_buffe
if (status != cutlass::Status::kSuccess) {
std::ostringstream err_msg;
err_msg << "cutlass group_gemm.initialize failed: " << cutlassGetStatusString(status);
throw std::runtime_error(err_msg.str());
FLASHINFER_ERROR(err_msg.str());
}
status = gemm.run(stream);
if (status != cutlass::Status::kSuccess) {
std::ostringstream err_msg;
err_msg << "cutlass group_gemm.run failed: " << cutlassGetStatusString(status);
throw std::runtime_error(err_msg.str());
FLASHINFER_ERROR(err_msg.str());
}
});

Expand Down
2 changes: 1 addition & 1 deletion include/flashinfer/gemm/group_gemm_sm90.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ cudaError_t CutlassSegmentGEMMSM90Run(void* float_buffer, size_t float_buffer_si
sizeof(DTypeIn) == 1) {
std::ostringstream err_msg;
err_msg << "Row-major layout is not supported for fp8 data type";
throw std::runtime_error(err_msg.str());
FLASHINFER_ERROR(err_msg.str());
} else {
using LayoutA = cutlass::layout::RowMajor;
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
Expand Down
2 changes: 2 additions & 0 deletions include/flashinfer/math.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
#include <cuda_fp16.h>
#include <cuda_runtime.h>

#include <cstdint>

namespace flashinfer {
namespace math {

Expand Down
22 changes: 11 additions & 11 deletions include/flashinfer/utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@

#include <cstdint>
#include <iostream>
#include <sstream>
#include <stdexcept>
#include <vector>

#include "exception.h"

#define STR_HELPER(x) #x
#define STR(x) STR_HELPER(x)

Expand Down Expand Up @@ -57,7 +57,7 @@

#define DISPATCH_ALLOW_FP16_QK_REDUCTION(allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, ...) \
if (allow_fp16_qk_reduction) { \
throw std::runtime_error("FP16_QK_REDUCTION disabled at compile time"); \
FLASHINFER_ERROR("FP16_QK_REDUCTION disabled at compile time"); \
} else { \
constexpr bool ALLOW_FP16_QK_REDUCTION = false; \
__VA_ARGS__ \
Expand All @@ -73,7 +73,7 @@
} else { \
std::ostringstream err_msg; \
err_msg << "Unsupported num_frags_q: " << num_frags_q; \
throw std::invalid_argument(err_msg.str()); \
FLASHINFER_ERROR(err_msg.str()); \
}

#define DISPATCH_NUM_FRAGS_KV(max_frags_kv, NUM_FRAGS_KV, ...) \
Expand All @@ -92,7 +92,7 @@
} else { \
std::ostringstream err_msg; \
err_msg << "Unsupported max_frags_kv: " << max_frags_kv; \
throw std::invalid_argument(err_msg.str()); \
FLASHINFER_ERROR(err_msg.str()); \
}

#define DISPATCH_CTA_TILE_Q(cta_tile_q, CTA_TILE_Q, ...) \
Expand All @@ -115,7 +115,7 @@
default: { \
std::ostringstream err_msg; \
err_msg << "Unsupported cta_tile_q: " << cta_tile_q; \
throw std::invalid_argument(err_msg.str()); \
FLASHINFER_ERROR(err_msg.str()); \
} \
}

Expand All @@ -138,7 +138,7 @@
} else { \
std::ostringstream err_msg; \
err_msg << "Unsupported group_size: " << group_size; \
throw std::invalid_argument(err_msg.str()); \
FLASHINFER_ERROR(err_msg.str()); \
}

#define DISPATCH_MASK_MODE(mask_mode, MASK_MODE, ...) \
Expand All @@ -161,7 +161,7 @@
default: { \
std::ostringstream err_msg; \
err_msg << "Unsupported mask_mode: " << int(mask_mode); \
throw std::invalid_argument(err_msg.str()); \
FLASHINFER_ERROR(err_msg.str()); \
} \
}

Expand Down Expand Up @@ -190,7 +190,7 @@
default: { \
std::ostringstream err_msg; \
err_msg << "Unsupported head_dim: " << head_dim; \
throw std::invalid_argument(err_msg.str()); \
FLASHINFER_ERROR(err_msg.str()); \
} \
}

Expand All @@ -214,7 +214,7 @@
default: { \
std::ostringstream err_msg; \
err_msg << "Unsupported pos_encoding_mode: " << int(pos_encoding_mode); \
throw std::invalid_argument(err_msg.str()); \
FLASHINFER_ERROR(err_msg.str()); \
} \
}

Expand Down Expand Up @@ -248,7 +248,7 @@
default: { \
std::ostringstream err_msg; \
err_msg << "Unsupported aligned_vec_size: " << aligned_vec_size; \
throw std::invalid_argument(err_msg.str()); \
FLASHINFER_ERROR(err_msg.str()); \
} \
}

Expand Down
1 change: 0 additions & 1 deletion python/aot_MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

prune */__pycache__
prune csrc
prune csrc_aot
exclude aot_setup.py
exclude setup.py

Expand Down
14 changes: 7 additions & 7 deletions python/aot_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def write_if_different(path: pathlib.Path, content: str) -> None:


def get_instantiation_cu() -> Tuple[List[str], List[str], List[str]]:
path = root / "python" / "csrc_aot" / "generated"
path = root / "python" / "csrc" / "generated"
path.mkdir(parents=True, exist_ok=True)

head_dims = os.environ.get("FLASHINFER_HEAD_DIMS", "64,128,256").split(",")
Expand Down Expand Up @@ -423,12 +423,12 @@ def ln(src: str, dst: str, is_dir: bool = False) -> None:
"csrc/quantization.cu",
"csrc/rope.cu",
"csrc/sampling.cu",
"csrc_aot/activation.cu",
"csrc_aot/batch_decode.cu",
"csrc_aot/batch_prefill.cu",
"csrc_aot/flashinfer_ops.cu",
"csrc_aot/single_decode.cu",
"csrc_aot/single_prefill.cu",
"csrc/activation.cu",
"csrc/batch_decode.cu",
"csrc/batch_prefill.cu",
"csrc/single_decode.cu",
"csrc/single_prefill.cu",
"csrc/flashinfer_ops.cu",
]
+ files_decode
+ files_prefill,
Expand Down
Loading

0 comments on commit eaf73fd

Please sign in to comment.