Skip to content

Commit

Permalink
Fix launchconfig y-gridsize too large in epilogue kernel (#1586)
Browse files Browse the repository at this point in the history
This fixes the launch config for epilogue kernels in the GramMatrix computation for large number of columns in the resulting kernel matrix.
This code-path is triggered when predicting scores with a support vector space larger than 262140 which would result in a y-grid-dimension larger than 65535. 

Although this is not a regression this code path might get hit more often now that we allow sparse input data, so we might want to fix it in 23.06 as well.

CC @tfeher .

Authors:
  - Malte Förster (https://github.com/mfoerste4)

Approvers:
  - Tamas Bela Feher (https://github.com/tfeher)

URL: #1586
  • Loading branch information
mfoerste4 authored Jun 22, 2023
1 parent afa5963 commit cb77979
Showing 1 changed file with 27 additions and 21 deletions.
48 changes: 27 additions & 21 deletions cpp/include/raft/distance/detail/kernels/kernel_matrices.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,17 @@ __global__ void rbf_kernel_expanded(
}
}

namespace {
std::tuple<dim3, dim3> generateLaunchConfig2dElementwiseOp(int n1, int n2)
{
dim3 block_shape = dim3(32, 4);
const int num_blocks_x = raft::ceildiv(n1, 32);
const int num_blocks_y = std::min(raft::ceildiv(n2, 32), (1 << 16) - 1);
dim3 grid_shape = dim3(num_blocks_x, num_blocks_y);
return std::make_tuple(grid_shape, block_shape);
}
} // namespace

/**
* Create a kernel matrix using polynomial kernel function.
*/
Expand All @@ -152,12 +163,11 @@ class PolynomialKernel : public GramMatrixBase<math_t> {
polynomial_kernel_nopad<<<raft::ceildiv<size_t>((size_t)rows * cols, 128), 128, 0, stream>>>(
inout, rows * cols, exponent, gain, offset);
} else {
int n1 = is_row_major ? cols : rows;
int n2 = is_row_major ? rows : cols;
polynomial_kernel<<<dim3(raft::ceildiv(n1, 32), raft::ceildiv(n2, 4), 1),
dim3(32, 4, 1),
0,
stream>>>(inout, ld, n1, n2, exponent, gain, offset);
int n1 = is_row_major ? cols : rows;
int n2 = is_row_major ? rows : cols;
auto [grid_shape, block_shape] = generateLaunchConfig2dElementwiseOp(n1, n2);
polynomial_kernel<<<grid_shape, block_shape, 0, stream>>>(
inout, ld, n1, n2, exponent, gain, offset);
}
RAFT_CUDA_TRY(cudaPeekAtLastError());
}
Expand Down Expand Up @@ -327,12 +337,10 @@ class TanhKernel : public GramMatrixBase<math_t> {
tanh_kernel_nopad<<<raft::ceildiv<size_t>((size_t)rows * cols, 128), 128, 0, stream>>>(
inout, rows * cols, gain, offset);
} else {
int n1 = is_row_major ? cols : rows;
int n2 = is_row_major ? rows : cols;
tanh_kernel<<<dim3(raft::ceildiv(n1, 32), raft::ceildiv(n2, 4), 1),
dim3(32, 4, 1),
0,
stream>>>(inout, ld, n1, n2, gain, offset);
int n1 = is_row_major ? cols : rows;
int n2 = is_row_major ? rows : cols;
auto [grid_shape, block_shape] = generateLaunchConfig2dElementwiseOp(n1, n2);
tanh_kernel<<<grid_shape, block_shape, 0, stream>>>(inout, ld, n1, n2, gain, offset);
}
RAFT_CUDA_TRY(cudaPeekAtLastError());
}
Expand Down Expand Up @@ -498,14 +506,13 @@ class RBFKernel : public GramMatrixBase<math_t> {
bool is_row_major,
cudaStream_t stream)
{
int n1 = is_row_major ? cols : rows;
int n2 = is_row_major ? rows : cols;
math_t* norm_n1 = is_row_major ? norm_x2 : norm_x1;
math_t* norm_n2 = is_row_major ? norm_x1 : norm_x2;
rbf_kernel_expanded<<<dim3(raft::ceildiv(n1, 32), raft::ceildiv(n2, 4), 1),
dim3(32, 4, 1),
0,
stream>>>(inout, ld, n1, n2, norm_n1, norm_n2, gain);
int n1 = is_row_major ? cols : rows;
int n2 = is_row_major ? rows : cols;
math_t* norm_n1 = is_row_major ? norm_x2 : norm_x1;
math_t* norm_n2 = is_row_major ? norm_x1 : norm_x2;
auto [grid_shape, block_shape] = generateLaunchConfig2dElementwiseOp(n1, n2);
rbf_kernel_expanded<<<grid_shape, block_shape, 0, stream>>>(
inout, ld, n1, n2, norm_n1, norm_n2, gain);
}

public:
Expand Down Expand Up @@ -576,7 +583,6 @@ class RBFKernel : public GramMatrixBase<math_t> {
math_t* norm_x2)
{
cudaStream_t stream = resource::get_cuda_stream(handle);
// lazy compute norms if not given
rmm::device_uvector<math_t> tmp_norm_x1(0, stream);
rmm::device_uvector<math_t> tmp_norm_x2(0, stream);
Expand Down

0 comments on commit cb77979

Please sign in to comment.