Skip to content

Commit

Permalink
Optimize fusedL2NN when data is skinny
Browse files Browse the repository at this point in the history
  • Loading branch information
ahendriksen committed Aug 26, 2022
1 parent 57df37d commit 67e1328
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 14 deletions.
33 changes: 28 additions & 5 deletions cpp/bench/spatial/fused_l2_nn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,13 @@
#include <common/benchmark.hpp>
#include <limits>
#include <raft/cudart_utils.h>
#include <raft/distance/fused_l2_nn.hpp>
#include <raft/distance/fused_l2_nn.cuh>
#include <raft/handle.hpp>
#include <raft/linalg/norm.hpp>
#include <raft/random/rng.cuh>

#if defined RAFT_NN_COMPILED
#include <raft/spatial/knn/specializations.hpp>
#endif
// Note: do not include raft/spatial/knn/specializations.hpp based on
// RAFT_NN_COMPILED, as fusedL2NN is not specialized and not defined there.

namespace raft::bench::spatial {

Expand Down Expand Up @@ -73,6 +72,30 @@ struct fused_l2_nn : public fixture {
false,
stream);
});

// Num distance calculations
int64_t num_dist_calcs = (int64_t)params.n * (int64_t)params.m;

int64_t num_flops = 3 * num_dist_calcs * params.k;

int64_t read_elts = (int64_t)params.n * params.k + (int64_t)params.m * params.k;
int64_t write_elts = (int64_t)params.n;

state.counters["D/s"] = benchmark::Counter(num_dist_calcs,
benchmark::Counter::kIsIterationInvariantRate,
benchmark::Counter::OneK::kIs1000);

state.counters["FLOP/s"] = benchmark::Counter(
num_flops, benchmark::Counter::kIsIterationInvariantRate, benchmark::Counter::OneK::kIs1000);

state.counters["BW Wr"] = benchmark::Counter(write_elts * sizeof(cub::KeyValuePair<int, T>),
benchmark::Counter::kIsIterationInvariantRate,
benchmark::Counter::OneK::kIs1000);
state.counters["BW Rd"] = benchmark::Counter(read_elts * sizeof(float),
benchmark::Counter::kIsIterationInvariantRate,
benchmark::Counter::OneK::kIs1000);

state.counters["K"] = benchmark::Counter(params.k);
}

private:
Expand All @@ -88,9 +111,9 @@ const std::vector<fused_l2_nn_inputs> fused_l2_nn_input_vecs = {
{32, 16384, 16384}, {64, 16384, 16384}, {128, 16384, 16384}, {256, 16384, 16384},
{512, 16384, 16384}, {1024, 16384, 16384}, {16384, 32, 16384}, {16384, 64, 16384},
{16384, 128, 16384}, {16384, 256, 16384}, {16384, 512, 16384}, {16384, 1024, 16384},
{16384, 16384, 2}, {16384, 16384, 4}, {16384, 16384, 8}, {16384, 16384, 16},
{16384, 16384, 32}, {16384, 16384, 64}, {16384, 16384, 128}, {16384, 16384, 256},
{16384, 16384, 512}, {16384, 16384, 1024}, {16384, 16384, 16384},

};

