diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index bdaacb4a85..034dc059b0 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -312,10 +312,6 @@ if(RAFT_COMPILE_LIBRARY) src/distance/specializations/detail/l1_double_double_double_int.cu src/distance/specializations/detail/l2_expanded_float_float_float_int.cu src/distance/specializations/detail/l2_expanded_double_double_double_int.cu - src/distance/specializations/detail/l2_sqrt_expanded_float_float_float_int.cu - src/distance/specializations/detail/l2_sqrt_expanded_double_double_double_int.cu - src/distance/specializations/detail/l2_sqrt_unexpanded_float_float_float_int.cu - src/distance/specializations/detail/l2_sqrt_unexpanded_double_double_double_int.cu src/distance/specializations/detail/l2_unexpanded_double_double_double_int.cu src/distance/specializations/detail/l2_unexpanded_float_float_float_int.cu src/distance/specializations/detail/l_inf_double_double_double_int.cu diff --git a/cpp/bench/CMakeLists.txt b/cpp/bench/CMakeLists.txt index 8049074c09..d92ccba8e3 100644 --- a/cpp/bench/CMakeLists.txt +++ b/cpp/bench/CMakeLists.txt @@ -72,6 +72,11 @@ if(BUILD_BENCH) OPTIONAL LIB ) + ConfigureBench( + NAME TUNE_DISTANCE PATH bench/distance/tune_pairwise/kernel.cu + bench/distance/tune_pairwise/bench.cu bench/main.cpp + ) + ConfigureBench( NAME DISTANCE_BENCH diff --git a/cpp/bench/distance/tune_pairwise/bench.cu b/cpp/bench/distance/tune_pairwise/bench.cu new file mode 100644 index 0000000000..87159ab1b1 --- /dev/null +++ b/cpp/bench/distance/tune_pairwise/bench.cu @@ -0,0 +1,151 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Tuning benchmarks. +// +// Goals: +// +// 1. Fast compile times to maintain iteration speed. +// 2. Create benchmarks that can inform the design of the kernels. +// +// Non-goals: +// +// 1. Measure every distance operation. Instead measures just one distance +// operation at the same time. +// 2. Be useful for finding performance regressions. This is handled by the +// normal benchmarks. +// +// So far, both goals are partly achieved. +// +// RE (1), COMPILE TIMES: kernel.cu is fast to compile. This file is not. +// When the internals of a pairwise distance kernel is changed, this file is not +// recompiled. +// +// RE 2, benchmarks with intent: this file contains a benchmark to check the +// maximal throughput of a kernel. Measuring other things, like performance on +// skinny or wide matrices is not yet implemented. + +#include "kernel.cuh" // launch_kernel +#include // std::min +#include // RAFT_BENCH_REGISTER +#include // pairwise_matrix_params +#include // rmm::device_uvector +#include // std::vector + +namespace raft::bench::distance::tune { + +// Max throughput benchmark. +// +// Goal: Measure the maximum distances/sec that can be computed. +// +// To achieve this, we make sure that: +// +// - Input data size is a multiple of the block tile size. +// +// - Perfect distribution of work between SMs, i.e. the number of block tiles is +// a large multiple (num_waves) of the number of blocks (#SMs * occupancy). +// +// - Multiple iterations over Kblk are executed (num_k_iters). +struct throughput_param { + int num_waves; + int occupancy; + int num_k_iters; +}; + +const std::vector throughput_params{ + // 32 waves, requested occupancy of 4, and 32 k iterations typically achieves + // maximum throughput. No need to pick higher values. + {32, 4, 32}, +}; + +struct throughput_bench : public fixture { + const throughput_param p; + + throughput_bench(const throughput_param& p_) : p(p_) {} + + void run_benchmark(::benchmark::State& state) override + { + // Get block size: + int block_m, block_n, block_k; + get_block_size(block_m, block_n, block_k); + + // Determine number of blocks that will be launched. This informs the size + // of the inputs as well as the grid size. + const int num_sms = raft::getMultiProcessorCount(); + const int max_occupancy = get_max_occupancy(); + const int occupancy = std::min(p.occupancy, max_occupancy); + const int num_blocks = occupancy * num_sms; + dim3 grid(num_blocks); + + // Create input sizes that are a multiple of the block tile size. + size_t m = block_m; + size_t n = block_n * p.num_waves * num_blocks; + size_t k = block_k * p.num_k_iters; + + // DataT, OutT, IdxT, etc, are defined in tuned_kernel.cuh + rmm::device_uvector x_vec(m * k, stream); + rmm::device_uvector y_vec(n * k, stream); + rmm::device_uvector x_norm_vec(m, stream); + rmm::device_uvector y_norm_vec(n, stream); + rmm::device_uvector out_vec(m * n, stream); + + auto x = x_vec.data(); + auto y = y_vec.data(); + auto x_norm = x_norm_vec.data(); + auto y_norm = y_norm_vec.data(); + auto out = out_vec.data(); + FinOpT fin_op{}; + + // Create kernel parameter struct. Flip x and y if column major. + IdxT ldx = row_major ? k : m; + IdxT ldy = row_major ? k : n; + IdxT ld_out = row_major ? n : m; + + // Template parameters of pairwise_matrix_params are defined in kernel.cuh + pairwise_matrix_params kparams{ + IdxT(m), IdxT(n), IdxT(k), ldx, ldy, ld_out, x, y, x_norm, y_norm, out, fin_op, row_major}; + + // Run benchmark + loop_on_state(state, [&]() { launch_kernel(kparams, grid, stream); }); + + // Report metrics. We don't report flop/s because we do not know for each + // distance operation how many flops it costs. For L2_unexp and l1, we can + // double this number to get the flop/s. For l2 expanded, core_ops/s should + // equal flop/s (modulo the sqrt and subtracting from the norm). + size_t num_core_ops = m * n * k; + size_t read_elts = n * k + m * k; + size_t write_elts = m * n; + + state.counters["m"] = benchmark::Counter(m); + state.counters["n"] = benchmark::Counter(n); + state.counters["k"] = benchmark::Counter(k); + state.counters["occupancy"] = benchmark::Counter(occupancy); + state.counters["# waves"] = benchmark::Counter(p.num_waves); + state.counters["# k iters"] = benchmark::Counter(p.num_k_iters); + + state.counters["core_ops/s"] = benchmark::Counter(num_core_ops, + benchmark::Counter::kIsIterationInvariantRate, + benchmark::Counter::OneK::kIs1000); + + state.counters["BW"] = benchmark::Counter(write_elts * sizeof(OutT) + read_elts * sizeof(DataT), + benchmark::Counter::kIsIterationInvariantRate, + benchmark::Counter::OneK::kIs1000); + } +}; + +RAFT_BENCH_REGISTER(throughput_bench, "", throughput_params); + +} // namespace raft::bench::distance::tune diff --git a/cpp/bench/distance/tune_pairwise/kernel.cu b/cpp/bench/distance/tune_pairwise/kernel.cu new file mode 100644 index 0000000000..3112e1ea9a --- /dev/null +++ b/cpp/bench/distance/tune_pairwise/kernel.cu @@ -0,0 +1,88 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "kernel.cuh" +#include // pairwise_matrix_sm60_wrapper +#include // raft::linalg::Policy4x4 +#include // raft::util::arch::SM_compute_arch + +namespace raft::bench::distance::tune { + +// Distance op +using OpT = raft::distance::detail::ops::lp_unexp_distance_op; +constexpr float metric_arg = 2.0; +OpT distance_op{metric_arg}; + +// Kernel policy +constexpr int vec_len = 1; +using Policy = typename raft::linalg::Policy4x4::Policy; + +// Architecture +namespace arch = raft::util::arch; +constexpr auto sm_compat_range = arch::SM_range(arch::SM_min(), arch::SM_future()); + +void launch_kernel(pairwise_matrix_params params, dim3 grid, cudaStream_t stream) +{ + dim3 block(Policy::Nthreads); + int smem_size = OpT::shared_mem_size(); + + // Obtain function pointer to kernel + auto kernel = raft::distance::detail::pairwise_matrix_kernel; + + kernel<<>>(distance_op, params); + RAFT_CUDA_TRY(cudaGetLastError()); +} + +void get_block_size(int& m, int& n, int& k) +{ + m = Policy::Mblk; + n = Policy::Nblk; + k = Policy::Kblk; +} + +void* get_kernel_ptr() +{ + auto kernel = raft::distance::detail::pairwise_matrix_kernel; + return reinterpret_cast(kernel); +} + +int get_max_occupancy() +{ + void* kernel_ptr = get_kernel_ptr(); + int max_occupancy; + int smem_size = OpT::shared_mem_size(); + + RAFT_CUDA_TRY(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_occupancy, kernel_ptr, Policy::Nthreads, smem_size)); + + return max_occupancy; +} + +} // namespace raft::bench::distance::tune diff --git a/cpp/bench/distance/tune_pairwise/kernel.cuh b/cpp/bench/distance/tune_pairwise/kernel.cuh new file mode 100644 index 0000000000..5da54a343c --- /dev/null +++ b/cpp/bench/distance/tune_pairwise/kernel.cuh @@ -0,0 +1,44 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // lp_unexp_distance_op +#include // pairwise_matrix_params + +namespace raft::bench::distance::tune { + +// Launch one specific kernel with the following template parameters +constexpr bool row_major = true; +using DataT = float; +using AccT = float; +using OutT = DataT; +using IdxT = int; + +using FinOpT = raft::identity_op; + +using pairwise_matrix_params = + raft::distance::detail::pairwise_matrix_params; + +// Launches kernel +void launch_kernel(pairwise_matrix_params, dim3, cudaStream_t); + +// Describes the block size that is decided by the policy +void get_block_size(int& m, int& n, int& k); + +int get_max_occupancy(); + +} // namespace raft::bench::distance::tune diff --git a/cpp/include/raft/core/kvp.hpp b/cpp/include/raft/core/kvp.hpp index 8d3321eb77..192d160d45 100644 --- a/cpp/include/raft/core/kvp.hpp +++ b/cpp/include/raft/core/kvp.hpp @@ -20,7 +20,7 @@ #ifdef _RAFT_HAS_CUDA #include -#include +#include // raft::shfl_xor #endif namespace raft { /** diff --git a/cpp/include/raft/distance/detail/distance.cuh b/cpp/include/raft/distance/detail/distance.cuh index f469250b45..7493c4e558 100644 --- a/cpp/include/raft/distance/detail/distance.cuh +++ b/cpp/include/raft/distance/detail/distance.cuh @@ -16,25 +16,18 @@ #pragma once -#include -#include - -#include -#include -#include -#include -#include - #include - +#include #include #include - +#include +#include #include #include -#include -#include -#include +#include +#include +#include +#include namespace raft { namespace distance { @@ -140,14 +133,14 @@ void distance_impl(raft::resources const& handle, cudaStream_t stream = raft::resource::get_cuda_stream(handle); - AccT* norm_col_vec = workspace; - AccT* norm_row_vec = workspace; - AccT* sq_norm_col_vec = workspace; - AccT* sq_norm_row_vec = workspace; + AccT* x_norm = workspace; + AccT* y_norm = workspace; + AccT* sq_x_norm = workspace; + AccT* sq_y_norm = workspace; if (x != y) { - norm_row_vec += m; + y_norm += m; - raft::linalg::reduce(norm_col_vec, + raft::linalg::reduce(x_norm, x, k, m, @@ -158,7 +151,7 @@ void distance_impl(raft::resources const& handle, false, raft::identity_op(), raft::add_op()); - raft::linalg::reduce(norm_row_vec, + raft::linalg::reduce(y_norm, y, k, n, @@ -170,12 +163,12 @@ void distance_impl(raft::resources const& handle, raft::identity_op(), raft::add_op()); - sq_norm_col_vec += (m + n); - sq_norm_row_vec = sq_norm_col_vec + m; - raft::linalg::rowNorm(sq_norm_col_vec, x, k, m, raft::linalg::L2Norm, is_row_major, stream); - raft::linalg::rowNorm(sq_norm_row_vec, y, k, n, raft::linalg::L2Norm, is_row_major, stream); + sq_x_norm += (m + n); + sq_y_norm = sq_x_norm + m; + raft::linalg::rowNorm(sq_x_norm, x, k, m, raft::linalg::L2Norm, is_row_major, stream); + raft::linalg::rowNorm(sq_y_norm, y, k, n, raft::linalg::L2Norm, is_row_major, stream); } else { - raft::linalg::reduce(norm_col_vec, + raft::linalg::reduce(x_norm, x, k, m, @@ -186,15 +179,15 @@ void distance_impl(raft::resources const& handle, false, raft::identity_op(), raft::add_op()); - sq_norm_col_vec += m; - sq_norm_row_vec = sq_norm_col_vec; - raft::linalg::rowNorm(sq_norm_col_vec, x, k, m, raft::linalg::L2Norm, is_row_major, stream); + sq_x_norm += m; + sq_y_norm = sq_x_norm; + raft::linalg::rowNorm(sq_x_norm, x, k, m, raft::linalg::L2Norm, is_row_major, stream); } using OpT = ops::correlation_distance_op; - OpT corr_op(is_row_major, sq_norm_col_vec, sq_norm_row_vec, m, n, k); + OpT corr_op(is_row_major, sq_x_norm, sq_y_norm, m, n, k); pairwise_matrix_dispatch( - corr_op, m, n, k, x, y, norm_col_vec, norm_row_vec, out, fin_op, stream, is_row_major); + corr_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); } template @@ -223,22 +216,22 @@ void distance_impl(raft::resources const& handle, cudaStream_t stream = raft::resource::get_cuda_stream(handle); - DataT* norm_A = workspace; - DataT* norm_B = workspace; + DataT* x_norm = workspace; + DataT* y_norm = workspace; if (x != y) { - norm_B += m; + y_norm += m; raft::linalg::rowNorm( - norm_A, x, k, m, raft::linalg::L2Norm, is_row_major, stream, raft::sqrt_op{}); + x_norm, x, k, m, raft::linalg::L2Norm, is_row_major, stream, raft::sqrt_op{}); raft::linalg::rowNorm( - norm_B, y, k, n, raft::linalg::L2Norm, is_row_major, stream, raft::sqrt_op{}); + y_norm, y, k, n, raft::linalg::L2Norm, is_row_major, stream, raft::sqrt_op{}); } else { raft::linalg::rowNorm( - norm_A, x, k, m, raft::linalg::L2Norm, is_row_major, stream, raft::sqrt_op{}); + x_norm, x, k, m, raft::linalg::L2Norm, is_row_major, stream, raft::sqrt_op{}); } ops::cosine_distance_op distance_op{}; pairwise_matrix_dispatch( - distance_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major); + distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); } template @@ -389,10 +382,6 @@ void distance_impl(raft::resources const& handle, return (!x_zero) * raft::exp(input); }; - // This op takes some shortcuts when x equals y. So its behavior changes based - // on this. - ops::kl_divergence_op kl_divergence{is_row_major, x == y}; - if (x != y) { raft::linalg::unaryOp( (DataT*)y, y, n * k, unaryOp_lambda, stream); @@ -401,8 +390,12 @@ void distance_impl(raft::resources const& handle, const DataT* x_norm = nullptr; const DataT* y_norm = nullptr; - pairwise_matrix_dispatch( - kl_divergence, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); + // This op takes some shortcuts when x equals y. So its behavior changes based + // on this. + ops::kl_divergence_op distance_op{is_row_major, x == y}; + + pairwise_matrix_dispatch( + distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); if (x != y) { // Now reverse previous log (x) back to x using (e ^ log(x)) @@ -464,22 +457,22 @@ void distance_impl_l2_expanded( // NOTE: different name "workspace size error"); ASSERT(workspace != nullptr, "workspace is null"); - DataT* norm_A = workspace; - DataT* norm_B = workspace; + DataT* x_norm = workspace; + DataT* y_norm = workspace; if (x != y) { - norm_B += m; + y_norm += m; raft::linalg::rowNorm( - norm_A, x, k, m, raft::linalg::L2Norm, is_row_major, stream, raft::identity_op{}); + x_norm, x, k, m, raft::linalg::L2Norm, is_row_major, stream, raft::identity_op{}); raft::linalg::rowNorm( - norm_B, y, k, n, raft::linalg::L2Norm, is_row_major, stream, raft::identity_op{}); + y_norm, y, k, n, raft::linalg::L2Norm, is_row_major, stream, raft::identity_op{}); } else { raft::linalg::rowNorm( - norm_A, x, k, m, raft::linalg::L2Norm, is_row_major, stream, raft::identity_op{}); + x_norm, x, k, m, raft::linalg::L2Norm, is_row_major, stream, raft::identity_op{}); } ops::l2_exp_distance_op distance_op{perform_sqrt}; pairwise_matrix_dispatch( - distance_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major); + distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); } template @@ -543,13 +536,13 @@ void distance_impl(raft::resources const& handle, ops::l2_unexp_distance_op l2_op(perform_sqrt); // The unexpanded L2 does not require the norms of a and b to be calculated. - const DataT* norm_A = nullptr; - const DataT* norm_B = nullptr; + const DataT* x_norm = nullptr; + const DataT* y_norm = nullptr; cudaStream_t stream = raft::resource::get_cuda_stream(handle); pairwise_matrix_dispatch( - l2_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major); + l2_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); } template @@ -571,13 +564,13 @@ void distance_impl(raft::resources const& handle, ops::l2_unexp_distance_op l2_op(perform_sqrt); // The unexpanded L2 does not require the norms of a and b to be calculated. - const DataT* norm_A = nullptr; - const DataT* norm_B = nullptr; + const DataT* x_norm = nullptr; + const DataT* y_norm = nullptr; cudaStream_t stream = raft::resource::get_cuda_stream(handle); pairwise_matrix_dispatch( - l2_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major); + l2_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); } template diff --git a/cpp/include/raft/distance/detail/distance_ops/canberra.cuh b/cpp/include/raft/distance/detail/distance_ops/canberra.cuh index 930294ce31..eaf37b7e9c 100644 --- a/cpp/include/raft/distance/detail/distance_ops/canberra.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/canberra.cuh @@ -16,7 +16,8 @@ #pragma once -#include +#include // raft::abs +#include // DI namespace raft::distance::detail::ops { @@ -42,7 +43,7 @@ struct canberra_distance_op { // Size of shared memory. This is normally decided by the kernel policy, but // some ops such as correlation_distance_op use more. template - constexpr size_t shared_mem_size() + static constexpr size_t shared_mem_size() { return Policy::SmemSize; } diff --git a/cpp/include/raft/distance/detail/distance_ops/correlation.cuh b/cpp/include/raft/distance/detail/distance_ops/correlation.cuh index 289b69070a..4fc4bb8297 100644 --- a/cpp/include/raft/distance/detail/distance_ops/correlation.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/correlation.cuh @@ -16,7 +16,7 @@ #pragma once -#include +#include // DI namespace raft::distance::detail::ops { @@ -61,7 +61,7 @@ struct correlation_distance_op { // Size of shared memory. This is normally decided by the kernel policy, but // some ops such as correlation_distance_op use more. template - constexpr size_t shared_mem_size() + static constexpr size_t shared_mem_size() { return Policy::SmemSize + (2 * (Policy::Mblk + Policy::Nblk) * sizeof(DataT)); } diff --git a/cpp/include/raft/distance/detail/distance_ops/cosine.cuh b/cpp/include/raft/distance/detail/distance_ops/cosine.cuh index 7c37c27b4e..0883136c9f 100644 --- a/cpp/include/raft/distance/detail/distance_ops/cosine.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/cosine.cuh @@ -16,7 +16,7 @@ #pragma once -#include +#include // DI namespace raft::distance::detail::ops { @@ -26,7 +26,7 @@ struct cosine_cutlass_op { __device__ cosine_cutlass_op() noexcept {} __device__ AccT operator()(DataT& aNorm, const DataT& bNorm, DataT& accVal) const noexcept { - return static_cast(1.0) - (AccT)(accVal / (aNorm * bNorm)); + return static_cast(1.0) - static_cast(accVal / (aNorm * bNorm)); } __device__ AccT operator()(DataT aData) const noexcept { return aData; } }; @@ -53,7 +53,7 @@ struct cosine_distance_op { // Size of shared memory. This is normally decided by the kernel policy, but // some ops such as correlation_distance_op use more. template - constexpr size_t shared_mem_size() + static constexpr size_t shared_mem_size() { return Policy::SmemSize + ((Policy::Mblk + Policy::Nblk) * sizeof(DataT)); } @@ -76,7 +76,10 @@ struct cosine_distance_op { } } - cosine_cutlass_op get_cutlass_op() { return cosine_cutlass_op(); } + constexpr cosine_cutlass_op get_cutlass_op() const + { + return cosine_cutlass_op(); + } }; } // namespace raft::distance::detail::ops diff --git a/cpp/include/raft/distance/detail/distance_ops/cutlass.cuh b/cpp/include/raft/distance/detail/distance_ops/cutlass.cuh index d3eb90467b..7a4fe0ce83 100644 --- a/cpp/include/raft/distance/detail/distance_ops/cutlass.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/cutlass.cuh @@ -16,7 +16,8 @@ #pragma once -#include +#include // std::false_type +#include // std::declval namespace raft::distance::detail::ops { @@ -34,7 +35,8 @@ struct has_cutlass_op : std::false_type { // Specialization recognizes types that do support CUTLASS template -struct has_cutlass_op> : std::true_type { +struct has_cutlass_op().get_cutlass_op())>> + : std::true_type { }; } // namespace raft::distance::detail::ops diff --git a/cpp/include/raft/distance/detail/distance_ops/hamming.cuh b/cpp/include/raft/distance/detail/distance_ops/hamming.cuh index 1cfdcfdc73..475b8892e9 100644 --- a/cpp/include/raft/distance/detail/distance_ops/hamming.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/hamming.cuh @@ -16,7 +16,7 @@ #pragma once -#include +#include // DI namespace raft::distance::detail::ops { @@ -45,7 +45,7 @@ struct hamming_distance_op { // Size of shared memory. This is normally decided by the kernel policy, but // some ops such as correlation_distance_op use more. template - constexpr size_t shared_mem_size() + static constexpr size_t shared_mem_size() { return Policy::SmemSize; } diff --git a/cpp/include/raft/distance/detail/distance_ops/hellinger.cuh b/cpp/include/raft/distance/detail/distance_ops/hellinger.cuh index c4aecc7a6f..0489b45854 100644 --- a/cpp/include/raft/distance/detail/distance_ops/hellinger.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/hellinger.cuh @@ -15,7 +15,7 @@ */ #pragma once -#include +#include // DI namespace raft::distance::detail::ops { @@ -42,7 +42,7 @@ struct hellinger_distance_op { // Size of shared memory. This is normally decided by the kernel policy, but // some ops such as correlation_distance_op use more. template - constexpr size_t shared_mem_size() + static constexpr size_t shared_mem_size() { return Policy::SmemSize; } diff --git a/cpp/include/raft/distance/detail/distance_ops/jensen_shannon.cuh b/cpp/include/raft/distance/detail/distance_ops/jensen_shannon.cuh index 41eeb9dd83..e46c63734c 100644 --- a/cpp/include/raft/distance/detail/distance_ops/jensen_shannon.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/jensen_shannon.cuh @@ -15,7 +15,8 @@ */ #pragma once -#include +#include // raft::log +#include // DI namespace raft::distance::detail::ops { @@ -44,7 +45,7 @@ struct jensen_shannon_distance_op { // Size of shared memory. This is normally decided by the kernel policy, but // some ops such as correlation_distance_op use more. template - constexpr size_t shared_mem_size() + static constexpr size_t shared_mem_size() { return Policy::SmemSize; } diff --git a/cpp/include/raft/distance/detail/distance_ops/kl_divergence.cuh b/cpp/include/raft/distance/detail/distance_ops/kl_divergence.cuh index d046b62c30..d083c5ddcc 100644 --- a/cpp/include/raft/distance/detail/distance_ops/kl_divergence.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/kl_divergence.cuh @@ -15,7 +15,8 @@ */ #pragma once -#include +#include // raft::log +#include // DI namespace raft::distance::detail::ops { @@ -49,7 +50,7 @@ struct kl_divergence_op { // Size of shared memory. This is normally decided by the kernel policy, but // some ops such as correlation_distance_op use more. template - constexpr size_t shared_mem_size() + static constexpr size_t shared_mem_size() { return Policy::SmemSize; } diff --git a/cpp/include/raft/distance/detail/distance_ops/l1.cuh b/cpp/include/raft/distance/detail/distance_ops/l1.cuh index 8ec4000827..7e86fd3603 100644 --- a/cpp/include/raft/distance/detail/distance_ops/l1.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/l1.cuh @@ -15,7 +15,7 @@ */ #pragma once -#include +#include // DI namespace raft::distance::detail::ops { @@ -41,7 +41,7 @@ struct l1_distance_op { // Size of shared memory. This is normally decided by the kernel policy, but // some ops such as correlation_distance_op use more. template - constexpr size_t shared_mem_size() + static constexpr size_t shared_mem_size() { return Policy::SmemSize; } diff --git a/cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh b/cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh index 2a7af53813..95577fd311 100644 --- a/cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh @@ -16,7 +16,7 @@ #pragma once -#include +#include // DI namespace raft::distance::detail::ops { @@ -54,7 +54,7 @@ struct l2_exp_distance_op { using AccT = AccType; using IdxT = IdxType; - bool sqrt; + const bool sqrt; l2_exp_distance_op(bool sqrt_) noexcept : sqrt(sqrt_) {} @@ -67,7 +67,7 @@ struct l2_exp_distance_op { // Size of shared memory. This is normally decided by the kernel policy, but // some ops such as correlation_distance_op use more. template - constexpr size_t shared_mem_size() + static constexpr size_t shared_mem_size() { return Policy::SmemSize + ((Policy::Mblk + Policy::Nblk) * sizeof(DataT)); } @@ -102,7 +102,10 @@ struct l2_exp_distance_op { } } - l2_exp_cutlass_op get_cutlass_op() { return l2_exp_cutlass_op(sqrt); } + constexpr l2_exp_cutlass_op get_cutlass_op() const + { + return l2_exp_cutlass_op(sqrt); + } }; } // namespace raft::distance::detail::ops diff --git a/cpp/include/raft/distance/detail/distance_ops/l2_unexp.cuh b/cpp/include/raft/distance/detail/distance_ops/l2_unexp.cuh index f0ea591eaf..62c212ee8f 100644 --- a/cpp/include/raft/distance/detail/distance_ops/l2_unexp.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/l2_unexp.cuh @@ -16,7 +16,7 @@ #pragma once -#include +#include // DI namespace raft::distance::detail::ops { @@ -46,7 +46,7 @@ struct l2_unexp_distance_op { // Size of shared memory. This is normally decided by the kernel policy, but // some ops such as correlation_distance_op use more. template - constexpr size_t shared_mem_size() + static constexpr size_t shared_mem_size() { return Policy::SmemSize; } diff --git a/cpp/include/raft/distance/detail/distance_ops/l_inf.cuh b/cpp/include/raft/distance/detail/distance_ops/l_inf.cuh index fb21fb1a21..88853a3083 100644 --- a/cpp/include/raft/distance/detail/distance_ops/l_inf.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/l_inf.cuh @@ -16,7 +16,7 @@ #pragma once -#include +#include // DI namespace raft::distance::detail::ops { @@ -42,7 +42,7 @@ struct l_inf_distance_op { // Size of shared memory. This is normally decided by the kernel policy, but // some ops such as correlation_distance_op use more. template - constexpr size_t shared_mem_size() + static constexpr size_t shared_mem_size() { return Policy::SmemSize; } diff --git a/cpp/include/raft/distance/detail/distance_ops/lp_unexp.cuh b/cpp/include/raft/distance/detail/distance_ops/lp_unexp.cuh index 71dfd51a6e..290f4af1b4 100644 --- a/cpp/include/raft/distance/detail/distance_ops/lp_unexp.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/lp_unexp.cuh @@ -15,7 +15,8 @@ */ #pragma once -#include +#include // raft::pow, raft::abs +#include // DI namespace raft::distance::detail::ops { @@ -45,7 +46,7 @@ struct lp_unexp_distance_op { // Size of shared memory. This is normally decided by the kernel policy, but // some ops such as correlation_distance_op use more. template - constexpr size_t shared_mem_size() + static constexpr size_t shared_mem_size() { return Policy::SmemSize; } diff --git a/cpp/include/raft/distance/detail/distance_ops/russel_rao.cuh b/cpp/include/raft/distance/detail/distance_ops/russel_rao.cuh index ea09e4d1db..63dbf350d1 100644 --- a/cpp/include/raft/distance/detail/distance_ops/russel_rao.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/russel_rao.cuh @@ -16,7 +16,7 @@ #pragma once -#include +#include // DI namespace raft::distance::detail::ops { @@ -47,7 +47,7 @@ struct russel_rao_distance_op { // Size of shared memory. This is normally decided by the kernel policy, but // some ops such as correlation_distance_op use more. template - constexpr size_t shared_mem_size() + static constexpr size_t shared_mem_size() { return Policy::SmemSize; } diff --git a/cpp/include/raft/distance/detail/distance_ops/template.cuh b/cpp/include/raft/distance/detail/distance_ops/template.cuh index 6998f3cad4..4320068361 100644 --- a/cpp/include/raft/distance/detail/distance_ops/template.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/template.cuh @@ -16,7 +16,7 @@ #pragma once -#include +#include // DI namespace raft::distance::detail::ops { @@ -42,8 +42,8 @@ struct template_distance_op { // Size of shared memory. This is normally decided by the kernel policy, but // some ops such as correlation_distance_op use more. - template - constexpr size_t shared_mem_size() + template + static constexpr size_t shared_mem_size() { return Policy::SmemSize + TODO; } @@ -59,6 +59,10 @@ struct template_distance_op { { TODO; } + + // If exist, returns a cutlass op that performs the same operation. + // See cosine and l2_exp distance ops for an example. + constexpr l2_exp_cutlass_op get_cutlass_op() const { TODO; } }; } // namespace raft::distance::detail::ops diff --git a/cpp/include/raft/distance/detail/fused_l2_nn.cuh b/cpp/include/raft/distance/detail/fused_l2_nn.cuh index 8fbd7a9c69..be6fed9f10 100644 --- a/cpp/include/raft/distance/detail/fused_l2_nn.cuh +++ b/cpp/include/raft/distance/detail/fused_l2_nn.cuh @@ -16,23 +16,20 @@ #pragma once -#include -#include -#include -#include -#include -#include +#include // size_t +#include // std::numeric_limits +#include // raft::KeyValuePair +#include // raft::identity_op +#include // ops::l2_exp_distance_op +#include // PairwiseDistances +#include // Policy +#include // raft::ceildiv, raft::shfl namespace raft { namespace distance { namespace detail { -#if (ENABLE_MEMCPY_ASYNC == 1) -#include -using namespace nvcuda::experimental; -#endif - template struct KVPMinReduceImpl { typedef raft::KeyValuePair KVP; @@ -124,11 +121,10 @@ DI void updateReducedVal( template __global__ __launch_bounds__(P::Nthreads, 2) void fusedL2NNkernel(OutT* min, const DataT* x, @@ -142,7 +138,7 @@ __global__ __launch_bounds__(P::Nthreads, 2) void fusedL2NNkernel(OutT* min, int* mutex, ReduceOpT redOp, KVPReduceOpT pairRedOp, - CoreLambda core_op, + OpT distance_op, FinalLambda fin_op) { extern __shared__ char smem[]; @@ -163,24 +159,6 @@ __global__ __launch_bounds__(P::Nthreads, 2) void fusedL2NNkernel(OutT* min, IdxT gridStrideY) { KVPReduceOpT pairRed_op(pairRedOp); -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = 0; j < P::AccColsPerTh; ++j) { - acc[i][j] = regxn[i] + regyn[j] - (DataT)2.0 * acc[i][j]; - } - } - if (Sqrt) { -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = 0; j < P::AccColsPerTh; ++j) { - auto acc_ij = acc[i][j]; - acc[i][j] = acc_ij > DataT{0} ? raft::sqrt(acc_ij) : DataT{0}; - } - } - } - // intra thread reduce const auto acccolid = threadIdx.x % P::AccThCols; const auto accrowid = threadIdx.x / P::AccThCols; @@ -229,18 +207,18 @@ __global__ __launch_bounds__(P::Nthreads, 2) void fusedL2NNkernel(OutT* min, }; IdxT lda = k, ldb = k, ldd = n; - PairwiseDistances + row_major, + write_out> obj(x, y, m, @@ -251,9 +229,9 @@ __global__ __launch_bounds__(P::Nthreads, 2) void fusedL2NNkernel(OutT* min, ldd, xn, yn, - nullptr, + nullptr, // Output pointer smem, - core_op, + distance_op, epilog_lambda, fin_op, rowEpilog_lambda); @@ -289,9 +267,6 @@ void fusedL2NNImpl(OutT* min, constexpr auto maxVal = std::numeric_limits::max(); typedef KeyValuePair KVPair; - // Accumulation operation lambda - auto core_lambda = [] __device__(DataT & acc, DataT & x, DataT & y) { acc += x * y; }; - RAFT_CUDA_TRY(cudaMemsetAsync(workspace, 0, sizeof(int) * m, stream)); if (initOutBuffer) { initKernel @@ -300,59 +275,25 @@ void fusedL2NNImpl(OutT* min, } constexpr size_t shmemSize = P::SmemSize + ((P::Mblk + P::Nblk) * sizeof(DataT)); - if (sqrt) { - auto fusedL2NNSqrt = fusedL2NNkernel; - dim3 grid = launchConfigGenerator

(m, n, shmemSize, fusedL2NNSqrt); - - fusedL2NNSqrt<<>>(min, - x, - y, - xn, - yn, - m, - n, - k, - maxVal, - workspace, - redOp, - pairRedOp, - core_lambda, - raft::identity_op{}); - } else { - auto fusedL2NN = fusedL2NNkernel; - dim3 grid = launchConfigGenerator

(m, n, shmemSize, fusedL2NN); - fusedL2NN<<>>(min, - x, - y, - xn, - yn, - m, - n, - k, - maxVal, - workspace, - redOp, - pairRedOp, - core_lambda, - raft::identity_op{}); - } + using AccT = DataT; + ops::l2_exp_distance_op distance_op{sqrt}; + + raft::identity_op fin_op{}; + + auto kernel = fusedL2NNkernel; + + dim3 grid = launchConfigGenerator

