diff --git a/cpp/bench/CMakeLists.txt b/cpp/bench/CMakeLists.txt index 3c3d296b06..0a05682bcb 100644 --- a/cpp/bench/CMakeLists.txt +++ b/cpp/bench/CMakeLists.txt @@ -94,6 +94,7 @@ if(BUILD_BENCH) bench/linalg/add.cu bench/linalg/map_then_reduce.cu bench/linalg/matrix_vector_op.cu + bench/linalg/normalize.cu bench/linalg/reduce_rows_by_key.cu bench/linalg/reduce.cu bench/main.cpp diff --git a/cpp/bench/linalg/normalize.cu b/cpp/bench/linalg/normalize.cu new file mode 100644 index 0000000000..d01473ffeb --- /dev/null +++ b/cpp/bench/linalg/normalize.cu @@ -0,0 +1,79 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include + +#include + +namespace raft::bench::linalg { + +template +struct normalize_input { + IdxT rows, cols; +}; + +template +inline auto operator<<(std::ostream& os, const normalize_input& p) -> std::ostream& +{ + os << p.rows << "#" << p.cols; + return os; +} + +template +struct rowNormalize : public fixture { + rowNormalize(const normalize_input& p) + : params(p), in(p.rows * p.cols, stream), out(p.rows * p.cols, stream) + { + raft::random::RngState rng{1234}; + raft::random::uniform(rng, in.data(), p.rows * p.cols, (T)-10.0, (T)10.0, stream); + } + + void run_benchmark(::benchmark::State& state) override + { + std::ostringstream label_stream; + label_stream << params; + state.SetLabel(label_stream.str()); + + loop_on_state(state, [this]() { + auto input_view = raft::make_device_matrix_view( + in.data(), params.rows, params.cols); + auto output_view = raft::make_device_matrix_view( + out.data(), params.rows, params.cols); + raft::linalg::row_normalize(handle, input_view, output_view, raft::linalg::L2Norm); + }); + } + + private: + normalize_input params; + rmm::device_uvector in, out; +}; // struct rowNormalize + +const std::vector> normalize_inputs_i32 = + raft::util::itertools::product>( + {10, 100, 1000, 10000, 100000}, {8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384}); +const std::vector> normalize_inputs_i64 = + raft::util::itertools::product>( + {10, 100, 1000, 10000, 100000}, {8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384}); + +RAFT_BENCH_REGISTER((rowNormalize), "", normalize_inputs_i32); +RAFT_BENCH_REGISTER((rowNormalize), "", normalize_inputs_i32); +RAFT_BENCH_REGISTER((rowNormalize), "", normalize_inputs_i64); +RAFT_BENCH_REGISTER((rowNormalize), "", normalize_inputs_i64); + +} // namespace raft::bench::linalg diff --git a/cpp/include/raft/linalg/detail/norm.cuh b/cpp/include/raft/linalg/detail/norm.cuh index a0b557211c..f2f08233d5 100644 --- a/cpp/include/raft/linalg/detail/norm.cuh +++ b/cpp/include/raft/linalg/detail/norm.cuh @@ -16,15 +16,13 @@ #pragma once +#include #include namespace raft { namespace linalg { namespace detail { -/** different types of norms supported on the input buffers */ -enum NormType { L1Norm = 0, L2Norm }; - template void rowNormCaller(Type* dots, const Type* data, @@ -64,7 +62,21 @@ void rowNormCaller(Type* dots, raft::Sum(), fin_op); break; - default: ASSERT(false, "Invalid norm type passed! [%d]", type); + case LinfNorm: + raft::linalg::reduce(dots, + data, + D, + N, + (Type)0, + rowMajor, + true, + stream, + false, + raft::L1Op(), + raft::Max(), + fin_op); + break; + default: THROW("Unsupported norm type: %d", type); }; } @@ -89,7 +101,7 @@ void colNormCaller(Type* dots, false, stream, false, - raft::L1Op(), + raft::L1Op(), raft::Sum(), fin_op); break; @@ -103,11 +115,25 @@ void colNormCaller(Type* dots, false, stream, false, - raft::L2Op(), + raft::L2Op(), raft::Sum(), fin_op); break; - default: ASSERT(false, "Invalid norm type passed! [%d]", type); + case LinfNorm: + raft::linalg::reduce(dots, + data, + D, + N, + (Type)0, + rowMajor, + false, + stream, + false, + raft::L1Op(), + raft::Max(), + fin_op); + break; + default: THROW("Unsupported norm type: %d", type); }; } diff --git a/cpp/include/raft/linalg/detail/normalize.cuh b/cpp/include/raft/linalg/detail/normalize.cuh new file mode 100644 index 0000000000..78c773ab35 --- /dev/null +++ b/cpp/include/raft/linalg/detail/normalize.cuh @@ -0,0 +1,187 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace raft { +namespace linalg { +namespace detail { + +template +struct NormalizeThinPolicy { + static constexpr int LogicalWarpSize = warpSize; + static constexpr int RowsPerBlock = rpb; + static constexpr int ThreadsPerBlock = LogicalWarpSize * RowsPerBlock; +}; + +template +__global__ void __launch_bounds__(Policy::ThreadsPerBlock) + coalesced_normalize_thin_kernel(Type* out, + const Type* in, + IdxType D, + IdxType N, + Type init, + MainLambda main_op, + ReduceLambda reduce_op, + FinalLambda fin_op, + Type eps) +{ + IdxType i = threadIdx.y + (Policy::RowsPerBlock * static_cast(blockIdx.x)); + if (i >= N) return; + + Type acc = init; + for (IdxType j = threadIdx.x; j < D; j += Policy::LogicalWarpSize) { + Type val = in[j + D * i]; + acc = reduce_op(acc, main_op(val, j)); + } + acc = raft::logicalWarpReduce(acc, reduce_op); + acc = fin_op(acc); + if (acc <= eps) return; + for (IdxType j = threadIdx.x; j < D; j += Policy::LogicalWarpSize) { + out[j + D * i] = in[j + D * i] / acc; + } +} + +template +inline void coalesced_normalize_thin(Type* out, + const Type* in, + IdxType D, + IdxType N, + Type init, + cudaStream_t stream, + MainLambda main_op, + ReduceLambda reduce_op, + FinalLambda fin_op, + Type eps) +{ + dim3 grid(ceildiv(N, (IdxType)Policy::RowsPerBlock), 1, 1); + dim3 block(Policy::LogicalWarpSize, Policy::RowsPerBlock, 1); + coalesced_normalize_thin_kernel + <<>>(out, in, D, N, init, main_op, reduce_op, fin_op, eps); + RAFT_CUDA_TRY(cudaPeekAtLastError()); +} + +template +__global__ void __launch_bounds__(TPB) coalesced_normalize_medium_kernel(Type* out, + const Type* in, + IdxType D, + IdxType N, + Type init, + MainLambda main_op, + ReduceLambda reduce_op, + FinalLambda fin_op, + Type eps) +{ + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + __shared__ Type bcast_acc; + Type thread_data = init; + IdxType rowStart = blockIdx.x * D; + for (IdxType i = threadIdx.x; i < D; i += TPB) { + IdxType idx = rowStart + i; + thread_data = reduce_op(thread_data, main_op(in[idx], i)); + } + Type acc = BlockReduce(temp_storage).Reduce(thread_data, reduce_op); + if (threadIdx.x == 0) { bcast_acc = fin_op(acc); } + __syncthreads(); + if (bcast_acc <= eps) return; + for (IdxType i = threadIdx.x; i < D; i += TPB) { + IdxType idx = rowStart + i; + out[idx] = in[idx] / bcast_acc; + } +} + +template +inline void coalesced_normalize_medium(Type* out, + const Type* in, + IdxType D, + IdxType N, + Type init, + cudaStream_t stream, + MainLambda main_op, + ReduceLambda reduce_op, + FinalLambda fin_op, + Type eps) +{ + coalesced_normalize_medium_kernel + <<>>(out, in, D, N, init, main_op, reduce_op, fin_op, eps); + RAFT_CUDA_TRY(cudaPeekAtLastError()); +} + +template +void coalesced_normalize(Type* out, + const Type* in, + IdxType D, + IdxType N, + Type init, + cudaStream_t stream, + MainLambda main_op, + ReduceLambda reduce_op, + FinalLambda fin_op, + Type eps) +{ + const IdxType numSMs = raft::getMultiProcessorCount(); + if (D <= IdxType(256) || (D <= IdxType(512) && N >= 4 * numSMs)) { + if (D <= IdxType(2)) { + coalesced_normalize_thin>( + out, in, D, N, init, stream, main_op, reduce_op, fin_op, eps); + } else if (D <= IdxType(4)) { + coalesced_normalize_thin>( + out, in, D, N, init, stream, main_op, reduce_op, fin_op, eps); + } else if (D <= IdxType(8)) { + coalesced_normalize_thin>( + out, in, D, N, init, stream, main_op, reduce_op, fin_op, eps); + } else if (D <= IdxType(16)) { + coalesced_normalize_thin>( + out, in, D, N, init, stream, main_op, reduce_op, fin_op, eps); + } else { + coalesced_normalize_thin>( + out, in, D, N, init, stream, main_op, reduce_op, fin_op, eps); + } + } else { + coalesced_normalize_medium<256>(out, in, D, N, init, stream, main_op, reduce_op, fin_op, eps); + } +} + +} // namespace detail +} // namespace linalg +} // namespace raft diff --git a/cpp/include/raft/linalg/norm.cuh b/cpp/include/raft/linalg/norm.cuh index 389affef13..b756744755 100644 --- a/cpp/include/raft/linalg/norm.cuh +++ b/cpp/include/raft/linalg/norm.cuh @@ -22,16 +22,12 @@ #include "linalg_types.hpp" #include +#include #include namespace raft { namespace linalg { -/** different types of norms supported on the input buffers */ -using detail::L1Norm; -using detail::L2Norm; -using detail::NormType; - /** * @brief Compute row-wise norm of the input matrix and perform fin_op lambda * @@ -44,7 +40,7 @@ using detail::NormType; * @tparam Lambda device final lambda * @tparam IdxType Integer type used to for addressing * @param dots the output vector of row-wise dot products - * @param data the input matrix (currently assumed to be row-major) + * @param data the input matrix * @param D number of columns of data * @param N number of rows of data * @param type the type of norm to be applied @@ -71,7 +67,7 @@ void rowNorm(Type* dots, * @tparam Lambda device final lambda * @tparam IdxType Integer type used to for addressing * @param dots the output vector of column-wise dot products - * @param data the input matrix (currently assumed to be row-major) + * @param data the input matrix * @param D number of columns of data * @param N number of rows of data * @param type the type of norm to be applied diff --git a/cpp/include/raft/linalg/norm_types.hpp b/cpp/include/raft/linalg/norm_types.hpp new file mode 100644 index 0000000000..d399e588ce --- /dev/null +++ b/cpp/include/raft/linalg/norm_types.hpp @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +namespace raft { +namespace linalg { + +/** Enum to tell how to compute a norm */ +enum NormType : unsigned short { + /** L0 (actually not a norm): sum((x_i != 0 ? 1 : 0)) */ + L0PseudoNorm = 0, + /** L1 norm or Manhattan: sum(abs(x_i)) */ + L1Norm = 1, + /** L2 norm or Euclidean: sqrt(sum(x_i^2)). Note that in some prims the square root is optional, + in which case it can be specified using a boolean or a functor final_op */ + L2Norm = 2, + /** Linf norm or Chebyshev: max(abs(x_i)) */ + LinfNorm +}; + +} // namespace linalg +} // namespace raft diff --git a/cpp/include/raft/linalg/normalize.cuh b/cpp/include/raft/linalg/normalize.cuh new file mode 100644 index 0000000000..4bdf697581 --- /dev/null +++ b/cpp/include/raft/linalg/normalize.cuh @@ -0,0 +1,131 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "detail/normalize.cuh" + +#include + +namespace raft { +namespace linalg { + +/** + * @brief Divide rows by their norm defined by main_op, reduce_op and fin_op + * + * @tparam ElementType Input/Output data type + * @tparam IndexType Integer type used to for addressing + * @tparam MainLambda Type of main_op + * @tparam ReduceLambda Type of reduce_op + * @tparam FinalLambda Type of fin_op + * @param[in] handle raft::handle_t + * @param[in] in the input raft::device_matrix_view + * @param[out] out the output raft::device_matrix_view + * @param[in] init Initialization value, i.e identity element for the reduction operation + * @param[in] main_op Operation to apply to the elements before reducing them (e.g square for L2) + * @param[in] reduce_op Operation to reduce a pair of elements (e.g sum for L2) + * @param[in] fin_op Operation to apply once to the reduction result to finalize the norm + * computation (e.g sqrt for L2) + * @param[in] eps If the norm is below eps, the row is considered zero and no division is applied + */ +template +void row_normalize(const raft::handle_t& handle, + raft::device_matrix_view in, + raft::device_matrix_view out, + ElementType init, + MainLambda main_op, + ReduceLambda reduce_op, + FinalLambda fin_op, + ElementType eps = ElementType(1e-8)) +{ + RAFT_EXPECTS(raft::is_row_or_column_major(in), "Input must be contiguous"); + RAFT_EXPECTS(raft::is_row_or_column_major(out), "Output must be contiguous"); + RAFT_EXPECTS(in.extent(0) == out.extent(0), + "The number of rows of the input and output should be equal"); + RAFT_EXPECTS(in.extent(1) == out.extent(1), + "The number of columns of the input and output should be equal"); + + detail::coalesced_normalize(out.data_handle(), + in.data_handle(), + in.extent(1), + in.extent(0), + init, + handle.get_stream(), + main_op, + reduce_op, + fin_op, + eps); +} + +/** + * @brief Divide rows by their norm. + * + * @tparam ElementType Input/Output data type + * @tparam IndexType Integer type used to for addressing + * @param[in] handle raft::handle_t + * @param[in] in the input raft::device_matrix_view + * @param[out] out the output raft::device_matrix_view + * @param[in] norm_type the type of norm to be applied + * @param[in] eps If the norm is below eps, the row is considered zero and no division is applied + */ +template +void row_normalize(const raft::handle_t& handle, + raft::device_matrix_view in, + raft::device_matrix_view out, + NormType norm_type, + ElementType eps = ElementType(1e-8)) +{ + switch (norm_type) { + case L1Norm: + row_normalize(handle, + in, + out, + ElementType(0), + raft::L1Op(), + raft::Sum(), + raft::Nop(), + eps); + break; + case L2Norm: + row_normalize(handle, + in, + out, + ElementType(0), + raft::L2Op(), + raft::Sum(), + raft::SqrtOp(), + eps); + break; + case LinfNorm: + row_normalize(handle, + in, + out, + ElementType(0), + raft::L1Op(), + raft::Max(), + raft::Nop(), + eps); + break; + default: THROW("Unsupported norm type: %d", norm_type); + } +} + +} // namespace linalg +} // namespace raft diff --git a/cpp/include/raft/spatial/knn/detail/ann_kmeans_balanced.cuh b/cpp/include/raft/spatial/knn/detail/ann_kmeans_balanced.cuh index f92a74918b..961cc76381 100644 --- a/cpp/include/raft/spatial/knn/detail/ann_kmeans_balanced.cuh +++ b/cpp/include/raft/spatial/knn/detail/ann_kmeans_balanced.cuh @@ -32,6 +32,7 @@ #include #include #include +#include #include #include #include @@ -663,8 +664,16 @@ void balancing_em_iters(const handle_t& handle, // To avoid converging to zero, we normalize the center vectors on every iteration. case raft::distance::DistanceType::InnerProduct: case raft::distance::DistanceType::CosineExpanded: - case raft::distance::DistanceType::CorrelationExpanded: - utils::normalize_rows(n_clusters, dim, cluster_centers, stream); + case raft::distance::DistanceType::CorrelationExpanded: { + auto clusters_in_view = + raft::make_device_matrix_view( + cluster_centers, n_clusters, dim); + auto clusters_out_view = raft::make_device_matrix_view( + cluster_centers, n_clusters, dim); + raft::linalg::row_normalize( + handle, clusters_in_view, clusters_out_view, raft::linalg::L2Norm); + break; + } default: break; } // E: Expectation step - predict labels diff --git a/cpp/include/raft/spatial/knn/detail/ann_utils.cuh b/cpp/include/raft/spatial/knn/detail/ann_utils.cuh index 5d031cc51d..f27aeb24f9 100644 --- a/cpp/include/raft/spatial/knn/detail/ann_utils.cuh +++ b/cpp/include/raft/spatial/knn/detail/ann_utils.cuh @@ -199,49 +199,6 @@ inline void dots_along_rows( */ } -template -__global__ void normalize_rows_kernel(IdxT n_rows, IdxT n_cols, float* a) -{ - IdxT i = threadIdx.y + (blockDim.y * static_cast(blockIdx.x)); - if (i >= n_rows) return; - - float sqsum = 0.0; - for (IdxT j = threadIdx.x; j < n_cols; j += blockDim.x) { - float val = a[j + (n_cols * i)]; - sqsum += val * val; - } - sqsum += __shfl_xor_sync(0xffffffff, sqsum, 1); - sqsum += __shfl_xor_sync(0xffffffff, sqsum, 2); - sqsum += __shfl_xor_sync(0xffffffff, sqsum, 4); - sqsum += __shfl_xor_sync(0xffffffff, sqsum, 8); - sqsum += __shfl_xor_sync(0xffffffff, sqsum, 16); - if (sqsum <= 1e-8) return; - sqsum = rsqrtf(sqsum); // reciprocal of the square root - for (IdxT j = threadIdx.x; j < n_cols; j += blockDim.x) { - a[j + n_cols * i] *= sqsum; - } -} - -/** - * @brief Divide rows by their L2 norm (square root of sum of squares). - * - * NB: device-only function - * - * @tparam IdxT index type - * - * @param[in] n_rows - * @param[in] n_cols - * @param[inout] a device pointer to a row-major matrix [n_rows, n_cols] - * @param stream - */ -template -inline void normalize_rows(IdxT n_rows, IdxT n_cols, float* a, rmm::cuda_stream_view stream) -{ - dim3 threads(32, 4, 1); // DO NOT CHANGE - dim3 blocks(ceildiv(n_rows, threads.y), 1, 1); - normalize_rows_kernel<<>>(n_rows, n_cols, a); -} - template __global__ void outer_add_kernel(const T* a, IdxT len_a, const T* b, IdxT len_b, T* c) { diff --git a/cpp/include/raft/util/cuda_utils.cuh b/cpp/include/raft/util/cuda_utils.cuh index 1d1c82eb94..e5b58718a0 100644 --- a/cpp/include/raft/util/cuda_utils.cuh +++ b/cpp/include/raft/util/cuda_utils.cuh @@ -516,6 +516,16 @@ struct Nop { HDI Type operator()(Type in, IdxType i = 0) { return in; } }; +template +struct SqrtOp { + HDI Type operator()(Type in, IdxType i = 0) { return mySqrt(in); } +}; + +template +struct L0Op { + HDI Type operator()(Type in, IdxType i = 0) { return in != Type(0) ? Type(1) : Type(0); } +}; + template struct L1Op { HDI Type operator()(Type in, IdxType i = 0) { return myAbs(in); } @@ -530,6 +540,11 @@ template struct Sum { HDI Type operator()(Type a, Type b) { return a + b; } }; + +template +struct Max { + HDI Type operator()(Type a, Type b) { return myMax(a, b); } +}; /** @} */ /** @@ -729,6 +744,26 @@ DI auto dp4a(unsigned int a, unsigned int b, unsigned int c) -> unsigned int #endif } +/** + * @brief Logical-warp-level reduction + * @tparam logicalWarpSize Logical warp size (2, 4, 8, 16 or 32) + * @tparam T Value type to be reduced + * @tparam ReduceLambda Reduction operation type + * @param val input value + * @param reduce_op Reduction operation + * @return Reduction result. All lanes will have the valid result. + */ +template +DI T logicalWarpReduce(T val, ReduceLambda reduce_op) +{ +#pragma unroll + for (int i = logicalWarpSize / 2; i > 0; i >>= 1) { + T tmp = shfl_xor(val, i); + val = reduce_op(val, tmp); + } + return val; +} + /** * @brief Warp-level sum reduction * @param val input value @@ -742,12 +777,7 @@ DI auto dp4a(unsigned int a, unsigned int b, unsigned int c) -> unsigned int template DI T warpReduce(T val) { -#pragma unroll - for (int i = WarpSize / 2; i > 0; i >>= 1) { - T tmp = shfl_xor(val, i); - val += tmp; - } - return val; + return logicalWarpReduce(val, raft::Sum()); } /** diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 31144f6ffd..3e8f944a5b 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -146,6 +146,7 @@ if(BUILD_TESTS) test/linalg/mean_squared_error.cu test/linalg/multiply.cu test/linalg/norm.cu + test/linalg/normalize.cu test/linalg/power.cu test/linalg/reduce.cu test/linalg/reduce_cols_by_key.cu diff --git a/cpp/test/linalg/normalize.cu b/cpp/test/linalg/normalize.cu new file mode 100644 index 0000000000..cb949b6a5d --- /dev/null +++ b/cpp/test/linalg/normalize.cu @@ -0,0 +1,151 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../test_utils.h" +#include +#include +#include +#include +#include +#include +#include + +namespace raft { +namespace linalg { + +template +struct RowNormalizeInputs { + T tolerance; + IdxT rows, cols; + raft::linalg::NormType norm_type; + unsigned long long int seed; +}; + +template +::std::ostream& operator<<(::std::ostream& os, const RowNormalizeInputs& I) +{ + os << "{ " << I.tolerance << ", " << I.rows << ", " << I.cols << ", " << I.norm_type << ", " + << I.seed << '}' << std::endl; + return os; +} + +template +void rowNormalizeRef( + T* out, const T* in, IdxT cols, IdxT rows, raft::linalg::NormType norm_type, cudaStream_t stream) +{ + rmm::device_uvector norm(rows, stream); + if (norm_type == raft::linalg::L2Norm) { + raft::linalg::rowNorm(norm.data(), in, cols, rows, norm_type, true, stream, raft::SqrtOp()); + } else { + raft::linalg::rowNorm(norm.data(), in, cols, rows, norm_type, true, stream, raft::Nop()); + } + raft::linalg::matrixVectorOp( + out, + in, + norm.data(), + cols, + rows, + true, + false, + [] __device__(T a, T b) { return a / b; }, + stream); +} + +template +class RowNormalizeTest : public ::testing::TestWithParam> { + public: + RowNormalizeTest() + : params(::testing::TestWithParam>::GetParam()), + stream(handle.get_stream()), + data(params.rows * params.cols, stream), + out_exp(params.rows * params.cols, stream), + out_act(params.rows * params.cols, stream) + { + } + + void SetUp() override + { + raft::random::RngState r(params.seed); + int len = params.rows * params.cols; + uniform(handle, r, data.data(), len, T(-10.0), T(10.0)); + + rowNormalizeRef( + out_exp.data(), data.data(), params.cols, params.rows, params.norm_type, stream); + + auto input_view = raft::make_device_matrix_view( + data.data(), params.rows, params.cols); + auto output_view = raft::make_device_matrix_view( + out_act.data(), params.rows, params.cols); + raft::linalg::row_normalize(handle, input_view, output_view, params.norm_type); + + handle.sync_stream(stream); + } + + protected: + raft::handle_t handle; + cudaStream_t stream; + + RowNormalizeInputs params; + rmm::device_uvector data, out_exp, out_act; +}; + +const std::vector> inputsf_i32 = + raft::util::itertools::product>( + {0.00001f}, + {11, 101, 12345}, + {2, 3, 7, 12, 33, 125, 254}, + {raft::linalg::L1Norm, raft::linalg::L2Norm, raft::linalg::LinfNorm}, + {1234ULL}); +const std::vector> inputsd_i32 = + raft::util::itertools::product>( + {0.00000001}, + {11, 101, 12345}, + {2, 3, 7, 12, 33, 125, 254}, + {raft::linalg::L1Norm, raft::linalg::L2Norm, raft::linalg::LinfNorm}, + {1234ULL}); +const std::vector> inputsf_u32 = + raft::util::itertools::product>( + {0.00001f}, + {11u, 101u, 12345u}, + {2u, 3u, 7u, 12u, 33u, 125u, 254u}, + {raft::linalg::L1Norm, raft::linalg::L2Norm, raft::linalg::LinfNorm}, + {1234ULL}); +const std::vector> inputsd_u32 = + raft::util::itertools::product>( + {0.00000001}, + {11u, 101u, 12345u}, + {2u, 3u, 7u, 12u, 33u, 125u, 254u}, + {raft::linalg::L1Norm, raft::linalg::L2Norm, raft::linalg::LinfNorm}, + {1234ULL}); + +#define ROWNORMALIZE_TEST(test_type, test_name, test_inputs) \ + typedef RAFT_DEPAREN(test_type) test_name; \ + TEST_P(test_name, Result) \ + { \ + ASSERT_TRUE(raft::devArrMatch(out_exp.data(), \ + out_act.data(), \ + params.rows* params.cols, \ + raft::CompareApprox(params.tolerance))); \ + } \ + INSTANTIATE_TEST_CASE_P(RowNormalizeTests, test_name, ::testing::ValuesIn(test_inputs)) + +ROWNORMALIZE_TEST((RowNormalizeTest), RowNormalizeTestFI32, inputsf_i32); +ROWNORMALIZE_TEST((RowNormalizeTest), RowNormalizeTestDI32, inputsd_i32); +ROWNORMALIZE_TEST((RowNormalizeTest), RowNormalizeTestFU32, inputsf_u32); +ROWNORMALIZE_TEST((RowNormalizeTest), RowNormalizeTestDU32, inputsd_u32); + +} // end namespace linalg +} // end namespace raft