Skip to content

Commit

Permalink
Add fused cosine 1-NN kernel and unify the fused distance 1-NN kernels
Browse files Browse the repository at this point in the history
fix doc issue in fused_distance_nn runtime API
  • Loading branch information
mdoijade committed Feb 1, 2024
1 parent 3c87b92 commit 1417a2e
Show file tree
Hide file tree
Showing 23 changed files with 1,789 additions and 333 deletions.
4 changes: 3 additions & 1 deletion cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ endif()

if(RAFT_NVTX)
# This enables NVTX within the project with no option to disable it downstream.
target_link_libraries(raft INTERFACE CUDA::nvToolsExt)
target_link_libraries(raft INTERFACE CUDA::nvtx3)
target_compile_definitions(raft INTERFACE NVTX_ENABLED)
else()
# Allow enable NVTX downstream if not set here. This creates a new option at build/install time,
Expand Down Expand Up @@ -327,6 +327,7 @@ if(RAFT_COMPILE_LIBRARY)
src/distance/detail/pairwise_matrix/dispatch_russel_rao_float_float_float_int.cu
src/distance/distance.cu
src/distance/fused_l2_nn.cu
src/distance/fused_distance_nn.cu
src/linalg/detail/coalesced_reduction.cu
src/matrix/detail/select_k_double_int64_t.cu
src/matrix/detail/select_k_double_uint32_t.cu
Expand Down Expand Up @@ -425,6 +426,7 @@ if(RAFT_COMPILE_LIBRARY)
src/raft_runtime/cluster/update_centroids.cuh
src/raft_runtime/cluster/update_centroids_double.cu
src/raft_runtime/cluster/update_centroids_float.cu
src/raft_runtime/distance/fused_distance_min_arg.cu
src/raft_runtime/distance/fused_l2_min_arg.cu
src/raft_runtime/distance/pairwise_distance.cu
src/raft_runtime/matrix/select_k_float_int64_t.cu
Expand Down
12 changes: 2 additions & 10 deletions cpp/bench/prims/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -122,16 +122,8 @@ if(BUILD_PRIMS_BENCH)
)

ConfigureBench(
NAME
MATRIX_BENCH
PATH
bench/prims/matrix/argmin.cu
bench/prims/matrix/gather.cu
bench/prims/matrix/select_k.cu
bench/prims/matrix/main.cpp
OPTIONAL
LIB
EXPLICIT_INSTANTIATE_ONLY
NAME MATRIX_BENCH PATH bench/prims/matrix/argmin.cu bench/prims/matrix/gather.cu
bench/prims/matrix/select_k.cu bench/prims/main.cpp OPTIONAL LIB EXPLICIT_INSTANTIATE_ONLY
)

