From 0e96662f9b4fc77cd4ac6e528fe6103c81715287 Mon Sep 17 00:00:00 2001 From: Louis Sugy Date: Sat, 21 Jan 2023 21:55:43 +0100 Subject: [PATCH] Improvements in `matrix::gather`: test coverage, compilation errors, performance (#1126) In order to deprecate `copy_selected` from `ann_utils.cuh`, I wanted to make sure that the performance of `matrix::gather` was on par. But in the process I discovered that: - Map transforms and conditional copy were not tested at all. - In fact, most of the API in `gather.cuh` wasn't covered in tests and some of the functions didn't even compile. - The same type `MatrixIteratorT` was used for the input and output iterators, which made it impossible to take advantage of custom iterators, as is needed in `kmeans_balanced` to convert the dataset from `T` to `float` and gather in a single step. - The performance was really poor when `D` is small because the kernel assigns one block per row (so a block could be working on only 2 or 3 elements...) This PR addresses all the aforementioned issues. Authors: - Louis Sugy (https://github.com/Nyrio) Approvers: - Tamas Bela Feher (https://github.com/tfeher) - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/1126 --- cpp/bench/CMakeLists.txt | 2 +- cpp/bench/matrix/argmin.cu | 17 +- cpp/bench/matrix/gather.cu | 101 +++++ .../raft/cluster/detail/kmeans_common.cuh | 2 +- cpp/include/raft/core/operators.hpp | 51 ++- cpp/include/raft/matrix/detail/gather.cuh | 236 ++++++----- cpp/include/raft/matrix/gather.cuh | 371 ++++++++---------- cpp/test/matrix/gather.cu | 208 ++++++---- 8 files changed, 578 insertions(+), 410 deletions(-) create mode 100644 cpp/bench/matrix/gather.cu diff --git a/cpp/bench/CMakeLists.txt b/cpp/bench/CMakeLists.txt index 813483adc5..8dcdb325e9 100644 --- a/cpp/bench/CMakeLists.txt +++ b/cpp/bench/CMakeLists.txt @@ -103,7 +103,7 @@ if(BUILD_BENCH) bench/main.cpp ) - ConfigureBench(NAME MATRIX_BENCH PATH bench/matrix/argmin.cu bench/main.cpp) + ConfigureBench(NAME MATRIX_BENCH PATH bench/matrix/argmin.cu bench/matrix/gather.cu bench/main.cpp) ConfigureBench( NAME RANDOM_BENCH PATH bench/random/make_blobs.cu bench/random/permute.cu bench/random/rng.cu diff --git a/cpp/bench/matrix/argmin.cu b/cpp/bench/matrix/argmin.cu index 0d0dea0fdb..52f5aab7f3 100644 --- a/cpp/bench/matrix/argmin.cu +++ b/cpp/bench/matrix/argmin.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,10 +17,11 @@ #include #include #include +#include #include -namespace raft::bench::linalg { +namespace raft::bench::matrix { template struct ArgminParams { @@ -57,15 +58,11 @@ struct Argmin : public fixture { raft::device_vector indices; }; // struct Argmin -const std::vector> argmin_inputs_i64{ - {1000, 64}, {1000, 128}, {1000, 256}, {1000, 512}, {1000, 1024}, - {10000, 64}, {10000, 128}, {10000, 256}, {10000, 512}, {10000, 1024}, - {100000, 64}, {100000, 128}, {100000, 256}, {100000, 512}, {100000, 1024}, - {1000000, 64}, {1000000, 128}, {1000000, 256}, {1000000, 512}, {1000000, 1024}, - {10000000, 64}, {10000000, 128}, {10000000, 256}, {10000000, 512}, {10000000, 1024}, -}; +const std::vector> argmin_inputs_i64 = + raft::util::itertools::product>({1000, 10000, 100000, 1000000, 10000000}, + {64, 128, 256, 512, 1024}); RAFT_BENCH_REGISTER((Argmin), "", argmin_inputs_i64); RAFT_BENCH_REGISTER((Argmin), "", argmin_inputs_i64); -} // namespace raft::bench::linalg +} // namespace raft::bench::matrix diff --git a/cpp/bench/matrix/gather.cu b/cpp/bench/matrix/gather.cu new file mode 100644 index 0000000000..97812c20a1 --- /dev/null +++ b/cpp/bench/matrix/gather.cu @@ -0,0 +1,101 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include + +#include + +namespace raft::bench::matrix { + +template +struct GatherParams { + IdxT rows, cols, map_length; +}; + +template +inline auto operator<<(std::ostream& os, const GatherParams& p) -> std::ostream& +{ + os << p.rows << "#" << p.cols << "#" << p.map_length; + return os; +} + +template +struct Gather : public fixture { + Gather(const GatherParams& p) : params(p) {} + + void allocate_data(const ::benchmark::State& state) override + { + matrix = raft::make_device_matrix(handle, params.rows, params.cols); + map = raft::make_device_vector(handle, params.map_length); + out = raft::make_device_matrix(handle, params.map_length, params.cols); + stencil = raft::make_device_vector(handle, Conditional ? params.map_length : IdxT(0)); + + raft::random::RngState rng{1234}; + raft::random::uniform( + rng, matrix.data_handle(), params.rows * params.cols, T(-1), T(1), stream); + raft::random::uniformInt( + handle, rng, map.data_handle(), params.map_length, (MapT)0, (MapT)params.rows); + if constexpr (Conditional) { + raft::random::uniform(rng, stencil.data_handle(), params.map_length, T(-1), T(1), stream); + } + handle.sync_stream(stream); + } + + void run_benchmark(::benchmark::State& state) override + { + std::ostringstream label_stream; + label_stream << params; + state.SetLabel(label_stream.str()); + + loop_on_state(state, [this]() { + auto matrix_const_view = raft::make_device_matrix_view( + matrix.data_handle(), matrix.extent(0), matrix.extent(1)); + auto map_const_view = + raft::make_device_vector_view(map.data_handle(), map.extent(0)); + if constexpr (Conditional) { + auto stencil_const_view = + raft::make_device_vector_view(stencil.data_handle(), stencil.extent(0)); + auto pred_op = raft::plug_const_op(T(0.0), raft::greater_op()); + raft::matrix::gather_if( + handle, matrix_const_view, out.view(), map_const_view, stencil_const_view, pred_op); + } else { + raft::matrix::gather(handle, matrix_const_view, map_const_view, out.view()); + } + }); + } + + private: + GatherParams params; + raft::device_matrix matrix, out; + raft::device_vector stencil; + raft::device_vector map; +}; // struct Gather + +template +using GatherIf = Gather; + +const std::vector> gather_inputs_i64 = + raft::util::itertools::product>( + {1000000}, {10, 20, 50, 100, 200, 500}, {1000, 10000, 100000, 1000000}); + +RAFT_BENCH_REGISTER((Gather), "", gather_inputs_i64); +RAFT_BENCH_REGISTER((Gather), "", gather_inputs_i64); +RAFT_BENCH_REGISTER((GatherIf), "", gather_inputs_i64); +RAFT_BENCH_REGISTER((GatherIf), "", gather_inputs_i64); +} // namespace raft::bench::matrix diff --git a/cpp/include/raft/cluster/detail/kmeans_common.cuh b/cpp/include/raft/cluster/detail/kmeans_common.cuh index 2fd33ac759..559793442f 100644 --- a/cpp/include/raft/cluster/detail/kmeans_common.cuh +++ b/cpp/include/raft/cluster/detail/kmeans_common.cuh @@ -335,7 +335,7 @@ void shuffleAndGather(const raft::handle_t& handle, in.extent(1), in.extent(0), indices.data_handle(), - n_samples_to_gather, + static_cast(n_samples_to_gather), out.data_handle(), stream); } diff --git a/cpp/include/raft/core/operators.hpp b/cpp/include/raft/core/operators.hpp index de27c2b271..edb437c880 100644 --- a/cpp/include/raft/core/operators.hpp +++ b/cpp/include/raft/core/operators.hpp @@ -147,6 +147,14 @@ struct div_checkzero_op { } }; +struct modulo_op { + template + constexpr RAFT_INLINE_FUNCTION auto operator()(const T1& a, const T2& b) const + { + return a % b; + } +}; + struct pow_op { template RAFT_INLINE_FUNCTION auto operator()(const Type& a, const Type& b) const @@ -189,17 +197,49 @@ struct argmax_op { } }; +struct greater_op { + template + constexpr RAFT_INLINE_FUNCTION auto operator()(const T1& a, const T2& b) const + { + return a > b; + } +}; + +struct less_op { + template + constexpr RAFT_INLINE_FUNCTION auto operator()(const T1& a, const T2& b) const + { + return a < b; + } +}; + +struct greater_or_equal_op { + template + constexpr RAFT_INLINE_FUNCTION auto operator()(const T1& a, const T2& b) const + { + return a >= b; + } +}; + +struct less_or_equal_op { + template + constexpr RAFT_INLINE_FUNCTION auto operator()(const T1& a, const T2& b) const + { + return a <= b; + } +}; + struct equal_op { - template - constexpr RAFT_INLINE_FUNCTION auto operator()(const Type& a, const Type& b) const + template + constexpr RAFT_INLINE_FUNCTION auto operator()(const T1& a, const T2& b) const { return a == b; } }; struct notequal_op { - template - constexpr RAFT_INLINE_FUNCTION auto operator()(const Type& a, const Type& b) const + template + constexpr RAFT_INLINE_FUNCTION auto operator()(const T1& a, const T2& b) const { return a != b; } @@ -267,6 +307,9 @@ using div_const_op = plug_const_op; template using div_checkzero_const_op = plug_const_op; +template +using modulo_const_op = plug_const_op; + template using pow_const_op = plug_const_op; diff --git a/cpp/include/raft/matrix/detail/gather.cuh b/cpp/include/raft/matrix/detail/gather.cuh index c006f69e47..a8efc2d0d0 100644 --- a/cpp/include/raft/matrix/detail/gather.cuh +++ b/cpp/include/raft/matrix/detail/gather.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,41 +17,63 @@ #pragma once #include +#include namespace raft { namespace matrix { namespace detail { -// gatherKernel conditionally copies rows from the source matrix 'in' into the destination matrix -// 'out' according to a map (or a transformed map) -template +struct gather_policy { + static constexpr int n_threads = tpb; + static constexpr int work_per_thread = wpt; + static constexpr int stride = tpb * wpt; +}; + +/** Conditionally copies rows from the source matrix 'in' into the destination matrix + * 'out' according to a map (or a transformed map) */ +template -__global__ void gatherKernel(const MatrixIteratorT in, - IndexT D, - IndexT N, - MapIteratorT map, - StencilIteratorT stencil, - MatrixIteratorT out, - PredicateOp pred_op, - MapTransformOp transform_op) + typename OutputIteratorT, + typename IndexT> +__global__ void gather_kernel(const InputIteratorT in, + IndexT D, + IndexT len, + const MapIteratorT map, + StencilIteratorT stencil, + OutputIteratorT out, + PredicateOp pred_op, + MapTransformOp transform_op) { typedef typename std::iterator_traits::value_type MapValueT; typedef typename std::iterator_traits::value_type StencilValueT; - IndexT outRowStart = blockIdx.x * D; - MapValueT map_val = map[blockIdx.x]; - StencilValueT stencil_val = stencil[blockIdx.x]; +#pragma unroll + for (IndexT wid = 0; wid < Policy::work_per_thread; wid++) { + IndexT tid = threadIdx.x + (Policy::work_per_thread * static_cast(blockIdx.x) + wid) * + Policy::n_threads; + if (tid < len) { + IndexT i_dst = tid / D; + IndexT j = tid % D; + + MapValueT map_val = map[i_dst]; + StencilValueT stencil_val = stencil[i_dst]; - bool predicate = pred_op(stencil_val); - if (predicate) { - IndexT inRowStart = transform_op(map_val) * D; - for (int i = threadIdx.x; i < D; i += TPB) { - out[outRowStart + i] = in[inRowStart + i]; + bool predicate = pred_op(stencil_val); + if (predicate) { + IndexT i_src = transform_op(map_val); + out[tid] = in[i_src * D + j]; + } } } } @@ -60,7 +82,7 @@ __global__ void gatherKernel(const MatrixIteratorT in, * @brief gather conditionally copies rows from a source matrix into a destination matrix according * to a transformed map. * - * @tparam MatrixIteratorT Random-access iterator type, for reading input matrix (may be a + * @tparam InputIteratorT Random-access iterator type, for reading input matrix (may be a * simple pointer type). * @tparam MapIteratorT Random-access iterator type, for reading input map (may be a simple * pointer type). @@ -69,7 +91,10 @@ __global__ void gatherKernel(const MatrixIteratorT in, * @tparam UnaryPredicateOp Unary lambda expression or operator type, UnaryPredicateOp's result * type must be convertible to bool type. * @tparam MapTransformOp Unary lambda expression or operator type, MapTransformOp's result - * type must be convertible to IndexT (= int) type. + * type must be convertible to IndexT. + * @tparam OutputIteratorT Random-access iterator type, for writing output matrix (may be a + * simple pointer type). + * @tparam IndexT Index type. * * @param in Pointer to the input matrix (assumed to be row-major) * @param D Leading dimension of the input matrix 'in', which in-case of row-major @@ -83,18 +108,20 @@ __global__ void gatherKernel(const MatrixIteratorT in, * @param transform_op The transformation operation, transforms the map values to IndexT * @param stream CUDA stream to launch kernels within */ -template -void gatherImpl(const MatrixIteratorT in, - int D, - int N, - MapIteratorT map, + typename MapTransformOp, + typename OutputIteratorT, + typename IndexT> +void gatherImpl(const InputIteratorT in, + IndexT D, + IndexT N, + const MapIteratorT map, StencilIteratorT stencil, - int map_length, - MatrixIteratorT out, + IndexT map_length, + OutputIteratorT out, UnaryPredicateOp pred_op, MapTransformOp transform_op, cudaStream_t stream) @@ -102,9 +129,6 @@ void gatherImpl(const MatrixIteratorT in, // skip in case of 0 length input if (map_length <= 0 || N <= 0 || D <= 0) return; - // signed integer type for indexing or global offsets - typedef int IndexT; - // map value type typedef typename std::iterator_traits::value_type MapValueT; @@ -121,38 +145,26 @@ void gatherImpl(const MatrixIteratorT in, static_assert((std::is_convertible::value), "UnaryPredicateOp's result type must be convertible to bool type"); - if (D <= 32) { - gatherKernel - <<>>(in, D, N, map, stencil, out, pred_op, transform_op); - } else if (D <= 64) { - gatherKernel - <<>>(in, D, N, map, stencil, out, pred_op, transform_op); - } else if (D <= 128) { - gatherKernel - <<>>(in, D, N, map, stencil, out, pred_op, transform_op); + IndexT len = map_length * D; + constexpr int TPB = 128; + const int n_sm = raft::getMultiProcessorCount(); + // The following empirical heuristics enforce that we keep a good balance between having enough + // blocks and enough work per thread. + if (len < 32 * TPB * n_sm) { + using Policy = gather_policy; + IndexT n_blocks = raft::ceildiv(map_length * D, static_cast(Policy::stride)); + gather_kernel<<>>( + in, D, len, map, stencil, out, pred_op, transform_op); + } else if (len < 32 * 4 * TPB * n_sm) { + using Policy = gather_policy; + IndexT n_blocks = raft::ceildiv(map_length * D, static_cast(Policy::stride)); + gather_kernel<<>>( + in, D, len, map, stencil, out, pred_op, transform_op); } else { - gatherKernel - <<>>(in, D, N, map, stencil, out, pred_op, transform_op); + using Policy = gather_policy; + IndexT n_blocks = raft::ceildiv(map_length * D, static_cast(Policy::stride)); + gather_kernel<<>>( + in, D, len, map, stencil, out, pred_op, transform_op); } RAFT_CUDA_TRY(cudaPeekAtLastError()); } @@ -160,10 +172,13 @@ void gatherImpl(const MatrixIteratorT in, /** * @brief gather copies rows from a source matrix into a destination matrix according to a map. * - * @tparam MatrixIteratorT Random-access iterator type, for reading input matrix (may be a + * @tparam InputIteratorT Random-access iterator type, for reading input matrix (may be a * simple pointer type). * @tparam MapIteratorT Random-access iterator type, for reading input map (may be a simple * pointer type). + * @tparam OutputIteratorT Random-access iterator type, for writing output matrix (may be a + * simple pointer type). + * @tparam IndexT Index type. * * @param in Pointer to the input matrix (assumed to be row-major) * @param D Leading dimension of the input matrix 'in', which in-case of row-major @@ -174,13 +189,13 @@ void gatherImpl(const MatrixIteratorT in, * @param out Pointer to the output matrix (assumed to be row-major) * @param stream CUDA stream to launch kernels within */ -template -void gather(const MatrixIteratorT in, - int D, - int N, - MapIteratorT map, - int map_length, - MatrixIteratorT out, +template +void gather(const InputIteratorT in, + IndexT D, + IndexT N, + const MapIteratorT map, + IndexT map_length, + OutputIteratorT out, cudaStream_t stream) { typedef typename std::iterator_traits::value_type MapValueT; @@ -192,12 +207,15 @@ void gather(const MatrixIteratorT in, * @brief gather copies rows from a source matrix into a destination matrix according to a * transformed map. * - * @tparam MatrixIteratorT Random-access iterator type, for reading input matrix (may be a + * @tparam InputIteratorT Random-access iterator type, for reading input matrix (may be a * simple pointer type). * @tparam MapIteratorT Random-access iterator type, for reading input map (may be a simple * pointer type). * @tparam MapTransformOp Unary lambda expression or operator type, MapTransformOp's result - * type must be convertible to IndexT (= int) type. + * type must be convertible to IndexT. + * @tparam OutputIteratorT Random-access iterator type, for writing output matrix (may be a + * simple pointer type). + * @tparam IndexT Index type. * * @param in Pointer to the input matrix (assumed to be row-major) * @param D Leading dimension of the input matrix 'in', which in-case of row-major @@ -209,13 +227,17 @@ void gather(const MatrixIteratorT in, * @param transform_op The transformation operation, transforms the map values to IndexT * @param stream CUDA stream to launch kernels within */ -template -void gather(const MatrixIteratorT in, - int D, - int N, - MapIteratorT map, - int map_length, - MatrixIteratorT out, +template +void gather(const InputIteratorT in, + IndexT D, + IndexT N, + const MapIteratorT map, + IndexT map_length, + OutputIteratorT out, MapTransformOp transform_op, cudaStream_t stream) { @@ -227,7 +249,7 @@ void gather(const MatrixIteratorT in, * @brief gather_if conditionally copies rows from a source matrix into a destination matrix * according to a map. * - * @tparam MatrixIteratorT Random-access iterator type, for reading input matrix (may be a + * @tparam InputIteratorT Random-access iterator type, for reading input matrix (may be a * simple pointer type). * @tparam MapIteratorT Random-access iterator type, for reading input map (may be a simple * pointer type). @@ -235,6 +257,9 @@ void gather(const MatrixIteratorT in, * simple pointer type). * @tparam UnaryPredicateOp Unary lambda expression or operator type, UnaryPredicateOp's result * type must be convertible to bool type. + * @tparam OutputIteratorT Random-access iterator type, for writing output matrix (may be a + * simple pointer type). + * @tparam IndexT Index type. * * @param in Pointer to the input matrix (assumed to be row-major) * @param D Leading dimension of the input matrix 'in', which in-case of row-major @@ -247,17 +272,19 @@ void gather(const MatrixIteratorT in, * @param pred_op Predicate to apply to the stencil values * @param stream CUDA stream to launch kernels within */ -template -void gather_if(const MatrixIteratorT in, - int D, - int N, - MapIteratorT map, + typename UnaryPredicateOp, + typename OutputIteratorT, + typename IndexT> +void gather_if(const InputIteratorT in, + IndexT D, + IndexT N, + const MapIteratorT map, StencilIteratorT stencil, - int map_length, - MatrixIteratorT out, + IndexT map_length, + OutputIteratorT out, UnaryPredicateOp pred_op, cudaStream_t stream) { @@ -269,7 +296,7 @@ void gather_if(const MatrixIteratorT in, * @brief gather_if conditionally copies rows from a source matrix into a destination matrix * according to a transformed map. * - * @tparam MatrixIteratorT Random-access iterator type, for reading input matrix (may be a + * @tparam InputIteratorT Random-access iterator type, for reading input matrix (may be a * simple pointer type). * @tparam MapIteratorT Random-access iterator type, for reading input map (may be a simple * pointer type). @@ -278,7 +305,10 @@ void gather_if(const MatrixIteratorT in, * @tparam UnaryPredicateOp Unary lambda expression or operator type, UnaryPredicateOp's result * type must be convertible to bool type. * @tparam MapTransformOp Unary lambda expression or operator type, MapTransformOp's result - * type must be convertible to IndexT (= int) type. + * type must be convertible to IndexT type. + * @tparam OutputIteratorT Random-access iterator type, for writing output matrix (may be a + * simple pointer type). + * @tparam IndexT Index type. * * @param in Pointer to the input matrix (assumed to be row-major) * @param D Leading dimension of the input matrix 'in', which in-case of row-major @@ -292,18 +322,20 @@ void gather_if(const MatrixIteratorT in, * @param transform_op The transformation operation, transforms the map values to IndexT * @param stream CUDA stream to launch kernels within */ -template -void gather_if(const MatrixIteratorT in, - int D, - int N, - MapIteratorT map, + typename MapTransformOp, + typename OutputIteratorT, + typename IndexT> +void gather_if(const InputIteratorT in, + IndexT D, + IndexT N, + const MapIteratorT map, StencilIteratorT stencil, - int map_length, - MatrixIteratorT out, + IndexT map_length, + OutputIteratorT out, UnaryPredicateOp pred_op, MapTransformOp transform_op, cudaStream_t stream) diff --git a/cpp/include/raft/matrix/gather.cuh b/cpp/include/raft/matrix/gather.cuh index 6a923fb0cc..9487da35b5 100644 --- a/cpp/include/raft/matrix/gather.cuh +++ b/cpp/include/raft/matrix/gather.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,6 +19,7 @@ #include #include #include +#include namespace raft::matrix { @@ -28,62 +29,68 @@ namespace raft::matrix { */ /** - * @brief gather copies rows from a source matrix into a destination matrix according to a map. + * @brief Copies rows from a source matrix into a destination matrix according to a map. * - * @tparam MatrixIteratorT Random-access iterator type, for reading input matrix (may be a - * simple pointer type). - * @tparam MapIteratorT Random-access iterator type, for reading input map (may be a simple - * pointer type). + * For each output row, read the index in the input matrix from the map and copy the row. * - * @param in Pointer to the input matrix (assumed to be row-major) - * @param D Leading dimension of the input matrix 'in', which in-case of row-major - * storage is the number of columns - * @param N Second dimension - * @param map Pointer to the input sequence of gather locations - * @param map_length The length of 'map' and 'stencil' - * @param out Pointer to the output matrix (assumed to be row-major) + * @tparam InputIteratorT Input iterator type, for the input matrix (may be a pointer type). + * @tparam MapIteratorT Input iterator type, for the map (may be a pointer type). + * @tparam OutputIteratorT Output iterator type, for the output matrix (may be a pointer type). + * @tparam IndexT Index type. + * + * @param in Input matrix, dim = [N, D] (row-major) + * @param D Number of columns of the input/output matrices + * @param N Number of rows of the input matrix + * @param map Map of row indices to gather, dim = [map_length] + * @param map_length The length of 'map', number of rows of the output matrix + * @param out Output matrix, dim = [map_length, D] (row-major) * @param stream CUDA stream to launch kernels within */ -template -void gather(const MatrixIteratorT in, - int D, - int N, - MapIteratorT map, - int map_length, - MatrixIteratorT out, +template +void gather(const InputIteratorT in, + IndexT D, + IndexT N, + const MapIteratorT map, + IndexT map_length, + OutputIteratorT out, cudaStream_t stream) { detail::gather(in, D, N, map, map_length, out, stream); } /** - * @brief gather copies rows from a source matrix into a destination matrix according to a - * transformed map. + * @brief Copies rows from a source matrix into a destination matrix according to a transformed map. + * + * For each output row, read the index in the input matrix from the map, apply a transformation to + * this input index and copy the row. * - * @tparam MatrixIteratorT Random-access iterator type, for reading input matrix (may be a - * simple pointer type). - * @tparam MapIteratorT Random-access iterator type, for reading input map (may be a simple - * pointer type). - * @tparam MapTransformOp Unary lambda expression or operator type, MapTransformOp's result - * type must be convertible to IndexT (= int) type. + * @tparam InputIteratorT Input iterator type, for the input matrix (may be a pointer type). + * @tparam MapIteratorT Input iterator type, for the map (may be a pointer type). + * @tparam MapTransformOp Unary lambda expression or operator type. MapTransformOp's result type + * must be convertible to IndexT. + * @tparam OutputIteratorT Output iterator type, for the output matrix (may be a pointer type). + * @tparam IndexT Index type. * - * @param in Pointer to the input matrix (assumed to be row-major) - * @param D Leading dimension of the input matrix 'in', which in-case of row-major - * storage is the number of columns - * @param N Second dimension - * @param map Pointer to the input sequence of gather locations - * @param map_length The length of 'map' and 'stencil' - * @param out Pointer to the output matrix (assumed to be row-major) - * @param transform_op The transformation operation, transforms the map values to IndexT + * @param in Input matrix, dim = [N, D] (row-major) + * @param D Number of columns of the input/output matrices + * @param N Number of rows of the input matrix + * @param map Map of row indices to gather, dim = [map_length] + * @param map_length The length of 'map', number of rows of the output matrix + * @param out Output matrix, dim = [map_length, D] (row-major) + * @param transform_op Transformation to apply to map values * @param stream CUDA stream to launch kernels within */ -template -void gather(const MatrixIteratorT in, - int D, - int N, - MapIteratorT map, - int map_length, - MatrixIteratorT out, +template +void gather(const InputIteratorT in, + IndexT D, + IndexT N, + const MapIteratorT map, + IndexT map_length, + OutputIteratorT out, MapTransformOp transform_op, cudaStream_t stream) { @@ -91,40 +98,42 @@ void gather(const MatrixIteratorT in, } /** - * @brief gather_if conditionally copies rows from a source matrix into a destination matrix - * according to a map. + * @brief Conditionally copies rows from a source matrix into a destination matrix. + * + * For each output row, read the index in the input matrix from the map, read a stencil value, apply + * a predicate to the stencil value, and if true, copy the row. * - * @tparam MatrixIteratorT Random-access iterator type, for reading input matrix (may be a - * simple pointer type). - * @tparam MapIteratorT Random-access iterator type, for reading input map (may be a simple - * pointer type). - * @tparam StencilIteratorT Random-access iterator type, for reading input stencil (may be a - * simple pointer type). - * @tparam UnaryPredicateOp Unary lambda expression or operator type, UnaryPredicateOp's result - * type must be convertible to bool type. + * @tparam InputIteratorT Input iterator type, for the input matrix (may be a pointer type). + * @tparam MapIteratorT Input iterator type, for the map (may be a pointer type). + * @tparam StencilIteratorT Input iterator type, for the stencil (may be a pointer type). + * @tparam UnaryPredicateOp Unary lambda expression or operator type. UnaryPredicateOp's result type + * must be convertible to bool type. + * @tparam OutputIteratorT Output iterator type, for the output matrix (may be a pointer type). + * @tparam IndexT Index type. * - * @param in Pointer to the input matrix (assumed to be row-major) - * @param D Leading dimension of the input matrix 'in', which in-case of row-major - * storage is the number of columns - * @param N Second dimension - * @param map Pointer to the input sequence of gather locations - * @param stencil Pointer to the input sequence of stencil or predicate values - * @param map_length The length of 'map' and 'stencil' - * @param out Pointer to the output matrix (assumed to be row-major) + * @param in Input matrix, dim = [N, D] (row-major) + * @param D Number of columns of the input/output matrices + * @param N Number of rows of the input matrix + * @param map Map of row indices to gather, dim = [map_length] + * @param stencil Sequence of stencil values, dim = [map_length] + * @param map_length The length of 'map' and 'stencil', number of rows of the output matrix + * @param out Output matrix, dim = [map_length, D] (row-major) * @param pred_op Predicate to apply to the stencil values * @param stream CUDA stream to launch kernels within */ -template -void gather_if(const MatrixIteratorT in, - int D, - int N, - MapIteratorT map, + typename UnaryPredicateOp, + typename OutputIteratorT, + typename IndexT> +void gather_if(const InputIteratorT in, + IndexT D, + IndexT N, + const MapIteratorT map, StencilIteratorT stencil, - int map_length, - MatrixIteratorT out, + IndexT map_length, + OutputIteratorT out, UnaryPredicateOp pred_op, cudaStream_t stream) { @@ -132,44 +141,47 @@ void gather_if(const MatrixIteratorT in, } /** - * @brief gather_if conditionally copies rows from a source matrix into a destination matrix - * according to a transformed map. + * @brief Conditionally copies rows according to a transformed map. + * + * For each output row, read the index in the input matrix from the map, read a stencil value, + * apply a predicate to the stencil value, and if true, apply a transformation to the input index + * and copy the row. * - * @tparam MatrixIteratorT Random-access iterator type, for reading input matrix (may be a - * simple pointer type). - * @tparam MapIteratorT Random-access iterator type, for reading input map (may be a simple - * pointer type). - * @tparam StencilIteratorT Random-access iterator type, for reading input stencil (may be a - * simple pointer type). - * @tparam UnaryPredicateOp Unary lambda expression or operator type, UnaryPredicateOp's result - * type must be convertible to bool type. - * @tparam MapTransformOp Unary lambda expression or operator type, MapTransformOp's result - * type must be convertible to IndexT (= int) type. + * @tparam InputIteratorT Input iterator type, for the input matrix (may be a pointer type). + * @tparam MapIteratorT Input iterator type, for the map (may be a pointer type). + * @tparam MapTransformOp Unary lambda expression or operator type. MapTransformOp's result type + * must be convertible to IndexT. + * @tparam StencilIteratorT Input iterator type, for the stencil (may be a pointer type). + * @tparam UnaryPredicateOp Unary lambda expression or operator type. UnaryPredicateOp's result type + * must be convertible to bool type. + * @tparam OutputIteratorT Output iterator type, for the output matrix (may be a pointer type). + * @tparam IndexT Index type. * - * @param in Pointer to the input matrix (assumed to be row-major) - * @param D Leading dimension of the input matrix 'in', which in-case of row-major - * storage is the number of columns - * @param N Second dimension - * @param map Pointer to the input sequence of gather locations - * @param stencil Pointer to the input sequence of stencil or predicate values - * @param map_length The length of 'map' and 'stencil' - * @param out Pointer to the output matrix (assumed to be row-major) + * @param in Input matrix, dim = [N, D] (row-major) + * @param D Number of columns of the input/output matrices + * @param N Number of rows of the input matrix + * @param map Map of row indices to gather, dim = [map_length] + * @param stencil Sequence of stencil values, dim = [map_length] + * @param map_length The length of 'map' and 'stencil', number of rows of the output matrix + * @param out Output matrix, dim = [map_length, D] (row-major) * @param pred_op Predicate to apply to the stencil values - * @param transform_op The transformation operation, transforms the map values to IndexT + * @param transform_op Transformation to apply to map values * @param stream CUDA stream to launch kernels within */ -template -void gather_if(const MatrixIteratorT in, - int D, - int N, - MapIteratorT map, + typename MapTransformOp, + typename OutputIteratorT, + typename IndexT> +void gather_if(const InputIteratorT in, + IndexT D, + IndexT N, + const MapIteratorT map, StencilIteratorT stencil, - int map_length, - MatrixIteratorT out, + IndexT map_length, + OutputIteratorT out, UnaryPredicateOp pred_op, MapTransformOp transform_op, cudaStream_t stream) @@ -178,58 +190,31 @@ void gather_if(const MatrixIteratorT in, } /** - * @brief gather copies rows from a source matrix into a destination matrix according to a map. + * @brief Copies rows from a source matrix into a destination matrix according to a transformed map. * - * @tparam matrix_t Matrix element type - * @tparam map_t Map vector type - * @tparam idx_t integer type used for indexing - * @param[in] handle raft handle for managing resources - * @param[in] in Input matrix (assumed to be row-major) - * @param[in] map Vector of gather locations - * @param[out] out Output matrix (assumed to be row-major) - */ -template -void gather(const raft::handle_t& handle, - raft::device_matrix_view in, - raft::device_vector_view map, - raft::device_matrix_view out) -{ - RAFT_EXPECTS(out.extent(0) == map.extent(0), - "Number of rows in output matrix must equal the size of the map vector"); - RAFT_EXPECTS(out.extent(1) == in.extent(1), - "Number of columns in input and output matrices must be equal."); - - raft::matrix::detail::gather( - const_cast(in.data_handle()), // TODO: There's a better way to handle this - static_cast(in.extent(1)), - static_cast(in.extent(0)), - map.data_handle(), - static_cast(map.extent(0)), - out.data_handle(), - handle.get_stream()); -} - -/** - * @brief gather copies rows from a source matrix into a destination matrix according to a - * transformed map. + * For each output row, read the index in the input matrix from the map, apply a transformation to + * this input index if specified, and copy the row. * - * @tparam matrix_t Matrix type - * @tparam map_t Map vector type - * @tparam map_xform_t Unary lambda expression or operator type, MapTransformOp's result - * type must be convertible to idx_t (= int) type. - * @tparam idx_t integer type for indexing - * @param[in] handle raft handle for managing resources - * @param[in] in Input matrix (assumed to be row-major) - * @param[in] map Input vector of gather locations - * @param[out] out Output matrix (assumed to be row-major) - * @param[in] transform_op The transformation operation, transforms the map values to idx_t + * @tparam matrix_t Matrix element type + * @tparam map_t Integer type of map elements + * @tparam idx_t Integer type used for indexing + * @tparam map_xform_t Unary lambda expression or operator type. MapTransformOp's result type must + * be convertible to idx_t. + * @param[in] handle raft handle for managing resources + * @param[in] in Input matrix, dim = [N, D] (row-major) + * @param[in] map Map of row indices to gather, dim = [map_length] + * @param[out] out Output matrix, dim = [map_length, D] (row-major) + * @param[in] transform_op (optional) Transformation to apply to map values */ -template +template void gather(const raft::handle_t& handle, raft::device_matrix_view in, raft::device_vector_view map, - raft::device_matrix_view out, - map_xform_t transform_op) + raft::device_matrix_view out, + map_xform_t transform_op = raft::identity_op()) { RAFT_EXPECTS(out.extent(0) == map.extent(0), "Number of rows in output matrix must equal the size of the map vector"); @@ -238,95 +223,51 @@ void gather(const raft::handle_t& handle, detail::gather( const_cast(in.data_handle()), // TODO: There's a better way to handle this - static_cast(in.extent(1)), - static_cast(in.extent(0)), - map, - static_cast(map.extent(0)), + in.extent(1), + in.extent(0), + map.data_handle(), + map.extent(0), out.data_handle(), transform_op, handle.get_stream()); } /** - * @brief gather_if conditionally copies rows from a source matrix into a destination matrix - * according to a map. + * @brief Conditionally copies rows according to a transformed map. + * + * For each output row, read the index in the input matrix from the map, read a stencil value, + * apply a predicate to the stencil value, and if true, apply a transformation if specified to the + * input index, and copy the row. * - * @tparam matrix_t Matrix value type - * @tparam map_t Map vector type - * @tparam stencil_t Stencil vector type - * @tparam unary_pred_t Unary lambda expression or operator type, unary_pred_t's result - * type must be convertible to bool type. - * @tparam idx_t integer type for indexing - * @param[in] handle raft handle for managing resources - * @param[in] in Input matrix (assumed to be row-major) - * @param[in] map Input vector of gather locations - * @param[in] stencil Input vector of stencil or predicate values - * @param[out] out Output matrix (assumed to be row-major) - * @param[in] pred_op Predicate to apply to the stencil values + * @tparam matrix_t Matrix element type + * @tparam map_t Integer type of map elements + * @tparam stencil_t Value type for stencil (input type for the pred_op) + * @tparam unary_pred_t Unary lambda expression or operator type. unary_pred_t's result + * type must be convertible to bool type. + * @tparam map_xform_t Unary lambda expression or operator type. MapTransformOp's result type must + * be convertible to idx_t. + * @tparam idx_t Integer type used for indexing + * @param[in] handle raft handle for managing resources + * @param[in] in Input matrix, dim = [N, D] (row-major) + * @param[in] map Map of row indices to gather, dim = [map_length] + * @param[in] stencil Vector of stencil values, dim = [map_length] + * @param[out] out Output matrix, dim = [map_length, D] (row-major) + * @param[in] pred_op Predicate to apply to the stencil values + * @param[in] transform_op (optional) Transformation to apply to map values */ template + typename idx_t, + typename map_xform_t = raft::identity_op> void gather_if(const raft::handle_t& handle, raft::device_matrix_view in, raft::device_matrix_view out, raft::device_vector_view map, raft::device_vector_view stencil, - unary_pred_t pred_op) -{ - RAFT_EXPECTS(out.extent(0) == map.extent(0), - "Number of rows in output matrix must equal the size of the map vector"); - RAFT_EXPECTS(out.extent(1) == in.extent(1), - "Number of columns in input and output matrices must be equal."); - RAFT_EXPECTS(map.extent(0) == stencil.extent(0), - "Number of elements in stencil must equal number of elements in map"); - - detail::gather_if(const_cast(in.data_handle()), - out.extent(1), - out.extent(0), - map.data_handle(), - stencil.data_handle(), - map.extent(0), - out.data_handle(), - pred_op, - handle.get_stream()); -} - -/** - * @brief gather_if conditionally copies rows from a source matrix into a destination matrix - * according to a transformed map. - * - * @tparam matrix_t Matrix value type, for reading input matrix - * @tparam map_t Vector value type for map - * @tparam stencil_t Vector value type for stencil - * @tparam unary_pred_t Unary lambda expression or operator type, unary_pred_t's result - * type must be convertible to bool type. - * @tparam map_xform_t Unary lambda expression or operator type, map_xform_t's result - * type must be convertible to idx_t (= int) type. - * @tparam idx_t integer type for indexing - * @param[in] handle raft handle for managing resources - * @param[in] in Input matrix (assumed to be row-major) - * @param[in] map Vector of gather locations - * @param[in] stencil Vector of stencil or predicate values - * @param[out] out Output matrix (assumed to be row-major) - * @param[in] pred_op Predicate to apply to the stencil values - * @param[in] transform_op The transformation operation, transforms the map values to idx_t - */ -template -void gather_if(const raft::handle_t& handle, - raft::device_matrix_view in, - raft::device_matrix_view out, - raft::device_vector_view map, - raft::device_vector_view stencil, unary_pred_t pred_op, - map_xform_t transform_op) + map_xform_t transform_op = raft::identity_op()) { RAFT_EXPECTS(out.extent(0) == map.extent(0), "Number of rows in output matrix must equal the size of the map vector"); diff --git a/cpp/test/matrix/gather.cu b/cpp/test/matrix/gather.cu index 0bea62e9cf..3659265e84 100644 --- a/cpp/test/matrix/gather.cu +++ b/cpp/test/matrix/gather.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,50 +18,72 @@ #include #include #include +#include #include #include #include +#include #include namespace raft { -template -void naiveGatherImpl( - MatrixIteratorT in, int D, int N, MapIteratorT map, int map_length, MatrixIteratorT out) +template +void naiveGather(InputIteratorT in, + IdxT D, + IdxT N, + MapIteratorT map, + StencilIteratorT stencil, + IdxT map_length, + OutputIteratorT out, + UnaryPredicateOp pred_op, + MapTransformOp transform_op) { - for (int outRow = 0; outRow < map_length; ++outRow) { + for (IdxT outRow = 0; outRow < map_length; ++outRow) { + if constexpr (Conditional) { + auto stencil_val = stencil[outRow]; + if (!pred_op(stencil_val)) continue; + } typename std::iterator_traits::value_type map_val = map[outRow]; - int inRowStart = map_val * D; - int outRowStart = outRow * D; - for (int i = 0; i < D; ++i) { + IdxT transformed_val; + if constexpr (MapTransform) { + transformed_val = transform_op(map_val); + } else { + transformed_val = map_val; + } + IdxT inRowStart = transformed_val * D; + IdxT outRowStart = outRow * D; + for (IdxT i = 0; i < D; ++i) { out[outRowStart + i] = in[inRowStart + i]; } } } -template -void naiveGather( - MatrixIteratorT in, int D, int N, MapIteratorT map, int map_length, MatrixIteratorT out) -{ - naiveGatherImpl(in, D, N, map, map_length, out); -} - +template struct GatherInputs { - uint32_t nrows; - uint32_t ncols; - uint32_t map_length; + IdxT nrows; + IdxT ncols; + IdxT map_length; unsigned long long int seed; }; -template -class GatherTest : public ::testing::TestWithParam { +template +class GatherTest : public ::testing::TestWithParam> { protected: GatherTest() : stream(handle.get_stream()), - params(::testing::TestWithParam::GetParam()), + params(::testing::TestWithParam>::GetParam()), d_in(0, stream), d_out_exp(0, stream), d_out_act(0, stream), + d_stencil(0, stream), d_map(0, stream) { } @@ -71,44 +93,71 @@ class GatherTest : public ::testing::TestWithParam { raft::random::RngState r(params.seed); raft::random::RngState r_int(params.seed); - uint32_t nrows = params.nrows; - uint32_t ncols = params.ncols; - uint32_t map_length = params.map_length; - uint32_t len = nrows * ncols; + IdxT map_length = params.map_length; + IdxT len = params.nrows * params.ncols; // input matrix setup - d_in.resize(nrows * ncols, stream); - h_in.resize(nrows * ncols); + d_in.resize(params.nrows * params.ncols, stream); + h_in.resize(params.nrows * params.ncols); raft::random::uniform(handle, r, d_in.data(), len, MatrixT(-1.0), MatrixT(1.0)); raft::update_host(h_in.data(), d_in.data(), len, stream); // map setup d_map.resize(map_length, stream); h_map.resize(map_length); - raft::random::uniformInt(handle, r_int, d_map.data(), map_length, (MapT)0, nrows); + raft::random::uniformInt(handle, r_int, d_map.data(), map_length, (MapT)0, (MapT)params.nrows); raft::update_host(h_map.data(), d_map.data(), map_length, stream); - // expected and actual output matrix setup - h_out.resize(map_length * ncols); - d_out_exp.resize(map_length * ncols, stream); - d_out_act.resize(map_length * ncols, stream); + // stencil setup + if (Conditional) { + d_stencil.resize(map_length, stream); + h_stencil.resize(map_length); + raft::random::uniform(handle, r, d_stencil.data(), map_length, MatrixT(-1.0), MatrixT(1.0)); + raft::update_host(h_stencil.data(), d_stencil.data(), map_length, stream); + } - // launch gather on the host and copy the results to device - naiveGather(h_in.data(), ncols, nrows, h_map.data(), map_length, h_out.data()); - raft::update_device(d_out_exp.data(), h_out.data(), map_length * ncols, stream); + // unary predicate op (used only when Conditional is true) + auto pred_op = raft::plug_const_op(MatrixT(0.0), raft::greater_op()); - auto in_view = raft::make_device_matrix_view( - d_in.data(), nrows, ncols); - auto out_view = - raft::make_device_matrix_view(d_out_act.data(), map_length, ncols); - auto map_view = - raft::make_device_vector_view(d_map.data(), map_length); + // map transform op (used only when MapTransform is true) + auto transform_op = + raft::compose_op(raft::modulo_const_op(params.nrows), raft::add_const_op(10)); - raft::matrix::gather(handle, in_view, map_view, out_view); + // expected and actual output matrix setup + h_out.resize(map_length * params.ncols); + d_out_exp.resize(map_length * params.ncols, stream); + d_out_act.resize(map_length * params.ncols, stream); - // // launch device version of the kernel - // gatherLaunch( - // handle, d_in.data(), ncols, nrows, d_map.data(), map_length, d_out_act.data(), stream); + // launch gather on the host and copy the results to device + naiveGather(h_in.data(), + params.ncols, + params.nrows, + h_map.data(), + h_stencil.data(), + map_length, + h_out.data(), + pred_op, + transform_op); + raft::update_device(d_out_exp.data(), h_out.data(), map_length * params.ncols, stream); + + auto in_view = raft::make_device_matrix_view( + d_in.data(), params.nrows, params.ncols); + auto out_view = raft::make_device_matrix_view( + d_out_act.data(), map_length, params.ncols); + auto map_view = raft::make_device_vector_view(d_map.data(), map_length); + auto stencil_view = + raft::make_device_vector_view(d_stencil.data(), map_length); + + if (Conditional && MapTransform) { + raft::matrix::gather_if( + handle, in_view, out_view, map_view, stencil_view, pred_op, transform_op); + } else if (Conditional) { + raft::matrix::gather_if(handle, in_view, out_view, map_view, stencil_view, pred_op); + } else if (MapTransform) { + raft::matrix::gather(handle, in_view, map_view, out_view, transform_op); + } else { + raft::matrix::gather(handle, in_view, map_view, out_view); + } handle.sync_stream(stream); } @@ -116,41 +165,46 @@ class GatherTest : public ::testing::TestWithParam { protected: raft::handle_t handle; cudaStream_t stream = 0; - GatherInputs params; - std::vector h_in, h_out; + GatherInputs params; + std::vector h_in, h_out, h_stencil; std::vector h_map; - rmm::device_uvector d_in, d_out_exp, d_out_act; + rmm::device_uvector d_in, d_out_exp, d_out_act, d_stencil; rmm::device_uvector d_map; }; -const std::vector inputs = {{1024, 32, 128, 1234ULL}, - {1024, 32, 256, 1234ULL}, - {1024, 32, 512, 1234ULL}, - {1024, 32, 1024, 1234ULL}, - {1024, 64, 128, 1234ULL}, - {1024, 64, 256, 1234ULL}, - {1024, 64, 512, 1234ULL}, - {1024, 64, 1024, 1234ULL}, - {1024, 128, 128, 1234ULL}, - {1024, 128, 256, 1234ULL}, - {1024, 128, 512, 1234ULL}, - {1024, 128, 1024, 1234ULL}}; - -typedef GatherTest GatherTestF; -TEST_P(GatherTestF, Result) -{ - ASSERT_TRUE(devArrMatch( - d_out_exp.data(), d_out_act.data(), params.map_length * params.ncols, raft::Compare())); -} - -typedef GatherTest GatherTestD; -TEST_P(GatherTestD, Result) -{ - ASSERT_TRUE(devArrMatch( - d_out_exp.data(), d_out_act.data(), params.map_length * params.ncols, raft::Compare())); -} - -INSTANTIATE_TEST_CASE_P(GatherTests, GatherTestF, ::testing::ValuesIn(inputs)); -INSTANTIATE_TEST_CASE_P(GatherTests, GatherTestD, ::testing::ValuesIn(inputs)); +#define GATHER_TEST(test_type, test_name, test_inputs) \ + typedef RAFT_DEPAREN(test_type) test_name; \ + TEST_P(test_name, Result) \ + { \ + ASSERT_TRUE(devArrMatch(d_out_exp.data(), \ + d_out_act.data(), \ + params.map_length* params.ncols, \ + raft::Compare())); \ + } \ + INSTANTIATE_TEST_CASE_P(GatherTests, test_name, ::testing::ValuesIn(test_inputs)) + +const std::vector> inputs_i32 = + raft::util::itertools::product>({25, 2000}, {6, 31, 129}, {11, 999}, {1234ULL}); +const std::vector> inputs_i64 = + raft::util::itertools::product>( + {25, 2000}, {6, 31, 129}, {11, 999}, {1234ULL}); + +GATHER_TEST((GatherTest), GatherTestFU32I32, inputs_i32); +GATHER_TEST((GatherTest), + GatherTransformTestFU32I32, + inputs_i32); +GATHER_TEST((GatherTest), GatherIfTestFU32I32, inputs_i32); +GATHER_TEST((GatherTest), + GatherIfTransformTestFU32I32, + inputs_i32); +GATHER_TEST((GatherTest), + GatherIfTransformTestDU32I32, + inputs_i32); +GATHER_TEST((GatherTest), + GatherIfTransformTestFU32I64, + inputs_i64); +GATHER_TEST((GatherTest), + GatherIfTransformTestFI64I64, + inputs_i64); } // end namespace raft \ No newline at end of file