Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce compile times of distance specializations #1307

Merged
Show file tree
Hide file tree
Changes from 106 commits
Commits
Show all changes
110 commits
Select commit Hold shift + click to select a range
8d3e8a0
contractions: Concentrate tile index calculations
ahendriksen Sep 2, 2022
cb7baab
pairwise_distance_base: Remove all ldgXY(0) calls
ahendriksen Sep 2, 2022
066bf3b
pairwise_distance_base: Move all logic into run loop
ahendriksen Sep 2, 2022
a15d5fc
pairwise_distance_base: Fix typo
ahendriksen Oct 5, 2022
71c6da6
Remove deprecated header
ahendriksen Jan 11, 2023
4bbedf6
Replace lambdas by raft::void_op
ahendriksen Jan 12, 2023
c3d1f6e
Use an operator for L1 distance
ahendriksen Jan 12, 2023
3e3478b
Add launch function
ahendriksen Jan 12, 2023
264a9d2
l1: Replace run-time -> compile-time dispatch
ahendriksen Jan 13, 2023
b232057
pairwise matrix: move files into subdirectories
ahendriksen Jan 13, 2023
06f6ffa
pairwise matrix: Untangle dispatching and kernel template parameters
ahendriksen Jan 13, 2023
2f41faa
l2 unexp: Use pairwise matrix dispatch
ahendriksen Jan 13, 2023
7938614
l2 exp: Use pairwise matrix dispatch
ahendriksen Jan 13, 2023
7afe6cc
Add template for distance operator
ahendriksen Jan 13, 2023
5fe3292
Reenable cutlass-based kernels for CUDA 12.0
ahendriksen Jan 13, 2023
c623332
pairwise matrix l2: Add support for CUTLASS kernels
ahendriksen Jan 13, 2023
27511fc
Canberra: use dispatching mechanism
ahendriksen Jan 13, 2023
58ce6f8
Chebyshev: use pairwise matrix dispatch
ahendriksen Jan 13, 2023
d397c17
Correlation: use pairwise matrix dispatch
ahendriksen Jan 13, 2023
7005a4f
Hamming: use pairwise matrix dispatch
ahendriksen Jan 13, 2023
7831deb
Hellinger: use pairwise matrix dispatch
ahendriksen Jan 13, 2023
4dc72ce
Jensen-Shannon: use pairwise matrix dispatch
ahendriksen Jan 13, 2023
b0d36c1
remove old hamming code
ahendriksen Jan 13, 2023
e95a65b
KL divergence: use pairwise matrix dispatch
ahendriksen Jan 13, 2023
f1c105b
Minkowski: use pairwise matrix dispatch
ahendriksen Jan 13, 2023
ac66e3f
Russel-Rao: use pairwise matrix dispatch
ahendriksen Jan 13, 2023
a89896a
Cosine: use pairwise matrix dispatch
ahendriksen Jan 13, 2023
16b2acd
Fix include for l1 op
ahendriksen Jan 13, 2023
1326e34
kl_divergence: Use raft::log instead of raft::myLog
ahendriksen Feb 10, 2023
0169b26
distance_op: Add expensive_inner_loop marker
ahendriksen Feb 10, 2023
52e95e1
Update copyright notices
ahendriksen Feb 10, 2023
28cd57b
Reusable dispatch mechanism
ahendriksen Feb 10, 2023
c44aece
Dispatch mechanism using switch statement
ahendriksen Feb 10, 2023
7c3bd76
Remove one ".template" from kernel_sm60
ahendriksen Feb 10, 2023
d62eeb7
Dispatch on veclen instead of byte_alignment
ahendriksen Feb 10, 2023
5c3dcaf
Use many template parameters again
ahendriksen Feb 20, 2023
2613e8a
Remove duplicate DistanceType enum definition
ahendriksen Feb 20, 2023
62ed53a
Remove pairwiseDistanceMatKernel
ahendriksen Feb 20, 2023
c334ba3
Remove distance::detail::pairwise_distance_impl
ahendriksen Feb 20, 2023
8e43238
distance_ops: Include cuda_utils.cuh
ahendriksen Feb 21, 2023
e176351
Replace DistanceImpl with method overloads
ahendriksen Feb 21, 2023
6ddd14f
Remove impl files and move doc strings
ahendriksen Feb 21, 2023
34ccddc
Update readme
ahendriksen Feb 21, 2023
b27cdca
Merge branch 'rapids/branch-23.04' into wip-refactor-distance
ahendriksen Feb 21, 2023
6a12ded
Reenable device code generation
ahendriksen Feb 21, 2023
486393e
Readd overload of raft::distance::detail::distance
ahendriksen Feb 21, 2023
ca29e2d
Fix style
ahendriksen Feb 21, 2023
28c95a1
Fix 11.8 compilation error
ahendriksen Feb 22, 2023
a5592b9
Rename minkowski -> lp_unexp
ahendriksen Feb 22, 2023
265ba07
Rename Chebyshev -> l_inf
ahendriksen Feb 22, 2023
7ccb8a7
Rename euc -> l2
ahendriksen Feb 22, 2023
874d014
Update copyright headers
ahendriksen Feb 22, 2023
757fb44
Remove misleading note about workspace nullptr
ahendriksen Feb 22, 2023
d6e9261
Remove notes file
ahendriksen Feb 22, 2023
885bda6
Put template on struct instead of methods
ahendriksen Feb 22, 2023
cd38ec6
Fix style
ahendriksen Feb 22, 2023
749d000
Add dispatch based on compute architecture
ahendriksen Feb 22, 2023
7262861
Fix style
ahendriksen Feb 22, 2023
e7a8e89
Merge branch 'branch-23.04' into wip-refactor-distance
cjnolet Feb 22, 2023
1ef8520
Merge remote-tracking branch 'rapids/pull-request/1142' into enh-arch…
ahendriksen Feb 23, 2023
09a3050
Fix linker error: multiple definition..
ahendriksen Mar 6, 2023
6467221
Update cpp/include/raft/distance/detail/distance_ops/canberra.cuh
ahendriksen Mar 6, 2023
a83461e
Update cpp/include/raft/distance/detail/distance.cuh
ahendriksen Mar 6, 2023
393edf3
Add note about alignment in case of byte input
ahendriksen Mar 6, 2023
48a0c21
Fix
ahendriksen Mar 7, 2023
f8daf48
Merge remote-tracking branch 'rapids/pull-request/1142' into enh-arch…
ahendriksen Mar 7, 2023
1a6636f
Implement review feedback
ahendriksen Mar 7, 2023
e696f24
Merge branch 'branch-23.04' into enh-arch-dispatch
ahendriksen Mar 10, 2023
3164802
Determine runtime arch using kernel pointer
ahendriksen Mar 13, 2023
35014e6
Fix Gram compilation error
ahendriksen Mar 14, 2023
1516198
Reformat comments
ahendriksen Mar 14, 2023
37cdc51
Merge branch 'branch-23.04' into enh-arch-dispatch
ahendriksen Mar 15, 2023
e399afa
Fix kl_divergence index type
ahendriksen Mar 14, 2023
f738d0d
Remove spurious includes from pairwise_distance_base
ahendriksen Mar 14, 2023
fa09bf7
Instantiate kernel launch code
ahendriksen Mar 14, 2023
cf5b236
Add instantiation point
ahendriksen Mar 14, 2023
da2eb69
Add *_essentials headers
ahendriksen Mar 14, 2023
9370a29
Decouple test and pairwise distance code
ahendriksen Mar 14, 2023
14a9477
Take distance_op in pairwise_distance_base
ahendriksen Mar 14, 2023
f54e7a4
Add tuning benchmark for pairwise distances
ahendriksen Mar 15, 2023
35a2ad4
Limit loop unrolling for expensive distance ops
ahendriksen Mar 15, 2023
a8d98ca
Merge branch 'branch-23.04' into enh-arch-dispatch
cjnolet Mar 15, 2023
9bd7a83
Fix column major errors on SM80
ahendriksen Mar 16, 2023
c9ab1b8
Fix col major errors on SM80
ahendriksen Mar 16, 2023
584897f
Merge remote-tracking branch 'rapids/branch-23.04' into enh-arch-disp…
ahendriksen Mar 16, 2023
1581afa
Merge branch 'enh-arch-dispatch' into enh-reduce-compile-times-specia…
ahendriksen Mar 16, 2023
5d1f6c2
Use raft::util::arch namespace
ahendriksen Mar 16, 2023
30c3391
Fix build failure
ahendriksen Mar 17, 2023
9eaf9b5
Fix spelling
ahendriksen Mar 17, 2023
8136afa
Merge branch 'enh-arch-dispatch' into enh-reduce-compile-times-specia…
ahendriksen Mar 17, 2023
2b3b203
Fix pairwise tune benchmark
ahendriksen Mar 18, 2023
3b686a9
Fix compilation error pairwise tuning bench
ahendriksen Mar 20, 2023
e5eb772
Implement reviewer feedback
ahendriksen Mar 20, 2023
7f58194
Merge remote-tracking branch 'rapids/branch-23.04' into enh-specializ…
ahendriksen Mar 20, 2023
a5a3629
Fix merge
ahendriksen Mar 20, 2023
c2970ba
tune_distance: Enable changing distance op without recompile
ahendriksen Mar 20, 2023
6c0d944
Use std::declval
ahendriksen Mar 20, 2023
05a4743
Merge remote-tracking branch 'rapids/branch-23.04' into enh-specializ…
ahendriksen Mar 21, 2023
1df17be
distance_ops: Use static shared_mem_size
ahendriksen Mar 21, 2023
6f4e77d
Pinning dask temporarily because a recent commit broke things
cjnolet Mar 21, 2023
9931ce3
Merge branch 'branch-23.04' into build-2304-pin_dask
cjnolet Mar 21, 2023
7958a32
Updating raft-dask recipe for now. Not yet able to fix the issue w/ t…
cjnolet Mar 22, 2023
857710e
Merge branch 'build-2304-pin_dask' of github.com:cjnolet/raft into bu…
cjnolet Mar 22, 2023
490c844
Merge branch 'branch-23.04' into enh-specializations-reduce-compile-t…
ahendriksen Mar 22, 2023
e2f1aa2
Pinning dask for wheel
cjnolet Mar 22, 2023
2991617
Merge remote-tracking branch 'rapids/pull-request/1363' into enh-spec…
ahendriksen Mar 22, 2023
d57bca4
Revert "Pinning dask for wheel"
cjnolet Mar 22, 2023
b9383e1
Revert "Updating raft-dask recipe for now. Not yet able to fix the is…
cjnolet Mar 22, 2023
4d16e5a
Revert "Pinning dask temporarily because a recent commit broke things"
cjnolet Mar 22, 2023
3a86004
Merge branch 'branch-23.04' into enh-specializations-reduce-compile-t…
cjnolet Mar 22, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions conda/environments/all_cuda-118_arch-x86_64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ dependencies:
- cxx-compiler
- cython>=0.29,<0.30
- dask-cuda=23.04
- dask>=2023.1.1
- distributed>=2023.1.1
- dask<=2023.3.1
- distributed>=2023.3.1
- doxygen>=1.8.20
- gcc_linux-64=11.*
- graphviz
Expand Down
2 changes: 1 addition & 1 deletion conda/recipes/raft-dask/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ requirements:
run:
- {{ pin_compatible('cudatoolkit', max_pin='x', min_pin='x') }}
- cuda-python >=11.7.1,<12.0
- dask >=2023.1.1
- dask <=2023.3.1
- dask-cuda ={{ minor_version }}
- distributed >=2023.1.1
- joblib >=0.11
Expand Down
4 changes: 0 additions & 4 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -312,10 +312,6 @@ if(RAFT_COMPILE_LIBRARY)
src/distance/specializations/detail/l1_double_double_double_int.cu
src/distance/specializations/detail/l2_expanded_float_float_float_int.cu
src/distance/specializations/detail/l2_expanded_double_double_double_int.cu
src/distance/specializations/detail/l2_sqrt_expanded_float_float_float_int.cu
src/distance/specializations/detail/l2_sqrt_expanded_double_double_double_int.cu
src/distance/specializations/detail/l2_sqrt_unexpanded_float_float_float_int.cu
src/distance/specializations/detail/l2_sqrt_unexpanded_double_double_double_int.cu
src/distance/specializations/detail/l2_unexpanded_double_double_double_int.cu
src/distance/specializations/detail/l2_unexpanded_float_float_float_int.cu
src/distance/specializations/detail/l_inf_double_double_double_int.cu
Expand Down
5 changes: 5 additions & 0 deletions cpp/bench/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@ if(BUILD_BENCH)
OPTIONAL LIB
)

