Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

Commit

Permalink
using singleton instead of extern
Browse files Browse the repository at this point in the history
  • Loading branch information
yuchengliu1 committed Feb 20, 2024
1 parent ea90d51 commit f58d0e1
Show file tree
Hide file tree
Showing 8 changed files with 124 additions and 118 deletions.
10 changes: 8 additions & 2 deletions bestla/bestla/bestla_parallel.h
Original file line number Diff line number Diff line change
Expand Up @@ -599,7 +599,10 @@ class IThreading {
#if BTLA_OPENMP
class OMPThreading : public IThreading {
public:
explicit OMPThreading(int nthreads) : IThreading(nthreads) { omp_set_num_threads(nthreads); }
explicit OMPThreading(int nthreads) : IThreading(nthreads) {
printf("Using OMP\n");
omp_set_num_threads(nthreads);
}
void parallel_for(const thread_func& func) override {
if (mThreadNum > 1) {
#pragma omp parallel
Expand All @@ -624,7 +627,10 @@ class OMPThreading : public IThreading {

class StdThreading : public IThreading {
public:
explicit StdThreading(int nthreads) : IThreading(nthreads) { create_threads(); }
explicit StdThreading(int nthreads) : IThreading(nthreads) {
printf("Using Std\n");
create_threads();
}
void parallel_for(const thread_func& func) override {
if (mThreadNum > 1) {
running.store(mThreadNum - 1);
Expand Down
4 changes: 2 additions & 2 deletions bestla/bestla/ut/bestla_parallel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class UT_OMPThreading {
kernel::wrapper::Transpose2D<float>::template forward<BTLA_ISA::AVX512F>(src.data(), ref.data(), row, col, col,
row);
parallel::Scheduler2D _para({threads, row, col, 1, 1});
DefaultThreading.parallel_for([&](int tidx) {
UT_Threading::get()->parallel_for([&](int tidx) {
parallel::ThreadProblem2D thdp{tidx};
_para.getIndex(thdp);
if (thdp.valid) {
Expand Down Expand Up @@ -61,7 +61,7 @@ class UT_StdThreading {
kernel::wrapper::Transpose2D<float>::template forward<BTLA_ISA::AVX512F>(src.data(), ref.data(), row, col, col,
row);
parallel::Scheduler2D _para({threads, row, col, 1, 1});
DefaultThreading.parallel_for([&](int tidx) {
UT_Threading::get()->parallel_for([&](int tidx) {
parallel::ThreadProblem2D thdp{tidx};
_para.getIndex(thdp);
if (thdp.valid) {
Expand Down
8 changes: 4 additions & 4 deletions bestla/bestla/ut/bestla_prologue_a.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ class UT_ActivationU8KBlockQuantize {
auto quanAct = actA.createStorage(m, k, kblock, hasreduce);
avector<int8_t> bufA(quanAct.mSize);
quanAct.assign(bufA.data());
actA.quantize({raw.data(), lda, &quanAct}, m, k, &DefaultThreading);
actA.quantize({raw.data(), lda, &quanAct}, m, k, UT_Threading::get());

ut::buffer_error(q.data(), quanAct.template APtr<uint8_t>(), q.size(), uint8_t(1));
ut::buffer_error(zp.data(), quanAct.template ZPtr<uint8_t>(), zp.size(), uint8_t(1));
Expand Down Expand Up @@ -186,7 +186,7 @@ class UT_ActivationS8KBlockQuantize {
auto quanAct = actA.createStorage(m, k, kblock, hasreduce);
avector<int8_t> bufA(quanAct.mSize);
quanAct.assign(bufA.data());
actA.quantize({raw.data(), k, &quanAct}, m, k, &DefaultThreading);
actA.quantize({raw.data(), k, &quanAct}, m, k, UT_Threading::get());
ut::buffer_error(q.data(), quanAct.template APtr<int8_t>(), q.size(), int8_t(1));
if (hasreduce) {
avector<float> redref(reduce.size(), 0.f), redqref(reduce.size(), 0.f);
Expand Down Expand Up @@ -235,7 +235,7 @@ class UT_ShuffleActivationKblock {
auto reordA = kernel.createReorderStorage(m, k, 32);
avector<int8_t> bufA(reordA.mSize);
reordA.assign(bufA.data());
kernel.preprocess({src.data(), k, nullptr, indices.data(), &reordA}, m, k, 32, &DefaultThreading);
kernel.preprocess({src.data(), k, nullptr, indices.data(), &reordA}, m, k, 32, UT_Threading::get());

kernel.getActivation(&dstptr, &dststride, {src.data(), k, nullptr, indices.data(), &reordA}, m, kpad, 0, 0, cache,
CacheSize);
Expand Down Expand Up @@ -272,7 +272,7 @@ class UT_ShuffleActivationKblock {
avector<int8_t> bufA(quanAct.mSize + reordAct.mSize);
quanAct.assign(bufA.data());
reordAct.assign(bufA.data() + quanAct.mSize);
actA.quantize({raw_cp.data(), k, &quanAct, indices.data(), &reordAct}, m, k, &DefaultThreading);
actA.quantize({raw_cp.data(), k, &quanAct, indices.data(), &reordAct}, m, k, UT_Threading::get());
ut::buffer_error(quanAct.template APtr<int8_t>(), q.data(), q.size(), int8_t(1));
if (hasreduce) {
avector<float> redref(reduce.size(), 0.f), redqref(reduce.size(), 0.f);
Expand Down
130 changes: 65 additions & 65 deletions bestla/bestla/ut/bestla_prologue_b.cpp

Large diffs are not rendered by default.

9 changes: 0 additions & 9 deletions bestla/bestla/ut/bestla_ut.cpp
Original file line number Diff line number Diff line change
@@ -1,15 +1,6 @@
#include <stdio.h>
#include <bestla_parallel.h>

namespace bestla {
namespace ut {
#ifdef _OPENMP
parallel::OMPThreading DefaultThreading(4);
#else
parallel::StdThreading DefaultThreading(4);
#endif // _OPNEMP
} // namespace ut
} // namespace bestla
int main() {
printf("BesTLA UT done\n");
return 0;
Expand Down
25 changes: 17 additions & 8 deletions bestla/bestla/ut/bestla_ut.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,20 @@ using sAVX512_VNNI = gemm::ICoreRowNAvx512vnni<48, 8>;
using sAMX_INT8_US = gemm::ICoreRowNAmxint8<64, 16>;
using sAMX_INT8_SS = gemm::ICoreRowNAmxint8SS<64, 16>;
using sAVX2 = gemm::SCoreRowNAvx2<24, 4>;
#ifdef _OPENMP
extern parallel::OMPThreading DefaultThreading;

class UT_Threading {
public:
static bestla::parallel::IThreading* get() {
#if BTLA_UT_OPENMP
static bestla::parallel::OMPThreading DefaultThreading(4);
#else
extern parallel::StdThreading DefaultThreading;
static bestla::parallel::StdThreading DefaultThreading(4);
#endif // _OPNEMP
return &DefaultThreading;
}

static void set_threads(int n_thread) { get()->set_threads(n_thread); }
};

constexpr size_t CacheSize = size_t(100) << 10;
static int8_t cache[CacheSize];
Expand Down Expand Up @@ -129,11 +138,11 @@ utils::aligned_vector<_T> readFile2Buffer(const char* filepath) {
return buf;
}

#define UT_START() \
{ \
GetCPUDevice(); \
ut::DefaultThreading.set_threads(_cd->getThreads()); \
printf("Test Class: %s\n", __FUNCTION__); \
#define UT_START() \
{ \
GetCPUDevice(); \
ut::UT_Threading::set_threads(_cd->getThreads()); \
printf("Test Class: %s\n", __FUNCTION__); \
}
template <typename _T>
static double buffer_error(_T* ref, _T* tar, size_t size, _T thres = _T(0)) {
Expand Down
50 changes: 25 additions & 25 deletions bestla/bestla/ut/bestla_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@ class UT_Fp32Fp32 {
auto packw = launcher.mProB.createStorage(n, k);
avector<int8_t> buffer(packw.mSize);
packw.assign(buffer.data());
launcher.mProB.packWeight(n, k, {matB.data(), n, &packw}, &DefaultThreading);
launcher.mProB.packWeight(n, k, {matB.data(), n, &packw}, UT_Threading::get());
utils::GemmProblem gp(1, m, n, k);
typename Launcher::Param args{gp, {matA.data(), k}, {matB.data(), n, &packw}, {matC.data(), n}};
parallel::GemmRun<Parallel>(launcher, args, &DefaultThreading);
parallel::GemmRun<Parallel>(launcher, args, UT_Threading::get());
ut::buffer_error(ref.data(), matC.data(), ref.size(), 0.001f);
}

Expand All @@ -65,7 +65,7 @@ class UT_Fp32Fp32 {
wrapper::gemm::LauncherBase<Core_T::ISA, Core_T, prologue_a::gemm::ActivationBase, prologue_b::gemm::WeightPack,
epilogue::gemm::AccumulatorWriteBackFp32>;
Launcher kernel;
DefaultThreading.set_threads(threads);
UT_Threading::set_threads(threads);
auto corestr = gemm::CoreAttr::to_str(Core_T::ID);
utils::timer<std::chrono::milliseconds> tm;
auto tmpB = kernel.mProB.createStorage(n, k);
Expand All @@ -74,7 +74,7 @@ class UT_Fp32Fp32 {
for (size_t i = 0; i < batch; i++) {
packBs[i] = tmpB;
packBs[i].assign(bufB.data() + i * tmpB.mSize);
kernel.mProB.packWeight(n, k, {B + i * n * k, n, &packBs[i]}, &DefaultThreading);
kernel.mProB.packWeight(n, k, {B + i * n * k, n, &packBs[i]}, UT_Threading::get());
}
auto psize = (size_t)m * n * k * 2;
tm.start();
Expand All @@ -83,7 +83,7 @@ class UT_Fp32Fp32 {
log.start();
utils::GemmProblem gp(1, m, n, k);
typename Launcher::Param args{gp, {A + i * m * k, k}, {0, 0, &packBs[i]}, {C + i * m * n, n}};
parallel::GemmRun<Parallel>(kernel, args, &DefaultThreading);
parallel::GemmRun<Parallel>(kernel, args, UT_Threading::get());
if (log.stop()) {
double flops = double(psize) / log.avg_val / 1e6;
printf("%s %s Flops:%.3f PerCoreFlops:%.3f\n ", corestr, log.get_log_str(), flops, flops / threads);
Expand Down Expand Up @@ -190,14 +190,14 @@ class UT_U8S8S32 {
auto packw = launcher.mProB.createStorage(n, k);
avector<int8_t> buffer(packw.mSize);
packw.assign(buffer.data());
launcher.mProB.packWeight(n, k, {matBs8.data(), n, &packw}, &DefaultThreading);
launcher.mProB.packWeight(n, k, {matBs8.data(), n, &packw}, UT_Threading::get());
utils::GemmProblem gp(1, m, n, k);
typename Launcher::Param args{
gp,
{matAu8.data(), k},
{matBs8.data(), n, &packw},
{matC.data(), n, 1, scaleAf32.data(), scaleBf32.data(), zpAu8.data(), reduceB.data()}};
parallel::GemmRun<Parallel>(launcher, args, &DefaultThreading);
parallel::GemmRun<Parallel>(launcher, args, UT_Threading::get());
ut::buffer_error(refC.data(), matC.data(), refC.size(), 0.001f);
}

Expand All @@ -212,7 +212,7 @@ class UT_U8S8S32 {
wrapper::gemm::LauncherBase<Core_T::ISA, Core_T, prologue_a::gemm::ActivationBase, prologue_b::gemm::WeightPack,
epilogue::gemm::AccumulatorWriteBackInt32>;
Launcher kernel;
DefaultThreading.set_threads(threads);
UT_Threading::set_threads(threads);
auto corestr = gemm::CoreAttr::to_str(Core_T::ID);
utils::timer<std::chrono::milliseconds> tm;
auto tmpB = kernel.mProB.createStorage(n, k);
Expand All @@ -221,7 +221,7 @@ class UT_U8S8S32 {
for (size_t i = 0; i < batch; i++) {
packBs[i] = tmpB;
packBs[i].assign(bufB.data() + i * tmpB.mSize);
kernel.mProB.packWeight(n, k, {B + i * n * k, n, &packBs[i]}, &DefaultThreading);
kernel.mProB.packWeight(n, k, {B + i * n * k, n, &packBs[i]}, UT_Threading::get());
}
auto psize = (size_t)m * n * k * 2;
tm.start();
Expand All @@ -230,7 +230,7 @@ class UT_U8S8S32 {
log.start();
utils::GemmProblem gp(1, m, n, k);
typename Launcher::Param args{gp, {A + i * m * k, k}, {0, 0, &packBs[i]}, {C + i * m * n, n}};
parallel::GemmRun<Parallel>(kernel, args, &DefaultThreading);
parallel::GemmRun<Parallel>(kernel, args, UT_Threading::get());
if (log.stop()) {
double flops = double(psize) / log.avg_val / 1e6;
printf("Threads %d %s %s Flops:%.3f PerCoreFlops:%.3f\n", threads, corestr, log.get_log_str(), flops,
Expand Down Expand Up @@ -324,11 +324,11 @@ class UT_S8S8S32 {
auto packw = launcher.mProB.createStorage(n, k);
avector<int8_t> buffer(packw.mSize);
packw.assign(buffer.data());
launcher.mProB.packWeight(n, k, {matBs8.data(), n, &packw}, &DefaultThreading);
launcher.mProB.packWeight(n, k, {matBs8.data(), n, &packw}, UT_Threading::get());
utils::GemmProblem gp(1, m, n, k);
typename Launcher::Param args{
gp, {matAu8.data(), k}, {matBs8.data(), n, &packw}, {matC.data(), n, 1, scaleAf32.data(), scaleBf32.data()}};
parallel::GemmRun<Parallel>(launcher, args, &DefaultThreading);
parallel::GemmRun<Parallel>(launcher, args, UT_Threading::get());
ut::buffer_error(refC.data(), matC.data(), refC.size(), 0.001f);
}

Expand All @@ -343,7 +343,7 @@ class UT_S8S8S32 {
wrapper::gemm::LauncherBase<Core_T::ISA, Core_T, prologue_a::gemm::ActivationBase, prologue_b::gemm::WeightPack,
epilogue::gemm::AccumulatorWriteBackInt32>;
Launcher kernel;
DefaultThreading.set_threads(threads);
UT_Threading::set_threads(threads);
auto corestr = gemm::CoreAttr::to_str(Core_T::ID);
utils::timer<std::chrono::milliseconds> tm;
auto tmpB = kernel.mProB.createStorage(n, k);
Expand All @@ -352,7 +352,7 @@ class UT_S8S8S32 {
for (size_t i = 0; i < batch; i++) {
packBs[i] = tmpB;
packBs[i].assign(bufB.data() + i * tmpB.mSize);
kernel.mProB.packWeight(n, k, {B + i * n * k, n, &packBs[i]}, &DefaultThreading);
kernel.mProB.packWeight(n, k, {B + i * n * k, n, &packBs[i]}, UT_Threading::get());
}
auto psize = (size_t)m * n * k * 2;
tm.start();
Expand All @@ -361,7 +361,7 @@ class UT_S8S8S32 {
log.start();
utils::GemmProblem gp(1, m, n, k);
typename Launcher::Param args{gp, {A + i * m * k, k}, {0, 0, &packBs[i]}, {C + i * m * n, n}};
parallel::GemmRun<Parallel>(kernel, args, &DefaultThreading);
parallel::GemmRun<Parallel>(kernel, args, UT_Threading::get());
if (log.stop()) {
double flops = double(psize) / log.avg_val / 1e6;
printf("Threads %d %s %s Flops:%.3f PerCoreFlops:%.3f\n", threads, corestr, log.get_log_str(), flops,
Expand Down Expand Up @@ -430,11 +430,11 @@ class UT_Bf16Bf16Fp32 {
fill_buffer_randn(matAbf16.data(), matAbf16.size(), utils::bf16(-0.5f), utils::bf16(0.5f));
fill_buffer_randn(matBbf16.data(), matBbf16.size(), utils::bf16(-0.5f), utils::bf16(0.5f));
avector<float> matC(m * n), refC(m * n);
launcher.mProB.packWeight(n, k, {matBbf16.data(), n, &packw}, &DefaultThreading);
launcher.mProB.packWeight(n, k, {matBbf16.data(), n, &packw}, UT_Threading::get());
gemmref_bf16bf16fp32(m, n, k, matAbf16.data(), matBbf16.data(), refC.data(), k, n, n);
utils::GemmProblem gp(1, m, n, k);
typename Launcher::Param args{gp, {matAbf16.data(), k}, {matBbf16.data(), n, &packw}, {matC.data(), n}};
parallel::GemmRun<Parallel>(launcher, args, &DefaultThreading);
parallel::GemmRun<Parallel>(launcher, args, UT_Threading::get());
buffer_error(refC.data(), matC.data(), refC.size(), 0.05f);
}

Expand All @@ -449,7 +449,7 @@ class UT_Bf16Bf16Fp32 {
wrapper::gemm::LauncherBase<Core_T::ISA, Core_T, prologue_a::gemm::ActivationBase, prologue_b::gemm::WeightPack,
epilogue::gemm::AccumulatorWriteBackFp32>;
Launcher kernel;
DefaultThreading.set_threads(threads);
UT_Threading::set_threads(threads);
auto corestr = gemm::CoreAttr::to_str(Core_T::ID);
utils::timer<std::chrono::milliseconds> tm;
auto tmpB = kernel.mProB.createStorage(n, k);
Expand All @@ -458,7 +458,7 @@ class UT_Bf16Bf16Fp32 {
for (size_t i = 0; i < batch; i++) {
packBs[i] = tmpB;
packBs[i].assign(bufB.data() + i * tmpB.mSize);
kernel.mProB.packWeight(n, k, {B + i * n * k, n, &packBs[i]}, &DefaultThreading);
kernel.mProB.packWeight(n, k, {B + i * n * k, n, &packBs[i]}, UT_Threading::get());
}
auto psize = (size_t)m * n * k * 2;
tm.start();
Expand All @@ -467,7 +467,7 @@ class UT_Bf16Bf16Fp32 {
log.start();
utils::GemmProblem gp(1, m, n, k);
typename Launcher::Param args{gp, {A + i * m * k, k}, {0, 0, &packBs[i]}, {C + i * m * n, n}};
parallel::GemmRun<Parallel>(kernel, args, &DefaultThreading);
parallel::GemmRun<Parallel>(kernel, args, UT_Threading::get());
if (log.stop()) {
double flops = double(psize) / log.avg_val / 1e6;
printf("Threads %d %s %s Flops:%.3f PerCoreFlops:%.3f\n", threads, corestr, log.get_log_str(), flops,
Expand Down Expand Up @@ -534,11 +534,11 @@ class UT_Fp16Fp16Fp16 {
avector<utils::fp16> matAbf16(m * k), matBbf16(k * n), matC(m * n), refC(m * n);
fill_buffer_randn(matAbf16.data(), matAbf16.size(), utils::fp16(-0.5f), utils::fp16(0.5f));
fill_buffer_randn(matBbf16.data(), matBbf16.size(), utils::fp16(-0.5f), utils::fp16(0.5f));
launcher.mProB.packWeight(n, k, {matBbf16.data(), n, &packw}, &DefaultThreading);
launcher.mProB.packWeight(n, k, {matBbf16.data(), n, &packw}, UT_Threading::get());
gemmref_fp16fp16fp16(m, n, k, matAbf16.data(), matBbf16.data(), refC.data(), k, n, n);
GemmProblem gp(1, m, n, k);
typename Launcher::Param args{gp, {matAbf16.data(), k}, {matBbf16.data(), n, &packw}, {matC.data(), n}};
parallel::GemmRun<Parallel>(launcher, args, &DefaultThreading);
parallel::GemmRun<Parallel>(launcher, args, UT_Threading::get());
buffer_error(refC.data(), matC.data(), refC.size(), utils::fp16(0.0002f * k));
}

Expand All @@ -553,7 +553,7 @@ class UT_Fp16Fp16Fp16 {
wrapper::gemm::LauncherBase<Core_T::ISA, Core_T, prologue_a::gemm::ActivationBase, prologue_b::gemm::WeightPack,
epilogue::gemm::AccumulatorWriteBackFp16>;
Launcher kernel;
DefaultThreading.set_threads(threads);
UT_Threading::set_threads(threads);
auto corestr = gemm::CoreAttr::to_str(Core_T::ID);
utils::timer<std::chrono::milliseconds> tm;
auto tmpB = kernel.mProB.createStorage(n, k);
Expand All @@ -562,7 +562,7 @@ class UT_Fp16Fp16Fp16 {
for (size_t i = 0; i < batch; i++) {
packBs[i] = tmpB;
packBs[i].assign(bufB.data() + i * tmpB.mSize);
kernel.mProB.packWeight(n, k, {B + i * n * k, n, &packBs[i]}, &DefaultThreading);
kernel.mProB.packWeight(n, k, {B + i * n * k, n, &packBs[i]}, UT_Threading::get());
}
auto psize = (size_t)m * n * k * 2;
tm.start();
Expand All @@ -571,7 +571,7 @@ class UT_Fp16Fp16Fp16 {
log.start();
GemmProblem gp(1, m, n, k);
typename Launcher::Param args{gp, {A + i * m * k, k}, {0, 0, &packBs[i]}, {C + i * m * n, n}};
parallel::GemmRun<Parallel>(kernel, args, &DefaultThreading);
parallel::GemmRun<Parallel>(kernel, args, UT_Threading::get());
if (log.stop()) {
double flops = double(psize) / log.avg_val / 1e6;
printf("Threads %d %s %s Flops:%.3f PerCoreFlops:%.3f\n", threads, corestr, log.get_log_str(), flops,
Expand Down
6 changes: 3 additions & 3 deletions bestla/bestla/ut/kernel_jit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ class UT_Memcpy2D_AVX512F {
kernel::jit::JitMemcpy2DAvx512f::forward<float, float>(src.data(), dst.data(), row, col, srcstep, dststep);
}
tm.start();
parallel::Scheduler2D para({DefaultThreading.num_threads(), row, col, 4, 64});
parallel::Scheduler2D para({UT_Threading::get()->num_threads(), row, col, 4, 64});
for (size_t i = 0; i < TestLoop; i++) {
DefaultThreading.parallel_for([&](int tidx) {
UT_Threading::get()->parallel_for([&](int tidx) {
parallel::ThreadProblem2D thdp{tidx};
para.getIndex(thdp);
if (thdp.valid) {
Expand All @@ -47,7 +47,7 @@ class UT_Memcpy2D_AVX512F {

tm.start();
for (size_t i = 0; i < TestLoop; i++) {
DefaultThreading.parallel_for([&](int tidx) {
UT_Threading::get()->parallel_for([&](int tidx) {
parallel::ThreadProblem2D thdp{tidx};
para.getIndex(thdp);
if (thdp.valid) {
Expand Down

0 comments on commit f58d0e1

Please sign in to comment.