From d796a4759c7bf2394e91e8a1f4bb49f0f59ebe0a Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Thu, 1 Feb 2024 02:01:36 -0800 Subject: [PATCH] Add fused cosine 1-NN kernel and unify the fused distance 1-NN kernels --- cpp/CMakeLists.txt | 4 +- cpp/bench/prims/CMakeLists.txt | 12 +- .../distance/detail/fused_distance_nn.cuh | 89 ++++ .../custom_epilogue_with_broadcast.h | 3 +- .../detail/fused_distance_nn/cutlass_base.cuh | 19 +- .../epilogue_elementwise.cuh | 10 +- .../fused_distance_nn/fused_cosine_nn.cuh | 135 ++++++ .../fused_distance_nn/helper_structs.cuh | 145 ++++++ .../fused_distance_nn/persistent_gemm.h | 7 +- .../predicated_tile_iterator_reduced_vec.h | 101 +++-- .../detail/fused_distance_nn/simt_kernel.cuh | 186 ++++++++ .../raft/distance/detail/fused_l2_nn.cuh | 262 +---------- .../raft/distance/fused_distance_nn-ext.cuh | 91 ++++ .../raft/distance/fused_distance_nn-inl.cuh | 325 ++++++++++++++ .../raft/distance/fused_distance_nn.cuh | 24 + ...pers.cuh => fused_distance_nn_helpers.cuh} | 5 +- cpp/include/raft/distance/fused_l2_nn-ext.cuh | 12 +- cpp/include/raft/distance/fused_l2_nn-inl.cuh | 4 +- .../distance/fused_distance_nn.hpp | 71 +++ cpp/src/distance/fused_distance_nn.cu | 60 +++ .../distance/fused_distance_min_arg.cu | 137 ++++++ cpp/test/CMakeLists.txt | 1 + cpp/test/distance/fused_cosine_nn.cu | 416 ++++++++++++++++++ 23 files changed, 1786 insertions(+), 333 deletions(-) create mode 100644 cpp/include/raft/distance/detail/fused_distance_nn.cuh create mode 100644 cpp/include/raft/distance/detail/fused_distance_nn/fused_cosine_nn.cuh create mode 100644 cpp/include/raft/distance/detail/fused_distance_nn/helper_structs.cuh create mode 100644 cpp/include/raft/distance/detail/fused_distance_nn/simt_kernel.cuh create mode 100644 cpp/include/raft/distance/fused_distance_nn-ext.cuh create mode 100644 cpp/include/raft/distance/fused_distance_nn-inl.cuh create mode 100755 cpp/include/raft/distance/fused_distance_nn.cuh rename cpp/include/raft/distance/{fused_l2_nn_helpers.cuh => fused_distance_nn_helpers.cuh} (89%) create mode 100644 cpp/include/raft_runtime/distance/fused_distance_nn.hpp create mode 100644 cpp/src/distance/fused_distance_nn.cu create mode 100644 cpp/src/raft_runtime/distance/fused_distance_min_arg.cu create mode 100644 cpp/test/distance/fused_cosine_nn.cu diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 650bc1a059..61ebbe9978 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -259,7 +259,7 @@ endif() if(RAFT_NVTX) # This enables NVTX within the project with no option to disable it downstream. - target_link_libraries(raft INTERFACE CUDA::nvToolsExt) + target_link_libraries(raft INTERFACE CUDA::nvtx3) target_compile_definitions(raft INTERFACE NVTX_ENABLED) else() # Allow enable NVTX downstream if not set here. This creates a new option at build/install time, @@ -327,6 +327,7 @@ if(RAFT_COMPILE_LIBRARY) src/distance/detail/pairwise_matrix/dispatch_russel_rao_float_float_float_int.cu src/distance/distance.cu src/distance/fused_l2_nn.cu + src/distance/fused_distance_nn.cu src/linalg/detail/coalesced_reduction.cu src/matrix/detail/select_k_double_int64_t.cu src/matrix/detail/select_k_double_uint32_t.cu @@ -425,6 +426,7 @@ if(RAFT_COMPILE_LIBRARY) src/raft_runtime/cluster/update_centroids.cuh src/raft_runtime/cluster/update_centroids_double.cu src/raft_runtime/cluster/update_centroids_float.cu + src/raft_runtime/distance/fused_distance_min_arg.cu src/raft_runtime/distance/fused_l2_min_arg.cu src/raft_runtime/distance/pairwise_distance.cu src/raft_runtime/matrix/select_k_float_int64_t.cu diff --git a/cpp/bench/prims/CMakeLists.txt b/cpp/bench/prims/CMakeLists.txt index 3a2431cd34..d031431946 100644 --- a/cpp/bench/prims/CMakeLists.txt +++ b/cpp/bench/prims/CMakeLists.txt @@ -122,16 +122,8 @@ if(BUILD_PRIMS_BENCH) ) ConfigureBench( - NAME - MATRIX_BENCH - PATH - bench/prims/matrix/argmin.cu - bench/prims/matrix/gather.cu - bench/prims/matrix/select_k.cu - bench/prims/matrix/main.cpp - OPTIONAL - LIB - EXPLICIT_INSTANTIATE_ONLY + NAME MATRIX_BENCH PATH bench/prims/matrix/argmin.cu bench/prims/matrix/gather.cu + bench/prims/matrix/select_k.cu bench/prims/main.cpp OPTIONAL LIB EXPLICIT_INSTANTIATE_ONLY ) ConfigureBench( diff --git a/cpp/include/raft/distance/detail/fused_distance_nn.cuh b/cpp/include/raft/distance/detail/fused_distance_nn.cuh new file mode 100644 index 0000000000..94f199275d --- /dev/null +++ b/cpp/include/raft/distance/detail/fused_distance_nn.cuh @@ -0,0 +1,89 @@ +/* + * Copyright (c) 2021-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // size_t +#include // std::numeric_limits +#include // raft::KeyValuePair +#include // raft::identity_op +#include // ops::l2_exp_distance_op +#include +#include +#include +#include +#include // PairwiseDistances +#include +#include // Policy +#include // raft::util::arch::SM_* +#include // raft::ceildiv, raft::shfl + +namespace raft { +namespace distance { + +namespace detail { + +template +void fusedDistanceNNImpl(OutT* min, + const DataT* x, + const DataT* y, + const DataT* xn, + const DataT* yn, + IdxT m, + IdxT n, + IdxT k, + int* workspace, + ReduceOpT redOp, + KVPReduceOpT pairRedOp, + bool sqrt, + bool initOutBuffer, + bool isRowMajor, + raft::distance::DistanceType metric, + float metric_arg, + cudaStream_t stream) +{ + // The kernel policy is determined by fusedDistanceNN. + typedef Policy P; + + dim3 blk(P::Nthreads); + auto nblks = raft::ceildiv(m, P::Nthreads); + constexpr auto maxVal = std::numeric_limits::max(); + typedef KeyValuePair KVPair; + + RAFT_CUDA_TRY(cudaMemsetAsync(workspace, 0, sizeof(int) * m, stream)); + if (initOutBuffer) { + initKernel + <<>>(min, m, maxVal, redOp); + RAFT_CUDA_TRY(cudaGetLastError()); + } + + switch (metric) { + case DistanceType::CosineExpanded: + fusedCosineNN( + min, x, y, xn, yn, m, n, k, workspace, redOp, pairRedOp, sqrt, stream); + break; + default: assert("only cosine metric is supported with fusedDistanceNN\n"); break; + } +} + +} // namespace detail +} // namespace distance +} // namespace raft diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/custom_epilogue_with_broadcast.h b/cpp/include/raft/distance/detail/fused_distance_nn/custom_epilogue_with_broadcast.h index f659ed256d..ac20578083 100644 --- a/cpp/include/raft/distance/detail/fused_distance_nn/custom_epilogue_with_broadcast.h +++ b/cpp/include/raft/distance/detail/fused_distance_nn/custom_epilogue_with_broadcast.h @@ -30,7 +30,7 @@ **************************************************************************************************/ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -615,6 +615,7 @@ class EpilogueWithBroadcastCustom : public EpilogueBase #include +#include #include // FusedDistanceNNEpilogueElementwise #include // FusedDistanceNNGemm #include // getMultiProcessorCount @@ -46,6 +47,14 @@ namespace raft { namespace distance { namespace detail { +template +RAFT_KERNEL initBinMutexKernel(cuda::binary_semaphore* mut, IdxT m) +{ + auto tid = IdxT(blockIdx.x) * blockDim.x + threadIdx.x; + + if (tid < m) { mut[tid].release(); } +} + template ; constexpr int batch_count = 1; + rmm::device_uvector> bin_mutex(m, stream); + + int blks_ = (m / 256) + 1; + + initBinMutexKernel<<>>(bin_mutex.data(), m); + typename EpilogueOutputOp::Params epilog_op_param( - dist_op, cg_reduce_op, redOp, pairRedOp, mutexes); + dist_op, cg_reduce_op, redOp, pairRedOp, mutexes, bin_mutex.data()); // Number of pipelines you want to use constexpr int NumStages = 3; diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/epilogue_elementwise.cuh b/cpp/include/raft/distance/detail/fused_distance_nn/epilogue_elementwise.cuh index a21f3d60e0..d65d2df4a4 100644 --- a/cpp/include/raft/distance/detail/fused_distance_nn/epilogue_elementwise.cuh +++ b/cpp/include/raft/distance/detail/fused_distance_nn/epilogue_elementwise.cuh @@ -29,7 +29,7 @@ * **************************************************************************************************/ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -62,6 +62,7 @@ #include #include +#include #include ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -122,6 +123,7 @@ class FusedDistanceNNEpilogueElementwise { KVPReduceOpT_ pair_redop_; ReduceOpT_ red_op_; int* mutexes_; + cuda::binary_semaphore* bin_mutex_; using CGReduceT = CGReduceOp_; // // Methods @@ -131,12 +133,14 @@ class FusedDistanceNNEpilogueElementwise { CGReduceOp cg_reduce_op, ReduceOpT_ red_op, KVPReduceOpT_ pair_redop, - int* mutexes) + int* mutexes, + cuda::binary_semaphore* bin_mutex) : cg_reduce_op(cg_reduce_op), dist_op_(dist_op), pair_redop_(pair_redop), red_op_(red_op), - mutexes_(mutexes) + mutexes_(mutexes), + bin_mutex_(bin_mutex) { } diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/fused_cosine_nn.cuh b/cpp/include/raft/distance/detail/fused_distance_nn/fused_cosine_nn.cuh new file mode 100644 index 0000000000..e86db734a5 --- /dev/null +++ b/cpp/include/raft/distance/detail/fused_distance_nn/fused_cosine_nn.cuh @@ -0,0 +1,135 @@ +/* + * Copyright (c) 2021-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // size_t +#include // std::numeric_limits +#include // raft::KeyValuePair +#include // raft::identity_op +#include // ops::l2_exp_distance_op +#include +#include +#include +#include // PairwiseDistances +#include // Policy +#include // raft::util::arch::SM_* +#include // raft::ceildiv, raft::shfl + +namespace raft { +namespace distance { + +namespace detail { + +template +void fusedCosineNN(OutT* min, + const DataT* x, + const DataT* y, + const DataT* xn, + const DataT* yn, + IdxT m, + IdxT n, + IdxT k, + int* workspace, + ReduceOpT redOp, + KVPReduceOpT pairRedOp, + bool sqrt, + cudaStream_t stream) +{ + // The kernel policy is determined by fusedL2NN. + typedef Policy P; + + dim3 blk(P::Nthreads); + constexpr auto maxVal = std::numeric_limits::max(); + typedef KeyValuePair KVPair; + + namespace arch = raft::util::arch; + using AccT = DataT; + ops::cosine_distance_op distance_op{}; + + raft::identity_op fin_op{}; + + auto kernel = fusedDistanceNNkernel; + + // Get pointer to fp32 SIMT kernel to determine the runtime architecture of the + // current system. Other methods to determine the architecture (that do not + // require a pointer) can be error prone. See: + // https://github.com/NVIDIA/cub/issues/545 + void* kernel_ptr = reinterpret_cast(kernel); + auto runtime_arch = arch::kernel_virtual_arch(kernel_ptr); + auto cutlass_range = arch::SM_range(arch::SM_80(), arch::SM_future()); + + if (cutlass_range.contains(runtime_arch)) { + // If device is SM_80 or later, use CUTLASS-based kernel. + using cosineOp = raft::distance::detail::ops::cosine_cutlass_op; + using kvp_cg_min_reduce_op_ = kvp_cg_min_reduce_op; + kvp_cg_min_reduce_op_ cg_reduce_op; + cosineOp cosine_dist_op; + + IdxT lda, ldb, ldd; + lda = k, ldb = k, ldd = n; + + cutlassFusedDistanceNN(x, + y, + xn, + yn, + m, + n, + k, + lda, + ldb, + ldd, + min, + workspace, + cg_reduce_op, + cosine_dist_op, + redOp, + pairRedOp, + stream); + } else { + // If device less than SM_80, use fp32 SIMT kernel. + constexpr size_t shmemSize = P::SmemSize + ((P::Mblk + P::Nblk) * sizeof(DataT)); + dim3 grid = launchConfigGenerator

