Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add dispatch based on compute architecture #1295

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 40 additions & 9 deletions cpp/include/raft/distance/detail/distance.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@

#include <raft/distance/distance_types.hpp>
#include <raft/linalg/gemm.cuh>
#include <raft/util/arch.cuh>
#include <raft/util/cuda_utils.cuh>
#include <rmm/device_uvector.hpp>

Expand Down Expand Up @@ -261,8 +262,11 @@ void distance_impl(raft::resources const& handle,
distance_matrix_dispatch<decltype(distance_op), DataT, AccT, OutT, FinOpT, IdxT>(
distance_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major);
} else {
const auto deviceVersion = getComputeCapability();
if (deviceVersion.first >= 8) {
auto runtime_arch = raft::arch::kernel_runtime_arch();
auto cutlass_range = raft::arch::SM_range(raft::arch::SM_80(), raft::arch::SM_future());
auto legacy_range = raft::arch::SM_range(raft::arch::SM_min(), raft::arch::SM_80());

if (cutlass_range.contains(runtime_arch)) {
// If device is SM_80 or later, use CUTLASS-based kernel.
using Op = ops::cosine_cutlass_op<DataT, AccT>;
Op distance_op{};
Expand All @@ -272,8 +276,25 @@ void distance_impl(raft::resources const& handle,
} else {
// Else use "legacy" cosine kernel
ops::cosine_distance_op<DataT, AccT, IdxT> distance_op{};
distance_matrix_dispatch<decltype(distance_op), DataT, AccT, OutT, FinOpT, IdxT>(
distance_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major);
distance_matrix_dispatch<decltype(distance_op),
DataT,
AccT,
OutT,
FinOpT,
IdxT,
decltype(legacy_range)>(distance_op,
m,
n,
k,
x,
y,
norm_A,
norm_B,
out,
fin_op,
stream,
is_row_major,
legacy_range);
}
}
}
Expand Down Expand Up @@ -527,19 +548,29 @@ void distance_impl_l2_expanded( // NOTE: different name
distance_matrix_dispatch<decltype(l2_op), DataT, AccT, OutT, FinOpT, IdxT>(
l2_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major);
} else {
const auto deviceVersion = getComputeCapability();
if (deviceVersion.first >= 8) {
auto runtime_arch = raft::arch::kernel_runtime_arch();
auto cutlass_range = raft::arch::SM_range(raft::arch::SM_80(), raft::arch::SM_future());
auto legacy_range = raft::arch::SM_range(raft::arch::SM_min(), raft::arch::SM_80());

if (cutlass_range.contains(runtime_arch)) {
// If device is SM_80 or later, use CUTLASS-based kernel.
using L2Op = ops::l2_exp_cutlass_op<DataT, AccT>;
L2Op l2_op(perform_sqrt);

distance_matrix_cutlass_dispatch<decltype(l2_op), DataT, AccT, OutT, FinOpT, IdxT>(
l2_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major);
} else {
// Else use "legacy" L2
// Else use "legacy" L2. Compile *only* for architectures in the legacy
// range. For newer architectures, compile empty kernels.
ops::l2_exp_distance_op<DataT, AccT, IdxT> l2_op(perform_sqrt);
distance_matrix_dispatch<decltype(l2_op), DataT, AccT, OutT, FinOpT, IdxT>(
l2_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major);
distance_matrix_dispatch<decltype(l2_op),
DataT,
AccT,
OutT,
FinOpT,
IdxT,
decltype(legacy_range)>(
l2_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major, legacy_range);
}
}
}
Expand Down
46 changes: 32 additions & 14 deletions cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <cstdio>
#include <raft/distance/detail/pairwise_distance_cutlass_base.cuh>
#include <raft/linalg/contractions.cuh>
#include <raft/util/arch.cuh>
#include <utility>

namespace raft::distance::detail {
Expand Down Expand Up @@ -90,19 +91,22 @@ template <typename OpT,
typename AccT,
typename OutT,
typename FinOpT,
typename IdxT = int>
void distance_matrix_dispatch(OpT distance_op,
IdxT m,
IdxT n,
IdxT k,
const DataT* x,
const DataT* y,
const DataT* x_norm,
const DataT* y_norm,
OutT* out,
FinOpT fin_op,
cudaStream_t stream,
bool is_row_major)
typename IdxT = int,
typename SM_compat_t = raft::arch::SM_range<raft::arch::SM_min, raft::arch::SM_future>>
void distance_matrix_dispatch(
OpT distance_op,
IdxT m,
IdxT n,
IdxT k,
const DataT* x,
const DataT* y,
const DataT* x_norm,
const DataT* y_norm,
OutT* out,
FinOpT fin_op,
cudaStream_t stream,
bool is_row_major,
SM_compat_t sm_compat_range = raft::arch::SM_range(raft::arch::SM_min(), raft::arch::SM_future()))
{
// Determine leading dimensions and, if column-major, flip order of passing x
// and y.
Expand Down Expand Up @@ -154,7 +158,21 @@ void distance_matrix_dispatch(OpT distance_op,
typedef typename std::conditional<row_major(), RowPolicy, ColPolicy>::type Policy;

return pairwise_matrix<Policy, row_major(), DataT, AccT, OutT, IdxT, OpT, FinOpT>(
distance_op, fin_op, x, y, x_norm, y_norm, m, n, k, ldx, ldy, ld_out, out, stream);
distance_op,
fin_op,
x,
y,
x_norm,
y_norm,
m,
n,
k,
ldx,
ldy,
ld_out,
out,
stream,
sm_compat_range);
});
}

