Skip to content

Commit

Permalink
Fix MHA shape inference (#18009)
Browse files Browse the repository at this point in the history
The previous shape inference never had the chance to infer the past_key
and past_value outputs because we were returning early.
  • Loading branch information
PatriceVignola authored and jchen351 committed Oct 18, 2023
1 parent b7d8cdc commit e4b421a
Showing 1 changed file with 3 additions and 9 deletions.
12 changes: 3 additions & 9 deletions onnxruntime/core/graph/contrib_ops/bert_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -171,10 +171,7 @@ void MultiHeadAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& c
*output_shape.add_dim() = query_dims[1];
*output_shape.add_dim() = query_dims[2] * query_dims[4];
updateOutputShape(ctx, 0, output_shape);
return;
}

if (hasInputShape(ctx, 2)) {
} else if (hasInputShape(ctx, 2)) {
auto& value_shape = getInputShape(ctx, 2);
auto& value_dims = value_shape.dim();
if (value_dims.size() != 3 && value_dims.size() != 4) {
Expand All @@ -192,10 +189,7 @@ void MultiHeadAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& c
? (dmmha_packing ? value_dims[2] / 3 : value_dims[2])
: value_dims[1] * value_dims[3];
updateOutputShape(ctx, 0, output_shape);
return;
}

if (hasInputShape(ctx, 1)) {
} else if (hasInputShape(ctx, 1)) {
auto& key_shape = getInputShape(ctx, 1);
if (key_shape.dim().size() == 5) { // packed KV
ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput(ctx);
Expand All @@ -217,7 +211,7 @@ void MultiHeadAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& c
propagateElemTypeFromInputToOutput(ctx, static_cast<size_t>(past_key_index) + 1, 2);
} else {
if (sequence_length > 0 && past_dims[2].has_dim_value()) {
int64_t total_sequence_length = sequence_length + past_shape.dim(3).dim_value();
int64_t total_sequence_length = sequence_length + past_dims[2].dim_value();

ONNX_NAMESPACE::TensorShapeProto present_shape;
for (auto& dim : past_dims) {
Expand Down

0 comments on commit e4b421a

Please sign in to comment.