Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add dispatch based on compute architecture #1335

Merged
merged 77 commits into from
Mar 17, 2023

Conversation

ahendriksen
Copy link
Contributor

@ahendriksen ahendriksen commented Mar 13, 2023

This PR improves the ability to do dispatch based on compute architecture. It is a follow up to #1142.

It has two goals:

  1. Make it easier to specify which compute architectures a kernel is compatible with / should be compiled for.
  2. Make it easier to compile a kernel only for the architectures for which it is used (if it is unused, the kernel should be empty).

We have a specific use case in RAFT for this feature. For the L2 pairwise distance kernel we have a CUTLASS based implementation that works om SM80+ and a fallback kernel. Preferably, each kernel is only compiled for the architectures on which it is actually used.

The calculation of the tile indices are now performed in ldgXY(). This
will make it possible to remove all state related to the tile index out
of the class in the next commit.

Note that the calculation of the tile index can depend on which
overloaded constructor is called(!)
This commit moves all grid and tile indexing logic into the caller.
Contractions_NT is now only responsible for *intra*-tile indexing.

Due to the complexity of the epilog function, the ldgNextGridStride
function is not yet called from within the main loop. That is the next
goal so that we have all the grid and tile indexing localized in the
loop.
This commit removes the epilog function and moves its functionality into
the run loop. The next step might be to see if the ldgNextGridStride()
method has to be called the current location, or if performance is the
same if its called at the start of the loop.
This results in subtle issues with non-square KernelPolicy, as found in
fusedL2KNN.
This is more general than just for L1. Making use of it more is work in
progress.
This did remove support for the CUTLASS kernels. Has to be put back.
I wasted a lot of time because I had not replaced the op::core() method
of the l2_exp_distance_op after I copied it from l2_unexp_distance_op...

If I copy something from the template and forget to fill it in, I get a
compile error.
I am testing on CUDA 12, where it does not seem to work. Prior to my
commits, the CUTLASS kernels were also not working. So not sure what's
up.

In any case: consider this untested.
This indicates that the operator uses expensive operations (pow, exp,
log) in the inner loop. Therefore, unrolling and or veclen parameters
should be adjusted
@ahendriksen ahendriksen self-assigned this Mar 13, 2023
@ahendriksen ahendriksen added 3 - Ready for Review improvement Improvement / enhancement to an existing function non-breaking Non-breaking change labels Mar 13, 2023
@ahendriksen
Copy link
Contributor Author

This is a retry of PR #1295, which got inadvertently closed. I have taken into account the reviews that were performed there. In addition, I have centralized the dispatching more than in the previous PR. I am confident that this structure will support adding Hopper kernels without compiling duplicate and redundant kernels.

@tfeher tfeher changed the title Add dispatch based on compute architecture (attempt 2) Add dispatch based on compute architecture Mar 14, 2023
Copy link
Contributor

@tfeher tfeher left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Allard for the PR! I find the new helpers for dispatching based on compute arch useful, and overall the PR looks good to me.

It might make sense to shorten the PR description by putting description starting from "The previous approach can be described ..." in a separate comment.

@ahendriksen
Copy link
Contributor Author

Thanks @tfeher! Moved PR description below:

The previous approach can be described as follows (using psuedo code):

template <typename DistanceOpT>
__global__ void generic_kernel(DistanceOpT op, other arguments .. ) { .. implementation .. }

void cutlass_launch(args.. ) {.. implementation }

template <typename DistanceOpT>
__global__ void generic_kernel_prior_to_sm80(DistanceOpT op, args..) {
#if __CUDA_ARCH__ < 800
  .. copy and paste generic_kernel implementation
#endif
}

template <typename DistanceOpT>
void dispatch(DistanceOpT op, args..) {
  if (! op == l2_distance_op) {
    // run normal generic kernel
    generic_kernel(op, args..);
  } else {
    if (device_capability() >= 800) {
      cutlass_launch(args...);
    } else {
      generic_kernel_prior_to_sm80<<<grid, block, ...>>>(op, args...);
    }
  }
}

The main issue is that the generic kernel is copy pasted. In the new approach, the compute architectures for which the generic kernel has to be compiled (the "compatibility range") is given as an argument (using a compile-time tag type). This allows comparing to the compatibility range inside the kernel and early exiting if the architecture for which it is currently compiling is not supported. Therefore, we do not need to copy and paste generic kernel.

template <typename DistanceOpT, typename SM_compat_t>
__global__ void generic_kernel(DistanceOpT op, SM_compat_t sm_compat_range, args .. ) {
  // Early exit to minimize the size of the kernel when it is not supposed to be compiled.
  if constexpr(! sm_compat_range.contains(raft::arch::SM_compute_arch())) {
    assert(false);
    return;
  }
  .. rest of implementation ..
}

void cutlass_launch(args.. ) {.. implementation }

template <typename DistanceOpT>
void dispatch(DistanceOpT op, args..) {
  if (! op == l2_distance_op) {
    // run normal generic kernel and compile for all architectures:
    auto full_range = raft::arch::SM_range(raft::arch::SM_min(), raft::arch::SM_future());
    generic_kernel<<<grid, block, ...>>>(op, full_range, args..);
  } else {
    // Get current architecture of device at runtime
    void* kernel_ptr = generic_kernel<OpT, decltype(legacy_range), ...>;
    auto runtime_arch = raft::arch::kernel_runtime_arch(kernel_ptr);
    // Define compatibility ranges for the cutlass and generic kernel
    auto cutlass_range = raft::arch::SM_range(raft::arch::SM_80(), raft::arch::SM_future());
    auto legacy_range = raft::arch::SM_range(raft::arch::SM_min(), raft::arch::SM_80());

    if (cutlass_range.contains(runtime_arch)) {
      // On SM80+: run cutlass kernel. (this might actually also compile a non-trivial kernel
      // for SM70 but that is unfortunately outside our control)
      cutlass_launch(_args...);
    } else {
      // This will run on architectures < SM80. Also, the compiled kernels from
      // SM80 and higher will be empty (< 10 instructions).
      generic_kernel<<<grid, block, ....>>>(op, legacy_range, args...);
    }
  }
}

*/
#pragma once

namespace raft::arch {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For new utility files, I'd like to start using raft::util namespace prefix to make them consistent w/ the other namespaces. At some point, we're going to scrape through the other utility files as well. raft::util::arch would be great here.

Copy link
Member

@cjnolet cjnolet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like the namespace in raft/util/arch.cuh to match the nesting raft::util::arch so that we can start making util consistent w/ the other APIs in RAFT. I'm okay doing that in #1307 though so that we can get this one merged.

@cjnolet
Copy link
Member

cjnolet commented Mar 16, 2023

/merge

@cjnolet
Copy link
Member

cjnolet commented Mar 16, 2023

@ahendriksen it looks like there's a few CI failures for distance tests

@cjnolet cjnolet closed this Mar 16, 2023
@cjnolet cjnolet reopened this Mar 16, 2023
@ahendriksen
Copy link
Contributor Author

Hi Corey, yes there was a problem with column major input on Ampere. It took a bit longer to fix than expected. But I expect that the last commit will fix the issue :)

@cjnolet
Copy link
Member

cjnolet commented Mar 17, 2023

/merge

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
3 - Ready for Review cpp improvement Improvement / enhancement to an existing function non-breaking Non-breaking change
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants