Skip to content

Commit

Permalink
Create branch according to cpu core uarch
Browse files Browse the repository at this point in the history
  • Loading branch information
Chen Fu committed Feb 14, 2022
1 parent dd33ce0 commit 8c83585
Show file tree
Hide file tree
Showing 7 changed files with 102 additions and 2 deletions.
17 changes: 17 additions & 0 deletions onnxruntime/core/common/cpuid_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,4 +118,21 @@ CPUIDInfo::CPUIDInfo() {

}

int32_t CPUIDInfo::GetCurrentUarch() const {
#if (defined(CPUIDINFO_ARCH_X86) || defined(CPUIDINFO_ARCH_ARM)) && defined(CPUINFO_SUPPORTED)
if (!pytorch_cpuinfo_init_) {
return -1;
}
const auto uarchIdx = cpuinfo_get_current_uarch_index();
const struct cpuinfo_uarch_info* uarch_info = cpuinfo_get_uarch(uarchIdx);
if (uarch_info == NULL) {
return -1;
}
return uarch_info->uarch;

#else
return -1;
#endif
}

} // namespace onnxruntime
5 changes: 5 additions & 0 deletions onnxruntime/core/common/cpuid_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ class CPUIDInfo {
// ARM
bool HasArmNeonDot() const { return has_arm_neon_dot_; }

/**
* @return CPU core micro-architecture running the current thread
*/
int32_t GetCurrentUarch() const;

private:
CPUIDInfo();
bool has_avx_{false};
Expand Down
69 changes: 69 additions & 0 deletions onnxruntime/core/mlas/lib/mlasi.h
Original file line number Diff line number Diff line change
Expand Up @@ -1987,3 +1987,72 @@ MlasReadTimeStampCounter(void)
#endif
#endif
}

/**
* @brief IDs for cpu microarchitectures.
*
* Copied from python cpuinfo package. Can't use the definition
* from cpuinfo directly as it causes lots of compilation issues
* in many platforms that we support.
*/
enum MlasUArch {
cpuinfo_uarch_unknown = 0,

/** ARM Cortex-A32. */
cpuinfo_uarch_cortex_a32 = 0x00300332,
/** ARM Cortex-A35. */
cpuinfo_uarch_cortex_a35 = 0x00300335,
/** ARM Cortex-A53. */
cpuinfo_uarch_cortex_a53 = 0x00300353,
/** ARM Cortex-A55 revision 0 (restricted dual-issue capabilities compared to revision 1+). */
cpuinfo_uarch_cortex_a55r0 = 0x00300354,
/** ARM Cortex-A55. */
cpuinfo_uarch_cortex_a55 = 0x00300355,
/** ARM Cortex-A57. */
cpuinfo_uarch_cortex_a57 = 0x00300357,
/** ARM Cortex-A65. */
cpuinfo_uarch_cortex_a65 = 0x00300365,
/** ARM Cortex-A72. */
cpuinfo_uarch_cortex_a72 = 0x00300372,
/** ARM Cortex-A73. */
cpuinfo_uarch_cortex_a73 = 0x00300373,
/** ARM Cortex-A75. */
cpuinfo_uarch_cortex_a75 = 0x00300375,
/** ARM Cortex-A76. */
cpuinfo_uarch_cortex_a76 = 0x00300376,
/** ARM Cortex-A77. */
cpuinfo_uarch_cortex_a77 = 0x00300377,
/** ARM Cortex-A78. */
cpuinfo_uarch_cortex_a78 = 0x00300378,
};

enum MlasCoreType { mlas_core_unknown = 0, mlas_core_little = 2, mlas_core_big = 3 };

/**
* @return 2 current core is little core with narrow memory load (e.g. ARMv8 a53)
* 3 current core is big core with wider load (e.g. ARMv8 a72)
*/
MLAS_FORCEINLINE
int32_t
MlasGetCoreUArch()
{
thread_local int32_t core_type = mlas_core_unknown;
if (core_type == mlas_core_unknown) {
// initialization needed
#if defined(MLAS_TARGET_ARM64) && defined(__linux__)
auto uarch = MLAS_CPUIDINFO::GetCPUIDInfo().GetCurrentUarch();
if (uarch == cpuinfo_uarch_cortex_a53 || uarch == cpuinfo_uarch_cortex_a55r0 ||
uarch == cpuinfo_uarch_cortex_a55) {
core_type = mlas_core_little;
} else {
core_type = mlas_core_big;
}
#else
core_type = mlas_core_big;
#endif // MLAS_TARGET_ARM64

}
return core_type;
}


