diff --git a/cpp/bench/spatial/fused_l2_nn.cu b/cpp/bench/spatial/fused_l2_nn.cu index e5b5dc377a..aa36483145 100644 --- a/cpp/bench/spatial/fused_l2_nn.cu +++ b/cpp/bench/spatial/fused_l2_nn.cu @@ -22,9 +22,12 @@ #include #include -#if defined RAFT_NN_COMPILED -#include -#endif +// TODO: Once fusedL2NN is specialized in the raft_distance shared library, add +// back +// +// #if defined RAFT_NN_COMPILED +// #include +// #endif namespace raft::bench::spatial { @@ -73,6 +76,30 @@ struct fused_l2_nn : public fixture { false, stream); }); + + // Num distance calculations + int64_t num_dist_calcs = (int64_t)params.n * (int64_t)params.m; + + int64_t num_flops = 3 * num_dist_calcs * params.k; + + int64_t read_elts = (int64_t)params.n * params.k + (int64_t)params.m * params.k; + int64_t write_elts = (int64_t)params.n; + + state.counters["D/s"] = benchmark::Counter(num_dist_calcs, + benchmark::Counter::kIsIterationInvariantRate, + benchmark::Counter::OneK::kIs1000); + + state.counters["FLOP/s"] = benchmark::Counter( + num_flops, benchmark::Counter::kIsIterationInvariantRate, benchmark::Counter::OneK::kIs1000); + + state.counters["BW Wr"] = benchmark::Counter(write_elts * sizeof(cub::KeyValuePair), + benchmark::Counter::kIsIterationInvariantRate, + benchmark::Counter::OneK::kIs1000); + state.counters["BW Rd"] = benchmark::Counter(read_elts * sizeof(float), + benchmark::Counter::kIsIterationInvariantRate, + benchmark::Counter::OneK::kIs1000); + + state.counters["K"] = benchmark::Counter(params.k); } private: @@ -88,9 +115,9 @@ const std::vector fused_l2_nn_input_vecs = { {32, 16384, 16384}, {64, 16384, 16384}, {128, 16384, 16384}, {256, 16384, 16384}, {512, 16384, 16384}, {1024, 16384, 16384}, {16384, 32, 16384}, {16384, 64, 16384}, {16384, 128, 16384}, {16384, 256, 16384}, {16384, 512, 16384}, {16384, 1024, 16384}, + {16384, 16384, 2}, {16384, 16384, 4}, {16384, 16384, 8}, {16384, 16384, 16}, {16384, 16384, 32}, {16384, 16384, 64}, {16384, 16384, 128}, {16384, 16384, 256}, {16384, 16384, 512}, {16384, 16384, 1024}, {16384, 16384, 16384}, - }; RAFT_BENCH_REGISTER(fused_l2_nn, "", fused_l2_nn_input_vecs); diff --git a/cpp/include/raft/distance/detail/fused_l2_nn.cuh b/cpp/include/raft/distance/detail/fused_l2_nn.cuh index 81d02c410c..308f8a096a 100644 --- a/cpp/include/raft/distance/detail/fused_l2_nn.cuh +++ b/cpp/include/raft/distance/detail/fused_l2_nn.cuh @@ -92,14 +92,14 @@ DI void updateReducedVal( const auto lid = threadIdx.x % raft::WarpSize; const auto accrowid = threadIdx.x / P::AccThCols; - // for now have first lane from each warp update a unique output row. This - // will resolve hang issues with pre-Volta architectures + // Update each output row in order within a warp. This will resolve hang + // issues with pre-Volta architectures #pragma unroll for (int j = 0; j < (raft::WarpSize / P::AccThCols); j++) { - if (lid == 0) { + if (lid == j * P::AccThCols) { #pragma unroll for (int i = 0; i < P::AccRowsPerTh; ++i) { - auto rid = gridStrideY + accrowid + j + i * P::AccThRows; + auto rid = gridStrideY + accrowid + i * P::AccThRows; if (rid < m) { auto value = val[i]; while (atomicCAS(mutex + rid, 0, 1) == 1) @@ -111,14 +111,6 @@ DI void updateReducedVal( } } } - if (j < (raft::WarpSize / P::AccThCols) - 1) { -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { - auto tmpkey = raft::shfl(val[i].key, (j + 1) * P::AccThCols); - auto tmpvalue = raft::shfl(val[i].value, (j + 1) * P::AccThCols); - val[i] = {tmpkey, tmpvalue}; - } - } } } @@ -210,8 +202,10 @@ __global__ __launch_bounds__(P::Nthreads, 2) void fusedL2NNkernel(OutT* min, for (int i = 0; i < P::AccRowsPerTh; ++i) { #pragma unroll for (int j = P::AccThCols / 2; j > 0; j >>= 1) { - auto tmpkey = raft::shfl(val[i].key, lid + j); - auto tmpvalue = raft::shfl(val[i].value, lid + j); + // Actually, the srcLane (lid +j) should be (lid +j) % P:AccThCols, + // but the shfl op applies the modulo internally. + auto tmpkey = raft::shfl(val[i].key, lid + j, P::AccThCols); + auto tmpvalue = raft::shfl(val[i].value, lid + j, P::AccThCols); KVPair tmp = {tmpkey, tmpvalue}; val[i] = pairRed_op(accrowid + i * P::AccThRows + gridStrideY, tmp, val[i]); } @@ -261,7 +255,7 @@ __global__ __launch_bounds__(P::Nthreads, 2) void fusedL2NNkernel(OutT* min, template void fusedL2NNImpl(OutT* min, @@ -279,7 +273,8 @@ void fusedL2NNImpl(OutT* min, bool initOutBuffer, cudaStream_t stream) { - typedef typename linalg::Policy4x4::Policy P; + // The kernel policy is determined by fusedL2NN. + typedef Policy P; dim3 blk(P::Nthreads); auto nblks = raft::ceildiv(m, P::Nthreads); diff --git a/cpp/include/raft/distance/fused_l2_nn.cuh b/cpp/include/raft/distance/fused_l2_nn.cuh index ac8895c9ce..121ccbf60d 100644 --- a/cpp/include/raft/distance/fused_l2_nn.cuh +++ b/cpp/include/raft/distance/fused_l2_nn.cuh @@ -24,6 +24,7 @@ #include #include #include +#include #include namespace raft { @@ -99,20 +100,64 @@ void fusedL2NN(OutT* min, bool initOutBuffer, cudaStream_t stream) { + // When k is smaller than 32, the Policy4x4 results in redundant calculations + // as it uses tiles that have k=32. Therefore, use a "skinny" policy instead + // that uses tiles with a smaller value of k. + bool is_skinny = k < 32; + size_t bytes = sizeof(DataT) * k; if (16 % sizeof(DataT) == 0 && bytes % 16 == 0) { - detail::fusedL2NNImpl( - min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream); + if (is_skinny) { + detail::fusedL2NNImpl::Policy, + ReduceOpT>( + min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream); + } else { + detail::fusedL2NNImpl::Policy, + ReduceOpT>( + min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream); + } } else if (8 % sizeof(DataT) == 0 && bytes % 8 == 0) { - detail::fusedL2NNImpl( - min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream); + if (is_skinny) { + detail::fusedL2NNImpl::Policy, + ReduceOpT>( + min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream); + } else { + detail::fusedL2NNImpl::Policy, + ReduceOpT>( + min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream); + } } else { - detail::fusedL2NNImpl( - min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream); + if (is_skinny) { + detail::fusedL2NNImpl::Policy, + ReduceOpT>( + min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream); + } else { + detail::fusedL2NNImpl::Policy, + ReduceOpT>( + min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream); + } } } } // namespace distance } // namespace raft -#endif \ No newline at end of file +#endif diff --git a/cpp/include/raft/linalg/contractions.cuh b/cpp/include/raft/linalg/contractions.cuh index 5ccbd15c3d..8aed0cb4be 100644 --- a/cpp/include/raft/linalg/contractions.cuh +++ b/cpp/include/raft/linalg/contractions.cuh @@ -167,6 +167,28 @@ struct Policy4x4 { }; /** @} */ +/** + * A smaller k-block (8 instead of 32) with fewer threads per block (8x8 instead + * of 16x16), which is faster for raft::distance::fusedL2NN on skinny matrices, + * i.e., matrices with a small k dimension. + * + */ +template +struct Policy4x4Skinny { +}; + +template +struct Policy4x4Skinny { + typedef KernelPolicy Policy; + typedef ColKernelPolicy ColPolicy; +}; + +template +struct Policy4x4Skinny { + typedef KernelPolicy Policy; + typedef ColKernelPolicy ColPolicy; +}; + /** * @defgroup Policy2x8 16 elements per thread Policy with k-block = 16 * @{ diff --git a/cpp/test/distance/fused_l2_nn.cu b/cpp/test/distance/fused_l2_nn.cu index 192f0c9a74..2a5b30e01f 100644 --- a/cpp/test/distance/fused_l2_nn.cu +++ b/cpp/test/distance/fused_l2_nn.cu @@ -23,6 +23,13 @@ #include #include +// TODO: Once fusedL2NN is specialized in the raft_distance shared library, add +// the following: +// +// #if defined RAFT_NN_COMPILED +// #include +// #endif + namespace raft { namespace distance { @@ -102,6 +109,23 @@ struct Inputs { DataT tolerance; int m, n, k; unsigned long long int seed; + + friend std::ostream& operator<<(std::ostream& os, const Inputs& p) + { + return os << "m: " << p.m + << ", " + "n: " + << p.n + << ", " + "k: " + << p.k + << ", " + "seed: " + << p.seed + << ", " + "tol: " + << p.tolerance; + } }; template @@ -231,19 +255,62 @@ template } const std::vector> inputsf = { - {0.001f, 32, 32, 32, 1234ULL}, {0.001f, 32, 64, 32, 1234ULL}, {0.001f, 64, 32, 32, 1234ULL}, - {0.001f, 64, 64, 32, 1234ULL}, {0.001f, 128, 32, 32, 1234ULL}, {0.001f, 128, 64, 32, 1234ULL}, - {0.001f, 128, 128, 64, 1234ULL}, {0.001f, 64, 128, 128, 1234ULL}, - - {0.001f, 32, 32, 34, 1234ULL}, {0.001f, 32, 64, 34, 1234ULL}, {0.001f, 64, 32, 34, 1234ULL}, - {0.001f, 64, 64, 34, 1234ULL}, {0.001f, 128, 32, 34, 1234ULL}, {0.001f, 128, 64, 34, 1234ULL}, - {0.001f, 128, 128, 66, 1234ULL}, {0.001f, 64, 128, 130, 1234ULL}, - - {0.001f, 32, 32, 33, 1234ULL}, {0.001f, 32, 64, 33, 1234ULL}, {0.001f, 64, 32, 33, 1234ULL}, - {0.001f, 64, 64, 33, 1234ULL}, {0.001f, 128, 32, 33, 1234ULL}, {0.001f, 128, 64, 33, 1234ULL}, - {0.001f, 128, 128, 65, 1234ULL}, {0.001f, 64, 128, 129, 1234ULL}, - + {0.001f, 32, 32, 32, 1234ULL}, + {0.001f, 32, 64, 32, 1234ULL}, + {0.001f, 64, 32, 32, 1234ULL}, + {0.001f, 64, 64, 32, 1234ULL}, + {0.001f, 128, 32, 32, 1234ULL}, + {0.001f, 128, 64, 32, 1234ULL}, + {0.001f, 128, 128, 64, 1234ULL}, + {0.001f, 64, 128, 128, 1234ULL}, + + {0.001f, 32, 32, 34, 1234ULL}, + {0.001f, 32, 64, 34, 1234ULL}, + {0.001f, 64, 32, 34, 1234ULL}, + {0.001f, 64, 64, 34, 1234ULL}, + {0.001f, 128, 32, 34, 1234ULL}, + {0.001f, 128, 64, 34, 1234ULL}, + {0.001f, 128, 128, 66, 1234ULL}, + {0.001f, 64, 128, 130, 1234ULL}, + + {0.001f, 32, 32, 33, 1234ULL}, + {0.001f, 32, 64, 33, 1234ULL}, + {0.001f, 64, 32, 33, 1234ULL}, + {0.001f, 64, 64, 33, 1234ULL}, + {0.001f, 128, 32, 33, 1234ULL}, + {0.001f, 128, 64, 33, 1234ULL}, + {0.001f, 128, 128, 65, 1234ULL}, + {0.001f, 64, 128, 129, 1234ULL}, {0.006f, 1805, 134, 2, 1234ULL}, + + // Repeat with smaller values of k + {0.006f, 32, 32, 1, 1234ULL}, + {0.001f, 32, 64, 2, 1234ULL}, + {0.001f, 64, 32, 3, 1234ULL}, + {0.001f, 64, 64, 4, 1234ULL}, + {0.001f, 128, 32, 5, 1234ULL}, + {0.001f, 128, 64, 6, 1234ULL}, + {0.001f, 128, 128, 7, 1234ULL}, + {0.001f, 64, 128, 8, 1234ULL}, + + {0.001f, 32, 32, 9, 1234ULL}, + {0.001f, 32, 64, 10, 1234ULL}, + {0.001f, 64, 32, 11, 1234ULL}, + {0.001f, 64, 64, 12, 1234ULL}, + {0.001f, 128, 32, 13, 1234ULL}, + {0.001f, 128, 64, 14, 1234ULL}, + {0.001f, 128, 128, 15, 1234ULL}, + {0.001f, 64, 128, 16, 1234ULL}, + + {0.001f, 32, 32, 17, 1234ULL}, + {0.001f, 32, 64, 18, 1234ULL}, + {0.001f, 64, 32, 19, 1234ULL}, + {0.001f, 64, 64, 20, 1234ULL}, + {0.001f, 128, 32, 21, 1234ULL}, + {0.001f, 128, 64, 22, 1234ULL}, + {0.001f, 128, 128, 23, 1234ULL}, + {0.00001, 64, 128, 24, 1234ULL}, + {0.001f, 1805, 134, 25, 1234ULL}, }; typedef FusedL2NNTest FusedL2NNTestF_Sq; TEST_P(FusedL2NNTestF_Sq, Result)