Skip to content

Commit

Permalink
fixing arm ut
Browse files Browse the repository at this point in the history
  • Loading branch information
fajin-corp committed Jul 8, 2024
1 parent 48751ed commit d2a0008
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,12 @@ DQMatMulReplaceWithMatMulNBits::DQMatMulReplaceWithMatMulNBits(int64_t accuracy_
}()},
intra_op_thread_pool_{intra_op_thread_pool} {
ORT_ENFORCE(accuracy_level_ >= 0 && accuracy_level_ <= 4, "MatMulNBits accuracy level must be between 0 and 4");

if (!intra_op_thread_pool) {
OrtThreadPoolParams to;
intra_op_thread_pool_optional_ = concurrency::CreateThreadPool(&onnxruntime::Env::Default(), to,
concurrency::ThreadPoolType::INTRA_OP);
}
}

NodeAttributes
Expand All @@ -311,8 +317,6 @@ DQMatMulReplaceWithMatMulNBits::ExtraAttributes(const RuntimeState& runtime_stat
Status DQMatMulReplaceWithMatMulNBits::ProcessNewNode(Graph& graph,
const NodesToOptimize& selected_nodes,
Node& replacement_node) const {
ORT_ENFORCE(intra_op_thread_pool_, "Intra op thread pool cannot be null");

const auto* dq_node = selected_nodes.Input(0);
const auto* weight_arg = dq_node->InputDefs()[0];
const auto* scale_arg = dq_node->InputDefs()[1];
Expand Down Expand Up @@ -373,7 +377,7 @@ Status DQMatMulReplaceWithMatMulNBits::ProcessNewNode(Graph& graph,
static_cast<int>(K),
static_cast<int>(N),
static_cast<int>(block_size),
intra_op_thread_pool_);
intra_op_thread_pool_ ? intra_op_thread_pool_ : intra_op_thread_pool_optional_.value().get());
} else {
MlasQDQTransposeBlockwiseQuantized<float, 4, false>(
weight_src.DataAsByteSpan().data(),
Expand All @@ -386,7 +390,7 @@ Status DQMatMulReplaceWithMatMulNBits::ProcessNewNode(Graph& graph,
static_cast<int>(K),
static_cast<int>(N),
static_cast<int>(block_size),
intra_op_thread_pool_);
intra_op_thread_pool_ ? intra_op_thread_pool_ : intra_op_thread_pool_optional_.value().get());
}
} else {
if (weight_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT4) {
Expand All @@ -401,7 +405,7 @@ Status DQMatMulReplaceWithMatMulNBits::ProcessNewNode(Graph& graph,
static_cast<int>(K),
static_cast<int>(N),
static_cast<int>(block_size),
intra_op_thread_pool_);
intra_op_thread_pool_ ? intra_op_thread_pool_ : intra_op_thread_pool_optional_.value().get());

} else {
MlasQDQTransposeBlockwiseQuantized<MLFloat16, 4, false>(
Expand All @@ -415,7 +419,7 @@ Status DQMatMulReplaceWithMatMulNBits::ProcessNewNode(Graph& graph,
static_cast<int>(K),
static_cast<int>(N),
static_cast<int>(block_size),
intra_op_thread_pool_);
intra_op_thread_pool_ ? intra_op_thread_pool_ : intra_op_thread_pool_optional_.value().get());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ struct DQMatMulReplaceWithMatMulNBits : public ReplaceWithNew {
const std::string op_type_;
const std::vector<NodeAndMoveInfo> value_moves_;
concurrency::ThreadPool* intra_op_thread_pool_;
std::optional<std::unique_ptr<concurrency::ThreadPool>> intra_op_thread_pool_optional_;
};

struct GemmReplaceWithQuant : public Action {
Expand Down

0 comments on commit d2a0008

Please sign in to comment.