Skip to content

Commit

Permalink
Batch cutlass distance kernels along N matrix dim (#2215)
Browse files Browse the repository at this point in the history
- batch input matrix with N-dim to at most 65535 to avoid cutlass gridY limitation.

Authors:
  - Mahesh Doijade (https://github.com/mdoijade)

Approvers:
  - Tamas Bela Feher (https://github.com/tfeher)

URL: #2215
  • Loading branch information
mdoijade authored Mar 14, 2024
1 parent cb80657 commit 56c0b3a
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 56 deletions.
119 changes: 66 additions & 53 deletions cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -91,18 +91,11 @@ std::enable_if_t<ops::has_cutlass_op<OpT>::value> cutlassDistanceKernel(const Da

typename EpilogueOutputOp::Params epilog_op_param(dist_op, fin_op);

const DataT *a, *b;

IdxT gemm_lda, gemm_ldb;

// Number of pipelines you want to use
constexpr int NumStages = 3;
// Alignment
constexpr int Alignment = VecLen;

// default initialize problem size with row major inputs
auto problem_size = cutlass::gemm::GemmCoord(n, m, k);

using cutlassDistKernel =
typename cutlass::gemm::kernel::PairwiseDistanceGemm<DataT,
Alignment,
Expand All @@ -116,53 +109,73 @@ std::enable_if_t<ops::has_cutlass_op<OpT>::value> cutlassDistanceKernel(const Da

using cutlassDist = cutlass::gemm::device::GemmUniversalAdapter<cutlassDistKernel>;

if constexpr (isRowMajor) {
a = y;
b = x;
gemm_lda = ldb;
gemm_ldb = lda;
} else {
problem_size = cutlass::gemm::GemmCoord(m, n, k);
a = x;
b = y;
gemm_lda = lda;
gemm_ldb = ldb;
constexpr uint32_t gridYZMax = ((1 << (sizeof(uint16_t) * 8)) - 1);
constexpr uint32_t max_batch_size = gridYZMax * cutlassDistKernel::ThreadblockShape::kN;
IdxT numNbatches = (n - 1 + max_batch_size) / max_batch_size;

for (IdxT i = 0; i < numNbatches; i++) {
const DataT *a, *b;
IdxT gemm_lda, gemm_ldb;
size_t offsetN = i * max_batch_size;

if constexpr (isRowMajor) {
gemm_lda = ldb;
gemm_ldb = lda;
a = y + offsetN * gemm_lda;
b = x;
} else {
gemm_lda = lda;
gemm_ldb = ldb;
a = x;
b = y + offsetN;
}
IdxT chunkN = (i + 1) * max_batch_size;
IdxT currentN = (chunkN < n) ? max_batch_size : (n - offsetN);

// default initialize problem size with row major inputs
auto problem_size = isRowMajor ? cutlass::gemm::GemmCoord(currentN, m, k)
: cutlass::gemm::GemmCoord(m, currentN, k);

typename cutlassDist::Arguments arguments{
mode,
problem_size,
batch_count,
epilog_op_param,
a,
b,
xn, // C matrix eq vector param, which here is A norm
nullptr, // tensor_Z,
(DataT*)yn + offsetN, // this is broadcast vec, which is required to be non-const param
dOutput + offsetN, // Output distance matrix
(int64_t)0, // batch stride A
(int64_t)0, // batch stride B
(int64_t)0, // batch stride Norm A
(int64_t)0,
(int64_t)0, // batch stride Norm B
(int64_t)0, // batch stride Output
gemm_lda, // stride A
gemm_ldb, // stride B
1, // stride A norm
0, // this is no-op for Z
0, // This must be zero
ldd // stride Output matrix
};

// Using the arguments, query for extra workspace required for matrix multiplication computation
size_t workspace_size = cutlassDist::get_workspace_size(arguments);
// Allocate workspace memory
rmm::device_uvector<uint8_t> workspace(workspace_size, stream);
// Instantiate CUTLASS kernel depending on templates
cutlassDist cutlassDist_op;
// Check the problem size is supported or not
RAFT_CUTLASS_TRY(cutlassDist_op.can_implement(arguments));

// Initialize CUTLASS kernel with arguments and workspace pointer
RAFT_CUTLASS_TRY(cutlassDist_op.initialize(arguments, workspace.data(), stream));

// Launch initialized CUTLASS kernel
RAFT_CUTLASS_TRY(cutlassDist_op(stream));
}

typename cutlassDist::Arguments arguments{
mode, problem_size, batch_count, epilog_op_param, a, b,
xn, // C matrix eq vector param, which here is A norm
nullptr, // tensor_Z,
(DataT*)yn, // this is broadcast vec, which is required to be non-const param
dOutput, // Output distance matrix
(int64_t)0, // batch stride A
(int64_t)0, // batch stride B
(int64_t)0, // batch stride Norm A
(int64_t)0,
(int64_t)0, // batch stride Norm B
(int64_t)0, // batch stride Output
gemm_lda, // stride A
gemm_ldb, // stride B
1, // stride A norm
0, // this is no-op for Z
0, // This must be zero
ldd // stride Output matrix
};

// Using the arguments, query for extra workspace required for matrix multiplication computation
size_t workspace_size = cutlassDist::get_workspace_size(arguments);
// Allocate workspace memory
rmm::device_uvector<uint8_t> workspace(workspace_size, stream);
// Instantiate CUTLASS kernel depending on templates
cutlassDist cutlassDist_op;
// Check the problem size is supported or not
RAFT_CUTLASS_TRY(cutlassDist_op.can_implement(arguments));

// Initialize CUTLASS kernel with arguments and workspace pointer
RAFT_CUTLASS_TRY(cutlassDist_op.initialize(arguments, workspace.data(), stream));

// Launch initialized CUTLASS kernel
RAFT_CUTLASS_TRY(cutlassDist_op(stream));
}

}; // namespace detail
Expand Down
4 changes: 3 additions & 1 deletion cpp/test/distance/dist_cos.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2018-2023, NVIDIA CORPORATION.
* Copyright (c) 2018-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -29,10 +29,12 @@ class DistanceExpCosXequalY
: public DistanceTestSameBuffer<raft::distance::DistanceType::CosineExpanded, DataType> {};

const std::vector<DistanceInputs<float>> inputsf = {
{0.001f, 128, (65536 + 128) * 128, 8, true, 1234ULL},
{0.001f, 1024, 1024, 32, true, 1234ULL},
{0.001f, 1024, 32, 1024, true, 1234ULL},
{0.001f, 32, 1024, 1024, true, 1234ULL},
{0.003f, 1024, 1024, 1024, true, 1234ULL},
{0.001f, (65536 + 128) * 128, 128, 8, false, 1234ULL},
{0.001f, 1024, 1024, 32, false, 1234ULL},
{0.001f, 1024, 32, 1024, false, 1234ULL},
{0.001f, 32, 1024, 1024, false, 1234ULL},
Expand Down
4 changes: 3 additions & 1 deletion cpp/test/distance/dist_l2_exp.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2018-2023, NVIDIA CORPORATION.
* Copyright (c) 2018-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -29,12 +29,14 @@ class DistanceEucExpTestXequalY
: public DistanceTestSameBuffer<raft::distance::DistanceType::L2Expanded, DataType> {};

const std::vector<DistanceInputs<float>> inputsf = {
{0.001f, 128, (65536 + 128) * 128, 8, true, 1234ULL},
{0.001f, 2048, 4096, 128, true, 1234ULL},
{0.001f, 1024, 1024, 32, true, 1234ULL},
{0.001f, 1024, 32, 1024, true, 1234ULL},
{0.001f, 32, 1024, 1024, true, 1234ULL},
{0.003f, 1024, 1024, 1024, true, 1234ULL},
{0.003f, 1021, 1021, 1021, true, 1234ULL},
{0.001f, (65536 + 128) * 128, 128, 8, false, 1234ULL},
{0.001f, 1024, 1024, 32, false, 1234ULL},
{0.001f, 1024, 32, 1024, false, 1234ULL},
{0.001f, 32, 1024, 1024, false, 1234ULL},
Expand Down
2 changes: 1 addition & 1 deletion cpp/test/distance/distance_base.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ void naiveDistance(DataType* dist,
DataType metric_arg = 2.0f,
cudaStream_t stream = 0)
{
static const dim3 TPB(16, 32, 1);
static const dim3 TPB(4, 256, 1);
dim3 nblks(raft::ceildiv(m, (int)TPB.x), raft::ceildiv(n, (int)TPB.y), 1);

switch (type) {
Expand Down

0 comments on commit 56c0b3a

Please sign in to comment.