diff --git a/ci/cpu/build.sh b/ci/cpu/build.sh index 657126fdf0..5bb09520a8 100755 --- a/ci/cpu/build.sh +++ b/ci/cpu/build.sh @@ -1,5 +1,5 @@ #!/bin/bash -# Copyright (c) 2022, NVIDIA CORPORATION. +# Copyright (c) 2022-2023, NVIDIA CORPORATION. ######################################### # RAFT CPU conda build script for CI # ######################################### diff --git a/ci/gpu/build.sh b/ci/gpu/build.sh index 84026203fa..30f026734a 100644 --- a/ci/gpu/build.sh +++ b/ci/gpu/build.sh @@ -1,7 +1,7 @@ #!/bin/bash -# Copyright (c) 2020-2022, NVIDIA CORPORATION. +# Copyright (c) 2020-2023, NVIDIA CORPORATION. ######################################### # RAFT GPU build and test script for CI # ######################################### diff --git a/cpp/bench/distance/distance_common.cuh b/cpp/bench/distance/distance_common.cuh index 7ddecd7579..7d51bfa608 100644 --- a/cpp/bench/distance/distance_common.cuh +++ b/cpp/bench/distance/distance_common.cuh @@ -24,14 +24,14 @@ namespace raft::bench::distance { -struct distance_inputs { +struct distance_params { int m, n, k; bool isRowMajor; -}; // struct distance_inputs +}; // struct distance_params template struct distance : public fixture { - distance(const distance_inputs& p) + distance(const distance_params& p) : params(p), x(p.m * p.k, stream), y(p.n * p.k, stream), @@ -63,13 +63,13 @@ struct distance : public fixture { } private: - distance_inputs params; + distance_params params; rmm::device_uvector x, y, out; rmm::device_uvector workspace; size_t worksize; }; // struct Distance -const std::vector dist_input_vecs{ +const std::vector dist_input_vecs{ {32, 16384, 16384, true}, {64, 16384, 16384, true}, {128, 16384, 16384, true}, {256, 16384, 16384, true}, {512, 16384, 16384, true}, {1024, 16384, 16384, true}, {16384, 32, 16384, true}, {16384, 64, 16384, true}, {16384, 128, 16384, true}, diff --git a/cpp/bench/distance/masked_nn.cu b/cpp/bench/distance/masked_nn.cu new file mode 100644 index 0000000000..a7b941d091 --- /dev/null +++ b/cpp/bench/distance/masked_nn.cu @@ -0,0 +1,267 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if defined RAFT_NN_COMPILED +#include +#endif + +namespace raft::bench::distance::masked_nn { + +// Introduce various sparsity patterns +enum AdjacencyPattern { + checkerboard = 0, + checkerboard_4 = 1, + checkerboard_64 = 2, + all_true = 3, + all_false = 4 +}; + +struct Params { + int m, n, k, num_groups; + AdjacencyPattern pattern; +}; // struct Params + +__global__ void init_adj(AdjacencyPattern pattern, + int n, + raft::device_matrix_view adj, + raft::device_vector_view group_idxs) +{ + int m = adj.extent(0); + int num_groups = adj.extent(1); + + for (int idx_m = blockIdx.y * blockDim.y + threadIdx.y; idx_m < m; + idx_m += blockDim.y * gridDim.y) { + for (int idx_g = blockIdx.x * blockDim.x + threadIdx.x; idx_g < num_groups; + idx_g += blockDim.x * gridDim.x) { + switch (pattern) { + case checkerboard: adj(idx_m, idx_g) = (idx_m + idx_g) % 2; break; + case checkerboard_4: adj(idx_m, idx_g) = (idx_m / 4 + idx_g) % 2; break; + case checkerboard_64: adj(idx_m, idx_g) = (idx_m / 64 + idx_g) % 2; break; + case all_true: adj(idx_m, idx_g) = true; break; + case all_false: adj(idx_m, idx_g) = false; break; + default: assert(false && "unknown pattern"); + } + } + } + // Each group is of size n / num_groups. + // + // - group_idxs[j] indicates the start of group j + 1 (i.e. is the inclusive + // scan of the group lengths) + // + // - The first group always starts at index zero, so we do not store it. + // + // - The group_idxs[num_groups - 1] should always equal n. + + if (blockIdx.y == 0 && threadIdx.y == 0) { + const int g_stride = blockDim.x * gridDim.x; + for (int idx_g = blockIdx.x * blockDim.x + threadIdx.x; idx_g < num_groups; idx_g += g_stride) { + group_idxs(idx_g) = (idx_g + 1) * (n / num_groups); + } + group_idxs(num_groups - 1) = n; + } +} + +template +struct masked_l2_nn : public fixture { + using DataT = T; + using IdxT = int; + using OutT = raft::KeyValuePair; + using RedOpT = raft::distance::MinAndDistanceReduceOp; + using PairRedOpT = raft::distance::KVPMinReduce; + using ParamT = raft::distance::MaskedL2NNParams; + + // Parameters + Params params; + // Data + raft::device_vector out; + raft::device_matrix x, y; + raft::device_vector xn, yn; + raft::device_matrix adj; + raft::device_vector group_idxs; + + masked_l2_nn(const Params& p) + : params(p), + out{raft::make_device_vector(handle, p.m)}, + x{raft::make_device_matrix(handle, p.m, p.k)}, + y{raft::make_device_matrix(handle, p.n, p.k)}, + xn{raft::make_device_vector(handle, p.m)}, + yn{raft::make_device_vector(handle, p.n)}, + adj{raft::make_device_matrix(handle, p.m, p.num_groups)}, + group_idxs{raft::make_device_vector(handle, p.num_groups)} + { + raft::random::RngState r(123456ULL); + + uniform(handle, r, x.data_handle(), p.m * p.k, T(-1.0), T(1.0)); + uniform(handle, r, y.data_handle(), p.n * p.k, T(-1.0), T(1.0)); + raft::linalg::rowNorm( + xn.data_handle(), x.data_handle(), p.k, p.m, raft::linalg::L2Norm, true, stream); + raft::linalg::rowNorm( + yn.data_handle(), y.data_handle(), p.k, p.n, raft::linalg::L2Norm, true, stream); + raft::distance::initialize, int>( + handle, out.data_handle(), p.m, std::numeric_limits::max(), RedOpT{}); + + dim3 block(32, 32); + dim3 grid(10, 10); + init_adj<<>>(p.pattern, p.n, adj.view(), group_idxs.view()); + RAFT_CUDA_TRY(cudaGetLastError()); + } + + void run_benchmark(::benchmark::State& state) override + { + bool init_out = true; + bool sqrt = false; + ParamT masked_l2_params{RedOpT{}, PairRedOpT{}, sqrt, init_out}; + + loop_on_state(state, [this, masked_l2_params]() { + // It is sufficient to only benchmark the L2-squared metric + raft::distance::maskedL2NN(handle, + masked_l2_params, + x.view(), + y.view(), + xn.view(), + yn.view(), + adj.view(), + group_idxs.view(), + out.view()); + }); + + // Virtual flop count if no skipping had occurred. + size_t virtual_flops = size_t(2) * size_t(params.m) * size_t(params.n) * size_t(params.k); + + int64_t read_elts = params.n * params.k + params.m * params.k; + int64_t write_elts = params.m; + + // Virtual min flops is the number of flops that would have been executed if + // the algorithm had actually skipped each computation that it could have + // skipped. + size_t virtual_min_flops = 0; + switch (params.pattern) { + case checkerboard: + case checkerboard_4: + case checkerboard_64: virtual_min_flops = virtual_flops / 2; break; + case all_true: virtual_min_flops = virtual_flops; break; + case all_false: virtual_min_flops = 0; break; + default: assert(false && "unknown pattern"); + } + + // VFLOP/s is the "virtual" flop count that would have executed if there was + // no adjacency pattern. This is useful for comparing to fusedL2NN + state.counters["VFLOP/s"] = benchmark::Counter(virtual_flops, + benchmark::Counter::kIsIterationInvariantRate, + benchmark::Counter::OneK::kIs1000); + // Virtual min flops is the number of flops that would have been executed if + // the algorithm had actually skipped each computation that it could have + // skipped. + state.counters["VminFLOP/s"] = benchmark::Counter(virtual_min_flops, + benchmark::Counter::kIsIterationInvariantRate, + benchmark::Counter::OneK::kIs1000); + + state.counters["BW Wr"] = benchmark::Counter(write_elts * sizeof(OutT), + benchmark::Counter::kIsIterationInvariantRate, + benchmark::Counter::OneK::kIs1000); + state.counters["BW Rd"] = benchmark::Counter(read_elts * sizeof(DataT), + benchmark::Counter::kIsIterationInvariantRate, + benchmark::Counter::OneK::kIs1000); + + state.counters["m"] = benchmark::Counter(params.m); + state.counters["n"] = benchmark::Counter(params.n); + state.counters["k"] = benchmark::Counter(params.k); + state.counters["num_groups"] = benchmark::Counter(params.num_groups); + state.counters["group size"] = benchmark::Counter(params.n / params.num_groups); + state.counters["Pat"] = benchmark::Counter(static_cast(params.pattern)); + + state.counters["SM count"] = raft::getMultiProcessorCount(); + } +}; // struct MaskedL2NN + +const std::vector masked_l2_nn_input_vecs = { + // Very fat matrices... + {32, 16384, 16384, 32, AdjacencyPattern::checkerboard}, + {64, 16384, 16384, 32, AdjacencyPattern::checkerboard}, + {128, 16384, 16384, 32, AdjacencyPattern::checkerboard}, + {256, 16384, 16384, 32, AdjacencyPattern::checkerboard}, + {512, 16384, 16384, 32, AdjacencyPattern::checkerboard}, + {1024, 16384, 16384, 32, AdjacencyPattern::checkerboard}, + {16384, 32, 16384, 32, AdjacencyPattern::checkerboard}, + {16384, 64, 16384, 32, AdjacencyPattern::checkerboard}, + {16384, 128, 16384, 32, AdjacencyPattern::checkerboard}, + {16384, 256, 16384, 32, AdjacencyPattern::checkerboard}, + {16384, 512, 16384, 32, AdjacencyPattern::checkerboard}, + {16384, 1024, 16384, 32, AdjacencyPattern::checkerboard}, + + // Representative matrices... + {16384, 16384, 32, 32, AdjacencyPattern::checkerboard}, + {16384, 16384, 64, 32, AdjacencyPattern::checkerboard}, + {16384, 16384, 128, 32, AdjacencyPattern::checkerboard}, + {16384, 16384, 256, 32, AdjacencyPattern::checkerboard}, + {16384, 16384, 512, 32, AdjacencyPattern::checkerboard}, + {16384, 16384, 1024, 32, AdjacencyPattern::checkerboard}, + {16384, 16384, 16384, 32, AdjacencyPattern::checkerboard}, + + {16384, 16384, 32, 32, AdjacencyPattern::checkerboard_4}, + {16384, 16384, 64, 32, AdjacencyPattern::checkerboard_4}, + {16384, 16384, 128, 32, AdjacencyPattern::checkerboard_4}, + {16384, 16384, 256, 32, AdjacencyPattern::checkerboard_4}, + {16384, 16384, 512, 32, AdjacencyPattern::checkerboard_4}, + {16384, 16384, 1024, 32, AdjacencyPattern::checkerboard_4}, + {16384, 16384, 16384, 32, AdjacencyPattern::checkerboard_4}, + + {16384, 16384, 32, 32, AdjacencyPattern::checkerboard_64}, + {16384, 16384, 64, 32, AdjacencyPattern::checkerboard_64}, + {16384, 16384, 128, 32, AdjacencyPattern::checkerboard_64}, + {16384, 16384, 256, 32, AdjacencyPattern::checkerboard_64}, + {16384, 16384, 512, 32, AdjacencyPattern::checkerboard_64}, + {16384, 16384, 1024, 32, AdjacencyPattern::checkerboard_64}, + {16384, 16384, 16384, 32, AdjacencyPattern::checkerboard_64}, + + {16384, 16384, 32, 32, AdjacencyPattern::all_true}, + {16384, 16384, 64, 32, AdjacencyPattern::all_true}, + {16384, 16384, 128, 32, AdjacencyPattern::all_true}, + {16384, 16384, 256, 32, AdjacencyPattern::all_true}, + {16384, 16384, 512, 32, AdjacencyPattern::all_true}, + {16384, 16384, 1024, 32, AdjacencyPattern::all_true}, + {16384, 16384, 16384, 32, AdjacencyPattern::all_true}, + + {16384, 16384, 32, 32, AdjacencyPattern::all_false}, + {16384, 16384, 64, 32, AdjacencyPattern::all_false}, + {16384, 16384, 128, 32, AdjacencyPattern::all_false}, + {16384, 16384, 256, 32, AdjacencyPattern::all_false}, + {16384, 16384, 512, 32, AdjacencyPattern::all_false}, + {16384, 16384, 1024, 32, AdjacencyPattern::all_false}, + {16384, 16384, 16384, 32, AdjacencyPattern::all_false}, +}; + +RAFT_BENCH_REGISTER(masked_l2_nn, "", masked_l2_nn_input_vecs); +// We don't benchmark double to keep compile times in check when not using the +// distance library. + +} // namespace raft::bench::distance::masked_nn \ No newline at end of file diff --git a/cpp/include/raft/distance/detail/compress_to_bits.cuh b/cpp/include/raft/distance/detail/compress_to_bits.cuh new file mode 100644 index 0000000000..d5bf1f421d --- /dev/null +++ b/cpp/include/raft/distance/detail/compress_to_bits.cuh @@ -0,0 +1,122 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include + +namespace raft::distance::detail { + +/** + * @brief Compress 2D boolean matrix to bitfield + * + * Utility kernel for maskedL2NN. + * + * @tparam T + * + * @parameter[in] in An `m x n` boolean matrix. Row major. + * @parameter[out] out An `(m / bits_per_elem) x n` matrix with elements of + * type T, where T is of size `bits_per_elem` bits. + * Note: the division (`/`) is a ceilDiv. + */ +template ::value>> +__global__ void compress_to_bits_kernel( + raft::device_matrix_view in, + raft::device_matrix_view out) +{ + constexpr int bits_per_element = 8 * sizeof(T); + constexpr int tile_dim_m = bits_per_element; + constexpr int nthreads = 128; + constexpr int tile_dim_n = nthreads; // read 128 bools at once = 1 sector + + // Tile in shared memory is transposed + __shared__ bool smem[tile_dim_n][tile_dim_m]; + + const int num_tiles_per_m = raft::ceildiv(in.extent(0), tile_dim_m); + const int num_tiles_per_n = raft::ceildiv(in.extent(1), tile_dim_n); + + for (int lin_tile_idx = blockIdx.x; true; lin_tile_idx += gridDim.x) { + const int tile_idx_n = tile_dim_n * (lin_tile_idx % num_tiles_per_n); + const int tile_idx_m = tile_dim_m * (lin_tile_idx / num_tiles_per_n); + + if (in.extent(0) <= tile_idx_m) { break; } + // Fill shared memory tile + bool reg_buf[tile_dim_m]; +#pragma unroll + for (int i = 0; i < tile_dim_m; ++i) { + const int in_m = tile_idx_m + i; + const int in_n = tile_idx_n + threadIdx.x; + bool in_bounds = in_m < in.extent(0) && in_n < in.extent(1); + reg_buf[i] = in_bounds ? in(in_m, in_n) : false; + smem[threadIdx.x][i] = reg_buf[i]; + } + __syncthreads(); + + // Drain memory tile into single output element out_elem. + T out_elem{0}; +#pragma unroll + for (int j = 0; j < tile_dim_n; ++j) { + if (smem[threadIdx.x][j]) { out_elem |= T(1) << j; } + } + __syncthreads(); + + // Write output. + int out_m = tile_idx_m / bits_per_element; + int out_n = tile_idx_n + threadIdx.x; + + if (out_m < out.extent(0) && out_n < out.extent(1)) { out(out_m, out_n) = out_elem; } + } +} + +/** + * @brief Compress 2D boolean matrix to bitfield + * + * Utility kernel for maskedL2NN. + * + * @tparam T + * + * @parameter[in] in An `m x n` boolean matrix. Row major. + * @parameter[out] out An `(m / bits_per_elem) x n` matrix with elements of + * type T, where T is of size `bits_per_elem` bits. + * Note: the division (`/`) is a ceilDiv. + */ +template ::value>> +void compress_to_bits(raft::device_resources const& handle, + raft::device_matrix_view in, + raft::device_matrix_view out) +{ + auto stream = handle.get_stream(); + constexpr int bits_per_element = 8 * sizeof(T); + + RAFT_EXPECTS(raft::ceildiv(in.extent(0), bits_per_element) == out.extent(0), + "Number of output rows must be ceildiv(input rows, bits_per_elem)"); + RAFT_EXPECTS(in.extent(1) == out.extent(1), "Number of output columns must equal input columns."); + + const int num_SMs = raft::getMultiProcessorCount(); + int blocks_per_sm = 0; + constexpr int num_threads = 128; + constexpr int dyn_smem_size = 0; + RAFT_CUDA_TRY(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &blocks_per_sm, compress_to_bits_kernel, num_threads, dyn_smem_size)); + + dim3 grid(num_SMs * blocks_per_sm); + dim3 block(128); + compress_to_bits_kernel<<>>(in, out); + RAFT_CUDA_TRY(cudaGetLastError()); +} + +}; // namespace raft::distance::detail \ No newline at end of file diff --git a/cpp/include/raft/distance/detail/fused_l2_nn.cuh b/cpp/include/raft/distance/detail/fused_l2_nn.cuh index 447359ffe6..8fbd7a9c69 100644 --- a/cpp/include/raft/distance/detail/fused_l2_nn.cuh +++ b/cpp/include/raft/distance/detail/fused_l2_nn.cuh @@ -37,6 +37,7 @@ template struct KVPMinReduceImpl { typedef raft::KeyValuePair KVP; DI KVP operator()(LabelT rit, const KVP& a, const KVP& b) { return b.value < a.value ? b : a; } + DI KVP operator()(const KVP& a, const KVP& b) { return b.value < a.value ? b : a; } }; // KVPMinReduce diff --git a/cpp/include/raft/distance/detail/masked_distance_base.cuh b/cpp/include/raft/distance/detail/masked_distance_base.cuh new file mode 100644 index 0000000000..a383568be9 --- /dev/null +++ b/cpp/include/raft/distance/detail/masked_distance_base.cuh @@ -0,0 +1,326 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include +#include +#include + +#include + +namespace raft { +namespace distance { +namespace detail { + +/** + * @brief Device class for masked nearest neighbor computations. + * + * @tparam useNorms whether norms are needed + * @tparam DataT input data-type (for x and y matrices) + * @tparam AccT accumulation data-type + * @tparam IdxT index data-type + * @tparam Policy struct which tunes the Contraction kernel + * @tparam CoreLambda tells how to accumulate an x and y into + acc. its signature: + template void core_lambda(AccT& acc, + const DataT& x, const DataT& y) + * @tparam EpilogueLambda applies an elementwise function to compute final + values. Its signature is: + template void epilogue_lambda + (AccT acc[][], DataT* regxn, DataT* regyn); + * @tparam FinalLambda the final lambda called on final distance value + * @tparam rowEpilogueLambda epilog lambda that executes when a full row has + * been processed. + * + * @param[in] x input matrix + * @param[in] y input matrix + * @param[in] m number of rows of x + * @param[in] n number of columns of y + * @param[in] k number of cols of x and y + * @param[in] lda leading dimension of x + * @param[in] ldb leading dimension of y + * @param[in] ldd parameter to keep Contractions_NT happy.. + * @param[in] xn row norms of input matrix A. Required for expanded L2, cosine + * @param[in] yn row norms of input matrix B. Required for expanded L2, cosine + * @param[in] adj An adjacency matrix encoded as a bitfield indicating for each + * row of `x` and each group in `y` whether to compute the + * distance. Dim = `(m / 64) x num_groups`. + * @param[in] group_idxs An array containing the *end* indices of each group + * in `y`. The value of group_idxs[j] indicates the + * start of group j + 1, i.e., it is the inclusive + * scan of the group lengths. The first group is + * always assumed to start at index 0 and the last + * group typically ends at index `n`. Length = + * `num_groups`. + * @param[in] num_groups The number of groups in group_idxs. + * @param[in] smem shared mem buffer for intermediate storage of x, y, xn & yn. + * @param core_op the core accumulation operation lambda + * @param epilog_op the epilog operation lambda + * @param fin_op the final gemm epilogue lambda + * @param rowEpilog_op epilog lambda that executes when a full row has been processed. + */ +template > +struct MaskedDistances : public BaseClass { + private: + typedef Policy P; + const DataT* xn; + const DataT* yn; + const DataT* const yBase; + const uint64_t* adj; + const IdxT* group_idxs; + IdxT num_groups; + char* smem; + CoreLambda core_op; + EpilogueLambda epilog_op; + FinalLambda fin_op; + rowEpilogueLambda rowEpilog_op; + + AccT acc[P::AccRowsPerTh][P::AccColsPerTh]; + + public: + // Constructor + DI MaskedDistances(const DataT* _x, + const DataT* _y, + IdxT _m, + IdxT _n, + IdxT _k, + IdxT _lda, + IdxT _ldb, + IdxT _ldd, + const DataT* _xn, + const DataT* _yn, + const uint64_t* _adj, + const IdxT* _group_idxs, + IdxT _num_groups, + char* _smem, + CoreLambda _core_op, + EpilogueLambda _epilog_op, + FinalLambda _fin_op, + rowEpilogueLambda _rowEpilog_op) + : BaseClass(_x, _y, _m, _n, _k, _lda, _ldb, _ldd, _smem), + xn(_xn), + yn(_yn), + yBase(_y), + adj(_adj), + group_idxs(_group_idxs), + num_groups(_num_groups), + smem(_smem), + core_op(_core_op), + epilog_op(_epilog_op), + fin_op(_fin_op), + rowEpilog_op(_rowEpilog_op) + { + } + + DI void run() + { + const auto grid_stride_m = (P::Mblk * gridDim.y); + const auto grid_offset_m = (P::Mblk * blockIdx.y); + + const auto grid_stride_g = gridDim.x; + const auto grid_offset_g = blockIdx.x; + + for (auto tile_idx_m = grid_offset_m; tile_idx_m < this->m; tile_idx_m += grid_stride_m) { + // Start loop over groups + for (auto idx_g = grid_offset_g; idx_g < this->num_groups; idx_g += grid_stride_g) { + const uint64_t block_adj = get_block_adjacency(adj, tile_idx_m, idx_g); + // block_adj is a bitfield that contains a 1 if a row is adjacent to the + // current group. All zero means we can skip this group. + if (block_adj == 0) { continue; } + + // thread_adj is a bitfield that contains a 1 at location i iff we must + // compute row i of acc (the accumulator register tile). That is, + // for i = 0,.., AccRowsPerTh and j = 0,.., AccColsPerTh: + // + // ((1 << i) & thread_adj) > 0 <=> acc[i][j] must be computed. + // + // We precompute this information because it is used in various + // locations to skip thread-local computations, specifically: + // + // 1. To skip computations if thread_adj == 0, i.e., none of the values + // of `acc` have to be computed. + // + // 2. In epilog_op, to consider only values of `acc` to be reduced that + // are not masked of. + // + // Note 1: Even when the computation can be skipped for a specific thread, + // the thread still participates in synchronization operations. + // + // Note 2: In theory, it should be possible to skip computations for + // specific rows of `acc`. In practice, however, this does not improve + // performance. + int thread_adj = compute_thread_adjacency(block_adj); + + auto tile_idx_n = idx_g == 0 ? 0 : group_idxs[idx_g - 1]; + const auto group_end_n = group_idxs[idx_g]; + for (; tile_idx_n < group_end_n; tile_idx_n += P::Nblk) { + // We provide group_end_n to limit the number of unnecessary data + // points that are loaded from y. + this->ldgXY(tile_idx_m, tile_idx_n, 0, group_end_n); + + reset_accumulator(); + this->stsXY(); + __syncthreads(); + this->switch_write_buffer(); + + for (int kidx = P::Kblk; kidx < this->k; kidx += P::Kblk) { + this->ldgXY(tile_idx_m, tile_idx_n, kidx, group_end_n); + // Process all data in shared memory (previous k-block) and + // accumulate in registers. + if (thread_adj != 0) { accumulate(); } + this->stsXY(); + __syncthreads(); + this->switch_write_buffer(); + this->switch_read_buffer(); + } + if (thread_adj != 0) { + accumulate(); // last iteration + } + // The pre-condition for the loop over tile_idx_n is that write_buffer + // and read_buffer point to the same buffer. This flips read_buffer + // back so that it satisfies the pre-condition of this loop. + this->switch_read_buffer(); + + if (useNorms) { + DataT regxn[P::AccRowsPerTh], regyn[P::AccColsPerTh]; + load_norms(tile_idx_m, tile_idx_n, group_end_n, regxn, regyn); + if (thread_adj != 0) { + epilog_op(acc, thread_adj, regxn, regyn, tile_idx_n, tile_idx_m, group_end_n); + } + } else { + if (thread_adj != 0) { + epilog_op(acc, thread_adj, nullptr, nullptr, tile_idx_n, tile_idx_m, group_end_n); + } + } + } // tile_idx_n + } // idx_g + rowEpilog_op(tile_idx_m); + } // tile_idx_m + } + + private: + DI uint64_t get_block_adjacency(const uint64_t* adj, IdxT tile_idx_m, IdxT idx_group) + { + // A single element of `adj` contains exactly enough bits to indicate which + // rows in the current tile to skip and which to compute. + static_assert(P::Mblk == 8 * sizeof(adj[0]), + "maskedL2NN only supports a policy with 64 rows per block."); + IdxT block_flag_idx = tile_idx_m / P::Mblk; + // Index into adj at row tile_idx_m / 64 and column idx_group. + return adj[block_flag_idx * this->num_groups + idx_group]; + } + + DI uint32_t compute_thread_adjacency(const uint64_t block_adj) + { + // thread_adj is a bitfield that contains a 1 at location i iff we must + // compute row i of acc (the accumulator register tile). It is described in + // more detail in the run() method. + uint32_t thread_adj = 0; +#pragma unroll + for (int thread_row_idx = 0; thread_row_idx < P::AccRowsPerTh; ++thread_row_idx) { + // Index `thread_row_idx` refers to a row of the current threads' register + // tile `acc`, i.e., acc[i][:]. Index `block_row_idx` refers to the + // corresponding row of the current block tile in shared memory. + const int block_row_idx = this->accrowid + thread_row_idx * P::AccThRows; + + // block_row_is_adjacent is true if the current block_row_idx is adjacent + // to the current group. + const uint64_t block_mask = 1ull << block_row_idx; + const bool block_row_is_adjacent = (block_adj & block_mask) != 0; + if (block_row_is_adjacent) { + // If block row is adjacent, write a 1 bit to thread_adj at location + // `thread_row_idx`. + const uint32_t thread_mask = 1 << thread_row_idx; + thread_adj |= thread_mask; + } + } + return thread_adj; + } + + DI void reset_accumulator() + { + // Reset accumulator registers to zero. +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { +#pragma unroll + for (int j = 0; j < P::AccColsPerTh; ++j) { + acc[i][j] = BaseClass::Zero; + } + } + } + + DI void accumulate() + { +#pragma unroll + for (int ki = 0; ki < P::Kblk; ki += P::Veclen) { + this->ldsXY(ki); +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { +#pragma unroll + for (int j = 0; j < P::AccColsPerTh; ++j) { +#pragma unroll + for (int v = 0; v < P::Veclen; ++v) { + core_op(acc[i][j], this->regx[i][v], this->regy[j][v]); + } + } + } + } + } + + DI void load_norms(IdxT tile_idx_m, + IdxT tile_idx_n, + IdxT end_n, + DataT (®xn)[P::AccRowsPerTh], + DataT (®yn)[P::AccColsPerTh]) + { + DataT* sxNorm = (DataT*)(&smem[P::SmemSize]); + DataT* syNorm = (&sxNorm[P::Mblk]); + + // Load x & y norms required by this threadblock in shmem buffer + for (int i = threadIdx.x; i < P::Mblk; i += P::Nthreads) { + auto idx = tile_idx_m + i; + sxNorm[i] = idx < this->m ? xn[idx] : 0; + } + + for (int i = threadIdx.x; i < P::Nblk; i += P::Nthreads) { + auto idx = tile_idx_n + i; + syNorm[i] = idx < end_n ? yn[idx] : 0; + } + __syncthreads(); + +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { + regxn[i] = sxNorm[i * P::AccThRows + (threadIdx.x / P::AccThCols)]; + } +#pragma unroll + for (int i = 0; i < P::AccColsPerTh; ++i) { + regyn[i] = syNorm[i * P::AccThCols + (threadIdx.x % P::AccThCols)]; + } + } +}; // struct MaskedDistances + +}; // namespace detail +}; // namespace distance +}; // namespace raft \ No newline at end of file diff --git a/cpp/include/raft/distance/detail/masked_nn.cuh b/cpp/include/raft/distance/detail/masked_nn.cuh new file mode 100644 index 0000000000..87000e9e6e --- /dev/null +++ b/cpp/include/raft/distance/detail/masked_nn.cuh @@ -0,0 +1,325 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace raft { +namespace distance { +namespace detail { + +template +__global__ __launch_bounds__(P::Nthreads, 2) void maskedL2NNkernel(OutT* min, + const DataT* x, + const DataT* y, + const DataT* xn, + const DataT* yn, + const uint64_t* adj, + const IdxT* group_idxs, + IdxT num_groups, + IdxT m, + IdxT n, + IdxT k, + bool sqrt, + DataT maxVal, + int* mutex, + ReduceOpT redOp, + KVPReduceOpT pairRedOp, + CoreLambda core_op, + FinalLambda fin_op) +{ + extern __shared__ char smem[]; + + typedef raft::KeyValuePair KVPair; + KVPair val[P::AccRowsPerTh]; +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { + val[i] = {-1, maxVal}; + } + + // epilogue operation lambda for final value calculation + auto epilog_lambda = [pairRedOp, &val, maxVal, sqrt] __device__( + DataT acc[P::AccRowsPerTh][P::AccColsPerTh], + int thread_adj, + DataT* regxn, + DataT* regyn, + IdxT tile_idx_n, + IdxT tile_idx_m, + IdxT tile_end_n) { + KVPReduceOpT pairRed_op(pairRedOp); + +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { +#pragma unroll + for (int j = 0; j < P::AccColsPerTh; ++j) { + acc[i][j] = regxn[i] + regyn[j] - (DataT)2.0 * acc[i][j]; + } + } + if (sqrt) { +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { +#pragma unroll + for (int j = 0; j < P::AccColsPerTh; ++j) { + acc[i][j] = raft::sqrt(acc[i][j]); + } + } + } + + // intra thread reduce + const auto acccolid = threadIdx.x % P::AccThCols; + const auto accrowid = threadIdx.x / P::AccThCols; + +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { + // thread_adj is a bitfield that contains a 1 at location i iff we must + // compute row i of acc (the accumulator register tile). It is described in + // more detail in the maskedDistances.run() method. + const bool ignore = (thread_adj & (1 << i)) == 0; + if (ignore) { continue; } +#pragma unroll + for (int j = 0; j < P::AccColsPerTh; ++j) { + auto tmpkey = acccolid + j * P::AccThCols + tile_idx_n; + if (tile_end_n <= tmpkey) { + // Do not process beyond end of tile. + continue; + } + KVPair tmp = {tmpkey, acc[i][j]}; + if (tmpkey < tile_end_n) { + val[i] = pairRed_op(accrowid + i * P::AccThRows + tile_idx_m, tmp, val[i]); + } + } + } + }; + + auto rowEpilog_lambda = + [m, mutex, min, pairRedOp, redOp, &val, maxVal] __device__(IdxT tile_idx_m) { + KVPReduceOpT pairRed_op(pairRedOp); + ReduceOpT red_op(redOp); + + const auto accrowid = threadIdx.x / P::AccThCols; + const auto lid = raft::laneId(); + // reduce +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { +#pragma unroll + for (int j = P::AccThCols / 2; j > 0; j >>= 1) { + auto tmpkey = raft::shfl(val[i].key, lid + j); + auto tmpvalue = raft::shfl(val[i].value, lid + j); + KVPair tmp = {tmpkey, tmpvalue}; + val[i] = pairRed_op(accrowid + i * P::AccThRows + tile_idx_m, tmp, val[i]); + } + } + + updateReducedVal(mutex, min, val, red_op, m, tile_idx_m); + + // reset the val array. +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { + val[i] = {-1, maxVal}; + } + }; + + IdxT lda = k, ldb = k, ldd = n; + MaskedDistances + obj(x, + y, + m, + n, + k, + lda, + ldb, + ldd, + xn, + yn, + adj, + group_idxs, + num_groups, + smem, + core_op, + epilog_lambda, + fin_op, + rowEpilog_lambda); + obj.run(); +} + +/** + * @brief Wrapper for maskedL2NNkernel + * + * Responsibilities: + * - Allocate (and initialize) workspace memory for: + * - mutexes used in nearest neighbor update step + * - adjacency matrix bitfield + * - Compress adjacency matrix to bitfield + * - Initialize output buffer (conditional on `initOutBuffer`) + * - Specify core and final operations for the L2 norm + * - Determine optimal launch configuration for kernel. + * - Launch kernel and check for errors. + * + * @tparam DataT Input data-type (for x and y matrices). + * @tparam OutT Output data-type (for key-value pairs). + * @tparam IdxT Index data-type. + * @tparam ReduceOpT A struct to perform the final needed reduction + * operation and also to initialize the output array + * elements with the appropriate initial value needed for + * reduction. + * @tparam KVPReduceOpT Type of Reduction operation on key value pairs. + * + * @param handle RAFT handle for managing expensive resources + * @param[out] out Will contain reduced output (nn key-value pairs) + * @param[in] x First matrix. Row major. Dim = `m x k`. (on device) + * @param[in] y Second matrix. Row major. Dim = `n x k`. (on device) + * @param[in] xn L2 squared norm of `x`. Length = `m`. + * @param[in] yn L2 squared norm of `y`. Length = `n`. + * @param[in] adj A boolean adjacency matrix indicating for each + * row of `x` and each group in `y` whether to compute the + * distance. Dim = `m x num_groups`. + * @param[in] group_idxs An array containing the *end* indices of each group + * in `y`. The value of group_idxs[j] indicates the + * start of group j + 1, i.e., it is the inclusive + * scan of the group lengths. The first group is + * always assumed to start at index 0 and the last + * group typically ends at index `n`. Length = + * `num_groups`. + * @param[in] num_groups Length of `group_idxs`. + * @param m Rows of `x`. + * @param n Rows of `y`. + * @param k Cols of `x` and `y`. + * @param redOp Reduction operator in the epilogue + * @param pairRedOp Reduction operation on key value pairs + * @param sqrt Whether to compute the squared or actual (i.e. sqrt) L2 norm. + * @param initOutBuffer Whether to initialize the output buffer + * + * + */ +template +void maskedL2NNImpl(raft::device_resources const& handle, + OutT* out, + const DataT* x, + const DataT* y, + const DataT* xn, + const DataT* yn, + const bool* adj, + const IdxT* group_idxs, + IdxT num_groups, + IdxT m, + IdxT n, + IdxT k, + ReduceOpT redOp, + KVPReduceOpT pairRedOp, + bool sqrt, + bool initOutBuffer) +{ + typedef typename linalg::Policy4x4::Policy P; + + static_assert(P::Mblk == 64, "maskedL2NNImpl only supports a policy with 64 rows per block."); + + // Get stream and workspace memory resource + rmm::mr::device_memory_resource* ws_mr = + dynamic_cast(handle.get_workspace_resource()); + auto stream = handle.get_stream(); + + // Acquire temporary buffers and initialize to zero: + // 1) Adjacency matrix bitfield + // 2) Workspace for fused nearest neighbor operation + size_t m_div_64 = raft::ceildiv(m, IdxT(64)); + rmm::device_uvector ws_adj64{m_div_64 * num_groups, stream, ws_mr}; + rmm::device_uvector ws_fused_nn{size_t(m), stream, ws_mr}; + RAFT_CUDA_TRY(cudaMemsetAsync(ws_adj64.data(), 0, ws_adj64.size() * sizeof(uint64_t), stream)); + RAFT_CUDA_TRY(cudaMemsetAsync(ws_fused_nn.data(), 0, ws_fused_nn.size() * sizeof(int), stream)); + + // Compress boolean adjacency matrix to bitfield. + auto adj_view = raft::make_device_matrix_view(adj, m, num_groups); + auto adj64_view = + raft::make_device_matrix_view(ws_adj64.data(), m_div_64, num_groups); + compress_to_bits(handle, adj_view, adj64_view); + + // Initialize output buffer with keyvalue pairs as determined by the reduction + // operator (it will be called with maxVal). + constexpr auto maxVal = std::numeric_limits::max(); + if (initOutBuffer) { + dim3 grid(raft::ceildiv(m, P::Nthreads)); + dim3 block(P::Nthreads); + + initKernel<<>>(out, m, maxVal, redOp); + RAFT_CUDA_TRY(cudaGetLastError()); + } + + // Accumulation operation lambda + auto core_lambda = [] __device__(DataT & acc, DataT & x, DataT & y) { acc += x * y; }; + auto fin_op = raft::identity_op{}; + + auto kernel = maskedL2NNkernel; + constexpr size_t smemSize = P::SmemSize + ((P::Mblk + P::Nblk) * sizeof(DataT)); + dim3 block(P::Nthreads); + dim3 grid = launchConfigGenerator

