diff --git a/.gitmodules b/.gitmodules index 770156842e28c..95a9344ebf919 100644 --- a/.gitmodules +++ b/.gitmodules @@ -72,3 +72,4 @@ [submodule "cmake/external/onnx-tensorrt"] path = cmake/external/onnx-tensorrt url = https://github.com/onnx/onnx-tensorrt.git + branch = 8.2-GA \ No newline at end of file diff --git a/cmake/external/onnx-tensorrt b/cmake/external/onnx-tensorrt index 4f54a1950e117..f42daeee49f25 160000 --- a/cmake/external/onnx-tensorrt +++ b/cmake/external/onnx-tensorrt @@ -1 +1 @@ -Subproject commit 4f54a1950e1174dca490900eb7b07cc374f53d41 +Subproject commit f42daeee49f2517a954c5601f0f76bef9ed94b62 diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index a433951987f5f..b0d64c01c5d5b 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -2339,6 +2339,46 @@ void RegisterContribSchemas() { updateOutputShape(ctx, 0, output_shape); }); + static const char* DisentangledAttention_TRT_ver1_doc = + R"DOC(Disentangled Attention TensorRT Plugin.)DOC"; + + ONNX_CONTRIB_OPERATOR_SCHEMA(DisentangledAttention_TRT) + .SetDomain(kOnnxDomain) + .SinceVersion(1) + .SetDoc(DisentangledAttention_TRT_ver1_doc) + .Input(0, "c2c_attention", "content-to-content attention tensor, QcKc^T.", "T") + .Input(1, "c2p_attention", "content-to-position attention tensor, QcKr^T.", "T") + .Input(2, "p2c_attention", "position-to-content attention tensor, KcQr^T.", "T") + .Output(0, "disentangled_attention", "The disentangled attention output tensor.", "T") + .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float tensors.") + .Attr("span", "Maximum relative distance, k.", AttributeProto::INT) + .Attr("factor", "Scaling factor applied to attention values, 1/sqrt(3d). d is hidden size per head = H/N. H is hidden size, N is number of heads.", AttributeProto::FLOAT) + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + // Type inference + using namespace ONNX_NAMESPACE; + propagateElemTypeFromInputToOutput(ctx, 0, 0); + + // Shape Inference + if (!hasInputShape(ctx, 0)) { + return; + } + + auto& input0_shape = getInputShape(ctx, 0); + auto& input0_dims = input0_shape.dim(); + if (input0_dims.size() != 3) { + fail_shape_inference("Input 0 shall be 3 dimensions"); + } + + // output dims is same as input[0] dims, i.e., regular c2c attention dims + // ONNX_NAMESPACE::TensorShapeProto disentangled_attention_shape; + // for (auto& dim : input0_dims) { + // *disentangled_attention_shape.add_dim() = dim; + // } + // updateOutputShape(ctx, 0, disentangled_attention_shape); + propagateShapeFromInputToOutput(ctx, 0, 0); + + }); + #ifndef _OPSCHEMA_LIB_ // Register the NCHWc schemas if supported by the platform. if (MlasNchwcGetBlockSize() > 1) { @@ -2352,4 +2392,4 @@ void RegisterContribSchemas() { } } // namespace contrib -} // namespace onnxruntime +} // namespace onnxruntime \ No newline at end of file