From 611abc710ca3e564da2d61dfd589df44f7e48ae7 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade <36705640+mdoijade@users.noreply.github.com> Date: Thu, 17 Nov 2022 04:27:26 +0530 Subject: [PATCH] Add cutlass 3xTF32,DMMA based L2/cosine distance kernels for SM 8.0 or higher (#939) -- 3xTF32 cutlass based L2 exp/cosine kernel provides 3.5x speedup for fp32 inputs compared to existing pairwise distance kernel for ampere or higher. -- DMMA cutlass based implementation for L2 exp/cosine provides 2.6x speedup instead of existing double precision FMA pipeline based kernel. -- add cutlass as header only dependency to RAFT. Authors: - Mahesh Doijade (https://github.com/mdoijade) - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Tamas Bela Feher (https://github.com/tfeher) - Robert Maynard (https://github.com/robertmaynard) - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/939 --- cpp/CMakeLists.txt | 18 +- cpp/cmake/thirdparty/get_cutlass.cmake | 99 +++ cpp/include/raft/distance/detail/cosine.cuh | 138 +++-- cpp/include/raft/distance/detail/distance.cuh | 25 +- .../raft/distance/detail/euclidean.cuh | 145 +++-- .../detail/pairwise_distance_base.cuh | 85 +++ .../detail/pairwise_distance_cutlass_base.cuh | 178 ++++++ .../detail/pairwise_distance_epilogue.h | 101 +++ .../pairwise_distance_epilogue_elementwise.h | 171 ++++++ .../distance/detail/pairwise_distance_gemm.h | 239 +++++++ .../detail/predicated_tile_iterator_normvec.h | 581 ++++++++++++++++++ cpp/include/raft/util/cudart_utils.hpp | 12 + cpp/test/CMakeLists.txt | 3 +- cpp/test/distance/dist_adj.cu | 63 +- cpp/test/distance/dist_euc_exp.cu | 3 + cpp/test/distance/dist_eucsqrt_exp.cu | 76 +++ docs/source/build.md | 5 +- 17 files changed, 1798 insertions(+), 144 deletions(-) create mode 100644 cpp/cmake/thirdparty/get_cutlass.cmake create mode 100644 cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh create mode 100644 cpp/include/raft/distance/detail/pairwise_distance_epilogue.h create mode 100644 cpp/include/raft/distance/detail/pairwise_distance_epilogue_elementwise.h create mode 100644 cpp/include/raft/distance/detail/pairwise_distance_gemm.h create mode 100644 cpp/include/raft/distance/detail/predicated_tile_iterator_normvec.h create mode 100644 cpp/test/distance/dist_eucsqrt_exp.cu diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index af08a1a2a4..94e693f861 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -182,6 +182,7 @@ rapids_cpm_init() include(cmake/thirdparty/get_thrust.cmake) include(cmake/thirdparty/get_rmm.cmake) include(cmake/thirdparty/get_faiss.cmake) +include(cmake/thirdparty/get_cutlass.cmake) if(RAFT_ENABLE_cuco_DEPENDENCY) include(${rapids-cmake-dir}/cpm/cuco.cmake) @@ -365,7 +366,11 @@ if(RAFT_COMPILE_DIST_LIBRARY) INTERFACE_POSITION_INDEPENDENT_CODE ON ) - target_link_libraries(raft_distance_lib PUBLIC raft::raft cuco::cuco) + target_link_libraries( + raft_distance_lib + PUBLIC raft::raft cuco::cuco + PRIVATE nvidia::cutlass::cutlass + ) target_compile_options( raft_distance_lib PRIVATE "$<$:${RAFT_CXX_FLAGS}>" "$<$:${RAFT_CUDA_FLAGS}>" @@ -383,6 +388,7 @@ endif() target_link_libraries( raft_distance INTERFACE raft::raft $ + nvidia::cutlass::cutlass ) # ################################################################################################## @@ -439,7 +445,11 @@ if(RAFT_COMPILE_NN_LIBRARY) INTERFACE_POSITION_INDEPENDENT_CODE ON ) - target_link_libraries(raft_nn_lib PUBLIC faiss::faiss raft::raft) + target_link_libraries( + raft_nn_lib + PUBLIC faiss::faiss raft::raft + PRIVATE nvidia::cutlass::cutlass + ) target_compile_options( raft_nn_lib PRIVATE "$<$:${RAFT_CXX_FLAGS}>" "$<$:${RAFT_CUDA_FLAGS}>" @@ -454,7 +464,9 @@ if(TARGET raft_nn_lib AND (NOT TARGET raft::raft_nn_lib)) add_library(raft::raft_nn_lib ALIAS raft_nn_lib) endif() -target_link_libraries(raft_nn INTERFACE raft::raft $) +target_link_libraries( + raft_nn INTERFACE raft::raft $ nvidia::cutlass::cutlass +) # ################################################################################################## # * install targets----------------------------------------------------------- diff --git a/cpp/cmake/thirdparty/get_cutlass.cmake b/cpp/cmake/thirdparty/get_cutlass.cmake new file mode 100644 index 0000000000..811a5466c3 --- /dev/null +++ b/cpp/cmake/thirdparty/get_cutlass.cmake @@ -0,0 +1,99 @@ +# ============================================================================= +# Copyright (c) 2021-2022, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License +# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing permissions and limitations under +# the License. +# ============================================================================= + +function(find_and_configure_cutlass) + set(oneValueArgs VERSION REPOSITORY PINNED_TAG) + cmake_parse_arguments(PKG "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + + # if(RAFT_ENABLE_DIST_DEPENDENCIES OR RAFT_COMPILE_LIBRARIES) + set(CUTLASS_ENABLE_HEADERS_ONLY + ON + CACHE BOOL "Enable only the header library" + ) + set(CUTLASS_NAMESPACE + "raft_cutlass" + CACHE STRING "Top level namespace of CUTLASS" + ) + set(CUTLASS_ENABLE_CUBLAS + OFF + CACHE BOOL "Disable CUTLASS to build with cuBLAS library." + ) + + rapids_cpm_find( + NvidiaCutlass ${PKG_VERSION} + GLOBAL_TARGETS nvidia::cutlass::cutlass + CPM_ARGS + GIT_REPOSITORY ${PKG_REPOSITORY} + GIT_TAG ${PKG_PINNED_TAG} + GIT_SHALLOW TRUE + OPTIONS "CUDAToolkit_ROOT ${CUDAToolkit_LIBRARY_DIR}" + ) + + if(TARGET CUTLASS AND NOT TARGET nvidia::cutlass::cutlass) + add_library(nvidia::cutlass::cutlass ALIAS CUTLASS) + endif() + + if(NvidiaCutlass_ADDED) + rapids_export( + BUILD NvidiaCutlass + EXPORT_SET NvidiaCutlass + GLOBAL_TARGETS nvidia::cutlass::cutlass + NAMESPACE nvidia::cutlass:: + ) + endif() + # endif() + + # We generate the cutlass-config files when we built cutlass locally, so always do + # `find_dependency` + rapids_export_package( + BUILD NvidiaCutlass raft-distance-exports GLOBAL_TARGETS nvidia::cutlass::cutlass + ) + rapids_export_package( + INSTALL NvidiaCutlass raft-distance-exports GLOBAL_TARGETS nvidia::cutlass::cutlass + ) + rapids_export_package( + BUILD NvidiaCutlass raft-nn-exports GLOBAL_TARGETS nvidia::cutlass::cutlass + ) + rapids_export_package( + INSTALL NvidiaCutlass raft-nn-exports GLOBAL_TARGETS nvidia::cutlass::cutlass + ) + + # Tell cmake where it can find the generated NvidiaCutlass-config.cmake we wrote. + include("${rapids-cmake-dir}/export/find_package_root.cmake") + rapids_export_find_package_root( + INSTALL NvidiaCutlass [=[${CMAKE_CURRENT_LIST_DIR}/../]=] raft-distance-exports + ) + rapids_export_find_package_root( + BUILD NvidiaCutlass [=[${CMAKE_CURRENT_LIST_DIR}]=] raft-distance-exports + ) + include("${rapids-cmake-dir}/export/find_package_root.cmake") + rapids_export_find_package_root( + INSTALL NvidiaCutlass [=[${CMAKE_CURRENT_LIST_DIR}/../]=] raft-nn-exports + ) + rapids_export_find_package_root( + BUILD NvidiaCutlass [=[${CMAKE_CURRENT_LIST_DIR}]=] raft-nn-exports + ) +endfunction() + +if(NOT RAFT_CUTLASS_GIT_TAG) + set(RAFT_CUTLASS_GIT_TAG v2.9.1) +endif() + +if(NOT RAFT_CUTLASS_GIT_REPOSITORY) + set(RAFT_CUTLASS_GIT_REPOSITORY https://github.com/NVIDIA/cutlass.git) +endif() + +find_and_configure_cutlass( + VERSION 2.9.1 REPOSITORY ${RAFT_CUTLASS_GIT_REPOSITORY} PINNED_TAG ${RAFT_CUTLASS_GIT_TAG} +) diff --git a/cpp/include/raft/distance/detail/cosine.cuh b/cpp/include/raft/distance/detail/cosine.cuh index b7eed3e2a8..f06051962f 100644 --- a/cpp/include/raft/distance/detail/cosine.cuh +++ b/cpp/include/raft/distance/detail/cosine.cuh @@ -17,12 +17,23 @@ #pragma once #include +#include #include namespace raft { namespace distance { namespace detail { +template +struct CosineOp { + __device__ CosineOp() noexcept {} + __device__ AccT operator()(DataT& aNorm, const DataT& bNorm, DataT& accVal) const noexcept + { + return static_cast(1.0) - (AccT)(accVal / (aNorm * bNorm)); + } + __device__ AccT operator()(DataT aData) const noexcept { return aData; } +}; + /** * @brief the cosine distance matrix calculation implementer * It computes the following equation: @@ -71,61 +82,74 @@ void cosineImpl(const DataT* x, FinalLambda fin_op, cudaStream_t stream) { - typedef typename raft::linalg::Policy4x4::Policy RowPolicy; - typedef typename raft::linalg::Policy4x4::ColPolicy ColPolicy; +#if (__CUDACC_VER_MAJOR__ < 12) + const auto deviceVersion = getComputeCapability(); + if (deviceVersion.first >= 8) { + using CosineOp_ = CosineOp; + CosineOp_ cosine_dist_op; + + cutlassDistanceKernel( + x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput, fin_op, cosine_dist_op, stream); - typedef typename std::conditional::type KPolicy; + } else +#endif + { + typedef typename raft::linalg::Policy4x4::Policy RowPolicy; + typedef typename raft::linalg::Policy4x4::ColPolicy ColPolicy; - dim3 blk(KPolicy::Nthreads); + typedef typename std::conditional::type KPolicy; - // Accumulation operation lambda - auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) { acc += x * y; }; + dim3 blk(KPolicy::Nthreads); - // epilogue operation lambda for final value calculation - auto epilog_lambda = [] __device__(AccT acc[KPolicy::AccRowsPerTh][KPolicy::AccColsPerTh], - DataT * regxn, - DataT * regyn, - IdxT gridStrideX, - IdxT gridStrideY) { + // Accumulation operation lambda + auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) { acc += x * y; }; + + // epilogue operation lambda for final value calculation + auto epilog_lambda = [] __device__(AccT acc[KPolicy::AccRowsPerTh][KPolicy::AccColsPerTh], + DataT * regxn, + DataT * regyn, + IdxT gridStrideX, + IdxT gridStrideY) { #pragma unroll - for (int i = 0; i < KPolicy::AccRowsPerTh; ++i) { + for (int i = 0; i < KPolicy::AccRowsPerTh; ++i) { #pragma unroll - for (int j = 0; j < KPolicy::AccColsPerTh; ++j) { - acc[i][j] = acc[i][j] / (regxn[i] * regyn[j]); + for (int j = 0; j < KPolicy::AccColsPerTh; ++j) { + acc[i][j] = 1.0 - (acc[i][j] / (regxn[i] * regyn[j])); + } } - } - }; + }; - constexpr size_t shmemSize = - KPolicy::SmemSize + ((KPolicy::Mblk + KPolicy::Nblk) * sizeof(DataT)); - if (isRowMajor) { - auto cosineRowMajor = pairwiseDistanceMatKernel; - dim3 grid = launchConfigGenerator(m, n, shmemSize, cosineRowMajor); - cosineRowMajor<<>>( - x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, fin_op); - } else { - auto cosineColMajor = pairwiseDistanceMatKernel; - dim3 grid = launchConfigGenerator(m, n, shmemSize, cosineColMajor); - cosineColMajor<<>>( - x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, fin_op); + constexpr size_t shmemSize = + KPolicy::SmemSize + ((KPolicy::Mblk + KPolicy::Nblk) * sizeof(DataT)); + if (isRowMajor) { + auto cosineRowMajor = pairwiseDistanceMatKernelPriorToAmpere; + dim3 grid = launchConfigGenerator(m, n, shmemSize, cosineRowMajor); + cosineRowMajor<<>>( + x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, fin_op); + } else { + auto cosineColMajor = pairwiseDistanceMatKernelPriorToAmpere; + dim3 grid = launchConfigGenerator(m, n, shmemSize, cosineColMajor); + cosineColMajor<<>>( + x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, fin_op); + } } RAFT_CUDA_TRY(cudaGetLastError()); @@ -207,13 +231,11 @@ void cosineAlgo1(Index_ m, { auto norm_op = [] __device__(AccType in) { return raft::mySqrt(in); }; - // Wrap fin_op to allow computing 1 - pA before calling fin_op - auto wrapped_fin_op = [fin_op] __device__(AccType d_val, Index_ g_d_idx) { - return fin_op(static_cast(1.0) - d_val, g_d_idx); - }; - - typedef std::is_same is_bool; - typedef typename std::conditional::type CosOutType; + // raft distance support inputs as float/double and output as uint8_t/float/double. + static_assert(!((sizeof(OutType) > 1) && (sizeof(AccType) != sizeof(OutType))), + "OutType can be uint8_t, float, double," + "if sizeof(OutType) > 1 then sizeof(AccType) == sizeof(OutType)."); + typedef typename std::conditional::type CosOutType; CosOutType* pDcast = reinterpret_cast(pD); ASSERT( @@ -234,12 +256,12 @@ void cosineAlgo1(Index_ m, if (isRowMajor) { lda = k, ldb = k, ldd = n; - cosine( - m, n, k, lda, ldb, ldd, pA, pB, col_vec, row_vec, pDcast, wrapped_fin_op, stream); + cosine( + m, n, k, lda, ldb, ldd, pA, pB, col_vec, row_vec, pDcast, fin_op, stream); } else { lda = n, ldb = m, ldd = m; - cosine( - n, m, k, lda, ldb, ldd, pB, pA, row_vec, col_vec, pDcast, wrapped_fin_op, stream); + cosine( + n, m, k, lda, ldb, ldd, pB, pA, row_vec, col_vec, pDcast, fin_op, stream); } } diff --git a/cpp/include/raft/distance/detail/distance.cuh b/cpp/include/raft/distance/detail/distance.cuh index 5072a85fc0..b459c73bee 100644 --- a/cpp/include/raft/distance/detail/distance.cuh +++ b/cpp/include/raft/distance/detail/distance.cuh @@ -615,6 +615,19 @@ void distance(const InType* x, * @note if workspace is passed as nullptr, this will return in * worksize, the number of bytes of workspace required */ + +// Default final op functor which facilitates elementwise operation on +// final distance value if any. +template +struct default_fin_op { + __host__ __device__ default_fin_op() noexcept {}; + // functor signature. + __host__ __device__ OutType operator()(AccType d_val, Index g_d_idx) const noexcept + { + return d_val; + } +}; + template ( - x, y, dist, m, n, k, workspace, worksize, default_fin_op, stream, isRowMajor, metric_arg); + using final_op_type = default_fin_op; + final_op_type fin_op; + + // raft distance support inputs as float/double and output as uint8_t/float/double. + static_assert(!((sizeof(OutType) > 1) && (sizeof(AccType) != sizeof(OutType))), + "OutType can be uint8_t, float, double," + "if sizeof(OutType) > 1 then sizeof(AccType) == sizeof(OutType)."); + distance( + x, y, dist, m, n, k, workspace, worksize, fin_op, stream, isRowMajor, metric_arg); RAFT_CUDA_TRY(cudaPeekAtLastError()); } diff --git a/cpp/include/raft/distance/detail/euclidean.cuh b/cpp/include/raft/distance/detail/euclidean.cuh index d83e81b6a9..5ea74fa884 100644 --- a/cpp/include/raft/distance/detail/euclidean.cuh +++ b/cpp/include/raft/distance/detail/euclidean.cuh @@ -15,13 +15,30 @@ */ #pragma once + #include +#include #include namespace raft { namespace distance { namespace detail { +template +struct L2ExpandedOp { + bool sqrt; + + __device__ L2ExpandedOp() noexcept : sqrt(false) {} + __device__ L2ExpandedOp(bool isSqrt) noexcept : sqrt(isSqrt) {} + __device__ AccT operator()(DataT& aNorm, const DataT& bNorm, DataT& accVal) const noexcept + { + AccT outVal = aNorm + bNorm - DataT(2.0) * accVal; + return sqrt ? raft::mySqrt(outVal) : outVal; + } + + __device__ AccT operator()(DataT aData) const noexcept { return aData; } +}; + /** * @brief the expanded euclidean distance matrix calculation implementer * It computes the following equation: C = op(A^2 + B^2 - 2AB) @@ -71,71 +88,85 @@ void euclideanExpImpl(const DataT* x, FinalLambda fin_op, cudaStream_t stream) { - typedef typename raft::linalg::Policy4x4::Policy RowPolicy; - typedef typename raft::linalg::Policy4x4::ColPolicy ColPolicy; +#if (__CUDACC_VER_MAJOR__ < 12) + const auto deviceVersion = getComputeCapability(); + if (deviceVersion.first >= 8) { + using L2Op = L2ExpandedOp; + L2Op L2_dist_op(sqrt); - typedef typename std::conditional::type KPolicy; + cutlassDistanceKernel( + x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput, fin_op, L2_dist_op, stream); - dim3 blk(KPolicy::Nthreads); + } else +#endif + { - // Accumulation operation lambda - auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) { acc += x * y; }; + typedef typename raft::linalg::Policy4x4::Policy RowPolicy; + typedef typename raft::linalg::Policy4x4::ColPolicy ColPolicy; - // epilogue operation lambda for final value calculation - auto epilog_lambda = [sqrt] __device__(AccT acc[KPolicy::AccRowsPerTh][KPolicy::AccColsPerTh], - DataT * regxn, - DataT * regyn, - IdxT gridStrideX, - IdxT gridStrideY) { + typedef typename std::conditional::type KPolicy; + + dim3 blk(KPolicy::Nthreads); + + // Accumulation operation lambda + auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) { acc += x * y; }; + + // epilogue operation lambda for final value calculation + auto epilog_lambda = [sqrt] __device__(AccT acc[KPolicy::AccRowsPerTh][KPolicy::AccColsPerTh], + DataT * regxn, + DataT * regyn, + IdxT gridStrideX, + IdxT gridStrideY) { #pragma unroll - for (int i = 0; i < KPolicy::AccRowsPerTh; ++i) { + for (int i = 0; i < KPolicy::AccRowsPerTh; ++i) { #pragma unroll - for (int j = 0; j < KPolicy::AccColsPerTh; ++j) { - acc[i][j] = regxn[i] + regyn[j] - (DataT)2.0 * acc[i][j]; + for (int j = 0; j < KPolicy::AccColsPerTh; ++j) { + acc[i][j] = regxn[i] + regyn[j] - (DataT)2.0 * acc[i][j]; + } } - } - if (sqrt) { + if (sqrt) { #pragma unroll - for (int i = 0; i < KPolicy::AccRowsPerTh; ++i) { + for (int i = 0; i < KPolicy::AccRowsPerTh; ++i) { #pragma unroll - for (int j = 0; j < KPolicy::AccColsPerTh; ++j) { - acc[i][j] = raft::mySqrt(acc[i][j]); + for (int j = 0; j < KPolicy::AccColsPerTh; ++j) { + acc[i][j] = raft::mySqrt(acc[i][j]); + } } } - } - }; + }; - constexpr size_t shmemSize = - KPolicy::SmemSize + ((KPolicy::Mblk + KPolicy::Nblk) * sizeof(DataT)); - if (isRowMajor) { - auto euclideanExpRowMajor = pairwiseDistanceMatKernel; - dim3 grid = launchConfigGenerator(m, n, shmemSize, euclideanExpRowMajor); - - euclideanExpRowMajor<<>>( - x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, fin_op); - } else { - auto euclideanExpColMajor = pairwiseDistanceMatKernel; - dim3 grid = launchConfigGenerator(m, n, shmemSize, euclideanExpColMajor); - euclideanExpColMajor<<>>( - x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, fin_op); + constexpr size_t shmemSize = + KPolicy::SmemSize + ((KPolicy::Mblk + KPolicy::Nblk) * sizeof(DataT)); + if (isRowMajor) { + auto euclideanExpRowMajor = pairwiseDistanceMatKernelPriorToAmpere; + dim3 grid = launchConfigGenerator(m, n, shmemSize, euclideanExpRowMajor); + + euclideanExpRowMajor<<>>( + x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, fin_op); + } else { + auto euclideanExpColMajor = pairwiseDistanceMatKernelPriorToAmpere; + dim3 grid = launchConfigGenerator(m, n, shmemSize, euclideanExpColMajor); + euclideanExpColMajor<<>>( + x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, fin_op); + } } RAFT_CUDA_TRY(cudaGetLastError()); @@ -164,6 +195,7 @@ void euclideanExp(IdxT m, { size_t bytesA = sizeof(DataT) * lda; size_t bytesB = sizeof(DataT) * ldb; + if (16 % sizeof(DataT) == 0 && bytesA % 16 == 0 && bytesB % 16 == 0) { euclideanExpImpl( x, y, xn, yn, m, n, k, lda, ldb, ldd, sqrt, dOutput, fin_op, stream); @@ -217,8 +249,11 @@ void euclideanAlgo1(Index_ m, { auto norm_op = [] __device__(InType in) { return in; }; - typedef std::is_same is_bool; - typedef typename std::conditional::type ExpOutType; + // raft distance support inputs as float/double and output as uint8_t/float/double. + static_assert(!((sizeof(OutType) > 1) && (sizeof(AccType) != sizeof(OutType))), + "OutType can be uint8_t, float, double," + "if sizeof(OutType) > 1 then sizeof(AccType) == sizeof(OutType)."); + typedef typename std::conditional::type ExpOutType; ExpOutType* pDcast = reinterpret_cast(pD); ASSERT( diff --git a/cpp/include/raft/distance/detail/pairwise_distance_base.cuh b/cpp/include/raft/distance/detail/pairwise_distance_base.cuh index 27e9935358..26536d13cd 100644 --- a/cpp/include/raft/distance/detail/pairwise_distance_base.cuh +++ b/cpp/include/raft/distance/detail/pairwise_distance_base.cuh @@ -364,6 +364,91 @@ __global__ __launch_bounds__(Policy::Nthreads, 2) obj.run(); } +/** + * @brief the distance matrix calculation kernel for L2 and cosine + * for GPU arch < SM 8.0, this version is to make sure we don't recompile + * these kernels for ampere or higher as we use cutlass kernel for it. + * @tparam useNorms whether norms are needed + * @tparam DataT input data-type (for A and B matrices) + * @tparam AccT accumulation data-type + * @tparam OutT output data-type (for C and D matrices) + * @tparam IdxT index data-type + * @tparam Policy struct which tunes the Contraction kernel + * @tparam CoreLambda lambda which implements accumulation operation + * @tparam EpilogueLambda lambda which implements operation for calculating + final value. + * @tparam FinalLambda final lambda called on final distance value + * @tparam isRowMajor true if input/output is row major(default), + false for column major + * + * @param[in] x input matrix + * @param[in] y input matrix + * @param[in] xn row norms of input matrix A. + * @param[in] yn row norms of input matrix B. + * @param[in] m number of rows of A and C/D + * @param[in] n number of columns of B and C/D + * @param[in] k number of cols of A and rows of B + * @param[in] lda leading dimension of A + * @param[in] ldb leading dimension of B + * @param[in] ldd leading dimension of C/D + * @param[output] pD output matrix + * @param core_op the core lambda + * @param epilog_op the epilogue lambda + * @param fin_op the final gemm epilogue lambda + */ + +template +__global__ __launch_bounds__(Policy::Nthreads, 2) + + void pairwiseDistanceMatKernelPriorToAmpere(const DataT* x, + const DataT* y, + const DataT* _xn, + const DataT* _yn, + IdxT m, + IdxT n, + IdxT k, + IdxT lda, + IdxT ldb, + IdxT ldd, + OutT* dOutput, + CoreLambda core_op, + EpilogueLambda epilog_op, + FinalLambda fin_op) +{ + //#if __CUDA_ARCH__ < 800 + // TODO: re-enable the CUDA_ARCH guard for below Ampere once cutlass based + // kernels are enabled for CUDA 12.0 + extern __shared__ char smem[]; + auto rowEpilog = [] __device__(IdxT starty) { return; }; + + PairwiseDistances + obj( + x, y, m, n, k, lda, ldb, ldd, _xn, _yn, dOutput, smem, core_op, epilog_op, fin_op, rowEpilog); + obj.run(); + //#endif +} + template dim3 launchConfigGenerator(IdxT m, IdxT n, std::size_t sMemSize, T func) { diff --git a/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh b/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh new file mode 100644 index 0000000000..f39d880da4 --- /dev/null +++ b/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh @@ -0,0 +1,178 @@ +/* + * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wstrict-aliasing" + +#if (__CUDACC_VER_MAJOR__ < 12) + +// We define CUTLASS_NAMESPACE in case +// RAFT cmake is not used +#ifndef CUTLASS_NAMESPACE +#define cutlass raft_cutlass +#endif + +#include + +#include +#include +#include + +#include +#include +#include +#include + +#include "./pairwise_distance_epilogue_elementwise.h" +#include "./pairwise_distance_gemm.h" + +#define CUTLASS_CHECK(status) \ + { \ + cutlass::Status error = status; \ + if (error != cutlass::Status::kSuccess) { \ + std::cerr << "Got cutlass error: " << cutlassGetStatusString(error) << " at: " << __LINE__ \ + << std::endl; \ + exit(EXIT_FAILURE); \ + } \ + } + +namespace raft { +namespace distance { +namespace detail { + +template +void cutlassDistanceKernel(const DataT* x, + const DataT* y, + const DataT* xn, + const DataT* yn, + IdxT m, + IdxT n, + IdxT k, + IdxT lda, + IdxT ldb, + IdxT ldd, + OutT* dOutput, + FinalLambda fin_op, + DistanceFn dist_op, + cudaStream_t stream) +{ + static_assert(!(std::is_same::value), + "OutType bool is not supported use uint8_t instead"); + + using EpilogueOutputOp = + cutlass::epilogue::thread::PairwiseDistanceEpilogueElementwise; + constexpr int batch_count = 1; + + constexpr auto mode = cutlass::gemm::GemmUniversalMode::kGemm; + + typename EpilogueOutputOp::Params epilog_op_param(dist_op, fin_op); + + const DataT *a, *b; + + IdxT gemm_lda, gemm_ldb; + + // Number of pipelines you want to use + constexpr int NumStages = 3; + // Alignment + constexpr int Alignment = VecLen; + + // default initialize problem size with row major inputs + auto problem_size = cutlass::gemm::GemmCoord(n, m, k); + + using cutlassDistKernel = + typename cutlass::gemm::kernel::PairwiseDistanceGemm::GemmKernel; + + using cutlassDist = cutlass::gemm::device::GemmUniversalAdapter; + + if constexpr (isRowMajor) { + a = y; + b = x; + gemm_lda = ldb; + gemm_ldb = lda; + } else { + problem_size = cutlass::gemm::GemmCoord(m, n, k); + a = x; + b = y; + gemm_lda = lda; + gemm_ldb = ldb; + } + + typename cutlassDist::Arguments arguments{ + mode, problem_size, batch_count, epilog_op_param, a, b, + xn, // C matrix eq vector param, which here is A norm + nullptr, // tensor_Z, + (DataT*)yn, // this is broadcast vec, which is required to be non-const param + dOutput, // Output distance matrix + (int64_t)0, // batch stride A + (int64_t)0, // batch stride B + (int64_t)0, // batch stride Norm A + (int64_t)0, + (int64_t)0, // batch stride Norm B + (int64_t)0, // batch stride Output + gemm_lda, // stride A + gemm_ldb, // stride B + 1, // stride A norm + 0, // this is no-op for Z + 0, // This must be zero + ldd // stride Output matrix + }; + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = cutlassDist::get_workspace_size(arguments); + // Allocate workspace memory + rmm::device_uvector workspace(workspace_size, stream); + // Instantiate CUTLASS kernel depending on templates + cutlassDist cutlassDist_op; + // Check the problem size is supported or not + cutlass::Status status = cutlassDist_op.can_implement(arguments); + CUTLASS_CHECK(status); + // Initialize CUTLASS kernel with arguments and workspace pointer + status = cutlassDist_op.initialize(arguments, workspace.data(), stream); + CUTLASS_CHECK(status); + // Launch initialized CUTLASS kernel + status = cutlassDist_op(); + CUTLASS_CHECK(status); +} + +}; // namespace detail +}; // namespace distance +}; // namespace raft +#endif // (__CUDACC_VER_MAJOR__ < 12) +#pragma GCC diagnostic pop diff --git a/cpp/include/raft/distance/detail/pairwise_distance_epilogue.h b/cpp/include/raft/distance/detail/pairwise_distance_epilogue.h new file mode 100644 index 0000000000..21e7d18854 --- /dev/null +++ b/cpp/include/raft/distance/detail/pairwise_distance_epilogue.h @@ -0,0 +1,101 @@ +/* + * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! \file + \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. + +This is adapted from DefaultEpilogueWithBroadcastTensorOp from CUTLASS 2.9.0 +(https://github.com/NVIDIA/cutlass/blob/master/include/cutlass/epilogue/threadblock/default_epilogue_with_broadcast.h#L75) + +This epilogue allows us to load norm buffers using PredicatedTileIteratorNormVec +and EpilogueWithBroadcast used for distances L2/cosine as well as applies user-define elementwise +operation. +-- A norm load is provided PredicatedTileIteratorNormVec +-- B norm load is provided by EpilogueWithBroadcast +-- elementwise operation is provided by OutputOp +*/ + +#pragma once + +#include +#include +#include + +#include + +#include "./predicated_tile_iterator_normvec.h" +#include +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Defines sensible defaults for epilogues for TensorOps. +template +struct PairwiseDistanceEpilogue { + /// Use defaults related to the existing epilogue + using Base = + DefaultEpilogueTensorOp; + + // + // Stores the result z = (y = GEMM(A, B, C), broadcast) + // + using OutputTileIterator = cutlass::epilogue::threadblock:: + PredicatedTileIteratorNormVec; + + // + // Additional tensor tile iterator - stores t = Elementwise(z) + // + using TensorTileIterator = + cutlass::epilogue::threadblock::PredicatedTileIterator; + + /// Define the epilogue + using Epilogue = EpilogueWithBroadcast; +}; + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/cpp/include/raft/distance/detail/pairwise_distance_epilogue_elementwise.h b/cpp/include/raft/distance/detail/pairwise_distance_epilogue_elementwise.h new file mode 100644 index 0000000000..3e33f4d833 --- /dev/null +++ b/cpp/include/raft/distance/detail/pairwise_distance_epilogue_elementwise.h @@ -0,0 +1,171 @@ +/* + * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// +/*! \file + \brief Functor performing distance operations used by epilogues of pairwise distance + * kernels. +* This is adapted from LinearCombinationBiasElementwise from CUTLASS 2.9.0 +* customized for applying elementwise distance formula on accumulated GEMM value +* and applying user-defined final custom operation on the distance value. +*/ + +#pragma once + +#include +#include +#include +#include +#include + +#include + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace thread { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// This base class is meant to define the concept required of the +/// EpilogueWithBroadcast::OutputOp +template +class PairwiseDistanceEpilogueElementwise { + public: + using ElementOutput = ElementC_; + using ElementC = ElementC_; + using ElementAccumulator = ElementAccumulator_; + using ElementCompute = ElementCompute_; + using ElementZ = ElementZ_; + using ElementT = ElementT_; + static int const kElementsPerAccess = ElementsPerAccess; + static int const kCount = kElementsPerAccess; + + using DistanceOp = DistanceOp_; + using FinalOp = FinalOp_; + + using FragmentAccumulator = Array; + using FragmentCompute = Array; + using FragmentC = Array; + using FragmentZ = Array; + using FragmentT = Array; + + using FragmentOutput = FragmentZ; + + static bool const kIsHeavy = false; // ElementwiseOp::kIsHeavy; + + /// If true, the 'Z' tensor is stored + static bool const kStoreZ = false; // We don't store anything in Z, + + /// If true, the 'T' tensor is stored + static bool const kStoreT = true; // this is our final output storage. + + /// Host-constructable parameters structure + struct Params { + FinalOp_ final_op_; + DistanceOp_ dist_op_; + + // + // Methods + // + CUTLASS_HOST_DEVICE + Params(DistanceOp_ dist_op, FinalOp final_op) : final_op_(final_op), dist_op_(dist_op) {} + + CUTLASS_HOST_DEVICE + Params() {} + }; + + private: + // + // Data members + // + FinalOp_ final_op; + DistanceOp_ elementwise_op; + + public: + // + // Methods + // + + /// Constructor from Params + CUTLASS_HOST_DEVICE + PairwiseDistanceEpilogueElementwise(Params const& params) + : final_op(params.final_op_), elementwise_op(params.dist_op_) + { + } + + /// Returns true if source is needed + CUTLASS_HOST_DEVICE + bool is_source_needed() const + { + // we use for making sure C matrix path is used for A mat norm. + return true; + } + + /// Functionally required for serial reduction in the epilogue + CUTLASS_HOST_DEVICE + void set_k_partition(int k_partition, int k_partition_count) {} + + /// Applies the operation when is_source_needed() is true + CUTLASS_HOST_DEVICE + void operator()(FragmentZ& frag_Z, + FragmentT& frag_T, + FragmentAccumulator const& AB, + FragmentC const& frag_C, + FragmentCompute const& V) const + { + FragmentCompute tmp_Accum = + NumericArrayConverter()(AB); + FragmentCompute tmp_C = + NumericArrayConverter()(frag_C); + FragmentCompute result_Z; + FragmentCompute result_T; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kElementsPerAccess; ++i) { + result_Z[i] = elementwise_op(tmp_C[i], V[i], tmp_Accum[i]); + result_T[i] = final_op(result_Z[i], 0); + } + + NumericArrayConverter convert_t; + frag_T = convert_t(result_T); + } + + /// Applies the operation when is_source_needed() is false + CUTLASS_HOST_DEVICE + void operator()(FragmentZ& frag_Z, + FragmentT& frag_T, + FragmentAccumulator const& AB, + FragmentCompute const& V) const + { + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace thread +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/cpp/include/raft/distance/detail/pairwise_distance_gemm.h b/cpp/include/raft/distance/detail/pairwise_distance_gemm.h new file mode 100644 index 0000000000..ea9ed77fb5 --- /dev/null +++ b/cpp/include/raft/distance/detail/pairwise_distance_gemm.h @@ -0,0 +1,239 @@ +/* + * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#include +#include +#include +#include + +#include "./pairwise_distance_epilogue.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC_, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Element type for final output + // typename ElementOutT, + /// Epilogue output operator - must satisfy concept of 'EpilogueWithBroadcastOp' + typename EpilogueOutputOp, + /// Number of stages used in the pipelined mainloop + int Stages, + /// data layout row/column major of inputs + bool isRowMajor> +struct PairwiseDistanceGemm { + // This struct is specialized for fp32/3xTF32 + + /// Threadblock-level tile size (concept: GemmShape) + using ThreadblockShape = + cutlass::gemm::GemmShape<128, 128, 16>; // <- threadblock tile M = 128, N = 128, K = 16 + /// Warp-level tile size (concept: GemmShape) + // This code section describes tile size a warp will compute + using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; // <- warp tile M = 64, N = 64, K = 16 + /// Warp-level tile size (concept: GemmShape) + // This code section describes the size of MMA op + using InstructionShape = + cutlass::gemm::GemmShape<16, 8, 8>; // <- MMA Op tile M = 16, N = 8, K = 8 + + /// Operation performed by GEMM + using Operator = cutlass::arch::OpMultiplyAddFastF32; + + // This code section describes whether you want to use tensor cores or regular SIMT cores on GPU + // SM + using OperatorClass = cutlass::arch::OpClassTensorOp; + + // This code section describes CUDA SM architecture number + using ArchTag = cutlass::arch::Sm80; + + // This code section describes how threadblocks are scheduled on GPU + /// Threadblock-level swizzling operator + using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; + + /// data layout for final output matrix. + // we keep this same layout even for column major inputs + using LayoutOutput = cutlass::layout::RowMajor; + + typedef typename std::conditional::type NormXLayout; + + typedef typename std:: + conditional::type LayoutA_; + + typedef typename std:: + conditional::type LayoutB_; + + using GemmBase = typename DefaultGemmUniversal::GemmKernel; + + // Replace epilogue + using Epilogue = typename cutlass::epilogue::threadblock::PairwiseDistanceEpilogue< + typename GemmBase::Epilogue::Shape, + typename GemmBase::Epilogue::WarpMmaOperator, + GemmBase::Epilogue::kPartitionsK, + ElementAccumulator, + typename EpilogueOutputOp::ElementT, + ElementAccumulator, + EpilogueOutputOp, + NormXLayout, + GemmBase::Epilogue::kElementsPerAccess>::Epilogue; + + // Compose the GEMM kernel + using GemmKernel = GemmWithFusedEpilogue; +}; + +template < + /// Layout type for A matrix operand + int kAlignmentA, + /// Layout type for B matrix operand + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC_, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Epilogue output operator - must satisfy concept of 'EpilogueWithBroadcastOp' + typename EpilogueOutputOp, + /// Number of stages used in the pipelined mainloop + int Stages, + /// data layout row/column major of inputs + bool isRowMajor> +struct PairwiseDistanceGemm { + // using Transform = cutlass::ComplexTransform::kNone; + // Threadblock-level tile size (concept: GemmShape) + using ThreadblockShape = + cutlass::gemm::GemmShape<64, 64, 16>; // <- threadblock tile M = 64, N = 64, K = 16 + /// Warp-level tile size (concept: GemmShape) + // This code section describes tile size a warp will compute + using WarpShape = cutlass::gemm::GemmShape<32, 32, 16>; // <- warp tile M = 32, N = 32, K = 16 + /// Warp-level tile size (concept: GemmShape) + // This code section describes the size of MMA op + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; + + // Operation performed by GEMM + using Operator = cutlass::arch::OpMultiplyAdd; + // This code section describes whether you want to use tensor cores or regular SIMT cores on GPU + // SM + using OperatorClass = cutlass::arch::OpClassTensorOp; + + // This code section describes CUDA SM architecture number + using ArchTag = cutlass::arch::Sm80; + + // This code section describes how threadblocks are scheduled on GPU + /// Threadblock-level swizzling operator + using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; + + /// data layout for final output matrix. + // we keep this same layout even for column major inputs + using LayoutOutput = cutlass::layout::RowMajor; + + typedef typename std::conditional::type NormXLayout; + + typedef typename std:: + conditional::type LayoutA_; + + typedef typename std:: + conditional::type LayoutB_; + + using GemmBase = typename DefaultGemmUniversal::GemmKernel; + + // Replace epilogue + using Epilogue = typename cutlass::epilogue::threadblock::PairwiseDistanceEpilogue< + typename GemmBase::Epilogue::Shape, + typename GemmBase::Epilogue::WarpMmaOperator, + GemmBase::Epilogue::kPartitionsK, + ElementC_, + typename EpilogueOutputOp::ElementT, + ElementC_, + EpilogueOutputOp, + NormXLayout, + GemmBase::Epilogue::kElementsPerAccess>::Epilogue; + + // Compose the GEMM kernel + using GemmKernel = GemmWithFusedEpilogue; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass \ No newline at end of file diff --git a/cpp/include/raft/distance/detail/predicated_tile_iterator_normvec.h b/cpp/include/raft/distance/detail/predicated_tile_iterator_normvec.h new file mode 100644 index 0000000000..67c01448dc --- /dev/null +++ b/cpp/include/raft/distance/detail/predicated_tile_iterator_normvec.h @@ -0,0 +1,581 @@ +/* + * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! \file + \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. + +This file contains a customized version of PredicatedTileIterator from CUTLASS 2.9.0 +(https://github.com/NVIDIA/cutlass/blob/v2.9.0/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h#L75) + +Changes: +- added `Layout_` template param +- Only the row index is used to load the data in load_with_byte_offset(). + This way the same normalization data is used across all columns in a row. + +*/ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +//////////////////////////////////////////////////////////////////////////////// + +namespace epilogue { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator used to load and store output tile from global memory in epilogue. +/// +/// Satisfies: ReadableTileIterator | PredicatedTileIterator | ForwardTileIterator +/// +template +class PredicatedTileIteratorNormVec { + public: + using ThreadMap = ThreadMap_; + using Shape = typename ThreadMap::Shape; + + using Element = Element_; + + using Layout = Layout_; + using TensorRef = TensorRef; + using ConstTensorRef = typename TensorRef::ConstTensorRef; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using TensorCoord = MatrixCoord; + + static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; + static int const kThreads = ThreadMap::kThreads; + static int const kIterations = ThreadMap::Count::kTile; + + static_assert(ThreadMap::Iterations::kRow > 0, "ThreadMap::Iterations::kRow must be > 0"); + static_assert(ThreadMap::Iterations::kGroup > 0, "ThreadMap::Iterations::kGroup must be > 0"); + static_assert(ThreadMap::Iterations::kCluster > 0, "ThreadMap::Iterations::kCluster must be > 0"); + static_assert(ThreadMap::Iterations::kColumn > 0, "ThreadMap::Iterations::kColumn must be > 0"); + + /// Fragment object + using Fragment = Array; + + /// Memory access size + using AccessType = AlignedArray; + + // + // Parameters struct + // + + /// Uses a non-template class + struct Params : PredicatedTileIteratorParams { + using Base = PredicatedTileIteratorParams; + + CUTLASS_HOST_DEVICE + Params() {} + + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : PredicatedTileIteratorParams( + layout.stride(0) * int(sizeof(AccessType)) / kElementsPerAccess, + make_OutputTileThreadMapDesc()) + { + } + + CUTLASS_HOST_DEVICE + Params(Base const& base) : Base(base) {} + }; + + /// Mask object + struct Mask { + static int const kCount = ThreadMap::Iterations::kColumn; + + /// Predicate state + bool predicates[kCount]; + + // + // Mask + // + CUTLASS_HOST_DEVICE + Mask() { enable(); } + + ///< Efficiently disables all accesses guarded by mask + CUTLASS_HOST_DEVICE void clear() + { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { + predicates[i] = false; + } + } + + ///< CUTLASS_HOST_DEVICE enables all accesses guarded by mask + CUTLASS_DEVICE void enable() + { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { + predicates[i] = true; + } + } + }; + + private: + // + // Data members + // + + /// Parameters structure containing reference and precomputed state. + PredicatedTileIteratorParams params_; + + /// Byte-level pointer + uint8_t* byte_pointer_; + + /// Array of boolean values to contain steady-state predicates + Mask mask_; + + /// Extent of the matrix tile in rows + Index extent_row_; + + /// Extent of the matrix tile in rows + Index extent_column_; + + /// A thread's starting row position (assuming steady-state predicates have been computed) + Index thread_start_row_; + + /// A thread's starting column + Index thread_start_column_; + + /// Internal state counter + int state_[3]; + + /// Scatter indices + int const* indices_; + + // + // Static asserts about internal strides + // + + static_assert(sizeof(extent_row_) == 4, "Expected 32b extents"); + static_assert(sizeof(thread_start_row_) == 4, "Expected 32b extents"); + static_assert(sizeof(PredicatedTileIteratorParams::stride) == 8, "Expected 64b strides"); + + private: + // + // Methods + // + + public: + // + // Methods + // + + /// Constructor + CUTLASS_DEVICE + PredicatedTileIteratorNormVec(PredicatedTileIteratorParams const& params, + Element* pointer, + TensorCoord extent, + int thread_idx, + TensorCoord threadblock_offset = TensorCoord(), + int const* indices = nullptr) + : params_(params), indices_(indices) + { + TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx) + threadblock_offset; + + extent_row_ = extent.row(); + extent_column_ = extent.column(); + + thread_start_row_ = thread_offset.row(); + thread_start_column_ = thread_offset.column(); + + // Initialize predicates + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kColumn; ++c) { + mask_.predicates[c] = + ((thread_offset.column() + ThreadMap::Delta::kColumn * c) < extent.column()); + } + + // Null pointer performs no accesses + if (!pointer) { mask_.clear(); } + + if (ScatterD && !indices) { mask_.clear(); } + + // Initialize pointer + byte_pointer_ = reinterpret_cast(pointer) + + LongIndex(thread_offset.row()) * LongIndex(params_.stride); + + if (ScatterD) { + byte_pointer_ = reinterpret_cast(pointer) + + LongIndex(thread_offset.column()) * sizeof(AccessType) / kElementsPerAccess; + } + + // Initialize internal state counter + state_[0] = state_[1] = state_[2] = 0; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) + { + byte_pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment& frag, int64_t byte_offset) const + { + uint8_t* byte_pointer = byte_pointer_; + AccessType* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int frag_row_idx = + (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + bool row_guard = ((row_offset + thread_start_row_) < extent_row_); + + AccessType* memory_pointer = reinterpret_cast(byte_pointer + byte_offset); + + if (ScatterD && row_guard) { + assert(indices_); + + memory_pointer = reinterpret_cast( + byte_pointer + byte_offset + + LongIndex(indices_[row_offset + thread_start_row_]) * LongIndex(params_.stride)); + } + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { + bool guard = row_guard && mask_.predicates[column]; + + cutlass::arch::global_load( + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column], + (void*)&memory_pointer[0], + guard); + } + + if (row + 1 < ThreadMap::Iterations::kRow) { + if (!ScatterD) { byte_pointer += params_.increment_row; } + } + } + + if (group + 1 < ThreadMap::Iterations::kGroup) { byte_pointer += params_.increment_group; } + } + + if (cluster + 1 < ThreadMap::Iterations::kCluster) { + byte_pointer += params_.increment_cluster; + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) const { load_with_byte_offset(frag, 0); } + + /// Stores a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const& frag, int64_t byte_offset) const + { + uint8_t* byte_pointer = byte_pointer_; + AccessType const* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int frag_row_idx = + (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + bool row_guard = ((row_offset + thread_start_row_) < extent_row_); + + AccessType* memory_pointer = reinterpret_cast(byte_pointer + byte_offset); + + if (ScatterD && row_guard) { + assert(indices_); + + memory_pointer = reinterpret_cast( + byte_pointer + byte_offset + + LongIndex(indices_[row_offset + thread_start_row_]) * LongIndex(params_.stride)); + } + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { + bool guard = row_guard && mask_.predicates[column]; + + if (UseCUDAStore) { + if (guard) { + memory_pointer[column * ThreadMap::Delta::kColumn / kElementsPerAccess] = + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column]; + } + } else { + cutlass::arch::global_store( + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column], + (void*)&memory_pointer[column * ThreadMap::Delta::kColumn / kElementsPerAccess], + guard); + } + } + + if (row + 1 < ThreadMap::Iterations::kRow) { + if (!ScatterD) { byte_pointer += params_.increment_row; } + } + } + + if (group + 1 < ThreadMap::Iterations::kGroup) { byte_pointer += params_.increment_group; } + } + + if (cluster + 1 < ThreadMap::Iterations::kCluster) { + byte_pointer += params_.increment_cluster; + } + } + } + + /// Stores a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) const { store_with_byte_offset(frag, 0); } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void downsample_load_with_byte_offset(Fragment& frag, + int64_t byte_offset, + int convolution_P, + int convolution_Q, + int add_P, + int add_Q, + int problem_N) const + { + uint8_t* byte_pointer = byte_pointer_; + AccessType* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int frag_row_idx = + (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + bool row_guard = ((row_offset + thread_start_row_) < extent_row_); + + int output_row = row_offset + thread_start_row_; + int output_N = output_row / (convolution_P * convolution_Q); + int output_PQ = output_row % (convolution_P * convolution_Q); + int output_P = output_PQ / convolution_Q; + int output_Q = output_PQ % convolution_Q; + + int input_row = output_N * 2 * convolution_P * 2 * convolution_Q + + (2 * output_P + add_P) * 2 * convolution_Q + 2 * output_Q + add_Q; + + int64_t byte_offset = (input_row - output_row) * problem_N * sizeof(float); + + AccessType* memory_pointer = reinterpret_cast(byte_pointer + byte_offset); + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { + bool guard = row_guard && mask_.predicates[column]; + + cutlass::arch::global_load( + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column], + (void*)&memory_pointer[column * ThreadMap::Delta::kColumn / kElementsPerAccess], + guard); + } + + if (row + 1 < ThreadMap::Iterations::kRow) { byte_pointer += params_.increment_row; } + } + + if (group + 1 < ThreadMap::Iterations::kGroup) { byte_pointer += params_.increment_group; } + } + + if (cluster + 1 < ThreadMap::Iterations::kCluster) { + byte_pointer += params_.increment_cluster; + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void upsample_load_with_byte_offset(Fragment& frag, + int64_t byte_offset, + int convolution_P, + int convolution_Q, + int add_P, + int add_Q, + int problem_N) const + { + uint8_t* byte_pointer = byte_pointer_; + AccessType* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int frag_row_idx = + (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + bool row_guard = ((row_offset + thread_start_row_) < extent_row_); + + int output_row = row_offset + thread_start_row_; + int output_N = output_row / (convolution_P * convolution_Q); + int output_PQ = output_row % (convolution_P * convolution_Q); + int output_P = output_PQ / convolution_Q; + int output_Q = output_PQ % convolution_Q; + int row_add_P = add_P; + int row_add_Q = add_Q; + if (output_P > convolution_P - 2) row_add_P = 0; + if (output_Q > convolution_Q - 2) row_add_Q = 0; + + int input_row = output_N * (convolution_P / 2) * (convolution_Q / 2) + + ((output_P + row_add_P) / 2) * (convolution_Q / 2) + + (output_Q + row_add_Q) / 2; + + int64_t byte_offset = (input_row - output_row) * problem_N * sizeof(float); + + AccessType* memory_pointer = reinterpret_cast(byte_pointer + byte_offset); + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { + bool guard = row_guard && mask_.predicates[column]; + + cutlass::arch::global_load( + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column], + (void*)&memory_pointer[column * ThreadMap::Delta::kColumn / kElementsPerAccess], + guard); + } + + if (row + 1 < ThreadMap::Iterations::kRow) { byte_pointer += params_.increment_row; } + } + + if (group + 1 < ThreadMap::Iterations::kGroup) { byte_pointer += params_.increment_group; } + } + + if (cluster + 1 < ThreadMap::Iterations::kCluster) { + byte_pointer += params_.increment_cluster; + } + } + } + + CUTLASS_DEVICE + MatrixCoord thread_start() const { return MatrixCoord(thread_start_row_, thread_start_column_); } + + /// Need to get the thread start row from the tile iterator + CUTLASS_DEVICE + int32_t thread_start_row() const { return thread_start_row_; } + + /// Need to get the thread start row from the tile iterator + CUTLASS_DEVICE + int32_t thread_start_column() const { return thread_start_column_; } + + /// Extent of the matrix in rows + CUTLASS_DEVICE + Index extent_row() const { return extent_row_; } + + /// Extent of the matrix in columns + CUTLASS_DEVICE + Index extent_column() const { return extent_column_; } + + /// Advances to the next position to load or store + CUTLASS_HOST_DEVICE + PredicatedTileIteratorNormVec& operator++() + { + ++state_[0]; + + if (!ScatterD) { byte_pointer_ += params_.advance_row; } + + thread_start_row_ += ThreadMap::Shape::kRow; + + if (state_[0] == ThreadMap::Count::kRow) { + state_[0] = 0; + ++state_[1]; + byte_pointer_ += params_.advance_group; + + thread_start_row_ += + (ThreadMap::Shape::kGroup - 1) * ThreadMap::Shape::kRow * ThreadMap::Count::kRow; + + if (state_[1] == ThreadMap::Count::kGroup) { + state_[1] = 0; + ++state_[2]; + byte_pointer_ += params_.advance_cluster; + + thread_start_row_ += ThreadMap::Count::kGroup * ThreadMap::Shape::kGroup * + ThreadMap::Count::kRow * ThreadMap::Shape::kRow; + + if (state_[2] == ThreadMap::Count::kCluster) { + state_[2] = 0; + byte_pointer_ += params_.advance_tile; + } + } + } + + return *this; + } + + ///< Efficiently disables all accesses guarded by mask + CUTLASS_DEVICE void clear_mask() { mask_.clear(); } + + ///< Efficiently enables all accesses guarded by mask + CUTLASS_DEVICE void enable_mask() { mask_.enable(); } + + ///< Sets the mask + CUTLASS_DEVICE void get_mask(Mask& mask) const { mask = mask_; } + + ///< Sets the mask + CUTLASS_DEVICE void set_mask(Mask const& mask) { mask_ = mask; } +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/cpp/include/raft/util/cudart_utils.hpp b/cpp/include/raft/util/cudart_utils.hpp index 1346c6073a..68a95da587 100644 --- a/cpp/include/raft/util/cudart_utils.hpp +++ b/cpp/include/raft/util/cudart_utils.hpp @@ -355,6 +355,18 @@ inline int getMultiProcessorCount() return mpCount; } +/** helper method to get major minor compute capability version */ +inline std::pair getComputeCapability() +{ + int devId; + RAFT_CUDA_TRY(cudaGetDevice(&devId)); + int majorVer, minorVer; + RAFT_CUDA_TRY(cudaDeviceGetAttribute(&majorVer, cudaDevAttrComputeCapabilityMajor, devId)); + RAFT_CUDA_TRY(cudaDeviceGetAttribute(&minorVer, cudaDevAttrComputeCapabilityMinor, devId)); + + return std::make_pair(majorVer, minorVer); +} + /** helper method to convert an array on device to a string on host */ template std::string arr2Str(const T* arr, int size, std::string name, cudaStream_t stream, int width = 4) diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 3192330639..31144f6ffd 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -108,6 +108,7 @@ if(BUILD_TESTS) test/distance/dist_cos.cu test/distance/dist_euc_exp.cu test/distance/dist_euc_unexp.cu + test/distance/dist_eucsqrt_exp.cu test/distance/dist_hamming.cu test/distance/dist_hellinger.cu test/distance/dist_jensen_shannon.cu @@ -194,7 +195,7 @@ if(BUILD_TESTS) ConfigureTest( NAME SOLVERS_TEST PATH test/cluster_solvers_deprecated.cu test/eigen_solvers.cu test/lap/lap.cu - test/mst.cu + test/mst.cu OPTIONAL DIST ) ConfigureTest( diff --git a/cpp/test/distance/dist_adj.cu b/cpp/test/distance/dist_adj.cu index 72906af1b2..f3f36b4576 100644 --- a/cpp/test/distance/dist_adj.cu +++ b/cpp/test/distance/dist_adj.cu @@ -26,7 +26,7 @@ namespace raft { namespace distance { template -__global__ void naiveDistanceAdjKernel(bool* dist, +__global__ void naiveDistanceAdjKernel(uint8_t* dist, const DataType* x, const DataType* y, int m, @@ -50,7 +50,7 @@ __global__ void naiveDistanceAdjKernel(bool* dist, } template -void naiveDistanceAdj(bool* dist, +void naiveDistanceAdj(uint8_t* dist, const DataType* x, const DataType* y, int m, @@ -74,6 +74,18 @@ struct DistanceAdjInputs { unsigned long long int seed; }; +template +struct threshold_final_op { + DataT threshold_val; + + __device__ __host__ threshold_final_op() noexcept : threshold_val(0.0) {} + __device__ __host__ threshold_final_op(DataT val) noexcept : threshold_val(val) {} + __device__ __host__ OutT operator()(AccT d_val, Index g_idx) const noexcept + { + return d_val <= threshold_val; + } +}; + template ::std::ostream& operator<<(::std::ostream& os, const DistanceAdjInputs& dims) { @@ -109,25 +121,28 @@ class DistanceAdjTest : public ::testing::TestWithParam( + getWorkspaceSize( x.data(), y.data(), m, n, k); rmm::device_uvector workspace(worksize, stream); - auto fin_op = [threshold] __device__(DataType d_val, int g_d_idx) { - return d_val <= threshold; - }; - raft::distance::distance( - x.data(), - y.data(), - dist.data(), - m, - n, - k, - workspace.data(), - workspace.size(), - fin_op, - stream, - isRowMajor); + using threshold_final_op_ = threshold_final_op; + threshold_final_op_ threshold_op(threshold); + + raft::distance::distance(x.data(), + y.data(), + dist.data(), + m, + n, + k, + workspace.data(), + workspace.size(), + threshold_op, + stream, + isRowMajor); handle.sync_stream(stream); } @@ -135,8 +150,12 @@ class DistanceAdjTest : public ::testing::TestWithParam params; - rmm::device_uvector dist_ref; - rmm::device_uvector dist; + // We use uint8_t even if the output in this test is a bool because + // cutlass doesn't support bool as output buffer yet. In cuda + // sizeof(bool) is 1 byte hence it doesn't increase + // memory consumption if we use uint8_t instead of bool. + rmm::device_uvector dist_ref; + rmm::device_uvector dist; raft::handle_t handle; cudaStream_t stream; }; @@ -156,7 +175,7 @@ TEST_P(DistanceAdjTestF, Result) { int m = params.isRowMajor ? params.m : params.n; int n = params.isRowMajor ? params.n : params.m; - ASSERT_TRUE(devArrMatch(dist_ref.data(), dist.data(), m, n, raft::Compare(), stream)); + ASSERT_TRUE(devArrMatch(dist_ref.data(), dist.data(), m, n, raft::Compare(), stream)); } INSTANTIATE_TEST_CASE_P(DistanceAdjTests, DistanceAdjTestF, ::testing::ValuesIn(inputsf)); @@ -175,7 +194,7 @@ TEST_P(DistanceAdjTestD, Result) { int m = params.isRowMajor ? params.m : params.n; int n = params.isRowMajor ? params.n : params.m; - ASSERT_TRUE(devArrMatch(dist_ref.data(), dist.data(), m, n, raft::Compare(), stream)); + ASSERT_TRUE(devArrMatch(dist_ref.data(), dist.data(), m, n, raft::Compare(), stream)); } INSTANTIATE_TEST_CASE_P(DistanceAdjTests, DistanceAdjTestD, ::testing::ValuesIn(inputsd)); diff --git a/cpp/test/distance/dist_euc_exp.cu b/cpp/test/distance/dist_euc_exp.cu index ff142da7fa..5371b8a3e2 100644 --- a/cpp/test/distance/dist_euc_exp.cu +++ b/cpp/test/distance/dist_euc_exp.cu @@ -25,14 +25,17 @@ class DistanceEucExpTest : public DistanceTest> inputsf = { + {0.001f, 2048, 4096, 128, true, 1234ULL}, {0.001f, 1024, 1024, 32, true, 1234ULL}, {0.001f, 1024, 32, 1024, true, 1234ULL}, {0.001f, 32, 1024, 1024, true, 1234ULL}, {0.003f, 1024, 1024, 1024, true, 1234ULL}, + {0.003f, 1021, 1021, 1021, true, 1234ULL}, {0.001f, 1024, 1024, 32, false, 1234ULL}, {0.001f, 1024, 32, 1024, false, 1234ULL}, {0.001f, 32, 1024, 1024, false, 1234ULL}, {0.003f, 1024, 1024, 1024, false, 1234ULL}, + {0.003f, 1021, 1021, 1021, false, 1234ULL}, }; typedef DistanceEucExpTest DistanceEucExpTestF; TEST_P(DistanceEucExpTestF, Result) diff --git a/cpp/test/distance/dist_eucsqrt_exp.cu b/cpp/test/distance/dist_eucsqrt_exp.cu new file mode 100644 index 0000000000..c4f2dc80c2 --- /dev/null +++ b/cpp/test/distance/dist_eucsqrt_exp.cu @@ -0,0 +1,76 @@ +/* + * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../test_utils.h" +#include "distance_base.cuh" + +namespace raft { +namespace distance { + +template +class DistanceEucSqrtExpTest + : public DistanceTest { +}; + +const std::vector> inputsf = { + {0.001f, 2048, 4096, 128, true, 1234ULL}, + {0.001f, 1024, 1024, 32, true, 1234ULL}, + {0.001f, 1024, 32, 1024, true, 1234ULL}, + {0.001f, 32, 1024, 1024, true, 1234ULL}, + {0.003f, 1024, 1024, 1024, true, 1234ULL}, + {0.003f, 1021, 1021, 1021, true, 1234ULL}, + {0.001f, 1024, 1024, 32, false, 1234ULL}, + {0.001f, 1024, 32, 1024, false, 1234ULL}, + {0.001f, 32, 1024, 1024, false, 1234ULL}, + {0.003f, 1024, 1024, 1024, false, 1234ULL}, + {0.003f, 1021, 1021, 1021, false, 1234ULL}, +}; +typedef DistanceEucSqrtExpTest DistanceEucSqrtExpTestF; +TEST_P(DistanceEucSqrtExpTestF, Result) +{ + int m = params.isRowMajor ? params.m : params.n; + int n = params.isRowMajor ? params.n : params.m; + ASSERT_TRUE(devArrMatch( + dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance), stream)); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceEucSqrtExpTestF, ::testing::ValuesIn(inputsf)); + +const std::vector> inputsd = { + {0.001, 1024, 1024, 32, true, 1234ULL}, + {0.001, 1024, 32, 1024, true, 1234ULL}, + {0.001, 32, 1024, 1024, true, 1234ULL}, + {0.003, 1024, 1024, 1024, true, 1234ULL}, + {0.001, 1024, 1024, 32, false, 1234ULL}, + {0.001, 1024, 32, 1024, false, 1234ULL}, + {0.001, 32, 1024, 1024, false, 1234ULL}, + {0.003, 1024, 1024, 1024, false, 1234ULL}, +}; +typedef DistanceEucSqrtExpTest DistanceEucSqrtExpTestD; +TEST_P(DistanceEucSqrtExpTestD, Result) +{ + int m = params.isRowMajor ? params.m : params.n; + int n = params.isRowMajor ? params.n : params.m; + ASSERT_TRUE(devArrMatch( + dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance), stream)); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceEucSqrtExpTestD, ::testing::ValuesIn(inputsd)); + +class BigMatrixEucSqrtExp + : public BigMatrixDistanceTest { +}; +TEST_F(BigMatrixEucSqrtExp, Result) {} +} // end namespace distance +} // end namespace raft diff --git a/docs/source/build.md b/docs/source/build.md index fe34e0fb41..30acc1b399 100644 --- a/docs/source/build.md +++ b/docs/source/build.md @@ -19,8 +19,9 @@ In addition to the libraries included with cudatoolkit 11.0+, there are some oth #### Optional - [cuCollections](https://github.com/NVIDIA/cuCollections) - Used in `raft::sparse::distance` API. -- [Libcu++](https://github.com/NVIDIA/libcudacxx) v1.7.0 -- [FAISS](https://github.com/facebookresearch/faiss) v1.7.0 - Used in `raft::neighbors` API.. +- [Libcu++](https://github.com/NVIDIA/libcudacxx) v1.7.0 - Used by cuCollections +- [CUTLASS](https://github.com/NVIDIA/cutlass) v2.9.1 - Used in `raft::distance` API. +- [FAISS](https://github.com/facebookresearch/faiss) v1.7.0 - Used in `raft::neighbors` API. - [NCCL](https://github.com/NVIDIA/nccl) - Used in `raft::comms` API and needed to build `raft-dask`. - [UCX](https://github.com/openucx/ucx) - Used in `raft::comms` API and needed to build `raft-dask`. - [Googletest](https://github.com/google/googletest) - Needed to build tests