Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update TensorRT-LLM v0.14.0 #2401

Merged
merged 1 commit into from
Nov 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ tensorrt_llm/bindings.pyi
tensorrt_llm/bindings/*.pyi
*docs/cpp_docs*
*docs/source/_cpp_gen*
docs/source/llm-api
docs/source/llm-api/*.rst
docs/source/llm-api-examples/llm_*.rst
*.swp

Expand Down
5 changes: 4 additions & 1 deletion .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,7 @@
url = https://github.com/NVIDIA/NVTX.git
[submodule "3rdparty/ucxx"]
path = 3rdparty/ucxx
url = https://github.com/GuanLuo/ucxx.git
url = https://github.com/rapidsai/ucxx.git
[submodule "3rdparty/pybind11"]
path = 3rdparty/pybind11
url = https://github.com/pybind/pybind11.git
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,5 +46,5 @@ repos:
args:
- --skip=".git,3rdparty"
- --exclude-file=examples/whisper/tokenizer.py
- --ignore-words-list=rouge,inout,atleast,strat,nd,subtile,thrid
- --ignore-words-list=rouge,inout,atleast,strat,nd,subtile,thrid,improbe
exclude: 'tests/llm-test-defs/turtle/test_input_files'
1 change: 1 addition & 0 deletions 3rdparty/pybind11
Submodule pybind11 added at f99ffd
31 changes: 23 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ TensorRT-LLM
[![python](https://img.shields.io/badge/python-3.10.12-green)](https://www.python.org/downloads/release/python-31012/)
[![cuda](https://img.shields.io/badge/cuda-12.5.1-green)](https://developer.nvidia.com/cuda-downloads)
[![trt](https://img.shields.io/badge/TRT-10.4.0-green)](https://developer.nvidia.com/tensorrt)
[![version](https://img.shields.io/badge/release-0.13.0-green)](./tensorrt_llm/version.py)
[![version](https://img.shields.io/badge/release-0.14.0-green)](./tensorrt_llm/version.py)
[![license](https://img.shields.io/badge/license-Apache%202-blue)](./LICENSE)

[Architecture](./docs/source/architecture/overview.md)   |   [Results](./docs/source/performance/perf-overview.md)   |   [Examples](./examples/)   |   [Documentation](./docs/source/)
Expand All @@ -17,6 +17,24 @@ TensorRT-LLM
<div align="left">

## Latest News
* [2024/09/29] 🌟 AI at Meta PyTorch + TensorRT v2.4 🌟 ⚡TensorRT 10.1 ⚡PyTorch 2.4 ⚡CUDA 12.4 ⚡Python 3.12
[➡️ link](https://github.com/pytorch/TensorRT/releases/tag/v2.4.0)
<div align="center">
<img src="docs/source/media/image-09-29-2024.png" width="50%">
<div align="left">

* [2024/09/17] ✨ NVIDIA TensorRT-LLM Meetup
[➡️ link](https://drive.google.com/file/d/1RR8GqC-QbuaKuHj82rZcXb3MS20SWo6F/view?usp=share_link)

* [2024/09/17] ✨ Accelerating LLM Inference at Databricks with TensorRT-LLM
[➡️ link](https://drive.google.com/file/d/1NeSmrLaWRJAY1rxD9lJmzpB9rzr38j8j/view?usp=sharing)

* [2024/09/17] ✨ TensorRT-LLM @ Baseten
[➡️ link](https://drive.google.com/file/d/1Y7L2jqW-aRmt31mCdqhwvGMmCSOzBUjG/view?usp=share_link)

* [2024/09/04] 🏎️🏎️🏎️ Best Practices for Tuning TensorRT-LLM for Optimal Serving with BentoML
[➡️ link](https://www.bentoml.com/blog/tuning-tensor-rt-llm-for-optimal-serving-with-bentoml)

* [2024/08/20] 🏎️SDXL with #TensorRT Model Optimizer ⏱️⚡ 🏁 cache diffusion 🏁 quantization aware training 🏁 QLoRA 🏁 #Python 3.12
[➡️ link](https://developer.nvidia.com/blog/nvidia-tensorrt-model-optimizer-v0-15-boosts-inference-performance-and-expands-model-support/)

Expand All @@ -43,6 +61,9 @@ TensorRT-LLM
* [2024/07/02] Let the @MistralAI MoE tokens fly 📈 🚀 #Mixtral 8x7B with NVIDIA #TensorRT #LLM on #H100.
[➡️ Tech blog](https://developer.nvidia.com/blog/achieving-high-mixtral-8x7b-performance-with-nvidia-h100-tensor-core-gpus-and-tensorrt-llm?ncid=so-twit-928467)

<details close>
<summary>Previous News</summary>

* [2024/06/24] Enhanced with NVIDIA #TensorRT #LLM, @upstage.ai’s solar-10.7B-instruct is ready to power your developer projects through our API catalog 🏎️. ✨[➡️ link](https://build.nvidia.com/upstage/solar-10_7b-instruct?snippet_tab=Try )

* [2024/06/18] CYMI: 🤩 Stable Diffusion 3 dropped last week 🎊 🏎️ Speed up your SD3 with #TensorRT INT8 Quantization[➡️ link](https://build.nvidia.com/upstage/solar-10_7b-instruct?snippet_tab=Try )
Expand All @@ -55,10 +76,6 @@ Technical Deep Dive for serious coders ✅+99% compression ✅1 set of weights
* [2024/06/04] ✨ #TensorRT and GeForce #RTX unlock ComfyUI SD superhero powers 🦸⚡ 🎥 Demo: [➡️ link](https://youtu.be/64QEVfbPHyg)
📗 DIY notebook: [➡️ link](https://console.brev.dev/launchable/deploy?userID=2x2sil999&orgID=ktj33l4xj&name=ComfyUI_TensorRT&instance=L4%40g2-standard-4%3Anvidia-l4%3A1&diskStorage=500&cloudID=GCP&baseImage=docker.io%2Fpytorch%2Fpytorch%3A2.2.0-cuda12.1-cudnn8-runtime&ports=ComfUI%3A8188&file=https%3A%2F%2Fgithub.com%2Fbrevdev%2Fnotebooks%2Fblob%2Fmain%2Ftensorrt-comfyui.ipynb&launchableID=env-2hQX3n7ae5mq3NjNZ32DfAG0tJf)

<details close>
<summary>Previous News</summary>


* [2024/05/28] ✨#TensorRT weight stripping for ResNet-50 ✨ ✅+99% compression
✅1 set of weights → ** GPUs\ ✅0 performance loss ✅** models…LLM, CNN, etc
👀 📚 DIY [➡️ link](https://console.brev.dev/launchable/deploy?userID=2x2sil999&orgID=ktj33l4xj&launchableID=env-2h6bym7h5GFNho3vpWQQeUYMwTM&instance=L4%40g6.xlarge&diskStorage=500&cloudID=devplane-brev-1&baseImage=nvcr.io%2Fnvidia%2Ftensorrt%3A24.05-py3&file=https%3A%2F%2Fgithub.com%2FNVIDIA%2FTensorRT%2Fblob%2Frelease%2F10.0%2Fsamples%2Fpython%2Fsample_weight_stripping%2Fnotebooks%2Fweight_stripping.ipynb&name=tensorrt_weight_stripping_resnet50)
Expand All @@ -68,10 +85,8 @@ Serverless TensorRT-LLM (LLaMA 3 8B) | Modal Docs [➡️ link](https://modal.co

* [2024/05/08] NVIDIA TensorRT Model Optimizer -- the newest member of the #TensorRT ecosystem is a library of post-training and training-in-the-loop model optimization techniques ✅quantization ✅sparsity ✅QAT [➡️ blog](https://developer.nvidia.com/blog/accelerate-generative-ai-inference-performance-with-nvidia-tensorrt-model-optimizer-now-publicly-available/)


* [2024/05/07] 🦙🦙🦙 24,000 tokens per second 🛫Meta Llama 3 takes off with #TensorRT #LLM 📚[➡️ link](https://blogs.nvidia.com/blog/meta-llama3-inference-acceleration/)


* [2024/02/06] [🚀 Speed up inference with SOTA quantization techniques in TRT-LLM](./docs/source/blogs/quantization-in-TRT-LLM.md)
* [2024/01/30] [ New XQA-kernel provides 2.4x more Llama-70B throughput within the same latency budget](./docs/source/blogs/XQA-kernel.md)
* [2023/12/04] [Falcon-180B on a single H200 GPU with INT4 AWQ, and 6.7x faster Llama-70B over A100](./docs/source/blogs/Falcon180B-H200.md)
Expand All @@ -88,7 +103,7 @@ Serverless TensorRT-LLM (LLaMA 3 8B) | Modal Docs [➡️ link](https://modal.co
## TensorRT-LLM Overview

TensorRT-LLM is a library for optimizing Large Language Model (LLM) inference.
It provides state-of-the-art optimziations, including custom attention kernels, inflight batching, paged KV caching, quantization (FP8, INT4 [AWQ](https://arxiv.org/abs/2306.00978), INT8 [SmoothQuant](https://arxiv.org/abs/2211.10438), ++) and much more, to perform inference efficiently on NVIDIA GPUs
It provides state-of-the-art optimizations, including custom attention kernels, inflight batching, paged KV caching, quantization (FP8, INT4 [AWQ](https://arxiv.org/abs/2306.00978), INT8 [SmoothQuant](https://arxiv.org/abs/2211.10438), ++) and much more, to perform inference efficiently on NVIDIA GPUs

TensorRT-LLM provides a Python API to build LLMs into optimized
[TensorRT](https://developer.nvidia.com/tensorrt) engines.
Expand Down
5 changes: 3 additions & 2 deletions benchmarks/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,6 @@ There are currently three workflows to benchmark TensorRT-LLM:
- The recommended workflow that uses TensorRT-LLM C++ API and can take advantage of the latest features of TensorRT-LLM.
* [Python benchmarks](./python)
- The Python benchmarking scripts can only benchmark the Python runtime, which do not support the latest features, such as in-flight batching.
* [The Python benchmarking suite](./Suite.md)
- This benchmarking suite is a current work in progress and is prone to large changes.
* [The Python benchmarking suite](../docs/source/performance/perf-benchmarking.md)
- This benchmarker is native to TensorRT-LLM and is a Python benchmarker for reproducing and testing the performance of TensorRT-LLM.
- _NOTE_: This benchmarking suite is a current work in progress and is prone to large changes.
99 changes: 86 additions & 13 deletions benchmarks/cpp/gptManagerBenchmark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ struct BenchmarkParams
{
std::optional<SizeType32> maxTokensInPagedKvCache{std::nullopt};
std::optional<float> freeGpuMemoryFraction{std::nullopt};
std::optional<float> crossKvCacheFraction{std::nullopt};
bool enableTrtOverlap{false};
bool enableBlockReuse{false};
bool enableChunkedContext{false};
Expand All @@ -159,6 +160,8 @@ struct BenchmarkParams
std::optional<int> sinkTokenLength{std::nullopt};
bool multiBlockMode{true};
bool enableContextFMHAFP32Acc{false};
bool cudaGraphMode{false};
SizeType32 cudaGraphCacheSize{0};

// lora / peft params
std::optional<std::string> loraDir{std::nullopt};
Expand Down Expand Up @@ -470,7 +473,38 @@ class Recorder
mRequestBenchInfos[requestId].firstTokenSeen = true;
}

mRequestBenchInfos[requestId].outputLength += 1;
mRequestBenchInfos[requestId].decodingIter += 1;
}

void recordToken(uint64_t requestId, std::list<NamedTensor> const& responseTensors)
{
int32_t outputLength = 1;
for (auto& tensor : responseTensors)
{
if (tensor.name == inference_request::kSequenceLengthTensorName)
{
// Tensor of shape nBeams, and we only need the first one
outputLength = *(bufferCast<int32_t>(*(tensor.tensor)));
break;
}
}

mRequestBenchInfos[requestId].outputLength += outputLength;
this->recordToken(requestId);
}

void recordToken(uint64_t requestId, texec::Response const& response)
{
auto outputTokenIds = response.getResult().outputTokenIds;

int32_t outputLength = 1;
for (auto const& beam : outputTokenIds)
{
outputLength = std::max(static_cast<int32_t>(beam.size()), outputLength);
}

mRequestBenchInfos[requestId].outputLength += outputLength;
this->recordToken(requestId);
}

void recordEnd(uint64_t requestId, std::list<NamedTensor> const& responseTensors, bool hasError)
Expand Down Expand Up @@ -500,7 +534,7 @@ class Recorder
}
else
{
this->recordToken(requestId);
this->recordToken(requestId, responseTensors);
}
}

Expand Down Expand Up @@ -532,7 +566,7 @@ class Recorder
}
else
{
this->recordToken(requestId);
this->recordToken(requestId, response);
}
}
}
Expand Down Expand Up @@ -818,11 +852,13 @@ class ExecutorServer
texec::SchedulerConfig schedulerConfig(capacitySchedulerPolicy);
texec::KvCacheConfig kvCacheConfig(benchmarkParams.enableBlockReuse, benchmarkParams.maxTokensInPagedKvCache,
benchmarkParams.maxAttentionWindowVec, benchmarkParams.sinkTokenLength,
benchmarkParams.freeGpuMemoryFraction, benchmarkParams.kvHostCacheSize, benchmarkParams.kvOnboardBlocks);
benchmarkParams.freeGpuMemoryFraction, benchmarkParams.kvHostCacheSize, benchmarkParams.kvOnboardBlocks,
benchmarkParams.crossKvCacheFraction);
texec::PeftCacheConfig peftCacheConfig(0, benchmarkParams.loraDeviceNumModLayers, 8, 64, 4, 4, 4, 24, 8,
std::nullopt, benchmarkParams.loraHostCacheSize);
texec::ExtendedRuntimePerfKnobConfig extendedRuntimePerfKnobConfig(
benchmarkParams.multiBlockMode, benchmarkParams.enableContextFMHAFP32Acc);
texec::ExtendedRuntimePerfKnobConfig extendedRuntimePerfKnobConfig(benchmarkParams.multiBlockMode,
benchmarkParams.enableContextFMHAFP32Acc, benchmarkParams.cudaGraphMode,
benchmarkParams.cudaGraphCacheSize);
texec::ExecutorConfig executorConfig(
maxBeamWidth, schedulerConfig, kvCacheConfig, benchmarkParams.enableChunkedContext, true);
executorConfig.setGpuWeightsPercent(benchmarkParams.gpuWeightsPercent);
Expand Down Expand Up @@ -940,7 +976,7 @@ class ExecutorServer
{
if (!warmup && !response.hasError())
{
mRecorder->recordToken(reqId);
mRecorder->recordToken(reqId, response);
}
}
}
Expand Down Expand Up @@ -1228,7 +1264,7 @@ class GptServer
{
if (errMsg.empty())
{
mRecorder->recordToken(requestId);
mRecorder->recordToken(requestId, response_tensors);
}
}
}
Expand Down Expand Up @@ -1430,6 +1466,10 @@ void benchmarkGptManager(std::filesystem::path const& engineDir, TrtGptModelType
{
optionalParams.kvCacheConfig.freeGpuMemoryFraction = benchmarkParams.freeGpuMemoryFraction;
}
if (benchmarkParams.crossKvCacheFraction)
{
optionalParams.kvCacheConfig.crossKvCacheFraction = benchmarkParams.crossKvCacheFraction;
}
if (benchmarkParams.maxAttentionWindowVec)
{
optionalParams.kvCacheConfig.maxAttentionWindowVec = benchmarkParams.maxAttentionWindowVec;
Expand Down Expand Up @@ -1458,8 +1498,8 @@ void benchmarkGptManager(std::filesystem::path const& engineDir, TrtGptModelType
: benchmarkParams.executorLookaheadConfig.has_value() ? texec::DecodingMode::Lookahead()
: texec::DecodingMode::Auto(),
benchmarkParams.executorLookaheadConfig, benchmarkParams.medusaChoices);
optionalParams.extendedRuntimePerfKnobConfig = texec::ExtendedRuntimePerfKnobConfig(
benchmarkParams.multiBlockMode, benchmarkParams.enableContextFMHAFP32Acc);
optionalParams.extendedRuntimePerfKnobConfig = texec::ExtendedRuntimePerfKnobConfig(benchmarkParams.multiBlockMode,
benchmarkParams.enableContextFMHAFP32Acc, benchmarkParams.cudaGraphMode, benchmarkParams.cudaGraphCacheSize);

auto const jsonConfig = GptJsonConfig::parse(engineDir / "config.json");
auto const worldConfig = WorldConfig::mpi(jsonConfig.getGpusPerNode(), jsonConfig.getTensorParallelism(),
Expand Down Expand Up @@ -1874,6 +1914,8 @@ int main(int argc, char* argv[])
"random_seed", "integer random seed for exponential time delays.", cxxopts::value<int>()->default_value("420"));
options.add_options()(
"kv_cache_free_gpu_mem_fraction", "K-V Cache Free Gpu Mem Fraction.", cxxopts::value<float>());
options.add_options()(
"cross_kv_cache_fraction", "Cross K-V Cache Fraction (from 0.0 to 1.0).", cxxopts::value<float>());
options.add_options()("request_rate",
"request rate in reqs/sec. Skipping this arg or negative value will trigger offline/0-delay.",
cxxopts::value<float>());
Expand All @@ -1895,7 +1937,8 @@ int main(int argc, char* argv[])
options.add_options()("return_generation_logits", "Whether to return generation logits.",
cxxopts::value<bool>()->default_value("false"));

options.add_options()("scheduler_policy", "Choose scheduler policy between max_utilization/guaranteed_no_evict.",
options.add_options()("scheduler_policy",
"Choose scheduler policy between max_utilization/guaranteed_no_evict/static_batch.",
cxxopts::value<std::string>()->default_value("guaranteed_no_evict"));

options.add_options()("first_batch_delay",
Expand Down Expand Up @@ -1946,6 +1989,12 @@ int main(int argc, char* argv[])
cxxopts::value<bool>()->default_value("true"));
options.add_options()(
"encoder_engine_dir", "Directory that store the engines of the encoder models.", cxxopts::value<std::string>());
options.add_options()("cuda_graph_mode", "When enabled, inference is executed with cuda graph.",
cxxopts::value<bool>()->default_value("false"));
options.add_options()("cuda_graph_cache_size",
"Specify how many cuda graphs are cached in the runtime. Larger cache gives better perf, but consumes more GPU "
"memory.",
cxxopts::value<SizeType32>()->default_value("0"));

options.add_options()("enable_context_fmha_fp32_acc", "Enable FMHA runner FP32 accumulation",
cxxopts::value<bool>()->default_value("false"));
Expand Down Expand Up @@ -2040,6 +2089,20 @@ int main(int argc, char* argv[])
{
benchmarkParams.freeGpuMemoryFraction = result["kv_cache_free_gpu_mem_fraction"].as<float>();
}
// Argument: K-V Cache Cross Attention Fraction. Only applicable to enc-dec models.
if (result.count("encoder_engine_dir") && result.count("decoder_engine_dir"))
{
if (result.count("cross_kv_cache_fraction"))
{
benchmarkParams.crossKvCacheFraction = result["cross_kv_cache_fraction"].as<float>();
}
else
{
benchmarkParams.crossKvCacheFraction
= 0.5f; // default value if not set. but non enc-dec should not even have this param set
}
}

// Argument: Enable TRT overlap
benchmarkParams.enableTrtOverlap = result["enable_trt_overlap"].as<bool>();

Expand Down Expand Up @@ -2131,6 +2194,12 @@ int main(int argc, char* argv[])
// Argument: enable_context_fmha_fp32_acc
benchmarkParams.enableContextFMHAFP32Acc = result["enable_context_fmha_fp32_acc"].as<bool>();

// Argument: cuda_graph_mode
benchmarkParams.cudaGraphMode = result["cuda_graph_mode"].as<bool>();

// Argument: cuda_graph_mode
benchmarkParams.cudaGraphCacheSize = result["cuda_graph_cache_size"].as<SizeType32>();

std::optional<TokenIdType> padId;
// Argument: Padding token id
if (result.count("pad_id"))
Expand Down Expand Up @@ -2168,6 +2237,10 @@ int main(int argc, char* argv[])
{
capacitySchedulerPolicy = texec::CapacitySchedulerPolicy::kGUARANTEED_NO_EVICT;
}
else if (capacitySchedulerPolicyArg == "static_batch")
{
capacitySchedulerPolicy = texec::CapacitySchedulerPolicy::kSTATIC_BATCH;
}
else
{
TLLM_LOG_ERROR("Unexpected scheduler policy: " + capacitySchedulerPolicyArg);
Expand Down Expand Up @@ -2246,14 +2319,14 @@ int main(int argc, char* argv[])
{
texec::ModelType executorModelType;
std::optional<std::string> decoderEngineDir = std::nullopt, encoderEngineDir = std::nullopt;
if (result.count("encoder_engine_dir") && result.count("engine_dir"))
if (result.count("encoder_engine_dir") && result.count("decoder_engine_dir"))
{
TLLM_CHECK_WITH_INFO(api == "executor", "encoder-decoder only support executor api.");
TLLM_CHECK_WITH_INFO(
modelType == TrtGptModelType::InflightFusedBatching, "encoder-decoder only support inflight batching.");
executorModelType = texec::ModelType::kENCODER_DECODER;
decoderEngineDir = result["engine_dir"].as<std::string>();
encoderEngineDir = result["encoder_engine_dir"].as<std::string>();
decoderEngineDir = result["decoder_engine_dir"].as<std::string>();
}
else if (result.count("engine_dir"))
{
Expand Down
2 changes: 0 additions & 2 deletions benchmarks/cpp/utils/prepare_real_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,8 +231,6 @@ def dataset(root_args, **kwargs):
}, root_args.output)
else:
print_dataset(
task_ids,
input_ids,
output_lens,
tokenizer=None,
)
2 changes: 1 addition & 1 deletion benchmarks/python/gpt_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def __init__(self, args, batch_sizes, in_out_lens, gpu_weights_percents,

kv_cache_type = KVCacheType.CONTINUOUS
if hasattr(self, 'kv_cache_type'):
kv_cache_type = self.kv_cache_type
kv_cache_type = KVCacheType(self.kv_cache_type)
else:
if hasattr(self, 'paged_kv_cache'):
kv_cache_type = KVCacheType.PAGED if self.paged_kv_cache == True else KVCacheType.CONTINUOUS
Expand Down
Loading