From 757a91c3ca4f5ebf4879739c0871d2d5534465ac Mon Sep 17 00:00:00 2001 From: Xinyu Chen Date: Sun, 24 Jun 2018 06:28:10 +0800 Subject: [PATCH] Improve batch gemm performance using MKL (#342) * improve batch_dot performance by using MKL * reduce for loop * improve double batch gemm * remove unnecessary reserve * fix lint * add MKL version check --- mshadow/base.h | 1 + mshadow/dot_engine-inl.h | 75 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 76 insertions(+) diff --git a/mshadow/base.h b/mshadow/base.h index 32aaa051..23cb49ac 100755 --- a/mshadow/base.h +++ b/mshadow/base.h @@ -166,6 +166,7 @@ extern "C" { #include #include #include + #include #endif #if MSHADOW_USE_CUDA diff --git a/mshadow/dot_engine-inl.h b/mshadow/dot_engine-inl.h index f47fa97c..5363974f 100644 --- a/mshadow/dot_engine-inl.h +++ b/mshadow/dot_engine-inl.h @@ -7,6 +7,7 @@ #ifndef MSHADOW_DOT_ENGINE_INL_H_ #define MSHADOW_DOT_ENGINE_INL_H_ +#include #include "./base.h" #include "./extension/implicit_gemm.h" @@ -291,11 +292,48 @@ struct BLASEngine { const float *A, int lda, const float *B, int ldb, float beta, float *C, int ldc, int batch_count, float **workspace) { +#if (MSHADOW_USE_MKL && INTEL_MKL_VERSION >= 20160000) + std::vector p_m(batch_count, m); + std::vector p_n(batch_count, n); + std::vector p_k(batch_count, k); + std::vector p_lda(batch_count, lda); + std::vector p_ldb(batch_count, ldb); + std::vector p_ldc(batch_count, ldc); + std::vector p_alpha(batch_count, alpha); + std::vector p_beta(batch_count, beta); + std::vector pp_A; + std::vector pp_B; + std::vector pp_C; + + CBLAS_TRANSPOSE cblas_a_trans = GetT(transa); + CBLAS_TRANSPOSE cblas_b_trans = GetT(transb); + + std::vector p_group_sizeb(batch_count, batch_count); + std::vector p_transa(batch_count, cblas_a_trans); + std::vector p_transb(batch_count, cblas_b_trans); + + auto m_k = m * k; + auto k_n = k * n; + auto m_n = m * n; + + for (int i = 0; i < batch_count; i++) { + pp_A.push_back(A + i * m_k); + pp_B.push_back(B + i * k_n); + pp_C.push_back(C + i * m_n); + } + + cblas_sgemm_batch(CblasColMajor, p_transa.data(), p_transb.data(), + p_m.data(), p_n.data(), p_k.data(), + p_alpha.data(), pp_A.data(), p_lda.data(), pp_B.data(), + p_ldb.data(), p_beta.data(), pp_C.data(), p_ldc.data(), + 1, p_group_sizeb.data()); +#else for (int i = 0; i < batch_count; ++i) { gemm(stream, transa, transb, m, n, k, alpha, A + i * m * k, lda, B + i * k * n, ldb, beta, C + i * m * n, ldc); } +#endif } inline static void gemv(Stream *stream, bool trans, int m, int n, @@ -361,11 +399,48 @@ struct BLASEngine { const double *A, int lda, const double *B, int ldb, double beta, double *C, int ldc, int batch_count, double **workspace) { +#if (MSHADOW_USE_MKL && INTEL_MKL_VERSION >= 20160000) + std::vector p_m(batch_count, m); + std::vector p_n(batch_count, n); + std::vector p_k(batch_count, k); + std::vector p_lda(batch_count, lda); + std::vector p_ldb(batch_count, ldb); + std::vector p_ldc(batch_count, ldc); + std::vector p_alpha(batch_count, alpha); + std::vector p_beta(batch_count, beta); + std::vector pp_A; + std::vector pp_B; + std::vector pp_C; + + CBLAS_TRANSPOSE cblas_a_trans = GetT(transa); + CBLAS_TRANSPOSE cblas_b_trans = GetT(transb); + + std::vector p_group_sizeb(batch_count, batch_count); + std::vector p_transa(batch_count, cblas_a_trans); + std::vector p_transb(batch_count, cblas_b_trans); + + auto m_k = m * k; + auto k_n = k * n; + auto m_n = m * n; + + for (int i = 0; i < batch_count; i++) { + pp_A.push_back(A + i * m_k); + pp_B.push_back(B + i * k_n); + pp_C.push_back(C + i * m_n); + } + + cblas_dgemm_batch(CblasColMajor, p_transa.data(), p_transb.data(), + p_m.data(), p_n.data(), p_k.data(), + p_alpha.data(), pp_A.data(), p_lda.data(), pp_B.data(), + p_ldb.data(), p_beta.data(), pp_C.data(), p_ldc.data(), + 1, p_group_sizeb.data()); +#else for (int i = 0; i < batch_count; ++i) { gemm(stream, transa, transb, m, n, k, alpha, A + i * m * k, lda, B + i * k * n, ldb, beta, C + i * m * n, ldc); } +#endif } inline static void gemv(Stream *stream, bool trans, int m, int n, double alpha,