Skip to content

Commit

Permalink
Contrib ops for TRT plugin: Disentangled Attention Plugin (#11287)
Browse files Browse the repository at this point in the history
* Add disentangled attention TRT plugin as contrib op

* update plugin name & remove null character

* update onnx-tensorrt submodule with my beta version

* use suggested plugin name & simpler shape propagation

* update onnx-tensorrt gitsubmodule to temporary fork

* update onnx-tensorrt to temporary commit

* redirect submodule back to latest 8.2-GA release of onnx-tensorrt repo

Co-authored-by: HHH-ComputeLab <[email protected]>
  • Loading branch information
symphonylyh and symphonylyh authored May 8, 2022
1 parent 70e5018 commit c2de603
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 2 deletions.
1 change: 1 addition & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,4 @@
[submodule "cmake/external/onnx-tensorrt"]
path = cmake/external/onnx-tensorrt
url = https://github.com/onnx/onnx-tensorrt.git
branch = 8.2-GA
2 changes: 1 addition & 1 deletion cmake/external/onnx-tensorrt
42 changes: 41 additions & 1 deletion onnxruntime/core/graph/contrib_ops/contrib_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2339,6 +2339,46 @@ void RegisterContribSchemas() {
updateOutputShape(ctx, 0, output_shape);
});

static const char* DisentangledAttention_TRT_ver1_doc =
R"DOC(Disentangled Attention TensorRT Plugin.)DOC";

ONNX_CONTRIB_OPERATOR_SCHEMA(DisentangledAttention_TRT)
.SetDomain(kOnnxDomain)
.SinceVersion(1)
.SetDoc(DisentangledAttention_TRT_ver1_doc)
.Input(0, "c2c_attention", "content-to-content attention tensor, QcKc^T.", "T")
.Input(1, "c2p_attention", "content-to-position attention tensor, QcKr^T.", "T")
.Input(2, "p2c_attention", "position-to-content attention tensor, KcQr^T.", "T")
.Output(0, "disentangled_attention", "The disentangled attention output tensor.", "T")
.TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float tensors.")
.Attr("span", "Maximum relative distance, k.", AttributeProto::INT)
.Attr("factor", "Scaling factor applied to attention values, 1/sqrt(3d). d is hidden size per head = H/N. H is hidden size, N is number of heads.", AttributeProto::FLOAT)
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
// Type inference
using namespace ONNX_NAMESPACE;
propagateElemTypeFromInputToOutput(ctx, 0, 0);

// Shape Inference
if (!hasInputShape(ctx, 0)) {
return;
}

auto& input0_shape = getInputShape(ctx, 0);
auto& input0_dims = input0_shape.dim();
if (input0_dims.size() != 3) {
fail_shape_inference("Input 0 shall be 3 dimensions");
}

// output dims is same as input[0] dims, i.e., regular c2c attention dims
// ONNX_NAMESPACE::TensorShapeProto disentangled_attention_shape;
// for (auto& dim : input0_dims) {
// *disentangled_attention_shape.add_dim() = dim;
// }
// updateOutputShape(ctx, 0, disentangled_attention_shape);
propagateShapeFromInputToOutput(ctx, 0, 0);

});

#ifndef _OPSCHEMA_LIB_
// Register the NCHWc schemas if supported by the platform.
if (MlasNchwcGetBlockSize() > 1) {
Expand All @@ -2352,4 +2392,4 @@ void RegisterContribSchemas() {
}

} // namespace contrib
} // namespace onnxruntime
} // namespace onnxruntime

0 comments on commit c2de603

Please sign in to comment.