Skip to content

Commit

Permalink
Move QAttention/QEmbedLayerNormalization op defs to quantization_defs…
Browse files Browse the repository at this point in the history
….cc (#10507)
  • Loading branch information
snnn authored Feb 9, 2022
1 parent c9fbd0b commit 6f3ade5
Show file tree
Hide file tree
Showing 5 changed files with 317 additions and 289 deletions.
286 changes: 2 additions & 284 deletions onnxruntime/core/graph/contrib_ops/bert_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,177 +5,12 @@
#include "core/graph/contrib_ops/contrib_defs.h"
#include "core/graph/contrib_ops/quantization_defs.h"
#include "core/graph/contrib_ops/onnx_function_util.h"
#include "core/graph/contrib_ops/shape_inference_functions.h"

using namespace ::ONNX_NAMESPACE;

namespace onnxruntime {
namespace contrib {
void embedLayerNormalizationShapeInference(InferenceContext& ctx) {
propagateElemTypeFromInputToOutput(ctx, 2, 0);
propagateElemTypeFromInputToOutput(ctx, 0, 1);
if (!hasInputShape(ctx, 0)) {
// TODO(kreeger): In this case update the output to (?, ?, hidden_size).
return;
}

auto& input_ids_shape = getInputShape(ctx, 0);
auto& input_ids_dims = input_ids_shape.dim();

// Note that both batch size and sequence length could be symbolic.
// So we only check dimension size here.
if (input_ids_dims.size() != 2) {
fail_shape_inference("input_ids shall be 2 dimensions");
}

bool has_segment = hasInputShape(ctx, 1);
if (has_segment) {
// Ensure that segment_ids has the same shape.
auto& segment_ids_shape = getInputShape(ctx, 1);
auto& segment_ids_dims = segment_ids_shape.dim();
if (segment_ids_dims.size() != 2) {
fail_shape_inference("segment_ids input shall be 2 dimensions");
}
}

// get hidden_size from the last dimension of embedding
auto& word_embedding_shape = getInputShape(ctx, 2);
auto& word_embedding_dims = word_embedding_shape.dim();
if (word_embedding_dims.size() != 2 ||
!word_embedding_dims[1].has_dim_value() ||
word_embedding_shape.dim(1).dim_value() <= 0) {
fail_shape_inference("word_embedding should have 2 dimensions and dimension size is known.");
}
int64_t hidden_size = word_embedding_shape.dim(1).dim_value();

// Ensure that all embeddings + the gamma/beta tensors have the same hidden_size:
auto& position_embedding_shape = getInputShape(ctx, 3);
auto& position_embedding_dims = position_embedding_shape.dim();
if (position_embedding_dims.size() != 2 ||
!position_embedding_dims[1].has_dim_value() ||
position_embedding_shape.dim(1).dim_value() != hidden_size) {
fail_shape_inference(
"position_embedding should have 2 dimensions, dimension size known, "
"and same hidden size as word_embedding.");
}

if (has_segment) {
auto& segment_embedding_shape = getInputShape(ctx, 4);
auto& segment_embedding_dims = segment_embedding_shape.dim();
if (segment_embedding_dims.size() != 2 ||
!segment_embedding_dims[1].has_dim_value() ||
segment_embedding_shape.dim(1).dim_value() != hidden_size) {
fail_shape_inference(
"segment_embedding should have 2 dimensions, dimension size known, "
"and same hidden size as word_embedding.");
}
}

auto& gamma_shape = getInputShape(ctx, 5);
auto& gamma_dims = gamma_shape.dim();
if (gamma_dims.size() != 1 ||
!gamma_dims[0].has_dim_value() ||
gamma_shape.dim(0).dim_value() != hidden_size) {
fail_shape_inference(
"gamma should have 2 dimension, dimension size known, "
"and same hidden size as word_embedding.");
}

auto& beta_shape = getInputShape(ctx, 6);
auto& beta_dims = gamma_shape.dim();
if (beta_dims.size() != 1 ||
!beta_dims[0].has_dim_value() ||
beta_shape.dim(0).dim_value() != hidden_size) {
fail_shape_inference(
"beta should have 1 dimension, dimension size known, "
"and same hidden size as word_embedding.");
}

// input shape is (batch_size, sequence_length), output shape is (batch_size, sequence_length, hidden_size)
ONNX_NAMESPACE::TensorShapeProto output_shape;
*output_shape.add_dim() = input_ids_dims[0];
*output_shape.add_dim() = input_ids_dims[1];

output_shape.add_dim();
output_shape.mutable_dim(2)->set_dim_value(hidden_size);

updateOutputShape(ctx, 0, output_shape);

// mask_index shape is (batch_size)
ONNX_NAMESPACE::TensorShapeProto mask_index_shape;
*mask_index_shape.add_dim() = input_ids_dims[0];
updateOutputShape(ctx, 1, mask_index_shape);

if (ctx.getNumOutputs() > 2) {
updateOutputShape(ctx, 2, output_shape);
propagateElemTypeFromInputToOutput(ctx, 0, 2);
}
}
void AttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, int past_input_index) {
// Type inference
ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 2, 0);
if (ctx.getNumOutputs() > 1) {
ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 2, 1);
}

// Shape inference
if (hasInputShape(ctx, 0) && hasInputShape(ctx, 2)) {
auto& input_shape = getInputShape(ctx, 0);
auto& input_dims = input_shape.dim();
if (input_dims.size() != 3) {
fail_shape_inference("Inputs 0 shall be 3 dimensions");
}

auto& bias_shape = getInputShape(ctx, 2);
auto& bias_dims = bias_shape.dim();
if (bias_dims.size() != 1) {
fail_shape_inference("Invalid bias shape");
}

std::vector<int64_t> qkv_hidden_sizes;
getRepeatedAttribute(ctx, "qkv_hidden_sizes", qkv_hidden_sizes);

int64_t output_hidden_size;
if (qkv_hidden_sizes.size() != 0) {
if (qkv_hidden_sizes.size() != 3) {
fail_shape_inference("qkv_hidden_sizes should have 3 elements")
}
output_hidden_size = qkv_hidden_sizes[2];
} else {
output_hidden_size = bias_shape.dim(0).dim_value() / 3;
}

ONNX_NAMESPACE::TensorShapeProto output_shape;
for (auto& dim : input_dims) {
*output_shape.add_dim() = dim;
}

output_shape.mutable_dim(2)->set_dim_value(output_hidden_size);
updateOutputShape(ctx, 0, output_shape);

// TODO does the extra output need any changes?
if (ctx.getNumOutputs() > 1) {
if (hasInputShape(ctx, past_input_index)) {
auto& past_shape = getInputShape(ctx, past_input_index);
auto& past_dims = past_shape.dim();
if (past_dims.size() != 5) {
fail_shape_inference("Inputs 4 shall be 5 dimensions");
}

if (past_dims[3].has_dim_value() && input_dims[1].has_dim_value()) {
auto all_sequence_length = past_shape.dim(3).dim_value() + input_shape.dim(1).dim_value();

ONNX_NAMESPACE::TensorShapeProto present_shape;
for (auto& dim : past_dims) {
*present_shape.add_dim() = dim;
}
present_shape.mutable_dim(3)->set_dim_value(all_sequence_length);

updateOutputShape(ctx, 1, present_shape);
}
}
}
}
}

void DecoderAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx) {
// Type inference
Expand Down Expand Up @@ -255,86 +90,6 @@ ONNX_MS_OPERATOR_SET_SCHEMA(Attention, 1,
AttentionTypeAndShapeInference(ctx, past_input_index);
}));

ONNX_MS_OPERATOR_SET_SCHEMA(QAttention, 1,
OpSchema()
.SetDoc("Quantization of Multi-Head Self Attention.")
.Attr("num_heads", "Number of attention heads", AttributeProto::INT)
.Attr("unidirectional",
"Whether every token can only attend to previous tokens. Default value is 0.",
AttributeProto::INT,
static_cast<int64_t>(0))
.Input(
0,
"input",
"3D input tensor with shape (batch_size, sequence_length, input_hidden_size)",
"T1")
.Input(
1,
"weight",
"2D input tensor with shape (input_hidden_size, 3 * hidden_size), hidden_size = num_heads * head_size",
"T2")
.Input(
2,
"bias",
"1D input tensor with shape (3 * hidden_size)",
"T3")
.Input(
3,
"input_scale",
"scale of quantized input tensor. It's a scalar, which means a per-tensor/layer quantization.",
"T3")
.Input(
4,
"weight_scale",
"scale of weight scale. It's a scalar or a 1D tensor, which means a per-tensor/per-column quantization."
"Its size should be 3 * hidden_size if it is per-column quantization",
"T3")
.Input(
5,
"mask_index",
"Attention mask index with shape (batch_size)",
"T4",
OpSchema::Optional)
.Input(
6,
"input_zero_point",
"zero point of quantized input tensor.It's a scalar, which means a per-tensor/layer quantization.",
"T1",
OpSchema::Optional)
.Input(
7,
"weight_zero_point",
"zero point of quantized weight tensor. It's a scalar or a 1D tensor, which means a per-tensor/per-column quantization."
"Its size should be 3 * hidden_size if it is per-column quantization",
"T2",
OpSchema::Optional)
.Input(
8,
"past",
"past state for key and value with shape (2, batch_size, num_heads, past_sequence_length, head_size).",
"T3",
OpSchema::Optional)
.Output(
0,
"output",
"3D output tensor with shape (batch_size, sequence_length, hidden_size)",
"T3")
.Output(
1,
"present",
"present state for key and value with shape (2, batch_size, num_heads, past_sequence_length + sequence_length, head_size)",
"T3",
OpSchema::Optional)
.TypeConstraint("T1", {"tensor(int8)", "tensor(uint8)"}, "Constrain input and output types to int8 tensors.")
.TypeConstraint("T2", {"tensor(int8)", "tensor(uint8)"}, "Constrain input and output types to int8 tensors.")
.TypeConstraint("T3", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float tensors.")
.TypeConstraint("T4", {"tensor(int32)"}, "Constrain mask index to integer types")
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
constexpr int past_input_index = 8;

AttentionTypeAndShapeInference(ctx, past_input_index);
}));