(m, n, shmemSize, kernel); + + kernel<<>>( + min, x, y, xn, yn, m, n, k, maxVal, workspace, redOp, pairRedOp, distance_op, fin_op); + RAFT_CUDA_TRY(cudaGetLastError()); + } +} + +} // namespace detail +} // namespace distance +} // namespace raft diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/helper_structs.cuh b/cpp/include/raft/distance/detail/fused_distance_nn/helper_structs.cuh new file mode 100644 index 0000000000..e88ea9cfc8 --- /dev/null +++ b/cpp/include/raft/distance/detail/fused_distance_nn/helper_structs.cuh @@ -0,0 +1,145 @@ +/* + * Copyright (c) 2021-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // size_t +#include // std::numeric_limits +#include // raft::KeyValuePair +#include // raft::identity_op +#include // ops::l2_exp_distance_op +#include +#include +#include // PairwiseDistances +#include // Policy +#include // raft::util::arch::SM_* +#include // raft::ceildiv, raft::shfl +#include + +namespace raft { +namespace distance { + +namespace detail { + +template +struct KVPMinReduceImpl { + typedef raft::KeyValuePair KVP; + DI KVP operator()(LabelT rit, const KVP& a, const KVP& b) { return b.value < a.value ? b : a; } + DI KVP operator()(const KVP& a, const KVP& b) { return b.value < a.value ? b : a; } + +}; // KVPMinReduce + +template +struct MinAndDistanceReduceOpImpl { + typedef typename raft::KeyValuePair KVP; + + DI void operator()(LabelT rid, KVP* out, const KVP& other) const + { + if (other.value < out->value) { + out->key = other.key; + out->value = other.value; + } + } + DI void operator()(LabelT rid, volatile KVP* out, const KVP& other) const + { + if (other.value < out->value) { + out->key = other.key; + out->value = other.value; + } + } + + DI void operator()(LabelT rid, DataT* out, const KVP& other) const + { + if (other.value < *out) { *out = other.value; } + } + + DI void operator()(LabelT rid, volatile DataT* out, const KVP& other) const + { + if (other.value < *out) { *out = other.value; } + } + + DI void operator()(LabelT rid, DataT* out, const DataT& other) const + { + if (other < *out) { *out = other; } + } + + DI void operator()(LabelT rid, volatile DataT* out, const DataT& other) const + { + if (other < *out) { *out = other; } + } + + DI void init(DataT* out, DataT maxVal) const { *out = maxVal; } + DI void init(KVP* out, DataT maxVal) const + { + out->value = maxVal; + out->key = 0xfffffff0; + } + + DI void init_key(DataT& out, LabelT idx) const { return; } + DI void init_key(KVP& out, LabelT idx) const { out.key = idx; } + + DI DataT get_value(KVP& out) const { return out.value; } + DI DataT get_value(DataT& out) const { return out; } +}; + +template +struct MinReduceOpImpl { + typedef typename raft::KeyValuePair KVP; + DI void operator()(LabelT rid, DataT* out, const KVP& other) + { + if (other.value < *out) { *out = other.value; } + } + + DI void init(DataT* out, DataT maxVal) { *out = maxVal; } +}; + +template +RAFT_KERNEL initKernel(OutT* min, IdxT m, DataT maxVal, ReduceOpT redOp) +{ + auto tid = IdxT(blockIdx.x) * blockDim.x + threadIdx.x; + if (tid < m) { redOp.init(min + tid, maxVal); } +} + +template +void initialize(OutT* min, IdxT m, DataT maxVal, ReduceOpT redOp, cudaStream_t stream) +{ + auto blks = raft::ceildiv(m, 256); + initKernel<<>>(min, m, maxVal, redOp); +} + +// cg::reduce functor for FusedDistanceNN used in its cutlass version +// to output the min distance value & key(loc id). +// This is used in fused_distance_nn/predicated_tile_iterator_reduced_vec.h +// store_with_byte_offset() passed to cg::reduce() & select_reduce. +template +struct kvp_cg_min_reduce_op { + typedef typename raft::KeyValuePair KVP; + + __host__ __device__ kvp_cg_min_reduce_op() noexcept {}; + + using AccTypeT = AccType; + using IndexT = Index; + // functor signature. + __host__ __device__ KVP operator()(KVP a, KVP b) const { return a.value < b.value ? a : b; } + + __host__ __device__ AccType operator()(AccType a, AccType b) const { return min(a, b); } + + __host__ __device__ bool isAmin(AccType a, AccType b) const { return a < b ? true : false; } +}; + +} // namespace detail +} // namespace distance +} // namespace raft diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/persistent_gemm.h b/cpp/include/raft/distance/detail/fused_distance_nn/persistent_gemm.h index 3a8d6c8655..a04fe36b79 100644 --- a/cpp/include/raft/distance/detail/fused_distance_nn/persistent_gemm.h +++ b/cpp/include/raft/distance/detail/fused_distance_nn/persistent_gemm.h @@ -29,7 +29,7 @@ * **************************************************************************************************/ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -181,8 +181,7 @@ struct FusedDistanceNNPersistent { /// Default ctor CUTLASS_HOST_DEVICE Arguments() - : // problem_count(0), - threadblock_count(0), + : threadblock_count(0), ptr_A(nullptr), ptr_B(nullptr), ptr_C(nullptr), @@ -206,6 +205,7 @@ struct FusedDistanceNNPersistent { void const* ptr_B, void const* ptr_C, void* ptr_Vector, + // volatile void* ptr_Tensor, void* ptr_Tensor, typename LayoutA::Stride::Index lda, typename LayoutB::Stride::Index ldb, @@ -236,7 +236,6 @@ struct FusedDistanceNNPersistent { /// Parameters structure struct Params { - // typename ProblemVisitor::Params problem_visitor; temp_problem_visitor problem_visitor; int threadblock_count; diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h b/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h index dc224c5c96..4591fa7855 100644 --- a/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h +++ b/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h @@ -29,7 +29,7 @@ * **************************************************************************************************/ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -322,9 +322,10 @@ class PredicatedTileIteratorReducedVec { Params params_; /// Byte-level pointer - uint8_t* byte_pointer_; + // uint8_t* byte_pointer_; /// Byte-level pointer first tile offset of this threadblock. - uint8_t* first_tile_byte_pointer_; + volatile uint8_t* first_tile_byte_pointer_; + // uint8_t* first_tile_byte_pointer_; /// Array of boolean values to contain steady-state predicates Mask mask_; @@ -349,6 +350,8 @@ class PredicatedTileIteratorReducedVec { /// Scatter indices int const* indices_; + const int do_gmem_reduction_; + // // Static asserts about internal strides // @@ -359,7 +362,6 @@ class PredicatedTileIteratorReducedVec { protected: SharedStorage& shared_storage_; - const bool& do_gmem_reduction_; private: // @@ -373,10 +375,10 @@ class PredicatedTileIteratorReducedVec { CUTLASS_DEVICE PredicatedTileIteratorReducedVec(SharedStorage& shared_storage, Params const& params, - Element* pointer, + volatile Element* pointer, TensorCoord extent, int thread_idx, - const bool& do_gmem_reduction, + const bool do_gmem_reduction, TensorCoord threadblock_offset = TensorCoord(), int const* indices = nullptr) : params_(params), @@ -408,6 +410,7 @@ class PredicatedTileIteratorReducedVec { EpilogueOpParams const& user_params = params_.user_param; shared_storage_.initSmem(user_params); } + __syncthreads(); // Null pointer performs no accesses if (!pointer) { mask_.clear(); } @@ -415,65 +418,61 @@ class PredicatedTileIteratorReducedVec { if (ScatterD && !indices) { mask_.clear(); } // Initialize pointer - first_tile_byte_pointer_ = reinterpret_cast(pointer) + + first_tile_byte_pointer_ = reinterpret_cast(pointer) + LongIndex(block_offset.row()) * LongIndex(params_.stride); - if (ScatterD) { - byte_pointer_ = reinterpret_cast(pointer) + - LongIndex(thread_offset.column()) * sizeof(AccessType) / kElementsPerAccess; - } - + // first_tile_byte_pointer_ = reinterpret_cast(pointer) + + // LongIndex(block_offset.row()) * LongIndex(params_.stride); // Initialize internal state counter state_[0] = state_[1] = state_[2] = 0; } - /// Destructor - CUTLASS_DEVICE - ~PredicatedTileIteratorReducedVec() + CUTLASS_DEVICE void dumpToGmem() { + if (block_start_row_first_tile_ >= extent_row_) { return; } + if (do_gmem_reduction_) { EpilogueOpParams const& user_params = params_.user_param; - auto gmem_ptr = reinterpret_cast(first_tile_byte_pointer_); - Element* shared_elem_arr = shared_storage_.data(); const uint32_t mutex_id = (block_start_row_first_tile_ / total_rows); - bool useGmemMutex = (gridDim.x != ((extent_row_ - 1 + total_rows) / total_rows)); - // If this is not optimal grid size perform mutex based gmem reduce. - if (useGmemMutex) { - // single lock per block for multiple rows - if (threadIdx.x == 0 && block_start_row_first_tile_ < extent_row_) { - // acquire mutex lock. - unsigned int ns = 8; - while (atomicCAS(user_params.mutexes_ + mutex_id, 0, 1) == 1) { - __nanosleep(ns); - if (ns < 256) { ns *= 2; } - } + const bool useGmemMutex = (gridDim.x != ((extent_row_ - 1 + total_rows) / total_rows)); + int row = threadIdx.x; + Element* shared_elem_arr = shared_storage_.data(); + Element row_local_min; + if (row < total_rows) { row_local_min = shared_elem_arr[row]; } + + // single lock per block for multiple rows + if (useGmemMutex && threadIdx.x == 0) { user_params.bin_mutex_[mutex_id].acquire(); } + __syncthreads(); + + if (row < total_rows) { + volatile Element* gmem_ptr = reinterpret_cast(first_tile_byte_pointer_); + + if ((block_start_row_first_tile_ + row) < extent_row_) { + user_params.red_op_(block_start_row_first_tile_ + row, (gmem_ptr + row), row_local_min); } } __syncthreads(); - for (int row = threadIdx.x; row < total_rows; row += blockDim.x) { - if (block_start_row_first_tile_ + row < extent_row_) { - user_params.red_op_( - block_start_row_first_tile_ + row, &gmem_ptr[row], shared_elem_arr[row]); - } - } + __threadfence(); - if (useGmemMutex) { - __threadfence(); - __syncthreads(); - if (threadIdx.x == 0 && block_start_row_first_tile_ < extent_row_) { - // release mutex lock. - atomicExch(user_params.mutexes_ + mutex_id, 0); - } + if (useGmemMutex && (threadIdx.x == 0)) { + // release mutex lock. + user_params.bin_mutex_[mutex_id].release(); } + shared_storage_.initSmem(user_params); + __syncthreads(); } } + /// Destructor + CUTLASS_DEVICE + ~PredicatedTileIteratorReducedVec() {} + /// Adds a pointer offset in units of Element CUTLASS_HOST_DEVICE void add_pointer_offset(LongIndex pointer_offset) { - byte_pointer_ += pointer_offset * sizeof_bits::value / 8; + // byte_pointer_ += pointer_offset * sizeof_bits::value / 8; } /// Performs reduction and Stores a reduced output to memory @@ -514,9 +513,6 @@ class PredicatedTileIteratorReducedVec { user_params.red_op_.init(&red_val, maxVal); if (row_guard) { - const int iter_row = (row_id % total_rows); - const auto prev_red_val = user_params.red_op_.get_value(shared_elem_arr[iter_row]); - CUTLASS_PRAGMA_UNROLL for (int column = 0; column < ThreadMap::Iterations::kColumn * kElementsPerAccess; ++column) { @@ -535,6 +531,10 @@ class PredicatedTileIteratorReducedVec { user_params.red_op_(row_id, &red_val, this_val); } } + } + const int iter_row = (row_id % total_rows); + const auto prev_red_val = user_params.red_op_.get_value(shared_elem_arr[iter_row]); + if (row_guard) { // select_reduce doesn't need to use `red_op_` as at the warp level we use cg_reduce_op, // this satisfies the requirement of mst/single linkage of checking colors buffer. select_reduce red_obj( @@ -543,6 +543,7 @@ class PredicatedTileIteratorReducedVec { } } } + __syncthreads(); } /// Stores a fragment to memory @@ -573,15 +574,14 @@ class PredicatedTileIteratorReducedVec { PredicatedTileIteratorReducedVec& operator++() { ++state_[0]; - - if (!ScatterD) { byte_pointer_ += params_.advance_row; } + // if (!ScatterD) { byte_pointer_ += params_.advance_row; } thread_start_row_ += ThreadMap::Shape::kRow; if (state_[0] == ThreadMap::Count::kRow) { state_[0] = 0; ++state_[1]; - byte_pointer_ += params_.advance_group; + // byte_pointer_ += params_.advance_group; thread_start_row_ += (ThreadMap::Shape::kGroup - 1) * ThreadMap::Shape::kRow * ThreadMap::Count::kRow; @@ -589,18 +589,17 @@ class PredicatedTileIteratorReducedVec { if (state_[1] == ThreadMap::Count::kGroup) { state_[1] = 0; ++state_[2]; - byte_pointer_ += params_.advance_cluster; + // byte_pointer_ += params_.advance_cluster; thread_start_row_ += ThreadMap::Count::kGroup * ThreadMap::Shape::kGroup * ThreadMap::Count::kRow * ThreadMap::Shape::kRow; if (state_[2] == ThreadMap::Count::kCluster) { state_[2] = 0; - byte_pointer_ += params_.advance_tile; + // byte_pointer_ += params_.advance_tile; } } } - return *this; } diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/simt_kernel.cuh b/cpp/include/raft/distance/detail/fused_distance_nn/simt_kernel.cuh new file mode 100644 index 0000000000..f5e4c725d6 --- /dev/null +++ b/cpp/include/raft/distance/detail/fused_distance_nn/simt_kernel.cuh @@ -0,0 +1,186 @@ +/* + * Copyright (c) 2021-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // size_t +#include // std::numeric_limits +#include // raft::KeyValuePair +#include // ops::l2_exp_distance_op +#include // PairwiseDistances +#include // Policy + +namespace raft { +namespace distance { +namespace detail { + +// TODO: specialize this function for MinAndDistanceReduceOp +// with atomicCAS of 64 bit which will eliminate mutex and shfls +template +DI void updateReducedVal( + int* mutex, OutT* min, KVPair* val, ReduceOpT red_op, IdxT m, IdxT gridStrideY) +{ + const auto lid = threadIdx.x % raft::WarpSize; + const auto accrowid = threadIdx.x / P::AccThCols; + + // Update each output row in order within a warp. This will resolve hang + // issues with pre-Volta architectures +#pragma unroll + for (int j = 0; j < (raft::WarpSize / P::AccThCols); j++) { + if (lid == j * P::AccThCols) { +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { + auto rid = gridStrideY + accrowid + i * P::AccThRows; + if (rid < m) { + auto value = val[i]; + while (atomicCAS(mutex + rid, 0, 1) == 1) + ; + __threadfence(); + red_op(rid, min + rid, value); + __threadfence(); + atomicCAS(mutex + rid, 1, 0); + } + } + } + } +} + +template +__launch_bounds__(P::Nthreads, 2) RAFT_KERNEL fusedDistanceNNkernel(OutT* min, + const DataT* x, + const DataT* y, + const DataT* xn, + const DataT* yn, + IdxT m, + IdxT n, + IdxT k, + DataT maxVal, + int* mutex, + ReduceOpT redOp, + KVPReduceOpT pairRedOp, + OpT distance_op, + FinalLambda fin_op) +{ +// compile only if below non-ampere arch. +#if __CUDA_ARCH__ < 800 + extern __shared__ char smem[]; + + typedef KeyValuePair KVPair; + KVPair val[P::AccRowsPerTh]; +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { + val[i] = {0, maxVal}; + } + + // epilogue operation lambda for final value calculation + auto epilog_lambda = [n, pairRedOp, &val, maxVal] __device__( + DataT acc[P::AccRowsPerTh][P::AccColsPerTh], + DataT * regxn, + DataT * regyn, + IdxT gridStrideX, + IdxT gridStrideY) { + KVPReduceOpT pairRed_op(pairRedOp); + + // intra thread reduce + const auto acccolid = threadIdx.x % P::AccThCols; + const auto accrowid = threadIdx.x / P::AccThCols; +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { +#pragma unroll + for (int j = 0; j < P::AccColsPerTh; ++j) { + auto tmpkey = acccolid + j * P::AccThCols + gridStrideX; + KVPair tmp = {tmpkey, acc[i][j]}; + if (tmpkey < n) { + val[i] = pairRed_op(accrowid + i * P::AccThRows + gridStrideY, tmp, val[i]); + } + } + } + }; + + auto rowEpilog_lambda = + [m, mutex, min, pairRedOp, redOp, &val, maxVal] __device__(IdxT gridStrideY) { + KVPReduceOpT pairRed_op(pairRedOp); + ReduceOpT red_op(redOp); + + const auto accrowid = threadIdx.x / P::AccThCols; + const auto lid = raft::laneId(); + + // reduce +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { +#pragma unroll + for (int j = P::AccThCols / 2; j > 0; j >>= 1) { + // Actually, the srcLane (lid +j) should be (lid +j) % P:AccThCols, + // but the shfl op applies the modulo internally. + auto tmpkey = raft::shfl(val[i].key, lid + j, P::AccThCols); + auto tmpvalue = raft::shfl(val[i].value, lid + j, P::AccThCols); + KVPair tmp = {tmpkey, tmpvalue}; + val[i] = pairRed_op(accrowid + i * P::AccThRows + gridStrideY, tmp, val[i]); + } + } + + updateReducedVal(mutex, min, val, red_op, m, gridStrideY); + + // reset the val array. +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { + val[i] = {0, maxVal}; + } + }; + + IdxT lda = k, ldb = k, ldd = n; + constexpr bool row_major = true; + constexpr bool write_out = false; + PairwiseDistances + obj(x, + y, + m, + n, + k, + lda, + ldb, + ldd, + xn, + yn, + nullptr, // Output pointer + smem, + distance_op, + epilog_lambda, + fin_op, + rowEpilog_lambda); + obj.run(); +#endif +} + +} // namespace detail +} // namespace distance +} // namespace raft diff --git a/cpp/include/raft/distance/detail/fused_l2_nn.cuh b/cpp/include/raft/distance/detail/fused_l2_nn.cuh index 2468dcd740..75275d40b3 100644 --- a/cpp/include/raft/distance/detail/fused_l2_nn.cuh +++ b/cpp/include/raft/distance/detail/fused_l2_nn.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * Copyright (c) 2021-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,6 +22,8 @@ #include // raft::identity_op #include // ops::l2_exp_distance_op #include +#include +#include #include // PairwiseDistances #include // Policy #include // raft::util::arch::SM_* @@ -32,248 +34,6 @@ namespace distance { namespace detail { -template -struct KVPMinReduceImpl { - typedef raft::KeyValuePair KVP; - DI KVP operator()(LabelT rit, const KVP& a, const KVP& b) { return b.value < a.value ? b : a; } - DI KVP operator()(const KVP& a, const KVP& b) { return b.value < a.value ? b : a; } - -}; // KVPMinReduce - -template -struct MinAndDistanceReduceOpImpl { - typedef typename raft::KeyValuePair KVP; - DI void operator()(LabelT rid, KVP* out, const KVP& other) const - { - if (other.value < out->value) { - out->key = other.key; - out->value = other.value; - } - } - - DI void operator()(LabelT rid, DataT* out, const KVP& other) const - { - if (other.value < *out) { *out = other.value; } - } - - DI void operator()(LabelT rid, DataT* out, const DataT& other) const - { - if (other < *out) { *out = other; } - } - - DI void init(DataT* out, DataT maxVal) const { *out = maxVal; } - DI void init(KVP* out, DataT maxVal) const { out->value = maxVal; } - - DI void init_key(DataT& out, LabelT idx) const { return; } - DI void init_key(KVP& out, LabelT idx) const { out.key = idx; } - - DI DataT get_value(KVP& out) const - { - return out.value; - ; - } - DI DataT get_value(DataT& out) const { return out; } -}; - -template -struct MinReduceOpImpl { - typedef typename raft::KeyValuePair KVP; - DI void operator()(LabelT rid, DataT* out, const KVP& other) - { - if (other.value < *out) { *out = other.value; } - } - - DI void init(DataT* out, DataT maxVal) { *out = maxVal; } -}; - -template -RAFT_KERNEL initKernel(OutT* min, IdxT m, DataT maxVal, ReduceOpT redOp) -{ - auto tid = IdxT(blockIdx.x) * blockDim.x + threadIdx.x; - if (tid < m) { redOp.init(min + tid, maxVal); } -} - -template -void initialize(OutT* min, IdxT m, DataT maxVal, ReduceOpT redOp, cudaStream_t stream) -{ - auto blks = raft::ceildiv(m, 256); - initKernel<<>>(min, m, maxVal, redOp); -} - -// TODO: specialize this function for MinAndDistanceReduceOp -// with atomicCAS of 64 bit which will eliminate mutex and shfls -template -DI void updateReducedVal( - int* mutex, OutT* min, KVPair* val, ReduceOpT red_op, IdxT m, IdxT gridStrideY) -{ - const auto lid = threadIdx.x % raft::WarpSize; - const auto accrowid = threadIdx.x / P::AccThCols; - - // Update each output row in order within a warp. This will resolve hang - // issues with pre-Volta architectures -#pragma unroll - for (int j = 0; j < (raft::WarpSize / P::AccThCols); j++) { - if (lid == j * P::AccThCols) { -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { - auto rid = gridStrideY + accrowid + i * P::AccThRows; - if (rid < m) { - auto value = val[i]; - while (atomicCAS(mutex + rid, 0, 1) == 1) - ; - __threadfence(); - red_op(rid, min + rid, value); - __threadfence(); - atomicCAS(mutex + rid, 1, 0); - } - } - } - } -} - -template -__launch_bounds__(P::Nthreads, 2) RAFT_KERNEL fusedL2NNkernel(OutT* min, - const DataT* x, - const DataT* y, - const DataT* xn, - const DataT* yn, - IdxT m, - IdxT n, - IdxT k, - DataT maxVal, - int* mutex, - ReduceOpT redOp, - KVPReduceOpT pairRedOp, - OpT distance_op, - FinalLambda fin_op) -{ -// compile only if below non-ampere arch. -#if __CUDA_ARCH__ < 800 - extern __shared__ char smem[]; - - typedef KeyValuePair KVPair; - KVPair val[P::AccRowsPerTh]; -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { - val[i] = {0, maxVal}; - } - - // epilogue operation lambda for final value calculation - auto epilog_lambda = [n, pairRedOp, &val, maxVal] __device__( - DataT acc[P::AccRowsPerTh][P::AccColsPerTh], - DataT * regxn, - DataT * regyn, - IdxT gridStrideX, - IdxT gridStrideY) { - KVPReduceOpT pairRed_op(pairRedOp); - - // intra thread reduce - const auto acccolid = threadIdx.x % P::AccThCols; - const auto accrowid = threadIdx.x / P::AccThCols; -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = 0; j < P::AccColsPerTh; ++j) { - auto tmpkey = acccolid + j * P::AccThCols + gridStrideX; - KVPair tmp = {tmpkey, acc[i][j]}; - if (tmpkey < n) { - val[i] = pairRed_op(accrowid + i * P::AccThRows + gridStrideY, tmp, val[i]); - } - } - } - }; - - auto rowEpilog_lambda = - [m, mutex, min, pairRedOp, redOp, &val, maxVal] __device__(IdxT gridStrideY) { - KVPReduceOpT pairRed_op(pairRedOp); - ReduceOpT red_op(redOp); - - const auto accrowid = threadIdx.x / P::AccThCols; - const auto lid = raft::laneId(); - - // reduce -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = P::AccThCols / 2; j > 0; j >>= 1) { - // Actually, the srcLane (lid +j) should be (lid +j) % P:AccThCols, - // but the shfl op applies the modulo internally. - auto tmpkey = raft::shfl(val[i].key, lid + j, P::AccThCols); - auto tmpvalue = raft::shfl(val[i].value, lid + j, P::AccThCols); - KVPair tmp = {tmpkey, tmpvalue}; - val[i] = pairRed_op(accrowid + i * P::AccThRows + gridStrideY, tmp, val[i]); - } - } - - updateReducedVal(mutex, min, val, red_op, m, gridStrideY); - - // reset the val array. -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { - val[i] = {0, maxVal}; - } - }; - - IdxT lda = k, ldb = k, ldd = n; - constexpr bool row_major = true; - constexpr bool write_out = false; - PairwiseDistances - obj(x, - y, - m, - n, - k, - lda, - ldb, - ldd, - xn, - yn, - nullptr, // Output pointer - smem, - distance_op, - epilog_lambda, - fin_op, - rowEpilog_lambda); - obj.run(); -#endif -} - -// cg::reduce functor for FusedDistanceNN used in its cutlass version -// to output the min distance value & key(loc id). -// This is used in fused_distance_nn/predicated_tile_iterator_reduced_vec.h -// store_with_byte_offset() passed to cg::reduce() & select_reduce. -template -struct kvp_cg_min_reduce_op { - typedef typename raft::KeyValuePair KVP; - - __host__ __device__ kvp_cg_min_reduce_op() noexcept {}; - - using AccTypeT = AccType; - using IndexT = Index; - // functor signature. - __host__ __device__ KVP operator()(KVP a, KVP b) const { return a.value < b.value ? a : b; } - - __host__ __device__ AccType operator()(AccType a, AccType b) const { return min(a, b); } - - __host__ __device__ bool isAmin(AccType a, AccType b) const { return a < b ? true : false; } -}; - template ; + auto kernel = fusedDistanceNNkernel; // Get pointer to fp32 SIMT kernel to determine the best compute architecture // out of all for which the kernel was compiled for that matches closely diff --git a/cpp/include/raft/distance/fused_distance_nn-ext.cuh b/cpp/include/raft/distance/fused_distance_nn-ext.cuh new file mode 100644 index 0000000000..9dd236a3bd --- /dev/null +++ b/cpp/include/raft/distance/fused_distance_nn-ext.cuh @@ -0,0 +1,91 @@ +/* + * Copyright (c) 2021-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // int64_t +#include // raft::KeyValuePair +#include // raft::resources +#include // include initialize and reduce operations +#include // RAFT_EXPLICIT + +#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY + +namespace raft { +namespace distance { + +template +void fusedDistanceNNMinReduce(OutT* min, + const DataT* x, + const DataT* y, + const DataT* xn, + const DataT* yn, + IdxT m, + IdxT n, + IdxT k, + void* workspace, + bool sqrt, + bool initOutBuffer, + bool isRowMajor, + raft::distance::DistanceType metric, + float metric_arg, + cudaStream_t stream) RAFT_EXPLICIT; + +} // namespace distance +} // namespace raft + +#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY + +#define instantiate_raft_distance_fusedDistanceNNMinReduce(DataT, OutT, IdxT) \ + extern template void raft::distance::fusedDistanceNNMinReduce( \ + OutT * min, \ + const DataT* x, \ + const DataT* y, \ + const DataT* xn, \ + const DataT* yn, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + void* workspace, \ + bool sqrt, \ + bool initOutBuffer, \ + bool isRowMajor, \ + raft::distance::DistanceType metric, \ + float metric_arg, \ + cudaStream_t stream) + +instantiate_raft_distance_fusedDistanceNNMinReduce(double, double, int); +instantiate_raft_distance_fusedDistanceNNMinReduce(double, double, int64_t); +instantiate_raft_distance_fusedDistanceNNMinReduce(float, float, int); +instantiate_raft_distance_fusedDistanceNNMinReduce(float, float, int64_t); + +// We can't have comma's in the macro expansion, so we use the COMMA macro: +#define COMMA , + +instantiate_raft_distance_fusedDistanceNNMinReduce(double, + raft::KeyValuePair, + int); +instantiate_raft_distance_fusedDistanceNNMinReduce(double, + raft::KeyValuePair, + int64_t); +instantiate_raft_distance_fusedDistanceNNMinReduce(float, raft::KeyValuePair, int); +instantiate_raft_distance_fusedDistanceNNMinReduce(float, + raft::KeyValuePair, + int64_t); + +#undef COMMA + +#undef instantiate_raft_distance_fusedDistanceNNMinReduce diff --git a/cpp/include/raft/distance/fused_distance_nn-inl.cuh b/cpp/include/raft/distance/fused_distance_nn-inl.cuh new file mode 100644 index 0000000000..342bde828d --- /dev/null +++ b/cpp/include/raft/distance/fused_distance_nn-inl.cuh @@ -0,0 +1,325 @@ +/* + * Copyright (c) 2021-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __FUSED_DISTANCE_NN_H +#define __FUSED_DISTANCE_NN_H + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace raft { +namespace distance { + +/** + * \ingroup fused_l2_nn + * @{ + */ +/** + * @brief Fused L2 distance and 1-nearest-neighbor computation in a single call. + * + * The benefits of such a call are 2-fold: 1) eliminate the need for an + * intermediate buffer to store the output of gemm 2) reduce the memory read + * traffic on this intermediate buffer, otherwise needed during the reduction + * phase for 1-NN. + * + * @tparam DataT data type + * @tparam OutT output type to either store 1-NN indices and their minimum + * distances or store only the min distances. Accordingly, one + * has to pass an appropriate `ReduceOpT` + * @tparam IdxT indexing arithmetic type + * @tparam ReduceOpT A struct to perform the final needed reduction operation + * and also to initialize the output array elements with the + * appropriate initial value needed for reduction. + * @tparam KVPReduceOpT A struct providing functions for key-value pair comparison. + * + * @param[out] min will contain the reduced output (Length = `m`) + * (on device) + * @param[in] x first matrix. Row major. Dim = `m x k`. + * (on device). + * @param[in] y second matrix. Row major. Dim = `n x k`. + * (on device). + * @param[in] xn L2 squared norm of `x`. Length = `m`. (on device). + * @param[in] yn L2 squared norm of `y`. Length = `n`. (on device) + * @param[in] m gemm m + * @param[in] n gemm n + * @param[in] k gemm k + * @param[in] workspace temp workspace. Size = sizeof(int)*m. (on device) + * @param[in] redOp reduction operator in the epilogue + * @param[in] pairRedOp reduction operation on key value pairs + * @param[in] sqrt Whether the output `minDist` should contain L2-sqrt + * @param[in] initOutBuffer whether to initialize the output buffer before the + * main kernel launch + * @param[in] isRowMajor whether the input/output is row or column major. + * @param[in] metric Distance metric to be used (supports L2, cosine) + * @param[in] metric_arg power argument for distances like Minkowski (not supported for now) + * @param[in] stream cuda stream + */ +template +void fusedDistanceNN(OutT* min, + const DataT* x, + const DataT* y, + const DataT* xn, + const DataT* yn, + IdxT m, + IdxT n, + IdxT k, + void* workspace, + ReduceOpT redOp, + KVPReduceOpT pairRedOp, + bool sqrt, + bool initOutBuffer, + bool isRowMajor, + raft::distance::DistanceType metric, + float metric_arg, + cudaStream_t stream) +{ + ASSERT(isRowMajor, "fusedDistanceNN only supports row major inputs"); + // When k is smaller than 32, the Policy4x4 results in redundant calculations + // as it uses tiles that have k=32. Therefore, use a "skinny" policy instead + // that uses tiles with a smaller value of k. + bool is_skinny = k < 32; + + size_t bytes = sizeof(DataT) * k; + auto px = reinterpret_cast(x); + auto py = reinterpret_cast(y); + if (16 % sizeof(DataT) == 0 && bytes % 16 == 0 && px % 16 == 0 && py % 16 == 0) { + if (is_skinny) { + detail::fusedDistanceNNImpl< + DataT, + OutT, + IdxT, + typename linalg::Policy4x4Skinny::Policy, + ReduceOpT>(min, + x, + y, + xn, + yn, + m, + n, + k, + (int*)workspace, + redOp, + pairRedOp, + sqrt, + initOutBuffer, + isRowMajor, + metric, + metric_arg, + stream); + } else { + detail::fusedDistanceNNImpl::Policy, + ReduceOpT>(min, + x, + y, + xn, + yn, + m, + n, + k, + (int*)workspace, + redOp, + pairRedOp, + sqrt, + initOutBuffer, + isRowMajor, + metric, + metric_arg, + stream); + } + } else if (8 % sizeof(DataT) == 0 && bytes % 8 == 0 && px % 8 == 0 && py % 8 == 0) { + if (is_skinny) { + detail::fusedDistanceNNImpl< + DataT, + OutT, + IdxT, + typename linalg::Policy4x4Skinny::Policy, + ReduceOpT>(min, + x, + y, + xn, + yn, + m, + n, + k, + (int*)workspace, + redOp, + pairRedOp, + sqrt, + initOutBuffer, + isRowMajor, + metric, + metric_arg, + stream); + } else { + detail::fusedDistanceNNImpl::Policy, + ReduceOpT>(min, + x, + y, + xn, + yn, + m, + n, + k, + (int*)workspace, + redOp, + pairRedOp, + sqrt, + initOutBuffer, + isRowMajor, + metric, + metric_arg, + stream); + } + } else { + if (is_skinny) { + detail::fusedDistanceNNImpl::Policy, + ReduceOpT>(min, + x, + y, + xn, + yn, + m, + n, + k, + (int*)workspace, + redOp, + pairRedOp, + sqrt, + initOutBuffer, + isRowMajor, + metric, + metric_arg, + stream); + } else { + detail::fusedDistanceNNImpl::Policy, + ReduceOpT>(min, + x, + y, + xn, + yn, + m, + n, + k, + (int*)workspace, + redOp, + pairRedOp, + sqrt, + initOutBuffer, + isRowMajor, + metric, + metric_arg, + stream); + } + } +} + +/** + * @brief Wrapper around fusedDistanceNN with minimum reduction operators. + * + * fusedDistanceNN cannot be compiled in the distance library due to the lambda + * operators, so this wrapper covers the most common case (minimum). + * + * @tparam DataT data type + * @tparam OutT output type to either store 1-NN indices and their minimum + * distances (e.g. raft::KeyValuePair) or store only the min + * distances. + * @tparam IdxT indexing arithmetic type + * @param[out] min will contain the reduced output (Length = `m`) + * (on device) + * @param[in] x first matrix. Row major. Dim = `m x k`. + * (on device). + * @param[in] y second matrix. Row major. Dim = `n x k`. + * (on device). + * @param[in] xn L2 squared norm of `x`. Length = `m`. (on device). + * @param[in] yn L2 squared norm of `y`. Length = `n`. (on device) + * @param[in] m gemm m + * @param[in] n gemm n + * @param[in] k gemm k + * @param[in] workspace temp workspace. Size = sizeof(int)*m. (on device) + * @param[in] sqrt Whether the output `minDist` should contain L2-sqrt + * @param[in] initOutBuffer whether to initialize the output buffer before the + * main kernel launch + * @param[in] isRowMajor whether the input/output is row or column major. + * @param[in] metric Distance metric to be used (supports L2, cosine) + * @param[in] metric_arg power argument for distances like Minkowski (not supported for now) + * @param[in] stream cuda stream + */ +template +void fusedDistanceNNMinReduce(OutT* min, + const DataT* x, + const DataT* y, + const DataT* xn, + const DataT* yn, + IdxT m, + IdxT n, + IdxT k, + void* workspace, + bool sqrt, + bool initOutBuffer, + bool isRowMajor, + raft::distance::DistanceType metric, + float metric_arg, + cudaStream_t stream) +{ + MinAndDistanceReduceOp redOp; + KVPMinReduce pairRedOp; + + fusedDistanceNN(min, + x, + y, + xn, + yn, + m, + n, + k, + workspace, + redOp, + pairRedOp, + sqrt, + initOutBuffer, + isRowMajor, + metric, + metric_arg, + stream); +} + +/** @} */ + +} // namespace distance +} // namespace raft + +#endif diff --git a/cpp/include/raft/distance/fused_distance_nn.cuh b/cpp/include/raft/distance/fused_distance_nn.cuh new file mode 100755 index 0000000000..0c22df72f1 --- /dev/null +++ b/cpp/include/raft/distance/fused_distance_nn.cuh @@ -0,0 +1,24 @@ +/* + * Copyright (c) 2021-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#ifndef RAFT_EXPLICIT_INSTANTIATE_ONLY +#include "fused_distance_nn-inl.cuh" +#endif + +#ifdef RAFT_COMPILED +#include "fused_distance_nn-ext.cuh" +#endif diff --git a/cpp/include/raft/distance/fused_l2_nn_helpers.cuh b/cpp/include/raft/distance/fused_distance_nn_helpers.cuh similarity index 89% rename from cpp/include/raft/distance/fused_l2_nn_helpers.cuh rename to cpp/include/raft/distance/fused_distance_nn_helpers.cuh index 996f696ef6..e70d098d09 100644 --- a/cpp/include/raft/distance/fused_l2_nn_helpers.cuh +++ b/cpp/include/raft/distance/fused_distance_nn_helpers.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,7 +17,8 @@ #pragma once #include -#include +// #include +#include namespace raft::distance { diff --git a/cpp/include/raft/distance/fused_l2_nn-ext.cuh b/cpp/include/raft/distance/fused_l2_nn-ext.cuh index c99c1eb015..66e9960f1d 100644 --- a/cpp/include/raft/distance/fused_l2_nn-ext.cuh +++ b/cpp/include/raft/distance/fused_l2_nn-ext.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * Copyright (c) 2021-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,11 +16,11 @@ #pragma once -#include // int64_t -#include // raft::KeyValuePair -#include // raft::resources -#include // include initialize and reduce operations -#include // RAFT_EXPLICIT +#include // int64_t +#include // raft::KeyValuePair +#include // raft::resources +#include // include initialize and reduce operations +#include // RAFT_EXPLICIT #ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY diff --git a/cpp/include/raft/distance/fused_l2_nn-inl.cuh b/cpp/include/raft/distance/fused_l2_nn-inl.cuh index 17373e3bcc..4cb6b367a5 100644 --- a/cpp/include/raft/distance/fused_l2_nn-inl.cuh +++ b/cpp/include/raft/distance/fused_l2_nn-inl.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * Copyright (c) 2021-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,7 +23,7 @@ #include #include #include -#include +#include #include #include #include diff --git a/cpp/include/raft_runtime/distance/fused_distance_nn.hpp b/cpp/include/raft_runtime/distance/fused_distance_nn.hpp new file mode 100644 index 0000000000..70fc884474 --- /dev/null +++ b/cpp/include/raft_runtime/distance/fused_distance_nn.hpp @@ -0,0 +1,71 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +namespace raft::runtime::distance { + +/** + * @defgroup fused_distance_nn_min_arg_runtime Fused Distance 1NN Runtime API + * @{ + */ + +/** + * @brief Wrapper around fusedDistanceNN with minimum reduction operators. + * + * fusedDistanceNN cannot be compiled in the distance library due to the lambda + * operators, so this wrapper covers the most common case (minimum). + * + * @param[in] handle raft handle + * @param[out] min will contain the reduced output (Length = `m`) + * (on device) + * @param[in] x first matrix. Row major. Dim = `m x k`. + * (on device). + * @param[in] y second matrix. Row major. Dim = `n x k`. + * (on device). + * @param[in] m gemm m + * @param[in] n gemm n + * @param[in] k gemm k + * @param[in] sqrt Whether the output `minDist` should contain L2-sqrt + */ +void fused_distance_nn_min_arg(raft::resources const& handle, + int* min, + const float* x, + const float* y, + int m, + int n, + int k, + bool sqrt, + raft::distance::DistanceType metric, + bool isRowMajor, + float metric_arg); + +void fused_distance_nn_min_arg(raft::resources const& handle, + int* min, + const double* x, + const double* y, + int m, + int n, + int k, + bool sqrt, + raft::distance::DistanceType metric, + bool isRowMajor, + float metric_arg); + +/** @} */ // end group fused_distance_nn_min_arg_runtime + +} // end namespace raft::runtime::distance diff --git a/cpp/src/distance/fused_distance_nn.cu b/cpp/src/distance/fused_distance_nn.cu new file mode 100644 index 0000000000..c3d1301e29 --- /dev/null +++ b/cpp/src/distance/fused_distance_nn.cu @@ -0,0 +1,60 @@ +/* + * Copyright (c) 2021-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include // int64_t +#include // raft::KeyValuePair +#include + +#define instantiate_raft_distance_fusedDistanceNNMinReduce(DataT, OutT, IdxT) \ + template void raft::distance::fusedDistanceNNMinReduce( \ + OutT * min, \ + const DataT* x, \ + const DataT* y, \ + const DataT* xn, \ + const DataT* yn, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + void* workspace, \ + bool sqrt, \ + bool initOutBuffer, \ + bool isRowMajor, \ + raft::distance::DistanceType metric, \ + float metric_arg, \ + cudaStream_t stream) + +instantiate_raft_distance_fusedDistanceNNMinReduce(double, double, int); +instantiate_raft_distance_fusedDistanceNNMinReduce(double, double, int64_t); +instantiate_raft_distance_fusedDistanceNNMinReduce(float, float, int); +instantiate_raft_distance_fusedDistanceNNMinReduce(float, float, int64_t); + +// We can't have comma's in the macro expansion, so we use the COMMA macro: +#define COMMA , + +instantiate_raft_distance_fusedDistanceNNMinReduce(double, + raft::KeyValuePair, + int); +instantiate_raft_distance_fusedDistanceNNMinReduce(double, + raft::KeyValuePair, + int64_t); +instantiate_raft_distance_fusedDistanceNNMinReduce(float, raft::KeyValuePair, int); +instantiate_raft_distance_fusedDistanceNNMinReduce(float, + raft::KeyValuePair, + int64_t); + +#undef COMMA + +#undef instantiate_raft_distance_fusedDistanceNNMinReduce diff --git a/cpp/src/raft_runtime/distance/fused_distance_min_arg.cu b/cpp/src/raft_runtime/distance/fused_distance_min_arg.cu new file mode 100644 index 0000000000..90d00d9f6b --- /dev/null +++ b/cpp/src/raft_runtime/distance/fused_distance_min_arg.cu @@ -0,0 +1,137 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace raft::runtime::distance { + +template +struct KeyValueIndexOp { + __host__ __device__ __forceinline__ IndexT + operator()(const raft::KeyValuePair& a) const + { + return a.key; + } +}; + +template +void compute_fused_cosine_nn_min_arg(raft::resources const& handle, + idx_t* min, + const value_t* x, + const value_t* y, + idx_t m, + idx_t n, + idx_t k, + bool sqrt) +{ + rmm::device_uvector workspace(m, resource::get_cuda_stream(handle)); + auto kvp = raft::make_device_vector>(handle, m); + + rmm::device_uvector x_norms(m, resource::get_cuda_stream(handle)); + rmm::device_uvector y_norms(n, resource::get_cuda_stream(handle)); + constexpr bool is_row_major = true; + raft::linalg::rowNorm(x_norms.data(), + x, + k, + m, + raft::linalg::L2Norm, + is_row_major, + resource::get_cuda_stream(handle), + raft::sqrt_op{}); + raft::linalg::rowNorm(y_norms.data(), + y, + k, + n, + raft::linalg::L2Norm, + is_row_major, + resource::get_cuda_stream(handle), + raft::sqrt_op{}); + + raft::distance::fusedDistanceNNMinReduce(kvp.data_handle(), + x, + y, + x_norms.data(), + y_norms.data(), + m, + n, + k, + (void*)workspace.data(), + sqrt, + true, + is_row_major, + raft::distance::DistanceType::CosineExpanded, + 0.0f, + resource::get_cuda_stream(handle)); + + KeyValueIndexOp conversion_op; + thrust::transform(resource::get_thrust_policy(handle), + kvp.data_handle(), + kvp.data_handle() + m, + min, + conversion_op); + resource::sync_stream(handle); +} + +void fused_distance_nn_min_arg(raft::resources const& handle, + int* min, + const float* x, + const float* y, + int m, + int n, + int k, + bool sqrt, + raft::distance::DistanceType metric, + bool isRowMajor, + float metric_arg) +{ + switch (metric) { + case raft::distance::DistanceType::CosineExpanded: + compute_fused_cosine_nn_min_arg(handle, min, x, y, m, n, k, sqrt); + break; + default: assert("only cosine metric is supported with fusedDistanceNN\n"); break; + } +} + +void fused_distance_nn_min_arg(raft::resources const& handle, + int* min, + const double* x, + const double* y, + int m, + int n, + int k, + bool sqrt, + raft::distance::DistanceType metric, + bool isRowMajor, + float metric_arg) +{ + switch (metric) { + case raft::distance::DistanceType::CosineExpanded: + compute_fused_cosine_nn_min_arg(handle, min, x, y, m, n, k, sqrt); + break; + default: assert("only cosine metric is supported with fusedDistanceNN\n"); break; + } +} + +} // end namespace raft::runtime::distance diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index fe29409d9b..2a1384e96e 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -170,6 +170,7 @@ if(BUILD_TESTS) test/distance/masked_nn.cu test/distance/masked_nn_compress_to_bits.cu test/distance/fused_l2_nn.cu + test/distance/fused_cosine_nn.cu test/distance/gram.cu LIB EXPLICIT_INSTANTIATE_ONLY diff --git a/cpp/test/distance/fused_cosine_nn.cu b/cpp/test/distance/fused_cosine_nn.cu new file mode 100644 index 0000000000..5a89e71608 --- /dev/null +++ b/cpp/test/distance/fused_cosine_nn.cu @@ -0,0 +1,416 @@ +/* + * Copyright (c) 2021-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../test_utils.cuh" +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace raft { +namespace distance { + +template +struct RaftKVPMinReduce { + typedef raft::KeyValuePair KVP; + + DI KVP operator()(LabelT rit, const KVP& a, const KVP& b) { return b.value < a.value ? b : a; } + + DI KVP operator()(const KVP& a, const KVP& b) { return b.value < a.value ? b : a; } + +}; // KVPMinReduce + +template +__global__ void naiveCosKernel(raft::KeyValuePair* min, + DataT* x, + DataT* y, + int m, + int n, + int k, + int* workspace, + DataT maxVal) +{ + int midx = threadIdx.y + blockIdx.y * blockDim.y; + int nidx = threadIdx.x + blockIdx.x * blockDim.x; + DataT acc_a = DataT(0); + DataT acc_b = DataT(0); + DataT acc_ab = DataT(0); + // if (midx >= m || nidx >= n) { return; } + + for (int i = 0; i < k; ++i) { + int xidx = i + midx * k; + int yidx = i + nidx * k; + auto a = x[xidx]; + auto b = y[yidx]; + acc_a += a * a; + acc_b += b * b; + acc_ab += a * b; + } + + // Use 1.0 - (cosine similarity) to calc the distance + DataT acc = maxVal; + if (midx < m || nidx < n) { acc = (DataT)1.0 - acc_ab / (raft::sqrt(acc_a) * raft::sqrt(acc_b)); } + + ReduceOpT redOp; + typedef cub::WarpReduce> WarpReduce; + __shared__ typename WarpReduce::TempStorage temp[NWARPS]; + int warpId = threadIdx.x / raft::WarpSize; + raft::KeyValuePair tmp; + tmp.key = nidx; + tmp.value = midx >= m || nidx >= n ? maxVal : acc; + tmp = WarpReduce(temp[warpId]).Reduce(tmp, RaftKVPMinReduce()); + if (threadIdx.x % raft::WarpSize == 0 && midx < m) { + while (atomicCAS(workspace + midx, 0, 1) == 1) + ; + __threadfence(); + redOp(midx, min + midx, tmp); + __threadfence(); + atomicCAS(workspace + midx, 1, 0); + } +} + +template +void naive(raft::KeyValuePair* min, + DataT* x, + DataT* y, + int m, + int n, + int k, + int* workspace, + cudaStream_t stream) +{ + static const dim3 TPB(32, 16, 1); + dim3 nblks(raft::ceildiv(n, (int)TPB.x), raft::ceildiv(m, (int)TPB.y), 1); + RAFT_CUDA_TRY(cudaMemsetAsync(workspace, 0, sizeof(int) * m, stream)); + auto blks = raft::ceildiv(m, 256); + MinAndDistanceReduceOp op; + detail::initKernel, int> + <<>>(min, m, std::numeric_limits::max(), op); + RAFT_CUDA_TRY(cudaGetLastError()); + naiveCosKernel, 16> + <<>>(min, x, y, m, n, k, workspace, std::numeric_limits::max()); + RAFT_CUDA_TRY(cudaGetLastError()); +} + +template +struct Inputs { + DataT tolerance; + int m, n, k; + unsigned long long int seed; + + friend std::ostream& operator<<(std::ostream& os, const Inputs& p) + { + return os << "m: " << p.m + << ", " + "n: " + << p.n + << ", " + "k: " + << p.k + << ", " + "seed: " + << p.seed + << ", " + "tol: " + << p.tolerance; + } +}; + +template +class FusedCosineNNTest : public ::testing::TestWithParam> { + public: + FusedCosineNNTest() + : params(::testing::TestWithParam>::GetParam()), + stream(resource::get_cuda_stream(handle)), + x(params.m * params.k, stream), + y(params.n * params.k, stream), + xn(params.m, stream), + yn(params.n, stream), + min(params.m, stream), + min_ref(params.m, stream), + workspace(params.m * sizeof(int), stream) + { + } + + protected: + void SetUp() override + { + raft::random::RngState r(params.seed); + int m = params.m; + int n = params.n; + int k = params.k; + uniform(handle, r, x.data(), m * k, DataT(-1.0), DataT(1.0)); + uniform(handle, r, y.data(), n * k, DataT(-1.0), DataT(1.0)); + generateGoldenResult(); + raft::linalg::rowNorm( + xn.data(), x.data(), k, m, raft::linalg::L2Norm, true, stream, raft::sqrt_op{}); + raft::linalg::rowNorm( + yn.data(), y.data(), k, n, raft::linalg::L2Norm, true, stream, raft::sqrt_op{}); + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + } + + protected: + raft::resources handle; + cudaStream_t stream; + Inputs params; + rmm::device_uvector x; + rmm::device_uvector y; + rmm::device_uvector xn; + rmm::device_uvector yn; + rmm::device_uvector> min; + rmm::device_uvector> min_ref; + rmm::device_uvector workspace; + + virtual void generateGoldenResult() + { + int m = params.m; + int n = params.n; + int k = params.k; + naive(min_ref.data(), x.data(), y.data(), m, n, k, (int*)workspace.data(), stream); + } + + void runTest(raft::KeyValuePair* out) + { + int m = params.m; + int n = params.n; + int k = params.k; + raft::distance::DistanceType metric = raft::distance::DistanceType::CosineExpanded; + constexpr bool init_out_buffer = true; + fusedDistanceNNMinReduce, int>(out, + x.data(), + y.data(), + xn.data(), + yn.data(), + m, + n, + k, + (void*)workspace.data(), + false, + init_out_buffer, + true, + metric, + 0.0f, + stream); + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + } +}; + +template +struct CompareApproxAbsKVP { + typedef typename raft::KeyValuePair KVP; + CompareApproxAbsKVP(T eps_) : eps(eps_) {} + bool operator()(const KVP& a, const KVP& b) const + { + T diff = std::abs(std::abs(a.value) - std::abs(b.value)); + T m = std::max(std::abs(a.value), std::abs(b.value)); + T ratio = m >= eps ? diff / m : diff; + return (ratio <= eps); + } + + private: + T eps; +}; + +template +struct CompareExactKVP { + typedef typename raft::KeyValuePair KVP; + bool operator()(const KVP& a, const KVP& b) const + { + if (a.value != b.value) return false; + return true; + } +}; + +template +::testing::AssertionResult devArrMatch(const raft::KeyValuePair* expected, + const raft::KeyValuePair* actual, + size_t size, + L eq_compare, + cudaStream_t stream = 0) +{ + typedef typename raft::KeyValuePair KVP; + std::shared_ptr exp_h(new KVP[size]); + std::shared_ptr act_h(new KVP[size]); + raft::update_host(exp_h.get(), expected, size, stream); + raft::update_host(act_h.get(), actual, size, stream); + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + for (size_t i(0); i < size; ++i) { + auto exp = exp_h.get()[i]; + auto act = act_h.get()[i]; + if (!eq_compare(exp, act)) { + return ::testing::AssertionFailure() + << "actual=" << act.key << "," << act.value << " != expected=" << exp.key << "," + << exp.value << " @" << i; + } + } + return ::testing::AssertionSuccess(); +} + +const std::vector> inputsf = { + {0.001f, 32, 32, 32, 1234ULL}, + {0.001f, 32, 64, 32, 1234ULL}, + {0.001f, 64, 32, 32, 1234ULL}, + {0.001f, 64, 64, 32, 1234ULL}, + {0.001f, 128, 32, 32, 1234ULL}, + {0.001f, 128, 64, 32, 1234ULL}, + {0.001f, 128, 128, 64, 1234ULL}, + {0.001f, 64, 128, 128, 1234ULL}, + + {0.001f, 32, 32, 34, 1234ULL}, + {0.001f, 32, 64, 34, 1234ULL}, + {0.001f, 64, 32, 34, 1234ULL}, + {0.001f, 64, 64, 34, 1234ULL}, + {0.001f, 128, 32, 34, 1234ULL}, + {0.001f, 128, 64, 34, 1234ULL}, + {0.001f, 128, 128, 66, 1234ULL}, + {0.001f, 64, 128, 130, 1234ULL}, + + {0.001f, 32, 32, 33, 1234ULL}, + {0.001f, 32, 64, 33, 1234ULL}, + {0.001f, 64, 32, 33, 1234ULL}, + {0.001f, 64, 64, 33, 1234ULL}, + {0.001f, 128, 32, 33, 1234ULL}, + {0.001f, 128, 64, 33, 1234ULL}, + {0.001f, 128, 128, 65, 1234ULL}, + {0.001f, 64, 128, 129, 1234ULL}, + {0.006f, 1805, 134, 2, 1234ULL}, + {0.006f, 8192, 1024, 64, 1234ULL}, + {0.006f, 8192, 1025, 64, 1234ULL}, + + // Repeat with smaller values of k + {0.006f, 32, 32, 1, 1234ULL}, + {0.001f, 32, 64, 2, 1234ULL}, + {0.001f, 64, 32, 3, 1234ULL}, + {0.001f, 64, 64, 4, 1234ULL}, + {0.001f, 128, 32, 5, 1234ULL}, + {0.001f, 128, 64, 6, 1234ULL}, + {0.001f, 128, 128, 7, 1234ULL}, + {0.001f, 64, 128, 8, 1234ULL}, + + {0.001f, 32, 32, 9, 1234ULL}, + {0.001f, 32, 64, 10, 1234ULL}, + {0.001f, 64, 32, 11, 1234ULL}, + {0.001f, 64, 64, 12, 1234ULL}, + {0.001f, 128, 32, 13, 1234ULL}, + {0.001f, 128, 64, 14, 1234ULL}, + {0.001f, 128, 128, 15, 1234ULL}, + {0.001f, 64, 128, 16, 1234ULL}, + + {0.001f, 32, 32, 17, 1234ULL}, + {0.001f, 32, 64, 18, 1234ULL}, + {0.001f, 64, 32, 19, 1234ULL}, + {0.001f, 64, 64, 20, 1234ULL}, + {0.001f, 128, 32, 21, 1234ULL}, + {0.001f, 128, 64, 22, 1234ULL}, + {0.001f, 128, 128, 23, 1234ULL}, + {0.00001, 64, 128, 24, 1234ULL}, + {0.001f, 1805, 134, 25, 1234ULL}, + {0.006f, 8192, 1024, 25, 1234ULL}, + {0.006f, 8192, 1024, 66, 1234ULL}, +}; +typedef FusedCosineNNTest FusedCosineNNTestF; +TEST_P(FusedCosineNNTestF, Result) +{ + runTest(min.data()); + ASSERT_TRUE(devArrMatch( + min_ref.data(), min.data(), params.m, CompareApproxAbsKVP(params.tolerance), stream)); +} +INSTANTIATE_TEST_CASE_P(FusedCosineNNTests, FusedCosineNNTestF, ::testing::ValuesIn(inputsf)); + +const std::vector> inputsd = { + {0.00001, 32, 32, 32, 1234ULL}, {0.00001, 32, 64, 32, 1234ULL}, + {0.00001, 64, 32, 32, 1234ULL}, {0.00001, 64, 64, 32, 1234ULL}, + {0.00001, 128, 32, 32, 1234ULL}, {0.00001, 128, 64, 32, 1234ULL}, + {0.00001, 128, 128, 64, 1234ULL}, {0.00001, 64, 128, 128, 1234ULL}, + + {0.00001, 32, 32, 34, 1234ULL}, {0.00001, 32, 64, 34, 1234ULL}, + {0.00001, 64, 32, 34, 1234ULL}, {0.00001, 64, 64, 34, 1234ULL}, + {0.00001, 128, 32, 34, 1234ULL}, {0.00001, 128, 64, 34, 1234ULL}, + {0.00001, 128, 128, 66, 1234ULL}, {0.00001, 64, 128, 130, 1234ULL}, + + {0.00001, 32, 32, 33, 1234ULL}, {0.00001, 32, 64, 33, 1234ULL}, + {0.00001, 64, 32, 33, 1234ULL}, {0.00001, 64, 64, 33, 1234ULL}, + {0.00001, 128, 32, 33, 1234ULL}, {0.00001, 128, 64, 33, 1234ULL}, + {0.00001, 128, 128, 65, 1234ULL}, {0.00001, 64, 128, 129, 1234ULL}, + + {0.00001, 1805, 134, 2, 1234ULL}, {0.00001, 8192, 1024, 25, 1234ULL}, +}; +typedef FusedCosineNNTest FusedCosineNNTestD; +TEST_P(FusedCosineNNTestD, Result) +{ + runTest(min.data()); + ASSERT_TRUE(devArrMatch( + min_ref.data(), min.data(), params.m, CompareApproxAbsKVP(params.tolerance), stream)); +} +INSTANTIATE_TEST_CASE_P(FusedCosineNNTests, FusedCosineNNTestD, ::testing::ValuesIn(inputsd)); + +/// This is to test output determinism of the prim +template +class FusedCosineNNDetTest : public FusedCosineNNTest { + public: + FusedCosineNNDetTest() : stream(resource::get_cuda_stream(handle)), min1(0, stream) {} + + void SetUp() override + { + FusedCosineNNTest::SetUp(); + int m = this->params.m; + min1.resize(m, stream); + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + } + + void TearDown() override { FusedCosineNNTest::TearDown(); } + + protected: + raft::resources handle; + cudaStream_t stream; + + rmm::device_uvector> min1; + + static const int NumRepeats = 3; + + void generateGoldenResult() override {} +}; + +typedef FusedCosineNNDetTest FusedCosineNNDetTestF; +TEST_P(FusedCosineNNDetTestF, Result) +{ + runTest(min.data()); // assumed to be golden + for (int i = 0; i < NumRepeats; ++i) { + runTest(min1.data()); + ASSERT_TRUE(devArrMatch(min.data(), min1.data(), params.m, CompareExactKVP(), stream)); + cudaMemsetAsync(min1.data(), 0, sizeof(*min.data()) * params.m, stream); + } +} +INSTANTIATE_TEST_CASE_P(FusedCosineNNDetTests, FusedCosineNNDetTestF, ::testing::ValuesIn(inputsf)); + +typedef FusedCosineNNDetTest FusedCosineNNDetTestD; +TEST_P(FusedCosineNNDetTestD, Result) +{ + runTest(min.data()); // assumed to be golden + for (int i = 0; i < NumRepeats; ++i) { + runTest(min1.data()); + ASSERT_TRUE(devArrMatch(min.data(), min1.data(), params.m, CompareExactKVP(), stream)); + } +} +INSTANTIATE_TEST_CASE_P(FusedCosineNNDetTests, FusedCosineNNDetTestD, ::testing::ValuesIn(inputsd)); + +} // end namespace distance +} // end namespace raft