From f87c1e0fb0facadfb7cef353b0d7bfa089f8ba06 Mon Sep 17 00:00:00 2001 From: Louis Sugy Date: Fri, 25 Nov 2022 19:54:13 +0100 Subject: [PATCH 01/22] Unify use of common functors --- cpp/include/raft/cluster/detail/kmeans.cuh | 116 ++++++-------- .../raft/cluster/detail/kmeans_common.cuh | 69 ++++----- cpp/include/raft/cluster/kmeans.cuh | 3 +- cpp/include/raft/distance/detail/cosine.cuh | 11 +- .../raft/distance/detail/euclidean.cuh | 11 +- .../raft/distance/detail/fused_l2_nn.cuh | 10 +- .../raft/distance/detail/hellinger.cuh | 14 +- cpp/include/raft/label/detail/classlabels.cuh | 3 +- cpp/include/raft/linalg/add.cuh | 3 +- cpp/include/raft/linalg/detail/add.cuh | 8 +- cpp/include/raft/linalg/detail/divide.cuh | 5 +- cpp/include/raft/linalg/detail/eltwise.cuh | 17 +- cpp/include/raft/linalg/detail/functional.cuh | 69 --------- cpp/include/raft/linalg/detail/multiply.cuh | 3 +- .../raft/linalg/detail/strided_reduction.cuh | 4 +- cpp/include/raft/linalg/detail/subtract.cuh | 6 +- cpp/include/raft/linalg/divide.cuh | 3 +- cpp/include/raft/linalg/eltwise.cuh | 6 - cpp/include/raft/linalg/power.cuh | 6 +- cpp/include/raft/linalg/sqrt.cuh | 3 +- cpp/include/raft/matrix/detail/gather.cuh | 37 +---- cpp/include/raft/matrix/detail/math.cuh | 46 +----- .../sparse/distance/detail/lp_distance.cuh | 35 ++--- cpp/include/raft/sparse/op/detail/slice.cuh | 11 +- .../raft/spatial/knn/detail/ann_quantized.cuh | 6 +- .../raft/spatial/knn/detail/fused_l2_knn.cuh | 27 ++-- .../raft/spatial/knn/detail/ivf_pq_build.cuh | 9 +- .../raft/spatial/knn/detail/ivf_pq_search.cuh | 16 +- .../raft/spatial/knn/detail/processing.cuh | 53 +++---- cpp/include/raft/stats/detail/mean_center.cuh | 20 +-- .../raft/stats/detail/silhouette_score.cuh | 18 +-- .../raft/stats/detail/weighted_mean.cuh | 4 +- cpp/include/raft/util/cuda_utils.cuh | 146 ++++++++++++++++-- cpp/src/distance/cluster_cost.cuh | 22 ++- cpp/test/cluster/kmeans.cu | 32 ++-- cpp/test/linalg/binary_op.cu | 10 +- cpp/test/linalg/coalesced_reduction.cu | 3 +- cpp/test/linalg/map_then_reduce.cu | 11 +- cpp/test/linalg/matrix_vector.cu | 15 +- cpp/test/linalg/norm.cu | 22 ++- cpp/test/linalg/normalize.cu | 10 +- cpp/test/matrix/linewise_op.cu | 7 +- cpp/test/sparse/dist_coo_spmv.cu | 13 +- 43 files changed, 394 insertions(+), 549 deletions(-) delete mode 100644 cpp/include/raft/linalg/detail/functional.cuh diff --git a/cpp/include/raft/cluster/detail/kmeans.cuh b/cpp/include/raft/cluster/detail/kmeans.cuh index 5aa9870b46..8b04dd2a75 100644 --- a/cpp/include/raft/cluster/detail/kmeans.cuh +++ b/cpp/include/raft/cluster/detail/kmeans.cuh @@ -195,16 +195,15 @@ void kmeansPlusPlus(const raft::handle_t& handle, // Outputs minDistanceBuf[n_trials x n_samples] where minDistance[i, :] contains updated // minClusterDistance that includes candidate-i auto minDistBuf = distBuffer.view(); - raft::linalg::matrixVectorOp( - minDistBuf.data_handle(), - pwd.data_handle(), - minClusterDistance.data_handle(), - pwd.extent(1), - pwd.extent(0), - true, - true, - [=] __device__(DataT mat, DataT vec) { return vec <= mat ? vec : mat; }, - stream); + raft::linalg::matrixVectorOp(minDistBuf.data_handle(), + pwd.data_handle(), + minClusterDistance.data_handle(), + pwd.extent(1), + pwd.extent(0), + true, + true, + raft::Min{}, + stream); // Calculate costPerCandidate[n_trials] where costPerCandidate[i] is the cluster cost when using // centroid candidate-i @@ -321,21 +320,15 @@ void update_centroids(const raft::handle_t& handle, // weight_per_cluster[n_clusters] - 1D array, weight_per_cluster[i] contains sum of weights in // cluster-i. // Note - when weight_per_cluster[i] is 0, new_centroids[i] is reset to 0 - raft::linalg::matrixVectorOp( - new_centroids.data_handle(), - new_centroids.data_handle(), - weight_per_cluster.data_handle(), - new_centroids.extent(1), - new_centroids.extent(0), - true, - false, - [=] __device__(DataT mat, DataT vec) { - if (vec == 0) - return DataT(0); - else - return mat / vec; - }, - handle.get_stream()); + raft::linalg::matrixVectorOp(new_centroids.data_handle(), + new_centroids.data_handle(), + weight_per_cluster.data_handle(), + new_centroids.extent(1), + new_centroids.extent(0), + true, + false, + raft::DivideCheckZero{}, + handle.get_stream()); // copy centroids[i] to new_centroids[i] when weight_per_cluster[i] is 0 cub::ArgIndexInputIterator itr_wt(weight_per_cluster.data_handle()); @@ -351,9 +344,7 @@ void update_centroids(const raft::handle_t& handle, // copy when the sum of weights in the cluster is 0 return map.value == 0; }, - [=] __device__(raft::KeyValuePair map) { // map - return map.key; - }, + raft::KeyOp{}, handle.get_stream()); } @@ -394,7 +385,7 @@ void kmeans_fit_main(const raft::handle_t& handle, // resource auto wtInCluster = raft::make_device_vector(handle, n_clusters); - rmm::device_scalar> clusterCostD(stream); + rmm::device_scalar clusterCostD(stream); // L2 norm of X: ||x||^2 auto L2NormX = raft::make_device_vector(handle, n_samples); @@ -465,16 +456,12 @@ void kmeans_fit_main(const raft::handle_t& handle, // compute the squared norm between the newCentroids and the original // centroids, destructor releases the resource auto sqrdNorm = raft::make_device_scalar(handle, DataT(0)); - raft::linalg::mapThenSumReduce( - sqrdNorm.data_handle(), - newCentroids.size(), - [=] __device__(const DataT a, const DataT b) { - DataT diff = a - b; - return diff * diff; - }, - stream, - centroids.data_handle(), - newCentroids.data_handle()); + raft::linalg::mapThenSumReduce(sqrdNorm.data_handle(), + newCentroids.size(), + raft::SqDiff{}, + stream, + centroids.data_handle(), + newCentroids.data_handle()); DataT sqrdNormError = 0; raft::copy(&sqrdNormError, sqrdNorm.data_handle(), sqrdNorm.size(), stream); @@ -489,18 +476,11 @@ void kmeans_fit_main(const raft::handle_t& handle, minClusterAndDistance.view(), workspace, raft::make_device_scalar_view(clusterCostD.data()), - [] __device__(const raft::KeyValuePair& a, - const raft::KeyValuePair& b) { - raft::KeyValuePair res; - res.key = 0; - res.value = a.value + b.value; - return res; - }); - - DataT curClusteringCost = 0; - raft::copy(&curClusteringCost, &(clusterCostD.data()->value), 1, stream); - - handle.sync_stream(stream); + raft::ValueOp{}, + raft::Sum{}); + + DataT curClusteringCost = clusterCostD.value(stream); + ASSERT(curClusteringCost != (DataT)0.0, "Too few points and centroids being found is getting 0 cost from " "centers"); @@ -553,15 +533,10 @@ void kmeans_fit_main(const raft::handle_t& handle, minClusterAndDistance.view(), workspace, raft::make_device_scalar_view(clusterCostD.data()), - [] __device__(const raft::KeyValuePair& a, - const raft::KeyValuePair& b) { - raft::KeyValuePair res; - res.key = 0; - res.value = a.value + b.value; - return res; - }); + raft::ValueOp{}, + raft::Sum{}); - raft::copy(inertia.data_handle(), &(clusterCostD.data()->value), 1, stream); + inertia[0] = clusterCostD.value(stream); RAFT_LOG_DEBUG("KMeans.fit: completed after %d iterations with %f inertia[0] ", n_iter[0] > params.max_iter ? n_iter[0] - 1 : n_iter[0], @@ -673,7 +648,8 @@ void initScalableKMeansPlusPlus(const raft::handle_t& handle, minClusterDistanceVec.view(), workspace, raft::make_device_scalar_view(clusterCost.data()), - [] __device__(const DataT& a, const DataT& b) { return a + b; }); + raft::Nop{}, + raft::Sum{}); auto psi = clusterCost.value(stream); @@ -705,7 +681,8 @@ void initScalableKMeansPlusPlus(const raft::handle_t& handle, minClusterDistanceVec.view(), workspace, raft::make_device_scalar_view(clusterCost.data()), - [] __device__(const DataT& a, const DataT& b) { return a + b; }); + raft::Nop{}, + raft::Sum{}); psi = clusterCost.value(stream); @@ -1074,7 +1051,7 @@ void kmeans_predict(handle_t const& handle, workspace); // calculate cluster cost phi_x(C) - rmm::device_scalar> clusterCostD(stream); + rmm::device_scalar clusterCostD(stream); // TODO: add different templates for InType of binaryOp to avoid thrust transform thrust::transform(handle.get_thrust_policy(), minClusterAndDistance.data_handle(), @@ -1092,21 +1069,16 @@ void kmeans_predict(handle_t const& handle, minClusterAndDistance.view(), workspace, raft::make_device_scalar_view(clusterCostD.data()), - [] __device__(const raft::KeyValuePair& a, - const raft::KeyValuePair& b) { - raft::KeyValuePair res; - res.key = 0; - res.value = a.value + b.value; - return res; - }); - - raft::copy(inertia.data_handle(), &(clusterCostD.data()->value), 1, stream); + raft::ValueOp{}, + raft::Sum{}); thrust::transform(handle.get_thrust_policy(), minClusterAndDistance.data_handle(), minClusterAndDistance.data_handle() + minClusterAndDistance.size(), labels.data_handle(), - [=] __device__(raft::KeyValuePair pair) { return pair.key; }); + raft::KeyOp{}); + + inertia[0] = clusterCostD.value(stream); } template diff --git a/cpp/include/raft/cluster/detail/kmeans_common.cuh b/cpp/include/raft/cluster/detail/kmeans_common.cuh index 2973be8c23..ab592e584e 100644 --- a/cpp/include/raft/cluster/detail/kmeans_common.cuh +++ b/cpp/include/raft/cluster/detail/kmeans_common.cuh @@ -157,11 +157,7 @@ void checkWeight(const raft::handle_t& handle, auto scale = static_cast(n_samples) / wt_sum; raft::linalg::unaryOp( - weight.data_handle(), - weight.data_handle(), - n_samples, - [=] __device__(const DataT& wt) { return wt * scale; }, - stream); + weight.data_handle(), weight.data_handle(), n_samples, raft::ScalarMul{scale}, stream); } } @@ -179,33 +175,42 @@ IndexT getCentroidsBatchSize(int batch_centroids, IndexT n_local_clusters) return (minVal == 0) ? n_local_clusters : minVal; } -template +template void computeClusterCost(const raft::handle_t& handle, - raft::device_vector_view minClusterDistance, + raft::device_vector_view minClusterDistance, rmm::device_uvector& workspace, - raft::device_scalar_view clusterCost, + raft::device_scalar_view clusterCost, + MainOpT main_op, ReductionOpT reduction_op) { - cudaStream_t stream = handle.get_stream(); + cudaStream_t stream = handle.get_stream(); + + cub::TransformInputIterator itr(minClusterDistance.data_handle(), + main_op); + size_t temp_storage_bytes = 0; RAFT_CUDA_TRY(cub::DeviceReduce::Reduce(nullptr, temp_storage_bytes, - minClusterDistance.data_handle(), + itr, clusterCost.data_handle(), minClusterDistance.size(), reduction_op, - DataT(), + OutputT(), stream)); workspace.resize(temp_storage_bytes, stream); RAFT_CUDA_TRY(cub::DeviceReduce::Reduce(workspace.data(), temp_storage_bytes, - minClusterDistance.data_handle(), + itr, clusterCost.data_handle(), minClusterDistance.size(), reduction_op, - DataT(), + OutputT(), stream)); } @@ -267,9 +272,7 @@ void sampleCentroids(const raft::handle_t& handle, sampledMinClusterDistance.data_handle(), nPtsSampledInRank, inRankCp.data(), - [=] __device__(raft::KeyValuePair val) { // MapTransformOp - return val.key; - }, + raft::KeyOp{}, stream); } @@ -464,10 +467,8 @@ void minClusterAndDistanceCompute( pair.value = val; return pair; }, - [=] __device__(raft::KeyValuePair a, raft::KeyValuePair b) { - return (b.value < a.value) ? b : a; - }, - [=] __device__(raft::KeyValuePair pair) { return pair; }); + raft::ArgMin{}, + raft::Nop, IndexT>{}); } } } @@ -542,7 +543,6 @@ void minClusterDistanceCompute(const raft::handle_t& handle, if (is_fused) { workspace.resize((sizeof(IndexT)) * ns, stream); - // todo(lsugy): remove cIdx raft::distance::fusedL2NNMinReduce( minClusterDistanceView.data_handle(), datasetView.data_handle(), @@ -577,23 +577,16 @@ void minClusterDistanceCompute(const raft::handle_t& handle, pairwise_distance_kmeans( handle, datasetView, centroidsView, pairwiseDistanceView, workspace, metric); - raft::linalg::coalescedReduction( - minClusterDistanceView.data_handle(), - pairwiseDistanceView.data_handle(), - pairwiseDistanceView.extent(1), - pairwiseDistanceView.extent(0), - std::numeric_limits::max(), - stream, - true, - [=] __device__(DataT val, IndexT i) { // MainLambda - return val; - }, - [=] __device__(DataT a, DataT b) { // ReduceLambda - return (b < a) ? b : a; - }, - [=] __device__(DataT val) { // FinalLambda - return val; - }); + raft::linalg::coalescedReduction(minClusterDistanceView.data_handle(), + pairwiseDistanceView.data_handle(), + pairwiseDistanceView.extent(1), + pairwiseDistanceView.extent(0), + std::numeric_limits::max(), + stream, + true, + raft::Nop{}, + raft::Min{}, + raft::Nop{}); } } } diff --git a/cpp/include/raft/cluster/kmeans.cuh b/cpp/include/raft/cluster/kmeans.cuh index d64815244b..ac51164375 100644 --- a/cpp/include/raft/cluster/kmeans.cuh +++ b/cpp/include/raft/cluster/kmeans.cuh @@ -313,7 +313,8 @@ void cluster_cost(const raft::handle_t& handle, raft::device_scalar_view clusterCost, ReductionOpT reduction_op) { - detail::computeClusterCost(handle, minClusterDistance, workspace, clusterCost, reduction_op); + detail::computeClusterCost( + handle, minClusterDistance, workspace, clusterCost, raft::Nop{}, reduction_op); } /** diff --git a/cpp/include/raft/distance/detail/cosine.cuh b/cpp/include/raft/distance/detail/cosine.cuh index f06051962f..ed843bb0c7 100644 --- a/cpp/include/raft/distance/detail/cosine.cuh +++ b/cpp/include/raft/distance/detail/cosine.cuh @@ -229,8 +229,6 @@ void cosineAlgo1(Index_ m, cudaStream_t stream, bool isRowMajor) { - auto norm_op = [] __device__(AccType in) { return raft::mySqrt(in); }; - // raft distance support inputs as float/double and output as uint8_t/float/double. static_assert(!((sizeof(OutType) > 1) && (sizeof(AccType) != sizeof(OutType))), "OutType can be uint8_t, float, double," @@ -248,10 +246,13 @@ void cosineAlgo1(Index_ m, InType* row_vec = workspace; if (pA != pB) { row_vec += m; - raft::linalg::rowNorm(col_vec, pA, k, m, raft::linalg::L2Norm, isRowMajor, stream, norm_op); - raft::linalg::rowNorm(row_vec, pB, k, n, raft::linalg::L2Norm, isRowMajor, stream, norm_op); + raft::linalg::rowNorm( + col_vec, pA, k, m, raft::linalg::L2Norm, isRowMajor, stream, raft::SqrtOp{}); + raft::linalg::rowNorm( + row_vec, pB, k, n, raft::linalg::L2Norm, isRowMajor, stream, raft::SqrtOp{}); } else { - raft::linalg::rowNorm(col_vec, pA, k, m, raft::linalg::L2Norm, isRowMajor, stream, norm_op); + raft::linalg::rowNorm( + col_vec, pA, k, m, raft::linalg::L2Norm, isRowMajor, stream, raft::SqrtOp{}); } if (isRowMajor) { diff --git a/cpp/include/raft/distance/detail/euclidean.cuh b/cpp/include/raft/distance/detail/euclidean.cuh index 5ea74fa884..e8348b3cab 100644 --- a/cpp/include/raft/distance/detail/euclidean.cuh +++ b/cpp/include/raft/distance/detail/euclidean.cuh @@ -247,8 +247,6 @@ void euclideanAlgo1(Index_ m, cudaStream_t stream, bool isRowMajor) { - auto norm_op = [] __device__(InType in) { return in; }; - // raft distance support inputs as float/double and output as uint8_t/float/double. static_assert(!((sizeof(OutType) > 1) && (sizeof(AccType) != sizeof(OutType))), "OutType can be uint8_t, float, double," @@ -266,10 +264,13 @@ void euclideanAlgo1(Index_ m, InType* row_vec = workspace; if (pA != pB) { row_vec += m; - raft::linalg::rowNorm(col_vec, pA, k, m, raft::linalg::L2Norm, isRowMajor, stream, norm_op); - raft::linalg::rowNorm(row_vec, pB, k, n, raft::linalg::L2Norm, isRowMajor, stream, norm_op); + raft::linalg::rowNorm( + col_vec, pA, k, m, raft::linalg::L2Norm, isRowMajor, stream, raft::Nop{}); + raft::linalg::rowNorm( + row_vec, pB, k, n, raft::linalg::L2Norm, isRowMajor, stream, raft::Nop{}); } else { - raft::linalg::rowNorm(col_vec, pA, k, m, raft::linalg::L2Norm, isRowMajor, stream, norm_op); + raft::linalg::rowNorm( + col_vec, pA, k, m, raft::linalg::L2Norm, isRowMajor, stream, raft::Nop{}); } if (isRowMajor) { diff --git a/cpp/include/raft/distance/detail/fused_l2_nn.cuh b/cpp/include/raft/distance/detail/fused_l2_nn.cuh index e8c2648c2e..437b45f1d2 100644 --- a/cpp/include/raft/distance/detail/fused_l2_nn.cuh +++ b/cpp/include/raft/distance/detail/fused_l2_nn.cuh @@ -298,8 +298,6 @@ void fusedL2NNImpl(OutT* min, RAFT_CUDA_TRY(cudaGetLastError()); } - auto fin_op = [] __device__(DataT d_val, int g_d_idx) { return d_val; }; - constexpr size_t shmemSize = P::SmemSize + ((P::Mblk + P::Nblk) * sizeof(DataT)); if (sqrt) { auto fusedL2NNSqrt = fusedL2NNkernel; + raft::Nop>; dim3 grid = launchConfigGenerator

(m, n, shmemSize, fusedL2NNSqrt); fusedL2NNSqrt<<>>( - min, x, y, xn, yn, m, n, k, maxVal, workspace, redOp, pairRedOp, core_lambda, fin_op); + min, x, y, xn, yn, m, n, k, maxVal, workspace, redOp, pairRedOp, core_lambda, raft::Nop{}); } else { auto fusedL2NN = fusedL2NNkernel; + raft::Nop>; dim3 grid = launchConfigGenerator

(m, n, shmemSize, fusedL2NN); fusedL2NN<<>>( - min, x, y, xn, yn, m, n, k, maxVal, workspace, redOp, pairRedOp, core_lambda, fin_op); + min, x, y, xn, yn, m, n, k, maxVal, workspace, redOp, pairRedOp, core_lambda, raft::Nop{}); } RAFT_CUDA_TRY(cudaGetLastError()); diff --git a/cpp/include/raft/distance/detail/hellinger.cuh b/cpp/include/raft/distance/detail/hellinger.cuh index 31854fd1d6..79d134459b 100644 --- a/cpp/include/raft/distance/detail/hellinger.cuh +++ b/cpp/include/raft/distance/detail/hellinger.cuh @@ -78,14 +78,10 @@ static void hellingerImpl(const DataT* x, dim3 blk(KPolicy::Nthreads); - auto unaryOp_lambda = [] __device__(DataT input) { return raft::mySqrt(input); }; // First sqrt x and y - raft::linalg::unaryOp( - (DataT*)x, x, m * k, unaryOp_lambda, stream); - + raft::linalg::unaryOp, IdxT>((DataT*)x, x, m * k, SqrtOp{}, stream); if (x != y) { - raft::linalg::unaryOp( - (DataT*)y, y, n * k, unaryOp_lambda, stream); + raft::linalg::unaryOp, IdxT>((DataT*)y, y, n * k, SqrtOp{}, stream); } // Accumulation operation lambda @@ -145,11 +141,9 @@ static void hellingerImpl(const DataT* x, } // Revert sqrt of x and y - raft::linalg::unaryOp( - (DataT*)x, x, m * k, unaryOp_lambda, stream); + raft::linalg::unaryOp, IdxT>((DataT*)x, x, m * k, SqrtOp{}, stream); if (x != y) { - raft::linalg::unaryOp( - (DataT*)y, y, n * k, unaryOp_lambda, stream); + raft::linalg::unaryOp, IdxT>((DataT*)y, y, n * k, SqrtOp{}, stream); } RAFT_CUDA_TRY(cudaGetLastError()); diff --git a/cpp/include/raft/label/detail/classlabels.cuh b/cpp/include/raft/label/detail/classlabels.cuh index 0af1c70b91..a3a98d3124 100644 --- a/cpp/include/raft/label/detail/classlabels.cuh +++ b/cpp/include/raft/label/detail/classlabels.cuh @@ -194,8 +194,7 @@ void make_monotonic( template void make_monotonic(Type* out, Type* in, size_t N, cudaStream_t stream, bool zero_based = false) { - make_monotonic( - out, in, N, stream, [] __device__(Type val) { return false; }, zero_based); + make_monotonic(out, in, N, stream, raft::ConstOp(false), zero_based); } }; // namespace detail diff --git a/cpp/include/raft/linalg/add.cuh b/cpp/include/raft/linalg/add.cuh index 37956fe762..e54eaedec6 100644 --- a/cpp/include/raft/linalg/add.cuh +++ b/cpp/include/raft/linalg/add.cuh @@ -27,13 +27,12 @@ #include #include +#include #include namespace raft { namespace linalg { -using detail::adds_scalar; - /** * @ingroup arithmetic * @brief Elementwise scalar add operation on the input buffer diff --git a/cpp/include/raft/linalg/detail/add.cuh b/cpp/include/raft/linalg/detail/add.cuh index 34966ebbc2..81489ed287 100644 --- a/cpp/include/raft/linalg/detail/add.cuh +++ b/cpp/include/raft/linalg/detail/add.cuh @@ -16,14 +16,10 @@ #pragma once -#include "functional.cuh" - #include #include #include -#include - namespace raft { namespace linalg { namespace detail { @@ -31,13 +27,13 @@ namespace detail { template void addScalar(OutT* out, const InT* in, InT scalar, IdxType len, cudaStream_t stream) { - raft::linalg::unaryOp(out, in, len, adds_scalar(scalar), stream); + raft::linalg::unaryOp(out, in, len, raft::ScalarAdd(scalar), stream); } template void add(OutT* out, const InT* in1, const InT* in2, IdxType len, cudaStream_t stream) { - raft::linalg::binaryOp(out, in1, in2, len, thrust::plus(), stream); + raft::linalg::binaryOp(out, in1, in2, len, raft::Sum(), stream); } template diff --git a/cpp/include/raft/linalg/detail/divide.cuh b/cpp/include/raft/linalg/detail/divide.cuh index 333cd3e83c..c699deca78 100644 --- a/cpp/include/raft/linalg/detail/divide.cuh +++ b/cpp/include/raft/linalg/detail/divide.cuh @@ -16,10 +16,9 @@ #pragma once -#include "functional.cuh" - #include #include +#include namespace raft { namespace linalg { @@ -28,7 +27,7 @@ namespace detail { template void divideScalar(OutT* out, const InT* in, InT scalar, IdxType len, cudaStream_t stream) { - raft::linalg::unaryOp(out, in, len, divides_scalar(scalar), stream); + raft::linalg::unaryOp(out, in, len, raft::ScalarDiv(scalar), stream); } }; // end namespace detail diff --git a/cpp/include/raft/linalg/detail/eltwise.cuh b/cpp/include/raft/linalg/detail/eltwise.cuh index 019f86a779..c087ab8c1c 100644 --- a/cpp/include/raft/linalg/detail/eltwise.cuh +++ b/cpp/include/raft/linalg/detail/eltwise.cuh @@ -20,8 +20,7 @@ #include #include - -#include +#include namespace raft { namespace linalg { @@ -30,48 +29,48 @@ namespace detail { template void scalarAdd(OutType* out, const InType* in, InType scalar, IdxType len, cudaStream_t stream) { - raft::linalg::unaryOp(out, in, len, adds_scalar(scalar), stream); + raft::linalg::unaryOp(out, in, len, raft::ScalarAdd(scalar), stream); } template void scalarMultiply(OutType* out, const InType* in, InType scalar, IdxType len, cudaStream_t stream) { - raft::linalg::unaryOp(out, in, len, multiplies_scalar(scalar), stream); + raft::linalg::unaryOp(out, in, len, raft::ScalarMul(scalar), stream); } template void eltwiseAdd( OutType* out, const InType* in1, const InType* in2, IdxType len, cudaStream_t stream) { - raft::linalg::binaryOp(out, in1, in2, len, thrust::plus(), stream); + raft::linalg::binaryOp(out, in1, in2, len, raft::Sum(), stream); } template void eltwiseSub( OutType* out, const InType* in1, const InType* in2, IdxType len, cudaStream_t stream) { - raft::linalg::binaryOp(out, in1, in2, len, thrust::minus(), stream); + raft::linalg::binaryOp(out, in1, in2, len, raft::Subtract(), stream); } template void eltwiseMultiply( OutType* out, const InType* in1, const InType* in2, IdxType len, cudaStream_t stream) { - raft::linalg::binaryOp(out, in1, in2, len, thrust::multiplies(), stream); + raft::linalg::binaryOp(out, in1, in2, len, raft::Multiply(), stream); } template void eltwiseDivide( OutType* out, const InType* in1, const InType* in2, IdxType len, cudaStream_t stream) { - raft::linalg::binaryOp(out, in1, in2, len, thrust::divides(), stream); + raft::linalg::binaryOp(out, in1, in2, len, raft::Divide(), stream); } template void eltwiseDivideCheckZero( OutType* out, const InType* in1, const InType* in2, IdxType len, cudaStream_t stream) { - raft::linalg::binaryOp(out, in1, in2, len, divides_check_zero(), stream); + raft::linalg::binaryOp(out, in1, in2, len, raft::ScalarDivCheckZero(), stream); } }; // end namespace detail diff --git a/cpp/include/raft/linalg/detail/functional.cuh b/cpp/include/raft/linalg/detail/functional.cuh deleted file mode 100644 index 067b1565e0..0000000000 --- a/cpp/include/raft/linalg/detail/functional.cuh +++ /dev/null @@ -1,69 +0,0 @@ -/* - * Copyright (c) 2022, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -namespace raft { -namespace linalg { -namespace detail { - -template -struct divides_scalar { - public: - divides_scalar(ArgType scalar) : scalar_(scalar) {} - - __host__ __device__ inline ReturnType operator()(ArgType in) { return in / scalar_; } - - private: - ArgType scalar_; -}; - -template -struct adds_scalar { - public: - adds_scalar(ArgType scalar) : scalar_(scalar) {} - - __host__ __device__ inline ReturnType operator()(ArgType in) { return in + scalar_; } - - private: - ArgType scalar_; -}; - -template -struct multiplies_scalar { - public: - multiplies_scalar(ArgType scalar) : scalar_(scalar) {} - - __host__ __device__ inline ReturnType operator()(ArgType in) { return in * scalar_; } - - private: - ArgType scalar_; -}; - -template -struct divides_check_zero { - public: - __host__ __device__ inline ReturnType operator()(ArgType a, ArgType b) - { - return (b == static_cast(0)) ? 0.0 : a / b; - } -}; - -} // namespace detail -} // namespace linalg -} // namespace raft \ No newline at end of file diff --git a/cpp/include/raft/linalg/detail/multiply.cuh b/cpp/include/raft/linalg/detail/multiply.cuh index f1a8548bfa..7c492b839f 100644 --- a/cpp/include/raft/linalg/detail/multiply.cuh +++ b/cpp/include/raft/linalg/detail/multiply.cuh @@ -26,8 +26,7 @@ template void multiplyScalar( math_t* out, const math_t* in, const math_t scalar, IdxType len, cudaStream_t stream) { - raft::linalg::unaryOp( - out, in, len, [scalar] __device__(math_t in) { return in * scalar; }, stream); + raft::linalg::unaryOp(out, in, len, raft::ScalarMul{scalar}, stream); } }; // end namespace detail diff --git a/cpp/include/raft/linalg/detail/strided_reduction.cuh b/cpp/include/raft/linalg/detail/strided_reduction.cuh index d72bd54a32..d9b14a8155 100644 --- a/cpp/include/raft/linalg/detail/strided_reduction.cuh +++ b/cpp/include/raft/linalg/detail/strided_reduction.cuh @@ -123,9 +123,7 @@ void stridedReduction(OutType* dots, { ///@todo: this extra should go away once we have eliminated the need /// for atomics in stridedKernel (redesign for this is already underway) - if (!inplace) - raft::linalg::unaryOp( - dots, dots, D, [init] __device__(OutType a) { return init; }, stream); + if (!inplace) raft::linalg::unaryOp(dots, dots, D, raft::ConstOp(init), stream); // Arbitrary numbers for now, probably need to tune const dim3 thrds(32, 16); diff --git a/cpp/include/raft/linalg/detail/subtract.cuh b/cpp/include/raft/linalg/detail/subtract.cuh index ae0f09d2fe..e28f7c8e4c 100644 --- a/cpp/include/raft/linalg/detail/subtract.cuh +++ b/cpp/include/raft/linalg/detail/subtract.cuh @@ -27,15 +27,13 @@ namespace detail { template void subtractScalar(OutT* out, const InT* in, InT scalar, IdxType len, cudaStream_t stream) { - auto op = [scalar] __device__(InT in) { return OutT(in - scalar); }; - raft::linalg::unaryOp(out, in, len, op, stream); + raft::linalg::unaryOp(out, in, len, raft::ScalarSub(scalar), stream); } template void subtract(OutT* out, const InT* in1, const InT* in2, IdxType len, cudaStream_t stream) { - auto op = [] __device__(InT a, InT b) { return OutT(a - b); }; - raft::linalg::binaryOp(out, in1, in2, len, op, stream); + raft::linalg::binaryOp(out, in1, in2, len, raft::Subtract(), stream); } template diff --git a/cpp/include/raft/linalg/divide.cuh b/cpp/include/raft/linalg/divide.cuh index 53b083045e..526d8a9716 100644 --- a/cpp/include/raft/linalg/divide.cuh +++ b/cpp/include/raft/linalg/divide.cuh @@ -21,13 +21,12 @@ #include "detail/divide.cuh" #include +#include #include namespace raft { namespace linalg { -using detail::divides_scalar; - /** * @defgroup ScalarOps Scalar operations on the input buffer * @tparam OutT output data-type upon which the math operation will be performed diff --git a/cpp/include/raft/linalg/eltwise.cuh b/cpp/include/raft/linalg/eltwise.cuh index dbc06a4af3..2e6c1a4ab5 100644 --- a/cpp/include/raft/linalg/eltwise.cuh +++ b/cpp/include/raft/linalg/eltwise.cuh @@ -23,8 +23,6 @@ namespace raft { namespace linalg { -using detail::adds_scalar; - /** * @defgroup ScalarOps Scalar operations on the input buffer * @tparam InType data-type upon which the math operation will be performed @@ -42,8 +40,6 @@ void scalarAdd(OutType* out, const InType* in, InType scalar, IdxType len, cudaS detail::scalarAdd(out, in, scalar, len, stream); } -using detail::multiplies_scalar; - template void scalarMultiply(OutType* out, const InType* in, InType scalar, IdxType len, cudaStream_t stream) { @@ -90,8 +86,6 @@ void eltwiseDivide( detail::eltwiseDivide(out, in1, in2, len, stream); } -using detail::divides_check_zero; - template void eltwiseDivideCheckZero( OutType* out, const InType* in1, const InType* in2, IdxType len, cudaStream_t stream) diff --git a/cpp/include/raft/linalg/power.cuh b/cpp/include/raft/linalg/power.cuh index acd226b71d..d5d898d768 100644 --- a/cpp/include/raft/linalg/power.cuh +++ b/cpp/include/raft/linalg/power.cuh @@ -41,8 +41,7 @@ namespace linalg { template void powerScalar(out_t* out, const in_t* in, const in_t scalar, IdxType len, cudaStream_t stream) { - raft::linalg::unaryOp( - out, in, len, [scalar] __device__(in_t in) { return raft::myPow(in, scalar); }, stream); + raft::linalg::unaryOp(out, in, len, raft::ScalarPow(scalar), stream); } /** @} */ @@ -61,8 +60,7 @@ void powerScalar(out_t* out, const in_t* in, const in_t scalar, IdxType len, cud template void power(out_t* out, const in_t* in1, const in_t* in2, IdxType len, cudaStream_t stream) { - raft::linalg::binaryOp( - out, in1, in2, len, [] __device__(in_t a, in_t b) { return raft::myPow(a, b); }, stream); + raft::linalg::binaryOp(out, in1, in2, len, raft::Pow(), stream); } /** @} */ diff --git a/cpp/include/raft/linalg/sqrt.cuh b/cpp/include/raft/linalg/sqrt.cuh index 2951285c3a..a8cc3ec6ba 100644 --- a/cpp/include/raft/linalg/sqrt.cuh +++ b/cpp/include/raft/linalg/sqrt.cuh @@ -38,8 +38,7 @@ namespace linalg { template void sqrt(out_t* out, const in_t* in, IdxType len, cudaStream_t stream) { - raft::linalg::unaryOp( - out, in, len, [] __device__(in_t in) { return raft::mySqrt(in); }, stream); + raft::linalg::unaryOp(out, in, len, SqrtOp{}, stream); } /** @} */ diff --git a/cpp/include/raft/matrix/detail/gather.cuh b/cpp/include/raft/matrix/detail/gather.cuh index 3738afba5d..bbd24f0353 100644 --- a/cpp/include/raft/matrix/detail/gather.cuh +++ b/cpp/include/raft/matrix/detail/gather.cuh @@ -16,6 +16,8 @@ #pragma once +#include + namespace raft { namespace matrix { namespace detail { @@ -183,16 +185,7 @@ void gather(const MatrixIteratorT in, { typedef typename std::iterator_traits::value_type MapValueT; gatherImpl( - in, - D, - N, - map, - map, - map_length, - out, - [] __device__(MapValueT val) { return true; }, - [] __device__(MapValueT val) { return val; }, - stream); + in, D, N, map, map, map_length, out, raft::ConstOp(true), raft::Nop(), stream); } /** @@ -227,17 +220,7 @@ void gather(const MatrixIteratorT in, cudaStream_t stream) { typedef typename std::iterator_traits::value_type MapValueT; - gatherImpl( - in, - D, - N, - map, - map, - map_length, - out, - [] __device__(MapValueT val) { return true; }, - transform_op, - stream); + gatherImpl(in, D, N, map, map, map_length, out, raft::ConstOp(true), transform_op, stream); } /** @@ -279,17 +262,7 @@ void gather_if(const MatrixIteratorT in, cudaStream_t stream) { typedef typename std::iterator_traits::value_type MapValueT; - gatherImpl( - in, - D, - N, - map, - stencil, - map_length, - out, - pred_op, - [] __device__(MapValueT val) { return val; }, - stream); + gatherImpl(in, D, N, map, stencil, map_length, out, pred_op, raft::Nop(), stream); } /** diff --git a/cpp/include/raft/matrix/detail/math.cuh b/cpp/include/raft/matrix/detail/math.cuh index 64c85a03a5..dec3d17b96 100644 --- a/cpp/include/raft/matrix/detail/math.cuh +++ b/cpp/include/raft/matrix/detail/math.cuh @@ -188,8 +188,7 @@ void reciprocal(math_t* in, math_t* out, IdxType len, cudaStream_t stream) template void setValue(math_t* out, const math_t* in, math_t scalar, int len, cudaStream_t stream = 0) { - raft::linalg::unaryOp( - out, in, len, [scalar] __device__(math_t in) { return scalar; }, stream); + raft::linalg::unaryOp(out, in, len, raft::ConstOp(scalar), stream); } template @@ -201,8 +200,7 @@ void ratio( rmm::device_scalar d_sum(stream); auto* d_sum_ptr = d_sum.data(); - auto no_op = [] __device__(math_t in) { return in; }; - raft::linalg::mapThenSumReduce(d_sum_ptr, len, no_op, stream, src); + raft::linalg::mapThenSumReduce(d_sum_ptr, len, raft::Nop{}, stream, src); raft::linalg::unaryOp( d_dest, d_src, len, [=] __device__(math_t a) { return a / (*d_sum_ptr); }, stream); } @@ -217,15 +215,7 @@ void matrixVectorBinaryMult(Type* data, cudaStream_t stream) { raft::linalg::matrixVectorOp( - data, - data, - vec, - n_col, - n_row, - rowMajor, - bcastAlongRows, - [] __device__(Type a, Type b) { return a * b; }, - stream); + data, data, vec, n_col, n_row, rowMajor, bcastAlongRows, raft::Multiply(), stream); } template @@ -264,15 +254,7 @@ void matrixVectorBinaryDiv(Type* data, cudaStream_t stream) { raft::linalg::matrixVectorOp( - data, - data, - vec, - n_col, - n_row, - rowMajor, - bcastAlongRows, - [] __device__(Type a, Type b) { return a / b; }, - stream); + data, data, vec, n_col, n_row, rowMajor, bcastAlongRows, raft::Divide(), stream); } template @@ -330,15 +312,7 @@ void matrixVectorBinaryAdd(Type* data, cudaStream_t stream) { raft::linalg::matrixVectorOp( - data, - data, - vec, - n_col, - n_row, - rowMajor, - bcastAlongRows, - [] __device__(Type a, Type b) { return a + b; }, - stream); + data, data, vec, n_col, n_row, rowMajor, bcastAlongRows, raft::Sum(), stream); } template @@ -351,15 +325,7 @@ void matrixVectorBinarySub(Type* data, cudaStream_t stream) { raft::linalg::matrixVectorOp( - data, - data, - vec, - n_col, - n_row, - rowMajor, - bcastAlongRows, - [] __device__(Type a, Type b) { return a - b; }, - stream); + data, data, vec, n_col, n_row, rowMajor, bcastAlongRows, raft::Subtract(), stream); } // Computes an argmin/argmax column-wise in a DxN matrix diff --git a/cpp/include/raft/sparse/distance/detail/lp_distance.cuh b/cpp/include/raft/sparse/distance/detail/lp_distance.cuh index 0707eb2a9b..3c5ee62e5e 100644 --- a/cpp/include/raft/sparse/distance/detail/lp_distance.cuh +++ b/cpp/include/raft/sparse/distance/detail/lp_distance.cuh @@ -193,13 +193,12 @@ class lp_unexpanded_distances_t : public distances_t { { unexpanded_lp_distances(out_dists, config_, PDiff(p), Sum(), AtomicAdd()); - float one_over_p = 1.0f / p; - raft::linalg::unaryOp( - out_dists, - out_dists, - config_->a_nrows * config_->b_nrows, - [=] __device__(value_t input) { return pow(input, one_over_p); }, - config_->handle.get_stream()); + value_t one_over_p = value_t{1} / p; + raft::linalg::unaryOp(out_dists, + out_dists, + config_->a_nrows * config_->b_nrows, + raft::ScalarPow(one_over_p), + config_->handle.get_stream()); } private: @@ -220,12 +219,11 @@ class hamming_unexpanded_distances_t : public distances_t { unexpanded_lp_distances(out_dists, config_, NotEqual(), Sum(), AtomicAdd()); value_t n_cols = 1.0 / config_->a_ncols; - raft::linalg::unaryOp( - out_dists, - out_dists, - config_->a_nrows * config_->b_nrows, - [=] __device__(value_t input) { return input * n_cols; }, - config_->handle.get_stream()); + raft::linalg::unaryOp(out_dists, + out_dists, + config_->a_nrows * config_->b_nrows, + raft::ScalarMul(n_cols), + config_->handle.get_stream()); } private: @@ -302,12 +300,11 @@ class kl_divergence_unexpanded_distances_t : public distances_t { Sum(), AtomicAdd()); - raft::linalg::unaryOp( - out_dists, - out_dists, - config_->a_nrows * config_->b_nrows, - [=] __device__(value_t input) { return 0.5 * input; }, - config_->handle.get_stream()); + raft::linalg::unaryOp(out_dists, + out_dists, + config_->a_nrows * config_->b_nrows, + raft::ScalarMul(0.5), + config_->handle.get_stream()); } private: diff --git a/cpp/include/raft/sparse/op/detail/slice.cuh b/cpp/include/raft/sparse/op/detail/slice.cuh index 193d246b4b..d402a9e1e3 100644 --- a/cpp/include/raft/sparse/op/detail/slice.cuh +++ b/cpp/include/raft/sparse/op/detail/slice.cuh @@ -70,12 +70,11 @@ void csr_row_slice_indptr(value_idx start_row, // we add another 1 to stop row. raft::copy_async(indptr_out, indptr + start_row, (stop_row + 2) - start_row, stream); - raft::linalg::unaryOp( - indptr_out, - indptr_out, - (stop_row + 2) - start_row, - [s_offset] __device__(value_idx input) { return input - s_offset; }, - stream); + raft::linalg::unaryOp(indptr_out, + indptr_out, + (stop_row + 2) - start_row, + raft::ScalarSub(s_offset), + stream); } /** diff --git a/cpp/include/raft/spatial/knn/detail/ann_quantized.cuh b/cpp/include/raft/spatial/knn/detail/ann_quantized.cuh index e5900ffd69..c3f7015bab 100644 --- a/cpp/include/raft/spatial/knn/detail/ann_quantized.cuh +++ b/cpp/include/raft/spatial/knn/detail/ann_quantized.cuh @@ -202,11 +202,7 @@ void approx_knn_search(const handle_t& handle, float p = 0.5; // standard l2 if (index->metric == raft::distance::DistanceType::LpUnexpanded) p = 1.0 / index->metricArg; raft::linalg::unaryOp( - distances, - distances, - n * k, - [p] __device__(float input) { return powf(input, p); }, - handle.get_stream()); + distances, distances, n * k, raft::ScalarPow(p), handle.get_stream()); } if constexpr (std::is_same_v) { index->metric_processor->postprocess(distances); } } diff --git a/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh b/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh index 41f1df85fe..795901903e 100644 --- a/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh +++ b/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh @@ -566,8 +566,6 @@ void fusedL2UnexpKnnImpl(const DataT* x, acc += diff * diff; }; - auto fin_op = [] __device__(AccT d_val, int g_d_idx) { return d_val; }; - typedef cub::KeyValuePair Pair; if (isRowMajor) { @@ -578,7 +576,7 @@ void fusedL2UnexpKnnImpl(const DataT* x, IdxT, KPolicy, decltype(core_lambda), - decltype(fin_op), + raft::Nop, 32, 2, usePrevTopKs, @@ -590,7 +588,7 @@ void fusedL2UnexpKnnImpl(const DataT* x, IdxT, KPolicy, decltype(core_lambda), - decltype(fin_op), + raft::Nop, 64, 3, usePrevTopKs, @@ -630,7 +628,7 @@ void fusedL2UnexpKnnImpl(const DataT* x, ldb, ldd, core_lambda, - fin_op, + raft::Nop{}, sqrt, (uint32_t)numOfNN, (int*)workspace, @@ -757,8 +755,6 @@ void fusedL2ExpKnnImpl(const DataT* x, // Accumulation operation lambda auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) { acc += x * y; }; - auto fin_op = [] __device__(AccT d_val, int g_d_idx) { return d_val; }; - typedef cub::KeyValuePair Pair; if (isRowMajor) { @@ -769,7 +765,7 @@ void fusedL2ExpKnnImpl(const DataT* x, IdxT, KPolicy, decltype(core_lambda), - decltype(fin_op), + raft::Nop, 32, 2, usePrevTopKs, @@ -781,7 +777,7 @@ void fusedL2ExpKnnImpl(const DataT* x, IdxT, KPolicy, decltype(core_lambda), - decltype(fin_op), + raft::Nop, 64, 3, usePrevTopKs, @@ -818,14 +814,15 @@ void fusedL2ExpKnnImpl(const DataT* x, DataT* xn = (DataT*)workspace; DataT* yn = (DataT*)workspace; - auto norm_op = [] __device__(DataT in) { return in; }; - if (x != y) { yn += m; - raft::linalg::rowNorm(xn, x, k, m, raft::linalg::L2Norm, isRowMajor, stream, norm_op); - raft::linalg::rowNorm(yn, y, k, n, raft::linalg::L2Norm, isRowMajor, stream, norm_op); + raft::linalg::rowNorm( + xn, x, k, m, raft::linalg::L2Norm, isRowMajor, stream, raft::Nop{}); + raft::linalg::rowNorm( + yn, y, k, n, raft::linalg::L2Norm, isRowMajor, stream, raft::Nop{}); } else { - raft::linalg::rowNorm(xn, x, k, n, raft::linalg::L2Norm, isRowMajor, stream, norm_op); + raft::linalg::rowNorm( + xn, x, k, n, raft::linalg::L2Norm, isRowMajor, stream, raft::Nop{}); } fusedL2ExpKnnRowMajor<<>>(x, y, @@ -838,7 +835,7 @@ void fusedL2ExpKnnImpl(const DataT* x, ldb, ldd, core_lambda, - fin_op, + raft::Nop{}, sqrt, (uint32_t)numOfNN, mutexes, diff --git a/cpp/include/raft/spatial/knn/detail/ivf_pq_build.cuh b/cpp/include/raft/spatial/knn/detail/ivf_pq_build.cuh index 9262ef6baf..6604c1d9ce 100644 --- a/cpp/include/raft/spatial/knn/detail/ivf_pq_build.cuh +++ b/cpp/include/raft/spatial/knn/detail/ivf_pq_build.cuh @@ -259,14 +259,7 @@ void select_residuals(const handle_t& handle, n_rows, (IdxT)dim, dataset, row_ids, (IdxT)dim, tmp.data(), (IdxT)dim, stream); raft::matrix::linewiseOp( - tmp.data(), - tmp.data(), - IdxT(dim), - n_rows, - true, - [] __device__(float a, float b) { return a - b; }, - stream, - center); + tmp.data(), tmp.data(), IdxT(dim), n_rows, true, raft::Subtract{}, stream, center); float alpha = 1.0; float beta = 0.0; diff --git a/cpp/include/raft/spatial/knn/detail/ivf_pq_search.cuh b/cpp/include/raft/spatial/knn/detail/ivf_pq_search.cuh index c1a3682f47..cd9ae5f712 100644 --- a/cpp/include/raft/spatial/knn/detail/ivf_pq_search.cuh +++ b/cpp/include/raft/spatial/knn/detail/ivf_pq_search.cuh @@ -419,13 +419,7 @@ void postprocess_distances(float* out, // [n_queries, topk] case distance::DistanceType::L2Unexpanded: case distance::DistanceType::L2Expanded: { linalg::unaryOp( - out, - in, - len, - [scaling_factor] __device__(ScoreT x) -> float { - return scaling_factor * scaling_factor * float(x); - }, - stream); + out, in, len, raft::ScalarMul(scaling_factor * scaling_factor), stream); } break; case distance::DistanceType::L2SqrtUnexpanded: case distance::DistanceType::L2SqrtExpanded: { @@ -438,13 +432,7 @@ void postprocess_distances(float* out, // [n_queries, topk] } break; case distance::DistanceType::InnerProduct: { linalg::unaryOp( - out, - in, - len, - [scaling_factor] __device__(ScoreT x) -> float { - return -scaling_factor * scaling_factor * float(x); - }, - stream); + out, in, len, raft::ScalarMul(-scaling_factor * scaling_factor), stream); } break; default: RAFT_FAIL("Unexpected metric."); } diff --git a/cpp/include/raft/spatial/knn/detail/processing.cuh b/cpp/include/raft/spatial/knn/detail/processing.cuh index a80c1c1935..8b30010dc3 100644 --- a/cpp/include/raft/spatial/knn/detail/processing.cuh +++ b/cpp/include/raft/spatial/knn/detail/processing.cuh @@ -59,32 +59,30 @@ class CosineMetricProcessor : public MetricProcessor { raft::linalg::NormType::L2Norm, row_major_, stream_, - [] __device__(math_t in) { return sqrtf(in); }); - - raft::linalg::matrixVectorOp( - data, - data, - colsums_.data(), - n_cols_, - n_rows_, - row_major_, - false, - [] __device__(math_t mat_in, math_t vec_in) { return mat_in / vec_in; }, - stream_); + raft::SqrtOp{}); + + raft::linalg::matrixVectorOp(data, + data, + colsums_.data(), + n_cols_, + n_rows_, + row_major_, + false, + raft::Divide{}, + stream_); } void revert(math_t* data) { - raft::linalg::matrixVectorOp( - data, - data, - colsums_.data(), - n_cols_, - n_rows_, - row_major_, - false, - [] __device__(math_t mat_in, math_t vec_in) { return mat_in * vec_in; }, - stream_); + raft::linalg::matrixVectorOp(data, + data, + colsums_.data(), + n_cols_, + n_rows_, + row_major_, + false, + raft::Multiply{}, + stream_); } void postprocess(math_t* data) @@ -122,12 +120,11 @@ class CorrelationMetricProcessor : public CosineMetricProcessor { true, cosine::stream_); - raft::linalg::unaryOp( - means_.data(), - means_.data(), - cosine::n_rows_, - [=] __device__(math_t in) { return in * normalizer_const; }, - cosine::stream_); + raft::linalg::unaryOp(means_.data(), + means_.data(), + cosine::n_rows_, + raft::ScalarMul(normalizer_const), + cosine::stream_); raft::stats::meanCenter(data, data, diff --git a/cpp/include/raft/stats/detail/mean_center.cuh b/cpp/include/raft/stats/detail/mean_center.cuh index 61017511b1..8453744789 100644 --- a/cpp/include/raft/stats/detail/mean_center.cuh +++ b/cpp/include/raft/stats/detail/mean_center.cuh @@ -49,15 +49,7 @@ void meanCenter(Type* out, cudaStream_t stream) { raft::linalg::matrixVectorOp( - out, - data, - mu, - D, - N, - rowMajor, - bcastAlongRows, - [] __device__(Type a, Type b) { return a - b; }, - stream); + out, data, mu, D, N, rowMajor, bcastAlongRows, raft::Subtract{}, stream); } /** @@ -85,15 +77,7 @@ void meanAdd(Type* out, cudaStream_t stream) { raft::linalg::matrixVectorOp( - out, - data, - mu, - D, - N, - rowMajor, - bcastAlongRows, - [] __device__(Type a, Type b) { return a + b; }, - stream); + out, data, mu, D, N, rowMajor, bcastAlongRows, raft::Sum{}, stream); } }; // end namespace detail diff --git a/cpp/include/raft/stats/detail/silhouette_score.cuh b/cpp/include/raft/stats/detail/silhouette_score.cuh index 076d9b13e5..115a6ad9c6 100644 --- a/cpp/include/raft/stats/detail/silhouette_score.cuh +++ b/cpp/include/raft/stats/detail/silhouette_score.cuh @@ -172,20 +172,6 @@ struct SilOp { } }; -/** - * @brief structure that defines the reduction Lambda to find minimum between elements - */ -template -struct MinOp { - HDI DataT operator()(DataT a, DataT b) - { - if (a > b) - return b; - else - return a; - } -}; - /** * @brief main function that returns the average silhouette score for a given set of data and its * clusterings @@ -300,8 +286,8 @@ DataT silhouette_score( true, stream, false, - raft::Nop(), - MinOp()); + raft::Nop{}, + raft::Min{}); // calculating the silhouette score per sample using the d_aArray and d_bArray raft::linalg::binaryOp>( diff --git a/cpp/include/raft/stats/detail/weighted_mean.cuh b/cpp/include/raft/stats/detail/weighted_mean.cuh index 43dbe4e7f1..c8b59e08a2 100644 --- a/cpp/include/raft/stats/detail/weighted_mean.cuh +++ b/cpp/include/raft/stats/detail/weighted_mean.cuh @@ -66,8 +66,8 @@ void weightedMean(Type* mu, stream, false, [weights] __device__(Type v, IdxType i) { return v * weights[i]; }, - [] __device__(Type a, Type b) { return a + b; }, - [WS] __device__(Type v) { return v / WS; }); + raft::Sum{}, + raft::ScalarDiv(WS)); } }; // end namespace detail }; // end namespace stats diff --git a/cpp/include/raft/util/cuda_utils.cuh b/cpp/include/raft/util/cuda_utils.cuh index 5818fc21f3..f167ad76ed 100644 --- a/cpp/include/raft/util/cuda_utils.cuh +++ b/cpp/include/raft/util/cuda_utils.cuh @@ -508,44 +508,170 @@ HDI double myATanh(double x) /** @} */ /** - * @defgroup LambdaOps Lambda operations in reduction kernels + * @defgroup LambdaOps Commonly used lambda operations * @{ */ -// IdxType mostly to be used for MainLambda in *Reduction kernels +// The optional index argument is mostly to be used for MainLambda in reduction kernels template struct Nop { - HDI Type operator()(Type in, IdxType i = 0) { return in; } + HDI Type operator()(Type in, IdxType i = 0) const { return in; } +}; + +struct KeyOp { + template + HDI typename KVP::Key operator()(const KVP& p, IdxType i = 0) const + { + return p.key; + } +}; + +struct ValueOp { + template + HDI typename KVP::Value operator()(const KVP& p, IdxType i = 0) const + { + return p.value; + } }; template struct SqrtOp { - HDI Type operator()(Type in, IdxType i = 0) { return mySqrt(in); } + HDI Type operator()(Type in, IdxType i = 0) const { return mySqrt(in); } }; template struct L0Op { - HDI Type operator()(Type in, IdxType i = 0) { return in != Type(0) ? Type(1) : Type(0); } + HDI Type operator()(Type in, IdxType i = 0) const { return in != Type(0) ? Type(1) : Type(0); } }; template struct L1Op { - HDI Type operator()(Type in, IdxType i = 0) { return myAbs(in); } + HDI Type operator()(Type in, IdxType i = 0) const { return myAbs(in); } }; template struct L2Op { - HDI Type operator()(Type in, IdxType i = 0) { return in * in; } + HDI Type operator()(Type in, IdxType i = 0) const { return in * in; } }; -template +template struct Sum { - HDI Type operator()(Type a, Type b) { return a + b; } + HDI OutT operator()(InT a, InT b) const { return a + b; } +}; + +template +struct Subtract { + HDI OutT operator()(InT a, InT b) const { return a - b; } +}; + +template +struct Multiply { + HDI OutT operator()(InT a, InT b) const { return a * b; } +}; + +template +struct Divide { + HDI OutT operator()(InT a, InT b) const { return a / b; } +}; + +template +struct DivideCheckZero { + HDI OutT operator()(InT a, InT b) const + { + if (b == InT{0}) { return InT{0}; } + return a / b; + } +}; + +template +struct Pow { + HDI OutT operator()(InT a, InT b) const { return raft::myPow(a, b); } +}; + +template +struct Min { + HDI Type operator()(Type a, Type b) const + { + if (a > b) { return b; } + return a; + } }; template struct Max { - HDI Type operator()(Type a, Type b) { return myMax(a, b); } + HDI Type operator()(Type a, Type b) const + { + if (b > a) { return b; } + return a; + } +}; + +template +struct SqDiff { + HDI Type operator()(Type a, Type b) const + { + Type diff = a - b; + return diff * diff; + } +}; + +struct ArgMin { + template + HDI KVP operator()(const KVP& a, const KVP& b) const + { + if ((b.value < a.value) || ((a.value == b.value) && (b.key < a.key))) { return b; } + return a; + } +}; + +struct ArgMax { + template + HDI KVP operator()(const KVP& a, const KVP& b) const + { + if ((b.value > a.value) || ((a.value == b.value) && (b.key < a.key))) { return b; } + return a; + } +}; + +template +struct ConstOp { + const OutT scalar; + + ConstOp(OutT s) : scalar{s} {} + + template + HDI OutT operator()(InT unused) const + { + return scalar; + } +}; + +template +struct ScalarOp { + ComposedOpT composed_op; + const InT scalar; + + ScalarOp(InT s) : scalar{s} {} + + HDI OutT operator()(InT a) const { return composed_op(a, scalar); } }; + +template +using ScalarAdd = ScalarOp, InT, OutT>; + +template +using ScalarSub = ScalarOp, InT, OutT>; + +template +using ScalarMul = ScalarOp, InT, OutT>; + +template +using ScalarDiv = ScalarOp, InT, OutT>; + +template +using ScalarDivCheckZero = ScalarOp, InT, OutT>; + +template +using ScalarPow = ScalarOp, InT, OutT>; /** @} */ /** diff --git a/cpp/src/distance/cluster_cost.cuh b/cpp/src/distance/cluster_cost.cuh index 344673830b..82b6c041a3 100644 --- a/cpp/src/distance/cluster_cost.cuh +++ b/cpp/src/distance/cluster_cost.cuh @@ -59,20 +59,18 @@ void cluster_cost(const raft::handle_t& handle, handle.get_stream()); auto distances = raft::make_device_vector(handle, n_samples); - thrust::transform( - handle.get_thrust_policy(), - min_cluster_distance.data_handle(), - min_cluster_distance.data_handle() + n_samples, - distances.data_handle(), - [] __device__(const raft::KeyValuePair& a) { return a.value; }); + thrust::transform(handle.get_thrust_policy(), + min_cluster_distance.data_handle(), + min_cluster_distance.data_handle() + n_samples, + distances.data_handle(), + raft::ValueOp{}); rmm::device_scalar device_cost(0, handle.get_stream()); - raft::cluster::kmeans::cluster_cost( - handle, - distances.view(), - workspace, - make_device_scalar_view(device_cost.data()), - [] __device__(const ElementType& a, const ElementType& b) { return a + b; }); + raft::cluster::kmeans::cluster_cost(handle, + distances.view(), + workspace, + make_device_scalar_view(device_cost.data()), + raft::Sum{}); raft::update_host(cost, device_cost.data(), 1, handle.get_stream()); } diff --git a/cpp/test/cluster/kmeans.cu b/cpp/test/cluster/kmeans.cu index 698d23ac27..e5fa78e3aa 100644 --- a/cpp/test/cluster/kmeans.cu +++ b/cpp/test/cluster/kmeans.cu @@ -50,11 +50,7 @@ void run_cluster_cost(const raft::handle_t& handle, raft::device_scalar_view clusterCost) { raft::cluster::kmeans::cluster_cost( - handle, - minClusterDistance, - workspace, - clusterCost, - [] __device__(const DataT& a, const DataT& b) { return a + b; }); + handle, minClusterDistance, workspace, clusterCost, raft::Sum{}); } template @@ -347,25 +343,25 @@ const std::vector> inputsf2 = {{1000, 32, 5, 0.0001f, true}, {10000, 500, 100, 0.0001f, true}, {10000, 500, 100, 0.0001f, false}}; -const std::vector> inputsd2 = {{1000, 32, 5, 0.0001, true}, - {1000, 32, 5, 0.0001, false}, - {1000, 100, 20, 0.0001, true}, - {1000, 100, 20, 0.0001, false}, - {10000, 32, 10, 0.0001, true}, - {10000, 32, 10, 0.0001, false}, - {10000, 100, 50, 0.0001, true}, - {10000, 100, 50, 0.0001, false}, - {10000, 500, 100, 0.0001, true}, - {10000, 500, 100, 0.0001, false}}; +// const std::vector> inputsd2 = {{1000, 32, 5, 0.0001, true}, +// {1000, 32, 5, 0.0001, false}, +// {1000, 100, 20, 0.0001, true}, +// {1000, 100, 20, 0.0001, false}, +// {10000, 32, 10, 0.0001, true}, +// {10000, 32, 10, 0.0001, false}, +// {10000, 100, 50, 0.0001, true}, +// {10000, 100, 50, 0.0001, false}, +// {10000, 500, 100, 0.0001, true}, +// {10000, 500, 100, 0.0001, false}}; typedef KmeansTest KmeansTestF; TEST_P(KmeansTestF, Result) { ASSERT_TRUE(score == 1.0); } -typedef KmeansTest KmeansTestD; -TEST_P(KmeansTestD, Result) { ASSERT_TRUE(score == 1.0); } +// typedef KmeansTest KmeansTestD; +// TEST_P(KmeansTestD, Result) { ASSERT_TRUE(score == 1.0); } INSTANTIATE_TEST_CASE_P(KmeansTests, KmeansTestF, ::testing::ValuesIn(inputsf2)); -INSTANTIATE_TEST_CASE_P(KmeansTests, KmeansTestD, ::testing::ValuesIn(inputsd2)); +// INSTANTIATE_TEST_CASE_P(KmeansTests, KmeansTestD, ::testing::ValuesIn(inputsd2)); } // namespace raft diff --git a/cpp/test/linalg/binary_op.cu b/cpp/test/linalg/binary_op.cu index b92fa09427..848e37de42 100644 --- a/cpp/test/linalg/binary_op.cu +++ b/cpp/test/linalg/binary_op.cu @@ -36,8 +36,7 @@ void binaryOpLaunch( auto in1_view = raft::make_device_vector_view(in1, len); auto in2_view = raft::make_device_vector_view(in2, len); - binary_op( - handle, in1_view, in2_view, out_view, [] __device__(InType a, InType b) { return a + b; }); + binary_op(handle, in1_view, in2_view, out_view, raft::Sum{}); } template @@ -139,12 +138,7 @@ class BinaryOpAlignment : public ::testing::Test { RAFT_CUDA_TRY(cudaMemsetAsync(x.data(), 0, n * sizeof(math_t), stream)); RAFT_CUDA_TRY(cudaMemsetAsync(y.data(), 0, n * sizeof(math_t), stream)); raft::linalg::binaryOp( - z.data() + 9, - x.data() + 137, - y.data() + 19, - 256, - [] __device__(math_t x, math_t y) { return x + y; }, - handle.get_stream()); + z.data() + 9, x.data() + 137, y.data() + 19, 256, raft::Sum{}, handle.get_stream()); } raft::handle_t handle; diff --git a/cpp/test/linalg/coalesced_reduction.cu b/cpp/test/linalg/coalesced_reduction.cu index 791537b430..c429d925ce 100644 --- a/cpp/test/linalg/coalesced_reduction.cu +++ b/cpp/test/linalg/coalesced_reduction.cu @@ -47,8 +47,7 @@ void coalescedReductionLaunch( { auto dots_view = raft::make_device_vector_view(dots, rows); auto data_view = raft::make_device_matrix_view(data, rows, cols); - coalesced_reduction( - handle, data_view, dots_view, (T)0, inplace, [] __device__(T in, int i) { return in * in; }); + coalesced_reduction(handle, data_view, dots_view, (T)0, inplace, raft::L2Op{}); } template diff --git a/cpp/test/linalg/map_then_reduce.cu b/cpp/test/linalg/map_then_reduce.cu index adf784f601..7907f7d1ca 100644 --- a/cpp/test/linalg/map_then_reduce.cu +++ b/cpp/test/linalg/map_then_reduce.cu @@ -63,9 +63,8 @@ template void mapReduceLaunch( OutType* out_ref, OutType* out, const InType* in, size_t len, cudaStream_t stream) { - auto op = [] __device__(InType in) { return in; }; - naiveMapReduce(out_ref, in, len, op, stream); - mapThenSumReduce(out, len, op, 0, in); + naiveMapReduce(out_ref, in, len, raft::Nop{}, stream); + mapThenSumReduce(out, len, raft::Nop{}, 0, in); } template @@ -150,23 +149,21 @@ class MapGenericReduceTest : public ::testing::Test { void testMin() { - auto op = [] __device__(InType in) { return in; }; OutType neutral = std::numeric_limits::max(); auto output_view = raft::make_device_scalar_view(output.data()); auto input_view = raft::make_device_vector_view( input.data(), static_cast(input.size())); - map_reduce(handle, input_view, output_view, neutral, op, cub::Min()); + map_reduce(handle, input_view, output_view, neutral, raft::Nop{}, cub::Min()); EXPECT_TRUE(raft::devArrMatch( OutType(1), output.data(), 1, raft::Compare(), handle.get_stream())); } void testMax() { - auto op = [] __device__(InType in) { return in; }; OutType neutral = std::numeric_limits::min(); auto output_view = raft::make_device_scalar_view(output.data()); auto input_view = raft::make_device_vector_view( input.data(), static_cast(input.size())); - map_reduce(handle, input_view, output_view, neutral, op, cub::Max()); + map_reduce(handle, input_view, output_view, neutral, raft::Nop{}, cub::Max()); EXPECT_TRUE(raft::devArrMatch( OutType(5), output.data(), 1, raft::Compare(), handle.get_stream())); } diff --git a/cpp/test/linalg/matrix_vector.cu b/cpp/test/linalg/matrix_vector.cu index f103b5918b..2424f8d3aa 100644 --- a/cpp/test/linalg/matrix_vector.cu +++ b/cpp/test/linalg/matrix_vector.cu @@ -113,34 +113,25 @@ void naive_matrix_vector_op_launch(const raft::handle_t& handle, return mat_element; } }; - auto operation_div = [] __device__(T mat_element, T vec_element) { - return mat_element / vec_element; - }; auto operation_bin_div_skip_zero = [] __device__(T mat_element, T vec_element) { if (raft::myAbs(vec_element) < T(1e-10)) return T(0); else return mat_element / vec_element; }; - auto operation_bin_add = [] __device__(T mat_element, T vec_element) { - return mat_element + vec_element; - }; - auto operation_bin_sub = [] __device__(T mat_element, T vec_element) { - return mat_element - vec_element; - }; if (operation_type == 0) { naiveMatVec( in, in, vec1, D, N, row_major, bcast_along_rows, operation_bin_mult_skip_zero, stream); } else if (operation_type == 1) { - naiveMatVec(in, in, vec1, D, N, row_major, bcast_along_rows, operation_div, stream); + naiveMatVec(in, in, vec1, D, N, row_major, bcast_along_rows, raft::Divide{}, stream); } else if (operation_type == 2) { naiveMatVec( in, in, vec1, D, N, row_major, bcast_along_rows, operation_bin_div_skip_zero, stream); } else if (operation_type == 3) { - naiveMatVec(in, in, vec1, D, N, row_major, bcast_along_rows, operation_bin_add, stream); + naiveMatVec(in, in, vec1, D, N, row_major, bcast_along_rows, raft::Sum{}, stream); } else if (operation_type == 4) { - naiveMatVec(in, in, vec1, D, N, row_major, bcast_along_rows, operation_bin_sub, stream); + naiveMatVec(in, in, vec1, D, N, row_major, bcast_along_rows, raft::Subtract{}, stream); } else { THROW("Unknown operation type '%d'!", (int)operation_type); } diff --git a/cpp/test/linalg/norm.cu b/cpp/test/linalg/norm.cu index f0b8d3bb55..1752c51af8 100644 --- a/cpp/test/linalg/norm.cu +++ b/cpp/test/linalg/norm.cu @@ -95,11 +95,12 @@ class RowNormTest : public ::testing::TestWithParam> { auto input_col_major = raft::make_device_matrix_view( data.data(), params.rows, params.cols); if (params.do_sqrt) { - auto fin_op = [] __device__(const T in) { return raft::mySqrt(in); }; if (params.rowMajor) { - norm(handle, input_row_major, output_view, params.type, Apply::ALONG_ROWS, fin_op); + norm( + handle, input_row_major, output_view, params.type, Apply::ALONG_ROWS, raft::SqrtOp{}); } else { - norm(handle, input_col_major, output_view, params.type, Apply::ALONG_ROWS, fin_op); + norm( + handle, input_col_major, output_view, params.type, Apply::ALONG_ROWS, raft::SqrtOp{}); } } else { if (params.rowMajor) { @@ -171,11 +172,20 @@ class ColNormTest : public ::testing::TestWithParam> { auto input_col_major = raft::make_device_matrix_view( data.data(), params.rows, params.cols); if (params.do_sqrt) { - auto fin_op = [] __device__(const T in) { return raft::mySqrt(in); }; if (params.rowMajor) { - norm(handle, input_row_major, output_view, params.type, Apply::ALONG_COLUMNS, fin_op); + norm(handle, + input_row_major, + output_view, + params.type, + Apply::ALONG_COLUMNS, + raft::SqrtOp{}); } else { - norm(handle, input_col_major, output_view, params.type, Apply::ALONG_COLUMNS, fin_op); + norm(handle, + input_col_major, + output_view, + params.type, + Apply::ALONG_COLUMNS, + raft::SqrtOp{}); } } else { if (params.rowMajor) { diff --git a/cpp/test/linalg/normalize.cu b/cpp/test/linalg/normalize.cu index cb949b6a5d..702adbf6d7 100644 --- a/cpp/test/linalg/normalize.cu +++ b/cpp/test/linalg/normalize.cu @@ -53,15 +53,7 @@ void rowNormalizeRef( raft::linalg::rowNorm(norm.data(), in, cols, rows, norm_type, true, stream, raft::Nop()); } raft::linalg::matrixVectorOp( - out, - in, - norm.data(), - cols, - rows, - true, - false, - [] __device__(T a, T b) { return a / b; }, - stream); + out, in, norm.data(), cols, rows, true, false, raft::Divide{}, stream); } template diff --git a/cpp/test/matrix/linewise_op.cu b/cpp/test/matrix/linewise_op.cu index 2e3d54dcf5..e8151c72ee 100644 --- a/cpp/test/matrix/linewise_op.cu +++ b/cpp/test/matrix/linewise_op.cu @@ -58,7 +58,6 @@ struct LinewiseTest : public ::testing::TestWithParam void runLinewiseSum(T* out, const T* in, const I lineLen, const I nLines, const T* vec) { - auto f = [] __device__(T a, T b) -> T { return a + b; }; constexpr auto rowmajor = std::is_same_v; I m = rowmajor ? lineLen : nLines; @@ -68,7 +67,8 @@ struct LinewiseTest : public ::testing::TestWithParam(out, n, m); auto vec_view = raft::make_device_vector_view(vec, lineLen); - matrix::linewise_op(handle, in_view, out_view, raft::is_row_major(in_view), f, vec_view); + matrix::linewise_op( + handle, in_view, out_view, raft::is_row_major(in_view), raft::Sum{}, vec_view); } template @@ -107,9 +107,8 @@ struct LinewiseTest : public ::testing::TestWithParam T { return a + b; }; auto vec_view = raft::make_device_vector_view(vec, alongLines ? lineLen : nLines); - matrix::linewise_op(handle, in, out, alongLines, f, vec_view); + matrix::linewise_op(handle, in, out, alongLines, raft::Sum{}, vec_view); } /** diff --git a/cpp/test/sparse/dist_coo_spmv.cu b/cpp/test/sparse/dist_coo_spmv.cu index c004aeaef0..e0a9a7c71c 100644 --- a/cpp/test/sparse/dist_coo_spmv.cu +++ b/cpp/test/sparse/dist_coo_spmv.cu @@ -158,13 +158,12 @@ class SparseDistanceCOOSPMVTest case raft::distance::DistanceType::LpUnexpanded: { compute_dist( detail::PDiff(params.input_configuration.metric_arg), detail::Sum(), detail::AtomicAdd()); - float p = 1.0f / params.input_configuration.metric_arg; - raft::linalg::unaryOp( - out_dists.data(), - out_dists.data(), - dist_config.a_nrows * dist_config.b_nrows, - [=] __device__(value_t input) { return powf(input, p); }, - dist_config.handle.get_stream()); + value_t p = value_t{1} / params.input_configuration.metric_arg; + raft::linalg::unaryOp(out_dists.data(), + out_dists.data(), + dist_config.a_nrows * dist_config.b_nrows, + raft::ScalarPow{p}, + dist_config.handle.get_stream()); } break; default: throw raft::exception("Unknown distance"); From 5ca93074517e2c6260cbcd9deefa14c4e2544813 Mon Sep 17 00:00:00 2001 From: Louis Sugy Date: Fri, 25 Nov 2022 19:58:13 +0100 Subject: [PATCH 02/22] Undo unintentional changes to k-means test --- cpp/test/cluster/kmeans.cu | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/cpp/test/cluster/kmeans.cu b/cpp/test/cluster/kmeans.cu index e5fa78e3aa..20f08eedb8 100644 --- a/cpp/test/cluster/kmeans.cu +++ b/cpp/test/cluster/kmeans.cu @@ -343,25 +343,25 @@ const std::vector> inputsf2 = {{1000, 32, 5, 0.0001f, true}, {10000, 500, 100, 0.0001f, true}, {10000, 500, 100, 0.0001f, false}}; -// const std::vector> inputsd2 = {{1000, 32, 5, 0.0001, true}, -// {1000, 32, 5, 0.0001, false}, -// {1000, 100, 20, 0.0001, true}, -// {1000, 100, 20, 0.0001, false}, -// {10000, 32, 10, 0.0001, true}, -// {10000, 32, 10, 0.0001, false}, -// {10000, 100, 50, 0.0001, true}, -// {10000, 100, 50, 0.0001, false}, -// {10000, 500, 100, 0.0001, true}, -// {10000, 500, 100, 0.0001, false}}; +const std::vector> inputsd2 = {{1000, 32, 5, 0.0001, true}, + {1000, 32, 5, 0.0001, false}, + {1000, 100, 20, 0.0001, true}, + {1000, 100, 20, 0.0001, false}, + {10000, 32, 10, 0.0001, true}, + {10000, 32, 10, 0.0001, false}, + {10000, 100, 50, 0.0001, true}, + {10000, 100, 50, 0.0001, false}, + {10000, 500, 100, 0.0001, true}, + {10000, 500, 100, 0.0001, false}}; typedef KmeansTest KmeansTestF; TEST_P(KmeansTestF, Result) { ASSERT_TRUE(score == 1.0); } -// typedef KmeansTest KmeansTestD; -// TEST_P(KmeansTestD, Result) { ASSERT_TRUE(score == 1.0); } +typedef KmeansTest KmeansTestD; +TEST_P(KmeansTestD, Result) { ASSERT_TRUE(score == 1.0); } INSTANTIATE_TEST_CASE_P(KmeansTests, KmeansTestF, ::testing::ValuesIn(inputsf2)); -// INSTANTIATE_TEST_CASE_P(KmeansTests, KmeansTestD, ::testing::ValuesIn(inputsd2)); +INSTANTIATE_TEST_CASE_P(KmeansTests, KmeansTestD, ::testing::ValuesIn(inputsd2)); } // namespace raft From 20d807a370368654aa0f7389e2f84bc05775e9b2 Mon Sep 17 00:00:00 2001 From: Louis Sugy Date: Fri, 25 Nov 2022 20:07:37 +0100 Subject: [PATCH 03/22] Remove include of deleted file --- cpp/include/raft/linalg/detail/eltwise.cuh | 2 -- 1 file changed, 2 deletions(-) diff --git a/cpp/include/raft/linalg/detail/eltwise.cuh b/cpp/include/raft/linalg/detail/eltwise.cuh index c087ab8c1c..93b19c37bb 100644 --- a/cpp/include/raft/linalg/detail/eltwise.cuh +++ b/cpp/include/raft/linalg/detail/eltwise.cuh @@ -16,8 +16,6 @@ #pragma once -#include "functional.cuh" - #include #include #include From a744213aeb0e6ecc9488fd0c450bcc678df6270b Mon Sep 17 00:00:00 2001 From: Louis Sugy Date: Fri, 25 Nov 2022 20:26:17 +0100 Subject: [PATCH 04/22] Clang format --- .../raft/distance/detail/fused_l2_nn.cuh | 32 ++++++++++++++++--- 1 file changed, 28 insertions(+), 4 deletions(-) diff --git a/cpp/include/raft/distance/detail/fused_l2_nn.cuh b/cpp/include/raft/distance/detail/fused_l2_nn.cuh index 437b45f1d2..b8778c691e 100644 --- a/cpp/include/raft/distance/detail/fused_l2_nn.cuh +++ b/cpp/include/raft/distance/detail/fused_l2_nn.cuh @@ -311,8 +311,20 @@ void fusedL2NNImpl(OutT* min, raft::Nop>; dim3 grid = launchConfigGenerator

(m, n, shmemSize, fusedL2NNSqrt); - fusedL2NNSqrt<<>>( - min, x, y, xn, yn, m, n, k, maxVal, workspace, redOp, pairRedOp, core_lambda, raft::Nop{}); + fusedL2NNSqrt<<>>(min, + x, + y, + xn, + yn, + m, + n, + k, + maxVal, + workspace, + redOp, + pairRedOp, + core_lambda, + raft::Nop{}); } else { auto fusedL2NN = fusedL2NNkernel>; dim3 grid = launchConfigGenerator

(m, n, shmemSize, fusedL2NN); - fusedL2NN<<>>( - min, x, y, xn, yn, m, n, k, maxVal, workspace, redOp, pairRedOp, core_lambda, raft::Nop{}); + fusedL2NN<<>>(min, + x, + y, + xn, + yn, + m, + n, + k, + maxVal, + workspace, + redOp, + pairRedOp, + core_lambda, + raft::Nop{}); } RAFT_CUDA_TRY(cudaGetLastError()); From 1b91e53cc85b4f136bfe5bf197144bcf4f2a89af Mon Sep 17 00:00:00 2001 From: Louis Sugy Date: Mon, 28 Nov 2022 11:15:51 +0100 Subject: [PATCH 05/22] Add missing namespace specifiers --- cpp/include/raft/distance/detail/hellinger.cuh | 12 ++++++++---- cpp/include/raft/linalg/sqrt.cuh | 2 +- cpp/include/raft/stats/detail/silhouette_score.cuh | 4 ++-- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/cpp/include/raft/distance/detail/hellinger.cuh b/cpp/include/raft/distance/detail/hellinger.cuh index 79d134459b..f39fd6158f 100644 --- a/cpp/include/raft/distance/detail/hellinger.cuh +++ b/cpp/include/raft/distance/detail/hellinger.cuh @@ -79,9 +79,11 @@ static void hellingerImpl(const DataT* x, dim3 blk(KPolicy::Nthreads); // First sqrt x and y - raft::linalg::unaryOp, IdxT>((DataT*)x, x, m * k, SqrtOp{}, stream); + raft::linalg::unaryOp, IdxT>( + (DataT*)x, x, m * k, raft::SqrtOp{}, stream); if (x != y) { - raft::linalg::unaryOp, IdxT>((DataT*)y, y, n * k, SqrtOp{}, stream); + raft::linalg::unaryOp, IdxT>( + (DataT*)y, y, n * k, raft::SqrtOp{}, stream); } // Accumulation operation lambda @@ -141,9 +143,11 @@ static void hellingerImpl(const DataT* x, } // Revert sqrt of x and y - raft::linalg::unaryOp, IdxT>((DataT*)x, x, m * k, SqrtOp{}, stream); + raft::linalg::unaryOp, IdxT>( + (DataT*)x, x, m * k, raft::SqrtOp{}, stream); if (x != y) { - raft::linalg::unaryOp, IdxT>((DataT*)y, y, n * k, SqrtOp{}, stream); + raft::linalg::unaryOp, IdxT>( + (DataT*)y, y, n * k, raft::SqrtOp{}, stream); } RAFT_CUDA_TRY(cudaGetLastError()); diff --git a/cpp/include/raft/linalg/sqrt.cuh b/cpp/include/raft/linalg/sqrt.cuh index a8cc3ec6ba..6d0a5d58f3 100644 --- a/cpp/include/raft/linalg/sqrt.cuh +++ b/cpp/include/raft/linalg/sqrt.cuh @@ -38,7 +38,7 @@ namespace linalg { template void sqrt(out_t* out, const in_t* in, IdxType len, cudaStream_t stream) { - raft::linalg::unaryOp(out, in, len, SqrtOp{}, stream); + raft::linalg::unaryOp(out, in, len, raft::SqrtOp{}, stream); } /** @} */ diff --git a/cpp/include/raft/stats/detail/silhouette_score.cuh b/cpp/include/raft/stats/detail/silhouette_score.cuh index 115a6ad9c6..b95d3783e5 100644 --- a/cpp/include/raft/stats/detail/silhouette_score.cuh +++ b/cpp/include/raft/stats/detail/silhouette_score.cuh @@ -272,11 +272,11 @@ DataT silhouette_score( nRows, true, true, - DivOp(), + raft::DivOp(), stream); // calculating row-wise minimum - raft::linalg::reduce, MinOp>( + raft::linalg::reduce, raft::MinOp>( d_bArray.data(), averageDistanceBetweenSampleAndCluster.data(), nLabels, From f74c3d4d79f7f171c7e84967d81837217b1bd928 Mon Sep 17 00:00:00 2001 From: Louis Sugy Date: Wed, 30 Nov 2022 14:35:21 +0100 Subject: [PATCH 06/22] New functors have automatic type inferrence, old ones are deprecrated --- cpp/bench/linalg/norm.cu | 3 +- cpp/include/raft/cluster/detail/kmeans.cuh | 30 +-- .../raft/cluster/detail/kmeans_common.cuh | 19 +- cpp/include/raft/cluster/kmeans.cuh | 3 +- cpp/include/raft/distance/detail/canberra.cuh | 2 +- .../raft/distance/detail/chebyshev.cuh | 2 +- .../raft/distance/detail/correlation.cuh | 12 +- cpp/include/raft/distance/detail/cosine.cuh | 7 +- .../raft/distance/detail/euclidean.cuh | 7 +- .../raft/distance/detail/fused_l2_nn.cuh | 8 +- .../raft/distance/detail/hellinger.cuh | 13 +- cpp/include/raft/distance/detail/l1.cuh | 2 +- .../raft/distance/detail/minkowski.cuh | 2 +- cpp/include/raft/label/detail/classlabels.cuh | 2 +- cpp/include/raft/linalg/add.cuh | 1 - .../raft/linalg/coalesced_reduction.cuh | 25 +- cpp/include/raft/linalg/detail/add.cuh | 4 +- .../linalg/detail/coalesced_reduction.cuh | 86 +++---- cpp/include/raft/linalg/detail/divide.cuh | 2 +- cpp/include/raft/linalg/detail/eltwise.cuh | 14 +- cpp/include/raft/linalg/detail/multiply.cuh | 3 +- cpp/include/raft/linalg/detail/norm.cuh | 25 +- cpp/include/raft/linalg/detail/reduce.cuh | 12 +- .../raft/linalg/detail/strided_reduction.cuh | 18 +- cpp/include/raft/linalg/detail/subtract.cuh | 4 +- cpp/include/raft/linalg/norm.cuh | 13 +- cpp/include/raft/linalg/normalize.cuh | 31 +-- cpp/include/raft/linalg/power.cuh | 4 +- cpp/include/raft/linalg/reduce.cuh | 25 +- cpp/include/raft/linalg/sqrt.cuh | 2 +- cpp/include/raft/linalg/strided_reduction.cuh | 25 +- cpp/include/raft/matrix/detail/gather.cuh | 6 +- cpp/include/raft/matrix/detail/math.cuh | 12 +- .../sparse/distance/detail/lp_distance.cuh | 6 +- cpp/include/raft/sparse/op/detail/slice.cuh | 2 +- .../raft/spatial/knn/detail/ann_quantized.cuh | 2 +- .../raft/spatial/knn/detail/fused_l2_knn.cuh | 19 +- .../spatial/knn/detail/ivf_flat_build.cuh | 3 +- .../spatial/knn/detail/ivf_flat_search.cuh | 2 +- .../raft/spatial/knn/detail/ivf_pq_build.cuh | 4 +- .../raft/spatial/knn/detail/ivf_pq_search.cuh | 4 +- .../raft/spatial/knn/detail/processing.cuh | 29 +-- .../stats/detail/batched/silhouette_score.cuh | 25 +- cpp/include/raft/stats/detail/mean_center.cuh | 4 +- .../raft/stats/detail/silhouette_score.cuh | 20 +- .../raft/stats/detail/weighted_mean.cuh | 5 +- cpp/include/raft/util/cuda_utils.cuh | 236 +++++++++++++----- cpp/include/raft/util/scatter.cuh | 4 +- cpp/src/distance/cluster_cost.cuh | 5 +- cpp/test/cluster/kmeans.cu | 2 +- cpp/test/distance/distance_base.cuh | 2 +- cpp/test/linalg/binary_op.cu | 4 +- cpp/test/linalg/coalesced_reduction.cu | 14 +- cpp/test/linalg/map_then_reduce.cu | 9 +- cpp/test/linalg/matrix_vector.cu | 7 +- cpp/test/linalg/norm.cu | 23 +- cpp/test/linalg/normalize.cu | 8 +- cpp/test/linalg/reduce.cu | 12 +- cpp/test/linalg/reduce.cuh | 36 +-- cpp/test/linalg/strided_reduction.cu | 15 +- cpp/test/matrix/linewise_op.cu | 4 +- cpp/test/sparse/dist_coo_spmv.cu | 2 +- 62 files changed, 506 insertions(+), 426 deletions(-) diff --git a/cpp/bench/linalg/norm.cu b/cpp/bench/linalg/norm.cu index cce4195cf1..efecee88c9 100644 --- a/cpp/bench/linalg/norm.cu +++ b/cpp/bench/linalg/norm.cu @@ -18,6 +18,7 @@ #include #include #include +#include #include #include @@ -60,7 +61,7 @@ struct rowNorm : public fixture { output_view, raft::linalg::L2Norm, raft::linalg::Apply::ALONG_ROWS, - raft::SqrtOp()); + raft::sqrt_op()); }); } diff --git a/cpp/include/raft/cluster/detail/kmeans.cuh b/cpp/include/raft/cluster/detail/kmeans.cuh index 4a74ea5801..d021a902a6 100644 --- a/cpp/include/raft/cluster/detail/kmeans.cuh +++ b/cpp/include/raft/cluster/detail/kmeans.cuh @@ -202,7 +202,7 @@ void kmeansPlusPlus(const raft::handle_t& handle, pwd.extent(0), true, true, - raft::Min{}, + raft::min_op{}, stream); // Calculate costPerCandidate[n_trials] where costPerCandidate[i] is the cluster cost when using @@ -330,7 +330,7 @@ void update_centroids(const raft::handle_t& handle, new_centroids.extent(0), true, false, - raft::DivideCheckZero{}, + raft::div_checkzero_op{}, handle.get_stream()); // copy centroids[i] to new_centroids[i] when weight_per_cluster[i] is 0 @@ -347,7 +347,7 @@ void update_centroids(const raft::handle_t& handle, // copy when the sum of weights in the cluster is 0 return map.value == 0; }, - raft::KeyOp{}, + raft::key_op{}, handle.get_stream()); } @@ -461,7 +461,7 @@ void kmeans_fit_main(const raft::handle_t& handle, auto sqrdNorm = raft::make_device_scalar(handle, DataT(0)); raft::linalg::mapThenSumReduce(sqrdNorm.data_handle(), newCentroids.size(), - raft::SqDiff{}, + raft::sqdiff_op{}, stream, centroids.data_handle(), newCentroids.data_handle()); @@ -479,8 +479,8 @@ void kmeans_fit_main(const raft::handle_t& handle, minClusterAndDistance.view(), workspace, raft::make_device_scalar_view(clusterCostD.data()), - raft::ValueOp{}, - raft::Sum{}); + raft::value_op{}, + raft::add_op{}); DataT curClusteringCost = clusterCostD.value(stream); @@ -536,8 +536,8 @@ void kmeans_fit_main(const raft::handle_t& handle, minClusterAndDistance.view(), workspace, raft::make_device_scalar_view(clusterCostD.data()), - raft::ValueOp{}, - raft::Sum{}); + raft::value_op{}, + raft::add_op{}); inertia[0] = clusterCostD.value(stream); @@ -651,8 +651,8 @@ void initScalableKMeansPlusPlus(const raft::handle_t& handle, minClusterDistanceVec.view(), workspace, raft::make_device_scalar_view(clusterCost.data()), - raft::Nop{}, - raft::Sum{}); + raft::identity_op{}, + raft::add_op{}); auto psi = clusterCost.value(stream); @@ -684,8 +684,8 @@ void initScalableKMeansPlusPlus(const raft::handle_t& handle, minClusterDistanceVec.view(), workspace, raft::make_device_scalar_view(clusterCost.data()), - raft::Nop{}, - raft::Sum{}); + raft::identity_op{}, + raft::add_op{}); psi = clusterCost.value(stream); @@ -1072,14 +1072,14 @@ void kmeans_predict(handle_t const& handle, minClusterAndDistance.view(), workspace, raft::make_device_scalar_view(clusterCostD.data()), - raft::ValueOp{}, - raft::Sum{}); + raft::value_op{}, + raft::add_op{}); thrust::transform(handle.get_thrust_policy(), minClusterAndDistance.data_handle(), minClusterAndDistance.data_handle() + minClusterAndDistance.size(), labels.data_handle(), - raft::KeyOp{}); + raft::key_op{}); inertia[0] = clusterCostD.value(stream); } diff --git a/cpp/include/raft/cluster/detail/kmeans_common.cuh b/cpp/include/raft/cluster/detail/kmeans_common.cuh index ab592e584e..77c772e2ce 100644 --- a/cpp/include/raft/cluster/detail/kmeans_common.cuh +++ b/cpp/include/raft/cluster/detail/kmeans_common.cuh @@ -156,8 +156,11 @@ void checkWeight(const raft::handle_t& handle, n_samples); auto scale = static_cast(n_samples) / wt_sum; - raft::linalg::unaryOp( - weight.data_handle(), weight.data_handle(), n_samples, raft::ScalarMul{scale}, stream); + raft::linalg::unaryOp(weight.data_handle(), + weight.data_handle(), + n_samples, + raft::scalar_mul_op{scale}, + stream); } } @@ -272,7 +275,7 @@ void sampleCentroids(const raft::handle_t& handle, sampledMinClusterDistance.data_handle(), nPtsSampledInRank, inRankCp.data(), - raft::KeyOp{}, + raft::key_op{}, stream); } @@ -467,8 +470,8 @@ void minClusterAndDistanceCompute( pair.value = val; return pair; }, - raft::ArgMin{}, - raft::Nop, IndexT>{}); + raft::argmin_op{}, + raft::identity_op{}); } } } @@ -584,9 +587,9 @@ void minClusterDistanceCompute(const raft::handle_t& handle, std::numeric_limits::max(), stream, true, - raft::Nop{}, - raft::Min{}, - raft::Nop{}); + raft::identity_op{}, + raft::min_op{}, + raft::identity_op{}); } } } diff --git a/cpp/include/raft/cluster/kmeans.cuh b/cpp/include/raft/cluster/kmeans.cuh index ac51164375..994d4992b3 100644 --- a/cpp/include/raft/cluster/kmeans.cuh +++ b/cpp/include/raft/cluster/kmeans.cuh @@ -20,6 +20,7 @@ #include #include #include +#include namespace raft::cluster::kmeans { @@ -314,7 +315,7 @@ void cluster_cost(const raft::handle_t& handle, ReductionOpT reduction_op) { detail::computeClusterCost( - handle, minClusterDistance, workspace, clusterCost, raft::Nop{}, reduction_op); + handle, minClusterDistance, workspace, clusterCost, raft::identity_op{}, reduction_op); } /** diff --git a/cpp/include/raft/distance/detail/canberra.cuh b/cpp/include/raft/distance/detail/canberra.cuh index 6be994b80a..4693d742a1 100644 --- a/cpp/include/raft/distance/detail/canberra.cuh +++ b/cpp/include/raft/distance/detail/canberra.cuh @@ -73,7 +73,7 @@ static void canberraImpl(const DataT* x, // Accumulation operation lambda auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) { - const auto diff = raft::L1Op()(x - y); + const auto diff = raft::myAbs(x - y); const auto add = raft::myAbs(x) + raft::myAbs(y); // deal with potential for 0 in denominator by // forcing 1/0 instead diff --git a/cpp/include/raft/distance/detail/chebyshev.cuh b/cpp/include/raft/distance/detail/chebyshev.cuh index 1ac10f269e..c2312824c1 100644 --- a/cpp/include/raft/distance/detail/chebyshev.cuh +++ b/cpp/include/raft/distance/detail/chebyshev.cuh @@ -72,7 +72,7 @@ static void chebyshevImpl(const DataT* x, // Accumulation operation lambda auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) { - const auto diff = raft::L1Op()(x - y); + const auto diff = raft::myAbs(x - y); acc = raft::myMax(acc, diff); }; diff --git a/cpp/include/raft/distance/detail/correlation.cuh b/cpp/include/raft/distance/detail/correlation.cuh index 2b77d280fe..9bdbbf112c 100644 --- a/cpp/include/raft/distance/detail/correlation.cuh +++ b/cpp/include/raft/distance/detail/correlation.cuh @@ -262,8 +262,8 @@ void correlationImpl(int m, true, stream, false, - raft::Nop(), - raft::Sum()); + raft::identity_op(), + raft::add_op()); raft::linalg::reduce(norm_row_vec, pB, k, @@ -273,8 +273,8 @@ void correlationImpl(int m, true, stream, false, - raft::Nop(), - raft::Sum()); + raft::identity_op(), + raft::add_op()); sq_norm_col_vec += (m + n); sq_norm_row_vec = sq_norm_col_vec + m; @@ -290,8 +290,8 @@ void correlationImpl(int m, true, stream, false, - raft::Nop(), - raft::Sum()); + raft::identity_op(), + raft::add_op()); sq_norm_col_vec += m; sq_norm_row_vec = sq_norm_col_vec; raft::linalg::rowNorm(sq_norm_col_vec, pA, k, m, raft::linalg::L2Norm, isRowMajor, stream); diff --git a/cpp/include/raft/distance/detail/cosine.cuh b/cpp/include/raft/distance/detail/cosine.cuh index ed843bb0c7..46a694aa51 100644 --- a/cpp/include/raft/distance/detail/cosine.cuh +++ b/cpp/include/raft/distance/detail/cosine.cuh @@ -19,6 +19,7 @@ #include #include #include +#include namespace raft { namespace distance { @@ -247,12 +248,12 @@ void cosineAlgo1(Index_ m, if (pA != pB) { row_vec += m; raft::linalg::rowNorm( - col_vec, pA, k, m, raft::linalg::L2Norm, isRowMajor, stream, raft::SqrtOp{}); + col_vec, pA, k, m, raft::linalg::L2Norm, isRowMajor, stream, raft::sqrt_op{}); raft::linalg::rowNorm( - row_vec, pB, k, n, raft::linalg::L2Norm, isRowMajor, stream, raft::SqrtOp{}); + row_vec, pB, k, n, raft::linalg::L2Norm, isRowMajor, stream, raft::sqrt_op{}); } else { raft::linalg::rowNorm( - col_vec, pA, k, m, raft::linalg::L2Norm, isRowMajor, stream, raft::SqrtOp{}); + col_vec, pA, k, m, raft::linalg::L2Norm, isRowMajor, stream, raft::sqrt_op{}); } if (isRowMajor) { diff --git a/cpp/include/raft/distance/detail/euclidean.cuh b/cpp/include/raft/distance/detail/euclidean.cuh index e8348b3cab..4184810fff 100644 --- a/cpp/include/raft/distance/detail/euclidean.cuh +++ b/cpp/include/raft/distance/detail/euclidean.cuh @@ -19,6 +19,7 @@ #include #include #include +#include namespace raft { namespace distance { @@ -265,12 +266,12 @@ void euclideanAlgo1(Index_ m, if (pA != pB) { row_vec += m; raft::linalg::rowNorm( - col_vec, pA, k, m, raft::linalg::L2Norm, isRowMajor, stream, raft::Nop{}); + col_vec, pA, k, m, raft::linalg::L2Norm, isRowMajor, stream, raft::identity_op{}); raft::linalg::rowNorm( - row_vec, pB, k, n, raft::linalg::L2Norm, isRowMajor, stream, raft::Nop{}); + row_vec, pB, k, n, raft::linalg::L2Norm, isRowMajor, stream, raft::identity_op{}); } else { raft::linalg::rowNorm( - col_vec, pA, k, m, raft::linalg::L2Norm, isRowMajor, stream, raft::Nop{}); + col_vec, pA, k, m, raft::linalg::L2Norm, isRowMajor, stream, raft::identity_op{}); } if (isRowMajor) { diff --git a/cpp/include/raft/distance/detail/fused_l2_nn.cuh b/cpp/include/raft/distance/detail/fused_l2_nn.cuh index b8778c691e..c9750df8ad 100644 --- a/cpp/include/raft/distance/detail/fused_l2_nn.cuh +++ b/cpp/include/raft/distance/detail/fused_l2_nn.cuh @@ -308,7 +308,7 @@ void fusedL2NNImpl(OutT* min, ReduceOpT, KVPReduceOpT, decltype(core_lambda), - raft::Nop>; + raft::identity_op>; dim3 grid = launchConfigGenerator

(m, n, shmemSize, fusedL2NNSqrt); fusedL2NNSqrt<<>>(min, @@ -324,7 +324,7 @@ void fusedL2NNImpl(OutT* min, redOp, pairRedOp, core_lambda, - raft::Nop{}); + raft::identity_op{}); } else { auto fusedL2NN = fusedL2NNkernel>; + raft::identity_op>; dim3 grid = launchConfigGenerator

(m, n, shmemSize, fusedL2NN); fusedL2NN<<>>(min, x, @@ -349,7 +349,7 @@ void fusedL2NNImpl(OutT* min, redOp, pairRedOp, core_lambda, - raft::Nop{}); + raft::identity_op{}); } RAFT_CUDA_TRY(cudaGetLastError()); diff --git a/cpp/include/raft/distance/detail/hellinger.cuh b/cpp/include/raft/distance/detail/hellinger.cuh index f39fd6158f..51f462ab36 100644 --- a/cpp/include/raft/distance/detail/hellinger.cuh +++ b/cpp/include/raft/distance/detail/hellinger.cuh @@ -17,6 +17,7 @@ #pragma once #include #include +#include namespace raft { namespace distance { @@ -79,11 +80,9 @@ static void hellingerImpl(const DataT* x, dim3 blk(KPolicy::Nthreads); // First sqrt x and y - raft::linalg::unaryOp, IdxT>( - (DataT*)x, x, m * k, raft::SqrtOp{}, stream); + raft::linalg::unaryOp((DataT*)x, x, m * k, raft::sqrt_op{}, stream); if (x != y) { - raft::linalg::unaryOp, IdxT>( - (DataT*)y, y, n * k, raft::SqrtOp{}, stream); + raft::linalg::unaryOp((DataT*)y, y, n * k, raft::sqrt_op{}, stream); } // Accumulation operation lambda @@ -143,11 +142,9 @@ static void hellingerImpl(const DataT* x, } // Revert sqrt of x and y - raft::linalg::unaryOp, IdxT>( - (DataT*)x, x, m * k, raft::SqrtOp{}, stream); + raft::linalg::unaryOp((DataT*)x, x, m * k, raft::sqrt_op{}, stream); if (x != y) { - raft::linalg::unaryOp, IdxT>( - (DataT*)y, y, n * k, raft::SqrtOp{}, stream); + raft::linalg::unaryOp((DataT*)y, y, n * k, raft::sqrt_op{}, stream); } RAFT_CUDA_TRY(cudaGetLastError()); diff --git a/cpp/include/raft/distance/detail/l1.cuh b/cpp/include/raft/distance/detail/l1.cuh index 6372019fd3..8de0035dbe 100644 --- a/cpp/include/raft/distance/detail/l1.cuh +++ b/cpp/include/raft/distance/detail/l1.cuh @@ -71,7 +71,7 @@ static void l1Impl(const DataT* x, // Accumulation operation lambda auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) { - const auto diff = raft::L1Op()(x - y); + const auto diff = raft::myAbs(x - y); acc += diff; }; diff --git a/cpp/include/raft/distance/detail/minkowski.cuh b/cpp/include/raft/distance/detail/minkowski.cuh index d3d0979d0d..71323fc61d 100644 --- a/cpp/include/raft/distance/detail/minkowski.cuh +++ b/cpp/include/raft/distance/detail/minkowski.cuh @@ -74,7 +74,7 @@ void minkowskiUnExpImpl(const DataT* x, // Accumulation operation lambda auto core_lambda = [p] __device__(AccT & acc, DataT & x, DataT & y) { - const auto diff = raft::L1Op()(x - y); + const auto diff = raft::myAbs(x - y); acc += raft::myPow(diff, p); }; diff --git a/cpp/include/raft/label/detail/classlabels.cuh b/cpp/include/raft/label/detail/classlabels.cuh index a3a98d3124..cef249e6e1 100644 --- a/cpp/include/raft/label/detail/classlabels.cuh +++ b/cpp/include/raft/label/detail/classlabels.cuh @@ -194,7 +194,7 @@ void make_monotonic( template void make_monotonic(Type* out, Type* in, size_t N, cudaStream_t stream, bool zero_based = false) { - make_monotonic(out, in, N, stream, raft::ConstOp(false), zero_based); + make_monotonic(out, in, N, stream, raft::const_op(false), zero_based); } }; // namespace detail diff --git a/cpp/include/raft/linalg/add.cuh b/cpp/include/raft/linalg/add.cuh index e54eaedec6..ec16db5251 100644 --- a/cpp/include/raft/linalg/add.cuh +++ b/cpp/include/raft/linalg/add.cuh @@ -27,7 +27,6 @@ #include #include -#include #include namespace raft { diff --git a/cpp/include/raft/linalg/coalesced_reduction.cuh b/cpp/include/raft/linalg/coalesced_reduction.cuh index e9e5a99f46..b9c20c7e4c 100644 --- a/cpp/include/raft/linalg/coalesced_reduction.cuh +++ b/cpp/include/raft/linalg/coalesced_reduction.cuh @@ -22,6 +22,7 @@ #include #include +#include namespace raft { namespace linalg { @@ -56,9 +57,9 @@ namespace linalg { template , - typename ReduceLambda = raft::Sum, - typename FinalLambda = raft::Nop> + typename MainLambda = raft::identity_op, + typename ReduceLambda = raft::add_op, + typename FinalLambda = raft::identity_op> void coalescedReduction(OutType* dots, const InType* data, IdxType D, @@ -66,9 +67,9 @@ void coalescedReduction(OutType* dots, OutType init, cudaStream_t stream, bool inplace = false, - MainLambda main_op = raft::Nop(), - ReduceLambda reduce_op = raft::Sum(), - FinalLambda final_op = raft::Nop()) + MainLambda main_op = raft::identity_op(), + ReduceLambda reduce_op = raft::add_op(), + FinalLambda final_op = raft::identity_op()) { detail::coalescedReduction( dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); @@ -113,17 +114,17 @@ template , - typename ReduceLambda = raft::Sum, - typename FinalLambda = raft::Nop> + typename MainLambda = raft::identity_op, + typename ReduceLambda = raft::add_op, + typename FinalLambda = raft::identity_op> void coalesced_reduction(const raft::handle_t& handle, raft::device_matrix_view data, raft::device_vector_view dots, OutValueType init, bool inplace = false, - MainLambda main_op = raft::Nop(), - ReduceLambda reduce_op = raft::Sum(), - FinalLambda final_op = raft::Nop()) + MainLambda main_op = raft::identity_op(), + ReduceLambda reduce_op = raft::add_op(), + FinalLambda final_op = raft::identity_op()) { if constexpr (std::is_same_v) { RAFT_EXPECTS(static_cast(dots.size()) == data.extent(0), diff --git a/cpp/include/raft/linalg/detail/add.cuh b/cpp/include/raft/linalg/detail/add.cuh index 81489ed287..b1b6922809 100644 --- a/cpp/include/raft/linalg/detail/add.cuh +++ b/cpp/include/raft/linalg/detail/add.cuh @@ -27,13 +27,13 @@ namespace detail { template void addScalar(OutT* out, const InT* in, InT scalar, IdxType len, cudaStream_t stream) { - raft::linalg::unaryOp(out, in, len, raft::ScalarAdd(scalar), stream); + raft::linalg::unaryOp(out, in, len, raft::scalar_add_op(scalar), stream); } template void add(OutT* out, const InT* in1, const InT* in2, IdxType len, cudaStream_t stream) { - raft::linalg::binaryOp(out, in1, in2, len, raft::Sum(), stream); + raft::linalg::binaryOp(out, in1, in2, len, raft::add_op(), stream); } template diff --git a/cpp/include/raft/linalg/detail/coalesced_reduction.cuh b/cpp/include/raft/linalg/detail/coalesced_reduction.cuh index 63351f5475..4dc3d5bd8c 100644 --- a/cpp/include/raft/linalg/detail/coalesced_reduction.cuh +++ b/cpp/include/raft/linalg/detail/coalesced_reduction.cuh @@ -71,9 +71,9 @@ template , - typename ReduceLambda = raft::Sum, - typename FinalLambda = raft::Nop> + typename MainLambda = raft::identity_op, + typename ReduceLambda = raft::add_op, + typename FinalLambda = raft::identity_op> void coalescedReductionThin(OutType* dots, const InType* data, IdxType D, @@ -81,9 +81,9 @@ void coalescedReductionThin(OutType* dots, OutType init, cudaStream_t stream, bool inplace = false, - MainLambda main_op = raft::Nop(), - ReduceLambda reduce_op = raft::Sum(), - FinalLambda final_op = raft::Nop()) + MainLambda main_op = raft::identity_op(), + ReduceLambda reduce_op = raft::add_op(), + FinalLambda final_op = raft::identity_op()) { common::nvtx::range fun_scope( "coalescedReductionThin<%d,%d>", Policy::LogicalWarpSize, Policy::RowsPerBlock); @@ -97,9 +97,9 @@ void coalescedReductionThin(OutType* dots, template , - typename ReduceLambda = raft::Sum, - typename FinalLambda = raft::Nop> + typename MainLambda = raft::identity_op, + typename ReduceLambda = raft::add_op, + typename FinalLambda = raft::identity_op> void coalescedReductionThinDispatcher(OutType* dots, const InType* data, IdxType D, @@ -107,9 +107,9 @@ void coalescedReductionThinDispatcher(OutType* dots, OutType init, cudaStream_t stream, bool inplace = false, - MainLambda main_op = raft::Nop(), - ReduceLambda reduce_op = raft::Sum(), - FinalLambda final_op = raft::Nop()) + MainLambda main_op = raft::identity_op(), + ReduceLambda reduce_op = raft::add_op(), + FinalLambda final_op = raft::identity_op()) { if (D <= IdxType(2)) { coalescedReductionThin>( @@ -168,9 +168,9 @@ template , - typename ReduceLambda = raft::Sum, - typename FinalLambda = raft::Nop> + typename MainLambda = raft::identity_op, + typename ReduceLambda = raft::add_op, + typename FinalLambda = raft::identity_op> void coalescedReductionMedium(OutType* dots, const InType* data, IdxType D, @@ -178,9 +178,9 @@ void coalescedReductionMedium(OutType* dots, OutType init, cudaStream_t stream, bool inplace = false, - MainLambda main_op = raft::Nop(), - ReduceLambda reduce_op = raft::Sum(), - FinalLambda final_op = raft::Nop()) + MainLambda main_op = raft::identity_op(), + ReduceLambda reduce_op = raft::add_op(), + FinalLambda final_op = raft::identity_op()) { common::nvtx::range fun_scope("coalescedReductionMedium<%d>", TPB); coalescedReductionMediumKernel @@ -191,9 +191,9 @@ void coalescedReductionMedium(OutType* dots, template , - typename ReduceLambda = raft::Sum, - typename FinalLambda = raft::Nop> + typename MainLambda = raft::identity_op, + typename ReduceLambda = raft::add_op, + typename FinalLambda = raft::identity_op> void coalescedReductionMediumDispatcher(OutType* dots, const InType* data, IdxType D, @@ -201,9 +201,9 @@ void coalescedReductionMediumDispatcher(OutType* dots, OutType init, cudaStream_t stream, bool inplace = false, - MainLambda main_op = raft::Nop(), - ReduceLambda reduce_op = raft::Sum(), - FinalLambda final_op = raft::Nop()) + MainLambda main_op = raft::identity_op(), + ReduceLambda reduce_op = raft::add_op(), + FinalLambda final_op = raft::identity_op()) { // Note: for now, this kernel is only used when D > 256. If this changes in the future, use // smaller block sizes when relevant. @@ -251,9 +251,9 @@ template , - typename ReduceLambda = raft::Sum, - typename FinalLambda = raft::Nop> + typename MainLambda = raft::identity_op, + typename ReduceLambda = raft::add_op, + typename FinalLambda = raft::identity_op> void coalescedReductionThick(OutType* dots, const InType* data, IdxType D, @@ -261,9 +261,9 @@ void coalescedReductionThick(OutType* dots, OutType init, cudaStream_t stream, bool inplace = false, - MainLambda main_op = raft::Nop(), - ReduceLambda reduce_op = raft::Sum(), - FinalLambda final_op = raft::Nop()) + MainLambda main_op = raft::identity_op(), + ReduceLambda reduce_op = raft::add_op(), + FinalLambda final_op = raft::identity_op()) { common::nvtx::range fun_scope( "coalescedReductionThick<%d,%d>", ThickPolicy::ThreadsPerBlock, ThickPolicy::BlocksPerRow); @@ -291,7 +291,7 @@ void coalescedReductionThick(OutType* dots, init, stream, inplace, - raft::Nop(), + raft::identity_op(), reduce_op, final_op); } @@ -299,9 +299,9 @@ void coalescedReductionThick(OutType* dots, template , - typename ReduceLambda = raft::Sum, - typename FinalLambda = raft::Nop> + typename MainLambda = raft::identity_op, + typename ReduceLambda = raft::add_op, + typename FinalLambda = raft::identity_op> void coalescedReductionThickDispatcher(OutType* dots, const InType* data, IdxType D, @@ -309,9 +309,9 @@ void coalescedReductionThickDispatcher(OutType* dots, OutType init, cudaStream_t stream, bool inplace = false, - MainLambda main_op = raft::Nop(), - ReduceLambda reduce_op = raft::Sum(), - FinalLambda final_op = raft::Nop()) + MainLambda main_op = raft::identity_op(), + ReduceLambda reduce_op = raft::add_op(), + FinalLambda final_op = raft::identity_op()) { // Note: multiple elements per thread to take advantage of the sequential reduction and loop // unrolling @@ -330,9 +330,9 @@ void coalescedReductionThickDispatcher(OutType* dots, template , - typename ReduceLambda = raft::Sum, - typename FinalLambda = raft::Nop> + typename MainLambda = raft::identity_op, + typename ReduceLambda = raft::add_op, + typename FinalLambda = raft::identity_op> void coalescedReduction(OutType* dots, const InType* data, IdxType D, @@ -340,9 +340,9 @@ void coalescedReduction(OutType* dots, OutType init, cudaStream_t stream, bool inplace = false, - MainLambda main_op = raft::Nop(), - ReduceLambda reduce_op = raft::Sum(), - FinalLambda final_op = raft::Nop()) + MainLambda main_op = raft::identity_op(), + ReduceLambda reduce_op = raft::add_op(), + FinalLambda final_op = raft::identity_op()) { /* The primitive selects one of three implementations based on heuristics: * - Thin: very efficient when D is small and/or N is large diff --git a/cpp/include/raft/linalg/detail/divide.cuh b/cpp/include/raft/linalg/detail/divide.cuh index c699deca78..29388a2035 100644 --- a/cpp/include/raft/linalg/detail/divide.cuh +++ b/cpp/include/raft/linalg/detail/divide.cuh @@ -27,7 +27,7 @@ namespace detail { template void divideScalar(OutT* out, const InT* in, InT scalar, IdxType len, cudaStream_t stream) { - raft::linalg::unaryOp(out, in, len, raft::ScalarDiv(scalar), stream); + raft::linalg::unaryOp(out, in, len, raft::scalar_div_op(scalar), stream); } }; // end namespace detail diff --git a/cpp/include/raft/linalg/detail/eltwise.cuh b/cpp/include/raft/linalg/detail/eltwise.cuh index 93b19c37bb..f744876820 100644 --- a/cpp/include/raft/linalg/detail/eltwise.cuh +++ b/cpp/include/raft/linalg/detail/eltwise.cuh @@ -27,48 +27,48 @@ namespace detail { template void scalarAdd(OutType* out, const InType* in, InType scalar, IdxType len, cudaStream_t stream) { - raft::linalg::unaryOp(out, in, len, raft::ScalarAdd(scalar), stream); + raft::linalg::unaryOp(out, in, len, raft::scalar_add_op(scalar), stream); } template void scalarMultiply(OutType* out, const InType* in, InType scalar, IdxType len, cudaStream_t stream) { - raft::linalg::unaryOp(out, in, len, raft::ScalarMul(scalar), stream); + raft::linalg::unaryOp(out, in, len, raft::scalar_mul_op(scalar), stream); } template void eltwiseAdd( OutType* out, const InType* in1, const InType* in2, IdxType len, cudaStream_t stream) { - raft::linalg::binaryOp(out, in1, in2, len, raft::Sum(), stream); + raft::linalg::binaryOp(out, in1, in2, len, raft::add_op(), stream); } template void eltwiseSub( OutType* out, const InType* in1, const InType* in2, IdxType len, cudaStream_t stream) { - raft::linalg::binaryOp(out, in1, in2, len, raft::Subtract(), stream); + raft::linalg::binaryOp(out, in1, in2, len, raft::sub_op(), stream); } template void eltwiseMultiply( OutType* out, const InType* in1, const InType* in2, IdxType len, cudaStream_t stream) { - raft::linalg::binaryOp(out, in1, in2, len, raft::Multiply(), stream); + raft::linalg::binaryOp(out, in1, in2, len, raft::mul_op(), stream); } template void eltwiseDivide( OutType* out, const InType* in1, const InType* in2, IdxType len, cudaStream_t stream) { - raft::linalg::binaryOp(out, in1, in2, len, raft::Divide(), stream); + raft::linalg::binaryOp(out, in1, in2, len, raft::div_op(), stream); } template void eltwiseDivideCheckZero( OutType* out, const InType* in1, const InType* in2, IdxType len, cudaStream_t stream) { - raft::linalg::binaryOp(out, in1, in2, len, raft::ScalarDivCheckZero(), stream); + raft::linalg::binaryOp(out, in1, in2, len, raft::div_checkzero_op(), stream); } }; // end namespace detail diff --git a/cpp/include/raft/linalg/detail/multiply.cuh b/cpp/include/raft/linalg/detail/multiply.cuh index 7c492b839f..8641ccd154 100644 --- a/cpp/include/raft/linalg/detail/multiply.cuh +++ b/cpp/include/raft/linalg/detail/multiply.cuh @@ -17,6 +17,7 @@ #pragma once #include +#include namespace raft { namespace linalg { @@ -26,7 +27,7 @@ template void multiplyScalar( math_t* out, const math_t* in, const math_t scalar, IdxType len, cudaStream_t stream) { - raft::linalg::unaryOp(out, in, len, raft::ScalarMul{scalar}, stream); + raft::linalg::unaryOp(out, in, len, raft::scalar_mul_op{scalar}, stream); } }; // end namespace detail diff --git a/cpp/include/raft/linalg/detail/norm.cuh b/cpp/include/raft/linalg/detail/norm.cuh index f2f08233d5..2927a1bcfe 100644 --- a/cpp/include/raft/linalg/detail/norm.cuh +++ b/cpp/include/raft/linalg/detail/norm.cuh @@ -18,6 +18,7 @@ #include #include +#include namespace raft { namespace linalg { @@ -44,8 +45,8 @@ void rowNormCaller(Type* dots, true, stream, false, - raft::L1Op(), - raft::Sum(), + raft::abs_op(), + raft::add_op(), fin_op); break; case L2Norm: @@ -58,8 +59,8 @@ void rowNormCaller(Type* dots, true, stream, false, - raft::L2Op(), - raft::Sum(), + raft::sq_op(), + raft::add_op(), fin_op); break; case LinfNorm: @@ -72,8 +73,8 @@ void rowNormCaller(Type* dots, true, stream, false, - raft::L1Op(), - raft::Max(), + raft::abs_op(), + raft::max_op(), fin_op); break; default: THROW("Unsupported norm type: %d", type); @@ -101,8 +102,8 @@ void colNormCaller(Type* dots, false, stream, false, - raft::L1Op(), - raft::Sum(), + raft::abs_op(), + raft::add_op(), fin_op); break; case L2Norm: @@ -115,8 +116,8 @@ void colNormCaller(Type* dots, false, stream, false, - raft::L2Op(), - raft::Sum(), + raft::sq_op(), + raft::add_op(), fin_op); break; case LinfNorm: @@ -129,8 +130,8 @@ void colNormCaller(Type* dots, false, stream, false, - raft::L1Op(), - raft::Max(), + raft::abs_op(), + raft::max_op(), fin_op); break; default: THROW("Unsupported norm type: %d", type); diff --git a/cpp/include/raft/linalg/detail/reduce.cuh b/cpp/include/raft/linalg/detail/reduce.cuh index 3022973b43..b359019bef 100644 --- a/cpp/include/raft/linalg/detail/reduce.cuh +++ b/cpp/include/raft/linalg/detail/reduce.cuh @@ -27,9 +27,9 @@ namespace detail { template , - typename ReduceLambda = raft::Sum, - typename FinalLambda = raft::Nop> + typename MainLambda = raft::identity_op, + typename ReduceLambda = raft::add_op, + typename FinalLambda = raft::identity_op> void reduce(OutType* dots, const InType* data, IdxType D, @@ -39,9 +39,9 @@ void reduce(OutType* dots, bool alongRows, cudaStream_t stream, bool inplace = false, - MainLambda main_op = raft::Nop(), - ReduceLambda reduce_op = raft::Sum(), - FinalLambda final_op = raft::Nop()) + MainLambda main_op = raft::identity_op(), + ReduceLambda reduce_op = raft::add_op(), + FinalLambda final_op = raft::identity_op()) { if (rowMajor && alongRows) { raft::linalg::coalescedReduction( diff --git a/cpp/include/raft/linalg/detail/strided_reduction.cuh b/cpp/include/raft/linalg/detail/strided_reduction.cuh index d9b14a8155..aa4517b2b5 100644 --- a/cpp/include/raft/linalg/detail/strided_reduction.cuh +++ b/cpp/include/raft/linalg/detail/strided_reduction.cuh @@ -107,9 +107,9 @@ __global__ void stridedReductionKernel(OutType* dots, template , - typename ReduceLambda = raft::Sum, - typename FinalLambda = raft::Nop> + typename MainLambda = raft::identity_op, + typename ReduceLambda = raft::add_op, + typename FinalLambda = raft::identity_op> void stridedReduction(OutType* dots, const InType* data, IdxType D, @@ -117,13 +117,13 @@ void stridedReduction(OutType* dots, OutType init, cudaStream_t stream, bool inplace = false, - MainLambda main_op = raft::Nop(), - ReduceLambda reduce_op = raft::Sum(), - FinalLambda final_op = raft::Nop()) + MainLambda main_op = raft::identity_op(), + ReduceLambda reduce_op = raft::add_op(), + FinalLambda final_op = raft::identity_op()) { ///@todo: this extra should go away once we have eliminated the need /// for atomics in stridedKernel (redesign for this is already underway) - if (!inplace) raft::linalg::unaryOp(dots, dots, D, raft::ConstOp(init), stream); + if (!inplace) raft::linalg::unaryOp(dots, dots, D, raft::const_op(init), stream); // Arbitrary numbers for now, probably need to tune const dim3 thrds(32, 16); @@ -135,7 +135,7 @@ void stridedReduction(OutType* dots, ///@todo: this complication should go away once we have eliminated the need /// for atomics in stridedKernel (redesign for this is already underway) - if constexpr (std::is_same>::value && + if constexpr (std::is_same::value && std::is_same::value) stridedSummationKernel <<>>(dots, data, D, N, init, main_op); @@ -146,7 +146,7 @@ void stridedReduction(OutType* dots, ///@todo: this complication should go away once we have eliminated the need /// for atomics in stridedKernel (redesign for this is already underway) // Perform final op on output data - if (!std::is_same>::value) + if (!std::is_same::value) raft::linalg::unaryOp(dots, dots, D, final_op, stream); } diff --git a/cpp/include/raft/linalg/detail/subtract.cuh b/cpp/include/raft/linalg/detail/subtract.cuh index e28f7c8e4c..3eebd1a55f 100644 --- a/cpp/include/raft/linalg/detail/subtract.cuh +++ b/cpp/include/raft/linalg/detail/subtract.cuh @@ -27,13 +27,13 @@ namespace detail { template void subtractScalar(OutT* out, const InT* in, InT scalar, IdxType len, cudaStream_t stream) { - raft::linalg::unaryOp(out, in, len, raft::ScalarSub(scalar), stream); + raft::linalg::unaryOp(out, in, len, raft::scalar_sub_op(scalar), stream); } template void subtract(OutT* out, const InT* in1, const InT* in2, IdxType len, cudaStream_t stream) { - raft::linalg::binaryOp(out, in1, in2, len, raft::Subtract(), stream); + raft::linalg::binaryOp(out, in1, in2, len, raft::sub_op(), stream); } template diff --git a/cpp/include/raft/linalg/norm.cuh b/cpp/include/raft/linalg/norm.cuh index 9abfd3bdb0..91e47ece6f 100644 --- a/cpp/include/raft/linalg/norm.cuh +++ b/cpp/include/raft/linalg/norm.cuh @@ -23,6 +23,7 @@ #include #include +#include #include namespace raft { @@ -47,7 +48,7 @@ namespace linalg { * @param stream cuda stream where to launch work * @param fin_op the final lambda op */ -template > +template void rowNorm(Type* dots, const Type* data, IdxType D, @@ -55,7 +56,7 @@ void rowNorm(Type* dots, NormType type, bool rowMajor, cudaStream_t stream, - Lambda fin_op = raft::Nop()) + Lambda fin_op = raft::identity_op()) { detail::rowNormCaller(dots, data, D, N, type, rowMajor, stream, fin_op); } @@ -74,7 +75,7 @@ void rowNorm(Type* dots, * @param stream cuda stream where to launch work * @param fin_op the final lambda op */ -template > +template void colNorm(Type* dots, const Type* data, IdxType D, @@ -82,7 +83,7 @@ void colNorm(Type* dots, NormType type, bool rowMajor, cudaStream_t stream, - Lambda fin_op = raft::Nop()) + Lambda fin_op = raft::identity_op()) { detail::colNormCaller(dots, data, D, N, type, rowMajor, stream, fin_op); } @@ -104,13 +105,13 @@ void colNorm(Type* dots, template > + typename Lambda = raft::identity_op> void norm(const raft::handle_t& handle, raft::device_matrix_view in, raft::device_vector_view out, NormType type, Apply apply, - Lambda fin_op = raft::Nop()) + Lambda fin_op = raft::identity_op()) { RAFT_EXPECTS(raft::is_row_or_column_major(in), "Input must be contiguous"); diff --git a/cpp/include/raft/linalg/normalize.cuh b/cpp/include/raft/linalg/normalize.cuh index 4bdf697581..fba8a0f09c 100644 --- a/cpp/include/raft/linalg/normalize.cuh +++ b/cpp/include/raft/linalg/normalize.cuh @@ -19,6 +19,7 @@ #include "detail/normalize.cuh" #include +#include namespace raft { namespace linalg { @@ -94,34 +95,16 @@ void row_normalize(const raft::handle_t& handle, { switch (norm_type) { case L1Norm: - row_normalize(handle, - in, - out, - ElementType(0), - raft::L1Op(), - raft::Sum(), - raft::Nop(), - eps); + row_normalize( + handle, in, out, ElementType(0), raft::abs_op(), raft::add_op(), raft::identity_op(), eps); break; case L2Norm: - row_normalize(handle, - in, - out, - ElementType(0), - raft::L2Op(), - raft::Sum(), - raft::SqrtOp(), - eps); + row_normalize( + handle, in, out, ElementType(0), raft::sq_op(), raft::add_op(), raft::sqrt_op(), eps); break; case LinfNorm: - row_normalize(handle, - in, - out, - ElementType(0), - raft::L1Op(), - raft::Max(), - raft::Nop(), - eps); + row_normalize( + handle, in, out, ElementType(0), raft::abs_op(), raft::max_op(), raft::identity_op(), eps); break; default: THROW("Unsupported norm type: %d", norm_type); } diff --git a/cpp/include/raft/linalg/power.cuh b/cpp/include/raft/linalg/power.cuh index d5d898d768..ef8fea20ee 100644 --- a/cpp/include/raft/linalg/power.cuh +++ b/cpp/include/raft/linalg/power.cuh @@ -41,7 +41,7 @@ namespace linalg { template void powerScalar(out_t* out, const in_t* in, const in_t scalar, IdxType len, cudaStream_t stream) { - raft::linalg::unaryOp(out, in, len, raft::ScalarPow(scalar), stream); + raft::linalg::unaryOp(out, in, len, raft::scalar_pow_op(scalar), stream); } /** @} */ @@ -60,7 +60,7 @@ void powerScalar(out_t* out, const in_t* in, const in_t scalar, IdxType len, cud template void power(out_t* out, const in_t* in1, const in_t* in2, IdxType len, cudaStream_t stream) { - raft::linalg::binaryOp(out, in1, in2, len, raft::Pow(), stream); + raft::linalg::binaryOp(out, in1, in2, len, raft::pow_op(), stream); } /** @} */ diff --git a/cpp/include/raft/linalg/reduce.cuh b/cpp/include/raft/linalg/reduce.cuh index 5579acf355..45660e5ed8 100644 --- a/cpp/include/raft/linalg/reduce.cuh +++ b/cpp/include/raft/linalg/reduce.cuh @@ -22,6 +22,7 @@ #include "linalg_types.hpp" #include +#include #include namespace raft { @@ -59,9 +60,9 @@ namespace linalg { template , - typename ReduceLambda = raft::Sum, - typename FinalLambda = raft::Nop> + typename MainLambda = raft::identity_op, + typename ReduceLambda = raft::add_op, + typename FinalLambda = raft::identity_op> void reduce(OutType* dots, const InType* data, IdxType D, @@ -71,9 +72,9 @@ void reduce(OutType* dots, bool alongRows, cudaStream_t stream, bool inplace = false, - MainLambda main_op = raft::Nop(), - ReduceLambda reduce_op = raft::Sum(), - FinalLambda final_op = raft::Nop()) + MainLambda main_op = raft::identity_op(), + ReduceLambda reduce_op = raft::add_op(), + FinalLambda final_op = raft::identity_op()) { detail::reduce( dots, data, D, N, init, rowMajor, alongRows, stream, inplace, main_op, reduce_op, final_op); @@ -118,18 +119,18 @@ template , - typename ReduceLambda = raft::Sum, - typename FinalLambda = raft::Nop> + typename MainLambda = raft::identity_op, + typename ReduceLambda = raft::add_op, + typename FinalLambda = raft::identity_op> void reduce(const raft::handle_t& handle, raft::device_matrix_view data, raft::device_vector_view dots, OutElementType init, Apply apply, bool inplace = false, - MainLambda main_op = raft::Nop(), - ReduceLambda reduce_op = raft::Sum(), - FinalLambda final_op = raft::Nop()) + MainLambda main_op = raft::identity_op(), + ReduceLambda reduce_op = raft::add_op(), + FinalLambda final_op = raft::identity_op()) { RAFT_EXPECTS(raft::is_row_or_column_major(data), "Input must be contiguous"); diff --git a/cpp/include/raft/linalg/sqrt.cuh b/cpp/include/raft/linalg/sqrt.cuh index 6d0a5d58f3..cc9bfc69ef 100644 --- a/cpp/include/raft/linalg/sqrt.cuh +++ b/cpp/include/raft/linalg/sqrt.cuh @@ -38,7 +38,7 @@ namespace linalg { template void sqrt(out_t* out, const in_t* in, IdxType len, cudaStream_t stream) { - raft::linalg::unaryOp(out, in, len, raft::SqrtOp{}, stream); + raft::linalg::unaryOp(out, in, len, raft::sqrt_op{}, stream); } /** @} */ diff --git a/cpp/include/raft/linalg/strided_reduction.cuh b/cpp/include/raft/linalg/strided_reduction.cuh index 0aa4aecef5..4f44787bbb 100644 --- a/cpp/include/raft/linalg/strided_reduction.cuh +++ b/cpp/include/raft/linalg/strided_reduction.cuh @@ -23,6 +23,7 @@ #include #include +#include #include @@ -59,9 +60,9 @@ namespace linalg { template , - typename ReduceLambda = raft::Sum, - typename FinalLambda = raft::Nop> + typename MainLambda = raft::identity_op, + typename ReduceLambda = raft::add_op, + typename FinalLambda = raft::identity_op> void stridedReduction(OutType* dots, const InType* data, IdxType D, @@ -69,9 +70,9 @@ void stridedReduction(OutType* dots, OutType init, cudaStream_t stream, bool inplace = false, - MainLambda main_op = raft::Nop(), - ReduceLambda reduce_op = raft::Sum(), - FinalLambda final_op = raft::Nop()) + MainLambda main_op = raft::identity_op(), + ReduceLambda reduce_op = raft::add_op(), + FinalLambda final_op = raft::identity_op()) { // Only compile for types supported by myAtomicReduce, but don't make the compilation fail in // other cases, because coalescedReduction supports arbitrary types. @@ -124,17 +125,17 @@ template , - typename ReduceLambda = raft::Sum, - typename FinalLambda = raft::Nop> + typename MainLambda = raft::identity_op, + typename ReduceLambda = raft::add_op, + typename FinalLambda = raft::identity_op> void strided_reduction(const raft::handle_t& handle, raft::device_matrix_view data, raft::device_vector_view dots, OutValueType init, bool inplace = false, - MainLambda main_op = raft::Nop(), - ReduceLambda reduce_op = raft::Sum(), - FinalLambda final_op = raft::Nop()) + MainLambda main_op = raft::identity_op(), + ReduceLambda reduce_op = raft::add_op(), + FinalLambda final_op = raft::identity_op()) { if constexpr (std::is_same_v) { RAFT_EXPECTS(static_cast(dots.size()) == data.extent(1), diff --git a/cpp/include/raft/matrix/detail/gather.cuh b/cpp/include/raft/matrix/detail/gather.cuh index bbd24f0353..a2e1562af8 100644 --- a/cpp/include/raft/matrix/detail/gather.cuh +++ b/cpp/include/raft/matrix/detail/gather.cuh @@ -185,7 +185,7 @@ void gather(const MatrixIteratorT in, { typedef typename std::iterator_traits::value_type MapValueT; gatherImpl( - in, D, N, map, map, map_length, out, raft::ConstOp(true), raft::Nop(), stream); + in, D, N, map, map, map_length, out, raft::const_op(true), raft::identity_op(), stream); } /** @@ -220,7 +220,7 @@ void gather(const MatrixIteratorT in, cudaStream_t stream) { typedef typename std::iterator_traits::value_type MapValueT; - gatherImpl(in, D, N, map, map, map_length, out, raft::ConstOp(true), transform_op, stream); + gatherImpl(in, D, N, map, map, map_length, out, raft::const_op(true), transform_op, stream); } /** @@ -262,7 +262,7 @@ void gather_if(const MatrixIteratorT in, cudaStream_t stream) { typedef typename std::iterator_traits::value_type MapValueT; - gatherImpl(in, D, N, map, stencil, map_length, out, pred_op, raft::Nop(), stream); + gatherImpl(in, D, N, map, stencil, map_length, out, pred_op, raft::identity_op(), stream); } /** diff --git a/cpp/include/raft/matrix/detail/math.cuh b/cpp/include/raft/matrix/detail/math.cuh index dec3d17b96..6e80ff6880 100644 --- a/cpp/include/raft/matrix/detail/math.cuh +++ b/cpp/include/raft/matrix/detail/math.cuh @@ -188,7 +188,7 @@ void reciprocal(math_t* in, math_t* out, IdxType len, cudaStream_t stream) template void setValue(math_t* out, const math_t* in, math_t scalar, int len, cudaStream_t stream = 0) { - raft::linalg::unaryOp(out, in, len, raft::ConstOp(scalar), stream); + raft::linalg::unaryOp(out, in, len, raft::const_op(scalar), stream); } template @@ -200,7 +200,7 @@ void ratio( rmm::device_scalar d_sum(stream); auto* d_sum_ptr = d_sum.data(); - raft::linalg::mapThenSumReduce(d_sum_ptr, len, raft::Nop{}, stream, src); + raft::linalg::mapThenSumReduce(d_sum_ptr, len, raft::identity_op{}, stream, src); raft::linalg::unaryOp( d_dest, d_src, len, [=] __device__(math_t a) { return a / (*d_sum_ptr); }, stream); } @@ -215,7 +215,7 @@ void matrixVectorBinaryMult(Type* data, cudaStream_t stream) { raft::linalg::matrixVectorOp( - data, data, vec, n_col, n_row, rowMajor, bcastAlongRows, raft::Multiply(), stream); + data, data, vec, n_col, n_row, rowMajor, bcastAlongRows, raft::mul_op(), stream); } template @@ -254,7 +254,7 @@ void matrixVectorBinaryDiv(Type* data, cudaStream_t stream) { raft::linalg::matrixVectorOp( - data, data, vec, n_col, n_row, rowMajor, bcastAlongRows, raft::Divide(), stream); + data, data, vec, n_col, n_row, rowMajor, bcastAlongRows, raft::div_op(), stream); } template @@ -312,7 +312,7 @@ void matrixVectorBinaryAdd(Type* data, cudaStream_t stream) { raft::linalg::matrixVectorOp( - data, data, vec, n_col, n_row, rowMajor, bcastAlongRows, raft::Sum(), stream); + data, data, vec, n_col, n_row, rowMajor, bcastAlongRows, raft::add_op(), stream); } template @@ -325,7 +325,7 @@ void matrixVectorBinarySub(Type* data, cudaStream_t stream) { raft::linalg::matrixVectorOp( - data, data, vec, n_col, n_row, rowMajor, bcastAlongRows, raft::Subtract(), stream); + data, data, vec, n_col, n_row, rowMajor, bcastAlongRows, raft::sub_op(), stream); } // Computes an argmin/argmax column-wise in a DxN matrix diff --git a/cpp/include/raft/sparse/distance/detail/lp_distance.cuh b/cpp/include/raft/sparse/distance/detail/lp_distance.cuh index 3c5ee62e5e..c0a2511db9 100644 --- a/cpp/include/raft/sparse/distance/detail/lp_distance.cuh +++ b/cpp/include/raft/sparse/distance/detail/lp_distance.cuh @@ -197,7 +197,7 @@ class lp_unexpanded_distances_t : public distances_t { raft::linalg::unaryOp(out_dists, out_dists, config_->a_nrows * config_->b_nrows, - raft::ScalarPow(one_over_p), + raft::scalar_pow_op(one_over_p), config_->handle.get_stream()); } @@ -222,7 +222,7 @@ class hamming_unexpanded_distances_t : public distances_t { raft::linalg::unaryOp(out_dists, out_dists, config_->a_nrows * config_->b_nrows, - raft::ScalarMul(n_cols), + raft::scalar_mul_op(n_cols), config_->handle.get_stream()); } @@ -303,7 +303,7 @@ class kl_divergence_unexpanded_distances_t : public distances_t { raft::linalg::unaryOp(out_dists, out_dists, config_->a_nrows * config_->b_nrows, - raft::ScalarMul(0.5), + raft::scalar_mul_op(0.5), config_->handle.get_stream()); } diff --git a/cpp/include/raft/sparse/op/detail/slice.cuh b/cpp/include/raft/sparse/op/detail/slice.cuh index d402a9e1e3..ddf002bc0e 100644 --- a/cpp/include/raft/sparse/op/detail/slice.cuh +++ b/cpp/include/raft/sparse/op/detail/slice.cuh @@ -73,7 +73,7 @@ void csr_row_slice_indptr(value_idx start_row, raft::linalg::unaryOp(indptr_out, indptr_out, (stop_row + 2) - start_row, - raft::ScalarSub(s_offset), + raft::scalar_sub_op(s_offset), stream); } diff --git a/cpp/include/raft/spatial/knn/detail/ann_quantized.cuh b/cpp/include/raft/spatial/knn/detail/ann_quantized.cuh index c3f7015bab..57ca2625bc 100644 --- a/cpp/include/raft/spatial/knn/detail/ann_quantized.cuh +++ b/cpp/include/raft/spatial/knn/detail/ann_quantized.cuh @@ -202,7 +202,7 @@ void approx_knn_search(const handle_t& handle, float p = 0.5; // standard l2 if (index->metric == raft::distance::DistanceType::LpUnexpanded) p = 1.0 / index->metricArg; raft::linalg::unaryOp( - distances, distances, n * k, raft::ScalarPow(p), handle.get_stream()); + distances, distances, n * k, raft::scalar_pow_op(p), handle.get_stream()); } if constexpr (std::is_same_v) { index->metric_processor->postprocess(distances); } } diff --git a/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh b/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh index 795901903e..3765f1a719 100644 --- a/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh +++ b/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh @@ -22,6 +22,7 @@ #include "processing.cuh" #include #include +#include namespace raft { namespace spatial { @@ -576,7 +577,7 @@ void fusedL2UnexpKnnImpl(const DataT* x, IdxT, KPolicy, decltype(core_lambda), - raft::Nop, + raft::identity_op, 32, 2, usePrevTopKs, @@ -588,7 +589,7 @@ void fusedL2UnexpKnnImpl(const DataT* x, IdxT, KPolicy, decltype(core_lambda), - raft::Nop, + raft::identity_op, 64, 3, usePrevTopKs, @@ -628,7 +629,7 @@ void fusedL2UnexpKnnImpl(const DataT* x, ldb, ldd, core_lambda, - raft::Nop{}, + raft::identity_op{}, sqrt, (uint32_t)numOfNN, (int*)workspace, @@ -765,7 +766,7 @@ void fusedL2ExpKnnImpl(const DataT* x, IdxT, KPolicy, decltype(core_lambda), - raft::Nop, + raft::identity_op, 32, 2, usePrevTopKs, @@ -777,7 +778,7 @@ void fusedL2ExpKnnImpl(const DataT* x, IdxT, KPolicy, decltype(core_lambda), - raft::Nop, + raft::identity_op, 64, 3, usePrevTopKs, @@ -817,12 +818,12 @@ void fusedL2ExpKnnImpl(const DataT* x, if (x != y) { yn += m; raft::linalg::rowNorm( - xn, x, k, m, raft::linalg::L2Norm, isRowMajor, stream, raft::Nop{}); + xn, x, k, m, raft::linalg::L2Norm, isRowMajor, stream, raft::identity_op{}); raft::linalg::rowNorm( - yn, y, k, n, raft::linalg::L2Norm, isRowMajor, stream, raft::Nop{}); + yn, y, k, n, raft::linalg::L2Norm, isRowMajor, stream, raft::identity_op{}); } else { raft::linalg::rowNorm( - xn, x, k, n, raft::linalg::L2Norm, isRowMajor, stream, raft::Nop{}); + xn, x, k, n, raft::linalg::L2Norm, isRowMajor, stream, raft::identity_op{}); } fusedL2ExpKnnRowMajor<<>>(x, y, @@ -835,7 +836,7 @@ void fusedL2ExpKnnImpl(const DataT* x, ldb, ldd, core_lambda, - raft::Nop{}, + raft::identity_op{}, sqrt, (uint32_t)numOfNN, mutexes, diff --git a/cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh b/cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh index 14c4dd85f1..6740337071 100644 --- a/cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh +++ b/cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh @@ -28,6 +28,7 @@ #include #include #include +#include #include #include @@ -245,7 +246,7 @@ inline auto extend(const handle_t& handle, raft::linalg::L2Norm, true, stream, - raft::SqrtOp()); + raft::sqrt_op()); RAFT_LOG_TRACE_VEC(ext_index.center_norms()->data_handle(), std::min(dim, 20)); } } diff --git a/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh b/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh index 94f4dc96c6..f43b15a117 100644 --- a/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh +++ b/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh @@ -1110,7 +1110,7 @@ void search_impl(const handle_t& handle, raft::linalg::L2Norm, true, stream, - raft::SqrtOp()); + raft::sqrt_op()); utils::outer_add(query_norm_dev.data(), (IdxT)n_queries, index.center_norms()->data_handle(), diff --git a/cpp/include/raft/spatial/knn/detail/ivf_pq_build.cuh b/cpp/include/raft/spatial/knn/detail/ivf_pq_build.cuh index 6604c1d9ce..7a1a4492d1 100644 --- a/cpp/include/raft/spatial/knn/detail/ivf_pq_build.cuh +++ b/cpp/include/raft/spatial/knn/detail/ivf_pq_build.cuh @@ -259,7 +259,7 @@ void select_residuals(const handle_t& handle, n_rows, (IdxT)dim, dataset, row_ids, (IdxT)dim, tmp.data(), (IdxT)dim, stream); raft::matrix::linewiseOp( - tmp.data(), tmp.data(), IdxT(dim), n_rows, true, raft::Subtract{}, stream, center); + tmp.data(), tmp.data(), IdxT(dim), n_rows, true, raft::sub_op{}, stream, center); float alpha = 1.0; float beta = 0.0; @@ -1179,7 +1179,7 @@ inline auto build_device( raft::linalg::L2Norm, true, stream, - raft::SqrtOp()); + raft::sqrt_op()); RAFT_CUDA_TRY(cudaMemcpy2DAsync(index.centers().data_handle() + index.dim(), sizeof(float) * index.dim_ext(), center_norms.data(), diff --git a/cpp/include/raft/spatial/knn/detail/ivf_pq_search.cuh b/cpp/include/raft/spatial/knn/detail/ivf_pq_search.cuh index cd9ae5f712..cfedd0e606 100644 --- a/cpp/include/raft/spatial/knn/detail/ivf_pq_search.cuh +++ b/cpp/include/raft/spatial/knn/detail/ivf_pq_search.cuh @@ -419,7 +419,7 @@ void postprocess_distances(float* out, // [n_queries, topk] case distance::DistanceType::L2Unexpanded: case distance::DistanceType::L2Expanded: { linalg::unaryOp( - out, in, len, raft::ScalarMul(scaling_factor * scaling_factor), stream); + out, in, len, raft::scalar_mul_op(scaling_factor * scaling_factor), stream); } break; case distance::DistanceType::L2SqrtUnexpanded: case distance::DistanceType::L2SqrtExpanded: { @@ -432,7 +432,7 @@ void postprocess_distances(float* out, // [n_queries, topk] } break; case distance::DistanceType::InnerProduct: { linalg::unaryOp( - out, in, len, raft::ScalarMul(-scaling_factor * scaling_factor), stream); + out, in, len, raft::scalar_mul_op(-scaling_factor * scaling_factor), stream); } break; default: RAFT_FAIL("Unexpected metric."); } diff --git a/cpp/include/raft/spatial/knn/detail/processing.cuh b/cpp/include/raft/spatial/knn/detail/processing.cuh index 8b30010dc3..747a1db00d 100644 --- a/cpp/include/raft/spatial/knn/detail/processing.cuh +++ b/cpp/include/raft/spatial/knn/detail/processing.cuh @@ -23,6 +23,7 @@ #include #include #include +#include #include namespace raft { @@ -59,30 +60,16 @@ class CosineMetricProcessor : public MetricProcessor { raft::linalg::NormType::L2Norm, row_major_, stream_, - raft::SqrtOp{}); - - raft::linalg::matrixVectorOp(data, - data, - colsums_.data(), - n_cols_, - n_rows_, - row_major_, - false, - raft::Divide{}, - stream_); + raft::sqrt_op{}); + + raft::linalg::matrixVectorOp( + data, data, colsums_.data(), n_cols_, n_rows_, row_major_, false, raft::div_op{}, stream_); } void revert(math_t* data) { - raft::linalg::matrixVectorOp(data, - data, - colsums_.data(), - n_cols_, - n_rows_, - row_major_, - false, - raft::Multiply{}, - stream_); + raft::linalg::matrixVectorOp( + data, data, colsums_.data(), n_cols_, n_rows_, row_major_, false, raft::mul_op{}, stream_); } void postprocess(math_t* data) @@ -123,7 +110,7 @@ class CorrelationMetricProcessor : public CosineMetricProcessor { raft::linalg::unaryOp(means_.data(), means_.data(), cosine::n_rows_, - raft::ScalarMul(normalizer_const), + raft::scalar_mul_op(normalizer_const), cosine::stream_); raft::stats::meanCenter(data, diff --git a/cpp/include/raft/stats/detail/batched/silhouette_score.cuh b/cpp/include/raft/stats/detail/batched/silhouette_score.cuh index 25a3721af1..5bd8c4a7ab 100644 --- a/cpp/include/raft/stats/detail/batched/silhouette_score.cuh +++ b/cpp/include/raft/stats/detail/batched/silhouette_score.cuh @@ -249,19 +249,18 @@ value_t silhouette_score( // calculating row-wise minimum in b // this prim only supports int indices for now - raft::linalg:: - reduce, raft::stats::detail::MinOp>( - b_ptr, - b_ptr, - n_labels, - n_rows, - std::numeric_limits::max(), - true, - true, - stream, - false, - raft::Nop(), - raft::stats::detail::MinOp()); + raft::linalg::reduce( + b_ptr, + b_ptr, + n_labels, + n_rows, + std::numeric_limits::max(), + true, + true, + stream, + false, + raft::identity_op(), + raft::min_op()); // calculating the silhouette score per sample raft::linalg::binaryOp, value_t, value_idx>( diff --git a/cpp/include/raft/stats/detail/mean_center.cuh b/cpp/include/raft/stats/detail/mean_center.cuh index 8453744789..6e1c07e1e3 100644 --- a/cpp/include/raft/stats/detail/mean_center.cuh +++ b/cpp/include/raft/stats/detail/mean_center.cuh @@ -49,7 +49,7 @@ void meanCenter(Type* out, cudaStream_t stream) { raft::linalg::matrixVectorOp( - out, data, mu, D, N, rowMajor, bcastAlongRows, raft::Subtract{}, stream); + out, data, mu, D, N, rowMajor, bcastAlongRows, raft::sub_op{}, stream); } /** @@ -77,7 +77,7 @@ void meanAdd(Type* out, cudaStream_t stream) { raft::linalg::matrixVectorOp( - out, data, mu, D, N, rowMajor, bcastAlongRows, raft::Sum{}, stream); + out, data, mu, D, N, rowMajor, bcastAlongRows, raft::add_op{}, stream); } }; // end namespace detail diff --git a/cpp/include/raft/stats/detail/silhouette_score.cuh b/cpp/include/raft/stats/detail/silhouette_score.cuh index b95d3783e5..7b899cb92d 100644 --- a/cpp/include/raft/stats/detail/silhouette_score.cuh +++ b/cpp/include/raft/stats/detail/silhouette_score.cuh @@ -272,11 +272,11 @@ DataT silhouette_score( nRows, true, true, - raft::DivOp(), + raft::div_op(), stream); // calculating row-wise minimum - raft::linalg::reduce, raft::MinOp>( + raft::linalg::reduce( d_bArray.data(), averageDistanceBetweenSampleAndCluster.data(), nLabels, @@ -286,8 +286,8 @@ DataT silhouette_score( true, stream, false, - raft::Nop{}, - raft::Min{}); + raft::identity_op{}, + raft::min_op{}); // calculating the silhouette score per sample using the d_aArray and d_bArray raft::linalg::binaryOp>( @@ -297,12 +297,12 @@ DataT silhouette_score( rmm::device_scalar d_avgSilhouetteScore(stream); RAFT_CUDA_TRY(cudaMemsetAsync(d_avgSilhouetteScore.data(), 0, sizeof(DataT), stream)); - raft::linalg::mapThenSumReduce>(d_avgSilhouetteScore.data(), - nRows, - raft::Nop(), - stream, - perSampleSilScore, - perSampleSilScore); + raft::linalg::mapThenSumReduce(d_avgSilhouetteScore.data(), + nRows, + raft::identity_op(), + stream, + perSampleSilScore, + perSampleSilScore); DataT avgSilhouetteScore = d_avgSilhouetteScore.value(stream); diff --git a/cpp/include/raft/stats/detail/weighted_mean.cuh b/cpp/include/raft/stats/detail/weighted_mean.cuh index c8b59e08a2..ba7e6bf4c4 100644 --- a/cpp/include/raft/stats/detail/weighted_mean.cuh +++ b/cpp/include/raft/stats/detail/weighted_mean.cuh @@ -18,6 +18,7 @@ #include #include +#include #include namespace raft { @@ -66,8 +67,8 @@ void weightedMean(Type* mu, stream, false, [weights] __device__(Type v, IdxType i) { return v * weights[i]; }, - raft::Sum{}, - raft::ScalarDiv(WS)); + raft::add_op{}, + raft::scalar_div_op(WS)); } }; // end namespace detail }; // end namespace stats diff --git a/cpp/include/raft/util/cuda_utils.cuh b/cpp/include/raft/util/cuda_utils.cuh index f167ad76ed..b1b5853abc 100644 --- a/cpp/include/raft/util/cuda_utils.cuh +++ b/cpp/include/raft/util/cuda_utils.cuh @@ -508,16 +508,87 @@ HDI double myATanh(double x) /** @} */ /** - * @defgroup LambdaOps Commonly used lambda operations + * @defgroup LambdaOps Legacy lambda operations, to be deprecated * @{ */ -// The optional index argument is mostly to be used for MainLambda in reduction kernels template struct Nop { - HDI Type operator()(Type in, IdxType i = 0) const { return in; } + [[deprecated("Nop is deprecated. Use identity_op instead.")]] HDI Type + operator()(Type in, IdxType i = 0) const + { + return in; + } +}; + +template +struct SqrtOp { + [[deprecated("SqrtOp is deprecated. Use sqrt_op instead.")]] HDI Type + operator()(Type in, IdxType i = 0) const + { + return mySqrt(in); + } +}; + +template +struct L0Op { + [[deprecated("L0Op is deprecated. Use nz_op instead.")]] HDI Type operator()(Type in, + IdxType i = 0) const + { + return in != Type(0) ? Type(1) : Type(0); + } +}; + +template +struct L1Op { + [[deprecated("L1Op is deprecated. Use abs_op instead.")]] HDI Type operator()(Type in, + IdxType i = 0) const + { + return myAbs(in); + } +}; + +template +struct L2Op { + [[deprecated("L2Op is deprecated. Use sq_op instead.")]] HDI Type operator()(Type in, + IdxType i = 0) const + { + return in * in; + } }; -struct KeyOp { +template +struct Sum { + [[deprecated("Sum is deprecated. Use add_op instead.")]] HDI OutT operator()(InT a, InT b) const + { + return a + b; + } +}; + +template +struct Max { + [[deprecated("Max is deprecated. Use max_op instead.")]] HDI Type operator()(Type a, Type b) const + { + if (b > a) { return b; } + return a; + } +}; +/** @} */ + +/** + * @defgroup Functors Commonly used functors. + * The optional index argument is mostly to be used for MainLambda in reduction kernels + * @{ + */ + +struct identity_op { + template + HDI Type operator()(Type in, IdxType i = 0) const + { + return in; + } +}; + +struct key_op { template HDI typename KVP::Key operator()(const KVP& p, IdxType i = 0) const { @@ -525,7 +596,7 @@ struct KeyOp { } }; -struct ValueOp { +struct value_op { template HDI typename KVP::Value operator()(const KVP& p, IdxType i = 0) const { @@ -533,62 +604,89 @@ struct ValueOp { } }; -template -struct SqrtOp { - HDI Type operator()(Type in, IdxType i = 0) const { return mySqrt(in); } +struct sqrt_op { + template + HDI Type operator()(Type in, IdxType i = 0) const + { + return mySqrt(in); + } }; -template -struct L0Op { - HDI Type operator()(Type in, IdxType i = 0) const { return in != Type(0) ? Type(1) : Type(0); } +struct nz_op { + template + HDI Type operator()(Type in, IdxType i = 0) const + { + return in != Type(0) ? Type(1) : Type(0); + } }; -template -struct L1Op { - HDI Type operator()(Type in, IdxType i = 0) const { return myAbs(in); } +struct abs_op { + template + HDI Type operator()(Type in, IdxType i = 0) const + { + return myAbs(in); + } }; -template -struct L2Op { - HDI Type operator()(Type in, IdxType i = 0) const { return in * in; } +struct sq_op { + template + HDI Type operator()(Type in, IdxType i = 0) const + { + return in * in; + } }; -template -struct Sum { - HDI OutT operator()(InT a, InT b) const { return a + b; } +struct add_op { + template + HDI auto operator()(T1 a, T2 b) const + { + return a + b; + } }; -template -struct Subtract { - HDI OutT operator()(InT a, InT b) const { return a - b; } +struct sub_op { + template + HDI auto operator()(T1 a, T2 b) const + { + return a - b; + } }; -template -struct Multiply { - HDI OutT operator()(InT a, InT b) const { return a * b; } +struct mul_op { + template + HDI auto operator()(T1 a, T2 b) const + { + return a * b; + } }; -template -struct Divide { - HDI OutT operator()(InT a, InT b) const { return a / b; } +struct div_op { + template + HDI auto operator()(T1 a, T2 b) const + { + return a / b; + } }; -template -struct DivideCheckZero { - HDI OutT operator()(InT a, InT b) const +struct div_checkzero_op { + template + HDI Type operator()(Type a, Type b) const { - if (b == InT{0}) { return InT{0}; } + if (b == Type{0}) { return Type{0}; } return a / b; } }; -template -struct Pow { - HDI OutT operator()(InT a, InT b) const { return raft::myPow(a, b); } +struct pow_op { + template + HDI Type operator()(Type a, Type b) const + { + return raft::myPow(a, b); + } }; -template -struct Min { +struct min_op { + template HDI Type operator()(Type a, Type b) const { if (a > b) { return b; } @@ -596,8 +694,8 @@ struct Min { } }; -template -struct Max { +struct max_op { + template HDI Type operator()(Type a, Type b) const { if (b > a) { return b; } @@ -605,8 +703,8 @@ struct Max { } }; -template -struct SqDiff { +struct sqdiff_op { + template HDI Type operator()(Type a, Type b) const { Type diff = a - b; @@ -614,7 +712,7 @@ struct SqDiff { } }; -struct ArgMin { +struct argmin_op { template HDI KVP operator()(const KVP& a, const KVP& b) const { @@ -623,7 +721,7 @@ struct ArgMin { } }; -struct ArgMax { +struct argmax_op { template HDI KVP operator()(const KVP& a, const KVP& b) const { @@ -632,46 +730,50 @@ struct ArgMax { } }; -template -struct ConstOp { - const OutT scalar; +template +struct const_op { + const ScalarT scalar; - ConstOp(OutT s) : scalar{s} {} + const_op(ScalarT s) : scalar{s} {} template - HDI OutT operator()(InT unused) const + HDI ScalarT operator()(InT unused) const { return scalar; } }; -template -struct ScalarOp { +template +struct scalar_op { ComposedOpT composed_op; - const InT scalar; + const ScalarT scalar; - ScalarOp(InT s) : scalar{s} {} + scalar_op(ScalarT s) : scalar{s} {} - HDI OutT operator()(InT a) const { return composed_op(a, scalar); } + template + HDI auto operator()(InT a) const + { + return composed_op(a, scalar); + } }; -template -using ScalarAdd = ScalarOp, InT, OutT>; +template +using scalar_add_op = scalar_op; -template -using ScalarSub = ScalarOp, InT, OutT>; +template +using scalar_sub_op = scalar_op; -template -using ScalarMul = ScalarOp, InT, OutT>; +template +using scalar_mul_op = scalar_op; -template -using ScalarDiv = ScalarOp, InT, OutT>; +template +using scalar_div_op = scalar_op; -template -using ScalarDivCheckZero = ScalarOp, InT, OutT>; +template +using scalar_div_checkzero_op = scalar_op; -template -using ScalarPow = ScalarOp, InT, OutT>; +template +using scalar_pow_op = scalar_op; /** @} */ /** @@ -1065,7 +1167,7 @@ DI T warpReduce(T val, ReduceLambda reduce_op) template DI T warpReduce(T val) { - return warpReduce(val, raft::Sum{}); + return warpReduce(val, raft::add_op{}); } /** diff --git a/cpp/include/raft/util/scatter.cuh b/cpp/include/raft/util/scatter.cuh index c20afa5454..58b5ce0bc1 100644 --- a/cpp/include/raft/util/scatter.cuh +++ b/cpp/include/raft/util/scatter.cuh @@ -37,13 +37,13 @@ namespace raft { * will be applied to every element before scattering it to the right location. * The second param in this method will be the destination index. */ -template , int TPB = 256> +template void scatter(DataT* out, const DataT* in, const IdxT* idx, IdxT len, cudaStream_t stream, - Lambda op = raft::Nop()) + Lambda op = raft::identity_op()) { if (len <= 0) return; constexpr size_t DataSize = sizeof(DataT); diff --git a/cpp/src/distance/cluster_cost.cuh b/cpp/src/distance/cluster_cost.cuh index 82b6c041a3..3dca92b2db 100644 --- a/cpp/src/distance/cluster_cost.cuh +++ b/cpp/src/distance/cluster_cost.cuh @@ -18,6 +18,7 @@ #include #include #include +#include namespace raft::cluster::kmeans::runtime { template @@ -63,14 +64,14 @@ void cluster_cost(const raft::handle_t& handle, min_cluster_distance.data_handle(), min_cluster_distance.data_handle() + n_samples, distances.data_handle(), - raft::ValueOp{}); + raft::value_op{}); rmm::device_scalar device_cost(0, handle.get_stream()); raft::cluster::kmeans::cluster_cost(handle, distances.view(), workspace, make_device_scalar_view(device_cost.data()), - raft::Sum{}); + raft::add_op{}); raft::update_host(cost, device_cost.data(), 1, handle.get_stream()); } diff --git a/cpp/test/cluster/kmeans.cu b/cpp/test/cluster/kmeans.cu index 20f08eedb8..406f3f4ca1 100644 --- a/cpp/test/cluster/kmeans.cu +++ b/cpp/test/cluster/kmeans.cu @@ -50,7 +50,7 @@ void run_cluster_cost(const raft::handle_t& handle, raft::device_scalar_view clusterCost) { raft::cluster::kmeans::cluster_cost( - handle, minClusterDistance, workspace, clusterCost, raft::Sum{}); + handle, minClusterDistance, workspace, clusterCost, raft::add_op{}); } template diff --git a/cpp/test/distance/distance_base.cuh b/cpp/test/distance/distance_base.cuh index 19d449c18b..302a81320c 100644 --- a/cpp/test/distance/distance_base.cuh +++ b/cpp/test/distance/distance_base.cuh @@ -166,7 +166,7 @@ __global__ void naiveLpUnexpDistanceKernel(DataType* dist, int yidx = isRowMajor ? i + nidx * k : i * n + nidx; auto a = x[xidx]; auto b = y[yidx]; - auto diff = raft::L1Op()(a - b); + auto diff = raft::myAbs(a - b); acc += raft::myPow(diff, p); } auto one_over_p = 1 / p; diff --git a/cpp/test/linalg/binary_op.cu b/cpp/test/linalg/binary_op.cu index 848e37de42..34def617d2 100644 --- a/cpp/test/linalg/binary_op.cu +++ b/cpp/test/linalg/binary_op.cu @@ -36,7 +36,7 @@ void binaryOpLaunch( auto in1_view = raft::make_device_vector_view(in1, len); auto in2_view = raft::make_device_vector_view(in2, len); - binary_op(handle, in1_view, in2_view, out_view, raft::Sum{}); + binary_op(handle, in1_view, in2_view, out_view, raft::add_op{}); } template @@ -138,7 +138,7 @@ class BinaryOpAlignment : public ::testing::Test { RAFT_CUDA_TRY(cudaMemsetAsync(x.data(), 0, n * sizeof(math_t), stream)); RAFT_CUDA_TRY(cudaMemsetAsync(y.data(), 0, n * sizeof(math_t), stream)); raft::linalg::binaryOp( - z.data() + 9, x.data() + 137, y.data() + 19, 256, raft::Sum{}, handle.get_stream()); + z.data() + 9, x.data() + 137, y.data() + 19, 256, raft::add_op{}, handle.get_stream()); } raft::handle_t handle; diff --git a/cpp/test/linalg/coalesced_reduction.cu b/cpp/test/linalg/coalesced_reduction.cu index c429d925ce..c2e466f223 100644 --- a/cpp/test/linalg/coalesced_reduction.cu +++ b/cpp/test/linalg/coalesced_reduction.cu @@ -47,7 +47,7 @@ void coalescedReductionLaunch( { auto dots_view = raft::make_device_vector_view(dots, rows); auto data_view = raft::make_device_matrix_view(data, rows, cols); - coalesced_reduction(handle, data_view, dots_view, (T)0, inplace, raft::L2Op{}); + coalesced_reduction(handle, data_view, dots_view, (T)0, inplace, raft::sq_op{}); } template @@ -79,9 +79,9 @@ class coalescedReductionTest : public ::testing::TestWithParam{}, - raft::Sum{}, - raft::Nop{}); + raft::sq_op{}, + raft::add_op{}, + raft::identity_op{}); naiveCoalescedReduction(dots_exp.data(), data.data(), cols, @@ -89,9 +89,9 @@ class coalescedReductionTest : public ::testing::TestWithParam{}, - raft::Sum{}, - raft::Nop{}); + raft::sq_op{}, + raft::add_op{}, + raft::identity_op{}); coalescedReductionLaunch(handle, dots_act.data(), data.data(), cols, rows); coalescedReductionLaunch(handle, dots_act.data(), data.data(), cols, rows, true); diff --git a/cpp/test/linalg/map_then_reduce.cu b/cpp/test/linalg/map_then_reduce.cu index 7907f7d1ca..9828f5b206 100644 --- a/cpp/test/linalg/map_then_reduce.cu +++ b/cpp/test/linalg/map_then_reduce.cu @@ -20,6 +20,7 @@ #include #include #include +#include #include #include #include @@ -63,8 +64,8 @@ template void mapReduceLaunch( OutType* out_ref, OutType* out, const InType* in, size_t len, cudaStream_t stream) { - naiveMapReduce(out_ref, in, len, raft::Nop{}, stream); - mapThenSumReduce(out, len, raft::Nop{}, 0, in); + naiveMapReduce(out_ref, in, len, raft::identity_op{}, stream); + mapThenSumReduce(out, len, raft::identity_op{}, 0, in); } template @@ -153,7 +154,7 @@ class MapGenericReduceTest : public ::testing::Test { auto output_view = raft::make_device_scalar_view(output.data()); auto input_view = raft::make_device_vector_view( input.data(), static_cast(input.size())); - map_reduce(handle, input_view, output_view, neutral, raft::Nop{}, cub::Min()); + map_reduce(handle, input_view, output_view, neutral, raft::identity_op{}, cub::Min()); EXPECT_TRUE(raft::devArrMatch( OutType(1), output.data(), 1, raft::Compare(), handle.get_stream())); } @@ -163,7 +164,7 @@ class MapGenericReduceTest : public ::testing::Test { auto output_view = raft::make_device_scalar_view(output.data()); auto input_view = raft::make_device_vector_view( input.data(), static_cast(input.size())); - map_reduce(handle, input_view, output_view, neutral, raft::Nop{}, cub::Max()); + map_reduce(handle, input_view, output_view, neutral, raft::identity_op{}, cub::Max()); EXPECT_TRUE(raft::devArrMatch( OutType(5), output.data(), 1, raft::Compare(), handle.get_stream())); } diff --git a/cpp/test/linalg/matrix_vector.cu b/cpp/test/linalg/matrix_vector.cu index 2424f8d3aa..83d1d60ef6 100644 --- a/cpp/test/linalg/matrix_vector.cu +++ b/cpp/test/linalg/matrix_vector.cu @@ -20,6 +20,7 @@ #include #include #include +#include #include namespace raft { @@ -124,14 +125,14 @@ void naive_matrix_vector_op_launch(const raft::handle_t& handle, naiveMatVec( in, in, vec1, D, N, row_major, bcast_along_rows, operation_bin_mult_skip_zero, stream); } else if (operation_type == 1) { - naiveMatVec(in, in, vec1, D, N, row_major, bcast_along_rows, raft::Divide{}, stream); + naiveMatVec(in, in, vec1, D, N, row_major, bcast_along_rows, raft::div_op{}, stream); } else if (operation_type == 2) { naiveMatVec( in, in, vec1, D, N, row_major, bcast_along_rows, operation_bin_div_skip_zero, stream); } else if (operation_type == 3) { - naiveMatVec(in, in, vec1, D, N, row_major, bcast_along_rows, raft::Sum{}, stream); + naiveMatVec(in, in, vec1, D, N, row_major, bcast_along_rows, raft::add_op{}, stream); } else if (operation_type == 4) { - naiveMatVec(in, in, vec1, D, N, row_major, bcast_along_rows, raft::Subtract{}, stream); + naiveMatVec(in, in, vec1, D, N, row_major, bcast_along_rows, raft::sub_op{}, stream); } else { THROW("Unknown operation type '%d'!", (int)operation_type); } diff --git a/cpp/test/linalg/norm.cu b/cpp/test/linalg/norm.cu index 1752c51af8..53a1656ec4 100644 --- a/cpp/test/linalg/norm.cu +++ b/cpp/test/linalg/norm.cu @@ -18,6 +18,7 @@ #include #include #include +#include #include #include @@ -96,11 +97,9 @@ class RowNormTest : public ::testing::TestWithParam> { data.data(), params.rows, params.cols); if (params.do_sqrt) { if (params.rowMajor) { - norm( - handle, input_row_major, output_view, params.type, Apply::ALONG_ROWS, raft::SqrtOp{}); + norm(handle, input_row_major, output_view, params.type, Apply::ALONG_ROWS, raft::sqrt_op{}); } else { - norm( - handle, input_col_major, output_view, params.type, Apply::ALONG_ROWS, raft::SqrtOp{}); + norm(handle, input_col_major, output_view, params.type, Apply::ALONG_ROWS, raft::sqrt_op{}); } } else { if (params.rowMajor) { @@ -173,19 +172,11 @@ class ColNormTest : public ::testing::TestWithParam> { data.data(), params.rows, params.cols); if (params.do_sqrt) { if (params.rowMajor) { - norm(handle, - input_row_major, - output_view, - params.type, - Apply::ALONG_COLUMNS, - raft::SqrtOp{}); + norm( + handle, input_row_major, output_view, params.type, Apply::ALONG_COLUMNS, raft::sqrt_op{}); } else { - norm(handle, - input_col_major, - output_view, - params.type, - Apply::ALONG_COLUMNS, - raft::SqrtOp{}); + norm( + handle, input_col_major, output_view, params.type, Apply::ALONG_COLUMNS, raft::sqrt_op{}); } } else { if (params.rowMajor) { diff --git a/cpp/test/linalg/normalize.cu b/cpp/test/linalg/normalize.cu index 702adbf6d7..111253206c 100644 --- a/cpp/test/linalg/normalize.cu +++ b/cpp/test/linalg/normalize.cu @@ -20,6 +20,7 @@ #include #include #include +#include #include #include @@ -48,12 +49,13 @@ void rowNormalizeRef( { rmm::device_uvector norm(rows, stream); if (norm_type == raft::linalg::L2Norm) { - raft::linalg::rowNorm(norm.data(), in, cols, rows, norm_type, true, stream, raft::SqrtOp()); + raft::linalg::rowNorm(norm.data(), in, cols, rows, norm_type, true, stream, raft::sqrt_op()); } else { - raft::linalg::rowNorm(norm.data(), in, cols, rows, norm_type, true, stream, raft::Nop()); + raft::linalg::rowNorm( + norm.data(), in, cols, rows, norm_type, true, stream, raft::identity_op()); } raft::linalg::matrixVectorOp( - out, in, norm.data(), cols, rows, true, false, raft::Divide{}, stream); + out, in, norm.data(), cols, rows, true, false, raft::div_op{}, stream); } template diff --git a/cpp/test/linalg/reduce.cu b/cpp/test/linalg/reduce.cu index 00f3810d28..91cc56ac35 100644 --- a/cpp/test/linalg/reduce.cu +++ b/cpp/test/linalg/reduce.cu @@ -101,9 +101,9 @@ void reduceLaunch(OutType* dots, template , - typename ReduceLambda = raft::Sum, - typename FinalLambda = raft::SqrtOp> + typename MainLambda = raft::sq_op, + typename ReduceLambda = raft::add_op, + typename FinalLambda = raft::sqrt_op> class ReduceTest : public ::testing::TestWithParam> { public: ReduceTest() @@ -301,7 +301,7 @@ REDUCE_TEST((ReduceTest, ArgMaxOp, - raft::Nop, int>>), + raft::identity_op>), ReduceTestKVPISI32, inputs_kvpis_i32); REDUCE_TEST((ReduceTest, ArgMaxOp, - raft::Nop, int>>), + raft::identity_op>), ReduceTestKVPIFI32, inputs_kvpif_i32); REDUCE_TEST((ReduceTest, ArgMaxOp, - raft::Nop, int>>), + raft::identity_op>), ReduceTestKVPIDI32, inputs_kvpid_i32); diff --git a/cpp/test/linalg/reduce.cuh b/cpp/test/linalg/reduce.cuh index 0dcffd3f41..ea67b32b1a 100644 --- a/cpp/test/linalg/reduce.cuh +++ b/cpp/test/linalg/reduce.cuh @@ -61,9 +61,9 @@ __global__ void naiveCoalescedReductionKernel(OutType* dots, template , - typename ReduceLambda = raft::Sum, - typename FinalLambda = raft::Nop> + typename MainLambda = raft::identity_op, + typename ReduceLambda = raft::add_op, + typename FinalLambda = raft::identity_op> void naiveCoalescedReduction(OutType* dots, const InType* data, IdxType D, @@ -71,9 +71,9 @@ void naiveCoalescedReduction(OutType* dots, cudaStream_t stream, OutType init, bool inplace = false, - MainLambda main_op = raft::Nop(), - ReduceLambda reduce_op = raft::Sum(), - FinalLambda fin_op = raft::Nop()) + MainLambda main_op = raft::identity_op(), + ReduceLambda reduce_op = raft::add_op(), + FinalLambda fin_op = raft::identity_op()) { static const IdxType TPB = 64; IdxType nblks = raft::ceildiv(N, TPB); @@ -115,9 +115,9 @@ __global__ void naiveStridedReductionKernel(OutType* dots, template , - typename ReduceLambda = raft::Sum, - typename FinalLambda = raft::Nop> + typename MainLambda = raft::identity_op, + typename ReduceLambda = raft::add_op, + typename FinalLambda = raft::identity_op> void naiveStridedReduction(OutType* dots, const InType* data, IdxType D, @@ -125,9 +125,9 @@ void naiveStridedReduction(OutType* dots, cudaStream_t stream, OutType init, bool inplace = false, - MainLambda main_op = raft::Nop(), - ReduceLambda reduce_op = raft::Sum(), - FinalLambda fin_op = raft::Nop()) + MainLambda main_op = raft::identity_op(), + ReduceLambda reduce_op = raft::add_op(), + FinalLambda fin_op = raft::identity_op()) { static const IdxType TPB = 64; IdxType nblks = raft::ceildiv(D, TPB); @@ -139,9 +139,9 @@ void naiveStridedReduction(OutType* dots, template , - typename ReduceLambda = raft::Sum, - typename FinalLambda = raft::Nop> + typename MainLambda = raft::identity_op, + typename ReduceLambda = raft::add_op, + typename FinalLambda = raft::identity_op> void naiveReduction(OutType* dots, const InType* data, IdxType D, @@ -151,9 +151,9 @@ void naiveReduction(OutType* dots, cudaStream_t stream, OutType init, bool inplace = false, - MainLambda main_op = raft::Nop(), - ReduceLambda reduce_op = raft::Sum(), - FinalLambda fin_op = raft::Nop()) + MainLambda main_op = raft::identity_op(), + ReduceLambda reduce_op = raft::add_op(), + FinalLambda fin_op = raft::identity_op()) { if (rowMajor && alongRows) { naiveCoalescedReduction(dots, data, D, N, stream, init, inplace, main_op, reduce_op, fin_op); diff --git a/cpp/test/linalg/strided_reduction.cu b/cpp/test/linalg/strided_reduction.cu index 77ca585ea5..50daafb64b 100644 --- a/cpp/test/linalg/strided_reduction.cu +++ b/cpp/test/linalg/strided_reduction.cu @@ -19,6 +19,7 @@ #include #include #include +#include #include namespace raft { @@ -38,7 +39,7 @@ void stridedReductionLaunch( raft::handle_t handle{stream}; auto dots_view = raft::make_device_vector_view(dots, cols); auto data_view = raft::make_device_matrix_view(data, rows, cols); - strided_reduction(handle, data_view, dots_view, (T)0, inplace, raft::L2Op{}); + strided_reduction(handle, data_view, dots_view, (T)0, inplace, raft::sq_op{}); } template @@ -70,9 +71,9 @@ class stridedReductionTest : public ::testing::TestWithParam{}, - raft::Sum{}, - raft::Nop{}); + raft::sq_op{}, + raft::add_op{}, + raft::identity_op{}); naiveStridedReduction(dots_exp.data(), data.data(), cols, @@ -80,9 +81,9 @@ class stridedReductionTest : public ::testing::TestWithParam{}, - raft::Sum{}, - raft::Nop{}); + raft::sq_op{}, + raft::add_op{}, + raft::identity_op{}); stridedReductionLaunch(dots_act.data(), data.data(), cols, rows, false, stream); stridedReductionLaunch(dots_act.data(), data.data(), cols, rows, true, stream); handle.sync_stream(stream); diff --git a/cpp/test/matrix/linewise_op.cu b/cpp/test/matrix/linewise_op.cu index e8151c72ee..cb9a1d0c31 100644 --- a/cpp/test/matrix/linewise_op.cu +++ b/cpp/test/matrix/linewise_op.cu @@ -68,7 +68,7 @@ struct LinewiseTest : public ::testing::TestWithParam(vec, lineLen); matrix::linewise_op( - handle, in_view, out_view, raft::is_row_major(in_view), raft::Sum{}, vec_view); + handle, in_view, out_view, raft::is_row_major(in_view), raft::add_op{}, vec_view); } template @@ -108,7 +108,7 @@ struct LinewiseTest : public ::testing::TestWithParam(vec, alongLines ? lineLen : nLines); - matrix::linewise_op(handle, in, out, alongLines, raft::Sum{}, vec_view); + matrix::linewise_op(handle, in, out, alongLines, raft::add_op{}, vec_view); } /** diff --git a/cpp/test/sparse/dist_coo_spmv.cu b/cpp/test/sparse/dist_coo_spmv.cu index e0a9a7c71c..a51f01ddb3 100644 --- a/cpp/test/sparse/dist_coo_spmv.cu +++ b/cpp/test/sparse/dist_coo_spmv.cu @@ -162,7 +162,7 @@ class SparseDistanceCOOSPMVTest raft::linalg::unaryOp(out_dists.data(), out_dists.data(), dist_config.a_nrows * dist_config.b_nrows, - raft::ScalarPow{p}, + raft::scalar_pow_op{p}, dist_config.handle.get_stream()); } break; From 406f12ec7da75b07d3c25d777fd20991cd1bd717 Mon Sep 17 00:00:00 2001 From: Louis Sugy Date: Wed, 30 Nov 2022 14:42:08 +0100 Subject: [PATCH 07/22] Update copyright year --- cpp/include/raft/distance/detail/canberra.cuh | 2 +- cpp/include/raft/distance/detail/chebyshev.cuh | 2 +- cpp/include/raft/distance/detail/l1.cuh | 2 +- cpp/include/raft/distance/detail/minkowski.cuh | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/cpp/include/raft/distance/detail/canberra.cuh b/cpp/include/raft/distance/detail/canberra.cuh index 4693d742a1..90ed3940e1 100644 --- a/cpp/include/raft/distance/detail/canberra.cuh +++ b/cpp/include/raft/distance/detail/canberra.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/include/raft/distance/detail/chebyshev.cuh b/cpp/include/raft/distance/detail/chebyshev.cuh index c2312824c1..454ee8c8bb 100644 --- a/cpp/include/raft/distance/detail/chebyshev.cuh +++ b/cpp/include/raft/distance/detail/chebyshev.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/include/raft/distance/detail/l1.cuh b/cpp/include/raft/distance/detail/l1.cuh index 8de0035dbe..95514db60b 100644 --- a/cpp/include/raft/distance/detail/l1.cuh +++ b/cpp/include/raft/distance/detail/l1.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2021, NVIDIA CORPORATION. + * Copyright (c) 2018-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/include/raft/distance/detail/minkowski.cuh b/cpp/include/raft/distance/detail/minkowski.cuh index 71323fc61d..bda83babf1 100644 --- a/cpp/include/raft/distance/detail/minkowski.cuh +++ b/cpp/include/raft/distance/detail/minkowski.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. From 2701ba434f22b83dc2281ea579438c42b00798c5 Mon Sep 17 00:00:00 2001 From: Louis Sugy Date: Wed, 30 Nov 2022 16:31:25 +0100 Subject: [PATCH 08/22] Lift operator ambiguity using compose ops --- .../raft/spatial/knn/detail/ivf_pq_search.cuh | 19 ++++++++-- cpp/include/raft/util/cuda_utils.cuh | 35 +++++++++++++++++++ 2 files changed, 51 insertions(+), 3 deletions(-) diff --git a/cpp/include/raft/spatial/knn/detail/ivf_pq_search.cuh b/cpp/include/raft/spatial/knn/detail/ivf_pq_search.cuh index cfedd0e606..0013dbe896 100644 --- a/cpp/include/raft/spatial/knn/detail/ivf_pq_search.cuh +++ b/cpp/include/raft/spatial/knn/detail/ivf_pq_search.cuh @@ -419,7 +419,12 @@ void postprocess_distances(float* out, // [n_queries, topk] case distance::DistanceType::L2Unexpanded: case distance::DistanceType::L2Expanded: { linalg::unaryOp( - out, in, len, raft::scalar_mul_op(scaling_factor * scaling_factor), stream); + out, + in, + len, + raft::scalar_op>, float>( + scaling_factor * scaling_factor), + stream); } break; case distance::DistanceType::L2SqrtUnexpanded: case distance::DistanceType::L2SqrtExpanded: { @@ -427,12 +432,20 @@ void postprocess_distances(float* out, // [n_queries, topk] out, in, len, - [scaling_factor] __device__(ScoreT x) -> float { return scaling_factor * sqrtf(float(x)); }, + raft::scalar_op< + raft::binary_compose_op>>, + float>(scaling_factor), stream); } break; case distance::DistanceType::InnerProduct: { linalg::unaryOp( - out, in, len, raft::scalar_mul_op(-scaling_factor * scaling_factor), stream); + out, + in, + len, + raft::scalar_op>, float>( + -scaling_factor * scaling_factor), + stream); } break; default: RAFT_FAIL("Unexpected metric."); } diff --git a/cpp/include/raft/util/cuda_utils.cuh b/cpp/include/raft/util/cuda_utils.cuh index b1b5853abc..75e5a8dc45 100644 --- a/cpp/include/raft/util/cuda_utils.cuh +++ b/cpp/include/raft/util/cuda_utils.cuh @@ -588,6 +588,15 @@ struct identity_op { } }; +template +struct cast_op { + template + HDI OutT operator()(InT in, IdxType i = 0) const + { + return static_cast(in); + } +}; + struct key_op { template HDI typename KVP::Key operator()(const KVP& p, IdxType i = 0) const @@ -774,6 +783,32 @@ using scalar_div_checkzero_op = scalar_op; template using scalar_pow_op = scalar_op; + +template +struct unary_compose_op { + OuterOpT outer_op; + InnerOpT inner_op; + + template + HDI auto operator()(Type a) const + { + return outer_op(inner_op(a)); + } +}; + +template +struct binary_compose_op { + OuterOpT outer_op; + InnerOpT1 inner_op1; + InnerOpT2 inner_op2; + + template + HDI auto operator()(T1 a, T2 b) const + { + return outer_op(inner_op1(a), inner_op2(b)); + } +}; + /** @} */ /** From 6051c7f3ca28b7a4ebb0f60c5cdaae47e4b80536 Mon Sep 17 00:00:00 2001 From: Louis Sugy Date: Thu, 1 Dec 2022 11:23:42 +0100 Subject: [PATCH 09/22] Fix conflicting template type --- .../raft/stats/detail/silhouette_score.cuh | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/cpp/include/raft/stats/detail/silhouette_score.cuh b/cpp/include/raft/stats/detail/silhouette_score.cuh index 7b899cb92d..27ad8621dc 100644 --- a/cpp/include/raft/stats/detail/silhouette_score.cuh +++ b/cpp/include/raft/stats/detail/silhouette_score.cuh @@ -264,16 +264,16 @@ DataT silhouette_score( RAFT_CUDA_TRY(cudaMemsetAsync( averageDistanceBetweenSampleAndCluster.data(), 0, nRows * nLabels * sizeof(DataT), stream)); - raft::linalg::matrixVectorOp>(averageDistanceBetweenSampleAndCluster.data(), - sampleToClusterSumOfDistances.data(), - binCountArray.data(), - binCountArray.data(), - nLabels, - nRows, - true, - true, - raft::div_op(), - stream); + raft::linalg::matrixVectorOp(averageDistanceBetweenSampleAndCluster.data(), + sampleToClusterSumOfDistances.data(), + binCountArray.data(), + binCountArray.data(), + nLabels, + nRows, + true, + true, + raft::div_op(), + stream); // calculating row-wise minimum raft::linalg::reduce( From 062d8751266f3237572699666d8b31016b1b89b0 Mon Sep 17 00:00:00 2001 From: Louis Sugy Date: Thu, 1 Dec 2022 13:33:14 +0100 Subject: [PATCH 10/22] Move new operators to their own header --- cpp/cmake/modules/ConfigureCUDA.cmake | 2 +- cpp/include/raft/cluster/detail/kmeans.cuh | 1 + .../raft/cluster/detail/kmeans_common.cuh | 1 + cpp/include/raft/cluster/kmeans.cuh | 2 +- cpp/include/raft/core/operators.hpp | 262 ++++++++++++++++++ .../detail/pairwise_distance_base.cuh | 1 + cpp/include/raft/label/detail/classlabels.cuh | 1 + .../raft/linalg/coalesced_reduction.cuh | 2 +- cpp/include/raft/linalg/detail/add.cuh | 1 + .../linalg/detail/coalesced_reduction.cuh | 1 + cpp/include/raft/linalg/detail/divide.cuh | 2 +- cpp/include/raft/linalg/detail/eltwise.cuh | 2 +- cpp/include/raft/linalg/detail/multiply.cuh | 2 +- cpp/include/raft/linalg/detail/norm.cuh | 2 +- cpp/include/raft/linalg/detail/reduce.cuh | 2 +- .../raft/linalg/detail/strided_reduction.cuh | 1 + cpp/include/raft/linalg/detail/subtract.cuh | 1 + cpp/include/raft/linalg/norm.cuh | 2 +- cpp/include/raft/linalg/normalize.cuh | 2 +- cpp/include/raft/linalg/power.cuh | 2 +- cpp/include/raft/linalg/reduce.cuh | 2 +- cpp/include/raft/linalg/sqrt.cuh | 2 +- cpp/include/raft/linalg/strided_reduction.cuh | 2 +- cpp/include/raft/matrix/detail/gather.cuh | 2 +- cpp/include/raft/matrix/detail/math.cuh | 1 + .../sparse/distance/detail/lp_distance.cuh | 1 + cpp/include/raft/sparse/op/detail/slice.cuh | 1 + .../raft/spatial/knn/detail/ann_quantized.cuh | 1 + .../raft/spatial/knn/detail/fused_l2_knn.cuh | 1 + .../spatial/knn/detail/ivf_flat_build.cuh | 2 +- .../spatial/knn/detail/ivf_flat_search.cuh | 1 + .../raft/spatial/knn/detail/ivf_pq_build.cuh | 1 + .../raft/spatial/knn/detail/ivf_pq_search.cuh | 1 + .../raft/spatial/knn/detail/processing.cuh | 2 +- .../raft/stats/detail/silhouette_score.cuh | 1 + cpp/include/raft/stats/mean_center.cuh | 1 + cpp/include/raft/stats/silhouette_score.cuh | 1 + cpp/include/raft/stats/weighted_mean.cuh | 1 + cpp/include/raft/util/cuda_utils.cuh | 238 +--------------- cpp/include/raft/util/scatter.cuh | 1 + cpp/src/distance/cluster_cost.cuh | 1 + cpp/test/cluster/kmeans.cu | 1 + cpp/test/distance/distance_base.cuh | 1 + cpp/test/linalg/binary_op.cu | 1 + cpp/test/linalg/coalesced_reduction.cu | 1 + cpp/test/linalg/map_then_reduce.cu | 1 + cpp/test/linalg/matrix_vector.cu | 1 + cpp/test/linalg/norm.cu | 1 + cpp/test/linalg/normalize.cu | 1 + cpp/test/linalg/reduce.cu | 1 + cpp/test/linalg/reduce.cuh | 1 + cpp/test/linalg/strided_reduction.cu | 1 + cpp/test/matrix/linewise_op.cu | 1 + cpp/test/sparse/dist_coo_spmv.cu | 1 + 54 files changed, 315 insertions(+), 254 deletions(-) create mode 100644 cpp/include/raft/core/operators.hpp diff --git a/cpp/cmake/modules/ConfigureCUDA.cmake b/cpp/cmake/modules/ConfigureCUDA.cmake index 5e68ca5bc4..624755a867 100644 --- a/cpp/cmake/modules/ConfigureCUDA.cmake +++ b/cpp/cmake/modules/ConfigureCUDA.cmake @@ -18,7 +18,7 @@ if(DISABLE_DEPRECATION_WARNINGS) endif() if(CMAKE_COMPILER_IS_GNUCXX) - list(APPEND RAFT_CXX_FLAGS -Wall -Werror -Wno-unknown-pragmas -Wno-error=deprecated-declarations) + list(APPEND RAFT_CXX_FLAGS -Wall -Werror -Wno-unknown-pragmas -Wno-error=deprecated-declarations -ftemplate-backtrace-limit=100) endif() list(APPEND RAFT_CUDA_FLAGS --expt-extended-lambda --expt-relaxed-constexpr) diff --git a/cpp/include/raft/cluster/detail/kmeans.cuh b/cpp/include/raft/cluster/detail/kmeans.cuh index d021a902a6..6208901efa 100644 --- a/cpp/include/raft/cluster/detail/kmeans.cuh +++ b/cpp/include/raft/cluster/detail/kmeans.cuh @@ -36,6 +36,7 @@ #include #include #include +#include #include #include #include diff --git a/cpp/include/raft/cluster/detail/kmeans_common.cuh b/cpp/include/raft/cluster/detail/kmeans_common.cuh index 77c772e2ce..3bc1b5b89b 100644 --- a/cpp/include/raft/cluster/detail/kmeans_common.cuh +++ b/cpp/include/raft/cluster/detail/kmeans_common.cuh @@ -34,6 +34,7 @@ #include #include #include +#include #include #include #include diff --git a/cpp/include/raft/cluster/kmeans.cuh b/cpp/include/raft/cluster/kmeans.cuh index 994d4992b3..4b912dc966 100644 --- a/cpp/include/raft/cluster/kmeans.cuh +++ b/cpp/include/raft/cluster/kmeans.cuh @@ -20,7 +20,7 @@ #include #include #include -#include +#include namespace raft::cluster::kmeans { diff --git a/cpp/include/raft/core/operators.hpp b/cpp/include/raft/core/operators.hpp new file mode 100644 index 0000000000..1ae2d45131 --- /dev/null +++ b/cpp/include/raft/core/operators.hpp @@ -0,0 +1,262 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +#include + +namespace raft { + +/** + * @defgroup Functors Commonly used functors. + * The optional index argument is mostly to be used for MainLambda in reduction kernels + * @{ + */ + +struct identity_op { + template + constexpr RAFT_INLINE_FUNCTION Type operator()(Type in, IdxType i = 0) const + { + return in; + } +}; + +template +struct cast_op { + template + constexpr RAFT_INLINE_FUNCTION OutT operator()(InT in, IdxType i = 0) const + { + return static_cast(in); + } +}; + +struct key_op { + template + constexpr RAFT_INLINE_FUNCTION typename KVP::Key operator()(const KVP& p, IdxType i = 0) const + { + return p.key; + } +}; + +struct value_op { + template + constexpr RAFT_INLINE_FUNCTION typename KVP::Value operator()(const KVP& p, IdxType i = 0) const + { + return p.value; + } +}; + +struct sqrt_op { + template + constexpr RAFT_INLINE_FUNCTION Type operator()(Type in, IdxType i = 0) const + { + return std::sqrt(in); + } +}; + +struct nz_op { + template + constexpr RAFT_INLINE_FUNCTION Type operator()(Type in, IdxType i = 0) const + { + return in != Type(0) ? Type(1) : Type(0); + } +}; + +struct abs_op { + template + constexpr RAFT_INLINE_FUNCTION Type operator()(Type in, IdxType i = 0) const + { + return std::abs(in); + } +}; + +struct sq_op { + template + constexpr RAFT_INLINE_FUNCTION Type operator()(Type in, IdxType i = 0) const + { + return in * in; + } +}; + +struct add_op { + template + constexpr RAFT_INLINE_FUNCTION auto operator()(T1 a, T2 b) const + { + return a + b; + } +}; + +struct sub_op { + template + constexpr RAFT_INLINE_FUNCTION auto operator()(T1 a, T2 b) const + { + return a - b; + } +}; + +struct mul_op { + template + constexpr RAFT_INLINE_FUNCTION auto operator()(T1 a, T2 b) const + { + return a * b; + } +}; + +struct div_op { + template + constexpr RAFT_INLINE_FUNCTION auto operator()(T1 a, T2 b) const + { + return a / b; + } +}; + +struct div_checkzero_op { + template + constexpr RAFT_INLINE_FUNCTION Type operator()(Type a, Type b) const + { + if (b == Type{0}) { return Type{0}; } + return a / b; + } +}; + +struct pow_op { + template + constexpr RAFT_INLINE_FUNCTION Type operator()(Type a, Type b) const + { + return std::pow(a, b); + } +}; + +struct min_op { + template + constexpr RAFT_INLINE_FUNCTION Type operator()(Type a, Type b) const + { + if (a > b) { return b; } + return a; + } +}; + +struct max_op { + template + constexpr RAFT_INLINE_FUNCTION Type operator()(Type a, Type b) const + { + if (b > a) { return b; } + return a; + } +}; + +struct sqdiff_op { + template + constexpr RAFT_INLINE_FUNCTION Type operator()(Type a, Type b) const + { + Type diff = a - b; + return diff * diff; + } +}; + +struct argmin_op { + template + constexpr RAFT_INLINE_FUNCTION KVP operator()(const KVP& a, const KVP& b) const + { + if ((b.value < a.value) || ((a.value == b.value) && (b.key < a.key))) { return b; } + return a; + } +}; + +struct argmax_op { + template + constexpr RAFT_INLINE_FUNCTION KVP operator()(const KVP& a, const KVP& b) const + { + if ((b.value > a.value) || ((a.value == b.value) && (b.key < a.key))) { return b; } + return a; + } +}; + +template +struct const_op { + const ScalarT scalar; + + constexpr const_op(ScalarT s) : scalar{s} {} + + template + constexpr RAFT_INLINE_FUNCTION ScalarT operator()(InT unused) const + { + return scalar; + } +}; + +template +struct scalar_op { + ComposedOpT composed_op; + const ScalarT scalar; + + constexpr scalar_op(ScalarT s) : scalar{s} {} + + template + constexpr RAFT_INLINE_FUNCTION auto operator()(InT a) const + { + return composed_op(a, scalar); + } +}; + +template +using scalar_add_op = scalar_op; + +template +using scalar_sub_op = scalar_op; + +template +using scalar_mul_op = scalar_op; + +template +using scalar_div_op = scalar_op; + +template +using scalar_div_checkzero_op = scalar_op; + +template +using scalar_pow_op = scalar_op; + +template +struct unary_compose_op { + OuterOpT outer_op; + InnerOpT inner_op; + + template + constexpr RAFT_INLINE_FUNCTION auto operator()(Type a) const + { + return outer_op(inner_op(a)); + } +}; + +template +struct binary_compose_op { + OuterOpT outer_op; + InnerOpT1 inner_op1; + InnerOpT2 inner_op2; + + template + constexpr RAFT_INLINE_FUNCTION auto operator()(T1 a, T2 b) const + { + return outer_op(inner_op1(a), inner_op2(b)); + } +}; + +/** @} */ +} // namespace raft diff --git a/cpp/include/raft/distance/detail/pairwise_distance_base.cuh b/cpp/include/raft/distance/detail/pairwise_distance_base.cuh index 26536d13cd..69bb83d29a 100644 --- a/cpp/include/raft/distance/detail/pairwise_distance_base.cuh +++ b/cpp/include/raft/distance/detail/pairwise_distance_base.cuh @@ -14,6 +14,7 @@ * limitations under the License. */ #pragma once +#include #include #include #include diff --git a/cpp/include/raft/label/detail/classlabels.cuh b/cpp/include/raft/label/detail/classlabels.cuh index cef249e6e1..64d8b4bfae 100644 --- a/cpp/include/raft/label/detail/classlabels.cuh +++ b/cpp/include/raft/label/detail/classlabels.cuh @@ -18,6 +18,7 @@ #include +#include #include #include #include diff --git a/cpp/include/raft/linalg/coalesced_reduction.cuh b/cpp/include/raft/linalg/coalesced_reduction.cuh index b9c20c7e4c..45cd640edc 100644 --- a/cpp/include/raft/linalg/coalesced_reduction.cuh +++ b/cpp/include/raft/linalg/coalesced_reduction.cuh @@ -22,7 +22,7 @@ #include #include -#include +#include namespace raft { namespace linalg { diff --git a/cpp/include/raft/linalg/detail/add.cuh b/cpp/include/raft/linalg/detail/add.cuh index b1b6922809..bec1dcd9b1 100644 --- a/cpp/include/raft/linalg/detail/add.cuh +++ b/cpp/include/raft/linalg/detail/add.cuh @@ -16,6 +16,7 @@ #pragma once +#include #include #include #include diff --git a/cpp/include/raft/linalg/detail/coalesced_reduction.cuh b/cpp/include/raft/linalg/detail/coalesced_reduction.cuh index 4dc3d5bd8c..238e17fa56 100644 --- a/cpp/include/raft/linalg/detail/coalesced_reduction.cuh +++ b/cpp/include/raft/linalg/detail/coalesced_reduction.cuh @@ -18,6 +18,7 @@ #include #include +#include #include #include diff --git a/cpp/include/raft/linalg/detail/divide.cuh b/cpp/include/raft/linalg/detail/divide.cuh index 29388a2035..d25316649b 100644 --- a/cpp/include/raft/linalg/detail/divide.cuh +++ b/cpp/include/raft/linalg/detail/divide.cuh @@ -17,8 +17,8 @@ #pragma once #include +#include #include -#include namespace raft { namespace linalg { diff --git a/cpp/include/raft/linalg/detail/eltwise.cuh b/cpp/include/raft/linalg/detail/eltwise.cuh index f744876820..60a31a4357 100644 --- a/cpp/include/raft/linalg/detail/eltwise.cuh +++ b/cpp/include/raft/linalg/detail/eltwise.cuh @@ -16,9 +16,9 @@ #pragma once +#include #include #include -#include namespace raft { namespace linalg { diff --git a/cpp/include/raft/linalg/detail/multiply.cuh b/cpp/include/raft/linalg/detail/multiply.cuh index 8641ccd154..bb757d4531 100644 --- a/cpp/include/raft/linalg/detail/multiply.cuh +++ b/cpp/include/raft/linalg/detail/multiply.cuh @@ -16,8 +16,8 @@ #pragma once +#include #include -#include namespace raft { namespace linalg { diff --git a/cpp/include/raft/linalg/detail/norm.cuh b/cpp/include/raft/linalg/detail/norm.cuh index 2927a1bcfe..ed7e360848 100644 --- a/cpp/include/raft/linalg/detail/norm.cuh +++ b/cpp/include/raft/linalg/detail/norm.cuh @@ -16,9 +16,9 @@ #pragma once +#include #include #include -#include namespace raft { namespace linalg { diff --git a/cpp/include/raft/linalg/detail/reduce.cuh b/cpp/include/raft/linalg/detail/reduce.cuh index b359019bef..721ca8179f 100644 --- a/cpp/include/raft/linalg/detail/reduce.cuh +++ b/cpp/include/raft/linalg/detail/reduce.cuh @@ -16,9 +16,9 @@ #pragma once +#include #include #include -#include namespace raft { namespace linalg { diff --git a/cpp/include/raft/linalg/detail/strided_reduction.cuh b/cpp/include/raft/linalg/detail/strided_reduction.cuh index aa4517b2b5..0e516b4750 100644 --- a/cpp/include/raft/linalg/detail/strided_reduction.cuh +++ b/cpp/include/raft/linalg/detail/strided_reduction.cuh @@ -18,6 +18,7 @@ #include "unary_op.cuh" #include +#include #include #include #include diff --git a/cpp/include/raft/linalg/detail/subtract.cuh b/cpp/include/raft/linalg/detail/subtract.cuh index 3eebd1a55f..378d29e4bc 100644 --- a/cpp/include/raft/linalg/detail/subtract.cuh +++ b/cpp/include/raft/linalg/detail/subtract.cuh @@ -16,6 +16,7 @@ #pragma once +#include #include #include #include diff --git a/cpp/include/raft/linalg/norm.cuh b/cpp/include/raft/linalg/norm.cuh index 91e47ece6f..55f8f2b2ec 100644 --- a/cpp/include/raft/linalg/norm.cuh +++ b/cpp/include/raft/linalg/norm.cuh @@ -22,8 +22,8 @@ #include "linalg_types.hpp" #include +#include #include -#include #include namespace raft { diff --git a/cpp/include/raft/linalg/normalize.cuh b/cpp/include/raft/linalg/normalize.cuh index fba8a0f09c..e303f2e437 100644 --- a/cpp/include/raft/linalg/normalize.cuh +++ b/cpp/include/raft/linalg/normalize.cuh @@ -18,8 +18,8 @@ #include "detail/normalize.cuh" +#include #include -#include namespace raft { namespace linalg { diff --git a/cpp/include/raft/linalg/power.cuh b/cpp/include/raft/linalg/power.cuh index ef8fea20ee..2c9897b486 100644 --- a/cpp/include/raft/linalg/power.cuh +++ b/cpp/include/raft/linalg/power.cuh @@ -19,9 +19,9 @@ #pragma once #include +#include #include #include -#include #include namespace raft { diff --git a/cpp/include/raft/linalg/reduce.cuh b/cpp/include/raft/linalg/reduce.cuh index 45660e5ed8..3eb8196408 100644 --- a/cpp/include/raft/linalg/reduce.cuh +++ b/cpp/include/raft/linalg/reduce.cuh @@ -22,7 +22,7 @@ #include "linalg_types.hpp" #include -#include +#include #include namespace raft { diff --git a/cpp/include/raft/linalg/sqrt.cuh b/cpp/include/raft/linalg/sqrt.cuh index cc9bfc69ef..ad6cad2eb2 100644 --- a/cpp/include/raft/linalg/sqrt.cuh +++ b/cpp/include/raft/linalg/sqrt.cuh @@ -19,8 +19,8 @@ #pragma once #include +#include #include -#include namespace raft { namespace linalg { diff --git a/cpp/include/raft/linalg/strided_reduction.cuh b/cpp/include/raft/linalg/strided_reduction.cuh index 4f44787bbb..d9c26910e7 100644 --- a/cpp/include/raft/linalg/strided_reduction.cuh +++ b/cpp/include/raft/linalg/strided_reduction.cuh @@ -22,8 +22,8 @@ #include "detail/strided_reduction.cuh" #include +#include #include -#include #include diff --git a/cpp/include/raft/matrix/detail/gather.cuh b/cpp/include/raft/matrix/detail/gather.cuh index a2e1562af8..c006f69e47 100644 --- a/cpp/include/raft/matrix/detail/gather.cuh +++ b/cpp/include/raft/matrix/detail/gather.cuh @@ -16,7 +16,7 @@ #pragma once -#include +#include namespace raft { namespace matrix { diff --git a/cpp/include/raft/matrix/detail/math.cuh b/cpp/include/raft/matrix/detail/math.cuh index 6e80ff6880..c559da3942 100644 --- a/cpp/include/raft/matrix/detail/math.cuh +++ b/cpp/include/raft/matrix/detail/math.cuh @@ -19,6 +19,7 @@ #include #include +#include #include #include #include diff --git a/cpp/include/raft/sparse/distance/detail/lp_distance.cuh b/cpp/include/raft/sparse/distance/detail/lp_distance.cuh index c0a2511db9..5129258f39 100644 --- a/cpp/include/raft/sparse/distance/detail/lp_distance.cuh +++ b/cpp/include/raft/sparse/distance/detail/lp_distance.cuh @@ -18,6 +18,7 @@ #include +#include #include #include #include diff --git a/cpp/include/raft/sparse/op/detail/slice.cuh b/cpp/include/raft/sparse/op/detail/slice.cuh index ddf002bc0e..78a3592c94 100644 --- a/cpp/include/raft/sparse/op/detail/slice.cuh +++ b/cpp/include/raft/sparse/op/detail/slice.cuh @@ -18,6 +18,7 @@ #include +#include #include #include #include diff --git a/cpp/include/raft/spatial/knn/detail/ann_quantized.cuh b/cpp/include/raft/spatial/knn/detail/ann_quantized.cuh index 57ca2625bc..08a06ad4a4 100644 --- a/cpp/include/raft/spatial/knn/detail/ann_quantized.cuh +++ b/cpp/include/raft/spatial/knn/detail/ann_quantized.cuh @@ -22,6 +22,7 @@ #include "common_faiss.h" #include "processing.cuh" +#include #include #include diff --git a/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh b/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh index 3765f1a719..85a05877f1 100644 --- a/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh +++ b/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh @@ -20,6 +20,7 @@ #include // TODO: Need to hide the PairwiseDistance class impl and expose to public API #include "processing.cuh" +#include #include #include #include diff --git a/cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh b/cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh index 6740337071..f08a97e0f7 100644 --- a/cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh +++ b/cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh @@ -24,11 +24,11 @@ #include #include #include +#include #include #include #include #include -#include #include #include diff --git a/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh b/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh index f43b15a117..d2f7d681d7 100644 --- a/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh +++ b/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh @@ -25,6 +25,7 @@ #include #include #include +#include #include #include #include diff --git a/cpp/include/raft/spatial/knn/detail/ivf_pq_build.cuh b/cpp/include/raft/spatial/knn/detail/ivf_pq_build.cuh index 7a1a4492d1..2a2fc1f10b 100644 --- a/cpp/include/raft/spatial/knn/detail/ivf_pq_build.cuh +++ b/cpp/include/raft/spatial/knn/detail/ivf_pq_build.cuh @@ -25,6 +25,7 @@ #include #include #include +#include #include #include #include diff --git a/cpp/include/raft/spatial/knn/detail/ivf_pq_search.cuh b/cpp/include/raft/spatial/knn/detail/ivf_pq_search.cuh index 0013dbe896..8f0a23df03 100644 --- a/cpp/include/raft/spatial/knn/detail/ivf_pq_search.cuh +++ b/cpp/include/raft/spatial/knn/detail/ivf_pq_search.cuh @@ -27,6 +27,7 @@ #include #include #include +#include #include #include #include diff --git a/cpp/include/raft/spatial/knn/detail/processing.cuh b/cpp/include/raft/spatial/knn/detail/processing.cuh index 747a1db00d..42f74e0ab7 100644 --- a/cpp/include/raft/spatial/knn/detail/processing.cuh +++ b/cpp/include/raft/spatial/knn/detail/processing.cuh @@ -17,13 +17,13 @@ #include "processing.hpp" +#include #include #include #include #include #include #include -#include #include namespace raft { diff --git a/cpp/include/raft/stats/detail/silhouette_score.cuh b/cpp/include/raft/stats/detail/silhouette_score.cuh index 27ad8621dc..136c2db16c 100644 --- a/cpp/include/raft/stats/detail/silhouette_score.cuh +++ b/cpp/include/raft/stats/detail/silhouette_score.cuh @@ -21,6 +21,7 @@ #include #include #include +#include #include #include #include diff --git a/cpp/include/raft/stats/mean_center.cuh b/cpp/include/raft/stats/mean_center.cuh index 9f49ff8be2..333d49e193 100644 --- a/cpp/include/raft/stats/mean_center.cuh +++ b/cpp/include/raft/stats/mean_center.cuh @@ -20,6 +20,7 @@ #pragma once #include +#include #include namespace raft { diff --git a/cpp/include/raft/stats/silhouette_score.cuh b/cpp/include/raft/stats/silhouette_score.cuh index fafddb7b23..ff44acc459 100644 --- a/cpp/include/raft/stats/silhouette_score.cuh +++ b/cpp/include/raft/stats/silhouette_score.cuh @@ -19,6 +19,7 @@ #pragma once #include +#include #include #include diff --git a/cpp/include/raft/stats/weighted_mean.cuh b/cpp/include/raft/stats/weighted_mean.cuh index 30a922b243..95151c5453 100644 --- a/cpp/include/raft/stats/weighted_mean.cuh +++ b/cpp/include/raft/stats/weighted_mean.cuh @@ -20,6 +20,7 @@ #pragma once #include +#include #include namespace raft { diff --git a/cpp/include/raft/util/cuda_utils.cuh b/cpp/include/raft/util/cuda_utils.cuh index 75e5a8dc45..61dd6e0ad8 100644 --- a/cpp/include/raft/util/cuda_utils.cuh +++ b/cpp/include/raft/util/cuda_utils.cuh @@ -21,6 +21,7 @@ #include #include +#include #ifndef ENABLE_MEMCPY_ASYNC // enable memcpy_async interface by default for newer GPUs @@ -574,243 +575,6 @@ struct Max { }; /** @} */ -/** - * @defgroup Functors Commonly used functors. - * The optional index argument is mostly to be used for MainLambda in reduction kernels - * @{ - */ - -struct identity_op { - template - HDI Type operator()(Type in, IdxType i = 0) const - { - return in; - } -}; - -template -struct cast_op { - template - HDI OutT operator()(InT in, IdxType i = 0) const - { - return static_cast(in); - } -}; - -struct key_op { - template - HDI typename KVP::Key operator()(const KVP& p, IdxType i = 0) const - { - return p.key; - } -}; - -struct value_op { - template - HDI typename KVP::Value operator()(const KVP& p, IdxType i = 0) const - { - return p.value; - } -}; - -struct sqrt_op { - template - HDI Type operator()(Type in, IdxType i = 0) const - { - return mySqrt(in); - } -}; - -struct nz_op { - template - HDI Type operator()(Type in, IdxType i = 0) const - { - return in != Type(0) ? Type(1) : Type(0); - } -}; - -struct abs_op { - template - HDI Type operator()(Type in, IdxType i = 0) const - { - return myAbs(in); - } -}; - -struct sq_op { - template - HDI Type operator()(Type in, IdxType i = 0) const - { - return in * in; - } -}; - -struct add_op { - template - HDI auto operator()(T1 a, T2 b) const - { - return a + b; - } -}; - -struct sub_op { - template - HDI auto operator()(T1 a, T2 b) const - { - return a - b; - } -}; - -struct mul_op { - template - HDI auto operator()(T1 a, T2 b) const - { - return a * b; - } -}; - -struct div_op { - template - HDI auto operator()(T1 a, T2 b) const - { - return a / b; - } -}; - -struct div_checkzero_op { - template - HDI Type operator()(Type a, Type b) const - { - if (b == Type{0}) { return Type{0}; } - return a / b; - } -}; - -struct pow_op { - template - HDI Type operator()(Type a, Type b) const - { - return raft::myPow(a, b); - } -}; - -struct min_op { - template - HDI Type operator()(Type a, Type b) const - { - if (a > b) { return b; } - return a; - } -}; - -struct max_op { - template - HDI Type operator()(Type a, Type b) const - { - if (b > a) { return b; } - return a; - } -}; - -struct sqdiff_op { - template - HDI Type operator()(Type a, Type b) const - { - Type diff = a - b; - return diff * diff; - } -}; - -struct argmin_op { - template - HDI KVP operator()(const KVP& a, const KVP& b) const - { - if ((b.value < a.value) || ((a.value == b.value) && (b.key < a.key))) { return b; } - return a; - } -}; - -struct argmax_op { - template - HDI KVP operator()(const KVP& a, const KVP& b) const - { - if ((b.value > a.value) || ((a.value == b.value) && (b.key < a.key))) { return b; } - return a; - } -}; - -template -struct const_op { - const ScalarT scalar; - - const_op(ScalarT s) : scalar{s} {} - - template - HDI ScalarT operator()(InT unused) const - { - return scalar; - } -}; - -template -struct scalar_op { - ComposedOpT composed_op; - const ScalarT scalar; - - scalar_op(ScalarT s) : scalar{s} {} - - template - HDI auto operator()(InT a) const - { - return composed_op(a, scalar); - } -}; - -template -using scalar_add_op = scalar_op; - -template -using scalar_sub_op = scalar_op; - -template -using scalar_mul_op = scalar_op; - -template -using scalar_div_op = scalar_op; - -template -using scalar_div_checkzero_op = scalar_op; - -template -using scalar_pow_op = scalar_op; - -template -struct unary_compose_op { - OuterOpT outer_op; - InnerOpT inner_op; - - template - HDI auto operator()(Type a) const - { - return outer_op(inner_op(a)); - } -}; - -template -struct binary_compose_op { - OuterOpT outer_op; - InnerOpT1 inner_op1; - InnerOpT2 inner_op2; - - template - HDI auto operator()(T1 a, T2 b) const - { - return outer_op(inner_op1(a), inner_op2(b)); - } -}; - -/** @} */ - /** * @defgroup Sign Obtain sign value * @brief Obtain sign of x diff --git a/cpp/include/raft/util/scatter.cuh b/cpp/include/raft/util/scatter.cuh index 58b5ce0bc1..e69be36ad9 100644 --- a/cpp/include/raft/util/scatter.cuh +++ b/cpp/include/raft/util/scatter.cuh @@ -16,6 +16,7 @@ #pragma once +#include #include #include diff --git a/cpp/src/distance/cluster_cost.cuh b/cpp/src/distance/cluster_cost.cuh index 3dca92b2db..c1c39dceb7 100644 --- a/cpp/src/distance/cluster_cost.cuh +++ b/cpp/src/distance/cluster_cost.cuh @@ -15,6 +15,7 @@ */ #include +#include #include #include #include diff --git a/cpp/test/cluster/kmeans.cu b/cpp/test/cluster/kmeans.cu index 406f3f4ca1..2c78b1ea2a 100644 --- a/cpp/test/cluster/kmeans.cu +++ b/cpp/test/cluster/kmeans.cu @@ -22,6 +22,7 @@ #include #include #include +#include #include #include #include diff --git a/cpp/test/distance/distance_base.cuh b/cpp/test/distance/distance_base.cuh index 302a81320c..8b8c53d354 100644 --- a/cpp/test/distance/distance_base.cuh +++ b/cpp/test/distance/distance_base.cuh @@ -17,6 +17,7 @@ #include "../test_utils.h" #include #include +#include #include #include #include diff --git a/cpp/test/linalg/binary_op.cu b/cpp/test/linalg/binary_op.cu index 34def617d2..f4f9cd11d7 100644 --- a/cpp/test/linalg/binary_op.cu +++ b/cpp/test/linalg/binary_op.cu @@ -17,6 +17,7 @@ #include "../test_utils.h" #include "binary_op.cuh" #include +#include #include #include #include diff --git a/cpp/test/linalg/coalesced_reduction.cu b/cpp/test/linalg/coalesced_reduction.cu index c2e466f223..1466f557dd 100644 --- a/cpp/test/linalg/coalesced_reduction.cu +++ b/cpp/test/linalg/coalesced_reduction.cu @@ -17,6 +17,7 @@ #include "../test_utils.h" #include "reduce.cuh" #include +#include #include #include #include diff --git a/cpp/test/linalg/map_then_reduce.cu b/cpp/test/linalg/map_then_reduce.cu index 9828f5b206..ba9cca6f04 100644 --- a/cpp/test/linalg/map_then_reduce.cu +++ b/cpp/test/linalg/map_then_reduce.cu @@ -17,6 +17,7 @@ #include "../test_utils.h" #include #include +#include #include #include #include diff --git a/cpp/test/linalg/matrix_vector.cu b/cpp/test/linalg/matrix_vector.cu index 83d1d60ef6..e321d45703 100644 --- a/cpp/test/linalg/matrix_vector.cu +++ b/cpp/test/linalg/matrix_vector.cu @@ -18,6 +18,7 @@ #include "matrix_vector_op.cuh" #include #include +#include #include #include #include diff --git a/cpp/test/linalg/norm.cu b/cpp/test/linalg/norm.cu index 53a1656ec4..09ebdbd13e 100644 --- a/cpp/test/linalg/norm.cu +++ b/cpp/test/linalg/norm.cu @@ -16,6 +16,7 @@ #include "../test_utils.h" #include +#include #include #include #include diff --git a/cpp/test/linalg/normalize.cu b/cpp/test/linalg/normalize.cu index 111253206c..599592fce1 100644 --- a/cpp/test/linalg/normalize.cu +++ b/cpp/test/linalg/normalize.cu @@ -16,6 +16,7 @@ #include "../test_utils.h" #include +#include #include #include #include diff --git a/cpp/test/linalg/reduce.cu b/cpp/test/linalg/reduce.cu index 91cc56ac35..cd112526e0 100644 --- a/cpp/test/linalg/reduce.cu +++ b/cpp/test/linalg/reduce.cu @@ -18,6 +18,7 @@ #include "reduce.cuh" #include #include +#include #include #include #include diff --git a/cpp/test/linalg/reduce.cuh b/cpp/test/linalg/reduce.cuh index ea67b32b1a..17e91ce202 100644 --- a/cpp/test/linalg/reduce.cuh +++ b/cpp/test/linalg/reduce.cuh @@ -17,6 +17,7 @@ #pragma once #include +#include #include #include #include diff --git a/cpp/test/linalg/strided_reduction.cu b/cpp/test/linalg/strided_reduction.cu index 50daafb64b..789ea64c3a 100644 --- a/cpp/test/linalg/strided_reduction.cu +++ b/cpp/test/linalg/strided_reduction.cu @@ -17,6 +17,7 @@ #include "../test_utils.h" #include "reduce.cuh" #include +#include #include #include #include diff --git a/cpp/test/matrix/linewise_op.cu b/cpp/test/matrix/linewise_op.cu index cb9a1d0c31..5dcdf265ac 100644 --- a/cpp/test/matrix/linewise_op.cu +++ b/cpp/test/matrix/linewise_op.cu @@ -20,6 +20,7 @@ #include #include #include +#include #include #include #include diff --git a/cpp/test/sparse/dist_coo_spmv.cu b/cpp/test/sparse/dist_coo_spmv.cu index a51f01ddb3..1b5de9183d 100644 --- a/cpp/test/sparse/dist_coo_spmv.cu +++ b/cpp/test/sparse/dist_coo_spmv.cu @@ -16,6 +16,7 @@ #include +#include #include #include #include From c4bfe37fd0c36f037a187ec15daee47ac5b9f139 Mon Sep 17 00:00:00 2001 From: Louis Sugy Date: Thu, 1 Dec 2022 13:37:02 +0100 Subject: [PATCH 11/22] Revert unwanted change on cmake file --- cpp/cmake/modules/ConfigureCUDA.cmake | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/cmake/modules/ConfigureCUDA.cmake b/cpp/cmake/modules/ConfigureCUDA.cmake index 624755a867..5e68ca5bc4 100644 --- a/cpp/cmake/modules/ConfigureCUDA.cmake +++ b/cpp/cmake/modules/ConfigureCUDA.cmake @@ -18,7 +18,7 @@ if(DISABLE_DEPRECATION_WARNINGS) endif() if(CMAKE_COMPILER_IS_GNUCXX) - list(APPEND RAFT_CXX_FLAGS -Wall -Werror -Wno-unknown-pragmas -Wno-error=deprecated-declarations -ftemplate-backtrace-limit=100) + list(APPEND RAFT_CXX_FLAGS -Wall -Werror -Wno-unknown-pragmas -Wno-error=deprecated-declarations) endif() list(APPEND RAFT_CUDA_FLAGS --expt-extended-lambda --expt-relaxed-constexpr) From e1f0a0b1f19bc81e262f84952841dd9840858e4b Mon Sep 17 00:00:00 2001 From: Louis Sugy Date: Thu, 1 Dec 2022 14:56:01 +0100 Subject: [PATCH 12/22] Support arbitrary number of arguments for inner_op in unary_compose_op --- cpp/include/raft/core/operators.hpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cpp/include/raft/core/operators.hpp b/cpp/include/raft/core/operators.hpp index 1ae2d45131..a29d23f076 100644 --- a/cpp/include/raft/core/operators.hpp +++ b/cpp/include/raft/core/operators.hpp @@ -238,10 +238,10 @@ struct unary_compose_op { OuterOpT outer_op; InnerOpT inner_op; - template - constexpr RAFT_INLINE_FUNCTION auto operator()(Type a) const + template + constexpr RAFT_INLINE_FUNCTION auto operator()(Args... args) const { - return outer_op(inner_op(a)); + return outer_op(inner_op(args...)); } }; From fe0bff9aeab85be770b843e6019eb293533af985 Mon Sep 17 00:00:00 2001 From: Louis Sugy Date: Fri, 2 Dec 2022 16:46:25 +0100 Subject: [PATCH 13/22] Operators improvements: perfect forwarding, references, better support for non-default-constructible operators --- cpp/include/raft/core/operators.hpp | 124 +++++++++++------- .../raft/spatial/knn/detail/ivf_pq_search.cuh | 41 +++--- 2 files changed, 95 insertions(+), 70 deletions(-) diff --git a/cpp/include/raft/core/operators.hpp b/cpp/include/raft/core/operators.hpp index a29d23f076..ef281fdf29 100644 --- a/cpp/include/raft/core/operators.hpp +++ b/cpp/include/raft/core/operators.hpp @@ -18,6 +18,8 @@ #include #include +#include +#include #include @@ -25,13 +27,13 @@ namespace raft { /** * @defgroup Functors Commonly used functors. - * The optional index argument is mostly to be used for MainLambda in reduction kernels + * The optional unused arguments are useful for kernels that pass the index along with the value. * @{ */ struct identity_op { - template - constexpr RAFT_INLINE_FUNCTION Type operator()(Type in, IdxType i = 0) const + template + constexpr RAFT_INLINE_FUNCTION Type operator()(const Type& in, UnusedArgs...) const { return in; } @@ -39,56 +41,56 @@ struct identity_op { template struct cast_op { - template - constexpr RAFT_INLINE_FUNCTION OutT operator()(InT in, IdxType i = 0) const + template + constexpr RAFT_INLINE_FUNCTION OutT operator()(InT in, UnusedArgs...) const { return static_cast(in); } }; struct key_op { - template - constexpr RAFT_INLINE_FUNCTION typename KVP::Key operator()(const KVP& p, IdxType i = 0) const + template + constexpr RAFT_INLINE_FUNCTION typename KVP::Key operator()(const KVP& p, UnusedArgs...) const { return p.key; } }; struct value_op { - template - constexpr RAFT_INLINE_FUNCTION typename KVP::Value operator()(const KVP& p, IdxType i = 0) const + template + constexpr RAFT_INLINE_FUNCTION typename KVP::Value operator()(const KVP& p, UnusedArgs...) const { return p.value; } }; struct sqrt_op { - template - constexpr RAFT_INLINE_FUNCTION Type operator()(Type in, IdxType i = 0) const + template + constexpr RAFT_INLINE_FUNCTION Type operator()(const Type& in, UnusedArgs...) const { return std::sqrt(in); } }; struct nz_op { - template - constexpr RAFT_INLINE_FUNCTION Type operator()(Type in, IdxType i = 0) const + template + constexpr RAFT_INLINE_FUNCTION Type operator()(const Type& in, UnusedArgs...) const { return in != Type(0) ? Type(1) : Type(0); } }; struct abs_op { - template - constexpr RAFT_INLINE_FUNCTION Type operator()(Type in, IdxType i = 0) const + template + constexpr RAFT_INLINE_FUNCTION Type operator()(const Type& in, UnusedArgs...) const { return std::abs(in); } }; struct sq_op { - template - constexpr RAFT_INLINE_FUNCTION Type operator()(Type in, IdxType i = 0) const + template + constexpr RAFT_INLINE_FUNCTION Type operator()(const Type& in, UnusedArgs...) const { return in * in; } @@ -96,7 +98,7 @@ struct sq_op { struct add_op { template - constexpr RAFT_INLINE_FUNCTION auto operator()(T1 a, T2 b) const + constexpr RAFT_INLINE_FUNCTION auto operator()(const T1& a, const T2& b) const { return a + b; } @@ -104,7 +106,7 @@ struct add_op { struct sub_op { template - constexpr RAFT_INLINE_FUNCTION auto operator()(T1 a, T2 b) const + constexpr RAFT_INLINE_FUNCTION auto operator()(const T1& a, const T2& b) const { return a - b; } @@ -112,7 +114,7 @@ struct sub_op { struct mul_op { template - constexpr RAFT_INLINE_FUNCTION auto operator()(T1 a, T2 b) const + constexpr RAFT_INLINE_FUNCTION auto operator()(const T1& a, const T2& b) const { return a * b; } @@ -120,7 +122,7 @@ struct mul_op { struct div_op { template - constexpr RAFT_INLINE_FUNCTION auto operator()(T1 a, T2 b) const + constexpr RAFT_INLINE_FUNCTION auto operator()(const T1& a, const T2& b) const { return a / b; } @@ -128,7 +130,7 @@ struct div_op { struct div_checkzero_op { template - constexpr RAFT_INLINE_FUNCTION Type operator()(Type a, Type b) const + constexpr RAFT_INLINE_FUNCTION Type operator()(const Type& a, const Type& b) const { if (b == Type{0}) { return Type{0}; } return a / b; @@ -137,7 +139,7 @@ struct div_checkzero_op { struct pow_op { template - constexpr RAFT_INLINE_FUNCTION Type operator()(Type a, Type b) const + constexpr RAFT_INLINE_FUNCTION Type operator()(const Type& a, const Type& b) const { return std::pow(a, b); } @@ -145,7 +147,7 @@ struct pow_op { struct min_op { template - constexpr RAFT_INLINE_FUNCTION Type operator()(Type a, Type b) const + constexpr RAFT_INLINE_FUNCTION Type operator()(const Type& a, const Type& b) const { if (a > b) { return b; } return a; @@ -154,7 +156,7 @@ struct min_op { struct max_op { template - constexpr RAFT_INLINE_FUNCTION Type operator()(Type a, Type b) const + constexpr RAFT_INLINE_FUNCTION Type operator()(const Type& a, const Type& b) const { if (b > a) { return b; } return a; @@ -163,7 +165,7 @@ struct max_op { struct sqdiff_op { template - constexpr RAFT_INLINE_FUNCTION Type operator()(Type a, Type b) const + constexpr RAFT_INLINE_FUNCTION Type operator()(const Type& a, const Type& b) const { Type diff = a - b; return diff * diff; @@ -192,21 +194,26 @@ template struct const_op { const ScalarT scalar; - constexpr const_op(ScalarT s) : scalar{s} {} + constexpr const_op(const ScalarT& s) : scalar{s} {} - template - constexpr RAFT_INLINE_FUNCTION ScalarT operator()(InT unused) const + template + constexpr RAFT_INLINE_FUNCTION ScalarT operator()(Args...) const { return scalar; } }; -template +template struct scalar_op { - ComposedOpT composed_op; const ScalarT scalar; + const BinaryOpT composed_op; - constexpr scalar_op(ScalarT s) : scalar{s} {} + template >> + constexpr scalar_op(const ScalarT& s) : scalar{s}, composed_op{} + { + } + constexpr scalar_op(const ScalarT& s, BinaryOpT o) : scalar{s}, composed_op{o} {} template constexpr RAFT_INLINE_FUNCTION auto operator()(InT a) const @@ -216,43 +223,66 @@ struct scalar_op { }; template -using scalar_add_op = scalar_op; +using scalar_add_op = scalar_op; template -using scalar_sub_op = scalar_op; +using scalar_sub_op = scalar_op; template -using scalar_mul_op = scalar_op; +using scalar_mul_op = scalar_op; template -using scalar_div_op = scalar_op; +using scalar_div_op = scalar_op; template -using scalar_div_checkzero_op = scalar_op; +using scalar_div_checkzero_op = scalar_op; template -using scalar_pow_op = scalar_op; +using scalar_pow_op = scalar_op; template -struct unary_compose_op { - OuterOpT outer_op; - InnerOpT inner_op; +struct compose_op { + const OuterOpT outer_op; + const InnerOpT inner_op; + + template && + std::is_default_constructible_v>> + constexpr compose_op() : outer_op{}, inner_op{} + { + } + constexpr compose_op(OuterOpT out_op, InnerOpT in_op) : outer_op{out_op}, inner_op{in_op} {} template - constexpr RAFT_INLINE_FUNCTION auto operator()(Args... args) const + constexpr RAFT_INLINE_FUNCTION auto operator()(Args&&... args) const { - return outer_op(inner_op(args...)); + return outer_op(inner_op(std::forward(args)...)); } }; template struct binary_compose_op { - OuterOpT outer_op; - InnerOpT1 inner_op1; - InnerOpT2 inner_op2; + const OuterOpT outer_op; + const InnerOpT1 inner_op1; + const InnerOpT2 inner_op2; + + template && + std::is_default_constructible_v && + std::is_default_constructible_v>> + constexpr binary_compose_op() : outer_op{}, inner_op1{}, inner_op2{} + { + } + constexpr binary_compose_op(OuterOpT out_op, InnerOpT1 in_op1, InnerOpT2 in_op2) + : outer_op{out_op}, inner_op1{in_op1}, inner_op2{in_op2} + { + } template - constexpr RAFT_INLINE_FUNCTION auto operator()(T1 a, T2 b) const + constexpr RAFT_INLINE_FUNCTION auto operator()(const T1& a, const T2& b) const { return outer_op(inner_op1(a), inner_op2(b)); } diff --git a/cpp/include/raft/spatial/knn/detail/ivf_pq_search.cuh b/cpp/include/raft/spatial/knn/detail/ivf_pq_search.cuh index 8f0a23df03..48bf934f5d 100644 --- a/cpp/include/raft/spatial/knn/detail/ivf_pq_search.cuh +++ b/cpp/include/raft/spatial/knn/detail/ivf_pq_search.cuh @@ -419,34 +419,29 @@ void postprocess_distances(float* out, // [n_queries, topk] switch (metric) { case distance::DistanceType::L2Unexpanded: case distance::DistanceType::L2Expanded: { - linalg::unaryOp( - out, - in, - len, - raft::scalar_op>, float>( - scaling_factor * scaling_factor), - stream); + linalg::unaryOp(out, + in, + len, + raft::compose_op(raft::scalar_mul_op{scaling_factor * scaling_factor}, + raft::cast_op{}), + stream); } break; case distance::DistanceType::L2SqrtUnexpanded: case distance::DistanceType::L2SqrtExpanded: { - linalg::unaryOp( - out, - in, - len, - raft::scalar_op< - raft::binary_compose_op>>, - float>(scaling_factor), - stream); + linalg::unaryOp(out, + in, + len, + raft::compose_op(raft::scalar_mul_op{scaling_factor}, + raft::compose_op>{}), + stream); } break; case distance::DistanceType::InnerProduct: { - linalg::unaryOp( - out, - in, - len, - raft::scalar_op>, float>( - -scaling_factor * scaling_factor), - stream); + linalg::unaryOp(out, + in, + len, + raft::compose_op(raft::scalar_mul_op{-scaling_factor * scaling_factor}, + raft::cast_op{}), + stream); } break; default: RAFT_FAIL("Unexpected metric."); } From 5c07ef3886cdd16180f9b7d98d2af3054537d322 Mon Sep 17 00:00:00 2001 From: Louis Sugy Date: Fri, 2 Dec 2022 17:08:23 +0100 Subject: [PATCH 14/22] Remove unused includes and bring back custom DivOp in silhouette_score.cuh --- cpp/include/raft/stats/detail/silhouette_score.cuh | 2 +- cpp/include/raft/stats/mean_center.cuh | 1 - cpp/include/raft/stats/silhouette_score.cuh | 1 - cpp/include/raft/stats/weighted_mean.cuh | 1 - 4 files changed, 1 insertion(+), 4 deletions(-) diff --git a/cpp/include/raft/stats/detail/silhouette_score.cuh b/cpp/include/raft/stats/detail/silhouette_score.cuh index 136c2db16c..3cf95c3941 100644 --- a/cpp/include/raft/stats/detail/silhouette_score.cuh +++ b/cpp/include/raft/stats/detail/silhouette_score.cuh @@ -273,7 +273,7 @@ DataT silhouette_score( nRows, true, true, - raft::div_op(), + DivOp(), stream); // calculating row-wise minimum diff --git a/cpp/include/raft/stats/mean_center.cuh b/cpp/include/raft/stats/mean_center.cuh index 333d49e193..9f49ff8be2 100644 --- a/cpp/include/raft/stats/mean_center.cuh +++ b/cpp/include/raft/stats/mean_center.cuh @@ -20,7 +20,6 @@ #pragma once #include -#include #include namespace raft { diff --git a/cpp/include/raft/stats/silhouette_score.cuh b/cpp/include/raft/stats/silhouette_score.cuh index ff44acc459..fafddb7b23 100644 --- a/cpp/include/raft/stats/silhouette_score.cuh +++ b/cpp/include/raft/stats/silhouette_score.cuh @@ -19,7 +19,6 @@ #pragma once #include -#include #include #include diff --git a/cpp/include/raft/stats/weighted_mean.cuh b/cpp/include/raft/stats/weighted_mean.cuh index 95151c5453..30a922b243 100644 --- a/cpp/include/raft/stats/weighted_mean.cuh +++ b/cpp/include/raft/stats/weighted_mean.cuh @@ -20,7 +20,6 @@ #pragma once #include -#include #include namespace raft { From 66f52ed42124e8715e71d5625773ed447bd5583b Mon Sep 17 00:00:00 2001 From: Louis Sugy Date: Mon, 5 Dec 2022 18:16:55 +0100 Subject: [PATCH 15/22] More powerful compose ops --- cpp/include/raft/core/operators.hpp | 113 +++++++++++++----- .../raft/spatial/knn/detail/ivf_pq_search.cuh | 13 +- 2 files changed, 93 insertions(+), 33 deletions(-) diff --git a/cpp/include/raft/core/operators.hpp b/cpp/include/raft/core/operators.hpp index ef281fdf29..ce2064f046 100644 --- a/cpp/include/raft/core/operators.hpp +++ b/cpp/include/raft/core/operators.hpp @@ -203,6 +203,20 @@ struct const_op { } }; +/** + * @brief Wraps around a binary operator, passing a given scalar on the right-hand side. + * + * Usage example: + * @code{.cpp} + * #include + * + * raft::scalar_op op(2.0f); + * std::cout << op(2.1f) << std::endl; // 4.2 + * @endcode + * + * @tparam ScalarT + * @tparam BinaryOpT + */ template struct scalar_op { const ScalarT scalar; @@ -240,51 +254,96 @@ using scalar_div_checkzero_op = scalar_op; template using scalar_pow_op = scalar_op; -template +/** + * @brief Constructs an operator by composing a chain of operators. + * + * Note that all arguments are passed to the innermost operator. + * + * Usage example: + * @code{.cpp} + * #include + * + * auto op = raft::compose_op(raft::sqrt_op(), raft::abs_op(), raft::cast_op(), + * raft::scalar_add_op(8)); + * std::cout << op(-50) << std::endl; // 6.48074 + * @endcode + * + * @tparam OpsT Any number of operation types. + */ +template struct compose_op { - const OuterOpT outer_op; - const InnerOpT inner_op; + const std::tuple ops; - template && - std::is_default_constructible_v>> - constexpr compose_op() : outer_op{}, inner_op{} + template , + typename CondT = std::enable_if_t>> + constexpr compose_op() : ops{} { } - constexpr compose_op(OuterOpT out_op, InnerOpT in_op) : outer_op{out_op}, inner_op{in_op} {} + constexpr explicit compose_op(OpsT... ops) : ops{ops...} {} template constexpr RAFT_INLINE_FUNCTION auto operator()(Args&&... args) const { - return outer_op(inner_op(std::forward(args)...)); + return compose(std::forward(args)...); + } + + private: + template + constexpr RAFT_INLINE_FUNCTION auto compose(Args&&... args) const + { + if constexpr (RemOps > 0) { + return compose(std::get(ops)(std::forward(args)...)); + } else { + return identity_op{}(std::forward(args)...); + } } }; -template -struct binary_compose_op { +/** + * @brief Constructs an operator by composing an outer op with one inner op for each of its inputs. + * + * Usage example: + * @code{.cpp} + * #include + * + * raft::map_args_op> op; + * std::cout << op(42.0f, 10) << std::endl; // 16.4807 + * @endcode + * + * @tparam OuterOpT Outer operation type + * @tparam ArgOpsT Operation types for each input of the outer operation + */ +template +struct map_args_op { const OuterOpT outer_op; - const InnerOpT1 inner_op1; - const InnerOpT2 inner_op2; - - template && - std::is_default_constructible_v && - std::is_default_constructible_v>> - constexpr binary_compose_op() : outer_op{}, inner_op1{}, inner_op2{} + const std::tuple arg_ops; + + template , + typename CondT = std::enable_if_t && + std::is_default_constructible_v>> + constexpr map_args_op() : outer_op{}, arg_ops{} { } - constexpr binary_compose_op(OuterOpT out_op, InnerOpT1 in_op1, InnerOpT2 in_op2) - : outer_op{out_op}, inner_op1{in_op1}, inner_op2{in_op2} + constexpr explicit map_args_op(OuterOpT outer_op, ArgOpsT... arg_ops) + : outer_op{outer_op}, arg_ops{arg_ops...} { } - template - constexpr RAFT_INLINE_FUNCTION auto operator()(const T1& a, const T2& b) const + template + constexpr RAFT_INLINE_FUNCTION auto operator()(Args&&... args) const + { + constexpr size_t kNumOps = sizeof...(ArgOpsT); + static_assert(kNumOps == sizeof...(Args), + "The number of arguments does not match the number of mapping operators"); + return map_args(std::make_index_sequence{}, std::forward(args)...); + } + + private: + template + constexpr RAFT_INLINE_FUNCTION auto map_args(std::index_sequence, Args&&... args) const { - return outer_op(inner_op1(a), inner_op2(b)); + return outer_op(std::get(arg_ops)(std::forward(args))...); } }; diff --git a/cpp/include/raft/spatial/knn/detail/ivf_pq_search.cuh b/cpp/include/raft/spatial/knn/detail/ivf_pq_search.cuh index 48bf934f5d..83d2e357b9 100644 --- a/cpp/include/raft/spatial/knn/detail/ivf_pq_search.cuh +++ b/cpp/include/raft/spatial/knn/detail/ivf_pq_search.cuh @@ -428,12 +428,13 @@ void postprocess_distances(float* out, // [n_queries, topk] } break; case distance::DistanceType::L2SqrtUnexpanded: case distance::DistanceType::L2SqrtExpanded: { - linalg::unaryOp(out, - in, - len, - raft::compose_op(raft::scalar_mul_op{scaling_factor}, - raft::compose_op>{}), - stream); + linalg::unaryOp( + out, + in, + len, + raft::compose_op{ + raft::scalar_mul_op{scaling_factor}, raft::sqrt_op{}, raft::cast_op{}}, + stream); } break; case distance::DistanceType::InnerProduct: { linalg::unaryOp(out, From d5f94dc196bb760dd8962da8fcbcc369f551b161 Mon Sep 17 00:00:00 2001 From: Louis Sugy Date: Mon, 5 Dec 2022 18:28:22 +0100 Subject: [PATCH 16/22] Add tuple header --- cpp/include/raft/core/operators.hpp | 1 + 1 file changed, 1 insertion(+) diff --git a/cpp/include/raft/core/operators.hpp b/cpp/include/raft/core/operators.hpp index ce2064f046..0a23264007 100644 --- a/cpp/include/raft/core/operators.hpp +++ b/cpp/include/raft/core/operators.hpp @@ -18,6 +18,7 @@ #include #include +#include #include #include From fee9a50dbbcb705e0538eec0dd68d7246e4a35fc Mon Sep 17 00:00:00 2001 From: Louis Sugy Date: Tue, 6 Dec 2022 15:18:25 +0100 Subject: [PATCH 17/22] Change return types to auto --- cpp/include/raft/core/operators.hpp | 36 ++++++++++++++--------------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/cpp/include/raft/core/operators.hpp b/cpp/include/raft/core/operators.hpp index 0a23264007..c8cf719ce4 100644 --- a/cpp/include/raft/core/operators.hpp +++ b/cpp/include/raft/core/operators.hpp @@ -34,7 +34,7 @@ namespace raft { struct identity_op { template - constexpr RAFT_INLINE_FUNCTION Type operator()(const Type& in, UnusedArgs...) const + constexpr RAFT_INLINE_FUNCTION auto operator()(const Type& in, UnusedArgs...) const { return in; } @@ -43,7 +43,7 @@ struct identity_op { template struct cast_op { template - constexpr RAFT_INLINE_FUNCTION OutT operator()(InT in, UnusedArgs...) const + constexpr RAFT_INLINE_FUNCTION auto operator()(InT in, UnusedArgs...) const { return static_cast(in); } @@ -51,7 +51,7 @@ struct cast_op { struct key_op { template - constexpr RAFT_INLINE_FUNCTION typename KVP::Key operator()(const KVP& p, UnusedArgs...) const + constexpr RAFT_INLINE_FUNCTION auto operator()(const KVP& p, UnusedArgs...) const { return p.key; } @@ -59,7 +59,7 @@ struct key_op { struct value_op { template - constexpr RAFT_INLINE_FUNCTION typename KVP::Value operator()(const KVP& p, UnusedArgs...) const + constexpr RAFT_INLINE_FUNCTION auto operator()(const KVP& p, UnusedArgs...) const { return p.value; } @@ -67,7 +67,7 @@ struct value_op { struct sqrt_op { template - constexpr RAFT_INLINE_FUNCTION Type operator()(const Type& in, UnusedArgs...) const + constexpr RAFT_INLINE_FUNCTION auto operator()(const Type& in, UnusedArgs...) const { return std::sqrt(in); } @@ -75,7 +75,7 @@ struct sqrt_op { struct nz_op { template - constexpr RAFT_INLINE_FUNCTION Type operator()(const Type& in, UnusedArgs...) const + constexpr RAFT_INLINE_FUNCTION auto operator()(const Type& in, UnusedArgs...) const { return in != Type(0) ? Type(1) : Type(0); } @@ -83,7 +83,7 @@ struct nz_op { struct abs_op { template - constexpr RAFT_INLINE_FUNCTION Type operator()(const Type& in, UnusedArgs...) const + constexpr RAFT_INLINE_FUNCTION auto operator()(const Type& in, UnusedArgs...) const { return std::abs(in); } @@ -91,7 +91,7 @@ struct abs_op { struct sq_op { template - constexpr RAFT_INLINE_FUNCTION Type operator()(const Type& in, UnusedArgs...) const + constexpr RAFT_INLINE_FUNCTION auto operator()(const Type& in, UnusedArgs...) const { return in * in; } @@ -131,7 +131,7 @@ struct div_op { struct div_checkzero_op { template - constexpr RAFT_INLINE_FUNCTION Type operator()(const Type& a, const Type& b) const + constexpr RAFT_INLINE_FUNCTION auto operator()(const Type& a, const Type& b) const { if (b == Type{0}) { return Type{0}; } return a / b; @@ -140,7 +140,7 @@ struct div_checkzero_op { struct pow_op { template - constexpr RAFT_INLINE_FUNCTION Type operator()(const Type& a, const Type& b) const + constexpr RAFT_INLINE_FUNCTION auto operator()(const Type& a, const Type& b) const { return std::pow(a, b); } @@ -148,7 +148,7 @@ struct pow_op { struct min_op { template - constexpr RAFT_INLINE_FUNCTION Type operator()(const Type& a, const Type& b) const + constexpr RAFT_INLINE_FUNCTION auto operator()(const Type& a, const Type& b) const { if (a > b) { return b; } return a; @@ -157,7 +157,7 @@ struct min_op { struct max_op { template - constexpr RAFT_INLINE_FUNCTION Type operator()(const Type& a, const Type& b) const + constexpr RAFT_INLINE_FUNCTION auto operator()(const Type& a, const Type& b) const { if (b > a) { return b; } return a; @@ -166,7 +166,7 @@ struct max_op { struct sqdiff_op { template - constexpr RAFT_INLINE_FUNCTION Type operator()(const Type& a, const Type& b) const + constexpr RAFT_INLINE_FUNCTION auto operator()(const Type& a, const Type& b) const { Type diff = a - b; return diff * diff; @@ -175,7 +175,7 @@ struct sqdiff_op { struct argmin_op { template - constexpr RAFT_INLINE_FUNCTION KVP operator()(const KVP& a, const KVP& b) const + constexpr RAFT_INLINE_FUNCTION auto operator()(const KVP& a, const KVP& b) const { if ((b.value < a.value) || ((a.value == b.value) && (b.key < a.key))) { return b; } return a; @@ -184,7 +184,7 @@ struct argmin_op { struct argmax_op { template - constexpr RAFT_INLINE_FUNCTION KVP operator()(const KVP& a, const KVP& b) const + constexpr RAFT_INLINE_FUNCTION auto operator()(const KVP& a, const KVP& b) const { if ((b.value > a.value) || ((a.value == b.value) && (b.key < a.key))) { return b; } return a; @@ -195,10 +195,10 @@ template struct const_op { const ScalarT scalar; - constexpr const_op(const ScalarT& s) : scalar{s} {} + constexpr explicit const_op(const ScalarT& s) : scalar{s} {} template - constexpr RAFT_INLINE_FUNCTION ScalarT operator()(Args...) const + constexpr RAFT_INLINE_FUNCTION auto operator()(Args...) const { return scalar; } @@ -225,7 +225,7 @@ struct scalar_op { template >> - constexpr scalar_op(const ScalarT& s) : scalar{s}, composed_op{} + constexpr explicit scalar_op(const ScalarT& s) : scalar{s}, composed_op{} { } constexpr scalar_op(const ScalarT& s, BinaryOpT o) : scalar{s}, composed_op{o} {} From f2f0f7953f292f0851bd8240781f3b4703e960ab Mon Sep 17 00:00:00 2001 From: Louis Sugy Date: Tue, 6 Dec 2022 18:24:15 +0100 Subject: [PATCH 18/22] Suppress clang-tidy modernize-use-default-member-init warning --- cpp/include/raft/core/operators.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/include/raft/core/operators.hpp b/cpp/include/raft/core/operators.hpp index c8cf719ce4..69d3802927 100644 --- a/cpp/include/raft/core/operators.hpp +++ b/cpp/include/raft/core/operators.hpp @@ -277,7 +277,7 @@ struct compose_op { template , typename CondT = std::enable_if_t>> - constexpr compose_op() : ops{} + constexpr compose_op() { } constexpr explicit compose_op(OpsT... ops) : ops{ops...} {} @@ -323,7 +323,7 @@ struct map_args_op { typename T2 = std::tuple, typename CondT = std::enable_if_t && std::is_default_constructible_v>> - constexpr map_args_op() : outer_op{}, arg_ops{} + constexpr map_args_op() { } constexpr explicit map_args_op(OuterOpT outer_op, ArgOpsT... arg_ops) From 7113adccbd591cac2a7e9b0dfa825063add99cec Mon Sep 17 00:00:00 2001 From: Louis Sugy Date: Wed, 7 Dec 2022 15:55:22 +0100 Subject: [PATCH 19/22] Rename scalar_xxx_op to xxx_const_op --- .../raft/cluster/detail/kmeans_common.cuh | 2 +- cpp/include/raft/core/operators.hpp | 32 +++++++++---------- cpp/include/raft/linalg/detail/add.cuh | 2 +- cpp/include/raft/linalg/detail/divide.cuh | 2 +- cpp/include/raft/linalg/detail/eltwise.cuh | 4 +-- cpp/include/raft/linalg/detail/multiply.cuh | 2 +- cpp/include/raft/linalg/detail/subtract.cuh | 2 +- cpp/include/raft/linalg/power.cuh | 2 +- .../sparse/distance/detail/lp_distance.cuh | 6 ++-- cpp/include/raft/sparse/op/detail/slice.cuh | 2 +- .../raft/spatial/knn/detail/ann_quantized.cuh | 2 +- .../raft/spatial/knn/detail/ivf_pq_search.cuh | 6 ++-- .../raft/spatial/knn/detail/processing.cuh | 2 +- .../raft/stats/detail/weighted_mean.cuh | 2 +- cpp/test/sparse/dist_coo_spmv.cu | 2 +- 15 files changed, 35 insertions(+), 35 deletions(-) diff --git a/cpp/include/raft/cluster/detail/kmeans_common.cuh b/cpp/include/raft/cluster/detail/kmeans_common.cuh index 3bc1b5b89b..4e52661278 100644 --- a/cpp/include/raft/cluster/detail/kmeans_common.cuh +++ b/cpp/include/raft/cluster/detail/kmeans_common.cuh @@ -160,7 +160,7 @@ void checkWeight(const raft::handle_t& handle, raft::linalg::unaryOp(weight.data_handle(), weight.data_handle(), n_samples, - raft::scalar_mul_op{scale}, + raft::mul_const_op{scale}, stream); } } diff --git a/cpp/include/raft/core/operators.hpp b/cpp/include/raft/core/operators.hpp index 69d3802927..3bba5a9c0d 100644 --- a/cpp/include/raft/core/operators.hpp +++ b/cpp/include/raft/core/operators.hpp @@ -205,55 +205,55 @@ struct const_op { }; /** - * @brief Wraps around a binary operator, passing a given scalar on the right-hand side. + * @brief Wraps around a binary operator, passing a constant on the right-hand side. * * Usage example: * @code{.cpp} * #include * - * raft::scalar_op op(2.0f); + * raft::plug_const_op op(2.0f); * std::cout << op(2.1f) << std::endl; // 4.2 * @endcode * - * @tparam ScalarT + * @tparam ConstT * @tparam BinaryOpT */ -template -struct scalar_op { - const ScalarT scalar; +template +struct plug_const_op { + const ConstT c; const BinaryOpT composed_op; template >> - constexpr explicit scalar_op(const ScalarT& s) : scalar{s}, composed_op{} + constexpr explicit plug_const_op(const ConstT& s) : c{s}, composed_op{} { } - constexpr scalar_op(const ScalarT& s, BinaryOpT o) : scalar{s}, composed_op{o} {} + constexpr plug_const_op(const ConstT& s, BinaryOpT o) : c{s}, composed_op{o} {} template constexpr RAFT_INLINE_FUNCTION auto operator()(InT a) const { - return composed_op(a, scalar); + return composed_op(a, c); } }; template -using scalar_add_op = scalar_op; +using add_const_op = plug_const_op; template -using scalar_sub_op = scalar_op; +using sub_const_op = plug_const_op; template -using scalar_mul_op = scalar_op; +using mul_const_op = plug_const_op; template -using scalar_div_op = scalar_op; +using div_const_op = plug_const_op; template -using scalar_div_checkzero_op = scalar_op; +using div_checkzero_const_op = plug_const_op; template -using scalar_pow_op = scalar_op; +using pow_const_op = plug_const_op; /** * @brief Constructs an operator by composing a chain of operators. @@ -265,7 +265,7 @@ using scalar_pow_op = scalar_op; * #include * * auto op = raft::compose_op(raft::sqrt_op(), raft::abs_op(), raft::cast_op(), - * raft::scalar_add_op(8)); + * raft::add_const_op(8)); * std::cout << op(-50) << std::endl; // 6.48074 * @endcode * diff --git a/cpp/include/raft/linalg/detail/add.cuh b/cpp/include/raft/linalg/detail/add.cuh index bec1dcd9b1..bf9b2bd1d8 100644 --- a/cpp/include/raft/linalg/detail/add.cuh +++ b/cpp/include/raft/linalg/detail/add.cuh @@ -28,7 +28,7 @@ namespace detail { template void addScalar(OutT* out, const InT* in, InT scalar, IdxType len, cudaStream_t stream) { - raft::linalg::unaryOp(out, in, len, raft::scalar_add_op(scalar), stream); + raft::linalg::unaryOp(out, in, len, raft::add_const_op(scalar), stream); } template diff --git a/cpp/include/raft/linalg/detail/divide.cuh b/cpp/include/raft/linalg/detail/divide.cuh index d25316649b..eef1d19d6e 100644 --- a/cpp/include/raft/linalg/detail/divide.cuh +++ b/cpp/include/raft/linalg/detail/divide.cuh @@ -27,7 +27,7 @@ namespace detail { template void divideScalar(OutT* out, const InT* in, InT scalar, IdxType len, cudaStream_t stream) { - raft::linalg::unaryOp(out, in, len, raft::scalar_div_op(scalar), stream); + raft::linalg::unaryOp(out, in, len, raft::div_const_op(scalar), stream); } }; // end namespace detail diff --git a/cpp/include/raft/linalg/detail/eltwise.cuh b/cpp/include/raft/linalg/detail/eltwise.cuh index 60a31a4357..25b4ca0499 100644 --- a/cpp/include/raft/linalg/detail/eltwise.cuh +++ b/cpp/include/raft/linalg/detail/eltwise.cuh @@ -27,13 +27,13 @@ namespace detail { template void scalarAdd(OutType* out, const InType* in, InType scalar, IdxType len, cudaStream_t stream) { - raft::linalg::unaryOp(out, in, len, raft::scalar_add_op(scalar), stream); + raft::linalg::unaryOp(out, in, len, raft::add_const_op(scalar), stream); } template void scalarMultiply(OutType* out, const InType* in, InType scalar, IdxType len, cudaStream_t stream) { - raft::linalg::unaryOp(out, in, len, raft::scalar_mul_op(scalar), stream); + raft::linalg::unaryOp(out, in, len, raft::mul_const_op(scalar), stream); } template diff --git a/cpp/include/raft/linalg/detail/multiply.cuh b/cpp/include/raft/linalg/detail/multiply.cuh index bb757d4531..84b832d875 100644 --- a/cpp/include/raft/linalg/detail/multiply.cuh +++ b/cpp/include/raft/linalg/detail/multiply.cuh @@ -27,7 +27,7 @@ template void multiplyScalar( math_t* out, const math_t* in, const math_t scalar, IdxType len, cudaStream_t stream) { - raft::linalg::unaryOp(out, in, len, raft::scalar_mul_op{scalar}, stream); + raft::linalg::unaryOp(out, in, len, raft::mul_const_op{scalar}, stream); } }; // end namespace detail diff --git a/cpp/include/raft/linalg/detail/subtract.cuh b/cpp/include/raft/linalg/detail/subtract.cuh index 378d29e4bc..6df09df8ed 100644 --- a/cpp/include/raft/linalg/detail/subtract.cuh +++ b/cpp/include/raft/linalg/detail/subtract.cuh @@ -28,7 +28,7 @@ namespace detail { template void subtractScalar(OutT* out, const InT* in, InT scalar, IdxType len, cudaStream_t stream) { - raft::linalg::unaryOp(out, in, len, raft::scalar_sub_op(scalar), stream); + raft::linalg::unaryOp(out, in, len, raft::sub_const_op(scalar), stream); } template diff --git a/cpp/include/raft/linalg/power.cuh b/cpp/include/raft/linalg/power.cuh index 2c9897b486..59c2cdf314 100644 --- a/cpp/include/raft/linalg/power.cuh +++ b/cpp/include/raft/linalg/power.cuh @@ -41,7 +41,7 @@ namespace linalg { template void powerScalar(out_t* out, const in_t* in, const in_t scalar, IdxType len, cudaStream_t stream) { - raft::linalg::unaryOp(out, in, len, raft::scalar_pow_op(scalar), stream); + raft::linalg::unaryOp(out, in, len, raft::pow_const_op(scalar), stream); } /** @} */ diff --git a/cpp/include/raft/sparse/distance/detail/lp_distance.cuh b/cpp/include/raft/sparse/distance/detail/lp_distance.cuh index 5129258f39..01b665e207 100644 --- a/cpp/include/raft/sparse/distance/detail/lp_distance.cuh +++ b/cpp/include/raft/sparse/distance/detail/lp_distance.cuh @@ -198,7 +198,7 @@ class lp_unexpanded_distances_t : public distances_t { raft::linalg::unaryOp(out_dists, out_dists, config_->a_nrows * config_->b_nrows, - raft::scalar_pow_op(one_over_p), + raft::pow_const_op(one_over_p), config_->handle.get_stream()); } @@ -223,7 +223,7 @@ class hamming_unexpanded_distances_t : public distances_t { raft::linalg::unaryOp(out_dists, out_dists, config_->a_nrows * config_->b_nrows, - raft::scalar_mul_op(n_cols), + raft::mul_const_op(n_cols), config_->handle.get_stream()); } @@ -304,7 +304,7 @@ class kl_divergence_unexpanded_distances_t : public distances_t { raft::linalg::unaryOp(out_dists, out_dists, config_->a_nrows * config_->b_nrows, - raft::scalar_mul_op(0.5), + raft::mul_const_op(0.5), config_->handle.get_stream()); } diff --git a/cpp/include/raft/sparse/op/detail/slice.cuh b/cpp/include/raft/sparse/op/detail/slice.cuh index 78a3592c94..4d2f1a4195 100644 --- a/cpp/include/raft/sparse/op/detail/slice.cuh +++ b/cpp/include/raft/sparse/op/detail/slice.cuh @@ -74,7 +74,7 @@ void csr_row_slice_indptr(value_idx start_row, raft::linalg::unaryOp(indptr_out, indptr_out, (stop_row + 2) - start_row, - raft::scalar_sub_op(s_offset), + raft::sub_const_op(s_offset), stream); } diff --git a/cpp/include/raft/spatial/knn/detail/ann_quantized.cuh b/cpp/include/raft/spatial/knn/detail/ann_quantized.cuh index 08a06ad4a4..55675f2a46 100644 --- a/cpp/include/raft/spatial/knn/detail/ann_quantized.cuh +++ b/cpp/include/raft/spatial/knn/detail/ann_quantized.cuh @@ -203,7 +203,7 @@ void approx_knn_search(const handle_t& handle, float p = 0.5; // standard l2 if (index->metric == raft::distance::DistanceType::LpUnexpanded) p = 1.0 / index->metricArg; raft::linalg::unaryOp( - distances, distances, n * k, raft::scalar_pow_op(p), handle.get_stream()); + distances, distances, n * k, raft::pow_const_op(p), handle.get_stream()); } if constexpr (std::is_same_v) { index->metric_processor->postprocess(distances); } } diff --git a/cpp/include/raft/spatial/knn/detail/ivf_pq_search.cuh b/cpp/include/raft/spatial/knn/detail/ivf_pq_search.cuh index 83d2e357b9..4ecb81edcb 100644 --- a/cpp/include/raft/spatial/knn/detail/ivf_pq_search.cuh +++ b/cpp/include/raft/spatial/knn/detail/ivf_pq_search.cuh @@ -422,7 +422,7 @@ void postprocess_distances(float* out, // [n_queries, topk] linalg::unaryOp(out, in, len, - raft::compose_op(raft::scalar_mul_op{scaling_factor * scaling_factor}, + raft::compose_op(raft::mul_const_op{scaling_factor * scaling_factor}, raft::cast_op{}), stream); } break; @@ -433,14 +433,14 @@ void postprocess_distances(float* out, // [n_queries, topk] in, len, raft::compose_op{ - raft::scalar_mul_op{scaling_factor}, raft::sqrt_op{}, raft::cast_op{}}, + raft::mul_const_op{scaling_factor}, raft::sqrt_op{}, raft::cast_op{}}, stream); } break; case distance::DistanceType::InnerProduct: { linalg::unaryOp(out, in, len, - raft::compose_op(raft::scalar_mul_op{-scaling_factor * scaling_factor}, + raft::compose_op(raft::mul_const_op{-scaling_factor * scaling_factor}, raft::cast_op{}), stream); } break; diff --git a/cpp/include/raft/spatial/knn/detail/processing.cuh b/cpp/include/raft/spatial/knn/detail/processing.cuh index 42f74e0ab7..b4b1cb2c14 100644 --- a/cpp/include/raft/spatial/knn/detail/processing.cuh +++ b/cpp/include/raft/spatial/knn/detail/processing.cuh @@ -110,7 +110,7 @@ class CorrelationMetricProcessor : public CosineMetricProcessor { raft::linalg::unaryOp(means_.data(), means_.data(), cosine::n_rows_, - raft::scalar_mul_op(normalizer_const), + raft::mul_const_op(normalizer_const), cosine::stream_); raft::stats::meanCenter(data, diff --git a/cpp/include/raft/stats/detail/weighted_mean.cuh b/cpp/include/raft/stats/detail/weighted_mean.cuh index ba7e6bf4c4..ada0995f7d 100644 --- a/cpp/include/raft/stats/detail/weighted_mean.cuh +++ b/cpp/include/raft/stats/detail/weighted_mean.cuh @@ -68,7 +68,7 @@ void weightedMean(Type* mu, false, [weights] __device__(Type v, IdxType i) { return v * weights[i]; }, raft::add_op{}, - raft::scalar_div_op(WS)); + raft::div_const_op(WS)); } }; // end namespace detail }; // end namespace stats diff --git a/cpp/test/sparse/dist_coo_spmv.cu b/cpp/test/sparse/dist_coo_spmv.cu index 1b5de9183d..9a8f650449 100644 --- a/cpp/test/sparse/dist_coo_spmv.cu +++ b/cpp/test/sparse/dist_coo_spmv.cu @@ -163,7 +163,7 @@ class SparseDistanceCOOSPMVTest raft::linalg::unaryOp(out_dists.data(), out_dists.data(), dist_config.a_nrows * dist_config.b_nrows, - raft::scalar_pow_op{p}, + raft::pow_const_op{p}, dist_config.handle.get_stream()); } break; From 384606f7de774b8eb1d84198adf7ab59bd66841c Mon Sep 17 00:00:00 2001 From: Louis Sugy Date: Fri, 9 Dec 2022 14:58:04 +0100 Subject: [PATCH 20/22] Operators tests --- cpp/include/raft/core/operators.hpp | 4 +- cpp/test/CMakeLists.txt | 2 + cpp/test/cluster/kmeans.cu | 2 +- cpp/test/cluster/linkage.cu | 2 +- cpp/test/core/operators_device.cu | 340 ++++++++++++++++++ cpp/test/core/operators_host.cpp | 286 +++++++++++++++ cpp/test/distance/dist_adj.cu | 2 +- cpp/test/distance/dist_canberra.cu | 2 +- cpp/test/distance/dist_chebyshev.cu | 2 +- cpp/test/distance/dist_correlation.cu | 2 +- cpp/test/distance/dist_cos.cu | 2 +- cpp/test/distance/dist_euc_exp.cu | 2 +- cpp/test/distance/dist_euc_unexp.cu | 2 +- cpp/test/distance/dist_eucsqrt_exp.cu | 2 +- cpp/test/distance/dist_hamming.cu | 2 +- cpp/test/distance/dist_hellinger.cu | 2 +- cpp/test/distance/dist_jensen_shannon.cu | 2 +- cpp/test/distance/dist_kl_divergence.cu | 2 +- cpp/test/distance/dist_l1.cu | 2 +- cpp/test/distance/dist_minkowski.cu | 2 +- cpp/test/distance/dist_russell_rao.cu | 2 +- cpp/test/distance/distance_base.cuh | 2 +- cpp/test/distance/fused_l2_nn.cu | 2 +- cpp/test/distance/gram.cu | 2 +- cpp/test/label/label.cu | 2 +- cpp/test/label/merge_labels.cu | 2 +- cpp/test/linalg/add.cu | 2 +- cpp/test/linalg/axpy.cu | 2 +- cpp/test/linalg/binary_op.cu | 2 +- cpp/test/linalg/binary_op.cuh | 2 +- cpp/test/linalg/cholesky_r1.cu | 2 +- cpp/test/linalg/coalesced_reduction.cu | 2 +- cpp/test/linalg/divide.cu | 2 +- cpp/test/linalg/dot.cu | 2 +- cpp/test/linalg/eig.cu | 2 +- cpp/test/linalg/eig_sel.cu | 2 +- cpp/test/linalg/eltwise.cu | 2 +- cpp/test/linalg/gemm_layout.cu | 2 +- cpp/test/linalg/gemv.cu | 2 +- cpp/test/linalg/map.cu | 2 +- cpp/test/linalg/map_then_reduce.cu | 2 +- cpp/test/linalg/matrix_vector.cu | 2 +- cpp/test/linalg/matrix_vector_op.cu | 2 +- cpp/test/linalg/matrix_vector_op.cuh | 2 +- cpp/test/linalg/mean_squared_error.cu | 2 +- cpp/test/linalg/multiply.cu | 2 +- cpp/test/linalg/norm.cu | 2 +- cpp/test/linalg/normalize.cu | 2 +- cpp/test/linalg/power.cu | 2 +- cpp/test/linalg/reduce.cu | 2 +- cpp/test/linalg/reduce_cols_by_key.cu | 2 +- cpp/test/linalg/reduce_rows_by_key.cu | 2 +- cpp/test/linalg/rsvd.cu | 2 +- cpp/test/linalg/sqrt.cu | 2 +- cpp/test/linalg/strided_reduction.cu | 2 +- cpp/test/linalg/subtract.cu | 2 +- cpp/test/linalg/svd.cu | 2 +- cpp/test/linalg/ternary_op.cu | 2 +- cpp/test/linalg/transpose.cu | 2 +- cpp/test/linalg/unary_op.cu | 2 +- cpp/test/linalg/unary_op.cuh | 2 +- cpp/test/matrix/argmax.cu | 2 +- cpp/test/matrix/argmin.cu | 2 +- cpp/test/matrix/columnSort.cu | 2 +- cpp/test/matrix/diagonal.cu | 2 +- cpp/test/matrix/gather.cu | 2 +- cpp/test/matrix/linewise_op.cu | 2 +- cpp/test/matrix/math.cu | 2 +- cpp/test/matrix/matrix.cu | 2 +- cpp/test/matrix/norm.cu | 2 +- cpp/test/matrix/reverse.cu | 2 +- cpp/test/matrix/slice.cu | 2 +- cpp/test/matrix/triangular.cu | 2 +- cpp/test/mst.cu | 2 +- cpp/test/neighbors/ann_ivf_flat.cu | 2 +- cpp/test/neighbors/ann_ivf_pq.cuh | 2 +- cpp/test/neighbors/ann_utils.cuh | 2 +- cpp/test/neighbors/ball_cover.cu | 2 +- cpp/test/neighbors/epsilon_neighborhood.cu | 2 +- cpp/test/neighbors/faiss_mr.cu | 2 +- cpp/test/neighbors/fused_l2_knn.cu | 2 +- cpp/test/neighbors/haversine.cu | 2 +- cpp/test/neighbors/knn.cu | 2 +- cpp/test/neighbors/refine.cu | 2 +- cpp/test/neighbors/selection.cu | 2 +- cpp/test/random/make_blobs.cu | 2 +- cpp/test/random/make_regression.cu | 2 +- cpp/test/random/multi_variable_gaussian.cu | 2 +- cpp/test/random/permute.cu | 2 +- cpp/test/random/rmat_rectangular_generator.cu | 2 +- cpp/test/random/rng.cu | 2 +- cpp/test/random/rng_discrete.cu | 2 +- cpp/test/random/rng_int.cu | 2 +- cpp/test/random/sample_without_replacement.cu | 2 +- cpp/test/sparse/add.cu | 2 +- cpp/test/sparse/convert_coo.cu | 2 +- cpp/test/sparse/convert_csr.cu | 2 +- cpp/test/sparse/csr_row_slice.cu | 2 +- cpp/test/sparse/csr_to_dense.cu | 2 +- cpp/test/sparse/csr_transpose.cu | 2 +- cpp/test/sparse/degree.cu | 2 +- cpp/test/sparse/dist_coo_spmv.cu | 2 +- cpp/test/sparse/distance.cu | 2 +- cpp/test/sparse/filter.cu | 2 +- cpp/test/sparse/neighbors/brute_force.cu | 2 +- .../sparse/neighbors/connect_components.cu | 2 +- cpp/test/sparse/neighbors/knn_graph.cu | 2 +- cpp/test/sparse/norm.cu | 2 +- cpp/test/sparse/reduce.cu | 2 +- cpp/test/sparse/row_op.cu | 2 +- cpp/test/sparse/sort.cu | 2 +- cpp/test/sparse/spgemmi.cu | 2 +- cpp/test/sparse/symmetrize.cu | 2 +- cpp/test/stats/accuracy.cu | 2 +- cpp/test/stats/adjusted_rand_index.cu | 2 +- cpp/test/stats/completeness_score.cu | 2 +- cpp/test/stats/contingencyMatrix.cu | 2 +- cpp/test/stats/cov.cu | 2 +- cpp/test/stats/dispersion.cu | 2 +- cpp/test/stats/entropy.cu | 2 +- cpp/test/stats/histogram.cu | 2 +- cpp/test/stats/homogeneity_score.cu | 2 +- cpp/test/stats/information_criterion.cu | 2 +- cpp/test/stats/kl_divergence.cu | 2 +- cpp/test/stats/mean.cu | 2 +- cpp/test/stats/mean_center.cu | 2 +- cpp/test/stats/meanvar.cu | 2 +- cpp/test/stats/minmax.cu | 2 +- cpp/test/stats/mutual_info_score.cu | 2 +- cpp/test/stats/r2_score.cu | 2 +- cpp/test/stats/rand_index.cu | 2 +- cpp/test/stats/regression_metrics.cu | 2 +- cpp/test/stats/silhouette_score.cu | 2 +- cpp/test/stats/stddev.cu | 2 +- cpp/test/stats/sum.cu | 2 +- cpp/test/stats/trustworthiness.cu | 2 +- cpp/test/stats/v_measure.cu | 2 +- cpp/test/stats/weighted_mean.cu | 2 +- cpp/test/test_utils.cuh | 329 +++++++++++++++++ cpp/test/test_utils.h | 336 ++--------------- 140 files changed, 1114 insertions(+), 451 deletions(-) create mode 100644 cpp/test/core/operators_device.cu create mode 100644 cpp/test/core/operators_host.cpp create mode 100644 cpp/test/test_utils.cuh diff --git a/cpp/include/raft/core/operators.hpp b/cpp/include/raft/core/operators.hpp index 3bba5a9c0d..8a694a140d 100644 --- a/cpp/include/raft/core/operators.hpp +++ b/cpp/include/raft/core/operators.hpp @@ -225,7 +225,8 @@ struct plug_const_op { template >> - constexpr explicit plug_const_op(const ConstT& s) : c{s}, composed_op{} + constexpr explicit plug_const_op(const ConstT& s) + : c{s}, composed_op{} // The compiler complains if composed_op is not initialized explictly { } constexpr plug_const_op(const ConstT& s, BinaryOpT o) : c{s}, composed_op{o} {} @@ -324,6 +325,7 @@ struct map_args_op { typename CondT = std::enable_if_t && std::is_default_constructible_v>> constexpr map_args_op() + : outer_op{} // The compiler complains if outer_op is not initialized explictly { } constexpr explicit map_args_op(OuterOpT outer_op, ArgOpsT... arg_ops) diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index a75eb3bff6..5be8401a6f 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -86,6 +86,8 @@ if(BUILD_TESTS) CORE_TEST PATH test/common/logger.cpp + test/core/operators_device.cu + test/core/operators_host.cpp test/handle.cpp test/interruptible.cu test/nvtx.cpp diff --git a/cpp/test/cluster/kmeans.cu b/cpp/test/cluster/kmeans.cu index 2c78b1ea2a..9644541a0c 100644 --- a/cpp/test/cluster/kmeans.cu +++ b/cpp/test/cluster/kmeans.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/cluster/linkage.cu b/cpp/test/cluster/linkage.cu index 5533f552bd..53aa5c55e3 100644 --- a/cpp/test/cluster/linkage.cu +++ b/cpp/test/cluster/linkage.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include diff --git a/cpp/test/core/operators_device.cu b/cpp/test/core/operators_device.cu new file mode 100644 index 0000000000..1697a09fcf --- /dev/null +++ b/cpp/test/core/operators_device.cu @@ -0,0 +1,340 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include + +#include "../test_utils.cuh" +#include +#include +#include +#include + +template +__global__ void eval_op_on_device_kernel(OutT* out, OpT op, Args... args) +{ + out[0] = op(std::forward(args)...); +} + +template +auto eval_op_on_device(OpT op, Args&&... args) +{ + typedef decltype(op(args...)) OutT; + auto stream = rmm::cuda_stream_default; + rmm::device_scalar result(stream); + eval_op_on_device_kernel<<<1, 1, 0, stream>>>(result.data(), op, std::forward(args)...); + return result.value(stream); +} + +TEST(OperatorsDevice, IdentityOp) +{ + raft::identity_op op; + ASSERT_TRUE(raft::match(12.34f, eval_op_on_device(op, 12.34f, 0), raft::Compare())); +} + +TEST(OperatorsDevice, CastOp) +{ + raft::cast_op op; + ASSERT_TRUE( + raft::match(1234.0f, eval_op_on_device(op, 1234, 0), raft::CompareApprox(0.00001f))); +} + +TEST(OperatorsDevice, KeyOp) +{ + raft::key_op op; + raft::KeyValuePair kvp(12, 3.4f); + ASSERT_TRUE(raft::match(12, eval_op_on_device(op, kvp, 0), raft::Compare())); +} + +TEST(OperatorsDevice, ValueOp) +{ + raft::value_op op; + raft::KeyValuePair kvp(12, 3.4f); + ASSERT_TRUE( + raft::match(3.4f, eval_op_on_device(op, kvp, 0), raft::CompareApprox(0.00001f))); +} + +TEST(OperatorsDevice, SqrtOpF) +{ + raft::sqrt_op op; + ASSERT_TRUE(raft::match( + std::sqrt(12.34f), eval_op_on_device(op, 12.34f, 0), raft::CompareApprox(0.0001f))); + ASSERT_TRUE(raft::match( + std::sqrt(12.34), eval_op_on_device(op, 12.34, 0), raft::CompareApprox(0.000001))); +} + +TEST(OperatorsDevice, NZOp) +{ + raft::nz_op op; + ASSERT_TRUE( + raft::match(0.0f, eval_op_on_device(op, 0.0f, 0), raft::CompareApprox(0.00001f))); + ASSERT_TRUE( + raft::match(1.0f, eval_op_on_device(op, 12.34f, 0), raft::CompareApprox(0.00001f))); +} + +TEST(OperatorsDevice, AbsOp) +{ + raft::abs_op op; + ASSERT_TRUE( + raft::match(12.34f, eval_op_on_device(op, -12.34f, 0), raft::CompareApprox(0.00001f))); + ASSERT_TRUE( + raft::match(12.34, eval_op_on_device(op, -12.34, 0), raft::CompareApprox(0.000001))); + ASSERT_TRUE(raft::match(1234, eval_op_on_device(op, -1234, 0), raft::Compare())); +} + +TEST(OperatorsDevice, SqOp) +{ + raft::sq_op op; + ASSERT_TRUE( + raft::match(152.2756f, eval_op_on_device(op, 12.34f, 0), raft::CompareApprox(0.00001f))); + ASSERT_TRUE(raft::match(289, eval_op_on_device(op, -17, 0), raft::Compare())); +} + +TEST(OperatorsDevice, AddOp) +{ + raft::add_op op; + ASSERT_TRUE( + raft::match(12.34f, eval_op_on_device(op, 12.0f, 0.34f), raft::CompareApprox(0.00001f))); + ASSERT_TRUE(raft::match(1234, eval_op_on_device(op, 1200, 34), raft::Compare())); +} + +TEST(OperatorsDevice, SubOp) +{ + raft::sub_op op; + ASSERT_TRUE( + raft::match(12.34f, eval_op_on_device(op, 13.0f, 0.66f), raft::CompareApprox(0.00001f))); + ASSERT_TRUE(raft::match(1234, eval_op_on_device(op, 1300, 66), raft::Compare())); +} + +TEST(OperatorsDevice, MulOp) +{ + raft::mul_op op; + ASSERT_TRUE( + raft::match(12.34f, eval_op_on_device(op, 2.0f, 6.17f), raft::CompareApprox(0.00001f))); +} + +TEST(OperatorsDevice, DivOp) +{ + raft::div_op op; + ASSERT_TRUE( + raft::match(12.34f, eval_op_on_device(op, 37.02f, 3.0f), raft::CompareApprox(0.00001f))); +} + +TEST(OperatorsDevice, DivCheckZeroOp) +{ + raft::div_checkzero_op op; + ASSERT_TRUE( + raft::match(12.34f, eval_op_on_device(op, 37.02f, 3.0f), raft::CompareApprox(0.00001f))); + ASSERT_TRUE( + raft::match(0.0f, eval_op_on_device(op, 37.02f, 0.0f), raft::CompareApprox(0.00001f))); +} + +TEST(OperatorsDevice, PowOp) +{ + raft::pow_op op; + ASSERT_TRUE( + raft::match(1000.0f, eval_op_on_device(op, 10.0f, 3.0f), raft::CompareApprox(0.0001f))); + ASSERT_TRUE( + raft::match(1000.0, eval_op_on_device(op, 10.0, 3.0), raft::CompareApprox(0.000001))); +} + +TEST(OperatorsDevice, MinOp) +{ + raft::min_op op; + ASSERT_TRUE( + raft::match(3.0f, eval_op_on_device(op, 3.0f, 5.0f), raft::CompareApprox(0.0001f))); + ASSERT_TRUE( + raft::match(3.0, eval_op_on_device(op, 5.0, 3.0), raft::CompareApprox(0.000001))); + ASSERT_TRUE(raft::match(3, eval_op_on_device(op, 3, 5), raft::Compare())); +} + +TEST(OperatorsDevice, MaxOp) +{ + raft::max_op op; + ASSERT_TRUE( + raft::match(5.0f, eval_op_on_device(op, 3.0f, 5.0f), raft::CompareApprox(0.0001f))); + ASSERT_TRUE( + raft::match(5.0, eval_op_on_device(op, 5.0, 3.0), raft::CompareApprox(0.000001))); + ASSERT_TRUE(raft::match(5, eval_op_on_device(op, 3, 5), raft::Compare())); +} + +TEST(OperatorsDevice, SqDiffOp) +{ + raft::sqdiff_op op; + ASSERT_TRUE( + raft::match(4.0f, eval_op_on_device(op, 3.0f, 5.0f), raft::CompareApprox(0.0001f))); + ASSERT_TRUE( + raft::match(4.0, eval_op_on_device(op, 5.0, 3.0), raft::CompareApprox(0.000001))); +} + +TEST(OperatorsDevice, ArgminOp) +{ + raft::argmin_op op; + raft::KeyValuePair kvp_a(0, 1.2f); + raft::KeyValuePair kvp_b(0, 3.4f); + raft::KeyValuePair kvp_c(1, 1.2f); + ASSERT_TRUE(raft::match( + kvp_a, eval_op_on_device(op, kvp_a, kvp_b), raft::Compare>())); + ASSERT_TRUE(raft::match( + kvp_a, eval_op_on_device(op, kvp_b, kvp_a), raft::Compare>())); + ASSERT_TRUE(raft::match( + kvp_a, eval_op_on_device(op, kvp_a, kvp_c), raft::Compare>())); + ASSERT_TRUE(raft::match( + kvp_a, eval_op_on_device(op, kvp_c, kvp_a), raft::Compare>())); + ASSERT_TRUE(raft::match( + kvp_c, eval_op_on_device(op, kvp_b, kvp_c), raft::Compare>())); + ASSERT_TRUE(raft::match( + kvp_c, eval_op_on_device(op, kvp_c, kvp_b), raft::Compare>())); +} + +TEST(OperatorsDevice, ArgmaxOp) +{ + raft::argmax_op op; + raft::KeyValuePair kvp_a(0, 1.2f); + raft::KeyValuePair kvp_b(0, 3.4f); + raft::KeyValuePair kvp_c(1, 1.2f); + ASSERT_TRUE(raft::match( + kvp_b, eval_op_on_device(op, kvp_a, kvp_b), raft::Compare>())); + ASSERT_TRUE(raft::match( + kvp_b, eval_op_on_device(op, kvp_b, kvp_a), raft::Compare>())); + ASSERT_TRUE(raft::match( + kvp_a, eval_op_on_device(op, kvp_a, kvp_c), raft::Compare>())); + ASSERT_TRUE(raft::match( + kvp_a, eval_op_on_device(op, kvp_c, kvp_a), raft::Compare>())); + ASSERT_TRUE(raft::match( + kvp_b, eval_op_on_device(op, kvp_b, kvp_c), raft::Compare>())); + ASSERT_TRUE(raft::match( + kvp_b, eval_op_on_device(op, kvp_c, kvp_b), raft::Compare>())); +} + +TEST(OperatorsDevice, ConstOp) +{ + raft::const_op op(12.34f); + ASSERT_TRUE(raft::match(12.34f, eval_op_on_device(op), raft::Compare())); + ASSERT_TRUE(raft::match(12.34f, eval_op_on_device(op, 42), raft::Compare())); + ASSERT_TRUE(raft::match(12.34f, eval_op_on_device(op, 13, 37.0f), raft::Compare())); +} + +template +struct trinary_add { + const T c; + constexpr explicit trinary_add(const T& c_) : c{c_} {} + constexpr RAFT_INLINE_FUNCTION auto operator()(T a, T b) const { return a + b + c; } +}; + +TEST(OperatorsDevice, PlugConstOp) +{ + // First, wrap around a default-constructible op + { + raft::plug_const_op op(0.34f); + ASSERT_TRUE( + raft::match(12.34f, eval_op_on_device(op, 12.0f), raft::CompareApprox(0.0001f))); + } + + // Second, wrap around a non-default-constructible op + { + auto op = raft::plug_const_op(10.0f, trinary_add(2.0f)); + ASSERT_TRUE( + raft::match(12.34f, eval_op_on_device(op, 0.34f), raft::CompareApprox(0.0001f))); + } +} + +TEST(OperatorsDevice, AddConstOp) +{ + raft::add_const_op op(0.34f); + ASSERT_TRUE( + raft::match(12.34f, eval_op_on_device(op, 12.0f), raft::CompareApprox(0.0001f))); +} + +TEST(OperatorsDevice, SubConstOp) +{ + raft::sub_const_op op(0.66f); + ASSERT_TRUE( + raft::match(12.34f, eval_op_on_device(op, 13.0f), raft::CompareApprox(0.0001f))); +} + +TEST(OperatorsDevice, MulConstOp) +{ + raft::mul_const_op op(2.0f); + ASSERT_TRUE( + raft::match(12.34f, eval_op_on_device(op, 6.17f), raft::CompareApprox(0.0001f))); +} + +TEST(OperatorsDevice, DivConstOp) +{ + raft::div_const_op op(3.0f); + ASSERT_TRUE( + raft::match(12.34f, eval_op_on_device(op, 37.02f), raft::CompareApprox(0.0001f))); +} + +TEST(OperatorsDevice, DivCheckZeroConstOp) +{ + // Non-zero denominator + { + raft::div_checkzero_const_op op(3.0f); + ASSERT_TRUE( + raft::match(12.34f, eval_op_on_device(op, 37.02f), raft::CompareApprox(0.0001f))); + } + // Zero denominator + { + raft::div_checkzero_const_op op(0.0f); + ASSERT_TRUE( + raft::match(0.0f, eval_op_on_device(op, 37.02f), raft::CompareApprox(0.0001f))); + } +} + +TEST(OperatorsDevice, PowConstOp) +{ + raft::pow_const_op op(3.0f); + ASSERT_TRUE( + raft::match(1000.0f, eval_op_on_device(op, 10.0f), raft::CompareApprox(0.0001f))); +} + +TEST(OperatorsDevice, ComposeOp) +{ + // All ops are default-constructible + { + raft::compose_op> op; + ASSERT_TRUE(raft::match( + std::sqrt(42.0f), eval_op_on_device(op, -42, 0), raft::CompareApprox(0.0001f))); + } + // Some ops are not default-constructible + { + auto op = raft::compose_op( + raft::sqrt_op(), raft::abs_op(), raft::add_const_op(8.0f), raft::cast_op()); + ASSERT_TRUE(raft::match( + std::sqrt(42.0f), eval_op_on_device(op, -50, 0), raft::CompareApprox(0.0001f))); + } +} + +TEST(OperatorsDevice, MapArgsOp) +{ + // All ops are default-constructible + { + raft::map_args_op op; + ASSERT_TRUE( + raft::match(42.0f, eval_op_on_device(op, 5.0f, -17.0f), raft::CompareApprox(0.0001f))); + } + // Some ops are not default-constructible + { + auto op = raft::map_args_op( + raft::add_op(), raft::pow_const_op(2.0f), raft::mul_const_op(-1.0f)); + ASSERT_TRUE( + raft::match(42.0f, eval_op_on_device(op, 5.0f, -17.0f), raft::CompareApprox(0.0001f))); + } +} diff --git a/cpp/test/core/operators_host.cpp b/cpp/test/core/operators_host.cpp new file mode 100644 index 0000000000..de66fda919 --- /dev/null +++ b/cpp/test/core/operators_host.cpp @@ -0,0 +1,286 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include + +#include "../test_utils.h" +#include +#include + +TEST(OperatorsHost, IdentityOp) +{ + raft::identity_op op; + ASSERT_TRUE(raft::match(12.34f, op(12.34f, 0), raft::Compare())); +} + +TEST(OperatorsHost, CastOp) +{ + raft::cast_op op; + ASSERT_TRUE(raft::match(1234.0f, op(1234, 0), raft::CompareApprox(0.00001f))); +} + +TEST(OperatorsHost, KeyOp) +{ + raft::key_op op; + raft::KeyValuePair kvp(12, 3.4f); + ASSERT_TRUE(raft::match(12, op(kvp, 0), raft::Compare())); +} + +TEST(OperatorsHost, ValueOp) +{ + raft::value_op op; + raft::KeyValuePair kvp(12, 3.4f); + ASSERT_TRUE(raft::match(3.4f, op(kvp, 0), raft::CompareApprox(0.00001f))); +} + +TEST(OperatorsHost, SqrtOpF) +{ + raft::sqrt_op op; + ASSERT_TRUE(raft::match(std::sqrt(12.34f), op(12.34f, 0), raft::CompareApprox(0.0001f))); + ASSERT_TRUE(raft::match(std::sqrt(12.34), op(12.34, 0), raft::CompareApprox(0.000001))); +} + +TEST(OperatorsHost, NZOp) +{ + raft::nz_op op; + ASSERT_TRUE(raft::match(0.0f, op(0.0f, 0), raft::CompareApprox(0.00001f))); + ASSERT_TRUE(raft::match(1.0f, op(12.34f, 0), raft::CompareApprox(0.00001f))); +} + +TEST(OperatorsHost, AbsOp) +{ + raft::abs_op op; + ASSERT_TRUE(raft::match(12.34f, op(-12.34f, 0), raft::CompareApprox(0.00001f))); + ASSERT_TRUE(raft::match(12.34, op(-12.34, 0), raft::CompareApprox(0.000001))); + ASSERT_TRUE(raft::match(1234, op(-1234, 0), raft::Compare())); +} + +TEST(OperatorsHost, SqOp) +{ + raft::sq_op op; + ASSERT_TRUE(raft::match(152.2756f, op(12.34f, 0), raft::CompareApprox(0.00001f))); + ASSERT_TRUE(raft::match(289, op(-17, 0), raft::Compare())); +} + +TEST(OperatorsHost, AddOp) +{ + raft::add_op op; + ASSERT_TRUE(raft::match(12.34f, op(12.0f, 0.34f), raft::CompareApprox(0.00001f))); + ASSERT_TRUE(raft::match(1234, op(1200, 34), raft::Compare())); +} + +TEST(OperatorsHost, SubOp) +{ + raft::sub_op op; + ASSERT_TRUE(raft::match(12.34f, op(13.0f, 0.66f), raft::CompareApprox(0.00001f))); + ASSERT_TRUE(raft::match(1234, op(1300, 66), raft::Compare())); +} + +TEST(OperatorsHost, MulOp) +{ + raft::mul_op op; + ASSERT_TRUE(raft::match(12.34f, op(2.0f, 6.17f), raft::CompareApprox(0.00001f))); +} + +TEST(OperatorsHost, DivOp) +{ + raft::div_op op; + ASSERT_TRUE(raft::match(12.34f, op(37.02f, 3.0f), raft::CompareApprox(0.00001f))); +} + +TEST(OperatorsHost, DivCheckZeroOp) +{ + raft::div_checkzero_op op; + ASSERT_TRUE(raft::match(12.34f, op(37.02f, 3.0f), raft::CompareApprox(0.00001f))); + ASSERT_TRUE(raft::match(0.0f, op(37.02f, 0.0f), raft::CompareApprox(0.00001f))); +} + +TEST(OperatorsHost, PowOp) +{ + raft::pow_op op; + ASSERT_TRUE(raft::match(1000.0f, op(10.0f, 3.0f), raft::CompareApprox(0.0001f))); + ASSERT_TRUE(raft::match(1000.0, op(10.0, 3.0), raft::CompareApprox(0.000001))); +} + +TEST(OperatorsHost, MinOp) +{ + raft::min_op op; + ASSERT_TRUE(raft::match(3.0f, op(3.0f, 5.0f), raft::CompareApprox(0.0001f))); + ASSERT_TRUE(raft::match(3.0, op(5.0, 3.0), raft::CompareApprox(0.000001))); + ASSERT_TRUE(raft::match(3, op(3, 5), raft::Compare())); +} + +TEST(OperatorsHost, MaxOp) +{ + raft::max_op op; + ASSERT_TRUE(raft::match(5.0f, op(3.0f, 5.0f), raft::CompareApprox(0.0001f))); + ASSERT_TRUE(raft::match(5.0, op(5.0, 3.0), raft::CompareApprox(0.000001))); + ASSERT_TRUE(raft::match(5, op(3, 5), raft::Compare())); +} + +TEST(OperatorsHost, SqDiffOp) +{ + raft::sqdiff_op op; + ASSERT_TRUE(raft::match(4.0f, op(3.0f, 5.0f), raft::CompareApprox(0.0001f))); + ASSERT_TRUE(raft::match(4.0, op(5.0, 3.0), raft::CompareApprox(0.000001))); +} + +TEST(OperatorsHost, ArgminOp) +{ + raft::argmin_op op; + raft::KeyValuePair kvp_a(0, 1.2f); + raft::KeyValuePair kvp_b(0, 3.4f); + raft::KeyValuePair kvp_c(1, 1.2f); + ASSERT_TRUE( + raft::match(kvp_a, op(kvp_a, kvp_b), raft::Compare>())); + ASSERT_TRUE( + raft::match(kvp_a, op(kvp_b, kvp_a), raft::Compare>())); + ASSERT_TRUE( + raft::match(kvp_a, op(kvp_a, kvp_c), raft::Compare>())); + ASSERT_TRUE( + raft::match(kvp_a, op(kvp_c, kvp_a), raft::Compare>())); + ASSERT_TRUE( + raft::match(kvp_c, op(kvp_b, kvp_c), raft::Compare>())); + ASSERT_TRUE( + raft::match(kvp_c, op(kvp_c, kvp_b), raft::Compare>())); +} + +TEST(OperatorsHost, ArgmaxOp) +{ + raft::argmax_op op; + raft::KeyValuePair kvp_a(0, 1.2f); + raft::KeyValuePair kvp_b(0, 3.4f); + raft::KeyValuePair kvp_c(1, 1.2f); + ASSERT_TRUE( + raft::match(kvp_b, op(kvp_a, kvp_b), raft::Compare>())); + ASSERT_TRUE( + raft::match(kvp_b, op(kvp_b, kvp_a), raft::Compare>())); + ASSERT_TRUE( + raft::match(kvp_a, op(kvp_a, kvp_c), raft::Compare>())); + ASSERT_TRUE( + raft::match(kvp_a, op(kvp_c, kvp_a), raft::Compare>())); + ASSERT_TRUE( + raft::match(kvp_b, op(kvp_b, kvp_c), raft::Compare>())); + ASSERT_TRUE( + raft::match(kvp_b, op(kvp_c, kvp_b), raft::Compare>())); +} + +TEST(OperatorsHost, ConstOp) +{ + raft::const_op op(12.34f); + ASSERT_TRUE(raft::match(12.34f, op(), raft::Compare())); + ASSERT_TRUE(raft::match(12.34f, op(42), raft::Compare())); + ASSERT_TRUE(raft::match(12.34f, op(13, 37.0f), raft::Compare())); +} + +template +struct trinary_add { + const T c; + constexpr explicit trinary_add(const T& c_) : c{c_} {} + constexpr RAFT_INLINE_FUNCTION auto operator()(T a, T b) const { return a + b + c; } +}; + +TEST(OperatorsHost, PlugConstOp) +{ + // First, wrap around a default-constructible op + { + raft::plug_const_op op(0.34f); + ASSERT_TRUE(raft::match(12.34f, op(12.0f), raft::CompareApprox(0.0001f))); + } + + // Second, wrap around a non-default-constructible op + { + auto op = raft::plug_const_op(10.0f, trinary_add(2.0f)); + ASSERT_TRUE(raft::match(12.34f, op(0.34f), raft::CompareApprox(0.0001f))); + } +} + +TEST(OperatorsHost, AddConstOp) +{ + raft::add_const_op op(0.34f); + ASSERT_TRUE(raft::match(12.34f, op(12.0f), raft::CompareApprox(0.0001f))); +} + +TEST(OperatorsHost, SubConstOp) +{ + raft::sub_const_op op(0.66f); + ASSERT_TRUE(raft::match(12.34f, op(13.0f), raft::CompareApprox(0.0001f))); +} + +TEST(OperatorsHost, MulConstOp) +{ + raft::mul_const_op op(2.0f); + ASSERT_TRUE(raft::match(12.34f, op(6.17f), raft::CompareApprox(0.0001f))); +} + +TEST(OperatorsHost, DivConstOp) +{ + raft::div_const_op op(3.0f); + ASSERT_TRUE(raft::match(12.34f, op(37.02f), raft::CompareApprox(0.0001f))); +} + +TEST(OperatorsHost, DivCheckZeroConstOp) +{ + // Non-zero denominator + { + raft::div_checkzero_const_op op(3.0f); + ASSERT_TRUE(raft::match(12.34f, op(37.02f), raft::CompareApprox(0.0001f))); + } + // Zero denominator + { + raft::div_checkzero_const_op op(0.0f); + ASSERT_TRUE(raft::match(0.0f, op(37.02f), raft::CompareApprox(0.0001f))); + } +} + +TEST(OperatorsHost, PowConstOp) +{ + raft::pow_const_op op(3.0f); + ASSERT_TRUE(raft::match(1000.0f, op(10.0f), raft::CompareApprox(0.0001f))); +} + +TEST(OperatorsHost, ComposeOp) +{ + // All ops are default-constructible + { + raft::compose_op> op; + ASSERT_TRUE(raft::match(std::sqrt(42.0f), op(-42, 0), raft::CompareApprox(0.0001f))); + } + // Some ops are not default-constructible + { + auto op = raft::compose_op( + raft::sqrt_op(), raft::abs_op(), raft::add_const_op(8.0f), raft::cast_op()); + ASSERT_TRUE(raft::match(std::sqrt(42.0f), op(-50, 0), raft::CompareApprox(0.0001f))); + } +} + +TEST(OperatorsHost, MapArgsOp) +{ + // All ops are default-constructible + { + raft::map_args_op op; + ASSERT_TRUE(raft::match(42.0f, op(5.0f, -17.0f), raft::CompareApprox(0.0001f))); + } + // Some ops are not default-constructible + { + auto op = raft::map_args_op( + raft::add_op(), raft::pow_const_op(2.0f), raft::mul_const_op(-1.0f)); + ASSERT_TRUE(raft::match(42.0f, op(5.0f, -17.0f), raft::CompareApprox(0.0001f))); + } +} diff --git a/cpp/test/distance/dist_adj.cu b/cpp/test/distance/dist_adj.cu index f3f36b4576..4f6dfaac24 100644 --- a/cpp/test/distance/dist_adj.cu +++ b/cpp/test/distance/dist_adj.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/distance/dist_canberra.cu b/cpp/test/distance/dist_canberra.cu index 1f368fbee8..db5555d9c8 100644 --- a/cpp/test/distance/dist_canberra.cu +++ b/cpp/test/distance/dist_canberra.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include "distance_base.cuh" namespace raft { diff --git a/cpp/test/distance/dist_chebyshev.cu b/cpp/test/distance/dist_chebyshev.cu index 8f506601ca..abad828de7 100644 --- a/cpp/test/distance/dist_chebyshev.cu +++ b/cpp/test/distance/dist_chebyshev.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include "distance_base.cuh" namespace raft { diff --git a/cpp/test/distance/dist_correlation.cu b/cpp/test/distance/dist_correlation.cu index 77d770b4d1..0e3f0ee0b5 100644 --- a/cpp/test/distance/dist_correlation.cu +++ b/cpp/test/distance/dist_correlation.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include "distance_base.cuh" namespace raft { diff --git a/cpp/test/distance/dist_cos.cu b/cpp/test/distance/dist_cos.cu index 900a71e514..9faf7651f7 100644 --- a/cpp/test/distance/dist_cos.cu +++ b/cpp/test/distance/dist_cos.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include "distance_base.cuh" namespace raft { diff --git a/cpp/test/distance/dist_euc_exp.cu b/cpp/test/distance/dist_euc_exp.cu index 5371b8a3e2..567e279691 100644 --- a/cpp/test/distance/dist_euc_exp.cu +++ b/cpp/test/distance/dist_euc_exp.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include "distance_base.cuh" namespace raft { diff --git a/cpp/test/distance/dist_euc_unexp.cu b/cpp/test/distance/dist_euc_unexp.cu index 81e6be7116..311ad190e2 100644 --- a/cpp/test/distance/dist_euc_unexp.cu +++ b/cpp/test/distance/dist_euc_unexp.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include "distance_base.cuh" namespace raft { diff --git a/cpp/test/distance/dist_eucsqrt_exp.cu b/cpp/test/distance/dist_eucsqrt_exp.cu index c4f2dc80c2..d717158649 100644 --- a/cpp/test/distance/dist_eucsqrt_exp.cu +++ b/cpp/test/distance/dist_eucsqrt_exp.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include "distance_base.cuh" namespace raft { diff --git a/cpp/test/distance/dist_hamming.cu b/cpp/test/distance/dist_hamming.cu index 616ce8f729..1eef9fba4e 100644 --- a/cpp/test/distance/dist_hamming.cu +++ b/cpp/test/distance/dist_hamming.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include "distance_base.cuh" namespace raft { diff --git a/cpp/test/distance/dist_hellinger.cu b/cpp/test/distance/dist_hellinger.cu index d6f994aaf6..85a157aa31 100644 --- a/cpp/test/distance/dist_hellinger.cu +++ b/cpp/test/distance/dist_hellinger.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include "distance_base.cuh" namespace raft { diff --git a/cpp/test/distance/dist_jensen_shannon.cu b/cpp/test/distance/dist_jensen_shannon.cu index 43e4f3aa0f..a1e2f9f38c 100644 --- a/cpp/test/distance/dist_jensen_shannon.cu +++ b/cpp/test/distance/dist_jensen_shannon.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include "distance_base.cuh" namespace raft { diff --git a/cpp/test/distance/dist_kl_divergence.cu b/cpp/test/distance/dist_kl_divergence.cu index 6a5fe8d7ac..94330d9450 100644 --- a/cpp/test/distance/dist_kl_divergence.cu +++ b/cpp/test/distance/dist_kl_divergence.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include "distance_base.cuh" namespace raft { diff --git a/cpp/test/distance/dist_l1.cu b/cpp/test/distance/dist_l1.cu index 322fb52d5c..dc6bcf72b7 100644 --- a/cpp/test/distance/dist_l1.cu +++ b/cpp/test/distance/dist_l1.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include "distance_base.cuh" namespace raft { diff --git a/cpp/test/distance/dist_minkowski.cu b/cpp/test/distance/dist_minkowski.cu index 3e0a2ead92..af2661da3a 100644 --- a/cpp/test/distance/dist_minkowski.cu +++ b/cpp/test/distance/dist_minkowski.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include "distance_base.cuh" namespace raft { diff --git a/cpp/test/distance/dist_russell_rao.cu b/cpp/test/distance/dist_russell_rao.cu index e92a01c70a..3c5124c31f 100644 --- a/cpp/test/distance/dist_russell_rao.cu +++ b/cpp/test/distance/dist_russell_rao.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include "distance_base.cuh" namespace raft { diff --git a/cpp/test/distance/distance_base.cuh b/cpp/test/distance/distance_base.cuh index 8b8c53d354..067b1b2c0e 100644 --- a/cpp/test/distance/distance_base.cuh +++ b/cpp/test/distance/distance_base.cuh @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/distance/fused_l2_nn.cu b/cpp/test/distance/fused_l2_nn.cu index 800f45c7fc..c4126d25df 100644 --- a/cpp/test/distance/fused_l2_nn.cu +++ b/cpp/test/distance/fused_l2_nn.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/distance/gram.cu b/cpp/test/distance/gram.cu index 168e3d93f8..4366e023a0 100644 --- a/cpp/test/distance/gram.cu +++ b/cpp/test/distance/gram.cu @@ -18,7 +18,7 @@ #include #endif -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/label/label.cu b/cpp/test/label/label.cu index 02b3191c4d..bda87d423c 100644 --- a/cpp/test/label/label.cu +++ b/cpp/test/label/label.cu @@ -18,7 +18,7 @@ #include -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include diff --git a/cpp/test/label/merge_labels.cu b/cpp/test/label/merge_labels.cu index 184ab4922f..c3d2f82a84 100644 --- a/cpp/test/label/merge_labels.cu +++ b/cpp/test/label/merge_labels.cu @@ -17,7 +17,7 @@ #include #include -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/linalg/add.cu b/cpp/test/linalg/add.cu index c73791086b..0e5fc40232 100644 --- a/cpp/test/linalg/add.cu +++ b/cpp/test/linalg/add.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include "add.cuh" #include #include diff --git a/cpp/test/linalg/axpy.cu b/cpp/test/linalg/axpy.cu index f6cabae012..2eb11f314d 100644 --- a/cpp/test/linalg/axpy.cu +++ b/cpp/test/linalg/axpy.cu @@ -15,7 +15,7 @@ */ #include -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/linalg/binary_op.cu b/cpp/test/linalg/binary_op.cu index f4f9cd11d7..ac143842cb 100644 --- a/cpp/test/linalg/binary_op.cu +++ b/cpp/test/linalg/binary_op.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include "binary_op.cuh" #include #include diff --git a/cpp/test/linalg/binary_op.cuh b/cpp/test/linalg/binary_op.cuh index 62820ddb97..8b0bc609d2 100644 --- a/cpp/test/linalg/binary_op.cuh +++ b/cpp/test/linalg/binary_op.cuh @@ -16,7 +16,7 @@ #pragma once -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include diff --git a/cpp/test/linalg/cholesky_r1.cu b/cpp/test/linalg/cholesky_r1.cu index 134b7645ff..9d90b03a6e 100644 --- a/cpp/test/linalg/cholesky_r1.cu +++ b/cpp/test/linalg/cholesky_r1.cu @@ -22,7 +22,7 @@ #include #include -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include namespace raft { diff --git a/cpp/test/linalg/coalesced_reduction.cu b/cpp/test/linalg/coalesced_reduction.cu index 1466f557dd..dc82ab9511 100644 --- a/cpp/test/linalg/coalesced_reduction.cu +++ b/cpp/test/linalg/coalesced_reduction.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include "reduce.cuh" #include #include diff --git a/cpp/test/linalg/divide.cu b/cpp/test/linalg/divide.cu index 4e2e5cdba7..4b5ea0a2dc 100644 --- a/cpp/test/linalg/divide.cu +++ b/cpp/test/linalg/divide.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include "unary_op.cuh" #include #include diff --git a/cpp/test/linalg/dot.cu b/cpp/test/linalg/dot.cu index b5007aea32..80a9f24aba 100644 --- a/cpp/test/linalg/dot.cu +++ b/cpp/test/linalg/dot.cu @@ -15,7 +15,7 @@ */ #include -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/linalg/eig.cu b/cpp/test/linalg/eig.cu index a913b14fcb..4b834c1aa8 100644 --- a/cpp/test/linalg/eig.cu +++ b/cpp/test/linalg/eig.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/linalg/eig_sel.cu b/cpp/test/linalg/eig_sel.cu index 9d57c4fa0a..c2b12e5d4a 100644 --- a/cpp/test/linalg/eig_sel.cu +++ b/cpp/test/linalg/eig_sel.cu @@ -16,7 +16,7 @@ #if CUDART_VERSION >= 10010 -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/linalg/eltwise.cu b/cpp/test/linalg/eltwise.cu index 07ded5ec79..d9ab7e0984 100644 --- a/cpp/test/linalg/eltwise.cu +++ b/cpp/test/linalg/eltwise.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/linalg/gemm_layout.cu b/cpp/test/linalg/gemm_layout.cu index dbe10ab4cc..a992a32304 100644 --- a/cpp/test/linalg/gemm_layout.cu +++ b/cpp/test/linalg/gemm_layout.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/linalg/gemv.cu b/cpp/test/linalg/gemv.cu index 2bd9abc200..594810bab2 100644 --- a/cpp/test/linalg/gemv.cu +++ b/cpp/test/linalg/gemv.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/linalg/map.cu b/cpp/test/linalg/map.cu index 95a2aff130..7e3a1562d9 100644 --- a/cpp/test/linalg/map.cu +++ b/cpp/test/linalg/map.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/linalg/map_then_reduce.cu b/cpp/test/linalg/map_then_reduce.cu index ba9cca6f04..1e7f58ec38 100644 --- a/cpp/test/linalg/map_then_reduce.cu +++ b/cpp/test/linalg/map_then_reduce.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/linalg/matrix_vector.cu b/cpp/test/linalg/matrix_vector.cu index e321d45703..7018e1da96 100644 --- a/cpp/test/linalg/matrix_vector.cu +++ b/cpp/test/linalg/matrix_vector.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include "matrix_vector_op.cuh" #include #include diff --git a/cpp/test/linalg/matrix_vector_op.cu b/cpp/test/linalg/matrix_vector_op.cu index 1c96c3fc74..e2775c168d 100644 --- a/cpp/test/linalg/matrix_vector_op.cu +++ b/cpp/test/linalg/matrix_vector_op.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include "matrix_vector_op.cuh" #include #include diff --git a/cpp/test/linalg/matrix_vector_op.cuh b/cpp/test/linalg/matrix_vector_op.cuh index 602d05d153..cf316ef111 100644 --- a/cpp/test/linalg/matrix_vector_op.cuh +++ b/cpp/test/linalg/matrix_vector_op.cuh @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include diff --git a/cpp/test/linalg/mean_squared_error.cu b/cpp/test/linalg/mean_squared_error.cu index 795f831417..18e7debcb1 100644 --- a/cpp/test/linalg/mean_squared_error.cu +++ b/cpp/test/linalg/mean_squared_error.cu @@ -15,7 +15,7 @@ */ #include -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/linalg/multiply.cu b/cpp/test/linalg/multiply.cu index 1d6446c5c0..c90fb93fd0 100644 --- a/cpp/test/linalg/multiply.cu +++ b/cpp/test/linalg/multiply.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include "unary_op.cuh" #include #include diff --git a/cpp/test/linalg/norm.cu b/cpp/test/linalg/norm.cu index 09ebdbd13e..94540b9ff6 100644 --- a/cpp/test/linalg/norm.cu +++ b/cpp/test/linalg/norm.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/linalg/normalize.cu b/cpp/test/linalg/normalize.cu index 599592fce1..0a6786b1ee 100644 --- a/cpp/test/linalg/normalize.cu +++ b/cpp/test/linalg/normalize.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/linalg/power.cu b/cpp/test/linalg/power.cu index bdab49d5c8..54c2e2a7aa 100644 --- a/cpp/test/linalg/power.cu +++ b/cpp/test/linalg/power.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/linalg/reduce.cu b/cpp/test/linalg/reduce.cu index cd112526e0..4ad382c4f7 100644 --- a/cpp/test/linalg/reduce.cu +++ b/cpp/test/linalg/reduce.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include "reduce.cuh" #include #include diff --git a/cpp/test/linalg/reduce_cols_by_key.cu b/cpp/test/linalg/reduce_cols_by_key.cu index 63afbe2fed..a378c450ce 100644 --- a/cpp/test/linalg/reduce_cols_by_key.cu +++ b/cpp/test/linalg/reduce_cols_by_key.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/linalg/reduce_rows_by_key.cu b/cpp/test/linalg/reduce_rows_by_key.cu index 7b124cb7bb..22229d2224 100644 --- a/cpp/test/linalg/reduce_rows_by_key.cu +++ b/cpp/test/linalg/reduce_rows_by_key.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/linalg/rsvd.cu b/cpp/test/linalg/rsvd.cu index f774d59631..04e17468c3 100644 --- a/cpp/test/linalg/rsvd.cu +++ b/cpp/test/linalg/rsvd.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/linalg/sqrt.cu b/cpp/test/linalg/sqrt.cu index ed57e94914..9008313b58 100644 --- a/cpp/test/linalg/sqrt.cu +++ b/cpp/test/linalg/sqrt.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/linalg/strided_reduction.cu b/cpp/test/linalg/strided_reduction.cu index 789ea64c3a..6a8c43ad52 100644 --- a/cpp/test/linalg/strided_reduction.cu +++ b/cpp/test/linalg/strided_reduction.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include "reduce.cuh" #include #include diff --git a/cpp/test/linalg/subtract.cu b/cpp/test/linalg/subtract.cu index 3904f9f33f..426fc98f9f 100644 --- a/cpp/test/linalg/subtract.cu +++ b/cpp/test/linalg/subtract.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/linalg/svd.cu b/cpp/test/linalg/svd.cu index c18417dc9e..7918d481db 100644 --- a/cpp/test/linalg/svd.cu +++ b/cpp/test/linalg/svd.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/linalg/ternary_op.cu b/cpp/test/linalg/ternary_op.cu index e172d771cd..c78df08820 100644 --- a/cpp/test/linalg/ternary_op.cu +++ b/cpp/test/linalg/ternary_op.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/linalg/transpose.cu b/cpp/test/linalg/transpose.cu index 6a05317f49..110dc527d3 100644 --- a/cpp/test/linalg/transpose.cu +++ b/cpp/test/linalg/transpose.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include diff --git a/cpp/test/linalg/unary_op.cu b/cpp/test/linalg/unary_op.cu index 57b009a0ac..3ebf70e69f 100644 --- a/cpp/test/linalg/unary_op.cu +++ b/cpp/test/linalg/unary_op.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include "unary_op.cuh" #include #include diff --git a/cpp/test/linalg/unary_op.cuh b/cpp/test/linalg/unary_op.cuh index 190d531a9f..28bcc004a4 100644 --- a/cpp/test/linalg/unary_op.cuh +++ b/cpp/test/linalg/unary_op.cuh @@ -16,7 +16,7 @@ #pragma once -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include diff --git a/cpp/test/matrix/argmax.cu b/cpp/test/matrix/argmax.cu index 0219eb1aff..33af0ce5a4 100644 --- a/cpp/test/matrix/argmax.cu +++ b/cpp/test/matrix/argmax.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/matrix/argmin.cu b/cpp/test/matrix/argmin.cu index bdf178cd8a..22f0a6cac0 100644 --- a/cpp/test/matrix/argmin.cu +++ b/cpp/test/matrix/argmin.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/matrix/columnSort.cu b/cpp/test/matrix/columnSort.cu index aba1c4e1f0..000a911efd 100644 --- a/cpp/test/matrix/columnSort.cu +++ b/cpp/test/matrix/columnSort.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/matrix/diagonal.cu b/cpp/test/matrix/diagonal.cu index e1ad9e144b..f6cd178b23 100644 --- a/cpp/test/matrix/diagonal.cu +++ b/cpp/test/matrix/diagonal.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/matrix/gather.cu b/cpp/test/matrix/gather.cu index 4b3244913b..0bea62e9cf 100644 --- a/cpp/test/matrix/gather.cu +++ b/cpp/test/matrix/gather.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/matrix/linewise_op.cu b/cpp/test/matrix/linewise_op.cu index 5dcdf265ac..9ce1371944 100644 --- a/cpp/test/matrix/linewise_op.cu +++ b/cpp/test/matrix/linewise_op.cu @@ -15,7 +15,7 @@ */ #include "../linalg/matrix_vector_op.cuh" -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/matrix/math.cu b/cpp/test/matrix/math.cu index 684b550dfc..f2c1a6249c 100644 --- a/cpp/test/matrix/math.cu +++ b/cpp/test/matrix/math.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include diff --git a/cpp/test/matrix/matrix.cu b/cpp/test/matrix/matrix.cu index 78391d5ff2..8cfbdac32b 100644 --- a/cpp/test/matrix/matrix.cu +++ b/cpp/test/matrix/matrix.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include diff --git a/cpp/test/matrix/norm.cu b/cpp/test/matrix/norm.cu index 38fdd409eb..b1e10c9047 100644 --- a/cpp/test/matrix/norm.cu +++ b/cpp/test/matrix/norm.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/matrix/reverse.cu b/cpp/test/matrix/reverse.cu index c905b8711e..49d501b6d0 100644 --- a/cpp/test/matrix/reverse.cu +++ b/cpp/test/matrix/reverse.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/matrix/slice.cu b/cpp/test/matrix/slice.cu index 5faf672d13..9060357b3f 100644 --- a/cpp/test/matrix/slice.cu +++ b/cpp/test/matrix/slice.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/matrix/triangular.cu b/cpp/test/matrix/triangular.cu index 9af3defb5d..9c6c49066b 100644 --- a/cpp/test/matrix/triangular.cu +++ b/cpp/test/matrix/triangular.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/mst.cu b/cpp/test/mst.cu index 544ca80a46..d11f0b5842 100644 --- a/cpp/test/mst.cu +++ b/cpp/test/mst.cu @@ -16,7 +16,7 @@ #include -#include "test_utils.h" +#include "test_utils.cuh" #include #include #include diff --git a/cpp/test/neighbors/ann_ivf_flat.cu b/cpp/test/neighbors/ann_ivf_flat.cu index 735d569318..1207b75a4a 100644 --- a/cpp/test/neighbors/ann_ivf_flat.cu +++ b/cpp/test/neighbors/ann_ivf_flat.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include "ann_utils.cuh" #include diff --git a/cpp/test/neighbors/ann_ivf_pq.cuh b/cpp/test/neighbors/ann_ivf_pq.cuh index 9d6ad11ccb..7c3ec044b1 100644 --- a/cpp/test/neighbors/ann_ivf_pq.cuh +++ b/cpp/test/neighbors/ann_ivf_pq.cuh @@ -15,7 +15,7 @@ */ #pragma once -#include "../test_utils.h" +#include "../test_utils.cuh" #include "ann_utils.cuh" #include diff --git a/cpp/test/neighbors/ann_utils.cuh b/cpp/test/neighbors/ann_utils.cuh index 07ef410d36..05fe6ab92d 100644 --- a/cpp/test/neighbors/ann_utils.cuh +++ b/cpp/test/neighbors/ann_utils.cuh @@ -25,7 +25,7 @@ #include #include -#include "../test_utils.h" +#include "../test_utils.cuh" #include namespace raft::neighbors { diff --git a/cpp/test/neighbors/ball_cover.cu b/cpp/test/neighbors/ball_cover.cu index 47030b0d62..7405863b9f 100644 --- a/cpp/test/neighbors/ball_cover.cu +++ b/cpp/test/neighbors/ball_cover.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include "spatial_data.h" #include #include diff --git a/cpp/test/neighbors/epsilon_neighborhood.cu b/cpp/test/neighbors/epsilon_neighborhood.cu index c83817f6f8..4f33db489e 100644 --- a/cpp/test/neighbors/epsilon_neighborhood.cu +++ b/cpp/test/neighbors/epsilon_neighborhood.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/neighbors/faiss_mr.cu b/cpp/test/neighbors/faiss_mr.cu index 91ba1cc94c..38e793d120 100644 --- a/cpp/test/neighbors/faiss_mr.cu +++ b/cpp/test/neighbors/faiss_mr.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include diff --git a/cpp/test/neighbors/fused_l2_knn.cu b/cpp/test/neighbors/fused_l2_knn.cu index 8df193d53d..d57f99da50 100644 --- a/cpp/test/neighbors/fused_l2_knn.cu +++ b/cpp/test/neighbors/fused_l2_knn.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include diff --git a/cpp/test/neighbors/haversine.cu b/cpp/test/neighbors/haversine.cu index 78bd377156..91a2ca07df 100644 --- a/cpp/test/neighbors/haversine.cu +++ b/cpp/test/neighbors/haversine.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/neighbors/knn.cu b/cpp/test/neighbors/knn.cu index eb5ecf663f..ff3a6a80b4 100644 --- a/cpp/test/neighbors/knn.cu +++ b/cpp/test/neighbors/knn.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include diff --git a/cpp/test/neighbors/refine.cu b/cpp/test/neighbors/refine.cu index 06c1317b1e..674171e030 100644 --- a/cpp/test/neighbors/refine.cu +++ b/cpp/test/neighbors/refine.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include "ann_utils.cuh" #include "refine_helper.cuh" diff --git a/cpp/test/neighbors/selection.cu b/cpp/test/neighbors/selection.cu index bfcfca5ead..d793ea46ee 100644 --- a/cpp/test/neighbors/selection.cu +++ b/cpp/test/neighbors/selection.cu @@ -20,7 +20,7 @@ #include #include -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include diff --git a/cpp/test/random/make_blobs.cu b/cpp/test/random/make_blobs.cu index 1f14fd23f7..741b374c8c 100644 --- a/cpp/test/random/make_blobs.cu +++ b/cpp/test/random/make_blobs.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/random/make_regression.cu b/cpp/test/random/make_regression.cu index 65d4c4cb31..c2e447adf8 100644 --- a/cpp/test/random/make_regression.cu +++ b/cpp/test/random/make_regression.cu @@ -19,7 +19,7 @@ #include #include -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/random/multi_variable_gaussian.cu b/cpp/test/random/multi_variable_gaussian.cu index f9f79e6845..04626a53c7 100644 --- a/cpp/test/random/multi_variable_gaussian.cu +++ b/cpp/test/random/multi_variable_gaussian.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/random/permute.cu b/cpp/test/random/permute.cu index 32e5540d51..be4f2a005f 100644 --- a/cpp/test/random/permute.cu +++ b/cpp/test/random/permute.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/random/rmat_rectangular_generator.cu b/cpp/test/random/rmat_rectangular_generator.cu index 0baaaf28cf..c1c4752453 100644 --- a/cpp/test/random/rmat_rectangular_generator.cu +++ b/cpp/test/random/rmat_rectangular_generator.cu @@ -19,7 +19,7 @@ #include #include -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include diff --git a/cpp/test/random/rng.cu b/cpp/test/random/rng.cu index 82f6e0e247..bdce79b76e 100644 --- a/cpp/test/random/rng.cu +++ b/cpp/test/random/rng.cu @@ -17,7 +17,7 @@ #include #include -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/random/rng_discrete.cu b/cpp/test/random/rng_discrete.cu index b7aef51af5..1dee281527 100644 --- a/cpp/test/random/rng_discrete.cu +++ b/cpp/test/random/rng_discrete.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/random/rng_int.cu b/cpp/test/random/rng_int.cu index d5270c456e..89d6d208a5 100644 --- a/cpp/test/random/rng_int.cu +++ b/cpp/test/random/rng_int.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/random/sample_without_replacement.cu b/cpp/test/random/sample_without_replacement.cu index 355df1fcc1..a6cf3569e6 100644 --- a/cpp/test/random/sample_without_replacement.cu +++ b/cpp/test/random/sample_without_replacement.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/sparse/add.cu b/cpp/test/sparse/add.cu index 862cbffdc7..692094a861 100644 --- a/cpp/test/sparse/add.cu +++ b/cpp/test/sparse/add.cu @@ -20,7 +20,7 @@ #include #include -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include diff --git a/cpp/test/sparse/convert_coo.cu b/cpp/test/sparse/convert_coo.cu index 1142c6f3f2..21f81bfa6a 100644 --- a/cpp/test/sparse/convert_coo.cu +++ b/cpp/test/sparse/convert_coo.cu @@ -22,7 +22,7 @@ #include -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include diff --git a/cpp/test/sparse/convert_csr.cu b/cpp/test/sparse/convert_csr.cu index 007cbd7fdb..bc81dcaba5 100644 --- a/cpp/test/sparse/convert_csr.cu +++ b/cpp/test/sparse/convert_csr.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include diff --git a/cpp/test/sparse/csr_row_slice.cu b/cpp/test/sparse/csr_row_slice.cu index 39b235d5f1..6b10f8d798 100644 --- a/cpp/test/sparse/csr_row_slice.cu +++ b/cpp/test/sparse/csr_row_slice.cu @@ -24,7 +24,7 @@ #include -#include "../test_utils.h" +#include "../test_utils.cuh" namespace raft { namespace sparse { diff --git a/cpp/test/sparse/csr_to_dense.cu b/cpp/test/sparse/csr_to_dense.cu index 5811c5c22b..b99d6fa7fd 100644 --- a/cpp/test/sparse/csr_to_dense.cu +++ b/cpp/test/sparse/csr_to_dense.cu @@ -24,7 +24,7 @@ #include -#include "../test_utils.h" +#include "../test_utils.cuh" namespace raft { namespace sparse { diff --git a/cpp/test/sparse/csr_transpose.cu b/cpp/test/sparse/csr_transpose.cu index 108d38a8b4..2342f6e7ef 100644 --- a/cpp/test/sparse/csr_transpose.cu +++ b/cpp/test/sparse/csr_transpose.cu @@ -23,7 +23,7 @@ #include #include -#include "../test_utils.h" +#include "../test_utils.cuh" namespace raft { namespace sparse { diff --git a/cpp/test/sparse/degree.cu b/cpp/test/sparse/degree.cu index a4af021c05..b5b22ff15c 100644 --- a/cpp/test/sparse/degree.cu +++ b/cpp/test/sparse/degree.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include diff --git a/cpp/test/sparse/dist_coo_spmv.cu b/cpp/test/sparse/dist_coo_spmv.cu index 9a8f650449..be73bd3130 100644 --- a/cpp/test/sparse/dist_coo_spmv.cu +++ b/cpp/test/sparse/dist_coo_spmv.cu @@ -27,7 +27,7 @@ #include #include -#include "../test_utils.h" +#include "../test_utils.cuh" #include diff --git a/cpp/test/sparse/distance.cu b/cpp/test/sparse/distance.cu index 4ce2f4cbde..367bfddad1 100644 --- a/cpp/test/sparse/distance.cu +++ b/cpp/test/sparse/distance.cu @@ -24,7 +24,7 @@ #include -#include "../test_utils.h" +#include "../test_utils.cuh" namespace raft { namespace sparse { diff --git a/cpp/test/sparse/filter.cu b/cpp/test/sparse/filter.cu index ba80c84fd5..6ada5ddcad 100644 --- a/cpp/test/sparse/filter.cu +++ b/cpp/test/sparse/filter.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include diff --git a/cpp/test/sparse/neighbors/brute_force.cu b/cpp/test/sparse/neighbors/brute_force.cu index 8fa5e8322d..96ba3bc48f 100644 --- a/cpp/test/sparse/neighbors/brute_force.cu +++ b/cpp/test/sparse/neighbors/brute_force.cu @@ -17,7 +17,7 @@ #include #include -#include "../../test_utils.h" +#include "../../test_utils.cuh" #include #include diff --git a/cpp/test/sparse/neighbors/connect_components.cu b/cpp/test/sparse/neighbors/connect_components.cu index fc4eecd4ee..f469cc1aa2 100644 --- a/cpp/test/sparse/neighbors/connect_components.cu +++ b/cpp/test/sparse/neighbors/connect_components.cu @@ -34,7 +34,7 @@ #include #include -#include "../../test_utils.h" +#include "../../test_utils.cuh" namespace raft { namespace sparse { diff --git a/cpp/test/sparse/neighbors/knn_graph.cu b/cpp/test/sparse/neighbors/knn_graph.cu index d6f9e8386f..c3700a536c 100644 --- a/cpp/test/sparse/neighbors/knn_graph.cu +++ b/cpp/test/sparse/neighbors/knn_graph.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../../test_utils.h" +#include "../../test_utils.cuh" #include #include #include diff --git a/cpp/test/sparse/norm.cu b/cpp/test/sparse/norm.cu index 8e54edd6c9..1a69acc535 100644 --- a/cpp/test/sparse/norm.cu +++ b/cpp/test/sparse/norm.cu @@ -16,7 +16,7 @@ #include -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include diff --git a/cpp/test/sparse/reduce.cu b/cpp/test/sparse/reduce.cu index 4280192723..5752624435 100644 --- a/cpp/test/sparse/reduce.cu +++ b/cpp/test/sparse/reduce.cu @@ -16,7 +16,7 @@ #include -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/sparse/row_op.cu b/cpp/test/sparse/row_op.cu index 732bd06103..6393c5ee86 100644 --- a/cpp/test/sparse/row_op.cu +++ b/cpp/test/sparse/row_op.cu @@ -19,7 +19,7 @@ #include #include -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include diff --git a/cpp/test/sparse/sort.cu b/cpp/test/sparse/sort.cu index 9b75965498..23c2f5b67a 100644 --- a/cpp/test/sparse/sort.cu +++ b/cpp/test/sparse/sort.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/sparse/spgemmi.cu b/cpp/test/sparse/spgemmi.cu index a132c94fde..653c2fa29b 100644 --- a/cpp/test/sparse/spgemmi.cu +++ b/cpp/test/sparse/spgemmi.cu @@ -16,7 +16,7 @@ #include -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include diff --git a/cpp/test/sparse/symmetrize.cu b/cpp/test/sparse/symmetrize.cu index 7cf1a1e07d..6f2f877304 100644 --- a/cpp/test/sparse/symmetrize.cu +++ b/cpp/test/sparse/symmetrize.cu @@ -22,7 +22,7 @@ #include #include -#include "../test_utils.h" +#include "../test_utils.cuh" #include diff --git a/cpp/test/stats/accuracy.cu b/cpp/test/stats/accuracy.cu index 192c187794..eaccdecab4 100644 --- a/cpp/test/stats/accuracy.cu +++ b/cpp/test/stats/accuracy.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/stats/adjusted_rand_index.cu b/cpp/test/stats/adjusted_rand_index.cu index f113af821d..52bc72174a 100644 --- a/cpp/test/stats/adjusted_rand_index.cu +++ b/cpp/test/stats/adjusted_rand_index.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/stats/completeness_score.cu b/cpp/test/stats/completeness_score.cu index 2f8a40afdc..a9d1748f88 100644 --- a/cpp/test/stats/completeness_score.cu +++ b/cpp/test/stats/completeness_score.cu @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/stats/contingencyMatrix.cu b/cpp/test/stats/contingencyMatrix.cu index 7943610689..d27114388e 100644 --- a/cpp/test/stats/contingencyMatrix.cu +++ b/cpp/test/stats/contingencyMatrix.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/stats/cov.cu b/cpp/test/stats/cov.cu index 890c5b7826..59a2c6e081 100644 --- a/cpp/test/stats/cov.cu +++ b/cpp/test/stats/cov.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/stats/dispersion.cu b/cpp/test/stats/dispersion.cu index 4f18c9fb54..e414fcf5f4 100644 --- a/cpp/test/stats/dispersion.cu +++ b/cpp/test/stats/dispersion.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/stats/entropy.cu b/cpp/test/stats/entropy.cu index 04aa9f7a80..96b2b9f590 100644 --- a/cpp/test/stats/entropy.cu +++ b/cpp/test/stats/entropy.cu @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/stats/histogram.cu b/cpp/test/stats/histogram.cu index d9793a57df..76677ac27c 100644 --- a/cpp/test/stats/histogram.cu +++ b/cpp/test/stats/histogram.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/stats/homogeneity_score.cu b/cpp/test/stats/homogeneity_score.cu index 9bd6d9266b..ecbf160770 100644 --- a/cpp/test/stats/homogeneity_score.cu +++ b/cpp/test/stats/homogeneity_score.cu @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/stats/information_criterion.cu b/cpp/test/stats/information_criterion.cu index 4a9a2128c6..2cfbd787c6 100644 --- a/cpp/test/stats/information_criterion.cu +++ b/cpp/test/stats/information_criterion.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include diff --git a/cpp/test/stats/kl_divergence.cu b/cpp/test/stats/kl_divergence.cu index 58a64f7199..b5a6c393f3 100644 --- a/cpp/test/stats/kl_divergence.cu +++ b/cpp/test/stats/kl_divergence.cu @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/stats/mean.cu b/cpp/test/stats/mean.cu index b299f81f68..19398d6d8e 100644 --- a/cpp/test/stats/mean.cu +++ b/cpp/test/stats/mean.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/stats/mean_center.cu b/cpp/test/stats/mean_center.cu index 30dcdd475b..31947ef527 100644 --- a/cpp/test/stats/mean_center.cu +++ b/cpp/test/stats/mean_center.cu @@ -15,7 +15,7 @@ */ #include "../linalg/matrix_vector_op.cuh" -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/stats/meanvar.cu b/cpp/test/stats/meanvar.cu index 424395c5e8..fb9fc13dec 100644 --- a/cpp/test/stats/meanvar.cu +++ b/cpp/test/stats/meanvar.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/stats/minmax.cu b/cpp/test/stats/minmax.cu index a2ba6bfc9e..1171995d5c 100644 --- a/cpp/test/stats/minmax.cu +++ b/cpp/test/stats/minmax.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/stats/mutual_info_score.cu b/cpp/test/stats/mutual_info_score.cu index fb9362df52..8b6e7b2095 100644 --- a/cpp/test/stats/mutual_info_score.cu +++ b/cpp/test/stats/mutual_info_score.cu @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/stats/r2_score.cu b/cpp/test/stats/r2_score.cu index d77daacb04..7fb15505ab 100644 --- a/cpp/test/stats/r2_score.cu +++ b/cpp/test/stats/r2_score.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/stats/rand_index.cu b/cpp/test/stats/rand_index.cu index 67e4ab5517..0010f3cbcd 100644 --- a/cpp/test/stats/rand_index.cu +++ b/cpp/test/stats/rand_index.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include diff --git a/cpp/test/stats/regression_metrics.cu b/cpp/test/stats/regression_metrics.cu index effc3d04dd..86ac03c8b3 100644 --- a/cpp/test/stats/regression_metrics.cu +++ b/cpp/test/stats/regression_metrics.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/stats/silhouette_score.cu b/cpp/test/stats/silhouette_score.cu index 37a6fff786..876926b71a 100644 --- a/cpp/test/stats/silhouette_score.cu +++ b/cpp/test/stats/silhouette_score.cu @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/stats/stddev.cu b/cpp/test/stats/stddev.cu index 73f30f17e9..7f54eee2ab 100644 --- a/cpp/test/stats/stddev.cu +++ b/cpp/test/stats/stddev.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/stats/sum.cu b/cpp/test/stats/sum.cu index e67988abb0..c69bd04c6e 100644 --- a/cpp/test/stats/sum.cu +++ b/cpp/test/stats/sum.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include diff --git a/cpp/test/stats/trustworthiness.cu b/cpp/test/stats/trustworthiness.cu index cbb8228f8f..a95cddf5aa 100644 --- a/cpp/test/stats/trustworthiness.cu +++ b/cpp/test/stats/trustworthiness.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/stats/v_measure.cu b/cpp/test/stats/v_measure.cu index 0cbc2da7d9..79899c1d75 100644 --- a/cpp/test/stats/v_measure.cu +++ b/cpp/test/stats/v_measure.cu @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/stats/weighted_mean.cu b/cpp/test/stats/weighted_mean.cu index 9f33855572..8b4af07898 100644 --- a/cpp/test/stats/weighted_mean.cu +++ b/cpp/test/stats/weighted_mean.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/test_utils.cuh b/cpp/test/test_utils.cuh new file mode 100644 index 0000000000..5704eefae3 --- /dev/null +++ b/cpp/test/test_utils.cuh @@ -0,0 +1,329 @@ +/* + * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "test_utils.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace raft { + +/* + * @brief Helper function to compare 2 device n-D arrays with custom comparison + * @tparam T the data type of the arrays + * @tparam L the comparator lambda or object function + * @param expected expected value(s) + * @param actual actual values + * @param eq_compare the comparator + * @param stream cuda stream + * @return the testing assertion to be later used by ASSERT_TRUE/EXPECT_TRUE + * @{ + */ +template +testing::AssertionResult devArrMatch( + const T* expected, const T* actual, size_t size, L eq_compare, cudaStream_t stream = 0) +{ + std::unique_ptr exp_h(new T[size]); + std::unique_ptr act_h(new T[size]); + raft::update_host(exp_h.get(), expected, size, stream); + raft::update_host(act_h.get(), actual, size, stream); + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + for (size_t i(0); i < size; ++i) { + auto exp = exp_h.get()[i]; + auto act = act_h.get()[i]; + if (!eq_compare(exp, act)) { + return testing::AssertionFailure() << "actual=" << act << " != expected=" << exp << " @" << i; + } + } + return testing::AssertionSuccess(); +} + +template +testing::AssertionResult devArrMatch( + T expected, const T* actual, size_t size, L eq_compare, cudaStream_t stream = 0) +{ + std::unique_ptr act_h(new T[size]); + raft::update_host(act_h.get(), actual, size, stream); + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + for (size_t i(0); i < size; ++i) { + auto act = act_h.get()[i]; + if (!eq_compare(expected, act)) { + return testing::AssertionFailure() + << "actual=" << act << " != expected=" << expected << " @" << i; + } + } + return testing::AssertionSuccess(); +} + +template +testing::AssertionResult devArrMatch(const T* expected, + const T* actual, + size_t rows, + size_t cols, + L eq_compare, + cudaStream_t stream = 0) +{ + size_t size = rows * cols; + std::unique_ptr exp_h(new T[size]); + std::unique_ptr act_h(new T[size]); + raft::update_host(exp_h.get(), expected, size, stream); + raft::update_host(act_h.get(), actual, size, stream); + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + for (size_t i(0); i < rows; ++i) { + for (size_t j(0); j < cols; ++j) { + auto idx = i * cols + j; // row major assumption! + auto exp = exp_h.get()[idx]; + auto act = act_h.get()[idx]; + if (!eq_compare(exp, act)) { + return testing::AssertionFailure() + << "actual=" << act << " != expected=" << exp << " @" << i << "," << j; + } + } + } + return testing::AssertionSuccess(); +} + +template +testing::AssertionResult devArrMatch( + T expected, const T* actual, size_t rows, size_t cols, L eq_compare, cudaStream_t stream = 0) +{ + size_t size = rows * cols; + std::unique_ptr act_h(new T[size]); + raft::update_host(act_h.get(), actual, size, stream); + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + for (size_t i(0); i < rows; ++i) { + for (size_t j(0); j < cols; ++j) { + auto idx = i * cols + j; // row major assumption! + auto act = act_h.get()[idx]; + if (!eq_compare(expected, act)) { + return testing::AssertionFailure() + << "actual=" << act << " != expected=" << expected << " @" << i << "," << j; + } + } + } + return testing::AssertionSuccess(); +} + +/* + * @brief Helper function to compare a device n-D arrays with an expected array + * on the host, using a custom comparison + * @tparam T the data type of the arrays + * @tparam L the comparator lambda or object function + * @param expected_h host array of expected value(s) + * @param actual_d device array actual values + * @param eq_compare the comparator + * @param stream cuda stream + * @return the testing assertion to be later used by ASSERT_TRUE/EXPECT_TRUE + */ +template +testing::AssertionResult devArrMatchHost( + const T* expected_h, const T* actual_d, size_t size, L eq_compare, cudaStream_t stream = 0) +{ + std::unique_ptr act_h(new T[size]); + raft::update_host(act_h.get(), actual_d, size, stream); + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + bool ok = true; + auto fail = testing::AssertionFailure(); + for (size_t i(0); i < size; ++i) { + auto exp = expected_h[i]; + auto act = act_h.get()[i]; + if (!eq_compare(exp, act)) { + ok = false; + fail << "actual=" << act << " != expected=" << exp << " @" << i << "; "; + } + } + if (!ok) return fail; + return testing::AssertionSuccess(); +} + +/** + * @brief Helper function to compare host vectors using a custom comparison + * @tparam T the element type + * @tparam L the comparator lambda or object function + * @param expected_h host vector of expected value(s) + * @param actual_h host vector actual values + * @param eq_compare the comparator + * @return the testing assertion to be later used by ASSERT_TRUE/EXPECT_TRUE + */ +template +testing::AssertionResult hostVecMatch(const std::vector& expected_h, + const std::vector& actual_h, + L eq_compare) +{ + auto n = actual_h.size(); + if (n != expected_h.size()) + return testing::AssertionFailure() + << "vector sizez mismatch: " + << "actual=" << n << " != expected=" << expected_h.size() << "; "; + for (size_t i = 0; i < n; ++i) { + auto exp = expected_h[i]; + auto act = actual_h[i]; + if (!eq_compare(exp, act)) { + return testing::AssertionFailure() + << "actual=" << act << " != expected=" << exp << " @" << i << "; "; + } + } + return testing::AssertionSuccess(); +} + +/* + * @brief Helper function to compare diagonal values of a 2D matrix + * @tparam T the data type of the arrays + * @tparam L the comparator lambda or object function + * @param expected expected value along diagonal + * @param actual actual matrix + * @param eq_compare the comparator + * @param stream cuda stream + * @return the testing assertion to be later used by ASSERT_TRUE/EXPECT_TRUE + */ +template +testing::AssertionResult diagonalMatch( + T expected, const T* actual, size_t rows, size_t cols, L eq_compare, cudaStream_t stream = 0) +{ + size_t size = rows * cols; + std::unique_ptr act_h(new T[size]); + raft::update_host(act_h.get(), actual, size, stream); + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + for (size_t i(0); i < rows; ++i) { + for (size_t j(0); j < cols; ++j) { + if (i != j) continue; + auto idx = i * cols + j; // row major assumption! + auto act = act_h.get()[idx]; + if (!eq_compare(expected, act)) { + return testing::AssertionFailure() + << "actual=" << act << " != expected=" << expected << " @" << i << "," << j; + } + } + } + return testing::AssertionSuccess(); +} + +template +typename std::enable_if_t> gen_uniform(T* out, + raft::random::RngState& rng, + IdxT len, + cudaStream_t stream, + T range_min = T(-1), + T range_max = T(1)) +{ + raft::random::uniform(rng, out, len, range_min, range_max, stream); +} + +template +typename std::enable_if_t> gen_uniform(T* out, + raft::random::RngState& rng, + IdxT len, + cudaStream_t stream, + T range_min = T(0), + T range_max = T(100)) +{ + raft::random::uniformInt(rng, out, len, range_min, range_max, stream); +} + +template +void gen_uniform(raft::KeyValuePair* out, + raft::random::RngState& rng, + IdxT len, + cudaStream_t stream) +{ + rmm::device_uvector keys(len, stream); + rmm::device_uvector values(len, stream); + + gen_uniform(keys.data(), rng, len, stream); + gen_uniform(values.data(), rng, len, stream); + + const T1* d_keys = keys.data(); + const T2* d_values = values.data(); + auto counting = thrust::make_counting_iterator(0); + thrust::for_each(rmm::exec_policy(stream), + counting, + counting + len, + [out, d_keys, d_values] __device__(int idx) { + out[idx].key = d_keys[idx]; + out[idx].value = d_values[idx]; + }); +} + +/** @} */ + +/** time the function call 'func' using cuda events */ +#define TIMEIT_LOOP(ms, count, func) \ + do { \ + cudaEvent_t start, stop; \ + RAFT_CUDA_TRY(cudaEventCreate(&start)); \ + RAFT_CUDA_TRY(cudaEventCreate(&stop)); \ + RAFT_CUDA_TRY(cudaEventRecord(start)); \ + for (int i = 0; i < count; ++i) { \ + func; \ + } \ + RAFT_CUDA_TRY(cudaEventRecord(stop)); \ + RAFT_CUDA_TRY(cudaEventSynchronize(stop)); \ + ms = 0.f; \ + RAFT_CUDA_TRY(cudaEventElapsedTime(&ms, start, stop)); \ + ms /= args.runs; \ + } while (0) + +inline std::vector read_csv(std::string filename, bool skip_first_n_columns = 1) +{ + std::vector result; + std::ifstream myFile(filename); + if (!myFile.is_open()) throw std::runtime_error("Could not open file"); + + std::string line, colname; + int val; + + if (myFile.good()) { + std::getline(myFile, line); + std::stringstream ss(line); + while (std::getline(ss, colname, ',')) {} + } + + int n_lines = 0; + while (std::getline(myFile, line)) { + std::stringstream ss(line); + int colIdx = 0; + while (ss >> val) { + if (colIdx >= skip_first_n_columns) { + result.push_back(val); + if (ss.peek() == ',') ss.ignore(); + } + colIdx++; + } + n_lines++; + } + + printf("lines read: %d\n", n_lines); + myFile.close(); + return result; +} + +}; // end namespace raft \ No newline at end of file diff --git a/cpp/test/test_utils.h b/cpp/test/test_utils.h index 26483e6b2d..75590463b0 100644 --- a/cpp/test/test_utils.h +++ b/cpp/test/test_utils.h @@ -15,23 +15,13 @@ */ #pragma once -#include + +#include #include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include +#include + +#include namespace raft { @@ -40,13 +30,22 @@ struct Compare { bool operator()(const T& a, const T& b) const { return a == b; } }; +template +struct Compare> { + bool operator()(const raft::KeyValuePair& a, + const raft::KeyValuePair& b) const + { + return a.key == b.key && a.value == b.value; + } +}; + template struct CompareApprox { CompareApprox(T eps_) : eps(eps_) {} bool operator()(const T& a, const T& b) const { - T diff = abs(a - b); - T m = std::max(abs(a), abs(b)); + T diff = std::abs(a - b); + T m = std::max(std::abs(a), std::abs(b)); T ratio = diff > eps ? diff / m : diff; return (ratio <= eps); @@ -85,8 +84,8 @@ struct CompareApproxAbs { CompareApproxAbs(T eps_) : eps(eps_) {} bool operator()(const T& a, const T& b) const { - T diff = abs(abs(a) - abs(b)); - T m = std::max(abs(a), abs(b)); + T diff = std::abs(std::abs(a) - std::abs(b)); + T m = std::max(std::abs(a), std::abs(b)); T ratio = diff >= eps ? diff / m : diff; return (ratio <= eps); } @@ -98,210 +97,14 @@ struct CompareApproxAbs { template struct CompareApproxNoScaling { CompareApproxNoScaling(T eps_) : eps(eps_) {} - bool operator()(const T& a, const T& b) const { return (abs(a - b) <= eps); } + bool operator()(const T& a, const T& b) const { return (std::abs(a - b) <= eps); } private: T eps; }; -template -__host__ __device__ T abs(const T& a) -{ - return a > T(0) ? a : -a; -} - -/* - * @brief Helper function to compare 2 device n-D arrays with custom comparison - * @tparam T the data type of the arrays - * @tparam L the comparator lambda or object function - * @param expected expected value(s) - * @param actual actual values - * @param eq_compare the comparator - * @param stream cuda stream - * @return the testing assertion to be later used by ASSERT_TRUE/EXPECT_TRUE - * @{ - */ -template -testing::AssertionResult devArrMatch( - const T* expected, const T* actual, size_t size, L eq_compare, cudaStream_t stream = 0) -{ - std::unique_ptr exp_h(new T[size]); - std::unique_ptr act_h(new T[size]); - raft::update_host(exp_h.get(), expected, size, stream); - raft::update_host(act_h.get(), actual, size, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); - for (size_t i(0); i < size; ++i) { - auto exp = exp_h.get()[i]; - auto act = act_h.get()[i]; - if (!eq_compare(exp, act)) { - return testing::AssertionFailure() << "actual=" << act << " != expected=" << exp << " @" << i; - } - } - return testing::AssertionSuccess(); -} - -template -testing::AssertionResult devArrMatch( - T expected, const T* actual, size_t size, L eq_compare, cudaStream_t stream = 0) -{ - std::unique_ptr act_h(new T[size]); - raft::update_host(act_h.get(), actual, size, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); - for (size_t i(0); i < size; ++i) { - auto act = act_h.get()[i]; - if (!eq_compare(expected, act)) { - return testing::AssertionFailure() - << "actual=" << act << " != expected=" << expected << " @" << i; - } - } - return testing::AssertionSuccess(); -} - -template -testing::AssertionResult devArrMatch(const T* expected, - const T* actual, - size_t rows, - size_t cols, - L eq_compare, - cudaStream_t stream = 0) -{ - size_t size = rows * cols; - std::unique_ptr exp_h(new T[size]); - std::unique_ptr act_h(new T[size]); - raft::update_host(exp_h.get(), expected, size, stream); - raft::update_host(act_h.get(), actual, size, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); - for (size_t i(0); i < rows; ++i) { - for (size_t j(0); j < cols; ++j) { - auto idx = i * cols + j; // row major assumption! - auto exp = exp_h.get()[idx]; - auto act = act_h.get()[idx]; - if (!eq_compare(exp, act)) { - return testing::AssertionFailure() - << "actual=" << act << " != expected=" << exp << " @" << i << "," << j; - } - } - } - return testing::AssertionSuccess(); -} - -template -testing::AssertionResult devArrMatch( - T expected, const T* actual, size_t rows, size_t cols, L eq_compare, cudaStream_t stream = 0) -{ - size_t size = rows * cols; - std::unique_ptr act_h(new T[size]); - raft::update_host(act_h.get(), actual, size, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); - for (size_t i(0); i < rows; ++i) { - for (size_t j(0); j < cols; ++j) { - auto idx = i * cols + j; // row major assumption! - auto act = act_h.get()[idx]; - if (!eq_compare(expected, act)) { - return testing::AssertionFailure() - << "actual=" << act << " != expected=" << expected << " @" << i << "," << j; - } - } - } - return testing::AssertionSuccess(); -} - -/* - * @brief Helper function to compare a device n-D arrays with an expected array - * on the host, using a custom comparison - * @tparam T the data type of the arrays - * @tparam L the comparator lambda or object function - * @param expected_h host array of expected value(s) - * @param actual_d device array actual values - * @param eq_compare the comparator - * @param stream cuda stream - * @return the testing assertion to be later used by ASSERT_TRUE/EXPECT_TRUE - */ -template -testing::AssertionResult devArrMatchHost( - const T* expected_h, const T* actual_d, size_t size, L eq_compare, cudaStream_t stream = 0) -{ - std::unique_ptr act_h(new T[size]); - raft::update_host(act_h.get(), actual_d, size, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); - bool ok = true; - auto fail = testing::AssertionFailure(); - for (size_t i(0); i < size; ++i) { - auto exp = expected_h[i]; - auto act = act_h.get()[i]; - if (!eq_compare(exp, act)) { - ok = false; - fail << "actual=" << act << " != expected=" << exp << " @" << i << "; "; - } - } - if (!ok) return fail; - return testing::AssertionSuccess(); -} - -/** - * @brief Helper function to compare host vectors using a custom comparison - * @tparam T the element type - * @tparam L the comparator lambda or object function - * @param expected_h host vector of expected value(s) - * @param actual_h host vector actual values - * @param eq_compare the comparator - * @return the testing assertion to be later used by ASSERT_TRUE/EXPECT_TRUE - */ -template -testing::AssertionResult hostVecMatch(const std::vector& expected_h, - const std::vector& actual_h, - L eq_compare) -{ - auto n = actual_h.size(); - if (n != expected_h.size()) - return testing::AssertionFailure() - << "vector sizez mismatch: " - << "actual=" << n << " != expected=" << expected_h.size() << "; "; - for (size_t i = 0; i < n; ++i) { - auto exp = expected_h[i]; - auto act = actual_h[i]; - if (!eq_compare(exp, act)) { - return testing::AssertionFailure() - << "actual=" << act << " != expected=" << exp << " @" << i << "; "; - } - } - return testing::AssertionSuccess(); -} - -/* - * @brief Helper function to compare diagonal values of a 2D matrix - * @tparam T the data type of the arrays - * @tparam L the comparator lambda or object function - * @param expected expected value along diagonal - * @param actual actual matrix - * @param eq_compare the comparator - * @param stream cuda stream - * @return the testing assertion to be later used by ASSERT_TRUE/EXPECT_TRUE - */ -template -testing::AssertionResult diagonalMatch( - T expected, const T* actual, size_t rows, size_t cols, L eq_compare, cudaStream_t stream = 0) -{ - size_t size = rows * cols; - std::unique_ptr act_h(new T[size]); - raft::update_host(act_h.get(), actual, size, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); - for (size_t i(0); i < rows; ++i) { - for (size_t j(0); j < cols; ++j) { - if (i != j) continue; - auto idx = i * cols + j; // row major assumption! - auto act = act_h.get()[idx]; - if (!eq_compare(expected, act)) { - return testing::AssertionFailure() - << "actual=" << act << " != expected=" << expected << " @" << i << "," << j; - } - } - } - return testing::AssertionSuccess(); -} - template -testing::AssertionResult match(const T expected, T actual, L eq_compare) +testing::AssertionResult match(const T& expected, const T& actual, L eq_compare) { if (!eq_compare(expected, actual)) { return testing::AssertionFailure() << "actual=" << actual << " != expected=" << expected; @@ -309,103 +112,4 @@ testing::AssertionResult match(const T expected, T actual, L eq_compare) return testing::AssertionSuccess(); } -template -typename std::enable_if_t> gen_uniform(T* out, - raft::random::RngState& rng, - IdxT len, - cudaStream_t stream, - T range_min = T(-1), - T range_max = T(1)) -{ - raft::random::uniform(rng, out, len, range_min, range_max, stream); -} - -template -typename std::enable_if_t> gen_uniform(T* out, - raft::random::RngState& rng, - IdxT len, - cudaStream_t stream, - T range_min = T(0), - T range_max = T(100)) -{ - raft::random::uniformInt(rng, out, len, range_min, range_max, stream); -} - -template -void gen_uniform(raft::KeyValuePair* out, - raft::random::RngState& rng, - IdxT len, - cudaStream_t stream) -{ - rmm::device_uvector keys(len, stream); - rmm::device_uvector values(len, stream); - - gen_uniform(keys.data(), rng, len, stream); - gen_uniform(values.data(), rng, len, stream); - - const T1* d_keys = keys.data(); - const T2* d_values = values.data(); - auto counting = thrust::make_counting_iterator(0); - thrust::for_each(rmm::exec_policy(stream), - counting, - counting + len, - [out, d_keys, d_values] __device__(int idx) { - out[idx].key = d_keys[idx]; - out[idx].value = d_values[idx]; - }); -} - -/** @} */ - -/** time the function call 'func' using cuda events */ -#define TIMEIT_LOOP(ms, count, func) \ - do { \ - cudaEvent_t start, stop; \ - RAFT_CUDA_TRY(cudaEventCreate(&start)); \ - RAFT_CUDA_TRY(cudaEventCreate(&stop)); \ - RAFT_CUDA_TRY(cudaEventRecord(start)); \ - for (int i = 0; i < count; ++i) { \ - func; \ - } \ - RAFT_CUDA_TRY(cudaEventRecord(stop)); \ - RAFT_CUDA_TRY(cudaEventSynchronize(stop)); \ - ms = 0.f; \ - RAFT_CUDA_TRY(cudaEventElapsedTime(&ms, start, stop)); \ - ms /= args.runs; \ - } while (0) - -inline std::vector read_csv(std::string filename, bool skip_first_n_columns = 1) -{ - std::vector result; - std::ifstream myFile(filename); - if (!myFile.is_open()) throw std::runtime_error("Could not open file"); - - std::string line, colname; - int val; - - if (myFile.good()) { - std::getline(myFile, line); - std::stringstream ss(line); - while (std::getline(ss, colname, ',')) {} - } - - int n_lines = 0; - while (std::getline(myFile, line)) { - std::stringstream ss(line); - int colIdx = 0; - while (ss >> val) { - if (colIdx >= skip_first_n_columns) { - result.push_back(val); - if (ss.peek() == ',') ss.ignore(); - } - colIdx++; - } - n_lines++; - } - - printf("lines read: %d\n", n_lines); - myFile.close(); - return result; -} - }; // end namespace raft From c6ee689b9344c95cf52b85ee237046e1431b7aaa Mon Sep 17 00:00:00 2001 From: Louis Sugy Date: Fri, 9 Dec 2022 15:03:01 +0100 Subject: [PATCH 21/22] Typo explictly -> explicitly --- cpp/include/raft/core/operators.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/include/raft/core/operators.hpp b/cpp/include/raft/core/operators.hpp index 8a694a140d..19b9695896 100644 --- a/cpp/include/raft/core/operators.hpp +++ b/cpp/include/raft/core/operators.hpp @@ -226,7 +226,7 @@ struct plug_const_op { template >> constexpr explicit plug_const_op(const ConstT& s) - : c{s}, composed_op{} // The compiler complains if composed_op is not initialized explictly + : c{s}, composed_op{} // The compiler complains if composed_op is not initialized explicitly { } constexpr plug_const_op(const ConstT& s, BinaryOpT o) : c{s}, composed_op{o} {} @@ -325,7 +325,7 @@ struct map_args_op { typename CondT = std::enable_if_t && std::is_default_constructible_v>> constexpr map_args_op() - : outer_op{} // The compiler complains if outer_op is not initialized explictly + : outer_op{} // The compiler complains if outer_op is not initialized explicitly { } constexpr explicit map_args_op(OuterOpT outer_op, ArgOpsT... arg_ops) From b734f7e72366f71906b5227000a181e5043677e3 Mon Sep 17 00:00:00 2001 From: Louis Sugy Date: Mon, 12 Dec 2022 12:08:05 +0100 Subject: [PATCH 22/22] Replace raft::abs in tests with std::abs --- cpp/test/distance/fused_l2_nn.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/test/distance/fused_l2_nn.cu b/cpp/test/distance/fused_l2_nn.cu index c4126d25df..252f56607f 100644 --- a/cpp/test/distance/fused_l2_nn.cu +++ b/cpp/test/distance/fused_l2_nn.cu @@ -208,8 +208,8 @@ struct CompareApproxAbsKVP { CompareApproxAbsKVP(T eps_) : eps(eps_) {} bool operator()(const KVP& a, const KVP& b) const { - T diff = raft::abs(raft::abs(a.value) - raft::abs(b.value)); - T m = std::max(raft::abs(a.value), raft::abs(b.value)); + T diff = std::abs(std::abs(a.value) - std::abs(b.value)); + T m = std::max(std::abs(a.value), std::abs(b.value)); T ratio = m >= eps ? diff / m : diff; return (ratio <= eps); }