diff --git a/cpp/bench/CMakeLists.txt b/cpp/bench/CMakeLists.txt index 0a05682bcb..1d729d728b 100644 --- a/cpp/bench/CMakeLists.txt +++ b/cpp/bench/CMakeLists.txt @@ -94,6 +94,7 @@ if(BUILD_BENCH) bench/linalg/add.cu bench/linalg/map_then_reduce.cu bench/linalg/matrix_vector_op.cu + bench/linalg/norm.cu bench/linalg/normalize.cu bench/linalg/reduce_rows_by_key.cu bench/linalg/reduce.cu diff --git a/cpp/bench/linalg/norm.cu b/cpp/bench/linalg/norm.cu new file mode 100644 index 0000000000..cce4195cf1 --- /dev/null +++ b/cpp/bench/linalg/norm.cu @@ -0,0 +1,84 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include + +#include + +namespace raft::bench::linalg { + +template +struct norm_input { + IdxT rows, cols; +}; + +template +inline auto operator<<(std::ostream& os, const norm_input& p) -> std::ostream& +{ + os << p.rows << "#" << p.cols; + return os; +} + +template +struct rowNorm : public fixture { + rowNorm(const norm_input& p) : params(p), in(p.rows * p.cols, stream), dots(p.rows, stream) + { + raft::random::RngState rng{1234}; + raft::random::uniform(rng, in.data(), p.rows * p.cols, (T)-10.0, (T)10.0, stream); + } + + void run_benchmark(::benchmark::State& state) override + { + std::ostringstream label_stream; + label_stream << params; + state.SetLabel(label_stream.str()); + + loop_on_state(state, [this]() { + auto input_view = raft::make_device_matrix_view( + in.data(), params.rows, params.cols); + auto output_view = + raft::make_device_vector_view(dots.data(), params.rows); + raft::linalg::norm(handle, + input_view, + output_view, + raft::linalg::L2Norm, + raft::linalg::Apply::ALONG_ROWS, + raft::SqrtOp()); + }); + } + + private: + norm_input params; + rmm::device_uvector in, dots; +}; // struct rowNorm + +const std::vector> norm_inputs_i32 = + raft::util::itertools::product>({10, 100, 1000, 10000, 100000}, + {16, 32, 64, 128, 256, 512, 1024}); +const std::vector> norm_inputs_i64 = + raft::util::itertools::product>({10, 100, 1000, 10000, 100000}, + {16, 32, 64, 128, 256, 512, 1024}); + +RAFT_BENCH_REGISTER((rowNorm), "", norm_inputs_i32); +RAFT_BENCH_REGISTER((rowNorm), "", norm_inputs_i32); +RAFT_BENCH_REGISTER((rowNorm), "", norm_inputs_i64); +RAFT_BENCH_REGISTER((rowNorm), "", norm_inputs_i64); + +} // namespace raft::bench::linalg diff --git a/cpp/include/raft/linalg/coalesced_reduction.cuh b/cpp/include/raft/linalg/coalesced_reduction.cuh index 6ef0d52e62..e9e5a99f46 100644 --- a/cpp/include/raft/linalg/coalesced_reduction.cuh +++ b/cpp/include/raft/linalg/coalesced_reduction.cuh @@ -112,21 +112,21 @@ void coalescedReduction(OutType* dots, template , + typename IdxType, + typename MainLambda = raft::Nop, typename ReduceLambda = raft::Sum, typename FinalLambda = raft::Nop> void coalesced_reduction(const raft::handle_t& handle, - raft::device_matrix_view data, - raft::device_vector_view dots, + raft::device_matrix_view data, + raft::device_vector_view dots, OutValueType init, bool inplace = false, - MainLambda main_op = raft::Nop(), + MainLambda main_op = raft::Nop(), ReduceLambda reduce_op = raft::Sum(), FinalLambda final_op = raft::Nop()) { if constexpr (std::is_same_v) { - RAFT_EXPECTS(static_cast(dots.size()) == data.extent(0), + RAFT_EXPECTS(static_cast(dots.size()) == data.extent(0), "Output should be equal to number of rows in Input"); coalescedReduction(dots.data_handle(), @@ -140,7 +140,7 @@ void coalesced_reduction(const raft::handle_t& handle, reduce_op, final_op); } else if constexpr (std::is_same_v) { - RAFT_EXPECTS(static_cast(dots.size()) == data.extent(1), + RAFT_EXPECTS(static_cast(dots.size()) == data.extent(1), "Output should be equal to number of columns in Input"); coalescedReduction(dots.data_handle(), diff --git a/cpp/include/raft/linalg/detail/coalesced_reduction.cuh b/cpp/include/raft/linalg/detail/coalesced_reduction.cuh index cf1b8cf5a5..63351f5475 100644 --- a/cpp/include/raft/linalg/detail/coalesced_reduction.cuh +++ b/cpp/include/raft/linalg/detail/coalesced_reduction.cuh @@ -17,34 +17,136 @@ #pragma once #include +#include #include +#include namespace raft { namespace linalg { namespace detail { -// Kernel (based on norm.cuh) to perform reductions along the coalesced dimension -// of the matrix, i.e. reduce along rows for row major or reduce along columns -// for column major layout. Kernel does an inplace reduction adding to original -// values of dots. +template +struct ReductionThinPolicy { + static constexpr int LogicalWarpSize = warpSize; + static constexpr int RowsPerBlock = rpb; + static constexpr int ThreadsPerBlock = LogicalWarpSize * RowsPerBlock; +}; + +template +__global__ void __launch_bounds__(Policy::ThreadsPerBlock) + coalescedReductionThinKernel(OutType* dots, + const InType* data, + IdxType D, + IdxType N, + OutType init, + MainLambda main_op, + ReduceLambda reduce_op, + FinalLambda final_op, + bool inplace = false) +{ + IdxType i = threadIdx.y + (Policy::RowsPerBlock * static_cast(blockIdx.x)); + if (i >= N) return; + + OutType acc = init; + for (IdxType j = threadIdx.x; j < D; j += Policy::LogicalWarpSize) { + acc = reduce_op(acc, main_op(data[j + (D * i)], j)); + } + acc = raft::logicalWarpReduce(acc, reduce_op); + if (threadIdx.x == 0) { + if (inplace) { + dots[i] = final_op(reduce_op(dots[i], acc)); + } else { + dots[i] = final_op(acc); + } + } +} + +template , + typename ReduceLambda = raft::Sum, + typename FinalLambda = raft::Nop> +void coalescedReductionThin(OutType* dots, + const InType* data, + IdxType D, + IdxType N, + OutType init, + cudaStream_t stream, + bool inplace = false, + MainLambda main_op = raft::Nop(), + ReduceLambda reduce_op = raft::Sum(), + FinalLambda final_op = raft::Nop()) +{ + common::nvtx::range fun_scope( + "coalescedReductionThin<%d,%d>", Policy::LogicalWarpSize, Policy::RowsPerBlock); + dim3 threads(Policy::LogicalWarpSize, Policy::RowsPerBlock, 1); + dim3 blocks(ceildiv(N, Policy::RowsPerBlock), 1, 1); + coalescedReductionThinKernel + <<>>(dots, data, D, N, init, main_op, reduce_op, final_op, inplace); + RAFT_CUDA_TRY(cudaPeekAtLastError()); +} + template , + typename ReduceLambda = raft::Sum, + typename FinalLambda = raft::Nop> +void coalescedReductionThinDispatcher(OutType* dots, + const InType* data, + IdxType D, + IdxType N, + OutType init, + cudaStream_t stream, + bool inplace = false, + MainLambda main_op = raft::Nop(), + ReduceLambda reduce_op = raft::Sum(), + FinalLambda final_op = raft::Nop()) +{ + if (D <= IdxType(2)) { + coalescedReductionThin>( + dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); + } else if (D <= IdxType(4)) { + coalescedReductionThin>( + dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); + } else if (D <= IdxType(8)) { + coalescedReductionThin>( + dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); + } else if (D <= IdxType(16)) { + coalescedReductionThin>( + dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); + } else { + coalescedReductionThin>( + dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); + } +} + +template -__global__ void coalescedReductionKernel(OutType* dots, - const InType* data, - int D, - int N, - OutType init, - MainLambda main_op, - ReduceLambda reduce_op, - FinalLambda final_op, - bool inplace = false) +__global__ void __launch_bounds__(TPB) coalescedReductionMediumKernel(OutType* dots, + const InType* data, + IdxType D, + IdxType N, + OutType init, + MainLambda main_op, + ReduceLambda reduce_op, + FinalLambda final_op, + bool inplace = false) { - typedef cub::BlockReduce BlockReduce; + typedef cub::BlockReduce BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; OutType thread_data = init; IdxType rowStart = blockIdx.x * D; @@ -62,6 +164,169 @@ __global__ void coalescedReductionKernel(OutType* dots, } } +template , + typename ReduceLambda = raft::Sum, + typename FinalLambda = raft::Nop> +void coalescedReductionMedium(OutType* dots, + const InType* data, + IdxType D, + IdxType N, + OutType init, + cudaStream_t stream, + bool inplace = false, + MainLambda main_op = raft::Nop(), + ReduceLambda reduce_op = raft::Sum(), + FinalLambda final_op = raft::Nop()) +{ + common::nvtx::range fun_scope("coalescedReductionMedium<%d>", TPB); + coalescedReductionMediumKernel + <<>>(dots, data, D, N, init, main_op, reduce_op, final_op, inplace); + RAFT_CUDA_TRY(cudaPeekAtLastError()); +} + +template , + typename ReduceLambda = raft::Sum, + typename FinalLambda = raft::Nop> +void coalescedReductionMediumDispatcher(OutType* dots, + const InType* data, + IdxType D, + IdxType N, + OutType init, + cudaStream_t stream, + bool inplace = false, + MainLambda main_op = raft::Nop(), + ReduceLambda reduce_op = raft::Sum(), + FinalLambda final_op = raft::Nop()) +{ + // Note: for now, this kernel is only used when D > 256. If this changes in the future, use + // smaller block sizes when relevant. + coalescedReductionMedium<256>( + dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); +} + +template +struct ReductionThickPolicy { + static constexpr int ThreadsPerBlock = tpb; + static constexpr int BlocksPerRow = bpr; + static constexpr int BlockStride = tpb * bpr; +}; + +template +__global__ void __launch_bounds__(Policy::ThreadsPerBlock) + coalescedReductionThickKernel(OutType* buffer, + const InType* data, + IdxType D, + IdxType N, + OutType init, + MainLambda main_op, + ReduceLambda reduce_op) +{ + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + OutType thread_data = init; + IdxType rowStart = blockIdx.x * D; + for (IdxType i = blockIdx.y * Policy::ThreadsPerBlock + threadIdx.x; i < D; + i += Policy::BlockStride) { + IdxType idx = rowStart + i; + thread_data = reduce_op(thread_data, main_op(data[idx], i)); + } + OutType acc = BlockReduce(temp_storage).Reduce(thread_data, reduce_op); + if (threadIdx.x == 0) { buffer[Policy::BlocksPerRow * blockIdx.x + blockIdx.y] = acc; } +} + +template , + typename ReduceLambda = raft::Sum, + typename FinalLambda = raft::Nop> +void coalescedReductionThick(OutType* dots, + const InType* data, + IdxType D, + IdxType N, + OutType init, + cudaStream_t stream, + bool inplace = false, + MainLambda main_op = raft::Nop(), + ReduceLambda reduce_op = raft::Sum(), + FinalLambda final_op = raft::Nop()) +{ + common::nvtx::range fun_scope( + "coalescedReductionThick<%d,%d>", ThickPolicy::ThreadsPerBlock, ThickPolicy::BlocksPerRow); + + dim3 threads(ThickPolicy::ThreadsPerBlock, 1, 1); + dim3 blocks(N, ThickPolicy::BlocksPerRow, 1); + + rmm::device_uvector buffer(N * ThickPolicy::BlocksPerRow, stream); + + /* We apply a two-step reduction: + * 1. coalescedReductionThickKernel reduces the [N x D] input data to [N x BlocksPerRow]. It + * applies the main_op but not the final op. + * 2. coalescedReductionThinKernel reduces [N x BlocksPerRow] to [N x 1]. It doesn't apply any + * main_op but applies final_op. If in-place, the existing and new values are reduced. + */ + + coalescedReductionThickKernel + <<>>(buffer.data(), data, D, N, init, main_op, reduce_op); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + + coalescedReductionThin(dots, + buffer.data(), + static_cast(ThickPolicy::BlocksPerRow), + N, + init, + stream, + inplace, + raft::Nop(), + reduce_op, + final_op); +} + +template , + typename ReduceLambda = raft::Sum, + typename FinalLambda = raft::Nop> +void coalescedReductionThickDispatcher(OutType* dots, + const InType* data, + IdxType D, + IdxType N, + OutType init, + cudaStream_t stream, + bool inplace = false, + MainLambda main_op = raft::Nop(), + ReduceLambda reduce_op = raft::Sum(), + FinalLambda final_op = raft::Nop()) +{ + // Note: multiple elements per thread to take advantage of the sequential reduction and loop + // unrolling + if (D < IdxType(32768)) { + coalescedReductionThick, ReductionThinPolicy<32, 4>>( + dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); + } else { + coalescedReductionThick, ReductionThinPolicy<32, 4>>( + dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); + } +} + +// Primitive to perform reductions along the coalesced dimension of the matrix, i.e. reduce along +// rows for row major or reduce along columns for column major layout. Can do an inplace reduction +// adding to original values of dots if requested. template > void coalescedReduction(OutType* dots, const InType* data, - int D, - int N, + IdxType D, + IdxType N, OutType init, cudaStream_t stream, bool inplace = false, @@ -79,22 +344,22 @@ void coalescedReduction(OutType* dots, ReduceLambda reduce_op = raft::Sum(), FinalLambda final_op = raft::Nop()) { - // One block per reduction - // Efficient only for large leading dimensions - if (D <= 32) { - coalescedReductionKernel - <<>>(dots, data, D, N, init, main_op, reduce_op, final_op, inplace); - } else if (D <= 64) { - coalescedReductionKernel - <<>>(dots, data, D, N, init, main_op, reduce_op, final_op, inplace); - } else if (D <= 128) { - coalescedReductionKernel - <<>>(dots, data, D, N, init, main_op, reduce_op, final_op, inplace); + /* The primitive selects one of three implementations based on heuristics: + * - Thin: very efficient when D is small and/or N is large + * - Thick: used when N is very small and D very large + * - Medium: used when N is too small to fill the GPU with the thin kernel + */ + const IdxType numSMs = raft::getMultiProcessorCount(); + if (D <= IdxType(256) || N >= IdxType(4) * numSMs) { + coalescedReductionThinDispatcher( + dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); + } else if (N < numSMs && D >= IdxType(16384)) { + coalescedReductionThickDispatcher( + dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); } else { - coalescedReductionKernel - <<>>(dots, data, D, N, init, main_op, reduce_op, final_op, inplace); + coalescedReductionMediumDispatcher( + dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); } - RAFT_CUDA_TRY(cudaPeekAtLastError()); } } // namespace detail diff --git a/cpp/include/raft/linalg/norm.cuh b/cpp/include/raft/linalg/norm.cuh index b756744755..9abfd3bdb0 100644 --- a/cpp/include/raft/linalg/norm.cuh +++ b/cpp/include/raft/linalg/norm.cuh @@ -33,8 +33,7 @@ namespace linalg { * * Row-wise norm is useful while computing pairwise distance matrix, for * example. - * This is used in many clustering algos like knn, kmeans, dbscan, etc... The - * current implementation is optimized only for bigger values of 'D'. + * This is used in many clustering algos like knn, kmeans, dbscan, etc... * * @tparam Type the data type * @tparam Lambda device final lambda diff --git a/cpp/include/raft/linalg/reduce.cuh b/cpp/include/raft/linalg/reduce.cuh index 9b3f4ee347..5579acf355 100644 --- a/cpp/include/raft/linalg/reduce.cuh +++ b/cpp/include/raft/linalg/reduce.cuh @@ -117,17 +117,17 @@ void reduce(OutType* dots, template , + typename IdxType = std::uint32_t, + typename MainLambda = raft::Nop, typename ReduceLambda = raft::Sum, typename FinalLambda = raft::Nop> void reduce(const raft::handle_t& handle, - raft::device_matrix_view data, - raft::device_vector_view dots, + raft::device_matrix_view data, + raft::device_vector_view dots, OutElementType init, Apply apply, bool inplace = false, - MainLambda main_op = raft::Nop(), + MainLambda main_op = raft::Nop(), ReduceLambda reduce_op = raft::Sum(), FinalLambda final_op = raft::Nop()) { @@ -137,10 +137,10 @@ void reduce(const raft::handle_t& handle, bool along_rows = apply == Apply::ALONG_ROWS; if (along_rows) { - RAFT_EXPECTS(static_cast(dots.size()) == data.extent(1), + RAFT_EXPECTS(static_cast(dots.size()) == data.extent(1), "Output should be equal to number of columns in Input"); } else { - RAFT_EXPECTS(static_cast(dots.size()) == data.extent(0), + RAFT_EXPECTS(static_cast(dots.size()) == data.extent(0), "Output should be equal to number of rows in Input"); } diff --git a/cpp/include/raft/linalg/strided_reduction.cuh b/cpp/include/raft/linalg/strided_reduction.cuh index 9147692c03..0aa4aecef5 100644 --- a/cpp/include/raft/linalg/strided_reduction.cuh +++ b/cpp/include/raft/linalg/strided_reduction.cuh @@ -24,6 +24,8 @@ #include #include +#include + namespace raft { namespace linalg { @@ -71,8 +73,16 @@ void stridedReduction(OutType* dots, ReduceLambda reduce_op = raft::Sum(), FinalLambda final_op = raft::Nop()) { - detail::stridedReduction( - dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); + // Only compile for types supported by myAtomicReduce, but don't make the compilation fail in + // other cases, because coalescedReduction supports arbitrary types. + if constexpr (std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v) { + detail::stridedReduction( + dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); + } else { + THROW("Unsupported type for stridedReduction: %s", typeid(OutType).name()); + } } /** diff --git a/cpp/include/raft/spatial/knn/detail/ann_utils.cuh b/cpp/include/raft/spatial/knn/detail/ann_utils.cuh index ec03476252..b721915187 100644 --- a/cpp/include/raft/spatial/knn/detail/ann_utils.cuh +++ b/cpp/include/raft/spatial/knn/detail/ann_utils.cuh @@ -204,54 +204,6 @@ inline void memzero(T* ptr, IdxT n_elems, rmm::cuda_stream_view stream) } } -template -__global__ void dots_along_rows_kernel(IdxT n_rows, IdxT n_cols, const float* a, float* out) -{ - IdxT i = threadIdx.y + (blockDim.y * static_cast(blockIdx.x)); - if (i >= n_rows) return; - - float sqsum = 0.0; - for (IdxT j = threadIdx.x; j < n_cols; j += blockDim.x) { - float val = a[j + (n_cols * i)]; - sqsum += val * val; - } - sqsum += __shfl_xor_sync(0xffffffff, sqsum, 1); - sqsum += __shfl_xor_sync(0xffffffff, sqsum, 2); - sqsum += __shfl_xor_sync(0xffffffff, sqsum, 4); - sqsum += __shfl_xor_sync(0xffffffff, sqsum, 8); - sqsum += __shfl_xor_sync(0xffffffff, sqsum, 16); - if (threadIdx.x == 0) { out[i] = sqsum; } -} - -/** - * @brief Square sum of values in each row (row-major matrix). - * - * NB: device-only function - * - * @tparam IdxT index type - * - * @param n_rows - * @param n_cols - * @param[in] a device pointer to the row-major matrix [n_rows, n_cols] - * @param[out] out device pointer to the vector of dot-products [n_rows] - * @param stream - */ -template -inline void dots_along_rows( - IdxT n_rows, IdxT n_cols, const float* a, float* out, rmm::cuda_stream_view stream) -{ - dim3 threads(32, 4, 1); - dim3 blocks(ceildiv(n_rows, threads.y), 1, 1); - dots_along_rows_kernel<<>>(n_rows, n_cols, a, out); - /** - * TODO: this can be replaced with the rowNorm helper as shown below. - * However, the rowNorm helper seems to incur a significant performance penalty - * (example case ann-search slowed down from 150ms to 186ms). - * - * raft::linalg::rowNorm(out, a, n_cols, n_rows, raft::linalg::L2Norm, true, stream); - */ -} - template __global__ void outer_add_kernel(const T* a, IdxT len_a, const T* b, IdxT len_b, T* c) { diff --git a/cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh b/cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh index 82d498a789..e9af97b547 100644 --- a/cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh +++ b/cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh @@ -25,6 +25,7 @@ #include #include #include +#include #include #include @@ -231,12 +232,14 @@ inline auto extend(const handle_t& handle, orig_index.center_norms()->size(), stream); } else { - // todo(lsugy): use other prim and remove this one - utils::dots_along_rows(n_lists, - dim, - ext_index.centers().data_handle(), - ext_index.center_norms()->data_handle(), - stream); + raft::linalg::rowNorm(ext_index.center_norms()->data_handle(), + ext_index.centers().data_handle(), + dim, + n_lists, + raft::linalg::L2Norm, + true, + stream, + raft::SqrtOp()); RAFT_LOG_TRACE_VEC(ext_index.center_norms()->data_handle(), std::min(dim, 20)); } } diff --git a/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh b/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh index 8c6bde06de..f6b9e62008 100644 --- a/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh +++ b/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh @@ -27,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -1102,8 +1103,14 @@ void search_impl(const handle_t& handle, if (index.metric() == raft::distance::DistanceType::L2Expanded) { alpha = -2.0f; beta = 1.0f; - utils::dots_along_rows( - n_queries, index.dim(), converted_queries_ptr, query_norm_dev.data(), stream); + raft::linalg::rowNorm(query_norm_dev.data(), + converted_queries_ptr, + static_cast(index.dim()), + static_cast(n_queries), + raft::linalg::L2Norm, + true, + stream, + raft::SqrtOp()); utils::outer_add(query_norm_dev.data(), (IdxT)n_queries, index.center_norms()->data_handle(), diff --git a/cpp/include/raft/spatial/knn/detail/ivf_pq_build.cuh b/cpp/include/raft/spatial/knn/detail/ivf_pq_build.cuh index de7fa1b2f7..9262ef6baf 100644 --- a/cpp/include/raft/spatial/knn/detail/ivf_pq_build.cuh +++ b/cpp/include/raft/spatial/knn/detail/ivf_pq_build.cuh @@ -29,6 +29,7 @@ #include #include #include +#include #include #include #include @@ -1178,8 +1179,14 @@ inline auto build_device( stream)); rmm::device_uvector center_norms(index.n_lists(), stream, device_memory); - utils::dots_along_rows( - index.n_lists(), index.dim(), cluster_centers, center_norms.data(), stream); + raft::linalg::rowNorm(center_norms.data(), + cluster_centers, + index.dim(), + index.n_lists(), + raft::linalg::L2Norm, + true, + stream, + raft::SqrtOp()); RAFT_CUDA_TRY(cudaMemcpy2DAsync(index.centers().data_handle() + index.dim(), sizeof(float) * index.dim_ext(), center_norms.data(), diff --git a/cpp/include/raft/util/cuda_utils.cuh b/cpp/include/raft/util/cuda_utils.cuh index a64fbdb1be..5818fc21f3 100644 --- a/cpp/include/raft/util/cuda_utils.cuh +++ b/cpp/include/raft/util/cuda_utils.cuh @@ -18,6 +18,7 @@ #include #include +#include #include @@ -636,9 +637,42 @@ DI uint32_t ballot(bool inFlag, uint32_t mask = 0xffffffffu) #endif } +/** True CUDA alignment of a type (adapted from CUB) */ +template +struct cuda_alignment { + struct Pad { + T val; + char byte; + }; + + static constexpr int bytes = sizeof(Pad) - sizeof(T); +}; + +template +struct is_multiple { + static constexpr int large_align_bytes = cuda_alignment::bytes; + static constexpr int unit_align_bytes = cuda_alignment::bytes; + static constexpr bool value = + (sizeof(LargeT) % sizeof(UnitT) == 0) && (large_align_bytes % unit_align_bytes == 0); +}; + +template +inline constexpr bool is_multiple_v = is_multiple::value; + +template +struct is_shuffleable { + static constexpr bool value = + std::is_same_v || std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || std::is_same_v; +}; + +template +inline constexpr bool is_shuffleable_v = is_shuffleable::value; + /** * @brief Shuffle the data inside a warp - * @tparam T the data type (currently assumed to be 4B) + * @tparam T the data type * @param val value to be shuffled * @param srcLane lane from where to shuffle * @param width lane width @@ -646,7 +680,10 @@ DI uint32_t ballot(bool inFlag, uint32_t mask = 0xffffffffu) * @return the shuffled data */ template -DI T shfl(T val, int srcLane, int width = WarpSize, uint32_t mask = 0xffffffffu) +DI std::enable_if_t, T> shfl(T val, + int srcLane, + int width = WarpSize, + uint32_t mask = 0xffffffffu) { #if CUDART_VERSION >= 9000 return __shfl_sync(mask, val, srcLane, width); @@ -655,9 +692,40 @@ DI T shfl(T val, int srcLane, int width = WarpSize, uint32_t mask = 0xffffffffu) #endif } +/// Overload of shfl for data types not supported by the CUDA intrinsics +template +DI std::enable_if_t, T> shfl(T val, + int srcLane, + int width = WarpSize, + uint32_t mask = 0xffffffffu) +{ + using UnitT = + std::conditional_t, + unsigned int, + std::conditional_t, unsigned short, unsigned char>>; + + constexpr int n_words = sizeof(T) / sizeof(UnitT); + + T output; + UnitT* output_alias = reinterpret_cast(&output); + UnitT* input_alias = reinterpret_cast(&val); + + unsigned int shuffle_word; + shuffle_word = shfl((unsigned int)input_alias[0], srcLane, width, mask); + output_alias[0] = shuffle_word; + +#pragma unroll + for (int i = 1; i < n_words; ++i) { + shuffle_word = shfl((unsigned int)input_alias[i], srcLane, width, mask); + output_alias[i] = shuffle_word; + } + + return output; +} + /** * @brief Shuffle the data inside a warp from lower lane IDs - * @tparam T the data type (currently assumed to be 4B) + * @tparam T the data type * @param val value to be shuffled * @param delta lower lane ID delta from where to shuffle * @param width lane width @@ -665,7 +733,10 @@ DI T shfl(T val, int srcLane, int width = WarpSize, uint32_t mask = 0xffffffffu) * @return the shuffled data */ template -DI T shfl_up(T val, int delta, int width = WarpSize, uint32_t mask = 0xffffffffu) +DI std::enable_if_t, T> shfl_up(T val, + int delta, + int width = WarpSize, + uint32_t mask = 0xffffffffu) { #if CUDART_VERSION >= 9000 return __shfl_up_sync(mask, val, delta, width); @@ -674,9 +745,40 @@ DI T shfl_up(T val, int delta, int width = WarpSize, uint32_t mask = 0xffffffffu #endif } +/// Overload of shfl_up for data types not supported by the CUDA intrinsics +template +DI std::enable_if_t, T> shfl_up(T val, + int delta, + int width = WarpSize, + uint32_t mask = 0xffffffffu) +{ + using UnitT = + std::conditional_t, + unsigned int, + std::conditional_t, unsigned short, unsigned char>>; + + constexpr int n_words = sizeof(T) / sizeof(UnitT); + + T output; + UnitT* output_alias = reinterpret_cast(&output); + UnitT* input_alias = reinterpret_cast(&val); + + unsigned int shuffle_word; + shuffle_word = shfl_up((unsigned int)input_alias[0], delta, width, mask); + output_alias[0] = shuffle_word; + +#pragma unroll + for (int i = 1; i < n_words; ++i) { + shuffle_word = shfl_up((unsigned int)input_alias[i], delta, width, mask); + output_alias[i] = shuffle_word; + } + + return output; +} + /** * @brief Shuffle the data inside a warp - * @tparam T the data type (currently assumed to be 4B) + * @tparam T the data type * @param val value to be shuffled * @param laneMask mask to be applied in order to perform xor shuffle * @param width lane width @@ -684,7 +786,10 @@ DI T shfl_up(T val, int delta, int width = WarpSize, uint32_t mask = 0xffffffffu * @return the shuffled data */ template -DI T shfl_xor(T val, int laneMask, int width = WarpSize, uint32_t mask = 0xffffffffu) +DI std::enable_if_t, T> shfl_xor(T val, + int laneMask, + int width = WarpSize, + uint32_t mask = 0xffffffffu) { #if CUDART_VERSION >= 9000 return __shfl_xor_sync(mask, val, laneMask, width); @@ -693,6 +798,37 @@ DI T shfl_xor(T val, int laneMask, int width = WarpSize, uint32_t mask = 0xfffff #endif } +/// Overload of shfl_xor for data types not supported by the CUDA intrinsics +template +DI std::enable_if_t, T> shfl_xor(T val, + int laneMask, + int width = WarpSize, + uint32_t mask = 0xffffffffu) +{ + using UnitT = + std::conditional_t, + unsigned int, + std::conditional_t, unsigned short, unsigned char>>; + + constexpr int n_words = sizeof(T) / sizeof(UnitT); + + T output; + UnitT* output_alias = reinterpret_cast(&output); + UnitT* input_alias = reinterpret_cast(&val); + + unsigned int shuffle_word; + shuffle_word = shfl_xor((unsigned int)input_alias[0], laneMask, width, mask); + output_alias[0] = shuffle_word; + +#pragma unroll + for (int i = 1; i < n_words; ++i) { + shuffle_word = shfl_xor((unsigned int)input_alias[i], laneMask, width, mask); + output_alias[i] = shuffle_word; + } + + return output; +} + /** * @brief Four-way byte dot product-accumulate. * @tparam T Four-byte integer: int or unsigned int @@ -775,19 +911,35 @@ DI T logicalWarpReduce(T val, ReduceLambda reduce_op) } /** - * @brief Warp-level sum reduction + * @brief Warp-level reduction + * @tparam T Value type to be reduced + * @tparam ReduceLambda Reduction operation type * @param val input value + * @param reduce_op Reduction operation + * @return Reduction result. All lanes will have the valid result. + * @note Why not cub? Because cub doesn't seem to allow working with arbitrary + * number of warps in a block. All threads in the warp must enter this + * function together + */ +template +DI T warpReduce(T val, ReduceLambda reduce_op) +{ + return logicalWarpReduce(val, reduce_op); +} + +/** + * @brief Warp-level sum reduction * @tparam T Value type to be reduced + * @param val input value * @return Reduction result. All lanes will have the valid result. * @note Why not cub? Because cub doesn't seem to allow working with arbitrary * number of warps in a block. All threads in the warp must enter this * function together - * @todo Expand this to support arbitrary reduction ops */ template DI T warpReduce(T val) { - return logicalWarpReduce(val, raft::Sum()); + return warpReduce(val, raft::Sum{}); } /** diff --git a/cpp/test/linalg/coalesced_reduction.cu b/cpp/test/linalg/coalesced_reduction.cu index cc2acef565..791537b430 100644 --- a/cpp/test/linalg/coalesced_reduction.cu +++ b/cpp/test/linalg/coalesced_reduction.cu @@ -70,11 +70,31 @@ class coalescedReductionTest : public ::testing::TestWithParam{}, + raft::Sum{}, + raft::Nop{}); + naiveCoalescedReduction(dots_exp.data(), + data.data(), + cols, + rows, + stream, + T(0), + true, + raft::L2Op{}, + raft::Sum{}, + raft::Nop{}); + coalescedReductionLaunch(handle, dots_act.data(), data.data(), cols, rows); - // Add to result with inplace = true next coalescedReductionLaunch(handle, dots_act.data(), data.data(), cols, rows, true); handle.sync_stream(stream); diff --git a/cpp/test/linalg/norm.cu b/cpp/test/linalg/norm.cu index 5243f2435f..f0b8d3bb55 100644 --- a/cpp/test/linalg/norm.cu +++ b/cpp/test/linalg/norm.cu @@ -19,22 +19,23 @@ #include #include #include +#include namespace raft { namespace linalg { -template +template struct NormInputs { T tolerance; - int rows, cols; + IdxT rows, cols; NormType type; bool do_sqrt; bool rowMajor; unsigned long long int seed; }; -template -::std::ostream& operator<<(::std::ostream& os, const NormInputs& I) +template +::std::ostream& operator<<(::std::ostream& os, const NormInputs& I) { os << "{ " << I.tolerance << ", " << I.rows << ", " << I.cols << ", " << I.type << ", " << I.do_sqrt << ", " << I.seed << '}' << std::endl; @@ -42,14 +43,14 @@ template } ///// Row-wise norm test definitions -template +template __global__ void naiveRowNormKernel( - Type* dots, const Type* data, int D, int N, NormType type, bool do_sqrt) + Type* dots, const Type* data, IdxT D, IdxT N, NormType type, bool do_sqrt) { - Type acc = (Type)0; - int rowStart = threadIdx.x + blockIdx.x * blockDim.x; + Type acc = (Type)0; + IdxT rowStart = threadIdx.x + static_cast(blockIdx.x) * blockDim.x; if (rowStart < N) { - for (int i = 0; i < D; ++i) { + for (IdxT i = 0; i < D; ++i) { if (type == L2Norm) { acc += data[rowStart * D + i] * data[rowStart * D + i]; } else { @@ -60,21 +61,21 @@ __global__ void naiveRowNormKernel( } } -template +template void naiveRowNorm( - Type* dots, const Type* data, int D, int N, NormType type, bool do_sqrt, cudaStream_t stream) + Type* dots, const Type* data, IdxT D, IdxT N, NormType type, bool do_sqrt, cudaStream_t stream) { - static const int TPB = 64; - int nblks = raft::ceildiv(N, TPB); + static const IdxT TPB = 64; + IdxT nblks = raft::ceildiv(N, TPB); naiveRowNormKernel<<>>(dots, data, D, N, type, do_sqrt); RAFT_CUDA_TRY(cudaPeekAtLastError()); } -template -class RowNormTest : public ::testing::TestWithParam> { +template +class RowNormTest : public ::testing::TestWithParam> { public: RowNormTest() - : params(::testing::TestWithParam>::GetParam()), + : params(::testing::TestWithParam>::GetParam()), stream(handle.get_stream()), data(params.rows * params.cols, stream), dots_exp(params.rows, stream), @@ -85,13 +86,13 @@ class RowNormTest : public ::testing::TestWithParam> { void SetUp() override { raft::random::RngState r(params.seed); - int rows = params.rows, cols = params.cols, len = rows * cols; + IdxT rows = params.rows, cols = params.cols, len = rows * cols; uniform(handle, r, data.data(), len, T(-1.0), T(1.0)); naiveRowNorm(dots_exp.data(), data.data(), cols, rows, params.type, params.do_sqrt, stream); - auto output_view = raft::make_device_vector_view(dots_act.data(), params.rows); - auto input_row_major = raft::make_device_matrix_view( + auto output_view = raft::make_device_vector_view(dots_act.data(), params.rows); + auto input_row_major = raft::make_device_matrix_view( data.data(), params.rows, params.cols); - auto input_col_major = raft::make_device_matrix_view( + auto input_col_major = raft::make_device_matrix_view( data.data(), params.rows, params.cols); if (params.do_sqrt) { auto fin_op = [] __device__(const T in) { return raft::mySqrt(in); }; @@ -114,20 +115,20 @@ class RowNormTest : public ::testing::TestWithParam> { raft::handle_t handle; cudaStream_t stream; - NormInputs params; + NormInputs params; rmm::device_uvector data, dots_exp, dots_act; }; ///// Column-wise norm test definitisons -template +template __global__ void naiveColNormKernel( - Type* dots, const Type* data, int D, int N, NormType type, bool do_sqrt) + Type* dots, const Type* data, IdxT D, IdxT N, NormType type, bool do_sqrt) { - int colID = threadIdx.x + blockIdx.x * blockDim.x; - if (colID > D) return; // avoid out-of-bounds thread + IdxT colID = threadIdx.x + static_cast(blockIdx.x) * blockDim.x; + if (colID >= D) return; // avoid out-of-bounds thread Type acc = 0; - for (int i = 0; i < N; i++) { + for (IdxT i = 0; i < N; i++) { Type v = data[colID + i * D]; acc += type == L2Norm ? v * v : raft::myAbs(v); } @@ -135,21 +136,21 @@ __global__ void naiveColNormKernel( dots[colID] = do_sqrt ? raft::mySqrt(acc) : acc; } -template +template void naiveColNorm( - Type* dots, const Type* data, int D, int N, NormType type, bool do_sqrt, cudaStream_t stream) + Type* dots, const Type* data, IdxT D, IdxT N, NormType type, bool do_sqrt, cudaStream_t stream) { - static const int TPB = 64; - int nblks = raft::ceildiv(D, TPB); + static const IdxT TPB = 64; + IdxT nblks = raft::ceildiv(D, TPB); naiveColNormKernel<<>>(dots, data, D, N, type, do_sqrt); RAFT_CUDA_TRY(cudaPeekAtLastError()); } -template -class ColNormTest : public ::testing::TestWithParam> { +template +class ColNormTest : public ::testing::TestWithParam> { public: ColNormTest() - : params(::testing::TestWithParam>::GetParam()), + : params(::testing::TestWithParam>::GetParam()), stream(handle.get_stream()), data(params.rows * params.cols, stream), dots_exp(params.cols, stream), @@ -160,14 +161,14 @@ class ColNormTest : public ::testing::TestWithParam> { void SetUp() override { raft::random::RngState r(params.seed); - int rows = params.rows, cols = params.cols, len = rows * cols; + IdxT rows = params.rows, cols = params.cols, len = rows * cols; uniform(handle, r, data.data(), len, T(-1.0), T(1.0)); naiveColNorm(dots_exp.data(), data.data(), cols, rows, params.type, params.do_sqrt, stream); - auto output_view = raft::make_device_vector_view(dots_act.data(), params.cols); - auto input_row_major = raft::make_device_matrix_view( + auto output_view = raft::make_device_vector_view(dots_act.data(), params.cols); + auto input_row_major = raft::make_device_matrix_view( data.data(), params.rows, params.cols); - auto input_col_major = raft::make_device_matrix_view( + auto input_col_major = raft::make_device_matrix_view( data.data(), params.rows, params.cols); if (params.do_sqrt) { auto fin_op = [] __device__(const T in) { return raft::mySqrt(in); }; @@ -190,121 +191,81 @@ class ColNormTest : public ::testing::TestWithParam> { raft::handle_t handle; cudaStream_t stream; - NormInputs params; + NormInputs params; rmm::device_uvector data, dots_exp, dots_act; }; ///// Row- and column-wise tests -const std::vector> inputsf = {{0.00001f, 1024, 32, L1Norm, false, true, 1234ULL}, - {0.00001f, 1024, 64, L1Norm, false, true, 1234ULL}, - {0.00001f, 1024, 128, L1Norm, false, true, 1234ULL}, - {0.00001f, 1024, 256, L1Norm, false, true, 1234ULL}, - {0.00001f, 1024, 32, L2Norm, false, true, 1234ULL}, - {0.00001f, 1024, 64, L2Norm, false, true, 1234ULL}, - {0.00001f, 1024, 128, L2Norm, false, true, 1234ULL}, - {0.00001f, 1024, 256, L2Norm, false, true, 1234ULL}, - - {0.00001f, 1024, 32, L1Norm, true, true, 1234ULL}, - {0.00001f, 1024, 64, L1Norm, true, true, 1234ULL}, - {0.00001f, 1024, 128, L1Norm, true, true, 1234ULL}, - {0.00001f, 1024, 256, L1Norm, true, true, 1234ULL}, - {0.00001f, 1024, 32, L2Norm, true, true, 1234ULL}, - {0.00001f, 1024, 64, L2Norm, true, true, 1234ULL}, - {0.00001f, 1024, 128, L2Norm, true, true, 1234ULL}, - {0.00001f, 1024, 256, L2Norm, true, true, 1234ULL}}; - -const std::vector> inputsd = { - {0.00000001, 1024, 32, L1Norm, false, true, 1234ULL}, - {0.00000001, 1024, 64, L1Norm, false, true, 1234ULL}, - {0.00000001, 1024, 128, L1Norm, false, true, 1234ULL}, - {0.00000001, 1024, 256, L1Norm, false, true, 1234ULL}, - {0.00000001, 1024, 32, L2Norm, false, true, 1234ULL}, - {0.00000001, 1024, 64, L2Norm, false, true, 1234ULL}, - {0.00000001, 1024, 128, L2Norm, false, true, 1234ULL}, - {0.00000001, 1024, 256, L2Norm, false, true, 1234ULL}, - - {0.00000001, 1024, 32, L1Norm, true, true, 1234ULL}, - {0.00000001, 1024, 64, L1Norm, true, true, 1234ULL}, - {0.00000001, 1024, 128, L1Norm, true, true, 1234ULL}, - {0.00000001, 1024, 256, L1Norm, true, true, 1234ULL}, - {0.00000001, 1024, 32, L2Norm, true, true, 1234ULL}, - {0.00000001, 1024, 64, L2Norm, true, true, 1234ULL}, - {0.00000001, 1024, 128, L2Norm, true, true, 1234ULL}, - {0.00000001, 1024, 256, L2Norm, true, true, 1234ULL}}; - -typedef RowNormTest RowNormTestF; -TEST_P(RowNormTestF, Result) -{ - ASSERT_TRUE(raft::devArrMatch( - dots_exp.data(), dots_act.data(), params.rows, raft::CompareApprox(params.tolerance))); -} - -typedef RowNormTest RowNormTestD; -TEST_P(RowNormTestD, Result) -{ - ASSERT_TRUE(raft::devArrMatch( - dots_exp.data(), dots_act.data(), params.rows, raft::CompareApprox(params.tolerance))); -} - -INSTANTIATE_TEST_CASE_P(RowNormTests, RowNormTestF, ::testing::ValuesIn(inputsf)); - -INSTANTIATE_TEST_CASE_P(RowNormTests, RowNormTestD, ::testing::ValuesIn(inputsd)); - -const std::vector> inputscf = { - {0.00001f, 32, 1024, L1Norm, false, true, 1234ULL}, - {0.00001f, 64, 1024, L1Norm, false, true, 1234ULL}, - {0.00001f, 128, 1024, L1Norm, false, true, 1234ULL}, - {0.00001f, 256, 1024, L1Norm, false, true, 1234ULL}, - {0.00001f, 32, 1024, L2Norm, false, true, 1234ULL}, - {0.00001f, 64, 1024, L2Norm, false, true, 1234ULL}, - {0.00001f, 128, 1024, L2Norm, false, true, 1234ULL}, - {0.00001f, 256, 1024, L2Norm, false, true, 1234ULL}, - - {0.00001f, 32, 1024, L1Norm, true, true, 1234ULL}, - {0.00001f, 64, 1024, L1Norm, true, true, 1234ULL}, - {0.00001f, 128, 1024, L1Norm, true, true, 1234ULL}, - {0.00001f, 256, 1024, L1Norm, true, true, 1234ULL}, - {0.00001f, 32, 1024, L2Norm, true, true, 1234ULL}, - {0.00001f, 64, 1024, L2Norm, true, true, 1234ULL}, - {0.00001f, 128, 1024, L2Norm, true, true, 1234ULL}, - {0.00001f, 256, 1024, L2Norm, true, true, 1234ULL}}; - -const std::vector> inputscd = { - {0.00000001, 32, 1024, L1Norm, false, true, 1234ULL}, - {0.00000001, 64, 1024, L1Norm, false, true, 1234ULL}, - {0.00000001, 128, 1024, L1Norm, false, true, 1234ULL}, - {0.00000001, 256, 1024, L1Norm, false, true, 1234ULL}, - {0.00000001, 32, 1024, L2Norm, false, true, 1234ULL}, - {0.00000001, 64, 1024, L2Norm, false, true, 1234ULL}, - {0.00000001, 128, 1024, L2Norm, false, true, 1234ULL}, - {0.00000001, 256, 1024, L2Norm, false, true, 1234ULL}, - - {0.00000001, 32, 1024, L1Norm, true, true, 1234ULL}, - {0.00000001, 64, 1024, L1Norm, true, true, 1234ULL}, - {0.00000001, 128, 1024, L1Norm, true, true, 1234ULL}, - {0.00000001, 256, 1024, L1Norm, true, true, 1234ULL}, - {0.00000001, 32, 1024, L2Norm, true, true, 1234ULL}, - {0.00000001, 64, 1024, L2Norm, true, true, 1234ULL}, - {0.00000001, 128, 1024, L2Norm, true, true, 1234ULL}, - {0.00000001, 256, 1024, L2Norm, true, true, 1234ULL}}; - -typedef ColNormTest ColNormTestF; -TEST_P(ColNormTestF, Result) -{ - ASSERT_TRUE(raft::devArrMatch( - dots_exp.data(), dots_act.data(), params.cols, raft::CompareApprox(params.tolerance))); -} - -typedef ColNormTest ColNormTestD; -TEST_P(ColNormTestD, Result) -{ - ASSERT_TRUE(raft::devArrMatch( - dots_exp.data(), dots_act.data(), params.cols, raft::CompareApprox(params.tolerance))); -} - -INSTANTIATE_TEST_CASE_P(ColNormTests, ColNormTestF, ::testing::ValuesIn(inputscf)); - -INSTANTIATE_TEST_CASE_P(ColNormTests, ColNormTestD, ::testing::ValuesIn(inputscd)); +const std::vector> inputsf_i32 = + raft::util::itertools::product>( + {0.00001f}, {11, 1234}, {7, 33, 128, 500}, {L1Norm, L2Norm}, {false, true}, {true}, {1234ULL}); +const std::vector> inputsd_i32 = + raft::util::itertools::product>({0.00000001}, + {11, 1234}, + {7, 33, 128, 500}, + {L1Norm, L2Norm}, + {false, true}, + {true}, + {1234ULL}); +const std::vector> inputsf_i64 = + raft::util::itertools::product>( + {0.00001f}, {11, 1234}, {7, 33, 128, 500}, {L1Norm, L2Norm}, {false, true}, {true}, {1234ULL}); +const std::vector> inputsd_i64 = + raft::util::itertools::product>({0.00000001}, + {11, 1234}, + {7, 33, 128, 500}, + {L1Norm, L2Norm}, + {false, true}, + {true}, + {1234ULL}); +const std::vector> inputscf_i32 = + raft::util::itertools::product>( + {0.00001f}, {7, 33, 128, 500}, {11, 1234}, {L1Norm, L2Norm}, {false, true}, {true}, {1234ULL}); +const std::vector> inputscd_i32 = + raft::util::itertools::product>({0.00000001}, + {7, 33, 128, 500}, + {11, 1234}, + {L1Norm, L2Norm}, + {false, true}, + {true}, + {1234ULL}); +const std::vector> inputscf_i64 = + raft::util::itertools::product>( + {0.00001f}, {7, 33, 128, 500}, {11, 1234}, {L1Norm, L2Norm}, {false, true}, {true}, {1234ULL}); +const std::vector> inputscd_i64 = + raft::util::itertools::product>({0.00000001}, + {7, 33, 128, 500}, + {11, 1234}, + {L1Norm, L2Norm}, + {false, true}, + {true}, + {1234ULL}); + +typedef RowNormTest RowNormTestF_i32; +typedef RowNormTest RowNormTestD_i32; +typedef RowNormTest RowNormTestF_i64; +typedef RowNormTest RowNormTestD_i64; +typedef ColNormTest ColNormTestF_i32; +typedef ColNormTest ColNormTestD_i32; +typedef ColNormTest ColNormTestF_i64; +typedef ColNormTest ColNormTestD_i64; + +#define ROWNORM_TEST(test_type, test_inputs) \ + TEST_P(test_type, Result) \ + { \ + ASSERT_TRUE(raft::devArrMatch( \ + dots_exp.data(), dots_act.data(), dots_exp.size(), raft::CompareApprox(params.tolerance))); \ + } \ + INSTANTIATE_TEST_CASE_P(RowNormTests, test_type, ::testing::ValuesIn(test_inputs)) + +ROWNORM_TEST(RowNormTestF_i32, inputsf_i32); +ROWNORM_TEST(RowNormTestD_i32, inputsd_i32); +ROWNORM_TEST(RowNormTestF_i64, inputsf_i64); +ROWNORM_TEST(RowNormTestD_i64, inputsd_i64); +ROWNORM_TEST(ColNormTestF_i32, inputscf_i32); +ROWNORM_TEST(ColNormTestD_i32, inputscd_i32); +ROWNORM_TEST(ColNormTestF_i64, inputscf_i64); +ROWNORM_TEST(ColNormTestD_i64, inputscd_i64); } // end namespace linalg } // end namespace raft diff --git a/cpp/test/linalg/reduce.cu b/cpp/test/linalg/reduce.cu index 57654f88ab..00f3810d28 100644 --- a/cpp/test/linalg/reduce.cu +++ b/cpp/test/linalg/reduce.cu @@ -17,79 +17,97 @@ #include "../test_utils.h" #include "reduce.cuh" #include +#include #include -#include #include #include +#include namespace raft { namespace linalg { -template +template struct ReduceInputs { OutType tolerance; - int rows, cols; + IdxType rows, cols; bool rowMajor, alongRows; + OutType init; unsigned long long int seed; }; -template -::std::ostream& operator<<(::std::ostream& os, const ReduceInputs& dims) +template +::std::ostream& operator<<(::std::ostream& os, const ReduceInputs& dims) { + os << "{ " << dims.tolerance << ", " << dims.rows << ", " << dims.cols << ", " << dims.rowMajor + << ", " << dims.alongRows << ", " << dims.init << " " << dims.seed << '}'; return os; } // Or else, we get the following compilation error // for an extended __device__ lambda cannot have private or protected access // within its class -template +template void reduceLaunch(OutType* dots, const InType* data, - int cols, - int rows, + IdxType cols, + IdxType rows, bool rowMajor, bool alongRows, + OutType init, bool inplace, - cudaStream_t stream) + cudaStream_t stream, + MainLambda main_op, + ReduceLambda reduce_op, + FinalLambda final_op) { - Apply apply = alongRows ? Apply::ALONG_ROWS : Apply::ALONG_COLUMNS; - int output_size = alongRows ? cols : rows; + Apply apply = alongRows ? Apply::ALONG_ROWS : Apply::ALONG_COLUMNS; + IdxType output_size = alongRows ? cols : rows; - auto output_view_row_major = raft::make_device_vector_view(dots, output_size); - auto input_view_row_major = raft::make_device_matrix_view(data, rows, cols); - - auto output_view_col_major = raft::make_device_vector_view(dots, output_size); + auto output_view = raft::make_device_vector_view(dots, output_size); + auto input_view_row_major = raft::make_device_matrix_view(data, rows, cols); auto input_view_col_major = - raft::make_device_matrix_view(data, rows, cols); + raft::make_device_matrix_view(data, rows, cols); raft::handle_t handle{stream}; if (rowMajor) { reduce(handle, input_view_row_major, - output_view_row_major, - (OutType)0, - + output_view, + init, apply, inplace, - [] __device__(InType in, int i) { return static_cast(in * in); }); + main_op, + reduce_op, + final_op); } else { reduce(handle, input_view_col_major, - output_view_col_major, - (OutType)0, - + output_view, + init, apply, inplace, - [] __device__(InType in, int i) { return static_cast(in * in); }); + main_op, + reduce_op, + final_op); } } -template -class ReduceTest : public ::testing::TestWithParam> { +template , + typename ReduceLambda = raft::Sum, + typename FinalLambda = raft::SqrtOp> +class ReduceTest : public ::testing::TestWithParam> { public: ReduceTest() - : params(::testing::TestWithParam>::GetParam()), + : params(::testing::TestWithParam>::GetParam()), stream(handle.get_stream()), data(params.rows * params.cols, stream), dots_exp(params.alongRows ? params.rows : params.cols, stream), @@ -101,22 +119,66 @@ class ReduceTest : public ::testing::TestWithParam void SetUp() override { raft::random::RngState r(params.seed); - int rows = params.rows, cols = params.cols; - int len = rows * cols; - outlen = params.alongRows ? rows : cols; - uniform(handle, r, data.data(), len, InType(-1.0), InType(1.0)); - naiveReduction( - dots_exp.data(), data.data(), cols, rows, params.rowMajor, params.alongRows, stream); - - // Perform reduction with default inplace = false first - reduceLaunch( - dots_act.data(), data.data(), cols, rows, params.rowMajor, params.alongRows, false, stream); - // Add to result with inplace = true next, which shouldn't affect - // in the case of coalescedReduction! - if (!(params.rowMajor ^ params.alongRows)) { - reduceLaunch( - dots_act.data(), data.data(), cols, rows, params.rowMajor, params.alongRows, true, stream); - } + IdxType rows = params.rows, cols = params.cols; + IdxType len = rows * cols; + gen_uniform(data.data(), r, len, stream); + + MainLambda main_op; + ReduceLambda reduce_op; + FinalLambda fin_op; + + // For both the naive and the actual implementation, execute first with inplace=false then true + + naiveReduction(dots_exp.data(), + data.data(), + cols, + rows, + params.rowMajor, + params.alongRows, + stream, + params.init, + false, + main_op, + reduce_op, + fin_op); + naiveReduction(dots_exp.data(), + data.data(), + cols, + rows, + params.rowMajor, + params.alongRows, + stream, + params.init, + true, + main_op, + reduce_op, + fin_op); + + reduceLaunch(dots_act.data(), + data.data(), + cols, + rows, + params.rowMajor, + params.alongRows, + params.init, + false, + stream, + main_op, + reduce_op, + fin_op); + reduceLaunch(dots_act.data(), + data.data(), + cols, + rows, + params.rowMajor, + params.alongRows, + params.init, + true, + stream, + main_op, + reduce_op, + fin_op); + handle.sync_stream(stream); } @@ -124,92 +186,140 @@ class ReduceTest : public ::testing::TestWithParam raft::handle_t handle; cudaStream_t stream; - ReduceInputs params; + ReduceInputs params; rmm::device_uvector data; rmm::device_uvector dots_exp, dots_act; - int outlen; }; -const std::vector> inputsff = { - {0.000002f, 1024, 32, true, true, 1234ULL}, - {0.000002f, 1024, 64, true, true, 1234ULL}, - {0.000002f, 1024, 128, true, true, 1234ULL}, - {0.000002f, 1024, 256, true, true, 1234ULL}, - {0.000002f, 1024, 32, true, false, 1234ULL}, - {0.000002f, 1024, 64, true, false, 1234ULL}, - {0.000002f, 1024, 128, true, false, 1234ULL}, - {0.000002f, 1024, 256, true, false, 1234ULL}, - {0.000002f, 1024, 32, false, true, 1234ULL}, - {0.000002f, 1024, 64, false, true, 1234ULL}, - {0.000002f, 1024, 128, false, true, 1234ULL}, - {0.000002f, 1024, 256, false, true, 1234ULL}, - {0.000002f, 1024, 32, false, false, 1234ULL}, - {0.000002f, 1024, 64, false, false, 1234ULL}, - {0.000002f, 1024, 128, false, false, 1234ULL}, - {0.000002f, 1024, 256, false, false, 1234ULL}}; - -const std::vector> inputsdd = { - {0.000000001, 1024, 32, true, true, 1234ULL}, - {0.000000001, 1024, 64, true, true, 1234ULL}, - {0.000000001, 1024, 128, true, true, 1234ULL}, - {0.000000001, 1024, 256, true, true, 1234ULL}, - {0.000000001, 1024, 32, true, false, 1234ULL}, - {0.000000001, 1024, 64, true, false, 1234ULL}, - {0.000000001, 1024, 128, true, false, 1234ULL}, - {0.000000001, 1024, 256, true, false, 1234ULL}, - {0.000000001, 1024, 32, false, true, 1234ULL}, - {0.000000001, 1024, 64, false, true, 1234ULL}, - {0.000000001, 1024, 128, false, true, 1234ULL}, - {0.000000001, 1024, 256, false, true, 1234ULL}, - {0.000000001, 1024, 32, false, false, 1234ULL}, - {0.000000001, 1024, 64, false, false, 1234ULL}, - {0.000000001, 1024, 128, false, false, 1234ULL}, - {0.000000001, 1024, 256, false, false, 1234ULL}}; - -const std::vector> inputsfd = { - {0.000002f, 1024, 32, true, true, 1234ULL}, - {0.000002f, 1024, 64, true, true, 1234ULL}, - {0.000002f, 1024, 128, true, true, 1234ULL}, - {0.000002f, 1024, 256, true, true, 1234ULL}, - {0.000002f, 1024, 32, true, false, 1234ULL}, - {0.000002f, 1024, 64, true, false, 1234ULL}, - {0.000002f, 1024, 128, true, false, 1234ULL}, - {0.000002f, 1024, 256, true, false, 1234ULL}, - {0.000002f, 1024, 32, false, true, 1234ULL}, - {0.000002f, 1024, 64, false, true, 1234ULL}, - {0.000002f, 1024, 128, false, true, 1234ULL}, - {0.000002f, 1024, 256, false, true, 1234ULL}, - {0.000002f, 1024, 32, false, false, 1234ULL}, - {0.000002f, 1024, 64, false, false, 1234ULL}, - {0.000002f, 1024, 128, false, false, 1234ULL}, - {0.000002f, 1024, 256, false, false, 1234ULL}}; - -typedef ReduceTest ReduceTestFF; -TEST_P(ReduceTestFF, Result) -{ - ASSERT_TRUE(devArrMatch( - dots_exp.data(), dots_act.data(), outlen, raft::CompareApprox(params.tolerance))); -} +#define REDUCE_TEST(test_type, test_name, test_inputs) \ + typedef RAFT_DEPAREN(test_type) test_name; \ + TEST_P(test_name, Result) \ + { \ + ASSERT_TRUE(raft::devArrMatch( \ + dots_exp.data(), dots_act.data(), dots_exp.size(), raft::CompareApprox(params.tolerance))); \ + } \ + INSTANTIATE_TEST_CASE_P(ReduceTests, test_name, ::testing::ValuesIn(test_inputs)) -typedef ReduceTest ReduceTestDD; -TEST_P(ReduceTestDD, Result) -{ - ASSERT_TRUE(devArrMatch( - dots_exp.data(), dots_act.data(), outlen, raft::CompareApprox(params.tolerance))); -} +const std::vector> inputsff_i32 = + raft::util::itertools::product>( + {0.000002f}, {11, 1234}, {7, 33, 128, 500}, {true, false}, {true, false}, {0.0f}, {1234ULL}); +const std::vector> inputsdd_i32 = + raft::util::itertools::product>( + {0.000000001}, {11, 1234}, {7, 33, 128, 500}, {true, false}, {true, false}, {0.0}, {1234ULL}); +const std::vector> inputsfd_i32 = + raft::util::itertools::product>( + {0.000000001}, {11, 1234}, {7, 33, 128, 500}, {true, false}, {true, false}, {0.0f}, {1234ULL}); +const std::vector> inputsff_u32 = + raft::util::itertools::product>({0.000002f}, + {11u, 1234u}, + {7u, 33u, 128u, 500u}, + {true, false}, + {true, false}, + {0.0f}, + {1234ULL}); +const std::vector> inputsff_i64 = + raft::util::itertools::product>( + {0.000002f}, {11, 1234}, {7, 33, 128, 500}, {true, false}, {true, false}, {0.0f}, {1234ULL}); -typedef ReduceTest ReduceTestFD; -TEST_P(ReduceTestFD, Result) -{ - ASSERT_TRUE(devArrMatch( - dots_exp.data(), dots_act.data(), outlen, raft::CompareApprox(params.tolerance))); -} +REDUCE_TEST((ReduceTest), ReduceTestFFI32, inputsff_i32); +REDUCE_TEST((ReduceTest), ReduceTestDDI32, inputsdd_i32); +REDUCE_TEST((ReduceTest), ReduceTestFDI32, inputsfd_i32); +REDUCE_TEST((ReduceTest), ReduceTestFFU32, inputsff_u32); +REDUCE_TEST((ReduceTest), ReduceTestFFI64, inputsff_i64); + +// The following test cases are for "thick" coalesced reductions + +const std::vector> inputsff_thick_i32 = + raft::util::itertools::product>( + {0.0001f}, {3, 9}, {17771, 33333, 100000}, {true}, {true}, {0.0f}, {1234ULL}); +const std::vector> inputsdd_thick_i32 = + raft::util::itertools::product>( + {0.000001}, {3, 9}, {17771, 33333, 100000}, {true}, {true}, {0.0}, {1234ULL}); +const std::vector> inputsfd_thick_i32 = + raft::util::itertools::product>( + {0.000001}, {3, 9}, {17771, 33333, 100000}, {true}, {true}, {0.0f}, {1234ULL}); +const std::vector> inputsff_thick_u32 = + raft::util::itertools::product>( + {0.0001f}, {3u, 9u}, {17771u, 33333u, 100000u}, {true}, {true}, {0.0f}, {1234ULL}); +const std::vector> inputsff_thick_i64 = + raft::util::itertools::product>( + {0.0001f}, {3, 9}, {17771, 33333, 100000}, {true}, {true}, {0.0f}, {1234ULL}); + +REDUCE_TEST((ReduceTest), ReduceTestFFI32Thick, inputsff_thick_i32); +REDUCE_TEST((ReduceTest), ReduceTestDDI32Thick, inputsdd_thick_i32); +REDUCE_TEST((ReduceTest), ReduceTestFDI32Thick, inputsfd_thick_i32); +REDUCE_TEST((ReduceTest), ReduceTestFFU32Thick, inputsff_thick_u32); +REDUCE_TEST((ReduceTest), ReduceTestFFI64Thick, inputsff_thick_i64); + +// Test key-value-pair reductions. This is important because shuffle intrinsics can't be used +// directly with those types. -INSTANTIATE_TEST_CASE_P(ReduceTests, ReduceTestFF, ::testing::ValuesIn(inputsff)); +template +struct ValueToKVP { + HDI raft::KeyValuePair operator()(T value, IdxT idx) { return {idx, value}; } +}; + +template +struct ArgMaxOp { + HDI raft::KeyValuePair operator()(raft::KeyValuePair a, + raft::KeyValuePair b) + { + return (a.value > b.value || (a.value == b.value && a.key <= b.key)) ? a : b; + } +}; -INSTANTIATE_TEST_CASE_P(ReduceTests, ReduceTestDD, ::testing::ValuesIn(inputsdd)); +const std::vector, int>> inputs_kvpis_i32 = + raft::util::itertools::product, int>>( + {raft::KeyValuePair{0, short(0)}}, + {11, 1234}, + {7, 33, 128, 500}, + {true}, + {true}, + {raft::KeyValuePair{0, short(0)}}, + {1234ULL}); +const std::vector, int>> inputs_kvpif_i32 = + raft::util::itertools::product, int>>( + {raft::KeyValuePair{0, 0.0001f}}, + {11, 1234}, + {7, 33, 128, 500}, + {true}, + {true}, + {raft::KeyValuePair{0, 0.0f}}, + {1234ULL}); +const std::vector, int>> inputs_kvpid_i32 = + raft::util::itertools::product, int>>( + {raft::KeyValuePair{0, 0.000001}}, + {11, 1234}, + {7, 33, 128, 500}, + {true}, + {true}, + {raft::KeyValuePair{0, 0.0}}, + {1234ULL}); -INSTANTIATE_TEST_CASE_P(ReduceTests, ReduceTestFD, ::testing::ValuesIn(inputsfd)); +REDUCE_TEST((ReduceTest, + int, + ValueToKVP, + ArgMaxOp, + raft::Nop, int>>), + ReduceTestKVPISI32, + inputs_kvpis_i32); +REDUCE_TEST((ReduceTest, + int, + ValueToKVP, + ArgMaxOp, + raft::Nop, int>>), + ReduceTestKVPIFI32, + inputs_kvpif_i32); +REDUCE_TEST((ReduceTest, + int, + ValueToKVP, + ArgMaxOp, + raft::Nop, int>>), + ReduceTestKVPIDI32, + inputs_kvpid_i32); } // end namespace linalg } // end namespace raft diff --git a/cpp/test/linalg/reduce.cuh b/cpp/test/linalg/reduce.cuh index 162bf9f2c1..0dcffd3f41 100644 --- a/cpp/test/linalg/reduce.cuh +++ b/cpp/test/linalg/reduce.cuh @@ -28,70 +28,141 @@ namespace raft { namespace linalg { -template -__global__ void naiveCoalescedReductionKernel(OutType* dots, const InType* data, int D, int N) +template +__global__ void naiveCoalescedReductionKernel(OutType* dots, + const InType* data, + IdxType D, + IdxType N, + OutType init, + bool inplace, + MainLambda main_op, + ReduceLambda reduce_op, + FinalLambda fin_op) { - OutType acc = (OutType)0; - int rowStart = threadIdx.x + blockIdx.x * blockDim.x; + OutType acc = init; + IdxType rowStart = threadIdx.x + static_cast(blockIdx.x) * blockDim.x; if (rowStart < N) { - for (int i = 0; i < D; ++i) { - acc += static_cast(data[rowStart * D + i] * data[rowStart * D + i]); + for (IdxType i = 0; i < D; ++i) { + acc = reduce_op(acc, main_op(data[rowStart * D + i], i)); + } + if (inplace) { + dots[rowStart] = fin_op(reduce_op(dots[rowStart], acc)); + } else { + dots[rowStart] = fin_op(acc); } - dots[rowStart] = 2 * acc; } } -template -void naiveCoalescedReduction(OutType* dots, const InType* data, int D, int N, cudaStream_t stream) +template , + typename ReduceLambda = raft::Sum, + typename FinalLambda = raft::Nop> +void naiveCoalescedReduction(OutType* dots, + const InType* data, + IdxType D, + IdxType N, + cudaStream_t stream, + OutType init, + bool inplace = false, + MainLambda main_op = raft::Nop(), + ReduceLambda reduce_op = raft::Sum(), + FinalLambda fin_op = raft::Nop()) { - static const int TPB = 64; - int nblks = raft::ceildiv(N, TPB); - naiveCoalescedReductionKernel<<>>(dots, data, D, N); + static const IdxType TPB = 64; + IdxType nblks = raft::ceildiv(N, TPB); + naiveCoalescedReductionKernel<<>>( + dots, data, D, N, init, inplace, main_op, reduce_op, fin_op); RAFT_CUDA_TRY(cudaPeekAtLastError()); } -template -void unaryAndGemv(OutType* dots, const InType* data, int D, int N, cudaStream_t stream) +template +__global__ void naiveStridedReductionKernel(OutType* dots, + const InType* data, + IdxType D, + IdxType N, + OutType init, + bool inplace, + MainLambda main_op, + ReduceLambda reduce_op, + FinalLambda fin_op) { - // computes a MLCommon unary op on data (squares it), then computes Ax - //(A input matrix and x column vector) to sum columns - rmm::device_uvector sq(D * N, stream); - raft::linalg::unaryOp( - thrust::raw_pointer_cast(sq.data()), - data, - D * N, - [] __device__(InType v) { return static_cast(v * v); }, - stream); - cublasHandle_t handle; - RAFT_CUBLAS_TRY(cublasCreate(&handle)); - rmm::device_uvector ones(N, stream); // column vector [1...1] - raft::linalg::unaryOp( - ones.data(), ones.data(), ones.size(), [=] __device__(OutType input) { return 1; }, stream); - OutType alpha = 1, beta = 0; - // #TODO: Call from public API when ready - RAFT_CUBLAS_TRY(raft::linalg::detail::cublasgemv( - handle, CUBLAS_OP_N, D, N, &alpha, sq.data(), D, ones.data(), 1, &beta, dots, 1, stream)); - RAFT_CUDA_TRY(cudaDeviceSynchronize()); - RAFT_CUBLAS_TRY(cublasDestroy(handle)); + OutType acc = init; + IdxType col = threadIdx.x + static_cast(blockIdx.x) * blockDim.x; + if (col < D) { + for (IdxType i = 0; i < N; ++i) { + acc = reduce_op(acc, main_op(data[i * D + col], i)); + } + if (inplace) { + dots[col] = fin_op(reduce_op(dots[col], acc)); + } else { + dots[col] = fin_op(acc); + } + } +} + +template , + typename ReduceLambda = raft::Sum, + typename FinalLambda = raft::Nop> +void naiveStridedReduction(OutType* dots, + const InType* data, + IdxType D, + IdxType N, + cudaStream_t stream, + OutType init, + bool inplace = false, + MainLambda main_op = raft::Nop(), + ReduceLambda reduce_op = raft::Sum(), + FinalLambda fin_op = raft::Nop()) +{ + static const IdxType TPB = 64; + IdxType nblks = raft::ceildiv(D, TPB); + naiveStridedReductionKernel<<>>( + dots, data, D, N, init, inplace, main_op, reduce_op, fin_op); + RAFT_CUDA_TRY(cudaPeekAtLastError()); } -template +template , + typename ReduceLambda = raft::Sum, + typename FinalLambda = raft::Nop> void naiveReduction(OutType* dots, const InType* data, - int D, - int N, + IdxType D, + IdxType N, bool rowMajor, bool alongRows, - cudaStream_t stream) + cudaStream_t stream, + OutType init, + bool inplace = false, + MainLambda main_op = raft::Nop(), + ReduceLambda reduce_op = raft::Sum(), + FinalLambda fin_op = raft::Nop()) { if (rowMajor && alongRows) { - naiveCoalescedReduction(dots, data, D, N, stream); + naiveCoalescedReduction(dots, data, D, N, stream, init, inplace, main_op, reduce_op, fin_op); } else if (rowMajor && !alongRows) { - unaryAndGemv(dots, data, D, N, stream); + naiveStridedReduction(dots, data, D, N, stream, init, inplace, main_op, reduce_op, fin_op); } else if (!rowMajor && alongRows) { - unaryAndGemv(dots, data, N, D, stream); + naiveStridedReduction(dots, data, N, D, stream, init, inplace, main_op, reduce_op, fin_op); } else { - naiveCoalescedReduction(dots, data, N, D, stream); + naiveCoalescedReduction(dots, data, N, D, stream, init, inplace, main_op, reduce_op, fin_op); } RAFT_CUDA_TRY(cudaDeviceSynchronize()); } diff --git a/cpp/test/linalg/strided_reduction.cu b/cpp/test/linalg/strided_reduction.cu index 39e2764def..77ca585ea5 100644 --- a/cpp/test/linalg/strided_reduction.cu +++ b/cpp/test/linalg/strided_reduction.cu @@ -32,13 +32,13 @@ struct stridedReductionInputs { }; template -void stridedReductionLaunch(T* dots, const T* data, int cols, int rows, cudaStream_t stream) +void stridedReductionLaunch( + T* dots, const T* data, int cols, int rows, bool inplace, cudaStream_t stream) { raft::handle_t handle{stream}; auto dots_view = raft::make_device_vector_view(dots, cols); auto data_view = raft::make_device_matrix_view(data, rows, cols); - strided_reduction( - handle, data_view, dots_view, (T)0, false, [] __device__(T in, int i) { return in * in; }); + strided_reduction(handle, data_view, dots_view, (T)0, inplace, raft::L2Op{}); } template @@ -61,8 +61,30 @@ class stridedReductionTest : public ::testing::TestWithParam{}, + raft::Sum{}, + raft::Nop{}); + naiveStridedReduction(dots_exp.data(), + data.data(), + cols, + rows, + stream, + T(0), + true, + raft::L2Op{}, + raft::Sum{}, + raft::Nop{}); + stridedReductionLaunch(dots_act.data(), data.data(), cols, rows, false, stream); + stridedReductionLaunch(dots_act.data(), data.data(), cols, rows, true, stream); handle.sync_stream(stream); } diff --git a/cpp/test/test_utils.h b/cpp/test/test_utils.h index 14319b85e1..26483e6b2d 100644 --- a/cpp/test/test_utils.h +++ b/cpp/test/test_utils.h @@ -18,13 +18,18 @@ #include #include #include +#include +#include #include #include +#include +#include #include #include #include #include +#include #include #include @@ -42,7 +47,7 @@ struct CompareApprox { { T diff = abs(a - b); T m = std::max(abs(a), abs(b)); - T ratio = diff >= eps ? diff / m : diff; + T ratio = diff > eps ? diff / m : diff; return (ratio <= eps); } @@ -51,6 +56,30 @@ struct CompareApprox { T eps; }; +template +::std::ostream& operator<<(::std::ostream& os, const raft::KeyValuePair& kv) +{ + os << "{ " << kv.key << ", " << kv.value << '}'; + return os; +} + +template +struct CompareApprox> { + CompareApprox(raft::KeyValuePair eps) + : compare_keys(eps.key), compare_values(eps.value) + { + } + bool operator()(const raft::KeyValuePair& a, + const raft::KeyValuePair& b) const + { + return compare_keys(a.key, b.key) && compare_values(a.value, b.value); + } + + private: + CompareApprox compare_keys; + CompareApprox compare_values; +}; + template struct CompareApproxAbs { CompareApproxAbs(T eps_) : eps(eps_) {} @@ -280,6 +309,52 @@ testing::AssertionResult match(const T expected, T actual, L eq_compare) return testing::AssertionSuccess(); } +template +typename std::enable_if_t> gen_uniform(T* out, + raft::random::RngState& rng, + IdxT len, + cudaStream_t stream, + T range_min = T(-1), + T range_max = T(1)) +{ + raft::random::uniform(rng, out, len, range_min, range_max, stream); +} + +template +typename std::enable_if_t> gen_uniform(T* out, + raft::random::RngState& rng, + IdxT len, + cudaStream_t stream, + T range_min = T(0), + T range_max = T(100)) +{ + raft::random::uniformInt(rng, out, len, range_min, range_max, stream); +} + +template +void gen_uniform(raft::KeyValuePair* out, + raft::random::RngState& rng, + IdxT len, + cudaStream_t stream) +{ + rmm::device_uvector keys(len, stream); + rmm::device_uvector values(len, stream); + + gen_uniform(keys.data(), rng, len, stream); + gen_uniform(values.data(), rng, len, stream); + + const T1* d_keys = keys.data(); + const T2* d_values = values.data(); + auto counting = thrust::make_counting_iterator(0); + thrust::for_each(rmm::exec_policy(stream), + counting, + counting + len, + [out, d_keys, d_values] __device__(int idx) { + out[idx].key = d_keys[idx]; + out[idx].value = d_values[idx]; + }); +} + /** @} */ /** time the function call 'func' using cuda events */