From 5ad87af45b5bb2409f33dd031be2865b91af15a2 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Fri, 29 Nov 2024 13:11:19 +0000 Subject: [PATCH 1/2] Only cast `cu_seqlens` when tracing --- src/transformers/models/qwen2_vl/modeling_qwen2_vl.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index 5455fcdf3c51..2c93c8d0d186 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -1025,7 +1025,12 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch. rotary_pos_emb = self.rot_pos_emb(grid_thw) cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( - dim=0, dtype=grid_thw.dtype + dim=0, + # Select dtype based on the following factors: + # - FA2 requires that cu_seqlens_q must have dtype int32 + # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw + # See https://github.com/huggingface/transformers/pull/34852 for more information + dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32 ) cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) From fb8eff7e2209d1050551e7938d2658cb25867c75 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Fri, 29 Nov 2024 13:15:49 +0000 Subject: [PATCH 2/2] Formatting --- src/transformers/models/qwen2_vl/modeling_qwen2_vl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index 2c93c8d0d186..f7648f4a53d1 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -1030,7 +1030,7 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch. # - FA2 requires that cu_seqlens_q must have dtype int32 # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw # See https://github.com/huggingface/transformers/pull/34852 for more information - dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32 + dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, ) cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)