diff --git a/cpp/bench/CMakeLists.txt b/cpp/bench/CMakeLists.txt index 1bc2c86243..b1ffc72ba9 100644 --- a/cpp/bench/CMakeLists.txt +++ b/cpp/bench/CMakeLists.txt @@ -82,7 +82,6 @@ if(BUILD_BENCH) bench/distance/distance_l1.cu bench/distance/distance_unexp_l2.cu bench/distance/fused_l2_nn.cu - bench/distance/masked_nn.cu bench/distance/kernels.cu bench/main.cpp OPTIONAL diff --git a/cpp/bench/distance/masked_nn.cu b/cpp/bench/distance/masked_nn.cu deleted file mode 100644 index 3677d44864..0000000000 --- a/cpp/bench/distance/masked_nn.cu +++ /dev/null @@ -1,267 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#if defined RAFT_NN_COMPILED -#include -#endif - -namespace raft::bench::distance::masked_nn { - -// Introduce various sparsity patterns -enum AdjacencyPattern { - checkerboard = 0, - checkerboard_4 = 1, - checkerboard_64 = 2, - all_true = 3, - all_false = 4 -}; - -struct Params { - int m, n, k, num_groups; - AdjacencyPattern pattern; -}; // struct Params - -__global__ void init_adj(AdjacencyPattern pattern, - int n, - raft::device_matrix_view adj, - raft::device_vector_view group_idxs) -{ - int m = adj.extent(0); - int num_groups = adj.extent(1); - - for (int idx_m = blockIdx.y * blockDim.y + threadIdx.y; idx_m < m; - idx_m += blockDim.y * gridDim.y) { - for (int idx_g = blockIdx.x * blockDim.x + threadIdx.x; idx_g < num_groups; - idx_g += blockDim.x * gridDim.x) { - switch (pattern) { - case checkerboard: adj(idx_m, idx_g) = (idx_m + idx_g) % 2; break; - case checkerboard_4: adj(idx_m, idx_g) = (idx_m / 4 + idx_g) % 2; break; - case checkerboard_64: adj(idx_m, idx_g) = (idx_m / 64 + idx_g) % 2; break; - case all_true: adj(idx_m, idx_g) = true; break; - case all_false: adj(idx_m, idx_g) = false; break; - default: assert(false && "unknown pattern"); - } - } - } - // Each group is of size n / num_groups. - // - // - group_idxs[j] indicates the start of group j + 1 (i.e. is the inclusive - // scan of the group lengths) - // - // - The first group always starts at index zero, so we do not store it. - // - // - The group_idxs[num_groups - 1] should always equal n. - - if (blockIdx.y == 0 && threadIdx.y == 0) { - const int g_stride = blockDim.x * gridDim.x; - for (int idx_g = blockIdx.x * blockDim.x + threadIdx.x; idx_g < num_groups; idx_g += g_stride) { - group_idxs(idx_g) = (idx_g + 1) * (n / num_groups); - } - group_idxs(num_groups - 1) = n; - } -} - -template -struct masked_l2_nn : public fixture { - using DataT = T; - using IdxT = int; - using OutT = raft::KeyValuePair; - using RedOpT = raft::distance::MinAndDistanceReduceOp; - using PairRedOpT = raft::distance::KVPMinReduce; - using ParamT = raft::distance::MaskedL2NNParams; - - // Parameters - Params params; - // Data - raft::device_vector out; - raft::device_matrix x, y; - raft::device_vector xn, yn; - raft::device_matrix adj; - raft::device_vector group_idxs; - - masked_l2_nn(const Params& p) - : params(p), - out{raft::make_device_vector(handle, p.m)}, - x{raft::make_device_matrix(handle, p.m, p.k)}, - y{raft::make_device_matrix(handle, p.n, p.k)}, - xn{raft::make_device_vector(handle, p.m)}, - yn{raft::make_device_vector(handle, p.n)}, - adj{raft::make_device_matrix(handle, p.m, p.num_groups)}, - group_idxs{raft::make_device_vector(handle, p.num_groups)} - { - raft::random::RngState r(123456ULL); - - uniform(handle, r, x.data_handle(), p.m * p.k, T(-1.0), T(1.0)); - uniform(handle, r, y.data_handle(), p.n * p.k, T(-1.0), T(1.0)); - raft::linalg::rowNorm( - xn.data_handle(), x.data_handle(), p.k, p.m, raft::linalg::L2Norm, true, stream); - raft::linalg::rowNorm( - yn.data_handle(), y.data_handle(), p.k, p.n, raft::linalg::L2Norm, true, stream); - raft::distance::initialize, int>( - handle, out.data_handle(), p.m, std::numeric_limits::max(), RedOpT{}); - - dim3 block(32, 32); - dim3 grid(10, 10); - init_adj<<>>(p.pattern, p.n, adj.view(), group_idxs.view()); - RAFT_CUDA_TRY(cudaGetLastError()); - } - - void run_benchmark(::benchmark::State& state) override - { - bool init_out = true; - bool sqrt = false; - ParamT masked_l2_params{RedOpT{}, PairRedOpT{}, sqrt, init_out}; - - loop_on_state(state, [this, masked_l2_params]() { - // It is sufficient to only benchmark the L2-squared metric - raft::distance::maskedL2NN(handle, - masked_l2_params, - x.view(), - y.view(), - xn.view(), - yn.view(), - adj.view(), - group_idxs.view(), - out.view()); - }); - - // Virtual flop count if no skipping had occurred. - size_t virtual_flops = size_t(2) * size_t(params.m) * size_t(params.n) * size_t(params.k); - - int64_t read_elts = params.n * params.k + params.m * params.k; - int64_t write_elts = params.m; - - // Virtual min flops is the number of flops that would have been executed if - // the algorithm had actually skipped each computation that it could have - // skipped. - size_t virtual_min_flops = 0; - switch (params.pattern) { - case checkerboard: - case checkerboard_4: - case checkerboard_64: virtual_min_flops = virtual_flops / 2; break; - case all_true: virtual_min_flops = virtual_flops; break; - case all_false: virtual_min_flops = 0; break; - default: assert(false && "unknown pattern"); - } - - // VFLOP/s is the "virtual" flop count that would have executed if there was - // no adjacency pattern. This is useful for comparing to fusedL2NN - state.counters["VFLOP/s"] = benchmark::Counter(virtual_flops, - benchmark::Counter::kIsIterationInvariantRate, - benchmark::Counter::OneK::kIs1000); - // Virtual min flops is the number of flops that would have been executed if - // the algorithm had actually skipped each computation that it could have - // skipped. - state.counters["VminFLOP/s"] = benchmark::Counter(virtual_min_flops, - benchmark::Counter::kIsIterationInvariantRate, - benchmark::Counter::OneK::kIs1000); - - state.counters["BW Wr"] = benchmark::Counter(write_elts * sizeof(OutT), - benchmark::Counter::kIsIterationInvariantRate, - benchmark::Counter::OneK::kIs1000); - state.counters["BW Rd"] = benchmark::Counter(read_elts * sizeof(DataT), - benchmark::Counter::kIsIterationInvariantRate, - benchmark::Counter::OneK::kIs1000); - - state.counters["m"] = benchmark::Counter(params.m); - state.counters["n"] = benchmark::Counter(params.n); - state.counters["k"] = benchmark::Counter(params.k); - state.counters["num_groups"] = benchmark::Counter(params.num_groups); - state.counters["group size"] = benchmark::Counter(params.n / params.num_groups); - state.counters["Pat"] = benchmark::Counter(static_cast(params.pattern)); - - state.counters["SM count"] = raft::getMultiProcessorCount(); - } -}; // struct MaskedL2NN - -const std::vector masked_l2_nn_input_vecs = { - // Very fat matrices... - {32, 16384, 16384, 32, AdjacencyPattern::checkerboard}, - {64, 16384, 16384, 32, AdjacencyPattern::checkerboard}, - {128, 16384, 16384, 32, AdjacencyPattern::checkerboard}, - {256, 16384, 16384, 32, AdjacencyPattern::checkerboard}, - {512, 16384, 16384, 32, AdjacencyPattern::checkerboard}, - {1024, 16384, 16384, 32, AdjacencyPattern::checkerboard}, - {16384, 32, 16384, 32, AdjacencyPattern::checkerboard}, - {16384, 64, 16384, 32, AdjacencyPattern::checkerboard}, - {16384, 128, 16384, 32, AdjacencyPattern::checkerboard}, - {16384, 256, 16384, 32, AdjacencyPattern::checkerboard}, - {16384, 512, 16384, 32, AdjacencyPattern::checkerboard}, - {16384, 1024, 16384, 32, AdjacencyPattern::checkerboard}, - - // Representative matrices... - {16384, 16384, 32, 32, AdjacencyPattern::checkerboard}, - {16384, 16384, 64, 32, AdjacencyPattern::checkerboard}, - {16384, 16384, 128, 32, AdjacencyPattern::checkerboard}, - {16384, 16384, 256, 32, AdjacencyPattern::checkerboard}, - {16384, 16384, 512, 32, AdjacencyPattern::checkerboard}, - {16384, 16384, 1024, 32, AdjacencyPattern::checkerboard}, - {16384, 16384, 16384, 32, AdjacencyPattern::checkerboard}, - - {16384, 16384, 32, 32, AdjacencyPattern::checkerboard_4}, - {16384, 16384, 64, 32, AdjacencyPattern::checkerboard_4}, - {16384, 16384, 128, 32, AdjacencyPattern::checkerboard_4}, - {16384, 16384, 256, 32, AdjacencyPattern::checkerboard_4}, - {16384, 16384, 512, 32, AdjacencyPattern::checkerboard_4}, - {16384, 16384, 1024, 32, AdjacencyPattern::checkerboard_4}, - {16384, 16384, 16384, 32, AdjacencyPattern::checkerboard_4}, - - {16384, 16384, 32, 32, AdjacencyPattern::checkerboard_64}, - {16384, 16384, 64, 32, AdjacencyPattern::checkerboard_64}, - {16384, 16384, 128, 32, AdjacencyPattern::checkerboard_64}, - {16384, 16384, 256, 32, AdjacencyPattern::checkerboard_64}, - {16384, 16384, 512, 32, AdjacencyPattern::checkerboard_64}, - {16384, 16384, 1024, 32, AdjacencyPattern::checkerboard_64}, - {16384, 16384, 16384, 32, AdjacencyPattern::checkerboard_64}, - - {16384, 16384, 32, 32, AdjacencyPattern::all_true}, - {16384, 16384, 64, 32, AdjacencyPattern::all_true}, - {16384, 16384, 128, 32, AdjacencyPattern::all_true}, - {16384, 16384, 256, 32, AdjacencyPattern::all_true}, - {16384, 16384, 512, 32, AdjacencyPattern::all_true}, - {16384, 16384, 1024, 32, AdjacencyPattern::all_true}, - {16384, 16384, 16384, 32, AdjacencyPattern::all_true}, - - {16384, 16384, 32, 32, AdjacencyPattern::all_false}, - {16384, 16384, 64, 32, AdjacencyPattern::all_false}, - {16384, 16384, 128, 32, AdjacencyPattern::all_false}, - {16384, 16384, 256, 32, AdjacencyPattern::all_false}, - {16384, 16384, 512, 32, AdjacencyPattern::all_false}, - {16384, 16384, 1024, 32, AdjacencyPattern::all_false}, - {16384, 16384, 16384, 32, AdjacencyPattern::all_false}, -}; - -RAFT_BENCH_REGISTER(masked_l2_nn, "", masked_l2_nn_input_vecs); -// We don't benchmark double to keep compile times in check when not using the -// distance library. - -} // namespace raft::bench::distance::masked_nn diff --git a/cpp/include/raft/distance/detail/compress_to_bits.cuh b/cpp/include/raft/distance/detail/compress_to_bits.cuh deleted file mode 100644 index e36b7ce707..0000000000 --- a/cpp/include/raft/distance/detail/compress_to_bits.cuh +++ /dev/null @@ -1,122 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include -#include -#include - -namespace raft::distance::detail { - -/** - * @brief Compress 2D boolean matrix to bitfield - * - * Utility kernel for maskedL2NN. - * - * @tparam T - * - * @parameter[in] in An `m x n` boolean matrix. Row major. - * @parameter[out] out An `(m / bits_per_elem) x n` matrix with elements of - * type T, where T is of size `bits_per_elem` bits. - * Note: the division (`/`) is a ceilDiv. - */ -template ::value>> -__global__ void compress_to_bits_kernel( - raft::device_matrix_view in, - raft::device_matrix_view out) -{ - constexpr int bits_per_element = 8 * sizeof(T); - constexpr int tile_dim_m = bits_per_element; - constexpr int nthreads = 128; - constexpr int tile_dim_n = nthreads; // read 128 bools at once = 1 sector - - // Tile in shared memory is transposed - __shared__ bool smem[tile_dim_n][tile_dim_m]; - - const int num_tiles_per_m = raft::ceildiv(in.extent(0), tile_dim_m); - const int num_tiles_per_n = raft::ceildiv(in.extent(1), tile_dim_n); - - for (int lin_tile_idx = blockIdx.x; true; lin_tile_idx += gridDim.x) { - const int tile_idx_n = tile_dim_n * (lin_tile_idx % num_tiles_per_n); - const int tile_idx_m = tile_dim_m * (lin_tile_idx / num_tiles_per_n); - - if (in.extent(0) <= tile_idx_m) { break; } - // Fill shared memory tile - bool reg_buf[tile_dim_m]; -#pragma unroll - for (int i = 0; i < tile_dim_m; ++i) { - const int in_m = tile_idx_m + i; - const int in_n = tile_idx_n + threadIdx.x; - bool in_bounds = in_m < in.extent(0) && in_n < in.extent(1); - reg_buf[i] = in_bounds ? in(in_m, in_n) : false; - smem[threadIdx.x][i] = reg_buf[i]; - } - __syncthreads(); - - // Drain memory tile into single output element out_elem. - T out_elem{0}; -#pragma unroll - for (int j = 0; j < tile_dim_n; ++j) { - if (smem[threadIdx.x][j]) { out_elem |= T(1) << j; } - } - __syncthreads(); - - // Write output. - int out_m = tile_idx_m / bits_per_element; - int out_n = tile_idx_n + threadIdx.x; - - if (out_m < out.extent(0) && out_n < out.extent(1)) { out(out_m, out_n) = out_elem; } - } -} - -/** - * @brief Compress 2D boolean matrix to bitfield - * - * Utility kernel for maskedL2NN. - * - * @tparam T - * - * @parameter[in] in An `m x n` boolean matrix. Row major. - * @parameter[out] out An `(m / bits_per_elem) x n` matrix with elements of - * type T, where T is of size `bits_per_elem` bits. - * Note: the division (`/`) is a ceilDiv. - */ -template ::value>> -void compress_to_bits(raft::device_resources const& handle, - raft::device_matrix_view in, - raft::device_matrix_view out) -{ - auto stream = handle.get_stream(); - constexpr int bits_per_element = 8 * sizeof(T); - - RAFT_EXPECTS(raft::ceildiv(in.extent(0), bits_per_element) == out.extent(0), - "Number of output rows must be ceildiv(input rows, bits_per_elem)"); - RAFT_EXPECTS(in.extent(1) == out.extent(1), "Number of output columns must equal input columns."); - - const int num_SMs = raft::getMultiProcessorCount(); - int blocks_per_sm = 0; - constexpr int num_threads = 128; - constexpr int dyn_smem_size = 0; - RAFT_CUDA_TRY(cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &blocks_per_sm, compress_to_bits_kernel, num_threads, dyn_smem_size)); - - dim3 grid(num_SMs * blocks_per_sm); - dim3 block(128); - compress_to_bits_kernel<<>>(in, out); - RAFT_CUDA_TRY(cudaGetLastError()); -} - -}; // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/detail/fused_l2_nn.cuh b/cpp/include/raft/distance/detail/fused_l2_nn.cuh index 8fbd7a9c69..447359ffe6 100644 --- a/cpp/include/raft/distance/detail/fused_l2_nn.cuh +++ b/cpp/include/raft/distance/detail/fused_l2_nn.cuh @@ -37,7 +37,6 @@ template struct KVPMinReduceImpl { typedef raft::KeyValuePair KVP; DI KVP operator()(LabelT rit, const KVP& a, const KVP& b) { return b.value < a.value ? b : a; } - DI KVP operator()(const KVP& a, const KVP& b) { return b.value < a.value ? b : a; } }; // KVPMinReduce diff --git a/cpp/include/raft/distance/detail/masked_distance_base.cuh b/cpp/include/raft/distance/detail/masked_distance_base.cuh deleted file mode 100644 index 6d4e3f40a6..0000000000 --- a/cpp/include/raft/distance/detail/masked_distance_base.cuh +++ /dev/null @@ -1,326 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once -#include -#include -#include - -#include - -namespace raft { -namespace distance { -namespace detail { - -/** - * @brief Device class for masked nearest neighbor computations. - * - * @tparam useNorms whether norms are needed - * @tparam DataT input data-type (for x and y matrices) - * @tparam AccT accumulation data-type - * @tparam IdxT index data-type - * @tparam Policy struct which tunes the Contraction kernel - * @tparam CoreLambda tells how to accumulate an x and y into - acc. its signature: - template void core_lambda(AccT& acc, - const DataT& x, const DataT& y) - * @tparam EpilogueLambda applies an elementwise function to compute final - values. Its signature is: - template void epilogue_lambda - (AccT acc[][], DataT* regxn, DataT* regyn); - * @tparam FinalLambda the final lambda called on final distance value - * @tparam rowEpilogueLambda epilog lambda that executes when a full row has - * been processed. - * - * @param[in] x input matrix - * @param[in] y input matrix - * @param[in] m number of rows of x - * @param[in] n number of columns of y - * @param[in] k number of cols of x and y - * @param[in] lda leading dimension of x - * @param[in] ldb leading dimension of y - * @param[in] ldd parameter to keep Contractions_NT happy.. - * @param[in] xn row norms of input matrix A. Required for expanded L2, cosine - * @param[in] yn row norms of input matrix B. Required for expanded L2, cosine - * @param[in] adj An adjacency matrix encoded as a bitfield indicating for each - * row of `x` and each group in `y` whether to compute the - * distance. Dim = `(m / 64) x num_groups`. - * @param[in] group_idxs An array containing the *end* indices of each group - * in `y`. The value of group_idxs[j] indicates the - * start of group j + 1, i.e., it is the inclusive - * scan of the group lengths. The first group is - * always assumed to start at index 0 and the last - * group typically ends at index `n`. Length = - * `num_groups`. - * @param[in] num_groups The number of groups in group_idxs. - * @param[in] smem shared mem buffer for intermediate storage of x, y, xn & yn. - * @param core_op the core accumulation operation lambda - * @param epilog_op the epilog operation lambda - * @param fin_op the final gemm epilogue lambda - * @param rowEpilog_op epilog lambda that executes when a full row has been processed. - */ -template > -struct MaskedDistances : public BaseClass { - private: - typedef Policy P; - const DataT* xn; - const DataT* yn; - const DataT* const yBase; - const uint64_t* adj; - const IdxT* group_idxs; - IdxT num_groups; - char* smem; - CoreLambda core_op; - EpilogueLambda epilog_op; - FinalLambda fin_op; - rowEpilogueLambda rowEpilog_op; - - AccT acc[P::AccRowsPerTh][P::AccColsPerTh]; - - public: - // Constructor - DI MaskedDistances(const DataT* _x, - const DataT* _y, - IdxT _m, - IdxT _n, - IdxT _k, - IdxT _lda, - IdxT _ldb, - IdxT _ldd, - const DataT* _xn, - const DataT* _yn, - const uint64_t* _adj, - const IdxT* _group_idxs, - IdxT _num_groups, - char* _smem, - CoreLambda _core_op, - EpilogueLambda _epilog_op, - FinalLambda _fin_op, - rowEpilogueLambda _rowEpilog_op) - : BaseClass(_x, _y, _m, _n, _k, _lda, _ldb, _ldd, _smem), - xn(_xn), - yn(_yn), - yBase(_y), - adj(_adj), - group_idxs(_group_idxs), - num_groups(_num_groups), - smem(_smem), - core_op(_core_op), - epilog_op(_epilog_op), - fin_op(_fin_op), - rowEpilog_op(_rowEpilog_op) - { - } - - DI void run() - { - const auto grid_stride_m = (P::Mblk * gridDim.y); - const auto grid_offset_m = (P::Mblk * blockIdx.y); - - const auto grid_stride_g = gridDim.x; - const auto grid_offset_g = blockIdx.x; - - for (auto tile_idx_m = grid_offset_m; tile_idx_m < this->m; tile_idx_m += grid_stride_m) { - // Start loop over groups - for (auto idx_g = grid_offset_g; idx_g < this->num_groups; idx_g += grid_stride_g) { - const uint64_t block_adj = get_block_adjacency(adj, tile_idx_m, idx_g); - // block_adj is a bitfield that contains a 1 if a row is adjacent to the - // current group. All zero means we can skip this group. - if (block_adj == 0) { continue; } - - // thread_adj is a bitfield that contains a 1 at location i iff we must - // compute row i of acc (the accumulator register tile). That is, - // for i = 0,.., AccRowsPerTh and j = 0,.., AccColsPerTh: - // - // ((1 << i) & thread_adj) > 0 <=> acc[i][j] must be computed. - // - // We precompute this information because it is used in various - // locations to skip thread-local computations, specifically: - // - // 1. To skip computations if thread_adj == 0, i.e., none of the values - // of `acc` have to be computed. - // - // 2. In epilog_op, to consider only values of `acc` to be reduced that - // are not masked of. - // - // Note 1: Even when the computation can be skipped for a specific thread, - // the thread still participates in synchronization operations. - // - // Note 2: In theory, it should be possible to skip computations for - // specific rows of `acc`. In practice, however, this does not improve - // performance. - int thread_adj = compute_thread_adjacency(block_adj); - - auto tile_idx_n = idx_g == 0 ? 0 : group_idxs[idx_g - 1]; - const auto group_end_n = group_idxs[idx_g]; - for (; tile_idx_n < group_end_n; tile_idx_n += P::Nblk) { - // We provide group_end_n to limit the number of unnecessary data - // points that are loaded from y. - this->ldgXY(tile_idx_m, tile_idx_n, 0, group_end_n); - - reset_accumulator(); - this->stsXY(); - __syncthreads(); - this->switch_write_buffer(); - - for (int kidx = P::Kblk; kidx < this->k; kidx += P::Kblk) { - this->ldgXY(tile_idx_m, tile_idx_n, kidx, group_end_n); - // Process all data in shared memory (previous k-block) and - // accumulate in registers. - if (thread_adj != 0) { accumulate(); } - this->stsXY(); - __syncthreads(); - this->switch_write_buffer(); - this->switch_read_buffer(); - } - if (thread_adj != 0) { - accumulate(); // last iteration - } - // The pre-condition for the loop over tile_idx_n is that write_buffer - // and read_buffer point to the same buffer. This flips read_buffer - // back so that it satisfies the pre-condition of this loop. - this->switch_read_buffer(); - - if (useNorms) { - DataT regxn[P::AccRowsPerTh], regyn[P::AccColsPerTh]; - load_norms(tile_idx_m, tile_idx_n, group_end_n, regxn, regyn); - if (thread_adj != 0) { - epilog_op(acc, thread_adj, regxn, regyn, tile_idx_n, tile_idx_m, group_end_n); - } - } else { - if (thread_adj != 0) { - epilog_op(acc, thread_adj, nullptr, nullptr, tile_idx_n, tile_idx_m, group_end_n); - } - } - } // tile_idx_n - } // idx_g - rowEpilog_op(tile_idx_m); - } // tile_idx_m - } - - private: - DI uint64_t get_block_adjacency(const uint64_t* adj, IdxT tile_idx_m, IdxT idx_group) - { - // A single element of `adj` contains exactly enough bits to indicate which - // rows in the current tile to skip and which to compute. - static_assert(P::Mblk == 8 * sizeof(adj[0]), - "maskedL2NN only supports a policy with 64 rows per block."); - IdxT block_flag_idx = tile_idx_m / P::Mblk; - // Index into adj at row tile_idx_m / 64 and column idx_group. - return adj[block_flag_idx * this->num_groups + idx_group]; - } - - DI uint32_t compute_thread_adjacency(const uint64_t block_adj) - { - // thread_adj is a bitfield that contains a 1 at location i iff we must - // compute row i of acc (the accumulator register tile). It is described in - // more detail in the run() method. - uint32_t thread_adj = 0; -#pragma unroll - for (int thread_row_idx = 0; thread_row_idx < P::AccRowsPerTh; ++thread_row_idx) { - // Index `thread_row_idx` refers to a row of the current threads' register - // tile `acc`, i.e., acc[i][:]. Index `block_row_idx` refers to the - // corresponding row of the current block tile in shared memory. - const int block_row_idx = this->accrowid + thread_row_idx * P::AccThRows; - - // block_row_is_adjacent is true if the current block_row_idx is adjacent - // to the current group. - const uint64_t block_mask = 1ull << block_row_idx; - const bool block_row_is_adjacent = (block_adj & block_mask) != 0; - if (block_row_is_adjacent) { - // If block row is adjacent, write a 1 bit to thread_adj at location - // `thread_row_idx`. - const uint32_t thread_mask = 1 << thread_row_idx; - thread_adj |= thread_mask; - } - } - return thread_adj; - } - - DI void reset_accumulator() - { - // Reset accumulator registers to zero. -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = 0; j < P::AccColsPerTh; ++j) { - acc[i][j] = BaseClass::Zero; - } - } - } - - DI void accumulate() - { -#pragma unroll - for (int ki = 0; ki < P::Kblk; ki += P::Veclen) { - this->ldsXY(ki); -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = 0; j < P::AccColsPerTh; ++j) { -#pragma unroll - for (int v = 0; v < P::Veclen; ++v) { - core_op(acc[i][j], this->regx[i][v], this->regy[j][v]); - } - } - } - } - } - - DI void load_norms(IdxT tile_idx_m, - IdxT tile_idx_n, - IdxT end_n, - DataT (®xn)[P::AccRowsPerTh], - DataT (®yn)[P::AccColsPerTh]) - { - DataT* sxNorm = (DataT*)(&smem[P::SmemSize]); - DataT* syNorm = (&sxNorm[P::Mblk]); - - // Load x & y norms required by this threadblock in shmem buffer - for (int i = threadIdx.x; i < P::Mblk; i += P::Nthreads) { - auto idx = tile_idx_m + i; - sxNorm[i] = idx < this->m ? xn[idx] : 0; - } - - for (int i = threadIdx.x; i < P::Nblk; i += P::Nthreads) { - auto idx = tile_idx_n + i; - syNorm[i] = idx < end_n ? yn[idx] : 0; - } - __syncthreads(); - -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { - regxn[i] = sxNorm[i * P::AccThRows + (threadIdx.x / P::AccThCols)]; - } -#pragma unroll - for (int i = 0; i < P::AccColsPerTh; ++i) { - regyn[i] = syNorm[i * P::AccThCols + (threadIdx.x % P::AccThCols)]; - } - } -}; // struct MaskedDistances - -}; // namespace detail -}; // namespace distance -}; // namespace raft diff --git a/cpp/include/raft/distance/detail/masked_nn.cuh b/cpp/include/raft/distance/detail/masked_nn.cuh deleted file mode 100644 index 1c92de16fc..0000000000 --- a/cpp/include/raft/distance/detail/masked_nn.cuh +++ /dev/null @@ -1,325 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include - -#include -#include -#include -#include -#include -#include - -namespace raft { -namespace distance { -namespace detail { - -template -__global__ __launch_bounds__(P::Nthreads, 2) void maskedL2NNkernel(OutT* min, - const DataT* x, - const DataT* y, - const DataT* xn, - const DataT* yn, - const uint64_t* adj, - const IdxT* group_idxs, - IdxT num_groups, - IdxT m, - IdxT n, - IdxT k, - bool sqrt, - DataT maxVal, - int* mutex, - ReduceOpT redOp, - KVPReduceOpT pairRedOp, - CoreLambda core_op, - FinalLambda fin_op) -{ - extern __shared__ char smem[]; - - typedef raft::KeyValuePair KVPair; - KVPair val[P::AccRowsPerTh]; -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { - val[i] = {-1, maxVal}; - } - - // epilogue operation lambda for final value calculation - auto epilog_lambda = [pairRedOp, &val, maxVal, sqrt] __device__( - DataT acc[P::AccRowsPerTh][P::AccColsPerTh], - int thread_adj, - DataT* regxn, - DataT* regyn, - IdxT tile_idx_n, - IdxT tile_idx_m, - IdxT tile_end_n) { - KVPReduceOpT pairRed_op(pairRedOp); - -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = 0; j < P::AccColsPerTh; ++j) { - acc[i][j] = regxn[i] + regyn[j] - (DataT)2.0 * acc[i][j]; - } - } - if (sqrt) { -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = 0; j < P::AccColsPerTh; ++j) { - acc[i][j] = raft::sqrt(acc[i][j]); - } - } - } - - // intra thread reduce - const auto acccolid = threadIdx.x % P::AccThCols; - const auto accrowid = threadIdx.x / P::AccThCols; - -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { - // thread_adj is a bitfield that contains a 1 at location i iff we must - // compute row i of acc (the accumulator register tile). It is described in - // more detail in the maskedDistances.run() method. - const bool ignore = (thread_adj & (1 << i)) == 0; - if (ignore) { continue; } -#pragma unroll - for (int j = 0; j < P::AccColsPerTh; ++j) { - auto tmpkey = acccolid + j * P::AccThCols + tile_idx_n; - if (tile_end_n <= tmpkey) { - // Do not process beyond end of tile. - continue; - } - KVPair tmp = {tmpkey, acc[i][j]}; - if (tmpkey < tile_end_n) { - val[i] = pairRed_op(accrowid + i * P::AccThRows + tile_idx_m, tmp, val[i]); - } - } - } - }; - - auto rowEpilog_lambda = - [m, mutex, min, pairRedOp, redOp, &val, maxVal] __device__(IdxT tile_idx_m) { - KVPReduceOpT pairRed_op(pairRedOp); - ReduceOpT red_op(redOp); - - const auto accrowid = threadIdx.x / P::AccThCols; - const auto lid = raft::laneId(); - // reduce -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = P::AccThCols / 2; j > 0; j >>= 1) { - auto tmpkey = raft::shfl(val[i].key, lid + j); - auto tmpvalue = raft::shfl(val[i].value, lid + j); - KVPair tmp = {tmpkey, tmpvalue}; - val[i] = pairRed_op(accrowid + i * P::AccThRows + tile_idx_m, tmp, val[i]); - } - } - - updateReducedVal(mutex, min, val, red_op, m, tile_idx_m); - - // reset the val array. -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { - val[i] = {-1, maxVal}; - } - }; - - IdxT lda = k, ldb = k, ldd = n; - MaskedDistances - obj(x, - y, - m, - n, - k, - lda, - ldb, - ldd, - xn, - yn, - adj, - group_idxs, - num_groups, - smem, - core_op, - epilog_lambda, - fin_op, - rowEpilog_lambda); - obj.run(); -} - -/** - * @brief Wrapper for maskedL2NNkernel - * - * Responsibilities: - * - Allocate (and initialize) workspace memory for: - * - mutexes used in nearest neighbor update step - * - adjacency matrix bitfield - * - Compress adjacency matrix to bitfield - * - Initialize output buffer (conditional on `initOutBuffer`) - * - Specify core and final operations for the L2 norm - * - Determine optimal launch configuration for kernel. - * - Launch kernel and check for errors. - * - * @tparam DataT Input data-type (for x and y matrices). - * @tparam OutT Output data-type (for key-value pairs). - * @tparam IdxT Index data-type. - * @tparam ReduceOpT A struct to perform the final needed reduction - * operation and also to initialize the output array - * elements with the appropriate initial value needed for - * reduction. - * @tparam KVPReduceOpT Type of Reduction operation on key value pairs. - * - * @param handle RAFT handle for managing expensive resources - * @param[out] out Will contain reduced output (nn key-value pairs) - * @param[in] x First matrix. Row major. Dim = `m x k`. (on device) - * @param[in] y Second matrix. Row major. Dim = `n x k`. (on device) - * @param[in] xn L2 squared norm of `x`. Length = `m`. - * @param[in] yn L2 squared norm of `y`. Length = `n`. - * @param[in] adj A boolean adjacency matrix indicating for each - * row of `x` and each group in `y` whether to compute the - * distance. Dim = `m x num_groups`. - * @param[in] group_idxs An array containing the *end* indices of each group - * in `y`. The value of group_idxs[j] indicates the - * start of group j + 1, i.e., it is the inclusive - * scan of the group lengths. The first group is - * always assumed to start at index 0 and the last - * group typically ends at index `n`. Length = - * `num_groups`. - * @param[in] num_groups Length of `group_idxs`. - * @param m Rows of `x`. - * @param n Rows of `y`. - * @param k Cols of `x` and `y`. - * @param redOp Reduction operator in the epilogue - * @param pairRedOp Reduction operation on key value pairs - * @param sqrt Whether to compute the squared or actual (i.e. sqrt) L2 norm. - * @param initOutBuffer Whether to initialize the output buffer - * - * - */ -template -void maskedL2NNImpl(raft::device_resources const& handle, - OutT* out, - const DataT* x, - const DataT* y, - const DataT* xn, - const DataT* yn, - const bool* adj, - const IdxT* group_idxs, - IdxT num_groups, - IdxT m, - IdxT n, - IdxT k, - ReduceOpT redOp, - KVPReduceOpT pairRedOp, - bool sqrt, - bool initOutBuffer) -{ - typedef typename linalg::Policy4x4::Policy P; - - static_assert(P::Mblk == 64, "maskedL2NNImpl only supports a policy with 64 rows per block."); - - // Get stream and workspace memory resource - rmm::mr::device_memory_resource* ws_mr = - dynamic_cast(handle.get_workspace_resource()); - auto stream = handle.get_stream(); - - // Acquire temporary buffers and initialize to zero: - // 1) Adjacency matrix bitfield - // 2) Workspace for fused nearest neighbor operation - size_t m_div_64 = raft::ceildiv(m, IdxT(64)); - rmm::device_uvector ws_adj64{m_div_64 * num_groups, stream, ws_mr}; - rmm::device_uvector ws_fused_nn{size_t(m), stream, ws_mr}; - RAFT_CUDA_TRY(cudaMemsetAsync(ws_adj64.data(), 0, ws_adj64.size() * sizeof(uint64_t), stream)); - RAFT_CUDA_TRY(cudaMemsetAsync(ws_fused_nn.data(), 0, ws_fused_nn.size() * sizeof(int), stream)); - - // Compress boolean adjacency matrix to bitfield. - auto adj_view = raft::make_device_matrix_view(adj, m, num_groups); - auto adj64_view = - raft::make_device_matrix_view(ws_adj64.data(), m_div_64, num_groups); - compress_to_bits(handle, adj_view, adj64_view); - - // Initialize output buffer with keyvalue pairs as determined by the reduction - // operator (it will be called with maxVal). - constexpr auto maxVal = std::numeric_limits::max(); - if (initOutBuffer) { - dim3 grid(raft::ceildiv(m, P::Nthreads)); - dim3 block(P::Nthreads); - - initKernel<<>>(out, m, maxVal, redOp); - RAFT_CUDA_TRY(cudaGetLastError()); - } - - // Accumulation operation lambda - auto core_lambda = [] __device__(DataT & acc, DataT & x, DataT & y) { acc += x * y; }; - auto fin_op = raft::identity_op{}; - - auto kernel = maskedL2NNkernel; - constexpr size_t smemSize = P::SmemSize + ((P::Mblk + P::Nblk) * sizeof(DataT)); - dim3 block(P::Nthreads); - dim3 grid = launchConfigGenerator

