Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize fusedL2NN when data is skinny #794

Merged
merged 5 commits into from
Sep 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 31 additions & 4 deletions cpp/bench/spatial/fused_l2_nn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,12 @@
#include <raft/linalg/norm.cuh>
#include <raft/random/rng.cuh>

#if defined RAFT_NN_COMPILED
#include <raft/spatial/knn/specializations.cuh>
#endif
// TODO: Once fusedL2NN is specialized in the raft_distance shared library, add
// back
//
// #if defined RAFT_NN_COMPILED
// #include <raft/spatial/knn/specializations.hpp>
// #endif

namespace raft::bench::spatial {

Expand Down Expand Up @@ -73,6 +76,30 @@ struct fused_l2_nn : public fixture {
false,
stream);
});

// Num distance calculations
int64_t num_dist_calcs = (int64_t)params.n * (int64_t)params.m;

int64_t num_flops = 3 * num_dist_calcs * params.k;

int64_t read_elts = (int64_t)params.n * params.k + (int64_t)params.m * params.k;
int64_t write_elts = (int64_t)params.n;

state.counters["D/s"] = benchmark::Counter(num_dist_calcs,
benchmark::Counter::kIsIterationInvariantRate,
benchmark::Counter::OneK::kIs1000);

state.counters["FLOP/s"] = benchmark::Counter(
num_flops, benchmark::Counter::kIsIterationInvariantRate, benchmark::Counter::OneK::kIs1000);

state.counters["BW Wr"] = benchmark::Counter(write_elts * sizeof(cub::KeyValuePair<int, T>),
benchmark::Counter::kIsIterationInvariantRate,
benchmark::Counter::OneK::kIs1000);
state.counters["BW Rd"] = benchmark::Counter(read_elts * sizeof(float),
benchmark::Counter::kIsIterationInvariantRate,
benchmark::Counter::OneK::kIs1000);

state.counters["K"] = benchmark::Counter(params.k);
}

private:
Expand All @@ -88,9 +115,9 @@ const std::vector<fused_l2_nn_inputs> fused_l2_nn_input_vecs = {
{32, 16384, 16384}, {64, 16384, 16384}, {128, 16384, 16384}, {256, 16384, 16384},
{512, 16384, 16384}, {1024, 16384, 16384}, {16384, 32, 16384}, {16384, 64, 16384},
{16384, 128, 16384}, {16384, 256, 16384}, {16384, 512, 16384}, {16384, 1024, 16384},
{16384, 16384, 2}, {16384, 16384, 4}, {16384, 16384, 8}, {16384, 16384, 16},
{16384, 16384, 32}, {16384, 16384, 64}, {16384, 16384, 128}, {16384, 16384, 256},
{16384, 16384, 512}, {16384, 16384, 1024}, {16384, 16384, 16384},

};

