-
Notifications
You must be signed in to change notification settings - Fork 197
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Simplify distance/detail to make is easier to dispatch to different kernel implementations #1142
Simplify distance/detail to make is easier to dispatch to different kernel implementations #1142
Conversation
6d78314
to
d58f0f3
Compare
d58f0f3
to
a213ef3
Compare
Reminder/TODO make grid stride loop variables local variables instead of member variables. See: #838 (comment) |
TODO: #838 (comment) Also look at fused and maskedL2NN for shared memory calculations. |
The calculation of the tile indices are now performed in ldgXY(). This will make it possible to remove all state related to the tile index out of the class in the next commit. Note that the calculation of the tile index can depend on which overloaded constructor is called(!)
This commit moves all grid and tile indexing logic into the caller. Contractions_NT is now only responsible for *intra*-tile indexing. Due to the complexity of the epilog function, the ldgNextGridStride function is not yet called from within the main loop. That is the next goal so that we have all the grid and tile indexing localized in the loop.
This commit removes the epilog function and moves its functionality into the run loop. The next step might be to see if the ldgNextGridStride() method has to be called the current location, or if performance is the same if its called at the start of the loop.
This results in subtle issues with non-square KernelPolicy, as found in fusedL2KNN.
This is more general than just for L1. Making use of it more is work in progress.
By adding yet another struct ^^
This did remove support for the CUTLASS kernels. Has to be put back.
I wasted a lot of time because I had not replaced the op::core() method of the l2_exp_distance_op after I copied it from l2_unexp_distance_op... If I copy something from the template and forget to fill it in, I get a compile error.
I am testing on CUDA 12, where it does not seem to work. Prior to my commits, the CUTLASS kernels were also not working. So not sure what's up. In any case: consider this untested.
The CI error does not seem related to the changes in the PR:
Is there a way to rerun the CI without pushing an empty commit? |
@ahendriksen if you click on the "Details" button it will take you to the actions tab where you can use the "Re-run jobs" dropdown. See https://docs.github.com/en/actions/managing-workflow-runs/re-running-workflows-and-jobs for more info. I've queued up a rerun. |
Thanks! And thanks for linking to the docs. I will have a look to see what else I am missing. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I really like the idea of abstracting each unique distance formula into composable ops
. In this case, not only does it allow us to specify and react differently to various performance conditions, but it allows us to capture the differences of the computations themselves in a single place. This is very much what we did w/ the sparse APIs as well- they are a series of composable operations that can be executed according to their needs (in the sparse case that's binary vs dot-product-based vs full-pairwise evaluation).
I think this looks great. I'd like to see the functions in the public API that accept explicit workspaces deprecated because the memory resource should be getting propagated through the raft::resources
instance (and be controllable by the user). I think we have an opportunity to here to expand our developer guide as well. Otherwise, I'm completely on board w/ this change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Allard for this PR! This looks great! I am very happy to see the reduction of code duplication. I have just a few smaller questions.
cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh
Outdated
Show resolved
Hide resolved
Co-authored-by: Tamas Bela Feher <[email protected]>
Co-authored-by: Tamas Bela Feher <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Allard for the updates! The PR looks good to me!
Just a heads up- we need to merge this cuml PR before this RAFT PR is merged, otherwise cuml will break downstream. |
Good catch. I should be more careful with the non-breaking label. Although that file has been deprecated for a couple of months now. |
No problem. To be fair, we had updated cuml in 23.02 to remove most of its uses of deprecated headers but this one slipped through the cracks. Before we merged this I just wanted to do a quick grep to be sure. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
/merge |
…ernel implementations (rapidsai#1142) The pairwise distance metrics are quite varied. The table below summarizes the differences, in terms of - Epilog : whether the metric has a non-empty epilog operation. - Uses norms: whether the metric requires precalculation of the norms of the vectors. - Has params: whether the norm has additional parameters. The L2 metric, for instance, has the `sqrt` boolean parameter that determines whether to calculate the squared or actual distance. - Pre- & post-processing: For some metrics, the norms have to be precalculated. For other metrics, the input matrices are transformed before the kernel launch, and "untransformed" after. - Expensive inner loop: some metrics use `pow`, `log` or other expensive functions in the inner loop. - Depends on row-major: the calculation of some metrics depend on whether the input is row-major. - CUTLASS: some metrics have an implementation using CUTLASS and tensor cores. <table border="2" cellspacing="0" cellpadding="6" rules="groups" frame="hsides"> <colgroup> <col class="org-left" /> <col class="org-left" /> <col class="org-left" /> <col class="org-left" /> <col class="org-left" /> <col class="org-left" /> <col class="org-left" /> <col class="org-left" /> </colgroup> <thead> <tr> <th scope="col" class="org-left">Metric</th> <th scope="col" class="org-left">Epilog</th> <th scope="col" class="org-left">Uses norms</th> <th scope="col" class="org-left">Has params</th> <th scope="col" class="org-left">Pre- & post-processing</th> <th scope="col" class="org-left">Expensive inner loop</th> <th scope="col" class="org-left">Depends on row-major</th> <th scope="col" class="org-left">CUTLASS</th> </tr> </thead> <tbody> <tr> <td class="org-left">Canberra</td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left">x</td> <td class="org-left"> </td> <td class="org-left"> </td> </tr> <tr> <td class="org-left">Chebyshev (Linf)</td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left"> </td> </tr> <tr> <td class="org-left">Correlation</td> <td class="org-left">x</td> <td class="org-left">x (twice)</td> <td class="org-left">x (many)</td> <td class="org-left">compute norms</td> <td class="org-left"> </td> <td class="org-left">x</td> <td class="org-left"> </td> </tr> <tr> <td class="org-left">Cosine</td> <td class="org-left">x</td> <td class="org-left">x</td> <td class="org-left"> </td> <td class="org-left">compute norms</td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left">x</td> </tr> <tr> <td class="org-left">Hamming</td> <td class="org-left">x</td> <td class="org-left"> </td> <td class="org-left">x (k)</td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left"> </td> </tr> <tr> <td class="org-left">Hellinger</td> <td class="org-left">x</td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left">sqrt and square</td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left"> </td> </tr> <tr> <td class="org-left">Jensen Shannon</td> <td class="org-left">x</td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left">x</td> <td class="org-left"> </td> <td class="org-left"> </td> </tr> <tr> <td class="org-left">KL divergence</td> <td class="org-left">x</td> <td class="org-left"> </td> <td class="org-left">x (row major, x == y)</td> <td class="org-left">yes</td> <td class="org-left">x</td> <td class="org-left">x</td> <td class="org-left"> </td> </tr> <tr> <td class="org-left">L1</td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left"> </td> </tr> <tr> <td class="org-left">L2 expanded</td> <td class="org-left">x</td> <td class="org-left">x</td> <td class="org-left">x (sqrt)</td> <td class="org-left">compute norms</td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left">x</td> </tr> <tr> <td class="org-left">L2 unexpanded</td> <td class="org-left">x</td> <td class="org-left"> </td> <td class="org-left">x (sqrt)</td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left"> </td> </tr> <tr> <td class="org-left">Minkowski (Lp)</td> <td class="org-left">x</td> <td class="org-left"> </td> <td class="org-left">x (p)</td> <td class="org-left"> </td> <td class="org-left">x</td> <td class="org-left"> </td> <td class="org-left"> </td> </tr> <tr> <td class="org-left">Russel-Rao</td> <td class="org-left">x</td> <td class="org-left"> </td> <td class="org-left">x (k, 1/k)</td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left"> </td> </tr> </tbody> </table> To keep the complexity that results from all these differences in check, there are several layers between the public API and the kernel launch, each with their own responsibility. ## Before 1. `raft::distance::pairwise_distance` takes distance type as a run-time argument and dispatches to `raft::distance::detail::pairwise_distance_impl`. 2. `raft::distance::detail::pairwise_distance_impl` allocates workspace as necessary and calls `raft::distance::detail::distance` 3. `raft::distance::detail::distance` defines a default final operation (the identity) and calls an overload of itself. 4. `raft::distance::detail::distance` (with `fin_op`) initializes a `DistanceImpl` zero-sized struct with the correct template arguments and runs the `.run()` method of the struct. 5. `raft::distance::detail::DistanceImpl<DistanceType>.run()` calls `raft::distance::detail::XX_Impl`. 6. `raft::distance::detail::XX_Impl` has the following responsibilities: - Pre-compute norms if necessary - Transform input if necessary - If metric supports a CUTLASS operation, dispatch if necessary. - Swap inputs if column-major. - Based on runtime parameter `row_major` dispatch to function template `raft::distance::detail::XX<bool row_major>` 7. `raft::distance::detail::XX` based on alignment of input data dispatch to function template `raft::distance::detail::XX_Impl<int veclen>` (different overload of previous `raft::distance::detail::XX_Impl`) 8. `raft::distance::detail::XX_Impl` has the following responsibilities: - Define `core_op` and `epilog_op` - Define `use_norms` - Launch kernel `pairwiseDistanceMatKernel` with correct launch parameters **Observations**: - Steps 6 and 7 both convert a runtime value to a compile time constant (row-major layout and alignment). - Step 7 is repeated (copy pasted) for each metric. - Steps 7 and 8 do a lot of different things and the steps in between do relatively little. - Steps 1-5 do fairly little (but require a lot of boilerplate) **Proposal**: 1. Collect as much of the runtime behavior of each metric in a `distance_op` that [contains](https://github.com/ahendriksen/raft/blob/wip-refactor-distance/cpp/include/raft/distance/detail/distance_ops/canberra.cuh): - The core_op - The epilog_op - The required shared memory - Whether the inner loop is expensive (and thus loop unrolling should be curtailed) 2. Collect the runtime -> compile-time dispatch in one location ([dispatch.cuh](https://github.com/ahendriksen/raft/blob/486393eff4e0cf1d45ab9d7990b64d607e835d70/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh#L70)) 3. Collect kernel launching in one [location](https://github.com/ahendriksen/raft/blob/486393eff4e0cf1d45ab9d7990b64d607e835d70/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh#L108) 4. Remove some of the boilerplate in steps 1-5. ## After 1. `raft::distance::pairwise_distance` takes distance type as a run-time argument, allocates workspace as necessary, and dispatches to `raft::distance::detail::distance`. 2. `raft::distance::detail::distance` defines a default final operation (the identity) and calls an overload of itself. 3. `raft::distance::detail::distance` (with `fin_op`) calls an overload of `raft::distance::detail::distance_impl` for the correct distance type. 4. `raft::distance::detail::distance_impl` has the following responsibilities: - Pre-compute norms if necessary - Initialize distance op with parameters as necessary, see below for more information. - Transform input if necessary - If metric supports a CUTLASS operation, dispatch if necessary. - Dispatch to `raft::distance::detail::distance_matrix_dispatch` 5. `raft::distance::detail::distance_matrix_dispatch` has the following responsibilities: - swap x, y matrices if column major - dispatch to correct kernel based on run-time parameters `row_major` and `vec_len` - Determine kernel policy based on parameters - Call `raft::distance::detail::pairwise_matrix` 6. `raft::distance::detail::pairwise_matrix` launches the `raft::distance::detail::pairwise_matrix_kernel` with the correct launch parameters. **Distance_op** `raft::distance::detail::ops::XX_distance_op` [[example]](https://github.com/ahendriksen/raft/blob/wip-refactor-distance/cpp/include/raft/distance/detail/distance_ops/canberra.cuh) has the following responsibilities: - Take any parameters (sqrt, k, etc) - Define `core_op` and `epilog_op` - Define `use_norms`, `expensive_inner_loop`, and `shared_mem_size()`. Still TODO: - [x] Rename Minkowski and Chebyshev to Lp and Linf. - [x] Do something with this note in the comments: "if workspace is passed as nullptr, this will return in worksize, the number of bytes of workspace required", which is wrong. - [x] Add a mechanism to limit duplicate compilation when a CUTLASS kernel is available. This is done in follow up PR rapidsai#1295. - [x] Some distance_ops have additional template parameters. This must be cleared up. Authors: - Allard Hendriksen (https://github.com/ahendriksen) - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Tamas Bela Feher (https://github.com/tfeher) - Corey J. Nolet (https://github.com/cjnolet) URL: rapidsai#1142
This PR improves the ability to do dispatch based on compute architecture. It is a follow up to #1142. It has two goals: 1. Make it easier to specify which compute architectures a kernel is compatible with / should be compiled for. 2. Make it easier to compile a kernel only for the architectures for which it is used (if it is unused, the kernel should be empty). We have a specific use case in RAFT for this feature. For the L2 pairwise distance kernel we have a CUTLASS based implementation that works om SM80+ and a fallback kernel. Preferably, each kernel is only compiled for the architectures on which it is actually used. Authors: - Allard Hendriksen (https://github.com/ahendriksen) - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Tamas Bela Feher (https://github.com/tfeher) - Corey J. Nolet (https://github.com/cjnolet) URL: #1335
The pairwise distance metrics are quite varied. The table below summarizes the differences, in terms of
sqrt
boolean parameter that determines whether to calculate the squared or actual distance.pow
,log
or other expensive functions in the inner loop.To keep the complexity that results from all these differences in check, there are several layers between the public API and the kernel launch, each with their own responsibility.
Before
raft::distance::pairwise_distance
takes distance type as a run-time argument and dispatches toraft::distance::detail::pairwise_distance_impl
.raft::distance::detail::pairwise_distance_impl
allocates workspace as necessary and callsraft::distance::detail::distance
raft::distance::detail::distance
defines a default final operation (the identity) and calls an overload of itself.raft::distance::detail::distance
(withfin_op
) initializes aDistanceImpl
zero-sized struct with the correct template arguments and runs the.run()
method of the struct.raft::distance::detail::DistanceImpl<DistanceType>.run()
callsraft::distance::detail::XX_Impl
.raft::distance::detail::XX_Impl
has the following responsibilities:row_major
dispatch to function templateraft::distance::detail::XX<bool row_major>
raft::distance::detail::XX
based on alignment of input data dispatch to function templateraft::distance::detail::XX_Impl<int veclen>
(different overload of previousraft::distance::detail::XX_Impl
)raft::distance::detail::XX_Impl
has the following responsibilities:core_op
andepilog_op
use_norms
pairwiseDistanceMatKernel
with correct launch parametersObservations:
Proposal:
distance_op
that contains:After
raft::distance::pairwise_distance
takes distance type as a run-time argument, allocates workspace as necessary, and dispatches toraft::distance::detail::distance
.raft::distance::detail::distance
defines a default final operation (the identity) and calls an overload of itself.raft::distance::detail::distance
(withfin_op
) calls an overload ofraft::distance::detail::distance_impl
for the correct distance type.raft::distance::detail::distance_impl
has the following responsibilities:raft::distance::detail::distance_matrix_dispatch
raft::distance::detail::distance_matrix_dispatch
has the following responsibilities:row_major
andvec_len
raft::distance::detail::pairwise_matrix
raft::distance::detail::pairwise_matrix
launches theraft::distance::detail::pairwise_matrix_kernel
with the correct launch parameters.Distance_op
raft::distance::detail::ops::XX_distance_op
[example] has the following responsibilities:core_op
andepilog_op
use_norms
,expensive_inner_loop
, andshared_mem_size()
.Still TODO: