Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add dispatch based on compute architecture #1295

Closed

Conversation

ahendriksen
Copy link
Contributor

This PR improves the ability to do dispatch based on compute architecture. It is a follow up to #1142.

It has two goals:

  1. Make it easier to specify which compute architectures a kernel is compatible with / should be compiled for.
  2. Make it easier to compile a kernel only for the architectures for which it is used (if it is unused, the kernel should be empty).

We have a specific use case in RAFT for this feature. For the L2 pairwise distance kernel we have a CUTLASS based implementation that works om SM80+ and a fallback kernel. Preferably, each kernel is only compiled for the architectures on which it is actually used.

The previous approach can be described as follows (using psuedo code):

template <typename DistanceOpT>
__global__ void generic_kernel(DistanceOpT op, other arguments .. ) { .. implementation .. }

__global__ void cutlass_kernel_l2(args.. ) {.. implementation }

template <typename DistanceOpT>
__global__ void generic_kernel_prior_to_sm80(DistanceOpT op, args..) {
#if __CUDA_ARCH__ < 800
  .. copy and paste generic_kernel implementation
#endif
}

template <typename DistanceOpT>
void dispatch(DistanceOpT op, args..) {
  if (! op == l2_distance_op) {
    // run normal generic kernel
    generic_kernel(op, args..);
  } else {
    if (device_capability() >= 800) {
      cutlass_kernel_l2(args...);
    } else {
      generic_kernel_prior_to_sm80(op, args...);
    }
  }
}

The main issue is that the generic kernel is copy pasted. In the new approach, the compute architectures for which the generic kernel has to be compiled (the "compatibility range") is given as an argument (using a compile-time tag type). This allows comparing to the compatibility range inside the kernel and early exiting if the architecture for which it is currently compiling is not supported. Therefore, we do not need to copy and paste generic kernel.

template <typename DistanceOpT, typename SM_compat_t>
__global__ void generic_kernel(DistanceOpT op, SM_compat_t sm_compat_range, args .. ) {
  // Early exit to minimize the size of the kernel when it is not supposed to be compiled.
  if constexpr(! sm_compat_range.contains(raft::arch::SM_compute_arch())) {
    assert(false);
    return;
  }
  .. rest of implementation ..
}

__global__ void cutlass_kernel_l2(args.. ) {.. implementation }

template <typename DistanceOpT>
void dispatch(DistanceOpT op, args..) {
  if (! op == l2_distance_op) {
    // run normal generic kernel and compile for all architectures:
    auto full_range = raft::arch::SM_range(raft::arch::SM_min(), raft::arch::SM_future());
    generic_kernel(op, full_range, args..);
  } else {
    // Get current architecture of device at runtime
    auto runtime_arch = raft::arch::kernel_runtime_arch();
    // Define compatibility ranges for the cutlass and generic kernel
    auto cutlass_range = raft::arch::SM_range(raft::arch::SM_80(), raft::arch::SM_future());
    auto legacy_range = raft::arch::SM_range(raft::arch::SM_min(), raft::arch::SM_80());

    if (cutlass_range.contains(runtime_arch)) {
      // On SM80+: run cutlass kernel. (this might actually also compile a non-trivial kernel
      // for SM70 but that is unfortunately outside our control)
      cutlass_kernel_l2(args...);
    } else {
      // This will run on architectures < SM80. Also, the compiled kernels from
      // SM80 and higher will be empty (< 10 instructions).
      generic_kernel(op, legacy_range, args...);
    }
  }
}

@ahendriksen ahendriksen added improvement Improvement / enhancement to an existing function non-breaking Non-breaking change and removed CMake labels Feb 22, 2023
@ahendriksen ahendriksen changed the base branch from branch-23.04 to pull-request/1142 February 22, 2023 15:40
Copy link
Contributor

@tfeher tfeher left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Allard for this PR! It indeed provides a cleaner method to dispatch kernels based on GPU arch, and at the same time enables us to compile only for the intended architectures. Overall it looks good, see a few comments below.

