From f5842b86002d632e7ed230c12cdc3199576f19b1 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Sat, 23 Nov 2024 14:39:46 -0800 Subject: [PATCH] jit: further accelerate compilation by spliting files and multi-threading (#628) This PR accelerates JIT compilation by: - Add a `parallel_load_modules` function to load necessary modules for a model in parallel using python multi-threading. - Splitting sampling.cu into renorm.cu and sampling.cu The batch prefill attention template could be further split into multiple instances to accelerate compilation, we leave that for future work. --- python/aot_setup.py | 1 + python/csrc/renorm.cu | 79 ++++++++++++++++++++ python/csrc/sampling.cu | 59 --------------- python/flashinfer/jit/__init__.py | 1 + python/flashinfer/jit/batch_prefill_templ.py | 2 + python/flashinfer/jit/core.py | 4 +- python/flashinfer/jit/utils.py | 32 ++++++++ python/flashinfer/sampling.py | 1 + tests/test_jit_warmup.py | 54 +++++++++++++ 9 files changed, 173 insertions(+), 60 deletions(-) create mode 100644 python/csrc/renorm.cu create mode 100644 tests/test_jit_warmup.py diff --git a/python/aot_setup.py b/python/aot_setup.py index e9c73159..e9efc269 100644 --- a/python/aot_setup.py +++ b/python/aot_setup.py @@ -427,6 +427,7 @@ def ln(src: str, dst: str, is_dir: bool = False) -> None: "csrc/quantization.cu", "csrc/rope.cu", "csrc/sampling.cu", + "csrc/renorm.cu", "csrc/activation.cu", "csrc/batch_decode.cu", "csrc/batch_prefill.cu", diff --git a/python/csrc/renorm.cu b/python/csrc/renorm.cu new file mode 100644 index 00000000..4a17ce2e --- /dev/null +++ b/python/csrc/renorm.cu @@ -0,0 +1,79 @@ +/* + * Copyright (c) 2024 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +#include "pytorch_extension_utils.h" + +using namespace flashinfer; + +void top_p_renorm_probs(at::Tensor probs, at::Tensor renorm_probs, + std::optional maybe_top_p_arr, double top_p_val, + int64_t cuda_stream) { + CHECK_INPUT(probs); + auto device = probs.device(); + CHECK_DIM(2, probs); // probs: (batch_size, vocab_size) + unsigned int batch_size = probs.size(0); + unsigned int vocab_size = probs.size(1); + bool has_top_p_arr = maybe_top_p_arr.has_value(); + + cudaStream_t stream = reinterpret_cast(cuda_stream); + cudaError_t status = sampling::TopPRenormProb( + static_cast(probs.data_ptr()), static_cast(renorm_probs.data_ptr()), + has_top_p_arr ? static_cast(maybe_top_p_arr->data_ptr()) : nullptr, batch_size, + top_p_val, vocab_size, stream); + TORCH_CHECK(status == cudaSuccess, + "TopPRenormProb failed with error code " + std::string(cudaGetErrorString(status))); +} + +void top_k_renorm_probs(at::Tensor probs, at::Tensor renorm_probs, + std::optional maybe_top_k_arr, unsigned int top_k_val, + int64_t cuda_stream) { + CHECK_INPUT(probs); + auto device = probs.device(); + CHECK_DIM(2, probs); // probs: (batch_size, vocab_size) + unsigned int batch_size = probs.size(0); + unsigned int vocab_size = probs.size(1); + bool has_top_k_arr = maybe_top_k_arr.has_value(); + + cudaStream_t stream = reinterpret_cast(cuda_stream); + cudaError_t status = sampling::TopKRenormProb( + static_cast(probs.data_ptr()), static_cast(renorm_probs.data_ptr()), + has_top_k_arr ? static_cast(maybe_top_k_arr->data_ptr()) : nullptr, batch_size, + top_k_val, vocab_size, stream); + + TORCH_CHECK(status == cudaSuccess, + "TopKRenormProb failed with error code " + std::string(cudaGetErrorString(status))); +} + +void top_k_mask_logits(at::Tensor logits, at::Tensor mask_logits, + std::optional maybe_top_k_arr, unsigned int top_k_val, + int64_t cuda_stream) { + CHECK_INPUT(logits); + auto device = logits.device(); + CHECK_DIM(2, logits); // logits: (batch_size, vocab_size) + unsigned int batch_size = logits.size(0); + unsigned int vocab_size = logits.size(1); + bool has_top_k_arr = maybe_top_k_arr.has_value(); + + cudaStream_t stream = reinterpret_cast(cuda_stream); + cudaError_t status = sampling::TopKMaskLogits( + static_cast(logits.data_ptr()), static_cast(mask_logits.data_ptr()), + has_top_k_arr ? static_cast(maybe_top_k_arr->data_ptr()) : nullptr, batch_size, + top_k_val, vocab_size, stream); + + TORCH_CHECK(status == cudaSuccess, + "TopKMaskLogits failed with error code " + std::string(cudaGetErrorString(status))); +} diff --git a/python/csrc/sampling.cu b/python/csrc/sampling.cu index 0f0021a1..484ff58d 100644 --- a/python/csrc/sampling.cu +++ b/python/csrc/sampling.cu @@ -143,65 +143,6 @@ void top_k_top_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_sample std::string(cudaGetErrorString(status))); } -void top_p_renorm_probs(at::Tensor probs, at::Tensor renorm_probs, - std::optional maybe_top_p_arr, double top_p_val, - int64_t cuda_stream) { - CHECK_INPUT(probs); - auto device = probs.device(); - CHECK_DIM(2, probs); // probs: (batch_size, vocab_size) - unsigned int batch_size = probs.size(0); - unsigned int vocab_size = probs.size(1); - bool has_top_p_arr = maybe_top_p_arr.has_value(); - - cudaStream_t stream = reinterpret_cast(cuda_stream); - cudaError_t status = sampling::TopPRenormProb( - static_cast(probs.data_ptr()), static_cast(renorm_probs.data_ptr()), - has_top_p_arr ? static_cast(maybe_top_p_arr->data_ptr()) : nullptr, batch_size, - top_p_val, vocab_size, stream); - TORCH_CHECK(status == cudaSuccess, - "TopPRenormProb failed with error code " + std::string(cudaGetErrorString(status))); -} - -void top_k_renorm_probs(at::Tensor probs, at::Tensor renorm_probs, - std::optional maybe_top_k_arr, unsigned int top_k_val, - int64_t cuda_stream) { - CHECK_INPUT(probs); - auto device = probs.device(); - CHECK_DIM(2, probs); // probs: (batch_size, vocab_size) - unsigned int batch_size = probs.size(0); - unsigned int vocab_size = probs.size(1); - bool has_top_k_arr = maybe_top_k_arr.has_value(); - - cudaStream_t stream = reinterpret_cast(cuda_stream); - cudaError_t status = sampling::TopKRenormProb( - static_cast(probs.data_ptr()), static_cast(renorm_probs.data_ptr()), - has_top_k_arr ? static_cast(maybe_top_k_arr->data_ptr()) : nullptr, batch_size, - top_k_val, vocab_size, stream); - - TORCH_CHECK(status == cudaSuccess, - "TopKRenormProb failed with error code " + std::string(cudaGetErrorString(status))); -} - -void top_k_mask_logits(at::Tensor logits, at::Tensor mask_logits, - std::optional maybe_top_k_arr, unsigned int top_k_val, - int64_t cuda_stream) { - CHECK_INPUT(logits); - auto device = logits.device(); - CHECK_DIM(2, logits); // logits: (batch_size, vocab_size) - unsigned int batch_size = logits.size(0); - unsigned int vocab_size = logits.size(1); - bool has_top_k_arr = maybe_top_k_arr.has_value(); - - cudaStream_t stream = reinterpret_cast(cuda_stream); - cudaError_t status = sampling::TopKMaskLogits( - static_cast(logits.data_ptr()), static_cast(mask_logits.data_ptr()), - has_top_k_arr ? static_cast(maybe_top_k_arr->data_ptr()) : nullptr, batch_size, - top_k_val, vocab_size, stream); - - TORCH_CHECK(status == cudaSuccess, - "TopKMaskLogits failed with error code " + std::string(cudaGetErrorString(status))); -} - void chain_speculative_sampling(at::Tensor draft_probs, at::Tensor draft_token_ids, at::Tensor uniform_samples, at::Tensor target_probs, at::Tensor output_token_ids, at::Tensor output_accepted_token_num, diff --git a/python/flashinfer/jit/__init__.py b/python/flashinfer/jit/__init__.py index ee95117c..e184e73f 100644 --- a/python/flashinfer/jit/__init__.py +++ b/python/flashinfer/jit/__init__.py @@ -35,6 +35,7 @@ from .attention import get_single_prefill_uri as get_single_prefill_uri from .core import clear_cache_dir, load_cuda_ops from .env import * +from .utils import parallel_load_modules as parallel_load_modules try: from .aot_config import prebuilt_ops_uri as prebuilt_ops_uri # type: ignore[import] diff --git a/python/flashinfer/jit/batch_prefill_templ.py b/python/flashinfer/jit/batch_prefill_templ.py index 8e00ebf3..b6be762c 100644 --- a/python/flashinfer/jit/batch_prefill_templ.py +++ b/python/flashinfer/jit/batch_prefill_templ.py @@ -14,6 +14,8 @@ limitations under the License. """ +import itertools + batch_prefill_suffix = [ "_plan.cu", "_ragged_run.cu", diff --git a/python/flashinfer/jit/core.py b/python/flashinfer/jit/core.py index 93f946ed..8b9490ca 100644 --- a/python/flashinfer/jit/core.py +++ b/python/flashinfer/jit/core.py @@ -108,7 +108,7 @@ def load_cuda_ops( ] + CUTLASS_INCLUDE_DIRS lock = FileLock(FLASHINFER_JIT_DIR / f"{name}.lock", thread_local=False) with lock: - return torch_cpp_ext.load( + module = torch_cpp_ext.load( name, list(map(lambda _: str(_), sources)), extra_cflags=cflags, @@ -119,3 +119,5 @@ def load_cuda_ops( verbose=verbose, with_cuda=True, ) + logger.info(f"Finished loading JIT ops: {name}") + return module diff --git a/python/flashinfer/jit/utils.py b/python/flashinfer/jit/utils.py index b35d216e..63a87c5f 100644 --- a/python/flashinfer/jit/utils.py +++ b/python/flashinfer/jit/utils.py @@ -15,9 +15,13 @@ """ import pathlib +import threading +from typing import Callable, List import torch +from .core import logger + def write_if_different(path: pathlib.Path, content: str) -> None: if path.exists(): @@ -30,6 +34,34 @@ def write_if_different(path: pathlib.Path, content: str) -> None: f.write(content) +def parallel_load_modules( + load_module_funcs: List[Callable], +): + threads = [] + exceptions = [] + + def wrapper(func): + try: + func() + except Exception as e: + exceptions.append((func, e)) + + for func in load_module_funcs: + thread = threading.Thread(target=wrapper, args=(func,)) + thread.start() + threads.append(thread) + + for thread in threads: + thread.join() + + if exceptions: + for func, e in exceptions: + print(f"Exception occurred in {func.__name__}: {e}") + raise RuntimeError("One or more exceptions occurred during module loading") + + logger.info("Finished loading modules") + + dtype_map = { torch.float16: "half", torch.bfloat16: "nv_bfloat16", diff --git a/python/flashinfer/sampling.py b/python/flashinfer/sampling.py index 896721c6..a91b184c 100644 --- a/python/flashinfer/sampling.py +++ b/python/flashinfer/sampling.py @@ -37,6 +37,7 @@ def get_sampling_module(): "sampling", [ FLASHINFER_CSRC_DIR / "sampling.cu", + FLASHINFER_CSRC_DIR / "renorm.cu", FLASHINFER_CSRC_DIR / "flashinfer_sampling_ops.cu", ], ) diff --git a/tests/test_jit_warmup.py b/tests/test_jit_warmup.py new file mode 100644 index 00000000..a664c026 --- /dev/null +++ b/tests/test_jit_warmup.py @@ -0,0 +1,54 @@ +""" +Copyright (c) 2024 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import torch +from flashinfer.jit import parallel_load_modules +from flashinfer.utils import PosEncodingMode + +import flashinfer + + +def test_warmpup_llama(): + parallel_load_modules( + [ + lambda: flashinfer.activation.get_act_and_mul_module("silu"), + flashinfer.norm.get_norm_module, + flashinfer.sampling.get_sampling_module, + flashinfer.quantization.get_quantization_module, + flashinfer.page.get_page_module, + lambda: flashinfer.decode.get_batch_decode_module( + torch.float16, + torch.float16, + torch.float16, + torch.int32, + 128, + PosEncodingMode.NONE.value, + False, # use_sliding_window + False, # use_logits_soft_cap + ), + lambda: flashinfer.prefill.gen_batch_prefill_module( + torch.float16, + torch.float16, + torch.float16, + torch.int32, + 128, + PosEncodingMode.NONE.value, + False, # use_sliding_window + False, # use_logits_soft_cap + False, # allow_fp16_qk_reduction + ), + ] + )