Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fused cosine 1-NN cutlass based kernel #2125

Merged
merged 23 commits into from
Mar 19, 2024
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
1417a2e
Add fused cosine 1-NN kernel and unify the fused distance 1-NN kernels
mdoijade Feb 1, 2024
9c5592b
Merge branch 'branch-24.04' into fusedCosineNN
mdoijade Feb 1, 2024
b69f2dc
Merge branch 'branch-24.04' into fusedCosineNN
mdoijade Feb 12, 2024
138eac6
Merge branch 'branch-24.04' into fusedCosineNN
mdoijade Feb 12, 2024
42841da
Merge branch 'branch-24.04' into fusedCosineNN_tmp
mdoijade Feb 16, 2024
5384408
remove double datatype API, code cleanup and other review comments
mdoijade Feb 16, 2024
0ab7a84
fix formatting issues
mdoijade Feb 16, 2024
80e1358
Merge branch 'branch-24.04' into fusedCosineNN
mdoijade Feb 21, 2024
e9090d6
unify fusedl2nn with fuseddistanceNN, add deprecation warning for fus…
mdoijade Feb 22, 2024
5a48625
expose fused_distance_nn in pylibraft and add unit test for it with a…
mdoijade Feb 23, 2024
63261c6
correct the description for fused distance nn arg min pylibraft API
mdoijade Feb 23, 2024
65d0a95
Merge branch 'branch-24.04' into fusedCosineNN
mdoijade Feb 23, 2024
f53e439
fix the fused_l2_nn header name in masked_nn.cuh
mdoijade Feb 23, 2024
958a2b3
fix copyright year in masked_nn.cuh
mdoijade Feb 23, 2024
7a0a6db
fix copyright year for newly added source files
mdoijade Feb 28, 2024
24f38dd
Merge branch 'branch-24.04' into fusedCosineNN
mdoijade Feb 28, 2024
f4974db
fix fused_distance_nn.pyx file permission to be rw instead of rwx
mdoijade Feb 29, 2024
f02902c
merge branch-24.04 and resolve conflicts
mdoijade Mar 7, 2024
7571459
fix clang formatting issues
mdoijade Mar 7, 2024
b8c33d7
Merge branch 'branch-24.04' into fusedCosineNN
tfeher Mar 15, 2024
78f53e8
Merge branch 'branch-24.04' into fusedCosineNN
mdoijade Mar 16, 2024
9f63c56
Update python/pylibraft/pylibraft/distance/fused_distance_nn.pyx
mdoijade Mar 18, 2024
a0632f2
Merge branch 'branch-24.04' into fusedCosineNN
mdoijade Mar 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ endif()

if(RAFT_NVTX)
# This enables NVTX within the project with no option to disable it downstream.
target_link_libraries(raft INTERFACE CUDA::nvToolsExt)
target_link_libraries(raft INTERFACE CUDA::nvtx3)
target_compile_definitions(raft INTERFACE NVTX_ENABLED)
else()
# Allow enable NVTX downstream if not set here. This creates a new option at build/install time,
Expand Down Expand Up @@ -324,6 +324,7 @@ if(RAFT_COMPILE_LIBRARY)
src/distance/detail/pairwise_matrix/dispatch_russel_rao_float_float_float_int.cu
src/distance/distance.cu
src/distance/fused_l2_nn.cu
src/distance/fused_distance_nn.cu
src/linalg/detail/coalesced_reduction.cu
src/matrix/detail/select_k_double_int64_t.cu
src/matrix/detail/select_k_double_uint32_t.cu
Expand Down Expand Up @@ -422,6 +423,7 @@ if(RAFT_COMPILE_LIBRARY)
src/raft_runtime/cluster/update_centroids.cuh
src/raft_runtime/cluster/update_centroids_double.cu
src/raft_runtime/cluster/update_centroids_float.cu
src/raft_runtime/distance/fused_distance_min_arg.cu
src/raft_runtime/distance/fused_l2_min_arg.cu
src/raft_runtime/distance/pairwise_distance.cu
src/raft_runtime/matrix/select_k_float_int64_t.cu
Expand Down
12 changes: 2 additions & 10 deletions cpp/bench/prims/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -122,16 +122,8 @@ if(BUILD_PRIMS_BENCH)
)

ConfigureBench(
NAME
MATRIX_BENCH
PATH
bench/prims/matrix/argmin.cu
bench/prims/matrix/gather.cu
bench/prims/matrix/select_k.cu
bench/prims/matrix/main.cpp
OPTIONAL
LIB
EXPLICIT_INSTANTIATE_ONLY
NAME MATRIX_BENCH PATH bench/prims/matrix/argmin.cu bench/prims/matrix/gather.cu
bench/prims/matrix/select_k.cu bench/prims/main.cpp OPTIONAL LIB EXPLICIT_INSTANTIATE_ONLY
)

ConfigureBench(
Expand Down
97 changes: 97 additions & 0 deletions cpp/include/raft/distance/detail/fused_distance_nn.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <raft/core/kvp.hpp> // raft::KeyValuePair
#include <raft/core/operators.hpp> // raft::identity_op
#include <raft/distance/detail/distance_ops/l2_exp.cuh> // ops::l2_exp_distance_op
#include <raft/distance/detail/fused_distance_nn/cutlass_base.cuh>
#include <raft/distance/detail/fused_distance_nn/fused_cosine_nn.cuh>
#include <raft/distance/detail/fused_distance_nn/fused_l2_nn.cuh>
#include <raft/distance/detail/fused_distance_nn/helper_structs.cuh>
#include <raft/distance/detail/fused_distance_nn/simt_kernel.cuh>
#include <raft/distance/detail/pairwise_distance_base.cuh> // PairwiseDistances
#include <raft/distance/distance_types.hpp>
#include <raft/linalg/contractions.cuh> // Policy
#include <raft/util/arch.cuh> // raft::util::arch::SM_*
#include <raft/util/cuda_utils.cuh> // raft::ceildiv, raft::shfl

#include <cstddef> // size_t
#include <limits> // std::numeric_limits

namespace raft {
namespace distance {

namespace detail {

template <typename DataT,
typename OutT,
typename IdxT,
typename Policy,
typename ReduceOpT,
typename KVPReduceOpT>
void fusedDistanceNNImpl(OutT* min,
const DataT* x,
const DataT* y,
const DataT* xn,
const DataT* yn,
IdxT m,
IdxT n,
IdxT k,
int* workspace,
ReduceOpT redOp,
KVPReduceOpT pairRedOp,
bool sqrt,
bool initOutBuffer,
bool isRowMajor,
raft::distance::DistanceType metric,
float metric_arg,
cudaStream_t stream)
{
// The kernel policy is determined by fusedDistanceNN.
typedef Policy P;

dim3 blk(P::Nthreads);
auto nblks = raft::ceildiv<int>(m, P::Nthreads);
constexpr auto maxVal = std::numeric_limits<DataT>::max();
typedef KeyValuePair<IdxT, DataT> KVPair;

RAFT_CUDA_TRY(cudaMemsetAsync(workspace, 0, sizeof(int) * m, stream));
if (initOutBuffer) {
initKernel<DataT, OutT, IdxT, ReduceOpT>
<<<nblks, P::Nthreads, 0, stream>>>(min, m, maxVal, redOp);
RAFT_CUDA_TRY(cudaGetLastError());
}

switch (metric) {
case raft::distance::DistanceType::CosineExpanded:
fusedCosineNN<DataT, OutT, IdxT, P, ReduceOpT, KVPReduceOpT>(
min, x, y, xn, yn, m, n, k, workspace, redOp, pairRedOp, sqrt, stream);
break;
case raft::distance::DistanceType::L2SqrtExpanded:
case raft::distance::DistanceType::L2Expanded:
// initOutBuffer is take care by fusedDistanceNNImpl() so we set it false to fusedL2NNImpl.
fusedL2NNImpl<DataT, OutT, IdxT, P, ReduceOpT, KVPReduceOpT>(
min, x, y, xn, yn, m, n, k, workspace, redOp, pairRedOp, sqrt, false, stream);
break;
default: assert("only cosine/l2 metric is supported with fusedDistanceNN\n"); break;
}
}

} // namespace detail
} // namespace distance
} // namespace raft
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,7 @@ class EpilogueWithBroadcastCustom : public EpilogueBase<Shape_,
++tensor_iterator;
}
}
tensor_iterator.dumpToGmem();
}

/// Helper to invoke the output functor over each vector of output
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@

#include <rmm/device_uvector.hpp>

#include <cuda/semaphore>