ConfigureBench(
Expand Down
89 changes: 89 additions & 0 deletions cpp/include/raft/distance/detail/fused_distance_nn.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
/*
* Copyright (c) 2021-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <cstddef> // size_t
#include <limits> // std::numeric_limits
#include <raft/core/kvp.hpp> // raft::KeyValuePair
#include <raft/core/operators.hpp> // raft::identity_op
#include <raft/distance/detail/distance_ops/l2_exp.cuh> // ops::l2_exp_distance_op
#include <raft/distance/detail/fused_distance_nn/cutlass_base.cuh>
#include <raft/distance/detail/fused_distance_nn/fused_cosine_nn.cuh>
#include <raft/distance/detail/fused_distance_nn/helper_structs.cuh>
#include <raft/distance/detail/fused_distance_nn/simt_kernel.cuh>
#include <raft/distance/detail/pairwise_distance_base.cuh> // PairwiseDistances
#include <raft/distance/distance_types.hpp>
#include <raft/linalg/contractions.cuh> // Policy
#include <raft/util/arch.cuh> // raft::util::arch::SM_*
#include <raft/util/cuda_utils.cuh> // raft::ceildiv, raft::shfl

namespace raft {
namespace distance {

namespace detail {

template <typename DataT,
typename OutT,
typename IdxT,
typename Policy,
typename ReduceOpT,
typename KVPReduceOpT>
void fusedDistanceNNImpl(OutT* min,
const DataT* x,
const DataT* y,
const DataT* xn,
const DataT* yn,
IdxT m,
IdxT n,
IdxT k,
int* workspace,
ReduceOpT redOp,
KVPReduceOpT pairRedOp,
bool sqrt,
bool initOutBuffer,
bool isRowMajor,
raft::distance::DistanceType metric,
float metric_arg,
cudaStream_t stream)
{
// The kernel policy is determined by fusedDistanceNN.
typedef Policy P;

dim3 blk(P::Nthreads);
auto nblks = raft::ceildiv<int>(m, P::Nthreads);
constexpr auto maxVal = std::numeric_limits<DataT>::max();
typedef KeyValuePair<IdxT, DataT> KVPair;

RAFT_CUDA_TRY(cudaMemsetAsync(workspace, 0, sizeof(int) * m, stream));
if (initOutBuffer) {
initKernel<DataT, OutT, IdxT, ReduceOpT>
<<<nblks, P::Nthreads, 0, stream>>>(min, m, maxVal, redOp);
RAFT_CUDA_TRY(cudaGetLastError());
}

switch (metric) {
case DistanceType::CosineExpanded:
fusedCosineNN<DataT, OutT, IdxT, P, ReduceOpT, KVPReduceOpT>(
min, x, y, xn, yn, m, n, k, workspace, redOp, pairRedOp, sqrt, stream);
break;
default: assert("only cosine metric is supported with fusedDistanceNN\n"); break;
}
}

} // namespace detail
} // namespace distance
} // namespace raft
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
**************************************************************************************************/

/*
* Copyright (c) 2023, NVIDIA CORPORATION.
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -615,6 +615,7 @@ class EpilogueWithBroadcastCustom : public EpilogueBase<Shape_,
++tensor_iterator;
}
}
tensor_iterator.dumpToGmem();
}

/// Helper to invoke the output functor over each vector of output
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -37,6 +37,7 @@
#include <cutlass/matrix_coord.h>
#include <cutlass/tensor_view.h>

#include <cuda/semaphore>
#include <raft/distance/detail/fused_distance_nn/epilogue_elementwise.cuh> // FusedDistanceNNEpilogueElementwise
#include <raft/distance/detail/fused_distance_nn/gemm.h> // FusedDistanceNNGemm
#include <raft/util/cudart_utils.hpp> // getMultiProcessorCount
Expand All @@ -46,6 +47,14 @@ namespace raft {
namespace distance {
namespace detail {

template <typename IdxT>
RAFT_KERNEL initBinMutexKernel(cuda::binary_semaphore<cuda::thread_scope_device>* mut, IdxT m)
{
auto tid = IdxT(blockIdx.x) * blockDim.x + threadIdx.x;

if (tid < m) { mut[tid].release(); }
}

template <typename DataT,
typename AccT,
typename OutT,
Expand Down Expand Up @@ -87,8 +96,14 @@ void cutlassFusedDistanceNN(const DataT* x,
KVPReduceOpT>;
constexpr int batch_count = 1;

rmm::device_uvector<cuda::binary_semaphore<cuda::thread_scope_device>> bin_mutex(m, stream);

int blks_ = (m / 256) + 1;

initBinMutexKernel<<<blks_, 256, 0, stream>>>(bin_mutex.data(), m);

typename EpilogueOutputOp::Params epilog_op_param(
dist_op, cg_reduce_op, redOp, pairRedOp, mutexes);
dist_op, cg_reduce_op, redOp, pairRedOp, mutexes, bin_mutex.data());

// Number of pipelines you want to use
constexpr int NumStages = 3;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
*
**************************************************************************************************/
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -62,6 +62,7 @@
#include <cutlass/numeric_conversion.h>
#include <cutlass/numeric_types.h>

#include <cuda/semaphore>
#include <cutlass/epilogue/thread/activation.h>

/////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -122,6 +123,7 @@ class FusedDistanceNNEpilogueElementwise {
KVPReduceOpT_ pair_redop_;
ReduceOpT_ red_op_;
int* mutexes_;
cuda::binary_semaphore<cuda::thread_scope_device>* bin_mutex_;
using CGReduceT = CGReduceOp_;
//
// Methods
Expand All @@ -131,12 +133,14 @@ class FusedDistanceNNEpilogueElementwise {
CGReduceOp cg_reduce_op,
ReduceOpT_ red_op,
KVPReduceOpT_ pair_redop,
int* mutexes)
int* mutexes,
cuda::binary_semaphore<cuda::thread_scope_device>* bin_mutex)
: cg_reduce_op(cg_reduce_op),
dist_op_(dist_op),
pair_redop_(pair_redop),
red_op_(red_op),
mutexes_(mutexes)
mutexes_(mutexes),
bin_mutex_(bin_mutex)
{
}

