diff --git a/cpp/bench/linalg/norm.cu b/cpp/bench/linalg/norm.cu index cce4195cf1..efecee88c9 100644 --- a/cpp/bench/linalg/norm.cu +++ b/cpp/bench/linalg/norm.cu @@ -18,6 +18,7 @@ #include #include #include +#include #include #include @@ -60,7 +61,7 @@ struct rowNorm : public fixture { output_view, raft::linalg::L2Norm, raft::linalg::Apply::ALONG_ROWS, - raft::SqrtOp()); + raft::sqrt_op()); }); } diff --git a/cpp/include/raft/cluster/detail/kmeans.cuh b/cpp/include/raft/cluster/detail/kmeans.cuh index 2d3481b4e1..e575849536 100644 --- a/cpp/include/raft/cluster/detail/kmeans.cuh +++ b/cpp/include/raft/cluster/detail/kmeans.cuh @@ -36,6 +36,7 @@ #include #include #include +#include #include #include #include @@ -197,16 +198,15 @@ void kmeansPlusPlus(const raft::handle_t& handle, // Outputs minDistanceBuf[n_trials x n_samples] where minDistance[i, :] contains updated // minClusterDistance that includes candidate-i auto minDistBuf = distBuffer.view(); - raft::linalg::matrixVectorOp( - minDistBuf.data_handle(), - pwd.data_handle(), - minClusterDistance.data_handle(), - pwd.extent(1), - pwd.extent(0), - true, - true, - [=] __device__(DataT mat, DataT vec) { return vec <= mat ? vec : mat; }, - stream); + raft::linalg::matrixVectorOp(minDistBuf.data_handle(), + pwd.data_handle(), + minClusterDistance.data_handle(), + pwd.extent(1), + pwd.extent(0), + true, + true, + raft::min_op{}, + stream); // Calculate costPerCandidate[n_trials] where costPerCandidate[i] is the cluster cost when using // centroid candidate-i @@ -326,21 +326,15 @@ void update_centroids(const raft::handle_t& handle, // weight_per_cluster[n_clusters] - 1D array, weight_per_cluster[i] contains sum of weights in // cluster-i. // Note - when weight_per_cluster[i] is 0, new_centroids[i] is reset to 0 - raft::linalg::matrixVectorOp( - new_centroids.data_handle(), - new_centroids.data_handle(), - weight_per_cluster.data_handle(), - new_centroids.extent(1), - new_centroids.extent(0), - true, - false, - [=] __device__(DataT mat, DataT vec) { - if (vec == 0) - return DataT(0); - else - return mat / vec; - }, - handle.get_stream()); + raft::linalg::matrixVectorOp(new_centroids.data_handle(), + new_centroids.data_handle(), + weight_per_cluster.data_handle(), + new_centroids.extent(1), + new_centroids.extent(0), + true, + false, + raft::div_checkzero_op{}, + handle.get_stream()); // copy centroids[i] to new_centroids[i] when weight_per_cluster[i] is 0 cub::ArgIndexInputIterator itr_wt(weight_per_cluster.data_handle()); @@ -356,9 +350,7 @@ void update_centroids(const raft::handle_t& handle, // copy when the sum of weights in the cluster is 0 return map.value == 0; }, - [=] __device__(raft::KeyValuePair map) { // map - return map.key; - }, + raft::key_op{}, handle.get_stream()); } @@ -399,7 +391,7 @@ void kmeans_fit_main(const raft::handle_t& handle, // resource auto wtInCluster = raft::make_device_vector(handle, n_clusters); - rmm::device_scalar> clusterCostD(stream); + rmm::device_scalar clusterCostD(stream); // L2 norm of X: ||x||^2 auto L2NormX = raft::make_device_vector(handle, n_samples); @@ -470,16 +462,12 @@ void kmeans_fit_main(const raft::handle_t& handle, // compute the squared norm between the newCentroids and the original // centroids, destructor releases the resource auto sqrdNorm = raft::make_device_scalar(handle, DataT(0)); - raft::linalg::mapThenSumReduce( - sqrdNorm.data_handle(), - newCentroids.size(), - [=] __device__(const DataT a, const DataT b) { - DataT diff = a - b; - return diff * diff; - }, - stream, - centroids.data_handle(), - newCentroids.data_handle()); + raft::linalg::mapThenSumReduce(sqrdNorm.data_handle(), + newCentroids.size(), + raft::sqdiff_op{}, + stream, + centroids.data_handle(), + newCentroids.data_handle()); DataT sqrdNormError = 0; raft::copy(&sqrdNormError, sqrdNorm.data_handle(), sqrdNorm.size(), stream); @@ -494,18 +482,11 @@ void kmeans_fit_main(const raft::handle_t& handle, minClusterAndDistance.view(), workspace, raft::make_device_scalar_view(clusterCostD.data()), - [] __device__(const raft::KeyValuePair& a, - const raft::KeyValuePair& b) { - raft::KeyValuePair res; - res.key = 0; - res.value = a.value + b.value; - return res; - }); - - DataT curClusteringCost = 0; - raft::copy(&curClusteringCost, &(clusterCostD.data()->value), 1, stream); - - handle.sync_stream(stream); + raft::value_op{}, + raft::add_op{}); + + DataT curClusteringCost = clusterCostD.value(stream); + ASSERT(curClusteringCost != (DataT)0.0, "Too few points and centroids being found is getting 0 cost from " "centers"); @@ -558,15 +539,10 @@ void kmeans_fit_main(const raft::handle_t& handle, minClusterAndDistance.view(), workspace, raft::make_device_scalar_view(clusterCostD.data()), - [] __device__(const raft::KeyValuePair& a, - const raft::KeyValuePair& b) { - raft::KeyValuePair res; - res.key = 0; - res.value = a.value + b.value; - return res; - }); + raft::value_op{}, + raft::add_op{}); - raft::copy(inertia.data_handle(), &(clusterCostD.data()->value), 1, stream); + inertia[0] = clusterCostD.value(stream); RAFT_LOG_DEBUG("KMeans.fit: completed after %d iterations with %f inertia[0] ", n_iter[0] > params.max_iter ? n_iter[0] - 1 : n_iter[0], @@ -678,7 +654,8 @@ void initScalableKMeansPlusPlus(const raft::handle_t& handle, minClusterDistanceVec.view(), workspace, raft::make_device_scalar_view(clusterCost.data()), - [] __device__(const DataT& a, const DataT& b) { return a + b; }); + raft::identity_op{}, + raft::add_op{}); auto psi = clusterCost.value(stream); @@ -710,7 +687,8 @@ void initScalableKMeansPlusPlus(const raft::handle_t& handle, minClusterDistanceVec.view(), workspace, raft::make_device_scalar_view(clusterCost.data()), - [] __device__(const DataT& a, const DataT& b) { return a + b; }); + raft::identity_op{}, + raft::add_op{}); psi = clusterCost.value(stream); @@ -1079,7 +1057,7 @@ void kmeans_predict(handle_t const& handle, workspace); // calculate cluster cost phi_x(C) - rmm::device_scalar> clusterCostD(stream); + rmm::device_scalar clusterCostD(stream); // TODO: add different templates for InType of binaryOp to avoid thrust transform thrust::transform(handle.get_thrust_policy(), minClusterAndDistance.data_handle(), @@ -1097,21 +1075,16 @@ void kmeans_predict(handle_t const& handle, minClusterAndDistance.view(), workspace, raft::make_device_scalar_view(clusterCostD.data()), - [] __device__(const raft::KeyValuePair& a, - const raft::KeyValuePair& b) { - raft::KeyValuePair res; - res.key = 0; - res.value = a.value + b.value; - return res; - }); - - raft::copy(inertia.data_handle(), &(clusterCostD.data()->value), 1, stream); + raft::value_op{}, + raft::add_op{}); thrust::transform(handle.get_thrust_policy(), minClusterAndDistance.data_handle(), minClusterAndDistance.data_handle() + minClusterAndDistance.size(), labels.data_handle(), - [=] __device__(raft::KeyValuePair pair) { return pair.key; }); + raft::key_op{}); + + inertia[0] = clusterCostD.value(stream); } template diff --git a/cpp/include/raft/cluster/detail/kmeans_common.cuh b/cpp/include/raft/cluster/detail/kmeans_common.cuh index 2973be8c23..4e52661278 100644 --- a/cpp/include/raft/cluster/detail/kmeans_common.cuh +++ b/cpp/include/raft/cluster/detail/kmeans_common.cuh @@ -34,6 +34,7 @@ #include #include #include +#include #include #include #include @@ -156,12 +157,11 @@ void checkWeight(const raft::handle_t& handle, n_samples); auto scale = static_cast(n_samples) / wt_sum; - raft::linalg::unaryOp( - weight.data_handle(), - weight.data_handle(), - n_samples, - [=] __device__(const DataT& wt) { return wt * scale; }, - stream); + raft::linalg::unaryOp(weight.data_handle(), + weight.data_handle(), + n_samples, + raft::mul_const_op{scale}, + stream); } } @@ -179,33 +179,42 @@ IndexT getCentroidsBatchSize(int batch_centroids, IndexT n_local_clusters) return (minVal == 0) ? n_local_clusters : minVal; } -template +template void computeClusterCost(const raft::handle_t& handle, - raft::device_vector_view minClusterDistance, + raft::device_vector_view minClusterDistance, rmm::device_uvector& workspace, - raft::device_scalar_view clusterCost, + raft::device_scalar_view clusterCost, + MainOpT main_op, ReductionOpT reduction_op) { - cudaStream_t stream = handle.get_stream(); + cudaStream_t stream = handle.get_stream(); + + cub::TransformInputIterator itr(minClusterDistance.data_handle(), + main_op); + size_t temp_storage_bytes = 0; RAFT_CUDA_TRY(cub::DeviceReduce::Reduce(nullptr, temp_storage_bytes, - minClusterDistance.data_handle(), + itr, clusterCost.data_handle(), minClusterDistance.size(), reduction_op, - DataT(), + OutputT(), stream)); workspace.resize(temp_storage_bytes, stream); RAFT_CUDA_TRY(cub::DeviceReduce::Reduce(workspace.data(), temp_storage_bytes, - minClusterDistance.data_handle(), + itr, clusterCost.data_handle(), minClusterDistance.size(), reduction_op, - DataT(), + OutputT(), stream)); } @@ -267,9 +276,7 @@ void sampleCentroids(const raft::handle_t& handle, sampledMinClusterDistance.data_handle(), nPtsSampledInRank, inRankCp.data(), - [=] __device__(raft::KeyValuePair val) { // MapTransformOp - return val.key; - }, + raft::key_op{}, stream); } @@ -464,10 +471,8 @@ void minClusterAndDistanceCompute( pair.value = val; return pair; }, - [=] __device__(raft::KeyValuePair a, raft::KeyValuePair b) { - return (b.value < a.value) ? b : a; - }, - [=] __device__(raft::KeyValuePair pair) { return pair; }); + raft::argmin_op{}, + raft::identity_op{}); } } } @@ -542,7 +547,6 @@ void minClusterDistanceCompute(const raft::handle_t& handle, if (is_fused) { workspace.resize((sizeof(IndexT)) * ns, stream); - // todo(lsugy): remove cIdx raft::distance::fusedL2NNMinReduce( minClusterDistanceView.data_handle(), datasetView.data_handle(), @@ -577,23 +581,16 @@ void minClusterDistanceCompute(const raft::handle_t& handle, pairwise_distance_kmeans( handle, datasetView, centroidsView, pairwiseDistanceView, workspace, metric); - raft::linalg::coalescedReduction( - minClusterDistanceView.data_handle(), - pairwiseDistanceView.data_handle(), - pairwiseDistanceView.extent(1), - pairwiseDistanceView.extent(0), - std::numeric_limits::max(), - stream, - true, - [=] __device__(DataT val, IndexT i) { // MainLambda - return val; - }, - [=] __device__(DataT a, DataT b) { // ReduceLambda - return (b < a) ? b : a; - }, - [=] __device__(DataT val) { // FinalLambda - return val; - }); + raft::linalg::coalescedReduction(minClusterDistanceView.data_handle(), + pairwiseDistanceView.data_handle(), + pairwiseDistanceView.extent(1), + pairwiseDistanceView.extent(0), + std::numeric_limits::max(), + stream, + true, + raft::identity_op{}, + raft::min_op{}, + raft::identity_op{}); } } } diff --git a/cpp/include/raft/cluster/kmeans.cuh b/cpp/include/raft/cluster/kmeans.cuh index d64815244b..4b912dc966 100644 --- a/cpp/include/raft/cluster/kmeans.cuh +++ b/cpp/include/raft/cluster/kmeans.cuh @@ -20,6 +20,7 @@ #include #include #include +#include namespace raft::cluster::kmeans { @@ -313,7 +314,8 @@ void cluster_cost(const raft::handle_t& handle, raft::device_scalar_view clusterCost, ReductionOpT reduction_op) { - detail::computeClusterCost(handle, minClusterDistance, workspace, clusterCost, reduction_op); + detail::computeClusterCost( + handle, minClusterDistance, workspace, clusterCost, raft::identity_op{}, reduction_op); } /** diff --git a/cpp/include/raft/core/operators.hpp b/cpp/include/raft/core/operators.hpp new file mode 100644 index 0000000000..19b9695896 --- /dev/null +++ b/cpp/include/raft/core/operators.hpp @@ -0,0 +1,354 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include + +#include + +namespace raft { + +/** + * @defgroup Functors Commonly used functors. + * The optional unused arguments are useful for kernels that pass the index along with the value. + * @{ + */ + +struct identity_op { + template + constexpr RAFT_INLINE_FUNCTION auto operator()(const Type& in, UnusedArgs...) const + { + return in; + } +}; + +template +struct cast_op { + template + constexpr RAFT_INLINE_FUNCTION auto operator()(InT in, UnusedArgs...) const + { + return static_cast(in); + } +}; + +struct key_op { + template + constexpr RAFT_INLINE_FUNCTION auto operator()(const KVP& p, UnusedArgs...) const + { + return p.key; + } +}; + +struct value_op { + template + constexpr RAFT_INLINE_FUNCTION auto operator()(const KVP& p, UnusedArgs...) const + { + return p.value; + } +}; + +struct sqrt_op { + template + constexpr RAFT_INLINE_FUNCTION auto operator()(const Type& in, UnusedArgs...) const + { + return std::sqrt(in); + } +}; + +struct nz_op { + template + constexpr RAFT_INLINE_FUNCTION auto operator()(const Type& in, UnusedArgs...) const + { + return in != Type(0) ? Type(1) : Type(0); + } +}; + +struct abs_op { + template + constexpr RAFT_INLINE_FUNCTION auto operator()(const Type& in, UnusedArgs...) const + { + return std::abs(in); + } +}; + +struct sq_op { + template + constexpr RAFT_INLINE_FUNCTION auto operator()(const Type& in, UnusedArgs...) const + { + return in * in; + } +}; + +struct add_op { + template + constexpr RAFT_INLINE_FUNCTION auto operator()(const T1& a, const T2& b) const + { + return a + b; + } +}; + +struct sub_op { + template + constexpr RAFT_INLINE_FUNCTION auto operator()(const T1& a, const T2& b) const + { + return a - b; + } +}; + +struct mul_op { + template + constexpr RAFT_INLINE_FUNCTION auto operator()(const T1& a, const T2& b) const + { + return a * b; + } +}; + +struct div_op { + template + constexpr RAFT_INLINE_FUNCTION auto operator()(const T1& a, const T2& b) const + { + return a / b; + } +}; + +struct div_checkzero_op { + template + constexpr RAFT_INLINE_FUNCTION auto operator()(const Type& a, const Type& b) const + { + if (b == Type{0}) { return Type{0}; } + return a / b; + } +}; + +struct pow_op { + template + constexpr RAFT_INLINE_FUNCTION auto operator()(const Type& a, const Type& b) const + { + return std::pow(a, b); + } +}; + +struct min_op { + template + constexpr RAFT_INLINE_FUNCTION auto operator()(const Type& a, const Type& b) const + { + if (a > b) { return b; } + return a; + } +}; + +struct max_op { + template + constexpr RAFT_INLINE_FUNCTION auto operator()(const Type& a, const Type& b) const + { + if (b > a) { return b; } + return a; + } +}; + +struct sqdiff_op { + template + constexpr RAFT_INLINE_FUNCTION auto operator()(const Type& a, const Type& b) const + { + Type diff = a - b; + return diff * diff; + } +}; + +struct argmin_op { + template + constexpr RAFT_INLINE_FUNCTION auto operator()(const KVP& a, const KVP& b) const + { + if ((b.value < a.value) || ((a.value == b.value) && (b.key < a.key))) { return b; } + return a; + } +}; + +struct argmax_op { + template + constexpr RAFT_INLINE_FUNCTION auto operator()(const KVP& a, const KVP& b) const + { + if ((b.value > a.value) || ((a.value == b.value) && (b.key < a.key))) { return b; } + return a; + } +}; + +template +struct const_op { + const ScalarT scalar; + + constexpr explicit const_op(const ScalarT& s) : scalar{s} {} + + template + constexpr RAFT_INLINE_FUNCTION auto operator()(Args...) const + { + return scalar; + } +}; + +/** + * @brief Wraps around a binary operator, passing a constant on the right-hand side. + * + * Usage example: + * @code{.cpp} + * #include + * + * raft::plug_const_op op(2.0f); + * std::cout << op(2.1f) << std::endl; // 4.2 + * @endcode + * + * @tparam ConstT + * @tparam BinaryOpT + */ +template +struct plug_const_op { + const ConstT c; + const BinaryOpT composed_op; + + template >> + constexpr explicit plug_const_op(const ConstT& s) + : c{s}, composed_op{} // The compiler complains if composed_op is not initialized explicitly + { + } + constexpr plug_const_op(const ConstT& s, BinaryOpT o) : c{s}, composed_op{o} {} + + template + constexpr RAFT_INLINE_FUNCTION auto operator()(InT a) const + { + return composed_op(a, c); + } +}; + +template +using add_const_op = plug_const_op; + +template +using sub_const_op = plug_const_op; + +template +using mul_const_op = plug_const_op; + +template +using div_const_op = plug_const_op; + +template +using div_checkzero_const_op = plug_const_op; + +template +using pow_const_op = plug_const_op; + +/** + * @brief Constructs an operator by composing a chain of operators. + * + * Note that all arguments are passed to the innermost operator. + * + * Usage example: + * @code{.cpp} + * #include + * + * auto op = raft::compose_op(raft::sqrt_op(), raft::abs_op(), raft::cast_op(), + * raft::add_const_op(8)); + * std::cout << op(-50) << std::endl; // 6.48074 + * @endcode + * + * @tparam OpsT Any number of operation types. + */ +template +struct compose_op { + const std::tuple ops; + + template , + typename CondT = std::enable_if_t>> + constexpr compose_op() + { + } + constexpr explicit compose_op(OpsT... ops) : ops{ops...} {} + + template + constexpr RAFT_INLINE_FUNCTION auto operator()(Args&&... args) const + { + return compose(std::forward(args)...); + } + + private: + template + constexpr RAFT_INLINE_FUNCTION auto compose(Args&&... args) const + { + if constexpr (RemOps > 0) { + return compose(std::get(ops)(std::forward(args)...)); + } else { + return identity_op{}(std::forward(args)...); + } + } +}; + +/** + * @brief Constructs an operator by composing an outer op with one inner op for each of its inputs. + * + * Usage example: + * @code{.cpp} + * #include + * + * raft::map_args_op> op; + * std::cout << op(42.0f, 10) << std::endl; // 16.4807 + * @endcode + * + * @tparam OuterOpT Outer operation type + * @tparam ArgOpsT Operation types for each input of the outer operation + */ +template +struct map_args_op { + const OuterOpT outer_op; + const std::tuple arg_ops; + + template , + typename CondT = std::enable_if_t && + std::is_default_constructible_v>> + constexpr map_args_op() + : outer_op{} // The compiler complains if outer_op is not initialized explicitly + { + } + constexpr explicit map_args_op(OuterOpT outer_op, ArgOpsT... arg_ops) + : outer_op{outer_op}, arg_ops{arg_ops...} + { + } + + template + constexpr RAFT_INLINE_FUNCTION auto operator()(Args&&... args) const + { + constexpr size_t kNumOps = sizeof...(ArgOpsT); + static_assert(kNumOps == sizeof...(Args), + "The number of arguments does not match the number of mapping operators"); + return map_args(std::make_index_sequence{}, std::forward(args)...); + } + + private: + template + constexpr RAFT_INLINE_FUNCTION auto map_args(std::index_sequence, Args&&... args) const + { + return outer_op(std::get(arg_ops)(std::forward(args))...); + } +}; + +/** @} */ +} // namespace raft diff --git a/cpp/include/raft/distance/detail/canberra.cuh b/cpp/include/raft/distance/detail/canberra.cuh index 6be994b80a..90ed3940e1 100644 --- a/cpp/include/raft/distance/detail/canberra.cuh +++ b/cpp/include/raft/distance/detail/canberra.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -73,7 +73,7 @@ static void canberraImpl(const DataT* x, // Accumulation operation lambda auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) { - const auto diff = raft::L1Op()(x - y); + const auto diff = raft::myAbs(x - y); const auto add = raft::myAbs(x) + raft::myAbs(y); // deal with potential for 0 in denominator by // forcing 1/0 instead diff --git a/cpp/include/raft/distance/detail/chebyshev.cuh b/cpp/include/raft/distance/detail/chebyshev.cuh index 1ac10f269e..454ee8c8bb 100644 --- a/cpp/include/raft/distance/detail/chebyshev.cuh +++ b/cpp/include/raft/distance/detail/chebyshev.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -72,7 +72,7 @@ static void chebyshevImpl(const DataT* x, // Accumulation operation lambda auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) { - const auto diff = raft::L1Op()(x - y); + const auto diff = raft::myAbs(x - y); acc = raft::myMax(acc, diff); }; diff --git a/cpp/include/raft/distance/detail/correlation.cuh b/cpp/include/raft/distance/detail/correlation.cuh index 2b77d280fe..9bdbbf112c 100644 --- a/cpp/include/raft/distance/detail/correlation.cuh +++ b/cpp/include/raft/distance/detail/correlation.cuh @@ -262,8 +262,8 @@ void correlationImpl(int m, true, stream, false, - raft::Nop(), - raft::Sum()); + raft::identity_op(), + raft::add_op()); raft::linalg::reduce(norm_row_vec, pB, k, @@ -273,8 +273,8 @@ void correlationImpl(int m, true, stream, false, - raft::Nop(), - raft::Sum()); + raft::identity_op(), + raft::add_op()); sq_norm_col_vec += (m + n); sq_norm_row_vec = sq_norm_col_vec + m; @@ -290,8 +290,8 @@ void correlationImpl(int m, true, stream, false, - raft::Nop(), - raft::Sum()); + raft::identity_op(), + raft::add_op()); sq_norm_col_vec += m; sq_norm_row_vec = sq_norm_col_vec; raft::linalg::rowNorm(sq_norm_col_vec, pA, k, m, raft::linalg::L2Norm, isRowMajor, stream); diff --git a/cpp/include/raft/distance/detail/cosine.cuh b/cpp/include/raft/distance/detail/cosine.cuh index f06051962f..46a694aa51 100644 --- a/cpp/include/raft/distance/detail/cosine.cuh +++ b/cpp/include/raft/distance/detail/cosine.cuh @@ -19,6 +19,7 @@ #include #include #include +#include namespace raft { namespace distance { @@ -229,8 +230,6 @@ void cosineAlgo1(Index_ m, cudaStream_t stream, bool isRowMajor) { - auto norm_op = [] __device__(AccType in) { return raft::mySqrt(in); }; - // raft distance support inputs as float/double and output as uint8_t/float/double. static_assert(!((sizeof(OutType) > 1) && (sizeof(AccType) != sizeof(OutType))), "OutType can be uint8_t, float, double," @@ -248,10 +247,13 @@ void cosineAlgo1(Index_ m, InType* row_vec = workspace; if (pA != pB) { row_vec += m; - raft::linalg::rowNorm(col_vec, pA, k, m, raft::linalg::L2Norm, isRowMajor, stream, norm_op); - raft::linalg::rowNorm(row_vec, pB, k, n, raft::linalg::L2Norm, isRowMajor, stream, norm_op); + raft::linalg::rowNorm( + col_vec, pA, k, m, raft::linalg::L2Norm, isRowMajor, stream, raft::sqrt_op{}); + raft::linalg::rowNorm( + row_vec, pB, k, n, raft::linalg::L2Norm, isRowMajor, stream, raft::sqrt_op{}); } else { - raft::linalg::rowNorm(col_vec, pA, k, m, raft::linalg::L2Norm, isRowMajor, stream, norm_op); + raft::linalg::rowNorm( + col_vec, pA, k, m, raft::linalg::L2Norm, isRowMajor, stream, raft::sqrt_op{}); } if (isRowMajor) { diff --git a/cpp/include/raft/distance/detail/euclidean.cuh b/cpp/include/raft/distance/detail/euclidean.cuh index 5ea74fa884..4184810fff 100644 --- a/cpp/include/raft/distance/detail/euclidean.cuh +++ b/cpp/include/raft/distance/detail/euclidean.cuh @@ -19,6 +19,7 @@ #include #include #include +#include namespace raft { namespace distance { @@ -247,8 +248,6 @@ void euclideanAlgo1(Index_ m, cudaStream_t stream, bool isRowMajor) { - auto norm_op = [] __device__(InType in) { return in; }; - // raft distance support inputs as float/double and output as uint8_t/float/double. static_assert(!((sizeof(OutType) > 1) && (sizeof(AccType) != sizeof(OutType))), "OutType can be uint8_t, float, double," @@ -266,10 +265,13 @@ void euclideanAlgo1(Index_ m, InType* row_vec = workspace; if (pA != pB) { row_vec += m; - raft::linalg::rowNorm(col_vec, pA, k, m, raft::linalg::L2Norm, isRowMajor, stream, norm_op); - raft::linalg::rowNorm(row_vec, pB, k, n, raft::linalg::L2Norm, isRowMajor, stream, norm_op); + raft::linalg::rowNorm( + col_vec, pA, k, m, raft::linalg::L2Norm, isRowMajor, stream, raft::identity_op{}); + raft::linalg::rowNorm( + row_vec, pB, k, n, raft::linalg::L2Norm, isRowMajor, stream, raft::identity_op{}); } else { - raft::linalg::rowNorm(col_vec, pA, k, m, raft::linalg::L2Norm, isRowMajor, stream, norm_op); + raft::linalg::rowNorm( + col_vec, pA, k, m, raft::linalg::L2Norm, isRowMajor, stream, raft::identity_op{}); } if (isRowMajor) { diff --git a/cpp/include/raft/distance/detail/fused_l2_nn.cuh b/cpp/include/raft/distance/detail/fused_l2_nn.cuh index e8c2648c2e..c9750df8ad 100644 --- a/cpp/include/raft/distance/detail/fused_l2_nn.cuh +++ b/cpp/include/raft/distance/detail/fused_l2_nn.cuh @@ -298,8 +298,6 @@ void fusedL2NNImpl(OutT* min, RAFT_CUDA_TRY(cudaGetLastError()); } - auto fin_op = [] __device__(DataT d_val, int g_d_idx) { return d_val; }; - constexpr size_t shmemSize = P::SmemSize + ((P::Mblk + P::Nblk) * sizeof(DataT)); if (sqrt) { auto fusedL2NNSqrt = fusedL2NNkernel; + raft::identity_op>; dim3 grid = launchConfigGenerator

