diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 81a6d76507..ce6eb00bc1 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -277,10 +277,10 @@ if(RAFT_COMPILE_DIST_LIBRARY) src/distance/specializations/detail/russel_rao_double_double_double_int.cu src/distance/specializations/detail/russel_rao_float_float_float_uint32.cu src/distance/specializations/detail/russel_rao_float_float_float_int.cu - src/distance/specializations/fused_l2_nn_double_int.cu - src/distance/specializations/fused_l2_nn_double_int64.cu - src/distance/specializations/fused_l2_nn_float_int.cu - src/distance/specializations/fused_l2_nn_float_int64.cu +# src/distance/specializations/fused_l2_nn_double_int.cu +# src/distance/specializations/fused_l2_nn_double_int64.cu +# src/distance/specializations/fused_l2_nn_float_int.cu +# src/distance/specializations/fused_l2_nn_float_int64.cu src/random/specializations/rmat_rectangular_generator_int_double.cu src/random/specializations/rmat_rectangular_generator_int64_double.cu src/random/specializations/rmat_rectangular_generator_int_float.cu diff --git a/cpp/bench/spatial/fused_l2_nn.cu b/cpp/bench/spatial/fused_l2_nn.cu index 2463089675..aa36483145 100644 --- a/cpp/bench/spatial/fused_l2_nn.cu +++ b/cpp/bench/spatial/fused_l2_nn.cu @@ -53,7 +53,7 @@ struct fused_l2_nn : public fixture { uniform(handle, r, y.data(), p.n * p.k, T(-1.0), T(1.0)); raft::linalg::rowNorm(xn.data(), x.data(), p.k, p.m, raft::linalg::L2Norm, true, stream); raft::linalg::rowNorm(yn.data(), y.data(), p.k, p.n, raft::linalg::L2Norm, true, stream); - raft::distance::initialize, int>( + raft::distance::initialize, int>( handle, out.data(), p.m, std::numeric_limits::max(), op); } @@ -61,20 +61,20 @@ struct fused_l2_nn : public fixture { { loop_on_state(state, [this]() { // it is enough to only benchmark the L2-squared metric - raft::distance::fusedL2NN, int>(out.data(), - x.data(), - y.data(), - xn.data(), - yn.data(), - params.m, - params.n, - params.k, - (void*)workspace.data(), - op, - pairRedOp, - false, - false, - stream); + raft::distance::fusedL2NN, int>(out.data(), + x.data(), + y.data(), + xn.data(), + yn.data(), + params.m, + params.n, + params.k, + (void*)workspace.data(), + op, + pairRedOp, + false, + false, + stream); }); // Num distance calculations @@ -92,7 +92,7 @@ struct fused_l2_nn : public fixture { state.counters["FLOP/s"] = benchmark::Counter( num_flops, benchmark::Counter::kIsIterationInvariantRate, benchmark::Counter::OneK::kIs1000); - state.counters["BW Wr"] = benchmark::Counter(write_elts * sizeof(raft::KeyValuePair), + state.counters["BW Wr"] = benchmark::Counter(write_elts * sizeof(cub::KeyValuePair), benchmark::Counter::kIsIterationInvariantRate, benchmark::Counter::OneK::kIs1000); state.counters["BW Rd"] = benchmark::Counter(read_elts * sizeof(float), @@ -105,7 +105,7 @@ struct fused_l2_nn : public fixture { private: fused_l2_nn_inputs params; rmm::device_uvector x, y, xn, yn; - rmm::device_uvector> out; + rmm::device_uvector> out; rmm::device_uvector workspace; raft::distance::KVPMinReduce pairRedOp; raft::distance::MinAndDistanceReduceOp op; diff --git a/cpp/include/raft/cluster/detail/kmeans.cuh b/cpp/include/raft/cluster/detail/kmeans.cuh index 94fee3edbf..26005f58a0 100644 --- a/cpp/include/raft/cluster/detail/kmeans.cuh +++ b/cpp/include/raft/cluster/detail/kmeans.cuh @@ -32,7 +32,6 @@ #include #include #include -#include #include #include #include @@ -279,7 +278,7 @@ void kmeans_fit_main(const raft::handle_t& handle, // - key is the index of nearest cluster // - value is the distance to the nearest cluster auto minClusterAndDistance = - raft::make_device_vector, IndexT>(handle, n_samples); + raft::make_device_vector, IndexT>(handle, n_samples); // temporary buffer to store L2 norm of centroids or distance matrix, // destructor releases the resource @@ -293,7 +292,7 @@ void kmeans_fit_main(const raft::handle_t& handle, // resource auto wtInCluster = raft::make_device_vector(handle, n_clusters); - rmm::device_scalar> clusterCostD(stream); + rmm::device_scalar> clusterCostD(stream); // L2 norm of X: ||x||^2 auto L2NormX = raft::make_device_vector(handle, n_samples); @@ -338,12 +337,12 @@ void kmeans_fit_main(const raft::handle_t& handle, workspace); // Using TransformInputIteratorT to dereference an array of - // raft::KeyValuePair and converting them to just return the Key to be used + // cub::KeyValuePair and converting them to just return the Key to be used // in reduce_rows_by_key prims detail::KeyValueIndexOp conversion_op; cub::TransformInputIterator, - raft::KeyValuePair*> + cub::KeyValuePair*> itr(minClusterAndDistance.data_handle(), conversion_op); workspace.resize(n_samples, stream); @@ -401,14 +400,14 @@ void kmeans_fit_main(const raft::handle_t& handle, itr_wt, wtInCluster.size(), newCentroids.data_handle(), - [=] __device__(raft::KeyValuePair map) { // predicate + [=] __device__(cub::KeyValuePair map) { // predicate // copy when the # of samples in the cluster is 0 if (map.value == 0) return true; else return false; }, - [=] __device__(raft::KeyValuePair map) { // map + [=] __device__(cub::KeyValuePair map) { // map return map.key; }, stream); @@ -440,9 +439,9 @@ void kmeans_fit_main(const raft::handle_t& handle, minClusterAndDistance.view(), workspace, raft::make_device_scalar_view(clusterCostD.data()), - [] __device__(const raft::KeyValuePair& a, - const raft::KeyValuePair& b) { - raft::KeyValuePair res; + [] __device__(const cub::KeyValuePair& a, + const cub::KeyValuePair& b) { + cub::KeyValuePair res; res.key = 0; res.value = a.value + b.value; return res; @@ -490,8 +489,8 @@ void kmeans_fit_main(const raft::handle_t& handle, minClusterAndDistance.data_handle() + minClusterAndDistance.size(), weight.data_handle(), minClusterAndDistance.data_handle(), - [=] __device__(const raft::KeyValuePair kvp, DataT wt) { - raft::KeyValuePair res; + [=] __device__(const cub::KeyValuePair kvp, DataT wt) { + cub::KeyValuePair res; res.value = kvp.value * wt; res.key = kvp.key; return res; @@ -502,9 +501,9 @@ void kmeans_fit_main(const raft::handle_t& handle, minClusterAndDistance.view(), workspace, raft::make_device_scalar_view(clusterCostD.data()), - [] __device__(const raft::KeyValuePair& a, - const raft::KeyValuePair& b) { - raft::KeyValuePair res; + [] __device__(const cub::KeyValuePair& a, + const cub::KeyValuePair& b) { + cub::KeyValuePair res; res.key = 0; res.value = a.value + b.value; return res; @@ -971,7 +970,7 @@ void kmeans_predict(handle_t const& handle, if (normalize_weight) checkWeight(handle, weight.view(), workspace); auto minClusterAndDistance = - raft::make_device_vector, IndexT>(handle, n_samples); + raft::make_device_vector, IndexT>(handle, n_samples); rmm::device_uvector L2NormBuf_OR_DistBuf(0, stream); // L2 norm of X: ||x||^2 @@ -1002,15 +1001,15 @@ void kmeans_predict(handle_t const& handle, workspace); // calculate cluster cost phi_x(C) - rmm::device_scalar> clusterCostD(stream); + rmm::device_scalar> clusterCostD(stream); // TODO: add different templates for InType of binaryOp to avoid thrust transform thrust::transform(handle.get_thrust_policy(), minClusterAndDistance.data_handle(), minClusterAndDistance.data_handle() + minClusterAndDistance.size(), weight.data_handle(), minClusterAndDistance.data_handle(), - [=] __device__(const raft::KeyValuePair kvp, DataT wt) { - raft::KeyValuePair res; + [=] __device__(const cub::KeyValuePair kvp, DataT wt) { + cub::KeyValuePair res; res.value = kvp.value * wt; res.key = kvp.key; return res; @@ -1020,9 +1019,9 @@ void kmeans_predict(handle_t const& handle, minClusterAndDistance.view(), workspace, raft::make_device_scalar_view(clusterCostD.data()), - [] __device__(const raft::KeyValuePair& a, - const raft::KeyValuePair& b) { - raft::KeyValuePair res; + [] __device__(const cub::KeyValuePair& a, + const cub::KeyValuePair& b) { + cub::KeyValuePair res; res.key = 0; res.value = a.value + b.value; return res; @@ -1034,7 +1033,7 @@ void kmeans_predict(handle_t const& handle, minClusterAndDistance.data_handle(), minClusterAndDistance.data_handle() + minClusterAndDistance.size(), labels.data_handle(), - [=] __device__(raft::KeyValuePair pair) { return pair.key; }); + [=] __device__(cub::KeyValuePair pair) { return pair.key; }); } template diff --git a/cpp/include/raft/cluster/detail/kmeans_common.cuh b/cpp/include/raft/cluster/detail/kmeans_common.cuh index d4dd565ea0..e9929a089d 100644 --- a/cpp/include/raft/cluster/detail/kmeans_common.cuh +++ b/cpp/include/raft/cluster/detail/kmeans_common.cuh @@ -31,7 +31,6 @@ #include #include #include -#include #include #include #include @@ -67,7 +66,7 @@ struct SamplingOp { } __host__ __device__ __forceinline__ bool operator()( - const raft::KeyValuePair& a) const + const cub::KeyValuePair& a) const { DataT prob_threshold = (DataT)rnd[a.key]; @@ -80,7 +79,7 @@ struct SamplingOp { template struct KeyValueIndexOp { __host__ __device__ __forceinline__ IndexT - operator()(const raft::KeyValuePair& a) const + operator()(const cub::KeyValuePair& a) const { return a.key; } @@ -225,7 +224,7 @@ void sampleCentroids(const raft::handle_t& handle, auto nSelected = raft::make_device_scalar(handle, 0); cub::ArgIndexInputIterator ip_itr(minClusterDistance.data_handle()); auto sampledMinClusterDistance = - raft::make_device_vector, IndexT>(handle, n_local_samples); + raft::make_device_vector, IndexT>(handle, n_local_samples); size_t temp_storage_bytes = 0; RAFT_CUDA_TRY(cub::DeviceSelect::If(nullptr, temp_storage_bytes, @@ -255,7 +254,7 @@ void sampleCentroids(const raft::handle_t& handle, thrust::for_each_n(handle.get_thrust_policy(), sampledMinClusterDistance.data_handle(), nPtsSampledInRank, - [=] __device__(raft::KeyValuePair val) { + [=] __device__(cub::KeyValuePair val) { rawPtr_isSampleCentroid[val.key] = 1; }); @@ -267,7 +266,7 @@ void sampleCentroids(const raft::handle_t& handle, sampledMinClusterDistance.data_handle(), nPtsSampledInRank, inRankCp.data(), - [=] __device__(raft::KeyValuePair val) { // MapTransformOp + [=] __device__(cub::KeyValuePair val) { // MapTransformOp return val.key; }, stream); @@ -356,7 +355,7 @@ void minClusterAndDistanceCompute( const KMeansParams& params, const raft::device_matrix_view X, const raft::device_matrix_view centroids, - const raft::device_vector_view, IndexT> minClusterAndDistance, + const raft::device_vector_view, IndexT> minClusterAndDistance, const raft::device_vector_view L2NormX, rmm::device_uvector& L2NormBuf_OR_DistBuf, rmm::device_uvector& workspace) @@ -391,7 +390,7 @@ void minClusterAndDistanceCompute( auto pairwiseDistance = raft::make_device_matrix_view( L2NormBuf_OR_DistBuf.data(), dataBatchSize, centroidsBatchSize); - raft::KeyValuePair initial_value(0, std::numeric_limits::max()); + cub::KeyValuePair initial_value(0, std::numeric_limits::max()); thrust::fill(handle.get_thrust_policy(), minClusterAndDistance.data_handle(), @@ -410,7 +409,7 @@ void minClusterAndDistanceCompute( // minClusterAndDistanceView [ns x n_clusters] auto minClusterAndDistanceView = - raft::make_device_vector_view, IndexT>( + raft::make_device_vector_view, IndexT>( minClusterAndDistance.data_handle() + dIdx, ns); auto L2NormXView = @@ -421,7 +420,7 @@ void minClusterAndDistanceCompute( workspace.resize((sizeof(int)) * ns, stream); // todo(lsugy): remove cIdx - raft::distance::fusedL2NNMinReduce, IndexT>( + raft::distance::fusedL2NNMinReduce, IndexT>( minClusterAndDistanceView.data_handle(), datasetView.data_handle(), centroids.data_handle(), @@ -467,15 +466,15 @@ void minClusterAndDistanceCompute( stream, true, [=] __device__(const DataT val, const IndexT i) { - raft::KeyValuePair pair; + cub::KeyValuePair pair; pair.key = cIdx + i; pair.value = val; return pair; }, - [=] __device__(raft::KeyValuePair a, raft::KeyValuePair b) { + [=] __device__(cub::KeyValuePair a, cub::KeyValuePair b) { return (b.value < a.value) ? b : a; }, - [=] __device__(raft::KeyValuePair pair) { return pair; }); + [=] __device__(cub::KeyValuePair pair) { return pair; }); } } } @@ -624,7 +623,7 @@ void countSamplesInCluster(const raft::handle_t& handle, // - key is the index of nearest cluster // - value is the distance to the nearest cluster auto minClusterAndDistance = - raft::make_device_vector, IndexT>(handle, n_samples); + raft::make_device_vector, IndexT>(handle, n_samples); // temporary buffer to store distance matrix, destructor releases the resource rmm::device_uvector L2NormBuf_OR_DistBuf(0, stream); @@ -643,13 +642,13 @@ void countSamplesInCluster(const raft::handle_t& handle, L2NormBuf_OR_DistBuf, workspace); - // Using TransformInputIteratorT to dereference an array of raft::KeyValuePair + // Using TransformInputIteratorT to dereference an array of cub::KeyValuePair // and converting them to just return the Key to be used in reduce_rows_by_key // prims detail::KeyValueIndexOp conversion_op; cub::TransformInputIterator, - raft::KeyValuePair*> + cub::KeyValuePair*> itr(minClusterAndDistance.data_handle(), conversion_op); // count # of samples in each cluster diff --git a/cpp/include/raft/cluster/kmeans.cuh b/cpp/include/raft/cluster/kmeans.cuh index 0ce35da4a5..539fc33c40 100644 --- a/cpp/include/raft/cluster/kmeans.cuh +++ b/cpp/include/raft/cluster/kmeans.cuh @@ -18,7 +18,6 @@ #include #include #include -#include #include namespace raft::cluster { @@ -354,7 +353,7 @@ void minClusterAndDistanceCompute( const KMeansParams& params, const raft::device_matrix_view X, const raft::device_matrix_view centroids, - const raft::device_vector_view, IndexT>& minClusterAndDistance, + const raft::device_vector_view, IndexT>& minClusterAndDistance, const raft::device_vector_view& L2NormX, rmm::device_uvector& L2NormBuf_OR_DistBuf, rmm::device_uvector& workspace) diff --git a/cpp/include/raft/core/detail/macros.hpp b/cpp/include/raft/core/detail/macros.hpp index bfb47437ad..66b67579fc 100644 --- a/cpp/include/raft/core/detail/macros.hpp +++ b/cpp/include/raft/core/detail/macros.hpp @@ -37,7 +37,7 @@ #define _RAFT_HOST_DEVICE _RAFT_HOST _RAFT_DEVICE #ifndef RAFT_INLINE_FUNCTION -#define RAFT_INLINE_FUNCTION _RAFT_HOST_DEVICE _RAFT_FORCEINLINE +#define RAFT_INLINE_FUNCTION _RAFT_FORCEINLINE _RAFT_HOST_DEVICE #endif /** diff --git a/cpp/include/raft/core/kvp.hpp b/cpp/include/raft/core/kvp.hpp deleted file mode 100644 index f6ea841dc4..0000000000 --- a/cpp/include/raft/core/kvp.hpp +++ /dev/null @@ -1,62 +0,0 @@ -/* - * Copyright (c) 2022, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -#ifdef _RAFT_HAS_CUDA -#include -#endif -namespace raft { -/** - * \brief A key identifier paired with a corresponding value - * - */ -template -struct KeyValuePair { - typedef _Key Key; ///< Key data type - typedef _Value Value; ///< Value data type - - Key key; ///< Item key - Value value; ///< Item value - - /// Constructor - RAFT_INLINE_FUNCTION KeyValuePair() {} - -#ifdef _RAFT_HAS_CUDA - /// Conversion Constructor to allow integration w/ cub - RAFT_INLINE_FUNCTION KeyValuePair(cub::KeyValuePair<_Key, _Value> kvp) - : key(kvp.key), value(kvp.value) - { - } - - RAFT_INLINE_FUNCTION operator cub::KeyValuePair<_Key, _Value>() - { - return cub::KeyValuePair(key, value); - } -#endif - - /// Constructor - RAFT_INLINE_FUNCTION KeyValuePair(Key const& key, Value const& value) : key(key), value(value) {} - - /// Inequality operator - RAFT_INLINE_FUNCTION bool operator!=(const KeyValuePair& b) - { - return (value != b.value) || (key != b.key); - } -}; -} // end namespace raft diff --git a/cpp/include/raft/distance/detail/fused_l2_nn.cuh b/cpp/include/raft/distance/detail/fused_l2_nn.cuh index 1385d0aa09..8aae7d40f4 100644 --- a/cpp/include/raft/distance/detail/fused_l2_nn.cuh +++ b/cpp/include/raft/distance/detail/fused_l2_nn.cuh @@ -16,8 +16,8 @@ #pragma once +#include #include -#include #include #include #include @@ -25,7 +25,6 @@ namespace raft { namespace distance { - namespace detail { #if (ENABLE_MEMCPY_ASYNC == 1) @@ -35,14 +34,15 @@ using namespace nvcuda::experimental; template struct KVPMinReduceImpl { - typedef raft::KeyValuePair KVP; + typedef cub::KeyValuePair KVP; + DI KVP operator()(LabelT rit, const KVP& a, const KVP& b) { return b.value < a.value ? b : a; } }; // KVPMinReduce template struct MinAndDistanceReduceOpImpl { - typedef typename raft::KeyValuePair KVP; + typedef typename cub::KeyValuePair KVP; DI void operator()(LabelT rid, KVP* out, const KVP& other) { if (other.value < out->value) { @@ -66,7 +66,7 @@ struct MinAndDistanceReduceOpImpl { template struct MinReduceOpImpl { - typedef typename raft::KeyValuePair KVP; + typedef typename cub::KeyValuePair KVP; DI void operator()(LabelT rid, DataT* out, const KVP& other) { if (other.value < *out) { *out = other.value; } @@ -146,7 +146,7 @@ __global__ __launch_bounds__(P::Nthreads, 2) void fusedL2NNkernel(OutT* min, { extern __shared__ char smem[]; - typedef KeyValuePair KVPair; + typedef cub::KeyValuePair KVPair; KVPair val[P::AccRowsPerTh]; #pragma unroll for (int i = 0; i < P::AccRowsPerTh; ++i) { @@ -285,7 +285,7 @@ void fusedL2NNImpl(OutT* min, dim3 blk(P::Nthreads); auto nblks = raft::ceildiv(m, P::Nthreads); constexpr auto maxVal = std::numeric_limits::max(); - typedef KeyValuePair KVPair; + typedef cub::KeyValuePair KVPair; // Accumulation operation lambda auto core_lambda = [] __device__(DataT & acc, DataT & x, DataT & y) { acc += x * y; }; diff --git a/cpp/include/raft/distance/fused_l2_nn.cuh b/cpp/include/raft/distance/fused_l2_nn.cuh index fb4fb8d34c..2915bce360 100644 --- a/cpp/include/raft/distance/fused_l2_nn.cuh +++ b/cpp/include/raft/distance/fused_l2_nn.cuh @@ -168,8 +168,7 @@ void fusedL2NN(OutT* min, * * @tparam DataT data type * @tparam OutT output type to either store 1-NN indices and their minimum - * distances (e.g. raft::KeyValuePair) or store only the min - * distances. + * distances (e.g. cub::KeyValuePair) or store only the min distances. * @tparam IdxT indexing arithmetic type * @param[out] min will contain the reduced output (Length = `m`) * (on device) diff --git a/cpp/include/raft/distance/specializations/distance.cuh b/cpp/include/raft/distance/specializations/distance.cuh index 73d075f260..3b7d08f2aa 100644 --- a/cpp/include/raft/distance/specializations/distance.cuh +++ b/cpp/include/raft/distance/specializations/distance.cuh @@ -31,4 +31,4 @@ #include #include #include -#include +//#include diff --git a/cpp/include/raft/distance/specializations/fused_l2_nn_min.cuh b/cpp/include/raft/distance/specializations/fused_l2_nn_min.cuh index 88e1216635..deddf65b37 100644 --- a/cpp/include/raft/distance/specializations/fused_l2_nn_min.cuh +++ b/cpp/include/raft/distance/specializations/fused_l2_nn_min.cuh @@ -16,14 +16,13 @@ #pragma once -#include #include namespace raft { namespace distance { -extern template void fusedL2NNMinReduce, int>( - raft::KeyValuePair* min, +extern template void fusedL2NNMinReduce, int>( + cub::KeyValuePair* min, const float* x, const float* y, const float* xn, @@ -35,8 +34,8 @@ extern template void fusedL2NNMinReduce, i bool sqrt, bool initOutBuffer, cudaStream_t stream); -extern template void fusedL2NNMinReduce, int64_t>( - raft::KeyValuePair* min, +extern template void fusedL2NNMinReduce, int64_t>( + cub::KeyValuePair* min, const float* x, const float* y, const float* xn, @@ -48,8 +47,8 @@ extern template void fusedL2NNMinReduce, int>( - raft::KeyValuePair* min, +extern template void fusedL2NNMinReduce, int>( + cub::KeyValuePair* min, const double* x, const double* y, const double* xn, @@ -61,8 +60,8 @@ extern template void fusedL2NNMinReduce, bool sqrt, bool initOutBuffer, cudaStream_t stream); -extern template void fusedL2NNMinReduce, int64_t>( - raft::KeyValuePair* min, +extern template void fusedL2NNMinReduce, int64_t>( + cub::KeyValuePair* min, const double* x, const double* y, const double* xn, diff --git a/cpp/include/raft/sparse/spatial/detail/connect_components.cuh b/cpp/include/raft/sparse/spatial/detail/connect_components.cuh index 1c14669e28..f515ab5739 100644 --- a/cpp/include/raft/sparse/spatial/detail/connect_components.cuh +++ b/cpp/include/raft/sparse/spatial/detail/connect_components.cuh @@ -31,7 +31,6 @@ #include #include -#include #include #include #include @@ -46,6 +45,43 @@ #include namespace raft::sparse::spatial::detail { +/** + * \brief A key identifier paired with a corresponding value + * + * NOTE: This is being included close to where it's being used + * because it's meant to be temporary. There is a conflict + * between the cub and thrust_cub namespaces with older CUDA + * versions so we're using our own as a workaround. + */ +template +struct KeyValuePair { + typedef _Key Key; ///< Key data type + typedef _Value Value; ///< Value data type + + Key key; ///< Item key + Value value; ///< Item value + + /// Constructor + __host__ __device__ __forceinline__ KeyValuePair() {} + + /// Copy Constructor + __host__ __device__ __forceinline__ KeyValuePair(cub::KeyValuePair<_Key, _Value> kvp) + : key(kvp.key), value(kvp.value) + { + } + + /// Constructor + __host__ __device__ __forceinline__ KeyValuePair(Key const& key, Value const& value) + : key(key), value(value) + { + } + + /// Inequality operator + __host__ __device__ __forceinline__ bool operator!=(const KeyValuePair& b) + { + return (value != b.value) || (key != b.key); + } +}; /** * Functor with reduction ops for performing fused 1-nn @@ -61,7 +97,7 @@ struct FixConnectivitiesRedOp { FixConnectivitiesRedOp(value_idx* colors_, value_idx m_) : colors(colors_), m(m_){}; - typedef typename raft::KeyValuePair KVP; + typedef typename cub::KeyValuePair KVP; DI void operator()(value_idx rit, KVP* out, const KVP& other) { if (rit < m && other.value < out->value && colors[rit] != colors[other.key]) { @@ -112,7 +148,7 @@ struct TupleComp { template struct CubKVPMinReduce { - typedef raft::KeyValuePair KVP; + typedef cub::KeyValuePair KVP; DI KVP @@ -161,7 +197,7 @@ struct LookupColorOp { DI value_idx - operator()(const raft::KeyValuePair& kvp) + operator()(const cub::KeyValuePair& kvp) { return colors[kvp.key]; } @@ -182,7 +218,7 @@ struct LookupColorOp { * @param[in] stream cuda stream for which to order cuda operations */ template -void perform_1nn(raft::KeyValuePair* kvp, +void perform_1nn(cub::KeyValuePair* kvp, value_idx* nn_colors, value_idx* colors, const value_t* X, @@ -196,7 +232,7 @@ void perform_1nn(raft::KeyValuePair* kvp, raft::linalg::rowNorm(x_norm.data(), X, n_cols, n_rows, raft::linalg::L2Norm, true, stream); - raft::distance::fusedL2NN, value_idx>( + raft::distance::fusedL2NN, value_idx>( kvp, X, X, @@ -231,7 +267,7 @@ void perform_1nn(raft::KeyValuePair* kvp, template void sort_by_color(value_idx* colors, value_idx* nn_colors, - raft::KeyValuePair* kvp, + cub::KeyValuePair* kvp, value_idx* src_indices, size_t n_rows, cudaStream_t stream) @@ -253,7 +289,7 @@ __global__ void min_components_by_color_kernel(value_idx* out_rows, value_t* out_vals, const value_idx* out_index, const value_idx* indices, - const raft::KeyValuePair* kvp, + const cub::KeyValuePair* kvp, size_t nnz) { size_t tid = blockDim.x * blockIdx.x + threadIdx.x; @@ -287,7 +323,7 @@ template void min_components_by_color(raft::sparse::COO& coo, const value_idx* out_index, const value_idx* indices, - const raft::KeyValuePair* kvp, + const cub::KeyValuePair* kvp, size_t nnz, cudaStream_t stream) { @@ -348,7 +384,7 @@ void connect_components( * is guaranteed to be != color of its nearest neighbor. */ rmm::device_uvector nn_colors(n_rows, stream); - rmm::device_uvector> temp_inds_dists(n_rows, stream); + rmm::device_uvector> temp_inds_dists(n_rows, stream); rmm::device_uvector src_indices(n_rows, stream); perform_1nn(temp_inds_dists.data(), diff --git a/cpp/include/raft/spatial/knn/detail/ann_kmeans_balanced.cuh b/cpp/include/raft/spatial/knn/detail/ann_kmeans_balanced.cuh index bf0df065b2..6d3289e14c 100644 --- a/cpp/include/raft/spatial/knn/detail/ann_kmeans_balanced.cuh +++ b/cpp/include/raft/spatial/knn/detail/ann_kmeans_balanced.cuh @@ -84,9 +84,9 @@ inline void predict_float_core(const handle_t& handle, auto workspace = raft::make_device_mdarray( handle, mr, make_extents((sizeof(int)) * n_rows)); - auto minClusterAndDistance = raft::make_device_mdarray, IdxT>( + auto minClusterAndDistance = raft::make_device_mdarray, IdxT>( handle, mr, make_extents(n_rows)); - raft::KeyValuePair initial_value(0, std::numeric_limits::max()); + cub::KeyValuePair initial_value(0, std::numeric_limits::max()); thrust::fill(handle.get_thrust_policy(), minClusterAndDistance.data_handle(), minClusterAndDistance.data_handle() + minClusterAndDistance.size(), @@ -97,7 +97,7 @@ inline void predict_float_core(const handle_t& handle, raft::linalg::rowNorm( centroidsNorm.data_handle(), centers, dim, n_clusters, raft::linalg::L2Norm, true, stream); - raft::distance::fusedL2NNMinReduce, IdxT>( + raft::distance::fusedL2NNMinReduce, IdxT>( minClusterAndDistance.data_handle(), dataset, centers, @@ -117,7 +117,7 @@ inline void predict_float_core(const handle_t& handle, minClusterAndDistance.data_handle(), minClusterAndDistance.data_handle() + n_rows, labels, - [=] __device__(raft::KeyValuePair kvp) { + [=] __device__(cub::KeyValuePair kvp) { return static_cast(kvp.key); }); break; diff --git a/cpp/src/distance/specializations/fused_l2_nn_double_int.cu b/cpp/src/distance/specializations/fused_l2_nn_double_int.cu index 4448ee0cc2..b032261169 100644 --- a/cpp/src/distance/specializations/fused_l2_nn_double_int.cu +++ b/cpp/src/distance/specializations/fused_l2_nn_double_int.cu @@ -14,14 +14,13 @@ * limitations under the License. */ -#include #include namespace raft { namespace distance { -template void fusedL2NNMinReduce, int>( - raft::KeyValuePair* min, +template void fusedL2NNMinReduce, int>( + cub::KeyValuePair* min, const double* x, const double* y, const double* xn, diff --git a/cpp/src/distance/specializations/fused_l2_nn_double_int64.cu b/cpp/src/distance/specializations/fused_l2_nn_double_int64.cu index 54478a1656..a208b013d5 100644 --- a/cpp/src/distance/specializations/fused_l2_nn_double_int64.cu +++ b/cpp/src/distance/specializations/fused_l2_nn_double_int64.cu @@ -14,14 +14,13 @@ * limitations under the License. */ -#include #include namespace raft { namespace distance { -template void fusedL2NNMinReduce, int64_t>( - raft::KeyValuePair* min, +template void fusedL2NNMinReduce, int64_t>( + cub::KeyValuePair* min, const double* x, const double* y, const double* xn, diff --git a/cpp/src/distance/specializations/fused_l2_nn_float_int.cu b/cpp/src/distance/specializations/fused_l2_nn_float_int.cu index e25c9fad91..f58349a826 100644 --- a/cpp/src/distance/specializations/fused_l2_nn_float_int.cu +++ b/cpp/src/distance/specializations/fused_l2_nn_float_int.cu @@ -14,14 +14,13 @@ * limitations under the License. */ -#include #include namespace raft { namespace distance { -template void fusedL2NNMinReduce, int>( - raft::KeyValuePair* min, +template void fusedL2NNMinReduce, int>( + cub::KeyValuePair* min, const float* x, const float* y, const float* xn, diff --git a/cpp/src/distance/specializations/fused_l2_nn_float_int64.cu b/cpp/src/distance/specializations/fused_l2_nn_float_int64.cu index b7abd91304..e43c3aa4e9 100644 --- a/cpp/src/distance/specializations/fused_l2_nn_float_int64.cu +++ b/cpp/src/distance/specializations/fused_l2_nn_float_int64.cu @@ -14,14 +14,13 @@ * limitations under the License. */ -#include #include namespace raft { namespace distance { -template void fusedL2NNMinReduce, int64_t>( - raft::KeyValuePair* min, +template void fusedL2NNMinReduce, int64_t>( + cub::KeyValuePair* min, const float* x, const float* y, const float* xn, diff --git a/cpp/test/distance/fused_l2_nn.cu b/cpp/test/distance/fused_l2_nn.cu index fdb6bf68fe..2838a2209e 100644 --- a/cpp/test/distance/fused_l2_nn.cu +++ b/cpp/test/distance/fused_l2_nn.cu @@ -16,7 +16,6 @@ #include "../test_utils.h" #include -#include #include #include #include @@ -24,20 +23,19 @@ #include #include -#if defined RAFT_NN_COMPILED -#include -#endif - -#if defined RAFT_DISTANCE_COMPILED -#include -#endif +// TODO: Once fusedL2NN is specialized in the raft_distance shared library, add +// the following: +// +// #if defined RAFT_NN_COMPILED +// #include +// #endif namespace raft { namespace distance { template struct CubKVPMinReduce { - typedef raft::KeyValuePair KVP; + typedef cub::KeyValuePair KVP; DI KVP operator()(LabelT rit, const KVP& a, const KVP& b) { return b.value < a.value ? b : a; } @@ -46,7 +44,7 @@ struct CubKVPMinReduce { }; // KVPMinReduce template -__global__ void naiveKernel(raft::KeyValuePair* min, +__global__ void naiveKernel(cub::KeyValuePair* min, DataT* x, DataT* y, int m, @@ -66,10 +64,10 @@ __global__ void naiveKernel(raft::KeyValuePair* min, } if (Sqrt) { acc = raft::mySqrt(acc); } ReduceOpT redOp; - typedef cub::WarpReduce> WarpReduce; + typedef cub::WarpReduce> WarpReduce; __shared__ typename WarpReduce::TempStorage temp[NWARPS]; int warpId = threadIdx.x / raft::WarpSize; - raft::KeyValuePair tmp; + cub::KeyValuePair tmp; tmp.key = nidx; tmp.value = midx >= m || nidx >= n ? maxVal : acc; tmp = WarpReduce(temp[warpId]).Reduce(tmp, CubKVPMinReduce()); @@ -84,7 +82,7 @@ __global__ void naiveKernel(raft::KeyValuePair* min, } template -void naive(raft::KeyValuePair* min, +void naive(cub::KeyValuePair* min, DataT* x, DataT* y, int m, @@ -98,7 +96,7 @@ void naive(raft::KeyValuePair* min, RAFT_CUDA_TRY(cudaMemsetAsync(workspace, 0, sizeof(int) * m, stream)); auto blks = raft::ceildiv(m, 256); MinAndDistanceReduceOp op; - detail::initKernel, int> + detail::initKernel, int> <<>>(min, m, std::numeric_limits::max(), op); RAFT_CUDA_TRY(cudaGetLastError()); naiveKernel, 16> @@ -167,8 +165,8 @@ class FusedL2NNTest : public ::testing::TestWithParam> { rmm::device_uvector y; rmm::device_uvector xn; rmm::device_uvector yn; - rmm::device_uvector> min; - rmm::device_uvector> min_ref; + rmm::device_uvector> min; + rmm::device_uvector> min_ref; rmm::device_uvector workspace; raft::handle_t handle; cudaStream_t stream; @@ -181,34 +179,33 @@ class FusedL2NNTest : public ::testing::TestWithParam> { naive(min_ref.data(), x.data(), y.data(), m, n, k, (int*)workspace.data(), stream); } - void runTest(raft::KeyValuePair* out) + void runTest(cub::KeyValuePair* out) { int m = params.m; int n = params.n; int k = params.k; MinAndDistanceReduceOp redOp; - fusedL2NN, int>( - out, - x.data(), - y.data(), - xn.data(), - yn.data(), - m, - n, - k, - (void*)workspace.data(), - redOp, - raft::distance::KVPMinReduce(), - Sqrt, - true, - stream); + fusedL2NN, int>(out, + x.data(), + y.data(), + xn.data(), + yn.data(), + m, + n, + k, + (void*)workspace.data(), + redOp, + raft::distance::KVPMinReduce(), + Sqrt, + true, + stream); RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); } }; template struct CompareApproxAbsKVP { - typedef typename raft::KeyValuePair KVP; + typedef typename cub::KeyValuePair KVP; CompareApproxAbsKVP(T eps_) : eps(eps_) {} bool operator()(const KVP& a, const KVP& b) const { @@ -224,7 +221,7 @@ struct CompareApproxAbsKVP { template struct CompareExactKVP { - typedef typename raft::KeyValuePair KVP; + typedef typename cub::KeyValuePair KVP; bool operator()(const KVP& a, const KVP& b) const { if (a.value != b.value) return false; @@ -233,13 +230,13 @@ struct CompareExactKVP { }; template -::testing::AssertionResult devArrMatch(const raft::KeyValuePair* expected, - const raft::KeyValuePair* actual, +::testing::AssertionResult devArrMatch(const cub::KeyValuePair* expected, + const cub::KeyValuePair* actual, size_t size, L eq_compare, cudaStream_t stream = 0) { - typedef typename raft::KeyValuePair KVP; + typedef typename cub::KeyValuePair KVP; std::shared_ptr exp_h(new KVP[size]); std::shared_ptr act_h(new KVP[size]); raft::update_host(exp_h.get(), expected, size, stream); @@ -387,7 +384,7 @@ class FusedL2NNDetTest : public FusedL2NNTest { raft::handle_t handle; cudaStream_t stream; - rmm::device_uvector> min1; + rmm::device_uvector> min1; static const int NumRepeats = 100;