From b7d3d720b0ad717949c23f625a63fd6bf45e59d1 Mon Sep 17 00:00:00 2001 From: "Luo, Yu" Date: Mon, 15 Jan 2024 17:15:29 +0800 Subject: [PATCH] remove platform code as it's implemented by neural_speed --- .../cpu/quantization/neural_speed_defs.h | 6 ------ .../cpu/quantization/neural_speed_gemm.cc | 15 --------------- 2 files changed, 21 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/quantization/neural_speed_defs.h b/onnxruntime/contrib_ops/cpu/quantization/neural_speed_defs.h index 3ffb004357843..864abffd131fe 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/neural_speed_defs.h +++ b/onnxruntime/contrib_ops/cpu/quantization/neural_speed_defs.h @@ -42,10 +42,4 @@ class ORTThreading : public parallel::IThreading { void* mTp; }; -class Platform { - public: - Platform(); - static Platform* get(); -}; - } // namespace bestla diff --git a/onnxruntime/contrib_ops/cpu/quantization/neural_speed_gemm.cc b/onnxruntime/contrib_ops/cpu/quantization/neural_speed_gemm.cc index afac65d8cd3b6..73aaa4ae61a6e 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/neural_speed_gemm.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/neural_speed_gemm.cc @@ -29,18 +29,6 @@ void ORTThreading::parallel_for(const parallel::thread_func& func) const { [&](ptrdiff_t tid) { func(static_cast(tid)); }); } -Platform::Platform() { - GetCPUDevice(); - if (_cd->AMX_INT8() || _cd->AMX_BF16()) { - utils::request_perm_xtile_data(); - } -} - -Platform* Platform::get() { - static Platform instance; - return &instance; -} - template static void NSSQ4GemmCompF32(size_t M, size_t N, size_t K, const float* A, size_t lda, storage::gemm::StorageWeightKBlockNInteger* B, float* C, size_t ldc, int8_t* WorkSpace, @@ -442,9 +430,6 @@ size_t NSSQNBitsGemmBatchWorkspaceSize(const size_t M, const size_t N, const siz void NSSQNBitsGemmBatchPackedB(const size_t M, const size_t N, const size_t K, const size_t BatchN, const NS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams, void* WorkSpace, void* ThreadPool) { - // Prepare system config for AMX instructions - auto p = Platform::get(); - (void)(p); // only nbits=4 can be packed, so not necessary to check the nbits in DataParams if (NSSQ4GemmBatchDriver(M, N, K, BatchN, DataParams, reinterpret_cast(WorkSpace), ThreadPool)) { // PackedWeight is created by bestla