8 changes: 7 additions & 1 deletion onnxruntime/core/mlas/lib/qgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,11 +205,13 @@ MlasSymmQgemmBatch(
const size_t N = Shape.N;
const size_t K = Shape.K;
const MLAS_SYMM_QGEMM_DISPATCH* dispatch = GetMlasPlatform().SymmQgemmDispatch;
MLAS_SYMM_QGEMM_OPERATION* operation = dispatch->Operation;

if (ThreadPool == nullptr) {
// So our caller handles threaded job partition.
// Call single threaded operation directly
auto uarch = MlasGetCoreUArch();
MLAS_SYMM_QGEMM_OPERATION* operation =
uarch == mlas_core_little ? dispatch->LitOperation : dispatch->BigOperation;

for (size_t gemm_i = 0; gemm_i < BatchN; gemm_i++) {
auto Data = &DataParams[gemm_i];
Expand Down Expand Up @@ -258,6 +260,10 @@ MlasSymmQgemmBatch(
ThreadsPerGemm = ThreadCountM * ThreadCountN;

MlasTrySimpleParallel(ThreadPool, ThreadsPerGemm * BatchN, [&](ptrdiff_t tid) {
auto uarch = MlasGetCoreUArch();
MLAS_SYMM_QGEMM_OPERATION* operation =
uarch == mlas_core_little ? dispatch->LitOperation : dispatch->BigOperation;

const auto gemm_i = tid / ThreadsPerGemm;
const auto blk_i = tid % ThreadsPerGemm;
auto Data = &DataParams[gemm_i];
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/core/mlas/lib/qgemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -802,7 +802,8 @@ struct MLAS_GEMM_QUANT_DISPATCH {
};

struct MLAS_SYMM_QGEMM_DISPATCH {
MLAS_SYMM_QGEMM_OPERATION* Operation;
MLAS_SYMM_QGEMM_OPERATION* LitOperation; /// running on little cores with narrow memory load
MLAS_SYMM_QGEMM_OPERATION* BigOperation; /// running on big cores with wider memory load
MLAS_GEMM_QUANT_COPY_PACKB_ROUTINE* CopyPackBRoutine;
size_t StrideM; /**< num of rows processed by kernel at a time */
size_t PackedK;
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/mlas/lib/qgemm_kernel_neon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1217,6 +1217,7 @@ const MLAS_GEMM_QUANT_DISPATCH MlasGemmX8S8DispatchNeon = {
};

const MLAS_SYMM_QGEMM_DISPATCH MlasSymmQgemmS8DispatchNeon = {
MlasSymmQGemmPackedOperation<MLAS_GEMM_X8S8_KERNEL_NEON>,
MlasSymmQGemmPackedOperation<MLAS_GEMM_X8S8_KERNEL_NEON>,
MlasGemmQuantCopyPackB<MLAS_GEMM_X8S8_KERNEL_NEON>,
4, // StrideM
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/mlas/lib/qgemm_kernel_sdot.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1027,6 +1027,7 @@ size_t MlasSymmQGemmKernel<MLAS_SYMM_GEMM_S8S8_KERNEL_SDOT>(
}

const MLAS_SYMM_QGEMM_DISPATCH MlasSymmQgemmS8DispatchSdot = {
MlasSymmQGemmPackedOperation<MLAS_SYMM_GEMM_S8S8_KERNEL_SDOT>,
MlasSymmQGemmPackedOperation<MLAS_SYMM_GEMM_S8S8_KERNEL_SDOT>,
MlasGemmQuantCopyPackB<MLAS_SYMM_GEMM_S8S8_KERNEL_SDOT>,
4, // StrideM
Expand Down

0 comments on commit 8c83585

Please sign in to comment.