Skip to content

Commit

Permalink
jagged_dense_bmm operator optimization (pytorch#1643)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#1643

This diff optimizes the jagged_dense_bmm operator with the following optimizations:
* tiling across thread blocks, and use GPU shared memory for thread block
* tiling across threads within a thread block, and use registers for each thread

Differential Revision: D43674845

fbshipit-source-id: a19be6fac4db0e78668c60b72b408ef3a02f1684
  • Loading branch information
Rengan Xu authored and facebook-github-bot committed Mar 17, 2023
1 parent c7cddec commit a219eaf
Showing 1 changed file with 138 additions and 29 deletions.
167 changes: 138 additions & 29 deletions fbgemm_gpu/src/jagged_tensor_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2071,36 +2071,132 @@ Tensor jagged_jagged_bmm_forward(
return output;
}
template <typename index_t, typename scalar_t>
template <
const int BLOCK_TILE_M,
const int BLOCK_TILE_N,
const int BLOCK_TILE_K,
const int THREAD_TILE_M,
const int THREAD_TILE_N,
typename index_t,
typename scalar_t>
__global__ __launch_bounds__(kMaxThreads) void jagged_dense_bmm_kernel(
const at::PackedTensorAccessor32<scalar_t, 2> x_values,
const at::PackedTensorAccessor32<index_t, 1> x_offsets,
const at::PackedTensorAccessor32<scalar_t, 3> y,
const at::PackedTensorAccessor32<scalar_t, 2> __restrict__ x_values,
const at::PackedTensorAccessor32<index_t, 1> __restrict__ x_offsets,
const at::PackedTensorAccessor32<scalar_t, 3> __restrict__ y,
at::PackedTensorAccessor32<scalar_t, 2> output,
const int max_L) {
const int B = x_offsets.size(0) - 1;
const int K = x_values.size(1);
const int N = y.size(2);
const int b_l_begin = blockIdx.x * blockDim.y + threadIdx.y;
const int b_l_step = gridDim.x * blockDim.y;
for (int b_l = b_l_begin; b_l < B * max_L; b_l += b_l_step) {
const int b = b_l / max_L;
const int l = b_l % max_L;
const auto block_row = blockIdx.y;
const auto block_col = blockIdx.x;
const int THREADS_X_PER_BLOCK = BLOCK_TILE_N / THREAD_TILE_N;
const int THREADS_Y_PER_BLOCK = BLOCK_TILE_M / THREAD_TILE_M;
const int THREADS_PER_BLOCK = THREADS_X_PER_BLOCK * THREADS_Y_PER_BLOCK;
const auto thread_row = threadIdx.x / THREADS_X_PER_BLOCK;
const auto thread_col = threadIdx.x % THREADS_X_PER_BLOCK;
const auto NUM_K_BLOCKS = (K + BLOCK_TILE_K - 1) / BLOCK_TILE_K;
__shared__ scalar_t As[BLOCK_TILE_M][BLOCK_TILE_K];
__shared__ scalar_t Bs[BLOCK_TILE_K][BLOCK_TILE_N];
for (auto b = blockIdx.z; b < B; b += gridDim.z) {
const index_t row_start = x_offsets[b];
const index_t row_end = x_offsets[b + 1];
const auto length = min(row_end - row_start, (index_t)max_L);
// the indices that this current will load into shared mem
const auto inner_row_a = threadIdx.x / BLOCK_TILE_K;
const auto inner_col_a = threadIdx.x % BLOCK_TILE_K;
// the number of rows of As that will be loaded per step by a thread block
const auto A_TILE_ROW_STRIDE = THREADS_PER_BLOCK / BLOCK_TILE_K;
const auto inner_row_b = threadIdx.x / BLOCK_TILE_N;
const auto inner_col_b = threadIdx.x % BLOCK_TILE_N;
const auto B_TILE_ROW_STRIDE = THREADS_PER_BLOCK / BLOCK_TILE_N;
// registers for C
scalar_t accum[THREAD_TILE_M][THREAD_TILE_N] = {0};
// registers for As and Bs
scalar_t fragment_a[THREAD_TILE_M] = {0};
scalar_t fragment_b[THREAD_TILE_N] = {0};
// loop for block tiles in K dimension
for (auto block = 0; block < NUM_K_BLOCKS; block++) {
// load a block of x_values from global memory to shared memory
// apply tiling for threads in a block
#pragma unroll
for (auto offset = 0; offset < BLOCK_TILE_M;
offset += A_TILE_ROW_STRIDE) {
auto x_row_offset = block_row * BLOCK_TILE_M + inner_row_a + offset;
auto x_col_offset = block * BLOCK_TILE_K + inner_col_a;
if ((x_row_offset < length) && (x_col_offset < K)) {
As[inner_row_a + offset][inner_col_a] =
x_values[row_start + x_row_offset][x_col_offset];
} else {
As[inner_row_a + offset][inner_col_a] = 0;
}
}
const int row_start = x_offsets[b];
const int row_end = x_offsets[b + 1];
const int length = min(row_end - row_start, max_L);
if (length == 0 || l >= length) {
return;
} else {
// TODO: use shared memory and better reduction
for (int n = threadIdx.x; n < N; n += blockDim.x) {
at::acc_type<scalar_t, true> acc = 0;
for (int k = 0; k < K; ++k) {
acc += x_values[row_start + l][k] * y[b][k][n];
// load a block of y from global memory to shared memory
// apply tiling for threads in a block
#pragma unroll
for (auto offset = 0; offset < BLOCK_TILE_K;
offset += B_TILE_ROW_STRIDE) {
auto y_row_offset = block * BLOCK_TILE_K + inner_row_b + offset;
auto y_col_offset = block_col * BLOCK_TILE_N + inner_col_b;
if ((y_row_offset < K) && (y_col_offset < N)) {
Bs[inner_row_b + offset][inner_col_b] =
y[b][y_row_offset][y_col_offset];
} else {
Bs[inner_row_b + offset][inner_col_b] = 0;
}
}
__syncthreads();
// calculate the results per thread
#pragma unroll
for (auto k = 0; k < BLOCK_TILE_K; k++) {
// load values from shared memory to registers for x_values
for (auto row = 0; row < THREAD_TILE_M; row++) {
fragment_a[row] = As[thread_row * THREAD_TILE_M + row][k];
}
// load values from shared memory to registers for y
#pragma unroll
for (auto col = 0; col < THREAD_TILE_N; col++) {
fragment_b[col] = Bs[k][thread_col * THREAD_TILE_N + col];
}
// each thread calcualtes THREAD_TILE_M * THREAD_TILE_N elements
#pragma unroll
for (auto row = 0; row < THREAD_TILE_M; row++) {
#pragma unroll
for (auto col = 0; col < THREAD_TILE_N; col++) {
accum[row][col] += fragment_a[row] * fragment_b[col];
}
}
}
__syncthreads();
}
// write the result to the output
#pragma unroll
for (auto row = 0; row < THREAD_TILE_M; row++) {
#pragma unroll
for (auto col = 0; col < THREAD_TILE_N; col++) {
auto out_row_offset =
block_row * BLOCK_TILE_M + thread_row * THREAD_TILE_M + row;
auto out_col_offset =
block_col * BLOCK_TILE_N + thread_col * THREAD_TILE_N + col;
if ((out_row_offset < length) && (out_col_offset < N)) {
output[row_start + out_row_offset][out_col_offset] = accum[row][col];
}
output[row_start + l][n] = acc;
}
}
}
Expand All @@ -2124,9 +2220,18 @@ Tensor jagged_dense_bmm_forward(
const int total_L = x_values.size(0);
auto output = at::zeros({total_L, N}, x_values.options());
if (B > 0 && M > 0 && N > 0) {
const int block_dim_x =
std::min(div_round_up(N, kWarpSize) * kWarpSize, kMaxThreads);
const int block_dim_y = kMaxThreads / block_dim_x;
constexpr int BLOCK_TILE_M = 64;
constexpr int BLOCK_TILE_N = 8;
constexpr int BLOCK_TILE_K = 8;
constexpr int THREAD_TILE_M = 4;
constexpr int THREAD_TILE_N = 4;
const dim3 block(
(BLOCK_TILE_M * BLOCK_TILE_N) / (THREAD_TILE_M * THREAD_TILE_N));
const dim3 grid(
div_round_up(N, BLOCK_TILE_N),
div_round_up(max_L, BLOCK_TILE_M),
std::min(B, 65535));
AT_DISPATCH_INDEX_TYPES(
x_offsets.scalar_type(), "jagged_dense_bmm_kernel_1", [&] {
Expand All @@ -2136,11 +2241,15 @@ Tensor jagged_dense_bmm_forward(
x_values.scalar_type(),
"jagged_dense_bmm_kernel_2",
[&] {
jagged_dense_bmm_kernel<index_t, scalar_t>
<<<div_round_up(B * max_L, block_dim_y),
dim3(block_dim_x, block_dim_y),
0,
at::cuda::getCurrentCUDAStream()>>>(
jagged_dense_bmm_kernel<
BLOCK_TILE_M,
BLOCK_TILE_N,
BLOCK_TILE_K,
THREAD_TILE_M,
THREAD_TILE_N,
index_t,
scalar_t>
<<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
x_values.packed_accessor32<scalar_t, 2>(),
x_offsets.packed_accessor32<index_t, 1>(),
y.packed_accessor32<scalar_t, 3>(),
Expand Down

0 comments on commit a219eaf

Please sign in to comment.