From d2a00084668e8279b96f703ee0c3279ed2549928 Mon Sep 17 00:00:00 2001 From: Jing Fang Date: Mon, 8 Jul 2024 16:31:51 -0700 Subject: [PATCH] fixing arm ut --- .../selectors_actions/qdq_actions.cc | 16 ++++++++++------ .../selectors_actions/qdq_actions.h | 1 + 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc index 657218b44dc91..fe310d2ea453f 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc @@ -288,6 +288,12 @@ DQMatMulReplaceWithMatMulNBits::DQMatMulReplaceWithMatMulNBits(int64_t accuracy_ }()}, intra_op_thread_pool_{intra_op_thread_pool} { ORT_ENFORCE(accuracy_level_ >= 0 && accuracy_level_ <= 4, "MatMulNBits accuracy level must be between 0 and 4"); + + if (!intra_op_thread_pool) { + OrtThreadPoolParams to; + intra_op_thread_pool_optional_ = concurrency::CreateThreadPool(&onnxruntime::Env::Default(), to, + concurrency::ThreadPoolType::INTRA_OP); + } } NodeAttributes @@ -311,8 +317,6 @@ DQMatMulReplaceWithMatMulNBits::ExtraAttributes(const RuntimeState& runtime_stat Status DQMatMulReplaceWithMatMulNBits::ProcessNewNode(Graph& graph, const NodesToOptimize& selected_nodes, Node& replacement_node) const { - ORT_ENFORCE(intra_op_thread_pool_, "Intra op thread pool cannot be null"); - const auto* dq_node = selected_nodes.Input(0); const auto* weight_arg = dq_node->InputDefs()[0]; const auto* scale_arg = dq_node->InputDefs()[1]; @@ -373,7 +377,7 @@ Status DQMatMulReplaceWithMatMulNBits::ProcessNewNode(Graph& graph, static_cast(K), static_cast(N), static_cast(block_size), - intra_op_thread_pool_); + intra_op_thread_pool_ ? intra_op_thread_pool_ : intra_op_thread_pool_optional_.value().get()); } else { MlasQDQTransposeBlockwiseQuantized( weight_src.DataAsByteSpan().data(), @@ -386,7 +390,7 @@ Status DQMatMulReplaceWithMatMulNBits::ProcessNewNode(Graph& graph, static_cast(K), static_cast(N), static_cast(block_size), - intra_op_thread_pool_); + intra_op_thread_pool_ ? intra_op_thread_pool_ : intra_op_thread_pool_optional_.value().get()); } } else { if (weight_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT4) { @@ -401,7 +405,7 @@ Status DQMatMulReplaceWithMatMulNBits::ProcessNewNode(Graph& graph, static_cast(K), static_cast(N), static_cast(block_size), - intra_op_thread_pool_); + intra_op_thread_pool_ ? intra_op_thread_pool_ : intra_op_thread_pool_optional_.value().get()); } else { MlasQDQTransposeBlockwiseQuantized( @@ -415,7 +419,7 @@ Status DQMatMulReplaceWithMatMulNBits::ProcessNewNode(Graph& graph, static_cast(K), static_cast(N), static_cast(block_size), - intra_op_thread_pool_); + intra_op_thread_pool_ ? intra_op_thread_pool_ : intra_op_thread_pool_optional_.value().get()); } } diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h index d80c3f9d183bf..52ae745186b53 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h @@ -103,6 +103,7 @@ struct DQMatMulReplaceWithMatMulNBits : public ReplaceWithNew { const std::string op_type_; const std::vector value_moves_; concurrency::ThreadPool* intra_op_thread_pool_; + std::optional> intra_op_thread_pool_optional_; }; struct GemmReplaceWithQuant : public Action {