Skip to content

Commit

Permalink
Replace matrixVectorOp implementation with matrixLinewiseOp
Browse files Browse the repository at this point in the history
  • Loading branch information
achirkin committed Nov 25, 2021
1 parent bb30615 commit 5c63642
Showing 1 changed file with 5 additions and 194 deletions.
199 changes: 5 additions & 194 deletions cpp/include/raft/linalg/matrix_vector_op.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,84 +16,11 @@

#pragma once

#include <raft/cuda_utils.cuh>
#include <raft/linalg/matrix_linewise_op.cuh>
#include <raft/pow2_utils.cuh>
#include <raft/vectorized.cuh>

namespace raft {
namespace linalg {

namespace {
template <size_t VecBytes>
struct AlignedAccess {
template <typename T>
static inline bool test(const T* matrix, size_t strideBytes)
{
return Pow2<VecBytes>::isAligned(matrix) && Pow2<VecBytes>::isAligned(strideBytes) &&
Pow2<sizeof(T)>::isAligned(VecBytes);
}
};
}; // namespace

template <typename Type, int veclen_, typename Lambda, typename IdxType>
__global__ void matrixVectorOpKernel(Type* out,
const Type* matrix,
const Type* vector,
IdxType D,
IdxType N,
bool rowMajor,
bool bcastAlongRows,
Lambda op)
{
typedef TxN_t<Type, veclen_> VecType;
IdxType len = N * D;
IdxType idx = threadIdx.x;
idx += (IdxType)blockIdx.x * (IdxType)blockDim.x;
idx *= VecType::Ratio;
if (idx >= len) return;
IdxType vIdx;
VecType mat, vec;
///@todo: yikes! use fast-int-div here.
///@todo: shared mem for vector could help with perf
if (rowMajor && bcastAlongRows) {
vIdx = idx % D;
vec.load(vector, vIdx);
} else if (!rowMajor && !bcastAlongRows) {
vIdx = idx % N;
vec.load(vector, vIdx);
} else if (rowMajor && !bcastAlongRows) {
vIdx = idx / D;
vec.fill(vector[vIdx]);
} else {
vIdx = idx / N;
vec.fill(vector[vIdx]);
}
mat.load(matrix, idx);
#pragma unroll
for (int i = 0; i < VecType::Ratio; ++i)
mat.val.data[i] = op(mat.val.data[i], vec.val.data[i]);
mat.store(out, idx);
}

template <typename Type, int veclen_, typename Lambda, typename IdxType, int TPB>
void matrixVectorOpImpl(Type* out,
const Type* matrix,
const Type* vec,
IdxType D,
IdxType N,
bool rowMajor,
bool bcastAlongRows,
Lambda op,
cudaStream_t stream)
{
IdxType len = N * D;
IdxType nblks = raft::ceildiv(veclen_ ? len / veclen_ : veclen_, (IdxType)TPB);
matrixVectorOpKernel<Type, veclen_, Lambda, IdxType>
<<<nblks, TPB, 0, stream>>>(out, matrix, vec, D, N, rowMajor, bcastAlongRows, op);
CUDA_CHECK(cudaPeekAtLastError());
}

/**
* @brief Operations for all the columns or rows with a given vector.
* Caution : Threads process multiple elements to speed up processing. These
Expand Down Expand Up @@ -129,98 +56,8 @@ void matrixVectorOp(Type* out,
cudaStream_t stream)
{
IdxType stride = rowMajor ? D : N;
// IdxType nLines = rowMajor ? N : D;
// return matrixLinewiseOp(out, matrix, stride, nLines,
// rowMajor == bcastAlongRows, op, stream, vec);

size_t stride_bytes = stride * sizeof(Type);

if (AlignedAccess<16>::test(matrix, stride_bytes) && AlignedAccess<16>::test(out, stride_bytes)) {
matrixVectorOpImpl<Type, 16 / sizeof(Type), Lambda, IdxType, TPB>(
out, matrix, vec, D, N, rowMajor, bcastAlongRows, op, stream);
} else if (AlignedAccess<8>::test(matrix, stride_bytes) &&
AlignedAccess<8>::test(out, stride_bytes)) {
matrixVectorOpImpl<Type, 8 / sizeof(Type), Lambda, IdxType, TPB>(
out, matrix, vec, D, N, rowMajor, bcastAlongRows, op, stream);
} else if (AlignedAccess<4>::test(matrix, stride_bytes) &&
AlignedAccess<4>::test(out, stride_bytes)) {
matrixVectorOpImpl<Type, 4 / sizeof(Type), Lambda, IdxType, TPB>(
out, matrix, vec, D, N, rowMajor, bcastAlongRows, op, stream);
} else if (AlignedAccess<2>::test(matrix, stride_bytes) &&
AlignedAccess<2>::test(out, stride_bytes)) {
matrixVectorOpImpl<Type, 2 / sizeof(Type), Lambda, IdxType, TPB>(
out, matrix, vec, D, N, rowMajor, bcastAlongRows, op, stream);
} else if (AlignedAccess<1>::test(matrix, stride_bytes) &&
AlignedAccess<1>::test(out, stride_bytes)) {
matrixVectorOpImpl<Type, 1 / sizeof(Type), Lambda, IdxType, TPB>(
out, matrix, vec, D, N, rowMajor, bcastAlongRows, op, stream);
} else {
matrixVectorOpImpl<Type, 1, Lambda, IdxType, TPB>(
out, matrix, vec, D, N, rowMajor, bcastAlongRows, op, stream);
}
}

///@todo: come up with a cleaner interface to support these cases in future!

template <typename Type, int veclen_, typename Lambda, typename IdxType>
__global__ void matrixVectorOpKernel(Type* out,
const Type* matrix,
const Type* vector1,
const Type* vector2,
IdxType D,
IdxType N,
bool rowMajor,
bool bcastAlongRows,
Lambda op)
{
typedef TxN_t<Type, veclen_> VecType;
IdxType len = N * D;
IdxType idx = (threadIdx.x + (blockIdx.x * blockDim.x)) * VecType::Ratio;
if (idx >= len) return;
IdxType vIdx;
VecType mat, vec1, vec2;
///@todo: yikes! use fast-int-div here.
///@todo: shared mem for vector could help with perf
if (rowMajor && bcastAlongRows) {
vIdx = idx % D;
vec1.load(vector1, vIdx);
vec2.load(vector2, vIdx);
} else if (!rowMajor && !bcastAlongRows) {
vIdx = idx % N;
vec1.load(vector1, vIdx);
vec2.load(vector2, vIdx);
} else if (rowMajor && !bcastAlongRows) {
vIdx = idx / D;
vec1.fill(vector1[vIdx]);
vec2.fill(vector2[vIdx]);
} else {
vIdx = idx / N;
vec1.fill(vector1[vIdx]);
vec2.fill(vector2[vIdx]);
}
mat.load(matrix, idx);
#pragma unroll
for (int i = 0; i < VecType::Ratio; ++i)
mat.val.data[i] = op(mat.val.data[i], vec1.val.data[i], vec2.val.data[i]);
mat.store(out, idx);
}

template <typename Type, int veclen_, typename Lambda, typename IdxType, int TPB>
void matrixVectorOpImpl(Type* out,
const Type* matrix,
const Type* vec1,
const Type* vec2,
IdxType D,
IdxType N,
bool rowMajor,
bool bcastAlongRows,
Lambda op,
cudaStream_t stream)
{
IdxType nblks = raft::ceildiv(N * D, (IdxType)TPB);
matrixVectorOpKernel<Type, veclen_, Lambda, IdxType>
<<<nblks, TPB, 0, stream>>>(out, matrix, vec1, vec2, D, N, rowMajor, bcastAlongRows, op);
CUDA_CHECK(cudaPeekAtLastError());
IdxType nLines = rowMajor ? N : D;
return matrixLinewiseOp(out, matrix, stride, nLines, rowMajor == bcastAlongRows, op, stream, vec);
}

/**
Expand Down Expand Up @@ -260,35 +97,9 @@ void matrixVectorOp(Type* out,
cudaStream_t stream)
{
IdxType stride = rowMajor ? D : N;
// IdxType nLines = rowMajor ? N : D;
// return matrixLinewiseOp(out, matrix, stride, nLines,
// rowMajor == bcastAlongRows, op, stream, vec1, vec2);

size_t stride_bytes = stride * sizeof(Type);

if (AlignedAccess<16>::test(matrix, stride_bytes) && AlignedAccess<16>::test(out, stride_bytes)) {
matrixVectorOpImpl<Type, 16 / sizeof(Type), Lambda, IdxType, TPB>(
out, matrix, vec1, vec2, D, N, rowMajor, bcastAlongRows, op, stream);
} else if (AlignedAccess<8>::test(matrix, stride_bytes) &&
AlignedAccess<8>::test(out, stride_bytes)) {
matrixVectorOpImpl<Type, 8 / sizeof(Type), Lambda, IdxType, TPB>(
out, matrix, vec1, vec2, D, N, rowMajor, bcastAlongRows, op, stream);
} else if (AlignedAccess<4>::test(matrix, stride_bytes) &&
AlignedAccess<4>::test(out, stride_bytes)) {
matrixVectorOpImpl<Type, 4 / sizeof(Type), Lambda, IdxType, TPB>(
out, matrix, vec1, vec2, D, N, rowMajor, bcastAlongRows, op, stream);
} else if (AlignedAccess<2>::test(matrix, stride_bytes) &&
AlignedAccess<2>::test(out, stride_bytes)) {
matrixVectorOpImpl<Type, 2 / sizeof(Type), Lambda, IdxType, TPB>(
out, matrix, vec1, vec2, D, N, rowMajor, bcastAlongRows, op, stream);
} else if (AlignedAccess<1>::test(matrix, stride_bytes) &&
AlignedAccess<1>::test(out, stride_bytes)) {
matrixVectorOpImpl<Type, 1 / sizeof(Type), Lambda, IdxType, TPB>(
out, matrix, vec1, vec2, D, N, rowMajor, bcastAlongRows, op, stream);
} else {
matrixVectorOpImpl<Type, 1, Lambda, IdxType, TPB>(
out, matrix, vec1, vec2, D, N, rowMajor, bcastAlongRows, op, stream);
}
IdxType nLines = rowMajor ? N : D;
return matrixLinewiseOp(
out, matrix, stride, nLines, rowMajor == bcastAlongRows, op, stream, vec1, vec2);
}

}; // end namespace linalg
Expand Down

0 comments on commit 5c63642

Please sign in to comment.