Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batch cutlass distance kernels along N matrix dim #2215

Merged
merged 5 commits into from
Mar 14, 2024
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 66 additions & 53 deletions cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -91,18 +91,11 @@ std::enable_if_t<ops::has_cutlass_op<OpT>::value> cutlassDistanceKernel(const Da

typename EpilogueOutputOp::Params epilog_op_param(dist_op, fin_op);

const DataT *a, *b;

IdxT gemm_lda, gemm_ldb;

// Number of pipelines you want to use
constexpr int NumStages = 3;
// Alignment
constexpr int Alignment = VecLen;

// default initialize problem size with row major inputs
auto problem_size = cutlass::gemm::GemmCoord(n, m, k);

using cutlassDistKernel =
typename cutlass::gemm::kernel::PairwiseDistanceGemm<DataT,
Alignment,
Expand All @@ -116,53 +109,73 @@ std::enable_if_t<ops::has_cutlass_op<OpT>::value> cutlassDistanceKernel(const Da

using cutlassDist = cutlass::gemm::device::GemmUniversalAdapter<cutlassDistKernel>;

if constexpr (isRowMajor) {
a = y;
b = x;
gemm_lda = ldb;
gemm_ldb = lda;
} else {
problem_size = cutlass::gemm::GemmCoord(m, n, k);
a = x;
b = y;
gemm_lda = lda;
gemm_ldb = ldb;
constexpr uint32_t gridYZMax = ((1 << (sizeof(uint16_t) * 8)) - 1);
tfeher marked this conversation as resolved.
Show resolved Hide resolved
constexpr uint32_t max_batch_size = gridYZMax * cutlassDistKernel::ThreadblockShape::kN;
IdxT numNbatches = (n - 1 + max_batch_size) / max_batch_size;

for (IdxT i = 0; i < numNbatches; i++) {
const DataT *a, *b;
IdxT gemm_lda, gemm_ldb;
size_t offsetN = i * max_batch_size;

if constexpr (isRowMajor) {
gemm_lda = ldb;
gemm_ldb = lda;
a = y + offsetN * gemm_lda;
b = x;
} else {
gemm_lda = lda;
gemm_ldb = ldb;
a = x;
b = y + offsetN;
}
IdxT chunkN = (i + 1) * max_batch_size;
IdxT currentN = (chunkN < n) ? max_batch_size : (n - offsetN);

// default initialize problem size with row major inputs
auto problem_size = isRowMajor ? cutlass::gemm::GemmCoord(currentN, m, k)
: cutlass::gemm::GemmCoord(m, currentN, k);

typename cutlassDist::Arguments arguments{
mode,
problem_size,
batch_count,
epilog_op_param,
a,
b,
xn, // C matrix eq vector param, which here is A norm
nullptr, // tensor_Z,
(DataT*)yn + offsetN, // this is broadcast vec, which is required to be non-const param
dOutput + offsetN, // Output distance matrix
(int64_t)0, // batch stride A
(int64_t)0, // batch stride B
(int64_t)0, // batch stride Norm A
(int64_t)0,
(int64_t)0, // batch stride Norm B
(int64_t)0, // batch stride Output
gemm_lda, // stride A
gemm_ldb, // stride B
1, // stride A norm
0, // this is no-op for Z
0, // This must be zero
ldd // stride Output matrix
};

// Using the arguments, query for extra workspace required for matrix multiplication computation
size_t workspace_size = cutlassDist::get_workspace_size(arguments);
// Allocate workspace memory
rmm::device_uvector<uint8_t> workspace(workspace_size, stream);
// Instantiate CUTLASS kernel depending on templates
cutlassDist cutlassDist_op;
// Check the problem size is supported or not
RAFT_CUTLASS_TRY(cutlassDist_op.can_implement(arguments));

// Initialize CUTLASS kernel with arguments and workspace pointer
RAFT_CUTLASS_TRY(cutlassDist_op.initialize(arguments, workspace.data(), stream));

// Launch initialized CUTLASS kernel
RAFT_CUTLASS_TRY(cutlassDist_op(stream));
}

typename cutlassDist::Arguments arguments{
mode, problem_size, batch_count, epilog_op_param, a, b,
xn, // C matrix eq vector param, which here is A norm
nullptr, // tensor_Z,
(DataT*)yn, // this is broadcast vec, which is required to be non-const param
dOutput, // Output distance matrix
(int64_t)0, // batch stride A
(int64_t)0, // batch stride B
(int64_t)0, // batch stride Norm A
(int64_t)0,
(int64_t)0, // batch stride Norm B
(int64_t)0, // batch stride Output
gemm_lda, // stride A
gemm_ldb, // stride B
1, // stride A norm
0, // this is no-op for Z
0, // This must be zero
ldd // stride Output matrix
};

// Using the arguments, query for extra workspace required for matrix multiplication computation
size_t workspace_size = cutlassDist::get_workspace_size(arguments);
// Allocate workspace memory
rmm::device_uvector<uint8_t> workspace(workspace_size, stream);
// Instantiate CUTLASS kernel depending on templates
cutlassDist cutlassDist_op;
// Check the problem size is supported or not
RAFT_CUTLASS_TRY(cutlassDist_op.can_implement(arguments));

// Initialize CUTLASS kernel with arguments and workspace pointer
RAFT_CUTLASS_TRY(cutlassDist_op.initialize(arguments, workspace.data(), stream));

// Launch initialized CUTLASS kernel
RAFT_CUTLASS_TRY(cutlassDist_op(stream));
}

}; // namespace detail
Expand Down
Loading