diff --git a/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh b/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh index 282097742c..1cc272f74e 100644 --- a/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh +++ b/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh @@ -91,18 +91,11 @@ std::enable_if_t::value> cutlassDistanceKernel(const Da typename EpilogueOutputOp::Params epilog_op_param(dist_op, fin_op); - const DataT *a, *b; - - IdxT gemm_lda, gemm_ldb; - // Number of pipelines you want to use constexpr int NumStages = 3; // Alignment constexpr int Alignment = VecLen; - // default initialize problem size with row major inputs - auto problem_size = cutlass::gemm::GemmCoord(n, m, k); - using cutlassDistKernel = typename cutlass::gemm::kernel::PairwiseDistanceGemm::value> cutlassDistanceKernel(const Da using cutlassDist = cutlass::gemm::device::GemmUniversalAdapter; - if constexpr (isRowMajor) { - a = y; - b = x; - gemm_lda = ldb; - gemm_ldb = lda; - } else { - problem_size = cutlass::gemm::GemmCoord(m, n, k); - a = x; - b = y; - gemm_lda = lda; - gemm_ldb = ldb; + constexpr uint32_t gridYZMax = ((1 << (sizeof(uint16_t) * 8)) - 1); + constexpr uint32_t max_batch_size = gridYZMax * cutlassDistKernel::ThreadblockShape::kN; + IdxT numNbatches = (n - 1 + max_batch_size) / max_batch_size; + + for (IdxT i = 0; i < numNbatches; i++) { + const DataT *a, *b; + IdxT gemm_lda, gemm_ldb; + size_t offsetN = i * max_batch_size; + + if constexpr (isRowMajor) { + gemm_lda = ldb; + gemm_ldb = lda; + a = y + offsetN * gemm_lda; + b = x; + } else { + gemm_lda = lda; + gemm_ldb = ldb; + a = x; + b = y + offsetN; + } + IdxT chunkN = (i + 1) * max_batch_size; + IdxT currentN = (chunkN < n) ? max_batch_size : (n - offsetN); + + // default initialize problem size with row major inputs + auto problem_size = isRowMajor ? cutlass::gemm::GemmCoord(currentN, m, k) + : cutlass::gemm::GemmCoord(m, currentN, k); + + typename cutlassDist::Arguments arguments{ + mode, + problem_size, + batch_count, + epilog_op_param, + a, + b, + xn, // C matrix eq vector param, which here is A norm + nullptr, // tensor_Z, + (DataT*)yn + offsetN, // this is broadcast vec, which is required to be non-const param + dOutput + offsetN, // Output distance matrix + (int64_t)0, // batch stride A + (int64_t)0, // batch stride B + (int64_t)0, // batch stride Norm A + (int64_t)0, + (int64_t)0, // batch stride Norm B + (int64_t)0, // batch stride Output + gemm_lda, // stride A + gemm_ldb, // stride B + 1, // stride A norm + 0, // this is no-op for Z + 0, // This must be zero + ldd // stride Output matrix + }; + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = cutlassDist::get_workspace_size(arguments); + // Allocate workspace memory + rmm::device_uvector workspace(workspace_size, stream); + // Instantiate CUTLASS kernel depending on templates + cutlassDist cutlassDist_op; + // Check the problem size is supported or not + RAFT_CUTLASS_TRY(cutlassDist_op.can_implement(arguments)); + + // Initialize CUTLASS kernel with arguments and workspace pointer + RAFT_CUTLASS_TRY(cutlassDist_op.initialize(arguments, workspace.data(), stream)); + + // Launch initialized CUTLASS kernel + RAFT_CUTLASS_TRY(cutlassDist_op(stream)); } - - typename cutlassDist::Arguments arguments{ - mode, problem_size, batch_count, epilog_op_param, a, b, - xn, // C matrix eq vector param, which here is A norm - nullptr, // tensor_Z, - (DataT*)yn, // this is broadcast vec, which is required to be non-const param - dOutput, // Output distance matrix - (int64_t)0, // batch stride A - (int64_t)0, // batch stride B - (int64_t)0, // batch stride Norm A - (int64_t)0, - (int64_t)0, // batch stride Norm B - (int64_t)0, // batch stride Output - gemm_lda, // stride A - gemm_ldb, // stride B - 1, // stride A norm - 0, // this is no-op for Z - 0, // This must be zero - ldd // stride Output matrix - }; - - // Using the arguments, query for extra workspace required for matrix multiplication computation - size_t workspace_size = cutlassDist::get_workspace_size(arguments); - // Allocate workspace memory - rmm::device_uvector workspace(workspace_size, stream); - // Instantiate CUTLASS kernel depending on templates - cutlassDist cutlassDist_op; - // Check the problem size is supported or not - RAFT_CUTLASS_TRY(cutlassDist_op.can_implement(arguments)); - - // Initialize CUTLASS kernel with arguments and workspace pointer - RAFT_CUTLASS_TRY(cutlassDist_op.initialize(arguments, workspace.data(), stream)); - - // Launch initialized CUTLASS kernel - RAFT_CUTLASS_TRY(cutlassDist_op(stream)); } }; // namespace detail diff --git a/cpp/test/distance/dist_cos.cu b/cpp/test/distance/dist_cos.cu index caf55529ed..b792ec4039 100644 --- a/cpp/test/distance/dist_cos.cu +++ b/cpp/test/distance/dist_cos.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2023, NVIDIA CORPORATION. + * Copyright (c) 2018-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -29,10 +29,12 @@ class DistanceExpCosXequalY : public DistanceTestSameBuffer {}; const std::vector> inputsf = { + {0.001f, 128, (65536 + 128) * 128, 8, true, 1234ULL}, {0.001f, 1024, 1024, 32, true, 1234ULL}, {0.001f, 1024, 32, 1024, true, 1234ULL}, {0.001f, 32, 1024, 1024, true, 1234ULL}, {0.003f, 1024, 1024, 1024, true, 1234ULL}, + {0.001f, (65536 + 128) * 128, 128, 8, false, 1234ULL}, {0.001f, 1024, 1024, 32, false, 1234ULL}, {0.001f, 1024, 32, 1024, false, 1234ULL}, {0.001f, 32, 1024, 1024, false, 1234ULL}, diff --git a/cpp/test/distance/dist_l2_exp.cu b/cpp/test/distance/dist_l2_exp.cu index 7bdbb44362..0203d9ed9d 100644 --- a/cpp/test/distance/dist_l2_exp.cu +++ b/cpp/test/distance/dist_l2_exp.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2023, NVIDIA CORPORATION. + * Copyright (c) 2018-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -29,12 +29,14 @@ class DistanceEucExpTestXequalY : public DistanceTestSameBuffer {}; const std::vector> inputsf = { + {0.001f, 128, (65536 + 128) * 128, 8, true, 1234ULL}, {0.001f, 2048, 4096, 128, true, 1234ULL}, {0.001f, 1024, 1024, 32, true, 1234ULL}, {0.001f, 1024, 32, 1024, true, 1234ULL}, {0.001f, 32, 1024, 1024, true, 1234ULL}, {0.003f, 1024, 1024, 1024, true, 1234ULL}, {0.003f, 1021, 1021, 1021, true, 1234ULL}, + {0.001f, (65536 + 128) * 128, 128, 8, false, 1234ULL}, {0.001f, 1024, 1024, 32, false, 1234ULL}, {0.001f, 1024, 32, 1024, false, 1234ULL}, {0.001f, 32, 1024, 1024, false, 1234ULL}, diff --git a/cpp/test/distance/distance_base.cuh b/cpp/test/distance/distance_base.cuh index 938cd219d0..2854a8f3df 100644 --- a/cpp/test/distance/distance_base.cuh +++ b/cpp/test/distance/distance_base.cuh @@ -339,7 +339,7 @@ void naiveDistance(DataType* dist, DataType metric_arg = 2.0f, cudaStream_t stream = 0) { - static const dim3 TPB(16, 32, 1); + static const dim3 TPB(4, 256, 1); dim3 nblks(raft::ceildiv(m, (int)TPB.x), raft::ceildiv(n, (int)TPB.y), 1); switch (type) {