From e4b421a90dc4e173815aefd1bee98d1e5ac2f778 Mon Sep 17 00:00:00 2001 From: Patrice Vignola Date: Tue, 17 Oct 2023 21:19:57 -0700 Subject: [PATCH] Fix MHA shape inference (#18009) The previous shape inference never had the chance to infer the past_key and past_value outputs because we were returning early. --- onnxruntime/core/graph/contrib_ops/bert_defs.cc | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index d1269b9c2c23c..3a75b29ffe3c7 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -171,10 +171,7 @@ void MultiHeadAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& c *output_shape.add_dim() = query_dims[1]; *output_shape.add_dim() = query_dims[2] * query_dims[4]; updateOutputShape(ctx, 0, output_shape); - return; - } - - if (hasInputShape(ctx, 2)) { + } else if (hasInputShape(ctx, 2)) { auto& value_shape = getInputShape(ctx, 2); auto& value_dims = value_shape.dim(); if (value_dims.size() != 3 && value_dims.size() != 4) { @@ -192,10 +189,7 @@ void MultiHeadAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& c ? (dmmha_packing ? value_dims[2] / 3 : value_dims[2]) : value_dims[1] * value_dims[3]; updateOutputShape(ctx, 0, output_shape); - return; - } - - if (hasInputShape(ctx, 1)) { + } else if (hasInputShape(ctx, 1)) { auto& key_shape = getInputShape(ctx, 1); if (key_shape.dim().size() == 5) { // packed KV ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput(ctx); @@ -217,7 +211,7 @@ void MultiHeadAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& c propagateElemTypeFromInputToOutput(ctx, static_cast(past_key_index) + 1, 2); } else { if (sequence_length > 0 && past_dims[2].has_dim_value()) { - int64_t total_sequence_length = sequence_length + past_shape.dim(3).dim_value(); + int64_t total_sequence_length = sequence_length + past_dims[2].dim_value(); ONNX_NAMESPACE::TensorShapeProto present_shape; for (auto& dim : past_dims) {