Skip to content

Commit

Permalink
Optimize fusedL2NN when data is skinny (#794)
Browse files Browse the repository at this point in the history
The fusedL2NN kernel uses tiling to maximize performance. The current implementation assumes that the input matrices are at least 32 elements wide. When this is not the case, it performs redundant computations. 

This PR adds a policy to apply when the matrix is skinny (less than 32 elements wide). This results in a 1.5 - 2x performance improvement across GPU architectures.

Authors:
  - Allard Hendriksen (https://github.com/ahendriksen)

Approvers:
  - Tamas Bela Feher (https://github.com/tfeher)

URL: #794
  • Loading branch information
Allard Hendriksen authored Sep 6, 2022
1 parent 700bb1e commit 8c639d9
Show file tree
Hide file tree
Showing 5 changed files with 195 additions and 39 deletions.
35 changes: 31 additions & 4 deletions cpp/bench/spatial/fused_l2_nn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,12 @@
#include <raft/linalg/norm.cuh>
#include <raft/random/rng.cuh>

#if defined RAFT_NN_COMPILED
#include <raft/spatial/knn/specializations.cuh>
#endif
// TODO: Once fusedL2NN is specialized in the raft_distance shared library, add
// back
//
// #if defined RAFT_NN_COMPILED
// #include <raft/spatial/knn/specializations.hpp>
// #endif

namespace raft::bench::spatial {

Expand Down Expand Up @@ -73,6 +76,30 @@ struct fused_l2_nn : public fixture {
false,
stream);
});

// Num distance calculations
int64_t num_dist_calcs = (int64_t)params.n * (int64_t)params.m;

int64_t num_flops = 3 * num_dist_calcs * params.k;

int64_t read_elts = (int64_t)params.n * params.k + (int64_t)params.m * params.k;
int64_t write_elts = (int64_t)params.n;

state.counters["D/s"] = benchmark::Counter(num_dist_calcs,
benchmark::Counter::kIsIterationInvariantRate,
benchmark::Counter::OneK::kIs1000);

state.counters["FLOP/s"] = benchmark::Counter(
num_flops, benchmark::Counter::kIsIterationInvariantRate, benchmark::Counter::OneK::kIs1000);

state.counters["BW Wr"] = benchmark::Counter(write_elts * sizeof(cub::KeyValuePair<int, T>),
benchmark::Counter::kIsIterationInvariantRate,
benchmark::Counter::OneK::kIs1000);
state.counters["BW Rd"] = benchmark::Counter(read_elts * sizeof(float),
benchmark::Counter::kIsIterationInvariantRate,
benchmark::Counter::OneK::kIs1000);

state.counters["K"] = benchmark::Counter(params.k);
}

private:
Expand All @@ -88,9 +115,9 @@ const std::vector<fused_l2_nn_inputs> fused_l2_nn_input_vecs = {
{32, 16384, 16384}, {64, 16384, 16384}, {128, 16384, 16384}, {256, 16384, 16384},
{512, 16384, 16384}, {1024, 16384, 16384}, {16384, 32, 16384}, {16384, 64, 16384},
{16384, 128, 16384}, {16384, 256, 16384}, {16384, 512, 16384}, {16384, 1024, 16384},
{16384, 16384, 2}, {16384, 16384, 4}, {16384, 16384, 8}, {16384, 16384, 16},
{16384, 16384, 32}, {16384, 16384, 64}, {16384, 16384, 128}, {16384, 16384, 256},
{16384, 16384, 512}, {16384, 16384, 1024}, {16384, 16384, 16384},

};

RAFT_BENCH_REGISTER(fused_l2_nn<float>, "", fused_l2_nn_input_vecs);
Expand Down
27 changes: 11 additions & 16 deletions cpp/include/raft/distance/detail/fused_l2_nn.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,14 @@ DI void updateReducedVal(
const auto lid = threadIdx.x % raft::WarpSize;
const auto accrowid = threadIdx.x / P::AccThCols;

// for now have first lane from each warp update a unique output row. This
// will resolve hang issues with pre-Volta architectures
// Update each output row in order within a warp. This will resolve hang
// issues with pre-Volta architectures
#pragma unroll
for (int j = 0; j < (raft::WarpSize / P::AccThCols); j++) {
if (lid == 0) {
if (lid == j * P::AccThCols) {
#pragma unroll
for (int i = 0; i < P::AccRowsPerTh; ++i) {
auto rid = gridStrideY + accrowid + j + i * P::AccThRows;
auto rid = gridStrideY + accrowid + i * P::AccThRows;
if (rid < m) {
auto value = val[i];
while (atomicCAS(mutex + rid, 0, 1) == 1)
Expand All @@ -111,14 +111,6 @@ DI void updateReducedVal(
}
}
}
if (j < (raft::WarpSize / P::AccThCols) - 1) {
#pragma unroll
for (int i = 0; i < P::AccRowsPerTh; ++i) {
auto tmpkey = raft::shfl(val[i].key, (j + 1) * P::AccThCols);
auto tmpvalue = raft::shfl(val[i].value, (j + 1) * P::AccThCols);
val[i] = {tmpkey, tmpvalue};
}
}
}
}

Expand Down Expand Up @@ -210,8 +202,10 @@ __global__ __launch_bounds__(P::Nthreads, 2) void fusedL2NNkernel(OutT* min,
for (int i = 0; i < P::AccRowsPerTh; ++i) {
#pragma unroll
for (int j = P::AccThCols / 2; j > 0; j >>= 1) {
auto tmpkey = raft::shfl(val[i].key, lid + j);
auto tmpvalue = raft::shfl(val[i].value, lid + j);
// Actually, the srcLane (lid +j) should be (lid +j) % P:AccThCols,
// but the shfl op applies the modulo internally.
auto tmpkey = raft::shfl(val[i].key, lid + j, P::AccThCols);
auto tmpvalue = raft::shfl(val[i].value, lid + j, P::AccThCols);
KVPair tmp = {tmpkey, tmpvalue};
val[i] = pairRed_op(accrowid + i * P::AccThRows + gridStrideY, tmp, val[i]);
}
Expand Down Expand Up @@ -261,7 +255,7 @@ __global__ __launch_bounds__(P::Nthreads, 2) void fusedL2NNkernel(OutT* min,
template <typename DataT,
typename OutT,
typename IdxT,
int VecLen,
typename Policy,
typename ReduceOpT,
typename KVPReduceOpT>
void fusedL2NNImpl(OutT* min,
Expand All @@ -279,7 +273,8 @@ void fusedL2NNImpl(OutT* min,
bool initOutBuffer,
cudaStream_t stream)
{
typedef typename linalg::Policy4x4<DataT, VecLen>::Policy P;
// The kernel policy is determined by fusedL2NN.
typedef Policy P;

dim3 blk(P::Nthreads);
auto nblks = raft::ceildiv<int>(m, P::Nthreads);
Expand Down
59 changes: 52 additions & 7 deletions cpp/include/raft/distance/fused_l2_nn.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <raft/cuda_utils.cuh>
#include <raft/distance/detail/fused_l2_nn.cuh>
#include <raft/handle.hpp>
#include <raft/linalg/contractions.cuh>
#include <stdint.h>

namespace raft {
Expand Down Expand Up @@ -99,20 +100,64 @@ void fusedL2NN(OutT* min,
bool initOutBuffer,
cudaStream_t stream)
{
// When k is smaller than 32, the Policy4x4 results in redundant calculations
// as it uses tiles that have k=32. Therefore, use a "skinny" policy instead
// that uses tiles with a smaller value of k.
bool is_skinny = k < 32;

size_t bytes = sizeof(DataT) * k;
if (16 % sizeof(DataT) == 0 && bytes % 16 == 0) {
detail::fusedL2NNImpl<DataT, OutT, IdxT, 16 / sizeof(DataT), ReduceOpT>(
min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream);
if (is_skinny) {
detail::fusedL2NNImpl<DataT,
OutT,
IdxT,
typename linalg::Policy4x4Skinny<DataT, 16 / sizeof(DataT)>::Policy,
ReduceOpT>(
min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream);
} else {
detail::fusedL2NNImpl<DataT,
OutT,
IdxT,
typename linalg::Policy4x4<DataT, 16 / sizeof(DataT)>::Policy,
ReduceOpT>(
min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream);
}
} else if (8 % sizeof(DataT) == 0 && bytes % 8 == 0) {
detail::fusedL2NNImpl<DataT, OutT, IdxT, 8 / sizeof(DataT), ReduceOpT>(
min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream);
if (is_skinny) {
detail::fusedL2NNImpl<DataT,
OutT,
IdxT,
typename linalg::Policy4x4Skinny<DataT, 8 / sizeof(DataT)>::Policy,
ReduceOpT>(
min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream);
} else {
detail::fusedL2NNImpl<DataT,
OutT,
IdxT,
typename linalg::Policy4x4<DataT, 8 / sizeof(DataT)>::Policy,
ReduceOpT>(
min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream);
}
} else {
detail::fusedL2NNImpl<DataT, OutT, IdxT, 1, ReduceOpT>(
min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream);
if (is_skinny) {
detail::fusedL2NNImpl<DataT,
OutT,
IdxT,
typename linalg::Policy4x4Skinny<DataT, 1>::Policy,
ReduceOpT>(
min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream);
} else {
detail::fusedL2NNImpl<DataT,
OutT,
IdxT,
typename linalg::Policy4x4<DataT, 1>::Policy,
ReduceOpT>(
min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream);
}
}
}

} // namespace distance
} // namespace raft

#endif
#endif
22 changes: 22 additions & 0 deletions cpp/include/raft/linalg/contractions.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,28 @@ struct Policy4x4<double, _veclen> {
};
/** @} */

/**
* A smaller k-block (8 instead of 32) with fewer threads per block (8x8 instead
* of 16x16), which is faster for raft::distance::fusedL2NN on skinny matrices,
* i.e., matrices with a small k dimension.
*
*/
template <typename DataT, int _veclen>
struct Policy4x4Skinny {
};

template <int _veclen>
struct Policy4x4Skinny<float, _veclen> {
typedef KernelPolicy<float, _veclen, 8, 4, 4, 8, 8> Policy;
typedef ColKernelPolicy<float, _veclen, 8, 4, 4, 8, 8> ColPolicy;
};

template <int _veclen>
struct Policy4x4Skinny<double, _veclen> {
typedef KernelPolicy<double, _veclen, 8, 4, 4, 8, 8> Policy;
typedef ColKernelPolicy<double, _veclen, 8, 4, 4, 8, 8> ColPolicy;
};

/**
* @defgroup Policy2x8 16 elements per thread Policy with k-block = 16
* @{
Expand Down
91 changes: 79 additions & 12 deletions cpp/test/distance/fused_l2_nn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@
#include <raft/linalg/norm.cuh>
#include <raft/random/rng.cuh>

// TODO: Once fusedL2NN is specialized in the raft_distance shared library, add
// the following:
//
// #if defined RAFT_NN_COMPILED
// #include <raft/spatial/knn/specializations.hpp>
// #endif

namespace raft {
namespace distance {

Expand Down Expand Up @@ -102,6 +109,23 @@ struct Inputs {
DataT tolerance;
int m, n, k;
unsigned long long int seed;

friend std::ostream& operator<<(std::ostream& os, const Inputs& p)
{
return os << "m: " << p.m
<< ", "
"n: "
<< p.n
<< ", "
"k: "
<< p.k
<< ", "
"seed: "
<< p.seed
<< ", "
"tol: "
<< p.tolerance;
}
};

template <typename DataT, bool Sqrt>
Expand Down Expand Up @@ -231,19 +255,62 @@ template <typename K, typename V, typename L>
}

const std::vector<Inputs<float>> inputsf = {
{0.001f, 32, 32, 32, 1234ULL}, {0.001f, 32, 64, 32, 1234ULL}, {0.001f, 64, 32, 32, 1234ULL},
{0.001f, 64, 64, 32, 1234ULL}, {0.001f, 128, 32, 32, 1234ULL}, {0.001f, 128, 64, 32, 1234ULL},
{0.001f, 128, 128, 64, 1234ULL}, {0.001f, 64, 128, 128, 1234ULL},

{0.001f, 32, 32, 34, 1234ULL}, {0.001f, 32, 64, 34, 1234ULL}, {0.001f, 64, 32, 34, 1234ULL},
{0.001f, 64, 64, 34, 1234ULL}, {0.001f, 128, 32, 34, 1234ULL}, {0.001f, 128, 64, 34, 1234ULL},
{0.001f, 128, 128, 66, 1234ULL}, {0.001f, 64, 128, 130, 1234ULL},

{0.001f, 32, 32, 33, 1234ULL}, {0.001f, 32, 64, 33, 1234ULL}, {0.001f, 64, 32, 33, 1234ULL},
{0.001f, 64, 64, 33, 1234ULL}, {0.001f, 128, 32, 33, 1234ULL}, {0.001f, 128, 64, 33, 1234ULL},
{0.001f, 128, 128, 65, 1234ULL}, {0.001f, 64, 128, 129, 1234ULL},

{0.001f, 32, 32, 32, 1234ULL},
{0.001f, 32, 64, 32, 1234ULL},
{0.001f, 64, 32, 32, 1234ULL},
{0.001f, 64, 64, 32, 1234ULL},
{0.001f, 128, 32, 32, 1234ULL},
{0.001f, 128, 64, 32, 1234ULL},
{0.001f, 128, 128, 64, 1234ULL},
{0.001f, 64, 128, 128, 1234ULL},

{0.001f, 32, 32, 34, 1234ULL},
{0.001f, 32, 64, 34, 1234ULL},
{0.001f, 64, 32, 34, 1234ULL},
{0.001f, 64, 64, 34, 1234ULL},
{0.001f, 128, 32, 34, 1234ULL},
{0.001f, 128, 64, 34, 1234ULL},
{0.001f, 128, 128, 66, 1234ULL},
{0.001f, 64, 128, 130, 1234ULL},

{0.001f, 32, 32, 33, 1234ULL},
{0.001f, 32, 64, 33, 1234ULL},
{0.001f, 64, 32, 33, 1234ULL},
{0.001f, 64, 64, 33, 1234ULL},
{0.001f, 128, 32, 33, 1234ULL},
{0.001f, 128, 64, 33, 1234ULL},
{0.001f, 128, 128, 65, 1234ULL},
{0.001f, 64, 128, 129, 1234ULL},
{0.006f, 1805, 134, 2, 1234ULL},

// Repeat with smaller values of k
{0.006f, 32, 32, 1, 1234ULL},
{0.001f, 32, 64, 2, 1234ULL},
{0.001f, 64, 32, 3, 1234ULL},
{0.001f, 64, 64, 4, 1234ULL},
{0.001f, 128, 32, 5, 1234ULL},
{0.001f, 128, 64, 6, 1234ULL},
{0.001f, 128, 128, 7, 1234ULL},
{0.001f, 64, 128, 8, 1234ULL},

{0.001f, 32, 32, 9, 1234ULL},
{0.001f, 32, 64, 10, 1234ULL},
{0.001f, 64, 32, 11, 1234ULL},
{0.001f, 64, 64, 12, 1234ULL},
{0.001f, 128, 32, 13, 1234ULL},
{0.001f, 128, 64, 14, 1234ULL},
{0.001f, 128, 128, 15, 1234ULL},
{0.001f, 64, 128, 16, 1234ULL},

{0.001f, 32, 32, 17, 1234ULL},
{0.001f, 32, 64, 18, 1234ULL},
{0.001f, 64, 32, 19, 1234ULL},
{0.001f, 64, 64, 20, 1234ULL},
{0.001f, 128, 32, 21, 1234ULL},
{0.001f, 128, 64, 22, 1234ULL},
{0.001f, 128, 128, 23, 1234ULL},
{0.00001, 64, 128, 24, 1234ULL},
{0.001f, 1805, 134, 25, 1234ULL},
};
typedef FusedL2NNTest<float, false> FusedL2NNTestF_Sq;
TEST_P(FusedL2NNTestF_Sq, Result)
Expand Down

0 comments on commit 8c639d9

Please sign in to comment.