-
Notifications
You must be signed in to change notification settings - Fork 197
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add dispatch based on compute architecture #1295
Add dispatch based on compute architecture #1295
Conversation
405b817
to
7262861
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Allard for this PR! It indeed provides a cleaner method to dispatch kernels based on GPU arch, and at the same time enables us to compile only for the intended architectures. Overall it looks good, see a few comments below.
NOTE: this PR is a follow up to #1142. To keep the diff minimal, I have set the base branch to be the previous PR (instead of 23.04). This should be changed before merging :) |
}; | ||
|
||
// A dummy kernel that is used to determine the runtime architecture. | ||
__global__ inline void dummy_runtime_kernel() {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This needs to be static so we don't run into the issue where multiple consumers of raft build with different arch
values and we get incorrect kernel selection.
For more info see: NVIDIA/cub#545
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a good point. It looks like the dummy kernel approach requires making the kernel static to get a reliable solution, at the cost of littering the final binary with many empty kernels.
In kernel_runtime_arch, we are currently taking a pointer to the dummy_runtime_kernel
. If instead, we took a runtime argument that was a pointer to one of the candidate kernels that is going to be called, would that solve the problem? That is, I would remove the dummy_runtime_kernel
and the kernel pointer would have to be provided by the user. I think it does solve the linking problem that you described above and it doesn't create spurious kernels, but I want to double check before I change the code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Requiring a kernel pointer would work as well since we would now be querying based a specific kernel that was only compiled once.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot! I will go for that direction then.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm a little late to the party, but I came up with an idea for an alternative way of doing this that I like better because it avoids the empty kernel. See https://github.com/NVIDIA/cub/issues/556
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the pointer! I've been meaning to respond to this for a while, but never found the time to test my assertions.
We are currently (that is: in the PR that was merged) avoiding the empty kernel by forcing the caller to provide a pointer to one of the kernel versions. We then query the func attributes of that kernel.
The __CUDA_ARCH_LIST__
looks like a worthwile approach. However, it may break when kernels are weakly linked (e.g. templated). You describe the issue very well in #1722. I had not considered outlawing weak linking completely.. Let's see how that goes!
…ernel implementations (#1142) The pairwise distance metrics are quite varied. The table below summarizes the differences, in terms of - Epilog : whether the metric has a non-empty epilog operation. - Uses norms: whether the metric requires precalculation of the norms of the vectors. - Has params: whether the norm has additional parameters. The L2 metric, for instance, has the `sqrt` boolean parameter that determines whether to calculate the squared or actual distance. - Pre- & post-processing: For some metrics, the norms have to be precalculated. For other metrics, the input matrices are transformed before the kernel launch, and "untransformed" after. - Expensive inner loop: some metrics use `pow`, `log` or other expensive functions in the inner loop. - Depends on row-major: the calculation of some metrics depend on whether the input is row-major. - CUTLASS: some metrics have an implementation using CUTLASS and tensor cores. <table border="2" cellspacing="0" cellpadding="6" rules="groups" frame="hsides"> <colgroup> <col class="org-left" /> <col class="org-left" /> <col class="org-left" /> <col class="org-left" /> <col class="org-left" /> <col class="org-left" /> <col class="org-left" /> <col class="org-left" /> </colgroup> <thead> <tr> <th scope="col" class="org-left">Metric</th> <th scope="col" class="org-left">Epilog</th> <th scope="col" class="org-left">Uses norms</th> <th scope="col" class="org-left">Has params</th> <th scope="col" class="org-left">Pre- & post-processing</th> <th scope="col" class="org-left">Expensive inner loop</th> <th scope="col" class="org-left">Depends on row-major</th> <th scope="col" class="org-left">CUTLASS</th> </tr> </thead> <tbody> <tr> <td class="org-left">Canberra</td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left">x</td> <td class="org-left"> </td> <td class="org-left"> </td> </tr> <tr> <td class="org-left">Chebyshev (Linf)</td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left"> </td> </tr> <tr> <td class="org-left">Correlation</td> <td class="org-left">x</td> <td class="org-left">x (twice)</td> <td class="org-left">x (many)</td> <td class="org-left">compute norms</td> <td class="org-left"> </td> <td class="org-left">x</td> <td class="org-left"> </td> </tr> <tr> <td class="org-left">Cosine</td> <td class="org-left">x</td> <td class="org-left">x</td> <td class="org-left"> </td> <td class="org-left">compute norms</td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left">x</td> </tr> <tr> <td class="org-left">Hamming</td> <td class="org-left">x</td> <td class="org-left"> </td> <td class="org-left">x (k)</td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left"> </td> </tr> <tr> <td class="org-left">Hellinger</td> <td class="org-left">x</td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left">sqrt and square</td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left"> </td> </tr> <tr> <td class="org-left">Jensen Shannon</td> <td class="org-left">x</td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left">x</td> <td class="org-left"> </td> <td class="org-left"> </td> </tr> <tr> <td class="org-left">KL divergence</td> <td class="org-left">x</td> <td class="org-left"> </td> <td class="org-left">x (row major, x == y)</td> <td class="org-left">yes</td> <td class="org-left">x</td> <td class="org-left">x</td> <td class="org-left"> </td> </tr> <tr> <td class="org-left">L1</td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left"> </td> </tr> <tr> <td class="org-left">L2 expanded</td> <td class="org-left">x</td> <td class="org-left">x</td> <td class="org-left">x (sqrt)</td> <td class="org-left">compute norms</td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left">x</td> </tr> <tr> <td class="org-left">L2 unexpanded</td> <td class="org-left">x</td> <td class="org-left"> </td> <td class="org-left">x (sqrt)</td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left"> </td> </tr> <tr> <td class="org-left">Minkowski (Lp)</td> <td class="org-left">x</td> <td class="org-left"> </td> <td class="org-left">x (p)</td> <td class="org-left"> </td> <td class="org-left">x</td> <td class="org-left"> </td> <td class="org-left"> </td> </tr> <tr> <td class="org-left">Russel-Rao</td> <td class="org-left">x</td> <td class="org-left"> </td> <td class="org-left">x (k, 1/k)</td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left"> </td> </tr> </tbody> </table> To keep the complexity that results from all these differences in check, there are several layers between the public API and the kernel launch, each with their own responsibility. ## Before 1. `raft::distance::pairwise_distance` takes distance type as a run-time argument and dispatches to `raft::distance::detail::pairwise_distance_impl`. 2. `raft::distance::detail::pairwise_distance_impl` allocates workspace as necessary and calls `raft::distance::detail::distance` 3. `raft::distance::detail::distance` defines a default final operation (the identity) and calls an overload of itself. 4. `raft::distance::detail::distance` (with `fin_op`) initializes a `DistanceImpl` zero-sized struct with the correct template arguments and runs the `.run()` method of the struct. 5. `raft::distance::detail::DistanceImpl<DistanceType>.run()` calls `raft::distance::detail::XX_Impl`. 6. `raft::distance::detail::XX_Impl` has the following responsibilities: - Pre-compute norms if necessary - Transform input if necessary - If metric supports a CUTLASS operation, dispatch if necessary. - Swap inputs if column-major. - Based on runtime parameter `row_major` dispatch to function template `raft::distance::detail::XX<bool row_major>` 7. `raft::distance::detail::XX` based on alignment of input data dispatch to function template `raft::distance::detail::XX_Impl<int veclen>` (different overload of previous `raft::distance::detail::XX_Impl`) 8. `raft::distance::detail::XX_Impl` has the following responsibilities: - Define `core_op` and `epilog_op` - Define `use_norms` - Launch kernel `pairwiseDistanceMatKernel` with correct launch parameters **Observations**: - Steps 6 and 7 both convert a runtime value to a compile time constant (row-major layout and alignment). - Step 7 is repeated (copy pasted) for each metric. - Steps 7 and 8 do a lot of different things and the steps in between do relatively little. - Steps 1-5 do fairly little (but require a lot of boilerplate) **Proposal**: 1. Collect as much of the runtime behavior of each metric in a `distance_op` that [contains](https://github.com/ahendriksen/raft/blob/wip-refactor-distance/cpp/include/raft/distance/detail/distance_ops/canberra.cuh): - The core_op - The epilog_op - The required shared memory - Whether the inner loop is expensive (and thus loop unrolling should be curtailed) 2. Collect the runtime -> compile-time dispatch in one location ([dispatch.cuh](https://github.com/ahendriksen/raft/blob/486393eff4e0cf1d45ab9d7990b64d607e835d70/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh#L70)) 3. Collect kernel launching in one [location](https://github.com/ahendriksen/raft/blob/486393eff4e0cf1d45ab9d7990b64d607e835d70/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh#L108) 4. Remove some of the boilerplate in steps 1-5. ## After 1. `raft::distance::pairwise_distance` takes distance type as a run-time argument, allocates workspace as necessary, and dispatches to `raft::distance::detail::distance`. 2. `raft::distance::detail::distance` defines a default final operation (the identity) and calls an overload of itself. 3. `raft::distance::detail::distance` (with `fin_op`) calls an overload of `raft::distance::detail::distance_impl` for the correct distance type. 4. `raft::distance::detail::distance_impl` has the following responsibilities: - Pre-compute norms if necessary - Initialize distance op with parameters as necessary, see below for more information. - Transform input if necessary - If metric supports a CUTLASS operation, dispatch if necessary. - Dispatch to `raft::distance::detail::distance_matrix_dispatch` 5. `raft::distance::detail::distance_matrix_dispatch` has the following responsibilities: - swap x, y matrices if column major - dispatch to correct kernel based on run-time parameters `row_major` and `vec_len` - Determine kernel policy based on parameters - Call `raft::distance::detail::pairwise_matrix` 6. `raft::distance::detail::pairwise_matrix` launches the `raft::distance::detail::pairwise_matrix_kernel` with the correct launch parameters. **Distance_op** `raft::distance::detail::ops::XX_distance_op` [[example]](https://github.com/ahendriksen/raft/blob/wip-refactor-distance/cpp/include/raft/distance/detail/distance_ops/canberra.cuh) has the following responsibilities: - Take any parameters (sqrt, k, etc) - Define `core_op` and `epilog_op` - Define `use_norms`, `expensive_inner_loop`, and `shared_mem_size()`. Still TODO: - [x] Rename Minkowski and Chebyshev to Lp and Linf. - [x] Do something with this note in the comments: "if workspace is passed as nullptr, this will return in worksize, the number of bytes of workspace required", which is wrong. - [x] Add a mechanism to limit duplicate compilation when a CUTLASS kernel is available. This is done in follow up PR #1295. - [x] Some distance_ops have additional template parameters. This must be cleared up. Authors: - Allard Hendriksen (https://github.com/ahendriksen) - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Tamas Bela Feher (https://github.com/tfeher) - Corey J. Nolet (https://github.com/cjnolet) URL: #1142
rapids-bot closed this PR. I set the target of this PR to #1142 (so that the diff was reasonably small). Now, I cannot reopen this PR (because the target branch has been deleted) or retarget it (because this PR is closed). I fear I have to open a new PR. I will get back to your reviews! Apologies for the inconvenience. |
…ernel implementations (rapidsai#1142) The pairwise distance metrics are quite varied. The table below summarizes the differences, in terms of - Epilog : whether the metric has a non-empty epilog operation. - Uses norms: whether the metric requires precalculation of the norms of the vectors. - Has params: whether the norm has additional parameters. The L2 metric, for instance, has the `sqrt` boolean parameter that determines whether to calculate the squared or actual distance. - Pre- & post-processing: For some metrics, the norms have to be precalculated. For other metrics, the input matrices are transformed before the kernel launch, and "untransformed" after. - Expensive inner loop: some metrics use `pow`, `log` or other expensive functions in the inner loop. - Depends on row-major: the calculation of some metrics depend on whether the input is row-major. - CUTLASS: some metrics have an implementation using CUTLASS and tensor cores. <table border="2" cellspacing="0" cellpadding="6" rules="groups" frame="hsides"> <colgroup> <col class="org-left" /> <col class="org-left" /> <col class="org-left" /> <col class="org-left" /> <col class="org-left" /> <col class="org-left" /> <col class="org-left" /> <col class="org-left" /> </colgroup> <thead> <tr> <th scope="col" class="org-left">Metric</th> <th scope="col" class="org-left">Epilog</th> <th scope="col" class="org-left">Uses norms</th> <th scope="col" class="org-left">Has params</th> <th scope="col" class="org-left">Pre- & post-processing</th> <th scope="col" class="org-left">Expensive inner loop</th> <th scope="col" class="org-left">Depends on row-major</th> <th scope="col" class="org-left">CUTLASS</th> </tr> </thead> <tbody> <tr> <td class="org-left">Canberra</td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left">x</td> <td class="org-left"> </td> <td class="org-left"> </td> </tr> <tr> <td class="org-left">Chebyshev (Linf)</td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left"> </td> </tr> <tr> <td class="org-left">Correlation</td> <td class="org-left">x</td> <td class="org-left">x (twice)</td> <td class="org-left">x (many)</td> <td class="org-left">compute norms</td> <td class="org-left"> </td> <td class="org-left">x</td> <td class="org-left"> </td> </tr> <tr> <td class="org-left">Cosine</td> <td class="org-left">x</td> <td class="org-left">x</td> <td class="org-left"> </td> <td class="org-left">compute norms</td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left">x</td> </tr> <tr> <td class="org-left">Hamming</td> <td class="org-left">x</td> <td class="org-left"> </td> <td class="org-left">x (k)</td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left"> </td> </tr> <tr> <td class="org-left">Hellinger</td> <td class="org-left">x</td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left">sqrt and square</td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left"> </td> </tr> <tr> <td class="org-left">Jensen Shannon</td> <td class="org-left">x</td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left">x</td> <td class="org-left"> </td> <td class="org-left"> </td> </tr> <tr> <td class="org-left">KL divergence</td> <td class="org-left">x</td> <td class="org-left"> </td> <td class="org-left">x (row major, x == y)</td> <td class="org-left">yes</td> <td class="org-left">x</td> <td class="org-left">x</td> <td class="org-left"> </td> </tr> <tr> <td class="org-left">L1</td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left"> </td> </tr> <tr> <td class="org-left">L2 expanded</td> <td class="org-left">x</td> <td class="org-left">x</td> <td class="org-left">x (sqrt)</td> <td class="org-left">compute norms</td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left">x</td> </tr> <tr> <td class="org-left">L2 unexpanded</td> <td class="org-left">x</td> <td class="org-left"> </td> <td class="org-left">x (sqrt)</td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left"> </td> </tr> <tr> <td class="org-left">Minkowski (Lp)</td> <td class="org-left">x</td> <td class="org-left"> </td> <td class="org-left">x (p)</td> <td class="org-left"> </td> <td class="org-left">x</td> <td class="org-left"> </td> <td class="org-left"> </td> </tr> <tr> <td class="org-left">Russel-Rao</td> <td class="org-left">x</td> <td class="org-left"> </td> <td class="org-left">x (k, 1/k)</td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left"> </td> <td class="org-left"> </td> </tr> </tbody> </table> To keep the complexity that results from all these differences in check, there are several layers between the public API and the kernel launch, each with their own responsibility. ## Before 1. `raft::distance::pairwise_distance` takes distance type as a run-time argument and dispatches to `raft::distance::detail::pairwise_distance_impl`. 2. `raft::distance::detail::pairwise_distance_impl` allocates workspace as necessary and calls `raft::distance::detail::distance` 3. `raft::distance::detail::distance` defines a default final operation (the identity) and calls an overload of itself. 4. `raft::distance::detail::distance` (with `fin_op`) initializes a `DistanceImpl` zero-sized struct with the correct template arguments and runs the `.run()` method of the struct. 5. `raft::distance::detail::DistanceImpl<DistanceType>.run()` calls `raft::distance::detail::XX_Impl`. 6. `raft::distance::detail::XX_Impl` has the following responsibilities: - Pre-compute norms if necessary - Transform input if necessary - If metric supports a CUTLASS operation, dispatch if necessary. - Swap inputs if column-major. - Based on runtime parameter `row_major` dispatch to function template `raft::distance::detail::XX<bool row_major>` 7. `raft::distance::detail::XX` based on alignment of input data dispatch to function template `raft::distance::detail::XX_Impl<int veclen>` (different overload of previous `raft::distance::detail::XX_Impl`) 8. `raft::distance::detail::XX_Impl` has the following responsibilities: - Define `core_op` and `epilog_op` - Define `use_norms` - Launch kernel `pairwiseDistanceMatKernel` with correct launch parameters **Observations**: - Steps 6 and 7 both convert a runtime value to a compile time constant (row-major layout and alignment). - Step 7 is repeated (copy pasted) for each metric. - Steps 7 and 8 do a lot of different things and the steps in between do relatively little. - Steps 1-5 do fairly little (but require a lot of boilerplate) **Proposal**: 1. Collect as much of the runtime behavior of each metric in a `distance_op` that [contains](https://github.com/ahendriksen/raft/blob/wip-refactor-distance/cpp/include/raft/distance/detail/distance_ops/canberra.cuh): - The core_op - The epilog_op - The required shared memory - Whether the inner loop is expensive (and thus loop unrolling should be curtailed) 2. Collect the runtime -> compile-time dispatch in one location ([dispatch.cuh](https://github.com/ahendriksen/raft/blob/486393eff4e0cf1d45ab9d7990b64d607e835d70/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh#L70)) 3. Collect kernel launching in one [location](https://github.com/ahendriksen/raft/blob/486393eff4e0cf1d45ab9d7990b64d607e835d70/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh#L108) 4. Remove some of the boilerplate in steps 1-5. ## After 1. `raft::distance::pairwise_distance` takes distance type as a run-time argument, allocates workspace as necessary, and dispatches to `raft::distance::detail::distance`. 2. `raft::distance::detail::distance` defines a default final operation (the identity) and calls an overload of itself. 3. `raft::distance::detail::distance` (with `fin_op`) calls an overload of `raft::distance::detail::distance_impl` for the correct distance type. 4. `raft::distance::detail::distance_impl` has the following responsibilities: - Pre-compute norms if necessary - Initialize distance op with parameters as necessary, see below for more information. - Transform input if necessary - If metric supports a CUTLASS operation, dispatch if necessary. - Dispatch to `raft::distance::detail::distance_matrix_dispatch` 5. `raft::distance::detail::distance_matrix_dispatch` has the following responsibilities: - swap x, y matrices if column major - dispatch to correct kernel based on run-time parameters `row_major` and `vec_len` - Determine kernel policy based on parameters - Call `raft::distance::detail::pairwise_matrix` 6. `raft::distance::detail::pairwise_matrix` launches the `raft::distance::detail::pairwise_matrix_kernel` with the correct launch parameters. **Distance_op** `raft::distance::detail::ops::XX_distance_op` [[example]](https://github.com/ahendriksen/raft/blob/wip-refactor-distance/cpp/include/raft/distance/detail/distance_ops/canberra.cuh) has the following responsibilities: - Take any parameters (sqrt, k, etc) - Define `core_op` and `epilog_op` - Define `use_norms`, `expensive_inner_loop`, and `shared_mem_size()`. Still TODO: - [x] Rename Minkowski and Chebyshev to Lp and Linf. - [x] Do something with this note in the comments: "if workspace is passed as nullptr, this will return in worksize, the number of bytes of workspace required", which is wrong. - [x] Add a mechanism to limit duplicate compilation when a CUTLASS kernel is available. This is done in follow up PR rapidsai#1295. - [x] Some distance_ops have additional template parameters. This must be cleared up. Authors: - Allard Hendriksen (https://github.com/ahendriksen) - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Tamas Bela Feher (https://github.com/tfeher) - Corey J. Nolet (https://github.com/cjnolet) URL: rapidsai#1142
This PR improves the ability to do dispatch based on compute architecture. It is a follow up to #1142.
It has two goals:
We have a specific use case in RAFT for this feature. For the L2 pairwise distance kernel we have a CUTLASS based implementation that works om SM80+ and a fallback kernel. Preferably, each kernel is only compiled for the architectures on which it is actually used.
The previous approach can be described as follows (using psuedo code):
The main issue is that the generic kernel is copy pasted. In the new approach, the compute architectures for which the generic kernel has to be compiled (the "compatibility range") is given as an argument (using a compile-time tag type). This allows comparing to the compatibility range inside the kernel and early exiting if the architecture for which it is currently compiling is not supported. Therefore, we do not need to copy and paste generic kernel.