Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

Commit

Permalink
update base thread pool
Browse files Browse the repository at this point in the history
  • Loading branch information
yuchengliu1 committed Feb 4, 2024
1 parent 9727cfd commit d96013c
Show file tree
Hide file tree
Showing 10 changed files with 63 additions and 37 deletions.
4 changes: 2 additions & 2 deletions bestla/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ option(BTLA_UT_KERNEL_JIT "Enable unit test for jit kernels" OFF)
option(BTLA_UT_KERNEL_INTRIN "Enable unit test for intrinsic kernels" OFF)
option(BTLA_UT_KERNEL_WRAPPER "Enable unit test for runtime ISA kernels" OFF)
option(BTLA_UT_NOASAN "Disable sanitize" OFF)
option(BTLA_UT_BENCHMARK "Benchmark ON may take a long time to finish all tests" OFF)
option(BTLA_UT_OPENMP "Use OpenMP" ON)
option(BTLA_UT_BENCHMARK "Benchmark ON may take a long time to finish all tests" ON)
option(BTLA_UT_OPENMP "Use OpenMP" OFF)

add_library(${PROJECT_NAME} INTERFACE)
add_library(neural_speed::${PROJECT_NAME} ALIAS ${PROJECT_NAME})
Expand Down
47 changes: 25 additions & 22 deletions bestla/bestla/bestla_parallel.h
Original file line number Diff line number Diff line change
Expand Up @@ -626,27 +626,28 @@ class StdThreading : public IThreading {
explicit StdThreading(int nthreads) : IThreading(nthreads) { create_threads(); }
void parallel_for(const thread_func& func) override {
if (mThreadNum > 1) {
func_ = &func;
running.store(mThreadNum - 1);
for (size_t i = 0; i < mThreadNum - 1; i++) {
locks[i] = true;
func_[i] = &func;
}
func(0);
while (true) {
bool is_lock = false;
for (size_t i = 0; !is_lock && i < mThreadNum - 1; i++) {
is_lock |= locks[i];
}
if (!is_lock) break;
if (running.load() == 0)
break;
else
_mm_pause();
}
} else {
func(0);
}
}

void set_threads(int nthreads) override {
stop_threads();
mThreadNum = nthreads;
create_threads();
if (nthreads != mThreadNum) {
stop_threads();
mThreadNum = nthreads;
create_threads();
}
}

inline void sync() const override { assert(0); }
Expand All @@ -655,24 +656,25 @@ class StdThreading : public IThreading {

private:
void stop_threads() {
for (int i = 0; i < mThreadNum - 1; i++) stop[i] = true;
stop = true;
for (int i = 0; i < mThreadNum - 1; i++) thdset[i].join();
thdset.clear();
// printf("stop %d\n", mThreadNum);
}
void create_threads() {
thdset.clear();
// printf("create %d\n", mThreadNum);
thdset.resize(mThreadNum - 1);
locks.resize(mThreadNum - 1);
stop.resize(mThreadNum - 1);
stop = false;

for (size_t i = 0; i < mThreadNum - 1; i++) {
stop[i] = false;
locks[i] = false;
thdset[i] = std::thread(
[&](int tidx) {
while (!stop[tidx]) {
if (locks[tidx]) {
(*func_)(tidx + 1);
locks[tidx] = false;
while (true) {
if (stop.load() == true) break;
if (func_[tidx] != nullptr) {
(*func_[tidx])(tidx + 1);
func_[tidx] = nullptr;
running.fetch_sub(1);
} else {
_mm_pause();
}
Expand All @@ -683,8 +685,9 @@ class StdThreading : public IThreading {
}

std::vector<std::thread> thdset;
std::vector<bool> locks, stop;
const thread_func* func_ = nullptr;
std::atomic_bool stop;
std::atomic_int running;
const thread_func* func_[100];
};

class SingleThread : public StdThreading {
Expand Down
2 changes: 2 additions & 0 deletions bestla/bestla/ut/bestla_epilogue.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "bestla_epilogue.h"
#include "bestla_ut.h"

#ifdef BTLA_UT_EPILOGUE
namespace bestla {
using namespace utils;
namespace ut {
Expand Down Expand Up @@ -139,3 +140,4 @@ static UT_AlphaBetaProcessFp32 sUT_AlphaBetaProcessFp32;
#endif
} // namespace ut
} // namespace bestla
#endif
2 changes: 2 additions & 0 deletions bestla/bestla/ut/bestla_gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include "bestla_utils.h"
#include "bestla_ut.h"

#ifdef BTLA_UT_GEMM
namespace bestla {
using namespace utils;

Expand Down Expand Up @@ -1115,3 +1116,4 @@ static UT_GEMM_AMXINT8 sUT_GEMM_AMXINT8;
#endif
} // namespace ut
} // namespace bestla
#endif
3 changes: 3 additions & 0 deletions bestla/bestla/ut/bestla_parallel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
#include "bestla_gemm.h"
#include "bestla_ut.h"
#include "bestla_prologue_a.h"

#ifdef BTLA_UT_PARALLEL
namespace bestla {
using namespace utils;
namespace ut {
Expand Down Expand Up @@ -206,3 +208,4 @@ static UT_SchedulerGemmKBlockNew sUT_SchedulerGemmKBlockNew;
#endif
} // namespace ut
} // namespace bestla
#endif
2 changes: 2 additions & 0 deletions bestla/bestla/ut/bestla_prologue_a.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include "bestla_ut.h"
#include "kernel_avx512f.h"

#ifdef BTLA_UT_PROLOGUE_A
namespace bestla {
using namespace utils;
namespace ut {
Expand Down Expand Up @@ -292,3 +293,4 @@ static UT_ShuffleActivationKblock sUT_ShuffleActivationKblock;
#endif
} // namespace ut
} // namespace bestla
#endif
2 changes: 2 additions & 0 deletions bestla/bestla/ut/bestla_prologue_b.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "bestla_wrapper.h"
#include "bestla_ut.h"

#ifdef BTLA_UT_PROLOGUE_B
namespace bestla {
using namespace utils;
namespace ut {
Expand Down Expand Up @@ -1780,3 +1781,4 @@ static UT_CompFp16 sUT_CompFp16;
#endif
} // namespace ut
} // namespace bestla
#endif
10 changes: 10 additions & 0 deletions bestla/bestla/ut/bestla_ut.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
#include <stdio.h>
#include <bestla_parallel.h>

namespace bestla {
namespace ut {
#ifdef _OPENMP
parallel::OMPThreading DefaultThreading(4);
#else
parallel::StdThreading DefaultThreading(4);
#endif // _OPNEMP
} // namespace ut
} // namespace bestla
int main() {
printf("BesTLA UT done\n");
return 0;
Expand Down
6 changes: 4 additions & 2 deletions bestla/bestla/ut/bestla_ut.h
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#pragma once

#include <random>
#include <stdexcept>
#include "bestla_utils.h"
Expand Down Expand Up @@ -25,9 +27,9 @@ using sAMX_INT8_US = gemm::ICoreRowNAmxint8<64, 16>;
using sAMX_INT8_SS = gemm::ICoreRowNAmxint8SS<64, 16>;
using sAVX2 = gemm::SCoreRowNAvx2<24, 4>;
#ifdef _OPENMP
static parallel::OMPThreading DefaultThreading(4);
extern parallel::OMPThreading DefaultThreading;
#else
static parallel::StdThreading DefaultThreading(4);
extern parallel::StdThreading DefaultThreading;
#endif // _OPNEMP

constexpr size_t CacheSize = size_t(100) << 10;
Expand Down
22 changes: 11 additions & 11 deletions bestla/bestla/ut/bestla_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@ class UT_Fp32Fp32 {
public:
UT_Fp32Fp32() {
UT_START();
#ifdef JBLAS_UT_BENCHMARK
#ifdef BTLA_UT_BENCHMARK
benchmark_all(1, 4096, 4096, 32);
benchmark_all(1024, 4096, 4096, 32);
benchmark_all(2048, 4096, 4096, 32);
#endif // JBLAS_UT_BENCHMARK
#endif // BTLA_UT_BENCHMARK

CheckISA(AVX2);
ut<sAVX2>(1, 1, 1);
Expand Down Expand Up @@ -116,7 +116,7 @@ class UT_Fp32Fp32 {
}
}
};
#ifdef JBLAS_UT_WRAPPER
#ifdef BTLA_UT_WRAPPER
static UT_Fp32Fp32 sUT_Fp32Fp32;
#endif

Expand All @@ -125,7 +125,7 @@ class UT_U8S8S32 {
UT_U8S8S32() {
UT_START();
GetCPUDevice();
#ifdef JBLAS_UT_BENCHMARK
#ifdef BTLA_UT_BENCHMARK
benchmark_all(1024, 4096, 4096, 32);
benchmark_all(2048, 4096, 4096, 32);
#endif
Expand Down Expand Up @@ -269,7 +269,7 @@ class UT_U8S8S32 {
}
}
};
#ifdef JBLAS_UT_WRAPPER
#ifdef BTLA_UT_WRAPPER
static UT_U8S8S32 sUT_U8S8S32;
#endif

Expand All @@ -278,7 +278,7 @@ class UT_S8S8S32 {
UT_S8S8S32() {
UT_START();
GetCPUDevice();
#ifdef JBLAS_UT_BENCHMARK
#ifdef BTLA_UT_BENCHMARK
benchmark_all(1024, 4096, 4096, 32);
benchmark_all(2048, 4096, 4096, 32);
#endif
Expand Down Expand Up @@ -393,7 +393,7 @@ class UT_S8S8S32 {
}
}
};
#ifdef JBLAS_UT_WRAPPER
#ifdef BTLA_UT_WRAPPER
static UT_S8S8S32 sUT_S8S8S32;
#endif

Expand All @@ -403,7 +403,7 @@ class UT_Bf16Bf16Fp32 {
UT_START();
CheckISA(AMX_BF16);
request_perm_xtile_data();
#ifdef JBLAS_UT_BENCHMARK
#ifdef BTLA_UT_BENCHMARK
benchmark_all(1024, 4096, 4096, 32);
benchmark_all(2048, 4096, 4096, 32);
#endif
Expand Down Expand Up @@ -499,7 +499,7 @@ class UT_Bf16Bf16Fp32 {
}
}
};
#ifdef JBLAS_UT_WRAPPER
#ifdef BTLA_UT_WRAPPER
static UT_Bf16Bf16Fp32 sUT_Bf16Bf16Fp32;
#endif

Expand All @@ -508,7 +508,7 @@ class UT_Fp16Fp16Fp16 {
UT_Fp16Fp16Fp16() {
UT_START();
CheckISA(AVX512_FP16);
#ifdef JBLAS_UT_BENCHMARK
#ifdef BTLA_UT_BENCHMARK
benchmark_all(1024, 4096, 4096, 32);
benchmark_all(2048, 4096, 4096, 32);
#endif
Expand Down Expand Up @@ -602,7 +602,7 @@ class UT_Fp16Fp16Fp16 {
}
}
};
#ifdef JBLAS_UT_WRAPPER
#ifdef BTLA_UT_WRAPPER
static UT_Fp16Fp16Fp16 sUT_Fp16Fp16Fp16;
#endif
} // namespace ut
Expand Down

0 comments on commit d96013c

Please sign in to comment.