Skip to content

Commit

Permalink
Version 2 (all faster, but not fully tested)
Browse files Browse the repository at this point in the history
  • Loading branch information
achirkin committed Nov 25, 2021
1 parent 635c79b commit ed27367
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 118 deletions.
198 changes: 114 additions & 84 deletions cpp/include/raft/linalg/matrix_linewise_op.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -32,39 +32,49 @@ struct Linewise {
typedef raft::TxN_t<Type, VecElems> Vec;
typedef raft::Pow2<VecBytes> AlignBytes;
typedef raft::Pow2<VecElems> AlignElems;
typedef raft::Pow2<raft::WarpSize> AlignWarp;

template <typename Lambda, typename... Args>
static __device__ __forceinline__ void vectorCols(Type* out, const Type* in,
const IdxType rowLen,
Lambda op, Args... args) {
const IdxType alignedStart = IdxType(AlignBytes::roundUp(in) - in);
const IdxType alignedEnd = IdxType(AlignBytes::roundDown(in + rowLen) - in);
IdxType i0 = threadIdx.x + blockIdx.y * blockDim.x;

// First unaligned pieces
if (i0 < alignedStart) out[i0] = op(in[i0], args...);

// aligned core chunk
{
Vec data;
const IdxType d = blockDim.x * gridDim.y * VecElems;
for (IdxType i = alignedStart + i0 * VecElems; i < alignedEnd; i += d) {
data.load(in, i);
template <typename Lambda, typename... Vecs>
static __device__ __forceinline__ void vectorCols(
typename Vec::io_t* out, const typename Vec::io_t* in,
const typename Vec::io_t* in_end, const IdxType rowLen, IdxType rowDiv,
IdxType rowMod, Lambda op, Vecs... vecs) noexcept {
constexpr IdxType warpPad = (AlignWarp::Value - 1) * VecElems;
Type args[sizeof...(Vecs)];
Vec v, w;
bool update = true;
for (; in < in_end;
in += AlignWarp::Value, out += AlignWarp::Value, rowMod += warpPad) {
v.val.internal = __ldcv(in);
while (rowMod >= rowLen) {
rowMod -= rowLen;
rowDiv++;
update = true;
}
if (update) {
int l = 0;
((args[l] = vecs[rowDiv], l++), ...);
update = false;
}
#pragma unroll VecElems
for (int k = 0; k < VecElems; k++)
data.val.data[k] = op(data.val.data[k], args...);
data.store(out, i);
for (int k = 0; k < VecElems; k++, rowMod++) {
if (rowMod == rowLen) {
rowMod = 0;
rowDiv++;
int l = 0;
((args[l] = vecs[rowDiv], l++), ...);
}
int l = 0;
w.val.data[k] = op(v.val.data[k], (std::ignore = vecs, args[l++])...);
}
*out = w.val.internal;
}
// last unaligned pieces
i0 += alignedEnd;
if (i0 < rowLen) out[i0] = op(in[i0], args...);
}

template <typename Lambda, typename... Args>
static __device__ __forceinline__ void vectorRows(
typename Vec::io_t* out, const typename Vec::io_t* in, const IdxType len,
Lambda op, Args... args) {
Lambda op, Args... args) noexcept {
Vec v;
const IdxType d = BlockSize * gridDim.x;
for (IdxType i = threadIdx.x + blockIdx.x * BlockSize; i < len; i += d) {
Expand All @@ -78,8 +88,7 @@ struct Linewise {

static __device__ __forceinline__ Vec loadVec(const Type* p,
const IdxType blockOffset,
const IdxType rowLen) {
// 11.096 ms / 34 Regs
const IdxType rowLen) noexcept {
__shared__ alignas(sizeof(Type) * VecElems) Type shm[VecElems * BlockSize];
IdxType j = blockOffset + threadIdx.x;
#pragma unroll VecElems
Expand All @@ -95,40 +104,55 @@ struct Linewise {
reinterpret_cast<typename Vec::io_t*>(shm)[threadIdx.x];
return out;
}

// // 16.686 ms / 66 Regs
// typedef raft::Pow2<raft::WarpSize> AlignWarp;
// int l = AlignWarp::mod(threadIdx.x);
// int d = l >> (AlignWarp::Log2 - AlignElems::Log2);
// Vec out;
// #pragma unroll VecElems
// for (int k = VecElems, j = blockOffset + (threadIdx.x - l) * VecElems + l;
// k > 0; k--, j += AlignWarp::Value) {
// while (j >= rowLen) j -= rowLen;
// const int kd = AlignElems::mod(k + l + d);
// out.val.data[kd] = __ldg(p + j);
// }
// l = AlignWarp::mod(l * VecElems);
// #pragma unroll VecElems
// for (int k = d; k < VecElems + d; k++) {
// const int kd = AlignElems::mod(k);
// out.val.data[kd] = __shfl_sync(0xffffffffu, out.val.data[kd], kd + l);
// }
// return out;
}
};

template <typename Type, typename IdxType, std::size_t VecBytes, int BlockSize,
typename Lambda, typename... Vecs>
__global__ void __launch_bounds__(BlockSize)
matrixLinewiseVecColsKernel(Type* out, const Type* in, const IdxType rowLen,
const IdxType nRows, Lambda op, Vecs... vecs) {
const IdxType j = threadIdx.y + blockIdx.x * blockDim.y;
if (j < nRows) {
const IdxType shift = rowLen * j;
Linewise<Type, IdxType, VecBytes, BlockSize>::vectorCols(
out + shift, in + shift, rowLen, op, vecs[j]...);
matrixLinewiseVecColsMainKernel(Type* out, const Type* in,
const IdxType arrOffset, const IdxType rowLen,
const IdxType len,
const IdxType elemsPerThread, Lambda op,
Vecs... vecs) {
typedef Linewise<Type, IdxType, VecBytes, BlockSize> L;

IdxType t = L::AlignWarp::mod(threadIdx.x);
t = arrOffset + elemsPerThread * (blockIdx.x * BlockSize + threadIdx.x - t) +
t * L::VecElems;

return L::vectorCols(
reinterpret_cast<typename L::Vec::io_t*>(out + t),
reinterpret_cast<const typename L::Vec::io_t*>(in + t),
reinterpret_cast<const typename L::Vec::io_t*>(
in + min(t + elemsPerThread * L::AlignWarp::Value, len)),
rowLen, t / rowLen, t % rowLen, op, vecs...);
}

template <typename Type, typename IdxType, std::size_t MaxOffset,
typename Lambda, typename... Vecs>
__global__ void __launch_bounds__(MaxOffset, 2)
matrixLinewiseVecColsTailKernel(Type* out, const Type* in,
const IdxType arrOffset,
const IdxType arrTail, const IdxType rowLen,
const IdxType len, Lambda op, Vecs... vecs) {
typedef Linewise<Type, IdxType, sizeof(Type), MaxOffset> L;
IdxType threadOffset, elemsPerWarp;
if (blockIdx.x == 0) {
threadOffset = threadIdx.x;
elemsPerWarp = threadOffset < arrOffset;
} else {
threadOffset = arrTail + threadIdx.x;
elemsPerWarp = threadOffset < len;
}
const IdxType rowDiv = threadOffset / rowLen;
const IdxType rowMod = threadOffset % rowLen;
return L::vectorCols(
reinterpret_cast<typename L::Vec::io_t*>(out + threadOffset),
reinterpret_cast<const typename L::Vec::io_t*>(in + threadOffset),
reinterpret_cast<const typename L::Vec::io_t*>(in + threadOffset +
elemsPerWarp),
rowLen, rowDiv, rowMod, op, vecs...);
}

template <typename Type, typename IdxType, std::size_t VecBytes, int BlockSize,
Expand All @@ -140,10 +164,10 @@ __global__ void __launch_bounds__(BlockSize)
typedef Linewise<Type, IdxType, VecBytes, BlockSize> L;
const IdxType blockOffset =
(arrOffset + BlockSize * L::VecElems * blockIdx.x) % rowLen;
L::vectorRows(reinterpret_cast<typename L::Vec::io_t*>(out),
reinterpret_cast<const typename L::Vec::io_t*>(in),
L::AlignElems::div(len), op,
L::loadVec(vecs, blockOffset, rowLen)...);
return L::vectorRows(reinterpret_cast<typename L::Vec::io_t*>(out),
reinterpret_cast<const typename L::Vec::io_t*>(in),
L::AlignElems::div(len), op,
L::loadVec(vecs, blockOffset, rowLen)...);
}

template <typename Type, typename IdxType, std::size_t MaxOffset,
Expand All @@ -153,9 +177,6 @@ __global__ void __launch_bounds__(MaxOffset, 2)
const IdxType arrOffset,
const IdxType arrTail, const IdxType rowLen,
const IdxType len, Lambda op, Vecs... vecs) {
constexpr std::size_t MaxOffsetMod = MaxOffset - 1;
static_assert((MaxOffset & MaxOffsetMod) == 0,
"MaxOffset must be power of two.");
typedef Linewise<Type, IdxType, sizeof(Type), MaxOffset> L;
if (blockIdx.x == 0)
L::vectorRows(reinterpret_cast<typename L::Vec::io_t*>(out),
Expand All @@ -174,27 +195,36 @@ template <typename Type, typename IdxType, std::size_t VecBytes,
void matrixLinewiseVecCols(Type* out, const Type* in, const IdxType rowLen,
const IdxType nRows, Lambda op, cudaStream_t stream,
Vecs... vecs) {
typedef raft::Pow2<VecBytes> AlignBytes;
constexpr std::size_t VecElems = VecBytes / sizeof(Type);
IdxType bsx = 32;
IdxType bsy = 8;
const IdxType totalLen = rowLen * nRows;
const Type* alignedStart = AlignBytes::roundUp(in);
const IdxType alignedOff = IdxType(alignedStart - in);
const IdxType alignedEnd = IdxType(AlignBytes::roundDown(in + totalLen) - in);
const IdxType alignedLen = alignedEnd - alignedOff;
// blockSize
constexpr int BlockSize = 256;
while (bsy > nRows * 2) {
bsy >>= 1;
bsx <<= 1;
}
IdxType gsy = raft::ceildiv<IdxType>(nRows, bsy);
IdxType gsx =
min(raft::ceildiv<IdxType>(raft::getMultiProcessorCount() * 64, gsy),
raft::ceildiv<IdxType>(rowLen, bsx * VecElems));
// NB: gridSize.x and gridSize.y are swapped, because gsx is bounded by a small number,
// but gsy can grow uncontrollably with the number of rows.
// (there is a tight limit on the max grid size in `y` direction).
dim3 bs(bsx, bsy, 1);
dim3 gs(gsy, gsx, 1);
matrixLinewiseVecColsKernel<Type, IdxType, VecBytes, BlockSize, Lambda,
Vecs...>
<<<gs, bs, 0, stream>>>(out, in, rowLen, nRows, op, vecs...);
constexpr dim3 bs(BlockSize, 1, 1);
// Minimum size of the grid to make device well occupied
const uint occupy = raft::getMultiProcessorCount() * 64;
// does not make sense to have more blocks than this
const uint maxBlocks = raft::ceildiv<uint>(uint(alignedLen), bs.x * VecElems);
const dim3 gs(min(maxBlocks, occupy), 1, 1);

const IdxType elemsPerThread =
raft::ceildiv<IdxType>(alignedLen, gs.x * VecElems * BlockSize) * VecElems;
matrixLinewiseVecColsMainKernel<Type, IdxType, VecBytes, BlockSize, Lambda,
Vecs...><<<gs, bs, 0, stream>>>(
out, in, alignedOff, rowLen, alignedLen, elemsPerThread, op, vecs...);
CUDA_CHECK(cudaPeekAtLastError());
if (alignedLen < totalLen) {
// should be not smaller than the warp size for better branching
constexpr std::size_t MaxOffset = std::max(std::size_t(32), VecBytes);
matrixLinewiseVecColsTailKernel<Type, IdxType, MaxOffset, Lambda, Vecs...>
<<<dim3(2, 1, 1), dim3(MaxOffset, 1, 1), 0, stream>>>(
out, in, alignedOff, alignedEnd, rowLen, totalLen, op, vecs...);
CUDA_CHECK(cudaPeekAtLastError());
}
}

template <typename Type, typename IdxType, std::size_t VecBytes,
Expand All @@ -214,12 +244,12 @@ void matrixLinewiseVecRows(Type* out, const Type* in, const IdxType rowLen,
(rowLen / raft::gcd(bs.x * uint(VecElems), uint(rowLen))) * VecElems;
// Minimum size of the grid to make device well occupied
const uint occupy = raft::getMultiProcessorCount() * 64;
const dim3 gs = dim3(min(
// does not make sense to have more blocks than this
raft::ceildiv<uint>(uint(totalLen), bs.x * VecElems),
// increase the stride size if necessary
raft::ceildiv<uint>(occupy, stride) * stride),
1, 1);
const dim3 gs(min(
// does not make sense to have more blocks than this
raft::ceildiv<uint>(uint(totalLen), bs.x * VecElems),
// increase the stride size if necessary
raft::ceildiv<uint>(occupy, stride) * stride),
1, 1);

const Type* alignedStart = AlignBytes::roundUp(in);
const IdxType alignedOff = IdxType(alignedStart - in);
Expand Down
36 changes: 2 additions & 34 deletions cpp/test/linalg/matrix_linewise_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -180,36 +180,6 @@ struct LinewiseTest

return testing::AssertionSuccess();
}

testing::AssertionResult runTemp() {
rmm::device_uvector<T> blob = genData();
I n = 257;
I m = 420227;
auto [out, in, vec1, vec2] = assignSafePtrs(blob, n, m);

stream.synchronize();
cudaProfilerStart();
PUSH_RANGE(stream, params.useVanillaMatrixVectorOp ? "method: original"
: "method: linewise");
for (auto alongRows : ::testing::Bool()) {
PUSH_RANGE(stream, alongRows ? "alongRows" : "acrossRows");
I lineLen = alongRows ? m : n;
I nLines = alongRows ? n : m;
{
PUSH_RANGE(stream, "one vec");
runLinewiseSum(out, in, lineLen, nLines, alongRows, vec1);
POP_RANGE(stream);
PUSH_RANGE(stream, "two vecs");
runLinewiseSum(out, in, lineLen, nLines, alongRows, vec1, vec2);
POP_RANGE(stream);
}
POP_RANGE(stream);
}
POP_RANGE(stream);
cudaProfilerStop();

return testing::AssertionSuccess();
}
};

#define TEST_IT(fun, TestClass, ElemType, IndexType) \
Expand Down Expand Up @@ -257,12 +227,10 @@ struct TenGigs {

TEST_IT(run, Megabyte, float, int);
TEST_IT(run, Megabyte, double, int);
// TEST_IT(run, Gigabyte, float, int);
// TEST_IT(run, Gigabyte, double, int);
TEST_IT(run, Gigabyte, float, int);
TEST_IT(run, Gigabyte, double, int);
TEST_IT(run, TenGigs, float, uint64_t);
TEST_IT(run, TenGigs, double, uint64_t);

TEST_IT(runTemp, Gigabyte, float, int);

} // end namespace linalg
} // end namespace raft

0 comments on commit ed27367

Please sign in to comment.