Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LLAVA is slow due to unnecessary output tokens #1118

Closed
1 of 4 tasks
Gutianpei opened this issue Feb 20, 2024 · 18 comments
Closed
1 of 4 tasks

LLAVA is slow due to unnecessary output tokens #1118

Gutianpei opened this issue Feb 20, 2024 · 18 comments
Assignees
Labels
bug Something isn't working

Comments

@Gutianpei
Copy link

System Info

  • H100

Who can help?

@kaiy

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Use the official code to run LLAVA1.5-13B

Expected behavior

Much higher throughput -- currently I got ~9.8 img/sec with batchsize=48, where sglang has 18.6 img/sec. TensorRT-LLM should be at least 2x faster than sglang or vllm.

actual behavior

See above

additional notes

I also benchmarked llama2 and the throughput is expected. Looking into the code I found the output ids contain all the image tokens, where official llava code only contain the text tokens. Is it possible the LLM part are predicting the image tokens as well so it cause the slowdown?

@Gutianpei Gutianpei added the bug Something isn't working label Feb 20, 2024
@white-wolf-tech
Copy link

I have tested in vllm. and trt-llm is much slower than vllm with big batch size.

@Gutianpei
Copy link
Author

I have tested in vllm. and trt-llm is much slower than vllm with big batch size.

Did you test multi-modal models or just language models?

@white-wolf-tech
Copy link

I have tested in vllm. and trt-llm is much slower than vllm with big batch size.

Did you test multi-modal models or just language models?

just LLM

@springsprite
Copy link

System Info

* H100

Who can help?

@kaiy

Information

* [x]  The official example scripts

* [ ]  My own modified scripts

Tasks

* [ ]  An officially supported task in the `examples` folder (such as GLUE/SQuAD, ...)

* [ ]  My own task or dataset (give details below)

Reproduction

Use the official code to run LLAVA1.5-13B

Expected behavior

Much higher throughput -- currently I got ~9.8 img/sec with batchsize=48, where sglang has 18.6 img/sec. TensorRT-LLM should be at least 2x faster than sglang or vllm.

actual behavior

See above

additional notes

I also benchmarked llama2 and the throughput is expected. Looking into the code I found the output ids contain all the image tokens, where official llava code only contain the text tokens. Is it possible the LLM part are predicting the image tokens as well so it cause the slowdown?

what is the version of the used tensorrt_llm? check the version by using the following code

python -c "import tensorrt_llm;print(tensorrt_llm.__version__)"

@Gutianpei
Copy link
Author

python -c "import tensorrt_llm;print(tensorrt_llm.version)"

It's [TensorRT-LLM] TensorRT-LLM version: 0.9.0.dev20240206000.9.0.dev2024020600

@kaiyux
Copy link
Member

kaiyux commented Feb 26, 2024

@Gutianpei Can you please share the steps to reproduce the numbers you get?

@symphonylyh
Copy link
Collaborator

symphonylyh commented Feb 26, 2024

@Gutianpei @x-transformers @springsprite we have observed a similar behavior (i.e., running non-stop until reaching max_new_tokens) and have fixed that in the upcoming release. Can you please test again on top of the main branch after tomorrow's weekly release?
cross-ref #1123

@Gutianpei
Copy link
Author

Gutianpei commented Feb 27, 2024

@Gutianpei Can you please share the steps to reproduce the numbers you get?

Thanks for the help. This is the script I used, I just changed the logging in run.py to output the img/sec throughput: f'Generate latency: { num_iters * batch_size / profiler.elapsed_time_in_sec("Generate")} sec'

export MODEL_NAME="llava-1.5-13b-hf"
export BATCHSIZE=48

python3 ../llama/convert_checkpoint.py \
    --model_dir tmp/hf_models/${MODEL_NAME} \
    --dtype float16 \
    --output_dir tmp/trt_models/${MODEL_NAME}/int8_weightonly/1-gpu  \
    --use_weight_only \
    --weight_only_precision int8

 python3 build_visual_engine.py \
    --model_name ${MODEL_NAME} \
    --model_path tmp/hf_models/${MODEL_NAME} \
    --max_batch_size ${BATCHSIZE}

