Skip to content
This repository has been archived by the owner on Aug 11, 2020. It is now read-only.

Improve batch gemm performance using MKL #342

Merged
merged 6 commits into from
Jun 23, 2018
Merged
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions mshadow/dot_engine-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#ifndef MSHADOW_DOT_ENGINE_INL_H_
#define MSHADOW_DOT_ENGINE_INL_H_

#include <vector>
#include "./base.h"
#include "./extension/implicit_gemm.h"

Expand Down Expand Up @@ -291,11 +292,48 @@ struct BLASEngine<cpu, float> {
const float *A, int lda, const float *B, int ldb,
float beta, float *C, int ldc, int batch_count,
float **workspace) {
#if MSHADOW_USE_MKL
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is cblas_sgemm_batch and cblas_dgemm_batch generally supported in MKL? Do we need to check the version?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to this page, Intel MKL 11.3 Beta (part of Intel® Parallel Studio XE 2016 Beta) includes a new flavor of GEMM feature called "Batch GEMM".

std::vector<int> p_m(batch_count, m);
std::vector<int> p_n(batch_count, n);
std::vector<int> p_k(batch_count, k);
std::vector<int> p_lda(batch_count, lda);
std::vector<int> p_ldb(batch_count, ldb);
std::vector<int> p_ldc(batch_count, ldc);
std::vector<float> p_alpha(batch_count, alpha);
std::vector<float> p_beta(batch_count, beta);
std::vector<const float*> pp_A;
std::vector<const float*> pp_B;
std::vector<float*> pp_C;

CBLAS_TRANSPOSE cblas_a_trans = GetT(transa);
CBLAS_TRANSPOSE cblas_b_trans = GetT(transb);

std::vector<int> p_group_sizeb(batch_count, batch_count);
std::vector<CBLAS_TRANSPOSE> p_transa(batch_count, cblas_a_trans);
std::vector<CBLAS_TRANSPOSE> p_transb(batch_count, cblas_b_trans);

auto m_k = m * k;
auto k_n = k * n;
auto m_n = m * n;

for (int i = 0; i < batch_count; i++) {
pp_A.push_back(A + i * m_k);
pp_B.push_back(B + i * k_n);
pp_C.push_back(C + i * m_n);
}

cblas_sgemm_batch(CblasColMajor, p_transa.data(), p_transb.data(),
p_m.data(), p_n.data(), p_k.data(),
p_alpha.data(), pp_A.data(), p_lda.data(), pp_B.data(),
p_ldb.data(), p_beta.data(), pp_C.data(), p_ldc.data(),
1, p_group_sizeb.data());
#else
for (int i = 0; i < batch_count; ++i) {
gemm(stream, transa, transb, m, n, k, alpha,
A + i * m * k, lda, B + i * k * n, ldb,
beta, C + i * m * n, ldc);
}
#endif
}
inline static void gemv(Stream<cpu> *stream,
bool trans, int m, int n,
Expand Down Expand Up @@ -361,11 +399,48 @@ struct BLASEngine<cpu, double> {
const double *A, int lda, const double *B, int ldb,
double beta, double *C, int ldc, int batch_count,
double **workspace) {
#if MSHADOW_USE_MKL
std::vector<int> p_m(batch_count, m);
std::vector<int> p_n(batch_count, n);
std::vector<int> p_k(batch_count, k);
std::vector<int> p_lda(batch_count, lda);
std::vector<int> p_ldb(batch_count, ldb);
std::vector<int> p_ldc(batch_count, ldc);
std::vector<double> p_alpha(batch_count, alpha);
std::vector<double> p_beta(batch_count, beta);
std::vector<const double*> pp_A;
std::vector<const double*> pp_B;
std::vector<double*> pp_C;

CBLAS_TRANSPOSE cblas_a_trans = GetT(transa);
CBLAS_TRANSPOSE cblas_b_trans = GetT(transb);

std::vector<int> p_group_sizeb(batch_count, batch_count);
std::vector<CBLAS_TRANSPOSE> p_transa(batch_count, cblas_a_trans);
std::vector<CBLAS_TRANSPOSE> p_transb(batch_count, cblas_b_trans);

auto m_k = m * k;
auto k_n = k * n;
auto m_n = m * n;

for (int i = 0; i < batch_count; i++) {
pp_A.push_back(A + i * m_k);
pp_B.push_back(B + i * k_n);
pp_C.push_back(C + i * m_n);
}

cblas_dgemm_batch(CblasColMajor, p_transa.data(), p_transb.data(),
p_m.data(), p_n.data(), p_k.data(),
p_alpha.data(), pp_A.data(), p_lda.data(), pp_B.data(),
p_ldb.data(), p_beta.data(), pp_C.data(), p_ldc.data(),
1, p_group_sizeb.data());
#else
for (int i = 0; i < batch_count; ++i) {
gemm(stream, transa, transb, m, n, k, alpha,
A + i * m * k, lda, B + i * k * n, ldb,
beta, C + i * m * n, ldc);
}
#endif
}
inline static void gemv(Stream<cpu> *stream,
bool trans, int m, int n, double alpha,
Expand Down