diff --git a/cpp/bench/linalg/matrix_vector_op.cu b/cpp/bench/linalg/matrix_vector_op.cu index aa8f2667ed..aa388955da 100644 --- a/cpp/bench/linalg/matrix_vector_op.cu +++ b/cpp/bench/linalg/matrix_vector_op.cu @@ -20,61 +20,137 @@ namespace raft::bench::linalg { +template struct mat_vec_op_inputs { - int rows, cols; + IdxT rows, cols; bool rowMajor, bcastAlongRows; + IdxT inAlignOffset, outAlignOffset; }; // struct mat_vec_op_inputs -template +template +inline auto operator<<(std::ostream& os, const mat_vec_op_inputs& p) -> std::ostream& +{ + os << p.rows << "#" << p.cols << "#" << p.rowMajor << "#" << p.bcastAlongRows << "#" + << p.inAlignOffset << "#" << p.outAlignOffset; + return os; +} + +template struct mat_vec_op : public fixture { - mat_vec_op(const mat_vec_op_inputs& p) + mat_vec_op(const mat_vec_op_inputs& p) : params(p), - out(p.rows * p.cols, stream), - in(p.rows * p.cols, stream), - vec(p.bcastAlongRows ? p.cols : p.rows, stream) + out(p.rows * p.cols + params.outAlignOffset, stream), + in(p.rows * p.cols + params.inAlignOffset, stream), + vec1(p.bcastAlongRows ? p.cols : p.rows, stream), + vec2(p.bcastAlongRows ? p.cols : p.rows, stream) { } void run_benchmark(::benchmark::State& state) override { + std::ostringstream label_stream; + label_stream << params; + state.SetLabel(label_stream.str()); + loop_on_state(state, [this]() { - raft::linalg::matrixVectorOp(out.data(), - in.data(), - vec.data(), - params.cols, - params.rows, - params.rowMajor, - params.bcastAlongRows, - raft::Sum(), - stream); + if constexpr (OpT::useTwoVectors) { + raft::linalg::matrixVectorOp(out.data() + params.outAlignOffset, + in.data() + params.inAlignOffset, + vec1.data(), + vec2.data(), + params.cols, + params.rows, + params.rowMajor, + params.bcastAlongRows, + OpT{}, + stream); + } else { + raft::linalg::matrixVectorOp(out.data() + params.outAlignOffset, + in.data() + params.inAlignOffset, + vec1.data(), + params.cols, + params.rows, + params.rowMajor, + params.bcastAlongRows, + OpT{}, + stream); + } }); } private: - mat_vec_op_inputs params; - rmm::device_uvector out, in, vec; + mat_vec_op_inputs params; + rmm::device_uvector out, in, vec1, vec2; }; // struct MatVecOp -const std::vector mat_vec_op_input_vecs{ - {1024, 128, true, true}, {1024 * 1024, 128, true, true}, - {1024, 128 + 2, true, true}, {1024 * 1024, 128 + 2, true, true}, - {1024, 128 + 1, true, true}, {1024 * 1024, 128 + 1, true, true}, +template +std::vector> get_mv_inputs() +{ + std::vector> out; - {1024, 128, true, false}, {1024 * 1024, 128, true, false}, - {1024, 128 + 2, true, false}, {1024 * 1024, 128 + 2, true, false}, - {1024, 128 + 1, true, false}, {1024 * 1024, 128 + 1, true, false}, + // Scalability benchmark with round dimensions + std::vector rows = {1000, 100000, 1000000}; + std::vector cols = {8, 64, 256, 1024}; + for (bool rowMajor : {true, false}) { + for (bool alongRows : {true, false}) { + for (IdxT rows_ : rows) { + for (IdxT cols_ : cols) { + out.push_back({rows_, cols_, rowMajor, alongRows, 0, 0}); + } + } + } + } - {1024, 128, false, false}, {1024 * 1024, 128, false, false}, - {1024, 128 + 2, false, false}, {1024 * 1024, 128 + 2, false, false}, - {1024, 128 + 1, false, false}, {1024 * 1024, 128 + 1, false, false}, + // Odd dimensions, misalignment + std::vector> rowcols = { + {44739207, 7}, + {44739207, 15}, + {44739207, 16}, + {44739207, 17}, + {2611236, 256}, + {2611236, 257}, + {2611236, 263}, + }; + for (bool rowMajor : {true, false}) { + for (bool alongRows : {true, false}) { + for (auto rc : rowcols) { + for (IdxT inAlignOffset : {0, 1}) { + for (IdxT outAlignOffset : {0, 1}) { + out.push_back({std::get<0>(rc), + std::get<1>(rc), + rowMajor, + alongRows, + inAlignOffset, + outAlignOffset}); + } + } + } + } + } + return out; +} - {1024, 128, false, true}, {1024 * 1024, 128, false, true}, - {1024, 128 + 2, false, true}, {1024 * 1024, 128 + 2, false, true}, - {1024, 128 + 1, false, true}, {1024 * 1024, 128 + 1, false, true}, +const std::vector> mv_input_i32 = get_mv_inputs(); +const std::vector> mv_input_i64 = get_mv_inputs(); +template +struct Add1Vec { + static constexpr bool useTwoVectors = false; + HDI T operator()(T a, T b) const { return a + b; }; +}; +template +struct Add2Vec { + static constexpr bool useTwoVectors = true; + HDI T operator()(T a, T b, T c) const { return a + b + c; }; }; -RAFT_BENCH_REGISTER(mat_vec_op, "", mat_vec_op_input_vecs); -RAFT_BENCH_REGISTER(mat_vec_op, "", mat_vec_op_input_vecs); +RAFT_BENCH_REGISTER((mat_vec_op, float, int>), "", mv_input_i32); +RAFT_BENCH_REGISTER((mat_vec_op, double, int>), "", mv_input_i32); +RAFT_BENCH_REGISTER((mat_vec_op, float, int>), "", mv_input_i32); +RAFT_BENCH_REGISTER((mat_vec_op, double, int>), "", mv_input_i32); +RAFT_BENCH_REGISTER((mat_vec_op, float, int64_t>), "", mv_input_i64); +RAFT_BENCH_REGISTER((mat_vec_op, double, int64_t>), "", mv_input_i64); +RAFT_BENCH_REGISTER((mat_vec_op, float, int64_t>), "", mv_input_i64); +RAFT_BENCH_REGISTER((mat_vec_op, double, int64_t>), "", mv_input_i64); } // namespace raft::bench::linalg diff --git a/cpp/include/raft/linalg/detail/matrix_vector_op.cuh b/cpp/include/raft/linalg/detail/matrix_vector_op.cuh index 4cfccdcaa3..62ec9bb7a4 100644 --- a/cpp/include/raft/linalg/detail/matrix_vector_op.cuh +++ b/cpp/include/raft/linalg/detail/matrix_vector_op.cuh @@ -22,80 +22,10 @@ namespace raft { namespace linalg { namespace detail { -namespace { -template -struct AlignedAccess { - template - static inline bool test(const T* matrix, size_t strideBytes) - { - return Pow2::isAligned(matrix) && Pow2::isAligned(strideBytes) && - Pow2::isAligned(VecBytes); - } -}; -}; // namespace - -template -__global__ void matrixVectorOpKernel(Type* out, - const Type* matrix, - const Type* vector, - IdxType D, - IdxType N, - bool rowMajor, - bool bcastAlongRows, - Lambda op) -{ - typedef TxN_t VecType; - IdxType len = N * D; - IdxType idx = threadIdx.x; - idx += (IdxType)blockIdx.x * (IdxType)blockDim.x; - idx *= VecType::Ratio; - if (idx >= len) return; - IdxType vIdx; - VecType mat, vec; - ///@todo: yikes! use fast-int-div here. - ///@todo: shared mem for vector could help with perf - if (rowMajor && bcastAlongRows) { - vIdx = idx % D; - vec.load(vector, vIdx); - } else if (!rowMajor && !bcastAlongRows) { - vIdx = idx % N; - vec.load(vector, vIdx); - } else if (rowMajor && !bcastAlongRows) { - vIdx = idx / D; - vec.fill(vector[vIdx]); - } else { - vIdx = idx / N; - vec.fill(vector[vIdx]); - } - mat.load(matrix, idx); -#pragma unroll - for (int i = 0; i < VecType::Ratio; ++i) - mat.val.data[i] = op(mat.val.data[i], vec.val.data[i]); - mat.store(out, idx); -} - -template -void matrixVectorOpImpl(Type* out, - const Type* matrix, - const Type* vec, - IdxType D, - IdxType N, - bool rowMajor, - bool bcastAlongRows, - Lambda op, - cudaStream_t stream) -{ - IdxType len = N * D; - IdxType nblks = raft::ceildiv(veclen_ ? len / veclen_ : veclen_, (IdxType)TPB); - matrixVectorOpKernel - <<>>(out, matrix, vec, D, N, rowMajor, bcastAlongRows, op); - RAFT_CUDA_TRY(cudaPeekAtLastError()); -} - -template -void matrixVectorOp(Type* out, - const Type* matrix, - const Type* vec, +template +void matrixVectorOp(MatT* out, + const MatT* matrix, + const VecT* vec, IdxType D, IdxType N, bool rowMajor, @@ -109,11 +39,16 @@ void matrixVectorOp(Type* out, out, matrix, stride, nLines, rowMajor == bcastAlongRows, op, stream, vec); } -template -void matrixVectorOp(Type* out, - const Type* matrix, - const Type* vec1, - const Type* vec2, +template +void matrixVectorOp(MatT* out, + const MatT* matrix, + const Vec1T* vec1, + const Vec2T* vec2, IdxType D, IdxType N, bool rowMajor, diff --git a/cpp/include/raft/linalg/matrix_vector_op.cuh b/cpp/include/raft/linalg/matrix_vector_op.cuh index 1438a09bd3..d68be838b0 100644 --- a/cpp/include/raft/linalg/matrix_vector_op.cuh +++ b/cpp/include/raft/linalg/matrix_vector_op.cuh @@ -35,10 +35,10 @@ namespace linalg { * Note : the function will also check that the size of the window of accesses * is a multiple of the number of elements processed by a thread in order to * enable faster processing - * @tparam Type the matrix/vector type + * @tparam MatT the matrix type * @tparam Lambda a device function which represents a binary operator + * @tparam VecT the input vector type * @tparam IdxType Integer type used to for addressing - * @tparam TPB threads per block of the cuda kernel launched * @param out the output matrix (passing out = matrix makes it in-place) * @param matrix the input matrix * @param vec the vector @@ -50,10 +50,10 @@ namespace linalg { * @param op the mathematical operation * @param stream cuda stream where to launch work */ -template -void matrixVectorOp(Type* out, - const Type* matrix, - const Type* vec, +template +void matrixVectorOp(MatT* out, + const MatT* matrix, + const VecT* vec, IdxType D, IdxType N, bool rowMajor, @@ -72,10 +72,11 @@ void matrixVectorOp(Type* out, * Note : the function will also check that the size of the window of accesses * is a multiple of the number of elements processed by a thread in order to * enable faster processing - * @tparam Type the matrix/vector type + * @tparam MatT the matrix type * @tparam Lambda a device function which represents a binary operator + * @tparam Vec1T the first input vector type + * @tparam Vec2T the second input vector type * @tparam IdxType Integer type used to for addressing - * @tparam TPB threads per block of the cuda kernel launched * @param out the output matrix (passing out = matrix makes it in-place) * @param matrix the input matrix * @param vec1 the first vector @@ -88,11 +89,11 @@ void matrixVectorOp(Type* out, * @param op the mathematical operation * @param stream cuda stream where to launch work */ -template -void matrixVectorOp(Type* out, - const Type* matrix, - const Type* vec1, - const Type* vec2, +template +void matrixVectorOp(MatT* out, + const MatT* matrix, + const Vec1T* vec1, + const Vec2T* vec2, IdxType D, IdxType N, bool rowMajor, @@ -116,12 +117,11 @@ void matrixVectorOp(Type* out, * Note : the function will also check that the size of the window of accesses * is a multiple of the number of elements processed by a thread in order to * enable faster processing - * @tparam InValueType the data-type of the input matrices and vectors + * @tparam MatValueType the data-type of the input matrix + * @tparam VecValueType the data-type of the input vector * @tparam LayoutPolicy the layout of input and output (raft::row_major or raft::col_major) * @tparam Lambda a device function which represents a binary operator - * @tparam OutElementType the data-type of the output raft::matrix_view * @tparam IndexType Integer used for addressing - * @tparam TPB threads per block of the cuda kernel launched * @param[in] handle raft::handle_t * @param[in] matrix input raft::matrix_view * @param[in] vec vector raft::vector_view @@ -130,16 +130,15 @@ void matrixVectorOp(Type* out, * the rows of the matrix or columns using enum class raft::linalg::Apply * @param[in] op the mathematical operation */ -template + typename IndexType> void matrix_vector_op(const raft::handle_t& handle, - raft::device_matrix_view matrix, - raft::device_vector_view vec, - raft::device_matrix_view out, + raft::device_matrix_view matrix, + raft::device_vector_view vec, + raft::device_matrix_view out, Apply apply, Lambda op) { @@ -177,12 +176,12 @@ void matrix_vector_op(const raft::handle_t& handle, * Note : the function will also check that the size of the window of accesses * is a multiple of the number of elements processed by a thread in order to * enable faster processing - * @tparam InValueType the data-type of the input matrices and vectors + * @tparam MatValueType the data-type of the input and output matrices + * @tparam Vec1ValueType the data-type of the first input vector + * @tparam Vec2ValueType the data-type of the second input vector * @tparam LayoutPolicy the layout of input and output (raft::row_major or raft::col_major) * @tparam Lambda a device function which represents a binary operator - * @tparam OutElementType the data-type of the output raft::matrix_view * @tparam IndexType Integer used for addressing - * @tparam TPB threads per block of the cuda kernel launched * @param handle raft::handle_t * @param matrix input raft::matrix_view * @param vec1 the first vector raft::vector_view @@ -192,17 +191,17 @@ void matrix_vector_op(const raft::handle_t& handle, * the rows of the matrix or columns using enum class raft::linalg::Apply * @param op the mathematical operation */ -template + typename IndexType> void matrix_vector_op(const raft::handle_t& handle, - raft::device_matrix_view matrix, - raft::device_vector_view vec1, - raft::device_vector_view vec2, - raft::device_matrix_view out, + raft::device_matrix_view matrix, + raft::device_vector_view vec1, + raft::device_vector_view vec2, + raft::device_matrix_view out, Apply apply, Lambda op) { diff --git a/cpp/include/raft/matrix/detail/linewise_op.cuh b/cpp/include/raft/matrix/detail/linewise_op.cuh index 8180b88c8a..37198684ee 100644 --- a/cpp/include/raft/matrix/detail/linewise_op.cuh +++ b/cpp/include/raft/matrix/detail/linewise_op.cuh @@ -23,11 +23,28 @@ #include #include +#include namespace raft { namespace matrix { namespace detail { +/** This type simplifies returning arrays and passing them as arguments */ +template +struct VecArg { + Type val[VecElems]; +}; + +/** Executes the operation with the given matrix element and an arbitrary number of vector elements + * contained in the given tuple. The index_sequence is used here for compile-time indexing of the + * tuple in the fold expression. */ +template +__device__ __forceinline__ MatT +RunMatVecOp(Lambda op, MatT mat, Tuple&& args, std::index_sequence) +{ + return op(mat, (thrust::get(args))...); +} + template struct Linewise { static constexpr IdxType VecElems = VecBytes / sizeof(Type); @@ -78,10 +95,13 @@ struct Linewise { IdxType rowDiv, IdxType rowMod, Lambda op, - Vecs... vecs) noexcept + const Vecs*... vecs) noexcept { constexpr IdxType warpPad = (AlignWarp::Value - 1) * VecElems; - Type args[sizeof...(Vecs)]; + constexpr auto index = std::index_sequence_for(); + // todo(lsugy): switch to cuda::std::tuple from libcudacxx if we add it as a required + // dependency. Note that thrust::tuple is limited to 10 elements. + thrust::tuple args; Vec v, w; bool update = true; for (; in < in_end; in += AlignWarp::Value, out += AlignWarp::Value, rowMod += warpPad) { @@ -92,8 +112,7 @@ struct Linewise { update = true; } if (update) { - int l = 0; - ((args[l] = vecs[rowDiv], l++), ...); + args = thrust::make_tuple((vecs[rowDiv])...); update = false; } #pragma unroll VecElems @@ -101,11 +120,9 @@ struct Linewise { if (rowMod == rowLen) { rowMod = 0; rowDiv++; - int l = 0; - ((args[l] = vecs[rowDiv], l++), ...); + args = thrust::make_tuple((vecs[rowDiv])...); } - int l = 0; - w.val.data[k] = op(v.val.data[k], (std::ignore = vecs, args[l++])...); + w.val.data[k] = RunMatVecOp(op, v.val.data[k], args, index); } *out = *w.vectorized_data(); } @@ -143,7 +160,7 @@ struct Linewise { *v.vectorized_data() = __ldcv(in + i); #pragma unroll VecElems for (int k = 0; k < VecElems; k++) - v.val.data[k] = op(v.val.data[k], args.val.data[k]...); + v.val.data[k] = op(v.val.data[k], args.val[k]...); __stwt(out + i, *v.vectorized_data()); } } @@ -153,16 +170,18 @@ struct Linewise { * of a vector. Most of the time this is not aligned, so we load it thread-striped * within a block and then use the shared memory to get a contiguous chunk. * + * @tparam VecT Type of the vector to load * @param [in] shm a shared memory region for rearranging the data among threads * @param [in] p pointer to a vector * @param [in] blockOffset the offset of the current block into a vector. * @param [in] rowLen the length of a vector. * @return a contiguous chunk of a vector, suitable for `vectorRows`. */ - static __device__ __forceinline__ Vec loadVec(Type* shm, - const Type* p, - const IdxType blockOffset, - const IdxType rowLen) noexcept + template + static __device__ __forceinline__ VecArg loadVec(VecT* shm, + const VecT* p, + const IdxType blockOffset, + const IdxType rowLen) noexcept { IdxType j = blockOffset + threadIdx.x; #pragma unroll VecElems @@ -173,8 +192,10 @@ struct Linewise { } __syncthreads(); { - Vec out; - *out.vectorized_data() = reinterpret_cast(shm)[threadIdx.x]; + VecArg out; +#pragma unroll VecElems + for (int i = 0; i < VecElems; i++) + out.val[i] = shm[threadIdx.x * VecElems + i]; return out; } } @@ -182,6 +203,7 @@ struct Linewise { /** * @brief Same as loadVec, but padds data with Ones * + * @tparam VecT Type of the vector to load * @param shm * @param p * @param blockOffset @@ -189,23 +211,27 @@ struct Linewise { * @param rowLenPadded * @return a contiguous chunk of a vector, suitable for `vectorRows`. */ - static __device__ __forceinline__ Vec loadVecPadded(Type* shm, - const Type* p, - const IdxType blockOffset, - const IdxType rowLen, - const IdxType rowLenPadded) noexcept + template + static __device__ __forceinline__ VecArg loadVecPadded( + VecT* shm, + const VecT* p, + const IdxType blockOffset, + const IdxType rowLen, + const IdxType rowLenPadded) noexcept { IdxType j = blockOffset + threadIdx.x; #pragma unroll VecElems for (int k = threadIdx.x; k < VecElems * BlockSize; k += BlockSize, j += BlockSize) { while (j >= rowLenPadded) j -= rowLenPadded; - shm[k] = j < rowLen ? p[j] : Type(1); + shm[k] = j < rowLen ? p[j] : VecT(1); } __syncthreads(); { - Vec out; - *out.vectorized_data() = reinterpret_cast(shm)[threadIdx.x]; + VecArg out; +#pragma unroll VecElems + for (int i = 0; i < VecElems; i++) + out.val[i] = shm[threadIdx.x * VecElems + i]; return out; } } @@ -242,7 +268,7 @@ __global__ void __launch_bounds__(BlockSize) const IdxType len, const IdxType elemsPerThread, Lambda op, - Vecs... vecs) + const Vecs*... vecs) { typedef Linewise L; @@ -286,7 +312,7 @@ __global__ void __launch_bounds__(MaxOffset, 2) const IdxType rowLen, const IdxType len, Lambda op, - Vecs... vecs) + const Vecs*... vecs) { // Note, L::VecElems == 1 typedef Linewise L; @@ -313,6 +339,15 @@ __global__ void __launch_bounds__(MaxOffset, 2) vecs...); } +/** Helper function to get the largest type from a variadic list of types */ +template +constexpr size_t maxSizeOf() +{ + size_t maxSize = 0; + ((maxSize = std::max(maxSize, sizeof(Types))), ...); + return maxSize; +} + /** * This kernel prepares the inputs for the `vectorRows` function where the most of the * work happens; see `vectorRows` for details. @@ -342,20 +377,22 @@ __global__ void __launch_bounds__(BlockSize) const IdxType rowLen, const IdxType len, Lambda op, - Vecs... vecs) + const Vecs*... vecs) { typedef Linewise L; - constexpr uint workSize = L::VecElems * BlockSize; - uint workOffset = workSize; - __shared__ __align__(sizeof(Type) * L::VecElems) - Type shm[workSize * ((sizeof...(Vecs)) > 1 ? 2 : 1)]; + constexpr uint workSize = L::VecElems * BlockSize; + constexpr size_t maxVecItemSize = maxSizeOf(); + uint workOffset = workSize * maxVecItemSize; + __shared__ __align__( + maxVecItemSize * + L::VecElems) char shm[workSize * maxVecItemSize * ((sizeof...(Vecs)) > 1 ? 2 : 1)]; const IdxType blockOffset = (arrOffset + BlockSize * L::VecElems * blockIdx.x) % rowLen; - return L::vectorRows( - reinterpret_cast(out), - reinterpret_cast(in), - L::AlignElems::div(len), - op, - (workOffset ^= workSize, L::loadVec(shm + workOffset, vecs, blockOffset, rowLen))...); + return L::vectorRows(reinterpret_cast(out), + reinterpret_cast(in), + L::AlignElems::div(len), + op, + (workOffset ^= workSize * maxVecItemSize, + L::loadVec((Vecs*)(shm + workOffset), vecs, blockOffset, rowLen))...); } /** @@ -383,21 +420,23 @@ __global__ void __launch_bounds__(BlockSize) const IdxType rowLenPadded, const IdxType lenPadded, Lambda op, - Vecs... vecs) + const Vecs*... vecs) { typedef Linewise L; - constexpr uint workSize = L::VecElems * BlockSize; - uint workOffset = workSize; - __shared__ __align__(sizeof(Type) * L::VecElems) - Type shm[workSize * ((sizeof...(Vecs)) > 1 ? 2 : 1)]; + constexpr uint workSize = L::VecElems * BlockSize; + constexpr size_t maxVecItemSize = maxSizeOf(); + uint workOffset = workSize * maxVecItemSize; + __shared__ __align__( + maxVecItemSize * + L::VecElems) char shm[workSize * maxVecItemSize * ((sizeof...(Vecs)) > 1 ? 2 : 1)]; const IdxType blockOffset = (BlockSize * L::VecElems * blockIdx.x) % rowLenPadded; return L::vectorRows( reinterpret_cast(out), reinterpret_cast(in), L::AlignElems::div(lenPadded), op, - (workOffset ^= workSize, - L::loadVecPadded(shm + workOffset, vecs, blockOffset, rowLen, rowLenPadded))...); + (workOffset ^= workSize * maxVecItemSize, + L::loadVecPadded((Vecs*)(shm + workOffset), vecs, blockOffset, rowLen, rowLenPadded))...); } /** @@ -426,12 +465,13 @@ __global__ void __launch_bounds__(MaxOffset, 2) const IdxType rowLen, const IdxType len, Lambda op, - Vecs... vecs) + const Vecs*... vecs) { // Note, L::VecElems == 1 - constexpr uint workSize = MaxOffset; - uint workOffset = workSize; - __shared__ Type shm[workSize * ((sizeof...(Vecs)) > 1 ? 2 : 1)]; + constexpr uint workSize = MaxOffset; + constexpr size_t maxVecItemSize = maxSizeOf(); + uint workOffset = workSize * maxVecItemSize; + __shared__ char shm[workSize * maxVecItemSize * ((sizeof...(Vecs)) > 1 ? 2 : 1)]; typedef Linewise L; if (blockIdx.x == 0) { // first block: offset = 0, length = arrOffset @@ -439,16 +479,17 @@ __global__ void __launch_bounds__(MaxOffset, 2) reinterpret_cast(in), arrOffset, op, - (workOffset ^= workSize, L::loadVec(shm + workOffset, vecs, 0, rowLen))...); + (workOffset ^= workSize * maxVecItemSize, + L::loadVec((Vecs*)(shm + workOffset), vecs, 0, rowLen))...); } else { // second block: offset = arrTail, length = len - arrTail // NB: I substract MaxOffset (= blockDim.x) to get the correct indexing for block 1 - L::vectorRows( - reinterpret_cast(out + arrTail - MaxOffset), - reinterpret_cast(in + arrTail - MaxOffset), - len - arrTail + MaxOffset, - op, - (workOffset ^= workSize, L::loadVec(shm + workOffset, vecs, arrTail % rowLen, rowLen))...); + L::vectorRows(reinterpret_cast(out + arrTail - MaxOffset), + reinterpret_cast(in + arrTail - MaxOffset), + len - arrTail + MaxOffset, + op, + (workOffset ^= workSize * maxVecItemSize, + L::loadVec((Vecs*)(shm + workOffset), vecs, arrTail % rowLen, rowLen))...); } } @@ -484,7 +525,7 @@ void matrixLinewiseVecCols(Type* out, const IdxType nRows, Lambda op, cudaStream_t stream, - Vecs... vecs) + const Vecs*... vecs) { typedef raft::Pow2 AlignBytes; constexpr std::size_t VecElems = VecBytes / sizeof(Type); @@ -537,7 +578,7 @@ void matrixLinewiseVecColsSpan( const IdxType nRows, Lambda op, cudaStream_t stream, - Vecs... vecs) + const Vecs*... vecs) { typedef raft::Pow2 AlignBytes; constexpr std::size_t VecElems = VecBytes / sizeof(Type); @@ -584,7 +625,7 @@ void matrixLinewiseVecRows(Type* out, const IdxType nRows, Lambda op, cudaStream_t stream, - Vecs... vecs) + const Vecs*... vecs) { typedef raft::Pow2 AlignBytes; constexpr std::size_t VecElems = VecBytes / sizeof(Type); @@ -614,7 +655,7 @@ void matrixLinewiseVecRows(Type* out, const uint occupy = getOptimalGridSize(); const dim3 gs(std::min( // does not make sense to have more blocks than this - raft::ceildiv(uint(totalLen), block_work_size), + raft::ceildiv(uint(alignedLen), block_work_size), // increase the grid size to be not less than `occupy` while // still being the multiple of `expected_grid_size` raft::ceildiv(occupy, expected_grid_size) * expected_grid_size), @@ -655,7 +696,7 @@ void matrixLinewiseVecRowsSpan( const IdxType nRows, Lambda op, cudaStream_t stream, - Vecs... vecs) + const Vecs*... vecs) { constexpr std::size_t VecElems = VecBytes / sizeof(Type); typedef raft::Pow2 AlignBytes; @@ -719,7 +760,7 @@ struct MatrixLinewiseOp { const bool alongLines, Lambda op, cudaStream_t stream, - Vecs... vecs) + const Vecs*... vecs) { if constexpr (VecBytes > sizeof(Type)) { if (!raft::Pow2::areSameAlignOffsets(in, out)) @@ -746,7 +787,7 @@ struct MatrixLinewiseOp { const bool alongLines, Lambda op, cudaStream_t stream, - Vecs... vecs) + const Vecs*... vecs) { constexpr auto is_rowmajor = std::is_same_v>; constexpr auto is_colmajor = std::is_same_v>; diff --git a/cpp/include/raft/matrix/matrix.cuh b/cpp/include/raft/matrix/matrix.cuh index 3a7e0dad47..cd6c4fa219 100644 --- a/cpp/include/raft/matrix/matrix.cuh +++ b/cpp/include/raft/matrix/matrix.cuh @@ -289,7 +289,7 @@ void linewiseOp(m_t* out, const bool alongLines, Lambda op, cudaStream_t stream, - Vecs... vecs) + const Vecs*... vecs) { common::nvtx::range fun_scope("linewiseOp-%c-%zu (%zu, %zu)", alongLines ? 'l' : 'x', diff --git a/cpp/include/raft/spatial/knn/detail/ann_kmeans_balanced.cuh b/cpp/include/raft/spatial/knn/detail/ann_kmeans_balanced.cuh index ff4708bb7b..fd009b30af 100644 --- a/cpp/include/raft/spatial/knn/detail/ann_kmeans_balanced.cuh +++ b/cpp/include/raft/spatial/knn/detail/ann_kmeans_balanced.cuh @@ -245,18 +245,19 @@ void calc_centers_and_sizes(const handle_t& handle, if (mr == nullptr) { mr = rmm::mr::get_current_device_resource(); } if (!reset_counters) { - utils::map_along_rows( - n_clusters, - dim, + raft::linalg::matrixVectorOp( + centers, centers, cluster_sizes, - [] __device__(float c, uint32_t s) -> float { return c * s; }, + (int64_t)dim, + (int64_t)n_clusters, + true, + false, + [=] __device__(float c, uint32_t s) -> float { return c * s; }, stream); } rmm::device_uvector workspace(0, stream, mr); - rmm::device_uvector cluster_sizes_f(n_clusters, stream, mr); - float* sizes_f = cluster_sizes_f.data(); // If we reset the counters, we can compute directly the new sizes in cluster_sizes. // If we don't reset, we compute in a temporary buffer and add in a separate step. @@ -291,28 +292,21 @@ void calc_centers_and_sizes(const handle_t& handle, static_cast(n_clusters), workspace); - // Add previous sizes if necessary and cast to float - auto counting = thrust::make_counting_iterator(0); - thrust::for_each( - handle.get_thrust_policy(), counting, counting + n_clusters, [=] __device__(int idx) { - uint32_t temp_size = temp_sizes[idx]; - if (!reset_counters) { - temp_size += cluster_sizes[idx]; - cluster_sizes[idx] = temp_size; - } - sizes_f[idx] = static_cast(temp_size); - }); + // Add previous sizes if necessary + if (!reset_counters) { + raft::linalg::add(cluster_sizes, cluster_sizes, temp_sizes, n_clusters, stream); + } raft::linalg::matrixVectorOp( centers, centers, - sizes_f, + cluster_sizes, static_cast(dim), static_cast(n_clusters), true, false, - [=] __device__(float mat, float vec) { - if (vec == 0.0f) + [=] __device__(float mat, uint32_t vec) { + if (vec == 0u) return 0.0f; else return mat / vec; diff --git a/cpp/include/raft/spatial/knn/detail/ann_utils.cuh b/cpp/include/raft/spatial/knn/detail/ann_utils.cuh index dbd509216b..7b26ccfb42 100644 --- a/cpp/include/raft/spatial/knn/detail/ann_utils.cuh +++ b/cpp/include/raft/spatial/knn/detail/ann_utils.cuh @@ -302,45 +302,6 @@ inline void normalize_rows(IdxT n_rows, IdxT n_cols, float* a, rmm::cuda_stream_ normalize_rows_kernel<<>>(n_rows, n_cols, a); } -template -__global__ void map_along_rows_kernel( - IdxT n_rows, uint32_t n_cols, float* a, const uint32_t* d, Lambda map) -{ - IdxT gid = threadIdx.x + blockDim.x * static_cast(blockIdx.x); - IdxT i = gid / n_cols; - if (i >= n_rows) return; - float& x = a[gid]; - x = map(x, d[i]); -} - -/** - * @brief Map a binary function over a matrix and a vector element-wise, broadcasting the vector - * values along rows: `m[i, j] = op(m[i,j], v[i])` - * - * NB: device-only function - * - * @tparam IdxT index type - * @tparam Lambda - * - * @param n_rows - * @param n_cols - * @param[inout] m device pointer to a row-major matrix [n_rows, n_cols] - * @param[in] v device pointer to a vector [n_rows] - * @param op the binary operation to apply on every element of matrix rows and of the vector - */ -template -inline void map_along_rows(IdxT n_rows, - uint32_t n_cols, - float* m, - const uint32_t* v, - Lambda op, - rmm::cuda_stream_view stream) -{ - dim3 threads(128, 1, 1); - dim3 blocks(ceildiv(n_rows * n_cols, threads.x), 1, 1); - map_along_rows_kernel<<>>(n_rows, n_cols, m, v, op); -} - template __global__ void outer_add_kernel(const T* a, IdxT len_a, const T* b, IdxT len_b, T* c) { diff --git a/cpp/include/raft/util/detail/itertools.hpp b/cpp/include/raft/util/detail/itertools.hpp new file mode 100644 index 0000000000..1908d90b95 --- /dev/null +++ b/cpp/include/raft/util/detail/itertools.hpp @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +namespace raft::util::itertools::detail { + +template +inline std::vector product(std::index_sequence index, const std::vector&... vecs) +{ + size_t len = 1; + ((len *= vecs.size()), ...); + std::vector out; + out.reserve(len); + for (size_t i = 0; i < len; i++) { + std::tuple tup; + size_t mod = len, new_mod; + ((new_mod = mod / vecs.size(), std::get(tup) = vecs[(i % mod) / new_mod], mod = new_mod), + ...); + out.push_back({std::get(tup)...}); + } + return out; +} + +} // namespace raft::util::itertools::detail diff --git a/cpp/include/raft/util/itertools.hpp b/cpp/include/raft/util/itertools.hpp new file mode 100644 index 0000000000..493ac9befe --- /dev/null +++ b/cpp/include/raft/util/itertools.hpp @@ -0,0 +1,47 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +/** + * Helpers inspired by the Python itertools library + * + */ + +namespace raft::util::itertools { + +/** + * @brief Cartesian product of the given initializer lists. + * + * This helper can be used to easily define input parameters in tests/benchmarks. + * Note that it's not optimized for use with large lists / many lists in performance-critical code! + * + * @tparam S Type of the output structures. + * @tparam Args Types of the elements of the initilizer lists, matching the types of the first + * fields of the structure (if the structure has more fields, some might be initialized + * with their default value). + * @param lists One or more initializer lists. + * @return std::vector A vector of structures containing the cartesian product. + */ +template +std::vector product(std::initializer_list... lists) +{ + return detail::product(std::index_sequence_for(), (std::vector(lists))...); +} + +} // namespace raft::util::itertools diff --git a/cpp/test/linalg/matrix_vector.cu b/cpp/test/linalg/matrix_vector.cu index 9062f3be4d..f103b5918b 100644 --- a/cpp/test/linalg/matrix_vector.cu +++ b/cpp/test/linalg/matrix_vector.cu @@ -130,16 +130,17 @@ void naive_matrix_vector_op_launch(const raft::handle_t& handle, }; if (operation_type == 0) { - naiveMatVecOp( - in, vec1, D, N, row_major, bcast_along_rows, operation_bin_mult_skip_zero, stream); + naiveMatVec( + in, in, vec1, D, N, row_major, bcast_along_rows, operation_bin_mult_skip_zero, stream); } else if (operation_type == 1) { - naiveMatVecOp(in, vec1, D, N, row_major, bcast_along_rows, operation_div, stream); + naiveMatVec(in, in, vec1, D, N, row_major, bcast_along_rows, operation_div, stream); } else if (operation_type == 2) { - naiveMatVecOp(in, vec1, D, N, row_major, bcast_along_rows, operation_bin_div_skip_zero, stream); + naiveMatVec( + in, in, vec1, D, N, row_major, bcast_along_rows, operation_bin_div_skip_zero, stream); } else if (operation_type == 3) { - naiveMatVecOp(in, vec1, D, N, row_major, bcast_along_rows, operation_bin_add, stream); + naiveMatVec(in, in, vec1, D, N, row_major, bcast_along_rows, operation_bin_add, stream); } else if (operation_type == 4) { - naiveMatVecOp(in, vec1, D, N, row_major, bcast_along_rows, operation_bin_sub, stream); + naiveMatVec(in, in, vec1, D, N, row_major, bcast_along_rows, operation_bin_sub, stream); } else { THROW("Unknown operation type '%d'!", (int)operation_type); } diff --git a/cpp/test/linalg/matrix_vector_op.cu b/cpp/test/linalg/matrix_vector_op.cu index b5a3168a06..1c96c3fc74 100644 --- a/cpp/test/linalg/matrix_vector_op.cu +++ b/cpp/test/linalg/matrix_vector_op.cu @@ -20,140 +20,138 @@ #include #include #include +#include +#include namespace raft { namespace linalg { -template +template struct MatVecOpInputs { - T tolerance; IdxType rows, cols; - bool rowMajor, bcastAlongRows, useTwoVectors; + bool rowMajor, bcastAlongRows; + IdxType inAlignOffset, outAlignOffset; unsigned long long int seed; }; -template -::std::ostream& operator<<(::std::ostream& os, const MatVecOpInputs& dims) +template +::std::ostream& operator<<(::std::ostream& os, const MatVecOpInputs& dims) { return os; } +template +inline void gen_uniform(const raft::handle_t& handle, raft::random::RngState& rng, T* ptr, LenT len) +{ + if constexpr (std::is_integral_v) { + raft::random::uniformInt(handle, rng, ptr, len, (T)0, (T)100); + } else { + raft::random::uniform(handle, rng, ptr, len, (T)-10.0, (T)10.0); + } +} + // Or else, we get the following compilation error // for an extended __device__ lambda cannot have private or protected access // within its class -template +template void matrixVectorOpLaunch(const raft::handle_t& handle, - T* out, - const T* in, - const T* vec1, - const T* vec2, + MatT* out, + const MatT* in, + const Vec1T* vec1, + const Vec2T* vec2, IdxType D, IdxType N, bool rowMajor, - bool bcastAlongRows, - bool useTwoVectors) + bool bcastAlongRows) { - auto out_row_major = raft::make_device_matrix_view(out, N, D); - auto in_row_major = raft::make_device_matrix_view(in, N, D); + auto out_row_major = raft::make_device_matrix_view(out, N, D); + auto in_row_major = raft::make_device_matrix_view(in, N, D); - auto out_col_major = raft::make_device_matrix_view(out, N, D); - auto in_col_major = raft::make_device_matrix_view(in, N, D); + auto out_col_major = raft::make_device_matrix_view(out, N, D); + auto in_col_major = raft::make_device_matrix_view(in, N, D); auto apply = bcastAlongRows ? Apply::ALONG_ROWS : Apply::ALONG_COLUMNS; auto len = bcastAlongRows ? D : N; - auto vec1_view = raft::make_device_vector_view(vec1, len); - auto vec2_view = raft::make_device_vector_view(vec2, len); + auto vec1_view = raft::make_device_vector_view(vec1, len); - if (useTwoVectors) { + if constexpr (OpT::useTwoVectors) { + auto vec2_view = raft::make_device_vector_view(vec2, len); if (rowMajor) { - matrix_vector_op(handle, - in_row_major, - vec1_view, - vec2_view, - out_row_major, - apply, - [] __device__(T a, T b, T c) { return a + b + c; }); + matrix_vector_op(handle, in_row_major, vec1_view, vec2_view, out_row_major, apply, OpT{}); } else { - matrix_vector_op(handle, - in_col_major, - vec1_view, - vec2_view, - out_col_major, - - apply, - [] __device__(T a, T b, T c) { return a + b + c; }); + matrix_vector_op(handle, in_col_major, vec1_view, vec2_view, out_col_major, apply, OpT{}); } } else { if (rowMajor) { - matrix_vector_op( - handle, in_row_major, vec1_view, out_row_major, apply, [] __device__(T a, T b) { - return a + b; - }); + matrix_vector_op(handle, in_row_major, vec1_view, out_row_major, apply, OpT{}); } else { - matrix_vector_op( - handle, in_col_major, vec1_view, out_col_major, apply, [] __device__(T a, T b) { - return a + b; - }); + matrix_vector_op(handle, in_col_major, vec1_view, out_col_major, apply, OpT{}); } } } -template -class MatVecOpTest : public ::testing::TestWithParam> { +template +class MatVecOpTest : public ::testing::TestWithParam> { public: MatVecOpTest() - : params(::testing::TestWithParam>::GetParam()), - stream(handle.get_stream()), - in(params.rows * params.cols, stream), - out_ref(params.rows * params.cols, stream), - out(params.rows * params.cols, stream), - vec1(params.bcastAlongRows ? params.cols : params.rows, stream), - vec2(params.bcastAlongRows ? params.cols : params.rows, stream) + : stream(handle.get_stream()), + params(::testing::TestWithParam>::GetParam()), + vec_size(params.bcastAlongRows ? params.cols : params.rows), + in(params.rows * params.cols + params.inAlignOffset, stream), + out_ref(params.rows * params.cols + params.outAlignOffset, stream), + out(params.rows * params.cols + params.outAlignOffset, stream), + vec1(vec_size, stream), + vec2(vec_size, stream) { } protected: void SetUp() override { + MatT* in_ptr = in.data() + params.inAlignOffset; + MatT* out_ptr = out.data() + params.outAlignOffset; + MatT* out_ref_ptr = out_ref.data() + params.outAlignOffset; + raft::random::RngState r(params.seed); - IdxType N = params.rows, D = params.cols; - IdxType len = N * D; - IdxType vecLen = params.bcastAlongRows ? D : N; - uniform(handle, r, in.data(), len, (T)-1.0, (T)1.0); - uniform(handle, r, vec1.data(), vecLen, (T)-1.0, (T)1.0); - uniform(handle, r, vec2.data(), vecLen, (T)-1.0, (T)1.0); - if (params.useTwoVectors) { - naiveMatVec(out_ref.data(), - in.data(), + IdxType len = params.rows * params.cols; + gen_uniform(handle, r, in_ptr, len); + gen_uniform(handle, r, vec1.data(), vec_size); + gen_uniform(handle, r, vec2.data(), vec_size); + if constexpr (OpT::useTwoVectors) { + naiveMatVec(out_ref_ptr, + in_ptr, vec1.data(), vec2.data(), - D, - N, + params.cols, + params.rows, params.rowMajor, params.bcastAlongRows, - (T)1.0, + OpT{}, stream); } else { - naiveMatVec(out_ref.data(), - in.data(), + naiveMatVec(out_ref_ptr, + in_ptr, vec1.data(), - D, - N, + params.cols, + params.rows, params.rowMajor, params.bcastAlongRows, - (T)1.0, + OpT{}, stream); } - matrixVectorOpLaunch(handle, - out.data(), - in.data(), - vec1.data(), - vec2.data(), - D, - N, - params.rowMajor, - params.bcastAlongRows, - params.useTwoVectors); + matrixVectorOpLaunch(handle, + out_ptr, + in_ptr, + vec1.data(), + vec2.data(), + params.cols, + params.rows, + params.rowMajor, + params.bcastAlongRows); handle.sync_stream(); } @@ -161,87 +159,110 @@ class MatVecOpTest : public ::testing::TestWithParam> raft::handle_t handle; cudaStream_t stream; - MatVecOpInputs params; - rmm::device_uvector in, out, out_ref, vec1, vec2; + MatVecOpInputs params; + IdxType vec_size; + rmm::device_uvector in; + rmm::device_uvector out; + rmm::device_uvector out_ref; + rmm::device_uvector vec1; + rmm::device_uvector vec2; }; -const std::vector> inputsf_i32 = { - {0.00001f, 1024, 32, true, true, false, 1234ULL}, - {0.00001f, 1024, 64, true, true, false, 1234ULL}, - {0.00001f, 1024, 32, true, false, false, 1234ULL}, - {0.00001f, 1024, 64, true, false, false, 1234ULL}, - {0.00001f, 1024, 32, false, true, false, 1234ULL}, - {0.00001f, 1024, 64, false, true, false, 1234ULL}, - {0.00001f, 1024, 32, false, false, false, 1234ULL}, - {0.00001f, 1024, 64, false, false, false, 1234ULL}, - - {0.00001f, 1024, 32, true, true, true, 1234ULL}, - {0.00001f, 1024, 64, true, true, true, 1234ULL}, - {0.00001f, 1024, 32, true, false, true, 1234ULL}, - {0.00001f, 1024, 64, true, false, true, 1234ULL}, - {0.00001f, 1024, 32, false, true, true, 1234ULL}, - {0.00001f, 1024, 64, false, true, true, 1234ULL}, - {0.00001f, 1024, 32, false, false, true, 1234ULL}, - {0.00001f, 1024, 64, false, false, true, 1234ULL}}; -typedef MatVecOpTest MatVecOpTestF_i32; -TEST_P(MatVecOpTestF_i32, Result) -{ - ASSERT_TRUE(devArrMatch( - out_ref.data(), out.data(), params.rows * params.cols, CompareApprox(params.tolerance))); -} -INSTANTIATE_TEST_SUITE_P(MatVecOpTests, MatVecOpTestF_i32, ::testing::ValuesIn(inputsf_i32)); +#define MVTEST(TestClass, OutType, inputs, tolerance) \ + TEST_P(TestClass, Result) \ + { \ + if constexpr (std::is_floating_point_v) { \ + ASSERT_TRUE(devArrMatch(out_ref.data() + params.outAlignOffset, \ + out.data() + params.outAlignOffset, \ + params.rows * params.cols, \ + CompareApprox(tolerance))); \ + } else { \ + ASSERT_TRUE(devArrMatch(out_ref.data() + params.outAlignOffset, \ + out.data() + params.outAlignOffset, \ + params.rows * params.cols, \ + Compare())); \ + } \ + } \ + INSTANTIATE_TEST_SUITE_P(MatVecOpTests, TestClass, ::testing::ValuesIn(inputs)) -const std::vector> inputsf_i64 = { - {0.00001f, 2500, 250, false, false, false, 1234ULL}, - {0.00001f, 2500, 250, false, false, true, 1234ULL}}; -typedef MatVecOpTest MatVecOpTestF_i64; -TEST_P(MatVecOpTestF_i64, Result) -{ - ASSERT_TRUE(devArrMatch( - out_ref.data(), out.data(), params.rows * params.cols, CompareApprox(params.tolerance))); -} -INSTANTIATE_TEST_SUITE_P(MatVecOpTests, MatVecOpTestF_i64, ::testing::ValuesIn(inputsf_i64)); - -const std::vector> inputsd_i32 = { - {0.0000001, 1024, 32, true, true, false, 1234ULL}, - {0.0000001, 1024, 64, true, true, false, 1234ULL}, - {0.0000001, 1024, 32, true, false, false, 1234ULL}, - {0.0000001, 1024, 64, true, false, false, 1234ULL}, - {0.0000001, 1024, 32, false, true, false, 1234ULL}, - {0.0000001, 1024, 64, false, true, false, 1234ULL}, - {0.0000001, 1024, 32, false, false, false, 1234ULL}, - {0.0000001, 1024, 64, false, false, false, 1234ULL}, - - {0.0000001, 1024, 32, true, true, true, 1234ULL}, - {0.0000001, 1024, 64, true, true, true, 1234ULL}, - {0.0000001, 1024, 32, true, false, true, 1234ULL}, - {0.0000001, 1024, 64, true, false, true, 1234ULL}, - {0.0000001, 1024, 32, false, true, true, 1234ULL}, - {0.0000001, 1024, 64, false, true, true, 1234ULL}, - {0.0000001, 1024, 32, false, false, true, 1234ULL}, - {0.0000001, 1024, 64, false, false, true, 1234ULL}}; -typedef MatVecOpTest MatVecOpTestD_i32; -TEST_P(MatVecOpTestD_i32, Result) -{ - ASSERT_TRUE(devArrMatch(out_ref.data(), - out.data(), - params.rows * params.cols, - CompareApprox(params.tolerance))); -} -INSTANTIATE_TEST_SUITE_P(MatVecOpTests, MatVecOpTestD_i32, ::testing::ValuesIn(inputsd_i32)); +#define MV_EPS_F 0.00001f +#define MV_EPS_D 0.0000001 -const std::vector> inputsd_i64 = { - {0.0000001, 2500, 250, false, false, false, 1234ULL}, - {0.0000001, 2500, 250, false, false, true, 1234ULL}}; -typedef MatVecOpTest MatVecOpTestD_i64; -TEST_P(MatVecOpTestD_i64, Result) -{ - ASSERT_TRUE(devArrMatch(out_ref.data(), - out.data(), - params.rows * params.cols, - CompareApprox(params.tolerance))); -} -INSTANTIATE_TEST_SUITE_P(MatVecOpTests, MatVecOpTestD_i64, ::testing::ValuesIn(inputsd_i64)); +/* + * This set of tests covers cases where all the types are the same. + */ + +const std::vector> inputs_i32 = + raft::util::itertools::product>( + {1024}, {32, 64}, {true, false}, {true, false}, {0, 1, 2}, {0, 1, 2}, {1234ULL}); +const std::vector> inputs_i64 = + raft::util::itertools::product>( + {2500}, {250}, {false}, {false}, {0, 1}, {0, 1}, {1234ULL}); + +template +struct Add1Vec { + static constexpr bool useTwoVectors = false; + HDI T operator()(T a, T b) const { return a + b; }; +}; +template +struct Add2Vec { + static constexpr bool useTwoVectors = true; + HDI T operator()(T a, T b, T c) const { return a + b + c; }; +}; + +typedef MatVecOpTest, float, int> MatVecOpTestF_i32_add1vec; +typedef MatVecOpTest, float, int> MatVecOpTestF_i32_add2vec; +typedef MatVecOpTest, float, int64_t> MatVecOpTestF_i64_add1vec; +typedef MatVecOpTest, float, int64_t> MatVecOpTestF_i64_add2vec; +typedef MatVecOpTest, double, int> MatVecOpTestD_i32_add1vec; +typedef MatVecOpTest, double, int> MatVecOpTestD_i32_add2vec; +typedef MatVecOpTest, double, int64_t> MatVecOpTestD_i64_add1vec; +typedef MatVecOpTest, double, int64_t> MatVecOpTestD_i64_add2vec; + +MVTEST(MatVecOpTestF_i32_add1vec, float, inputs_i32, MV_EPS_F); +MVTEST(MatVecOpTestF_i32_add2vec, float, inputs_i32, MV_EPS_F); +MVTEST(MatVecOpTestF_i64_add1vec, float, inputs_i64, MV_EPS_F); +MVTEST(MatVecOpTestF_i64_add2vec, float, inputs_i64, MV_EPS_F); +MVTEST(MatVecOpTestD_i32_add1vec, double, inputs_i32, MV_EPS_D); +MVTEST(MatVecOpTestD_i32_add2vec, double, inputs_i32, MV_EPS_D); +MVTEST(MatVecOpTestD_i64_add1vec, double, inputs_i64, MV_EPS_D); +MVTEST(MatVecOpTestD_i64_add2vec, double, inputs_i64, MV_EPS_D); + +/* + * This set of tests covers cases with different types. + */ + +template +struct MulAndAdd { + static constexpr bool useTwoVectors = true; + HDI MatT operator()(MatT a, Vec1T b, Vec2T c) const { return a * b + c; }; +}; + +typedef MatVecOpTest, float, int, int32_t, float> + MatVecOpTestF_i32_MulAndAdd_i32_f; +typedef MatVecOpTest, float, int, int32_t, double> + MatVecOpTestF_i32_MulAndAdd_i32_d; +typedef MatVecOpTest, float, int, int64_t, float> + MatVecOpTestF_i32_MulAndAdd_i64_f; +typedef MatVecOpTest, double, int, int32_t, float> + MatVecOpTestD_i32_MulAndAdd_i32_f; + +MVTEST(MatVecOpTestF_i32_MulAndAdd_i32_f, float, inputs_i32, MV_EPS_F); +MVTEST(MatVecOpTestF_i32_MulAndAdd_i32_d, float, inputs_i32, MV_EPS_F); +MVTEST(MatVecOpTestF_i32_MulAndAdd_i64_f, float, inputs_i32, MV_EPS_F); +MVTEST(MatVecOpTestD_i32_MulAndAdd_i32_f, double, inputs_i32, (double)MV_EPS_F); + +struct DQMultiply { + static constexpr bool useTwoVectors = true; + HDI int8_t operator()(int8_t a, float b, float c) const + { + return static_cast((static_cast(a) / 100.0f * (b + c) / 20.0f) * 100.0f); + }; +}; + +typedef MatVecOpTest MatVecOpTestI8_i32_DQMultiply_f_f; + +MVTEST(MatVecOpTestI8_i32_DQMultiply_f_f, int8_t, inputs_i32, 0); } // end namespace linalg } // end namespace raft diff --git a/cpp/test/linalg/matrix_vector_op.cuh b/cpp/test/linalg/matrix_vector_op.cuh index 934c2f3e0d..602d05d153 100644 --- a/cpp/test/linalg/matrix_vector_op.cuh +++ b/cpp/test/linalg/matrix_vector_op.cuh @@ -21,57 +21,15 @@ namespace raft { namespace linalg { -template -__global__ void naiveMatVecOpKernel(Type* mat, - const Type* vec, - IdxType D, - IdxType N, - bool rowMajor, - bool bcastAlongRows, - LambdaOp operation) -{ - IdxType idx = threadIdx.x + blockIdx.x * blockDim.x; - IdxType len = N * D; - IdxType col; - if (rowMajor && bcastAlongRows) { - col = idx % D; - } else if (!rowMajor && !bcastAlongRows) { - col = idx % N; - } else if (rowMajor && !bcastAlongRows) { - col = idx / D; - } else { - col = idx / N; - } - if (idx < len) { mat[idx] = operation(mat[idx], vec[col]); } -} - -template -void naiveMatVecOp(Type* mat, - const Type* vec, - IdxType D, - IdxType N, - bool rowMajor, - bool bcastAlongRows, - LambdaOp operation, - cudaStream_t stream) -{ - static const IdxType TPB = 64; - IdxType len = N * D; - IdxType nblks = raft::ceildiv(len, TPB); - naiveMatVecOpKernel - <<>>(mat, vec, D, N, rowMajor, bcastAlongRows, operation); - RAFT_CUDA_TRY(cudaPeekAtLastError()); -} - -template -__global__ void naiveMatVecKernel(Type* out, - const Type* mat, - const Type* vec, +template +__global__ void naiveMatVecKernel(OutT* out, + const MatT* mat, + const VecT* vec, IdxType D, IdxType N, bool rowMajor, bool bcastAlongRows, - Type scalar) + Lambda op) { IdxType idx = threadIdx.x + blockIdx.x * blockDim.x; IdxType len = N * D; @@ -85,38 +43,65 @@ __global__ void naiveMatVecKernel(Type* out, } else { col = idx / N; } - if (idx < len) { out[idx] = mat[idx] + scalar * vec[col]; } + if (idx < len) { out[idx] = op(mat[idx], vec[col]); } } -template -void naiveMatVec(Type* out, - const Type* mat, - const Type* vec, +template +void naiveMatVec(OutT* out, + const MatT* mat, + const VecT* vec, IdxType D, IdxType N, bool rowMajor, bool bcastAlongRows, - Type scalar, + Lambda op, cudaStream_t stream) { static const IdxType TPB = 64; IdxType len = N * D; IdxType nblks = raft::ceildiv(len, TPB); - naiveMatVecKernel - <<>>(out, mat, vec, D, N, rowMajor, bcastAlongRows, scalar); + naiveMatVecKernel<<>>(out, mat, vec, D, N, rowMajor, bcastAlongRows, op); RAFT_CUDA_TRY(cudaPeekAtLastError()); } -template -__global__ void naiveMatVecKernel(Type* out, - const Type* mat, - const Type* vec1, - const Type* vec2, +template +void naiveMatVec(OutT* out, + const MatT* mat, + const VecT* vec, + IdxType D, + IdxType N, + bool rowMajor, + bool bcastAlongRows, + OutT scalar, + cudaStream_t stream) +{ + naiveMatVec( + out, + mat, + vec, + D, + N, + rowMajor, + bcastAlongRows, + [scalar] __device__(MatT a, VecT b) { return (OutT)(a + scalar * b); }, + stream); +} + +template +__global__ void naiveMatVecKernel(OutT* out, + const MatT* mat, + const Vec1T* vec1, + const Vec2T* vec2, IdxType D, IdxType N, bool rowMajor, bool bcastAlongRows, - Type scalar) + Lambda op) { IdxType idx = threadIdx.x + blockIdx.x * blockDim.x; IdxType len = N * D; @@ -130,28 +115,58 @@ __global__ void naiveMatVecKernel(Type* out, } else { col = idx / N; } - if (idx < len) { out[idx] = mat[idx] + scalar * vec1[col] + vec2[col]; } + if (idx < len) { out[idx] = op(mat[idx], vec1[col], vec2[col]); } } -template -void naiveMatVec(Type* out, - const Type* mat, - const Type* vec1, - const Type* vec2, +template +void naiveMatVec(OutT* out, + const MatT* mat, + const Vec1T* vec1, + const Vec2T* vec2, IdxType D, IdxType N, bool rowMajor, bool bcastAlongRows, - Type scalar, + Lambda op, cudaStream_t stream) { static const IdxType TPB = 64; IdxType len = N * D; IdxType nblks = raft::ceildiv(len, TPB); - naiveMatVecKernel - <<>>(out, mat, vec1, vec2, D, N, rowMajor, bcastAlongRows, scalar); + naiveMatVecKernel<<>>( + out, mat, vec1, vec2, D, N, rowMajor, bcastAlongRows, op); RAFT_CUDA_TRY(cudaPeekAtLastError()); } +template +void naiveMatVec(OutT* out, + const MatT* mat, + const Vec1T* vec1, + const Vec2T* vec2, + IdxType D, + IdxType N, + bool rowMajor, + bool bcastAlongRows, + OutT scalar, + cudaStream_t stream) +{ + naiveMatVec( + out, + mat, + vec1, + vec2, + D, + N, + rowMajor, + bcastAlongRows, + [scalar] __device__(MatT a, Vec1T b, Vec2T c) { return (OutT)(a + scalar * b + c); }, + stream); +} + } // end namespace linalg } // end namespace raft diff --git a/cpp/test/matrix/linewise_op.cu b/cpp/test/matrix/linewise_op.cu index c61af89bec..2e3d54dcf5 100644 --- a/cpp/test/matrix/linewise_op.cu +++ b/cpp/test/matrix/linewise_op.cu @@ -347,7 +347,6 @@ struct LinewiseTest : public ::testing::TestWithParam> dims; for (auto m : sizes) { for (auto n : sizes) { - dims.push_back(std::make_tuple(n, m)); dims.push_back(std::make_tuple(m, n)); } }