Skip to content

Commit

Permalink
Add tuning benchmark for pairwise distances
Browse files Browse the repository at this point in the history
  • Loading branch information
ahendriksen committed Mar 15, 2023
1 parent 14a9477 commit f54e7a4
Show file tree
Hide file tree
Showing 9 changed files with 293 additions and 1 deletion.
5 changes: 5 additions & 0 deletions cpp/bench/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,11 @@ if(BUILD_BENCH)
OPTIONAL DIST NN
)

ConfigureBench(
NAME TUNE_DISTANCE PATH bench/distance/tune_pairwise/kernel.cu
bench/distance/tune_pairwise/bench.cu bench/main.cpp
)

ConfigureBench(
NAME
DISTANCE_BENCH
Expand Down
146 changes: 146 additions & 0 deletions cpp/bench/distance/tune_pairwise/bench.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

// Tuning benchmarks.
//
// Goals:
//
// 1. Fast compile times to maintain iteration speed.
// 2. Create benchmarks that can inform the design of the kernels.
//
// Non-goals:
//
// 1. Measure every distance operation. Instead measures just one distance
// operation at the same time.
// 2. Be useful for finding performance regressions. This is handled by the
// normal benchmarks.
//
// So far, both goals are partly achieved.
//
// RE (1), COMPILE TIMES: kernel.cu is fast to compile. This file is not.
// When the internals of a pairwise distance kernel is changed, this file is not
// recompiled.
//
// RE 2, benchmarks with intent: this file contains a benchmark to check the
// maximal throughput of a kernel. Measuring other things, like performance on
// skinny or wide matrices is not yet implemented.

#include "kernel.cuh" // launch_kernel
#include <algorithm> // std::min
#include <common/benchmark.hpp> // RAFT_BENCH_REGISTER
#include <raft/distance/detail/pairwise_matrix/params.cuh> // pairwise_matrix_params
#include <rmm/device_uvector.hpp> // rmm::device_uvector
#include <vector> // std::vector

namespace raft::bench::distance::tune {

// Max throughput benchmark.
//
// Goal: Measure the maximum distances/sec that can be computed.
//
// To achieve this, we make sure that:
//
// - Input data size is a multiple of the block tile size.
//
// - Perfect distribution of work between SMs, i.e. the number of block tiles is
// a large multiple (num_waves) of the number of blocks (#SMs * occupancy).
//
// - Multiple iterations over Kblk are executed (num_k_iters).
struct throughput_param {
int num_waves;
int occupancy;
int num_k_iters;
};

const std::vector<throughput_param> throughput_params{
// 32 waves, requested occupancy of 4, and 32 k iterations typically achieves
// maximum throughput. No need to pick higher values.
{32, 4, 32},
};

struct throughput_bench : public fixture {
const throughput_param p;

throughput_bench(const throughput_param& p_) : p(p_) {}

void run_benchmark(::benchmark::State& state) override
{
// Get block size:
int block_m, block_n, block_k;
get_block_size(block_m, block_n, block_k);

// Determine number of blocks that will be launched. This informs the size
// of the inputs as well as the grid size.
const int num_sms = raft::getMultiProcessorCount();
const int max_occupancy = get_max_occupancy(distance_op);
const int occupancy = std::min(p.occupancy, max_occupancy);
const int num_blocks = occupancy * num_sms;
dim3 grid(num_blocks);

// Create input sizes that are a multiple of the block tile size.
size_t m = block_m;
size_t n = block_n * p.num_waves * num_blocks;
size_t k = block_k * p.num_k_iters;

// DataT, OutT, IdxT, etc, are defined in tuned_kernel.cuh
rmm::device_uvector<DataT> x_vec(m * k, stream);
rmm::device_uvector<DataT> y_vec(n * k, stream);
rmm::device_uvector<DataT> x_norm_vec(m, stream);
rmm::device_uvector<DataT> y_norm_vec(n, stream);
rmm::device_uvector<OutT> out_vec(m * n, stream);

auto x = x_vec.data();
auto y = y_vec.data();
auto x_norm = x_norm_vec.data();
auto y_norm = y_norm_vec.data();
auto out = out_vec.data();
FinOpT fin_op{};

auto make_params = raft::distance::detail::make_params<IdxT, DataT, OutT, FinOpT>;
pairwise_matrix_params kparams =
row_major ? make_params(m, n, k, x, y, x_norm, y_norm, out, fin_op, row_major)
: make_params(n, m, k, y, x, y_norm, x_norm, out, fin_op, row_major);

// Run benchmark
loop_on_state(state, [&]() { launch_kernel(distance_op, kparams, grid, stream); });

// Report metrics. We don't report flop/s because we do not know for each
// distance operation how many flops it costs. For L2_unexp and l1, we can
// double this number to get the flop/s. For l2 expanded, dist/s should
// equal flop/s (modulo the sqrt and subtracting from the norm).
size_t num_dists = m * n * k;
size_t read_elts = n * k + m * k;
size_t write_elts = m * n;

state.counters["m"] = benchmark::Counter(m);
state.counters["n"] = benchmark::Counter(n);
state.counters["k"] = benchmark::Counter(k);
state.counters["occupancy"] = benchmark::Counter(occupancy);
state.counters["# waves"] = benchmark::Counter(p.num_waves);
state.counters["# k iters"] = benchmark::Counter(p.num_k_iters);

state.counters["dist/s"] = benchmark::Counter(
num_dists, benchmark::Counter::kIsIterationInvariantRate, benchmark::Counter::OneK::kIs1000);

state.counters["BW"] = benchmark::Counter(write_elts * sizeof(OutT) + read_elts * sizeof(DataT),
benchmark::Counter::kIsIterationInvariantRate,
benchmark::Counter::OneK::kIs1000);
}
};

RAFT_BENCH_REGISTER(throughput_bench, "", throughput_params);

} // namespace raft::bench::distance::tune
85 changes: 85 additions & 0 deletions cpp/bench/distance/tune_pairwise/kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "kernel.cuh"
#include <raft/distance/detail/pairwise_matrix/kernel_sm60.cuh> // pairwise_matrix_sm60_wrapper
#include <raft/linalg/contractions.cuh> // raft::linalg::Policy4x4
#include <raft/util/arch.cuh> // raft::arch::SM_compute_arch

namespace raft::bench::distance::tune {

constexpr int vec_len = 1;
using Policy = typename raft::linalg::Policy4x4<DataT, vec_len>::Policy;
constexpr auto sm_compat_range =
raft::arch::SM_range(raft::arch::SM_min(), raft::arch::SM_future());

void launch_kernel(OpT distance_op, pairwise_matrix_params params, dim3 grid, cudaStream_t stream)
{
dim3 block(Policy::Nthreads);
// Use .template to disambiguate (See:
// https://en.cppreference.com/w/cpp/language/dependent_name)
int smem_size = distance_op.template shared_mem_size<Policy>();

// Obtain function pointer to kernel
auto kernel = raft::distance::detail::pairwise_matrix_kernel<Policy,
row_major,
decltype(sm_compat_range),
OpT,
IdxT,
DataT,
OutT,
FinOpT>;

kernel<<<grid, block, smem_size, stream>>>(distance_op, params);
RAFT_CUDA_TRY(cudaGetLastError());
}

void get_block_size(int& m, int& n, int& k)
{
m = Policy::Mblk;
n = Policy::Nblk;
k = Policy::Kblk;
}

void* get_kernel_ptr()
{
auto kernel = raft::distance::detail::pairwise_matrix_kernel<Policy,
row_major,
decltype(sm_compat_range),
OpT,
IdxT,
DataT,
OutT,
FinOpT>;

return reinterpret_cast<void*>(kernel);
}

int get_max_occupancy(OpT distance_op)
{
void* kernel_ptr = get_kernel_ptr();
int max_occupancy;
// Use .template to disambiguate (See:
// https://en.cppreference.com/w/cpp/language/dependent_name)
int smem_size = distance_op.template shared_mem_size<Policy>();

RAFT_CUDA_TRY(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_occupancy, kernel_ptr, Policy::Nthreads, smem_size));

return max_occupancy;
}

} // namespace raft::bench::distance::tune
51 changes: 51 additions & 0 deletions cpp/bench/distance/tune_pairwise/kernel.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <raft/distance/detail/distance_ops/all_ops.cuh> // lp_unexp_distance_op
#include <raft/distance/detail/pairwise_matrix/params.cuh> // pairwise_matrix_params

