;
- dim3 grid = launchConfigGenerator(m, n, shmemSize, fusedL2NNSqrt);
-
- fusedL2NNSqrt<<>>(min,
- x,
- y,
- xn,
- yn,
- m,
- n,
- k,
- maxVal,
- workspace,
- redOp,
- pairRedOp,
- core_lambda,
- raft::identity_op{});
- } else {
- auto fusedL2NN = fusedL2NNkernel;
- dim3 grid = launchConfigGenerator(m, n, shmemSize, fusedL2NN);
- fusedL2NN<<>>(min,
- x,
- y,
- xn,
- yn,
- m,
- n,
- k,
- maxVal,
- workspace,
- redOp,
- pairRedOp,
- core_lambda,
- raft::identity_op{});
- }
+ using AccT = DataT;
+ ops::l2_exp_distance_op distance_op{sqrt};
+
+ raft::identity_op fin_op{};
+
+ auto kernel = fusedL2NNkernel;
+
+ dim3 grid = launchConfigGenerator(m, n, shmemSize, kernel);
+
+ kernel<<>>(
+ min, x, y, xn, yn, m, n, k, maxVal, workspace, redOp, pairRedOp, distance_op, fin_op);
RAFT_CUDA_TRY(cudaGetLastError());
}
diff --git a/cpp/include/raft/distance/detail/pairwise_distance_base.cuh b/cpp/include/raft/distance/detail/pairwise_distance_base.cuh
index 0293f10c29..c6b09be31e 100644
--- a/cpp/include/raft/distance/detail/pairwise_distance_base.cuh
+++ b/cpp/include/raft/distance/detail/pairwise_distance_base.cuh
@@ -14,14 +14,11 @@
* limitations under the License.
*/
#pragma once
-#include
-#include
-#include
-#include
-#include
-#include
+#include // raft::linalg::Contractions_NT
+#include // ceildiv
+#include // RAFT_CUDA_TRY
-#include
+#include // size_t
namespace raft {
namespace distance {
@@ -29,16 +26,12 @@ namespace detail {
/**
* @brief Device class for L1, L2 and cosine distance metrics.
- * @tparam useNorms whether norms are needed
* @tparam DataT input data-type (for A and B matrices)
* @tparam AccT accumulation data-type
* @tparam OutT output data-type (for C and D matrices)
* @tparam IdxT index data-type
* @tparam Policy struct which tunes the Contraction kernel
- * @tparam CoreLambda tells how to accumulate an x and y into
- acc. its signature:
- template void core_lambda(AccT& acc,
- const DataT& x, const DataT& y)
+ * @tparam OpT A distance operation, e.g., cosine_distance_op.
* @tparam EpilogueLambda applies an elementwise function to compute final
values. Its signature is:
template void epilogue_lambda
@@ -56,19 +49,17 @@ namespace detail {
* @param[in] yn row norms of input matrix B. Required for expanded L2, cosine
* @param[output] pD output matrix
* @param[in] smem shared mem buffer for intermediate storage of A, B, xn & yn.
- * @param core_op the core accumulation operation lambda
+ * @param distance_op the distance operation, e.g. cosine_distance_op
* @param epilog_op the epilog operation lambda
* @param fin_op the final gemm epilogue lambda
* @param rowEpilog_op epilog lambda that executes when a full row has been processed
*/
-template >
struct PairwiseDistances : public BaseClass {
+ // Get accumulation type from distance_op
+ using AccT = typename OpT::AccT;
+
private:
typedef Policy P;
const DataT* xn;
@@ -83,7 +77,7 @@ struct PairwiseDistances : public BaseClass {
const DataT* const yBase;
OutT* dOutput;
char* smem;
- CoreLambda core_op;
+ OpT distance_op;
EpilogueLambda epilog_op;
FinalLambda fin_op;
rowEpilogueLambda rowEpilog_op;
@@ -109,7 +103,7 @@ struct PairwiseDistances : public BaseClass {
const DataT* _yn,
OutT* _dOutput,
char* _smem,
- CoreLambda _core_op,
+ OpT _distance_op,
EpilogueLambda _epilog_op,
FinalLambda _fin_op,
rowEpilogueLambda _rowEpilog_op)
@@ -119,7 +113,7 @@ struct PairwiseDistances : public BaseClass {
yBase(_y),
dOutput(_dOutput),
smem(_smem),
- core_op(_core_op),
+ distance_op(_distance_op),
epilog_op(_epilog_op),
fin_op(_fin_op),
rowEpilog_op(_rowEpilog_op),
@@ -159,15 +153,25 @@ struct PairwiseDistances : public BaseClass {
this->switch_read_buffer();
// Epilog:
- if (useNorms) {
+ if (distance_op.use_norms) {
DataT regxn[P::AccRowsPerTh], regyn[P::AccColsPerTh];
load_norms(tile_idx_m, tile_idx_n, regxn, regyn);
// Overlap ldg with epilog computation
ldgNextGridStride(tile_idx_m, tile_idx_n);
+ // Calculate distance_op epilog.
+ // Use .template to disambiguate (See:
+ // https://en.cppreference.com/w/cpp/language/dependent_name)
+ distance_op.template epilog(acc, regxn, regyn, tile_idx_n, tile_idx_m);
+ // And any possible additional epilogs
epilog_op(acc, regxn, regyn, tile_idx_n, tile_idx_m);
} else {
// Overlap ldg with epilog computation
ldgNextGridStride(tile_idx_m, tile_idx_n);
+ // Calculate distance_op epilog.
+ // Use .template to disambiguate (See:
+ // https://en.cppreference.com/w/cpp/language/dependent_name)
+ distance_op.template epilog(acc, nullptr, nullptr, tile_idx_n, tile_idx_m);
+ // And any possible additional epilogs
epilog_op(acc, nullptr, nullptr, tile_idx_n, tile_idx_m);
}
if (writeOut) { store_output(tile_idx_m, tile_idx_n); }
@@ -201,24 +205,41 @@ struct PairwiseDistances : public BaseClass {
}
}
- DI void accumulate()
+ DI void accumulate_reg_tile(DataT (®_x)[P::AccRowsPerTh][P::Veclen],
+ DataT (®_y)[P::AccColsPerTh][P::Veclen])
{
#pragma unroll
- for (int ki = 0; ki < P::Kblk; ki += P::Veclen) {
- this->ldsXY(ki);
+ for (int v = 0; v < P::Veclen; ++v) {
#pragma unroll
for (int i = 0; i < P::AccRowsPerTh; ++i) {
#pragma unroll
for (int j = 0; j < P::AccColsPerTh; ++j) {
-#pragma unroll
- for (int v = 0; v < P::Veclen; ++v) {
- core_op(acc[i][j], this->regx[i][v], this->regy[j][v]);
- }
+ distance_op.core(acc[i][j], reg_x[i][v], reg_y[j][v]);
}
}
}
}
+ DI void accumulate()
+ {
+ // We have a separate ldsXY and accumulate_reg_tile outside the loop body,
+ // so that these separated calls can be interspersed with preceding and
+ // following instructions, thereby hiding latency.
+ this->ldsXY(0);
+
+ // If expensive inner loop, do not unroll loop.
+ constexpr int num_iterations = P::Kblk / P::Veclen - 1;
+ constexpr int unroll_count = decltype(distance_op)::expensive_inner_loop ? 1 : num_iterations;
+#pragma unroll unroll_count
+ for (int ki = P::Veclen; ki < P::Kblk; ki += P::Veclen) {
+ accumulate_reg_tile(this->regx, this->regy);
+ this->ldsXY(ki);
+ }
+
+ // Accumulate last loaded tile.
+ accumulate_reg_tile(this->regx, this->regy);
+ }
+
DI void load_norms(IdxT tile_idx_m,
IdxT tile_idx_n,
DataT (®xn)[P::AccRowsPerTh],
@@ -274,7 +295,11 @@ struct PairwiseDistances : public BaseClass {
template
dim3 launchConfigGenerator(IdxT m, IdxT n, std::size_t sMemSize, T func)
{
- const auto numSMs = raft::getMultiProcessorCount();
+ int devId;
+ RAFT_CUDA_TRY(cudaGetDevice(&devId));
+ int numSMs;
+ RAFT_CUDA_TRY(cudaDeviceGetAttribute(&numSMs, cudaDevAttrMultiProcessorCount, devId));
+
int numBlocksPerSm = 0;
dim3 grid;
diff --git a/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh b/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh
index c5fdd28117..efcd5d9389 100644
--- a/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh
+++ b/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh
@@ -64,21 +64,20 @@ template
-typename std::enable_if::value>::type cutlassDistanceKernel(
- const DataT* x,
- const DataT* y,
- const DataT* xn,
- const DataT* yn,
- IdxT m,
- IdxT n,
- IdxT k,
- IdxT lda,
- IdxT ldb,
- IdxT ldd,
- OutT* dOutput,
- FinalLambda fin_op,
- OpT distance_op,
- cudaStream_t stream)
+std::enable_if_t::value> cutlassDistanceKernel(const DataT* x,
+ const DataT* y,
+ const DataT* xn,
+ const DataT* yn,
+ IdxT m,
+ IdxT n,
+ IdxT k,
+ IdxT lda,
+ IdxT ldb,
+ IdxT ldd,
+ OutT* dOutput,
+ FinalLambda fin_op,
+ OpT distance_op,
+ cudaStream_t stream)
{
static_assert(!(std::is_same::value),
"OutType bool is not supported use uint8_t instead");
diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh
index 8524ce6fdf..e04b56ee8a 100644
--- a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh
+++ b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh
@@ -15,63 +15,74 @@
*/
#pragma once
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-#include
+/* This file has two responsibilities:
+ *
+ * 1. Dispatch to the correct implementation of a kernel based on the
+ * architecture of the device on which the kernel will be launched. For
+ * instance, the cosine distance has a CUTLASS-based implementation that can
+ * be used on SM80+ and the normal implementation that is used on older
+ * architectures.
+ *
+ * 2. Provide concise function templates that can be instantiated in
+ * src/distance/distance/specializations/detail/. Previously,
+ * raft::distance::detail::distance was instantiated. The function
+ * necessarily required a large set of include files, which slowed down the
+ * build. The raft::distance::detail::pairwise_matrix_arch_dispatch functions
+ * do not require as large an include files set, which speeds up the build.
+ */
+
+#include // ops::has_cutlass_op
+#include // dispatch_sm60
+#include // pairwise_matrix_params
+#include // raft::util::arch::SM_*
+
+// NOTE: to minimize compile times, we do not include dispatch_sm80.cuh.
+// Including dispatch_sm80.cuh can slow down compile times (due to CUTLASS).
+// Therefore, it is the including file's responsibility to include the correct
+// dispatch_smXX.cuh headers, as is done in raft/distance/detail/distance.cuh
+// and the specializations in src/distance/distance/specializations/detail/.
namespace raft::distance::detail {
+// This forward-declaration ensures that we do not need to include
+// dispatch_sm80.cuh if we are not calling it in practice. This makes compiling
+// all the non-CUTLASS based distance specializations faster. For CUTLASS-based
+// distances, dispatch_sm80.cuh has to be included by the file including this
+// file.
template
-void pairwise_matrix_dispatch(OpT distance_op,
- IdxT m,
- IdxT n,
- IdxT k,
- const DataT* x,
- const DataT* y,
- const DataT* x_norm,
- const DataT* y_norm,
- OutT* out,
- FinOpT fin_op,
- cudaStream_t stream,
- bool is_row_major)
-{
- // Create kernel parameter struct. Flip x and y if column major.
- IdxT ldx = is_row_major ? k : m;
- IdxT ldy = is_row_major ? k : n;
- IdxT ld_out = is_row_major ? n : m;
-
- pairwise_matrix_params params{
- m, n, k, ldx, ldy, ld_out, x, y, x_norm, y_norm, out, fin_op, is_row_major};
-
- if (!params.is_row_major) { params.flip_x_and_y(); }
+ typename SM_compat_t>
+void pairwise_matrix_sm80_dispatch(OpT,
+ pairwise_matrix_params,
+ SM_compat_t,
+ cudaStream_t);
+template
+void pairwise_matrix_instantiation_point(OpT distance_op,
+ pairwise_matrix_params params,
+ cudaStream_t stream)
+{
// On CUDA 12:
// - always execute normal kernel
//
// On CUDA 11 and below:
// - execute CUTLASS-based kernel on SM_80 and above
// - execute normal kernel below SM_80
+ namespace arch = raft::util::arch;
constexpr bool is_ctk_12 = __CUDACC_VER_MAJOR__ == 12;
constexpr bool cutlass_op_unavailable = !ops::has_cutlass_op();
if constexpr (is_ctk_12 || cutlass_op_unavailable) {
// Always execute legacy kernels on CUDA 12
- auto any_range = raft::arch::SM_range(raft::arch::SM_min(), raft::arch::SM_future());
+ auto any_range = arch::SM_range(arch::SM_min(), arch::SM_future());
pairwise_matrix_sm60_dispatch(distance_op, params, any_range, stream);
} else {
- auto cutlass_range = raft::arch::SM_range(raft::arch::SM_80(), raft::arch::SM_future());
- auto legacy_range = raft::arch::SM_range(raft::arch::SM_min(), raft::arch::SM_80());
+ auto cutlass_range = arch::SM_range(arch::SM_80(), arch::SM_future());
+ auto legacy_range = arch::SM_range(arch::SM_min(), arch::SM_80());
// Get pointer to SM60 kernel to determine the runtime architecture of the
// current system. Other methods to determine the architecture (that do not
@@ -79,7 +90,7 @@ void pairwise_matrix_dispatch(OpT distance_op,
// https://github.com/NVIDIA/cub/issues/545
auto sm60_wrapper = pairwise_matrix_sm60_get_wrapper(distance_op, params, legacy_range);
void* kernel_ptr = reinterpret_cast(sm60_wrapper.kernel_ptr);
- auto runtime_arch = raft::arch::kernel_runtime_arch(kernel_ptr);
+ auto runtime_arch = arch::kernel_runtime_arch(kernel_ptr);
if (cutlass_range.contains(runtime_arch)) {
// If device is SM_80 or later, use CUTLASS-based kernel.
@@ -92,4 +103,35 @@ void pairwise_matrix_dispatch(OpT distance_op,
}
}
+template
+void pairwise_matrix_dispatch(OpT distance_op,
+ IdxT m,
+ IdxT n,
+ IdxT k,
+ const DataT* x,
+ const DataT* y,
+ const DataT* x_norm,
+ const DataT* y_norm,
+ OutT* out,
+ FinOpT fin_op,
+ cudaStream_t stream,
+ bool is_row_major)
+{
+ // Create kernel parameter struct. Flip x and y if column major.
+ IdxT ldx = is_row_major ? k : m;
+ IdxT ldy = is_row_major ? k : n;
+ IdxT ld_out = is_row_major ? n : m;
+
+ pairwise_matrix_params params{
+ m, n, k, ldx, ldy, ld_out, x, y, x_norm, y_norm, out, fin_op, is_row_major};
+
+ if (!params.is_row_major) { params.flip_x_and_y(); }
+ pairwise_matrix_instantiation_point(distance_op, params, stream);
+}
+
}; // namespace raft::distance::detail
diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch_layout.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch_layout.cuh
index c1e4c08af4..f2b0e59822 100644
--- a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch_layout.cuh
+++ b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch_layout.cuh
@@ -15,10 +15,11 @@
*/
#pragma once
-#include "kernel_sm60.cuh"
-#include
-#include
-
+#include // std::min
+#include // size_t
+#include // RAFT_EXPECTS
+#include // pairwise_matrix_params
+#include // std::integral_constant
namespace raft::distance::detail {
/**
@@ -99,15 +100,15 @@ auto dispatch_layout(bool row_major, int vec_len, F&& f)
{
if (row_major) {
switch (vec_len) {
- case 4: return f(std::bool_constant(), vec_len_constant<4>());
- case 2: return f(std::bool_constant(), vec_len_constant<2>());
- default: return f(std::bool_constant(), vec_len_constant<1>());
+ case 4: return f(std::true_type(), vec_len_constant<4>());
+ case 2: return f(std::true_type(), vec_len_constant<2>());
+ default: return f(std::true_type(), vec_len_constant<1>());
}
} else {
switch (vec_len) {
- case 4: return f(std::bool_constant(), vec_len_constant<4>());
- case 2: return f(std::bool_constant(), vec_len_constant<2>());
- default: return f(std::bool_constant(), vec_len_constant<1>());
+ case 4: return f(std::false_type(), vec_len_constant<4>());
+ case 2: return f(std::false_type(), vec_len_constant<2>());
+ default: return f(std::false_type(), vec_len_constant<1>());
}
}
}
diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch_sm60.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch_sm60.cuh
index 6e284007ea..2080fbe9cd 100644
--- a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch_sm60.cuh
+++ b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch_sm60.cuh
@@ -15,10 +15,10 @@
*/
#pragma once
-#include
-#include
-#include
-#include
+#include // std::min
+#include // dispatch_layout
+#include