Skip to content

Commit

Permalink
Rename kernel arch finding function for dispatch (#1536)
Browse files Browse the repository at this point in the history
-- as the kernel arch given by the cudaFuncAttribute ptxVersion depends on what archs the kernel was compiled for
we should renam kernel_runtime_arch() as kernel_virtual_arch().
-- accordingly update comments to reflect this.

Authors:
  - Mahesh Doijade (https://github.com/mdoijade)
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #1536
  • Loading branch information
mdoijade authored May 19, 2023
1 parent af7e067 commit 0154e8e
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 13 deletions.
7 changes: 4 additions & 3 deletions cpp/include/raft/distance/detail/fused_l2_nn.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -325,12 +325,13 @@ void fusedL2NNImpl(OutT* min,
decltype(distance_op),
decltype(fin_op)>;

// Get pointer to fp32 SIMT kernel to determine the runtime architecture of the
// current system. Other methods to determine the architecture (that do not
// Get pointer to fp32 SIMT kernel to determine the best compute architecture
// out of all for which the kernel was compiled for that matches closely
// to the current device. Other methods to determine the architecture (that do not
// require a pointer) can be error prone. See:
// https://github.com/NVIDIA/cub/issues/545
void* kernel_ptr = reinterpret_cast<void*>(kernel);
auto runtime_arch = arch::kernel_runtime_arch(kernel_ptr);
auto runtime_arch = arch::kernel_virtual_arch(kernel_ptr);
auto cutlass_range = arch::SM_range(arch::SM_80(), arch::SM_future());

if (cutlass_range.contains(runtime_arch)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,13 +108,14 @@ void pairwise_matrix_dispatch(OpT distance_op,
auto cutlass_range = arch::SM_range(arch::SM_80(), arch::SM_future());
auto legacy_range = arch::SM_range(arch::SM_min(), arch::SM_80());

// Get pointer to SM60 kernel to determine the runtime architecture of the
// current system. Other methods to determine the architecture (that do not
// Get pointer to SM60 kernel to determine the best compute architecture
// out of all for which the kernel was compiled for that matches closely
// to the current device. Other methods to determine the architecture (that do not
// require a pointer) can be error prone. See:
// https://github.com/NVIDIA/cub/issues/545
auto sm60_wrapper = pairwise_matrix_sm60_get_wrapper(distance_op, params, legacy_range);
void* kernel_ptr = reinterpret_cast<void*>(sm60_wrapper.kernel_ptr);
auto runtime_arch = arch::kernel_runtime_arch(kernel_ptr);
auto runtime_arch = arch::kernel_virtual_arch(kernel_ptr);

if (cutlass_range.contains(runtime_arch)) {
// If device is SM_80 or later, use CUTLASS-based kernel.
Expand Down
15 changes: 8 additions & 7 deletions cpp/include/raft/util/arch.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@ namespace raft::util::arch {
* compute architecture that a kernel is compiled with. It can only be used
* inside kernels with a template argument.
*
* - raft::util::arch::kernel_runtime_arch : a function that computes at *run-time*
* which version of a kernel will launch (i.e., it will return the compute
* architecture of the version of the kernel that will be launched by the
* driver).
* - raft::util::arch::kernel_virtual_arch : a function that computes at *run-time*
* which version of a kernel will launch (i.e., it will return the virtual compute
* architecture of the version of the kernel that it was compiled for which
* will be launched by the driver).
*
* - raft::util::arch::SM_range : a compile-time value to represent an open interval
* of compute architectures. This can be used to check if the current
Expand Down Expand Up @@ -97,7 +97,7 @@ struct SM_compute_arch {
// compute architecture of the version of the kernel that the driver picks when
// the kernel runs.
struct SM_runtime {
friend SM_runtime kernel_runtime_arch(void*);
friend SM_runtime kernel_virtual_arch(void*);

private:
const int _version;
Expand All @@ -107,15 +107,16 @@ struct SM_runtime {
__host__ __device__ int value() const { return _version; }
};

// Computes which compute architecture of a kernel will run
// Computes which virtual compute architecture the given kernel was compiled for,
// driver picks the version of the kernel that closely matches the current hardware.
//
// Semantics are described above in the documentation of SM_runtime.
//
// This function requires a pointer to the kernel that will run. Other methods
// to determine the architecture (that do not require a pointer) can be error
// prone. See:
// https://github.com/NVIDIA/cub/issues/545
inline SM_runtime kernel_runtime_arch(void* kernel)
inline SM_runtime kernel_virtual_arch(void* kernel)
{
cudaFuncAttributes attributes;
RAFT_CUDA_TRY(cudaFuncGetAttributes(&attributes, kernel));
Expand Down

0 comments on commit 0154e8e

Please sign in to comment.