namespace raft::bench::distance::tune {

// Launch one specific kernel with the following template parameters
constexpr bool row_major = true;
using DataT = float;
using AccT = float;
using OutT = DataT;
using IdxT = int;

// Distance op
// C++17 inline variable. Used by both tuned_kernel.cu and tune_pairwise.cu
// See: https://open-std.org/JTC1/SC22/WG21/docs/papers/2016/p0386r0.pdf
using OpT = raft::distance::detail::ops::lp_unexp_distance_op<DataT, AccT, IdxT>;
constexpr float metric_arg = 2.0;
inline const OpT distance_op{metric_arg};
using FinOpT = raft::identity_op;

using pairwise_matrix_params =
raft::distance::detail::pairwise_matrix_params<IdxT, DataT, OutT, FinOpT>;

// Launches kernel
void launch_kernel(OpT, pairwise_matrix_params, dim3, cudaStream_t);

// Describes the block size that is decided by the policy
void get_block_size(int& m, int& n, int& k);

void* get_kernel_ptr();
int get_max_occupancy(OpT);

} // namespace raft::bench::distance::tune
1 change: 1 addition & 0 deletions cpp/include/raft/distance/detail/distance_ops/canberra.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#pragma once

#include <raft/core/operators.hpp> // raft::abs
#include <raft/util/cuda_dev_essentials.cuh> // DI

namespace raft::distance::detail::ops {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/

#pragma once
#include <raft/core/operators.hpp> // raft::log
#include <raft/util/cuda_dev_essentials.cuh> // DI

namespace raft::distance::detail::ops {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/

#pragma once
#include <raft/core/operators.hpp> // raft::log
#include <raft/util/cuda_dev_essentials.cuh> // DI

namespace raft::distance::detail::ops {
Expand Down
1 change: 1 addition & 0 deletions cpp/include/raft/distance/detail/distance_ops/lp_unexp.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/

#pragma once
#include <raft/core/operators.hpp> // raft::pow, raft::abs
#include <raft/util/cuda_dev_essentials.cuh> // DI

namespace raft::distance::detail::ops {
Expand Down
3 changes: 2 additions & 1 deletion cpp/include/raft/util/device_loads_stores.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@

#pragma once

#include <raft/util/cuda_dev_essentials.cuh>
#include <cstdint> // uintX_t
#include <raft/util/cuda_dev_essentials.cuh> // DI

namespace raft {

Expand Down

0 comments on commit f54e7a4

Please sign in to comment.