You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
python3 tensorrt_llm/examples/multimodal/run.py --visual_engine_dir /tmp/mllama/trt_engines/encoder/ \
--visual_engine_name visual_encoder.engine \
--llm_engine_dir /tmp/mllama/trt_engines/decoder/ \
--hf_model_dir Llama-3.2-11B-Vision-Instruct/ \
--image_path https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg \
--input_text "<|image|><|begin_of_text|>If I had to write a haiku for this one" \
--max_new_tokens 50 \
--batch_size 1
output:
", it would be:.\\nA rabbit in a coat.\\nA charming and dapper fellow.\\nHe's a stylish chap indeed. <OCR/> ርርርርርር
Works as expected.
actual behavior
When run with:
python3 tools/multimodal/client.py --model_type mllama --text "<|image|><|begin_of_text|>If I had to write a haiku for this one" --image https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffuser
s/rabbit.jpg --top-p 0.7 --temperature 0.9 --top-k 40 --request-output-len 20
the result is:
[beam 0 ]:
<|image|><|begin_of_text|>If I had to write a haiku for this one, it would be:
“Golden sunsets fade
Gone, yet memories remain
Summer's
When shown a different image or provided different runtime parameters would similarly ignore image content (different image -> same output).
additional notes
Double-checked tokenization output and also I checked that image inputs are sent correctly (image_bytes) and verified encoder_input_features and cross_attention_masks (multimodal_encoder outputs) are in the same ballpark (though not same or equal by any means) when run with tensorrt_llm/examples/multimodal/run.py
If that difference is not ok, should I look into that?
Aside of that, the bls setting also not working. The LLM itself seems to be working fine and giving correct responses.
The text was updated successfully, but these errors were encountered:
There seems to be a mistake in the multimodal encoder. Specifically the line that sets the skip_cross_attn_blocks unconditionally to torch.ones. Compare to multimodal runner in trt-llm, which sets it to ones only if there's no image data. When changed to torch.zeros in my case the pipeline seems to be working now.
System Info
cpu: x86_64
mem: 128G
gpu: H100 80G
docker: tritonserver:24.12-trtllm-python-py3
Cuda: 12.6
Driver: 535.216.01
TensorRT: 10.7.0
TensorRT-LLM: v0.16.0
Who can help?
@kaiyux @byshiue
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
Steps to reproduce:
Using scripts for Mllama build and deployment multimodal.md:
except:
max_encoder_input_len
to 6404 for Visual-Instruct-11B as indicated by tensorRT-LLM guidecross_kv_cache_fraction
in config.pbtxt - 0.5 (it won't start in triton otherwise)triton command is
tritonserver --model-repository=multimodal_ifb --model-control-mode=explicit --log-verbose=3 --load-model=tensorrt_llm --load-model=multimodal_encoders --load-model=ensemble --load-model=tensorrt_llm_bls --cuda-memory-pool-byte-size=0:300000000
Expected behavior
When tested with ...
output:
", it would be:.\\nA rabbit in a coat.\\nA charming and dapper fellow.\\nHe's a stylish chap indeed. <OCR/> ርርርርርር
Works as expected.
actual behavior
When run with:
the result is:
When shown a different image or provided different runtime parameters would similarly ignore image content (different image -> same output).
additional notes
Double-checked tokenization output and also I checked that image inputs are sent correctly (image_bytes) and verified encoder_input_features and cross_attention_masks (multimodal_encoder outputs) are in the same ballpark (though not same or equal by any means) when run with tensorrt_llm/examples/multimodal/run.py
encoder_input_features in triton:
in tensorrt_llm runner
If that difference is not ok, should I look into that?
Aside of that, the bls setting also not working. The LLM itself seems to be working fine and giving correct responses.
The text was updated successfully, but these errors were encountered: