Skip to content

Commit

Permalink
Add fused cosine 1-NN cutlass based kernel (#2125)
Browse files Browse the repository at this point in the history
- Adds cosine 1-NN cutlass based kernel for SM 8.0 or higher using tensor cores.
- based on 3x TF32
- unifies the fusedDistanceNN kernels for L2/cosine.
- expose this API in pylibraft as `fused_distance_nn_arg_min` supporting cosine & L2 distance metrics.

Authors:
  - Mahesh Doijade (https://github.com/mdoijade)
  - Tamas Bela Feher (https://github.com/tfeher)

Approvers:
  - Tamas Bela Feher (https://github.com/tfeher)
  - Ben Frederickson (https://github.com/benfred)

URL: #2125
  • Loading branch information
mdoijade authored Mar 19, 2024
1 parent e53aa0c commit 413e34e
Show file tree
Hide file tree
Showing 32 changed files with 2,281 additions and 174 deletions.
4 changes: 3 additions & 1 deletion cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ endif()

if(RAFT_NVTX)
# This enables NVTX within the project with no option to disable it downstream.
target_link_libraries(raft INTERFACE CUDA::nvToolsExt)
target_link_libraries(raft INTERFACE CUDA::nvtx3)
target_compile_definitions(raft INTERFACE NVTX_ENABLED)
else()
# Allow enable NVTX downstream if not set here. This creates a new option at build/install time,
Expand Down Expand Up @@ -324,6 +324,7 @@ if(RAFT_COMPILE_LIBRARY)
src/distance/detail/pairwise_matrix/dispatch_russel_rao_float_float_float_int.cu
src/distance/distance.cu
src/distance/fused_l2_nn.cu
src/distance/fused_distance_nn.cu
src/linalg/detail/coalesced_reduction.cu
src/matrix/detail/select_k_double_int64_t.cu
src/matrix/detail/select_k_double_uint32_t.cu
Expand Down Expand Up @@ -422,6 +423,7 @@ if(RAFT_COMPILE_LIBRARY)
src/raft_runtime/cluster/update_centroids.cuh
src/raft_runtime/cluster/update_centroids_double.cu
src/raft_runtime/cluster/update_centroids_float.cu
src/raft_runtime/distance/fused_distance_min_arg.cu
src/raft_runtime/distance/fused_l2_min_arg.cu
src/raft_runtime/distance/pairwise_distance.cu
src/raft_runtime/matrix/select_k_float_int64_t.cu
Expand Down
12 changes: 2 additions & 10 deletions cpp/bench/prims/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -122,16 +122,8 @@ if(BUILD_PRIMS_BENCH)
)

ConfigureBench(
NAME
MATRIX_BENCH
PATH
bench/prims/matrix/argmin.cu
bench/prims/matrix/gather.cu
bench/prims/matrix/select_k.cu
bench/prims/matrix/main.cpp
OPTIONAL
LIB
EXPLICIT_INSTANTIATE_ONLY
NAME MATRIX_BENCH PATH bench/prims/matrix/argmin.cu bench/prims/matrix/gather.cu
bench/prims/matrix/select_k.cu bench/prims/main.cpp OPTIONAL LIB EXPLICIT_INSTANTIATE_ONLY
)

ConfigureBench(
Expand Down
97 changes: 97 additions & 0 deletions cpp/include/raft/distance/detail/fused_distance_nn.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <raft/core/kvp.hpp> // raft::KeyValuePair
#include <raft/core/operators.hpp> // raft::identity_op
#include <raft/distance/detail/distance_ops/l2_exp.cuh> // ops::l2_exp_distance_op
#include <raft/distance/detail/fused_distance_nn/cutlass_base.cuh>
#include <raft/distance/detail/fused_distance_nn/fused_cosine_nn.cuh>
#include <raft/distance/detail/fused_distance_nn/fused_l2_nn.cuh>
#include <raft/distance/detail/fused_distance_nn/helper_structs.cuh>
#include <raft/distance/detail/fused_distance_nn/simt_kernel.cuh>
#include <raft/distance/detail/pairwise_distance_base.cuh> // PairwiseDistances
#include <raft/distance/distance_types.hpp>
#include <raft/linalg/contractions.cuh> // Policy
#include <raft/util/arch.cuh> // raft::util::arch::SM_*
#include <raft/util/cuda_utils.cuh> // raft::ceildiv, raft::shfl

#include <cstddef> // size_t
#include <limits> // std::numeric_limits

namespace raft {
namespace distance {

namespace detail {

template <typename DataT,
typename OutT,
typename IdxT,
typename Policy,
typename ReduceOpT,
typename KVPReduceOpT>
void fusedDistanceNNImpl(OutT* min,
const DataT* x,
const DataT* y,
const DataT* xn,
const DataT* yn,
IdxT m,
IdxT n,
IdxT k,
int* workspace,
ReduceOpT redOp,
KVPReduceOpT pairRedOp,
bool sqrt,
bool initOutBuffer,
bool isRowMajor,
raft::distance::DistanceType metric,
float metric_arg,
cudaStream_t stream)
{
// The kernel policy is determined by fusedDistanceNN.
typedef Policy P;

dim3 blk(P::Nthreads);
auto nblks = raft::ceildiv<int>(m, P::Nthreads);
constexpr auto maxVal = std::numeric_limits<DataT>::max();
typedef KeyValuePair<IdxT, DataT> KVPair;

RAFT_CUDA_TRY(cudaMemsetAsync(workspace, 0, sizeof(int) * m, stream));
if (initOutBuffer) {
initKernel<DataT, OutT, IdxT, ReduceOpT>
<<<nblks, P::Nthreads, 0, stream>>>(min, m, maxVal, redOp);
RAFT_CUDA_TRY(cudaGetLastError());
}

switch (metric) {
case raft::distance::DistanceType::CosineExpanded:
fusedCosineNN<DataT, OutT, IdxT, P, ReduceOpT, KVPReduceOpT>(
min, x, y, xn, yn, m, n, k, workspace, redOp, pairRedOp, sqrt, stream);
break;
case raft::distance::DistanceType::L2SqrtExpanded:
case raft::distance::DistanceType::L2Expanded:
// initOutBuffer is take care by fusedDistanceNNImpl() so we set it false to fusedL2NNImpl.
fusedL2NNImpl<DataT, OutT, IdxT, P, ReduceOpT, KVPReduceOpT>(
min, x, y, xn, yn, m, n, k, workspace, redOp, pairRedOp, sqrt, false, stream);
break;
default: assert("only cosine/l2 metric is supported with fusedDistanceNN\n"); break;
}
}

} // namespace detail
} // namespace distance
} // namespace raft
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,7 @@ class EpilogueWithBroadcastCustom : public EpilogueBase<Shape_,
++tensor_iterator;
}
}
tensor_iterator.dumpToGmem();
}

/// Helper to invoke the output functor over each vector of output
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@

#include <rmm/device_uvector.hpp>

#include <cuda/semaphore>

#include <cutlass/cutlass.h>
#include <cutlass/gemm/device/gemm.h>
#include <cutlass/gemm/device/gemm_grouped.h>
Expand All @@ -46,6 +48,14 @@ namespace raft {
namespace distance {
namespace detail {

template <typename IdxT>
RAFT_KERNEL initBinMutexKernel(cuda::binary_semaphore<cuda::thread_scope_device>* mut, IdxT m)
{
auto tid = IdxT(blockIdx.x) * blockDim.x + threadIdx.x;

if (tid < m) { mut[tid].release(); }
}

template <typename DataT,
typename AccT,
typename OutT,
Expand Down Expand Up @@ -87,8 +97,14 @@ void cutlassFusedDistanceNN(const DataT* x,
KVPReduceOpT>;
constexpr int batch_count = 1;

rmm::device_uvector<cuda::binary_semaphore<cuda::thread_scope_device>> bin_mutex(m, stream);

int blks_ = (m / 256) + 1;

initBinMutexKernel<<<blks_, 256, 0, stream>>>(bin_mutex.data(), m);

typename EpilogueOutputOp::Params epilog_op_param(
dist_op, cg_reduce_op, redOp, pairRedOp, mutexes);
dist_op, cg_reduce_op, redOp, pairRedOp, mutexes, bin_mutex.data());

// Number of pipelines you want to use
constexpr int NumStages = 3;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@

#pragma once

#include <cuda/semaphore>

#include <cutlass/array.h>
#include <cutlass/cutlass.h>
#include <cutlass/epilogue/thread/activation.h>
Expand Down Expand Up @@ -121,6 +123,7 @@ class FusedDistanceNNEpilogueElementwise {
KVPReduceOpT_ pair_redop_;
ReduceOpT_ red_op_;
int* mutexes_;
cuda::binary_semaphore<cuda::thread_scope_device>* bin_mutex_;
using CGReduceT = CGReduceOp_;
//
// Methods
Expand All @@ -130,12 +133,14 @@ class FusedDistanceNNEpilogueElementwise {
CGReduceOp cg_reduce_op,
ReduceOpT_ red_op,
KVPReduceOpT_ pair_redop,
int* mutexes)
int* mutexes,
cuda::binary_semaphore<cuda::thread_scope_device>* bin_mutex)
: cg_reduce_op(cg_reduce_op),
dist_op_(dist_op),
pair_redop_(pair_redop),
red_op_(red_op),
mutexes_(mutexes)
mutexes_(mutexes),
bin_mutex_(bin_mutex)
{
}

Expand Down
136 changes: 136 additions & 0 deletions cpp/include/raft/distance/detail/fused_distance_nn/fused_cosine_nn.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <raft/core/kvp.hpp> // raft::KeyValuePair
#include <raft/core/operators.hpp> // raft::identity_op
#include <raft/distance/detail/distance_ops/cosine.cuh> // ops::l2_exp_distance_op
#include <raft/distance/detail/fused_distance_nn/cutlass_base.cuh>
#include <raft/distance/detail/fused_distance_nn/helper_structs.cuh>
#include <raft/distance/detail/fused_distance_nn/simt_kernel.cuh>
#include <raft/distance/detail/pairwise_distance_base.cuh> // PairwiseDistances
#include <raft/linalg/contractions.cuh> // Policy
#include <raft/util/arch.cuh> // raft::util::arch::SM_*
#include <raft/util/cuda_utils.cuh> // raft::ceildiv, raft::shfl

#include <cstddef> // size_t
#include <limits> // std::numeric_limits

namespace raft {
namespace distance {

namespace detail {

template <typename DataT,
typename OutT,
typename IdxT,
typename Policy,
typename ReduceOpT,
typename KVPReduceOpT>
void fusedCosineNN(OutT* min,
const DataT* x,
const DataT* y,
const DataT* xn,
const DataT* yn,
IdxT m,
IdxT n,
IdxT k,
int* workspace,
ReduceOpT redOp,
KVPReduceOpT pairRedOp,
bool sqrt,
cudaStream_t stream)
{
// The kernel policy is determined by fusedL2NN.
typedef Policy P;

dim3 blk(P::Nthreads);
constexpr auto maxVal = std::numeric_limits<DataT>::max();
typedef KeyValuePair<IdxT, DataT> KVPair;

namespace arch = raft::util::arch;
using AccT = DataT;
ops::cosine_distance_op<DataT, AccT, IdxT> distance_op{};

raft::identity_op fin_op{};

auto kernel = fusedDistanceNNkernel<DataT,
OutT,
IdxT,
P,
ReduceOpT,
KVPReduceOpT,
decltype(distance_op),
decltype(fin_op)>;

// Get pointer to fp32 SIMT kernel to determine the runtime architecture of the
// current system. Other methods to determine the architecture (that do not
// require a pointer) can be error prone. See:
// https://github.com/NVIDIA/cub/issues/545
void* kernel_ptr = reinterpret_cast<void*>(kernel);
auto runtime_arch = arch::kernel_virtual_arch(kernel_ptr);
auto cutlass_range = arch::SM_range(arch::SM_80(), arch::SM_future());

if (cutlass_range.contains(runtime_arch)) {
// If device is SM_80 or later, use CUTLASS-based kernel.
using cosineOp = raft::distance::detail::ops::cosine_cutlass_op<DataT, DataT>;
using kvp_cg_min_reduce_op_ = kvp_cg_min_reduce_op<DataT, IdxT, OutT>;
kvp_cg_min_reduce_op_ cg_reduce_op;
cosineOp cosine_dist_op;

IdxT lda, ldb, ldd;
lda = k, ldb = k, ldd = n;

cutlassFusedDistanceNN<DataT,
DataT,
OutT,
IdxT,
P::Veclen,
decltype(cg_reduce_op),
decltype(cosine_dist_op),
ReduceOpT,
KVPReduceOpT>(x,
y,
xn,
yn,
m,
n,
k,
lda,
ldb,
ldd,
min,
workspace,
cg_reduce_op,
cosine_dist_op,
redOp,
pairRedOp,
stream);
} else {
// If device less than SM_80, use fp32 SIMT kernel.
constexpr size_t shmemSize = P::SmemSize + ((P::Mblk + P::Nblk) * sizeof(DataT));
dim3 grid = launchConfigGenerator<P>(m, n, shmemSize, kernel);

kernel<<<grid, blk, shmemSize, stream>>>(
min, x, y, xn, yn, m, n, k, maxVal, workspace, redOp, pairRedOp, distance_op, fin_op);
RAFT_CUDA_TRY(cudaGetLastError());
}
}

} // namespace detail
} // namespace distance
} // namespace raft
Loading

0 comments on commit 413e34e

Please sign in to comment.