ConfigureBench(
NAME TUNE_DISTANCE PATH bench/distance/tune_pairwise/kernel.cu
bench/distance/tune_pairwise/bench.cu bench/main.cpp
)

ConfigureBench(
NAME
DISTANCE_BENCH
Expand Down
151 changes: 151 additions & 0 deletions cpp/bench/distance/tune_pairwise/bench.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

// Tuning benchmarks.
//
// Goals:
//
// 1. Fast compile times to maintain iteration speed.
// 2. Create benchmarks that can inform the design of the kernels.
//
// Non-goals:
//
// 1. Measure every distance operation. Instead measures just one distance
// operation at the same time.
// 2. Be useful for finding performance regressions. This is handled by the
// normal benchmarks.
//
// So far, both goals are partly achieved.
//
// RE (1), COMPILE TIMES: kernel.cu is fast to compile. This file is not.
// When the internals of a pairwise distance kernel is changed, this file is not
// recompiled.
//
// RE 2, benchmarks with intent: this file contains a benchmark to check the
// maximal throughput of a kernel. Measuring other things, like performance on
// skinny or wide matrices is not yet implemented.

#include "kernel.cuh" // launch_kernel
#include <algorithm> // std::min
#include <common/benchmark.hpp> // RAFT_BENCH_REGISTER
#include <raft/distance/detail/pairwise_matrix/params.cuh> // pairwise_matrix_params
#include <rmm/device_uvector.hpp> // rmm::device_uvector
#include <vector> // std::vector

namespace raft::bench::distance::tune {

// Max throughput benchmark.
//
// Goal: Measure the maximum distances/sec that can be computed.
//
// To achieve this, we make sure that:
//
// - Input data size is a multiple of the block tile size.
//
// - Perfect distribution of work between SMs, i.e. the number of block tiles is
// a large multiple (num_waves) of the number of blocks (#SMs * occupancy).
//
// - Multiple iterations over Kblk are executed (num_k_iters).
struct throughput_param {
int num_waves;
int occupancy;
int num_k_iters;
};

const std::vector<throughput_param> throughput_params{
// 32 waves, requested occupancy of 4, and 32 k iterations typically achieves
// maximum throughput. No need to pick higher values.
{32, 4, 32},
};

struct throughput_bench : public fixture {
const throughput_param p;

throughput_bench(const throughput_param& p_) : p(p_) {}

void run_benchmark(::benchmark::State& state) override
{
// Get block size:
int block_m, block_n, block_k;
get_block_size(block_m, block_n, block_k);

// Determine number of blocks that will be launched. This informs the size
// of the inputs as well as the grid size.
const int num_sms = raft::getMultiProcessorCount();
const int max_occupancy = get_max_occupancy();
const int occupancy = std::min(p.occupancy, max_occupancy);
const int num_blocks = occupancy * num_sms;
dim3 grid(num_blocks);

// Create input sizes that are a multiple of the block tile size.
size_t m = block_m;
size_t n = block_n * p.num_waves * num_blocks;
size_t k = block_k * p.num_k_iters;

// DataT, OutT, IdxT, etc, are defined in tuned_kernel.cuh
rmm::device_uvector<DataT> x_vec(m * k, stream);
rmm::device_uvector<DataT> y_vec(n * k, stream);
rmm::device_uvector<DataT> x_norm_vec(m, stream);
rmm::device_uvector<DataT> y_norm_vec(n, stream);
rmm::device_uvector<OutT> out_vec(m * n, stream);

auto x = x_vec.data();
auto y = y_vec.data();
auto x_norm = x_norm_vec.data();
auto y_norm = y_norm_vec.data();
auto out = out_vec.data();
FinOpT fin_op{};

// Create kernel parameter struct. Flip x and y if column major.
IdxT ldx = row_major ? k : m;
IdxT ldy = row_major ? k : n;
IdxT ld_out = row_major ? n : m;

// Template parameters of pairwise_matrix_params are defined in kernel.cuh
pairwise_matrix_params kparams{
IdxT(m), IdxT(n), IdxT(k), ldx, ldy, ld_out, x, y, x_norm, y_norm, out, fin_op, row_major};

// Run benchmark
loop_on_state(state, [&]() { launch_kernel(kparams, grid, stream); });

// Report metrics. We don't report flop/s because we do not know for each
// distance operation how many flops it costs. For L2_unexp and l1, we can
// double this number to get the flop/s. For l2 expanded, core_ops/s should
// equal flop/s (modulo the sqrt and subtracting from the norm).
size_t num_core_ops = m * n * k;
size_t read_elts = n * k + m * k;
size_t write_elts = m * n;

state.counters["m"] = benchmark::Counter(m);
state.counters["n"] = benchmark::Counter(n);
state.counters["k"] = benchmark::Counter(k);
state.counters["occupancy"] = benchmark::Counter(occupancy);
state.counters["# waves"] = benchmark::Counter(p.num_waves);
state.counters["# k iters"] = benchmark::Counter(p.num_k_iters);

state.counters["core_ops/s"] = benchmark::Counter(num_core_ops,
benchmark::Counter::kIsIterationInvariantRate,
benchmark::Counter::OneK::kIs1000);

state.counters["BW"] = benchmark::Counter(write_elts * sizeof(OutT) + read_elts * sizeof(DataT),
benchmark::Counter::kIsIterationInvariantRate,
benchmark::Counter::OneK::kIs1000);
}
};

RAFT_BENCH_REGISTER(throughput_bench, "", throughput_params);

} // namespace raft::bench::distance::tune
88 changes: 88 additions & 0 deletions cpp/bench/distance/tune_pairwise/kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "kernel.cuh"
#include <raft/distance/detail/pairwise_matrix/kernel_sm60.cuh> // pairwise_matrix_sm60_wrapper
#include <raft/linalg/contractions.cuh> // raft::linalg::Policy4x4
#include <raft/util/arch.cuh> // raft::util::arch::SM_compute_arch

namespace raft::bench::distance::tune {

// Distance op
using OpT = raft::distance::detail::ops::lp_unexp_distance_op<DataT, AccT, IdxT>;
constexpr float metric_arg = 2.0;
OpT distance_op{metric_arg};

// Kernel policy
constexpr int vec_len = 1;
using Policy = typename raft::linalg::Policy4x4<DataT, vec_len>::Policy;

// Architecture
namespace arch = raft::util::arch;
constexpr auto sm_compat_range = arch::SM_range(arch::SM_min(), arch::SM_future());

void launch_kernel(pairwise_matrix_params params, dim3 grid, cudaStream_t stream)
{
dim3 block(Policy::Nthreads);
int smem_size = OpT::shared_mem_size<Policy>();

// Obtain function pointer to kernel
auto kernel = raft::distance::detail::pairwise_matrix_kernel<Policy,
row_major,
decltype(sm_compat_range),
OpT,
IdxT,
DataT,
OutT,
FinOpT>;

kernel<<<grid, block, smem_size, stream>>>(distance_op, params);
RAFT_CUDA_TRY(cudaGetLastError());
}

void get_block_size(int& m, int& n, int& k)
{
m = Policy::Mblk;
n = Policy::Nblk;
k = Policy::Kblk;
}

void* get_kernel_ptr()
{
auto kernel = raft::distance::detail::pairwise_matrix_kernel<Policy,
row_major,
decltype(sm_compat_range),
OpT,
IdxT,
DataT,
OutT,
FinOpT>;
return reinterpret_cast<void*>(kernel);
}

int get_max_occupancy()
{
void* kernel_ptr = get_kernel_ptr();
int max_occupancy;
int smem_size = OpT::shared_mem_size<Policy>();

RAFT_CUDA_TRY(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_occupancy, kernel_ptr, Policy::Nthreads, smem_size));

return max_occupancy;
}

} // namespace raft::bench::distance::tune
44 changes: 44 additions & 0 deletions cpp/bench/distance/tune_pairwise/kernel.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <raft/distance/detail/distance_ops/all_ops.cuh> // lp_unexp_distance_op
#include <raft/distance/detail/pairwise_matrix/params.cuh> // pairwise_matrix_params

namespace raft::bench::distance::tune {

// Launch one specific kernel with the following template parameters
constexpr bool row_major = true;
using DataT = float;
using AccT = float;
using OutT = DataT;
using IdxT = int;

using FinOpT = raft::identity_op;

using pairwise_matrix_params =
raft::distance::detail::pairwise_matrix_params<IdxT, DataT, OutT, FinOpT>;

// Launches kernel
void launch_kernel(pairwise_matrix_params, dim3, cudaStream_t);

// Describes the block size that is decided by the policy
void get_block_size(int& m, int& n, int& k);

int get_max_occupancy();

} // namespace raft::bench::distance::tune
2 changes: 1 addition & 1 deletion cpp/include/raft/core/kvp.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

#ifdef _RAFT_HAS_CUDA
#include <cub/cub.cuh>
#include <raft/util/cuda_utils.cuh>
#include <raft/util/cuda_utils.cuh> // raft::shfl_xor
ahendriksen marked this conversation as resolved.
Show resolved Hide resolved
#endif
namespace raft {
/**
Expand Down
Loading