Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create branch according to cpu core uarch #10521

Merged
merged 1 commit into from
Feb 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions onnxruntime/core/common/cpuid_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,4 +118,21 @@ CPUIDInfo::CPUIDInfo() {

}

int32_t CPUIDInfo::GetCurrentUarch() const {
#if (defined(CPUIDINFO_ARCH_X86) || defined(CPUIDINFO_ARCH_ARM)) && defined(CPUINFO_SUPPORTED)
if (!pytorch_cpuinfo_init_) {
return -1;
}
const auto uarchIdx = cpuinfo_get_current_uarch_index();
const struct cpuinfo_uarch_info* uarch_info = cpuinfo_get_uarch(uarchIdx);
if (uarch_info == NULL) {
return -1;
}
return uarch_info->uarch;

#else
return -1;
#endif
}

} // namespace onnxruntime
5 changes: 5 additions & 0 deletions onnxruntime/core/common/cpuid_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ class CPUIDInfo {
// ARM
bool HasArmNeonDot() const { return has_arm_neon_dot_; }

/**
* @return CPU core micro-architecture running the current thread
*/
int32_t GetCurrentUarch() const;

private:
CPUIDInfo();
bool has_avx_{false};
Expand Down
69 changes: 69 additions & 0 deletions onnxruntime/core/mlas/lib/mlasi.h
Original file line number Diff line number Diff line change
Expand Up @@ -1987,3 +1987,72 @@ MlasReadTimeStampCounter(void)
#endif
#endif
}

/**
* @brief IDs for cpu microarchitectures.
*
* Copied from python cpuinfo package. Can't use the definition
* from cpuinfo directly as it causes lots of compilation issues
* in many platforms that we support.
*/
enum MlasUArch {
cpuinfo_uarch_unknown = 0,
chenfucn marked this conversation as resolved.
Show resolved Hide resolved

/** ARM Cortex-A32. */
cpuinfo_uarch_cortex_a32 = 0x00300332,
/** ARM Cortex-A35. */
cpuinfo_uarch_cortex_a35 = 0x00300335,
/** ARM Cortex-A53. */
cpuinfo_uarch_cortex_a53 = 0x00300353,
/** ARM Cortex-A55 revision 0 (restricted dual-issue capabilities compared to revision 1+). */
cpuinfo_uarch_cortex_a55r0 = 0x00300354,
/** ARM Cortex-A55. */
cpuinfo_uarch_cortex_a55 = 0x00300355,
/** ARM Cortex-A57. */
cpuinfo_uarch_cortex_a57 = 0x00300357,
/** ARM Cortex-A65. */
cpuinfo_uarch_cortex_a65 = 0x00300365,
/** ARM Cortex-A72. */
cpuinfo_uarch_cortex_a72 = 0x00300372,
/** ARM Cortex-A73. */
cpuinfo_uarch_cortex_a73 = 0x00300373,
/** ARM Cortex-A75. */
cpuinfo_uarch_cortex_a75 = 0x00300375,
/** ARM Cortex-A76. */
cpuinfo_uarch_cortex_a76 = 0x00300376,
/** ARM Cortex-A77. */
cpuinfo_uarch_cortex_a77 = 0x00300377,
/** ARM Cortex-A78. */
cpuinfo_uarch_cortex_a78 = 0x00300378,
};

enum MlasCoreType { mlas_core_unknown = 0, mlas_core_little = 2, mlas_core_big = 3 };

/**
* @return 2 current core is little core with narrow memory load (e.g. ARMv8 a53)
* 3 current core is big core with wider load (e.g. ARMv8 a72)
*/
MLAS_FORCEINLINE
int32_t
MlasGetCoreUArch()
{
thread_local int32_t core_type = mlas_core_unknown;
if (core_type == mlas_core_unknown) {
// initialization needed
#if defined(MLAS_TARGET_ARM64) && defined(__linux__)
auto uarch = MLAS_CPUIDINFO::GetCPUIDInfo().GetCurrentUarch();
yufenglee marked this conversation as resolved.
Show resolved Hide resolved
if (uarch == cpuinfo_uarch_cortex_a53 || uarch == cpuinfo_uarch_cortex_a55r0 ||
uarch == cpuinfo_uarch_cortex_a55) {
core_type = mlas_core_little;
} else {
core_type = mlas_core_big;
}
#else
core_type = mlas_core_big;
#endif // MLAS_TARGET_ARM64

}
return core_type;
}