(m, n, smemSize, kernel); + + kernel<<>>(out, + x, + y, + xn, + yn, + ws_adj64.data(), + group_idxs, + num_groups, + m, + n, + k, + sqrt, + maxVal, + ws_fused_nn.data(), + redOp, + pairRedOp, + core_lambda, + fin_op); + + RAFT_CUDA_TRY(cudaGetLastError()); +} + +} // namespace detail +} // namespace distance +} // namespace raft \ No newline at end of file diff --git a/cpp/include/raft/distance/detail/pairwise_distance_base.cuh b/cpp/include/raft/distance/detail/pairwise_distance_base.cuh index 445b4bac52..293600ed21 100644 --- a/cpp/include/raft/distance/detail/pairwise_distance_base.cuh +++ b/cpp/include/raft/distance/detail/pairwise_distance_base.cuh @@ -59,6 +59,7 @@ namespace detail { * @param core_op the core accumulation operation lambda * @param epilog_op the epilog operation lambda * @param fin_op the final gemm epilogue lambda + * @param rowEpilog_op epilog lambda that executes when a full row has been processed */ template m; - gridStrideY += P::Mblk * gridDim.y) { - for (auto gridStrideX = blockIdx.x * P::Nblk; gridStrideX < this->n; - gridStrideX += P::Nblk * gridDim.x) { - prolog(gridStrideX, gridStrideY); - loop(); - epilog(gridStrideX, gridStrideY); + for (auto tile_idx_m = grid_offset_m; tile_idx_m < this->m; tile_idx_m += grid_stride_m) { + this->ldgXY(tile_idx_m, grid_offset_n, 0); + for (auto tile_idx_n = grid_offset_n; tile_idx_n < this->n; tile_idx_n += grid_stride_n) { + // Prolog: + reset_accumulator(); + this->stsXY(); + __syncthreads(); + this->switch_write_buffer(); + + // Main loop: + for (int kidx = P::Kblk; kidx < this->k; kidx += P::Kblk) { + this->ldgXY(tile_idx_m, tile_idx_n, kidx); + // Process all data in shared memory (previous k-block) and + // accumulate in registers. + accumulate(); + this->stsXY(); + __syncthreads(); + this->switch_write_buffer(); + this->switch_read_buffer(); + } + accumulate(); // last iteration + // The pre-condition for the loop over tile_idx_n is that write_buffer + // and read_buffer point to the same buffer. This flips read_buffer back + // so that it satisfies the pre-condition of this loop. + this->switch_read_buffer(); + + // Epilog: + if (useNorms) { + DataT regxn[P::AccRowsPerTh], regyn[P::AccColsPerTh]; + load_norms(tile_idx_m, tile_idx_n, regxn, regyn); + // Overlap ldg with epilog computation + ldgNextGridStride(tile_idx_m, tile_idx_n); + epilog_op(acc, regxn, regyn, tile_idx_n, tile_idx_m); + } else { + // Overlap ldg with epilog computation + ldgNextGridStride(tile_idx_m, tile_idx_n); + epilog_op(acc, nullptr, nullptr, tile_idx_n, tile_idx_m); + } + if (writeOut) { store_output(tile_idx_m, tile_idx_n); } } - rowEpilog_op(gridStrideY); + rowEpilog_op(tile_idx_m); } } private: - DI void updateIndicesY() - { - const auto stride = P::Nblk * gridDim.x; - if (isRowMajor) { - this->y += stride * this->ldb; - } else { - this->y += stride; - } - this->yrowid += stride; - } - - DI void updateIndicesXY() - { - const auto stride = P::Mblk * gridDim.y; - if (isRowMajor) { - this->x += stride * this->lda; - this->yrowid = IdxT(blockIdx.x) * P::Nblk + this->srowid; - this->y = yBase + this->yrowid * this->ldb; - } else { - this->x += stride; - this->yrowid = IdxT(blockIdx.x) * P::Nblk; - this->y = yBase + this->yrowid + this->srowid * this->ldb; - } - this->xrowid += stride; - } - - DI void ldgNextGridStride(IdxT gridStrideX, IdxT gridStrideY) + DI void ldgNextGridStride(IdxT tile_idx_m, IdxT tile_idx_n) { // Fetch next grid stride ldg if within range - if ((gridStrideX + gridDim.x * P::Nblk) < this->n) { - updateIndicesY(); - this->ldgXY(0); - } else if ((gridStrideY + gridDim.y * P::Mblk) < this->m) { - updateIndicesXY(); - this->ldgXY(0); + const auto next_tile_tile_idx_n = tile_idx_n + grid_stride_n; + const auto next_tile_tile_idx_m = tile_idx_m + grid_stride_m; + if ((next_tile_tile_idx_n) < this->n) { + this->ldgXY(tile_idx_m, next_tile_tile_idx_n, 0); + } else if ((next_tile_tile_idx_m) < this->m) { + this->ldgXY(next_tile_tile_idx_m, grid_offset_n, 0); } } - DI void prolog(IdxT gridStrideX, IdxT gridStrideY) + DI void reset_accumulator() { - if (gridStrideX == blockIdx.x * P::Nblk) { this->ldgXY(0); } - + // Reset accumulator registers to zero. #pragma unroll for (int i = 0; i < P::AccRowsPerTh; ++i) { #pragma unroll @@ -184,28 +199,6 @@ struct PairwiseDistances : public BaseClass { acc[i][j] = BaseClass::Zero; } } - - this->stsXY(); - __syncthreads(); - this->pageWr ^= 1; - } - - DI void loop() - { - for (int kidx = P::Kblk; kidx < this->k; kidx += P::Kblk) { - this->ldgXY(kidx); - accumulate(); // on the previous k-block - this->stsXY(); - __syncthreads(); - this->pageWr ^= 1; - this->pageRd ^= 1; - } - accumulate(); // last iteration - // This is needed for making sure next grid stride of - // non-norm based metrics uses previously accumulated buffer so - // it doesn't make shmem dirty until previous iteration - // is complete. - this->pageRd ^= 1; } DI void accumulate() @@ -226,60 +219,52 @@ struct PairwiseDistances : public BaseClass { } } - DI void epilog(IdxT gridStrideX, IdxT gridStrideY) + DI void load_norms(IdxT tile_idx_m, + IdxT tile_idx_n, + DataT (®xn)[P::AccRowsPerTh], + DataT (®yn)[P::AccColsPerTh]) { - if (useNorms) { - DataT* sxNorm = (DataT*)(&smem[P::SmemSize]); - DataT* syNorm = (&sxNorm[P::Mblk]); - - // Load x & y norms required by this threadblock in shmem buffer - if (gridStrideX == blockIdx.x * P::Nblk) { - for (int i = threadIdx.x; i < P::Mblk; i += P::Nthreads) { - auto idx = gridStrideY + i; - sxNorm[i] = idx < this->m ? xn[idx] : 0; - } - } - - for (int i = threadIdx.x; i < P::Nblk; i += P::Nthreads) { - auto idx = gridStrideX + i; - syNorm[i] = idx < this->n ? yn[idx] : 0; + DataT* sxNorm = (DataT*)(&smem[P::SmemSize]); + DataT* syNorm = (&sxNorm[P::Mblk]); + + // Load x & y norms required by this threadblock in shmem buffer + if (tile_idx_n == blockIdx.x * P::Nblk) { + for (int i = threadIdx.x; i < P::Mblk; i += P::Nthreads) { + auto idx = tile_idx_m + i; + sxNorm[i] = idx < this->m ? xn[idx] : 0; } + } - __syncthreads(); + for (int i = threadIdx.x; i < P::Nblk; i += P::Nthreads) { + auto idx = tile_idx_n + i; + syNorm[i] = idx < this->n ? yn[idx] : 0; + } + __syncthreads(); - DataT regxn[P::AccRowsPerTh], regyn[P::AccColsPerTh]; #pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { - regxn[i] = sxNorm[i * P::AccThRows + (threadIdx.x / P::AccThCols)]; - } + for (int i = 0; i < P::AccRowsPerTh; ++i) { + regxn[i] = sxNorm[i * P::AccThRows + (threadIdx.x / P::AccThCols)]; + } #pragma unroll - for (int i = 0; i < P::AccColsPerTh; ++i) { - regyn[i] = syNorm[i * P::AccThCols + (threadIdx.x % P::AccThCols)]; - } - - // Overlap ldg with epilog computation - ldgNextGridStride(gridStrideX, gridStrideY); - epilog_op(acc, regxn, regyn, gridStrideX, gridStrideY); - } else { - // Overlap ldg with epilog computation - ldgNextGridStride(gridStrideX, gridStrideY); - epilog_op(acc, nullptr, nullptr, gridStrideX, gridStrideY); + for (int i = 0; i < P::AccColsPerTh; ++i) { + regyn[i] = syNorm[i * P::AccThCols + (threadIdx.x % P::AccThCols)]; } + } - if (writeOut) { - IdxT starty = gridStrideY + this->accrowid; - IdxT startx = gridStrideX + this->acccolid; + DI void store_output(IdxT tile_idx_m, IdxT tile_idx_n) + { + IdxT starty = tile_idx_m + this->accrowid; + IdxT startx = tile_idx_n + this->acccolid; #pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { - auto rowId = starty + i * P::AccThRows; + for (int i = 0; i < P::AccRowsPerTh; ++i) { + auto rowId = starty + i * P::AccThRows; #pragma unroll - for (int j = 0; j < P::AccColsPerTh; ++j) { - auto colId = startx + j * P::AccThCols; - if (rowId < this->m && colId < this->n) { - // Promote to 64 bit index for final write, as output array can be > 2^31 - dOutput[std::size_t(rowId) * this->n + colId] = fin_op(acc[i][j], 0); - } + for (int j = 0; j < P::AccColsPerTh; ++j) { + auto colId = startx + j * P::AccThCols; + if (rowId < this->m && colId < this->n) { + // Promote to 64 bit index for final write, as output array can be > 2^31 + dOutput[std::size_t(rowId) * this->n + colId] = fin_op(acc[i][j], 0); } } } @@ -477,4 +462,4 @@ dim3 launchConfigGenerator(IdxT m, IdxT n, std::size_t sMemSize, T func) }; // namespace detail }; // namespace distance -}; // namespace raft +}; // namespace raft \ No newline at end of file diff --git a/cpp/include/raft/distance/masked_nn.cuh b/cpp/include/raft/distance/masked_nn.cuh new file mode 100644 index 0000000000..5a6a4f08d8 --- /dev/null +++ b/cpp/include/raft/distance/masked_nn.cuh @@ -0,0 +1,199 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __MASKED_L2_NN_H +#define __MASKED_L2_NN_H + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace raft { +namespace distance { +/** + * \defgroup masked_nn Masked 1-nearest neighbors + * @{ + */ + +/** + * @brief Parameter struct for maskedL2NN function + * + * @tparam ReduceOpT Type of reduction operator in the epilogue. + * @tparam KVPReduceOpT Type of Reduction operation on key value pairs. + * + * Usage example: + * @code{.cpp} + * #include + * + * using IdxT = int; + * using DataT = float; + * using RedOpT = raft::distance::MinAndDistanceReduceOp; + * using PairRedOpT = raft::distance::KVPMinReduce; + * using ParamT = raft::distance::MaskedL2NNParams; + * + * bool init_out = true; + * bool sqrt = false; + * + * ParamT masked_l2_params{RedOpT{}, PairRedOpT{}, sqrt, init_out}; + * @endcode + * + * Prescribes how to reduce a distance to an intermediate type (`redOp`), and + * how to reduce two intermediate types (`pairRedOp`). Typically, a distance is + * mapped to an (index, value) pair and (index, value) pair with the lowest + * value (distance) is selected. + * + * In addition, prescribes whether to compute the square root of the distance + * (`sqrt`) and whether to initialize the output buffer (`initOutBuffer`). + */ +template +struct MaskedL2NNParams { + /** Reduction operator in the epilogue */ + ReduceOpT redOp; + /** Reduction operation on key value pairs */ + KVPReduceOpT pairRedOp; + /** Whether the output `minDist` should contain L2-sqrt */ + bool sqrt; + /** Whether to initialize the output buffer before the main kernel launch */ + bool initOutBuffer; +}; + +/** + * @brief Masked L2 distance and 1-nearest-neighbor computation in a single call. + * + * This function enables faster computation of nearest neighbors if the + * computation of distances between certain point pairs can be skipped. + * + * We use an adjacency matrix that describes which distances to calculate. The + * points in `y` are divided into groups, and the adjacency matrix indicates + * whether to compute distances between points in `x` and groups in `y`. In other + * words, if `adj[i,k]` is true then distance between point `x_i`, and points in + * `group_k` will be calculated. + * + * **Performance considerations** + * + * The points in `x` are processed in tiles of `M` points (`M` is currently 64, + * but may change in the future). As a result, the largest compute time + * reduction occurs if all `M` points can skip a group. If only part of the `M` + * points can skip a group, then at most a minor compute time reduction and a + * modest energy use reduction can be expected. + * + * The points in `y` are also grouped into tiles of `N` points (`N` is currently + * 64, but may change in the future). As a result, group sizes should be larger + * than `N` to avoid wasting computational resources. If the group sizes are + * evenly divisible by `N`, then the computation is most efficient, although for + * larger group sizes this effect is minor. + * + * + * **Comparison to SDDM** + * + * [SDDMM](https://ieeexplore.ieee.org/document/8638042) (sampled dense-dense + * matrix multiplication) is a matrix-matrix multiplication where only part of + * the output is computed. Compared to maskedL2NN, there are a few differences: + * + * - The output of maskedL2NN is a single vector (of nearest neighbors) and not + * a sparse matrix. + * + * - The sampling in maskedL2NN is expressed through intermediate "groups" + rather than a CSR format. + * + * @tparam DataT data type + * @tparam OutT output type to either store 1-NN indices and their minimum + * distances or store only the min distances. Accordingly, one + * has to pass an appropriate `ReduceOpT` + * @tparam IdxT indexing arithmetic type + * @tparam ReduceOpT A struct to perform the final needed reduction operation + * and also to initialize the output array elements with the + * appropriate initial value needed for reduction. + * + * @param handle RAFT handle for managing expensive resources + * @param params Parameter struct specifying the reduction operations. + * @param[in] x First matrix. Row major. Dim = `m x k`. + * (on device). + * @param[in] y Second matrix. Row major. Dim = `n x k`. + * (on device). + * @param[in] x_norm L2 squared norm of `x`. Length = `m`. (on device). + * @param[in] y_norm L2 squared norm of `y`. Length = `n`. (on device) + * @param[in] adj A boolean adjacency matrix indicating for each + * row of `x` and each group in `y` whether to compute the + * distance. Dim = `m x num_groups`. + * @param[in] group_idxs An array containing the *end* indices of each group + * in `y`. The value of group_idxs[j] indicates the + * start of group j + 1, i.e., it is the inclusive + * scan of the group lengths. The first group is + * always assumed to start at index 0 and the last + * group typically ends at index `n`. Length = + * `num_groups`. + * @param[out] out will contain the reduced output (Length = `m`) + * (on device) + */ +template +void maskedL2NN(raft::device_resources const& handle, + raft::distance::MaskedL2NNParams params, + raft::device_matrix_view x, + raft::device_matrix_view y, + raft::device_vector_view x_norm, + raft::device_vector_view y_norm, + raft::device_matrix_view adj, + raft::device_vector_view group_idxs, + raft::device_vector_view out) +{ + IdxT m = x.extent(0); + IdxT n = y.extent(0); + IdxT k = x.extent(1); + IdxT num_groups = group_idxs.extent(0); + + // Match k dimension of x, y + RAFT_EXPECTS(x.extent(1) == y.extent(1), "Dimension of vectors in x and y must be equal."); + // Match x, x_norm and y, y_norm + RAFT_EXPECTS(m == x_norm.extent(0), "Length of `x_norm` must match input `x`."); + RAFT_EXPECTS(n == y_norm.extent(0), "Length of `y_norm` must match input `y` "); + // Match adj to x and group_idxs + RAFT_EXPECTS(m == adj.extent(0), "#rows in `adj` must match input `x`."); + RAFT_EXPECTS(num_groups == adj.extent(1), "#cols in `adj` must match length of `group_idxs`."); + // NOTE: We do not check if all indices in group_idxs actually points *inside* y. + + // If there is no work to be done, return immediately. + if (m == 0 || n == 0 || k == 0 || num_groups == 0) { return; } + + detail::maskedL2NNImpl(handle, + out.data_handle(), + x.data_handle(), + y.data_handle(), + x_norm.data_handle(), + y_norm.data_handle(), + adj.data_handle(), + group_idxs.data_handle(), + num_groups, + m, + n, + k, + params.redOp, + params.pairRedOp, + params.sqrt, + params.initOutBuffer); +} + +/** @} */ + +} // namespace distance +} // namespace raft + +#endif \ No newline at end of file diff --git a/cpp/include/raft/linalg/detail/contractions.cuh b/cpp/include/raft/linalg/detail/contractions.cuh index e247f39bc7..9301580a9e 100644 --- a/cpp/include/raft/linalg/detail/contractions.cuh +++ b/cpp/include/raft/linalg/detail/contractions.cuh @@ -40,14 +40,10 @@ struct Contractions_NT { /** leading dimension in Output D */ IdxT ldd; - /** current thread's global mem row id for X data */ - IdxT xrowid; - /** current thread's global mem row id for Y data */ - IdxT yrowid; /** global memory pointer to X matrix */ - const DataT* x; + const DataT* x_base; /** global memory pointer to Y matrix */ - const DataT* y; + const DataT* y_base; /** current thread's smem row id */ int srowid; @@ -94,10 +90,8 @@ struct Contractions_NT { k(_k), lda(_k), ldb(_k), - xrowid(IdxT(blockIdx.x) * P::Mblk + threadIdx.x / P::LdgThRow), - yrowid(IdxT(blockIdx.y) * P::Nblk + threadIdx.x / P::LdgThRow), - x(_x + xrowid * lda), - y(_y + yrowid * ldb), + x_base(_x), + y_base(_y), srowid(threadIdx.x / P::LdgThRow), scolid((threadIdx.x % P::LdgThRow) * P::Veclen), accrowid(threadIdx.x / P::AccThCols), @@ -133,6 +127,8 @@ struct Contractions_NT { lda(_lda), ldb(_ldb), ldd(_ldd), + x_base(_x), + y_base(_y), srowid(threadIdx.x / P::LdgThRow), scolid((threadIdx.x % P::LdgThRow) * P::Veclen), accrowid(threadIdx.x / P::AccThCols), @@ -142,17 +138,6 @@ struct Contractions_NT { pageWr(0), pageRd(0) { - if (isRowMajor) { - xrowid = IdxT(blockIdx.y) * P::Mblk + srowid; - yrowid = IdxT(blockIdx.x) * P::Nblk + srowid; - x = _x + xrowid * lda; - y = _y + yrowid * ldb; - } else { - xrowid = IdxT(blockIdx.y) * P::Mblk; - yrowid = IdxT(blockIdx.x) * P::Nblk; - x = _x + xrowid + srowid * lda; - y = _y + yrowid + srowid * ldb; - } } protected: @@ -160,10 +145,16 @@ struct Contractions_NT { * @brief Load current block of X/Y from global memory to registers * @param[in] kidx current start index of k to be loaded */ - DI void ldgXY(IdxT kidx) + DI void ldgXY(IdxT tile_idx_m, IdxT tile_idx_n, IdxT kidx) + { + ldgX(tile_idx_m, kidx); + ldgY(tile_idx_n, kidx); + } + + DI void ldgXY(IdxT tile_idx_m, IdxT tile_idx_n, IdxT kidx, IdxT tile_end_n) { - ldgX(kidx); - ldgY(kidx); + ldgX(tile_idx_m, kidx); + ldgY(tile_idx_n, kidx, tile_end_n); } /** @@ -186,9 +177,16 @@ struct Contractions_NT { ldsY(kidx, sy + pageRd * P::SmemPage); } + DI void switch_read_buffer() { this->pageRd ^= 1; } + + DI void switch_write_buffer() { this->pageWr ^= 1; } + private: - DI void ldgX(IdxT kidx) + DI void ldgX(IdxT tile_idx_m, IdxT kidx) { + IdxT xrowid = isRowMajor ? tile_idx_m + srowid : tile_idx_m; + auto x = isRowMajor ? x_base + xrowid * lda : x_base + xrowid + srowid * lda; + if (isRowMajor) { auto numRows = m; auto koffset = kidx + scolid; @@ -220,10 +218,15 @@ struct Contractions_NT { } } - DI void ldgY(IdxT kidx) + DI void ldgY(IdxT tile_idx_n, IdxT kidx) { ldgY(tile_idx_n, kidx, n); } + + DI void ldgY(IdxT tile_idx_n, IdxT kidx, IdxT end_n) { + IdxT yrowid = isRowMajor ? tile_idx_n + srowid : tile_idx_n; + auto y = isRowMajor ? y_base + yrowid * ldb : y_base + yrowid + srowid * ldb; + if (isRowMajor) { - auto numRows = n; + auto numRows = end_n; auto koffset = kidx + scolid; #pragma unroll for (int i = 0; i < P::LdgPerThY; ++i) { @@ -241,7 +244,7 @@ struct Contractions_NT { auto koffset = scolid; #pragma unroll for (int i = 0; i < P::LdgPerThY; ++i) { - if ((koffset + yrowid) < ldb && (srowid + kidx + i * P::LdgRowsY) < numRows) { + if ((koffset + yrowid) < end_n && (srowid + kidx + i * P::LdgRowsY) < numRows) { ldg(ldgDataY[i], y + (kidx + i * P::LdgRowsY) * ldb + koffset); } else { #pragma unroll diff --git a/cpp/include/raft/spatial/knn/detail/epsilon_neighborhood.cuh b/cpp/include/raft/spatial/knn/detail/epsilon_neighborhood.cuh index e4843acee9..1bc6622e43 100644 --- a/cpp/include/raft/spatial/knn/detail/epsilon_neighborhood.cuh +++ b/cpp/include/raft/spatial/knn/detail/epsilon_neighborhood.cuh @@ -64,7 +64,7 @@ struct EpsUnexpL2SqNeighborhood : public BaseClass { private: DI void prolog() { - this->ldgXY(0); + this->ldgXY(IdxT(blockIdx.x) * P::Mblk, IdxT(blockIdx.y) * P::Nblk, 0); #pragma unroll for (int i = 0; i < P::AccRowsPerTh; ++i) { #pragma unroll @@ -74,18 +74,18 @@ struct EpsUnexpL2SqNeighborhood : public BaseClass { } this->stsXY(); __syncthreads(); - this->pageWr ^= 1; + this->switch_write_buffer(); } DI void loop() { for (int kidx = P::Kblk; kidx < this->k; kidx += P::Kblk) { - this->ldgXY(kidx); + this->ldgXY(IdxT(blockIdx.x) * P::Mblk, IdxT(blockIdx.y) * P::Nblk, kidx); accumulate(); // on the previous k-block this->stsXY(); __syncthreads(); - this->pageWr ^= 1; - this->pageRd ^= 1; + this->switch_write_buffer(); + this->switch_read_buffer(); } accumulate(); // last iteration } @@ -240,4 +240,4 @@ void epsUnexpL2SqNeighborhood(bool* adj, } // namespace detail } // namespace knn } // namespace spatial -} // namespace raft +} // namespace raft \ No newline at end of file diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 3c41621274..41fb917d17 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -122,6 +122,8 @@ if(BUILD_TESTS) test/distance/dist_l1.cu test/distance/dist_minkowski.cu test/distance/dist_russell_rao.cu + test/distance/masked_nn.cu + test/distance/masked_nn_compress_to_bits.cu test/distance/fused_l2_nn.cu test/distance/gram.cu OPTIONAL diff --git a/cpp/test/distance/fused_l2_nn.cu b/cpp/test/distance/fused_l2_nn.cu index 8b9681b9d3..af67214193 100644 --- a/cpp/test/distance/fused_l2_nn.cu +++ b/cpp/test/distance/fused_l2_nn.cu @@ -385,7 +385,7 @@ class FusedL2NNDetTest : public FusedL2NNTest { rmm::device_uvector> min1; - static const int NumRepeats = 100; + static const int NumRepeats = 3; void generateGoldenResult() override {} }; diff --git a/cpp/test/distance/masked_nn.cu b/cpp/test/distance/masked_nn.cu new file mode 100644 index 0000000000..0076634c4f --- /dev/null +++ b/cpp/test/distance/masked_nn.cu @@ -0,0 +1,435 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../test_utils.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace raft::distance::masked_nn { + +// The adjacency pattern determines what distances get computed. +enum AdjacencyPattern { + checkerboard = 0, // adjacency matrix looks like a checkerboard (half the distances are computed) + checkerboard_4 = 1, // checkerboard with tiles of size 4x4 + checkerboard_64 = 2, // checkerboard with tiles of size 64x64 + all_true = 3, // no distance computations can be skipped + all_false = 4 // all distance computations can be skipped +}; + +// Kernels: +// - init_adj: to initialize the adjacency kernel with a specific adjacency pattern +// - referenceKernel: to produce the ground-truth output + +__global__ void init_adj(AdjacencyPattern pattern, + int n, + raft::device_matrix_view adj, + raft::device_vector_view group_idxs) +{ + int m = adj.extent(0); + int num_groups = adj.extent(1); + + for (int idx_m = blockIdx.y * blockDim.y + threadIdx.y; idx_m < m; + idx_m += blockDim.y * gridDim.y) { + for (int idx_g = blockIdx.x * blockDim.x + threadIdx.x; idx_g < num_groups; + idx_g += blockDim.x * gridDim.x) { + switch (pattern) { + case checkerboard: adj(idx_m, idx_g) = (idx_m + idx_g) % 2; break; + case checkerboard_4: adj(idx_m, idx_g) = (idx_m / 4 + idx_g) % 2; break; + case checkerboard_64: adj(idx_m, idx_g) = (idx_m / 64 + idx_g) % 2; break; + case all_true: adj(idx_m, idx_g) = true; break; + case all_false: adj(idx_m, idx_g) = false; break; + default: assert(false && "unknown pattern"); + } + } + } + // Each group is of size n / num_groups. + // + // - group_idxs[j] indicates the start of group j + 1 (i.e. is the inclusive + // scan of the group lengths) + // + // - The first group always starts at index zero, so we do not store it. + // + // - The group_idxs[num_groups - 1] should always equal n. + + if (blockIdx.y == 0 && threadIdx.y == 0) { + const int g_stride = blockDim.x * gridDim.x; + for (int idx_g = blockIdx.x * blockDim.x + threadIdx.x; idx_g < num_groups; idx_g += g_stride) { + group_idxs(idx_g) = (idx_g + 1) * (n / num_groups); + } + group_idxs(num_groups - 1) = n; + } +} + +template +__global__ __launch_bounds__(32 * NWARPS, + 2) void referenceKernel(raft::KeyValuePair* min, + DataT* x, + DataT* y, + bool* adj, + int* group_idxs, + int m, + int n, + int k, + int num_groups, + bool sqrt, + int* workspace, + DataT maxVal) +{ + const int m_stride = blockDim.y * gridDim.y; + const int m_offset = threadIdx.y + blockIdx.y * blockDim.y; + const int n_stride = blockDim.x * gridDim.x; + const int n_offset = threadIdx.x + blockIdx.x * blockDim.x; + + for (int m_grid = 0; m_grid < m; m_grid += m_stride) { + for (int n_grid = 0; n_grid < n; n_grid += n_stride) { + int midx = m_grid + m_offset; + int nidx = n_grid + n_offset; + + // Do a reverse linear search to determine the group index. + int group_idx = 0; + for (int i = num_groups; 0 <= i; --i) { + if (nidx < group_idxs[i]) { group_idx = i; } + } + const bool include_dist = adj[midx * num_groups + group_idx] && midx < m && nidx < n; + + // Compute L2 metric. + DataT acc = DataT(0); + for (int i = 0; i < k; ++i) { + int xidx = i + midx * k; + int yidx = i + nidx * k; + auto diff = x[xidx] - y[yidx]; + acc += diff * diff; + } + if (sqrt) { acc = raft::sqrt(acc); } + ReduceOpT redOp; + typedef cub::WarpReduce> WarpReduce; + __shared__ typename WarpReduce::TempStorage temp[NWARPS]; + int warpId = threadIdx.x / raft::WarpSize; + raft::KeyValuePair tmp; + tmp.key = include_dist ? nidx : -1; + tmp.value = include_dist ? acc : maxVal; + tmp = WarpReduce(temp[warpId]).Reduce(tmp, raft::distance::KVPMinReduce{}); + if (threadIdx.x % raft::WarpSize == 0 && midx < m) { + while (atomicCAS(workspace + midx, 0, 1) == 1) + ; + __threadfence(); + redOp(midx, min + midx, tmp); + __threadfence(); + atomicCAS(workspace + midx, 1, 0); + } + __syncthreads(); + } + } +} + +// Structs +// - Params: holds parameters for test case +// - Inputs: holds the inputs to the functions under test (x, y, adj, group_idxs). Is generated from +// the inputs. +struct Params { + double tolerance; + int m, n, k, num_groups; + bool sqrt; + unsigned long long int seed; + AdjacencyPattern pattern; +}; + +inline auto operator<<(std::ostream& os, const Params& p) -> std::ostream& +{ + os << "m: " << p.m << ", n: " << p.n << ", k: " << p.k << ", num_groups: " << p.num_groups + << ", sqrt: " << p.sqrt << ", seed: " << p.seed << ", tol: " << p.tolerance; + return os; +} + +template +struct Inputs { + using IdxT = int; + + raft::device_matrix x, y; + raft::device_matrix adj; + raft::device_vector group_idxs; + + Inputs(const raft::handle_t& handle, const Params& p) + : x{raft::make_device_matrix(handle, p.m, p.k)}, + y{raft::make_device_matrix(handle, p.n, p.k)}, + adj{raft::make_device_matrix(handle, p.m, p.num_groups)}, + group_idxs{raft::make_device_vector(handle, p.num_groups)} + { + // Initialize x, y + raft::random::RngState r(p.seed); + uniform(handle, r, x.data_handle(), p.m * p.k, DataT(-1.0), DataT(1.0)); + uniform(handle, r, y.data_handle(), p.n * p.k, DataT(-1.0), DataT(1.0)); + + // Initialize adj, group_idxs. + dim3 block(32, 32); + dim3 grid(10, 10); + init_adj<<>>( + p.pattern, p.n, adj.view(), group_idxs.view()); + RAFT_CUDA_TRY(cudaGetLastError()); + } +}; + +template > +auto reference(const raft::handle_t& handle, Inputs inp, const Params& p) + -> raft::device_vector +{ + int m = inp.x.extent(0); + int n = inp.y.extent(0); + int k = inp.x.extent(1); + int num_groups = inp.group_idxs.extent(0); + + if (m == 0 || n == 0 || k == 0 || num_groups == 0) { + return raft::make_device_vector(handle, 0); + } + + // Initialize workspace + auto stream = handle.get_stream(); + rmm::device_uvector workspace(p.m * sizeof(int), stream); + RAFT_CUDA_TRY(cudaMemsetAsync(workspace.data(), 0, sizeof(int) * m, stream)); + + // Initialize output + auto out = raft::make_device_vector(handle, m); + auto blks = raft::ceildiv(m, 256); + MinAndDistanceReduceOp op; + raft::distance::detail::initKernel, int> + <<>>(out.data_handle(), m, std::numeric_limits::max(), op); + RAFT_CUDA_TRY(cudaGetLastError()); + + // Launch reference kernel + const int nwarps = 16; + static const dim3 TPB(32, nwarps, 1); + dim3 nblks(1, 200, 1); + referenceKernel + <<>>(out.data_handle(), + inp.x.data_handle(), + inp.y.data_handle(), + inp.adj.data_handle(), + inp.group_idxs.data_handle(), + m, + n, + k, + num_groups, + p.sqrt, + (int*)workspace.data(), + std::numeric_limits::max()); + RAFT_CUDA_TRY(cudaGetLastError()); + + return out; +} + +template > +auto run_masked_nn(const raft::handle_t& handle, Inputs inp, const Params& p) + -> raft::device_vector +{ + // Compute norms: + auto x_norm = raft::make_device_vector(handle, p.m); + auto y_norm = raft::make_device_vector(handle, p.n); + + raft::linalg::norm(handle, + std::as_const(inp.x).view(), + x_norm.view(), + raft::linalg::L2Norm, + raft::linalg::Apply::ALONG_ROWS); + raft::linalg::norm(handle, + std::as_const(inp.y).view(), + y_norm.view(), + raft::linalg::L2Norm, + raft::linalg::Apply::ALONG_ROWS); + + // Create parameters for maskedL2NN + using IdxT = int; + using RedOpT = MinAndDistanceReduceOp; + using PairRedOpT = raft::distance::KVPMinReduce; + using ParamT = raft::distance::MaskedL2NNParams; + + bool init_out = true; + ParamT masked_l2_params{RedOpT{}, PairRedOpT{}, p.sqrt, init_out}; + + // Create output + auto out = raft::make_device_vector(handle, p.m); + + // Launch kernel + raft::distance::maskedL2NN(handle, + masked_l2_params, + inp.x.view(), + inp.y.view(), + x_norm.view(), + y_norm.view(), + inp.adj.view(), + inp.group_idxs.view(), + out.view()); + + handle.sync_stream(); + + return out; +} + +template +struct CompareApproxAbsKVP { + typedef typename raft::KeyValuePair KVP; + CompareApproxAbsKVP(T eps_) : eps(eps_) {} + bool operator()(const KVP& a, const KVP& b) const + { + T diff = raft::abs(raft::abs(a.value) - raft::abs(b.value)); + T m = std::max(raft::abs(a.value), raft::abs(b.value)); + T ratio = m >= eps ? diff / m : diff; + return (ratio <= eps); + } + + private: + T eps; +}; + +template +::testing::AssertionResult devArrMatch(const raft::KeyValuePair* expected, + const raft::KeyValuePair* actual, + size_t size, + L eq_compare, + cudaStream_t stream = 0) +{ + typedef typename raft::KeyValuePair KVP; + std::shared_ptr exp_h(new KVP[size]); + std::shared_ptr act_h(new KVP[size]); + raft::update_host(exp_h.get(), expected, size, stream); + raft::update_host(act_h.get(), actual, size, stream); + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + for (size_t i(0); i < size; ++i) { + auto exp = exp_h.get()[i]; + auto act = act_h.get()[i]; + if (!eq_compare(exp, act)) { + return ::testing::AssertionFailure() + << "actual=" << act.key << "," << act.value << " != expected=" << exp.key << "," + << exp.value << " @" << i; + } + } + return ::testing::AssertionSuccess(); +} + +inline auto gen_params() -> std::vector +{ + // Regular powers of two + auto regular = raft::util::itertools::product({0.001f}, // tolerance + {32, 64, 512}, // m + {32, 64, 512}, // n + {8, 32}, // k + {2, 32}, // num_groups + {true, false}, // sqrt + {1234ULL}, // seed + {AdjacencyPattern::all_true, + AdjacencyPattern::checkerboard, + AdjacencyPattern::checkerboard_64, + AdjacencyPattern::all_false}); + + // Irregular sizes to check tiling and bounds checking + auto irregular = raft::util::itertools::product({0.001f}, // tolerance + {511, 512, 513}, // m + {127, 128, 129}, // n + {5}, // k + {3, 9}, // num_groups + {true, false}, // sqrt + {1234ULL}, // seed + {AdjacencyPattern::all_true, + AdjacencyPattern::checkerboard, + AdjacencyPattern::checkerboard_64}); + + regular.insert(regular.end(), irregular.begin(), irregular.end()); + + return regular; +} + +class MaskedL2NNTest : public ::testing::TestWithParam { + // Empty. +}; + +// +TEST_P(MaskedL2NNTest, ReferenceCheckFloat) +{ + using DataT = float; + + // Get parameters; create handle and input data. + Params p = GetParam(); + raft::handle_t handle{}; + Inputs inputs{handle, p}; + + // Calculate reference and test output + auto out_reference = reference(handle, inputs, p); + auto out_fast = run_masked_nn(handle, inputs, p); + + // Check for differences. + ASSERT_TRUE(devArrMatch(out_reference.data_handle(), + out_fast.data_handle(), + p.m, + CompareApproxAbsKVP(p.tolerance), + handle.get_stream())); +} + +// This test checks whether running the maskedL2NN twice returns the same +// output. +TEST_P(MaskedL2NNTest, DeterminismCheck) +{ + using DataT = float; + + // Get parameters; create handle and input data. + Params p = GetParam(); + raft::handle_t handle{}; + Inputs inputs{handle, p}; + + // Calculate reference and test output + auto out1 = run_masked_nn(handle, inputs, p); + auto out2 = run_masked_nn(handle, inputs, p); + + // Check for differences. + ASSERT_TRUE(devArrMatch(out1.data_handle(), + out2.data_handle(), + p.m, + CompareApproxAbsKVP(p.tolerance), + handle.get_stream())); +} + +TEST_P(MaskedL2NNTest, ReferenceCheckDouble) +{ + using DataT = double; + + // Get parameters; create handle and input data. + Params p = GetParam(); + raft::handle_t handle{}; + Inputs inputs{handle, p}; + + // Calculate reference and test output + auto out_reference = reference(handle, inputs, p); + auto out_fast = run_masked_nn(handle, inputs, p); + + // Check for differences. + ASSERT_TRUE(devArrMatch(out_reference.data_handle(), + out_fast.data_handle(), + p.m, + CompareApproxAbsKVP(p.tolerance), + handle.get_stream())); +} + +INSTANTIATE_TEST_CASE_P(MaskedL2NNTests, MaskedL2NNTest, ::testing::ValuesIn(gen_params())); + +} // end namespace raft::distance::masked_nn \ No newline at end of file diff --git a/cpp/test/distance/masked_nn_compress_to_bits.cu b/cpp/test/distance/masked_nn_compress_to_bits.cu new file mode 100644 index 0000000000..e7d75780be --- /dev/null +++ b/cpp/test/distance/masked_nn_compress_to_bits.cu @@ -0,0 +1,216 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../test_utils.cuh" +#include "../test_utils.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace raft::distance::masked_nn::compress_to_bits { + +/** + * @brief Transpose and decompress 2D bitfield to boolean matrix + * + * Inverse operation of compress_to_bits + * + * @tparam T + * + * @parameter[in] in An `m x n` bitfield matrix. Row major. + * @parameter in_rows The number of rows of `in`, i.e. `m`. + * @parameter in_cols The number of cols of `in`, i.e. `n`. + * + * @parameter[out] out An `(m * bits_per_elem) x n` boolean matrix. + */ +template ::value>> +__global__ void decompress_bits_kernel(const T* in, int in_rows, int in_cols, bool* out) +{ + constexpr int bits_per_element = 8 * sizeof(T); + + const size_t i = threadIdx.y + blockIdx.y * blockDim.y; + const size_t j = threadIdx.x + blockIdx.x * blockDim.x; + + if (in_rows <= i || in_cols <= j) { return; } + + const size_t out_rows = in_rows * bits_per_element; + const size_t out_cols = in_cols; + const size_t out_i = i * bits_per_element; + const size_t out_j = j; + + if (out_rows <= out_i && out_cols <= out_j) { return; } + + T bitfield = in[i * in_cols + j]; + for (int bitpos = 0; bitpos < bits_per_element; ++bitpos) { + bool bit = ((T(1) << bitpos) & bitfield) != 0; + out[(out_i + bitpos) * out_cols + out_j] = bit; + } +} + +/** + * @brief Transpose and decompress 2D bitfield to boolean matrix + * + * Inverse operation of compress_to_bits + * + * @tparam T + * + * @parameter[in] in An `m x n` bitfield matrix. Row major. + * @parameter in_rows The number of rows of `in`, i.e. `m`. + * @parameter in_cols The number of cols of `in`, i.e. `n`. + * + * @parameter[out] out An `n x (m * bits_per_elem)` boolean matrix. + */ +template ::value>> +void decompress_bits(const raft::handle_t& handle, const T* in, int in_rows, int in_cols, bool* out) +{ + auto stream = handle.get_stream(); + dim3 grid(raft::ceildiv(in_cols, 32), raft::ceildiv(in_rows, 32)); + dim3 block(32, 32); + decompress_bits_kernel<<>>(in, in_rows, in_cols, out); + RAFT_CUDA_TRY(cudaGetLastError()); +} + +// Params holds parameters for test case +struct Params { + int m, n; +}; + +inline auto operator<<(std::ostream& os, const Params& p) -> std::ostream& +{ + return os << "m: " << p.m << ", n: " << p.n; +} + +// Check that the following holds +// +// decompress(compress(x)) == x +// +// for 2D boolean matrices x. +template +void check_invertible(const Params& p) +{ + using raft::distance::detail::compress_to_bits; + constexpr int bits_per_elem = sizeof(T) * 8; + + // Make m and n that are safe to ceildiv. + int m = raft::round_up_safe(p.m, bits_per_elem); + int n = p.n; + + // Generate random input + raft::handle_t handle{}; + raft::random::RngState r(1ULL); + auto in = raft::make_device_matrix(handle, m, n); + raft::random::bernoulli(handle, r, in.data_handle(), m * n, 0.5f); + + int tmp_m = raft::ceildiv(m, bits_per_elem); + int out_m = tmp_m * bits_per_elem; + + auto tmp = raft::make_device_matrix(handle, tmp_m, n); + auto out = raft::make_device_matrix(handle, out_m, n); + + handle.sync_stream(); + RAFT_CUDA_TRY(cudaGetLastError()); + + ASSERT_EQ(in.extent(0), out.extent(0)) << "M does not match"; + ASSERT_EQ(in.extent(1), out.extent(1)) << "N does not match"; + + compress_to_bits(handle, in.view(), tmp.view()); + handle.sync_stream(); + RAFT_CUDA_TRY(cudaGetLastError()); + + decompress_bits(handle, tmp.data_handle(), tmp.extent(0), tmp.extent(1), out.data_handle()); + handle.sync_stream(); + RAFT_CUDA_TRY(cudaGetLastError()); + + // Check for differences. + ASSERT_TRUE(raft::devArrMatch(in.data_handle(), + out.data_handle(), + in.extent(0) * in.extent(1), + raft::Compare(), + handle.get_stream())); + handle.sync_stream(); + RAFT_CUDA_TRY(cudaGetLastError()); +} + +void check_all_true(const Params& p) +{ + using raft::distance::detail::compress_to_bits; + using T = uint64_t; + constexpr int bits_per_elem = sizeof(T) * 8; + + // Make m and n that are safe to ceildiv. + int m = raft::round_up_safe(p.m, bits_per_elem); + int n = p.n; + + raft::handle_t handle{}; + raft::random::RngState r(1ULL); + auto in = raft::make_device_matrix(handle, m, n); + raft::matrix::fill(handle, in.view(), true); + + int tmp_m = raft::ceildiv(m, bits_per_elem); + auto tmp = raft::make_device_matrix(handle, tmp_m, n); + handle.sync_stream(); + RAFT_CUDA_TRY(cudaGetLastError()); + + compress_to_bits(handle, in.view(), tmp.view()); + handle.sync_stream(); + RAFT_CUDA_TRY(cudaGetLastError()); + + auto expected = raft::make_device_matrix(handle, tmp_m, n); + raft::matrix::fill(handle, expected.view(), ~T(0)); + + // Check for differences. + ASSERT_TRUE(raft::devArrMatch(expected.data_handle(), + tmp.data_handle(), + tmp.extent(0) * tmp.extent(1), + raft::Compare(), + handle.get_stream())); + handle.sync_stream(); + RAFT_CUDA_TRY(cudaGetLastError()); +} + +class CompressToBitsTest : public ::testing::TestWithParam { + // Empty. +}; + +TEST_P(CompressToBitsTest, CheckTrue64) { check_all_true(GetParam()); } + +TEST_P(CompressToBitsTest, CheckInvertible64) +{ + using T = uint64_t; + check_invertible(GetParam()); +} + +TEST_P(CompressToBitsTest, CheckInvertible32) +{ + using T = uint32_t; + check_invertible(GetParam()); +} + +std::vector params = raft::util::itertools::product( + {1, 3, 32, 33, 63, 64, 65, 128, 10013}, {1, 3, 32, 33, 63, 64, 65, 13001}); + +INSTANTIATE_TEST_CASE_P(CompressToBits, CompressToBitsTest, ::testing::ValuesIn(params)); + +} // namespace raft::distance::masked_nn::compress_to_bits \ No newline at end of file diff --git a/docs/source/conf.py b/docs/source/conf.py index a85dc15b3b..3b83464d3d 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- # -# Copyright (c) 2018-2022, NVIDIA CORPORATION. +# Copyright (c) 2018-2023, NVIDIA CORPORATION. # # This file is execfile()d with the current directory set to its # containing dir. diff --git a/docs/source/cpp_api/distance.rst b/docs/source/cpp_api/distance.rst index eb9bc6255d..1632f19fba 100644 --- a/docs/source/cpp_api/distance.rst +++ b/docs/source/cpp_api/distance.rst @@ -25,3 +25,4 @@ namespace *raft::distance* distance_pairwise.rst distance_1nn.rst + distance_masked_nn.rst diff --git a/docs/source/cpp_api/distance_masked_nn.rst b/docs/source/cpp_api/distance_masked_nn.rst new file mode 100644 index 0000000000..d5f2f6be7f --- /dev/null +++ b/docs/source/cpp_api/distance_masked_nn.rst @@ -0,0 +1 @@ +distance_masked_nn.rst \ No newline at end of file diff --git a/fetch_rapids.cmake b/fetch_rapids.cmake index a6be017d77..2d312bd3e5 100644 --- a/fetch_rapids.cmake +++ b/fetch_rapids.cmake @@ -1,5 +1,5 @@ # ============================================================================= -# Copyright (c) 2022, NVIDIA CORPORATION. +# Copyright (c) 2022-2023, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at diff --git a/python/pylibraft/CMakeLists.txt b/python/pylibraft/CMakeLists.txt index 98d723e27b..b12d0a63ea 100644 --- a/python/pylibraft/CMakeLists.txt +++ b/python/pylibraft/CMakeLists.txt @@ -1,5 +1,5 @@ # ============================================================================= -# Copyright (c) 2022, NVIDIA CORPORATION. +# Copyright (c) 2022-2023, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at