diff --git a/cpp/bench/CMakeLists.txt b/cpp/bench/CMakeLists.txt index 813483adc5..6b985acfc3 100644 --- a/cpp/bench/CMakeLists.txt +++ b/cpp/bench/CMakeLists.txt @@ -103,7 +103,10 @@ if(BUILD_BENCH) bench/main.cpp ) - ConfigureBench(NAME MATRIX_BENCH PATH bench/matrix/argmin.cu bench/main.cpp) + ConfigureBench( + NAME MATRIX_BENCH PATH bench/matrix/argmin.cu bench/matrix/gather.cu bench/matrix/select_k.cu + bench/main.cpp + ) ConfigureBench( NAME RANDOM_BENCH PATH bench/random/make_blobs.cu bench/random/permute.cu bench/random/rng.cu @@ -127,7 +130,6 @@ if(BUILD_BENCH) bench/neighbors/knn/ivf_pq_int8_t_int64_t.cu bench/neighbors/knn/ivf_pq_uint8_t_uint32_t.cu bench/neighbors/refine.cu - bench/neighbors/selection.cu bench/main.cpp OPTIONAL DIST diff --git a/cpp/bench/matrix/argmin.cu b/cpp/bench/matrix/argmin.cu index 0d0dea0fdb..52f5aab7f3 100644 --- a/cpp/bench/matrix/argmin.cu +++ b/cpp/bench/matrix/argmin.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,10 +17,11 @@ #include #include #include +#include #include -namespace raft::bench::linalg { +namespace raft::bench::matrix { template struct ArgminParams { @@ -57,15 +58,11 @@ struct Argmin : public fixture { raft::device_vector indices; }; // struct Argmin -const std::vector> argmin_inputs_i64{ - {1000, 64}, {1000, 128}, {1000, 256}, {1000, 512}, {1000, 1024}, - {10000, 64}, {10000, 128}, {10000, 256}, {10000, 512}, {10000, 1024}, - {100000, 64}, {100000, 128}, {100000, 256}, {100000, 512}, {100000, 1024}, - {1000000, 64}, {1000000, 128}, {1000000, 256}, {1000000, 512}, {1000000, 1024}, - {10000000, 64}, {10000000, 128}, {10000000, 256}, {10000000, 512}, {10000000, 1024}, -}; +const std::vector> argmin_inputs_i64 = + raft::util::itertools::product>({1000, 10000, 100000, 1000000, 10000000}, + {64, 128, 256, 512, 1024}); RAFT_BENCH_REGISTER((Argmin), "", argmin_inputs_i64); RAFT_BENCH_REGISTER((Argmin), "", argmin_inputs_i64); -} // namespace raft::bench::linalg +} // namespace raft::bench::matrix diff --git a/cpp/bench/matrix/gather.cu b/cpp/bench/matrix/gather.cu new file mode 100644 index 0000000000..97812c20a1 --- /dev/null +++ b/cpp/bench/matrix/gather.cu @@ -0,0 +1,101 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include + +#include + +namespace raft::bench::matrix { + +template +struct GatherParams { + IdxT rows, cols, map_length; +}; + +template +inline auto operator<<(std::ostream& os, const GatherParams& p) -> std::ostream& +{ + os << p.rows << "#" << p.cols << "#" << p.map_length; + return os; +} + +template +struct Gather : public fixture { + Gather(const GatherParams& p) : params(p) {} + + void allocate_data(const ::benchmark::State& state) override + { + matrix = raft::make_device_matrix(handle, params.rows, params.cols); + map = raft::make_device_vector(handle, params.map_length); + out = raft::make_device_matrix(handle, params.map_length, params.cols); + stencil = raft::make_device_vector(handle, Conditional ? params.map_length : IdxT(0)); + + raft::random::RngState rng{1234}; + raft::random::uniform( + rng, matrix.data_handle(), params.rows * params.cols, T(-1), T(1), stream); + raft::random::uniformInt( + handle, rng, map.data_handle(), params.map_length, (MapT)0, (MapT)params.rows); + if constexpr (Conditional) { + raft::random::uniform(rng, stencil.data_handle(), params.map_length, T(-1), T(1), stream); + } + handle.sync_stream(stream); + } + + void run_benchmark(::benchmark::State& state) override + { + std::ostringstream label_stream; + label_stream << params; + state.SetLabel(label_stream.str()); + + loop_on_state(state, [this]() { + auto matrix_const_view = raft::make_device_matrix_view( + matrix.data_handle(), matrix.extent(0), matrix.extent(1)); + auto map_const_view = + raft::make_device_vector_view(map.data_handle(), map.extent(0)); + if constexpr (Conditional) { + auto stencil_const_view = + raft::make_device_vector_view(stencil.data_handle(), stencil.extent(0)); + auto pred_op = raft::plug_const_op(T(0.0), raft::greater_op()); + raft::matrix::gather_if( + handle, matrix_const_view, out.view(), map_const_view, stencil_const_view, pred_op); + } else { + raft::matrix::gather(handle, matrix_const_view, map_const_view, out.view()); + } + }); + } + + private: + GatherParams params; + raft::device_matrix matrix, out; + raft::device_vector stencil; + raft::device_vector map; +}; // struct Gather + +template +using GatherIf = Gather; + +const std::vector> gather_inputs_i64 = + raft::util::itertools::product>( + {1000000}, {10, 20, 50, 100, 200, 500}, {1000, 10000, 100000, 1000000}); + +RAFT_BENCH_REGISTER((Gather), "", gather_inputs_i64); +RAFT_BENCH_REGISTER((Gather), "", gather_inputs_i64); +RAFT_BENCH_REGISTER((GatherIf), "", gather_inputs_i64); +RAFT_BENCH_REGISTER((GatherIf), "", gather_inputs_i64); +} // namespace raft::bench::matrix diff --git a/cpp/bench/matrix/select_k.cu b/cpp/bench/matrix/select_k.cu new file mode 100644 index 0000000000..452a50ba50 --- /dev/null +++ b/cpp/bench/matrix/select_k.cu @@ -0,0 +1,133 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * TODO: reconsider how to organize shared test+bench files better + * Related Issue: https://github.com/rapidsai/raft/issues/1153 + * (although this header does not depend on any gtest headers) + */ +#include "../../test/matrix/select_k.cuh" + +#include + +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include + +namespace raft::matrix { + +using namespace raft::bench; // NOLINT + +template +struct selection : public fixture { + explicit selection(const select::params& p) + : params_(p), + in_dists_(p.batch_size * p.len, stream), + in_ids_(p.batch_size * p.len, stream), + out_dists_(p.batch_size * p.k, stream), + out_ids_(p.batch_size * p.k, stream) + { + raft::sparse::iota_fill(in_ids_.data(), IdxT(p.batch_size), IdxT(p.len), stream); + raft::random::RngState state{42}; + raft::random::uniform(handle, state, in_dists_.data(), in_dists_.size(), KeyT(-1.0), KeyT(1.0)); + } + + void run_benchmark(::benchmark::State& state) override // NOLINT + { + handle_t handle{stream}; + using_pool_memory_res res; + try { + std::ostringstream label_stream; + label_stream << params_.batch_size << "#" << params_.len << "#" << params_.k; + state.SetLabel(label_stream.str()); + loop_on_state(state, [this, &handle]() { + select::select_k_impl(handle, + Algo, + in_dists_.data(), + in_ids_.data(), + params_.batch_size, + params_.len, + params_.k, + out_dists_.data(), + out_ids_.data(), + params_.select_min); + }); + } catch (raft::exception& e) { + state.SkipWithError(e.what()); + } + } + + private: + const select::params params_; + rmm::device_uvector in_dists_, out_dists_; + rmm::device_uvector in_ids_, out_ids_; +}; + +const std::vector kInputs{ + {20000, 500, 1, true}, {20000, 500, 2, true}, {20000, 500, 4, true}, + {20000, 500, 8, true}, {20000, 500, 16, true}, {20000, 500, 32, true}, + {20000, 500, 64, true}, {20000, 500, 128, true}, {20000, 500, 256, true}, + + {1000, 10000, 1, true}, {1000, 10000, 2, true}, {1000, 10000, 4, true}, + {1000, 10000, 8, true}, {1000, 10000, 16, true}, {1000, 10000, 32, true}, + {1000, 10000, 64, true}, {1000, 10000, 128, true}, {1000, 10000, 256, true}, + + {100, 100000, 1, true}, {100, 100000, 2, true}, {100, 100000, 4, true}, + {100, 100000, 8, true}, {100, 100000, 16, true}, {100, 100000, 32, true}, + {100, 100000, 64, true}, {100, 100000, 128, true}, {100, 100000, 256, true}, + + {10, 1000000, 1, true}, {10, 1000000, 2, true}, {10, 1000000, 4, true}, + {10, 1000000, 8, true}, {10, 1000000, 16, true}, {10, 1000000, 32, true}, + {10, 1000000, 64, true}, {10, 1000000, 128, true}, {10, 1000000, 256, true}, +}; + +#define SELECTION_REGISTER(KeyT, IdxT, A) \ + namespace BENCHMARK_PRIVATE_NAME(selection) \ + { \ + using SelectK = selection; \ + RAFT_BENCH_REGISTER(SelectK, #KeyT "/" #IdxT "/" #A, kInputs); \ + } + +SELECTION_REGISTER(float, int, kPublicApi); // NOLINT +SELECTION_REGISTER(float, int, kRadix8bits); // NOLINT +SELECTION_REGISTER(float, int, kRadix11bits); // NOLINT +SELECTION_REGISTER(float, int, kWarpAuto); // NOLINT +SELECTION_REGISTER(float, int, kWarpImmediate); // NOLINT +SELECTION_REGISTER(float, int, kWarpFiltered); // NOLINT +SELECTION_REGISTER(float, int, kWarpDistributed); // NOLINT +SELECTION_REGISTER(float, int, kWarpDistributedShm); // NOLINT + +SELECTION_REGISTER(double, int, kRadix8bits); // NOLINT +SELECTION_REGISTER(double, int, kRadix11bits); // NOLINT +SELECTION_REGISTER(double, int, kWarpAuto); // NOLINT + +SELECTION_REGISTER(double, size_t, kRadix8bits); // NOLINT +SELECTION_REGISTER(double, size_t, kRadix11bits); // NOLINT +SELECTION_REGISTER(double, size_t, kWarpImmediate); // NOLINT +SELECTION_REGISTER(double, size_t, kWarpFiltered); // NOLINT +SELECTION_REGISTER(double, size_t, kWarpDistributed); // NOLINT +SELECTION_REGISTER(double, size_t, kWarpDistributedShm); // NOLINT + +} // namespace raft::matrix diff --git a/cpp/bench/neighbors/selection.cu b/cpp/bench/neighbors/selection.cu deleted file mode 100644 index 1f116c199f..0000000000 --- a/cpp/bench/neighbors/selection.cu +++ /dev/null @@ -1,123 +0,0 @@ -/* - * Copyright (c) 2022, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#if defined RAFT_NN_COMPILED -#include -#endif - -#include -#include - -#include -#include - -namespace raft::bench::spatial { - -struct params { - int n_inputs; - int input_len; - int k; - int select_min; -}; - -template -struct selection : public fixture { - explicit selection(const params& p) - : params_(p), - in_dists_(p.n_inputs * p.input_len, stream), - in_ids_(p.n_inputs * p.input_len, stream), - out_dists_(p.n_inputs * p.k, stream), - out_ids_(p.n_inputs * p.k, stream) - { - raft::sparse::iota_fill(in_ids_.data(), IdxT(p.n_inputs), IdxT(p.input_len), stream); - raft::random::RngState state{42}; - raft::random::uniform(handle, state, in_dists_.data(), in_dists_.size(), KeyT(-1.0), KeyT(1.0)); - } - - void run_benchmark(::benchmark::State& state) override - { - using_pool_memory_res res; - try { - std::ostringstream label_stream; - label_stream << params_.n_inputs << "#" << params_.input_len << "#" << params_.k; - state.SetLabel(label_stream.str()); - loop_on_state(state, [this]() { - raft::spatial::knn::select_k(in_dists_.data(), - in_ids_.data(), - params_.n_inputs, - params_.input_len, - out_dists_.data(), - out_ids_.data(), - params_.select_min, - params_.k, - stream, - Algo); - }); - } catch (raft::exception& e) { - state.SkipWithError(e.what()); - } - } - - private: - const params params_; - rmm::device_uvector in_dists_, out_dists_; - rmm::device_uvector in_ids_, out_ids_; -}; - -const std::vector kInputs{ - {20000, 500, 1, true}, {20000, 500, 2, true}, {20000, 500, 4, true}, - {20000, 500, 8, true}, {20000, 500, 16, true}, {20000, 500, 32, true}, - {20000, 500, 64, true}, {20000, 500, 128, true}, {20000, 500, 256, true}, - - {1000, 10000, 1, true}, {1000, 10000, 2, true}, {1000, 10000, 4, true}, - {1000, 10000, 8, true}, {1000, 10000, 16, true}, {1000, 10000, 32, true}, - {1000, 10000, 64, true}, {1000, 10000, 128, true}, {1000, 10000, 256, true}, - - {100, 100000, 1, true}, {100, 100000, 2, true}, {100, 100000, 4, true}, - {100, 100000, 8, true}, {100, 100000, 16, true}, {100, 100000, 32, true}, - {100, 100000, 64, true}, {100, 100000, 128, true}, {100, 100000, 256, true}, - - {10, 1000000, 1, true}, {10, 1000000, 2, true}, {10, 1000000, 4, true}, - {10, 1000000, 8, true}, {10, 1000000, 16, true}, {10, 1000000, 32, true}, - {10, 1000000, 64, true}, {10, 1000000, 128, true}, {10, 1000000, 256, true}, -}; - -#define SELECTION_REGISTER(KeyT, IdxT, Algo) \ - namespace BENCHMARK_PRIVATE_NAME(selection) \ - { \ - using SelectK = selection; \ - RAFT_BENCH_REGISTER(SelectK, #KeyT "/" #IdxT "/" #Algo, kInputs); \ - } - -SELECTION_REGISTER(float, int, FAISS); -SELECTION_REGISTER(float, int, RADIX_8_BITS); -SELECTION_REGISTER(float, int, RADIX_11_BITS); -SELECTION_REGISTER(float, int, WARP_SORT); - -SELECTION_REGISTER(double, int, FAISS); -SELECTION_REGISTER(double, int, RADIX_8_BITS); -SELECTION_REGISTER(double, int, RADIX_11_BITS); -SELECTION_REGISTER(double, int, WARP_SORT); - -SELECTION_REGISTER(double, size_t, FAISS); -SELECTION_REGISTER(double, size_t, RADIX_8_BITS); -SELECTION_REGISTER(double, size_t, RADIX_11_BITS); -SELECTION_REGISTER(double, size_t, WARP_SORT); - -} // namespace raft::bench::spatial diff --git a/cpp/cmake/thirdparty/get_cutlass.cmake b/cpp/cmake/thirdparty/get_cutlass.cmake index 811a5466c3..3e02ce064e 100644 --- a/cpp/cmake/thirdparty/get_cutlass.cmake +++ b/cpp/cmake/thirdparty/get_cutlass.cmake @@ -30,6 +30,10 @@ function(find_and_configure_cutlass) CACHE BOOL "Disable CUTLASS to build with cuBLAS library." ) + if (CUDA_STATIC_RUNTIME) + set(CUDART_LIBRARY "${CUDA_cudart_static_LIBRARY}" CACHE FILEPATH "fixing cutlass cmake code" FORCE) + endif() + rapids_cpm_find( NvidiaCutlass ${PKG_VERSION} GLOBAL_TARGETS nvidia::cutlass::cutlass diff --git a/cpp/include/raft/cluster/detail/kmeans_common.cuh b/cpp/include/raft/cluster/detail/kmeans_common.cuh index 2fd33ac759..559793442f 100644 --- a/cpp/include/raft/cluster/detail/kmeans_common.cuh +++ b/cpp/include/raft/cluster/detail/kmeans_common.cuh @@ -335,7 +335,7 @@ void shuffleAndGather(const raft::handle_t& handle, in.extent(1), in.extent(0), indices.data_handle(), - n_samples_to_gather, + static_cast(n_samples_to_gather), out.data_handle(), stream); } diff --git a/cpp/include/raft/core/device_resources.hpp b/cpp/include/raft/core/device_resources.hpp index 9b9e07cf4f..ec0b92dde2 100644 --- a/cpp/include/raft/core/device_resources.hpp +++ b/cpp/include/raft/core/device_resources.hpp @@ -59,11 +59,18 @@ namespace raft { */ class device_resources : public resources { public: - // delete copy/move constructors and assignment operators as - // copying and moving underlying resources is unsafe - device_resources(const device_resources&) = delete; - device_resources& operator=(const device_resources&) = delete; - device_resources(device_resources&&) = delete; + device_resources(const device_resources& handle, + rmm::mr::device_memory_resource* workspace_resource) + : resources{handle} + { + // replace the resource factory for the workspace_resources + resources::add_resource_factory( + std::make_shared(workspace_resource)); + } + + device_resources(const device_resources& handle) : resources{handle} {} + + device_resources(device_resources&&) = delete; device_resources& operator=(device_resources&&) = delete; /** @@ -210,7 +217,7 @@ class device_resources : public resources { return resource::get_subcomm(*this, key); } - const rmm::mr::device_memory_resource* get_workspace_resource() const + rmm::mr::device_memory_resource* get_workspace_resource() const { return resource::get_workspace_resource(*this); } diff --git a/cpp/include/raft/core/handle.hpp b/cpp/include/raft/core/handle.hpp index 6486965cdf..c1e7aa538f 100644 --- a/cpp/include/raft/core/handle.hpp +++ b/cpp/include/raft/core/handle.hpp @@ -32,11 +32,14 @@ namespace raft { */ class handle_t : public raft::device_resources { public: - // delete copy/move constructors and assignment operators as - // copying and moving underlying resources is unsafe - handle_t(const handle_t&) = delete; - handle_t& operator=(const handle_t&) = delete; - handle_t(handle_t&&) = delete; + handle_t(const handle_t& handle, rmm::mr::device_memory_resource* workspace_resource) + : device_resources(handle, workspace_resource) + { + } + + handle_t(const handle_t& handle) : device_resources{handle} {} + + handle_t(handle_t&&) = delete; handle_t& operator=(handle_t&&) = delete; /** diff --git a/cpp/include/raft/core/math.hpp b/cpp/include/raft/core/math.hpp new file mode 100644 index 0000000000..c5f08b84b7 --- /dev/null +++ b/cpp/include/raft/core/math.hpp @@ -0,0 +1,320 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +#include + +namespace raft { + +/** + * @defgroup Absolute Absolute value + * @{ + */ +template +RAFT_INLINE_FUNCTION auto abs(T x) + -> std::enable_if_t || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v, + T> +{ +#ifdef __CUDA_ARCH__ + return ::abs(x); +#else + return std::abs(x); +#endif +} +template +constexpr RAFT_INLINE_FUNCTION auto abs(T x) + -> std::enable_if_t && !std::is_same_v && + !std::is_same_v && !std::is_same_v && + !std::is_same_v, + T> +{ + return x < T{0} ? -x : x; +} +/** @} */ + +/** + * @defgroup Trigonometry Trigonometry functions + * @{ + */ +/** Inverse cosine */ +template +RAFT_INLINE_FUNCTION auto acos(T x) +{ +#ifdef __CUDA_ARCH__ + return ::acos(x); +#else + return std::acos(x); +#endif +} + +/** Inverse sine */ +template +RAFT_INLINE_FUNCTION auto asin(T x) +{ +#ifdef __CUDA_ARCH__ + return ::asin(x); +#else + return std::asin(x); +#endif +} + +/** Inverse hyperbolic tangent */ +template +RAFT_INLINE_FUNCTION auto atanh(T x) +{ +#ifdef __CUDA_ARCH__ + return ::atanh(x); +#else + return std::atanh(x); +#endif +} + +/** Cosine */ +template +RAFT_INLINE_FUNCTION auto cos(T x) +{ +#ifdef __CUDA_ARCH__ + return ::cos(x); +#else + return std::cos(x); +#endif +} + +/** Sine */ +template +RAFT_INLINE_FUNCTION auto sin(T x) +{ +#ifdef __CUDA_ARCH__ + return ::sin(x); +#else + return std::sin(x); +#endif +} + +/** Sine and cosine */ +template +RAFT_INLINE_FUNCTION std::enable_if_t || std::is_same_v> sincos( + const T& x, T* s, T* c) +{ +#ifdef __CUDA_ARCH__ + ::sincos(x, s, c); +#else + *s = std::sin(x); + *c = std::cos(x); +#endif +} + +/** Hyperbolic tangent */ +template +RAFT_INLINE_FUNCTION auto tanh(T x) +{ +#ifdef __CUDA_ARCH__ + return ::tanh(x); +#else + return std::tanh(x); +#endif +} +/** @} */ + +/** + * @defgroup Exponential Exponential and logarithm + * @{ + */ +/** Exponential function */ +template +RAFT_INLINE_FUNCTION auto exp(T x) +{ +#ifdef __CUDA_ARCH__ + return ::exp(x); +#else + return std::exp(x); +#endif +} + +/** Natural logarithm */ +template +RAFT_INLINE_FUNCTION auto log(T x) +{ +#ifdef __CUDA_ARCH__ + return ::log(x); +#else + return std::log(x); +#endif +} +/** @} */ + +/** + * @defgroup Maximum Maximum of two or more values. + * + * The CUDA Math API has overloads for all combinations of float/double. We provide similar + * functionality while wrapping around std::max, which only supports arguments of the same type. + * However, though the CUDA Math API supports combinations of unsigned and signed integers, this is + * very error-prone so we do not support that and require the user to cast instead. (e.g the max of + * -1 and 1u is 4294967295u...) + * + * When no overload matches, we provide a generic implementation but require that both types be the + * same (and that the less-than operator be defined). + * @{ + */ +template +RAFT_INLINE_FUNCTION auto max(const T1& x, const T2& y) +{ +#ifdef __CUDA_ARCH__ + // Combinations of types supported by the CUDA Math API + if constexpr ((std::is_integral_v && std::is_integral_v && std::is_same_v) || + ((std::is_same_v || std::is_same_v)&&( + std::is_same_v || std::is_same_v))) { + return ::max(x, y); + } + // Else, check that the types are the same and provide a generic implementation + else { + static_assert( + std::is_same_v, + "No native max overload for these types. Both argument types must be the same to use " + "the generic max. Please cast appropriately."); + return (x < y) ? y : x; + } +#else + if constexpr (std::is_same_v && std::is_same_v) { + return std::max(static_cast(x), y); + } else if constexpr (std::is_same_v && std::is_same_v) { + return std::max(x, static_cast(y)); + } else { + static_assert( + std::is_same_v, + "std::max requires that both argument types be the same. Please cast appropriately."); + return std::max(x, y); + } +#endif +} + +/** Many-argument overload to avoid verbose nested calls or use with variadic arguments */ +template +RAFT_INLINE_FUNCTION auto max(const T1& x, const T2& y, Args&&... args) +{ + return raft::max(x, raft::max(y, std::forward(args)...)); +} + +/** One-argument overload for convenience when using with variadic arguments */ +template +constexpr RAFT_INLINE_FUNCTION auto max(const T& x) +{ + return x; +} +/** @} */ + +/** + * @defgroup Minimum Minimum of two or more values. + * + * The CUDA Math API has overloads for all combinations of float/double. We provide similar + * functionality while wrapping around std::min, which only supports arguments of the same type. + * However, though the CUDA Math API supports combinations of unsigned and signed integers, this is + * very error-prone so we do not support that and require the user to cast instead. (e.g the min of + * -1 and 1u is 1u...) + * + * When no overload matches, we provide a generic implementation but require that both types be the + * same (and that the less-than operator be defined). + * @{ + */ +template +RAFT_INLINE_FUNCTION auto min(const T1& x, const T2& y) +{ +#ifdef __CUDA_ARCH__ + // Combinations of types supported by the CUDA Math API + if constexpr ((std::is_integral_v && std::is_integral_v && std::is_same_v) || + ((std::is_same_v || std::is_same_v)&&( + std::is_same_v || std::is_same_v))) { + return ::min(x, y); + } + // Else, check that the types are the same and provide a generic implementation + else { + static_assert( + std::is_same_v, + "No native min overload for these types. Both argument types must be the same to use " + "the generic min. Please cast appropriately."); + return (y < x) ? y : x; + } +#else + if constexpr (std::is_same_v && std::is_same_v) { + return std::min(static_cast(x), y); + } else if constexpr (std::is_same_v && std::is_same_v) { + return std::min(x, static_cast(y)); + } else { + static_assert( + std::is_same_v, + "std::min requires that both argument types be the same. Please cast appropriately."); + return std::min(x, y); + } +#endif +} + +/** Many-argument overload to avoid verbose nested calls or use with variadic arguments */ +template +RAFT_INLINE_FUNCTION auto min(const T1& x, const T2& y, Args&&... args) +{ + return raft::min(x, raft::min(y, std::forward(args)...)); +} + +/** One-argument overload for convenience when using with variadic arguments */ +template +constexpr RAFT_INLINE_FUNCTION auto min(const T& x) +{ + return x; +} +/** @} */ + +/** + * @defgroup Power Power and root functions + * @{ + */ +/** Power */ +template +RAFT_INLINE_FUNCTION auto pow(T1 x, T2 y) +{ +#ifdef __CUDA_ARCH__ + return ::pow(x, y); +#else + return std::pow(x, y); +#endif +} + +/** Square root */ +template +RAFT_INLINE_FUNCTION auto sqrt(T x) +{ +#ifdef __CUDA_ARCH__ + return ::sqrt(x); +#else + return std::sqrt(x); +#endif +} +/** @} */ + +/** Sign */ +template +RAFT_INLINE_FUNCTION auto sgn(T val) -> int +{ + return (T(0) < val) - (val < T(0)); +} + +} // namespace raft diff --git a/cpp/include/raft/core/operators.hpp b/cpp/include/raft/core/operators.hpp index 398354df46..edb437c880 100644 --- a/cpp/include/raft/core/operators.hpp +++ b/cpp/include/raft/core/operators.hpp @@ -23,6 +23,7 @@ #include #include +#include namespace raft { @@ -75,9 +76,9 @@ struct value_op { struct sqrt_op { template - constexpr RAFT_INLINE_FUNCTION auto operator()(const Type& in, UnusedArgs...) const + RAFT_INLINE_FUNCTION auto operator()(const Type& in, UnusedArgs...) const { - return std::sqrt(in); + return raft::sqrt(in); } }; @@ -91,9 +92,9 @@ struct nz_op { struct abs_op { template - constexpr RAFT_INLINE_FUNCTION auto operator()(const Type& in, UnusedArgs...) const + RAFT_INLINE_FUNCTION auto operator()(const Type& in, UnusedArgs...) const { - return std::abs(in); + return raft::abs(in); } }; @@ -146,29 +147,35 @@ struct div_checkzero_op { } }; +struct modulo_op { + template + constexpr RAFT_INLINE_FUNCTION auto operator()(const T1& a, const T2& b) const + { + return a % b; + } +}; + struct pow_op { template - constexpr RAFT_INLINE_FUNCTION auto operator()(const Type& a, const Type& b) const + RAFT_INLINE_FUNCTION auto operator()(const Type& a, const Type& b) const { - return std::pow(a, b); + return raft::pow(a, b); } }; struct min_op { - template - constexpr RAFT_INLINE_FUNCTION auto operator()(const Type& a, const Type& b) const + template + RAFT_INLINE_FUNCTION auto operator()(Args&&... args) const { - if (a > b) { return b; } - return a; + return raft::min(std::forward(args)...); } }; struct max_op { - template - constexpr RAFT_INLINE_FUNCTION auto operator()(const Type& a, const Type& b) const + template + RAFT_INLINE_FUNCTION auto operator()(Args&&... args) const { - if (b > a) { return b; } - return a; + return raft::max(std::forward(args)...); } }; @@ -190,17 +197,49 @@ struct argmax_op { } }; +struct greater_op { + template + constexpr RAFT_INLINE_FUNCTION auto operator()(const T1& a, const T2& b) const + { + return a > b; + } +}; + +struct less_op { + template + constexpr RAFT_INLINE_FUNCTION auto operator()(const T1& a, const T2& b) const + { + return a < b; + } +}; + +struct greater_or_equal_op { + template + constexpr RAFT_INLINE_FUNCTION auto operator()(const T1& a, const T2& b) const + { + return a >= b; + } +}; + +struct less_or_equal_op { + template + constexpr RAFT_INLINE_FUNCTION auto operator()(const T1& a, const T2& b) const + { + return a <= b; + } +}; + struct equal_op { - template - constexpr RAFT_INLINE_FUNCTION auto operator()(const Type& a, const Type& b) const + template + constexpr RAFT_INLINE_FUNCTION auto operator()(const T1& a, const T2& b) const { return a == b; } }; struct notequal_op { - template - constexpr RAFT_INLINE_FUNCTION auto operator()(const Type& a, const Type& b) const + template + constexpr RAFT_INLINE_FUNCTION auto operator()(const T1& a, const T2& b) const { return a != b; } @@ -268,6 +307,9 @@ using div_const_op = plug_const_op; template using div_checkzero_const_op = plug_const_op; +template +using modulo_const_op = plug_const_op; + template using pow_const_op = plug_const_op; diff --git a/cpp/include/raft/core/resources.hpp b/cpp/include/raft/core/resources.hpp index 797fd5968d..64e281e934 100644 --- a/cpp/include/raft/core/resources.hpp +++ b/cpp/include/raft/core/resources.hpp @@ -62,9 +62,12 @@ class resources { } } - resources(const resources&) = delete; - resources& operator=(const resources&) = delete; - resources(resources&&) = delete; + /** + * @brief Shallow copy of underlying resources instance. + * Note that this does not create any new resources. + */ + resources(const resources& res) : factories_(res.factories_), resources_(res.resources_) {} + resources(resources&&) = delete; resources& operator=(resources&&) = delete; /** @@ -120,7 +123,7 @@ class resources { return reinterpret_cast(res->get_resource()); } - private: + protected: mutable std::mutex mutex_; mutable std::vector factories_; mutable std::vector resources_; diff --git a/cpp/include/raft/distance/detail/canberra.cuh b/cpp/include/raft/distance/detail/canberra.cuh index 43a904edba..f17a26dc4b 100644 --- a/cpp/include/raft/distance/detail/canberra.cuh +++ b/cpp/include/raft/distance/detail/canberra.cuh @@ -73,8 +73,8 @@ static void canberraImpl(const DataT* x, // Accumulation operation lambda auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) { - const auto diff = raft::myAbs(x - y); - const auto add = raft::myAbs(x) + raft::myAbs(y); + const auto diff = raft::abs(x - y); + const auto add = raft::abs(x) + raft::abs(y); // deal with potential for 0 in denominator by // forcing 1/0 instead acc += ((add != 0) * diff / (add + (add == 0))); diff --git a/cpp/include/raft/distance/detail/chebyshev.cuh b/cpp/include/raft/distance/detail/chebyshev.cuh index 52573bd170..43b36e7921 100644 --- a/cpp/include/raft/distance/detail/chebyshev.cuh +++ b/cpp/include/raft/distance/detail/chebyshev.cuh @@ -73,8 +73,8 @@ static void chebyshevImpl(const DataT* x, // Accumulation operation lambda auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) { - const auto diff = raft::myAbs(x - y); - acc = raft::myMax(acc, diff); + const auto diff = raft::abs(x - y); + acc = raft::max(acc, diff); }; // epilogue operation lambda for final value calculation diff --git a/cpp/include/raft/distance/detail/correlation.cuh b/cpp/include/raft/distance/detail/correlation.cuh index 9bdbbf112c..f7fe3678e6 100644 --- a/cpp/include/raft/distance/detail/correlation.cuh +++ b/cpp/include/raft/distance/detail/correlation.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -125,7 +125,7 @@ static void correlationImpl(const DataT* x, auto Q_denom = k * regx2n[i] - (regxn[i] * regxn[i]); auto R_denom = k * regy2n[j] - (regyn[j] * regyn[j]); - acc[i][j] = 1 - (numer / raft::mySqrt(Q_denom * R_denom)); + acc[i][j] = 1 - (numer / raft::sqrt(Q_denom * R_denom)); } } }; diff --git a/cpp/include/raft/distance/detail/euclidean.cuh b/cpp/include/raft/distance/detail/euclidean.cuh index 4184810fff..1a2db63f5c 100644 --- a/cpp/include/raft/distance/detail/euclidean.cuh +++ b/cpp/include/raft/distance/detail/euclidean.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * Copyright (c) 2018-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -34,7 +34,7 @@ struct L2ExpandedOp { __device__ AccT operator()(DataT& aNorm, const DataT& bNorm, DataT& accVal) const noexcept { AccT outVal = aNorm + bNorm - DataT(2.0) * accVal; - return sqrt ? raft::mySqrt(outVal) : outVal; + return sqrt ? raft::sqrt(outVal) : outVal; } __device__ AccT operator()(DataT aData) const noexcept { return aData; } @@ -130,7 +130,7 @@ void euclideanExpImpl(const DataT* x, for (int i = 0; i < KPolicy::AccRowsPerTh; ++i) { #pragma unroll for (int j = 0; j < KPolicy::AccColsPerTh; ++j) { - acc[i][j] = raft::mySqrt(acc[i][j]); + acc[i][j] = raft::sqrt(acc[i][j]); } } } @@ -350,7 +350,7 @@ void euclideanUnExpImpl(const DataT* x, for (int i = 0; i < KPolicy::AccRowsPerTh; ++i) { #pragma unroll for (int j = 0; j < KPolicy::AccColsPerTh; ++j) { - acc[i][j] = raft::mySqrt(acc[i][j]); + acc[i][j] = raft::sqrt(acc[i][j]); } } } diff --git a/cpp/include/raft/distance/detail/fused_l2_nn.cuh b/cpp/include/raft/distance/detail/fused_l2_nn.cuh index 5311a26d19..b7ee95795f 100644 --- a/cpp/include/raft/distance/detail/fused_l2_nn.cuh +++ b/cpp/include/raft/distance/detail/fused_l2_nn.cuh @@ -185,7 +185,7 @@ __global__ __launch_bounds__(P::Nthreads, 2) void fusedL2NNkernel(OutT* min, #pragma unroll for (int j = 0; j < P::AccColsPerTh; ++j) { auto acc_ij = acc[i][j]; - acc[i][j] = acc_ij > DataT{0} ? raft::mySqrt(acc_ij) : DataT{0}; + acc[i][j] = acc_ij > DataT{0} ? raft::sqrt(acc_ij) : DataT{0}; } } } diff --git a/cpp/include/raft/distance/detail/hellinger.cuh b/cpp/include/raft/distance/detail/hellinger.cuh index 51f462ab36..13507fe84f 100644 --- a/cpp/include/raft/distance/detail/hellinger.cuh +++ b/cpp/include/raft/distance/detail/hellinger.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -105,7 +105,7 @@ static void hellingerImpl(const DataT* x, // Adjust to replace NaN in sqrt with 0 if input to sqrt is negative const auto finalVal = (1 - acc[i][j]); const auto rectifier = (!signbit(finalVal)); - acc[i][j] = raft::mySqrt(rectifier * finalVal); + acc[i][j] = raft::sqrt(rectifier * finalVal); } } }; diff --git a/cpp/include/raft/distance/detail/jensen_shannon.cuh b/cpp/include/raft/distance/detail/jensen_shannon.cuh index 92ee071cf5..f96da01b87 100644 --- a/cpp/include/raft/distance/detail/jensen_shannon.cuh +++ b/cpp/include/raft/distance/detail/jensen_shannon.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -78,11 +78,11 @@ static void jensenShannonImpl(const DataT* x, auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) { const DataT m = 0.5f * (x + y); const bool m_zero = (m == 0); - const auto logM = (!m_zero) * raft::myLog(m + m_zero); + const auto logM = (!m_zero) * raft::log(m + m_zero); const bool x_zero = (x == 0); const bool y_zero = (y == 0); - acc += (-x * (logM - raft::myLog(x + x_zero))) + (-y * (logM - raft::myLog(y + y_zero))); + acc += (-x * (logM - raft::log(x + x_zero))) + (-y * (logM - raft::log(y + y_zero))); }; // epilogue operation lambda for final value calculation @@ -95,7 +95,7 @@ static void jensenShannonImpl(const DataT* x, for (int i = 0; i < KPolicy::AccRowsPerTh; ++i) { #pragma unroll for (int j = 0; j < KPolicy::AccColsPerTh; ++j) { - acc[i][j] = raft::mySqrt(0.5 * acc[i][j]); + acc[i][j] = raft::sqrt(0.5 * acc[i][j]); } } }; diff --git a/cpp/include/raft/distance/detail/kl_divergence.cuh b/cpp/include/raft/distance/detail/kl_divergence.cuh index 4c0c4b6ace..7ebeaf4de9 100644 --- a/cpp/include/raft/distance/detail/kl_divergence.cuh +++ b/cpp/include/raft/distance/detail/kl_divergence.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -81,10 +81,10 @@ static void klDivergenceImpl(const DataT* x, auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) { if (isRowMajor) { const bool x_zero = (x == 0); - acc += x * (raft::myLog(x + x_zero) - y); + acc += x * (raft::log(x + x_zero) - y); } else { const bool y_zero = (y == 0); - acc += y * (raft::myLog(y + y_zero) - x); + acc += y * (raft::log(y + y_zero) - x); } }; @@ -92,23 +92,23 @@ static void klDivergenceImpl(const DataT* x, if (isRowMajor) { const bool x_zero = (x == 0); const bool y_zero = (y == 0); - acc += x * (raft::myLog(x + x_zero) - (!y_zero) * raft::myLog(y + y_zero)); + acc += x * (raft::log(x + x_zero) - (!y_zero) * raft::log(y + y_zero)); } else { const bool y_zero = (y == 0); const bool x_zero = (x == 0); - acc += y * (raft::myLog(y + y_zero) - (!x_zero) * raft::myLog(x + x_zero)); + acc += y * (raft::log(y + y_zero) - (!x_zero) * raft::log(x + x_zero)); } }; auto unaryOp_lambda = [] __device__(DataT input) { const bool x_zero = (input == 0); - return (!x_zero) * raft::myLog(input + x_zero); + return (!x_zero) * raft::log(input + x_zero); }; auto unaryOp_lambda_reverse = [] __device__(DataT input) { // reverse previous log (x) back to x using (e ^ log(x)) const bool x_zero = (input == 0); - return (!x_zero) * raft::myExp(input); + return (!x_zero) * raft::exp(input); }; // epilogue operation lambda for final value calculation diff --git a/cpp/include/raft/distance/detail/l1.cuh b/cpp/include/raft/distance/detail/l1.cuh index 87893bab7c..bf10651b60 100644 --- a/cpp/include/raft/distance/detail/l1.cuh +++ b/cpp/include/raft/distance/detail/l1.cuh @@ -71,7 +71,7 @@ static void l1Impl(const DataT* x, // Accumulation operation lambda auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) { - const auto diff = raft::myAbs(x - y); + const auto diff = raft::abs(x - y); acc += diff; }; diff --git a/cpp/include/raft/distance/detail/minkowski.cuh b/cpp/include/raft/distance/detail/minkowski.cuh index bda83babf1..42af8cd281 100644 --- a/cpp/include/raft/distance/detail/minkowski.cuh +++ b/cpp/include/raft/distance/detail/minkowski.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -74,8 +74,8 @@ void minkowskiUnExpImpl(const DataT* x, // Accumulation operation lambda auto core_lambda = [p] __device__(AccT & acc, DataT & x, DataT & y) { - const auto diff = raft::myAbs(x - y); - acc += raft::myPow(diff, p); + const auto diff = raft::abs(x - y); + acc += raft::pow(diff, p); }; // epilogue operation lambda for final value calculation @@ -89,7 +89,7 @@ void minkowskiUnExpImpl(const DataT* x, for (int i = 0; i < KPolicy::AccRowsPerTh; ++i) { #pragma unroll for (int j = 0; j < KPolicy::AccColsPerTh; ++j) { - acc[i][j] = raft::myPow(acc[i][j], one_over_p); + acc[i][j] = raft::pow(acc[i][j], one_over_p); } } }; diff --git a/cpp/include/raft/linalg/detail/lstsq.cuh b/cpp/include/raft/linalg/detail/lstsq.cuh index 1273956b21..f0cf300e2f 100644 --- a/cpp/include/raft/linalg/detail/lstsq.cuh +++ b/cpp/include/raft/linalg/detail/lstsq.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * Copyright (c) 2018-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -104,7 +104,7 @@ struct DivideByNonZero { operator()(const math_t a, const math_t b) const { - return raft::myAbs(b) >= eps ? a / b : a; + return raft::abs(b) >= eps ? a / b : a; } }; diff --git a/cpp/include/raft/matrix/detail/gather.cuh b/cpp/include/raft/matrix/detail/gather.cuh index c006f69e47..a8efc2d0d0 100644 --- a/cpp/include/raft/matrix/detail/gather.cuh +++ b/cpp/include/raft/matrix/detail/gather.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,41 +17,63 @@ #pragma once #include +#include namespace raft { namespace matrix { namespace detail { -// gatherKernel conditionally copies rows from the source matrix 'in' into the destination matrix -// 'out' according to a map (or a transformed map) -template +struct gather_policy { + static constexpr int n_threads = tpb; + static constexpr int work_per_thread = wpt; + static constexpr int stride = tpb * wpt; +}; + +/** Conditionally copies rows from the source matrix 'in' into the destination matrix + * 'out' according to a map (or a transformed map) */ +template -__global__ void gatherKernel(const MatrixIteratorT in, - IndexT D, - IndexT N, - MapIteratorT map, - StencilIteratorT stencil, - MatrixIteratorT out, - PredicateOp pred_op, - MapTransformOp transform_op) + typename OutputIteratorT, + typename IndexT> +__global__ void gather_kernel(const InputIteratorT in, + IndexT D, + IndexT len, + const MapIteratorT map, + StencilIteratorT stencil, + OutputIteratorT out, + PredicateOp pred_op, + MapTransformOp transform_op) { typedef typename std::iterator_traits::value_type MapValueT; typedef typename std::iterator_traits::value_type StencilValueT; - IndexT outRowStart = blockIdx.x * D; - MapValueT map_val = map[blockIdx.x]; - StencilValueT stencil_val = stencil[blockIdx.x]; +#pragma unroll + for (IndexT wid = 0; wid < Policy::work_per_thread; wid++) { + IndexT tid = threadIdx.x + (Policy::work_per_thread * static_cast(blockIdx.x) + wid) * + Policy::n_threads; + if (tid < len) { + IndexT i_dst = tid / D; + IndexT j = tid % D; + + MapValueT map_val = map[i_dst]; + StencilValueT stencil_val = stencil[i_dst]; - bool predicate = pred_op(stencil_val); - if (predicate) { - IndexT inRowStart = transform_op(map_val) * D; - for (int i = threadIdx.x; i < D; i += TPB) { - out[outRowStart + i] = in[inRowStart + i]; + bool predicate = pred_op(stencil_val); + if (predicate) { + IndexT i_src = transform_op(map_val); + out[tid] = in[i_src * D + j]; + } } } } @@ -60,7 +82,7 @@ __global__ void gatherKernel(const MatrixIteratorT in, * @brief gather conditionally copies rows from a source matrix into a destination matrix according * to a transformed map. * - * @tparam MatrixIteratorT Random-access iterator type, for reading input matrix (may be a + * @tparam InputIteratorT Random-access iterator type, for reading input matrix (may be a * simple pointer type). * @tparam MapIteratorT Random-access iterator type, for reading input map (may be a simple * pointer type). @@ -69,7 +91,10 @@ __global__ void gatherKernel(const MatrixIteratorT in, * @tparam UnaryPredicateOp Unary lambda expression or operator type, UnaryPredicateOp's result * type must be convertible to bool type. * @tparam MapTransformOp Unary lambda expression or operator type, MapTransformOp's result - * type must be convertible to IndexT (= int) type. + * type must be convertible to IndexT. + * @tparam OutputIteratorT Random-access iterator type, for writing output matrix (may be a + * simple pointer type). + * @tparam IndexT Index type. * * @param in Pointer to the input matrix (assumed to be row-major) * @param D Leading dimension of the input matrix 'in', which in-case of row-major @@ -83,18 +108,20 @@ __global__ void gatherKernel(const MatrixIteratorT in, * @param transform_op The transformation operation, transforms the map values to IndexT * @param stream CUDA stream to launch kernels within */ -template -void gatherImpl(const MatrixIteratorT in, - int D, - int N, - MapIteratorT map, + typename MapTransformOp, + typename OutputIteratorT, + typename IndexT> +void gatherImpl(const InputIteratorT in, + IndexT D, + IndexT N, + const MapIteratorT map, StencilIteratorT stencil, - int map_length, - MatrixIteratorT out, + IndexT map_length, + OutputIteratorT out, UnaryPredicateOp pred_op, MapTransformOp transform_op, cudaStream_t stream) @@ -102,9 +129,6 @@ void gatherImpl(const MatrixIteratorT in, // skip in case of 0 length input if (map_length <= 0 || N <= 0 || D <= 0) return; - // signed integer type for indexing or global offsets - typedef int IndexT; - // map value type typedef typename std::iterator_traits::value_type MapValueT; @@ -121,38 +145,26 @@ void gatherImpl(const MatrixIteratorT in, static_assert((std::is_convertible::value), "UnaryPredicateOp's result type must be convertible to bool type"); - if (D <= 32) { - gatherKernel - <<>>(in, D, N, map, stencil, out, pred_op, transform_op); - } else if (D <= 64) { - gatherKernel - <<>>(in, D, N, map, stencil, out, pred_op, transform_op); - } else if (D <= 128) { - gatherKernel - <<>>(in, D, N, map, stencil, out, pred_op, transform_op); + IndexT len = map_length * D; + constexpr int TPB = 128; + const int n_sm = raft::getMultiProcessorCount(); + // The following empirical heuristics enforce that we keep a good balance between having enough + // blocks and enough work per thread. + if (len < 32 * TPB * n_sm) { + using Policy = gather_policy; + IndexT n_blocks = raft::ceildiv(map_length * D, static_cast(Policy::stride)); + gather_kernel<<>>( + in, D, len, map, stencil, out, pred_op, transform_op); + } else if (len < 32 * 4 * TPB * n_sm) { + using Policy = gather_policy; + IndexT n_blocks = raft::ceildiv(map_length * D, static_cast(Policy::stride)); + gather_kernel<<>>( + in, D, len, map, stencil, out, pred_op, transform_op); } else { - gatherKernel - <<>>(in, D, N, map, stencil, out, pred_op, transform_op); + using Policy = gather_policy; + IndexT n_blocks = raft::ceildiv(map_length * D, static_cast(Policy::stride)); + gather_kernel<<>>( + in, D, len, map, stencil, out, pred_op, transform_op); } RAFT_CUDA_TRY(cudaPeekAtLastError()); } @@ -160,10 +172,13 @@ void gatherImpl(const MatrixIteratorT in, /** * @brief gather copies rows from a source matrix into a destination matrix according to a map. * - * @tparam MatrixIteratorT Random-access iterator type, for reading input matrix (may be a + * @tparam InputIteratorT Random-access iterator type, for reading input matrix (may be a * simple pointer type). * @tparam MapIteratorT Random-access iterator type, for reading input map (may be a simple * pointer type). + * @tparam OutputIteratorT Random-access iterator type, for writing output matrix (may be a + * simple pointer type). + * @tparam IndexT Index type. * * @param in Pointer to the input matrix (assumed to be row-major) * @param D Leading dimension of the input matrix 'in', which in-case of row-major @@ -174,13 +189,13 @@ void gatherImpl(const MatrixIteratorT in, * @param out Pointer to the output matrix (assumed to be row-major) * @param stream CUDA stream to launch kernels within */ -template -void gather(const MatrixIteratorT in, - int D, - int N, - MapIteratorT map, - int map_length, - MatrixIteratorT out, +template +void gather(const InputIteratorT in, + IndexT D, + IndexT N, + const MapIteratorT map, + IndexT map_length, + OutputIteratorT out, cudaStream_t stream) { typedef typename std::iterator_traits::value_type MapValueT; @@ -192,12 +207,15 @@ void gather(const MatrixIteratorT in, * @brief gather copies rows from a source matrix into a destination matrix according to a * transformed map. * - * @tparam MatrixIteratorT Random-access iterator type, for reading input matrix (may be a + * @tparam InputIteratorT Random-access iterator type, for reading input matrix (may be a * simple pointer type). * @tparam MapIteratorT Random-access iterator type, for reading input map (may be a simple * pointer type). * @tparam MapTransformOp Unary lambda expression or operator type, MapTransformOp's result - * type must be convertible to IndexT (= int) type. + * type must be convertible to IndexT. + * @tparam OutputIteratorT Random-access iterator type, for writing output matrix (may be a + * simple pointer type). + * @tparam IndexT Index type. * * @param in Pointer to the input matrix (assumed to be row-major) * @param D Leading dimension of the input matrix 'in', which in-case of row-major @@ -209,13 +227,17 @@ void gather(const MatrixIteratorT in, * @param transform_op The transformation operation, transforms the map values to IndexT * @param stream CUDA stream to launch kernels within */ -template -void gather(const MatrixIteratorT in, - int D, - int N, - MapIteratorT map, - int map_length, - MatrixIteratorT out, +template +void gather(const InputIteratorT in, + IndexT D, + IndexT N, + const MapIteratorT map, + IndexT map_length, + OutputIteratorT out, MapTransformOp transform_op, cudaStream_t stream) { @@ -227,7 +249,7 @@ void gather(const MatrixIteratorT in, * @brief gather_if conditionally copies rows from a source matrix into a destination matrix * according to a map. * - * @tparam MatrixIteratorT Random-access iterator type, for reading input matrix (may be a + * @tparam InputIteratorT Random-access iterator type, for reading input matrix (may be a * simple pointer type). * @tparam MapIteratorT Random-access iterator type, for reading input map (may be a simple * pointer type). @@ -235,6 +257,9 @@ void gather(const MatrixIteratorT in, * simple pointer type). * @tparam UnaryPredicateOp Unary lambda expression or operator type, UnaryPredicateOp's result * type must be convertible to bool type. + * @tparam OutputIteratorT Random-access iterator type, for writing output matrix (may be a + * simple pointer type). + * @tparam IndexT Index type. * * @param in Pointer to the input matrix (assumed to be row-major) * @param D Leading dimension of the input matrix 'in', which in-case of row-major @@ -247,17 +272,19 @@ void gather(const MatrixIteratorT in, * @param pred_op Predicate to apply to the stencil values * @param stream CUDA stream to launch kernels within */ -template -void gather_if(const MatrixIteratorT in, - int D, - int N, - MapIteratorT map, + typename UnaryPredicateOp, + typename OutputIteratorT, + typename IndexT> +void gather_if(const InputIteratorT in, + IndexT D, + IndexT N, + const MapIteratorT map, StencilIteratorT stencil, - int map_length, - MatrixIteratorT out, + IndexT map_length, + OutputIteratorT out, UnaryPredicateOp pred_op, cudaStream_t stream) { @@ -269,7 +296,7 @@ void gather_if(const MatrixIteratorT in, * @brief gather_if conditionally copies rows from a source matrix into a destination matrix * according to a transformed map. * - * @tparam MatrixIteratorT Random-access iterator type, for reading input matrix (may be a + * @tparam InputIteratorT Random-access iterator type, for reading input matrix (may be a * simple pointer type). * @tparam MapIteratorT Random-access iterator type, for reading input map (may be a simple * pointer type). @@ -278,7 +305,10 @@ void gather_if(const MatrixIteratorT in, * @tparam UnaryPredicateOp Unary lambda expression or operator type, UnaryPredicateOp's result * type must be convertible to bool type. * @tparam MapTransformOp Unary lambda expression or operator type, MapTransformOp's result - * type must be convertible to IndexT (= int) type. + * type must be convertible to IndexT type. + * @tparam OutputIteratorT Random-access iterator type, for writing output matrix (may be a + * simple pointer type). + * @tparam IndexT Index type. * * @param in Pointer to the input matrix (assumed to be row-major) * @param D Leading dimension of the input matrix 'in', which in-case of row-major @@ -292,18 +322,20 @@ void gather_if(const MatrixIteratorT in, * @param transform_op The transformation operation, transforms the map values to IndexT * @param stream CUDA stream to launch kernels within */ -template -void gather_if(const MatrixIteratorT in, - int D, - int N, - MapIteratorT map, + typename MapTransformOp, + typename OutputIteratorT, + typename IndexT> +void gather_if(const InputIteratorT in, + IndexT D, + IndexT N, + const MapIteratorT map, StencilIteratorT stencil, - int map_length, - MatrixIteratorT out, + IndexT map_length, + OutputIteratorT out, UnaryPredicateOp pred_op, MapTransformOp transform_op, cudaStream_t stream) diff --git a/cpp/include/raft/matrix/detail/math.cuh b/cpp/include/raft/matrix/detail/math.cuh index c559da3942..f5c33d1cf6 100644 --- a/cpp/include/raft/matrix/detail/math.cuh +++ b/cpp/include/raft/matrix/detail/math.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -87,10 +87,10 @@ void seqRoot(math_t* in, if (a < math_t(0)) { return math_t(0); } else { - return sqrt(a * scalar); + return raft::sqrt(a * scalar); } } else { - return sqrt(a * scalar); + return raft::sqrt(a * scalar); } }, stream); @@ -278,7 +278,7 @@ void matrixVectorBinaryDivSkipZero(Type* data, rowMajor, bcastAlongRows, [] __device__(Type a, Type b) { - if (raft::myAbs(b) < Type(1e-10)) + if (raft::abs(b) < Type(1e-10)) return Type(0); else return a / b; @@ -294,7 +294,7 @@ void matrixVectorBinaryDivSkipZero(Type* data, rowMajor, bcastAlongRows, [] __device__(Type a, Type b) { - if (raft::myAbs(b) < Type(1e-10)) + if (raft::abs(b) < Type(1e-10)) return a; else return a / b; diff --git a/cpp/include/raft/spatial/knn/detail/topk.cuh b/cpp/include/raft/matrix/detail/select_k.cuh similarity index 59% rename from cpp/include/raft/spatial/knn/detail/topk.cuh rename to cpp/include/raft/matrix/detail/select_k.cuh index f4dcb53088..ac1ba3dfa3 100644 --- a/cpp/include/raft/spatial/knn/detail/topk.cuh +++ b/cpp/include/raft/matrix/detail/select_k.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,34 +16,34 @@ #pragma once -#include "topk/radix_topk.cuh" -#include "topk/warpsort_topk.cuh" +#include "select_radix.cuh" +#include "select_warpsort.cuh" #include #include #include -namespace raft::spatial::knn::detail { +namespace raft::matrix::detail { /** * Select k smallest or largest key/values from each row in the input data. * - * If you think of the input data `in_keys` as a row-major matrix with len columns and - * batch_size rows, then this function selects k smallest/largest values in each row and fills - * in the row-major matrix `out` of size (batch_size, k). + * If you think of the input data `in_val` as a row-major matrix with `len` columns and + * `batch_size` rows, then this function selects `k` smallest/largest values in each row and fills + * in the row-major matrix `out_val` of size (batch_size, k). * * @tparam T * the type of the keys (what is being compared). * @tparam IdxT * the index type (what is being selected together with the keys). * - * @param[in] in + * @param[in] in_val * contiguous device array of inputs of size (len * batch_size); * these are compared and selected. * @param[in] in_idx * contiguous device array of inputs of size (len * batch_size); - * typically, these are indices of the corresponding in_keys. + * typically, these are indices of the corresponding in_val. * @param batch_size * number of input rows, i.e. the batch size. * @param len @@ -51,12 +51,12 @@ namespace raft::spatial::knn::detail { * Invariant: len >= k. * @param k * the number of outputs to select in each input row. - * @param[out] out + * @param[out] out_val * contiguous device array of outputs of size (k * batch_size); - * the k smallest/largest values from each row of the `in_keys`. + * the k smallest/largest values from each row of the `in_val`. * @param[out] out_idx * contiguous device array of outputs of size (k * batch_size); - * the payload selected together with `out`. + * the payload selected together with `out_val`. * @param select_min * whether to select k smallest (true) or largest (false) keys. * @param stream @@ -64,28 +64,28 @@ namespace raft::spatial::knn::detail { * memory pool here to avoid memory allocations within the call). */ template -void select_topk(const T* in, - const IdxT* in_idx, - size_t batch_size, - size_t len, - int k, - T* out, - IdxT* out_idx, - bool select_min, - rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr = nullptr) +void select_k(const T* in_val, + const IdxT* in_idx, + size_t batch_size, + size_t len, + int k, + T* out_val, + IdxT* out_idx, + bool select_min, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr = nullptr) { common::nvtx::range fun_scope( - "matrix::select_topk(batch_size = %zu, len = %zu, k = %d)", batch_size, len, k); + "matrix::select_k(batch_size = %zu, len = %zu, k = %d)", batch_size, len, k); // TODO (achirkin): investigate the trade-off for a wider variety of inputs. const bool radix_faster = batch_size >= 64 && len >= 102400 && k >= 128; - if (k <= raft::spatial::knn::detail::topk::kMaxCapacity && !radix_faster) { - topk::warp_sort_topk( - in, in_idx, batch_size, len, k, out, out_idx, select_min, stream, mr); + if (k <= select::warpsort::kMaxCapacity && !radix_faster) { + select::warpsort::select_k( + in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, stream, mr); } else { - topk::radix_topk= 4 ? 11 : 8), 512>( - in, in_idx, batch_size, len, k, out, out_idx, select_min, stream, mr); + select::radix::select_k= 4 ? 11 : 8), 512>( + in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, stream, mr); } } -} // namespace raft::spatial::knn::detail +} // namespace raft::matrix::detail diff --git a/cpp/include/raft/spatial/knn/detail/topk/radix_topk.cuh b/cpp/include/raft/matrix/detail/select_radix.cuh similarity index 87% rename from cpp/include/raft/spatial/knn/detail/topk/radix_topk.cuh rename to cpp/include/raft/matrix/detail/select_radix.cuh index 9c0f20b706..de19e63a4c 100644 --- a/cpp/include/raft/spatial/knn/detail/topk/radix_topk.cuh +++ b/cpp/include/raft/matrix/detail/select_radix.cuh @@ -17,6 +17,7 @@ #pragma once #include +#include #include #include #include @@ -27,29 +28,29 @@ #include #include -#include +#include #include -namespace raft::spatial::knn::detail::topk { +namespace raft::matrix::detail::select::radix { constexpr int ITEM_PER_THREAD = 32; constexpr int VECTORIZED_READ_SIZE = 16; template -__host__ __device__ constexpr int calc_num_buckets() +_RAFT_HOST_DEVICE constexpr int calc_num_buckets() { return 1 << BitsPerPass; } template -__host__ __device__ constexpr int calc_num_passes() +_RAFT_HOST_DEVICE constexpr int calc_num_passes() { return ceildiv(sizeof(T) * 8, BitsPerPass); } // Minimum reasonable block size for the given radix size. template -__host__ __device__ constexpr int calc_min_block_size() +_RAFT_HOST_DEVICE constexpr int calc_min_block_size() { return 1 << std::max(BitsPerPass - 4, Pow2::Log2 + 1); } @@ -62,7 +63,7 @@ __host__ __device__ constexpr int calc_min_block_size() * NB: Use pass=-1 for calc_mask(). */ template -__device__ constexpr int calc_start_bit(int pass) +_RAFT_DEVICE constexpr int calc_start_bit(int pass) { int start_bit = static_cast(sizeof(T) * 8) - (pass + 1) * BitsPerPass; if (start_bit < 0) { start_bit = 0; } @@ -70,7 +71,7 @@ __device__ constexpr int calc_start_bit(int pass) } template -__device__ constexpr unsigned calc_mask(int pass) +_RAFT_DEVICE constexpr unsigned calc_mask(int pass) { static_assert(BitsPerPass <= 31); int num_bits = calc_start_bit(pass - 1) - calc_start_bit(pass); @@ -82,7 +83,7 @@ __device__ constexpr unsigned calc_mask(int pass) * as of integers. */ template -__device__ typename cub::Traits::UnsignedBits twiddle_in(T key, bool greater) +_RAFT_DEVICE typename cub::Traits::UnsignedBits twiddle_in(T key, bool greater) { auto bits = reinterpret_cast::UnsignedBits&>(key); bits = cub::Traits::TwiddleIn(bits); @@ -91,7 +92,7 @@ __device__ typename cub::Traits::UnsignedBits twiddle_in(T key, bool greater) } template -__device__ int calc_bucket(T x, int start_bit, unsigned mask, bool greater) +_RAFT_DEVICE int calc_bucket(T x, int start_bit, unsigned mask, bool greater) { static_assert(BitsPerPass <= sizeof(int) * 8 - 1); // so return type can be int return (twiddle_in(x, greater) >> start_bit) & mask; @@ -112,7 +113,7 @@ __device__ int calc_bucket(T x, int start_bit, unsigned mask, bool greater) * @param f the lambda taking two arguments (T x, IdxT idx) */ template -__device__ void vectorized_process(const T* in, IdxT len, Func f) +_RAFT_DEVICE void vectorized_process(const T* in, IdxT len, Func f) { const IdxT stride = blockDim.x * gridDim.x; const int tid = blockIdx.x * blockDim.x + threadIdx.x; @@ -167,18 +168,18 @@ struct Counter { * (see steps 4-1 in `radix_kernel` description). */ template -__device__ void filter_and_histogram(const T* in_buf, - const IdxT* in_idx_buf, - T* out_buf, - IdxT* out_idx_buf, - T* out, - IdxT* out_idx, - IdxT len, - Counter* counter, - IdxT* histogram, - bool greater, - int pass, - int k) +_RAFT_DEVICE void filter_and_histogram(const T* in_buf, + const IdxT* in_idx_buf, + T* out_buf, + IdxT* out_idx_buf, + T* out, + IdxT* out_idx, + IdxT len, + Counter* counter, + IdxT* histogram, + bool greater, + int pass, + int k) { constexpr int num_buckets = calc_num_buckets(); __shared__ IdxT histogram_smem[num_buckets]; @@ -260,10 +261,10 @@ __device__ void filter_and_histogram(const T* in_buf, * (step 2 in `radix_kernel` description) */ template -__device__ void scan(volatile IdxT* histogram, - const int start, - const int num_buckets, - const IdxT current) +_RAFT_DEVICE void scan(volatile IdxT* histogram, + const int start, + const int num_buckets, + const IdxT current) { typedef cub::BlockScan BlockScan; __shared__ typename BlockScan::TempStorage temp_storage; @@ -284,7 +285,7 @@ __device__ void scan(volatile IdxT* histogram, * (steps 2-3 in `radix_kernel` description) */ template -__device__ void choose_bucket(Counter* counter, IdxT* histogram, const IdxT k) +_RAFT_DEVICE void choose_bucket(Counter* counter, IdxT* histogram, const IdxT k) { constexpr int num_buckets = calc_num_buckets(); int index = threadIdx.x; @@ -547,21 +548,21 @@ inline dim3 get_optimal_grid_size(size_t req_batch_size, size_t len) * memory pool here to avoid memory allocations within the call). */ template -void radix_topk(const T* in, - const IdxT* in_idx, - size_t batch_size, - size_t len, - int k, - T* out, - IdxT* out_idx, - bool select_min, - rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr = nullptr) +void select_k(const T* in, + const IdxT* in_idx, + size_t batch_size, + size_t len, + int k, + T* out, + IdxT* out_idx, + bool select_min, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr = nullptr) { // reduce the block size if the input length is too small. if constexpr (BlockSize > calc_min_block_size()) { if (BlockSize * ITEM_PER_THREAD > len) { - return radix_topk( + return select_k( in, in_idx, batch_size, len, k, out, out_idx, select_min, stream); } } @@ -573,23 +574,33 @@ void radix_topk(const T* in, dim3 blocks = get_optimal_grid_size(batch_size, len); size_t max_chunk_size = blocks.y; - auto pool_guard = raft::get_pool_memory_resource( - mr, - max_chunk_size * (sizeof(Counter) // counters - + sizeof(IdxT) * (num_buckets + 2) // histograms and IdxT bufs - + sizeof(T) * 2 // T bufs - )); + size_t req_aux = max_chunk_size * (sizeof(Counter) + num_buckets * sizeof(IdxT)); + size_t req_buf = max_chunk_size * len * 2 * (sizeof(T) + sizeof(IdxT)); + size_t mem_req = req_aux + req_buf; + size_t mem_free, mem_total; + RAFT_CUDA_TRY(cudaMemGetInfo(&mem_free, &mem_total)); + std::optional managed_memory; + rmm::mr::device_memory_resource* mr_buf = nullptr; + if (mem_req > mem_free) { + // if there's not enough memory for buffers on the device, resort to the managed memory. + mem_req = req_aux; + managed_memory.emplace(); + mr_buf = &managed_memory.value(); + } + + auto pool_guard = raft::get_pool_memory_resource(mr, mem_req); if (pool_guard) { - RAFT_LOG_DEBUG("radix_topk: using pool memory resource with initial size %zu bytes", + RAFT_LOG_DEBUG("radix::select_k: using pool memory resource with initial size %zu bytes", pool_guard->pool_size()); } + if (mr_buf == nullptr) { mr_buf = mr; } rmm::device_uvector> counters(max_chunk_size, stream, mr); - rmm::device_uvector histograms(num_buckets * max_chunk_size, stream, mr); - rmm::device_uvector buf1(len * max_chunk_size, stream, mr); - rmm::device_uvector idx_buf1(len * max_chunk_size, stream, mr); - rmm::device_uvector buf2(len * max_chunk_size, stream, mr); - rmm::device_uvector idx_buf2(len * max_chunk_size, stream, mr); + rmm::device_uvector histograms(max_chunk_size * num_buckets, stream, mr); + rmm::device_uvector buf1(max_chunk_size * len, stream, mr_buf); + rmm::device_uvector idx_buf1(max_chunk_size * len, stream, mr_buf); + rmm::device_uvector buf2(max_chunk_size * len, stream, mr_buf); + rmm::device_uvector idx_buf2(max_chunk_size * len, stream, mr_buf); for (size_t offset = 0; offset < batch_size; offset += max_chunk_size) { blocks.y = std::min(max_chunk_size, batch_size - offset); @@ -646,4 +657,4 @@ void radix_topk(const T* in, } } -} // namespace raft::spatial::knn::detail::topk +} // namespace raft::matrix::detail::select::radix diff --git a/cpp/include/raft/spatial/knn/detail/topk/warpsort_topk.cuh b/cpp/include/raft/matrix/detail/select_warpsort.cuh similarity index 71% rename from cpp/include/raft/spatial/knn/detail/topk/warpsort_topk.cuh rename to cpp/include/raft/matrix/detail/select_warpsort.cuh index c06aa04aea..d362b73792 100644 --- a/cpp/include/raft/spatial/knn/detail/topk/warpsort_topk.cuh +++ b/cpp/include/raft/matrix/detail/select_warpsort.cuh @@ -16,10 +16,11 @@ #pragma once -#include "bitonic_sort.cuh" - +#include #include +#include #include +#include #include #include @@ -31,12 +32,12 @@ /* Three APIs of different scopes are provided: - 1. host function: warp_sort_topk() + 1. host function: select_k() 2. block-wide API: class block_sort 3. warp-wide API: several implementations of warp_sort_* - 1. warp_sort_topk() + 1. select_k() (see the docstring) 2. class block_sort @@ -74,7 +75,7 @@ These two classes can be regarded as fixed size priority queue for a warp. Usage is similar to class block_sort. No shared memory is needed. - The host function (warp_sort_topk) uses a heuristic to choose between these two classes for + The host function (select_k) uses a heuristic to choose between these two classes for sorting, warp_sort_immediate being chosen when the number of inputs per warp is somewhat small (see the usage of LaunchThreshold::len_factor_for_choosing). @@ -94,7 +95,7 @@ } */ -namespace raft::spatial::knn::detail::topk { +namespace raft::matrix::detail::select::warpsort { static constexpr int kMaxCapacity = 256; @@ -102,18 +103,12 @@ namespace { /** Whether 'left` should indeed be on the left w.r.t. `right`. */ template -__device__ __forceinline__ auto is_ordered(T left, T right) -> bool +_RAFT_DEVICE _RAFT_FORCEINLINE auto is_ordered(T left, T right) -> bool { if constexpr (Ascending) { return left < right; } if constexpr (!Ascending) { return left > right; } } -constexpr auto calc_capacity(int k) -> int -{ - int capacity = isPo2(k) ? k : (1 << (log2(k) + 1)); - return capacity; -} - } // namespace /** @@ -134,7 +129,7 @@ constexpr auto calc_capacity(int k) -> int */ template class warp_sort { - static_assert(isPo2(Capacity)); + static_assert(is_a_power_of_two(Capacity)); static_assert(std::is_default_constructible_v); public: @@ -148,13 +143,16 @@ class warp_sort { /** The number of elements to select. */ const int k; + /** Extra memory required per-block for keeping the state (shared or global). */ + constexpr static auto mem_required(uint32_t block_size) -> size_t { return 0; } + /** * Construct the warp_sort empty queue. * * @param k * number of elements to select. */ - __device__ warp_sort(int k) : k(k) + _RAFT_DEVICE warp_sort(int k) : k(k) { #pragma unroll for (int i = 0; i < kMaxArrLen; i++) { @@ -182,7 +180,7 @@ class warp_sort { * It serves as a conditional; when `false` the function does nothing. * We need it to ensure threads within a full warp don't diverge calling `bitonic::merge()`. */ - __device__ void load_sorted(const T* in, const IdxT* in_idx, bool do_merge = true) + _RAFT_DEVICE void load_sorted(const T* in, const IdxT* in_idx, bool do_merge = true) { if (do_merge) { int idx = Pow2::mod(laneId()) ^ Pow2::Mask; @@ -198,7 +196,7 @@ class warp_sort { } } if (kWarpWidth < WarpSize || do_merge) { - topk::bitonic(Ascending, kWarpWidth).merge(val_arr_, idx_arr_); + util::bitonic(Ascending, kWarpWidth).merge(val_arr_, idx_arr_); } } @@ -211,15 +209,23 @@ class warp_sort { * @param[out] out_idx * device pointer to a contiguous array, unique per-subwarp of size `kWarpWidth` * (length: k <= kWarpWidth * kMaxArrLen). + * @param valF (optional) postprocess values (T -> OutT) + * @param idxF (optional) postprocess indices (IdxT -> OutIdxT) */ - template - __device__ void store(T* out, IdxT* out_idx, Lambda post_process = raft::identity_op()) const + template + _RAFT_DEVICE void store(OutT* out, + OutIdxT* out_idx, + ValF valF = raft::identity_op{}, + IdxF idxF = raft::identity_op{}) const { int idx = Pow2::mod(laneId()); #pragma unroll kMaxArrLen for (int i = 0; i < kMaxArrLen && idx < k; i++, idx += kWarpWidth) { - out[idx] = post_process(val_arr_[i]); - out_idx[idx] = idx_arr_[i]; + out[idx] = valF(val_arr_[i]); + out_idx[idx] = idxF(idx_arr_[i]); } } @@ -246,8 +252,8 @@ class warp_sort { * the associated indices of the elements in the same format as `keys_in`. */ template - __device__ __forceinline__ void merge_in(const T* __restrict__ keys_in, - const IdxT* __restrict__ ids_in) + _RAFT_DEVICE _RAFT_FORCEINLINE void merge_in(const T* __restrict__ keys_in, + const IdxT* __restrict__ ids_in) { #pragma unroll for (int i = std::min(kMaxArrLen, PerThreadSizeIn); i > 0; i--) { @@ -258,7 +264,7 @@ class warp_sort { idx_arr_[kMaxArrLen - i] = ids_in[PerThreadSizeIn - i]; } } - topk::bitonic(Ascending, kWarpWidth).merge(val_arr_, idx_arr_); + util::bitonic(Ascending, kWarpWidth).merge(val_arr_, idx_arr_); } }; @@ -276,8 +282,9 @@ class warp_sort_filtered : public warp_sort { using warp_sort::kDummy; using warp_sort::kWarpWidth; using warp_sort::k; + using warp_sort::mem_required; - __device__ warp_sort_filtered(int k, T limit) + explicit _RAFT_DEVICE warp_sort_filtered(int k, T limit = kDummy) : warp_sort(k), buf_len_(0), k_th_(limit) { #pragma unroll @@ -287,12 +294,14 @@ class warp_sort_filtered : public warp_sort { } } - __device__ __forceinline__ explicit warp_sort_filtered(int k) - : warp_sort_filtered(k, kDummy) + _RAFT_DEVICE _RAFT_FORCEINLINE static auto init_blockwide(int k, + uint8_t* = nullptr, + T limit = kDummy) { + return warp_sort_filtered{k, limit}; } - __device__ void add(T val, IdxT idx) + _RAFT_DEVICE void add(T val, IdxT idx) { // comparing for k_th should reduce the total amount of updates: // `false` means the input value is surely not in the top-k values. @@ -310,22 +319,22 @@ class warp_sort_filtered : public warp_sort { if (do_add) { add_to_buf_(val, idx); } } - __device__ void done() + _RAFT_DEVICE void done() { if (any(buf_len_ != 0)) { merge_buf_(); } } private: - __device__ __forceinline__ void set_k_th_() + _RAFT_DEVICE _RAFT_FORCEINLINE void set_k_th_() { // NB on using srcLane: it's ok if it is outside the warp size / width; // the modulo op will be done inside the __shfl_sync. k_th_ = shfl(val_arr_[kMaxArrLen - 1], k - 1, kWarpWidth); } - __device__ __forceinline__ void merge_buf_() + _RAFT_DEVICE _RAFT_FORCEINLINE void merge_buf_() { - topk::bitonic(!Ascending, kWarpWidth).sort(val_buf_, idx_buf_); + util::bitonic(!Ascending, kWarpWidth).sort(val_buf_, idx_buf_); this->merge_in(val_buf_, idx_buf_); buf_len_ = 0; set_k_th_(); // contains warp sync @@ -335,7 +344,7 @@ class warp_sort_filtered : public warp_sort { } } - __device__ __forceinline__ void add_to_buf_(T val, IdxT idx) + _RAFT_DEVICE _RAFT_FORCEINLINE void add_to_buf_(T val, IdxT idx) { // NB: the loop is used here to ensure the constant indexing, // to not force the buffers spill into the local memory. @@ -374,8 +383,9 @@ class warp_sort_distributed : public warp_sort { using warp_sort::kDummy; using warp_sort::kWarpWidth; using warp_sort::k; + using warp_sort::mem_required; - __device__ warp_sort_distributed(int k, T limit) + explicit _RAFT_DEVICE warp_sort_distributed(int k, T limit = kDummy) : warp_sort(k), buf_val_(kDummy), buf_idx_(IdxT{}), @@ -384,12 +394,14 @@ class warp_sort_distributed : public warp_sort { { } - __device__ __forceinline__ explicit warp_sort_distributed(int k) - : warp_sort_distributed(k, kDummy) + _RAFT_DEVICE _RAFT_FORCEINLINE static auto init_blockwide(int k, + uint8_t* = nullptr, + T limit = kDummy) { + return warp_sort_distributed{k, limit}; } - __device__ void add(T val, IdxT idx) + _RAFT_DEVICE void add(T val, IdxT idx) { // mask tells which lanes in the warp have valid items to be added uint32_t mask = ballot(is_ordered(val, k_th_)); @@ -429,7 +441,7 @@ class warp_sort_distributed : public warp_sort { } } - __device__ void done() + _RAFT_DEVICE void done() { if (buf_len_ != 0) { merge_buf_(); @@ -438,16 +450,16 @@ class warp_sort_distributed : public warp_sort { } private: - __device__ __forceinline__ void set_k_th_() + _RAFT_DEVICE _RAFT_FORCEINLINE void set_k_th_() { // NB on using srcLane: it's ok if it is outside the warp size / width; // the modulo op will be done inside the __shfl_sync. k_th_ = shfl(val_arr_[kMaxArrLen - 1], k - 1, kWarpWidth); } - __device__ __forceinline__ void merge_buf_() + _RAFT_DEVICE _RAFT_FORCEINLINE void merge_buf_() { - topk::bitonic<1>(!Ascending, kWarpWidth).sort(buf_val_, buf_idx_); + util::bitonic<1>(!Ascending, kWarpWidth).sort(buf_val_, buf_idx_); this->merge_in<1>(&buf_val_, &buf_idx_); set_k_th_(); // contains warp sync buf_val_ = kDummy; @@ -464,6 +476,117 @@ class warp_sort_distributed : public warp_sort { T k_th_; }; +/** + * The same as `warp_sort_distributed`, but keeps the temporary value and index buffers + * in the given external pointers (normally, a shared memory pointer should be passed in). + */ +template +class warp_sort_distributed_ext : public warp_sort { + public: + using warp_sort::kDummy; + using warp_sort::kWarpWidth; + using warp_sort::k; + + constexpr static auto mem_required(uint32_t block_size) -> size_t + { + return (sizeof(T) + sizeof(IdxT)) * block_size; + } + + _RAFT_DEVICE warp_sort_distributed_ext(int k, T* val_buf, IdxT* idx_buf, T limit = kDummy) + : warp_sort(k), + val_buf_(val_buf), + idx_buf_(idx_buf), + buf_len_(0), + k_th_(limit) + { + val_buf_[laneId()] = kDummy; + } + + _RAFT_DEVICE static auto init_blockwide(int k, uint8_t* shmem, T limit = kDummy) + { + T* val_buf = nullptr; + IdxT* idx_buf = nullptr; + if constexpr (alignof(T) >= alignof(IdxT)) { + val_buf = reinterpret_cast(shmem); + idx_buf = reinterpret_cast(val_buf + blockDim.x); + } else { + idx_buf = reinterpret_cast(shmem); + val_buf = reinterpret_cast(idx_buf + blockDim.x); + } + auto warp_offset = Pow2::roundDown(threadIdx.x); + val_buf += warp_offset; + idx_buf += warp_offset; + return warp_sort_distributed_ext{k, val_buf, idx_buf, limit}; + } + + _RAFT_DEVICE void add(T val, IdxT idx) + { + bool do_add = is_ordered(val, k_th_); + // mask tells which lanes in the warp have valid items to be added + uint32_t mask = ballot(do_add); + if (mask == 0) { return; } + // where to put the element in the tmp buffer + int dst_ix = buf_len_ + __popc(mask & ((1u << laneId()) - 1u)); + // put all elements, which fit into the current tmp buffer + if (do_add && dst_ix < WarpSize) { + val_buf_[dst_ix] = val; + idx_buf_[dst_ix] = idx; + do_add = false; + } + // Total number of elements to be added + buf_len_ += __popc(mask); + // If the buffer is still not full, we can return + if (buf_len_ < WarpSize) { return; } + // Otherwise, merge the warp tmp buffer into the queue + merge_buf_(); // implies warp sync + buf_len_ -= WarpSize; + // save the inputs that couldn't fit before the merge + if (do_add) { + dst_ix -= WarpSize; + val_buf_[dst_ix] = val; + idx_buf_[dst_ix] = idx; + } + } + + _RAFT_DEVICE void done() + { + if (buf_len_ != 0) { + merge_buf_(); + buf_len_ = 0; + } + __syncthreads(); + } + + private: + _RAFT_DEVICE _RAFT_FORCEINLINE void set_k_th_() + { + // NB on using srcLane: it's ok if it is outside the warp size / width; + // the modulo op will be done inside the __shfl_sync. + k_th_ = shfl(val_arr_[kMaxArrLen - 1], k - 1, kWarpWidth); + } + + _RAFT_DEVICE _RAFT_FORCEINLINE void merge_buf_() + { + __syncwarp(); // make sure the threads are aware of the data written by others + T buf_val = val_buf_[laneId()]; + IdxT buf_idx = idx_buf_[laneId()]; + val_buf_[laneId()] = kDummy; + util::bitonic<1>(!Ascending, kWarpWidth).sort(buf_val, buf_idx); + this->merge_in<1>(&buf_val, &buf_idx); + set_k_th_(); // contains warp sync + } + + using warp_sort::kMaxArrLen; + using warp_sort::val_arr_; + using warp_sort::idx_arr_; + + T* val_buf_; + IdxT* idx_buf_; + uint32_t buf_len_; // 0 <= buf_len_ < WarpSize + + T k_th_; +}; + /** * This version of warp_sort adds every input element into the intermediate sorting * buffer, and thus does the sorting step every `Capacity` input elements. @@ -476,8 +599,10 @@ class warp_sort_immediate : public warp_sort { using warp_sort::kDummy; using warp_sort::kWarpWidth; using warp_sort::k; + using warp_sort::mem_required; - __device__ warp_sort_immediate(int k) : warp_sort(k), buf_len_(0) + explicit _RAFT_DEVICE warp_sort_immediate(int k) + : warp_sort(k), buf_len_(0) { #pragma unroll for (int i = 0; i < kMaxArrLen; i++) { @@ -486,7 +611,12 @@ class warp_sort_immediate : public warp_sort { } } - __device__ void add(T val, IdxT idx) + _RAFT_DEVICE _RAFT_FORCEINLINE static auto init_blockwide(int k, uint8_t* = nullptr) + { + return warp_sort_immediate{k}; + } + + _RAFT_DEVICE void add(T val, IdxT idx) { // NB: the loop is used here to ensure the constant indexing, // to not force the buffers spill into the local memory. @@ -500,7 +630,7 @@ class warp_sort_immediate : public warp_sort { ++buf_len_; if (buf_len_ == kMaxArrLen) { - topk::bitonic(!Ascending, kWarpWidth).sort(val_buf_, idx_buf_); + util::bitonic(!Ascending, kWarpWidth).sort(val_buf_, idx_buf_); this->merge_in(val_buf_, idx_buf_); #pragma unroll for (int i = 0; i < kMaxArrLen; i++) { @@ -510,10 +640,10 @@ class warp_sort_immediate : public warp_sort { } } - __device__ void done() + _RAFT_DEVICE void done() { if (buf_len_ != 0) { - topk::bitonic(!Ascending, kWarpWidth).sort(val_buf_, idx_buf_); + util::bitonic(!Ascending, kWarpWidth).sort(val_buf_, idx_buf_); this->merge_in(val_buf_, idx_buf_); } } @@ -545,15 +675,11 @@ class block_sort { using queue_t = WarpSortWarpWide; template - __device__ block_sort(int k, uint8_t* smem_buf, Args... args) : queue_(k, args...) + _RAFT_DEVICE block_sort(int k, Args... args) : queue_(queue_t::init_blockwide(k, args...)) { - val_smem_ = reinterpret_cast(smem_buf); - const int num_of_warp = subwarp_align::div(blockDim.x); - idx_smem_ = reinterpret_cast( - smem_buf + Pow2<256>::roundUp(ceildiv(num_of_warp, 2) * sizeof(T) * k)); } - __device__ void add(T val, IdxT idx) { queue_.add(val, idx); } + _RAFT_DEVICE void add(T val, IdxT idx) { queue_.add(val, idx); } /** * At the point of calling this function, the warp-level queues consumed all input @@ -561,22 +687,26 @@ class block_sort { * * Here we tree-merge the results using the shared memory and block sync. */ - __device__ void done() + _RAFT_DEVICE void done(uint8_t* smem_buf) { queue_.done(); + int nwarps = subwarp_align::div(blockDim.x); + auto val_smem = reinterpret_cast(smem_buf); + auto idx_smem = reinterpret_cast( + smem_buf + Pow2<256>::roundUp(ceildiv(nwarps, 2) * sizeof(T) * queue_.k)); + const int warp_id = subwarp_align::div(threadIdx.x); // NB: there is no need for the second __synchthreads between .load_sorted and .store: // we shift the pointers every iteration, such that individual warps either access the same // locations or do not overlap with any of the other warps. The access patterns within warps // are different for the two functions, but .load_sorted implies warp sync at the end, so // there is no need for __syncwarp either. - for (int shift_mask = ~0, nwarps = subwarp_align::div(blockDim.x), split = (nwarps + 1) >> 1; - nwarps > 1; + for (int shift_mask = ~0, split = (nwarps + 1) >> 1; nwarps > 1; nwarps = split, split = (nwarps + 1) >> 1) { if (warp_id < nwarps && warp_id >= split) { int dst_warp_shift = (warp_id - (split & shift_mask)) * queue_.k; - queue_.store(val_smem_ + dst_warp_shift, idx_smem_ + dst_warp_shift); + queue_.store(val_smem + dst_warp_shift, idx_smem + dst_warp_shift); } __syncthreads(); @@ -586,23 +716,27 @@ class block_sort { // The last argument serves as a condition for loading // -- to make sure threads within a full warp do not diverge on `bitonic::merge()` queue_.load_sorted( - val_smem_ + src_warp_shift, idx_smem_ + src_warp_shift, warp_id < nwarps - split); + val_smem + src_warp_shift, idx_smem + src_warp_shift, warp_id < nwarps - split); } } } /** Save the content by the pointer location. */ - template - __device__ void store(T* out, IdxT* out_idx, Lambda post_process = raft::identity_op()) const + template + _RAFT_DEVICE void store(OutT* out, + OutIdxT* out_idx, + ValF valF = raft::identity_op{}, + IdxF idxF = raft::identity_op{}) const { - if (threadIdx.x < subwarp_align::Value) { queue_.store(out, out_idx, post_process); } + if (threadIdx.x < subwarp_align::Value) { queue_.store(out, out_idx, valF, idxF); } } private: using subwarp_align = Pow2; queue_t queue_; - T* val_smem_; - IdxT* idx_smem_; }; /** @@ -620,7 +754,10 @@ __launch_bounds__(256) __global__ void block_kernel(const T* in, const IdxT* in_idx, IdxT len, int k, T* out, IdxT* out_idx) { extern __shared__ __align__(256) uint8_t smem_buf_bytes[]; - block_sort queue(k, smem_buf_bytes); + using bq_t = block_sort; + uint8_t* warp_smem = bq_t::queue_t::mem_required(blockDim.x) > 0 ? smem_buf_bytes : nullptr; + bq_t queue(k, warp_smem); + in += blockIdx.y * len; if (in_idx != nullptr) { in_idx += blockIdx.y * len; } @@ -631,7 +768,7 @@ __launch_bounds__(256) __global__ (i < len && in_idx != nullptr) ? __ldcs(in_idx + i) : i); } - queue.done(); + queue.done(smem_buf_bytes); const int block_id = blockIdx.x + gridDim.x * blockIdx.y; queue.store(out + block_id * k, out_idx + block_id * k); } @@ -658,7 +795,7 @@ struct launch_setup { int* min_grid_size, int block_size_limit = 0) { - const int capacity = calc_capacity(k); + const int capacity = bound_by_power_of_two(k); if constexpr (Capacity > 1) { if (capacity < Capacity) { return launch_setup::calc_optimal_params( @@ -691,7 +828,7 @@ struct launch_setup { IdxT* out_idx, rmm::cuda_stream_view stream) { - const int capacity = calc_capacity(k); + const int capacity = bound_by_power_of_two(k); if constexpr (Capacity > 1) { if (capacity < Capacity) { return launch_setup::kernel(k, @@ -742,6 +879,18 @@ struct LaunchThreshold { static constexpr int len_factor_for_single_block = 32; }; +template <> +struct LaunchThreshold { + static constexpr int len_factor_for_multi_block = 2; + static constexpr int len_factor_for_single_block = 32; +}; + +template <> +struct LaunchThreshold { + static constexpr int len_factor_for_multi_block = 2; + static constexpr int len_factor_for_single_block = 32; +}; + template <> struct LaunchThreshold { static constexpr int len_factor_for_choosing = 4; @@ -753,7 +902,7 @@ template