Expand Down
56 changes: 38 additions & 18 deletions cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <cstddef>
#include <raft/core/operators.hpp>
#include <raft/distance/detail/pairwise_distance_base.cuh>
#include <raft/util/arch.cuh>

namespace raft::distance::detail {

Expand All @@ -28,21 +29,30 @@ template <typename Policy,
typename OutT,
typename IdxT,
typename opT,
typename FinOpT>
__global__ __launch_bounds__(Policy::Nthreads, 2) void pairwise_matrix_kernel(const DataT* x,
const DataT* y,
const DataT* _xn,
const DataT* _yn,
IdxT m,
IdxT n,
IdxT k,
IdxT lda,
IdxT ldb,
IdxT ldd,
OutT* dOutput,
opT distance_op,
FinOpT fin_op)
typename FinOpT,
typename SM_compat_t>
__global__ __launch_bounds__(Policy::Nthreads,
2) void pairwise_matrix_kernel(const DataT* x,
const DataT* y,
const DataT* _xn,
const DataT* _yn,
IdxT m,
IdxT n,
IdxT k,
IdxT lda,
IdxT ldb,
IdxT ldd,
OutT* dOutput,
opT distance_op,
FinOpT fin_op,
SM_compat_t sm_compat_range)
{
// Early exit to minimize the size of the kernel when it is not supposed to be compiled.
if constexpr (!sm_compat_range.contains(raft::arch::SM_compute_arch())) {
assert(false);
return;
}

extern __shared__ char smem[];

// Wrap operator back into lambdas. This is temporary and should be removed. (TODO)
Expand Down Expand Up @@ -103,7 +113,8 @@ template <typename Policy,
typename OutT,
typename IdxT,
typename OpT,
typename FinOpT>
typename FinOpT,
typename SM_compat_t>
void pairwise_matrix(OpT distance_op,
FinOpT fin_op,
const DataT* x,
Expand All @@ -117,18 +128,27 @@ void pairwise_matrix(OpT distance_op,
IdxT ldb,
IdxT ldd,
OutT* dOutput,
cudaStream_t stream)
cudaStream_t stream,
SM_compat_t sm_compat_range)
{
dim3 blk(Policy::Nthreads);
// Use .template to disambiguate (See:
// https://en.cppreference.com/w/cpp/language/dependent_name)
size_t smem_size = distance_op.template shared_mem_size<Policy>();
// Obtain function pointer to kernel
auto kernel = pairwise_matrix_kernel<Policy, row_major, DataT, AccT, OutT, IdxT, OpT, FinOpT>;
auto kernel = pairwise_matrix_kernel<Policy,
row_major,
DataT,
AccT,
OutT,
IdxT,
OpT,
FinOpT,
decltype(sm_compat_range)>;
dim3 grid = launchConfigGenerator<Policy>(m, n, smem_size, kernel);

kernel<<<grid, blk, smem_size, stream>>>(
x, y, _xn, _yn, m, n, k, lda, ldb, ldd, dOutput, distance_op, fin_op);
x, y, _xn, _yn, m, n, k, lda, ldb, ldd, dOutput, distance_op, fin_op, sm_compat_range);
RAFT_CUDA_TRY(cudaGetLastError());
}

Expand Down
141 changes: 141 additions & 0 deletions cpp/include/raft/util/arch.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

namespace raft::arch {

/* raft::arch provides the following facilities:
*
* - raft::arch::SM_XX : hardcoded compile-time constants for various compute
* architectures. The values raft::arch::SM_min and raft::arch::SM_future
* represent architectures that are always smaller and larger (respectively)
* than any architecture that can be encountered in practice.
*
* - raft::arch::SM_compute_arch : a compile-time value for the *current*
* compute architecture that a kernel is compiled with. It can only be used
* inside kernels with a template argument.
*
* - raft::arch::kernel_runtime_arch : a function that computes at *run-time*
* which version of a kernel will launch (i.e., it will return the compute
* architecture of the version of the kernel that will be launched by the
* driver).
*
* - raft::arch::SM_range : a compile-time value to represent an open interval
* of compute architectures. This can be used to check if the current
* compile-time architecture is in a specified compatibility range.
*/

// detail::SM_generic is a template to create a generic compile-time SM
// architecture constant.
namespace detail {
template <int n>
struct SM_generic {
public:
__host__ __device__ constexpr int value() const { return n; }
};

// A dummy kernel that is used to determine the runtime architecture.
__global__ inline void dummy_runtime_kernel() {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to be static so we don't run into the issue where multiple consumers of raft build with different arch values and we get incorrect kernel selection.

For more info see: NVIDIA/cub#545

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good point. It looks like the dummy kernel approach requires making the kernel static to get a reliable solution, at the cost of littering the final binary with many empty kernels.

In kernel_runtime_arch, we are currently taking a pointer to the dummy_runtime_kernel. If instead, we took a runtime argument that was a pointer to one of the candidate kernels that is going to be called, would that solve the problem? That is, I would remove the dummy_runtime_kernel and the kernel pointer would have to be provided by the user. I think it does solve the linking problem that you described above and it doesn't create spurious kernels, but I want to double check before I change the code.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Requiring a kernel pointer would work as well since we would now be querying based a specific kernel that was only compiled once.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot! I will go for that direction then.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a little late to the party, but I came up with an idea for an alternative way of doing this that I like better because it avoids the empty kernel. See https://github.com/NVIDIA/cub/issues/556

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the pointer! I've been meaning to respond to this for a while, but never found the time to test my assertions.

We are currently (that is: in the PR that was merged) avoiding the empty kernel by forcing the caller to provide a pointer to one of the kernel versions. We then query the func attributes of that kernel.

The __CUDA_ARCH_LIST__ looks like a worthwile approach. However, it may break when kernels are weakly linked (e.g. templated). You describe the issue very well in #1722. I had not considered outlawing weak linking completely.. Let's see how that goes!

} // namespace detail

// A list of architectures that RAPIDS explicitly builds for (SM60, ..., SM90)
// and SM_MIN and SM_FUTURE, that allow specifying an open interval of
// compatible compute architectures.
using SM_min = detail::SM_generic<350>;
using SM_60 = detail::SM_generic<600>;
using SM_70 = detail::SM_generic<700>;
using SM_75 = detail::SM_generic<750>;
using SM_80 = detail::SM_generic<800>;
using SM_86 = detail::SM_generic<860>;
using SM_90 = detail::SM_generic<900>;
using SM_future = detail::SM_generic<99999>;

// This is a type that uses the __CUDA_ARCH__ macro to obtain the compile-time
// compute architecture. It can only be used where __CUDA_ARCH__ is defined,
// i.e., inside a __global__ function template.
struct SM_compute_arch {
template <int dummy = 0>
__device__ constexpr int value() const
{
#ifdef __CUDA_ARCH__
return __CUDA_ARCH__;
#else
// This function should not be called in host code (because __CUDA_ARCH__ is
// not defined). This function is constexpr and thus can be called in host
// code (due to the --expt-relaxed-constexpr compile flag). We would like to
// provide an intelligible error message when this function is called in
// host code, which we do below.
//
// To make sure the static_assert only fires in host code, we use a dummy
// template parameter as described in P2593:
// https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2022/p2593r0.html
static_assert(dummy != 0,
"SM_compute_arch.value() is only callable from a __global__ function template. "
"A way to create a function template is by adding 'template <int dummy = 0>'.");
return -1;
#endif
}
};

// A runtime value for the actual compute architecture of a kernel.
//
// A single kernel can be compiled for several "virtual" compute architectures.
// When a program runs, the driver picks the version of the kernel that most
// closely matches the current hardware. This struct reflects the virtual
// compute architecture of the version of the kernel that the driver picks when
// the kernel runs.
struct SM_runtime {
friend SM_runtime kernel_runtime_arch();

private:
const int _version;
SM_runtime(int version) : _version(version) {}

public:
__host__ __device__ int value() const { return _version; }
};

// Computes which compute architecture of a kernel will run
//
// Semantics are described above in the documentation of SM_runtime.
inline SM_runtime kernel_runtime_arch()
{
auto kernel = detail::dummy_runtime_kernel;
cudaFuncAttributes attributes;
cudaFuncGetAttributes(&attributes, kernel);

return SM_runtime(10 * attributes.ptxVersion);
}

// SM_range represents a range of SM architectures. It can be used to
// conditionally compile a kernel.
template <typename SM_MIN, typename SM_MAX>
struct SM_range {
private:
const SM_MIN _min;
const SM_MAX _max;

public:
__host__ __device__ constexpr SM_range(SM_MIN min, SM_MAX max) : _min(min), _max(max) {}

template <typename SM_t>
__host__ __device__ constexpr bool contains(SM_t current) const
{
return _min.value() <= current.value() && current.value() < _max.value();
}
};

} // namespace raft::arch