From 129192a90f7aa15d98f0fa1c2bb523e35a715140 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Mon, 4 Jan 2021 17:20:37 -0800 Subject: [PATCH] Improve add_bias_kernel for small bias length --- src/operator/nn/fully_connected-inl.h | 40 +++++++++++++++++++-------- 1 file changed, 29 insertions(+), 11 deletions(-) diff --git a/src/operator/nn/fully_connected-inl.h b/src/operator/nn/fully_connected-inl.h index d12948ea74aa..c90e8ce014e7 100644 --- a/src/operator/nn/fully_connected-inl.h +++ b/src/operator/nn/fully_connected-inl.h @@ -133,20 +133,36 @@ namespace { inline int ceil_div(int x, int y) { return (x + y - 1) / y; } + + inline int FindNumRowsPerBlock(size_t bias_length, size_t lead_dim) { + int ret = 1; + while (bias_length < nthreads_addbias && + lead_dim % 2 == 0) { + bias_length *= 2; + ret *= 2; + lead_dim /= 2; + } + return ret; + } } // namespace template -__global__ void add_bias_kernel(DType* mat, DType* bias, size_t lead_dim, size_t bias_length) { +__global__ void add_bias_kernel(DType* const mat, const DType* const bias, + const size_t lead_dim, const size_t bias_length, + const int rows) { __shared__ LType scratch[nthreads_addbias * 2]; + const int threads_per_row = nthreads_addbias / rows; + const int threadId_in_row = threadIdx.x & (threads_per_row - 1); + const int row_id = threadIdx.x * rows / nthreads_addbias; const index_t N = bias_length * sizeof(DType)/sizeof(LType); - const index_t base = blockIdx.x * N; + const index_t base = (blockIdx.x * rows + row_id) * N; LType* const mat_aligned = reinterpret_cast(mat) + base; - const LType* const bias_aligned = reinterpret_cast(bias); + const LType* const bias_aligned = reinterpret_cast(bias); LType* const scratch_bias_load = scratch + threadIdx.x; DType* const scratch_bias = reinterpret_cast(scratch_bias_load); LType* const scratch_mat_load = scratch_bias_load + nthreads_addbias; DType* const scratch_mat = reinterpret_cast(scratch_mat_load); - for (index_t i = threadIdx.x; i < N; i += blockDim.x) { + for (index_t i = threadId_in_row; i < N; i += threads_per_row) { *scratch_bias_load = bias_aligned[i]; *scratch_mat_load = mat_aligned[i]; #pragma unroll @@ -162,13 +178,15 @@ void AddBias(Tensor bias, Tensor data, Tensor out, Stream* s) { int ltype = mxnet::common::cuda::get_load_type(bias.shape_[0] * sizeof(DType)); MXNET_LOAD_TYPE_SWITCH(ltype, LType, { - add_bias_kernel<<::GetStream(s)>>>(out.dptr_, - bias.dptr_, - data.size(0), - bias.shape_[0]); + int rows = FindNumRowsPerBlock(bias.shape_[0] * sizeof(DType) / sizeof(LType), data.size(0)); + add_bias_kernel<<::GetStream(s)>>>(out.dptr_, + bias.dptr_, + data.size(0), + bias.shape_[0], + rows); }); }