From 4ad7daba36d535c22f9182dae6bb714b34cbba25 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Fri, 3 Feb 2023 08:51:05 -0500 Subject: [PATCH] Reverting a few commits from 23.02 and speeding up end-to-end build time (#1232) Authors: - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Mark Sadang (https://github.com/msadang) - Brad Rees (https://github.com/BradReesWork) URL: https://github.com/rapidsai/raft/pull/1232 --- .../recipes/libraft/build_libraft_distance.sh | 2 +- conda/recipes/libraft/build_libraft_nn.sh | 2 +- conda/recipes/libraft/build_libraft_tests.sh | 2 +- cpp/CMakeLists.txt | 15 +- cpp/bench/CMakeLists.txt | 1 - cpp/bench/distance/distance_common.cuh | 12 +- cpp/bench/distance/masked_nn.cu | 267 ----------- cpp/bench/neighbors/knn.cuh | 2 - cpp/bench/neighbors/refine.cu | 9 +- .../raft/distance/detail/compress_to_bits.cuh | 122 ----- .../raft/distance/detail/fused_l2_nn.cuh | 1 - .../distance/detail/masked_distance_base.cuh | 326 ------------- .../raft/distance/detail/masked_nn.cuh | 325 ------------- .../detail/pairwise_distance_base.cuh | 205 +++++---- cpp/include/raft/distance/masked_nn.cuh | 199 -------- .../raft/linalg/detail/contractions.cuh | 61 ++- .../raft/neighbors/specializations.cuh | 5 +- .../knn/detail/epsilon_neighborhood.cuh | 10 +- cpp/src/distance/neighbors/ivfpq_build.cu | 14 - .../distance/neighbors/ivfpq_deserialize.cu | 29 ++ .../neighbors/ivfpq_search_float_uint64_t.cu | 42 ++ .../neighbors/ivfpq_search_int8_t_uint64_t.cu | 42 ++ ...ch.cu => ivfpq_search_uint8_t_uint64_t.cu} | 2 - cpp/src/distance/neighbors/ivfpq_serialize.cu | 29 ++ .../neighbors/refine_h_uint64_t_float.cu | 1 + .../neighbors/refine_h_uint64_t_int8_t.cu | 1 + .../ball_cover_all_knn_query.cu | 41 ++ ...all_cover.cu => ball_cover_build_index.cu} | 20 - .../specializations/ball_cover_knn_query.cu | 43 ++ .../brute_force_knn_long_float_int.cu | 41 ++ .../brute_force_knn_long_float_uint.cu | 41 ++ .../brute_force_knn_uint32_t_float_int.cu | 40 ++ .../brute_force_knn_uint32_t_float_uint.cu | 41 ++ .../detail/ball_cover_lowdim_pass_one_3d.cu | 13 - cpp/src/nn/specializations/knn.cu | 86 ---- cpp/test/CMakeLists.txt | 2 - cpp/test/distance/fused_l2_nn.cu | 2 +- cpp/test/distance/masked_nn.cu | 435 ------------------ .../distance/masked_nn_compress_to_bits.cu | 216 --------- cpp/test/neighbors/refine.cu | 2 +- docs/source/cpp_api/distance.rst | 1 - docs/source/cpp_api/distance_masked_nn.rst | 16 - 42 files changed, 565 insertions(+), 2201 deletions(-) delete mode 100644 cpp/bench/distance/masked_nn.cu delete mode 100644 cpp/include/raft/distance/detail/compress_to_bits.cuh delete mode 100644 cpp/include/raft/distance/detail/masked_distance_base.cuh delete mode 100644 cpp/include/raft/distance/detail/masked_nn.cuh delete mode 100644 cpp/include/raft/distance/masked_nn.cuh create mode 100644 cpp/src/distance/neighbors/ivfpq_deserialize.cu create mode 100644 cpp/src/distance/neighbors/ivfpq_search_float_uint64_t.cu create mode 100644 cpp/src/distance/neighbors/ivfpq_search_int8_t_uint64_t.cu rename cpp/src/distance/neighbors/{ivfpq_search.cu => ivfpq_search_uint8_t_uint64_t.cu} (96%) create mode 100644 cpp/src/distance/neighbors/ivfpq_serialize.cu create mode 100644 cpp/src/nn/specializations/ball_cover_all_knn_query.cu rename cpp/src/nn/specializations/{ball_cover.cu => ball_cover_build_index.cu} (70%) create mode 100644 cpp/src/nn/specializations/ball_cover_knn_query.cu create mode 100644 cpp/src/nn/specializations/brute_force_knn_long_float_int.cu create mode 100644 cpp/src/nn/specializations/brute_force_knn_long_float_uint.cu create mode 100644 cpp/src/nn/specializations/brute_force_knn_uint32_t_float_int.cu create mode 100644 cpp/src/nn/specializations/brute_force_knn_uint32_t_float_uint.cu delete mode 100644 cpp/src/nn/specializations/knn.cu delete mode 100644 cpp/test/distance/masked_nn.cu delete mode 100644 cpp/test/distance/masked_nn_compress_to_bits.cu delete mode 100644 docs/source/cpp_api/distance_masked_nn.rst diff --git a/conda/recipes/libraft/build_libraft_distance.sh b/conda/recipes/libraft/build_libraft_distance.sh index dca32b5238..d7e995fc03 100644 --- a/conda/recipes/libraft/build_libraft_distance.sh +++ b/conda/recipes/libraft/build_libraft_distance.sh @@ -1,4 +1,4 @@ #!/usr/bin/env bash # Copyright (c) 2022-2023, NVIDIA CORPORATION. -PARALLEL_LEVEL=8 ./build.sh libraft -v --allgpuarch --compile-dist --no-nvtx +./build.sh libraft -v --allgpuarch --compile-dist --no-nvtx diff --git a/conda/recipes/libraft/build_libraft_nn.sh b/conda/recipes/libraft/build_libraft_nn.sh index 1d82e902a2..9865922cd0 100644 --- a/conda/recipes/libraft/build_libraft_nn.sh +++ b/conda/recipes/libraft/build_libraft_nn.sh @@ -1,4 +1,4 @@ #!/usr/bin/env bash # Copyright (c) 2022-2023, NVIDIA CORPORATION. -PARALLEL_LEVEL=8 ./build.sh libraft -v --allgpuarch --compile-nn --no-nvtx +./build.sh libraft -v --allgpuarch --compile-nn --no-nvtx diff --git a/conda/recipes/libraft/build_libraft_tests.sh b/conda/recipes/libraft/build_libraft_tests.sh index dc2fed2e6b..6adbbe78e1 100644 --- a/conda/recipes/libraft/build_libraft_tests.sh +++ b/conda/recipes/libraft/build_libraft_tests.sh @@ -1,5 +1,5 @@ #!/usr/bin/env bash # Copyright (c) 2022-2023, NVIDIA CORPORATION. -PARALLEL_LEVEL=8 ./build.sh tests bench -v --allgpuarch --no-nvtx +./build.sh tests bench -v --allgpuarch --no-nvtx cmake --install cpp/build --component testing diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 228e153f40..1d54409ae6 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -305,7 +305,6 @@ if(RAFT_COMPILE_DIST_LIBRARY) src/distance/neighbors/specializations/refine_h_uint64_t_float.cu src/distance/neighbors/specializations/refine_h_uint64_t_int8_t.cu src/distance/neighbors/specializations/refine_h_uint64_t_uint8_t.cu - src/distance/neighbors/ivfpq_search.cu src/distance/cluster/kmeans_fit_float.cu src/distance/cluster/kmeans_fit_double.cu src/distance/distance/specializations/detail/canberra.cu @@ -356,6 +355,11 @@ if(RAFT_COMPILE_DIST_LIBRARY) src/distance/distance/specializations/fused_l2_nn_float_int.cu src/distance/distance/specializations/fused_l2_nn_float_int64.cu src/distance/neighbors/ivfpq_build.cu + src/distance/neighbors/ivfpq_deserialize.cu + src/distance/neighbors/ivfpq_serialize.cu + src/distance/neighbors/ivfpq_search_float_uint64_t.cu + src/distance/neighbors/ivfpq_search_int8_t_uint64_t.cu + src/distance/neighbors/ivfpq_search_uint8_t_uint64_t.cu src/distance/neighbors/specializations/detail/ivfpq_compute_similarity_float_fast.cu src/distance/neighbors/specializations/detail/ivfpq_compute_similarity_float_no_basediff.cu src/distance/neighbors/specializations/detail/ivfpq_compute_similarity_float_no_smem_lut.cu @@ -427,16 +431,21 @@ set_target_properties(raft_nn PROPERTIES EXPORT_NAME nn) if(RAFT_COMPILE_NN_LIBRARY) add_library( raft_nn_lib - src/nn/specializations/ball_cover.cu src/nn/specializations/detail/ball_cover_lowdim_pass_one_2d.cu src/nn/specializations/detail/ball_cover_lowdim_pass_two_2d.cu src/nn/specializations/detail/ball_cover_lowdim_pass_one_3d.cu src/nn/specializations/detail/ball_cover_lowdim_pass_two_3d.cu + src/nn/specializations/ball_cover_all_knn_query.cu + src/nn/specializations/ball_cover_build_index.cu + src/nn/specializations/ball_cover_knn_query.cu src/nn/specializations/fused_l2_knn_long_float_true.cu src/nn/specializations/fused_l2_knn_long_float_false.cu src/nn/specializations/fused_l2_knn_int_float_true.cu src/nn/specializations/fused_l2_knn_int_float_false.cu - src/nn/specializations/knn.cu + src/nn/specializations/brute_force_knn_long_float_int.cu + src/nn/specializations/brute_force_knn_long_float_uint.cu + src/nn/specializations/brute_force_knn_uint32_t_float_int.cu + src/nn/specializations/brute_force_knn_uint32_t_float_uint.cu ) set_target_properties( raft_nn_lib diff --git a/cpp/bench/CMakeLists.txt b/cpp/bench/CMakeLists.txt index 1bc2c86243..b1ffc72ba9 100644 --- a/cpp/bench/CMakeLists.txt +++ b/cpp/bench/CMakeLists.txt @@ -82,7 +82,6 @@ if(BUILD_BENCH) bench/distance/distance_l1.cu bench/distance/distance_unexp_l2.cu bench/distance/fused_l2_nn.cu - bench/distance/masked_nn.cu bench/distance/kernels.cu bench/main.cpp OPTIONAL diff --git a/cpp/bench/distance/distance_common.cuh b/cpp/bench/distance/distance_common.cuh index 1be00ec0c7..7ddecd7579 100644 --- a/cpp/bench/distance/distance_common.cuh +++ b/cpp/bench/distance/distance_common.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,14 +24,14 @@ namespace raft::bench::distance { -struct distance_params { +struct distance_inputs { int m, n, k; bool isRowMajor; -}; // struct distance_params +}; // struct distance_inputs template struct distance : public fixture { - distance(const distance_params& p) + distance(const distance_inputs& p) : params(p), x(p.m * p.k, stream), y(p.n * p.k, stream), @@ -63,13 +63,13 @@ struct distance : public fixture { } private: - distance_params params; + distance_inputs params; rmm::device_uvector x, y, out; rmm::device_uvector workspace; size_t worksize; }; // struct Distance -const std::vector dist_input_vecs{ +const std::vector dist_input_vecs{ {32, 16384, 16384, true}, {64, 16384, 16384, true}, {128, 16384, 16384, true}, {256, 16384, 16384, true}, {512, 16384, 16384, true}, {1024, 16384, 16384, true}, {16384, 32, 16384, true}, {16384, 64, 16384, true}, {16384, 128, 16384, true}, diff --git a/cpp/bench/distance/masked_nn.cu b/cpp/bench/distance/masked_nn.cu deleted file mode 100644 index 3677d44864..0000000000 --- a/cpp/bench/distance/masked_nn.cu +++ /dev/null @@ -1,267 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#if defined RAFT_NN_COMPILED -#include -#endif - -namespace raft::bench::distance::masked_nn { - -// Introduce various sparsity patterns -enum AdjacencyPattern { - checkerboard = 0, - checkerboard_4 = 1, - checkerboard_64 = 2, - all_true = 3, - all_false = 4 -}; - -struct Params { - int m, n, k, num_groups; - AdjacencyPattern pattern; -}; // struct Params - -__global__ void init_adj(AdjacencyPattern pattern, - int n, - raft::device_matrix_view adj, - raft::device_vector_view group_idxs) -{ - int m = adj.extent(0); - int num_groups = adj.extent(1); - - for (int idx_m = blockIdx.y * blockDim.y + threadIdx.y; idx_m < m; - idx_m += blockDim.y * gridDim.y) { - for (int idx_g = blockIdx.x * blockDim.x + threadIdx.x; idx_g < num_groups; - idx_g += blockDim.x * gridDim.x) { - switch (pattern) { - case checkerboard: adj(idx_m, idx_g) = (idx_m + idx_g) % 2; break; - case checkerboard_4: adj(idx_m, idx_g) = (idx_m / 4 + idx_g) % 2; break; - case checkerboard_64: adj(idx_m, idx_g) = (idx_m / 64 + idx_g) % 2; break; - case all_true: adj(idx_m, idx_g) = true; break; - case all_false: adj(idx_m, idx_g) = false; break; - default: assert(false && "unknown pattern"); - } - } - } - // Each group is of size n / num_groups. - // - // - group_idxs[j] indicates the start of group j + 1 (i.e. is the inclusive - // scan of the group lengths) - // - // - The first group always starts at index zero, so we do not store it. - // - // - The group_idxs[num_groups - 1] should always equal n. - - if (blockIdx.y == 0 && threadIdx.y == 0) { - const int g_stride = blockDim.x * gridDim.x; - for (int idx_g = blockIdx.x * blockDim.x + threadIdx.x; idx_g < num_groups; idx_g += g_stride) { - group_idxs(idx_g) = (idx_g + 1) * (n / num_groups); - } - group_idxs(num_groups - 1) = n; - } -} - -template -struct masked_l2_nn : public fixture { - using DataT = T; - using IdxT = int; - using OutT = raft::KeyValuePair; - using RedOpT = raft::distance::MinAndDistanceReduceOp; - using PairRedOpT = raft::distance::KVPMinReduce; - using ParamT = raft::distance::MaskedL2NNParams; - - // Parameters - Params params; - // Data - raft::device_vector out; - raft::device_matrix x, y; - raft::device_vector xn, yn; - raft::device_matrix adj; - raft::device_vector group_idxs; - - masked_l2_nn(const Params& p) - : params(p), - out{raft::make_device_vector(handle, p.m)}, - x{raft::make_device_matrix(handle, p.m, p.k)}, - y{raft::make_device_matrix(handle, p.n, p.k)}, - xn{raft::make_device_vector(handle, p.m)}, - yn{raft::make_device_vector(handle, p.n)}, - adj{raft::make_device_matrix(handle, p.m, p.num_groups)}, - group_idxs{raft::make_device_vector(handle, p.num_groups)} - { - raft::random::RngState r(123456ULL); - - uniform(handle, r, x.data_handle(), p.m * p.k, T(-1.0), T(1.0)); - uniform(handle, r, y.data_handle(), p.n * p.k, T(-1.0), T(1.0)); - raft::linalg::rowNorm( - xn.data_handle(), x.data_handle(), p.k, p.m, raft::linalg::L2Norm, true, stream); - raft::linalg::rowNorm( - yn.data_handle(), y.data_handle(), p.k, p.n, raft::linalg::L2Norm, true, stream); - raft::distance::initialize, int>( - handle, out.data_handle(), p.m, std::numeric_limits::max(), RedOpT{}); - - dim3 block(32, 32); - dim3 grid(10, 10); - init_adj<<>>(p.pattern, p.n, adj.view(), group_idxs.view()); - RAFT_CUDA_TRY(cudaGetLastError()); - } - - void run_benchmark(::benchmark::State& state) override - { - bool init_out = true; - bool sqrt = false; - ParamT masked_l2_params{RedOpT{}, PairRedOpT{}, sqrt, init_out}; - - loop_on_state(state, [this, masked_l2_params]() { - // It is sufficient to only benchmark the L2-squared metric - raft::distance::maskedL2NN(handle, - masked_l2_params, - x.view(), - y.view(), - xn.view(), - yn.view(), - adj.view(), - group_idxs.view(), - out.view()); - }); - - // Virtual flop count if no skipping had occurred. - size_t virtual_flops = size_t(2) * size_t(params.m) * size_t(params.n) * size_t(params.k); - - int64_t read_elts = params.n * params.k + params.m * params.k; - int64_t write_elts = params.m; - - // Virtual min flops is the number of flops that would have been executed if - // the algorithm had actually skipped each computation that it could have - // skipped. - size_t virtual_min_flops = 0; - switch (params.pattern) { - case checkerboard: - case checkerboard_4: - case checkerboard_64: virtual_min_flops = virtual_flops / 2; break; - case all_true: virtual_min_flops = virtual_flops; break; - case all_false: virtual_min_flops = 0; break; - default: assert(false && "unknown pattern"); - } - - // VFLOP/s is the "virtual" flop count that would have executed if there was - // no adjacency pattern. This is useful for comparing to fusedL2NN - state.counters["VFLOP/s"] = benchmark::Counter(virtual_flops, - benchmark::Counter::kIsIterationInvariantRate, - benchmark::Counter::OneK::kIs1000); - // Virtual min flops is the number of flops that would have been executed if - // the algorithm had actually skipped each computation that it could have - // skipped. - state.counters["VminFLOP/s"] = benchmark::Counter(virtual_min_flops, - benchmark::Counter::kIsIterationInvariantRate, - benchmark::Counter::OneK::kIs1000); - - state.counters["BW Wr"] = benchmark::Counter(write_elts * sizeof(OutT), - benchmark::Counter::kIsIterationInvariantRate, - benchmark::Counter::OneK::kIs1000); - state.counters["BW Rd"] = benchmark::Counter(read_elts * sizeof(DataT), - benchmark::Counter::kIsIterationInvariantRate, - benchmark::Counter::OneK::kIs1000); - - state.counters["m"] = benchmark::Counter(params.m); - state.counters["n"] = benchmark::Counter(params.n); - state.counters["k"] = benchmark::Counter(params.k); - state.counters["num_groups"] = benchmark::Counter(params.num_groups); - state.counters["group size"] = benchmark::Counter(params.n / params.num_groups); - state.counters["Pat"] = benchmark::Counter(static_cast(params.pattern)); - - state.counters["SM count"] = raft::getMultiProcessorCount(); - } -}; // struct MaskedL2NN - -const std::vector masked_l2_nn_input_vecs = { - // Very fat matrices... - {32, 16384, 16384, 32, AdjacencyPattern::checkerboard}, - {64, 16384, 16384, 32, AdjacencyPattern::checkerboard}, - {128, 16384, 16384, 32, AdjacencyPattern::checkerboard}, - {256, 16384, 16384, 32, AdjacencyPattern::checkerboard}, - {512, 16384, 16384, 32, AdjacencyPattern::checkerboard}, - {1024, 16384, 16384, 32, AdjacencyPattern::checkerboard}, - {16384, 32, 16384, 32, AdjacencyPattern::checkerboard}, - {16384, 64, 16384, 32, AdjacencyPattern::checkerboard}, - {16384, 128, 16384, 32, AdjacencyPattern::checkerboard}, - {16384, 256, 16384, 32, AdjacencyPattern::checkerboard}, - {16384, 512, 16384, 32, AdjacencyPattern::checkerboard}, - {16384, 1024, 16384, 32, AdjacencyPattern::checkerboard}, - - // Representative matrices... - {16384, 16384, 32, 32, AdjacencyPattern::checkerboard}, - {16384, 16384, 64, 32, AdjacencyPattern::checkerboard}, - {16384, 16384, 128, 32, AdjacencyPattern::checkerboard}, - {16384, 16384, 256, 32, AdjacencyPattern::checkerboard}, - {16384, 16384, 512, 32, AdjacencyPattern::checkerboard}, - {16384, 16384, 1024, 32, AdjacencyPattern::checkerboard}, - {16384, 16384, 16384, 32, AdjacencyPattern::checkerboard}, - - {16384, 16384, 32, 32, AdjacencyPattern::checkerboard_4}, - {16384, 16384, 64, 32, AdjacencyPattern::checkerboard_4}, - {16384, 16384, 128, 32, AdjacencyPattern::checkerboard_4}, - {16384, 16384, 256, 32, AdjacencyPattern::checkerboard_4}, - {16384, 16384, 512, 32, AdjacencyPattern::checkerboard_4}, - {16384, 16384, 1024, 32, AdjacencyPattern::checkerboard_4}, - {16384, 16384, 16384, 32, AdjacencyPattern::checkerboard_4}, - - {16384, 16384, 32, 32, AdjacencyPattern::checkerboard_64}, - {16384, 16384, 64, 32, AdjacencyPattern::checkerboard_64}, - {16384, 16384, 128, 32, AdjacencyPattern::checkerboard_64}, - {16384, 16384, 256, 32, AdjacencyPattern::checkerboard_64}, - {16384, 16384, 512, 32, AdjacencyPattern::checkerboard_64}, - {16384, 16384, 1024, 32, AdjacencyPattern::checkerboard_64}, - {16384, 16384, 16384, 32, AdjacencyPattern::checkerboard_64}, - - {16384, 16384, 32, 32, AdjacencyPattern::all_true}, - {16384, 16384, 64, 32, AdjacencyPattern::all_true}, - {16384, 16384, 128, 32, AdjacencyPattern::all_true}, - {16384, 16384, 256, 32, AdjacencyPattern::all_true}, - {16384, 16384, 512, 32, AdjacencyPattern::all_true}, - {16384, 16384, 1024, 32, AdjacencyPattern::all_true}, - {16384, 16384, 16384, 32, AdjacencyPattern::all_true}, - - {16384, 16384, 32, 32, AdjacencyPattern::all_false}, - {16384, 16384, 64, 32, AdjacencyPattern::all_false}, - {16384, 16384, 128, 32, AdjacencyPattern::all_false}, - {16384, 16384, 256, 32, AdjacencyPattern::all_false}, - {16384, 16384, 512, 32, AdjacencyPattern::all_false}, - {16384, 16384, 1024, 32, AdjacencyPattern::all_false}, - {16384, 16384, 16384, 32, AdjacencyPattern::all_false}, -}; - -RAFT_BENCH_REGISTER(masked_l2_nn, "", masked_l2_nn_input_vecs); -// We don't benchmark double to keep compile times in check when not using the -// distance library. - -} // namespace raft::bench::distance::masked_nn diff --git a/cpp/bench/neighbors/knn.cuh b/cpp/bench/neighbors/knn.cuh index eec1cba99e..633ea33670 100644 --- a/cpp/bench/neighbors/knn.cuh +++ b/cpp/bench/neighbors/knn.cuh @@ -33,8 +33,6 @@ #if defined RAFT_DISTANCE_COMPILED #include #include -#else -#pragma message("NN / Distance specializations are not enabled; expect very long building times.") #endif #endif diff --git a/cpp/bench/neighbors/refine.cu b/cpp/bench/neighbors/refine.cu index f32af3a57e..255004361c 100644 --- a/cpp/bench/neighbors/refine.cu +++ b/cpp/bench/neighbors/refine.cu @@ -27,7 +27,6 @@ #if defined RAFT_DISTANCE_COMPILED #include -#include #endif #if defined RAFT_NN_COMPILED @@ -114,9 +113,9 @@ std::vector> getInputs() return out; } -using refine_float_uint64 = RefineAnn; -RAFT_BENCH_REGISTER(refine_float_uint64, "", getInputs()); +using refine_float_int64 = RefineAnn; +RAFT_BENCH_REGISTER(refine_float_int64, "", getInputs()); -using refine_uint8_uint64 = RefineAnn; -RAFT_BENCH_REGISTER(refine_uint8_uint64, "", getInputs()); +using refine_uint8_int64 = RefineAnn; +RAFT_BENCH_REGISTER(refine_uint8_int64, "", getInputs()); } // namespace raft::bench::neighbors diff --git a/cpp/include/raft/distance/detail/compress_to_bits.cuh b/cpp/include/raft/distance/detail/compress_to_bits.cuh deleted file mode 100644 index e36b7ce707..0000000000 --- a/cpp/include/raft/distance/detail/compress_to_bits.cuh +++ /dev/null @@ -1,122 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include -#include -#include - -namespace raft::distance::detail { - -/** - * @brief Compress 2D boolean matrix to bitfield - * - * Utility kernel for maskedL2NN. - * - * @tparam T - * - * @parameter[in] in An `m x n` boolean matrix. Row major. - * @parameter[out] out An `(m / bits_per_elem) x n` matrix with elements of - * type T, where T is of size `bits_per_elem` bits. - * Note: the division (`/`) is a ceilDiv. - */ -template ::value>> -__global__ void compress_to_bits_kernel( - raft::device_matrix_view in, - raft::device_matrix_view out) -{ - constexpr int bits_per_element = 8 * sizeof(T); - constexpr int tile_dim_m = bits_per_element; - constexpr int nthreads = 128; - constexpr int tile_dim_n = nthreads; // read 128 bools at once = 1 sector - - // Tile in shared memory is transposed - __shared__ bool smem[tile_dim_n][tile_dim_m]; - - const int num_tiles_per_m = raft::ceildiv(in.extent(0), tile_dim_m); - const int num_tiles_per_n = raft::ceildiv(in.extent(1), tile_dim_n); - - for (int lin_tile_idx = blockIdx.x; true; lin_tile_idx += gridDim.x) { - const int tile_idx_n = tile_dim_n * (lin_tile_idx % num_tiles_per_n); - const int tile_idx_m = tile_dim_m * (lin_tile_idx / num_tiles_per_n); - - if (in.extent(0) <= tile_idx_m) { break; } - // Fill shared memory tile - bool reg_buf[tile_dim_m]; -#pragma unroll - for (int i = 0; i < tile_dim_m; ++i) { - const int in_m = tile_idx_m + i; - const int in_n = tile_idx_n + threadIdx.x; - bool in_bounds = in_m < in.extent(0) && in_n < in.extent(1); - reg_buf[i] = in_bounds ? in(in_m, in_n) : false; - smem[threadIdx.x][i] = reg_buf[i]; - } - __syncthreads(); - - // Drain memory tile into single output element out_elem. - T out_elem{0}; -#pragma unroll - for (int j = 0; j < tile_dim_n; ++j) { - if (smem[threadIdx.x][j]) { out_elem |= T(1) << j; } - } - __syncthreads(); - - // Write output. - int out_m = tile_idx_m / bits_per_element; - int out_n = tile_idx_n + threadIdx.x; - - if (out_m < out.extent(0) && out_n < out.extent(1)) { out(out_m, out_n) = out_elem; } - } -} - -/** - * @brief Compress 2D boolean matrix to bitfield - * - * Utility kernel for maskedL2NN. - * - * @tparam T - * - * @parameter[in] in An `m x n` boolean matrix. Row major. - * @parameter[out] out An `(m / bits_per_elem) x n` matrix with elements of - * type T, where T is of size `bits_per_elem` bits. - * Note: the division (`/`) is a ceilDiv. - */ -template ::value>> -void compress_to_bits(raft::device_resources const& handle, - raft::device_matrix_view in, - raft::device_matrix_view out) -{ - auto stream = handle.get_stream(); - constexpr int bits_per_element = 8 * sizeof(T); - - RAFT_EXPECTS(raft::ceildiv(in.extent(0), bits_per_element) == out.extent(0), - "Number of output rows must be ceildiv(input rows, bits_per_elem)"); - RAFT_EXPECTS(in.extent(1) == out.extent(1), "Number of output columns must equal input columns."); - - const int num_SMs = raft::getMultiProcessorCount(); - int blocks_per_sm = 0; - constexpr int num_threads = 128; - constexpr int dyn_smem_size = 0; - RAFT_CUDA_TRY(cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &blocks_per_sm, compress_to_bits_kernel, num_threads, dyn_smem_size)); - - dim3 grid(num_SMs * blocks_per_sm); - dim3 block(128); - compress_to_bits_kernel<<>>(in, out); - RAFT_CUDA_TRY(cudaGetLastError()); -} - -}; // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/detail/fused_l2_nn.cuh b/cpp/include/raft/distance/detail/fused_l2_nn.cuh index 8fbd7a9c69..447359ffe6 100644 --- a/cpp/include/raft/distance/detail/fused_l2_nn.cuh +++ b/cpp/include/raft/distance/detail/fused_l2_nn.cuh @@ -37,7 +37,6 @@ template struct KVPMinReduceImpl { typedef raft::KeyValuePair KVP; DI KVP operator()(LabelT rit, const KVP& a, const KVP& b) { return b.value < a.value ? b : a; } - DI KVP operator()(const KVP& a, const KVP& b) { return b.value < a.value ? b : a; } }; // KVPMinReduce diff --git a/cpp/include/raft/distance/detail/masked_distance_base.cuh b/cpp/include/raft/distance/detail/masked_distance_base.cuh deleted file mode 100644 index 6d4e3f40a6..0000000000 --- a/cpp/include/raft/distance/detail/masked_distance_base.cuh +++ /dev/null @@ -1,326 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once -#include -#include -#include - -#include - -namespace raft { -namespace distance { -namespace detail { - -/** - * @brief Device class for masked nearest neighbor computations. - * - * @tparam useNorms whether norms are needed - * @tparam DataT input data-type (for x and y matrices) - * @tparam AccT accumulation data-type - * @tparam IdxT index data-type - * @tparam Policy struct which tunes the Contraction kernel - * @tparam CoreLambda tells how to accumulate an x and y into - acc. its signature: - template void core_lambda(AccT& acc, - const DataT& x, const DataT& y) - * @tparam EpilogueLambda applies an elementwise function to compute final - values. Its signature is: - template void epilogue_lambda - (AccT acc[][], DataT* regxn, DataT* regyn); - * @tparam FinalLambda the final lambda called on final distance value - * @tparam rowEpilogueLambda epilog lambda that executes when a full row has - * been processed. - * - * @param[in] x input matrix - * @param[in] y input matrix - * @param[in] m number of rows of x - * @param[in] n number of columns of y - * @param[in] k number of cols of x and y - * @param[in] lda leading dimension of x - * @param[in] ldb leading dimension of y - * @param[in] ldd parameter to keep Contractions_NT happy.. - * @param[in] xn row norms of input matrix A. Required for expanded L2, cosine - * @param[in] yn row norms of input matrix B. Required for expanded L2, cosine - * @param[in] adj An adjacency matrix encoded as a bitfield indicating for each - * row of `x` and each group in `y` whether to compute the - * distance. Dim = `(m / 64) x num_groups`. - * @param[in] group_idxs An array containing the *end* indices of each group - * in `y`. The value of group_idxs[j] indicates the - * start of group j + 1, i.e., it is the inclusive - * scan of the group lengths. The first group is - * always assumed to start at index 0 and the last - * group typically ends at index `n`. Length = - * `num_groups`. - * @param[in] num_groups The number of groups in group_idxs. - * @param[in] smem shared mem buffer for intermediate storage of x, y, xn & yn. - * @param core_op the core accumulation operation lambda - * @param epilog_op the epilog operation lambda - * @param fin_op the final gemm epilogue lambda - * @param rowEpilog_op epilog lambda that executes when a full row has been processed. - */ -template > -struct MaskedDistances : public BaseClass { - private: - typedef Policy P; - const DataT* xn; - const DataT* yn; - const DataT* const yBase; - const uint64_t* adj; - const IdxT* group_idxs; - IdxT num_groups; - char* smem; - CoreLambda core_op; - EpilogueLambda epilog_op; - FinalLambda fin_op; - rowEpilogueLambda rowEpilog_op; - - AccT acc[P::AccRowsPerTh][P::AccColsPerTh]; - - public: - // Constructor - DI MaskedDistances(const DataT* _x, - const DataT* _y, - IdxT _m, - IdxT _n, - IdxT _k, - IdxT _lda, - IdxT _ldb, - IdxT _ldd, - const DataT* _xn, - const DataT* _yn, - const uint64_t* _adj, - const IdxT* _group_idxs, - IdxT _num_groups, - char* _smem, - CoreLambda _core_op, - EpilogueLambda _epilog_op, - FinalLambda _fin_op, - rowEpilogueLambda _rowEpilog_op) - : BaseClass(_x, _y, _m, _n, _k, _lda, _ldb, _ldd, _smem), - xn(_xn), - yn(_yn), - yBase(_y), - adj(_adj), - group_idxs(_group_idxs), - num_groups(_num_groups), - smem(_smem), - core_op(_core_op), - epilog_op(_epilog_op), - fin_op(_fin_op), - rowEpilog_op(_rowEpilog_op) - { - } - - DI void run() - { - const auto grid_stride_m = (P::Mblk * gridDim.y); - const auto grid_offset_m = (P::Mblk * blockIdx.y); - - const auto grid_stride_g = gridDim.x; - const auto grid_offset_g = blockIdx.x; - - for (auto tile_idx_m = grid_offset_m; tile_idx_m < this->m; tile_idx_m += grid_stride_m) { - // Start loop over groups - for (auto idx_g = grid_offset_g; idx_g < this->num_groups; idx_g += grid_stride_g) { - const uint64_t block_adj = get_block_adjacency(adj, tile_idx_m, idx_g); - // block_adj is a bitfield that contains a 1 if a row is adjacent to the - // current group. All zero means we can skip this group. - if (block_adj == 0) { continue; } - - // thread_adj is a bitfield that contains a 1 at location i iff we must - // compute row i of acc (the accumulator register tile). That is, - // for i = 0,.., AccRowsPerTh and j = 0,.., AccColsPerTh: - // - // ((1 << i) & thread_adj) > 0 <=> acc[i][j] must be computed. - // - // We precompute this information because it is used in various - // locations to skip thread-local computations, specifically: - // - // 1. To skip computations if thread_adj == 0, i.e., none of the values - // of `acc` have to be computed. - // - // 2. In epilog_op, to consider only values of `acc` to be reduced that - // are not masked of. - // - // Note 1: Even when the computation can be skipped for a specific thread, - // the thread still participates in synchronization operations. - // - // Note 2: In theory, it should be possible to skip computations for - // specific rows of `acc`. In practice, however, this does not improve - // performance. - int thread_adj = compute_thread_adjacency(block_adj); - - auto tile_idx_n = idx_g == 0 ? 0 : group_idxs[idx_g - 1]; - const auto group_end_n = group_idxs[idx_g]; - for (; tile_idx_n < group_end_n; tile_idx_n += P::Nblk) { - // We provide group_end_n to limit the number of unnecessary data - // points that are loaded from y. - this->ldgXY(tile_idx_m, tile_idx_n, 0, group_end_n); - - reset_accumulator(); - this->stsXY(); - __syncthreads(); - this->switch_write_buffer(); - - for (int kidx = P::Kblk; kidx < this->k; kidx += P::Kblk) { - this->ldgXY(tile_idx_m, tile_idx_n, kidx, group_end_n); - // Process all data in shared memory (previous k-block) and - // accumulate in registers. - if (thread_adj != 0) { accumulate(); } - this->stsXY(); - __syncthreads(); - this->switch_write_buffer(); - this->switch_read_buffer(); - } - if (thread_adj != 0) { - accumulate(); // last iteration - } - // The pre-condition for the loop over tile_idx_n is that write_buffer - // and read_buffer point to the same buffer. This flips read_buffer - // back so that it satisfies the pre-condition of this loop. - this->switch_read_buffer(); - - if (useNorms) { - DataT regxn[P::AccRowsPerTh], regyn[P::AccColsPerTh]; - load_norms(tile_idx_m, tile_idx_n, group_end_n, regxn, regyn); - if (thread_adj != 0) { - epilog_op(acc, thread_adj, regxn, regyn, tile_idx_n, tile_idx_m, group_end_n); - } - } else { - if (thread_adj != 0) { - epilog_op(acc, thread_adj, nullptr, nullptr, tile_idx_n, tile_idx_m, group_end_n); - } - } - } // tile_idx_n - } // idx_g - rowEpilog_op(tile_idx_m); - } // tile_idx_m - } - - private: - DI uint64_t get_block_adjacency(const uint64_t* adj, IdxT tile_idx_m, IdxT idx_group) - { - // A single element of `adj` contains exactly enough bits to indicate which - // rows in the current tile to skip and which to compute. - static_assert(P::Mblk == 8 * sizeof(adj[0]), - "maskedL2NN only supports a policy with 64 rows per block."); - IdxT block_flag_idx = tile_idx_m / P::Mblk; - // Index into adj at row tile_idx_m / 64 and column idx_group. - return adj[block_flag_idx * this->num_groups + idx_group]; - } - - DI uint32_t compute_thread_adjacency(const uint64_t block_adj) - { - // thread_adj is a bitfield that contains a 1 at location i iff we must - // compute row i of acc (the accumulator register tile). It is described in - // more detail in the run() method. - uint32_t thread_adj = 0; -#pragma unroll - for (int thread_row_idx = 0; thread_row_idx < P::AccRowsPerTh; ++thread_row_idx) { - // Index `thread_row_idx` refers to a row of the current threads' register - // tile `acc`, i.e., acc[i][:]. Index `block_row_idx` refers to the - // corresponding row of the current block tile in shared memory. - const int block_row_idx = this->accrowid + thread_row_idx * P::AccThRows; - - // block_row_is_adjacent is true if the current block_row_idx is adjacent - // to the current group. - const uint64_t block_mask = 1ull << block_row_idx; - const bool block_row_is_adjacent = (block_adj & block_mask) != 0; - if (block_row_is_adjacent) { - // If block row is adjacent, write a 1 bit to thread_adj at location - // `thread_row_idx`. - const uint32_t thread_mask = 1 << thread_row_idx; - thread_adj |= thread_mask; - } - } - return thread_adj; - } - - DI void reset_accumulator() - { - // Reset accumulator registers to zero. -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = 0; j < P::AccColsPerTh; ++j) { - acc[i][j] = BaseClass::Zero; - } - } - } - - DI void accumulate() - { -#pragma unroll - for (int ki = 0; ki < P::Kblk; ki += P::Veclen) { - this->ldsXY(ki); -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = 0; j < P::AccColsPerTh; ++j) { -#pragma unroll - for (int v = 0; v < P::Veclen; ++v) { - core_op(acc[i][j], this->regx[i][v], this->regy[j][v]); - } - } - } - } - } - - DI void load_norms(IdxT tile_idx_m, - IdxT tile_idx_n, - IdxT end_n, - DataT (®xn)[P::AccRowsPerTh], - DataT (®yn)[P::AccColsPerTh]) - { - DataT* sxNorm = (DataT*)(&smem[P::SmemSize]); - DataT* syNorm = (&sxNorm[P::Mblk]); - - // Load x & y norms required by this threadblock in shmem buffer - for (int i = threadIdx.x; i < P::Mblk; i += P::Nthreads) { - auto idx = tile_idx_m + i; - sxNorm[i] = idx < this->m ? xn[idx] : 0; - } - - for (int i = threadIdx.x; i < P::Nblk; i += P::Nthreads) { - auto idx = tile_idx_n + i; - syNorm[i] = idx < end_n ? yn[idx] : 0; - } - __syncthreads(); - -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { - regxn[i] = sxNorm[i * P::AccThRows + (threadIdx.x / P::AccThCols)]; - } -#pragma unroll - for (int i = 0; i < P::AccColsPerTh; ++i) { - regyn[i] = syNorm[i * P::AccThCols + (threadIdx.x % P::AccThCols)]; - } - } -}; // struct MaskedDistances - -}; // namespace detail -}; // namespace distance -}; // namespace raft diff --git a/cpp/include/raft/distance/detail/masked_nn.cuh b/cpp/include/raft/distance/detail/masked_nn.cuh deleted file mode 100644 index 1c92de16fc..0000000000 --- a/cpp/include/raft/distance/detail/masked_nn.cuh +++ /dev/null @@ -1,325 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include - -#include -#include -#include -#include -#include -#include - -namespace raft { -namespace distance { -namespace detail { - -template -__global__ __launch_bounds__(P::Nthreads, 2) void maskedL2NNkernel(OutT* min, - const DataT* x, - const DataT* y, - const DataT* xn, - const DataT* yn, - const uint64_t* adj, - const IdxT* group_idxs, - IdxT num_groups, - IdxT m, - IdxT n, - IdxT k, - bool sqrt, - DataT maxVal, - int* mutex, - ReduceOpT redOp, - KVPReduceOpT pairRedOp, - CoreLambda core_op, - FinalLambda fin_op) -{ - extern __shared__ char smem[]; - - typedef raft::KeyValuePair KVPair; - KVPair val[P::AccRowsPerTh]; -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { - val[i] = {-1, maxVal}; - } - - // epilogue operation lambda for final value calculation - auto epilog_lambda = [pairRedOp, &val, maxVal, sqrt] __device__( - DataT acc[P::AccRowsPerTh][P::AccColsPerTh], - int thread_adj, - DataT* regxn, - DataT* regyn, - IdxT tile_idx_n, - IdxT tile_idx_m, - IdxT tile_end_n) { - KVPReduceOpT pairRed_op(pairRedOp); - -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = 0; j < P::AccColsPerTh; ++j) { - acc[i][j] = regxn[i] + regyn[j] - (DataT)2.0 * acc[i][j]; - } - } - if (sqrt) { -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = 0; j < P::AccColsPerTh; ++j) { - acc[i][j] = raft::sqrt(acc[i][j]); - } - } - } - - // intra thread reduce - const auto acccolid = threadIdx.x % P::AccThCols; - const auto accrowid = threadIdx.x / P::AccThCols; - -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { - // thread_adj is a bitfield that contains a 1 at location i iff we must - // compute row i of acc (the accumulator register tile). It is described in - // more detail in the maskedDistances.run() method. - const bool ignore = (thread_adj & (1 << i)) == 0; - if (ignore) { continue; } -#pragma unroll - for (int j = 0; j < P::AccColsPerTh; ++j) { - auto tmpkey = acccolid + j * P::AccThCols + tile_idx_n; - if (tile_end_n <= tmpkey) { - // Do not process beyond end of tile. - continue; - } - KVPair tmp = {tmpkey, acc[i][j]}; - if (tmpkey < tile_end_n) { - val[i] = pairRed_op(accrowid + i * P::AccThRows + tile_idx_m, tmp, val[i]); - } - } - } - }; - - auto rowEpilog_lambda = - [m, mutex, min, pairRedOp, redOp, &val, maxVal] __device__(IdxT tile_idx_m) { - KVPReduceOpT pairRed_op(pairRedOp); - ReduceOpT red_op(redOp); - - const auto accrowid = threadIdx.x / P::AccThCols; - const auto lid = raft::laneId(); - // reduce -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = P::AccThCols / 2; j > 0; j >>= 1) { - auto tmpkey = raft::shfl(val[i].key, lid + j); - auto tmpvalue = raft::shfl(val[i].value, lid + j); - KVPair tmp = {tmpkey, tmpvalue}; - val[i] = pairRed_op(accrowid + i * P::AccThRows + tile_idx_m, tmp, val[i]); - } - } - - updateReducedVal(mutex, min, val, red_op, m, tile_idx_m); - - // reset the val array. -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { - val[i] = {-1, maxVal}; - } - }; - - IdxT lda = k, ldb = k, ldd = n; - MaskedDistances - obj(x, - y, - m, - n, - k, - lda, - ldb, - ldd, - xn, - yn, - adj, - group_idxs, - num_groups, - smem, - core_op, - epilog_lambda, - fin_op, - rowEpilog_lambda); - obj.run(); -} - -/** - * @brief Wrapper for maskedL2NNkernel - * - * Responsibilities: - * - Allocate (and initialize) workspace memory for: - * - mutexes used in nearest neighbor update step - * - adjacency matrix bitfield - * - Compress adjacency matrix to bitfield - * - Initialize output buffer (conditional on `initOutBuffer`) - * - Specify core and final operations for the L2 norm - * - Determine optimal launch configuration for kernel. - * - Launch kernel and check for errors. - * - * @tparam DataT Input data-type (for x and y matrices). - * @tparam OutT Output data-type (for key-value pairs). - * @tparam IdxT Index data-type. - * @tparam ReduceOpT A struct to perform the final needed reduction - * operation and also to initialize the output array - * elements with the appropriate initial value needed for - * reduction. - * @tparam KVPReduceOpT Type of Reduction operation on key value pairs. - * - * @param handle RAFT handle for managing expensive resources - * @param[out] out Will contain reduced output (nn key-value pairs) - * @param[in] x First matrix. Row major. Dim = `m x k`. (on device) - * @param[in] y Second matrix. Row major. Dim = `n x k`. (on device) - * @param[in] xn L2 squared norm of `x`. Length = `m`. - * @param[in] yn L2 squared norm of `y`. Length = `n`. - * @param[in] adj A boolean adjacency matrix indicating for each - * row of `x` and each group in `y` whether to compute the - * distance. Dim = `m x num_groups`. - * @param[in] group_idxs An array containing the *end* indices of each group - * in `y`. The value of group_idxs[j] indicates the - * start of group j + 1, i.e., it is the inclusive - * scan of the group lengths. The first group is - * always assumed to start at index 0 and the last - * group typically ends at index `n`. Length = - * `num_groups`. - * @param[in] num_groups Length of `group_idxs`. - * @param m Rows of `x`. - * @param n Rows of `y`. - * @param k Cols of `x` and `y`. - * @param redOp Reduction operator in the epilogue - * @param pairRedOp Reduction operation on key value pairs - * @param sqrt Whether to compute the squared or actual (i.e. sqrt) L2 norm. - * @param initOutBuffer Whether to initialize the output buffer - * - * - */ -template -void maskedL2NNImpl(raft::device_resources const& handle, - OutT* out, - const DataT* x, - const DataT* y, - const DataT* xn, - const DataT* yn, - const bool* adj, - const IdxT* group_idxs, - IdxT num_groups, - IdxT m, - IdxT n, - IdxT k, - ReduceOpT redOp, - KVPReduceOpT pairRedOp, - bool sqrt, - bool initOutBuffer) -{ - typedef typename linalg::Policy4x4::Policy P; - - static_assert(P::Mblk == 64, "maskedL2NNImpl only supports a policy with 64 rows per block."); - - // Get stream and workspace memory resource - rmm::mr::device_memory_resource* ws_mr = - dynamic_cast(handle.get_workspace_resource()); - auto stream = handle.get_stream(); - - // Acquire temporary buffers and initialize to zero: - // 1) Adjacency matrix bitfield - // 2) Workspace for fused nearest neighbor operation - size_t m_div_64 = raft::ceildiv(m, IdxT(64)); - rmm::device_uvector ws_adj64{m_div_64 * num_groups, stream, ws_mr}; - rmm::device_uvector ws_fused_nn{size_t(m), stream, ws_mr}; - RAFT_CUDA_TRY(cudaMemsetAsync(ws_adj64.data(), 0, ws_adj64.size() * sizeof(uint64_t), stream)); - RAFT_CUDA_TRY(cudaMemsetAsync(ws_fused_nn.data(), 0, ws_fused_nn.size() * sizeof(int), stream)); - - // Compress boolean adjacency matrix to bitfield. - auto adj_view = raft::make_device_matrix_view(adj, m, num_groups); - auto adj64_view = - raft::make_device_matrix_view(ws_adj64.data(), m_div_64, num_groups); - compress_to_bits(handle, adj_view, adj64_view); - - // Initialize output buffer with keyvalue pairs as determined by the reduction - // operator (it will be called with maxVal). - constexpr auto maxVal = std::numeric_limits::max(); - if (initOutBuffer) { - dim3 grid(raft::ceildiv(m, P::Nthreads)); - dim3 block(P::Nthreads); - - initKernel<<>>(out, m, maxVal, redOp); - RAFT_CUDA_TRY(cudaGetLastError()); - } - - // Accumulation operation lambda - auto core_lambda = [] __device__(DataT & acc, DataT & x, DataT & y) { acc += x * y; }; - auto fin_op = raft::identity_op{}; - - auto kernel = maskedL2NNkernel; - constexpr size_t smemSize = P::SmemSize + ((P::Mblk + P::Nblk) * sizeof(DataT)); - dim3 block(P::Nthreads); - dim3 grid = launchConfigGenerator

(m, n, smemSize, kernel); - - kernel<<>>(out, - x, - y, - xn, - yn, - ws_adj64.data(), - group_idxs, - num_groups, - m, - n, - k, - sqrt, - maxVal, - ws_fused_nn.data(), - redOp, - pairRedOp, - core_lambda, - fin_op); - - RAFT_CUDA_TRY(cudaGetLastError()); -} - -} // namespace detail -} // namespace distance -} // namespace raft diff --git a/cpp/include/raft/distance/detail/pairwise_distance_base.cuh b/cpp/include/raft/distance/detail/pairwise_distance_base.cuh index d849b23999..445b4bac52 100644 --- a/cpp/include/raft/distance/detail/pairwise_distance_base.cuh +++ b/cpp/include/raft/distance/detail/pairwise_distance_base.cuh @@ -59,7 +59,6 @@ namespace detail { * @param core_op the core accumulation operation lambda * @param epilog_op the epilog operation lambda * @param fin_op the final gemm epilogue lambda - * @param rowEpilog_op epilog lambda that executes when a full row has been processed */ template m; tile_idx_m += grid_stride_m) { - this->ldgXY(tile_idx_m, grid_offset_n, 0); - for (auto tile_idx_n = grid_offset_n; tile_idx_n < this->n; tile_idx_n += grid_stride_n) { - // Prolog: - reset_accumulator(); - this->stsXY(); - __syncthreads(); - this->switch_write_buffer(); - - // Main loop: - for (int kidx = P::Kblk; kidx < this->k; kidx += P::Kblk) { - this->ldgXY(tile_idx_m, tile_idx_n, kidx); - // Process all data in shared memory (previous k-block) and - // accumulate in registers. - accumulate(); - this->stsXY(); - __syncthreads(); - this->switch_write_buffer(); - this->switch_read_buffer(); - } - accumulate(); // last iteration - // The pre-condition for the loop over tile_idx_n is that write_buffer - // and read_buffer point to the same buffer. This flips read_buffer back - // so that it satisfies the pre-condition of this loop. - this->switch_read_buffer(); - - // Epilog: - if (useNorms) { - DataT regxn[P::AccRowsPerTh], regyn[P::AccColsPerTh]; - load_norms(tile_idx_m, tile_idx_n, regxn, regyn); - // Overlap ldg with epilog computation - ldgNextGridStride(tile_idx_m, tile_idx_n); - epilog_op(acc, regxn, regyn, tile_idx_n, tile_idx_m); - } else { - // Overlap ldg with epilog computation - ldgNextGridStride(tile_idx_m, tile_idx_n); - epilog_op(acc, nullptr, nullptr, tile_idx_n, tile_idx_m); - } - if (writeOut) { store_output(tile_idx_m, tile_idx_n); } + for (auto gridStrideY = blockIdx.y * P::Mblk; gridStrideY < this->m; + gridStrideY += P::Mblk * gridDim.y) { + for (auto gridStrideX = blockIdx.x * P::Nblk; gridStrideX < this->n; + gridStrideX += P::Nblk * gridDim.x) { + prolog(gridStrideX, gridStrideY); + loop(); + epilog(gridStrideX, gridStrideY); } - rowEpilog_op(tile_idx_m); + rowEpilog_op(gridStrideY); } } private: - DI void ldgNextGridStride(IdxT tile_idx_m, IdxT tile_idx_n) + DI void updateIndicesY() + { + const auto stride = P::Nblk * gridDim.x; + if (isRowMajor) { + this->y += stride * this->ldb; + } else { + this->y += stride; + } + this->yrowid += stride; + } + + DI void updateIndicesXY() + { + const auto stride = P::Mblk * gridDim.y; + if (isRowMajor) { + this->x += stride * this->lda; + this->yrowid = IdxT(blockIdx.x) * P::Nblk + this->srowid; + this->y = yBase + this->yrowid * this->ldb; + } else { + this->x += stride; + this->yrowid = IdxT(blockIdx.x) * P::Nblk; + this->y = yBase + this->yrowid + this->srowid * this->ldb; + } + this->xrowid += stride; + } + + DI void ldgNextGridStride(IdxT gridStrideX, IdxT gridStrideY) { // Fetch next grid stride ldg if within range - const auto next_tile_tile_idx_n = tile_idx_n + grid_stride_n; - const auto next_tile_tile_idx_m = tile_idx_m + grid_stride_m; - if ((next_tile_tile_idx_n) < this->n) { - this->ldgXY(tile_idx_m, next_tile_tile_idx_n, 0); - } else if ((next_tile_tile_idx_m) < this->m) { - this->ldgXY(next_tile_tile_idx_m, grid_offset_n, 0); + if ((gridStrideX + gridDim.x * P::Nblk) < this->n) { + updateIndicesY(); + this->ldgXY(0); + } else if ((gridStrideY + gridDim.y * P::Mblk) < this->m) { + updateIndicesXY(); + this->ldgXY(0); } } - DI void reset_accumulator() + DI void prolog(IdxT gridStrideX, IdxT gridStrideY) { - // Reset accumulator registers to zero. + if (gridStrideX == blockIdx.x * P::Nblk) { this->ldgXY(0); } + #pragma unroll for (int i = 0; i < P::AccRowsPerTh; ++i) { #pragma unroll @@ -199,6 +184,28 @@ struct PairwiseDistances : public BaseClass { acc[i][j] = BaseClass::Zero; } } + + this->stsXY(); + __syncthreads(); + this->pageWr ^= 1; + } + + DI void loop() + { + for (int kidx = P::Kblk; kidx < this->k; kidx += P::Kblk) { + this->ldgXY(kidx); + accumulate(); // on the previous k-block + this->stsXY(); + __syncthreads(); + this->pageWr ^= 1; + this->pageRd ^= 1; + } + accumulate(); // last iteration + // This is needed for making sure next grid stride of + // non-norm based metrics uses previously accumulated buffer so + // it doesn't make shmem dirty until previous iteration + // is complete. + this->pageRd ^= 1; } DI void accumulate() @@ -219,52 +226,60 @@ struct PairwiseDistances : public BaseClass { } } - DI void load_norms(IdxT tile_idx_m, - IdxT tile_idx_n, - DataT (®xn)[P::AccRowsPerTh], - DataT (®yn)[P::AccColsPerTh]) + DI void epilog(IdxT gridStrideX, IdxT gridStrideY) { - DataT* sxNorm = (DataT*)(&smem[P::SmemSize]); - DataT* syNorm = (&sxNorm[P::Mblk]); - - // Load x & y norms required by this threadblock in shmem buffer - if (tile_idx_n == blockIdx.x * P::Nblk) { - for (int i = threadIdx.x; i < P::Mblk; i += P::Nthreads) { - auto idx = tile_idx_m + i; - sxNorm[i] = idx < this->m ? xn[idx] : 0; + if (useNorms) { + DataT* sxNorm = (DataT*)(&smem[P::SmemSize]); + DataT* syNorm = (&sxNorm[P::Mblk]); + + // Load x & y norms required by this threadblock in shmem buffer + if (gridStrideX == blockIdx.x * P::Nblk) { + for (int i = threadIdx.x; i < P::Mblk; i += P::Nthreads) { + auto idx = gridStrideY + i; + sxNorm[i] = idx < this->m ? xn[idx] : 0; + } } - } - for (int i = threadIdx.x; i < P::Nblk; i += P::Nthreads) { - auto idx = tile_idx_n + i; - syNorm[i] = idx < this->n ? yn[idx] : 0; - } - __syncthreads(); + for (int i = threadIdx.x; i < P::Nblk; i += P::Nthreads) { + auto idx = gridStrideX + i; + syNorm[i] = idx < this->n ? yn[idx] : 0; + } + __syncthreads(); + + DataT regxn[P::AccRowsPerTh], regyn[P::AccColsPerTh]; #pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { - regxn[i] = sxNorm[i * P::AccThRows + (threadIdx.x / P::AccThCols)]; - } + for (int i = 0; i < P::AccRowsPerTh; ++i) { + regxn[i] = sxNorm[i * P::AccThRows + (threadIdx.x / P::AccThCols)]; + } #pragma unroll - for (int i = 0; i < P::AccColsPerTh; ++i) { - regyn[i] = syNorm[i * P::AccThCols + (threadIdx.x % P::AccThCols)]; + for (int i = 0; i < P::AccColsPerTh; ++i) { + regyn[i] = syNorm[i * P::AccThCols + (threadIdx.x % P::AccThCols)]; + } + + // Overlap ldg with epilog computation + ldgNextGridStride(gridStrideX, gridStrideY); + epilog_op(acc, regxn, regyn, gridStrideX, gridStrideY); + } else { + // Overlap ldg with epilog computation + ldgNextGridStride(gridStrideX, gridStrideY); + epilog_op(acc, nullptr, nullptr, gridStrideX, gridStrideY); } - } - DI void store_output(IdxT tile_idx_m, IdxT tile_idx_n) - { - IdxT starty = tile_idx_m + this->accrowid; - IdxT startx = tile_idx_n + this->acccolid; + if (writeOut) { + IdxT starty = gridStrideY + this->accrowid; + IdxT startx = gridStrideX + this->acccolid; #pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { - auto rowId = starty + i * P::AccThRows; + for (int i = 0; i < P::AccRowsPerTh; ++i) { + auto rowId = starty + i * P::AccThRows; #pragma unroll - for (int j = 0; j < P::AccColsPerTh; ++j) { - auto colId = startx + j * P::AccThCols; - if (rowId < this->m && colId < this->n) { - // Promote to 64 bit index for final write, as output array can be > 2^31 - dOutput[std::size_t(rowId) * this->n + colId] = fin_op(acc[i][j], 0); + for (int j = 0; j < P::AccColsPerTh; ++j) { + auto colId = startx + j * P::AccThCols; + if (rowId < this->m && colId < this->n) { + // Promote to 64 bit index for final write, as output array can be > 2^31 + dOutput[std::size_t(rowId) * this->n + colId] = fin_op(acc[i][j], 0); + } } } } diff --git a/cpp/include/raft/distance/masked_nn.cuh b/cpp/include/raft/distance/masked_nn.cuh deleted file mode 100644 index ea2e10a304..0000000000 --- a/cpp/include/raft/distance/masked_nn.cuh +++ /dev/null @@ -1,199 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef __MASKED_L2_NN_H -#define __MASKED_L2_NN_H - -#pragma once - -#include -#include -#include -#include -#include -#include - -namespace raft { -namespace distance { -/** - * \defgroup masked_nn Masked 1-nearest neighbors - * @{ - */ - -/** - * @brief Parameter struct for maskedL2NN function - * - * @tparam ReduceOpT Type of reduction operator in the epilogue. - * @tparam KVPReduceOpT Type of Reduction operation on key value pairs. - * - * Usage example: - * @code{.cpp} - * #include - * - * using IdxT = int; - * using DataT = float; - * using RedOpT = raft::distance::MinAndDistanceReduceOp; - * using PairRedOpT = raft::distance::KVPMinReduce; - * using ParamT = raft::distance::MaskedL2NNParams; - * - * bool init_out = true; - * bool sqrt = false; - * - * ParamT masked_l2_params{RedOpT{}, PairRedOpT{}, sqrt, init_out}; - * @endcode - * - * Prescribes how to reduce a distance to an intermediate type (`redOp`), and - * how to reduce two intermediate types (`pairRedOp`). Typically, a distance is - * mapped to an (index, value) pair and (index, value) pair with the lowest - * value (distance) is selected. - * - * In addition, prescribes whether to compute the square root of the distance - * (`sqrt`) and whether to initialize the output buffer (`initOutBuffer`). - */ -template -struct MaskedL2NNParams { - /** Reduction operator in the epilogue */ - ReduceOpT redOp; - /** Reduction operation on key value pairs */ - KVPReduceOpT pairRedOp; - /** Whether the output `minDist` should contain L2-sqrt */ - bool sqrt; - /** Whether to initialize the output buffer before the main kernel launch */ - bool initOutBuffer; -}; - -/** - * @brief Masked L2 distance and 1-nearest-neighbor computation in a single call. - * - * This function enables faster computation of nearest neighbors if the - * computation of distances between certain point pairs can be skipped. - * - * We use an adjacency matrix that describes which distances to calculate. The - * points in `y` are divided into groups, and the adjacency matrix indicates - * whether to compute distances between points in `x` and groups in `y`. In other - * words, if `adj[i,k]` is true then distance between point `x_i`, and points in - * `group_k` will be calculated. - * - * **Performance considerations** - * - * The points in `x` are processed in tiles of `M` points (`M` is currently 64, - * but may change in the future). As a result, the largest compute time - * reduction occurs if all `M` points can skip a group. If only part of the `M` - * points can skip a group, then at most a minor compute time reduction and a - * modest energy use reduction can be expected. - * - * The points in `y` are also grouped into tiles of `N` points (`N` is currently - * 64, but may change in the future). As a result, group sizes should be larger - * than `N` to avoid wasting computational resources. If the group sizes are - * evenly divisible by `N`, then the computation is most efficient, although for - * larger group sizes this effect is minor. - * - * - * **Comparison to SDDM** - * - * [SDDMM](https://ieeexplore.ieee.org/document/8638042) (sampled dense-dense - * matrix multiplication) is a matrix-matrix multiplication where only part of - * the output is computed. Compared to maskedL2NN, there are a few differences: - * - * - The output of maskedL2NN is a single vector (of nearest neighbors) and not - * a sparse matrix. - * - * - The sampling in maskedL2NN is expressed through intermediate "groups" - rather than a CSR format. - * - * @tparam DataT data type - * @tparam OutT output type to either store 1-NN indices and their minimum - * distances or store only the min distances. Accordingly, one - * has to pass an appropriate `ReduceOpT` - * @tparam IdxT indexing arithmetic type - * @tparam ReduceOpT A struct to perform the final needed reduction operation - * and also to initialize the output array elements with the - * appropriate initial value needed for reduction. - * - * @param handle RAFT handle for managing expensive resources - * @param params Parameter struct specifying the reduction operations. - * @param[in] x First matrix. Row major. Dim = `m x k`. - * (on device). - * @param[in] y Second matrix. Row major. Dim = `n x k`. - * (on device). - * @param[in] x_norm L2 squared norm of `x`. Length = `m`. (on device). - * @param[in] y_norm L2 squared norm of `y`. Length = `n`. (on device) - * @param[in] adj A boolean adjacency matrix indicating for each - * row of `x` and each group in `y` whether to compute the - * distance. Dim = `m x num_groups`. - * @param[in] group_idxs An array containing the *end* indices of each group - * in `y`. The value of group_idxs[j] indicates the - * start of group j + 1, i.e., it is the inclusive - * scan of the group lengths. The first group is - * always assumed to start at index 0 and the last - * group typically ends at index `n`. Length = - * `num_groups`. - * @param[out] out will contain the reduced output (Length = `m`) - * (on device) - */ -template -void maskedL2NN(raft::device_resources const& handle, - raft::distance::MaskedL2NNParams params, - raft::device_matrix_view x, - raft::device_matrix_view y, - raft::device_vector_view x_norm, - raft::device_vector_view y_norm, - raft::device_matrix_view adj, - raft::device_vector_view group_idxs, - raft::device_vector_view out) -{ - IdxT m = x.extent(0); - IdxT n = y.extent(0); - IdxT k = x.extent(1); - IdxT num_groups = group_idxs.extent(0); - - // Match k dimension of x, y - RAFT_EXPECTS(x.extent(1) == y.extent(1), "Dimension of vectors in x and y must be equal."); - // Match x, x_norm and y, y_norm - RAFT_EXPECTS(m == x_norm.extent(0), "Length of `x_norm` must match input `x`."); - RAFT_EXPECTS(n == y_norm.extent(0), "Length of `y_norm` must match input `y` "); - // Match adj to x and group_idxs - RAFT_EXPECTS(m == adj.extent(0), "#rows in `adj` must match input `x`."); - RAFT_EXPECTS(num_groups == adj.extent(1), "#cols in `adj` must match length of `group_idxs`."); - // NOTE: We do not check if all indices in group_idxs actually points *inside* y. - - // If there is no work to be done, return immediately. - if (m == 0 || n == 0 || k == 0 || num_groups == 0) { return; } - - detail::maskedL2NNImpl(handle, - out.data_handle(), - x.data_handle(), - y.data_handle(), - x_norm.data_handle(), - y_norm.data_handle(), - adj.data_handle(), - group_idxs.data_handle(), - num_groups, - m, - n, - k, - params.redOp, - params.pairRedOp, - params.sqrt, - params.initOutBuffer); -} - -/** @} */ - -} // namespace distance -} // namespace raft - -#endif diff --git a/cpp/include/raft/linalg/detail/contractions.cuh b/cpp/include/raft/linalg/detail/contractions.cuh index b15cb222b4..e247f39bc7 100644 --- a/cpp/include/raft/linalg/detail/contractions.cuh +++ b/cpp/include/raft/linalg/detail/contractions.cuh @@ -40,10 +40,14 @@ struct Contractions_NT { /** leading dimension in Output D */ IdxT ldd; + /** current thread's global mem row id for X data */ + IdxT xrowid; + /** current thread's global mem row id for Y data */ + IdxT yrowid; /** global memory pointer to X matrix */ - const DataT* x_base; + const DataT* x; /** global memory pointer to Y matrix */ - const DataT* y_base; + const DataT* y; /** current thread's smem row id */ int srowid; @@ -90,8 +94,10 @@ struct Contractions_NT { k(_k), lda(_k), ldb(_k), - x_base(_x), - y_base(_y), + xrowid(IdxT(blockIdx.x) * P::Mblk + threadIdx.x / P::LdgThRow), + yrowid(IdxT(blockIdx.y) * P::Nblk + threadIdx.x / P::LdgThRow), + x(_x + xrowid * lda), + y(_y + yrowid * ldb), srowid(threadIdx.x / P::LdgThRow), scolid((threadIdx.x % P::LdgThRow) * P::Veclen), accrowid(threadIdx.x / P::AccThCols), @@ -127,8 +133,6 @@ struct Contractions_NT { lda(_lda), ldb(_ldb), ldd(_ldd), - x_base(_x), - y_base(_y), srowid(threadIdx.x / P::LdgThRow), scolid((threadIdx.x % P::LdgThRow) * P::Veclen), accrowid(threadIdx.x / P::AccThCols), @@ -138,6 +142,17 @@ struct Contractions_NT { pageWr(0), pageRd(0) { + if (isRowMajor) { + xrowid = IdxT(blockIdx.y) * P::Mblk + srowid; + yrowid = IdxT(blockIdx.x) * P::Nblk + srowid; + x = _x + xrowid * lda; + y = _y + yrowid * ldb; + } else { + xrowid = IdxT(blockIdx.y) * P::Mblk; + yrowid = IdxT(blockIdx.x) * P::Nblk; + x = _x + xrowid + srowid * lda; + y = _y + yrowid + srowid * ldb; + } } protected: @@ -145,16 +160,10 @@ struct Contractions_NT { * @brief Load current block of X/Y from global memory to registers * @param[in] kidx current start index of k to be loaded */ - DI void ldgXY(IdxT tile_idx_m, IdxT tile_idx_n, IdxT kidx) - { - ldgX(tile_idx_m, kidx); - ldgY(tile_idx_n, kidx); - } - - DI void ldgXY(IdxT tile_idx_m, IdxT tile_idx_n, IdxT kidx, IdxT tile_end_n) + DI void ldgXY(IdxT kidx) { - ldgX(tile_idx_m, kidx); - ldgY(tile_idx_n, kidx, tile_end_n); + ldgX(kidx); + ldgY(kidx); } /** @@ -177,16 +186,9 @@ struct Contractions_NT { ldsY(kidx, sy + pageRd * P::SmemPage); } - DI void switch_read_buffer() { this->pageRd ^= 1; } - - DI void switch_write_buffer() { this->pageWr ^= 1; } - private: - DI void ldgX(IdxT tile_idx_m, IdxT kidx) + DI void ldgX(IdxT kidx) { - IdxT xrowid = isRowMajor ? tile_idx_m + srowid : tile_idx_m; - auto x = isRowMajor ? x_base + xrowid * lda : x_base + xrowid + srowid * lda; - if (isRowMajor) { auto numRows = m; auto koffset = kidx + scolid; @@ -218,15 +220,10 @@ struct Contractions_NT { } } - DI void ldgY(IdxT tile_idx_n, IdxT kidx) { ldgY(tile_idx_n, kidx, n); } - - DI void ldgY(IdxT tile_idx_n, IdxT kidx, IdxT end_n) + DI void ldgY(IdxT kidx) { - IdxT yrowid = isRowMajor ? tile_idx_n + srowid : tile_idx_n; - auto y = isRowMajor ? y_base + yrowid * ldb : y_base + yrowid + srowid * ldb; - if (isRowMajor) { - auto numRows = end_n; + auto numRows = n; auto koffset = kidx + scolid; #pragma unroll for (int i = 0; i < P::LdgPerThY; ++i) { @@ -244,7 +241,7 @@ struct Contractions_NT { auto koffset = scolid; #pragma unroll for (int i = 0; i < P::LdgPerThY; ++i) { - if ((koffset + yrowid) < end_n && (srowid + kidx + i * P::LdgRowsY) < numRows) { + if ((koffset + yrowid) < ldb && (srowid + kidx + i * P::LdgRowsY) < numRows) { ldg(ldgDataY[i], y + (kidx + i * P::LdgRowsY) * ldb + koffset); } else { #pragma unroll @@ -318,4 +315,4 @@ struct Contractions_NT { } // namespace detail } // namespace linalg -} // namespace raft +} // namespace raft \ No newline at end of file diff --git a/cpp/include/raft/neighbors/specializations.cuh b/cpp/include/raft/neighbors/specializations.cuh index d17467c8a7..d6cdae1e68 100644 --- a/cpp/include/raft/neighbors/specializations.cuh +++ b/cpp/include/raft/neighbors/specializations.cuh @@ -20,8 +20,9 @@ #pragma once #include -#include #include #include -#include + +#include + #endif diff --git a/cpp/include/raft/spatial/knn/detail/epsilon_neighborhood.cuh b/cpp/include/raft/spatial/knn/detail/epsilon_neighborhood.cuh index 7616083796..e4843acee9 100644 --- a/cpp/include/raft/spatial/knn/detail/epsilon_neighborhood.cuh +++ b/cpp/include/raft/spatial/knn/detail/epsilon_neighborhood.cuh @@ -64,7 +64,7 @@ struct EpsUnexpL2SqNeighborhood : public BaseClass { private: DI void prolog() { - this->ldgXY(IdxT(blockIdx.x) * P::Mblk, IdxT(blockIdx.y) * P::Nblk, 0); + this->ldgXY(0); #pragma unroll for (int i = 0; i < P::AccRowsPerTh; ++i) { #pragma unroll @@ -74,18 +74,18 @@ struct EpsUnexpL2SqNeighborhood : public BaseClass { } this->stsXY(); __syncthreads(); - this->switch_write_buffer(); + this->pageWr ^= 1; } DI void loop() { for (int kidx = P::Kblk; kidx < this->k; kidx += P::Kblk) { - this->ldgXY(IdxT(blockIdx.x) * P::Mblk, IdxT(blockIdx.y) * P::Nblk, kidx); + this->ldgXY(kidx); accumulate(); // on the previous k-block this->stsXY(); __syncthreads(); - this->switch_write_buffer(); - this->switch_read_buffer(); + this->pageWr ^= 1; + this->pageRd ^= 1; } accumulate(); // last iteration } diff --git a/cpp/src/distance/neighbors/ivfpq_build.cu b/cpp/src/distance/neighbors/ivfpq_build.cu index 31e304835b..7e6b12fb80 100644 --- a/cpp/src/distance/neighbors/ivfpq_build.cu +++ b/cpp/src/distance/neighbors/ivfpq_build.cu @@ -64,18 +64,4 @@ RAFT_INST_BUILD_EXTEND(uint8_t, uint64_t); #undef RAFT_INST_BUILD_EXTEND -void serialize(raft::device_resources const& handle, - const std::string& filename, - const raft::neighbors::ivf_pq::index& index) -{ - raft::spatial::knn::ivf_pq::detail::serialize(handle, filename, index); -}; - -void deserialize(raft::device_resources const& handle, - const std::string& filename, - raft::neighbors::ivf_pq::index* index) -{ - if (!index) { RAFT_FAIL("Invalid index pointer"); } - *index = raft::spatial::knn::ivf_pq::detail::deserialize(handle, filename); -}; } // namespace raft::runtime::neighbors::ivf_pq diff --git a/cpp/src/distance/neighbors/ivfpq_deserialize.cu b/cpp/src/distance/neighbors/ivfpq_deserialize.cu new file mode 100644 index 0000000000..e7ad77eef2 --- /dev/null +++ b/cpp/src/distance/neighbors/ivfpq_deserialize.cu @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +namespace raft::runtime::neighbors::ivf_pq { + +void deserialize(raft::device_resources const& handle, + const std::string& filename, + raft::neighbors::ivf_pq::index* index) +{ + if (!index) { RAFT_FAIL("Invalid index pointer"); } + *index = raft::spatial::knn::ivf_pq::detail::deserialize(handle, filename); +}; +} // namespace raft::runtime::neighbors::ivf_pq diff --git a/cpp/src/distance/neighbors/ivfpq_search_float_uint64_t.cu b/cpp/src/distance/neighbors/ivfpq_search_float_uint64_t.cu new file mode 100644 index 0000000000..c463aa9845 --- /dev/null +++ b/cpp/src/distance/neighbors/ivfpq_search_float_uint64_t.cu @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +namespace raft::runtime::neighbors::ivf_pq { + +#define RAFT_SEARCH_INST(T, IdxT) \ + void search(raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::search_params& params, \ + const raft::neighbors::ivf_pq::index& idx, \ + const T* queries, \ + uint32_t n_queries, \ + uint32_t k, \ + IdxT* neighbors, \ + float* distances, \ + rmm::mr::device_memory_resource* mr) \ + { \ + raft::neighbors::ivf_pq::search( \ + handle, params, idx, queries, n_queries, k, neighbors, distances, mr); \ + } + +RAFT_SEARCH_INST(float, uint64_t); + +#undef RAFT_INST_SEARCH + +} // namespace raft::runtime::neighbors::ivf_pq diff --git a/cpp/src/distance/neighbors/ivfpq_search_int8_t_uint64_t.cu b/cpp/src/distance/neighbors/ivfpq_search_int8_t_uint64_t.cu new file mode 100644 index 0000000000..ab0dd576b9 --- /dev/null +++ b/cpp/src/distance/neighbors/ivfpq_search_int8_t_uint64_t.cu @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +namespace raft::runtime::neighbors::ivf_pq { + +#define RAFT_SEARCH_INST(T, IdxT) \ + void search(raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::search_params& params, \ + const raft::neighbors::ivf_pq::index& idx, \ + const T* queries, \ + uint32_t n_queries, \ + uint32_t k, \ + IdxT* neighbors, \ + float* distances, \ + rmm::mr::device_memory_resource* mr) \ + { \ + raft::neighbors::ivf_pq::search( \ + handle, params, idx, queries, n_queries, k, neighbors, distances, mr); \ + } + +RAFT_SEARCH_INST(int8_t, uint64_t); + +#undef RAFT_INST_SEARCH + +} // namespace raft::runtime::neighbors::ivf_pq diff --git a/cpp/src/distance/neighbors/ivfpq_search.cu b/cpp/src/distance/neighbors/ivfpq_search_uint8_t_uint64_t.cu similarity index 96% rename from cpp/src/distance/neighbors/ivfpq_search.cu rename to cpp/src/distance/neighbors/ivfpq_search_uint8_t_uint64_t.cu index 05ab890ea5..2a745eb37d 100644 --- a/cpp/src/distance/neighbors/ivfpq_search.cu +++ b/cpp/src/distance/neighbors/ivfpq_search_uint8_t_uint64_t.cu @@ -35,8 +35,6 @@ namespace raft::runtime::neighbors::ivf_pq { handle, params, idx, queries, n_queries, k, neighbors, distances, mr); \ } -RAFT_SEARCH_INST(float, uint64_t); -RAFT_SEARCH_INST(int8_t, uint64_t); RAFT_SEARCH_INST(uint8_t, uint64_t); #undef RAFT_INST_SEARCH diff --git a/cpp/src/distance/neighbors/ivfpq_serialize.cu b/cpp/src/distance/neighbors/ivfpq_serialize.cu new file mode 100644 index 0000000000..706c344993 --- /dev/null +++ b/cpp/src/distance/neighbors/ivfpq_serialize.cu @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +namespace raft::runtime::neighbors::ivf_pq { + +void serialize(raft::device_resources const& handle, + const std::string& filename, + const raft::neighbors::ivf_pq::index& index) +{ + raft::spatial::knn::ivf_pq::detail::serialize(handle, filename, index); +}; + +} // namespace raft::runtime::neighbors::ivf_pq diff --git a/cpp/src/distance/neighbors/refine_h_uint64_t_float.cu b/cpp/src/distance/neighbors/refine_h_uint64_t_float.cu index 2a2dcff3bf..8549d65dc5 100644 --- a/cpp/src/distance/neighbors/refine_h_uint64_t_float.cu +++ b/cpp/src/distance/neighbors/refine_h_uint64_t_float.cu @@ -1,3 +1,4 @@ + /* * Copyright (c) 2022-2023, NVIDIA CORPORATION. * diff --git a/cpp/src/distance/neighbors/refine_h_uint64_t_int8_t.cu b/cpp/src/distance/neighbors/refine_h_uint64_t_int8_t.cu index d7c60b62a5..cf6d7a397a 100644 --- a/cpp/src/distance/neighbors/refine_h_uint64_t_int8_t.cu +++ b/cpp/src/distance/neighbors/refine_h_uint64_t_int8_t.cu @@ -18,6 +18,7 @@ #include namespace raft::runtime::neighbors { + void refine(raft::device_resources const& handle, raft::host_matrix_view dataset, raft::host_matrix_view queries, diff --git a/cpp/src/nn/specializations/ball_cover_all_knn_query.cu b/cpp/src/nn/specializations/ball_cover_all_knn_query.cu new file mode 100644 index 0000000000..da5cd8de4f --- /dev/null +++ b/cpp/src/nn/specializations/ball_cover_all_knn_query.cu @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +// Ignore upstream specializations to avoid unnecessary recompiling +#ifdef RAFT_DISTANCE_COMPILED +#include +#endif + +#include +#include +#include + +#include + +namespace raft::neighbors::ball_cover { +template void all_knn_query( + raft::device_resources const& handle, + BallCoverIndex& index, + std::uint32_t k, + std::int64_t* inds, + float* dists, + bool perform_post_filtering, + float weight); + +}; // namespace raft::neighbors::ball_cover diff --git a/cpp/src/nn/specializations/ball_cover.cu b/cpp/src/nn/specializations/ball_cover_build_index.cu similarity index 70% rename from cpp/src/nn/specializations/ball_cover.cu rename to cpp/src/nn/specializations/ball_cover_build_index.cu index f37fda31af..70fcbec356 100644 --- a/cpp/src/nn/specializations/ball_cover.cu +++ b/cpp/src/nn/specializations/ball_cover_build_index.cu @@ -36,24 +36,4 @@ template void build_index( raft::device_resources const& handle, BallCoverIndex& index); -template void knn_query( - raft::device_resources const& handle, - const BallCoverIndex& index, - std::uint32_t k, - const float* query, - std::uint32_t n_query_pts, - std::int64_t* inds, - float* dists, - bool perform_post_filtering, - float weight); - -template void all_knn_query( - raft::device_resources const& handle, - BallCoverIndex& index, - std::uint32_t k, - std::int64_t* inds, - float* dists, - bool perform_post_filtering, - float weight); - }; // namespace raft::neighbors::ball_cover diff --git a/cpp/src/nn/specializations/ball_cover_knn_query.cu b/cpp/src/nn/specializations/ball_cover_knn_query.cu new file mode 100644 index 0000000000..d5ca1cbc1c --- /dev/null +++ b/cpp/src/nn/specializations/ball_cover_knn_query.cu @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +// Ignore upstream specializations to avoid unnecessary recompiling +#ifdef RAFT_DISTANCE_COMPILED +#include +#endif + +#include +#include +#include + +#include + +namespace raft::neighbors::ball_cover { +template void knn_query( + raft::device_resources const& handle, + const BallCoverIndex& index, + std::uint32_t k, + const float* query, + std::uint32_t n_query_pts, + std::int64_t* inds, + float* dists, + bool perform_post_filtering, + float weight); + +}; // namespace raft::neighbors::ball_cover diff --git a/cpp/src/nn/specializations/brute_force_knn_long_float_int.cu b/cpp/src/nn/specializations/brute_force_knn_long_float_int.cu new file mode 100644 index 0000000000..b08bcfbc79 --- /dev/null +++ b/cpp/src/nn/specializations/brute_force_knn_long_float_int.cu @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +namespace raft { +namespace spatial { +namespace knn { + +template void brute_force_knn(raft::device_resources const& handle, + std::vector& input, + std::vector& sizes, + int D, + float* search_items, + int n, + long* res_I, + float* res_D, + int k, + bool rowMajorIndex, + bool rowMajorQuery, + std::vector* translations, + distance::DistanceType metric, + float metric_arg); + +}; // namespace knn +}; // namespace spatial +}; // namespace raft diff --git a/cpp/src/nn/specializations/brute_force_knn_long_float_uint.cu b/cpp/src/nn/specializations/brute_force_knn_long_float_uint.cu new file mode 100644 index 0000000000..78cb92bb38 --- /dev/null +++ b/cpp/src/nn/specializations/brute_force_knn_long_float_uint.cu @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +namespace raft { +namespace spatial { +namespace knn { + +template void brute_force_knn(raft::device_resources const& handle, + std::vector& input, + std::vector& sizes, + unsigned int D, + float* search_items, + unsigned int n, + long* res_I, + float* res_D, + unsigned int k, + bool rowMajorIndex, + bool rowMajorQuery, + std::vector* translations, + distance::DistanceType metric, + float metric_arg); + +}; // namespace knn +}; // namespace spatial +}; // namespace raft diff --git a/cpp/src/nn/specializations/brute_force_knn_uint32_t_float_int.cu b/cpp/src/nn/specializations/brute_force_knn_uint32_t_float_int.cu new file mode 100644 index 0000000000..0082a30796 --- /dev/null +++ b/cpp/src/nn/specializations/brute_force_knn_uint32_t_float_int.cu @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +namespace raft { +namespace spatial { +namespace knn { +template void brute_force_knn(raft::device_resources const& handle, + std::vector& input, + std::vector& sizes, + int D, + float* search_items, + int n, + uint32_t* res_I, + float* res_D, + int k, + bool rowMajorIndex, + bool rowMajorQuery, + std::vector* translations, + distance::DistanceType metric, + float metric_arg); + +}; // namespace knn +}; // namespace spatial +}; // namespace raft diff --git a/cpp/src/nn/specializations/brute_force_knn_uint32_t_float_uint.cu b/cpp/src/nn/specializations/brute_force_knn_uint32_t_float_uint.cu new file mode 100644 index 0000000000..b2a1af2cf0 --- /dev/null +++ b/cpp/src/nn/specializations/brute_force_knn_uint32_t_float_uint.cu @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +namespace raft { +namespace spatial { +namespace knn { + +template void brute_force_knn(raft::device_resources const& handle, + std::vector& input, + std::vector& sizes, + unsigned int D, + float* search_items, + unsigned int n, + uint32_t* res_I, + float* res_D, + unsigned int k, + bool rowMajorIndex, + bool rowMajorQuery, + std::vector* translations, + distance::DistanceType metric, + float metric_arg); + +}; // namespace knn +}; // namespace spatial +}; // namespace raft diff --git a/cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_one_3d.cu b/cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_one_3d.cu index a12d6548ed..1a1c17b29f 100644 --- a/cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_one_3d.cu +++ b/cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_one_3d.cu @@ -37,19 +37,6 @@ template void rbc_low_dim_pass_one( float weight, std::uint32_t* dists_counter); -template void rbc_low_dim_pass_two( - raft::device_resources const& handle, - const BallCoverIndex& index, - const float* query, - const std::uint32_t n_query_rows, - std::uint32_t k, - const std::int64_t* R_knn_inds, - const float* R_knn_dists, - DistFunc& dfunc, - std::int64_t* inds, - float* dists, - float weight, - std::uint32_t* post_dists_counter); }; // namespace detail }; // namespace knn }; // namespace spatial diff --git a/cpp/src/nn/specializations/knn.cu b/cpp/src/nn/specializations/knn.cu deleted file mode 100644 index d135610bfb..0000000000 --- a/cpp/src/nn/specializations/knn.cu +++ /dev/null @@ -1,86 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft { -namespace spatial { -namespace knn { - -template void brute_force_knn(raft::device_resources const& handle, - std::vector& input, - std::vector& sizes, - int D, - float* search_items, - int n, - long* res_I, - float* res_D, - int k, - bool rowMajorIndex, - bool rowMajorQuery, - std::vector* translations, - distance::DistanceType metric, - float metric_arg); - -template void brute_force_knn(raft::device_resources const& handle, - std::vector& input, - std::vector& sizes, - unsigned int D, - float* search_items, - unsigned int n, - long* res_I, - float* res_D, - unsigned int k, - bool rowMajorIndex, - bool rowMajorQuery, - std::vector* translations, - distance::DistanceType metric, - float metric_arg); - -template void brute_force_knn(raft::device_resources const& handle, - std::vector& input, - std::vector& sizes, - int D, - float* search_items, - int n, - uint32_t* res_I, - float* res_D, - int k, - bool rowMajorIndex, - bool rowMajorQuery, - std::vector* translations, - distance::DistanceType metric, - float metric_arg); - -template void brute_force_knn(raft::device_resources const& handle, - std::vector& input, - std::vector& sizes, - unsigned int D, - float* search_items, - unsigned int n, - uint32_t* res_I, - float* res_D, - unsigned int k, - bool rowMajorIndex, - bool rowMajorQuery, - std::vector* translations, - distance::DistanceType metric, - float metric_arg); - -}; // namespace knn -}; // namespace spatial -}; // namespace raft diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 2e89418f8e..3c41621274 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -123,8 +123,6 @@ if(BUILD_TESTS) test/distance/dist_minkowski.cu test/distance/dist_russell_rao.cu test/distance/fused_l2_nn.cu - test/distance/masked_nn.cu - test/distance/masked_nn_compress_to_bits.cu test/distance/gram.cu OPTIONAL DIST diff --git a/cpp/test/distance/fused_l2_nn.cu b/cpp/test/distance/fused_l2_nn.cu index af67214193..8b9681b9d3 100644 --- a/cpp/test/distance/fused_l2_nn.cu +++ b/cpp/test/distance/fused_l2_nn.cu @@ -385,7 +385,7 @@ class FusedL2NNDetTest : public FusedL2NNTest { rmm::device_uvector> min1; - static const int NumRepeats = 3; + static const int NumRepeats = 100; void generateGoldenResult() override {} }; diff --git a/cpp/test/distance/masked_nn.cu b/cpp/test/distance/masked_nn.cu deleted file mode 100644 index c80c984992..0000000000 --- a/cpp/test/distance/masked_nn.cu +++ /dev/null @@ -1,435 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "../test_utils.h" -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace raft::distance::masked_nn { - -// The adjacency pattern determines what distances get computed. -enum AdjacencyPattern { - checkerboard = 0, // adjacency matrix looks like a checkerboard (half the distances are computed) - checkerboard_4 = 1, // checkerboard with tiles of size 4x4 - checkerboard_64 = 2, // checkerboard with tiles of size 64x64 - all_true = 3, // no distance computations can be skipped - all_false = 4 // all distance computations can be skipped -}; - -// Kernels: -// - init_adj: to initialize the adjacency kernel with a specific adjacency pattern -// - referenceKernel: to produce the ground-truth output - -__global__ void init_adj(AdjacencyPattern pattern, - int n, - raft::device_matrix_view adj, - raft::device_vector_view group_idxs) -{ - int m = adj.extent(0); - int num_groups = adj.extent(1); - - for (int idx_m = blockIdx.y * blockDim.y + threadIdx.y; idx_m < m; - idx_m += blockDim.y * gridDim.y) { - for (int idx_g = blockIdx.x * blockDim.x + threadIdx.x; idx_g < num_groups; - idx_g += blockDim.x * gridDim.x) { - switch (pattern) { - case checkerboard: adj(idx_m, idx_g) = (idx_m + idx_g) % 2; break; - case checkerboard_4: adj(idx_m, idx_g) = (idx_m / 4 + idx_g) % 2; break; - case checkerboard_64: adj(idx_m, idx_g) = (idx_m / 64 + idx_g) % 2; break; - case all_true: adj(idx_m, idx_g) = true; break; - case all_false: adj(idx_m, idx_g) = false; break; - default: assert(false && "unknown pattern"); - } - } - } - // Each group is of size n / num_groups. - // - // - group_idxs[j] indicates the start of group j + 1 (i.e. is the inclusive - // scan of the group lengths) - // - // - The first group always starts at index zero, so we do not store it. - // - // - The group_idxs[num_groups - 1] should always equal n. - - if (blockIdx.y == 0 && threadIdx.y == 0) { - const int g_stride = blockDim.x * gridDim.x; - for (int idx_g = blockIdx.x * blockDim.x + threadIdx.x; idx_g < num_groups; idx_g += g_stride) { - group_idxs(idx_g) = (idx_g + 1) * (n / num_groups); - } - group_idxs(num_groups - 1) = n; - } -} - -template -__global__ __launch_bounds__(32 * NWARPS, - 2) void referenceKernel(raft::KeyValuePair* min, - DataT* x, - DataT* y, - bool* adj, - int* group_idxs, - int m, - int n, - int k, - int num_groups, - bool sqrt, - int* workspace, - DataT maxVal) -{ - const int m_stride = blockDim.y * gridDim.y; - const int m_offset = threadIdx.y + blockIdx.y * blockDim.y; - const int n_stride = blockDim.x * gridDim.x; - const int n_offset = threadIdx.x + blockIdx.x * blockDim.x; - - for (int m_grid = 0; m_grid < m; m_grid += m_stride) { - for (int n_grid = 0; n_grid < n; n_grid += n_stride) { - int midx = m_grid + m_offset; - int nidx = n_grid + n_offset; - - // Do a reverse linear search to determine the group index. - int group_idx = 0; - for (int i = num_groups; 0 <= i; --i) { - if (nidx < group_idxs[i]) { group_idx = i; } - } - const bool include_dist = adj[midx * num_groups + group_idx] && midx < m && nidx < n; - - // Compute L2 metric. - DataT acc = DataT(0); - for (int i = 0; i < k; ++i) { - int xidx = i + midx * k; - int yidx = i + nidx * k; - auto diff = x[xidx] - y[yidx]; - acc += diff * diff; - } - if (sqrt) { acc = raft::sqrt(acc); } - ReduceOpT redOp; - typedef cub::WarpReduce> WarpReduce; - __shared__ typename WarpReduce::TempStorage temp[NWARPS]; - int warpId = threadIdx.x / raft::WarpSize; - raft::KeyValuePair tmp; - tmp.key = include_dist ? nidx : -1; - tmp.value = include_dist ? acc : maxVal; - tmp = WarpReduce(temp[warpId]).Reduce(tmp, raft::distance::KVPMinReduce{}); - if (threadIdx.x % raft::WarpSize == 0 && midx < m) { - while (atomicCAS(workspace + midx, 0, 1) == 1) - ; - __threadfence(); - redOp(midx, min + midx, tmp); - __threadfence(); - atomicCAS(workspace + midx, 1, 0); - } - __syncthreads(); - } - } -} - -// Structs -// - Params: holds parameters for test case -// - Inputs: holds the inputs to the functions under test (x, y, adj, group_idxs). Is generated from -// the inputs. -struct Params { - double tolerance; - int m, n, k, num_groups; - bool sqrt; - unsigned long long int seed; - AdjacencyPattern pattern; -}; - -inline auto operator<<(std::ostream& os, const Params& p) -> std::ostream& -{ - os << "m: " << p.m << ", n: " << p.n << ", k: " << p.k << ", num_groups: " << p.num_groups - << ", sqrt: " << p.sqrt << ", seed: " << p.seed << ", tol: " << p.tolerance; - return os; -} - -template -struct Inputs { - using IdxT = int; - - raft::device_matrix x, y; - raft::device_matrix adj; - raft::device_vector group_idxs; - - Inputs(const raft::handle_t& handle, const Params& p) - : x{raft::make_device_matrix(handle, p.m, p.k)}, - y{raft::make_device_matrix(handle, p.n, p.k)}, - adj{raft::make_device_matrix(handle, p.m, p.num_groups)}, - group_idxs{raft::make_device_vector(handle, p.num_groups)} - { - // Initialize x, y - raft::random::RngState r(p.seed); - uniform(handle, r, x.data_handle(), p.m * p.k, DataT(-1.0), DataT(1.0)); - uniform(handle, r, y.data_handle(), p.n * p.k, DataT(-1.0), DataT(1.0)); - - // Initialize adj, group_idxs. - dim3 block(32, 32); - dim3 grid(10, 10); - init_adj<<>>( - p.pattern, p.n, adj.view(), group_idxs.view()); - RAFT_CUDA_TRY(cudaGetLastError()); - } -}; - -template > -auto reference(const raft::handle_t& handle, Inputs inp, const Params& p) - -> raft::device_vector -{ - int m = inp.x.extent(0); - int n = inp.y.extent(0); - int k = inp.x.extent(1); - int num_groups = inp.group_idxs.extent(0); - - if (m == 0 || n == 0 || k == 0 || num_groups == 0) { - return raft::make_device_vector(handle, 0); - } - - // Initialize workspace - auto stream = handle.get_stream(); - rmm::device_uvector workspace(p.m * sizeof(int), stream); - RAFT_CUDA_TRY(cudaMemsetAsync(workspace.data(), 0, sizeof(int) * m, stream)); - - // Initialize output - auto out = raft::make_device_vector(handle, m); - auto blks = raft::ceildiv(m, 256); - MinAndDistanceReduceOp op; - raft::distance::detail::initKernel, int> - <<>>(out.data_handle(), m, std::numeric_limits::max(), op); - RAFT_CUDA_TRY(cudaGetLastError()); - - // Launch reference kernel - const int nwarps = 16; - static const dim3 TPB(32, nwarps, 1); - dim3 nblks(1, 200, 1); - referenceKernel - <<>>(out.data_handle(), - inp.x.data_handle(), - inp.y.data_handle(), - inp.adj.data_handle(), - inp.group_idxs.data_handle(), - m, - n, - k, - num_groups, - p.sqrt, - (int*)workspace.data(), - std::numeric_limits::max()); - RAFT_CUDA_TRY(cudaGetLastError()); - - return out; -} - -template > -auto run_masked_nn(const raft::handle_t& handle, Inputs inp, const Params& p) - -> raft::device_vector -{ - // Compute norms: - auto x_norm = raft::make_device_vector(handle, p.m); - auto y_norm = raft::make_device_vector(handle, p.n); - - raft::linalg::norm(handle, - std::as_const(inp.x).view(), - x_norm.view(), - raft::linalg::L2Norm, - raft::linalg::Apply::ALONG_ROWS); - raft::linalg::norm(handle, - std::as_const(inp.y).view(), - y_norm.view(), - raft::linalg::L2Norm, - raft::linalg::Apply::ALONG_ROWS); - - // Create parameters for maskedL2NN - using IdxT = int; - using RedOpT = MinAndDistanceReduceOp; - using PairRedOpT = raft::distance::KVPMinReduce; - using ParamT = raft::distance::MaskedL2NNParams; - - bool init_out = true; - ParamT masked_l2_params{RedOpT{}, PairRedOpT{}, p.sqrt, init_out}; - - // Create output - auto out = raft::make_device_vector(handle, p.m); - - // Launch kernel - raft::distance::maskedL2NN(handle, - masked_l2_params, - inp.x.view(), - inp.y.view(), - x_norm.view(), - y_norm.view(), - inp.adj.view(), - inp.group_idxs.view(), - out.view()); - - handle.sync_stream(); - - return out; -} - -template -struct CompareApproxAbsKVP { - typedef typename raft::KeyValuePair KVP; - CompareApproxAbsKVP(T eps_) : eps(eps_) {} - bool operator()(const KVP& a, const KVP& b) const - { - T diff = raft::abs(raft::abs(a.value) - raft::abs(b.value)); - T m = std::max(raft::abs(a.value), raft::abs(b.value)); - T ratio = m >= eps ? diff / m : diff; - return (ratio <= eps); - } - - private: - T eps; -}; - -template -::testing::AssertionResult devArrMatch(const raft::KeyValuePair* expected, - const raft::KeyValuePair* actual, - size_t size, - L eq_compare, - cudaStream_t stream = 0) -{ - typedef typename raft::KeyValuePair KVP; - std::shared_ptr exp_h(new KVP[size]); - std::shared_ptr act_h(new KVP[size]); - raft::update_host(exp_h.get(), expected, size, stream); - raft::update_host(act_h.get(), actual, size, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); - for (size_t i(0); i < size; ++i) { - auto exp = exp_h.get()[i]; - auto act = act_h.get()[i]; - if (!eq_compare(exp, act)) { - return ::testing::AssertionFailure() - << "actual=" << act.key << "," << act.value << " != expected=" << exp.key << "," - << exp.value << " @" << i; - } - } - return ::testing::AssertionSuccess(); -} - -inline auto gen_params() -> std::vector -{ - // Regular powers of two - auto regular = raft::util::itertools::product({0.001f}, // tolerance - {32, 64, 512}, // m - {32, 64, 512}, // n - {8, 32}, // k - {2, 32}, // num_groups - {true, false}, // sqrt - {1234ULL}, // seed - {AdjacencyPattern::all_true, - AdjacencyPattern::checkerboard, - AdjacencyPattern::checkerboard_64, - AdjacencyPattern::all_false}); - - // Irregular sizes to check tiling and bounds checking - auto irregular = raft::util::itertools::product({0.001f}, // tolerance - {511, 512, 513}, // m - {127, 128, 129}, // n - {5}, // k - {3, 9}, // num_groups - {true, false}, // sqrt - {1234ULL}, // seed - {AdjacencyPattern::all_true, - AdjacencyPattern::checkerboard, - AdjacencyPattern::checkerboard_64}); - - regular.insert(regular.end(), irregular.begin(), irregular.end()); - - return regular; -} - -class MaskedL2NNTest : public ::testing::TestWithParam { - // Empty. -}; - -// -TEST_P(MaskedL2NNTest, ReferenceCheckFloat) -{ - using DataT = float; - - // Get parameters; create handle and input data. - Params p = GetParam(); - raft::handle_t handle{}; - Inputs inputs{handle, p}; - - // Calculate reference and test output - auto out_reference = reference(handle, inputs, p); - auto out_fast = run_masked_nn(handle, inputs, p); - - // Check for differences. - ASSERT_TRUE(devArrMatch(out_reference.data_handle(), - out_fast.data_handle(), - p.m, - CompareApproxAbsKVP(p.tolerance), - handle.get_stream())); -} - -// This test checks whether running the maskedL2NN twice returns the same -// output. -TEST_P(MaskedL2NNTest, DeterminismCheck) -{ - using DataT = float; - - // Get parameters; create handle and input data. - Params p = GetParam(); - raft::handle_t handle{}; - Inputs inputs{handle, p}; - - // Calculate reference and test output - auto out1 = run_masked_nn(handle, inputs, p); - auto out2 = run_masked_nn(handle, inputs, p); - - // Check for differences. - ASSERT_TRUE(devArrMatch(out1.data_handle(), - out2.data_handle(), - p.m, - CompareApproxAbsKVP(p.tolerance), - handle.get_stream())); -} - -TEST_P(MaskedL2NNTest, ReferenceCheckDouble) -{ - using DataT = double; - - // Get parameters; create handle and input data. - Params p = GetParam(); - raft::handle_t handle{}; - Inputs inputs{handle, p}; - - // Calculate reference and test output - auto out_reference = reference(handle, inputs, p); - auto out_fast = run_masked_nn(handle, inputs, p); - - // Check for differences. - ASSERT_TRUE(devArrMatch(out_reference.data_handle(), - out_fast.data_handle(), - p.m, - CompareApproxAbsKVP(p.tolerance), - handle.get_stream())); -} - -INSTANTIATE_TEST_CASE_P(MaskedL2NNTests, MaskedL2NNTest, ::testing::ValuesIn(gen_params())); - -} // end namespace raft::distance::masked_nn diff --git a/cpp/test/distance/masked_nn_compress_to_bits.cu b/cpp/test/distance/masked_nn_compress_to_bits.cu deleted file mode 100644 index 7597362274..0000000000 --- a/cpp/test/distance/masked_nn_compress_to_bits.cu +++ /dev/null @@ -1,216 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "../test_utils.cuh" -#include "../test_utils.h" -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace raft::distance::masked_nn::compress_to_bits { - -/** - * @brief Transpose and decompress 2D bitfield to boolean matrix - * - * Inverse operation of compress_to_bits - * - * @tparam T - * - * @parameter[in] in An `m x n` bitfield matrix. Row major. - * @parameter in_rows The number of rows of `in`, i.e. `m`. - * @parameter in_cols The number of cols of `in`, i.e. `n`. - * - * @parameter[out] out An `(m * bits_per_elem) x n` boolean matrix. - */ -template ::value>> -__global__ void decompress_bits_kernel(const T* in, int in_rows, int in_cols, bool* out) -{ - constexpr int bits_per_element = 8 * sizeof(T); - - const size_t i = threadIdx.y + blockIdx.y * blockDim.y; - const size_t j = threadIdx.x + blockIdx.x * blockDim.x; - - if (in_rows <= i || in_cols <= j) { return; } - - const size_t out_rows = in_rows * bits_per_element; - const size_t out_cols = in_cols; - const size_t out_i = i * bits_per_element; - const size_t out_j = j; - - if (out_rows <= out_i && out_cols <= out_j) { return; } - - T bitfield = in[i * in_cols + j]; - for (int bitpos = 0; bitpos < bits_per_element; ++bitpos) { - bool bit = ((T(1) << bitpos) & bitfield) != 0; - out[(out_i + bitpos) * out_cols + out_j] = bit; - } -} - -/** - * @brief Transpose and decompress 2D bitfield to boolean matrix - * - * Inverse operation of compress_to_bits - * - * @tparam T - * - * @parameter[in] in An `m x n` bitfield matrix. Row major. - * @parameter in_rows The number of rows of `in`, i.e. `m`. - * @parameter in_cols The number of cols of `in`, i.e. `n`. - * - * @parameter[out] out An `n x (m * bits_per_elem)` boolean matrix. - */ -template ::value>> -void decompress_bits(const raft::handle_t& handle, const T* in, int in_rows, int in_cols, bool* out) -{ - auto stream = handle.get_stream(); - dim3 grid(raft::ceildiv(in_cols, 32), raft::ceildiv(in_rows, 32)); - dim3 block(32, 32); - decompress_bits_kernel<<>>(in, in_rows, in_cols, out); - RAFT_CUDA_TRY(cudaGetLastError()); -} - -// Params holds parameters for test case -struct Params { - int m, n; -}; - -inline auto operator<<(std::ostream& os, const Params& p) -> std::ostream& -{ - return os << "m: " << p.m << ", n: " << p.n; -} - -// Check that the following holds -// -// decompress(compress(x)) == x -// -// for 2D boolean matrices x. -template -void check_invertible(const Params& p) -{ - using raft::distance::detail::compress_to_bits; - constexpr int bits_per_elem = sizeof(T) * 8; - - // Make m and n that are safe to ceildiv. - int m = raft::round_up_safe(p.m, bits_per_elem); - int n = p.n; - - // Generate random input - raft::handle_t handle{}; - raft::random::RngState r(1ULL); - auto in = raft::make_device_matrix(handle, m, n); - raft::random::bernoulli(handle, r, in.data_handle(), m * n, 0.5f); - - int tmp_m = raft::ceildiv(m, bits_per_elem); - int out_m = tmp_m * bits_per_elem; - - auto tmp = raft::make_device_matrix(handle, tmp_m, n); - auto out = raft::make_device_matrix(handle, out_m, n); - - handle.sync_stream(); - RAFT_CUDA_TRY(cudaGetLastError()); - - ASSERT_EQ(in.extent(0), out.extent(0)) << "M does not match"; - ASSERT_EQ(in.extent(1), out.extent(1)) << "N does not match"; - - compress_to_bits(handle, in.view(), tmp.view()); - handle.sync_stream(); - RAFT_CUDA_TRY(cudaGetLastError()); - - decompress_bits(handle, tmp.data_handle(), tmp.extent(0), tmp.extent(1), out.data_handle()); - handle.sync_stream(); - RAFT_CUDA_TRY(cudaGetLastError()); - - // Check for differences. - ASSERT_TRUE(raft::devArrMatch(in.data_handle(), - out.data_handle(), - in.extent(0) * in.extent(1), - raft::Compare(), - handle.get_stream())); - handle.sync_stream(); - RAFT_CUDA_TRY(cudaGetLastError()); -} - -void check_all_true(const Params& p) -{ - using raft::distance::detail::compress_to_bits; - using T = uint64_t; - constexpr int bits_per_elem = sizeof(T) * 8; - - // Make m and n that are safe to ceildiv. - int m = raft::round_up_safe(p.m, bits_per_elem); - int n = p.n; - - raft::handle_t handle{}; - raft::random::RngState r(1ULL); - auto in = raft::make_device_matrix(handle, m, n); - raft::matrix::fill(handle, in.view(), true); - - int tmp_m = raft::ceildiv(m, bits_per_elem); - auto tmp = raft::make_device_matrix(handle, tmp_m, n); - handle.sync_stream(); - RAFT_CUDA_TRY(cudaGetLastError()); - - compress_to_bits(handle, in.view(), tmp.view()); - handle.sync_stream(); - RAFT_CUDA_TRY(cudaGetLastError()); - - auto expected = raft::make_device_matrix(handle, tmp_m, n); - raft::matrix::fill(handle, expected.view(), ~T(0)); - - // Check for differences. - ASSERT_TRUE(raft::devArrMatch(expected.data_handle(), - tmp.data_handle(), - tmp.extent(0) * tmp.extent(1), - raft::Compare(), - handle.get_stream())); - handle.sync_stream(); - RAFT_CUDA_TRY(cudaGetLastError()); -} - -class CompressToBitsTest : public ::testing::TestWithParam { - // Empty. -}; - -TEST_P(CompressToBitsTest, CheckTrue64) { check_all_true(GetParam()); } - -TEST_P(CompressToBitsTest, CheckInvertible64) -{ - using T = uint64_t; - check_invertible(GetParam()); -} - -TEST_P(CompressToBitsTest, CheckInvertible32) -{ - using T = uint32_t; - check_invertible(GetParam()); -} - -std::vector params = raft::util::itertools::product( - {1, 3, 32, 33, 63, 64, 65, 128, 10013}, {1, 3, 32, 33, 63, 64, 65, 13001}); - -INSTANTIATE_TEST_CASE_P(CompressToBits, CompressToBitsTest, ::testing::ValuesIn(params)); - -} // namespace raft::distance::masked_nn::compress_to_bits diff --git a/cpp/test/neighbors/refine.cu b/cpp/test/neighbors/refine.cu index a78f5cfe5c..174dce5a7f 100644 --- a/cpp/test/neighbors/refine.cu +++ b/cpp/test/neighbors/refine.cu @@ -32,7 +32,7 @@ #include #if defined RAFT_DISTANCE_COMPILED -#include +#include #endif #include diff --git a/docs/source/cpp_api/distance.rst b/docs/source/cpp_api/distance.rst index 1632f19fba..eb9bc6255d 100644 --- a/docs/source/cpp_api/distance.rst +++ b/docs/source/cpp_api/distance.rst @@ -25,4 +25,3 @@ namespace *raft::distance* distance_pairwise.rst distance_1nn.rst - distance_masked_nn.rst diff --git a/docs/source/cpp_api/distance_masked_nn.rst b/docs/source/cpp_api/distance_masked_nn.rst deleted file mode 100644 index 89e23ba98a..0000000000 --- a/docs/source/cpp_api/distance_masked_nn.rst +++ /dev/null @@ -1,16 +0,0 @@ -Masked 1-Nearest Neighbors -========================== - -.. role:: py(code) - :language: c++ - :class: highlight - -``#include `` - -namespace *raft::distance* - -.. doxygengroup:: masked_nn - :project: RAFT - :members: - :content-only: -