(m, n, shmemSize, kernel); + + kernel<<>>( + min, x, y, xn, yn, m, n, k, maxVal, workspace, redOp, pairRedOp, distance_op, fin_op); RAFT_CUDA_TRY(cudaGetLastError()); } diff --git a/cpp/include/raft/distance/detail/pairwise_distance_base.cuh b/cpp/include/raft/distance/detail/pairwise_distance_base.cuh index 0293f10c29..c6b09be31e 100644 --- a/cpp/include/raft/distance/detail/pairwise_distance_base.cuh +++ b/cpp/include/raft/distance/detail/pairwise_distance_base.cuh @@ -14,14 +14,11 @@ * limitations under the License. */ #pragma once -#include -#include -#include -#include -#include -#include +#include // raft::linalg::Contractions_NT +#include // ceildiv +#include // RAFT_CUDA_TRY -#include +#include // size_t namespace raft { namespace distance { @@ -29,16 +26,12 @@ namespace detail { /** * @brief Device class for L1, L2 and cosine distance metrics. - * @tparam useNorms whether norms are needed * @tparam DataT input data-type (for A and B matrices) * @tparam AccT accumulation data-type * @tparam OutT output data-type (for C and D matrices) * @tparam IdxT index data-type * @tparam Policy struct which tunes the Contraction kernel - * @tparam CoreLambda tells how to accumulate an x and y into - acc. its signature: - template void core_lambda(AccT& acc, - const DataT& x, const DataT& y) + * @tparam OpT A distance operation, e.g., cosine_distance_op. * @tparam EpilogueLambda applies an elementwise function to compute final values. Its signature is: template void epilogue_lambda @@ -56,19 +49,17 @@ namespace detail { * @param[in] yn row norms of input matrix B. Required for expanded L2, cosine * @param[output] pD output matrix * @param[in] smem shared mem buffer for intermediate storage of A, B, xn & yn. - * @param core_op the core accumulation operation lambda + * @param distance_op the distance operation, e.g. cosine_distance_op * @param epilog_op the epilog operation lambda * @param fin_op the final gemm epilogue lambda * @param rowEpilog_op epilog lambda that executes when a full row has been processed */ -template > struct PairwiseDistances : public BaseClass { + // Get accumulation type from distance_op + using AccT = typename OpT::AccT; + private: typedef Policy P; const DataT* xn; @@ -83,7 +77,7 @@ struct PairwiseDistances : public BaseClass { const DataT* const yBase; OutT* dOutput; char* smem; - CoreLambda core_op; + OpT distance_op; EpilogueLambda epilog_op; FinalLambda fin_op; rowEpilogueLambda rowEpilog_op; @@ -109,7 +103,7 @@ struct PairwiseDistances : public BaseClass { const DataT* _yn, OutT* _dOutput, char* _smem, - CoreLambda _core_op, + OpT _distance_op, EpilogueLambda _epilog_op, FinalLambda _fin_op, rowEpilogueLambda _rowEpilog_op) @@ -119,7 +113,7 @@ struct PairwiseDistances : public BaseClass { yBase(_y), dOutput(_dOutput), smem(_smem), - core_op(_core_op), + distance_op(_distance_op), epilog_op(_epilog_op), fin_op(_fin_op), rowEpilog_op(_rowEpilog_op), @@ -159,15 +153,25 @@ struct PairwiseDistances : public BaseClass { this->switch_read_buffer(); // Epilog: - if (useNorms) { + if (distance_op.use_norms) { DataT regxn[P::AccRowsPerTh], regyn[P::AccColsPerTh]; load_norms(tile_idx_m, tile_idx_n, regxn, regyn); // Overlap ldg with epilog computation ldgNextGridStride(tile_idx_m, tile_idx_n); + // Calculate distance_op epilog. + // Use .template to disambiguate (See: + // https://en.cppreference.com/w/cpp/language/dependent_name) + distance_op.template epilog(acc, regxn, regyn, tile_idx_n, tile_idx_m); + // And any possible additional epilogs epilog_op(acc, regxn, regyn, tile_idx_n, tile_idx_m); } else { // Overlap ldg with epilog computation ldgNextGridStride(tile_idx_m, tile_idx_n); + // Calculate distance_op epilog. + // Use .template to disambiguate (See: + // https://en.cppreference.com/w/cpp/language/dependent_name) + distance_op.template epilog(acc, nullptr, nullptr, tile_idx_n, tile_idx_m); + // And any possible additional epilogs epilog_op(acc, nullptr, nullptr, tile_idx_n, tile_idx_m); } if (writeOut) { store_output(tile_idx_m, tile_idx_n); } @@ -201,24 +205,41 @@ struct PairwiseDistances : public BaseClass { } } - DI void accumulate() + DI void accumulate_reg_tile(DataT (®_x)[P::AccRowsPerTh][P::Veclen], + DataT (®_y)[P::AccColsPerTh][P::Veclen]) { #pragma unroll - for (int ki = 0; ki < P::Kblk; ki += P::Veclen) { - this->ldsXY(ki); + for (int v = 0; v < P::Veclen; ++v) { #pragma unroll for (int i = 0; i < P::AccRowsPerTh; ++i) { #pragma unroll for (int j = 0; j < P::AccColsPerTh; ++j) { -#pragma unroll - for (int v = 0; v < P::Veclen; ++v) { - core_op(acc[i][j], this->regx[i][v], this->regy[j][v]); - } + distance_op.core(acc[i][j], reg_x[i][v], reg_y[j][v]); } } } } + DI void accumulate() + { + // We have a separate ldsXY and accumulate_reg_tile outside the loop body, + // so that these separated calls can be interspersed with preceding and + // following instructions, thereby hiding latency. + this->ldsXY(0); + + // If expensive inner loop, do not unroll loop. + constexpr int num_iterations = P::Kblk / P::Veclen - 1; + constexpr int unroll_count = decltype(distance_op)::expensive_inner_loop ? 1 : num_iterations; +#pragma unroll unroll_count + for (int ki = P::Veclen; ki < P::Kblk; ki += P::Veclen) { + accumulate_reg_tile(this->regx, this->regy); + this->ldsXY(ki); + } + + // Accumulate last loaded tile. + accumulate_reg_tile(this->regx, this->regy); + } + DI void load_norms(IdxT tile_idx_m, IdxT tile_idx_n, DataT (®xn)[P::AccRowsPerTh], @@ -274,7 +295,11 @@ struct PairwiseDistances : public BaseClass { template dim3 launchConfigGenerator(IdxT m, IdxT n, std::size_t sMemSize, T func) { - const auto numSMs = raft::getMultiProcessorCount(); + int devId; + RAFT_CUDA_TRY(cudaGetDevice(&devId)); + int numSMs; + RAFT_CUDA_TRY(cudaDeviceGetAttribute(&numSMs, cudaDevAttrMultiProcessorCount, devId)); + int numBlocksPerSm = 0; dim3 grid; diff --git a/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh b/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh index c5fdd28117..efcd5d9389 100644 --- a/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh +++ b/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh @@ -64,21 +64,20 @@ template -typename std::enable_if::value>::type cutlassDistanceKernel( - const DataT* x, - const DataT* y, - const DataT* xn, - const DataT* yn, - IdxT m, - IdxT n, - IdxT k, - IdxT lda, - IdxT ldb, - IdxT ldd, - OutT* dOutput, - FinalLambda fin_op, - OpT distance_op, - cudaStream_t stream) +std::enable_if_t::value> cutlassDistanceKernel(const DataT* x, + const DataT* y, + const DataT* xn, + const DataT* yn, + IdxT m, + IdxT n, + IdxT k, + IdxT lda, + IdxT ldb, + IdxT ldd, + OutT* dOutput, + FinalLambda fin_op, + OpT distance_op, + cudaStream_t stream) { static_assert(!(std::is_same::value), "OutType bool is not supported use uint8_t instead"); diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh index 8524ce6fdf..e04b56ee8a 100644 --- a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh +++ b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh @@ -15,63 +15,74 @@ */ #pragma once -#include -#include -#include -#include -#include -#include -#include -#include +/* This file has two responsibilities: + * + * 1. Dispatch to the correct implementation of a kernel based on the + * architecture of the device on which the kernel will be launched. For + * instance, the cosine distance has a CUTLASS-based implementation that can + * be used on SM80+ and the normal implementation that is used on older + * architectures. + * + * 2. Provide concise function templates that can be instantiated in + * src/distance/distance/specializations/detail/. Previously, + * raft::distance::detail::distance was instantiated. The function + * necessarily required a large set of include files, which slowed down the + * build. The raft::distance::detail::pairwise_matrix_arch_dispatch functions + * do not require as large an include files set, which speeds up the build. + */ + +#include // ops::has_cutlass_op +#include // dispatch_sm60 +#include // pairwise_matrix_params +#include // raft::util::arch::SM_* + +// NOTE: to minimize compile times, we do not include dispatch_sm80.cuh. +// Including dispatch_sm80.cuh can slow down compile times (due to CUTLASS). +// Therefore, it is the including file's responsibility to include the correct +// dispatch_smXX.cuh headers, as is done in raft/distance/detail/distance.cuh +// and the specializations in src/distance/distance/specializations/detail/. namespace raft::distance::detail { +// This forward-declaration ensures that we do not need to include +// dispatch_sm80.cuh if we are not calling it in practice. This makes compiling +// all the non-CUTLASS based distance specializations faster. For CUTLASS-based +// distances, dispatch_sm80.cuh has to be included by the file including this +// file. template -void pairwise_matrix_dispatch(OpT distance_op, - IdxT m, - IdxT n, - IdxT k, - const DataT* x, - const DataT* y, - const DataT* x_norm, - const DataT* y_norm, - OutT* out, - FinOpT fin_op, - cudaStream_t stream, - bool is_row_major) -{ - // Create kernel parameter struct. Flip x and y if column major. - IdxT ldx = is_row_major ? k : m; - IdxT ldy = is_row_major ? k : n; - IdxT ld_out = is_row_major ? n : m; - - pairwise_matrix_params params{ - m, n, k, ldx, ldy, ld_out, x, y, x_norm, y_norm, out, fin_op, is_row_major}; - - if (!params.is_row_major) { params.flip_x_and_y(); } + typename SM_compat_t> +void pairwise_matrix_sm80_dispatch(OpT, + pairwise_matrix_params, + SM_compat_t, + cudaStream_t); +template +void pairwise_matrix_instantiation_point(OpT distance_op, + pairwise_matrix_params params, + cudaStream_t stream) +{ // On CUDA 12: // - always execute normal kernel // // On CUDA 11 and below: // - execute CUTLASS-based kernel on SM_80 and above // - execute normal kernel below SM_80 + namespace arch = raft::util::arch; constexpr bool is_ctk_12 = __CUDACC_VER_MAJOR__ == 12; constexpr bool cutlass_op_unavailable = !ops::has_cutlass_op(); if constexpr (is_ctk_12 || cutlass_op_unavailable) { // Always execute legacy kernels on CUDA 12 - auto any_range = raft::arch::SM_range(raft::arch::SM_min(), raft::arch::SM_future()); + auto any_range = arch::SM_range(arch::SM_min(), arch::SM_future()); pairwise_matrix_sm60_dispatch(distance_op, params, any_range, stream); } else { - auto cutlass_range = raft::arch::SM_range(raft::arch::SM_80(), raft::arch::SM_future()); - auto legacy_range = raft::arch::SM_range(raft::arch::SM_min(), raft::arch::SM_80()); + auto cutlass_range = arch::SM_range(arch::SM_80(), arch::SM_future()); + auto legacy_range = arch::SM_range(arch::SM_min(), arch::SM_80()); // Get pointer to SM60 kernel to determine the runtime architecture of the // current system. Other methods to determine the architecture (that do not @@ -79,7 +90,7 @@ void pairwise_matrix_dispatch(OpT distance_op, // https://github.com/NVIDIA/cub/issues/545 auto sm60_wrapper = pairwise_matrix_sm60_get_wrapper(distance_op, params, legacy_range); void* kernel_ptr = reinterpret_cast(sm60_wrapper.kernel_ptr); - auto runtime_arch = raft::arch::kernel_runtime_arch(kernel_ptr); + auto runtime_arch = arch::kernel_runtime_arch(kernel_ptr); if (cutlass_range.contains(runtime_arch)) { // If device is SM_80 or later, use CUTLASS-based kernel. @@ -92,4 +103,35 @@ void pairwise_matrix_dispatch(OpT distance_op, } } +template +void pairwise_matrix_dispatch(OpT distance_op, + IdxT m, + IdxT n, + IdxT k, + const DataT* x, + const DataT* y, + const DataT* x_norm, + const DataT* y_norm, + OutT* out, + FinOpT fin_op, + cudaStream_t stream, + bool is_row_major) +{ + // Create kernel parameter struct. Flip x and y if column major. + IdxT ldx = is_row_major ? k : m; + IdxT ldy = is_row_major ? k : n; + IdxT ld_out = is_row_major ? n : m; + + pairwise_matrix_params params{ + m, n, k, ldx, ldy, ld_out, x, y, x_norm, y_norm, out, fin_op, is_row_major}; + + if (!params.is_row_major) { params.flip_x_and_y(); } + pairwise_matrix_instantiation_point(distance_op, params, stream); +} + }; // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch_layout.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch_layout.cuh index c1e4c08af4..f2b0e59822 100644 --- a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch_layout.cuh +++ b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch_layout.cuh @@ -15,10 +15,11 @@ */ #pragma once -#include "kernel_sm60.cuh" -#include -#include - +#include // std::min +#include // size_t +#include // RAFT_EXPECTS +#include // pairwise_matrix_params +#include // std::integral_constant namespace raft::distance::detail { /** @@ -99,15 +100,15 @@ auto dispatch_layout(bool row_major, int vec_len, F&& f) { if (row_major) { switch (vec_len) { - case 4: return f(std::bool_constant(), vec_len_constant<4>()); - case 2: return f(std::bool_constant(), vec_len_constant<2>()); - default: return f(std::bool_constant(), vec_len_constant<1>()); + case 4: return f(std::true_type(), vec_len_constant<4>()); + case 2: return f(std::true_type(), vec_len_constant<2>()); + default: return f(std::true_type(), vec_len_constant<1>()); } } else { switch (vec_len) { - case 4: return f(std::bool_constant(), vec_len_constant<4>()); - case 2: return f(std::bool_constant(), vec_len_constant<2>()); - default: return f(std::bool_constant(), vec_len_constant<1>()); + case 4: return f(std::false_type(), vec_len_constant<4>()); + case 2: return f(std::false_type(), vec_len_constant<2>()); + default: return f(std::false_type(), vec_len_constant<1>()); } } } diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch_sm60.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch_sm60.cuh index 6e284007ea..2080fbe9cd 100644 --- a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch_sm60.cuh +++ b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch_sm60.cuh @@ -15,10 +15,10 @@ */ #pragma once -#include -#include -#include -#include +#include // std::min +#include // dispatch_layout +#include // pairwise_matrix_sm60_wrapper +#include // raft::linalg::Policy4x4 namespace raft::distance::detail { @@ -35,7 +35,11 @@ pairwise_matrix_sm60_wrapper pairwise_matrix_sm6 { int vec_len = determine_vec_len(params); - return dispatch_layout(params.is_row_major, vec_len, [&](auto row_major, auto vec_len_aligned) { + // f takes compile-time constants row_major and vec_len aligned and returns + // the corresponding kernel wrapper. The wrapper contains the launch + // parameters of the kernel: a pointer to the kernel function, grid size, + // block size, and shared memory size. + auto f = [&](auto row_major, auto vec_len_aligned) { // row_major and vec_len are std::integral_constants of type bool and int // respectively. @@ -46,15 +50,19 @@ pairwise_matrix_sm60_wrapper pairwise_matrix_sm6 // Prevent double, vec_len=4 combination (this is not supported) constexpr int vec_len = std::min(vec_len_op, static_cast(16 / sizeof(DataT))); - typedef typename raft::linalg::Policy4x4::Policy RowPolicy; - typedef typename raft::linalg::Policy4x4::ColPolicy ColPolicy; - typedef typename std::conditional::type Policy; + using RowPolicy = typename raft::linalg::Policy4x4::Policy; + using ColPolicy = typename raft::linalg::Policy4x4::ColPolicy; + using Policy = typename std::conditional::type; auto wrapper = make_pairwise_matrix_sm60_wrapper(distance_op, params, sm_compat_range); return wrapper; - }); + }; + + // Dispatch_layout calls f with appropriate compile time constants based on + // the runtime values of params.is_row_major and vec_len. + return dispatch_layout(params.is_row_major, vec_len, f); } template // std::min -#include -#include +#include // std::min +#include // cutlassDistanceKernel +#include // dispatch_layout namespace raft::distance::detail { @@ -34,7 +34,9 @@ void pairwise_matrix_sm80_dispatch(OpT distance_op, { int vec_len = determine_vec_len(params); - dispatch_layout(params.is_row_major, vec_len, [&](auto row_major, auto vec_len_aligned) { + // f takes compile-time constants row_major and vec_len aligned and runs the + // corresponding cutlass launch code. + auto f = [&](auto row_major, auto vec_len_aligned) { // row_major and vec_len are std::integral_constants of type bool and int // respectively. @@ -56,7 +58,11 @@ void pairwise_matrix_sm80_dispatch(OpT distance_op, params.fin_op, distance_op, stream); - }); + }; + + // Dispatch_layout calls f with appropriate compile time constants based on + // the runtime values of params.is_row_major and vec_len. + dispatch_layout(params.is_row_major, vec_len, f); } }; // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh index 6e3ab7b26b..2d0a98862e 100644 --- a/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh +++ b/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh @@ -15,11 +15,11 @@ */ #pragma once -#include -#include -#include -#include -#include +#include // assert +#include // raft::void_op +#include // PairwiseDistances +#include // pairwise_matrix_params +#include // raft::util::arch::SM_compute_arch namespace raft::distance::detail { @@ -36,43 +36,27 @@ __global__ __launch_bounds__(Policy::Nthreads, 2) void pairwise_matrix_kernel( { // Early exit to minimize the size of the kernel when it is not supposed to be compiled. constexpr SM_compat_t sm_compat_range{}; - if constexpr (!sm_compat_range.contains(raft::arch::SM_compute_arch())) { + if constexpr (!sm_compat_range.contains(raft::util::arch::SM_compute_arch())) { assert(false); return; } extern __shared__ char smem[]; - using AccT = typename OpT::AccT; - - // Wrap operator back into lambdas. This is temporary and should be removed. - // See: https://github.com/rapidsai/raft/issues/1323 - auto core_op = [distance_op] __device__(AccT & acc, DataT & x, DataT & y) { - distance_op.core(acc, x, y); - }; - auto epilog_op = [distance_op] __device__(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], - DataT * regxn, - DataT * regyn, - IdxT gridStrideX, - IdxT gridStrideY) { - // Use .template to disambiguate (See: - // https://en.cppreference.com/w/cpp/language/dependent_name) - distance_op.template epilog(acc, regxn, regyn, gridStrideX, gridStrideY); - }; - + // The epilog is already provided by distance_op. Do not provide additional + // epilogs. + auto epilog_op = raft::void_op(); // No support for row_epilog_op. auto row_epilog_op = raft::void_op(); // Always write output constexpr bool write_out = true; constexpr bool use_norms = distance_op.use_norms; - PairwiseDistances -void pairwise_matrix(OpT distance_op, - pairwise_matrix_params params, - cudaStream_t stream) -{ - dim3 blk(Policy::Nthreads); - // Use .template to disambiguate (See: - // https://en.cppreference.com/w/cpp/language/dependent_name) - size_t smem_size = distance_op.template shared_mem_size(); - // Obtain function pointer to kernel - auto kernel = - pairwise_matrix_kernel; - dim3 grid = launchConfigGenerator(params.m, params.n, smem_size, kernel); - - kernel<<>>(distance_op, params); - RAFT_CUDA_TRY(cudaGetLastError()); -} - // The type of a pointer to the pairwise matrix kernel. The following template // arguments are type-erased: // @@ -181,9 +140,9 @@ pairwise_matrix_sm60_wrapper make_pairwise_matri SM_compat_t sm_compat_range) { dim3 block(Policy::Nthreads); - // Use .template to disambiguate (See: + // Use ::template to disambiguate (See: // https://en.cppreference.com/w/cpp/language/dependent_name) - int smem_size = distance_op.template shared_mem_size(); + int smem_size = OpT::template shared_mem_size(); // Obtain function pointer to kernel auto kernel = pairwise_matrix_kernel; diff --git a/cpp/include/raft/distance/specializations/detail/00_write_template.py b/cpp/include/raft/distance/specializations/detail/00_write_template.py new file mode 100644 index 0000000000..63ae6580b4 --- /dev/null +++ b/cpp/include/raft/distance/specializations/detail/00_write_template.py @@ -0,0 +1,148 @@ +#!/usr/bin/env python3 + +# This template manages all files in this directory, apart from +# inner_product.cuh and kernels.cuh. + + +# NOTE: this template is not perfectly formatted. Use pre-commit to get +# everything in shape again. +start_template = """/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace raft::distance::detail { + +""" + +extern_template = """ +extern template void pairwise_matrix_instantiation_point( + OpT, + pairwise_matrix_params, + cudaStream_t); +""" + +end_template = """} // namespace raft::distance::detail +""" + +data_type_instances = [ + dict( + DataT="float", + AccT="float", + OutT="float", + IdxT="int", + ), + dict( + DataT="double", + AccT="double", + OutT="double", + IdxT="int", + ), +] + + + + +op_instances = [ + dict( + path_prefix="canberra", + OpT="ops::canberra_distance_op", + ), + dict( + path_prefix="correlation", + OpT="ops::correlation_distance_op", + ), + dict( + path_prefix="cosine", + OpT="ops::cosine_distance_op", + # cosine uses CUTLASS for SM80+ + ), + dict( + path_prefix="hamming_unexpanded", + OpT="ops::hamming_distance_op", + ), + dict( + path_prefix="hellinger_expanded", + OpT="ops::hellinger_distance_op", + ), + # inner product is handled by cublas. + dict( + path_prefix="jensen_shannon", + OpT="ops::jensen_shannon_distance_op", + ), + dict( + path_prefix="kl_divergence", + OpT="ops::kl_divergence_op", + ), + dict( + path_prefix="l1", + OpT="ops::l1_distance_op", + ), + dict( + path_prefix="l2_expanded", + OpT="ops::l2_exp_distance_op", + # L2 expanded uses CUTLASS for SM80+ + ), + dict( + path_prefix="l2_unexpanded", + OpT="ops::l2_unexp_distance_op", + ), + dict( + path_prefix="l_inf", + OpT="ops::l_inf_distance_op", + ), + dict( + path_prefix="lp_unexpanded", + OpT="ops::lp_unexp_distance_op", + ), + dict( + path_prefix="russel_rao", + OpT="ops::russel_rao_distance_op", + ), +] + +def fill_in(s, template): + for k, v in template.items(): + s = s.replace(k, v) + return s + +for op_instance in op_instances: + path = fill_in("path_prefix.cuh", op_instance) + with open(path, "w") as f: + f.write(start_template) + + for data_type_instance in data_type_instances: + op_data_instance = { + k : fill_in(v, data_type_instance) + for k, v in op_instance.items() + } + instance = { + **op_data_instance, + **data_type_instance, + "FinopT": "raft::identity_op", + } + + text = fill_in(extern_template, instance) + + f.write(text) + + f.write(end_template) diff --git a/cpp/include/raft/distance/specializations/detail/canberra.cuh b/cpp/include/raft/distance/specializations/detail/canberra.cuh index badce715a5..276c85e5f6 100644 --- a/cpp/include/raft/distance/specializations/detail/canberra.cuh +++ b/cpp/include/raft/distance/specializations/detail/canberra.cuh @@ -16,37 +16,25 @@ #pragma once -#include #include -namespace raft { -namespace distance { -namespace detail { -extern template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - float metric_arg); +namespace raft::distance::detail { -extern template void distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - double metric_arg); -} // namespace detail -} // namespace distance -} // namespace raft +extern template void pairwise_matrix_instantiation_point< + ops::canberra_distance_op, + int, + float, + float, + raft::identity_op>(ops::canberra_distance_op, + pairwise_matrix_params, + cudaStream_t); + +extern template void pairwise_matrix_instantiation_point< + ops::canberra_distance_op, + int, + double, + double, + raft::identity_op>(ops::canberra_distance_op, + pairwise_matrix_params, + cudaStream_t); +} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/correlation.cuh b/cpp/include/raft/distance/specializations/detail/correlation.cuh index 013a0d43a3..f019f678df 100644 --- a/cpp/include/raft/distance/specializations/detail/correlation.cuh +++ b/cpp/include/raft/distance/specializations/detail/correlation.cuh @@ -18,36 +18,23 @@ #include -namespace raft { -namespace distance { -namespace detail { -extern template void -distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - float metric_arg); +namespace raft::distance::detail { -extern template void -distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - double metric_arg); -} // namespace detail -} // namespace distance -} // namespace raft +extern template void pairwise_matrix_instantiation_point< + ops::correlation_distance_op, + int, + float, + float, + raft::identity_op>(ops::correlation_distance_op, + pairwise_matrix_params, + cudaStream_t); + +extern template void pairwise_matrix_instantiation_point< + ops::correlation_distance_op, + int, + double, + double, + raft::identity_op>(ops::correlation_distance_op, + pairwise_matrix_params, + cudaStream_t); +} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/cosine.cuh b/cpp/include/raft/distance/specializations/detail/cosine.cuh index c88bd1b0f6..dcde4ec286 100644 --- a/cpp/include/raft/distance/specializations/detail/cosine.cuh +++ b/cpp/include/raft/distance/specializations/detail/cosine.cuh @@ -18,36 +18,23 @@ #include -namespace raft { -namespace distance { -namespace detail { -extern template void -distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - float metric_arg); +namespace raft::distance::detail { -extern template void -distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - double metric_arg); -} // namespace detail -} // namespace distance -} // namespace raft +extern template void pairwise_matrix_instantiation_point, + int, + float, + float, + raft::identity_op>( + ops::cosine_distance_op, + pairwise_matrix_params, + cudaStream_t); + +extern template void pairwise_matrix_instantiation_point< + ops::cosine_distance_op, + int, + double, + double, + raft::identity_op>(ops::cosine_distance_op, + pairwise_matrix_params, + cudaStream_t); +} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/hamming_unexpanded.cuh b/cpp/include/raft/distance/specializations/detail/hamming_unexpanded.cuh index 3c5cad3315..1d6964fbce 100644 --- a/cpp/include/raft/distance/specializations/detail/hamming_unexpanded.cuh +++ b/cpp/include/raft/distance/specializations/detail/hamming_unexpanded.cuh @@ -18,36 +18,23 @@ #include -namespace raft { -namespace distance { -namespace detail { -extern template void -distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - float metric_arg); +namespace raft::distance::detail { -extern template void -distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - double metric_arg); -} // namespace detail -} // namespace distance -} // namespace raft +extern template void pairwise_matrix_instantiation_point< + ops::hamming_distance_op, + int, + float, + float, + raft::identity_op>(ops::hamming_distance_op, + pairwise_matrix_params, + cudaStream_t); + +extern template void pairwise_matrix_instantiation_point< + ops::hamming_distance_op, + int, + double, + double, + raft::identity_op>(ops::hamming_distance_op, + pairwise_matrix_params, + cudaStream_t); +} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/hellinger_expanded.cuh b/cpp/include/raft/distance/specializations/detail/hellinger_expanded.cuh index bf214c046f..f96a06f919 100644 --- a/cpp/include/raft/distance/specializations/detail/hellinger_expanded.cuh +++ b/cpp/include/raft/distance/specializations/detail/hellinger_expanded.cuh @@ -18,37 +18,23 @@ #include -namespace raft { -namespace distance { -namespace detail { -extern template void -distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - float metric_arg); +namespace raft::distance::detail { -extern template void -distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - double metric_arg); +extern template void pairwise_matrix_instantiation_point< + ops::hellinger_distance_op, + int, + float, + float, + raft::identity_op>(ops::hellinger_distance_op, + pairwise_matrix_params, + cudaStream_t); -} // namespace detail -} // namespace distance -} // namespace raft +extern template void pairwise_matrix_instantiation_point< + ops::hellinger_distance_op, + int, + double, + double, + raft::identity_op>(ops::hellinger_distance_op, + pairwise_matrix_params, + cudaStream_t); +} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/jensen_shannon.cuh b/cpp/include/raft/distance/specializations/detail/jensen_shannon.cuh index 145834fb70..0b58646582 100644 --- a/cpp/include/raft/distance/specializations/detail/jensen_shannon.cuh +++ b/cpp/include/raft/distance/specializations/detail/jensen_shannon.cuh @@ -18,37 +18,23 @@ #include -namespace raft { -namespace distance { -namespace detail { -extern template void -distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - float metric_arg); +namespace raft::distance::detail { -extern template void -distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - double metric_arg); +extern template void pairwise_matrix_instantiation_point< + ops::jensen_shannon_distance_op, + int, + float, + float, + raft::identity_op>(ops::jensen_shannon_distance_op, + pairwise_matrix_params, + cudaStream_t); -} // namespace detail -} // namespace distance -} // namespace raft +extern template void pairwise_matrix_instantiation_point< + ops::jensen_shannon_distance_op, + int, + double, + double, + raft::identity_op>(ops::jensen_shannon_distance_op, + pairwise_matrix_params, + cudaStream_t); +} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/kl_divergence.cuh b/cpp/include/raft/distance/specializations/detail/kl_divergence.cuh index f0928916cd..5c164e0fd4 100644 --- a/cpp/include/raft/distance/specializations/detail/kl_divergence.cuh +++ b/cpp/include/raft/distance/specializations/detail/kl_divergence.cuh @@ -18,36 +18,23 @@ #include -namespace raft { -namespace distance { -namespace detail { -extern template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - float metric_arg); +namespace raft::distance::detail { -extern template void -distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - double metric_arg); +extern template void pairwise_matrix_instantiation_point, + int, + float, + float, + raft::identity_op>( + ops::kl_divergence_op, + pairwise_matrix_params, + cudaStream_t); -} // namespace detail -} // namespace distance -} // namespace raft +extern template void pairwise_matrix_instantiation_point, + int, + double, + double, + raft::identity_op>( + ops::kl_divergence_op, + pairwise_matrix_params, + cudaStream_t); +} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/l1.cuh b/cpp/include/raft/distance/specializations/detail/l1.cuh index 23261a2571..870627d909 100644 --- a/cpp/include/raft/distance/specializations/detail/l1.cuh +++ b/cpp/include/raft/distance/specializations/detail/l1.cuh @@ -18,35 +18,23 @@ #include -namespace raft { -namespace distance { -namespace detail { -extern template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - float metric_arg); +namespace raft::distance::detail { -extern template void distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - double metric_arg); +extern template void pairwise_matrix_instantiation_point, + int, + float, + float, + raft::identity_op>( + ops::l1_distance_op, + pairwise_matrix_params, + cudaStream_t); -} // namespace detail -} // namespace distance -} // namespace raft +extern template void pairwise_matrix_instantiation_point, + int, + double, + double, + raft::identity_op>( + ops::l1_distance_op, + pairwise_matrix_params, + cudaStream_t); +} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/l2_expanded.cuh b/cpp/include/raft/distance/specializations/detail/l2_expanded.cuh index f953018b7d..ee3207bcce 100644 --- a/cpp/include/raft/distance/specializations/detail/l2_expanded.cuh +++ b/cpp/include/raft/distance/specializations/detail/l2_expanded.cuh @@ -18,36 +18,23 @@ #include -namespace raft { -namespace distance { -namespace detail { -extern template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - float metric_arg); +namespace raft::distance::detail { -extern template void -distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - double metric_arg); +extern template void pairwise_matrix_instantiation_point, + int, + float, + float, + raft::identity_op>( + ops::l2_exp_distance_op, + pairwise_matrix_params, + cudaStream_t); -} // namespace detail -} // namespace distance -} // namespace raft +extern template void pairwise_matrix_instantiation_point< + ops::l2_exp_distance_op, + int, + double, + double, + raft::identity_op>(ops::l2_exp_distance_op, + pairwise_matrix_params, + cudaStream_t); +} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/l2_sqrt_expanded.cuh b/cpp/include/raft/distance/specializations/detail/l2_sqrt_expanded.cuh deleted file mode 100644 index 9f5f6a3706..0000000000 --- a/cpp/include/raft/distance/specializations/detail/l2_sqrt_expanded.cuh +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -namespace raft { -namespace distance { -namespace detail { -extern template void -distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - float metric_arg); - -extern template void -distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - double metric_arg); - -} // namespace detail -} // namespace distance -} // namespace raft diff --git a/cpp/include/raft/distance/specializations/detail/l2_sqrt_unexpanded.cuh b/cpp/include/raft/distance/specializations/detail/l2_sqrt_unexpanded.cuh deleted file mode 100644 index 94531ddc33..0000000000 --- a/cpp/include/raft/distance/specializations/detail/l2_sqrt_unexpanded.cuh +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -namespace raft { -namespace distance { -namespace detail { -extern template void -distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - float metric_arg); - -extern template void -distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - double metric_arg); - -} // namespace detail -} // namespace distance -} // namespace raft diff --git a/cpp/include/raft/distance/specializations/detail/l2_unexpanded.cuh b/cpp/include/raft/distance/specializations/detail/l2_unexpanded.cuh index 224b21fce8..1fbf57632b 100644 --- a/cpp/include/raft/distance/specializations/detail/l2_unexpanded.cuh +++ b/cpp/include/raft/distance/specializations/detail/l2_unexpanded.cuh @@ -18,36 +18,23 @@ #include -namespace raft { -namespace distance { -namespace detail { -extern template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - float metric_arg); +namespace raft::distance::detail { -extern template void -distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - double metric_arg); +extern template void pairwise_matrix_instantiation_point< + ops::l2_unexp_distance_op, + int, + float, + float, + raft::identity_op>(ops::l2_unexp_distance_op, + pairwise_matrix_params, + cudaStream_t); -} // namespace detail -} // namespace distance -} // namespace raft +extern template void pairwise_matrix_instantiation_point< + ops::l2_unexp_distance_op, + int, + double, + double, + raft::identity_op>(ops::l2_unexp_distance_op, + pairwise_matrix_params, + cudaStream_t); +} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/l_inf.cuh b/cpp/include/raft/distance/specializations/detail/l_inf.cuh index 9a46d7b488..388d3bf439 100644 --- a/cpp/include/raft/distance/specializations/detail/l_inf.cuh +++ b/cpp/include/raft/distance/specializations/detail/l_inf.cuh @@ -18,35 +18,23 @@ #include -namespace raft { -namespace distance { -namespace detail { -extern template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - float metric_arg); +namespace raft::distance::detail { -extern template void distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - double metric_arg); +extern template void pairwise_matrix_instantiation_point, + int, + float, + float, + raft::identity_op>( + ops::l_inf_distance_op, + pairwise_matrix_params, + cudaStream_t); -} // namespace detail -} // namespace distance -} // namespace raft +extern template void pairwise_matrix_instantiation_point< + ops::l_inf_distance_op, + int, + double, + double, + raft::identity_op>(ops::l_inf_distance_op, + pairwise_matrix_params, + cudaStream_t); +} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/lp_unexpanded.cuh b/cpp/include/raft/distance/specializations/detail/lp_unexpanded.cuh index e05ef02c42..d8e86ce6f2 100644 --- a/cpp/include/raft/distance/specializations/detail/lp_unexpanded.cuh +++ b/cpp/include/raft/distance/specializations/detail/lp_unexpanded.cuh @@ -18,36 +18,23 @@ #include -namespace raft { -namespace distance { -namespace detail { -extern template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - float metric_arg); +namespace raft::distance::detail { -extern template void -distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - double metric_arg); +extern template void pairwise_matrix_instantiation_point< + ops::lp_unexp_distance_op, + int, + float, + float, + raft::identity_op>(ops::lp_unexp_distance_op, + pairwise_matrix_params, + cudaStream_t); -} // namespace detail -} // namespace distance -} // namespace raft +extern template void pairwise_matrix_instantiation_point< + ops::lp_unexp_distance_op, + int, + double, + double, + raft::identity_op>(ops::lp_unexp_distance_op, + pairwise_matrix_params, + cudaStream_t); +} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/russel_rao.cuh b/cpp/include/raft/distance/specializations/detail/russel_rao.cuh index afc87997c0..4803fb8ab0 100644 --- a/cpp/include/raft/distance/specializations/detail/russel_rao.cuh +++ b/cpp/include/raft/distance/specializations/detail/russel_rao.cuh @@ -18,37 +18,23 @@ #include -namespace raft { -namespace distance { -namespace detail { -extern template void -distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - float metric_arg); +namespace raft::distance::detail { -extern template void -distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - double metric_arg); +extern template void pairwise_matrix_instantiation_point< + ops::russel_rao_distance_op, + int, + float, + float, + raft::identity_op>(ops::russel_rao_distance_op, + pairwise_matrix_params, + cudaStream_t); -} // namespace detail -} // namespace distance -} // namespace raft +extern template void pairwise_matrix_instantiation_point< + ops::russel_rao_distance_op, + int, + double, + double, + raft::identity_op>(ops::russel_rao_distance_op, + pairwise_matrix_params, + cudaStream_t); +} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/distance.cuh b/cpp/include/raft/distance/specializations/distance.cuh index 8daa398b49..a34f696e9e 100644 --- a/cpp/include/raft/distance/specializations/distance.cuh +++ b/cpp/include/raft/distance/specializations/distance.cuh @@ -27,8 +27,6 @@ #include #include #include -#include -#include #include #include #include diff --git a/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh b/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh index 4e18a210d4..4a571c1447 100644 --- a/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh +++ b/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh @@ -22,6 +22,8 @@ #include "processing.cuh" #include #include +#include +#include #include #include @@ -183,13 +185,11 @@ DI void updateSortedWarpQ( } } -template Pair; @@ -222,295 +222,279 @@ __global__ __launch_bounds__(Policy::Nthreads, 2) void fusedL2kNN(const DataT* x using namespace raft::neighbors::detail::faiss_select; typedef WarpSelect, NumWarpQ, NumThreadQ, 32> myWarpSelect; - auto rowEpilog_lambda = [m, n, numOfNN, out_dists, out_inds, mutexes] __device__( - IdxT gridStrideY) { - if (gridDim.x == 1) { return; } - - Pair* shDumpKV = nullptr; - if (useNorms) { - shDumpKV = (Pair*)(&smem[Policy::SmemSize + ((Policy::Mblk + Policy::Nblk) * sizeof(DataT))]); - } else { - shDumpKV = (Pair*)(&smem[Policy::SmemSize]); - } - - const int lid = threadIdx.x % warpSize; - const IdxT starty = gridStrideY + (threadIdx.x / Policy::AccThCols); - - // 0 -> consumer done consuming the buffer. - // -1 -> consumer started consuming the buffer - // -2 -> producer done filling the buffer - // 1 -> prod acquired to fill the buffer - if (blockIdx.x == 0) { - auto cta_processed = 0; - myWarpSelect heapArr1(identity, keyMax, numOfNN); - myWarpSelect heapArr2(identity, keyMax, numOfNN); - myWarpSelect* heapArr[] = {&heapArr1, &heapArr2}; - __syncwarp(); - - loadAllWarpQShmem(heapArr, &shDumpKV[0], m, numOfNN); - - while (cta_processed < gridDim.x - 1) { - if (threadIdx.x == 0) { - while (atomicCAS((int*)&mutexes[gridStrideY / Policy::Mblk], -2, -1) != -2) - ; - } - __threadfence(); - __syncthreads(); + auto rowEpilog_lambda = + [m, n, &distance_op, numOfNN, out_dists, out_inds, mutexes] __device__(IdxT gridStrideY) { + if (gridDim.x == 1) { return; } + + // Use ::template to disambiguate (See: + // https://en.cppreference.com/w/cpp/language/dependent_name) + int smem_offset = OpT::template shared_mem_size(); + Pair* shDumpKV = (Pair*)(&smem[smem_offset]); + + const int lid = threadIdx.x % warpSize; + const IdxT starty = gridStrideY + (threadIdx.x / Policy::AccThCols); + + // 0 -> consumer done consuming the buffer. + // -1 -> consumer started consuming the buffer + // -2 -> producer done filling the buffer + // 1 -> prod acquired to fill the buffer + if (blockIdx.x == 0) { + auto cta_processed = 0; + myWarpSelect heapArr1(identity, keyMax, numOfNN); + myWarpSelect heapArr2(identity, keyMax, numOfNN); + myWarpSelect* heapArr[] = {&heapArr1, &heapArr2}; + __syncwarp(); + + loadAllWarpQShmem(heapArr, &shDumpKV[0], m, numOfNN); + + while (cta_processed < gridDim.x - 1) { + if (threadIdx.x == 0) { + while (atomicCAS((int*)&mutexes[gridStrideY / Policy::Mblk], -2, -1) != -2) + ; + } + __threadfence(); + __syncthreads(); #pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { - const auto rowId = starty + i * Policy::AccThRows; - if (rowId < m) { + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + const auto rowId = starty + i * Policy::AccThRows; + if (rowId < m) { #pragma unroll - for (int j = 0; j < myWarpSelect::kNumWarpQRegisters; ++j) { - Pair otherKV; - otherKV.value = identity; - otherKV.key = keyMax; - const auto idx = j * warpSize + lid; - if (idx < numOfNN) { - otherKV.value = out_dists[rowId * numOfNN + idx]; - otherKV.key = (uint32_t)out_inds[rowId * numOfNN + idx]; - const auto shMemRowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; - shDumpKV[shMemRowId * numOfNN + idx] = otherKV; + for (int j = 0; j < myWarpSelect::kNumWarpQRegisters; ++j) { + Pair otherKV; + otherKV.value = identity; + otherKV.key = keyMax; + const auto idx = j * warpSize + lid; + if (idx < numOfNN) { + otherKV.value = out_dists[rowId * numOfNN + idx]; + otherKV.key = (uint32_t)out_inds[rowId * numOfNN + idx]; + const auto shMemRowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; + shDumpKV[shMemRowId * numOfNN + idx] = otherKV; + } } } } - } - __threadfence(); - __syncthreads(); + __threadfence(); + __syncthreads(); - if (threadIdx.x == 0) { atomicExch((int*)&mutexes[gridStrideY / Policy::Mblk], 0); } - __threadfence(); + if (threadIdx.x == 0) { atomicExch((int*)&mutexes[gridStrideY / Policy::Mblk], 0); } + __threadfence(); // Perform merging of otherKV with topk's across warp. #pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { - const auto rowId = starty + i * Policy::AccThRows; - if (rowId < m) { + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + const auto rowId = starty + i * Policy::AccThRows; + if (rowId < m) { #pragma unroll - for (int j = 0; j < myWarpSelect::kNumWarpQRegisters; ++j) { - Pair otherKV; - otherKV.value = identity; - otherKV.key = keyMax; - const auto idx = j * warpSize + lid; - if (idx < numOfNN) { - const auto shMemRowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; - otherKV = shDumpKV[shMemRowId * numOfNN + idx]; + for (int j = 0; j < myWarpSelect::kNumWarpQRegisters; ++j) { + Pair otherKV; + otherKV.value = identity; + otherKV.key = keyMax; + const auto idx = j * warpSize + lid; + if (idx < numOfNN) { + const auto shMemRowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; + otherKV = shDumpKV[shMemRowId * numOfNN + idx]; + } + heapArr[i]->add(otherKV.value, otherKV.key); } - heapArr[i]->add(otherKV.value, otherKV.key); } } + cta_processed++; } - cta_processed++; - } #pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { - const auto rowId = starty + i * Policy::AccThRows; - if (rowId < m) { - bool needSort = (heapArr[i]->numVals > 0); - needSort = __any_sync(0xffffffff, needSort); - if (needSort) { heapArr[i]->reduce(); } + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + const auto rowId = starty + i * Policy::AccThRows; + if (rowId < m) { + bool needSort = (heapArr[i]->numVals > 0); + needSort = __any_sync(0xffffffff, needSort); + if (needSort) { heapArr[i]->reduce(); } + } } - } - storeWarpQGmem(heapArr, out_dists, out_inds, m, numOfNN, starty); - } else { - if (threadIdx.x == 0) { - while (atomicCAS((int*)&mutexes[gridStrideY / Policy::Mblk], 0, 1) != 0) - ; - } - __threadfence(); - __syncthreads(); + storeWarpQGmem(heapArr, out_dists, out_inds, m, numOfNN, starty); + } else { + if (threadIdx.x == 0) { + while (atomicCAS((int*)&mutexes[gridStrideY / Policy::Mblk], 0, 1) != 0) + ; + } + __threadfence(); + __syncthreads(); #pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { - const auto rowId = starty + i * Policy::AccThRows; - if (rowId < m) { - for (int idx = lid; idx < numOfNN; idx += warpSize) { - const auto shMemRowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; - Pair KVPair = shDumpKV[shMemRowId * numOfNN + idx]; - out_dists[rowId * numOfNN + idx] = KVPair.value; - out_inds[rowId * numOfNN + idx] = (IdxT)KVPair.key; + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + const auto rowId = starty + i * Policy::AccThRows; + if (rowId < m) { + for (int idx = lid; idx < numOfNN; idx += warpSize) { + const auto shMemRowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; + Pair KVPair = shDumpKV[shMemRowId * numOfNN + idx]; + out_dists[rowId * numOfNN + idx] = KVPair.value; + out_inds[rowId * numOfNN + idx] = (IdxT)KVPair.key; + } } } - } - __threadfence(); - __syncthreads(); - - if (threadIdx.x == 0) { atomicExch((int*)&mutexes[gridStrideY / Policy::Mblk], -2); } - __threadfence(); - } - }; + __threadfence(); + __syncthreads(); - // epilogue operation lambda for final value calculation - auto epilog_lambda = [numOfNN, m, n, ldd, out_dists, out_inds, keyMax, identity] __device__( - AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], - DataT * regxn, - DataT * regyn, - IdxT gridStrideX, - IdxT gridStrideY) { - if (useNorms) { -#pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = 0; j < Policy::AccColsPerTh; ++j) { - acc[i][j] = regxn[i] + regyn[j] - (DataT)2.0 * acc[i][j]; - } + if (threadIdx.x == 0) { atomicExch((int*)&mutexes[gridStrideY / Policy::Mblk], -2); } + __threadfence(); } - } + }; - Pair* shDumpKV = nullptr; - if (useNorms) { - constexpr size_t shmemSize = - Policy::SmemSize + ((Policy::Mblk + Policy::Nblk) * sizeof(DataT)); - shDumpKV = (Pair*)(&smem[shmemSize]); - } else { - shDumpKV = (Pair*)(&smem[Policy::SmemSize]); - } + // epilogue operation lambda for final value calculation + auto epilog_lambda = + [&distance_op, numOfNN, m, n, ldd, out_dists, out_inds, keyMax, identity] __device__( + AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], + DataT * regxn, + DataT * regyn, + IdxT gridStrideX, + IdxT gridStrideY) { + // Use ::template to disambiguate (See: + // https://en.cppreference.com/w/cpp/language/dependent_name) + int smem_offset = OpT::template shared_mem_size(); + Pair* shDumpKV = (Pair*)(&smem[smem_offset]); + + constexpr uint32_t mask = 0xffffffffu; + const IdxT starty = gridStrideY + (threadIdx.x / Policy::AccThCols); + const IdxT startx = gridStrideX + (threadIdx.x % Policy::AccThCols); + const int lid = raft::laneId(); - constexpr uint32_t mask = 0xffffffffu; - const IdxT starty = gridStrideY + (threadIdx.x / Policy::AccThCols); - const IdxT startx = gridStrideX + (threadIdx.x % Policy::AccThCols); - const int lid = raft::laneId(); - - myWarpSelect heapArr1(identity, keyMax, numOfNN); - myWarpSelect heapArr2(identity, keyMax, numOfNN); - myWarpSelect* heapArr[] = {&heapArr1, &heapArr2}; - if (usePrevTopKs) { - if (gridStrideX == blockIdx.x * Policy::Nblk) { - loadPrevTopKsGmemWarpQ(heapArr, out_dists, out_inds, m, numOfNN, starty); + myWarpSelect heapArr1(identity, keyMax, numOfNN); + myWarpSelect heapArr2(identity, keyMax, numOfNN); + myWarpSelect* heapArr[] = {&heapArr1, &heapArr2}; + if (usePrevTopKs) { + if (gridStrideX == blockIdx.x * Policy::Nblk) { + loadPrevTopKsGmemWarpQ(heapArr, out_dists, out_inds, m, numOfNN, starty); + } } - } - if (gridStrideX > blockIdx.x * Policy::Nblk) { + if (gridStrideX > blockIdx.x * Policy::Nblk) { #pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { - const auto rowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; - Pair tempKV = shDumpKV[(rowId * numOfNN) + numOfNN - 1]; - heapArr[i]->warpKTop = tempKV.value; - } + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + const auto rowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; + Pair tempKV = shDumpKV[(rowId * numOfNN) + numOfNN - 1]; + heapArr[i]->warpKTop = tempKV.value; + } - // total vals can atmost be 256, (32*8) - int numValsWarpTopK[Policy::AccRowsPerTh]; - int anyWarpTopKs = 0; + // total vals can atmost be 256, (32*8) + int numValsWarpTopK[Policy::AccRowsPerTh]; + int anyWarpTopKs = 0; #pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { - const auto rowId = starty + i * Policy::AccThRows; - numValsWarpTopK[i] = 0; - if (rowId < m) { + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + const auto rowId = starty + i * Policy::AccThRows; + numValsWarpTopK[i] = 0; + if (rowId < m) { #pragma unroll - for (int j = 0; j < Policy::AccColsPerTh; ++j) { - const auto colId = startx + j * Policy::AccThCols; - if (colId < ldd) { - if (acc[i][j] < heapArr[i]->warpKTop) { numValsWarpTopK[i]++; } + for (int j = 0; j < Policy::AccColsPerTh; ++j) { + const auto colId = startx + j * Policy::AccThCols; + if (colId < ldd) { + if (acc[i][j] < heapArr[i]->warpKTop) { numValsWarpTopK[i]++; } + } } + anyWarpTopKs += numValsWarpTopK[i]; } - anyWarpTopKs += numValsWarpTopK[i]; } - } - anyWarpTopKs = __syncthreads_or(anyWarpTopKs > 0); - if (anyWarpTopKs) { - Pair* allWarpTopKs = (Pair*)(&smem[0]); - uint32_t needScanSort[Policy::AccRowsPerTh]; + anyWarpTopKs = __syncthreads_or(anyWarpTopKs > 0); + if (anyWarpTopKs) { + Pair* allWarpTopKs = (Pair*)(&smem[0]); + uint32_t needScanSort[Policy::AccRowsPerTh]; #pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { - const auto gmemRowId = starty + i * Policy::AccThRows; - needScanSort[i] = 0; - if (gmemRowId < m) { - int myVals = numValsWarpTopK[i]; - needScanSort[i] = __ballot_sync(mask, myVals > 0); - if (needScanSort[i]) { + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + const auto gmemRowId = starty + i * Policy::AccThRows; + needScanSort[i] = 0; + if (gmemRowId < m) { + int myVals = numValsWarpTopK[i]; + needScanSort[i] = __ballot_sync(mask, myVals > 0); + if (needScanSort[i]) { #pragma unroll - for (unsigned int k = 1; k <= 16; k *= 2) { - const unsigned int n = __shfl_up_sync(mask, numValsWarpTopK[i], k); - if (lid >= k) { numValsWarpTopK[i] += n; } + for (unsigned int k = 1; k <= 16; k *= 2) { + const unsigned int n = __shfl_up_sync(mask, numValsWarpTopK[i], k); + if (lid >= k) { numValsWarpTopK[i] += n; } + } } + // As each thread will know its total vals to write. + // we only store its starting location. + numValsWarpTopK[i] -= myVals; } - // As each thread will know its total vals to write. - // we only store its starting location. - numValsWarpTopK[i] -= myVals; - } - if (needScanSort[i]) { - const auto rowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; - if (gmemRowId < m) { - if (needScanSort[i] & ((uint32_t)1 << lid)) { + if (needScanSort[i]) { + const auto rowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; + if (gmemRowId < m) { + if (needScanSort[i] & ((uint32_t)1 << lid)) { #pragma unroll - for (int j = 0; j < Policy::AccColsPerTh; ++j) { - const auto colId = startx + j * Policy::AccThCols; - if (colId < ldd) { - if (acc[i][j] < heapArr[i]->warpKTop) { - Pair otherKV = {colId, acc[i][j]}; - allWarpTopKs[rowId * (256) + numValsWarpTopK[i]] = otherKV; - numValsWarpTopK[i]++; + for (int j = 0; j < Policy::AccColsPerTh; ++j) { + const auto colId = startx + j * Policy::AccThCols; + if (colId < ldd) { + if (acc[i][j] < heapArr[i]->warpKTop) { + Pair otherKV = {colId, acc[i][j]}; + allWarpTopKs[rowId * (256) + numValsWarpTopK[i]] = otherKV; + numValsWarpTopK[i]++; + } } } } + __syncwarp(); + const int finalNumVals = raft::shfl(numValsWarpTopK[i], 31); + loadWarpQShmem(heapArr[i], &shDumpKV[0], rowId, numOfNN); + updateSortedWarpQ( + heapArr[i], &allWarpTopKs[0], rowId, finalNumVals); } - __syncwarp(); - const int finalNumVals = raft::shfl(numValsWarpTopK[i], 31); - loadWarpQShmem(heapArr[i], &shDumpKV[0], rowId, numOfNN); - updateSortedWarpQ( - heapArr[i], &allWarpTopKs[0], rowId, finalNumVals); } } - } - __syncthreads(); + __syncthreads(); #pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { - if (needScanSort[i]) { - const auto rowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; - const auto gmemRowId = starty + i * Policy::AccThRows; - if (gmemRowId < m) { - storeWarpQShmem(heapArr[i], shDumpKV, rowId, numOfNN); + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + if (needScanSort[i]) { + const auto rowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; + const auto gmemRowId = starty + i * Policy::AccThRows; + if (gmemRowId < m) { + storeWarpQShmem(heapArr[i], shDumpKV, rowId, numOfNN); + } } } } - } - } else { + } else { #pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { - const auto gmemRowId = starty + i * Policy::AccThRows; - const auto shMemRowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; - if (gmemRowId < m) { + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + const auto gmemRowId = starty + i * Policy::AccThRows; + const auto shMemRowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; + if (gmemRowId < m) { #pragma unroll - for (int j = 0; j < Policy::AccColsPerTh; ++j) { - const auto colId = startx + j * Policy::AccThCols; - Pair otherKV = {keyMax, identity}; - if (colId < ldd) { - otherKV.value = acc[i][j]; - otherKV.key = colId; + for (int j = 0; j < Policy::AccColsPerTh; ++j) { + const auto colId = startx + j * Policy::AccThCols; + Pair otherKV = {keyMax, identity}; + if (colId < ldd) { + otherKV.value = acc[i][j]; + otherKV.key = colId; + } + heapArr[i]->add(otherKV.value, otherKV.key); } - heapArr[i]->add(otherKV.value, otherKV.key); - } - bool needSort = (heapArr[i]->numVals > 0); - needSort = __any_sync(mask, needSort); - if (needSort) { heapArr[i]->reduce(); } - storeWarpQShmem(heapArr[i], shDumpKV, shMemRowId, numOfNN); + bool needSort = (heapArr[i]->numVals > 0); + needSort = __any_sync(mask, needSort); + if (needSort) { heapArr[i]->reduce(); } + storeWarpQShmem(heapArr[i], shDumpKV, shMemRowId, numOfNN); + } } } - } - if (((gridStrideX + Policy::Nblk * gridDim.x) >= n) && gridDim.x == 1) { - // This is last iteration of grid stride X - loadAllWarpQShmem(heapArr, &shDumpKV[0], m, numOfNN); - storeWarpQGmem(heapArr, out_dists, out_inds, m, numOfNN, starty); - } - }; + if (((gridStrideX + Policy::Nblk * gridDim.x) >= n) && gridDim.x == 1) { + // This is last iteration of grid stride X + loadAllWarpQShmem(heapArr, &shDumpKV[0], m, numOfNN); + storeWarpQGmem(heapArr, out_dists, out_inds, m, numOfNN, starty); + } + }; - raft::distance::detail::PairwiseDistances + write_out> obj(x, y, m, @@ -521,9 +505,9 @@ __global__ __launch_bounds__(Policy::Nthreads, 2) void fusedL2kNN(const DataT* x ldd, _xn, _yn, - nullptr, + nullptr, // output ptr, can be null as write_out == false. smem, - core_op, + distance_op, epilog_lambda, fin_op, rowEpilog_lambda); @@ -562,38 +546,32 @@ void fusedL2UnexpKnnImpl(const DataT* x, dim3 blk(KPolicy::Nthreads); // Accumulation operation lambda - auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) { - const auto diff = x - y; - acc += diff * diff; - }; - typedef cub::KeyValuePair Pair; - if (isRowMajor) { - constexpr auto fusedL2UnexpKnn32RowMajor = fusedL2kNN distance_op{sqrt}; + raft::identity_op fin_op{}; + + if constexpr (isRowMajor) { + constexpr auto fusedL2UnexpKnn32RowMajor = fusedL2kNN; - constexpr auto fusedL2UnexpKnn64RowMajor = fusedL2kNN; + constexpr auto fusedL2UnexpKnn64RowMajor = fusedL2kNN; + isRowMajor>; auto fusedL2UnexpKnnRowMajor = fusedL2UnexpKnn32RowMajor; if (numOfNN <= 32) { @@ -604,8 +582,10 @@ void fusedL2UnexpKnnImpl(const DataT* x, ASSERT(numOfNN <= 64, "fusedL2kNN: num of nearest neighbors must be <= 64"); } - const auto sharedMemSize = KPolicy::SmemSize + (KPolicy::Mblk * numOfNN * sizeof(Pair)); - dim3 grid = raft::distance::detail::launchConfigGenerator( + const auto sharedMemSize = + distance_op.template shared_mem_size() + KPolicy::Mblk * numOfNN * sizeof(Pair); + + dim3 grid = raft::distance::detail::launchConfigGenerator( m, n, sharedMemSize, fusedL2UnexpKnnRowMajor); if (grid.x > 1) { @@ -628,9 +608,8 @@ void fusedL2UnexpKnnImpl(const DataT* x, lda, ldb, ldd, - core_lambda, - raft::identity_op{}, - sqrt, + distance_op, + fin_op, (uint32_t)numOfNN, (int*)workspace, out_dists, @@ -753,36 +732,33 @@ void fusedL2ExpKnnImpl(const DataT* x, ASSERT(workspace != nullptr, "workspace is null"); dim3 blk(KPolicy::Nthreads); - // Accumulation operation lambda - auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) { acc += x * y; }; typedef cub::KeyValuePair Pair; - if (isRowMajor) { - constexpr auto fusedL2ExpKnn32RowMajor = fusedL2kNN distance_op{sqrt}; + raft::identity_op fin_op{}; + + if constexpr (isRowMajor) { + constexpr auto fusedL2ExpKnn32RowMajor = fusedL2kNN; - constexpr auto fusedL2ExpKnn64RowMajor = fusedL2kNN; + constexpr auto fusedL2ExpKnn64RowMajor = fusedL2kNN; + isRowMajor>; auto fusedL2ExpKnnRowMajor = fusedL2ExpKnn32RowMajor; if (numOfNN <= 32) { @@ -793,9 +769,8 @@ void fusedL2ExpKnnImpl(const DataT* x, ASSERT(numOfNN <= 64, "fusedL2kNN: num of nearest neighbors must be <= 64"); } - const auto sharedMemSize = KPolicy::SmemSize + - ((KPolicy::Mblk + KPolicy::Nblk) * sizeof(DataT)) + - (KPolicy::Mblk * numOfNN * sizeof(Pair)); + const auto sharedMemSize = + distance_op.template shared_mem_size() + (KPolicy::Mblk * numOfNN * sizeof(Pair)); dim3 grid = raft::distance::detail::launchConfigGenerator( m, n, sharedMemSize, fusedL2ExpKnnRowMajor); int32_t* mutexes = nullptr; @@ -835,9 +810,8 @@ void fusedL2ExpKnnImpl(const DataT* x, lda, ldb, ldd, - core_lambda, - raft::identity_op{}, - sqrt, + distance_op, + fin_op, (uint32_t)numOfNN, mutexes, out_dists, diff --git a/cpp/include/raft/util/arch.cuh b/cpp/include/raft/util/arch.cuh index 8c48b87269..dc35b10063 100644 --- a/cpp/include/raft/util/arch.cuh +++ b/cpp/include/raft/util/arch.cuh @@ -15,25 +15,27 @@ */ #pragma once -namespace raft::arch { +#include // RAFT_CUDA_TRY -/* raft::arch provides the following facilities: +namespace raft::util::arch { + +/* raft::util::arch provides the following facilities: * - * - raft::arch::SM_XX : hardcoded compile-time constants for various compute - * architectures. The values raft::arch::SM_min and raft::arch::SM_future + * - raft::util::arch::SM_XX : hardcoded compile-time constants for various compute + * architectures. The values raft::util::arch::SM_min and raft::util::arch::SM_future * represent architectures that are always smaller and larger (respectively) * than any architecture that can be encountered in practice. * - * - raft::arch::SM_compute_arch : a compile-time value for the *current* + * - raft::util::arch::SM_compute_arch : a compile-time value for the *current* * compute architecture that a kernel is compiled with. It can only be used * inside kernels with a template argument. * - * - raft::arch::kernel_runtime_arch : a function that computes at *run-time* + * - raft::util::arch::kernel_runtime_arch : a function that computes at *run-time* * which version of a kernel will launch (i.e., it will return the compute * architecture of the version of the kernel that will be launched by the * driver). * - * - raft::arch::SM_range : a compile-time value to represent an open interval + * - raft::util::arch::SM_range : a compile-time value to represent an open interval * of compute architectures. This can be used to check if the current * compile-time architecture is in a specified compatibility range. */ @@ -46,9 +48,6 @@ struct SM_generic { public: __host__ __device__ constexpr int value() const { return n; } }; - -// A dummy kernel that is used to determine the runtime architecture. -__global__ inline void dummy_runtime_kernel() {} } // namespace detail // A list of architectures that RAPIDS explicitly builds for (SM60, ..., SM90) @@ -119,7 +118,7 @@ struct SM_runtime { inline SM_runtime kernel_runtime_arch(void* kernel) { cudaFuncAttributes attributes; - cudaFuncGetAttributes(&attributes, kernel); + RAFT_CUDA_TRY(cudaFuncGetAttributes(&attributes, kernel)); return SM_runtime(10 * attributes.ptxVersion); } @@ -143,4 +142,4 @@ struct SM_range { } }; -} // namespace raft::arch +} // namespace raft::util::arch diff --git a/cpp/include/raft/util/cuda_dev_essentials.cuh b/cpp/include/raft/util/cuda_dev_essentials.cuh new file mode 100644 index 0000000000..5080dc33ee --- /dev/null +++ b/cpp/include/raft/util/cuda_dev_essentials.cuh @@ -0,0 +1,91 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +// This file provides a few essential functions for use in __device__ code. The +// scope is necessarily limited to ensure that compilation times are minimized. +// Please make sure not to include large / expensive files from here. + +namespace raft { + +/** helper macro for device inlined functions */ +#define DI inline __device__ +#define HDI inline __host__ __device__ +#define HD __host__ __device__ + +/** + * @brief Provide a ceiling division operation ie. ceil(a / b) + * @tparam IntType supposed to be only integers for now! + */ +template +constexpr HDI IntType ceildiv(IntType a, IntType b) +{ + return (a + b - 1) / b; +} + +/** + * @brief Provide an alignment function ie. ceil(a / b) * b + * @tparam IntType supposed to be only integers for now! + */ +template +constexpr HDI IntType alignTo(IntType a, IntType b) +{ + return ceildiv(a, b) * b; +} + +/** + * @brief Provide an alignment function ie. (a / b) * b + * @tparam IntType supposed to be only integers for now! + */ +template +constexpr HDI IntType alignDown(IntType a, IntType b) +{ + return (a / b) * b; +} + +/** + * @brief Check if the input is a power of 2 + * @tparam IntType data type (checked only for integers) + */ +template +constexpr HDI bool isPo2(IntType num) +{ + return (num && !(num & (num - 1))); +} + +/** + * @brief Give logarithm of the number to base-2 + * @tparam IntType data type (checked only for integers) + */ +template +constexpr HDI IntType log2(IntType num, IntType ret = IntType(0)) +{ + return num <= IntType(1) ? ret : log2(num >> IntType(1), ++ret); +} + +/** number of threads per warp */ +static const int WarpSize = 32; + +/** get the laneId of the current thread */ +DI int laneId() +{ + int id; + asm("mov.s32 %0, %%laneid;" : "=r"(id)); + return id; +} + +} // namespace raft diff --git a/cpp/include/raft/util/cuda_rt_essentials.hpp b/cpp/include/raft/util/cuda_rt_essentials.hpp new file mode 100644 index 0000000000..e5f3af4e61 --- /dev/null +++ b/cpp/include/raft/util/cuda_rt_essentials.hpp @@ -0,0 +1,60 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +// This file provides a few essential functions that wrap the CUDA runtime API. +// The scope is necessarily limited to ensure that compilation times are +// minimized. Please make sure not to include large / expensive files from here. + +#include +#include + +namespace raft { + +/** + * @brief Exception thrown when a CUDA error is encountered. + */ +struct cuda_error : public raft::exception { + explicit cuda_error(char const* const message) : raft::exception(message) {} + explicit cuda_error(std::string const& message) : raft::exception(message) {} +}; + +} // namespace raft + +/** + * @brief Error checking macro for CUDA runtime API functions. + * + * Invokes a CUDA runtime API function call, if the call does not return + * cudaSuccess, invokes cudaGetLastError() to clear the error and throws an + * exception detailing the CUDA error that occurred + * + */ +#define RAFT_CUDA_TRY(call) \ + do { \ + cudaError_t const status = call; \ + if (status != cudaSuccess) { \ + cudaGetLastError(); \ + std::string msg{}; \ + SET_ERROR_MSG(msg, \ + "CUDA error encountered at: ", \ + "call='%s', Reason=%s:%s", \ + #call, \ + cudaGetErrorName(status), \ + cudaGetErrorString(status)); \ + throw raft::cuda_error(msg); \ + } \ + } while (0) diff --git a/cpp/include/raft/util/cuda_utils.cuh b/cpp/include/raft/util/cuda_utils.cuh index 5be9dc999a..687a6b4651 100644 --- a/cpp/include/raft/util/cuda_utils.cuh +++ b/cpp/include/raft/util/cuda_utils.cuh @@ -23,113 +23,10 @@ #include #include #include - -#ifndef ENABLE_MEMCPY_ASYNC -// enable memcpy_async interface by default for newer GPUs -#if __CUDA_ARCH__ >= 800 -#define ENABLE_MEMCPY_ASYNC 1 -#endif -#else // ENABLE_MEMCPY_ASYNC -// disable memcpy_async for all older GPUs -#if __CUDA_ARCH__ < 800 -#define ENABLE_MEMCPY_ASYNC 0 -#endif -#endif // ENABLE_MEMCPY_ASYNC +#include namespace raft { -/** helper macro for device inlined functions */ -#define DI inline __device__ -#define HDI inline __host__ __device__ -#define HD __host__ __device__ - -/** - * @brief Provide a ceiling division operation ie. ceil(a / b) - * @tparam IntType supposed to be only integers for now! - */ -template -constexpr HDI IntType ceildiv(IntType a, IntType b) -{ - return (a + b - 1) / b; -} - -/** - * @brief Provide an alignment function ie. ceil(a / b) * b - * @tparam IntType supposed to be only integers for now! - */ -template -constexpr HDI IntType alignTo(IntType a, IntType b) -{ - return ceildiv(a, b) * b; -} - -/** - * @brief Provide an alignment function ie. (a / b) * b - * @tparam IntType supposed to be only integers for now! - */ -template -constexpr HDI IntType alignDown(IntType a, IntType b) -{ - return (a / b) * b; -} - -/** - * @brief Check if the input is a power of 2 - * @tparam IntType data type (checked only for integers) - */ -template -constexpr HDI bool isPo2(IntType num) -{ - return (num && !(num & (num - 1))); -} - -/** - * @brief Give logarithm of the number to base-2 - * @tparam IntType data type (checked only for integers) - */ -template -constexpr HDI IntType log2(IntType num, IntType ret = IntType(0)) -{ - return num <= IntType(1) ? ret : log2(num >> IntType(1), ++ret); -} - -/** Device function to apply the input lambda across threads in the grid */ -template -DI void forEach(int num, L lambda) -{ - int idx = (blockDim.x * blockIdx.x) + threadIdx.x; - const int numThreads = blockDim.x * gridDim.x; -#pragma unroll - for (int itr = 0; itr < ItemsPerThread; ++itr, idx += numThreads) { - if (idx < num) lambda(idx, itr); - } -} - -/** number of threads per warp */ -static const int WarpSize = 32; - -/** get the laneId of the current thread */ -DI int laneId() -{ - int id; - asm("mov.s32 %0, %%laneid;" : "=r"(id)); - return id; -} - -/** - * @brief Swap two values - * @tparam T the datatype of the values - * @param a first input - * @param b second input - */ -template -HDI void swapVals(T& a, T& b) -{ - T tmp = a; - a = b; - b = tmp; -} - /** Device function to have atomic add support for older archs */ template DI void myAtomicAdd(Type* address, Type val) diff --git a/cpp/include/raft/util/cudart_utils.hpp b/cpp/include/raft/util/cudart_utils.hpp index 0feb188ad8..0a7ca23028 100644 --- a/cpp/include/raft/util/cudart_utils.hpp +++ b/cpp/include/raft/util/cudart_utils.hpp @@ -25,6 +25,7 @@ #pragma once #include +#include #include #include #include @@ -40,42 +41,7 @@ #include #include #include - -namespace raft { - -/** - * @brief Exception thrown when a CUDA error is encountered. - */ -struct cuda_error : public raft::exception { - explicit cuda_error(char const* const message) : raft::exception(message) {} - explicit cuda_error(std::string const& message) : raft::exception(message) {} -}; - -} // namespace raft - -/** - * @brief Error checking macro for CUDA runtime API functions. - * - * Invokes a CUDA runtime API function call, if the call does not return - * cudaSuccess, invokes cudaGetLastError() to clear the error and throws an - * exception detailing the CUDA error that occurred - * - */ -#define RAFT_CUDA_TRY(call) \ - do { \ - cudaError_t const status = call; \ - if (status != cudaSuccess) { \ - cudaGetLastError(); \ - std::string msg{}; \ - SET_ERROR_MSG(msg, \ - "CUDA error encountered at: ", \ - "call='%s', Reason=%s:%s", \ - #call, \ - cudaGetErrorName(status), \ - cudaGetErrorString(status)); \ - throw raft::cuda_error(msg); \ - } \ - } while (0) +#include // FIXME: Remove after consumers rename #ifndef CUDA_TRY diff --git a/cpp/include/raft/util/device_loads_stores.cuh b/cpp/include/raft/util/device_loads_stores.cuh index 2b87c44d60..c9bda26b81 100644 --- a/cpp/include/raft/util/device_loads_stores.cuh +++ b/cpp/include/raft/util/device_loads_stores.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,8 @@ #pragma once -#include +#include // uintX_t +#include // DI namespace raft { diff --git a/cpp/src/distance/specializations/detail/00_write_template.py b/cpp/src/distance/specializations/detail/00_write_template.py new file mode 100644 index 0000000000..3f2f853569 --- /dev/null +++ b/cpp/src/distance/specializations/detail/00_write_template.py @@ -0,0 +1,159 @@ +#!/usr/bin/env python3 + +# NOTE: this template is not perfectly formatted. Use pre-commit to get +# everything in shape again. +template = """/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include // raft::identity_op +#include // ops::* +#include // pairwise_matrix_instantiation_point +INCLUDE_SM_HEADERS + +namespace raft::distance::detail { + +template void pairwise_matrix_instantiation_point( + OpT, + pairwise_matrix_params, + cudaStream_t); + +} // namespace raft::distance::detail +""" + +data_type_instances = [ + dict( + DataT="float", + AccT="float", + OutT="float", + IdxT="int", + ), + dict( + DataT="double", + AccT="double", + OutT="double", + IdxT="int", + ), +] + +op_instances = [ + dict( + path_prefix="canberra", + OpT="ops::canberra_distance_op", + archs = [60], + ), + dict( + path_prefix="correlation", + OpT="ops::correlation_distance_op", + archs = [60], + ), + dict( + path_prefix="cosine", + OpT="ops::cosine_distance_op", + archs = [60, 80], + ), + dict( + path_prefix="hamming_unexpanded", + OpT="ops::hamming_distance_op", + archs = [60], + ), + dict( + path_prefix="hellinger_expanded", + OpT="ops::hellinger_distance_op", + archs = [60], + ), + # inner product is handled by cublas. + dict( + path_prefix="jensen_shannon", + OpT="ops::jensen_shannon_distance_op", + archs = [60], + ), + dict( + path_prefix="kl_divergence", + OpT="ops::kl_divergence_op", + archs = [60], + ), + dict( + path_prefix="l1", + OpT="ops::l1_distance_op", + archs = [60], + ), + dict( + path_prefix="l2_expanded", + OpT="ops::l2_exp_distance_op", + archs = [60, 80], + ), + dict( + path_prefix="l2_unexpanded", + OpT="ops::l2_unexp_distance_op", + archs = [60], + ), + dict( + path_prefix="l_inf", + OpT="ops::l_inf_distance_op", + archs = [60], + ), + dict( + path_prefix="lp_unexpanded", + OpT="ops::lp_unexp_distance_op", + archs = [60], + ), + dict( + path_prefix="russel_rao", + OpT="ops::russel_rao_distance_op", + archs = [60], + ), +] + +def fill_in(s, template): + for k, v in template.items(): + s = s.replace(k, v) + return s + +def fill_include_sm_headers(op_instance): + include_headers ="\n".join([ + f"#include " + for arch in op_instance["archs"] + ]) + + return { + "path_prefix": op_instance["path_prefix"], + "OpT": op_instance["OpT"], + "INCLUDE_SM_HEADERS": include_headers + } + +for op_instance in op_instances: + op_instance = fill_include_sm_headers(op_instance) + + for data_type_instance in data_type_instances: + op_data_instance = { + k : fill_in(v, data_type_instance) + for k, v in op_instance.items() + } + instance = { + **op_data_instance, + **data_type_instance, + "FinopT": "decltype(raft::identity_op())", + } + + text = fill_in(template, instance) + + path = fill_in("path_prefix_DataT_AccT_OutT_IdxT.cu", instance) + with open(path, "w") as f: + f.write(text) diff --git a/cpp/src/distance/specializations/detail/canberra_double_double_double_int.cu b/cpp/src/distance/specializations/detail/canberra_double_double_double_int.cu index 4e9e608792..037d218178 100644 --- a/cpp/src/distance/specializations/detail/canberra_double_double_double_int.cu +++ b/cpp/src/distance/specializations/detail/canberra_double_double_double_int.cu @@ -14,24 +14,20 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include // ops::* +#include // pairwise_matrix_instantiation_point +#include -namespace raft { -namespace distance { -namespace detail { -template void distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - double metric_arg); -} // namespace detail -} // namespace distance -} // namespace raft +namespace raft::distance::detail { + +template void pairwise_matrix_instantiation_point, + int, + double, + double, + decltype(raft::identity_op())>( + ops::canberra_distance_op, + pairwise_matrix_params, + cudaStream_t); + +} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/canberra_float_float_float_int.cu b/cpp/src/distance/specializations/detail/canberra_float_float_float_int.cu index 6dfc385e55..0ed8ea7bb0 100644 --- a/cpp/src/distance/specializations/detail/canberra_float_float_float_int.cu +++ b/cpp/src/distance/specializations/detail/canberra_float_float_float_int.cu @@ -14,25 +14,20 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include // ops::* +#include // pairwise_matrix_instantiation_point +#include -namespace raft { -namespace distance { -namespace detail { -template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - float metric_arg); +namespace raft::distance::detail { -} // namespace detail -} // namespace distance -} // namespace raft +template void pairwise_matrix_instantiation_point, + int, + float, + float, + decltype(raft::identity_op())>( + ops::canberra_distance_op, + pairwise_matrix_params, + cudaStream_t); + +} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/correlation_double_double_double_int.cu b/cpp/src/distance/specializations/detail/correlation_double_double_double_int.cu index 2df77a4b5d..0c11f0621e 100644 --- a/cpp/src/distance/specializations/detail/correlation_double_double_double_int.cu +++ b/cpp/src/distance/specializations/detail/correlation_double_double_double_int.cu @@ -14,27 +14,20 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include // ops::* +#include // pairwise_matrix_instantiation_point +#include -namespace raft { -namespace distance { -namespace detail { +namespace raft::distance::detail { -template void -distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - double metric_arg); +template void pairwise_matrix_instantiation_point, + int, + double, + double, + decltype(raft::identity_op())>( + ops::correlation_distance_op, + pairwise_matrix_params, + cudaStream_t); -} // namespace detail -} // namespace distance -} // namespace raft +} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/correlation_float_float_float_int.cu b/cpp/src/distance/specializations/detail/correlation_float_float_float_int.cu index 76ed00afa6..396e158554 100644 --- a/cpp/src/distance/specializations/detail/correlation_float_float_float_int.cu +++ b/cpp/src/distance/specializations/detail/correlation_float_float_float_int.cu @@ -14,25 +14,20 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include // ops::* +#include // pairwise_matrix_instantiation_point +#include -namespace raft { -namespace distance { -namespace detail { -template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - float metric_arg); +namespace raft::distance::detail { -} // namespace detail -} // namespace distance -} // namespace raft +template void pairwise_matrix_instantiation_point, + int, + float, + float, + decltype(raft::identity_op())>( + ops::correlation_distance_op, + pairwise_matrix_params, + cudaStream_t); + +} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/cosine_double_double_double_int.cu b/cpp/src/distance/specializations/detail/cosine_double_double_double_int.cu index 3e0bcb92ed..e9afb6f563 100644 --- a/cpp/src/distance/specializations/detail/cosine_double_double_double_int.cu +++ b/cpp/src/distance/specializations/detail/cosine_double_double_double_int.cu @@ -14,26 +14,21 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include // ops::* +#include // pairwise_matrix_instantiation_point +#include +#include -namespace raft { -namespace distance { -namespace detail { +namespace raft::distance::detail { -template void distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - double metric_arg); +template void pairwise_matrix_instantiation_point, + int, + double, + double, + decltype(raft::identity_op())>( + ops::cosine_distance_op, + pairwise_matrix_params, + cudaStream_t); -} // namespace detail -} // namespace distance -} // namespace raft +} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/cosine_float_float_float_int.cu b/cpp/src/distance/specializations/detail/cosine_float_float_float_int.cu index 23131ce2c7..1033c491d6 100644 --- a/cpp/src/distance/specializations/detail/cosine_float_float_float_int.cu +++ b/cpp/src/distance/specializations/detail/cosine_float_float_float_int.cu @@ -14,26 +14,21 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include // ops::* +#include // pairwise_matrix_instantiation_point +#include +#include -namespace raft { -namespace distance { -namespace detail { +namespace raft::distance::detail { -template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - float metric_arg); +template void pairwise_matrix_instantiation_point, + int, + float, + float, + decltype(raft::identity_op())>( + ops::cosine_distance_op, + pairwise_matrix_params, + cudaStream_t); -} // namespace detail -} // namespace distance -} // namespace raft +} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/hamming_unexpanded_double_double_double_int.cu b/cpp/src/distance/specializations/detail/hamming_unexpanded_double_double_double_int.cu index b618fd024c..195115914d 100644 --- a/cpp/src/distance/specializations/detail/hamming_unexpanded_double_double_double_int.cu +++ b/cpp/src/distance/specializations/detail/hamming_unexpanded_double_double_double_int.cu @@ -14,27 +14,20 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include // ops::* +#include // pairwise_matrix_instantiation_point +#include -namespace raft { -namespace distance { -namespace detail { +namespace raft::distance::detail { -template void -distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - double metric_arg); +template void pairwise_matrix_instantiation_point, + int, + double, + double, + decltype(raft::identity_op())>( + ops::hamming_distance_op, + pairwise_matrix_params, + cudaStream_t); -} // namespace detail -} // namespace distance -} // namespace raft +} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/hamming_unexpanded_float_float_float_int.cu b/cpp/src/distance/specializations/detail/hamming_unexpanded_float_float_float_int.cu index 18e7aad9e9..a74c6c404e 100644 --- a/cpp/src/distance/specializations/detail/hamming_unexpanded_float_float_float_int.cu +++ b/cpp/src/distance/specializations/detail/hamming_unexpanded_float_float_float_int.cu @@ -14,25 +14,20 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include // ops::* +#include // pairwise_matrix_instantiation_point +#include -namespace raft { -namespace distance { -namespace detail { -template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - float metric_arg); +namespace raft::distance::detail { -} // namespace detail -} // namespace distance -} // namespace raft +template void pairwise_matrix_instantiation_point, + int, + float, + float, + decltype(raft::identity_op())>( + ops::hamming_distance_op, + pairwise_matrix_params, + cudaStream_t); + +} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/hellinger_expanded_double_double_double_int.cu b/cpp/src/distance/specializations/detail/hellinger_expanded_double_double_double_int.cu index 08ab20cfe5..bac1dd7bd0 100644 --- a/cpp/src/distance/specializations/detail/hellinger_expanded_double_double_double_int.cu +++ b/cpp/src/distance/specializations/detail/hellinger_expanded_double_double_double_int.cu @@ -14,27 +14,20 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include // ops::* +#include // pairwise_matrix_instantiation_point +#include -namespace raft { -namespace distance { -namespace detail { +namespace raft::distance::detail { -template void -distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - double metric_arg); +template void pairwise_matrix_instantiation_point, + int, + double, + double, + decltype(raft::identity_op())>( + ops::hellinger_distance_op, + pairwise_matrix_params, + cudaStream_t); -} // namespace detail -} // namespace distance -} // namespace raft +} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/hellinger_expanded_float_float_float_int.cu b/cpp/src/distance/specializations/detail/hellinger_expanded_float_float_float_int.cu index 79eed075fb..77c113b1a9 100644 --- a/cpp/src/distance/specializations/detail/hellinger_expanded_float_float_float_int.cu +++ b/cpp/src/distance/specializations/detail/hellinger_expanded_float_float_float_int.cu @@ -14,26 +14,20 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include // ops::* +#include // pairwise_matrix_instantiation_point +#include -namespace raft { -namespace distance { -namespace detail { +namespace raft::distance::detail { -template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - float metric_arg); +template void pairwise_matrix_instantiation_point, + int, + float, + float, + decltype(raft::identity_op())>( + ops::hellinger_distance_op, + pairwise_matrix_params, + cudaStream_t); -} // namespace detail -} // namespace distance -} // namespace raft +} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/jensen_shannon_double_double_double_int.cu b/cpp/src/distance/specializations/detail/jensen_shannon_double_double_double_int.cu index ed84ee6dc4..188e52c152 100644 --- a/cpp/src/distance/specializations/detail/jensen_shannon_double_double_double_int.cu +++ b/cpp/src/distance/specializations/detail/jensen_shannon_double_double_double_int.cu @@ -14,25 +14,21 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include // ops::* +#include // pairwise_matrix_instantiation_point +#include -namespace raft { -namespace distance { -namespace detail { -template void distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - double metric_arg); +namespace raft::distance::detail { -} // namespace detail -} // namespace distance -} // namespace raft +template void + pairwise_matrix_instantiation_point, + int, + double, + double, + decltype(raft::identity_op())>( + ops::jensen_shannon_distance_op, + pairwise_matrix_params, + cudaStream_t); + +} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/jensen_shannon_float_float_float_int.cu b/cpp/src/distance/specializations/detail/jensen_shannon_float_float_float_int.cu index a241af767c..b0afbf7bb2 100644 --- a/cpp/src/distance/specializations/detail/jensen_shannon_float_float_float_int.cu +++ b/cpp/src/distance/specializations/detail/jensen_shannon_float_float_float_int.cu @@ -14,25 +14,21 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include // ops::* +#include // pairwise_matrix_instantiation_point +#include -namespace raft { -namespace distance { -namespace detail { -template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - float metric_arg); +namespace raft::distance::detail { -} // namespace detail -} // namespace distance -} // namespace raft +template void + pairwise_matrix_instantiation_point, + int, + float, + float, + decltype(raft::identity_op())>( + ops::jensen_shannon_distance_op, + pairwise_matrix_params, + cudaStream_t); + +} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/kl_divergence_double_double_double_int.cu b/cpp/src/distance/specializations/detail/kl_divergence_double_double_double_int.cu index c4c944d123..f06ae85414 100644 --- a/cpp/src/distance/specializations/detail/kl_divergence_double_double_double_int.cu +++ b/cpp/src/distance/specializations/detail/kl_divergence_double_double_double_int.cu @@ -14,25 +14,20 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include // ops::* +#include // pairwise_matrix_instantiation_point +#include -namespace raft { -namespace distance { -namespace detail { -template void distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - double metric_arg); +namespace raft::distance::detail { -} // namespace detail -} // namespace distance -} // namespace raft +template void pairwise_matrix_instantiation_point, + int, + double, + double, + decltype(raft::identity_op())>( + ops::kl_divergence_op, + pairwise_matrix_params, + cudaStream_t); + +} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/kl_divergence_float_float_float_int.cu b/cpp/src/distance/specializations/detail/kl_divergence_float_float_float_int.cu index aa1db5a837..00d5a5ee5b 100644 --- a/cpp/src/distance/specializations/detail/kl_divergence_float_float_float_int.cu +++ b/cpp/src/distance/specializations/detail/kl_divergence_float_float_float_int.cu @@ -14,25 +14,20 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include // ops::* +#include // pairwise_matrix_instantiation_point +#include -namespace raft { -namespace distance { -namespace detail { -template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - float metric_arg); +namespace raft::distance::detail { -} // namespace detail -} // namespace distance -} // namespace raft +template void pairwise_matrix_instantiation_point, + int, + float, + float, + decltype(raft::identity_op())>( + ops::kl_divergence_op, + pairwise_matrix_params, + cudaStream_t); + +} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/l1_double_double_double_int.cu b/cpp/src/distance/specializations/detail/l1_double_double_double_int.cu index 391a1c2aa4..5c235316da 100644 --- a/cpp/src/distance/specializations/detail/l1_double_double_double_int.cu +++ b/cpp/src/distance/specializations/detail/l1_double_double_double_int.cu @@ -14,25 +14,20 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include // ops::* +#include // pairwise_matrix_instantiation_point +#include -namespace raft { -namespace distance { -namespace detail { -template void distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - double metric_arg); +namespace raft::distance::detail { -} // namespace detail -} // namespace distance -} // namespace raft +template void pairwise_matrix_instantiation_point, + int, + double, + double, + decltype(raft::identity_op())>( + ops::l1_distance_op, + pairwise_matrix_params, + cudaStream_t); + +} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/l1_float_float_float_int.cu b/cpp/src/distance/specializations/detail/l1_float_float_float_int.cu index 7b45e52ca1..fb293ca83d 100644 --- a/cpp/src/distance/specializations/detail/l1_float_float_float_int.cu +++ b/cpp/src/distance/specializations/detail/l1_float_float_float_int.cu @@ -14,25 +14,20 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include // ops::* +#include // pairwise_matrix_instantiation_point +#include -namespace raft { -namespace distance { -namespace detail { -template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - float metric_arg); +namespace raft::distance::detail { -} // namespace detail -} // namespace distance -} // namespace raft +template void pairwise_matrix_instantiation_point, + int, + float, + float, + decltype(raft::identity_op())>( + ops::l1_distance_op, + pairwise_matrix_params, + cudaStream_t); + +} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/l2_expanded_double_double_double_int.cu b/cpp/src/distance/specializations/detail/l2_expanded_double_double_double_int.cu index 8c5f746fa2..2c02f0224f 100644 --- a/cpp/src/distance/specializations/detail/l2_expanded_double_double_double_int.cu +++ b/cpp/src/distance/specializations/detail/l2_expanded_double_double_double_int.cu @@ -14,24 +14,21 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include // ops::* +#include // pairwise_matrix_instantiation_point +#include +#include -namespace raft { -namespace distance { -namespace detail { -template void distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - double metric_arg); -} // namespace detail -} // namespace distance -} // namespace raft +namespace raft::distance::detail { + +template void pairwise_matrix_instantiation_point, + int, + double, + double, + decltype(raft::identity_op())>( + ops::l2_exp_distance_op, + pairwise_matrix_params, + cudaStream_t); + +} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/l2_expanded_float_float_float_int.cu b/cpp/src/distance/specializations/detail/l2_expanded_float_float_float_int.cu index c266125f98..85e25a25ca 100644 --- a/cpp/src/distance/specializations/detail/l2_expanded_float_float_float_int.cu +++ b/cpp/src/distance/specializations/detail/l2_expanded_float_float_float_int.cu @@ -14,25 +14,21 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include // ops::* +#include // pairwise_matrix_instantiation_point +#include +#include -namespace raft { -namespace distance { -namespace detail { -template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - float metric_arg); +namespace raft::distance::detail { -} // namespace detail -} // namespace distance -} // namespace raft +template void pairwise_matrix_instantiation_point, + int, + float, + float, + decltype(raft::identity_op())>( + ops::l2_exp_distance_op, + pairwise_matrix_params, + cudaStream_t); + +} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/l2_sqrt_expanded_double_double_double_int.cu b/cpp/src/distance/specializations/detail/l2_sqrt_expanded_double_double_double_int.cu deleted file mode 100644 index 399b120527..0000000000 --- a/cpp/src/distance/specializations/detail/l2_sqrt_expanded_double_double_double_int.cu +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft { -namespace distance { -namespace detail { -template void distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - double metric_arg); - -} // namespace detail -} // namespace distance -} // namespace raft diff --git a/cpp/src/distance/specializations/detail/l2_sqrt_expanded_float_float_float_int.cu b/cpp/src/distance/specializations/detail/l2_sqrt_expanded_float_float_float_int.cu deleted file mode 100644 index 66de212b8e..0000000000 --- a/cpp/src/distance/specializations/detail/l2_sqrt_expanded_float_float_float_int.cu +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft { -namespace distance { -namespace detail { -template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - float metric_arg); - -} // namespace detail -} // namespace distance -} // namespace raft diff --git a/cpp/src/distance/specializations/detail/l2_sqrt_unexpanded_double_double_double_int.cu b/cpp/src/distance/specializations/detail/l2_sqrt_unexpanded_double_double_double_int.cu deleted file mode 100644 index 562d93b2de..0000000000 --- a/cpp/src/distance/specializations/detail/l2_sqrt_unexpanded_double_double_double_int.cu +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft { -namespace distance { -namespace detail { - -template void distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - double metric_arg); -} // namespace detail -} // namespace distance -} // namespace raft diff --git a/cpp/src/distance/specializations/detail/l2_sqrt_unexpanded_float_float_float_int.cu b/cpp/src/distance/specializations/detail/l2_sqrt_unexpanded_float_float_float_int.cu deleted file mode 100644 index 386bbafc5f..0000000000 --- a/cpp/src/distance/specializations/detail/l2_sqrt_unexpanded_float_float_float_int.cu +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft { -namespace distance { -namespace detail { -template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - float metric_arg); - -} // namespace detail -} // namespace distance -} // namespace raft diff --git a/cpp/src/distance/specializations/detail/l2_unexpanded_double_double_double_int.cu b/cpp/src/distance/specializations/detail/l2_unexpanded_double_double_double_int.cu index 7733c3af48..5b4d995d14 100644 --- a/cpp/src/distance/specializations/detail/l2_unexpanded_double_double_double_int.cu +++ b/cpp/src/distance/specializations/detail/l2_unexpanded_double_double_double_int.cu @@ -14,25 +14,20 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include // ops::* +#include // pairwise_matrix_instantiation_point +#include -namespace raft { -namespace distance { -namespace detail { -template void distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - double metric_arg); +namespace raft::distance::detail { -} // namespace detail -} // namespace distance -} // namespace raft +template void pairwise_matrix_instantiation_point, + int, + double, + double, + decltype(raft::identity_op())>( + ops::l2_unexp_distance_op, + pairwise_matrix_params, + cudaStream_t); + +} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/l2_unexpanded_float_float_float_int.cu b/cpp/src/distance/specializations/detail/l2_unexpanded_float_float_float_int.cu index 4ea18d31de..a63c3f0bb8 100644 --- a/cpp/src/distance/specializations/detail/l2_unexpanded_float_float_float_int.cu +++ b/cpp/src/distance/specializations/detail/l2_unexpanded_float_float_float_int.cu @@ -14,25 +14,20 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include // ops::* +#include // pairwise_matrix_instantiation_point +#include -namespace raft { -namespace distance { -namespace detail { -template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - float metric_arg); +namespace raft::distance::detail { -} // namespace detail -} // namespace distance -} // namespace raft +template void pairwise_matrix_instantiation_point, + int, + float, + float, + decltype(raft::identity_op())>( + ops::l2_unexp_distance_op, + pairwise_matrix_params, + cudaStream_t); + +} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/l_inf_double_double_double_int.cu b/cpp/src/distance/specializations/detail/l_inf_double_double_double_int.cu index 74414f8fd6..831167523f 100644 --- a/cpp/src/distance/specializations/detail/l_inf_double_double_double_int.cu +++ b/cpp/src/distance/specializations/detail/l_inf_double_double_double_int.cu @@ -14,26 +14,20 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include // ops::* +#include // pairwise_matrix_instantiation_point +#include -namespace raft { -namespace distance { -namespace detail { +namespace raft::distance::detail { -template void distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - double metric_arg); +template void pairwise_matrix_instantiation_point, + int, + double, + double, + decltype(raft::identity_op())>( + ops::l_inf_distance_op, + pairwise_matrix_params, + cudaStream_t); -} // namespace detail -} // namespace distance -} // namespace raft +} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/l_inf_float_float_float_int.cu b/cpp/src/distance/specializations/detail/l_inf_float_float_float_int.cu index e418fc455f..02e667cbe3 100644 --- a/cpp/src/distance/specializations/detail/l_inf_float_float_float_int.cu +++ b/cpp/src/distance/specializations/detail/l_inf_float_float_float_int.cu @@ -14,25 +14,20 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include // ops::* +#include // pairwise_matrix_instantiation_point +#include -namespace raft { -namespace distance { -namespace detail { -template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - float metric_arg); +namespace raft::distance::detail { -} // namespace detail -} // namespace distance -} // namespace raft +template void pairwise_matrix_instantiation_point, + int, + float, + float, + decltype(raft::identity_op())>( + ops::l_inf_distance_op, + pairwise_matrix_params, + cudaStream_t); + +} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/lp_unexpanded_double_double_double_int.cu b/cpp/src/distance/specializations/detail/lp_unexpanded_double_double_double_int.cu index 402cb51b7e..ebd71065ec 100644 --- a/cpp/src/distance/specializations/detail/lp_unexpanded_double_double_double_int.cu +++ b/cpp/src/distance/specializations/detail/lp_unexpanded_double_double_double_int.cu @@ -14,25 +14,20 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include // ops::* +#include // pairwise_matrix_instantiation_point +#include -namespace raft { -namespace distance { -namespace detail { +namespace raft::distance::detail { -template void distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - double metric_arg); -} // namespace detail -} // namespace distance -} // namespace raft +template void pairwise_matrix_instantiation_point, + int, + double, + double, + decltype(raft::identity_op())>( + ops::lp_unexp_distance_op, + pairwise_matrix_params, + cudaStream_t); + +} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/lp_unexpanded_float_float_float_int.cu b/cpp/src/distance/specializations/detail/lp_unexpanded_float_float_float_int.cu index 7efe2b3349..b94a81fdce 100644 --- a/cpp/src/distance/specializations/detail/lp_unexpanded_float_float_float_int.cu +++ b/cpp/src/distance/specializations/detail/lp_unexpanded_float_float_float_int.cu @@ -14,25 +14,20 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include // ops::* +#include // pairwise_matrix_instantiation_point +#include -namespace raft { -namespace distance { -namespace detail { -template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - float metric_arg); +namespace raft::distance::detail { -} // namespace detail -} // namespace distance -} // namespace raft +template void pairwise_matrix_instantiation_point, + int, + float, + float, + decltype(raft::identity_op())>( + ops::lp_unexp_distance_op, + pairwise_matrix_params, + cudaStream_t); + +} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/russel_rao_double_double_double_int.cu b/cpp/src/distance/specializations/detail/russel_rao_double_double_double_int.cu index b1e6f5e1f4..6f952fcc37 100644 --- a/cpp/src/distance/specializations/detail/russel_rao_double_double_double_int.cu +++ b/cpp/src/distance/specializations/detail/russel_rao_double_double_double_int.cu @@ -14,26 +14,20 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include // ops::* +#include // pairwise_matrix_instantiation_point +#include -namespace raft { -namespace distance { -namespace detail { -template void -distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - double metric_arg); +namespace raft::distance::detail { -} // namespace detail -} // namespace distance -} // namespace raft +template void pairwise_matrix_instantiation_point, + int, + double, + double, + decltype(raft::identity_op())>( + ops::russel_rao_distance_op, + pairwise_matrix_params, + cudaStream_t); + +} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/russel_rao_float_float_float_int.cu b/cpp/src/distance/specializations/detail/russel_rao_float_float_float_int.cu index 1e12bcd705..3223ce33a7 100644 --- a/cpp/src/distance/specializations/detail/russel_rao_float_float_float_int.cu +++ b/cpp/src/distance/specializations/detail/russel_rao_float_float_float_int.cu @@ -14,25 +14,20 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include // ops::* +#include // pairwise_matrix_instantiation_point +#include -namespace raft { -namespace distance { -namespace detail { -template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - float metric_arg); +namespace raft::distance::detail { -} // namespace detail -} // namespace distance -} // namespace raft +template void pairwise_matrix_instantiation_point, + int, + float, + float, + decltype(raft::identity_op())>( + ops::russel_rao_distance_op, + pairwise_matrix_params, + cudaStream_t); + +} // namespace raft::distance::detail diff --git a/cpp/test/distance/distance_base.cuh b/cpp/test/distance/distance_base.cuh index 0e084f2ad8..438e212fbd 100644 --- a/cpp/test/distance/distance_base.cuh +++ b/cpp/test/distance/distance_base.cuh @@ -16,16 +16,24 @@ #include "../test_utils.cuh" #include -#include -#include -#include -#include -#include -#include +#include // common::nvtx::range + +#include // make_device_matrix_view +#include // raft::device_resources +#include // raft::sqrt +#include // raft::distance::DistanceType +#include +#include // rmm::device_uvector + +// When the distance library is precompiled, include only the raft_runtime +// headers. This way, a small change in one of the kernel internals does not +// trigger a rebuild of the test files (it of course still triggers a rebuild of +// the raft specializations) #if defined RAFT_COMPILED -#include +#include +#else +#include #endif -#include namespace raft { namespace distance { @@ -409,6 +417,25 @@ template return os; } +// TODO: Remove when mdspan-based raft::runtime::distance::pairwise_distance is +// implemented. +// +// Context: +// https://github.com/rapidsai/raft/issues/1338 +template +constexpr bool layout_to_row_major(); + +template <> +constexpr bool layout_to_row_major() +{ + return true; +} +template <> +constexpr bool layout_to_row_major() +{ + return false; +} + template void distanceLauncher(raft::device_resources const& handle, DataType* x, @@ -422,12 +449,23 @@ void distanceLauncher(raft::device_resources const& handle, DataType threshold, DataType metric_arg = 2.0f) { +#if defined RAFT_COMPILED + // TODO: Implement and use mdspan-based + // raft::runtime::distance::pairwise_distance here. + // + // Context: + // https://github.com/rapidsai/raft/issues/1338 + bool row_major = layout_to_row_major(); + raft::runtime::distance::pairwise_distance( + handle, x, y, dist, m, n, k, distanceType, row_major, metric_arg); +#else auto x_v = make_device_matrix_view(x, m, k); auto y_v = make_device_matrix_view(y, n, k); auto dist_v = make_device_matrix_view(dist, m, n); raft::distance::distance( handle, x_v, y_v, dist_v, metric_arg); +#endif } template @@ -523,9 +561,25 @@ class BigMatrixDistanceTest : public ::testing::Test { auto testInfo = testing::UnitTest::GetInstance()->current_test_info(); common::nvtx::range fun_scope("test::%s/%s", testInfo->test_suite_name(), testInfo->name()); + void pairwise_distance(raft::device_resources const& handle, + float* x, + float* y, + float* dists, + int m, + int n, + int k, + raft::distance::DistanceType metric, + bool isRowMajor, + float metric_arg); + constexpr bool row_major = true; + constexpr float metric_arg = 0.0f; +#if defined RAFT_COMPILED + raft::runtime::distance::pairwise_distance( + handle, x.data(), x.data(), dist.data(), m, n, k, distanceType, row_major, metric_arg); +#else raft::distance::distance( - handle, x.data(), x.data(), dist.data(), m, n, k, true, 0.0f); - + handle, x.data(), x.data(), dist.data(), m, n, k, row_major, metric_arg); +#endif RAFT_CUDA_TRY(cudaStreamSynchronize(handle.get_stream())); } diff --git a/cpp/test/distance/fused_l2_nn.cu b/cpp/test/distance/fused_l2_nn.cu index 4a74d7f16a..383ad39319 100644 --- a/cpp/test/distance/fused_l2_nn.cu +++ b/cpp/test/distance/fused_l2_nn.cu @@ -182,22 +182,20 @@ class FusedL2NNTest : public ::testing::TestWithParam> { int m = params.m; int n = params.n; int k = params.k; - MinAndDistanceReduceOp redOp; - fusedL2NN, int>( - out, - x.data(), - y.data(), - xn.data(), - yn.data(), - m, - n, - k, - (void*)workspace.data(), - redOp, - raft::distance::KVPMinReduce(), - Sqrt, - true, - stream); + + const bool init_out_buffer = true; + fusedL2NNMinReduce, int>(out, + x.data(), + y.data(), + xn.data(), + yn.data(), + m, + n, + k, + (void*)workspace.data(), + Sqrt, + init_out_buffer, + stream); RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); } };