Skip to content

Commit

Permalink
Add cutlass 3xTF32,DMMA based L2/cosine distance kernels for SM 8.0 o…
Browse files Browse the repository at this point in the history
…r higher (#939)

-- 3xTF32 cutlass based L2 exp/cosine kernel provides 3.5x speedup for fp32 inputs compared to existing pairwise distance kernel for ampere or higher.
-- DMMA cutlass based implementation for L2 exp/cosine provides 2.6x speedup instead of existing double precision FMA pipeline based kernel.
-- add cutlass as header only dependency to RAFT.

Authors:
  - Mahesh Doijade (https://github.com/mdoijade)
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Tamas Bela Feher (https://github.com/tfeher)
  - Robert Maynard (https://github.com/robertmaynard)
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #939
  • Loading branch information
mdoijade authored Nov 16, 2022
1 parent c7e74bd commit 611abc7
Show file tree
Hide file tree
Showing 17 changed files with 1,798 additions and 144 deletions.
18 changes: 15 additions & 3 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ rapids_cpm_init()
include(cmake/thirdparty/get_thrust.cmake)
include(cmake/thirdparty/get_rmm.cmake)
include(cmake/thirdparty/get_faiss.cmake)
include(cmake/thirdparty/get_cutlass.cmake)

if(RAFT_ENABLE_cuco_DEPENDENCY)
include(${rapids-cmake-dir}/cpm/cuco.cmake)
Expand Down Expand Up @@ -365,7 +366,11 @@ if(RAFT_COMPILE_DIST_LIBRARY)
INTERFACE_POSITION_INDEPENDENT_CODE ON
)

target_link_libraries(raft_distance_lib PUBLIC raft::raft cuco::cuco)
target_link_libraries(
raft_distance_lib
PUBLIC raft::raft cuco::cuco
PRIVATE nvidia::cutlass::cutlass
)
target_compile_options(
raft_distance_lib PRIVATE "$<$<COMPILE_LANGUAGE:CXX>:${RAFT_CXX_FLAGS}>"
"$<$<COMPILE_LANGUAGE:CUDA>:${RAFT_CUDA_FLAGS}>"
Expand All @@ -383,6 +388,7 @@ endif()

target_link_libraries(
raft_distance INTERFACE raft::raft $<TARGET_NAME_IF_EXISTS:raft::raft_distance_lib>
nvidia::cutlass::cutlass
)

# ##################################################################################################
Expand Down Expand Up @@ -439,7 +445,11 @@ if(RAFT_COMPILE_NN_LIBRARY)
INTERFACE_POSITION_INDEPENDENT_CODE ON
)

target_link_libraries(raft_nn_lib PUBLIC faiss::faiss raft::raft)
target_link_libraries(
raft_nn_lib
PUBLIC faiss::faiss raft::raft
PRIVATE nvidia::cutlass::cutlass
)
target_compile_options(
raft_nn_lib PRIVATE "$<$<COMPILE_LANGUAGE:CXX>:${RAFT_CXX_FLAGS}>"
"$<$<COMPILE_LANGUAGE:CUDA>:${RAFT_CUDA_FLAGS}>"
Expand All @@ -454,7 +464,9 @@ if(TARGET raft_nn_lib AND (NOT TARGET raft::raft_nn_lib))
add_library(raft::raft_nn_lib ALIAS raft_nn_lib)
endif()

target_link_libraries(raft_nn INTERFACE raft::raft $<TARGET_NAME_IF_EXISTS:raft::raft_nn_lib>)
target_link_libraries(
raft_nn INTERFACE raft::raft $<TARGET_NAME_IF_EXISTS:raft::raft_nn_lib> nvidia::cutlass::cutlass
)

# ##################################################################################################
# * install targets-----------------------------------------------------------
Expand Down
99 changes: 99 additions & 0 deletions cpp/cmake/thirdparty/get_cutlass.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# =============================================================================
# Copyright (c) 2021-2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing permissions and limitations under
# the License.
# =============================================================================

function(find_and_configure_cutlass)
set(oneValueArgs VERSION REPOSITORY PINNED_TAG)
cmake_parse_arguments(PKG "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})

# if(RAFT_ENABLE_DIST_DEPENDENCIES OR RAFT_COMPILE_LIBRARIES)
set(CUTLASS_ENABLE_HEADERS_ONLY
ON
CACHE BOOL "Enable only the header library"
)
set(CUTLASS_NAMESPACE
"raft_cutlass"
CACHE STRING "Top level namespace of CUTLASS"
)
set(CUTLASS_ENABLE_CUBLAS
OFF
CACHE BOOL "Disable CUTLASS to build with cuBLAS library."
)

rapids_cpm_find(
NvidiaCutlass ${PKG_VERSION}
GLOBAL_TARGETS nvidia::cutlass::cutlass
CPM_ARGS
GIT_REPOSITORY ${PKG_REPOSITORY}
GIT_TAG ${PKG_PINNED_TAG}
GIT_SHALLOW TRUE
OPTIONS "CUDAToolkit_ROOT ${CUDAToolkit_LIBRARY_DIR}"
)

if(TARGET CUTLASS AND NOT TARGET nvidia::cutlass::cutlass)
add_library(nvidia::cutlass::cutlass ALIAS CUTLASS)
endif()

if(NvidiaCutlass_ADDED)
rapids_export(
BUILD NvidiaCutlass
EXPORT_SET NvidiaCutlass
GLOBAL_TARGETS nvidia::cutlass::cutlass
NAMESPACE nvidia::cutlass::
)
endif()
# endif()

# We generate the cutlass-config files when we built cutlass locally, so always do
# `find_dependency`
rapids_export_package(
BUILD NvidiaCutlass raft-distance-exports GLOBAL_TARGETS nvidia::cutlass::cutlass
)
rapids_export_package(
INSTALL NvidiaCutlass raft-distance-exports GLOBAL_TARGETS nvidia::cutlass::cutlass
)
rapids_export_package(
BUILD NvidiaCutlass raft-nn-exports GLOBAL_TARGETS nvidia::cutlass::cutlass
)
rapids_export_package(
INSTALL NvidiaCutlass raft-nn-exports GLOBAL_TARGETS nvidia::cutlass::cutlass
)

# Tell cmake where it can find the generated NvidiaCutlass-config.cmake we wrote.
include("${rapids-cmake-dir}/export/find_package_root.cmake")
rapids_export_find_package_root(
INSTALL NvidiaCutlass [=[${CMAKE_CURRENT_LIST_DIR}/../]=] raft-distance-exports
)
rapids_export_find_package_root(
BUILD NvidiaCutlass [=[${CMAKE_CURRENT_LIST_DIR}]=] raft-distance-exports
)
include("${rapids-cmake-dir}/export/find_package_root.cmake")
rapids_export_find_package_root(
INSTALL NvidiaCutlass [=[${CMAKE_CURRENT_LIST_DIR}/../]=] raft-nn-exports
)
rapids_export_find_package_root(
BUILD NvidiaCutlass [=[${CMAKE_CURRENT_LIST_DIR}]=] raft-nn-exports
)
endfunction()

if(NOT RAFT_CUTLASS_GIT_TAG)
set(RAFT_CUTLASS_GIT_TAG v2.9.1)
endif()

if(NOT RAFT_CUTLASS_GIT_REPOSITORY)
set(RAFT_CUTLASS_GIT_REPOSITORY https://github.com/NVIDIA/cutlass.git)
endif()

find_and_configure_cutlass(
VERSION 2.9.1 REPOSITORY ${RAFT_CUTLASS_GIT_REPOSITORY} PINNED_TAG ${RAFT_CUTLASS_GIT_TAG}
)
138 changes: 80 additions & 58 deletions cpp/include/raft/distance/detail/cosine.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,23 @@
#pragma once

#include <raft/distance/detail/pairwise_distance_base.cuh>
#include <raft/distance/detail/pairwise_distance_cutlass_base.cuh>
#include <raft/linalg/norm.cuh>

namespace raft {
namespace distance {
namespace detail {

template <typename DataT, typename AccT>
struct CosineOp {
__device__ CosineOp() noexcept {}
__device__ AccT operator()(DataT& aNorm, const DataT& bNorm, DataT& accVal) const noexcept
{
return static_cast<AccT>(1.0) - (AccT)(accVal / (aNorm * bNorm));
}
__device__ AccT operator()(DataT aData) const noexcept { return aData; }
};

/**
* @brief the cosine distance matrix calculation implementer
* It computes the following equation:
Expand Down Expand Up @@ -71,61 +82,74 @@ void cosineImpl(const DataT* x,
FinalLambda fin_op,
cudaStream_t stream)
{
typedef typename raft::linalg::Policy4x4<DataT, VecLen>::Policy RowPolicy;
typedef typename raft::linalg::Policy4x4<DataT, VecLen>::ColPolicy ColPolicy;
#if (__CUDACC_VER_MAJOR__ < 12)
const auto deviceVersion = getComputeCapability();
if (deviceVersion.first >= 8) {
using CosineOp_ = CosineOp<DataT, AccT>;
CosineOp_ cosine_dist_op;

cutlassDistanceKernel<DataT, AccT, OutT, IdxT, VecLen, FinalLambda, CosineOp_, isRowMajor>(
x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput, fin_op, cosine_dist_op, stream);

typedef typename std::conditional<isRowMajor, RowPolicy, ColPolicy>::type KPolicy;
} else
#endif
{
typedef typename raft::linalg::Policy4x4<DataT, VecLen>::Policy RowPolicy;
typedef typename raft::linalg::Policy4x4<DataT, VecLen>::ColPolicy ColPolicy;

dim3 blk(KPolicy::Nthreads);
typedef typename std::conditional<isRowMajor, RowPolicy, ColPolicy>::type KPolicy;

// Accumulation operation lambda
auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) { acc += x * y; };
dim3 blk(KPolicy::Nthreads);

// epilogue operation lambda for final value calculation
auto epilog_lambda = [] __device__(AccT acc[KPolicy::AccRowsPerTh][KPolicy::AccColsPerTh],
DataT * regxn,
DataT * regyn,
IdxT gridStrideX,
IdxT gridStrideY) {
// Accumulation operation lambda
auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) { acc += x * y; };

// epilogue operation lambda for final value calculation
auto epilog_lambda = [] __device__(AccT acc[KPolicy::AccRowsPerTh][KPolicy::AccColsPerTh],
DataT * regxn,
DataT * regyn,
IdxT gridStrideX,
IdxT gridStrideY) {
#pragma unroll
for (int i = 0; i < KPolicy::AccRowsPerTh; ++i) {
for (int i = 0; i < KPolicy::AccRowsPerTh; ++i) {
#pragma unroll
for (int j = 0; j < KPolicy::AccColsPerTh; ++j) {
acc[i][j] = acc[i][j] / (regxn[i] * regyn[j]);
for (int j = 0; j < KPolicy::AccColsPerTh; ++j) {
acc[i][j] = 1.0 - (acc[i][j] / (regxn[i] * regyn[j]));
}
}
}
};
};

constexpr size_t shmemSize =
KPolicy::SmemSize + ((KPolicy::Mblk + KPolicy::Nblk) * sizeof(DataT));
if (isRowMajor) {
auto cosineRowMajor = pairwiseDistanceMatKernel<true,
DataT,
AccT,
OutT,
IdxT,
KPolicy,
decltype(core_lambda),
decltype(epilog_lambda),
FinalLambda,
true>;
dim3 grid = launchConfigGenerator<KPolicy>(m, n, shmemSize, cosineRowMajor);
cosineRowMajor<<<grid, blk, shmemSize, stream>>>(
x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, fin_op);
} else {
auto cosineColMajor = pairwiseDistanceMatKernel<true,
DataT,
AccT,
OutT,
IdxT,
KPolicy,
decltype(core_lambda),
decltype(epilog_lambda),
FinalLambda,
false>;
dim3 grid = launchConfigGenerator<KPolicy>(m, n, shmemSize, cosineColMajor);
cosineColMajor<<<grid, blk, shmemSize, stream>>>(
x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, fin_op);
constexpr size_t shmemSize =
KPolicy::SmemSize + ((KPolicy::Mblk + KPolicy::Nblk) * sizeof(DataT));
if (isRowMajor) {
auto cosineRowMajor = pairwiseDistanceMatKernelPriorToAmpere<true,
DataT,
AccT,
OutT,
IdxT,
KPolicy,
decltype(core_lambda),
decltype(epilog_lambda),
FinalLambda,
true>;
dim3 grid = launchConfigGenerator<KPolicy>(m, n, shmemSize, cosineRowMajor);
cosineRowMajor<<<grid, blk, shmemSize, stream>>>(
x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, fin_op);
} else {
auto cosineColMajor = pairwiseDistanceMatKernelPriorToAmpere<true,
DataT,
AccT,
OutT,
IdxT,
KPolicy,
decltype(core_lambda),
decltype(epilog_lambda),
FinalLambda,
false>;
dim3 grid = launchConfigGenerator<KPolicy>(m, n, shmemSize, cosineColMajor);
cosineColMajor<<<grid, blk, shmemSize, stream>>>(
x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, fin_op);
}
}

RAFT_CUDA_TRY(cudaGetLastError());
Expand Down Expand Up @@ -207,13 +231,11 @@ void cosineAlgo1(Index_ m,
{
auto norm_op = [] __device__(AccType in) { return raft::mySqrt(in); };

// Wrap fin_op to allow computing 1 - pA before calling fin_op
auto wrapped_fin_op = [fin_op] __device__(AccType d_val, Index_ g_d_idx) {
return fin_op(static_cast<AccType>(1.0) - d_val, g_d_idx);
};

typedef std::is_same<OutType, bool> is_bool;
typedef typename std::conditional<is_bool::value, OutType, AccType>::type CosOutType;
// raft distance support inputs as float/double and output as uint8_t/float/double.
static_assert(!((sizeof(OutType) > 1) && (sizeof(AccType) != sizeof(OutType))),
"OutType can be uint8_t, float, double,"
"if sizeof(OutType) > 1 then sizeof(AccType) == sizeof(OutType).");
typedef typename std::conditional<sizeof(OutType) == 1, OutType, AccType>::type CosOutType;
CosOutType* pDcast = reinterpret_cast<CosOutType*>(pD);

ASSERT(
Expand All @@ -234,12 +256,12 @@ void cosineAlgo1(Index_ m,

if (isRowMajor) {
lda = k, ldb = k, ldd = n;
cosine<InType, AccType, CosOutType, Index_, decltype(wrapped_fin_op), true>(
m, n, k, lda, ldb, ldd, pA, pB, col_vec, row_vec, pDcast, wrapped_fin_op, stream);
cosine<InType, AccType, CosOutType, Index_, FinalLambda, true>(
m, n, k, lda, ldb, ldd, pA, pB, col_vec, row_vec, pDcast, fin_op, stream);
} else {
lda = n, ldb = m, ldd = m;
cosine<InType, AccType, CosOutType, Index_, decltype(wrapped_fin_op), false>(
n, m, k, lda, ldb, ldd, pB, pA, row_vec, col_vec, pDcast, wrapped_fin_op, stream);
cosine<InType, AccType, CosOutType, Index_, FinalLambda, false>(
n, m, k, lda, ldb, ldd, pB, pA, row_vec, col_vec, pDcast, fin_op, stream);
}
}

Expand Down
25 changes: 22 additions & 3 deletions cpp/include/raft/distance/detail/distance.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,19 @@ void distance(const InType* x,
* @note if workspace is passed as nullptr, this will return in
* worksize, the number of bytes of workspace required
*/

// Default final op functor which facilitates elementwise operation on
// final distance value if any.
template <typename AccType, typename OutType, typename Index>
struct default_fin_op {
__host__ __device__ default_fin_op() noexcept {};
// functor signature.
__host__ __device__ OutType operator()(AccType d_val, Index g_d_idx) const noexcept
{
return d_val;
}
};

template <raft::distance::DistanceType distanceType,
typename InType,
typename AccType,
Expand All @@ -632,9 +645,15 @@ void distance(const InType* x,
bool isRowMajor = true,
InType metric_arg = 2.0f)
{
auto default_fin_op = [] __device__(AccType d_val, Index_ g_d_idx) { return d_val; };
distance<distanceType, InType, AccType, OutType, decltype(default_fin_op), Index_>(
x, y, dist, m, n, k, workspace, worksize, default_fin_op, stream, isRowMajor, metric_arg);
using final_op_type = default_fin_op<AccType, OutType, Index_>;
final_op_type fin_op;

// raft distance support inputs as float/double and output as uint8_t/float/double.
static_assert(!((sizeof(OutType) > 1) && (sizeof(AccType) != sizeof(OutType))),
"OutType can be uint8_t, float, double,"
"if sizeof(OutType) > 1 then sizeof(AccType) == sizeof(OutType).");
distance<distanceType, InType, AccType, OutType, final_op_type, Index_>(
x, y, dist, m, n, k, workspace, worksize, fin_op, stream, isRowMajor, metric_arg);
RAFT_CUDA_TRY(cudaPeekAtLastError());
}

Expand Down
Loading

0 comments on commit 611abc7

Please sign in to comment.