From 6f3ade55ecd2b989612c1680b87ad9b6c557d1c9 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Wed, 9 Feb 2022 14:23:17 -0800 Subject: [PATCH] Move QAttention/QEmbedLayerNormalization op defs to quantization_defs.cc (#10507) --- .../core/graph/contrib_ops/bert_defs.cc | 286 +----------------- onnxruntime/core/graph/contrib_ops/ms_opset.h | 4 +- .../graph/contrib_ops/quantization_defs.cc | 121 +++++++- .../contrib_ops/shape_inference_functions.cc | 178 +++++++++++ .../contrib_ops/shape_inference_functions.h | 17 ++ 5 files changed, 317 insertions(+), 289 deletions(-) create mode 100644 onnxruntime/core/graph/contrib_ops/shape_inference_functions.cc create mode 100644 onnxruntime/core/graph/contrib_ops/shape_inference_functions.h diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index c9f12a7a1edb5..3da9e4fc8d3d5 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -5,177 +5,12 @@ #include "core/graph/contrib_ops/contrib_defs.h" #include "core/graph/contrib_ops/quantization_defs.h" #include "core/graph/contrib_ops/onnx_function_util.h" +#include "core/graph/contrib_ops/shape_inference_functions.h" using namespace ::ONNX_NAMESPACE; namespace onnxruntime { namespace contrib { -void embedLayerNormalizationShapeInference(InferenceContext& ctx) { - propagateElemTypeFromInputToOutput(ctx, 2, 0); - propagateElemTypeFromInputToOutput(ctx, 0, 1); - if (!hasInputShape(ctx, 0)) { - // TODO(kreeger): In this case update the output to (?, ?, hidden_size). - return; - } - - auto& input_ids_shape = getInputShape(ctx, 0); - auto& input_ids_dims = input_ids_shape.dim(); - - // Note that both batch size and sequence length could be symbolic. - // So we only check dimension size here. - if (input_ids_dims.size() != 2) { - fail_shape_inference("input_ids shall be 2 dimensions"); - } - - bool has_segment = hasInputShape(ctx, 1); - if (has_segment) { - // Ensure that segment_ids has the same shape. - auto& segment_ids_shape = getInputShape(ctx, 1); - auto& segment_ids_dims = segment_ids_shape.dim(); - if (segment_ids_dims.size() != 2) { - fail_shape_inference("segment_ids input shall be 2 dimensions"); - } - } - - // get hidden_size from the last dimension of embedding - auto& word_embedding_shape = getInputShape(ctx, 2); - auto& word_embedding_dims = word_embedding_shape.dim(); - if (word_embedding_dims.size() != 2 || - !word_embedding_dims[1].has_dim_value() || - word_embedding_shape.dim(1).dim_value() <= 0) { - fail_shape_inference("word_embedding should have 2 dimensions and dimension size is known."); - } - int64_t hidden_size = word_embedding_shape.dim(1).dim_value(); - - // Ensure that all embeddings + the gamma/beta tensors have the same hidden_size: - auto& position_embedding_shape = getInputShape(ctx, 3); - auto& position_embedding_dims = position_embedding_shape.dim(); - if (position_embedding_dims.size() != 2 || - !position_embedding_dims[1].has_dim_value() || - position_embedding_shape.dim(1).dim_value() != hidden_size) { - fail_shape_inference( - "position_embedding should have 2 dimensions, dimension size known, " - "and same hidden size as word_embedding."); - } - - if (has_segment) { - auto& segment_embedding_shape = getInputShape(ctx, 4); - auto& segment_embedding_dims = segment_embedding_shape.dim(); - if (segment_embedding_dims.size() != 2 || - !segment_embedding_dims[1].has_dim_value() || - segment_embedding_shape.dim(1).dim_value() != hidden_size) { - fail_shape_inference( - "segment_embedding should have 2 dimensions, dimension size known, " - "and same hidden size as word_embedding."); - } - } - - auto& gamma_shape = getInputShape(ctx, 5); - auto& gamma_dims = gamma_shape.dim(); - if (gamma_dims.size() != 1 || - !gamma_dims[0].has_dim_value() || - gamma_shape.dim(0).dim_value() != hidden_size) { - fail_shape_inference( - "gamma should have 2 dimension, dimension size known, " - "and same hidden size as word_embedding."); - } - - auto& beta_shape = getInputShape(ctx, 6); - auto& beta_dims = gamma_shape.dim(); - if (beta_dims.size() != 1 || - !beta_dims[0].has_dim_value() || - beta_shape.dim(0).dim_value() != hidden_size) { - fail_shape_inference( - "beta should have 1 dimension, dimension size known, " - "and same hidden size as word_embedding."); - } - - // input shape is (batch_size, sequence_length), output shape is (batch_size, sequence_length, hidden_size) - ONNX_NAMESPACE::TensorShapeProto output_shape; - *output_shape.add_dim() = input_ids_dims[0]; - *output_shape.add_dim() = input_ids_dims[1]; - - output_shape.add_dim(); - output_shape.mutable_dim(2)->set_dim_value(hidden_size); - - updateOutputShape(ctx, 0, output_shape); - - // mask_index shape is (batch_size) - ONNX_NAMESPACE::TensorShapeProto mask_index_shape; - *mask_index_shape.add_dim() = input_ids_dims[0]; - updateOutputShape(ctx, 1, mask_index_shape); - - if (ctx.getNumOutputs() > 2) { - updateOutputShape(ctx, 2, output_shape); - propagateElemTypeFromInputToOutput(ctx, 0, 2); - } -} -void AttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, int past_input_index) { - // Type inference - ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 2, 0); - if (ctx.getNumOutputs() > 1) { - ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 2, 1); - } - - // Shape inference - if (hasInputShape(ctx, 0) && hasInputShape(ctx, 2)) { - auto& input_shape = getInputShape(ctx, 0); - auto& input_dims = input_shape.dim(); - if (input_dims.size() != 3) { - fail_shape_inference("Inputs 0 shall be 3 dimensions"); - } - - auto& bias_shape = getInputShape(ctx, 2); - auto& bias_dims = bias_shape.dim(); - if (bias_dims.size() != 1) { - fail_shape_inference("Invalid bias shape"); - } - - std::vector qkv_hidden_sizes; - getRepeatedAttribute(ctx, "qkv_hidden_sizes", qkv_hidden_sizes); - - int64_t output_hidden_size; - if (qkv_hidden_sizes.size() != 0) { - if (qkv_hidden_sizes.size() != 3) { - fail_shape_inference("qkv_hidden_sizes should have 3 elements") - } - output_hidden_size = qkv_hidden_sizes[2]; - } else { - output_hidden_size = bias_shape.dim(0).dim_value() / 3; - } - - ONNX_NAMESPACE::TensorShapeProto output_shape; - for (auto& dim : input_dims) { - *output_shape.add_dim() = dim; - } - - output_shape.mutable_dim(2)->set_dim_value(output_hidden_size); - updateOutputShape(ctx, 0, output_shape); - - // TODO does the extra output need any changes? - if (ctx.getNumOutputs() > 1) { - if (hasInputShape(ctx, past_input_index)) { - auto& past_shape = getInputShape(ctx, past_input_index); - auto& past_dims = past_shape.dim(); - if (past_dims.size() != 5) { - fail_shape_inference("Inputs 4 shall be 5 dimensions"); - } - - if (past_dims[3].has_dim_value() && input_dims[1].has_dim_value()) { - auto all_sequence_length = past_shape.dim(3).dim_value() + input_shape.dim(1).dim_value(); - - ONNX_NAMESPACE::TensorShapeProto present_shape; - for (auto& dim : past_dims) { - *present_shape.add_dim() = dim; - } - present_shape.mutable_dim(3)->set_dim_value(all_sequence_length); - - updateOutputShape(ctx, 1, present_shape); - } - } - } - } -} void DecoderAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx) { // Type inference @@ -255,86 +90,6 @@ ONNX_MS_OPERATOR_SET_SCHEMA(Attention, 1, AttentionTypeAndShapeInference(ctx, past_input_index); })); -ONNX_MS_OPERATOR_SET_SCHEMA(QAttention, 1, - OpSchema() - .SetDoc("Quantization of Multi-Head Self Attention.") - .Attr("num_heads", "Number of attention heads", AttributeProto::INT) - .Attr("unidirectional", - "Whether every token can only attend to previous tokens. Default value is 0.", - AttributeProto::INT, - static_cast(0)) - .Input( - 0, - "input", - "3D input tensor with shape (batch_size, sequence_length, input_hidden_size)", - "T1") - .Input( - 1, - "weight", - "2D input tensor with shape (input_hidden_size, 3 * hidden_size), hidden_size = num_heads * head_size", - "T2") - .Input( - 2, - "bias", - "1D input tensor with shape (3 * hidden_size)", - "T3") - .Input( - 3, - "input_scale", - "scale of quantized input tensor. It's a scalar, which means a per-tensor/layer quantization.", - "T3") - .Input( - 4, - "weight_scale", - "scale of weight scale. It's a scalar or a 1D tensor, which means a per-tensor/per-column quantization." - "Its size should be 3 * hidden_size if it is per-column quantization", - "T3") - .Input( - 5, - "mask_index", - "Attention mask index with shape (batch_size)", - "T4", - OpSchema::Optional) - .Input( - 6, - "input_zero_point", - "zero point of quantized input tensor.It's a scalar, which means a per-tensor/layer quantization.", - "T1", - OpSchema::Optional) - .Input( - 7, - "weight_zero_point", - "zero point of quantized weight tensor. It's a scalar or a 1D tensor, which means a per-tensor/per-column quantization." - "Its size should be 3 * hidden_size if it is per-column quantization", - "T2", - OpSchema::Optional) - .Input( - 8, - "past", - "past state for key and value with shape (2, batch_size, num_heads, past_sequence_length, head_size).", - "T3", - OpSchema::Optional) - .Output( - 0, - "output", - "3D output tensor with shape (batch_size, sequence_length, hidden_size)", - "T3") - .Output( - 1, - "present", - "present state for key and value with shape (2, batch_size, num_heads, past_sequence_length + sequence_length, head_size)", - "T3", - OpSchema::Optional) - .TypeConstraint("T1", {"tensor(int8)", "tensor(uint8)"}, "Constrain input and output types to int8 tensors.") - .TypeConstraint("T2", {"tensor(int8)", "tensor(uint8)"}, "Constrain input and output types to int8 tensors.") - .TypeConstraint("T3", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float tensors.") - .TypeConstraint("T4", {"tensor(int32)"}, "Constrain mask index to integer types") - .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { - constexpr int past_input_index = 8; - - AttentionTypeAndShapeInference(ctx, past_input_index); - })); - constexpr const char* Longformer_Attention_doc = R"DOC( Longformer Self Attention with a local context and a global context. Tokens attend locally: Each token attends to its W previous tokens and W succeding tokens with W being the window length. A selected few tokens @@ -420,44 +175,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(EmbedLayerNormalization, 1, .Output(2, "embedding_sum", "sum of word_embedding and position_embedding without layer normalization", "T", OpSchema::Optional) .TypeConstraint("T1", {"tensor(int32)"}, "Constrain input and output integer tensors types") .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output float tensors types.") - .TypeAndShapeInferenceFunction(embedLayerNormalizationShapeInference)); - -constexpr const char* QEmbedLayerNormalization_ver1_doc = R"DOC( -QEmbedLayerNormalization is the quantized fusion of embedding layer in BERT model, with optional mask processing. -The embedding layer takes input_ids (word IDs) and segment_ids (sentence IDs) to look up word_embedding, position_embedding, -and segment_emedding; the embeddings are added then applied layer normalization using gamma and beta tensors. The input_ids -and segment_ids remain int32. All embeddings, gamma, and beta tensors are converted to int8/uint8. The last input mask is optional. -If mask is provided, mask index (that is position of first 0 in mask, or number of words will be calculated.)DOC"; - -ONNX_MS_OPERATOR_SET_SCHEMA(QEmbedLayerNormalization, 1, - OpSchema() - .SetSupportLevel(OpSchema::SupportType::EXPERIMENTAL) - .SetDoc(QEmbedLayerNormalization_ver1_doc) - .Attr("epsilon", "The epsilon value to use to avoid division by zero.", AttributeProto::FLOAT, kDefaultEmbedLayerNormEpsilon) - .Input(0, "input_ids", "2D words IDs with shape (batch_size, sequence_length)", "T1") - .Input(1, "segment_ids", "2D segment IDs with shape (batch_size, sequence_length)", "T1", OpSchema::Optional) - .Input(2, "word_embedding_quant", "2D with shape (,hidden_size)", "T2") - .Input(3, "position_embedding_quant", "2D with shape (, hidden_size)", "T2") - .Input(4, "segment_embedding", "2D with shape (, hidden_size)", "T2", OpSchema::Optional) - .Input(5, "gamma_quant", "1D gamma tensor for layer normalization with shape (hidden_size)", "T2") - .Input(6, "beta_quant", "1D beta tensor for layer normalization with shape (hidden_size)", "T2") - .Input(7, "mask", "Mask", "T1", OpSchema::Optional) - .Input(8, "word_embedding_scale", "Scale for word embeddings", "T") - .Input(9, "position_embedding_scale", "Scale for position embeddings", "T") - .Input(10, "segment_embedding_scale", "Scale for segment embeddings", "T", OpSchema::Optional) - .Input(11, "gamma_scale", "Scale for 1D gamma tensor", "T") - .Input(12, "beta_scale", "Scale for 1D beta tensor", "T") - .Input(13, "word_embedding_zero_point", "Zero point for word embeddings", "T2") - .Input(14, "position_embedding_zero_point", "Zero point for position embeddings", "T2") - .Input(15, "segment_embedding_zero_point", "Zero Point for segment embeddings", "T2", OpSchema::Optional) - .Input(16, "gamma_zero_point", "Zero Point for 1D gamma tensor", "T2") - .Input(17, "beta_zero_point", "Zero Point for 1D beta tensor", "T2") - .Output(0, "layernorm_out", "LayerNorm Output", "T") - .Output(1, "mask_index_out", "Mask Index Output", "T1") - .TypeConstraint("T1", {"tensor(int32)"}, "Constrain mask index to integer types") - .TypeConstraint("T2", {"tensor(int8)", "tensor(uint8)"}, "Constrain input and output types to int8 tensors.") - .TypeConstraint("T", {"tensor(float)"}, "Constrain input and output types to float32 tensors.") - .TypeAndShapeInferenceFunction(embedLayerNormalizationShapeInference)); + .TypeAndShapeInferenceFunction(EmbedLayerNormalizationShapeInference)); constexpr const char* FastGelu_ver1_doc = R"DOC( GELU (Gaussian Error Linear Unit) approximation: Y=0.5*X*(1+tanh(0.797885*X+0.035677*X*X*X)) with an optional input of bias that will be added to X before GELU.)DOC"; diff --git a/onnxruntime/core/graph/contrib_ops/ms_opset.h b/onnxruntime/core/graph/contrib_ops/ms_opset.h index 0bdeef32d22ae..1b45809d19184 100644 --- a/onnxruntime/core/graph/contrib_ops/ms_opset.h +++ b/onnxruntime/core/graph/contrib_ops/ms_opset.h @@ -19,6 +19,8 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, DynamicQuantizeLSTM); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, DynamicQuantizeMatMul); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MatMulIntegerToFloat); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MulInteger); +class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QAttention); +class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QEmbedLayerNormalization); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QGemm); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QLinearAdd); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QLinearConcat); @@ -60,8 +62,6 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MaxpoolWithMask); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MurmurHash3); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, NGramRepeatBlock); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Pad); -class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QAttention); -class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QEmbedLayerNormalization); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Rfft); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SampleOp); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SkipLayerNormalization); diff --git a/onnxruntime/core/graph/contrib_ops/quantization_defs.cc b/onnxruntime/core/graph/contrib_ops/quantization_defs.cc index 9144dddb75908..74dbcd86d4107 100644 --- a/onnxruntime/core/graph/contrib_ops/quantization_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/quantization_defs.cc @@ -4,8 +4,7 @@ #include "core/graph/contrib_ops/quantization_defs.h" #include "core/graph/constants.h" #include "core/graph/contrib_ops/contrib_defs.h" - - +#include "core/graph/contrib_ops/shape_inference_functions.h" namespace ONNX_NAMESPACE { void RNNShapeInference(InferenceContext& ctx); @@ -960,5 +959,121 @@ Wwhere the function `Sigmoid(x) = 1 / (1 + exp(-x))` )DOC"; second_input_shape.dim(transB ? 0 : 1)}); } })); -} // namespace contrib + ONNX_MS_OPERATOR_SET_SCHEMA(QAttention, 1, + OpSchema() + .SetDoc("Quantization of Multi-Head Self Attention.") + .Attr("num_heads", "Number of attention heads", AttributeProto::INT) + .Attr("unidirectional", + "Whether every token can only attend to previous tokens. Default value is 0.", + AttributeProto::INT, + static_cast(0)) + .Input( + 0, + "input", + "3D input tensor with shape (batch_size, sequence_length, input_hidden_size)", + "T1") + .Input( + 1, + "weight", + "2D input tensor with shape (input_hidden_size, 3 * hidden_size), hidden_size = num_heads * head_size", + "T2") + .Input( + 2, + "bias", + "1D input tensor with shape (3 * hidden_size)", + "T3") + .Input( + 3, + "input_scale", + "scale of quantized input tensor. It's a scalar, which means a per-tensor/layer quantization.", + "T3") + .Input( + 4, + "weight_scale", + "scale of weight scale. It's a scalar or a 1D tensor, which means a per-tensor/per-column quantization." + "Its size should be 3 * hidden_size if it is per-column quantization", + "T3") + .Input( + 5, + "mask_index", + "Attention mask index with shape (batch_size)", + "T4", + OpSchema::Optional) + .Input( + 6, + "input_zero_point", + "zero point of quantized input tensor.It's a scalar, which means a per-tensor/layer quantization.", + "T1", + OpSchema::Optional) + .Input( + 7, + "weight_zero_point", + "zero point of quantized weight tensor. It's a scalar or a 1D tensor, which means a per-tensor/per-column quantization." + "Its size should be 3 * hidden_size if it is per-column quantization", + "T2", + OpSchema::Optional) + .Input( + 8, + "past", + "past state for key and value with shape (2, batch_size, num_heads, past_sequence_length, head_size).", + "T3", + OpSchema::Optional) + .Output( + 0, + "output", + "3D output tensor with shape (batch_size, sequence_length, hidden_size)", + "T3") + .Output( + 1, + "present", + "present state for key and value with shape (2, batch_size, num_heads, past_sequence_length + sequence_length, head_size)", + "T3", + OpSchema::Optional) + .TypeConstraint("T1", {"tensor(int8)", "tensor(uint8)"}, "Constrain input and output types to int8 tensors.") + .TypeConstraint("T2", {"tensor(int8)", "tensor(uint8)"}, "Constrain input and output types to int8 tensors.") + .TypeConstraint("T3", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float tensors.") + .TypeConstraint("T4", {"tensor(int32)"}, "Constrain mask index to integer types") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + constexpr int past_input_index = 8; + + AttentionTypeAndShapeInference(ctx, past_input_index); + })); + + constexpr const char* QEmbedLayerNormalization_ver1_doc = R"DOC( +QEmbedLayerNormalization is the quantized fusion of embedding layer in BERT model, with optional mask processing. +The embedding layer takes input_ids (word IDs) and segment_ids (sentence IDs) to look up word_embedding, position_embedding, +and segment_emedding; the embeddings are added then applied layer normalization using gamma and beta tensors. The input_ids +and segment_ids remain int32. All embeddings, gamma, and beta tensors are converted to int8/uint8. The last input mask is optional. +If mask is provided, mask index (that is position of first 0 in mask, or number of words will be calculated.)DOC"; + + ONNX_MS_OPERATOR_SET_SCHEMA(QEmbedLayerNormalization, 1, + OpSchema() + .SetSupportLevel(OpSchema::SupportType::EXPERIMENTAL) + .SetDoc(QEmbedLayerNormalization_ver1_doc) + .Attr("epsilon", "The epsilon value to use to avoid division by zero.", AttributeProto::FLOAT, kDefaultEmbedLayerNormEpsilon) + .Input(0, "input_ids", "2D words IDs with shape (batch_size, sequence_length)", "T1") + .Input(1, "segment_ids", "2D segment IDs with shape (batch_size, sequence_length)", "T1", OpSchema::Optional) + .Input(2, "word_embedding_quant", "2D with shape (,hidden_size)", "T2") + .Input(3, "position_embedding_quant", "2D with shape (, hidden_size)", "T2") + .Input(4, "segment_embedding", "2D with shape (, hidden_size)", "T2", OpSchema::Optional) + .Input(5, "gamma_quant", "1D gamma tensor for layer normalization with shape (hidden_size)", "T2") + .Input(6, "beta_quant", "1D beta tensor for layer normalization with shape (hidden_size)", "T2") + .Input(7, "mask", "Mask", "T1", OpSchema::Optional) + .Input(8, "word_embedding_scale", "Scale for word embeddings", "T") + .Input(9, "position_embedding_scale", "Scale for position embeddings", "T") + .Input(10, "segment_embedding_scale", "Scale for segment embeddings", "T", OpSchema::Optional) + .Input(11, "gamma_scale", "Scale for 1D gamma tensor", "T") + .Input(12, "beta_scale", "Scale for 1D beta tensor", "T") + .Input(13, "word_embedding_zero_point", "Zero point for word embeddings", "T2") + .Input(14, "position_embedding_zero_point", "Zero point for position embeddings", "T2") + .Input(15, "segment_embedding_zero_point", "Zero Point for segment embeddings", "T2", OpSchema::Optional) + .Input(16, "gamma_zero_point", "Zero Point for 1D gamma tensor", "T2") + .Input(17, "beta_zero_point", "Zero Point for 1D beta tensor", "T2") + .Output(0, "layernorm_out", "LayerNorm Output", "T") + .Output(1, "mask_index_out", "Mask Index Output", "T1") + .TypeConstraint("T1", {"tensor(int32)"}, "Constrain mask index to integer types") + .TypeConstraint("T2", {"tensor(int8)", "tensor(uint8)"}, "Constrain input and output types to int8 tensors.") + .TypeConstraint("T", {"tensor(float)"}, "Constrain input and output types to float32 tensors.") + .TypeAndShapeInferenceFunction(EmbedLayerNormalizationShapeInference)); + } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/core/graph/contrib_ops/shape_inference_functions.cc b/onnxruntime/core/graph/contrib_ops/shape_inference_functions.cc new file mode 100644 index 0000000000000..381b0755b8bfe --- /dev/null +++ b/onnxruntime/core/graph/contrib_ops/shape_inference_functions.cc @@ -0,0 +1,178 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/graph/contrib_ops/shape_inference_functions.h" +#include + +namespace onnxruntime { +namespace contrib { +void EmbedLayerNormalizationShapeInference(::ONNX_NAMESPACE::InferenceContext& ctx) { + propagateElemTypeFromInputToOutput(ctx, 2, 0); + propagateElemTypeFromInputToOutput(ctx, 0, 1); + if (!hasInputShape(ctx, 0)) { + // TODO(kreeger): In this case update the output to (?, ?, hidden_size). + return; + } + + auto& input_ids_shape = getInputShape(ctx, 0); + auto& input_ids_dims = input_ids_shape.dim(); + + // Note that both batch size and sequence length could be symbolic. + // So we only check dimension size here. + if (input_ids_dims.size() != 2) { + fail_shape_inference("input_ids shall be 2 dimensions"); + } + + bool has_segment = hasInputShape(ctx, 1); + if (has_segment) { + // Ensure that segment_ids has the same shape. + auto& segment_ids_shape = getInputShape(ctx, 1); + auto& segment_ids_dims = segment_ids_shape.dim(); + if (segment_ids_dims.size() != 2) { + fail_shape_inference("segment_ids input shall be 2 dimensions"); + } + } + + // get hidden_size from the last dimension of embedding + auto& word_embedding_shape = getInputShape(ctx, 2); + auto& word_embedding_dims = word_embedding_shape.dim(); + if (word_embedding_dims.size() != 2 || + !word_embedding_dims[1].has_dim_value() || + word_embedding_shape.dim(1).dim_value() <= 0) { + fail_shape_inference("word_embedding should have 2 dimensions and dimension size is known."); + } + int64_t hidden_size = word_embedding_shape.dim(1).dim_value(); + + // Ensure that all embeddings + the gamma/beta tensors have the same hidden_size: + auto& position_embedding_shape = getInputShape(ctx, 3); + auto& position_embedding_dims = position_embedding_shape.dim(); + if (position_embedding_dims.size() != 2 || + !position_embedding_dims[1].has_dim_value() || + position_embedding_shape.dim(1).dim_value() != hidden_size) { + fail_shape_inference( + "position_embedding should have 2 dimensions, dimension size known, " + "and same hidden size as word_embedding."); + } + + if (has_segment) { + auto& segment_embedding_shape = getInputShape(ctx, 4); + auto& segment_embedding_dims = segment_embedding_shape.dim(); + if (segment_embedding_dims.size() != 2 || + !segment_embedding_dims[1].has_dim_value() || + segment_embedding_shape.dim(1).dim_value() != hidden_size) { + fail_shape_inference( + "segment_embedding should have 2 dimensions, dimension size known, " + "and same hidden size as word_embedding."); + } + } + + auto& gamma_shape = getInputShape(ctx, 5); + auto& gamma_dims = gamma_shape.dim(); + if (gamma_dims.size() != 1 || + !gamma_dims[0].has_dim_value() || + gamma_shape.dim(0).dim_value() != hidden_size) { + fail_shape_inference( + "gamma should have 2 dimension, dimension size known, " + "and same hidden size as word_embedding."); + } + + auto& beta_shape = getInputShape(ctx, 6); + auto& beta_dims = gamma_shape.dim(); + if (beta_dims.size() != 1 || + !beta_dims[0].has_dim_value() || + beta_shape.dim(0).dim_value() != hidden_size) { + fail_shape_inference( + "beta should have 1 dimension, dimension size known, " + "and same hidden size as word_embedding."); + } + + // input shape is (batch_size, sequence_length), output shape is (batch_size, sequence_length, hidden_size) + ONNX_NAMESPACE::TensorShapeProto output_shape; + *output_shape.add_dim() = input_ids_dims[0]; + *output_shape.add_dim() = input_ids_dims[1]; + + output_shape.add_dim(); + output_shape.mutable_dim(2)->set_dim_value(hidden_size); + + updateOutputShape(ctx, 0, output_shape); + + // mask_index shape is (batch_size) + ONNX_NAMESPACE::TensorShapeProto mask_index_shape; + *mask_index_shape.add_dim() = input_ids_dims[0]; + updateOutputShape(ctx, 1, mask_index_shape); + + if (ctx.getNumOutputs() > 2) { + updateOutputShape(ctx, 2, output_shape); + propagateElemTypeFromInputToOutput(ctx, 0, 2); + } +} + +void AttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, int past_input_index) { + // Type inference + ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 2, 0); + if (ctx.getNumOutputs() > 1) { + ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 2, 1); + } + + // Shape inference + if (hasInputShape(ctx, 0) && hasInputShape(ctx, 2)) { + auto& input_shape = getInputShape(ctx, 0); + auto& input_dims = input_shape.dim(); + if (input_dims.size() != 3) { + fail_shape_inference("Inputs 0 shall be 3 dimensions"); + } + + auto& bias_shape = getInputShape(ctx, 2); + auto& bias_dims = bias_shape.dim(); + if (bias_dims.size() != 1) { + fail_shape_inference("Invalid bias shape"); + } + + std::vector qkv_hidden_sizes; + getRepeatedAttribute(ctx, "qkv_hidden_sizes", qkv_hidden_sizes); + + int64_t output_hidden_size; + if (qkv_hidden_sizes.size() != 0) { + if (qkv_hidden_sizes.size() != 3) { + fail_shape_inference("qkv_hidden_sizes should have 3 elements") + } + output_hidden_size = qkv_hidden_sizes[2]; + } else { + output_hidden_size = bias_shape.dim(0).dim_value() / 3; + } + + ONNX_NAMESPACE::TensorShapeProto output_shape; + for (auto& dim : input_dims) { + *output_shape.add_dim() = dim; + } + + output_shape.mutable_dim(2)->set_dim_value(output_hidden_size); + updateOutputShape(ctx, 0, output_shape); + + // TODO does the extra output need any changes? + if (ctx.getNumOutputs() > 1) { + if (hasInputShape(ctx, past_input_index)) { + auto& past_shape = getInputShape(ctx, past_input_index); + auto& past_dims = past_shape.dim(); + if (past_dims.size() != 5) { + fail_shape_inference("Inputs 4 shall be 5 dimensions"); + } + + if (past_dims[3].has_dim_value() && input_dims[1].has_dim_value()) { + auto all_sequence_length = past_shape.dim(3).dim_value() + input_shape.dim(1).dim_value(); + + ONNX_NAMESPACE::TensorShapeProto present_shape; + for (auto& dim : past_dims) { + *present_shape.add_dim() = dim; + } + present_shape.mutable_dim(3)->set_dim_value(all_sequence_length); + + updateOutputShape(ctx, 1, present_shape); + } + } + } + } +} + +} // namespace contrib +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/graph/contrib_ops/shape_inference_functions.h b/onnxruntime/core/graph/contrib_ops/shape_inference_functions.h new file mode 100644 index 0000000000000..fe0274c6e6d5a --- /dev/null +++ b/onnxruntime/core/graph/contrib_ops/shape_inference_functions.h @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +// This file contains some helper functions that are used for implementing ONNX type/shape inference. + +namespace ONNX_NAMESPACE { +struct InferenceContext; +} + +namespace onnxruntime { +namespace contrib { +void AttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, int past_input_index); +void EmbedLayerNormalizationShapeInference(::ONNX_NAMESPACE::InferenceContext& ctx); +} +} // namespace onnxruntime \ No newline at end of file