Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Port BRGEMM (#20910)
Browse files Browse the repository at this point in the history
  • Loading branch information
bgawrych authored Mar 31, 2022
1 parent 67467f8 commit 76fc3ef
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 10 deletions.
4 changes: 4 additions & 0 deletions docs/static_site/src/pages/api/faq/env_var.md
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,10 @@ If ctypes is used, it must be `mxnet._ctypes.ndarray.NDArrayBase`.
If no such algorithm exists given other constraints, MXNet will error out. This variable affects the choice
of CUDNN convolution algorithms. Please see [CUDNN developer guide](https://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html) for more details.

* MXNET_MKLDNN_FORCE_FC_AB_FORMAT
- Values: 0, 1 ```(default=0)```
- If set to true, FullyConnected will use only AB format for weights, thus MXNet won't use BRGEMM implementation of FC on machines with AVX512-VNNI support which requires special weights format.

* MXNET_CPU_PARALLEL_SIZE
- Values: Int ```(default=200000)```
- The minimum size to call parallel operations by OpenMP for CPU context.
Expand Down
17 changes: 14 additions & 3 deletions src/operator/nn/mkldnn/mkldnn_base-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -285,15 +285,26 @@ inline static mkldnn::memory::desc GetMemDesc(const NDArray &arr, int dtype = -1
return mkldnn::memory::desc{dims, get_mkldnn_type(dtype), mkldnn::memory::format_tag::any};
}

inline static mkldnn::memory::desc GetFCWeightDesc(const NDArray &arr, int dtype = -1) {
inline static bool ChooseBRGEMMImpl(const mkldnn::memory::dims& weight_dims, size_t batch_size) {
// Conditions based on measurement results done on CLX8280
// https://github.com/apache/incubator-mxnet/pull/20533
return weight_dims[0] >= 1024 && weight_dims[1] >= 1024 && batch_size >= 16384 &&
weight_dims[0] % 64 == 0 && weight_dims[1] % 64 == 0;
}

inline static mkldnn::memory::desc GetFCWeightDesc(const NDArray& arr,
size_t batch_size,
int dtype = -1) {
int ndim = arr.shape().ndim();
mkldnn::memory::dims dims(ndim);
dtype = (dtype == -1) ? arr.dtype() : dtype;
for (size_t i = 0; i < dims.size(); i++) dims[i] = arr.shape()[i];
auto format = mkldnn::memory::format_tag::any;
// for batch 256 alexnet benchmark test
const bool force_fc_ab_format = dmlc::GetEnv("MXNET_MKLDNN_FORCE_FC_AB_FORMAT", false);
if (dims.size() == 2) {
format = mkldnn::memory::format_tag::ab;
if (force_fc_ab_format || !ChooseBRGEMMImpl(dims, batch_size)) {
format = mkldnn::memory::format_tag::ab;
}
}

return mkldnn::memory::desc{dims, get_mkldnn_type(dtype), format};
Expand Down
16 changes: 9 additions & 7 deletions src/operator/nn/mkldnn/mkldnn_fully_connected.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,14 @@ mkldnn::inner_product_forward::primitive_desc GetFCFwdImpl(
const MKLDNNFCFullParam &full_param, const bool is_train,
const NDArray &data, const NDArray &weight, const NDArray *bias,
const mkldnn::memory::desc &out_md) {
auto data_md = GetMemDesc(data);
auto weight_md = full_param.mkldnn_param.quantized ?
GetFCWeightDesc(weight, mshadow::kInt8) : GetFCWeightDesc(weight);
auto engine = CpuEngine::Get()->get_engine();
auto propagation =
is_train ? mkldnn::prop_kind::forward_training : mkldnn::prop_kind::forward_scoring;
auto data_md = GetMemDesc(data);
auto weight_md =
full_param.mkldnn_param.quantized
? GetFCWeightDesc(weight, data.shape()[0], mshadow::kInt8)
: GetFCWeightDesc(weight, data.shape()[0]);
auto propagation = is_train ? mkldnn::prop_kind::forward_training
: mkldnn::prop_kind::forward_scoring;

mkldnn::primitive_attr attr;
mkldnn::post_ops ops;
Expand Down Expand Up @@ -91,7 +93,7 @@ inline static mkldnn::inner_product_backward_data::primitive_desc GetFCBwdData(
const NDArray &data, const NDArray &weight, const NDArray &output,
mkldnn::inner_product_forward::primitive_desc fwd_pd) {
auto data_md = GetMemDesc(data);
auto weight_md = GetFCWeightDesc(weight);
auto weight_md = GetFCWeightDesc(weight, data.shape()[0]);
auto out_md = GetMemDesc(output);
auto engine = CpuEngine::Get()->get_engine();
mkldnn::inner_product_backward_data::desc desc(data_md, weight_md, out_md);
Expand All @@ -102,7 +104,7 @@ inline static mkldnn::inner_product_backward_weights::primitive_desc GetFCBwdWei
const NDArray &data, const NDArray &weight, const NDArray *bias,
const NDArray &output, mkldnn::inner_product_forward::primitive_desc fwd_pd) {
auto data_md = GetMemDesc(data);
auto weight_md = GetFCWeightDesc(weight);
auto weight_md = GetFCWeightDesc(weight, data.shape()[0]);
auto out_md = GetMemDesc(output);
auto engine = CpuEngine::Get()->get_engine();
if (bias) {
Expand Down

0 comments on commit 76fc3ef

Please sign in to comment.