Skip to content

Commit

Permalink
jit: further accelerate compilation by spliting files and multi-threa…
Browse files Browse the repository at this point in the history
…ding (#628)

This PR accelerates JIT compilation by:
- Add a `parallel_load_modules` function to load necessary modules for a
model in parallel using python multi-threading.
- Splitting sampling.cu into renorm.cu and sampling.cu

The batch prefill attention template could be further split into
multiple instances to accelerate compilation, we leave that for future
work.
  • Loading branch information
yzh119 authored Nov 23, 2024
1 parent 9cba9fb commit f5842b8
Show file tree
Hide file tree
Showing 9 changed files with 173 additions and 60 deletions.
1 change: 1 addition & 0 deletions python/aot_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,7 @@ def ln(src: str, dst: str, is_dir: bool = False) -> None:
"csrc/quantization.cu",
"csrc/rope.cu",
"csrc/sampling.cu",
"csrc/renorm.cu",
"csrc/activation.cu",
"csrc/batch_decode.cu",
"csrc/batch_prefill.cu",
Expand Down
79 changes: 79 additions & 0 deletions python/csrc/renorm.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
/*
* Copyright (c) 2024 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <flashinfer/sampling.cuh>

#include "pytorch_extension_utils.h"

using namespace flashinfer;

void top_p_renorm_probs(at::Tensor probs, at::Tensor renorm_probs,
std::optional<at::Tensor> maybe_top_p_arr, double top_p_val,
int64_t cuda_stream) {
CHECK_INPUT(probs);
auto device = probs.device();
CHECK_DIM(2, probs); // probs: (batch_size, vocab_size)
unsigned int batch_size = probs.size(0);
unsigned int vocab_size = probs.size(1);
bool has_top_p_arr = maybe_top_p_arr.has_value();

cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
cudaError_t status = sampling::TopPRenormProb<float>(
static_cast<float*>(probs.data_ptr()), static_cast<float*>(renorm_probs.data_ptr()),
has_top_p_arr ? static_cast<float*>(maybe_top_p_arr->data_ptr()) : nullptr, batch_size,
top_p_val, vocab_size, stream);
TORCH_CHECK(status == cudaSuccess,
"TopPRenormProb failed with error code " + std::string(cudaGetErrorString(status)));
}

void top_k_renorm_probs(at::Tensor probs, at::Tensor renorm_probs,
std::optional<at::Tensor> maybe_top_k_arr, unsigned int top_k_val,
int64_t cuda_stream) {
CHECK_INPUT(probs);
auto device = probs.device();
CHECK_DIM(2, probs); // probs: (batch_size, vocab_size)
unsigned int batch_size = probs.size(0);
unsigned int vocab_size = probs.size(1);
bool has_top_k_arr = maybe_top_k_arr.has_value();

cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
cudaError_t status = sampling::TopKRenormProb<float>(
static_cast<float*>(probs.data_ptr()), static_cast<float*>(renorm_probs.data_ptr()),
has_top_k_arr ? static_cast<int*>(maybe_top_k_arr->data_ptr()) : nullptr, batch_size,
top_k_val, vocab_size, stream);

TORCH_CHECK(status == cudaSuccess,
"TopKRenormProb failed with error code " + std::string(cudaGetErrorString(status)));
}

void top_k_mask_logits(at::Tensor logits, at::Tensor mask_logits,
std::optional<at::Tensor> maybe_top_k_arr, unsigned int top_k_val,
int64_t cuda_stream) {
CHECK_INPUT(logits);
auto device = logits.device();
CHECK_DIM(2, logits); // logits: (batch_size, vocab_size)
unsigned int batch_size = logits.size(0);
unsigned int vocab_size = logits.size(1);
bool has_top_k_arr = maybe_top_k_arr.has_value();

cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
cudaError_t status = sampling::TopKMaskLogits<float>(
static_cast<float*>(logits.data_ptr()), static_cast<float*>(mask_logits.data_ptr()),
has_top_k_arr ? static_cast<int*>(maybe_top_k_arr->data_ptr()) : nullptr, batch_size,
top_k_val, vocab_size, stream);

TORCH_CHECK(status == cudaSuccess,
"TopKMaskLogits failed with error code " + std::string(cudaGetErrorString(status)));
}
59 changes: 0 additions & 59 deletions python/csrc/sampling.cu
Original file line number Diff line number Diff line change
Expand Up @@ -143,65 +143,6 @@ void top_k_top_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_sample
std::string(cudaGetErrorString(status)));
}

void top_p_renorm_probs(at::Tensor probs, at::Tensor renorm_probs,
std::optional<at::Tensor> maybe_top_p_arr, double top_p_val,
int64_t cuda_stream) {
CHECK_INPUT(probs);
auto device = probs.device();
CHECK_DIM(2, probs); // probs: (batch_size, vocab_size)
unsigned int batch_size = probs.size(0);
unsigned int vocab_size = probs.size(1);
bool has_top_p_arr = maybe_top_p_arr.has_value();

cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
cudaError_t status = sampling::TopPRenormProb<float>(
static_cast<float*>(probs.data_ptr()), static_cast<float*>(renorm_probs.data_ptr()),
has_top_p_arr ? static_cast<float*>(maybe_top_p_arr->data_ptr()) : nullptr, batch_size,
top_p_val, vocab_size, stream);
TORCH_CHECK(status == cudaSuccess,
"TopPRenormProb failed with error code " + std::string(cudaGetErrorString(status)));
}

void top_k_renorm_probs(at::Tensor probs, at::Tensor renorm_probs,
std::optional<at::Tensor> maybe_top_k_arr, unsigned int top_k_val,
int64_t cuda_stream) {
CHECK_INPUT(probs);
auto device = probs.device();
CHECK_DIM(2, probs); // probs: (batch_size, vocab_size)
unsigned int batch_size = probs.size(0);
unsigned int vocab_size = probs.size(1);
bool has_top_k_arr = maybe_top_k_arr.has_value();

cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
cudaError_t status = sampling::TopKRenormProb<float>(
static_cast<float*>(probs.data_ptr()), static_cast<float*>(renorm_probs.data_ptr()),
has_top_k_arr ? static_cast<int*>(maybe_top_k_arr->data_ptr()) : nullptr, batch_size,
top_k_val, vocab_size, stream);

TORCH_CHECK(status == cudaSuccess,
"TopKRenormProb failed with error code " + std::string(cudaGetErrorString(status)));
}

void top_k_mask_logits(at::Tensor logits, at::Tensor mask_logits,
std::optional<at::Tensor> maybe_top_k_arr, unsigned int top_k_val,
int64_t cuda_stream) {
CHECK_INPUT(logits);
auto device = logits.device();
CHECK_DIM(2, logits); // logits: (batch_size, vocab_size)
unsigned int batch_size = logits.size(0);
unsigned int vocab_size = logits.size(1);
bool has_top_k_arr = maybe_top_k_arr.has_value();

cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
cudaError_t status = sampling::TopKMaskLogits<float>(
static_cast<float*>(logits.data_ptr()), static_cast<float*>(mask_logits.data_ptr()),
has_top_k_arr ? static_cast<int*>(maybe_top_k_arr->data_ptr()) : nullptr, batch_size,
top_k_val, vocab_size, stream);

TORCH_CHECK(status == cudaSuccess,
"TopKMaskLogits failed with error code " + std::string(cudaGetErrorString(status)));
}

void chain_speculative_sampling(at::Tensor draft_probs, at::Tensor draft_token_ids,
at::Tensor uniform_samples, at::Tensor target_probs,
at::Tensor output_token_ids, at::Tensor output_accepted_token_num,
Expand Down
1 change: 1 addition & 0 deletions python/flashinfer/jit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from .attention import get_single_prefill_uri as get_single_prefill_uri
from .core import clear_cache_dir, load_cuda_ops
from .env import *
from .utils import parallel_load_modules as parallel_load_modules

try:
from .aot_config import prebuilt_ops_uri as prebuilt_ops_uri # type: ignore[import]
Expand Down
2 changes: 2 additions & 0 deletions python/flashinfer/jit/batch_prefill_templ.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
limitations under the License.
"""

import itertools

batch_prefill_suffix = [
"_plan.cu",
"_ragged_run.cu",
Expand Down
4 changes: 3 additions & 1 deletion python/flashinfer/jit/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def load_cuda_ops(
] + CUTLASS_INCLUDE_DIRS
lock = FileLock(FLASHINFER_JIT_DIR / f"{name}.lock", thread_local=False)
with lock:
return torch_cpp_ext.load(
module = torch_cpp_ext.load(
name,
list(map(lambda _: str(_), sources)),
extra_cflags=cflags,
Expand All @@ -119,3 +119,5 @@ def load_cuda_ops(
verbose=verbose,
with_cuda=True,
)
logger.info(f"Finished loading JIT ops: {name}")
return module
32 changes: 32 additions & 0 deletions python/flashinfer/jit/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,13 @@
"""

import pathlib
import threading
from typing import Callable, List

import torch

from .core import logger


def write_if_different(path: pathlib.Path, content: str) -> None:
if path.exists():
Expand All @@ -30,6 +34,34 @@ def write_if_different(path: pathlib.Path, content: str) -> None:
f.write(content)


def parallel_load_modules(
load_module_funcs: List[Callable],
):
threads = []
exceptions = []

def wrapper(func):
try:
func()
except Exception as e:
exceptions.append((func, e))

for func in load_module_funcs:
thread = threading.Thread(target=wrapper, args=(func,))
thread.start()
threads.append(thread)

for thread in threads:
thread.join()

if exceptions:
for func, e in exceptions:
print(f"Exception occurred in {func.__name__}: {e}")
raise RuntimeError("One or more exceptions occurred during module loading")

logger.info("Finished loading modules")


dtype_map = {
torch.float16: "half",
torch.bfloat16: "nv_bfloat16",
Expand Down
1 change: 1 addition & 0 deletions python/flashinfer/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def get_sampling_module():
"sampling",
[
FLASHINFER_CSRC_DIR / "sampling.cu",
FLASHINFER_CSRC_DIR / "renorm.cu",
FLASHINFER_CSRC_DIR / "flashinfer_sampling_ops.cu",
],
)
Expand Down
54 changes: 54 additions & 0 deletions tests/test_jit_warmup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
"""
Copyright (c) 2024 by FlashInfer team.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import torch
from flashinfer.jit import parallel_load_modules
from flashinfer.utils import PosEncodingMode

import flashinfer


def test_warmpup_llama():
parallel_load_modules(
[
lambda: flashinfer.activation.get_act_and_mul_module("silu"),
flashinfer.norm.get_norm_module,
flashinfer.sampling.get_sampling_module,
flashinfer.quantization.get_quantization_module,
flashinfer.page.get_page_module,
lambda: flashinfer.decode.get_batch_decode_module(
torch.float16,
torch.float16,
torch.float16,
torch.int32,
128,
PosEncodingMode.NONE.value,
False, # use_sliding_window
False, # use_logits_soft_cap
),
lambda: flashinfer.prefill.gen_batch_prefill_module(
torch.float16,
torch.float16,
torch.float16,
torch.int32,
128,
PosEncodingMode.NONE.value,
False, # use_sliding_window
False, # use_logits_soft_cap
False, # allow_fp16_qk_reduction
),
]
)

0 comments on commit f5842b8

Please sign in to comment.