Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Simplify distance/detail to make is easier to dispatch to different k…
…ernel implementations (#1142) The pairwise distance metrics are quite varied. The table below summarizes the differences, in terms of - Epilog : whether the metric has a non-empty epilog operation. - Uses norms: whether the metric requires precalculation of the norms of the vectors. - Has params: whether the norm has additional parameters. The L2 metric, for instance, has the `sqrt` boolean parameter that determines whether to calculate the squared or actual distance. - Pre- & post-processing: For some metrics, the norms have to be precalculated. For other metrics, the input matrices are transformed before the kernel launch, and "untransformed" after. - Expensive inner loop: some metrics use `pow`, `log` or other expensive functions in the inner loop. - Depends on row-major: the calculation of some metrics depend on whether the input is row-major. - CUTLASS: some metrics have an implementation using CUTLASS and tensor cores. <table border="2" cellspacing="0" cellpadding="6" rules="groups" frame="hsides"> <colgroup> <col class="org-left" /> <col class="org-left" /> <col class="org-left" /> <col class="org-left" /> <col class="org-left" /> <col class="org-left" /> <col class="org-left" /> <col class="org-left" /> </colgroup> <thead> <tr> <th scope="col" class="org-left">Metric</th> <th scope="col" class="org-left">Epilog</th> <th scope="col" class="org-left">Uses norms</th> <th scope="col" class="org-left">Has params</th> <th scope="col" class="org-left">Pre- & post-processing</th> <th scope="col" class="org-left">Expensive inner loop</th> <th scope="col" class="org-left">Depends on row-major</th> <th scope="col" class="org-left">CUTLASS</th> </tr> </thead> <tbody> <tr> <td class="org-left">Canberra</td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left">x</td> <td class="org-left"> </td> <td class="org-left"> </td> </tr> <tr> <td class="org-left">Chebyshev (Linf)</td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left"> </td> </tr> <tr> <td class="org-left">Correlation</td> <td class="org-left">x</td> <td class="org-left">x (twice)</td> <td class="org-left">x (many)</td> <td class="org-left">compute norms</td> <td class="org-left"> </td> <td class="org-left">x</td> <td class="org-left"> </td> </tr> <tr> <td class="org-left">Cosine</td> <td class="org-left">x</td> <td class="org-left">x</td> <td class="org-left"> </td> <td class="org-left">compute norms</td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left">x</td> </tr> <tr> <td class="org-left">Hamming</td> <td class="org-left">x</td> <td class="org-left"> </td> <td class="org-left">x (k)</td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left"> </td> </tr> <tr> <td class="org-left">Hellinger</td> <td class="org-left">x</td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left">sqrt and square</td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left"> </td> </tr> <tr> <td class="org-left">Jensen Shannon</td> <td class="org-left">x</td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left">x</td> <td class="org-left"> </td> <td class="org-left"> </td> </tr> <tr> <td class="org-left">KL divergence</td> <td class="org-left">x</td> <td class="org-left"> </td> <td class="org-left">x (row major, x == y)</td> <td class="org-left">yes</td> <td class="org-left">x</td> <td class="org-left">x</td> <td class="org-left"> </td> </tr> <tr> <td class="org-left">L1</td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left"> </td> </tr> <tr> <td class="org-left">L2 expanded</td> <td class="org-left">x</td> <td class="org-left">x</td> <td class="org-left">x (sqrt)</td> <td class="org-left">compute norms</td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left">x</td> </tr> <tr> <td class="org-left">L2 unexpanded</td> <td class="org-left">x</td> <td class="org-left"> </td> <td class="org-left">x (sqrt)</td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left"> </td> </tr> <tr> <td class="org-left">Minkowski (Lp)</td> <td class="org-left">x</td> <td class="org-left"> </td> <td class="org-left">x (p)</td> <td class="org-left"> </td> <td class="org-left">x</td> <td class="org-left"> </td> <td class="org-left"> </td> </tr> <tr> <td class="org-left">Russel-Rao</td> <td class="org-left">x</td> <td class="org-left"> </td> <td class="org-left">x (k, 1/k)</td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left"> </td> </tr> </tbody> </table> To keep the complexity that results from all these differences in check, there are several layers between the public API and the kernel launch, each with their own responsibility. ## Before 1. `raft::distance::pairwise_distance` takes distance type as a run-time argument and dispatches to `raft::distance::detail::pairwise_distance_impl`. 2. `raft::distance::detail::pairwise_distance_impl` allocates workspace as necessary and calls `raft::distance::detail::distance` 3. `raft::distance::detail::distance` defines a default final operation (the identity) and calls an overload of itself. 4. `raft::distance::detail::distance` (with `fin_op`) initializes a `DistanceImpl` zero-sized struct with the correct template arguments and runs the `.run()` method of the struct. 5. `raft::distance::detail::DistanceImpl<DistanceType>.run()` calls `raft::distance::detail::XX_Impl`. 6. `raft::distance::detail::XX_Impl` has the following responsibilities: - Pre-compute norms if necessary - Transform input if necessary - If metric supports a CUTLASS operation, dispatch if necessary. - Swap inputs if column-major. - Based on runtime parameter `row_major` dispatch to function template `raft::distance::detail::XX<bool row_major>` 7. `raft::distance::detail::XX` based on alignment of input data dispatch to function template `raft::distance::detail::XX_Impl<int veclen>` (different overload of previous `raft::distance::detail::XX_Impl`) 8. `raft::distance::detail::XX_Impl` has the following responsibilities: - Define `core_op` and `epilog_op` - Define `use_norms` - Launch kernel `pairwiseDistanceMatKernel` with correct launch parameters **Observations**: - Steps 6 and 7 both convert a runtime value to a compile time constant (row-major layout and alignment). - Step 7 is repeated (copy pasted) for each metric. - Steps 7 and 8 do a lot of different things and the steps in between do relatively little. - Steps 1-5 do fairly little (but require a lot of boilerplate) **Proposal**: 1. Collect as much of the runtime behavior of each metric in a `distance_op` that [contains](https://github.com/ahendriksen/raft/blob/wip-refactor-distance/cpp/include/raft/distance/detail/distance_ops/canberra.cuh): - The core_op - The epilog_op - The required shared memory - Whether the inner loop is expensive (and thus loop unrolling should be curtailed) 2. Collect the runtime -> compile-time dispatch in one location ([dispatch.cuh](https://github.com/ahendriksen/raft/blob/486393eff4e0cf1d45ab9d7990b64d607e835d70/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh#L70)) 3. Collect kernel launching in one [location](https://github.com/ahendriksen/raft/blob/486393eff4e0cf1d45ab9d7990b64d607e835d70/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh#L108) 4. Remove some of the boilerplate in steps 1-5. ## After 1. `raft::distance::pairwise_distance` takes distance type as a run-time argument, allocates workspace as necessary, and dispatches to `raft::distance::detail::distance`. 2. `raft::distance::detail::distance` defines a default final operation (the identity) and calls an overload of itself. 3. `raft::distance::detail::distance` (with `fin_op`) calls an overload of `raft::distance::detail::distance_impl` for the correct distance type. 4. `raft::distance::detail::distance_impl` has the following responsibilities: - Pre-compute norms if necessary - Initialize distance op with parameters as necessary, see below for more information. - Transform input if necessary - If metric supports a CUTLASS operation, dispatch if necessary. - Dispatch to `raft::distance::detail::distance_matrix_dispatch` 5. `raft::distance::detail::distance_matrix_dispatch` has the following responsibilities: - swap x, y matrices if column major - dispatch to correct kernel based on run-time parameters `row_major` and `vec_len` - Determine kernel policy based on parameters - Call `raft::distance::detail::pairwise_matrix` 6. `raft::distance::detail::pairwise_matrix` launches the `raft::distance::detail::pairwise_matrix_kernel` with the correct launch parameters. **Distance_op** `raft::distance::detail::ops::XX_distance_op` [[example]](https://github.com/ahendriksen/raft/blob/wip-refactor-distance/cpp/include/raft/distance/detail/distance_ops/canberra.cuh) has the following responsibilities: - Take any parameters (sqrt, k, etc) - Define `core_op` and `epilog_op` - Define `use_norms`, `expensive_inner_loop`, and `shared_mem_size()`. Still TODO: - [x] Rename Minkowski and Chebyshev to Lp and Linf. - [x] Do something with this note in the comments: "if workspace is passed as nullptr, this will return in worksize, the number of bytes of workspace required", which is wrong. - [x] Add a mechanism to limit duplicate compilation when a CUTLASS kernel is available. This is done in follow up PR #1295. - [x] Some distance_ops have additional template parameters. This must be cleared up. Authors: - Allard Hendriksen (https://github.com/ahendriksen) - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Tamas Bela Feher (https://github.com/tfeher) - Corey J. Nolet (https://github.com/cjnolet) URL: #1142
- Loading branch information