#include <cutlass/cutlass.h>
#include <cutlass/gemm/device/gemm.h>
#include <cutlass/gemm/device/gemm_grouped.h>
Expand All @@ -46,6 +48,14 @@ namespace raft {
namespace distance {
namespace detail {

template <typename IdxT>
RAFT_KERNEL initBinMutexKernel(cuda::binary_semaphore<cuda::thread_scope_device>* mut, IdxT m)
{
auto tid = IdxT(blockIdx.x) * blockDim.x + threadIdx.x;

if (tid < m) { mut[tid].release(); }
}

template <typename DataT,
typename AccT,
typename OutT,
Expand Down Expand Up @@ -87,8 +97,14 @@ void cutlassFusedDistanceNN(const DataT* x,
KVPReduceOpT>;
constexpr int batch_count = 1;

rmm::device_uvector<cuda::binary_semaphore<cuda::thread_scope_device>> bin_mutex(m, stream);

int blks_ = (m / 256) + 1;

initBinMutexKernel<<<blks_, 256, 0, stream>>>(bin_mutex.data(), m);

typename EpilogueOutputOp::Params epilog_op_param(
dist_op, cg_reduce_op, redOp, pairRedOp, mutexes);
dist_op, cg_reduce_op, redOp, pairRedOp, mutexes, bin_mutex.data());

// Number of pipelines you want to use
constexpr int NumStages = 3;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@

#pragma once

#include <cuda/semaphore>

#include <cutlass/array.h>
#include <cutlass/cutlass.h>
#include <cutlass/epilogue/thread/activation.h>
Expand Down Expand Up @@ -121,6 +123,7 @@ class FusedDistanceNNEpilogueElementwise {
KVPReduceOpT_ pair_redop_;
ReduceOpT_ red_op_;
int* mutexes_;
cuda::binary_semaphore<cuda::thread_scope_device>* bin_mutex_;
using CGReduceT = CGReduceOp_;
//
// Methods
Expand All @@ -130,12 +133,14 @@ class FusedDistanceNNEpilogueElementwise {
CGReduceOp cg_reduce_op,
ReduceOpT_ red_op,
KVPReduceOpT_ pair_redop,
int* mutexes)
int* mutexes,
cuda::binary_semaphore<cuda::thread_scope_device>* bin_mutex)
: cg_reduce_op(cg_reduce_op),
dist_op_(dist_op),
pair_redop_(pair_redop),
red_op_(red_op),
mutexes_(mutexes)
mutexes_(mutexes),
bin_mutex_(bin_mutex)
{
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <raft/core/kvp.hpp> // raft::KeyValuePair
#include <raft/core/operators.hpp> // raft::identity_op
#include <raft/distance/detail/distance_ops/cosine.cuh> // ops::l2_exp_distance_op
#include <raft/distance/detail/fused_distance_nn/cutlass_base.cuh>
#include <raft/distance/detail/fused_distance_nn/helper_structs.cuh>
#include <raft/distance/detail/fused_distance_nn/simt_kernel.cuh>
#include <raft/distance/detail/pairwise_distance_base.cuh> // PairwiseDistances
#include <raft/linalg/contractions.cuh> // Policy
#include <raft/util/arch.cuh> // raft::util::arch::SM_*
#include <raft/util/cuda_utils.cuh> // raft::ceildiv, raft::shfl

#include <cstddef> // size_t
#include <limits> // std::numeric_limits

namespace raft {
namespace distance {

namespace detail {

template <typename DataT,
typename OutT,
typename IdxT,
typename Policy,
typename ReduceOpT,
typename KVPReduceOpT>
void fusedCosineNN(OutT* min,
const DataT* x,
const DataT* y,
const DataT* xn,
const DataT* yn,
IdxT m,
IdxT n,
IdxT k,
int* workspace,
ReduceOpT redOp,
KVPReduceOpT pairRedOp,
bool sqrt,
cudaStream_t stream)
{
// The kernel policy is determined by fusedL2NN.
typedef Policy P;

dim3 blk(P::Nthreads);
constexpr auto maxVal = std::numeric_limits<DataT>::max();
typedef KeyValuePair<IdxT, DataT> KVPair;

namespace arch = raft::util::arch;
using AccT = DataT;
ops::cosine_distance_op<DataT, AccT, IdxT> distance_op{};

raft::identity_op fin_op{};

auto kernel = fusedDistanceNNkernel<DataT,
OutT,
IdxT,
P,
ReduceOpT,
KVPReduceOpT,
decltype(distance_op),
decltype(fin_op)>;

// Get pointer to fp32 SIMT kernel to determine the runtime architecture of the
// current system. Other methods to determine the architecture (that do not
// require a pointer) can be error prone. See:
// https://github.com/NVIDIA/cub/issues/545
void* kernel_ptr = reinterpret_cast<void*>(kernel);
auto runtime_arch = arch::kernel_virtual_arch(kernel_ptr);
auto cutlass_range = arch::SM_range(arch::SM_80(), arch::SM_future());

if (cutlass_range.contains(runtime_arch)) {
// If device is SM_80 or later, use CUTLASS-based kernel.
using cosineOp = raft::distance::detail::ops::cosine_cutlass_op<DataT, DataT>;
using kvp_cg_min_reduce_op_ = kvp_cg_min_reduce_op<DataT, IdxT, OutT>;
kvp_cg_min_reduce_op_ cg_reduce_op;
cosineOp cosine_dist_op;

IdxT lda, ldb, ldd;
lda = k, ldb = k, ldd = n;

cutlassFusedDistanceNN<DataT,
DataT,
OutT,
IdxT,
P::Veclen,
decltype(cg_reduce_op),
decltype(cosine_dist_op),
ReduceOpT,
KVPReduceOpT>(x,
y,
xn,
yn,
m,
n,
k,
lda,
ldb,
ldd,
min,
workspace,
cg_reduce_op,
cosine_dist_op,
redOp,
pairRedOp,
stream);
} else {
// If device less than SM_80, use fp32 SIMT kernel.
constexpr size_t shmemSize = P::SmemSize + ((P::Mblk + P::Nblk) * sizeof(DataT));
dim3 grid = launchConfigGenerator<P>(m, n, shmemSize, kernel);

kernel<<<grid, blk, shmemSize, stream>>>(
min, x, y, xn, yn, m, n, k, maxVal, workspace, redOp, pairRedOp, distance_op, fin_op);
RAFT_CUDA_TRY(cudaGetLastError());
}
}

} // namespace detail
} // namespace distance
} // namespace raft
Loading
Loading