diff --git a/.github/workflows/scripts/formatScan/clangtidy.sh b/.github/workflows/scripts/formatScan/clangtidy.sh index a3795d6fc..e776c1ef2 100644 --- a/.github/workflows/scripts/formatScan/clangtidy.sh +++ b/.github/workflows/scripts/formatScan/clangtidy.sh @@ -11,7 +11,7 @@ log_path=${log_dir}/clangtidy.log cd ${REPO_DIR} mkdir build cd build -cmake .. -G Ninja -DNS_USE_CLANG_TIDY=CHECK -DBTLA_USE_OPENMP=OFF +cmake .. -G Ninja -DNS_USE_CLANG_TIDY=CHECK -DBTLA_ENABLE_OPENMP=OFF -DNS_USE_OMP=OFF ninja 2>&1 | tee ${log_path} if [[ ! -f ${log_path} ]] || [[ $(grep -c "warning:" ${log_path}) != 0 ]] || [[ $(grep -c "error" ${log_path}) != 0 ]]; then diff --git a/CMakeLists.txt b/CMakeLists.txt index c8e4d6c82..c341d83c7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -60,9 +60,9 @@ option(NS_AVX512_VBMI "neural_speed: enable AVX512-VBMI" option(NS_AVX512_VNNI "neural_speed: enable AVX512-VNNI" OFF) option(NS_FMA "neural_speed: enable FMA" ON) option(NS_AMX "neural_speed: enable AMX" OFF) +option(NS_USE_OMP "neural_speed: use OpenMP thread pool." ON) option(NS_BUILD_TESTS "neural_speed: build tests" ${NS_STANDALONE}) -option(NS_BTLA_UT "enable BesTLA's unit tests" OFF) option(NS_BUILD_EXAMPLES "neural_speed: build examples" ${NS_STANDALONE}) option(NS_USE_CLANG_TIDY "neural_speed: clang-tidy check" OFF) @@ -135,12 +135,13 @@ if (NS_PYTHON_API) add_subdirectory(third_party/pybind11) endif() -if (NS_BTLA_UT) - set(BTLA_UT_ALL ON) +if(NS_USE_OMP) + include(FindOpenMP) + # compile BesTLA's OMPTheading class, then it can be used in ne_layers + set(BTLA_ENABLE_OPENMP ON CACHE BOOL "BesTLA enable compiling OpenMP threading") + add_compile_definitions(NS_USE_OMP) endif() -include(FindOpenMP) -set(BTLA_USE_OPENMP ON CACHE BOOL "BesTLA use OpenMP") add_subdirectory(bestla) add_subdirectory(neural_speed) diff --git a/CMakePresets.json b/CMakePresets.json index 2ba85fd8f..a9343e484 100644 --- a/CMakePresets.json +++ b/CMakePresets.json @@ -23,6 +23,16 @@ "inherits": "linux-debug", "cacheVariables": { "CMAKE_BUILD_TYPE": "Release" } }, + { + "name": "linux-release-thread", + "displayName": "Linux Release Thread Pool", + "description": "Release", + "inherits": "linux-debug", + "cacheVariables": { + "CMAKE_BUILD_TYPE": "Release", + "NS_USE_OMP": "OFF" + } + }, { "name": "windows-base", "description": "Target Windows with the Visual Studio development environment.", @@ -49,23 +59,51 @@ "value": "x64", "strategy": "external" }, - "cacheVariables": { "CMAKE_BUILD_TYPE": "Debug" } + "cacheVariables": { + "CMAKE_BUILD_TYPE": "Debug", + "NS_PROFILING": "ON", + "NS_USE_OMP": "ON", + "BTLA_UT_DEBUG": "ON" + } }, { "name": "x64-release", "displayName": "x64 Release", "description": "Target Windows (64-bit) with the Visual Studio development environment. (RelWithDebInfo)", "inherits": "x64-debug", - "cacheVariables": { "CMAKE_BUILD_TYPE": "Release" } + "cacheVariables": { + "CMAKE_BUILD_TYPE": "Release", + "BTLA_UT_DEBUG": "OFF" + } + }, + { + "name": "x64-release-thread", + "displayName": "x64 Release without OpenMP", + "description": "Target Windows (64-bit) with the Visual Studio development environment. (RelWithDebInfo)", + "inherits": "x64-release", + "cacheVariables": { + "NS_USE_OMP": "OFF" + } }, { "name": "x64-bestla-UT", "displayName": "x64 BesTLA unit test", "description": "Target Windows (64-bit) with the Visual Studio development environment. (RelWithDebInfo)", - "inherits": "x64-debug", + "inherits": "x64-release", "cacheVariables": { - "CMAKE_BUILD_TYPE": "Release", - "NS_BTLA_UT": "ON" + "CMAKE_BUILD_TYPE": "RelWithDebInfo", + "BTLA_UT_ALL": "ON", + "BTLA_UT_BENCHMARK": "ON", + "BTLA_UT_OPENMP": "ON" + } + }, + { + "name": "x64-ut-thread", + "displayName": "x64 BesTLA UT without OpenMP", + "description": "Target Windows (64-bit) with the Visual Studio development environment. (RelWithDebInfo)", + "inherits": "x64-bestla-UT", + "cacheVariables": { + "BTLA_UT_OPENMP": "OFF" } } ] diff --git a/bestla/CMakeLists.txt b/bestla/CMakeLists.txt index 2b17a8603..a3082acca 100644 --- a/bestla/CMakeLists.txt +++ b/bestla/CMakeLists.txt @@ -4,7 +4,7 @@ project(bestla LANGUAGES CXX VERSION 0.1.0) file(GLOB headers ${PROJECT_NAME}/*.h ${PROJECT_NAME}/*.hpp) file(GLOB xbyak_headers ${PROJECT_NAME}/xbyak/*.h ${PROJECT_NAME}/xbyak/*.hpp) -option(BTLA_USE_OPENMP "Enable OpenMP thread pool" OFF) +option(BTLA_ENABLE_OPENMP "Compile OpenMP thread pool if OMP can be found" OFF) option(BTLA_UT_ALL "Enable all unit tests" OFF) option(BTLA_UT_DEBUG "Enable debug unit tests" OFF) @@ -19,7 +19,7 @@ option(BTLA_UT_KERNEL_INTRIN "Enable unit test for intrinsic kernels" OFF) option(BTLA_UT_KERNEL_WRAPPER "Enable unit test for runtime ISA kernels" OFF) option(BTLA_UT_NOASAN "Disable sanitize" OFF) option(BTLA_UT_BENCHMARK "Benchmark ON may take a long time to finish all tests" OFF) -option(BTLA_UT_OPENMP "Use OpenMP" ON) +option(BTLA_UT_OPENMP "Use OpenMP for UT tests" OFF) add_library(${PROJECT_NAME} INTERFACE) add_library(neural_speed::${PROJECT_NAME} ALIAS ${PROJECT_NAME}) @@ -30,10 +30,10 @@ target_include_directories( ) -if(BTLA_USE_OPENMP) - message(STATUS "BesTLA using OpenMP") +if(BTLA_ENABLE_OPENMP) + message(STATUS "BesTLA enable OpenMP ThreadPool") target_compile_definitions(${PROJECT_NAME} INTERFACE BTLA_USE_OPENMP) -endif(BTLA_USE_OPENMP) +endif(BTLA_ENABLE_OPENMP) if(WIN32) target_compile_definitions(${PROJECT_NAME} INTERFACE _CRT_SECURE_NO_WARNINGS NOMINMAX) @@ -64,12 +64,14 @@ endif() function(add_ut_flag UT_OPTION) if(${${UT_OPTION}}) - target_compile_definitions(${PROJECT_NAME}_ut PRIVATE ${UT_OPTION}) + # target_compile_definitions(${PROJECT_NAME}_ut PRIVATE ${UT_OPTION}) + add_compile_definitions(${UT_OPTION}) endif() endfunction() if(UT_BUILD) file(GLOB srcs ${PROJECT_NAME}/ut/*.cc ${PROJECT_NAME}/ut/*.cpp) #compile everything even run parts of UTs + list(REMOVE_ITEM srcs ${CMAKE_CURRENT_SOURCE_DIR}/${PROJECT_NAME}/ut/bestla_benchmark.cpp) file(GLOB ut_headers ${PROJECT_NAME}/ut/*.h) include_directories(${PROJECT_NAME}) add_executable(${PROJECT_NAME}_ut ${srcs} ${headers} ${ut_headers}) @@ -96,8 +98,21 @@ if(UT_BUILD) add_ut_flag(BTLA_UT_KERNEL_INTRIN) add_ut_flag(BTLA_UT_KERNEL_JIT) add_ut_flag(BTLA_UT_KERNEL_WRAPPER) - add_ut_flag(BTLA_UT_BENCHMARK) - target_link_libraries(${PROJECT_NAME}_ut PRIVATE ${PROJECT_NAME}) endif(UT_BUILD) +if(BTLA_UT_BENCHMARK) + file(GLOB srcs ${PROJECT_NAME}/ut/bestla_benchmark.cpp) #compile everything even run parts of UTs + file(GLOB ut_headers ${PROJECT_NAME}/ut/*.h) + include_directories(${PROJECT_NAME}) + add_executable(${PROJECT_NAME}_benchmark ${srcs} ${headers} ${ut_headers}) + if(BTLA_UT_OPENMP) + include(FindOpenMP) + target_compile_definitions(${PROJECT_NAME} INTERFACE BTLA_USE_OPENMP) + target_link_libraries(${PROJECT_NAME}_benchmark PRIVATE OpenMP::OpenMP_CXX) + endif() + if(NOT WIN32) + target_link_options(${PROJECT_NAME}_benchmark PRIVATE -lpthread) + endif() + target_link_libraries(${PROJECT_NAME}_benchmark PRIVATE ${PROJECT_NAME}) +endif(BTLA_UT_BENCHMARK) diff --git a/bestla/bestla/bestla.h b/bestla/bestla/bestla.h index d3327a656..7942eb6b7 100644 --- a/bestla/bestla/bestla.h +++ b/bestla/bestla/bestla.h @@ -31,6 +31,7 @@ enum class BTLA_ISA : uint8_t { AMX_INT8, AVX512_FP16, AVX512_BF16, + ISA_COUNT, }; enum class BTLA_DTYPE : uint32_t { EleBitsMask = 0xff, diff --git a/bestla/bestla/bestla_device.h b/bestla/bestla/bestla_device.h index 921305e7c..ca161fd6d 100644 --- a/bestla/bestla/bestla_device.h +++ b/bestla/bestla/bestla_device.h @@ -215,6 +215,7 @@ class CpuDevice { public: inline int getThreads() { return numthreads; } inline int getCores() { return numcores; } + inline uint32_t getL3CacheSize() { return L3Cache; } inline uint32_t getL2CacheSize() { return L2Cache; } inline uint32_t getL1CacheSize() { return L1Cache; } inline uint32_t getL2CacheSize_E() { return E_L2Cache; } @@ -228,7 +229,7 @@ class CpuDevice { inline bool AMX_BF16() { return mHasAMX_BF16; } inline bool AVX512_BF16() { return mHasAVX512_BF16; } inline bool AVX512_FP16() { return mHasAVX512_FP16; } - inline float getPE() { return (P_core.size() * P_power) / (E_core.size() * E_power); } + inline float* const getPE() { return PE; } inline size_t getPcoreNum() { return P_core.size(); } inline size_t getEcoreNum() { return E_core.size(); } inline size_t getSMTcoreNum() { return SMT_core.size(); } @@ -328,12 +329,40 @@ class CpuDevice { } } numcores = P_core.size() + E_core.size(); - numthreads = P_core.size() * 2 + E_core.size(); + numthreads = P_core.size() + E_core.size() + SMT_core.size(); + + { + // set PE + uint32_t tmp[4]; + _cpu.getCpuid(1, tmp); + if (p) printf("!!!\t%x\t%x\t%x\t%x!!!\n", tmp[0], tmp[1], tmp[2], tmp[3]); + const int famliy = (tmp[0] >> 8) & ((1u << 4) - 1); // cpu.extractBit(a[0], 8, 11); + const int extendedModel = (tmp[0] >> 16) & ((1u << 4) - 1); // cpu.extractBit(a[0], 16, 24); + { + for (int i = 0; i < int(BTLA_ISA::ISA_COUNT); i++) PE[i] = 1.0f; + // CPU identification refer to: https://en.wikichip.org/wiki/intel/cpuid + if (famliy == 6) switch (extendedModel) { + case 9: // ALD + PE[int(BTLA_ISA::AVX2)] = 3.0f; + PE[int(BTLA_ISA::AVX_VNNI)] = 5.0f; + break; + case 10: // MTL + PE[int(BTLA_ISA::AVX2)] = 2.2f; + PE[int(BTLA_ISA::AVX_VNNI)] = 3.0f; + break; + case 11: // RPL + PE[int(BTLA_ISA::AVX2)] = 1.8f; + PE[int(BTLA_ISA::AVX_VNNI)] = 2.6f; + break; + } + } + } } else { L1Cache = _cpu.getDataCacheSize(0); L2Cache = _cpu.getDataCacheSize(1); numthreads = numcores; } + L3Cache = _cpu.getDataCacheSize(2); #if FIXED_CACHE L2Cache = L2Cache >= FIXED_CACHE_SIZE ? FIXED_CACHE_SIZE : L2Cache; E_L2Cache = E_L2Cache >= FIXED_CACHE_SIZE ? FIXED_CACHE_SIZE : E_L2Cache; @@ -357,7 +386,7 @@ class CpuDevice { Xbyak::util::Cpu cpu; uint32_t tmp[4]; cpu.getCpuid(0x1A, tmp); - int core_type = (tmp[0] >> 24) & ((1u << 7) - 1); // cpu.extractBit(a[0], 24, 31); + int core_type = (tmp[0] >> 24) & ((1u << 8) - 1); // cpu.extractBit(a[0], 24, 31); switch (core_type) { case 32: // printf("Atom\n"); @@ -407,7 +436,7 @@ class CpuDevice { } static void core_bond(int core) { #ifdef _WIN32 - SetThreadAffinityMask(GetCurrentThread(), 1 << core); + SetThreadAffinityMask(GetCurrentThread(), 1LL << core); #else cpu_set_t cpuset; CPU_ZERO(&cpuset); @@ -420,7 +449,7 @@ class CpuDevice { static void core_bond(std::thread& thread, int core) { #ifdef _WIN32 HANDLE handle = thread.native_handle(); - SetThreadAffinityMask(handle, 1 << core); + SetThreadAffinityMask(handle, 1LL << core); #else cpu_set_t cpuset; CPU_ZERO(&cpuset); @@ -434,7 +463,7 @@ class CpuDevice { bool isHybrid() { return mHybrid; } protected: - uint32_t L2Cache, L1Cache; + uint32_t L2Cache, L1Cache, L3Cache; bool mHybrid = false; bool mHasAVX2, mHasAVX_VNNI, mHasAVX, mHasAVX512_VNNI, mHasAMX_INT8, mHasAMX_BF16, mHasAVX512F, mHasAVX512_BF16, mHasAVX512_FP16; @@ -442,21 +471,61 @@ class CpuDevice { int numthreads; std::vector P_core, E_core, SMT_core; uint32_t E_L2Cache, E_L1Cache; - float P_power = 4.8, E_power = 2.3; + float PE[int(BTLA_ISA::ISA_COUNT)]; }; #define GetCPUDevice() auto _cd = bestla::device::CpuDevice::getInstance(); -class CpuBase { +class CpuRuntime { public: - CpuBase() { + CpuRuntime() = default; + static CpuRuntime& getInstance(int thread) { + static std::map instances; + if (instances.count(thread) == 0) instances[thread] = CpuRuntime(thread); + return instances[thread]; + } + + inline float getPE(const BTLA_ISA isa) { + // printf("GET:%d\t%f\n",int(isa), *cur_PE); + return PE[int(isa)] * P_core_num / E_core_num; + } + + inline void adjustPE(const BTLA_ISA isa, const float PE_) { + // printf("Adjust:%d,%f\n",int(isa),PE_); + PE[int(isa)] *= PE_; + } + + size_t mL2Cache, mL1Cache, mL2Cache_P = 0, mL1Cache_P = 0, mL2Cache_E = 0, mL1Cache_E = 0; + int P_core_num = 0, E_core_num = 0; + bool mHybrid = false; + + private: + CpuRuntime(int thread) { GetCPUDevice(); mL2Cache = _cd->getL2CacheSize(); mL1Cache = _cd->getL1CacheSize(); - mNumThreads = _cd->getThreads(); + maxThreads = _cd->getThreads(); + mHybrid = false; + if (_cd->isHybrid() && thread > _cd->getPcoreNum()) { + if (thread > _cd->getPcoreNum() + _cd->getEcoreNum()) { + mL1Cache_P = mL1Cache / 2; + mL2Cache_P = mL2Cache / 2; + P_core_num = _cd->getPcoreNum(); + E_core_num = _cd->getEcoreNum(); + } else { + mL1Cache_P = mL1Cache; + mL2Cache_P = mL2Cache; + P_core_num = _cd->getPcoreNum(); + E_core_num = thread - P_core_num; + } + mL1Cache_E = _cd->getL1CacheSize_E(); + mL2Cache_E = _cd->getL2CacheSize_E(); + mHybrid = true; + memcpy(PE, _cd->getPE(), int(BTLA_ISA::ISA_COUNT) * sizeof(float)); + } } - size_t mL2Cache, mL1Cache; - int mNumThreads; + float PE[int(BTLA_ISA::ISA_COUNT)]; + int maxThreads; }; } // namespace device } // namespace bestla diff --git a/bestla/bestla/bestla_parallel.h b/bestla/bestla/bestla_parallel.h index c780aa4cf..7a996ae05 100644 --- a/bestla/bestla/bestla_parallel.h +++ b/bestla/bestla/bestla_parallel.h @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. #pragma once +#include #include #include #include @@ -23,10 +24,244 @@ namespace bestla { namespace parallel { + +using thread_func = std::function; + +class IThreading { + public: + explicit IThreading(int nthreads, bool supportPE) : mThreadNum(nthreads), isSupportPE(supportPE) {} + virtual void parallel_for(const thread_func& func) = 0; + virtual inline void sync(int tidx, int idx = 0) = 0; + virtual int num_threads() const { return mThreadNum; }; + virtual int is_support_PE() const { return isSupportPE; }; + virtual void set_threads(int nthreads) = 0; + virtual std::pair get_PEtime() const { return {0.0f, 0.0f}; }; + + protected: + int mThreadNum; + const bool isSupportPE; +}; + +#if BTLA_OPENMP +class OMPThreading : public IThreading { + public: + explicit OMPThreading(int nthreads) : IThreading(nthreads, false) { + // printf("Using OMP\n"); + omp_set_num_threads(nthreads); + } + void parallel_for(const thread_func& func) override { + if (mThreadNum > 1) { +#pragma omp parallel + { + int tidx = omp_get_thread_num(); + func(tidx); + } + } else { + func(0); + } + } + virtual void set_threads(int nthreads) override { + mThreadNum = nthreads; + omp_set_num_threads(nthreads); + } + virtual inline void sync(int tidx, int idx = 0) override { + (void)(tidx); + (void)(idx); +#pragma omp barrier + (void)(0); // make msvc happy with c++20 + } +}; +#endif + +class StdThreading : public IThreading { + public: + using Timer_T = utils::timer; + explicit StdThreading(int nthreads) : IThreading(nthreads, true) { + // printf("Using Std\n"); + cr = &device::CpuRuntime::getInstance(nthreads); + create_threads(); + } + void parallel_for(const thread_func& func) override { + time_per_p = 0; + time_per_e = 0; + Timer_T tm; + if (mThreadNum > 1) { + running.store(mThreadNum - 1); + for (int i = 0; i < 10; i++) flag[i].store(mThreadNum); + if (cr->mHybrid) { + int time_p = 0, time_e = 0; + + for (size_t i = 0; i < mThreadNum - 1; i++) func_[i] = &func; + thread_time[0] = 0; + tm.start(); + func(0); + thread_time[0] += int(tm.stop()); + while (true) { + if (running.load() == 0) + break; + else + _mm_pause(); + } + for (int i = 0; i < mThreadNum; i++) + if (i >= cr->P_core_num && i < cr->P_core_num + cr->E_core_num) + time_e += thread_time[i]; + else + time_p += thread_time[i]; + time_per_p = (time_p) / (1.0 * (mThreadNum - cr->E_core_num)); + time_per_e = (time_e) / (1.0 * cr->E_core_num); + // printf("%d %d %f %f\n", time_p, time_e, time_per_p, time_per_e); + } else { + for (size_t i = 0; i < mThreadNum - 1; i++) { + func_[i] = &func; + } + func(0); + while (true) { + if (running.load() == 0) + break; + else + _mm_pause(); + } + } + } else { + func(0); + } + } + + void set_threads(int nthreads) override { + if (nthreads != mThreadNum) { + stop_threads(); + mThreadNum = nthreads; + cr = &device::CpuRuntime::getInstance(nthreads); + create_threads(); + } + } + + inline void sync(int tidx, int idx = 0) override { + flag[idx].fetch_sub(1); + if (cr->mHybrid) { + Timer_T tm; + tm.start(); + while (true) { + if (flag[idx].load() == 0) + break; + else + _mm_pause(); + } + thread_time[tidx] -= int(tm.stop()); + } else { + while (true) { + if (flag[idx].load() == 0) + break; + else + _mm_pause(); + } + } + } + + std::pair get_PEtime() const override { return {time_per_p, time_per_e}; }; + + ~StdThreading() { stop_threads(); } + + private: + void stop_threads() { + stop = true; + for (int i = 0; i < mThreadNum - 1; i++) thdset[i].join(); + thdset.clear(); + // printf("stop %d\n", mThreadNum); + } + void create_threads() { + // printf("create %d\n", mThreadNum); + thdset.resize(mThreadNum - 1); + stop = false; + GetCPUDevice(); + std::vector core_order; + if (_cd->isHybrid()) { + core_order.resize(_cd->getThreads()); + memcpy(reinterpret_cast(core_order.data()), reinterpret_cast(_cd->getPCores()), + _cd->getPcoreNum() * sizeof(int)); + memcpy(reinterpret_cast(core_order.data() + _cd->getPcoreNum()), reinterpret_cast(_cd->getECores()), + _cd->getEcoreNum() * sizeof(int)); + memcpy(reinterpret_cast(core_order.data() + _cd->getPcoreNum() + _cd->getEcoreNum()), + reinterpret_cast(_cd->getSMTCores()), _cd->getSMTcoreNum() * sizeof(int)); + } else { + core_order.resize(mThreadNum); + for (int i = 0; i < mThreadNum; i++) core_order[i] = i; + } + _cd->core_bond(core_order[0]); + if (cr->mHybrid) { + thread_time.resize(mThreadNum); + for (size_t i = 0; i < mThreadNum - 1; i++) { + thdset[i] = std::thread( + [&](int tidx, int core_id) { + _cd->core_bond(core_id); + Timer_T tm; + while (true) { + if (stop.load() == true) break; + if (func_[tidx] != nullptr) { + thread_time[tidx + 1] = 0; + tm.start(); + (*func_[tidx])(tidx + 1); + func_[tidx] = nullptr; + thread_time[tidx + 1] += int(tm.stop()); + running.fetch_sub(1); + } else { + _mm_pause(); + } + } + }, + int(i), core_order[i + 1]); + } + } else + for (size_t i = 0; i < mThreadNum - 1; i++) { + thdset[i] = std::thread( + [&](int tidx, int core_id) { + _cd->core_bond(core_id); + while (true) { + if (stop.load() == true) break; + if (func_[tidx] != nullptr) { + (*func_[tidx])(tidx + 1); + func_[tidx] = nullptr; + running.fetch_sub(1); + } else { + _mm_pause(); + } + } + }, + int(i), core_order[i + 1]); + } + } + device::CpuRuntime* cr; + std::vector thread_time; + float time_per_p, time_per_e; + std::vector thdset; + std::atomic_bool stop; + std::atomic_int running; + std::atomic_int flag[10]; + const thread_func* func_[100]; +}; + +class SingleThread : public IThreading { + public: + SingleThread() : IThreading(1, false) {} + + void set_threads(int nthreads) override { + assert(0); + (void)(nthreads); + } + + inline void parallel_for(const thread_func& func) override { func(0); } + + inline void sync(int tidx, int idx = 0) override { + (void)(tidx); + (void)(idx); + } +}; + struct Config2D { int threads; int size[2]; int step[2]; + int offset[2]; }; struct ThreadProblem2D { int tid; @@ -60,6 +295,8 @@ class Scheduler2D { problem.loc[1] = problem.tidx[1] * mThdSize[1]; problem.size[0] = utils::remainsize(problem.loc[0], mSize[0], mThdSize[0]); problem.size[1] = utils::remainsize(problem.loc[1], mSize[1], mThdSize[1]); + problem.loc[0] += moffset[0]; + problem.loc[1] += moffset[1]; problem.valid = true; } @@ -68,10 +305,13 @@ class Scheduler2D { for (size_t i = 0; i < 2; i++) { mSize[i] = config.size[i]; mStep[i] = config.step[i]; + moffset[i] = config.offset[i]; } schedule(); } + constexpr static BTLA_ISA gemm_ISA() { return BTLA_ISA::NoSIMD; } + void print() { printf("Thread Block:(%d,%d)\n", mThdSize[0], mThdSize[1]); printf("Thread in use:%d of %d, Nx%d\n", mThdValid, mThdCount, mThdPerRow); @@ -111,6 +351,7 @@ class Scheduler2D { int mThdPerRow = 0; int mThdValid = 0; int mThdCount = 0; + int moffset[2] = {0, 0}; private: int mThdSize[2] = {0, 0}; @@ -123,6 +364,7 @@ namespace gemm { struct Config { const int threads; const utils::GemmProblem problem; + const int offset[2]; const size_t l2cache = 1024ULL * 1024; const size_t l1cache = 32ULL * 1024; }; @@ -156,6 +398,8 @@ class SchedulerBase : public Scheduler2D { mThdCount = config.threads; mL2Size = config.l2cache; mL1Size = config.l1cache; + Scheduler2D::moffset[0] = config.offset[0]; + Scheduler2D::moffset[1] = config.offset[1]; if (mSize[0] <= 0 || mSize[1] <= 0 || mSize[2] <= 0) { return; } @@ -166,6 +410,8 @@ class SchedulerBase : public Scheduler2D { assert(this->mBlock[2] > 0); } + constexpr static BTLA_ISA gemm_ISA() { return _GemmCore_T::ISA; } + constexpr int valid_theads() { return mThdValid; } virtual void print() { @@ -175,6 +421,9 @@ class SchedulerBase : public Scheduler2D { printf("Cache Size:%zu used:%zu\n", mL2Size, mL2Use); } + template + friend class SchedulerDispatcher; + protected: virtual void schedule() { int rownum = utils::updiv(mSize[0], mStep[0]); @@ -204,7 +453,7 @@ class SchedulerBase : public Scheduler2D { mL2Use += static_cast(mBlock[1]) * mBlock[2] * mEleSize[1]; mL2Use += static_cast(mStep[0]) * mBlock[2] * mEleSize[0]; } - const float DensityThres = 16; + static float constexpr DensityThres = 16; static size_t constexpr ReservedSize = 32ULL * 1024ULL; virtual float calculate_score() { @@ -279,8 +528,6 @@ class SchedulerBase : public Scheduler2D { size_t mL2Size = 0, mL1Size = 0, mL2Use = 0; float mDensity = 0.f; - - protected: int mSize[3] = {0, 0, 0}; int mThdSize[3] = {0, 0, 0}; static constexpr int mStep[3] = {_GemmCore_T::MTILE, _GemmCore_T::NTILE, _GemmCore_T::KTILE}; @@ -315,6 +562,8 @@ class SchedulerKBlock : public Scheduler2D { mThdCount = config.threads; mL2Size = config.l2cache; mL1Size = config.l1cache; + moffset[0] = config.offset[0]; + moffset[1] = config.offset[1]; mKBlock = config.problem.dims[4]; if (mSize[0] <= 0 || mSize[1] <= 0 || mSize[2] <= 0) { return; @@ -326,6 +575,8 @@ class SchedulerKBlock : public Scheduler2D { assert(this->mBlock[2] > 0); } + constexpr static BTLA_ISA gemm_ISA() { return _GemmCore_T::ISA; } + constexpr int valid_theads() { return mThdValid; } void print() { @@ -335,6 +586,9 @@ class SchedulerKBlock : public Scheduler2D { printf("Cache Size:%zu used:%zu\n", mL2Size, mL2Use); } + template + friend class SchedulerDispatcher; + protected: void schedule() { int rownum = utils::updiv(mSize[0], mStep[0]); @@ -364,7 +618,7 @@ class SchedulerKBlock : public Scheduler2D { mL2Use += static_cast(mBlock[1]) * mBlock[2] * mEleSize[1]; mL2Use += static_cast(mStep[0]) * mBlock[2] * mEleSize[0]; } - const float DensityThres = 16; + static float constexpr DensityThres = 16; float calculate_score() { int tmpnstep = mThdSize[1] < _GemmCore_T::PREFERRED_N ? mThdSize[1] : _GemmCore_T::PREFERRED_N; @@ -495,8 +749,13 @@ class SchedulerKBlockS : public SchedulerBase<_GemmCore_T> { assert(this->mBlock[2] % _GemmCore_T::KTILE == 0); } + constexpr static BTLA_ISA gemm_ISA() { return _GemmCore_T::ISA; } + + template + friend class SchedulerDispatcher; + protected: - const float DensityThres = 16; + static float constexpr DensityThres = 16; static size_t constexpr ReservedSize = 32ULL * 1024ULL; void cache_blocking_compute() override { @@ -581,84 +840,127 @@ class SchedulerKBlockS : public SchedulerBase<_GemmCore_T> { int mKBlock{0}; }; -} // namespace gemm -using thread_func = std::function; - -class IThreading { - public: - explicit IThreading(int nthreads) : mThreadNum(nthreads) {} - virtual void parallel_for(const thread_func& func) const = 0; - virtual inline void sync() const { assert(0); }; - virtual int num_threads() const { return mThreadNum; }; - virtual void set_threads(int nthreads) = 0; - - protected: - int mThreadNum; -}; -#if BTLA_OPENMP -class OMPThreading : public IThreading { +template +class SchedulerDispatcher { public: - explicit OMPThreading(int nthreads) : IThreading(nthreads) { omp_set_num_threads(nthreads); } - void parallel_for(const thread_func& func) const override { - if (mThreadNum > 1) { -#pragma omp parallel - { - int tidx = omp_get_thread_num(); - func(tidx); - } + using ThreadProblem = ThreadProblemBase; + SchedulerDispatcher() = default; + ~SchedulerDispatcher() { + std::pair PEtime = th_->get_PEtime(); + if (needDispach && int(PEtime.first) > 0 && int(PEtime.second) > 0) + cr->adjustPE(Scheduler::gemm_ISA(), PEtime.second / PEtime.first); + } + SchedulerDispatcher(const IThreading* th, const utils::GemmProblem& problem) { + th_ = th; + cr = &device::CpuRuntime::getInstance(th->num_threads()); + needDispach = cr->mHybrid && th->is_support_PE(); + if (!needDispach) { + Scheduler_P = std::move(Scheduler({th->num_threads(), problem, {0, 0}, cr->mL2Cache, cr->mL1Cache})); } else { - func(0); + Pcore_num = cr->P_core_num; + Ecore_num = cr->E_core_num; + utils::GemmProblem problem_P = problem, problem_E = problem; + const int N = problem.dims[2]; + auto PE_Ratio = cr->getPE(Scheduler::gemm_ISA()); + const int N_offset = utils::padto(N - int(N / (1 + PE_Ratio)), Scheduler::mStep[1]); + problem_P.dims[2] = N_offset; + Scheduler_P = + std::move(Scheduler({th->num_threads() - cr->E_core_num, problem_P, {0, 0}, cr->mL2Cache_P, cr->mL1Cache_P})); + problem_E.dims[2] = N - N_offset; + Scheduler_E = std::move(Scheduler({cr->E_core_num, problem_E, {0, N_offset}, cr->mL2Cache_E, cr->mL1Cache_E})); } } - virtual void set_threads(int nthreads) override { - mThreadNum = nthreads; - omp_set_num_threads(nthreads); + + void getIndex(ThreadProblem& problem) { + if (!needDispach) { + Scheduler_P.getIndex(problem); + } else { + if (problem.tid >= Pcore_num + Ecore_num) { + problem.tid -= Ecore_num; + Scheduler_P.getIndex(problem); + } else if (problem.tid >= Pcore_num) { + problem.tid -= Pcore_num; + Scheduler_E.getIndex(problem); + } else { + Scheduler_P.getIndex(problem); + } + } } - virtual inline void sync() const override { -#pragma omp barrier - (void)(0); // make msvc happy with c++20 + + void print() { + printf("dispatch to hybrid:%d\n", needDispach); + Scheduler_P.print(); + if (needDispach) Scheduler_E.print(); } + + private: + Scheduler Scheduler_P, Scheduler_E; + const IThreading* th_; + device::CpuRuntime* cr; + bool needDispach = false; + int Pcore_num = 0, Ecore_num = 0; }; -#endif -class StdThreading : public IThreading { +template <> +class SchedulerDispatcher { public: - explicit StdThreading(int nthreads) : IThreading(nthreads) {} - void parallel_for(const thread_func& func) const override { - if (mThreadNum > 1) { - std::vector thdset(mThreadNum - 1); - for (size_t i = 0; i < mThreadNum - 1; i++) { - thdset[i] = std::thread([&](int tidx) { func(tidx); }, int(i + 1)); - } - func(0); - for (size_t i = 0; i < mThreadNum - 1; i++) { - thdset[i].join(); - } + using ThreadProblem = ThreadProblem2D; + SchedulerDispatcher() = default; + ~SchedulerDispatcher() {} + SchedulerDispatcher(const IThreading* th, const Config2D& config) { + device::CpuRuntime& cr = device::CpuRuntime::getInstance(config.threads); + needDispach = cr.mHybrid && th->is_support_PE(); + if (!needDispach) { + Scheduler_P = std::move(Scheduler2D(config)); } else { - func(0); + Pcore_num = cr.P_core_num; + Ecore_num = cr.E_core_num; + Config2D config_P = config, config_E = config; + const int N = config.size[1]; + const int N_offset = utils::padto(N - int(N / (1 + cr.getPE(BTLA_ISA::NoSIMD))), config.step[1]); + config_P.threads = config.threads - cr.E_core_num; + config_P.size[1] = N_offset; + Scheduler_P = std::move(Scheduler2D(config_P)); + config_E.threads = cr.E_core_num; + config_E.size[1] = N - N_offset; + config_E.offset[1] += N_offset; + Scheduler_E = std::move(Scheduler2D(config_E)); } } - void set_threads(int nthreads) override { mThreadNum = nthreads; } + void getIndex(ThreadProblem& problem) { + if (!needDispach) { + Scheduler_P.getIndex(problem); + } else { + if (problem.tid >= Pcore_num + Ecore_num) { + problem.tid -= Ecore_num; + Scheduler_P.getIndex(problem); + } else if (problem.tid >= Pcore_num) { + problem.tid -= Pcore_num; + Scheduler_E.getIndex(problem); + } else { + Scheduler_P.getIndex(problem); + } + } + } - inline void sync() const override { assert(0); } + void print() { + printf("dispatch to hybrid:%d\n", needDispach); + Scheduler_P.print(); + if (needDispach) Scheduler_E.print(); + } private: + Scheduler2D Scheduler_P, Scheduler_E; + bool needDispach = false; + int Pcore_num = 0, Ecore_num = 0; }; -class SingleThread : public StdThreading { - public: - SingleThread() : StdThreading(1) {} - - void set_threads(int nthreads) override { (void)(nthreads); } - - inline void sync() const override {} -}; +} // namespace gemm template void GemmRun(Launch_T& launcher, const typename Launch_T::Param& args, parallel::IThreading* th) { - device::CpuBase cb; - Parallel_T para({th->num_threads(), args.problem, cb.mL2Cache, cb.mL1Cache}); + gemm::SchedulerDispatcher para(th, args.problem); static bool flag = false; if (flag) { printf("%s\n", __FUNCTION__); @@ -676,10 +978,9 @@ void GemmRun(Launch_T& launcher, const typename Launch_T::Param& args, parallel: template void GemmRunWithA(Launch_T& launcher, const typename Launch_T::Param& args, parallel::IThreading* th) { - device::CpuBase cb; - Parallel_T para({th->num_threads(), args.problem, cb.mL2Cache, cb.mL1Cache}); + gemm::SchedulerDispatcher para(th, args.problem); using AParall = typename Launch_T::PrologueA::Parallel; - auto apara = launcher.mProA.createParallel(th->num_threads(), args.problem); + AParall apara = launcher.mProA.createParallel(th->num_threads(), args.problem); static bool flag = false; if (flag) { printf("%s\n", __FUNCTION__); @@ -692,7 +993,7 @@ void GemmRunWithA(Launch_T& launcher, const typename Launch_T::Param& args, para if (thdpA.valid) { launcher.mProA.run(args.paramA, thdpA); } - th->sync(); + th->sync(tidx); typename Parallel_T::ThreadProblem thdp{tidx}; para.getIndex(thdp); if (thdp.valid) { diff --git a/bestla/bestla/bestla_utils.h b/bestla/bestla/bestla_utils.h index 8804a4436..e6636ff74 100644 --- a/bestla/bestla/bestla_utils.h +++ b/bestla/bestla/bestla_utils.h @@ -405,7 +405,7 @@ inline float get_mxfp_maxnorm(const BTLA_DTYPE t, int ebits, int mantissa_bits) } else { max_norm *= 1.75; } - return max_norm; + return static_cast(max_norm); } #ifndef _WIN32 @@ -667,12 +667,13 @@ class timer_statistics_logger { } float min_val, max_val, avg_val; - private: void record() { min_val = statis.min_val / log_ratio; max_val = statis.max_val / log_ratio; avg_val = statis.avg_val / log_ratio; } + + private: float log_ratio; char str[256]; timer<_PRECISION> tm; @@ -727,28 +728,30 @@ static float nf4_dequant_fp32_LUT[] = {0.f, // For more details pls refer // (8-Bit Approximations for Parallelism in Deep Learning)[https://arxiv.org/abs/1511.04561] static float dq8_bnb_LUT[] = { - -0.99297, -0.97891, -0.96484, -0.95078, -0.93672, -0.92266, -0.90859, -0.89453, -0.88047, -0.86641, -0.85234, - -0.83828, -0.82422, -0.81016, -0.79609, -0.78203, -0.76797, -0.75391, -0.73984, -0.72578, -0.71172, -0.69766, - -0.68359, -0.66953, -0.65547, -0.64141, -0.62734, -0.61328, -0.59922, -0.58516, -0.57109, -0.55703, -0.54297, - -0.52891, -0.51484, -0.50078, -0.48672, -0.47266, -0.45859, -0.44453, -0.43047, -0.41641, -0.40234, -0.38828, - -0.37422, -0.36016, -0.34609, -0.33203, -0.31797, -0.30391, -0.28984, -0.27578, -0.26172, -0.24766, -0.23359, - -0.21953, -0.20547, -0.19141, -0.17734, -0.16328, -0.14922, -0.13516, -0.12109, -0.10703, -0.09859, -0.09578, - -0.09297, -0.09016, -0.08734, -0.08453, -0.08172, -0.07891, -0.07609, -0.07328, -0.07047, -0.06766, -0.06484, - -0.06203, -0.05922, -0.05641, -0.05359, -0.05078, -0.04797, -0.04516, -0.04234, -0.03953, -0.03672, -0.03391, - -0.03109, -0.02828, -0.02547, -0.02266, -0.01984, -0.01703, -0.01422, -0.01141, -0.00972, -0.00916, -0.00859, - -0.00803, -0.00747, -0.00691, -0.00634, -0.00578, -0.00522, -0.00466, -0.00409, -0.00353, -0.00297, -0.00241, - -0.00184, -0.00128, -0.00094, -0.00083, -0.00072, -0.00061, -0.00049, -0.00038, -0.00027, -0.00016, -0.00009, - -0.00007, -0.00004, -0.00002, -0.00001, -0.00000, -0.00000, 0.00000, 0.00000, 0.00000, 0.00001, 0.00002, - 0.00004, 0.00007, 0.00009, 0.00016, 0.00027, 0.00038, 0.00049, 0.00061, 0.00072, 0.00083, 0.00094, - 0.00128, 0.00184, 0.00241, 0.00297, 0.00353, 0.00409, 0.00466, 0.00522, 0.00578, 0.00634, 0.00691, - 0.00747, 0.00803, 0.00859, 0.00916, 0.00972, 0.01141, 0.01422, 0.01703, 0.01984, 0.02266, 0.02547, - 0.02828, 0.03109, 0.03391, 0.03672, 0.03953, 0.04234, 0.04516, 0.04797, 0.05078, 0.05359, 0.05641, - 0.05922, 0.06203, 0.06484, 0.06766, 0.07047, 0.07328, 0.07609, 0.07891, 0.08172, 0.08453, 0.08734, - 0.09016, 0.09297, 0.09578, 0.09859, 0.10703, 0.12109, 0.13516, 0.14922, 0.16328, 0.17734, 0.19141, - 0.20547, 0.21953, 0.23359, 0.24766, 0.26172, 0.27578, 0.28984, 0.30391, 0.31797, 0.33203, 0.34609, - 0.36016, 0.37422, 0.38828, 0.40234, 0.41641, 0.43047, 0.44453, 0.45859, 0.47266, 0.48672, 0.50078, - 0.51484, 0.52891, 0.54297, 0.55703, 0.57109, 0.58516, 0.59922, 0.61328, 0.62734, 0.64141, 0.65547, - 0.66953, 0.68359, 0.69766, 0.71172, 0.72578, 0.73984, 0.75391, 0.76797, 0.78203, 0.79609, 0.81016, - 0.82422, 0.83828, 0.85234, 0.86641, 0.88047, 0.89453, 0.90859, 0.92266, 0.93672, 0.95078, 0.96484, - 0.97891, 0.99297, 1.00000}; + -0.99297f, -0.97891f, -0.96484f, -0.95078f, -0.93672f, -0.92266f, -0.90859f, -0.89453f, -0.88047f, -0.86641f, + -0.85234f, -0.83828f, -0.82422f, -0.81016f, -0.79609f, -0.78203f, -0.76797f, -0.75391f, -0.73984f, -0.72578f, + -0.71172f, -0.69766f, -0.68359f, -0.66953f, -0.65547f, -0.64141f, -0.62734f, -0.61328f, -0.59922f, -0.58516f, + -0.57109f, -0.55703f, -0.54297f, -0.52891f, -0.51484f, -0.50078f, -0.48672f, -0.47266f, -0.45859f, -0.44453f, + -0.43047f, -0.41641f, -0.40234f, -0.38828f, -0.37422f, -0.36016f, -0.34609f, -0.33203f, -0.31797f, -0.30391f, + -0.28984f, -0.27578f, -0.26172f, -0.24766f, -0.23359f, -0.21953f, -0.20547f, -0.19141f, -0.17734f, -0.16328f, + -0.14922f, -0.13516f, -0.12109f, -0.10703f, -0.09859f, -0.09578f, -0.09297f, -0.09016f, -0.08734f, -0.08453f, + -0.08172f, -0.07891f, -0.07609f, -0.07328f, -0.07047f, -0.06766f, -0.06484f, -0.06203f, -0.05922f, -0.05641f, + -0.05359f, -0.05078f, -0.04797f, -0.04516f, -0.04234f, -0.03953f, -0.03672f, -0.03391f, -0.03109f, -0.02828f, + -0.02547f, -0.02266f, -0.01984f, -0.01703f, -0.01422f, -0.01141f, -0.00972f, -0.00916f, -0.00859f, -0.00803f, + -0.00747f, -0.00691f, -0.00634f, -0.00578f, -0.00522f, -0.00466f, -0.00409f, -0.00353f, -0.00297f, -0.00241f, + -0.00184f, -0.00128f, -0.00094f, -0.00083f, -0.00072f, -0.00061f, -0.00049f, -0.00038f, -0.00027f, -0.00016f, + -0.00009f, -0.00007f, -0.00004f, -0.00002f, -0.00001f, -0.00000f, -0.00000f, 0.00000f, 0.00000f, 0.00000f, + 0.00001f, 0.00002f, 0.00004f, 0.00007f, 0.00009f, 0.00016f, 0.00027f, 0.00038f, 0.00049f, 0.00061f, + 0.00072f, 0.00083f, 0.00094f, 0.00128f, 0.00184f, 0.00241f, 0.00297f, 0.00353f, 0.00409f, 0.00466f, + 0.00522f, 0.00578f, 0.00634f, 0.00691f, 0.00747f, 0.00803f, 0.00859f, 0.00916f, 0.00972f, 0.01141f, + 0.01422f, 0.01703f, 0.01984f, 0.02266f, 0.02547f, 0.02828f, 0.03109f, 0.03391f, 0.03672f, 0.03953f, + 0.04234f, 0.04516f, 0.04797f, 0.05078f, 0.05359f, 0.05641f, 0.05922f, 0.06203f, 0.06484f, 0.06766f, + 0.07047f, 0.07328f, 0.07609f, 0.07891f, 0.08172f, 0.08453f, 0.08734f, 0.09016f, 0.09297f, 0.09578f, + 0.09859f, 0.10703f, 0.12109f, 0.13516f, 0.14922f, 0.16328f, 0.17734f, 0.19141f, 0.20547f, 0.21953f, + 0.23359f, 0.24766f, 0.26172f, 0.27578f, 0.28984f, 0.30391f, 0.31797f, 0.33203f, 0.34609f, 0.36016f, + 0.37422f, 0.38828f, 0.40234f, 0.41641f, 0.43047f, 0.44453f, 0.45859f, 0.47266f, 0.48672f, 0.50078f, + 0.51484f, 0.52891f, 0.54297f, 0.55703f, 0.57109f, 0.58516f, 0.59922f, 0.61328f, 0.62734f, 0.64141f, + 0.65547f, 0.66953f, 0.68359f, 0.69766f, 0.71172f, 0.72578f, 0.73984f, 0.75391f, 0.76797f, 0.78203f, + 0.79609f, 0.81016f, 0.82422f, 0.83828f, 0.85234f, 0.86641f, 0.88047f, 0.89453f, 0.90859f, 0.92266f, + 0.93672f, 0.95078f, 0.96484f, 0.97891f, 0.99297f, 1.00000f}; } // namespace bestla diff --git a/bestla/bestla/ut/bestla_benchmark.cpp b/bestla/bestla/ut/bestla_benchmark.cpp new file mode 100644 index 000000000..eadb03511 --- /dev/null +++ b/bestla/bestla/ut/bestla_benchmark.cpp @@ -0,0 +1,914 @@ +#include +#include "bestla_wrapper.h" +#include "bestla_ut.h" + +namespace bestla { +using namespace utils; +namespace ut { +int constexpr TestMs = 500; +class Benchmark_Fp32Fp32 { + public: + Benchmark_Fp32Fp32() { + UT_START(); + benchmark_all(1, 4096, 4096); + benchmark_all(1024, 4096, 4096); + benchmark_all(2048, 4096, 4096); + } + + using AType = float; + using BType = float; + using CType = float; + template + void benchmark(int m, int n, int k, int batch, AType* A, BType* B, CType* C, float timems, int threads) { + LOG_T log; + using Parallel = parallel::gemm::SchedulerBase; + using Launcher = + wrapper::gemm::LauncherBase; + Launcher kernel; + UT_Threading::set_threads(threads); + auto corestr = gemm::CoreAttr::to_str(Core_T::ID); + utils::timer tm; + auto tmpB = kernel.mProB.createStorage(n, k); + std::vector packBs(batch, 0); + avector bufB(tmpB.mSize * batch); + for (size_t i = 0; i < batch; i++) { + packBs[i] = tmpB; + packBs[i].assign(bufB.data() + i * tmpB.mSize); + kernel.mProB.packWeight(n, k, {B + i * n * k, n, &packBs[i]}, UT_Threading::get()); + } + auto psize = (size_t)m * n * k * 2; + tm.start(); + while (tm.stop() < timems) { + for (size_t i = 0; i < batch; i++) { + log.start(); + utils::GemmProblem gp(1, m, n, k); + typename Launcher::Param args{gp, {A + i * m * k, k}, {0, 0, &packBs[i]}, {C + i * m * n, n}}; + parallel::GemmRun(kernel, args, UT_Threading::get()); + log.stop(); + if (tm.stop() >= timems) { + break; + } + } + } + log.record(); + double flops = double(psize) / log.min_val / 1e6; + printf("Threads %d %s %s Flops:%.3f PerCoreFlops:%.3f\n", threads, corestr, log.get_log_str(), flops, + flops / threads); + } + + void benchmark_all(int m, int n, int k) { + auto memsize = gemm_memsize(m, n, k, BTLA_DTYPE::F32, BTLA_DTYPE::F32, BTLA_DTYPE::F32); + auto batch = auto_batch(memsize); + printf("%d %d %d %d %s %s %s\n", m, n, k, batch, bestla_dtype_str(BTLA_DTYPE::F32), + bestla_dtype_str(BTLA_DTYPE::F32), bestla_dtype_str(BTLA_DTYPE::F32)); + avector A(size_t(m) * k * batch); + avector B(size_t(k) * n * batch); + avector C(size_t(m) * n * batch, 0); + fill_buffer_randn(A.data(), m * k, -0.5f, 0.5f); + fill_buffer_randn(B.data(), n * k, -0.5f, 0.5f); + for (size_t i = 0; i < batch - 1; i++) { + memcpy(A.data() + i * m * k, A.data(), m * k * sizeof(AType)); + memcpy(B.data() + i * n * k, B.data(), n * k * sizeof(BType)); + } + using LOG = timer_statistics_logger; + float testtime = float(TestMs); + GetCPUDevice(); + auto threads_cfg = UT_Threading::get_threads_config(); + for (auto threads : threads_cfg) { + if (_cd->AVX512F()) { + benchmark(m, n, k, batch, A.data(), B.data(), C.data(), testtime, threads); + } + if (_cd->AVX2()) { + benchmark(m, n, k, batch, A.data(), B.data(), C.data(), testtime, threads); + } + } + } +}; +#ifdef BTLA_UT_WRAPPER +static Benchmark_Fp32Fp32 sBenchmark_Fp32Fp32; +#endif + +class Benchmark_U8S8S32 { + public: + Benchmark_U8S8S32() { + UT_START(); + benchmark_all(1, 4096, 4096); + benchmark_all(1024, 4096, 4096); + benchmark_all(2048, 4096, 4096); + } + + using AType = uint8_t; + using BType = int8_t; + using CType = int; + template + void benchmark(int m, int n, int k, int batch, AType* A, BType* B, CType* C, float timems, int threads) { + LOG_T log; + using Parallel = parallel::gemm::SchedulerBase; + using Launcher = + wrapper::gemm::LauncherBase; + static Launcher kernel; + UT_Threading::set_threads(threads); + auto corestr = gemm::CoreAttr::to_str(Core_T::ID); + utils::timer tm; + auto tmpB = kernel.mProB.createStorage(n, k); + std::vector packBs(batch, 0); + avector bufB(tmpB.mSize * batch); + for (size_t i = 0; i < batch; i++) { + packBs[i] = tmpB; + packBs[i].assign(bufB.data() + i * tmpB.mSize); + kernel.mProB.packWeight(n, k, {B + i * n * k, n, &packBs[i]}, UT_Threading::get()); + } + auto psize = (size_t)m * n * k * 2; + tm.start(); + while (tm.stop() < timems) { + for (size_t i = 0; i < batch; i++) { + log.start(); + utils::GemmProblem gp(1, m, n, k); + typename Launcher::Param args{gp, {A + i * m * k, k}, {0, 0, &packBs[i]}, {C + i * m * n, n}}; + parallel::GemmRun(kernel, args, UT_Threading::get()); + log.stop(); + if (tm.stop() >= timems) { + break; + } + } + } + log.record(); + double flops = double(psize) / log.min_val / 1e6; + printf("Threads %d %s %s Flops:%.3f PerCoreFlops:%.3f\n", threads, corestr, log.get_log_str(), flops, + flops / threads); + } + + void benchmark_all(int m, int n, int k) { + auto memsize = gemm_memsize(m, n, k, BTLA_DTYPE::U8, BTLA_DTYPE::S8, BTLA_DTYPE::S32); + auto batch = auto_batch(memsize); + printf("%d %d %d %d %s %s %s\n", m, n, k, batch, bestla_dtype_str(BTLA_DTYPE::U8), bestla_dtype_str(BTLA_DTYPE::S8), + bestla_dtype_str(BTLA_DTYPE::S32)); + avector A(size_t(m) * k * batch); + avector B(size_t(k) * n * batch); + avector C(size_t(m) * n * batch); + fill_buffer_randn(A.data(), m * k, AType(0), AType(255)); + fill_buffer_randn(B.data(), k * n, BType(-127), BType(127)); + for (size_t i = 0; i < batch - 1; i++) { + memcpy(A.data() + i * m * k, A.data(), m * k * sizeof(AType)); + memcpy(B.data() + i * n * k, B.data(), n * k * sizeof(BType)); + } + using LOG = timer_statistics_logger; + float testtime = float(TestMs); + GetCPUDevice(); + auto threads_cfg = UT_Threading::get_threads_config(); + for (auto threads : threads_cfg) { + if (_cd->AMX_INT8()) { + benchmark, LOG>(m, n, k, batch, A.data(), B.data(), C.data(), testtime, threads); + benchmark, LOG>(m, n, k, batch, A.data(), B.data(), C.data(), testtime, threads); + benchmark, LOG>(m, n, k, batch, A.data(), B.data(), C.data(), testtime, threads); + } + if (_cd->AVX512_VNNI()) { + benchmark, LOG>(m, n, k, batch, A.data(), B.data(), C.data(), testtime, + threads); + benchmark, LOG>(m, n, k, batch, A.data(), B.data(), C.data(), testtime, + threads); + } + if (_cd->AVX_VNNI()) { + benchmark, LOG>(m, n, k, batch, A.data(), B.data(), C.data(), testtime, threads); + benchmark, LOG>(m, n, k, batch, A.data(), B.data(), C.data(), testtime, threads); + } + } + } +}; +#ifdef BTLA_UT_WRAPPER +static Benchmark_U8S8S32 sBenchmark_U8S8S32; +#endif + +class Benchmark_S8S8S32 { + public: + Benchmark_S8S8S32() { + UT_START(); + benchmark_all(1, 4096, 4096); + benchmark_all(1024, 4096, 4096); + benchmark_all(2048, 4096, 4096); + } + + using AType = int8_t; + using BType = int8_t; + using CType = int; + template + void benchmark(int m, int n, int k, int batch, AType* A, BType* B, CType* C, float timems, int threads) { + LOG_T log; + using Parallel = parallel::gemm::SchedulerBase; + using Launcher = + wrapper::gemm::LauncherBase; + Launcher kernel; + UT_Threading::set_threads(threads); + auto corestr = gemm::CoreAttr::to_str(Core_T::ID); + utils::timer tm; + auto tmpB = kernel.mProB.createStorage(n, k); + std::vector packBs(batch, 0); + avector bufB(tmpB.mSize * batch); + for (size_t i = 0; i < batch; i++) { + packBs[i] = tmpB; + packBs[i].assign(bufB.data() + i * tmpB.mSize); + kernel.mProB.packWeight(n, k, {B + i * n * k, n, &packBs[i]}, UT_Threading::get()); + } + auto psize = (size_t)m * n * k * 2; + tm.start(); + while (tm.stop() < timems) { + for (size_t i = 0; i < batch; i++) { + log.start(); + utils::GemmProblem gp(1, m, n, k); + typename Launcher::Param args{gp, {A + i * m * k, k}, {0, 0, &packBs[i]}, {C + i * m * n, n}}; + parallel::GemmRun(kernel, args, UT_Threading::get()); + log.stop(); + if (tm.stop() >= timems) { + break; + } + } + } + log.record(); + double flops = double(psize) / log.min_val / 1e6; + printf("Threads %d %s %s Flops:%.3f PerCoreFlops:%.3f\n", threads, corestr, log.get_log_str(), flops, + flops / threads); + } + + void benchmark_all(int m, int n, int k) { + auto memsize = gemm_memsize(m, n, k, BTLA_DTYPE::S8, BTLA_DTYPE::S8, BTLA_DTYPE::S32); + auto batch = auto_batch(memsize); + printf("%d %d %d %d %s %s %s\n", m, n, k, batch, bestla_dtype_str(BTLA_DTYPE::S8), bestla_dtype_str(BTLA_DTYPE::S8), + bestla_dtype_str(BTLA_DTYPE::S32)); + avector A(size_t(m) * k * batch); + avector B(size_t(k) * n * batch); + avector C(size_t(m) * n * batch); + fill_buffer_randn(A.data(), m * k, AType(0), AType(255)); + fill_buffer_randn(B.data(), k * n, BType(-127), BType(127)); + for (size_t i = 0; i < batch - 1; i++) { + memcpy(A.data() + i * m * k, A.data(), m * k * sizeof(AType)); + memcpy(B.data() + i * n * k, B.data(), n * k * sizeof(AType)); + } + using LOG = timer_statistics_logger; + float testtime = float(TestMs); + GetCPUDevice(); + auto threads_cfg = UT_Threading::get_threads_config(); + for (auto threads : threads_cfg) { + if (_cd->AMX_INT8()) { + benchmark, LOG>(m, n, k, batch, A.data(), B.data(), C.data(), testtime, + threads); + benchmark, LOG>(m, n, k, batch, A.data(), B.data(), C.data(), testtime, + threads); + benchmark, LOG>(m, n, k, batch, A.data(), B.data(), C.data(), testtime, + threads); + } + } + } +}; +#ifdef BTLA_UT_WRAPPER +static Benchmark_S8S8S32 sBenchmark_S8S8S32; +#endif + +class Benchmark_Bf16Bf16Fp32 { + public: + Benchmark_Bf16Bf16Fp32() { + UT_START(); + benchmark_all(1, 4096, 4096); + benchmark_all(1024, 4096, 4096); + benchmark_all(2048, 4096, 4096); + } + + using AType = utils::bf16; + using BType = utils::bf16; + using CType = float; + template + void benchmark(int m, int n, int k, int batch, AType* A, BType* B, CType* C, float timems, int threads) { + LOG_T log; + using Parallel = parallel::gemm::SchedulerBase; + using Launcher = + wrapper::gemm::LauncherBase; + Launcher kernel; + UT_Threading::set_threads(threads); + auto corestr = gemm::CoreAttr::to_str(Core_T::ID); + utils::timer tm; + auto tmpB = kernel.mProB.createStorage(n, k); + std::vector packBs(batch, 0); + avector bufB(tmpB.mSize * batch); + for (size_t i = 0; i < batch; i++) { + packBs[i] = tmpB; + packBs[i].assign(bufB.data() + i * tmpB.mSize); + kernel.mProB.packWeight(n, k, {B + i * n * k, n, &packBs[i]}, UT_Threading::get()); + } + auto psize = (size_t)m * n * k * 2; + tm.start(); + while (tm.stop() < timems) { + for (size_t i = 0; i < batch; i++) { + log.start(); + utils::GemmProblem gp(1, m, n, k); + typename Launcher::Param args{gp, {A + i * m * k, k}, {0, 0, &packBs[i]}, {C + i * m * n, n}}; + parallel::GemmRun(kernel, args, UT_Threading::get()); + log.stop(); + if (tm.stop() >= timems) { + break; + } + } + } + log.record(); + double flops = double(psize) / log.min_val / 1e6; + printf("Threads %d %s %s Flops:%.3f PerCoreFlops:%.3f\n", threads, corestr, log.get_log_str(), flops, + flops / threads); + } + + void benchmark_all(int m, int n, int k) { + auto memsize = gemm_memsize(m, n, k, BTLA_DTYPE::BF16, BTLA_DTYPE::BF16, BTLA_DTYPE::F32); + auto batch = auto_batch(memsize); + printf("%d %d %d %d %s %s %s\n", m, n, k, batch, bestla_dtype_str(BTLA_DTYPE::BF16), + bestla_dtype_str(BTLA_DTYPE::BF16), bestla_dtype_str(BTLA_DTYPE::F32)); + avector A(size_t(m) * k * batch); + avector B(size_t(k) * n * batch); + avector C(size_t(m) * n * batch); + fill_buffer_randn(A.data(), k * m, AType(-0.5f), AType(0.5f)); + fill_buffer_randn(B.data(), k * n, BType(-0.5f), BType(0.5f)); + for (size_t i = 0; i < batch - 1; i++) { + memcpy(A.data() + i * m * k, A.data(), m * k * sizeof(AType)); + memcpy(B.data() + i * n * k, B.data(), n * k * sizeof(BType)); + } + using LOG = timer_statistics_logger; + float testtime = float(TestMs); + GetCPUDevice(); + auto threads_cfg = UT_Threading::get_threads_config(); + for (auto threads : threads_cfg) { + if (_cd->AMX_BF16()) { + benchmark, LOG>(m, n, k, batch, A.data(), B.data(), C.data(), testtime, threads); + benchmark, LOG>(m, n, k, batch, A.data(), B.data(), C.data(), testtime, threads); + benchmark, LOG>(m, n, k, batch, A.data(), B.data(), C.data(), testtime, threads); + } + } + } +}; +#ifdef BTLA_UT_WRAPPER +static Benchmark_Bf16Bf16Fp32 sBenchmark_Bf16Bf16Fp32; +#endif + +class Benchmark_Fp16Fp16Fp16 { + public: + Benchmark_Fp16Fp16Fp16() { + UT_START(); + benchmark_all(1, 4096, 4096); + benchmark_all(1024, 4096, 4096); + benchmark_all(2048, 4096, 4096); + } + + using AType = utils::fp16; + using BType = utils::fp16; + using CType = utils::fp16; + template + void benchmark(int m, int n, int k, int batch, AType* A, BType* B, CType* C, float timems, int threads) { + LOG_T log; + using Parallel = parallel::gemm::SchedulerBase; + using Launcher = + wrapper::gemm::LauncherBase; + Launcher kernel; + UT_Threading::set_threads(threads); + auto corestr = gemm::CoreAttr::to_str(Core_T::ID); + utils::timer tm; + auto tmpB = kernel.mProB.createStorage(n, k); + std::vector packBs(batch, 0); + avector bufB(tmpB.mSize * batch); + for (size_t i = 0; i < batch; i++) { + packBs[i] = tmpB; + packBs[i].assign(bufB.data() + i * tmpB.mSize); + kernel.mProB.packWeight(n, k, {B + i * n * k, n, &packBs[i]}, UT_Threading::get()); + } + auto psize = (size_t)m * n * k * 2; + tm.start(); + while (tm.stop() < timems) { + for (size_t i = 0; i < batch; i++) { + log.start(); + GemmProblem gp(1, m, n, k); + typename Launcher::Param args{gp, {A + i * m * k, k}, {0, 0, &packBs[i]}, {C + i * m * n, n}}; + parallel::GemmRun(kernel, args, UT_Threading::get()); + log.stop(); + if (tm.stop() >= timems) { + break; + } + } + } + log.record(); + double flops = double(psize) / log.min_val / 1e6; + printf("Threads %d %s %s Flops:%.3f PerCoreFlops:%.3f\n", threads, corestr, log.get_log_str(), flops, + flops / threads); + } + + void benchmark_all(int m, int n, int k) { + auto memsize = gemm_memsize(m, n, k, BTLA_DTYPE::F16, BTLA_DTYPE::F16, BTLA_DTYPE::F16); + auto batch = auto_batch(memsize); + printf("%d %d %d %d %s %s %s\n", m, n, k, batch, bestla_dtype_str(BTLA_DTYPE::F16), + bestla_dtype_str(BTLA_DTYPE::F16), bestla_dtype_str(BTLA_DTYPE::F16)); + avector A(size_t(m) * k * batch); + avector B(size_t(k) * n * batch); + avector C(size_t(m) * n * batch); + fill_buffer_randn(A.data(), k * m, AType(-0.5f), AType(0.5f)); + fill_buffer_randn(B.data(), k * n, AType(-0.5f), AType(0.5f)); + for (size_t i = 0; i < batch - 1; i++) { + memcpy(A.data() + i * m * k, A.data(), m * k * sizeof(AType)); + memcpy(B.data() + i * n * k, B.data(), n * k * sizeof(BType)); + } + using LOG = timer_statistics_logger; + float testtime = float(TestMs); + GetCPUDevice(); + auto threads_cfg = UT_Threading::get_threads_config(); + for (auto threads : threads_cfg) { + if (_cd->AVX512_FP16()) { + benchmark(m, n, k, batch, A.data(), B.data(), C.data(), testtime, threads); + benchmark, LOG>(m, n, k, batch, A.data(), B.data(), C.data(), testtime, + threads); + } + } + } +}; +#ifdef BTLA_UT_WRAPPER +static Benchmark_Fp16Fp16Fp16 sBenchmark_Fp16Fp16Fp16; +#endif + +class UTWOQ_CompFp32 { + public: + UTWOQ_CompFp32() { + UT_START(); + ut_s4(); + ut_s8(); + ut_f4(); + } + + void ut_s4() { + benchmark_all(1, 4096, 4096, BTLA_DTYPE::S4_CLIP); + benchmark_all(1024, 4096, 4096, BTLA_DTYPE::S4_CLIP); + } + + void ut_s8() { + benchmark_all(1, 4096, 4096, BTLA_DTYPE::S8); + benchmark_all(1024, 4096, 4096, BTLA_DTYPE::S8); + } + + void ut_f4() { + benchmark_all(1, 4096, 4096, BTLA_DTYPE::F4_BNB); + benchmark_all(1024, 4096, 4096, BTLA_DTYPE::F4_BNB); + } + + template class Wei, typename Scale_T> + void benchmark(int m, int n, int k, int batch, int blocksize, float* A, float* B, float* C, float timems, int threads, + BTLA_DTYPE qtype) { + LOG_T log; + using Parallel = parallel::gemm::SchedulerBase; + using Launcher = wrapper::gemm::LauncherBase; + Launcher kernel; + UT_Threading::set_threads(threads); + auto corestr = gemm::CoreAttr::to_str(Core_T::ID); + utils::timer tm; + using WType = typename Wei::StorageWeight; + WType tmpB(0); + if constexpr (std::is_same_v, + prologue_b::gemm::WeightKBlockNInteger>) { + tmpB = kernel.mProB.createStorage(n, k, blocksize, qtype, bestla_dtype, bestla_dtype, false); + + } else if constexpr (std::is_same_v, + prologue_b::gemm::WeightKBlockNFloat>) { + tmpB = kernel.mProB.createStorage(n, k, blocksize, qtype, bestla_dtype); + } + std::vector packBs(batch, 0); + avector bufB(tmpB.mSize * batch); + for (size_t i = 0; i < batch; i++) { + packBs[i] = tmpB; + packBs[i].assign(bufB.data() + i * tmpB.mSize); + } + kernel.mProB.packWeight(n, k, B, n, &packBs[0], UT_Threading::get()); + for (size_t i = 1; i < batch; i++) { + memcpy(packBs[i].template WPtr(), packBs[0].template WPtr(), packBs[0].template WSize()); + memcpy(packBs[i].template SPtr(), packBs[0].template SPtr(), packBs[0].CSize() * sizeof(Scale_T)); + } + auto psize = (size_t)m * n * k * 2; + auto memsize = (size_t)packBs[0].mSize + (m * k + m * n) * sizeof(float); + tm.start(); + while (tm.stop() < timems) { + for (int i = 0; i < batch; i++) { + log.start(); + GemmProblem gp(1, m, n, k); + typename Launcher::Param args{gp, {A + i * m * k, k}, {&packBs[i]}, {C + i * m * n, n}}; + parallel::GemmRun(kernel, args, UT_Threading::get()); + log.stop(); + if (tm.stop() >= timems) { + break; + } + } + } + log.record(); + double flops = double(psize) / log.min_val / 1e6; + double band = double(memsize) / log.min_val / 1e6; + printf("Threads %d Block %d %s %s Flops:%.3fG PerCoreFlops:%.3fG MemoryBandwidth:%.3fGB/s\n", threads, blocksize, + corestr, log.get_log_str(), flops, flops / threads, band); + } + + template class Wei, typename Scale_T> + void benchmark_mem(int m, int n, int k, int batch, int blocksize, float* A, float* B, float* C, float timems, + int threads, BTLA_DTYPE qtype) { + LOG_T log; + using Parallel = parallel::gemm::SchedulerKBlock; + using Launcher = + wrapper::gemm::LauncherKBlock; + Launcher kernel; + UT_Threading::set_threads(threads); + auto corestr = gemm::CoreAttr::to_str(Core_T::ID); + utils::timer tm; + using WType = typename Wei::StorageWeight; + WType tmpB(0); + if constexpr (std::is_same_v, + prologue_b::gemm::WeightKBlockNInteger>) { + tmpB = kernel.mProB.createStorage(n, k, blocksize, qtype, bestla_dtype, bestla_dtype, false); + + } else if constexpr (std::is_same_v, + prologue_b::gemm::WeightKBlockNFloat>) { + tmpB = kernel.mProB.createStorage(n, k, blocksize, qtype, bestla_dtype); + } + auto memsize = (size_t)tmpB.mSize + (m * k + m * n) * sizeof(float); + std::vector packBs(batch, 0); + avector bufB(tmpB.mSize * batch); + for (size_t i = 0; i < batch; i++) { + packBs[i] = tmpB; + packBs[i].assign(bufB.data() + i * tmpB.mSize); + } + kernel.mProB.packWeight(n, k, B, n, &packBs[0], UT_Threading::get()); + for (size_t i = 1; i < batch; i++) { + memcpy(packBs[i].template WPtr(), packBs[0].template WPtr(), packBs[0].template WSize()); + memcpy(packBs[i].template SPtr(), packBs[0].template SPtr(), packBs[0].CSize() * sizeof(Scale_T)); + } + auto psize = (size_t)m * n * k * 2; + tm.start(); + while (tm.stop() < timems) { + log.start(); + for (size_t i = 0; i < batch; i++) { + GemmProblem gp(1, m, n, k, blocksize); + typename Launcher::Param args{gp, + {A + i * m * k, k}, + {&packBs[i]}, + {packBs[i].template SPtr(), packBs[i].SDtype(), packBs[i].CStep()}, + {C + i * m * n, n}}; + parallel::GemmRun(kernel, args, UT_Threading::get()); + } + log.stop(); + } + log.record(); + double t = log.min_val / batch; + double flops = double(psize) / t / 1e6; + double band = double(memsize) / t / 1e6; + printf("Threads %d Block %d %s Flops:%.3fG PerCoreFlops:%.3fG MemoryBandwidth:%.3fGB/s\n", threads, blocksize, + corestr, flops, flops / threads, band); + } + + template