From 20648c599ec8a95eca23820ce2a5ac05f7b86c55 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Thu, 20 Oct 2022 23:02:26 +0530 Subject: [PATCH 01/25] cutlass based euclidean expanded, cosine kernels --- cpp/CMakeLists.txt | 57 +- cpp/include/raft/core/cudart_utils.hpp | 13 + cpp/include/raft/distance/detail/cosine.cuh | 146 ++-- cpp/include/raft/distance/detail/distance.cuh | 18 +- .../raft/distance/detail/euclidean.cuh | 144 ++-- .../detail/pairwise_distance_cutlass_base.cuh | 169 +++++ .../detail/pairwise_distance_epilogue.h | 124 ++++ .../pairwise_distance_epilogue_elementwise.h | 169 +++++ .../distance/detail/pairwise_distance_gemm.h | 206 ++++++ .../detail/predicated_tile_iterator_normvec.h | 670 ++++++++++++++++++ cpp/include/raft/distance/distance.cuh | 2 + .../distance/specializations/distance.cuh | 22 +- cpp/test/distance/dist_adj.cu | 35 +- cpp/test/distance/dist_euc_exp.cu | 3 + cpp/test/distance/dist_eucsqrt_exp.cu | 74 ++ 15 files changed, 1678 insertions(+), 174 deletions(-) create mode 100755 cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh create mode 100755 cpp/include/raft/distance/detail/pairwise_distance_epilogue.h create mode 100755 cpp/include/raft/distance/detail/pairwise_distance_epilogue_elementwise.h create mode 100755 cpp/include/raft/distance/detail/pairwise_distance_gemm.h create mode 100755 cpp/include/raft/distance/detail/predicated_tile_iterator_normvec.h create mode 100755 cpp/test/distance/dist_eucsqrt_exp.cu diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 2c424e9431..ea3ea8674f 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -162,7 +162,8 @@ add_library(raft::raft ALIAS raft) target_include_directories(raft INTERFACE "$" - "$") + "$" + "${CUTLASS_DIR}/include") # Keep RAFT as lightweight as possible. # Only CUDA libs and rmm should @@ -225,21 +226,21 @@ set_target_properties(raft_distance PROPERTIES EXPORT_NAME distance) if(RAFT_COMPILE_DIST_LIBRARY) add_library(raft_distance_lib src/distance/pairwise_distance.cu - src/distance/specializations/detail/canberra.cu - src/distance/specializations/detail/chebyshev.cu - src/distance/specializations/detail/correlation.cu - src/distance/specializations/detail/cosine.cu - src/distance/specializations/detail/hamming_unexpanded.cu - src/distance/specializations/detail/hellinger_expanded.cu - src/distance/specializations/detail/jensen_shannon_float_float_float_int.cu - src/distance/specializations/detail/jensen_shannon_float_float_float_uint32.cu - src/distance/specializations/detail/jensen_shannon_double_double_double_int.cu - src/distance/specializations/detail/kl_divergence_float_float_float_int.cu - src/distance/specializations/detail/kl_divergence_float_float_float_uint32.cu - src/distance/specializations/detail/kl_divergence_double_double_double_int.cu - src/distance/specializations/detail/l1_float_float_float_int.cu - src/distance/specializations/detail/l1_float_float_float_uint32.cu - src/distance/specializations/detail/l1_double_double_double_int.cu +# src/distance/specializations/detail/canberra.cu +# src/distance/specializations/detail/chebyshev.cu +# src/distance/specializations/detail/correlation.cu +# src/distance/specializations/detail/cosine.cu +# src/distance/specializations/detail/hamming_unexpanded.cu +# src/distance/specializations/detail/hellinger_expanded.cu +# src/distance/specializations/detail/jensen_shannon_float_float_float_int.cu +# src/distance/specializations/detail/jensen_shannon_float_float_float_uint32.cu +# src/distance/specializations/detail/jensen_shannon_double_double_double_int.cu +# src/distance/specializations/detail/kl_divergence_float_float_float_int.cu +# src/distance/specializations/detail/kl_divergence_float_float_float_uint32.cu +# src/distance/specializations/detail/kl_divergence_double_double_double_int.cu +# src/distance/specializations/detail/l1_float_float_float_int.cu +# src/distance/specializations/detail/l1_float_float_float_uint32.cu +# src/distance/specializations/detail/l1_double_double_double_int.cu src/distance/specializations/detail/l2_expanded_float_float_float_int.cu src/distance/specializations/detail/l2_expanded_float_float_float_uint32.cu src/distance/specializations/detail/l2_expanded_double_double_double_int.cu @@ -249,12 +250,12 @@ if(RAFT_COMPILE_DIST_LIBRARY) src/distance/specializations/detail/l2_sqrt_unexpanded_float_float_float_int.cu src/distance/specializations/detail/l2_sqrt_unexpanded_float_float_float_uint32.cu src/distance/specializations/detail/l2_sqrt_unexpanded_double_double_double_int.cu - src/distance/specializations/detail/l2_unexpanded_double_double_double_int.cu - src/distance/specializations/detail/l2_unexpanded_float_float_float_uint32.cu - src/distance/specializations/detail/l2_unexpanded_float_float_float_int.cu - src/distance/specializations/detail/lp_unexpanded_double_double_double_int.cu - src/distance/specializations/detail/lp_unexpanded_float_float_float_uint32.cu - src/distance/specializations/detail/lp_unexpanded_float_float_float_int.cu +# src/distance/specializations/detail/l2_unexpanded_double_double_double_int.cu +# src/distance/specializations/detail/l2_unexpanded_float_float_float_uint32.cu +# src/distance/specializations/detail/l2_unexpanded_float_float_float_int.cu +# src/distance/specializations/detail/lp_unexpanded_double_double_double_int.cu +# src/distance/specializations/detail/lp_unexpanded_float_float_float_uint32.cu +# src/distance/specializations/detail/lp_unexpanded_float_float_float_int.cu ) set_target_properties( raft_distance_lib @@ -305,16 +306,16 @@ set_target_properties(raft_nn PROPERTIES EXPORT_NAME nn) if(RAFT_COMPILE_NN_LIBRARY) add_library(raft_nn_lib - src/nn/specializations/ball_cover.cu - src/nn/specializations/detail/ball_cover_lowdim_pass_one_2d.cu - src/nn/specializations/detail/ball_cover_lowdim_pass_two_2d.cu - src/nn/specializations/detail/ball_cover_lowdim_pass_one_3d.cu - src/nn/specializations/detail/ball_cover_lowdim_pass_two_3d.cu +# src/nn/specializations/ball_cover.cu +# src/nn/specializations/detail/ball_cover_lowdim_pass_one_2d.cu +# src/nn/specializations/detail/ball_cover_lowdim_pass_two_2d.cu +# src/nn/specializations/detail/ball_cover_lowdim_pass_one_3d.cu +# src/nn/specializations/detail/ball_cover_lowdim_pass_two_3d.cu src/nn/specializations/fused_l2_knn_long_float_true.cu src/nn/specializations/fused_l2_knn_long_float_false.cu src/nn/specializations/fused_l2_knn_int_float_true.cu src/nn/specializations/fused_l2_knn_int_float_false.cu - src/nn/specializations/knn.cu +# src/nn/specializations/knn.cu ) set_target_properties( raft_nn_lib diff --git a/cpp/include/raft/core/cudart_utils.hpp b/cpp/include/raft/core/cudart_utils.hpp index e0957ea1f3..a878d3ae77 100644 --- a/cpp/include/raft/core/cudart_utils.hpp +++ b/cpp/include/raft/core/cudart_utils.hpp @@ -354,6 +354,19 @@ inline int getMultiProcessorCount() return mpCount; } +/** helper method to get max usable shared mem per block parameter */ +inline std::pair getMajorMinorVersion() +{ + int devId; + RAFT_CUDA_TRY(cudaGetDevice(&devId)); + int majorVer, minorVer; + RAFT_CUDA_TRY(cudaDeviceGetAttribute(&majorVer, cudaDevAttrComputeCapabilityMajor, devId)); + RAFT_CUDA_TRY(cudaDeviceGetAttribute(&minorVer, cudaDevAttrComputeCapabilityMinor, devId)); + + return std::make_pair(majorVer, minorVer); +} + + /** helper method to convert an array on device to a string on host */ template std::string arr2Str(const T* arr, int size, std::string name, cudaStream_t stream, int width = 4) diff --git a/cpp/include/raft/distance/detail/cosine.cuh b/cpp/include/raft/distance/detail/cosine.cuh index b7eed3e2a8..01b1388a92 100644 --- a/cpp/include/raft/distance/detail/cosine.cuh +++ b/cpp/include/raft/distance/detail/cosine.cuh @@ -18,11 +18,24 @@ #include #include +#include namespace raft { namespace distance { namespace detail { +template +struct CosineOp { + __device__ __host__ CosineOp() { } + __device__ __host__ AccT operator() (DataT &aNorm, const DataT &bNorm, DataT &accVal) const { + return static_cast(1.0) - (AccT) (accVal / (aNorm * bNorm)); + } + __device__ __host__ AccT operator() (DataT aData) const { + return aData; + } +}; + + /** * @brief the cosine distance matrix calculation implementer * It computes the following equation: @@ -71,61 +84,72 @@ void cosineImpl(const DataT* x, FinalLambda fin_op, cudaStream_t stream) { - typedef typename raft::linalg::Policy4x4::Policy RowPolicy; - typedef typename raft::linalg::Policy4x4::ColPolicy ColPolicy; - - typedef typename std::conditional::type KPolicy; - - dim3 blk(KPolicy::Nthreads); - - // Accumulation operation lambda - auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) { acc += x * y; }; - - // epilogue operation lambda for final value calculation - auto epilog_lambda = [] __device__(AccT acc[KPolicy::AccRowsPerTh][KPolicy::AccColsPerTh], - DataT * regxn, - DataT * regyn, - IdxT gridStrideX, - IdxT gridStrideY) { -#pragma unroll - for (int i = 0; i < KPolicy::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = 0; j < KPolicy::AccColsPerTh; ++j) { - acc[i][j] = acc[i][j] / (regxn[i] * regyn[j]); - } - } - }; + const auto deviceVersion = getMajorMinorVersion(); + if (deviceVersion.first >= 8) { + using CosineOp_ = CosineOp; + CosineOp_ cosine_dist_op; + + cutlassDistanceKernel( + x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput, fin_op, cosine_dist_op, stream); - constexpr size_t shmemSize = - KPolicy::SmemSize + ((KPolicy::Mblk + KPolicy::Nblk) * sizeof(DataT)); - if (isRowMajor) { - auto cosineRowMajor = pairwiseDistanceMatKernel; - dim3 grid = launchConfigGenerator(m, n, shmemSize, cosineRowMajor); - cosineRowMajor<<>>( - x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, fin_op); } else { - auto cosineColMajor = pairwiseDistanceMatKernel; - dim3 grid = launchConfigGenerator(m, n, shmemSize, cosineColMajor); - cosineColMajor<<>>( - x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, fin_op); + + typedef typename raft::linalg::Policy4x4::Policy RowPolicy; + typedef typename raft::linalg::Policy4x4::ColPolicy ColPolicy; + + typedef typename std::conditional::type KPolicy; + + dim3 blk(KPolicy::Nthreads); + + // Accumulation operation lambda + auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) { acc += x * y; }; + + // epilogue operation lambda for final value calculation + auto epilog_lambda = [] __device__(AccT acc[KPolicy::AccRowsPerTh][KPolicy::AccColsPerTh], + DataT * regxn, + DataT * regyn, + IdxT gridStrideX, + IdxT gridStrideY) { + #pragma unroll + for (int i = 0; i < KPolicy::AccRowsPerTh; ++i) { + #pragma unroll + for (int j = 0; j < KPolicy::AccColsPerTh; ++j) { + acc[i][j] = 1.0 - (acc[i][j] / (regxn[i] * regyn[j]) ); + } + } + }; + + constexpr size_t shmemSize = + KPolicy::SmemSize + ((KPolicy::Mblk + KPolicy::Nblk) * sizeof(DataT)); + if (isRowMajor) { + auto cosineRowMajor = pairwiseDistanceMatKernel; + dim3 grid = launchConfigGenerator(m, n, shmemSize, cosineRowMajor); + cosineRowMajor<<>>( + x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, fin_op); + } else { + auto cosineColMajor = pairwiseDistanceMatKernel; + dim3 grid = launchConfigGenerator(m, n, shmemSize, cosineColMajor); + cosineColMajor<<>>( + x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, fin_op); + } } RAFT_CUDA_TRY(cudaGetLastError()); @@ -207,13 +231,7 @@ void cosineAlgo1(Index_ m, { auto norm_op = [] __device__(AccType in) { return raft::mySqrt(in); }; - // Wrap fin_op to allow computing 1 - pA before calling fin_op - auto wrapped_fin_op = [fin_op] __device__(AccType d_val, Index_ g_d_idx) { - return fin_op(static_cast(1.0) - d_val, g_d_idx); - }; - - typedef std::is_same is_bool; - typedef typename std::conditional::type CosOutType; + typedef typename std::conditional::type CosOutType; CosOutType* pDcast = reinterpret_cast(pD); ASSERT( @@ -234,12 +252,12 @@ void cosineAlgo1(Index_ m, if (isRowMajor) { lda = k, ldb = k, ldd = n; - cosine( - m, n, k, lda, ldb, ldd, pA, pB, col_vec, row_vec, pDcast, wrapped_fin_op, stream); + cosine( + m, n, k, lda, ldb, ldd, pA, pB, col_vec, row_vec, pDcast, fin_op, stream); } else { lda = n, ldb = m, ldd = m; - cosine( - n, m, k, lda, ldb, ldd, pB, pA, row_vec, col_vec, pDcast, wrapped_fin_op, stream); + cosine( + n, m, k, lda, ldb, ldd, pB, pA, row_vec, col_vec, pDcast, fin_op, stream); } } diff --git a/cpp/include/raft/distance/detail/distance.cuh b/cpp/include/raft/distance/detail/distance.cuh index 4782afe46e..1b6df7974e 100644 --- a/cpp/include/raft/distance/detail/distance.cuh +++ b/cpp/include/raft/distance/detail/distance.cuh @@ -615,6 +615,16 @@ void distance(const InType* x, * @note if workspace is passed as nullptr, this will return in * worksize, the number of bytes of workspace required */ + +template +struct default_fin_op { + __host__ __device__ default_fin_op() { }; + // functor signature. + __host__ __device__ OutType operator()(AccType d_val, Index g_d_idx) const { + return d_val; + } +}; + template ( - x, y, dist, m, n, k, workspace, worksize, default_fin_op, stream, isRowMajor, metric_arg); + using final_op_type = default_fin_op; + final_op_type fin_op; + + distance( + x, y, dist, m, n, k, workspace, worksize, fin_op, stream, isRowMajor, metric_arg); RAFT_CUDA_TRY(cudaPeekAtLastError()); } diff --git a/cpp/include/raft/distance/detail/euclidean.cuh b/cpp/include/raft/distance/detail/euclidean.cuh index d83e81b6a9..0aade70d17 100644 --- a/cpp/include/raft/distance/detail/euclidean.cuh +++ b/cpp/include/raft/distance/detail/euclidean.cuh @@ -15,13 +15,31 @@ */ #pragma once + #include #include +#include namespace raft { namespace distance { namespace detail { +template +struct L2ExpandedOp { + bool sqrt; + + __device__ __host__ L2ExpandedOp() : sqrt(false) { } + __device__ __host__ L2ExpandedOp(bool isSqrt) : sqrt(isSqrt) { } + __device__ __host__ AccT operator() (DataT &aNorm, const DataT &bNorm, DataT &accVal) const { + AccT outVal = aNorm + bNorm - DataT(2.0) * accVal; + return sqrt ? raft::mySqrt(outVal) : outVal; + } + + __device__ __host__ AccT operator() (DataT aData) const { + return aData; + } +}; + /** * @brief the expanded euclidean distance matrix calculation implementer * It computes the following equation: C = op(A^2 + B^2 - 2AB) @@ -71,71 +89,81 @@ void euclideanExpImpl(const DataT* x, FinalLambda fin_op, cudaStream_t stream) { - typedef typename raft::linalg::Policy4x4::Policy RowPolicy; - typedef typename raft::linalg::Policy4x4::ColPolicy ColPolicy; + const auto deviceVersion = getMajorMinorVersion(); + if (deviceVersion.first >= 8) { + using L2Op = L2ExpandedOp; + L2Op L2_dist_op(sqrt); - typedef typename std::conditional::type KPolicy; + cutlassDistanceKernel( + x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput, fin_op, L2_dist_op, stream); - dim3 blk(KPolicy::Nthreads); + } else { - // Accumulation operation lambda - auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) { acc += x * y; }; + typedef typename raft::linalg::Policy4x4::Policy RowPolicy; + typedef typename raft::linalg::Policy4x4::ColPolicy ColPolicy; - // epilogue operation lambda for final value calculation - auto epilog_lambda = [sqrt] __device__(AccT acc[KPolicy::AccRowsPerTh][KPolicy::AccColsPerTh], - DataT * regxn, - DataT * regyn, - IdxT gridStrideX, - IdxT gridStrideY) { -#pragma unroll - for (int i = 0; i < KPolicy::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = 0; j < KPolicy::AccColsPerTh; ++j) { - acc[i][j] = regxn[i] + regyn[j] - (DataT)2.0 * acc[i][j]; - } - } - if (sqrt) { -#pragma unroll + typedef typename std::conditional::type KPolicy; + + dim3 blk(KPolicy::Nthreads); + + // Accumulation operation lambda + auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) { acc += x * y; }; + + // epilogue operation lambda for final value calculation + auto epilog_lambda = [sqrt] __device__(AccT acc[KPolicy::AccRowsPerTh][KPolicy::AccColsPerTh], + DataT * regxn, + DataT * regyn, + IdxT gridStrideX, + IdxT gridStrideY) { + #pragma unroll for (int i = 0; i < KPolicy::AccRowsPerTh; ++i) { -#pragma unroll + #pragma unroll for (int j = 0; j < KPolicy::AccColsPerTh; ++j) { - acc[i][j] = raft::mySqrt(acc[i][j]); + acc[i][j] = regxn[i] + regyn[j] - (DataT)2.0 * acc[i][j]; } } - } - }; + if (sqrt) { + #pragma unroll + for (int i = 0; i < KPolicy::AccRowsPerTh; ++i) { + #pragma unroll + for (int j = 0; j < KPolicy::AccColsPerTh; ++j) { + acc[i][j] = raft::mySqrt(acc[i][j]); + } + } + } + }; - constexpr size_t shmemSize = - KPolicy::SmemSize + ((KPolicy::Mblk + KPolicy::Nblk) * sizeof(DataT)); - if (isRowMajor) { - auto euclideanExpRowMajor = pairwiseDistanceMatKernel; - dim3 grid = launchConfigGenerator(m, n, shmemSize, euclideanExpRowMajor); - - euclideanExpRowMajor<<>>( - x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, fin_op); - } else { - auto euclideanExpColMajor = pairwiseDistanceMatKernel; - dim3 grid = launchConfigGenerator(m, n, shmemSize, euclideanExpColMajor); - euclideanExpColMajor<<>>( - x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, fin_op); + constexpr size_t shmemSize = KPolicy::SmemSize + ((KPolicy::Mblk + KPolicy::Nblk) * sizeof(DataT)); + if (isRowMajor) { + auto euclideanExpRowMajor = pairwiseDistanceMatKernel; + dim3 grid = launchConfigGenerator(m, n, shmemSize, euclideanExpRowMajor); + + euclideanExpRowMajor<<>>( + x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, fin_op); + } else { + auto euclideanExpColMajor = pairwiseDistanceMatKernel; + dim3 grid = launchConfigGenerator(m, n, shmemSize, euclideanExpColMajor); + euclideanExpColMajor<<>>( + x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, fin_op); + } } RAFT_CUDA_TRY(cudaGetLastError()); @@ -164,6 +192,7 @@ void euclideanExp(IdxT m, { size_t bytesA = sizeof(DataT) * lda; size_t bytesB = sizeof(DataT) * ldb; + if (16 % sizeof(DataT) == 0 && bytesA % 16 == 0 && bytesB % 16 == 0) { euclideanExpImpl( x, y, xn, yn, m, n, k, lda, ldb, ldd, sqrt, dOutput, fin_op, stream); @@ -217,8 +246,7 @@ void euclideanAlgo1(Index_ m, { auto norm_op = [] __device__(InType in) { return in; }; - typedef std::is_same is_bool; - typedef typename std::conditional::type ExpOutType; + typedef typename std::conditional::type ExpOutType; ExpOutType* pDcast = reinterpret_cast(pD); ASSERT( diff --git a/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh b/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh new file mode 100755 index 0000000000..216d567444 --- /dev/null +++ b/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh @@ -0,0 +1,169 @@ +/* + * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wstrict-aliasing" + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/device/gemm.h" + +#include "cutlass/tensor_view.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/tensor.h" + +#include "./pairwise_distance_gemm.h" +#include "./pairwise_distance_epilogue_elementwise.h" + +#define CUTLASS_CHECK(status) \ + { \ + cutlass::Status error = status; \ + if (error != cutlass::Status::kSuccess) { \ + std::cerr << "Got cutlass error: " << cutlassGetStatusString(error) << " at: " << __LINE__ \ + << std::endl; \ + exit(EXIT_FAILURE); \ + } \ + } + +namespace raft { +namespace distance { +namespace detail { + +template +void cutlassDistanceKernel(const DataT* x, + const DataT* y, + const DataT* xn, + const DataT* yn, + IdxT m, + IdxT n, + IdxT k, + IdxT lda, + IdxT ldb, + IdxT ldd, + OutT* dOutput, + FinalLambda fin_op, + DistanceFn dist_op, + cudaStream_t stream) { + + using EpilogueOutputOp = cutlass::epilogue::thread::PairwiseDistanceEpilogueElementwise< + DataT, // ElementC_ + AccT, // ElementAccumulator_ + DataT, // ElementCompute_ + AccT, // ElementZ_ + OutT, // ElementT_ + 1, // Elements per access 1 + DistanceFn, + FinalLambda + >; + constexpr int batch_count = 1; + + constexpr auto mode = cutlass::gemm::GemmUniversalMode::kGemm; + + typename EpilogueOutputOp::Params epilog_op_param(dist_op, fin_op); + + const DataT *a, *b; + + IdxT gemm_lda, gemm_ldb; + + // Number of pipelines you want to use + constexpr int NumStages = 3; + // Alignment + constexpr int Alignment = VecLen; + + // default initialize problem size with row major inputs + auto problem_size = cutlass::gemm::GemmCoord(static_cast(n), static_cast(m), static_cast(k)); + + using cutlassDistKernel = typename cutlass::gemm::kernel::PairwiseDistanceGemm< + DataT, Alignment, DataT, Alignment, + AccT, AccT, + EpilogueOutputOp, + NumStages, // Number of pipeline stages + isRowMajor + >::GemmKernel; + + using cutlassDist = cutlass::gemm::device::GemmUniversalAdapter; + + if constexpr (isRowMajor) { + a = y; + b = x; + gemm_lda = ldb; + gemm_ldb = lda; + } else { + problem_size = cutlass::gemm::GemmCoord(static_cast(m), static_cast(n), static_cast(k)); + a = x; + b = y; + gemm_lda = lda; + gemm_ldb = ldb; + } + + typename cutlassDist::Arguments arguments { + mode, + problem_size, + batch_count, + epilog_op_param, + a, + b, + xn, // C matrix eq vector param, which here is A norm + nullptr, //tensor_Z, + (DataT*)yn, // this is broadcast vec, which is required to be non-const param + dOutput, // Output distance matrix + (int64_t)0, // batch stride A + (int64_t) 0, // batch stride B + (int64_t)0, // batch stride Norm A + (int64_t)0, + (int64_t)0, // batch stride Norm B + (int64_t)0, // batch stride Output + gemm_lda, // stride A + gemm_ldb, // stride B + 1, // stride A norm + 0, // this is no-op for Z + 0, // This must be zero + ldd // stride Output matrix + }; + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = cutlassDist::get_workspace_size(arguments); + // Allocate workspace memory + rmm::device_uvector workspace(workspace_size, stream); + // Instantiate CUTLASS kernel depending on templates + cutlassDist cutlassDist_op; + // Check the problem size is supported or not + cutlass::Status status = cutlassDist_op.can_implement(arguments); + CUTLASS_CHECK(status); + // Initialize CUTLASS kernel with arguments and workspace pointer + status = cutlassDist_op.initialize(arguments, workspace.data(), stream); + CUTLASS_CHECK(status); + // Launch initialized CUTLASS kernel + status = cutlassDist_op(); + CUTLASS_CHECK(status); +} + +}; +}; +}; +#pragma GCC diagnostic pop \ No newline at end of file diff --git a/cpp/include/raft/distance/detail/pairwise_distance_epilogue.h b/cpp/include/raft/distance/detail/pairwise_distance_epilogue.h new file mode 100755 index 0000000000..a36d473596 --- /dev/null +++ b/cpp/include/raft/distance/detail/pairwise_distance_epilogue.h @@ -0,0 +1,124 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. + + The epilogue rearranges the result of a matrix product through shared memory to match canonical + tensor layouts in global memory. Epilogues support conversion and reduction operations. + +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/array.h" + +#include "cutlass/gemm/gemm.h" + +#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" +#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h" +#include "cutlass/epilogue/threadblock/epilogue.h" +#include "cutlass/epilogue/threadblock/epilogue_with_broadcast.h" +#include "predicated_tile_iterator_normvec.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Defines sensible defaults for epilogues for TensorOps. +template < + typename Shape, + typename WarpMmaTensorOp, + int PartitionsK, + typename ElementOutput, + typename ElementTensor, + typename ElementVector, + typename OutputOp, + typename LayoutT, + int ElementsPerAccess, + bool ScatterD = false +> +struct PairwiseDistanceEpilogue { + + /// Use defaults related to the existing epilogue + using Base = DefaultEpilogueTensorOp< + Shape, + WarpMmaTensorOp, + PartitionsK, + OutputOp, + ElementsPerAccess + >; + + // + // Stores the result z = (y = GEMM(A, B, C), broadcast) + // + using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorNormVec< + typename Base::OutputTileThreadMap, + ElementOutput, + LayoutT + >; + + // + // Additional tensor tile iterator - stores t = Elementwise(z) + // + using TensorTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< + typename Base::OutputTileThreadMap, + ElementTensor + >; + + /// Define the epilogue + using Epilogue = EpilogueWithBroadcast< + Shape, + WarpMmaTensorOp, + PartitionsK, + OutputTileIterator, + TensorTileIterator, + ElementVector, + typename Base::AccumulatorFragmentIterator, + typename Base::WarpTileIterator, + typename Base::SharedLoadIterator, + OutputOp, + typename Base::Padding, + Base::kFragmentsPerIteration + >; +}; + + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/cpp/include/raft/distance/detail/pairwise_distance_epilogue_elementwise.h b/cpp/include/raft/distance/detail/pairwise_distance_epilogue_elementwise.h new file mode 100755 index 0000000000..9833ef6cf3 --- /dev/null +++ b/cpp/include/raft/distance/detail/pairwise_distance_epilogue_elementwise.h @@ -0,0 +1,169 @@ +// +/*! \file + \brief Functor performing distance operations used by epilogues of pairwise distance + * kernels. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/array.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" + +#include "cutlass/epilogue/thread/activation.h" + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace thread { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// This base class is meant to define the concept required of the +/// EpilogueWithBroadcast::OutputOp +template < + typename ElementC_, + typename ElementAccumulator_, + typename ElementCompute_, + typename ElementZ_, + typename ElementT_, + int ElementsPerAccess, + typename DistanceOp_ , + typename FinalOp_ +> +class PairwiseDistanceEpilogueElementwise { +public: + + using ElementOutput = ElementC_; + using ElementC = ElementC_; + using ElementAccumulator = ElementAccumulator_; + using ElementCompute = ElementCompute_; + using ElementZ = ElementZ_; + using ElementT = ElementT_; + static int const kElementsPerAccess = ElementsPerAccess; + static int const kCount = kElementsPerAccess; + + using DistanceOp = DistanceOp_; + using FinalOp = FinalOp_; + + using FragmentAccumulator = Array; + using FragmentCompute = Array; + using FragmentC = Array; + using FragmentZ = Array; + using FragmentT = Array; + + using FragmentOutput = FragmentZ; + + static bool const kIsHeavy = false; //ElementwiseOp::kIsHeavy; + + /// If true, the 'Z' tensor is stored + static bool const kStoreZ = false; // We don't store anything in Z, + + /// If true, the 'T' tensor is stored + static bool const kStoreT = true; // this is our final output storage. + + /// Host-constructable parameters structure + struct Params { + FinalOp_ final_op_tmp; + DistanceOp_ dist_op_tmp; + + // + // Methods + // + CUTLASS_HOST_DEVICE + Params(DistanceOp_ dist_op, FinalOp lambdafn_): + final_op_tmp(lambdafn_), + dist_op_tmp(dist_op) + { + } + + CUTLASS_HOST_DEVICE + Params() + { + } + }; + +private: + // + // Data members + // + FinalOp_ final_op; + DistanceOp_ elementwise_op; + +public: + // + // Methods + // + + /// Constructor from Params + CUTLASS_HOST_DEVICE + PairwiseDistanceEpilogueElementwise(Params const ¶ms) : + final_op(params.final_op_tmp), elementwise_op(params.dist_op_tmp) { + } + + /// Returns true if source is needed + CUTLASS_HOST_DEVICE + bool is_source_needed() const { + // we use for making sure C matrix path is used for A mat norm. + return true; + } + + /// Functionally required for serial reduction in the epilogue + CUTLASS_HOST_DEVICE + void set_k_partition(int k_partition, int k_partition_count) { +#if 0 + if (k_partition) { + beta_ = ElementCompute(1); + } + + if (k_partition != k_partition_count - 1) { + skip_elementwise_ = true; + } +#endif + } + + /// Applies the operation when is_source_needed() is true + CUTLASS_HOST_DEVICE + void operator()( + FragmentZ &frag_Z, + FragmentT &frag_T, + FragmentAccumulator const &AB, + FragmentC const &frag_C, + FragmentCompute const &V) const { + + FragmentCompute tmp_Accum = NumericArrayConverter()(AB); + FragmentCompute tmp_C = NumericArrayConverter()(frag_C); + FragmentCompute result_Z; + FragmentCompute result_T; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kElementsPerAccess; ++i) { + result_Z[i] = elementwise_op(tmp_C[i], V[i], tmp_Accum[i]); + result_T[i] = final_op(result_Z[i], 0); + } + + NumericArrayConverter convert_t; + frag_T = convert_t(result_T); + } + + /// Applies the operation when is_source_needed() is false + CUTLASS_HOST_DEVICE + void operator()( + FragmentZ &frag_Z, + FragmentT &frag_T, + FragmentAccumulator const &AB, + FragmentCompute const &V) const { + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace thread +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/cpp/include/raft/distance/detail/pairwise_distance_gemm.h b/cpp/include/raft/distance/detail/pairwise_distance_gemm.h new file mode 100755 index 0000000000..15635edaee --- /dev/null +++ b/cpp/include/raft/distance/detail/pairwise_distance_gemm.h @@ -0,0 +1,206 @@ +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/kernel/gemm_with_fused_epilogue.h" +#include "cutlass/gemm/kernel/default_gemm_universal.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/tensor.h" + +//#include "./epilogue_with_bcast_threadblock.h" +#include "pairwise_distance_epilogue.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC_, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Element type for final output + //typename ElementOutT, + /// Epilogue output operator - must satisfy concept of 'EpilogueWithBroadcastOp' + typename EpilogueOutputOp, + /// Number of stages used in the pipelined mainloop + int Stages, + /// data layout row/column major of inputs + bool isRowMajor +> +struct PairwiseDistanceGemm { + + // This struct is specialized for fp32/3xTF32 + + /// Threadblock-level tile size (concept: GemmShape) + using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 16>; // <- threadblock tile M = 128, N = 64, K = 16 + /// Warp-level tile size (concept: GemmShape) + // This code section describes tile size a warp will compute + using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; // <- warp tile M = 64, N = 64, K = 16 + /// Warp-level tile size (concept: GemmShape) + // This code section describes the size of MMA op + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; // <- MMA Op tile M = 16, N = 8, K = 8 + + /// Operation performed by GEMM + using Operator = cutlass::arch::OpMultiplyAddFastF32; + + // This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM + using OperatorClass = cutlass::arch::OpClassTensorOp; + + // This code section describes CUDA SM architecture number + using ArchTag = cutlass::arch::Sm80; + + // This code section describes how threadblocks are scheduled on GPU + /// Threadblock-level swizzling operator + using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ?? + + /// data layout for final output matrix. + // we keep this same layout even for column major inputs + using LayoutOutput = cutlass::layout::RowMajor; + + typedef typename std::conditional::type NormXLayout; + + typedef typename std::conditional::type LayoutA_; + + typedef typename std::conditional::type LayoutB_; + + using GemmBase = typename DefaultGemmUniversal< + ElementA_, LayoutA_, cutlass::ComplexTransform::kNone, kAlignmentA, + ElementB_, LayoutB_, cutlass::ComplexTransform::kNone, kAlignmentB, + ElementC_, LayoutOutput, ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + Operator + >::GemmKernel; + + // Replace epilogue + using Epilogue = typename cutlass::epilogue::threadblock::PairwiseDistanceEpilogue< + typename GemmBase::Epilogue::Shape, + typename GemmBase::Epilogue::WarpMmaOperator, + GemmBase::Epilogue::kPartitionsK, + ElementAccumulator, + typename EpilogueOutputOp::ElementT, + ElementAccumulator, + EpilogueOutputOp, + NormXLayout, + GemmBase::Epilogue::kElementsPerAccess + >::Epilogue; + + // Compose the GEMM kernel + using GemmKernel = GemmWithFusedEpilogue< + typename GemmBase::Mma, + Epilogue, + ThreadblockSwizzle + >; +}; + +template < + /// Layout type for A matrix operand + int kAlignmentA, + /// Layout type for B matrix operand + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC_, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Epilogue output operator - must satisfy concept of 'EpilogueWithBroadcastOp' + typename EpilogueOutputOp, + /// Number of stages used in the pipelined mainloop + int Stages, + /// data layout row/column major of inputs + bool isRowMajor +> +struct PairwiseDistanceGemm { + + //using Transform = cutlass::ComplexTransform::kNone; + // Threadblock-level tile size (concept: GemmShape) + using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 16>; // <- threadblock tile M = 128, N = 64, K = 16 + /// Warp-level tile size (concept: GemmShape) + // This code section describes tile size a warp will compute + using WarpShape = cutlass::gemm::GemmShape<32, 32, 16>; // <- warp tile M = 64, N = 64, K = 16 + /// Warp-level tile size (concept: GemmShape) + // This code section describes the size of MMA op + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; + + // Operation performed by GEMM + using Operator = cutlass::arch::OpMultiplyAdd; + // This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM + using OperatorClass = cutlass::arch::OpClassTensorOp; + + // This code section describes CUDA SM architecture number + using ArchTag = cutlass::arch::Sm80; + + // This code section describes how threadblocks are scheduled on GPU + /// Threadblock-level swizzling operator + using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ?? + + /// data layout for final output matrix. + // we keep this same layout even for column major inputs + using LayoutOutput = cutlass::layout::RowMajor; + + typedef typename std::conditional::type NormXLayout; + + typedef typename std::conditional::type LayoutA_; + + typedef typename std::conditional::type LayoutB_; + + using GemmBase = typename DefaultGemmUniversal< + double, LayoutA_, cutlass::ComplexTransform::kNone, 1, + double, LayoutB_, cutlass::ComplexTransform::kNone, 1, + ElementC_, LayoutOutput, ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + Operator + >::GemmKernel; + + // Replace epilogue + using Epilogue = typename cutlass::epilogue::threadblock::PairwiseDistanceEpilogue< + typename GemmBase::Epilogue::Shape, + typename GemmBase::Epilogue::WarpMmaOperator, + GemmBase::Epilogue::kPartitionsK, + ElementC_, + typename EpilogueOutputOp::ElementT, + ElementC_, + EpilogueOutputOp, + NormXLayout, + GemmBase::Epilogue::kElementsPerAccess + >::Epilogue; + + // Compose the GEMM kernel + using GemmKernel = GemmWithFusedEpilogue< + typename GemmBase::Mma, + Epilogue, + ThreadblockSwizzle + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} +} +} \ No newline at end of file diff --git a/cpp/include/raft/distance/detail/predicated_tile_iterator_normvec.h b/cpp/include/raft/distance/detail/predicated_tile_iterator_normvec.h new file mode 100755 index 0000000000..a1536b2493 --- /dev/null +++ b/cpp/include/raft/distance/detail/predicated_tile_iterator_normvec.h @@ -0,0 +1,670 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. + + The epilogue rearranges the result of a matrix product through shared memory to match canonical + tensor layouts in global memory. Epilogues support conversion and reduction operations. + +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/array.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/transform/pitch_linear_thread_map.h" +#include "cutlass/epilogue/threadblock/output_tile_thread_map.h" +#include "cutlass/arch/arch.h" +#include "cutlass/arch/memory.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator_params.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +//////////////////////////////////////////////////////////////////////////////// + +namespace epilogue { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator used to load and store output tile from global memory in epilogue. +/// +/// Satisfies: ReadableTileIterator | PredicatedTileIterator | ForwardTileIterator +/// +template < + typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap) + typename Element_, ///< Element data type + typename Layout_, + bool ScatterD = false, ///< Scatter D operand or not + bool UseCUDAStore = false +> +class PredicatedTileIteratorNormVec { +public: + using ThreadMap = ThreadMap_; + using Shape = typename ThreadMap::Shape; + + using Element = Element_; + + //using Layout = layout::RowMajor; + using Layout = Layout_; + using TensorRef = TensorRef; + using ConstTensorRef = typename TensorRef::ConstTensorRef; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using TensorCoord = MatrixCoord; + + static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; + static int const kThreads = ThreadMap::kThreads; + static int const kIterations = ThreadMap::Count::kTile; + + static_assert( ThreadMap::Iterations::kRow > 0,"ThreadMap::Iterations::kRow must be > 0"); + static_assert( ThreadMap::Iterations::kGroup > 0,"ThreadMap::Iterations::kGroup must be > 0"); + static_assert( ThreadMap::Iterations::kCluster > 0,"ThreadMap::Iterations::kCluster must be > 0"); + static_assert( ThreadMap::Iterations::kColumn > 0,"ThreadMap::Iterations::kColumn must be > 0"); + + /// Fragment object + using Fragment = Array< + Element, + ThreadMap::Iterations::kColumn * + ThreadMap::Iterations::kRow * + ThreadMap::Iterations::kGroup * + ThreadMap::Iterations::kCluster * ThreadMap::kElementsPerAccess>; + + /// Memory access size + using AccessType = AlignedArray; + + // + // Parameters struct + // + + /// Uses a non-template class + struct Params : PredicatedTileIteratorParams { + using Base = PredicatedTileIteratorParams; + + CUTLASS_HOST_DEVICE + Params() { } + + CUTLASS_HOST_DEVICE + Params(Layout const &layout): + PredicatedTileIteratorParams( + layout.stride(0) * int(sizeof(AccessType)) / kElementsPerAccess, + make_OutputTileThreadMapDesc() + ) + { } + + CUTLASS_HOST_DEVICE + Params(Base const &base) : + Base(base) { } + }; + + /// Mask object + struct Mask { + + static int const kCount = ThreadMap::Iterations::kColumn; + + /// Predicate state + bool predicates[kCount]; + + // + // Mask + // + CUTLASS_HOST_DEVICE + Mask() { + enable(); + } + + ///< Efficiently disables all accesses guarded by mask + CUTLASS_HOST_DEVICE void clear() { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { + predicates[i] = false; + } + } + + ///< CUTLASS_HOST_DEVICE enables all accesses guarded by mask + CUTLASS_DEVICE void enable() { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { + predicates[i] = true; + } + } + }; + +private: + + // + // Data members + // + + /// Parameters structure containing reference and precomputed state. + PredicatedTileIteratorParams params_; + + /// Byte-level pointer + uint8_t *byte_pointer_; + + /// Array of boolean values to contain steady-state predicates + Mask mask_; + + /// Extent of the matrix tile in rows + Index extent_row_; + + /// Extent of the matrix tile in rows + Index extent_column_; + + /// A thread's starting row position (assuming steady-state predicates have been computed) + Index thread_start_row_; + + /// A thread's starting column + Index thread_start_column_; + + /// Internal state counter + int state_[3]; + + /// Scatter indices + int const *indices_; + + // + // Static asserts about internal strides + // + + static_assert(sizeof(extent_row_) == 4, "Expected 32b extents"); + static_assert(sizeof(thread_start_row_) == 4, "Expected 32b extents"); + static_assert(sizeof(PredicatedTileIteratorParams::stride) == 8, "Expected 64b strides"); + +private: + + // + // Methods + // + +public: + + // + // Methods + // + + /// Constructor + CUTLASS_DEVICE + PredicatedTileIteratorNormVec( + PredicatedTileIteratorParams const & params, + Element *pointer, + TensorCoord extent, + int thread_idx, + TensorCoord threadblock_offset = TensorCoord(), + int const *indices = nullptr + ): + params_(params), indices_(indices) + { + + TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx) + threadblock_offset; + + extent_row_ = extent.row(); + extent_column_ = extent.column(); + + thread_start_row_ = thread_offset.row(); + thread_start_column_ = thread_offset.column(); + + // Initialize predicates + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kColumn; ++c) { + + mask_.predicates[c] = ((thread_offset.column() + + ThreadMap::Delta::kColumn * c) < extent.column()); + } + + // Null pointer performs no accesses + if (!pointer) { + mask_.clear(); + } + + if (ScatterD && !indices) { + mask_.clear(); + } + + // Initialize pointer + byte_pointer_ = reinterpret_cast(pointer) + + LongIndex(thread_offset.row()) * LongIndex(params_.stride); + + if (ScatterD) { + byte_pointer_ = reinterpret_cast(pointer) + + LongIndex(thread_offset.column()) * sizeof(AccessType) / kElementsPerAccess; + } + + // Initialize internal state counter + state_[0] = state_[1] = state_[2] = 0; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + byte_pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment &frag, int64_t byte_offset) const { + uint8_t *byte_pointer = byte_pointer_; + AccessType *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { + + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + + int frag_row_idx = + (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + bool row_guard = ((row_offset + thread_start_row_) < extent_row_); + + AccessType *memory_pointer = reinterpret_cast(byte_pointer + byte_offset); + + if (ScatterD && row_guard) { + assert(indices_); + + memory_pointer = reinterpret_cast(byte_pointer + byte_offset + + LongIndex(indices_[row_offset + thread_start_row_]) * LongIndex(params_.stride)); + } + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { + bool guard = row_guard && mask_.predicates[column]; + + + cutlass::arch::global_load< + AccessType, + sizeof(AccessType) + >( + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column], + (void *)&memory_pointer[0], + guard); + + } + + if (row + 1 < ThreadMap::Iterations::kRow) { + if (!ScatterD) { + byte_pointer += params_.increment_row; + } + } + } + + if (group + 1 < ThreadMap::Iterations::kGroup) { + byte_pointer += params_.increment_group; + } + } + + if (cluster + 1 < ThreadMap::Iterations::kCluster) { + byte_pointer += params_.increment_cluster; + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) const { + + load_with_byte_offset(frag, 0); + } + + /// Stores a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const &frag, int64_t byte_offset) const { + uint8_t *byte_pointer = byte_pointer_; + AccessType const *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { + + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + + int frag_row_idx = + (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + bool row_guard = ((row_offset + thread_start_row_) < extent_row_); + + AccessType *memory_pointer = reinterpret_cast(byte_pointer + byte_offset); + + if (ScatterD && row_guard) { + assert(indices_); + + memory_pointer = reinterpret_cast(byte_pointer + byte_offset + + LongIndex(indices_[row_offset + thread_start_row_]) * LongIndex(params_.stride)); + } + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { + + bool guard = row_guard && mask_.predicates[column]; + + if (UseCUDAStore) { + if (guard) { + memory_pointer[column * ThreadMap::Delta::kColumn / kElementsPerAccess] = + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column]; + } + } else { + cutlass::arch::global_store( + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column], + (void *)&memory_pointer[column * ThreadMap::Delta::kColumn / kElementsPerAccess], + guard); + } + } + + if (row + 1 < ThreadMap::Iterations::kRow) { + if (!ScatterD) { + byte_pointer += params_.increment_row; + } + } + } + + if (group + 1 < ThreadMap::Iterations::kGroup) { + byte_pointer += params_.increment_group; + } + } + + if (cluster + 1 < ThreadMap::Iterations::kCluster) { + byte_pointer += params_.increment_cluster; + } + } + } + + /// Stores a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) const { + + store_with_byte_offset(frag, 0); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void downsample_load_with_byte_offset(Fragment &frag, int64_t byte_offset, int convolution_P, int convolution_Q, int add_P, int add_Q, int problem_N) const { + + uint8_t *byte_pointer = byte_pointer_; + AccessType *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { + + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + + int frag_row_idx = + (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + bool row_guard = ((row_offset + thread_start_row_) < extent_row_); + + int output_row = row_offset + thread_start_row_; + int output_N = output_row / (convolution_P * convolution_Q); + int output_PQ = output_row % (convolution_P * convolution_Q); + int output_P = output_PQ / convolution_Q; + int output_Q = output_PQ % convolution_Q; + + int input_row = output_N * 2 * convolution_P * 2 * convolution_Q + + (2 * output_P + add_P) * 2 * convolution_Q + 2 * output_Q + add_Q; + + int64_t byte_offset = (input_row-output_row)*problem_N*sizeof(float); + + AccessType *memory_pointer = reinterpret_cast(byte_pointer + byte_offset); + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { + + bool guard = row_guard && mask_.predicates[column]; + + cutlass::arch::global_load< + AccessType, + sizeof(AccessType) + >( + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + + column], + (void *)&memory_pointer[column * ThreadMap::Delta::kColumn / + kElementsPerAccess], + guard); + } + + if (row + 1 < ThreadMap::Iterations::kRow) { + byte_pointer += params_.increment_row; + } + } + + if (group + 1 < ThreadMap::Iterations::kGroup) { + byte_pointer += params_.increment_group; + } + } + + if (cluster + 1 < ThreadMap::Iterations::kCluster) { + byte_pointer += params_.increment_cluster; + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void upsample_load_with_byte_offset(Fragment &frag, int64_t byte_offset, int convolution_P, int convolution_Q, int add_P, int add_Q, int problem_N) const { + + uint8_t *byte_pointer = byte_pointer_; + AccessType *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { + + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + + int frag_row_idx = + (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + bool row_guard = ((row_offset + thread_start_row_) < extent_row_); + + int output_row = row_offset + thread_start_row_; + int output_N = output_row / (convolution_P * convolution_Q); + int output_PQ = output_row % (convolution_P * convolution_Q); + int output_P = output_PQ / convolution_Q; + int output_Q = output_PQ % convolution_Q; + int row_add_P = add_P; + int row_add_Q = add_Q; + if (output_P > convolution_P - 2) row_add_P = 0; + if (output_Q > convolution_Q - 2) row_add_Q = 0; + + int input_row = output_N * (convolution_P/2) * (convolution_Q/2) + + ((output_P + row_add_P)/2) * (convolution_Q/2) + (output_Q + row_add_Q)/2; + + int64_t byte_offset = (input_row-output_row)*problem_N*sizeof(float); + + AccessType *memory_pointer = reinterpret_cast(byte_pointer + byte_offset); + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { + + bool guard = row_guard && mask_.predicates[column]; + + cutlass::arch::global_load< + AccessType, + sizeof(AccessType) + >( + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + + column], + (void *)&memory_pointer[column * ThreadMap::Delta::kColumn / + kElementsPerAccess], + guard); + } + + if (row + 1 < ThreadMap::Iterations::kRow) { + byte_pointer += params_.increment_row; + } + } + + if (group + 1 < ThreadMap::Iterations::kGroup) { + byte_pointer += params_.increment_group; + } + } + + if (cluster + 1 < ThreadMap::Iterations::kCluster) { + byte_pointer += params_.increment_cluster; + } + } + } + + CUTLASS_DEVICE + MatrixCoord thread_start() const { + return MatrixCoord(thread_start_row_, thread_start_column_); + } + + /// Need to get the thread start row from the tile iterator + CUTLASS_DEVICE + int32_t thread_start_row() const { + return thread_start_row_; + } + + /// Need to get the thread start row from the tile iterator + CUTLASS_DEVICE + int32_t thread_start_column() const { + return thread_start_column_; + } + + /// Extent of the matrix in rows + CUTLASS_DEVICE + Index extent_row() const { + return extent_row_; + } + + /// Extent of the matrix in columns + CUTLASS_DEVICE + Index extent_column() const { + return extent_column_; + } + + /// Advances to the next position to load or store + CUTLASS_HOST_DEVICE + PredicatedTileIteratorNormVec &operator++() { + + ++state_[0]; + + if (!ScatterD) { + byte_pointer_ += params_.advance_row; + } + + thread_start_row_ += ThreadMap::Shape::kRow; + + if (state_[0] == ThreadMap::Count::kRow) { + + state_[0] = 0; + ++state_[1]; + byte_pointer_ += params_.advance_group; + + thread_start_row_ += (ThreadMap::Shape::kGroup - 1) * + ThreadMap::Shape::kRow * ThreadMap::Count::kRow; + + if (state_[1] == ThreadMap::Count::kGroup) { + + state_[1] = 0; + ++state_[2]; + byte_pointer_ += params_.advance_cluster; + + thread_start_row_ += ThreadMap::Count::kGroup * + ThreadMap::Shape::kGroup * ThreadMap::Count::kRow * ThreadMap::Shape::kRow; + + if (state_[2] == ThreadMap::Count::kCluster) { + state_[2] = 0; + byte_pointer_ += params_.advance_tile; + } + } + } + + return *this; + } + + ///< Efficiently disables all accesses guarded by mask + CUTLASS_DEVICE void clear_mask() { + mask_.clear(); + } + + ///< Efficiently enables all accesses guarded by mask + CUTLASS_DEVICE void enable_mask() { + mask_.enable(); + } + + ///< Sets the mask + CUTLASS_DEVICE void get_mask(Mask &mask) const { + mask = mask_; + } + + ///< Sets the mask + CUTLASS_DEVICE void set_mask(Mask const &mask) { + mask_ = mask; + } +}; + + + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/cpp/include/raft/distance/distance.cuh b/cpp/include/raft/distance/distance.cuh index 3db1749bb4..eabc1030b6 100644 --- a/cpp/include/raft/distance/distance.cuh +++ b/cpp/include/raft/distance/distance.cuh @@ -315,6 +315,7 @@ void pairwise_distance(const raft::handle_t& handle, detail::pairwise_distance_impl( x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); break; +#if 0 case raft::distance::DistanceType::L1: detail::pairwise_distance_impl( x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); @@ -364,6 +365,7 @@ void pairwise_distance(const raft::handle_t& handle, pairwise_distance_impl( x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); break; +#endif default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); }; } diff --git a/cpp/include/raft/distance/specializations/distance.cuh b/cpp/include/raft/distance/specializations/distance.cuh index 7553f87e39..dce1926e66 100644 --- a/cpp/include/raft/distance/specializations/distance.cuh +++ b/cpp/include/raft/distance/specializations/distance.cuh @@ -16,17 +16,17 @@ #pragma once -#include -#include -#include -#include -#include -#include -#include -#include -#include +// #include +// #include +// #include +// #include +// #include +// #include +// #include +// #include +// #include #include #include #include -#include -#include +// #include +// #include diff --git a/cpp/test/distance/dist_adj.cu b/cpp/test/distance/dist_adj.cu index 16c6e11719..7d3064af7b 100644 --- a/cpp/test/distance/dist_adj.cu +++ b/cpp/test/distance/dist_adj.cu @@ -26,7 +26,7 @@ namespace raft { namespace distance { template -__global__ void naiveDistanceAdjKernel(bool* dist, +__global__ void naiveDistanceAdjKernel(uint8_t* dist, const DataType* x, const DataType* y, int m, @@ -50,7 +50,7 @@ __global__ void naiveDistanceAdjKernel(bool* dist, } template -void naiveDistanceAdj(bool* dist, +void naiveDistanceAdj(uint8_t* dist, const DataType* x, const DataType* y, int m, @@ -74,6 +74,17 @@ struct DistanceAdjInputs { unsigned long long int seed; }; +template +struct threshold_final_op { + DataT threshold_val; + + __device__ __host__ threshold_final_op() : threshold_val(0.0) {} + __device__ __host__ threshold_final_op(DataT val) : threshold_val(val) {} + __device__ __host__ OutT operator()(AccT d_val, Index g_idx) const { + return d_val <= threshold_val; + } +}; + template ::std::ostream& operator<<(::std::ostream& os, const DistanceAdjInputs& dims) { @@ -109,14 +120,18 @@ class DistanceAdjTest : public ::testing::TestWithParam( + getWorkspaceSize( x.data(), y.data(), m, n, k); rmm::device_uvector workspace(worksize, stream); - +#if 0 auto fin_op = [threshold] __device__(DataType d_val, int g_d_idx) { return d_val <= threshold; }; - raft::distance::distance( +#else + using threshold_final_op_ = threshold_final_op; + threshold_final_op_ threshold_op(threshold); +#endif + raft::distance::distance( x.data(), y.data(), dist.data(), @@ -125,7 +140,7 @@ class DistanceAdjTest : public ::testing::TestWithParam params; - rmm::device_uvector dist_ref; - rmm::device_uvector dist; + rmm::device_uvector dist_ref; + rmm::device_uvector dist; raft::handle_t handle; cudaStream_t stream; }; @@ -156,7 +171,7 @@ TEST_P(DistanceAdjTestF, Result) { int m = params.isRowMajor ? params.m : params.n; int n = params.isRowMajor ? params.n : params.m; - ASSERT_TRUE(devArrMatch(dist_ref.data(), dist.data(), m, n, raft::Compare(), stream)); + ASSERT_TRUE(devArrMatch(dist_ref.data(), dist.data(), m, n, raft::Compare(), stream)); } INSTANTIATE_TEST_CASE_P(DistanceAdjTests, DistanceAdjTestF, ::testing::ValuesIn(inputsf)); @@ -175,7 +190,7 @@ TEST_P(DistanceAdjTestD, Result) { int m = params.isRowMajor ? params.m : params.n; int n = params.isRowMajor ? params.n : params.m; - ASSERT_TRUE(devArrMatch(dist_ref.data(), dist.data(), m, n, raft::Compare(), stream)); + ASSERT_TRUE(devArrMatch(dist_ref.data(), dist.data(), m, n, raft::Compare(), stream)); } INSTANTIATE_TEST_CASE_P(DistanceAdjTests, DistanceAdjTestD, ::testing::ValuesIn(inputsd)); diff --git a/cpp/test/distance/dist_euc_exp.cu b/cpp/test/distance/dist_euc_exp.cu index ff142da7fa..5371b8a3e2 100644 --- a/cpp/test/distance/dist_euc_exp.cu +++ b/cpp/test/distance/dist_euc_exp.cu @@ -25,14 +25,17 @@ class DistanceEucExpTest : public DistanceTest> inputsf = { + {0.001f, 2048, 4096, 128, true, 1234ULL}, {0.001f, 1024, 1024, 32, true, 1234ULL}, {0.001f, 1024, 32, 1024, true, 1234ULL}, {0.001f, 32, 1024, 1024, true, 1234ULL}, {0.003f, 1024, 1024, 1024, true, 1234ULL}, + {0.003f, 1021, 1021, 1021, true, 1234ULL}, {0.001f, 1024, 1024, 32, false, 1234ULL}, {0.001f, 1024, 32, 1024, false, 1234ULL}, {0.001f, 32, 1024, 1024, false, 1234ULL}, {0.003f, 1024, 1024, 1024, false, 1234ULL}, + {0.003f, 1021, 1021, 1021, false, 1234ULL}, }; typedef DistanceEucExpTest DistanceEucExpTestF; TEST_P(DistanceEucExpTestF, Result) diff --git a/cpp/test/distance/dist_eucsqrt_exp.cu b/cpp/test/distance/dist_eucsqrt_exp.cu new file mode 100755 index 0000000000..90b4f4288c --- /dev/null +++ b/cpp/test/distance/dist_eucsqrt_exp.cu @@ -0,0 +1,74 @@ +/* + * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../test_utils.h" +#include "distance_base.cuh" + +namespace raft { +namespace distance { + +template +class DistanceEucSqrtExpTest : public DistanceTest { +}; + +const std::vector> inputsf = { + {0.001f, 2048, 4096, 128, true, 1234ULL}, + {0.001f, 1024, 1024, 32, true, 1234ULL}, + {0.001f, 1024, 32, 1024, true, 1234ULL}, + {0.001f, 32, 1024, 1024, true, 1234ULL}, + {0.003f, 1024, 1024, 1024, true, 1234ULL}, + {0.003f, 1021, 1021, 1021, true, 1234ULL}, + {0.001f, 1024, 1024, 32, false, 1234ULL}, + {0.001f, 1024, 32, 1024, false, 1234ULL}, + {0.001f, 32, 1024, 1024, false, 1234ULL}, + {0.003f, 1024, 1024, 1024, false, 1234ULL}, + {0.003f, 1021, 1021, 1021, false, 1234ULL}, +}; +typedef DistanceEucSqrtExpTest DistanceEucSqrtExpTestF; +TEST_P(DistanceEucSqrtExpTestF, Result) +{ + int m = params.isRowMajor ? params.m : params.n; + int n = params.isRowMajor ? params.n : params.m; + ASSERT_TRUE(devArrMatch( + dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance), stream)); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceEucSqrtExpTestF, ::testing::ValuesIn(inputsf)); + +const std::vector> inputsd = { + {0.001, 1024, 1024, 32, true, 1234ULL}, + {0.001, 1024, 32, 1024, true, 1234ULL}, + {0.001, 32, 1024, 1024, true, 1234ULL}, + {0.003, 1024, 1024, 1024, true, 1234ULL}, + {0.001, 1024, 1024, 32, false, 1234ULL}, + {0.001, 1024, 32, 1024, false, 1234ULL}, + {0.001, 32, 1024, 1024, false, 1234ULL}, + {0.003, 1024, 1024, 1024, false, 1234ULL}, +}; +typedef DistanceEucSqrtExpTest DistanceEucSqrtExpTestD; +TEST_P(DistanceEucSqrtExpTestD, Result) +{ + int m = params.isRowMajor ? params.m : params.n; + int n = params.isRowMajor ? params.n : params.m; + ASSERT_TRUE(devArrMatch( + dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance), stream)); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceEucSqrtExpTestD, ::testing::ValuesIn(inputsd)); + +class BigMatrixEucSqrtExp : public BigMatrixDistanceTest { +}; +TEST_F(BigMatrixEucSqrtExp, Result) {} +} // end namespace distance +} // end namespace raft From a9dabc8ccf65c75d7596e2f4439ac124367b3a63 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Fri, 21 Oct 2022 15:29:39 +0530 Subject: [PATCH 02/25] add prior ampere pairwisedistmat kernel to prevent redundant kernel compilation, fix clang formating and correct line endings --- cpp/include/raft/core/cudart_utils.hpp | 3 +- cpp/include/raft/distance/detail/cosine.cuh | 75 +- cpp/include/raft/distance/detail/distance.cuh | 8 +- .../raft/distance/detail/euclidean.cuh | 87 +- .../detail/pairwise_distance_base.cuh | 83 ++ .../detail/pairwise_distance_cutlass_base.cuh | 335 +++-- .../detail/pairwise_distance_epilogue.h | 233 ++- .../pairwise_distance_epilogue_elementwise.h | 332 +++-- .../distance/detail/pairwise_distance_gemm.h | 430 +++--- .../detail/predicated_tile_iterator_normvec.h | 1261 ++++++++--------- cpp/test/distance/dist_adj.cu | 40 +- cpp/test/distance/dist_eucsqrt_exp.cu | 150 +- 12 files changed, 1516 insertions(+), 1521 deletions(-) mode change 100755 => 100644 cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh mode change 100755 => 100644 cpp/include/raft/distance/detail/pairwise_distance_epilogue.h mode change 100755 => 100644 cpp/include/raft/distance/detail/pairwise_distance_epilogue_elementwise.h mode change 100755 => 100644 cpp/include/raft/distance/detail/pairwise_distance_gemm.h mode change 100755 => 100644 cpp/include/raft/distance/detail/predicated_tile_iterator_normvec.h mode change 100755 => 100644 cpp/test/distance/dist_eucsqrt_exp.cu diff --git a/cpp/include/raft/core/cudart_utils.hpp b/cpp/include/raft/core/cudart_utils.hpp index a878d3ae77..8c6aee0780 100644 --- a/cpp/include/raft/core/cudart_utils.hpp +++ b/cpp/include/raft/core/cudart_utils.hpp @@ -355,7 +355,7 @@ inline int getMultiProcessorCount() } /** helper method to get max usable shared mem per block parameter */ -inline std::pair getMajorMinorVersion() +inline std::pair getMajorMinorVersion() { int devId; RAFT_CUDA_TRY(cudaGetDevice(&devId)); @@ -366,7 +366,6 @@ inline std::pair getMajorMinorVersion() return std::make_pair(majorVer, minorVer); } - /** helper method to convert an array on device to a string on host */ template std::string arr2Str(const T* arr, int size, std::string name, cudaStream_t stream, int width = 4) diff --git a/cpp/include/raft/distance/detail/cosine.cuh b/cpp/include/raft/distance/detail/cosine.cuh index 01b1388a92..4a94d02fc4 100644 --- a/cpp/include/raft/distance/detail/cosine.cuh +++ b/cpp/include/raft/distance/detail/cosine.cuh @@ -17,8 +17,8 @@ #pragma once #include -#include #include +#include namespace raft { namespace distance { @@ -26,16 +26,14 @@ namespace detail { template struct CosineOp { - __device__ __host__ CosineOp() { } - __device__ __host__ AccT operator() (DataT &aNorm, const DataT &bNorm, DataT &accVal) const { - return static_cast(1.0) - (AccT) (accVal / (aNorm * bNorm)); - } - __device__ __host__ AccT operator() (DataT aData) const { - return aData; - } + __device__ __host__ CosineOp() {} + __device__ __host__ AccT operator()(DataT& aNorm, const DataT& bNorm, DataT& accVal) const + { + return static_cast(1.0) - (AccT)(accVal / (aNorm * bNorm)); + } + __device__ __host__ AccT operator()(DataT aData) const { return aData; } }; - /** * @brief the cosine distance matrix calculation implementer * It computes the following equation: @@ -84,16 +82,15 @@ void cosineImpl(const DataT* x, FinalLambda fin_op, cudaStream_t stream) { - const auto deviceVersion = getMajorMinorVersion(); + const auto deviceVersion = getMajorMinorVersion(); if (deviceVersion.first >= 8) { using CosineOp_ = CosineOp; CosineOp_ cosine_dist_op; cutlassDistanceKernel( - x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput, fin_op, cosine_dist_op, stream); + x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput, fin_op, cosine_dist_op, stream); } else { - typedef typename raft::linalg::Policy4x4::Policy RowPolicy; typedef typename raft::linalg::Policy4x4::ColPolicy ColPolicy; @@ -106,15 +103,15 @@ void cosineImpl(const DataT* x, // epilogue operation lambda for final value calculation auto epilog_lambda = [] __device__(AccT acc[KPolicy::AccRowsPerTh][KPolicy::AccColsPerTh], - DataT * regxn, - DataT * regyn, - IdxT gridStrideX, - IdxT gridStrideY) { - #pragma unroll + DataT * regxn, + DataT * regyn, + IdxT gridStrideX, + IdxT gridStrideY) { +#pragma unroll for (int i = 0; i < KPolicy::AccRowsPerTh; ++i) { - #pragma unroll +#pragma unroll for (int j = 0; j < KPolicy::AccColsPerTh; ++j) { - acc[i][j] = 1.0 - (acc[i][j] / (regxn[i] * regyn[j]) ); + acc[i][j] = 1.0 - (acc[i][j] / (regxn[i] * regyn[j])); } } }; @@ -122,30 +119,30 @@ void cosineImpl(const DataT* x, constexpr size_t shmemSize = KPolicy::SmemSize + ((KPolicy::Mblk + KPolicy::Nblk) * sizeof(DataT)); if (isRowMajor) { - auto cosineRowMajor = pairwiseDistanceMatKernel; + auto cosineRowMajor = pairwiseDistanceMatKernelPriorToAmpere; dim3 grid = launchConfigGenerator(m, n, shmemSize, cosineRowMajor); cosineRowMajor<<>>( x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, fin_op); } else { - auto cosineColMajor = pairwiseDistanceMatKernel; + auto cosineColMajor = pairwiseDistanceMatKernelPriorToAmpere; dim3 grid = launchConfigGenerator(m, n, shmemSize, cosineColMajor); cosineColMajor<<>>( x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, fin_op); diff --git a/cpp/include/raft/distance/detail/distance.cuh b/cpp/include/raft/distance/detail/distance.cuh index 1b6df7974e..7dda15f573 100644 --- a/cpp/include/raft/distance/detail/distance.cuh +++ b/cpp/include/raft/distance/detail/distance.cuh @@ -618,11 +618,9 @@ void distance(const InType* x, template struct default_fin_op { - __host__ __device__ default_fin_op() { }; - // functor signature. - __host__ __device__ OutType operator()(AccType d_val, Index g_d_idx) const { - return d_val; - } + __host__ __device__ default_fin_op(){}; + // functor signature. + __host__ __device__ OutType operator()(AccType d_val, Index g_d_idx) const { return d_val; } }; template -#include #include +#include namespace raft { namespace distance { @@ -26,18 +26,17 @@ namespace detail { template struct L2ExpandedOp { - bool sqrt; - - __device__ __host__ L2ExpandedOp() : sqrt(false) { } - __device__ __host__ L2ExpandedOp(bool isSqrt) : sqrt(isSqrt) { } - __device__ __host__ AccT operator() (DataT &aNorm, const DataT &bNorm, DataT &accVal) const { - AccT outVal = aNorm + bNorm - DataT(2.0) * accVal; - return sqrt ? raft::mySqrt(outVal) : outVal; - } + bool sqrt; + + __device__ __host__ L2ExpandedOp() : sqrt(false) {} + __device__ __host__ L2ExpandedOp(bool isSqrt) : sqrt(isSqrt) {} + __device__ __host__ AccT operator()(DataT& aNorm, const DataT& bNorm, DataT& accVal) const + { + AccT outVal = aNorm + bNorm - DataT(2.0) * accVal; + return sqrt ? raft::mySqrt(outVal) : outVal; + } - __device__ __host__ AccT operator() (DataT aData) const { - return aData; - } + __device__ __host__ AccT operator()(DataT aData) const { return aData; } }; /** @@ -89,16 +88,15 @@ void euclideanExpImpl(const DataT* x, FinalLambda fin_op, cudaStream_t stream) { - const auto deviceVersion = getMajorMinorVersion(); + const auto deviceVersion = getMajorMinorVersion(); if (deviceVersion.first >= 8) { using L2Op = L2ExpandedOp; L2Op L2_dist_op(sqrt); cutlassDistanceKernel( - x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput, fin_op, L2_dist_op, stream); + x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput, fin_op, L2_dist_op, stream); } else { - typedef typename raft::linalg::Policy4x4::Policy RowPolicy; typedef typename raft::linalg::Policy4x4::ColPolicy ColPolicy; @@ -111,21 +109,21 @@ void euclideanExpImpl(const DataT* x, // epilogue operation lambda for final value calculation auto epilog_lambda = [sqrt] __device__(AccT acc[KPolicy::AccRowsPerTh][KPolicy::AccColsPerTh], - DataT * regxn, - DataT * regyn, - IdxT gridStrideX, - IdxT gridStrideY) { - #pragma unroll + DataT * regxn, + DataT * regyn, + IdxT gridStrideX, + IdxT gridStrideY) { +#pragma unroll for (int i = 0; i < KPolicy::AccRowsPerTh; ++i) { - #pragma unroll +#pragma unroll for (int j = 0; j < KPolicy::AccColsPerTh; ++j) { acc[i][j] = regxn[i] + regyn[j] - (DataT)2.0 * acc[i][j]; } } if (sqrt) { - #pragma unroll +#pragma unroll for (int i = 0; i < KPolicy::AccRowsPerTh; ++i) { - #pragma unroll +#pragma unroll for (int j = 0; j < KPolicy::AccColsPerTh; ++j) { acc[i][j] = raft::mySqrt(acc[i][j]); } @@ -133,33 +131,34 @@ void euclideanExpImpl(const DataT* x, } }; - constexpr size_t shmemSize = KPolicy::SmemSize + ((KPolicy::Mblk + KPolicy::Nblk) * sizeof(DataT)); + constexpr size_t shmemSize = + KPolicy::SmemSize + ((KPolicy::Mblk + KPolicy::Nblk) * sizeof(DataT)); if (isRowMajor) { - auto euclideanExpRowMajor = pairwiseDistanceMatKernel; + auto euclideanExpRowMajor = pairwiseDistanceMatKernelPriorToAmpere; dim3 grid = launchConfigGenerator(m, n, shmemSize, euclideanExpRowMajor); euclideanExpRowMajor<<>>( x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, fin_op); } else { - auto euclideanExpColMajor = pairwiseDistanceMatKernel; + auto euclideanExpColMajor = pairwiseDistanceMatKernelPriorToAmpere; dim3 grid = launchConfigGenerator(m, n, shmemSize, euclideanExpColMajor); euclideanExpColMajor<<>>( x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, fin_op); diff --git a/cpp/include/raft/distance/detail/pairwise_distance_base.cuh b/cpp/include/raft/distance/detail/pairwise_distance_base.cuh index 9d203c0c4f..5faba0486e 100644 --- a/cpp/include/raft/distance/detail/pairwise_distance_base.cuh +++ b/cpp/include/raft/distance/detail/pairwise_distance_base.cuh @@ -364,6 +364,89 @@ __global__ __launch_bounds__(Policy::Nthreads, 2) obj.run(); } +/** + * @brief the distance matrix calculation kernel for L2 and cosine + * for GPU arch < SM 8.0, this version is to make sure we don't recompile + * these kernels for ampere or higher as we use cutlass kernel for it. + * @tparam useNorms whether norms are needed + * @tparam DataT input data-type (for A and B matrices) + * @tparam AccT accumulation data-type + * @tparam OutT output data-type (for C and D matrices) + * @tparam IdxT index data-type + * @tparam Policy struct which tunes the Contraction kernel + * @tparam CoreLambda lambda which implements accumulation operation + * @tparam EpilogueLambda lambda which implements operation for calculating + final value. + * @tparam FinalLambda final lambda called on final distance value + * @tparam isRowMajor true if input/output is row major(default), + false for column major + * + * @param[in] x input matrix + * @param[in] y input matrix + * @param[in] xn row norms of input matrix A. + * @param[in] yn row norms of input matrix B. + * @param[in] m number of rows of A and C/D + * @param[in] n number of columns of B and C/D + * @param[in] k number of cols of A and rows of B + * @param[in] lda leading dimension of A + * @param[in] ldb leading dimension of B + * @param[in] ldd leading dimension of C/D + * @param[output] pD output matrix + * @param core_op the core lambda + * @param epilog_op the epilogue lambda + * @param fin_op the final gemm epilogue lambda + */ + +template +__global__ __launch_bounds__(Policy::Nthreads, 2) + + void pairwiseDistanceMatKernelPriorToAmpere(const DataT* x, + const DataT* y, + const DataT* _xn, + const DataT* _yn, + IdxT m, + IdxT n, + IdxT k, + IdxT lda, + IdxT ldb, + IdxT ldd, + OutT* dOutput, + CoreLambda core_op, + EpilogueLambda epilog_op, + FinalLambda fin_op) +{ +#if __CUDA_ARCH__ < 800 + extern __shared__ char smem[]; + auto rowEpilog = [] __device__(IdxT starty) { return; }; + + PairwiseDistances + obj( + x, y, m, n, k, lda, ldb, ldd, _xn, _yn, dOutput, smem, core_op, epilog_op, fin_op, rowEpilog); + obj.run(); +#endif +} + template dim3 launchConfigGenerator(IdxT m, IdxT n, std::size_t sMemSize, T func) { diff --git a/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh b/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh old mode 100755 new mode 100644 index 216d567444..ea216f8c73 --- a/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh +++ b/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh @@ -1,169 +1,168 @@ -/* - * Copyright (c) 2018-2022, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wstrict-aliasing" - -#include - -#include "cutlass/cutlass.h" -#include "cutlass/gemm/device/gemm_universal_adapter.h" -#include "cutlass/gemm/device/gemm.h" - -#include "cutlass/tensor_view.h" -#include "cutlass/matrix_coord.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/layout/tensor.h" - -#include "./pairwise_distance_gemm.h" -#include "./pairwise_distance_epilogue_elementwise.h" - -#define CUTLASS_CHECK(status) \ - { \ - cutlass::Status error = status; \ - if (error != cutlass::Status::kSuccess) { \ - std::cerr << "Got cutlass error: " << cutlassGetStatusString(error) << " at: " << __LINE__ \ - << std::endl; \ - exit(EXIT_FAILURE); \ - } \ - } - -namespace raft { -namespace distance { -namespace detail { - -template -void cutlassDistanceKernel(const DataT* x, - const DataT* y, - const DataT* xn, - const DataT* yn, - IdxT m, - IdxT n, - IdxT k, - IdxT lda, - IdxT ldb, - IdxT ldd, - OutT* dOutput, - FinalLambda fin_op, - DistanceFn dist_op, - cudaStream_t stream) { - - using EpilogueOutputOp = cutlass::epilogue::thread::PairwiseDistanceEpilogueElementwise< - DataT, // ElementC_ - AccT, // ElementAccumulator_ - DataT, // ElementCompute_ - AccT, // ElementZ_ - OutT, // ElementT_ - 1, // Elements per access 1 - DistanceFn, - FinalLambda - >; - constexpr int batch_count = 1; - - constexpr auto mode = cutlass::gemm::GemmUniversalMode::kGemm; - - typename EpilogueOutputOp::Params epilog_op_param(dist_op, fin_op); - - const DataT *a, *b; - - IdxT gemm_lda, gemm_ldb; - - // Number of pipelines you want to use - constexpr int NumStages = 3; - // Alignment - constexpr int Alignment = VecLen; - - // default initialize problem size with row major inputs - auto problem_size = cutlass::gemm::GemmCoord(static_cast(n), static_cast(m), static_cast(k)); - - using cutlassDistKernel = typename cutlass::gemm::kernel::PairwiseDistanceGemm< - DataT, Alignment, DataT, Alignment, - AccT, AccT, - EpilogueOutputOp, - NumStages, // Number of pipeline stages - isRowMajor - >::GemmKernel; - - using cutlassDist = cutlass::gemm::device::GemmUniversalAdapter; - - if constexpr (isRowMajor) { - a = y; - b = x; - gemm_lda = ldb; - gemm_ldb = lda; - } else { - problem_size = cutlass::gemm::GemmCoord(static_cast(m), static_cast(n), static_cast(k)); - a = x; - b = y; - gemm_lda = lda; - gemm_ldb = ldb; - } - - typename cutlassDist::Arguments arguments { - mode, - problem_size, - batch_count, - epilog_op_param, - a, - b, - xn, // C matrix eq vector param, which here is A norm - nullptr, //tensor_Z, - (DataT*)yn, // this is broadcast vec, which is required to be non-const param - dOutput, // Output distance matrix - (int64_t)0, // batch stride A - (int64_t) 0, // batch stride B - (int64_t)0, // batch stride Norm A - (int64_t)0, - (int64_t)0, // batch stride Norm B - (int64_t)0, // batch stride Output - gemm_lda, // stride A - gemm_ldb, // stride B - 1, // stride A norm - 0, // this is no-op for Z - 0, // This must be zero - ldd // stride Output matrix - }; - - // Using the arguments, query for extra workspace required for matrix multiplication computation - size_t workspace_size = cutlassDist::get_workspace_size(arguments); - // Allocate workspace memory - rmm::device_uvector workspace(workspace_size, stream); - // Instantiate CUTLASS kernel depending on templates - cutlassDist cutlassDist_op; - // Check the problem size is supported or not - cutlass::Status status = cutlassDist_op.can_implement(arguments); - CUTLASS_CHECK(status); - // Initialize CUTLASS kernel with arguments and workspace pointer - status = cutlassDist_op.initialize(arguments, workspace.data(), stream); - CUTLASS_CHECK(status); - // Launch initialized CUTLASS kernel - status = cutlassDist_op(); - CUTLASS_CHECK(status); -} - -}; -}; -}; +/* + * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wstrict-aliasing" + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" + +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/tensor_view.h" + +#include "./pairwise_distance_epilogue_elementwise.h" +#include "./pairwise_distance_gemm.h" + +#define CUTLASS_CHECK(status) \ + { \ + cutlass::Status error = status; \ + if (error != cutlass::Status::kSuccess) { \ + std::cerr << "Got cutlass error: " << cutlassGetStatusString(error) << " at: " << __LINE__ \ + << std::endl; \ + exit(EXIT_FAILURE); \ + } \ + } + +namespace raft { +namespace distance { +namespace detail { + +template +void cutlassDistanceKernel(const DataT* x, + const DataT* y, + const DataT* xn, + const DataT* yn, + IdxT m, + IdxT n, + IdxT k, + IdxT lda, + IdxT ldb, + IdxT ldd, + OutT* dOutput, + FinalLambda fin_op, + DistanceFn dist_op, + cudaStream_t stream) +{ + using EpilogueOutputOp = + cutlass::epilogue::thread::PairwiseDistanceEpilogueElementwise; + constexpr int batch_count = 1; + + constexpr auto mode = cutlass::gemm::GemmUniversalMode::kGemm; + + typename EpilogueOutputOp::Params epilog_op_param(dist_op, fin_op); + + const DataT *a, *b; + + IdxT gemm_lda, gemm_ldb; + + // Number of pipelines you want to use + constexpr int NumStages = 3; + // Alignment + constexpr int Alignment = VecLen; + + // default initialize problem size with row major inputs + auto problem_size = + cutlass::gemm::GemmCoord(static_cast(n), static_cast(m), static_cast(k)); + + using cutlassDistKernel = + typename cutlass::gemm::kernel::PairwiseDistanceGemm::GemmKernel; + + using cutlassDist = cutlass::gemm::device::GemmUniversalAdapter; + + if constexpr (isRowMajor) { + a = y; + b = x; + gemm_lda = ldb; + gemm_ldb = lda; + } else { + problem_size = + cutlass::gemm::GemmCoord(static_cast(m), static_cast(n), static_cast(k)); + a = x; + b = y; + gemm_lda = lda; + gemm_ldb = ldb; + } + + typename cutlassDist::Arguments arguments{ + mode, problem_size, batch_count, epilog_op_param, a, b, + xn, // C matrix eq vector param, which here is A norm + nullptr, // tensor_Z, + (DataT*)yn, // this is broadcast vec, which is required to be non-const param + dOutput, // Output distance matrix + (int64_t)0, // batch stride A + (int64_t)0, // batch stride B + (int64_t)0, // batch stride Norm A + (int64_t)0, + (int64_t)0, // batch stride Norm B + (int64_t)0, // batch stride Output + gemm_lda, // stride A + gemm_ldb, // stride B + 1, // stride A norm + 0, // this is no-op for Z + 0, // This must be zero + ldd // stride Output matrix + }; + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = cutlassDist::get_workspace_size(arguments); + // Allocate workspace memory + rmm::device_uvector workspace(workspace_size, stream); + // Instantiate CUTLASS kernel depending on templates + cutlassDist cutlassDist_op; + // Check the problem size is supported or not + cutlass::Status status = cutlassDist_op.can_implement(arguments); + CUTLASS_CHECK(status); + // Initialize CUTLASS kernel with arguments and workspace pointer + status = cutlassDist_op.initialize(arguments, workspace.data(), stream); + CUTLASS_CHECK(status); + // Launch initialized CUTLASS kernel + status = cutlassDist_op(); + CUTLASS_CHECK(status); +} + +}; // namespace detail +}; // namespace distance +}; // namespace raft #pragma GCC diagnostic pop \ No newline at end of file diff --git a/cpp/include/raft/distance/detail/pairwise_distance_epilogue.h b/cpp/include/raft/distance/detail/pairwise_distance_epilogue.h old mode 100755 new mode 100644 index a36d473596..e2360809bb --- a/cpp/include/raft/distance/detail/pairwise_distance_epilogue.h +++ b/cpp/include/raft/distance/detail/pairwise_distance_epilogue.h @@ -1,124 +1,109 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. - - The epilogue rearranges the result of a matrix product through shared memory to match canonical - tensor layouts in global memory. Epilogues support conversion and reduction operations. - -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/numeric_types.h" -#include "cutlass/array.h" - -#include "cutlass/gemm/gemm.h" - -#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" -#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h" -#include "cutlass/epilogue/threadblock/epilogue.h" -#include "cutlass/epilogue/threadblock/epilogue_with_broadcast.h" -#include "predicated_tile_iterator_normvec.h" - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace epilogue { -namespace threadblock { - -//////////////////////////////////////////////////////////////////////////////// - -/// Defines sensible defaults for epilogues for TensorOps. -template < - typename Shape, - typename WarpMmaTensorOp, - int PartitionsK, - typename ElementOutput, - typename ElementTensor, - typename ElementVector, - typename OutputOp, - typename LayoutT, - int ElementsPerAccess, - bool ScatterD = false -> -struct PairwiseDistanceEpilogue { - - /// Use defaults related to the existing epilogue - using Base = DefaultEpilogueTensorOp< - Shape, - WarpMmaTensorOp, - PartitionsK, - OutputOp, - ElementsPerAccess - >; - - // - // Stores the result z = (y = GEMM(A, B, C), broadcast) - // - using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorNormVec< - typename Base::OutputTileThreadMap, - ElementOutput, - LayoutT - >; - - // - // Additional tensor tile iterator - stores t = Elementwise(z) - // - using TensorTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< - typename Base::OutputTileThreadMap, - ElementTensor - >; - - /// Define the epilogue - using Epilogue = EpilogueWithBroadcast< - Shape, - WarpMmaTensorOp, - PartitionsK, - OutputTileIterator, - TensorTileIterator, - ElementVector, - typename Base::AccumulatorFragmentIterator, - typename Base::WarpTileIterator, - typename Base::SharedLoadIterator, - OutputOp, - typename Base::Padding, - Base::kFragmentsPerIteration - >; -}; - - -} // namespace threadblock -} // namespace epilogue -} // namespace cutlass - -//////////////////////////////////////////////////////////////////////////////// +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. + + The epilogue rearranges the result of a matrix product through shared memory to match canonical + tensor layouts in global memory. Epilogues support conversion and reduction operations. + +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/gemm.h" + +#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" +#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h" +#include "cutlass/epilogue/threadblock/epilogue.h" +#include "cutlass/epilogue/threadblock/epilogue_with_broadcast.h" +#include "predicated_tile_iterator_normvec.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Defines sensible defaults for epilogues for TensorOps. +template +struct PairwiseDistanceEpilogue { + /// Use defaults related to the existing epilogue + using Base = + DefaultEpilogueTensorOp; + + // + // Stores the result z = (y = GEMM(A, B, C), broadcast) + // + using OutputTileIterator = cutlass::epilogue::threadblock:: + PredicatedTileIteratorNormVec; + + // + // Additional tensor tile iterator - stores t = Elementwise(z) + // + using TensorTileIterator = + cutlass::epilogue::threadblock::PredicatedTileIterator; + + /// Define the epilogue + using Epilogue = EpilogueWithBroadcast; +}; + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/cpp/include/raft/distance/detail/pairwise_distance_epilogue_elementwise.h b/cpp/include/raft/distance/detail/pairwise_distance_epilogue_elementwise.h old mode 100755 new mode 100644 index 9833ef6cf3..68b1815fdd --- a/cpp/include/raft/distance/detail/pairwise_distance_epilogue_elementwise.h +++ b/cpp/include/raft/distance/detail/pairwise_distance_epilogue_elementwise.h @@ -1,169 +1,163 @@ -// -/*! \file - \brief Functor performing distance operations used by epilogues of pairwise distance - * kernels. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/numeric_types.h" -#include "cutlass/array.h" -#include "cutlass/functional.h" -#include "cutlass/numeric_conversion.h" - -#include "cutlass/epilogue/thread/activation.h" - - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace epilogue { -namespace thread { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// This base class is meant to define the concept required of the -/// EpilogueWithBroadcast::OutputOp -template < - typename ElementC_, - typename ElementAccumulator_, - typename ElementCompute_, - typename ElementZ_, - typename ElementT_, - int ElementsPerAccess, - typename DistanceOp_ , - typename FinalOp_ -> -class PairwiseDistanceEpilogueElementwise { -public: - - using ElementOutput = ElementC_; - using ElementC = ElementC_; - using ElementAccumulator = ElementAccumulator_; - using ElementCompute = ElementCompute_; - using ElementZ = ElementZ_; - using ElementT = ElementT_; - static int const kElementsPerAccess = ElementsPerAccess; - static int const kCount = kElementsPerAccess; - - using DistanceOp = DistanceOp_; - using FinalOp = FinalOp_; - - using FragmentAccumulator = Array; - using FragmentCompute = Array; - using FragmentC = Array; - using FragmentZ = Array; - using FragmentT = Array; - - using FragmentOutput = FragmentZ; - - static bool const kIsHeavy = false; //ElementwiseOp::kIsHeavy; - - /// If true, the 'Z' tensor is stored - static bool const kStoreZ = false; // We don't store anything in Z, - - /// If true, the 'T' tensor is stored - static bool const kStoreT = true; // this is our final output storage. - - /// Host-constructable parameters structure - struct Params { - FinalOp_ final_op_tmp; - DistanceOp_ dist_op_tmp; - - // - // Methods - // - CUTLASS_HOST_DEVICE - Params(DistanceOp_ dist_op, FinalOp lambdafn_): - final_op_tmp(lambdafn_), - dist_op_tmp(dist_op) - { - } - - CUTLASS_HOST_DEVICE - Params() - { - } - }; - -private: - // - // Data members - // - FinalOp_ final_op; - DistanceOp_ elementwise_op; - -public: - // - // Methods - // - - /// Constructor from Params - CUTLASS_HOST_DEVICE - PairwiseDistanceEpilogueElementwise(Params const ¶ms) : - final_op(params.final_op_tmp), elementwise_op(params.dist_op_tmp) { - } - - /// Returns true if source is needed - CUTLASS_HOST_DEVICE - bool is_source_needed() const { - // we use for making sure C matrix path is used for A mat norm. - return true; - } - - /// Functionally required for serial reduction in the epilogue - CUTLASS_HOST_DEVICE - void set_k_partition(int k_partition, int k_partition_count) { -#if 0 - if (k_partition) { - beta_ = ElementCompute(1); - } - - if (k_partition != k_partition_count - 1) { - skip_elementwise_ = true; - } -#endif - } - - /// Applies the operation when is_source_needed() is true - CUTLASS_HOST_DEVICE - void operator()( - FragmentZ &frag_Z, - FragmentT &frag_T, - FragmentAccumulator const &AB, - FragmentC const &frag_C, - FragmentCompute const &V) const { - - FragmentCompute tmp_Accum = NumericArrayConverter()(AB); - FragmentCompute tmp_C = NumericArrayConverter()(frag_C); - FragmentCompute result_Z; - FragmentCompute result_T; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kElementsPerAccess; ++i) { - result_Z[i] = elementwise_op(tmp_C[i], V[i], tmp_Accum[i]); - result_T[i] = final_op(result_Z[i], 0); - } - - NumericArrayConverter convert_t; - frag_T = convert_t(result_T); - } - - /// Applies the operation when is_source_needed() is false - CUTLASS_HOST_DEVICE - void operator()( - FragmentZ &frag_Z, - FragmentT &frag_T, - FragmentAccumulator const &AB, - FragmentCompute const &V) const { - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace thread -} // namespace epilogue -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// +// +/*! \file + \brief Functor performing distance operations used by epilogues of pairwise distance + * kernels. +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/epilogue/thread/activation.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace thread { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// This base class is meant to define the concept required of the +/// EpilogueWithBroadcast::OutputOp +template +class PairwiseDistanceEpilogueElementwise { + public: + using ElementOutput = ElementC_; + using ElementC = ElementC_; + using ElementAccumulator = ElementAccumulator_; + using ElementCompute = ElementCompute_; + using ElementZ = ElementZ_; + using ElementT = ElementT_; + static int const kElementsPerAccess = ElementsPerAccess; + static int const kCount = kElementsPerAccess; + + using DistanceOp = DistanceOp_; + using FinalOp = FinalOp_; + + using FragmentAccumulator = Array; + using FragmentCompute = Array; + using FragmentC = Array; + using FragmentZ = Array; + using FragmentT = Array; + + using FragmentOutput = FragmentZ; + + static bool const kIsHeavy = false; // ElementwiseOp::kIsHeavy; + + /// If true, the 'Z' tensor is stored + static bool const kStoreZ = false; // We don't store anything in Z, + + /// If true, the 'T' tensor is stored + static bool const kStoreT = true; // this is our final output storage. + + /// Host-constructable parameters structure + struct Params { + FinalOp_ final_op_; + DistanceOp_ dist_op_; + + // + // Methods + // + CUTLASS_HOST_DEVICE + Params(DistanceOp_ dist_op, FinalOp final_op) : final_op_(final_op), dist_op_(dist_op) {} + + CUTLASS_HOST_DEVICE + Params() {} + }; + + private: + // + // Data members + // + FinalOp_ final_op; + DistanceOp_ elementwise_op; + + public: + // + // Methods + // + + /// Constructor from Params + CUTLASS_HOST_DEVICE + PairwiseDistanceEpilogueElementwise(Params const& params) + : final_op(params.final_op_), elementwise_op(params.dist_op_) + { + } + + /// Returns true if source is needed + CUTLASS_HOST_DEVICE + bool is_source_needed() const + { + // we use for making sure C matrix path is used for A mat norm. + return true; + } + + /// Functionally required for serial reduction in the epilogue + CUTLASS_HOST_DEVICE + void set_k_partition(int k_partition, int k_partition_count) + { +#if 0 + if (k_partition) { + beta_ = ElementCompute(1); + } + + if (k_partition != k_partition_count - 1) { + skip_elementwise_ = true; + } +#endif + } + + /// Applies the operation when is_source_needed() is true + CUTLASS_HOST_DEVICE + void operator()(FragmentZ& frag_Z, + FragmentT& frag_T, + FragmentAccumulator const& AB, + FragmentC const& frag_C, + FragmentCompute const& V) const + { + FragmentCompute tmp_Accum = + NumericArrayConverter()(AB); + FragmentCompute tmp_C = + NumericArrayConverter()(frag_C); + FragmentCompute result_Z; + FragmentCompute result_T; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kElementsPerAccess; ++i) { + result_Z[i] = elementwise_op(tmp_C[i], V[i], tmp_Accum[i]); + result_T[i] = final_op(result_Z[i], 0); + } + + NumericArrayConverter convert_t; + frag_T = convert_t(result_T); + } + + /// Applies the operation when is_source_needed() is false + CUTLASS_HOST_DEVICE + void operator()(FragmentZ& frag_Z, + FragmentT& frag_T, + FragmentAccumulator const& AB, + FragmentCompute const& V) const + { + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace thread +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/cpp/include/raft/distance/detail/pairwise_distance_gemm.h b/cpp/include/raft/distance/detail/pairwise_distance_gemm.h old mode 100755 new mode 100644 index 15635edaee..a47dcb523e --- a/cpp/include/raft/distance/detail/pairwise_distance_gemm.h +++ b/cpp/include/raft/distance/detail/pairwise_distance_gemm.h @@ -1,206 +1,224 @@ -#pragma once - -#include "cutlass/cutlass.h" - -#include "cutlass/gemm/kernel/gemm_with_fused_epilogue.h" -#include "cutlass/gemm/kernel/default_gemm_universal.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/layout/tensor.h" - -//#include "./epilogue_with_bcast_threadblock.h" -#include "pairwise_distance_epilogue.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace kernel { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - /// Element type for A matrix operand - typename ElementA_, - /// Layout type for A matrix operand - int kAlignmentA, - /// Element type for B matrix operand - typename ElementB_, - /// Layout type for B matrix operand - int kAlignmentB, - /// Element type for C and D matrix operands - typename ElementC_, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Element type for final output - //typename ElementOutT, - /// Epilogue output operator - must satisfy concept of 'EpilogueWithBroadcastOp' - typename EpilogueOutputOp, - /// Number of stages used in the pipelined mainloop - int Stages, - /// data layout row/column major of inputs - bool isRowMajor -> -struct PairwiseDistanceGemm { - - // This struct is specialized for fp32/3xTF32 - - /// Threadblock-level tile size (concept: GemmShape) - using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 16>; // <- threadblock tile M = 128, N = 64, K = 16 - /// Warp-level tile size (concept: GemmShape) - // This code section describes tile size a warp will compute - using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; // <- warp tile M = 64, N = 64, K = 16 - /// Warp-level tile size (concept: GemmShape) - // This code section describes the size of MMA op - using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; // <- MMA Op tile M = 16, N = 8, K = 8 - - /// Operation performed by GEMM - using Operator = cutlass::arch::OpMultiplyAddFastF32; - - // This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM - using OperatorClass = cutlass::arch::OpClassTensorOp; - - // This code section describes CUDA SM architecture number - using ArchTag = cutlass::arch::Sm80; - - // This code section describes how threadblocks are scheduled on GPU - /// Threadblock-level swizzling operator - using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ?? - - /// data layout for final output matrix. - // we keep this same layout even for column major inputs - using LayoutOutput = cutlass::layout::RowMajor; - - typedef typename std::conditional::type NormXLayout; - - typedef typename std::conditional::type LayoutA_; - - typedef typename std::conditional::type LayoutB_; - - using GemmBase = typename DefaultGemmUniversal< - ElementA_, LayoutA_, cutlass::ComplexTransform::kNone, kAlignmentA, - ElementB_, LayoutB_, cutlass::ComplexTransform::kNone, kAlignmentB, - ElementC_, LayoutOutput, ElementAccumulator, - OperatorClass, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - Stages, - Operator - >::GemmKernel; - - // Replace epilogue - using Epilogue = typename cutlass::epilogue::threadblock::PairwiseDistanceEpilogue< - typename GemmBase::Epilogue::Shape, - typename GemmBase::Epilogue::WarpMmaOperator, - GemmBase::Epilogue::kPartitionsK, - ElementAccumulator, - typename EpilogueOutputOp::ElementT, - ElementAccumulator, - EpilogueOutputOp, - NormXLayout, - GemmBase::Epilogue::kElementsPerAccess - >::Epilogue; - - // Compose the GEMM kernel - using GemmKernel = GemmWithFusedEpilogue< - typename GemmBase::Mma, - Epilogue, - ThreadblockSwizzle - >; -}; - -template < - /// Layout type for A matrix operand - int kAlignmentA, - /// Layout type for B matrix operand - int kAlignmentB, - /// Element type for C and D matrix operands - typename ElementC_, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Epilogue output operator - must satisfy concept of 'EpilogueWithBroadcastOp' - typename EpilogueOutputOp, - /// Number of stages used in the pipelined mainloop - int Stages, - /// data layout row/column major of inputs - bool isRowMajor -> -struct PairwiseDistanceGemm { - - //using Transform = cutlass::ComplexTransform::kNone; - // Threadblock-level tile size (concept: GemmShape) - using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 16>; // <- threadblock tile M = 128, N = 64, K = 16 - /// Warp-level tile size (concept: GemmShape) - // This code section describes tile size a warp will compute - using WarpShape = cutlass::gemm::GemmShape<32, 32, 16>; // <- warp tile M = 64, N = 64, K = 16 - /// Warp-level tile size (concept: GemmShape) - // This code section describes the size of MMA op - using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; - - // Operation performed by GEMM - using Operator = cutlass::arch::OpMultiplyAdd; - // This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM - using OperatorClass = cutlass::arch::OpClassTensorOp; - - // This code section describes CUDA SM architecture number - using ArchTag = cutlass::arch::Sm80; - - // This code section describes how threadblocks are scheduled on GPU - /// Threadblock-level swizzling operator - using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ?? - - /// data layout for final output matrix. - // we keep this same layout even for column major inputs - using LayoutOutput = cutlass::layout::RowMajor; - - typedef typename std::conditional::type NormXLayout; - - typedef typename std::conditional::type LayoutA_; - - typedef typename std::conditional::type LayoutB_; - - using GemmBase = typename DefaultGemmUniversal< - double, LayoutA_, cutlass::ComplexTransform::kNone, 1, - double, LayoutB_, cutlass::ComplexTransform::kNone, 1, - ElementC_, LayoutOutput, ElementAccumulator, - OperatorClass, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - Stages, - Operator - >::GemmKernel; - - // Replace epilogue - using Epilogue = typename cutlass::epilogue::threadblock::PairwiseDistanceEpilogue< - typename GemmBase::Epilogue::Shape, - typename GemmBase::Epilogue::WarpMmaOperator, - GemmBase::Epilogue::kPartitionsK, - ElementC_, - typename EpilogueOutputOp::ElementT, - ElementC_, - EpilogueOutputOp, - NormXLayout, - GemmBase::Epilogue::kElementsPerAccess - >::Epilogue; - - // Compose the GEMM kernel - using GemmKernel = GemmWithFusedEpilogue< - typename GemmBase::Mma, - Epilogue, - ThreadblockSwizzle - >; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} -} -} \ No newline at end of file +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/kernel/default_gemm_universal.h" +#include "cutlass/gemm/kernel/gemm_with_fused_epilogue.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/tensor.h" + +//#include "./epilogue_with_bcast_threadblock.h" +#include "pairwise_distance_epilogue.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC_, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Element type for final output + // typename ElementOutT, + /// Epilogue output operator - must satisfy concept of 'EpilogueWithBroadcastOp' + typename EpilogueOutputOp, + /// Number of stages used in the pipelined mainloop + int Stages, + /// data layout row/column major of inputs + bool isRowMajor> +struct PairwiseDistanceGemm { + // This struct is specialized for fp32/3xTF32 + + /// Threadblock-level tile size (concept: GemmShape) + using ThreadblockShape = + cutlass::gemm::GemmShape<128, 128, 16>; // <- threadblock tile M = 128, N = 64, K = 16 + /// Warp-level tile size (concept: GemmShape) + // This code section describes tile size a warp will compute + using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; // <- warp tile M = 64, N = 64, K = 16 + /// Warp-level tile size (concept: GemmShape) + // This code section describes the size of MMA op + using InstructionShape = + cutlass::gemm::GemmShape<16, 8, 8>; // <- MMA Op tile M = 16, N = 8, K = 8 + + /// Operation performed by GEMM + using Operator = cutlass::arch::OpMultiplyAddFastF32; + + // This code section describes whether you want to use tensor cores or regular SIMT cores on GPU + // SM + using OperatorClass = cutlass::arch::OpClassTensorOp; + + // This code section describes CUDA SM architecture number + using ArchTag = cutlass::arch::Sm80; + + // This code section describes how threadblocks are scheduled on GPU + /// Threadblock-level swizzling operator + using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ?? + + /// data layout for final output matrix. + // we keep this same layout even for column major inputs + using LayoutOutput = cutlass::layout::RowMajor; + + typedef typename std::conditional::type NormXLayout; + + typedef typename std:: + conditional::type LayoutA_; + + typedef typename std:: + conditional::type LayoutB_; + + using GemmBase = typename DefaultGemmUniversal::GemmKernel; + + // Replace epilogue + using Epilogue = typename cutlass::epilogue::threadblock::PairwiseDistanceEpilogue< + typename GemmBase::Epilogue::Shape, + typename GemmBase::Epilogue::WarpMmaOperator, + GemmBase::Epilogue::kPartitionsK, + ElementAccumulator, + typename EpilogueOutputOp::ElementT, + ElementAccumulator, + EpilogueOutputOp, + NormXLayout, + GemmBase::Epilogue::kElementsPerAccess>::Epilogue; + + // Compose the GEMM kernel + using GemmKernel = GemmWithFusedEpilogue; +}; + +template < + /// Layout type for A matrix operand + int kAlignmentA, + /// Layout type for B matrix operand + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC_, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Epilogue output operator - must satisfy concept of 'EpilogueWithBroadcastOp' + typename EpilogueOutputOp, + /// Number of stages used in the pipelined mainloop + int Stages, + /// data layout row/column major of inputs + bool isRowMajor> +struct PairwiseDistanceGemm { + // using Transform = cutlass::ComplexTransform::kNone; + // Threadblock-level tile size (concept: GemmShape) + using ThreadblockShape = + cutlass::gemm::GemmShape<64, 64, 16>; // <- threadblock tile M = 128, N = 64, K = 16 + /// Warp-level tile size (concept: GemmShape) + // This code section describes tile size a warp will compute + using WarpShape = cutlass::gemm::GemmShape<32, 32, 16>; // <- warp tile M = 64, N = 64, K = 16 + /// Warp-level tile size (concept: GemmShape) + // This code section describes the size of MMA op + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; + + // Operation performed by GEMM + using Operator = cutlass::arch::OpMultiplyAdd; + // This code section describes whether you want to use tensor cores or regular SIMT cores on GPU + // SM + using OperatorClass = cutlass::arch::OpClassTensorOp; + + // This code section describes CUDA SM architecture number + using ArchTag = cutlass::arch::Sm80; + + // This code section describes how threadblocks are scheduled on GPU + /// Threadblock-level swizzling operator + using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ?? + + /// data layout for final output matrix. + // we keep this same layout even for column major inputs + using LayoutOutput = cutlass::layout::RowMajor; + + typedef typename std::conditional::type NormXLayout; + + typedef typename std:: + conditional::type LayoutA_; + + typedef typename std:: + conditional::type LayoutB_; + + using GemmBase = typename DefaultGemmUniversal::GemmKernel; + + // Replace epilogue + using Epilogue = typename cutlass::epilogue::threadblock::PairwiseDistanceEpilogue< + typename GemmBase::Epilogue::Shape, + typename GemmBase::Epilogue::WarpMmaOperator, + GemmBase::Epilogue::kPartitionsK, + ElementC_, + typename EpilogueOutputOp::ElementT, + ElementC_, + EpilogueOutputOp, + NormXLayout, + GemmBase::Epilogue::kElementsPerAccess>::Epilogue; + + // Compose the GEMM kernel + using GemmKernel = GemmWithFusedEpilogue; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass \ No newline at end of file diff --git a/cpp/include/raft/distance/detail/predicated_tile_iterator_normvec.h b/cpp/include/raft/distance/detail/predicated_tile_iterator_normvec.h old mode 100755 new mode 100644 index a1536b2493..effb17ee8b --- a/cpp/include/raft/distance/detail/predicated_tile_iterator_normvec.h +++ b/cpp/include/raft/distance/detail/predicated_tile_iterator_normvec.h @@ -1,670 +1,591 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. - - The epilogue rearranges the result of a matrix product through shared memory to match canonical - tensor layouts in global memory. Epilogues support conversion and reduction operations. - -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/numeric_types.h" -#include "cutlass/array.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/layout/tensor.h" -#include "cutlass/matrix_shape.h" -#include "cutlass/tensor_ref.h" -#include "cutlass/transform/pitch_linear_thread_map.h" -#include "cutlass/epilogue/threadblock/output_tile_thread_map.h" -#include "cutlass/arch/arch.h" -#include "cutlass/arch/memory.h" -#include "cutlass/epilogue/threadblock/predicated_tile_iterator_params.h" - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { - -//////////////////////////////////////////////////////////////////////////////// - -namespace epilogue { -namespace threadblock { - -//////////////////////////////////////////////////////////////////////////////// - -/// Tile iterator used to load and store output tile from global memory in epilogue. -/// -/// Satisfies: ReadableTileIterator | PredicatedTileIterator | ForwardTileIterator -/// -template < - typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap) - typename Element_, ///< Element data type - typename Layout_, - bool ScatterD = false, ///< Scatter D operand or not - bool UseCUDAStore = false -> -class PredicatedTileIteratorNormVec { -public: - using ThreadMap = ThreadMap_; - using Shape = typename ThreadMap::Shape; - - using Element = Element_; - - //using Layout = layout::RowMajor; - using Layout = Layout_; - using TensorRef = TensorRef; - using ConstTensorRef = typename TensorRef::ConstTensorRef; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - using TensorCoord = MatrixCoord; - - static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; - static int const kThreads = ThreadMap::kThreads; - static int const kIterations = ThreadMap::Count::kTile; - - static_assert( ThreadMap::Iterations::kRow > 0,"ThreadMap::Iterations::kRow must be > 0"); - static_assert( ThreadMap::Iterations::kGroup > 0,"ThreadMap::Iterations::kGroup must be > 0"); - static_assert( ThreadMap::Iterations::kCluster > 0,"ThreadMap::Iterations::kCluster must be > 0"); - static_assert( ThreadMap::Iterations::kColumn > 0,"ThreadMap::Iterations::kColumn must be > 0"); - - /// Fragment object - using Fragment = Array< - Element, - ThreadMap::Iterations::kColumn * - ThreadMap::Iterations::kRow * - ThreadMap::Iterations::kGroup * - ThreadMap::Iterations::kCluster * ThreadMap::kElementsPerAccess>; - - /// Memory access size - using AccessType = AlignedArray; - - // - // Parameters struct - // - - /// Uses a non-template class - struct Params : PredicatedTileIteratorParams { - using Base = PredicatedTileIteratorParams; - - CUTLASS_HOST_DEVICE - Params() { } - - CUTLASS_HOST_DEVICE - Params(Layout const &layout): - PredicatedTileIteratorParams( - layout.stride(0) * int(sizeof(AccessType)) / kElementsPerAccess, - make_OutputTileThreadMapDesc() - ) - { } - - CUTLASS_HOST_DEVICE - Params(Base const &base) : - Base(base) { } - }; - - /// Mask object - struct Mask { - - static int const kCount = ThreadMap::Iterations::kColumn; - - /// Predicate state - bool predicates[kCount]; - - // - // Mask - // - CUTLASS_HOST_DEVICE - Mask() { - enable(); - } - - ///< Efficiently disables all accesses guarded by mask - CUTLASS_HOST_DEVICE void clear() { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kCount; ++i) { - predicates[i] = false; - } - } - - ///< CUTLASS_HOST_DEVICE enables all accesses guarded by mask - CUTLASS_DEVICE void enable() { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kCount; ++i) { - predicates[i] = true; - } - } - }; - -private: - - // - // Data members - // - - /// Parameters structure containing reference and precomputed state. - PredicatedTileIteratorParams params_; - - /// Byte-level pointer - uint8_t *byte_pointer_; - - /// Array of boolean values to contain steady-state predicates - Mask mask_; - - /// Extent of the matrix tile in rows - Index extent_row_; - - /// Extent of the matrix tile in rows - Index extent_column_; - - /// A thread's starting row position (assuming steady-state predicates have been computed) - Index thread_start_row_; - - /// A thread's starting column - Index thread_start_column_; - - /// Internal state counter - int state_[3]; - - /// Scatter indices - int const *indices_; - - // - // Static asserts about internal strides - // - - static_assert(sizeof(extent_row_) == 4, "Expected 32b extents"); - static_assert(sizeof(thread_start_row_) == 4, "Expected 32b extents"); - static_assert(sizeof(PredicatedTileIteratorParams::stride) == 8, "Expected 64b strides"); - -private: - - // - // Methods - // - -public: - - // - // Methods - // - - /// Constructor - CUTLASS_DEVICE - PredicatedTileIteratorNormVec( - PredicatedTileIteratorParams const & params, - Element *pointer, - TensorCoord extent, - int thread_idx, - TensorCoord threadblock_offset = TensorCoord(), - int const *indices = nullptr - ): - params_(params), indices_(indices) - { - - TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx) + threadblock_offset; - - extent_row_ = extent.row(); - extent_column_ = extent.column(); - - thread_start_row_ = thread_offset.row(); - thread_start_column_ = thread_offset.column(); - - // Initialize predicates - CUTLASS_PRAGMA_UNROLL - for (int c = 0; c < ThreadMap::Iterations::kColumn; ++c) { - - mask_.predicates[c] = ((thread_offset.column() - + ThreadMap::Delta::kColumn * c) < extent.column()); - } - - // Null pointer performs no accesses - if (!pointer) { - mask_.clear(); - } - - if (ScatterD && !indices) { - mask_.clear(); - } - - // Initialize pointer - byte_pointer_ = reinterpret_cast(pointer) + - LongIndex(thread_offset.row()) * LongIndex(params_.stride); - - if (ScatterD) { - byte_pointer_ = reinterpret_cast(pointer) + - LongIndex(thread_offset.column()) * sizeof(AccessType) / kElementsPerAccess; - } - - // Initialize internal state counter - state_[0] = state_[1] = state_[2] = 0; - } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - byte_pointer_ += pointer_offset * sizeof_bits::value / 8; - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load_with_byte_offset(Fragment &frag, int64_t byte_offset) const { - uint8_t *byte_pointer = byte_pointer_; - AccessType *frag_ptr = reinterpret_cast(&frag); - - CUTLASS_PRAGMA_UNROLL - for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { - - CUTLASS_PRAGMA_UNROLL - for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { - - CUTLASS_PRAGMA_UNROLL - for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { - - int frag_row_idx = - (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); - - int row_offset = row * ThreadMap::Delta::kRow - + group * ThreadMap::Delta::kGroup - + cluster * ThreadMap::Delta::kCluster; - - bool row_guard = ((row_offset + thread_start_row_) < extent_row_); - - AccessType *memory_pointer = reinterpret_cast(byte_pointer + byte_offset); - - if (ScatterD && row_guard) { - assert(indices_); - - memory_pointer = reinterpret_cast(byte_pointer + byte_offset + - LongIndex(indices_[row_offset + thread_start_row_]) * LongIndex(params_.stride)); - } - - CUTLASS_PRAGMA_UNROLL - for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { - bool guard = row_guard && mask_.predicates[column]; - - - cutlass::arch::global_load< - AccessType, - sizeof(AccessType) - >( - frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column], - (void *)&memory_pointer[0], - guard); - - } - - if (row + 1 < ThreadMap::Iterations::kRow) { - if (!ScatterD) { - byte_pointer += params_.increment_row; - } - } - } - - if (group + 1 < ThreadMap::Iterations::kGroup) { - byte_pointer += params_.increment_group; - } - } - - if (cluster + 1 < ThreadMap::Iterations::kCluster) { - byte_pointer += params_.increment_cluster; - } - } - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load(Fragment &frag) const { - - load_with_byte_offset(frag, 0); - } - - /// Stores a fragment to memory - CUTLASS_DEVICE - void store_with_byte_offset(Fragment const &frag, int64_t byte_offset) const { - uint8_t *byte_pointer = byte_pointer_; - AccessType const *frag_ptr = reinterpret_cast(&frag); - - CUTLASS_PRAGMA_UNROLL - for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { - - CUTLASS_PRAGMA_UNROLL - for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { - - CUTLASS_PRAGMA_UNROLL - for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { - - int frag_row_idx = - (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); - - int row_offset = row * ThreadMap::Delta::kRow - + group * ThreadMap::Delta::kGroup - + cluster * ThreadMap::Delta::kCluster; - - bool row_guard = ((row_offset + thread_start_row_) < extent_row_); - - AccessType *memory_pointer = reinterpret_cast(byte_pointer + byte_offset); - - if (ScatterD && row_guard) { - assert(indices_); - - memory_pointer = reinterpret_cast(byte_pointer + byte_offset + - LongIndex(indices_[row_offset + thread_start_row_]) * LongIndex(params_.stride)); - } - - CUTLASS_PRAGMA_UNROLL - for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { - - bool guard = row_guard && mask_.predicates[column]; - - if (UseCUDAStore) { - if (guard) { - memory_pointer[column * ThreadMap::Delta::kColumn / kElementsPerAccess] = - frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column]; - } - } else { - cutlass::arch::global_store( - frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column], - (void *)&memory_pointer[column * ThreadMap::Delta::kColumn / kElementsPerAccess], - guard); - } - } - - if (row + 1 < ThreadMap::Iterations::kRow) { - if (!ScatterD) { - byte_pointer += params_.increment_row; - } - } - } - - if (group + 1 < ThreadMap::Iterations::kGroup) { - byte_pointer += params_.increment_group; - } - } - - if (cluster + 1 < ThreadMap::Iterations::kCluster) { - byte_pointer += params_.increment_cluster; - } - } - } - - /// Stores a fragment to memory - CUTLASS_DEVICE - void store(Fragment const &frag) const { - - store_with_byte_offset(frag, 0); - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void downsample_load_with_byte_offset(Fragment &frag, int64_t byte_offset, int convolution_P, int convolution_Q, int add_P, int add_Q, int problem_N) const { - - uint8_t *byte_pointer = byte_pointer_; - AccessType *frag_ptr = reinterpret_cast(&frag); - - CUTLASS_PRAGMA_UNROLL - for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { - - CUTLASS_PRAGMA_UNROLL - for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { - - CUTLASS_PRAGMA_UNROLL - for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { - - int frag_row_idx = - (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); - - int row_offset = row * ThreadMap::Delta::kRow - + group * ThreadMap::Delta::kGroup - + cluster * ThreadMap::Delta::kCluster; - - bool row_guard = ((row_offset + thread_start_row_) < extent_row_); - - int output_row = row_offset + thread_start_row_; - int output_N = output_row / (convolution_P * convolution_Q); - int output_PQ = output_row % (convolution_P * convolution_Q); - int output_P = output_PQ / convolution_Q; - int output_Q = output_PQ % convolution_Q; - - int input_row = output_N * 2 * convolution_P * 2 * convolution_Q + - (2 * output_P + add_P) * 2 * convolution_Q + 2 * output_Q + add_Q; - - int64_t byte_offset = (input_row-output_row)*problem_N*sizeof(float); - - AccessType *memory_pointer = reinterpret_cast(byte_pointer + byte_offset); - - CUTLASS_PRAGMA_UNROLL - for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { - - bool guard = row_guard && mask_.predicates[column]; - - cutlass::arch::global_load< - AccessType, - sizeof(AccessType) - >( - frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + - column], - (void *)&memory_pointer[column * ThreadMap::Delta::kColumn / - kElementsPerAccess], - guard); - } - - if (row + 1 < ThreadMap::Iterations::kRow) { - byte_pointer += params_.increment_row; - } - } - - if (group + 1 < ThreadMap::Iterations::kGroup) { - byte_pointer += params_.increment_group; - } - } - - if (cluster + 1 < ThreadMap::Iterations::kCluster) { - byte_pointer += params_.increment_cluster; - } - } - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void upsample_load_with_byte_offset(Fragment &frag, int64_t byte_offset, int convolution_P, int convolution_Q, int add_P, int add_Q, int problem_N) const { - - uint8_t *byte_pointer = byte_pointer_; - AccessType *frag_ptr = reinterpret_cast(&frag); - - CUTLASS_PRAGMA_UNROLL - for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { - - CUTLASS_PRAGMA_UNROLL - for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { - - CUTLASS_PRAGMA_UNROLL - for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { - - int frag_row_idx = - (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); - - int row_offset = row * ThreadMap::Delta::kRow - + group * ThreadMap::Delta::kGroup - + cluster * ThreadMap::Delta::kCluster; - - bool row_guard = ((row_offset + thread_start_row_) < extent_row_); - - int output_row = row_offset + thread_start_row_; - int output_N = output_row / (convolution_P * convolution_Q); - int output_PQ = output_row % (convolution_P * convolution_Q); - int output_P = output_PQ / convolution_Q; - int output_Q = output_PQ % convolution_Q; - int row_add_P = add_P; - int row_add_Q = add_Q; - if (output_P > convolution_P - 2) row_add_P = 0; - if (output_Q > convolution_Q - 2) row_add_Q = 0; - - int input_row = output_N * (convolution_P/2) * (convolution_Q/2) + - ((output_P + row_add_P)/2) * (convolution_Q/2) + (output_Q + row_add_Q)/2; - - int64_t byte_offset = (input_row-output_row)*problem_N*sizeof(float); - - AccessType *memory_pointer = reinterpret_cast(byte_pointer + byte_offset); - - CUTLASS_PRAGMA_UNROLL - for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { - - bool guard = row_guard && mask_.predicates[column]; - - cutlass::arch::global_load< - AccessType, - sizeof(AccessType) - >( - frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + - column], - (void *)&memory_pointer[column * ThreadMap::Delta::kColumn / - kElementsPerAccess], - guard); - } - - if (row + 1 < ThreadMap::Iterations::kRow) { - byte_pointer += params_.increment_row; - } - } - - if (group + 1 < ThreadMap::Iterations::kGroup) { - byte_pointer += params_.increment_group; - } - } - - if (cluster + 1 < ThreadMap::Iterations::kCluster) { - byte_pointer += params_.increment_cluster; - } - } - } - - CUTLASS_DEVICE - MatrixCoord thread_start() const { - return MatrixCoord(thread_start_row_, thread_start_column_); - } - - /// Need to get the thread start row from the tile iterator - CUTLASS_DEVICE - int32_t thread_start_row() const { - return thread_start_row_; - } - - /// Need to get the thread start row from the tile iterator - CUTLASS_DEVICE - int32_t thread_start_column() const { - return thread_start_column_; - } - - /// Extent of the matrix in rows - CUTLASS_DEVICE - Index extent_row() const { - return extent_row_; - } - - /// Extent of the matrix in columns - CUTLASS_DEVICE - Index extent_column() const { - return extent_column_; - } - - /// Advances to the next position to load or store - CUTLASS_HOST_DEVICE - PredicatedTileIteratorNormVec &operator++() { - - ++state_[0]; - - if (!ScatterD) { - byte_pointer_ += params_.advance_row; - } - - thread_start_row_ += ThreadMap::Shape::kRow; - - if (state_[0] == ThreadMap::Count::kRow) { - - state_[0] = 0; - ++state_[1]; - byte_pointer_ += params_.advance_group; - - thread_start_row_ += (ThreadMap::Shape::kGroup - 1) * - ThreadMap::Shape::kRow * ThreadMap::Count::kRow; - - if (state_[1] == ThreadMap::Count::kGroup) { - - state_[1] = 0; - ++state_[2]; - byte_pointer_ += params_.advance_cluster; - - thread_start_row_ += ThreadMap::Count::kGroup * - ThreadMap::Shape::kGroup * ThreadMap::Count::kRow * ThreadMap::Shape::kRow; - - if (state_[2] == ThreadMap::Count::kCluster) { - state_[2] = 0; - byte_pointer_ += params_.advance_tile; - } - } - } - - return *this; - } - - ///< Efficiently disables all accesses guarded by mask - CUTLASS_DEVICE void clear_mask() { - mask_.clear(); - } - - ///< Efficiently enables all accesses guarded by mask - CUTLASS_DEVICE void enable_mask() { - mask_.enable(); - } - - ///< Sets the mask - CUTLASS_DEVICE void get_mask(Mask &mask) const { - mask = mask_; - } - - ///< Sets the mask - CUTLASS_DEVICE void set_mask(Mask const &mask) { - mask_ = mask; - } -}; - - - -/////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace epilogue -} // namespace cutlass - -//////////////////////////////////////////////////////////////////////////////// +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. + + The epilogue rearranges the result of a matrix product through shared memory to match canonical + tensor layouts in global memory. Epilogues support conversion and reduction operations. + +*/ + +#pragma once + +#include "cutlass/arch/arch.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/threadblock/output_tile_thread_map.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator_params.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/transform/pitch_linear_thread_map.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +//////////////////////////////////////////////////////////////////////////////// + +namespace epilogue { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator used to load and store output tile from global memory in epilogue. +/// +/// Satisfies: ReadableTileIterator | PredicatedTileIterator | ForwardTileIterator +/// +template +class PredicatedTileIteratorNormVec { + public: + using ThreadMap = ThreadMap_; + using Shape = typename ThreadMap::Shape; + + using Element = Element_; + + // using Layout = layout::RowMajor; + using Layout = Layout_; + using TensorRef = TensorRef; + using ConstTensorRef = typename TensorRef::ConstTensorRef; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using TensorCoord = MatrixCoord; + + static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; + static int const kThreads = ThreadMap::kThreads; + static int const kIterations = ThreadMap::Count::kTile; + + static_assert(ThreadMap::Iterations::kRow > 0, "ThreadMap::Iterations::kRow must be > 0"); + static_assert(ThreadMap::Iterations::kGroup > 0, "ThreadMap::Iterations::kGroup must be > 0"); + static_assert(ThreadMap::Iterations::kCluster > 0, "ThreadMap::Iterations::kCluster must be > 0"); + static_assert(ThreadMap::Iterations::kColumn > 0, "ThreadMap::Iterations::kColumn must be > 0"); + + /// Fragment object + using Fragment = Array; + + /// Memory access size + using AccessType = AlignedArray; + + // + // Parameters struct + // + + /// Uses a non-template class + struct Params : PredicatedTileIteratorParams { + using Base = PredicatedTileIteratorParams; + + CUTLASS_HOST_DEVICE + Params() {} + + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : PredicatedTileIteratorParams( + layout.stride(0) * int(sizeof(AccessType)) / kElementsPerAccess, + make_OutputTileThreadMapDesc()) + { + } + + CUTLASS_HOST_DEVICE + Params(Base const& base) : Base(base) {} + }; + + /// Mask object + struct Mask { + static int const kCount = ThreadMap::Iterations::kColumn; + + /// Predicate state + bool predicates[kCount]; + + // + // Mask + // + CUTLASS_HOST_DEVICE + Mask() { enable(); } + + ///< Efficiently disables all accesses guarded by mask + CUTLASS_HOST_DEVICE void clear() + { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { + predicates[i] = false; + } + } + + ///< CUTLASS_HOST_DEVICE enables all accesses guarded by mask + CUTLASS_DEVICE void enable() + { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { + predicates[i] = true; + } + } + }; + + private: + // + // Data members + // + + /// Parameters structure containing reference and precomputed state. + PredicatedTileIteratorParams params_; + + /// Byte-level pointer + uint8_t* byte_pointer_; + + /// Array of boolean values to contain steady-state predicates + Mask mask_; + + /// Extent of the matrix tile in rows + Index extent_row_; + + /// Extent of the matrix tile in rows + Index extent_column_; + + /// A thread's starting row position (assuming steady-state predicates have been computed) + Index thread_start_row_; + + /// A thread's starting column + Index thread_start_column_; + + /// Internal state counter + int state_[3]; + + /// Scatter indices + int const* indices_; + + // + // Static asserts about internal strides + // + + static_assert(sizeof(extent_row_) == 4, "Expected 32b extents"); + static_assert(sizeof(thread_start_row_) == 4, "Expected 32b extents"); + static_assert(sizeof(PredicatedTileIteratorParams::stride) == 8, "Expected 64b strides"); + + private: + // + // Methods + // + + public: + // + // Methods + // + + /// Constructor + CUTLASS_DEVICE + PredicatedTileIteratorNormVec(PredicatedTileIteratorParams const& params, + Element* pointer, + TensorCoord extent, + int thread_idx, + TensorCoord threadblock_offset = TensorCoord(), + int const* indices = nullptr) + : params_(params), indices_(indices) + { + TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx) + threadblock_offset; + + extent_row_ = extent.row(); + extent_column_ = extent.column(); + + thread_start_row_ = thread_offset.row(); + thread_start_column_ = thread_offset.column(); + + // Initialize predicates + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kColumn; ++c) { + mask_.predicates[c] = + ((thread_offset.column() + ThreadMap::Delta::kColumn * c) < extent.column()); + } + + // Null pointer performs no accesses + if (!pointer) { mask_.clear(); } + + if (ScatterD && !indices) { mask_.clear(); } + + // Initialize pointer + byte_pointer_ = reinterpret_cast(pointer) + + LongIndex(thread_offset.row()) * LongIndex(params_.stride); + + if (ScatterD) { + byte_pointer_ = reinterpret_cast(pointer) + + LongIndex(thread_offset.column()) * sizeof(AccessType) / kElementsPerAccess; + } + + // Initialize internal state counter + state_[0] = state_[1] = state_[2] = 0; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) + { + byte_pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment& frag, int64_t byte_offset) const + { + uint8_t* byte_pointer = byte_pointer_; + AccessType* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int frag_row_idx = + (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + bool row_guard = ((row_offset + thread_start_row_) < extent_row_); + + AccessType* memory_pointer = reinterpret_cast(byte_pointer + byte_offset); + + if (ScatterD && row_guard) { + assert(indices_); + + memory_pointer = reinterpret_cast( + byte_pointer + byte_offset + + LongIndex(indices_[row_offset + thread_start_row_]) * LongIndex(params_.stride)); + } + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { + bool guard = row_guard && mask_.predicates[column]; + + cutlass::arch::global_load( + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column], + (void*)&memory_pointer[0], + guard); + } + + if (row + 1 < ThreadMap::Iterations::kRow) { + if (!ScatterD) { byte_pointer += params_.increment_row; } + } + } + + if (group + 1 < ThreadMap::Iterations::kGroup) { byte_pointer += params_.increment_group; } + } + + if (cluster + 1 < ThreadMap::Iterations::kCluster) { + byte_pointer += params_.increment_cluster; + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) const { load_with_byte_offset(frag, 0); } + + /// Stores a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const& frag, int64_t byte_offset) const + { + uint8_t* byte_pointer = byte_pointer_; + AccessType const* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int frag_row_idx = + (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + bool row_guard = ((row_offset + thread_start_row_) < extent_row_); + + AccessType* memory_pointer = reinterpret_cast(byte_pointer + byte_offset); + + if (ScatterD && row_guard) { + assert(indices_); + + memory_pointer = reinterpret_cast( + byte_pointer + byte_offset + + LongIndex(indices_[row_offset + thread_start_row_]) * LongIndex(params_.stride)); + } + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { + bool guard = row_guard && mask_.predicates[column]; + + if (UseCUDAStore) { + if (guard) { + memory_pointer[column * ThreadMap::Delta::kColumn / kElementsPerAccess] = + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column]; + } + } else { + cutlass::arch::global_store( + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column], + (void*)&memory_pointer[column * ThreadMap::Delta::kColumn / kElementsPerAccess], + guard); + } + } + + if (row + 1 < ThreadMap::Iterations::kRow) { + if (!ScatterD) { byte_pointer += params_.increment_row; } + } + } + + if (group + 1 < ThreadMap::Iterations::kGroup) { byte_pointer += params_.increment_group; } + } + + if (cluster + 1 < ThreadMap::Iterations::kCluster) { + byte_pointer += params_.increment_cluster; + } + } + } + + /// Stores a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) const { store_with_byte_offset(frag, 0); } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void downsample_load_with_byte_offset(Fragment& frag, + int64_t byte_offset, + int convolution_P, + int convolution_Q, + int add_P, + int add_Q, + int problem_N) const + { + uint8_t* byte_pointer = byte_pointer_; + AccessType* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int frag_row_idx = + (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + bool row_guard = ((row_offset + thread_start_row_) < extent_row_); + + int output_row = row_offset + thread_start_row_; + int output_N = output_row / (convolution_P * convolution_Q); + int output_PQ = output_row % (convolution_P * convolution_Q); + int output_P = output_PQ / convolution_Q; + int output_Q = output_PQ % convolution_Q; + + int input_row = output_N * 2 * convolution_P * 2 * convolution_Q + + (2 * output_P + add_P) * 2 * convolution_Q + 2 * output_Q + add_Q; + + int64_t byte_offset = (input_row - output_row) * problem_N * sizeof(float); + + AccessType* memory_pointer = reinterpret_cast(byte_pointer + byte_offset); + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { + bool guard = row_guard && mask_.predicates[column]; + + cutlass::arch::global_load( + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column], + (void*)&memory_pointer[column * ThreadMap::Delta::kColumn / kElementsPerAccess], + guard); + } + + if (row + 1 < ThreadMap::Iterations::kRow) { byte_pointer += params_.increment_row; } + } + + if (group + 1 < ThreadMap::Iterations::kGroup) { byte_pointer += params_.increment_group; } + } + + if (cluster + 1 < ThreadMap::Iterations::kCluster) { + byte_pointer += params_.increment_cluster; + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void upsample_load_with_byte_offset(Fragment& frag, + int64_t byte_offset, + int convolution_P, + int convolution_Q, + int add_P, + int add_Q, + int problem_N) const + { + uint8_t* byte_pointer = byte_pointer_; + AccessType* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int frag_row_idx = + (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + bool row_guard = ((row_offset + thread_start_row_) < extent_row_); + + int output_row = row_offset + thread_start_row_; + int output_N = output_row / (convolution_P * convolution_Q); + int output_PQ = output_row % (convolution_P * convolution_Q); + int output_P = output_PQ / convolution_Q; + int output_Q = output_PQ % convolution_Q; + int row_add_P = add_P; + int row_add_Q = add_Q; + if (output_P > convolution_P - 2) row_add_P = 0; + if (output_Q > convolution_Q - 2) row_add_Q = 0; + + int input_row = output_N * (convolution_P / 2) * (convolution_Q / 2) + + ((output_P + row_add_P) / 2) * (convolution_Q / 2) + + (output_Q + row_add_Q) / 2; + + int64_t byte_offset = (input_row - output_row) * problem_N * sizeof(float); + + AccessType* memory_pointer = reinterpret_cast(byte_pointer + byte_offset); + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { + bool guard = row_guard && mask_.predicates[column]; + + cutlass::arch::global_load( + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column], + (void*)&memory_pointer[column * ThreadMap::Delta::kColumn / kElementsPerAccess], + guard); + } + + if (row + 1 < ThreadMap::Iterations::kRow) { byte_pointer += params_.increment_row; } + } + + if (group + 1 < ThreadMap::Iterations::kGroup) { byte_pointer += params_.increment_group; } + } + + if (cluster + 1 < ThreadMap::Iterations::kCluster) { + byte_pointer += params_.increment_cluster; + } + } + } + + CUTLASS_DEVICE + MatrixCoord thread_start() const { return MatrixCoord(thread_start_row_, thread_start_column_); } + + /// Need to get the thread start row from the tile iterator + CUTLASS_DEVICE + int32_t thread_start_row() const { return thread_start_row_; } + + /// Need to get the thread start row from the tile iterator + CUTLASS_DEVICE + int32_t thread_start_column() const { return thread_start_column_; } + + /// Extent of the matrix in rows + CUTLASS_DEVICE + Index extent_row() const { return extent_row_; } + + /// Extent of the matrix in columns + CUTLASS_DEVICE + Index extent_column() const { return extent_column_; } + + /// Advances to the next position to load or store + CUTLASS_HOST_DEVICE + PredicatedTileIteratorNormVec& operator++() + { + ++state_[0]; + + if (!ScatterD) { byte_pointer_ += params_.advance_row; } + + thread_start_row_ += ThreadMap::Shape::kRow; + + if (state_[0] == ThreadMap::Count::kRow) { + state_[0] = 0; + ++state_[1]; + byte_pointer_ += params_.advance_group; + + thread_start_row_ += + (ThreadMap::Shape::kGroup - 1) * ThreadMap::Shape::kRow * ThreadMap::Count::kRow; + + if (state_[1] == ThreadMap::Count::kGroup) { + state_[1] = 0; + ++state_[2]; + byte_pointer_ += params_.advance_cluster; + + thread_start_row_ += ThreadMap::Count::kGroup * ThreadMap::Shape::kGroup * + ThreadMap::Count::kRow * ThreadMap::Shape::kRow; + + if (state_[2] == ThreadMap::Count::kCluster) { + state_[2] = 0; + byte_pointer_ += params_.advance_tile; + } + } + } + + return *this; + } + + ///< Efficiently disables all accesses guarded by mask + CUTLASS_DEVICE void clear_mask() { mask_.clear(); } + + ///< Efficiently enables all accesses guarded by mask + CUTLASS_DEVICE void enable_mask() { mask_.enable(); } + + ///< Sets the mask + CUTLASS_DEVICE void get_mask(Mask& mask) const { mask = mask_; } + + ///< Sets the mask + CUTLASS_DEVICE void set_mask(Mask const& mask) { mask_ = mask; } +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/cpp/test/distance/dist_adj.cu b/cpp/test/distance/dist_adj.cu index 7d3064af7b..8595c82d1e 100644 --- a/cpp/test/distance/dist_adj.cu +++ b/cpp/test/distance/dist_adj.cu @@ -76,11 +76,12 @@ struct DistanceAdjInputs { template struct threshold_final_op { - DataT threshold_val; + DataT threshold_val; __device__ __host__ threshold_final_op() : threshold_val(0.0) {} __device__ __host__ threshold_final_op(DataT val) : threshold_val(val) {} - __device__ __host__ OutT operator()(AccT d_val, Index g_idx) const { + __device__ __host__ OutT operator()(AccT d_val, Index g_idx) const + { return d_val <= threshold_val; } }; @@ -123,26 +124,25 @@ class DistanceAdjTest : public ::testing::TestWithParam( x.data(), y.data(), m, n, k); rmm::device_uvector workspace(worksize, stream); -#if 0 - auto fin_op = [threshold] __device__(DataType d_val, int g_d_idx) { - return d_val <= threshold; - }; -#else + using threshold_final_op_ = threshold_final_op; threshold_final_op_ threshold_op(threshold); -#endif - raft::distance::distance( - x.data(), - y.data(), - dist.data(), - m, - n, - k, - workspace.data(), - workspace.size(), - threshold_op, - stream, - isRowMajor); + + raft::distance::distance(x.data(), + y.data(), + dist.data(), + m, + n, + k, + workspace.data(), + workspace.size(), + threshold_op, + stream, + isRowMajor); handle.sync_stream(stream); } diff --git a/cpp/test/distance/dist_eucsqrt_exp.cu b/cpp/test/distance/dist_eucsqrt_exp.cu old mode 100755 new mode 100644 index 90b4f4288c..c4f2dc80c2 --- a/cpp/test/distance/dist_eucsqrt_exp.cu +++ b/cpp/test/distance/dist_eucsqrt_exp.cu @@ -1,74 +1,76 @@ -/* - * Copyright (c) 2018-2022, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "../test_utils.h" -#include "distance_base.cuh" - -namespace raft { -namespace distance { - -template -class DistanceEucSqrtExpTest : public DistanceTest { -}; - -const std::vector> inputsf = { - {0.001f, 2048, 4096, 128, true, 1234ULL}, - {0.001f, 1024, 1024, 32, true, 1234ULL}, - {0.001f, 1024, 32, 1024, true, 1234ULL}, - {0.001f, 32, 1024, 1024, true, 1234ULL}, - {0.003f, 1024, 1024, 1024, true, 1234ULL}, - {0.003f, 1021, 1021, 1021, true, 1234ULL}, - {0.001f, 1024, 1024, 32, false, 1234ULL}, - {0.001f, 1024, 32, 1024, false, 1234ULL}, - {0.001f, 32, 1024, 1024, false, 1234ULL}, - {0.003f, 1024, 1024, 1024, false, 1234ULL}, - {0.003f, 1021, 1021, 1021, false, 1234ULL}, -}; -typedef DistanceEucSqrtExpTest DistanceEucSqrtExpTestF; -TEST_P(DistanceEucSqrtExpTestF, Result) -{ - int m = params.isRowMajor ? params.m : params.n; - int n = params.isRowMajor ? params.n : params.m; - ASSERT_TRUE(devArrMatch( - dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance), stream)); -} -INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceEucSqrtExpTestF, ::testing::ValuesIn(inputsf)); - -const std::vector> inputsd = { - {0.001, 1024, 1024, 32, true, 1234ULL}, - {0.001, 1024, 32, 1024, true, 1234ULL}, - {0.001, 32, 1024, 1024, true, 1234ULL}, - {0.003, 1024, 1024, 1024, true, 1234ULL}, - {0.001, 1024, 1024, 32, false, 1234ULL}, - {0.001, 1024, 32, 1024, false, 1234ULL}, - {0.001, 32, 1024, 1024, false, 1234ULL}, - {0.003, 1024, 1024, 1024, false, 1234ULL}, -}; -typedef DistanceEucSqrtExpTest DistanceEucSqrtExpTestD; -TEST_P(DistanceEucSqrtExpTestD, Result) -{ - int m = params.isRowMajor ? params.m : params.n; - int n = params.isRowMajor ? params.n : params.m; - ASSERT_TRUE(devArrMatch( - dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance), stream)); -} -INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceEucSqrtExpTestD, ::testing::ValuesIn(inputsd)); - -class BigMatrixEucSqrtExp : public BigMatrixDistanceTest { -}; -TEST_F(BigMatrixEucSqrtExp, Result) {} -} // end namespace distance -} // end namespace raft +/* + * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../test_utils.h" +#include "distance_base.cuh" + +namespace raft { +namespace distance { + +template +class DistanceEucSqrtExpTest + : public DistanceTest { +}; + +const std::vector> inputsf = { + {0.001f, 2048, 4096, 128, true, 1234ULL}, + {0.001f, 1024, 1024, 32, true, 1234ULL}, + {0.001f, 1024, 32, 1024, true, 1234ULL}, + {0.001f, 32, 1024, 1024, true, 1234ULL}, + {0.003f, 1024, 1024, 1024, true, 1234ULL}, + {0.003f, 1021, 1021, 1021, true, 1234ULL}, + {0.001f, 1024, 1024, 32, false, 1234ULL}, + {0.001f, 1024, 32, 1024, false, 1234ULL}, + {0.001f, 32, 1024, 1024, false, 1234ULL}, + {0.003f, 1024, 1024, 1024, false, 1234ULL}, + {0.003f, 1021, 1021, 1021, false, 1234ULL}, +}; +typedef DistanceEucSqrtExpTest DistanceEucSqrtExpTestF; +TEST_P(DistanceEucSqrtExpTestF, Result) +{ + int m = params.isRowMajor ? params.m : params.n; + int n = params.isRowMajor ? params.n : params.m; + ASSERT_TRUE(devArrMatch( + dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance), stream)); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceEucSqrtExpTestF, ::testing::ValuesIn(inputsf)); + +const std::vector> inputsd = { + {0.001, 1024, 1024, 32, true, 1234ULL}, + {0.001, 1024, 32, 1024, true, 1234ULL}, + {0.001, 32, 1024, 1024, true, 1234ULL}, + {0.003, 1024, 1024, 1024, true, 1234ULL}, + {0.001, 1024, 1024, 32, false, 1234ULL}, + {0.001, 1024, 32, 1024, false, 1234ULL}, + {0.001, 32, 1024, 1024, false, 1234ULL}, + {0.003, 1024, 1024, 1024, false, 1234ULL}, +}; +typedef DistanceEucSqrtExpTest DistanceEucSqrtExpTestD; +TEST_P(DistanceEucSqrtExpTestD, Result) +{ + int m = params.isRowMajor ? params.m : params.n; + int n = params.isRowMajor ? params.n : params.m; + ASSERT_TRUE(devArrMatch( + dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance), stream)); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceEucSqrtExpTestD, ::testing::ValuesIn(inputsd)); + +class BigMatrixEucSqrtExp + : public BigMatrixDistanceTest { +}; +TEST_F(BigMatrixEucSqrtExp, Result) {} +} // end namespace distance +} // end namespace raft From 1a45bfaf4ad2b0454d8104c4469c7353bb5cf261 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Fri, 21 Oct 2022 15:49:06 +0530 Subject: [PATCH 03/25] add noexcept to the functor methods --- cpp/include/raft/distance/detail/cosine.cuh | 6 +++--- cpp/include/raft/distance/detail/distance.cuh | 7 +++++-- cpp/include/raft/distance/detail/euclidean.cuh | 8 ++++---- cpp/test/distance/dist_adj.cu | 6 +++--- 4 files changed, 15 insertions(+), 12 deletions(-) diff --git a/cpp/include/raft/distance/detail/cosine.cuh b/cpp/include/raft/distance/detail/cosine.cuh index 4a94d02fc4..16dde7fe51 100644 --- a/cpp/include/raft/distance/detail/cosine.cuh +++ b/cpp/include/raft/distance/detail/cosine.cuh @@ -26,12 +26,12 @@ namespace detail { template struct CosineOp { - __device__ __host__ CosineOp() {} - __device__ __host__ AccT operator()(DataT& aNorm, const DataT& bNorm, DataT& accVal) const + __device__ __host__ CosineOp() noexcept {} + __device__ __host__ AccT operator()(DataT& aNorm, const DataT& bNorm, DataT& accVal) const noexcept { return static_cast(1.0) - (AccT)(accVal / (aNorm * bNorm)); } - __device__ __host__ AccT operator()(DataT aData) const { return aData; } + __device__ __host__ AccT operator()(DataT aData) const noexcept { return aData; } }; /** diff --git a/cpp/include/raft/distance/detail/distance.cuh b/cpp/include/raft/distance/detail/distance.cuh index 7dda15f573..a3b7e20eec 100644 --- a/cpp/include/raft/distance/detail/distance.cuh +++ b/cpp/include/raft/distance/detail/distance.cuh @@ -618,9 +618,12 @@ void distance(const InType* x, template struct default_fin_op { - __host__ __device__ default_fin_op(){}; + __host__ __device__ default_fin_op() noexcept {}; // functor signature. - __host__ __device__ OutType operator()(AccType d_val, Index g_d_idx) const { return d_val; } + __host__ __device__ OutType operator()(AccType d_val, Index g_d_idx) const noexcept + { + return d_val; + } }; template struct L2ExpandedOp { bool sqrt; - __device__ __host__ L2ExpandedOp() : sqrt(false) {} - __device__ __host__ L2ExpandedOp(bool isSqrt) : sqrt(isSqrt) {} - __device__ __host__ AccT operator()(DataT& aNorm, const DataT& bNorm, DataT& accVal) const + __device__ __host__ L2ExpandedOp() noexcept : sqrt(false) {} + __device__ __host__ L2ExpandedOp(bool isSqrt) noexcept : sqrt(isSqrt) {} + __device__ __host__ AccT operator()(DataT& aNorm, const DataT& bNorm, DataT& accVal) const noexcept { AccT outVal = aNorm + bNorm - DataT(2.0) * accVal; return sqrt ? raft::mySqrt(outVal) : outVal; } - __device__ __host__ AccT operator()(DataT aData) const { return aData; } + __device__ __host__ AccT operator()(DataT aData) const noexcept { return aData; } }; /** diff --git a/cpp/test/distance/dist_adj.cu b/cpp/test/distance/dist_adj.cu index 8595c82d1e..dd0bc706c0 100644 --- a/cpp/test/distance/dist_adj.cu +++ b/cpp/test/distance/dist_adj.cu @@ -78,9 +78,9 @@ template struct threshold_final_op { DataT threshold_val; - __device__ __host__ threshold_final_op() : threshold_val(0.0) {} - __device__ __host__ threshold_final_op(DataT val) : threshold_val(val) {} - __device__ __host__ OutT operator()(AccT d_val, Index g_idx) const + __device__ __host__ threshold_final_op() noexcept : threshold_val(0.0) {} + __device__ __host__ threshold_final_op(DataT val) noexcept : threshold_val(val) {} + __device__ __host__ OutT operator()(AccT d_val, Index g_idx) const noexcept { return d_val <= threshold_val; } From 7786fcb882bda7a9b243a57de5544f1a09b349de Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Thu, 27 Oct 2022 15:45:53 +0530 Subject: [PATCH 04/25] fix comments, remove redundant code and fix formatting issues --- cpp/include/raft/distance/detail/cosine.cuh | 8 ++-- cpp/include/raft/distance/detail/distance.cuh | 4 +- .../raft/distance/detail/euclidean.cuh | 10 +++-- .../detail/pairwise_distance_cutlass_base.cuh | 17 ++++---- .../detail/pairwise_distance_epilogue.h | 40 ++++++------------ .../pairwise_distance_epilogue_elementwise.h | 29 +++++++------ .../distance/detail/pairwise_distance_gemm.h | 23 +++++++++-- .../detail/predicated_tile_iterator_normvec.h | 41 ++++++------------- cpp/include/raft/distance/distance.cuh | 2 - 9 files changed, 84 insertions(+), 90 deletions(-) diff --git a/cpp/include/raft/distance/detail/cosine.cuh b/cpp/include/raft/distance/detail/cosine.cuh index 16dde7fe51..7cb7e9fabc 100644 --- a/cpp/include/raft/distance/detail/cosine.cuh +++ b/cpp/include/raft/distance/detail/cosine.cuh @@ -26,12 +26,14 @@ namespace detail { template struct CosineOp { - __device__ __host__ CosineOp() noexcept {} - __device__ __host__ AccT operator()(DataT& aNorm, const DataT& bNorm, DataT& accVal) const noexcept + __device__ CosineOp() noexcept {} + __device__ AccT operator()(DataT& aNorm, + const DataT& bNorm, + DataT& accVal) const noexcept { return static_cast(1.0) - (AccT)(accVal / (aNorm * bNorm)); } - __device__ __host__ AccT operator()(DataT aData) const noexcept { return aData; } + __device__ AccT operator()(DataT aData) const noexcept { return aData; } }; /** diff --git a/cpp/include/raft/distance/detail/distance.cuh b/cpp/include/raft/distance/detail/distance.cuh index 9834136ef6..c1a9c7ead3 100644 --- a/cpp/include/raft/distance/detail/distance.cuh +++ b/cpp/include/raft/distance/detail/distance.cuh @@ -616,7 +616,7 @@ void distance(const InType* x, * worksize, the number of bytes of workspace required */ -// Default final op functor which facilitates elementwise operation on +// Default final op functor which facilitates elementwise operation on // final distance value if any. template struct default_fin_op { @@ -624,7 +624,7 @@ struct default_fin_op { // functor signature. __host__ __device__ OutType operator()(AccType d_val, Index g_d_idx) const noexcept { - return d_val; + return d_val; } }; diff --git a/cpp/include/raft/distance/detail/euclidean.cuh b/cpp/include/raft/distance/detail/euclidean.cuh index b23219f84e..461a8dd0ae 100644 --- a/cpp/include/raft/distance/detail/euclidean.cuh +++ b/cpp/include/raft/distance/detail/euclidean.cuh @@ -28,15 +28,17 @@ template struct L2ExpandedOp { bool sqrt; - __device__ __host__ L2ExpandedOp() noexcept : sqrt(false) {} - __device__ __host__ L2ExpandedOp(bool isSqrt) noexcept : sqrt(isSqrt) {} - __device__ __host__ AccT operator()(DataT& aNorm, const DataT& bNorm, DataT& accVal) const noexcept + __device__ L2ExpandedOp() noexcept : sqrt(false) {} + __device__ L2ExpandedOp(bool isSqrt) noexcept : sqrt(isSqrt) {} + __device__ AccT operator()(DataT& aNorm, + const DataT& bNorm, + DataT& accVal) const noexcept { AccT outVal = aNorm + bNorm - DataT(2.0) * accVal; return sqrt ? raft::mySqrt(outVal) : outVal; } - __device__ __host__ AccT operator()(DataT aData) const noexcept { return aData; } + __device__ AccT operator()(DataT aData) const noexcept { return aData; } }; /** diff --git a/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh b/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh index ea216f8c73..44408c52f5 100644 --- a/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh +++ b/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh @@ -70,6 +70,9 @@ void cutlassDistanceKernel(const DataT* x, DistanceFn dist_op, cudaStream_t stream) { + static_assert(!(std::is_same::value), + "OutType bool is not supported use uint8_t instead"); + using EpilogueOutputOp = cutlass::epilogue::thread::PairwiseDistanceEpilogueElementwise(n), static_cast(m), static_cast(k)); + auto problem_size = cutlass::gemm::GemmCoord(n, m, k); using cutlassDistKernel = typename cutlass::gemm::kernel::PairwiseDistanceGemm(m), static_cast(n), static_cast(k)); - a = x; - b = y; - gemm_lda = lda; - gemm_ldb = ldb; + problem_size = cutlass::gemm::GemmCoord(m, n, k); + a = x; + b = y; + gemm_lda = lda; + gemm_ldb = ldb; } typename cutlassDist::Arguments arguments{ diff --git a/cpp/include/raft/distance/detail/pairwise_distance_epilogue.h b/cpp/include/raft/distance/detail/pairwise_distance_epilogue.h index e2360809bb..d094e9d01d 100644 --- a/cpp/include/raft/distance/detail/pairwise_distance_epilogue.h +++ b/cpp/include/raft/distance/detail/pairwise_distance_epilogue.h @@ -1,33 +1,19 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause +/* + * Copyright (c) 2018-2022, NVIDIA CORPORATION. * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * http://www.apache.org/licenses/LICENSE-2.0 * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + /*! \file \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. diff --git a/cpp/include/raft/distance/detail/pairwise_distance_epilogue_elementwise.h b/cpp/include/raft/distance/detail/pairwise_distance_epilogue_elementwise.h index 68b1815fdd..f6ed49fcfa 100644 --- a/cpp/include/raft/distance/detail/pairwise_distance_epilogue_elementwise.h +++ b/cpp/include/raft/distance/detail/pairwise_distance_epilogue_elementwise.h @@ -1,3 +1,19 @@ +/* + * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + // /*! \file \brief Functor performing distance operations used by epilogues of pairwise distance @@ -106,18 +122,7 @@ class PairwiseDistanceEpilogueElementwise { /// Functionally required for serial reduction in the epilogue CUTLASS_HOST_DEVICE - void set_k_partition(int k_partition, int k_partition_count) - { -#if 0 - if (k_partition) { - beta_ = ElementCompute(1); - } - - if (k_partition != k_partition_count - 1) { - skip_elementwise_ = true; - } -#endif - } + void set_k_partition(int k_partition, int k_partition_count) {} /// Applies the operation when is_source_needed() is true CUTLASS_HOST_DEVICE diff --git a/cpp/include/raft/distance/detail/pairwise_distance_gemm.h b/cpp/include/raft/distance/detail/pairwise_distance_gemm.h index a47dcb523e..eef48548f2 100644 --- a/cpp/include/raft/distance/detail/pairwise_distance_gemm.h +++ b/cpp/include/raft/distance/detail/pairwise_distance_gemm.h @@ -1,3 +1,19 @@ +/* + * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + #pragma once #include "cutlass/cutlass.h" @@ -7,7 +23,6 @@ #include "cutlass/layout/matrix.h" #include "cutlass/layout/tensor.h" -//#include "./epilogue_with_bcast_threadblock.h" #include "pairwise_distance_epilogue.h" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -44,7 +59,7 @@ struct PairwiseDistanceGemm { /// Threadblock-level tile size (concept: GemmShape) using ThreadblockShape = - cutlass::gemm::GemmShape<128, 128, 16>; // <- threadblock tile M = 128, N = 64, K = 16 + cutlass::gemm::GemmShape<128, 128, 16>; // <- threadblock tile M = 128, N = 128, K = 16 /// Warp-level tile size (concept: GemmShape) // This code section describes tile size a warp will compute using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; // <- warp tile M = 64, N = 64, K = 16 @@ -145,10 +160,10 @@ struct PairwiseDistanceGemm; // <- threadblock tile M = 128, N = 64, K = 16 + cutlass::gemm::GemmShape<64, 64, 16>; // <- threadblock tile M = 64, N = 64, K = 16 /// Warp-level tile size (concept: GemmShape) // This code section describes tile size a warp will compute - using WarpShape = cutlass::gemm::GemmShape<32, 32, 16>; // <- warp tile M = 64, N = 64, K = 16 + using WarpShape = cutlass::gemm::GemmShape<32, 32, 16>; // <- warp tile M = 32, N = 32, K = 16 /// Warp-level tile size (concept: GemmShape) // This code section describes the size of MMA op using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; diff --git a/cpp/include/raft/distance/detail/predicated_tile_iterator_normvec.h b/cpp/include/raft/distance/detail/predicated_tile_iterator_normvec.h index effb17ee8b..97df2aff27 100644 --- a/cpp/include/raft/distance/detail/predicated_tile_iterator_normvec.h +++ b/cpp/include/raft/distance/detail/predicated_tile_iterator_normvec.h @@ -1,33 +1,19 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause +/* + * Copyright (c) 2018-2022, NVIDIA CORPORATION. * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * http://www.apache.org/licenses/LICENSE-2.0 * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + /*! \file \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. @@ -78,7 +64,6 @@ class PredicatedTileIteratorNormVec { using Element = Element_; - // using Layout = layout::RowMajor; using Layout = Layout_; using TensorRef = TensorRef; using ConstTensorRef = typename TensorRef::ConstTensorRef; diff --git a/cpp/include/raft/distance/distance.cuh b/cpp/include/raft/distance/distance.cuh index e04e7354fe..6e3f97b45c 100644 --- a/cpp/include/raft/distance/distance.cuh +++ b/cpp/include/raft/distance/distance.cuh @@ -315,7 +315,6 @@ void pairwise_distance(const raft::handle_t& handle, detail::pairwise_distance_impl( x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); break; -#if 0 case raft::distance::DistanceType::L1: detail::pairwise_distance_impl( x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); @@ -365,7 +364,6 @@ void pairwise_distance(const raft::handle_t& handle, pairwise_distance_impl( x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); break; -#endif default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); }; } From 181fc40c13cae7985ea56fbb99228a3d5fc9e054 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Fri, 28 Oct 2022 16:41:37 +0530 Subject: [PATCH 05/25] add cutlass cmake support for raft with custom namespace, fix formating issues --- cpp/CMakeLists.txt | 6 +- cpp/cmake/thirdparty/get_cutlass.cmake | 75 +++++++++++++++++++ .../detail/pairwise_distance_cutlass_base.cuh | 16 ++-- .../detail/pairwise_distance_epilogue.h | 18 ++--- .../pairwise_distance_epilogue_elementwise.h | 12 +-- .../distance/detail/pairwise_distance_gemm.h | 12 +-- .../detail/predicated_tile_iterator_normvec.h | 24 +++--- cpp/test/CMakeLists.txt | 3 + 8 files changed, 124 insertions(+), 42 deletions(-) create mode 100644 cpp/cmake/thirdparty/get_cutlass.cmake diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index f54c320b47..aef01581bc 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -62,6 +62,7 @@ option(RAFT_NVTX "Enable nvtx markers" OFF) option(RAFT_COMPILE_LIBRARIES "Enable building raft shared library instantiations" ${BUILD_TESTS}) option(RAFT_COMPILE_NN_LIBRARY "Enable building raft nearest neighbors shared library instantiations" ${RAFT_COMPILE_LIBRARIES}) option(RAFT_COMPILE_DIST_LIBRARY "Enable building raft distant shared library instantiations" ${RAFT_COMPILE_LIBRARIES}) +option(RAFT_ENABLE_DIST_DEPENDENCIES "Search for raft::distance dependencies like cutlass" ${RAFT_COMPILE_LIBRARIES}) option(RAFT_ENABLE_NN_DEPENDENCIES "Search for raft::nn dependencies like faiss" ${RAFT_COMPILE_LIBRARIES}) option(RAFT_ENABLE_thrust_DEPENDENCY "Enable Thrust dependency" ON) @@ -156,6 +157,7 @@ rapids_cpm_init() include(cmake/thirdparty/get_thrust.cmake) include(cmake/thirdparty/get_rmm.cmake) include(cmake/thirdparty/get_faiss.cmake) +include(cmake/thirdparty/get_cutlass.cmake) if(RAFT_ENABLE_cuco_DEPENDENCY) include(${rapids-cmake-dir}/cpm/cuco.cmake) @@ -178,7 +180,7 @@ add_library(raft::raft ALIAS raft) target_include_directories(raft INTERFACE "$" "$" - "${CUTLASS_DIR}/include") + "${CUTLASS_INCLUDE_DIR}") # Keep RAFT as lightweight as possible. # Only CUDA libs and rmm should @@ -320,6 +322,8 @@ if(RAFT_COMPILE_DIST_LIBRARY) ) target_compile_definitions(raft_distance_lib INTERFACE "RAFT_DISTANCE_COMPILED") + target_compile_definitions(raft_distance_lib + INTERFACE "CUTLASS_NAMESPACE=${RAFT_CUTLASS_NAMESPACE}") # ensure CUDA symbols aren't relocated to the middle of the debug build binaries target_link_options(raft_distance_lib PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/fatbin.ld") diff --git a/cpp/cmake/thirdparty/get_cutlass.cmake b/cpp/cmake/thirdparty/get_cutlass.cmake new file mode 100644 index 0000000000..fcc5d3bed3 --- /dev/null +++ b/cpp/cmake/thirdparty/get_cutlass.cmake @@ -0,0 +1,75 @@ +#============================================================================= +# Copyright (c) 2021-2022, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#============================================================================= + +function(find_and_configure_cutlass) + set(oneValueArgs VERSION REPOSITORY PINNED_TAG) + cmake_parse_arguments(PKG "${options}" "${oneValueArgs}" + "${multiValueArgs}" ${ARGN} ) + + if(RAFT_ENABLE_DIST_DEPENDENCIES OR RAFT_COMPILE_LIBRARIES) + rapids_find_generate_module(cutlass + HEADER_NAMES cutlass/include/* + ) + set(CUTLASS_ENABLE_HEADERS_ONLY ON) + set(RAFT_CUTLASS_NAMESPACE "raft_cutlass" CACHE STRING "Top level namespace of CUTLASS") + set(BUILD_SHARED_LIBS OFF) + # if (PKG_BUILD_STATIC_LIBS) + # set(BUILD_SHARED_LIBS OFF) + # set(CPM_DOWNLOAD_cutlass ON) + # endif() + + rapids_cpm_find(cutlass ${PKG_VERSION} + GLOBAL_TARGETS cutlass::cutlass + CPM_ARGS + GIT_REPOSITORY ${PKG_REPOSITORY} + GIT_TAG ${PKG_PINNED_TAG} + EXCLUDE_FROM_ALL ${PKG_EXCLUDE_FROM_ALL} + OPTIONS + "CUDAToolkit_ROOT ${CUDAToolkit_LIBRARY_DIR}" + ) + + if(TARGET cutlass AND NOT TARGET cutlass::cutlass) + add_library(cutlass::cutlass ALIAS cutlass) + endif() + + # if(cutlass_ADDED) + # rapids_export(BUILD cutlass + # EXPORT_SET cutlass-targets + # GLOBAL_TARGETS cutlass + # NAMESPACE cutlass::) + # endif() + endif() + + # We generate the faiss-config files when we built faiss locally, so always do `find_dependency` + #rapids_export_package(BUILD cutlass raft-distance-lib-exports GLOBAL_TARGETS cutlass::cutlass cutlass) + rapids_export_package(INSTALL cutlass raft-distance-lib-exports GLOBAL_TARGETS cutlass::cutlass cutlass) + + # Tell cmake where it can find the generated faiss-config.cmake we wrote. + include("${rapids-cmake-dir}/export/find_package_root.cmake") + rapids_export_find_package_root(INSTALL cutlass [=[${CMAKE_CURRENT_LIST_DIR}]=] raft-distance-lib-exports) +endfunction() + +if(NOT RAFT_CUTLASS_GIT_TAG) + set(RAFT_CUTLASS_GIT_TAG v2.9.0) +endif() + +if(NOT RAFT_CUTLASS_GIT_REPOSITORY) + set(RAFT_CUTLASS_GIT_REPOSITORY https://github.com/NVIDIA/cutlass.git) +endif() + +find_and_configure_cutlass(VERSION 2.9.0 + REPOSITORY ${RAFT_CUTLASS_GIT_REPOSITORY} + PINNED_TAG ${RAFT_CUTLASS_GIT_TAG}) diff --git a/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh b/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh index 44408c52f5..5b67d77e96 100644 --- a/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh +++ b/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh @@ -21,14 +21,14 @@ #include -#include "cutlass/cutlass.h" -#include "cutlass/gemm/device/gemm.h" -#include "cutlass/gemm/device/gemm_universal_adapter.h" - -#include "cutlass/layout/matrix.h" -#include "cutlass/layout/tensor.h" -#include "cutlass/matrix_coord.h" -#include "cutlass/tensor_view.h" +#include +#include +#include + +#include +#include +#include +#include #include "./pairwise_distance_epilogue_elementwise.h" #include "./pairwise_distance_gemm.h" diff --git a/cpp/include/raft/distance/detail/pairwise_distance_epilogue.h b/cpp/include/raft/distance/detail/pairwise_distance_epilogue.h index d094e9d01d..d34af4ff70 100644 --- a/cpp/include/raft/distance/detail/pairwise_distance_epilogue.h +++ b/cpp/include/raft/distance/detail/pairwise_distance_epilogue.h @@ -24,17 +24,17 @@ #pragma once -#include "cutlass/array.h" -#include "cutlass/cutlass.h" -#include "cutlass/numeric_types.h" +#include +#include +#include -#include "cutlass/gemm/gemm.h" +#include -#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" -#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h" -#include "cutlass/epilogue/threadblock/epilogue.h" -#include "cutlass/epilogue/threadblock/epilogue_with_broadcast.h" -#include "predicated_tile_iterator_normvec.h" +#include +#include +#include +#include +#include "./predicated_tile_iterator_normvec.h" //////////////////////////////////////////////////////////////////////////////// diff --git a/cpp/include/raft/distance/detail/pairwise_distance_epilogue_elementwise.h b/cpp/include/raft/distance/detail/pairwise_distance_epilogue_elementwise.h index f6ed49fcfa..07f37e0234 100644 --- a/cpp/include/raft/distance/detail/pairwise_distance_epilogue_elementwise.h +++ b/cpp/include/raft/distance/detail/pairwise_distance_epilogue_elementwise.h @@ -22,13 +22,13 @@ #pragma once -#include "cutlass/array.h" -#include "cutlass/cutlass.h" -#include "cutlass/functional.h" -#include "cutlass/numeric_conversion.h" -#include "cutlass/numeric_types.h" +#include +#include +#include +#include +#include -#include "cutlass/epilogue/thread/activation.h" +#include ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/cpp/include/raft/distance/detail/pairwise_distance_gemm.h b/cpp/include/raft/distance/detail/pairwise_distance_gemm.h index eef48548f2..6768a1b579 100644 --- a/cpp/include/raft/distance/detail/pairwise_distance_gemm.h +++ b/cpp/include/raft/distance/detail/pairwise_distance_gemm.h @@ -16,14 +16,14 @@ #pragma once -#include "cutlass/cutlass.h" +#include -#include "cutlass/gemm/kernel/default_gemm_universal.h" -#include "cutlass/gemm/kernel/gemm_with_fused_epilogue.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/layout/tensor.h" +#include +#include +#include +#include -#include "pairwise_distance_epilogue.h" +#include "./pairwise_distance_epilogue.h" ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/cpp/include/raft/distance/detail/predicated_tile_iterator_normvec.h b/cpp/include/raft/distance/detail/predicated_tile_iterator_normvec.h index 97df2aff27..6272770f4f 100644 --- a/cpp/include/raft/distance/detail/predicated_tile_iterator_normvec.h +++ b/cpp/include/raft/distance/detail/predicated_tile_iterator_normvec.h @@ -24,18 +24,18 @@ #pragma once -#include "cutlass/arch/arch.h" -#include "cutlass/arch/memory.h" -#include "cutlass/array.h" -#include "cutlass/cutlass.h" -#include "cutlass/epilogue/threadblock/output_tile_thread_map.h" -#include "cutlass/epilogue/threadblock/predicated_tile_iterator_params.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/layout/tensor.h" -#include "cutlass/matrix_shape.h" -#include "cutlass/numeric_types.h" -#include "cutlass/tensor_ref.h" -#include "cutlass/transform/pitch_linear_thread_map.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include //////////////////////////////////////////////////////////////////////////////// diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 059a1a792b..645c5d0d03 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -62,6 +62,9 @@ function(ConfigureTest) PUBLIC "$" ) + target_compile_definitions(${TEST_NAME} + INTERFACE "CUTLASS_NAMESPACE=${RAFT_CUTLASS_NAMESPACE}") + install( TARGETS ${TEST_NAME} COMPONENT testing From 3d3454599c1e883a533ef5c4f051f502ded8a1a4 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Fri, 28 Oct 2022 18:12:16 +0530 Subject: [PATCH 06/25] fix formatting issues --- cpp/include/raft/distance/detail/cosine.cuh | 8 +++----- cpp/include/raft/distance/detail/euclidean.cuh | 10 ++++------ .../raft/distance/detail/pairwise_distance_epilogue.h | 4 ++-- .../detail/pairwise_distance_epilogue_elementwise.h | 2 +- .../distance/detail/predicated_tile_iterator_normvec.h | 2 +- 5 files changed, 11 insertions(+), 15 deletions(-) diff --git a/cpp/include/raft/distance/detail/cosine.cuh b/cpp/include/raft/distance/detail/cosine.cuh index 7cb7e9fabc..e3d400e799 100644 --- a/cpp/include/raft/distance/detail/cosine.cuh +++ b/cpp/include/raft/distance/detail/cosine.cuh @@ -26,14 +26,12 @@ namespace detail { template struct CosineOp { - __device__ CosineOp() noexcept {} - __device__ AccT operator()(DataT& aNorm, - const DataT& bNorm, - DataT& accVal) const noexcept + __device__ CosineOp() noexcept {} + __device__ AccT operator()(DataT& aNorm, const DataT& bNorm, DataT& accVal) const noexcept { return static_cast(1.0) - (AccT)(accVal / (aNorm * bNorm)); } - __device__ AccT operator()(DataT aData) const noexcept { return aData; } + __device__ AccT operator()(DataT aData) const noexcept { return aData; } }; /** diff --git a/cpp/include/raft/distance/detail/euclidean.cuh b/cpp/include/raft/distance/detail/euclidean.cuh index 461a8dd0ae..64359e2270 100644 --- a/cpp/include/raft/distance/detail/euclidean.cuh +++ b/cpp/include/raft/distance/detail/euclidean.cuh @@ -28,17 +28,15 @@ template struct L2ExpandedOp { bool sqrt; - __device__ L2ExpandedOp() noexcept : sqrt(false) {} - __device__ L2ExpandedOp(bool isSqrt) noexcept : sqrt(isSqrt) {} - __device__ AccT operator()(DataT& aNorm, - const DataT& bNorm, - DataT& accVal) const noexcept + __device__ L2ExpandedOp() noexcept : sqrt(false) {} + __device__ L2ExpandedOp(bool isSqrt) noexcept : sqrt(isSqrt) {} + __device__ AccT operator()(DataT& aNorm, const DataT& bNorm, DataT& accVal) const noexcept { AccT outVal = aNorm + bNorm - DataT(2.0) * accVal; return sqrt ? raft::mySqrt(outVal) : outVal; } - __device__ AccT operator()(DataT aData) const noexcept { return aData; } + __device__ AccT operator()(DataT aData) const noexcept { return aData; } }; /** diff --git a/cpp/include/raft/distance/detail/pairwise_distance_epilogue.h b/cpp/include/raft/distance/detail/pairwise_distance_epilogue.h index d34af4ff70..28eb57fdd0 100644 --- a/cpp/include/raft/distance/detail/pairwise_distance_epilogue.h +++ b/cpp/include/raft/distance/detail/pairwise_distance_epilogue.h @@ -24,17 +24,17 @@ #pragma once -#include #include +#include #include #include +#include "./predicated_tile_iterator_normvec.h" #include #include #include #include -#include "./predicated_tile_iterator_normvec.h" //////////////////////////////////////////////////////////////////////////////// diff --git a/cpp/include/raft/distance/detail/pairwise_distance_epilogue_elementwise.h b/cpp/include/raft/distance/detail/pairwise_distance_epilogue_elementwise.h index 07f37e0234..80c926874a 100644 --- a/cpp/include/raft/distance/detail/pairwise_distance_epilogue_elementwise.h +++ b/cpp/include/raft/distance/detail/pairwise_distance_epilogue_elementwise.h @@ -22,8 +22,8 @@ #pragma once -#include #include +#include #include #include #include diff --git a/cpp/include/raft/distance/detail/predicated_tile_iterator_normvec.h b/cpp/include/raft/distance/detail/predicated_tile_iterator_normvec.h index 6272770f4f..c343d09083 100644 --- a/cpp/include/raft/distance/detail/predicated_tile_iterator_normvec.h +++ b/cpp/include/raft/distance/detail/predicated_tile_iterator_normvec.h @@ -24,10 +24,10 @@ #pragma once -#include #include #include #include +#include #include #include #include From 02c23ed0a2ae6b552ef09c02bbd2d9cab77c29d7 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Thu, 3 Nov 2022 20:03:42 +0530 Subject: [PATCH 07/25] fix the cutlass_include_dir path in cmake --- cpp/CMakeLists.txt | 3 +-- cpp/cmake/thirdparty/get_cutlass.cmake | 8 +++++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index aef01581bc..2bf17076e8 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -180,7 +180,7 @@ add_library(raft::raft ALIAS raft) target_include_directories(raft INTERFACE "$" "$" - "${CUTLASS_INCLUDE_DIR}") + "$") # Keep RAFT as lightweight as possible. # Only CUDA libs and rmm should @@ -449,7 +449,6 @@ if(TARGET raft_nn_lib) EXPORT raft-nn-lib-exports) endif() - install(DIRECTORY include/raft COMPONENT raft DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}) diff --git a/cpp/cmake/thirdparty/get_cutlass.cmake b/cpp/cmake/thirdparty/get_cutlass.cmake index fcc5d3bed3..cfe7c2f543 100644 --- a/cpp/cmake/thirdparty/get_cutlass.cmake +++ b/cpp/cmake/thirdparty/get_cutlass.cmake @@ -22,10 +22,11 @@ function(find_and_configure_cutlass) if(RAFT_ENABLE_DIST_DEPENDENCIES OR RAFT_COMPILE_LIBRARIES) rapids_find_generate_module(cutlass HEADER_NAMES cutlass/include/* + INCLUDE_SUFFIXES cutlass ) set(CUTLASS_ENABLE_HEADERS_ONLY ON) set(RAFT_CUTLASS_NAMESPACE "raft_cutlass" CACHE STRING "Top level namespace of CUTLASS") - set(BUILD_SHARED_LIBS OFF) +# set(BUILD_SHARED_LIBS OFF) # if (PKG_BUILD_STATIC_LIBS) # set(BUILD_SHARED_LIBS OFF) # set(CPM_DOWNLOAD_cutlass ON) @@ -38,6 +39,7 @@ function(find_and_configure_cutlass) GIT_TAG ${PKG_PINNED_TAG} EXCLUDE_FROM_ALL ${PKG_EXCLUDE_FROM_ALL} OPTIONS + "CMAKE_INSTALL_INCLUDEDIR include" "CUDAToolkit_ROOT ${CUDAToolkit_LIBRARY_DIR}" ) @@ -54,10 +56,10 @@ function(find_and_configure_cutlass) endif() # We generate the faiss-config files when we built faiss locally, so always do `find_dependency` - #rapids_export_package(BUILD cutlass raft-distance-lib-exports GLOBAL_TARGETS cutlass::cutlass cutlass) + rapids_export_package(BUILD cutlass raft-distance-lib-exports GLOBAL_TARGETS cutlass::cutlass cutlass) rapids_export_package(INSTALL cutlass raft-distance-lib-exports GLOBAL_TARGETS cutlass::cutlass cutlass) - # Tell cmake where it can find the generated faiss-config.cmake we wrote. + # Tell cmake where it can find the generated cutlass-config.cmake we wrote. include("${rapids-cmake-dir}/export/find_package_root.cmake") rapids_export_find_package_root(INSTALL cutlass [=[${CMAKE_CURRENT_LIST_DIR}]=] raft-distance-lib-exports) endfunction() From 79334361caae2279e8a30315c018be96fb87859b Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Fri, 4 Nov 2022 15:09:11 +0530 Subject: [PATCH 08/25] fix bugs in get_cutlass cmake to use cutlass provided properties correctly --- cpp/CMakeLists.txt | 6 ++-- cpp/cmake/thirdparty/get_cutlass.cmake | 45 ++++++++++---------------- 2 files changed, 19 insertions(+), 32 deletions(-) diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 2bf17076e8..342fc55060 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -179,8 +179,7 @@ add_library(raft::raft ALIAS raft) target_include_directories(raft INTERFACE "$" - "$" - "$") + "$") # Keep RAFT as lightweight as possible. # Only CUDA libs and rmm should @@ -192,6 +191,7 @@ target_link_libraries(raft INTERFACE CUDA::cusolver${_ctk_static_suffix} CUDA::cusparse${_ctk_static_suffix} $<$:raft::Thrust> + nvidia::cutlass::cutlass ) target_compile_features(raft INTERFACE cxx_std_17 $) @@ -322,8 +322,6 @@ if(RAFT_COMPILE_DIST_LIBRARY) ) target_compile_definitions(raft_distance_lib INTERFACE "RAFT_DISTANCE_COMPILED") - target_compile_definitions(raft_distance_lib - INTERFACE "CUTLASS_NAMESPACE=${RAFT_CUTLASS_NAMESPACE}") # ensure CUDA symbols aren't relocated to the middle of the debug build binaries target_link_options(raft_distance_lib PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/fatbin.ld") diff --git a/cpp/cmake/thirdparty/get_cutlass.cmake b/cpp/cmake/thirdparty/get_cutlass.cmake index cfe7c2f543..4908e3fda7 100644 --- a/cpp/cmake/thirdparty/get_cutlass.cmake +++ b/cpp/cmake/thirdparty/get_cutlass.cmake @@ -20,58 +20,47 @@ function(find_and_configure_cutlass) "${multiValueArgs}" ${ARGN} ) if(RAFT_ENABLE_DIST_DEPENDENCIES OR RAFT_COMPILE_LIBRARIES) - rapids_find_generate_module(cutlass - HEADER_NAMES cutlass/include/* - INCLUDE_SUFFIXES cutlass - ) - set(CUTLASS_ENABLE_HEADERS_ONLY ON) - set(RAFT_CUTLASS_NAMESPACE "raft_cutlass" CACHE STRING "Top level namespace of CUTLASS") -# set(BUILD_SHARED_LIBS OFF) - # if (PKG_BUILD_STATIC_LIBS) - # set(BUILD_SHARED_LIBS OFF) - # set(CPM_DOWNLOAD_cutlass ON) - # endif() + set(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library") + set(CUTLASS_NAMESPACE "raft_cutlass" CACHE STRING "Top level namespace of CUTLASS") rapids_cpm_find(cutlass ${PKG_VERSION} - GLOBAL_TARGETS cutlass::cutlass + GLOBAL_TARGETS nvidia::cutlass::CUTLASS CPM_ARGS GIT_REPOSITORY ${PKG_REPOSITORY} GIT_TAG ${PKG_PINNED_TAG} - EXCLUDE_FROM_ALL ${PKG_EXCLUDE_FROM_ALL} OPTIONS - "CMAKE_INSTALL_INCLUDEDIR include" "CUDAToolkit_ROOT ${CUDAToolkit_LIBRARY_DIR}" ) - if(TARGET cutlass AND NOT TARGET cutlass::cutlass) - add_library(cutlass::cutlass ALIAS cutlass) + if(TARGET cutlass AND NOT TARGET nvidia::cutlass::cutlass) + add_library(nvidia::cutlass::cutlass ALIAS CUTLASS) endif() - # if(cutlass_ADDED) - # rapids_export(BUILD cutlass - # EXPORT_SET cutlass-targets - # GLOBAL_TARGETS cutlass - # NAMESPACE cutlass::) - # endif() + if(cutlass_ADDED) + rapids_export(BUILD cutlass + EXPORT_SET NvidiaCutlass + GLOBAL_TARGETS nvidia::cutlass::CUTLASS + NAMESPACE nvidia::cutlass::) + endif() endif() - # We generate the faiss-config files when we built faiss locally, so always do `find_dependency` - rapids_export_package(BUILD cutlass raft-distance-lib-exports GLOBAL_TARGETS cutlass::cutlass cutlass) - rapids_export_package(INSTALL cutlass raft-distance-lib-exports GLOBAL_TARGETS cutlass::cutlass cutlass) + # We generate the cutlass-config files when we built cutlass locally, so always do `find_dependency` + rapids_export_package(BUILD cutlass raft-distance-exports GLOBAL_TARGETS nvidia::cutlass::CUTLASS) + rapids_export_package(INSTALL cutlass raft-distance-exports GLOBAL_TARGETS nvidia::cutlass::CUTLASS) # Tell cmake where it can find the generated cutlass-config.cmake we wrote. include("${rapids-cmake-dir}/export/find_package_root.cmake") - rapids_export_find_package_root(INSTALL cutlass [=[${CMAKE_CURRENT_LIST_DIR}]=] raft-distance-lib-exports) + rapids_export_find_package_root(INSTALL cutlass [=[${CMAKE_CURRENT_LIST_DIR}]=] raft-distance-exports) endfunction() if(NOT RAFT_CUTLASS_GIT_TAG) - set(RAFT_CUTLASS_GIT_TAG v2.9.0) + set(RAFT_CUTLASS_GIT_TAG v2.9.1) endif() if(NOT RAFT_CUTLASS_GIT_REPOSITORY) set(RAFT_CUTLASS_GIT_REPOSITORY https://github.com/NVIDIA/cutlass.git) endif() -find_and_configure_cutlass(VERSION 2.9.0 +find_and_configure_cutlass(VERSION 2.9.1 REPOSITORY ${RAFT_CUTLASS_GIT_REPOSITORY} PINNED_TAG ${RAFT_CUTLASS_GIT_TAG}) From d4bdec587bb3f5b54fa87cdb92db092e4e47e861 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Fri, 4 Nov 2022 15:44:15 +0530 Subject: [PATCH 09/25] remove the cutlass namespace setting in test cmakefiles as it is not required now --- cpp/test/CMakeLists.txt | 3 --- 1 file changed, 3 deletions(-) diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 645c5d0d03..059a1a792b 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -62,9 +62,6 @@ function(ConfigureTest) PUBLIC "$" ) - target_compile_definitions(${TEST_NAME} - INTERFACE "CUTLASS_NAMESPACE=${RAFT_CUTLASS_NAMESPACE}") - install( TARGETS ${TEST_NAME} COMPONENT testing From d26bcef911076b32c4da8ca42e91dd66f9964d2e Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Fri, 4 Nov 2022 18:00:11 +0530 Subject: [PATCH 10/25] temp remove dist dependency from cutlass to check if it works in ci/cd --- cpp/cmake/thirdparty/get_cutlass.cmake | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/cmake/thirdparty/get_cutlass.cmake b/cpp/cmake/thirdparty/get_cutlass.cmake index 4908e3fda7..9757418d21 100644 --- a/cpp/cmake/thirdparty/get_cutlass.cmake +++ b/cpp/cmake/thirdparty/get_cutlass.cmake @@ -19,7 +19,7 @@ function(find_and_configure_cutlass) cmake_parse_arguments(PKG "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN} ) - if(RAFT_ENABLE_DIST_DEPENDENCIES OR RAFT_COMPILE_LIBRARIES) + #if(RAFT_ENABLE_DIST_DEPENDENCIES OR RAFT_COMPILE_LIBRARIES) set(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library") set(CUTLASS_NAMESPACE "raft_cutlass" CACHE STRING "Top level namespace of CUTLASS") @@ -42,7 +42,7 @@ function(find_and_configure_cutlass) GLOBAL_TARGETS nvidia::cutlass::CUTLASS NAMESPACE nvidia::cutlass::) endif() - endif() + #endif() # We generate the cutlass-config files when we built cutlass locally, so always do `find_dependency` rapids_export_package(BUILD cutlass raft-distance-exports GLOBAL_TARGETS nvidia::cutlass::CUTLASS) From 451c3c06eb850ed70df554f79c1663c77f02cee6 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Thu, 10 Nov 2022 15:29:26 +0530 Subject: [PATCH 11/25] fix get_cutlass.cmake to work with pylibraft by using NvidiaCutlass instead of cutlass for its cmake config file --- cpp/cmake/thirdparty/get_cutlass.cmake | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/cpp/cmake/thirdparty/get_cutlass.cmake b/cpp/cmake/thirdparty/get_cutlass.cmake index 9757418d21..7a1a5701b3 100644 --- a/cpp/cmake/thirdparty/get_cutlass.cmake +++ b/cpp/cmake/thirdparty/get_cutlass.cmake @@ -22,9 +22,10 @@ function(find_and_configure_cutlass) #if(RAFT_ENABLE_DIST_DEPENDENCIES OR RAFT_COMPILE_LIBRARIES) set(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library") set(CUTLASS_NAMESPACE "raft_cutlass" CACHE STRING "Top level namespace of CUTLASS") + set(CUTLASS_ENABLE_CUBLAS OFF CACHE BOOL "Disable CUTLASS to build with cuBLAS library.") - rapids_cpm_find(cutlass ${PKG_VERSION} - GLOBAL_TARGETS nvidia::cutlass::CUTLASS + rapids_cpm_find(NvidiaCutlass ${PKG_VERSION} + GLOBAL_TARGETS nvidia::cutlass::cutlass CPM_ARGS GIT_REPOSITORY ${PKG_REPOSITORY} GIT_TAG ${PKG_PINNED_TAG} @@ -32,25 +33,25 @@ function(find_and_configure_cutlass) "CUDAToolkit_ROOT ${CUDAToolkit_LIBRARY_DIR}" ) - if(TARGET cutlass AND NOT TARGET nvidia::cutlass::cutlass) + if(TARGET CUTLASS AND NOT TARGET nvidia::cutlass::cutlass) add_library(nvidia::cutlass::cutlass ALIAS CUTLASS) endif() - if(cutlass_ADDED) - rapids_export(BUILD cutlass + if(NvidiaCutlass_ADDED) + rapids_export(BUILD NvidiaCutlass EXPORT_SET NvidiaCutlass - GLOBAL_TARGETS nvidia::cutlass::CUTLASS + GLOBAL_TARGETS nvidia::cutlass::cutlass NAMESPACE nvidia::cutlass::) endif() #endif() # We generate the cutlass-config files when we built cutlass locally, so always do `find_dependency` - rapids_export_package(BUILD cutlass raft-distance-exports GLOBAL_TARGETS nvidia::cutlass::CUTLASS) - rapids_export_package(INSTALL cutlass raft-distance-exports GLOBAL_TARGETS nvidia::cutlass::CUTLASS) + rapids_export_package(BUILD NvidiaCutlass raft-distance-exports GLOBAL_TARGETS nvidia::cutlass::cutlass) + rapids_export_package(INSTALL NvidiaCutlass raft-distance-exports GLOBAL_TARGETS nvidia::cutlass::cutlass) - # Tell cmake where it can find the generated cutlass-config.cmake we wrote. + # Tell cmake where it can find the generated NvidiaCutlass-config.cmake we wrote. include("${rapids-cmake-dir}/export/find_package_root.cmake") - rapids_export_find_package_root(INSTALL cutlass [=[${CMAKE_CURRENT_LIST_DIR}]=] raft-distance-exports) + rapids_export_find_package_root(INSTALL NvidiaCutlass [=[${CMAKE_CURRENT_LIST_DIR}]=] raft-distance-exports) endfunction() if(NOT RAFT_CUTLASS_GIT_TAG) From 7b512f99bbfffee488ad329062787b017ba2f087 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Fri, 11 Nov 2022 01:09:09 +0530 Subject: [PATCH 12/25] fix get_cutlass install path, make changes as per review comments --- cpp/cmake/thirdparty/get_cutlass.cmake | 12 +++++++----- cpp/include/raft/distance/detail/cosine.cuh | 6 +++++- cpp/include/raft/distance/detail/distance.cuh | 4 ++++ cpp/include/raft/distance/detail/euclidean.cuh | 6 +++++- .../detail/pairwise_distance_cutlass_base.cuh | 7 +++++++ .../distance/detail/pairwise_distance_epilogue.h | 11 ++++++++--- .../detail/pairwise_distance_epilogue_elementwise.h | 3 +++ .../raft/distance/detail/pairwise_distance_gemm.h | 4 ++-- .../detail/predicated_tile_iterator_normvec.h | 9 +++++++-- cpp/include/raft/util/cudart_utils.hpp | 2 +- cpp/test/distance/dist_adj.cu | 4 ++++ 11 files changed, 53 insertions(+), 15 deletions(-) diff --git a/cpp/cmake/thirdparty/get_cutlass.cmake b/cpp/cmake/thirdparty/get_cutlass.cmake index 7a1a5701b3..afae39974d 100644 --- a/cpp/cmake/thirdparty/get_cutlass.cmake +++ b/cpp/cmake/thirdparty/get_cutlass.cmake @@ -25,11 +25,12 @@ function(find_and_configure_cutlass) set(CUTLASS_ENABLE_CUBLAS OFF CACHE BOOL "Disable CUTLASS to build with cuBLAS library.") rapids_cpm_find(NvidiaCutlass ${PKG_VERSION} - GLOBAL_TARGETS nvidia::cutlass::cutlass + GLOBAL_TARGETS nvidia::cutlass::cutlass CPM_ARGS - GIT_REPOSITORY ${PKG_REPOSITORY} - GIT_TAG ${PKG_PINNED_TAG} - OPTIONS + GIT_REPOSITORY ${PKG_REPOSITORY} + GIT_TAG ${PKG_PINNED_TAG} + GIT_SHALLOW TRUE + OPTIONS "CUDAToolkit_ROOT ${CUDAToolkit_LIBRARY_DIR}" ) @@ -51,7 +52,8 @@ function(find_and_configure_cutlass) # Tell cmake where it can find the generated NvidiaCutlass-config.cmake we wrote. include("${rapids-cmake-dir}/export/find_package_root.cmake") - rapids_export_find_package_root(INSTALL NvidiaCutlass [=[${CMAKE_CURRENT_LIST_DIR}]=] raft-distance-exports) + rapids_export_find_package_root(INSTALL NvidiaCutlass [=[${CMAKE_CURRENT_LIST_DIR}/../]=] raft-distance-exports) + rapids_export_find_package_root(BUILD NvidiaCutlass [=[${CMAKE_CURRENT_LIST_DIR}]=] raft-distance-exports) endfunction() if(NOT RAFT_CUTLASS_GIT_TAG) diff --git a/cpp/include/raft/distance/detail/cosine.cuh b/cpp/include/raft/distance/detail/cosine.cuh index e3d400e799..eea4af0d06 100644 --- a/cpp/include/raft/distance/detail/cosine.cuh +++ b/cpp/include/raft/distance/detail/cosine.cuh @@ -82,7 +82,7 @@ void cosineImpl(const DataT* x, FinalLambda fin_op, cudaStream_t stream) { - const auto deviceVersion = getMajorMinorVersion(); + const auto deviceVersion = getComputeCapability(); if (deviceVersion.first >= 8) { using CosineOp_ = CosineOp; CosineOp_ cosine_dist_op; @@ -228,6 +228,10 @@ void cosineAlgo1(Index_ m, { auto norm_op = [] __device__(AccType in) { return raft::mySqrt(in); }; + // raft distance support inputs as float/double and output as uint8_t/float/double. + static_assert(!((sizeof(OutType) > 1) && (sizeof(AccType) != sizeof(OutType))), + "OutType can be uint8_t, float, double," + "if sizeof(OutType) > 1 then sizeof(AccType) == sizeof(OutType)."); typedef typename std::conditional::type CosOutType; CosOutType* pDcast = reinterpret_cast(pD); diff --git a/cpp/include/raft/distance/detail/distance.cuh b/cpp/include/raft/distance/detail/distance.cuh index c1a9c7ead3..f7049ed981 100644 --- a/cpp/include/raft/distance/detail/distance.cuh +++ b/cpp/include/raft/distance/detail/distance.cuh @@ -648,6 +648,10 @@ void distance(const InType* x, using final_op_type = default_fin_op; final_op_type fin_op; + // raft distance support inputs as float/double and output as uint8_t/float/double. + static_assert(! ((sizeof(OutType) > 1) && (sizeof(AccType) != sizeof(OutType))), + "OutType can be uint8_t, float, double," + "if sizeof(OutType) > 1 then sizeof(AccType) == sizeof(OutType)."); distance( x, y, dist, m, n, k, workspace, worksize, fin_op, stream, isRowMajor, metric_arg); RAFT_CUDA_TRY(cudaPeekAtLastError()); diff --git a/cpp/include/raft/distance/detail/euclidean.cuh b/cpp/include/raft/distance/detail/euclidean.cuh index 64359e2270..6a9a03ee03 100644 --- a/cpp/include/raft/distance/detail/euclidean.cuh +++ b/cpp/include/raft/distance/detail/euclidean.cuh @@ -88,7 +88,7 @@ void euclideanExpImpl(const DataT* x, FinalLambda fin_op, cudaStream_t stream) { - const auto deviceVersion = getMajorMinorVersion(); + const auto deviceVersion = getComputeCapability(); if (deviceVersion.first >= 8) { using L2Op = L2ExpandedOp; L2Op L2_dist_op(sqrt); @@ -245,6 +245,10 @@ void euclideanAlgo1(Index_ m, { auto norm_op = [] __device__(InType in) { return in; }; + // raft distance support inputs as float/double and output as uint8_t/float/double. + static_assert(! ((sizeof(OutType) > 1) && (sizeof(AccType) != sizeof(OutType))), + "OutType can be uint8_t, float, double," + "if sizeof(OutType) > 1 then sizeof(AccType) == sizeof(OutType)."); typedef typename std::conditional::type ExpOutType; ExpOutType* pDcast = reinterpret_cast(pD); diff --git a/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh b/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh index 5b67d77e96..cc5b33be2f 100644 --- a/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh +++ b/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh @@ -19,6 +19,13 @@ #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wstrict-aliasing" +// We define CUTLASS_NAMESPACE in case +// RAFT cmake is not used +#ifndef CUTLASS_NAMESPACE +#define cutlass raft_cutlass +#endif + + #include #include diff --git a/cpp/include/raft/distance/detail/pairwise_distance_epilogue.h b/cpp/include/raft/distance/detail/pairwise_distance_epilogue.h index 28eb57fdd0..c9833fe7ae 100644 --- a/cpp/include/raft/distance/detail/pairwise_distance_epilogue.h +++ b/cpp/include/raft/distance/detail/pairwise_distance_epilogue.h @@ -17,9 +17,14 @@ /*! \file \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. - The epilogue rearranges the result of a matrix product through shared memory to match canonical - tensor layouts in global memory. Epilogues support conversion and reduction operations. - +This is adapted from DefaultEpilogueWithBroadcastTensorOp from CUTLASS 2.9.0 +(https://github.com/NVIDIA/cutlass/blob/master/include/cutlass/epilogue/threadblock/default_epilogue_with_broadcast.h#L75) + +This epilogue allows us to load norm buffers using PredicatedTileIteratorNormVec +and EpilogueWithBroadcast used for distances L2/cosine as well as applies user-define elementwise operation. +-- A norm load is provided PredicatedTileIteratorNormVec +-- B norm load is provided by EpilogueWithBroadcast +-- elementwise operation is provided by OutputOp */ #pragma once diff --git a/cpp/include/raft/distance/detail/pairwise_distance_epilogue_elementwise.h b/cpp/include/raft/distance/detail/pairwise_distance_epilogue_elementwise.h index 80c926874a..3e33f4d833 100644 --- a/cpp/include/raft/distance/detail/pairwise_distance_epilogue_elementwise.h +++ b/cpp/include/raft/distance/detail/pairwise_distance_epilogue_elementwise.h @@ -18,6 +18,9 @@ /*! \file \brief Functor performing distance operations used by epilogues of pairwise distance * kernels. +* This is adapted from LinearCombinationBiasElementwise from CUTLASS 2.9.0 +* customized for applying elementwise distance formula on accumulated GEMM value +* and applying user-defined final custom operation on the distance value. */ #pragma once diff --git a/cpp/include/raft/distance/detail/pairwise_distance_gemm.h b/cpp/include/raft/distance/detail/pairwise_distance_gemm.h index 6768a1b579..ea9ed77fb5 100644 --- a/cpp/include/raft/distance/detail/pairwise_distance_gemm.h +++ b/cpp/include/raft/distance/detail/pairwise_distance_gemm.h @@ -80,7 +80,7 @@ struct PairwiseDistanceGemm { // This code section describes how threadblocks are scheduled on GPU /// Threadblock-level swizzling operator - using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ?? + using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; /// data layout for final output matrix. // we keep this same layout even for column major inputs @@ -179,7 +179,7 @@ struct PairwiseDistanceGemm; // <- ?? + using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; /// data layout for final output matrix. // we keep this same layout even for column major inputs diff --git a/cpp/include/raft/distance/detail/predicated_tile_iterator_normvec.h b/cpp/include/raft/distance/detail/predicated_tile_iterator_normvec.h index c343d09083..67c01448dc 100644 --- a/cpp/include/raft/distance/detail/predicated_tile_iterator_normvec.h +++ b/cpp/include/raft/distance/detail/predicated_tile_iterator_normvec.h @@ -17,8 +17,13 @@ /*! \file \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. - The epilogue rearranges the result of a matrix product through shared memory to match canonical - tensor layouts in global memory. Epilogues support conversion and reduction operations. +This file contains a customized version of PredicatedTileIterator from CUTLASS 2.9.0 +(https://github.com/NVIDIA/cutlass/blob/v2.9.0/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h#L75) + +Changes: +- added `Layout_` template param +- Only the row index is used to load the data in load_with_byte_offset(). + This way the same normalization data is used across all columns in a row. */ diff --git a/cpp/include/raft/util/cudart_utils.hpp b/cpp/include/raft/util/cudart_utils.hpp index 0a3fe1d463..faac17a745 100644 --- a/cpp/include/raft/util/cudart_utils.hpp +++ b/cpp/include/raft/util/cudart_utils.hpp @@ -356,7 +356,7 @@ inline int getMultiProcessorCount() } /** helper method to get major minor compute capability version */ -inline std::pair getMajorMinorVersion() +inline std::pair getComputeCapability() { int devId; RAFT_CUDA_TRY(cudaGetDevice(&devId)); diff --git a/cpp/test/distance/dist_adj.cu b/cpp/test/distance/dist_adj.cu index 19a11ecd2e..f3f36b4576 100644 --- a/cpp/test/distance/dist_adj.cu +++ b/cpp/test/distance/dist_adj.cu @@ -150,6 +150,10 @@ class DistanceAdjTest : public ::testing::TestWithParam params; + // We use uint8_t even if the output in this test is a bool because + // cutlass doesn't support bool as output buffer yet. In cuda + // sizeof(bool) is 1 byte hence it doesn't increase + // memory consumption if we use uint8_t instead of bool. rmm::device_uvector dist_ref; rmm::device_uvector dist; raft::handle_t handle; From d32b4c0e1c5ede36b02aac394e0645203d0e89a1 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Fri, 11 Nov 2022 01:30:23 +0530 Subject: [PATCH 13/25] fix clang format issues --- cpp/include/raft/distance/detail/distance.cuh | 2 +- cpp/include/raft/distance/detail/euclidean.cuh | 2 +- .../raft/distance/detail/pairwise_distance_cutlass_base.cuh | 3 +-- .../raft/distance/detail/pairwise_distance_epilogue.h | 5 +++-- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/cpp/include/raft/distance/detail/distance.cuh b/cpp/include/raft/distance/detail/distance.cuh index f7049ed981..cb41af8746 100644 --- a/cpp/include/raft/distance/detail/distance.cuh +++ b/cpp/include/raft/distance/detail/distance.cuh @@ -649,7 +649,7 @@ void distance(const InType* x, final_op_type fin_op; // raft distance support inputs as float/double and output as uint8_t/float/double. - static_assert(! ((sizeof(OutType) > 1) && (sizeof(AccType) != sizeof(OutType))), + static_assert(!((sizeof(OutType) > 1) && (sizeof(AccType) != sizeof(OutType))), "OutType can be uint8_t, float, double," "if sizeof(OutType) > 1 then sizeof(AccType) == sizeof(OutType)."); distance( diff --git a/cpp/include/raft/distance/detail/euclidean.cuh b/cpp/include/raft/distance/detail/euclidean.cuh index 6a9a03ee03..1b645b5699 100644 --- a/cpp/include/raft/distance/detail/euclidean.cuh +++ b/cpp/include/raft/distance/detail/euclidean.cuh @@ -246,7 +246,7 @@ void euclideanAlgo1(Index_ m, auto norm_op = [] __device__(InType in) { return in; }; // raft distance support inputs as float/double and output as uint8_t/float/double. - static_assert(! ((sizeof(OutType) > 1) && (sizeof(AccType) != sizeof(OutType))), + static_assert(!((sizeof(OutType) > 1) && (sizeof(AccType) != sizeof(OutType))), "OutType can be uint8_t, float, double," "if sizeof(OutType) > 1 then sizeof(AccType) == sizeof(OutType)."); typedef typename std::conditional::type ExpOutType; diff --git a/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh b/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh index cc5b33be2f..3f052e4239 100644 --- a/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh +++ b/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh @@ -22,10 +22,9 @@ // We define CUTLASS_NAMESPACE in case // RAFT cmake is not used #ifndef CUTLASS_NAMESPACE -#define cutlass raft_cutlass +#define cutlass raft_cutlass #endif - #include #include diff --git a/cpp/include/raft/distance/detail/pairwise_distance_epilogue.h b/cpp/include/raft/distance/detail/pairwise_distance_epilogue.h index c9833fe7ae..21e7d18854 100644 --- a/cpp/include/raft/distance/detail/pairwise_distance_epilogue.h +++ b/cpp/include/raft/distance/detail/pairwise_distance_epilogue.h @@ -20,8 +20,9 @@ This is adapted from DefaultEpilogueWithBroadcastTensorOp from CUTLASS 2.9.0 (https://github.com/NVIDIA/cutlass/blob/master/include/cutlass/epilogue/threadblock/default_epilogue_with_broadcast.h#L75) -This epilogue allows us to load norm buffers using PredicatedTileIteratorNormVec -and EpilogueWithBroadcast used for distances L2/cosine as well as applies user-define elementwise operation. +This epilogue allows us to load norm buffers using PredicatedTileIteratorNormVec +and EpilogueWithBroadcast used for distances L2/cosine as well as applies user-define elementwise +operation. -- A norm load is provided PredicatedTileIteratorNormVec -- B norm load is provided by EpilogueWithBroadcast -- elementwise operation is provided by OutputOp From f7c440ac0a73a7013e6c57e27eca26dd83347fb2 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Fri, 11 Nov 2022 18:29:38 +0530 Subject: [PATCH 14/25] temp fix to check if python build works --- cpp/CMakeLists.txt | 3 +++ 1 file changed, 3 insertions(+) diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 07335c8dca..d31e0b4803 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -590,6 +590,9 @@ string( [=[ if(distance IN_LIST raft_FIND_COMPONENTS) enable_language(CUDA) + if(TARGET CUTLASS AND NOT TARGET nvidia::cutlass::cutlass) + add_library(nvidia::cutlass::cutlass ALIAS CUTLASS) + endif() endif() if(nn IN_LIST raft_FIND_COMPONENTS) From b1a1fd797d5d33303d17f00ec41e66b79e6d26f7 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Tue, 15 Nov 2022 19:12:25 +0530 Subject: [PATCH 15/25] add raft-exports instead of raft-distance-exports as other raft components also use distance headers --- cpp/CMakeLists.txt | 3 --- cpp/cmake/thirdparty/get_cutlass.cmake | 8 ++++---- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index d31e0b4803..07335c8dca 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -590,9 +590,6 @@ string( [=[ if(distance IN_LIST raft_FIND_COMPONENTS) enable_language(CUDA) - if(TARGET CUTLASS AND NOT TARGET nvidia::cutlass::cutlass) - add_library(nvidia::cutlass::cutlass ALIAS CUTLASS) - endif() endif() if(nn IN_LIST raft_FIND_COMPONENTS) diff --git a/cpp/cmake/thirdparty/get_cutlass.cmake b/cpp/cmake/thirdparty/get_cutlass.cmake index afae39974d..77f49b0fb0 100644 --- a/cpp/cmake/thirdparty/get_cutlass.cmake +++ b/cpp/cmake/thirdparty/get_cutlass.cmake @@ -47,13 +47,13 @@ function(find_and_configure_cutlass) #endif() # We generate the cutlass-config files when we built cutlass locally, so always do `find_dependency` - rapids_export_package(BUILD NvidiaCutlass raft-distance-exports GLOBAL_TARGETS nvidia::cutlass::cutlass) - rapids_export_package(INSTALL NvidiaCutlass raft-distance-exports GLOBAL_TARGETS nvidia::cutlass::cutlass) + rapids_export_package(BUILD NvidiaCutlass raft-exports GLOBAL_TARGETS nvidia::cutlass::cutlass) + rapids_export_package(INSTALL NvidiaCutlass raft-exports GLOBAL_TARGETS nvidia::cutlass::cutlass) # Tell cmake where it can find the generated NvidiaCutlass-config.cmake we wrote. include("${rapids-cmake-dir}/export/find_package_root.cmake") - rapids_export_find_package_root(INSTALL NvidiaCutlass [=[${CMAKE_CURRENT_LIST_DIR}/../]=] raft-distance-exports) - rapids_export_find_package_root(BUILD NvidiaCutlass [=[${CMAKE_CURRENT_LIST_DIR}]=] raft-distance-exports) + rapids_export_find_package_root(INSTALL NvidiaCutlass [=[${CMAKE_CURRENT_LIST_DIR}/../]=] raft-exports) + rapids_export_find_package_root(BUILD NvidiaCutlass [=[${CMAKE_CURRENT_LIST_DIR}]=] raft-exports) endfunction() if(NOT RAFT_CUTLASS_GIT_TAG) From 4ef44e75f23608c80189211af907b201228d773f Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Wed, 16 Nov 2022 15:15:29 +0530 Subject: [PATCH 16/25] make cutlass to depend only on raft_distance and add raft_distance dependency wherever needed in downstream tests/lib --- cpp/CMakeLists.txt | 9 +++++---- cpp/cmake/thirdparty/get_cutlass.cmake | 8 ++++---- cpp/test/CMakeLists.txt | 2 ++ 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 07335c8dca..823e843163 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -218,7 +218,7 @@ target_link_libraries( CUDA::cusolver${_ctk_static_suffix} CUDA::cusparse${_ctk_static_suffix} $<$:raft::Thrust> - nvidia::cutlass::cutlass + #nvidia::cutlass::cutlass ) target_compile_features(raft INTERFACE cxx_std_17 $) @@ -352,7 +352,7 @@ if(RAFT_COMPILE_DIST_LIBRARY) INTERFACE_POSITION_INDEPENDENT_CODE ON ) - target_link_libraries(raft_distance_lib PUBLIC raft::raft cuco::cuco) + target_link_libraries(raft_distance_lib PUBLIC raft::raft cuco::cuco nvidia::cutlass::cutlass) target_compile_options( raft_distance_lib PRIVATE "$<$:${RAFT_CXX_FLAGS}>" "$<$:${RAFT_CUDA_FLAGS}>" @@ -369,9 +369,10 @@ if(TARGET raft_distance_lib AND (NOT TARGET raft::raft_distance_lib)) endif() target_link_libraries( - raft_distance INTERFACE raft::raft $ + raft_distance INTERFACE raft::raft $ nvidia::cutlass::cutlass ) + # ################################################################################################## # * raft_nn ------------------------------------------------------------------ add_library(raft_nn INTERFACE) @@ -424,7 +425,7 @@ if(RAFT_COMPILE_NN_LIBRARY) INTERFACE_POSITION_INDEPENDENT_CODE ON ) - target_link_libraries(raft_nn_lib PUBLIC faiss::faiss raft::raft) + target_link_libraries(raft_nn_lib PUBLIC faiss::faiss raft::raft raft_distance) target_compile_options( raft_nn_lib PRIVATE "$<$:${RAFT_CXX_FLAGS}>" "$<$:${RAFT_CUDA_FLAGS}>" diff --git a/cpp/cmake/thirdparty/get_cutlass.cmake b/cpp/cmake/thirdparty/get_cutlass.cmake index 77f49b0fb0..afae39974d 100644 --- a/cpp/cmake/thirdparty/get_cutlass.cmake +++ b/cpp/cmake/thirdparty/get_cutlass.cmake @@ -47,13 +47,13 @@ function(find_and_configure_cutlass) #endif() # We generate the cutlass-config files when we built cutlass locally, so always do `find_dependency` - rapids_export_package(BUILD NvidiaCutlass raft-exports GLOBAL_TARGETS nvidia::cutlass::cutlass) - rapids_export_package(INSTALL NvidiaCutlass raft-exports GLOBAL_TARGETS nvidia::cutlass::cutlass) + rapids_export_package(BUILD NvidiaCutlass raft-distance-exports GLOBAL_TARGETS nvidia::cutlass::cutlass) + rapids_export_package(INSTALL NvidiaCutlass raft-distance-exports GLOBAL_TARGETS nvidia::cutlass::cutlass) # Tell cmake where it can find the generated NvidiaCutlass-config.cmake we wrote. include("${rapids-cmake-dir}/export/find_package_root.cmake") - rapids_export_find_package_root(INSTALL NvidiaCutlass [=[${CMAKE_CURRENT_LIST_DIR}/../]=] raft-exports) - rapids_export_find_package_root(BUILD NvidiaCutlass [=[${CMAKE_CURRENT_LIST_DIR}]=] raft-exports) + rapids_export_find_package_root(INSTALL NvidiaCutlass [=[${CMAKE_CURRENT_LIST_DIR}/../]=] raft-distance-exports) + rapids_export_find_package_root(BUILD NvidiaCutlass [=[${CMAKE_CURRENT_LIST_DIR}]=] raft-distance-exports) endfunction() if(NOT RAFT_CUTLASS_GIT_TAG) diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 2752c8b1b4..9982f36847 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -196,6 +196,8 @@ if(BUILD_TESTS) ConfigureTest( NAME SOLVERS_TEST PATH test/cluster_solvers_deprecated.cu test/eigen_solvers.cu test/lap/lap.cu test/mst.cu + OPTIONAL + DIST ) ConfigureTest( From 186fcc79d48c5ff6ba66f099395a021b670d3520 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Wed, 16 Nov 2022 15:30:44 +0530 Subject: [PATCH 17/25] fix cmake formatting issues --- cpp/CMakeLists.txt | 5 +- cpp/cmake/thirdparty/get_cutlass.cmake | 111 ++++++++++++++----------- cpp/test/CMakeLists.txt | 4 +- 3 files changed, 67 insertions(+), 53 deletions(-) diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 823e843163..51815d3afc 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -218,7 +218,6 @@ target_link_libraries( CUDA::cusolver${_ctk_static_suffix} CUDA::cusparse${_ctk_static_suffix} $<$:raft::Thrust> - #nvidia::cutlass::cutlass ) target_compile_features(raft INTERFACE cxx_std_17 $) @@ -369,10 +368,10 @@ if(TARGET raft_distance_lib AND (NOT TARGET raft::raft_distance_lib)) endif() target_link_libraries( - raft_distance INTERFACE raft::raft $ nvidia::cutlass::cutlass + raft_distance INTERFACE raft::raft $ + nvidia::cutlass::cutlass ) - # ################################################################################################## # * raft_nn ------------------------------------------------------------------ add_library(raft_nn INTERFACE) diff --git a/cpp/cmake/thirdparty/get_cutlass.cmake b/cpp/cmake/thirdparty/get_cutlass.cmake index afae39974d..039fbf6409 100644 --- a/cpp/cmake/thirdparty/get_cutlass.cmake +++ b/cpp/cmake/thirdparty/get_cutlass.cmake @@ -1,59 +1,76 @@ -#============================================================================= +# ============================================================================= # Copyright (c) 2021-2022, NVIDIA CORPORATION. # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -#============================================================================= +# Unless required by applicable law or agreed to in writing, software distributed under the License +# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing permissions and limitations under +# the License. +# ============================================================================= function(find_and_configure_cutlass) - set(oneValueArgs VERSION REPOSITORY PINNED_TAG) - cmake_parse_arguments(PKG "${options}" "${oneValueArgs}" - "${multiValueArgs}" ${ARGN} ) + set(oneValueArgs VERSION REPOSITORY PINNED_TAG) + cmake_parse_arguments(PKG "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) - #if(RAFT_ENABLE_DIST_DEPENDENCIES OR RAFT_COMPILE_LIBRARIES) - set(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library") - set(CUTLASS_NAMESPACE "raft_cutlass" CACHE STRING "Top level namespace of CUTLASS") - set(CUTLASS_ENABLE_CUBLAS OFF CACHE BOOL "Disable CUTLASS to build with cuBLAS library.") + # if(RAFT_ENABLE_DIST_DEPENDENCIES OR RAFT_COMPILE_LIBRARIES) + set(CUTLASS_ENABLE_HEADERS_ONLY + ON + CACHE BOOL "Enable only the header library" + ) + set(CUTLASS_NAMESPACE + "raft_cutlass" + CACHE STRING "Top level namespace of CUTLASS" + ) + set(CUTLASS_ENABLE_CUBLAS + OFF + CACHE BOOL "Disable CUTLASS to build with cuBLAS library." + ) - rapids_cpm_find(NvidiaCutlass ${PKG_VERSION} - GLOBAL_TARGETS nvidia::cutlass::cutlass - CPM_ARGS - GIT_REPOSITORY ${PKG_REPOSITORY} - GIT_TAG ${PKG_PINNED_TAG} - GIT_SHALLOW TRUE - OPTIONS - "CUDAToolkit_ROOT ${CUDAToolkit_LIBRARY_DIR}" - ) + rapids_cpm_find( + NvidiaCutlass ${PKG_VERSION} + GLOBAL_TARGETS nvidia::cutlass::cutlass + CPM_ARGS + GIT_REPOSITORY ${PKG_REPOSITORY} + GIT_TAG ${PKG_PINNED_TAG} + GIT_SHALLOW TRUE + OPTIONS "CUDAToolkit_ROOT ${CUDAToolkit_LIBRARY_DIR}" + ) - if(TARGET CUTLASS AND NOT TARGET nvidia::cutlass::cutlass) - add_library(nvidia::cutlass::cutlass ALIAS CUTLASS) - endif() + if(TARGET CUTLASS AND NOT TARGET nvidia::cutlass::cutlass) + add_library(nvidia::cutlass::cutlass ALIAS CUTLASS) + endif() - if(NvidiaCutlass_ADDED) - rapids_export(BUILD NvidiaCutlass - EXPORT_SET NvidiaCutlass - GLOBAL_TARGETS nvidia::cutlass::cutlass - NAMESPACE nvidia::cutlass::) - endif() - #endif() + if(NvidiaCutlass_ADDED) + rapids_export( + BUILD NvidiaCutlass + EXPORT_SET NvidiaCutlass + GLOBAL_TARGETS nvidia::cutlass::cutlass + NAMESPACE nvidia::cutlass:: + ) + endif() + # endif() - # We generate the cutlass-config files when we built cutlass locally, so always do `find_dependency` - rapids_export_package(BUILD NvidiaCutlass raft-distance-exports GLOBAL_TARGETS nvidia::cutlass::cutlass) - rapids_export_package(INSTALL NvidiaCutlass raft-distance-exports GLOBAL_TARGETS nvidia::cutlass::cutlass) + # We generate the cutlass-config files when we built cutlass locally, so always do + # `find_dependency` + rapids_export_package( + BUILD NvidiaCutlass raft-distance-exports GLOBAL_TARGETS nvidia::cutlass::cutlass + ) + rapids_export_package( + INSTALL NvidiaCutlass raft-distance-exports GLOBAL_TARGETS nvidia::cutlass::cutlass + ) - # Tell cmake where it can find the generated NvidiaCutlass-config.cmake we wrote. - include("${rapids-cmake-dir}/export/find_package_root.cmake") - rapids_export_find_package_root(INSTALL NvidiaCutlass [=[${CMAKE_CURRENT_LIST_DIR}/../]=] raft-distance-exports) - rapids_export_find_package_root(BUILD NvidiaCutlass [=[${CMAKE_CURRENT_LIST_DIR}]=] raft-distance-exports) + # Tell cmake where it can find the generated NvidiaCutlass-config.cmake we wrote. + include("${rapids-cmake-dir}/export/find_package_root.cmake") + rapids_export_find_package_root( + INSTALL NvidiaCutlass [=[${CMAKE_CURRENT_LIST_DIR}/../]=] raft-distance-exports + ) + rapids_export_find_package_root( + BUILD NvidiaCutlass [=[${CMAKE_CURRENT_LIST_DIR}]=] raft-distance-exports + ) endfunction() if(NOT RAFT_CUTLASS_GIT_TAG) @@ -64,6 +81,6 @@ if(NOT RAFT_CUTLASS_GIT_REPOSITORY) set(RAFT_CUTLASS_GIT_REPOSITORY https://github.com/NVIDIA/cutlass.git) endif() -find_and_configure_cutlass(VERSION 2.9.1 - REPOSITORY ${RAFT_CUTLASS_GIT_REPOSITORY} - PINNED_TAG ${RAFT_CUTLASS_GIT_TAG}) +find_and_configure_cutlass( + VERSION 2.9.1 REPOSITORY ${RAFT_CUTLASS_GIT_REPOSITORY} PINNED_TAG ${RAFT_CUTLASS_GIT_TAG} +) diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 9982f36847..d37b6fc37c 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -195,9 +195,7 @@ if(BUILD_TESTS) ConfigureTest( NAME SOLVERS_TEST PATH test/cluster_solvers_deprecated.cu test/eigen_solvers.cu test/lap/lap.cu - test/mst.cu - OPTIONAL - DIST + test/mst.cu OPTIONAL DIST ) ConfigureTest( From 8aa8909634d3724a531172138699ef88ce90bf99 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Wed, 16 Nov 2022 17:09:34 +0530 Subject: [PATCH 18/25] prevent cutlass based pairwise dist kernels to be disabled on cuda 12 and enable non-cutlass version for ampere+ temporarily --- cpp/include/raft/distance/detail/cosine.cuh | 5 ++++- cpp/include/raft/distance/detail/euclidean.cuh | 6 +++++- .../raft/distance/detail/pairwise_distance_base.cuh | 6 ++++-- .../detail/pairwise_distance_cutlass_base.cuh | 11 +++++++---- 4 files changed, 20 insertions(+), 8 deletions(-) diff --git a/cpp/include/raft/distance/detail/cosine.cuh b/cpp/include/raft/distance/detail/cosine.cuh index eea4af0d06..f06051962f 100644 --- a/cpp/include/raft/distance/detail/cosine.cuh +++ b/cpp/include/raft/distance/detail/cosine.cuh @@ -82,6 +82,7 @@ void cosineImpl(const DataT* x, FinalLambda fin_op, cudaStream_t stream) { +#if (__CUDACC_VER_MAJOR__ < 12) const auto deviceVersion = getComputeCapability(); if (deviceVersion.first >= 8) { using CosineOp_ = CosineOp; @@ -90,7 +91,9 @@ void cosineImpl(const DataT* x, cutlassDistanceKernel( x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput, fin_op, cosine_dist_op, stream); - } else { + } else +#endif + { typedef typename raft::linalg::Policy4x4::Policy RowPolicy; typedef typename raft::linalg::Policy4x4::ColPolicy ColPolicy; diff --git a/cpp/include/raft/distance/detail/euclidean.cuh b/cpp/include/raft/distance/detail/euclidean.cuh index 1b645b5699..5ea74fa884 100644 --- a/cpp/include/raft/distance/detail/euclidean.cuh +++ b/cpp/include/raft/distance/detail/euclidean.cuh @@ -88,6 +88,7 @@ void euclideanExpImpl(const DataT* x, FinalLambda fin_op, cudaStream_t stream) { +#if (__CUDACC_VER_MAJOR__ < 12) const auto deviceVersion = getComputeCapability(); if (deviceVersion.first >= 8) { using L2Op = L2ExpandedOp; @@ -96,7 +97,10 @@ void euclideanExpImpl(const DataT* x, cutlassDistanceKernel( x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput, fin_op, L2_dist_op, stream); - } else { + } else +#endif + { + typedef typename raft::linalg::Policy4x4::Policy RowPolicy; typedef typename raft::linalg::Policy4x4::ColPolicy ColPolicy; diff --git a/cpp/include/raft/distance/detail/pairwise_distance_base.cuh b/cpp/include/raft/distance/detail/pairwise_distance_base.cuh index b314582b5c..26536d13cd 100644 --- a/cpp/include/raft/distance/detail/pairwise_distance_base.cuh +++ b/cpp/include/raft/distance/detail/pairwise_distance_base.cuh @@ -425,7 +425,9 @@ __global__ __launch_bounds__(Policy::Nthreads, 2) EpilogueLambda epilog_op, FinalLambda fin_op) { -#if __CUDA_ARCH__ < 800 + //#if __CUDA_ARCH__ < 800 + // TODO: re-enable the CUDA_ARCH guard for below Ampere once cutlass based + // kernels are enabled for CUDA 12.0 extern __shared__ char smem[]; auto rowEpilog = [] __device__(IdxT starty) { return; }; @@ -444,7 +446,7 @@ __global__ __launch_bounds__(Policy::Nthreads, 2) obj( x, y, m, n, k, lda, ldb, ldd, _xn, _yn, dOutput, smem, core_op, epilog_op, fin_op, rowEpilog); obj.run(); -#endif + //#endif } template diff --git a/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh b/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh index 3f052e4239..f39d880da4 100644 --- a/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh +++ b/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh @@ -19,6 +19,8 @@ #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wstrict-aliasing" +#if (__CUDACC_VER_MAJOR__ < 12) + // We define CUTLASS_NAMESPACE in case // RAFT cmake is not used #ifndef CUTLASS_NAMESPACE @@ -169,7 +171,8 @@ void cutlassDistanceKernel(const DataT* x, CUTLASS_CHECK(status); } -}; // namespace detail -}; // namespace distance -}; // namespace raft -#pragma GCC diagnostic pop \ No newline at end of file +}; // namespace detail +}; // namespace distance +}; // namespace raft +#endif // (__CUDACC_VER_MAJOR__ < 12) +#pragma GCC diagnostic pop From abfd4932d3937a7b5f11b908cde000c41d0fc42f Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Wed, 16 Nov 2022 12:18:18 -0500 Subject: [PATCH 19/25] Moving cutlass dependency to distance and nn to keep them separate. --- cpp/CMakeLists.txt | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 51815d3afc..b3085069e1 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -424,7 +424,7 @@ if(RAFT_COMPILE_NN_LIBRARY) INTERFACE_POSITION_INDEPENDENT_CODE ON ) - target_link_libraries(raft_nn_lib PUBLIC faiss::faiss raft::raft raft_distance) + target_link_libraries(raft_nn_lib PUBLIC faiss::faiss raft::raft nvidia::cutlass::cutlass) target_compile_options( raft_nn_lib PRIVATE "$<$:${RAFT_CXX_FLAGS}>" "$<$:${RAFT_CUDA_FLAGS}>" @@ -439,7 +439,9 @@ if(TARGET raft_nn_lib AND (NOT TARGET raft::raft_nn_lib)) add_library(raft::raft_nn_lib ALIAS raft_nn_lib) endif() -target_link_libraries(raft_nn INTERFACE raft::raft $) +target_link_libraries( + raft_nn INTERFACE raft::raft $ nvidia::cutlass::cutlass +) # ################################################################################################## # * install targets----------------------------------------------------------- From f1b123959aa7a03bcbee9b1e49a6769ecf8b761e Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Wed, 16 Nov 2022 12:24:17 -0500 Subject: [PATCH 20/25] Adding CUTLASS to build docs as dependency --- docs/source/build.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/source/build.md b/docs/source/build.md index 2a093fcc22..6533c0b783 100644 --- a/docs/source/build.md +++ b/docs/source/build.md @@ -19,8 +19,9 @@ In addition to the libraries included with cudatoolkit 11.0+, there are some oth #### Optional - [cuCollections](https://github.com/NVIDIA/cuCollections) - Used in `raft::sparse::distance` API. -- [Libcu++](https://github.com/NVIDIA/libcudacxx) v1.7.0 -- [FAISS](https://github.com/facebookresearch/faiss) v1.7.0 - Used in `raft::neighbors` API.. +- [Libcu++](https://github.com/NVIDIA/libcudacxx) v1.7.0 - Used by cuCollections +- [CUTLASS](https://github.com/NVIDIA/cutlass) v2.9.1 - Used in `raft::distance` API. +- [FAISS](https://github.com/facebookresearch/faiss) v1.7.0 - Used in `raft::neighbors` API. - [NCCL](https://github.com/NVIDIA/nccl) - Used in `raft::comms` API and needed to build `raft-dask`. - [UCX](https://github.com/openucx/ucx) - Used in `raft::comms` API and needed to build `raft-dask`. - [Googletest](https://github.com/google/googletest) - Needed to build tests From 32e6052fa7e0fb670daf6c81d7e20d483dd20e05 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Wed, 16 Nov 2022 13:02:29 -0500 Subject: [PATCH 21/25] Updating to export to both distance and nn --- cpp/cmake/thirdparty/get_cutlass.cmake | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/cmake/thirdparty/get_cutlass.cmake b/cpp/cmake/thirdparty/get_cutlass.cmake index 039fbf6409..20127cbe4a 100644 --- a/cpp/cmake/thirdparty/get_cutlass.cmake +++ b/cpp/cmake/thirdparty/get_cutlass.cmake @@ -60,7 +60,7 @@ function(find_and_configure_cutlass) BUILD NvidiaCutlass raft-distance-exports GLOBAL_TARGETS nvidia::cutlass::cutlass ) rapids_export_package( - INSTALL NvidiaCutlass raft-distance-exports GLOBAL_TARGETS nvidia::cutlass::cutlass + INSTALL NvidiaCutlass raft-nn-exports GLOBAL_TARGETS nvidia::cutlass::cutlass ) # Tell cmake where it can find the generated NvidiaCutlass-config.cmake we wrote. @@ -69,7 +69,7 @@ function(find_and_configure_cutlass) INSTALL NvidiaCutlass [=[${CMAKE_CURRENT_LIST_DIR}/../]=] raft-distance-exports ) rapids_export_find_package_root( - BUILD NvidiaCutlass [=[${CMAKE_CURRENT_LIST_DIR}]=] raft-distance-exports + BUILD NvidiaCutlass [=[${CMAKE_CURRENT_LIST_DIR}]=] raft-nn-exports ) endfunction() From f6de9ee0f4985498dcc6e61eda7f8eb7e4c6ed70 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Wed, 16 Nov 2022 13:06:08 -0500 Subject: [PATCH 22/25] Adding cutlass as private dependency --- cpp/CMakeLists.txt | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index b3085069e1..9b9e88773b 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -351,7 +351,11 @@ if(RAFT_COMPILE_DIST_LIBRARY) INTERFACE_POSITION_INDEPENDENT_CODE ON ) - target_link_libraries(raft_distance_lib PUBLIC raft::raft cuco::cuco nvidia::cutlass::cutlass) + target_link_libraries( + raft_distance_lib + PUBLIC raft::raft cuco::cuco + PRIVATE nvidia::cutlass::cutlass + ) target_compile_options( raft_distance_lib PRIVATE "$<$:${RAFT_CXX_FLAGS}>" "$<$:${RAFT_CUDA_FLAGS}>" @@ -368,8 +372,9 @@ if(TARGET raft_distance_lib AND (NOT TARGET raft::raft_distance_lib)) endif() target_link_libraries( - raft_distance INTERFACE raft::raft $ - nvidia::cutlass::cutlass + raft_distance + INTERFACE raft::raft $ + PRIVATE nvidia::cutlass::cutlass ) # ################################################################################################## @@ -424,7 +429,11 @@ if(RAFT_COMPILE_NN_LIBRARY) INTERFACE_POSITION_INDEPENDENT_CODE ON ) - target_link_libraries(raft_nn_lib PUBLIC faiss::faiss raft::raft nvidia::cutlass::cutlass) + target_link_libraries( + raft_nn_lib + PUBLIC faiss::faiss raft::raft + PRIVATE nvidia::cutlass::cutlass + ) target_compile_options( raft_nn_lib PRIVATE "$<$:${RAFT_CXX_FLAGS}>" "$<$:${RAFT_CUDA_FLAGS}>" @@ -440,7 +449,9 @@ if(TARGET raft_nn_lib AND (NOT TARGET raft::raft_nn_lib)) endif() target_link_libraries( - raft_nn INTERFACE raft::raft $ nvidia::cutlass::cutlass + raft_nn + INTERFACE raft::raft $ + PRIVATE nvidia::cutlass::cutlass ) # ################################################################################################## From 9bf064777e3838394cb7fceefa1edebb0b2897b5 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Wed, 16 Nov 2022 13:27:51 -0500 Subject: [PATCH 23/25] Making cutlass INTERFACE in raft::nn and raft::distance --- cpp/CMakeLists.txt | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 9b9e88773b..89b2cb1bf9 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -372,9 +372,8 @@ if(TARGET raft_distance_lib AND (NOT TARGET raft::raft_distance_lib)) endif() target_link_libraries( - raft_distance - INTERFACE raft::raft $ - PRIVATE nvidia::cutlass::cutlass + raft_distance INTERFACE raft::raft $ + nvidia::cutlass::cutlass ) # ################################################################################################## @@ -449,9 +448,7 @@ if(TARGET raft_nn_lib AND (NOT TARGET raft::raft_nn_lib)) endif() target_link_libraries( - raft_nn - INTERFACE raft::raft $ - PRIVATE nvidia::cutlass::cutlass + raft_nn INTERFACE raft::raft $ nvidia::cutlass::cutlass ) # ################################################################################################## From 8f0119a8e87eda20a3a399cedb691773e8ad3ba7 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Wed, 16 Nov 2022 13:37:49 -0500 Subject: [PATCH 24/25] Using proper exports per Robert Maynard's suggestion. --- cpp/CMakeLists.txt | 12 ++---------- cpp/cmake/thirdparty/get_cutlass.cmake | 21 +++++++++++++++++---- 2 files changed, 19 insertions(+), 14 deletions(-) diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 89b2cb1bf9..97640f6738 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -351,11 +351,7 @@ if(RAFT_COMPILE_DIST_LIBRARY) INTERFACE_POSITION_INDEPENDENT_CODE ON ) - target_link_libraries( - raft_distance_lib - PUBLIC raft::raft cuco::cuco - PRIVATE nvidia::cutlass::cutlass - ) + target_link_libraries(raft_distance_lib PUBLIC raft::raft cuco::cuco) target_compile_options( raft_distance_lib PRIVATE "$<$:${RAFT_CXX_FLAGS}>" "$<$:${RAFT_CUDA_FLAGS}>" @@ -428,11 +424,7 @@ if(RAFT_COMPILE_NN_LIBRARY) INTERFACE_POSITION_INDEPENDENT_CODE ON ) - target_link_libraries( - raft_nn_lib - PUBLIC faiss::faiss raft::raft - PRIVATE nvidia::cutlass::cutlass - ) + target_link_libraries(raft_nn_lib PUBLIC faiss::faiss raft::raft) target_compile_options( raft_nn_lib PRIVATE "$<$:${RAFT_CXX_FLAGS}>" "$<$:${RAFT_CUDA_FLAGS}>" diff --git a/cpp/cmake/thirdparty/get_cutlass.cmake b/cpp/cmake/thirdparty/get_cutlass.cmake index 20127cbe4a..811a5466c3 100644 --- a/cpp/cmake/thirdparty/get_cutlass.cmake +++ b/cpp/cmake/thirdparty/get_cutlass.cmake @@ -57,19 +57,32 @@ function(find_and_configure_cutlass) # We generate the cutlass-config files when we built cutlass locally, so always do # `find_dependency` rapids_export_package( - BUILD NvidiaCutlass raft-distance-exports GLOBAL_TARGETS nvidia::cutlass::cutlass + BUILD NvidiaCutlass raft-distance-exports GLOBAL_TARGETS nvidia::cutlass::cutlass ) rapids_export_package( - INSTALL NvidiaCutlass raft-nn-exports GLOBAL_TARGETS nvidia::cutlass::cutlass + INSTALL NvidiaCutlass raft-distance-exports GLOBAL_TARGETS nvidia::cutlass::cutlass + ) + rapids_export_package( + BUILD NvidiaCutlass raft-nn-exports GLOBAL_TARGETS nvidia::cutlass::cutlass + ) + rapids_export_package( + INSTALL NvidiaCutlass raft-nn-exports GLOBAL_TARGETS nvidia::cutlass::cutlass ) # Tell cmake where it can find the generated NvidiaCutlass-config.cmake we wrote. include("${rapids-cmake-dir}/export/find_package_root.cmake") rapids_export_find_package_root( - INSTALL NvidiaCutlass [=[${CMAKE_CURRENT_LIST_DIR}/../]=] raft-distance-exports + INSTALL NvidiaCutlass [=[${CMAKE_CURRENT_LIST_DIR}/../]=] raft-distance-exports + ) + rapids_export_find_package_root( + BUILD NvidiaCutlass [=[${CMAKE_CURRENT_LIST_DIR}]=] raft-distance-exports + ) + include("${rapids-cmake-dir}/export/find_package_root.cmake") + rapids_export_find_package_root( + INSTALL NvidiaCutlass [=[${CMAKE_CURRENT_LIST_DIR}/../]=] raft-nn-exports ) rapids_export_find_package_root( - BUILD NvidiaCutlass [=[${CMAKE_CURRENT_LIST_DIR}]=] raft-nn-exports + BUILD NvidiaCutlass [=[${CMAKE_CURRENT_LIST_DIR}]=] raft-nn-exports ) endfunction() From 6ad4fd18576f22942bd14366d2c14f1d8d10d80f Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Wed, 16 Nov 2022 13:49:32 -0500 Subject: [PATCH 25/25] Adding cutlass as private dependency of lib targets --- cpp/CMakeLists.txt | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 97640f6738..89b2cb1bf9 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -351,7 +351,11 @@ if(RAFT_COMPILE_DIST_LIBRARY) INTERFACE_POSITION_INDEPENDENT_CODE ON ) - target_link_libraries(raft_distance_lib PUBLIC raft::raft cuco::cuco) + target_link_libraries( + raft_distance_lib + PUBLIC raft::raft cuco::cuco + PRIVATE nvidia::cutlass::cutlass + ) target_compile_options( raft_distance_lib PRIVATE "$<$:${RAFT_CXX_FLAGS}>" "$<$:${RAFT_CUDA_FLAGS}>" @@ -424,7 +428,11 @@ if(RAFT_COMPILE_NN_LIBRARY) INTERFACE_POSITION_INDEPENDENT_CODE ON ) - target_link_libraries(raft_nn_lib PUBLIC faiss::faiss raft::raft) + target_link_libraries( + raft_nn_lib + PUBLIC faiss::faiss raft::raft + PRIVATE nvidia::cutlass::cutlass + ) target_compile_options( raft_nn_lib PRIVATE "$<$:${RAFT_CXX_FLAGS}>" "$<$:${RAFT_CUDA_FLAGS}>"