(m, n, shmemSize, fusedL2NNSqrt); - fusedL2NNSqrt<<>>( - min, x, y, xn, yn, m, n, k, maxVal, workspace, redOp, pairRedOp, core_lambda, fin_op); + fusedL2NNSqrt<<>>(min, + x, + y, + xn, + yn, + m, + n, + k, + maxVal, + workspace, + redOp, + pairRedOp, + core_lambda, + raft::identity_op{}); } else { auto fusedL2NN = fusedL2NNkernel; + raft::identity_op>; dim3 grid = launchConfigGenerator

(m, n, shmemSize, fusedL2NN); - fusedL2NN<<>>( - min, x, y, xn, yn, m, n, k, maxVal, workspace, redOp, pairRedOp, core_lambda, fin_op); + fusedL2NN<<>>(min, + x, + y, + xn, + yn, + m, + n, + k, + maxVal, + workspace, + redOp, + pairRedOp, + core_lambda, + raft::identity_op{}); } RAFT_CUDA_TRY(cudaGetLastError()); diff --git a/cpp/include/raft/distance/detail/hellinger.cuh b/cpp/include/raft/distance/detail/hellinger.cuh index 31854fd1d6..51f462ab36 100644 --- a/cpp/include/raft/distance/detail/hellinger.cuh +++ b/cpp/include/raft/distance/detail/hellinger.cuh @@ -17,6 +17,7 @@ #pragma once #include #include +#include namespace raft { namespace distance { @@ -78,14 +79,10 @@ static void hellingerImpl(const DataT* x, dim3 blk(KPolicy::Nthreads); - auto unaryOp_lambda = [] __device__(DataT input) { return raft::mySqrt(input); }; // First sqrt x and y - raft::linalg::unaryOp( - (DataT*)x, x, m * k, unaryOp_lambda, stream); - + raft::linalg::unaryOp((DataT*)x, x, m * k, raft::sqrt_op{}, stream); if (x != y) { - raft::linalg::unaryOp( - (DataT*)y, y, n * k, unaryOp_lambda, stream); + raft::linalg::unaryOp((DataT*)y, y, n * k, raft::sqrt_op{}, stream); } // Accumulation operation lambda @@ -145,11 +142,9 @@ static void hellingerImpl(const DataT* x, } // Revert sqrt of x and y - raft::linalg::unaryOp( - (DataT*)x, x, m * k, unaryOp_lambda, stream); + raft::linalg::unaryOp((DataT*)x, x, m * k, raft::sqrt_op{}, stream); if (x != y) { - raft::linalg::unaryOp( - (DataT*)y, y, n * k, unaryOp_lambda, stream); + raft::linalg::unaryOp((DataT*)y, y, n * k, raft::sqrt_op{}, stream); } RAFT_CUDA_TRY(cudaGetLastError()); diff --git a/cpp/include/raft/distance/detail/l1.cuh b/cpp/include/raft/distance/detail/l1.cuh index 6372019fd3..95514db60b 100644 --- a/cpp/include/raft/distance/detail/l1.cuh +++ b/cpp/include/raft/distance/detail/l1.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2021, NVIDIA CORPORATION. + * Copyright (c) 2018-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -71,7 +71,7 @@ static void l1Impl(const DataT* x, // Accumulation operation lambda auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) { - const auto diff = raft::L1Op()(x - y); + const auto diff = raft::myAbs(x - y); acc += diff; }; diff --git a/cpp/include/raft/distance/detail/minkowski.cuh b/cpp/include/raft/distance/detail/minkowski.cuh index d3d0979d0d..bda83babf1 100644 --- a/cpp/include/raft/distance/detail/minkowski.cuh +++ b/cpp/include/raft/distance/detail/minkowski.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -74,7 +74,7 @@ void minkowskiUnExpImpl(const DataT* x, // Accumulation operation lambda auto core_lambda = [p] __device__(AccT & acc, DataT & x, DataT & y) { - const auto diff = raft::L1Op()(x - y); + const auto diff = raft::myAbs(x - y); acc += raft::myPow(diff, p); }; diff --git a/cpp/include/raft/distance/detail/pairwise_distance_base.cuh b/cpp/include/raft/distance/detail/pairwise_distance_base.cuh index 26536d13cd..69bb83d29a 100644 --- a/cpp/include/raft/distance/detail/pairwise_distance_base.cuh +++ b/cpp/include/raft/distance/detail/pairwise_distance_base.cuh @@ -14,6 +14,7 @@ * limitations under the License. */ #pragma once +#include #include #include #include diff --git a/cpp/include/raft/label/detail/classlabels.cuh b/cpp/include/raft/label/detail/classlabels.cuh index 0af1c70b91..64d8b4bfae 100644 --- a/cpp/include/raft/label/detail/classlabels.cuh +++ b/cpp/include/raft/label/detail/classlabels.cuh @@ -18,6 +18,7 @@ #include +#include #include #include #include @@ -194,8 +195,7 @@ void make_monotonic( template void make_monotonic(Type* out, Type* in, size_t N, cudaStream_t stream, bool zero_based = false) { - make_monotonic( - out, in, N, stream, [] __device__(Type val) { return false; }, zero_based); + make_monotonic(out, in, N, stream, raft::const_op(false), zero_based); } }; // namespace detail diff --git a/cpp/include/raft/linalg/add.cuh b/cpp/include/raft/linalg/add.cuh index 341ba45af5..27ab24abe8 100644 --- a/cpp/include/raft/linalg/add.cuh +++ b/cpp/include/raft/linalg/add.cuh @@ -27,8 +27,6 @@ namespace raft { namespace linalg { -using detail::adds_scalar; - /** * @ingroup arithmetic * @brief Elementwise scalar add operation on the input buffer diff --git a/cpp/include/raft/linalg/coalesced_reduction.cuh b/cpp/include/raft/linalg/coalesced_reduction.cuh index e9e5a99f46..45cd640edc 100644 --- a/cpp/include/raft/linalg/coalesced_reduction.cuh +++ b/cpp/include/raft/linalg/coalesced_reduction.cuh @@ -22,6 +22,7 @@ #include #include +#include namespace raft { namespace linalg { @@ -56,9 +57,9 @@ namespace linalg { template , - typename ReduceLambda = raft::Sum, - typename FinalLambda = raft::Nop> + typename MainLambda = raft::identity_op, + typename ReduceLambda = raft::add_op, + typename FinalLambda = raft::identity_op> void coalescedReduction(OutType* dots, const InType* data, IdxType D, @@ -66,9 +67,9 @@ void coalescedReduction(OutType* dots, OutType init, cudaStream_t stream, bool inplace = false, - MainLambda main_op = raft::Nop(), - ReduceLambda reduce_op = raft::Sum(), - FinalLambda final_op = raft::Nop()) + MainLambda main_op = raft::identity_op(), + ReduceLambda reduce_op = raft::add_op(), + FinalLambda final_op = raft::identity_op()) { detail::coalescedReduction( dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); @@ -113,17 +114,17 @@ template , - typename ReduceLambda = raft::Sum, - typename FinalLambda = raft::Nop> + typename MainLambda = raft::identity_op, + typename ReduceLambda = raft::add_op, + typename FinalLambda = raft::identity_op> void coalesced_reduction(const raft::handle_t& handle, raft::device_matrix_view data, raft::device_vector_view dots, OutValueType init, bool inplace = false, - MainLambda main_op = raft::Nop(), - ReduceLambda reduce_op = raft::Sum(), - FinalLambda final_op = raft::Nop()) + MainLambda main_op = raft::identity_op(), + ReduceLambda reduce_op = raft::add_op(), + FinalLambda final_op = raft::identity_op()) { if constexpr (std::is_same_v) { RAFT_EXPECTS(static_cast(dots.size()) == data.extent(0), diff --git a/cpp/include/raft/linalg/detail/add.cuh b/cpp/include/raft/linalg/detail/add.cuh index 34966ebbc2..bf9b2bd1d8 100644 --- a/cpp/include/raft/linalg/detail/add.cuh +++ b/cpp/include/raft/linalg/detail/add.cuh @@ -16,14 +16,11 @@ #pragma once -#include "functional.cuh" - +#include #include #include #include -#include - namespace raft { namespace linalg { namespace detail { @@ -31,13 +28,13 @@ namespace detail { template void addScalar(OutT* out, const InT* in, InT scalar, IdxType len, cudaStream_t stream) { - raft::linalg::unaryOp(out, in, len, adds_scalar(scalar), stream); + raft::linalg::unaryOp(out, in, len, raft::add_const_op(scalar), stream); } template void add(OutT* out, const InT* in1, const InT* in2, IdxType len, cudaStream_t stream) { - raft::linalg::binaryOp(out, in1, in2, len, thrust::plus(), stream); + raft::linalg::binaryOp(out, in1, in2, len, raft::add_op(), stream); } template diff --git a/cpp/include/raft/linalg/detail/coalesced_reduction.cuh b/cpp/include/raft/linalg/detail/coalesced_reduction.cuh index 63351f5475..238e17fa56 100644 --- a/cpp/include/raft/linalg/detail/coalesced_reduction.cuh +++ b/cpp/include/raft/linalg/detail/coalesced_reduction.cuh @@ -18,6 +18,7 @@ #include #include +#include #include #include @@ -71,9 +72,9 @@ template , - typename ReduceLambda = raft::Sum, - typename FinalLambda = raft::Nop> + typename MainLambda = raft::identity_op, + typename ReduceLambda = raft::add_op, + typename FinalLambda = raft::identity_op> void coalescedReductionThin(OutType* dots, const InType* data, IdxType D, @@ -81,9 +82,9 @@ void coalescedReductionThin(OutType* dots, OutType init, cudaStream_t stream, bool inplace = false, - MainLambda main_op = raft::Nop(), - ReduceLambda reduce_op = raft::Sum(), - FinalLambda final_op = raft::Nop()) + MainLambda main_op = raft::identity_op(), + ReduceLambda reduce_op = raft::add_op(), + FinalLambda final_op = raft::identity_op()) { common::nvtx::range fun_scope( "coalescedReductionThin<%d,%d>", Policy::LogicalWarpSize, Policy::RowsPerBlock); @@ -97,9 +98,9 @@ void coalescedReductionThin(OutType* dots, template , - typename ReduceLambda = raft::Sum, - typename FinalLambda = raft::Nop> + typename MainLambda = raft::identity_op, + typename ReduceLambda = raft::add_op, + typename FinalLambda = raft::identity_op> void coalescedReductionThinDispatcher(OutType* dots, const InType* data, IdxType D, @@ -107,9 +108,9 @@ void coalescedReductionThinDispatcher(OutType* dots, OutType init, cudaStream_t stream, bool inplace = false, - MainLambda main_op = raft::Nop(), - ReduceLambda reduce_op = raft::Sum(), - FinalLambda final_op = raft::Nop()) + MainLambda main_op = raft::identity_op(), + ReduceLambda reduce_op = raft::add_op(), + FinalLambda final_op = raft::identity_op()) { if (D <= IdxType(2)) { coalescedReductionThin>( @@ -168,9 +169,9 @@ template , - typename ReduceLambda = raft::Sum, - typename FinalLambda = raft::Nop> + typename MainLambda = raft::identity_op, + typename ReduceLambda = raft::add_op, + typename FinalLambda = raft::identity_op> void coalescedReductionMedium(OutType* dots, const InType* data, IdxType D, @@ -178,9 +179,9 @@ void coalescedReductionMedium(OutType* dots, OutType init, cudaStream_t stream, bool inplace = false, - MainLambda main_op = raft::Nop(), - ReduceLambda reduce_op = raft::Sum(), - FinalLambda final_op = raft::Nop()) + MainLambda main_op = raft::identity_op(), + ReduceLambda reduce_op = raft::add_op(), + FinalLambda final_op = raft::identity_op()) { common::nvtx::range fun_scope("coalescedReductionMedium<%d>", TPB); coalescedReductionMediumKernel @@ -191,9 +192,9 @@ void coalescedReductionMedium(OutType* dots, template , - typename ReduceLambda = raft::Sum, - typename FinalLambda = raft::Nop> + typename MainLambda = raft::identity_op, + typename ReduceLambda = raft::add_op, + typename FinalLambda = raft::identity_op> void coalescedReductionMediumDispatcher(OutType* dots, const InType* data, IdxType D, @@ -201,9 +202,9 @@ void coalescedReductionMediumDispatcher(OutType* dots, OutType init, cudaStream_t stream, bool inplace = false, - MainLambda main_op = raft::Nop(), - ReduceLambda reduce_op = raft::Sum(), - FinalLambda final_op = raft::Nop()) + MainLambda main_op = raft::identity_op(), + ReduceLambda reduce_op = raft::add_op(), + FinalLambda final_op = raft::identity_op()) { // Note: for now, this kernel is only used when D > 256. If this changes in the future, use // smaller block sizes when relevant. @@ -251,9 +252,9 @@ template , - typename ReduceLambda = raft::Sum, - typename FinalLambda = raft::Nop> + typename MainLambda = raft::identity_op, + typename ReduceLambda = raft::add_op, + typename FinalLambda = raft::identity_op> void coalescedReductionThick(OutType* dots, const InType* data, IdxType D, @@ -261,9 +262,9 @@ void coalescedReductionThick(OutType* dots, OutType init, cudaStream_t stream, bool inplace = false, - MainLambda main_op = raft::Nop(), - ReduceLambda reduce_op = raft::Sum(), - FinalLambda final_op = raft::Nop()) + MainLambda main_op = raft::identity_op(), + ReduceLambda reduce_op = raft::add_op(), + FinalLambda final_op = raft::identity_op()) { common::nvtx::range fun_scope( "coalescedReductionThick<%d,%d>", ThickPolicy::ThreadsPerBlock, ThickPolicy::BlocksPerRow); @@ -291,7 +292,7 @@ void coalescedReductionThick(OutType* dots, init, stream, inplace, - raft::Nop(), + raft::identity_op(), reduce_op, final_op); } @@ -299,9 +300,9 @@ void coalescedReductionThick(OutType* dots, template , - typename ReduceLambda = raft::Sum, - typename FinalLambda = raft::Nop> + typename MainLambda = raft::identity_op, + typename ReduceLambda = raft::add_op, + typename FinalLambda = raft::identity_op> void coalescedReductionThickDispatcher(OutType* dots, const InType* data, IdxType D, @@ -309,9 +310,9 @@ void coalescedReductionThickDispatcher(OutType* dots, OutType init, cudaStream_t stream, bool inplace = false, - MainLambda main_op = raft::Nop(), - ReduceLambda reduce_op = raft::Sum(), - FinalLambda final_op = raft::Nop()) + MainLambda main_op = raft::identity_op(), + ReduceLambda reduce_op = raft::add_op(), + FinalLambda final_op = raft::identity_op()) { // Note: multiple elements per thread to take advantage of the sequential reduction and loop // unrolling @@ -330,9 +331,9 @@ void coalescedReductionThickDispatcher(OutType* dots, template , - typename ReduceLambda = raft::Sum, - typename FinalLambda = raft::Nop> + typename MainLambda = raft::identity_op, + typename ReduceLambda = raft::add_op, + typename FinalLambda = raft::identity_op> void coalescedReduction(OutType* dots, const InType* data, IdxType D, @@ -340,9 +341,9 @@ void coalescedReduction(OutType* dots, OutType init, cudaStream_t stream, bool inplace = false, - MainLambda main_op = raft::Nop(), - ReduceLambda reduce_op = raft::Sum(), - FinalLambda final_op = raft::Nop()) + MainLambda main_op = raft::identity_op(), + ReduceLambda reduce_op = raft::add_op(), + FinalLambda final_op = raft::identity_op()) { /* The primitive selects one of three implementations based on heuristics: * - Thin: very efficient when D is small and/or N is large diff --git a/cpp/include/raft/linalg/detail/divide.cuh b/cpp/include/raft/linalg/detail/divide.cuh index 333cd3e83c..eef1d19d6e 100644 --- a/cpp/include/raft/linalg/detail/divide.cuh +++ b/cpp/include/raft/linalg/detail/divide.cuh @@ -16,9 +16,8 @@ #pragma once -#include "functional.cuh" - #include +#include #include namespace raft { @@ -28,7 +27,7 @@ namespace detail { template void divideScalar(OutT* out, const InT* in, InT scalar, IdxType len, cudaStream_t stream) { - raft::linalg::unaryOp(out, in, len, divides_scalar(scalar), stream); + raft::linalg::unaryOp(out, in, len, raft::div_const_op(scalar), stream); } }; // end namespace detail diff --git a/cpp/include/raft/linalg/detail/eltwise.cuh b/cpp/include/raft/linalg/detail/eltwise.cuh index 019f86a779..25b4ca0499 100644 --- a/cpp/include/raft/linalg/detail/eltwise.cuh +++ b/cpp/include/raft/linalg/detail/eltwise.cuh @@ -16,13 +16,10 @@ #pragma once -#include "functional.cuh" - +#include #include #include -#include - namespace raft { namespace linalg { namespace detail { @@ -30,48 +27,48 @@ namespace detail { template void scalarAdd(OutType* out, const InType* in, InType scalar, IdxType len, cudaStream_t stream) { - raft::linalg::unaryOp(out, in, len, adds_scalar(scalar), stream); + raft::linalg::unaryOp(out, in, len, raft::add_const_op(scalar), stream); } template void scalarMultiply(OutType* out, const InType* in, InType scalar, IdxType len, cudaStream_t stream) { - raft::linalg::unaryOp(out, in, len, multiplies_scalar(scalar), stream); + raft::linalg::unaryOp(out, in, len, raft::mul_const_op(scalar), stream); } template void eltwiseAdd( OutType* out, const InType* in1, const InType* in2, IdxType len, cudaStream_t stream) { - raft::linalg::binaryOp(out, in1, in2, len, thrust::plus(), stream); + raft::linalg::binaryOp(out, in1, in2, len, raft::add_op(), stream); } template void eltwiseSub( OutType* out, const InType* in1, const InType* in2, IdxType len, cudaStream_t stream) { - raft::linalg::binaryOp(out, in1, in2, len, thrust::minus(), stream); + raft::linalg::binaryOp(out, in1, in2, len, raft::sub_op(), stream); } template void eltwiseMultiply( OutType* out, const InType* in1, const InType* in2, IdxType len, cudaStream_t stream) { - raft::linalg::binaryOp(out, in1, in2, len, thrust::multiplies(), stream); + raft::linalg::binaryOp(out, in1, in2, len, raft::mul_op(), stream); } template void eltwiseDivide( OutType* out, const InType* in1, const InType* in2, IdxType len, cudaStream_t stream) { - raft::linalg::binaryOp(out, in1, in2, len, thrust::divides(), stream); + raft::linalg::binaryOp(out, in1, in2, len, raft::div_op(), stream); } template void eltwiseDivideCheckZero( OutType* out, const InType* in1, const InType* in2, IdxType len, cudaStream_t stream) { - raft::linalg::binaryOp(out, in1, in2, len, divides_check_zero(), stream); + raft::linalg::binaryOp(out, in1, in2, len, raft::div_checkzero_op(), stream); } }; // end namespace detail diff --git a/cpp/include/raft/linalg/detail/functional.cuh b/cpp/include/raft/linalg/detail/functional.cuh deleted file mode 100644 index 067b1565e0..0000000000 --- a/cpp/include/raft/linalg/detail/functional.cuh +++ /dev/null @@ -1,69 +0,0 @@ -/* - * Copyright (c) 2022, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -namespace raft { -namespace linalg { -namespace detail { - -template -struct divides_scalar { - public: - divides_scalar(ArgType scalar) : scalar_(scalar) {} - - __host__ __device__ inline ReturnType operator()(ArgType in) { return in / scalar_; } - - private: - ArgType scalar_; -}; - -template -struct adds_scalar { - public: - adds_scalar(ArgType scalar) : scalar_(scalar) {} - - __host__ __device__ inline ReturnType operator()(ArgType in) { return in + scalar_; } - - private: - ArgType scalar_; -}; - -template -struct multiplies_scalar { - public: - multiplies_scalar(ArgType scalar) : scalar_(scalar) {} - - __host__ __device__ inline ReturnType operator()(ArgType in) { return in * scalar_; } - - private: - ArgType scalar_; -}; - -template -struct divides_check_zero { - public: - __host__ __device__ inline ReturnType operator()(ArgType a, ArgType b) - { - return (b == static_cast(0)) ? 0.0 : a / b; - } -}; - -} // namespace detail -} // namespace linalg -} // namespace raft \ No newline at end of file diff --git a/cpp/include/raft/linalg/detail/multiply.cuh b/cpp/include/raft/linalg/detail/multiply.cuh index f1a8548bfa..84b832d875 100644 --- a/cpp/include/raft/linalg/detail/multiply.cuh +++ b/cpp/include/raft/linalg/detail/multiply.cuh @@ -16,6 +16,7 @@ #pragma once +#include #include namespace raft { @@ -26,8 +27,7 @@ template void multiplyScalar( math_t* out, const math_t* in, const math_t scalar, IdxType len, cudaStream_t stream) { - raft::linalg::unaryOp( - out, in, len, [scalar] __device__(math_t in) { return in * scalar; }, stream); + raft::linalg::unaryOp(out, in, len, raft::mul_const_op{scalar}, stream); } }; // end namespace detail diff --git a/cpp/include/raft/linalg/detail/norm.cuh b/cpp/include/raft/linalg/detail/norm.cuh index f2f08233d5..ed7e360848 100644 --- a/cpp/include/raft/linalg/detail/norm.cuh +++ b/cpp/include/raft/linalg/detail/norm.cuh @@ -16,6 +16,7 @@ #pragma once +#include #include #include @@ -44,8 +45,8 @@ void rowNormCaller(Type* dots, true, stream, false, - raft::L1Op(), - raft::Sum(), + raft::abs_op(), + raft::add_op(), fin_op); break; case L2Norm: @@ -58,8 +59,8 @@ void rowNormCaller(Type* dots, true, stream, false, - raft::L2Op(), - raft::Sum(), + raft::sq_op(), + raft::add_op(), fin_op); break; case LinfNorm: @@ -72,8 +73,8 @@ void rowNormCaller(Type* dots, true, stream, false, - raft::L1Op(), - raft::Max(), + raft::abs_op(), + raft::max_op(), fin_op); break; default: THROW("Unsupported norm type: %d", type); @@ -101,8 +102,8 @@ void colNormCaller(Type* dots, false, stream, false, - raft::L1Op(), - raft::Sum(), + raft::abs_op(), + raft::add_op(), fin_op); break; case L2Norm: @@ -115,8 +116,8 @@ void colNormCaller(Type* dots, false, stream, false, - raft::L2Op(), - raft::Sum(), + raft::sq_op(), + raft::add_op(), fin_op); break; case LinfNorm: @@ -129,8 +130,8 @@ void colNormCaller(Type* dots, false, stream, false, - raft::L1Op(), - raft::Max(), + raft::abs_op(), + raft::max_op(), fin_op); break; default: THROW("Unsupported norm type: %d", type); diff --git a/cpp/include/raft/linalg/detail/reduce.cuh b/cpp/include/raft/linalg/detail/reduce.cuh index 3022973b43..721ca8179f 100644 --- a/cpp/include/raft/linalg/detail/reduce.cuh +++ b/cpp/include/raft/linalg/detail/reduce.cuh @@ -16,9 +16,9 @@ #pragma once +#include #include #include -#include namespace raft { namespace linalg { @@ -27,9 +27,9 @@ namespace detail { template , - typename ReduceLambda = raft::Sum, - typename FinalLambda = raft::Nop> + typename MainLambda = raft::identity_op, + typename ReduceLambda = raft::add_op, + typename FinalLambda = raft::identity_op> void reduce(OutType* dots, const InType* data, IdxType D, @@ -39,9 +39,9 @@ void reduce(OutType* dots, bool alongRows, cudaStream_t stream, bool inplace = false, - MainLambda main_op = raft::Nop(), - ReduceLambda reduce_op = raft::Sum(), - FinalLambda final_op = raft::Nop()) + MainLambda main_op = raft::identity_op(), + ReduceLambda reduce_op = raft::add_op(), + FinalLambda final_op = raft::identity_op()) { if (rowMajor && alongRows) { raft::linalg::coalescedReduction( diff --git a/cpp/include/raft/linalg/detail/strided_reduction.cuh b/cpp/include/raft/linalg/detail/strided_reduction.cuh index d72bd54a32..0e516b4750 100644 --- a/cpp/include/raft/linalg/detail/strided_reduction.cuh +++ b/cpp/include/raft/linalg/detail/strided_reduction.cuh @@ -18,6 +18,7 @@ #include "unary_op.cuh" #include +#include #include #include #include @@ -107,9 +108,9 @@ __global__ void stridedReductionKernel(OutType* dots, template , - typename ReduceLambda = raft::Sum, - typename FinalLambda = raft::Nop> + typename MainLambda = raft::identity_op, + typename ReduceLambda = raft::add_op, + typename FinalLambda = raft::identity_op> void stridedReduction(OutType* dots, const InType* data, IdxType D, @@ -117,15 +118,13 @@ void stridedReduction(OutType* dots, OutType init, cudaStream_t stream, bool inplace = false, - MainLambda main_op = raft::Nop(), - ReduceLambda reduce_op = raft::Sum(), - FinalLambda final_op = raft::Nop()) + MainLambda main_op = raft::identity_op(), + ReduceLambda reduce_op = raft::add_op(), + FinalLambda final_op = raft::identity_op()) { ///@todo: this extra should go away once we have eliminated the need /// for atomics in stridedKernel (redesign for this is already underway) - if (!inplace) - raft::linalg::unaryOp( - dots, dots, D, [init] __device__(OutType a) { return init; }, stream); + if (!inplace) raft::linalg::unaryOp(dots, dots, D, raft::const_op(init), stream); // Arbitrary numbers for now, probably need to tune const dim3 thrds(32, 16); @@ -137,7 +136,7 @@ void stridedReduction(OutType* dots, ///@todo: this complication should go away once we have eliminated the need /// for atomics in stridedKernel (redesign for this is already underway) - if constexpr (std::is_same>::value && + if constexpr (std::is_same::value && std::is_same::value) stridedSummationKernel <<>>(dots, data, D, N, init, main_op); @@ -148,7 +147,7 @@ void stridedReduction(OutType* dots, ///@todo: this complication should go away once we have eliminated the need /// for atomics in stridedKernel (redesign for this is already underway) // Perform final op on output data - if (!std::is_same>::value) + if (!std::is_same::value) raft::linalg::unaryOp(dots, dots, D, final_op, stream); } diff --git a/cpp/include/raft/linalg/detail/subtract.cuh b/cpp/include/raft/linalg/detail/subtract.cuh index ae0f09d2fe..6df09df8ed 100644 --- a/cpp/include/raft/linalg/detail/subtract.cuh +++ b/cpp/include/raft/linalg/detail/subtract.cuh @@ -16,6 +16,7 @@ #pragma once +#include #include #include #include @@ -27,15 +28,13 @@ namespace detail { template void subtractScalar(OutT* out, const InT* in, InT scalar, IdxType len, cudaStream_t stream) { - auto op = [scalar] __device__(InT in) { return OutT(in - scalar); }; - raft::linalg::unaryOp(out, in, len, op, stream); + raft::linalg::unaryOp(out, in, len, raft::sub_const_op(scalar), stream); } template void subtract(OutT* out, const InT* in1, const InT* in2, IdxType len, cudaStream_t stream) { - auto op = [] __device__(InT a, InT b) { return OutT(a - b); }; - raft::linalg::binaryOp(out, in1, in2, len, op, stream); + raft::linalg::binaryOp(out, in1, in2, len, raft::sub_op(), stream); } template diff --git a/cpp/include/raft/linalg/divide.cuh b/cpp/include/raft/linalg/divide.cuh index 53b083045e..526d8a9716 100644 --- a/cpp/include/raft/linalg/divide.cuh +++ b/cpp/include/raft/linalg/divide.cuh @@ -21,13 +21,12 @@ #include "detail/divide.cuh" #include +#include #include namespace raft { namespace linalg { -using detail::divides_scalar; - /** * @defgroup ScalarOps Scalar operations on the input buffer * @tparam OutT output data-type upon which the math operation will be performed diff --git a/cpp/include/raft/linalg/eltwise.cuh b/cpp/include/raft/linalg/eltwise.cuh index dbc06a4af3..2e6c1a4ab5 100644 --- a/cpp/include/raft/linalg/eltwise.cuh +++ b/cpp/include/raft/linalg/eltwise.cuh @@ -23,8 +23,6 @@ namespace raft { namespace linalg { -using detail::adds_scalar; - /** * @defgroup ScalarOps Scalar operations on the input buffer * @tparam InType data-type upon which the math operation will be performed @@ -42,8 +40,6 @@ void scalarAdd(OutType* out, const InType* in, InType scalar, IdxType len, cudaS detail::scalarAdd(out, in, scalar, len, stream); } -using detail::multiplies_scalar; - template void scalarMultiply(OutType* out, const InType* in, InType scalar, IdxType len, cudaStream_t stream) { @@ -90,8 +86,6 @@ void eltwiseDivide( detail::eltwiseDivide(out, in1, in2, len, stream); } -using detail::divides_check_zero; - template void eltwiseDivideCheckZero( OutType* out, const InType* in1, const InType* in2, IdxType len, cudaStream_t stream) diff --git a/cpp/include/raft/linalg/norm.cuh b/cpp/include/raft/linalg/norm.cuh index 19757eb86d..b64b128fa2 100644 --- a/cpp/include/raft/linalg/norm.cuh +++ b/cpp/include/raft/linalg/norm.cuh @@ -22,6 +22,7 @@ #include "linalg_types.hpp" #include +#include #include #include @@ -47,7 +48,7 @@ namespace linalg { * @param stream cuda stream where to launch work * @param fin_op the final lambda op */ -template > +template void rowNorm(Type* dots, const Type* data, IdxType D, @@ -55,7 +56,7 @@ void rowNorm(Type* dots, NormType type, bool rowMajor, cudaStream_t stream, - Lambda fin_op = raft::Nop()) + Lambda fin_op = raft::identity_op()) { detail::rowNormCaller(dots, data, D, N, type, rowMajor, stream, fin_op); } @@ -74,7 +75,7 @@ void rowNorm(Type* dots, * @param stream cuda stream where to launch work * @param fin_op the final lambda op */ -template > +template void colNorm(Type* dots, const Type* data, IdxType D, @@ -82,7 +83,7 @@ void colNorm(Type* dots, NormType type, bool rowMajor, cudaStream_t stream, - Lambda fin_op = raft::Nop()) + Lambda fin_op = raft::identity_op()) { detail::colNormCaller(dots, data, D, N, type, rowMajor, stream, fin_op); } @@ -109,13 +110,13 @@ void colNorm(Type* dots, template > + typename Lambda = raft::identity_op> void norm(const raft::handle_t& handle, raft::device_matrix_view in, raft::device_vector_view out, NormType type, Apply apply, - Lambda fin_op = raft::Nop()) + Lambda fin_op = raft::identity_op()) { RAFT_EXPECTS(raft::is_row_or_column_major(in), "Input must be contiguous"); diff --git a/cpp/include/raft/linalg/normalize.cuh b/cpp/include/raft/linalg/normalize.cuh index e64436762c..bf6ef5a570 100644 --- a/cpp/include/raft/linalg/normalize.cuh +++ b/cpp/include/raft/linalg/normalize.cuh @@ -18,6 +18,7 @@ #include "detail/normalize.cuh" +#include #include namespace raft { @@ -99,34 +100,16 @@ void row_normalize(const raft::handle_t& handle, { switch (norm_type) { case L1Norm: - row_normalize(handle, - in, - out, - ElementType(0), - raft::L1Op(), - raft::Sum(), - raft::Nop(), - eps); + row_normalize( + handle, in, out, ElementType(0), raft::abs_op(), raft::add_op(), raft::identity_op(), eps); break; case L2Norm: - row_normalize(handle, - in, - out, - ElementType(0), - raft::L2Op(), - raft::Sum(), - raft::SqrtOp(), - eps); + row_normalize( + handle, in, out, ElementType(0), raft::sq_op(), raft::add_op(), raft::sqrt_op(), eps); break; case LinfNorm: - row_normalize(handle, - in, - out, - ElementType(0), - raft::L1Op(), - raft::Max(), - raft::Nop(), - eps); + row_normalize( + handle, in, out, ElementType(0), raft::abs_op(), raft::max_op(), raft::identity_op(), eps); break; default: THROW("Unsupported norm type: %d", norm_type); } diff --git a/cpp/include/raft/linalg/power.cuh b/cpp/include/raft/linalg/power.cuh index acd226b71d..59c2cdf314 100644 --- a/cpp/include/raft/linalg/power.cuh +++ b/cpp/include/raft/linalg/power.cuh @@ -19,9 +19,9 @@ #pragma once #include +#include #include #include -#include #include namespace raft { @@ -41,8 +41,7 @@ namespace linalg { template void powerScalar(out_t* out, const in_t* in, const in_t scalar, IdxType len, cudaStream_t stream) { - raft::linalg::unaryOp( - out, in, len, [scalar] __device__(in_t in) { return raft::myPow(in, scalar); }, stream); + raft::linalg::unaryOp(out, in, len, raft::pow_const_op(scalar), stream); } /** @} */ @@ -61,8 +60,7 @@ void powerScalar(out_t* out, const in_t* in, const in_t scalar, IdxType len, cud template void power(out_t* out, const in_t* in1, const in_t* in2, IdxType len, cudaStream_t stream) { - raft::linalg::binaryOp( - out, in1, in2, len, [] __device__(in_t a, in_t b) { return raft::myPow(a, b); }, stream); + raft::linalg::binaryOp(out, in1, in2, len, raft::pow_op(), stream); } /** @} */ diff --git a/cpp/include/raft/linalg/reduce.cuh b/cpp/include/raft/linalg/reduce.cuh index 5579acf355..3eb8196408 100644 --- a/cpp/include/raft/linalg/reduce.cuh +++ b/cpp/include/raft/linalg/reduce.cuh @@ -22,6 +22,7 @@ #include "linalg_types.hpp" #include +#include #include namespace raft { @@ -59,9 +60,9 @@ namespace linalg { template , - typename ReduceLambda = raft::Sum, - typename FinalLambda = raft::Nop> + typename MainLambda = raft::identity_op, + typename ReduceLambda = raft::add_op, + typename FinalLambda = raft::identity_op> void reduce(OutType* dots, const InType* data, IdxType D, @@ -71,9 +72,9 @@ void reduce(OutType* dots, bool alongRows, cudaStream_t stream, bool inplace = false, - MainLambda main_op = raft::Nop(), - ReduceLambda reduce_op = raft::Sum(), - FinalLambda final_op = raft::Nop()) + MainLambda main_op = raft::identity_op(), + ReduceLambda reduce_op = raft::add_op(), + FinalLambda final_op = raft::identity_op()) { detail::reduce( dots, data, D, N, init, rowMajor, alongRows, stream, inplace, main_op, reduce_op, final_op); @@ -118,18 +119,18 @@ template , - typename ReduceLambda = raft::Sum, - typename FinalLambda = raft::Nop> + typename MainLambda = raft::identity_op, + typename ReduceLambda = raft::add_op, + typename FinalLambda = raft::identity_op> void reduce(const raft::handle_t& handle, raft::device_matrix_view data, raft::device_vector_view dots, OutElementType init, Apply apply, bool inplace = false, - MainLambda main_op = raft::Nop(), - ReduceLambda reduce_op = raft::Sum(), - FinalLambda final_op = raft::Nop()) + MainLambda main_op = raft::identity_op(), + ReduceLambda reduce_op = raft::add_op(), + FinalLambda final_op = raft::identity_op()) { RAFT_EXPECTS(raft::is_row_or_column_major(data), "Input must be contiguous"); diff --git a/cpp/include/raft/linalg/sqrt.cuh b/cpp/include/raft/linalg/sqrt.cuh index 2951285c3a..ad6cad2eb2 100644 --- a/cpp/include/raft/linalg/sqrt.cuh +++ b/cpp/include/raft/linalg/sqrt.cuh @@ -19,8 +19,8 @@ #pragma once #include +#include #include -#include namespace raft { namespace linalg { @@ -38,8 +38,7 @@ namespace linalg { template void sqrt(out_t* out, const in_t* in, IdxType len, cudaStream_t stream) { - raft::linalg::unaryOp( - out, in, len, [] __device__(in_t in) { return raft::mySqrt(in); }, stream); + raft::linalg::unaryOp(out, in, len, raft::sqrt_op{}, stream); } /** @} */ diff --git a/cpp/include/raft/linalg/strided_reduction.cuh b/cpp/include/raft/linalg/strided_reduction.cuh index 0aa4aecef5..d9c26910e7 100644 --- a/cpp/include/raft/linalg/strided_reduction.cuh +++ b/cpp/include/raft/linalg/strided_reduction.cuh @@ -22,6 +22,7 @@ #include "detail/strided_reduction.cuh" #include +#include #include #include @@ -59,9 +60,9 @@ namespace linalg { template , - typename ReduceLambda = raft::Sum, - typename FinalLambda = raft::Nop> + typename MainLambda = raft::identity_op, + typename ReduceLambda = raft::add_op, + typename FinalLambda = raft::identity_op> void stridedReduction(OutType* dots, const InType* data, IdxType D, @@ -69,9 +70,9 @@ void stridedReduction(OutType* dots, OutType init, cudaStream_t stream, bool inplace = false, - MainLambda main_op = raft::Nop(), - ReduceLambda reduce_op = raft::Sum(), - FinalLambda final_op = raft::Nop()) + MainLambda main_op = raft::identity_op(), + ReduceLambda reduce_op = raft::add_op(), + FinalLambda final_op = raft::identity_op()) { // Only compile for types supported by myAtomicReduce, but don't make the compilation fail in // other cases, because coalescedReduction supports arbitrary types. @@ -124,17 +125,17 @@ template , - typename ReduceLambda = raft::Sum, - typename FinalLambda = raft::Nop> + typename MainLambda = raft::identity_op, + typename ReduceLambda = raft::add_op, + typename FinalLambda = raft::identity_op> void strided_reduction(const raft::handle_t& handle, raft::device_matrix_view data, raft::device_vector_view dots, OutValueType init, bool inplace = false, - MainLambda main_op = raft::Nop(), - ReduceLambda reduce_op = raft::Sum(), - FinalLambda final_op = raft::Nop()) + MainLambda main_op = raft::identity_op(), + ReduceLambda reduce_op = raft::add_op(), + FinalLambda final_op = raft::identity_op()) { if constexpr (std::is_same_v) { RAFT_EXPECTS(static_cast(dots.size()) == data.extent(1), diff --git a/cpp/include/raft/matrix/detail/gather.cuh b/cpp/include/raft/matrix/detail/gather.cuh index 3738afba5d..c006f69e47 100644 --- a/cpp/include/raft/matrix/detail/gather.cuh +++ b/cpp/include/raft/matrix/detail/gather.cuh @@ -16,6 +16,8 @@ #pragma once +#include + namespace raft { namespace matrix { namespace detail { @@ -183,16 +185,7 @@ void gather(const MatrixIteratorT in, { typedef typename std::iterator_traits::value_type MapValueT; gatherImpl( - in, - D, - N, - map, - map, - map_length, - out, - [] __device__(MapValueT val) { return true; }, - [] __device__(MapValueT val) { return val; }, - stream); + in, D, N, map, map, map_length, out, raft::const_op(true), raft::identity_op(), stream); } /** @@ -227,17 +220,7 @@ void gather(const MatrixIteratorT in, cudaStream_t stream) { typedef typename std::iterator_traits::value_type MapValueT; - gatherImpl( - in, - D, - N, - map, - map, - map_length, - out, - [] __device__(MapValueT val) { return true; }, - transform_op, - stream); + gatherImpl(in, D, N, map, map, map_length, out, raft::const_op(true), transform_op, stream); } /** @@ -279,17 +262,7 @@ void gather_if(const MatrixIteratorT in, cudaStream_t stream) { typedef typename std::iterator_traits::value_type MapValueT; - gatherImpl( - in, - D, - N, - map, - stencil, - map_length, - out, - pred_op, - [] __device__(MapValueT val) { return val; }, - stream); + gatherImpl(in, D, N, map, stencil, map_length, out, pred_op, raft::identity_op(), stream); } /** diff --git a/cpp/include/raft/matrix/detail/math.cuh b/cpp/include/raft/matrix/detail/math.cuh index 64c85a03a5..c559da3942 100644 --- a/cpp/include/raft/matrix/detail/math.cuh +++ b/cpp/include/raft/matrix/detail/math.cuh @@ -19,6 +19,7 @@ #include #include +#include #include #include #include @@ -188,8 +189,7 @@ void reciprocal(math_t* in, math_t* out, IdxType len, cudaStream_t stream) template void setValue(math_t* out, const math_t* in, math_t scalar, int len, cudaStream_t stream = 0) { - raft::linalg::unaryOp( - out, in, len, [scalar] __device__(math_t in) { return scalar; }, stream); + raft::linalg::unaryOp(out, in, len, raft::const_op(scalar), stream); } template @@ -201,8 +201,7 @@ void ratio( rmm::device_scalar d_sum(stream); auto* d_sum_ptr = d_sum.data(); - auto no_op = [] __device__(math_t in) { return in; }; - raft::linalg::mapThenSumReduce(d_sum_ptr, len, no_op, stream, src); + raft::linalg::mapThenSumReduce(d_sum_ptr, len, raft::identity_op{}, stream, src); raft::linalg::unaryOp( d_dest, d_src, len, [=] __device__(math_t a) { return a / (*d_sum_ptr); }, stream); } @@ -217,15 +216,7 @@ void matrixVectorBinaryMult(Type* data, cudaStream_t stream) { raft::linalg::matrixVectorOp( - data, - data, - vec, - n_col, - n_row, - rowMajor, - bcastAlongRows, - [] __device__(Type a, Type b) { return a * b; }, - stream); + data, data, vec, n_col, n_row, rowMajor, bcastAlongRows, raft::mul_op(), stream); } template @@ -264,15 +255,7 @@ void matrixVectorBinaryDiv(Type* data, cudaStream_t stream) { raft::linalg::matrixVectorOp( - data, - data, - vec, - n_col, - n_row, - rowMajor, - bcastAlongRows, - [] __device__(Type a, Type b) { return a / b; }, - stream); + data, data, vec, n_col, n_row, rowMajor, bcastAlongRows, raft::div_op(), stream); } template @@ -330,15 +313,7 @@ void matrixVectorBinaryAdd(Type* data, cudaStream_t stream) { raft::linalg::matrixVectorOp( - data, - data, - vec, - n_col, - n_row, - rowMajor, - bcastAlongRows, - [] __device__(Type a, Type b) { return a + b; }, - stream); + data, data, vec, n_col, n_row, rowMajor, bcastAlongRows, raft::add_op(), stream); } template @@ -351,15 +326,7 @@ void matrixVectorBinarySub(Type* data, cudaStream_t stream) { raft::linalg::matrixVectorOp( - data, - data, - vec, - n_col, - n_row, - rowMajor, - bcastAlongRows, - [] __device__(Type a, Type b) { return a - b; }, - stream); + data, data, vec, n_col, n_row, rowMajor, bcastAlongRows, raft::sub_op(), stream); } // Computes an argmin/argmax column-wise in a DxN matrix diff --git a/cpp/include/raft/sparse/distance/detail/lp_distance.cuh b/cpp/include/raft/sparse/distance/detail/lp_distance.cuh index 0707eb2a9b..01b665e207 100644 --- a/cpp/include/raft/sparse/distance/detail/lp_distance.cuh +++ b/cpp/include/raft/sparse/distance/detail/lp_distance.cuh @@ -18,6 +18,7 @@ #include +#include #include #include #include @@ -193,13 +194,12 @@ class lp_unexpanded_distances_t : public distances_t { { unexpanded_lp_distances(out_dists, config_, PDiff(p), Sum(), AtomicAdd()); - float one_over_p = 1.0f / p; - raft::linalg::unaryOp( - out_dists, - out_dists, - config_->a_nrows * config_->b_nrows, - [=] __device__(value_t input) { return pow(input, one_over_p); }, - config_->handle.get_stream()); + value_t one_over_p = value_t{1} / p; + raft::linalg::unaryOp(out_dists, + out_dists, + config_->a_nrows * config_->b_nrows, + raft::pow_const_op(one_over_p), + config_->handle.get_stream()); } private: @@ -220,12 +220,11 @@ class hamming_unexpanded_distances_t : public distances_t { unexpanded_lp_distances(out_dists, config_, NotEqual(), Sum(), AtomicAdd()); value_t n_cols = 1.0 / config_->a_ncols; - raft::linalg::unaryOp( - out_dists, - out_dists, - config_->a_nrows * config_->b_nrows, - [=] __device__(value_t input) { return input * n_cols; }, - config_->handle.get_stream()); + raft::linalg::unaryOp(out_dists, + out_dists, + config_->a_nrows * config_->b_nrows, + raft::mul_const_op(n_cols), + config_->handle.get_stream()); } private: @@ -302,12 +301,11 @@ class kl_divergence_unexpanded_distances_t : public distances_t { Sum(), AtomicAdd()); - raft::linalg::unaryOp( - out_dists, - out_dists, - config_->a_nrows * config_->b_nrows, - [=] __device__(value_t input) { return 0.5 * input; }, - config_->handle.get_stream()); + raft::linalg::unaryOp(out_dists, + out_dists, + config_->a_nrows * config_->b_nrows, + raft::mul_const_op(0.5), + config_->handle.get_stream()); } private: diff --git a/cpp/include/raft/sparse/op/detail/slice.cuh b/cpp/include/raft/sparse/op/detail/slice.cuh index 193d246b4b..4d2f1a4195 100644 --- a/cpp/include/raft/sparse/op/detail/slice.cuh +++ b/cpp/include/raft/sparse/op/detail/slice.cuh @@ -18,6 +18,7 @@ #include +#include #include #include #include @@ -70,12 +71,11 @@ void csr_row_slice_indptr(value_idx start_row, // we add another 1 to stop row. raft::copy_async(indptr_out, indptr + start_row, (stop_row + 2) - start_row, stream); - raft::linalg::unaryOp( - indptr_out, - indptr_out, - (stop_row + 2) - start_row, - [s_offset] __device__(value_idx input) { return input - s_offset; }, - stream); + raft::linalg::unaryOp(indptr_out, + indptr_out, + (stop_row + 2) - start_row, + raft::sub_const_op(s_offset), + stream); } /** diff --git a/cpp/include/raft/spatial/knn/detail/ann_quantized.cuh b/cpp/include/raft/spatial/knn/detail/ann_quantized.cuh index e5900ffd69..55675f2a46 100644 --- a/cpp/include/raft/spatial/knn/detail/ann_quantized.cuh +++ b/cpp/include/raft/spatial/knn/detail/ann_quantized.cuh @@ -22,6 +22,7 @@ #include "common_faiss.h" #include "processing.cuh" +#include #include #include @@ -202,11 +203,7 @@ void approx_knn_search(const handle_t& handle, float p = 0.5; // standard l2 if (index->metric == raft::distance::DistanceType::LpUnexpanded) p = 1.0 / index->metricArg; raft::linalg::unaryOp( - distances, - distances, - n * k, - [p] __device__(float input) { return powf(input, p); }, - handle.get_stream()); + distances, distances, n * k, raft::pow_const_op(p), handle.get_stream()); } if constexpr (std::is_same_v) { index->metric_processor->postprocess(distances); } } diff --git a/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh b/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh index 41f1df85fe..85a05877f1 100644 --- a/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh +++ b/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh @@ -20,8 +20,10 @@ #include // TODO: Need to hide the PairwiseDistance class impl and expose to public API #include "processing.cuh" +#include #include #include +#include namespace raft { namespace spatial { @@ -566,8 +568,6 @@ void fusedL2UnexpKnnImpl(const DataT* x, acc += diff * diff; }; - auto fin_op = [] __device__(AccT d_val, int g_d_idx) { return d_val; }; - typedef cub::KeyValuePair Pair; if (isRowMajor) { @@ -578,7 +578,7 @@ void fusedL2UnexpKnnImpl(const DataT* x, IdxT, KPolicy, decltype(core_lambda), - decltype(fin_op), + raft::identity_op, 32, 2, usePrevTopKs, @@ -590,7 +590,7 @@ void fusedL2UnexpKnnImpl(const DataT* x, IdxT, KPolicy, decltype(core_lambda), - decltype(fin_op), + raft::identity_op, 64, 3, usePrevTopKs, @@ -630,7 +630,7 @@ void fusedL2UnexpKnnImpl(const DataT* x, ldb, ldd, core_lambda, - fin_op, + raft::identity_op{}, sqrt, (uint32_t)numOfNN, (int*)workspace, @@ -757,8 +757,6 @@ void fusedL2ExpKnnImpl(const DataT* x, // Accumulation operation lambda auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) { acc += x * y; }; - auto fin_op = [] __device__(AccT d_val, int g_d_idx) { return d_val; }; - typedef cub::KeyValuePair Pair; if (isRowMajor) { @@ -769,7 +767,7 @@ void fusedL2ExpKnnImpl(const DataT* x, IdxT, KPolicy, decltype(core_lambda), - decltype(fin_op), + raft::identity_op, 32, 2, usePrevTopKs, @@ -781,7 +779,7 @@ void fusedL2ExpKnnImpl(const DataT* x, IdxT, KPolicy, decltype(core_lambda), - decltype(fin_op), + raft::identity_op, 64, 3, usePrevTopKs, @@ -818,14 +816,15 @@ void fusedL2ExpKnnImpl(const DataT* x, DataT* xn = (DataT*)workspace; DataT* yn = (DataT*)workspace; - auto norm_op = [] __device__(DataT in) { return in; }; - if (x != y) { yn += m; - raft::linalg::rowNorm(xn, x, k, m, raft::linalg::L2Norm, isRowMajor, stream, norm_op); - raft::linalg::rowNorm(yn, y, k, n, raft::linalg::L2Norm, isRowMajor, stream, norm_op); + raft::linalg::rowNorm( + xn, x, k, m, raft::linalg::L2Norm, isRowMajor, stream, raft::identity_op{}); + raft::linalg::rowNorm( + yn, y, k, n, raft::linalg::L2Norm, isRowMajor, stream, raft::identity_op{}); } else { - raft::linalg::rowNorm(xn, x, k, n, raft::linalg::L2Norm, isRowMajor, stream, norm_op); + raft::linalg::rowNorm( + xn, x, k, n, raft::linalg::L2Norm, isRowMajor, stream, raft::identity_op{}); } fusedL2ExpKnnRowMajor<<>>(x, y, @@ -838,7 +837,7 @@ void fusedL2ExpKnnImpl(const DataT* x, ldb, ldd, core_lambda, - fin_op, + raft::identity_op{}, sqrt, (uint32_t)numOfNN, mutexes, diff --git a/cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh b/cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh index 14c4dd85f1..f08a97e0f7 100644 --- a/cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh +++ b/cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh @@ -24,6 +24,7 @@ #include #include #include +#include #include #include #include @@ -245,7 +246,7 @@ inline auto extend(const handle_t& handle, raft::linalg::L2Norm, true, stream, - raft::SqrtOp()); + raft::sqrt_op()); RAFT_LOG_TRACE_VEC(ext_index.center_norms()->data_handle(), std::min(dim, 20)); } } diff --git a/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh b/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh index 94f4dc96c6..d2f7d681d7 100644 --- a/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh +++ b/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh @@ -25,6 +25,7 @@ #include #include #include +#include #include #include #include @@ -1110,7 +1111,7 @@ void search_impl(const handle_t& handle, raft::linalg::L2Norm, true, stream, - raft::SqrtOp()); + raft::sqrt_op()); utils::outer_add(query_norm_dev.data(), (IdxT)n_queries, index.center_norms()->data_handle(), diff --git a/cpp/include/raft/spatial/knn/detail/ivf_pq_build.cuh b/cpp/include/raft/spatial/knn/detail/ivf_pq_build.cuh index 9262ef6baf..2a2fc1f10b 100644 --- a/cpp/include/raft/spatial/knn/detail/ivf_pq_build.cuh +++ b/cpp/include/raft/spatial/knn/detail/ivf_pq_build.cuh @@ -25,6 +25,7 @@ #include #include #include +#include #include #include #include @@ -259,14 +260,7 @@ void select_residuals(const handle_t& handle, n_rows, (IdxT)dim, dataset, row_ids, (IdxT)dim, tmp.data(), (IdxT)dim, stream); raft::matrix::linewiseOp( - tmp.data(), - tmp.data(), - IdxT(dim), - n_rows, - true, - [] __device__(float a, float b) { return a - b; }, - stream, - center); + tmp.data(), tmp.data(), IdxT(dim), n_rows, true, raft::sub_op{}, stream, center); float alpha = 1.0; float beta = 0.0; @@ -1186,7 +1180,7 @@ inline auto build_device( raft::linalg::L2Norm, true, stream, - raft::SqrtOp()); + raft::sqrt_op()); RAFT_CUDA_TRY(cudaMemcpy2DAsync(index.centers().data_handle() + index.dim(), sizeof(float) * index.dim_ext(), center_norms.data(), diff --git a/cpp/include/raft/spatial/knn/detail/ivf_pq_search.cuh b/cpp/include/raft/spatial/knn/detail/ivf_pq_search.cuh index c1a3682f47..4ecb81edcb 100644 --- a/cpp/include/raft/spatial/knn/detail/ivf_pq_search.cuh +++ b/cpp/include/raft/spatial/knn/detail/ivf_pq_search.cuh @@ -27,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -418,14 +419,12 @@ void postprocess_distances(float* out, // [n_queries, topk] switch (metric) { case distance::DistanceType::L2Unexpanded: case distance::DistanceType::L2Expanded: { - linalg::unaryOp( - out, - in, - len, - [scaling_factor] __device__(ScoreT x) -> float { - return scaling_factor * scaling_factor * float(x); - }, - stream); + linalg::unaryOp(out, + in, + len, + raft::compose_op(raft::mul_const_op{scaling_factor * scaling_factor}, + raft::cast_op{}), + stream); } break; case distance::DistanceType::L2SqrtUnexpanded: case distance::DistanceType::L2SqrtExpanded: { @@ -433,18 +432,17 @@ void postprocess_distances(float* out, // [n_queries, topk] out, in, len, - [scaling_factor] __device__(ScoreT x) -> float { return scaling_factor * sqrtf(float(x)); }, + raft::compose_op{ + raft::mul_const_op{scaling_factor}, raft::sqrt_op{}, raft::cast_op{}}, stream); } break; case distance::DistanceType::InnerProduct: { - linalg::unaryOp( - out, - in, - len, - [scaling_factor] __device__(ScoreT x) -> float { - return -scaling_factor * scaling_factor * float(x); - }, - stream); + linalg::unaryOp(out, + in, + len, + raft::compose_op(raft::mul_const_op{-scaling_factor * scaling_factor}, + raft::cast_op{}), + stream); } break; default: RAFT_FAIL("Unexpected metric."); } diff --git a/cpp/include/raft/spatial/knn/detail/processing.cuh b/cpp/include/raft/spatial/knn/detail/processing.cuh index a80c1c1935..b4b1cb2c14 100644 --- a/cpp/include/raft/spatial/knn/detail/processing.cuh +++ b/cpp/include/raft/spatial/knn/detail/processing.cuh @@ -17,6 +17,7 @@ #include "processing.hpp" +#include #include #include #include @@ -59,32 +60,16 @@ class CosineMetricProcessor : public MetricProcessor { raft::linalg::NormType::L2Norm, row_major_, stream_, - [] __device__(math_t in) { return sqrtf(in); }); + raft::sqrt_op{}); raft::linalg::matrixVectorOp( - data, - data, - colsums_.data(), - n_cols_, - n_rows_, - row_major_, - false, - [] __device__(math_t mat_in, math_t vec_in) { return mat_in / vec_in; }, - stream_); + data, data, colsums_.data(), n_cols_, n_rows_, row_major_, false, raft::div_op{}, stream_); } void revert(math_t* data) { raft::linalg::matrixVectorOp( - data, - data, - colsums_.data(), - n_cols_, - n_rows_, - row_major_, - false, - [] __device__(math_t mat_in, math_t vec_in) { return mat_in * vec_in; }, - stream_); + data, data, colsums_.data(), n_cols_, n_rows_, row_major_, false, raft::mul_op{}, stream_); } void postprocess(math_t* data) @@ -122,12 +107,11 @@ class CorrelationMetricProcessor : public CosineMetricProcessor { true, cosine::stream_); - raft::linalg::unaryOp( - means_.data(), - means_.data(), - cosine::n_rows_, - [=] __device__(math_t in) { return in * normalizer_const; }, - cosine::stream_); + raft::linalg::unaryOp(means_.data(), + means_.data(), + cosine::n_rows_, + raft::mul_const_op(normalizer_const), + cosine::stream_); raft::stats::meanCenter(data, data, diff --git a/cpp/include/raft/stats/detail/batched/silhouette_score.cuh b/cpp/include/raft/stats/detail/batched/silhouette_score.cuh index 25a3721af1..5bd8c4a7ab 100644 --- a/cpp/include/raft/stats/detail/batched/silhouette_score.cuh +++ b/cpp/include/raft/stats/detail/batched/silhouette_score.cuh @@ -249,19 +249,18 @@ value_t silhouette_score( // calculating row-wise minimum in b // this prim only supports int indices for now - raft::linalg:: - reduce, raft::stats::detail::MinOp>( - b_ptr, - b_ptr, - n_labels, - n_rows, - std::numeric_limits::max(), - true, - true, - stream, - false, - raft::Nop(), - raft::stats::detail::MinOp()); + raft::linalg::reduce( + b_ptr, + b_ptr, + n_labels, + n_rows, + std::numeric_limits::max(), + true, + true, + stream, + false, + raft::identity_op(), + raft::min_op()); // calculating the silhouette score per sample raft::linalg::binaryOp, value_t, value_idx>( diff --git a/cpp/include/raft/stats/detail/mean_center.cuh b/cpp/include/raft/stats/detail/mean_center.cuh index 61017511b1..6e1c07e1e3 100644 --- a/cpp/include/raft/stats/detail/mean_center.cuh +++ b/cpp/include/raft/stats/detail/mean_center.cuh @@ -49,15 +49,7 @@ void meanCenter(Type* out, cudaStream_t stream) { raft::linalg::matrixVectorOp( - out, - data, - mu, - D, - N, - rowMajor, - bcastAlongRows, - [] __device__(Type a, Type b) { return a - b; }, - stream); + out, data, mu, D, N, rowMajor, bcastAlongRows, raft::sub_op{}, stream); } /** @@ -85,15 +77,7 @@ void meanAdd(Type* out, cudaStream_t stream) { raft::linalg::matrixVectorOp( - out, - data, - mu, - D, - N, - rowMajor, - bcastAlongRows, - [] __device__(Type a, Type b) { return a + b; }, - stream); + out, data, mu, D, N, rowMajor, bcastAlongRows, raft::add_op{}, stream); } }; // end namespace detail diff --git a/cpp/include/raft/stats/detail/silhouette_score.cuh b/cpp/include/raft/stats/detail/silhouette_score.cuh index 076d9b13e5..3cf95c3941 100644 --- a/cpp/include/raft/stats/detail/silhouette_score.cuh +++ b/cpp/include/raft/stats/detail/silhouette_score.cuh @@ -21,6 +21,7 @@ #include #include #include +#include #include #include #include @@ -172,20 +173,6 @@ struct SilOp { } }; -/** - * @brief structure that defines the reduction Lambda to find minimum between elements - */ -template -struct MinOp { - HDI DataT operator()(DataT a, DataT b) - { - if (a > b) - return b; - else - return a; - } -}; - /** * @brief main function that returns the average silhouette score for a given set of data and its * clusterings @@ -278,19 +265,19 @@ DataT silhouette_score( RAFT_CUDA_TRY(cudaMemsetAsync( averageDistanceBetweenSampleAndCluster.data(), 0, nRows * nLabels * sizeof(DataT), stream)); - raft::linalg::matrixVectorOp>(averageDistanceBetweenSampleAndCluster.data(), - sampleToClusterSumOfDistances.data(), - binCountArray.data(), - binCountArray.data(), - nLabels, - nRows, - true, - true, - DivOp(), - stream); + raft::linalg::matrixVectorOp(averageDistanceBetweenSampleAndCluster.data(), + sampleToClusterSumOfDistances.data(), + binCountArray.data(), + binCountArray.data(), + nLabels, + nRows, + true, + true, + DivOp(), + stream); // calculating row-wise minimum - raft::linalg::reduce, MinOp>( + raft::linalg::reduce( d_bArray.data(), averageDistanceBetweenSampleAndCluster.data(), nLabels, @@ -300,8 +287,8 @@ DataT silhouette_score( true, stream, false, - raft::Nop(), - MinOp()); + raft::identity_op{}, + raft::min_op{}); // calculating the silhouette score per sample using the d_aArray and d_bArray raft::linalg::binaryOp>( @@ -311,12 +298,12 @@ DataT silhouette_score( rmm::device_scalar d_avgSilhouetteScore(stream); RAFT_CUDA_TRY(cudaMemsetAsync(d_avgSilhouetteScore.data(), 0, sizeof(DataT), stream)); - raft::linalg::mapThenSumReduce>(d_avgSilhouetteScore.data(), - nRows, - raft::Nop(), - stream, - perSampleSilScore, - perSampleSilScore); + raft::linalg::mapThenSumReduce(d_avgSilhouetteScore.data(), + nRows, + raft::identity_op(), + stream, + perSampleSilScore, + perSampleSilScore); DataT avgSilhouetteScore = d_avgSilhouetteScore.value(stream); diff --git a/cpp/include/raft/stats/detail/weighted_mean.cuh b/cpp/include/raft/stats/detail/weighted_mean.cuh index 43dbe4e7f1..ada0995f7d 100644 --- a/cpp/include/raft/stats/detail/weighted_mean.cuh +++ b/cpp/include/raft/stats/detail/weighted_mean.cuh @@ -18,6 +18,7 @@ #include #include +#include #include namespace raft { @@ -66,8 +67,8 @@ void weightedMean(Type* mu, stream, false, [weights] __device__(Type v, IdxType i) { return v * weights[i]; }, - [] __device__(Type a, Type b) { return a + b; }, - [WS] __device__(Type v) { return v / WS; }); + raft::add_op{}, + raft::div_const_op(WS)); } }; // end namespace detail }; // end namespace stats diff --git a/cpp/include/raft/util/cuda_utils.cuh b/cpp/include/raft/util/cuda_utils.cuh index 5818fc21f3..61dd6e0ad8 100644 --- a/cpp/include/raft/util/cuda_utils.cuh +++ b/cpp/include/raft/util/cuda_utils.cuh @@ -21,6 +21,7 @@ #include #include +#include #ifndef ENABLE_MEMCPY_ASYNC // enable memcpy_async interface by default for newer GPUs @@ -508,43 +509,69 @@ HDI double myATanh(double x) /** @} */ /** - * @defgroup LambdaOps Lambda operations in reduction kernels + * @defgroup LambdaOps Legacy lambda operations, to be deprecated * @{ */ -// IdxType mostly to be used for MainLambda in *Reduction kernels template struct Nop { - HDI Type operator()(Type in, IdxType i = 0) { return in; } + [[deprecated("Nop is deprecated. Use identity_op instead.")]] HDI Type + operator()(Type in, IdxType i = 0) const + { + return in; + } }; template struct SqrtOp { - HDI Type operator()(Type in, IdxType i = 0) { return mySqrt(in); } + [[deprecated("SqrtOp is deprecated. Use sqrt_op instead.")]] HDI Type + operator()(Type in, IdxType i = 0) const + { + return mySqrt(in); + } }; template struct L0Op { - HDI Type operator()(Type in, IdxType i = 0) { return in != Type(0) ? Type(1) : Type(0); } + [[deprecated("L0Op is deprecated. Use nz_op instead.")]] HDI Type operator()(Type in, + IdxType i = 0) const + { + return in != Type(0) ? Type(1) : Type(0); + } }; template struct L1Op { - HDI Type operator()(Type in, IdxType i = 0) { return myAbs(in); } + [[deprecated("L1Op is deprecated. Use abs_op instead.")]] HDI Type operator()(Type in, + IdxType i = 0) const + { + return myAbs(in); + } }; template struct L2Op { - HDI Type operator()(Type in, IdxType i = 0) { return in * in; } + [[deprecated("L2Op is deprecated. Use sq_op instead.")]] HDI Type operator()(Type in, + IdxType i = 0) const + { + return in * in; + } }; -template +template struct Sum { - HDI Type operator()(Type a, Type b) { return a + b; } + [[deprecated("Sum is deprecated. Use add_op instead.")]] HDI OutT operator()(InT a, InT b) const + { + return a + b; + } }; template struct Max { - HDI Type operator()(Type a, Type b) { return myMax(a, b); } + [[deprecated("Max is deprecated. Use max_op instead.")]] HDI Type operator()(Type a, Type b) const + { + if (b > a) { return b; } + return a; + } }; /** @} */ @@ -939,7 +966,7 @@ DI T warpReduce(T val, ReduceLambda reduce_op) template DI T warpReduce(T val) { - return warpReduce(val, raft::Sum{}); + return warpReduce(val, raft::add_op{}); } /** diff --git a/cpp/include/raft/util/scatter.cuh b/cpp/include/raft/util/scatter.cuh index c20afa5454..e69be36ad9 100644 --- a/cpp/include/raft/util/scatter.cuh +++ b/cpp/include/raft/util/scatter.cuh @@ -16,6 +16,7 @@ #pragma once +#include #include #include @@ -37,13 +38,13 @@ namespace raft { * will be applied to every element before scattering it to the right location. * The second param in this method will be the destination index. */ -template , int TPB = 256> +template void scatter(DataT* out, const DataT* in, const IdxT* idx, IdxT len, cudaStream_t stream, - Lambda op = raft::Nop()) + Lambda op = raft::identity_op()) { if (len <= 0) return; constexpr size_t DataSize = sizeof(DataT); diff --git a/cpp/src/distance/cluster/cluster_cost.cuh b/cpp/src/distance/cluster/cluster_cost.cuh index ac8247af95..4f208ab1cd 100644 --- a/cpp/src/distance/cluster/cluster_cost.cuh +++ b/cpp/src/distance/cluster/cluster_cost.cuh @@ -15,9 +15,11 @@ */ #include +#include #include #include #include +#include namespace raft::runtime::cluster::kmeans { template @@ -59,20 +61,18 @@ void cluster_cost(const raft::handle_t& handle, handle.get_stream()); auto distances = raft::make_device_vector(handle, n_samples); - thrust::transform( - handle.get_thrust_policy(), - min_cluster_distance.data_handle(), - min_cluster_distance.data_handle() + n_samples, - distances.data_handle(), - [] __device__(const raft::KeyValuePair& a) { return a.value; }); + thrust::transform(handle.get_thrust_policy(), + min_cluster_distance.data_handle(), + min_cluster_distance.data_handle() + n_samples, + distances.data_handle(), + raft::value_op{}); rmm::device_scalar device_cost(0, handle.get_stream()); - raft::cluster::kmeans::cluster_cost( - handle, - distances.view(), - workspace, - make_device_scalar_view(device_cost.data()), - [] __device__(const ElementType& a, const ElementType& b) { return a + b; }); + raft::cluster::kmeans::cluster_cost(handle, + distances.view(), + workspace, + make_device_scalar_view(device_cost.data()), + raft::add_op{}); raft::update_host(cost, device_cost.data(), 1, handle.get_stream()); } diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index a75eb3bff6..5be8401a6f 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -86,6 +86,8 @@ if(BUILD_TESTS) CORE_TEST PATH test/common/logger.cpp + test/core/operators_device.cu + test/core/operators_host.cpp test/handle.cpp test/interruptible.cu test/nvtx.cpp diff --git a/cpp/test/cluster/kmeans.cu b/cpp/test/cluster/kmeans.cu index 698d23ac27..9644541a0c 100644 --- a/cpp/test/cluster/kmeans.cu +++ b/cpp/test/cluster/kmeans.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include @@ -22,6 +22,7 @@ #include #include #include +#include #include #include #include @@ -50,11 +51,7 @@ void run_cluster_cost(const raft::handle_t& handle, raft::device_scalar_view clusterCost) { raft::cluster::kmeans::cluster_cost( - handle, - minClusterDistance, - workspace, - clusterCost, - [] __device__(const DataT& a, const DataT& b) { return a + b; }); + handle, minClusterDistance, workspace, clusterCost, raft::add_op{}); } template diff --git a/cpp/test/cluster/linkage.cu b/cpp/test/cluster/linkage.cu index 5533f552bd..53aa5c55e3 100644 --- a/cpp/test/cluster/linkage.cu +++ b/cpp/test/cluster/linkage.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include diff --git a/cpp/test/core/operators_device.cu b/cpp/test/core/operators_device.cu new file mode 100644 index 0000000000..1697a09fcf --- /dev/null +++ b/cpp/test/core/operators_device.cu @@ -0,0 +1,340 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include + +#include "../test_utils.cuh" +#include +#include +#include +#include + +template +__global__ void eval_op_on_device_kernel(OutT* out, OpT op, Args... args) +{ + out[0] = op(std::forward(args)...); +} + +template +auto eval_op_on_device(OpT op, Args&&... args) +{ + typedef decltype(op(args...)) OutT; + auto stream = rmm::cuda_stream_default; + rmm::device_scalar result(stream); + eval_op_on_device_kernel<<<1, 1, 0, stream>>>(result.data(), op, std::forward(args)...); + return result.value(stream); +} + +TEST(OperatorsDevice, IdentityOp) +{ + raft::identity_op op; + ASSERT_TRUE(raft::match(12.34f, eval_op_on_device(op, 12.34f, 0), raft::Compare())); +} + +TEST(OperatorsDevice, CastOp) +{ + raft::cast_op op; + ASSERT_TRUE( + raft::match(1234.0f, eval_op_on_device(op, 1234, 0), raft::CompareApprox(0.00001f))); +} + +TEST(OperatorsDevice, KeyOp) +{ + raft::key_op op; + raft::KeyValuePair kvp(12, 3.4f); + ASSERT_TRUE(raft::match(12, eval_op_on_device(op, kvp, 0), raft::Compare())); +} + +TEST(OperatorsDevice, ValueOp) +{ + raft::value_op op; + raft::KeyValuePair kvp(12, 3.4f); + ASSERT_TRUE( + raft::match(3.4f, eval_op_on_device(op, kvp, 0), raft::CompareApprox(0.00001f))); +} + +TEST(OperatorsDevice, SqrtOpF) +{ + raft::sqrt_op op; + ASSERT_TRUE(raft::match( + std::sqrt(12.34f), eval_op_on_device(op, 12.34f, 0), raft::CompareApprox(0.0001f))); + ASSERT_TRUE(raft::match( + std::sqrt(12.34), eval_op_on_device(op, 12.34, 0), raft::CompareApprox(0.000001))); +} + +TEST(OperatorsDevice, NZOp) +{ + raft::nz_op op; + ASSERT_TRUE( + raft::match(0.0f, eval_op_on_device(op, 0.0f, 0), raft::CompareApprox(0.00001f))); + ASSERT_TRUE( + raft::match(1.0f, eval_op_on_device(op, 12.34f, 0), raft::CompareApprox(0.00001f))); +} + +TEST(OperatorsDevice, AbsOp) +{ + raft::abs_op op; + ASSERT_TRUE( + raft::match(12.34f, eval_op_on_device(op, -12.34f, 0), raft::CompareApprox(0.00001f))); + ASSERT_TRUE( + raft::match(12.34, eval_op_on_device(op, -12.34, 0), raft::CompareApprox(0.000001))); + ASSERT_TRUE(raft::match(1234, eval_op_on_device(op, -1234, 0), raft::Compare())); +} + +TEST(OperatorsDevice, SqOp) +{ + raft::sq_op op; + ASSERT_TRUE( + raft::match(152.2756f, eval_op_on_device(op, 12.34f, 0), raft::CompareApprox(0.00001f))); + ASSERT_TRUE(raft::match(289, eval_op_on_device(op, -17, 0), raft::Compare())); +} + +TEST(OperatorsDevice, AddOp) +{ + raft::add_op op; + ASSERT_TRUE( + raft::match(12.34f, eval_op_on_device(op, 12.0f, 0.34f), raft::CompareApprox(0.00001f))); + ASSERT_TRUE(raft::match(1234, eval_op_on_device(op, 1200, 34), raft::Compare())); +} + +TEST(OperatorsDevice, SubOp) +{ + raft::sub_op op; + ASSERT_TRUE( + raft::match(12.34f, eval_op_on_device(op, 13.0f, 0.66f), raft::CompareApprox(0.00001f))); + ASSERT_TRUE(raft::match(1234, eval_op_on_device(op, 1300, 66), raft::Compare())); +} + +TEST(OperatorsDevice, MulOp) +{ + raft::mul_op op; + ASSERT_TRUE( + raft::match(12.34f, eval_op_on_device(op, 2.0f, 6.17f), raft::CompareApprox(0.00001f))); +} + +TEST(OperatorsDevice, DivOp) +{ + raft::div_op op; + ASSERT_TRUE( + raft::match(12.34f, eval_op_on_device(op, 37.02f, 3.0f), raft::CompareApprox(0.00001f))); +} + +TEST(OperatorsDevice, DivCheckZeroOp) +{ + raft::div_checkzero_op op; + ASSERT_TRUE( + raft::match(12.34f, eval_op_on_device(op, 37.02f, 3.0f), raft::CompareApprox(0.00001f))); + ASSERT_TRUE( + raft::match(0.0f, eval_op_on_device(op, 37.02f, 0.0f), raft::CompareApprox(0.00001f))); +} + +TEST(OperatorsDevice, PowOp) +{ + raft::pow_op op; + ASSERT_TRUE( + raft::match(1000.0f, eval_op_on_device(op, 10.0f, 3.0f), raft::CompareApprox(0.0001f))); + ASSERT_TRUE( + raft::match(1000.0, eval_op_on_device(op, 10.0, 3.0), raft::CompareApprox(0.000001))); +} + +TEST(OperatorsDevice, MinOp) +{ + raft::min_op op; + ASSERT_TRUE( + raft::match(3.0f, eval_op_on_device(op, 3.0f, 5.0f), raft::CompareApprox(0.0001f))); + ASSERT_TRUE( + raft::match(3.0, eval_op_on_device(op, 5.0, 3.0), raft::CompareApprox(0.000001))); + ASSERT_TRUE(raft::match(3, eval_op_on_device(op, 3, 5), raft::Compare())); +} + +TEST(OperatorsDevice, MaxOp) +{ + raft::max_op op; + ASSERT_TRUE( + raft::match(5.0f, eval_op_on_device(op, 3.0f, 5.0f), raft::CompareApprox(0.0001f))); + ASSERT_TRUE( + raft::match(5.0, eval_op_on_device(op, 5.0, 3.0), raft::CompareApprox(0.000001))); + ASSERT_TRUE(raft::match(5, eval_op_on_device(op, 3, 5), raft::Compare())); +} + +TEST(OperatorsDevice, SqDiffOp) +{ + raft::sqdiff_op op; + ASSERT_TRUE( + raft::match(4.0f, eval_op_on_device(op, 3.0f, 5.0f), raft::CompareApprox(0.0001f))); + ASSERT_TRUE( + raft::match(4.0, eval_op_on_device(op, 5.0, 3.0), raft::CompareApprox(0.000001))); +} + +TEST(OperatorsDevice, ArgminOp) +{ + raft::argmin_op op; + raft::KeyValuePair kvp_a(0, 1.2f); + raft::KeyValuePair kvp_b(0, 3.4f); + raft::KeyValuePair kvp_c(1, 1.2f); + ASSERT_TRUE(raft::match( + kvp_a, eval_op_on_device(op, kvp_a, kvp_b), raft::Compare>())); + ASSERT_TRUE(raft::match( + kvp_a, eval_op_on_device(op, kvp_b, kvp_a), raft::Compare>())); + ASSERT_TRUE(raft::match( + kvp_a, eval_op_on_device(op, kvp_a, kvp_c), raft::Compare>())); + ASSERT_TRUE(raft::match( + kvp_a, eval_op_on_device(op, kvp_c, kvp_a), raft::Compare>())); + ASSERT_TRUE(raft::match( + kvp_c, eval_op_on_device(op, kvp_b, kvp_c), raft::Compare>())); + ASSERT_TRUE(raft::match( + kvp_c, eval_op_on_device(op, kvp_c, kvp_b), raft::Compare>())); +} + +TEST(OperatorsDevice, ArgmaxOp) +{ + raft::argmax_op op; + raft::KeyValuePair kvp_a(0, 1.2f); + raft::KeyValuePair kvp_b(0, 3.4f); + raft::KeyValuePair kvp_c(1, 1.2f); + ASSERT_TRUE(raft::match( + kvp_b, eval_op_on_device(op, kvp_a, kvp_b), raft::Compare>())); + ASSERT_TRUE(raft::match( + kvp_b, eval_op_on_device(op, kvp_b, kvp_a), raft::Compare>())); + ASSERT_TRUE(raft::match( + kvp_a, eval_op_on_device(op, kvp_a, kvp_c), raft::Compare>())); + ASSERT_TRUE(raft::match( + kvp_a, eval_op_on_device(op, kvp_c, kvp_a), raft::Compare>())); + ASSERT_TRUE(raft::match( + kvp_b, eval_op_on_device(op, kvp_b, kvp_c), raft::Compare>())); + ASSERT_TRUE(raft::match( + kvp_b, eval_op_on_device(op, kvp_c, kvp_b), raft::Compare>())); +} + +TEST(OperatorsDevice, ConstOp) +{ + raft::const_op op(12.34f); + ASSERT_TRUE(raft::match(12.34f, eval_op_on_device(op), raft::Compare())); + ASSERT_TRUE(raft::match(12.34f, eval_op_on_device(op, 42), raft::Compare())); + ASSERT_TRUE(raft::match(12.34f, eval_op_on_device(op, 13, 37.0f), raft::Compare())); +} + +template +struct trinary_add { + const T c; + constexpr explicit trinary_add(const T& c_) : c{c_} {} + constexpr RAFT_INLINE_FUNCTION auto operator()(T a, T b) const { return a + b + c; } +}; + +TEST(OperatorsDevice, PlugConstOp) +{ + // First, wrap around a default-constructible op + { + raft::plug_const_op op(0.34f); + ASSERT_TRUE( + raft::match(12.34f, eval_op_on_device(op, 12.0f), raft::CompareApprox(0.0001f))); + } + + // Second, wrap around a non-default-constructible op + { + auto op = raft::plug_const_op(10.0f, trinary_add(2.0f)); + ASSERT_TRUE( + raft::match(12.34f, eval_op_on_device(op, 0.34f), raft::CompareApprox(0.0001f))); + } +} + +TEST(OperatorsDevice, AddConstOp) +{ + raft::add_const_op op(0.34f); + ASSERT_TRUE( + raft::match(12.34f, eval_op_on_device(op, 12.0f), raft::CompareApprox(0.0001f))); +} + +TEST(OperatorsDevice, SubConstOp) +{ + raft::sub_const_op op(0.66f); + ASSERT_TRUE( + raft::match(12.34f, eval_op_on_device(op, 13.0f), raft::CompareApprox(0.0001f))); +} + +TEST(OperatorsDevice, MulConstOp) +{ + raft::mul_const_op op(2.0f); + ASSERT_TRUE( + raft::match(12.34f, eval_op_on_device(op, 6.17f), raft::CompareApprox(0.0001f))); +} + +TEST(OperatorsDevice, DivConstOp) +{ + raft::div_const_op op(3.0f); + ASSERT_TRUE( + raft::match(12.34f, eval_op_on_device(op, 37.02f), raft::CompareApprox(0.0001f))); +} + +TEST(OperatorsDevice, DivCheckZeroConstOp) +{ + // Non-zero denominator + { + raft::div_checkzero_const_op op(3.0f); + ASSERT_TRUE( + raft::match(12.34f, eval_op_on_device(op, 37.02f), raft::CompareApprox(0.0001f))); + } + // Zero denominator + { + raft::div_checkzero_const_op op(0.0f); + ASSERT_TRUE( + raft::match(0.0f, eval_op_on_device(op, 37.02f), raft::CompareApprox(0.0001f))); + } +} + +TEST(OperatorsDevice, PowConstOp) +{ + raft::pow_const_op op(3.0f); + ASSERT_TRUE( + raft::match(1000.0f, eval_op_on_device(op, 10.0f), raft::CompareApprox(0.0001f))); +} + +TEST(OperatorsDevice, ComposeOp) +{ + // All ops are default-constructible + { + raft::compose_op> op; + ASSERT_TRUE(raft::match( + std::sqrt(42.0f), eval_op_on_device(op, -42, 0), raft::CompareApprox(0.0001f))); + } + // Some ops are not default-constructible + { + auto op = raft::compose_op( + raft::sqrt_op(), raft::abs_op(), raft::add_const_op(8.0f), raft::cast_op()); + ASSERT_TRUE(raft::match( + std::sqrt(42.0f), eval_op_on_device(op, -50, 0), raft::CompareApprox(0.0001f))); + } +} + +TEST(OperatorsDevice, MapArgsOp) +{ + // All ops are default-constructible + { + raft::map_args_op op; + ASSERT_TRUE( + raft::match(42.0f, eval_op_on_device(op, 5.0f, -17.0f), raft::CompareApprox(0.0001f))); + } + // Some ops are not default-constructible + { + auto op = raft::map_args_op( + raft::add_op(), raft::pow_const_op(2.0f), raft::mul_const_op(-1.0f)); + ASSERT_TRUE( + raft::match(42.0f, eval_op_on_device(op, 5.0f, -17.0f), raft::CompareApprox(0.0001f))); + } +} diff --git a/cpp/test/core/operators_host.cpp b/cpp/test/core/operators_host.cpp new file mode 100644 index 0000000000..de66fda919 --- /dev/null +++ b/cpp/test/core/operators_host.cpp @@ -0,0 +1,286 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include + +#include "../test_utils.h" +#include +#include + +TEST(OperatorsHost, IdentityOp) +{ + raft::identity_op op; + ASSERT_TRUE(raft::match(12.34f, op(12.34f, 0), raft::Compare())); +} + +TEST(OperatorsHost, CastOp) +{ + raft::cast_op op; + ASSERT_TRUE(raft::match(1234.0f, op(1234, 0), raft::CompareApprox(0.00001f))); +} + +TEST(OperatorsHost, KeyOp) +{ + raft::key_op op; + raft::KeyValuePair kvp(12, 3.4f); + ASSERT_TRUE(raft::match(12, op(kvp, 0), raft::Compare())); +} + +TEST(OperatorsHost, ValueOp) +{ + raft::value_op op; + raft::KeyValuePair kvp(12, 3.4f); + ASSERT_TRUE(raft::match(3.4f, op(kvp, 0), raft::CompareApprox(0.00001f))); +} + +TEST(OperatorsHost, SqrtOpF) +{ + raft::sqrt_op op; + ASSERT_TRUE(raft::match(std::sqrt(12.34f), op(12.34f, 0), raft::CompareApprox(0.0001f))); + ASSERT_TRUE(raft::match(std::sqrt(12.34), op(12.34, 0), raft::CompareApprox(0.000001))); +} + +TEST(OperatorsHost, NZOp) +{ + raft::nz_op op; + ASSERT_TRUE(raft::match(0.0f, op(0.0f, 0), raft::CompareApprox(0.00001f))); + ASSERT_TRUE(raft::match(1.0f, op(12.34f, 0), raft::CompareApprox(0.00001f))); +} + +TEST(OperatorsHost, AbsOp) +{ + raft::abs_op op; + ASSERT_TRUE(raft::match(12.34f, op(-12.34f, 0), raft::CompareApprox(0.00001f))); + ASSERT_TRUE(raft::match(12.34, op(-12.34, 0), raft::CompareApprox(0.000001))); + ASSERT_TRUE(raft::match(1234, op(-1234, 0), raft::Compare())); +} + +TEST(OperatorsHost, SqOp) +{ + raft::sq_op op; + ASSERT_TRUE(raft::match(152.2756f, op(12.34f, 0), raft::CompareApprox(0.00001f))); + ASSERT_TRUE(raft::match(289, op(-17, 0), raft::Compare())); +} + +TEST(OperatorsHost, AddOp) +{ + raft::add_op op; + ASSERT_TRUE(raft::match(12.34f, op(12.0f, 0.34f), raft::CompareApprox(0.00001f))); + ASSERT_TRUE(raft::match(1234, op(1200, 34), raft::Compare())); +} + +TEST(OperatorsHost, SubOp) +{ + raft::sub_op op; + ASSERT_TRUE(raft::match(12.34f, op(13.0f, 0.66f), raft::CompareApprox(0.00001f))); + ASSERT_TRUE(raft::match(1234, op(1300, 66), raft::Compare())); +} + +TEST(OperatorsHost, MulOp) +{ + raft::mul_op op; + ASSERT_TRUE(raft::match(12.34f, op(2.0f, 6.17f), raft::CompareApprox(0.00001f))); +} + +TEST(OperatorsHost, DivOp) +{ + raft::div_op op; + ASSERT_TRUE(raft::match(12.34f, op(37.02f, 3.0f), raft::CompareApprox(0.00001f))); +} + +TEST(OperatorsHost, DivCheckZeroOp) +{ + raft::div_checkzero_op op; + ASSERT_TRUE(raft::match(12.34f, op(37.02f, 3.0f), raft::CompareApprox(0.00001f))); + ASSERT_TRUE(raft::match(0.0f, op(37.02f, 0.0f), raft::CompareApprox(0.00001f))); +} + +TEST(OperatorsHost, PowOp) +{ + raft::pow_op op; + ASSERT_TRUE(raft::match(1000.0f, op(10.0f, 3.0f), raft::CompareApprox(0.0001f))); + ASSERT_TRUE(raft::match(1000.0, op(10.0, 3.0), raft::CompareApprox(0.000001))); +} + +TEST(OperatorsHost, MinOp) +{ + raft::min_op op; + ASSERT_TRUE(raft::match(3.0f, op(3.0f, 5.0f), raft::CompareApprox(0.0001f))); + ASSERT_TRUE(raft::match(3.0, op(5.0, 3.0), raft::CompareApprox(0.000001))); + ASSERT_TRUE(raft::match(3, op(3, 5), raft::Compare())); +} + +TEST(OperatorsHost, MaxOp) +{ + raft::max_op op; + ASSERT_TRUE(raft::match(5.0f, op(3.0f, 5.0f), raft::CompareApprox(0.0001f))); + ASSERT_TRUE(raft::match(5.0, op(5.0, 3.0), raft::CompareApprox(0.000001))); + ASSERT_TRUE(raft::match(5, op(3, 5), raft::Compare())); +} + +TEST(OperatorsHost, SqDiffOp) +{ + raft::sqdiff_op op; + ASSERT_TRUE(raft::match(4.0f, op(3.0f, 5.0f), raft::CompareApprox(0.0001f))); + ASSERT_TRUE(raft::match(4.0, op(5.0, 3.0), raft::CompareApprox(0.000001))); +} + +TEST(OperatorsHost, ArgminOp) +{ + raft::argmin_op op; + raft::KeyValuePair kvp_a(0, 1.2f); + raft::KeyValuePair kvp_b(0, 3.4f); + raft::KeyValuePair kvp_c(1, 1.2f); + ASSERT_TRUE( + raft::match(kvp_a, op(kvp_a, kvp_b), raft::Compare>())); + ASSERT_TRUE( + raft::match(kvp_a, op(kvp_b, kvp_a), raft::Compare>())); + ASSERT_TRUE( + raft::match(kvp_a, op(kvp_a, kvp_c), raft::Compare>())); + ASSERT_TRUE( + raft::match(kvp_a, op(kvp_c, kvp_a), raft::Compare>())); + ASSERT_TRUE( + raft::match(kvp_c, op(kvp_b, kvp_c), raft::Compare>())); + ASSERT_TRUE( + raft::match(kvp_c, op(kvp_c, kvp_b), raft::Compare>())); +} + +TEST(OperatorsHost, ArgmaxOp) +{ + raft::argmax_op op; + raft::KeyValuePair kvp_a(0, 1.2f); + raft::KeyValuePair kvp_b(0, 3.4f); + raft::KeyValuePair kvp_c(1, 1.2f); + ASSERT_TRUE( + raft::match(kvp_b, op(kvp_a, kvp_b), raft::Compare>())); + ASSERT_TRUE( + raft::match(kvp_b, op(kvp_b, kvp_a), raft::Compare>())); + ASSERT_TRUE( + raft::match(kvp_a, op(kvp_a, kvp_c), raft::Compare>())); + ASSERT_TRUE( + raft::match(kvp_a, op(kvp_c, kvp_a), raft::Compare>())); + ASSERT_TRUE( + raft::match(kvp_b, op(kvp_b, kvp_c), raft::Compare>())); + ASSERT_TRUE( + raft::match(kvp_b, op(kvp_c, kvp_b), raft::Compare>())); +} + +TEST(OperatorsHost, ConstOp) +{ + raft::const_op op(12.34f); + ASSERT_TRUE(raft::match(12.34f, op(), raft::Compare())); + ASSERT_TRUE(raft::match(12.34f, op(42), raft::Compare())); + ASSERT_TRUE(raft::match(12.34f, op(13, 37.0f), raft::Compare())); +} + +template +struct trinary_add { + const T c; + constexpr explicit trinary_add(const T& c_) : c{c_} {} + constexpr RAFT_INLINE_FUNCTION auto operator()(T a, T b) const { return a + b + c; } +}; + +TEST(OperatorsHost, PlugConstOp) +{ + // First, wrap around a default-constructible op + { + raft::plug_const_op op(0.34f); + ASSERT_TRUE(raft::match(12.34f, op(12.0f), raft::CompareApprox(0.0001f))); + } + + // Second, wrap around a non-default-constructible op + { + auto op = raft::plug_const_op(10.0f, trinary_add(2.0f)); + ASSERT_TRUE(raft::match(12.34f, op(0.34f), raft::CompareApprox(0.0001f))); + } +} + +TEST(OperatorsHost, AddConstOp) +{ + raft::add_const_op op(0.34f); + ASSERT_TRUE(raft::match(12.34f, op(12.0f), raft::CompareApprox(0.0001f))); +} + +TEST(OperatorsHost, SubConstOp) +{ + raft::sub_const_op op(0.66f); + ASSERT_TRUE(raft::match(12.34f, op(13.0f), raft::CompareApprox(0.0001f))); +} + +TEST(OperatorsHost, MulConstOp) +{ + raft::mul_const_op op(2.0f); + ASSERT_TRUE(raft::match(12.34f, op(6.17f), raft::CompareApprox(0.0001f))); +} + +TEST(OperatorsHost, DivConstOp) +{ + raft::div_const_op op(3.0f); + ASSERT_TRUE(raft::match(12.34f, op(37.02f), raft::CompareApprox(0.0001f))); +} + +TEST(OperatorsHost, DivCheckZeroConstOp) +{ + // Non-zero denominator + { + raft::div_checkzero_const_op op(3.0f); + ASSERT_TRUE(raft::match(12.34f, op(37.02f), raft::CompareApprox(0.0001f))); + } + // Zero denominator + { + raft::div_checkzero_const_op op(0.0f); + ASSERT_TRUE(raft::match(0.0f, op(37.02f), raft::CompareApprox(0.0001f))); + } +} + +TEST(OperatorsHost, PowConstOp) +{ + raft::pow_const_op op(3.0f); + ASSERT_TRUE(raft::match(1000.0f, op(10.0f), raft::CompareApprox(0.0001f))); +} + +TEST(OperatorsHost, ComposeOp) +{ + // All ops are default-constructible + { + raft::compose_op> op; + ASSERT_TRUE(raft::match(std::sqrt(42.0f), op(-42, 0), raft::CompareApprox(0.0001f))); + } + // Some ops are not default-constructible + { + auto op = raft::compose_op( + raft::sqrt_op(), raft::abs_op(), raft::add_const_op(8.0f), raft::cast_op()); + ASSERT_TRUE(raft::match(std::sqrt(42.0f), op(-50, 0), raft::CompareApprox(0.0001f))); + } +} + +TEST(OperatorsHost, MapArgsOp) +{ + // All ops are default-constructible + { + raft::map_args_op op; + ASSERT_TRUE(raft::match(42.0f, op(5.0f, -17.0f), raft::CompareApprox(0.0001f))); + } + // Some ops are not default-constructible + { + auto op = raft::map_args_op( + raft::add_op(), raft::pow_const_op(2.0f), raft::mul_const_op(-1.0f)); + ASSERT_TRUE(raft::match(42.0f, op(5.0f, -17.0f), raft::CompareApprox(0.0001f))); + } +} diff --git a/cpp/test/distance/dist_adj.cu b/cpp/test/distance/dist_adj.cu index f3f36b4576..4f6dfaac24 100644 --- a/cpp/test/distance/dist_adj.cu +++ b/cpp/test/distance/dist_adj.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/distance/dist_canberra.cu b/cpp/test/distance/dist_canberra.cu index 1f368fbee8..db5555d9c8 100644 --- a/cpp/test/distance/dist_canberra.cu +++ b/cpp/test/distance/dist_canberra.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include "distance_base.cuh" namespace raft { diff --git a/cpp/test/distance/dist_chebyshev.cu b/cpp/test/distance/dist_chebyshev.cu index 8f506601ca..abad828de7 100644 --- a/cpp/test/distance/dist_chebyshev.cu +++ b/cpp/test/distance/dist_chebyshev.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include "distance_base.cuh" namespace raft { diff --git a/cpp/test/distance/dist_correlation.cu b/cpp/test/distance/dist_correlation.cu index 77d770b4d1..0e3f0ee0b5 100644 --- a/cpp/test/distance/dist_correlation.cu +++ b/cpp/test/distance/dist_correlation.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include "distance_base.cuh" namespace raft { diff --git a/cpp/test/distance/dist_cos.cu b/cpp/test/distance/dist_cos.cu index 900a71e514..9faf7651f7 100644 --- a/cpp/test/distance/dist_cos.cu +++ b/cpp/test/distance/dist_cos.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include "distance_base.cuh" namespace raft { diff --git a/cpp/test/distance/dist_euc_exp.cu b/cpp/test/distance/dist_euc_exp.cu index 5371b8a3e2..567e279691 100644 --- a/cpp/test/distance/dist_euc_exp.cu +++ b/cpp/test/distance/dist_euc_exp.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include "distance_base.cuh" namespace raft { diff --git a/cpp/test/distance/dist_euc_unexp.cu b/cpp/test/distance/dist_euc_unexp.cu index 81e6be7116..311ad190e2 100644 --- a/cpp/test/distance/dist_euc_unexp.cu +++ b/cpp/test/distance/dist_euc_unexp.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include "distance_base.cuh" namespace raft { diff --git a/cpp/test/distance/dist_eucsqrt_exp.cu b/cpp/test/distance/dist_eucsqrt_exp.cu index c4f2dc80c2..d717158649 100644 --- a/cpp/test/distance/dist_eucsqrt_exp.cu +++ b/cpp/test/distance/dist_eucsqrt_exp.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include "distance_base.cuh" namespace raft { diff --git a/cpp/test/distance/dist_hamming.cu b/cpp/test/distance/dist_hamming.cu index 616ce8f729..1eef9fba4e 100644 --- a/cpp/test/distance/dist_hamming.cu +++ b/cpp/test/distance/dist_hamming.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include "distance_base.cuh" namespace raft { diff --git a/cpp/test/distance/dist_hellinger.cu b/cpp/test/distance/dist_hellinger.cu index d6f994aaf6..85a157aa31 100644 --- a/cpp/test/distance/dist_hellinger.cu +++ b/cpp/test/distance/dist_hellinger.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include "distance_base.cuh" namespace raft { diff --git a/cpp/test/distance/dist_jensen_shannon.cu b/cpp/test/distance/dist_jensen_shannon.cu index 43e4f3aa0f..a1e2f9f38c 100644 --- a/cpp/test/distance/dist_jensen_shannon.cu +++ b/cpp/test/distance/dist_jensen_shannon.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include "distance_base.cuh" namespace raft { diff --git a/cpp/test/distance/dist_kl_divergence.cu b/cpp/test/distance/dist_kl_divergence.cu index 6a5fe8d7ac..94330d9450 100644 --- a/cpp/test/distance/dist_kl_divergence.cu +++ b/cpp/test/distance/dist_kl_divergence.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include "distance_base.cuh" namespace raft { diff --git a/cpp/test/distance/dist_l1.cu b/cpp/test/distance/dist_l1.cu index 322fb52d5c..dc6bcf72b7 100644 --- a/cpp/test/distance/dist_l1.cu +++ b/cpp/test/distance/dist_l1.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include "distance_base.cuh" namespace raft { diff --git a/cpp/test/distance/dist_minkowski.cu b/cpp/test/distance/dist_minkowski.cu index 3e0a2ead92..af2661da3a 100644 --- a/cpp/test/distance/dist_minkowski.cu +++ b/cpp/test/distance/dist_minkowski.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include "distance_base.cuh" namespace raft { diff --git a/cpp/test/distance/dist_russell_rao.cu b/cpp/test/distance/dist_russell_rao.cu index e92a01c70a..3c5124c31f 100644 --- a/cpp/test/distance/dist_russell_rao.cu +++ b/cpp/test/distance/dist_russell_rao.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include "distance_base.cuh" namespace raft { diff --git a/cpp/test/distance/distance_base.cuh b/cpp/test/distance/distance_base.cuh index 19d449c18b..067b1b2c0e 100644 --- a/cpp/test/distance/distance_base.cuh +++ b/cpp/test/distance/distance_base.cuh @@ -14,9 +14,10 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include +#include #include #include #include @@ -166,7 +167,7 @@ __global__ void naiveLpUnexpDistanceKernel(DataType* dist, int yidx = isRowMajor ? i + nidx * k : i * n + nidx; auto a = x[xidx]; auto b = y[yidx]; - auto diff = raft::L1Op()(a - b); + auto diff = raft::myAbs(a - b); acc += raft::myPow(diff, p); } auto one_over_p = 1 / p; diff --git a/cpp/test/distance/fused_l2_nn.cu b/cpp/test/distance/fused_l2_nn.cu index 800f45c7fc..252f56607f 100644 --- a/cpp/test/distance/fused_l2_nn.cu +++ b/cpp/test/distance/fused_l2_nn.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include @@ -208,8 +208,8 @@ struct CompareApproxAbsKVP { CompareApproxAbsKVP(T eps_) : eps(eps_) {} bool operator()(const KVP& a, const KVP& b) const { - T diff = raft::abs(raft::abs(a.value) - raft::abs(b.value)); - T m = std::max(raft::abs(a.value), raft::abs(b.value)); + T diff = std::abs(std::abs(a.value) - std::abs(b.value)); + T m = std::max(std::abs(a.value), std::abs(b.value)); T ratio = m >= eps ? diff / m : diff; return (ratio <= eps); } diff --git a/cpp/test/distance/gram.cu b/cpp/test/distance/gram.cu index 168e3d93f8..4366e023a0 100644 --- a/cpp/test/distance/gram.cu +++ b/cpp/test/distance/gram.cu @@ -18,7 +18,7 @@ #include #endif -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/label/label.cu b/cpp/test/label/label.cu index 02b3191c4d..bda87d423c 100644 --- a/cpp/test/label/label.cu +++ b/cpp/test/label/label.cu @@ -18,7 +18,7 @@ #include -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include diff --git a/cpp/test/label/merge_labels.cu b/cpp/test/label/merge_labels.cu index 184ab4922f..c3d2f82a84 100644 --- a/cpp/test/label/merge_labels.cu +++ b/cpp/test/label/merge_labels.cu @@ -17,7 +17,7 @@ #include #include -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/linalg/add.cu b/cpp/test/linalg/add.cu index c73791086b..0e5fc40232 100644 --- a/cpp/test/linalg/add.cu +++ b/cpp/test/linalg/add.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include "add.cuh" #include #include diff --git a/cpp/test/linalg/axpy.cu b/cpp/test/linalg/axpy.cu index f6cabae012..2eb11f314d 100644 --- a/cpp/test/linalg/axpy.cu +++ b/cpp/test/linalg/axpy.cu @@ -15,7 +15,7 @@ */ #include -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/linalg/binary_op.cu b/cpp/test/linalg/binary_op.cu index b92fa09427..ac143842cb 100644 --- a/cpp/test/linalg/binary_op.cu +++ b/cpp/test/linalg/binary_op.cu @@ -14,9 +14,10 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include "binary_op.cuh" #include +#include #include #include #include @@ -36,8 +37,7 @@ void binaryOpLaunch( auto in1_view = raft::make_device_vector_view(in1, len); auto in2_view = raft::make_device_vector_view(in2, len); - binary_op( - handle, in1_view, in2_view, out_view, [] __device__(InType a, InType b) { return a + b; }); + binary_op(handle, in1_view, in2_view, out_view, raft::add_op{}); } template @@ -139,12 +139,7 @@ class BinaryOpAlignment : public ::testing::Test { RAFT_CUDA_TRY(cudaMemsetAsync(x.data(), 0, n * sizeof(math_t), stream)); RAFT_CUDA_TRY(cudaMemsetAsync(y.data(), 0, n * sizeof(math_t), stream)); raft::linalg::binaryOp( - z.data() + 9, - x.data() + 137, - y.data() + 19, - 256, - [] __device__(math_t x, math_t y) { return x + y; }, - handle.get_stream()); + z.data() + 9, x.data() + 137, y.data() + 19, 256, raft::add_op{}, handle.get_stream()); } raft::handle_t handle; diff --git a/cpp/test/linalg/binary_op.cuh b/cpp/test/linalg/binary_op.cuh index 62820ddb97..8b0bc609d2 100644 --- a/cpp/test/linalg/binary_op.cuh +++ b/cpp/test/linalg/binary_op.cuh @@ -16,7 +16,7 @@ #pragma once -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include diff --git a/cpp/test/linalg/cholesky_r1.cu b/cpp/test/linalg/cholesky_r1.cu index 134b7645ff..9d90b03a6e 100644 --- a/cpp/test/linalg/cholesky_r1.cu +++ b/cpp/test/linalg/cholesky_r1.cu @@ -22,7 +22,7 @@ #include #include -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include namespace raft { diff --git a/cpp/test/linalg/coalesced_reduction.cu b/cpp/test/linalg/coalesced_reduction.cu index 791537b430..dc82ab9511 100644 --- a/cpp/test/linalg/coalesced_reduction.cu +++ b/cpp/test/linalg/coalesced_reduction.cu @@ -14,9 +14,10 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include "reduce.cuh" #include +#include #include #include #include @@ -47,8 +48,7 @@ void coalescedReductionLaunch( { auto dots_view = raft::make_device_vector_view(dots, rows); auto data_view = raft::make_device_matrix_view(data, rows, cols); - coalesced_reduction( - handle, data_view, dots_view, (T)0, inplace, [] __device__(T in, int i) { return in * in; }); + coalesced_reduction(handle, data_view, dots_view, (T)0, inplace, raft::sq_op{}); } template @@ -80,9 +80,9 @@ class coalescedReductionTest : public ::testing::TestWithParam{}, - raft::Sum{}, - raft::Nop{}); + raft::sq_op{}, + raft::add_op{}, + raft::identity_op{}); naiveCoalescedReduction(dots_exp.data(), data.data(), cols, @@ -90,9 +90,9 @@ class coalescedReductionTest : public ::testing::TestWithParam{}, - raft::Sum{}, - raft::Nop{}); + raft::sq_op{}, + raft::add_op{}, + raft::identity_op{}); coalescedReductionLaunch(handle, dots_act.data(), data.data(), cols, rows); coalescedReductionLaunch(handle, dots_act.data(), data.data(), cols, rows, true); diff --git a/cpp/test/linalg/divide.cu b/cpp/test/linalg/divide.cu index 4e2e5cdba7..4b5ea0a2dc 100644 --- a/cpp/test/linalg/divide.cu +++ b/cpp/test/linalg/divide.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include "unary_op.cuh" #include #include diff --git a/cpp/test/linalg/dot.cu b/cpp/test/linalg/dot.cu index b5007aea32..80a9f24aba 100644 --- a/cpp/test/linalg/dot.cu +++ b/cpp/test/linalg/dot.cu @@ -15,7 +15,7 @@ */ #include -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/linalg/eig.cu b/cpp/test/linalg/eig.cu index a913b14fcb..4b834c1aa8 100644 --- a/cpp/test/linalg/eig.cu +++ b/cpp/test/linalg/eig.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/linalg/eig_sel.cu b/cpp/test/linalg/eig_sel.cu index 9d57c4fa0a..c2b12e5d4a 100644 --- a/cpp/test/linalg/eig_sel.cu +++ b/cpp/test/linalg/eig_sel.cu @@ -16,7 +16,7 @@ #if CUDART_VERSION >= 10010 -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/linalg/eltwise.cu b/cpp/test/linalg/eltwise.cu index 07ded5ec79..d9ab7e0984 100644 --- a/cpp/test/linalg/eltwise.cu +++ b/cpp/test/linalg/eltwise.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/linalg/gemm_layout.cu b/cpp/test/linalg/gemm_layout.cu index dbe10ab4cc..a992a32304 100644 --- a/cpp/test/linalg/gemm_layout.cu +++ b/cpp/test/linalg/gemm_layout.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/linalg/gemv.cu b/cpp/test/linalg/gemv.cu index 2bd9abc200..594810bab2 100644 --- a/cpp/test/linalg/gemv.cu +++ b/cpp/test/linalg/gemv.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/linalg/map.cu b/cpp/test/linalg/map.cu index 95a2aff130..7e3a1562d9 100644 --- a/cpp/test/linalg/map.cu +++ b/cpp/test/linalg/map.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/linalg/map_then_reduce.cu b/cpp/test/linalg/map_then_reduce.cu index adf784f601..1e7f58ec38 100644 --- a/cpp/test/linalg/map_then_reduce.cu +++ b/cpp/test/linalg/map_then_reduce.cu @@ -14,12 +14,14 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include +#include #include #include #include +#include #include #include #include @@ -63,9 +65,8 @@ template void mapReduceLaunch( OutType* out_ref, OutType* out, const InType* in, size_t len, cudaStream_t stream) { - auto op = [] __device__(InType in) { return in; }; - naiveMapReduce(out_ref, in, len, op, stream); - mapThenSumReduce(out, len, op, 0, in); + naiveMapReduce(out_ref, in, len, raft::identity_op{}, stream); + mapThenSumReduce(out, len, raft::identity_op{}, 0, in); } template @@ -150,23 +151,21 @@ class MapGenericReduceTest : public ::testing::Test { void testMin() { - auto op = [] __device__(InType in) { return in; }; OutType neutral = std::numeric_limits::max(); auto output_view = raft::make_device_scalar_view(output.data()); auto input_view = raft::make_device_vector_view( input.data(), static_cast(input.size())); - map_reduce(handle, input_view, output_view, neutral, op, cub::Min()); + map_reduce(handle, input_view, output_view, neutral, raft::identity_op{}, cub::Min()); EXPECT_TRUE(raft::devArrMatch( OutType(1), output.data(), 1, raft::Compare(), handle.get_stream())); } void testMax() { - auto op = [] __device__(InType in) { return in; }; OutType neutral = std::numeric_limits::min(); auto output_view = raft::make_device_scalar_view(output.data()); auto input_view = raft::make_device_vector_view( input.data(), static_cast(input.size())); - map_reduce(handle, input_view, output_view, neutral, op, cub::Max()); + map_reduce(handle, input_view, output_view, neutral, raft::identity_op{}, cub::Max()); EXPECT_TRUE(raft::devArrMatch( OutType(5), output.data(), 1, raft::Compare(), handle.get_stream())); } diff --git a/cpp/test/linalg/matrix_vector.cu b/cpp/test/linalg/matrix_vector.cu index f103b5918b..7018e1da96 100644 --- a/cpp/test/linalg/matrix_vector.cu +++ b/cpp/test/linalg/matrix_vector.cu @@ -14,12 +14,14 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include "matrix_vector_op.cuh" #include #include +#include #include #include +#include #include namespace raft { @@ -113,34 +115,25 @@ void naive_matrix_vector_op_launch(const raft::handle_t& handle, return mat_element; } }; - auto operation_div = [] __device__(T mat_element, T vec_element) { - return mat_element / vec_element; - }; auto operation_bin_div_skip_zero = [] __device__(T mat_element, T vec_element) { if (raft::myAbs(vec_element) < T(1e-10)) return T(0); else return mat_element / vec_element; }; - auto operation_bin_add = [] __device__(T mat_element, T vec_element) { - return mat_element + vec_element; - }; - auto operation_bin_sub = [] __device__(T mat_element, T vec_element) { - return mat_element - vec_element; - }; if (operation_type == 0) { naiveMatVec( in, in, vec1, D, N, row_major, bcast_along_rows, operation_bin_mult_skip_zero, stream); } else if (operation_type == 1) { - naiveMatVec(in, in, vec1, D, N, row_major, bcast_along_rows, operation_div, stream); + naiveMatVec(in, in, vec1, D, N, row_major, bcast_along_rows, raft::div_op{}, stream); } else if (operation_type == 2) { naiveMatVec( in, in, vec1, D, N, row_major, bcast_along_rows, operation_bin_div_skip_zero, stream); } else if (operation_type == 3) { - naiveMatVec(in, in, vec1, D, N, row_major, bcast_along_rows, operation_bin_add, stream); + naiveMatVec(in, in, vec1, D, N, row_major, bcast_along_rows, raft::add_op{}, stream); } else if (operation_type == 4) { - naiveMatVec(in, in, vec1, D, N, row_major, bcast_along_rows, operation_bin_sub, stream); + naiveMatVec(in, in, vec1, D, N, row_major, bcast_along_rows, raft::sub_op{}, stream); } else { THROW("Unknown operation type '%d'!", (int)operation_type); } diff --git a/cpp/test/linalg/matrix_vector_op.cu b/cpp/test/linalg/matrix_vector_op.cu index 1c96c3fc74..e2775c168d 100644 --- a/cpp/test/linalg/matrix_vector_op.cu +++ b/cpp/test/linalg/matrix_vector_op.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include "matrix_vector_op.cuh" #include #include diff --git a/cpp/test/linalg/matrix_vector_op.cuh b/cpp/test/linalg/matrix_vector_op.cuh index 602d05d153..cf316ef111 100644 --- a/cpp/test/linalg/matrix_vector_op.cuh +++ b/cpp/test/linalg/matrix_vector_op.cuh @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include diff --git a/cpp/test/linalg/mean_squared_error.cu b/cpp/test/linalg/mean_squared_error.cu index 795f831417..18e7debcb1 100644 --- a/cpp/test/linalg/mean_squared_error.cu +++ b/cpp/test/linalg/mean_squared_error.cu @@ -15,7 +15,7 @@ */ #include -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/linalg/multiply.cu b/cpp/test/linalg/multiply.cu index 1d6446c5c0..c90fb93fd0 100644 --- a/cpp/test/linalg/multiply.cu +++ b/cpp/test/linalg/multiply.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include "unary_op.cuh" #include #include diff --git a/cpp/test/linalg/norm.cu b/cpp/test/linalg/norm.cu index f0b8d3bb55..94540b9ff6 100644 --- a/cpp/test/linalg/norm.cu +++ b/cpp/test/linalg/norm.cu @@ -14,10 +14,12 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include +#include #include #include +#include #include #include @@ -95,11 +97,10 @@ class RowNormTest : public ::testing::TestWithParam> { auto input_col_major = raft::make_device_matrix_view( data.data(), params.rows, params.cols); if (params.do_sqrt) { - auto fin_op = [] __device__(const T in) { return raft::mySqrt(in); }; if (params.rowMajor) { - norm(handle, input_row_major, output_view, params.type, Apply::ALONG_ROWS, fin_op); + norm(handle, input_row_major, output_view, params.type, Apply::ALONG_ROWS, raft::sqrt_op{}); } else { - norm(handle, input_col_major, output_view, params.type, Apply::ALONG_ROWS, fin_op); + norm(handle, input_col_major, output_view, params.type, Apply::ALONG_ROWS, raft::sqrt_op{}); } } else { if (params.rowMajor) { @@ -171,11 +172,12 @@ class ColNormTest : public ::testing::TestWithParam> { auto input_col_major = raft::make_device_matrix_view( data.data(), params.rows, params.cols); if (params.do_sqrt) { - auto fin_op = [] __device__(const T in) { return raft::mySqrt(in); }; if (params.rowMajor) { - norm(handle, input_row_major, output_view, params.type, Apply::ALONG_COLUMNS, fin_op); + norm( + handle, input_row_major, output_view, params.type, Apply::ALONG_COLUMNS, raft::sqrt_op{}); } else { - norm(handle, input_col_major, output_view, params.type, Apply::ALONG_COLUMNS, fin_op); + norm( + handle, input_col_major, output_view, params.type, Apply::ALONG_COLUMNS, raft::sqrt_op{}); } } else { if (params.rowMajor) { diff --git a/cpp/test/linalg/normalize.cu b/cpp/test/linalg/normalize.cu index cb949b6a5d..0a6786b1ee 100644 --- a/cpp/test/linalg/normalize.cu +++ b/cpp/test/linalg/normalize.cu @@ -14,12 +14,14 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include +#include #include #include #include #include +#include #include #include @@ -48,20 +50,13 @@ void rowNormalizeRef( { rmm::device_uvector norm(rows, stream); if (norm_type == raft::linalg::L2Norm) { - raft::linalg::rowNorm(norm.data(), in, cols, rows, norm_type, true, stream, raft::SqrtOp()); + raft::linalg::rowNorm(norm.data(), in, cols, rows, norm_type, true, stream, raft::sqrt_op()); } else { - raft::linalg::rowNorm(norm.data(), in, cols, rows, norm_type, true, stream, raft::Nop()); + raft::linalg::rowNorm( + norm.data(), in, cols, rows, norm_type, true, stream, raft::identity_op()); } raft::linalg::matrixVectorOp( - out, - in, - norm.data(), - cols, - rows, - true, - false, - [] __device__(T a, T b) { return a / b; }, - stream); + out, in, norm.data(), cols, rows, true, false, raft::div_op{}, stream); } template diff --git a/cpp/test/linalg/power.cu b/cpp/test/linalg/power.cu index bdab49d5c8..54c2e2a7aa 100644 --- a/cpp/test/linalg/power.cu +++ b/cpp/test/linalg/power.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/linalg/reduce.cu b/cpp/test/linalg/reduce.cu index 00f3810d28..4ad382c4f7 100644 --- a/cpp/test/linalg/reduce.cu +++ b/cpp/test/linalg/reduce.cu @@ -14,10 +14,11 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include "reduce.cuh" #include #include +#include #include #include #include @@ -101,9 +102,9 @@ void reduceLaunch(OutType* dots, template , - typename ReduceLambda = raft::Sum, - typename FinalLambda = raft::SqrtOp> + typename MainLambda = raft::sq_op, + typename ReduceLambda = raft::add_op, + typename FinalLambda = raft::sqrt_op> class ReduceTest : public ::testing::TestWithParam> { public: ReduceTest() @@ -301,7 +302,7 @@ REDUCE_TEST((ReduceTest, ArgMaxOp, - raft::Nop, int>>), + raft::identity_op>), ReduceTestKVPISI32, inputs_kvpis_i32); REDUCE_TEST((ReduceTest, ArgMaxOp, - raft::Nop, int>>), + raft::identity_op>), ReduceTestKVPIFI32, inputs_kvpif_i32); REDUCE_TEST((ReduceTest, ArgMaxOp, - raft::Nop, int>>), + raft::identity_op>), ReduceTestKVPIDI32, inputs_kvpid_i32); diff --git a/cpp/test/linalg/reduce.cuh b/cpp/test/linalg/reduce.cuh index 0dcffd3f41..17e91ce202 100644 --- a/cpp/test/linalg/reduce.cuh +++ b/cpp/test/linalg/reduce.cuh @@ -17,6 +17,7 @@ #pragma once #include +#include #include #include #include @@ -61,9 +62,9 @@ __global__ void naiveCoalescedReductionKernel(OutType* dots, template , - typename ReduceLambda = raft::Sum, - typename FinalLambda = raft::Nop> + typename MainLambda = raft::identity_op, + typename ReduceLambda = raft::add_op, + typename FinalLambda = raft::identity_op> void naiveCoalescedReduction(OutType* dots, const InType* data, IdxType D, @@ -71,9 +72,9 @@ void naiveCoalescedReduction(OutType* dots, cudaStream_t stream, OutType init, bool inplace = false, - MainLambda main_op = raft::Nop(), - ReduceLambda reduce_op = raft::Sum(), - FinalLambda fin_op = raft::Nop()) + MainLambda main_op = raft::identity_op(), + ReduceLambda reduce_op = raft::add_op(), + FinalLambda fin_op = raft::identity_op()) { static const IdxType TPB = 64; IdxType nblks = raft::ceildiv(N, TPB); @@ -115,9 +116,9 @@ __global__ void naiveStridedReductionKernel(OutType* dots, template , - typename ReduceLambda = raft::Sum, - typename FinalLambda = raft::Nop> + typename MainLambda = raft::identity_op, + typename ReduceLambda = raft::add_op, + typename FinalLambda = raft::identity_op> void naiveStridedReduction(OutType* dots, const InType* data, IdxType D, @@ -125,9 +126,9 @@ void naiveStridedReduction(OutType* dots, cudaStream_t stream, OutType init, bool inplace = false, - MainLambda main_op = raft::Nop(), - ReduceLambda reduce_op = raft::Sum(), - FinalLambda fin_op = raft::Nop()) + MainLambda main_op = raft::identity_op(), + ReduceLambda reduce_op = raft::add_op(), + FinalLambda fin_op = raft::identity_op()) { static const IdxType TPB = 64; IdxType nblks = raft::ceildiv(D, TPB); @@ -139,9 +140,9 @@ void naiveStridedReduction(OutType* dots, template , - typename ReduceLambda = raft::Sum, - typename FinalLambda = raft::Nop> + typename MainLambda = raft::identity_op, + typename ReduceLambda = raft::add_op, + typename FinalLambda = raft::identity_op> void naiveReduction(OutType* dots, const InType* data, IdxType D, @@ -151,9 +152,9 @@ void naiveReduction(OutType* dots, cudaStream_t stream, OutType init, bool inplace = false, - MainLambda main_op = raft::Nop(), - ReduceLambda reduce_op = raft::Sum(), - FinalLambda fin_op = raft::Nop()) + MainLambda main_op = raft::identity_op(), + ReduceLambda reduce_op = raft::add_op(), + FinalLambda fin_op = raft::identity_op()) { if (rowMajor && alongRows) { naiveCoalescedReduction(dots, data, D, N, stream, init, inplace, main_op, reduce_op, fin_op); diff --git a/cpp/test/linalg/reduce_cols_by_key.cu b/cpp/test/linalg/reduce_cols_by_key.cu index 63afbe2fed..a378c450ce 100644 --- a/cpp/test/linalg/reduce_cols_by_key.cu +++ b/cpp/test/linalg/reduce_cols_by_key.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/linalg/reduce_rows_by_key.cu b/cpp/test/linalg/reduce_rows_by_key.cu index 7b124cb7bb..22229d2224 100644 --- a/cpp/test/linalg/reduce_rows_by_key.cu +++ b/cpp/test/linalg/reduce_rows_by_key.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/linalg/rsvd.cu b/cpp/test/linalg/rsvd.cu index f774d59631..04e17468c3 100644 --- a/cpp/test/linalg/rsvd.cu +++ b/cpp/test/linalg/rsvd.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/linalg/sqrt.cu b/cpp/test/linalg/sqrt.cu index ed57e94914..9008313b58 100644 --- a/cpp/test/linalg/sqrt.cu +++ b/cpp/test/linalg/sqrt.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/linalg/strided_reduction.cu b/cpp/test/linalg/strided_reduction.cu index 77ca585ea5..6a8c43ad52 100644 --- a/cpp/test/linalg/strided_reduction.cu +++ b/cpp/test/linalg/strided_reduction.cu @@ -14,11 +14,13 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include "reduce.cuh" #include +#include #include #include +#include #include namespace raft { @@ -38,7 +40,7 @@ void stridedReductionLaunch( raft::handle_t handle{stream}; auto dots_view = raft::make_device_vector_view(dots, cols); auto data_view = raft::make_device_matrix_view(data, rows, cols); - strided_reduction(handle, data_view, dots_view, (T)0, inplace, raft::L2Op{}); + strided_reduction(handle, data_view, dots_view, (T)0, inplace, raft::sq_op{}); } template @@ -70,9 +72,9 @@ class stridedReductionTest : public ::testing::TestWithParam{}, - raft::Sum{}, - raft::Nop{}); + raft::sq_op{}, + raft::add_op{}, + raft::identity_op{}); naiveStridedReduction(dots_exp.data(), data.data(), cols, @@ -80,9 +82,9 @@ class stridedReductionTest : public ::testing::TestWithParam{}, - raft::Sum{}, - raft::Nop{}); + raft::sq_op{}, + raft::add_op{}, + raft::identity_op{}); stridedReductionLaunch(dots_act.data(), data.data(), cols, rows, false, stream); stridedReductionLaunch(dots_act.data(), data.data(), cols, rows, true, stream); handle.sync_stream(stream); diff --git a/cpp/test/linalg/subtract.cu b/cpp/test/linalg/subtract.cu index 3904f9f33f..426fc98f9f 100644 --- a/cpp/test/linalg/subtract.cu +++ b/cpp/test/linalg/subtract.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/linalg/svd.cu b/cpp/test/linalg/svd.cu index c18417dc9e..7918d481db 100644 --- a/cpp/test/linalg/svd.cu +++ b/cpp/test/linalg/svd.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/linalg/ternary_op.cu b/cpp/test/linalg/ternary_op.cu index e172d771cd..c78df08820 100644 --- a/cpp/test/linalg/ternary_op.cu +++ b/cpp/test/linalg/ternary_op.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/linalg/transpose.cu b/cpp/test/linalg/transpose.cu index 6a05317f49..110dc527d3 100644 --- a/cpp/test/linalg/transpose.cu +++ b/cpp/test/linalg/transpose.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include diff --git a/cpp/test/linalg/unary_op.cu b/cpp/test/linalg/unary_op.cu index 57b009a0ac..3ebf70e69f 100644 --- a/cpp/test/linalg/unary_op.cu +++ b/cpp/test/linalg/unary_op.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include "unary_op.cuh" #include #include diff --git a/cpp/test/linalg/unary_op.cuh b/cpp/test/linalg/unary_op.cuh index 190d531a9f..28bcc004a4 100644 --- a/cpp/test/linalg/unary_op.cuh +++ b/cpp/test/linalg/unary_op.cuh @@ -16,7 +16,7 @@ #pragma once -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include diff --git a/cpp/test/matrix/argmax.cu b/cpp/test/matrix/argmax.cu index 0219eb1aff..33af0ce5a4 100644 --- a/cpp/test/matrix/argmax.cu +++ b/cpp/test/matrix/argmax.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/matrix/argmin.cu b/cpp/test/matrix/argmin.cu index bdf178cd8a..22f0a6cac0 100644 --- a/cpp/test/matrix/argmin.cu +++ b/cpp/test/matrix/argmin.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/matrix/columnSort.cu b/cpp/test/matrix/columnSort.cu index aba1c4e1f0..000a911efd 100644 --- a/cpp/test/matrix/columnSort.cu +++ b/cpp/test/matrix/columnSort.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/matrix/diagonal.cu b/cpp/test/matrix/diagonal.cu index e1ad9e144b..f6cd178b23 100644 --- a/cpp/test/matrix/diagonal.cu +++ b/cpp/test/matrix/diagonal.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/matrix/gather.cu b/cpp/test/matrix/gather.cu index 4b3244913b..0bea62e9cf 100644 --- a/cpp/test/matrix/gather.cu +++ b/cpp/test/matrix/gather.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/matrix/linewise_op.cu b/cpp/test/matrix/linewise_op.cu index 2e3d54dcf5..9ce1371944 100644 --- a/cpp/test/matrix/linewise_op.cu +++ b/cpp/test/matrix/linewise_op.cu @@ -15,11 +15,12 @@ */ #include "../linalg/matrix_vector_op.cuh" -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include #include +#include #include #include #include @@ -58,7 +59,6 @@ struct LinewiseTest : public ::testing::TestWithParam void runLinewiseSum(T* out, const T* in, const I lineLen, const I nLines, const T* vec) { - auto f = [] __device__(T a, T b) -> T { return a + b; }; constexpr auto rowmajor = std::is_same_v; I m = rowmajor ? lineLen : nLines; @@ -68,7 +68,8 @@ struct LinewiseTest : public ::testing::TestWithParam(out, n, m); auto vec_view = raft::make_device_vector_view(vec, lineLen); - matrix::linewise_op(handle, in_view, out_view, raft::is_row_major(in_view), f, vec_view); + matrix::linewise_op( + handle, in_view, out_view, raft::is_row_major(in_view), raft::add_op{}, vec_view); } template @@ -107,9 +108,8 @@ struct LinewiseTest : public ::testing::TestWithParam T { return a + b; }; auto vec_view = raft::make_device_vector_view(vec, alongLines ? lineLen : nLines); - matrix::linewise_op(handle, in, out, alongLines, f, vec_view); + matrix::linewise_op(handle, in, out, alongLines, raft::add_op{}, vec_view); } /** diff --git a/cpp/test/matrix/math.cu b/cpp/test/matrix/math.cu index 684b550dfc..f2c1a6249c 100644 --- a/cpp/test/matrix/math.cu +++ b/cpp/test/matrix/math.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include diff --git a/cpp/test/matrix/matrix.cu b/cpp/test/matrix/matrix.cu index 78391d5ff2..8cfbdac32b 100644 --- a/cpp/test/matrix/matrix.cu +++ b/cpp/test/matrix/matrix.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include diff --git a/cpp/test/matrix/norm.cu b/cpp/test/matrix/norm.cu index 38fdd409eb..b1e10c9047 100644 --- a/cpp/test/matrix/norm.cu +++ b/cpp/test/matrix/norm.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/matrix/reverse.cu b/cpp/test/matrix/reverse.cu index c905b8711e..49d501b6d0 100644 --- a/cpp/test/matrix/reverse.cu +++ b/cpp/test/matrix/reverse.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/matrix/slice.cu b/cpp/test/matrix/slice.cu index 5faf672d13..9060357b3f 100644 --- a/cpp/test/matrix/slice.cu +++ b/cpp/test/matrix/slice.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/matrix/triangular.cu b/cpp/test/matrix/triangular.cu index 9af3defb5d..9c6c49066b 100644 --- a/cpp/test/matrix/triangular.cu +++ b/cpp/test/matrix/triangular.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/mst.cu b/cpp/test/mst.cu index 544ca80a46..d11f0b5842 100644 --- a/cpp/test/mst.cu +++ b/cpp/test/mst.cu @@ -16,7 +16,7 @@ #include -#include "test_utils.h" +#include "test_utils.cuh" #include #include #include diff --git a/cpp/test/neighbors/ann_ivf_flat.cu b/cpp/test/neighbors/ann_ivf_flat.cu index 735d569318..1207b75a4a 100644 --- a/cpp/test/neighbors/ann_ivf_flat.cu +++ b/cpp/test/neighbors/ann_ivf_flat.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include "ann_utils.cuh" #include diff --git a/cpp/test/neighbors/ann_ivf_pq.cuh b/cpp/test/neighbors/ann_ivf_pq.cuh index 9d6ad11ccb..7c3ec044b1 100644 --- a/cpp/test/neighbors/ann_ivf_pq.cuh +++ b/cpp/test/neighbors/ann_ivf_pq.cuh @@ -15,7 +15,7 @@ */ #pragma once -#include "../test_utils.h" +#include "../test_utils.cuh" #include "ann_utils.cuh" #include diff --git a/cpp/test/neighbors/ann_utils.cuh b/cpp/test/neighbors/ann_utils.cuh index 07ef410d36..05fe6ab92d 100644 --- a/cpp/test/neighbors/ann_utils.cuh +++ b/cpp/test/neighbors/ann_utils.cuh @@ -25,7 +25,7 @@ #include #include -#include "../test_utils.h" +#include "../test_utils.cuh" #include namespace raft::neighbors { diff --git a/cpp/test/neighbors/ball_cover.cu b/cpp/test/neighbors/ball_cover.cu index 47030b0d62..7405863b9f 100644 --- a/cpp/test/neighbors/ball_cover.cu +++ b/cpp/test/neighbors/ball_cover.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include "spatial_data.h" #include #include diff --git a/cpp/test/neighbors/epsilon_neighborhood.cu b/cpp/test/neighbors/epsilon_neighborhood.cu index c83817f6f8..4f33db489e 100644 --- a/cpp/test/neighbors/epsilon_neighborhood.cu +++ b/cpp/test/neighbors/epsilon_neighborhood.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/neighbors/faiss_mr.cu b/cpp/test/neighbors/faiss_mr.cu index 91ba1cc94c..38e793d120 100644 --- a/cpp/test/neighbors/faiss_mr.cu +++ b/cpp/test/neighbors/faiss_mr.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include diff --git a/cpp/test/neighbors/fused_l2_knn.cu b/cpp/test/neighbors/fused_l2_knn.cu index 8df193d53d..d57f99da50 100644 --- a/cpp/test/neighbors/fused_l2_knn.cu +++ b/cpp/test/neighbors/fused_l2_knn.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include diff --git a/cpp/test/neighbors/haversine.cu b/cpp/test/neighbors/haversine.cu index 78bd377156..91a2ca07df 100644 --- a/cpp/test/neighbors/haversine.cu +++ b/cpp/test/neighbors/haversine.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/neighbors/knn.cu b/cpp/test/neighbors/knn.cu index eb5ecf663f..ff3a6a80b4 100644 --- a/cpp/test/neighbors/knn.cu +++ b/cpp/test/neighbors/knn.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include diff --git a/cpp/test/neighbors/refine.cu b/cpp/test/neighbors/refine.cu index 06c1317b1e..674171e030 100644 --- a/cpp/test/neighbors/refine.cu +++ b/cpp/test/neighbors/refine.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include "ann_utils.cuh" #include "refine_helper.cuh" diff --git a/cpp/test/neighbors/selection.cu b/cpp/test/neighbors/selection.cu index bfcfca5ead..d793ea46ee 100644 --- a/cpp/test/neighbors/selection.cu +++ b/cpp/test/neighbors/selection.cu @@ -20,7 +20,7 @@ #include #include -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include diff --git a/cpp/test/random/make_blobs.cu b/cpp/test/random/make_blobs.cu index 1f14fd23f7..741b374c8c 100644 --- a/cpp/test/random/make_blobs.cu +++ b/cpp/test/random/make_blobs.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/random/make_regression.cu b/cpp/test/random/make_regression.cu index 65d4c4cb31..c2e447adf8 100644 --- a/cpp/test/random/make_regression.cu +++ b/cpp/test/random/make_regression.cu @@ -19,7 +19,7 @@ #include #include -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/random/multi_variable_gaussian.cu b/cpp/test/random/multi_variable_gaussian.cu index f9f79e6845..04626a53c7 100644 --- a/cpp/test/random/multi_variable_gaussian.cu +++ b/cpp/test/random/multi_variable_gaussian.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/random/permute.cu b/cpp/test/random/permute.cu index 32e5540d51..be4f2a005f 100644 --- a/cpp/test/random/permute.cu +++ b/cpp/test/random/permute.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/random/rmat_rectangular_generator.cu b/cpp/test/random/rmat_rectangular_generator.cu index 0baaaf28cf..c1c4752453 100644 --- a/cpp/test/random/rmat_rectangular_generator.cu +++ b/cpp/test/random/rmat_rectangular_generator.cu @@ -19,7 +19,7 @@ #include #include -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include diff --git a/cpp/test/random/rng.cu b/cpp/test/random/rng.cu index 82f6e0e247..bdce79b76e 100644 --- a/cpp/test/random/rng.cu +++ b/cpp/test/random/rng.cu @@ -17,7 +17,7 @@ #include #include -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/random/rng_discrete.cu b/cpp/test/random/rng_discrete.cu index b7aef51af5..1dee281527 100644 --- a/cpp/test/random/rng_discrete.cu +++ b/cpp/test/random/rng_discrete.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/random/rng_int.cu b/cpp/test/random/rng_int.cu index d5270c456e..89d6d208a5 100644 --- a/cpp/test/random/rng_int.cu +++ b/cpp/test/random/rng_int.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/random/sample_without_replacement.cu b/cpp/test/random/sample_without_replacement.cu index 355df1fcc1..a6cf3569e6 100644 --- a/cpp/test/random/sample_without_replacement.cu +++ b/cpp/test/random/sample_without_replacement.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/sparse/add.cu b/cpp/test/sparse/add.cu index 862cbffdc7..692094a861 100644 --- a/cpp/test/sparse/add.cu +++ b/cpp/test/sparse/add.cu @@ -20,7 +20,7 @@ #include #include -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include diff --git a/cpp/test/sparse/convert_coo.cu b/cpp/test/sparse/convert_coo.cu index 1142c6f3f2..21f81bfa6a 100644 --- a/cpp/test/sparse/convert_coo.cu +++ b/cpp/test/sparse/convert_coo.cu @@ -22,7 +22,7 @@ #include -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include diff --git a/cpp/test/sparse/convert_csr.cu b/cpp/test/sparse/convert_csr.cu index 007cbd7fdb..bc81dcaba5 100644 --- a/cpp/test/sparse/convert_csr.cu +++ b/cpp/test/sparse/convert_csr.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include diff --git a/cpp/test/sparse/csr_row_slice.cu b/cpp/test/sparse/csr_row_slice.cu index 39b235d5f1..6b10f8d798 100644 --- a/cpp/test/sparse/csr_row_slice.cu +++ b/cpp/test/sparse/csr_row_slice.cu @@ -24,7 +24,7 @@ #include -#include "../test_utils.h" +#include "../test_utils.cuh" namespace raft { namespace sparse { diff --git a/cpp/test/sparse/csr_to_dense.cu b/cpp/test/sparse/csr_to_dense.cu index 5811c5c22b..b99d6fa7fd 100644 --- a/cpp/test/sparse/csr_to_dense.cu +++ b/cpp/test/sparse/csr_to_dense.cu @@ -24,7 +24,7 @@ #include -#include "../test_utils.h" +#include "../test_utils.cuh" namespace raft { namespace sparse { diff --git a/cpp/test/sparse/csr_transpose.cu b/cpp/test/sparse/csr_transpose.cu index 108d38a8b4..2342f6e7ef 100644 --- a/cpp/test/sparse/csr_transpose.cu +++ b/cpp/test/sparse/csr_transpose.cu @@ -23,7 +23,7 @@ #include #include -#include "../test_utils.h" +#include "../test_utils.cuh" namespace raft { namespace sparse { diff --git a/cpp/test/sparse/degree.cu b/cpp/test/sparse/degree.cu index a4af021c05..b5b22ff15c 100644 --- a/cpp/test/sparse/degree.cu +++ b/cpp/test/sparse/degree.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include diff --git a/cpp/test/sparse/dist_coo_spmv.cu b/cpp/test/sparse/dist_coo_spmv.cu index c004aeaef0..be73bd3130 100644 --- a/cpp/test/sparse/dist_coo_spmv.cu +++ b/cpp/test/sparse/dist_coo_spmv.cu @@ -16,6 +16,7 @@ #include +#include #include #include #include @@ -26,7 +27,7 @@ #include #include -#include "../test_utils.h" +#include "../test_utils.cuh" #include @@ -158,13 +159,12 @@ class SparseDistanceCOOSPMVTest case raft::distance::DistanceType::LpUnexpanded: { compute_dist( detail::PDiff(params.input_configuration.metric_arg), detail::Sum(), detail::AtomicAdd()); - float p = 1.0f / params.input_configuration.metric_arg; - raft::linalg::unaryOp( - out_dists.data(), - out_dists.data(), - dist_config.a_nrows * dist_config.b_nrows, - [=] __device__(value_t input) { return powf(input, p); }, - dist_config.handle.get_stream()); + value_t p = value_t{1} / params.input_configuration.metric_arg; + raft::linalg::unaryOp(out_dists.data(), + out_dists.data(), + dist_config.a_nrows * dist_config.b_nrows, + raft::pow_const_op{p}, + dist_config.handle.get_stream()); } break; default: throw raft::exception("Unknown distance"); diff --git a/cpp/test/sparse/distance.cu b/cpp/test/sparse/distance.cu index 4ce2f4cbde..367bfddad1 100644 --- a/cpp/test/sparse/distance.cu +++ b/cpp/test/sparse/distance.cu @@ -24,7 +24,7 @@ #include -#include "../test_utils.h" +#include "../test_utils.cuh" namespace raft { namespace sparse { diff --git a/cpp/test/sparse/filter.cu b/cpp/test/sparse/filter.cu index ba80c84fd5..6ada5ddcad 100644 --- a/cpp/test/sparse/filter.cu +++ b/cpp/test/sparse/filter.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include diff --git a/cpp/test/sparse/neighbors/brute_force.cu b/cpp/test/sparse/neighbors/brute_force.cu index 8fa5e8322d..96ba3bc48f 100644 --- a/cpp/test/sparse/neighbors/brute_force.cu +++ b/cpp/test/sparse/neighbors/brute_force.cu @@ -17,7 +17,7 @@ #include #include -#include "../../test_utils.h" +#include "../../test_utils.cuh" #include #include diff --git a/cpp/test/sparse/neighbors/connect_components.cu b/cpp/test/sparse/neighbors/connect_components.cu index fc4eecd4ee..f469cc1aa2 100644 --- a/cpp/test/sparse/neighbors/connect_components.cu +++ b/cpp/test/sparse/neighbors/connect_components.cu @@ -34,7 +34,7 @@ #include #include -#include "../../test_utils.h" +#include "../../test_utils.cuh" namespace raft { namespace sparse { diff --git a/cpp/test/sparse/neighbors/knn_graph.cu b/cpp/test/sparse/neighbors/knn_graph.cu index d6f9e8386f..c3700a536c 100644 --- a/cpp/test/sparse/neighbors/knn_graph.cu +++ b/cpp/test/sparse/neighbors/knn_graph.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../../test_utils.h" +#include "../../test_utils.cuh" #include #include #include diff --git a/cpp/test/sparse/norm.cu b/cpp/test/sparse/norm.cu index 8e54edd6c9..1a69acc535 100644 --- a/cpp/test/sparse/norm.cu +++ b/cpp/test/sparse/norm.cu @@ -16,7 +16,7 @@ #include -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include diff --git a/cpp/test/sparse/reduce.cu b/cpp/test/sparse/reduce.cu index 4280192723..5752624435 100644 --- a/cpp/test/sparse/reduce.cu +++ b/cpp/test/sparse/reduce.cu @@ -16,7 +16,7 @@ #include -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/sparse/row_op.cu b/cpp/test/sparse/row_op.cu index 732bd06103..6393c5ee86 100644 --- a/cpp/test/sparse/row_op.cu +++ b/cpp/test/sparse/row_op.cu @@ -19,7 +19,7 @@ #include #include -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include diff --git a/cpp/test/sparse/sort.cu b/cpp/test/sparse/sort.cu index 9b75965498..23c2f5b67a 100644 --- a/cpp/test/sparse/sort.cu +++ b/cpp/test/sparse/sort.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/sparse/spgemmi.cu b/cpp/test/sparse/spgemmi.cu index a132c94fde..653c2fa29b 100644 --- a/cpp/test/sparse/spgemmi.cu +++ b/cpp/test/sparse/spgemmi.cu @@ -16,7 +16,7 @@ #include -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include diff --git a/cpp/test/sparse/symmetrize.cu b/cpp/test/sparse/symmetrize.cu index 7cf1a1e07d..6f2f877304 100644 --- a/cpp/test/sparse/symmetrize.cu +++ b/cpp/test/sparse/symmetrize.cu @@ -22,7 +22,7 @@ #include #include -#include "../test_utils.h" +#include "../test_utils.cuh" #include diff --git a/cpp/test/stats/accuracy.cu b/cpp/test/stats/accuracy.cu index 192c187794..eaccdecab4 100644 --- a/cpp/test/stats/accuracy.cu +++ b/cpp/test/stats/accuracy.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/stats/adjusted_rand_index.cu b/cpp/test/stats/adjusted_rand_index.cu index f113af821d..52bc72174a 100644 --- a/cpp/test/stats/adjusted_rand_index.cu +++ b/cpp/test/stats/adjusted_rand_index.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/stats/completeness_score.cu b/cpp/test/stats/completeness_score.cu index 2f8a40afdc..a9d1748f88 100644 --- a/cpp/test/stats/completeness_score.cu +++ b/cpp/test/stats/completeness_score.cu @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/stats/contingencyMatrix.cu b/cpp/test/stats/contingencyMatrix.cu index 7943610689..d27114388e 100644 --- a/cpp/test/stats/contingencyMatrix.cu +++ b/cpp/test/stats/contingencyMatrix.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/stats/cov.cu b/cpp/test/stats/cov.cu index 890c5b7826..59a2c6e081 100644 --- a/cpp/test/stats/cov.cu +++ b/cpp/test/stats/cov.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/stats/dispersion.cu b/cpp/test/stats/dispersion.cu index 4f18c9fb54..e414fcf5f4 100644 --- a/cpp/test/stats/dispersion.cu +++ b/cpp/test/stats/dispersion.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/stats/entropy.cu b/cpp/test/stats/entropy.cu index 04aa9f7a80..96b2b9f590 100644 --- a/cpp/test/stats/entropy.cu +++ b/cpp/test/stats/entropy.cu @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/stats/histogram.cu b/cpp/test/stats/histogram.cu index d9793a57df..76677ac27c 100644 --- a/cpp/test/stats/histogram.cu +++ b/cpp/test/stats/histogram.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/stats/homogeneity_score.cu b/cpp/test/stats/homogeneity_score.cu index 9bd6d9266b..ecbf160770 100644 --- a/cpp/test/stats/homogeneity_score.cu +++ b/cpp/test/stats/homogeneity_score.cu @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/stats/information_criterion.cu b/cpp/test/stats/information_criterion.cu index 4a9a2128c6..2cfbd787c6 100644 --- a/cpp/test/stats/information_criterion.cu +++ b/cpp/test/stats/information_criterion.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include diff --git a/cpp/test/stats/kl_divergence.cu b/cpp/test/stats/kl_divergence.cu index 58a64f7199..b5a6c393f3 100644 --- a/cpp/test/stats/kl_divergence.cu +++ b/cpp/test/stats/kl_divergence.cu @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/stats/mean.cu b/cpp/test/stats/mean.cu index b299f81f68..19398d6d8e 100644 --- a/cpp/test/stats/mean.cu +++ b/cpp/test/stats/mean.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/stats/mean_center.cu b/cpp/test/stats/mean_center.cu index 30dcdd475b..31947ef527 100644 --- a/cpp/test/stats/mean_center.cu +++ b/cpp/test/stats/mean_center.cu @@ -15,7 +15,7 @@ */ #include "../linalg/matrix_vector_op.cuh" -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/stats/meanvar.cu b/cpp/test/stats/meanvar.cu index 424395c5e8..fb9fc13dec 100644 --- a/cpp/test/stats/meanvar.cu +++ b/cpp/test/stats/meanvar.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/stats/minmax.cu b/cpp/test/stats/minmax.cu index a2ba6bfc9e..1171995d5c 100644 --- a/cpp/test/stats/minmax.cu +++ b/cpp/test/stats/minmax.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/stats/mutual_info_score.cu b/cpp/test/stats/mutual_info_score.cu index fb9362df52..8b6e7b2095 100644 --- a/cpp/test/stats/mutual_info_score.cu +++ b/cpp/test/stats/mutual_info_score.cu @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/stats/r2_score.cu b/cpp/test/stats/r2_score.cu index d77daacb04..7fb15505ab 100644 --- a/cpp/test/stats/r2_score.cu +++ b/cpp/test/stats/r2_score.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/stats/rand_index.cu b/cpp/test/stats/rand_index.cu index 67e4ab5517..0010f3cbcd 100644 --- a/cpp/test/stats/rand_index.cu +++ b/cpp/test/stats/rand_index.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include diff --git a/cpp/test/stats/regression_metrics.cu b/cpp/test/stats/regression_metrics.cu index effc3d04dd..86ac03c8b3 100644 --- a/cpp/test/stats/regression_metrics.cu +++ b/cpp/test/stats/regression_metrics.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/stats/silhouette_score.cu b/cpp/test/stats/silhouette_score.cu index 37a6fff786..876926b71a 100644 --- a/cpp/test/stats/silhouette_score.cu +++ b/cpp/test/stats/silhouette_score.cu @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/stats/stddev.cu b/cpp/test/stats/stddev.cu index 73f30f17e9..7f54eee2ab 100644 --- a/cpp/test/stats/stddev.cu +++ b/cpp/test/stats/stddev.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/stats/sum.cu b/cpp/test/stats/sum.cu index e67988abb0..c69bd04c6e 100644 --- a/cpp/test/stats/sum.cu +++ b/cpp/test/stats/sum.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include diff --git a/cpp/test/stats/trustworthiness.cu b/cpp/test/stats/trustworthiness.cu index cbb8228f8f..a95cddf5aa 100644 --- a/cpp/test/stats/trustworthiness.cu +++ b/cpp/test/stats/trustworthiness.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/stats/v_measure.cu b/cpp/test/stats/v_measure.cu index 0cbc2da7d9..79899c1d75 100644 --- a/cpp/test/stats/v_measure.cu +++ b/cpp/test/stats/v_measure.cu @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/stats/weighted_mean.cu b/cpp/test/stats/weighted_mean.cu index 9f33855572..8b4af07898 100644 --- a/cpp/test/stats/weighted_mean.cu +++ b/cpp/test/stats/weighted_mean.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/test_utils.cuh b/cpp/test/test_utils.cuh new file mode 100644 index 0000000000..5704eefae3 --- /dev/null +++ b/cpp/test/test_utils.cuh @@ -0,0 +1,329 @@ +/* + * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "test_utils.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace raft { + +/* + * @brief Helper function to compare 2 device n-D arrays with custom comparison + * @tparam T the data type of the arrays + * @tparam L the comparator lambda or object function + * @param expected expected value(s) + * @param actual actual values + * @param eq_compare the comparator + * @param stream cuda stream + * @return the testing assertion to be later used by ASSERT_TRUE/EXPECT_TRUE + * @{ + */ +template +testing::AssertionResult devArrMatch( + const T* expected, const T* actual, size_t size, L eq_compare, cudaStream_t stream = 0) +{ + std::unique_ptr exp_h(new T[size]); + std::unique_ptr act_h(new T[size]); + raft::update_host(exp_h.get(), expected, size, stream); + raft::update_host(act_h.get(), actual, size, stream); + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + for (size_t i(0); i < size; ++i) { + auto exp = exp_h.get()[i]; + auto act = act_h.get()[i]; + if (!eq_compare(exp, act)) { + return testing::AssertionFailure() << "actual=" << act << " != expected=" << exp << " @" << i; + } + } + return testing::AssertionSuccess(); +} + +template +testing::AssertionResult devArrMatch( + T expected, const T* actual, size_t size, L eq_compare, cudaStream_t stream = 0) +{ + std::unique_ptr act_h(new T[size]); + raft::update_host(act_h.get(), actual, size, stream); + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + for (size_t i(0); i < size; ++i) { + auto act = act_h.get()[i]; + if (!eq_compare(expected, act)) { + return testing::AssertionFailure() + << "actual=" << act << " != expected=" << expected << " @" << i; + } + } + return testing::AssertionSuccess(); +} + +template +testing::AssertionResult devArrMatch(const T* expected, + const T* actual, + size_t rows, + size_t cols, + L eq_compare, + cudaStream_t stream = 0) +{ + size_t size = rows * cols; + std::unique_ptr exp_h(new T[size]); + std::unique_ptr act_h(new T[size]); + raft::update_host(exp_h.get(), expected, size, stream); + raft::update_host(act_h.get(), actual, size, stream); + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + for (size_t i(0); i < rows; ++i) { + for (size_t j(0); j < cols; ++j) { + auto idx = i * cols + j; // row major assumption! + auto exp = exp_h.get()[idx]; + auto act = act_h.get()[idx]; + if (!eq_compare(exp, act)) { + return testing::AssertionFailure() + << "actual=" << act << " != expected=" << exp << " @" << i << "," << j; + } + } + } + return testing::AssertionSuccess(); +} + +template +testing::AssertionResult devArrMatch( + T expected, const T* actual, size_t rows, size_t cols, L eq_compare, cudaStream_t stream = 0) +{ + size_t size = rows * cols; + std::unique_ptr act_h(new T[size]); + raft::update_host(act_h.get(), actual, size, stream); + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + for (size_t i(0); i < rows; ++i) { + for (size_t j(0); j < cols; ++j) { + auto idx = i * cols + j; // row major assumption! + auto act = act_h.get()[idx]; + if (!eq_compare(expected, act)) { + return testing::AssertionFailure() + << "actual=" << act << " != expected=" << expected << " @" << i << "," << j; + } + } + } + return testing::AssertionSuccess(); +} + +/* + * @brief Helper function to compare a device n-D arrays with an expected array + * on the host, using a custom comparison + * @tparam T the data type of the arrays + * @tparam L the comparator lambda or object function + * @param expected_h host array of expected value(s) + * @param actual_d device array actual values + * @param eq_compare the comparator + * @param stream cuda stream + * @return the testing assertion to be later used by ASSERT_TRUE/EXPECT_TRUE + */ +template +testing::AssertionResult devArrMatchHost( + const T* expected_h, const T* actual_d, size_t size, L eq_compare, cudaStream_t stream = 0) +{ + std::unique_ptr act_h(new T[size]); + raft::update_host(act_h.get(), actual_d, size, stream); + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + bool ok = true; + auto fail = testing::AssertionFailure(); + for (size_t i(0); i < size; ++i) { + auto exp = expected_h[i]; + auto act = act_h.get()[i]; + if (!eq_compare(exp, act)) { + ok = false; + fail << "actual=" << act << " != expected=" << exp << " @" << i << "; "; + } + } + if (!ok) return fail; + return testing::AssertionSuccess(); +} + +/** + * @brief Helper function to compare host vectors using a custom comparison + * @tparam T the element type + * @tparam L the comparator lambda or object function + * @param expected_h host vector of expected value(s) + * @param actual_h host vector actual values + * @param eq_compare the comparator + * @return the testing assertion to be later used by ASSERT_TRUE/EXPECT_TRUE + */ +template +testing::AssertionResult hostVecMatch(const std::vector& expected_h, + const std::vector& actual_h, + L eq_compare) +{ + auto n = actual_h.size(); + if (n != expected_h.size()) + return testing::AssertionFailure() + << "vector sizez mismatch: " + << "actual=" << n << " != expected=" << expected_h.size() << "; "; + for (size_t i = 0; i < n; ++i) { + auto exp = expected_h[i]; + auto act = actual_h[i]; + if (!eq_compare(exp, act)) { + return testing::AssertionFailure() + << "actual=" << act << " != expected=" << exp << " @" << i << "; "; + } + } + return testing::AssertionSuccess(); +} + +/* + * @brief Helper function to compare diagonal values of a 2D matrix + * @tparam T the data type of the arrays + * @tparam L the comparator lambda or object function + * @param expected expected value along diagonal + * @param actual actual matrix + * @param eq_compare the comparator + * @param stream cuda stream + * @return the testing assertion to be later used by ASSERT_TRUE/EXPECT_TRUE + */ +template +testing::AssertionResult diagonalMatch( + T expected, const T* actual, size_t rows, size_t cols, L eq_compare, cudaStream_t stream = 0) +{ + size_t size = rows * cols; + std::unique_ptr act_h(new T[size]); + raft::update_host(act_h.get(), actual, size, stream); + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + for (size_t i(0); i < rows; ++i) { + for (size_t j(0); j < cols; ++j) { + if (i != j) continue; + auto idx = i * cols + j; // row major assumption! + auto act = act_h.get()[idx]; + if (!eq_compare(expected, act)) { + return testing::AssertionFailure() + << "actual=" << act << " != expected=" << expected << " @" << i << "," << j; + } + } + } + return testing::AssertionSuccess(); +} + +template +typename std::enable_if_t> gen_uniform(T* out, + raft::random::RngState& rng, + IdxT len, + cudaStream_t stream, + T range_min = T(-1), + T range_max = T(1)) +{ + raft::random::uniform(rng, out, len, range_min, range_max, stream); +} + +template +typename std::enable_if_t> gen_uniform(T* out, + raft::random::RngState& rng, + IdxT len, + cudaStream_t stream, + T range_min = T(0), + T range_max = T(100)) +{ + raft::random::uniformInt(rng, out, len, range_min, range_max, stream); +} + +template +void gen_uniform(raft::KeyValuePair* out, + raft::random::RngState& rng, + IdxT len, + cudaStream_t stream) +{ + rmm::device_uvector keys(len, stream); + rmm::device_uvector values(len, stream); + + gen_uniform(keys.data(), rng, len, stream); + gen_uniform(values.data(), rng, len, stream); + + const T1* d_keys = keys.data(); + const T2* d_values = values.data(); + auto counting = thrust::make_counting_iterator(0); + thrust::for_each(rmm::exec_policy(stream), + counting, + counting + len, + [out, d_keys, d_values] __device__(int idx) { + out[idx].key = d_keys[idx]; + out[idx].value = d_values[idx]; + }); +} + +/** @} */ + +/** time the function call 'func' using cuda events */ +#define TIMEIT_LOOP(ms, count, func) \ + do { \ + cudaEvent_t start, stop; \ + RAFT_CUDA_TRY(cudaEventCreate(&start)); \ + RAFT_CUDA_TRY(cudaEventCreate(&stop)); \ + RAFT_CUDA_TRY(cudaEventRecord(start)); \ + for (int i = 0; i < count; ++i) { \ + func; \ + } \ + RAFT_CUDA_TRY(cudaEventRecord(stop)); \ + RAFT_CUDA_TRY(cudaEventSynchronize(stop)); \ + ms = 0.f; \ + RAFT_CUDA_TRY(cudaEventElapsedTime(&ms, start, stop)); \ + ms /= args.runs; \ + } while (0) + +inline std::vector read_csv(std::string filename, bool skip_first_n_columns = 1) +{ + std::vector result; + std::ifstream myFile(filename); + if (!myFile.is_open()) throw std::runtime_error("Could not open file"); + + std::string line, colname; + int val; + + if (myFile.good()) { + std::getline(myFile, line); + std::stringstream ss(line); + while (std::getline(ss, colname, ',')) {} + } + + int n_lines = 0; + while (std::getline(myFile, line)) { + std::stringstream ss(line); + int colIdx = 0; + while (ss >> val) { + if (colIdx >= skip_first_n_columns) { + result.push_back(val); + if (ss.peek() == ',') ss.ignore(); + } + colIdx++; + } + n_lines++; + } + + printf("lines read: %d\n", n_lines); + myFile.close(); + return result; +} + +}; // end namespace raft \ No newline at end of file diff --git a/cpp/test/test_utils.h b/cpp/test/test_utils.h index 26483e6b2d..75590463b0 100644 --- a/cpp/test/test_utils.h +++ b/cpp/test/test_utils.h @@ -15,23 +15,13 @@ */ #pragma once -#include + +#include #include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include +#include + +#include namespace raft { @@ -40,13 +30,22 @@ struct Compare { bool operator()(const T& a, const T& b) const { return a == b; } }; +template +struct Compare> { + bool operator()(const raft::KeyValuePair& a, + const raft::KeyValuePair& b) const + { + return a.key == b.key && a.value == b.value; + } +}; + template struct CompareApprox { CompareApprox(T eps_) : eps(eps_) {} bool operator()(const T& a, const T& b) const { - T diff = abs(a - b); - T m = std::max(abs(a), abs(b)); + T diff = std::abs(a - b); + T m = std::max(std::abs(a), std::abs(b)); T ratio = diff > eps ? diff / m : diff; return (ratio <= eps); @@ -85,8 +84,8 @@ struct CompareApproxAbs { CompareApproxAbs(T eps_) : eps(eps_) {} bool operator()(const T& a, const T& b) const { - T diff = abs(abs(a) - abs(b)); - T m = std::max(abs(a), abs(b)); + T diff = std::abs(std::abs(a) - std::abs(b)); + T m = std::max(std::abs(a), std::abs(b)); T ratio = diff >= eps ? diff / m : diff; return (ratio <= eps); } @@ -98,210 +97,14 @@ struct CompareApproxAbs { template struct CompareApproxNoScaling { CompareApproxNoScaling(T eps_) : eps(eps_) {} - bool operator()(const T& a, const T& b) const { return (abs(a - b) <= eps); } + bool operator()(const T& a, const T& b) const { return (std::abs(a - b) <= eps); } private: T eps; }; -template -__host__ __device__ T abs(const T& a) -{ - return a > T(0) ? a : -a; -} - -/* - * @brief Helper function to compare 2 device n-D arrays with custom comparison - * @tparam T the data type of the arrays - * @tparam L the comparator lambda or object function - * @param expected expected value(s) - * @param actual actual values - * @param eq_compare the comparator - * @param stream cuda stream - * @return the testing assertion to be later used by ASSERT_TRUE/EXPECT_TRUE - * @{ - */ -template -testing::AssertionResult devArrMatch( - const T* expected, const T* actual, size_t size, L eq_compare, cudaStream_t stream = 0) -{ - std::unique_ptr exp_h(new T[size]); - std::unique_ptr act_h(new T[size]); - raft::update_host(exp_h.get(), expected, size, stream); - raft::update_host(act_h.get(), actual, size, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); - for (size_t i(0); i < size; ++i) { - auto exp = exp_h.get()[i]; - auto act = act_h.get()[i]; - if (!eq_compare(exp, act)) { - return testing::AssertionFailure() << "actual=" << act << " != expected=" << exp << " @" << i; - } - } - return testing::AssertionSuccess(); -} - -template -testing::AssertionResult devArrMatch( - T expected, const T* actual, size_t size, L eq_compare, cudaStream_t stream = 0) -{ - std::unique_ptr act_h(new T[size]); - raft::update_host(act_h.get(), actual, size, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); - for (size_t i(0); i < size; ++i) { - auto act = act_h.get()[i]; - if (!eq_compare(expected, act)) { - return testing::AssertionFailure() - << "actual=" << act << " != expected=" << expected << " @" << i; - } - } - return testing::AssertionSuccess(); -} - -template -testing::AssertionResult devArrMatch(const T* expected, - const T* actual, - size_t rows, - size_t cols, - L eq_compare, - cudaStream_t stream = 0) -{ - size_t size = rows * cols; - std::unique_ptr exp_h(new T[size]); - std::unique_ptr act_h(new T[size]); - raft::update_host(exp_h.get(), expected, size, stream); - raft::update_host(act_h.get(), actual, size, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); - for (size_t i(0); i < rows; ++i) { - for (size_t j(0); j < cols; ++j) { - auto idx = i * cols + j; // row major assumption! - auto exp = exp_h.get()[idx]; - auto act = act_h.get()[idx]; - if (!eq_compare(exp, act)) { - return testing::AssertionFailure() - << "actual=" << act << " != expected=" << exp << " @" << i << "," << j; - } - } - } - return testing::AssertionSuccess(); -} - -template -testing::AssertionResult devArrMatch( - T expected, const T* actual, size_t rows, size_t cols, L eq_compare, cudaStream_t stream = 0) -{ - size_t size = rows * cols; - std::unique_ptr act_h(new T[size]); - raft::update_host(act_h.get(), actual, size, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); - for (size_t i(0); i < rows; ++i) { - for (size_t j(0); j < cols; ++j) { - auto idx = i * cols + j; // row major assumption! - auto act = act_h.get()[idx]; - if (!eq_compare(expected, act)) { - return testing::AssertionFailure() - << "actual=" << act << " != expected=" << expected << " @" << i << "," << j; - } - } - } - return testing::AssertionSuccess(); -} - -/* - * @brief Helper function to compare a device n-D arrays with an expected array - * on the host, using a custom comparison - * @tparam T the data type of the arrays - * @tparam L the comparator lambda or object function - * @param expected_h host array of expected value(s) - * @param actual_d device array actual values - * @param eq_compare the comparator - * @param stream cuda stream - * @return the testing assertion to be later used by ASSERT_TRUE/EXPECT_TRUE - */ -template -testing::AssertionResult devArrMatchHost( - const T* expected_h, const T* actual_d, size_t size, L eq_compare, cudaStream_t stream = 0) -{ - std::unique_ptr act_h(new T[size]); - raft::update_host(act_h.get(), actual_d, size, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); - bool ok = true; - auto fail = testing::AssertionFailure(); - for (size_t i(0); i < size; ++i) { - auto exp = expected_h[i]; - auto act = act_h.get()[i]; - if (!eq_compare(exp, act)) { - ok = false; - fail << "actual=" << act << " != expected=" << exp << " @" << i << "; "; - } - } - if (!ok) return fail; - return testing::AssertionSuccess(); -} - -/** - * @brief Helper function to compare host vectors using a custom comparison - * @tparam T the element type - * @tparam L the comparator lambda or object function - * @param expected_h host vector of expected value(s) - * @param actual_h host vector actual values - * @param eq_compare the comparator - * @return the testing assertion to be later used by ASSERT_TRUE/EXPECT_TRUE - */ -template -testing::AssertionResult hostVecMatch(const std::vector& expected_h, - const std::vector& actual_h, - L eq_compare) -{ - auto n = actual_h.size(); - if (n != expected_h.size()) - return testing::AssertionFailure() - << "vector sizez mismatch: " - << "actual=" << n << " != expected=" << expected_h.size() << "; "; - for (size_t i = 0; i < n; ++i) { - auto exp = expected_h[i]; - auto act = actual_h[i]; - if (!eq_compare(exp, act)) { - return testing::AssertionFailure() - << "actual=" << act << " != expected=" << exp << " @" << i << "; "; - } - } - return testing::AssertionSuccess(); -} - -/* - * @brief Helper function to compare diagonal values of a 2D matrix - * @tparam T the data type of the arrays - * @tparam L the comparator lambda or object function - * @param expected expected value along diagonal - * @param actual actual matrix - * @param eq_compare the comparator - * @param stream cuda stream - * @return the testing assertion to be later used by ASSERT_TRUE/EXPECT_TRUE - */ -template -testing::AssertionResult diagonalMatch( - T expected, const T* actual, size_t rows, size_t cols, L eq_compare, cudaStream_t stream = 0) -{ - size_t size = rows * cols; - std::unique_ptr act_h(new T[size]); - raft::update_host(act_h.get(), actual, size, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); - for (size_t i(0); i < rows; ++i) { - for (size_t j(0); j < cols; ++j) { - if (i != j) continue; - auto idx = i * cols + j; // row major assumption! - auto act = act_h.get()[idx]; - if (!eq_compare(expected, act)) { - return testing::AssertionFailure() - << "actual=" << act << " != expected=" << expected << " @" << i << "," << j; - } - } - } - return testing::AssertionSuccess(); -} - template -testing::AssertionResult match(const T expected, T actual, L eq_compare) +testing::AssertionResult match(const T& expected, const T& actual, L eq_compare) { if (!eq_compare(expected, actual)) { return testing::AssertionFailure() << "actual=" << actual << " != expected=" << expected; @@ -309,103 +112,4 @@ testing::AssertionResult match(const T expected, T actual, L eq_compare) return testing::AssertionSuccess(); } -template -typename std::enable_if_t> gen_uniform(T* out, - raft::random::RngState& rng, - IdxT len, - cudaStream_t stream, - T range_min = T(-1), - T range_max = T(1)) -{ - raft::random::uniform(rng, out, len, range_min, range_max, stream); -} - -template -typename std::enable_if_t> gen_uniform(T* out, - raft::random::RngState& rng, - IdxT len, - cudaStream_t stream, - T range_min = T(0), - T range_max = T(100)) -{ - raft::random::uniformInt(rng, out, len, range_min, range_max, stream); -} - -template -void gen_uniform(raft::KeyValuePair* out, - raft::random::RngState& rng, - IdxT len, - cudaStream_t stream) -{ - rmm::device_uvector keys(len, stream); - rmm::device_uvector values(len, stream); - - gen_uniform(keys.data(), rng, len, stream); - gen_uniform(values.data(), rng, len, stream); - - const T1* d_keys = keys.data(); - const T2* d_values = values.data(); - auto counting = thrust::make_counting_iterator(0); - thrust::for_each(rmm::exec_policy(stream), - counting, - counting + len, - [out, d_keys, d_values] __device__(int idx) { - out[idx].key = d_keys[idx]; - out[idx].value = d_values[idx]; - }); -} - -/** @} */ - -/** time the function call 'func' using cuda events */ -#define TIMEIT_LOOP(ms, count, func) \ - do { \ - cudaEvent_t start, stop; \ - RAFT_CUDA_TRY(cudaEventCreate(&start)); \ - RAFT_CUDA_TRY(cudaEventCreate(&stop)); \ - RAFT_CUDA_TRY(cudaEventRecord(start)); \ - for (int i = 0; i < count; ++i) { \ - func; \ - } \ - RAFT_CUDA_TRY(cudaEventRecord(stop)); \ - RAFT_CUDA_TRY(cudaEventSynchronize(stop)); \ - ms = 0.f; \ - RAFT_CUDA_TRY(cudaEventElapsedTime(&ms, start, stop)); \ - ms /= args.runs; \ - } while (0) - -inline std::vector read_csv(std::string filename, bool skip_first_n_columns = 1) -{ - std::vector result; - std::ifstream myFile(filename); - if (!myFile.is_open()) throw std::runtime_error("Could not open file"); - - std::string line, colname; - int val; - - if (myFile.good()) { - std::getline(myFile, line); - std::stringstream ss(line); - while (std::getline(ss, colname, ',')) {} - } - - int n_lines = 0; - while (std::getline(myFile, line)) { - std::stringstream ss(line); - int colIdx = 0; - while (ss >> val) { - if (colIdx >= skip_first_n_columns) { - result.push_back(val); - if (ss.peek() == ',') ss.ignore(); - } - colIdx++; - } - n_lines++; - } - - printf("lines read: %d\n", n_lines); - myFile.close(); - return result; -} - }; // end namespace raft