cpp/include/raft/util/arch.cuh Outdated Show resolved Hide resolved
cpp/include/raft/util/arch.cuh Outdated Show resolved Hide resolved
@ahendriksen
Copy link
Contributor Author

NOTE: this PR is a follow up to #1142. To keep the diff minimal, I have set the base branch to be the previous PR (instead of 23.04). This should be changed before merging :)

};

// A dummy kernel that is used to determine the runtime architecture.
__global__ inline void dummy_runtime_kernel() {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to be static so we don't run into the issue where multiple consumers of raft build with different arch values and we get incorrect kernel selection.

For more info see: NVIDIA/cub#545

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good point. It looks like the dummy kernel approach requires making the kernel static to get a reliable solution, at the cost of littering the final binary with many empty kernels.

In kernel_runtime_arch, we are currently taking a pointer to the dummy_runtime_kernel. If instead, we took a runtime argument that was a pointer to one of the candidate kernels that is going to be called, would that solve the problem? That is, I would remove the dummy_runtime_kernel and the kernel pointer would have to be provided by the user. I think it does solve the linking problem that you described above and it doesn't create spurious kernels, but I want to double check before I change the code.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Requiring a kernel pointer would work as well since we would now be querying based a specific kernel that was only compiled once.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot! I will go for that direction then.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a little late to the party, but I came up with an idea for an alternative way of doing this that I like better because it avoids the empty kernel. See https://github.com/NVIDIA/cub/issues/556

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the pointer! I've been meaning to respond to this for a while, but never found the time to test my assertions.

We are currently (that is: in the PR that was merged) avoiding the empty kernel by forcing the caller to provide a pointer to one of the kernel versions. We then query the func attributes of that kernel.

The __CUDA_ARCH_LIST__ looks like a worthwile approach. However, it may break when kernels are weakly linked (e.g. templated). You describe the issue very well in #1722. I had not considered outlawing weak linking completely.. Let's see how that goes!

@ahendriksen ahendriksen added the 5 - Merge After Dependencies Depends on another PR: do not merge out of order label Mar 7, 2023
rapids-bot bot pushed a commit that referenced this pull request Mar 10, 2023
…ernel implementations (#1142)

The pairwise distance metrics are quite varied. The table below summarizes the differences, in terms of 

- Epilog : whether the metric has a non-empty epilog operation.
- Uses norms: whether the metric requires precalculation of the norms of the vectors.
- Has params: whether the norm has additional parameters. The L2 metric, for instance, has the `sqrt` boolean parameter that determines whether to calculate the squared or actual distance.
- Pre- & post-processing: For some metrics, the norms have to be precalculated. For other metrics, the input matrices are transformed before the kernel launch, and "untransformed" after.
- Expensive inner loop: some metrics use `pow`, `log` or other expensive functions in the inner loop. 
- Depends on row-major: the calculation of some metrics depend on whether the input is row-major. 
- CUTLASS: some metrics have an implementation using CUTLASS and tensor cores.


<table border="2" cellspacing="0" cellpadding="6" rules="groups" frame="hsides">


<colgroup>
<col  class="org-left" />

<col  class="org-left" />

<col  class="org-left" />

<col  class="org-left" />

<col  class="org-left" />

<col  class="org-left" />

<col  class="org-left" />

<col  class="org-left" />
</colgroup>
<thead>
<tr>
<th scope="col" class="org-left">Metric</th>
<th scope="col" class="org-left">Epilog</th>
<th scope="col" class="org-left">Uses norms</th>
<th scope="col" class="org-left">Has params</th>
<th scope="col" class="org-left">Pre- &amp; post-processing</th>
<th scope="col" class="org-left">Expensive inner loop</th>
<th scope="col" class="org-left">Depends on row-major</th>
<th scope="col" class="org-left">CUTLASS</th>
</tr>
</thead>

<tbody>
<tr>
<td class="org-left">Canberra</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">x</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
</tr>


<tr>
<td class="org-left">Chebyshev (Linf)</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
</tr>


<tr>
<td class="org-left">Correlation</td>
<td class="org-left">x</td>
<td class="org-left">x (twice)</td>
<td class="org-left">x (many)</td>
<td class="org-left">compute norms</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">x</td>
<td class="org-left">&#xa0;</td>
</tr>


<tr>
<td class="org-left">Cosine</td>
<td class="org-left">x</td>
<td class="org-left">x</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">compute norms</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">x</td>
</tr>


<tr>
<td class="org-left">Hamming</td>
<td class="org-left">x</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">x (k)</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
</tr>


<tr>
<td class="org-left">Hellinger</td>
<td class="org-left">x</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">sqrt and square</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
</tr>


<tr>
<td class="org-left">Jensen Shannon</td>
<td class="org-left">x</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">x</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
</tr>


<tr>
<td class="org-left">KL divergence</td>
<td class="org-left">x</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">x (row major, x == y)</td>
<td class="org-left">yes</td>
<td class="org-left">x</td>
<td class="org-left">x</td>
<td class="org-left">&#xa0;</td>
</tr>


<tr>
<td class="org-left">L1</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
</tr>


<tr>
<td class="org-left">L2 expanded</td>
<td class="org-left">x</td>
<td class="org-left">x</td>
<td class="org-left">x (sqrt)</td>
<td class="org-left">compute norms</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">x</td>
</tr>


<tr>
<td class="org-left">L2 unexpanded</td>
<td class="org-left">x</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">x (sqrt)</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
</tr>


<tr>
<td class="org-left">Minkowski (Lp)</td>
<td class="org-left">x</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">x (p)</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">x</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
</tr>


<tr>
<td class="org-left">Russel-Rao</td>
<td class="org-left">x</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">x (k, 1/k)</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
</tr>
</tbody>
</table>


To keep the complexity that results from all these differences in check, there are several layers between the public API and the kernel launch, each with their own responsibility. 

## Before
1. `raft::distance::pairwise_distance` takes distance type as a run-time argument and dispatches to `raft::distance::detail::pairwise_distance_impl`.
2. `raft::distance::detail::pairwise_distance_impl` allocates workspace as necessary and calls `raft::distance::detail::distance`
3. `raft::distance::detail::distance` defines a default final operation (the identity) and calls an overload of itself.
4. `raft::distance::detail::distance` (with `fin_op`) initializes a `DistanceImpl` zero-sized struct with the correct template arguments and runs the `.run()` method of the struct.
5. `raft::distance::detail::DistanceImpl<DistanceType>.run()` calls `raft::distance::detail::XX_Impl`.
6. `raft::distance::detail::XX_Impl` has the following responsibilities:
   - Pre-compute norms if necessary
   - Transform input if necessary
   - If metric supports a CUTLASS operation, dispatch if necessary.
   - Swap inputs if column-major.
   - Based on runtime parameter `row_major` dispatch to function template `raft::distance::detail::XX<bool row_major>`
7. `raft::distance::detail::XX` based on alignment of input data dispatch to function template `raft::distance::detail::XX_Impl<int veclen>` (different overload of previous `raft::distance::detail::XX_Impl`)
8. `raft::distance::detail::XX_Impl` has the following responsibilities:
   - Define `core_op` and `epilog_op`
   - Define `use_norms`
   - Launch kernel `pairwiseDistanceMatKernel` with correct launch parameters

**Observations**: 
- Steps 6 and 7 both convert a runtime value to a compile time constant (row-major layout and alignment). 
- Step 7 is repeated (copy pasted) for each metric.
- Steps 7 and 8 do a lot of different things and the steps in between do relatively little.
- Steps 1-5 do fairly little (but require a lot of boilerplate)

**Proposal**:

1. Collect as much of the runtime behavior of each metric in a `distance_op` that [contains](https://github.com/ahendriksen/raft/blob/wip-refactor-distance/cpp/include/raft/distance/detail/distance_ops/canberra.cuh):
    - The core_op
    - The epilog_op
    - The required shared memory
    - Whether the inner loop is expensive (and thus loop unrolling should be curtailed)
2. Collect the runtime -> compile-time dispatch in one location ([dispatch.cuh](https://github.com/ahendriksen/raft/blob/486393eff4e0cf1d45ab9d7990b64d607e835d70/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh#L70))
3. Collect kernel launching in one [location](https://github.com/ahendriksen/raft/blob/486393eff4e0cf1d45ab9d7990b64d607e835d70/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh#L108)
4. Remove some of the boilerplate in steps 1-5.

## After
1. `raft::distance::pairwise_distance` takes distance type as a run-time argument, allocates workspace as necessary, and dispatches to `raft::distance::detail::distance`.
2. `raft::distance::detail::distance` defines a default final operation (the identity) and calls an overload of itself.
3. `raft::distance::detail::distance` (with `fin_op`) calls an overload of `raft::distance::detail::distance_impl` for the correct distance type.
4. `raft::distance::detail::distance_impl` has the following responsibilities:
   - Pre-compute norms if necessary
   - Initialize distance op with parameters as necessary, see below for more information.
   - Transform input if necessary
   - If metric supports a CUTLASS operation, dispatch if necessary.
   - Dispatch to `raft::distance::detail::distance_matrix_dispatch`
5. `raft::distance::detail::distance_matrix_dispatch` has the following responsibilities:
   - swap x, y matrices if column major
   - dispatch to correct kernel based on run-time parameters `row_major` and `vec_len`
   - Determine kernel policy based on parameters
   - Call `raft::distance::detail::pairwise_matrix`
6. `raft::distance::detail::pairwise_matrix` launches the `raft::distance::detail::pairwise_matrix_kernel` with the correct launch parameters.

**Distance_op**
 `raft::distance::detail::ops::XX_distance_op` [[example]](https://github.com/ahendriksen/raft/blob/wip-refactor-distance/cpp/include/raft/distance/detail/distance_ops/canberra.cuh) has the following responsibilities:
   - Take any parameters (sqrt, k, etc)
   - Define `core_op` and `epilog_op`
   - Define `use_norms`, `expensive_inner_loop`, and `shared_mem_size()`.


Still TODO:

- [x] Rename Minkowski and Chebyshev to Lp and Linf.
- [x] Do something with this note in the comments: "if workspace is passed as nullptr, this will return in worksize, the number of bytes of workspace required", which is wrong.
- [x] Add a mechanism to limit duplicate compilation when a CUTLASS kernel is available. This is done in follow up PR #1295.
- [x] Some distance_ops have additional template parameters. This must be cleared up.

Authors:
  - Allard Hendriksen (https://github.com/ahendriksen)
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Tamas Bela Feher (https://github.com/tfeher)
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #1142
@rapids-bot rapids-bot bot deleted the branch rapidsai:pull-request/1142 March 10, 2023 09:38
@rapids-bot rapids-bot bot closed this Mar 10, 2023
@ahendriksen
Copy link
Contributor Author

rapids-bot closed this PR. I set the target of this PR to #1142 (so that the diff was reasonably small). Now, I cannot reopen this PR (because the target branch has been deleted) or retarget it (because this PR is closed).

I fear I have to open a new PR. I will get back to your reviews! Apologies for the inconvenience.

lowener pushed a commit to lowener/raft that referenced this pull request Mar 15, 2023
…ernel implementations (rapidsai#1142)

The pairwise distance metrics are quite varied. The table below summarizes the differences, in terms of 

- Epilog : whether the metric has a non-empty epilog operation.
- Uses norms: whether the metric requires precalculation of the norms of the vectors.
- Has params: whether the norm has additional parameters. The L2 metric, for instance, has the `sqrt` boolean parameter that determines whether to calculate the squared or actual distance.
- Pre- & post-processing: For some metrics, the norms have to be precalculated. For other metrics, the input matrices are transformed before the kernel launch, and "untransformed" after.
- Expensive inner loop: some metrics use `pow`, `log` or other expensive functions in the inner loop. 
- Depends on row-major: the calculation of some metrics depend on whether the input is row-major. 
- CUTLASS: some metrics have an implementation using CUTLASS and tensor cores.


<table border="2" cellspacing="0" cellpadding="6" rules="groups" frame="hsides">


<colgroup>
<col  class="org-left" />

<col  class="org-left" />

<col  class="org-left" />

<col  class="org-left" />

<col  class="org-left" />

<col  class="org-left" />

<col  class="org-left" />

<col  class="org-left" />
</colgroup>
<thead>
<tr>
<th scope="col" class="org-left">Metric</th>
<th scope="col" class="org-left">Epilog</th>
<th scope="col" class="org-left">Uses norms</th>
<th scope="col" class="org-left">Has params</th>
<th scope="col" class="org-left">Pre- &amp; post-processing</th>
<th scope="col" class="org-left">Expensive inner loop</th>
<th scope="col" class="org-left">Depends on row-major</th>
<th scope="col" class="org-left">CUTLASS</th>
</tr>
</thead>

<tbody>
<tr>
<td class="org-left">Canberra</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">x</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
</tr>


<tr>
<td class="org-left">Chebyshev (Linf)</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
</tr>


<tr>
<td class="org-left">Correlation</td>
<td class="org-left">x</td>
<td class="org-left">x (twice)</td>
<td class="org-left">x (many)</td>
<td class="org-left">compute norms</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">x</td>
<td class="org-left">&#xa0;</td>
</tr>


<tr>
<td class="org-left">Cosine</td>
<td class="org-left">x</td>
<td class="org-left">x</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">compute norms</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">x</td>
</tr>


<tr>
<td class="org-left">Hamming</td>
<td class="org-left">x</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">x (k)</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
</tr>


<tr>
<td class="org-left">Hellinger</td>
<td class="org-left">x</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">sqrt and square</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
</tr>


<tr>
<td class="org-left">Jensen Shannon</td>
<td class="org-left">x</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">x</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
</tr>


<tr>
<td class="org-left">KL divergence</td>
<td class="org-left">x</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">x (row major, x == y)</td>
<td class="org-left">yes</td>
<td class="org-left">x</td>
<td class="org-left">x</td>
<td class="org-left">&#xa0;</td>
</tr>


<tr>
<td class="org-left">L1</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
</tr>


<tr>
<td class="org-left">L2 expanded</td>
<td class="org-left">x</td>
<td class="org-left">x</td>
<td class="org-left">x (sqrt)</td>
<td class="org-left">compute norms</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">x</td>
</tr>


<tr>
<td class="org-left">L2 unexpanded</td>
<td class="org-left">x</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">x (sqrt)</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
</tr>


<tr>
<td class="org-left">Minkowski (Lp)</td>
<td class="org-left">x</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">x (p)</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">x</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
</tr>


<tr>
<td class="org-left">Russel-Rao</td>
<td class="org-left">x</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">x (k, 1/k)</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
</tr>
</tbody>
</table>


To keep the complexity that results from all these differences in check, there are several layers between the public API and the kernel launch, each with their own responsibility. 

## Before
1. `raft::distance::pairwise_distance` takes distance type as a run-time argument and dispatches to `raft::distance::detail::pairwise_distance_impl`.
2. `raft::distance::detail::pairwise_distance_impl` allocates workspace as necessary and calls `raft::distance::detail::distance`
3. `raft::distance::detail::distance` defines a default final operation (the identity) and calls an overload of itself.
4. `raft::distance::detail::distance` (with `fin_op`) initializes a `DistanceImpl` zero-sized struct with the correct template arguments and runs the `.run()` method of the struct.
5. `raft::distance::detail::DistanceImpl<DistanceType>.run()` calls `raft::distance::detail::XX_Impl`.
6. `raft::distance::detail::XX_Impl` has the following responsibilities:
   - Pre-compute norms if necessary
   - Transform input if necessary
   - If metric supports a CUTLASS operation, dispatch if necessary.
   - Swap inputs if column-major.
   - Based on runtime parameter `row_major` dispatch to function template `raft::distance::detail::XX<bool row_major>`
7. `raft::distance::detail::XX` based on alignment of input data dispatch to function template `raft::distance::detail::XX_Impl<int veclen>` (different overload of previous `raft::distance::detail::XX_Impl`)
8. `raft::distance::detail::XX_Impl` has the following responsibilities:
   - Define `core_op` and `epilog_op`
   - Define `use_norms`
   - Launch kernel `pairwiseDistanceMatKernel` with correct launch parameters

**Observations**: 
- Steps 6 and 7 both convert a runtime value to a compile time constant (row-major layout and alignment). 
- Step 7 is repeated (copy pasted) for each metric.
- Steps 7 and 8 do a lot of different things and the steps in between do relatively little.
- Steps 1-5 do fairly little (but require a lot of boilerplate)

**Proposal**:

1. Collect as much of the runtime behavior of each metric in a `distance_op` that [contains](https://github.com/ahendriksen/raft/blob/wip-refactor-distance/cpp/include/raft/distance/detail/distance_ops/canberra.cuh):
    - The core_op
    - The epilog_op
    - The required shared memory
    - Whether the inner loop is expensive (and thus loop unrolling should be curtailed)
2. Collect the runtime -> compile-time dispatch in one location ([dispatch.cuh](https://github.com/ahendriksen/raft/blob/486393eff4e0cf1d45ab9d7990b64d607e835d70/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh#L70))
3. Collect kernel launching in one [location](https://github.com/ahendriksen/raft/blob/486393eff4e0cf1d45ab9d7990b64d607e835d70/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh#L108)
4. Remove some of the boilerplate in steps 1-5.

## After
1. `raft::distance::pairwise_distance` takes distance type as a run-time argument, allocates workspace as necessary, and dispatches to `raft::distance::detail::distance`.
2. `raft::distance::detail::distance` defines a default final operation (the identity) and calls an overload of itself.
3. `raft::distance::detail::distance` (with `fin_op`) calls an overload of `raft::distance::detail::distance_impl` for the correct distance type.
4. `raft::distance::detail::distance_impl` has the following responsibilities:
   - Pre-compute norms if necessary
   - Initialize distance op with parameters as necessary, see below for more information.
   - Transform input if necessary
   - If metric supports a CUTLASS operation, dispatch if necessary.
   - Dispatch to `raft::distance::detail::distance_matrix_dispatch`
5. `raft::distance::detail::distance_matrix_dispatch` has the following responsibilities:
   - swap x, y matrices if column major
   - dispatch to correct kernel based on run-time parameters `row_major` and `vec_len`
   - Determine kernel policy based on parameters
   - Call `raft::distance::detail::pairwise_matrix`
6. `raft::distance::detail::pairwise_matrix` launches the `raft::distance::detail::pairwise_matrix_kernel` with the correct launch parameters.

**Distance_op**
 `raft::distance::detail::ops::XX_distance_op` [[example]](https://github.com/ahendriksen/raft/blob/wip-refactor-distance/cpp/include/raft/distance/detail/distance_ops/canberra.cuh) has the following responsibilities:
   - Take any parameters (sqrt, k, etc)
   - Define `core_op` and `epilog_op`
   - Define `use_norms`, `expensive_inner_loop`, and `shared_mem_size()`.


Still TODO:

- [x] Rename Minkowski and Chebyshev to Lp and Linf.
- [x] Do something with this note in the comments: "if workspace is passed as nullptr, this will return in worksize, the number of bytes of workspace required", which is wrong.
- [x] Add a mechanism to limit duplicate compilation when a CUTLASS kernel is available. This is done in follow up PR rapidsai#1295.
- [x] Some distance_ops have additional template parameters. This must be cleared up.

Authors:
  - Allard Hendriksen (https://github.com/ahendriksen)
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Tamas Bela Feher (https://github.com/tfeher)
  - Corey J. Nolet (https://github.com/cjnolet)

URL: rapidsai#1142
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
3 - Ready for Review 5 - Merge After Dependencies Depends on another PR: do not merge out of order cpp improvement Improvement / enhancement to an existing function non-breaking Non-breaking change
Projects
Development

Successfully merging this pull request may close these issues.

4 participants