trtllm-build  \
    --checkpoint_dir tmp/trt_models/${MODEL_NAME}/int8_weightonly/1-gpu \
    --output_dir trt_engines/${MODEL_NAME}/int8_weightonly/1-gpu \
    --gpt_attention_plugin float16 \
    --gemm_plugin float16 \
    --max_batch_size ${BATCHSIZE} \
    --max_input_len 924 \
    --max_output_len 128 \
    --max_multimodal_len $((${BATCHSIZE} * 576)) \
    --paged_kv_cache enable \
    --multi_block_mode enable

python3 run.py \
    --max_new_tokens 128 \
    --input_text "Question: Describe this image. Answer:" \
    --hf_model_dir tmp/hf_models/${MODEL_NAME} \
    --visual_engine_dir visual_engines/${MODEL_NAME} \
    --llm_engine_dir trt_engines/${MODEL_NAME}/int8_weightonly/1-gpu  \
    --decoder_llm \
    --batch_size ${BATCHSIZE}

@Gutianpei
Copy link
Author

@Gutianpei @x-transformers @springsprite we have observed a similar behavior (i.e., running non-stop until reaching max_new_tokens) and have fixed that in the upcoming release. Can you please test again on top of the main branch after tomorrow's weekly release? cross-ref #1123

Thanks for the help. It does not look like running non-stop until reaching max_new_tokens caused the issue, as I set max_new_tokens to 100 in my experiment and the throughput is still much slower than expected even if the model outputs 100 new tokens everytime. I'll try the latest release once it published, and please also take a look of my testing script above for reproducing the issue.

@kaiyux
Copy link
Member

kaiyux commented Feb 28, 2024

@Gutianpei we pushed an update to the main branch, can you please try again on the latest main branch and see if the issue persists? Thank you.

@Gutianpei
Copy link
Author

Gutianpei commented Feb 28, 2024

@Gutianpei we pushed an update to the main branch, can you please try again on the latest main branch and see if the issue persists? Thank you.

Thanks for the fix! The throughput is clearly improved, I got 12.38 img/sec versus perviously 9.8. Unfortunately, I think it's still much slower than it should be. I can get 11.2 img/sec from sglang and much higher throughput from vllm, I think in theory an int8/fp8 trt-llm engine should be much faster. Can you take a look at the script I used above to see if any parameters I got wrong? Also do you think the super long output token length slows down the generation? Thank you!

@symphonylyh symphonylyh self-assigned this Feb 29, 2024
@symphonylyh
Copy link
Collaborator

symphonylyh commented Feb 29, 2024

@Gutianpei can you disable --paged_kv_cache? it's not recommended to use it in python runtime, as it's more relevant to inflight batching serving, etc. Please have a try and see if it helps

I can get 11.2 img/sec from sglang and much higher throughput from vllm

Regarding, can you share some measurement stats on this observation? For example, is the comparison apple-to-apple in terms of batch size, precision, etc. And for throughput, we're working on inflight batching serving for enc-dec and multimodal models, so this is something that will help with throughput

@symphonylyh
Copy link
Collaborator

Also as we discussed under #1123 #1123 (comment), we compared with HF transformers and see the correct speed should be ~5x faster for llava, is this ok for your use case?

@Gutianpei
Copy link
Author

@Gutianpei can you disable --paged_kv_cache? it's not recommended to use it in python runtime, as it's more relevant to inflight batching serving, etc. Please have a try and see if it helps

I can get 11.2 img/sec from sglang and much higher throughput from vllm

Regarding, can you share some measurement stats on this observation? For example, is the comparison apple-to-apple in terms of batch size, precision, etc. And for throughput, we're working on inflight batching serving for enc-dec and multimodal models, so this is something that will help with throughput

Thanks for the reply. The fix you pushed this week definitely accelerates the generation, I got 1.5x faster on my end.

  1. Disabling paged_kv_cache does not make any difference.
  2. Yes, the comparison is conducted in the same single H100 GPU. vllm and sglang have their own batching strategy to automatically maximize the utilization, I manually test and set the max batchsize for trt-llm for fair comparison. Precision for trt-llm is int8, and for others are float16.
  3. I can reproduce the result in Why is llava trt-llm not much faster than transformers? #1123, but I think in theory trt-llm should still be much faster?

