## Latest News
+* [2024/09/29] 🌟 AI at Meta PyTorch + TensorRT v2.4 🌟 ⚡TensorRT 10.1 ⚡PyTorch 2.4 ⚡CUDA 12.4 ⚡Python 3.12
+[➡️ link](https://github.com/pytorch/TensorRT/releases/tag/v2.4.0)
+
+
+
+
+* [2024/09/17] ✨ NVIDIA TensorRT-LLM Meetup
+[➡️ link](https://drive.google.com/file/d/1RR8GqC-QbuaKuHj82rZcXb3MS20SWo6F/view?usp=share_link)
+
+* [2024/09/17] ✨ Accelerating LLM Inference at Databricks with TensorRT-LLM
+[➡️ link](https://drive.google.com/file/d/1NeSmrLaWRJAY1rxD9lJmzpB9rzr38j8j/view?usp=sharing)
+
+* [2024/09/17] ✨ TensorRT-LLM @ Baseten
+[➡️ link](https://drive.google.com/file/d/1Y7L2jqW-aRmt31mCdqhwvGMmCSOzBUjG/view?usp=share_link)
+
+* [2024/09/04] 🏎️🏎️🏎️ Best Practices for Tuning TensorRT-LLM for Optimal Serving with BentoML
+[➡️ link](https://www.bentoml.com/blog/tuning-tensor-rt-llm-for-optimal-serving-with-bentoml)
+
* [2024/08/20] 🏎️SDXL with #TensorRT Model Optimizer ⏱️⚡ 🏁 cache diffusion 🏁 quantization aware training 🏁 QLoRA 🏁 #Python 3.12
[➡️ link](https://developer.nvidia.com/blog/nvidia-tensorrt-model-optimizer-v0-15-boosts-inference-performance-and-expands-model-support/)
@@ -43,6 +61,9 @@ TensorRT-LLM
* [2024/07/02] Let the @MistralAI MoE tokens fly 📈 🚀 #Mixtral 8x7B with NVIDIA #TensorRT #LLM on #H100.
[➡️ Tech blog](https://developer.nvidia.com/blog/achieving-high-mixtral-8x7b-performance-with-nvidia-h100-tensor-core-gpus-and-tensorrt-llm?ncid=so-twit-928467)
+
+Previous News
+
* [2024/06/24] Enhanced with NVIDIA #TensorRT #LLM, @upstage.ai’s solar-10.7B-instruct is ready to power your developer projects through our API catalog 🏎️. ✨[➡️ link](https://build.nvidia.com/upstage/solar-10_7b-instruct?snippet_tab=Try )
* [2024/06/18] CYMI: 🤩 Stable Diffusion 3 dropped last week 🎊 🏎️ Speed up your SD3 with #TensorRT INT8 Quantization[➡️ link](https://build.nvidia.com/upstage/solar-10_7b-instruct?snippet_tab=Try )
@@ -55,10 +76,6 @@ Technical Deep Dive for serious coders ✅+99% compression ✅1 set of weights
* [2024/06/04] ✨ #TensorRT and GeForce #RTX unlock ComfyUI SD superhero powers 🦸⚡ 🎥 Demo: [➡️ link](https://youtu.be/64QEVfbPHyg)
📗 DIY notebook: [➡️ link](https://console.brev.dev/launchable/deploy?userID=2x2sil999&orgID=ktj33l4xj&name=ComfyUI_TensorRT&instance=L4%40g2-standard-4%3Anvidia-l4%3A1&diskStorage=500&cloudID=GCP&baseImage=docker.io%2Fpytorch%2Fpytorch%3A2.2.0-cuda12.1-cudnn8-runtime&ports=ComfUI%3A8188&file=https%3A%2F%2Fgithub.com%2Fbrevdev%2Fnotebooks%2Fblob%2Fmain%2Ftensorrt-comfyui.ipynb&launchableID=env-2hQX3n7ae5mq3NjNZ32DfAG0tJf)
-
-Previous News
-
-
* [2024/05/28] ✨#TensorRT weight stripping for ResNet-50 ✨ ✅+99% compression
✅1 set of weights → ** GPUs\ ✅0 performance loss ✅** models…LLM, CNN, etc
👀 📚 DIY [➡️ link](https://console.brev.dev/launchable/deploy?userID=2x2sil999&orgID=ktj33l4xj&launchableID=env-2h6bym7h5GFNho3vpWQQeUYMwTM&instance=L4%40g6.xlarge&diskStorage=500&cloudID=devplane-brev-1&baseImage=nvcr.io%2Fnvidia%2Ftensorrt%3A24.05-py3&file=https%3A%2F%2Fgithub.com%2FNVIDIA%2FTensorRT%2Fblob%2Frelease%2F10.0%2Fsamples%2Fpython%2Fsample_weight_stripping%2Fnotebooks%2Fweight_stripping.ipynb&name=tensorrt_weight_stripping_resnet50)
@@ -68,10 +85,8 @@ Serverless TensorRT-LLM (LLaMA 3 8B) | Modal Docs [➡️ link](https://modal.co
* [2024/05/08] NVIDIA TensorRT Model Optimizer -- the newest member of the #TensorRT ecosystem is a library of post-training and training-in-the-loop model optimization techniques ✅quantization ✅sparsity ✅QAT [➡️ blog](https://developer.nvidia.com/blog/accelerate-generative-ai-inference-performance-with-nvidia-tensorrt-model-optimizer-now-publicly-available/)
-
* [2024/05/07] 🦙🦙🦙 24,000 tokens per second 🛫Meta Llama 3 takes off with #TensorRT #LLM 📚[➡️ link](https://blogs.nvidia.com/blog/meta-llama3-inference-acceleration/)
-
* [2024/02/06] [🚀 Speed up inference with SOTA quantization techniques in TRT-LLM](./docs/source/blogs/quantization-in-TRT-LLM.md)
* [2024/01/30] [ New XQA-kernel provides 2.4x more Llama-70B throughput within the same latency budget](./docs/source/blogs/XQA-kernel.md)
* [2023/12/04] [Falcon-180B on a single H200 GPU with INT4 AWQ, and 6.7x faster Llama-70B over A100](./docs/source/blogs/Falcon180B-H200.md)
@@ -88,7 +103,7 @@ Serverless TensorRT-LLM (LLaMA 3 8B) | Modal Docs [➡️ link](https://modal.co
## TensorRT-LLM Overview
TensorRT-LLM is a library for optimizing Large Language Model (LLM) inference.
-It provides state-of-the-art optimziations, including custom attention kernels, inflight batching, paged KV caching, quantization (FP8, INT4 [AWQ](https://arxiv.org/abs/2306.00978), INT8 [SmoothQuant](https://arxiv.org/abs/2211.10438), ++) and much more, to perform inference efficiently on NVIDIA GPUs
+It provides state-of-the-art optimizations, including custom attention kernels, inflight batching, paged KV caching, quantization (FP8, INT4 [AWQ](https://arxiv.org/abs/2306.00978), INT8 [SmoothQuant](https://arxiv.org/abs/2211.10438), ++) and much more, to perform inference efficiently on NVIDIA GPUs
TensorRT-LLM provides a Python API to build LLMs into optimized
[TensorRT](https://developer.nvidia.com/tensorrt) engines.
diff --git a/benchmarks/README.md b/benchmarks/README.md
index 00f450319..b368a6621 100644
--- a/benchmarks/README.md
+++ b/benchmarks/README.md
@@ -7,5 +7,6 @@ There are currently three workflows to benchmark TensorRT-LLM:
- The recommended workflow that uses TensorRT-LLM C++ API and can take advantage of the latest features of TensorRT-LLM.
* [Python benchmarks](./python)
- The Python benchmarking scripts can only benchmark the Python runtime, which do not support the latest features, such as in-flight batching.
-* [The Python benchmarking suite](./Suite.md)
- - This benchmarking suite is a current work in progress and is prone to large changes.
+* [The Python benchmarking suite](../docs/source/performance/perf-benchmarking.md)
+ - This benchmarker is native to TensorRT-LLM and is a Python benchmarker for reproducing and testing the performance of TensorRT-LLM.
+ - _NOTE_: This benchmarking suite is a current work in progress and is prone to large changes.
diff --git a/benchmarks/cpp/gptManagerBenchmark.cpp b/benchmarks/cpp/gptManagerBenchmark.cpp
index 45632350c..8e5d94a12 100644
--- a/benchmarks/cpp/gptManagerBenchmark.cpp
+++ b/benchmarks/cpp/gptManagerBenchmark.cpp
@@ -145,6 +145,7 @@ struct BenchmarkParams
{
std::optional maxTokensInPagedKvCache{std::nullopt};
std::optional freeGpuMemoryFraction{std::nullopt};
+ std::optional crossKvCacheFraction{std::nullopt};
bool enableTrtOverlap{false};
bool enableBlockReuse{false};
bool enableChunkedContext{false};
@@ -159,6 +160,8 @@ struct BenchmarkParams
std::optional sinkTokenLength{std::nullopt};
bool multiBlockMode{true};
bool enableContextFMHAFP32Acc{false};
+ bool cudaGraphMode{false};
+ SizeType32 cudaGraphCacheSize{0};
// lora / peft params
std::optional loraDir{std::nullopt};
@@ -470,7 +473,38 @@ class Recorder
mRequestBenchInfos[requestId].firstTokenSeen = true;
}
- mRequestBenchInfos[requestId].outputLength += 1;
+ mRequestBenchInfos[requestId].decodingIter += 1;
+ }
+
+ void recordToken(uint64_t requestId, std::list const& responseTensors)
+ {
+ int32_t outputLength = 1;
+ for (auto& tensor : responseTensors)
+ {
+ if (tensor.name == inference_request::kSequenceLengthTensorName)
+ {
+ // Tensor of shape nBeams, and we only need the first one
+ outputLength = *(bufferCast(*(tensor.tensor)));
+ break;
+ }
+ }
+
+ mRequestBenchInfos[requestId].outputLength += outputLength;
+ this->recordToken(requestId);
+ }
+
+ void recordToken(uint64_t requestId, texec::Response const& response)
+ {
+ auto outputTokenIds = response.getResult().outputTokenIds;
+
+ int32_t outputLength = 1;
+ for (auto const& beam : outputTokenIds)
+ {
+ outputLength = std::max(static_cast(beam.size()), outputLength);
+ }
+
+ mRequestBenchInfos[requestId].outputLength += outputLength;
+ this->recordToken(requestId);
}
void recordEnd(uint64_t requestId, std::list const& responseTensors, bool hasError)
@@ -500,7 +534,7 @@ class Recorder
}
else
{
- this->recordToken(requestId);
+ this->recordToken(requestId, responseTensors);
}
}
@@ -532,7 +566,7 @@ class Recorder
}
else
{
- this->recordToken(requestId);
+ this->recordToken(requestId, response);
}
}
}
@@ -818,11 +852,13 @@ class ExecutorServer
texec::SchedulerConfig schedulerConfig(capacitySchedulerPolicy);
texec::KvCacheConfig kvCacheConfig(benchmarkParams.enableBlockReuse, benchmarkParams.maxTokensInPagedKvCache,
benchmarkParams.maxAttentionWindowVec, benchmarkParams.sinkTokenLength,
- benchmarkParams.freeGpuMemoryFraction, benchmarkParams.kvHostCacheSize, benchmarkParams.kvOnboardBlocks);
+ benchmarkParams.freeGpuMemoryFraction, benchmarkParams.kvHostCacheSize, benchmarkParams.kvOnboardBlocks,
+ benchmarkParams.crossKvCacheFraction);
texec::PeftCacheConfig peftCacheConfig(0, benchmarkParams.loraDeviceNumModLayers, 8, 64, 4, 4, 4, 24, 8,
std::nullopt, benchmarkParams.loraHostCacheSize);
- texec::ExtendedRuntimePerfKnobConfig extendedRuntimePerfKnobConfig(
- benchmarkParams.multiBlockMode, benchmarkParams.enableContextFMHAFP32Acc);
+ texec::ExtendedRuntimePerfKnobConfig extendedRuntimePerfKnobConfig(benchmarkParams.multiBlockMode,
+ benchmarkParams.enableContextFMHAFP32Acc, benchmarkParams.cudaGraphMode,
+ benchmarkParams.cudaGraphCacheSize);
texec::ExecutorConfig executorConfig(
maxBeamWidth, schedulerConfig, kvCacheConfig, benchmarkParams.enableChunkedContext, true);
executorConfig.setGpuWeightsPercent(benchmarkParams.gpuWeightsPercent);
@@ -940,7 +976,7 @@ class ExecutorServer
{
if (!warmup && !response.hasError())
{
- mRecorder->recordToken(reqId);
+ mRecorder->recordToken(reqId, response);
}
}
}
@@ -1228,7 +1264,7 @@ class GptServer
{
if (errMsg.empty())
{
- mRecorder->recordToken(requestId);
+ mRecorder->recordToken(requestId, response_tensors);
}
}
}
@@ -1430,6 +1466,10 @@ void benchmarkGptManager(std::filesystem::path const& engineDir, TrtGptModelType
{
optionalParams.kvCacheConfig.freeGpuMemoryFraction = benchmarkParams.freeGpuMemoryFraction;
}
+ if (benchmarkParams.crossKvCacheFraction)
+ {
+ optionalParams.kvCacheConfig.crossKvCacheFraction = benchmarkParams.crossKvCacheFraction;
+ }
if (benchmarkParams.maxAttentionWindowVec)
{
optionalParams.kvCacheConfig.maxAttentionWindowVec = benchmarkParams.maxAttentionWindowVec;
@@ -1458,8 +1498,8 @@ void benchmarkGptManager(std::filesystem::path const& engineDir, TrtGptModelType
: benchmarkParams.executorLookaheadConfig.has_value() ? texec::DecodingMode::Lookahead()
: texec::DecodingMode::Auto(),
benchmarkParams.executorLookaheadConfig, benchmarkParams.medusaChoices);
- optionalParams.extendedRuntimePerfKnobConfig = texec::ExtendedRuntimePerfKnobConfig(
- benchmarkParams.multiBlockMode, benchmarkParams.enableContextFMHAFP32Acc);
+ optionalParams.extendedRuntimePerfKnobConfig = texec::ExtendedRuntimePerfKnobConfig(benchmarkParams.multiBlockMode,
+ benchmarkParams.enableContextFMHAFP32Acc, benchmarkParams.cudaGraphMode, benchmarkParams.cudaGraphCacheSize);
auto const jsonConfig = GptJsonConfig::parse(engineDir / "config.json");
auto const worldConfig = WorldConfig::mpi(jsonConfig.getGpusPerNode(), jsonConfig.getTensorParallelism(),
@@ -1874,6 +1914,8 @@ int main(int argc, char* argv[])
"random_seed", "integer random seed for exponential time delays.", cxxopts::value()->default_value("420"));
options.add_options()(
"kv_cache_free_gpu_mem_fraction", "K-V Cache Free Gpu Mem Fraction.", cxxopts::value());
+ options.add_options()(
+ "cross_kv_cache_fraction", "Cross K-V Cache Fraction (from 0.0 to 1.0).", cxxopts::value());
options.add_options()("request_rate",
"request rate in reqs/sec. Skipping this arg or negative value will trigger offline/0-delay.",
cxxopts::value());
@@ -1895,7 +1937,8 @@ int main(int argc, char* argv[])
options.add_options()("return_generation_logits", "Whether to return generation logits.",
cxxopts::value()->default_value("false"));
- options.add_options()("scheduler_policy", "Choose scheduler policy between max_utilization/guaranteed_no_evict.",
+ options.add_options()("scheduler_policy",
+ "Choose scheduler policy between max_utilization/guaranteed_no_evict/static_batch.",
cxxopts::value()->default_value("guaranteed_no_evict"));
options.add_options()("first_batch_delay",
@@ -1946,6 +1989,12 @@ int main(int argc, char* argv[])
cxxopts::value()->default_value("true"));
options.add_options()(
"encoder_engine_dir", "Directory that store the engines of the encoder models.", cxxopts::value());
+ options.add_options()("cuda_graph_mode", "When enabled, inference is executed with cuda graph.",
+ cxxopts::value()->default_value("false"));
+ options.add_options()("cuda_graph_cache_size",
+ "Specify how many cuda graphs are cached in the runtime. Larger cache gives better perf, but consumes more GPU "
+ "memory.",
+ cxxopts::value()->default_value("0"));
options.add_options()("enable_context_fmha_fp32_acc", "Enable FMHA runner FP32 accumulation",
cxxopts::value()->default_value("false"));
@@ -2040,6 +2089,20 @@ int main(int argc, char* argv[])
{
benchmarkParams.freeGpuMemoryFraction = result["kv_cache_free_gpu_mem_fraction"].as();
}
+ // Argument: K-V Cache Cross Attention Fraction. Only applicable to enc-dec models.
+ if (result.count("encoder_engine_dir") && result.count("decoder_engine_dir"))
+ {
+ if (result.count("cross_kv_cache_fraction"))
+ {
+ benchmarkParams.crossKvCacheFraction = result["cross_kv_cache_fraction"].as();
+ }
+ else
+ {
+ benchmarkParams.crossKvCacheFraction
+ = 0.5f; // default value if not set. but non enc-dec should not even have this param set
+ }
+ }
+
// Argument: Enable TRT overlap
benchmarkParams.enableTrtOverlap = result["enable_trt_overlap"].as();
@@ -2131,6 +2194,12 @@ int main(int argc, char* argv[])
// Argument: enable_context_fmha_fp32_acc
benchmarkParams.enableContextFMHAFP32Acc = result["enable_context_fmha_fp32_acc"].as();
+ // Argument: cuda_graph_mode
+ benchmarkParams.cudaGraphMode = result["cuda_graph_mode"].as();
+
+ // Argument: cuda_graph_mode
+ benchmarkParams.cudaGraphCacheSize = result["cuda_graph_cache_size"].as();
+
std::optional padId;
// Argument: Padding token id
if (result.count("pad_id"))
@@ -2168,6 +2237,10 @@ int main(int argc, char* argv[])
{
capacitySchedulerPolicy = texec::CapacitySchedulerPolicy::kGUARANTEED_NO_EVICT;
}
+ else if (capacitySchedulerPolicyArg == "static_batch")
+ {
+ capacitySchedulerPolicy = texec::CapacitySchedulerPolicy::kSTATIC_BATCH;
+ }
else
{
TLLM_LOG_ERROR("Unexpected scheduler policy: " + capacitySchedulerPolicyArg);
@@ -2246,14 +2319,14 @@ int main(int argc, char* argv[])
{
texec::ModelType executorModelType;
std::optional decoderEngineDir = std::nullopt, encoderEngineDir = std::nullopt;
- if (result.count("encoder_engine_dir") && result.count("engine_dir"))
+ if (result.count("encoder_engine_dir") && result.count("decoder_engine_dir"))
{
TLLM_CHECK_WITH_INFO(api == "executor", "encoder-decoder only support executor api.");
TLLM_CHECK_WITH_INFO(
modelType == TrtGptModelType::InflightFusedBatching, "encoder-decoder only support inflight batching.");
executorModelType = texec::ModelType::kENCODER_DECODER;
- decoderEngineDir = result["engine_dir"].as();
encoderEngineDir = result["encoder_engine_dir"].as();
+ decoderEngineDir = result["decoder_engine_dir"].as();
}
else if (result.count("engine_dir"))
{
diff --git a/benchmarks/cpp/utils/prepare_real_data.py b/benchmarks/cpp/utils/prepare_real_data.py
index 5f14f6747..94383cfa2 100644
--- a/benchmarks/cpp/utils/prepare_real_data.py
+++ b/benchmarks/cpp/utils/prepare_real_data.py
@@ -231,8 +231,6 @@ def dataset(root_args, **kwargs):
}, root_args.output)
else:
print_dataset(
- task_ids,
input_ids,
output_lens,
- tokenizer=None,
)
diff --git a/benchmarks/python/gpt_benchmark.py b/benchmarks/python/gpt_benchmark.py
index 04ba2ab0f..ce06c9f9f 100644
--- a/benchmarks/python/gpt_benchmark.py
+++ b/benchmarks/python/gpt_benchmark.py
@@ -80,7 +80,7 @@ def __init__(self, args, batch_sizes, in_out_lens, gpu_weights_percents,
kv_cache_type = KVCacheType.CONTINUOUS
if hasattr(self, 'kv_cache_type'):
- kv_cache_type = self.kv_cache_type
+ kv_cache_type = KVCacheType(self.kv_cache_type)
else:
if hasattr(self, 'paged_kv_cache'):
kv_cache_type = KVCacheType.PAGED if self.paged_kv_cache == True else KVCacheType.CONTINUOUS
diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt
index 32e89ae17..125526f7e 100644
--- a/cpp/CMakeLists.txt
+++ b/cpp/CMakeLists.txt
@@ -316,6 +316,8 @@ endif()
get_filename_component(TRT_LLM_ROOT_DIR ${CMAKE_CURRENT_SOURCE_DIR} PATH)
set(3RDPARTY_DIR ${TRT_LLM_ROOT_DIR}/3rdparty)
+add_subdirectory(${3RDPARTY_DIR}/pybind11 ${CMAKE_CURRENT_BINARY_DIR}/pybind11)
+
include_directories(
${CUDAToolkit_INCLUDE_DIRS}
${CUDNN_ROOT_DIR}/include
@@ -323,7 +325,8 @@ include_directories(
${3RDPARTY_DIR}/cutlass/include
${3RDPARTY_DIR}/cutlass/tools/util/include
${3RDPARTY_DIR}/NVTX/include
- ${3RDPARTY_DIR}/json/include)
+ ${3RDPARTY_DIR}/json/include
+ ${3RDPARTY_DIR}/pybind11/include)
# TRT dependencies
set_ifndef(TRT_LIB_DIR ${CMAKE_BINARY_DIR})
@@ -381,7 +384,7 @@ endif()
# set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -G")
set(CMAKE_CXX_FLAGS
- "${CMAKE_CXX_FLAGS} -DBUILD_SYSTEM=cmake_oss -DENABLE_MULTI_DEVICE=${ENABLE_MULTI_DEVICE} -DENABLE_UCX=${ENABLE_UCX}"
+ "${CMAKE_CXX_FLAGS} -DBUILD_SYSTEM=cmake_oss -DENABLE_MULTI_DEVICE=${ENABLE_MULTI_DEVICE}"
)
# Fix linking issue with TRT 10, the detailed description about `--mcmodel` can
@@ -561,6 +564,7 @@ if(ENABLE_UCX)
NO_DEFAULT_PATH)
endif()
endif()
+set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DENABLE_UCX=${ENABLE_UCX}")
file(STRINGS "${TRT_INCLUDE_DIR}/NvInferVersion.h" VERSION_STRINGS
REGEX "#define NV_TENSORRT_.*")
diff --git a/cpp/include/tensorrt_llm/batch_manager/capacityScheduler.h b/cpp/include/tensorrt_llm/batch_manager/capacityScheduler.h
new file mode 100644
index 000000000..a08544e2a
--- /dev/null
+++ b/cpp/include/tensorrt_llm/batch_manager/capacityScheduler.h
@@ -0,0 +1,187 @@
+/*
+ * Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include "common.h"
+#include "tensorrt_llm/batch_manager/llmRequest.h"
+#include "tensorrt_llm/common/algorithm.h"
+#include "tensorrt_llm/runtime/common.h"
+#include
+
+namespace tensorrt_llm::batch_manager
+{
+namespace kv_cache_manager
+{
+class KVCacheManager;
+}
+class BasePeftCacheManager;
+} // namespace tensorrt_llm::batch_manager
+
+namespace tensorrt_llm::batch_manager
+{
+
+using tensorrt_llm::runtime::SizeType32;
+
+/// @brief This scheduler takes into account the given request capacity and the KV cache capacity.
+/// Depending on the CapacitySchedulerPolicy it will schedule already started and new requests,
+/// or even pause previously started requests.
+class BaseCapacityScheduler
+{
+public:
+ explicit BaseCapacityScheduler(LlmRequestState noScheduleUntilState, LlmRequestState noScheduleAfterState)
+ : mNoScheduleUntilState(noScheduleUntilState)
+ , mNoScheduleAfterState(noScheduleAfterState)
+ {
+ }
+
+ [[nodiscard]] LlmRequestState constexpr getNoScheduleUntilState() const noexcept
+ {
+ return mNoScheduleUntilState;
+ }
+
+ [[nodiscard]] LlmRequestState constexpr getNoScheduleAfterState() const noexcept
+ {
+ return mNoScheduleAfterState;
+ }
+
+private:
+ /// The state until/after which the scheduler should not schedule requests
+ LlmRequestState mNoScheduleUntilState;
+ LlmRequestState mNoScheduleAfterState;
+};
+
+/// @brief Schedule up to maxNumRequests requests
+class MaxRequestsScheduler : public BaseCapacityScheduler
+{
+public:
+ explicit MaxRequestsScheduler(SizeType32 maxNumRequests,
+ std::shared_ptr kvCacheManager,
+ std::shared_ptr crossKvCacheManager,
+ LlmRequestState noScheduleUntilState = LlmRequestState::kCONTEXT_INIT,
+ LlmRequestState noScheduleAfterState = LlmRequestState::kGENERATION_COMPLETE);
+
+ /// @brief Takes as input a sorted list of requests and outputs a sorted lists of requests
+ /// to update for this current iteration, and a map of requests to pause
+ [[nodiscard]] std::tuple operator()(RequestList const& activeRequests) const;
+
+private:
+ SizeType32 mMaxNumRequests;
+ std::shared_ptr mKvCacheManager{nullptr};
+ std::shared_ptr mCrossKvCacheManager{nullptr};
+};
+
+/// @brief Schedule requests using the MAX_UTILIZATION policy
+/// @details Try reserving resources to advance requests by one step,
+/// may pause previously started requests.
+class MaxUtilizationScheduler : public BaseCapacityScheduler
+{
+public:
+ MaxUtilizationScheduler(SizeType32 maxNumRequests, std::shared_ptr kvCacheManager,
+ std::shared_ptr crossKvCacheManager,
+ std::shared_ptr peftCacheManager, bool manyMicroBatches,
+ LlmRequestState noScheduleUntilState = LlmRequestState::kCONTEXT_INIT,
+ LlmRequestState noScheduleAfterState = LlmRequestState::kGENERATION_COMPLETE);
+
+ [[nodiscard]] std::tuple operator()(RequestList const& activeRequests) const;
+
+private:
+ /// @return {fitsKvCache, fitsPeft}
+ std::pair trySchedulingRequestMaxUtilization(std::shared_ptr const& req,
+ RequestVector& scheduledRequests, SizeType32& numScheduledBlocks, SizeType32& numScheduledPeftPages,
+ std::unordered_set& seenTaskIds) const;
+
+ SizeType32 mMaxNumRequests;
+ std::shared_ptr mKvCacheManager{nullptr};
+ std::shared_ptr mCrossKvCacheManager{nullptr};
+ std::shared_ptr mPeftCacheManager{nullptr};
+ /// @brief Boolean that indicates if multiple micro batches might be in flight
+ bool mManyMicroBatches;
+};
+
+/// @brief Schedule requests using the GUARANTEED_NO_EVICT policy
+class GuaranteedNoEvictScheduler : public BaseCapacityScheduler
+{
+public:
+ GuaranteedNoEvictScheduler(SizeType32 maxNumRequests,
+ std::shared_ptr kvCacheManager,
+ std::shared_ptr crossKvCacheManager,
+ std::shared_ptr peftCacheManager,
+ LlmRequestState noScheduleUntilState = LlmRequestState::kCONTEXT_INIT,
+ LlmRequestState noScheduleAfterState = LlmRequestState::kGENERATION_COMPLETE);
+
+ [[nodiscard]] std::tuple operator()(RequestList const& activeRequests) const;
+
+protected:
+ [[nodiscard]] std::tuple forwardImpl(
+ RequestList const& activeRequests, bool staticBatchScheduling) const;
+
+private:
+ SizeType32 mMaxNumRequests;
+ std::shared_ptr mKvCacheManager{nullptr};
+ std::shared_ptr mCrossKvCacheManager{nullptr};
+ std::shared_ptr mPeftCacheManager{nullptr};
+};
+
+/// @brief Schedule requests using the STATIC_BATCH policy
+class StaticBatchScheduler : public GuaranteedNoEvictScheduler
+{
+public:
+ StaticBatchScheduler(SizeType32 maxNumRequests, std::shared_ptr kvCacheManager,
+ std::shared_ptr crossKvCacheManager,
+ std::shared_ptr peftCacheManager,
+ LlmRequestState noScheduleUntilState = LlmRequestState::kCONTEXT_INIT,
+ LlmRequestState noScheduleAfterState = LlmRequestState::kGENERATION_COMPLETE);
+
+ [[nodiscard]] std::tuple operator()(RequestList const& activeRequests) const;
+};
+
+class CapacityScheduler : public Algorithm
+{
+public:
+ constexpr static auto name{"CapacityScheduler"};
+
+ CapacityScheduler() = default;
+
+ CapacityScheduler(SizeType32 maxNumRequests, std::shared_ptr kvCacheManager,
+ std::shared_ptr crossKvCacheManager,
+ std::shared_ptr peftCacheManager,
+ executor::CapacitySchedulerPolicy capacitySchedulerPolicy, bool manyMicroBatches = false,
+ LlmRequestState noScheduleUntilState = LlmRequestState::kCONTEXT_INIT,
+ LlmRequestState noScheduleAfterState = LlmRequestState::kGENERATION_COMPLETE);
+
+ static CapacityScheduler make(SizeType32 maxNumRequests,
+ std::shared_ptr kvCacheManager,
+ std::shared_ptr crossKvCacheManager,
+ std::shared_ptr peftCacheManager,
+ executor::CapacitySchedulerPolicy capacitySchedulerPolicy, bool manyMicroBatches = false,
+ LlmRequestState noScheduleUntilState = LlmRequestState::kCONTEXT_INIT,
+ LlmRequestState noScheduleAfterState = LlmRequestState::kGENERATION_COMPLETE)
+ {
+ return CapacityScheduler{maxNumRequests, std::move(kvCacheManager), std::move(crossKvCacheManager),
+ std::move(peftCacheManager), capacitySchedulerPolicy, manyMicroBatches, noScheduleUntilState,
+ noScheduleAfterState};
+ }
+
+ [[nodiscard]] std::tuple operator()(RequestList const& activeRequests) const;
+
+private:
+ std::variant
+ mScheduler;
+};
+
+} // namespace tensorrt_llm::batch_manager
diff --git a/cpp/include/tensorrt_llm/batch_manager/common.h b/cpp/include/tensorrt_llm/batch_manager/common.h
new file mode 100644
index 000000000..6e4a76bc4
--- /dev/null
+++ b/cpp/include/tensorrt_llm/batch_manager/common.h
@@ -0,0 +1,118 @@
+/*
+ * Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include "tensorrt_llm/runtime/common.h"
+#include
+#include
+#include
+#include
+#include
+#include
+
+namespace tensorrt_llm::executor
+{
+class RequestWithId;
+}
+
+namespace tensorrt_llm::batch_manager
+{
+class LlmRequest;
+
+using RequestList = std::list>;
+using RequestIdType = std::uint64_t;
+using RequestVector = std::vector>;
+using ReqIdsSet = std::unordered_set;
+
+class ScheduledRequests
+{
+public:
+ /// @brief context phase requests (for decoder-only models) or encoder phase requests (for encoder-decoder models
+ /// and encoder-only models)
+ RequestVector contextRequests;
+
+ /// @brief generation phase requests (for decoder-only models) or empty for others
+ RequestVector generationRequests;
+
+ ScheduledRequests() = default;
+
+ explicit ScheduledRequests(RequestVector contextRequests, RequestVector generationRequests)
+ : contextRequests{std::move(contextRequests)}
+ , generationRequests{std::move(generationRequests)}
+ {
+ }
+
+ [[nodiscard]] bool empty() const
+ {
+ return contextRequests.empty() && generationRequests.empty();
+ }
+
+ [[nodiscard]] std::size_t size() const
+ {
+ return contextRequests.size() + generationRequests.size();
+ }
+};
+
+class BatchState
+{
+public:
+ BatchState() = default;
+
+ BatchState(runtime::SizeType32 numCtxRequests, runtime::SizeType32 numGenRequests, runtime::SizeType32 numTokens,
+ runtime::SizeType32 maxKvCacheLength)
+ : mNumCtxRequests{numCtxRequests}
+ , mNumGenRequests{numGenRequests}
+ , mNumTokens{numTokens}
+ , mMaxKvCacheLength{maxKvCacheLength}
+ {
+ }
+
+ bool isAnyContext() const
+ {
+ return mNumCtxRequests > 0;
+ }
+
+ bool operator==(BatchState const& other) const
+ {
+ return mNumCtxRequests == other.mNumCtxRequests && mNumGenRequests == other.mNumGenRequests
+ && mNumTokens == other.mNumTokens && mMaxKvCacheLength == other.mMaxKvCacheLength;
+ }
+
+ size_t hash() const
+ {
+ size_t h1 = std::hash{}(mNumCtxRequests);
+ size_t h2 = std::hash{}(mNumGenRequests);
+ size_t h3 = std::hash{}(mNumTokens);
+ size_t h4 = std::hash{}(mMaxKvCacheLength);
+ return h1 ^ h2 ^ h3 ^ h4;
+ }
+
+ runtime::SizeType32 mNumCtxRequests;
+ runtime::SizeType32 mNumGenRequests;
+ runtime::SizeType32 mNumTokens;
+ runtime::SizeType32 mMaxKvCacheLength;
+};
+
+struct BatchStateHash
+{
+ size_t operator()(BatchState const& bs) const
+ {
+ return bs.hash();
+ }
+};
+
+} // namespace tensorrt_llm::batch_manager
diff --git a/cpp/include/tensorrt_llm/batch_manager/evictionPolicy.h b/cpp/include/tensorrt_llm/batch_manager/evictionPolicy.h
new file mode 100644
index 000000000..a7326eee7
--- /dev/null
+++ b/cpp/include/tensorrt_llm/batch_manager/evictionPolicy.h
@@ -0,0 +1,74 @@
+/*
+ * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include "tensorrt_llm/batch_manager/kvCacheManager.h"
+
+#include
+
+using namespace tensorrt_llm::batch_manager::kv_cache_manager;
+
+namespace tensorrt_llm::batch_manager::eviction_policy
+{
+
+class BaseEvictionPolicy
+{
+public:
+ virtual ~BaseEvictionPolicy() = default;
+
+ virtual void initialize(
+ std::vector& mAllBlocksById, SizeType32 numPrimaryBlocks, SizeType32 numSecondaryBlocks)
+ = 0;
+
+ // Get a free block from the primary memory pool
+ virtual BlockPtr getFreePrimaryBlock() = 0;
+ // Get a free block from the secondary memory pool
+ virtual BlockPtr getFreeSecondaryBlock() = 0;
+ // Release a block. Prioritize the block for eviction if toFront=true
+ virtual void releaseBlock(BlockPtr block, bool toFront = false) = 0;
+ // Get the amount of free blocks in the primary memory pool
+ virtual SizeType32 getNumFreePrimaryBlocks() = 0;
+ // Get the amount of free blocks in the secondary memory pool
+ virtual SizeType32 getNumFreeSecondaryBlocks() = 0;
+ // Claim a free block. Called when the cache manager allocates or reuses a new block
+ virtual void claimBlock(KVCacheBlock block) = 0;
+};
+
+class LRUEvictionPolicy : public BaseEvictionPolicy
+{
+public:
+ void initialize(
+ std::vector& mAllBlocksById, SizeType32 numPrimaryBlocks, SizeType32 numSecondaryBlocks) override;
+ BlockPtr getFreePrimaryBlock() override;
+ BlockPtr getFreeSecondaryBlock() override;
+ void releaseBlock(BlockPtr block, bool toFront = false) override;
+ SizeType32 getNumFreePrimaryBlocks() override;
+ SizeType32 getNumFreeSecondaryBlocks() override;
+
+ void claimBlock(KVCacheBlock block);
+
+private:
+ FreeBlocksQueue mFreePrimaryBlocks;
+ FreeBlocksQueue mFreeSecondaryBlocks;
+
+ std::vector> mFreeBlockIterators;
+
+ SizeType32 mFreePrimaryBlocksSize;
+ SizeType32 mFreeSecondaryBlocksSize;
+};
+
+} // namespace tensorrt_llm::batch_manager::eviction_policy
diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheConfig.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheConfig.h
index 0aa80adfe..b7295650a 100644
--- a/cpp/include/tensorrt_llm/batch_manager/kvCacheConfig.h
+++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheConfig.h
@@ -41,7 +41,8 @@ class KvCacheConfig
std::optional> maxAttentionWindowVec = std::nullopt,
std::optional sinkTokenLength = std::nullopt,
std::optional freeGpuMemoryFraction = std::nullopt, bool enableBlockReuse = false, bool useUvm = false,
- std::optional hostCacheSize = std::nullopt, bool onboardBlocks = true)
+ std::optional hostCacheSize = std::nullopt, bool onboardBlocks = true,
+ std::optional crossKvCacheFraction = std::nullopt)
: maxTokens{maxTokens}
, maxAttentionWindowVec{maxAttentionWindowVec}
, sinkTokenLength{sinkTokenLength}
@@ -50,6 +51,7 @@ class KvCacheConfig
, useUvm(useUvm)
, hostCacheSize(hostCacheSize)
, onboardBlocks(onboardBlocks)
+ , crossKvCacheFraction{crossKvCacheFraction}
{
}
@@ -57,7 +59,7 @@ class KvCacheConfig
: KvCacheConfig(kvCacheConfig.getMaxTokens(), kvCacheConfig.getMaxAttentionWindowVec(),
kvCacheConfig.getSinkTokenLength(), kvCacheConfig.getFreeGpuMemoryFraction(),
kvCacheConfig.getEnableBlockReuse(), false, kvCacheConfig.getHostCacheSize(),
- kvCacheConfig.getOnboardBlocks())
+ kvCacheConfig.getOnboardBlocks(), kvCacheConfig.getCrossKvCacheFraction())
{
}
@@ -66,7 +68,8 @@ class KvCacheConfig
return maxTokens == other.maxTokens && maxAttentionWindowVec == other.maxAttentionWindowVec
&& sinkTokenLength == other.sinkTokenLength && freeGpuMemoryFraction == other.freeGpuMemoryFraction
&& enableBlockReuse == other.enableBlockReuse && useUvm == other.useUvm
- && hostCacheSize == other.hostCacheSize && onboardBlocks == other.onboardBlocks;
+ && hostCacheSize == other.hostCacheSize && onboardBlocks == other.onboardBlocks
+ && crossKvCacheFraction == other.crossKvCacheFraction;
}
friend std::ostream& operator<<(std::ostream& os, KvCacheConfig const& self);
@@ -80,5 +83,7 @@ class KvCacheConfig
bool useUvm;
std::optional hostCacheSize;
bool onboardBlocks;
+ // Cross will use crossKvCacheFraction of KV Cache and self attention will use the rest.
+ std::optional crossKvCacheFraction;
};
} // namespace tensorrt_llm::batch_manager::kv_cache_manager
diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h
index 38b49bd23..cc7aa9374 100644
--- a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h
+++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h
@@ -22,6 +22,7 @@
#include "tensorrt_llm/runtime/bufferManager.h"
#include "tensorrt_llm/runtime/common.h"
#include "tensorrt_llm/runtime/cudaStream.h"
+#include "tensorrt_llm/runtime/iBuffer.h"
#include "tensorrt_llm/runtime/iTensor.h"
#include "tensorrt_llm/runtime/modelConfig.h"
#include "tensorrt_llm/runtime/worldConfig.h"
@@ -29,13 +30,18 @@
#include
#include
-#include
+#include
#include
#include
#include
#include
#include
+namespace tensorrt_llm::batch_manager::eviction_policy
+{
+class BaseEvictionPolicy;
+}
+
namespace tensorrt_llm::batch_manager::kv_cache_manager
{
@@ -124,6 +130,8 @@ class KVCacheBlock
[[nodiscard]] IdType getBlockId() const;
+ [[nodiscard]] NextBlockMap getNextBlocks() const;
+
[[nodiscard]] kernels::KVCacheIndex::UnderlyingType getMemoryPoolBlockIndex() const;
[[nodiscard]] bool isPrimary() const;
@@ -144,22 +152,12 @@ class KVCacheBlock
[[nodiscard]] VecUniqueTokens const& getUniqueTokens() const;
- void setFreeBlockIterator(FreeBlocksQueue::iterator freeBlockIterator);
-
- void resetFreeBlockIterator();
-
- [[nodiscard]] std::optional const& getFreeBlockIterator() const;
-
void setPrevBlock(BlockPtr prevBlock);
void addNextBlock(BlockKey const& blockKey, BlockPtr block);
void removeNextBlock(BlockKey const& blockKey);
- static std::shared_ptr findBestGPUBlockToFree(std::shared_ptr searchStart);
-
- static std::shared_ptr findLeafBlock(std::shared_ptr searchStart);
-
[[nodiscard]] BlockPtr findMatchingBlock(BlockKey const& blockKey) const;
//! \brief Free block from previous block if present.
@@ -203,14 +201,21 @@ class GenerationRequest
{
public:
using SizeType32 = tensorrt_llm::runtime::SizeType32;
- using SharedPtr = std::shared_ptr;
- explicit GenerationRequest(SizeType32 seqSlotIdx, SizeType32 numTokens, SizeType32 beamWidth)
- : mSeqSlotIdx(seqSlotIdx)
+ explicit GenerationRequest(LlmRequest::RequestIdType requestId, SizeType32 numTokens, SizeType32 beamWidth,
+ SizeType32 maxBlocks, SizeType32 numPools = 1)
+ : mRequestId(requestId)
, mNumTokens(numTokens)
, mBeamWidth(beamWidth)
, mCacheBlockIds(beamWidth)
+ , mCacheBlockIndices{
+ runtime::BufferManager::cpu(runtime::ITensor::makeShape({numPools, beamWidth, 2, maxBlocks}),
+ runtime::TRTDataType::value)}
{
+ auto cacheBlockIdsRange = runtime::BufferRange(*mCacheBlockIndices);
+ std::fill(cacheBlockIdsRange.begin(), cacheBlockIdsRange.end(),
+ tensorrt_llm::kernels::KVCacheIndex{
+ std::numeric_limits::max()});
}
void addNewTokens(SizeType32 n)
@@ -225,9 +230,9 @@ class GenerationRequest
mNumTokens -= n;
}
- [[nodiscard]] SizeType32 getSequenceSlotIdx() const
+ [[nodiscard]] LlmRequest::RequestIdType getRequestId() const
{
- return mSeqSlotIdx;
+ return mRequestId;
}
[[nodiscard]] SizeType32 getNumTokens() const
@@ -245,6 +250,16 @@ class GenerationRequest
return mCacheBlockIds;
}
+ [[nodiscard]] runtime::ITensor& getCacheBlockIndices()
+ {
+ return *mCacheBlockIndices;
+ }
+
+ [[nodiscard]] runtime::ITensor const& getCacheBlockIndices() const
+ {
+ return *mCacheBlockIndices;
+ }
+
void addCacheBlock(SizeType32 beamIdx, KVCacheBlock::IdType blockId)
{
mCacheBlockIds.at(beamIdx).push_back(blockId);
@@ -272,37 +287,64 @@ class GenerationRequest
}
private:
- // Slot id of the sequence
- SizeType32 mSeqSlotIdx;
+ // Request id of the sequence
+ LlmRequest::RequestIdType mRequestId;
// Current number of generated tokens
SizeType32 mNumTokens;
// Number of beams
SizeType32 mBeamWidth;
- // List of blocks allocated for each beam of the sequence
+ // List of block ids allocated for each beam of the sequence
std::vector> mCacheBlockIds;
+ // Tensor of block indices allocated for each beam of the sequence
+ runtime::ITensor::SharedPtr mCacheBlockIndices;
};
-// BlockManager manages overall metadata of KVCacheBlocks in a layer of the
-// network. Layers are expected to be symmetric, so the metadata can be
-// reused for all layers of the network.
-// The array of cache blocks for a layer is called a pool.
-// Each pool has shape [max_blocks, 2, num_heads, tokens_per_block, head_size].
-// Size per block and number of blocks per pool are pre-determined and set in
-// constructor. These should not be changed after.
-// Block shape is [2, num_heads, tokens_per_block, head_size].
+// attach metadata to a pool pointer
+class KVCacheBlockPool
+{
+public:
+ SizeType32 numKvHeads;
+ SizeType32 numLayers;
+ SizeType32 blockSize;
+
+ // Memory pools. Primary is fast memory, secondary is slower memory used for offloading.
+ runtime::ITensor::SharedPtr primaryPtr;
+ runtime::ITensor::SharedPtr secondaryPtr;
+
+ KVCacheBlockPool(SizeType32 numKvHeads, SizeType32 numLayers, SizeType32 blockSize,
+ runtime::ITensor::SharedPtr primaryPtr = nullptr, runtime::ITensor::SharedPtr secondaryPtr = nullptr)
+ : numKvHeads(numKvHeads)
+ , numLayers(numLayers)
+ , blockSize(blockSize)
+ , primaryPtr(std::move(primaryPtr))
+ , secondaryPtr(std::move(secondaryPtr))
+ {
+ }
+};
+
+// The BlockManager manages the metadata of KVCacheBlocks.
+// It manages multiple arrays of cache blocks called pools.
+// Layers with the same number of kv heads are grouped under the same pool.
+// Each pool has shape [max_blocks, num_layers, 2, num_kv_heads, tokens_pre_block, head_size], where num_layers refers
+// to the number of layers with the same num_kv_heads that share that pool.
+// The metadata of KVCacheBlocks is shared between layers, so each block spans all of the managed pool - an allocated
+// block matches some chunk of memory in each pool. The shape of the chunk in every pool is [2, num_kv_heads,
+// tokens_per_block, head_size]. The size per block and number of blocks are pre-determined and set in the constructor.
// BlockManager maintains a list of free blocks at any time.
// Alloc pops off the block at the front, and Free pushes it back to the vector.
-// BlockManager maintains a vector of lists of seqSlotIdx to allocated blocks
+// BlockManager maintains a vector of lists of request ids to allocated blocks
// per sequence. This can be used to Free all blocks belonging to a sequence.
class BlockManager
{
public:
using SizeType32 = tensorrt_llm::runtime::SizeType32;
using CacheType = tensorrt_llm::batch_manager::kv_cache_manager::CacheType;
+ using BaseEvictionPolicy = tensorrt_llm::batch_manager::eviction_policy::BaseEvictionPolicy;
- explicit BlockManager(SizeType32 numLayers, SizeType32 numKvHeads, SizeType32 sizePerHead,
+ explicit BlockManager(std::vector const& numKvHeadsPerLayer, SizeType32 sizePerHead,
SizeType32 tokensPerBlock, SizeType32 blocksInPrimaryPool, SizeType32 blocksInSecondaryPool,
- std::shared_ptr stream, bool onboardBlocks, CacheType cacheType = CacheType::kSELF);
+ SizeType32 maxNumSequences, std::shared_ptr stream, bool onboardBlocks,
+ CacheType cacheType = CacheType::kSELF);
~BlockManager();
@@ -317,10 +359,6 @@ class BlockManager
//! \brief Assign blocks for new sequence. Does not try to reuse blocks.
void addSequence(GenerationRequest& sequence, SizeType32 numBlocks, SizeType32 unsharedBlockIdx);
- //! \brief Release block, which puts it back onto free blocks queue.
- //! \details Block appended by default, will be put at front if toFront is true.
- void releaseBlock(std::shared_ptr block, bool toFront = false);
-
//! \brief Allocate new block for each beam of the sequence.
//! \details Might free cached blocks if no free blocks are available.
void allocateBlock(GenerationRequest& sequence, bool shareAmongBeams = false);
@@ -336,10 +374,7 @@ class BlockManager
//! \brief Release last block in the sequence
void releaseLastBlock(GenerationRequest& sequence);
- [[nodiscard]] SizeType32 getNumFreeBlocks() const noexcept
- {
- return mFreePrimaryBlocks.size();
- }
+ [[nodiscard]] SizeType32 getNumFreeBlocks() const noexcept;
[[nodiscard]] SizeType32 getNumAllocTotalBlocks() const
{
@@ -381,21 +416,26 @@ class BlockManager
return mTokensPerBlock;
}
- //! \brief Get size of one K/V cache block in one layer.
- //! @details Volume of [numKvHeads, tokensPerBlock, sizePerHead]
- [[nodiscard]] SizeType32 getBlockSize() const
+ //! \brief Get size of one K/V cache block in one layer for the specified pool.
+ //! @details Volume of [numKvHeads, tokensPerBlock, sizePerHead] in the specified pool.
+ [[nodiscard]] SizeType32 getBlockSize(SizeType32 poolIdx) const
{
- return mBlockSize;
+ return mPools.at(poolIdx).blockSize;
}
- [[nodiscard]] runtime::ITensor::SharedPtr getPrimaryPool() const noexcept
+ [[nodiscard]] SizeType32 getNumPools() const noexcept
{
- return mPrimaryPool;
+ return mPools.size();
}
- [[nodiscard]] runtime::ITensor::SharedPtr getSecondaryPool() const noexcept
+ [[nodiscard]] runtime::ITensor::SharedPtr getPrimaryPool(SizeType32 poolIdx) const
{
- return mSecondaryPool;
+ return mPools.at(poolIdx).primaryPtr;
+ }
+
+ [[nodiscard]] runtime::ITensor::SharedPtr getSecondaryPool(SizeType32 poolIdx) const
+ {
+ return mPools.at(poolIdx).secondaryPtr;
}
[[nodiscard]] SizeType32 getNumLayers() const
@@ -403,10 +443,32 @@ class BlockManager
return mNumLayers;
}
+ [[nodiscard]] SizeType32 getNumPrimaryBlocks() const
+ {
+ return mNumPrimaryBlocks;
+ }
+
+ [[nodiscard]] SizeType32 getNumSecondaryBlocks() const
+ {
+ return mNumSecondaryBlocks;
+ }
+
+ [[nodiscard]] CacheType getCacheType() const
+ {
+ return mCacheType;
+ }
+
+ [[nodiscard]] SizeType32 getLayerPoolIdx(SizeType32 layerIdx) const
+ {
+ return mLayerToPool.at(layerIdx);
+ }
+
//! \brief Get index in pool to K or V block.
//! \param blockId the blockId as returned by getBlockId()
//! \param fieldIdx either 0 (K) or 1 (V),
- [[nodiscard]] kernels::KVCacheIndex getKOrVBlockIndex(KVCacheBlock::IdType blockId, SizeType32 fieldIdx) const;
+ //! \param poolIdx the index of the pool for which the index is calculated (each pool has different strides)
+ [[nodiscard]] kernels::KVCacheIndex getKOrVBlockIndex(
+ KVCacheBlock::IdType blockId, SizeType32 fieldIdx, SizeType32 poolIdx) const;
//! \brief Bring offloaded block from secondary to primary memory.
//! \details Does nothing of block is already in primary memory.
@@ -417,6 +479,11 @@ class BlockManager
BlockKey findNewContextBlock(
VecUniqueTokens const& uniqueTokens, std::shared_ptr const& llmRequest) const;
+ [[nodiscard]] runtime::BufferManager const& getBufferManager() const
+ {
+ return mBufferManager;
+ }
+
private:
//! \brief Add single block to beam of sequence and mAllocatedBlocksPerSeq.
void addBlockToBeam(BlockPtr& block, GenerationRequest& sequence, SizeType32 beamIdx);
@@ -436,22 +503,15 @@ class BlockManager
SizeType32 loadOrAllocateBlocks(
std::list const& blockKeys, SizeType32 numContextBlocks, GenerationRequest& sequence);
- //! \brief Find best primary block to free.
- //! \details The best primary block to free is the primary block that appears first in the queue and have no primary
- //! block descendants
- [[nodiscard]] std::shared_ptr findBestGPUBlockToFree();
-
//! \brief Find block least likely to be reused, free it if necessary and return.
[[nodiscard]] BlockPtr getFreeBlock();
- //! \brief Claim block if it is in free blocks list.
- void claimBlock(KVCacheBlock& block);
-
//! \brief Free block from previous block and claim it from free blocks list.
void claimLeafBlock(KVCacheBlock& block);
//! \brief Compute pointer to raw KV block (K & V, all layers).
- [[nodiscard]] runtime::ITensor::SharedPtr computeBlockPointer(std::shared_ptr block) const;
+ [[nodiscard]] runtime::ITensor::SharedPtr computeBlockPointer(
+ std::shared_ptr block, SizeType32 poolIdx) const;
//! \brief Copy content of src block to dst.
void copyBlock(BlockPtr src, BlockPtr dst);
@@ -460,23 +520,24 @@ class BlockManager
// Number of blocks in pools
SizeType32 mNumPrimaryBlocks;
SizeType32 mNumSecondaryBlocks;
- // List of free blocks. Blocks are either backed by fast primary memory or slow secondary memory,
- // we maintain separate queues for these.
- FreeBlocksQueue mFreePrimaryBlocks;
- FreeBlocksQueue mFreeSecondaryBlocks;
+
// List of allocated blocks for each sequences
- std::vector> mAllocatedBlocksPerSeq;
- // Memory pools. Primary is fast memory, secondary is slower memory used for offloading.
- runtime::ITensor::SharedPtr mPrimaryPool;
- runtime::ITensor::SharedPtr mSecondaryPool;
+ std::unordered_map> mAllocatedBlocksPerSeq;
+
+ // Pool per unique numKvHeads in the model
+ std::vector mPools;
+ // Matching of model layers to their pools
+ std::vector mLayerToPool;
+
// Whether offloaded blocks should be onboarded before reuse.
bool mOnboardBlocks;
// Buffer manager
runtime::BufferManager mBufferManager;
+
+ // Size of a single KV heads
+ SizeType32 mSizePerHead;
// Number of layers
SizeType32 mNumLayers;
- // Volume of [numKvHeads, tokensPerBlock, sizePerHead]
- SizeType32 mBlockSize;
// Used to keep track of number of free blocks during scheduling
SizeType32 mSchedulingNumFreeBlocks;
// Number of tokens per one block
@@ -489,6 +550,8 @@ class BlockManager
std::size_t mAllocTotalBlocks, mAllocNewBlocks, mReusedBlocks;
// KV cache type (self or cross)
CacheType mCacheType;
+ // Eviction Policy
+ std::shared_ptr mEvictionPolicy;
private:
friend class KVCacheManager;
@@ -497,17 +560,24 @@ class BlockManager
class KVCacheManager
{
public:
+ friend class KVCacheManagerBindings;
+
using SizeType32 = tensorrt_llm::runtime::SizeType32;
- using SequencesPtr = GenerationRequest::SharedPtr;
using CudaStreamPtr = std::shared_ptr;
using CacheType = tensorrt_llm::batch_manager::kv_cache_manager::CacheType;
- KVCacheManager(SizeType32 numLayers, SizeType32 numKvHeads, SizeType32 sizePerHead, SizeType32 tokensPerBlock,
+ KVCacheManager(std::vector const& numKvHeadsPerLayer, SizeType32 sizePerHead, SizeType32 tokensPerBlock,
SizeType32 blocksInPrimaryPool, SizeType32 blocksInSecondaryPool, SizeType32 maxNumSequences,
SizeType32 maxBeamWidth, SizeType32 maxAttentionWindow, SizeType32 sinkTokenLength, bool useOneMoreBlock,
CudaStreamPtr stream, bool enableBlockReuse = false, bool onboardBlocks = true,
CacheType cacheType = CacheType::kSELF);
+ KVCacheManager(SizeType32 numLayers, SizeType32 numKvHeads, SizeType32 sizePerHead, SizeType32 tokensPerBlock,
+ SizeType32 blocksInPrimaryPool, SizeType32 blocksInSecondaryPool, SizeType32 maxNumSequences,
+ SizeType32 maxBeamWidth, SizeType32 maxAttentionWindow, SizeType32 sinkTokenLength, bool useOneMoreBlock,
+ CudaStreamPtr stream, bool enableBlockReuse = true, bool onboardBlocks = true,
+ CacheType cacheType = CacheType::kSELF);
+
void allocatePools(nvinfer1::DataType dtype, bool useUvm = false);
void startScheduling();
@@ -583,10 +653,10 @@ class KVCacheManager
/// @return The number of blocks
[[nodiscard]] SizeType32 getRemainingBlocksToCompletion(LlmRequest const& req) const;
- void addContextTokens(SizeType32 seqSlotIdx, SizeType32 numTokens);
+ void addContextTokens(LlmRequest::RequestIdType requestId, SizeType32 numTokens);
- /// @brief Increase size for request at seqSlotIdx. Allocate new KV cache block(s) if needed.
- void addToken(SizeType32 seqSlotIdx);
+ /// @brief Increase size for request with requestId. Allocate new KV cache block(s) if needed.
+ void addToken(LlmRequest::RequestIdType requestId);
/// @brief Add new request to the KV cache manager.
/// @param inputLength Input length for which KV cache need to be allocated.
@@ -594,34 +664,40 @@ class KVCacheManager
/// @param llmRequest Optional request to use for KV cache lookup.
/// @details If llmRequest is supplied and KV cache reuse is enabled, try to recover KV cache blocks for
/// inputLength - 1 tokens and populate prepopulatedPromptLen.
- void addSequence(SizeType32 seqSlotIdx, SizeType32 inputLength, SizeType32 beamWidth,
+ void addSequence(LlmRequest::RequestIdType requestId, SizeType32 inputLength, SizeType32 beamWidth,
std::shared_ptr const& llmRequest = nullptr);
- void removeSequence(SizeType32 seqSlotIdx, std::shared_ptr const& llmRequest = nullptr);
+ void removeSequence(LlmRequest::RequestIdType requestId, std::shared_ptr const& llmRequest = nullptr);
- void schedulingRemoveSequence(SizeType32 seqSlotIdx);
+ void schedulingRemoveSequence(LlmRequest::RequestIdType requestId);
- [[nodiscard]] runtime::ITensor::UniquePtr getBlockPoolPointers() const;
+ [[nodiscard]] runtime::ITensor::SharedPtr getBlockPoolPointers() const
+ {
+ return mBlockPoolPointers;
+ }
+
+ [[nodiscard]] runtime::ITensor::SharedPtr getLayerToPoolMapping() const
+ {
+ return mLayerToPoolMapping;
+ }
void getBlockOffsetsOfBatch(
runtime::ITensor& output, SizeType32 firstBatchSlotIdx, SizeType32 batchSize, SizeType32 beamWidth) const;
//! @return maxBlockCount of all beams
SizeType32 copyBlockOffsets(
- runtime::ITensor& output, SizeType32 outputSlotOffset, SizeType32 seqSlotIdx, SizeType32 beamWidth) const;
-
- // Volume of [2, numKvHeads, tokensPerBlock, sizePerHead]
- [[nodiscard]] static SizeType32 constexpr calculatePageSize(tensorrt_llm::runtime::ModelConfig const& modelConfig)
- {
- return 2 * modelConfig.getNbKvHeads() * modelConfig.getTokensPerBlock() * modelConfig.getSizePerHead();
- }
+ runtime::ITensor& output, SizeType32 outputSlotOffset, LlmRequest::RequestIdType requestId) const;
- // numLayers * 2 * numKvHeads * sizePerHead
- [[nodiscard]] static SizeType32 constexpr calculateCacheSizePerToken(
+ // Sum of numLayers * 2 * numKvHeads * sizePerHead for each pool
+ [[nodiscard]] static SizeType32 calculateCacheSizePerToken(
tensorrt_llm::runtime::ModelConfig const& modelConfig, tensorrt_llm::runtime::WorldConfig const& worldConfig)
{
- return modelConfig.getNbAttentionLayers(worldConfig.getPipelineParallelism()) * 2 * modelConfig.getNbKvHeads()
- * modelConfig.getSizePerHead();
+ // NOTE: We expect the initialization of modelConfig to have already taken the tp size into account and do not
+ // address it here
+ // consider only local layers for the calculation
+ return modelConfig.getSumLocalKvHeads(
+ worldConfig.getPipelineParallelism(), worldConfig.getPipelineParallelRank())
+ * 2 * modelConfig.getSizePerHead();
}
[[nodiscard]] static std::tuple const calculateMaxNumBlocks(KvCacheConfig const& config,
@@ -633,14 +709,14 @@ class KVCacheManager
return mEnableBlockReuse;
}
- void removeToken(SizeType32 seqSlotIdx);
- void rewindKVCache(SizeType32 seqSlotIdx, SizeType32 rewindLengths);
+ void removeToken(LlmRequest::RequestIdType requestId);
+ void rewindKVCache(LlmRequest::RequestIdType requestId, SizeType32 rewindLengths);
- [[nodiscard]] GenerationRequest const& getSequence(SizeType32 seqSlotIdx) const;
+ [[nodiscard]] GenerationRequest const& getSequence(LlmRequest::RequestIdType requestId) const;
[[nodiscard]] bool isCrossKv() const
{
- return mCacheType == CacheType::kCROSS;
+ return mBlockManager.getCacheType() == CacheType::kCROSS;
}
//! \brief Find first new block that must be allocated for context phase and return it's concatenated token vector.
@@ -650,7 +726,7 @@ class KVCacheManager
//! \brief Store full context blocks contributed by llmRequest.
//! \details These blocks become reusable from next step.
- void storeContextBlocks(SizeType32 seqSlotIdx, std::shared_ptr const& llmRequest);
+ void storeContextBlocks(std::shared_ptr const& llmRequest);
[[nodiscard]] static SizeType32 getSinkBubbleLength(SizeType32 sinkTokenLen, SizeType32 tokensPerBlock);
@@ -658,14 +734,13 @@ class KVCacheManager
SizeType32 tokensPerBlock, SizeType32 maxBeamWidth, SizeType32 sinkTokenLen, bool useOneMoreBlock);
private:
- void setOffsets(kernels::KVCacheIndex* offsetsPtr, nvinfer1::Dims const& offsetsShape, SizeType32 seqSlotIdx,
- SizeType32 beamIdx, SizeType32 blockIdx, KVCacheBlock::IdType blockId) const;
+ void setOffsets(kernels::KVCacheIndex* offsetsPtr, nvinfer1::Dims const& offsetsShape, SizeType32 beamIdx,
+ SizeType32 blockIdx, KVCacheBlock::IdType blockId) const;
- void resetBlockOffsets(SizeType32 seqSlotIdx, SizeType32 beamWidth);
- void cacheBlockOffsets(GenerationRequest const& seq, SizeType32 seqSlotIdx);
- void cacheNewBlockOffsets(GenerationRequest const& seq, SizeType32 seqSlotIdx);
- void updateNewBlockPointer(GenerationRequest const& seq, SizeType32 seqSlotIdx, SizeType32 blockIdx);
- void updateToken(SizeType32 seqSlotIdx, bool addToken);
+ void cacheBlockOffsets(GenerationRequest& seq);
+ void cacheNewBlockOffsets(GenerationRequest& seq);
+ void updateNewBlockPointer(GenerationRequest& seq, SizeType32 blockIdx);
+ void updateToken(GenerationRequest& sequence, bool addToken);
private:
// Maximum number of sequences
@@ -685,14 +760,13 @@ class KVCacheManager
SizeType32 mSinkBlockTokenLength;
// Block manager
BlockManager mBlockManager;
- // List of all sequences
- std::vector mSequences;
- // buffer for block indices for all managed sequences
- runtime::ITensor::SharedPtr mSequenceBlockIndices;
+ // Map of all sequences
+ std::unordered_map mSequences;
// Whether to cache KV pages for reuse
bool mEnableBlockReuse;
- // KV cache type (self or cross)
- CacheType mCacheType;
+ // buffers for static tensors, will be created after allocating pools
+ runtime::ITensor::SharedPtr mBlockPoolPointers;
+ runtime::ITensor::SharedPtr mLayerToPoolMapping;
};
} // namespace tensorrt_llm::batch_manager::kv_cache_manager
diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheUtils.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheUtils.h
index 81b91e24a..69ca1963b 100644
--- a/cpp/include/tensorrt_llm/batch_manager/kvCacheUtils.h
+++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheUtils.h
@@ -65,6 +65,11 @@ class BlockIterator
return ret;
}
+ operator runtime::ITensor::SharedPtr()
+ {
+ return mCurrent;
+ }
+
[[nodiscard]] bool operator==(BlockIterator const& other) const
{
return mIdx == other.mIdx && mPool.get() == other.mPool.get();
@@ -91,9 +96,9 @@ class BlockIterator
};
[[nodiscard]] BlockIterator getBlockBeginIt(
- KVCacheManager const& cacheManager, LlmRequest const& request, SizeType32 beam);
+ KVCacheManager const& cacheManager, LlmRequest const& request, SizeType32 beam, SizeType32 poolIdx);
[[nodiscard]] BlockIterator getBlockEndIt(
- KVCacheManager const& cacheManager, LlmRequest const& request, SizeType32 beam);
+ KVCacheManager const& cacheManager, LlmRequest const& request, SizeType32 beam, SizeType32 poolIdx);
} // namespace tensorrt_llm::batch_manager::kv_cache_manager
diff --git a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h
index 0124592e8..475970b7b 100644
--- a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h
+++ b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h
@@ -26,6 +26,7 @@
#include "tensorrt_llm/runtime/samplingConfig.h"
#include
+#include
#include
#include
#include
@@ -39,24 +40,22 @@ namespace tensorrt_llm::batch_manager
* @brief The state of the request.
*
* Enum order must follow chronological order for state dependency check, @see hasReachedState().
- *
- * @todo(rkobus): refactor
*/
-enum LlmRequestState_t
+enum class LlmRequestState : int32_t
{
- REQUEST_STATE_UNKNOWN = 0, ///< Unknown state
- REQUEST_STATE_ENCODER_INIT = 1, ///< Encoder phase starts (for encoder-decoder models)
- REQUEST_STATE_CONTEXT_INIT = 2, ///< Context phase starts
- REQUEST_STATE_GENERATION_IN_PROGRESS = 3, ///< Generation phase is in progress
- REQUEST_STATE_GENERATION_TO_COMPLETE = 4, ///< Generation phase is to be completed
- REQUEST_STATE_GENERATION_COMPLETE = 5, ///< Generation phase completed
- REQUEST_STATE_DISAGG_GENERATION_INIT = 6, ///< For disaggregated serving only:
- /// new Generation request arrived at generation model
- REQUEST_STATE_DISAGG_CONTEXT_TRANS_IN_PROGRESS = 7, ///< For disaggregated serving only:
- /// Waiting context-only request transmitting the kv cache
- REQUEST_STATE_DISAGG_CONTEXT_COMPLETE = 8, ///< Context-only request finished kv cache transmission.
- REQUEST_STATE_DISAGG_GENERATION_TRANS_IN_PROGRESS
- = 9, ///< For disaggregated serving only: transmitting the kv cache
+ kUNKNOWN = 0, ///< Unknown state
+ kENCODER_INIT = 1, ///< Encoder phase starts (for encoder-decoder models)
+ kCONTEXT_INIT = 2, ///< Context phase starts
+ kGENERATION_IN_PROGRESS = 3, ///< Generation phase is in progress
+ kGENERATION_TO_COMPLETE = 4, ///< Generation phase is to be completed
+ kGENERATION_COMPLETE = 5, ///< Generation phase completed
+ kDISAGG_GENERATION_INIT = 6, ///< For disaggregated serving only:
+ /// new Generation request arrived at generation model
+ kDISAGG_CONTEXT_TRANS_IN_PROGRESS = 7, ///< For disaggregated serving only:
+ /// Waiting context-only request transmitting the kv cache
+ kDISAGG_CONTEXT_COMPLETE = 8, ///< Context-only request finished kv cache transmission.
+ kDISAGG_GENERATION_TRANS_IN_PROGRESS = 9, ///< For disaggregated serving only: transmitting the kv cache
+ kWAITING_TO_SEND_LOGITS = 10, ///< Generation phase completed, logits not sent yet
};
enum LlmRequestType
@@ -114,7 +113,7 @@ class GenericLlmRequest
, mPromptLen(inputTokens->size())
, mMaxNewTokens(maxNewTokens)
, mSamplingConfig(samplingConfig)
- , mState(REQUEST_STATE_CONTEXT_INIT)
+ , mState(LlmRequestState::kCONTEXT_INIT)
, mEndId(endId)
, mPadId(padId)
, mLogitsPostProcessor(logitsPostProcessor)
@@ -134,8 +133,7 @@ class GenericLlmRequest
, mLoraWeights(std::move(loraWeights))
, mLoraConfig(std::move(loraConfig))
, mLookaheadConfig(std::move(lookaheadConfig))
- , mContextChunkSize(std::nullopt)
- , mContextCurrentPosition(0)
+ , mContextChunkSize{mPromptLen}
, mLogProbs(samplingConfig.beamWidth)
, mCumLogProbs(samplingConfig.beamWidth)
, mDraftTokens(draftTokens.value_or(std::make_shared()))
@@ -159,7 +157,7 @@ class GenericLlmRequest
{
if (mEncoderTokens.has_value() || encoderInputFeatures.has_value())
{
- mState = REQUEST_STATE_ENCODER_INIT;
+ mState = LlmRequestState::kENCODER_INIT;
}
initialize(*inputTokens, returnLogProbs);
@@ -170,7 +168,7 @@ class GenericLlmRequest
, mPromptLen(req.getInputTokenIds().size())
, mMaxNewTokens(req.getMaxTokens())
, mSamplingConfig(req.getSamplingConfig(), req.getExternalDraftTokensConfig())
- , mState(REQUEST_STATE_CONTEXT_INIT)
+ , mState(LlmRequestState::kCONTEXT_INIT)
, mEndId(req.getEndId())
, mPadId(req.getPadId())
, mClientId(req.getClientId())
@@ -188,8 +186,7 @@ class GenericLlmRequest
, mLoraWeights(std::nullopt)
, mLoraConfig(std::nullopt)
, mLookaheadConfig(std::nullopt)
- , mContextChunkSize(std::nullopt)
- , mContextCurrentPosition(0)
+ , mContextChunkSize{mPromptLen}
, mLogProbs(mSamplingConfig.beamWidth)
, mCumLogProbs(mSamplingConfig.beamWidth)
, mDraftTokens(std::make_shared())
@@ -212,7 +209,7 @@ class GenericLlmRequest
{
if (req.getRequestType() == executor::RequestType::REQUEST_TYPE_GENERATION_ONLY)
{
- mState = REQUEST_STATE_DISAGG_GENERATION_INIT;
+ mState = LlmRequestState::kDISAGG_GENERATION_INIT;
}
if (mIsStreaming && mSamplingConfig.beamWidth > 1 && !mReturnAllGeneratedTokens)
{
@@ -236,7 +233,7 @@ class GenericLlmRequest
if (req.getEncoderInputTokenIds().has_value() || req.getEncoderInputFeatures().has_value())
{
- mState = REQUEST_STATE_ENCODER_INIT;
+ mState = LlmRequestState::kENCODER_INIT;
if (req.getEncoderInputTokenIds().has_value())
{
mEncoderTokens = std::make_shared(req.getEncoderInputTokenIds().value());
@@ -394,6 +391,15 @@ class GenericLlmRequest
mMaxNewTokens = maxNewTokens;
}
+ if (mNumReturnSequences > 1 && mSamplingConfig.beamWidth > 1)
+ {
+ TLLM_THROW(
+ "Using mNumReturnSequences (%d) > 1 with beam search is currently disabled, since TensorRT-LLM returns "
+ "a total of mNumReturnSequences x beamWidth beams, rather than limiting the number of returned beams "
+ "to mNumReturnSequences. This restriction will be removed once the issue is resolved.",
+ mNumReturnSequences);
+ }
+
TLLM_CHECK_WITH_INFO(mSamplingConfig.validate(), "Incorrect sampling config");
// validate extra ids when enabling kv cache reuse with prompt table
@@ -402,7 +408,8 @@ class GenericLlmRequest
TLLM_CHECK_WITH_INFO(mInputTokenExtraIds.has_value() && mInputTokenExtraIds.value(),
"Input token extra ids must be provided when enabling kv cache reuse with prompt table");
TLLM_CHECK_WITH_INFO(mInputTokenExtraIds.value()->size() == static_cast(mOrigPromptLen),
- "inputTokenExtraIds vector size must be the same as input token vector size.");
+ "inputTokenExtraIds vector size (%lu) must be the same as input token vector size (%lu).",
+ mInputTokenExtraIds.value()->size(), static_cast(mOrigPromptLen));
}
}
@@ -413,7 +420,7 @@ class GenericLlmRequest
/// @brief Get the params of the context
/// @return The params of the context
- std::optional const& getContextPhaseParams() const noexcept
+ [[nodiscard]] std::optional const& getContextPhaseParams() const noexcept
{
return mContextPhaseParams;
}
@@ -425,10 +432,10 @@ class GenericLlmRequest
/// @brief Get the state params of the context
/// @return The state params of the context
- executor::ContextPhaseState const& getContextPhaseState() const
+ [[nodiscard]] executor::DataTransceiverState const& getDataTransceiverState() const
{
TLLM_CHECK(mContextPhaseParams.has_value());
- return *static_cast(mContextPhaseParams.value().getState());
+ return *static_cast(mContextPhaseParams.value().getState());
}
/// @brief Get total number of tokens for this req (prompt + generated)
@@ -661,6 +668,11 @@ class GenericLlmRequest
return mSequenceIndex > 0;
}
+ [[nodiscard]] RequestIdType getParentRequestId() const
+ {
+ return mParentRequestId;
+ }
+
/// @brief Return a vector of the last-generated tokens of shape [num_beams]
[[nodiscard]] VecTokens const& getLastTokens()
{
@@ -715,10 +727,10 @@ class GenericLlmRequest
}
// for enc-dec models, pause means saving generated tokens to prompt but need to re-do encoder phase
- mState = mEncoderTokens.has_value() || mEncoderInputFeatures ? REQUEST_STATE_ENCODER_INIT
- : REQUEST_STATE_CONTEXT_INIT;
+ mState = mEncoderTokens.has_value() || mEncoderInputFeatures ? LlmRequestState::kENCODER_INIT
+ : LlmRequestState::kCONTEXT_INIT;
mContextCurrentPosition = 0;
- mContextChunkSize = std::nullopt;
+ mContextChunkSize = mPromptLen;
mSeqSlot.reset();
}
@@ -860,9 +872,9 @@ class GenericLlmRequest
return mOrigPromptLen;
}
- void setPrepopulatedPromptLen(SizeType32 prepopulatedPromptLen)
+ [[nodiscard]] SizeType32 getPromptLen() const
{
- mPrepopulatedPromptLen = prepopulatedPromptLen;
+ return mPromptLen;
}
[[nodiscard]] SizeType32 getPrepopulatedPromptLen() const
@@ -870,6 +882,37 @@ class GenericLlmRequest
return mPrepopulatedPromptLen;
}
+ void setPrepopulatedPromptLen(SizeType32 prepopulatedPromptLen, SizeType32 kvTokensPerBlock)
+ {
+ auto const promptLen = getPromptLen();
+ TLLM_CHECK(prepopulatedPromptLen < promptLen);
+ mPrepopulatedPromptLen = prepopulatedPromptLen;
+
+ if (prepopulatedPromptLen > 0)
+ {
+ // Currently, the runtime process is to apply for cache first and then determine prepopulation.
+ // Use the prepopulated length to advance the context position and decrease chunk size if necessary.
+ auto chunkSize = getContextChunkSize();
+ if (prepopulatedPromptLen + chunkSize < promptLen)
+ {
+ // make sure to end at block boundary after current chunk
+ auto const flooredEndPosition
+ = (prepopulatedPromptLen + chunkSize) / kvTokensPerBlock * kvTokensPerBlock;
+ chunkSize = flooredEndPosition - prepopulatedPromptLen;
+ TLLM_CHECK(chunkSize <= getContextChunkSize());
+ }
+ setContextCurrentPosition(prepopulatedPromptLen);
+ setContextChunkSize(chunkSize);
+
+ if (!isLastContextChunk())
+ {
+ TLLM_CHECK_WITH_INFO((getContextCurrentPosition() + getContextChunkSize()) % kvTokensPerBlock == 0,
+ "To prevent cache fragmentation, the context position after current chunk should be divisible "
+ "by the number of tokens per block, except for the last chunk.");
+ }
+ }
+ }
+
void setDraftTokens(std::shared_ptr const& draftTokens)
{
mDraftTokens = draftTokens;
@@ -1100,44 +1143,49 @@ class GenericLlmRequest
mGenerationLogitsFragments.clear();
}
- [[nodiscard]] bool hasReachedState(LlmRequestState_t state) const noexcept
+ [[nodiscard]] bool hasReachedState(LlmRequestState state) const noexcept
{
return mState >= state;
}
[[nodiscard]] bool isEncoderInitState() const noexcept
{
- return mState == REQUEST_STATE_ENCODER_INIT;
+ return mState == LlmRequestState::kENCODER_INIT;
}
[[nodiscard]] bool isContextInitState() const noexcept
{
- return mState == REQUEST_STATE_CONTEXT_INIT;
+ return mState == LlmRequestState::kCONTEXT_INIT;
}
[[nodiscard]] bool isGenerationInProgressState() const noexcept
{
- return mState == REQUEST_STATE_GENERATION_IN_PROGRESS || mState == REQUEST_STATE_GENERATION_TO_COMPLETE;
+ return mState == LlmRequestState::kGENERATION_IN_PROGRESS || mState == LlmRequestState::kGENERATION_TO_COMPLETE;
}
[[nodiscard]] bool isGenerationCompleteState() const noexcept
{
- return mState == REQUEST_STATE_GENERATION_COMPLETE;
+ return mState == LlmRequestState::kGENERATION_COMPLETE;
}
[[nodiscard]] bool isDisaggGenerationInitState() const noexcept
{
- return mState == REQUEST_STATE_DISAGG_GENERATION_INIT;
+ return mState == LlmRequestState::kDISAGG_GENERATION_INIT;
}
[[nodiscard]] bool isDisaggContextTransmissionState() const noexcept
{
- return mState == REQUEST_STATE_DISAGG_CONTEXT_TRANS_IN_PROGRESS;
+ return mState == LlmRequestState::kDISAGG_CONTEXT_TRANS_IN_PROGRESS;
}
[[nodiscard]] bool isDisaggContextCompleteState() const noexcept
{
- return mState == REQUEST_STATE_DISAGG_CONTEXT_COMPLETE;
+ return mState == LlmRequestState::kDISAGG_CONTEXT_COMPLETE;
+ }
+
+ [[nodiscard]] bool isCompleteWaitingToSendLogits() const noexcept
+ {
+ return mState == LlmRequestState::kWAITING_TO_SEND_LOGITS;
}
/// To determine whether the context is unchunked. When a context is chunked into only a part, it
@@ -1152,6 +1200,11 @@ class GenericLlmRequest
return mLlmRequestType == LlmRequestType::LLMREQUEST_TYPE_CONTEXT_ONLY;
}
+ [[nodiscard]] bool isGenerationOnlyRequest() const noexcept
+ {
+ return mLlmRequestType == LlmRequestType::LLMREQUEST_TYPE_GENERATION_ONLY;
+ }
+
void setContextCurrentPosition(SizeType32 contextCurrentPosition)
{
mContextCurrentPosition = contextCurrentPosition;
@@ -1170,12 +1223,11 @@ class GenericLlmRequest
return mPromptLen - getContextCurrentPosition();
}
- /// To retrieve the context chunk size, throw an exception when the context is not chunked.
[[nodiscard]] SizeType32 getContextChunkSize() const
{
- TLLM_CHECK_WITH_INFO(
- isContextInitState() && mContextChunkSize, "The current request is not in context chunking state.");
- return mContextChunkSize.value();
+ TLLM_CHECK_WITH_INFO(isContextInitState() || isDisaggGenerationInitState(),
+ "getContextChunkSize is only possible during the context phase.");
+ return mContextChunkSize;
}
/// To set the context chunk size, throw an exception when the chunk size is negative. If the chunk
@@ -1183,45 +1235,34 @@ class GenericLlmRequest
/// remaining length.
void setContextChunkSize(SizeType32 size)
{
- TLLM_CHECK_WITH_INFO(isContextInitState(), "Chunking is only possible during the context phase.");
+ TLLM_CHECK_WITH_INFO(isContextInitState(), "setContextChunkSize is only possible during the context phase.");
TLLM_CHECK_WITH_INFO(size >= 0, "The chunk size of context (%d) can't be negative.", size);
mContextChunkSize = std::min(size, getContextRemainingLength());
}
/// Determines whether the current position is only one chunk away from the end of the context.
- /// It will return true when the context is not chunked.
[[nodiscard]] bool isLastContextChunk() const noexcept
{
- return isFullContextRequest()
- || (isContextInitState() && getContextCurrentPosition() + getContextChunkSize() == mPromptLen);
+ return isDisaggGenerationInitState() || getContextCurrentPosition() + getContextChunkSize() == mPromptLen;
}
- /// Returns whether the position is at the beginning of the context. It will return true when the
- /// context is not chunked.
+ /// Returns whether the position is at the beginning of the context.
[[nodiscard]] bool isFirstContextChunk() const noexcept
{
- return isFullContextRequest() || getContextCurrentPosition() == 0;
- }
-
- [[nodiscard]] executor::PriorityType priority() const noexcept
- {
- return mPriority;
+ return getContextCurrentPosition() == 0;
}
/// Move the cursor forward one chunk. When not chunked, move forward to the end of the context.
void moveToNextContextChunk()
{
TLLM_CHECK_WITH_INFO(isContextInitState(), "Chunking is only possible during the context phase.");
- if (mContextChunkSize)
- {
- mContextCurrentPosition += getContextChunkSize();
- setContextChunkSize(0);
- }
- else
- {
- TLLM_CHECK_WITH_INFO(mContextCurrentPosition == 0, "Full context out of bounds.");
- mContextCurrentPosition = mPromptLen;
- }
+ mContextCurrentPosition += getContextChunkSize();
+ setContextChunkSize(0);
+ }
+
+ [[nodiscard]] executor::PriorityType priority() const noexcept
+ {
+ return mPriority;
}
/// Increment the counter of decoding iterations.
@@ -1241,20 +1282,24 @@ class GenericLlmRequest
return static_cast(getMaxNumGeneratedTokens()) / mDecodingIter;
}
+ [[nodiscard]] bool isFinished() const noexcept
+ {
+ return isGenerationCompleteState() || isDisaggContextTransmissionState() || isCompleteWaitingToSendLogits();
+ }
+
/// @brief Create a Response from the current state of the request
/// @return An optional Response
- std::optional createResponse()
+ std::optional createResponse(bool useFastLogits = false, int32_t mpiWorldRank = 0)
{
TLLM_CHECK(!isDisaggContextCompleteState());
- if (isGenerationCompleteState() || (mIsStreaming && isGenerationInProgressState())
- || isDisaggContextTransmissionState())
+ if (isFinished() || (mIsStreaming && mState == LlmRequestState::kGENERATION_IN_PROGRESS))
{
TLLM_LOG_DEBUG("Creating response for request %lu", mRequestId);
executor::Result result;
result.sequenceIndex = mSequenceIndex;
- result.isSequenceFinal = isGenerationCompleteState() || isDisaggContextTransmissionState();
+ result.isSequenceFinal = isFinished();
mSequenceFinalVec->at(mSequenceIndex) = result.isSequenceFinal;
result.isFinal = std::all_of(mSequenceFinalVec->begin(), mSequenceFinalVec->end(),
@@ -1273,7 +1318,7 @@ class GenericLlmRequest
}
// TODO: fill the rank ids
result.contextPhaseParams = executor::ContextPhaseParams{
- std::move(firstGenTokens), mContextPhaseParams.value().releaseState()};
+ std::move(firstGenTokens), mRequestId, mContextPhaseParams.value().releaseState()};
}
auto const calculateNbTokensOut = [this](SizeType32 maxNbTokens)
@@ -1292,8 +1337,7 @@ class GenericLlmRequest
auto const startTokenPos = maxNbTokens - maxNbTokensOut;
- auto const shouldSendResponse = isGenerationCompleteState()
- || (mIsStreaming && maxNbTokens > getMaxSentTokenLen()) || isDisaggContextTransmissionState();
+ auto const shouldSendResponse = isFinished() || (mIsStreaming && maxNbTokens > getMaxSentTokenLen());
if (!shouldSendResponse)
{
@@ -1333,6 +1377,11 @@ class GenericLlmRequest
= runtime::ITensor::slice(getGenerationLogitsHost(), startGenTokenPos, maxNbTokensOut);
result.generationLogits = executor::detail::ofITensor(generationLogitsHostCurrentStep);
}
+ else if (useFastLogits)
+ {
+ result.specDecFastLogitsInfo
+ = executor::SpeculativeDecodingFastLogitsInfo{mRequestId, mpiWorldRank};
+ }
else
{
result.generationLogits = executor::detail::ofITensor(getGenerationLogitsHost());
@@ -1351,7 +1400,7 @@ class GenericLlmRequest
setMaxSentTokenLen(maxNbTokens);
auto requestId = isChild() ? mParentRequestId : mRequestId;
- auto response = executor::Response(requestId, std::move(result));
+ auto response = executor::Response(requestId, std::move(result), mClientId);
return response;
}
@@ -1372,12 +1421,29 @@ class GenericLlmRequest
mDecodingIter = iter;
}
+ void setKvCacheTransferStart(std::chrono::time_point const& time)
+ {
+ mKvCacheTransferStart = time;
+ }
+
+ void setKvCacheTransferEnd(std::chrono::time_point const& time)
+ {
+ mKvCacheTransferEnd = time;
+ }
+
+ [[nodiscard]] double getKvCacheTransferTimeMS() const
+ {
+ // get max with 0 in case this function is called while end time is not recorded
+ return std::max(
+ 0.0, std::chrono::duration(mKvCacheTransferEnd - mKvCacheTransferStart).count());
+ }
+
RequestIdType mRequestId;
SizeType32 mPromptLen;
SizeType32 mMaxNewTokens;
// Tokens [beam_size, mPromptLen + getMaxNumGeneratedTokens()]
runtime::SamplingConfig mSamplingConfig;
- LlmRequestState_t mState;
+ LlmRequestState mState;
std::optional mEndId;
std::optional mPadId;
std::optional mSeqSlot;
@@ -1425,8 +1491,8 @@ class GenericLlmRequest
// To enable chunked context, the FHMA paged kv-cache also needs to be enabled. Except for the last one,
// the size of the context chunk needs to be an integer multiple of the kv-cache block size. The meaning
// of null value is that the context is not chunked.
- std::optional mContextChunkSize;
- SizeType32 mContextCurrentPosition;
+ SizeType32 mContextChunkSize{0};
+ SizeType32 mContextCurrentPosition{0};
std::vector mLogProbs; // [beamSize, seqLen]
VecLogProbs mCumLogProbs; // [beamSize]
@@ -1476,6 +1542,9 @@ class GenericLlmRequest
RequestIdType mParentRequestId;
std::shared_ptr> mSequenceFinalVec; // Indicators whether each sibling completes generation.
+ std::chrono::time_point mKvCacheTransferStart;
+ std::chrono::time_point mKvCacheTransferEnd;
+
private:
void initialize(VecTokens const& inputTokens, bool outputLogProbs)
{
@@ -1490,8 +1559,8 @@ class GenericLlmRequest
{
if (mInputTokenExtraIds.value()->size() != inputTokens.size())
{
- std::string errStr = "inputTokenExtraIds vector size must be the same as input token vector size.";
- TLLM_THROW(errStr);
+ TLLM_THROW("inputTokenExtraIds vector size (%lu) must be the same as input token vector size (%lu).",
+ mInputTokenExtraIds.value()->size(), inputTokens.size());
}
VecTokenExtraIds tokenExtraIds = *mInputTokenExtraIds.value();
for (std::size_t i = 0; i < inputTokens.size(); ++i)
@@ -1575,6 +1644,8 @@ class GenericLlmRequest
class LlmRequest : public GenericLlmRequest
{
+ friend class LlmRequestBindings;
+
public:
using Base = GenericLlmRequest;
using TensorPtr = Base::TensorPtr;
diff --git a/cpp/include/tensorrt_llm/batch_manager/microBatchScheduler.h b/cpp/include/tensorrt_llm/batch_manager/microBatchScheduler.h
new file mode 100644
index 000000000..2e932ba23
--- /dev/null
+++ b/cpp/include/tensorrt_llm/batch_manager/microBatchScheduler.h
@@ -0,0 +1,108 @@
+/*
+ * Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include "common.h"
+#include "tensorrt_llm/batch_manager/llmRequest.h"
+#include "tensorrt_llm/common/algorithm.h"
+#include "tensorrt_llm/runtime/common.h"
+
+namespace tensorrt_llm::batch_manager
+{
+
+namespace batch_scheduler
+{
+
+struct ContextChunkingConfig
+{
+ ContextChunkingConfig() = default;
+
+ executor::ContextChunkingPolicy chunkingPolicy;
+ /// The minimum size, also known as the chunk unit size. It generally
+ /// needs to be equal to the size of the kv cache block or its integer
+ /// multiples (except for the last context chunk) to avoid fragmentation.
+ /// When set to null, it indicates that the context chunk is disabled.
+ tensorrt_llm::runtime::SizeType32 chunkUnitSize;
+};
+
+} // namespace batch_scheduler
+
+/// @brief This scheduler takes into account the desired batch size and limits of the TRT engine to schedule requests.
+class MicroBatchScheduler : Algorithm
+{
+public:
+ constexpr static auto name{"MicroBatchScheduler"};
+
+ using SizeType32 = tensorrt_llm::runtime::SizeType32;
+ using ContextChunkingPolicy = tensorrt_llm::executor::ContextChunkingPolicy;
+
+ MicroBatchScheduler() = default;
+
+ explicit MicroBatchScheduler(SizeType32 maxBatchSize, std::optional maxNumTokens = std::nullopt,
+ std::optional ctxChunkConfig = std::nullopt,
+ std::optional maxContextLength = std::nullopt,
+ LlmRequestState noScheduleUntilState = LlmRequestState::kCONTEXT_INIT,
+ LlmRequestState noScheduleAfterState = LlmRequestState::kGENERATION_COMPLETE);
+
+ static MicroBatchScheduler make(SizeType32 maxBatchSize, std::optional maxNumTokens = std::nullopt,
+ std::optional ctxChunkConfig = std::nullopt,
+ std::optional maxContextLength = std::nullopt,
+ LlmRequestState noScheduleUntilState = LlmRequestState::kCONTEXT_INIT,
+ LlmRequestState noScheduleAfterState = LlmRequestState::kGENERATION_COMPLETE)
+ {
+ return MicroBatchScheduler{
+ maxBatchSize, maxNumTokens, ctxChunkConfig, maxContextLength, noScheduleUntilState, noScheduleAfterState};
+ }
+
+ std::tuple operator()(
+ RequestVector const& activeRequests, ReqIdsSet const& inflightReqIds);
+
+ static void setCtxRequestsChunkSize(RequestVector const& contextsToBeChunked, ContextChunkingPolicy ctxChunkPolicy,
+ std::optional ctxTokensCapacity, SizeType32 chunkUnitSize,
+ std::optional const& maxContextLength);
+
+private:
+ template
+ static void setCtxRequestsChunkSize(RequestVector const& contextsToBeChunked,
+ std::optional ctxTokensCapacity, SizeType32 chunkUnitSize,
+ std::optional const& maxContextLength);
+
+ /// After the chunk sizes have been determined, this function will discard
+ /// any draft tokens that don't fit.
+ static void fitDraftTokens(RequestVector const& contextsToBeChunked, std::optional ctxTokensCapacity,
+ SizeType32 chunkUnitSize, std::optional const& maxContextLength);
+
+ /// The maximum number of requests returned by scheduleRequests
+ SizeType32 mMaxBatchSize;
+
+ /// The maximum number of tokens to include in a batch
+ std::optional mMaxNumTokens;
+
+ /// The maximum length of the context. If the context exceeds this length,
+ /// it must be chunked, otherwise it cannot be processed. Therefore, it
+ /// needs to be set together with the chunk unit size to make sense.
+ /// When set to null, it indicates that context length is unlimited.
+ std::optional mMaxContextLength;
+
+ std::optional mCtxChunkConfig;
+
+ /// The state until/after which the scheduler should not schedule requests
+ LlmRequestState mNoScheduleUntilState;
+ LlmRequestState mNoScheduleAfterState;
+};
+
+} // namespace tensorrt_llm::batch_manager
diff --git a/cpp/include/tensorrt_llm/batch_manager/peftCacheManager.h b/cpp/include/tensorrt_llm/batch_manager/peftCacheManager.h
index 65808134b..f86e76b4b 100644
--- a/cpp/include/tensorrt_llm/batch_manager/peftCacheManager.h
+++ b/cpp/include/tensorrt_llm/batch_manager/peftCacheManager.h
@@ -51,6 +51,8 @@ class PeftTaskNotCachedException : public runtime::LoraExpectedException
class BasePeftCacheManager
{
public:
+ friend class BasePeftCacheManagerBindings;
+
using LlmRequestPtr = std::shared_ptr;
using RequestVector = std::vector;
using PeftTable = std::map>>;
diff --git a/cpp/include/tensorrt_llm/batch_manager/trtGptModelOptionalParams.h b/cpp/include/tensorrt_llm/batch_manager/trtGptModelOptionalParams.h
index fc61fd581..4a430d8c1 100644
--- a/cpp/include/tensorrt_llm/batch_manager/trtGptModelOptionalParams.h
+++ b/cpp/include/tensorrt_llm/batch_manager/trtGptModelOptionalParams.h
@@ -46,7 +46,9 @@ class TrtGptModelOptionalParams
executor::SchedulerConfig const& schedulerConfig = executor::SchedulerConfig{},
executor::ExtendedRuntimePerfKnobConfig const& extendedRuntimePerfKnobConfig
= executor::ExtendedRuntimePerfKnobConfig{},
- std::optional debugConfig = std::nullopt, uint64_t maxSeqIdleMicroseconds = 180000000)
+ std::optional debugConfig = std::nullopt, uint64_t maxSeqIdleMicroseconds = 180000000,
+ std::optional specDecConfig = std::nullopt,
+ bool isLeaderInOrchMode = false)
: kvCacheConfig{kvCacheConfig}
, enableTrtOverlap{enableTrtOverlap}
, deviceIds(deviceIds)
@@ -62,10 +64,12 @@ class TrtGptModelOptionalParams
, extendedRuntimePerfKnobConfig(extendedRuntimePerfKnobConfig)
, debugConfig{std::move(debugConfig)}
, maxSeqIdleMicroseconds{maxSeqIdleMicroseconds}
+ , speculativeDecodingConfig{std::move(specDecConfig)}
+ , isLeaderInOrchMode{isLeaderInOrchMode}
{
}
- explicit TrtGptModelOptionalParams(executor::ExecutorConfig const& executorConfig)
+ explicit TrtGptModelOptionalParams(executor::ExecutorConfig const& executorConfig, bool isLeaderInOrchMode)
: TrtGptModelOptionalParams(KvCacheConfig(executorConfig.getKvCacheConfig()), false,
executorConfig.getParallelConfig().value_or(executor::ParallelConfig()).getDeviceIds(),
executorConfig.getNormalizeLogProbs(), executorConfig.getEnableChunkedContext(),
@@ -74,16 +78,7 @@ class TrtGptModelOptionalParams
executorConfig.getGpuWeightsPercent(), executorConfig.getMaxBeamWidth(), executorConfig.getMaxBatchSize(),
executorConfig.getMaxNumTokens(), executorConfig.getSchedulerConfig(),
executorConfig.getExtendedRuntimePerfKnobConfig(), executorConfig.getDebugConfig(),
- executorConfig.getMaxSeqIdleMicroseconds())
- {
- }
-
- // Copy constructor
- TrtGptModelOptionalParams(TrtGptModelOptionalParams const& other)
- : TrtGptModelOptionalParams(other.kvCacheConfig, other.enableTrtOverlap, other.deviceIds,
- other.normalizeLogProbs, other.enableChunkedContext, other.peftCacheManagerConfig, other.decodingConfig,
- other.gpuWeightsPercent, other.maxBeamWidth, other.maxBatchSize, other.maxNumTokens, other.schedulerConfig,
- other.extendedRuntimePerfKnobConfig, other.debugConfig, other.maxSeqIdleMicroseconds)
+ executorConfig.getMaxSeqIdleMicroseconds(), executorConfig.getSpecDecConfig(), isLeaderInOrchMode)
{
}
@@ -103,6 +98,8 @@ class TrtGptModelOptionalParams
&& extendedRuntimePerfKnobConfig == other.extendedRuntimePerfKnobConfig //
&& debugConfig == other.debugConfig //
&& maxSeqIdleMicroseconds == other.maxSeqIdleMicroseconds //
+ && speculativeDecodingConfig == other.speculativeDecodingConfig //
+ && isLeaderInOrchMode == other.isLeaderInOrchMode //
;
}
@@ -126,6 +123,9 @@ class TrtGptModelOptionalParams
std::optional debugConfig;
// Sequence is considered idle if not updated for this amount of time.
uint64_t maxSeqIdleMicroseconds;
+ std::optional speculativeDecodingConfig;
+ // This rank is the leader worker in orchestrator mode
+ bool isLeaderInOrchMode;
};
} // namespace tensorrt_llm::batch_manager
diff --git a/cpp/include/tensorrt_llm/common/algorithm.h b/cpp/include/tensorrt_llm/common/algorithm.h
new file mode 100644
index 000000000..9363504f7
--- /dev/null
+++ b/cpp/include/tensorrt_llm/common/algorithm.h
@@ -0,0 +1,32 @@
+/*
+ * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+namespace tensorrt_llm
+{
+
+// Base class for algorithms
+struct Algorithm
+{
+ Algorithm() = default;
+ Algorithm(Algorithm&&) = default;
+ Algorithm& operator=(Algorithm&&) = default;
+ Algorithm(Algorithm const&) = delete;
+ Algorithm& operator=(Algorithm const&) = delete;
+};
+
+} // namespace tensorrt_llm
diff --git a/cpp/include/tensorrt_llm/common/cudaUtils.h b/cpp/include/tensorrt_llm/common/cudaUtils.h
index 71657c0bb..023f97d87 100644
--- a/cpp/include/tensorrt_llm/common/cudaUtils.h
+++ b/cpp/include/tensorrt_llm/common/cudaUtils.h
@@ -161,7 +161,7 @@ inline std::optional isCudaLaunchBlocking()
return result;
}
-inline void syncAndCheck(char const* const file, int const line)
+inline bool doCheckError()
{
auto const cudaLaunchBlocking = isCudaLaunchBlocking();
#ifndef NDEBUG
@@ -170,7 +170,12 @@ inline void syncAndCheck(char const* const file, int const line)
bool const checkError = cudaLaunchBlocking.value_or(false);
#endif
- if (checkError)
+ return checkError;
+}
+
+inline void syncAndCheck(char const* const file, int const line)
+{
+ if (doCheckError())
{
check(cudaGetLastError(), "cudaGetLastError", file, line);
check(cudaDeviceSynchronize(), "cudaDeviceSynchronize", file, line);
diff --git a/cpp/include/tensorrt_llm/common/mpiUtils.h b/cpp/include/tensorrt_llm/common/mpiUtils.h
index edf3da004..d5801f36c 100644
--- a/cpp/include/tensorrt_llm/common/mpiUtils.h
+++ b/cpp/include/tensorrt_llm/common/mpiUtils.h
@@ -99,7 +99,6 @@ struct MpiTypeConverter
};
template <>
-
struct MpiTypeConverter
{
@@ -380,9 +379,14 @@ class MpiComm
void allreduce(void const* sendbuf, void* recvbuf, int count, MpiType dtype, MpiOp op) const;
void allgather(void const* sendbuf, void* recvbuf, int count, MpiType dtype) const;
+
+ void allgatherv(void const* sendbuf, int sendcount, MpiType sendtype, void* recvbuf,
+ std::vector const& recvcounts, std::vector const& displs, MpiType recvtype) const;
+
void barrier() const;
void mprobe(int source, int tag, MPI_Message* msg, MPI_Status* status) const;
+ bool improbe(int source, int tag, MPI_Message* msg, MPI_Status* status) const;
//! \brief Returns if a message with the specified source and tag is available
bool iprobe(int source, int tag, MPI_Status* status) const;
diff --git a/cpp/include/tensorrt_llm/executor/executor.h b/cpp/include/tensorrt_llm/executor/executor.h
index a96c24d43..e6e5e1e0e 100644
--- a/cpp/include/tensorrt_llm/executor/executor.h
+++ b/cpp/include/tensorrt_llm/executor/executor.h
@@ -43,7 +43,7 @@ char const* version() noexcept;
class Model;
class Serialization;
-class ContextPhaseState;
+class DataTransceiverState;
/// @brief Sampling configuration
class SamplingConfig
@@ -186,11 +186,13 @@ class ExternalDraftTokensConfig
{
public:
explicit ExternalDraftTokensConfig(VecTokens tokens, std::optional logits = std::nullopt,
- std::optional const& acceptanceThreshold = std::nullopt);
+ std::optional const& acceptanceThreshold = std::nullopt,
+ std::optional const& fastLogits = std::nullopt);
[[nodiscard]] VecTokens getTokens() const;
[[nodiscard]] std::optional getLogits() const;
[[nodiscard]] std::optional getAcceptanceThreshold() const;
+ [[nodiscard]] std::optional getFastLogits() const;
private:
friend class Serialization;
@@ -200,6 +202,8 @@ class ExternalDraftTokensConfig
std::optional mLogits;
/// @brief The acceptance threshold. Must be > 0.f and <= 1.f
std::optional mAcceptanceThreshold;
+ /// @brief Use direct transfer for draft logits
+ std::optional mFastLogits;
};
/// @brief Configuration for prompt tuning
@@ -283,8 +287,10 @@ struct LookaheadDecodingConfig
class ContextPhaseParams
{
public:
- explicit ContextPhaseParams(VecTokens firstGenTokens);
- ContextPhaseParams(VecTokens firstGenTokens, void* state);
+ using RequestIdType = std::uint64_t;
+
+ explicit ContextPhaseParams(VecTokens firstGenTokens, RequestIdType reqId);
+ ContextPhaseParams(VecTokens firstGenTokens, RequestIdType reqId, void* state);
ContextPhaseParams(ContextPhaseParams const&);
ContextPhaseParams(ContextPhaseParams&&);
@@ -295,6 +301,8 @@ class ContextPhaseParams
[[nodiscard]] VecTokens const& getFirstGenTokens() const& noexcept;
[[nodiscard]] VecTokens popFirstGenTokens() && noexcept;
+ [[nodiscard]] RequestIdType getReqId() const noexcept;
+
[[nodiscard]] void const* getState() const noexcept;
[[nodiscard]] void* getState() noexcept;
[[nodiscard]] void* releaseState() noexcept;
@@ -304,6 +312,9 @@ class ContextPhaseParams
static void deleter(void const* data);
using StatePtr = std::unique_ptr;
+ /// @brief This request corresponds to the request ID in the context phase.
+ RequestIdType mReqId{0};
+
/// @brief The first tokens generated by context executor
VecTokens mFirstGenTokens;
@@ -311,6 +322,18 @@ class ContextPhaseParams
StatePtr mState{nullptr, deleter};
};
+/// @brief Configuration for speculative decoding (both draft and target models)
+class SpeculativeDecodingConfig
+{
+public:
+ explicit SpeculativeDecodingConfig(bool fastLogits);
+
+ bool operator==(SpeculativeDecodingConfig const& other) const;
+
+ /// @brief Send logits tensor directly from draft to target model.
+ bool fastLogits;
+};
+
/// @brief A class that holds information about the request
class Request
{
@@ -430,6 +453,16 @@ class Request
std::unique_ptr mImpl;
};
+/// @brief Struct that holds the logits information when using direct transfer
+struct SpeculativeDecodingFastLogitsInfo
+{
+ /// @brief Draft request id
+ uint64_t draftRequestId;
+
+ /// @brief MPI world rank of the draft model leader
+ int32_t draftParticipantId;
+};
+
/// @brief Struct that holds the generation result
struct Result
{
@@ -448,11 +481,14 @@ struct Result
/// @brief The context logits. Size [promptLen, vocabSizePadded]
std::optional contextLogits;
- /// @brief The context logits. Size [beamSize, maxNewTokens, vocabSizePadded] (non-streaming)
+ /// @brief The generation logits. Size [beamSize, maxNewTokens, vocabSizePadded] (non-streaming)
/// or [maxNewTokens, beamSize, vocabSizePadded] (streaming and allGeneratedTokens)
/// or [1, beamSize, vocabSizePadded] (streaming and non-allGeneratedTokens)
std::optional generationLogits;
+ /// @brief Logits information for direct transfer when using fast logits
+ std::optional specDecFastLogitsInfo;
+
/// @brief The encoder output. Size [encoderLen, hiddenSize]
std::optional encoderOutput;
@@ -477,8 +513,8 @@ struct Result
class Response
{
public:
- Response(IdType requestId, std::string errorMsg);
- Response(IdType requestId, Result Result);
+ Response(IdType requestId, std::string errorMsg, std::optional clientId = std::nullopt);
+ Response(IdType requestId, Result Result, std::optional clientId = std::nullopt);
~Response();
Response(Response const& other);
@@ -489,6 +525,9 @@ class Response
/// @brief Get the id of the request for which this response was generated
[[nodiscard]] IdType getRequestId() const;
+ /// @brief Get the client id of the request for which this response was generated
+ [[nodiscard]] std::optional getClientId() const;
+
/// @brief Indicates if this response has an error or not
[[nodiscard]] bool hasError() const;
@@ -538,13 +577,15 @@ class KvCacheConfig
std::optional> const& maxAttentionWindowVec = std::nullopt,
std::optional const& sinkTokenLength = std::nullopt,
std::optional const& freeGpuMemoryFraction = std::nullopt,
- std::optional const& hostCacheSize = std::nullopt, bool onboardBlocks = true);
+ std::optional const& hostCacheSize = std::nullopt, bool onboardBlocks = true,
+ std::optional const& crossKvCacheFraction = std::nullopt);
[[nodiscard]] bool getEnableBlockReuse() const;
[[nodiscard]] std::optional getMaxTokens() const;
[[nodiscard]] std::optional> getMaxAttentionWindowVec() const;
[[nodiscard]] std::optional getSinkTokenLength() const;
[[nodiscard]] std::optional getFreeGpuMemoryFraction() const;
+ [[nodiscard]] std::optional getCrossKvCacheFraction() const;
[[nodiscard]] std::optional getHostCacheSize() const;
[[nodiscard]] bool getOnboardBlocks() const;
@@ -553,6 +594,7 @@ class KvCacheConfig
void setMaxAttentionWindowVec(std::vector maxAttentionWindowVec);
void setSinkTokenLength(SizeType32 sinkTokenLength);
void setFreeGpuMemoryFraction(FloatType freeGpuMemoryFraction);
+ void setCrossKvCacheFraction(FloatType crossKvCacheFraction);
void setHostCacheSize(size_t hostCacheSize);
void setOnboardBlocks(bool onboardBlocks);
@@ -581,6 +623,12 @@ class KvCacheConfig
/// allocated.
std::optional mFreeGpuMemoryFraction;
+ /// @brief The fraction of the KV Cache memory should be reserved for cross attention
+ /// If set to p, self attention will use 1-p of KV Cache memory and cross attention
+ /// will use p of KV Cache memory. Default is 50%.
+ /// Should only be set when using encoder-decoder model.
+ std::optional mCrossKvCacheFraction;
+
/// @brief Size of secondary memory pool in bytes. Default is 0.
/// Having a secondary memory pool increases KV cache block reuse potential.
std::optional mHostCacheSize;
@@ -593,18 +641,24 @@ class KvCacheConfig
class ExtendedRuntimePerfKnobConfig
{
public:
- explicit ExtendedRuntimePerfKnobConfig(bool multiBlockMode = true, bool enableContextFMHAFP32Acc = false);
+ explicit ExtendedRuntimePerfKnobConfig(bool multiBlockMode = true, bool enableContextFMHAFP32Acc = false,
+ bool cudaGraphMode = false, SizeType32 cudaGraphCacheSize = 0);
bool operator==(ExtendedRuntimePerfKnobConfig const& other) const
{
- return mMultiBlockMode == other.mMultiBlockMode && mEnableContextFMHAFP32Acc == other.mEnableContextFMHAFP32Acc;
+ return mMultiBlockMode == other.mMultiBlockMode && mEnableContextFMHAFP32Acc == other.mEnableContextFMHAFP32Acc
+ && mCudaGraphMode == other.mCudaGraphMode && mCudaGraphCacheSize == other.mCudaGraphCacheSize;
}
[[nodiscard]] bool getMultiBlockMode() const;
[[nodiscard]] bool getEnableContextFMHAFP32Acc() const;
+ [[nodiscard]] bool getCudaGraphMode() const;
+ [[nodiscard]] SizeType32 getCudaGraphCacheSize() const;
void setMultiBlockMode(bool multiBlockMode);
void setEnableContextFMHAFP32Acc(bool enableContextFMHAFP32Acc);
+ void setCudaGraphMode(bool cudaGraphMode);
+ void setCudaGraphCacheSize(SizeType32 cacheSize);
private:
friend class Serialization;
@@ -614,6 +668,13 @@ class ExtendedRuntimePerfKnobConfig
/// @brief If enable FMHA runner FP32 accumulation.
bool mEnableContextFMHAFP32Acc;
+
+ /// @brief Control if enable cuda graph.
+ bool mCudaGraphMode;
+
+ /// @brief Number of cuda graphs to be cached in the runtime.
+ /// The larger the cache, the better the perf, but more GPU memory is consumed.
+ SizeType32 mCudaGraphCacheSize;
};
/// @brief Configuration class for debugging output
@@ -622,27 +683,33 @@ class DebugConfig
using StringVec = std::vector;
public:
- explicit DebugConfig(bool dumpInputTensors = false, bool dumpOuputTensors = false, StringVec debugTensorNames = {});
+ explicit DebugConfig(bool debugInputTensors = false, bool debugOutputTensors = false,
+ StringVec debugTensorNames = {}, SizeType32 debugTensorsMaxIterations = 0);
bool operator==(DebugConfig const& other) const;
- [[nodiscard]] bool getDumpInputTensors() const;
- [[nodiscard]] bool getDumpOutputTensors() const;
+ [[nodiscard]] bool getDebugInputTensors() const;
+ [[nodiscard]] bool getDebugOutputTensors() const;
[[nodiscard]] StringVec const& getDebugTensorNames() const;
+ [[nodiscard]] SizeType32 getDebugTensorsMaxIterations() const;
- void setDumpInputTensors(bool dumpInputTensors);
- void setDumpOuputTensors(bool dumpOuputTensors);
+ void setDebugInputTensors(bool debugInputTensors);
+ void setDebugOutputTensors(bool debugOutputTensors);
void setDebugTensorNames(StringVec const& debugTensorNames);
+ void setDebugTensorsMaxIterations(SizeType32 debugTensorsMaxIterations);
private:
friend class Serialization;
- /// @brief If true, dump all input tensors.
- bool mDumpInputTensors;
- /// @brief If true, dump all output tensors.
- bool mDumpOuputTensors;
- /// @brief If not empty, only dump tensors in this list.
+ /// @brief If true, debug all input tensors.
+ bool mDebugInputTensors;
+ /// @brief If true, debug all output tensors.
+ bool mDebugOutputTensors;
+ /// @brief If not empty, only debug tensors in this list.
StringVec mDebugTensorNames;
+ /// @brief If > 0, provide debug tensors for at most debugTensorsMaxIterations past iterations,
+ /// else dump them to files.
+ SizeType32 mDebugTensorsMaxIterations;
};
SizeType32 const kDefaultIterStatsMaxIterations = 1000;
@@ -847,7 +914,8 @@ class ExecutorConfig
std::optional maxQueueSize = std::nullopt,
ExtendedRuntimePerfKnobConfig const& extendedRuntimePerfKnobConfig = ExtendedRuntimePerfKnobConfig(),
std::optional debugConfig = std::nullopt, SizeType32 recvPollPeriodMs = 0,
- uint64_t maxSeqIdleMicroseconds = 180000000);
+ uint64_t maxSeqIdleMicroseconds = 180000000,
+ std::optional specDecConfig = std::nullopt);
[[nodiscard]] SizeType32 getMaxBeamWidth() const;
[[nodiscard]] SchedulerConfig getSchedulerConfig() const;
@@ -869,6 +937,7 @@ class ExecutorConfig
[[nodiscard]] std::optional getDebugConfig() const;
[[nodiscard]] SizeType32 getRecvPollPeriodMs() const;
[[nodiscard]] uint64_t getMaxSeqIdleMicroseconds() const;
+ [[nodiscard]] std::optional getSpecDecConfig() const;
void setMaxBeamWidth(SizeType32 maxBeamWidth);
void setMaxBatchSize(SizeType32 maxBatchSize);
@@ -890,6 +959,7 @@ class ExecutorConfig
void setDebugConfig(DebugConfig const& debugConfig);
void setRecvPollPeriodMs(SizeType32 const& recvPollPeriodMs);
void setMaxSeqIdleMicroseconds(uint64_t maxNumTokens);
+ void setSpecDecConfig(SpeculativeDecodingConfig const& specDecConfig);
private:
friend class Serialization;
@@ -952,6 +1022,9 @@ class ExecutorConfig
/// @brief The maximum time in microseconds a scheduled request can remain idle before getting terminated. Default
/// is 3 minutes.
uint64_t mMaxSeqIdleMicroseconds;
+
+ /// @brief The speculative decoding configuration
+ std::optional mSpeculativeDecodingConfig;
};
/// @brief The executor is responsible for receiving new requests and sending responses, and running the inference
@@ -1032,23 +1105,31 @@ class Executor
/// @param id The request id for which to cancel the response
void cancelRequest(IdType requestId);
- /// @brief Signals the server to shutdown
- /// This call is blocking. Only returns when all requests have terminated or timeout has been reached
+ /// @brief Signals the server to shutdown.
+ /// @details This call is blocking. Only returns when all requests have terminated or timeout has been reached
void shutdown();
- /// @brief Returns the per-iterations statistics computed since last call to getLatestIterationStats
- /// Contains at most iterStatsMaxIterations iterations
+ /// @brief Returns the per-iterations statistics computed since last call to getLatestIterationStats.
+ /// Contains at most iterStatsMaxIterations iterations.
/// @return Iteration stats
std::deque getLatestIterationStats();
- /// @brief Returns the request stats of each iteration computed since last call to getLatestRequestStats
- /// Contains at most requestStatsMaxIterations iterations
+ /// @brief Returns the request stats of each iteration computed since last call to getLatestRequestStats.
+ /// Contains at most requestStatsMaxIterations iterations.
/// @return Request stats grouped by iterations
std::deque getLatestRequestStats();
+ /// @brief Returns the debug tensors of each iteration computed since last call to getLatestDebugTensors.
+ /// Contains at most debugTensorsMaxIterations iterations.
+ /// @return Request debug tensors grouped by iterations
+ std::deque getLatestDebugTensors();
+
/// @brief Indicates if the current process is allowed to enqueueRequests
[[nodiscard]] bool canEnqueueRequests() const;
+ /// @brief Indicates if the current process participates in this executor instance
+ [[nodiscard]] bool isParticipant() const;
+
private:
class Impl;
std::unique_ptr mImpl;
diff --git a/cpp/include/tensorrt_llm/executor/serialization.h b/cpp/include/tensorrt_llm/executor/serialization.h
index 11d22c3f0..28aba9dc1 100644
--- a/cpp/include/tensorrt_llm/executor/serialization.h
+++ b/cpp/include/tensorrt_llm/executor/serialization.h
@@ -75,10 +75,10 @@ class Serialization
static void serialize(kv_cache::CacheState const& state, std::ostream& os);
[[nodiscard]] static size_t serializedSize(kv_cache::CacheState const& state);
- // ContextPhaseState
- [[nodiscard]] static ContextPhaseState deserializeContextPhaseState(std::istream& is);
- static void serialize(ContextPhaseState const& contextPhaseState, std::ostream& os);
- [[nodiscard]] static size_t serializedSize(ContextPhaseState const& contextPhaseState);
+ // DataTransceiverState
+ [[nodiscard]] static DataTransceiverState deserializeDataTransceiverState(std::istream& is);
+ static void serialize(DataTransceiverState const& dataTransceiverState, std::ostream& os);
+ [[nodiscard]] static size_t serializedSize(DataTransceiverState const& dataTransceiverState);
// ContextPhaseParams
[[nodiscard]] static ContextPhaseParams deserializeContextPhaseParams(std::istream& is);
@@ -95,6 +95,11 @@ class Serialization
static void serialize(Tensor const& tensor, std::ostream& os);
[[nodiscard]] static size_t serializedSize(Tensor const& tensor);
+ // SpeculativeDecodingFastLogitsInfo
+ [[nodiscard]] static SpeculativeDecodingFastLogitsInfo deserializeSpecDecFastLogitsInfo(std::istream& is);
+ static void serialize(SpeculativeDecodingFastLogitsInfo const& info, std::ostream& os);
+ [[nodiscard]] static size_t serializedSize(SpeculativeDecodingFastLogitsInfo const& info);
+
// Result
[[nodiscard]] static Result deserializeResult(std::istream& is);
static void serialize(Result const& result, std::ostream& os);
diff --git a/cpp/include/tensorrt_llm/executor/types.h b/cpp/include/tensorrt_llm/executor/types.h
index e07c539a9..5a8525caf 100644
--- a/cpp/include/tensorrt_llm/executor/types.h
+++ b/cpp/include/tensorrt_llm/executor/types.h
@@ -18,6 +18,7 @@
#include
#include
+#include