Expand Down
135 changes: 135 additions & 0 deletions cpp/include/raft/distance/detail/fused_distance_nn/fused_cosine_nn.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
/*
* Copyright (c) 2021-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <cstddef> // size_t
#include <limits> // std::numeric_limits
#include <raft/core/kvp.hpp> // raft::KeyValuePair
#include <raft/core/operators.hpp> // raft::identity_op
#include <raft/distance/detail/distance_ops/cosine.cuh> // ops::l2_exp_distance_op
#include <raft/distance/detail/fused_distance_nn/cutlass_base.cuh>
#include <raft/distance/detail/fused_distance_nn/helper_structs.cuh>
#include <raft/distance/detail/fused_distance_nn/simt_kernel.cuh>
#include <raft/distance/detail/pairwise_distance_base.cuh> // PairwiseDistances
#include <raft/linalg/contractions.cuh> // Policy
#include <raft/util/arch.cuh> // raft::util::arch::SM_*
#include <raft/util/cuda_utils.cuh> // raft::ceildiv, raft::shfl

namespace raft {
namespace distance {

namespace detail {

template <typename DataT,
typename OutT,
typename IdxT,
typename Policy,
typename ReduceOpT,
typename KVPReduceOpT>
void fusedCosineNN(OutT* min,
const DataT* x,
const DataT* y,
const DataT* xn,
const DataT* yn,
IdxT m,
IdxT n,
IdxT k,
int* workspace,
ReduceOpT redOp,
KVPReduceOpT pairRedOp,
bool sqrt,
cudaStream_t stream)
{
// The kernel policy is determined by fusedL2NN.
typedef Policy P;

dim3 blk(P::Nthreads);
constexpr auto maxVal = std::numeric_limits<DataT>::max();
typedef KeyValuePair<IdxT, DataT> KVPair;

namespace arch = raft::util::arch;
using AccT = DataT;
ops::cosine_distance_op<DataT, AccT, IdxT> distance_op{};

raft::identity_op fin_op{};

auto kernel = fusedDistanceNNkernel<DataT,
OutT,
IdxT,
P,
ReduceOpT,
KVPReduceOpT,
decltype(distance_op),
decltype(fin_op)>;

// Get pointer to fp32 SIMT kernel to determine the runtime architecture of the
// current system. Other methods to determine the architecture (that do not
// require a pointer) can be error prone. See:
// https://github.com/NVIDIA/cub/issues/545
void* kernel_ptr = reinterpret_cast<void*>(kernel);
auto runtime_arch = arch::kernel_virtual_arch(kernel_ptr);
auto cutlass_range = arch::SM_range(arch::SM_80(), arch::SM_future());

if (cutlass_range.contains(runtime_arch)) {
// If device is SM_80 or later, use CUTLASS-based kernel.
using cosineOp = raft::distance::detail::ops::cosine_cutlass_op<DataT, DataT>;
using kvp_cg_min_reduce_op_ = kvp_cg_min_reduce_op<DataT, IdxT, OutT>;
kvp_cg_min_reduce_op_ cg_reduce_op;
cosineOp cosine_dist_op;

IdxT lda, ldb, ldd;
lda = k, ldb = k, ldd = n;

cutlassFusedDistanceNN<DataT,
DataT,
OutT,
IdxT,
P::Veclen,
decltype(cg_reduce_op),
decltype(cosine_dist_op),
ReduceOpT,
KVPReduceOpT>(x,
y,
xn,
yn,
m,
n,
k,
lda,
ldb,
ldd,
min,
workspace,
cg_reduce_op,
cosine_dist_op,
redOp,
pairRedOp,
stream);
} else {
// If device less than SM_80, use fp32 SIMT kernel.
constexpr size_t shmemSize = P::SmemSize + ((P::Mblk + P::Nblk) * sizeof(DataT));
dim3 grid = launchConfigGenerator<P>(m, n, shmemSize, kernel);

kernel<<<grid, blk, shmemSize, stream>>>(
min, x, y, xn, yn, m, n, k, maxVal, workspace, redOp, pairRedOp, distance_op, fin_op);
RAFT_CUDA_TRY(cudaGetLastError());
}
}

} // namespace detail
} // namespace distance
} // namespace raft
Loading

0 comments on commit 1417a2e

Please sign in to comment.