Here is the logging from the above script I used (paged_kv_cache disabled):

[02/29/2024-06:55:50] [TRT-LLM] [I] TensorRT vision encoder latency: 0.04693348407745361 sec
[02/29/2024-06:55:50] [TRT-LLM] [I] TensorRT-LLM LLM latency: 3.345199966430664 sec
[02/29/2024-06:55:50] [TRT-LLM] [I] Generate latency: 3.7880650758743286 sec

I set max_new_tokens to 128, the actual generation does not have 128 output tokens but let's just use 128 for simplicity. From the log LLM latency is 3.34 sec/batch with batchsize=48 images so it's 14.37 img/sec, or 1839 decoding tok/sec. This is 15% slower than my vllm-llava implementation with the exactly same setting, which confuses me a lot since in theory TRT-LLM with int8 should be 2x faster than other services. Does the throughput on LLM part make sense to you?

However I think there are a lot room for improving TRT-LLM, just want to make sure there is no bug and all my settings are correct. Excited to checkout the inflight batching in the future!

@symphonylyh
Copy link
Collaborator

@Gutianpei , some explanation on the perf question you have: the first thing to clarify is latency vs. throughput. Latency is absolute kernel/model performance which I believe TRT-LLM is doing SOTA. Meanwhile, throughput wise TRT-LLM is also doing SOTA for those models with inflight batching ENABLED. What's the difference between (1) your current run of TRT-LLM llava at certain batch size and (2) future run of TRT-LLM llava with inflight batching enabled, or vllm run (I'm less familar with sglang)? Because serving optimization matters a lot for throughput. Think of a batch of input images that will generate output lengths of 10, 50, 100, (1) will have to wait until the entire batch finishes, while (2) can continously processing new batches/images when 10 and 50 finishes earlier. So we should keep this in mind for throughput comparison.

Meanwhile, for absolute kernel/model perf, your message is well received and we're working on improvements. For example, we found the data transfer of visual embedding from visual engine --> LLM engine can be optimized, and expect that to narrow the gap you observed

@amukkara
Copy link

amukkara commented Mar 1, 2024

@Gutianpei
For this model, INT8 does not give speedup over FP16 in TRT-LLM. Output tokens are same for INT8 and FP16 in my tests, so issue is not extra output tokens. With same output token count, TRT-LLM FP16 is 4x faster than transformers.

Can you try other quantisation methods like SmoothQuant or INT4 AWQ? examples/llama/README.md lists extra arguments to enable these methods for LLaMA. They should apply for LLaVA as well.

@Gutianpei
Copy link
Author

@symphonylyh
Thanks for the explanation! In my experiments I used the exactly the same images as input, so the output lengths should be same for all methods. In my use case, I'd care more about the throughput instead of latency. For your reference, here is the full table of my experiments:
max_new_tokens=48, model=llava1.5-13b

vllm throughput: 19 img/sec
sglang throughput: 18.2 img/sec
trt-llm throughput: 17.7 img/sec

I really appreciate your help and the trt-llm support. Will definitely try inflight batching once it's available. I'm closing this issue since my concerns are all addressed, thank you so much!

@amukkara
Thanks for the suggestions! I tried both and didn't see a clear speedup though, I think the inflight batching might be more helpful in my use case.

@bleedingfight
Copy link

@Gutianpei For this model, INT8 does not give speedup over FP16 in TRT-LLM. Output tokens are same for INT8 and FP16 in my tests, so issue is not extra output tokens. With same output token count, TRT-LLM FP16 is 4x faster than transformers.

Can you try other quantisation methods like SmoothQuant or INT4 AWQ? examples/llama/README.md lists extra arguments to enable these methods for LLaMA. They should apply for LLaVA as well.

May I ask how do you verify that the reasoning result of TRT is correct?I can generate output using the method provided by the demo, but the output is different from transformers. How can I ensure that the output of the trt is correct?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

7 participants