diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index f985cf10ded60..5e38789b65137 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -305,6 +305,7 @@ Do not modify directly.*
|||[13, 17]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)|
|||[11, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)|
|||[1, 10]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)|
+|RegexFullMatch|*in* X:**T1**
*out* Y:**T2**|20+|**T1** = tensor(string)
**T2** = tensor(bool)|
|Relu|*in* X:**T**
*out* Y:**T**|14+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int8)|
|||13|**T** = tensor(double), tensor(float)|
|||[6, 12]|**T** = tensor(double), tensor(float)|
@@ -382,6 +383,7 @@ Do not modify directly.*
|Squeeze|*in* data:**T**
*in* axes:**tensor(int64)**
*out* squeezed:**T**
or
*in* data:**T**
*out* squeezed:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||[1, 10]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|StringConcat|*in* X:**T**
*in* Y:**T**
*out* Z:**T**|20+|**T** = tensor(string)|
|StringNormalizer|*in* X:**tensor(string)**
*out* Y:**tensor(string)**|10+|**X** = tensor(string)|
|Sub|*in* A:**T**
*in* B:**T**
*out* C:**T**|14+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)|
|||13|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)|
diff --git a/include/onnxruntime/core/common/logging/logging.h b/include/onnxruntime/core/common/logging/logging.h
index bea3fa1d09cc2..2b9912ea77389 100644
--- a/include/onnxruntime/core/common/logging/logging.h
+++ b/include/onnxruntime/core/common/logging/logging.h
@@ -75,6 +75,21 @@ struct Category {
// TODO: What other high level categories are meaningful? Model? Optimizer? Execution?
};
+///
+/// ORT TraceLogging keywords for categories of dynamic logging enablement
+///
+enum class ORTTraceLoggingKeyword : uint64_t {
+ Session = 0x1, // ORT Session TraceLoggingWrite
+ Logs = 0x2, // LOGS() Macro ORT logs. Pair with an appropriate level depending on detail required
+ Reserved1 = 0x4, // Reserved if we want to add some specific sub-categories instead of just LOGS() or other uses
+ Reserved2 = 0x8,
+ Reserved3 = 0x10,
+ Reserved4 = 0x20,
+ Reserved5 = 0x40,
+ Reserved6 = 0x80,
+ Profiling = 0x100 // Enables profiling. At higher levels >5 can impact inference performance
+};
+
class ISink;
class Logger;
class Capture;
@@ -333,5 +348,17 @@ unsigned int GetThreadId();
*/
unsigned int GetProcessId();
+/**
+ If the ONNXRuntimeTraceLoggingProvider ETW Provider is enabled, then adds to the existing logger.
+*/
+std::unique_ptr EnhanceLoggerWithEtw(std::unique_ptr existingLogger, logging::Severity originalSeverity,
+ logging::Severity etwSeverity);
+
+/**
+ If the ONNXRuntimeTraceLoggingProvider ETW Provider is enabled, then can override the logging level.
+ But this overrided level only applies to the ETW sink. The original logger(s) retain their original logging level
+*/
+Severity OverrideLevelWithEtw(Severity originalSeverity);
+
} // namespace logging
} // namespace onnxruntime
diff --git a/include/onnxruntime/core/providers/cuda/cuda_context.h b/include/onnxruntime/core/providers/cuda/cuda_context.h
index d73d551920d47..9416fad5f1448 100644
--- a/include/onnxruntime/core/providers/cuda/cuda_context.h
+++ b/include/onnxruntime/core/providers/cuda/cuda_context.h
@@ -28,38 +28,45 @@ struct CudaContext : public CustomOpContext {
cudnnHandle_t cudnn_handle = {};
cublasHandle_t cublas_handle = {};
OrtAllocator* deferred_cpu_allocator = {};
+ // below are cuda ep options
+ int16_t device_id = 0;
+ int32_t arena_extend_strategy = 0;
+ int32_t cudnn_conv_algo_search = 0;
+ bool cudnn_conv_use_max_workspace = true;
+ bool cudnn_conv1d_pad_to_nc1d = false;
+ bool enable_skip_layer_norm_strict_mode = false;
+ bool prefer_nhwc = false;
void Init(const OrtKernelContext& kernel_ctx) {
- const auto& ort_api = Ort::GetApi();
- void* resource = {};
- OrtStatus* status = nullptr;
-
- status = ort_api.KernelContext_GetResource(&kernel_ctx, ORT_CUDA_RESOUCE_VERSION, CudaResource::cuda_stream_t, &resource);
- if (status) {
- ORT_CXX_API_THROW("failed to fetch cuda stream", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
- }
- cuda_stream = reinterpret_cast(resource);
-
- resource = {};
- status = ort_api.KernelContext_GetResource(&kernel_ctx, ORT_CUDA_RESOUCE_VERSION, CudaResource::cudnn_handle_t, &resource);
- if (status) {
- ORT_CXX_API_THROW("failed to fetch cudnn handle", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
- }
- cudnn_handle = reinterpret_cast(resource);
+ cuda_stream = FetchResource(kernel_ctx, CudaResource::cuda_stream_t);
+ cudnn_handle = FetchResource(kernel_ctx, CudaResource::cudnn_handle_t);
+ cublas_handle = FetchResource(kernel_ctx, CudaResource::cublas_handle_t);
+ deferred_cpu_allocator = FetchResource(kernel_ctx, CudaResource::deferred_cpu_allocator_t);
+
+ device_id = FetchResource(kernel_ctx, CudaResource::device_id_t);
+ arena_extend_strategy = FetchResource(kernel_ctx, CudaResource::arena_extend_strategy_t);
+ cudnn_conv_algo_search = FetchResource(kernel_ctx, CudaResource::cudnn_conv_algo_search_t);
+ cudnn_conv_use_max_workspace = FetchResource(kernel_ctx, CudaResource::cudnn_conv_use_max_workspace_t);
+
+ cudnn_conv1d_pad_to_nc1d = FetchResource(kernel_ctx, CudaResource::cudnn_conv1d_pad_to_nc1d_t);
+ enable_skip_layer_norm_strict_mode = FetchResource(kernel_ctx, CudaResource::enable_skip_layer_norm_strict_mode_t);
+ prefer_nhwc = FetchResource(kernel_ctx, CudaResource::prefer_nhwc_t);
+ }
- resource = {};
- status = ort_api.KernelContext_GetResource(&kernel_ctx, ORT_CUDA_RESOUCE_VERSION, CudaResource::cublas_handle_t, &resource);
- if (status) {
- ORT_CXX_API_THROW("failed to fetch cublas handle", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
+ template
+ T FetchResource(const OrtKernelContext& kernel_ctx, CudaResource resource_type) {
+ if (sizeof(T) > sizeof(void*)) {
+ ORT_CXX_API_THROW("void* is not large enough to hold resource type: " + std::to_string(resource_type), OrtErrorCode::ORT_INVALID_ARGUMENT);
}
- cublas_handle = reinterpret_cast(resource);
-
- resource = {};
- status = ort_api.KernelContext_GetResource(&kernel_ctx, ORT_CUDA_RESOUCE_VERSION, CudaResource::deferred_cpu_allocator_t, &resource);
+ const auto& ort_api = Ort::GetApi();
+ void* resource = {};
+ OrtStatus* status = ort_api.KernelContext_GetResource(&kernel_ctx, ORT_CUDA_RESOUCE_VERSION, resource_type, &resource);
if (status) {
- ORT_CXX_API_THROW("failed to fetch deferred cpu allocator", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
+ ORT_CXX_API_THROW("Failed to fetch cuda ep resource, resouce type: " + std::to_string(resource_type), OrtErrorCode::ORT_RUNTIME_EXCEPTION);
}
- deferred_cpu_allocator = reinterpret_cast(resource);
+ T t = {};
+ memcpy(&t, &resource, sizeof(T));
+ return t;
}
void* AllocDeferredCpuMem(size_t size) const {
diff --git a/include/onnxruntime/core/providers/cuda/cuda_resource.h b/include/onnxruntime/core/providers/cuda/cuda_resource.h
index 8c3ed46ade6a1..c0e6328f27122 100644
--- a/include/onnxruntime/core/providers/cuda/cuda_resource.h
+++ b/include/onnxruntime/core/providers/cuda/cuda_resource.h
@@ -3,11 +3,19 @@
#include "core/providers/resource.h"
-#define ORT_CUDA_RESOUCE_VERSION 2
+#define ORT_CUDA_RESOUCE_VERSION 3
enum CudaResource : int {
- cuda_stream_t = cuda_resource_offset,
+ cuda_stream_t = cuda_resource_offset, // 10000
cudnn_handle_t,
cublas_handle_t,
deferred_cpu_allocator_t,
+ // below are cuda ep options
+ device_id_t, // 10004
+ arena_extend_strategy_t,
+ cudnn_conv_algo_search_t,
+ cudnn_conv_use_max_workspace_t,
+ cudnn_conv1d_pad_to_nc1d_t,
+ enable_skip_layer_norm_strict_mode_t,
+ prefer_nhwc_t,
};
\ No newline at end of file
diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h
index 8cd0d0051d1eb..504f1db7b4420 100644
--- a/include/onnxruntime/core/session/onnxruntime_c_api.h
+++ b/include/onnxruntime/core/session/onnxruntime_c_api.h
@@ -4418,7 +4418,7 @@ struct OrtApi {
ORT_API2_STATUS(GetCUDAProviderOptionsByName, _In_ const OrtCUDAProviderOptionsV2* cuda_options, _In_ const char* key, _Outptr_ void** ptr);
/**
- * Get a EP resoure.
+ * Get a EP resource.
* E.g. a cuda stream or a cublas handle
*
* \param context - Kernel context
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv.ts
index cb40a9f08d2d7..7af2c5db49f40 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/conv.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/conv.ts
@@ -138,8 +138,12 @@ const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attribut
// const hasPreluActivationWeights = false; /* TODO: add support for prelu activation weights */
const isChannelsLast = attributes.format === 'NHWC';
if (attributes.group !== 1) {
- if (isChannelsLast && inputs[1].dims[0] === attributes.group && inputs[1].dims[1] === 1 &&
- attributes.dilations[0] === 1 && attributes.dilations[1] === 1) {
+ // Temporarily disable createGroupedConvVectorizeProgramInfo path due to bots failures with below two cases:
+ // [webgpu]Conv - conv - vectorize group - B
+ // [webgpu]Conv - conv - vectorize group - D
+ const disableGroupedConvVectorize = true;
+ if (!disableGroupedConvVectorize && isChannelsLast && inputs[1].dims[0] === attributes.group &&
+ inputs[1].dims[1] === 1 && attributes.dilations[0] === 1 && attributes.dilations[1] === 1) {
const outputShape = calculateOutputShape(
inputs[0].dims, inputs[1].dims, attributes.dilations, adjustedAttributes.pads, attributes.strides,
isChannelsLast);
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/where.ts b/js/web/lib/wasm/jsep/webgpu/ops/where.ts
index 687ee054096cc..2ef9637bcda5e 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/where.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/where.ts
@@ -76,7 +76,6 @@ const createWhereOpProgramInfo = (inputs: readonly TensorView[]): ProgramInfo =>
const isBroadcast = !(ShapeUtil.areEqual(dimsA, dimsB) && ShapeUtil.areEqual(dimsB, dimsC));
let outputShape = dimsA;
let outputSize = ShapeUtil.size(dimsA);
- const vecSize = Math.ceil(outputSize / 4);
// TODO: deal with zero-sized tensors (eg. dims=[1,0])
if (isBroadcast) {
@@ -88,6 +87,8 @@ const createWhereOpProgramInfo = (inputs: readonly TensorView[]): ProgramInfo =>
outputSize = ShapeUtil.size(outputShape);
}
+ const vecSize = Math.ceil(outputSize / 4);
+
return {
name: 'Where',
shaderCache: {inputDependencies: ['rank', 'rank', 'rank']},
diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_common.h b/onnxruntime/contrib_ops/cpu/bert/attention_common.h
index c9ed23895b60c..da489a6901512 100644
--- a/onnxruntime/contrib_ops/cpu/bert/attention_common.h
+++ b/onnxruntime/contrib_ops/cpu/bert/attention_common.h
@@ -133,6 +133,10 @@ constexpr const char* kMinSeqLenForFlashAttentionPackedQKV = "ORT_MIN_SEQ_LEN_FL
// Default value for the above setting.
constexpr int kDefaultMinSeqLenForFlashAttentionPackedQKV = 513;
+// Environment variable to enable loading more KV data in flight in
+// DecoderMaskedMultiHeadAttention/DecoderMaskedSelfAttention kernels
+constexpr const char* kDecoderMaskedAttentionLoadKVDataInFlight = "ORT_DECODER_MASKED_ATTENTION_LOAD_KV_DATA_IN_FLIGHT";
+
} // namespace attention
} // namespace contrib
diff --git a/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.cc
index 54aad9cbaf387..a9b60da0c96ca 100644
--- a/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.cc
+++ b/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.cc
@@ -70,6 +70,10 @@ Status DecoderMaskedMultiHeadAttention::ComputeInternal(OpKernelContext*
auto& device_prop = GetDeviceProp();
DecoderMaskedMultiHeadAttentionParams parameters;
+
+ parameters.kv_data_in_flight = ParseEnvironmentVariableWithDefault(
+ attention::kDecoderMaskedAttentionLoadKVDataInFlight, false);
+
bool is_dmmha_packing = (key == nullptr && value == nullptr);
ORT_RETURN_IF_ERROR(multihead_attention_helper::CheckInputs(query,
key,
diff --git a/onnxruntime/contrib_ops/cuda/bert/decoder_masked_self_attention.cc b/onnxruntime/contrib_ops/cuda/bert/decoder_masked_self_attention.cc
index 69ed07101e647..72ede2e22b557 100644
--- a/onnxruntime/contrib_ops/cuda/bert/decoder_masked_self_attention.cc
+++ b/onnxruntime/contrib_ops/cuda/bert/decoder_masked_self_attention.cc
@@ -52,6 +52,10 @@ Status DecoderMaskedSelfAttention::ComputeInternal(OpKernelContext* cont
auto& device_prop = GetDeviceProp();
DecoderMaskedMultiHeadAttentionParams parameters;
+
+ parameters.kv_data_in_flight = ParseEnvironmentVariableWithDefault(
+ attention::kDecoderMaskedAttentionLoadKVDataInFlight, false);
+
ORT_RETURN_IF_ERROR(CheckInputs(input->Shape(),
weights->Shape(),
bias->Shape(),
diff --git a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu
index 33e7a33494778..9efb6f08e8e99 100644
--- a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu
+++ b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu
@@ -344,52 +344,148 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentio
bool has_beams = params.cache_indir != nullptr && !params.is_cross_attention;
const int* beam_indices = has_beams ? ¶ms.cache_indir[bi_max_seq_length] : nullptr;
- for (int ti = ko; ti < ti_end; ti += K_PER_ITER) {
- bool is_masked = (params.mask != nullptr) && (params.mask[bi_total_seq_length + ti] == 0);
+ if (!params.kv_data_in_flight) {
+ for (int ti = ko; ti < ti_end; ti += K_PER_ITER) {
+ bool is_masked = (params.mask != nullptr) && (params.mask[bi_total_seq_length + ti] == 0);
- // The keys loaded from the key cache.
- K_vec_k k_vec[K_VECS_PER_THREAD];
- if (ti < tlength) {
- if (has_beams) {
- const int beam_offset = beam_indices[ti] * params.num_heads * params.max_sequence_length * head_size;
+ // The keys loaded from the key cache.
+ K_vec_k k_vec[K_VECS_PER_THREAD];
+ if (ti < tlength) {
+ if (has_beams) {
+ const int beam_offset = beam_indices[ti] * params.num_heads * params.max_sequence_length * head_size;
#pragma unroll
- for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) {
- int jj = ii * params.max_sequence_length + ti;
+ for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) {
+ int jj = ii * params.max_sequence_length + ti;
- k_vec[ii] = vec_conversion(
- (*reinterpret_cast(&k_cache_batch[beam_offset + jj * QK_ELTS_IN_16B])));
- }
- } else {
+ k_vec[ii] = vec_conversion(
+ (*reinterpret_cast(&k_cache_batch[beam_offset + jj * QK_ELTS_IN_16B])));
+ }
+ } else {
#pragma unroll
- for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) {
- int jj = ii * params.max_sequence_length + ti;
+ for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) {
+ int jj = ii * params.max_sequence_length + ti;
- k_vec[ii] = vec_conversion(
- (*reinterpret_cast(&k_cache_batch[jj * QK_ELTS_IN_16B])));
+ k_vec[ii] = vec_conversion(
+ (*reinterpret_cast(&k_cache_batch[jj * QK_ELTS_IN_16B])));
+ }
}
}
- }
- // Perform the dot product and normalize qk.
- // WARNING: ALL THE THREADS OF A WARP MUST ENTER!!!
- float qk = Qk_dot::dot(q_vec, k_vec) * inv_sqrt_dh;
+ // Perform the dot product and normalize qk.
+ // WARNING: ALL THE THREADS OF A WARP MUST ENTER!!!
+ float qk = Qk_dot::dot(q_vec, k_vec) * inv_sqrt_dh;
- // This is a deviation from FasterTransformer kernel implementation
- // but this aligns with ORT's other Attention kernels which strives to
- // mimic PyTorch when dealing with mask filter values
- if (is_masked) {
- qk += params.mask_filter_value;
+ // This is a deviation from FasterTransformer kernel implementation
+ // but this aligns with ORT's other Attention kernels which strives to
+ // mimic PyTorch when dealing with mask filter values
+ if (is_masked) {
+ qk += params.mask_filter_value;
+ }
+
+ // Store the product to shared memory. There's one qk value per timestep. Update the max.
+ if (ti < tlength && tidx % THREADS_PER_KEY == 0) {
+ if (params.relative_attention_bias != nullptr) {
+ qk = add_vec(qk,
+ reinterpret_cast(params.relative_attention_bias)[hi * params.sequence_length * params.total_sequence_length + ti]);
+ }
+ qk_max = fmaxf(qk_max, qk);
+ qk_smem[ti] = qk;
+ }
}
+ } else {
+ // TODO(hasesh): Tune this value for different workloads. Currently, it is tuned for Whisper model
+ // Also tune it for different architectures. This works best for Whisper on 80GB A100.
+ constexpr int K_CACHE_DATA_LOAD_UNROLL = 4;
- // Store the product to shared memory. There's one qk value per timestep. Update the max.
- if (ti < tlength && tidx % THREADS_PER_KEY == 0) {
- if (params.relative_attention_bias != nullptr) {
- qk = add_vec(qk,
- reinterpret_cast(params.relative_attention_bias)[hi * params.sequence_length * params.total_sequence_length + ti]);
+ for (int ti = ko; ti < ti_end; ti += (K_CACHE_DATA_LOAD_UNROLL * K_PER_ITER)) {
+ int is_masked[K_CACHE_DATA_LOAD_UNROLL];
+ int beam_offset[K_CACHE_DATA_LOAD_UNROLL];
+ int time_step[K_CACHE_DATA_LOAD_UNROLL];
+ bool time_bounds_cond[K_CACHE_DATA_LOAD_UNROLL];
+
+#pragma unroll
+ for (int k_unroll = 0; k_unroll < K_CACHE_DATA_LOAD_UNROLL; ++k_unroll) {
+ is_masked[k_unroll] = 1;
+ beam_offset[k_unroll] = 0;
+ time_step[k_unroll] = ti + k_unroll * K_PER_ITER;
+ time_bounds_cond[k_unroll] = (time_step[k_unroll] < tlength);
+ }
+
+#pragma unroll
+ for (int k_unroll = 0; k_unroll < K_CACHE_DATA_LOAD_UNROLL; ++k_unroll) {
+ if (time_bounds_cond[k_unroll] && params.mask != nullptr) {
+ is_masked[k_unroll] = params.mask[bi_total_seq_length + time_step[k_unroll]];
+ }
+ }
+
+ if (has_beams) {
+ int head_maxlength_headsize_prod = params.num_heads * params.max_sequence_length * head_size;
+
+#pragma unroll
+ for (int k_unroll = 0; k_unroll < K_CACHE_DATA_LOAD_UNROLL; ++k_unroll) {
+ if (time_bounds_cond[k_unroll]) {
+ beam_offset[k_unroll] = beam_indices[time_step[k_unroll]] * head_maxlength_headsize_prod;
+ }
+ }
+ }
+
+ // The keys loaded from the key cache.
+ K_vec_k k_vec[K_CACHE_DATA_LOAD_UNROLL][K_VECS_PER_THREAD];
+
+#pragma unroll
+ for (int k_unroll = 0; k_unroll < K_CACHE_DATA_LOAD_UNROLL; ++k_unroll) {
+ if (time_bounds_cond[k_unroll]) {
+ if (has_beams) {
+#pragma unroll
+ for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) {
+ int jj = ii * params.max_sequence_length + time_step[k_unroll];
+
+ k_vec[k_unroll][ii] = vec_conversion(
+ (*reinterpret_cast(&k_cache_batch[beam_offset[k_unroll] + jj * QK_ELTS_IN_16B])));
+ }
+ } else {
+#pragma unroll
+ for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) {
+ int jj = ii * params.max_sequence_length + time_step[k_unroll];
+
+ k_vec[k_unroll][ii] = vec_conversion(
+ (*reinterpret_cast(&k_cache_batch[jj * QK_ELTS_IN_16B])));
+ }
+ }
+ }
+ }
+
+ // Perform the dot product and normalize qk.
+ // WARNING: ALL THE THREADS OF A WARP MUST ENTER!!!
+ float qk[K_CACHE_DATA_LOAD_UNROLL];
+#pragma unroll
+ for (int k_unroll = 0; k_unroll < K_CACHE_DATA_LOAD_UNROLL; ++k_unroll) {
+ qk[k_unroll] = Qk_dot::dot(q_vec, k_vec[k_unroll]) * inv_sqrt_dh;
+ }
+
+// This is a deviation from FasterTransformer kernel implementation
+// but this aligns with ORT's other Attention kernels which strives to
+// mimic PyTorch when dealing with mask filter values
+#pragma unroll
+ for (int k_unroll = 0; k_unroll < K_CACHE_DATA_LOAD_UNROLL; ++k_unroll) {
+ if (time_bounds_cond[k_unroll] && is_masked[k_unroll] == 0) {
+ qk[k_unroll] += params.mask_filter_value;
+ }
+ }
+
+// Store the product to shared memory. There's one qk value per timestep. Update the max.
+#pragma unroll
+ for (int k_unroll = 0; k_unroll < K_CACHE_DATA_LOAD_UNROLL; ++k_unroll) {
+ if (time_bounds_cond[k_unroll] && (tidx % THREADS_PER_KEY == 0)) {
+ if (params.relative_attention_bias != nullptr) {
+ qk[k_unroll] = add_vec(qk[k_unroll],
+ reinterpret_cast(params.relative_attention_bias)[hi * params.sequence_length * params.total_sequence_length + time_step[k_unroll]]);
+ }
+ qk_max = fmaxf(qk_max, qk[k_unroll]);
+ qk_smem[time_step[k_unroll]] = qk[k_unroll];
+ }
}
- qk_max = fmaxf(qk_max, qk);
- qk_smem[ti] = qk;
}
}
@@ -504,18 +600,80 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentio
V_vec_acum out;
zero(out);
- // Loop over the timesteps to compute the partial outputs.
- for (int ti = vo; ti < tlength; ti += V_PER_ITER) {
- // Fetch offset based on cache_indir when beam sampling
- const int beam_src = has_beams ? params.cache_indir[bi_max_seq_length + ti] : 0;
- const int beam_offset = has_beams ? beam_src * params.num_heads * params.max_sequence_length * head_size : 0;
+ if (!params.kv_data_in_flight) {
+ // Loop over the timesteps to compute the partial outputs.
+ for (int ti = vo; ti < tlength; ti += V_PER_ITER) {
+ // Fetch offset based on cache_indir when beam sampling
+ const int beam_src = has_beams ? params.cache_indir[bi_max_seq_length + ti] : 0;
+ const int beam_offset = has_beams ? beam_src * params.num_heads * params.max_sequence_length * head_size : 0;
+
+ // Load the values from the cache.
+ V_vec_k v = vec_conversion(*reinterpret_cast(&v_cache_batch[beam_offset + ti * head_size]));
+
+ // Load the logits from shared memory.
+ T logit = logits_smem[ti];
+ out = fma(logit, v, out);
+ }
+ } else {
+ // Loop over the timesteps to compute the partial outputs.
+
+ // TODO(hasesh): Tune this value for different workloads. Currently, it is tuned for Whisper model
+ // Also tune it for different architectures. This works best for Whisper on 80GB A100.
+ constexpr int V_CACHE_DATA_LOAD_UNROLL = 8;
+
+ for (int ti = vo; ti < tlength; ti += V_CACHE_DATA_LOAD_UNROLL * V_PER_ITER) {
+ int beam_src[V_CACHE_DATA_LOAD_UNROLL];
+ int beam_offset[V_CACHE_DATA_LOAD_UNROLL];
+ int time_step[V_CACHE_DATA_LOAD_UNROLL];
+ bool time_bounds_cond[V_CACHE_DATA_LOAD_UNROLL];
+
+#pragma unroll
+ for (int v_unroll = 0; v_unroll < V_CACHE_DATA_LOAD_UNROLL; ++v_unroll) {
+ beam_src[v_unroll] = 0;
+ beam_offset[v_unroll] = 0;
+ time_step[v_unroll] = ti + v_unroll * V_PER_ITER;
+ time_bounds_cond[v_unroll] = (time_step[v_unroll] < tlength);
+ }
+
+ int head_maxlength_headsize_prod = params.num_heads * params.max_sequence_length * head_size;
+
+ if (has_beams) {
+// Do the global memory read and corresponding compute in separate unrolled loops
+#pragma unroll
+ for (int v_unroll = 0; v_unroll < V_CACHE_DATA_LOAD_UNROLL; ++v_unroll) {
+ if (time_bounds_cond[v_unroll]) {
+ beam_src[v_unroll] = params.cache_indir[bi_max_seq_length + time_step[v_unroll]];
+ }
+ }
+
+#pragma unroll
+ for (int v_unroll = 0; v_unroll < V_CACHE_DATA_LOAD_UNROLL; ++v_unroll) {
+ if (time_bounds_cond[v_unroll]) {
+ beam_offset[v_unroll] = beam_src[v_unroll] * head_maxlength_headsize_prod;
+ }
+ }
+ }
- // Load the values from the cache.
- V_vec_k v = vec_conversion(*reinterpret_cast(&v_cache_batch[beam_offset + ti * head_size]));
+ // Load the values from the V-cache and logits from shared memory.
+ V_vec_k v[V_CACHE_DATA_LOAD_UNROLL];
+ T logits[V_CACHE_DATA_LOAD_UNROLL];
- // Load the logits from shared memory.
- T logit = logits_smem[ti];
- out = fma(logit, v, out);
+// Do the global memory read and compute in separate unrolled loops
+#pragma unroll
+ for (int v_unroll = 0; v_unroll < V_CACHE_DATA_LOAD_UNROLL; ++v_unroll) {
+ if (time_bounds_cond[v_unroll]) {
+ v[v_unroll] = vec_conversion(*reinterpret_cast(&v_cache_batch[beam_offset[v_unroll] + time_step[v_unroll] * head_size]));
+ logits[v_unroll] = logits_smem[time_step[v_unroll]];
+ }
+ }
+
+#pragma unroll
+ for (int v_unroll = 0; v_unroll < V_CACHE_DATA_LOAD_UNROLL; ++v_unroll) {
+ if (time_bounds_cond[v_unroll]) {
+ out = fma(logits[v_unroll], v[v_unroll], out);
+ }
+ }
+ }
}
// One group of threads computes the product(s) for the current timestep.
diff --git a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.h
index 4b408dafa2d81..1a17757d1ec2d 100644
--- a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.h
+++ b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.h
@@ -22,6 +22,12 @@ struct DecoderMaskedMultiHeadAttentionParams : AttentionParameters {
bool is_cross_attention = false;
bool is_packed_qkv = false;
+ // Useful to better use global memory bandwidth on certain CUDA architectures.
+ // Turned off by default for now until we fully understand performance implications
+ // for all types of workloads.
+ // Can be turned on by appropriate environment variable (see attention_common.h).
+ bool kv_data_in_flight = false;
+
void* q = nullptr;
void* q_bias = nullptr;
@@ -62,4 +68,4 @@ void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParams& params, cud
} // namespace cuda
} // namespace contrib
-} // namespace onnxruntime
+} // namespace onnxruntime
\ No newline at end of file
diff --git a/onnxruntime/core/common/logging/logging.cc b/onnxruntime/core/common/logging/logging.cc
index 6c6e2f48557ef..eac9a7fa08081 100644
--- a/onnxruntime/core/common/logging/logging.cc
+++ b/onnxruntime/core/common/logging/logging.cc
@@ -12,6 +12,8 @@
#ifdef _WIN32
#include
+#include "core/platform/windows/logging/etw_sink.h"
+#include "core/common/logging/sinks/composite_sink.h"
#else
#include
#if defined(__MACH__) || defined(__wasm__) || defined(_AIX)
@@ -243,5 +245,36 @@ unsigned int GetProcessId() {
#endif
}
+std::unique_ptr EnhanceLoggerWithEtw(std::unique_ptr existingLogger, logging::Severity originalSeverity,
+ logging::Severity etwSeverity) {
+#ifdef _WIN32
+ auto& manager = EtwRegistrationManager::Instance();
+ if (manager.IsEnabled()) {
+ auto compositeSink = std::make_unique();
+ compositeSink->AddSink(std::move(existingLogger), originalSeverity);
+ compositeSink->AddSink(std::make_unique(), etwSeverity);
+ return compositeSink;
+ } else {
+ return existingLogger;
+ }
+#else
+ // On non-Windows platforms, just return the existing logger
+ (void)originalSeverity;
+ (void)etwSeverity;
+ return existingLogger;
+#endif // _WIN32
+}
+
+Severity OverrideLevelWithEtw(Severity originalSeverity) {
+#ifdef _WIN32
+ auto& manager = logging::EtwRegistrationManager::Instance();
+ if (manager.IsEnabled() &&
+ (manager.Keyword() & static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Logs)) != 0) {
+ return manager.MapLevelToSeverity();
+ }
+#endif // _WIN32
+ return originalSeverity;
+}
+
} // namespace logging
} // namespace onnxruntime
diff --git a/onnxruntime/core/common/logging/sinks/composite_sink.h b/onnxruntime/core/common/logging/sinks/composite_sink.h
index f27abb9e6aad5..9d18eb527ffdd 100644
--- a/onnxruntime/core/common/logging/sinks/composite_sink.h
+++ b/onnxruntime/core/common/logging/sinks/composite_sink.h
@@ -5,6 +5,8 @@
#include
#include
+#include
+#include
#include "core/common/logging/isink.h"
#include "core/common/logging/logging.h"
@@ -27,20 +29,31 @@ class CompositeSink : public ISink {
/// Adds a sink. Takes ownership of the sink (so pass unique_ptr by value).
///
/// The sink.
+ /// The min severity to send a message to that sink
/// This instance to allow chaining.
- CompositeSink& AddSink(std::unique_ptr sink) {
- sinks_.push_back(std::move(sink));
+ CompositeSink& AddSink(std::unique_ptr sink, logging::Severity severity) {
+ sinks_with_severity_.emplace_back(std::move(sink), severity);
return *this;
}
+ ///
+ /// Gets a const reference to the collection of sinks and min severity for that sink
+ ///
+ /// A const reference to the vector pair of unique_ptr to ISink and severity.
+ const std::vector, logging::Severity>>& GetSinks() const {
+ return sinks_with_severity_;
+ }
+
private:
void SendImpl(const Timestamp& timestamp, const std::string& logger_id, const Capture& message) override {
- for (auto& sink : sinks_) {
- sink->Send(timestamp, logger_id, message);
+ for (auto& sink_pair : sinks_with_severity_) {
+ if (message.Severity() >= sink_pair.second) {
+ sink_pair.first->Send(timestamp, logger_id, message);
+ }
}
}
- std::vector> sinks_;
+ std::vector, logging::Severity>> sinks_with_severity_;
};
} // namespace logging
} // namespace onnxruntime
diff --git a/onnxruntime/core/framework/execution_providers.h b/onnxruntime/core/framework/execution_providers.h
index d97953fd9d5ea..61147e4367876 100644
--- a/onnxruntime/core/framework/execution_providers.h
+++ b/onnxruntime/core/framework/execution_providers.h
@@ -13,6 +13,7 @@
#include "core/graph/graph_viewer.h"
#include "core/common/logging/logging.h"
#ifdef _WIN32
+#include
#include "core/platform/tracing.h"
#endif
@@ -47,6 +48,8 @@ class ExecutionProviders {
TraceLoggingWrite(
telemetry_provider_handle,
"ProviderOptions",
+ TraceLoggingKeyword(static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Session)),
+ TraceLoggingLevel(WINEVENT_LEVEL_INFO),
TraceLoggingString(provider_id.c_str(), "ProviderId"),
TraceLoggingString(config_pair.first.c_str(), "Key"),
TraceLoggingString(config_pair.second.c_str(), "Value"));
diff --git a/onnxruntime/core/framework/sequential_executor.cc b/onnxruntime/core/framework/sequential_executor.cc
index ba68bc1d7d834..ea7f1397c961b 100644
--- a/onnxruntime/core/framework/sequential_executor.cc
+++ b/onnxruntime/core/framework/sequential_executor.cc
@@ -181,7 +181,7 @@ class SessionScope {
}
auto& logger = session_state_.Logger();
- LOGS(logger, VERBOSE) << "Begin execution";
+ VLOGS(logger, 0) << "Begin execution";
const SequentialExecutionPlan& seq_exec_plan = *session_state_.GetExecutionPlan();
const auto& exec_plan_vec = seq_exec_plan.execution_plan;
VLOGS(logger, 1) << "Size of execution plan vector: " << exec_plan_vec.size();
@@ -515,7 +515,7 @@ onnxruntime::Status ExecuteKernel(StreamExecutionContext& ctx,
return Status(status.Category(), status.Code(), msg_string);
}
ctx.RecycleNodeInputs(idx);
- LOGS(logger, VERBOSE) << "stream " << stream_idx << " launch kernel with idx " << idx;
+ VLOGS(logger, 0) << "stream " << stream_idx << " launch kernel with idx " << idx;
return Status::OK();
}
@@ -531,7 +531,7 @@ onnxruntime::Status ExecuteThePlan(const SessionState& session_state, gsl::span<
const bool only_execute_path_to_fetches,
bool single_thread_mode) {
auto* execution_plan = session_state.GetExecutionPlan();
- LOGS(logger, VERBOSE) << "Number of streams: " << execution_plan->execution_plan.size();
+ VLOGS(logger, 0) << "Number of streams: " << execution_plan->execution_plan.size();
int32_t valid_streams = 0;
for (auto& stream : execution_plan->execution_plan) {
if (stream && stream->steps_.size() > 0)
diff --git a/onnxruntime/core/framework/stream_execution_context.cc b/onnxruntime/core/framework/stream_execution_context.cc
index 4ff5ee5db865d..875e7f395bfa8 100644
--- a/onnxruntime/core/framework/stream_execution_context.cc
+++ b/onnxruntime/core/framework/stream_execution_context.cc
@@ -168,7 +168,7 @@ void StreamExecutionContext::RecycleNodeInputs(onnxruntime::NodeIndex node_index
for (auto idx : execution_plan->node_release_list[node_index]) {
if (--release_plan_[idx] == 0) {
ORT_ENFORCE(frame_.ReleaseMLValue(static_cast(execution_plan->release_actions[idx].value_index)).IsOK());
- LOGS(*logger_, VERBOSE) << "ort value " << execution_plan->release_actions[idx].value_index << " released";
+ VLOGS(*logger_, 0) << "ort value " << execution_plan->release_actions[idx].value_index << " released";
}
}
}
diff --git a/onnxruntime/core/optimizer/qdq_transformer/qdq_util.cc b/onnxruntime/core/optimizer/qdq_transformer/qdq_util.cc
index 221c06d7c8dcf..b1ab641a23256 100644
--- a/onnxruntime/core/optimizer/qdq_transformer/qdq_util.cc
+++ b/onnxruntime/core/optimizer/qdq_transformer/qdq_util.cc
@@ -54,9 +54,26 @@ bool IsQDQPairSupported(
Initializer dq_zp(*dq_zp_tensor_proto, model_path);
Initializer dq_scale(*dq_scale_tensor_proto, model_path);
- return q_zp.data_type() == dq_zp.data_type() &&
- SpanEq(q_zp.DataAsByteSpan(), dq_zp.DataAsByteSpan()) &&
- *q_scale.data() == *dq_scale.data();
+ if (q_zp.data_type() != dq_zp.data_type() ||
+ q_scale.data_type() != q_scale.data_type() ||
+ !SpanEq(q_zp.DataAsByteSpan(), dq_zp.DataAsByteSpan())) {
+ return false;
+ }
+
+ switch (q_scale.data_type()) {
+ case ONNX_NAMESPACE::TensorProto::FLOAT:
+ return *q_scale.data() == *dq_scale.data();
+
+ case ONNX_NAMESPACE::TensorProto::FLOAT16:
+ return *q_scale.data() == *dq_scale.data();
+
+ case ONNX_NAMESPACE::TensorProto::BFLOAT16:
+ return *q_scale.data() == *dq_scale.data();
+
+ default:
+ assert(false);
+ return false;
+ }
}
bool IsDQSupported(const Node& dq_node, const GetConstantInitializerFn& get_const_initializer) {
diff --git a/onnxruntime/core/platform/telemetry.cc b/onnxruntime/core/platform/telemetry.cc
index a99261d1d1caa..dc3b011cc7968 100644
--- a/onnxruntime/core/platform/telemetry.cc
+++ b/onnxruntime/core/platform/telemetry.cc
@@ -12,6 +12,21 @@ void LogRuntimeError(uint32_t sessionId, const common::Status& status, const cha
env.GetTelemetryProvider().LogRuntimeError(sessionId, status, file, function, line);
}
+bool Telemetry::IsEnabled() const {
+ return false;
+}
+
+// Get the current logging level
+// The Level defined as uchar is coming from the ETW Enable callback in TraceLoggingRegisterEx.
+unsigned char Telemetry::Level() const {
+ return 0;
+}
+
+// Get the current keyword
+uint64_t Telemetry::Keyword() const {
+ return 0;
+}
+
void Telemetry::EnableTelemetryEvents() const {
}
diff --git a/onnxruntime/core/platform/telemetry.h b/onnxruntime/core/platform/telemetry.h
index da808e73d97c3..7b61de9d54073 100644
--- a/onnxruntime/core/platform/telemetry.h
+++ b/onnxruntime/core/platform/telemetry.h
@@ -38,6 +38,14 @@ class Telemetry {
virtual void DisableTelemetryEvents() const;
virtual void SetLanguageProjection(uint32_t projection) const;
+ virtual bool IsEnabled() const;
+
+ // Get the current logging level
+ virtual unsigned char Level() const;
+
+ // Get the current keyword
+ virtual uint64_t Keyword() const;
+
virtual void LogProcessInfo() const;
virtual void LogSessionCreationStart() const;
diff --git a/onnxruntime/core/platform/windows/logging/etw_sink.cc b/onnxruntime/core/platform/windows/logging/etw_sink.cc
index 396695e6c570c..5fb7f7a65161d 100644
--- a/onnxruntime/core/platform/windows/logging/etw_sink.cc
+++ b/onnxruntime/core/platform/windows/logging/etw_sink.cc
@@ -58,42 +58,107 @@ TRACELOGGING_DEFINE_PROVIDER(etw_provider_handle, "ONNXRuntimeTraceLoggingProvid
#pragma warning(pop)
#endif
-// Class to unregister ETW provider at shutdown.
-// We expect one static instance to be created for the lifetime of the program.
-class EtwRegistrationManager {
- public:
- static EtwRegistrationManager& Register() {
- const HRESULT etw_status = ::TraceLoggingRegister(etw_provider_handle);
-
- if (FAILED(etw_status)) {
- ORT_THROW("ETW registration failed. Logging will be broken: " + std::to_string(etw_status));
- }
+EtwRegistrationManager& EtwRegistrationManager::Instance() {
+ static EtwRegistrationManager instance;
+ instance.LazyInitialize();
+ return instance;
+}
- // return an instance that is just used to unregister as the program exits
- static EtwRegistrationManager instance(etw_status);
- return instance;
- }
+bool EtwRegistrationManager::IsEnabled() const {
+ std::lock_guard lock(provider_change_mutex_);
+ return is_enabled_;
+}
+
+UCHAR EtwRegistrationManager::Level() const {
+ std::lock_guard lock(provider_change_mutex_);
+ return level_;
+}
- const HRESULT Status() const noexcept {
- return etw_status_;
+Severity EtwRegistrationManager::MapLevelToSeverity() {
+ switch (level_) {
+ case TRACE_LEVEL_NONE:
+ return Severity::kFATAL; // There is no none severity option
+ case TRACE_LEVEL_VERBOSE:
+ return Severity::kVERBOSE;
+ case TRACE_LEVEL_INFORMATION:
+ return Severity::kINFO;
+ case TRACE_LEVEL_WARNING:
+ return Severity::kWARNING;
+ case TRACE_LEVEL_ERROR:
+ return Severity::kERROR;
+ case TRACE_LEVEL_CRITICAL:
+ return Severity::kFATAL;
+ default:
+ return Severity::kVERBOSE;
}
+}
+
+ULONGLONG EtwRegistrationManager::Keyword() const {
+ std::lock_guard lock(provider_change_mutex_);
+ return keyword_;
+}
- ~EtwRegistrationManager() {
- ::TraceLoggingUnregister(etw_provider_handle);
+HRESULT EtwRegistrationManager::Status() const {
+ return etw_status_;
+}
+
+void EtwRegistrationManager::RegisterInternalCallback(const EtwInternalCallback& callback) {
+ std::lock_guard lock(callbacks_mutex_);
+ callbacks_.push_back(callback);
+}
+
+void NTAPI EtwRegistrationManager::ORT_TL_EtwEnableCallback(
+ _In_ LPCGUID SourceId,
+ _In_ ULONG IsEnabled,
+ _In_ UCHAR Level,
+ _In_ ULONGLONG MatchAnyKeyword,
+ _In_ ULONGLONG MatchAllKeyword,
+ _In_opt_ PEVENT_FILTER_DESCRIPTOR FilterData,
+ _In_opt_ PVOID CallbackContext) {
+ auto& manager = EtwRegistrationManager::Instance();
+ {
+ std::lock_guard lock(manager.provider_change_mutex_);
+ manager.is_enabled_ = (IsEnabled != 0);
+ manager.level_ = Level;
+ manager.keyword_ = MatchAnyKeyword;
}
+ manager.InvokeCallbacks(SourceId, IsEnabled, Level, MatchAnyKeyword, MatchAllKeyword, FilterData, CallbackContext);
+}
+
+EtwRegistrationManager::~EtwRegistrationManager() {
+ ::TraceLoggingUnregister(etw_provider_handle);
+}
- private:
- ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(EtwRegistrationManager);
+EtwRegistrationManager::EtwRegistrationManager() {
+}
- EtwRegistrationManager(const HRESULT status) noexcept : etw_status_{status} {}
- const HRESULT etw_status_;
-};
+void EtwRegistrationManager::LazyInitialize() {
+ if (!initialized_) {
+ std::lock_guard lock(init_mutex_);
+ if (!initialized_) { // Double-check locking pattern
+ initialized_ = true;
+ etw_status_ = ::TraceLoggingRegisterEx(etw_provider_handle, ORT_TL_EtwEnableCallback, nullptr);
+ if (FAILED(etw_status_)) {
+ ORT_THROW("ETW registration failed. Logging will be broken: " + std::to_string(etw_status_));
+ }
+ }
+ }
+}
+
+void EtwRegistrationManager::InvokeCallbacks(LPCGUID SourceId, ULONG IsEnabled, UCHAR Level, ULONGLONG MatchAnyKeyword,
+ ULONGLONG MatchAllKeyword, PEVENT_FILTER_DESCRIPTOR FilterData,
+ PVOID CallbackContext) {
+ std::lock_guard lock(callbacks_mutex_);
+ for (const auto& callback : callbacks_) {
+ callback(SourceId, IsEnabled, Level, MatchAnyKeyword, MatchAllKeyword, FilterData, CallbackContext);
+ }
+}
void EtwSink::SendImpl(const Timestamp& timestamp, const std::string& logger_id, const Capture& message) {
UNREFERENCED_PARAMETER(timestamp);
// register on first usage
- static EtwRegistrationManager& etw_manager = EtwRegistrationManager::Register();
+ static EtwRegistrationManager& etw_manager = EtwRegistrationManager::Instance();
// do something (not that meaningful) with etw_manager so it doesn't get optimized out
// as we want an instance around to do the unregister
@@ -101,9 +166,8 @@ void EtwSink::SendImpl(const Timestamp& timestamp, const std::string& logger_id,
return;
}
- // Do we want to output Verbose level messages via ETW at any point it time?
// TODO: Validate if this filtering makes sense.
- if (message.Severity() <= Severity::kVERBOSE || message.DataType() == DataType::USER) {
+ if (message.DataType() == DataType::USER) {
return;
}
@@ -114,11 +178,13 @@ void EtwSink::SendImpl(const Timestamp& timestamp, const std::string& logger_id,
// TraceLoggingWrite requires (painfully) a compile time constant for the TraceLoggingLevel,
// forcing us to use an ugly macro for the call.
#define ETW_EVENT_NAME "ONNXRuntimeLogEvent"
-#define TRACE_LOG_WRITE(level) \
- TraceLoggingWrite(etw_provider_handle, ETW_EVENT_NAME, TraceLoggingLevel(level), \
- TraceLoggingString(logger_id.c_str(), "logger"), \
- TraceLoggingString(message.Category(), "category"), \
- TraceLoggingString(message.Location().ToString().c_str(), "location"), \
+#define TRACE_LOG_WRITE(level) \
+ TraceLoggingWrite(etw_provider_handle, ETW_EVENT_NAME, \
+ TraceLoggingKeyword(static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Logs)), \
+ TraceLoggingLevel(level), \
+ TraceLoggingString(logger_id.c_str(), "logger"), \
+ TraceLoggingString(message.Category(), "category"), \
+ TraceLoggingString(message.Location().ToString().c_str(), "location"), \
TraceLoggingString(message.Message().c_str(), "message"))
const auto severity{message.Severity()};
diff --git a/onnxruntime/core/platform/windows/logging/etw_sink.h b/onnxruntime/core/platform/windows/logging/etw_sink.h
index 1e4f49a619302..143c3fcfdfc52 100644
--- a/onnxruntime/core/platform/windows/logging/etw_sink.h
+++ b/onnxruntime/core/platform/windows/logging/etw_sink.h
@@ -3,7 +3,9 @@
#pragma once
+#include
#include
+#include
// check for Windows 10 SDK or later
// https://stackoverflow.com/questions/2665755/how-can-i-determine-the-version-of-the-windows-sdk-installed-on-my-computer
@@ -18,9 +20,11 @@
#include
#include
#include
+#include
#include "core/common/logging/capture.h"
#include "core/common/logging/isink.h"
+#include "core/platform/ort_mutex.h"
namespace onnxruntime {
namespace logging {
@@ -41,6 +45,62 @@ class EtwSink : public ISink {
// EtwTracingManager to ensure we cleanly unregister it
static std::atomic_flag have_instance_;
};
+
+class EtwRegistrationManager {
+ public:
+ using EtwInternalCallback = std::function;
+
+ // Singleton instance access
+ static EtwRegistrationManager& Instance();
+
+ // Check if ETW logging is enabled
+ bool IsEnabled() const;
+
+ // Get the current logging level
+ UCHAR Level() const;
+
+ Severity MapLevelToSeverity();
+
+ // Get the current keyword
+ uint64_t Keyword() const;
+
+ // Get the ETW registration status
+ HRESULT Status() const;
+
+ void RegisterInternalCallback(const EtwInternalCallback& callback);
+
+ private:
+ EtwRegistrationManager();
+ ~EtwRegistrationManager();
+ void LazyInitialize();
+
+ ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(EtwRegistrationManager);
+
+ void InvokeCallbacks(LPCGUID SourceId, ULONG IsEnabled, UCHAR Level, ULONGLONG MatchAnyKeyword,
+ ULONGLONG MatchAllKeyword, PEVENT_FILTER_DESCRIPTOR FilterData, PVOID CallbackContext);
+
+ static void NTAPI ORT_TL_EtwEnableCallback(
+ _In_ LPCGUID SourceId,
+ _In_ ULONG IsEnabled,
+ _In_ UCHAR Level,
+ _In_ ULONGLONG MatchAnyKeyword,
+ _In_ ULONGLONG MatchAllKeyword,
+ _In_opt_ PEVENT_FILTER_DESCRIPTOR FilterData,
+ _In_opt_ PVOID CallbackContext);
+
+ std::vector callbacks_;
+ OrtMutex callbacks_mutex_;
+ mutable OrtMutex provider_change_mutex_;
+ OrtMutex init_mutex_;
+ bool initialized_ = false;
+ bool is_enabled_;
+ UCHAR level_;
+ ULONGLONG keyword_;
+ HRESULT etw_status_;
+};
+
} // namespace logging
} // namespace onnxruntime
diff --git a/onnxruntime/core/platform/windows/telemetry.cc b/onnxruntime/core/platform/windows/telemetry.cc
index ec49c2edc2125..a9849873fd060 100644
--- a/onnxruntime/core/platform/windows/telemetry.cc
+++ b/onnxruntime/core/platform/windows/telemetry.cc
@@ -2,6 +2,7 @@
// Licensed under the MIT License.
#include "core/platform/windows/telemetry.h"
+#include "core/common/logging/logging.h"
#include "onnxruntime_config.h"
// ETW includes
@@ -16,6 +17,7 @@
#include
#include
+#include
// Seems this workaround can be dropped when we drop support for VS2017 toolchains
// https://developercommunity.visualstudio.com/content/problem/85934/traceloggingproviderh-is-incompatible-with-utf-8.html
@@ -55,15 +57,18 @@ TRACELOGGING_DEFINE_PROVIDER(telemetry_provider_handle, "Microsoft.ML.ONNXRuntim
#endif
OrtMutex WindowsTelemetry::mutex_;
+OrtMutex WindowsTelemetry::provider_change_mutex_;
uint32_t WindowsTelemetry::global_register_count_ = 0;
bool WindowsTelemetry::enabled_ = true;
uint32_t WindowsTelemetry::projection_ = 0;
+UCHAR WindowsTelemetry::level_ = 0;
+UINT64 WindowsTelemetry::keyword_ = 0;
WindowsTelemetry::WindowsTelemetry() {
std::lock_guard lock(mutex_);
if (global_register_count_ == 0) {
// TraceLoggingRegister is fancy in that you can only register once GLOBALLY for the whole process
- HRESULT hr = TraceLoggingRegister(telemetry_provider_handle);
+ HRESULT hr = TraceLoggingRegisterEx(telemetry_provider_handle, ORT_TL_EtwEnableCallback, nullptr);
if (SUCCEEDED(hr)) {
global_register_count_ += 1;
}
@@ -80,6 +85,44 @@ WindowsTelemetry::~WindowsTelemetry() {
}
}
+bool WindowsTelemetry::IsEnabled() const {
+ std::lock_guard lock(provider_change_mutex_);
+ return enabled_;
+}
+
+UCHAR WindowsTelemetry::Level() const {
+ std::lock_guard lock(provider_change_mutex_);
+ return level_;
+}
+
+UINT64 WindowsTelemetry::Keyword() const {
+ std::lock_guard lock(provider_change_mutex_);
+ return keyword_;
+}
+
+// HRESULT WindowsTelemetry::Status() {
+// return etw_status_;
+// }
+
+void NTAPI WindowsTelemetry::ORT_TL_EtwEnableCallback(
+ _In_ LPCGUID SourceId,
+ _In_ ULONG IsEnabled,
+ _In_ UCHAR Level,
+ _In_ ULONGLONG MatchAnyKeyword,
+ _In_ ULONGLONG MatchAllKeyword,
+ _In_opt_ PEVENT_FILTER_DESCRIPTOR FilterData,
+ _In_opt_ PVOID CallbackContext) {
+ (void)SourceId;
+ (void)MatchAllKeyword;
+ (void)FilterData;
+ (void)CallbackContext;
+
+ std::lock_guard lock(provider_change_mutex_);
+ enabled_ = (IsEnabled != 0);
+ level_ = Level;
+ keyword_ = MatchAnyKeyword;
+}
+
void WindowsTelemetry::EnableTelemetryEvents() const {
enabled_ = true;
}
@@ -110,6 +153,7 @@ void WindowsTelemetry::LogProcessInfo() const {
TraceLoggingBool(true, "UTCReplace_AppSessionGuid"),
TelemetryPrivacyDataTag(PDT_ProductAndServiceUsage),
TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES),
+ TraceLoggingLevel(WINEVENT_LEVEL_INFO),
// Telemetry info
TraceLoggingUInt8(0, "schemaVersion"),
TraceLoggingString(ORT_VERSION, "runtimeVersion"),
@@ -126,7 +170,8 @@ void WindowsTelemetry::LogSessionCreationStart() const {
"SessionCreationStart",
TraceLoggingBool(true, "UTCReplace_AppSessionGuid"),
TelemetryPrivacyDataTag(PDT_ProductAndServiceUsage),
- TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES));
+ TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES),
+ TraceLoggingLevel(WINEVENT_LEVEL_INFO));
}
void WindowsTelemetry::LogEvaluationStop() const {
@@ -199,6 +244,8 @@ void WindowsTelemetry::LogSessionCreation(uint32_t session_id, int64_t ir_versio
TraceLoggingBool(true, "UTCReplace_AppSessionGuid"),
TelemetryPrivacyDataTag(PDT_ProductAndServiceUsage),
TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES),
+ TraceLoggingKeyword(static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Session)),
+ TraceLoggingLevel(WINEVENT_LEVEL_INFO),
// Telemetry info
TraceLoggingUInt8(0, "schemaVersion"),
TraceLoggingUInt32(session_id, "sessionId"),
@@ -227,6 +274,7 @@ void WindowsTelemetry::LogRuntimeError(uint32_t session_id, const common::Status
TraceLoggingBool(true, "UTCReplace_AppSessionGuid"),
TelemetryPrivacyDataTag(PDT_ProductAndServicePerformance),
TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES),
+ TraceLoggingLevel(WINEVENT_LEVEL_ERROR),
// Telemetry info
TraceLoggingUInt8(0, "schemaVersion"),
TraceLoggingHResult(hr, "hResult"),
@@ -243,6 +291,7 @@ void WindowsTelemetry::LogRuntimeError(uint32_t session_id, const common::Status
TraceLoggingBool(true, "UTCReplace_AppSessionGuid"),
TelemetryPrivacyDataTag(PDT_ProductAndServicePerformance),
TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES),
+ TraceLoggingLevel(WINEVENT_LEVEL_ERROR),
// Telemetry info
TraceLoggingUInt8(0, "schemaVersion"),
TraceLoggingUInt32(session_id, "sessionId"),
diff --git a/onnxruntime/core/platform/windows/telemetry.h b/onnxruntime/core/platform/windows/telemetry.h
index 08e48214c85b3..c3798943d491d 100644
--- a/onnxruntime/core/platform/windows/telemetry.h
+++ b/onnxruntime/core/platform/windows/telemetry.h
@@ -3,6 +3,8 @@
#pragma once
#include "core/platform/telemetry.h"
+#include
+#include
#include "core/platform/ort_mutex.h"
#include "core/platform/windows/TraceLoggingConfig.h"
#include
@@ -22,6 +24,17 @@ class WindowsTelemetry : public Telemetry {
void DisableTelemetryEvents() const override;
void SetLanguageProjection(uint32_t projection) const override;
+ bool IsEnabled() const override;
+
+ // Get the current logging level
+ unsigned char Level() const override;
+
+ // Get the current keyword
+ UINT64 Keyword() const override;
+
+ // Get the ETW registration status
+ // static HRESULT Status();
+
void LogProcessInfo() const override;
void LogSessionCreationStart() const override;
@@ -50,6 +63,19 @@ class WindowsTelemetry : public Telemetry {
static uint32_t global_register_count_;
static bool enabled_;
static uint32_t projection_;
+
+ static OrtMutex provider_change_mutex_;
+ static UCHAR level_;
+ static ULONGLONG keyword_;
+
+ static void NTAPI ORT_TL_EtwEnableCallback(
+ _In_ LPCGUID SourceId,
+ _In_ ULONG IsEnabled,
+ _In_ UCHAR Level,
+ _In_ ULONGLONG MatchAnyKeyword,
+ _In_ ULONGLONG MatchAllKeyword,
+ _In_opt_ PEVENT_FILTER_DESCRIPTOR FilterData,
+ _In_opt_ PVOID CallbackContext);
};
} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc
index f60c7ddac5c05..9cd0b3d0620af 100644
--- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc
+++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc
@@ -989,6 +989,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Float8E5M2FNUZ, IsNaN);
#endif
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, IsInf);
+class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, StringConcat);
+class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, RegexFullMatch);
// !!PLEASE READ BELOW!! Following that, add new entries above this comment
@@ -2447,6 +2449,8 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
#endif
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
};
for (auto& function_table_entry : function_table) {
diff --git a/onnxruntime/core/providers/cpu/text/regex_full_match.cc b/onnxruntime/core/providers/cpu/text/regex_full_match.cc
new file mode 100644
index 0000000000000..cc4a5a9ae4e61
--- /dev/null
+++ b/onnxruntime/core/providers/cpu/text/regex_full_match.cc
@@ -0,0 +1,35 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "regex_full_match.h"
+#include "core/common/common.h"
+
+namespace onnxruntime {
+ONNX_CPU_OPERATOR_KERNEL(
+ RegexFullMatch,
+ 20,
+ KernelDefBuilder()
+ .TypeConstraint("T1", DataTypeImpl::GetTensorType())
+ .TypeConstraint("T2", DataTypeImpl::GetTensorType()),
+ RegexFullMatch);
+
+RegexFullMatch::RegexFullMatch(const OpKernelInfo& info) : OpKernel(info), re_{info.GetAttr("pattern")} {
+ ORT_ENFORCE(re_.ok(), "Invalid regex pattern: ", re_.pattern());
+}
+
+Status RegexFullMatch::Compute(OpKernelContext* context) const {
+ const auto* input_tensor = context->Input(0);
+ const auto input_data = input_tensor->template DataAsSpan();
+ auto* output_tensor = context->Output(0, input_tensor->Shape());
+ auto output_data = output_tensor->template MutableDataAsSpan();
+ auto output_iter = output_data.begin();
+ auto input_iter = input_data.begin();
+ while (input_iter != input_data.end()) {
+ *output_iter = RE2::FullMatch(*input_iter, re_);
+ input_iter++;
+ output_iter++;
+ }
+ return Status::OK();
+}
+
+} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/cpu/text/regex_full_match.h b/onnxruntime/core/providers/cpu/text/regex_full_match.h
new file mode 100644
index 0000000000000..0d3f1f4b4b824
--- /dev/null
+++ b/onnxruntime/core/providers/cpu/text/regex_full_match.h
@@ -0,0 +1,20 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+
+#include "core/framework/op_kernel.h"
+#include "re2/re2.h"
+
+namespace onnxruntime {
+
+class RegexFullMatch final : public OpKernel {
+ public:
+ explicit RegexFullMatch(const OpKernelInfo& info);
+ Status Compute(OpKernelContext* context) const override;
+
+ private:
+ RE2 re_;
+};
+
+} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/cpu/text/string_concat.cc b/onnxruntime/core/providers/cpu/text/string_concat.cc
new file mode 100644
index 0000000000000..bc626f8e055aa
--- /dev/null
+++ b/onnxruntime/core/providers/cpu/text/string_concat.cc
@@ -0,0 +1,60 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "string_concat.h"
+#include "core/providers/cpu/math/element_wise_ops.h"
+#include "core/common/common.h"
+
+namespace onnxruntime {
+ONNX_CPU_OPERATOR_KERNEL(StringConcat, 20,
+ KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()),
+ StringConcat);
+
+Status StringConcat::Compute(OpKernelContext* context) const {
+ ProcessBroadcastSpanFuncs broadcast_funcs{[](BroadcastHelper& broadcast_helper) {
+ auto x = broadcast_helper.ScalarInput0();
+ auto y = broadcast_helper.SpanInput1();
+ auto y_iter = y.begin();
+ auto output_iter = broadcast_helper.OutputSpan().begin();
+ const auto x_size = x.length();
+ while (y_iter != y.end()) {
+ output_iter->reserve(x_size + y_iter->length());
+ output_iter->append(x);
+ output_iter->append(*y_iter);
+ y_iter++;
+ output_iter++;
+ }
+ },
+ [](BroadcastHelper& broadcast_helper) {
+ auto x = broadcast_helper.SpanInput0();
+ auto x_iter = x.begin();
+ auto y = broadcast_helper.ScalarInput1();
+ auto output_iter = broadcast_helper.OutputSpan().begin();
+ const auto y_size = y.length();
+ while (x_iter != x.end()) {
+ output_iter->reserve(y_size + x_iter->length());
+ output_iter->append(*x_iter);
+ output_iter->append(y);
+ x_iter++;
+ output_iter++;
+ }
+ },
+ [](BroadcastHelper& broadcast_helper) {
+ auto x_iter = broadcast_helper.SpanInput0().begin();
+ auto y_iter = broadcast_helper.SpanInput1().begin();
+ auto output = broadcast_helper.OutputSpan();
+ auto output_iter = output.begin();
+ while (output_iter != output.end()) {
+ output_iter->reserve(x_iter->length() + y_iter->length());
+ output_iter->append(*x_iter);
+ output_iter->append(*y_iter);
+ x_iter++;
+ y_iter++;
+ output_iter++;
+ }
+ }};
+ UntypedBroadcastTwo(*context, broadcast_funcs);
+ return Status::OK();
+}
+
+} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/cpu/text/string_concat.h b/onnxruntime/core/providers/cpu/text/string_concat.h
new file mode 100644
index 0000000000000..63c1ea8a41146
--- /dev/null
+++ b/onnxruntime/core/providers/cpu/text/string_concat.h
@@ -0,0 +1,17 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+
+#include "core/framework/op_kernel.h"
+
+namespace onnxruntime {
+
+class StringConcat final : public OpKernel {
+ public:
+ StringConcat(const OpKernelInfo& info) : OpKernel(info) {}
+
+ Status Compute(OpKernelContext* context) const override;
+};
+
+} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/cpu/nn/string_normalizer.cc b/onnxruntime/core/providers/cpu/text/string_normalizer.cc
similarity index 100%
rename from onnxruntime/core/providers/cpu/nn/string_normalizer.cc
rename to onnxruntime/core/providers/cpu/text/string_normalizer.cc
diff --git a/onnxruntime/core/providers/cpu/nn/string_normalizer.h b/onnxruntime/core/providers/cpu/text/string_normalizer.h
similarity index 100%
rename from onnxruntime/core/providers/cpu/nn/string_normalizer.h
rename to onnxruntime/core/providers/cpu/text/string_normalizer.h
diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
index d8a0792209b0f..f7b23f12e8193 100644
--- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
+++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
@@ -2465,7 +2465,8 @@ void CUDAExecutionProvider::RegisterStreamHandlers(IStreamCommandHandleRegistry&
stream_,
use_ep_level_unified_stream_,
GetPerThreadContext().CudnnHandle(),
- GetPerThreadContext().CublasHandle());
+ GetPerThreadContext().CublasHandle(),
+ info_);
}
OrtDevice CUDAExecutionProvider::GetOrtDeviceByMemType(OrtMemType mem_type) const {
diff --git a/onnxruntime/core/providers/cuda/cuda_stream_handle.cc b/onnxruntime/core/providers/cuda/cuda_stream_handle.cc
index 9aad461b1d1c1..7c866395ecf6e 100644
--- a/onnxruntime/core/providers/cuda/cuda_stream_handle.cc
+++ b/onnxruntime/core/providers/cuda/cuda_stream_handle.cc
@@ -62,11 +62,13 @@ CudaStream::CudaStream(cudaStream_t stream,
bool release_cpu_buffer_on_cuda_stream,
bool own_flag,
cudnnHandle_t external_cudnn_handle,
- cublasHandle_t external_cublas_handle) : Stream(stream, device),
- own_stream_(own_flag),
- cpu_allocator_(cpu_allocator),
- release_cpu_buffer_on_cuda_stream_(release_cpu_buffer_on_cuda_stream),
- deferred_cpu_allocator_(*this) {
+ cublasHandle_t external_cublas_handle,
+ const CUDAExecutionProviderInfo& ep_info) : Stream(stream, device),
+ own_stream_(own_flag),
+ cpu_allocator_(cpu_allocator),
+ release_cpu_buffer_on_cuda_stream_(release_cpu_buffer_on_cuda_stream),
+ deferred_cpu_allocator_(*this),
+ ep_info_(ep_info) {
if (own_flag) {
CUBLAS_CALL_THROW(cublasCreate(&cublas_handle_));
CUBLAS_CALL_THROW(cublasSetStream(cublas_handle_, stream));
@@ -185,6 +187,27 @@ void* CudaStream::GetResource(int version, int id) const {
case CudaResource::deferred_cpu_allocator_t:
return const_cast(&deferred_cpu_allocator_);
break;
+ case CudaResource::device_id_t:
+ return reinterpret_cast(ep_info_.device_id);
+ break;
+ case CudaResource::arena_extend_strategy_t:
+ return reinterpret_cast(ep_info_.arena_extend_strategy);
+ break;
+ case CudaResource::cudnn_conv_algo_search_t:
+ return reinterpret_cast(ep_info_.cudnn_conv_algo_search);
+ break;
+ case CudaResource::cudnn_conv_use_max_workspace_t:
+ return reinterpret_cast(ep_info_.cudnn_conv_use_max_workspace);
+ break;
+ case CudaResource::cudnn_conv1d_pad_to_nc1d_t:
+ return reinterpret_cast(ep_info_.cudnn_conv1d_pad_to_nc1d);
+ break;
+ case CudaResource::enable_skip_layer_norm_strict_mode_t:
+ return reinterpret_cast(ep_info_.enable_skip_layer_norm_strict_mode);
+ break;
+ case CudaResource::prefer_nhwc_t:
+ return reinterpret_cast(ep_info_.prefer_nhwc);
+ break;
default:
break;
}
@@ -207,26 +230,28 @@ void RegisterCudaStreamHandles(IStreamCommandHandleRegistry& stream_handle_regis
cudaStream_t external_stream,
bool use_existing_stream,
cudnnHandle_t external_cudnn_handle,
- cublasHandle_t external_cublas_handle) {
+ cublasHandle_t external_cublas_handle,
+ const CUDAExecutionProviderInfo& ep_info) {
// wait cuda notification on cuda ep
stream_handle_registry.RegisterWaitFn(device_type, device_type, WaitCudaNotificationOnDevice);
// wait cuda notification on cpu ep
stream_handle_registry.RegisterWaitFn(device_type, OrtDevice::CPU, WaitCudaNotificationOnHost);
if (!use_existing_stream)
- stream_handle_registry.RegisterCreateStreamFn(device_type, [cpu_allocator, release_cpu_buffer_on_cuda_stream](const OrtDevice& device) {
+ stream_handle_registry.RegisterCreateStreamFn(device_type, [cpu_allocator, release_cpu_buffer_on_cuda_stream, ep_info](const OrtDevice& device) {
CUDA_CALL_THROW(cudaSetDevice(device.Id()));
cudaStream_t stream = nullptr;
CUDA_CALL_THROW(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
// CUDA_CALL_THROW(cudaStreamCreate(&stream));
- return std::make_unique(stream, device, cpu_allocator, release_cpu_buffer_on_cuda_stream, true, nullptr, nullptr);
+ return std::make_unique(stream, device, cpu_allocator, release_cpu_buffer_on_cuda_stream, true, nullptr, nullptr, ep_info);
});
else
stream_handle_registry.RegisterCreateStreamFn(device_type, [cpu_allocator,
release_cpu_buffer_on_cuda_stream,
external_stream,
external_cudnn_handle,
- external_cublas_handle](const OrtDevice& device) {
- return std::make_unique(external_stream, device, cpu_allocator, release_cpu_buffer_on_cuda_stream, false, external_cudnn_handle, external_cublas_handle);
+ external_cublas_handle,
+ ep_info](const OrtDevice& device) {
+ return std::make_unique(external_stream, device, cpu_allocator, release_cpu_buffer_on_cuda_stream, false, external_cudnn_handle, external_cublas_handle, ep_info);
});
}
diff --git a/onnxruntime/core/providers/cuda/cuda_stream_handle.h b/onnxruntime/core/providers/cuda/cuda_stream_handle.h
index 917702fae08f1..b02c167e9e9ec 100644
--- a/onnxruntime/core/providers/cuda/cuda_stream_handle.h
+++ b/onnxruntime/core/providers/cuda/cuda_stream_handle.h
@@ -6,6 +6,7 @@
#include "core/providers/cuda/shared_inc/cuda_utils.h"
#include "core/providers/cuda/shared_inc/cuda_call.h"
#include "core/framework/stream_handles.h"
+#include "core/providers/cuda/cuda_execution_provider_info.h"
namespace onnxruntime {
@@ -23,7 +24,8 @@ struct CudaStream : Stream {
bool release_cpu_buffer_on_cuda_stream,
bool own_flag,
cudnnHandle_t external_cudnn_handle,
- cublasHandle_t external_cublass_handle);
+ cublasHandle_t external_cublass_handle,
+ const CUDAExecutionProviderInfo& ep_info);
~CudaStream();
@@ -50,6 +52,7 @@ struct CudaStream : Stream {
AllocatorPtr cpu_allocator_;
bool release_cpu_buffer_on_cuda_stream_{true};
DeferredCpuAllocator deferred_cpu_allocator_;
+ const CUDAExecutionProviderInfo ep_info_;
};
void RegisterCudaStreamHandles(IStreamCommandHandleRegistry& stream_handle_registry,
@@ -59,6 +62,7 @@ void RegisterCudaStreamHandles(IStreamCommandHandleRegistry& stream_handle_regis
cudaStream_t external_stream,
bool use_existing_stream,
cudnnHandle_t external_cudnn_handle,
- cublasHandle_t external_cublass_handle);
+ cublasHandle_t external_cublass_handle,
+ const CUDAExecutionProviderInfo& ep_info);
void WaitCudaNotificationOnDevice(Stream& stream, synchronize::Notification& notification);
} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp
index 835d43037eaee..ab8ddbfe91bf0 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp
@@ -558,7 +558,11 @@ class DmlOperatorElementwiseQLinear : public DmlOperator
{
ML_CHECK_VALID_ARGUMENT(axis < outputShapeDimCount);
uint32_t broadcastAxisLength = outputShape[axis];
- ML_CHECK_VALID_ARGUMENT(inputTensorShape[0] == broadcastAxisLength);
+ ML_CHECK_VALID_ARGUMENT(
+ (inputTensorShape[0] == broadcastAxisLength) ||
+ // Treat as broadcast dimension to match CPU behavior.
+ (inputTensorShape[0] == 1)
+ );
inputTensorShape.insert(inputTensorShape.begin(), axis, 1);
inputTensorShape.insert(inputTensorShape.end(), outputShapeDimCount - 1 - axis, 1);
}
diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc
index 38d74909db86b..ca6a2238e520d 100644
--- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc
+++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc
@@ -18,6 +18,11 @@
#include "core/common/logging/capture.h"
#include "core/providers/qnn/builder/onnx_ctx_model_helper.h"
+#ifdef _WIN32
+#include
+#include "core/platform/tracing.h"
+#endif
+
// Flag to determine if Backend should do node validation for each opNode added
#define DO_GRAPH_NODE_VALIDATIONS 1
@@ -843,28 +848,46 @@ Status QnnBackendManager::ExtractBackendProfilingInfo() {
LOGS(*logger_, VERBOSE) << "The QNN backend does not support extended event data.";
}
- // Write to CSV in append mode
- const char* profilingCsvFilename = "qnn-profiling-data.csv";
- std::ifstream infile(profilingCsvFilename);
- bool exists = infile.good();
- infile.close();
-
- std::ofstream outfile(profilingCsvFilename, std::ios_base::app);
- ORT_RETURN_IF(!outfile.is_open(), "Failed to open qnn-profiling-data.csv");
- // If file didn't exist before, write the header
- if (!exists) {
- outfile << "Msg Timestamp,Message,Time,Unit of Measurement,Timing Source,Event Level,Event Identifier\n";
+ bool tracelogging_provider_ep_enabled = false;
+ const Env& env = Env::Default();
+ auto& provider = env.GetTelemetryProvider();
+ if (provider.IsEnabled()) {
+ auto keyword = provider.Keyword();
+ if ((keyword & static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Profiling)) != 0) {
+ tracelogging_provider_ep_enabled = true;
+ }
+ }
+ std::ofstream outfile;
+ if (!tracelogging_provider_ep_enabled) {
+ // Write to CSV in append mode
+ const char* profilingCsvFilename = "qnn-profiling-data.csv";
+ std::ifstream infile(profilingCsvFilename);
+ bool exists = infile.good();
+ infile.close();
+
+ outfile.open(profilingCsvFilename, std::ios_base::app);
+ ORT_RETURN_IF(!outfile.is_open(), "Failed to open qnn-profiling-data.csv");
+ // If file didn't exist before, write the header
+ if (!exists) {
+ outfile << "Msg Timestamp,Message,Time,Unit of Measurement,Timing Source,Event Level,Event Identifier\n";
+ }
}
for (size_t event_idx = 0; event_idx < num_events; event_idx++) {
ORT_RETURN_IF_ERROR(
- ExtractProfilingEvent(*(profile_events + event_idx), "ROOT", outfile, backendSupportsExtendedEventData));
+ ExtractProfilingEvent(*(profile_events + event_idx), "ROOT", outfile, backendSupportsExtendedEventData,
+ tracelogging_provider_ep_enabled));
ORT_RETURN_IF_ERROR(
- ExtractProfilingSubEvents(*(profile_events + event_idx), outfile, backendSupportsExtendedEventData));
+ ExtractProfilingSubEvents(*(profile_events + event_idx), outfile, backendSupportsExtendedEventData,
+ tracelogging_provider_ep_enabled));
}
- outfile.close();
- LOGS(*logger_, INFO) << "Wrote QNN profiling events (" << num_events << ") to qnn-profiling-data.csv";
+ if (!tracelogging_provider_ep_enabled) {
+ outfile.close();
+ LOGS(*logger_, VERBOSE) << "Wrote QNN profiling events (" << num_events << ") to qnn-profiling-data.csv";
+ } else {
+ LOGS(*logger_, VERBOSE) << "Wrote QNN profiling events (" << num_events << ") to ETW";
+ }
}
return Status::OK();
@@ -873,7 +896,8 @@ Status QnnBackendManager::ExtractBackendProfilingInfo() {
Status QnnBackendManager::ExtractProfilingSubEvents(
QnnProfile_EventId_t profile_event_id,
std::ofstream& outfile,
- bool useExtendedEventData) {
+ bool useExtendedEventData,
+ bool tracelogging_provider_ep_enabled) {
const QnnProfile_EventId_t* profile_sub_events{nullptr};
uint32_t num_sub_events{0};
auto result = qnn_interface_.profileGetSubEvents(profile_event_id, &profile_sub_events, &num_sub_events);
@@ -884,12 +908,14 @@ Status QnnBackendManager::ExtractProfilingSubEvents(
for (size_t sub_event_idx = 0; sub_event_idx < num_sub_events; sub_event_idx++) {
ORT_RETURN_IF_ERROR(
- ExtractProfilingEvent(*(profile_sub_events + sub_event_idx), "SUB-EVENT", outfile, useExtendedEventData));
+ ExtractProfilingEvent(*(profile_sub_events + sub_event_idx), "SUB-EVENT", outfile, useExtendedEventData,
+ tracelogging_provider_ep_enabled));
ORT_RETURN_IF_ERROR(
- ExtractProfilingSubEvents(*(profile_sub_events + sub_event_idx), outfile, useExtendedEventData));
+ ExtractProfilingSubEvents(*(profile_sub_events + sub_event_idx), outfile, useExtendedEventData,
+ tracelogging_provider_ep_enabled));
}
- LOGS(*logger_, INFO) << "Wrote QNN profiling sub events (" << num_sub_events << ") to qnn-profiling-data.csv";
+ LOGS(*logger_, VERBOSE) << "Wrote QNN profiling sub events (" << num_sub_events << ")";
}
return Status::OK();
@@ -899,18 +925,20 @@ Status QnnBackendManager::ExtractProfilingEvent(
QnnProfile_EventId_t profile_event_id,
const std::string& eventLevel,
std::ofstream& outfile,
- bool useExtendedEventData) {
+ bool useExtendedEventData,
+ bool tracelogging_provider_ep_enabled) {
if (useExtendedEventData) {
- return ExtractProfilingEventExtended(profile_event_id, eventLevel, outfile);
+ return ExtractProfilingEventExtended(profile_event_id, eventLevel, outfile, tracelogging_provider_ep_enabled);
} else {
- return ExtractProfilingEventBasic(profile_event_id, eventLevel, outfile);
+ return ExtractProfilingEventBasic(profile_event_id, eventLevel, outfile, tracelogging_provider_ep_enabled);
}
}
Status QnnBackendManager::ExtractProfilingEventBasic(
QnnProfile_EventId_t profile_event_id,
const std::string& eventLevel,
- std::ofstream& outfile) {
+ std::ofstream& outfile,
+ bool tracelogging_provider_ep_enabled) {
QnnProfile_EventData_t event_data;
auto result = qnn_interface_.profileGetEventData(profile_event_id, &event_data);
QnnProfile_Error_t errorCode = static_cast(result & 0xFFFF);
@@ -919,15 +947,32 @@ Status QnnBackendManager::ExtractProfilingEventBasic(
std::string message = GetEventTypeString(event_data.type);
std::string unit = GetUnitString(event_data.unit);
- outfile << "UNKNOWN"
- << ","
- << message << ","
- << event_data.value << ","
- << unit << ","
- << "BACKEND"
- << ","
- << eventLevel << ","
- << (event_data.identifier ? event_data.identifier : "NULL") << "\n";
+#ifndef _WIN32
+ tracelogging_provider_ep_enabled = false;
+#endif
+
+ if (!tracelogging_provider_ep_enabled) {
+ outfile << "UNKNOWN"
+ << ","
+ << message << ","
+ << event_data.value << ","
+ << unit << ","
+ << "BACKEND"
+ << ","
+ << eventLevel << ","
+ << (event_data.identifier ? event_data.identifier : "NULL") << "\n";
+ } else {
+#ifdef _WIN32
+ LogQnnProfileEventAsTraceLogging(
+ (uint64_t)0,
+ message,
+ std::to_string(event_data.value),
+ unit,
+ "BACKEND",
+ eventLevel,
+ (event_data.identifier ? event_data.identifier : "NULL"));
+#endif
+ }
return Status::OK();
}
@@ -935,7 +980,8 @@ Status QnnBackendManager::ExtractProfilingEventBasic(
Status QnnBackendManager::ExtractProfilingEventExtended(
QnnProfile_EventId_t profile_event_id,
const std::string& eventLevel,
- std::ofstream& outfile) {
+ std::ofstream& outfile,
+ bool tracelogging_provider_ep_enabled) {
QnnProfile_ExtendedEventData_t event_data_extended;
auto resultGetExtendedEventData = qnn_interface_.profileGetExtendedEventData(profile_event_id, &event_data_extended);
QnnProfile_Error_t errorCode = static_cast(resultGetExtendedEventData & 0xFFFF);
@@ -944,20 +990,61 @@ Status QnnBackendManager::ExtractProfilingEventExtended(
std::string message = GetEventTypeString(event_data_extended.v1.type);
std::string unit = GetUnitString(event_data_extended.v1.unit);
- if (event_data_extended.version == QNN_PROFILE_DATA_VERSION_1) {
- outfile << event_data_extended.v1.timestamp << ","
- << message << ","
- << ExtractQnnScalarValue(event_data_extended.v1.value) << ","
- << unit << ","
- << "BACKEND"
- << ","
- << eventLevel << ","
- << (event_data_extended.v1.identifier ? event_data_extended.v1.identifier : "NULL") << "\n";
+#ifndef _WIN32
+ tracelogging_provider_ep_enabled = false;
+#endif
+
+ if (!tracelogging_provider_ep_enabled) {
+ if (event_data_extended.version == QNN_PROFILE_DATA_VERSION_1) {
+ outfile << event_data_extended.v1.timestamp << ","
+ << message << ","
+ << ExtractQnnScalarValue(event_data_extended.v1.value) << ","
+ << unit << ","
+ << "BACKEND"
+ << ","
+ << eventLevel << ","
+ << (event_data_extended.v1.identifier ? event_data_extended.v1.identifier : "NULL") << "\n";
+ }
+ } else {
+#ifdef _WIN32
+ LogQnnProfileEventAsTraceLogging(
+ event_data_extended.v1.timestamp,
+ message,
+ ExtractQnnScalarValue(event_data_extended.v1.value),
+ unit,
+ "BACKEND",
+ eventLevel,
+ (event_data_extended.v1.identifier ? event_data_extended.v1.identifier : "NULL"));
+#endif
}
return Status::OK();
}
+#ifdef _WIN32
+void QnnBackendManager::LogQnnProfileEventAsTraceLogging(
+ uint64_t timestamp,
+ const std::string& message,
+ const std::string& qnnScalarValue,
+ const std::string& unit,
+ const std::string& timingSource,
+ const std::string& eventLevel,
+ const char* eventIdentifier) {
+ TraceLoggingWrite(
+ telemetry_provider_handle,
+ "QNNProfilingEvent",
+ TraceLoggingKeyword(static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Profiling)),
+ TraceLoggingLevel(WINEVENT_LEVEL_VERBOSE),
+ TraceLoggingValue(timestamp, "Timestamp"),
+ TraceLoggingString(message.c_str(), "Message"),
+ TraceLoggingString(qnnScalarValue.c_str(), "Value"),
+ TraceLoggingString(unit.c_str(), "Unit of Measurement"),
+ TraceLoggingString(timingSource.c_str(), "Timing Source"),
+ TraceLoggingString(eventLevel.c_str(), "Event Level"),
+ TraceLoggingString(eventIdentifier, "Event Identifier"));
+}
+#endif
+
const std::string& QnnBackendManager::GetUnitString(QnnProfile_EventUnit_t unitType) {
const auto& unitStringMap = GetUnitStringMap();
auto it = unitStringMap.find(unitType);
diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h
index bc05820da2f73..58f207efb9e95 100644
--- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h
+++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h
@@ -6,10 +6,15 @@
#include
#include
#include
+#include
#else
#include
#endif
+#include
+#include
+#include
+#include
#include "HTP/QnnHtpDevice.h"
#include "QnnLog.h"
#include "System/QnnSystemInterface.h"
@@ -117,8 +122,11 @@ class QnnBackendManager {
void Split(std::vector& split_string, const std::string& tokenized_string, const char separator);
Status ExtractBackendProfilingInfo();
- Status ExtractProfilingSubEvents(QnnProfile_EventId_t profile_event_id, std::ofstream& outfile, bool backendSupportsExtendedEventData);
- Status ExtractProfilingEvent(QnnProfile_EventId_t profile_event_id, const std::string& eventLevel, std::ofstream& outfile, bool backendSupportsExtendedEventData);
+ Status ExtractProfilingSubEvents(QnnProfile_EventId_t profile_event_id, std::ofstream& outfile,
+ bool backendSupportsExtendedEventData, bool tracelogging_provider_ep_enabled);
+ Status ExtractProfilingEvent(QnnProfile_EventId_t profile_event_id, const std::string& eventLevel,
+ std::ofstream& outfile, bool backendSupportsExtendedEventData,
+ bool tracelogging_provider_ep_enabled);
void SetQnnBackendType(uint32_t backend_id);
QnnBackendType GetQnnBackendType() { return qnn_backend_type_; }
@@ -175,13 +183,25 @@ class QnnBackendManager {
return (backend_build_id == nullptr ? std::string("") : std::string(backend_build_id));
}
- Status ExtractProfilingEventBasic(QnnProfile_EventId_t profile_event_id, const std::string& eventLevel, std::ofstream& outfile);
- Status ExtractProfilingEventExtended(QnnProfile_EventId_t profile_event_id, const std::string& eventLevel, std::ofstream& outfile);
+ Status ExtractProfilingEventBasic(QnnProfile_EventId_t profile_event_id, const std::string& eventLevel,
+ std::ofstream& outfile, bool tracelogging_provider_ep_enabled);
+ Status ExtractProfilingEventExtended(QnnProfile_EventId_t profile_event_id, const std::string& eventLevel,
+ std::ofstream& outfile, bool tracelogging_provider_ep_enabled);
static const std::string& GetUnitString(QnnProfile_EventUnit_t unitType);
static const std::unordered_map& GetUnitStringMap();
static const std::string GetEventTypeString(QnnProfile_EventType_t eventType);
static const std::string ExtractQnnScalarValue(const Qnn_Scalar_t& scalar);
const char* QnnProfileErrorToString(QnnProfile_Error_t error);
+#ifdef _WIN32
+ void LogQnnProfileEventAsTraceLogging(
+ uint64_t timestamp,
+ const std::string& message,
+ const std::string& qnnScalarValue,
+ const std::string& unit,
+ const std::string& timingSource,
+ const std::string& eventLevel,
+ const char* eventIdentifier);
+#endif
private:
const std::string backend_path_;
diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc
index c72012fd4a19b..e5856e85e19e8 100644
--- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc
+++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc
@@ -4,12 +4,13 @@
#include "qnn_execution_provider.h"
#include
-#include "core/providers/common.h"
#include "core/framework/compute_capability.h"
#include "core/graph/graph_viewer.h"
#include "core/session/onnxruntime_session_options_config_keys.h"
#include "core/session/onnxruntime_cxx_api.h"
#include "core/framework/kernel_registry.h"
+#include "core/platform/env.h"
+#include "core/providers/common.h"
#include "core/providers/partitioning_utils.h"
#include "core/providers/qnn/builder/op_builder_factory.h"
#include "core/providers/partitioning_utils.h"
@@ -28,7 +29,7 @@ static void ParseProfilingLevel(std::string profiling_level_string,
profiling_level_string.end(),
profiling_level_string.begin(),
[](unsigned char c) { return static_cast(std::tolower(c)); });
- LOGS_DEFAULT(VERBOSE) << "profiling_level: " << profiling_level_string;
+ LOGS_DEFAULT(INFO) << "profiling_level: " << profiling_level_string;
if (profiling_level_string == "off") {
profiling_level = qnn::ProfilingLevel::OFF;
} else if (profiling_level_string == "basic") {
@@ -146,9 +147,30 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio
static const std::string PROFILING_LEVEL = "profiling_level";
qnn::ProfilingLevel profiling_level = qnn::ProfilingLevel::OFF;
- auto profiling_level_pos = provider_options_map.find(PROFILING_LEVEL);
- if (profiling_level_pos != provider_options_map.end()) {
- ParseProfilingLevel(profiling_level_pos->second, profiling_level);
+ const Env& env = Env::Default();
+ auto& provider = env.GetTelemetryProvider();
+ if (provider.IsEnabled()) {
+ auto level = provider.Level();
+ auto keyword = provider.Keyword();
+ if ((keyword & static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Profiling)) != 0) {
+ if (level != 0) {
+ if (level == 5) {
+ LOGS_DEFAULT(INFO) << "Overriding profiling to basic based on ETW level: " << static_cast(level);
+ ParseProfilingLevel("basic", profiling_level);
+ } else if (level < 5) {
+ LOGS_DEFAULT(INFO) << "QNN Profiler ETW level not supported below level 5. Level: "
+ << static_cast(level);
+ } else {
+ LOGS_DEFAULT(INFO) << "Overriding profiling to detailed based on ETW level: " << static_cast(level);
+ ParseProfilingLevel("detailed", profiling_level);
+ }
+ }
+ }
+ } else {
+ auto profiling_level_pos = provider_options_map.find(PROFILING_LEVEL);
+ if (profiling_level_pos != provider_options_map.end()) {
+ ParseProfilingLevel(profiling_level_pos->second, profiling_level);
+ }
}
static const std::string RPC_CONTROL_LANTENCY = "rpc_control_latency";
diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc
index 684303a8b6448..4ece068b50fd1 100644
--- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc
+++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc
@@ -1315,6 +1315,9 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv
InitProviderOrtApi();
CUDA_CALL_THROW(cudaSetDevice(device_id_));
+ cudaDeviceProp prop;
+ CUDA_CALL_THROW(cudaGetDeviceProperties(&prop, device_id_));
+ compute_capability_ = GetComputeCapacity(prop);
if (info.has_user_compute_stream) {
external_stream_ = true;
stream_ = static_cast(info.user_compute_stream);
@@ -2778,19 +2781,15 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectorengine_cache_path, trt_state->trt_node_name_with_precision);
- const std::string engine_cache_path = cache_path + "_sm" + compute_capability + ".engine";
+ const std::string engine_cache_path = cache_path + "_sm" + compute_capability_ + ".engine";
const std::string encrypted_engine_cache_path = engine_cache_path + ".encrypted";
- const std::string profile_cache_path = cache_path + "_sm" + compute_capability + ".profile";
+ const std::string profile_cache_path = cache_path + "_sm" + compute_capability_ + ".profile";
std::string timing_cache_path = "";
if (timing_cache_enable_) {
- timing_cache_path = GetTimingCachePath(global_cache_path_, prop);
+ timing_cache_path = GetTimingCachePath(global_cache_path_, compute_capability_);
}
// Load serialized engine
@@ -3473,7 +3468,8 @@ void TensorrtExecutionProvider::RegisterStreamHandlers(IStreamCommandHandleRegis
stream_,
external_stream_ /* use_existing_stream */,
external_cudnn_handle_,
- external_cublas_handle_);
+ external_cublas_handle_,
+ {});
}
OrtDevice TensorrtExecutionProvider::GetOrtDeviceByMemType(OrtMemType mem_type) const {
diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h
index 7eefdd3cba9e2..bacdf0f3c996c 100644
--- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h
+++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h
@@ -258,6 +258,7 @@ class TensorrtExecutionProvider : public IExecutionProvider {
std::unique_ptr runtime_ = nullptr;
OrtMutex tensorrt_mu_;
int device_id_;
+ std::string compute_capability_;
bool context_memory_sharing_enable_ = false;
bool layer_norm_fp32_fallback_ = false;
size_t max_ctx_mem_size_ = 0;
diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_utils.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_utils.h
index 6bbeab7e94ce4..c69299d0ecdeb 100644
--- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_utils.h
+++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_utils.h
@@ -456,10 +456,10 @@ std::string GetComputeCapacity(const cudaDeviceProp& prop) {
* Get Timing by compute capability
*
*/
-std::string GetTimingCachePath(const std::string& root, cudaDeviceProp prop) {
+std::string GetTimingCachePath(const std::string& root, std::string& compute_cap) {
// append compute capability of the GPU as this invalidates the cache and TRT will throw when loading the cache
const std::string timing_cache_name = "TensorrtExecutionProvider_cache_sm" +
- GetComputeCapacity(prop) + ".timing";
+ compute_cap + ".timing";
return GetCachePath(root, timing_cache_name);
}
diff --git a/onnxruntime/core/session/custom_ops.cc b/onnxruntime/core/session/custom_ops.cc
index 984fdd6bce325..d653a27c577b0 100644
--- a/onnxruntime/core/session/custom_ops.cc
+++ b/onnxruntime/core/session/custom_ops.cc
@@ -374,9 +374,6 @@ ORT_API_STATUS_IMPL(OrtApis::KernelContext_GetResource, _In_ const OrtKernelCont
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Failed to fetch a stream hosting the requested resource");
}
*resource = stream->GetResource(resource_version, resource_id);
- if (!(*resource)) {
- return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Requested resource does not exist");
- }
return nullptr;
API_IMPL_END
};
diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc
index 665cdbc36a963..93877c8dd66bd 100644
--- a/onnxruntime/core/session/inference_session.cc
+++ b/onnxruntime/core/session/inference_session.cc
@@ -69,6 +69,11 @@
#include "core/util/protobuf_parsing_utils.h"
#include "core/util/thread_utils.h"
+#ifdef _WIN32
+#include "core/platform/windows/logging/etw_sink.h"
+#include "core/common/logging/sinks/composite_sink.h"
+#endif
+
// custom ops are not available in a minimal build unless ORT_MINIMAL_BUILD_CUSTOM_OPS is set
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS)
#include "core/framework/customregistry.h"
@@ -307,6 +312,7 @@ static Status FinalizeSessionOptions(const SessionOptions& user_provided_session
logging::Severity GetSeverity(const SessionOptions& session_options) {
logging::Severity severity = logging::Severity::kWARNING;
+
if (session_options.session_log_severity_level == -1) {
severity = logging::LoggingManager::DefaultLogger().GetSeverity();
} else {
@@ -322,11 +328,17 @@ logging::Severity GetSeverity(const SessionOptions& session_options) {
void InferenceSession::SetLoggingManager(const SessionOptions& session_options,
const Environment& session_env) {
logging_manager_ = session_env.GetLoggingManager();
+ std::unique_ptr sink;
+
if (session_options.user_logging_function) {
- std::unique_ptr user_sink = std::make_unique(session_options.user_logging_function,
- session_options.user_logging_param);
- user_logging_manager_ = std::make_unique(std::move(user_sink),
- GetSeverity(session_options),
+ sink = std::make_unique(session_options.user_logging_function,
+ session_options.user_logging_param);
+ auto sessionSeverity = GetSeverity(session_options);
+ auto etwOverrideSeverity = logging::OverrideLevelWithEtw(sessionSeverity);
+ sink = EnhanceLoggerWithEtw(std::move(sink), sessionSeverity, etwOverrideSeverity);
+
+ user_logging_manager_ = std::make_unique(std::move(sink),
+ std::min(sessionSeverity, etwOverrideSeverity),
false,
logging::LoggingManager::InstanceType::Temporal,
&session_options.session_logid);
@@ -467,6 +479,8 @@ void InferenceSession::TraceSessionOptions(const SessionOptions& session_options
#ifdef _WIN32
TraceLoggingWrite(telemetry_provider_handle,
"SessionOptions",
+ TraceLoggingKeyword(static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Session)),
+ TraceLoggingLevel(WINEVENT_LEVEL_INFO),
TraceLoggingUInt8(static_cast(session_options.execution_mode), "execution_mode"),
TraceLoggingUInt8(static_cast(session_options.execution_order), "execution_order"),
TraceLoggingBoolean(session_options.enable_profiling, "enable_profiling"),
@@ -487,6 +501,8 @@ void InferenceSession::TraceSessionOptions(const SessionOptions& session_options
TraceLoggingWrite(
telemetry_provider_handle,
"SessionOptions_IntraOrtThreadPoolParams",
+ TraceLoggingKeyword(static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Session)),
+ TraceLoggingLevel(WINEVENT_LEVEL_INFO),
TraceLoggingInt32(session_options.intra_op_param.thread_pool_size, "thread_pool_size"),
TraceLoggingBoolean(session_options.intra_op_param.auto_set_affinity, "auto_set_affinity"),
TraceLoggingBoolean(session_options.intra_op_param.allow_spinning, "allow_spinning"),
@@ -499,6 +515,8 @@ void InferenceSession::TraceSessionOptions(const SessionOptions& session_options
TraceLoggingWrite(
telemetry_provider_handle,
"SessionOptions_ConfigEntry",
+ TraceLoggingKeyword(static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Session)),
+ TraceLoggingLevel(WINEVENT_LEVEL_INFO),
TraceLoggingString(config_pair.first.c_str(), "Key"),
TraceLoggingString(config_pair.second.c_str(), "Value"));
}
diff --git a/onnxruntime/core/session/ort_env.cc b/onnxruntime/core/session/ort_env.cc
index e3957baa990f8..331f1db26a029 100644
--- a/onnxruntime/core/session/ort_env.cc
+++ b/onnxruntime/core/session/ort_env.cc
@@ -39,23 +39,23 @@ OrtEnv* OrtEnv::GetInstance(const OrtEnv::LoggingManagerConstructionInfo& lm_inf
if (!p_instance_) {
std::unique_ptr lmgr;
std::string name = lm_info.logid;
+
+ std::unique_ptr sink = nullptr;
if (lm_info.logging_function) {
- std::unique_ptr logger = std::make_unique(lm_info.logging_function,
- lm_info.logger_param);
- lmgr = std::make_unique(std::move(logger),
- static_cast(lm_info.default_warning_level),
- false,
- LoggingManager::InstanceType::Default,
- &name);
- } else {
- auto sink = MakePlatformDefaultLogSink();
+ sink = std::make_unique(lm_info.logging_function, lm_info.logger_param);
- lmgr = std::make_unique(std::move(sink),
- static_cast(lm_info.default_warning_level),
- false,
- LoggingManager::InstanceType::Default,
- &name);
+ } else {
+ sink = MakePlatformDefaultLogSink();
}
+ auto etwOverrideSeverity = logging::OverrideLevelWithEtw(static_cast(lm_info.default_warning_level));
+ sink = EnhanceLoggerWithEtw(std::move(sink), static_cast(lm_info.default_warning_level),
+ etwOverrideSeverity);
+ lmgr = std::make_unique(std::move(sink),
+ std::min(static_cast(lm_info.default_warning_level), etwOverrideSeverity),
+ false,
+ LoggingManager::InstanceType::Default,
+ &name);
+
std::unique_ptr env;
if (!tp_options) {
status = onnxruntime::Environment::Create(std::move(lmgr), env);
diff --git a/onnxruntime/core/session/provider_registration.cc b/onnxruntime/core/session/provider_registration.cc
index 2e9af9f1f9bb2..b012406bd026a 100644
--- a/onnxruntime/core/session/provider_registration.cc
+++ b/onnxruntime/core/session/provider_registration.cc
@@ -4,6 +4,7 @@
#include
#include "core/common/common.h"
+#include "core/common/logging/logging.h"
#include "core/framework/error_code_helper.h"
#include "core/framework/provider_options.h"
#include "core/providers/provider_factory_creators.h"
@@ -13,6 +14,7 @@
#include "core/providers/openvino/openvino_provider_factory_creator.h"
#ifdef _WIN32
+#include
#include "core/platform/tracing.h"
#endif
@@ -75,6 +77,8 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider,
TraceLoggingWrite(
telemetry_provider_handle,
"ProviderOptionsAppendExecutionProvider",
+ TraceLoggingKeyword(static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Session)),
+ TraceLoggingLevel(WINEVENT_LEVEL_INFO),
TraceLoggingString(provider_name, "ProviderName"),
TraceLoggingString(config_pair.first.c_str(), "Key"),
TraceLoggingString(config_pair.second.c_str(), "Value"));
diff --git a/onnxruntime/test/common/logging/sinks_test.cc b/onnxruntime/test/common/logging/sinks_test.cc
index 28fb407bc2f0e..7ca8d5fc1152c 100644
--- a/onnxruntime/test/common/logging/sinks_test.cc
+++ b/onnxruntime/test/common/logging/sinks_test.cc
@@ -156,7 +156,7 @@ TEST(LoggingTests, TestCompositeSink) {
EXPECT_CALL(*sink_ptr2, SendImpl(testing::_, testing::_, testing::_)).Times(1);
CompositeSink* sink = new CompositeSink();
- sink->AddSink(std::unique_ptr{sink_ptr1}).AddSink(std::unique_ptr{sink_ptr2});
+ sink->AddSink(std::unique_ptr{sink_ptr1}, min_log_level).AddSink(std::unique_ptr{sink_ptr2}, min_log_level);
LoggingManager manager{std::unique_ptr(sink), min_log_level, false, InstanceType::Temporal};
auto logger = manager.CreateLogger(logid);
diff --git a/onnxruntime/test/contrib_ops/decoder_masked_multihead_attention_op_test.cc b/onnxruntime/test/contrib_ops/decoder_masked_multihead_attention_op_test.cc
index d9c870a7dc52a..6afb61bd1f0a1 100644
--- a/onnxruntime/test/contrib_ops/decoder_masked_multihead_attention_op_test.cc
+++ b/onnxruntime/test/contrib_ops/decoder_masked_multihead_attention_op_test.cc
@@ -738,10 +738,23 @@ TEST(DecoderMaskedSelfAttentionTest, Test_fp32) {
tester.AddOutput("present", past_dims, present);
- // Run
- std::vector> execution_providers;
- execution_providers.push_back(DefaultCudaExecutionProvider());
- tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
+ // Run - Regular kernel execution path
+ {
+ std::vector> execution_providers;
+ execution_providers.push_back(DefaultCudaExecutionProvider());
+ tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
+ }
+
+ // Test alternate kernel path of loading more KV data "in flight"
+ {
+ ScopedEnvironmentVariables scoped_env_vars{
+ EnvVarMap{{onnxruntime::contrib::attention::kDecoderMaskedAttentionLoadKVDataInFlight, "1"}}};
+
+ std::vector> execution_providers;
+ execution_providers.push_back(DefaultCudaExecutionProvider());
+
+ tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
+ }
}
}
}
@@ -852,10 +865,22 @@ TEST(DecoderMaskedSelfAttentionTest, Test_fp16) {
tester.AddOutput("present", past_dims, present);
- // Run
- std::vector> execution_providers;
- execution_providers.push_back(DefaultCudaExecutionProvider());
- tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
+ // Run - Regular kernel execution path
+ {
+ std::vector> execution_providers;
+ execution_providers.push_back(DefaultCudaExecutionProvider());
+ tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
+ }
+
+ // Test alternate kernel path of loading more KV data "in flight"
+ {
+ ScopedEnvironmentVariables scoped_env_vars{
+ EnvVarMap{{onnxruntime::contrib::attention::kDecoderMaskedAttentionLoadKVDataInFlight, "1"}}};
+
+ std::vector> execution_providers;
+ execution_providers.push_back(DefaultCudaExecutionProvider());
+ tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
+ }
}
}
}
diff --git a/onnxruntime/test/contrib_ops/quantize_ops_test.cc b/onnxruntime/test/contrib_ops/quantize_ops_test.cc
index 64a97ed4f945b..db685967ae5ff 100644
--- a/onnxruntime/test/contrib_ops/quantize_ops_test.cc
+++ b/onnxruntime/test/contrib_ops/quantize_ops_test.cc
@@ -76,6 +76,16 @@ TEST(DequantizeLinearOpTest, DequantizeLinear_per_tensor_float_int32_cpu) {
test.Run();
}
+TEST(DequantizeLinearOpTest, DequantizeLinearOpTest_BroadcastTensorOfOne) {
+ OpTester test("DequantizeLinear", 1, onnxruntime::kMSDomain);
+
+ test.AddInput("x", {4}, {-30, -3, 100, 127});
+ test.AddInput("x_scale", {1}, {2.0f}, true);
+ test.AddInput("zero_point", {1}, {0}, true);
+ test.AddOutput("y", {4}, {-60.f, -6.f, 200.f, 254.f});
+ test.Run();
+}
+
#ifdef USE_CUDA
TEST(DequantizeLinearOpTest, DequantizeLinear_per_tensor_half_uint8) {
OpTester test("DequantizeLinear", 1, onnxruntime::kMSDomain);
diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc
index ef6e2d531bc1a..5adcb3c150b8d 100755
--- a/onnxruntime/test/optimizer/graph_transform_test.cc
+++ b/onnxruntime/test/optimizer/graph_transform_test.cc
@@ -4602,38 +4602,43 @@ TEST_F(GraphTransformationTests, GeluApproximation_SessionOptionConfig) {
}
// Test DoubleQDQPairsRemover to remove unnecessary DQ->Q nodes in the middle
-TEST_F(GraphTransformationTests, DoublQDQRemover_RemoveDupQDQ) {
- constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "qdq_optimization/dup_qdq.onnx";
- std::shared_ptr p_model;
- ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
- Graph& graph = p_model->MainGraph();
+TEST_F(GraphTransformationTests, DoublQDQRemover_RemoveDupQDQ_Float16) {
+ auto RunTest = [this](const ORTCHAR_T* model_uri) {
+ std::shared_ptr p_model;
+ ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
+ Graph& graph = p_model->MainGraph();
- onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
- ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level1));
- ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
+ onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
+ ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level1));
+ ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
- std::map op_to_count = CountOpsInGraph(graph);
- EXPECT_EQ(op_to_count["QuantizeLinear"], 3);
- EXPECT_EQ(op_to_count["DequantizeLinear"], 4);
+ std::map op_to_count = CountOpsInGraph(graph);
+ EXPECT_EQ(op_to_count["QuantizeLinear"], 3);
+ EXPECT_EQ(op_to_count["DequantizeLinear"], 4);
- std::string dq_scale_name_before_reshape_node;
- std::string zp_name_before_reshape_node;
- std::string dq_scale_name_after_reshape_node;
- std::string zp_name_after_reshape_node;
- for (auto& node : graph.Nodes()) {
- if (node.Name() == "dq_2") {
- dq_scale_name_before_reshape_node = node.InputDefs()[QDQ::InputIndex::SCALE_ID]->Name();
- zp_name_before_reshape_node = node.InputDefs()[QDQ::InputIndex::ZERO_POINT_ID]->Name();
- }
- if (node.Name() == "q_3") {
- dq_scale_name_after_reshape_node = node.InputDefs()[QDQ::InputIndex::SCALE_ID]->Name();
- zp_name_after_reshape_node = node.InputDefs()[QDQ::InputIndex::ZERO_POINT_ID]->Name();
+ std::string dq_scale_name_before_reshape_node;
+ std::string zp_name_before_reshape_node;
+ std::string dq_scale_name_after_reshape_node;
+ std::string zp_name_after_reshape_node;
+ for (auto& node : graph.Nodes()) {
+ if (node.Name() == "dq_2") {
+ dq_scale_name_before_reshape_node = node.InputDefs()[QDQ::InputIndex::SCALE_ID]->Name();
+ zp_name_before_reshape_node = node.InputDefs()[QDQ::InputIndex::ZERO_POINT_ID]->Name();
+ }
+ if (node.Name() == "q_3") {
+ dq_scale_name_after_reshape_node = node.InputDefs()[QDQ::InputIndex::SCALE_ID]->Name();
+ zp_name_after_reshape_node = node.InputDefs()[QDQ::InputIndex::ZERO_POINT_ID]->Name();
+ }
}
- }
- EXPECT_EQ(dq_scale_name_before_reshape_node.empty(), false);
- EXPECT_EQ(zp_name_before_reshape_node.empty(), false);
- EXPECT_EQ(dq_scale_name_before_reshape_node, dq_scale_name_after_reshape_node);
- EXPECT_EQ(zp_name_before_reshape_node, zp_name_after_reshape_node);
+ EXPECT_EQ(dq_scale_name_before_reshape_node.empty(), false);
+ EXPECT_EQ(zp_name_before_reshape_node.empty(), false);
+ EXPECT_EQ(dq_scale_name_before_reshape_node, dq_scale_name_after_reshape_node);
+ EXPECT_EQ(zp_name_before_reshape_node, zp_name_after_reshape_node);
+ };
+
+ RunTest(MODEL_FOLDER "qdq_optimization/dup_qdq.onnx");
+ RunTest(MODEL_FOLDER "qdq_optimization/dup_qdq_float16.onnx");
+ RunTest(MODEL_FOLDER "qdq_optimization/dup_qdq_bfloat16.onnx");
}
// Test Gelu -> FastGelu
diff --git a/onnxruntime/test/platform/windows/logging/HowToValidateEtwSinkOutput.md b/onnxruntime/test/platform/windows/logging/HowToValidateEtwSinkOutput.md
index 2f8d06d66d576..59fe946b929f2 100644
--- a/onnxruntime/test/platform/windows/logging/HowToValidateEtwSinkOutput.md
+++ b/onnxruntime/test/platform/windows/logging/HowToValidateEtwSinkOutput.md
@@ -1,10 +1,19 @@
-## Validating ETW Sink unit test output
+## About the ETW Sink
-## Setup
-Install Windows Performance Toolkit from
-You get to select components when installing, so can select just the performance toolkit.
+The ETW Sink (ONNXRuntimeTraceLoggingProvider) allows ONNX semi-structured printf style logs to be output via ETW.
-Overview of the steps is at if you want more detail.
+ETW makes it easy and useful to only enable and listen for events with great performance, and when you need them instead of only at compile time.
+Therefore ONNX will preserve any existing loggers and log severity [provided at compile time](docs/FAQ.md?plain=1#L7).
+
+However, when the provider is enabled a new ETW logger sink will also be added and the severity separately controlled via ETW dynamically.
+
+- Provider GUID: 929DD115-1ECB-4CB5-B060-EBD4983C421D
+- Keyword: Logs (0x2) keyword per [logging.h](include\onnxruntime\core\common\logging\logging.h)
+- Level: 1-5 ([CRITICAL through VERBOSE](https://learn.microsoft.com/en-us/windows/win32/api/evntprov/ns-evntprov-event_descriptor)) [mapping](onnxruntime\core\platform\windows\logging\etw_sink.cc) to [ONNX severity](include\onnxruntime\core\common\logging\severity.h) in an intuitive manner
+
+Notes:
+- The ETW provider must be enabled prior to session creation, as that as when internal logging setup is complete
+- Other structured ETW logs are output via the other Microsoft.ML.ONNXRuntime ETW provider. Both used together are recommended
## Capturing ETW trace output
@@ -25,9 +34,17 @@ Run the ETW sink unit tests
Stop the ETW tracing
`\onnxruntime\test\platform\windows\logging> wpr -stop TraceCaptureFile.etl EtwSinkTest`
-## View the output
+## View the trace output
+
+### Setup
+- Install Windows Performance Analyzer (Preview) from the Windows Store -
+- Or from the ADK
+ - You get to select components when installing, so can select just the performance toolkit.
+ - Overview of the steps is at if you want more detail.
+
+### Viewing
-Open TraceCaptureFile.etl file Windows Performance Analyzer.
+Open TraceCaptureFile.etl file in Windows Performance Analyzer.
Expand the "System Activity" dropdown in the left pane, and double-click "Generic Events".
That should open events in an Analysis window in the right pane. You should see an event
diff --git a/onnxruntime/test/platform/windows/logging/etw_sink_test.cc b/onnxruntime/test/platform/windows/logging/etw_sink_test.cc
index 7436ac5bd1729..05ef81d05f4ef 100644
--- a/onnxruntime/test/platform/windows/logging/etw_sink_test.cc
+++ b/onnxruntime/test/platform/windows/logging/etw_sink_test.cc
@@ -47,8 +47,8 @@ TEST(LoggingTests, TestEtwSink) {
///
TEST(LoggingTests, TestEtwSinkCtor) {
CompositeSink* sinks = new CompositeSink();
- sinks->AddSink(std::unique_ptr(new EtwSink()))
- .AddSink(std::unique_ptr(new EtwSink()));
+ sinks->AddSink(std::unique_ptr(new EtwSink()), Severity::kWARNING)
+ .AddSink(std::unique_ptr(new EtwSink()), Severity::kWARNING);
LoggingManager manager{std::unique_ptr{sinks},
Severity::kWARNING,
diff --git a/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc b/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc
index f4b21823a487b..026bb07edf44c 100644
--- a/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc
+++ b/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc
@@ -47,6 +47,16 @@ TEST(DequantizeLinearOpTest, Int32) {
test.Run();
}
+TEST(DequantizeLinearOpTest_BroadcastTensor, Int32) {
+ OpTester test("DequantizeLinear", 13);
+ test.AddInput("x", {4}, {-30, -3, 100, 127});
+ test.AddAttribute("axis", 0);
+ test.AddInput("x_scale", {1}, {2.0f});
+ test.AddInput("x_zero_point", {1}, {0});
+ test.AddOutput("y", {4}, {-60.f, -6.f, 200.f, 254.f});
+ test.Run();
+}
+
// 2d inputs
TEST(DequantizeLinearOpTest, 2D) {
OpTester test("DequantizeLinear", 10);
diff --git a/onnxruntime/test/providers/cpu/text/regex_full_match_test.cc b/onnxruntime/test/providers/cpu/text/regex_full_match_test.cc
new file mode 100644
index 0000000000000..4aa5a0d44b678
--- /dev/null
+++ b/onnxruntime/test/providers/cpu/text/regex_full_match_test.cc
@@ -0,0 +1,119 @@
+#include "gtest/gtest.h"
+#include "test/providers/provider_test_utils.h"
+namespace onnxruntime {
+namespace test {
+
+static void RunTest(const std::initializer_list& dims, const std::initializer_list& input, const std::string& pattern, const std::initializer_list& output) {
+ OpTester test("RegexFullMatch", 20, kOnnxDomain);
+ test.AddAttribute("pattern", pattern);
+ test.AddInput("Input", dims, input);
+ test.AddOutput("Output", dims, output);
+ test.Run();
+}
+
+TEST(RegexFullMatch, WebsiteMatch) {
+ RunTest({3, 1}, {"www.google.com", "www.facebook.com", "www.bbc.co.uk"}, R"(www\.[\w.-]+\.\bcom\b)", {true, true, false});
+}
+
+TEST(RegexFullMatch, EmailMatch) {
+ RunTest({2, 2}, {"account@gmail.com", "account@hotmail.com", "not email", "account@yahoo.com"}, R"((\W|^)[\w.\-]{0,25}@(yahoo|gmail)\.com(\W|$))", {true, false, false, true});
+}
+
+TEST(RegexFullMatch, MultibyteMatch) {
+ RunTest({1, 2}, {"ä", "a"}, "ä", {true, false});
+ RunTest({
+ 1,
+ },
+ {"une cédille like in Besançon"}, R"(.*Besançon.*)", {
+ true,
+ });
+ RunTest({
+ 1,
+ },
+ {"une cédille like in Besançon"}, R"(.*Besancon.*)", {
+ false,
+ });
+ RunTest({
+ 1,
+ },
+ {"Mit freundlichen Grüßen"}, R"(.*Grüßen$)", {
+ true,
+ });
+ RunTest({
+ 1,
+ },
+ {"Mit freundlichen Grüßen"}, R"(.*Grußen$)", {
+ false,
+ });
+ RunTest({
+ 3,
+ },
+ {"HПонедельник", "Понедельник", "недельник"}, R"(^Понед.*)", {
+ false,
+ true,
+ false,
+ });
+ RunTest({
+ 3,
+ },
+ {"thank you", "どうもありがとうございます", "こんにちは世界"}, R"(^こんにちは世界.*)", {
+ false,
+ false,
+ true,
+ });
+ RunTest({
+ 3,
+ },
+ {"नमस्ते, आपसे मिलकर अच्छा लगा", "नमस्ते", "स्वागत एवं नमस्ते"}, R"(.+नमस्ते$)", {
+ false,
+ false,
+ true,
+ });
+ RunTest({
+ 3,
+ },
+ {"你好,你好吗?", "你好呀", "你好呀!"}, R"(^你好.*\?$)", {
+ true,
+ false,
+ false,
+ });
+}
+
+TEST(RegexFullMatch, InvalidPattern) {
+ OpTester test("RegexFullMatch", 20, kOnnxDomain);
+ test.AddAttribute("pattern", R"([a-z)");
+ test.AddInput("Input", {
+ 1,
+ },
+ {
+ "abcdef",
+ });
+ test.AddOutput("Output", {
+ 1,
+ },
+ {
+ false,
+ });
+ test.Run(BaseTester::ExpectResult::kExpectFailure, "Invalid regex pattern: [a-z");
+}
+
+TEST(RegexFullMatch, NonUtf8Pattern) {
+ uint8_t invalid_bytes[] = {0xC0, 0xC1, 0x41, 0x42, 0xC3, 0x80, 0xC2, 0x80, 0xC2, 0xC3, 0xC4, 0x00};
+ OpTester test("RegexFullMatch", 20, kOnnxDomain);
+ test.AddAttribute("pattern", std::string((char*)invalid_bytes, sizeof(invalid_bytes)));
+ test.AddInput("Input", {
+ 1,
+ },
+ {
+ "abcd",
+ });
+ test.AddOutput("Output", {
+ 1,
+ },
+ {
+ false,
+ });
+ test.Run(BaseTester::ExpectResult::kExpectFailure, "Invalid regex pattern");
+}
+} // namespace test
+} // namespace onnxruntime
diff --git a/onnxruntime/test/providers/cpu/text/string_concat_test.cc b/onnxruntime/test/providers/cpu/text/string_concat_test.cc
new file mode 100644
index 0000000000000..2bfa3dc5615e1
--- /dev/null
+++ b/onnxruntime/test/providers/cpu/text/string_concat_test.cc
@@ -0,0 +1,76 @@
+#include "gtest/gtest.h"
+#include "test/providers/provider_test_utils.h"
+
+namespace onnxruntime {
+namespace test {
+
+static void RunTest(const std::vector& dims, const std::vector& input1,
+ const std::vector& input2, const std::vector& output) {
+ OpTester test("StringConcat", 20, onnxruntime::kOnnxDomain);
+ test.AddInput("X", dims, input1);
+ test.AddInput("Y", dims, input2);
+ test.AddOutput("Z", dims, output);
+ test.Run();
+}
+
+TEST(StringConcat, BasicConcatenation) {
+ RunTest({1, 2}, {"Hello", "World"}, {"Hello", "World"}, {"HelloHello", "WorldWorld"});
+}
+
+TEST(StringConcat, TwoDimensionalConcatenation) {
+ RunTest({2, 2}, {"Hello", "World", "ONNX", "onnxruntime"}, {"Hello", "World", "ONNX", "onnxruntime"},
+ {"HelloHello", "WorldWorld", "ONNXONNX", "onnxruntimeonnxruntime"});
+}
+
+TEST(StringConcat, LeftToRightBroadcastingConcatenation) {
+ OpTester test("StringConcat", 20, onnxruntime::kOnnxDomain);
+ test.AddInput("X", {2, 2}, {"Hello", "World", "ONNX", "onnxruntime"});
+ test.AddInput("Y", {1}, {"!"});
+ test.AddOutput("Z", {2, 2}, {"Hello!", "World!", "ONNX!", "onnxruntime!"});
+ test.Run();
+}
+
+TEST(StringConcat, RightToLeftBroadcastingConcatenation) {
+ OpTester test("StringConcat", 20, onnxruntime::kOnnxDomain);
+ test.AddInput("X", {1}, {"!"});
+ test.AddInput("Y", {2, 2}, {"Hello", "World", "ONNX", "onnxruntime"});
+ test.AddOutput("Z", {2, 2}, {"!Hello", "!World", "!ONNX", "!onnxruntime"});
+ test.Run();
+}
+
+TEST(StringConcat, BidirectionalBroadcastingConcatenation) {
+ OpTester test("StringConcat", 20, onnxruntime::kOnnxDomain);
+ test.AddInput("X", {2, 1, 3}, {"a", "b", "c", "d", "e", "f"});
+ test.AddInput("Y", {1, 4, 3}, {"a", "b", "c", "d", "e", "f", "g", "h", "i", "k", "l", "m"});
+ test.AddOutput("Z", {2, 4, 3},
+ {
+ "aa",
+ "bb",
+ "cc",
+ "ad",
+ "be",
+ "cf",
+ "ag",
+ "bh",
+ "ci",
+ "ak",
+ "bl",
+ "cm",
+ "da",
+ "eb",
+ "fc",
+ "dd",
+ "ee",
+ "ff",
+ "dg",
+ "eh",
+ "fi",
+ "dk",
+ "el",
+ "fm",
+ });
+ test.Run();
+}
+
+} // namespace test
+} // namespace onnxruntime
diff --git a/onnxruntime/test/providers/cpu/nn/string_normalizer_test.cc b/onnxruntime/test/providers/cpu/text/string_normalizer_test.cc
similarity index 100%
rename from onnxruntime/test/providers/cpu/nn/string_normalizer_test.cc
rename to onnxruntime/test/providers/cpu/text/string_normalizer_test.cc
diff --git a/onnxruntime/test/testdata/custom_op_library/cuda/cuda_ops.cc b/onnxruntime/test/testdata/custom_op_library/cuda/cuda_ops.cc
index 3d561d378cb8c..43795921f17da 100644
--- a/onnxruntime/test/testdata/custom_op_library/cuda/cuda_ops.cc
+++ b/onnxruntime/test/testdata/custom_op_library/cuda/cuda_ops.cc
@@ -28,14 +28,14 @@ void KernelOne(const Ort::Custom::CudaContext& cuda_ctx,
const Ort::Custom::Tensor& X,
const Ort::Custom::Tensor& Y,
Ort::Custom::Tensor