From 2e98138c7d022e754e2b2f3bbab377d9d47150c3 Mon Sep 17 00:00:00 2001 From: Mark Hoemmen Date: Thu, 29 Sep 2022 10:25:45 -0600 Subject: [PATCH 1/3] mdspan-ify several raft::random rng functions (#857) Add overloads taking the output vector as mdspan, of the following `raft::random` functions: * normal * lognormal * uniform * gumbel * logistic * exponential * rayleigh * laplace I plan to finish the remaining `raft::random` functions as part of this PR. However, this PR should be ready to merge now. Authors: - Mark Hoemmen (https://github.com/mhoemmen) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/857 --- cpp/include/raft/random/rng.cuh | 180 +++++++++++++++++++++++++++++++- cpp/test/random/rng.cu | 140 +++++++++++++++++++++---- 2 files changed, 296 insertions(+), 24 deletions(-) diff --git a/cpp/include/raft/random/rng.cuh b/cpp/include/raft/random/rng.cuh index ba6254bfc3..106881fa1a 100644 --- a/cpp/include/raft/random/rng.cuh +++ b/cpp/include/raft/random/rng.cuh @@ -30,6 +30,28 @@ namespace raft::random { /** * @brief Generate uniformly distributed numbers in the given range * + * @tparam OutputValueType Data type of output random number + * @tparam Index Data type used to represent length of the arrays + * + * @param[in] handle raft handle for resource management + * @param[in] rng_state random number generator state + * @param[out] out the output array + * @param[in] start start of the range + * @param[in] end end of the range + */ +template +void uniform(const raft::handle_t& handle, + RngState& rng_state, + raft::device_vector_view out, + OutputValueType start, + OutputValueType end) +{ + detail::uniform(rng_state, out.data_handle(), out.extent(0), start, end, handle.get_stream()); +} + +/** + * @brief Legacy overload of `uniform` taking raw pointers + * * @tparam OutType data type of output random number * @tparam LenType data type used to represent length of the arrays * @param[in] handle raft handle for resource management @@ -75,6 +97,29 @@ void uniformInt(const raft::handle_t& handle, /** * @brief Generate normal distributed numbers + * with a given mean and standard deviation + * + * @tparam OutputValueType data type of output random number + * @tparam IndexType data type used to represent length of the arrays + * + * @param[in] handle raft handle for resource management + * @param[in] rng_state random number generator state + * @param[out] out the output array + * @param[in] mu mean of the distribution + * @param[in] sigma std-dev of the distribution + */ +template +void normal(const raft::handle_t& handle, + RngState& rng_state, + raft::device_vector_view out, + OutputValueType mu, + OutputValueType sigma) +{ + detail::normal(rng_state, out.data_handle(), out.extent(0), mu, sigma, handle.get_stream()); +} + +/** + * @brief Legacy raw pointer overload of `normal`. * * @tparam OutType data type of output random number * @tparam LenType data type used to represent length of the arrays @@ -217,6 +262,29 @@ void scaled_bernoulli(const raft::handle_t& handle, /** * @brief Generate Gumbel distributed random numbers * + * @tparam OutputValueType data type of output random number + * @tparam IndexType data type used to represent length of the arrays + * + * @param[in] handle raft handle for resource management + * @param[in] rng_state random number generator state + * @param[out] out output array + * @param[in] mu mean value + * @param[in] beta scale value + * @note https://en.wikipedia.org/wiki/Gumbel_distribution + */ +template +void gumbel(const raft::handle_t& handle, + RngState& rng_state, + raft::device_vector_view out, + OutputValueType mu, + OutputValueType beta) +{ + detail::gumbel(rng_state, out.data_handle(), out.extent(0), mu, beta, handle.get_stream()); +} + +/** + * @brief Legacy raw pointer overload of `gumbel`. + * * @tparam OutType data type of output random number * @tparam LenType data type used to represent length of the arrays * @param[in] handle raft handle for resource management @@ -241,6 +309,28 @@ void gumbel(const raft::handle_t& handle, /** * @brief Generate lognormal distributed numbers * + * @tparam OutputValueType data type of output random number + * @tparam IndexType data type used to represent length of the arrays + * + * @param[in] handle raft handle for resource management + * @param[in] rng_state random number generator state + * @param[out] out the output array + * @param[in] mu mean of the distribution + * @param[in] sigma standard deviation of the distribution + */ +template +void lognormal(const raft::handle_t& handle, + RngState& rng_state, + raft::device_vector_view out, + OutputValueType mu, + OutputValueType sigma) +{ + detail::lognormal(rng_state, out.data_handle(), out.extent(0), mu, sigma, handle.get_stream()); +} + +/** + * @brief Legacy raw pointer overload of `lognormal`. + * * @tparam OutType data type of output random number * @tparam LenType data type used to represent length of the arrays * @param[in] handle raft handle for resource management @@ -248,7 +338,7 @@ void gumbel(const raft::handle_t& handle, * @param[out] ptr the output array * @param[in] len the number of elements in the output * @param[in] mu mean of the distribution - * @param[in] sigma std-dev of the distribution + * @param[in] sigma standard deviation of the distribution */ template void lognormal(const raft::handle_t& handle, @@ -264,6 +354,28 @@ void lognormal(const raft::handle_t& handle, /** * @brief Generate logistic distributed random numbers * + * @tparam OutputValueType data type of output random number + * @tparam IndexType data type used to represent length of the arrays + * + * @param[in] handle raft handle for resource management + * @param[in] rng_state random number generator state + * @param[out] out output array + * @param[in] mu mean value + * @param[in] scale scale value + */ +template +void logistic(const raft::handle_t& handle, + RngState& rng_state, + raft::device_vector_view out, + OutputValueType mu, + OutputValueType scale) +{ + detail::logistic(rng_state, out.data_handle(), out.extent(0), mu, scale, handle.get_stream()); +} + +/** + * @brief Legacy raw pointer overload of `logistic`. + * * @tparam OutType data type of output random number * @tparam LenType data type used to represent length of the arrays * @param[in] handle raft handle for resource management @@ -287,13 +399,33 @@ void logistic(const raft::handle_t& handle, /** * @brief Generate exponentially distributed random numbers * + * @tparam OutputValueType data type of output random number + * @tparam IndexType data type used to represent length of the arrays + * + * @param[in] handle raft handle for resource management + * @param[in] rng_state random number generator state + * @param[out] out output array + * @param[in] lambda the exponential distribution's lambda parameter + */ +template +void exponential(const raft::handle_t& handle, + RngState& rng_state, + raft::device_vector_view out, + OutputValueType lambda) +{ + detail::exponential(rng_state, out.data_handle(), out.extent(0), lambda, handle.get_stream()); +} + +/** + * @brief Legacy raw pointer overload of `exponential`. + * * @tparam OutType data type of output random number * @tparam LenType data type used to represent length of the arrays * @param[in] handle raft handle for resource management * @param[in] rng_state random number generator state * @param[out] ptr output array * @param[in] len number of elements in the output array - * @param[in] lambda the lambda + * @param[in] lambda the exponential distribution's lambda parameter */ template void exponential( @@ -305,13 +437,33 @@ void exponential( /** * @brief Generate rayleigh distributed random numbers * + * @tparam OutputValueType data type of output random number + * @tparam IndexType data type used to represent length of the arrays + * + * @param[in] handle raft handle for resource management + * @param[in] rng_state random number generator state + * @param[out] out output array + * @param[in] sigma the distribution's sigma parameter + */ +template +void rayleigh(const raft::handle_t& handle, + RngState& rng_state, + raft::device_vector_view out, + OutputValueType sigma) +{ + detail::rayleigh(rng_state, out.data_handle(), out.extent(0), sigma, handle.get_stream()); +} + +/** + * @brief Legacy raw pointer overload of `rayleigh`. + * * @tparam OutType data type of output random number * @tparam LenType data type used to represent length of the arrays * @param[in] handle raft handle for resource management * @param[in] rng_state random number generator state * @param[out] ptr output array * @param[in] len number of elements in the output array - * @param[in] sigma the sigma + * @param[in] sigma the distribution's sigma parameter */ template void rayleigh( @@ -323,6 +475,28 @@ void rayleigh( /** * @brief Generate laplace distributed random numbers * + * @tparam OutputValueType data type of output random number + * @tparam IndexType data type used to represent length of the arrays + * + * @param[in] handle raft handle for resource management + * @param[in] rng_state random number generator state + * @param[out] out output array + * @param[in] mu the mean + * @param[in] scale the scale + */ +template +void laplace(const raft::handle_t& handle, + RngState& rng_state, + raft::device_vector_view out, + OutputValueType mu, + OutputValueType scale) +{ + detail::laplace(rng_state, out.data_handle(), out.extent(0), mu, scale, handle.get_stream()); +} + +/** + * @brief Legacy raw pointer overload of `laplace`. + * * @tparam OutType data type of output random number * @tparam LenType data type used to represent length of the arrays * @param[in] handle raft handle for resource management diff --git a/cpp/test/random/rng.cu b/cpp/test/random/rng.cu index d778555076..8b32742f34 100644 --- a/cpp/test/random/rng.cu +++ b/cpp/test/random/rng.cu @@ -14,6 +14,7 @@ * limitations under the License. */ +#include #include #include "../test_utils.h" @@ -187,7 +188,100 @@ class RngTest : public ::testing::TestWithParam> { T h_stats[2]; // mean, var }; -typedef RngTest RngTestF; +template +class RngMdspanTest : public ::testing::TestWithParam> { + public: + RngMdspanTest() + : params(::testing::TestWithParam>::GetParam()), + stream(handle.get_stream()), + data(0, stream), + stats(2, stream) + { + data.resize(params.len, stream); + RAFT_CUDA_TRY(cudaMemsetAsync(stats.data(), 0, 2 * sizeof(T), stream)); + } + + protected: + void SetUp() override + { + RngState r(params.seed, params.gtype); + + raft::device_vector_view data_view(data.data(), data.size()); + const auto len = data_view.extent(0); + + switch (params.type) { + case RNG_Normal: normal(handle, r, data_view, params.start, params.end); break; + case RNG_LogNormal: lognormal(handle, r, data_view, params.start, params.end); break; + case RNG_Uniform: uniform(handle, r, data_view, params.start, params.end); break; + case RNG_Gumbel: gumbel(handle, r, data_view, params.start, params.end); break; + case RNG_Logistic: logistic(handle, r, data_view, params.start, params.end); break; + case RNG_Exp: exponential(handle, r, data_view, params.start); break; + case RNG_Rayleigh: rayleigh(handle, r, data_view, params.start); break; + case RNG_Laplace: laplace(handle, r, data_view, params.start, params.end); break; + }; + static const int threads = 128; + meanKernel<<>>( + stats.data(), data.data(), params.len); + update_host(h_stats, stats.data(), 2, stream); + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + h_stats[0] /= params.len; + h_stats[1] = (h_stats[1] / params.len) - (h_stats[0] * h_stats[0]); + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + } + + void getExpectedMeanVar(T meanvar[2]) + { + switch (params.type) { + case RNG_Normal: + meanvar[0] = params.start; + meanvar[1] = params.end * params.end; + break; + case RNG_LogNormal: { + auto var = params.end * params.end; + auto mu = params.start; + meanvar[0] = raft::myExp(mu + var * T(0.5)); + meanvar[1] = (raft::myExp(var) - T(1.0)) * raft::myExp(T(2.0) * mu + var); + break; + } + case RNG_Uniform: + meanvar[0] = (params.start + params.end) * T(0.5); + meanvar[1] = params.end - params.start; + meanvar[1] = meanvar[1] * meanvar[1] / T(12.0); + break; + case RNG_Gumbel: { + auto gamma = T(0.577215664901532); + meanvar[0] = params.start + params.end * gamma; + meanvar[1] = T(3.1415) * T(3.1415) * params.end * params.end / T(6.0); + break; + } + case RNG_Logistic: + meanvar[0] = params.start; + meanvar[1] = T(3.1415) * T(3.1415) * params.end * params.end / T(3.0); + break; + case RNG_Exp: + meanvar[0] = T(1.0) / params.start; + meanvar[1] = meanvar[0] * meanvar[0]; + break; + case RNG_Rayleigh: + meanvar[0] = params.start * raft::mySqrt(T(3.1415 / 2.0)); + meanvar[1] = ((T(4.0) - T(3.1415)) / T(2.0)) * params.start * params.start; + break; + case RNG_Laplace: + meanvar[0] = params.start; + meanvar[1] = T(2.0) * params.end * params.end; + break; + }; + } + + protected: + raft::handle_t handle; + cudaStream_t stream; + + RngInputs params; + rmm::device_uvector data, stats; + T h_stats[2]; // mean, var +}; + const std::vector> inputsf = { // Test with Philox {1024 * 1024, 3.0f, 1.3f, RNG_Normal, GenPhilox, 1234ULL}, @@ -206,16 +300,22 @@ const std::vector> inputsf = { {1024 * 1024, 1.6f, 0.0f, RNG_Rayleigh, GenPC, 1234ULL}, {1024 * 1024, 2.6f, 1.3f, RNG_Laplace, GenPC, 1234ULL}}; -TEST_P(RngTestF, Result) -{ - float meanvar[2]; - getExpectedMeanVar(meanvar); - ASSERT_TRUE(match(meanvar[0], h_stats[0], CompareApprox(NUM_SIGMA * MAX_SIGMA))); - ASSERT_TRUE(match(meanvar[1], h_stats[1], CompareApprox(NUM_SIGMA * MAX_SIGMA))); -} +#define _RAFT_RNG_TEST_BODY(VALUE_TYPE) \ + do { \ + VALUE_TYPE meanvar[2]; \ + getExpectedMeanVar(meanvar); \ + ASSERT_TRUE(match(meanvar[0], h_stats[0], CompareApprox(NUM_SIGMA * MAX_SIGMA))); \ + ASSERT_TRUE(match(meanvar[1], h_stats[1], CompareApprox(NUM_SIGMA * MAX_SIGMA))); \ + } while (false) + +using RngTestF = RngTest; +TEST_P(RngTestF, Result) { _RAFT_RNG_TEST_BODY(float); } INSTANTIATE_TEST_SUITE_P(RngTests, RngTestF, ::testing::ValuesIn(inputsf)); -typedef RngTest RngTestD; +using RngMdspanTestF = RngMdspanTest; +TEST_P(RngMdspanTestF, Result) { _RAFT_RNG_TEST_BODY(float); } +INSTANTIATE_TEST_SUITE_P(RngMdspanTests, RngMdspanTestF, ::testing::ValuesIn(inputsf)); + const std::vector> inputsd = { // Test with Philox {1024 * 1024, 3.0f, 1.3f, RNG_Normal, GenPhilox, 1234ULL}, @@ -234,15 +334,14 @@ const std::vector> inputsd = { {1024 * 1024, 1.6f, 0.0f, RNG_Rayleigh, GenPC, 1234ULL}, {1024 * 1024, 2.6f, 1.3f, RNG_Laplace, GenPC, 1234ULL}}; -TEST_P(RngTestD, Result) -{ - double meanvar[2]; - getExpectedMeanVar(meanvar); - ASSERT_TRUE(match(meanvar[0], h_stats[0], CompareApprox(NUM_SIGMA * MAX_SIGMA))); - ASSERT_TRUE(match(meanvar[1], h_stats[1], CompareApprox(NUM_SIGMA * MAX_SIGMA))); -} +using RngTestD = RngTest; +TEST_P(RngTestD, Result) { _RAFT_RNG_TEST_BODY(double); } INSTANTIATE_TEST_SUITE_P(RngTests, RngTestD, ::testing::ValuesIn(inputsd)); +using RngMdspanTestD = RngMdspanTest; +TEST_P(RngMdspanTestD, Result) { _RAFT_RNG_TEST_BODY(double); } +INSTANTIATE_TEST_SUITE_P(RngMdspanTests, RngMdspanTestD, ::testing::ValuesIn(inputsd)); + // ---------------------------------------------------------------------- // // Test for expected variance in mean calculations @@ -353,11 +452,10 @@ class ScaledBernoulliTest : public ::testing::Test { void rangeCheck() { - T* h_data = new T[len]; - update_host(h_data, data.data(), len, stream); - ASSERT_TRUE( - std::none_of(h_data, h_data + len, [](const T& a) { return a < -scale || a > scale; })); - delete[] h_data; + auto h_data = std::make_unique(len); + update_host(h_data.get(), data.data(), len, stream); + ASSERT_TRUE(std::none_of( + h_data.get(), h_data.get() + len, [](const T& a) { return a < -scale || a > scale; })); } raft::handle_t handle; From d475fca0c3475a71afc93ac8c9b3c37e7d238121 Mon Sep 17 00:00:00 2001 From: Louis Sugy Date: Fri, 30 Sep 2022 16:08:28 +0200 Subject: [PATCH 2/3] KMeans benchmarks (cuML + ANN implementations) and fix for IndexT=int64_t (#795) This PR is an answer to the first two sub-tasks described in the issue #700 cc @tfeher @achirkin Authors: - Louis Sugy (https://github.com/Nyrio) - Divye Gala (https://github.com/divyegala) - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Tamas Bela Feher (https://github.com/tfeher) - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/795 --- cpp/CMakeLists.txt | 7 + cpp/bench/CMakeLists.txt | 2 + cpp/bench/cluster/kmeans.cu | 115 ++++++++++++ cpp/bench/cluster/kmeans_balanced.cu | 110 +++++++++++ cpp/bench/common/benchmark.hpp | 79 +++++++- cpp/bench/spatial/knn.cu | 1 + cpp/include/raft/cluster/detail/kmeans.cuh | 15 +- .../raft/cluster/detail/kmeans_common.cuh | 175 +++++++----------- cpp/include/raft/core/detail/macros.hpp | 17 ++ .../raft/distance/detail/fused_l2_nn.cuh | 6 + cpp/include/raft/distance/fused_l2_nn.cuh | 55 ++++++ .../specializations/detail/russel_rao.cuh | 68 +++++++ .../distance/specializations/distance.cuh | 2 + .../specializations/fused_l2_nn_min.cuh | 126 +++++++++++++ cpp/include/raft/random/detail/rng_impl.cuh | 11 +- .../knn/detail/ann_kmeans_balanced.cuh | 8 +- .../raft/spatial/knn/detail/ann_utils.cuh | 62 ++++--- cpp/include/raft/spectral/detail/warn_dbg.hpp | 18 +- .../russel_rao_double_double_double_int.cu | 38 ++++ .../russel_rao_float_float_float_int.cu | 37 ++++ .../russel_rao_float_float_float_uint32.cu | 39 ++++ .../specializations/fused_l2_nn_double_int.cu | 49 +++++ .../fused_l2_nn_double_int64.cu | 49 +++++ .../specializations/fused_l2_nn_float_int.cu | 49 +++++ .../fused_l2_nn_float_int64.cu | 49 +++++ 25 files changed, 1032 insertions(+), 155 deletions(-) create mode 100644 cpp/bench/cluster/kmeans.cu create mode 100644 cpp/bench/cluster/kmeans_balanced.cu create mode 100644 cpp/include/raft/distance/specializations/detail/russel_rao.cuh create mode 100644 cpp/include/raft/distance/specializations/fused_l2_nn_min.cuh create mode 100644 cpp/src/distance/specializations/detail/russel_rao_double_double_double_int.cu create mode 100644 cpp/src/distance/specializations/detail/russel_rao_float_float_float_int.cu create mode 100644 cpp/src/distance/specializations/detail/russel_rao_float_float_float_uint32.cu create mode 100644 cpp/src/distance/specializations/fused_l2_nn_double_int.cu create mode 100644 cpp/src/distance/specializations/fused_l2_nn_double_int64.cu create mode 100644 cpp/src/distance/specializations/fused_l2_nn_float_int.cu create mode 100644 cpp/src/distance/specializations/fused_l2_nn_float_int64.cu diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index df41b47766..2da18d2a74 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -274,6 +274,13 @@ if(RAFT_COMPILE_DIST_LIBRARY) src/distance/specializations/detail/lp_unexpanded_double_double_double_int.cu src/distance/specializations/detail/lp_unexpanded_float_float_float_uint32.cu src/distance/specializations/detail/lp_unexpanded_float_float_float_int.cu + src/distance/specializations/detail/russel_rao_double_double_double_int.cu + src/distance/specializations/detail/russel_rao_float_float_float_uint32.cu + src/distance/specializations/detail/russel_rao_float_float_float_int.cu + src/distance/specializations/fused_l2_nn_double_int.cu + src/distance/specializations/fused_l2_nn_double_int64.cu + src/distance/specializations/fused_l2_nn_float_int.cu + src/distance/specializations/fused_l2_nn_float_int64.cu src/random/specializations/rmat_rectangular_generator_int_double.cu src/random/specializations/rmat_rectangular_generator_int64_double.cu src/random/specializations/rmat_rectangular_generator_int_float.cu diff --git a/cpp/bench/CMakeLists.txt b/cpp/bench/CMakeLists.txt index 6b2d463d0e..266571d4f3 100644 --- a/cpp/bench/CMakeLists.txt +++ b/cpp/bench/CMakeLists.txt @@ -18,6 +18,8 @@ set(RAFT_CPP_BENCH_TARGET "bench_raft") # (please keep the filenames in alphabetical order) add_executable(${RAFT_CPP_BENCH_TARGET} + bench/cluster/kmeans_balanced.cu + bench/cluster/kmeans.cu bench/distance/distance_cosine.cu bench/distance/distance_exp_l2.cu bench/distance/distance_l1.cu diff --git a/cpp/bench/cluster/kmeans.cu b/cpp/bench/cluster/kmeans.cu new file mode 100644 index 0000000000..bf4cc2f686 --- /dev/null +++ b/cpp/bench/cluster/kmeans.cu @@ -0,0 +1,115 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#if defined RAFT_DISTANCE_COMPILED && defined RAFT_NN_COMPILED +#include +#endif + +namespace raft::bench::cluster { + +struct KMeansBenchParams { + DatasetParams data; + BlobsParams blobs; + raft::cluster::KMeansParams kmeans; +}; + +template +struct KMeans : public BlobsFixture { + KMeans(const KMeansBenchParams& p) : BlobsFixture(p.data, p.blobs), params(p) {} + + void run_benchmark(::benchmark::State& state) override + { + raft::device_matrix_view X_view = this->X.view(); + std::optional> opt_weights_view = std::nullopt; + std::optional> centroids_view = + std::make_optional>(centroids.view()); + raft::device_vector_view labels_view = labels.view(); + raft::host_scalar_view inertia_view = raft::make_host_scalar_view(&inertia); + raft::host_scalar_view n_iter_view = raft::make_host_scalar_view(&n_iter); + + this->loop_on_state(state, [&]() { + raft::cluster::kmeans_fit_predict(this->handle, + params.kmeans, + X_view, + opt_weights_view, + centroids_view, + labels_view, + inertia_view, + n_iter_view); + }); + } + + void allocate_temp_buffers(const ::benchmark::State& state) override + { + centroids = + raft::make_device_matrix(this->handle, params.kmeans.n_clusters, params.data.cols); + labels = raft::make_device_vector(this->handle, params.data.rows); + } + + private: + KMeansBenchParams params; + raft::device_matrix centroids; + raft::device_vector labels; + T inertia; + IndexT n_iter; +}; // struct KMeans + +std::vector getKMeansInputs() +{ + std::vector out; + KMeansBenchParams p; + p.data.row_major = true; + p.blobs.cluster_std = 1.0; + p.blobs.shuffle = false; + p.blobs.center_box_min = -10.0; + p.blobs.center_box_max = 10.0; + p.blobs.seed = 12345ULL; + p.kmeans.init = raft::cluster::KMeansParams::KMeansPlusPlus; + p.kmeans.max_iter = 300; + p.kmeans.tol = 1e-4; + p.kmeans.verbosity = RAFT_LEVEL_INFO; + p.kmeans.metric = raft::distance::DistanceType::L2Expanded; + p.kmeans.inertia_check = true; + std::vector> row_cols_k = { + {1000000, 20, 1000}, + {3000000, 50, 20}, + {10000000, 50, 5}, + }; + for (auto& rck : row_cols_k) { + p.data.rows = std::get<0>(rck); + p.data.cols = std::get<1>(rck); + p.blobs.n_clusters = std::get<2>(rck); + p.kmeans.n_clusters = std::get<2>(rck); + for (auto bs_shift : std::vector({16, 18})) { + p.kmeans.batch_samples = 1 << bs_shift; + out.push_back(p); + } + } + return out; +} + +// note(lsugy): commenting out int64_t because the templates are not compiled in the distance +// library, resulting in long compilation times. +RAFT_BENCH_REGISTER((KMeans), "", getKMeansInputs()); +RAFT_BENCH_REGISTER((KMeans), "", getKMeansInputs()); +// RAFT_BENCH_REGISTER((KMeans), "", getKMeansInputs()); +// RAFT_BENCH_REGISTER((KMeans), "", getKMeansInputs()); + +} // namespace raft::bench::cluster diff --git a/cpp/bench/cluster/kmeans_balanced.cu b/cpp/bench/cluster/kmeans_balanced.cu new file mode 100644 index 0000000000..210b40ced8 --- /dev/null +++ b/cpp/bench/cluster/kmeans_balanced.cu @@ -0,0 +1,110 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#if defined RAFT_DISTANCE_COMPILED && defined RAFT_NN_COMPILED +#include +#endif + +namespace raft::bench::cluster { + +struct KMeansBalancedBenchParams { + DatasetParams data; + uint32_t max_iter; + uint32_t n_lists; + raft::distance::DistanceType metric; +}; + +template +struct KMeansBalanced : public fixture { + KMeansBalanced(const KMeansBalancedBenchParams& p) : params(p) {} + + void run_benchmark(::benchmark::State& state) override + { + this->loop_on_state(state, [this]() { + raft::spatial::knn::detail::kmeans::build_hierarchical(this->handle, + this->params.max_iter, + (uint32_t)this->params.data.cols, + this->X.data_handle(), + this->params.data.rows, + this->centroids.data_handle(), + this->params.n_lists, + this->params.metric, + this->handle.get_stream()); + }); + } + + void allocate_data(const ::benchmark::State& state) override + { + X = raft::make_device_matrix(handle, params.data.rows, params.data.cols); + + raft::random::RngState rng{1234}; + constexpr T kRangeMax = std::is_integral_v ? std::numeric_limits::max() : T(1); + constexpr T kRangeMin = std::is_integral_v ? std::numeric_limits::min() : T(-1); + if constexpr (std::is_integral_v) { + raft::random::uniformInt( + rng, X.data_handle(), params.data.rows * params.data.cols, kRangeMin, kRangeMax, stream); + } else { + raft::random::uniform( + rng, X.data_handle(), params.data.rows * params.data.cols, kRangeMin, kRangeMax, stream); + } + handle.sync_stream(stream); + } + + void allocate_temp_buffers(const ::benchmark::State& state) override + { + centroids = + raft::make_device_matrix(this->handle, params.n_lists, params.data.cols); + } + + private: + KMeansBalancedBenchParams params; + raft::device_matrix X; + raft::device_matrix centroids; +}; // struct KMeansBalanced + +std::vector getKMeansBalancedInputs() +{ + std::vector out; + KMeansBalancedBenchParams p; + p.data.row_major = true; + p.max_iter = 20; + p.metric = raft::distance::DistanceType::L2Expanded; + std::vector> row_cols = { + {100000, 128}, {1000000, 128}, {10000000, 128}, + // The following dataset sizes are too large for most GPUs. + // {100000000, 128}, + }; + for (auto& rc : row_cols) { + p.data.rows = rc.first; + p.data.cols = rc.second; + for (auto n_lists : std::vector({1000, 10000, 100000})) { + p.n_lists = n_lists; + out.push_back(p); + } + } + return out; +} + +// Note: the datasets sizes are too large for 32-bit index types. +RAFT_BENCH_REGISTER((KMeansBalanced), "", getKMeansBalancedInputs()); +RAFT_BENCH_REGISTER((KMeansBalanced), "", getKMeansBalancedInputs()); +RAFT_BENCH_REGISTER((KMeansBalanced), "", getKMeansBalancedInputs()); + +} // namespace raft::bench::cluster diff --git a/cpp/bench/common/benchmark.hpp b/cpp/bench/common/benchmark.hpp index fb878a0c8d..adfe5218e2 100644 --- a/cpp/bench/common/benchmark.hpp +++ b/cpp/bench/common/benchmark.hpp @@ -18,9 +18,12 @@ #include +#include +#include #include #include #include +#include #include @@ -121,6 +124,10 @@ class fixture { // every benchmark should be overriding this virtual void run_benchmark(::benchmark::State& state) = 0; virtual void generate_metrics(::benchmark::State& state) {} + virtual void allocate_data(const ::benchmark::State& state) {} + virtual void deallocate_data(const ::benchmark::State& state) {} + virtual void allocate_temp_buffers(const ::benchmark::State& state) {} + virtual void deallocate_temp_buffers(const ::benchmark::State& state) {} protected: /** The helper that writes zeroes to some buffer in GPU memory to flush the L2 cache. */ @@ -144,6 +151,58 @@ class fixture { } }; +/** Indicates the dataset size. */ +struct DatasetParams { + size_t rows; + size_t cols; + bool row_major; +}; + +/** Holds params needed to generate blobs dataset */ +struct BlobsParams { + int n_clusters; + double cluster_std; + bool shuffle; + double center_box_min, center_box_max; + uint64_t seed; +}; + +/** Fixture for cluster benchmarks using make_blobs */ +template +class BlobsFixture : public fixture { + public: + BlobsFixture(const DatasetParams dp, const BlobsParams bp) : data_params(dp), blobs_params(bp) {} + + virtual void run_benchmark(::benchmark::State& state) = 0; + + void allocate_data(const ::benchmark::State& state) override + { + auto labels_ref = raft::make_device_vector(this->handle, data_params.rows); + X = raft::make_device_matrix(this->handle, data_params.rows, data_params.cols); + + raft::random::make_blobs(X.data_handle(), + labels_ref.data_handle(), + (IndexT)data_params.rows, + (IndexT)data_params.cols, + (IndexT)blobs_params.n_clusters, + stream, + data_params.row_major, + nullptr, + nullptr, + (T)blobs_params.cluster_std, + blobs_params.shuffle, + (T)blobs_params.center_box_min, + (T)blobs_params.center_box_max, + blobs_params.seed); + this->handle.sync_stream(stream); + } + + protected: + DatasetParams data_params; + BlobsParams blobs_params; + raft::device_matrix X; +}; + namespace internal { template @@ -162,8 +221,17 @@ class Fixture : public ::benchmark::Fixture { { fixture_ = std::apply([](const Params&... ps) { return std::make_unique(ps...); }, params_); + fixture_->allocate_data(state); + fixture_->allocate_temp_buffers(state); + } + + void TearDown(const State& state) override + { + fixture_->deallocate_temp_buffers(state); + fixture_->deallocate_data(state); + fixture_.reset(); } - void TearDown(const State& state) override { fixture_.reset(); } + void SetUp(State& st) override { SetUp(const_cast(st)); } void TearDown(State& st) override { TearDown(const_cast(st)); } @@ -248,6 +316,10 @@ struct registrar { }; // namespace internal +#define RAFT_BENCH_REGISTER_INTERNAL(TestClass, ...) \ + static raft::bench::internal::registrar BENCHMARK_PRIVATE_NAME(registrar)( \ + RAFT_STRINGIFY(TestClass), __VA_ARGS__) + /** * This is the entry point macro for all benchmarks. This needs to be called * for the set of benchmarks to be registered so that the main harness inside @@ -262,8 +334,7 @@ struct registrar { * empty string * @param params... zero or more lists of params upon which to benchmark. */ -#define RAFT_BENCH_REGISTER(TestClass, ...) \ - static raft::bench::internal::registrar BENCHMARK_PRIVATE_NAME(registrar)( \ - #TestClass, __VA_ARGS__) +#define RAFT_BENCH_REGISTER(TestClass, ...) \ + RAFT_BENCH_REGISTER_INTERNAL(RAFT_DEPAREN(TestClass), __VA_ARGS__) } // namespace raft::bench diff --git a/cpp/bench/spatial/knn.cu b/cpp/bench/spatial/knn.cu index 64a1217d7f..6b08c7ee33 100644 --- a/cpp/bench/spatial/knn.cu +++ b/cpp/bench/spatial/knn.cu @@ -19,6 +19,7 @@ #include #include +#include #if defined RAFT_NN_COMPILED #include #endif diff --git a/cpp/include/raft/cluster/detail/kmeans.cuh b/cpp/include/raft/cluster/detail/kmeans.cuh index 8a463b97ef..26005f58a0 100644 --- a/cpp/include/raft/cluster/detail/kmeans.cuh +++ b/cpp/include/raft/cluster/detail/kmeans.cuh @@ -563,12 +563,12 @@ void initScalableKMeansPlusPlus(const raft::handle_t& handle, X.data_handle() + cIdx * n_features, 1, n_features); // flag the sample that is chosen as initial centroid - std::vector h_isSampleCentroid(n_samples); + std::vector h_isSampleCentroid(n_samples); std::fill(h_isSampleCentroid.begin(), h_isSampleCentroid.end(), 0); h_isSampleCentroid[cIdx] = 1; // device buffer to flag the sample that is chosen as initial centroid - auto isSampleCentroid = raft::make_device_vector(handle, n_samples); + auto isSampleCentroid = raft::make_device_vector(handle, n_samples); raft::copy( isSampleCentroid.data_handle(), h_isSampleCentroid.data(), isSampleCentroid.size(), stream); @@ -800,6 +800,17 @@ void kmeans_fit(handle_t const& handle, RAFT_EXPECTS(centroids.extent(1) == n_features, "invalid parameter (centroids.extent(1) != n_features)"); + // Display a warning if batch_centroids is set and a fusedL2NN-compatible metric is used + if (params.batch_centroids != 0 && params.batch_centroids != params.n_clusters && + (params.metric == raft::distance::DistanceType::L2Expanded || + params.metric == raft::distance::DistanceType::L2SqrtExpanded)) { + RAFT_LOG_INFO( + "batch_centroids=%d was passed, but batch_centroids=%d will be used (reason: " + "batch_centroids has no impact on the memory footprint when FusedL2NN can be used)", + params.batch_centroids, + params.n_clusters); + } + logger::get(RAFT_NAME).set_level(params.verbosity); // Allocate memory diff --git a/cpp/include/raft/cluster/detail/kmeans_common.cuh b/cpp/include/raft/cluster/detail/kmeans_common.cuh index ecbb39d60f..e9929a089d 100644 --- a/cpp/include/raft/cluster/detail/kmeans_common.cuh +++ b/cpp/include/raft/cluster/detail/kmeans_common.cuh @@ -51,44 +51,16 @@ namespace raft { namespace cluster { namespace detail { -template -struct FusedL2NNReduceOp { - IndexT offset; - - FusedL2NNReduceOp(IndexT _offset) : offset(_offset){}; - - typedef typename cub::KeyValuePair KVP; - DI void operator()(IndexT rit, KVP* out, const KVP& other) - { - if (other.value < out->value) { - out->key = offset + other.key; - out->value = other.value; - } - } - - DI void operator()(IndexT rit, DataT* out, const KVP& other) - { - if (other.value < *out) { *out = other.value; } - } - - DI void init(DataT* out, DataT maxVal) { *out = maxVal; } - DI void init(KVP* out, DataT maxVal) - { - out->key = -1; - out->value = maxVal; - } -}; - template struct SamplingOp { DataT* rnd; - int* flag; + uint8_t* flag; DataT cluster_cost; double oversampling_factor; IndexT n_clusters; CUB_RUNTIME_FUNCTION __forceinline__ - SamplingOp(DataT c, double l, IndexT k, DataT* rand, int* ptr) + SamplingOp(DataT c, double l, IndexT k, DataT* rand, uint8_t* ptr) : cluster_cost(c), oversampling_factor(l), n_clusters(k), rnd(rand), flag(ptr) { } @@ -240,7 +212,7 @@ template void sampleCentroids(const raft::handle_t& handle, const raft::device_matrix_view& X, const raft::device_vector_view& minClusterDistance, - const raft::device_vector_view& isSampleCentroid, + const raft::device_vector_view& isSampleCentroid, SamplingOp& select_op, rmm::device_uvector& inRankCp, rmm::device_uvector& workspace) @@ -278,7 +250,7 @@ void sampleCentroids(const raft::handle_t& handle, raft::copy(&nPtsSampledInRank, nSelected.data_handle(), 1, stream); handle.sync_stream(stream); - IndexT* rawPtr_isSampleCentroid = isSampleCentroid.data_handle(); + uint8_t* rawPtr_isSampleCentroid = isSampleCentroid.data_handle(); thrust::for_each_n(handle.get_thrust_policy(), sampledMinClusterDistance.data_handle(), nPtsSampledInRank, @@ -346,13 +318,13 @@ void shuffleAndGather(const raft::handle_t& handle, if (workspace) { // shuffle indices on device - raft::random::permute(indices.data_handle(), - nullptr, - nullptr, - (IndexT)in.extent(1), - (IndexT)in.extent(0), - true, - stream); + raft::random::permute(indices.data_handle(), + nullptr, + nullptr, + (IndexT)in.extent(1), + (IndexT)in.extent(0), + true, + stream); } else { // shuffle indices on host and copy to device... std::vector ht_indices(n_samples); @@ -443,41 +415,35 @@ void minClusterAndDistanceCompute( auto L2NormXView = raft::make_device_vector_view(L2NormX.data_handle() + dIdx, ns); - // tile over the centroids - for (IndexT cIdx = 0; cIdx < n_clusters; cIdx += centroidsBatchSize) { - // # of centroids for the current batch - auto nc = std::min((IndexT)centroidsBatchSize, n_clusters - cIdx); - - // centroidsView [nc x n_features] - view representing the current batch - // of centroids - auto centroidsView = raft::make_device_matrix_view( - centroids.data_handle() + (cIdx * n_features), nc, n_features); + if (metric == raft::distance::DistanceType::L2Expanded || + metric == raft::distance::DistanceType::L2SqrtExpanded) { + workspace.resize((sizeof(int)) * ns, stream); + + // todo(lsugy): remove cIdx + raft::distance::fusedL2NNMinReduce, IndexT>( + minClusterAndDistanceView.data_handle(), + datasetView.data_handle(), + centroids.data_handle(), + L2NormXView.data_handle(), + centroidsNorm.data_handle(), + ns, + n_clusters, + n_features, + (void*)workspace.data(), + metric != raft::distance::DistanceType::L2Expanded, + false, + stream); + } else { + // tile over the centroids + for (IndexT cIdx = 0; cIdx < n_clusters; cIdx += centroidsBatchSize) { + // # of centroids for the current batch + auto nc = std::min((IndexT)centroidsBatchSize, n_clusters - cIdx); + + // centroidsView [nc x n_features] - view representing the current batch + // of centroids + auto centroidsView = raft::make_device_matrix_view( + centroids.data_handle() + (cIdx * n_features), nc, n_features); - if (metric == raft::distance::DistanceType::L2Expanded || - metric == raft::distance::DistanceType::L2SqrtExpanded) { - auto centroidsNormView = - raft::make_device_vector_view(centroidsNorm.data_handle() + cIdx, nc); - workspace.resize((sizeof(int)) * ns, stream); - - FusedL2NNReduceOp redOp(cIdx); - raft::distance::KVPMinReduce pairRedOp; - - raft::distance::fusedL2NN, IndexT>( - minClusterAndDistanceView.data_handle(), - datasetView.data_handle(), - centroidsView.data_handle(), - L2NormXView.data_handle(), - centroidsNormView.data_handle(), - ns, - nc, - n_features, - (void*)workspace.data(), - redOp, - pairRedOp, - (metric == raft::distance::DistanceType::L2Expanded) ? false : true, - false, - stream); - } else { // pairwiseDistanceView [ns x nc] - view representing the pairwise // distance for current batch auto pairwiseDistanceView = @@ -578,40 +544,35 @@ void minClusterDistanceCompute(const raft::handle_t& handle, auto L2NormXView = raft::make_device_vector_view(L2NormX.data_handle() + dIdx, ns); - // tile over the centroids - for (IndexT cIdx = 0; cIdx < n_clusters; cIdx += centroidsBatchSize) { - // # of centroids for the current batch - auto nc = std::min((IndexT)centroidsBatchSize, n_clusters - cIdx); - - // centroidsView [nc x n_features] - view representing the current batch - // of centroids - auto centroidsView = raft::make_device_matrix_view( - centroids.data_handle() + cIdx * n_features, nc, n_features); - - if (metric == raft::distance::DistanceType::L2Expanded || - metric == raft::distance::DistanceType::L2SqrtExpanded) { - auto centroidsNormView = - raft::make_device_vector_view(centroidsNorm.data_handle() + cIdx, nc); - workspace.resize((sizeof(IndexT)) * ns, stream); - - FusedL2NNReduceOp redOp(cIdx); - raft::distance::KVPMinReduce pairRedOp; - raft::distance::fusedL2NN( - minClusterDistanceView.data_handle(), - datasetView.data_handle(), - centroidsView.data_handle(), - L2NormXView.data_handle(), - centroidsNormView.data_handle(), - ns, - nc, - n_features, - (void*)workspace.data(), - redOp, - pairRedOp, - (metric != raft::distance::DistanceType::L2Expanded), - false, - stream); - } else { + if (metric == raft::distance::DistanceType::L2Expanded || + metric == raft::distance::DistanceType::L2SqrtExpanded) { + workspace.resize((sizeof(IndexT)) * ns, stream); + + // todo(lsugy): remove cIdx + raft::distance::fusedL2NNMinReduce( + minClusterDistanceView.data_handle(), + datasetView.data_handle(), + centroids.data_handle(), + L2NormXView.data_handle(), + centroidsNorm.data_handle(), + ns, + n_clusters, + n_features, + (void*)workspace.data(), + metric != raft::distance::DistanceType::L2Expanded, + false, + stream); + } else { + // tile over the centroids + for (IndexT cIdx = 0; cIdx < n_clusters; cIdx += centroidsBatchSize) { + // # of centroids for the current batch + auto nc = std::min((IndexT)centroidsBatchSize, n_clusters - cIdx); + + // centroidsView [nc x n_features] - view representing the current batch + // of centroids + auto centroidsView = raft::make_device_matrix_view( + centroids.data_handle() + cIdx * n_features, nc, n_features); + // pairwiseDistanceView [ns x nc] - view representing the pairwise // distance for current batch auto pairwiseDistanceView = diff --git a/cpp/include/raft/core/detail/macros.hpp b/cpp/include/raft/core/detail/macros.hpp index 4b804d61e3..66b67579fc 100644 --- a/cpp/include/raft/core/detail/macros.hpp +++ b/cpp/include/raft/core/detail/macros.hpp @@ -39,3 +39,20 @@ #ifndef RAFT_INLINE_FUNCTION #define RAFT_INLINE_FUNCTION _RAFT_FORCEINLINE _RAFT_HOST_DEVICE #endif + +/** + * Some macro magic to remove optional parentheses of a macro argument. + * See https://stackoverflow.com/a/62984543 + */ +#ifndef RAFT_DEPAREN_MAGICRAFT_DEPAREN_H1 +#define RAFT_DEPAREN(X) RAFT_DEPAREN_H2(RAFT_DEPAREN_H1 X) +#define RAFT_DEPAREN_H1(...) RAFT_DEPAREN_H1 __VA_ARGS__ +#define RAFT_DEPAREN_H2(...) RAFT_DEPAREN_H3(__VA_ARGS__) +#define RAFT_DEPAREN_H3(...) RAFT_DEPAREN_MAGIC##__VA_ARGS__ +#define RAFT_DEPAREN_MAGICRAFT_DEPAREN_H1 +#endif + +#ifndef RAFT_STRINGIFY +#define RAFT_STRINGIFY_DETAIL(...) #__VA_ARGS__ +#define RAFT_STRINGIFY(...) RAFT_STRINGIFY_DETAIL(__VA_ARGS__) +#endif diff --git a/cpp/include/raft/distance/detail/fused_l2_nn.cuh b/cpp/include/raft/distance/detail/fused_l2_nn.cuh index f46338943f..6bd622853d 100644 --- a/cpp/include/raft/distance/detail/fused_l2_nn.cuh +++ b/cpp/include/raft/distance/detail/fused_l2_nn.cuh @@ -51,6 +51,12 @@ struct MinAndDistanceReduceOpImpl { } } + DI void operator()(LabelT rid, DataT* out, const KVP& other) + { + if (other.value < *out) { *out = other.value; } + } + + DI void init(DataT* out, DataT maxVal) { *out = maxVal; } DI void init(KVP* out, DataT maxVal) { out->key = -1; diff --git a/cpp/include/raft/distance/fused_l2_nn.cuh b/cpp/include/raft/distance/fused_l2_nn.cuh index c1cf790203..ed781f1d18 100644 --- a/cpp/include/raft/distance/fused_l2_nn.cuh +++ b/cpp/include/raft/distance/fused_l2_nn.cuh @@ -26,6 +26,7 @@ #include #include #include +#include namespace raft { namespace distance { @@ -100,6 +101,10 @@ void fusedL2NN(OutT* min, bool initOutBuffer, cudaStream_t stream) { + // Assigning -1 to unsigned integers results in a compiler error. + // Enforce a signed IdxT here with a clear error message. + static_assert(std::is_signed_v, "fusedL2NN only supports signed index types."); + // When k is smaller than 32, the Policy4x4 results in redundant calculations // as it uses tiles that have k=32. Therefore, use a "skinny" policy instead // that uses tiles with a smaller value of k. @@ -157,6 +162,56 @@ void fusedL2NN(OutT* min, } } +/** + * @brief Wrapper around fusedL2NN with minimum reduction operators. + * + * fusedL2NN cannot be compiled in the distance library due to the lambda + * operators, so this wrapper covers the most common case (minimum). + * This should be preferred to the more generic API when possible, in order to + * reduce compilation times for users of the shared library. + * + * @tparam DataT data type + * @tparam OutT output type to either store 1-NN indices and their minimum + * distances (e.g. cub::KeyValuePair) or store only the min distances. + * @tparam IdxT indexing arithmetic type + * @param[out] min will contain the reduced output (Length = `m`) + * (on device) + * @param[in] x first matrix. Row major. Dim = `m x k`. + * (on device). + * @param[in] y second matrix. Row major. Dim = `n x k`. + * (on device). + * @param[in] xn L2 squared norm of `x`. Length = `m`. (on device). + * @param[in] yn L2 squared norm of `y`. Length = `n`. (on device) + * @param[in] m gemm m + * @param[in] n gemm n + * @param[in] k gemm k + * @param[in] workspace temp workspace. Size = sizeof(int)*m. (on device) + * @param[in] sqrt Whether the output `minDist` should contain L2-sqrt + * @param[in] initOutBuffer whether to initialize the output buffer before the + * main kernel launch + * @param[in] stream cuda stream + */ +template +void fusedL2NNMinReduce(OutT* min, + const DataT* x, + const DataT* y, + const DataT* xn, + const DataT* yn, + IdxT m, + IdxT n, + IdxT k, + void* workspace, + bool sqrt, + bool initOutBuffer, + cudaStream_t stream) +{ + MinAndDistanceReduceOp redOp; + KVPMinReduce pairRedOp; + + fusedL2NN( + min, x, y, xn, yn, m, n, k, workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream); +} + } // namespace distance } // namespace raft diff --git a/cpp/include/raft/distance/specializations/detail/russel_rao.cuh b/cpp/include/raft/distance/specializations/detail/russel_rao.cuh new file mode 100644 index 0000000000..f0aa1c27ee --- /dev/null +++ b/cpp/include/raft/distance/specializations/detail/russel_rao.cuh @@ -0,0 +1,68 @@ +/* + * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace raft { +namespace distance { +namespace detail { +extern template void +distance( + const float* x, + const float* y, + float* dist, + int m, + int n, + int k, + void* workspace, + size_t worksize, + cudaStream_t stream, + bool isRowMajor, + float metric_arg); + +extern template void +distance( + const double* x, + const double* y, + double* dist, + int m, + int n, + int k, + void* workspace, + size_t worksize, + cudaStream_t stream, + bool isRowMajor, + double metric_arg); + +extern template void +distance( + const float* x, + const float* y, + float* dist, + std::uint32_t m, + std::uint32_t n, + std::uint32_t k, + void* workspace, + size_t worksize, + cudaStream_t stream, + bool isRowMajor, + float metric_arg); + +} // namespace detail +} // namespace distance +} // namespace raft diff --git a/cpp/include/raft/distance/specializations/distance.cuh b/cpp/include/raft/distance/specializations/distance.cuh index 7553f87e39..73d075f260 100644 --- a/cpp/include/raft/distance/specializations/distance.cuh +++ b/cpp/include/raft/distance/specializations/distance.cuh @@ -30,3 +30,5 @@ #include #include #include +#include +#include diff --git a/cpp/include/raft/distance/specializations/fused_l2_nn_min.cuh b/cpp/include/raft/distance/specializations/fused_l2_nn_min.cuh new file mode 100644 index 0000000000..deddf65b37 --- /dev/null +++ b/cpp/include/raft/distance/specializations/fused_l2_nn_min.cuh @@ -0,0 +1,126 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace raft { +namespace distance { + +extern template void fusedL2NNMinReduce, int>( + cub::KeyValuePair* min, + const float* x, + const float* y, + const float* xn, + const float* yn, + int m, + int n, + int k, + void* workspace, + bool sqrt, + bool initOutBuffer, + cudaStream_t stream); +extern template void fusedL2NNMinReduce, int64_t>( + cub::KeyValuePair* min, + const float* x, + const float* y, + const float* xn, + const float* yn, + int64_t m, + int64_t n, + int64_t k, + void* workspace, + bool sqrt, + bool initOutBuffer, + cudaStream_t stream); +extern template void fusedL2NNMinReduce, int>( + cub::KeyValuePair* min, + const double* x, + const double* y, + const double* xn, + const double* yn, + int m, + int n, + int k, + void* workspace, + bool sqrt, + bool initOutBuffer, + cudaStream_t stream); +extern template void fusedL2NNMinReduce, int64_t>( + cub::KeyValuePair* min, + const double* x, + const double* y, + const double* xn, + const double* yn, + int64_t m, + int64_t n, + int64_t k, + void* workspace, + bool sqrt, + bool initOutBuffer, + cudaStream_t stream); +extern template void fusedL2NNMinReduce(float* min, + const float* x, + const float* y, + const float* xn, + const float* yn, + int m, + int n, + int k, + void* workspace, + bool sqrt, + bool initOutBuffer, + cudaStream_t stream); +extern template void fusedL2NNMinReduce(float* min, + const float* x, + const float* y, + const float* xn, + const float* yn, + int64_t m, + int64_t n, + int64_t k, + void* workspace, + bool sqrt, + bool initOutBuffer, + cudaStream_t stream); +extern template void fusedL2NNMinReduce(double* min, + const double* x, + const double* y, + const double* xn, + const double* yn, + int m, + int n, + int k, + void* workspace, + bool sqrt, + bool initOutBuffer, + cudaStream_t stream); +extern template void fusedL2NNMinReduce(double* min, + const double* x, + const double* y, + const double* xn, + const double* yn, + int64_t m, + int64_t n, + int64_t k, + void* workspace, + bool sqrt, + bool initOutBuffer, + cudaStream_t stream); + +} // namespace distance +} // namespace raft \ No newline at end of file diff --git a/cpp/include/raft/random/detail/rng_impl.cuh b/cpp/include/raft/random/detail/rng_impl.cuh index d4471a4560..5aecbfcaa2 100644 --- a/cpp/include/raft/random/detail/rng_impl.cuh +++ b/cpp/include/raft/random/detail/rng_impl.cuh @@ -16,6 +16,7 @@ #pragma once +#include #include #include #include @@ -26,16 +27,6 @@ namespace raft { namespace random { namespace detail { -/** - * Some macro magic to remove optional parentheses of a macro argument. - * See https://stackoverflow.com/a/62984543 - */ -#define RAFT_DEPAREN(X) RAFT_DEPAREN_H2(RAFT_DEPAREN_H1 X) -#define RAFT_DEPAREN_H1(...) RAFT_DEPAREN_H1 __VA_ARGS__ -#define RAFT_DEPAREN_H2(...) RAFT_DEPAREN_H3(__VA_ARGS__) -#define RAFT_DEPAREN_H3(...) RAFT_DEPAREN_MAGIC##__VA_ARGS__ -#define RAFT_DEPAREN_MAGICRAFT_DEPAREN_H1 - /** * This macro will invoke function `func` with the correct instantiation of * device state as the first parameter, and passes all subsequent macro diff --git a/cpp/include/raft/spatial/knn/detail/ann_kmeans_balanced.cuh b/cpp/include/raft/spatial/knn/detail/ann_kmeans_balanced.cuh index 122306639f..f64c5549a4 100644 --- a/cpp/include/raft/spatial/knn/detail/ann_kmeans_balanced.cuh +++ b/cpp/include/raft/spatial/knn/detail/ann_kmeans_balanced.cuh @@ -80,8 +80,8 @@ inline void predict_float_core(const handle_t& handle, case raft::distance::DistanceType::L2Unexpanded: { rmm::device_uvector sqsum_centers(n_clusters, stream, mr); rmm::device_uvector sqsum_data(n_rows, stream, mr); - utils::dots_along_rows(n_clusters, dim, centers, sqsum_centers.data(), stream); - utils::dots_along_rows(n_rows, dim, dataset, sqsum_data.data(), stream); + utils::dots_along_rows(n_clusters, dim, centers, sqsum_centers.data(), stream); + utils::dots_along_rows(n_rows, dim, dataset, sqsum_data.data(), stream); utils::outer_add( sqsum_data.data(), n_rows, sqsum_centers.data(), n_clusters, distances.data(), stream); alpha = -2.0; @@ -105,7 +105,7 @@ inline void predict_float_core(const handle_t& handle, distances.data(), n_clusters, stream); - utils::argmin_along_rows(n_rows, n_clusters, distances.data(), labels, stream); + utils::argmin_along_rows(n_rows, n_clusters, distances.data(), labels, stream); } /** @@ -494,7 +494,7 @@ void balancing_em_iters(const handle_t& handle, case raft::distance::DistanceType::InnerProduct: case raft::distance::DistanceType::CosineExpanded: case raft::distance::DistanceType::CorrelationExpanded: - utils::normalize_rows(n_clusters, dim, cluster_centers, stream); + utils::normalize_rows(n_clusters, dim, cluster_centers, stream); default: break; } // E: Expectation step - predict labels diff --git a/cpp/include/raft/spatial/knn/detail/ann_utils.cuh b/cpp/include/raft/spatial/knn/detail/ann_utils.cuh index 9667053afd..e55758711a 100644 --- a/cpp/include/raft/spatial/knn/detail/ann_utils.cuh +++ b/cpp/include/raft/spatial/knn/detail/ann_utils.cuh @@ -151,18 +151,16 @@ inline void memzero(T* ptr, size_t n_elems, rmm::cuda_stream_view stream) } } -__global__ void argmin_along_rows_kernel(uint32_t n_rows, - uint32_t n_cols, - const float* a, - uint32_t* out) +template +__global__ void argmin_along_rows_kernel(IdxT n_rows, IdxT n_cols, const float* a, OutT* out) { - __shared__ uint32_t shm_ids[1024]; // NOLINT - __shared__ float shm_vals[1024]; // NOLINT - uint32_t i = blockIdx.x; + __shared__ OutT shm_ids[1024]; // NOLINT + __shared__ float shm_vals[1024]; // NOLINT + IdxT i = blockIdx.x; if (i >= n_rows) return; - uint32_t min_idx = n_cols; - float min_val = raft::upper_bound(); - for (uint32_t j = threadIdx.x; j < n_cols; j += blockDim.x) { + OutT min_idx = n_cols; + float min_val = raft::upper_bound(); + for (OutT j = threadIdx.x; j < n_cols; j += blockDim.x) { if (min_val > a[j + n_cols * i]) { min_val = a[j + n_cols * i]; min_idx = j; @@ -171,7 +169,7 @@ __global__ void argmin_along_rows_kernel(uint32_t n_rows, shm_vals[threadIdx.x] = min_val; shm_ids[threadIdx.x] = min_idx; __syncthreads(); - for (uint32_t offset = blockDim.x / 2; offset > 0; offset >>= 1) { + for (IdxT offset = blockDim.x / 2; offset > 0; offset >>= 1) { if (threadIdx.x < offset) { if (shm_vals[threadIdx.x] < shm_vals[threadIdx.x + offset]) { } else if (shm_vals[threadIdx.x] > shm_vals[threadIdx.x + offset]) { @@ -192,30 +190,35 @@ __global__ void argmin_along_rows_kernel(uint32_t n_rows, * NB: device-only function * TODO: specialize select_k for the case of `k == 1` and use that one instead. * + * @tparam IdxT index type + * @tparam OutT output type + * * @param n_rows * @param n_cols * @param[in] a device pointer to the row-major matrix [n_rows, n_cols] * @param[out] out device pointer to the vector of selected indices [n_rows] * @param stream */ +template inline void argmin_along_rows( - uint32_t n_rows, uint32_t n_cols, const float* a, uint32_t* out, rmm::cuda_stream_view stream) + IdxT n_rows, IdxT n_cols, const float* a, OutT* out, rmm::cuda_stream_view stream) { - uint32_t block_dim = 1024; + IdxT block_dim = 1024; while (block_dim > n_cols) { block_dim /= 2; } - block_dim = max(block_dim, 128); - argmin_along_rows_kernel<<>>(n_rows, n_cols, a, out); + block_dim = max(block_dim, (IdxT)128); + argmin_along_rows_kernel<<>>(n_rows, n_cols, a, out); } -__global__ void dots_along_rows_kernel(uint32_t n_rows, uint32_t n_cols, const float* a, float* out) +template +__global__ void dots_along_rows_kernel(IdxT n_rows, IdxT n_cols, const float* a, float* out) { - uint64_t i = threadIdx.y + (blockDim.y * blockIdx.x); + IdxT i = threadIdx.y + (blockDim.y * blockIdx.x); if (i >= n_rows) return; float sqsum = 0.0; - for (uint64_t j = threadIdx.x; j < n_cols; j += blockDim.x) { + for (IdxT j = threadIdx.x; j < n_cols; j += blockDim.x) { float val = a[j + (n_cols * i)]; sqsum += val * val; } @@ -232,18 +235,21 @@ __global__ void dots_along_rows_kernel(uint32_t n_rows, uint32_t n_cols, const f * * NB: device-only function * + * @tparam IdxT index type + * * @param n_rows * @param n_cols * @param[in] a device pointer to the row-major matrix [n_rows, n_cols] * @param[out] out device pointer to the vector of dot-products [n_rows] * @param stream */ +template inline void dots_along_rows( - uint32_t n_rows, uint32_t n_cols, const float* a, float* out, rmm::cuda_stream_view stream) + IdxT n_rows, IdxT n_cols, const float* a, float* out, rmm::cuda_stream_view stream) { dim3 threads(32, 4, 1); - dim3 blocks(ceildiv(n_rows, threads.y), 1, 1); - dots_along_rows_kernel<<>>(n_rows, n_cols, a, out); + dim3 blocks(ceildiv(n_rows, threads.y), 1, 1); + dots_along_rows_kernel<<>>(n_rows, n_cols, a, out); /** * TODO: this can be replaced with the rowNorm helper as shown below. * However, the rowNorm helper seems to incur a significant performance penalty @@ -317,13 +323,14 @@ void accumulate_into_selected(size_t n_rows, } } -__global__ void normalize_rows_kernel(uint32_t n_rows, uint32_t n_cols, float* a) +template +__global__ void normalize_rows_kernel(IdxT n_rows, IdxT n_cols, float* a) { uint64_t i = threadIdx.y + (blockDim.y * blockIdx.x); if (i >= n_rows) return; float sqsum = 0.0; - for (uint32_t j = threadIdx.x; j < n_cols; j += blockDim.x) { + for (IdxT j = threadIdx.x; j < n_cols; j += blockDim.x) { float val = a[j + (n_cols * i)]; sqsum += val * val; } @@ -334,7 +341,7 @@ __global__ void normalize_rows_kernel(uint32_t n_rows, uint32_t n_cols, float* a sqsum += __shfl_xor_sync(0xffffffff, sqsum, 16); if (sqsum <= 1e-8) return; sqsum = rsqrtf(sqsum); // reciprocal of the square root - for (uint32_t j = threadIdx.x; j < n_cols; j += blockDim.x) { + for (IdxT j = threadIdx.x; j < n_cols; j += blockDim.x) { a[j + n_cols * i] *= sqsum; } } @@ -344,16 +351,19 @@ __global__ void normalize_rows_kernel(uint32_t n_rows, uint32_t n_cols, float* a * * NB: device-only function * + * @tparam IdxT index type + * * @param[in] n_rows * @param[in] n_cols * @param[inout] a device pointer to a row-major matrix [n_rows, n_cols] * @param stream */ -inline void normalize_rows(uint32_t n_rows, uint32_t n_cols, float* a, rmm::cuda_stream_view stream) +template +inline void normalize_rows(IdxT n_rows, IdxT n_cols, float* a, rmm::cuda_stream_view stream) { dim3 threads(32, 4, 1); // DO NOT CHANGE dim3 blocks(ceildiv(n_rows, threads.y), 1, 1); - normalize_rows_kernel<<>>(n_rows, n_cols, a); + normalize_rows_kernel<<>>(n_rows, n_cols, a); } template diff --git a/cpp/include/raft/spectral/detail/warn_dbg.hpp b/cpp/include/raft/spectral/detail/warn_dbg.hpp index 08a4e6efb5..d714e1dc86 100644 --- a/cpp/include/raft/spectral/detail/warn_dbg.hpp +++ b/cpp/include/raft/spectral/detail/warn_dbg.hpp @@ -1,10 +1,24 @@ +/* + * Copyright (c) 2020-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #pragma once #include #include -#define STRINGIFY_DETAIL(x) #x -#define RAFT_STRINGIFY(x) STRINGIFY_DETAIL(x) +#include #ifdef DEBUG #define COUT() (std::cout) diff --git a/cpp/src/distance/specializations/detail/russel_rao_double_double_double_int.cu b/cpp/src/distance/specializations/detail/russel_rao_double_double_double_int.cu new file mode 100644 index 0000000000..809f6ccb78 --- /dev/null +++ b/cpp/src/distance/specializations/detail/russel_rao_double_double_double_int.cu @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +namespace raft { +namespace distance { +namespace detail { +template void +distance( + const double* x, + const double* y, + double* dist, + int m, + int n, + int k, + void* workspace, + std::size_t worksize, + cudaStream_t stream, + bool isRowMajor, + double metric_arg); + +} // namespace detail +} // namespace distance +} // namespace raft diff --git a/cpp/src/distance/specializations/detail/russel_rao_float_float_float_int.cu b/cpp/src/distance/specializations/detail/russel_rao_float_float_float_int.cu new file mode 100644 index 0000000000..831384fdd7 --- /dev/null +++ b/cpp/src/distance/specializations/detail/russel_rao_float_float_float_int.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +namespace raft { +namespace distance { +namespace detail { +template void distance( + const float* x, + const float* y, + float* dist, + int m, + int n, + int k, + void* workspace, + std::size_t worksize, + cudaStream_t stream, + bool isRowMajor, + float metric_arg); + +} // namespace detail +} // namespace distance +} // namespace raft diff --git a/cpp/src/distance/specializations/detail/russel_rao_float_float_float_uint32.cu b/cpp/src/distance/specializations/detail/russel_rao_float_float_float_uint32.cu new file mode 100644 index 0000000000..fcb8d42ab3 --- /dev/null +++ b/cpp/src/distance/specializations/detail/russel_rao_float_float_float_uint32.cu @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +namespace raft { +namespace distance { +namespace detail { + +template void +distance( + const float* x, + const float* y, + float* dist, + std::uint32_t m, + std::uint32_t n, + std::uint32_t k, + void* workspace, + std::size_t worksize, + cudaStream_t stream, + bool isRowMajor, + float metric_arg); + +} // namespace detail +} // namespace distance +} // namespace raft diff --git a/cpp/src/distance/specializations/fused_l2_nn_double_int.cu b/cpp/src/distance/specializations/fused_l2_nn_double_int.cu new file mode 100644 index 0000000000..b032261169 --- /dev/null +++ b/cpp/src/distance/specializations/fused_l2_nn_double_int.cu @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +namespace raft { +namespace distance { + +template void fusedL2NNMinReduce, int>( + cub::KeyValuePair* min, + const double* x, + const double* y, + const double* xn, + const double* yn, + int m, + int n, + int k, + void* workspace, + bool sqrt, + bool initOutBuffer, + cudaStream_t stream); +template void fusedL2NNMinReduce(double* min, + const double* x, + const double* y, + const double* xn, + const double* yn, + int m, + int n, + int k, + void* workspace, + bool sqrt, + bool initOutBuffer, + cudaStream_t stream); + +} // namespace distance +} // namespace raft diff --git a/cpp/src/distance/specializations/fused_l2_nn_double_int64.cu b/cpp/src/distance/specializations/fused_l2_nn_double_int64.cu new file mode 100644 index 0000000000..a208b013d5 --- /dev/null +++ b/cpp/src/distance/specializations/fused_l2_nn_double_int64.cu @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +namespace raft { +namespace distance { + +template void fusedL2NNMinReduce, int64_t>( + cub::KeyValuePair* min, + const double* x, + const double* y, + const double* xn, + const double* yn, + int64_t m, + int64_t n, + int64_t k, + void* workspace, + bool sqrt, + bool initOutBuffer, + cudaStream_t stream); +template void fusedL2NNMinReduce(double* min, + const double* x, + const double* y, + const double* xn, + const double* yn, + int64_t m, + int64_t n, + int64_t k, + void* workspace, + bool sqrt, + bool initOutBuffer, + cudaStream_t stream); + +} // namespace distance +} // namespace raft diff --git a/cpp/src/distance/specializations/fused_l2_nn_float_int.cu b/cpp/src/distance/specializations/fused_l2_nn_float_int.cu new file mode 100644 index 0000000000..f58349a826 --- /dev/null +++ b/cpp/src/distance/specializations/fused_l2_nn_float_int.cu @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +namespace raft { +namespace distance { + +template void fusedL2NNMinReduce, int>( + cub::KeyValuePair* min, + const float* x, + const float* y, + const float* xn, + const float* yn, + int m, + int n, + int k, + void* workspace, + bool sqrt, + bool initOutBuffer, + cudaStream_t stream); +template void fusedL2NNMinReduce(float* min, + const float* x, + const float* y, + const float* xn, + const float* yn, + int m, + int n, + int k, + void* workspace, + bool sqrt, + bool initOutBuffer, + cudaStream_t stream); + +} // namespace distance +} // namespace raft diff --git a/cpp/src/distance/specializations/fused_l2_nn_float_int64.cu b/cpp/src/distance/specializations/fused_l2_nn_float_int64.cu new file mode 100644 index 0000000000..e43c3aa4e9 --- /dev/null +++ b/cpp/src/distance/specializations/fused_l2_nn_float_int64.cu @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +namespace raft { +namespace distance { + +template void fusedL2NNMinReduce, int64_t>( + cub::KeyValuePair* min, + const float* x, + const float* y, + const float* xn, + const float* yn, + int64_t m, + int64_t n, + int64_t k, + void* workspace, + bool sqrt, + bool initOutBuffer, + cudaStream_t stream); +template void fusedL2NNMinReduce(float* min, + const float* x, + const float* y, + const float* xn, + const float* yn, + int64_t m, + int64_t n, + int64_t k, + void* workspace, + bool sqrt, + bool initOutBuffer, + cudaStream_t stream); + +} // namespace distance +} // namespace raft From 25ebd2c923d2bbfceaec145ff373ebc1400f5352 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Fri, 30 Sep 2022 15:49:43 -0400 Subject: [PATCH 3/3] Adding raft distance specialization --- cpp/bench/spatial/knn.cu | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/cpp/bench/spatial/knn.cu b/cpp/bench/spatial/knn.cu index 64a1217d7f..98721b4a63 100644 --- a/cpp/bench/spatial/knn.cu +++ b/cpp/bench/spatial/knn.cu @@ -19,6 +19,12 @@ #include #include +#include + +#if defined RAFT_DISTANCE_COMPILED +#include +#endif + #if defined RAFT_NN_COMPILED #include #endif