From 3e4a471f368d8f1fe4e2536fb145cd69627c6be7 Mon Sep 17 00:00:00 2001 From: Patrice Vignola Date: Thu, 23 Feb 2023 11:00:29 -0800 Subject: [PATCH] Cherry-pick 7 changes into 1.14.1 (#14762) This cherry picks the following 7 changes into 1.14.1: https://github.com/microsoft/onnxruntime/commit/1b7f65437efa74c8cd40a85250cebb3a6aa44aa4 https://github.com/microsoft/onnxruntime/commit/b539c364eec10cf238ffaf2e323ababaa05e433b https://github.com/microsoft/onnxruntime/commit/12d91173c4478e0975771c06fd9d062a33c46339 https://github.com/microsoft/onnxruntime/commit/ff3aed85404ba76debec38d4e16e6376dcb53cbe https://github.com/microsoft/onnxruntime/commit/3d79b1f06eff2c8c43ab6f549a49771307fb6e56 https://github.com/microsoft/onnxruntime/commit/c0d2472edeedabbb4332432f6cee8bae61e79f48 https://github.com/microsoft/onnxruntime/commit/e9ec4c098b07a0f9d7b00888a7c37d59d02b24fe --------- Signed-off-by: Cliff Woolley Co-authored-by: Sheil Kumar Co-authored-by: Ye Wang <52801275+wangyems@users.noreply.github.com> Co-authored-by: Ubuntu Co-authored-by: Lei Zhang Co-authored-by: Misha Chornyi <99709299+mc-nv@users.noreply.github.com> Co-authored-by: Cliff Woolley Co-authored-by: cao lei Co-authored-by: Lei Cao Co-authored-by: Tianlei Wu Co-authored-by: Vincent Wang Co-authored-by: Jian Chen --- cmake/onnxruntime_rocm_hipify.cmake | 4 + docs/ContribOperators.md | 63 +- docs/OperatorKernels.md | 36 +- .../onnxruntime/core/framework/ort_value.h | 2 +- include/onnxruntime/core/framework/tensor.h | 8 + java/build.gradle | 8 +- onnxruntime/contrib_ops/cpu/bert/attention.cc | 6 +- .../contrib_ops/cpu/bert/attention_base.cc | 48 +- .../contrib_ops/cpu/bert/attention_base.h | 4 +- .../contrib_ops/cpu/bert/attention_common.h | 30 +- .../contrib_ops/cpu/bert/attention_cpu_base.h | 38 +- .../cpu/bert/multihead_attention_helper.h | 32 + .../cpu/quantization/attention_quant.cc | 2 +- .../contrib_ops/cuda/bert/attention.cc | 26 +- onnxruntime/contrib_ops/cuda/bert/attention.h | 3 +- .../contrib_ops/cuda/bert/attention_impl.cu | 9 +- .../contrib_ops/cuda/bert/attention_impl.h | 2 +- .../contrib_ops/cuda/bert/attention_softmax.h | 10 +- .../cuda/bert/multihead_attention.cc | 13 +- .../cuda/bert/multihead_attention.h | 2 +- .../cuda/bert/relative_attn_bias.cc | 122 ++- .../cuda/bert/relative_attn_bias.h | 12 + .../cuda/bert/relative_attn_bias_impl.cu | 118 ++- .../cuda/bert/relative_attn_bias_impl.h | 15 + .../contrib_ops/cuda/cuda_contrib_kernels.cc | 4 + .../quantization/attention_quantization.cc | 4 +- .../qordered_ops/qordered_attention.cc | 2 +- .../qordered_attention_input_enum.h | 2 +- .../contrib_ops/rocm/bert/attention.cc | 6 +- .../contrib_ops/rocm/bert/attention_impl.cu | 14 +- .../contrib_ops/rocm/bert/attention_impl.h | 2 +- onnxruntime/core/framework/TensorSeq.h | 56 +- onnxruntime/core/framework/tensor.cc | 6 + onnxruntime/core/framework/utils.cc | 118 +-- .../core/graph/contrib_ops/bert_defs.cc | 44 +- onnxruntime/core/graph/contrib_ops/ms_opset.h | 2 + .../graph/contrib_ops/quantization_defs.cc | 4 +- .../core/providers/cann/tensor/identity_op.h | 1 + .../core/providers/cpu/controlflow/loop.cc | 17 +- .../core/providers/cpu/cpu_provider_shared.cc | 4 +- .../core/providers/cpu/cpu_provider_shared.h | 2 +- .../providers/cpu/optional/optional_ops.cc | 14 +- .../cpu/sequence/concat_from_sequence.cc | 2 +- .../providers/cpu/sequence/sequence_ops.cc | 53 +- .../core/providers/cpu/tensor/concat.h | 1 - .../core/providers/cpu/tensor/identity_op.h | 18 +- .../cuda/activation/activations_impl.cu | 13 +- .../providers/cuda/cuda_execution_provider.cc | 1 + .../core/providers/cuda/tensor/identity_op.h | 1 + .../inc/DmlExecutionProvider.h | 7 +- .../inc/IWinmlExecutionProvider.h | 12 +- .../inc/MLOperatorAuthor.h | 267 +++--- .../src/AbiCustomRegistry.cpp | 5 +- .../src/ExecutionProvider.cpp | 109 ++- .../src/ExecutionProvider.h | 1 + .../src/FusedGraphKernel.cpp | 3 +- .../src/IExecutionProvider.h | 3 +- .../src/MLOperatorAuthorImpl.cpp | 874 +++++++++++++++--- .../src/MLOperatorAuthorImpl.h | 74 +- .../src/Operators/DmlOperator.cpp | 99 +- .../src/Operators/DmlOperator.h | 24 + .../src/Operators/DmlOperatorAttention.cpp | 6 +- .../DmlOperatorConcatFromSequence.cpp | 161 ++++ .../src/Operators/DmlOperatorMemcpy.cpp | 44 +- .../src/Operators/OperatorRegistration.cpp | 186 +++- .../src/Operators/OperatorRegistration.h | 4 +- .../dml/DmlExecutionProvider/src/TensorDesc.h | 2 + .../dml/OperatorAuthorHelper/Attributes.h | 1 + .../MLOperatorAuthorHelper.h | 115 +++ .../MLOperatorAuthorPrivate.h | 114 ++- .../dml/OperatorAuthorHelper/OperatorHelper.h | 6 + .../OperatorAuthorHelper/OperatorVersions.h | 1 + .../providers/rocm/rocm_execution_provider.cc | 1 + .../providers/shared_library/provider_api.h | 2 +- .../provider_bridge_provider.cc | 4 +- .../shared_library/provider_interfaces.h | 3 + .../shared_library/provider_wrappedtypes.h | 6 +- onnxruntime/core/session/onnxruntime_c_api.cc | 26 +- .../core/session/provider_bridge_ort.cc | 3 + .../core/session/standalone_op_invoker.cc | 1 + .../python/onnxruntime_pybind_mlvalue.cc | 14 +- .../python/onnxruntime_pybind_state.cc | 2 +- .../tools/transformers/benchmark_helper.py | 1 + .../tools/transformers/fusion_attention.py | 2 +- .../transformers/models/gpt2/gpt2_helper.py | 18 +- .../python/tools/transformers/onnx_model.py | 18 +- .../tools/transformers/onnx_model_tnlr.py | 6 +- .../test/contrib_ops/attention_op_test.cc | 39 +- .../multihead_attention_op_test.cc | 10 +- .../contrib_ops/qordered_attention_test.cc | 2 +- .../relative_attention_bias_test.cc | 265 +++++- .../sequence/concat_from_sequence_op_test.cc | 9 +- .../providers/cpu/tensor/identity_op_test.cc | 25 - .../test/providers/provider_test_utils.h | 16 +- .../python/orttraining_test_ortmodule_api.py | 33 + .../orttraining/training_api/optimizer.cc | 7 +- .../training_ops/cpu/optimizer/adamw/adamw.cc | 1 + .../clip_grad_norm/clip_grad_norm.cc | 17 +- .../training_ops/cpu/optimizer/common.cc | 1 + .../training_ops/cpu/optimizer/sgd/sgd.cc | 1 + .../github/android/setup_gradle_wrapper.sh | 19 - ...ndroid-x86_64-crosscompile-ci-pipeline.yml | 15 +- .../templates/android-java-api-aar-test.yml | 3 +- .../azure-pipelines/templates/mac-ci.yml | 2 + .../templates/react-native-ci.yml | 4 +- .../templates/set-up-gradle-wrapper-step.yml | 10 + 106 files changed, 2927 insertions(+), 781 deletions(-) create mode 100644 onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorConcatFromSequence.cpp delete mode 100755 tools/ci_build/github/android/setup_gradle_wrapper.sh create mode 100644 tools/ci_build/github/azure-pipelines/templates/set-up-gradle-wrapper-step.yml diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake index 85246ec8bd37c..d03bc93095103 100644 --- a/cmake/onnxruntime_rocm_hipify.cmake +++ b/cmake/onnxruntime_rocm_hipify.cmake @@ -19,6 +19,10 @@ set(contrib_ops_excluded_files "bert/fast_gelu_impl.h" "bert/fast_gelu.cc" "bert/fast_gelu.h" + "bert/relative_attn_bias.cc" + "bert/relative_attn_bias.h" + "bert/relative_attn_bias_impl.cu" + "bert/relative_attn_bias_impl.h" "bert/skip_layer_norm.cc" "bert/skip_layer_norm.h" "bert/skip_layer_norm_impl.cu" diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 8cd6d4c9e26f1..f01a7ab14a61e 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -30,6 +30,7 @@ Do not modify directly.* * com.microsoft.FusedConv * com.microsoft.FusedGemm * com.microsoft.FusedMatMul + * com.microsoft.GatedRelativePositionBias * com.microsoft.GatherND * com.microsoft.Gelu * com.microsoft.GemmFastGelu @@ -152,7 +153,7 @@ This version of the operator has been available since version 1 of the 'com.micr
Attention mask with shape (batch_size, 1, max_sequence_length, max_sequence_length), (batch_size, total_sequence_length) or (batch_size, sequence_length, total_sequence_length), or index with shape (batch_size) or (2 * batch_size)
past (optional) : T
past state for key and value with shape (2, batch_size, num_heads, past_sequence_length, head_size)When past_present_share_buffer is set, its shape is (2, batch_size, num_heads, max_sequence_length, head_size)
-
extra_add (optional) : T
+
relative_position_bias (optional) : T
additional add to QxK' with shape (batch_size, num_heads, sequence_length, total_sequence_length)
past_sequence_length (optional) : M
When past_present_share_buffer is used, it is required to specify past_sequence_length (could be 0).
@@ -1608,6 +1609,58 @@ This version of the operator has been available since version 1 of the 'com.micr +### **com.microsoft.GatedRelativePositionBias** + + query_layer = (query_layer + query_bias).reshape(batch_size, seq_len, num_heads, head_size).transpose(1, 2) + gate_u, gate_r = torch.sigmoid( + self.gate_ur_linear(query_layer).view(batch_size, num_head, seq_len, 2, D/2).sum(-1, keepdim=False) + ).chunk(2, dim=-1) + gate_u_1 = gate_u * (gate_r * self.eco_a - 1.0) + 2.0 + rel_pos_bias = gate_u_1 * rel_pos + +#### Version + +This version of the operator has been available since version 1 of the 'com.microsoft' operator set. + +#### Attributes + +
+
num_heads : int (required)
+
Number of attention heads
+
+ +#### Inputs + +
+
query_layer : T
+
tensor with shape (batch_size, seq_len, num_heads x head_size)
+
query_bias : T
+
1-d tensor with shape (num_heads x head_size)
+
rel_pos : T
+
tensor with shape (1, num_head, seq_len, seq_len)
+
weight : T
+
gemm weight for the gated_ur_linear, shape (head_size, D), D is divisible by 2
+
bias : T
+
bias for the gated_ur_linear, shape (D)
+
eco_a : T
+
tensor of shape (1, num_heads, 1, 1)
+
+ +#### Outputs + +
+
output : T
+
output tensor with shape (batch_size, num_heads, seq_len, seq_len)
+
+ +#### Type Constraints + +
+
T : tensor(float), tensor(float16)
+
Constrain input and output types to float tensors.
+
+ + ### **com.microsoft.GatherND** Given `data` tensor of rank r >= 1, and `indices` tensor of rank q >= 1, gather @@ -2222,7 +2275,7 @@ This version of the operator has been available since version 1 of the 'com.micr
Number of attention heads
-#### Inputs (2 - 5) +#### Inputs (2 - 6)
query : T
@@ -2235,6 +2288,8 @@ This version of the operator has been available since version 1 of the 'com.micr
Bias tensor with shape (hidden_size + hidden_size + v_hidden_size) from input projection
key_padding_mask (optional) : M
Key padding mask with shape (batch_size) or (batch_size, kv_sequence_length)
+
relative_position_bias (optional) : T
+
relative position bias: addition to QxK' with shape (batch_size, num_heads, sequence_length, total_sequence_length) or (1, num_heads, sequence_length, total_sequence_length)
#### Outputs @@ -3221,7 +3276,7 @@ This version of the operator has been available since version 1 of the 'com.micr left-side padding, mask_index has shape (2 * batch_size), where the values are the exclusive end positions followed by the inclusive start positions. When unidirectional is 1, and each token only attend to previous tokens. For GPT-2, both past and present state are optional. Present state could appear in output even when past state is not in input. - Current version does not support past/present, extra_add and qkv_hidden_sizes. + Current version does not support past/present, relative_position_bias and qkv_hidden_sizes. TODO: Support them if needed in the future. #### Version @@ -3286,7 +3341,7 @@ This version of the operator has been available since version 1 of the 'com.micr
Attention mask with shape (batch_size, 1, max_sequence_length, max_sequence_length), (batch_size, past_sequence_length + sequence_length)or (batch_size, sequence_length, past_sequence_length + sequence_length), or index with shape (batch_size) or (2 * batch_size).
past (optional) : Q
past state for key and value with shape (2, batch_size, num_heads, past_sequence_length, head_size).
-
extra_add (optional) : S
+
relative_position_bias (optional) : S
additional add to QxK' with shape (batch_size, num_heads, sequence_length, sequence_length).
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index d23f0e0e2180e..a034ac80bcf67 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -417,7 +417,7 @@ Do not modify directly.* | | | | |**Operator Domain:** *com.microsoft*|||| -|Attention|*in* input:**T**
*in* weights:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* extra_add:**T**
*in* past_sequence_length:**M**
*out* output:**T**
*out* present:**T**|1+|**T** = tensor(float)| +|Attention|*in* input:**T**
*in* weights:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* relative_position_bias:**T**
*in* past_sequence_length:**M**
*out* output:**T**
*out* present:**T**|1+|**T** = tensor(float)| |AttnLSTM|*in* X:**T**
*in* W:**T**
*in* R:**T**
*in* B:**T**
*in* sequence_lens:**T1**
*in* initial_h:**T**
*in* initial_c:**T**
*in* P:**T**
*in* QW:**T**
*in* MW:**T**
*in* V:**T**
*in* M:**T**
*in* memory_seq_lens:**T1**
*in* AW:**T**
*out* Y:**T**
*out* Y_h:**T**
*out* Y_c:**T**|1+|**T** = tensor(double), tensor(float)
**T1** = tensor(int32)| |BeamSearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* num_beams:**I**
*in* num_return_sequences:**I**
*in* length_penalty:**T**
*in* repetition_penalty:**T**
*in* vocab_mask:**M**
*in* prefix_vocab_mask:**M**
*in* attention_mask:**I**
*out* sequences:**I**
*out* sequences_scores:**T**
*out* scores:**T**|1+|**T** = tensor(float)| |BiasGelu|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(float)| @@ -785,7 +785,7 @@ Do not modify directly.* | | | | |**Operator Domain:** *com.microsoft*|||| -|Attention|*in* input:**T**
*in* weights:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* extra_add:**T**
*in* past_sequence_length:**M**
*out* output:**T**
*out* present:**T**|1+|**T** = tensor(float), tensor(float16)| +|Attention|*in* input:**T**
*in* weights:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* relative_position_bias:**T**
*in* past_sequence_length:**M**
*out* output:**T**
*out* present:**T**|1+|**T** = tensor(float), tensor(float16)| |BeamSearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* num_beams:**I**
*in* num_return_sequences:**I**
*in* length_penalty:**T**
*in* repetition_penalty:**T**
*in* vocab_mask:**M**
*in* prefix_vocab_mask:**M**
*in* attention_mask:**I**
*out* sequences:**I**
*out* sequences_scores:**T**
*out* scores:**T**|1+|**T** = tensor(float), tensor(float16)| |BiasDropout|*in* data:**T**
*in* bias:**T**
*in* residual:**T**
*in* ratio:**T1**
*in* training_mode:**T2**
*out* output:**T**
*out* mask:**T2**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**T2** = tensor(bool)| |BiasGelu|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| @@ -803,6 +803,7 @@ Do not modify directly.* |FastGelu|*in* X:**T**
*in* bias:**T**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(float), tensor(float16)| |FusedConv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*in* Z:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |FusedMatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| +|GatedRelativePositionBias|*in* query_layer:**T**
*in* query_bias:**T**
*in* rel_pos:**T**
*in* weight:**T**
*in* bias:**T**
*in* eco_a:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| |Gelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |GreedySearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*out* sequences:**I**|1+|**T** = tensor(float), tensor(float16)| |GridSample|*in* X:**T1**
*in* Grid:**T1**
*out* Y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(float)| @@ -810,11 +811,11 @@ Do not modify directly.* |Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |Irfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |LongformerAttention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask:**T**
*in* global_weight:**T**
*in* global_bias:**T**
*in* global:**G**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| -|MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| +|MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* relative_position_bias:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| |NGramRepeatBlock|*in* input_ids:**Tid**
*in* scores:**T**
*out* scores_out:**T**|1+|**T** = tensor(float)
**Tid** = tensor(int64)| |NhwcConv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| |QAttention|*in* input:**T1**
*in* weight:**T2**
*in* bias:**T3**
*in* input_scale:**T3**
*in* weight_scale:**T3**
*in* mask_index:**T4**
*in* input_zero_point:**T1**
*in* weight_zero_point:**T2**
*in* past:**T3**
*out* output:**T3**
*out* present:**T3**|1+|**T1** = tensor(int8)
**T2** = tensor(int8)
**T3** = tensor(float), tensor(float16)
**T4** = tensor(int32)| -|QOrderedAttention|*in* input:**Q**
*in* scale_input:**S**
*in* scale_Q_gemm:**S**
*in* scale_K_gemm:**S**
*in* scale_V_gemm:**S**
*in* Q_weight:**Q**
*in* K_weight:**Q**
*in* V_weight:**Q**
*in* scale_Q_weight:**S**
*in* scale_K_weight:**S**
*in* scale_V_weight:**S**
*in* Q_bias:**S**
*in* K_bias:**S**
*in* V_bias:**S**
*in* scale_QKT_gemm:**S**
*in* scale_QKT_softmax:**S**
*in* scale_values_gemm:**S**
*in* mask_index:**G**
*in* past:**Q**
*in* extra_add:**S**
*out* output:**Q**|1+|**G** = tensor(int32)
**Q** = tensor(int8)
**S** = tensor(float)| +|QOrderedAttention|*in* input:**Q**
*in* scale_input:**S**
*in* scale_Q_gemm:**S**
*in* scale_K_gemm:**S**
*in* scale_V_gemm:**S**
*in* Q_weight:**Q**
*in* K_weight:**Q**
*in* V_weight:**Q**
*in* scale_Q_weight:**S**
*in* scale_K_weight:**S**
*in* scale_V_weight:**S**
*in* Q_bias:**S**
*in* K_bias:**S**
*in* V_bias:**S**
*in* scale_QKT_gemm:**S**
*in* scale_QKT_softmax:**S**
*in* scale_values_gemm:**S**
*in* mask_index:**G**
*in* past:**Q**
*in* relative_position_bias:**S**
*out* output:**Q**|1+|**G** = tensor(int32)
**Q** = tensor(int8)
**S** = tensor(float)| |QOrderedGelu|*in* X:**Q**
*in* scale_X:**S**
*in* scale_Y:**S**
*out* Y:**Q**|1+|**Q** = tensor(int8)
**S** = tensor(float)| |QOrderedLayerNormalization|*in* X:**Q**
*in* scale_X:**S**
*in* scale:**F**
*in* B:**F**
*in* scale_Y:**S**
*out* Y:**Q**|1+|**F** = tensor(float), tensor(float16)
**Q** = tensor(int8)
**S** = tensor(float)| |QOrderedLongformerAttention|*in* input:**Q**
*in* scale_input:**S**
*in* weight:**Q**
*in* scale_weight:**S**
*in* bias:**S**
*in* scale_bias:**S**
*in* scale_qkv_gemm:**S**
*in* mask:**F**
*in* global_weight:**Q**
*in* scale_global_weight:**S**
*in* global_bias:**S**
*in* scale_global_gemm:**S**
*in* global:**G**
*in* scale_output:**S**
*out* output:**Q**|1+|**F** = tensor(float16)
**G** = tensor(int32)
**Q** = tensor(int8)
**S** = tensor(float)| @@ -885,6 +886,7 @@ Do not modify directly.* |Concat|*in* inputs:**T**
*out* concat_result:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||11+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||4+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|ConcatFromSequence|*in* input_sequence:**S**
*out* concat_result:**T**|11+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |ConstantOfShape|*in* input:**T1**
*out* output:**T2**|9+|**T1** = tensor(int64)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |Conv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*out* Y:**T**|11+|**T** = tensor(float), tensor(float16)| |||1+|**T** = tensor(float), tensor(float16)| @@ -951,8 +953,8 @@ Do not modify directly.* |Hardmax|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(float), tensor(float16)| |||11+|**T** = tensor(float), tensor(float16)| |||1+|**T** = tensor(float), tensor(float16)| -|Identity|*in* input:**T**
*out* output:**T**

or

*in* input:**V**
*out* output:**V**|16+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|||14+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|Identity|*in* input:**T**
*out* output:**T**

or

*in* input:**V**
*out* output:**V**|16+|**V** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||14+|**V** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||1+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |ImageScaler|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| @@ -1003,8 +1005,8 @@ Do not modify directly.* |MeanVarianceNormalization|*in* X:**T**
*out* Y:**T**

or

*in* input:**T**
*out* output:**T**|13+|**T** = tensor(float), tensor(float16)| |||9+|**T** = tensor(float), tensor(float16)| |||1+|**T** = tensor(float), tensor(float16)| -|MemcpyFromHost|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|MemcpyToHost|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|MemcpyFromHost|*in* X:**T**
*out* Y:**T**|1+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|MemcpyToHost|*in* X:**T**
*out* Y:**T**|1+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |Min|*in* data_0:**T**
*out* min:**T**|13+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||12+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||8+|**T** = tensor(float), tensor(float16)| @@ -1096,9 +1098,15 @@ Do not modify directly.* |||13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||11+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |Selu|*in* X:**T**
*out* Y:**T**|6+|**T** = tensor(float), tensor(float16)| -|Shape|*in* data:**T**
*out* shape:**T1**|15+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| -|||13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| -|||1+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| +|SequenceAt|*in* input_sequence:**S**
*in* position:**I**
*out* tensor:**T**|11+|**I** = tensor(int32), tensor(int64)
**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))
**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|SequenceConstruct|*in* inputs:**T**
*out* output_sequence:**S**|11+|**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))
**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|SequenceEmpty|*out* output:**S**|11+|**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))| +|SequenceErase|*in* input_sequence:**S**
*in* position:**I**
*out* output_sequence:**S**|11+|**I** = tensor(int32), tensor(int64)
**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))| +|SequenceInsert|*in* input_sequence:**S**
*in* tensor:**T**
*in* position:**I**
*out* output_sequence:**S**|11+|**I** = tensor(int32), tensor(int64)
**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))| +|SequenceLength|*in* input_sequence:**S**
*out* length:**I**|11+|**I** = tensor(int64)
**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))| +|Shape|*in* data:**T**
*out* shape:**T1**|15+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| +|||13+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| +|||1+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| |Shrink|*in* input:**T**
*out* output:**T**|9+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint8)| |Sigmoid|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(float), tensor(float16)| |||6+|**T** = tensor(float), tensor(float16)| @@ -1106,8 +1114,8 @@ Do not modify directly.* |||9+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |Sin|*in* input:**T**
*out* output:**T**|7+|**T** = tensor(float), tensor(float16)| |Sinh|*in* input:**T**
*out* output:**T**|9+|**T** = tensor(float), tensor(float16)| -|Size|*in* data:**T**
*out* size:**T1**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| -|||1+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| +|Size|*in* data:**T**
*out* size:**T1**|13+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| +|||1+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| |Slice|*in* data:**T**
*in* starts:**Tind**
*in* ends:**Tind**
*in* axes:**Tind**
*in* steps:**Tind**
*out* output:**T**

or

*in* data:**T**
*out* output:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| |||11+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| |||10+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| @@ -1158,7 +1166,7 @@ Do not modify directly.* | | | | |**Operator Domain:** *com.microsoft*|||| -|Attention|*in* input:**T**
*in* weights:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* extra_add:**T**
*in* past_sequence_length:**M**
*out* output:**T**
*out* present:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| +|Attention|*in* input:**T**
*in* weights:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* relative_position_bias:**T**
*in* past_sequence_length:**M**
*out* output:**T**
*out* present:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| |BiasGelu|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(float), tensor(float16)| |ConvTransposeWithDynamicPads|*in* X:**T**
*in* W:**T**
*in* Pads:**tensor(int64)**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| |DequantizeLinear|*in* x:**T1**
*in* x_scale:**T2**
*in* x_zero_point:**T1**
*out* y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)| diff --git a/include/onnxruntime/core/framework/ort_value.h b/include/onnxruntime/core/framework/ort_value.h index 7cbbdcc919ea7..48c4e4320dfd7 100644 --- a/include/onnxruntime/core/framework/ort_value.h +++ b/include/onnxruntime/core/framework/ort_value.h @@ -10,12 +10,12 @@ #include "core/framework/allocator.h" #include "core/framework/data_types.h" #include "core/framework/tensor.h" -#include "core/framework/TensorSeq.h" namespace onnxruntime { #if !defined(DISABLE_SPARSE_TENSORS) class SparseTensor; #endif +class TensorSeq; } // namespace onnxruntime #endif diff --git a/include/onnxruntime/core/framework/tensor.h b/include/onnxruntime/core/framework/tensor.h index 975793a04ae02..7f3f26fa4aa02 100644 --- a/include/onnxruntime/core/framework/tensor.h +++ b/include/onnxruntime/core/framework/tensor.h @@ -112,6 +112,14 @@ class Tensor final { OrtValue& ort_value, gsl::span strides = {}); + /// + /// Creates an instance of Tensor on the heap using the appropriate __ctor and + /// initializes OrtValue with it. + /// + /// + /// + static void InitOrtValue(Tensor&& tensor, OrtValue& ort_value); + /** * Create tensor with given type, shape, pre-allocated memory and allocator which will be used to free the pre-allocated memory. * This function won't check if the preallocated buffer(p_data) has enough room for the shape. diff --git a/java/build.gradle b/java/build.gradle index b3a98f5cdabe4..2dcd821f3de65 100644 --- a/java/build.gradle +++ b/java/build.gradle @@ -42,7 +42,7 @@ jar { // Add explicit sources jar with pom file. task sourcesJar(type: Jar, dependsOn: classes) { - classifier = "sources" + archiveClassifier = "sources" from sourceSets.main.allSource into("META-INF/maven/$project.group/$mavenArtifactId") { from { generatePomFileForMavenPublication } @@ -52,7 +52,7 @@ task sourcesJar(type: Jar, dependsOn: classes) { // Add explicit javadoc jar with pom file task javadocJar(type: Jar, dependsOn: javadoc) { - classifier = "javadoc" + archiveClassifier = "javadoc" from javadoc.destinationDir into("META-INF/maven/$project.group/$mavenArtifactId") { from { generatePomFileForMavenPublication } @@ -187,8 +187,8 @@ test { jacocoTestReport { reports { - xml.enabled true - csv.enabled true + xml.required = true + csv.required = true html.destination file("${buildDir}/jacocoHtml") } } diff --git a/onnxruntime/contrib_ops/cpu/bert/attention.cc b/onnxruntime/contrib_ops/cpu/bert/attention.cc index 47db3fe558ce8..6aa0e726afe1b 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/attention.cc @@ -198,7 +198,7 @@ Status Attention::Compute(OpKernelContext* context) const { const Tensor* mask_index = context->Input(3); const Tensor* past = context->Input(4); - const Tensor* extra_add_qk = context->Input(5); + const Tensor* relative_position_bias = context->Input(5); const TensorShape& weights_shape = (weights ? weights->Shape() : weight_shape_); @@ -208,7 +208,7 @@ Status Attention::Compute(OpKernelContext* context) const { bias->Shape(), mask_index, past, - extra_add_qk, + relative_position_bias, ¶meters)); const int batch_size = parameters.batch_size; @@ -331,7 +331,7 @@ Status Attention::Compute(OpKernelContext* context) const { return ApplyAttention(Q, K, V, mask_index, past, output, batch_size, sequence_length, parameters.head_size, parameters.v_head_size, parameters.v_hidden_size, - extra_add_qk, context); + relative_position_bias, context); } } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_base.cc b/onnxruntime/contrib_ops/cpu/bert/attention_base.cc index affe7cab1d858..e75f68ea53c7c 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_base.cc +++ b/onnxruntime/contrib_ops/cpu/bert/attention_base.cc @@ -12,7 +12,7 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape, const TensorShape& bias_shape, const Tensor*& mask_index, const Tensor* past, - const Tensor* extra_add_qk, + const Tensor* relative_position_bias, void* parameters, const Tensor* past_seq_len) const { // Abbreviation and Meanings: @@ -37,7 +37,7 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape, // bias (Q/K/V) : (D + D + D_v) // mask_index : see below // past (K/V) : (2, B, N, P, H) or NULL - // extra_add_qk : (B, N, S, T) or NULL + // relative_position_bias : (B, N, S, T) or NULL // For mask_index, the following shapes are supported: // NULL, (B, 1), (1, 1) @@ -49,9 +49,9 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape, // When a model is pruned (like some attention heads are removed in Q/K/V), input_hidden_size could be larger // than hidden dimension of Q, K and V. - if (past != nullptr && extra_add_qk != nullptr) { - // past is used on GPT-2 model with past state, we don't have a case for extra add qk yet - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Attention cannot have both past and extra_add_qk"); + if (past != nullptr && relative_position_bias != nullptr) { + // past is used on GPT-2 model with past state, we don't have a case for relative position bias yet + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Attention cannot have both past and relative_position_bias"); } const auto& dims = input_shape.GetDims(); @@ -191,34 +191,34 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape, } } - if (extra_add_qk != nullptr) { - const auto& extra_add_qk_dims = extra_add_qk->Shape().GetDims(); + if (relative_position_bias != nullptr) { + const auto& relative_position_bias_dims = relative_position_bias->Shape().GetDims(); - if (extra_add_qk_dims.size() != 4) { + if (relative_position_bias_dims.size() != 4) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'extra_add_qk' is expected to have 4 dimensions, got ", - extra_add_qk_dims.size()); + "Input 'relative_position_bias' is expected to have 4 dimensions, got ", + relative_position_bias_dims.size()); } - if (extra_add_qk_dims[0] != batch_size) { + if (relative_position_bias_dims[0] != batch_size) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'extra_add_qk' dimension 0 should be same as batch_size, got ", - extra_add_qk_dims[0]); + "Input 'relative_position_bias' dimension 0 should be same as batch_size, got ", + relative_position_bias_dims[0]); } - if (extra_add_qk_dims[1] != num_heads_) { + if (relative_position_bias_dims[1] != num_heads_) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'extra_add_qk' dimension 1 should be same as number of heads, got ", - extra_add_qk_dims[1]); + "Input 'relative_position_bias' dimension 1 should be same as number of heads, got ", + relative_position_bias_dims[1]); } - if (extra_add_qk_dims[2] != sequence_length) { + if (relative_position_bias_dims[2] != sequence_length) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'extra_add_qk' dimension 2 should be same as sequence_length, got ", - extra_add_qk_dims[2]); + "Input 'relative_position_bias' dimension 2 should be same as sequence_length, got ", + relative_position_bias_dims[2]); } - if (extra_add_qk_dims[3] != total_sequence_length) { + if (relative_position_bias_dims[3] != total_sequence_length) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'extra_add_qk' dimension 3 should be same as total_sequence_length, got ", - extra_add_qk_dims[3]); + "Input 'relative_position_bias' dimension 3 should be same as total_sequence_length, got ", + relative_position_bias_dims[3]); } } @@ -320,7 +320,7 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape, const TensorShape& bias_shape, const Tensor*& mask_index, const Tensor* past, - const Tensor* extra_add_qk, + const Tensor* relative_position_bias, void* parameters, const int max_threads_per_block, const Tensor* past_seq_len) const { @@ -328,7 +328,7 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape, return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "num_heads should be no larger than ", max_threads_per_block); } - return CheckInputs(input_shape, weights_shape, bias_shape, mask_index, past, extra_add_qk, parameters, past_seq_len); + return CheckInputs(input_shape, weights_shape, bias_shape, mask_index, past, relative_position_bias, parameters, past_seq_len); } Tensor* AttentionBase::GetPresent(OpKernelContext* context, diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_base.h b/onnxruntime/contrib_ops/cpu/bert/attention_base.h index 2c49f196d52d8..2e077da2853d0 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_base.h @@ -18,7 +18,7 @@ class AttentionBase { const TensorShape& bias_shape, const Tensor*& mask_index, // Dummy mask of shape (1 or batch_size, 1) will be updated to nullptr. const Tensor* past, - const Tensor* extra_add_qk, + const Tensor* relative_position_bias, void* parameters, const int max_threads_per_block, // for CUDA const Tensor* past_seq_len = nullptr) const; @@ -61,7 +61,7 @@ class AttentionBase { const TensorShape& bias_shape, const Tensor*& mask_index, // Dummy mask of shape (1 or batch_size, 1) will be updated to nullptr. const Tensor* past, - const Tensor* extra_add_qk, + const Tensor* relative_position_bias, void* parameters, const Tensor* past_seq_len = nullptr) const; diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_common.h b/onnxruntime/contrib_ops/cpu/bert/attention_common.h index d0137a2049825..bc5ce6e323beb 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_common.h @@ -25,7 +25,7 @@ enum AttentionQkvFormat { Q_KV_BSNH_BSN2H, // for TRT fused cross attention, kv are packed }; -enum AttentionKernelType{ +enum AttentionKernelType { AttentionKernel_Unfused, AttentionKernel_TrtFusedAttention, AttentionKernel_TrtFlashAttention, @@ -38,15 +38,15 @@ enum AttentionKernelType{ struct AttentionParameters { int batch_size; int sequence_length; - int kv_sequence_length; // input sequence length of K or V - int past_sequence_length; // sequence length in past state of K or V - int total_sequence_length; // total sequence length of K or V - int max_sequence_length; // max sequence length from 4D mask - int input_hidden_size; // first dimension of weights for input projection - int hidden_size; // hidden size of Q or K - int head_size; // hidden size per head of Q or K - int v_hidden_size; // hidden size of V - int v_head_size; // hidden size per head of V + int kv_sequence_length; // input sequence length of K or V + int past_sequence_length; // sequence length in past state of K or V + int total_sequence_length; // total sequence length of K or V + int max_sequence_length; // max sequence length from 4D mask + int input_hidden_size; // first dimension of weights for input projection + int hidden_size; // hidden size of Q or K + int head_size; // hidden size per head of Q or K + int v_hidden_size; // hidden size of V + int v_head_size; // hidden size per head of V int num_heads; bool is_unidirectional; bool past_present_share_buffer; @@ -56,13 +56,17 @@ struct AttentionParameters { }; namespace attention { -// Environment variable to enable or disable fused self/causal attention kernel. Default is 0 (enabled). -constexpr const char* kDisableFusedAttention = "ORT_DISABLE_FUSED_ATTENTION"; +// Environment variable to enable or disable TRT fused self attention kernel. Default is 0 (enabled). +constexpr const char* kDisableFusedSelfAttention = "ORT_DISABLE_FUSED_ATTENTION"; // Environment variable to enable or disable fused cross attention kernel. Default is 0 (enabled). constexpr const char* kDisableFusedCrossAttention = "ORT_DISABLE_FUSED_CROSS_ATTENTION"; -// Environment variable to enable or disable TRT flash attention. Default is 0 (enabled). +// Environment variable to enable or disable TRT fused causal attention kernels. Default is 0 (disabled). +// Note that those causal attention kernels use fp16 accumulation. There is potential accuracy drop using those kernels. +constexpr const char* kEnableFusedCausalAttention = "ORT_ENABLE_FUSED_CAUSAL_ATTENTION"; + +// Environment variable to enable or disable TRT flash attention. This applies to both self and causal attention. Default is 0 (enabled). constexpr const char* kDisableTrtFlashAttention = "ORT_DISABLE_TRT_FLASH_ATTENTION"; // Environment variable to enable or disable cutlass memory efficient attention. Default is 0 (enabled). diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h b/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h index 0185fa9ea09a0..70d71ffb6ee40 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h @@ -19,18 +19,18 @@ class AttentionCPUBase : public AttentionBase { : AttentionBase(info, require_same_hidden_size) {} template - Status ApplyAttention(const T* Q, // Q data with shape BxNxSxH - const T* K, // K data with shape BxNxSxH - const T* V, // V value with size BxNxSxH_v - const Tensor* mask_index, // mask index. nullptr if no mask or its size is B - const Tensor* past, // past state - Tensor* output, // output tensor - int batch_size, // batch size (B) - int sequence_length, // sequence length (S) - int qk_head_size, // head size of Q or K (H) - int v_head_size, // head size of V (H_v) - int v_hidden_size, // hidden size of V (D_v) - const Tensor* extra_add_qk, // extra add in QK. Its size is BxNxSxT + Status ApplyAttention(const T* Q, // Q data with shape BxNxSxH + const T* K, // K data with shape BxNxSxH + const T* V, // V value with size BxNxSxH_v + const Tensor* mask_index, // mask index. nullptr if no mask or its size is B + const Tensor* past, // past state + Tensor* output, // output tensor + int batch_size, // batch size (B) + int sequence_length, // sequence length (S) + int qk_head_size, // head size of Q or K (H) + int v_head_size, // head size of V (H_v) + int v_hidden_size, // hidden size of V (D_v) + const Tensor* relative_position_bias, // bias addition in QK. Its size is BxNxSxT OpKernelContext* context) const { const int kv_sequence_length = sequence_length; @@ -67,16 +67,16 @@ class AttentionCPUBase : public AttentionBase { const T* past_data = past != nullptr ? past->Data() : nullptr; T* present_data = present != nullptr ? present->MutableData() : nullptr; - const T* extra_add_qk_data = nullptr; - if (extra_add_qk != nullptr) { - extra_add_qk_data = extra_add_qk->Data(); + const T* relative_position_bias_data = nullptr; + if (relative_position_bias != nullptr) { + relative_position_bias_data = relative_position_bias->Data(); } ComputeAttentionProbs(static_cast(attention_probs), Q, K, mask_index_data, mask_index_dims, static_cast(mask_data), has_unidirectional, batch_size, sequence_length, past_sequence_length, qk_head_size == 0 ? v_head_size : qk_head_size, - past_data, present_data, tp, extra_add_qk_data); + past_data, present_data, tp, relative_position_bias_data); // Compute the attentionScore * Value: out_tmp(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v) auto out_tmp_data = @@ -112,7 +112,7 @@ class AttentionCPUBase : public AttentionBase { const T* past, // past state T* present, // present state ThreadPool* tp, // thread pool - const T* extra_add_qk_data // extra add matrix with shape BxNxSxT + const T* relative_position_bias_data // bias addition matrix with shape BxNxSxT ) const { const int total_sequence_length = past_sequence_length + sequence_length; // T = P + L const size_t past_chunk_length = static_cast(past_sequence_length) * head_size; // P x H @@ -175,9 +175,9 @@ class AttentionCPUBase : public AttentionBase { } } - if (extra_add_qk_data != nullptr) { + if (relative_position_bias_data != nullptr) { for (int j = 0; j < sequence_length * total_sequence_length; j++) { - output[j] += extra_add_qk_data[output_offset + j]; + output[j] += relative_position_bias_data[output_offset + j]; } } } diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h index 8c3af05972c95..ee1720b9f43bb 100644 --- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h @@ -17,6 +17,7 @@ Status CheckInputs(const T* query, const T* value, const T* bias, const T* key_padding_mask, + const T* relative_position_bias, void* parameters, int num_heads, float mask_filter_value, @@ -26,6 +27,7 @@ Status CheckInputs(const T* query, // value (V) : (B, L, D_v) // bias (Q/K/V) : (D + D + D_v) // key_padding_mask (K/V) : (B) or (B, L) or None + // relative_position_bias : (B, 1, S, L) // When packed kv is used: // key (K) : (B, L, N, 2, H) // value (V) : None @@ -120,6 +122,36 @@ Status CheckInputs(const T* query, v_hidden_size = static_cast(value_dims[2]); } + if (relative_position_bias != nullptr) { + const auto& relative_position_bias_dims = relative_position_bias->Shape().GetDims(); + + if (relative_position_bias_dims.size() != 4) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'relative_position_bias' is expected to have 4 dimensions, got ", + relative_position_bias_dims.size()); + } + if (relative_position_bias_dims[0] != batch_size && relative_position_bias_dims[0] != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'relative_position_bias' dimension 0 should be batch_size or 1, got ", + relative_position_bias_dims[0]); + } + if (relative_position_bias_dims[1] != num_heads) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'relative_position_bias' dimension 1 should be same as number of heads, got ", + relative_position_bias_dims[1]); + } + if (relative_position_bias_dims[2] != sequence_length) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'relative_position_bias' dimension 2 should be same as sequence_length, got ", + relative_position_bias_dims[2]); + } + if (relative_position_bias_dims[3] != kv_sequence_length) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'relative_position_bias' dimension 3 should be same as total_sequence_length, got ", + relative_position_bias_dims[3]); + } + } + if (parameters != nullptr) { AttentionParameters* output_parameters = reinterpret_cast(parameters); output_parameters->batch_size = batch_size; diff --git a/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc b/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc index 64c17b7767e4f..e7df84c1b0066 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc @@ -160,7 +160,7 @@ Status QAttention::Compute(OpKernelContext* context) const { bias->Shape(), mask_index, past_tensor, - nullptr, // extra_add_qk + nullptr, // relative_position_bias nullptr // parameters )); diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.cc b/onnxruntime/contrib_ops/cuda/bert/attention.cc index 4a6d2dc137139..258bab5390015 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/attention.cc @@ -39,12 +39,15 @@ REGISTER_KERNEL_TYPED(MLFloat16) template Attention::Attention(const OpKernelInfo& info) : CudaKernel(info), AttentionBase(info, false) { - disable_fused_runner_ = sizeof(T) != 2 || - ParseEnvironmentVariableWithDefault(attention::kDisableFusedAttention, false); + disable_fused_self_attention_ = sizeof(T) != 2 || + ParseEnvironmentVariableWithDefault(attention::kDisableFusedSelfAttention, false); enable_trt_flash_attention_ = sizeof(T) == 2 && !ParseEnvironmentVariableWithDefault(attention::kDisableTrtFlashAttention, false); + enable_fused_causal_attention_ = sizeof(T) == 2 && + ParseEnvironmentVariableWithDefault(attention::kEnableFusedCausalAttention, false); + #if USE_FLASH_ATTENTION disable_memory_efficient_attention_ = ParseEnvironmentVariableWithDefault(attention::kDisableMemoryEfficientAttention, false); #else @@ -59,7 +62,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { const Tensor* bias = context->Input(2); const Tensor* mask_index = context->Input(3); const Tensor* past = context->Input(kPastInputIndex); - const Tensor* extra_add_qk = context->Input(5); + const Tensor* relative_position_bias = context->Input(5); const Tensor* past_seq_len = context->Input(kPastSequenceLengthInputIndex); auto& device_prop = GetDeviceProp(); @@ -69,7 +72,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { bias->Shape(), mask_index, past, - extra_add_qk, + relative_position_bias, ¶meters, device_prop.maxThreadsPerBlock, past_seq_len)); @@ -97,15 +100,14 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { int sm = device_prop.major * 10 + device_prop.minor; bool is_mask_1d_seq_len = parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN; - if (is_unidirectional_) { // GPT + if (is_unidirectional_ && enable_fused_causal_attention_) { // GPT // GPT fused kernels requires left side padding. mask can be: // none (no padding), 1D sequence lengths or 2d mask. // Fused kernels don't support different sequence lengths of q and kv, so only apply to the first token // where past state is empty. bool is_mask_2d_key_padding = parameters.mask_type == AttentionMaskType::MASK_2D_KEY_PADDING; - bool use_causal_fused_runner = !disable_fused_runner_ && - (nullptr == mask_index || is_mask_1d_seq_len || is_mask_2d_key_padding) && - nullptr == extra_add_qk && + bool use_causal_fused_runner = (nullptr == mask_index || is_mask_1d_seq_len || is_mask_2d_key_padding) && + nullptr == relative_position_bias && parameters.past_sequence_length == 0 && parameters.hidden_size == parameters.v_hidden_size && FusedMHARunnerFP16v2::is_supported(sm, parameters.head_size, sequence_length, @@ -121,11 +123,11 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { fused_runner = fused_fp16_runner_.get(); } } else { // BERT - bool use_fused_runner = !disable_fused_runner_ && + bool use_fused_runner = !disable_fused_self_attention_ && (nullptr == mask_index || is_mask_1d_seq_len) && nullptr == past && nullptr == present && - nullptr == extra_add_qk && + nullptr == relative_position_bias && parameters.hidden_size == parameters.v_hidden_size && FusedMHARunnerFP16v2::is_supported(sm, parameters.head_size, sequence_length, enable_trt_flash_attention_, false); @@ -151,7 +153,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { nullptr == mask_index && // TODO: support 1D mask nullptr == past && nullptr == present && - nullptr == extra_add_qk && + nullptr == relative_position_bias && (sizeof(T) == 2 || // sequence length threshold is 0 in FP16 parameters.sequence_length >= attention::kMinSequenceLengthForMemoryEfficientAttentionFp32) && has_memory_efficient_attention(sm, sizeof(T) == 2); @@ -203,7 +205,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { data.mask_index = (nullptr == mask_index) ? nullptr : mask_index->Data(); data.mask_index_dims = (nullptr == mask_index) ? gsl::span() : mask_index->Shape().GetDims(); data.past = (nullptr == past) ? nullptr : reinterpret_cast(past->Data()); - data.extra_add_qk = (nullptr == extra_add_qk) ? nullptr : reinterpret_cast(extra_add_qk->Data()); + data.relative_position_bias = (nullptr == relative_position_bias) ? nullptr : reinterpret_cast(relative_position_bias->Data()); data.workspace = reinterpret_cast(work_space.get()); data.output = reinterpret_cast(output->MutableData()); data.present = (nullptr == present) ? nullptr : reinterpret_cast(present->MutableData()); diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.h b/onnxruntime/contrib_ops/cuda/bert/attention.h index 13b2019b21d0d..ba7c56c04fdde 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention.h @@ -21,8 +21,9 @@ class Attention final : public CudaKernel, public AttentionBase { Status ComputeInternal(OpKernelContext* context) const override; protected: - bool disable_fused_runner_; + bool disable_fused_self_attention_; bool enable_trt_flash_attention_; + bool enable_fused_causal_attention_; bool disable_memory_efficient_attention_; mutable std::unique_ptr fused_fp16_runner_; }; diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu index 8c7ef9f919519..5b6c40a7f50fc 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu @@ -28,6 +28,7 @@ limitations under the License. #include #include +#include #include "core/providers/cuda/cu_inc/common.cuh" #include "core/providers/cuda/cuda_common.h" #include "core/providers/cuda/shared_inc/fpgeneric.h" @@ -555,10 +556,12 @@ Status QkvToContext( if (use_fused_kernel || use_fused_causal) { int* sequence_offset = reinterpret_cast(scratch1); if (parameters.mask_type == AttentionMaskType::MASK_2D_KEY_PADDING) { + DUMP_TENSOR_D("mask", reinterpret_cast(data.mask_index), batch_size, sequence_length); LaunchTrtSequenceOffset2d(sequence_offset, data.mask_index, batch_size, sequence_length, stream); } else { LaunchTrtSequenceOffset(sequence_offset, data.mask_index, batch_size, sequence_length, stream); } + DUMP_TENSOR_D("sequence_offset", sequence_offset, 1, (data.mask_index != nullptr ? 2 : 1) * batch_size + 1); CUDA_RETURN_IF_ERROR(cudaGetLastError()); FusedMHARunnerFP16v2* fused_fp16_runner = reinterpret_cast(fused_runner); @@ -665,7 +668,7 @@ Status QkvToContext( T* persistent_softmax_workspace = scratch1; // replace Q*K' in place with masked score for persistent softmax. ORT_RETURN_IF_ERROR( ComputeSoftmaxWithRawMask(stream, total_sequence_length, sequence_length, batch_size, num_heads, - mask_index, nullptr, data.extra_add_qk, scratch1, scratch2, + mask_index, nullptr, data.relative_position_bias, scratch1, scratch2, parameters.is_unidirectional, scale, mask_dimension, parameters.max_sequence_length, use_persistent_softmax, persistent_softmax_workspace, mask_filter_value)); @@ -675,10 +678,10 @@ Status QkvToContext( const int* mask_start = (mask_index_dims[0] > batch_size) ? mask_index + batch_size : nullptr; ORT_RETURN_IF_ERROR(ComputeSoftmaxWithMask1D( stream, total_sequence_length, sequence_length, batch_size, num_heads, - mask_index, mask_start, data.extra_add_qk, scratch1, scratch2, parameters.is_unidirectional)); + mask_index, mask_start, data.relative_position_bias, scratch1, scratch2, parameters.is_unidirectional)); } else { // no mask ORT_RETURN_IF_ERROR( - ComputeSoftmax(stream, total_sequence_length, sequence_length, batch_size, num_heads, data.extra_add_qk, + ComputeSoftmax(stream, total_sequence_length, sequence_length, batch_size, num_heads, data.relative_position_bias, scratch1, scratch2, parameters.is_unidirectional)); } diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h index d98a0380c479b..2ecda71479c52 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h @@ -41,7 +41,7 @@ struct AttentionData { const int* mask_index; gsl::span mask_index_dims; const T* past; - const T* extra_add_qk; + const T* relative_position_bias; T* workspace; T* output; diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_softmax.h b/onnxruntime/contrib_ops/cuda/bert/attention_softmax.h index bf13719030fca..2ac0695e8b1d1 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_softmax.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_softmax.h @@ -194,11 +194,7 @@ __device__ inline void SoftmaxWithRawMaskSmall(const int all_sequence_length, float thread_data = -CUDART_INF_F; if (threadIdx.x < all_sequence_length) { - if (add_before_softmax == nullptr) { - thread_data = float(input[index]) * rsqrt_head_size; - } else { - thread_data = float(input[index] + add_before_softmax[index]) * rsqrt_head_size; - } + thread_data = float(input[index]) * rsqrt_head_size; const int sequence_index = blockIdx.x % sequence_length; if (is_unidirectional) { @@ -229,6 +225,10 @@ __device__ inline void SoftmaxWithRawMaskSmall(const int all_sequence_length, thread_data = -CUDART_INF_F; } } + + if (add_before_softmax != nullptr) { + thread_data += float(add_before_softmax[index]); + } } if (skip_softmax) { diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc index 93e5e59ed00ae..8fce4dcb7a51e 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc @@ -40,8 +40,8 @@ MultiHeadAttention::MultiHeadAttention(const OpKernelInfo& info) mask_filter_value_ = info.GetAttrOrDefault("mask_filter_value", -10000.0f); - disable_fused_runner_ = sizeof(T) != 2 || - ParseEnvironmentVariableWithDefault(attention::kDisableFusedAttention, false); + disable_fused_self_attention_ = sizeof(T) != 2 || + ParseEnvironmentVariableWithDefault(attention::kDisableFusedSelfAttention, false); enable_trt_flash_attention_ = sizeof(T) == 2 && !ParseEnvironmentVariableWithDefault(attention::kDisableTrtFlashAttention, false); @@ -62,6 +62,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { const Tensor* value = context->Input(2); const Tensor* bias = context->Input(3); const Tensor* key_padding_mask = context->Input(4); + const Tensor* relative_position_bias = context->Input(5); auto& device_prop = GetDeviceProp(); AttentionParameters parameters; @@ -70,6 +71,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { value, bias, key_padding_mask, + relative_position_bias, ¶meters, num_heads_, mask_filter_value_, @@ -94,6 +96,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { bool use_fused_cross_attention = !disable_fused_cross_attention_ && nullptr == key_padding_mask && + nullptr == relative_position_bias && (value != nullptr || bias == nullptr) && // TODO: new kernel for adding bias to packed KV parameters.hidden_size == parameters.v_hidden_size && has_fused_cross_attention_kernel(sm, parameters.head_size, @@ -110,8 +113,9 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { } } - bool use_fused_runner = !disable_fused_runner_ && + bool use_fused_runner = !disable_fused_self_attention_ && fused_cross_attention_kernel == nullptr && + nullptr == relative_position_bias && value != nullptr && // fused runner requires packed qkv instead of packed kv (nullptr == key_padding_mask || is_mask_1d_seq_len) && parameters.hidden_size == parameters.v_hidden_size && @@ -143,6 +147,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { !disable_memory_efficient_attention_ && is_long_sequence && nullptr == key_padding_mask && // TODO: support 1D mask + nullptr == relative_position_bias && has_memory_efficient_attention(sm, sizeof(T) == 2); #else constexpr bool use_memory_efficient_attention = false; @@ -171,7 +176,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { data.mask_index = (nullptr == key_padding_mask) ? nullptr : key_padding_mask->Data(); data.mask_index_dims = (nullptr == key_padding_mask) ? gsl::span() : key_padding_mask->Shape().GetDims(); data.past = nullptr; - data.extra_add_qk = nullptr; + data.relative_position_bias = (nullptr == relative_position_bias) ? nullptr : reinterpret_cast(relative_position_bias->Data()); data.workspace = reinterpret_cast(work_space.get()); data.output = reinterpret_cast(output->MutableData()); data.present = nullptr; diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h index b4ac7f19597ea..e6ca36358402f 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h @@ -23,7 +23,7 @@ class MultiHeadAttention final : public CudaKernel { protected: int num_heads_; // number of attention heads float mask_filter_value_; - bool disable_fused_runner_; + bool disable_fused_self_attention_; bool enable_trt_flash_attention_; bool disable_fused_cross_attention_; bool disable_memory_efficient_attention_; diff --git a/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.cc b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.cc index af13efe0e2fbc..9627a1f7c3741 100644 --- a/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.cc +++ b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.cc @@ -3,7 +3,15 @@ #include "core/providers/cuda/cuda_common.h" #include "relative_attn_bias.h" +#include "core/common/safeint.h" #include "relative_attn_bias_impl.h" +#include "core/providers/cuda/shared_inc/fpgeneric.h" +#include "contrib_ops/cuda/bert/add_bias_transpose.h" + +using namespace onnxruntime::cuda; +using namespace ::onnxruntime::common; +using namespace ONNX_NAMESPACE; + namespace onnxruntime { namespace contrib { @@ -20,7 +28,16 @@ namespace cuda { .InputMemoryType(OrtMemTypeCPUInput, 1) \ .InputMemoryType(OrtMemTypeCPUInput, 2) \ .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - RelPosAttnBias); + RelPosAttnBias); \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + GatedRelativePositionBias, \ + kMSDomain, \ + 1, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + GatedRelativePositionBias); REGISTER_KERNEL_TYPED(float) REGISTER_KERNEL_TYPED(MLFloat16) @@ -69,6 +86,109 @@ Status RelPosAttnBias::ComputeInternal(OpKernelContext* context) const { device_prop.maxThreadsPerBlock); } +template +GatedRelativePositionBias::GatedRelativePositionBias(const OpKernelInfo& info) : CudaKernel(info) { + int64_t num_heads = 0; + ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0); + num_heads_ = SafeInt(num_heads); +} + +template +Status GatedRelativePositionBias::ComputeInternal(OpKernelContext* context) const { + const Tensor& query_tensor = *context->Input(0); + const Tensor& query_bias_tensor = *context->Input(1); + const Tensor& rel_pos_tensor = *context->Input(2); + const Tensor& weight_tensor = *context->Input(3); + const Tensor& bias_tensor = *context->Input(4); + const Tensor& eco_a_tensor = *context->Input(5); + + const auto& query_dims = query_tensor.Shape().GetDims(); + ORT_ENFORCE(query_dims.size() == 3); + ORT_ENFORCE(query_dims[2] > 0); + ORT_ENFORCE(query_dims[2] % num_heads_ == 0); + const auto batch_size = SafeInt(query_dims[0]); + const auto seq_len = SafeInt(query_dims[1]); + const auto head_size = SafeInt(query_dims[2] / num_heads_); + + ORT_ENFORCE(query_bias_tensor.Shape().NumDimensions() == 1); + ORT_ENFORCE(query_bias_tensor.Shape()[0] == query_dims[2]); + + const auto& rel_pos_dims = rel_pos_tensor.Shape().GetDims(); + ORT_ENFORCE(rel_pos_dims.size() == 4); + ORT_ENFORCE(rel_pos_dims[0] == 1); + ORT_ENFORCE(rel_pos_dims[1] == num_heads_); + ORT_ENFORCE(rel_pos_dims[2] == seq_len); + ORT_ENFORCE(rel_pos_dims[3] == seq_len); + + const auto& weight_dims = weight_tensor.Shape().GetDims(); + ORT_ENFORCE(weight_dims.size() == 2); + ORT_ENFORCE(weight_dims[0] == head_size); + ORT_ENFORCE((weight_dims[1] > 0) && (weight_dims[1] % 2 == 0)); + + ORT_ENFORCE(bias_tensor.Shape().NumDimensions() == 1); + ORT_ENFORCE(bias_tensor.Shape()[0] == weight_dims[1]); + + const auto D = SafeInt(weight_dims[1]); + + const auto& eco_a_dims = eco_a_tensor.Shape().GetDims(); + ORT_ENFORCE(eco_a_dims.size() == 4); + ORT_ENFORCE(eco_a_dims[0] == 1); + ORT_ENFORCE(eco_a_dims[1] == num_heads_); + ORT_ENFORCE(eco_a_dims[2] == 1); + ORT_ENFORCE(eco_a_dims[3] == 1); + + Tensor* output = context->Output(0, {batch_size, num_heads_, seq_len, seq_len}); + + auto& device_prop = GetDeviceProp(); + cublasHandle_t cublas = GetCublasHandle(context); + + typedef typename ToCudaType::MappedType CudaT; + const auto BNS = batch_size * num_heads_ * seq_len; + const size_t elements_in_query = (size_t)BNS * (size_t)head_size; + const size_t elements_after_gemm = (size_t)BNS *(size_t)D; + bool reuse_output = (seq_len >= D); + size_t workspace_size = sizeof(T) * (elements_in_query + (reuse_output ? (size_t)0 : elements_after_gemm)); + auto workspace = GetScratchBuffer(workspace_size, context->GetComputeStream()); + + // format 1: BxSx(NH * total_matrix) => matrix_to_transpose * (BxNxSxH) + constexpr int format = 1; + constexpr int total_maxtrix = 1; + constexpr int num_matrix_to_transpose = 1; + LaunchAddBiasTranspose(Stream(context), num_matrix_to_transpose, format, device_prop.maxThreadsPerBlock, + batch_size, seq_len, num_heads_, head_size, + reinterpret_cast(query_tensor.template Data()), + reinterpret_cast(query_bias_tensor.template Data()), + reinterpret_cast(workspace.get()), + false, head_size, reinterpret_cast(static_cast(nullptr)), total_maxtrix); + + // reuse output if possible + CudaT* gemm_output = reuse_output ? reinterpret_cast(output->template MutableData()) + : (reinterpret_cast(workspace.get()) + elements_in_query); + int ld_gemm_output = reuse_output ? seq_len : D; + + const CudaT one = ToCudaType::FromFloat(1.0f); + const CudaT zero = ToCudaType::FromFloat(0.0f); + + // ([b*n*s, h] * [h, D]), CUDA assumes col-major + CUBLAS_RETURN_IF_ERROR(cublasGemmHelper( + cublas, CUBLAS_OP_N, CUBLAS_OP_N, + D, BNS, head_size, &one, + reinterpret_cast(weight_tensor.template Data()), (int)D, + reinterpret_cast(workspace.get()), (int)head_size, + &zero, gemm_output, ld_gemm_output, device_prop)); + + auto status = LaunchGatedRelativePositionBiasKernel( + device_prop, Stream(context), + reinterpret_cast(output->template MutableData()), + reinterpret_cast(rel_pos_tensor.template Data()), + reinterpret_cast(gemm_output), + reinterpret_cast(bias_tensor.template Data()), + reinterpret_cast(eco_a_tensor.template Data()), + batch_size, num_heads_, seq_len, D, ld_gemm_output); + + return status; +} + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.h b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.h index b9674f6f35091..3bf4e730e29f9 100644 --- a/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.h +++ b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.h @@ -22,6 +22,18 @@ class RelPosAttnBias final : public CudaKernel { bool is_bidirectional_; }; +template +class GatedRelativePositionBias final : public CudaKernel { + public: + GatedRelativePositionBias(const OpKernelInfo& op_kernel_info); + + Status ComputeInternal(OpKernelContext* ctx) const override; + + private: + int num_heads_; +}; + + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.cu b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.cu index e333152cb5bcf..938496b058025 100644 --- a/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.cu @@ -36,7 +36,7 @@ __global__ void buildRelativeAttentionBias(T* relative_attention_bias, const bool is_bidirectional, const int max_distance) { const int head_id = blockIdx.x; - for (int seq_id = threadIdx.x; seq_id < seq_len * seq_len; seq_id += blockDim.x * gridDim.y) { + for (int seq_id = blockDim.x * blockIdx.y + threadIdx.x; seq_id < seq_len * seq_len; seq_id += blockDim.x * gridDim.y) { int row_id = seq_id / seq_len; int col_id = seq_id % seq_len; @@ -149,6 +149,122 @@ template Status LaunchRelPosAttnBiasKernel(cudaStream_t stream, const bool is_bidirectional, const int max_threads_per_block); +template +__global__ void GatedRelativePositionBiasKernelSmallD( + T* output, // (batch_size, num_heads, seq_len, seq_len) + const T* rel_pos, // (1, num_heads, seq_len, seq_len) + const T* qw, // (batch_size, num_heads, seq_len, D) + const T* bias, // (D) + const T* eco_a, // (1, num_heads, 1, 1) + const int D, + const int ldqw) { + __shared__ float gate[1]; + + const int seq_len = gridDim.x; + const int num_heads = gridDim.y; + const int s = blockIdx.x; + const int n = blockIdx.y; + const int b = blockIdx.z; + + rel_pos += ((int64_t)n * seq_len + s) * seq_len; + output += ((int64_t)b * num_heads * seq_len + (int64_t)n * seq_len + s) * seq_len; + qw += ((int64_t)b * num_heads * seq_len + (int64_t)n * seq_len + s) * ldqw; + + float val = 0.0f; + if (threadIdx.x < D) { + val = (float)qw[threadIdx.x] + (bias ? (float)bias[threadIdx.x] : 0.0f); + } + + float u = (threadIdx.x < D / 2) ? val : 0.0f; +#pragma unroll + for (int offset = 16; offset > 0; offset /= 2) { + u += __shfl_down_sync(0xffffffff, u, offset); + } + + float r = (threadIdx.x >= D / 2) ? val : 0.0f; +#pragma unroll + for (int offset = 16; offset > 0; offset /= 2) { + r += __shfl_down_sync(0xffffffff, r, offset); + } + + if (threadIdx.x == 0) { + u = 1.0f / (1.0f + expf(-u)); + r = 1.0f / (1.0f + expf(-r)); + gate[0] = u * (r * (float)eco_a[n] - 1.0f) + 2.0f; + } + __syncthreads(); + + for (int idx = threadIdx.x; idx < seq_len; idx += blockDim.x) { + output[idx] = (T)(gate[0] * (float)rel_pos[idx]); + } +} + +template +Status LaunchGatedRelativePositionBiasKernel( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + T* output, + const T* rel_pos, + const T* qw, // query * weight + const T* bias, + const T* eco_a, + const int batch_size, + const int num_heads, + const int seq_len, + const int D, + const int ldqw) { + ORT_ENFORCE(D <= 32 && D > 0 && (D % 2 == 0)); + ORT_ENFORCE(ldqw == seq_len || ldqw == D); + + int tpb = std::max(32, std::max(D, seq_len)); + tpb = std::min(tpb, device_prop.maxThreadsPerBlock); + + // round up tpb to power of 2 + --tpb; + tpb |= (tpb >> 1); + tpb |= (tpb >> 2); + tpb |= (tpb >> 4); + tpb |= (tpb >> 8); + tpb |= (tpb >> 16); + tpb++; + + dim3 block(tpb); + dim3 grid(seq_len, num_heads, batch_size); + + GatedRelativePositionBiasKernelSmallD<<>>( + output, rel_pos, qw, bias, eco_a, D, ldqw); + + return CUDA_CALL(cudaGetLastError()); +} + +template Status LaunchGatedRelativePositionBiasKernel( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + float* output, + const float* rel_pos, + const float* qw, + const float* bias, + const float* eco_a, + const int batch_size, + const int num_heads, + const int seq_len, + const int D, + const int ldqw); + +template Status LaunchGatedRelativePositionBiasKernel( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + half* output, + const half* rel_pos, + const half* qw, + const half* bias, + const half* eco_a, + const int batch_size, + const int num_heads, + const int seq_len, + const int D, + const int ldqw); + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.h b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.h index 5a1a229ab6077..5c7c98f55f3f5 100644 --- a/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.h @@ -22,6 +22,21 @@ Status LaunchRelPosAttnBiasKernel( const int max_threads_per_block ); +template +Status LaunchGatedRelativePositionBiasKernel( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + T* output, + const T* rel_pos, + const T* qw, // from query * weight + const T* bias, + const T* eco_a, + const int batch_size, + const int num_heads, + const int seq_len, + const int D, + const int ldqw); + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index a239e528af148..1254ccd7e1e17 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -32,6 +32,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, FusedMatMul); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, RelativePositionBias); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, RelativePositionBias); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, GatedRelativePositionBias); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, GatedRelativePositionBias); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, RemovePadding); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, RemovePadding); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, RestorePadding); @@ -162,6 +164,8 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc index e5ea47a6a2a5b..7cd717efc9fba 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc @@ -52,7 +52,7 @@ Status QAttention::CheckInputs(const Tensor* input, auto& device_prop = GetDeviceProp(); ORT_RETURN_IF_ERROR(AttentionBase::CheckInputs(input->Shape(), weights->Shape(), bias->Shape(), mask_index, past_tensor, - nullptr, // extra_add_qk + nullptr, // relative_position_bias parameters, device_prop.maxThreadsPerBlock)); @@ -198,7 +198,7 @@ Status QAttention::ComputeInternal(OpKernelContext* context) const { data.mask_index = (nullptr == mask_index) ? nullptr : mask_index->Data(); data.mask_index_dims = (nullptr == mask_index) ? gsl::span() : mask_index->Shape().GetDims(); data.past = (nullptr == past_tensor) ? nullptr : reinterpret_cast(past_tensor->Data()); - data.extra_add_qk = nullptr; // add_qk is not supported in quantized attention + data.relative_position_bias = nullptr; // add_qk is not supported in quantized attention data.workspace = reinterpret_cast(work_space.get()); data.output = reinterpret_cast(output->MutableData()); data.present = (nullptr == present) ? nullptr : reinterpret_cast(present->MutableData()); diff --git a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.cc b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.cc index 204c786cc2c5d..8122b2de5916b 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.cc @@ -212,7 +212,7 @@ Status QOrderedAttention::ComputeInternal(OpKernelContext* context) const { ORT_RETURN_IF_ERROR(CheckInputs(input->Shape(), merged_weights_shape, merged_bias_shape, mask_index, nullptr, // past - nullptr, // extra_add_qk + nullptr, // relative_position_bias nullptr, // parameters device_prop.maxThreadsPerBlock)); diff --git a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention_input_enum.h b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention_input_enum.h index 5fe62ef127800..5fb31be5fe86f 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention_input_enum.h +++ b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention_input_enum.h @@ -17,4 +17,4 @@ DefineQOrderedAttentionInput(Scale_QK_Softmax, scale_QKT_softmax, 15), DefineQOrderedAttentionInput(Scale_Values_Gemm, scale_values_gemm, 16), DefineQOrderedAttentionInput(Mask_Index, mask_index, 17), DefineQOrderedAttentionInput(Past, past, 18), -DefineQOrderedAttentionInput(Extra_Add, extra_add, 19) +DefineQOrderedAttentionInput(relative_position_bias, relative_position_bias, 19) diff --git a/onnxruntime/contrib_ops/rocm/bert/attention.cc b/onnxruntime/contrib_ops/rocm/bert/attention.cc index 756919834aef8..afc9fd9237ed7 100644 --- a/onnxruntime/contrib_ops/rocm/bert/attention.cc +++ b/onnxruntime/contrib_ops/rocm/bert/attention.cc @@ -39,7 +39,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { const Tensor* bias = context->Input(2); const Tensor* mask_index = context->Input(3); const Tensor* past = context->Input(4); - const Tensor* extra_add_qk = context->Input(5); + const Tensor* relative_position_bias = context->Input(5); auto& device_prop = GetDeviceProp(); ORT_RETURN_IF_ERROR(CheckInputs(input->Shape(), @@ -47,7 +47,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { bias->Shape(), mask_index, past, - extra_add_qk, + relative_position_bias, nullptr, device_prop.maxThreadsPerBlock)); @@ -129,7 +129,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { nullptr == mask_index ? gsl::span() : mask_index->Shape().GetDims(), mask_filter_value_, nullptr == past ? nullptr : past->Data(), - nullptr == extra_add_qk ? nullptr : extra_add_qk->Data(), + nullptr == relative_position_bias ? nullptr : relative_position_bias->Data(), work_space.get(), output->MutableData(), nullptr == present ? nullptr : present->MutableData()); diff --git a/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu b/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu index 954a129be1c65..fa6cce6a64132 100644 --- a/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu @@ -89,7 +89,7 @@ Status QkvToContext( bool is_unidirectional, int past_sequence_length, const T* past, - const T* extra_add_qk, + const T* relative_position_bias, T* present, bool use_persistent_softmax) { const int all_sequence_length = past_sequence_length + sequence_length; @@ -158,7 +158,7 @@ Status QkvToContext( T* persistent_softmax_workspace = scratch1; // replace Q*K' in place if persistent softmax is selected. ORT_RETURN_IF_ERROR( ComputeSoftmaxWithRawMask(stream, all_sequence_length, sequence_length, batch_size, num_heads, - mask_index, nullptr, extra_add_qk, scratch1, scratch2, + mask_index, nullptr, relative_position_bias, scratch1, scratch2, is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, use_persistent_softmax, persistent_softmax_workspace, mask_filter_value)); } else if (nullptr != mask_index) { // 1d mask index @@ -166,10 +166,10 @@ Status QkvToContext( // mask_index has 1D shape: either (batch_size) or (2*batch_size). Only the later one has start postions. const int* mask_start = (mask_index_dims[0] > batch_size) ? mask_index + batch_size : nullptr; ORT_RETURN_IF_ERROR(ComputeSoftmaxWithMask1D(stream, all_sequence_length, sequence_length, batch_size, num_heads, - mask_index, mask_start, extra_add_qk, scratch1, scratch2, is_unidirectional)); + mask_index, mask_start, relative_position_bias, scratch1, scratch2, is_unidirectional)); } else { // no mask ORT_RETURN_IF_ERROR(ComputeSoftmax(stream, all_sequence_length, sequence_length, batch_size, num_heads, - extra_add_qk, scratch1, scratch2, is_unidirectional)); + relative_position_bias, scratch1, scratch2, is_unidirectional)); } // compute P*V (as V*P), and store in scratch3: BxNxSxH @@ -206,7 +206,7 @@ Status LaunchAttentionKernel( gsl::span mask_index_dims, const float mask_filter_value, const void* past, - const void* extra_add_qk, + const void* relative_position_bias, void* workspace, void* output, void* present) { @@ -225,7 +225,7 @@ Status LaunchAttentionKernel( is_unidirectional, past_sequence_length, reinterpret_cast(past), - reinterpret_cast(extra_add_qk), + reinterpret_cast(relative_position_bias), reinterpret_cast<__half*>(present), use_persistent_softmax); } else { @@ -240,7 +240,7 @@ Status LaunchAttentionKernel( is_unidirectional, past_sequence_length, reinterpret_cast(past), - reinterpret_cast(extra_add_qk), + reinterpret_cast(relative_position_bias), reinterpret_cast(present), use_persistent_softmax); } diff --git a/onnxruntime/contrib_ops/rocm/bert/attention_impl.h b/onnxruntime/contrib_ops/rocm/bert/attention_impl.h index 7db692083f5e5..fdc46ce2e7729 100644 --- a/onnxruntime/contrib_ops/rocm/bert/attention_impl.h +++ b/onnxruntime/contrib_ops/rocm/bert/attention_impl.h @@ -42,7 +42,7 @@ Status LaunchAttentionKernel( gsl::span mask_index_dims, // Mask index shape const float mask_filter_value, // Mask value for filtered out positions const void* past, // Past state input - const void* extra_add_qk, // Additional Add + const void* relative_position_bias, // Additional Add void* workspace, // Temporary buffer void* output, // Output tensor void* present // Present state output diff --git a/onnxruntime/core/framework/TensorSeq.h b/onnxruntime/core/framework/TensorSeq.h index 9ac6d22ec1b13..940e3e5920a26 100644 --- a/onnxruntime/core/framework/TensorSeq.h +++ b/onnxruntime/core/framework/TensorSeq.h @@ -3,6 +3,8 @@ #pragma once +#include "core/framework/data_types.h" +#include "core/framework/ort_value.h" #include "core/framework/tensor.h" #include #include @@ -17,7 +19,8 @@ class TensorSeq { SetType(elem_type); } - using const_iterator = std::vector::const_iterator; + using const_iterator = std::vector::const_iterator; + using iterator = std::vector::iterator; // Sets the element type after construction. // Expects sequence to be empty at the time. @@ -27,7 +30,7 @@ class TensorSeq { ORT_ENFORCE(elem_type_ != nullptr, "Tensor sequence must contain only primitive types"); } - void SetElements(std::vector&& tensors) { + void SetElements(std::vector&& tensors) { // The caller of this method ensures that : // (1) `elem_type` is set before invoking this method // (2) All tensors contain elements of the same primitive data type @@ -56,16 +59,57 @@ class TensorSeq { return tensors_.cend(); } - // Get by index + iterator begin() noexcept { + return tensors_.begin(); + } + + iterator end() noexcept { + return tensors_.end(); + } + + // Get onnxruntime::Tensor by index const Tensor& Get(size_t i) const { + return GetAt(i).Get(); + } + + // Get OrtValue by index + const OrtValue& GetAt(size_t i) const { ORT_ENFORCE(i < tensors_.size()); return tensors_[i]; } + void Add(const OrtValue& tensor) { + ORT_ENFORCE(IsSameDataType(tensor.Get()), + "TensorSeq: tensor to be added has a different data type."); + tensors_.push_back(tensor); + } + + void Add(OrtValue&& tensor) { + ORT_ENFORCE(IsSameDataType(tensor.Get()), + "TensorSeq: tensor to be added has a different data type."); + tensors_.push_back(std::move(tensor)); + } + void Add(Tensor&& tensor) { ORT_ENFORCE(IsSameDataType(tensor), "TensorSeq: tensor to be added has a different data type."); - tensors_.push_back(std::move(tensor)); + OrtValue value; + Tensor::InitOrtValue(std::move(tensor), value); + Add(std::move(value)); + } + + static void InitOrtValue(const TensorSeq& source_tensor_seq, std::shared_ptr allocator, OrtValue& ort_value) { + auto target_tensor_seq = std::make_unique(source_tensor_seq.DataType()); + target_tensor_seq->Reserve(source_tensor_seq.Size()); + for (auto iter = source_tensor_seq.begin(); iter != source_tensor_seq.end(); ++iter) { + const Tensor& tensor = iter->Get(); + OrtValue value; + Tensor::InitOrtValue(tensor.DataType(), tensor.Shape(), allocator, value); + target_tensor_seq->Add(std::move(value)); + } + + auto ml_tensor_seq = SequenceTensorTypeBase::Type(); + ort_value.Init(target_tensor_seq.release(), ml_tensor_seq, ml_tensor_seq->GetDeleteFunc()); } void Reserve(size_t capacity) { @@ -80,9 +124,7 @@ class TensorSeq { // and the SequenceInsert op expects validation of tensors to be added to the seq against this type. const PrimitiveDataTypeBase* elem_type_{}; - // TODO: optimization opportunity - if all tensors in the seq are scalars, we can potentially represent them - // as vector - std::vector tensors_; + std::vector tensors_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/framework/tensor.cc b/onnxruntime/core/framework/tensor.cc index 314c76c04adbe..af3bad3085e4a 100644 --- a/onnxruntime/core/framework/tensor.cc +++ b/onnxruntime/core/framework/tensor.cc @@ -105,6 +105,12 @@ void Tensor::InitOrtValue(MLDataType p_type, const TensorShape& shape, ort_value.Init(p_tensor.release(), ml_tensor, ml_tensor->GetDeleteFunc()); } +void Tensor::InitOrtValue(Tensor&& tensor, OrtValue& ort_value) { + auto ml_tensor = DataTypeImpl::GetType(); + auto p_tensor = std::make_unique(std::move(tensor)); + ort_value.Init(p_tensor.release(), ml_tensor, ml_tensor->GetDeleteFunc()); +} + size_t Tensor::SizeInBytes() const { #ifdef ENABLE_STRIDED_TENSORS int64_t size = IsContiguous() ? shape_.Size() : GetSizeFromStrides(shape_, strides_); diff --git a/onnxruntime/core/framework/utils.cc b/onnxruntime/core/framework/utils.cc index f88d098454479..40b566b0983dc 100644 --- a/onnxruntime/core/framework/utils.cc +++ b/onnxruntime/core/framework/utils.cc @@ -117,14 +117,7 @@ static common::Status AllocateHelper(const AllocatorPtr& allocator, #endif } else if (source_mlvalue.IsTensorSequence()) { const TensorSeq& source_tensor_seq = source_mlvalue.Get(); - auto target_tensor_seq = std::make_unique(source_tensor_seq.DataType()); - std::vector tensors; - for (auto iter = source_tensor_seq.begin(); iter != source_tensor_seq.end(); ++iter) { - tensors.emplace_back(iter->DataType(), onnxruntime::TensorShape(iter->Shape()), allocator); - } - target_tensor_seq->SetElements(std::move(tensors)); - auto ml_tensor_seq = DataTypeImpl::GetType(); - target_mlvalue.Init(target_tensor_seq.release(), ml_tensor_seq, ml_tensor_seq->GetDeleteFunc()); + TensorSeq::InitOrtValue(source_tensor_seq, allocator, target_mlvalue); } else { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported OrtValue type."); } @@ -207,17 +200,19 @@ static Status BatchOrCopyMLValue(const SessionState& session_state, std::unique_ptr target_tensor = std::make_unique(source_tensor.DataType(), source_tensor.Shape(), allocator); target_tensor_seq.Add(std::move(*target_tensor)); } + const auto& data_transfer_mgr = session_state.GetDataTransferMgr(); auto source_iter = source_tensor_seq.begin(); auto target_iter = target_tensor_seq.begin(); + while (source_iter != source_tensor_seq.end() && target_iter != target_tensor_seq.end()) { if (copy_tensor_pairs != nullptr) { - copy_tensor_pairs->push_back({*source_iter, const_cast(*target_iter), stream}); + copy_tensor_pairs->push_back({source_iter->Get(), *target_iter->GetMutable(), stream}); } else { if (stream) - ORT_RETURN_IF_ERROR(session_state.GetDataTransferMgr().CopyTensorAsync(*source_iter, const_cast(*target_iter), *stream)); + ORT_RETURN_IF_ERROR(data_transfer_mgr.CopyTensorAsync(source_iter->Get(), *target_iter->GetMutable(), *stream)); else - ORT_RETURN_IF_ERROR(session_state.GetDataTransferMgr().CopyTensor(*source_iter, const_cast(*target_iter))); + ORT_RETURN_IF_ERROR(data_transfer_mgr.CopyTensor(source_iter->Get(), *target_iter->GetMutable())); } ++source_iter; ++target_iter; @@ -507,6 +502,48 @@ static common::Status CopyInputsAcrossDevices(const SessionState& session_state, return Status::OK(); } +#ifdef ORT_ENABLE_STREAM +struct DeviceStreamCollectionHolder { + DeviceStreamCollectionHolder( + const SessionState& session_state) : session_state_(session_state), + p_(session_state.AcquireDeviceStreamCollection()) { + } + + ~DeviceStreamCollectionHolder() { + if (p_) { + session_state_.RecycleDeviceStreamCollection(std::move(p_)); + } + } + + const SessionState& session_state_; + std::unique_ptr p_; +}; + +static void UpdateWithParentStream(DeviceStreamCollection& device_stream_collection, + Stream* parent_stream) { + if (parent_stream) { + // TODO: in theory, we should make current subgraph's stream depends on parent stream. + // but in current code structure, it causing issues with the resource sharing and stream + // lifetime. it also may cause additional cost of stream sync for single stream case. + // In first phase, let's just put all the subgraph execution on the parent stream. + for (size_t i = 0; i < device_stream_collection.NumStreams(); ++i) { + auto* stream = device_stream_collection.GetStream(i); + if (stream) { + // if current logic stream is not on the same EP instance as parent stream + // and the EP instance does have async streams (not EP like CPU) + // throw error as we don't have the code to setup the dependency at this moment. + if (stream->GetDevice() != parent_stream->GetDevice()) { + ORT_THROW("Subgraph has nodes running on device: ", stream->GetDevice().Type(), + " while parent graph node running on device: ", parent_stream->GetDevice().Type(), + ", this is not supported yet."); + } + device_stream_collection.SetDeviceStream(i, parent_stream); + } + } + } +} +#endif + // public method to do a single copy. used by external partners common::Status CopyOneInputAcrossDevices(const SessionState& session_state, const std::string& input_name, const OrtValue& orig_mlvalue, OrtValue& new_mlvalue) { @@ -526,8 +563,23 @@ common::Status CopyOneInputAcrossDevices(const SessionState& session_state, cons copy_info.source_device = orig_mlvalue.Get().Location().device; #endif + Stream* device_stream = nullptr; +#ifdef ORT_ENABLE_STREAM + DeviceStreamCollectionHolder device_stream_collection_holder(session_state); + if (device_stream_collection_holder.p_ != nullptr) { + DeviceStreamCollection* device_stream_collection = device_stream_collection_holder.p_.get(); + gsl::span streams = device_stream_collection->GetStreams(); + for (Stream* stream : streams) { + if (stream && stream->GetDevice().Type() != OrtDevice::CPU) { + device_stream = stream; + break; + } + } + } +#endif + // copy_info.target_device is not set leaving to be equal to CPU. - return BatchOrCopyMLValue(session_state, copy_info, orig_mlvalue, new_mlvalue, nullptr); + return BatchOrCopyMLValue(session_state, copy_info, orig_mlvalue, new_mlvalue, device_stream); } static common::Status CopyOutputsAcrossDevices(const SessionState& session_state, @@ -566,48 +618,6 @@ static common::Status CopyOutputsAcrossDevices(const SessionState& session_state return Status::OK(); } -#ifdef ORT_ENABLE_STREAM -struct DeviceStreamCollectionHolder { - DeviceStreamCollectionHolder( - const SessionState& session_state) : session_state_(session_state), - p_(session_state.AcquireDeviceStreamCollection()) { - } - - ~DeviceStreamCollectionHolder() { - if (p_) { - session_state_.RecycleDeviceStreamCollection(std::move(p_)); - } - } - - const SessionState& session_state_; - std::unique_ptr p_; -}; - -static void UpdateWithParentStream(DeviceStreamCollection& device_stream_collection, - Stream* parent_stream) { - if (parent_stream) { - // TODO: in theory, we should make current subgraph's stream depends on parent stream. - // but in current code structure, it causing issues with the resource sharing and stream - // lifetime. it also may cause additional cost of stream sync for single stream case. - // In first phase, let's just put all the subgraph execution on the parent stream. - for (size_t i = 0; i < device_stream_collection.NumStreams(); ++i) { - auto* stream = device_stream_collection.GetStream(i); - if (stream) { - // if current logic stream is not on the same EP instance as parent stream - // and the EP instance does have async streams (not EP like CPU) - // throw error as we don't have the code to setup the dependency at this moment. - if (stream->GetDevice() != parent_stream->GetDevice()) { - ORT_THROW("Subgraph has nodes running on device: ", stream->GetDevice().Type(), - " while parent graph node running on device: ", parent_stream->GetDevice().Type(), - ", this is not supported yet."); - } - device_stream_collection.SetDeviceStream(i, parent_stream); - } - } - } -} -#endif - static common::Status ExecuteGraphImpl(const SessionState& session_state, const FeedsFetchesManager& feeds_fetches_manager, diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index 68e3985651123..6b00ac94bc10f 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -243,7 +243,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "T", OpSchema::Optional) .Input(5, - "extra_add", + "relative_position_bias", "additional add to QxK' with shape (batch_size, num_heads, sequence_length, total_sequence_length)", "T", OpSchema::Optional) @@ -313,6 +313,12 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "Key padding mask with shape (batch_size) or (batch_size, kv_sequence_length)", "M", OpSchema::Optional) + .Input(5, + "relative_position_bias", + "relative position bias: addition to QxK' with shape (batch_size, num_heads, sequence_length, total_sequence_length)" + " or (1, num_heads, sequence_length, total_sequence_length)", + "T", + OpSchema::Optional) .Output(0, "output", "3D output tensor with shape (batch_size, sequence_length, v_hidden_size)", @@ -668,5 +674,41 @@ ONNX_MS_OPERATOR_SET_SCHEMA( RestorePaddingTypeAndShapeInference(ctx); })); +constexpr const char* GatedRelativePositionBias_ver1_doc = R"DOC( + query_layer = (query_layer + query_bias).reshape(batch_size, seq_len, num_heads, head_size).transpose(1, 2) + gate_u, gate_r = torch.sigmoid( + self.gate_ur_linear(query_layer).view(batch_size, num_head, seq_len, 2, D/2).sum(-1, keepdim=False) + ).chunk(2, dim=-1) + gate_u_1 = gate_u * (gate_r * self.eco_a - 1.0) + 2.0 + rel_pos_bias = gate_u_1 * rel_pos +)DOC"; + +ONNX_MS_OPERATOR_SET_SCHEMA( + GatedRelativePositionBias, 1, + OpSchema() + .SetDoc(GatedRelativePositionBias_ver1_doc) + .Attr("num_heads", "Number of attention heads", AttributeProto::INT) + .Input(0, "query_layer", "tensor with shape (batch_size, seq_len, num_heads x head_size)", "T") + .Input(1, "query_bias", "1-d tensor with shape (num_heads x head_size)", "T") + .Input(2, "rel_pos", "tensor with shape (1, num_head, seq_len, seq_len)", "T") + .Input(3, "weight", "gemm weight for the gated_ur_linear, shape (head_size, D), D is divisible by 2", "T") + .Input(4, "bias", "bias for the gated_ur_linear, shape (D)", "T") + .Input(5, "eco_a", "tensor of shape (1, num_heads, 1, 1)", "T") + .Output(0, "output", "output tensor with shape (batch_size, num_heads, seq_len, seq_len)", "T") + .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float tensors.") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + propagateElemTypeFromInputToOutput(ctx, 0, 0); + int64_t num_heads = getAttribute(ctx, "num_heads", -1L); + if (hasInputShape(ctx, 0)) { + auto& query_layer_shape = getInputShape(ctx, 0); + TensorShapeProto output_shape; + *output_shape.add_dim() = query_layer_shape.dim(0); + output_shape.add_dim()->set_dim_value(num_heads); + *output_shape.add_dim() = query_layer_shape.dim(1); + *output_shape.add_dim() = query_layer_shape.dim(1); + updateOutputShape(ctx, 0, output_shape); + } + })); + } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/core/graph/contrib_ops/ms_opset.h b/onnxruntime/core/graph/contrib_ops/ms_opset.h index a511d01fe1624..bd8469909fe7f 100644 --- a/onnxruntime/core/graph/contrib_ops/ms_opset.h +++ b/onnxruntime/core/graph/contrib_ops/ms_opset.h @@ -81,6 +81,7 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MurmurHash3); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, NGramRepeatBlock); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Pad); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, RelativePositionBias); +class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, GatedRelativePositionBias); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, RemovePadding); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, RestorePadding); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Rfft); @@ -171,6 +172,7 @@ class OpSet_Microsoft_ver1 { fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); + fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); diff --git a/onnxruntime/core/graph/contrib_ops/quantization_defs.cc b/onnxruntime/core/graph/contrib_ops/quantization_defs.cc index 6111afbd5d817..91e4f5d8ff81a 100644 --- a/onnxruntime/core/graph/contrib_ops/quantization_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/quantization_defs.cc @@ -1140,7 +1140,7 @@ where value of each element is the end position, or valid length of actual seque left-side padding, mask_index has shape (2 * batch_size), where the values are the exclusive end positions followed by the inclusive start positions. When unidirectional is 1, and each token only attend to previous tokens. For GPT-2, both past and present state are optional. Present state could appear in output even when past state is not in input. -Current version does not support past/present, extra_add and qkv_hidden_sizes. +Current version does not support past/present, relative_position_bias and qkv_hidden_sizes. TODO: Support them if needed in the future. )DOC"; @@ -1202,7 +1202,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( .Input(18, "past", "past state for key and value with shape (2, batch_size, num_heads, past_sequence_length, head_size).", "Q", OpSchema::Optional) - .Input(19, "extra_add", + .Input(19, "relative_position_bias", "additional add to QxK' with shape (batch_size, num_heads, sequence_length, sequence_length).", "S", OpSchema::Optional) .Output(0, "output", "3D output tensor with shape (batch_size, sequence_length, hidden_size)", "Q") diff --git a/onnxruntime/core/providers/cann/tensor/identity_op.h b/onnxruntime/core/providers/cann/tensor/identity_op.h index e4d74149c6d1d..8d9802ed8b00c 100644 --- a/onnxruntime/core/providers/cann/tensor/identity_op.h +++ b/onnxruntime/core/providers/cann/tensor/identity_op.h @@ -66,6 +66,7 @@ class IdentityOp final : public CannKernel { "IdentityOp cann: unable to get an allocator."); } auto X_size = X->Size(); + Y->Reserve(X_size); for (size_t i = 0; i < X_size; ++i) { const Tensor& source_tensor = X->Get(i); std::unique_ptr target_tensor = Tensor::Create(source_tensor.DataType(), diff --git a/onnxruntime/core/providers/cpu/controlflow/loop.cc b/onnxruntime/core/providers/cpu/controlflow/loop.cc index 69301f52aca29..43cd5787718f1 100644 --- a/onnxruntime/core/providers/cpu/controlflow/loop.cc +++ b/onnxruntime/core/providers/cpu/controlflow/loop.cc @@ -543,29 +543,26 @@ Status LoopImpl::Execute(const FeedsFetchesManager& ffm) { } else { // We can't move the Loop's inputs directly into the Loop's outputs // as operator inputs are read-only. Hence, we need to make a copy. - std::vector tensors; - auto& data = input.Get(); - output->SetType(data.DataType()); + output->Reserve(data.Size()); AllocatorPtr alloc; ORT_RETURN_IF_ERROR(context_.GetTempSpaceAllocator(&alloc)); for (auto it = data.begin(), end = data.end(); it != end; ++it) { - Tensor tmp(it->DataType(), onnxruntime::TensorShape(it->Shape()), alloc); + Tensor tmp(it->Get().DataType(), it->Get().Shape(), alloc); // Safely use the IDataTransfer abstraction as we only allow using // Loop on CUDA if the copy stream is the same as the compute stream. // So there is no explicit sync required between the compute and copy streams // to avoid data races. - auto* data_transer = session_state_.GetDataTransferMgr().GetDataTransfer(it->Location().device, tmp.Location().device); + auto* data_transer = session_state_.GetDataTransferMgr().GetDataTransfer(it->Get().Location().device, tmp.Location().device); if (context_.GetComputeStream()) - ORT_RETURN_IF_ERROR(data_transer->CopyTensorAsync(*it, tmp, *context_.GetComputeStream())); + ORT_RETURN_IF_ERROR(data_transer->CopyTensorAsync(it->Get(), tmp, *context_.GetComputeStream())); else - ORT_RETURN_IF_ERROR(data_transer->CopyTensor(*it, tmp)); - tensors.push_back(std::move(tmp)); - } + ORT_RETURN_IF_ERROR(data_transer->CopyTensor(it->Get(), tmp)); - output->SetElements(std::move(tensors)); + output->Add(std::move(tmp)); + } } } diff --git a/onnxruntime/core/providers/cpu/cpu_provider_shared.cc b/onnxruntime/core/providers/cpu/cpu_provider_shared.cc index b4a92019992b5..c0a75fc50b07e 100644 --- a/onnxruntime/core/providers/cpu/cpu_provider_shared.cc +++ b/onnxruntime/core/providers/cpu/cpu_provider_shared.cc @@ -198,12 +198,12 @@ struct ProviderHostCPUImpl : ProviderHostCPU { const TensorShape& bias_shape, const Tensor*& mask_index, const Tensor* past, - const Tensor* extra_add_qk, + const Tensor* relative_position_bias, void* parameters, const int max_threads_per_block, const Tensor* past_seq_len) override { return p->contrib::AttentionBase::CheckInputs(input_shape, weights_shape, bias_shape, mask_index, past, - extra_add_qk, + relative_position_bias, parameters, max_threads_per_block, past_seq_len); diff --git a/onnxruntime/core/providers/cpu/cpu_provider_shared.h b/onnxruntime/core/providers/cpu/cpu_provider_shared.h index 2490789dd31a2..f12e080adf30a 100644 --- a/onnxruntime/core/providers/cpu/cpu_provider_shared.h +++ b/onnxruntime/core/providers/cpu/cpu_provider_shared.h @@ -145,7 +145,7 @@ struct ProviderHostCPU { const TensorShape& bias_shape, const Tensor*& mask_index, const Tensor* past, - const Tensor* extra_add_qk, + const Tensor* relative_position_bias, void* parameters, const int max_threads_per_block, const Tensor* past_seq_len) = 0; diff --git a/onnxruntime/core/providers/cpu/optional/optional_ops.cc b/onnxruntime/core/providers/cpu/optional/optional_ops.cc index 946f0f335d201..33e9401935389 100644 --- a/onnxruntime/core/providers/cpu/optional/optional_ops.cc +++ b/onnxruntime/core/providers/cpu/optional/optional_ops.cc @@ -5,6 +5,7 @@ #include "optional_ops.h" #include "core/framework/ort_value.h" +#include "core/framework/TensorSeq.h" #include "core/providers/cpu/tensor/utils.h" namespace onnxruntime { @@ -69,18 +70,15 @@ static void CopySequenceTensor(AllocatorPtr alloc, } tgt->SetType(src->DataType()); - - std::vector output_tensors; - output_tensors.reserve(src->Size()); + tgt->Reserve(src->Size()); auto in_tensor = src->begin(); for (; in_tensor != src->end(); ++in_tensor) { - Tensor tmp(in_tensor->DataType(), onnxruntime::TensorShape(in_tensor->Shape()), alloc); - CopyCpuTensor(&*in_tensor, &tmp); - output_tensors.push_back(std::move(tmp)); + auto& tensor = in_tensor->Get(); + Tensor tmp(tensor.DataType(), tensor.Shape(), alloc); + CopyCpuTensor(&tensor, &tmp); + tgt->Add(std::move(tmp)); } - - tgt->SetElements(std::move(output_tensors)); } static Status PropagateInputOrtValueToFirstOutput(const OrtValue* input_ort_value, diff --git a/onnxruntime/core/providers/cpu/sequence/concat_from_sequence.cc b/onnxruntime/core/providers/cpu/sequence/concat_from_sequence.cc index 2e7ddbb98a744..b18ca6bd94dda 100644 --- a/onnxruntime/core/providers/cpu/sequence/concat_from_sequence.cc +++ b/onnxruntime/core/providers/cpu/sequence/concat_from_sequence.cc @@ -26,7 +26,7 @@ Status ConcatFromSequence::Compute(OpKernelContext* ctx) const { InlinedTensorsVector input_tensor_pointers; input_tensor_pointers.reserve(X->Size()); for (const auto& t : *X) { - input_tensor_pointers.push_back(&t); + input_tensor_pointers.push_back(&t.Get()); } // Validate inputs and prepare some metadata used during actual compute diff --git a/onnxruntime/core/providers/cpu/sequence/sequence_ops.cc b/onnxruntime/core/providers/cpu/sequence/sequence_ops.cc index f8282a3b5481f..578a3472ac10c 100644 --- a/onnxruntime/core/providers/cpu/sequence/sequence_ops.cc +++ b/onnxruntime/core/providers/cpu/sequence/sequence_ops.cc @@ -98,7 +98,8 @@ Status SequenceAt::Compute(OpKernelContext* context) const { const Tensor& indexed_tensor = X->Get(onnxruntime::narrow(input_seq_idx)); auto* Y = context->Output(0, indexed_tensor.Shape().GetDims()); - CopyCpuTensor(&indexed_tensor, Y); + // Using DataTransferManager here allows other non-CPU EPs to use this implementation of the sequence ops + ORT_RETURN_IF_ERROR(Info().GetDataTransferManager().CopyTensor(indexed_tensor, *Y)); return Status::OK(); } @@ -183,13 +184,13 @@ ONNX_CPU_OPERATOR_KERNEL( DataTypeImpl::GetTensorType()}), SequenceInsert); -Status CreateCopyAndAppendCpuTensor(const Tensor& in_tensor, OpKernelContext* context, std::vector& tensors) { +// Using DataTransferManager here allows other non-CPU EPs to use this implementation of the sequence ops +static Tensor CloneTensor(const Tensor& in_tensor, OpKernelContext* context, const DataTransferManager& dtm) { AllocatorPtr alloc; - ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&alloc)); + ORT_THROW_IF_ERROR(context->GetTempSpaceAllocator(&alloc)); Tensor tmp(in_tensor.DataType(), onnxruntime::TensorShape(in_tensor.Shape()), alloc); - CopyCpuTensor(&in_tensor, &tmp); - tensors.push_back(std::move(tmp)); - return Status::OK(); + ORT_THROW_IF_ERROR(dtm.CopyTensor(in_tensor, tmp)); + return tmp; } Status SequenceInsert::Compute(OpKernelContext* context) const { @@ -219,24 +220,23 @@ Status SequenceInsert::Compute(OpKernelContext* context) const { } auto* Y = context->Output(0); + Y->SetType(S->DataType()); + Y->Reserve(SafeInt(num_tensors_input_seq) + 1); - std::vector tensors; - tensors.reserve(SafeInt(num_tensors_input_seq) + 1); for (int i = 0; i < num_tensors_input_seq; ++i) { if (i == input_seq_idx) { - ORT_RETURN_IF_ERROR(CreateCopyAndAppendCpuTensor(*X, context, tensors)); - ORT_RETURN_IF_ERROR(CreateCopyAndAppendCpuTensor(S->Get(i), context, tensors)); + // Using DataTransferManager here allows other non-CPU EPs to use this implementation of the sequence ops + Y->Add(CloneTensor(*X, context, Info().GetDataTransferManager())); + Y->Add(S->GetAt(i)); } else { - ORT_RETURN_IF_ERROR(CreateCopyAndAppendCpuTensor(S->Get(i), context, tensors)); + Y->Add(S->GetAt(i)); } } if (input_seq_idx == num_tensors_input_seq) { - ORT_RETURN_IF_ERROR(CreateCopyAndAppendCpuTensor(*X, context, tensors)); + // Using DataTransferManager here allows other non-CPU EPs to use this implementation of the sequence ops + Y->Add(CloneTensor(*X, context, Info().GetDataTransferManager())); } - Y->SetType(S->DataType()); - Y->SetElements(std::move(tensors)); - return Status::OK(); } @@ -271,16 +271,14 @@ Status SequenceErase::Compute(OpKernelContext* context) const { auto* Y = context->Output(0); Y->SetType(S->DataType()); + Y->Reserve(SafeInt(num_tensors_input_seq) - 1); - std::vector tensors; - tensors.reserve(SafeInt(num_tensors_input_seq) - 1); for (int i = 0; i < num_tensors_input_seq; ++i) { if (i == input_seq_idx) { continue; } - ORT_RETURN_IF_ERROR(CreateCopyAndAppendCpuTensor(S->Get(i), context, tensors)); + Y->Add(S->GetAt(i)); } - Y->SetElements(std::move(tensors)); return Status::OK(); } @@ -311,13 +309,12 @@ Status SequenceConstruct::Compute(OpKernelContext* context) const { // now copy the tensors to the output sequence Y->SetType(first_dtype); - std::vector tensors; - tensors.reserve(num_inputs); + Y->Reserve(SafeInt(num_inputs)); for (int input_idx = 0; input_idx < num_inputs; ++input_idx) { const auto* X = context->Input(input_idx); - ORT_RETURN_IF_ERROR(CreateCopyAndAppendCpuTensor(*X, context, tensors)); + // Using DataTransferManager here allows other non-CPU EPs to use this implementation of the sequence ops + Y->Add(CloneTensor(*X, context, Info().GetDataTransferManager())); } - Y->SetElements(std::move(tensors)); return Status::OK(); } @@ -503,10 +500,12 @@ Status SplitToSequence::ComputeImpl(OpKernelContext& context, const Tensor& inpu is_uneven_split, num_remaining_splits, split_sizes)); + auto tseq = context.Output(0); + tseq->SetType(input.DataType()); + tseq->Reserve(static_cast(num_outputs)); // copy dimensions so we can update the selected axis in place auto output_dimensions = input_shape.AsShapeVector(); - std::vector tensors; int64_t input_offset = 0; const T* input_data = input.Data(); for (int i = 0; i < num_outputs; ++i) { @@ -550,13 +549,9 @@ Status SplitToSequence::ComputeImpl(OpKernelContext& context, const Tensor& inpu } // finally move the resulting tensor to the output sequence - tensors.push_back(std::move(output_tensor)); + tseq->Add(std::move(output_tensor)); } - auto& tseq = *context.Output(0); - tseq.SetType(input.DataType()); - tseq.SetElements(std::move(tensors)); - return Status::OK(); } } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/tensor/concat.h b/onnxruntime/core/providers/cpu/tensor/concat.h index 227f98bb4f2ae..3272b5a961a6f 100644 --- a/onnxruntime/core/providers/cpu/tensor/concat.h +++ b/onnxruntime/core/providers/cpu/tensor/concat.h @@ -5,7 +5,6 @@ #include "core/common/common.h" #include "core/framework/op_kernel.h" -#include "core/util/math_cpuonly.h" #include "core/framework/tensor.h" #include "concatbase.h" diff --git a/onnxruntime/core/providers/cpu/tensor/identity_op.h b/onnxruntime/core/providers/cpu/tensor/identity_op.h index 6fa6c26b1e2aa..4e2a97d58861e 100644 --- a/onnxruntime/core/providers/cpu/tensor/identity_op.h +++ b/onnxruntime/core/providers/cpu/tensor/identity_op.h @@ -15,6 +15,7 @@ #include "core/framework/op_kernel.h" #include "core/framework/TensorSeq.h" #include "core/framework/tensorprotoutils.h" +#include "core/providers/cpu/tensor/utils.h" #include "core/providers/utils.h" namespace onnxruntime { @@ -52,7 +53,7 @@ class IdentityOp final : public OpKernel { const void* source = X->DataRaw(X_type); void* target = Y->MutableDataRaw(X_type); - //If source and target pointers are not equal, we need to copy the data. + // If source and target pointers are not equal, we need to copy the data. if (target != source) { if (!X->IsDataTypeString()) { memcpy(target, source, SafeInt(shape.Size()) * X_type->Size()); @@ -89,21 +90,20 @@ class IdentityOp final : public OpKernel { // processing Tensors. if (X != output) { output->SetType(X->DataType()); + output->Reserve(X->Size()); AllocatorPtr alloc; auto status = context->GetTempSpaceAllocator(&alloc); if (!status.IsOK()) { ORT_THROW("Unable to get an allocator"); } - std::vector tensors; + for (auto it = X->begin(), end = X->end(); it != end; ++it) { - Tensor tmp(it->DataType(), onnxruntime::TensorShape(it->Shape()), alloc); - size_t bytes = it->SizeInBytes(); - memcpy(tmp.MutableDataRaw(), it->DataRaw(), bytes); - tensors.push_back(std::move(tmp)); + auto& it_tensor = it->Get(); + Tensor tmp(it_tensor.DataType(), it_tensor.Shape(), alloc); + CopyCpuTensor(&it_tensor, &tmp); + output->Add(std::move(tmp)); } - - output->SetElements(std::move(tensors)); } } @@ -111,4 +111,4 @@ class IdentityOp final : public OpKernel { } }; -} //namespace onnxruntime +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/activation/activations_impl.cu b/onnxruntime/core/providers/cuda/activation/activations_impl.cu index 05f90ee05b275..0f89fe984372d 100644 --- a/onnxruntime/core/providers/cuda/activation/activations_impl.cu +++ b/onnxruntime/core/providers/cuda/activation/activations_impl.cu @@ -47,7 +47,18 @@ struct OP_Selu : public CtxSelu { template struct OP_Sigmoid : public CtxSigmoid { __device__ __inline__ T operator()(const T& a) const { - return a > T(0) ? (T)1 / ((T)1. + _Exp(-_Abs(a))) : (T)1 - (T)1 / ((T)1 + _Exp(-_Abs(a))); + return a > T(0) ? (T)1 / ((T)1. + _Exp(-a)) : (T)1 - (T)1 / ((T)1 + _Exp(a)); + } +}; + +template <> +struct OP_Sigmoid : public CtxSigmoid { + __device__ __inline__ half operator()(const half& a) const { + // For some small negative values, it will loss precision if using half for compute. + // The cast between half and float below won't cause perf issue since there is cast inside _Exp if input is half. + float af = static_cast(a); + float res = af > 0.0f ? 1.0f / (1.0f + _Exp(-af)) : 1.0f - 1.0f / (1.0f + _Exp(af)); + return static_cast(res); } }; diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 7a673a05858ca..a8e6dbe963c06 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -81,6 +81,7 @@ class Memcpy final : public OpKernel { } } auto X_size = X->Size(); + Y->Reserve(X_size); for (size_t i = 0; i < X_size; ++i) { const Tensor& source_tensor = X->Get(i); std::unique_ptr target_tensor = Tensor::Create(source_tensor.DataType(), source_tensor.Shape(), alloc); diff --git a/onnxruntime/core/providers/cuda/tensor/identity_op.h b/onnxruntime/core/providers/cuda/tensor/identity_op.h index b1dea1e8b6e14..81fe93b48cc0b 100644 --- a/onnxruntime/core/providers/cuda/tensor/identity_op.h +++ b/onnxruntime/core/providers/cuda/tensor/identity_op.h @@ -68,6 +68,7 @@ class IdentityOp final : public CudaKernel { "IdentityOp cuda: unable to get an allocator."); } auto X_size = X->Size(); + Y->Reserve(X_size); for (size_t i = 0; i < X_size; ++i) { const Tensor& source_tensor = X->Get(i); std::unique_ptr target_tensor = Tensor::Create(source_tensor.DataType(), diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/DmlExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/DmlExecutionProvider.h index 9dfbd0e7ea0e0..fd96bd812d056 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/DmlExecutionProvider.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/DmlExecutionProvider.h @@ -31,12 +31,12 @@ namespace Dml bool enableMetacommands = true); ID3D12Resource* GetD3D12ResourceFromAllocation(onnxruntime::IAllocator* allocator, void* ptr); - void FlushContext(onnxruntime::IExecutionProvider* provider); + void FlushContext(onnxruntime::IExecutionProvider* provider); void SetDefaultRoundingMode(onnxruntime::IExecutionProvider* provider, AllocatorRoundingMode roundingMode); void ReleaseCompletedReferences(onnxruntime::IExecutionProvider* provider); - + onnxruntime::common::Status CopyTensor( - onnxruntime::IExecutionProvider* provider, + onnxruntime::IExecutionProvider* provider, const onnxruntime::Tensor& src, onnxruntime::Tensor& dst ); @@ -44,5 +44,6 @@ namespace Dml void FreeGPUAllocation(void* ptr); void RegisterDmlOperators(IMLOperatorRegistry* registry); + void RegisterCpuOperatorsAsDml(onnxruntime::KernelRegistry* registry); } // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h index 501a66bdfa711..232a022d869f4 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h @@ -29,7 +29,7 @@ namespace Windows::AI::MachineLearning::Adapter { public: // Hold a reference to an object until preceding work in the queue is complete. This - // only needs to be handled by providers which hide the asynchronous nature of + // only needs to be handled by providers which hide the asynchronous nature of // computation, and involve resoures which cannot be automatically by work in the // the provider's underlying queues. virtual void QueueReference(IUnknown *object) = 0; @@ -43,7 +43,7 @@ namespace Windows::AI::MachineLearning::Adapter bool isInternalOperator, IUnknown* data, IUnknown** abiData) const = 0; - + virtual uint64_t TryGetPooledAllocationId( IUnknown* data, bool isInternalOperator) = 0; @@ -63,12 +63,12 @@ namespace Windows::AI::MachineLearning::Adapter uint32_t resourceCount, IUnknown** resources) = 0; - // Waits for flushed work, discards unflushed work, and discards associated references to + // Waits for flushed work, discards unflushed work, and discards associated references to // prevent circular references. Must be the last call on the object before destruction. virtual void Close() = 0; }; - using MLOperatorTensorGetter = std::function(uint32_t index)>; + using MLOperatorTensorGetter = std::function, std::vector>>(uint32_t index)>; struct DmlOperatorParams { @@ -89,7 +89,7 @@ namespace Windows::AI::MachineLearning::Adapter }; using GraphNodeFactory = std::function>; -} \ No newline at end of file +} diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/MLOperatorAuthor.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/MLOperatorAuthor.h index 5b85576a5079a..69181db6021cb 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/MLOperatorAuthor.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/MLOperatorAuthor.h @@ -21,7 +21,7 @@ static_assert(sizeof(bool) == 1, "Unsupported size for bool type"); //! \enum MLOperatorAttributeType //! \brief Specifies the type of an attribute. -enum class MLOperatorAttributeType : uint32_t +enum class MLOperatorAttributeType : uint32_t { //! Undefined (unused) Undefined = 0, @@ -53,7 +53,7 @@ enum class MLOperatorTensorDataType : uint32_t //! Undefined (unused). Undefined = 0, - //! IEEE 32 bit floating point + //! IEEE 32 bit floating point Float = 1, //! 8 bit unsigned integer @@ -101,28 +101,31 @@ enum class MLOperatorTensorDataType : uint32_t //! \enum MLOperatorEdgeType //! \brief Specifies the types of an input or output edge of an operator. -enum class MLOperatorEdgeType : uint32_t -{ +enum class MLOperatorEdgeType : uint32_t +{ Undefined = 0, Tensor = 1, + SequenceTensor = 2, + Primitive = 3, }; - + //! \struct MLOperatorEdgeDescription //! \brief Specifies the properties of an input or output edge of an operator. -struct MLOperatorEdgeDescription +struct MLOperatorEdgeDescription { //! The type of the edge. MLOperatorEdgeType edgeType; - - union + + union { uint64_t reserved; //! The data type of a tensor. Used when edgeType is set to Tensor. + //! The data type of each tensor in a sequence of tensors. Used when edgeType is set to Sequence. MLOperatorTensorDataType tensorDataType; }; }; - + //! \interface IMLOperatorAttributes //! \brief Represents the values of an operator's attributes, as determined by a model using the operator. //! This interface is called by implementations of custom operator kernels, and by implementations @@ -133,13 +136,13 @@ IMLOperatorAttributes : IUnknown //! Gets the count of elements in an attribute. //! This may be used to determine if an attribute exists, and to determine the //! count of elements within an attribute of an array type. - STDMETHOD(GetAttributeElementCount)( + STDMETHOD(GetAttributeElementCount)( _In_z_ const char* name, MLOperatorAttributeType type, _Out_ uint32_t* elementCount ) const noexcept PURE; - //! Gets the value of an attribute element which is of a numeric type. + //! Gets the value of an attribute element which is of a numeric type. //! For attributes which are of array types, this method queries //! an individual element within the attribute at the specified index. STDMETHOD(GetAttribute)( @@ -149,10 +152,10 @@ IMLOperatorAttributes : IUnknown size_t elementByteSize, _Out_writes_bytes_(elementCount * elementByteSize) void* value ) const noexcept PURE; - + //! Gets the length of an attribute element which is of a string type. //! For attributes which are string arrays, this method queries - //! the size of an individual element within the attribute at the + //! the size of an individual element within the attribute at the //! specified index. //! The string is in UTF-8 format. The size includes the null termination character. STDMETHOD(GetStringAttributeElementLength)( @@ -160,10 +163,10 @@ IMLOperatorAttributes : IUnknown uint32_t elementIndex, _Out_ uint32_t* attributeElementByteSize ) const noexcept PURE; - + //! Gets the value of an attribute element which is of a string type. //! For attributes which are string arrays, this method queries - //! the value of an individual element within the attribute at the + //! the value of an individual element within the attribute at the //! specified index. //! The string is in UTF-8 format. The size includes the null termination character. STDMETHOD(GetStringAttributeElement)( @@ -177,7 +180,7 @@ IMLOperatorAttributes : IUnknown //! \interface IMLOperatorTensorShapeDescription //! \brief Represents the set of input and output tensor shapes of an operator. //! This interface is called by the factory objects registered to create kernels. -//! It is available to these factory objects unless corresponding kernels are +//! It is available to these factory objects unless corresponding kernels are //! registered using the MLOperatorKernelOptions::AllowDynamicInputShapes flag. interface DECLSPEC_UUID("F20E8CBE-3B28-4248-BE95-F96FBC6E4643") DECLSPEC_NOVTABLE IMLOperatorTensorShapeDescription : IUnknown @@ -185,50 +188,50 @@ IMLOperatorTensorShapeDescription : IUnknown //! Gets the number of dimensions of a tensor input of the operator. //! Returns an error if the input at the specified index is not a tensor. STDMETHOD(GetInputTensorDimensionCount)( - uint32_t inputIndex, + uint32_t inputIndex, _Out_ uint32_t* dimensionCount ) const noexcept PURE; //! Gets the sizes of dimensions of an input tensor of the operator. //! Returns an error if the input at the specified index is not a tensor. STDMETHOD(GetInputTensorShape)( - uint32_t inputIndex, - uint32_t dimensionCount, + uint32_t inputIndex, + uint32_t dimensionCount, _Out_writes_(dimensionCount) uint32_t* dimensions ) const noexcept PURE; - - //! Returns true if output shapes may be queried using GetOutputTensorDimensionCount - //! and GetOutputTensorShape. This is true if the kernel was registered with a + + //! Returns true if output shapes may be queried using GetOutputTensorDimensionCount + //! and GetOutputTensorShape. This is true if the kernel was registered with a //! shape inferrer. STDMETHOD_(bool, HasOutputShapeDescription)() const noexcept PURE; //! Gets the number of dimensions of a tensor output of the operator. //! Returns an error if the output at the specified index is not a tensor. STDMETHOD(GetOutputTensorDimensionCount)( - uint32_t outputIndex, + uint32_t outputIndex, _Out_ uint32_t* dimensionCount ) const noexcept PURE; //! Gets the sizes of dimensions of a tensor output of the operator. //! Returns an error if the output at the specified index is not a tensor. STDMETHOD(GetOutputTensorShape)( - uint32_t outputIndex, - uint32_t dimensionCount, + uint32_t outputIndex, + uint32_t dimensionCount, _Out_writes_(dimensionCount) uint32_t* dimensions ) const noexcept PURE; }; - + //! \interface IMLOperatorKernelCreationContext //! \brief Provides information about an operator's usage while kernels are being created. interface DECLSPEC_UUID("5459B53D-A0FC-4665-ADDD-70171EF7E631") DECLSPEC_NOVTABLE -IMLOperatorKernelCreationContext : public IMLOperatorAttributes +IMLOperatorKernelCreationContext : public IMLOperatorAttributes { //! Gets the number of inputs to the operator. STDMETHOD_(uint32_t, GetInputCount)() const noexcept PURE; //! Gets the number of outputs to the operator. STDMETHOD_(uint32_t, GetOutputCount)() const noexcept PURE; - + //! Returns true if an input to the operator is valid. //! This always returns true if within GetInputCount except for optional inputs. STDMETHOD_(bool, IsInputValid)(uint32_t inputIndex) const noexcept PURE; @@ -239,38 +242,38 @@ IMLOperatorKernelCreationContext : public IMLOperatorAttributes //! Gets the description of the specified input edge of the operator. STDMETHOD(GetInputEdgeDescription)( - uint32_t inputIndex, + uint32_t inputIndex, _Out_ MLOperatorEdgeDescription* edgeDescription ) const noexcept PURE; //! Gets the description of the specified output edge of the operator. STDMETHOD(GetOutputEdgeDescription)( - uint32_t outputIndex, + uint32_t outputIndex, _Out_ MLOperatorEdgeDescription* edgeDescription ) const noexcept PURE; - + //! Returns true if the description of input and output shapes connected to //! operator edges may be queried using GetTensorShapeDescription. //! This returns true unless the operator was registered using //! the MLOperatorKernelOptions::AllowDynamicInputShapes flag. STDMETHOD_(bool, HasTensorShapeDescription)() const noexcept PURE; - + //! Gets the description of input and output shapes connected to //! operator edges. STDMETHOD(GetTensorShapeDescription)( _COM_Outptr_ IMLOperatorTensorShapeDescription** shapeDescription ) const noexcept PURE; - + //! Returns an object whose supported interfaces vary based on the kernel type. //! For kernels registered with MLOperatorExecutionType::Cpu, executionObject will - //! be set to nullptr. + //! be set to nullptr. //! For kernels registered with MLOperatorExecutionType::D3D12, executionObject will //! support the ID3D12GraphicsCommandList interface. STDMETHOD_(void, GetExecutionInterface)( _Outptr_result_maybenull_ IUnknown** executionObject ) const noexcept PURE; }; - + //! \interface IMLOperatorTensor //! \brief Representation of a tensor used during computation of custom operator kernels. interface DECLSPEC_UUID("7FE41F41-F430-440E-AECE-54416DC8B9DB") DECLSPEC_NOVTABLE @@ -287,24 +290,24 @@ IMLOperatorTensor : IUnknown //! Gets the data type of the tensor. STDMETHOD_(MLOperatorTensorDataType, GetTensorDataType)() const noexcept PURE; - + //! Indicates whether the memory used by the tensor is CPU-addressable. //! This is true when kernels are registered using MLOperatorExecutionType::Cpu. STDMETHOD_(bool, IsCpuData)() const noexcept PURE; - - //! Whether the contents of the tensor are represented by an interface type, - //! or byte-addressable memory. This returns true when kernels are registered + + //! Whether the contents of the tensor are represented by an interface type, + //! or byte-addressable memory. This returns true when kernels are registered //! using MLOperatorExecutionType::D3D12. STDMETHOD_(bool, IsDataInterface)() const noexcept PURE; - + //! Returns a pointer to byte-addressable memory for the tensor. This may be - //! used when IsDataInterface returns false, because the kernel was - //! registered using MLOperatorExecutionType::Cpu. The data size is derived + //! used when IsDataInterface returns false, because the kernel was + //! registered using MLOperatorExecutionType::Cpu. The data size is derived //! from the tensor's shape. It is fully packed in memory. - STDMETHOD_(void*, GetData)() noexcept PURE; - + STDMETHOD_(void*, GetData)() noexcept PURE; + //! Gets an interface pointer for the tensor. This may be - //! used when IsDataInterface returns true, because the kernel was + //! used when IsDataInterface returns true, because the kernel was //! registered using MLOperatorExecutionType::D3D12. The dataInterface //! object supports the ID3D12Resource interface, and is a GPU buffer. STDMETHOD_(void, GetDataInterface)( @@ -318,30 +321,30 @@ interface DECLSPEC_UUID("82536A28-F022-4769-9D3F-8B278F84C0C3") DECLSPEC_NOVTABL IMLOperatorKernelContext : IUnknown { //! Gets the input tensor of the operator at the specified index. - //! This sets tensor to nullptr for optional inputs which do not exist. + //! This sets tensor to nullptr for optional inputs which do not exist. //! Returns an error if the input at the specified index is not a tensor. STDMETHOD(GetInputTensor)( - uint32_t inputIndex, + uint32_t inputIndex, _COM_Outptr_result_maybenull_ IMLOperatorTensor** tensor ) const noexcept PURE; - + //! Gets the output tensor of the operator at the specified index. - //! This sets tensor to nullptr for optional outputs which do not exist. - //! If the operator kernel was registered without a shape inference method, - //! then the overload of GetOutputTensor which consumes the tensor's shape must - //! be called instead. Returns an error if the output at the specified index is + //! This sets tensor to nullptr for optional outputs which do not exist. + //! If the operator kernel was registered without a shape inference method, + //! then the overload of GetOutputTensor which consumes the tensor's shape must + //! be called instead. Returns an error if the output at the specified index is //! not a tensor. STDMETHOD(GetOutputTensor)( - uint32_t outputIndex, + uint32_t outputIndex, _COM_Outptr_result_maybenull_ IMLOperatorTensor** tensor ) noexcept PURE; //! Gets the output tensor of the operator at the specified index, while declaring - //! its shape. - //! This returns nullptr for optional outputs which do not exist. - //! If the operator kernel was registered with a shape inference method, + //! its shape. + //! This returns nullptr for optional outputs which do not exist. + //! If the operator kernel was registered with a shape inference method, //! then the overload of GetOutputTensor which doesn't consume a shape may also - //! be called. Returns an error if the output at the specified index is + //! be called. Returns an error if the output at the specified index is //! not a tensor. STDMETHOD(GetOutputTensor)( uint32_t outputIndex, @@ -349,19 +352,19 @@ IMLOperatorKernelContext : IUnknown _In_reads_(dimensionCount) const uint32_t* dimensionSizes, _COM_Outptr_result_maybenull_ IMLOperatorTensor** tensor ) noexcept PURE; - + //! Allocates temporary data which will be usable as intermediate memory for the duration //! of a call to IMLOperatorKernel::Compute. This may be used by kernels //! registered using MLOperatorExecutionType::D3D12. The data //! object supports the ID3D12Resource interface, and is a GPU buffer. STDMETHOD(AllocateTemporaryData)(size_t size, _COM_Outptr_ IUnknown** data) const = 0; - + //! Returns an object whose supported interfaces vary based on the kernel type. //! For kernels registered with MLOperatorExecutionType::Cpu, executionObject will - //! be set to nullptr. + //! be set to nullptr. //! For kernels registered with MLOperatorExecutionType::D3D12, executionObject will //! support the ID3D12GraphicsCommandList interface. This may be a different object - //! than was provided to IMLOperatorKernelCreationContext::GetExecutionInterface + //! than was provided to IMLOperatorKernelCreationContext::GetExecutionInterface //! when the kernel instance was created. STDMETHOD_(void, GetExecutionInterface)( _Outptr_result_maybenull_ IUnknown** executionObject @@ -370,7 +373,7 @@ IMLOperatorKernelContext : IUnknown //! \interface IMLOperatorKernel //! \brief Implemented by custom operator kernels. -//! A factory which creates interfaces of this interface is supplied when +//! A factory which creates interfaces of this interface is supplied when //! registering custom operator kernels using IMLOperatorKernelFactory::RegisterOperatorKernel. interface DECLSPEC_UUID("11C4B4A0-B467-4EAA-A1A6-B961D8D0ED79") DECLSPEC_NOVTABLE IMLOperatorKernel : IUnknown @@ -380,11 +383,11 @@ IMLOperatorKernel : IUnknown //! simultaneously on different threads. STDMETHOD(Compute)(IMLOperatorKernelContext* context) noexcept PURE; }; - + //! \enum MLOperatorParameterOptions //! \brief Specifies option flags of input and output edges of operators. //! These options are used while defining custom operator schema. -enum class MLOperatorParameterOptions : uint32_t +enum class MLOperatorParameterOptions : uint32_t { //! There is a single instance of the input or output. Single = 0, @@ -402,11 +405,11 @@ DEFINE_ENUM_FLAG_OPERATORS(MLOperatorParameterOptions); //! \enum MLOperatorSchemaEdgeTypeFormat //! \brief Specifies the manner in which types of input and output edges are described. //! This is used within MLOperatorSchemaEdgeDescription while defining custom operator schema. -enum class MLOperatorSchemaEdgeTypeFormat +enum class MLOperatorSchemaEdgeTypeFormat { //! The type is defined using MLOperatorEdgeDescription. EdgeDescription = 0, - + //! The type is defined by a type string constructed as in ONNX operator schema. Label = 1, }; @@ -418,10 +421,10 @@ struct MLOperatorSchemaEdgeDescription { //! Options of the parameter, including whether it is optional or variadic. MLOperatorParameterOptions options; - + //! The manner in which the type constraints and type mapping are defined. MLOperatorSchemaEdgeTypeFormat typeFormat; - union + union { const void* reserved; @@ -429,23 +432,23 @@ struct MLOperatorSchemaEdgeDescription //! This is used when typeFormat is MLOperatorSchemaEdgeTypeFormat::Label. _Field_z_ const char* typeLabel; - //! A structure describing type support. + //! A structure describing type support. //! This is used when typeFormat is MLOperatorSchemaEdgeTypeFormat::EdgeDescription. MLOperatorEdgeDescription edgeDescription; }; }; - + //! \struct MLOperatorEdgeTypeConstraint -//! \brief Specifies constraints upon the types of edges supported in custom operator kernels -//! and schema. The provided type label string corresponds to type labels in the ONNX -//! specification for the same operator. For custom schema, it corresponds to type labels +//! \brief Specifies constraints upon the types of edges supported in custom operator kernels +//! and schema. The provided type label string corresponds to type labels in the ONNX +//! specification for the same operator. For custom schema, it corresponds to type labels //! specified within MLOperatorSchemaEdgeDescription when registering the operator's schema. -struct MLOperatorEdgeTypeConstraint +struct MLOperatorEdgeTypeConstraint { //! The label of the type for which the constraint is being defined. //! This is constructed as in ONNX operator schema. For example, "T". _Field_z_ const char* typeLabel; - + //! The set of allowed types for the constraint. _Field_size_opt_(allowedTypeCount) const MLOperatorEdgeDescription* allowedTypes; uint32_t allowedTypeCount; @@ -457,7 +460,7 @@ using MLOperatorEdgeTypeConstrant = MLOperatorEdgeTypeConstraint; //! \interface IMLOperatorShapeInferenceContext //! \brief Provides information about an operator's usage while shape inferrers are being invoked. interface DECLSPEC_UUID("105B6B29-5408-4A68-9959-09B5955A3492") DECLSPEC_NOVTABLE -IMLOperatorShapeInferenceContext : public IMLOperatorAttributes +IMLOperatorShapeInferenceContext : public IMLOperatorAttributes { //! Gets the number of inputs to the operator. STDMETHOD_(uint32_t, GetInputCount)() const noexcept PURE; @@ -465,11 +468,11 @@ IMLOperatorShapeInferenceContext : public IMLOperatorAttributes //! Gets the number of outputs to the operator. STDMETHOD_(uint32_t, GetOutputCount)() const noexcept PURE; - //! Returns true if an input to the operator is valid. + //! Returns true if an input to the operator is valid. //! This always returns true except for optional inputs and invalid indices. STDMETHOD_(bool, IsInputValid)(uint32_t inputIndex) const noexcept PURE; - //! Returns true if an output to the operator is valid. + //! Returns true if an output to the operator is valid. //! This always returns true except for optional outputs and invalid indices. STDMETHOD_(bool, IsOutputValid)(uint32_t outputIndex) const noexcept PURE; @@ -496,8 +499,8 @@ IMLOperatorShapeInferenceContext : public IMLOperatorAttributes //! Sets the inferred shape of an output tensor. //! Returns an error if the output at the specified index is not a tensor. STDMETHOD(SetOutputTensorShape)( - uint32_t outputIndex, - uint32_t dimensionCount, + uint32_t outputIndex, + uint32_t dimensionCount, const uint32_t* dimensions ) noexcept PURE; }; @@ -505,7 +508,7 @@ IMLOperatorShapeInferenceContext : public IMLOperatorAttributes //! \interface IMLOperatorTypeInferenceContext //! \brief Provides information about an operator's usage while type inferrers are being invoked. interface DECLSPEC_UUID("EC893BB1-F938-427B-8488-C8DCF775F138") DECLSPEC_NOVTABLE -IMLOperatorTypeInferenceContext : public IMLOperatorAttributes +IMLOperatorTypeInferenceContext : public IMLOperatorAttributes { //! Gets the number of inputs to the operator. STDMETHOD_(uint32_t, GetInputCount)() const noexcept PURE; @@ -513,11 +516,11 @@ IMLOperatorTypeInferenceContext : public IMLOperatorAttributes //! Gets the number of outputs to the operator. STDMETHOD_(uint32_t, GetOutputCount)() const noexcept PURE; - //! Returns true if an input to the operator is valid. + //! Returns true if an input to the operator is valid. //! This always returns true except for optional inputs. STDMETHOD_(bool, IsInputValid)(uint32_t inputIndex) const noexcept PURE; - //! Returns true if an output to the operator is valid. + //! Returns true if an output to the operator is valid. //! This always returns true except for optional outputs. STDMETHOD_(bool, IsOutputValid)(uint32_t outputIndex) const noexcept PURE; @@ -529,31 +532,31 @@ IMLOperatorTypeInferenceContext : public IMLOperatorAttributes //! Sets the inferred type of an output edge. STDMETHOD(SetOutputEdgeDescription)( - uint32_t outputIndex, + uint32_t outputIndex, const MLOperatorEdgeDescription* edgeDescription ) const noexcept PURE; }; - + //! \interface IMLOperatorTypeInferrer //! \brief Implemented by type inferrers to infer types of an operator's output edges. -//! Type inferrers must be provided when registering schema of custom operators if -//! the MLOperatorSchemaDescription structure cannot express how output types are -//! determined. For example, such as when an attribute of the operator determines +//! Type inferrers must be provided when registering schema of custom operators if +//! the MLOperatorSchemaDescription structure cannot express how output types are +//! determined. For example, such as when an attribute of the operator determines //! the data type of one of that operator's outputs. interface DECLSPEC_UUID("781AEB48-9BCB-4797-BF77-8BF455217BEB") DECLSPEC_NOVTABLE IMLOperatorTypeInferrer : IUnknown { - //! Called to infer types of an operator's output edges + //! Called to infer types of an operator's output edges STDMETHOD(InferOutputTypes)( IMLOperatorTypeInferenceContext* context ) noexcept PURE; }; //! \interface IMLOperatorShapeInferrer -//! \brief Implemented by shape inferrers to infer shapes of an operator's -//! output tensor edges. Shape inferrers may be provided when registering custom -//! operator kernels to improve performance and to enable the kernel to query -//! the shape of its output tensors when it is created and computed. Shape +//! \brief Implemented by shape inferrers to infer shapes of an operator's +//! output tensor edges. Shape inferrers may be provided when registering custom +//! operator kernels to improve performance and to enable the kernel to query +//! the shape of its output tensors when it is created and computed. Shape //! inferrers may also be provided when registering custom operator schema to //! improve model validation. interface DECLSPEC_UUID("540BE5BE-A6C9-40EE-83F6-D2B8B40A7798") DECLSPEC_NOVTABLE @@ -563,14 +566,14 @@ IMLOperatorShapeInferrer : IUnknown STDMETHOD(InferOutputShapes)( IMLOperatorShapeInferenceContext* context ) noexcept PURE; -}; +}; -//! \struct MLOperatorAttribute +//! \struct MLOperatorAttribute //! \brief Specifies the name and properties of an attribute of a custom operator. //! This is used when registering custom operator kernels and custom operator schema. -struct MLOperatorAttribute +struct MLOperatorAttribute { - //! NULL-terminated UTF-8 string representing the name of the attribute in the + //! NULL-terminated UTF-8 string representing the name of the attribute in the //! associated operator type. _Field_z_ const char* name; @@ -580,13 +583,13 @@ struct MLOperatorAttribute //! Whether the attribute is required in any model using the associated operator type. bool required; }; - + //! \struct MLOperatorAttributeNameValue //! \brief Specifies the name and value(s) of an attribute of a custom operator. //! This is used when registering custom operator kernels and custom operator schema. -struct MLOperatorAttributeNameValue +struct MLOperatorAttributeNameValue { - //! NULL-terminated UTF-8 string representing the name of the attribute in the + //! NULL-terminated UTF-8 string representing the name of the attribute in the //! associated operator type. _Field_z_ const char* name; @@ -596,35 +599,35 @@ struct MLOperatorAttributeNameValue //! The number of elements in the attribute value. This must be one, except for attributes //! which are of array types. uint32_t valueCount; - - union + + union { const void* reserved; - //! 64 bit integer value(s). Used when the type field is + //! 64 bit integer value(s). Used when the type field is //! MLOperatorAttributeType::Int or MLOperatorAttributeType::IntArray. _Field_size_(valueCount) const int64_t* ints; - //! NULL-terminated UTF-8 string value(s). Used when the type field is + //! NULL-terminated UTF-8 string value(s). Used when the type field is //! MLOperatorAttributeType::String or MLOperatorAttributeType::StringArray. _Field_size_(valueCount) const char* const* strings; - //! 32 bit floating point value(s). Used when the type field is + //! 32 bit floating point value(s). Used when the type field is //! MLOperatorAttributeType::Float or MLOperatorAttributeType::FloatArray. _Field_size_(valueCount) const float* floats; }; }; - + //! \struct MLOperatorSchemaDescription //! \brief Description of a custom operator schema used to register that schema. struct MLOperatorSchemaDescription { //! NULL-terminated UTF-8 string representing the name of the operator. _Field_z_ const char* name; - + //! The operator set version at which this operator was introduced or last changed. int32_t operatorSetVersionAtLastChange; - + //! An array containing the descriptions of the operator's input edges. _Field_size_opt_(inputCount) const MLOperatorSchemaEdgeDescription* inputs; @@ -636,20 +639,20 @@ struct MLOperatorSchemaDescription //! The number of outputs of the operator. uint32_t outputCount; - + //! An array of type constraints. Each constraint restricts input and outputs //! associated with a type label string to one or more edge types. _Field_size_opt_(typeConstraintCount) const MLOperatorEdgeTypeConstraint* typeConstraints; //! The number of type constraints provided. uint32_t typeConstraintCount; - + //! The set of attributes supported by the operator type. _Field_size_opt_(attributeCount) const MLOperatorAttribute* attributes; //! The number of provided attributes. uint32_t attributeCount; - + //! The default values of attributes. These will be applied when the attributes are missing //! in a model containing the operator type. _Field_size_opt_(defaultAttributeCount) const MLOperatorAttributeNameValue* defaultAttributes; @@ -657,10 +660,10 @@ struct MLOperatorSchemaDescription //! The number of provided default attribute values. uint32_t defaultAttributeCount; }; - + //! \struct MLOperatorSetId //! \brief Specifies the identity of an operator set. -struct MLOperatorSetId +struct MLOperatorSetId { //! The domain of the operator, for example, "ai.onnx.ml", or an empty string //! for the ONNX domain. @@ -669,26 +672,26 @@ struct MLOperatorSetId //! The version of the operator domain. int32_t version; }; - + //! \enum MLOperatorKernelOptions //! \brief Specifies options used when registering custom operator kernels. -enum class MLOperatorKernelOptions : uint32_t +enum class MLOperatorKernelOptions : uint32_t { None = 0, - + //! Specifies whether the shapes of input tensors are allowed to vary among invocations //! of an operator kernel instance. If this is not set, kernel instances may query input //! tensor shapes during creation, and front-load initialization work which depends //! on those shapes. Setting this may improve performance if shapes vary dynamically between - //! inference operations, and the kernel implementation handles this efficiently. + //! inference operations, and the kernel implementation handles this efficiently. AllowDynamicInputShapes = 1, }; DEFINE_ENUM_FLAG_OPERATORS(MLOperatorKernelOptions); - + //! \enum MLOperatorExecutionType //! \brief Specifies whether a kernel uses the CPU or GPU for computation. -enum class MLOperatorExecutionType : uint32_t +enum class MLOperatorExecutionType : uint32_t { Undefined = 0, Cpu = 1, @@ -704,15 +707,15 @@ struct MLOperatorKernelDescription //! NULL-terminated UTF-8 string representing the name of the operator. _Field_z_ const char* name; - - //! The minimum version of the operator sets for which this kernel is valid. + + //! The minimum version of the operator sets for which this kernel is valid. //! The maximum version is inferred based on registrations of operator set schema for //! subsequent versions of the same domain. int32_t minimumOperatorSetVersion; - + //! Specifies whether a kernel uses the CPU or GPU for computation. MLOperatorExecutionType executionType; - + //! An array of type constraints. Each constraint restricts input and outputs //! associated with a type label string to one or more edge types. _Field_size_opt_(typeConstraintCount) const MLOperatorEdgeTypeConstraint* typeConstraints; @@ -726,14 +729,14 @@ struct MLOperatorKernelDescription //! The number of provided default attribute values. uint32_t defaultAttributeCount; - + //! Options for the kernel which apply to all execution provider types. MLOperatorKernelOptions options; - + //! Reserved for additional options. Must be zero. uint32_t executionOptions; }; - + //! \interface IMLOperatorKernelFactory //! \brief Implemented by the author of a custom operator kernel to create instances of that kernel. interface DECLSPEC_UUID("EF15AD6F-0DC9-4908-AB35-A575A30DFBF8") DECLSPEC_NOVTABLE @@ -746,7 +749,7 @@ IMLOperatorKernelFactory : IUnknown _COM_Outptr_ IMLOperatorKernel** kernel ) noexcept PURE; }; - + //! \interface IMLOperatorRegistry //! \brief Represents an instance of a registry for custom operator kernel and schema. //! Custom operators may be used with WinML APIs by returning @@ -757,7 +760,7 @@ IMLOperatorRegistry : IUnknown //! Registers a set of custom operator schema comprising an operator set. Operator sets follow //! the ONNX versioning design. Callers should provide schema for all operators that have changed //! between the specified baseline version and the version specified within operatorSetId. This - //! prevents older versions of kernels from being used in models which import the newer operator + //! prevents older versions of kernels from being used in models which import the newer operator //! set version. A type inferrer must be provided if the MLOperatorSchemaDescription structure //! cannot express how output types are determined. A shape inferrer may optionally be provided //! to enable model validation. @@ -770,7 +773,7 @@ IMLOperatorRegistry : IUnknown _In_opt_ IMLOperatorShapeInferrer* shapeInferrer ) const noexcept PURE; - //! Registers a custom operator kernel. + //! Registers a custom operator kernel. //! A shape inferrer may optionally be provided. This may improve performance and enables //! the kernel to query the shape of its output tensors when it is created and computed. STDMETHOD(RegisterOperatorKernel)( @@ -779,16 +782,16 @@ IMLOperatorRegistry : IUnknown _In_opt_ IMLOperatorShapeInferrer* shapeInferrer ) const noexcept PURE; }; - + extern "C" { //! \fn MLCreateOperatorRegistry //! Creates an instance of IMLOperatorRegistry which may be used to register custom - //! operator kernel and custom operator schema. + //! operator kernel and custom operator schema. HRESULT WINAPI MLCreateOperatorRegistry(_COM_Outptr_ IMLOperatorRegistry** registry); } #endif /* WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_APP) */ #endif /* defined(__cplusplus) */ -#endif /* defined(_MSC_VER) && (_MSC_VER >= 1700) */ \ No newline at end of file +#endif /* defined(_MSC_VER) && (_MSC_VER >= 1700) */ diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.cpp index 7ef03316441ac..ede3e7f2c2257 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.cpp @@ -430,13 +430,14 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel( for (uint32_t j = 0; j < opKernel->typeConstraints[i].allowedTypeCount; ++j) { + auto edgeType = opKernel->typeConstraints[i].allowedTypes[j].edgeType; // TODO - handle non-tensor types - if (opKernel->typeConstraints[i].allowedTypes[j].edgeType != MLOperatorEdgeType::Tensor) + if (edgeType == MLOperatorEdgeType::Undefined) { ORT_THROW_IF_FAILED(E_NOTIMPL); } - types.push_back(ToTensorDataType(opKernel->typeConstraints[i].allowedTypes[j].tensorDataType)); + types.push_back(ToMLDataType(edgeType, opKernel->typeConstraints[i].allowedTypes[j].tensorDataType)); } builder.TypeConstraint(opKernel->typeConstraints[i].typeLabel, types); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp index 3ae8e1483141c..85674ada5eebe 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp @@ -60,6 +60,8 @@ namespace Dml auto customRegistry = *abiRegistry->GetRegistries().begin(); *registry = customRegistry->GetKernelRegistry(); *internalRegInfoMap = abiRegistry->GetInternalRegInfoMap(); + + Dml::RegisterCpuOperatorsAsDml(registry->get()); } ExecutionProvider::ExecutionProvider( @@ -477,6 +479,57 @@ namespace Dml ORT_CATCH_RETURN } + HRESULT __stdcall ExecutionProviderImpl::CopyTensors(gsl::span dst, gsl::span src) const noexcept + { + ORT_TRY + { + ORT_THROW_HR_IF(E_INVALIDARG, dst.size() != src.size()); + + // Source and destination for batched GPU -> CPU copies + std::vector srcDatas; + std::vector dstDatas; + std::vector dataSizesInBytes; + + assert(!m_closed); + auto provider = const_cast(this); + + for (uint32_t i = 0; i < dst.size(); ++i) + { + // This batching implementation only handles GPU -> CPU copies. Other copies do not require synchronization + // and are batched across multiple calls to CopyTensor. + if (src[i]->IsCpuData() || !dst[i]->IsCpuData()) + { + ORT_THROW_IF_FAILED(CopyTensor(dst[i], src[i])); + continue; + } + + const size_t dataSizeInBytes = ComputeByteSizeFromTensor(*dst[i]); + ORT_THROW_HR_IF(E_INVALIDARG, dataSizeInBytes != ComputeByteSizeFromTensor(*src[i])); // Tensors must be the same size + + if (dataSizeInBytes == 0) + { + continue; + } + + dataSizesInBytes.push_back(static_cast(ComputeByteSizeFromTensor(*dst[i]))); + ORT_THROW_HR_IF(E_INVALIDARG, dataSizesInBytes[i] != ComputeByteSizeFromTensor(*src[i])); // Tensors must be the same size + + dstDatas.push_back(dst[i]->GetData()); + const AllocationInfo* srcAllocInfo = m_allocator->DecodeDataHandle(MLOperatorTensor(src[i]).GetDataInterface().Get()); + + srcDatas.push_back(srcAllocInfo->GetResource()); + } + + const auto srcState = D3D12_RESOURCE_STATE_UNORDERED_ACCESS; // GPU resources are always kept in UAV state + + // Performs a blocking call to synchronize and read back data from the GPU into the destination buffer + m_readbackHeap->ReadbackFromGpu(dstDatas, dataSizesInBytes, srcDatas, srcState); + + return S_OK; + } + ORT_CATCH_RETURN + } + HRESULT STDMETHODCALLTYPE ExecutionProviderImpl::FillTensorWithPattern( IMLOperatorTensor* dst, gsl::span rawValue // Data type agnostic rawValue, treated as raw bits @@ -524,18 +577,35 @@ namespace Dml bool TryGetTensorDataType( const onnxruntime::NodeArg& nodeArg, + _Out_ MLOperatorEdgeType* edgeType, _Out_ MLOperatorTensorDataType* onnxElementType ) { *onnxElementType = MLOperatorTensorDataType::Undefined; + *edgeType = MLOperatorEdgeType::Undefined; const ::onnx::TypeProto* typeProto = nodeArg.TypeAsProto(); - if (typeProto != nullptr && typeProto->has_tensor_type()) + if (typeProto != nullptr) { - const ::onnx::TypeProto_Tensor& tensorTypeProto = typeProto->tensor_type(); - if (tensorTypeProto.has_elem_type()) + const ::onnx::TypeProto_Tensor* tensorTypeProto; + if (typeProto->has_tensor_type()) + { + *edgeType = MLOperatorEdgeType::Tensor; + tensorTypeProto = &typeProto->tensor_type(); + } + else if (typeProto->has_sequence_type()) + { + *edgeType = MLOperatorEdgeType::SequenceTensor; + tensorTypeProto = &typeProto->sequence_type().elem_type().tensor_type(); + } + else + { + return false; + } + + if (tensorTypeProto->has_elem_type()) { - *onnxElementType = static_cast(tensorTypeProto.elem_type()); + *onnxElementType = static_cast(tensorTypeProto->elem_type()); return true; } } @@ -543,6 +613,27 @@ namespace Dml return false; } + bool IsCpuOnDmlOperator(const onnxruntime::Node& node) + { + auto sequence_ops = std::array{ + "SequenceAt", + "SequenceConstruct", + "SequenceEmpty", + "SequenceLength", + "SequenceErase", + "SequenceInsert", + }; + + for (auto& sequence_op : sequence_ops) + { + if (strcmp(sequence_op, node.OpType().c_str()) == 0) + { + return true; + } + } + return false; + } + bool DoesNodeContainSupportedDataTypes( const onnxruntime::Node& node, _In_opt_ const InternalRegistrationInfo* regInfo, @@ -576,8 +667,9 @@ namespace Dml // Use the enumeration from the proto instead of nodeArg.Type() which returns a string. // Reject node if undefined data type or non-tensor, as DML cannot handle it. + MLOperatorEdgeType edgeType; MLOperatorTensorDataType onnxElementType; - if (!TryGetTensorDataType(nodeArg, &onnxElementType)) + if (!TryGetTensorDataType(nodeArg, &edgeType, &onnxElementType)) { // We shouldn't have arrived here because (1) no DML operators should have been // registered which use non-tensor types (2) ONNX validation should have already @@ -590,6 +682,13 @@ namespace Dml return; } + // Allow nodeArgs that are SequenceTensor when they are actually implemented by CPU Kernels. + if (edgeType == MLOperatorEdgeType::SequenceTensor && IsCpuOnDmlOperator(node)) + { + // Leave nodeContainsSupportedDataTypes alone. + return; + } + // Reject node for unknown DML data types. DML_TENSOR_DATA_TYPE dmlElementType = GetDmlDataTypeFromMlDataTypeNoThrow(onnxElementType); if (dmlElementType == DML_TENSOR_DATA_TYPE_UNKNOWN) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h index 048230f12723a..707838289455f 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h @@ -72,6 +72,7 @@ namespace Dml ) const noexcept final; STDMETHOD(CopyTensor)(IMLOperatorTensor* dst, IMLOperatorTensor* src) const noexcept final; + STDMETHOD(CopyTensors)(gsl::span dst, gsl::span src) const noexcept final; STDMETHOD(FillTensorWithPattern)( IMLOperatorTensor* dst, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/FusedGraphKernel.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/FusedGraphKernel.cpp index e809a20cc0f4b..b7f24d49d19da 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/FusedGraphKernel.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/FusedGraphKernel.cpp @@ -302,7 +302,8 @@ namespace Dml } else { - const onnxruntime::Tensor* tensor = kernelContext->Input(i); + assert(kernelContext->InputType(gsl::narrow_cast(i))->IsTensorType()); + const onnxruntime::Tensor* tensor = kernelContext->Input(gsl::narrow_cast(i)); uint64_t allocId; DmlGraphFusionHelper::UnwrapTensor(m_winmlProvider.Get(), tensor, &inputBindings[i].Buffer, &allocId); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/IExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/IExecutionProvider.h index b4baf62ab73f5..d7a0a607cdec9 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/IExecutionProvider.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/IExecutionProvider.h @@ -37,7 +37,7 @@ namespace Dml _In_opt_ const DML_BUFFER_BINDING* persistentResourceBinding, gsl::span inputTensors ) const noexcept = 0; - + STDMETHOD(ExecuteOperator)( IDMLCompiledOperator* op, _In_opt_ const DML_BUFFER_BINDING* persistentResourceBinding, @@ -53,6 +53,7 @@ namespace Dml ) const noexcept = 0; STDMETHOD(CopyTensor)(IMLOperatorTensor* dst, IMLOperatorTensor* src) const noexcept = 0; + STDMETHOD(CopyTensors)(gsl::span dst, gsl::span src) const noexcept = 0; STDMETHOD(FillTensorWithPattern)( IMLOperatorTensor* dst, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp index 197c62283fba9..e811505ddb043 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp @@ -5,6 +5,8 @@ #include "core/framework/customregistry.h" #include "core/framework/execution_frame.h" +#include "core/framework/TensorSeq.h" + #include "core/session/onnxruntime_c_api.h" #include "core/providers/dml/DmlExecutionProvider/inc/MLOperatorAuthor.h" @@ -269,31 +271,90 @@ namespace Windows::AI::MachineLearning::Adapter return onnxruntime::DataTypeImpl::GetTensorType(); \ } +#define ML_SEQUENCE_TENSOR_TYPE_CASE(x) \ + if (type == MLTypeTraits::TensorType) \ + { \ + return onnxruntime::DataTypeImpl::GetSequenceTensorType(); \ + } + +#define ML_PRIMITIVE_TYPE_CASE(x) \ + if (type == MLTypeTraits::TensorType) \ + { \ + return onnxruntime::DataTypeImpl::GetType(); \ + } + #pragma warning(push) #pragma warning(disable:4702) - onnxruntime::MLDataType ToTensorDataType(::MLOperatorTensorDataType type) - { - if (type == MLOperatorTensorDataType::String) - return onnxruntime::DataTypeImpl::GetTensorType(); + onnxruntime::MLDataType ToMLDataType(::MLOperatorEdgeType edgeType, ::MLOperatorTensorDataType type) + { + if (edgeType == ::MLOperatorEdgeType::Tensor) + { + if (type == MLOperatorTensorDataType::String) + return onnxruntime::DataTypeImpl::GetTensorType(); + + ML_TENSOR_TYPE_CASE(float); + ML_TENSOR_TYPE_CASE(uint8_t); + ML_TENSOR_TYPE_CASE(int8_t); + ML_TENSOR_TYPE_CASE(uint16_t); + ML_TENSOR_TYPE_CASE(int16_t); + ML_TENSOR_TYPE_CASE(int32_t); + ML_TENSOR_TYPE_CASE(int64_t); + ML_TENSOR_TYPE_CASE(bool); + ML_TENSOR_TYPE_CASE(double); + ML_TENSOR_TYPE_CASE(uint32_t); + ML_TENSOR_TYPE_CASE(uint64_t); + ML_TENSOR_TYPE_CASE(onnxruntime::MLFloat16); - ML_TENSOR_TYPE_CASE(float); - ML_TENSOR_TYPE_CASE(uint8_t); - ML_TENSOR_TYPE_CASE(int8_t); - ML_TENSOR_TYPE_CASE(uint16_t); - ML_TENSOR_TYPE_CASE(int16_t); - ML_TENSOR_TYPE_CASE(int32_t); - ML_TENSOR_TYPE_CASE(int64_t); - ML_TENSOR_TYPE_CASE(bool); - ML_TENSOR_TYPE_CASE(double); - ML_TENSOR_TYPE_CASE(uint32_t); - ML_TENSOR_TYPE_CASE(uint64_t); - ML_TENSOR_TYPE_CASE(onnxruntime::MLFloat16); + ORT_THROW_HR(E_NOTIMPL); + return onnxruntime::DataTypeImpl::GetTensorType(); + } + else if (edgeType == ::MLOperatorEdgeType::SequenceTensor) + { + if (type == MLOperatorTensorDataType::String) + return onnxruntime::DataTypeImpl::GetSequenceTensorType(); + + ML_SEQUENCE_TENSOR_TYPE_CASE(float); + ML_SEQUENCE_TENSOR_TYPE_CASE(uint8_t); + ML_SEQUENCE_TENSOR_TYPE_CASE(int8_t); + ML_SEQUENCE_TENSOR_TYPE_CASE(uint16_t); + ML_SEQUENCE_TENSOR_TYPE_CASE(int16_t); + ML_SEQUENCE_TENSOR_TYPE_CASE(int32_t); + ML_SEQUENCE_TENSOR_TYPE_CASE(int64_t); + ML_SEQUENCE_TENSOR_TYPE_CASE(bool); + ML_SEQUENCE_TENSOR_TYPE_CASE(double); + ML_SEQUENCE_TENSOR_TYPE_CASE(uint32_t); + ML_SEQUENCE_TENSOR_TYPE_CASE(uint64_t); + ML_SEQUENCE_TENSOR_TYPE_CASE(onnxruntime::MLFloat16); - ORT_THROW_HR(E_NOTIMPL); - return onnxruntime::DataTypeImpl::GetTensorType(); + ORT_THROW_HR(E_NOTIMPL); + return onnxruntime::DataTypeImpl::GetSequenceTensorType(); + } + else if (edgeType == ::MLOperatorEdgeType::Primitive) + { + if (type == MLOperatorTensorDataType::String) + return onnxruntime::DataTypeImpl::GetType(); + + ML_PRIMITIVE_TYPE_CASE(float); + ML_PRIMITIVE_TYPE_CASE(uint8_t); + ML_PRIMITIVE_TYPE_CASE(int8_t); + ML_PRIMITIVE_TYPE_CASE(uint16_t); + ML_PRIMITIVE_TYPE_CASE(int16_t); + ML_PRIMITIVE_TYPE_CASE(int32_t); + ML_PRIMITIVE_TYPE_CASE(int64_t); + ML_PRIMITIVE_TYPE_CASE(bool); + ML_PRIMITIVE_TYPE_CASE(double); + ML_PRIMITIVE_TYPE_CASE(uint32_t); + ML_PRIMITIVE_TYPE_CASE(uint64_t); + ML_PRIMITIVE_TYPE_CASE(onnxruntime::MLFloat16); + + ORT_THROW_HR(E_NOTIMPL); + return onnxruntime::DataTypeImpl::GetType(); + } #pragma warning(pop) + ORT_THROW_HR(E_NOTIMPL); } + #pragma warning(push) #pragma warning(disable:4702) ::MLOperatorTensorDataType ToMLTensorDataType(onnx::TensorProto_DataType type) @@ -358,6 +419,7 @@ namespace Windows::AI::MachineLearning::Adapter MLOperatorEdgeDescription ret = {}; ML_CHECK_BOOL(type->value_case() == onnx::TypeProto::kTensorType || + type->value_case() == onnx::TypeProto::kSequenceType || type->value_case() == onnx::TypeProto::VALUE_NOT_SET); if (type->value_case() == onnx::TypeProto::kTensorType) @@ -369,6 +431,15 @@ namespace Windows::AI::MachineLearning::Adapter ret.tensorDataType = ToMLTensorDataType(onnx::TensorProto_DataType(tensorType.elem_type())); } } + else if (type->value_case() == onnx::TypeProto::kSequenceType) + { + ret.edgeType = MLOperatorEdgeType::SequenceTensor; + const auto& tensorType = type->sequence_type().elem_type().tensor_type(); + if (tensorType.has_elem_type()) + { + ret.tensorDataType = ToMLTensorDataType(onnx::TensorProto_DataType(tensorType.elem_type())); + } + } return ret; } @@ -377,61 +448,117 @@ namespace Windows::AI::MachineLearning::Adapter #pragma warning(disable:4702) std::string ToTypeString(MLOperatorEdgeDescription desc) { - if (desc.edgeType != MLOperatorEdgeType::Tensor) + if (desc.edgeType == MLOperatorEdgeType::Tensor) { - ORT_THROW_HR(E_NOTIMPL); - } + switch (desc.tensorDataType) + { + case MLOperatorTensorDataType::Float: + return "tensor(float)"; + + case MLOperatorTensorDataType::UInt8: + return "tensor(uint8)"; + + case MLOperatorTensorDataType::Int8: + return "tensor(int8)"; + + case MLOperatorTensorDataType::UInt16: + return "tensor(uint16)"; + + case MLOperatorTensorDataType::Int16: + return "tensor(int16)"; + + case MLOperatorTensorDataType::Int32: + return "tensor(int32)"; + + case MLOperatorTensorDataType::Int64: + return "tensor(int64)"; + + case MLOperatorTensorDataType::String: + return "tensor(string)"; + + case MLOperatorTensorDataType::Bool: + return "tensor(bool)"; + + case MLOperatorTensorDataType::Float16: + return "tensor(float16)"; + + case MLOperatorTensorDataType::Double: + return "tensor(double)"; + + case MLOperatorTensorDataType::UInt32: + return "tensor(uint32)"; - switch (desc.tensorDataType) + case MLOperatorTensorDataType::UInt64: + return "tensor(uint64)"; + + case MLOperatorTensorDataType::Complex64: + return "tensor(complext64)"; + + case MLOperatorTensorDataType::Complex128: + return "tensor(complext128)"; + + default: + ORT_THROW_HR(E_NOTIMPL); + return ""; + } + } + else if (desc.edgeType == MLOperatorEdgeType::SequenceTensor) { - case MLOperatorTensorDataType::Float: - return "tensor(float)"; + switch (desc.tensorDataType) + { + case MLOperatorTensorDataType::Float: + return "seq(tensor(float))"; - case MLOperatorTensorDataType::UInt8: - return "tensor(uint8)"; + case MLOperatorTensorDataType::UInt8: + return "seq(tensor(uint8))"; - case MLOperatorTensorDataType::Int8: - return "tensor(int8)"; + case MLOperatorTensorDataType::Int8: + return "seq(tensor(int8))"; - case MLOperatorTensorDataType::UInt16: - return "tensor(uint16)"; + case MLOperatorTensorDataType::UInt16: + return "seq(tensor(uint16))"; - case MLOperatorTensorDataType::Int16: - return "tensor(int16)"; + case MLOperatorTensorDataType::Int16: + return "seq(tensor(int16))"; - case MLOperatorTensorDataType::Int32: - return "tensor(int32)"; + case MLOperatorTensorDataType::Int32: + return "seq(tensor(int32))"; - case MLOperatorTensorDataType::Int64: - return "tensor(int64)"; + case MLOperatorTensorDataType::Int64: + return "seq(tensor(int64))"; - case MLOperatorTensorDataType::String: - return "tensor(string)"; + case MLOperatorTensorDataType::String: + return "seq(tensor(string))"; - case MLOperatorTensorDataType::Bool: - return "tensor(bool)"; + case MLOperatorTensorDataType::Bool: + return "seq(tensor(bool))"; - case MLOperatorTensorDataType::Float16: - return "tensor(float16)"; + case MLOperatorTensorDataType::Float16: + return "seq(tensor(float16))"; - case MLOperatorTensorDataType::Double: - return "tensor(double)"; + case MLOperatorTensorDataType::Double: + return "seq(tensor(double))"; - case MLOperatorTensorDataType::UInt32: - return "tensor(uint32)"; + case MLOperatorTensorDataType::UInt32: + return "seq(tensor(uint32))"; - case MLOperatorTensorDataType::UInt64: - return "tensor(uint64)"; + case MLOperatorTensorDataType::UInt64: + return "seq(tensor(uint64))"; - case MLOperatorTensorDataType::Complex64: - return "tensor(complext64)"; + case MLOperatorTensorDataType::Complex64: + return "seq(tensor(complext64))"; - case MLOperatorTensorDataType::Complex128: - return "tensor(complext128)"; + case MLOperatorTensorDataType::Complex128: + return "seq(tensor(complext128))"; - default: + default: + ORT_THROW_HR(E_NOTIMPL); + return ""; + } + } + else + { ORT_THROW_HR(E_NOTIMPL); - return ""; } #pragma warning(pop) } @@ -446,9 +573,10 @@ namespace Windows::AI::MachineLearning::Adapter bool isInternalOperator, const AttributeMap* defaultAttributes, gsl::span requiredConstantCpuInputs, - MLOperatorTensorGetter& constantInputGetter + MLOperatorTensorGetter& constantInputGetter, + const onnxruntime::OpKernelContext* kernelContext ) - : OpNodeInfoWrapper(kerneInfo, inputShapeOverrides, defaultAttributes, requiredConstantCpuInputs, constantInputGetter), + : OpNodeInfoWrapper(kerneInfo, inputShapeOverrides, defaultAttributes, requiredConstantCpuInputs, constantInputGetter, kernelContext), m_inferredOutputShapes(inferredOutputShapes), m_allowInputShapeQuery(allowInputShapeQuery), m_allowOutputShapeQuery(allowOutputShapeQuery), @@ -738,8 +866,7 @@ namespace Windows::AI::MachineLearning::Adapter *edgeDesc = ToMLEdgeDesc(type); assert(edgeDesc->edgeType != MLOperatorEdgeType::Undefined); - assert((edgeDesc->edgeType != MLOperatorEdgeType::Tensor /*&& edgeDesc->edgeType != MLOperatorEdgeType::TensorSequence*/) || - edgeDesc->tensorDataType != MLOperatorTensorDataType::Undefined); + assert(edgeDesc->tensorDataType != MLOperatorTensorDataType::Undefined); return S_OK; } @@ -807,6 +934,66 @@ namespace Windows::AI::MachineLearning::Adapter ORT_CATCH_RETURN } + template + HRESULT STDMETHODCALLTYPE OpNodeInfoWrapper::GetSequenceInputTensorShape(uint32_t inputIndex, uint32_t sequenceIndex, uint32_t dimensionCount, uint32_t* dimensions) const noexcept + { + ORT_TRY + { + VerifyNotClosed(); + + memset(dimensions, 0, dimensionCount * sizeof(dimensions[0])); + if (inputIndex >= GetInputCount()) + { + return E_INVALIDARG; + } + + // Input shapes are determined either from the override or from the underlying proto + if (m_kernelContext) + { + assert(m_kernelContext->InputType(gsl::narrow_cast(inputIndex))->IsTensorSequenceType()); + auto inputTensorSeq = m_kernelContext->Input(gsl::narrow_cast(inputIndex)); + ML_CHECK_BOOL(inputTensorSeq != nullptr); + const auto& elemTensor = inputTensorSeq->Get(sequenceIndex); + const auto& shape = elemTensor.Shape(); + for (uint32_t i = 0; i < dimensionCount; ++i) + { + dimensions[i] = static_cast(shape[i]); + } + } + else if (m_inputShapesOverride) + { + if (m_inputShapesOverride->GetShape(inputIndex).size() != dimensionCount) + { + return E_INVALIDARG; + } + + for (uint32_t i = 0; i < dimensionCount; ++i) + { + dimensions[i] = m_inputShapesOverride->GetShape(inputIndex)[i]; + } + } + else + { + const auto* inputType = m_impl->GetInputType(inputIndex); + assert(inputType->has_sequence_type()); + ML_CHECK_BOOL(inputType->has_sequence_type()); + + const auto& elemType = inputType->sequence_type().elem_type(); + + for (uint32_t i = 0; i < dimensionCount; ++i) + { + // Shape inference is only done when all dimensions of all inputs have known values, + // so the input tensors will always have shapes at this point. + assert(elemType.tensor_type().shape().dim(i).has_dim_value()); + dimensions[i] = static_cast(elemType.tensor_type().shape().dim(i).dim_value()); + } + } + + return S_OK; + } + ORT_CATCH_RETURN + } + template bool STDMETHODCALLTYPE OpNodeInfoWrapper::IsInputValid(uint32_t inputIndex) const noexcept { @@ -865,6 +1052,75 @@ namespace Windows::AI::MachineLearning::Adapter ORT_CATCH_RETURN } + template + HRESULT STDMETHODCALLTYPE OpNodeInfoWrapper::GetSequenceInputCount(uint32_t inputIndex, uint32_t* inputCount) const noexcept + { + ORT_TRY + { + VerifyNotClosed(); + + ML_CHECK_BOOL(inputIndex < GetInputCount()); + ML_CHECK_BOOL(m_kernelContext != nullptr); + + // Input shapes are determined either from the input tensor, override or from the underlying proto + assert(m_kernelContext->InputType(gsl::narrow_cast(inputIndex))->IsTensorSequenceType()); + ML_CHECK_BOOL(m_kernelContext->InputType(gsl::narrow_cast(inputIndex))->IsTensorSequenceType()); + auto inputTensorSeq = m_kernelContext->Input(gsl::narrow_cast(inputIndex)); + ML_CHECK_BOOL(inputTensorSeq != nullptr); + *inputCount = static_cast(inputTensorSeq->Size()); + + return S_OK; + } + ORT_CATCH_RETURN + } + + template + HRESULT STDMETHODCALLTYPE OpNodeInfoWrapper::GetSequenceInputTensorDimensionCount(uint32_t inputIndex, uint32_t sequenceIndex, uint32_t* dimensionCount) const noexcept + { + ORT_TRY + { + VerifyNotClosed(); + + *dimensionCount = 0; + + if (inputIndex >= GetInputCount()) + { + return E_INVALIDARG; + } + + // Input shapes are determined either from the input tensor, override or from the underlying proto + if (m_kernelContext) + { + assert(m_kernelContext->InputType(gsl::narrow_cast(inputIndex))->IsTensorSequenceType()); + auto inputTensorSeq = m_kernelContext->Input(gsl::narrow_cast(inputIndex)); + ML_CHECK_BOOL(inputTensorSeq != nullptr); + const auto& elemTensor = inputTensorSeq->Get(sequenceIndex); + *dimensionCount = static_cast(elemTensor.Shape().NumDimensions()); + } + else if (m_inputShapesOverride) + { + *dimensionCount = gsl::narrow_cast(m_inputShapesOverride->GetShape(inputIndex).size()); + } + else + { + const auto* inputType = m_impl->GetInputType(inputIndex); + assert(inputType->has_sequence_type()); + ML_CHECK_BOOL(inputType->has_sequence_type()); + + const auto& elemType = inputType->sequence_type().elem_type(); + + // Shape inference is only done when all dimensions of all inputs have known values, + // so the input tensors will always have shapes at this point. + assert(elemType.tensor_type().has_shape()); + + *dimensionCount = elemType.tensor_type().shape().dim_size(); + } + + return S_OK; + } + ORT_CATCH_RETURN + } + template HRESULT STDMETHODCALLTYPE OpNodeInfoWrapper::GetConstantInputTensor(uint32_t inputIndex, IMLOperatorTensor** tensor) const noexcept { @@ -877,8 +1133,10 @@ namespace Windows::AI::MachineLearning::Adapter ORT_THROW_HR_IF(E_INVALIDARG, !inputRequiredAsConstant); - ComPtr tensorWrapper = m_constantInputGetter(inputIndex); + auto constantInput = m_constantInputGetter(inputIndex); + ORT_THROW_HR_IF(E_INVALIDARG, !std::holds_alternative>(constantInput)); + auto tensorWrapper = std::get>(constantInput); if (tensorWrapper == nullptr) { // This shouldn't happen since kernel creation is deferred and repeated when required constant inputs are not present. @@ -1019,7 +1277,7 @@ namespace Windows::AI::MachineLearning::Adapter gsl::span requiredConstantCpuInputs, MLOperatorTensorGetter& constantInputGetter ) - : OpNodeInfoWrapper(protoHelper, nullptr, defaultAttributes, requiredConstantCpuInputs, constantInputGetter), + : OpNodeInfoWrapper(protoHelper, nullptr, defaultAttributes, requiredConstantCpuInputs, constantInputGetter, nullptr), m_inferredOutputShapes(inferredOutputShapes), m_internalOperator(isInternalOperator), m_graphNodeCreateInfo(graphNodeCreateInfo) @@ -1143,7 +1401,7 @@ namespace Windows::AI::MachineLearning::Adapter if (operatorGraphDesc->nodesAsOpDesc) { m_graphNodeCreateInfo->nodesAsOperatorDesc = std::vector>(); - for (uint32_t nodeIndex = 0; nodeIndex < operatorGraphDesc->nodeCount; nodeIndex++) + for (uint32_t nodeIndex = 0; nodeIndex < operatorGraphDesc->nodeCount; nodeIndex++) { auto* node = operatorGraphDesc->nodesAsOpDesc[nodeIndex]; assert(node != nullptr); @@ -1154,7 +1412,7 @@ namespace Windows::AI::MachineLearning::Adapter else { m_graphNodeCreateInfo->nodesAsIDMLOperator = std::vector>(); - for (uint32_t nodeIndex = 0; nodeIndex < operatorGraphDesc->nodeCount; nodeIndex++) + for (uint32_t nodeIndex = 0; nodeIndex < operatorGraphDesc->nodeCount; nodeIndex++) { auto* node = operatorGraphDesc->nodesAsIDMLOperator[nodeIndex]; assert(node != nullptr); @@ -1413,21 +1671,36 @@ namespace Windows::AI::MachineLearning::Adapter { if (m_winmlProvider->TransitionsRequiredForOperator(m_internalOperator)) { + uint32_t totalInputTensorCount = 0; + for (auto inputTensor : m_inputTensors) + { + totalInputTensorCount += static_cast(inputTensor.size()); + } std::vector resourcesToTransition; - resourcesToTransition.reserve(m_inputTensors.size() + m_outputTensors.size() + m_temporaryAllocations.size()); + resourcesToTransition.reserve(totalInputTensorCount + m_outputTensors.size() + m_temporaryAllocations.size()); for (uint32_t i = 0; i < m_inputTensors.size(); ++i) { - ComPtr tensor; - ORT_THROW_IF_FAILED(GetInputTensor(i, tensor.GetAddressOf())); - - if (tensor) + for (uint32_t j = 0; j < m_inputTensors[i].size(); ++j) { - ComPtr resource; - tensor->GetDataInterface(resource.GetAddressOf()); - if (resource) + ComPtr tensor; + if (m_inputTensors[i].size() == 1) { - resourcesToTransition.push_back(resource.Get()); + ORT_THROW_IF_FAILED(GetInputTensor(i, tensor.GetAddressOf())); + } + else + { + ORT_THROW_IF_FAILED(GetSequenceInputTensor(i, j, tensor.GetAddressOf())); + } + + if (tensor) + { + ComPtr resource; + tensor->GetDataInterface(resource.GetAddressOf()); + if (resource) + { + resourcesToTransition.push_back(resource.Get()); + } } } } @@ -1468,8 +1741,8 @@ namespace Windows::AI::MachineLearning::Adapter // Pre-size tensor arrays. Member methods return pointers to these which // are stored in these arrays, which would become stale if the vectors reallocate // their internal storage. - m_inputTensors.resize(context->InputCount()); - m_outputTensors.resize(context->OutputCount()); + m_inputTensors.resize(context->InputCount(), std::vector>(1)); + m_outputTensors.resize(context->OutputCount(), std::vector>(1)); const void* executionHandle = m_provider->GetExecutionHandle(); if (executionHandle) @@ -1511,19 +1784,25 @@ namespace Windows::AI::MachineLearning::Adapter TransitionResourcesForOperatorIfRequired(false); } - for (auto& tensor : m_inputTensors) + for (auto& tensors : m_inputTensors) { - if (tensor) + for (auto& tensor : tensors) { - tensor->Close(); + if (tensor) + { + tensor->Close(); + } } } - for (auto& tensor : m_outputTensors) + for (auto& tensors : m_outputTensors) { - if (tensor) + for (auto& tensor : tensors) { - tensor->Close(); + if (tensor) + { + tensor->Close(); + } } } @@ -1532,6 +1811,12 @@ namespace Windows::AI::MachineLearning::Adapter Closable::Close(); } + bool STDMETHODCALLTYPE OpKernelContextWrapper::IsSequenceInputTensor(uint32_t inputIndex) const noexcept + { + assert(inputIndex < gsl::narrow_cast(m_impl->InputCount())); + return m_impl->InputType(inputIndex)->IsTensorSequenceType(); + } + HRESULT STDMETHODCALLTYPE OpKernelContextWrapper::GetInputTensor(uint32_t inputIndex, IMLOperatorTensor** tensor) const noexcept { ORT_TRY @@ -1542,9 +1827,10 @@ namespace Windows::AI::MachineLearning::Adapter ML_CHECK_BOOL(inputIndex < m_inputTensors.size()); auto opKernelContextWrapper = const_cast(this); - if (m_inputTensors[inputIndex]->GetInterface() == nullptr) + if (m_inputTensors[inputIndex][0]->GetInterface() == nullptr) { - auto inputTensor = m_impl->Input(inputIndex); + assert(m_impl->InputType(gsl::narrow_cast(inputIndex))->IsTensorType()); + auto inputTensor = m_impl->Input(gsl::narrow_cast(inputIndex)); if (inputTensor != nullptr) { ComPtr tensorWrapper = wil::MakeOrThrow( @@ -1553,14 +1839,154 @@ namespace Windows::AI::MachineLearning::Adapter m_winmlProvider.Get(), m_internalOperator); - opKernelContextWrapper->m_inputTensors[inputIndex] = tensorWrapper; + opKernelContextWrapper->m_inputTensors[inputIndex][0] = tensorWrapper; + } + } + + if (opKernelContextWrapper->m_inputTensors[inputIndex][0] != nullptr) + { + opKernelContextWrapper->m_inputTensors[inputIndex][0].CopyTo(tensor); + } + return S_OK; + } + ORT_CATCH_RETURN + } + + HRESULT STDMETHODCALLTYPE OpKernelContextWrapper::GetSequenceInputTensor(uint32_t inputIndex, uint32_t sequenceIndex, IMLOperatorTensor** tensor) const noexcept + { + ORT_TRY + { + VerifyNotClosed(); + *tensor = nullptr; + + auto opKernelContextWrapper = const_cast(this); + + ML_CHECK_BOOL(inputIndex < m_inputTensors.size()); + if (sequenceIndex >= m_inputTensors[inputIndex].size()) + { + opKernelContextWrapper->m_inputTensors[inputIndex].resize(sequenceIndex+1); + } + + if (m_inputTensors[inputIndex][sequenceIndex]->GetInterface() == nullptr) + { + assert(m_impl->InputType(gsl::narrow_cast(inputIndex))->IsTensorSequenceType()); + auto inputTensorSeq = m_impl->Input(gsl::narrow_cast(inputIndex)); + ML_CHECK_BOOL(inputTensorSeq != nullptr); + + auto elemTensor = const_cast(&inputTensorSeq->Get(sequenceIndex)); + if (elemTensor != nullptr) + { + ComPtr tensorWrapper = wil::MakeOrThrow( + elemTensor, + IsAllocationInterface(elemTensor->Location()), + m_winmlProvider.Get(), + m_internalOperator); + + opKernelContextWrapper->m_inputTensors[inputIndex][sequenceIndex] = tensorWrapper; + } + } + + if (opKernelContextWrapper->m_inputTensors[inputIndex][sequenceIndex] != nullptr) + { + opKernelContextWrapper->m_inputTensors[inputIndex][sequenceIndex].CopyTo(tensor); + } + return S_OK; + } + ORT_CATCH_RETURN + } + + HRESULT STDMETHODCALLTYPE OpKernelContextWrapper::GetSequenceOutputTensor( + uint32_t outputIndex, + uint32_t sequenceIndex, + MLOperatorTensorDataType dataType, + uint32_t dimensions, + const uint32_t* dimensionSizes, + bool gpuOutput, + IMLOperatorTensor** tensor) const noexcept + { + ORT_TRY + { + VerifyNotClosed(); + *tensor = nullptr; + + auto opKernelContextWrapper = const_cast(this); + + ML_CHECK_BOOL(outputIndex < m_outputTensors.size()); + if (sequenceIndex >= m_outputTensors[outputIndex].size()) + { + opKernelContextWrapper->m_outputTensors[outputIndex].resize(sequenceIndex+1); + } + + // Verify that the provided shape matches the shape determined using the kernel's shape inference function. + if (m_outputTensors[outputIndex][sequenceIndex]->GetInterface() == nullptr) + { + auto outputTensorSeq = m_impl->Output(gsl::narrow_cast(outputIndex)); + ML_CHECK_BOOL(outputTensorSeq != nullptr); + + auto mlDataType = ToMLDataType(MLOperatorEdgeType::Primitive, dataType); + + if (outputTensorSeq->Size() == 0) + { + outputTensorSeq->SetType(mlDataType); + } + + onnxruntime::AllocatorPtr alloc; + if (gpuOutput) + { + auto status = m_impl->GetTempSpaceAllocator(&alloc); + ORT_THROW_HR_IF(E_INVALIDARG, !status.IsOK()); + } + else + { + auto status = m_impl->GetTempSpaceCPUAllocator(&alloc); + ORT_THROW_HR_IF(E_INVALIDARG, !status.IsOK()); + } + + std::vector shapeDims(dimensions); + for (uint32_t i = 0; i < dimensions; ++i) + { + shapeDims[i] = dimensionSizes[i]; + } + + auto target_tensor = onnxruntime::Tensor(mlDataType, onnxruntime::TensorShape(shapeDims), alloc); + outputTensorSeq->Add(std::move(target_tensor)); + + auto elemTensor = const_cast(&outputTensorSeq->Get(sequenceIndex)); + if (elemTensor != nullptr) + { + ComPtr tensorWrapper = wil::MakeOrThrow( + elemTensor, + IsAllocationInterface(elemTensor->Location()), + m_winmlProvider.Get(), + m_internalOperator); + + opKernelContextWrapper->m_outputTensors[outputIndex][sequenceIndex] = tensorWrapper; } } - if (opKernelContextWrapper->m_inputTensors[inputIndex] != nullptr) + if (opKernelContextWrapper->m_outputTensors[outputIndex][sequenceIndex] != nullptr) { - opKernelContextWrapper->m_inputTensors[inputIndex].CopyTo(tensor); + opKernelContextWrapper->m_outputTensors[outputIndex][sequenceIndex].CopyTo(tensor); } + + return S_OK; + } + ORT_CATCH_RETURN + } + + HRESULT STDMETHODCALLTYPE OpKernelContextWrapper::GetSequenceInputCount(uint32_t inputIndex, uint32_t* inputCount) const noexcept + { + ORT_TRY + { + VerifyNotClosed(); + + ML_CHECK_BOOL(inputIndex < m_inputTensors.size()); + + assert(m_impl->InputType(gsl::narrow_cast(inputIndex))->IsTensorSequenceType()); + ML_CHECK_BOOL(m_impl->InputType(gsl::narrow_cast(inputIndex))->IsTensorSequenceType()); + auto inputTensorSeq = m_impl->Input(gsl::narrow_cast(inputIndex)); + ML_CHECK_BOOL(inputTensorSeq != nullptr); + *inputCount = static_cast(inputTensorSeq->Size()); return S_OK; } ORT_CATCH_RETURN @@ -1600,7 +2026,7 @@ namespace Windows::AI::MachineLearning::Adapter ML_CHECK_BOOL(outputIndex < m_outputTensors.size()); // Verify that the provided shape matches the shape determined using the kernel's shape inference function. - if (m_outputTensors[outputIndex]->GetInterface() == nullptr) + if (m_outputTensors[outputIndex][0]->GetInterface() == nullptr) { if (m_outputShapes) { @@ -1626,18 +2052,18 @@ namespace Windows::AI::MachineLearning::Adapter m_winmlProvider.Get(), m_internalOperator); - const_cast(this)->m_outputTensors[outputIndex] = tensorWrapper; + const_cast(this)->m_outputTensors[outputIndex][0] = tensorWrapper; } } - m_outputTensors[outputIndex].CopyTo(tensor); + m_outputTensors[outputIndex][0].CopyTo(tensor); return S_OK; } ORT_CATCH_RETURN } - HRESULT STDMETHODCALLTYPE OpKernelContextWrapper::AllocateTemporaryData(size_t size, IUnknown** abiAllocation) const + HRESULT STDMETHODCALLTYPE OpKernelContextWrapper::AllocateTemporaryData(size_t size, IUnknown** abiAllocation) const noexcept { ORT_TRY { @@ -1697,7 +2123,7 @@ namespace Windows::AI::MachineLearning::Adapter { ComPtr tensor; ORT_THROW_IF_FAILED(GetInputTensor(i, tensor.GetAddressOf())); - ret.push_back(m_inputTensors[i].Get()); + ret.push_back(m_inputTensors[i][0].Get()); } return ret; @@ -1719,7 +2145,7 @@ namespace Windows::AI::MachineLearning::Adapter outputShapes.GetShape(i).data(), tensor.GetAddressOf())); - ret.push_back(m_outputTensors[i].Get()); + ret.push_back(m_outputTensors[i][0].Get()); } return ret; @@ -1814,7 +2240,8 @@ namespace Windows::AI::MachineLearning::Adapter isInternalOperator, m_defaultAttributes, m_requiredConstantCpuInputs, - constantInputGetter); + constantInputGetter, + nullptr /*const onnxruntime::OpKernelContext* m_kernelContext*/); ORT_THROW_IF_FAILED(operatorFactory->CreateKernel(kernelInfoWrapper.Get(), m_kernel.GetAddressOf())); kernelInfoWrapper->Close(); @@ -1839,21 +2266,57 @@ namespace Windows::AI::MachineLearning::Adapter MLOperatorTensorGetter constantInputGetter = [context, winmlProviderCapture, internalOpCapture](uint32_t index) { - Microsoft::WRL::ComPtr tensorWrapper = nullptr; - const onnxruntime::Tensor* tensor = context->Input(static_cast(index)); - if (tensor != nullptr) + auto inputType = context->InputType(gsl::narrow_cast(index)); + + if (inputType != nullptr) { - tensorWrapper = wil::MakeOrThrow( - const_cast(tensor), - tensor ? IsAllocationInterface(tensor->Location()) : false, - winmlProviderCapture.Get(), - internalOpCapture); + if (inputType->IsTensorType()) + { + Microsoft::WRL::ComPtr tensorWrapper = nullptr; + + const auto* tensor = context->Input(gsl::narrow_cast(index)); + if (tensor != nullptr) + { + tensorWrapper = wil::MakeOrThrow( + const_cast(tensor), + IsAllocationInterface(tensor->Location()), + winmlProviderCapture.Get(), + internalOpCapture); + } + + return tensorWrapper; + } + else if (inputType->IsTensorSequenceType()) + { + std::vector> tensorWrappers; + + const auto* tensorSequence = context->Input(gsl::narrow_cast(index)); + if (tensorSequence != nullptr) + { + tensorWrappers.reserve(tensorSequence->Size()); + + for (uint32_t sequenceIndex = 0; sequenceIndex < tensorSequence->Size(); ++sequenceIndex) + { + auto& tensor = tensorSequence->Get(sequenceIndex); + auto tensorWrapper = wil::MakeOrThrow( + const_cast(&tensor), + IsAllocationInterface(tensor.Location()), + winmlProviderCapture.Get(), + internalOpCapture); + } + } + } + else + { + assert(false); + ORT_THROW_HR(E_INVALIDARG); + } } - return tensorWrapper; + return Microsoft::WRL::ComPtr(); }; - auto inferShapesAndCreateKernel = [&](const EdgeShapes& inputShapes, EdgeShapes& outputShapes) -> ComPtr { + auto inferShapesAndCreateKernel = [&, context](const EdgeShapes& inputShapes, EdgeShapes& outputShapes) -> ComPtr { // If the output size is not dynamic, infer it using the kernel. The result of inference is stored in m_inferredOutputShapes. if (m_requiresOutputShapesAtCreation) { @@ -1872,7 +2335,8 @@ namespace Windows::AI::MachineLearning::Adapter m_internalOperator, m_defaultAttributes, m_requiredConstantCpuInputs, - constantInputGetter); + constantInputGetter, + context /*const onnxruntime::OpKernelContext* m_kernelContext*/); ComPtr ret; ORT_THROW_IF_FAILED(m_operatorFactory->CreateKernel(kernelInfoWrapper.Get(), ret.GetAddressOf())); @@ -1893,29 +2357,16 @@ namespace Windows::AI::MachineLearning::Adapter m_constantInputTensorContentsOfKernel.resize(context->InputCount()); for (uint32_t index : m_requiredConstantCpuInputs) { - const onnxruntime::Tensor* weakTensor = context->Input(static_cast(index)); - - // Skip optional constant tensors. - if (weakTensor != nullptr) + if (index >= m_constantInputTensorContentsOfKernel.size()) { - MLOperatorTensor tensor = MLOperatorTensor(constantInputGetter(index).Get()); + continue; + } - if (index >= static_cast(context->InputCount())) - { - continue; - } - m_constantInputTensorContentsOfKernel[index].isValid = (tensor.GetInterface() != nullptr); + auto constantInput = constantInputGetter(index); - if (tensor.GetInterface() != nullptr) - { - m_constantInputTensorContentsOfKernel[index].shape = tensor.GetShape(); - m_constantInputTensorContentsOfKernel[index].type = tensor.GetTensorDataType(); - m_constantInputTensorContentsOfKernel[index].data.resize(tensor.GetUnalignedTensorByteSize()); - } - m_constantInputTensorContentsOfKernel[index].data.assign( - reinterpret_cast(tensor.GetByteData()), - reinterpret_cast(tensor.GetByteData()) + tensor.GetUnalignedTensorByteSize()); - } + std::visit([this, context, index](auto&& arg) { + FillConstantInputs(arg, context, index); + }, constantInput); } m_kernel = inferShapesAndCreateKernel(m_inputShapesOfKernelInference, m_inferredOutputShapes); @@ -1934,25 +2385,15 @@ namespace Windows::AI::MachineLearning::Adapter continue; } - const TensorContent& lastValue = m_constantInputTensorContentsOfKernel[index]; - MLOperatorTensor currentValue(constantInputGetter(index).Get()); + auto constantInput = constantInputGetter(index); + requiredCpuInputsChanged = std::visit([this, index](auto&& arg){ + return RequiredCpuInputChanged(arg, index); + }, constantInput); - if (lastValue.isValid != (currentValue.GetInterface() != nullptr)) + if (requiredCpuInputsChanged) { break; } - - if (lastValue.isValid) - { - if (lastValue.shape != currentValue.GetShape() || - lastValue.type != currentValue.GetTensorDataType() || - currentValue.GetUnalignedTensorByteSize() != lastValue.data.size() || - (memcmp(lastValue.data.data(), currentValue.GetByteData(), lastValue.data.size()) != 0)) - { - requiredCpuInputsChanged = true; - break; - } - } } // In the edge case that the input size is changing across invocations and the kernel requires @@ -2000,11 +2441,132 @@ namespace Windows::AI::MachineLearning::Adapter return onnxruntime::Status(); } + bool AbiOpKernel::RequiredCpuInputChanged(const ComPtr& constantTensor, uint32_t index) const + { + assert(std::holds_alternative(m_constantInputTensorContentsOfKernel[index])); + + auto lastValue = std::get(m_constantInputTensorContentsOfKernel[index]); + MLOperatorTensor currentValue(constantTensor.Get()); + + if (lastValue.isValid != (currentValue.GetInterface() != nullptr)) + { + return false; + } + + if (lastValue.isValid) + { + if (lastValue.shape != currentValue.GetShape() || + lastValue.type != currentValue.GetTensorDataType() || + currentValue.GetUnalignedTensorByteSize() != lastValue.data.size() || + (memcmp(lastValue.data.data(), currentValue.GetByteData(), lastValue.data.size()) != 0)) + { + return true; + } + } + + return false; + } + + bool AbiOpKernel::RequiredCpuInputChanged(const std::vector>& constantTensorSequence, uint32_t index) const + { + assert(std::holds_alternative>(m_constantInputTensorContentsOfKernel[index])); + auto lastValues = std::get>(m_constantInputTensorContentsOfKernel[index]); + + for (uint32_t sequenceIndex = 0; sequenceIndex < constantTensorSequence.size(); ++sequenceIndex) + { + const auto& lastValue = lastValues[sequenceIndex]; + MLOperatorTensor currentValue(constantTensorSequence[sequenceIndex].Get()); + + if (lastValue.isValid != (currentValue.GetInterface() != nullptr)) + { + return false; + } + + if (lastValue.isValid) + { + if (lastValue.shape != currentValue.GetShape() || + lastValue.type != currentValue.GetTensorDataType() || + currentValue.GetUnalignedTensorByteSize() != lastValue.data.size() || + (memcmp(lastValue.data.data(), currentValue.GetByteData(), lastValue.data.size()) != 0)) + { + return true; + } + } + } + + return false; + } + + void AbiOpKernel::FillConstantInputs(const ComPtr& constantTensor, onnxruntime::OpKernelContext* context, uint32_t index) const + { + // Skip optional constant tensors. + if (constantTensor != nullptr) + { + MLOperatorTensor tensor = MLOperatorTensor(constantTensor.Get()); + + if (index >= static_cast(context->InputCount())) + { + return; + } + + TensorContent tensorContent{}; + tensorContent.isValid = (tensor.GetInterface() != nullptr); + + if (tensor.GetInterface() != nullptr) + { + tensorContent.shape = tensor.GetShape(); + tensorContent.type = tensor.GetTensorDataType(); + tensorContent.data.resize(tensor.GetUnalignedTensorByteSize()); + } + + tensorContent.data.assign( + reinterpret_cast(tensor.GetByteData()), + reinterpret_cast(tensor.GetByteData()) + tensor.GetUnalignedTensorByteSize()); + + m_constantInputTensorContentsOfKernel[index] = std::move(tensorContent); + } + } + + void AbiOpKernel::FillConstantInputs(const std::vector>& constantTensorSequence, onnxruntime::OpKernelContext* context, uint32_t index) const + { + std::vector tensorContent(constantTensorSequence.size()); + + for (uint32_t i = 0; i < constantTensorSequence.size(); ++i) + { + const ComPtr& constantTensor = constantTensorSequence[i]; + + // Skip optional constant tensors. + if (constantTensor == nullptr) + { + continue; + } + + MLOperatorTensor tensor = MLOperatorTensor(constantTensor.Get()); + + if (index >= static_cast(context->InputCount())) + { + continue; + } + tensorContent[i].isValid = (tensor.GetInterface() != nullptr); + + if (tensor.GetInterface() != nullptr) + { + tensorContent[i].shape = tensor.GetShape(); + tensorContent[i].type = tensor.GetTensorDataType(); + tensorContent[i].data.resize(tensor.GetUnalignedTensorByteSize()); + } + tensorContent[i].data.assign( + reinterpret_cast(tensor.GetByteData()), + reinterpret_cast(tensor.GetByteData()) + tensor.GetUnalignedTensorByteSize()); + } + + m_constantInputTensorContentsOfKernel[index] = std::move(tensorContent); + } + bool AbiOpKernel::InputTensorShapesDefined() const { onnxruntime::ProtoHelperNodeContext protoContext(Node()); onnxruntime::OpNodeProtoHelper info(&protoContext); - return InputTensorShapesDefinedOnNode(info); } @@ -2018,15 +2580,35 @@ namespace Windows::AI::MachineLearning::Adapter auto inputType = context->InputType(static_cast(i)); if (inputType != nullptr && inputType->IsTensorType()) { - const onnxruntime::Tensor* tensor = context->Input(static_cast(i)); - if (tensor) + if (context->InputType(gsl::narrow_cast(i))->IsTensorSequenceType()) { - ret.GetMutableShape(i).resize(tensor->Shape().GetDims().size()); - for (size_t j = 0; j < ret.GetMutableShape(i).size(); ++j) + auto inputTensorSeq = context->Input(gsl::narrow_cast(i)); + for (uint32_t sequenceIndex = 0; sequenceIndex < inputTensorSeq->Size(); ++sequenceIndex) { - ret.GetMutableShape(i)[j] = gsl::narrow_cast(tensor->Shape().GetDims()[j]); + const auto& tensor = inputTensorSeq->Get(sequenceIndex); + ret.GetMutableShape(i).resize(tensor.Shape().GetDims().size()); + for (size_t j = 0; j < ret.GetMutableShape(i).size(); ++j) + { + ret.GetMutableShape(i)[j] = gsl::narrow_cast(tensor.Shape().GetDims()[j]); + } } } + else if (context->InputType(gsl::narrow_cast(i))->IsTensorType()) + { + const onnxruntime::Tensor* tensor = context->Input(gsl::narrow_cast(i)); + if (tensor) + { + ret.GetMutableShape(i).resize(tensor->Shape().GetDims().size()); + for (size_t j = 0; j < ret.GetMutableShape(i).size(); ++j) + { + ret.GetMutableShape(i)[j] = gsl::narrow_cast(tensor->Shape().GetDims()[j]); + } + } + } + else + { + ORT_THROW_HR(E_INVALIDARG); + } } } @@ -2082,6 +2664,7 @@ namespace Windows::AI::MachineLearning::Adapter if (outputProto->value_case() != onnx::TypeProto::kTensorType) { + assert(outputShapes.GetShape(outputIndex).empty()); ML_CHECK_BOOL(outputShapes.GetShape(outputIndex).empty()); continue; } @@ -2091,6 +2674,7 @@ namespace Windows::AI::MachineLearning::Adapter if (tensorType.has_shape()) { const auto& shape = tensorType.shape(); + assert(static_cast(shape.dim_size()) == outputShapes.GetShape(outputIndex).size()); ML_CHECK_BOOL(static_cast(shape.dim_size()) == outputShapes.GetShape(outputIndex).size()); for (uint32_t output_dim = 0; output_dim < outputShapes.GetShape(outputIndex).size(); ++output_dim) @@ -2099,6 +2683,7 @@ namespace Windows::AI::MachineLearning::Adapter { int64_t expected_size = shape.dim(output_dim).dim_value(); int64_t actual_size = outputShapes.GetShape(outputIndex)[output_dim]; + assert(expected_size == actual_size); ML_CHECK_BOOL(expected_size == actual_size); } } @@ -2145,7 +2730,6 @@ namespace Windows::AI::MachineLearning::Adapter MLOperatorEdgeDescription edgeDesc; ORT_THROW_IF_FAILED(GetOutputEdgeDescription(outputIndex, &edgeDesc)); - ML_CHECK_BOOL(edgeDesc.edgeType == MLOperatorEdgeType::Undefined || edgeDesc.edgeType == MLOperatorEdgeType::Tensor); // In the process of calling mutable_tensor_type, the type may switch from undefined to tensor. // This is done here in case the dimension count is zero (scalar) @@ -2307,7 +2891,7 @@ namespace Windows::AI::MachineLearning::Adapter } std::tuple, size_t> UnpackTensor( - const onnx::TensorProto& initializer, + const onnx::TensorProto& initializer, const onnxruntime::Path& modelPath) { std::unique_ptr unpackedTensor; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h index dd1b743587ab5..2915b8c915044 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h @@ -35,22 +35,30 @@ bool InputTensorShapesDefinedOnNode(const onnxruntime::OpNodeProtoHelper& nod for (uint32_t inputIndex = 0; inputIndex < inputCount; ++inputIndex) { - if (nodeInfo.GetInputType(inputIndex) && (nodeInfo.GetInputType(inputIndex)->value_case() == onnx::TypeProto::kTensorType)) + auto input = nodeInfo.GetInputType(inputIndex); + if (input) { - if (!nodeInfo.GetInputType(inputIndex)->tensor_type().has_shape()) + if (input->value_case() == onnx::TypeProto::kTensorType) { - return false; - } + if (!input->tensor_type().has_shape()) + { + return false; + } - const auto& shape = nodeInfo.GetInputType(inputIndex)->tensor_type().shape(); + const auto& shape = input->tensor_type().shape(); - for (int input_dim = 0; input_dim < shape.dim_size(); ++input_dim) - { - if (!shape.dim(input_dim).has_dim_value()) + for (int input_dim = 0; input_dim < shape.dim_size(); ++input_dim) { - return false; + if (!shape.dim(input_dim).has_dim_value()) + { + return false; + } } } + else if (input->value_case() == onnx::TypeProto::kSequenceType) + { + return false; + } } } @@ -164,9 +172,11 @@ class OpNodeInfoWrapper : public Base1_t, public Base2_t, public Closable const EdgeShapes* inputShapesOverride, const AttributeMap* defaultAttributes, gsl::span requiredConstantCpuInputs, - MLOperatorTensorGetter& constantInputGetter + MLOperatorTensorGetter& constantInputGetter, + const onnxruntime::OpKernelContext* kernelContext = nullptr ) : m_impl(impl), + m_kernelContext(kernelContext), m_inputShapesOverride(inputShapesOverride), m_constantInputGetter(constantInputGetter), m_defaultAttributes(defaultAttributes) @@ -217,6 +227,10 @@ class OpNodeInfoWrapper : public Base1_t, public Base2_t, public Closable HRESULT STDMETHODCALLTYPE GetInputTensorDimensionCount(uint32_t inputIndex, uint32_t* dimensionCount) const noexcept; HRESULT STDMETHODCALLTYPE GetInputTensorShape(uint32_t inputIndex, uint32_t dimensionCount, uint32_t* dimensions) const noexcept; + HRESULT STDMETHODCALLTYPE GetSequenceInputCount(uint32_t inputIndex, uint32_t* inputCount) const noexcept; + HRESULT STDMETHODCALLTYPE GetSequenceInputTensorDimensionCount(uint32_t inputIndex, uint32_t sequenceIndex, uint32_t* dimensionCount) const noexcept; + HRESULT STDMETHODCALLTYPE GetSequenceInputTensorShape(uint32_t inputIndex, uint32_t sequenceIndex, uint32_t dimensionCount, uint32_t* dimensions) const noexcept; + bool STDMETHODCALLTYPE IsInputValid(uint32_t inputIndex) const noexcept override; bool STDMETHODCALLTYPE IsOutputValid(uint32_t outputIndex) const noexcept override; @@ -228,6 +242,7 @@ class OpNodeInfoWrapper : public Base1_t, public Base2_t, public Closable protected: // Lifetime is managed by the caller and guaranteed to outlive this class const onnxruntime::OpNodeProtoHelper* m_impl = nullptr; + const onnxruntime::OpKernelContext* m_kernelContext = nullptr; private: template @@ -332,7 +347,7 @@ class OpKernelInfoWrapper : public OpNodeInfoWrapper< onnxruntime::ProtoHelperNodeContext, WRL::Base< Microsoft::WRL::ChainInterfaces, - IMLOperatorTensorShapeDescription, IMLOperatorAttributes1>, + IMLOperatorTensorShapeDescription, IMLOperatorTensorShapeDescriptionPrivate, IMLOperatorAttributes1>, onnxruntime::null_type> { public: @@ -346,7 +361,8 @@ class OpKernelInfoWrapper : public OpNodeInfoWrapper< bool isInternalOperator, const AttributeMap* defaultAttributes, gsl::span requiredConstantCpuInputs, - MLOperatorTensorGetter& constantInputGetter + MLOperatorTensorGetter& constantInputGetter, + const onnxruntime::OpKernelContext* kernelContext = nullptr ); // HasTensorShapeDescription returns false if and only if the kernel is registered using @@ -395,7 +411,7 @@ class DmlGraphOpKernelInfoWrapper : public OpNodeInfoWrapper< onnxruntime::ProtoHelperNodeContext, WRL::Base< Microsoft::WRL::ChainInterfaces, - IMLOperatorTensorShapeDescription, IMLOperatorAttributes1>, + IMLOperatorTensorShapeDescription, IMLOperatorTensorShapeDescriptionPrivate, IMLOperatorAttributes1>, onnxruntime::null_type> { public: @@ -441,18 +457,32 @@ class DmlGraphOpKernelInfoWrapper : public OpNodeInfoWrapper< DmlGraphNodeCreateInfo* m_graphNodeCreateInfo = nullptr; }; -class OpKernelContextWrapper : public WRL::Base, public Closable +class OpKernelContextWrapper : public WRL::Base, public Closable { public: ~OpKernelContextWrapper(); OpKernelContextWrapper(onnxruntime::OpKernelContext* context, const onnxruntime::IExecutionProvider* provider, bool isInternalOperator, const EdgeShapes* outputShapes); + bool STDMETHODCALLTYPE IsSequenceInputTensor(uint32_t inputIndex) const noexcept override; + HRESULT STDMETHODCALLTYPE GetSequenceInputCount(uint32_t inputIndex, uint32_t* inputCount) const noexcept override; + HRESULT STDMETHODCALLTYPE GetSequenceInputTensor(uint32_t inputIndex, uint32_t sequenceIndex, IMLOperatorTensor** tensor) const noexcept override; + + HRESULT STDMETHODCALLTYPE GetSequenceOutputTensor( + uint32_t outputIndex, + uint32_t sequenceIndex, + MLOperatorTensorDataType dataType, + uint32_t dimensions, + const uint32_t* dimensionSizes, + bool gpuOutput, + IMLOperatorTensor** tensor) const noexcept override; + HRESULT STDMETHODCALLTYPE GetInputTensor(uint32_t inputIndex, IMLOperatorTensor** tensor) const noexcept override; + HRESULT STDMETHODCALLTYPE GetOutputTensor(uint32_t outputIndex, IMLOperatorTensor** tensor) noexcept override; HRESULT STDMETHODCALLTYPE GetOutputTensor(uint32_t outputIndex, uint32_t dimensions, const uint32_t* dimensionSizes, IMLOperatorTensor** tensor) noexcept override; - HRESULT STDMETHODCALLTYPE AllocateTemporaryData(size_t size, IUnknown** data) const; + HRESULT STDMETHODCALLTYPE AllocateTemporaryData(size_t size, IUnknown** data) const noexcept override; HRESULT STDMETHODCALLTYPE AllocateTemporaryData(size_t size, IUnknown** data, uint64_t* allocId) const; void STDMETHODCALLTYPE GetExecutionInterface(IUnknown** executionInterface) const noexcept override; @@ -470,8 +500,8 @@ class OpKernelContextWrapper : public WRL::Base, publi onnxruntime::OpKernelContext* m_impl = nullptr; const EdgeShapes* m_outputShapes = nullptr; - std::vector> m_inputTensors; - std::vector> m_outputTensors; + std::vector>> m_inputTensors; + std::vector>> m_outputTensors; const onnxruntime::IExecutionProvider* m_provider = nullptr; ComPtr m_winmlProvider; @@ -535,7 +565,7 @@ class AbiOpKernel : public onnxruntime::OpKernel std::vector data; }; - mutable std::vector m_constantInputTensorContentsOfKernel; + mutable std::vector>> m_constantInputTensorContentsOfKernel; mutable std::mutex m_mutex; mutable EdgeShapes m_inferredOutputShapes; @@ -550,6 +580,12 @@ class AbiOpKernel : public onnxruntime::OpKernel ComPtr m_abiExecutionObject; const AttributeMap* m_defaultAttributes = nullptr; + +private: + bool RequiredCpuInputChanged(const ComPtr& constantTensor, uint32_t index) const; + bool RequiredCpuInputChanged(const std::vector>& constantTensorSequence, uint32_t index) const; + void FillConstantInputs(const ComPtr& constantTensor, onnxruntime::OpKernelContext* context, uint32_t index) const; + void FillConstantInputs(const std::vector>& constantTensor, onnxruntime::OpKernelContext* context, uint32_t index) const; }; class MLSchemaInferenceContext final : public OpNodeInfoWrapper< @@ -642,7 +678,7 @@ class MLSupportQueryContext final : public OpNodeInfoWrapper< // TODO - ... }; -onnxruntime::MLDataType ToTensorDataType(::MLOperatorTensorDataType type); +onnxruntime::MLDataType ToMLDataType(::MLOperatorEdgeType edgeType, ::MLOperatorTensorDataType type); std::string ToTypeString(MLOperatorEdgeDescription desc); onnx::AttributeProto_AttributeType ToProto(MLOperatorAttributeType type); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.cpp index 3ae29629efbcd..6ce937e2c0366 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.cpp @@ -328,14 +328,11 @@ namespace Dml } } - void DmlOperator::InitializeWithShapes( + void DmlOperator::InitializeInputsWithShapes( const MLOperatorKernelCreationContext& kernelInfo, const std::optional>>& kernelInputIndices, - const std::optional>>& kernelOutputIndices, const std::optional>> inputShapes, - const std::optional>> outputShapes, - uint32_t minDimensionCount - ) + uint32_t minDimensionCount) { if (kernelInputIndices) { @@ -347,15 +344,6 @@ namespace Dml std::iota(m_kernelInputIndices.begin(), m_kernelInputIndices.end(), 0); } - if (kernelOutputIndices) - { - m_kernelOutputIndices = *kernelOutputIndices; - } - else - { - m_kernelOutputIndices.resize(kernelInfo.GetOutputCount()); - std::iota(m_kernelOutputIndices.begin(), m_kernelOutputIndices.end(), 0); - } for (uint32_t i = 0; i < m_kernelInputIndices.size(); i++) { @@ -403,6 +391,23 @@ namespace Dml m_inputTensorDescs.push_back(tensorDesc); } } + } + + void DmlOperator::InitializeOutputsWithShapes( + const MLOperatorKernelCreationContext& kernelInfo, + const std::optional>>& kernelOutputIndices, + const std::optional>> outputShapes, + uint32_t minDimensionCount) + { + if (kernelOutputIndices) + { + m_kernelOutputIndices = *kernelOutputIndices; + } + else + { + m_kernelOutputIndices.resize(kernelInfo.GetOutputCount()); + std::iota(m_kernelOutputIndices.begin(), m_kernelOutputIndices.end(), 0); + } for (uint32_t i = 0; i < m_kernelOutputIndices.size(); i++) { @@ -434,6 +439,19 @@ namespace Dml } } + void DmlOperator::InitializeWithShapes( + const MLOperatorKernelCreationContext& kernelInfo, + const std::optional>>& kernelInputIndices, + const std::optional>>& kernelOutputIndices, + const std::optional>> inputShapes, + const std::optional>> outputShapes, + uint32_t minDimensionCount + ) + { + InitializeInputsWithShapes(kernelInfo, kernelInputIndices, inputShapes, minDimensionCount); + InitializeOutputsWithShapes(kernelInfo, kernelOutputIndices, outputShapes, minDimensionCount); + } + void DmlOperator::Compute(const MLOperatorKernelContext& kernelContext) { std::vector inputTensors = GetInputTensorsForExecute(kernelContext); @@ -649,6 +667,59 @@ namespace Dml ); } + TensorSequenceDesc DmlOperator::CreateTensorSequenceDescFromInput( + const MLOperatorKernelCreationContext& kernelInfo, + uint32_t index, + int32_t coerceAxis, + int32_t placement, + int32_t leftAlignedDimensionCount, + std::optional> tensorShape, + uint32_t minDimensionCount + ) const + { + if (!kernelInfo.IsInputValid(index)) + { + // The tensor is optional. + return TensorSequenceDesc(); + } + + auto edgeDesc = kernelInfo.GetInputEdgeDescription(index); + assert(edgeDesc.edgeType == MLOperatorEdgeType::SequenceTensor); + ORT_THROW_HR_IF(E_INVALIDARG, edgeDesc.edgeType != MLOperatorEdgeType::SequenceTensor); + + const auto& shapeDescription = kernelInfo.GetTensorShapeDescription(); + const uint32_t numTensors = shapeDescription.GetSequenceInputCount(index); + + TensorSequenceDesc tensorDescs; + tensorDescs.reserve(numTensors); + + for (uint32_t sequenceIndex = 0; sequenceIndex < numTensors; ++sequenceIndex) + { + std::vector actualTensorShape; + if (kernelInfo.HasTensorShapeDescription()) + { + actualTensorShape = shapeDescription.GetSequenceInputTensorShape(index, sequenceIndex); + + tensorDescs.emplace_back( + edgeDesc.tensorDataType, + tensorShape ? *tensorShape : actualTensorShape, + actualTensorShape, + coerceAxis, + placement, + leftAlignedDimensionCount, + minDimensionCount, + 0); + } + else + { + // The tensor has delayed shape determination. + tensorDescs.push_back(TensorDesc()); + } + } + + return tensorDescs; + } + TensorDesc DmlOperator::CreateTensorDescFromOutput( const MLOperatorKernelCreationContext& kernelInfo, uint32_t index, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.h index 493cc2e44577a..c1e8cf42a974c 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.h @@ -55,6 +55,20 @@ namespace Dml uint32_t minDimensionCount = NchwDimensionCount ); + void InitializeInputsWithShapes( + const MLOperatorKernelCreationContext& kernelInfo, + const std::optional>>& kernelInputIndices = std::nullopt, + const std::optional>> inputShapes = std::nullopt, + uint32_t minDimensionCount = NchwDimensionCount + ); + + void InitializeOutputsWithShapes( + const MLOperatorKernelCreationContext& kernelInfo, + const std::optional>>& kernelOutputIndices = std::nullopt, + const std::optional>> outputShapes = std::nullopt, + uint32_t minDimensionCount = NchwDimensionCount + ); + bool AllowHalfPrecisionComputation() const; DML_EXECUTION_FLAGS GetExecutionFlags() const; @@ -115,6 +129,16 @@ namespace Dml uint32_t minDimensionCount = NchwDimensionCount ) const; + TensorSequenceDesc CreateTensorSequenceDescFromInput( + const MLOperatorKernelCreationContext& kernelInfo, + uint32_t index, + int32_t coerceAxis = TensorAxis::DoNotCoerce, + int32_t placement = TensorAxis::W, + int32_t leftAlignedDimensionCount = TensorAxis::RightAligned, + std::optional> tensorShape = std::nullopt, + uint32_t minDimensionCount = NchwDimensionCount + ) const; + TensorDesc CreateTensorDescFromOutput( const MLOperatorKernelCreationContext& kernelInfo, uint32_t index, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorAttention.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorAttention.cpp index 63bae80c51a67..af93808248032 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorAttention.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorAttention.cpp @@ -401,15 +401,15 @@ class DmlOperatorAttention : public DmlOperator void CALLBACK QueryAttention(IMLOperatorSupportQueryContextPrivate* context, /*out*/ bool* isSupported) { *isSupported = false; - // Fall back to CPU if input 'past' and 'extra_add' is present because there is no current use case for this. + // Fall back to CPU if input 'past' and 'relative_position_bias' is present because there is no current use case for this. // and it will make the implementation more complex. // Also fall back to CPU if output 'present' is present for same reason as above. if (context->GetInputCount() > 4 || context->GetOutputCount() > 1) { return; } - // Checking input count alone is not sufficient to fallback to CPU if input 'past' and 'extra_add' is present - // because input 'mask_index', 'past', and 'extra_add' all are optional. + // Checking input count alone is not sufficient to fallback to CPU if input 'past' and 'relative_position_bias' is present + // because input 'mask_index', 'past', and 'relative_position_bias' all are optional. if (context->IsInputValid(4) || context->IsInputValid(5)) { return; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorConcatFromSequence.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorConcatFromSequence.cpp new file mode 100644 index 0000000000000..c9b38ccc908ee --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorConcatFromSequence.cpp @@ -0,0 +1,161 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "precomp.h" + +namespace Dml +{ + +class DmlOperatorConcatFromSequence : public DmlOperator +{ +public: + using Shape = std::vector; + using Self = DmlOperatorConcatFromSequence; + +private: + std::vector m_inputTensorDescs; + std::vector m_inputIndices; + TensorDesc m_outputTensorDesc; + Shape m_outputShape; + +public: + + DmlOperatorConcatFromSequence(const MLOperatorKernelCreationContext& kernelInfo) + : DmlOperator(kernelInfo) + { + auto new_axis = static_cast(kernelInfo.GetOptionalAttribute(AttrName::NewAxis, 0)); + ML_CHECK_VALID_ARGUMENT(1 == new_axis || 0 == new_axis); + + // Ensure there is only 1 input, and 1 output + ML_CHECK_VALID_ARGUMENT(kernelInfo.GetInputCount() == 1); + ML_CHECK_VALID_ARGUMENT(kernelInfo.GetOutputCount() == 1); + + // Ensure there the singular input is a squence of tensors + auto edgeDesc = kernelInfo.GetInputEdgeDescription(0); + assert(edgeDesc.edgeType == MLOperatorEdgeType::SequenceTensor); + auto sequenceInputDataType = edgeDesc.tensorDataType; + auto sequenceInputDmlDataType = Dml::GetDmlDataTypeFromMlDataTypeNoThrow(sequenceInputDataType); + + // Ensure there the singular output is a tensors + edgeDesc = kernelInfo.GetOutputEdgeDescription(0); + assert(edgeDesc.edgeType == MLOperatorEdgeType::Tensor); + + // Get the number of tensors in the sequence of tensors input. + // If no tensors are in the sequence, then no evaluation should occur. + // In this case the output shape is default initialized to {}. + auto tensorShapeDescription = kernelInfo.GetTensorShapeDescription(); + auto numTensorsInSequence = tensorShapeDescription.GetSequenceInputCount(0); + if (numTensorsInSequence > 0) + { + uint32_t axis = 0; + uint32_t axisTotal = 0; + std::optional inputDimCount; + for (uint32_t i = 0; i < numTensorsInSequence; i++) + { + // Remove empty tensors and keep only the non-empty tensors + auto shape = tensorShapeDescription.GetSequenceInputTensorShape(0, i); + + auto r = static_cast(shape.size()); + auto is_scalar = r == 0; + if (!new_axis && is_scalar) + { + ORT_THROW("Cannot concatenate scalars"); + } + + const int32_t signedAxis = gsl::narrow_cast(kernelInfo.GetAttribute(AttrName::Axis)); + axis = static_cast(HandleNegativeAxis(signedAxis, r + new_axis, !is_scalar)); + if (new_axis) + { + ML_CHECK_VALID_ARGUMENT(axis < r + 1); + shape.insert(shape.begin() + axis, 1); + } + else + { + ML_CHECK_VALID_ARGUMENT(axis < r); + } + + // When processing the first tensor, initialize output shape, inputDimCount and axis. + if (!inputDimCount) + { + m_outputShape = shape; + inputDimCount = r + new_axis; + } + + axisTotal += shape[axis]; + + if (OperatorHelper::ContainsEmptyDimensions(shape)) + { + continue; + } + + ML_CHECK_BOOL(*inputDimCount == shape.size()); + m_inputTensorDescs.emplace_back(TensorDesc(sequenceInputDmlDataType, shape)); + m_inputIndices.push_back(i); + } + + m_outputShape[axis] = axisTotal; + + // We should only call join if there exists input tensors that are non-empty and non-scalar. + // In that case, the inputDimCount must be set and greater than 0. + if (m_inputIndices.size() > 0) + { + m_outputTensorDesc = TensorDesc(sequenceInputDmlDataType, m_outputShape); + auto dmlAxis = GetDmlAdjustedAxis(axis, *inputDimCount, m_outputTensorDesc.GetDimensionCount()); + + auto outputIndices = std::vector> { 0 }; + gsl::span outputShapes[1] = { m_outputShape }; + DmlOperator::InitializeOutputsWithShapes(kernelInfo, outputIndices, outputShapes, 1); + + auto outputDescs = std::vector { m_outputTensorDesc.GetDmlDesc() }; + auto inputDescs = std::vector(m_inputTensorDescs.size()); + for (int i = 0; i < inputDescs.size(); i++) + { + inputDescs[i] = m_inputTensorDescs[i].GetDmlDesc(); + } + + DML_JOIN_OPERATOR_DESC joinDesc = {}; + joinDesc.InputCount = gsl::narrow_cast(inputDescs.size()); + joinDesc.InputTensors = inputDescs.data(); + joinDesc.OutputTensor = outputDescs.data(); + joinDesc.Axis = dmlAxis; + + DML_OPERATOR_DESC opDesc = { DML_OPERATOR_JOIN, &joinDesc }; + + SetDmlOperatorDesc(opDesc, kernelInfo); + } + } + } + + void Compute(const MLOperatorKernelContext& kernelContext) + { + auto outputTensor = kernelContext.GetOutputTensor(0, m_outputShape).GetInterface().Get(); + + if (!m_inputIndices.size()) + { + return; + } + + ComPtr operatorKernelContext; + kernelContext.GetInterface().As(&operatorKernelContext); + auto inputTensors = std::vector(m_inputIndices.size()); + for (uint32_t i = 0; i < inputTensors.size(); i++) + { + assert(m_inputTensorDescs[i].IsValid()); + ComPtr inputTensor; + ORT_THROW_IF_FAILED(operatorKernelContext->GetSequenceInputTensor(0, m_inputIndices[i], &inputTensor)); + inputTensors[i] = inputTensor.Get(); + } + + auto outputTensors = gsl::span { &outputTensor, 1 }; + + ORT_THROW_IF_FAILED(m_executionProvider->ExecuteOperator( + m_compiledOperator.Get(), + m_persistentResourceBinding ? &*m_persistentResourceBinding : nullptr, + gsl::make_span(inputTensors), + outputTensors)); + } +}; + +DML_OP_DEFINE_CREATION_FUNCTION(ConcatFromSequence, DmlOperatorConcatFromSequence); + +} // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorMemcpy.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorMemcpy.cpp index 1b84cc50afcc6..34ca280e5c80b 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorMemcpy.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorMemcpy.cpp @@ -6,6 +6,7 @@ namespace Dml { +template class DmlOperatorMemcpy : public DmlOperator { public: @@ -16,24 +17,43 @@ class DmlOperatorMemcpy : public DmlOperator { ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetInputCount() == 1, "MemcpyFromHost/ToHost expects 1 input tensor."); ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetOutputCount() == 1, "MemcpyFromHost/ToHost expects 1 output tensor."); - - DmlOperator::Initialize(kernelCreationContext); } void Compute(const MLOperatorKernelContext& kernelContext) { std::vector inputTensors = GetInputTensors(kernelContext); - std::vector outputTensors = GetOutputTensors(kernelContext); - assert(inputTensors.size() == 1); - assert(outputTensors.size() == 1); + std::vector outputTensors = GetInputTensors(kernelContext); - if (!OperatorHelper::ContainsEmptyDimensions(MLOperatorTensor(inputTensors.front()).GetShape())) + if (kernelContext.IsSequenceInputTensor(0)) { - ORT_THROW_IF_FAILED(m_executionProvider->CopyTensor( - outputTensors.front(), - inputTensors.front() - )); + const uint32_t numTensors = kernelContext.GetSequenceInputCount(0); + inputTensors.reserve(numTensors); + + for (uint32_t sequenceIndex = 0; sequenceIndex < numTensors; ++sequenceIndex) + { + auto* inputTensor = kernelContext.GetSequenceInputTensor(0, sequenceIndex).GetInterface().Get(); + const uint32_t dimCount = inputTensor->GetDimensionCount(); + + std::vector dimensions(dimCount); + inputTensor->GetShape(dimCount, dimensions.data()); + + inputTensors.push_back(inputTensor); + outputTensors.push_back(kernelContext.GetSequenceOutputTensor( + 0, + sequenceIndex, + inputTensor->GetTensorDataType(), + dimCount, + dimensions.data(), + gpuOutput).GetInterface().Get()); + } } + else + { + inputTensors = { kernelContext.GetInputTensor(0).GetInterface().Get() }; + outputTensors = { kernelContext.GetOutputTensor(0).GetInterface().Get() }; + } + + ORT_THROW_IF_FAILED(m_executionProvider->CopyTensors(outputTensors, inputTensors)); } private: @@ -41,7 +61,7 @@ class DmlOperatorMemcpy : public DmlOperator // MemcpyToHost is a special case which is hardcoded in MLOperatorAuthorImpl.cpp. If name changes this must be updated. // Special case makes sure that the output resource is created using the CPU allocator. -DML_OP_DEFINE_CREATION_FUNCTION(MemcpyFromHost, DmlOperatorMemcpy); -DML_OP_DEFINE_CREATION_FUNCTION(MemcpyToHost, DmlOperatorMemcpy); +DML_OP_DEFINE_CREATION_FUNCTION(MemcpyFromHost, DmlOperatorMemcpy); +DML_OP_DEFINE_CREATION_FUNCTION(MemcpyToHost, DmlOperatorMemcpy); } // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp index 480b59c47acf9..bd9d65804c90c 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp @@ -12,8 +12,99 @@ #include #include #include + using namespace Microsoft::WRL; +#include "core/framework/TensorSeq.h" +#include "core/providers/cpu/sequence/sequence_ops.h" +#include "core/providers/cpu/tensor/concatbase.h" + +namespace onnxruntime { + +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kDmlExecutionProvider, kOnnxDomain, 11, SequenceAt); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kDmlExecutionProvider, kOnnxDomain, 11, SequenceConstruct); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kDmlExecutionProvider, kOnnxDomain, 11, SequenceEmpty); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kDmlExecutionProvider, kOnnxDomain, 11, SequenceLength); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kDmlExecutionProvider, kOnnxDomain, 11, ConcatFromSequence); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kDmlExecutionProvider, kOnnxDomain, 11, SequenceErase); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kDmlExecutionProvider, kOnnxDomain, 11, SequenceInsert); + +} + +namespace onnxruntime { + +ONNX_OPERATOR_KERNEL_EX( + SequenceAt, + kOnnxDomain, + 11, + kDmlExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 1) + .TypeConstraint("S", DataTypeImpl::AllFixedSizeSequenceTensorTypes()) + .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()) + .TypeConstraint("I", std::vector{ + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}), + SequenceAt); + +ONNX_OPERATOR_KERNEL_EX( + SequenceConstruct, + kOnnxDomain, + 11, + kDmlExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()) + .TypeConstraint("S", DataTypeImpl::AllFixedSizeSequenceTensorTypes()), + SequenceConstruct); + +ONNX_OPERATOR_KERNEL_EX( + SequenceEmpty, + kOnnxDomain, + 11, + kDmlExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("S", DataTypeImpl::AllFixedSizeSequenceTensorTypes()), + SequenceEmpty); + +ONNX_OPERATOR_KERNEL_EX( + SequenceLength, + kOnnxDomain, + 11, + kDmlExecutionProvider, + (*KernelDefBuilder::Create()) + .OutputMemoryType(OrtMemTypeCPUInput, 0) + .TypeConstraint("S", DataTypeImpl::AllFixedSizeSequenceTensorTypes()) + .TypeConstraint("I", DataTypeImpl::GetTensorType()), + SequenceLength); + +ONNX_OPERATOR_KERNEL_EX( + SequenceErase, + kOnnxDomain, + 11, + kDmlExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 1) + .TypeConstraint("S", DataTypeImpl::AllFixedSizeSequenceTensorTypes()) + .TypeConstraint("I", std::vector{ + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}), + SequenceErase); + +ONNX_OPERATOR_KERNEL_EX( + SequenceInsert, + kOnnxDomain, + 11, + kDmlExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 2) + .TypeConstraint("S", DataTypeImpl::AllFixedSizeSequenceTensorTypes()) + .TypeConstraint("I", std::vector{ + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}), + SequenceInsert); + +} + namespace Dml { @@ -35,6 +126,21 @@ enum class SupportedTensorDataTypes : uint32_t UInt64 = 1<<13, Complex64 = 1<<14, Complex128 = 1<<15, + SequenceFloat32 = 1<<16, + SequenceUInt8 = 1<<17, + SequenceInt8 = 1<<18, + SequenceUInt16 = 1<<19, + SequenceInt16 = 1<<20, + SequenceInt32 = 1<<21, + SequenceInt64 = 1<<22, + SequenceString = 1<<23, + SequenceBool = 1<<24, + SequenceFloat16 = 1<<25, + SequenceFloat64 = 1<<26, + SequenceUInt32 = 1<<27, + SequenceUInt64 = 1<<28, + SequenceComplex64 = 1<<29, + SequenceComplex128 = 1<<30, Ints8to32 = UInt8|Int8|UInt16|Int16|UInt32|Int32, Ints32to64 = UInt32|Int32|UInt64|Int64, Ints8to64 = UInt8|Int8|UInt16|Int16|UInt32|Int32|UInt64|Int64, @@ -44,6 +150,8 @@ enum class SupportedTensorDataTypes : uint32_t NumericDefault = Ints8to32|Float16to32, // Only simple numbers, not bool, complex, or string. Scalars8to32 = UInt8|Int8|UInt16|Int16|UInt32|Int32|Float16to32|Bool, AllScalars = UInt8|Int8|UInt16|Int16|UInt32|Int32|UInt64|Int64|Float16|Float32|Float64|Bool, + AllSequences = SequenceUInt8|SequenceInt8|SequenceUInt16|SequenceInt16|SequenceUInt32|SequenceInt32| + SequenceUInt64|SequenceInt64|SequenceFloat16|SequenceFloat32|SequenceFloat64|SequenceBool, Ints8Bit = UInt8|Int8, Ints16Bit = UInt16|Int16, Ints32Bit = UInt32|Int32, @@ -116,6 +224,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(Split13); DML_OP_EXTERN_CREATION_FUNCTION(Transpose); DML_OP_EXTERN_CREATION_FUNCTION(Tile); DML_OP_EXTERN_CREATION_FUNCTION(Concat); +DML_OP_EXTERN_CREATION_FUNCTION(ConcatFromSequence); DML_OP_EXTERN_CREATION_FUNCTION(Slice7); DML_OP_EXTERN_CREATION_FUNCTION(Slice10); DML_OP_EXTERN_CREATION_FUNCTION(Slice11); @@ -284,6 +393,7 @@ DML_OP_EXTERN_QUERY_FUNCTION(QLinearSigmoid); DML_OP_EXTERN_QUERY_FUNCTION(Attention); constexpr static std::array typeNameListDefault = {"T"}; +constexpr static std::array typeNameListDefaultV = {"V"}; constexpr static std::array typeNameListAttention = {"T", "M"}; constexpr static std::array typeNameListTwo = { "T1", "T2" }; constexpr static std::array typeNameListLayerNorm = { "T", "U" }; @@ -315,6 +425,7 @@ constexpr static std::array supportedTypeListFloat1 constexpr static std::array supportedTypeListUInt8to64 = {SupportedTensorDataTypes::UInt8to64}; constexpr static std::array supportedTypeListNumericDefault = { SupportedTensorDataTypes::NumericDefault }; constexpr static std::array supportedTypeListAllScalars = {SupportedTensorDataTypes::AllScalars}; +constexpr static std::array supportedTypeListAllScalarsAndSequences = {SupportedTensorDataTypes::AllScalars | SupportedTensorDataTypes::AllSequences}; constexpr static std::array supportedTypeListEyeLike = { SupportedTensorDataTypes::AllScalars, SupportedTensorDataTypes::AllScalars}; constexpr static std::array supportedTypeListBool = {SupportedTensorDataTypes::Bool}; constexpr static std::array supportedTypeListPow12 = {SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::NumericDefault}; @@ -457,6 +568,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO( 7, Concat, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, {REG_INFO( 11, Concat, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, // Adds negative axis. {REG_INFO( 13, Concat, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, // Adds negative axis. + {REG_INFO_DYNAMIC_OUTPUTS(11, ConcatFromSequence, typeNameListDefault, supportedTypeListAllScalarsAndSequences,DmlGraphSupport::NotSupported)}, // Adds negative axis. {REG_INFO_VER( 7, Slice, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, {REG_INFO_VER( 10, Slice, typeNameListSlice10, supportedTypeListSlice10, DmlGraphSupport::Supported, requiredConstantCpuInputs(1, 2, 3, 4), std::nullopt, QuerySlice)}, // Adds negative axes. {REG_INFO_VER( 11, Slice, typeNameListSlice10, supportedTypeListSlice10, DmlGraphSupport::Supported, requiredConstantCpuInputs(1, 2, 3, 4), std::nullopt, QuerySlice)}, @@ -496,8 +608,8 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation // Data reorganization that merely changes the dimensions while keeping the data identical. {REG_INFO_COPY( 7, Identity, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, {REG_INFO_COPY(13, Identity, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, - {REG_INFO_COPY(14, Identity, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, - {REG_INFO_COPY(16, Identity, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, + {REG_INFO_COPY(14, Identity, typeNameListDefaultV, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, + {REG_INFO_COPY(16, Identity, typeNameListDefaultV, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, {REG_INFO_COPY( 7, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, {REG_INFO_COPY( 9, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, {REG_INFO_COPY(11, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, @@ -763,11 +875,38 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation }; template -MLOperatorEdgeDescription EdgeDesc() +MLOperatorEdgeDescription TensorEdgeDesc() { return {MLOperatorEdgeType::Tensor, static_cast(MLTypeTraits::TensorType)}; } +template +MLOperatorEdgeDescription SequenceEdgeDesc() +{ + return {MLOperatorEdgeType::SequenceTensor, static_cast(MLTypeTraits::TensorType)}; +} + +void RegisterCpuOperatorsAsDml(onnxruntime::KernelRegistry* registry) +{ + using namespace onnxruntime; + + static const BuildKernelCreateInfoFn function_table[] = { + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + }; + + for (auto& function_table_entry : function_table) { + KernelCreateInfo info = function_table_entry(); + if (info.kernel_def != nullptr) { // filter disabled entries where type is void + ORT_THROW_IF_ERROR(registry->Register(std::move(info))); + } + } +} + void RegisterDmlOperators(IMLOperatorRegistry* registry) { ComPtr registryPrivate; @@ -817,20 +956,35 @@ void RegisterDmlOperators(IMLOperatorRegistry* registry) std::vector supportedTypeList; SupportedTensorDataTypes supportedTypes = information.supportedTensorDataTypes[i]; - if (bool(supportedTypes & SupportedTensorDataTypes::Float32)) edgeDescs.push_back(EdgeDesc()); - if (bool(supportedTypes & SupportedTensorDataTypes::UInt8 )) edgeDescs.push_back(EdgeDesc()); - if (bool(supportedTypes & SupportedTensorDataTypes::Int8 )) edgeDescs.push_back(EdgeDesc()); - if (bool(supportedTypes & SupportedTensorDataTypes::UInt16 )) edgeDescs.push_back(EdgeDesc()); - if (bool(supportedTypes & SupportedTensorDataTypes::Int16 )) edgeDescs.push_back(EdgeDesc()); - if (bool(supportedTypes & SupportedTensorDataTypes::Int32 )) edgeDescs.push_back(EdgeDesc()); - if (bool(supportedTypes & SupportedTensorDataTypes::Int64 )) edgeDescs.push_back(EdgeDesc()); - //if (bool(supportedTypes & SupportedTensorDataTypes::String )) edgeDescs.push_back(EdgeDesc()); - if (bool(supportedTypes & SupportedTensorDataTypes::Bool )) edgeDescs.push_back(EdgeDesc()); - if (bool(supportedTypes & SupportedTensorDataTypes::Float16)) edgeDescs.push_back(EdgeDesc<::MLFloat16>()); - if (bool(supportedTypes & SupportedTensorDataTypes::Float64)) edgeDescs.push_back(EdgeDesc()); - if (bool(supportedTypes & SupportedTensorDataTypes::UInt32 )) edgeDescs.push_back(EdgeDesc()); - if (bool(supportedTypes & SupportedTensorDataTypes::UInt64 )) edgeDescs.push_back(EdgeDesc()); + // Scalars + if (bool(supportedTypes & SupportedTensorDataTypes::Float32)) edgeDescs.push_back(TensorEdgeDesc()); + if (bool(supportedTypes & SupportedTensorDataTypes::UInt8 )) edgeDescs.push_back(TensorEdgeDesc()); + if (bool(supportedTypes & SupportedTensorDataTypes::Int8 )) edgeDescs.push_back(TensorEdgeDesc()); + if (bool(supportedTypes & SupportedTensorDataTypes::UInt16 )) edgeDescs.push_back(TensorEdgeDesc()); + if (bool(supportedTypes & SupportedTensorDataTypes::Int16 )) edgeDescs.push_back(TensorEdgeDesc()); + if (bool(supportedTypes & SupportedTensorDataTypes::Int32 )) edgeDescs.push_back(TensorEdgeDesc()); + if (bool(supportedTypes & SupportedTensorDataTypes::Int64 )) edgeDescs.push_back(TensorEdgeDesc()); + //if (bool(supportedTypes & SupportedTensorDataTypes::String )) edgeDescs.push_back(TensorEdgeDesc()); + if (bool(supportedTypes & SupportedTensorDataTypes::Bool )) edgeDescs.push_back(TensorEdgeDesc()); + if (bool(supportedTypes & SupportedTensorDataTypes::Float16)) edgeDescs.push_back(TensorEdgeDesc<::MLFloat16>()); + if (bool(supportedTypes & SupportedTensorDataTypes::Float64)) edgeDescs.push_back(TensorEdgeDesc()); + if (bool(supportedTypes & SupportedTensorDataTypes::UInt32 )) edgeDescs.push_back(TensorEdgeDesc()); + if (bool(supportedTypes & SupportedTensorDataTypes::UInt64 )) edgeDescs.push_back(TensorEdgeDesc()); + // Sequences + if (bool(supportedTypes & SupportedTensorDataTypes::SequenceFloat32)) edgeDescs.push_back(SequenceEdgeDesc()); + if (bool(supportedTypes & SupportedTensorDataTypes::SequenceUInt8 )) edgeDescs.push_back(SequenceEdgeDesc()); + if (bool(supportedTypes & SupportedTensorDataTypes::SequenceInt8 )) edgeDescs.push_back(SequenceEdgeDesc()); + if (bool(supportedTypes & SupportedTensorDataTypes::SequenceUInt16 )) edgeDescs.push_back(SequenceEdgeDesc()); + if (bool(supportedTypes & SupportedTensorDataTypes::SequenceInt16 )) edgeDescs.push_back(SequenceEdgeDesc()); + if (bool(supportedTypes & SupportedTensorDataTypes::SequenceInt32 )) edgeDescs.push_back(SequenceEdgeDesc()); + if (bool(supportedTypes & SupportedTensorDataTypes::SequenceInt64 )) edgeDescs.push_back(SequenceEdgeDesc()); + //if (bool(supportedTypes & SupportedTensorDataTypes::SequenceString )) edgeDescs.push_back(SequenceEdgeDesc()); + if (bool(supportedTypes & SupportedTensorDataTypes::SequenceBool )) edgeDescs.push_back(SequenceEdgeDesc()); + if (bool(supportedTypes & SupportedTensorDataTypes::SequenceFloat16)) edgeDescs.push_back(SequenceEdgeDesc<::MLFloat16>()); + if (bool(supportedTypes & SupportedTensorDataTypes::SequenceFloat64)) edgeDescs.push_back(SequenceEdgeDesc()); + if (bool(supportedTypes & SupportedTensorDataTypes::SequenceUInt32 )) edgeDescs.push_back(SequenceEdgeDesc()); + if (bool(supportedTypes & SupportedTensorDataTypes::SequenceUInt64 )) edgeDescs.push_back(SequenceEdgeDesc()); typeConstraints[i].allowedTypeCount = static_cast(edgeDescs.size() - lastEdgeDescSize); lastEdgeDescSize = edgeDescs.size(); } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.h index 1032dfc3a24dc..120407467a834 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.h @@ -12,7 +12,7 @@ class MLOperatorKernelCreationContext; #define DML_OP_EXTERN_QUERY_FUNCTION(operatorName) extern void CALLBACK Query##operatorName(IMLOperatorSupportQueryContextPrivate* context, bool* isSupported); // A specific opset version for registration. -// e.g. +// e.g. // DML_OP_DEFINE_CREATION_FUNCTION(RoiAlign10, VersionedKernel); template class VersionedKernel : public BaseClass @@ -31,7 +31,7 @@ class VersionedKernel : public BaseClass // // Note the second parameter is the class name, but templated parameters with // commas in them break the macro, and so they are stuffed into the VA_ARGS. -// +// #define DML_OP_DEFINE_CREATION_FUNCTION(operatorName, ...)\ extern void CALLBACK Create##operatorName(IMLOperatorKernelCreationContext* kernelInfo, IMLOperatorKernel** opKernel)\ {\ diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/TensorDesc.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/TensorDesc.h index c5cbc59f3268b..ff70dec5b8871 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/TensorDesc.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/TensorDesc.h @@ -167,4 +167,6 @@ namespace Dml uint32_t m_minDimensionCount = NchwDimensionCount; uint32_t m_guaranteedBaseOffsetAlignment = 0; }; + + using TensorSequenceDesc = std::vector; } diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h index fe59778d08d81..4373a4119f77b 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h @@ -60,6 +60,7 @@ namespace AttrName static constexpr const char* Min = "min"; static constexpr const char* Mode = "mode"; static constexpr const char* NearestMode = "nearest_mode"; + static constexpr const char* NewAxis = "new_axis"; static constexpr const char* NoopWithEmptyAxes = "noop_with_empty_axes"; static constexpr const char* NormalizeVariance = "normalize_variance"; static constexpr const char* P = "p"; diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h index d79b2fb4e7c2a..ba0623142ffb1 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h @@ -168,6 +168,38 @@ class MLOperatorTensorShapeDescription return ret; } + uint32_t GetSequenceInputCount(uint32_t inputIndex) const + { + Microsoft::WRL::ComPtr private_impl; + m_impl.As(&private_impl); + uint32_t inputCount = 0; + ORT_THROW_IF_FAILED(private_impl->GetSequenceInputCount(inputIndex, &inputCount)); + return inputCount; + } + + uint32_t GetSequenceInputTensorDimensionCount(uint32_t inputIndex, uint32_t sequenceIndex) const + { + Microsoft::WRL::ComPtr private_impl; + m_impl.As(&private_impl); + + uint32_t ret; + ORT_THROW_IF_FAILED(private_impl->GetSequenceInputTensorDimensionCount(inputIndex, sequenceIndex, &ret)); + return ret; + } + + std::vector GetSequenceInputTensorShape(uint32_t inputIndex, uint32_t sequenceIndex) const + { + Microsoft::WRL::ComPtr private_impl; + m_impl.As(&private_impl); + + std::vector ret; + uint32_t dimensionCount = GetSequenceInputTensorDimensionCount(inputIndex, sequenceIndex); + ret.resize(dimensionCount); + + ORT_THROW_IF_FAILED(private_impl->GetSequenceInputTensorShape(inputIndex, sequenceIndex, dimensionCount, ret.data())); + return ret; + } + bool HasOutputShapeDescription() const noexcept { return m_impl->HasOutputShapeDescription(); @@ -622,6 +654,39 @@ class MLShapeInferenceContext : public MLOperatorAttributes return ret; } + uint32_t GetSequenceInputCount(uint32_t inputIndex) const + { + Microsoft::WRL::ComPtr private_impl; + m_impl.As(&private_impl); + uint32_t inputCount = 0; + ORT_THROW_IF_FAILED(private_impl->GetSequenceInputCount(inputIndex, &inputCount)); + return inputCount; + } + + uint32_t GetSequenceInputTensorDimensionCount(uint32_t inputIndex, uint32_t sequenceIndex) const + { + Microsoft::WRL::ComPtr private_impl; + m_impl.As(&private_impl); + + uint32_t ret; + ORT_THROW_IF_FAILED(private_impl->GetSequenceInputTensorDimensionCount(inputIndex, sequenceIndex, &ret)); + return ret; + } + + std::vector GetSequenceInputTensorShape(uint32_t inputIndex, uint32_t sequenceIndex) const + { + Microsoft::WRL::ComPtr private_impl; + m_impl.As(&private_impl); + + std::vector ret; + uint32_t dimensionCount = GetSequenceInputTensorDimensionCount(inputIndex, sequenceIndex); + ret.resize(dimensionCount); + + ORT_THROW_IF_FAILED(private_impl->GetSequenceInputTensorShape(inputIndex, sequenceIndex, dimensionCount, ret.data())); + return ret; + } + + void SetOutputTensorShape(uint32_t outputIndex, const std::vector& outputDimensions) { ORT_THROW_IF_FAILED(m_impl->SetOutputTensorShape(outputIndex, static_cast(outputDimensions.size()), outputDimensions.data())); @@ -691,6 +756,56 @@ class MLOperatorKernelContext return m_impl; } + bool IsSequenceInputTensor(uint32_t inputIndex) const + { + Microsoft::WRL::ComPtr operatorKernelContext; + m_impl.As(&operatorKernelContext); + return operatorKernelContext->IsSequenceInputTensor(inputIndex); + } + + uint32_t GetSequenceInputCount(uint32_t inputIndex) const + { + Microsoft::WRL::ComPtr operatorKernelContext; + m_impl.As(&operatorKernelContext); + uint32_t inputCount = 0; + ORT_THROW_IF_FAILED(operatorKernelContext->GetSequenceInputCount(inputIndex, &inputCount)); + return inputCount; + } + + MLOperatorTensor GetSequenceInputTensor(uint32_t inputIndex, uint32_t sequenceIndex) const + { + Microsoft::WRL::ComPtr operatorKernelContext; + m_impl.As(&operatorKernelContext); + + Microsoft::WRL::ComPtr tensor; + ORT_THROW_HR_IF(E_INVALIDARG, !operatorKernelContext->IsSequenceInputTensor(inputIndex)); + ORT_THROW_IF_FAILED(operatorKernelContext->GetSequenceInputTensor(inputIndex, sequenceIndex, &tensor)); + return tensor.Get(); + } + + MLOperatorTensor GetSequenceOutputTensor( + uint32_t outputIndex, + uint32_t sequenceIndex, + MLOperatorTensorDataType dataType, + uint32_t dimensions, + const uint32_t* dimensionSizes, + bool gpuOutput) const + { + Microsoft::WRL::ComPtr operatorKernelContext; + m_impl.As(&operatorKernelContext); + + Microsoft::WRL::ComPtr tensor; + ORT_THROW_IF_FAILED(operatorKernelContext->GetSequenceOutputTensor( + outputIndex, + sequenceIndex, + dataType, + dimensions, + dimensionSizes, + gpuOutput, + &tensor)); + return tensor.Get(); + } + MLOperatorTensor GetInputTensor(uint32_t inputIndex) const { Microsoft::WRL::ComPtr tensor; diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h index afcfc587ecade..bcf140cc4b02a 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h @@ -11,8 +11,8 @@ struct DML_OUTPUT_GRAPH_EDGE_DESC; struct DML_INTERMEDIATE_GRAPH_EDGE_DESC; // Either nodesAsOpDesc or nodesAsIDMLOperator is present. -// 1) Operator kernels which implement operators using only a single DML operator will pass a DML_OPERATOR_DESC. -// These kernels pass DML_OPERATOR_DESC, because while building Dml graph (inside FusedGraphKernel.cpp) we can change the +// 1) Operator kernels which implement operators using only a single DML operator will pass a DML_OPERATOR_DESC. +// These kernels pass DML_OPERATOR_DESC, because while building Dml graph (inside FusedGraphKernel.cpp) we can change the // the flag of constant inputs to DML_TENSOR_FLAG_OWNED_BY_DML. // 2) Operator kernels which implement operators using DMLX graph, they will pass IDMLOperator and won't be able // to use DML_TENSOR_FLAG_OWNED_BY_DML. @@ -34,22 +34,44 @@ struct MLOperatorGraphDesc interface __declspec(uuid("aa2173bb-6684-4de8-abf2-9acbdf88b426")) -IMLOperatorShapeInferenceContextPrivate : public IMLOperatorShapeInferenceContext +IMLOperatorShapeInferenceContextPrivate : public IMLOperatorShapeInferenceContext { STDMETHOD(GetConstantInputTensor)( - uint32_t inputIndex, + uint32_t inputIndex, _Outptr_ IMLOperatorTensor** tensor ) const noexcept PURE; + + //! Gets the number of dimensions of a tensor output of the operator. + STDMETHOD(GetSequenceInputCount)( + uint32_t inputIndex, + _Out_ uint32_t* inputCount + ) const noexcept PURE; + + //! Gets the number of dimensions of a tensor output of the operator. + STDMETHOD(GetSequenceInputTensorDimensionCount)( + uint32_t inputIndex, + uint32_t sequenceIndex, + _Out_ uint32_t* dimensionCount + ) const noexcept PURE; + + //! Gets the sizes of dimensions of an input tensor of the operator. + //! Returns an error if the input at the specified index is not a tensor. + STDMETHOD(GetSequenceInputTensorShape)( + uint32_t inputIndex, + uint32_t sequenceIndex, + uint32_t dimensionCount, + _Out_writes_(dimensionCount) uint32_t* dimensions + ) const noexcept PURE; }; interface __declspec(uuid("63bff199-0203-43c7-86c4-f442a599df4c")) -IMLOperatorKernelCreationContextPrivate : public IMLOperatorKernelCreationContext +IMLOperatorKernelCreationContextPrivate : public IMLOperatorKernelCreationContext { STDMETHOD(GetConstantInputTensor)( - uint32_t inputIndex, + uint32_t inputIndex, _Outptr_ IMLOperatorTensor** tensor ) const noexcept PURE; - + STDMETHOD_(bool, IsDmlGraphNode)() const noexcept PURE; STDMETHOD(SetDmlOperator)( @@ -81,7 +103,7 @@ IMLOperatorSupportQueryContextPrivate : public IMLOperatorAttributes1 //! Gets the number of outputs to the operator. STDMETHOD_(uint32_t, GetOutputCount)() const noexcept PURE; - //! Returns true if an input to the operator is valid. + //! Returns true if an input to the operator is valid. //! This always returns true except for optional inputs and invalid indices. STDMETHOD_(bool, IsInputValid)(uint32_t inputIndex) const noexcept PURE; @@ -91,13 +113,13 @@ IMLOperatorSupportQueryContextPrivate : public IMLOperatorAttributes1 //! Gets the description of the specified input edge of the operator. STDMETHOD(GetInputEdgeDescription)( - uint32_t inputIndex, + uint32_t inputIndex, _Out_ MLOperatorEdgeDescription* edgeDescription ) const noexcept PURE; //! Gets the description of the specified output edge of the operator. STDMETHOD(GetOutputEdgeDescription)( - uint32_t outputIndex, + uint32_t outputIndex, _Out_ MLOperatorEdgeDescription* edgeDescription ) const noexcept PURE; }; @@ -128,6 +150,78 @@ IMLOperatorRegistryPrivate : public IUnknown ) const noexcept PURE; }; +//! \interface IMLOperatorTensorShapeDescription1 +//! \brief Represents the set of input and output tensor shapes of an operator. +//! This interface is called by the factory objects registered to create kernels. +//! It is available to these factory objects unless corresponding kernels are +//! registered using the MLOperatorKernelOptions::AllowDynamicInputShapes flag. +interface DECLSPEC_UUID("440DA47C-018B-41F6-80A4-13FCF0544F37") DECLSPEC_NOVTABLE +IMLOperatorTensorShapeDescriptionPrivate : IUnknown +{ + //! Gets the number of dimensions of a tensor output of the operator. + STDMETHOD(GetSequenceInputCount)( + uint32_t inputIndex, + _Out_ uint32_t* inputCount + ) const noexcept PURE; + + //! Gets the number of dimensions of a tensor input of the operator. + //! Returns an error if the input at the specified index is not a tensor. + STDMETHOD(GetSequenceInputTensorDimensionCount)( + uint32_t inputIndex, + uint32_t sequenceIndex, + _Out_ uint32_t* dimensionCount + ) const noexcept PURE; + + //! Gets the sizes of dimensions of an input tensor of the operator. + //! Returns an error if the input at the specified index is not a tensor. + STDMETHOD(GetSequenceInputTensorShape)( + uint32_t inputIndex, + uint32_t sequenceIndex, + uint32_t dimensionCount, + _Out_writes_(dimensionCount) uint32_t* dimensions + ) const noexcept PURE; + +}; + +//! \interface IMLOperatorKernelContext +//! \brief Provides information about an operator's usage while kernels are being computed. +interface DECLSPEC_UUID("AFEED22E-B1B4-4DCE-BE09-27B95B7AD5AF") DECLSPEC_NOVTABLE +IMLOperatorKernelContextPrivate : IUnknown +{ + //! Gets the input tensor of the operator at the specified index. + //! This sets tensor to nullptr for optional inputs which do not exist. + //! Returns an error if the input at the specified index is not a tensor. + STDMETHOD(GetSequenceInputTensor)( + uint32_t inputIndex, + uint32_t sequenceIndex, + _COM_Outptr_result_maybenull_ IMLOperatorTensor** tensor + ) const noexcept PURE; + + //! Gets the output tensor of the operator at the specified index. + //! This sets tensor to nullptr for optional outputs which do not exist. + //! Returns an error if the output at the specified index is not a tensor. + STDMETHOD(GetSequenceOutputTensor)( + uint32_t outputIndex, + uint32_t sequenceIndex, + MLOperatorTensorDataType dataType, + uint32_t dimensions, + const uint32_t* dimensionSizes, + bool gpuOutput, + _COM_Outptr_result_maybenull_ IMLOperatorTensor** tensor + ) const noexcept PURE; + + //! Gets the input tensor of the operator at the specified index. + //! This sets tensor to nullptr for optional inputs which do not exist. + //! Returns an error if the input at the specified index is not a tensor. + STDMETHOD(GetSequenceInputCount)( + uint32_t inputIndex, + _Out_ uint32_t* inputCount + ) const noexcept PURE; + + //! Returns whether the tensor at inputIndex is a sequence tensor or not + STDMETHOD_(bool, IsSequenceInputTensor)(uint32_t inputIndex) const = 0; +}; + // Declare private enum MLOperatorAttributeType::Tensor. // // enum class MLOperatorAttributeType : uint32_t diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h index cd11eb4a31b85..327d98a3d864c 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h @@ -306,6 +306,9 @@ struct IShapeInformationAdapter { virtual uint32_t GetInputTensorDimensionCount(uint32_t inputIndex) const = 0; virtual std::vector GetInputTensorShape(uint32_t inputIndex) const = 0; + virtual uint32_t GetSequenceInputCount(uint32_t inputIndex) const = 0; + virtual uint32_t GetSequenceInputTensorDimensionCount(uint32_t inputIndex, uint32_t sequenceIndex) const = 0; + virtual std::vector GetSequenceInputTensorShape(uint32_t inputIndex, uint32_t sequenceIndex) const = 0; virtual ~IShapeInformationAdapter() {} }; @@ -341,6 +344,9 @@ struct ShapeInformationAdapter : IShapeInformationAdapter virtual uint32_t GetInputTensorDimensionCount(uint32_t inputIndex) const { return m_informationSource.GetInputTensorDimensionCount(inputIndex); } virtual std::vector GetInputTensorShape(uint32_t inputIndex) const { return m_informationSource.GetInputTensorShape(inputIndex); } + virtual uint32_t GetSequenceInputCount(uint32_t inputIndex) const { return m_informationSource.GetSequenceInputCount(inputIndex); } + virtual uint32_t GetSequenceInputTensorDimensionCount(uint32_t inputIndex, uint32_t sequenceIndex) const { return m_informationSource.GetSequenceInputTensorDimensionCount(inputIndex, sequenceIndex); } + virtual std::vector GetSequenceInputTensorShape(uint32_t inputIndex, uint32_t sequenceIndex) const { return m_informationSource.GetSequenceInputTensorShape(inputIndex, sequenceIndex); } virtual ~ShapeInformationAdapter() {} InformationSourceType& m_informationSource; diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h index a673da6176838..7784248bd87fa 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h @@ -245,6 +245,7 @@ namespace OperatorHelper static const int sc_sinceVer_Squeeze = 11; static const int sc_sinceVer_TopK = 11; static const int sc_sinceVer_Unsqueeze = 11; + static const int sc_sinceVer_ConcatFromSequence = 11; } // namespace OnnxOperatorSet11 namespace OnnxOperatorSet12 diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc index 8c86fa51dd4b2..54bc21d75906a 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc @@ -56,6 +56,7 @@ class Memcpy final : public OpKernel { "Memcpy rocm: unable to get an allocator."); } auto X_size = X->Size(); + Y->Reserve(X_size); for (size_t i = 0; i < X_size; ++i) { const Tensor& source_tensor = X->Get(i); std::unique_ptr target_tensor = Tensor::Create(source_tensor.DataType(), source_tensor.Shape(), alloc); diff --git a/onnxruntime/core/providers/shared_library/provider_api.h b/onnxruntime/core/providers/shared_library/provider_api.h index 4145e4dcbb162..727013359d30f 100644 --- a/onnxruntime/core/providers/shared_library/provider_api.h +++ b/onnxruntime/core/providers/shared_library/provider_api.h @@ -148,7 +148,7 @@ struct OpKernelInfo; struct PrimitiveDataTypeBase; struct Tensor; struct SparseTensor; -struct TensorSeq; +class TensorSeq; class SessionState; class If; diff --git a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc index b772fe95d6ecc..30be10ea7e15f 100644 --- a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc +++ b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc @@ -563,12 +563,12 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape, const TensorShape& bias_shape, const Tensor*& mask_index, const Tensor* past, - const Tensor* extra_add_qk, + const Tensor* relative_position_bias, void* parameters, const int max_threads_per_block, const Tensor* past_seq_len) const { return g_host_cpu.AttentionBase__CheckInputs(this, input_shape, weights_shape, bias_shape, - mask_index, past, extra_add_qk, parameters, + mask_index, past, relative_position_bias, parameters, max_threads_per_block, past_seq_len); } Tensor* AttentionBase::GetPresent(OpKernelContext* context, const Tensor* past, int batch_size, int head_size, diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h index 1fbe938c2ae56..b8968f087d600 100644 --- a/onnxruntime/core/providers/shared_library/provider_interfaces.h +++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h @@ -847,6 +847,9 @@ struct ProviderHost { virtual void TensorSeq__SetType(TensorSeq* p, MLDataType data_type) = 0; virtual size_t TensorSeq__Size(const TensorSeq* p) noexcept = 0; virtual const Tensor& TensorSeq__Get(const TensorSeq* p, size_t i) = 0; + virtual const OrtValue& TensorSeq__GetAt(const TensorSeq* p, size_t i) = 0; + virtual void TensorSeq__Add(TensorSeq* p, const OrtValue& tensor) = 0; + virtual void TensorSeq__Add(TensorSeq* p, OrtValue&& tensor) = 0; virtual void TensorSeq__Add(TensorSeq* p, Tensor&& tensor) = 0; virtual void TensorSeq__Reserve(TensorSeq* p, size_t capacity) = 0; diff --git a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h index afe3a35f7c787..67c6cac4a0965 100644 --- a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h +++ b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h @@ -1060,11 +1060,15 @@ struct SparseTensor final { #endif // TensorSeq -struct TensorSeq final { +class TensorSeq final { +public: MLDataType DataType() const noexcept { return g_host->TensorSeq__DataType(this); } void SetType(MLDataType elem_type) { g_host->TensorSeq__SetType(this, elem_type); } size_t Size() const noexcept { return g_host->TensorSeq__Size(this); } const Tensor& Get(size_t i) const { return g_host->TensorSeq__Get(this, i); } + const OrtValue& GetAt(size_t i) const { return g_host->TensorSeq__GetAt(this, i); } + void Add(const OrtValue& tensor) { g_host->TensorSeq__Add(this, tensor); } + void Add(OrtValue&& tensor) { g_host->TensorSeq__Add(this, std::move(tensor)); } void Add(Tensor&& tensor) { g_host->TensorSeq__Add(this, std::move(tensor)); } void Reserve(size_t capacity) { g_host->TensorSeq__Reserve(this, capacity); } }; diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 265823b09a913..5c7afe1453b5a 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -1803,43 +1803,29 @@ static ORT_STATUS_PTR OrtCreateValueImplSeqHelperMap(const OrtValue* const* in, } #endif -static ORT_STATUS_PTR OrtCreateValueImplSeqHelperTensor(const Tensor& tensor, Tensor& out) { - auto data_type = tensor.DataType(); - ORT_API_RETURN_IF_ERROR(CreateTensorImplForSeq(data_type, - tensor.Shape().GetDims().data(), tensor.Shape().NumDimensions(), - out)); - size_t num_elements = narrow(tensor.Shape().Size()); - ORT_API_RETURN_IF_ERROR(c_api_internal::PopulateTensorWithData(out, tensor.IsDataTypeString(), - tensor.DataRaw(), num_elements, data_type->Size())); - return nullptr; -} - static ORT_STATUS_PTR OrtCreateValueImplSeqHelper(const OrtValue* const* in, size_t num_values, _Outptr_ OrtValue** out) { using namespace c_api_internal; - std::vector tensors; - tensors.resize(num_values); - auto dtype = static_cast(in[0])->Get().DataType(); + auto dtype = in[0]->Get().DataType(); + auto seq_ptr = std::make_unique(dtype); + seq_ptr->Reserve(num_values); for (size_t idx = 0; idx < num_values; ++idx) { ORT_ENFORCE(in[idx]->IsTensor(), "Expecting all elements to be tensors. Got: ", DataTypeImpl::ToString(in[idx]->Type())); - auto& one_tensor = static_cast(in[idx])->Get(); - auto tensor_elem_type = one_tensor.DataType(); + auto tensor_elem_type = in[idx]->Get().DataType(); // sequences must have tensors of the same data type - if (idx > 0 && (tensor_elem_type != dtype)) { + if (tensor_elem_type != dtype) { return OrtApis::CreateStatus(ORT_FAIL, "Sequences must have tensors of the same data type. There was at least one tensor in the input that was different."); } - ORT_API_RETURN_IF_ERROR(OrtCreateValueImplSeqHelperTensor(one_tensor, tensors[idx])); + seq_ptr->Add(*in[idx]); } // create OrtValue with this vector auto value = std::make_unique(); auto ml_type = DataTypeImpl::GetType(); - auto seq_ptr = std::make_unique(dtype); - seq_ptr->SetElements(std::move(tensors)); value->Init(seq_ptr.release(), ml_type, ml_type->GetDeleteFunc()); diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 18079d3ee256d..81510120f46a9 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -972,6 +972,9 @@ struct ProviderHostImpl : ProviderHost { void TensorSeq__SetType(TensorSeq* p, MLDataType data_type) override { p->SetType(data_type); } size_t TensorSeq__Size(const TensorSeq* p) noexcept override { return p->Size(); } const Tensor& TensorSeq__Get(const TensorSeq* p, size_t i) override { return p->Get(i); } + const OrtValue& TensorSeq__GetAt(const TensorSeq* p, size_t i) override { return p->GetAt(i); } + void TensorSeq__Add(TensorSeq* p, const OrtValue& tensor) override { p->Add(tensor); } + void TensorSeq__Add(TensorSeq* p, OrtValue&& tensor) override { p->Add(std::move(tensor)); } void TensorSeq__Add(TensorSeq* p, Tensor&& tensor) override { p->Add(std::move(tensor)); } void TensorSeq__Reserve(TensorSeq* p, size_t capacity) override { p->Reserve(capacity); } diff --git a/onnxruntime/core/session/standalone_op_invoker.cc b/onnxruntime/core/session/standalone_op_invoker.cc index 1020e07f5ebc8..48cfdd60d2eba 100644 --- a/onnxruntime/core/session/standalone_op_invoker.cc +++ b/onnxruntime/core/session/standalone_op_invoker.cc @@ -4,6 +4,7 @@ #include "core/session/inference_session.h" #include "core/framework/kernel_registry.h" #include "core/framework/error_code_helper.h" +#include "core/framework/TensorSeq.h" #include "core/session/ort_apis.h" #include diff --git a/onnxruntime/python/onnxruntime_pybind_mlvalue.cc b/onnxruntime/python/onnxruntime_pybind_mlvalue.cc index 806d01521c62a..ff4c2837a647a 100644 --- a/onnxruntime/python/onnxruntime_pybind_mlvalue.cc +++ b/onnxruntime/python/onnxruntime_pybind_mlvalue.cc @@ -495,26 +495,24 @@ static void CreateSequenceOfTensors(AllocatorPtr alloc, const std::string& name_ throw std::runtime_error("Input is not of sequence type"); } + // set the seq type + MLDataType seq_dtype = OrtTypeInfo::ElementTypeFromProto( + static_cast(type_proto.sequence_type().elem_type().tensor_type().elem_type())); + auto p_seq_tensors = std::make_unique(seq_dtype); + // populate the seq - std::vector tensors; auto list_size = PyList_Size(pylist_obj); if (list_size > 0) { - tensors.resize(list_size); for (Py_ssize_t i = 0; i < list_size; ++i) { auto* py_obj = PyList_GetItem(pylist_obj, i); if (!PyObjectCheck_NumpyArray(py_obj)) { throw std::runtime_error("CreateSequenceOfTensors: Input is not a tensor"); } auto p_tensor = CreateTensor(alloc, name_input, reinterpret_cast(py_obj)); - tensors[i] = std::move(*p_tensor); + p_seq_tensors->Add(std::move(*p_tensor)); } } - // set the seq type - MLDataType seq_dtype = OrtTypeInfo::ElementTypeFromProto( - static_cast(type_proto.sequence_type().elem_type().tensor_type().elem_type())); - auto p_seq_tensors = std::make_unique(seq_dtype); - p_seq_tensors->SetElements(std::move(tensors)); auto ml_tensor_sequence = DataTypeImpl::GetType(); p_mlvalue->Init(p_seq_tensors.release(), ml_tensor_sequence, diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 2d62d2e3d70a8..f61fe7b8788ef 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -198,7 +198,7 @@ py::object AddNonTensor(const OrtValue& val, py::list py_list; for (const auto& rtensor : seq_tensors) { py::object obj; - GetPyObjFromTensor(rtensor, obj, data_transfer_manager, mem_cpy_to_host_functions); + GetPyObjFromTensor(rtensor.Get(), obj, data_transfer_manager, mem_cpy_to_host_functions); py_list.append(obj); } // XToolChain kills the build diff --git a/onnxruntime/python/tools/transformers/benchmark_helper.py b/onnxruntime/python/tools/transformers/benchmark_helper.py index 2b5c53b867257..bf8dd931c6927 100644 --- a/onnxruntime/python/tools/transformers/benchmark_helper.py +++ b/onnxruntime/python/tools/transformers/benchmark_helper.py @@ -543,6 +543,7 @@ def get_ort_environment_variables(): # Environment variables might impact ORT performance on transformer models. Note that they are for testing only. env_names = [ "ORT_DISABLE_FUSED_ATTENTION", + "ORT_ENABLE_FUSED_CAUSAL_ATTENTION", "ORT_DISABLE_FUSED_CROSS_ATTENTION", "ORT_DISABLE_TRT_FLASH_ATTENTION", "ORT_DISABLE_MEMORY_EFFICIENT_ATTENTION", diff --git a/onnxruntime/python/tools/transformers/fusion_attention.py b/onnxruntime/python/tools/transformers/fusion_attention.py index 245ea9322ad61..342d43306e699 100644 --- a/onnxruntime/python/tools/transformers/fusion_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_attention.py @@ -337,7 +337,7 @@ def create_attention_node( # For MultiHeadAttention operator, use separated inputs for query, key and value, and no weights. if self.use_multi_head_attention: if add_qk_str is not None: - logger.debug("MultiHeadAttention does not support extra_add_qk: cannot fuse the attention.") + logger.debug("MultiHeadAttention does not support relative_position_bias: cannot fuse the attention.") return None attention_inputs = [ diff --git a/onnxruntime/python/tools/transformers/models/gpt2/gpt2_helper.py b/onnxruntime/python/tools/transformers/models/gpt2/gpt2_helper.py index f5de69e8f0524..7eec8575f79a4 100644 --- a/onnxruntime/python/tools/transformers/models/gpt2/gpt2_helper.py +++ b/onnxruntime/python/tools/transformers/models/gpt2/gpt2_helper.py @@ -24,7 +24,6 @@ from benchmark_helper import Precision from float16 import float_to_float16_max_diff -from fusion_options import AttentionMaskFormat from io_binding_helper import IOBindingHelper from onnx_model import OnnxModel from torch_onnx_export_helper import torch_onnx_export @@ -188,6 +187,7 @@ def get_dummy_inputs( input_ids_dtype: torch.dtype = torch.int32, position_ids_dtype: torch.dtype = torch.int32, attention_mask_dtype: torch.dtype = torch.int32, + left_side_padding: bool = True, ) -> Gpt2Inputs: """Create random inputs for GPT2 model. Returns torch tensors of input_ids, position_ids, attention_mask and a list of past state tensors. @@ -218,9 +218,14 @@ def get_dummy_inputs( dtype=attention_mask_dtype, device=device, ) + if total_sequence_length >= 2: - padding_position = random.randint(0, total_sequence_length - 1) # test input with padding. - attention_mask[:, padding_position] = 0 + for i in range(batch_size): + padding_length = random.randint(0, total_sequence_length - 1) + if left_side_padding: + attention_mask[i, :padding_length] = 0 + else: # right side padding + attention_mask[i, total_sequence_length - padding_length :] = 0 # Deduce position_ids from attention mask position_ids = None @@ -517,11 +522,6 @@ def optimize_onnx( optimization_options = FusionOptions("gpt2") - if is_float16 and stage == 1: - # For init_decoder, enable mask index to use fused causal cuda kernel. - # Potentially, we can add other optimization like unpad for effective transformer - optimization_options.attention_mask_format = AttentionMaskFormat.MaskIndexEnd - # TODO(hasesh): Investigate parity issue for GPT-2 fp16 when SkipLayerNormalization # is enabled if is_float16: @@ -841,6 +841,7 @@ def test_parity( input_ids_dtype=input_ids_dtype, position_ids_dtype=position_ids_dtype, attention_mask_dtype=attention_mask_dtype, + left_side_padding=True, ) outputs = Gpt2Helper.pytorch_inference(model, dummy_inputs) if use_io_binding: @@ -868,6 +869,7 @@ def test_parity( max_abs_diff_list.append(max_abs_diff) if is_all_close: passed_test_cases += 1 + if is_top1_matched: top1_matched_cases += 1 top1_matched_cases_per_run[run_id] += 1 diff --git a/onnxruntime/python/tools/transformers/onnx_model.py b/onnxruntime/python/tools/transformers/onnx_model.py index 42fd4d5909a30..276a9428ecf72 100644 --- a/onnxruntime/python/tools/transformers/onnx_model.py +++ b/onnxruntime/python/tools/transformers/onnx_model.py @@ -327,7 +327,7 @@ def match_parent_path( self, node, parent_op_types, - parent_input_index, + parent_input_index=None, output_name_to_node=None, return_indice=None, ): @@ -347,7 +347,8 @@ def match_parent_path( Returns: parents: a list of matched parent node. """ - assert len(parent_input_index) == len(parent_op_types) + if parent_input_index is not None: + assert len(parent_input_index) == len(parent_op_types) if output_name_to_node is None: output_name_to_node = self.output_name_to_node() @@ -358,16 +359,19 @@ def match_parent_path( matched_parent = self.match_parent( current_node, op_type, - parent_input_index[i], + parent_input_index[i] if parent_input_index is not None else None, output_name_to_node, exclude=[], return_indice=return_indice, ) if matched_parent is None: - logger.debug( - f"Failed to match index={i} parent_input_index={parent_input_index[i]} op_type={op_type}", - stack_info=True, - ) + if parent_input_index is not None: + logger.debug( + f"Failed to match index={i} parent_input_index={parent_input_index[i]} op_type={op_type}", + stack_info=True, + ) + else: + logger.debug(f"Failed to match index={i} op_type={op_type}", stack_info=True) return None matched_parents.append(matched_parent) diff --git a/onnxruntime/python/tools/transformers/onnx_model_tnlr.py b/onnxruntime/python/tools/transformers/onnx_model_tnlr.py index dc8f6810914a7..85e510a828990 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_tnlr.py +++ b/onnxruntime/python/tools/transformers/onnx_model_tnlr.py @@ -172,8 +172,8 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): add = k_nodes[-2] matmul = k_nodes[-1] - extra_add_qk_nodes = self.model.match_parent_path(add_qk, ["Reshape", "Where"], [1, 0]) - if extra_add_qk_nodes is None: + relative_position_bias_nodes = self.model.match_parent_path(add_qk, ["Reshape", "Where"], [1, 0]) + if relative_position_bias_nodes is None: return if matmul.input[0] == root_input: @@ -189,7 +189,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): self.hidden_size, root_input, attention_last_node.output[0], - extra_add_qk_nodes[0].input[0], + relative_position_bias_nodes[0].input[0], ) if new_node is None: return diff --git a/onnxruntime/test/contrib_ops/attention_op_test.cc b/onnxruntime/test/contrib_ops/attention_op_test.cc index fb1d8fcfe451a..daeec7a64c2f8 100644 --- a/onnxruntime/test/contrib_ops/attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/attention_op_test.cc @@ -59,7 +59,7 @@ static void RunAttentionTest( const bool disable_cuda = false, const bool disable_rocm = false, std::vector qkv_sizes = {}, - const std::vector& extra_add_data = {}, + const std::vector& relative_position_bias_data = {}, int kv_sequence_length = 0, bool past_present_share_buffer = false, bool use_scale = false) { @@ -199,12 +199,12 @@ static void RunAttentionTest( } } - std::vector extra_add_data_dims = {batch_size, number_of_heads, sequence_length, sequence_length}; - if (extra_add_data.size() > 0) { + std::vector relative_position_bias_data_dims = {batch_size, number_of_heads, sequence_length, sequence_length}; + if (relative_position_bias_data.size() > 0) { if (use_float16) { - tester.AddInput("extra_add_qk", extra_add_data_dims, ToFloat16(extra_add_data)); + tester.AddInput("relative_position_bias", relative_position_bias_data_dims, ToFloat16(relative_position_bias_data)); } else { - tester.AddInput("extra_add_qk", extra_add_data_dims, extra_add_data); + tester.AddInput("relative_position_bias", relative_position_bias_data_dims, relative_position_bias_data); } } else { if (use_float16) { @@ -264,7 +264,7 @@ static void RunAttentionTest( const bool disable_cuda = false, const bool disable_rocm = false, const std::vector qkv_sizes = {}, - const std::vector& extra_add_data = {}, + const std::vector& relative_position_bias_data = {}, int kv_sequence_length = 0, bool past_present_share_buffer = false, bool use_scale = false) { @@ -272,13 +272,13 @@ static void RunAttentionTest( batch_size, sequence_length, hidden_size, number_of_heads, use_float16, is_unidirectional, use_past_state, past_sequence_length, past_data, present_data, mask_type, input_hidden_size, max_sequence_length, - disable_cpu, disable_cuda, disable_rocm, qkv_sizes, extra_add_data, + disable_cpu, disable_cuda, disable_rocm, qkv_sizes, relative_position_bias_data, kv_sequence_length, past_present_share_buffer, use_scale); RunAttentionTest(input_data, weights_data, true, bias_data, mask_index_data, output_data, batch_size, sequence_length, hidden_size, number_of_heads, use_float16, is_unidirectional, use_past_state, past_sequence_length, past_data, present_data, mask_type, input_hidden_size, max_sequence_length, - disable_cpu, disable_cuda, disable_rocm, qkv_sizes, extra_add_data, + disable_cpu, disable_cuda, disable_rocm, qkv_sizes, relative_position_bias_data, kv_sequence_length, past_present_share_buffer, use_scale); } @@ -390,7 +390,7 @@ TEST(AttentionTest, AttentionBatch1WithQKVAttr2) { 0, false, false, disable_rocm, qkv_sizes); } -TEST(AttentionTest, AttentionBatch1ExtraAdd) { +TEST(AttentionTest, AttentionBatch1RelativePositionBias) { int batch_size = 1; int sequence_length = 2; int hidden_size = 4; @@ -414,7 +414,7 @@ TEST(AttentionTest, AttentionBatch1ExtraAdd) { std::vector mask_index_data = {2L}; - std::vector extra_add_qk = { + std::vector relative_position_bias = { 0.2f, -0.1f, 0.4f, 2.5f, 1.6f, -1.1f, 0.4f, -2.5f}; std::vector output_data = { @@ -427,10 +427,10 @@ TEST(AttentionTest, AttentionBatch1ExtraAdd) { RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data, batch_size, sequence_length, hidden_size, number_of_heads, false, false, false, 0, nullptr, nullptr, AttentionMaskType::MASK_1D_KEY_SEQ_LEN, 0, - 0, disable_cpu, disable_cuda, disable_rocm, qkv_sizes, extra_add_qk); + 0, disable_cpu, disable_cuda, disable_rocm, qkv_sizes, relative_position_bias); } -TEST(AttentionTest, AttentionBatch2ExtraAdd) { +TEST(AttentionTest, AttentionBatch2RelativePositionBias) { int batch_size = 2; int sequence_length = 2; int hidden_size = 4; @@ -456,7 +456,7 @@ TEST(AttentionTest, AttentionBatch2ExtraAdd) { std::vector mask_index_data = {2L, 2L}; - std::vector extra_add_qk = { + std::vector relative_position_bias = { 0.2f, -0.1f, 0.4f, 2.5f, 1.6f, -1.1f, 0.4f, -2.5f, 0.2f, -0.1f, 0.4f, 2.5f, 1.6f, -1.1f, 0.4f, -2.5f}; @@ -472,7 +472,7 @@ TEST(AttentionTest, AttentionBatch2ExtraAdd) { RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data, batch_size, sequence_length, hidden_size, number_of_heads, false, false, false, 0, nullptr, nullptr, AttentionMaskType::MASK_1D_KEY_SEQ_LEN, 0, - 0, disable_cpu, disable_cuda, disable_rocm, qkv_sizes, extra_add_qk); + 0, disable_cpu, disable_cuda, disable_rocm, qkv_sizes, relative_position_bias); } TEST(AttentionTest, AttentionBatch1_Float16) { @@ -930,7 +930,8 @@ TEST(AttentionTest, Causal_EmptyPastState) { ScopedEnvironmentVariables scoped_env_vars{ EnvVarMap{ {onnxruntime::contrib::attention::kDisableTrtFlashAttention, "1"}, - {onnxruntime::contrib::attention::kDisableFusedAttention, "1"}}}; + {onnxruntime::contrib::attention::kEnableFusedCausalAttention, "0"}, + {onnxruntime::contrib::attention::kDisableFusedSelfAttention, "1"}}}; RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data, batch_size, sequence_length, hidden_size, number_of_heads, use_float16, is_unidirectional, use_past_state, past_sequence_length, &past_data, &present_data); @@ -941,7 +942,8 @@ TEST(AttentionTest, Causal_EmptyPastState) { ScopedEnvironmentVariables scoped_env_vars{ EnvVarMap{ {onnxruntime::contrib::attention::kDisableTrtFlashAttention, "1"}, - {onnxruntime::contrib::attention::kDisableFusedAttention, "0"}}}; + {onnxruntime::contrib::attention::kEnableFusedCausalAttention, "1"}, + {onnxruntime::contrib::attention::kDisableFusedSelfAttention, "0"}}}; RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data, batch_size, sequence_length, hidden_size, number_of_heads, use_float16, is_unidirectional, use_past_state, past_sequence_length, &past_data, &present_data); @@ -952,7 +954,8 @@ TEST(AttentionTest, Causal_EmptyPastState) { ScopedEnvironmentVariables scoped_env_vars{ EnvVarMap{ {onnxruntime::contrib::attention::kDisableTrtFlashAttention, "0"}, - {onnxruntime::contrib::attention::kDisableFusedAttention, "0"}}}; + {onnxruntime::contrib::attention::kEnableFusedCausalAttention, "1"}, + {onnxruntime::contrib::attention::kDisableFusedSelfAttention, "0"}}}; RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data, batch_size, sequence_length, hidden_size, number_of_heads, use_float16, is_unidirectional, use_past_state, past_sequence_length, &past_data, &present_data); @@ -1709,7 +1712,7 @@ TEST(AttentionTest, AttentionWithNormFactor) { use_float16, is_unidirectional, use_past_state, past_sequence_length, past_data, present_data, AttentionMaskType::MASK_2D_KEY_PADDING, 0 /*input_hidden_size*/, 0 /*max_sequence_length*/, false /*disable_cpu*/, false /*disable_cuda*/, true /*disable_rocm*/, {} /*qkv_sizes*/, - {} /*extra_add_data*/, 0 /*kv_sequence_length*/, false /*past_present_share_buffer*/, + {} /*relative_position_bias_data*/, 0 /*kv_sequence_length*/, false /*past_present_share_buffer*/, true /*use_scale*/); } diff --git a/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc b/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc index eddf9ea4c7c4e..263a3495758ff 100644 --- a/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc @@ -140,7 +140,7 @@ static void RunMultiHeadAttentionKernel( ScopedEnvironmentVariables scoped_env_vars{ EnvVarMap{ {onnxruntime::contrib::attention::kDisableTrtFlashAttention, "0"}, - {onnxruntime::contrib::attention::kDisableFusedAttention, "0"}, + {onnxruntime::contrib::attention::kDisableFusedSelfAttention, "0"}, {onnxruntime::contrib::attention::kDisableFusedCrossAttention, "0"}, {onnxruntime::contrib::attention::kDisableMemoryEfficientAttention, "0"}}}; RunMultiHeadAttentionTest( @@ -154,7 +154,7 @@ static void RunMultiHeadAttentionKernel( ScopedEnvironmentVariables scoped_env_vars{ EnvVarMap{ {onnxruntime::contrib::attention::kDisableTrtFlashAttention, "1"}, - {onnxruntime::contrib::attention::kDisableFusedAttention, "1"}, + {onnxruntime::contrib::attention::kDisableFusedSelfAttention, "1"}, {onnxruntime::contrib::attention::kDisableFusedCrossAttention, "1"}, {onnxruntime::contrib::attention::kDisableMemoryEfficientAttention, "1"}}}; RunMultiHeadAttentionTest( @@ -168,7 +168,7 @@ static void RunMultiHeadAttentionKernel( ScopedEnvironmentVariables scoped_env_vars{ EnvVarMap{ {onnxruntime::contrib::attention::kDisableTrtFlashAttention, "1"}, - {onnxruntime::contrib::attention::kDisableFusedAttention, "1"}, + {onnxruntime::contrib::attention::kDisableFusedSelfAttention, "1"}, {onnxruntime::contrib::attention::kDisableFusedCrossAttention, "0"}, {onnxruntime::contrib::attention::kDisableMemoryEfficientAttention, "1"}}}; RunMultiHeadAttentionTest( @@ -183,7 +183,7 @@ static void RunMultiHeadAttentionKernel( ScopedEnvironmentVariables scoped_env_vars{ EnvVarMap{ {onnxruntime::contrib::attention::kDisableTrtFlashAttention, "1"}, - {onnxruntime::contrib::attention::kDisableFusedAttention, "1"}, + {onnxruntime::contrib::attention::kDisableFusedSelfAttention, "1"}, {onnxruntime::contrib::attention::kDisableFusedCrossAttention, "1"}, {onnxruntime::contrib::attention::kDisableMemoryEfficientAttention, "0"}}}; RunMultiHeadAttentionTest( @@ -198,7 +198,7 @@ static void RunMultiHeadAttentionKernel( ScopedEnvironmentVariables scoped_env_vars{ EnvVarMap{ {onnxruntime::contrib::attention::kDisableTrtFlashAttention, "0"}, - {onnxruntime::contrib::attention::kDisableFusedAttention, "0"}, + {onnxruntime::contrib::attention::kDisableFusedSelfAttention, "0"}, {onnxruntime::contrib::attention::kDisableFusedCrossAttention, "1"}, {onnxruntime::contrib::attention::kDisableMemoryEfficientAttention, "1"}}}; RunMultiHeadAttentionTest( diff --git a/onnxruntime/test/contrib_ops/qordered_attention_test.cc b/onnxruntime/test/contrib_ops/qordered_attention_test.cc index cf14a7f9918c6..5257fbdb08809 100644 --- a/onnxruntime/test/contrib_ops/qordered_attention_test.cc +++ b/onnxruntime/test/contrib_ops/qordered_attention_test.cc @@ -278,7 +278,7 @@ TEST(QOrderedTest, Attention_WithData_ROW_ORDER) { test_qorder.AddInput("scale_values_gemm", {}, {attn_out_scale}, true); test_qorder.AddInput("mask_index", {batch_size, sequence_len}, input_mask.data(), input_mask.size()); test_qorder.AddOptionalInputEdge(); // past - test_qorder.AddOptionalInputEdge(); // extra_add + test_qorder.AddOptionalInputEdge(); // relative_position_bias test_qorder.AddOutput("output", {batch_size, sequence_len, hidden_size}, attn_out_q8.data(), attn_out_q8.size()); diff --git a/onnxruntime/test/contrib_ops/relative_attention_bias_test.cc b/onnxruntime/test/contrib_ops/relative_attention_bias_test.cc index 7722291bee653..ba0299e4f3808 100644 --- a/onnxruntime/test/contrib_ops/relative_attention_bias_test.cc +++ b/onnxruntime/test/contrib_ops/relative_attention_bias_test.cc @@ -10,9 +10,9 @@ namespace onnxruntime { namespace test { static void RunRelativePositionBiasTest( - const std::vector& bias_table, // Shape = [num_buckets, num_heads] - const std::vector& sequence_length, // Shape = [1] - const std::vector& output_data, // Shape = [1, num_heads, sequence_length, sequence_length] + const std::vector& bias_table, // Shape = [num_buckets, num_heads] + const std::vector& sequence_length, // Shape = [1] + const std::vector& output_data, // Shape = [1, num_heads, sequence_length, sequence_length] int max_distance, int num_buckets, int num_heads, @@ -155,5 +155,264 @@ TEST(RelativePositionBiasTest, RelativePositionBiasTest_FP16_No_Bidirectional) { true); } +/***************Following scripts is used to generate test data, for your reference************* +import torch + +batch_size = 2 +num_heads = 2 +seq_len = 3 +head_size = 4 +D = 8 + +def dim_string_of(tensor): + return "{" + ", ".join([str(d) for d in tensor.shape]) + "}" + +def value_string_of(tensor): + arr = tensor.flatten().numpy() + lines = ["f, ".join([str(v) for v in arr[i : min(i+8, arr.size)]]) for i in range(0, arr.size, 8)] + return "{\n " + "f,\n ".join(lines) + "f}" + +def print_tensor(name, tensor): + print(f"const std::vector {name}_dim = {dim_string_of(tensor)};") + print(f"const std::vector {name} = {value_string_of(tensor)};") + +torch.manual_seed(0) +query_layer = torch.rand(batch_size, seq_len, num_heads * head_size) +query_bias = torch.rand(num_heads * head_size) +rel_pos = torch.rand(1, num_heads, seq_len, seq_len) +weight = torch.rand(head_size, D) +bias = torch.rand(D) +eco_a = torch.rand(1, num_heads, 1, 1) + +qw = (query_layer + query_bias).reshape(batch_size, seq_len, num_heads, head_size).transpose(1, 2) +gate_u,gate_r = torch.sigmoid( + (torch.matmul(qw, weight) + bias).view(batch_size, num_heads, seq_len,2, D//2).sum(-1, keepdim=False) + ).chunk(2, dim=-1) +gate_u_1 = gate_u * (gate_r * eco_a - 1.0) + 2.0 +output = gate_u_1 * rel_pos + +# output for test case +print(f"const int batch_size = {batch_size};") +print(f"const int num_heads = {num_heads};") +print(f"const int seq_len = {seq_len};") +print(f"const int head_size = {head_size};") +print(f"const int D = {D};") + +print_tensor("query_layer", query_layer) +print_tensor("query_bias", query_bias) +print_tensor("rel_pos", rel_pos) +print_tensor("weight", weight) +print_tensor("bias", bias) +print_tensor("eco_a", eco_a) +print_tensor("output", output) +****************/ + +// .Input(0, "query_layer", "tensor with shape (batch_size, seq_len, num_heads x head_size)", "T") +// .Input(1, "query_bias", "1-d tensor with shape (num_heads x head_size)", "T") +// .Input(2, "rel_pos", "tensor with shape (1, num_head, seq_len, seq_len)", "T") +// .Input(3, "weight", "gemm weight for the gated_ur_linear, shape (head_size, D), D is divisible by 2", "T") +// .Input(4, "bias", "bias for the gated_ur_linear, shape (D)", "T") +// .Input(5, "eco_a", "tensor of shape (1, num_heads, 1, 1)", "T") +// .Output(0, "output", "output tensor with shape (batch_size, num_heads, seq_len, seq_len)", "T") +static void RunGatedRelativePositionBiasTest( + const std::vector& query_layer, + const std::vector& query_bias, + const std::vector& rel_pos, + const std::vector& weight, + const std::vector& bias, + const std::vector& eco_a, + const std::vector& output, + int batch_size, + int seq_len, + int num_heads, + int head_size, + int D, + bool use_float16 = false) { + int min_cuda_architecture = use_float16 ? 530 : 0; + + bool enable_cuda = HasCudaEnvironment(min_cuda_architecture); + if (enable_cuda) { + OpTester tester("GatedRelativePositionBias", 1, onnxruntime::kMSDomain); + tester.AddAttribute("num_heads", static_cast(num_heads)); + + std::vector query_layer_dims = {batch_size, seq_len, num_heads * head_size}; + std::vector query_bias_dims = {num_heads * head_size}; + std::vector rel_pos_dims = {1, num_heads, seq_len, seq_len}; + std::vector weight_dims = {head_size, D}; + std::vector bias_dims = {D}; + std::vector eco_a_dims = {1, num_heads, 1, 1}; + std::vector output_dims = {batch_size, num_heads, seq_len, seq_len}; + + if (use_float16) { + tester.AddInput("query_layer", query_layer_dims, ToFloat16(query_layer)); + tester.AddInput("query_bias", query_bias_dims, ToFloat16(query_bias)); + tester.AddInput("rel_pos", rel_pos_dims, ToFloat16(rel_pos)); + tester.AddInput("weight", weight_dims, ToFloat16(weight)); + tester.AddInput("bias", bias_dims, ToFloat16(bias)); + tester.AddInput("eco_a", eco_a_dims, ToFloat16(eco_a)); + tester.AddOutput("output", output_dims, ToFloat16(output)); + } else { + tester.AddInput("query_layer", query_layer_dims, query_layer); + tester.AddInput("query_bias", query_bias_dims, query_bias); + tester.AddInput("rel_pos", rel_pos_dims, rel_pos); + tester.AddInput("weight", weight_dims, weight); + tester.AddInput("bias", bias_dims, bias); + tester.AddInput("eco_a", eco_a_dims, eco_a); + tester.AddOutput("output", output_dims, output); + } + + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } +} + +TEST(GatedRelativePositionBiasTest, FP16_BSNHD_1x3x2x4x8) { + constexpr int batch_size = 1; + constexpr int num_heads = 2; + constexpr int seq_len = 3; + constexpr int head_size = 4; + constexpr int D = 8; + const std::vector query_layer_dim = {1, 3, 8}; + const std::vector query_layer = { + 0.4962566f, 0.7682218f, 0.08847743f, 0.13203049f, 0.30742282f, 0.6340787f, 0.4900934f, 0.89644474f, + 0.45562798f, 0.6323063f, 0.34889346f, 0.4017173f, 0.022325754f, 0.16885895f, 0.29388845f, 0.5185218f, + 0.6976676f, 0.8000114f, 0.16102946f, 0.28226858f, 0.68160856f, 0.915194f, 0.3970999f, 0.8741559f}; + const std::vector query_bias_dim = {8}; + const std::vector query_bias = { + 0.41940832f, 0.55290705f, 0.9527381f, 0.03616482f, 0.18523103f, 0.37341738f, 0.30510002f, 0.9320004f}; + const std::vector rel_pos_dim = {1, 2, 3, 3}; + const std::vector rel_pos = { + 0.17591017f, 0.26983356f, 0.15067977f, 0.031719506f, 0.20812976f, 0.929799f, 0.7231092f, 0.7423363f, + 0.5262958f, 0.24365824f, 0.58459234f, 0.03315264f, 0.13871688f, 0.242235f, 0.81546897f, 0.7931606f, + 0.27825248f, 0.4819588f}; + const std::vector weight_dim = {4, 8}; + const std::vector weight = { + 0.81978035f, 0.99706656f, 0.6984411f, 0.5675464f, 0.83524317f, 0.20559883f, 0.593172f, 0.112347245f, + 0.15345693f, 0.24170822f, 0.7262365f, 0.7010802f, 0.20382375f, 0.65105355f, 0.774486f, 0.43689132f, + 0.5190908f, 0.61585236f, 0.8101883f, 0.98009706f, 0.11468822f, 0.31676513f, 0.69650495f, 0.9142747f, + 0.93510365f, 0.9411784f, 0.5995073f, 0.06520867f, 0.54599625f, 0.18719733f, 0.034022927f, 0.94424623f}; + const std::vector bias_dim = {8}; + const std::vector bias = { + 0.8801799f, 0.0012360215f, 0.593586f, 0.41577f, 0.41771942f, 0.27112156f, 0.6922781f, 0.20384824f}; + const std::vector eco_a_dim = {1, 2, 1, 1}; + const std::vector eco_a = { + 0.68329567f, 0.75285405f}; + const std::vector output_dim = {1, 2, 3, 3}; + const std::vector output = { + 0.29608122f, 0.45416728f, 0.25361493f, 0.053390637f, 0.3503264f, 1.5650483f, 1.2171557f, 1.2495192f, + 0.88587445f, 0.42708054f, 1.0246648f, 0.05810945f, 0.2430356f, 0.4244021f, 1.428723f, 1.3902748f, + 0.48772895f, 0.84479123f}; + + RunGatedRelativePositionBiasTest(query_layer, query_bias, rel_pos, weight, bias, eco_a, output, + batch_size, seq_len, num_heads, head_size, D, true); +} + +TEST(GatedRelativePositionBiasTest, FP32_BSNHD_2x3x2x4x8) { + constexpr int batch_size = 2; + constexpr int num_heads = 2; + constexpr int seq_len = 3; + constexpr int head_size = 4; + constexpr int D = 8; + const std::vector query_layer_dim = {2, 3, 8}; + const std::vector query_layer = { + 0.4962566f, 0.7682218f, 0.08847743f, 0.13203049f, 0.30742282f, 0.6340787f, 0.4900934f, 0.89644474f, + 0.45562798f, 0.6323063f, 0.34889346f, 0.4017173f, 0.022325754f, 0.16885895f, 0.29388845f, 0.5185218f, + 0.6976676f, 0.8000114f, 0.16102946f, 0.28226858f, 0.68160856f, 0.915194f, 0.3970999f, 0.8741559f, + 0.41940832f, 0.55290705f, 0.9527381f, 0.03616482f, 0.18523103f, 0.37341738f, 0.30510002f, 0.9320004f, + 0.17591017f, 0.26983356f, 0.15067977f, 0.031719506f, 0.20812976f, 0.929799f, 0.7231092f, 0.7423363f, + 0.5262958f, 0.24365824f, 0.58459234f, 0.03315264f, 0.13871688f, 0.242235f, 0.81546897f, 0.7931606f}; + const std::vector query_bias_dim = {8}; + const std::vector query_bias = { + 0.27825248f, 0.4819588f, 0.81978035f, 0.99706656f, 0.6984411f, 0.5675464f, 0.83524317f, 0.20559883f}; + const std::vector rel_pos_dim = {1, 2, 3, 3}; + const std::vector rel_pos = { + 0.593172f, 0.112347245f, 0.15345693f, 0.24170822f, 0.7262365f, 0.7010802f, 0.20382375f, 0.65105355f, + 0.774486f, 0.43689132f, 0.5190908f, 0.61585236f, 0.8101883f, 0.98009706f, 0.11468822f, 0.31676513f, + 0.69650495f, 0.9142747f}; + const std::vector weight_dim = {4, 8}; + const std::vector weight = { + 0.93510365f, 0.9411784f, 0.5995073f, 0.06520867f, 0.54599625f, 0.18719733f, 0.034022927f, 0.94424623f, + 0.8801799f, 0.0012360215f, 0.593586f, 0.41577f, 0.41771942f, 0.27112156f, 0.6922781f, 0.20384824f, + 0.68329567f, 0.75285405f, 0.8579358f, 0.6869556f, 0.005132377f, 0.17565155f, 0.7496575f, 0.6046507f, + 0.10995799f, 0.21209025f, 0.97037464f, 0.83690894f, 0.28198743f, 0.3741576f, 0.023700953f, 0.49101293f}; + const std::vector bias_dim = {8}; + const std::vector bias = { + 0.123470545f, 0.11432165f, 0.4724502f, 0.5750725f, 0.29523486f, 0.7966888f, 0.19573045f, 0.95368505f}; + const std::vector eco_a_dim = {1, 2, 1, 1}; + const std::vector eco_a = { + 0.84264994f, 0.07835853f}; + const std::vector output_dim = {2, 2, 3, 3}; + const std::vector output = { + 1.0928818f, 0.20699267f, 0.28273466f, 0.44534987f, 1.3380982f, 1.2917475f, 0.3755537f, 1.1995932f, + 1.4270226f, 0.47112367f, 0.5597638f, 0.6641071f, 0.87368786f, 1.0569134f, 0.12367705f, 0.34158573f, + 0.75108063f, 0.98591405f, 1.0929474f, 0.2070051f, 0.28275162f, 0.4451845f, 1.3376014f, 1.2912678f, + 0.37552574f, 1.1995038f, 1.4269164f, 0.47112313f, 0.5597632f, 0.6641063f, 0.87367094f, 1.056893f, + 0.12367466f, 0.34158388f, 0.7510766f, 0.98590875f}; + + RunGatedRelativePositionBiasTest(query_layer, query_bias, rel_pos, weight, bias, eco_a, output, + batch_size, seq_len, num_heads, head_size, D, false); +} + +TEST(GatedRelativePositionBiasTest, FP32_LongSeq_BSNHD_2x5x2x4x4) { + constexpr int batch_size = 2; + constexpr int num_heads = 2; + constexpr int seq_len = 5; + constexpr int head_size = 4; + constexpr int D = 4; + const std::vector query_layer_dim = {2, 5, 8}; + const std::vector query_layer = { + 0.4962566f, 0.7682218f, 0.08847743f, 0.13203049f, 0.30742282f, 0.6340787f, 0.4900934f, 0.89644474f, + 0.45562798f, 0.6323063f, 0.34889346f, 0.4017173f, 0.022325754f, 0.16885895f, 0.29388845f, 0.5185218f, + 0.6976676f, 0.8000114f, 0.16102946f, 0.28226858f, 0.68160856f, 0.915194f, 0.3970999f, 0.8741559f, + 0.41940832f, 0.55290705f, 0.9527381f, 0.03616482f, 0.18523103f, 0.37341738f, 0.30510002f, 0.9320004f, + 0.17591017f, 0.26983356f, 0.15067977f, 0.031719506f, 0.20812976f, 0.929799f, 0.7231092f, 0.7423363f, + 0.5262958f, 0.24365824f, 0.58459234f, 0.03315264f, 0.13871688f, 0.242235f, 0.81546897f, 0.7931606f, + 0.27825248f, 0.4819588f, 0.81978035f, 0.99706656f, 0.6984411f, 0.5675464f, 0.83524317f, 0.20559883f, + 0.593172f, 0.112347245f, 0.15345693f, 0.24170822f, 0.7262365f, 0.7010802f, 0.20382375f, 0.65105355f, + 0.774486f, 0.43689132f, 0.5190908f, 0.61585236f, 0.8101883f, 0.98009706f, 0.11468822f, 0.31676513f, + 0.69650495f, 0.9142747f, 0.93510365f, 0.9411784f, 0.5995073f, 0.06520867f, 0.54599625f, 0.18719733f}; + const std::vector query_bias_dim = {8}; + const std::vector query_bias = { + 0.034022927f, 0.94424623f, 0.8801799f, 0.0012360215f, 0.593586f, 0.41577f, 0.41771942f, 0.27112156f}; + const std::vector rel_pos_dim = {1, 2, 5, 5}; + const std::vector rel_pos = { + 0.6922781f, 0.20384824f, 0.68329567f, 0.75285405f, 0.8579358f, 0.6869556f, 0.005132377f, 0.17565155f, + 0.7496575f, 0.6046507f, 0.10995799f, 0.21209025f, 0.97037464f, 0.83690894f, 0.28198743f, 0.3741576f, + 0.023700953f, 0.49101293f, 0.123470545f, 0.11432165f, 0.4724502f, 0.5750725f, 0.29523486f, 0.7966888f, + 0.19573045f, 0.95368505f, 0.84264994f, 0.07835853f, 0.37555784f, 0.5225613f, 0.57295054f, 0.61858714f, + 0.69621414f, 0.5299501f, 0.25603563f, 0.7365945f, 0.02037555f, 0.20364666f, 0.37483507f, 0.25644332f, + 0.32508332f, 0.09018916f, 0.39364243f, 0.6068782f, 0.17426711f, 0.47434032f, 0.8579254f, 0.44859987f, + 0.5138961f, 0.45686555f}; + const std::vector weight_dim = {4, 4}; + const std::vector weight = { + 0.6011907f, 0.81791973f, 0.9736231f, 0.81752795f, 0.97470677f, 0.46383917f, 0.050839245f, 0.2629614f, + 0.8404526f, 0.49675876f, 0.25147682f, 0.11684412f, 0.032073975f, 0.0779959f, 0.39858162f, 0.774203f}; + const std::vector bias_dim = {4}; + const std::vector bias = { + 0.77032053f, 0.017784059f, 0.811891f, 0.10874528f}; + const std::vector eco_a_dim = {1, 2, 1, 1}; + const std::vector eco_a = { + 0.39429486f, 0.29726368f}; + const std::vector output_dim = {2, 2, 5, 5}; + const std::vector output = { + 0.9534052f, 0.28073975f, 0.9410346f, 1.0368304f, 1.181549f, 0.94923383f, 0.0070919087f, 0.24271497f, + 1.0358753f, 0.8355051f, 0.15224966f, 0.29366368f, 1.3435968f, 1.158798f, 0.3904445f, 0.5147038f, + 0.03260383f, 0.67545396f, 0.16985025f, 0.15726471f, 0.64280313f, 0.7824283f, 0.40168867f, 1.0839535f, + 0.26630563f, 1.2391479f, 1.0948771f, 0.101813294f, 0.48797214f, 0.6789776f, 0.7492329f, 0.8089107f, + 0.91042155f, 0.6930023f, 0.3348113f, 0.95611423f, 0.026447866f, 0.2643374f, 0.48654333f, 0.3328685f, + 0.4239932f, 0.117630124f, 0.5134121f, 0.7915271f, 0.22728965f, 0.61497897f, 1.1122944f, 0.5816067f, + 0.6662628f, 0.59232306f, 0.95294285f, 0.2806036f, 0.9405782f, 1.0363276f, 1.1809759f, 0.95289487f, + 0.007119261f, 0.24365108f, 1.0398705f, 0.83872753f, 0.15201466f, 0.29321042f, 1.3415229f, 1.1570094f, + 0.38984182f, 0.51978874f, 0.032925934f, 0.682127f, 0.17152825f, 0.15881838f, 0.6571103f, 0.79984313f, + 0.4106292f, 1.1080796f, 0.2722329f, 1.2398669f, 1.0955123f, 0.101872355f, 0.4882552f, 0.6793715f, + 0.7427765f, 0.8019401f, 0.9025762f, 0.6870305f, 0.33192614f, 0.9568577f, 0.026468432f, 0.26454294f, + 0.48692167f, 0.33312735f, 0.4217717f, 0.117013805f, 0.5107221f, 0.78737986f, 0.22609876f, 0.6166911f, + 1.1153911f, 0.5832259f, 0.6681177f, 0.59397215f}; + + RunGatedRelativePositionBiasTest(query_layer, query_bias, rel_pos, weight, bias, eco_a, output, + batch_size, seq_len, num_heads, head_size, D, false); +} + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/sequence/concat_from_sequence_op_test.cc b/onnxruntime/test/providers/cpu/sequence/concat_from_sequence_op_test.cc index dd5393f8cb5fc..e4fd147a535ea 100644 --- a/onnxruntime/test/providers/cpu/sequence/concat_from_sequence_op_test.cc +++ b/onnxruntime/test/providers/cpu/sequence/concat_from_sequence_op_test.cc @@ -105,7 +105,9 @@ TEST(SequenceOpsTest, ConcatFromSequence_Concat_Axis2) { input.AddTensor({1, 2}, {2, 3}); test.AddSeqInput("S", input); test.AddOutput("I", {1, 2, 2}, {0, 2, 1, 3}); - test.Run(OpTester::ExpectResult::kExpectFailure, "axis 2 is not in valid range [-2,1]"); + + // TODO: Unskip when fixed #41968513 + test.Run(OpTester::ExpectResult::kExpectFailure, "axis 2 is not in valid range [-2,1]", {kDmlExecutionProvider}); } TEST(SequenceOpsTest, ConcatFromSequence_Concat_Axis1_WithEmptyInput) { @@ -134,8 +136,11 @@ TEST(SequenceOpsTest, ConcatFromSequence_Concat_ScalarInputs) { input.AddTensor({}, {3}); test.AddSeqInput("S", input); test.AddOutput("I", {3}, {1, 2, 3}); + + // TODO: Unskip when fixed #41968513 test.Run(OpTester::ExpectResult::kExpectFailure, - "Cannot concatenate scalars"); + "Cannot concatenate scalars", + {kDmlExecutionProvider}); } } // namespace test diff --git a/onnxruntime/test/providers/cpu/tensor/identity_op_test.cc b/onnxruntime/test/providers/cpu/tensor/identity_op_test.cc index 11b4b9a99e3b6..c443c306e2b5e 100644 --- a/onnxruntime/test/providers/cpu/tensor/identity_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/identity_op_test.cc @@ -25,11 +25,6 @@ TEST(Identity, StringType) { } TEST(Identity, SequenceType) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: Failed to find args for kernel type string 'T'"; - } - OpTester test("Identity", 14, kOnnxDomain); SeqTensors input; input.AddTensor({3, 2}, {1, 2, 3, 4, 5, 6}); @@ -42,11 +37,6 @@ TEST(Identity, SequenceType) { #if !defined(DISABLE_OPTIONAL_TYPE) TEST(Identity, OptionalTensorType_NonNone) { - // TODO: Unskip when fixed #42638109 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: Failed to find args for kernel type string 'T'"; - } - OpTester test("Identity", 16, kOnnxDomain); // Since this test is being written at a time when only opset 15 has been released, we set // `test_allow_released_onnx_opset_only_` to 'false' to allow this test to run @@ -59,11 +49,6 @@ TEST(Identity, OptionalTensorType_NonNone) { } TEST(Identity, OptionalTensorType_None) { - // TODO: Unskip when fixed #42638109 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: Failed to find args for kernel type string 'T'"; - } - OpTester test("Identity", 16, kOnnxDomain); // Since this test is being written at a time when only opset 15 has been released, we set // `test_allow_released_onnx_opset_only_` to 'false' to allow this test to run @@ -75,11 +60,6 @@ TEST(Identity, OptionalTensorType_None) { } TEST(Identity, OptionalTensorSequenceType_NonNone) { - // TODO: Unskip when fixed #42638109 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: Failed to find args for kernel type string 'T'"; - } - OpTester test("Identity", 16, kOnnxDomain); // Since this test is being written at a time when only opset 15 has been released, we set // `test_allow_released_onnx_opset_only_` to 'false' to allow this test to run @@ -95,11 +75,6 @@ TEST(Identity, OptionalTensorSequenceType_NonNone) { } TEST(Identity, OptionalTensorSequenceType_None) { - // TODO: Unskip when fixed #42638109 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: Failed to find args for kernel type string 'T'"; - } - OpTester test("Identity", 16, kOnnxDomain); // Since this test is being written at a time when only opset 15 has been released, we set // `test_allow_released_onnx_opset_only_` to 'false' to allow this test to run diff --git a/onnxruntime/test/providers/provider_test_utils.h b/onnxruntime/test/providers/provider_test_utils.h index a5d17f1fb6d2d..a06efb5f2a62e 100644 --- a/onnxruntime/test/providers/provider_test_utils.h +++ b/onnxruntime/test/providers/provider_test_utils.h @@ -3,6 +3,7 @@ #pragma once +#include #include #include @@ -1003,10 +1004,10 @@ class OpTester { if (seq_tensors) { auto num_tensors = seq_tensors->tensors.size(); - std::vector tensors; - tensors.resize(num_tensors); auto elem_type = DataTypeImpl::GetType(); + ptr = std::make_unique(elem_type); + ptr->Reserve(num_tensors); for (size_t i = 0; i < num_tensors; ++i) { TensorShape shape{seq_tensors->tensors[i].shape}; auto values_count = static_cast(seq_tensors->tensors[i].data.size()); @@ -1014,17 +1015,15 @@ class OpTester { " input values doesn't match tensor size of ", shape.Size()); auto allocator = test::AllocatorManager::Instance().GetAllocator(CPU); - auto& tensor = tensors[i]; - - tensor = Tensor(elem_type, - shape, - allocator); + Tensor tensor(elem_type, shape, allocator); auto* data_ptr = tensor.MutableData(); for (int64_t x = 0; x < values_count; ++x) { data_ptr[x] = seq_tensors->tensors[i].data[x]; } + ptr->Add(std::move(tensor)); + if (add_shape_to_tensor_data_) { auto* output_tensor_type = sequence_tensor_proto.proto.mutable_sequence_type() ->mutable_elem_type() @@ -1047,9 +1046,6 @@ class OpTester { } } } - - ptr = std::make_unique(elem_type); - ptr->SetElements(std::move(tensors)); } OrtValue value; diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index 7758603c484fc..29484986bd6ab 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -5598,3 +5598,36 @@ def forward(self, non_contiguous_tensor): x_copy = copy.deepcopy(x) assert not x.is_contiguous() _test_helpers.assert_values_are_close(pt_model(x), ort_model(x_copy)) + + +def test_gradient_correctness_bce_with_logits(): + class NeuralNetBCEWithLogitsLoss(torch.nn.Module): + def __init__(self, input_size, hidden_size): + super(NeuralNetBCEWithLogitsLoss, self).__init__() + self.linear = torch.nn.Linear(input_size, hidden_size) + + def forward(self, input, target): + loss_fct = torch.nn.BCEWithLogitsLoss() + return loss_fct(self.linear(input), target) + + N, D, H = 16, 256, 128 + device = "cuda" + pt_model = NeuralNetBCEWithLogitsLoss(D, H).to(device) + ort_model = ORTModule(copy.deepcopy(pt_model)) + + def run_step(model, input, target): + prediction = model(input, target) + loss = prediction.sum() + loss.backward() + return prediction + + for _ in range(10): + pt_input = torch.rand((N, D), device=device, requires_grad=True) + ort_input = copy.deepcopy(pt_input) + pt_target = torch.rand((N, H), device=device) + ort_target = copy.deepcopy(pt_target) + pt_prediction = run_step(pt_model, pt_input, pt_target) + ort_prediction = run_step(ort_model, ort_input, ort_target) + + _test_helpers.assert_values_are_close(ort_prediction, pt_prediction) + _test_helpers.assert_values_are_close(ort_input.grad, pt_input.grad) diff --git a/orttraining/orttraining/training_api/optimizer.cc b/orttraining/orttraining/training_api/optimizer.cc index efe6bbd2c9027..5d2d3c12259c4 100644 --- a/orttraining/orttraining/training_api/optimizer.cc +++ b/orttraining/orttraining/training_api/optimizer.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "core/framework/execution_provider.h" +#include "core/framework/TensorSeq.h" #include "core/providers/cpu/cpu_execution_provider.h" #include "core/session/inference_session.h" #include "core/session/environment.h" @@ -131,7 +132,11 @@ Status Optimizer::ConstructInputs() { ORT_ENFORCE(!tensors.empty(), "Tensors vector cannot be empty while building a tensor sequence."); auto tensor_seq = std::make_unique(tensors.front().DataType()); - tensor_seq->SetElements(std::move(tensors)); + tensor_seq->Reserve(tensors.size()); + for (auto& tensor : tensors) + { + tensor_seq->Add(std::move(tensor)); + } inputs->emplace_back( OrtValue(tensor_seq.release(), DataTypeImpl::GetType(), DataTypeImpl::GetType()->GetDeleteFunc())); diff --git a/orttraining/orttraining/training_ops/cpu/optimizer/adamw/adamw.cc b/orttraining/orttraining/training_ops/cpu/optimizer/adamw/adamw.cc index 9682373dd8a3a..8fe0cb89255ab 100644 --- a/orttraining/orttraining/training_ops/cpu/optimizer/adamw/adamw.cc +++ b/orttraining/orttraining/training_ops/cpu/optimizer/adamw/adamw.cc @@ -4,6 +4,7 @@ #include "orttraining/training_ops/cpu/optimizer/adamw/adamw.h" #include "orttraining/training_ops/cpu/optimizer/common.h" #include "core/framework/op_kernel.h" +#include "core/framework/TensorSeq.h" #include "core/platform/threadpool.h" #include "core/providers/common.h" #include "core/providers/cpu/math/element_wise_ops.h" diff --git a/orttraining/orttraining/training_ops/cpu/optimizer/clip_grad_norm/clip_grad_norm.cc b/orttraining/orttraining/training_ops/cpu/optimizer/clip_grad_norm/clip_grad_norm.cc index 8d441e18d448b..a46744da2beec 100644 --- a/orttraining/orttraining/training_ops/cpu/optimizer/clip_grad_norm/clip_grad_norm.cc +++ b/orttraining/orttraining/training_ops/cpu/optimizer/clip_grad_norm/clip_grad_norm.cc @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "core/framework/TensorSeq.h" #include "orttraining/training_ops/cpu/optimizer/clip_grad_norm/clip_grad_norm.h" #include "core/providers/cpu/math/element_wise_ops.h" #include "core/providers/cpu/tensor/utils.h" @@ -18,7 +19,7 @@ T GetL2Norm(const TensorSeq& gradients) { T l2_norm = 0; for (const auto& tensor : gradients) { l2_norm += - ReduceAggregatorSumSquare(tensor.Shape().Size(), *tensor.Data()).aggall(tensor.Data()); + ReduceAggregatorSumSquare(tensor.Get().Shape().Size(), *tensor.Get().Data()).aggall(tensor.Get().Data()); } return reduce_sqrt(l2_norm); } @@ -27,7 +28,10 @@ template void ClipGradNorm(T total_norm, T max_norm, TensorSeq& gradients) { const T clip_coefficient = std::min(max_norm / (total_norm + static_cast(Epsilon)), static_cast(1.0f)); - for (const auto& grad : gradients) { + auto gradients_size = gradients.Size(); + for (size_t i = 0; i < gradients_size; i++) + { + const auto& grad = gradients.Get(i); auto& tensor = const_cast(grad); MakeEigenArrayMap(tensor) *= clip_coefficient; } @@ -42,8 +46,13 @@ Status PopulateOutput(OpKernelContext* ctx, const TensorSeq* gradients, TensorSe ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&alloc)); clipped_gradients->SetType(gradients->DataType()); - clipped_gradients->Reserve(gradients->Size()); - for (const auto& grad : *gradients) { + clipped_gradients->SetElements({}); + + auto gradients_size = gradients->Size(); + clipped_gradients->Reserve(gradients_size); + for (size_t i = 0; i < gradients_size; i++) + { + const auto& grad = gradients->Get(i); Tensor target_tensor(grad.DataType(), grad.Shape(), alloc); CopyCpuTensor(&grad, &target_tensor); clipped_gradients->Add(std::move(target_tensor)); // Add will check for type consistency diff --git a/orttraining/orttraining/training_ops/cpu/optimizer/common.cc b/orttraining/orttraining/training_ops/cpu/optimizer/common.cc index 250d994609e12..6cd3eb857fef1 100644 --- a/orttraining/orttraining/training_ops/cpu/optimizer/common.cc +++ b/orttraining/orttraining/training_ops/cpu/optimizer/common.cc @@ -3,6 +3,7 @@ #include "core/common/common.h" #include "core/framework/op_kernel.h" +#include "core/framework/TensorSeq.h" #include "core/providers/cpu/tensor/utils.h" #include "orttraining/training_ops/cpu/optimizer/common.h" diff --git a/orttraining/orttraining/training_ops/cpu/optimizer/sgd/sgd.cc b/orttraining/orttraining/training_ops/cpu/optimizer/sgd/sgd.cc index 6c20214dd8306..3b4ed5b0466ea 100644 --- a/orttraining/orttraining/training_ops/cpu/optimizer/sgd/sgd.cc +++ b/orttraining/orttraining/training_ops/cpu/optimizer/sgd/sgd.cc @@ -4,6 +4,7 @@ #include "orttraining/training_ops/cpu/optimizer/sgd/sgd.h" #include "orttraining/training_ops/cpu/optimizer/common.h" #include "core/framework/op_kernel.h" +#include "core/framework/TensorSeq.h" #include "core/providers/common.h" #include "core/providers/cpu/math/element_wise_ops.h" diff --git a/tools/ci_build/github/android/setup_gradle_wrapper.sh b/tools/ci_build/github/android/setup_gradle_wrapper.sh deleted file mode 100755 index b5fa40dd25144..0000000000000 --- a/tools/ci_build/github/android/setup_gradle_wrapper.sh +++ /dev/null @@ -1,19 +0,0 @@ -#!/bin/bash - -# This script will setup gradlew to use gradle version 6.8.3 for Android CI, -# since the macOS pipeline is using gradle 7.0 which will fail the java build -# See, https://github.com/actions/virtual-environments/issues/3195 - -set -e -set -x - -if [ $# -ne 1 ]; then - echo "One command line argument, the ORT root directory, is expected" -fi - -ORT_ROOT=$1 - -pushd ${ORT_ROOT}/java -gradle wrapper --gradle-version 6.8.3 --no-daemon --no-watch-fs -./gradlew --version -popd diff --git a/tools/ci_build/github/azure-pipelines/android-x86_64-crosscompile-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/android-x86_64-crosscompile-ci-pipeline.yml index 182413df4cedc..f903cda3682d6 100644 --- a/tools/ci_build/github/azure-pipelines/android-x86_64-crosscompile-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/android-x86_64-crosscompile-ci-pipeline.yml @@ -39,8 +39,7 @@ stages: - template: "templates/use-android-ndk.yml" - - script: /bin/bash tools/ci_build/github/android/setup_gradle_wrapper.sh $(pwd) - displayName: Setup gradle wrapper to use gradle 6.8.3 + - template: templates/set-up-gradle-wrapper-step.yml # We build the host protoc to /protobuf_install - script: | @@ -126,8 +125,7 @@ stages: - template: "templates/use-android-ndk.yml" - - script: /bin/bash tools/ci_build/github/android/setup_gradle_wrapper.sh $(pwd) - displayName: Setup gradle wrapper to use gradle 6.8.3 + - template: templates/set-up-gradle-wrapper-step.yml # We build the host protoc to /protobuf_install - script: | @@ -231,8 +229,7 @@ stages: - template: "templates/use-android-ndk.yml" - - script: /bin/bash tools/ci_build/github/android/setup_gradle_wrapper.sh $(pwd) - displayName: Setup gradle wrapper to use gradle 6.8.3 + - template: templates/set-up-gradle-wrapper-step.yml - script: | python3 tools/python/run_android_emulator.py \ @@ -306,8 +303,7 @@ stages: - template: "templates/use-android-ndk.yml" - - script: /bin/bash tools/ci_build/github/android/setup_gradle_wrapper.sh $(pwd) - displayName: Setup gradle wrapper to use gradle 6.8.3 + - template: templates/set-up-gradle-wrapper-step.yml - script: | python3 tools/python/run_android_emulator.py \ @@ -391,8 +387,7 @@ stages: - template: "templates/use-android-ndk.yml" - - script: /bin/bash tools/ci_build/github/android/setup_gradle_wrapper.sh $(pwd) - displayName: Setup gradle wrapper to use gradle 6.8.3 + - template: templates/set-up-gradle-wrapper-step.yml # used by Build Minimal ORT - script: brew install coreutils ninja diff --git a/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar-test.yml b/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar-test.yml index 5c64f4b619c4a..445a633d89d82 100644 --- a/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar-test.yml +++ b/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar-test.yml @@ -39,8 +39,7 @@ jobs: - template: use-android-ndk.yml - - script: /bin/bash tools/ci_build/github/android/setup_gradle_wrapper.sh $(Build.SourcesDirectory) - displayName: Setup gradle wrapper to use gradle 6.8.3 + - template: set-up-gradle-wrapper-step.yml - script: | python3 $(Build.SourcesDirectory)/tools/python/run_android_emulator.py \ diff --git a/tools/ci_build/github/azure-pipelines/templates/mac-ci.yml b/tools/ci_build/github/azure-pipelines/templates/mac-ci.yml index 280c321033093..ea2f413b2b8b8 100644 --- a/tools/ci_build/github/azure-pipelines/templates/mac-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/mac-ci.yml @@ -50,6 +50,8 @@ jobs: - template: set-version-number-variables-step.yml + - template: set-up-gradle-wrapper-step.yml + - script: | brew install ccache echo "##vso[task.prependpath]/usr/local/opt/ccache/libexec" diff --git a/tools/ci_build/github/azure-pipelines/templates/react-native-ci.yml b/tools/ci_build/github/azure-pipelines/templates/react-native-ci.yml index 33f6d1f871512..4aa15af4be186 100644 --- a/tools/ci_build/github/azure-pipelines/templates/react-native-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/react-native-ci.yml @@ -60,9 +60,7 @@ jobs: brew install coreutils ninja npm yarn displayName: Install coreutils, ninja, npm, and yarn - - script: - /bin/bash $(Build.SourcesDirectory)/tools/ci_build/github/android/setup_gradle_wrapper.sh $(pwd) - displayName: Setup gradle wrapper to use gradle 6.8.3 + - template: set-up-gradle-wrapper-step.yml - script: | python3 -m pip install -q flatbuffers diff --git a/tools/ci_build/github/azure-pipelines/templates/set-up-gradle-wrapper-step.yml b/tools/ci_build/github/azure-pipelines/templates/set-up-gradle-wrapper-step.yml new file mode 100644 index 0000000000000..8409288246345 --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/templates/set-up-gradle-wrapper-step.yml @@ -0,0 +1,10 @@ +parameters: +- name: GradleWrapperVersion + type: string + default: "6.8.3" + +steps: +- script: | + gradle wrapper --gradle-version="${{ parameters.GradleWrapperVersion }}" --no-daemon --no-watch-fs + displayName: Set up gradle ${{ parameters.GradleWrapperVersion }} wrapper + workingDirectory: $(Build.SourcesDirectory)/java