constexpr const char* Longformer_Attention_doc = R"DOC(
Longformer Self Attention with a local context and a global context. Tokens attend locally: Each token
attends to its W previous tokens and W succeding tokens with W being the window length. A selected few tokens
Expand Down Expand Up @@ -420,44 +175,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(EmbedLayerNormalization, 1,
.Output(2, "embedding_sum", "sum of word_embedding and position_embedding without layer normalization", "T", OpSchema::Optional)
.TypeConstraint("T1", {"tensor(int32)"}, "Constrain input and output integer tensors types")
.TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output float tensors types.")
.TypeAndShapeInferenceFunction(embedLayerNormalizationShapeInference));

constexpr const char* QEmbedLayerNormalization_ver1_doc = R"DOC(
QEmbedLayerNormalization is the quantized fusion of embedding layer in BERT model, with optional mask processing.
The embedding layer takes input_ids (word IDs) and segment_ids (sentence IDs) to look up word_embedding, position_embedding,
and segment_emedding; the embeddings are added then applied layer normalization using gamma and beta tensors. The input_ids
and segment_ids remain int32. All embeddings, gamma, and beta tensors are converted to int8/uint8. The last input mask is optional.
If mask is provided, mask index (that is position of first 0 in mask, or number of words will be calculated.)DOC";

ONNX_MS_OPERATOR_SET_SCHEMA(QEmbedLayerNormalization, 1,
OpSchema()
.SetSupportLevel(OpSchema::SupportType::EXPERIMENTAL)
.SetDoc(QEmbedLayerNormalization_ver1_doc)
.Attr("epsilon", "The epsilon value to use to avoid division by zero.", AttributeProto::FLOAT, kDefaultEmbedLayerNormEpsilon)
.Input(0, "input_ids", "2D words IDs with shape (batch_size, sequence_length)", "T1")
.Input(1, "segment_ids", "2D segment IDs with shape (batch_size, sequence_length)", "T1", OpSchema::Optional)
.Input(2, "word_embedding_quant", "2D with shape (,hidden_size)", "T2")
.Input(3, "position_embedding_quant", "2D with shape (, hidden_size)", "T2")
.Input(4, "segment_embedding", "2D with shape (, hidden_size)", "T2", OpSchema::Optional)
.Input(5, "gamma_quant", "1D gamma tensor for layer normalization with shape (hidden_size)", "T2")
.Input(6, "beta_quant", "1D beta tensor for layer normalization with shape (hidden_size)", "T2")
.Input(7, "mask", "Mask", "T1", OpSchema::Optional)
.Input(8, "word_embedding_scale", "Scale for word embeddings", "T")
.Input(9, "position_embedding_scale", "Scale for position embeddings", "T")
.Input(10, "segment_embedding_scale", "Scale for segment embeddings", "T", OpSchema::Optional)
.Input(11, "gamma_scale", "Scale for 1D gamma tensor", "T")
.Input(12, "beta_scale", "Scale for 1D beta tensor", "T")
.Input(13, "word_embedding_zero_point", "Zero point for word embeddings", "T2")
.Input(14, "position_embedding_zero_point", "Zero point for position embeddings", "T2")
.Input(15, "segment_embedding_zero_point", "Zero Point for segment embeddings", "T2", OpSchema::Optional)
.Input(16, "gamma_zero_point", "Zero Point for 1D gamma tensor", "T2")
.Input(17, "beta_zero_point", "Zero Point for 1D beta tensor", "T2")
.Output(0, "layernorm_out", "LayerNorm Output", "T")
.Output(1, "mask_index_out", "Mask Index Output", "T1")
.TypeConstraint("T1", {"tensor(int32)"}, "Constrain mask index to integer types")
.TypeConstraint("T2", {"tensor(int8)", "tensor(uint8)"}, "Constrain input and output types to int8 tensors.")
.TypeConstraint("T", {"tensor(float)"}, "Constrain input and output types to float32 tensors.")
.TypeAndShapeInferenceFunction(embedLayerNormalizationShapeInference));
.TypeAndShapeInferenceFunction(EmbedLayerNormalizationShapeInference));

constexpr const char* FastGelu_ver1_doc = R"DOC(
GELU (Gaussian Error Linear Unit) approximation: Y=0.5*X*(1+tanh(0.797885*X+0.035677*X*X*X)) with an optional input of bias that will be added to X before GELU.)DOC";
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/graph/contrib_ops/ms_opset.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, DynamicQuantizeLSTM);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, DynamicQuantizeMatMul);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MatMulIntegerToFloat);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MulInteger);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QAttention);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QEmbedLayerNormalization);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QGemm);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QLinearAdd);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QLinearConcat);
Expand Down Expand Up @@ -60,8 +62,6 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MaxpoolWithMask);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MurmurHash3);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, NGramRepeatBlock);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Pad);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QAttention);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QEmbedLayerNormalization);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Rfft);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SampleOp);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SkipLayerNormalization);
Expand Down
Loading

0 comments on commit 6f3ade5

Please sign in to comment.