Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Contrib ops for TRT plugin: Disentangled Attention Plugin #11287

Merged
merged 8 commits into from
May 8, 2022
40 changes: 39 additions & 1 deletion onnxruntime/core/graph/contrib_ops/contrib_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2339,6 +2339,44 @@ void RegisterContribSchemas() {
updateOutputShape(ctx, 0, output_shape);
});

static const char* DisentangledAttention_TRT_ver1_doc =
R"DOC(Disentangled Attention TensorRT Plugin.)DOC";

ONNX_CONTRIB_OPERATOR_SCHEMA(DisentangledAttentionPlugin)
symphonylyh marked this conversation as resolved.
Show resolved Hide resolved
.SetDomain(kOnnxDomain)
.SinceVersion(1)
.SetDoc(DisentangledAttention_TRT_ver1_doc)
.Input(0, "c2c_attention", "content-to-content attention tensor, QcKc^T.", "T")
.Input(1, "c2p_attention", "content-to-position attention tensor, QcKr^T.", "T")
.Input(2, "p2c_attention", "position-to-content attention tensor, KcQr^T.", "T")
.Output(0, "disentangled_attention", "The disentangled attention output tensor.", "T")
.TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float tensors.")
.Attr("span", "Maximum relative distance, k.", AttributeProto::INT)
.Attr("factor", "Scaling factor applied to attention values, 1/sqrt(3d). d is hidden size per head = H/N. H is hidden size, N is number of heads.", AttributeProto::FLOAT)
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
// Type inference
using namespace ONNX_NAMESPACE;
propagateElemTypeFromInputToOutput(ctx, 0, 0);

// Shape Inference
if (!hasInputShape(ctx, 0)) {
return;
}

auto& input0_shape = getInputShape(ctx, 0);
auto& input0_dims = input0_shape.dim();
if (input0_dims.size() != 3) {
fail_shape_inference("Input 0 shall be 3 dimensions");
}

// output dims is same as input[0] dims, i.e., regular c2c attention dims
symphonylyh marked this conversation as resolved.
Show resolved Hide resolved
ONNX_NAMESPACE::TensorShapeProto disentangled_attention_shape;
for (auto& dim : input0_dims) {
*disentangled_attention_shape.add_dim() = dim;
}
updateOutputShape(ctx, 0, disentangled_attention_shape);
});

#ifndef _OPSCHEMA_LIB_
// Register the NCHWc schemas if supported by the platform.
if (MlasNchwcGetBlockSize() > 1) {
Expand All @@ -2352,4 +2390,4 @@ void RegisterContribSchemas() {
}

} // namespace contrib
} // namespace onnxruntime
} // namespace onnxruntime