From 635c79baa0572344e528fc39299976b417718e31 Mon Sep 17 00:00:00 2001 From: achirkin Date: Mon, 22 Nov 2021 17:09:47 +0100 Subject: [PATCH 01/17] Version 1 (acrossRows sometimes slower) --- .../raft/linalg/matrix_linewise_op.cuh | 297 ++++++++++++++++++ cpp/include/raft/pow2_utils.cuh | 44 +-- cpp/test/CMakeLists.txt | 2 + cpp/test/linalg/matrix_linewise_op.cu | 268 ++++++++++++++++ 4 files changed, 591 insertions(+), 20 deletions(-) create mode 100644 cpp/include/raft/linalg/matrix_linewise_op.cuh create mode 100644 cpp/test/linalg/matrix_linewise_op.cu diff --git a/cpp/include/raft/linalg/matrix_linewise_op.cuh b/cpp/include/raft/linalg/matrix_linewise_op.cuh new file mode 100644 index 0000000000..9fd80e1fff --- /dev/null +++ b/cpp/include/raft/linalg/matrix_linewise_op.cuh @@ -0,0 +1,297 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +namespace raft { +namespace linalg { + +namespace linewise_impl { + +template +struct Linewise { + static constexpr IdxType VecElems = VecBytes / sizeof(Type); + + typedef raft::TxN_t Vec; + typedef raft::Pow2 AlignBytes; + typedef raft::Pow2 AlignElems; + + template + static __device__ __forceinline__ void vectorCols(Type* out, const Type* in, + const IdxType rowLen, + Lambda op, Args... args) { + const IdxType alignedStart = IdxType(AlignBytes::roundUp(in) - in); + const IdxType alignedEnd = IdxType(AlignBytes::roundDown(in + rowLen) - in); + IdxType i0 = threadIdx.x + blockIdx.y * blockDim.x; + + // First unaligned pieces + if (i0 < alignedStart) out[i0] = op(in[i0], args...); + + // aligned core chunk + { + Vec data; + const IdxType d = blockDim.x * gridDim.y * VecElems; + for (IdxType i = alignedStart + i0 * VecElems; i < alignedEnd; i += d) { + data.load(in, i); +#pragma unroll VecElems + for (int k = 0; k < VecElems; k++) + data.val.data[k] = op(data.val.data[k], args...); + data.store(out, i); + } + } + // last unaligned pieces + i0 += alignedEnd; + if (i0 < rowLen) out[i0] = op(in[i0], args...); + } + + template + static __device__ __forceinline__ void vectorRows( + typename Vec::io_t* out, const typename Vec::io_t* in, const IdxType len, + Lambda op, Args... args) { + Vec v; + const IdxType d = BlockSize * gridDim.x; + for (IdxType i = threadIdx.x + blockIdx.x * BlockSize; i < len; i += d) { + v.val.internal = __ldcv(in + i); +#pragma unroll VecElems + for (int k = 0; k < VecElems; k++) + v.val.data[k] = op(v.val.data[k], args.val.data[k]...); + __stwt(out + i, v.val.internal); + } + } + + static __device__ __forceinline__ Vec loadVec(const Type* p, + const IdxType blockOffset, + const IdxType rowLen) { + // 11.096 ms / 34 Regs + __shared__ alignas(sizeof(Type) * VecElems) Type shm[VecElems * BlockSize]; + IdxType j = blockOffset + threadIdx.x; +#pragma unroll VecElems + for (int k = threadIdx.x; k < VecElems * BlockSize; + k += BlockSize, j += BlockSize) { + while (j >= rowLen) j -= rowLen; + shm[k] = p[j]; + } + __syncthreads(); + { + Vec out; + out.val.internal = + reinterpret_cast(shm)[threadIdx.x]; + return out; + } + + // // 16.686 ms / 66 Regs + // typedef raft::Pow2 AlignWarp; + // int l = AlignWarp::mod(threadIdx.x); + // int d = l >> (AlignWarp::Log2 - AlignElems::Log2); + // Vec out; + // #pragma unroll VecElems + // for (int k = VecElems, j = blockOffset + (threadIdx.x - l) * VecElems + l; + // k > 0; k--, j += AlignWarp::Value) { + // while (j >= rowLen) j -= rowLen; + // const int kd = AlignElems::mod(k + l + d); + // out.val.data[kd] = __ldg(p + j); + // } + // l = AlignWarp::mod(l * VecElems); + // #pragma unroll VecElems + // for (int k = d; k < VecElems + d; k++) { + // const int kd = AlignElems::mod(k); + // out.val.data[kd] = __shfl_sync(0xffffffffu, out.val.data[kd], kd + l); + // } + // return out; + } +}; + +template +__global__ void __launch_bounds__(BlockSize) + matrixLinewiseVecColsKernel(Type* out, const Type* in, const IdxType rowLen, + const IdxType nRows, Lambda op, Vecs... vecs) { + const IdxType j = threadIdx.y + blockIdx.x * blockDim.y; + if (j < nRows) { + const IdxType shift = rowLen * j; + Linewise::vectorCols( + out + shift, in + shift, rowLen, op, vecs[j]...); + } +} + +template +__global__ void __launch_bounds__(BlockSize) + matrixLinewiseVecRowsMainKernel(Type* out, const Type* in, + const IdxType arrOffset, const IdxType rowLen, + const IdxType len, Lambda op, Vecs... vecs) { + typedef Linewise L; + const IdxType blockOffset = + (arrOffset + BlockSize * L::VecElems * blockIdx.x) % rowLen; + L::vectorRows(reinterpret_cast(out), + reinterpret_cast(in), + L::AlignElems::div(len), op, + L::loadVec(vecs, blockOffset, rowLen)...); +} + +template +__global__ void __launch_bounds__(MaxOffset, 2) + matrixLinewiseVecRowsTailKernel(Type* out, const Type* in, + const IdxType arrOffset, + const IdxType arrTail, const IdxType rowLen, + const IdxType len, Lambda op, Vecs... vecs) { + constexpr std::size_t MaxOffsetMod = MaxOffset - 1; + static_assert((MaxOffset & MaxOffsetMod) == 0, + "MaxOffset must be power of two."); + typedef Linewise L; + if (blockIdx.x == 0) + L::vectorRows(reinterpret_cast(out), + reinterpret_cast(in), arrOffset, + op, L::loadVec(vecs, 0, rowLen)...); + else + L::vectorRows( + reinterpret_cast(out + arrTail - MaxOffset), + reinterpret_cast(in + arrTail - MaxOffset), + len - arrTail + MaxOffset, op, + L::loadVec(vecs, arrTail % rowLen, rowLen)...); +} + +template +void matrixLinewiseVecCols(Type* out, const Type* in, const IdxType rowLen, + const IdxType nRows, Lambda op, cudaStream_t stream, + Vecs... vecs) { + constexpr std::size_t VecElems = VecBytes / sizeof(Type); + IdxType bsx = 32; + IdxType bsy = 8; + constexpr int BlockSize = 256; + while (bsy > nRows * 2) { + bsy >>= 1; + bsx <<= 1; + } + IdxType gsy = raft::ceildiv(nRows, bsy); + IdxType gsx = + min(raft::ceildiv(raft::getMultiProcessorCount() * 64, gsy), + raft::ceildiv(rowLen, bsx * VecElems)); + // NB: gridSize.x and gridSize.y are swapped, because gsx is bounded by a small number, + // but gsy can grow uncontrollably with the number of rows. + // (there is a tight limit on the max grid size in `y` direction). + dim3 bs(bsx, bsy, 1); + dim3 gs(gsy, gsx, 1); + matrixLinewiseVecColsKernel + <<>>(out, in, rowLen, nRows, op, vecs...); + CUDA_CHECK(cudaPeekAtLastError()); +} + +template +void matrixLinewiseVecRows(Type* out, const Type* in, const IdxType rowLen, + const IdxType nRows, Lambda op, cudaStream_t stream, + Vecs... vecs) { + typedef raft::Pow2 AlignBytes; + constexpr std::size_t VecElems = VecBytes / sizeof(Type); + const IdxType totalLen = rowLen * nRows; + // blockSize + constexpr int BlockSize = 256; + constexpr dim3 bs(BlockSize, 1, 1); + // if we have `stride` number of blocks, then each block processes always the same + // indices along dimension rowLen; this means a block needs to index `vecs` only once! + const uint stride = + (rowLen / raft::gcd(bs.x * uint(VecElems), uint(rowLen))) * VecElems; + // Minimum size of the grid to make device well occupied + const uint occupy = raft::getMultiProcessorCount() * 64; + const dim3 gs = dim3(min( + // does not make sense to have more blocks than this + raft::ceildiv(uint(totalLen), bs.x * VecElems), + // increase the stride size if necessary + raft::ceildiv(occupy, stride) * stride), + 1, 1); + + const Type* alignedStart = AlignBytes::roundUp(in); + const IdxType alignedOff = IdxType(alignedStart - in); + const IdxType alignedEnd = IdxType(AlignBytes::roundDown(in + totalLen) - in); + const IdxType alignedLen = alignedEnd - alignedOff; + matrixLinewiseVecRowsMainKernel + <<>>(out + alignedOff, alignedStart, alignedOff, rowLen, + alignedLen, op, vecs...); + CUDA_CHECK(cudaPeekAtLastError()); + if (alignedLen < totalLen) { + // should be not smaller than the warp size for better branching + constexpr std::size_t MaxOffset = std::max(std::size_t(32), VecBytes); + matrixLinewiseVecRowsTailKernel + <<>>( + out, in, alignedOff, alignedEnd, rowLen, totalLen, op, vecs...); + CUDA_CHECK(cudaPeekAtLastError()); + } +} + +template +struct MatrixLinewiseOp { + template + static void run(Type* out, const Type* in, const IdxType lineLen, + const IdxType nLines, const bool alongLines, Lambda op, + cudaStream_t stream, Vecs... vecs) { + if constexpr (VecBytes > sizeof(Type)) { + if (!raft::Pow2::areSameAlignOffsets(in, out)) + return MatrixLinewiseOp> 1), sizeof(Type))>::run( + out, in, lineLen, nLines, alongLines, op, stream, vecs...); + } + if (alongLines) + return matrixLinewiseVecRows( + out, in, lineLen, nLines, op, stream, vecs...); + else + return matrixLinewiseVecCols( + out, in, lineLen, nLines, op, stream, vecs...); + } +}; + +}; // namespace linewise_impl + +/** + * Run a function over matrix lines (rows or columns) with a variable number + * row-vectors or column-vectors. + * The term `line` here signifies that the lines can be either columns or rows, + * depending on the matrix layout. + * What matters is if vectors are applied along lines (indices of vectors correspond + * indices within lines), or across lines (indices of vectors correspond to line indices). + * + * @param out result of the operation; can be same as `in`; should be aligned the same as `in` + * to allow faster vectorized memory transfers. + * @param in input matrix consisting of `nLines` lines, each `lineLen`-long. + * @param lineLen length of matrix line in elements (`=nCols` in row-major or `=nRows` in col-major) + * @param nLines number of matrix lines (`=nRows` in row-major or `=nCols` in col-major) + * @param alongLines whether vectors are indices along or across lines. + * @param op the operation applied on each line: + * for i in [0..lineLen) and j in [0..nLines): + * out[i, j] = op(in[i, j], vec1[i], vec2[i], ... veck[i]) if alongLines = true + * out[i, j] = op(in[i, j], vec1[j], vec2[j], ... veck[j]) if alongLines = false + * where matrix indexing is row-major ([i, j] = [i + lineLen * j]). + * @param stream a cuda stream for the kernels + * @param vecs zero or more vectors to be passed as arguments, + * size of each vector is `alongLines ? lineLen : nLines`. + */ +template +void matrixLinewiseOp(Type* out, const Type* in, const IdxType lineLen, + const IdxType nLines, const bool alongLines, Lambda op, + cudaStream_t stream, Vecs... vecs) { + linewise_impl::MatrixLinewiseOp<16>::run( + out, in, lineLen, nLines, alongLines, op, stream, vecs...); +} + +}; // end namespace linalg +}; // end namespace raft diff --git a/cpp/include/raft/pow2_utils.cuh b/cpp/include/raft/pow2_utils.cuh index 56a3192f9f..b1f0b21c7b 100644 --- a/cpp/include/raft/pow2_utils.cuh +++ b/cpp/include/raft/pow2_utils.cuh @@ -35,7 +35,9 @@ struct Pow2 { static_assert(std::is_integral::value, "Value must be integral."); static_assert(Value && !(Value & Mask), "Value must be power of two."); -#define Pow2_IsRepresentableAs(I) (std::is_integral::value && Type(I(Value)) == Value) +#define Pow2_CALL static constexpr __host__ __device__ __forceinline__ +#define Pow2_WHEN_INTEGRAL(I) std::enable_if_t +#define Pow2_IS_REPRESENTABLE_AS(I) (std::is_integral::value && Type(I(Value)) == Value) /** * Integer division by Value truncated toward zero @@ -44,7 +46,7 @@ struct Pow2 { * Invariant: `x = Value * quot(x) + rem(x)` */ template - static constexpr HDI std::enable_if_t quot(I x) noexcept + Pow2_CALL Pow2_WHEN_INTEGRAL(I) quot(I x) noexcept { if constexpr (std::is_signed::value) return (x >> I(Log2)) + (x < 0 && (x & I(Mask))); if constexpr (std::is_unsigned::value) return x >> I(Log2); @@ -57,7 +59,7 @@ struct Pow2 { * Invariant: `x = Value * quot(x) + rem(x)`. */ template - static constexpr HDI std::enable_if_t rem(I x) noexcept + Pow2_CALL Pow2_WHEN_INTEGRAL(I) rem(I x) noexcept { if constexpr (std::is_signed::value) return x < 0 ? -((-x) & I(Mask)) : (x & I(Mask)); if constexpr (std::is_unsigned::value) return x & I(Mask); @@ -74,7 +76,7 @@ struct Pow2 { * compared to normal C++ operators `/` and `%`. */ template - static constexpr HDI std::enable_if_t div(I x) noexcept + Pow2_CALL Pow2_WHEN_INTEGRAL(I) div(I x) noexcept { return x >> I(Log2); } @@ -91,7 +93,7 @@ struct Pow2 { * compared to normal C++ operators `/` and `%`. */ template - static constexpr HDI std::enable_if_t mod(I x) noexcept + Pow2_CALL Pow2_WHEN_INTEGRAL(I) mod(I x) noexcept { return x & I(Mask); } @@ -105,25 +107,25 @@ struct Pow2 { * NB: for pointers, the alignment is checked in bytes, not in elements. */ template - static constexpr HDI bool isAligned(PtrT p) noexcept + Pow2_CALL bool isAligned(PtrT p) noexcept { Pow2_CHECK_TYPE(PtrT); - if constexpr (Pow2_IsRepresentableAs(PtrT)) return mod(p) == 0; - if constexpr (!Pow2_IsRepresentableAs(PtrT)) return mod(reinterpret_cast(p)) == 0; + if constexpr (Pow2_IS_REPRESENTABLE_AS(PtrT)) return mod(p) == 0; + if constexpr (!Pow2_IS_REPRESENTABLE_AS(PtrT)) return mod(reinterpret_cast(p)) == 0; } /** Tell whether two pointers have the same address modulo Value. */ template - static constexpr HDI bool areSameAlignOffsets(PtrT a, PtrS b) noexcept + Pow2_CALL bool areSameAlignOffsets(PtrT a, PtrS b) noexcept { Pow2_CHECK_TYPE(PtrT); Pow2_CHECK_TYPE(PtrS); Type x, y; - if constexpr (Pow2_IsRepresentableAs(PtrT)) + if constexpr (Pow2_IS_REPRESENTABLE_AS(PtrT)) x = Type(mod(a)); else x = mod(reinterpret_cast(a)); - if constexpr (Pow2_IsRepresentableAs(PtrS)) + if constexpr (Pow2_IS_REPRESENTABLE_AS(PtrS)) y = Type(mod(b)); else y = mod(reinterpret_cast(b)); @@ -132,29 +134,31 @@ struct Pow2 { /** Get this or next Value-aligned address (in bytes) or integral. */ template - static constexpr HDI PtrT roundUp(PtrT p) noexcept + Pow2_CALL PtrT roundUp(PtrT p) noexcept { Pow2_CHECK_TYPE(PtrT); - if constexpr (Pow2_IsRepresentableAs(PtrT)) return p + PtrT(Mask) - mod(p + PtrT(Mask)); - if constexpr (!Pow2_IsRepresentableAs(PtrT)) { + if constexpr (Pow2_IS_REPRESENTABLE_AS(PtrT)) return (p + PtrT(Mask)) & PtrT(~Mask); + if constexpr (!Pow2_IS_REPRESENTABLE_AS(PtrT)) { auto x = reinterpret_cast(p); - return reinterpret_cast(x + Mask - mod(x + Mask)); + return reinterpret_cast((x + Mask) & (~Mask)); } } /** Get this or previous Value-aligned address (in bytes) or integral. */ template - static constexpr HDI PtrT roundDown(PtrT p) noexcept + Pow2_CALL PtrT roundDown(PtrT p) noexcept { Pow2_CHECK_TYPE(PtrT); - if constexpr (Pow2_IsRepresentableAs(PtrT)) return p - mod(p); - if constexpr (!Pow2_IsRepresentableAs(PtrT)) { + if constexpr (Pow2_IS_REPRESENTABLE_AS(PtrT)) return p & PtrT(~Mask); + if constexpr (!Pow2_IS_REPRESENTABLE_AS(PtrT)) { auto x = reinterpret_cast(p); - return reinterpret_cast(x - mod(x)); + return reinterpret_cast(x & (~Mask)); } } #undef Pow2_CHECK_TYPE -#undef Pow2_IsRepresentableAs +#undef Pow2_IS_REPRESENTABLE_AS +#undef Pow2_CALL +#undef Pow2_WHEN_INTEGRAL }; }; // namespace raft diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 14052293cf..8f90c1abba 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -51,6 +51,7 @@ add_executable(test_raft test/linalg/gemv.cu test/linalg/map.cu test/linalg/map_then_reduce.cu + test/linalg/matrix_linewise_op.cu test/linalg/matrix_vector_op.cu test/linalg/multiply.cu test/linalg/norm.cu @@ -130,6 +131,7 @@ PRIVATE CUDA::cusolver CUDA::cudart CUDA::cusparse + $<$:CUDA::nvToolsExt> rmm::rmm cuco::cuco FAISS::FAISS diff --git a/cpp/test/linalg/matrix_linewise_op.cu b/cpp/test/linalg/matrix_linewise_op.cu new file mode 100644 index 0000000000..a4973b0637 --- /dev/null +++ b/cpp/test/linalg/matrix_linewise_op.cu @@ -0,0 +1,268 @@ +/* + * Copyright (c) 2018-2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include "../test_utils.h" + +namespace raft { +namespace linalg { + +constexpr std::size_t PTR_PADDING = 128; + +template +void PUSH_RANGE(rmm::cuda_stream_view stream, const char* name, Args... args) { + int length = std::snprintf(nullptr, 0, name, args...); + assert(length >= 0); + auto buf = std::make_unique(length + 1); + std::snprintf(buf.get(), length + 1, name, args...); + stream.synchronize(); + nvtxRangePushA(buf.get()); +} +template <> +void PUSH_RANGE(rmm::cuda_stream_view stream, const char* name) { + stream.synchronize(); + nvtxRangePushA(name); +} + +void POP_RANGE(rmm::cuda_stream_view stream) { + stream.synchronize(); + nvtxRangePop(); +} + +struct LinewiseTestParams { + double tolerance; + std::size_t workSizeBytes; + uint64_t seed; + bool useVanillaMatrixVectorOp; +}; + +template +struct LinewiseTest + : public ::testing::TestWithParam { + const LinewiseTestParams params; + const raft::handle_t handle; + rmm::cuda_stream_view stream; + + LinewiseTest() + : testing::TestWithParam(), + params(ParamsReader::read( + ::testing::TestWithParam::GetParam())), + handle(), + stream(handle.get_stream_view()) {} + + void runLinewiseSum(T* out, const T* in, const I lineLen, const I nLines, + const bool alongLines, const T* vec) { + auto f = [] __device__(T a, T b) -> T { return a + b; }; + if (params.useVanillaMatrixVectorOp) + matrixVectorOp(out, in, vec, lineLen, nLines, true, alongLines, f, + stream); + else + matrixLinewiseOp(out, in, lineLen, nLines, alongLines, f, stream, vec); + } + + void runLinewiseSum(T* out, const T* in, const I lineLen, const I nLines, + const bool alongLines, const T* vec1, const T* vec2) { + auto f = [] __device__(T a, T b, T c) -> T { return a + b + c; }; + if (params.useVanillaMatrixVectorOp) + matrixVectorOp(out, in, vec1, vec2, lineLen, nLines, true, alongLines, f, + stream); + else + matrixLinewiseOp(out, in, lineLen, nLines, alongLines, f, stream, vec1, + vec2); + } + + rmm::device_uvector genData() { + raft::random::Rng r(params.seed); + const std::size_t workSizeElems = params.workSizeBytes / sizeof(T); + rmm::device_uvector blob(workSizeElems, stream); + r.uniform(blob.data(), workSizeElems, T(-1.0), T(1.0), stream); + return blob; + } + + /** + * Suggest multiple versions of matrix dimensions (n, m), such that + * + * (2 * n * m + numVectors * m + minUnused) * sizeof(T) <= workSize. + * + * This way I know I can create two matrices and numVectors vectors of size m, + * such that they fit into the allocated workSet. + */ + std::vector> suggestDimensions(I numVectors) { + const std::size_t workSizeElems = params.workSizeBytes / sizeof(T); + std::vector> out; + const double b = double(numVectors); + const double s = double(workSizeElems) - double(PTR_PADDING * 2 * (2 + b)); + double squareN = 0.25 * (sqrt(8.0 * s + b * b) - b); + + auto solveForN = [s, b](I m) -> double { + return (s - b * double(m)) / double(2 * m); + }; + auto solveForM = [s, b](I n) -> double { return s / double(2 * n + b); }; + auto addIfMakesSense = [&out](double x, double y) { + if (x <= 0 || y <= 0) return; + I n = I(floor(x)); + I m = I(floor(y)); + if (n > 0 && m > 0) out.push_back(std::make_tuple(n, m)); + }; + std::vector sizes = {15, 16, 17, 256, 257, 263}; + addIfMakesSense(squareN, squareN); + for (I k : sizes) { + addIfMakesSense(solveForN(k), k); + addIfMakesSense(k, solveForM(k)); + } + + return out; + } + + std::tuple assignSafePtrs( + rmm::device_uvector& blob, I n, I m) { + typedef raft::Pow2 Align; + T* out = Align::roundUp(blob.data()); + const T* in = + const_cast(Align::roundUp(out + n * m + PTR_PADDING)); + const T* vec1 = Align::roundUp(in + n * m + PTR_PADDING); + const T* vec2 = Align::roundUp(vec1 + m + PTR_PADDING); + ASSERT(blob.data() + blob.size() >= vec2 + PTR_PADDING, + "Failed to allocate pointers: the workset is not big enough."); + return std::make_tuple(out, in, vec1, vec2); + } + + testing::AssertionResult run() { + rmm::device_uvector blob = genData(); + + auto dims = suggestDimensions(2); + + stream.synchronize(); + cudaProfilerStart(); + PUSH_RANGE(stream, params.useVanillaMatrixVectorOp ? "method: original" + : "method: linewise"); + for (auto [n, m] : dims) { + auto [out, in, vec1, vec2] = assignSafePtrs(blob, n, m); + PUSH_RANGE(stream, "Dims-%zu-%zu", std::size_t(n), std::size_t(m)); + for (auto alongRows : ::testing::Bool()) { + PUSH_RANGE(stream, alongRows ? "alongRows" : "acrossRows"); + auto lineLen = alongRows ? m : n; + auto nLines = alongRows ? n : m; + { + PUSH_RANGE(stream, "one vec"); + runLinewiseSum(out, in, lineLen, nLines, alongRows, vec1); + POP_RANGE(stream); + PUSH_RANGE(stream, "two vecs"); + runLinewiseSum(out, in, lineLen, nLines, alongRows, vec1, vec2); + POP_RANGE(stream); + } + POP_RANGE(stream); + } + POP_RANGE(stream); + } + POP_RANGE(stream); + cudaProfilerStop(); + + return testing::AssertionSuccess(); + } + + testing::AssertionResult runTemp() { + rmm::device_uvector blob = genData(); + I n = 257; + I m = 420227; + auto [out, in, vec1, vec2] = assignSafePtrs(blob, n, m); + + stream.synchronize(); + cudaProfilerStart(); + PUSH_RANGE(stream, params.useVanillaMatrixVectorOp ? "method: original" + : "method: linewise"); + for (auto alongRows : ::testing::Bool()) { + PUSH_RANGE(stream, alongRows ? "alongRows" : "acrossRows"); + I lineLen = alongRows ? m : n; + I nLines = alongRows ? n : m; + { + PUSH_RANGE(stream, "one vec"); + runLinewiseSum(out, in, lineLen, nLines, alongRows, vec1); + POP_RANGE(stream); + PUSH_RANGE(stream, "two vecs"); + runLinewiseSum(out, in, lineLen, nLines, alongRows, vec1, vec2); + POP_RANGE(stream); + } + POP_RANGE(stream); + } + POP_RANGE(stream); + cudaProfilerStop(); + + return testing::AssertionSuccess(); + } +}; + +#define TEST_IT(fun, TestClass, ElemType, IndexType) \ + typedef LinewiseTest \ + TestClass##_##ElemType##_##IndexType; \ + TEST_P(TestClass##_##ElemType##_##IndexType, fun) { ASSERT_TRUE(fun()); } \ + INSTANTIATE_TEST_SUITE_P(LinewiseOp, TestClass##_##ElemType##_##IndexType, \ + TestClass##Params) + +auto MegabyteParams = ::testing::Bool(); + +struct Megabyte { + typedef bool Params; + static LinewiseTestParams read(Params ps) { + return {/** .tolerance */ 0.00001, + /** .workSizeBytes */ 1024 * 1024, + /** .seed */ 42ULL, + /** .useVanillaMatrixVectorOp */ ps}; + } +}; + +auto GigabyteParams = ::testing::Bool(); + +struct Gigabyte { + typedef bool Params; + static LinewiseTestParams read(Params ps) { + return {/** .tolerance */ 0.00001, + /** .workSizeBytes */ 1024 * 1024 * 1024, + /** .seed */ 42ULL, + /** .useVanillaMatrixVectorOp */ ps}; + } +}; + +auto TenGigsParams = ::testing::Bool(); + +struct TenGigs { + typedef bool Params; + static LinewiseTestParams read(Params ps) { + return {/** .tolerance */ 0.00001, + /** .workSizeBytes */ 10ULL * 1024ULL * 1024ULL * 1024ULL, + /** .seed */ 42ULL, + /** .useVanillaMatrixVectorOp */ ps}; + } +}; + +TEST_IT(run, Megabyte, float, int); +TEST_IT(run, Megabyte, double, int); +// TEST_IT(run, Gigabyte, float, int); +// TEST_IT(run, Gigabyte, double, int); +TEST_IT(run, TenGigs, float, uint64_t); +TEST_IT(run, TenGigs, double, uint64_t); + +TEST_IT(runTemp, Gigabyte, float, int); + +} // end namespace linalg +} // end namespace raft From ed27367da1f8fca710f3d6fe2d95ca72c8667263 Mon Sep 17 00:00:00 2001 From: achirkin Date: Thu, 25 Nov 2021 13:45:48 +0100 Subject: [PATCH 02/17] Version 2 (all faster, but not fully tested) --- .../raft/linalg/matrix_linewise_op.cuh | 198 ++++++++++-------- cpp/test/linalg/matrix_linewise_op.cu | 36 +--- 2 files changed, 116 insertions(+), 118 deletions(-) diff --git a/cpp/include/raft/linalg/matrix_linewise_op.cuh b/cpp/include/raft/linalg/matrix_linewise_op.cuh index 9fd80e1fff..d3789be504 100644 --- a/cpp/include/raft/linalg/matrix_linewise_op.cuh +++ b/cpp/include/raft/linalg/matrix_linewise_op.cuh @@ -32,39 +32,49 @@ struct Linewise { typedef raft::TxN_t Vec; typedef raft::Pow2 AlignBytes; typedef raft::Pow2 AlignElems; + typedef raft::Pow2 AlignWarp; - template - static __device__ __forceinline__ void vectorCols(Type* out, const Type* in, - const IdxType rowLen, - Lambda op, Args... args) { - const IdxType alignedStart = IdxType(AlignBytes::roundUp(in) - in); - const IdxType alignedEnd = IdxType(AlignBytes::roundDown(in + rowLen) - in); - IdxType i0 = threadIdx.x + blockIdx.y * blockDim.x; - - // First unaligned pieces - if (i0 < alignedStart) out[i0] = op(in[i0], args...); - - // aligned core chunk - { - Vec data; - const IdxType d = blockDim.x * gridDim.y * VecElems; - for (IdxType i = alignedStart + i0 * VecElems; i < alignedEnd; i += d) { - data.load(in, i); + template + static __device__ __forceinline__ void vectorCols( + typename Vec::io_t* out, const typename Vec::io_t* in, + const typename Vec::io_t* in_end, const IdxType rowLen, IdxType rowDiv, + IdxType rowMod, Lambda op, Vecs... vecs) noexcept { + constexpr IdxType warpPad = (AlignWarp::Value - 1) * VecElems; + Type args[sizeof...(Vecs)]; + Vec v, w; + bool update = true; + for (; in < in_end; + in += AlignWarp::Value, out += AlignWarp::Value, rowMod += warpPad) { + v.val.internal = __ldcv(in); + while (rowMod >= rowLen) { + rowMod -= rowLen; + rowDiv++; + update = true; + } + if (update) { + int l = 0; + ((args[l] = vecs[rowDiv], l++), ...); + update = false; + } #pragma unroll VecElems - for (int k = 0; k < VecElems; k++) - data.val.data[k] = op(data.val.data[k], args...); - data.store(out, i); + for (int k = 0; k < VecElems; k++, rowMod++) { + if (rowMod == rowLen) { + rowMod = 0; + rowDiv++; + int l = 0; + ((args[l] = vecs[rowDiv], l++), ...); + } + int l = 0; + w.val.data[k] = op(v.val.data[k], (std::ignore = vecs, args[l++])...); } + *out = w.val.internal; } - // last unaligned pieces - i0 += alignedEnd; - if (i0 < rowLen) out[i0] = op(in[i0], args...); } template static __device__ __forceinline__ void vectorRows( typename Vec::io_t* out, const typename Vec::io_t* in, const IdxType len, - Lambda op, Args... args) { + Lambda op, Args... args) noexcept { Vec v; const IdxType d = BlockSize * gridDim.x; for (IdxType i = threadIdx.x + blockIdx.x * BlockSize; i < len; i += d) { @@ -78,8 +88,7 @@ struct Linewise { static __device__ __forceinline__ Vec loadVec(const Type* p, const IdxType blockOffset, - const IdxType rowLen) { - // 11.096 ms / 34 Regs + const IdxType rowLen) noexcept { __shared__ alignas(sizeof(Type) * VecElems) Type shm[VecElems * BlockSize]; IdxType j = blockOffset + threadIdx.x; #pragma unroll VecElems @@ -95,40 +104,55 @@ struct Linewise { reinterpret_cast(shm)[threadIdx.x]; return out; } - - // // 16.686 ms / 66 Regs - // typedef raft::Pow2 AlignWarp; - // int l = AlignWarp::mod(threadIdx.x); - // int d = l >> (AlignWarp::Log2 - AlignElems::Log2); - // Vec out; - // #pragma unroll VecElems - // for (int k = VecElems, j = blockOffset + (threadIdx.x - l) * VecElems + l; - // k > 0; k--, j += AlignWarp::Value) { - // while (j >= rowLen) j -= rowLen; - // const int kd = AlignElems::mod(k + l + d); - // out.val.data[kd] = __ldg(p + j); - // } - // l = AlignWarp::mod(l * VecElems); - // #pragma unroll VecElems - // for (int k = d; k < VecElems + d; k++) { - // const int kd = AlignElems::mod(k); - // out.val.data[kd] = __shfl_sync(0xffffffffu, out.val.data[kd], kd + l); - // } - // return out; } }; template __global__ void __launch_bounds__(BlockSize) - matrixLinewiseVecColsKernel(Type* out, const Type* in, const IdxType rowLen, - const IdxType nRows, Lambda op, Vecs... vecs) { - const IdxType j = threadIdx.y + blockIdx.x * blockDim.y; - if (j < nRows) { - const IdxType shift = rowLen * j; - Linewise::vectorCols( - out + shift, in + shift, rowLen, op, vecs[j]...); + matrixLinewiseVecColsMainKernel(Type* out, const Type* in, + const IdxType arrOffset, const IdxType rowLen, + const IdxType len, + const IdxType elemsPerThread, Lambda op, + Vecs... vecs) { + typedef Linewise L; + + IdxType t = L::AlignWarp::mod(threadIdx.x); + t = arrOffset + elemsPerThread * (blockIdx.x * BlockSize + threadIdx.x - t) + + t * L::VecElems; + + return L::vectorCols( + reinterpret_cast(out + t), + reinterpret_cast(in + t), + reinterpret_cast( + in + min(t + elemsPerThread * L::AlignWarp::Value, len)), + rowLen, t / rowLen, t % rowLen, op, vecs...); +} + +template +__global__ void __launch_bounds__(MaxOffset, 2) + matrixLinewiseVecColsTailKernel(Type* out, const Type* in, + const IdxType arrOffset, + const IdxType arrTail, const IdxType rowLen, + const IdxType len, Lambda op, Vecs... vecs) { + typedef Linewise L; + IdxType threadOffset, elemsPerWarp; + if (blockIdx.x == 0) { + threadOffset = threadIdx.x; + elemsPerWarp = threadOffset < arrOffset; + } else { + threadOffset = arrTail + threadIdx.x; + elemsPerWarp = threadOffset < len; } + const IdxType rowDiv = threadOffset / rowLen; + const IdxType rowMod = threadOffset % rowLen; + return L::vectorCols( + reinterpret_cast(out + threadOffset), + reinterpret_cast(in + threadOffset), + reinterpret_cast(in + threadOffset + + elemsPerWarp), + rowLen, rowDiv, rowMod, op, vecs...); } template L; const IdxType blockOffset = (arrOffset + BlockSize * L::VecElems * blockIdx.x) % rowLen; - L::vectorRows(reinterpret_cast(out), - reinterpret_cast(in), - L::AlignElems::div(len), op, - L::loadVec(vecs, blockOffset, rowLen)...); + return L::vectorRows(reinterpret_cast(out), + reinterpret_cast(in), + L::AlignElems::div(len), op, + L::loadVec(vecs, blockOffset, rowLen)...); } template L; if (blockIdx.x == 0) L::vectorRows(reinterpret_cast(out), @@ -174,27 +195,36 @@ template AlignBytes; constexpr std::size_t VecElems = VecBytes / sizeof(Type); - IdxType bsx = 32; - IdxType bsy = 8; + const IdxType totalLen = rowLen * nRows; + const Type* alignedStart = AlignBytes::roundUp(in); + const IdxType alignedOff = IdxType(alignedStart - in); + const IdxType alignedEnd = IdxType(AlignBytes::roundDown(in + totalLen) - in); + const IdxType alignedLen = alignedEnd - alignedOff; + // blockSize constexpr int BlockSize = 256; - while (bsy > nRows * 2) { - bsy >>= 1; - bsx <<= 1; - } - IdxType gsy = raft::ceildiv(nRows, bsy); - IdxType gsx = - min(raft::ceildiv(raft::getMultiProcessorCount() * 64, gsy), - raft::ceildiv(rowLen, bsx * VecElems)); - // NB: gridSize.x and gridSize.y are swapped, because gsx is bounded by a small number, - // but gsy can grow uncontrollably with the number of rows. - // (there is a tight limit on the max grid size in `y` direction). - dim3 bs(bsx, bsy, 1); - dim3 gs(gsy, gsx, 1); - matrixLinewiseVecColsKernel - <<>>(out, in, rowLen, nRows, op, vecs...); + constexpr dim3 bs(BlockSize, 1, 1); + // Minimum size of the grid to make device well occupied + const uint occupy = raft::getMultiProcessorCount() * 64; + // does not make sense to have more blocks than this + const uint maxBlocks = raft::ceildiv(uint(alignedLen), bs.x * VecElems); + const dim3 gs(min(maxBlocks, occupy), 1, 1); + + const IdxType elemsPerThread = + raft::ceildiv(alignedLen, gs.x * VecElems * BlockSize) * VecElems; + matrixLinewiseVecColsMainKernel<<>>( + out, in, alignedOff, rowLen, alignedLen, elemsPerThread, op, vecs...); CUDA_CHECK(cudaPeekAtLastError()); + if (alignedLen < totalLen) { + // should be not smaller than the warp size for better branching + constexpr std::size_t MaxOffset = std::max(std::size_t(32), VecBytes); + matrixLinewiseVecColsTailKernel + <<>>( + out, in, alignedOff, alignedEnd, rowLen, totalLen, op, vecs...); + CUDA_CHECK(cudaPeekAtLastError()); + } } template (uint(totalLen), bs.x * VecElems), - // increase the stride size if necessary - raft::ceildiv(occupy, stride) * stride), - 1, 1); + const dim3 gs(min( + // does not make sense to have more blocks than this + raft::ceildiv(uint(totalLen), bs.x * VecElems), + // increase the stride size if necessary + raft::ceildiv(occupy, stride) * stride), + 1, 1); const Type* alignedStart = AlignBytes::roundUp(in); const IdxType alignedOff = IdxType(alignedStart - in); diff --git a/cpp/test/linalg/matrix_linewise_op.cu b/cpp/test/linalg/matrix_linewise_op.cu index a4973b0637..c2f70d4525 100644 --- a/cpp/test/linalg/matrix_linewise_op.cu +++ b/cpp/test/linalg/matrix_linewise_op.cu @@ -180,36 +180,6 @@ struct LinewiseTest return testing::AssertionSuccess(); } - - testing::AssertionResult runTemp() { - rmm::device_uvector blob = genData(); - I n = 257; - I m = 420227; - auto [out, in, vec1, vec2] = assignSafePtrs(blob, n, m); - - stream.synchronize(); - cudaProfilerStart(); - PUSH_RANGE(stream, params.useVanillaMatrixVectorOp ? "method: original" - : "method: linewise"); - for (auto alongRows : ::testing::Bool()) { - PUSH_RANGE(stream, alongRows ? "alongRows" : "acrossRows"); - I lineLen = alongRows ? m : n; - I nLines = alongRows ? n : m; - { - PUSH_RANGE(stream, "one vec"); - runLinewiseSum(out, in, lineLen, nLines, alongRows, vec1); - POP_RANGE(stream); - PUSH_RANGE(stream, "two vecs"); - runLinewiseSum(out, in, lineLen, nLines, alongRows, vec1, vec2); - POP_RANGE(stream); - } - POP_RANGE(stream); - } - POP_RANGE(stream); - cudaProfilerStop(); - - return testing::AssertionSuccess(); - } }; #define TEST_IT(fun, TestClass, ElemType, IndexType) \ @@ -257,12 +227,10 @@ struct TenGigs { TEST_IT(run, Megabyte, float, int); TEST_IT(run, Megabyte, double, int); -// TEST_IT(run, Gigabyte, float, int); -// TEST_IT(run, Gigabyte, double, int); +TEST_IT(run, Gigabyte, float, int); +TEST_IT(run, Gigabyte, double, int); TEST_IT(run, TenGigs, float, uint64_t); TEST_IT(run, TenGigs, double, uint64_t); -TEST_IT(runTemp, Gigabyte, float, int); - } // end namespace linalg } // end namespace raft From bb3061549d0b866e42f157948daa04ffa3cc2c75 Mon Sep 17 00:00:00 2001 From: achirkin Date: Thu, 25 Nov 2021 15:26:15 +0100 Subject: [PATCH 03/17] Cosmetics and tests --- .../raft/linalg/matrix_linewise_op.cuh | 66 +++++++++++-------- cpp/include/raft/linalg/matrix_vector_op.cuh | 41 ++++++++---- cpp/test/linalg/matrix_linewise_op.cu | 59 +++++++++++++---- 3 files changed, 114 insertions(+), 52 deletions(-) diff --git a/cpp/include/raft/linalg/matrix_linewise_op.cuh b/cpp/include/raft/linalg/matrix_linewise_op.cuh index d3789be504..af7dd41971 100644 --- a/cpp/include/raft/linalg/matrix_linewise_op.cuh +++ b/cpp/include/raft/linalg/matrix_linewise_op.cuh @@ -190,7 +190,7 @@ __global__ void __launch_bounds__(MaxOffset, 2) L::loadVec(vecs, arrTail % rowLen, rowLen)...); } -template void matrixLinewiseVecCols(Type* out, const Type* in, const IdxType rowLen, const IdxType nRows, Lambda op, cudaStream_t stream, @@ -202,11 +202,9 @@ void matrixLinewiseVecCols(Type* out, const Type* in, const IdxType rowLen, const IdxType alignedOff = IdxType(alignedStart - in); const IdxType alignedEnd = IdxType(AlignBytes::roundDown(in + totalLen) - in); const IdxType alignedLen = alignedEnd - alignedOff; - // blockSize - constexpr int BlockSize = 256; constexpr dim3 bs(BlockSize, 1, 1); - // Minimum size of the grid to make device well occupied - const uint occupy = raft::getMultiProcessorCount() * 64; + // Minimum size of the grid to make the device well occupied + const uint occupy = raft::getMultiProcessorCount() * (16384 / BlockSize); // does not make sense to have more blocks than this const uint maxBlocks = raft::ceildiv(uint(alignedLen), bs.x * VecElems); const dim3 gs(min(maxBlocks, occupy), 1, 1); @@ -227,7 +225,7 @@ void matrixLinewiseVecCols(Type* out, const Type* in, const IdxType rowLen, } } -template void matrixLinewiseVecRows(Type* out, const Type* in, const IdxType rowLen, const IdxType nRows, Lambda op, cudaStream_t stream, @@ -236,14 +234,13 @@ void matrixLinewiseVecRows(Type* out, const Type* in, const IdxType rowLen, constexpr std::size_t VecElems = VecBytes / sizeof(Type); const IdxType totalLen = rowLen * nRows; // blockSize - constexpr int BlockSize = 256; constexpr dim3 bs(BlockSize, 1, 1); // if we have `stride` number of blocks, then each block processes always the same // indices along dimension rowLen; this means a block needs to index `vecs` only once! const uint stride = (rowLen / raft::gcd(bs.x * uint(VecElems), uint(rowLen))) * VecElems; - // Minimum size of the grid to make device well occupied - const uint occupy = raft::getMultiProcessorCount() * 64; + // Minimum size of the grid to make the device well occupied + const uint occupy = raft::getMultiProcessorCount() * (16384 / BlockSize); const dim3 gs(min( // does not make sense to have more blocks than this raft::ceildiv(uint(totalLen), bs.x * VecElems), @@ -270,7 +267,16 @@ void matrixLinewiseVecRows(Type* out, const Type* in, const IdxType rowLen, } } -template +/** + * Select one of the implementations: + * a. vectors applied along/across lines + * b. recursively try different VecBytes, such that alignments of `in` and `out` + * are the same. + * + * @tparam VecBytes - size of the load/store ops in bytes. + * @tparam BlockSize - is fixed and should not affect the performance. + */ +template struct MatrixLinewiseOp { template static void run(Type* out, const Type* in, const IdxType lineLen, @@ -278,15 +284,19 @@ struct MatrixLinewiseOp { cudaStream_t stream, Vecs... vecs) { if constexpr (VecBytes > sizeof(Type)) { if (!raft::Pow2::areSameAlignOffsets(in, out)) - return MatrixLinewiseOp> 1), sizeof(Type))>::run( - out, in, lineLen, nLines, alongLines, op, stream, vecs...); + return MatrixLinewiseOp> 1), sizeof(Type)), + BlockSize>::run(out, in, lineLen, nLines, + alongLines, op, stream, + vecs...); } if (alongLines) - return matrixLinewiseVecRows( - out, in, lineLen, nLines, op, stream, vecs...); + return matrixLinewiseVecRows(out, in, lineLen, nLines, op, + stream, vecs...); else - return matrixLinewiseVecCols( - out, in, lineLen, nLines, op, stream, vecs...); + return matrixLinewiseVecCols(out, in, lineLen, nLines, op, + stream, vecs...); } }; @@ -297,29 +307,29 @@ struct MatrixLinewiseOp { * row-vectors or column-vectors. * The term `line` here signifies that the lines can be either columns or rows, * depending on the matrix layout. - * What matters is if vectors are applied along lines (indices of vectors correspond - * indices within lines), or across lines (indices of vectors correspond to line indices). + * What matters is if the vectors are applied along lines (indices of vectors correspond to + * indices within lines), or across lines (indices of vectors correspond to line numbers). * - * @param out result of the operation; can be same as `in`; should be aligned the same as `in` - * to allow faster vectorized memory transfers. - * @param in input matrix consisting of `nLines` lines, each `lineLen`-long. - * @param lineLen length of matrix line in elements (`=nCols` in row-major or `=nRows` in col-major) - * @param nLines number of matrix lines (`=nRows` in row-major or `=nCols` in col-major) - * @param alongLines whether vectors are indices along or across lines. - * @param op the operation applied on each line: + * @param [out] out result of the operation; can be same as `in`; should be aligned the same + * as `in` to allow faster vectorized memory transfers. + * @param [in] in input matrix consisting of `nLines` lines, each `lineLen`-long. + * @param [in] lineLen length of matrix line in elements (`=nCols` in row-major or `=nRows` in col-major) + * @param [in] nLines number of matrix lines (`=nRows` in row-major or `=nCols` in col-major) + * @param [in] alongLines whether vectors are indices along or across lines. + * @param [in] op the operation applied on each line: * for i in [0..lineLen) and j in [0..nLines): * out[i, j] = op(in[i, j], vec1[i], vec2[i], ... veck[i]) if alongLines = true * out[i, j] = op(in[i, j], vec1[j], vec2[j], ... veck[j]) if alongLines = false * where matrix indexing is row-major ([i, j] = [i + lineLen * j]). - * @param stream a cuda stream for the kernels - * @param vecs zero or more vectors to be passed as arguments, + * @param [in] stream a cuda stream for the kernels + * @param [in] vecs zero or more vectors to be passed as arguments, * size of each vector is `alongLines ? lineLen : nLines`. */ template void matrixLinewiseOp(Type* out, const Type* in, const IdxType lineLen, const IdxType nLines, const bool alongLines, Lambda op, cudaStream_t stream, Vecs... vecs) { - linewise_impl::MatrixLinewiseOp<16>::run( + linewise_impl::MatrixLinewiseOp<16, 256>::run( out, in, lineLen, nLines, alongLines, op, stream, vecs...); } diff --git a/cpp/include/raft/linalg/matrix_vector_op.cuh b/cpp/include/raft/linalg/matrix_vector_op.cuh index 81c1919b2e..3082d92d9f 100644 --- a/cpp/include/raft/linalg/matrix_vector_op.cuh +++ b/cpp/include/raft/linalg/matrix_vector_op.cuh @@ -17,6 +17,7 @@ #pragma once #include +#include #include #include @@ -127,22 +128,30 @@ void matrixVectorOp(Type* out, Lambda op, cudaStream_t stream) { - IdxType stride = rowMajor ? D : N; + IdxType stride = rowMajor ? D : N; + // IdxType nLines = rowMajor ? N : D; + // return matrixLinewiseOp(out, matrix, stride, nLines, + // rowMajor == bcastAlongRows, op, stream, vec); + size_t stride_bytes = stride * sizeof(Type); - if (AlignedAccess<16>::test(matrix, stride_bytes)) { + if (AlignedAccess<16>::test(matrix, stride_bytes) && AlignedAccess<16>::test(out, stride_bytes)) { matrixVectorOpImpl( out, matrix, vec, D, N, rowMajor, bcastAlongRows, op, stream); - } else if (AlignedAccess<8>::test(matrix, stride_bytes)) { + } else if (AlignedAccess<8>::test(matrix, stride_bytes) && + AlignedAccess<8>::test(out, stride_bytes)) { matrixVectorOpImpl( out, matrix, vec, D, N, rowMajor, bcastAlongRows, op, stream); - } else if (AlignedAccess<4>::test(matrix, stride_bytes)) { + } else if (AlignedAccess<4>::test(matrix, stride_bytes) && + AlignedAccess<4>::test(out, stride_bytes)) { matrixVectorOpImpl( out, matrix, vec, D, N, rowMajor, bcastAlongRows, op, stream); - } else if (AlignedAccess<2>::test(matrix, stride_bytes)) { + } else if (AlignedAccess<2>::test(matrix, stride_bytes) && + AlignedAccess<2>::test(out, stride_bytes)) { matrixVectorOpImpl( out, matrix, vec, D, N, rowMajor, bcastAlongRows, op, stream); - } else if (AlignedAccess<1>::test(matrix, stride_bytes)) { + } else if (AlignedAccess<1>::test(matrix, stride_bytes) && + AlignedAccess<1>::test(out, stride_bytes)) { matrixVectorOpImpl( out, matrix, vec, D, N, rowMajor, bcastAlongRows, op, stream); } else { @@ -250,22 +259,30 @@ void matrixVectorOp(Type* out, Lambda op, cudaStream_t stream) { - IdxType stride = rowMajor ? D : N; + IdxType stride = rowMajor ? D : N; + // IdxType nLines = rowMajor ? N : D; + // return matrixLinewiseOp(out, matrix, stride, nLines, + // rowMajor == bcastAlongRows, op, stream, vec1, vec2); + size_t stride_bytes = stride * sizeof(Type); - if (AlignedAccess<16>::test(matrix, stride_bytes)) { + if (AlignedAccess<16>::test(matrix, stride_bytes) && AlignedAccess<16>::test(out, stride_bytes)) { matrixVectorOpImpl( out, matrix, vec1, vec2, D, N, rowMajor, bcastAlongRows, op, stream); - } else if (AlignedAccess<8>::test(matrix, stride_bytes)) { + } else if (AlignedAccess<8>::test(matrix, stride_bytes) && + AlignedAccess<8>::test(out, stride_bytes)) { matrixVectorOpImpl( out, matrix, vec1, vec2, D, N, rowMajor, bcastAlongRows, op, stream); - } else if (AlignedAccess<4>::test(matrix, stride_bytes)) { + } else if (AlignedAccess<4>::test(matrix, stride_bytes) && + AlignedAccess<4>::test(out, stride_bytes)) { matrixVectorOpImpl( out, matrix, vec1, vec2, D, N, rowMajor, bcastAlongRows, op, stream); - } else if (AlignedAccess<2>::test(matrix, stride_bytes)) { + } else if (AlignedAccess<2>::test(matrix, stride_bytes) && + AlignedAccess<2>::test(out, stride_bytes)) { matrixVectorOpImpl( out, matrix, vec1, vec2, D, N, rowMajor, bcastAlongRows, op, stream); - } else if (AlignedAccess<1>::test(matrix, stride_bytes)) { + } else if (AlignedAccess<1>::test(matrix, stride_bytes) && + AlignedAccess<1>::test(out, stride_bytes)) { matrixVectorOpImpl( out, matrix, vec1, vec2, D, N, rowMajor, bcastAlongRows, op, stream); } else { diff --git a/cpp/test/linalg/matrix_linewise_op.cu b/cpp/test/linalg/matrix_linewise_op.cu index c2f70d4525..4a77fe57aa 100644 --- a/cpp/test/linalg/matrix_linewise_op.cu +++ b/cpp/test/linalg/matrix_linewise_op.cu @@ -23,6 +23,7 @@ #include #include #include "../test_utils.h" +#include "matrix_vector_op.cuh" namespace raft { namespace linalg { @@ -54,6 +55,9 @@ struct LinewiseTestParams { std::size_t workSizeBytes; uint64_t seed; bool useVanillaMatrixVectorOp; + bool checkCorrectness; + int inAlignOffset; + int outAlignOffset; }; template @@ -124,7 +128,7 @@ struct LinewiseTest I m = I(floor(y)); if (n > 0 && m > 0) out.push_back(std::make_tuple(n, m)); }; - std::vector sizes = {15, 16, 17, 256, 257, 263}; + std::vector sizes = {15, 16, 17, 256, 257, 263, 1024}; addIfMakesSense(squareN, squareN); for (I k : sizes) { addIfMakesSense(solveForN(k), k); @@ -137,9 +141,10 @@ struct LinewiseTest std::tuple assignSafePtrs( rmm::device_uvector& blob, I n, I m) { typedef raft::Pow2 Align; - T* out = Align::roundUp(blob.data()); + T* out = Align::roundUp(blob.data()) + params.outAlignOffset; const T* in = - const_cast(Align::roundUp(out + n * m + PTR_PADDING)); + const_cast(Align::roundUp(out + n * m + PTR_PADDING)) + + params.inAlignOffset; const T* vec1 = Align::roundUp(in + n * m + PTR_PADDING); const T* vec2 = Align::roundUp(vec1 + m + PTR_PADDING); ASSERT(blob.data() + blob.size() >= vec2 + PTR_PADDING, @@ -149,6 +154,8 @@ struct LinewiseTest testing::AssertionResult run() { rmm::device_uvector blob = genData(); + rmm::device_uvector blob_val( + params.checkCorrectness ? blob.size() / 2 : 0, stream); auto dims = suggestDimensions(2); @@ -167,9 +174,25 @@ struct LinewiseTest PUSH_RANGE(stream, "one vec"); runLinewiseSum(out, in, lineLen, nLines, alongRows, vec1); POP_RANGE(stream); + if (params.checkCorrectness) { + naiveMatVec(blob_val.data(), in, vec1, lineLen, nLines, true, + alongRows, T(1)); + EXPECT_NO_FATAL_FAILURE( + devArrMatch(blob_val.data(), out, n * m, + CompareApprox(params.tolerance))) + << "with one vec"; + } PUSH_RANGE(stream, "two vecs"); runLinewiseSum(out, in, lineLen, nLines, alongRows, vec1, vec2); POP_RANGE(stream); + if (params.checkCorrectness) { + naiveMatVec(blob_val.data(), in, vec1, vec2, lineLen, nLines, true, + alongRows, T(1)); + EXPECT_NO_FATAL_FAILURE( + devArrMatch(blob_val.data(), out, n * m, + CompareApprox(params.tolerance))) + << "with two vecs"; + } } POP_RANGE(stream); } @@ -189,39 +212,51 @@ struct LinewiseTest INSTANTIATE_TEST_SUITE_P(LinewiseOp, TestClass##_##ElemType##_##IndexType, \ TestClass##Params) -auto MegabyteParams = ::testing::Bool(); +auto MegabyteParams = + ::testing::Combine(::testing::Bool(), ::testing::Values(0, 1, 2, 4), + ::testing::Values(0, 1, 2, 3)); struct Megabyte { - typedef bool Params; + typedef std::tuple Params; static LinewiseTestParams read(Params ps) { return {/** .tolerance */ 0.00001, /** .workSizeBytes */ 1024 * 1024, /** .seed */ 42ULL, - /** .useVanillaMatrixVectorOp */ ps}; + /** .useVanillaMatrixVectorOp */ std::get<0>(ps), + /** .checkCorrectness */ true, + /** .inAlignOffset */ std::get<1>(ps), + /** .outAlignOffset */ std::get<2>(ps)}; } }; -auto GigabyteParams = ::testing::Bool(); +auto GigabyteParams = ::testing::Combine( + ::testing::Bool(), ::testing::Values(0, 1, 2), ::testing::Values(0, 1, 2)); struct Gigabyte { - typedef bool Params; + typedef std::tuple Params; static LinewiseTestParams read(Params ps) { return {/** .tolerance */ 0.00001, /** .workSizeBytes */ 1024 * 1024 * 1024, /** .seed */ 42ULL, - /** .useVanillaMatrixVectorOp */ ps}; + /** .useVanillaMatrixVectorOp */ std::get<0>(ps), + /** .checkCorrectness */ false, + /** .inAlignOffset */ std::get<1>(ps), + /** .outAlignOffset */ std::get<2>(ps)}; } }; -auto TenGigsParams = ::testing::Bool(); +auto TenGigsParams = GigabyteParams; struct TenGigs { - typedef bool Params; + typedef std::tuple Params; static LinewiseTestParams read(Params ps) { return {/** .tolerance */ 0.00001, /** .workSizeBytes */ 10ULL * 1024ULL * 1024ULL * 1024ULL, /** .seed */ 42ULL, - /** .useVanillaMatrixVectorOp */ ps}; + /** .useVanillaMatrixVectorOp */ std::get<0>(ps), + /** .checkCorrectness */ false, + /** .inAlignOffset */ std::get<1>(ps), + /** .outAlignOffset */ std::get<2>(ps)}; } }; From 5c6364265084c49852d1f309eec2410f4d2b62e2 Mon Sep 17 00:00:00 2001 From: achirkin Date: Thu, 25 Nov 2021 15:30:57 +0100 Subject: [PATCH 04/17] Replace matrixVectorOp implementation with matrixLinewiseOp --- cpp/include/raft/linalg/matrix_vector_op.cuh | 199 +------------------ 1 file changed, 5 insertions(+), 194 deletions(-) diff --git a/cpp/include/raft/linalg/matrix_vector_op.cuh b/cpp/include/raft/linalg/matrix_vector_op.cuh index 3082d92d9f..9e38ebd167 100644 --- a/cpp/include/raft/linalg/matrix_vector_op.cuh +++ b/cpp/include/raft/linalg/matrix_vector_op.cuh @@ -16,84 +16,11 @@ #pragma once -#include #include -#include -#include namespace raft { namespace linalg { -namespace { -template -struct AlignedAccess { - template - static inline bool test(const T* matrix, size_t strideBytes) - { - return Pow2::isAligned(matrix) && Pow2::isAligned(strideBytes) && - Pow2::isAligned(VecBytes); - } -}; -}; // namespace - -template -__global__ void matrixVectorOpKernel(Type* out, - const Type* matrix, - const Type* vector, - IdxType D, - IdxType N, - bool rowMajor, - bool bcastAlongRows, - Lambda op) -{ - typedef TxN_t VecType; - IdxType len = N * D; - IdxType idx = threadIdx.x; - idx += (IdxType)blockIdx.x * (IdxType)blockDim.x; - idx *= VecType::Ratio; - if (idx >= len) return; - IdxType vIdx; - VecType mat, vec; - ///@todo: yikes! use fast-int-div here. - ///@todo: shared mem for vector could help with perf - if (rowMajor && bcastAlongRows) { - vIdx = idx % D; - vec.load(vector, vIdx); - } else if (!rowMajor && !bcastAlongRows) { - vIdx = idx % N; - vec.load(vector, vIdx); - } else if (rowMajor && !bcastAlongRows) { - vIdx = idx / D; - vec.fill(vector[vIdx]); - } else { - vIdx = idx / N; - vec.fill(vector[vIdx]); - } - mat.load(matrix, idx); -#pragma unroll - for (int i = 0; i < VecType::Ratio; ++i) - mat.val.data[i] = op(mat.val.data[i], vec.val.data[i]); - mat.store(out, idx); -} - -template -void matrixVectorOpImpl(Type* out, - const Type* matrix, - const Type* vec, - IdxType D, - IdxType N, - bool rowMajor, - bool bcastAlongRows, - Lambda op, - cudaStream_t stream) -{ - IdxType len = N * D; - IdxType nblks = raft::ceildiv(veclen_ ? len / veclen_ : veclen_, (IdxType)TPB); - matrixVectorOpKernel - <<>>(out, matrix, vec, D, N, rowMajor, bcastAlongRows, op); - CUDA_CHECK(cudaPeekAtLastError()); -} - /** * @brief Operations for all the columns or rows with a given vector. * Caution : Threads process multiple elements to speed up processing. These @@ -129,98 +56,8 @@ void matrixVectorOp(Type* out, cudaStream_t stream) { IdxType stride = rowMajor ? D : N; - // IdxType nLines = rowMajor ? N : D; - // return matrixLinewiseOp(out, matrix, stride, nLines, - // rowMajor == bcastAlongRows, op, stream, vec); - - size_t stride_bytes = stride * sizeof(Type); - - if (AlignedAccess<16>::test(matrix, stride_bytes) && AlignedAccess<16>::test(out, stride_bytes)) { - matrixVectorOpImpl( - out, matrix, vec, D, N, rowMajor, bcastAlongRows, op, stream); - } else if (AlignedAccess<8>::test(matrix, stride_bytes) && - AlignedAccess<8>::test(out, stride_bytes)) { - matrixVectorOpImpl( - out, matrix, vec, D, N, rowMajor, bcastAlongRows, op, stream); - } else if (AlignedAccess<4>::test(matrix, stride_bytes) && - AlignedAccess<4>::test(out, stride_bytes)) { - matrixVectorOpImpl( - out, matrix, vec, D, N, rowMajor, bcastAlongRows, op, stream); - } else if (AlignedAccess<2>::test(matrix, stride_bytes) && - AlignedAccess<2>::test(out, stride_bytes)) { - matrixVectorOpImpl( - out, matrix, vec, D, N, rowMajor, bcastAlongRows, op, stream); - } else if (AlignedAccess<1>::test(matrix, stride_bytes) && - AlignedAccess<1>::test(out, stride_bytes)) { - matrixVectorOpImpl( - out, matrix, vec, D, N, rowMajor, bcastAlongRows, op, stream); - } else { - matrixVectorOpImpl( - out, matrix, vec, D, N, rowMajor, bcastAlongRows, op, stream); - } -} - -///@todo: come up with a cleaner interface to support these cases in future! - -template -__global__ void matrixVectorOpKernel(Type* out, - const Type* matrix, - const Type* vector1, - const Type* vector2, - IdxType D, - IdxType N, - bool rowMajor, - bool bcastAlongRows, - Lambda op) -{ - typedef TxN_t VecType; - IdxType len = N * D; - IdxType idx = (threadIdx.x + (blockIdx.x * blockDim.x)) * VecType::Ratio; - if (idx >= len) return; - IdxType vIdx; - VecType mat, vec1, vec2; - ///@todo: yikes! use fast-int-div here. - ///@todo: shared mem for vector could help with perf - if (rowMajor && bcastAlongRows) { - vIdx = idx % D; - vec1.load(vector1, vIdx); - vec2.load(vector2, vIdx); - } else if (!rowMajor && !bcastAlongRows) { - vIdx = idx % N; - vec1.load(vector1, vIdx); - vec2.load(vector2, vIdx); - } else if (rowMajor && !bcastAlongRows) { - vIdx = idx / D; - vec1.fill(vector1[vIdx]); - vec2.fill(vector2[vIdx]); - } else { - vIdx = idx / N; - vec1.fill(vector1[vIdx]); - vec2.fill(vector2[vIdx]); - } - mat.load(matrix, idx); -#pragma unroll - for (int i = 0; i < VecType::Ratio; ++i) - mat.val.data[i] = op(mat.val.data[i], vec1.val.data[i], vec2.val.data[i]); - mat.store(out, idx); -} - -template -void matrixVectorOpImpl(Type* out, - const Type* matrix, - const Type* vec1, - const Type* vec2, - IdxType D, - IdxType N, - bool rowMajor, - bool bcastAlongRows, - Lambda op, - cudaStream_t stream) -{ - IdxType nblks = raft::ceildiv(N * D, (IdxType)TPB); - matrixVectorOpKernel - <<>>(out, matrix, vec1, vec2, D, N, rowMajor, bcastAlongRows, op); - CUDA_CHECK(cudaPeekAtLastError()); + IdxType nLines = rowMajor ? N : D; + return matrixLinewiseOp(out, matrix, stride, nLines, rowMajor == bcastAlongRows, op, stream, vec); } /** @@ -260,35 +97,9 @@ void matrixVectorOp(Type* out, cudaStream_t stream) { IdxType stride = rowMajor ? D : N; - // IdxType nLines = rowMajor ? N : D; - // return matrixLinewiseOp(out, matrix, stride, nLines, - // rowMajor == bcastAlongRows, op, stream, vec1, vec2); - - size_t stride_bytes = stride * sizeof(Type); - - if (AlignedAccess<16>::test(matrix, stride_bytes) && AlignedAccess<16>::test(out, stride_bytes)) { - matrixVectorOpImpl( - out, matrix, vec1, vec2, D, N, rowMajor, bcastAlongRows, op, stream); - } else if (AlignedAccess<8>::test(matrix, stride_bytes) && - AlignedAccess<8>::test(out, stride_bytes)) { - matrixVectorOpImpl( - out, matrix, vec1, vec2, D, N, rowMajor, bcastAlongRows, op, stream); - } else if (AlignedAccess<4>::test(matrix, stride_bytes) && - AlignedAccess<4>::test(out, stride_bytes)) { - matrixVectorOpImpl( - out, matrix, vec1, vec2, D, N, rowMajor, bcastAlongRows, op, stream); - } else if (AlignedAccess<2>::test(matrix, stride_bytes) && - AlignedAccess<2>::test(out, stride_bytes)) { - matrixVectorOpImpl( - out, matrix, vec1, vec2, D, N, rowMajor, bcastAlongRows, op, stream); - } else if (AlignedAccess<1>::test(matrix, stride_bytes) && - AlignedAccess<1>::test(out, stride_bytes)) { - matrixVectorOpImpl( - out, matrix, vec1, vec2, D, N, rowMajor, bcastAlongRows, op, stream); - } else { - matrixVectorOpImpl( - out, matrix, vec1, vec2, D, N, rowMajor, bcastAlongRows, op, stream); - } + IdxType nLines = rowMajor ? N : D; + return matrixLinewiseOp( + out, matrix, stride, nLines, rowMajor == bcastAlongRows, op, stream, vec1, vec2); } }; // end namespace linalg From 6a261876deae6566dff5afcd379d88d5f513995a Mon Sep 17 00:00:00 2001 From: achirkin Date: Thu, 25 Nov 2021 16:06:47 +0100 Subject: [PATCH 05/17] Update to the new styles --- .../raft/linalg/matrix_linewise_op.cuh | 266 +++++++++++------- cpp/test/linalg/matrix_linewise_op.cu | 118 ++++---- 2 files changed, 226 insertions(+), 158 deletions(-) diff --git a/cpp/include/raft/linalg/matrix_linewise_op.cuh b/cpp/include/raft/linalg/matrix_linewise_op.cuh index af7dd41971..4f1061bcf3 100644 --- a/cpp/include/raft/linalg/matrix_linewise_op.cuh +++ b/cpp/include/raft/linalg/matrix_linewise_op.cuh @@ -35,16 +35,20 @@ struct Linewise { typedef raft::Pow2 AlignWarp; template - static __device__ __forceinline__ void vectorCols( - typename Vec::io_t* out, const typename Vec::io_t* in, - const typename Vec::io_t* in_end, const IdxType rowLen, IdxType rowDiv, - IdxType rowMod, Lambda op, Vecs... vecs) noexcept { + static __device__ __forceinline__ void vectorCols(typename Vec::io_t* out, + const typename Vec::io_t* in, + const typename Vec::io_t* in_end, + const IdxType rowLen, + IdxType rowDiv, + IdxType rowMod, + Lambda op, + Vecs... vecs) noexcept + { constexpr IdxType warpPad = (AlignWarp::Value - 1) * VecElems; Type args[sizeof...(Vecs)]; Vec v, w; bool update = true; - for (; in < in_end; - in += AlignWarp::Value, out += AlignWarp::Value, rowMod += warpPad) { + for (; in < in_end; in += AlignWarp::Value, out += AlignWarp::Value, rowMod += warpPad) { v.val.internal = __ldcv(in); while (rowMod >= rowLen) { rowMod -= rowLen; @@ -64,7 +68,7 @@ struct Linewise { int l = 0; ((args[l] = vecs[rowDiv], l++), ...); } - int l = 0; + int l = 0; w.val.data[k] = op(v.val.data[k], (std::ignore = vecs, args[l++])...); } *out = w.val.internal; @@ -72,9 +76,12 @@ struct Linewise { } template - static __device__ __forceinline__ void vectorRows( - typename Vec::io_t* out, const typename Vec::io_t* in, const IdxType len, - Lambda op, Args... args) noexcept { + static __device__ __forceinline__ void vectorRows(typename Vec::io_t* out, + const typename Vec::io_t* in, + const IdxType len, + Lambda op, + Args... args) noexcept + { Vec v; const IdxType d = BlockSize * gridDim.x; for (IdxType i = threadIdx.x + blockIdx.x * BlockSize; i < len; i += d) { @@ -88,54 +95,68 @@ struct Linewise { static __device__ __forceinline__ Vec loadVec(const Type* p, const IdxType blockOffset, - const IdxType rowLen) noexcept { + const IdxType rowLen) noexcept + { __shared__ alignas(sizeof(Type) * VecElems) Type shm[VecElems * BlockSize]; IdxType j = blockOffset + threadIdx.x; #pragma unroll VecElems - for (int k = threadIdx.x; k < VecElems * BlockSize; - k += BlockSize, j += BlockSize) { - while (j >= rowLen) j -= rowLen; + for (int k = threadIdx.x; k < VecElems * BlockSize; k += BlockSize, j += BlockSize) { + while (j >= rowLen) + j -= rowLen; shm[k] = p[j]; } __syncthreads(); { Vec out; - out.val.internal = - reinterpret_cast(shm)[threadIdx.x]; + out.val.internal = reinterpret_cast(shm)[threadIdx.x]; return out; } } }; -template +template __global__ void __launch_bounds__(BlockSize) - matrixLinewiseVecColsMainKernel(Type* out, const Type* in, - const IdxType arrOffset, const IdxType rowLen, + matrixLinewiseVecColsMainKernel(Type* out, + const Type* in, + const IdxType arrOffset, + const IdxType rowLen, const IdxType len, - const IdxType elemsPerThread, Lambda op, - Vecs... vecs) { + const IdxType elemsPerThread, + Lambda op, + Vecs... vecs) +{ typedef Linewise L; IdxType t = L::AlignWarp::mod(threadIdx.x); - t = arrOffset + elemsPerThread * (blockIdx.x * BlockSize + threadIdx.x - t) + - t * L::VecElems; + t = arrOffset + elemsPerThread * (blockIdx.x * BlockSize + threadIdx.x - t) + t * L::VecElems; - return L::vectorCols( - reinterpret_cast(out + t), - reinterpret_cast(in + t), - reinterpret_cast( - in + min(t + elemsPerThread * L::AlignWarp::Value, len)), - rowLen, t / rowLen, t % rowLen, op, vecs...); + return L::vectorCols(reinterpret_cast(out + t), + reinterpret_cast(in + t), + reinterpret_cast( + in + min(t + elemsPerThread * L::AlignWarp::Value, len)), + rowLen, + t / rowLen, + t % rowLen, + op, + vecs...); } -template +template __global__ void __launch_bounds__(MaxOffset, 2) - matrixLinewiseVecColsTailKernel(Type* out, const Type* in, + matrixLinewiseVecColsTailKernel(Type* out, + const Type* in, const IdxType arrOffset, - const IdxType arrTail, const IdxType rowLen, - const IdxType len, Lambda op, Vecs... vecs) { + const IdxType arrTail, + const IdxType rowLen, + const IdxType len, + Lambda op, + Vecs... vecs) +{ typedef Linewise L; IdxType threadOffset, elemsPerWarp; if (blockIdx.x == 0) { @@ -150,58 +171,85 @@ __global__ void __launch_bounds__(MaxOffset, 2) return L::vectorCols( reinterpret_cast(out + threadOffset), reinterpret_cast(in + threadOffset), - reinterpret_cast(in + threadOffset + - elemsPerWarp), - rowLen, rowDiv, rowMod, op, vecs...); + reinterpret_cast(in + threadOffset + elemsPerWarp), + rowLen, + rowDiv, + rowMod, + op, + vecs...); } -template +template __global__ void __launch_bounds__(BlockSize) - matrixLinewiseVecRowsMainKernel(Type* out, const Type* in, - const IdxType arrOffset, const IdxType rowLen, - const IdxType len, Lambda op, Vecs... vecs) { + matrixLinewiseVecRowsMainKernel(Type* out, + const Type* in, + const IdxType arrOffset, + const IdxType rowLen, + const IdxType len, + Lambda op, + Vecs... vecs) +{ typedef Linewise L; - const IdxType blockOffset = - (arrOffset + BlockSize * L::VecElems * blockIdx.x) % rowLen; + const IdxType blockOffset = (arrOffset + BlockSize * L::VecElems * blockIdx.x) % rowLen; return L::vectorRows(reinterpret_cast(out), reinterpret_cast(in), - L::AlignElems::div(len), op, + L::AlignElems::div(len), + op, L::loadVec(vecs, blockOffset, rowLen)...); } -template +template __global__ void __launch_bounds__(MaxOffset, 2) - matrixLinewiseVecRowsTailKernel(Type* out, const Type* in, + matrixLinewiseVecRowsTailKernel(Type* out, + const Type* in, const IdxType arrOffset, - const IdxType arrTail, const IdxType rowLen, - const IdxType len, Lambda op, Vecs... vecs) { + const IdxType arrTail, + const IdxType rowLen, + const IdxType len, + Lambda op, + Vecs... vecs) +{ typedef Linewise L; if (blockIdx.x == 0) L::vectorRows(reinterpret_cast(out), - reinterpret_cast(in), arrOffset, - op, L::loadVec(vecs, 0, rowLen)...); + reinterpret_cast(in), + arrOffset, + op, + L::loadVec(vecs, 0, rowLen)...); else - L::vectorRows( - reinterpret_cast(out + arrTail - MaxOffset), - reinterpret_cast(in + arrTail - MaxOffset), - len - arrTail + MaxOffset, op, - L::loadVec(vecs, arrTail % rowLen, rowLen)...); + L::vectorRows(reinterpret_cast(out + arrTail - MaxOffset), + reinterpret_cast(in + arrTail - MaxOffset), + len - arrTail + MaxOffset, + op, + L::loadVec(vecs, arrTail % rowLen, rowLen)...); } -template -void matrixLinewiseVecCols(Type* out, const Type* in, const IdxType rowLen, - const IdxType nRows, Lambda op, cudaStream_t stream, - Vecs... vecs) { +template +void matrixLinewiseVecCols(Type* out, + const Type* in, + const IdxType rowLen, + const IdxType nRows, + Lambda op, + cudaStream_t stream, + Vecs... vecs) +{ typedef raft::Pow2 AlignBytes; constexpr std::size_t VecElems = VecBytes / sizeof(Type); - const IdxType totalLen = rowLen * nRows; - const Type* alignedStart = AlignBytes::roundUp(in); - const IdxType alignedOff = IdxType(alignedStart - in); - const IdxType alignedEnd = IdxType(AlignBytes::roundDown(in + totalLen) - in); - const IdxType alignedLen = alignedEnd - alignedOff; + const IdxType totalLen = rowLen * nRows; + const Type* alignedStart = AlignBytes::roundUp(in); + const IdxType alignedOff = IdxType(alignedStart - in); + const IdxType alignedEnd = IdxType(AlignBytes::roundDown(in + totalLen) - in); + const IdxType alignedLen = alignedEnd - alignedOff; constexpr dim3 bs(BlockSize, 1, 1); // Minimum size of the grid to make the device well occupied const uint occupy = raft::getMultiProcessorCount() * (16384 / BlockSize); @@ -211,9 +259,8 @@ void matrixLinewiseVecCols(Type* out, const Type* in, const IdxType rowLen, const IdxType elemsPerThread = raft::ceildiv(alignedLen, gs.x * VecElems * BlockSize) * VecElems; - matrixLinewiseVecColsMainKernel<<>>( - out, in, alignedOff, rowLen, alignedLen, elemsPerThread, op, vecs...); + matrixLinewiseVecColsMainKernel + <<>>(out, in, alignedOff, rowLen, alignedLen, elemsPerThread, op, vecs...); CUDA_CHECK(cudaPeekAtLastError()); if (alignedLen < totalLen) { // should be not smaller than the warp size for better branching @@ -225,20 +272,28 @@ void matrixLinewiseVecCols(Type* out, const Type* in, const IdxType rowLen, } } -template -void matrixLinewiseVecRows(Type* out, const Type* in, const IdxType rowLen, - const IdxType nRows, Lambda op, cudaStream_t stream, - Vecs... vecs) { +template +void matrixLinewiseVecRows(Type* out, + const Type* in, + const IdxType rowLen, + const IdxType nRows, + Lambda op, + cudaStream_t stream, + Vecs... vecs) +{ typedef raft::Pow2 AlignBytes; constexpr std::size_t VecElems = VecBytes / sizeof(Type); - const IdxType totalLen = rowLen * nRows; + const IdxType totalLen = rowLen * nRows; // blockSize constexpr dim3 bs(BlockSize, 1, 1); // if we have `stride` number of blocks, then each block processes always the same // indices along dimension rowLen; this means a block needs to index `vecs` only once! - const uint stride = - (rowLen / raft::gcd(bs.x * uint(VecElems), uint(rowLen))) * VecElems; + const uint stride = (rowLen / raft::gcd(bs.x * uint(VecElems), uint(rowLen))) * VecElems; // Minimum size of the grid to make the device well occupied const uint occupy = raft::getMultiProcessorCount() * (16384 / BlockSize); const dim3 gs(min( @@ -246,16 +301,16 @@ void matrixLinewiseVecRows(Type* out, const Type* in, const IdxType rowLen, raft::ceildiv(uint(totalLen), bs.x * VecElems), // increase the stride size if necessary raft::ceildiv(occupy, stride) * stride), - 1, 1); + 1, + 1); const Type* alignedStart = AlignBytes::roundUp(in); const IdxType alignedOff = IdxType(alignedStart - in); const IdxType alignedEnd = IdxType(AlignBytes::roundDown(in + totalLen) - in); const IdxType alignedLen = alignedEnd - alignedOff; - matrixLinewiseVecRowsMainKernel - <<>>(out + alignedOff, alignedStart, alignedOff, rowLen, - alignedLen, op, vecs...); + matrixLinewiseVecRowsMainKernel + <<>>( + out + alignedOff, alignedStart, alignedOff, rowLen, alignedLen, op, vecs...); CUDA_CHECK(cudaPeekAtLastError()); if (alignedLen < totalLen) { // should be not smaller than the warp size for better branching @@ -279,24 +334,26 @@ void matrixLinewiseVecRows(Type* out, const Type* in, const IdxType rowLen, template struct MatrixLinewiseOp { template - static void run(Type* out, const Type* in, const IdxType lineLen, - const IdxType nLines, const bool alongLines, Lambda op, - cudaStream_t stream, Vecs... vecs) { + static void run(Type* out, + const Type* in, + const IdxType lineLen, + const IdxType nLines, + const bool alongLines, + Lambda op, + cudaStream_t stream, + Vecs... vecs) + { if constexpr (VecBytes > sizeof(Type)) { if (!raft::Pow2::areSameAlignOffsets(in, out)) - return MatrixLinewiseOp> 1), sizeof(Type)), - BlockSize>::run(out, in, lineLen, nLines, - alongLines, op, stream, - vecs...); + return MatrixLinewiseOp> 1), sizeof(Type)), BlockSize>::run( + out, in, lineLen, nLines, alongLines, op, stream, vecs...); } if (alongLines) - return matrixLinewiseVecRows(out, in, lineLen, nLines, op, - stream, vecs...); + return matrixLinewiseVecRows( + out, in, lineLen, nLines, op, stream, vecs...); else - return matrixLinewiseVecCols(out, in, lineLen, nLines, op, - stream, vecs...); + return matrixLinewiseVecCols( + out, in, lineLen, nLines, op, stream, vecs...); } }; @@ -313,7 +370,8 @@ struct MatrixLinewiseOp { * @param [out] out result of the operation; can be same as `in`; should be aligned the same * as `in` to allow faster vectorized memory transfers. * @param [in] in input matrix consisting of `nLines` lines, each `lineLen`-long. - * @param [in] lineLen length of matrix line in elements (`=nCols` in row-major or `=nRows` in col-major) + * @param [in] lineLen length of matrix line in elements (`=nCols` in row-major or `=nRows` in + * col-major) * @param [in] nLines number of matrix lines (`=nRows` in row-major or `=nCols` in col-major) * @param [in] alongLines whether vectors are indices along or across lines. * @param [in] op the operation applied on each line: @@ -326,9 +384,15 @@ struct MatrixLinewiseOp { * size of each vector is `alongLines ? lineLen : nLines`. */ template -void matrixLinewiseOp(Type* out, const Type* in, const IdxType lineLen, - const IdxType nLines, const bool alongLines, Lambda op, - cudaStream_t stream, Vecs... vecs) { +void matrixLinewiseOp(Type* out, + const Type* in, + const IdxType lineLen, + const IdxType nLines, + const bool alongLines, + Lambda op, + cudaStream_t stream, + Vecs... vecs) +{ linewise_impl::MatrixLinewiseOp<16, 256>::run( out, in, lineLen, nLines, alongLines, op, stream, vecs...); } diff --git a/cpp/test/linalg/matrix_linewise_op.cu b/cpp/test/linalg/matrix_linewise_op.cu index 4a77fe57aa..430a9981e0 100644 --- a/cpp/test/linalg/matrix_linewise_op.cu +++ b/cpp/test/linalg/matrix_linewise_op.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2021, NVIDIA CORPORATION. + * Copyright (c) 2021, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -31,7 +31,8 @@ namespace linalg { constexpr std::size_t PTR_PADDING = 128; template -void PUSH_RANGE(rmm::cuda_stream_view stream, const char* name, Args... args) { +void PUSH_RANGE(rmm::cuda_stream_view stream, const char* name, Args... args) +{ int length = std::snprintf(nullptr, 0, name, args...); assert(length >= 0); auto buf = std::make_unique(length + 1); @@ -40,12 +41,14 @@ void PUSH_RANGE(rmm::cuda_stream_view stream, const char* name, Args... args) { nvtxRangePushA(buf.get()); } template <> -void PUSH_RANGE(rmm::cuda_stream_view stream, const char* name) { +void PUSH_RANGE(rmm::cuda_stream_view stream, const char* name) +{ stream.synchronize(); nvtxRangePushA(name); } -void POP_RANGE(rmm::cuda_stream_view stream) { +void POP_RANGE(rmm::cuda_stream_view stream) +{ stream.synchronize(); nvtxRangePop(); } @@ -61,41 +64,47 @@ struct LinewiseTestParams { }; template -struct LinewiseTest - : public ::testing::TestWithParam { +struct LinewiseTest : public ::testing::TestWithParam { const LinewiseTestParams params; const raft::handle_t handle; rmm::cuda_stream_view stream; LinewiseTest() : testing::TestWithParam(), - params(ParamsReader::read( - ::testing::TestWithParam::GetParam())), + params( + ParamsReader::read(::testing::TestWithParam::GetParam())), handle(), - stream(handle.get_stream_view()) {} + stream(handle.get_stream_view()) + { + } - void runLinewiseSum(T* out, const T* in, const I lineLen, const I nLines, - const bool alongLines, const T* vec) { + void runLinewiseSum( + T* out, const T* in, const I lineLen, const I nLines, const bool alongLines, const T* vec) + { auto f = [] __device__(T a, T b) -> T { return a + b; }; if (params.useVanillaMatrixVectorOp) - matrixVectorOp(out, in, vec, lineLen, nLines, true, alongLines, f, - stream); + matrixVectorOp(out, in, vec, lineLen, nLines, true, alongLines, f, stream); else matrixLinewiseOp(out, in, lineLen, nLines, alongLines, f, stream, vec); } - void runLinewiseSum(T* out, const T* in, const I lineLen, const I nLines, - const bool alongLines, const T* vec1, const T* vec2) { + void runLinewiseSum(T* out, + const T* in, + const I lineLen, + const I nLines, + const bool alongLines, + const T* vec1, + const T* vec2) + { auto f = [] __device__(T a, T b, T c) -> T { return a + b + c; }; if (params.useVanillaMatrixVectorOp) - matrixVectorOp(out, in, vec1, vec2, lineLen, nLines, true, alongLines, f, - stream); + matrixVectorOp(out, in, vec1, vec2, lineLen, nLines, true, alongLines, f, stream); else - matrixLinewiseOp(out, in, lineLen, nLines, alongLines, f, stream, vec1, - vec2); + matrixLinewiseOp(out, in, lineLen, nLines, alongLines, f, stream, vec1, vec2); } - rmm::device_uvector genData() { + rmm::device_uvector genData() + { raft::random::Rng r(params.seed); const std::size_t workSizeElems = params.workSizeBytes / sizeof(T); rmm::device_uvector blob(workSizeElems, stream); @@ -111,17 +120,16 @@ struct LinewiseTest * This way I know I can create two matrices and numVectors vectors of size m, * such that they fit into the allocated workSet. */ - std::vector> suggestDimensions(I numVectors) { + std::vector> suggestDimensions(I numVectors) + { const std::size_t workSizeElems = params.workSizeBytes / sizeof(T); std::vector> out; const double b = double(numVectors); const double s = double(workSizeElems) - double(PTR_PADDING * 2 * (2 + b)); double squareN = 0.25 * (sqrt(8.0 * s + b * b) - b); - auto solveForN = [s, b](I m) -> double { - return (s - b * double(m)) / double(2 * m); - }; - auto solveForM = [s, b](I n) -> double { return s / double(2 * n + b); }; + auto solveForN = [s, b](I m) -> double { return (s - b * double(m)) / double(2 * m); }; + auto solveForM = [s, b](I n) -> double { return s / double(2 * n + b); }; auto addIfMakesSense = [&out](double x, double y) { if (x <= 0 || y <= 0) return; I n = I(floor(x)); @@ -138,13 +146,14 @@ struct LinewiseTest return out; } - std::tuple assignSafePtrs( - rmm::device_uvector& blob, I n, I m) { + std::tuple assignSafePtrs(rmm::device_uvector& blob, + I n, + I m) + { typedef raft::Pow2 Align; T* out = Align::roundUp(blob.data()) + params.outAlignOffset; const T* in = - const_cast(Align::roundUp(out + n * m + PTR_PADDING)) + - params.inAlignOffset; + const_cast(Align::roundUp(out + n * m + PTR_PADDING)) + params.inAlignOffset; const T* vec1 = Align::roundUp(in + n * m + PTR_PADDING); const T* vec2 = Align::roundUp(vec1 + m + PTR_PADDING); ASSERT(blob.data() + blob.size() >= vec2 + PTR_PADDING, @@ -152,45 +161,40 @@ struct LinewiseTest return std::make_tuple(out, in, vec1, vec2); } - testing::AssertionResult run() { + testing::AssertionResult run() + { rmm::device_uvector blob = genData(); - rmm::device_uvector blob_val( - params.checkCorrectness ? blob.size() / 2 : 0, stream); + rmm::device_uvector blob_val(params.checkCorrectness ? blob.size() / 2 : 0, stream); auto dims = suggestDimensions(2); stream.synchronize(); cudaProfilerStart(); - PUSH_RANGE(stream, params.useVanillaMatrixVectorOp ? "method: original" - : "method: linewise"); + PUSH_RANGE(stream, params.useVanillaMatrixVectorOp ? "method: original" : "method: linewise"); for (auto [n, m] : dims) { auto [out, in, vec1, vec2] = assignSafePtrs(blob, n, m); PUSH_RANGE(stream, "Dims-%zu-%zu", std::size_t(n), std::size_t(m)); for (auto alongRows : ::testing::Bool()) { PUSH_RANGE(stream, alongRows ? "alongRows" : "acrossRows"); auto lineLen = alongRows ? m : n; - auto nLines = alongRows ? n : m; + auto nLines = alongRows ? n : m; { PUSH_RANGE(stream, "one vec"); runLinewiseSum(out, in, lineLen, nLines, alongRows, vec1); POP_RANGE(stream); if (params.checkCorrectness) { - naiveMatVec(blob_val.data(), in, vec1, lineLen, nLines, true, - alongRows, T(1)); + naiveMatVec(blob_val.data(), in, vec1, lineLen, nLines, true, alongRows, T(1)); EXPECT_NO_FATAL_FAILURE( - devArrMatch(blob_val.data(), out, n * m, - CompareApprox(params.tolerance))) + devArrMatch(blob_val.data(), out, n * m, CompareApprox(params.tolerance))) << "with one vec"; } PUSH_RANGE(stream, "two vecs"); runLinewiseSum(out, in, lineLen, nLines, alongRows, vec1, vec2); POP_RANGE(stream); if (params.checkCorrectness) { - naiveMatVec(blob_val.data(), in, vec1, vec2, lineLen, nLines, true, - alongRows, T(1)); + naiveMatVec(blob_val.data(), in, vec1, vec2, lineLen, nLines, true, alongRows, T(1)); EXPECT_NO_FATAL_FAILURE( - devArrMatch(blob_val.data(), out, n * m, - CompareApprox(params.tolerance))) + devArrMatch(blob_val.data(), out, n * m, CompareApprox(params.tolerance))) << "with two vecs"; } } @@ -205,20 +209,18 @@ struct LinewiseTest } }; -#define TEST_IT(fun, TestClass, ElemType, IndexType) \ - typedef LinewiseTest \ - TestClass##_##ElemType##_##IndexType; \ - TEST_P(TestClass##_##ElemType##_##IndexType, fun) { ASSERT_TRUE(fun()); } \ - INSTANTIATE_TEST_SUITE_P(LinewiseOp, TestClass##_##ElemType##_##IndexType, \ - TestClass##Params) +#define TEST_IT(fun, TestClass, ElemType, IndexType) \ + typedef LinewiseTest TestClass##_##ElemType##_##IndexType; \ + TEST_P(TestClass##_##ElemType##_##IndexType, fun) { ASSERT_TRUE(fun()); } \ + INSTANTIATE_TEST_SUITE_P(LinewiseOp, TestClass##_##ElemType##_##IndexType, TestClass##Params) -auto MegabyteParams = - ::testing::Combine(::testing::Bool(), ::testing::Values(0, 1, 2, 4), - ::testing::Values(0, 1, 2, 3)); +auto MegabyteParams = ::testing::Combine( + ::testing::Bool(), ::testing::Values(0, 1, 2, 4), ::testing::Values(0, 1, 2, 3)); struct Megabyte { typedef std::tuple Params; - static LinewiseTestParams read(Params ps) { + static LinewiseTestParams read(Params ps) + { return {/** .tolerance */ 0.00001, /** .workSizeBytes */ 1024 * 1024, /** .seed */ 42ULL, @@ -229,12 +231,13 @@ struct Megabyte { } }; -auto GigabyteParams = ::testing::Combine( - ::testing::Bool(), ::testing::Values(0, 1, 2), ::testing::Values(0, 1, 2)); +auto GigabyteParams = + ::testing::Combine(::testing::Bool(), ::testing::Values(0, 1, 2), ::testing::Values(0, 1, 2)); struct Gigabyte { typedef std::tuple Params; - static LinewiseTestParams read(Params ps) { + static LinewiseTestParams read(Params ps) + { return {/** .tolerance */ 0.00001, /** .workSizeBytes */ 1024 * 1024 * 1024, /** .seed */ 42ULL, @@ -249,7 +252,8 @@ auto TenGigsParams = GigabyteParams; struct TenGigs { typedef std::tuple Params; - static LinewiseTestParams read(Params ps) { + static LinewiseTestParams read(Params ps) + { return {/** .tolerance */ 0.00001, /** .workSizeBytes */ 10ULL * 1024ULL * 1024ULL * 1024ULL, /** .seed */ 42ULL, From 1d9378b20b2fe8b4a793d75d36f437caff46fdb4 Mon Sep 17 00:00:00 2001 From: achirkin Date: Fri, 26 Nov 2021 08:03:24 +0100 Subject: [PATCH 06/17] Add NVTX flag --- cpp/test/CMakeLists.txt | 5 +++++ cpp/test/linalg/matrix_linewise_op.cu | 11 ++++++++++- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 8f90c1abba..7ddd7f6fa5 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -141,3 +141,8 @@ PRIVATE $ $ ) + +target_compile_definitions(test_raft +PRIVATE + $<$:NVTX_ENABLED> +) diff --git a/cpp/test/linalg/matrix_linewise_op.cu b/cpp/test/linalg/matrix_linewise_op.cu index 430a9981e0..8c04de1355 100644 --- a/cpp/test/linalg/matrix_linewise_op.cu +++ b/cpp/test/linalg/matrix_linewise_op.cu @@ -16,7 +16,6 @@ #include #include -#include #include #include #include @@ -25,6 +24,10 @@ #include "../test_utils.h" #include "matrix_vector_op.cuh" +#ifdef NVTX_ENABLED +#include +#endif + namespace raft { namespace linalg { @@ -38,19 +41,25 @@ void PUSH_RANGE(rmm::cuda_stream_view stream, const char* name, Args... args) auto buf = std::make_unique(length + 1); std::snprintf(buf.get(), length + 1, name, args...); stream.synchronize(); +#ifdef NVTX_ENABLED nvtxRangePushA(buf.get()); +#endif } template <> void PUSH_RANGE(rmm::cuda_stream_view stream, const char* name) { stream.synchronize(); +#ifdef NVTX_ENABLED nvtxRangePushA(name); +#endif } void POP_RANGE(rmm::cuda_stream_view stream) { stream.synchronize(); +#ifdef NVTX_ENABLED nvtxRangePop(); +#endif } struct LinewiseTestParams { From cfaa82ec11e184510d52ba014eda08f00b3adba7 Mon Sep 17 00:00:00 2001 From: achirkin Date: Tue, 7 Dec 2021 17:29:36 +0100 Subject: [PATCH 07/17] Fix incorrect behaviour on tiny matrices --- .../raft/linalg/matrix_linewise_op.cuh | 69 ++++++++++--------- cpp/test/linalg/matrix_linewise_op.cu | 66 ++++++++++++++---- 2 files changed, 88 insertions(+), 47 deletions(-) diff --git a/cpp/include/raft/linalg/matrix_linewise_op.cuh b/cpp/include/raft/linalg/matrix_linewise_op.cuh index 4f1061bcf3..349837c01a 100644 --- a/cpp/include/raft/linalg/matrix_linewise_op.cuh +++ b/cpp/include/raft/linalg/matrix_linewise_op.cuh @@ -250,18 +250,20 @@ void matrixLinewiseVecCols(Type* out, const IdxType alignedOff = IdxType(alignedStart - in); const IdxType alignedEnd = IdxType(AlignBytes::roundDown(in + totalLen) - in); const IdxType alignedLen = alignedEnd - alignedOff; - constexpr dim3 bs(BlockSize, 1, 1); - // Minimum size of the grid to make the device well occupied - const uint occupy = raft::getMultiProcessorCount() * (16384 / BlockSize); - // does not make sense to have more blocks than this - const uint maxBlocks = raft::ceildiv(uint(alignedLen), bs.x * VecElems); - const dim3 gs(min(maxBlocks, occupy), 1, 1); + if (alignedLen > 0) { + constexpr dim3 bs(BlockSize, 1, 1); + // Minimum size of the grid to make the device well occupied + const uint occupy = raft::getMultiProcessorCount() * (16384 / BlockSize); + // does not make sense to have more blocks than this + const uint maxBlocks = raft::ceildiv(uint(alignedLen), bs.x * VecElems); + const dim3 gs(min(maxBlocks, occupy), 1, 1); - const IdxType elemsPerThread = - raft::ceildiv(alignedLen, gs.x * VecElems * BlockSize) * VecElems; - matrixLinewiseVecColsMainKernel - <<>>(out, in, alignedOff, rowLen, alignedLen, elemsPerThread, op, vecs...); - CUDA_CHECK(cudaPeekAtLastError()); + const IdxType elemsPerThread = + raft::ceildiv(alignedLen, gs.x * VecElems * BlockSize) * VecElems; + matrixLinewiseVecColsMainKernel + <<>>(out, in, alignedOff, rowLen, alignedLen, elemsPerThread, op, vecs...); + CUDA_CHECK(cudaPeekAtLastError()); + } if (alignedLen < totalLen) { // should be not smaller than the warp size for better branching constexpr std::size_t MaxOffset = std::max(std::size_t(32), VecBytes); @@ -289,29 +291,30 @@ void matrixLinewiseVecRows(Type* out, typedef raft::Pow2 AlignBytes; constexpr std::size_t VecElems = VecBytes / sizeof(Type); const IdxType totalLen = rowLen * nRows; - // blockSize - constexpr dim3 bs(BlockSize, 1, 1); - // if we have `stride` number of blocks, then each block processes always the same - // indices along dimension rowLen; this means a block needs to index `vecs` only once! - const uint stride = (rowLen / raft::gcd(bs.x * uint(VecElems), uint(rowLen))) * VecElems; - // Minimum size of the grid to make the device well occupied - const uint occupy = raft::getMultiProcessorCount() * (16384 / BlockSize); - const dim3 gs(min( - // does not make sense to have more blocks than this - raft::ceildiv(uint(totalLen), bs.x * VecElems), - // increase the stride size if necessary - raft::ceildiv(occupy, stride) * stride), - 1, - 1); + const Type* alignedStart = AlignBytes::roundUp(in); + const IdxType alignedOff = IdxType(alignedStart - in); + const IdxType alignedEnd = IdxType(AlignBytes::roundDown(in + totalLen) - in); + const IdxType alignedLen = alignedEnd - alignedOff; + if (alignedLen > 0) { + constexpr dim3 bs(BlockSize, 1, 1); + // if we have `stride` number of blocks, then each block processes always the same + // indices along dimension rowLen; this means a block needs to index `vecs` only once! + const uint stride = (rowLen / raft::gcd(bs.x * uint(VecElems), uint(rowLen))) * VecElems; + // Minimum size of the grid to make the device well occupied + const uint occupy = raft::getMultiProcessorCount() * (16384 / BlockSize); + const dim3 gs(min( + // does not make sense to have more blocks than this + raft::ceildiv(uint(totalLen), bs.x * VecElems), + // increase the stride size if necessary + raft::ceildiv(occupy, stride) * stride), + 1, + 1); - const Type* alignedStart = AlignBytes::roundUp(in); - const IdxType alignedOff = IdxType(alignedStart - in); - const IdxType alignedEnd = IdxType(AlignBytes::roundDown(in + totalLen) - in); - const IdxType alignedLen = alignedEnd - alignedOff; - matrixLinewiseVecRowsMainKernel - <<>>( - out + alignedOff, alignedStart, alignedOff, rowLen, alignedLen, op, vecs...); - CUDA_CHECK(cudaPeekAtLastError()); + matrixLinewiseVecRowsMainKernel + <<>>( + out + alignedOff, alignedStart, alignedOff, rowLen, alignedLen, op, vecs...); + CUDA_CHECK(cudaPeekAtLastError()); + } if (alignedLen < totalLen) { // should be not smaller than the warp size for better branching constexpr std::size_t MaxOffset = std::max(std::size_t(32), VecBytes); diff --git a/cpp/test/linalg/matrix_linewise_op.cu b/cpp/test/linalg/matrix_linewise_op.cu index 8c04de1355..a132b6e6d7 100644 --- a/cpp/test/linalg/matrix_linewise_op.cu +++ b/cpp/test/linalg/matrix_linewise_op.cu @@ -112,10 +112,10 @@ struct LinewiseTest : public ::testing::TestWithParam genData() + rmm::device_uvector genData(size_t workSizeBytes) { raft::random::Rng r(params.seed); - const std::size_t workSizeElems = params.workSizeBytes / sizeof(T); + const std::size_t workSizeElems = workSizeBytes / sizeof(T); rmm::device_uvector blob(workSizeElems, stream); r.uniform(blob.data(), workSizeElems, T(-1.0), T(1.0), stream); return blob; @@ -170,17 +170,16 @@ struct LinewiseTest : public ::testing::TestWithParam>&& dims, rmm::device_uvector&& blob) { - rmm::device_uvector blob = genData(); rmm::device_uvector blob_val(params.checkCorrectness ? blob.size() / 2 : 0, stream); - auto dims = suggestDimensions(2); - stream.synchronize(); cudaProfilerStart(); + testing::AssertionResult r = testing::AssertionSuccess(); PUSH_RANGE(stream, params.useVanillaMatrixVectorOp ? "method: original" : "method: linewise"); for (auto [n, m] : dims) { + if (!r) break; auto [out, in, vec1, vec2] = assignSafePtrs(blob, n, m); PUSH_RANGE(stream, "Dims-%zu-%zu", std::size_t(n), std::size_t(m)); for (auto alongRows : ::testing::Bool()) { @@ -193,18 +192,20 @@ struct LinewiseTest : public ::testing::TestWithParam(params.tolerance))) - << "with one vec"; + r = devArrMatch(blob_val.data(), out, n * m, CompareApprox(params.tolerance)) + << " " << (alongRows ? "alongRows" : "acrossRows") + << " with one vec; lineLen: " << lineLen << "; nLines " << nLines; + if (!r) break; } PUSH_RANGE(stream, "two vecs"); runLinewiseSum(out, in, lineLen, nLines, alongRows, vec1, vec2); POP_RANGE(stream); if (params.checkCorrectness) { naiveMatVec(blob_val.data(), in, vec1, vec2, lineLen, nLines, true, alongRows, T(1)); - EXPECT_NO_FATAL_FAILURE( - devArrMatch(blob_val.data(), out, n * m, CompareApprox(params.tolerance))) - << "with two vecs"; + r = devArrMatch(blob_val.data(), out, n * m, CompareApprox(params.tolerance)) + << " " << (alongRows ? "alongRows" : "acrossRows") + << " with two vecs; lineLen: " << lineLen << "; nLines " << nLines; + if (!r) break; } } POP_RANGE(stream); @@ -214,7 +215,26 @@ struct LinewiseTest : public ::testing::TestWithParam sizes = {1, 2, 3, 4, 7, 16}; + std::vector> dims; + for (auto m : sizes) { + for (auto n : sizes) { + dims.push_back(std::make_tuple(n, m)); + dims.push_back(std::make_tuple(m, n)); + } + } + + return run(std::move(dims), genData(1024 * 1024)); } }; @@ -223,9 +243,25 @@ struct LinewiseTest : public ::testing::TestWithParam Params; + static LinewiseTestParams read(Params ps) + { + return {/** .tolerance */ 0.00001, + /** .workSizeBytes */ 0 /* not used anyway */, + /** .seed */ 42ULL, + /** .useVanillaMatrixVectorOp */ std::get<0>(ps), + /** .checkCorrectness */ true, + /** .inAlignOffset */ std::get<1>(ps), + /** .outAlignOffset */ std::get<2>(ps)}; + } +}; + +auto MegabyteParams = TinyParams; + struct Megabyte { typedef std::tuple Params; static LinewiseTestParams read(Params ps) @@ -273,6 +309,8 @@ struct TenGigs { } }; +TEST_IT(runEdgeCases, Tiny, float, int); +TEST_IT(runEdgeCases, Tiny, double, int); TEST_IT(run, Megabyte, float, int); TEST_IT(run, Megabyte, double, int); TEST_IT(run, Gigabyte, float, int); From 6f1b92d15bd3646f51747a93d0faaf6d64d31c1e Mon Sep 17 00:00:00 2001 From: achirkin Date: Wed, 8 Dec 2021 09:25:44 +0100 Subject: [PATCH 08/17] Hide implementation details --- cpp/include/raft/linalg/matrix_vector_op.cuh | 7 +-- .../detail/linewise_op.cuh} | 50 ++----------------- cpp/include/raft/matrix/matrix.hpp | 41 ++++++++++++++- cpp/test/linalg/matrix_linewise_op.cu | 19 +++---- 4 files changed, 59 insertions(+), 58 deletions(-) rename cpp/include/raft/{linalg/matrix_linewise_op.cuh => matrix/detail/linewise_op.cuh} (87%) diff --git a/cpp/include/raft/linalg/matrix_vector_op.cuh b/cpp/include/raft/linalg/matrix_vector_op.cuh index 9e38ebd167..750eca0742 100644 --- a/cpp/include/raft/linalg/matrix_vector_op.cuh +++ b/cpp/include/raft/linalg/matrix_vector_op.cuh @@ -16,7 +16,7 @@ #pragma once -#include +#include namespace raft { namespace linalg { @@ -57,7 +57,8 @@ void matrixVectorOp(Type* out, { IdxType stride = rowMajor ? D : N; IdxType nLines = rowMajor ? N : D; - return matrixLinewiseOp(out, matrix, stride, nLines, rowMajor == bcastAlongRows, op, stream, vec); + return matrix::linewiseOp( + out, matrix, stride, nLines, rowMajor == bcastAlongRows, op, stream, vec); } /** @@ -98,7 +99,7 @@ void matrixVectorOp(Type* out, { IdxType stride = rowMajor ? D : N; IdxType nLines = rowMajor ? N : D; - return matrixLinewiseOp( + return matrix::linewiseOp( out, matrix, stride, nLines, rowMajor == bcastAlongRows, op, stream, vec1, vec2); } diff --git a/cpp/include/raft/linalg/matrix_linewise_op.cuh b/cpp/include/raft/matrix/detail/linewise_op.cuh similarity index 87% rename from cpp/include/raft/linalg/matrix_linewise_op.cuh rename to cpp/include/raft/matrix/detail/linewise_op.cuh index 349837c01a..a67e9d2d0b 100644 --- a/cpp/include/raft/linalg/matrix_linewise_op.cuh +++ b/cpp/include/raft/matrix/detail/linewise_op.cuh @@ -21,9 +21,8 @@ #include namespace raft { -namespace linalg { - -namespace linewise_impl { +namespace matrix { +namespace detail { template struct Linewise { @@ -360,45 +359,6 @@ struct MatrixLinewiseOp { } }; -}; // namespace linewise_impl - -/** - * Run a function over matrix lines (rows or columns) with a variable number - * row-vectors or column-vectors. - * The term `line` here signifies that the lines can be either columns or rows, - * depending on the matrix layout. - * What matters is if the vectors are applied along lines (indices of vectors correspond to - * indices within lines), or across lines (indices of vectors correspond to line numbers). - * - * @param [out] out result of the operation; can be same as `in`; should be aligned the same - * as `in` to allow faster vectorized memory transfers. - * @param [in] in input matrix consisting of `nLines` lines, each `lineLen`-long. - * @param [in] lineLen length of matrix line in elements (`=nCols` in row-major or `=nRows` in - * col-major) - * @param [in] nLines number of matrix lines (`=nRows` in row-major or `=nCols` in col-major) - * @param [in] alongLines whether vectors are indices along or across lines. - * @param [in] op the operation applied on each line: - * for i in [0..lineLen) and j in [0..nLines): - * out[i, j] = op(in[i, j], vec1[i], vec2[i], ... veck[i]) if alongLines = true - * out[i, j] = op(in[i, j], vec1[j], vec2[j], ... veck[j]) if alongLines = false - * where matrix indexing is row-major ([i, j] = [i + lineLen * j]). - * @param [in] stream a cuda stream for the kernels - * @param [in] vecs zero or more vectors to be passed as arguments, - * size of each vector is `alongLines ? lineLen : nLines`. - */ -template -void matrixLinewiseOp(Type* out, - const Type* in, - const IdxType lineLen, - const IdxType nLines, - const bool alongLines, - Lambda op, - cudaStream_t stream, - Vecs... vecs) -{ - linewise_impl::MatrixLinewiseOp<16, 256>::run( - out, in, lineLen, nLines, alongLines, op, stream, vecs...); -} - -}; // end namespace linalg -}; // end namespace raft +} // end namespace detail +} // end namespace matrix +} // end namespace raft diff --git a/cpp/include/raft/matrix/matrix.hpp b/cpp/include/raft/matrix/matrix.hpp index c4cd30b7bc..f5827bf4bd 100644 --- a/cpp/include/raft/matrix/matrix.hpp +++ b/cpp/include/raft/matrix/matrix.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2020, NVIDIA CORPORATION. + * Copyright (c) 2018-2021, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ #pragma once +#include "detail/linewise_op.cuh" #include "detail/matrix.cuh" #include @@ -289,5 +290,43 @@ m_t getL2Norm(const raft::handle_t& handle, m_t* in, idx_t size, cudaStream_t st return normval; } +/** + * Run a function over matrix lines (rows or columns) with a variable number + * row-vectors or column-vectors. + * The term `line` here signifies that the lines can be either columns or rows, + * depending on the matrix layout. + * What matters is if the vectors are applied along lines (indices of vectors correspond to + * indices within lines), or across lines (indices of vectors correspond to line numbers). + * + * @param [out] out result of the operation; can be same as `in`; should be aligned the same + * as `in` to allow faster vectorized memory transfers. + * @param [in] in input matrix consisting of `nLines` lines, each `lineLen`-long. + * @param [in] lineLen length of matrix line in elements (`=nCols` in row-major or `=nRows` in + * col-major) + * @param [in] nLines number of matrix lines (`=nRows` in row-major or `=nCols` in col-major) + * @param [in] alongLines whether vectors are indices along or across lines. + * @param [in] op the operation applied on each line: + * for i in [0..lineLen) and j in [0..nLines): + * out[i, j] = op(in[i, j], vec1[i], vec2[i], ... veck[i]) if alongLines = true + * out[i, j] = op(in[i, j], vec1[j], vec2[j], ... veck[j]) if alongLines = false + * where matrix indexing is row-major ([i, j] = [i + lineLen * j]). + * @param [in] stream a cuda stream for the kernels + * @param [in] vecs zero or more vectors to be passed as arguments, + * size of each vector is `alongLines ? lineLen : nLines`. + */ +template +void linewiseOp(m_t* out, + const m_t* in, + const idx_t lineLen, + const idx_t nLines, + const bool alongLines, + Lambda op, + cudaStream_t stream, + Vecs... vecs) +{ + detail::MatrixLinewiseOp<16, 256>::run( + out, in, lineLen, nLines, alongLines, op, stream, vecs...); +} + }; // end namespace matrix }; // end namespace raft diff --git a/cpp/test/linalg/matrix_linewise_op.cu b/cpp/test/linalg/matrix_linewise_op.cu index a132b6e6d7..4c24ce585e 100644 --- a/cpp/test/linalg/matrix_linewise_op.cu +++ b/cpp/test/linalg/matrix_linewise_op.cu @@ -17,8 +17,8 @@ #include #include #include -#include #include +#include #include #include #include "../test_utils.h" @@ -29,7 +29,7 @@ #endif namespace raft { -namespace linalg { +namespace matrix { constexpr std::size_t PTR_PADDING = 128; @@ -92,9 +92,9 @@ struct LinewiseTest : public ::testing::TestWithParam T { return a + b; }; if (params.useVanillaMatrixVectorOp) - matrixVectorOp(out, in, vec, lineLen, nLines, true, alongLines, f, stream); + linalg::matrixVectorOp(out, in, vec, lineLen, nLines, true, alongLines, f, stream); else - matrixLinewiseOp(out, in, lineLen, nLines, alongLines, f, stream, vec); + matrix::linewiseOp(out, in, lineLen, nLines, alongLines, f, stream, vec); } void runLinewiseSum(T* out, @@ -107,9 +107,9 @@ struct LinewiseTest : public ::testing::TestWithParam T { return a + b + c; }; if (params.useVanillaMatrixVectorOp) - matrixVectorOp(out, in, vec1, vec2, lineLen, nLines, true, alongLines, f, stream); + linalg::matrixVectorOp(out, in, vec1, vec2, lineLen, nLines, true, alongLines, f, stream); else - matrixLinewiseOp(out, in, lineLen, nLines, alongLines, f, stream, vec1, vec2); + matrix::linewiseOp(out, in, lineLen, nLines, alongLines, f, stream, vec1, vec2); } rmm::device_uvector genData(size_t workSizeBytes) @@ -191,7 +191,7 @@ struct LinewiseTest : public ::testing::TestWithParam(params.tolerance)) << " " << (alongRows ? "alongRows" : "acrossRows") << " with one vec; lineLen: " << lineLen << "; nLines " << nLines; @@ -201,7 +201,8 @@ struct LinewiseTest : public ::testing::TestWithParam(params.tolerance)) << " " << (alongRows ? "alongRows" : "acrossRows") << " with two vecs; lineLen: " << lineLen << "; nLines " << nLines; @@ -318,5 +319,5 @@ TEST_IT(run, Gigabyte, double, int); TEST_IT(run, TenGigs, float, uint64_t); TEST_IT(run, TenGigs, double, uint64_t); -} // end namespace linalg +} // namespace matrix } // end namespace raft From c8d917b465b3d5ea85bc14f4f3caceca5f6919e8 Mon Sep 17 00:00:00 2001 From: achirkin Date: Wed, 8 Dec 2021 09:42:21 +0100 Subject: [PATCH 09/17] Move the tests as well --- cpp/test/CMakeLists.txt | 2 +- .../{linalg/matrix_linewise_op.cu => matrix/linewise_op.cu} | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) rename cpp/test/{linalg/matrix_linewise_op.cu => matrix/linewise_op.cu} (99%) diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 7ddd7f6fa5..6ea7a17bd0 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -51,7 +51,6 @@ add_executable(test_raft test/linalg/gemv.cu test/linalg/map.cu test/linalg/map_then_reduce.cu - test/linalg/matrix_linewise_op.cu test/linalg/matrix_vector_op.cu test/linalg/multiply.cu test/linalg/norm.cu @@ -63,6 +62,7 @@ add_executable(test_raft test/linalg/unary_op.cu test/matrix/math.cu test/matrix/matrix.cu + test/matrix/linewise_op.cu test/mr/device/buffer.cpp test/mr/host/buffer.cpp test/mst.cu diff --git a/cpp/test/linalg/matrix_linewise_op.cu b/cpp/test/matrix/linewise_op.cu similarity index 99% rename from cpp/test/linalg/matrix_linewise_op.cu rename to cpp/test/matrix/linewise_op.cu index 4c24ce585e..26bfa13148 100644 --- a/cpp/test/linalg/matrix_linewise_op.cu +++ b/cpp/test/matrix/linewise_op.cu @@ -21,8 +21,8 @@ #include #include #include +#include "../linalg/matrix_vector_op.cuh" #include "../test_utils.h" -#include "matrix_vector_op.cuh" #ifdef NVTX_ENABLED #include From 92390d70335ee45b413b007d9ecdf7775d626b47 Mon Sep 17 00:00:00 2001 From: achirkin Date: Fri, 10 Dec 2021 15:34:10 +0100 Subject: [PATCH 10/17] use NVTX helpers from future --- cpp/test/CMakeLists.txt | 6 -- cpp/test/matrix/linewise_op.cu | 101 +++++++++------------------------ 2 files changed, 27 insertions(+), 80 deletions(-) diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 6ea7a17bd0..56a1ab5356 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -131,7 +131,6 @@ PRIVATE CUDA::cusolver CUDA::cudart CUDA::cusparse - $<$:CUDA::nvToolsExt> rmm::rmm cuco::cuco FAISS::FAISS @@ -141,8 +140,3 @@ PRIVATE $ $ ) - -target_compile_definitions(test_raft -PRIVATE - $<$:NVTX_ENABLED> -) diff --git a/cpp/test/matrix/linewise_op.cu b/cpp/test/matrix/linewise_op.cu index 26bfa13148..930c3537e3 100644 --- a/cpp/test/matrix/linewise_op.cu +++ b/cpp/test/matrix/linewise_op.cu @@ -17,6 +17,7 @@ #include #include #include +#include #include #include #include @@ -24,49 +25,15 @@ #include "../linalg/matrix_vector_op.cuh" #include "../test_utils.h" -#ifdef NVTX_ENABLED -#include -#endif - namespace raft { namespace matrix { constexpr std::size_t PTR_PADDING = 128; -template -void PUSH_RANGE(rmm::cuda_stream_view stream, const char* name, Args... args) -{ - int length = std::snprintf(nullptr, 0, name, args...); - assert(length >= 0); - auto buf = std::make_unique(length + 1); - std::snprintf(buf.get(), length + 1, name, args...); - stream.synchronize(); -#ifdef NVTX_ENABLED - nvtxRangePushA(buf.get()); -#endif -} -template <> -void PUSH_RANGE(rmm::cuda_stream_view stream, const char* name) -{ - stream.synchronize(); -#ifdef NVTX_ENABLED - nvtxRangePushA(name); -#endif -} - -void POP_RANGE(rmm::cuda_stream_view stream) -{ - stream.synchronize(); -#ifdef NVTX_ENABLED - nvtxRangePop(); -#endif -} - struct LinewiseTestParams { double tolerance; std::size_t workSizeBytes; uint64_t seed; - bool useVanillaMatrixVectorOp; bool checkCorrectness; int inAlignOffset; int outAlignOffset; @@ -91,10 +58,7 @@ struct LinewiseTest : public ::testing::TestWithParam T { return a + b; }; - if (params.useVanillaMatrixVectorOp) - linalg::matrixVectorOp(out, in, vec, lineLen, nLines, true, alongLines, f, stream); - else - matrix::linewiseOp(out, in, lineLen, nLines, alongLines, f, stream, vec); + matrix::linewiseOp(out, in, lineLen, nLines, alongLines, f, stream, vec); } void runLinewiseSum(T* out, @@ -106,10 +70,7 @@ struct LinewiseTest : public ::testing::TestWithParam T { return a + b + c; }; - if (params.useVanillaMatrixVectorOp) - linalg::matrixVectorOp(out, in, vec1, vec2, lineLen, nLines, true, alongLines, f, stream); - else - matrix::linewiseOp(out, in, lineLen, nLines, alongLines, f, stream, vec1, vec2); + matrix::linewiseOp(out, in, lineLen, nLines, alongLines, f, stream, vec1, vec2); } rmm::device_uvector genData(size_t workSizeBytes) @@ -177,19 +138,19 @@ struct LinewiseTest : public ::testing::TestWithParam(params.tolerance)) @@ -197,9 +158,10 @@ struct LinewiseTest : public ::testing::TestWithParam Params; + typedef std::tuple Params; static LinewiseTestParams read(Params ps) { return {/** .tolerance */ 0.00001, /** .workSizeBytes */ 0 /* not used anyway */, /** .seed */ 42ULL, - /** .useVanillaMatrixVectorOp */ std::get<0>(ps), /** .checkCorrectness */ true, - /** .inAlignOffset */ std::get<1>(ps), - /** .outAlignOffset */ std::get<2>(ps)}; + /** .inAlignOffset */ std::get<0>(ps), + /** .outAlignOffset */ std::get<1>(ps)}; } }; auto MegabyteParams = TinyParams; struct Megabyte { - typedef std::tuple Params; + typedef std::tuple Params; static LinewiseTestParams read(Params ps) { return {/** .tolerance */ 0.00001, /** .workSizeBytes */ 1024 * 1024, /** .seed */ 42ULL, - /** .useVanillaMatrixVectorOp */ std::get<0>(ps), /** .checkCorrectness */ true, - /** .inAlignOffset */ std::get<1>(ps), - /** .outAlignOffset */ std::get<2>(ps)}; + /** .inAlignOffset */ std::get<0>(ps), + /** .outAlignOffset */ std::get<1>(ps)}; } }; -auto GigabyteParams = - ::testing::Combine(::testing::Bool(), ::testing::Values(0, 1, 2), ::testing::Values(0, 1, 2)); +auto GigabyteParams = ::testing::Combine(::testing::Values(0, 1, 2), ::testing::Values(0, 1, 2)); struct Gigabyte { - typedef std::tuple Params; + typedef std::tuple Params; static LinewiseTestParams read(Params ps) { return {/** .tolerance */ 0.00001, /** .workSizeBytes */ 1024 * 1024 * 1024, /** .seed */ 42ULL, - /** .useVanillaMatrixVectorOp */ std::get<0>(ps), /** .checkCorrectness */ false, - /** .inAlignOffset */ std::get<1>(ps), - /** .outAlignOffset */ std::get<2>(ps)}; + /** .inAlignOffset */ std::get<0>(ps), + /** .outAlignOffset */ std::get<1>(ps)}; } }; auto TenGigsParams = GigabyteParams; struct TenGigs { - typedef std::tuple Params; + typedef std::tuple Params; static LinewiseTestParams read(Params ps) { return {/** .tolerance */ 0.00001, /** .workSizeBytes */ 10ULL * 1024ULL * 1024ULL * 1024ULL, /** .seed */ 42ULL, - /** .useVanillaMatrixVectorOp */ std::get<0>(ps), /** .checkCorrectness */ false, - /** .inAlignOffset */ std::get<1>(ps), - /** .outAlignOffset */ std::get<2>(ps)}; + /** .inAlignOffset */ std::get<0>(ps), + /** .outAlignOffset */ std::get<1>(ps)}; } }; From b747a0e80c69e834e8361ea9cdbbc38c690dd03a Mon Sep 17 00:00:00 2001 From: achirkin Date: Thu, 16 Dec 2021 14:41:26 +0100 Subject: [PATCH 11/17] Add more docstrings and comments --- .../raft/matrix/detail/linewise_op.cuh | 162 +++++++++++++++++- cpp/include/raft/pow2_utils.cuh | 20 +-- cpp/test/matrix/linewise_op.cu | 10 +- 3 files changed, 168 insertions(+), 24 deletions(-) diff --git a/cpp/include/raft/matrix/detail/linewise_op.cuh b/cpp/include/raft/matrix/detail/linewise_op.cuh index a67e9d2d0b..4e7facb509 100644 --- a/cpp/include/raft/matrix/detail/linewise_op.cuh +++ b/cpp/include/raft/matrix/detail/linewise_op.cuh @@ -33,6 +33,39 @@ struct Linewise { typedef raft::Pow2 AlignElems; typedef raft::Pow2 AlignWarp; + /** + * Compute op(matrix_in, vec_1, vec_2, ...) where vectors are applied across the + * matrix rows (one vector element per matrix row). + * + * It's assumed that `in` and `out` are aligned to the cuda-vector-size, + * and their length is multiple of that. + * + * Block work arrangement: blocked; + * one warp works on a contiguous chunk of a matrix. Since the matrix is represented + * as a flat array, such an arangement minimizes the number of times when a single + * thread needs to reload the vector value at an index corresponding to the current + * matrix row. Ideally, a thread would load a value from a vector only once, but that + * is not possible if the vector size (= number of matrix rows) is too small or not + * aligned with the cuda-vector-size. + * + * Note about rowDiv/rowMod: + * these two represent the row/column indices in the original input matrices, before + * it was converted to (Vec::io_t*) type (which possibly involves shifting a pointer + * a bit to align to the cuda-vector-size). Thus, they are used to track the index for + * the argument vectors only (the vector pointers are not altered in any way). + * + * + * @tparam Vecs a pack of pointers to vectors (Type*) + * @param [out] out (aligned part of) the output matrix + * @param [in] in (aligned part of) the input matrix + * @param [in] in_end end of the (aligned part of the) input matrix + * @param [in] rowLen number of elements in a row (NOT the vector size) + * @param [in] rowDiv the index in the vectors (= row num in the original unaligned input matrix) + * @param [in] rowMod the index within a row in the original unaligned input matrix. + * @param [in] op the function to apply + * @param [in] vecs pointers to the argument vectors. + * + */ template static __device__ __forceinline__ void vectorCols(typename Vec::io_t* out, const typename Vec::io_t* in, @@ -74,6 +107,25 @@ struct Linewise { } } + /** + * Compute op(matrix_in, vec_1, vec_2, ...) where vectors are applied along + * matrix rows (vector and matrix indices are 1-1). + * + * It's assumed that `in` and `out` are aligned to the cuda-vector-size, + * and their length is multiple of that. + * + * Block work arrangement: striped; + * the grid size is chosen in such a way, that one thread always processes + * the same vector elements. That's why there is no need to read the + * vector arguments multiple times. + * + * @tparam Args a pack of raft::TxN_t + * @param [out] out (aligned part of) the output matrix + * @param [in] in (aligned part of) the input matrix + * @param [in] len total length of (the aligned part of) the input/output matrices + * @param [in] op the function to apply + * @param [in] args the cuda-vector-sized chunks on input vectors (raft::TxN_t) + */ template static __device__ __forceinline__ void vectorRows(typename Vec::io_t* out, const typename Vec::io_t* in, @@ -92,6 +144,16 @@ struct Linewise { } } + /** + * The helper for `vectorRows`. Loads the `raft::TxN_t` chunk + * of a vector. Most of the time this is not aligned, so we load it thread-striped + * within a block and then use the shared memory to get a contiguous chunk. + * + * @param [in] p pointer to a vector + * @param [in] blockOffset the offset of the current block into a vector. + * @param [in] rowLen the length of a vector. + * @return a contiguous chunk of a vector, suitable for `vectorRows`. + */ static __device__ __forceinline__ Vec loadVec(const Type* p, const IdxType blockOffset, const IdxType rowLen) noexcept @@ -113,6 +175,20 @@ struct Linewise { } }; +/** + * This kernel prepares the inputs for the `vectorCols` function where the most of the + * work happens; see `vectorCols` for details. + * + * @param [out] out the output matrix + * @param [in] in the input matrix + * @param [in] arrOffset such an offset into the matrices that makes them aligned to the + * cuda-vector-size + * @param [in] rowLen number of elements in a row (NOT the vector size) + * @param [in] len the total length of the aligned part of the matrices + * @param [in] elemsPerThread how many elements are processed by a single thread in total + * @param [in] op the function to apply + * @param [in] vecs pointers to the argument vectors + */ template __global__ void __launch_bounds__(MaxOffset, 2) matrixLinewiseVecColsTailKernel(Type* out, @@ -156,12 +249,15 @@ __global__ void __launch_bounds__(MaxOffset, 2) Lambda op, Vecs... vecs) { + // Note, L::VecElems == 1 typedef Linewise L; IdxType threadOffset, elemsPerWarp; if (blockIdx.x == 0) { + // first block: offset = 0, length = arrOffset threadOffset = threadIdx.x; elemsPerWarp = threadOffset < arrOffset; } else { + // second block: offset = arrTail, length = len - arrTail threadOffset = arrTail + threadIdx.x; elemsPerWarp = threadOffset < len; } @@ -178,6 +274,18 @@ __global__ void __launch_bounds__(MaxOffset, 2) vecs...); } +/** + * This kernel prepares the inputs for the `vectorRows` function where the most of the + * work happens; see `vectorRows` for details. + * + * @param [out] out the start of the *aligned* part of the output matrix + * @param [in] in the start of the *aligned* part of the input matrix + * @param [in] arrOffset such an offset into the matrices that makes them aligned to `VecBytes` + * @param [in] rowLen number of elements in a row (= the vector size) + * @param [in] len the total length of the aligned part of the matrices + * @param [in] op the function to apply + * @param [in] vecs pointers to the argument vectors + */ template __global__ void __launch_bounds__(MaxOffset, 2) matrixLinewiseVecRowsTailKernel(Type* out, @@ -213,19 +338,23 @@ __global__ void __launch_bounds__(MaxOffset, 2) Lambda op, Vecs... vecs) { + // Note, L::VecElems == 1 typedef Linewise L; - if (blockIdx.x == 0) + if (blockIdx.x == 0) { + // first block: offset = 0, length = arrOffset L::vectorRows(reinterpret_cast(out), reinterpret_cast(in), arrOffset, op, L::loadVec(vecs, 0, rowLen)...); - else + } else { + // second block: offset = arrTail, length = len - arrTail L::vectorRows(reinterpret_cast(out + arrTail - MaxOffset), reinterpret_cast(in + arrTail - MaxOffset), len - arrTail + MaxOffset, op, L::loadVec(vecs, arrTail % rowLen, rowLen)...); + } } template (uint(alignedLen), bs.x * VecElems); const dim3 gs(min(maxBlocks, occupy), 1, 1); - + // The work arrangement is blocked on the block and warp levels; + // see more details at Linewise::vectorCols. + // The value below determines how many scalar elements are processed by on thread in total. const IdxType elemsPerThread = raft::ceildiv(alignedLen, gs.x * VecElems * BlockSize) * VecElems; matrixLinewiseVecColsMainKernel @@ -296,16 +427,29 @@ void matrixLinewiseVecRows(Type* out, const IdxType alignedLen = alignedEnd - alignedOff; if (alignedLen > 0) { constexpr dim3 bs(BlockSize, 1, 1); - // if we have `stride` number of blocks, then each block processes always the same - // indices along dimension rowLen; this means a block needs to index `vecs` only once! - const uint stride = (rowLen / raft::gcd(bs.x * uint(VecElems), uint(rowLen))) * VecElems; + // The work arrangement is striped; + // see more details at Linewise::vectorRows. + // Below is the work amount performed by one block in one iteration. + constexpr uint block_work_size = bs.x * uint(VecElems); + /* Here I would define `grid_work_size = lcm(block_work_size, rowLen)` (Least Common Multiple) + This way, the grid spans a set of one or more rows each iteration, and, most importantly, + on every iteration each row processes the same set of indices within a row (= the same set + of vector indices). + This means, each block needs to load the values from the vector arguments only once. + Sadly, sometimes `grid_work_size > rowLen*nRows`, and sometimes grid_work_size > UINT_MAX. + That's why I don't declare it here explicitly. + Instead, I straightaway compute the + expected_grid_size = lcm(block_work_size, rowLen) / block_work_size + */ + const uint expected_grid_size = rowLen / raft::gcd(block_work_size, uint(rowLen)); // Minimum size of the grid to make the device well occupied const uint occupy = raft::getMultiProcessorCount() * (16384 / BlockSize); const dim3 gs(min( // does not make sense to have more blocks than this - raft::ceildiv(uint(totalLen), bs.x * VecElems), - // increase the stride size if necessary - raft::ceildiv(occupy, stride) * stride), + raft::ceildiv(uint(totalLen), block_work_size), + // increase the grid size to be not less than `occupy` while + // still being the multiple of `expected_grid_size` + raft::ceildiv(occupy, expected_grid_size) * expected_grid_size), 1, 1); diff --git a/cpp/include/raft/pow2_utils.cuh b/cpp/include/raft/pow2_utils.cuh index b1f0b21c7b..93f81db1ac 100644 --- a/cpp/include/raft/pow2_utils.cuh +++ b/cpp/include/raft/pow2_utils.cuh @@ -35,7 +35,7 @@ struct Pow2 { static_assert(std::is_integral::value, "Value must be integral."); static_assert(Value && !(Value & Mask), "Value must be power of two."); -#define Pow2_CALL static constexpr __host__ __device__ __forceinline__ +#define Pow2_FUNC_QUALIFIER static constexpr __host__ __device__ __forceinline__ #define Pow2_WHEN_INTEGRAL(I) std::enable_if_t #define Pow2_IS_REPRESENTABLE_AS(I) (std::is_integral::value && Type(I(Value)) == Value) @@ -46,7 +46,7 @@ struct Pow2 { * Invariant: `x = Value * quot(x) + rem(x)` */ template - Pow2_CALL Pow2_WHEN_INTEGRAL(I) quot(I x) noexcept + Pow2_FUNC_QUALIFIER Pow2_WHEN_INTEGRAL(I) quot(I x) noexcept { if constexpr (std::is_signed::value) return (x >> I(Log2)) + (x < 0 && (x & I(Mask))); if constexpr (std::is_unsigned::value) return x >> I(Log2); @@ -59,7 +59,7 @@ struct Pow2 { * Invariant: `x = Value * quot(x) + rem(x)`. */ template - Pow2_CALL Pow2_WHEN_INTEGRAL(I) rem(I x) noexcept + Pow2_FUNC_QUALIFIER Pow2_WHEN_INTEGRAL(I) rem(I x) noexcept { if constexpr (std::is_signed::value) return x < 0 ? -((-x) & I(Mask)) : (x & I(Mask)); if constexpr (std::is_unsigned::value) return x & I(Mask); @@ -76,7 +76,7 @@ struct Pow2 { * compared to normal C++ operators `/` and `%`. */ template - Pow2_CALL Pow2_WHEN_INTEGRAL(I) div(I x) noexcept + Pow2_FUNC_QUALIFIER Pow2_WHEN_INTEGRAL(I) div(I x) noexcept { return x >> I(Log2); } @@ -93,7 +93,7 @@ struct Pow2 { * compared to normal C++ operators `/` and `%`. */ template - Pow2_CALL Pow2_WHEN_INTEGRAL(I) mod(I x) noexcept + Pow2_FUNC_QUALIFIER Pow2_WHEN_INTEGRAL(I) mod(I x) noexcept { return x & I(Mask); } @@ -107,7 +107,7 @@ struct Pow2 { * NB: for pointers, the alignment is checked in bytes, not in elements. */ template - Pow2_CALL bool isAligned(PtrT p) noexcept + Pow2_FUNC_QUALIFIER bool isAligned(PtrT p) noexcept { Pow2_CHECK_TYPE(PtrT); if constexpr (Pow2_IS_REPRESENTABLE_AS(PtrT)) return mod(p) == 0; @@ -116,7 +116,7 @@ struct Pow2 { /** Tell whether two pointers have the same address modulo Value. */ template - Pow2_CALL bool areSameAlignOffsets(PtrT a, PtrS b) noexcept + Pow2_FUNC_QUALIFIER bool areSameAlignOffsets(PtrT a, PtrS b) noexcept { Pow2_CHECK_TYPE(PtrT); Pow2_CHECK_TYPE(PtrS); @@ -134,7 +134,7 @@ struct Pow2 { /** Get this or next Value-aligned address (in bytes) or integral. */ template - Pow2_CALL PtrT roundUp(PtrT p) noexcept + Pow2_FUNC_QUALIFIER PtrT roundUp(PtrT p) noexcept { Pow2_CHECK_TYPE(PtrT); if constexpr (Pow2_IS_REPRESENTABLE_AS(PtrT)) return (p + PtrT(Mask)) & PtrT(~Mask); @@ -146,7 +146,7 @@ struct Pow2 { /** Get this or previous Value-aligned address (in bytes) or integral. */ template - Pow2_CALL PtrT roundDown(PtrT p) noexcept + Pow2_FUNC_QUALIFIER PtrT roundDown(PtrT p) noexcept { Pow2_CHECK_TYPE(PtrT); if constexpr (Pow2_IS_REPRESENTABLE_AS(PtrT)) return p & PtrT(~Mask); @@ -157,7 +157,7 @@ struct Pow2 { } #undef Pow2_CHECK_TYPE #undef Pow2_IS_REPRESENTABLE_AS -#undef Pow2_CALL +#undef Pow2_FUNC_QUALIFIER #undef Pow2_WHEN_INTEGRAL }; diff --git a/cpp/test/matrix/linewise_op.cu b/cpp/test/matrix/linewise_op.cu index 930c3537e3..c5322cc056 100644 --- a/cpp/test/matrix/linewise_op.cu +++ b/cpp/test/matrix/linewise_op.cu @@ -17,7 +17,7 @@ #include #include #include -#include +// #include #include #include #include @@ -141,14 +141,14 @@ struct LinewiseTest : public ::testing::TestWithParam Date: Thu, 16 Dec 2021 14:55:21 +0100 Subject: [PATCH 12/17] Add even more docstrings and comments --- cpp/include/raft/matrix/detail/linewise_op.cuh | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/cpp/include/raft/matrix/detail/linewise_op.cuh b/cpp/include/raft/matrix/detail/linewise_op.cuh index 4e7facb509..f0b7ff775b 100644 --- a/cpp/include/raft/matrix/detail/linewise_op.cuh +++ b/cpp/include/raft/matrix/detail/linewise_op.cuh @@ -179,6 +179,9 @@ struct Linewise { * This kernel prepares the inputs for the `vectorCols` function where the most of the * work happens; see `vectorCols` for details. * + * The work arrangement is blocked; a single block works on a contiguous chunk of flattened + * matrix data and does not care about the gridDim. + * * @param [out] out the output matrix * @param [in] in the input matrix * @param [in] arrOffset such an offset into the matrices that makes them aligned to the @@ -278,6 +281,10 @@ __global__ void __launch_bounds__(MaxOffset, 2) * This kernel prepares the inputs for the `vectorRows` function where the most of the * work happens; see `vectorRows` for details. * + * The work arrangement is striped; the gridDim should be selected in such a way, that + * on each iteration a thread processes the same indices along rows: + * `(gridDim.x * BlockSize * VecElems) % rowLen == 0`. + * * @param [out] out the start of the *aligned* part of the output matrix * @param [in] in the start of the *aligned* part of the input matrix * @param [in] arrOffset such an offset into the matrices that makes them aligned to `VecBytes` @@ -349,6 +356,7 @@ __global__ void __launch_bounds__(MaxOffset, 2) L::loadVec(vecs, 0, rowLen)...); } else { // second block: offset = arrTail, length = len - arrTail + // NB: I substract MaxOffset (= blockDim.x) to get the correct indexing for block 1 L::vectorRows(reinterpret_cast(out + arrTail - MaxOffset), reinterpret_cast(in + arrTail - MaxOffset), len - arrTail + MaxOffset, From 2d818da36343fca05270b20420cab1f9b57f8a40 Mon Sep 17 00:00:00 2001 From: achirkin Date: Thu, 16 Dec 2021 15:05:55 +0100 Subject: [PATCH 13/17] Adapt to the new .style --- cpp/test/matrix/linewise_op.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/test/matrix/linewise_op.cu b/cpp/test/matrix/linewise_op.cu index c5322cc056..77b5ea3a31 100644 --- a/cpp/test/matrix/linewise_op.cu +++ b/cpp/test/matrix/linewise_op.cu @@ -18,12 +18,12 @@ #include #include // #include +#include "../linalg/matrix_vector_op.cuh" +#include "../test_utils.h" #include #include #include #include -#include "../linalg/matrix_vector_op.cuh" -#include "../test_utils.h" namespace raft { namespace matrix { From f4099a23809f0ad4c636483565f2ba8077edf36a Mon Sep 17 00:00:00 2001 From: achirkin Date: Fri, 17 Dec 2021 12:11:26 +0100 Subject: [PATCH 14/17] Use a double-buffered-style shared memory in loadVec --- .../raft/matrix/detail/linewise_op.cuh | 36 ++++++++++++------- 1 file changed, 23 insertions(+), 13 deletions(-) diff --git a/cpp/include/raft/matrix/detail/linewise_op.cuh b/cpp/include/raft/matrix/detail/linewise_op.cuh index f0b7ff775b..8651ed95a3 100644 --- a/cpp/include/raft/matrix/detail/linewise_op.cuh +++ b/cpp/include/raft/matrix/detail/linewise_op.cuh @@ -149,16 +149,17 @@ struct Linewise { * of a vector. Most of the time this is not aligned, so we load it thread-striped * within a block and then use the shared memory to get a contiguous chunk. * + * @param [in] shm a shared memory region for rearranging the data among threads * @param [in] p pointer to a vector * @param [in] blockOffset the offset of the current block into a vector. * @param [in] rowLen the length of a vector. * @return a contiguous chunk of a vector, suitable for `vectorRows`. */ - static __device__ __forceinline__ Vec loadVec(const Type* p, + static __device__ __forceinline__ Vec loadVec(Type* shm, + const Type* p, const IdxType blockOffset, const IdxType rowLen) noexcept { - __shared__ alignas(sizeof(Type) * VecElems) Type shm[VecElems * BlockSize]; IdxType j = blockOffset + threadIdx.x; #pragma unroll VecElems for (int k = threadIdx.x; k < VecElems * BlockSize; k += BlockSize, j += BlockSize) { @@ -309,12 +310,17 @@ __global__ void __launch_bounds__(BlockSize) Vecs... vecs) { typedef Linewise L; + constexpr uint workSize = L::VecElems * BlockSize; + uint workOffset = workSize; + __shared__ alignas(sizeof(Type) * L::VecElems) + Type shm[workSize * ((sizeof...(Vecs)) > 1 ? 2 : 1)]; const IdxType blockOffset = (arrOffset + BlockSize * L::VecElems * blockIdx.x) % rowLen; - return L::vectorRows(reinterpret_cast(out), - reinterpret_cast(in), - L::AlignElems::div(len), - op, - L::loadVec(vecs, blockOffset, rowLen)...); + return L::vectorRows( + reinterpret_cast(out), + reinterpret_cast(in), + L::AlignElems::div(len), + op, + (workOffset ^= workSize, L::loadVec(shm + workOffset, vecs, blockOffset, rowLen))...); } /** @@ -346,6 +352,9 @@ __global__ void __launch_bounds__(MaxOffset, 2) Vecs... vecs) { // Note, L::VecElems == 1 + constexpr uint workSize = MaxOffset; + uint workOffset = workSize; + __shared__ Type shm[workSize * ((sizeof...(Vecs)) > 1 ? 2 : 1)]; typedef Linewise L; if (blockIdx.x == 0) { // first block: offset = 0, length = arrOffset @@ -353,15 +362,16 @@ __global__ void __launch_bounds__(MaxOffset, 2) reinterpret_cast(in), arrOffset, op, - L::loadVec(vecs, 0, rowLen)...); + (workOffset ^= workSize, L::loadVec(shm + workOffset, vecs, 0, rowLen))...); } else { // second block: offset = arrTail, length = len - arrTail // NB: I substract MaxOffset (= blockDim.x) to get the correct indexing for block 1 - L::vectorRows(reinterpret_cast(out + arrTail - MaxOffset), - reinterpret_cast(in + arrTail - MaxOffset), - len - arrTail + MaxOffset, - op, - L::loadVec(vecs, arrTail % rowLen, rowLen)...); + L::vectorRows( + reinterpret_cast(out + arrTail - MaxOffset), + reinterpret_cast(in + arrTail - MaxOffset), + len - arrTail + MaxOffset, + op, + (workOffset ^= workSize, L::loadVec(shm + workOffset, vecs, arrTail % rowLen, rowLen))...); } } From 41f6475311c0bfb9f53683019b90b429f15316d1 Mon Sep 17 00:00:00 2001 From: achirkin Date: Fri, 17 Dec 2021 12:12:31 +0100 Subject: [PATCH 15/17] Adapt to changes in raft api --- cpp/test/matrix/linewise_op.cu | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/cpp/test/matrix/linewise_op.cu b/cpp/test/matrix/linewise_op.cu index 77b5ea3a31..af40b88ec0 100644 --- a/cpp/test/matrix/linewise_op.cu +++ b/cpp/test/matrix/linewise_op.cu @@ -50,7 +50,7 @@ struct LinewiseTest : public ::testing::TestWithParam::GetParam())), handle(), - stream(handle.get_stream_view()) + stream(handle.get_stream()) { } @@ -152,7 +152,8 @@ struct LinewiseTest : public ::testing::TestWithParam(params.tolerance)) << " " << (alongRows ? "alongRows" : "acrossRows") << " with one vec; lineLen: " << lineLen << "; nLines " << nLines; @@ -164,7 +165,7 @@ struct LinewiseTest : public ::testing::TestWithParam(params.tolerance)) << " " << (alongRows ? "alongRows" : "acrossRows") << " with two vecs; lineLen: " << lineLen << "; nLines " << nLines; From eec616a20325e85b6a232dc7b3a60e0a5ed0bfea Mon Sep 17 00:00:00 2001 From: achirkin Date: Fri, 17 Dec 2021 17:00:03 +0100 Subject: [PATCH 16/17] Tested NVTX ranges working fine --- cpp/test/matrix/linewise_op.cu | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/cpp/test/matrix/linewise_op.cu b/cpp/test/matrix/linewise_op.cu index af40b88ec0..1cd00b8adc 100644 --- a/cpp/test/matrix/linewise_op.cu +++ b/cpp/test/matrix/linewise_op.cu @@ -14,12 +14,12 @@ * limitations under the License. */ +#include "../linalg/matrix_vector_op.cuh" +#include "../test_utils.h" #include #include +#include #include -// #include -#include "../linalg/matrix_vector_op.cuh" -#include "../test_utils.h" #include #include #include @@ -141,14 +141,14 @@ struct LinewiseTest : public ::testing::TestWithParam Date: Tue, 21 Dec 2021 08:43:59 +0100 Subject: [PATCH 17/17] Removed/explained some magic constants and refactored a bit --- .../raft/matrix/detail/linewise_op.cuh | 36 ++++++++++++++----- cpp/include/raft/matrix/matrix.hpp | 6 ++++ 2 files changed, 34 insertions(+), 8 deletions(-) diff --git a/cpp/include/raft/matrix/detail/linewise_op.cuh b/cpp/include/raft/matrix/detail/linewise_op.cuh index 8651ed95a3..63fa872f9d 100644 --- a/cpp/include/raft/matrix/detail/linewise_op.cuh +++ b/cpp/include/raft/matrix/detail/linewise_op.cuh @@ -375,6 +375,26 @@ __global__ void __launch_bounds__(MaxOffset, 2) } } +/** Fully occupy GPU this many times for better work balancing. */ +static inline constexpr uint OptimalSmOccupancy = 16; + +/** + * Calculate the grid size to be `OptimalSmOccupancy * FullyOccupiedGPU`, where `FullyOccupiedGPU` + * is the maximum number of blocks fitting in all available SMs. + * + * @tparam BlockSize blockDim of the kernel. + * @return OptimalSmOccupancy * FullyOccupiedGPU + */ +template +inline uint getOptimalGridSize() +{ + int devId, smCount, maxBlockSize; + RAFT_CUDA_TRY(cudaGetDevice(&devId)); + RAFT_CUDA_TRY(cudaDeviceGetAttribute(&smCount, cudaDevAttrMultiProcessorCount, devId)); + RAFT_CUDA_TRY(cudaDeviceGetAttribute(&maxBlockSize, cudaDevAttrMaxThreadsPerBlock, devId)); + return OptimalSmOccupancy * static_cast(smCount * maxBlockSize / BlockSize); +} + template 0) { constexpr dim3 bs(BlockSize, 1, 1); // Minimum size of the grid to make the device well occupied - const uint occupy = raft::getMultiProcessorCount() * (16384 / BlockSize); + const uint occupy = getOptimalGridSize(); // does not make sense to have more blocks than this const uint maxBlocks = raft::ceildiv(uint(alignedLen), bs.x * VecElems); const dim3 gs(min(maxBlocks, occupy), 1, 1); @@ -410,15 +430,15 @@ void matrixLinewiseVecCols(Type* out, raft::ceildiv(alignedLen, gs.x * VecElems * BlockSize) * VecElems; matrixLinewiseVecColsMainKernel <<>>(out, in, alignedOff, rowLen, alignedLen, elemsPerThread, op, vecs...); - CUDA_CHECK(cudaPeekAtLastError()); + RAFT_CUDA_TRY(cudaPeekAtLastError()); } if (alignedLen < totalLen) { // should be not smaller than the warp size for better branching - constexpr std::size_t MaxOffset = std::max(std::size_t(32), VecBytes); + constexpr std::size_t MaxOffset = std::max(std::size_t(raft::WarpSize), VecBytes); matrixLinewiseVecColsTailKernel <<>>( out, in, alignedOff, alignedEnd, rowLen, totalLen, op, vecs...); - CUDA_CHECK(cudaPeekAtLastError()); + RAFT_CUDA_TRY(cudaPeekAtLastError()); } } @@ -461,7 +481,7 @@ void matrixLinewiseVecRows(Type* out, */ const uint expected_grid_size = rowLen / raft::gcd(block_work_size, uint(rowLen)); // Minimum size of the grid to make the device well occupied - const uint occupy = raft::getMultiProcessorCount() * (16384 / BlockSize); + const uint occupy = getOptimalGridSize(); const dim3 gs(min( // does not make sense to have more blocks than this raft::ceildiv(uint(totalLen), block_work_size), @@ -474,15 +494,15 @@ void matrixLinewiseVecRows(Type* out, matrixLinewiseVecRowsMainKernel <<>>( out + alignedOff, alignedStart, alignedOff, rowLen, alignedLen, op, vecs...); - CUDA_CHECK(cudaPeekAtLastError()); + RAFT_CUDA_TRY(cudaPeekAtLastError()); } if (alignedLen < totalLen) { // should be not smaller than the warp size for better branching - constexpr std::size_t MaxOffset = std::max(std::size_t(32), VecBytes); + constexpr std::size_t MaxOffset = std::max(std::size_t(raft::WarpSize), VecBytes); matrixLinewiseVecRowsTailKernel <<>>( out, in, alignedOff, alignedEnd, rowLen, totalLen, op, vecs...); - CUDA_CHECK(cudaPeekAtLastError()); + RAFT_CUDA_TRY(cudaPeekAtLastError()); } } diff --git a/cpp/include/raft/matrix/matrix.hpp b/cpp/include/raft/matrix/matrix.hpp index 07eba8c1a3..bf2dc963ad 100644 --- a/cpp/include/raft/matrix/matrix.hpp +++ b/cpp/include/raft/matrix/matrix.hpp @@ -23,6 +23,7 @@ #include #include #include +#include #include #include #include @@ -325,6 +326,11 @@ void linewiseOp(m_t* out, cudaStream_t stream, Vecs... vecs) { + common::nvtx::range fun_scope("linewiseOp-%c-%zu (%zu, %zu)", + alongLines ? 'l' : 'x', + sizeof...(Vecs), + size_t(lineLen), + size_t(nLines)); detail::MatrixLinewiseOp<16, 256>::run( out, in, lineLen, nLines, alongLines, op, stream, vecs...); }