RAFT_BENCH_REGISTER(fused_l2_nn<float>, "", fused_l2_nn_input_vecs);
Expand Down
5 changes: 3 additions & 2 deletions cpp/include/raft/distance/detail/fused_l2_nn.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ __global__ __launch_bounds__(P::Nthreads, 2) void fusedL2NNkernel(OutT* min,
template <typename DataT,
typename OutT,
typename IdxT,
int VecLen,
typename Policy,
typename ReduceOpT,
typename KVPReduceOpT>
void fusedL2NNImpl(OutT* min,
Expand All @@ -279,7 +279,8 @@ void fusedL2NNImpl(OutT* min,
bool initOutBuffer,
cudaStream_t stream)
{
typedef typename linalg::Policy4x4<DataT, VecLen>::Policy P;
// The kernel policy is determined by fusedL2NN.
typedef Policy P;

dim3 blk(P::Nthreads);
auto nblks = raft::ceildiv<int>(m, P::Nthreads);
Expand Down
59 changes: 52 additions & 7 deletions cpp/include/raft/distance/fused_l2_nn.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <raft/cuda_utils.cuh>
#include <raft/distance/detail/fused_l2_nn.cuh>
#include <raft/handle.hpp>
#include <raft/linalg/contractions.cuh>
#include <stdint.h>

namespace raft {
Expand Down Expand Up @@ -99,20 +100,64 @@ void fusedL2NN(OutT* min,
bool initOutBuffer,
cudaStream_t stream)
{
// When k is smaller than 32, the Policy4x4 results in redundant calculations
// as it uses tiles that have k=32. Therefore, use a "skinny" policy instead
// that uses tiles with a smaller value of k.
bool is_skinny = k < 32;

size_t bytes = sizeof(DataT) * k;
if (16 % sizeof(DataT) == 0 && bytes % 16 == 0) {
detail::fusedL2NNImpl<DataT, OutT, IdxT, 16 / sizeof(DataT), ReduceOpT>(
min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream);
if (is_skinny) {
detail::fusedL2NNImpl<DataT,
OutT,
IdxT,
typename linalg::Policy4x4Skinny<DataT, 16 / sizeof(DataT)>::Policy,
ReduceOpT>(
min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream);
} else {
detail::fusedL2NNImpl<DataT,
OutT,
IdxT,
typename linalg::Policy4x4<DataT, 16 / sizeof(DataT)>::Policy,
ReduceOpT>(
min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream);
}
} else if (8 % sizeof(DataT) == 0 && bytes % 8 == 0) {
detail::fusedL2NNImpl<DataT, OutT, IdxT, 8 / sizeof(DataT), ReduceOpT>(
min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream);
if (is_skinny) {
detail::fusedL2NNImpl<DataT,
OutT,
IdxT,
typename linalg::Policy4x4Skinny<DataT, 8 / sizeof(DataT)>::Policy,
ReduceOpT>(
min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream);
} else {
detail::fusedL2NNImpl<DataT,
OutT,
IdxT,
typename linalg::Policy4x4<DataT, 8 / sizeof(DataT)>::Policy,
ReduceOpT>(
min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream);
}
} else {
detail::fusedL2NNImpl<DataT, OutT, IdxT, 1, ReduceOpT>(
min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream);
if (is_skinny) {
detail::fusedL2NNImpl<DataT,
OutT,
IdxT,
typename linalg::Policy4x4Skinny<DataT, 1>::Policy,
ReduceOpT>(
min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream);
} else {
detail::fusedL2NNImpl<DataT,
OutT,
IdxT,
typename linalg::Policy4x4<DataT, 1>::Policy,
ReduceOpT>(
min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream);
}
}
}

} // namespace distance
} // namespace raft

#endif
#endif
26 changes: 26 additions & 0 deletions cpp/include/raft/linalg/contractions.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,32 @@ struct Policy4x4<double, _veclen> {
};
/** @} */

/**
* @defgroup Policy4x4Skinny
*
* A smaller k-block (8 instead of 32) with fewer threads per block (8x8 instead
* of 16x16), which is faster for raft::distance::fusedL2NN on skinny matrices,
* i.e., matrices with a small k dimension.
*
* @{
*/
template <typename DataT, int _veclen>
struct Policy4x4Skinny {
};

template <int _veclen>
struct Policy4x4Skinny<float, _veclen> {
typedef KernelPolicy<float, _veclen, 8, 4, 4, 8, 8> Policy;
typedef ColKernelPolicy<float, _veclen, 8, 4, 4, 8, 8> ColPolicy;
};

template <int _veclen>
struct Policy4x4Skinny<double, _veclen> {
typedef KernelPolicy<double, _veclen, 8, 4, 4, 8, 8> Policy;
typedef ColKernelPolicy<double, _veclen, 8, 4, 4, 8, 8> ColPolicy;
};
/** @} */

/**
* @defgroup Policy2x8 16 elements per thread Policy with k-block = 16
* @{
Expand Down

0 comments on commit 67e1328

Please sign in to comment.