Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace map_along_rows with matrixVectorOp #911

Merged
merged 30 commits into from
Nov 8, 2022
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
7f11225
Integrate accumulate_into_selected into raft prims
Nyrio Oct 7, 2022
529931d
Remove accumulate_into_selected
Nyrio Oct 10, 2022
0f284b0
Merge remote-tracking branch 'origin/branch-22.12' into enh-ann-accum…
Nyrio Oct 10, 2022
6a61537
Replace map_along_rows with matrixVectorOp
Nyrio Oct 10, 2022
e184ebf
Merge remote-tracking branch 'origin/branch-22.12' into enh-map-along…
Nyrio Oct 17, 2022
0a98482
Start adding support for arbitrary types in linewiseOp
Nyrio Oct 18, 2022
eafa61e
Allow different types for output, matrix and vector(s) in mdspan-base…
Nyrio Oct 18, 2022
74309ed
Call cub histogram with signed type to avoid a warning breaking compi…
Nyrio Oct 18, 2022
df7cfb5
Start adding support for different output/matrix/vector types in MatV…
Nyrio Oct 18, 2022
8264462
Fix shared mem buffering offset in linewise kernels
Nyrio Oct 18, 2022
e4a8b66
Pass custom op to naiveMat
Nyrio Oct 18, 2022
938b180
Support different output/matrix/vector(s) types in naiveMatVec
Nyrio Oct 18, 2022
c85ee9b
Test matrix-vector op with different matrix / vector types
Nyrio Oct 18, 2022
be142c6
Fix linewiseOp VecRows kernels
Nyrio Oct 19, 2022
a850d85
Fix linewiseOp VecCols kernels
Nyrio Oct 20, 2022
e44bcc4
Merge OutT=MatT because linewiseOp only supports one input/output mat…
Nyrio Oct 20, 2022
351fddc
Merge remote-tracking branch 'origin/branch-22.12' into enh-map-along…
Nyrio Oct 20, 2022
dcfd962
Replace for_each with linalg::add + fix syntax error
Nyrio Oct 20, 2022
4718e5b
Clang-format fix
Nyrio Oct 20, 2022
fc55921
used alignedLen instead of totalLen in max block number calculation
Nyrio Oct 21, 2022
5421dda
Add misalignments to matrix-vector-op test
Nyrio Oct 26, 2022
ccdc961
Extend matrix-vector op benchmark
Nyrio Oct 26, 2022
b5c4d01
Merge remote-tracking branch 'origin/branch-22.12' into enh-map-along…
Nyrio Oct 28, 2022
f2f67db
Apply changes to new padded kernel (note: test is still failing but a…
Nyrio Oct 28, 2022
1450bae
Put itertools in util namespace
Nyrio Oct 28, 2022
7bcbbe2
Remove TPB from public API (it wasn't even forwarded to the actual im…
Nyrio Oct 28, 2022
c7942b0
Fix utils -> util
Nyrio Oct 28, 2022
d28a3c5
Merge remote-tracking branch 'origin/branch-22.12' into enh-map-along…
Nyrio Nov 8, 2022
68fd977
Move product auxiliary function to itertools::detail
Nyrio Nov 8, 2022
b2be0c1
Add test case for int8_t matrix with float vectors
Nyrio Nov 8, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 14 additions & 79 deletions cpp/include/raft/linalg/detail/matrix_vector_op.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -22,80 +22,10 @@ namespace raft {
namespace linalg {
namespace detail {

namespace {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are any of the changes in this file going to break users downstream (such as cuml)?

Copy link
Contributor Author

@Nyrio Nyrio Oct 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It won't break cuML because all template types are inferred in calls to matrixVectorOp.
It could in theory break other projects if they provide the template list explicitly, but I'm not aware of any.

template <size_t VecBytes>
struct AlignedAccess {
template <typename T>
static inline bool test(const T* matrix, size_t strideBytes)
{
return Pow2<VecBytes>::isAligned(matrix) && Pow2<VecBytes>::isAligned(strideBytes) &&
Pow2<sizeof(T)>::isAligned(VecBytes);
}
};
}; // namespace

template <typename Type, int veclen_, typename Lambda, typename IdxType>
__global__ void matrixVectorOpKernel(Type* out,
const Type* matrix,
const Type* vector,
IdxType D,
IdxType N,
bool rowMajor,
bool bcastAlongRows,
Lambda op)
{
typedef TxN_t<Type, veclen_> VecType;
IdxType len = N * D;
IdxType idx = threadIdx.x;
idx += (IdxType)blockIdx.x * (IdxType)blockDim.x;
idx *= VecType::Ratio;
if (idx >= len) return;
IdxType vIdx;
VecType mat, vec;
///@todo: yikes! use fast-int-div here.
///@todo: shared mem for vector could help with perf
if (rowMajor && bcastAlongRows) {
vIdx = idx % D;
vec.load(vector, vIdx);
} else if (!rowMajor && !bcastAlongRows) {
vIdx = idx % N;
vec.load(vector, vIdx);
} else if (rowMajor && !bcastAlongRows) {
vIdx = idx / D;
vec.fill(vector[vIdx]);
} else {
vIdx = idx / N;
vec.fill(vector[vIdx]);
}
mat.load(matrix, idx);
#pragma unroll
for (int i = 0; i < VecType::Ratio; ++i)
mat.val.data[i] = op(mat.val.data[i], vec.val.data[i]);
mat.store(out, idx);
}

template <typename Type, int veclen_, typename Lambda, typename IdxType, int TPB>
void matrixVectorOpImpl(Type* out,
const Type* matrix,
const Type* vec,
IdxType D,
IdxType N,
bool rowMajor,
bool bcastAlongRows,
Lambda op,
cudaStream_t stream)
{
IdxType len = N * D;
IdxType nblks = raft::ceildiv(veclen_ ? len / veclen_ : veclen_, (IdxType)TPB);
matrixVectorOpKernel<Type, veclen_, Lambda, IdxType>
<<<nblks, TPB, 0, stream>>>(out, matrix, vec, D, N, rowMajor, bcastAlongRows, op);
RAFT_CUDA_TRY(cudaPeekAtLastError());
}

template <typename Type, typename Lambda, typename IdxType = int, int TPB = 256>
void matrixVectorOp(Type* out,
const Type* matrix,
const Type* vec,
template <typename MatT, typename Lambda, typename VecT, typename IdxType = int, int TPB = 256>
void matrixVectorOp(MatT* out,
const MatT* matrix,
const VecT* vec,
IdxType D,
IdxType N,
bool rowMajor,
Expand All @@ -109,11 +39,16 @@ void matrixVectorOp(Type* out,
out, matrix, stride, nLines, rowMajor == bcastAlongRows, op, stream, vec);
}

template <typename Type, typename Lambda, typename IdxType = int, int TPB = 256>
void matrixVectorOp(Type* out,
const Type* matrix,
const Type* vec1,
const Type* vec2,
template <typename MatT,
typename Lambda,
typename Vec1T,
typename Vec2T,
typename IdxType = int,
int TPB = 256>
void matrixVectorOp(MatT* out,
const MatT* matrix,
const Vec1T* vec1,
const Vec2T* vec2,
IdxType D,
IdxType N,
bool rowMajor,
Expand Down
62 changes: 36 additions & 26 deletions cpp/include/raft/linalg/matrix_vector_op.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@ namespace linalg {
* Note : the function will also check that the size of the window of accesses
* is a multiple of the number of elements processed by a thread in order to
* enable faster processing
* @tparam Type the matrix/vector type
* @tparam MatT the matrix type
* @tparam Lambda a device function which represents a binary operator
* @tparam VecT the input vector type
* @tparam IdxType Integer type used to for addressing
* @tparam TPB threads per block of the cuda kernel launched
* @param out the output matrix (passing out = matrix makes it in-place)
Expand All @@ -50,10 +51,10 @@ namespace linalg {
* @param op the mathematical operation
* @param stream cuda stream where to launch work
*/
template <typename Type, typename Lambda, typename IdxType = int, int TPB = 256>
void matrixVectorOp(Type* out,
const Type* matrix,
const Type* vec,
template <typename MatT, typename Lambda, typename VecT, typename IdxType = int, int TPB = 256>
void matrixVectorOp(MatT* out,
const MatT* matrix,
const VecT* vec,
IdxType D,
IdxType N,
bool rowMajor,
Expand All @@ -72,8 +73,10 @@ void matrixVectorOp(Type* out,
* Note : the function will also check that the size of the window of accesses
* is a multiple of the number of elements processed by a thread in order to
* enable faster processing
* @tparam Type the matrix/vector type
* @tparam MatT the matrix type
* @tparam Lambda a device function which represents a binary operator
* @tparam Vec1T the first input vector type
* @tparam Vec2T the second input vector type
* @tparam IdxType Integer type used to for addressing
* @tparam TPB threads per block of the cuda kernel launched
* @param out the output matrix (passing out = matrix makes it in-place)
Expand All @@ -88,11 +91,16 @@ void matrixVectorOp(Type* out,
* @param op the mathematical operation
* @param stream cuda stream where to launch work
*/
template <typename Type, typename Lambda, typename IdxType = int, int TPB = 256>
void matrixVectorOp(Type* out,
const Type* matrix,
const Type* vec1,
const Type* vec2,
template <typename MatT,
typename Lambda,
typename Vec1T,
typename Vec2T,
typename IdxType = int,
int TPB = 256>
Nyrio marked this conversation as resolved.
Show resolved Hide resolved
void matrixVectorOp(MatT* out,
const MatT* matrix,
const Vec1T* vec1,
const Vec2T* vec2,
IdxType D,
IdxType N,
bool rowMajor,
Expand All @@ -116,10 +124,10 @@ void matrixVectorOp(Type* out,
* Note : the function will also check that the size of the window of accesses
* is a multiple of the number of elements processed by a thread in order to
* enable faster processing
* @tparam InValueType the data-type of the input matrices and vectors
* @tparam MatValueType the data-type of the input matrix
* @tparam VecValueType the data-type of the input vector
* @tparam LayoutPolicy the layout of input and output (raft::row_major or raft::col_major)
* @tparam Lambda a device function which represents a binary operator
* @tparam OutElementType the data-type of the output raft::matrix_view
* @tparam IndexType Integer used for addressing
* @tparam TPB threads per block of the cuda kernel launched
* @param[in] handle raft::handle_t
Expand All @@ -130,16 +138,16 @@ void matrixVectorOp(Type* out,
* the rows of the matrix or columns using enum class raft::linalg::Apply
* @param[in] op the mathematical operation
*/
template <typename InValueType,
template <typename MatValueType,
typename VecValueType,
typename LayoutPolicy,
typename Lambda,
typename OutValueType,
typename IndexType,
int TPB = 256>
Nyrio marked this conversation as resolved.
Show resolved Hide resolved
void matrix_vector_op(const raft::handle_t& handle,
raft::device_matrix_view<const InValueType, IndexType, LayoutPolicy> matrix,
raft::device_vector_view<const InValueType, IndexType> vec,
raft::device_matrix_view<OutValueType, IndexType, LayoutPolicy> out,
raft::device_matrix_view<const MatValueType, IndexType, LayoutPolicy> matrix,
raft::device_vector_view<const VecValueType, IndexType> vec,
raft::device_matrix_view<MatValueType, IndexType, LayoutPolicy> out,
Apply apply,
Lambda op)
{
Expand Down Expand Up @@ -177,10 +185,11 @@ void matrix_vector_op(const raft::handle_t& handle,
* Note : the function will also check that the size of the window of accesses
* is a multiple of the number of elements processed by a thread in order to
* enable faster processing
* @tparam InValueType the data-type of the input matrices and vectors
* @tparam MatValueType the data-type of the input and output matrices
* @tparam Vec1ValueType the data-type of the first input vector
* @tparam Vec2ValueType the data-type of the second input vector
* @tparam LayoutPolicy the layout of input and output (raft::row_major or raft::col_major)
* @tparam Lambda a device function which represents a binary operator
* @tparam OutElementType the data-type of the output raft::matrix_view
* @tparam IndexType Integer used for addressing
* @tparam TPB threads per block of the cuda kernel launched
* @param handle raft::handle_t
Expand All @@ -192,17 +201,18 @@ void matrix_vector_op(const raft::handle_t& handle,
* the rows of the matrix or columns using enum class raft::linalg::Apply
* @param op the mathematical operation
*/
template <typename InValueType,
template <typename MatValueType,
typename Vec1ValueType,
typename Vec2ValueType,
typename LayoutPolicy,
typename Lambda,
typename OutValueType,
typename IndexType,
int TPB = 256>
void matrix_vector_op(const raft::handle_t& handle,
raft::device_matrix_view<const InValueType, IndexType, LayoutPolicy> matrix,
raft::device_vector_view<const InValueType, IndexType> vec1,
raft::device_vector_view<const InValueType, IndexType> vec2,
raft::device_matrix_view<OutValueType, IndexType, LayoutPolicy> out,
raft::device_matrix_view<const MatValueType, IndexType, LayoutPolicy> matrix,
raft::device_vector_view<const Vec1ValueType, IndexType> vec1,
raft::device_vector_view<const Vec2ValueType, IndexType> vec2,
raft::device_matrix_view<MatValueType, IndexType, LayoutPolicy> out,
Apply apply,
Lambda op)
{
Expand Down
Loading