From cfe660775304e58b2e2ba222a880ace5bb96e42e Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Thu, 22 Sep 2022 11:38:16 +0200 Subject: [PATCH] Add sparseL2NN initial implementation --- .../raft/distance/detail/compress_to_bits.cuh | 49 ++ .../distance/detail/sparse_distance_base.cuh | 362 +++++++++++++ .../raft/distance/detail/sparse_l2_nn.cuh | 303 +++++++++++ cpp/include/raft/distance/sparse_l2_nn.cuh | 114 ++++ .../raft/linalg/detail/contractions.cuh | 42 ++ cpp/test/CMakeLists.txt | 1 + cpp/test/distance/sparse_l2_nn.cu | 494 ++++++++++++++++++ 7 files changed, 1365 insertions(+) create mode 100644 cpp/include/raft/distance/detail/compress_to_bits.cuh create mode 100644 cpp/include/raft/distance/detail/sparse_distance_base.cuh create mode 100644 cpp/include/raft/distance/detail/sparse_l2_nn.cuh create mode 100644 cpp/include/raft/distance/sparse_l2_nn.cuh create mode 100644 cpp/test/distance/sparse_l2_nn.cu diff --git a/cpp/include/raft/distance/detail/compress_to_bits.cuh b/cpp/include/raft/distance/detail/compress_to_bits.cuh new file mode 100644 index 0000000000..e9a60154a3 --- /dev/null +++ b/cpp/include/raft/distance/detail/compress_to_bits.cuh @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include + +namespace raft { +namespace distance { +namespace detail { + +template ::value>> +__global__ void compress_to_bits_naive(const bool* in, int in_rows, int in_cols, T* out) +{ + constexpr int bits_per_element = 8 * sizeof(T); + + const size_t i = threadIdx.y + blockIdx.y * blockDim.y; + const size_t j = threadIdx.x + blockIdx.x * blockDim.x; + + if (in_rows <= i || in_cols <= j) { return; } + + bool bit = in[i * in_cols + j]; + int bitpos = j % bits_per_element; + + T bitfield = bit ? T(1) << bitpos : 0; + + const size_t out_rows = raft::ceildiv(in_cols, bits_per_element); + const size_t out_cols = in_rows; + const size_t out_j = i; + const size_t out_i = j / bits_per_element; + if (out_i < out_rows && out_j < out_cols) { atomicOr(&out[out_i * out_cols + out_j], bitfield); } +} + +}; // namespace detail +}; // namespace distance +}; // namespace raft diff --git a/cpp/include/raft/distance/detail/sparse_distance_base.cuh b/cpp/include/raft/distance/detail/sparse_distance_base.cuh new file mode 100644 index 0000000000..6e51ccbab3 --- /dev/null +++ b/cpp/include/raft/distance/detail/sparse_distance_base.cuh @@ -0,0 +1,362 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include +#include +#include + +#include + +namespace raft { +namespace distance { +namespace detail { + +/** + * @brief Device class for L1, L2 and cosine distance metrics. + * @tparam useNorms whether norms are needed + * @tparam DataT input data-type (for A and B matrices) + * @tparam AccT accumulation data-type + * @tparam OutT output data-type (for C and D matrices) + * @tparam IdxT index data-type + * @tparam Policy struct which tunes the Contraction kernel + * @tparam CoreLambda tells how to accumulate an x and y into + acc. its signature: + template void core_lambda(AccT& acc, + const DataT& x, const DataT& y) + * @tparam EpilogueLambda applies an elementwise function to compute final + values. Its signature is: + template void epilogue_lambda + (AccT acc[][], DataT* regxn, DataT* regyn); + * @tparam FinalLambda the final lambda called on final distance value + * @param[in] x input matrix + * @param[in] y input matrix + * @param[in] m number of rows of A and C/D + * @param[in] n number of columns of B and C/D + * @param[in] k number of cols of A and rows of B + * @param[in] lda leading dimension of A + * @param[in] ldb leading dimension of B + * @param[in] ldd leading dimension of C/D + * @param[in] xn row norms of input matrix A. Required for expanded L2, cosine + * @param[in] yn row norms of input matrix B. Required for expanded L2, cosine + * @param[output] pD output matrix + * @param[in] smem shared mem buffer for intermediate storage of A, B, xn & yn. + * @param core_op the core accumulation operation lambda + * @param epilog_op the epilog operation lambda + * @param fin_op the final gemm epilogue lambda + */ + +template > +struct SparseDistances : public BaseClass { + private: + typedef Policy P; + const DataT* xn; + const DataT* yn; + const DataT* const yBase; + const uint64_t* adj; + const IdxT* group_idxs; + IdxT num_groups; + char* smem; + CoreLambda core_op; + EpilogueLambda epilog_op; + FinalLambda fin_op; + rowEpilogueLambda rowEpilog_op; + + AccT acc[P::AccRowsPerTh][P::AccColsPerTh]; + + public: + // Constructor + DI SparseDistances(const DataT* _x, + const DataT* _y, + IdxT _m, + IdxT _n, + IdxT _k, + IdxT _lda, + IdxT _ldb, + IdxT _ldd, + const DataT* _xn, + const DataT* _yn, + const uint64_t* _adj, + const IdxT* _group_idxs, + IdxT _num_groups, + char* _smem, + CoreLambda _core_op, + EpilogueLambda _epilog_op, + FinalLambda _fin_op, + rowEpilogueLambda _rowEpilog_op) + : BaseClass(_x, _y, _m, _n, _k, _lda, _ldb, _ldd, _smem), + xn(_xn), + yn(_yn), + yBase(_y), + adj(_adj), + group_idxs(_group_idxs), + num_groups(_num_groups), + smem(_smem), + core_op(_core_op), + epilog_op(_epilog_op), + fin_op(_fin_op), + rowEpilog_op(_rowEpilog_op) + { + } + + DI void run() + { + const auto grid_stride_m = (P::Mblk * gridDim.y); + const auto grid_offset_m = (P::Mblk * blockIdx.y); + + const auto grid_stride_g = gridDim.x; + const auto grid_offset_g = blockIdx.x; + + for (auto tile_idx_m = grid_offset_m; tile_idx_m < this->m; tile_idx_m += grid_stride_m) { + // Start loop over groups + for (auto idx_g = grid_offset_g; idx_g < this->num_groups; idx_g += grid_stride_g) { + // The __syncthreads() ensures that loading the block flag occurs at + // the same time in all threads of the block. Since all threads load + // the same address, this speeds up the code. + __syncthreads(); + const uint64_t block_adj = get_block_adjacency(adj, tile_idx_m, idx_g); + // block_adj is a bitfield that contains a 1 if a row is adjacent to the + // current group. All zero means we can skip this group. + if (block_adj == 0) { continue; } + + // Determine which results, that are computed by this thread, have to + // be taken into account. This information is stored in a bitfield, + // thread_adj. If all results computed by this thread can be ignored, + // then we can also skip some computations (thread_adj == 0). + + // We precompute this information because it is used in various + // locations to skip thread-local computations. + int thread_adj = compute_thread_adjacency(block_adj); + + auto tile_idx_n = idx_g == 0 ? 0 : group_idxs[idx_g - 1]; + const auto tile_end_n = group_idxs[idx_g]; + for (; tile_idx_n < tile_end_n; tile_idx_n += P::Nblk) { + // We provide tile_end_n to limit the number of unnecessary data + // points that are loaded from y. + // TODO: determine if this actually improves performance. + this->ldgXY(tile_idx_m, tile_idx_n, 0, tile_end_n); + + reset_accumulator(); + this->stsXY(); + __syncthreads(); + this->switch_write_buffer(); + + for (int kidx = P::Kblk; kidx < this->k; kidx += P::Kblk) { + this->ldgXY(tile_idx_m, tile_idx_n, kidx, tile_end_n); + // Process all data in shared memory (previous k-block) and + // accumulate in registers. + if (thread_adj != 0) { accumulate(); } + this->stsXY(); + __syncthreads(); + this->switch_write_buffer(); + this->switch_read_buffer(); + } + if (thread_adj != 0) { + accumulate(); // last iteration + } + // This is needed for making sure next grid stride of + // non-norm based metrics uses previously accumulated buffer so + // it doesn't make shmem dirty until previous iteration + // is complete. + this->switch_read_buffer(); + + if (useNorms) { + DataT regxn[P::AccRowsPerTh], regyn[P::AccColsPerTh]; + load_norms(tile_idx_m, tile_idx_n, tile_end_n, regxn, regyn); + if (thread_adj != 0) { + epilog_op(acc, thread_adj, regxn, regyn, tile_idx_n, tile_idx_m, tile_end_n); + } + } else { + if (thread_adj != 0) { + epilog_op(acc, thread_adj, nullptr, nullptr, tile_idx_n, tile_idx_m, tile_end_n); + } + } + } // tile_idx_n + } // idx_g + rowEpilog_op(tile_idx_m); + } // tile_idx_n + } + + private: + DI uint64_t get_block_adjacency(const uint64_t* adj, IdxT tile_idx_m, IdxT idx_group) + { + IdxT block_flag_idx = tile_idx_m / P::Mblk; + return adj[block_flag_idx * this->num_groups + idx_group]; + } + + DI uint32_t compute_thread_adjacency(const uint64_t block_adj) + { + uint32_t thread_adj = 0; +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { + const uint64_t read_mask = 1ull << (this->accrowid + i * P::AccThRows); + const uint32_t write_mask = 1 << i; + if ((block_adj & read_mask) != 0) { thread_adj |= write_mask; } + } + return thread_adj; + } + + DI void reset_accumulator() + { + // Reset accumulator registers to zero. +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { +#pragma unroll + for (int j = 0; j < P::AccColsPerTh; ++j) { + acc[i][j] = BaseClass::Zero; + } + } + } + + DI void accumulate() + { +#pragma unroll + for (int ki = 0; ki < P::Kblk; ki += P::Veclen) { + this->ldsXY(ki); +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { +#pragma unroll + for (int j = 0; j < P::AccColsPerTh; ++j) { +#pragma unroll + for (int v = 0; v < P::Veclen; ++v) { + core_op(acc[i][j], this->regx[i][v], this->regy[j][v]); + } + } + } + } + } + + DI void load_norms(IdxT tile_idx_m, + IdxT tile_idx_n, + IdxT tile_end_n, + DataT (®xn)[P::AccRowsPerTh], + DataT (®yn)[P::AccColsPerTh]) + { + DataT* sxNorm = (DataT*)(&smem[P::SmemSize]); + DataT* syNorm = (&sxNorm[P::Mblk]); + + // Load x & y norms required by this threadblock in shmem buffer + for (int i = threadIdx.x; i < P::Mblk; i += P::Nthreads) { + auto idx = tile_idx_m + i; + sxNorm[i] = idx < this->m ? xn[idx] : 0; + } + + for (int i = threadIdx.x; i < P::Nblk; i += P::Nthreads) { + auto idx = tile_idx_n + i; + syNorm[i] = idx < tile_end_n ? yn[idx] : 0; + } + __syncthreads(); + +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { + regxn[i] = sxNorm[i * P::AccThRows + (threadIdx.x / P::AccThCols)]; + } +#pragma unroll + for (int i = 0; i < P::AccColsPerTh; ++i) { + regyn[i] = syNorm[i * P::AccThCols + (threadIdx.x % P::AccThCols)]; + } + } +}; // struct SparseDistances + +/** + * @brief the distance matrix calculation kernel for L1, L2 and cosine + * @tparam useNorms whether norms are needed + * @tparam DataT input data-type (for A and B matrices) + * @tparam AccT accumulation data-type + * @tparam OutT output data-type (for C and D matrices) + * @tparam IdxT index data-type + * @tparam Policy struct which tunes the Contraction kernel + * @tparam CoreLambda lambda which implements accumulation operation + * @tparam EpilogueLambda lambda which implements operation for calculating + final value. + * @tparam FinalLambda final lambda called on final distance value + * @tparam isRowMajor true if input/output is row major(default), + false for column major + * + * @param[in] x input matrix + * @param[in] y input matrix + * @param[in] xn row norms of input matrix A. + * @param[in] yn row norms of input matrix B. + * @param[in] m number of rows of A and C/D + * @param[in] n number of columns of B and C/D + * @param[in] k number of cols of A and rows of B + * @param[in] lda leading dimension of A + * @param[in] ldb leading dimension of B + * @param[in] ldd leading dimension of C/D + * @param[output] pD output matrix + * @param core_op the core lambda + * @param epilog_op the epilogue lambda + * @param fin_op the final gemm epilogue lambda + */ + +template +__global__ __launch_bounds__(Policy::Nthreads, 2) + + void sparseDistanceMatKernel(const DataT* x, + const DataT* y, + const DataT* _xn, + const DataT* _yn, + const bool* adj, + IdxT m, + IdxT n, + IdxT k, + IdxT lda, + IdxT ldb, + IdxT ldd, + CoreLambda core_op, + EpilogueLambda epilog_op, + FinalLambda fin_op) +{ + extern __shared__ char smem[]; + auto rowEpilog = [] __device__(IdxT starty) { return; }; + + SparseDistances + obj(x, y, m, n, k, lda, ldb, ldd, _xn, _yn, smem, core_op, epilog_op, fin_op, rowEpilog); + obj.run(); +} + +}; // namespace detail +}; // namespace distance +}; // namespace raft diff --git a/cpp/include/raft/distance/detail/sparse_l2_nn.cuh b/cpp/include/raft/distance/detail/sparse_l2_nn.cuh new file mode 100644 index 0000000000..acc66b3837 --- /dev/null +++ b/cpp/include/raft/distance/detail/sparse_l2_nn.cuh @@ -0,0 +1,303 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace raft { +namespace distance { +namespace detail { + +#if (ENABLE_MEMCPY_ASYNC == 1) +#include +using namespace nvcuda::experimental; +#endif + +template +__global__ __launch_bounds__(P::Nthreads, 2) void sparseL2NNkernel(OutT* min, + const DataT* x, + const DataT* y, + const DataT* xn, + const DataT* yn, + const uint64_t* adj, + const IdxT* group_idxs, + IdxT num_groups, + IdxT m, + IdxT n, + IdxT k, + DataT maxVal, + int* mutex, + ReduceOpT redOp, + KVPReduceOpT pairRedOp, + CoreLambda core_op, + FinalLambda fin_op) +{ + extern __shared__ char smem[]; + + typedef cub::KeyValuePair KVPair; + KVPair val[P::AccRowsPerTh]; +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { + val[i] = {-1, maxVal}; + } + + // epilogue operation lambda for final value calculation + auto epilog_lambda = [pairRedOp, &val, maxVal] __device__( + DataT acc[P::AccRowsPerTh][P::AccColsPerTh], + int acc_adj, + DataT* regxn, + DataT* regyn, + IdxT tile_idx_n, + IdxT tile_idx_m, + IdxT tile_end_n) { + KVPReduceOpT pairRed_op(pairRedOp); + +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { +#pragma unroll + for (int j = 0; j < P::AccColsPerTh; ++j) { + acc[i][j] = regxn[i] + regyn[j] - (DataT)2.0 * acc[i][j]; + } + } + if (Sqrt) { +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { +#pragma unroll + for (int j = 0; j < P::AccColsPerTh; ++j) { + acc[i][j] = raft::mySqrt(acc[i][j]); + } + } + } + + // intra thread reduce + const auto acccolid = threadIdx.x % P::AccThCols; + const auto accrowid = threadIdx.x / P::AccThCols; + +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { + const bool ignore = (acc_adj & (1 << i)) == 0; + if (ignore) { continue; } +#pragma unroll + for (int j = 0; j < P::AccColsPerTh; ++j) { + auto tmpkey = acccolid + j * P::AccThCols + tile_idx_n; + if (tile_end_n <= tmpkey) { + // Do not process beyond end of tile. + continue; + } + KVPair tmp = {tmpkey, acc[i][j]}; + if (tmpkey < tile_end_n) { + val[i] = pairRed_op(accrowid + i * P::AccThRows + tile_idx_m, tmp, val[i]); + } + } + } + }; + + auto rowEpilog_lambda = + [m, mutex, min, pairRedOp, redOp, &val, maxVal] __device__(IdxT tile_idx_m) { + KVPReduceOpT pairRed_op(pairRedOp); + ReduceOpT red_op(redOp); + + const auto accrowid = threadIdx.x / P::AccThCols; + const auto lid = raft::laneId(); + // reduce +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { +#pragma unroll + for (int j = P::AccThCols / 2; j > 0; j >>= 1) { + auto tmpkey = raft::shfl(val[i].key, lid + j); + auto tmpvalue = raft::shfl(val[i].value, lid + j); + KVPair tmp = {tmpkey, tmpvalue}; + val[i] = pairRed_op(accrowid + i * P::AccThRows + tile_idx_m, tmp, val[i]); + } + } + + updateReducedVal(mutex, min, val, red_op, m, tile_idx_m); + + // reset the val array. +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { + val[i] = {-1, maxVal}; + } + }; + + IdxT lda = k, ldb = k, ldd = n; + SparseDistances + obj(x, + y, + m, + n, + k, + lda, + ldb, + ldd, + xn, + yn, + adj, + group_idxs, + num_groups, + smem, + core_op, + epilog_lambda, + fin_op, + rowEpilog_lambda); + obj.run(); +} + +template +void sparseL2NNImpl(OutT* min, + const DataT* x, + const DataT* y, + const DataT* xn, + const DataT* yn, + const bool* adj, + const IdxT* group_idxs, + IdxT num_groups, + IdxT m, + IdxT n, + IdxT k, + int* workspace, + ReduceOpT redOp, + KVPReduceOpT pairRedOp, + bool sqrt, + bool initOutBuffer, + cudaStream_t stream) +{ + typedef typename linalg::Policy4x4::Policy P; + + static_assert(P::Mblk == 64, "sparseL2NNImpl only supports a policy with 64 rows per block."); + + // First, compress boolean to bitfield. + + // TODO 1: Remove allocation; use workspace instead(?) + // TODO 2: Use a faster compress_to_bits implementation that does not require a pre-zeroed output. + rmm::device_uvector adj64(raft::ceildiv(m, IdxT(64)) * num_groups, stream); + RAFT_CUDA_TRY(cudaMemsetAsync(adj64.data(), 0, adj64.size() * sizeof(uint64_t), stream)); + dim3 compress_grid(raft::ceildiv(m, 32), raft::ceildiv(num_groups, 32)); + compress_to_bits_naive<<>>( + adj, num_groups, m, adj64.data()); + + dim3 blk(P::Nthreads); + auto nblks = raft::ceildiv(m, P::Nthreads); + constexpr auto maxVal = std::numeric_limits::max(); + typedef cub::KeyValuePair KVPair; + + // Accumulation operation lambda + auto core_lambda = [] __device__(DataT & acc, DataT & x, DataT & y) { acc += x * y; }; + + RAFT_CUDA_TRY(cudaMemsetAsync(workspace, 0, sizeof(int) * m, stream)); + if (initOutBuffer) { + initKernel + <<>>(min, m, maxVal, redOp); + RAFT_CUDA_TRY(cudaGetLastError()); + } + + // TODO 3: remove fin_op + auto fin_op = [] __device__(DataT d_val, int g_d_idx) { return d_val; }; + + constexpr size_t shmemSize = P::SmemSize + ((P::Mblk + P::Nblk) * sizeof(DataT)); + if (sqrt) { + auto sparseL2NNSqrt = sparseL2NNkernel; + dim3 grid = launchConfigGenerator

(m, n, shmemSize, sparseL2NNSqrt); + + sparseL2NNSqrt<<>>(min, + x, + y, + xn, + yn, + adj64.data(), + group_idxs, + num_groups, + m, + n, + k, + maxVal, + workspace, + redOp, + pairRedOp, + core_lambda, + fin_op); + } else { + auto sparseL2NN = sparseL2NNkernel; + dim3 grid = launchConfigGenerator

(m, n, shmemSize, sparseL2NN); + sparseL2NN<<>>(min, + x, + y, + xn, + yn, + adj64.data(), + group_idxs, + num_groups, + m, + n, + k, + maxVal, + workspace, + redOp, + pairRedOp, + core_lambda, + fin_op); + } + + RAFT_CUDA_TRY(cudaGetLastError()); +} + +} // namespace detail +} // namespace distance +} // namespace raft diff --git a/cpp/include/raft/distance/sparse_l2_nn.cuh b/cpp/include/raft/distance/sparse_l2_nn.cuh new file mode 100644 index 0000000000..c690702cb4 --- /dev/null +++ b/cpp/include/raft/distance/sparse_l2_nn.cuh @@ -0,0 +1,114 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __SPARSE_L2_NN_H +#define __SPARSE_L2_NN_H + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +namespace raft { +namespace distance { + +/** + * @brief Sparse L2 distance and 1-nearest-neighbor computation in a single call. + * + * The benefits of such a call are 2-fold: 1) eliminate the need for an + * intermediate buffer to store the output of gemm 2) reduce the memory read + * traffic on this intermediate buffer, otherwise needed during the reduction + * phase for 1-NN. + * + * @tparam DataT data type + * @tparam OutT output type to either store 1-NN indices and their minimum + * distances or store only the min distances. Accordingly, one + * has to pass an appropriate `ReduceOpT` + * @tparam IdxT indexing arithmetic type + * @tparam ReduceOpT A struct to perform the final needed reduction operation + * and also to initialize the output array elements with the + * appropriate initial value needed for reduction. + * + * @param[out] min will contain the reduced output (Length = `m`) + * (on device) + * @param[in] x first matrix. Row major. Dim = `m x k`. + * (on device). + * @param[in] y second matrix. Row major. Dim = `n x k`. + * (on device). + * @param[in] xn L2 squared norm of `x`. Length = `m`. (on device). + * @param[in] yn L2 squared norm of `y`. Length = `n`. (on device) + * @param[in] m gemm m + * @param[in] n gemm n + * @param[in] k gemm k + * @param[in] workspace temp workspace. Size = sizeof(int)*m. (on device) + * @param[in] redOp reduction operator in the epilogue + * @param[in] pairRedOp reduction operation on key value pairs + * @param[in] sqrt Whether the output `minDist` should contain L2-sqrt + * @param[in] initOutBuffer whether to initialize the output buffer before the + * main kernel launch + * @param[in] stream cuda stream + */ +template +void sparseL2NN(OutT* min, + const DataT* x, + const DataT* y, + const DataT* xn, + const DataT* yn, + const bool* adj, + const IdxT* group_idxs, + IdxT num_groups, + IdxT m, + IdxT n, + IdxT k, + void* workspace, + ReduceOpT redOp, + KVPReduceOpT pairRedOp, + bool sqrt, + bool initOutBuffer, + cudaStream_t stream) +{ + // TODO: decide on kernel policy based on skinniness of the matrices. If k is + // low, it may make sense to use another kernel policy, like in + // fused_l2_nn.cuh. + detail::sparseL2NNImpl(min, + x, + y, + xn, + yn, + adj, + group_idxs, + num_groups, + m, + n, + k, + (int*)workspace, + redOp, + pairRedOp, + sqrt, + initOutBuffer, + stream); + // } +} + +} // namespace distance +} // namespace raft + +#endif diff --git a/cpp/include/raft/linalg/detail/contractions.cuh b/cpp/include/raft/linalg/detail/contractions.cuh index 6d7a8e2292..4c5a43cd57 100644 --- a/cpp/include/raft/linalg/detail/contractions.cuh +++ b/cpp/include/raft/linalg/detail/contractions.cuh @@ -151,6 +151,12 @@ struct Contractions_NT { ldgY(tile_idx_n, kidx); } + DI void ldgXY(IdxT tile_idx_m, IdxT tile_idx_n, IdxT kidx, IdxT tile_end_n) + { + ldgX(tile_idx_m, kidx); + ldgY(tile_idx_n, kidx, tile_end_n); + } + /** * @brief Store current block of X/Y from registers to smem * @param[in] kidx current start index of k to be loaded @@ -248,6 +254,42 @@ struct Contractions_NT { } } + DI void ldgY(IdxT tile_idx_n, IdxT kidx, IdxT tile_end_n) + { + IdxT yrowid = isRowMajor ? tile_idx_n + srowid : tile_idx_n; + auto y = isRowMajor ? y_base + yrowid * ldb : y_base + yrowid + srowid * ldb; + + if (isRowMajor) { + auto numRows = tile_end_n; + auto koffset = kidx + scolid; +#pragma unroll + for (int i = 0; i < P::LdgPerThY; ++i) { + if (koffset < ldb && (yrowid + i * P::LdgRowsY) < numRows) { + ldg(ldgDataY[i], y + i * P::LdgRowsY * ldb + koffset); + } else { +#pragma unroll + for (int j = 0; j < P::Veclen; ++j) { + ldgDataY[i][j] = Zero; + } + } + } + } else { + auto numRows = k; + auto koffset = scolid; +#pragma unroll + for (int i = 0; i < P::LdgPerThY; ++i) { + if ((koffset + yrowid) < tile_end_n && (srowid + kidx + i * P::LdgRowsY) < numRows) { + ldg(ldgDataY[i], y + (kidx + i * P::LdgRowsY) * ldb + koffset); + } else { +#pragma unroll + for (int j = 0; j < P::Veclen; ++j) { + ldgDataY[i][j] = Zero; + } + } + } + } + } + DI void stsX(DataT* smem) { auto* saddr = smem + srowid * P::SmemStride + scolid; diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 8ca30a5c82..68937e86f3 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -119,6 +119,7 @@ if(BUILD_TESTS) test/distance/dist_minkowski.cu test/distance/dist_russell_rao.cu test/distance/fused_l2_nn.cu + test/distance/sparse_l2_nn.cu test/distance/gram.cu OPTIONAL DIST diff --git a/cpp/test/distance/sparse_l2_nn.cu b/cpp/test/distance/sparse_l2_nn.cu new file mode 100644 index 0000000000..293c78ddee --- /dev/null +++ b/cpp/test/distance/sparse_l2_nn.cu @@ -0,0 +1,494 @@ +/* + * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../test_utils.h" +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace raft { +namespace distance { +namespace sparse_l2_nn { + +template +struct CubKVPMinReduce { + typedef cub::KeyValuePair KVP; + + DI KVP operator()(LabelT rit, const KVP& a, const KVP& b) { return b.value < a.value ? b : a; } + + DI KVP operator()(const KVP& a, const KVP& b) { return b.value < a.value ? b : a; } + +}; // KVPMinReduce + +template +__global__ __launch_bounds__(32 * NWARPS, 2) void naiveKernel(cub::KeyValuePair* min, + DataT* x, + DataT* y, + bool* adj, + int* group_idxs, + int m, + int n, + int k, + int num_groups, + int* workspace, + DataT maxVal) +{ + const int m_stride = blockDim.y * gridDim.y; + const int m_offset = threadIdx.y + blockIdx.y * blockDim.y; + const int n_stride = blockDim.x * gridDim.x; + const int n_offset = threadIdx.x + blockIdx.x * blockDim.x; + + for (int m_grid = 0; m_grid < m; m_grid += m_stride) { + for (int n_grid = 0; n_grid < n; n_grid += n_stride) { + int midx = m_grid + m_offset; + int nidx = n_grid + n_offset; + + // Do a reverse linear search to determine the group index. + int group_idx = 0; + for (int i = num_groups; 0 <= i; --i) { + if (nidx < group_idxs[i]) { group_idx = i; } + } + const bool include_dist = adj[group_idx * m + midx] && midx < m && nidx < n; + + // Compute L2 metric. + DataT acc = DataT(0); + for (int i = 0; i < k; ++i) { + int xidx = i + midx * k; + int yidx = i + nidx * k; + auto diff = x[xidx] - y[yidx]; + acc += diff * diff; + } + if (Sqrt) { acc = raft::mySqrt(acc); } + ReduceOpT redOp; + typedef cub::WarpReduce> WarpReduce; + __shared__ typename WarpReduce::TempStorage temp[NWARPS]; + int warpId = threadIdx.x / raft::WarpSize; + cub::KeyValuePair tmp; + tmp.key = include_dist ? nidx : -1; + tmp.value = include_dist ? acc : maxVal; + tmp = WarpReduce(temp[warpId]).Reduce(tmp, CubKVPMinReduce()); + if (threadIdx.x % raft::WarpSize == 0 && midx < m) { + while (atomicCAS(workspace + midx, 0, 1) == 1) + ; + __threadfence(); + redOp(midx, min + midx, tmp); + __threadfence(); + atomicCAS(workspace + midx, 1, 0); + } + __syncthreads(); + } + } +} + +template +void naive(cub::KeyValuePair* min, + DataT* x, + DataT* y, + bool* adj, + int* group_idxs, + int m, + int n, + int k, + int num_groups, + int* workspace, + cudaStream_t stream) +{ + RAFT_CUDA_TRY(cudaMemsetAsync(workspace, 0, sizeof(int) * m, stream)); + auto blks = raft::ceildiv(m, 256); + MinAndDistanceReduceOp op; + raft::distance::detail::initKernel, int> + <<>>(min, m, std::numeric_limits::max(), op); + RAFT_CUDA_TRY(cudaGetLastError()); + + const int nwarps = 16; + static const dim3 TPB(32, nwarps, 1); + dim3 nblks(1, 200, 1); + naiveKernel, nwarps><<>>( + min, x, y, adj, group_idxs, m, n, k, num_groups, workspace, std::numeric_limits::max()); + RAFT_CUDA_TRY(cudaGetLastError()); +} + +enum AdjacencyPattern { + checkerboard = 0, + checkerboard_4 = 1, + checkerboard_64 = 2, + all_true = 3, + all_false = 4 +}; + +template +struct Inputs { + DataT tolerance; + int m, n, k, num_groups; + unsigned long long int seed; + + AdjacencyPattern pattern; + + friend std::ostream& operator<<(std::ostream& os, const Inputs& p) + { + return os << "m: " << p.m + << ", " + "n: " + << p.n + << ", " + "k: " + << p.k + << ", " + "num_groups: " + << p.num_groups + << ", " + "seed: " + << p.seed + << ", " + "tol: " + << p.tolerance; + } +}; + +__global__ void init_adj( + int m, int n, int num_groups, AdjacencyPattern pattern, bool* adj, int* group_idxs) +{ + for (int i = blockIdx.y * blockDim.y + threadIdx.y; i < num_groups; i += blockDim.y * gridDim.y) { + for (int j = blockIdx.x * blockDim.x + threadIdx.x; j < m; j += blockDim.x * gridDim.x) { + switch (pattern) { + case checkerboard: adj[i * m + j] = (i + j) % 2; break; + case checkerboard_4: adj[i * m + j] = (i + (j / 4)) % 2; break; + case checkerboard_64: adj[i * m + j] = (i + (j / 64)) % 2; break; + case all_true: adj[i * m + j] = true; break; + case all_false: adj[i * m + j] = false; break; + default: assert(false && "unknown pattern"); + } + } + } + // Each group is of size n / num_groups. + // + // - group_idxs[j] indicates the start of group j + 1 (i.e. is the inclusive + // scan of the group lengths) + // + // - The first group always starts at index zero, so we do not store it. + // + // - The group_idxs[num_groups - 1] should always equal n. + + if (blockIdx.y == 0 && threadIdx.y == 0) { + const int j_stride = blockDim.x * gridDim.x; + for (int j = blockIdx.x * blockDim.x + threadIdx.x; j < num_groups; j += j_stride) { + group_idxs[j] = (j + 1) * (n / num_groups); + } + group_idxs[num_groups - 1] = n; + } +} + +template +class SparseL2NNTest : public ::testing::TestWithParam> { + public: + SparseL2NNTest() + : params(::testing::TestWithParam>::GetParam()), + stream(handle.get_stream()), + x(params.m * params.k, stream), + y(params.n * params.k, stream), + adj(params.m * params.num_groups, stream), + group_idxs(params.num_groups, stream), + xn(params.m, stream), + yn(params.n, stream), + min(params.m, stream), + min_ref(params.m, stream), + workspace(params.m * sizeof(int), stream) + { + } + + protected: + void SetUp() override + { + raft::random::RngState r(params.seed); + int m = params.m; + int n = params.n; + int k = params.k; + int num_groups = params.num_groups; + uniform(handle, r, x.data(), m * k, DataT(-1.0), DataT(1.0)); + uniform(handle, r, y.data(), n * k, DataT(-1.0), DataT(1.0)); + + dim3 block(32, 32); + dim3 grid(10, 10); + init_adj<<>>( + m, n, num_groups, params.pattern, adj.data(), group_idxs.data()); + RAFT_CUDA_TRY(cudaGetLastError()); + + generateGoldenResult(); + raft::linalg::rowNorm(xn.data(), x.data(), k, m, raft::linalg::L2Norm, true, stream); + raft::linalg::rowNorm(yn.data(), y.data(), k, n, raft::linalg::L2Norm, true, stream); + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + } + + protected: + Inputs params; + rmm::device_uvector x; + rmm::device_uvector y; + rmm::device_uvector adj; + rmm::device_uvector group_idxs; + rmm::device_uvector xn; + rmm::device_uvector yn; + rmm::device_uvector> min; + rmm::device_uvector> min_ref; + rmm::device_uvector workspace; + raft::handle_t handle; + cudaStream_t stream; + + virtual void generateGoldenResult() + { + int m = params.m; + int n = params.n; + int k = params.k; + int num_groups = params.num_groups; + + naive(min_ref.data(), + x.data(), + y.data(), + adj.data(), + group_idxs.data(), + m, + n, + k, + num_groups, + (int*)workspace.data(), + stream); + } + + void runTest(cub::KeyValuePair* out) + { + int m = params.m; + int n = params.n; + int k = params.k; + int num_groups = params.num_groups; + + MinAndDistanceReduceOp redOp; + sparseL2NN, int>( + out, + x.data(), + y.data(), + xn.data(), + yn.data(), + adj.data(), + group_idxs.data(), + num_groups, + m, + n, + k, + (void*)workspace.data(), + redOp, + raft::distance::KVPMinReduce(), + Sqrt, + true, + stream); + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + } +}; + +template +struct CompareApproxAbsKVP { + typedef typename cub::KeyValuePair KVP; + CompareApproxAbsKVP(T eps_) : eps(eps_) {} + bool operator()(const KVP& a, const KVP& b) const + { + T diff = raft::abs(raft::abs(a.value) - raft::abs(b.value)); + T m = std::max(raft::abs(a.value), raft::abs(b.value)); + T ratio = m >= eps ? diff / m : diff; + return (ratio <= eps); + } + + private: + T eps; +}; + +template +struct CompareExactKVP { + typedef typename cub::KeyValuePair KVP; + bool operator()(const KVP& a, const KVP& b) const + { + if (a.value != b.value) return false; + return true; + } +}; + +template +::testing::AssertionResult devArrMatch(const cub::KeyValuePair* expected, + const cub::KeyValuePair* actual, + size_t size, + L eq_compare, + cudaStream_t stream = 0) +{ + typedef typename cub::KeyValuePair KVP; + std::shared_ptr exp_h(new KVP[size]); + std::shared_ptr act_h(new KVP[size]); + raft::update_host(exp_h.get(), expected, size, stream); + raft::update_host(act_h.get(), actual, size, stream); + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + for (size_t i(0); i < size; ++i) { + auto exp = exp_h.get()[i]; + auto act = act_h.get()[i]; + if (!eq_compare(exp, act)) { + return ::testing::AssertionFailure() + << "actual=" << act.key << "," << act.value << " != expected=" << exp.key << "," + << exp.value << " @" << i; + } + } + return ::testing::AssertionSuccess(); +} + +const std::vector> inputsf = { + {0.001f, 32, 32, 32, 2, 1234ULL, AdjacencyPattern::all_true}, + {0.001f, 512, 512, 8, 32, 1234ULL, AdjacencyPattern::all_true}, + {0.001f, 512, 512, 8, 32, 1234ULL, AdjacencyPattern::all_false}, + {0.001f, 512, 512, 8, 32, 1234ULL, AdjacencyPattern::checkerboard}, + {0.001f, 512, 512, 8, 32, 1234ULL, AdjacencyPattern::checkerboard_4}, + {0.001f, 512, 512, 8, 32, 1234ULL, AdjacencyPattern::checkerboard_64}, + {0.001f, 1 << 9, 1 << 16, 8, 1 << 9, 1234ULL, AdjacencyPattern::all_true}, + {0.001f, 1 << 9, 1 << 16, 8, 1 << 9, 1234ULL, AdjacencyPattern::all_false}, + {0.001f, 1 << 9, 1 << 16, 8, 1 << 9, 1234ULL, AdjacencyPattern::checkerboard}, + {0.001f, 1 << 9, 1 << 16, 8, 1 << 9, 1234ULL, AdjacencyPattern::checkerboard_4}, + {0.001f, 1 << 9, 1 << 16, 8, 1 << 9, 1234ULL, AdjacencyPattern::checkerboard_64}, + {0.001f, (1 << 15) + 19, (1 << 9) + 17, 8, 32, 1234ULL, AdjacencyPattern::all_true}, + {0.001f, (1 << 15) + 19, (1 << 9) + 17, 8, 32, 1234ULL, AdjacencyPattern::all_false}, + {0.001f, (1 << 15) + 19, (1 << 9) + 17, 8, 32, 1234ULL, AdjacencyPattern::checkerboard}, +}; + +typedef SparseL2NNTest SparseL2NNTestF_Sq; +TEST_P(SparseL2NNTestF_Sq, Result) +{ + runTest(min.data()); + ASSERT_TRUE(devArrMatch( + min_ref.data(), min.data(), params.m, CompareApproxAbsKVP(params.tolerance), stream)); +} +INSTANTIATE_TEST_CASE_P(SparseL2NNTests, SparseL2NNTestF_Sq, ::testing::ValuesIn(inputsf)); +typedef SparseL2NNTest SparseL2NNTestF_Sqrt; +TEST_P(SparseL2NNTestF_Sqrt, Result) +{ + runTest(min.data()); + ASSERT_TRUE(devArrMatch( + min_ref.data(), min.data(), params.m, CompareApproxAbsKVP(params.tolerance), stream)); +} +INSTANTIATE_TEST_CASE_P(SparseL2NNTests, SparseL2NNTestF_Sqrt, ::testing::ValuesIn(inputsf)); + +const std::vector> inputsd = { + {0.00001, 32, 32, 32, 2, 1234ULL, AdjacencyPattern::all_true}, + + {0.00001, 512, 512, 8, 32, 1234ULL, AdjacencyPattern::all_true}, + {0.00001, 512, 512, 8, 32, 1234ULL, AdjacencyPattern::all_false}, + {0.00001, 512, 512, 8, 32, 1234ULL, AdjacencyPattern::checkerboard}, + {0.00001, 512, 512, 8, 32, 1234ULL, AdjacencyPattern::checkerboard_4}, + {0.00001, 512, 512, 8, 32, 1234ULL, AdjacencyPattern::checkerboard_64}, + + {0.00001, 1 << 9, 1 << 16, 8, 1 << 9, 1234ULL, AdjacencyPattern::all_true}, + {0.00001, 1 << 9, 1 << 16, 8, 1 << 9, 1234ULL, AdjacencyPattern::all_false}, + {0.00001, 1 << 9, 1 << 16, 8, 1 << 9, 1234ULL, AdjacencyPattern::checkerboard}, + {0.00001, 1 << 9, 1 << 16, 8, 1 << 9, 1234ULL, AdjacencyPattern::checkerboard_4}, + {0.00001, 1 << 9, 1 << 16, 8, 1 << 9, 1234ULL, AdjacencyPattern::checkerboard_64}, +}; +typedef SparseL2NNTest SparseL2NNTestD_Sq; +TEST_P(SparseL2NNTestD_Sq, Result) +{ + runTest(min.data()); + ASSERT_TRUE(devArrMatch( + min_ref.data(), min.data(), params.m, CompareApproxAbsKVP(params.tolerance), stream)); +} +INSTANTIATE_TEST_CASE_P(SparseL2NNTests, SparseL2NNTestD_Sq, ::testing::ValuesIn(inputsd)); +typedef SparseL2NNTest SparseL2NNTestD_Sqrt; +TEST_P(SparseL2NNTestD_Sqrt, Result) +{ + runTest(min.data()); + ASSERT_TRUE(devArrMatch( + min_ref.data(), min.data(), params.m, CompareApproxAbsKVP(params.tolerance), stream)); +} +INSTANTIATE_TEST_CASE_P(SparseL2NNTests, SparseL2NNTestD_Sqrt, ::testing::ValuesIn(inputsd)); + +/// This is to test output determinism of the prim +template +class SparseL2NNDetTest : public SparseL2NNTest { + public: + SparseL2NNDetTest() : stream(handle.get_stream()), min1(0, stream) {} + + void SetUp() override + { + SparseL2NNTest::SetUp(); + int m = this->params.m; + min1.resize(m, stream); + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + } + + void TearDown() override { SparseL2NNTest::TearDown(); } + + protected: + raft::handle_t handle; + cudaStream_t stream; + + rmm::device_uvector> min1; + + static const int NumRepeats = 100; + + void generateGoldenResult() override {} +}; + +typedef SparseL2NNDetTest SparseL2NNDetTestF_Sq; +TEST_P(SparseL2NNDetTestF_Sq, Result) +{ + runTest(min.data()); // assumed to be golden + for (int i = 0; i < NumRepeats; ++i) { + runTest(min1.data()); + ASSERT_TRUE(devArrMatch(min.data(), min1.data(), params.m, CompareExactKVP(), stream)); + } +} +INSTANTIATE_TEST_CASE_P(SparseL2NNDetTests, SparseL2NNDetTestF_Sq, ::testing::ValuesIn(inputsf)); +typedef SparseL2NNDetTest SparseL2NNDetTestF_Sqrt; +TEST_P(SparseL2NNDetTestF_Sqrt, Result) +{ + runTest(min.data()); // assumed to be golden + for (int i = 0; i < NumRepeats; ++i) { + runTest(min1.data()); + ASSERT_TRUE(devArrMatch(min.data(), min1.data(), params.m, CompareExactKVP(), stream)); + } +} +INSTANTIATE_TEST_CASE_P(SparseL2NNDetTests, SparseL2NNDetTestF_Sqrt, ::testing::ValuesIn(inputsf)); + +typedef SparseL2NNDetTest SparseL2NNDetTestD_Sq; +TEST_P(SparseL2NNDetTestD_Sq, Result) +{ + runTest(min.data()); // assumed to be golden + for (int i = 0; i < NumRepeats; ++i) { + runTest(min1.data()); + ASSERT_TRUE(devArrMatch(min.data(), min1.data(), params.m, CompareExactKVP(), stream)); + } +} +INSTANTIATE_TEST_CASE_P(SparseL2NNDetTests, SparseL2NNDetTestD_Sq, ::testing::ValuesIn(inputsd)); +typedef SparseL2NNDetTest SparseL2NNDetTestD_Sqrt; +TEST_P(SparseL2NNDetTestD_Sqrt, Result) +{ + runTest(min.data()); // assumed to be golden + for (int i = 0; i < NumRepeats; ++i) { + runTest(min1.data()); + ASSERT_TRUE(devArrMatch(min.data(), min1.data(), params.m, CompareExactKVP(), stream)); + } +} +INSTANTIATE_TEST_CASE_P(SparseL2NNDetTests, SparseL2NNDetTestD_Sqrt, ::testing::ValuesIn(inputsd)); + +} // end namespace sparse_l2_nn +} // end namespace distance +} // end namespace raft