Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mllama ignores input image when deployed in triton #692

Open
2 of 4 tasks
mutkach opened this issue Feb 5, 2025 · 2 comments
Open
2 of 4 tasks

Mllama ignores input image when deployed in triton #692

mutkach opened this issue Feb 5, 2025 · 2 comments
Labels
bug Something isn't working

Comments

@mutkach
Copy link

mutkach commented Feb 5, 2025

System Info

cpu: x86_64
mem: 128G
gpu: H100 80G
docker: tritonserver:24.12-trtllm-python-py3
Cuda: 12.6
Driver: 535.216.01
TensorRT: 10.7.0
TensorRT-LLM: v0.16.0

Who can help?

@kaiyux @byshiue

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Steps to reproduce:
Using scripts for Mllama build and deployment multimodal.md:
except:

  • use Visual-Instruct-11B instead of Visual-11B
  • set max_encoder_input_len to 6404 for Visual-Instruct-11B as indicated by tensorRT-LLM guide
  • set 1 batch size for testing purposes
  • checkout v0.16.0 tag for TensorRT-LLM (there's discrepancies when converting checkpoint otherwise)
  • fill in cross_kv_cache_fraction in config.pbtxt - 0.5 (it won't start in triton otherwise)
  • starting triton manually with a command
  • load ensemble model (e2e setup would not work otherwise)

triton command is tritonserver --model-repository=multimodal_ifb --model-control-mode=explicit --log-verbose=3 --load-model=tensorrt_llm --load-model=multimodal_encoders --load-model=ensemble --load-model=tensorrt_llm_bls --cuda-memory-pool-byte-size=0:300000000

Expected behavior

When tested with ...

python3 tensorrt_llm/examples/multimodal/run.py --visual_engine_dir /tmp/mllama/trt_engines/encoder/ \
                                   --visual_engine_name visual_encoder.engine \
                                   --llm_engine_dir /tmp/mllama/trt_engines/decoder/ \
                                   --hf_model_dir Llama-3.2-11B-Vision-Instruct/ \
                                   --image_path https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg \
                                   --input_text "<|image|><|begin_of_text|>If I had to write a haiku for this one" \
                                   --max_new_tokens 50 \
                                   --batch_size 1

output:

", it would be:.\\nA rabbit in a coat.\\nA charming and dapper fellow.\\nHe's a stylish chap indeed. <OCR/> ርርርርርር

Works as expected.

actual behavior

When run with:

python3 tools/multimodal/client.py  --model_type mllama --text "<|image|><|begin_of_text|>If I had to write a haiku for this one" --image https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffuser
s/rabbit.jpg --top-p 0.7 --temperature 0.9 --top-k 40 --request-output-len 20

the result is:

[beam 0 ]:
<|image|><|begin_of_text|>If I had to write a haiku for this one, it would be:
“Golden sunsets fade
Gone, yet memories remain
Summer's

When shown a different image or provided different runtime parameters would similarly ignore image content (different image -> same output).

additional notes

Double-checked tokenization output and also I checked that image inputs are sent correctly (image_bytes) and verified encoder_input_features and cross_attention_masks (multimodal_encoder outputs) are in the same ballpark (though not same or equal by any means) when run with tensorrt_llm/examples/multimodal/run.py

encoder_input_features in triton:

tensor([[[  8.1875,  12.3750,  -4.5938,  ..., -12.1875,  -4.4062,   5.1250],
         [ -1.0625,  13.5000,   7.4375,  ...,  -2.3125,  -3.0625, -13.2500],
         [-12.5000,   7.0625,   8.5625,  ...,   3.1875,  -0.1836,  -8.4375],
         ...,
         [ -3.8906,  -2.5625,  -6.0938,  ...,  -2.2812,  -8.1875,  -3.0312],
         [  2.7031,   7.0938,  -7.6875,  ...,  -8.5625,  -4.4062, -22.2500],
         [  4.2500,   1.2734,   1.5156,  ...,  -1.8359,  -2.5312,   1.5625]]],
       device='cuda:0', dtype=torch.bfloat16)

in tensorrt_llm runner

tensor([[  8.1250,  12.3750,  -4.6875,  ..., -12.1875,  -4.3438,   5.2500],
        [ -1.1328,  13.3125,   7.4062,  ...,  -2.6250,  -2.9531, -13.1875],
        [-12.3750,   6.9688,   8.5625,  ...,   2.9688,  -0.2139,  -8.6250],
        ...,
        [ -5.4375,  -2.8125,  -6.9375,  ...,  -3.4375,  -7.8125,  -3.7969],
        [  1.1641,   6.9062,  -3.5000,  ...,  -3.0625,  -2.9688, -27.2500],
        [  4.6562,   1.3906,   1.6953,  ...,  -1.6484,  -2.9375,   1.3281]],
       device='cuda:0', dtype=torch.bfloat16)

If that difference is not ok, should I look into that?
Aside of that, the bls setting also not working. The LLM itself seems to be working fine and giving correct responses.

@mutkach mutkach added the bug Something isn't working label Feb 5, 2025
@mutkach
Copy link
Author

mutkach commented Feb 5, 2025

By the way, thanks for the great work! I would appreciate any helpful directions regarding this issue.

@mutkach
Copy link
Author

mutkach commented Feb 6, 2025

There seems to be a mistake in the multimodal encoder. Specifically the line that sets the skip_cross_attn_blocks unconditionally to torch.ones. Compare to multimodal runner in trt-llm, which sets it to ones only if there's no image data. When changed to torch.zeros in my case the pipeline seems to be working now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant