diff --git a/cpp/include/raft/linalg/matrix_vector_op.cuh b/cpp/include/raft/linalg/matrix_vector_op.cuh index 3082d92d9f..9e38ebd167 100644 --- a/cpp/include/raft/linalg/matrix_vector_op.cuh +++ b/cpp/include/raft/linalg/matrix_vector_op.cuh @@ -16,84 +16,11 @@ #pragma once -#include #include -#include -#include namespace raft { namespace linalg { -namespace { -template -struct AlignedAccess { - template - static inline bool test(const T* matrix, size_t strideBytes) - { - return Pow2::isAligned(matrix) && Pow2::isAligned(strideBytes) && - Pow2::isAligned(VecBytes); - } -}; -}; // namespace - -template -__global__ void matrixVectorOpKernel(Type* out, - const Type* matrix, - const Type* vector, - IdxType D, - IdxType N, - bool rowMajor, - bool bcastAlongRows, - Lambda op) -{ - typedef TxN_t VecType; - IdxType len = N * D; - IdxType idx = threadIdx.x; - idx += (IdxType)blockIdx.x * (IdxType)blockDim.x; - idx *= VecType::Ratio; - if (idx >= len) return; - IdxType vIdx; - VecType mat, vec; - ///@todo: yikes! use fast-int-div here. - ///@todo: shared mem for vector could help with perf - if (rowMajor && bcastAlongRows) { - vIdx = idx % D; - vec.load(vector, vIdx); - } else if (!rowMajor && !bcastAlongRows) { - vIdx = idx % N; - vec.load(vector, vIdx); - } else if (rowMajor && !bcastAlongRows) { - vIdx = idx / D; - vec.fill(vector[vIdx]); - } else { - vIdx = idx / N; - vec.fill(vector[vIdx]); - } - mat.load(matrix, idx); -#pragma unroll - for (int i = 0; i < VecType::Ratio; ++i) - mat.val.data[i] = op(mat.val.data[i], vec.val.data[i]); - mat.store(out, idx); -} - -template -void matrixVectorOpImpl(Type* out, - const Type* matrix, - const Type* vec, - IdxType D, - IdxType N, - bool rowMajor, - bool bcastAlongRows, - Lambda op, - cudaStream_t stream) -{ - IdxType len = N * D; - IdxType nblks = raft::ceildiv(veclen_ ? len / veclen_ : veclen_, (IdxType)TPB); - matrixVectorOpKernel - <<>>(out, matrix, vec, D, N, rowMajor, bcastAlongRows, op); - CUDA_CHECK(cudaPeekAtLastError()); -} - /** * @brief Operations for all the columns or rows with a given vector. * Caution : Threads process multiple elements to speed up processing. These @@ -129,98 +56,8 @@ void matrixVectorOp(Type* out, cudaStream_t stream) { IdxType stride = rowMajor ? D : N; - // IdxType nLines = rowMajor ? N : D; - // return matrixLinewiseOp(out, matrix, stride, nLines, - // rowMajor == bcastAlongRows, op, stream, vec); - - size_t stride_bytes = stride * sizeof(Type); - - if (AlignedAccess<16>::test(matrix, stride_bytes) && AlignedAccess<16>::test(out, stride_bytes)) { - matrixVectorOpImpl( - out, matrix, vec, D, N, rowMajor, bcastAlongRows, op, stream); - } else if (AlignedAccess<8>::test(matrix, stride_bytes) && - AlignedAccess<8>::test(out, stride_bytes)) { - matrixVectorOpImpl( - out, matrix, vec, D, N, rowMajor, bcastAlongRows, op, stream); - } else if (AlignedAccess<4>::test(matrix, stride_bytes) && - AlignedAccess<4>::test(out, stride_bytes)) { - matrixVectorOpImpl( - out, matrix, vec, D, N, rowMajor, bcastAlongRows, op, stream); - } else if (AlignedAccess<2>::test(matrix, stride_bytes) && - AlignedAccess<2>::test(out, stride_bytes)) { - matrixVectorOpImpl( - out, matrix, vec, D, N, rowMajor, bcastAlongRows, op, stream); - } else if (AlignedAccess<1>::test(matrix, stride_bytes) && - AlignedAccess<1>::test(out, stride_bytes)) { - matrixVectorOpImpl( - out, matrix, vec, D, N, rowMajor, bcastAlongRows, op, stream); - } else { - matrixVectorOpImpl( - out, matrix, vec, D, N, rowMajor, bcastAlongRows, op, stream); - } -} - -///@todo: come up with a cleaner interface to support these cases in future! - -template -__global__ void matrixVectorOpKernel(Type* out, - const Type* matrix, - const Type* vector1, - const Type* vector2, - IdxType D, - IdxType N, - bool rowMajor, - bool bcastAlongRows, - Lambda op) -{ - typedef TxN_t VecType; - IdxType len = N * D; - IdxType idx = (threadIdx.x + (blockIdx.x * blockDim.x)) * VecType::Ratio; - if (idx >= len) return; - IdxType vIdx; - VecType mat, vec1, vec2; - ///@todo: yikes! use fast-int-div here. - ///@todo: shared mem for vector could help with perf - if (rowMajor && bcastAlongRows) { - vIdx = idx % D; - vec1.load(vector1, vIdx); - vec2.load(vector2, vIdx); - } else if (!rowMajor && !bcastAlongRows) { - vIdx = idx % N; - vec1.load(vector1, vIdx); - vec2.load(vector2, vIdx); - } else if (rowMajor && !bcastAlongRows) { - vIdx = idx / D; - vec1.fill(vector1[vIdx]); - vec2.fill(vector2[vIdx]); - } else { - vIdx = idx / N; - vec1.fill(vector1[vIdx]); - vec2.fill(vector2[vIdx]); - } - mat.load(matrix, idx); -#pragma unroll - for (int i = 0; i < VecType::Ratio; ++i) - mat.val.data[i] = op(mat.val.data[i], vec1.val.data[i], vec2.val.data[i]); - mat.store(out, idx); -} - -template -void matrixVectorOpImpl(Type* out, - const Type* matrix, - const Type* vec1, - const Type* vec2, - IdxType D, - IdxType N, - bool rowMajor, - bool bcastAlongRows, - Lambda op, - cudaStream_t stream) -{ - IdxType nblks = raft::ceildiv(N * D, (IdxType)TPB); - matrixVectorOpKernel - <<>>(out, matrix, vec1, vec2, D, N, rowMajor, bcastAlongRows, op); - CUDA_CHECK(cudaPeekAtLastError()); + IdxType nLines = rowMajor ? N : D; + return matrixLinewiseOp(out, matrix, stride, nLines, rowMajor == bcastAlongRows, op, stream, vec); } /** @@ -260,35 +97,9 @@ void matrixVectorOp(Type* out, cudaStream_t stream) { IdxType stride = rowMajor ? D : N; - // IdxType nLines = rowMajor ? N : D; - // return matrixLinewiseOp(out, matrix, stride, nLines, - // rowMajor == bcastAlongRows, op, stream, vec1, vec2); - - size_t stride_bytes = stride * sizeof(Type); - - if (AlignedAccess<16>::test(matrix, stride_bytes) && AlignedAccess<16>::test(out, stride_bytes)) { - matrixVectorOpImpl( - out, matrix, vec1, vec2, D, N, rowMajor, bcastAlongRows, op, stream); - } else if (AlignedAccess<8>::test(matrix, stride_bytes) && - AlignedAccess<8>::test(out, stride_bytes)) { - matrixVectorOpImpl( - out, matrix, vec1, vec2, D, N, rowMajor, bcastAlongRows, op, stream); - } else if (AlignedAccess<4>::test(matrix, stride_bytes) && - AlignedAccess<4>::test(out, stride_bytes)) { - matrixVectorOpImpl( - out, matrix, vec1, vec2, D, N, rowMajor, bcastAlongRows, op, stream); - } else if (AlignedAccess<2>::test(matrix, stride_bytes) && - AlignedAccess<2>::test(out, stride_bytes)) { - matrixVectorOpImpl( - out, matrix, vec1, vec2, D, N, rowMajor, bcastAlongRows, op, stream); - } else if (AlignedAccess<1>::test(matrix, stride_bytes) && - AlignedAccess<1>::test(out, stride_bytes)) { - matrixVectorOpImpl( - out, matrix, vec1, vec2, D, N, rowMajor, bcastAlongRows, op, stream); - } else { - matrixVectorOpImpl( - out, matrix, vec1, vec2, D, N, rowMajor, bcastAlongRows, op, stream); - } + IdxType nLines = rowMajor ? N : D; + return matrixLinewiseOp( + out, matrix, stride, nLines, rowMajor == bcastAlongRows, op, stream, vec1, vec2); } }; // end namespace linalg