diff --git a/onnxruntime/python/tools/transformers/large_model_exporter.py b/onnxruntime/python/tools/transformers/large_model_exporter.py index 407c3b80e153f..1601b1a203b9a 100644 --- a/onnxruntime/python/tools/transformers/large_model_exporter.py +++ b/onnxruntime/python/tools/transformers/large_model_exporter.py @@ -252,7 +252,8 @@ def fetch_onnx_inputs_outputs_name( elif name == "attention_mask": attn_mask = onnx_inputs[idx] onnx_inputs[idx] = torch.cat( - (attn_mask, torch.ones((attn_mask.shape[0], 1), device=attn_mask.device)), dim=1 + (attn_mask, torch.ones((attn_mask.shape[0], 1), device=attn_mask.device, dtype=attn_mask.dtype)), + dim=1, ) elif name == "input_ids": input_ids = onnx_inputs[idx]