RAFT_BENCH_REGISTER(fused_l2_nn<float>, "", fused_l2_nn_input_vecs);
Expand Down
27 changes: 11 additions & 16 deletions cpp/include/raft/distance/detail/fused_l2_nn.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,14 @@ DI void updateReducedVal(
const auto lid = threadIdx.x % raft::WarpSize;
const auto accrowid = threadIdx.x / P::AccThCols;

// for now have first lane from each warp update a unique output row. This
// will resolve hang issues with pre-Volta architectures
// Update each output row in order within a warp. This will resolve hang
// issues with pre-Volta architectures
#pragma unroll
for (int j = 0; j < (raft::WarpSize / P::AccThCols); j++) {
if (lid == 0) {
if (lid == j * P::AccThCols) {
#pragma unroll
for (int i = 0; i < P::AccRowsPerTh; ++i) {
auto rid = gridStrideY + accrowid + j + i * P::AccThRows;
auto rid = gridStrideY + accrowid + i * P::AccThRows;
ahendriksen marked this conversation as resolved.
Show resolved Hide resolved
if (rid < m) {
auto value = val[i];
while (atomicCAS(mutex + rid, 0, 1) == 1)
Expand All @@ -111,14 +111,6 @@ DI void updateReducedVal(
}
}
}
if (j < (raft::WarpSize / P::AccThCols) - 1) {
#pragma unroll
for (int i = 0; i < P::AccRowsPerTh; ++i) {
auto tmpkey = raft::shfl(val[i].key, (j + 1) * P::AccThCols);
auto tmpvalue = raft::shfl(val[i].value, (j + 1) * P::AccThCols);
val[i] = {tmpkey, tmpvalue};
}
}
}
}

Expand Down Expand Up @@ -210,8 +202,10 @@ __global__ __launch_bounds__(P::Nthreads, 2) void fusedL2NNkernel(OutT* min,
for (int i = 0; i < P::AccRowsPerTh; ++i) {
#pragma unroll
for (int j = P::AccThCols / 2; j > 0; j >>= 1) {
auto tmpkey = raft::shfl(val[i].key, lid + j);
auto tmpvalue = raft::shfl(val[i].value, lid + j);
// Actually, the srcLane (lid +j) should be (lid +j) % P:AccThCols,
// but the shfl op applies the modulo internally.
auto tmpkey = raft::shfl(val[i].key, lid + j, P::AccThCols);
ahendriksen marked this conversation as resolved.
Show resolved Hide resolved
auto tmpvalue = raft::shfl(val[i].value, lid + j, P::AccThCols);
KVPair tmp = {tmpkey, tmpvalue};
val[i] = pairRed_op(accrowid + i * P::AccThRows + gridStrideY, tmp, val[i]);
}
Expand Down Expand Up @@ -261,7 +255,7 @@ __global__ __launch_bounds__(P::Nthreads, 2) void fusedL2NNkernel(OutT* min,
template <typename DataT,
typename OutT,
typename IdxT,
int VecLen,
typename Policy,
typename ReduceOpT,
typename KVPReduceOpT>
void fusedL2NNImpl(OutT* min,
Expand All @@ -279,7 +273,8 @@ void fusedL2NNImpl(OutT* min,
bool initOutBuffer,
cudaStream_t stream)
{
typedef typename linalg::Policy4x4<DataT, VecLen>::Policy P;
// The kernel policy is determined by fusedL2NN.
typedef Policy P;

dim3 blk(P::Nthreads);
auto nblks = raft::ceildiv<int>(m, P::Nthreads);
Expand Down
59 changes: 52 additions & 7 deletions cpp/include/raft/distance/fused_l2_nn.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <raft/cuda_utils.cuh>
#include <raft/distance/detail/fused_l2_nn.cuh>
#include <raft/handle.hpp>
#include <raft/linalg/contractions.cuh>
#include <stdint.h>

namespace raft {
Expand Down Expand Up @@ -99,20 +100,64 @@ void fusedL2NN(OutT* min,
bool initOutBuffer,
cudaStream_t stream)
{
// When k is smaller than 32, the Policy4x4 results in redundant calculations
// as it uses tiles that have k=32. Therefore, use a "skinny" policy instead
// that uses tiles with a smaller value of k.
bool is_skinny = k < 32;

size_t bytes = sizeof(DataT) * k;
if (16 % sizeof(DataT) == 0 && bytes % 16 == 0) {
detail::fusedL2NNImpl<DataT, OutT, IdxT, 16 / sizeof(DataT), ReduceOpT>(
min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream);
if (is_skinny) {
detail::fusedL2NNImpl<DataT,
OutT,
IdxT,
typename linalg::Policy4x4Skinny<DataT, 16 / sizeof(DataT)>::Policy,
ReduceOpT>(
min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream);
} else {
detail::fusedL2NNImpl<DataT,
OutT,
IdxT,
typename linalg::Policy4x4<DataT, 16 / sizeof(DataT)>::Policy,
ReduceOpT>(
min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream);
}
} else if (8 % sizeof(DataT) == 0 && bytes % 8 == 0) {
detail::fusedL2NNImpl<DataT, OutT, IdxT, 8 / sizeof(DataT), ReduceOpT>(
min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream);
if (is_skinny) {
detail::fusedL2NNImpl<DataT,
OutT,
IdxT,
typename linalg::Policy4x4Skinny<DataT, 8 / sizeof(DataT)>::Policy,
ReduceOpT>(
min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream);
} else {
detail::fusedL2NNImpl<DataT,
OutT,
IdxT,
typename linalg::Policy4x4<DataT, 8 / sizeof(DataT)>::Policy,
ReduceOpT>(
min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream);
}
} else {
detail::fusedL2NNImpl<DataT, OutT, IdxT, 1, ReduceOpT>(
min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream);
if (is_skinny) {
detail::fusedL2NNImpl<DataT,
OutT,
IdxT,
typename linalg::Policy4x4Skinny<DataT, 1>::Policy,
ReduceOpT>(
min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream);
} else {
detail::fusedL2NNImpl<DataT,
OutT,
IdxT,
typename linalg::Policy4x4<DataT, 1>::Policy,
ReduceOpT>(
min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream);
}
}
}

} // namespace distance
} // namespace raft

#endif
#endif
22 changes: 22 additions & 0 deletions cpp/include/raft/linalg/contractions.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,28 @@ struct Policy4x4<double, _veclen> {
};
/** @} */

/**
* A smaller k-block (8 instead of 32) with fewer threads per block (8x8 instead
* of 16x16), which is faster for raft::distance::fusedL2NN on skinny matrices,
* i.e., matrices with a small k dimension.
*
*/
template <typename DataT, int _veclen>
struct Policy4x4Skinny {
};

template <int _veclen>
struct Policy4x4Skinny<float, _veclen> {
typedef KernelPolicy<float, _veclen, 8, 4, 4, 8, 8> Policy;
typedef ColKernelPolicy<float, _veclen, 8, 4, 4, 8, 8> ColPolicy;
};

template <int _veclen>
struct Policy4x4Skinny<double, _veclen> {
typedef KernelPolicy<double, _veclen, 8, 4, 4, 8, 8> Policy;
typedef ColKernelPolicy<double, _veclen, 8, 4, 4, 8, 8> ColPolicy;
};

/**
* @defgroup Policy2x8 16 elements per thread Policy with k-block = 16
* @{
Expand Down
91 changes: 79 additions & 12 deletions cpp/test/distance/fused_l2_nn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@
#include <raft/linalg/norm.cuh>
#include <raft/random/rng.cuh>

// TODO: Once fusedL2NN is specialized in the raft_distance shared library, add
// the following:
//
// #if defined RAFT_NN_COMPILED
// #include <raft/spatial/knn/specializations.hpp>
// #endif

namespace raft {
namespace distance {

Expand Down Expand Up @@ -102,6 +109,23 @@ struct Inputs {
DataT tolerance;
int m, n, k;
unsigned long long int seed;

friend std::ostream& operator<<(std::ostream& os, const Inputs& p)
{
return os << "m: " << p.m
<< ", "
"n: "
<< p.n
<< ", "
"k: "
<< p.k
<< ", "
"seed: "
<< p.seed
<< ", "
"tol: "
<< p.tolerance;
}
};

template <typename DataT, bool Sqrt>
Expand Down Expand Up @@ -231,19 +255,62 @@ template <typename K, typename V, typename L>
}

const std::vector<Inputs<float>> inputsf = {
{0.001f, 32, 32, 32, 1234ULL}, {0.001f, 32, 64, 32, 1234ULL}, {0.001f, 64, 32, 32, 1234ULL},
{0.001f, 64, 64, 32, 1234ULL}, {0.001f, 128, 32, 32, 1234ULL}, {0.001f, 128, 64, 32, 1234ULL},
{0.001f, 128, 128, 64, 1234ULL}, {0.001f, 64, 128, 128, 1234ULL},

{0.001f, 32, 32, 34, 1234ULL}, {0.001f, 32, 64, 34, 1234ULL}, {0.001f, 64, 32, 34, 1234ULL},
{0.001f, 64, 64, 34, 1234ULL}, {0.001f, 128, 32, 34, 1234ULL}, {0.001f, 128, 64, 34, 1234ULL},
{0.001f, 128, 128, 66, 1234ULL}, {0.001f, 64, 128, 130, 1234ULL},

{0.001f, 32, 32, 33, 1234ULL}, {0.001f, 32, 64, 33, 1234ULL}, {0.001f, 64, 32, 33, 1234ULL},
{0.001f, 64, 64, 33, 1234ULL}, {0.001f, 128, 32, 33, 1234ULL}, {0.001f, 128, 64, 33, 1234ULL},
{0.001f, 128, 128, 65, 1234ULL}, {0.001f, 64, 128, 129, 1234ULL},

{0.001f, 32, 32, 32, 1234ULL},
{0.001f, 32, 64, 32, 1234ULL},
{0.001f, 64, 32, 32, 1234ULL},
{0.001f, 64, 64, 32, 1234ULL},
{0.001f, 128, 32, 32, 1234ULL},
{0.001f, 128, 64, 32, 1234ULL},
{0.001f, 128, 128, 64, 1234ULL},
{0.001f, 64, 128, 128, 1234ULL},

{0.001f, 32, 32, 34, 1234ULL},
{0.001f, 32, 64, 34, 1234ULL},
{0.001f, 64, 32, 34, 1234ULL},
{0.001f, 64, 64, 34, 1234ULL},
{0.001f, 128, 32, 34, 1234ULL},
{0.001f, 128, 64, 34, 1234ULL},
{0.001f, 128, 128, 66, 1234ULL},
{0.001f, 64, 128, 130, 1234ULL},

{0.001f, 32, 32, 33, 1234ULL},
{0.001f, 32, 64, 33, 1234ULL},
{0.001f, 64, 32, 33, 1234ULL},
{0.001f, 64, 64, 33, 1234ULL},
{0.001f, 128, 32, 33, 1234ULL},
{0.001f, 128, 64, 33, 1234ULL},
{0.001f, 128, 128, 65, 1234ULL},
{0.001f, 64, 128, 129, 1234ULL},
{0.006f, 1805, 134, 2, 1234ULL},

// Repeat with smaller values of k
{0.006f, 32, 32, 1, 1234ULL},
{0.001f, 32, 64, 2, 1234ULL},
{0.001f, 64, 32, 3, 1234ULL},
{0.001f, 64, 64, 4, 1234ULL},
{0.001f, 128, 32, 5, 1234ULL},
{0.001f, 128, 64, 6, 1234ULL},
{0.001f, 128, 128, 7, 1234ULL},
{0.001f, 64, 128, 8, 1234ULL},

{0.001f, 32, 32, 9, 1234ULL},
{0.001f, 32, 64, 10, 1234ULL},
{0.001f, 64, 32, 11, 1234ULL},
{0.001f, 64, 64, 12, 1234ULL},
{0.001f, 128, 32, 13, 1234ULL},
{0.001f, 128, 64, 14, 1234ULL},
{0.001f, 128, 128, 15, 1234ULL},
{0.001f, 64, 128, 16, 1234ULL},

{0.001f, 32, 32, 17, 1234ULL},
{0.001f, 32, 64, 18, 1234ULL},
{0.001f, 64, 32, 19, 1234ULL},
{0.001f, 64, 64, 20, 1234ULL},
{0.001f, 128, 32, 21, 1234ULL},
{0.001f, 128, 64, 22, 1234ULL},
{0.001f, 128, 128, 23, 1234ULL},
{0.00001, 64, 128, 24, 1234ULL},
{0.001f, 1805, 134, 25, 1234ULL},
};
typedef FusedL2NNTest<float, false> FusedL2NNTestF_Sq;
TEST_P(FusedL2NNTestF_Sq, Result)
Expand Down