From 8c835856f69ac9c14effe79603acd3dc7cde29fd Mon Sep 17 00:00:00 2001 From: Chen Fu Date: Thu, 10 Feb 2022 14:30:22 -0800 Subject: [PATCH] Create branch according to cpu core uarch --- onnxruntime/core/common/cpuid_info.cc | 17 +++++ onnxruntime/core/common/cpuid_info.h | 5 ++ onnxruntime/core/mlas/lib/mlasi.h | 69 +++++++++++++++++++ onnxruntime/core/mlas/lib/qgemm.cpp | 8 ++- onnxruntime/core/mlas/lib/qgemm.h | 3 +- .../core/mlas/lib/qgemm_kernel_neon.cpp | 1 + .../core/mlas/lib/qgemm_kernel_sdot.cpp | 1 + 7 files changed, 102 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/common/cpuid_info.cc b/onnxruntime/core/common/cpuid_info.cc index f9b084b682e00..f94720f05bad8 100644 --- a/onnxruntime/core/common/cpuid_info.cc +++ b/onnxruntime/core/common/cpuid_info.cc @@ -118,4 +118,21 @@ CPUIDInfo::CPUIDInfo() { } +int32_t CPUIDInfo::GetCurrentUarch() const { +#if (defined(CPUIDINFO_ARCH_X86) || defined(CPUIDINFO_ARCH_ARM)) && defined(CPUINFO_SUPPORTED) + if (!pytorch_cpuinfo_init_) { + return -1; + } + const auto uarchIdx = cpuinfo_get_current_uarch_index(); + const struct cpuinfo_uarch_info* uarch_info = cpuinfo_get_uarch(uarchIdx); + if (uarch_info == NULL) { + return -1; + } + return uarch_info->uarch; + +#else + return -1; +#endif +} + } // namespace onnxruntime diff --git a/onnxruntime/core/common/cpuid_info.h b/onnxruntime/core/common/cpuid_info.h index 66fc21ff55562..aa0cc485e9d84 100644 --- a/onnxruntime/core/common/cpuid_info.h +++ b/onnxruntime/core/common/cpuid_info.h @@ -34,6 +34,11 @@ class CPUIDInfo { // ARM bool HasArmNeonDot() const { return has_arm_neon_dot_; } + /** + * @return CPU core micro-architecture running the current thread + */ + int32_t GetCurrentUarch() const; + private: CPUIDInfo(); bool has_avx_{false}; diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index ca9fdef577f7b..c8959e29818dd 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -1987,3 +1987,72 @@ MlasReadTimeStampCounter(void) #endif #endif } + +/** + * @brief IDs for cpu microarchitectures. + * + * Copied from python cpuinfo package. Can't use the definition + * from cpuinfo directly as it causes lots of compilation issues + * in many platforms that we support. + */ +enum MlasUArch { + cpuinfo_uarch_unknown = 0, + + /** ARM Cortex-A32. */ + cpuinfo_uarch_cortex_a32 = 0x00300332, + /** ARM Cortex-A35. */ + cpuinfo_uarch_cortex_a35 = 0x00300335, + /** ARM Cortex-A53. */ + cpuinfo_uarch_cortex_a53 = 0x00300353, + /** ARM Cortex-A55 revision 0 (restricted dual-issue capabilities compared to revision 1+). */ + cpuinfo_uarch_cortex_a55r0 = 0x00300354, + /** ARM Cortex-A55. */ + cpuinfo_uarch_cortex_a55 = 0x00300355, + /** ARM Cortex-A57. */ + cpuinfo_uarch_cortex_a57 = 0x00300357, + /** ARM Cortex-A65. */ + cpuinfo_uarch_cortex_a65 = 0x00300365, + /** ARM Cortex-A72. */ + cpuinfo_uarch_cortex_a72 = 0x00300372, + /** ARM Cortex-A73. */ + cpuinfo_uarch_cortex_a73 = 0x00300373, + /** ARM Cortex-A75. */ + cpuinfo_uarch_cortex_a75 = 0x00300375, + /** ARM Cortex-A76. */ + cpuinfo_uarch_cortex_a76 = 0x00300376, + /** ARM Cortex-A77. */ + cpuinfo_uarch_cortex_a77 = 0x00300377, + /** ARM Cortex-A78. */ + cpuinfo_uarch_cortex_a78 = 0x00300378, +}; + +enum MlasCoreType { mlas_core_unknown = 0, mlas_core_little = 2, mlas_core_big = 3 }; + +/** + * @return 2 current core is little core with narrow memory load (e.g. ARMv8 a53) + * 3 current core is big core with wider load (e.g. ARMv8 a72) + */ +MLAS_FORCEINLINE +int32_t +MlasGetCoreUArch() +{ + thread_local int32_t core_type = mlas_core_unknown; + if (core_type == mlas_core_unknown) { + // initialization needed +#if defined(MLAS_TARGET_ARM64) && defined(__linux__) + auto uarch = MLAS_CPUIDINFO::GetCPUIDInfo().GetCurrentUarch(); + if (uarch == cpuinfo_uarch_cortex_a53 || uarch == cpuinfo_uarch_cortex_a55r0 || + uarch == cpuinfo_uarch_cortex_a55) { + core_type = mlas_core_little; + } else { + core_type = mlas_core_big; + } +#else + core_type = mlas_core_big; +#endif // MLAS_TARGET_ARM64 + + } + return core_type; +} + + diff --git a/onnxruntime/core/mlas/lib/qgemm.cpp b/onnxruntime/core/mlas/lib/qgemm.cpp index 33ac5cfc27466..772d28cb9875a 100644 --- a/onnxruntime/core/mlas/lib/qgemm.cpp +++ b/onnxruntime/core/mlas/lib/qgemm.cpp @@ -205,11 +205,13 @@ MlasSymmQgemmBatch( const size_t N = Shape.N; const size_t K = Shape.K; const MLAS_SYMM_QGEMM_DISPATCH* dispatch = GetMlasPlatform().SymmQgemmDispatch; - MLAS_SYMM_QGEMM_OPERATION* operation = dispatch->Operation; if (ThreadPool == nullptr) { // So our caller handles threaded job partition. // Call single threaded operation directly + auto uarch = MlasGetCoreUArch(); + MLAS_SYMM_QGEMM_OPERATION* operation = + uarch == mlas_core_little ? dispatch->LitOperation : dispatch->BigOperation; for (size_t gemm_i = 0; gemm_i < BatchN; gemm_i++) { auto Data = &DataParams[gemm_i]; @@ -258,6 +260,10 @@ MlasSymmQgemmBatch( ThreadsPerGemm = ThreadCountM * ThreadCountN; MlasTrySimpleParallel(ThreadPool, ThreadsPerGemm * BatchN, [&](ptrdiff_t tid) { + auto uarch = MlasGetCoreUArch(); + MLAS_SYMM_QGEMM_OPERATION* operation = + uarch == mlas_core_little ? dispatch->LitOperation : dispatch->BigOperation; + const auto gemm_i = tid / ThreadsPerGemm; const auto blk_i = tid % ThreadsPerGemm; auto Data = &DataParams[gemm_i]; diff --git a/onnxruntime/core/mlas/lib/qgemm.h b/onnxruntime/core/mlas/lib/qgemm.h index ab566c5b50645..2f6168c527050 100644 --- a/onnxruntime/core/mlas/lib/qgemm.h +++ b/onnxruntime/core/mlas/lib/qgemm.h @@ -802,7 +802,8 @@ struct MLAS_GEMM_QUANT_DISPATCH { }; struct MLAS_SYMM_QGEMM_DISPATCH { - MLAS_SYMM_QGEMM_OPERATION* Operation; + MLAS_SYMM_QGEMM_OPERATION* LitOperation; /// running on little cores with narrow memory load + MLAS_SYMM_QGEMM_OPERATION* BigOperation; /// running on big cores with wider memory load MLAS_GEMM_QUANT_COPY_PACKB_ROUTINE* CopyPackBRoutine; size_t StrideM; /**< num of rows processed by kernel at a time */ size_t PackedK; diff --git a/onnxruntime/core/mlas/lib/qgemm_kernel_neon.cpp b/onnxruntime/core/mlas/lib/qgemm_kernel_neon.cpp index ce9460be9b501..0b747bc7cc84b 100644 --- a/onnxruntime/core/mlas/lib/qgemm_kernel_neon.cpp +++ b/onnxruntime/core/mlas/lib/qgemm_kernel_neon.cpp @@ -1217,6 +1217,7 @@ const MLAS_GEMM_QUANT_DISPATCH MlasGemmX8S8DispatchNeon = { }; const MLAS_SYMM_QGEMM_DISPATCH MlasSymmQgemmS8DispatchNeon = { + MlasSymmQGemmPackedOperation, MlasSymmQGemmPackedOperation, MlasGemmQuantCopyPackB, 4, // StrideM diff --git a/onnxruntime/core/mlas/lib/qgemm_kernel_sdot.cpp b/onnxruntime/core/mlas/lib/qgemm_kernel_sdot.cpp index 5b2d9a6e6cce3..604986cf9f662 100644 --- a/onnxruntime/core/mlas/lib/qgemm_kernel_sdot.cpp +++ b/onnxruntime/core/mlas/lib/qgemm_kernel_sdot.cpp @@ -1027,6 +1027,7 @@ size_t MlasSymmQGemmKernel( } const MLAS_SYMM_QGEMM_DISPATCH MlasSymmQgemmS8DispatchSdot = { + MlasSymmQGemmPackedOperation, MlasSymmQGemmPackedOperation, MlasGemmQuantCopyPackB, 4, // StrideM