Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jagged_dense_bmm operator optimization #1643

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,11 @@ static constexpr int32_t kWarpSize = 32;
#endif
// Max thread num in one thread block
static constexpr int32_t kMaxThreads = 1024;
// Max block size in Y dimension of a grid
static constexpr int32_t kMaxBlockYDim = 65535;
// Max block size in Z dimension of a grid
static constexpr int32_t kMaxBlockZDim = 65535;

static constexpr float kQParamEps = 1e-8f;

/* For rowwise int8 quantization, two quantization parameters (qparams)
Expand Down
183 changes: 153 additions & 30 deletions fbgemm_gpu/src/jagged_tensor_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2071,36 +2071,135 @@ Tensor jagged_jagged_bmm_forward(
return output;
}

template <typename index_t, typename scalar_t>
template <
const int BLOCK_TILE_M, // tile height of C that each thread block
// calculates
const int BLOCK_TILE_N, // tile width of C that each thread block
// calculates
const int BLOCK_TILE_K, // tile width of A that each thread block calculates
const int THREAD_TILE_M, // tile height of C that each thread
// calculates
const int THREAD_TILE_N, // tile width of C that each thread calcualtes
typename index_t,
typename scalar_t>
__global__ __launch_bounds__(kMaxThreads) void jagged_dense_bmm_kernel(
const at::PackedTensorAccessor32<scalar_t, 2> x_values,
const at::PackedTensorAccessor32<index_t, 1> x_offsets,
const at::PackedTensorAccessor32<scalar_t, 3> y,
at::PackedTensorAccessor32<scalar_t, 2> output,
const at::PackedTensorAccessor32<scalar_t, 2> __restrict__ x_values,
const at::PackedTensorAccessor32<index_t, 1> __restrict__ x_offsets,
const at::PackedTensorAccessor32<scalar_t, 3> __restrict__ y,
at::PackedTensorAccessor32<scalar_t, 2> __restrict__ output,
const int max_L) {
const int B = x_offsets.size(0) - 1;
const int K = x_values.size(1);
const int N = y.size(2);

const int b_l_begin = blockIdx.x * blockDim.y + threadIdx.y;
const int b_l_step = gridDim.x * blockDim.y;
for (int b_l = b_l_begin; b_l < B * max_L; b_l += b_l_step) {
const int b = b_l / max_L;
const int l = b_l % max_L;
const auto block_row = blockIdx.y;
const auto block_col = blockIdx.x;

const int THREADS_X_PER_BLOCK = BLOCK_TILE_N / THREAD_TILE_N;
const int THREADS_Y_PER_BLOCK = BLOCK_TILE_M / THREAD_TILE_M;
const int THREADS_PER_BLOCK = THREADS_X_PER_BLOCK * THREADS_Y_PER_BLOCK;
const auto thread_row = threadIdx.x / THREADS_X_PER_BLOCK;
const auto thread_col = threadIdx.x % THREADS_X_PER_BLOCK;
const auto NUM_K_BLOCKS = (K + BLOCK_TILE_K - 1) / BLOCK_TILE_K;

__shared__ scalar_t As[BLOCK_TILE_M][BLOCK_TILE_K];
__shared__ scalar_t Bs[BLOCK_TILE_K][BLOCK_TILE_N];

for (auto b = blockIdx.z; b < B; b += gridDim.z) {
const index_t row_start = x_offsets[b];
const index_t row_end = x_offsets[b + 1];
const auto length = min(row_end - row_start, (index_t)max_L);

// the indices that this current will load into shared mem
const auto inner_row_a = threadIdx.x / BLOCK_TILE_K;
const auto inner_col_a = threadIdx.x % BLOCK_TILE_K;
// the number of rows of As that will be loaded per step by a thread block
const auto A_TILE_ROW_STRIDE = THREADS_PER_BLOCK / BLOCK_TILE_K;

const auto inner_row_b = threadIdx.x / BLOCK_TILE_N;
const auto inner_col_b = threadIdx.x % BLOCK_TILE_N;
const auto B_TILE_ROW_STRIDE = THREADS_PER_BLOCK / BLOCK_TILE_N;

// registers for C
scalar_t accum[THREAD_TILE_M][THREAD_TILE_N] = {0};

// registers for As and Bs
scalar_t fragment_a[THREAD_TILE_M] = {0};
scalar_t fragment_b[THREAD_TILE_N] = {0};

// loop for block tiles in K dimension
for (auto block = 0; block < NUM_K_BLOCKS; block++) {
// load a block of x_values from global memory to shared memory
// apply tiling for threads in a block
#pragma unroll
for (auto offset = 0; offset < BLOCK_TILE_M;
offset += A_TILE_ROW_STRIDE) {
auto x_row_offset = block_row * BLOCK_TILE_M + inner_row_a + offset;
auto x_col_offset = block * BLOCK_TILE_K + inner_col_a;
if ((x_row_offset < length) && (x_col_offset < K)) {
As[inner_row_a + offset][inner_col_a] =
x_values[row_start + x_row_offset][x_col_offset];
} else {
As[inner_row_a + offset][inner_col_a] = 0;
}
}

const int row_start = x_offsets[b];
const int row_end = x_offsets[b + 1];
const int length = min(row_end - row_start, max_L);
if (length == 0 || l >= length) {
return;
} else {
// TODO: use shared memory and better reduction
for (int n = threadIdx.x; n < N; n += blockDim.x) {
at::acc_type<scalar_t, true> acc = 0;
for (int k = 0; k < K; ++k) {
acc += x_values[row_start + l][k] * y[b][k][n];
// load a block of y from global memory to shared memory
// apply tiling for threads in a block
#pragma unroll
for (auto offset = 0; offset < BLOCK_TILE_K;
offset += B_TILE_ROW_STRIDE) {
auto y_row_offset = block * BLOCK_TILE_K + inner_row_b + offset;
auto y_col_offset = block_col * BLOCK_TILE_N + inner_col_b;
if ((y_row_offset < K) && (y_col_offset < N)) {
Bs[inner_row_b + offset][inner_col_b] =
y[b][y_row_offset][y_col_offset];
} else {
Bs[inner_row_b + offset][inner_col_b] = 0;
}
}

__syncthreads();

// calculate the results per thread
#pragma unroll
for (auto k = 0; k < BLOCK_TILE_K; k++) {
// load values from shared memory to registers for x_values
for (auto row = 0; row < THREAD_TILE_M; row++) {
fragment_a[row] = As[thread_row * THREAD_TILE_M + row][k];
}

// load values from shared memory to registers for y
#pragma unroll
for (auto col = 0; col < THREAD_TILE_N; col++) {
fragment_b[col] = Bs[k][thread_col * THREAD_TILE_N + col];
}

// each thread calcualtes THREAD_TILE_M * THREAD_TILE_N elements
#pragma unroll
for (auto row = 0; row < THREAD_TILE_M; row++) {
#pragma unroll
for (auto col = 0; col < THREAD_TILE_N; col++) {
accum[row][col] += fragment_a[row] * fragment_b[col];
}
}
}

__syncthreads();
}

// write the result to the output
#pragma unroll
for (auto row = 0; row < THREAD_TILE_M; row++) {
#pragma unroll
for (auto col = 0; col < THREAD_TILE_N; col++) {
auto out_row_offset =
block_row * BLOCK_TILE_M + thread_row * THREAD_TILE_M + row;
auto out_col_offset =
block_col * BLOCK_TILE_N + thread_col * THREAD_TILE_N + col;
if ((out_row_offset < length) && (out_col_offset < N)) {
output[row_start + out_row_offset][out_col_offset] = accum[row][col];
}
output[row_start + l][n] = acc;
}
}
}
Expand All @@ -2124,9 +2223,29 @@ Tensor jagged_dense_bmm_forward(
const int total_L = x_values.size(0);
auto output = at::zeros({total_L, N}, x_values.options());
if (B > 0 && M > 0 && N > 0) {
const int block_dim_x =
std::min(div_round_up(N, kWarpSize) * kWarpSize, kMaxThreads);
const int block_dim_y = kMaxThreads / block_dim_x;
// The shared memory size is (BLOCK_TILE_M + BLOCK_TILE_N) * BLOCK_TILE_K
// BLOCK_TILE_M needs to be multiple of THREAD_TILE_M, and
// BLOCK_TILE_N needs to be multiple of THREAD_TILE_N
// The setting of these parameters needs to balance the hardware's shared
// memory size limit and occupancy
// TODO: autotune these parameters based on max_L and input and output
// tensor sizes
constexpr int BLOCK_TILE_M = 64;
constexpr int BLOCK_TILE_N = 8;
constexpr int BLOCK_TILE_K = 8;
constexpr int THREAD_TILE_M = 4;
constexpr int THREAD_TILE_N = 4;

const dim3 block(
(BLOCK_TILE_M * BLOCK_TILE_N) / (THREAD_TILE_M * THREAD_TILE_N));
const auto grid_dim_x = div_round_up(N, BLOCK_TILE_N);
const auto grid_dim_y = div_round_up(max_L, BLOCK_TILE_M);
TORCH_CHECK(
grid_dim_y <= kMaxBlockYDim,
"max_L cannot be larger than",
grid_dim_y * BLOCK_TILE_M + 1 - BLOCK_TILE_M);
const auto grid_dim_z = std::min(B, kMaxBlockZDim);
const dim3 grid(grid_dim_x, grid_dim_y, grid_dim_z);

AT_DISPATCH_INDEX_TYPES(
x_offsets.scalar_type(), "jagged_dense_bmm_kernel_1", [&] {
Expand All @@ -2136,11 +2255,15 @@ Tensor jagged_dense_bmm_forward(
x_values.scalar_type(),
"jagged_dense_bmm_kernel_2",
[&] {
jagged_dense_bmm_kernel<index_t, scalar_t>
<<<div_round_up(B * max_L, block_dim_y),
dim3(block_dim_x, block_dim_y),
0,
at::cuda::getCurrentCUDAStream()>>>(
jagged_dense_bmm_kernel<
BLOCK_TILE_M,
BLOCK_TILE_N,
BLOCK_TILE_K,
THREAD_TILE_M,
THREAD_TILE_N,
index_t,
scalar_t>
<<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
x_values.packed_accessor32<scalar_t, 2>(),
x_offsets.packed_accessor32<index_t, 1>(),
y.packed_accessor32<scalar_t, 3>(),
Expand Down