Skip to content

Commit

Permalink
Simplify distance/detail to make is easier to dispatch to different k…
Browse files Browse the repository at this point in the history
…ernel implementations (#1142)

The pairwise distance metrics are quite varied. The table below summarizes the differences, in terms of 

- Epilog : whether the metric has a non-empty epilog operation.
- Uses norms: whether the metric requires precalculation of the norms of the vectors.
- Has params: whether the norm has additional parameters. The L2 metric, for instance, has the `sqrt` boolean parameter that determines whether to calculate the squared or actual distance.
- Pre- & post-processing: For some metrics, the norms have to be precalculated. For other metrics, the input matrices are transformed before the kernel launch, and "untransformed" after.
- Expensive inner loop: some metrics use `pow`, `log` or other expensive functions in the inner loop. 
- Depends on row-major: the calculation of some metrics depend on whether the input is row-major. 
- CUTLASS: some metrics have an implementation using CUTLASS and tensor cores.


<table border="2" cellspacing="0" cellpadding="6" rules="groups" frame="hsides">


<colgroup>
<col  class="org-left" />

<col  class="org-left" />

<col  class="org-left" />

<col  class="org-left" />

<col  class="org-left" />

<col  class="org-left" />

<col  class="org-left" />

<col  class="org-left" />
</colgroup>
<thead>
<tr>
<th scope="col" class="org-left">Metric</th>
<th scope="col" class="org-left">Epilog</th>
<th scope="col" class="org-left">Uses norms</th>
<th scope="col" class="org-left">Has params</th>
<th scope="col" class="org-left">Pre- &amp; post-processing</th>
<th scope="col" class="org-left">Expensive inner loop</th>
<th scope="col" class="org-left">Depends on row-major</th>
<th scope="col" class="org-left">CUTLASS</th>
</tr>
</thead>

<tbody>
<tr>
<td class="org-left">Canberra</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">x</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
</tr>


<tr>
<td class="org-left">Chebyshev (Linf)</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
</tr>


<tr>
<td class="org-left">Correlation</td>
<td class="org-left">x</td>
<td class="org-left">x (twice)</td>
<td class="org-left">x (many)</td>
<td class="org-left">compute norms</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">x</td>
<td class="org-left">&#xa0;</td>
</tr>


<tr>
<td class="org-left">Cosine</td>
<td class="org-left">x</td>
<td class="org-left">x</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">compute norms</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">x</td>
</tr>


<tr>
<td class="org-left">Hamming</td>
<td class="org-left">x</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">x (k)</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
</tr>


<tr>
<td class="org-left">Hellinger</td>
<td class="org-left">x</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">sqrt and square</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
</tr>


<tr>
<td class="org-left">Jensen Shannon</td>
<td class="org-left">x</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">x</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
</tr>


<tr>
<td class="org-left">KL divergence</td>
<td class="org-left">x</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">x (row major, x == y)</td>
<td class="org-left">yes</td>
<td class="org-left">x</td>
<td class="org-left">x</td>
<td class="org-left">&#xa0;</td>
</tr>


<tr>
<td class="org-left">L1</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
</tr>


<tr>
<td class="org-left">L2 expanded</td>
<td class="org-left">x</td>
<td class="org-left">x</td>
<td class="org-left">x (sqrt)</td>
<td class="org-left">compute norms</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">x</td>
</tr>


<tr>
<td class="org-left">L2 unexpanded</td>
<td class="org-left">x</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">x (sqrt)</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
</tr>


<tr>
<td class="org-left">Minkowski (Lp)</td>
<td class="org-left">x</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">x (p)</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">x</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
</tr>


<tr>
<td class="org-left">Russel-Rao</td>
<td class="org-left">x</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">x (k, 1/k)</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
</tr>
</tbody>
</table>


To keep the complexity that results from all these differences in check, there are several layers between the public API and the kernel launch, each with their own responsibility. 

## Before
1. `raft::distance::pairwise_distance` takes distance type as a run-time argument and dispatches to `raft::distance::detail::pairwise_distance_impl`.
2. `raft::distance::detail::pairwise_distance_impl` allocates workspace as necessary and calls `raft::distance::detail::distance`
3. `raft::distance::detail::distance` defines a default final operation (the identity) and calls an overload of itself.
4. `raft::distance::detail::distance` (with `fin_op`) initializes a `DistanceImpl` zero-sized struct with the correct template arguments and runs the `.run()` method of the struct.
5. `raft::distance::detail::DistanceImpl<DistanceType>.run()` calls `raft::distance::detail::XX_Impl`.
6. `raft::distance::detail::XX_Impl` has the following responsibilities:
   - Pre-compute norms if necessary
   - Transform input if necessary
   - If metric supports a CUTLASS operation, dispatch if necessary.
   - Swap inputs if column-major.
   - Based on runtime parameter `row_major` dispatch to function template `raft::distance::detail::XX<bool row_major>`
7. `raft::distance::detail::XX` based on alignment of input data dispatch to function template `raft::distance::detail::XX_Impl<int veclen>` (different overload of previous `raft::distance::detail::XX_Impl`)
8. `raft::distance::detail::XX_Impl` has the following responsibilities:
   - Define `core_op` and `epilog_op`
   - Define `use_norms`
   - Launch kernel `pairwiseDistanceMatKernel` with correct launch parameters

**Observations**: 
- Steps 6 and 7 both convert a runtime value to a compile time constant (row-major layout and alignment). 
- Step 7 is repeated (copy pasted) for each metric.
- Steps 7 and 8 do a lot of different things and the steps in between do relatively little.
- Steps 1-5 do fairly little (but require a lot of boilerplate)

**Proposal**:

1. Collect as much of the runtime behavior of each metric in a `distance_op` that [contains](https://github.com/ahendriksen/raft/blob/wip-refactor-distance/cpp/include/raft/distance/detail/distance_ops/canberra.cuh):
    - The core_op
    - The epilog_op
    - The required shared memory
    - Whether the inner loop is expensive (and thus loop unrolling should be curtailed)
2. Collect the runtime -> compile-time dispatch in one location ([dispatch.cuh](https://github.com/ahendriksen/raft/blob/486393eff4e0cf1d45ab9d7990b64d607e835d70/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh#L70))
3. Collect kernel launching in one [location](https://github.com/ahendriksen/raft/blob/486393eff4e0cf1d45ab9d7990b64d607e835d70/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh#L108)
4. Remove some of the boilerplate in steps 1-5.

## After
1. `raft::distance::pairwise_distance` takes distance type as a run-time argument, allocates workspace as necessary, and dispatches to `raft::distance::detail::distance`.
2. `raft::distance::detail::distance` defines a default final operation (the identity) and calls an overload of itself.
3. `raft::distance::detail::distance` (with `fin_op`) calls an overload of `raft::distance::detail::distance_impl` for the correct distance type.
4. `raft::distance::detail::distance_impl` has the following responsibilities:
   - Pre-compute norms if necessary
   - Initialize distance op with parameters as necessary, see below for more information.
   - Transform input if necessary
   - If metric supports a CUTLASS operation, dispatch if necessary.
   - Dispatch to `raft::distance::detail::distance_matrix_dispatch`
5. `raft::distance::detail::distance_matrix_dispatch` has the following responsibilities:
   - swap x, y matrices if column major
   - dispatch to correct kernel based on run-time parameters `row_major` and `vec_len`
   - Determine kernel policy based on parameters
   - Call `raft::distance::detail::pairwise_matrix`
6. `raft::distance::detail::pairwise_matrix` launches the `raft::distance::detail::pairwise_matrix_kernel` with the correct launch parameters.

**Distance_op**
 `raft::distance::detail::ops::XX_distance_op` [[example]](https://github.com/ahendriksen/raft/blob/wip-refactor-distance/cpp/include/raft/distance/detail/distance_ops/canberra.cuh) has the following responsibilities:
   - Take any parameters (sqrt, k, etc)
   - Define `core_op` and `epilog_op`
   - Define `use_norms`, `expensive_inner_loop`, and `shared_mem_size()`.


Still TODO:

- [x] Rename Minkowski and Chebyshev to Lp and Linf.
- [x] Do something with this note in the comments: "if workspace is passed as nullptr, this will return in worksize, the number of bytes of workspace required", which is wrong.
- [x] Add a mechanism to limit duplicate compilation when a CUTLASS kernel is available. This is done in follow up PR #1295.
- [x] Some distance_ops have additional template parameters. This must be cleared up.

Authors:
  - Allard Hendriksen (https://github.com/ahendriksen)
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Tamas Bela Feher (https://github.com/tfeher)
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #1142
  • Loading branch information
ahendriksen authored Mar 10, 2023
1 parent bcb0976 commit e4aec7b
Show file tree
Hide file tree
Showing 45 changed files with 2,185 additions and 4,043 deletions.
4 changes: 2 additions & 2 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -317,8 +317,6 @@ if(RAFT_COMPILE_DIST_LIBRARY)
src/distance/cluster/kmeans_init_plus_plus_float.cu
src/distance/distance/specializations/detail/canberra_double_double_double_int.cu
src/distance/distance/specializations/detail/canberra_float_float_float_int.cu
src/distance/distance/specializations/detail/chebyshev_double_double_double_int.cu
src/distance/distance/specializations/detail/chebyshev_float_float_float_int.cu
src/distance/distance/specializations/detail/correlation_double_double_double_int.cu
src/distance/distance/specializations/detail/correlation_float_float_float_int.cu
src/distance/distance/specializations/detail/cosine_double_double_double_int.cu
Expand Down Expand Up @@ -352,6 +350,8 @@ if(RAFT_COMPILE_DIST_LIBRARY)
src/distance/distance/specializations/detail/l2_sqrt_unexpanded_double_double_double_int.cu
src/distance/distance/specializations/detail/l2_unexpanded_double_double_double_int.cu
src/distance/distance/specializations/detail/l2_unexpanded_float_float_float_int.cu
src/distance/distance/specializations/detail/l_inf_double_double_double_int.cu
src/distance/distance/specializations/detail/l_inf_float_float_float_int.cu
src/distance/distance/specializations/detail/lp_unexpanded_double_double_double_int.cu
src/distance/distance/specializations/detail/lp_unexpanded_float_float_float_int.cu
src/distance/distance/specializations/detail/russel_rao_double_double_double_int.cu
Expand Down
194 changes: 0 additions & 194 deletions cpp/include/raft/distance/detail/canberra.cuh

This file was deleted.

Loading

0 comments on commit e4aec7b

Please sign in to comment.