8 changes: 7 additions & 1 deletion onnxruntime/core/mlas/lib/qgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,11 +205,13 @@ MlasSymmQgemmBatch(
const size_t N = Shape.N;
const size_t K = Shape.K;
const MLAS_SYMM_QGEMM_DISPATCH* dispatch = GetMlasPlatform().SymmQgemmDispatch;
MLAS_SYMM_QGEMM_OPERATION* operation = dispatch->Operation;

if (ThreadPool == nullptr) {
// So our caller handles threaded job partition.
// Call single threaded operation directly
auto uarch = MlasGetCoreUArch();
MLAS_SYMM_QGEMM_OPERATION* operation =
uarch == mlas_core_little ? dispatch->LitOperation : dispatch->BigOperation;

for (size_t gemm_i = 0; gemm_i < BatchN; gemm_i++) {
auto Data = &DataParams[gemm_i];
Expand Down Expand Up @@ -258,6 +260,10 @@ MlasSymmQgemmBatch(
ThreadsPerGemm = ThreadCountM * ThreadCountN;

MlasTrySimpleParallel(ThreadPool, ThreadsPerGemm * BatchN, [&](ptrdiff_t tid) {
auto uarch = MlasGetCoreUArch();
MLAS_SYMM_QGEMM_OPERATION* operation =
uarch == mlas_core_little ? dispatch->LitOperation : dispatch->BigOperation;

const auto gemm_i = tid / ThreadsPerGemm;
const auto blk_i = tid % ThreadsPerGemm;
auto Data = &DataParams[gemm_i];
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/core/mlas/lib/qgemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -802,7 +802,8 @@ struct MLAS_GEMM_QUANT_DISPATCH {
};

struct MLAS_SYMM_QGEMM_DISPATCH {
MLAS_SYMM_QGEMM_OPERATION* Operation;
MLAS_SYMM_QGEMM_OPERATION* LitOperation; /// running on little cores with narrow memory load
MLAS_SYMM_QGEMM_OPERATION* BigOperation; /// running on big cores with wider memory load
MLAS_GEMM_QUANT_COPY_PACKB_ROUTINE* CopyPackBRoutine;
size_t StrideM; /**< num of rows processed by kernel at a time */
size_t PackedK;
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/mlas/lib/qgemm_kernel_neon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1217,6 +1217,7 @@ const MLAS_GEMM_QUANT_DISPATCH MlasGemmX8S8DispatchNeon = {
};

const MLAS_SYMM_QGEMM_DISPATCH MlasSymmQgemmS8DispatchNeon = {
MlasSymmQGemmPackedOperation<MLAS_GEMM_X8S8_KERNEL_NEON>,
MlasSymmQGemmPackedOperation<MLAS_GEMM_X8S8_KERNEL_NEON>,
MlasGemmQuantCopyPackB<MLAS_GEMM_X8S8_KERNEL_NEON>,
4, // StrideM
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/mlas/lib/qgemm_kernel_sdot.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1027,6 +1027,7 @@ size_t MlasSymmQGemmKernel<MLAS_SYMM_GEMM_S8S8_KERNEL_SDOT>(
}

const MLAS_SYMM_QGEMM_DISPATCH MlasSymmQgemmS8DispatchSdot = {
MlasSymmQGemmPackedOperation<MLAS_SYMM_GEMM_S8S8_KERNEL_SDOT>,
MlasSymmQGemmPackedOperation<MLAS_SYMM_GEMM_S8S8_KERNEL_SDOT>,
MlasGemmQuantCopyPackB<MLAS_SYMM_GEMM_S8S8_KERNEL_SDOT>,
4, // StrideM
Expand Down