(m, n, smemSize, kernel); - - kernel<<>>(out, - x, - y, - xn, - yn, - ws_adj64.data(), - group_idxs, - num_groups, - m, - n, - k, - sqrt, - maxVal, - ws_fused_nn.data(), - redOp, - pairRedOp, - core_lambda, - fin_op); - - RAFT_CUDA_TRY(cudaGetLastError()); -} - -} // namespace detail -} // namespace distance -} // namespace raft diff --git a/cpp/include/raft/distance/masked_nn.cuh b/cpp/include/raft/distance/masked_nn.cuh deleted file mode 100644 index ea2e10a304..0000000000 --- a/cpp/include/raft/distance/masked_nn.cuh +++ /dev/null @@ -1,199 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef __MASKED_L2_NN_H -#define __MASKED_L2_NN_H - -#pragma once - -#include -#include -#include -#include -#include -#include - -namespace raft { -namespace distance { -/** - * \defgroup masked_nn Masked 1-nearest neighbors - * @{ - */ - -/** - * @brief Parameter struct for maskedL2NN function - * - * @tparam ReduceOpT Type of reduction operator in the epilogue. - * @tparam KVPReduceOpT Type of Reduction operation on key value pairs. - * - * Usage example: - * @code{.cpp} - * #include - * - * using IdxT = int; - * using DataT = float; - * using RedOpT = raft::distance::MinAndDistanceReduceOp; - * using PairRedOpT = raft::distance::KVPMinReduce; - * using ParamT = raft::distance::MaskedL2NNParams; - * - * bool init_out = true; - * bool sqrt = false; - * - * ParamT masked_l2_params{RedOpT{}, PairRedOpT{}, sqrt, init_out}; - * @endcode - * - * Prescribes how to reduce a distance to an intermediate type (`redOp`), and - * how to reduce two intermediate types (`pairRedOp`). Typically, a distance is - * mapped to an (index, value) pair and (index, value) pair with the lowest - * value (distance) is selected. - * - * In addition, prescribes whether to compute the square root of the distance - * (`sqrt`) and whether to initialize the output buffer (`initOutBuffer`). - */ -template -struct MaskedL2NNParams { - /** Reduction operator in the epilogue */ - ReduceOpT redOp; - /** Reduction operation on key value pairs */ - KVPReduceOpT pairRedOp; - /** Whether the output `minDist` should contain L2-sqrt */ - bool sqrt; - /** Whether to initialize the output buffer before the main kernel launch */ - bool initOutBuffer; -}; - -/** - * @brief Masked L2 distance and 1-nearest-neighbor computation in a single call. - * - * This function enables faster computation of nearest neighbors if the - * computation of distances between certain point pairs can be skipped. - * - * We use an adjacency matrix that describes which distances to calculate. The - * points in `y` are divided into groups, and the adjacency matrix indicates - * whether to compute distances between points in `x` and groups in `y`. In other - * words, if `adj[i,k]` is true then distance between point `x_i`, and points in - * `group_k` will be calculated. - * - * **Performance considerations** - * - * The points in `x` are processed in tiles of `M` points (`M` is currently 64, - * but may change in the future). As a result, the largest compute time - * reduction occurs if all `M` points can skip a group. If only part of the `M` - * points can skip a group, then at most a minor compute time reduction and a - * modest energy use reduction can be expected. - * - * The points in `y` are also grouped into tiles of `N` points (`N` is currently - * 64, but may change in the future). As a result, group sizes should be larger - * than `N` to avoid wasting computational resources. If the group sizes are - * evenly divisible by `N`, then the computation is most efficient, although for - * larger group sizes this effect is minor. - * - * - * **Comparison to SDDM** - * - * [SDDMM](https://ieeexplore.ieee.org/document/8638042) (sampled dense-dense - * matrix multiplication) is a matrix-matrix multiplication where only part of - * the output is computed. Compared to maskedL2NN, there are a few differences: - * - * - The output of maskedL2NN is a single vector (of nearest neighbors) and not - * a sparse matrix. - * - * - The sampling in maskedL2NN is expressed through intermediate "groups" - rather than a CSR format. - * - * @tparam DataT data type - * @tparam OutT output type to either store 1-NN indices and their minimum - * distances or store only the min distances. Accordingly, one - * has to pass an appropriate `ReduceOpT` - * @tparam IdxT indexing arithmetic type - * @tparam ReduceOpT A struct to perform the final needed reduction operation - * and also to initialize the output array elements with the - * appropriate initial value needed for reduction. - * - * @param handle RAFT handle for managing expensive resources - * @param params Parameter struct specifying the reduction operations. - * @param[in] x First matrix. Row major. Dim = `m x k`. - * (on device). - * @param[in] y Second matrix. Row major. Dim = `n x k`. - * (on device). - * @param[in] x_norm L2 squared norm of `x`. Length = `m`. (on device). - * @param[in] y_norm L2 squared norm of `y`. Length = `n`. (on device) - * @param[in] adj A boolean adjacency matrix indicating for each - * row of `x` and each group in `y` whether to compute the - * distance. Dim = `m x num_groups`. - * @param[in] group_idxs An array containing the *end* indices of each group - * in `y`. The value of group_idxs[j] indicates the - * start of group j + 1, i.e., it is the inclusive - * scan of the group lengths. The first group is - * always assumed to start at index 0 and the last - * group typically ends at index `n`. Length = - * `num_groups`. - * @param[out] out will contain the reduced output (Length = `m`) - * (on device) - */ -template -void maskedL2NN(raft::device_resources const& handle, - raft::distance::MaskedL2NNParams params, - raft::device_matrix_view x, - raft::device_matrix_view y, - raft::device_vector_view x_norm, - raft::device_vector_view y_norm, - raft::device_matrix_view adj, - raft::device_vector_view group_idxs, - raft::device_vector_view out) -{ - IdxT m = x.extent(0); - IdxT n = y.extent(0); - IdxT k = x.extent(1); - IdxT num_groups = group_idxs.extent(0); - - // Match k dimension of x, y - RAFT_EXPECTS(x.extent(1) == y.extent(1), "Dimension of vectors in x and y must be equal."); - // Match x, x_norm and y, y_norm - RAFT_EXPECTS(m == x_norm.extent(0), "Length of `x_norm` must match input `x`."); - RAFT_EXPECTS(n == y_norm.extent(0), "Length of `y_norm` must match input `y` "); - // Match adj to x and group_idxs - RAFT_EXPECTS(m == adj.extent(0), "#rows in `adj` must match input `x`."); - RAFT_EXPECTS(num_groups == adj.extent(1), "#cols in `adj` must match length of `group_idxs`."); - // NOTE: We do not check if all indices in group_idxs actually points *inside* y. - - // If there is no work to be done, return immediately. - if (m == 0 || n == 0 || k == 0 || num_groups == 0) { return; } - - detail::maskedL2NNImpl(handle, - out.data_handle(), - x.data_handle(), - y.data_handle(), - x_norm.data_handle(), - y_norm.data_handle(), - adj.data_handle(), - group_idxs.data_handle(), - num_groups, - m, - n, - k, - params.redOp, - params.pairRedOp, - params.sqrt, - params.initOutBuffer); -} - -/** @} */ - -} // namespace distance -} // namespace raft - -#endif diff --git a/cpp/include/raft/linalg/detail/contractions.cuh b/cpp/include/raft/linalg/detail/contractions.cuh index b15cb222b4..f2d71117f7 100644 --- a/cpp/include/raft/linalg/detail/contractions.cuh +++ b/cpp/include/raft/linalg/detail/contractions.cuh @@ -151,12 +151,6 @@ struct Contractions_NT { ldgY(tile_idx_n, kidx); } - DI void ldgXY(IdxT tile_idx_m, IdxT tile_idx_n, IdxT kidx, IdxT tile_end_n) - { - ldgX(tile_idx_m, kidx); - ldgY(tile_idx_n, kidx, tile_end_n); - } - /** * @brief Store current block of X/Y from registers to smem * @param[in] kidx current start index of k to be loaded @@ -218,15 +212,13 @@ struct Contractions_NT { } } - DI void ldgY(IdxT tile_idx_n, IdxT kidx) { ldgY(tile_idx_n, kidx, n); } - - DI void ldgY(IdxT tile_idx_n, IdxT kidx, IdxT end_n) + DI void ldgY(IdxT tile_idx_n, IdxT kidx) { IdxT yrowid = isRowMajor ? tile_idx_n + srowid : tile_idx_n; auto y = isRowMajor ? y_base + yrowid * ldb : y_base + yrowid + srowid * ldb; if (isRowMajor) { - auto numRows = end_n; + auto numRows = n; auto koffset = kidx + scolid; #pragma unroll for (int i = 0; i < P::LdgPerThY; ++i) { @@ -244,7 +236,7 @@ struct Contractions_NT { auto koffset = scolid; #pragma unroll for (int i = 0; i < P::LdgPerThY; ++i) { - if ((koffset + yrowid) < end_n && (srowid + kidx + i * P::LdgRowsY) < numRows) { + if ((koffset + yrowid) < ldb && (srowid + kidx + i * P::LdgRowsY) < numRows) { ldg(ldgDataY[i], y + (kidx + i * P::LdgRowsY) * ldb + koffset); } else { #pragma unroll diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 2e89418f8e..3c41621274 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -123,8 +123,6 @@ if(BUILD_TESTS) test/distance/dist_minkowski.cu test/distance/dist_russell_rao.cu test/distance/fused_l2_nn.cu - test/distance/masked_nn.cu - test/distance/masked_nn_compress_to_bits.cu test/distance/gram.cu OPTIONAL DIST diff --git a/cpp/test/distance/fused_l2_nn.cu b/cpp/test/distance/fused_l2_nn.cu index af67214193..8b9681b9d3 100644 --- a/cpp/test/distance/fused_l2_nn.cu +++ b/cpp/test/distance/fused_l2_nn.cu @@ -385,7 +385,7 @@ class FusedL2NNDetTest : public FusedL2NNTest { rmm::device_uvector> min1; - static const int NumRepeats = 3; + static const int NumRepeats = 100; void generateGoldenResult() override {} }; diff --git a/cpp/test/distance/masked_nn.cu b/cpp/test/distance/masked_nn.cu deleted file mode 100644 index c80c984992..0000000000 --- a/cpp/test/distance/masked_nn.cu +++ /dev/null @@ -1,435 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "../test_utils.h" -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace raft::distance::masked_nn { - -// The adjacency pattern determines what distances get computed. -enum AdjacencyPattern { - checkerboard = 0, // adjacency matrix looks like a checkerboard (half the distances are computed) - checkerboard_4 = 1, // checkerboard with tiles of size 4x4 - checkerboard_64 = 2, // checkerboard with tiles of size 64x64 - all_true = 3, // no distance computations can be skipped - all_false = 4 // all distance computations can be skipped -}; - -// Kernels: -// - init_adj: to initialize the adjacency kernel with a specific adjacency pattern -// - referenceKernel: to produce the ground-truth output - -__global__ void init_adj(AdjacencyPattern pattern, - int n, - raft::device_matrix_view adj, - raft::device_vector_view group_idxs) -{ - int m = adj.extent(0); - int num_groups = adj.extent(1); - - for (int idx_m = blockIdx.y * blockDim.y + threadIdx.y; idx_m < m; - idx_m += blockDim.y * gridDim.y) { - for (int idx_g = blockIdx.x * blockDim.x + threadIdx.x; idx_g < num_groups; - idx_g += blockDim.x * gridDim.x) { - switch (pattern) { - case checkerboard: adj(idx_m, idx_g) = (idx_m + idx_g) % 2; break; - case checkerboard_4: adj(idx_m, idx_g) = (idx_m / 4 + idx_g) % 2; break; - case checkerboard_64: adj(idx_m, idx_g) = (idx_m / 64 + idx_g) % 2; break; - case all_true: adj(idx_m, idx_g) = true; break; - case all_false: adj(idx_m, idx_g) = false; break; - default: assert(false && "unknown pattern"); - } - } - } - // Each group is of size n / num_groups. - // - // - group_idxs[j] indicates the start of group j + 1 (i.e. is the inclusive - // scan of the group lengths) - // - // - The first group always starts at index zero, so we do not store it. - // - // - The group_idxs[num_groups - 1] should always equal n. - - if (blockIdx.y == 0 && threadIdx.y == 0) { - const int g_stride = blockDim.x * gridDim.x; - for (int idx_g = blockIdx.x * blockDim.x + threadIdx.x; idx_g < num_groups; idx_g += g_stride) { - group_idxs(idx_g) = (idx_g + 1) * (n / num_groups); - } - group_idxs(num_groups - 1) = n; - } -} - -template -__global__ __launch_bounds__(32 * NWARPS, - 2) void referenceKernel(raft::KeyValuePair* min, - DataT* x, - DataT* y, - bool* adj, - int* group_idxs, - int m, - int n, - int k, - int num_groups, - bool sqrt, - int* workspace, - DataT maxVal) -{ - const int m_stride = blockDim.y * gridDim.y; - const int m_offset = threadIdx.y + blockIdx.y * blockDim.y; - const int n_stride = blockDim.x * gridDim.x; - const int n_offset = threadIdx.x + blockIdx.x * blockDim.x; - - for (int m_grid = 0; m_grid < m; m_grid += m_stride) { - for (int n_grid = 0; n_grid < n; n_grid += n_stride) { - int midx = m_grid + m_offset; - int nidx = n_grid + n_offset; - - // Do a reverse linear search to determine the group index. - int group_idx = 0; - for (int i = num_groups; 0 <= i; --i) { - if (nidx < group_idxs[i]) { group_idx = i; } - } - const bool include_dist = adj[midx * num_groups + group_idx] && midx < m && nidx < n; - - // Compute L2 metric. - DataT acc = DataT(0); - for (int i = 0; i < k; ++i) { - int xidx = i + midx * k; - int yidx = i + nidx * k; - auto diff = x[xidx] - y[yidx]; - acc += diff * diff; - } - if (sqrt) { acc = raft::sqrt(acc); } - ReduceOpT redOp; - typedef cub::WarpReduce> WarpReduce; - __shared__ typename WarpReduce::TempStorage temp[NWARPS]; - int warpId = threadIdx.x / raft::WarpSize; - raft::KeyValuePair tmp; - tmp.key = include_dist ? nidx : -1; - tmp.value = include_dist ? acc : maxVal; - tmp = WarpReduce(temp[warpId]).Reduce(tmp, raft::distance::KVPMinReduce{}); - if (threadIdx.x % raft::WarpSize == 0 && midx < m) { - while (atomicCAS(workspace + midx, 0, 1) == 1) - ; - __threadfence(); - redOp(midx, min + midx, tmp); - __threadfence(); - atomicCAS(workspace + midx, 1, 0); - } - __syncthreads(); - } - } -} - -// Structs -// - Params: holds parameters for test case -// - Inputs: holds the inputs to the functions under test (x, y, adj, group_idxs). Is generated from -// the inputs. -struct Params { - double tolerance; - int m, n, k, num_groups; - bool sqrt; - unsigned long long int seed; - AdjacencyPattern pattern; -}; - -inline auto operator<<(std::ostream& os, const Params& p) -> std::ostream& -{ - os << "m: " << p.m << ", n: " << p.n << ", k: " << p.k << ", num_groups: " << p.num_groups - << ", sqrt: " << p.sqrt << ", seed: " << p.seed << ", tol: " << p.tolerance; - return os; -} - -template -struct Inputs { - using IdxT = int; - - raft::device_matrix x, y; - raft::device_matrix adj; - raft::device_vector group_idxs; - - Inputs(const raft::handle_t& handle, const Params& p) - : x{raft::make_device_matrix(handle, p.m, p.k)}, - y{raft::make_device_matrix(handle, p.n, p.k)}, - adj{raft::make_device_matrix(handle, p.m, p.num_groups)}, - group_idxs{raft::make_device_vector(handle, p.num_groups)} - { - // Initialize x, y - raft::random::RngState r(p.seed); - uniform(handle, r, x.data_handle(), p.m * p.k, DataT(-1.0), DataT(1.0)); - uniform(handle, r, y.data_handle(), p.n * p.k, DataT(-1.0), DataT(1.0)); - - // Initialize adj, group_idxs. - dim3 block(32, 32); - dim3 grid(10, 10); - init_adj<<>>( - p.pattern, p.n, adj.view(), group_idxs.view()); - RAFT_CUDA_TRY(cudaGetLastError()); - } -}; - -template > -auto reference(const raft::handle_t& handle, Inputs inp, const Params& p) - -> raft::device_vector -{ - int m = inp.x.extent(0); - int n = inp.y.extent(0); - int k = inp.x.extent(1); - int num_groups = inp.group_idxs.extent(0); - - if (m == 0 || n == 0 || k == 0 || num_groups == 0) { - return raft::make_device_vector(handle, 0); - } - - // Initialize workspace - auto stream = handle.get_stream(); - rmm::device_uvector workspace(p.m * sizeof(int), stream); - RAFT_CUDA_TRY(cudaMemsetAsync(workspace.data(), 0, sizeof(int) * m, stream)); - - // Initialize output - auto out = raft::make_device_vector(handle, m); - auto blks = raft::ceildiv(m, 256); - MinAndDistanceReduceOp op; - raft::distance::detail::initKernel, int> - <<>>(out.data_handle(), m, std::numeric_limits::max(), op); - RAFT_CUDA_TRY(cudaGetLastError()); - - // Launch reference kernel - const int nwarps = 16; - static const dim3 TPB(32, nwarps, 1); - dim3 nblks(1, 200, 1); - referenceKernel - <<>>(out.data_handle(), - inp.x.data_handle(), - inp.y.data_handle(), - inp.adj.data_handle(), - inp.group_idxs.data_handle(), - m, - n, - k, - num_groups, - p.sqrt, - (int*)workspace.data(), - std::numeric_limits::max()); - RAFT_CUDA_TRY(cudaGetLastError()); - - return out; -} - -template > -auto run_masked_nn(const raft::handle_t& handle, Inputs inp, const Params& p) - -> raft::device_vector -{ - // Compute norms: - auto x_norm = raft::make_device_vector(handle, p.m); - auto y_norm = raft::make_device_vector(handle, p.n); - - raft::linalg::norm(handle, - std::as_const(inp.x).view(), - x_norm.view(), - raft::linalg::L2Norm, - raft::linalg::Apply::ALONG_ROWS); - raft::linalg::norm(handle, - std::as_const(inp.y).view(), - y_norm.view(), - raft::linalg::L2Norm, - raft::linalg::Apply::ALONG_ROWS); - - // Create parameters for maskedL2NN - using IdxT = int; - using RedOpT = MinAndDistanceReduceOp; - using PairRedOpT = raft::distance::KVPMinReduce; - using ParamT = raft::distance::MaskedL2NNParams; - - bool init_out = true; - ParamT masked_l2_params{RedOpT{}, PairRedOpT{}, p.sqrt, init_out}; - - // Create output - auto out = raft::make_device_vector(handle, p.m); - - // Launch kernel - raft::distance::maskedL2NN(handle, - masked_l2_params, - inp.x.view(), - inp.y.view(), - x_norm.view(), - y_norm.view(), - inp.adj.view(), - inp.group_idxs.view(), - out.view()); - - handle.sync_stream(); - - return out; -} - -template -struct CompareApproxAbsKVP { - typedef typename raft::KeyValuePair KVP; - CompareApproxAbsKVP(T eps_) : eps(eps_) {} - bool operator()(const KVP& a, const KVP& b) const - { - T diff = raft::abs(raft::abs(a.value) - raft::abs(b.value)); - T m = std::max(raft::abs(a.value), raft::abs(b.value)); - T ratio = m >= eps ? diff / m : diff; - return (ratio <= eps); - } - - private: - T eps; -}; - -template -::testing::AssertionResult devArrMatch(const raft::KeyValuePair* expected, - const raft::KeyValuePair* actual, - size_t size, - L eq_compare, - cudaStream_t stream = 0) -{ - typedef typename raft::KeyValuePair KVP; - std::shared_ptr exp_h(new KVP[size]); - std::shared_ptr act_h(new KVP[size]); - raft::update_host(exp_h.get(), expected, size, stream); - raft::update_host(act_h.get(), actual, size, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); - for (size_t i(0); i < size; ++i) { - auto exp = exp_h.get()[i]; - auto act = act_h.get()[i]; - if (!eq_compare(exp, act)) { - return ::testing::AssertionFailure() - << "actual=" << act.key << "," << act.value << " != expected=" << exp.key << "," - << exp.value << " @" << i; - } - } - return ::testing::AssertionSuccess(); -} - -inline auto gen_params() -> std::vector -{ - // Regular powers of two - auto regular = raft::util::itertools::product({0.001f}, // tolerance - {32, 64, 512}, // m - {32, 64, 512}, // n - {8, 32}, // k - {2, 32}, // num_groups - {true, false}, // sqrt - {1234ULL}, // seed - {AdjacencyPattern::all_true, - AdjacencyPattern::checkerboard, - AdjacencyPattern::checkerboard_64, - AdjacencyPattern::all_false}); - - // Irregular sizes to check tiling and bounds checking - auto irregular = raft::util::itertools::product({0.001f}, // tolerance - {511, 512, 513}, // m - {127, 128, 129}, // n - {5}, // k - {3, 9}, // num_groups - {true, false}, // sqrt - {1234ULL}, // seed - {AdjacencyPattern::all_true, - AdjacencyPattern::checkerboard, - AdjacencyPattern::checkerboard_64}); - - regular.insert(regular.end(), irregular.begin(), irregular.end()); - - return regular; -} - -class MaskedL2NNTest : public ::testing::TestWithParam { - // Empty. -}; - -// -TEST_P(MaskedL2NNTest, ReferenceCheckFloat) -{ - using DataT = float; - - // Get parameters; create handle and input data. - Params p = GetParam(); - raft::handle_t handle{}; - Inputs inputs{handle, p}; - - // Calculate reference and test output - auto out_reference = reference(handle, inputs, p); - auto out_fast = run_masked_nn(handle, inputs, p); - - // Check for differences. - ASSERT_TRUE(devArrMatch(out_reference.data_handle(), - out_fast.data_handle(), - p.m, - CompareApproxAbsKVP(p.tolerance), - handle.get_stream())); -} - -// This test checks whether running the maskedL2NN twice returns the same -// output. -TEST_P(MaskedL2NNTest, DeterminismCheck) -{ - using DataT = float; - - // Get parameters; create handle and input data. - Params p = GetParam(); - raft::handle_t handle{}; - Inputs inputs{handle, p}; - - // Calculate reference and test output - auto out1 = run_masked_nn(handle, inputs, p); - auto out2 = run_masked_nn(handle, inputs, p); - - // Check for differences. - ASSERT_TRUE(devArrMatch(out1.data_handle(), - out2.data_handle(), - p.m, - CompareApproxAbsKVP(p.tolerance), - handle.get_stream())); -} - -TEST_P(MaskedL2NNTest, ReferenceCheckDouble) -{ - using DataT = double; - - // Get parameters; create handle and input data. - Params p = GetParam(); - raft::handle_t handle{}; - Inputs inputs{handle, p}; - - // Calculate reference and test output - auto out_reference = reference(handle, inputs, p); - auto out_fast = run_masked_nn(handle, inputs, p); - - // Check for differences. - ASSERT_TRUE(devArrMatch(out_reference.data_handle(), - out_fast.data_handle(), - p.m, - CompareApproxAbsKVP(p.tolerance), - handle.get_stream())); -} - -INSTANTIATE_TEST_CASE_P(MaskedL2NNTests, MaskedL2NNTest, ::testing::ValuesIn(gen_params())); - -} // end namespace raft::distance::masked_nn diff --git a/cpp/test/distance/masked_nn_compress_to_bits.cu b/cpp/test/distance/masked_nn_compress_to_bits.cu deleted file mode 100644 index 7597362274..0000000000 --- a/cpp/test/distance/masked_nn_compress_to_bits.cu +++ /dev/null @@ -1,216 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "../test_utils.cuh" -#include "../test_utils.h" -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace raft::distance::masked_nn::compress_to_bits { - -/** - * @brief Transpose and decompress 2D bitfield to boolean matrix - * - * Inverse operation of compress_to_bits - * - * @tparam T - * - * @parameter[in] in An `m x n` bitfield matrix. Row major. - * @parameter in_rows The number of rows of `in`, i.e. `m`. - * @parameter in_cols The number of cols of `in`, i.e. `n`. - * - * @parameter[out] out An `(m * bits_per_elem) x n` boolean matrix. - */ -template ::value>> -__global__ void decompress_bits_kernel(const T* in, int in_rows, int in_cols, bool* out) -{ - constexpr int bits_per_element = 8 * sizeof(T); - - const size_t i = threadIdx.y + blockIdx.y * blockDim.y; - const size_t j = threadIdx.x + blockIdx.x * blockDim.x; - - if (in_rows <= i || in_cols <= j) { return; } - - const size_t out_rows = in_rows * bits_per_element; - const size_t out_cols = in_cols; - const size_t out_i = i * bits_per_element; - const size_t out_j = j; - - if (out_rows <= out_i && out_cols <= out_j) { return; } - - T bitfield = in[i * in_cols + j]; - for (int bitpos = 0; bitpos < bits_per_element; ++bitpos) { - bool bit = ((T(1) << bitpos) & bitfield) != 0; - out[(out_i + bitpos) * out_cols + out_j] = bit; - } -} - -/** - * @brief Transpose and decompress 2D bitfield to boolean matrix - * - * Inverse operation of compress_to_bits - * - * @tparam T - * - * @parameter[in] in An `m x n` bitfield matrix. Row major. - * @parameter in_rows The number of rows of `in`, i.e. `m`. - * @parameter in_cols The number of cols of `in`, i.e. `n`. - * - * @parameter[out] out An `n x (m * bits_per_elem)` boolean matrix. - */ -template ::value>> -void decompress_bits(const raft::handle_t& handle, const T* in, int in_rows, int in_cols, bool* out) -{ - auto stream = handle.get_stream(); - dim3 grid(raft::ceildiv(in_cols, 32), raft::ceildiv(in_rows, 32)); - dim3 block(32, 32); - decompress_bits_kernel<<>>(in, in_rows, in_cols, out); - RAFT_CUDA_TRY(cudaGetLastError()); -} - -// Params holds parameters for test case -struct Params { - int m, n; -}; - -inline auto operator<<(std::ostream& os, const Params& p) -> std::ostream& -{ - return os << "m: " << p.m << ", n: " << p.n; -} - -// Check that the following holds -// -// decompress(compress(x)) == x -// -// for 2D boolean matrices x. -template -void check_invertible(const Params& p) -{ - using raft::distance::detail::compress_to_bits; - constexpr int bits_per_elem = sizeof(T) * 8; - - // Make m and n that are safe to ceildiv. - int m = raft::round_up_safe(p.m, bits_per_elem); - int n = p.n; - - // Generate random input - raft::handle_t handle{}; - raft::random::RngState r(1ULL); - auto in = raft::make_device_matrix(handle, m, n); - raft::random::bernoulli(handle, r, in.data_handle(), m * n, 0.5f); - - int tmp_m = raft::ceildiv(m, bits_per_elem); - int out_m = tmp_m * bits_per_elem; - - auto tmp = raft::make_device_matrix(handle, tmp_m, n); - auto out = raft::make_device_matrix(handle, out_m, n); - - handle.sync_stream(); - RAFT_CUDA_TRY(cudaGetLastError()); - - ASSERT_EQ(in.extent(0), out.extent(0)) << "M does not match"; - ASSERT_EQ(in.extent(1), out.extent(1)) << "N does not match"; - - compress_to_bits(handle, in.view(), tmp.view()); - handle.sync_stream(); - RAFT_CUDA_TRY(cudaGetLastError()); - - decompress_bits(handle, tmp.data_handle(), tmp.extent(0), tmp.extent(1), out.data_handle()); - handle.sync_stream(); - RAFT_CUDA_TRY(cudaGetLastError()); - - // Check for differences. - ASSERT_TRUE(raft::devArrMatch(in.data_handle(), - out.data_handle(), - in.extent(0) * in.extent(1), - raft::Compare(), - handle.get_stream())); - handle.sync_stream(); - RAFT_CUDA_TRY(cudaGetLastError()); -} - -void check_all_true(const Params& p) -{ - using raft::distance::detail::compress_to_bits; - using T = uint64_t; - constexpr int bits_per_elem = sizeof(T) * 8; - - // Make m and n that are safe to ceildiv. - int m = raft::round_up_safe(p.m, bits_per_elem); - int n = p.n; - - raft::handle_t handle{}; - raft::random::RngState r(1ULL); - auto in = raft::make_device_matrix(handle, m, n); - raft::matrix::fill(handle, in.view(), true); - - int tmp_m = raft::ceildiv(m, bits_per_elem); - auto tmp = raft::make_device_matrix(handle, tmp_m, n); - handle.sync_stream(); - RAFT_CUDA_TRY(cudaGetLastError()); - - compress_to_bits(handle, in.view(), tmp.view()); - handle.sync_stream(); - RAFT_CUDA_TRY(cudaGetLastError()); - - auto expected = raft::make_device_matrix(handle, tmp_m, n); - raft::matrix::fill(handle, expected.view(), ~T(0)); - - // Check for differences. - ASSERT_TRUE(raft::devArrMatch(expected.data_handle(), - tmp.data_handle(), - tmp.extent(0) * tmp.extent(1), - raft::Compare(), - handle.get_stream())); - handle.sync_stream(); - RAFT_CUDA_TRY(cudaGetLastError()); -} - -class CompressToBitsTest : public ::testing::TestWithParam { - // Empty. -}; - -TEST_P(CompressToBitsTest, CheckTrue64) { check_all_true(GetParam()); } - -TEST_P(CompressToBitsTest, CheckInvertible64) -{ - using T = uint64_t; - check_invertible(GetParam()); -} - -TEST_P(CompressToBitsTest, CheckInvertible32) -{ - using T = uint32_t; - check_invertible(GetParam()); -} - -std::vector params = raft::util::itertools::product( - {1, 3, 32, 33, 63, 64, 65, 128, 10013}, {1, 3, 32, 33, 63, 64, 65, 13001}); - -INSTANTIATE_TEST_CASE_P(CompressToBits, CompressToBitsTest, ::testing::ValuesIn(params)); - -} // namespace raft::distance::masked_nn::compress_to_bits diff --git a/docs/source/cpp_api/distance.rst b/docs/source/cpp_api/distance.rst index 1632f19fba..eb9bc6255d 100644 --- a/docs/source/cpp_api/distance.rst +++ b/docs/source/cpp_api/distance.rst @@ -25,4 +25,3 @@ namespace *raft::distance* distance_pairwise.rst distance_1nn.rst - distance_masked_nn.rst diff --git a/docs/source/cpp_api/distance_masked_nn.rst b/docs/source/cpp_api/distance_masked_nn.rst deleted file mode 100644 index 89e23ba98a..0000000000 --- a/docs/source/cpp_api/distance_masked_nn.rst +++ /dev/null @@ -1,16 +0,0 @@ -Masked 1-Nearest Neighbors -========================== - -.. role:: py(code) - :language: c++ - :class: highlight - -``#include `` - -namespace *raft::distance* - -.. doxygengroup:: masked_nn - :project: RAFT - :members: - :content-only: -