-
Notifications
You must be signed in to change notification settings - Fork 148
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
jit: further accelerate compilation by spliting files and multi-threa…
…ding (#628) This PR accelerates JIT compilation by: - Add a `parallel_load_modules` function to load necessary modules for a model in parallel using python multi-threading. - Splitting sampling.cu into renorm.cu and sampling.cu The batch prefill attention template could be further split into multiple instances to accelerate compilation, we leave that for future work.
- Loading branch information
Showing
9 changed files
with
173 additions
and
60 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
/* | ||
* Copyright (c) 2024 by FlashInfer team. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
#include <flashinfer/sampling.cuh> | ||
|
||
#include "pytorch_extension_utils.h" | ||
|
||
using namespace flashinfer; | ||
|
||
void top_p_renorm_probs(at::Tensor probs, at::Tensor renorm_probs, | ||
std::optional<at::Tensor> maybe_top_p_arr, double top_p_val, | ||
int64_t cuda_stream) { | ||
CHECK_INPUT(probs); | ||
auto device = probs.device(); | ||
CHECK_DIM(2, probs); // probs: (batch_size, vocab_size) | ||
unsigned int batch_size = probs.size(0); | ||
unsigned int vocab_size = probs.size(1); | ||
bool has_top_p_arr = maybe_top_p_arr.has_value(); | ||
|
||
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream); | ||
cudaError_t status = sampling::TopPRenormProb<float>( | ||
static_cast<float*>(probs.data_ptr()), static_cast<float*>(renorm_probs.data_ptr()), | ||
has_top_p_arr ? static_cast<float*>(maybe_top_p_arr->data_ptr()) : nullptr, batch_size, | ||
top_p_val, vocab_size, stream); | ||
TORCH_CHECK(status == cudaSuccess, | ||
"TopPRenormProb failed with error code " + std::string(cudaGetErrorString(status))); | ||
} | ||
|
||
void top_k_renorm_probs(at::Tensor probs, at::Tensor renorm_probs, | ||
std::optional<at::Tensor> maybe_top_k_arr, unsigned int top_k_val, | ||
int64_t cuda_stream) { | ||
CHECK_INPUT(probs); | ||
auto device = probs.device(); | ||
CHECK_DIM(2, probs); // probs: (batch_size, vocab_size) | ||
unsigned int batch_size = probs.size(0); | ||
unsigned int vocab_size = probs.size(1); | ||
bool has_top_k_arr = maybe_top_k_arr.has_value(); | ||
|
||
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream); | ||
cudaError_t status = sampling::TopKRenormProb<float>( | ||
static_cast<float*>(probs.data_ptr()), static_cast<float*>(renorm_probs.data_ptr()), | ||
has_top_k_arr ? static_cast<int*>(maybe_top_k_arr->data_ptr()) : nullptr, batch_size, | ||
top_k_val, vocab_size, stream); | ||
|
||
TORCH_CHECK(status == cudaSuccess, | ||
"TopKRenormProb failed with error code " + std::string(cudaGetErrorString(status))); | ||
} | ||
|
||
void top_k_mask_logits(at::Tensor logits, at::Tensor mask_logits, | ||
std::optional<at::Tensor> maybe_top_k_arr, unsigned int top_k_val, | ||
int64_t cuda_stream) { | ||
CHECK_INPUT(logits); | ||
auto device = logits.device(); | ||
CHECK_DIM(2, logits); // logits: (batch_size, vocab_size) | ||
unsigned int batch_size = logits.size(0); | ||
unsigned int vocab_size = logits.size(1); | ||
bool has_top_k_arr = maybe_top_k_arr.has_value(); | ||
|
||
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream); | ||
cudaError_t status = sampling::TopKMaskLogits<float>( | ||
static_cast<float*>(logits.data_ptr()), static_cast<float*>(mask_logits.data_ptr()), | ||
has_top_k_arr ? static_cast<int*>(maybe_top_k_arr->data_ptr()) : nullptr, batch_size, | ||
top_k_val, vocab_size, stream); | ||
|
||
TORCH_CHECK(status == cudaSuccess, | ||
"TopKMaskLogits failed with error code " + std::string(cudaGetErrorString(status))); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
""" | ||
Copyright (c) 2024 by FlashInfer team. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
""" | ||
|
||
import torch | ||
from flashinfer.jit import parallel_load_modules | ||
from flashinfer.utils import PosEncodingMode | ||
|
||
import flashinfer | ||
|
||
|
||
def test_warmpup_llama(): | ||
parallel_load_modules( | ||
[ | ||
lambda: flashinfer.activation.get_act_and_mul_module("silu"), | ||
flashinfer.norm.get_norm_module, | ||
flashinfer.sampling.get_sampling_module, | ||
flashinfer.quantization.get_quantization_module, | ||
flashinfer.page.get_page_module, | ||
lambda: flashinfer.decode.get_batch_decode_module( | ||
torch.float16, | ||
torch.float16, | ||
torch.float16, | ||
torch.int32, | ||
128, | ||
PosEncodingMode.NONE.value, | ||
False, # use_sliding_window | ||
False, # use_logits_soft_cap | ||
), | ||
lambda: flashinfer.prefill.gen_batch_prefill_module( | ||
torch.float16, | ||
torch.float16, | ||
torch.float16, | ||
torch.int32, | ||
128, | ||
PosEncodingMode.NONE.value, | ||
False, # use_sliding_window | ||
False, # use_logits_soft_cap | ||
False, # allow_fp16_qk_reduction | ||
), | ||
] | ||
) |