-
Notifications
You must be signed in to change notification settings - Fork 197
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Custom fusedL2NN kernel for kmeans prediction #2050
[WIP] Custom fusedL2NN kernel for kmeans prediction #2050
Conversation
…custom-kernel-ivfpq-codebook
Tagging @tfeher @achirkin @cjnolet for some early review, before running more benchmarks and writing test cases (although I am pretty sure the accuracy should be fine). Thanks
|
I also created another PR for subsampling support. #2052 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @abc99lr for the PR and the analyis! LGTM apart a small nitpick with sqrt
flag.
Two suggestions for bonus points:
- Could you please run the prims-bench and maybe add relevant test cases there?
{100000, 128}, {1000000, 128}, {10000000, 128}, - Would you consider changing the name of the PR to better reflect that it actually makes changes to the k-means rather than to ivf-pq?
* can sometimes have round-off errors, which will cause (aNorm == bNorm) ~ accVal instead. | ||
*/ | ||
curr_distance = curr_distance * !((curr_distance * curr_distance < raft::distance::detail::ops::get_clamp_precision<DataT>()) * (dataset_norm[curr_row] == centers_norm[curr_n])); | ||
if (sqrt) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Change it to constexpr
to make it more clear for both the reader and the compiler
if (sqrt) { | |
if constexpr (Sqrt) { |
@@ -380,6 +382,91 @@ void fusedL2NNImpl(OutT* min, | |||
} | |||
} | |||
|
|||
template <bool sqrt, typename DataT, typename IdxT, typename LabelT> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nitpick: template parameters should start with a capital
template <bool sqrt, typename DataT, typename IdxT, typename LabelT> | |
template <bool Sqrt, typename DataT, typename IdxT, typename LabelT> |
This PR address #1901 by subsampling the input dataset for PQ codebook training to reduce the runtime. Currently, a similar strategy is applied to `per_cluster` method, but not to the default `per_subset` method. This PR fixes this gap. Similar to the subsampling mechanism of the `per_cluster` method, we pick at minimum `256*max(pq_book_size, pq_dim)` number of input rows for training each code book. https://github.com/rapidsai/raft/blob/cf4e03d0b952c1baac73f695f94d6482d8c391d8/cpp/include/raft/neighbors/detail/ivf_pq_build.cuh#L408 The following performance numbers are generated using Deep-100M dataset. After subsampling, the search time and accuracy are not impacted (within +-5%) except one case where I saw 9% performance drop on search (using 10K batch for search). More extensive benchmarking across datasets seems to be needed for justification. Dataset | n_iter | n_list | pq_bits | pq_dim | ratio | Original time (s) | Subsampling (s) | Speedup [subsampling] -- | -- | -- | -- | -- | -- | -- | -- | -- Deep-100M | 25 | 50000 | 4 | 96 | 10 | 129 | 89.5 | 1.44 Deep-100M | 25 | 50000 | 5 | 96 | 10 | 128 | 89.4 | 1.43 Deep-100M | 25 | 50000 | 6 | 96 | 10 | 131 | 90 | 1.46 Deep-100M | 25 | 50000 | 7 | 96 | 10 | 129 | 91.1 | 1.42 Deep-100M | 25 | 50000 | 8 | 96 | 10 | 149 | 93.4 | 1.60 Note, after subsampling, the PQ codebook generation is no longer a bottleneck in the IVF-PQ index building. More optimizations on PQ codebook generation seem unnecessary. Although we could in theory apply the custom kernel approach (#2050) with subsampling, my early tests show the current GEMM approach performs better than the custom kernel after subsampling. Using multiple stream could improve the performance further by overlapping kernels for different `pq_dim`, given kernels are small after subsampling and may not fully utilize GPU. However, as mention above, since the entire PQ codebook is fast, this optimization may not be worthwhile. TODO - [x] Benchmark the performance/accuracy impacts on multiple datasets Authors: - Rui Lan (https://github.com/abc99lr) - Ray Douglass (https://github.com/raydouglass) - gpuCI (https://github.com/GPUtester) Approvers: - Tamas Bela Feher (https://github.com/tfeher) - Corey J. Nolet (https://github.com/cjnolet) URL: #2052
This PR adds a custom fusedL2NN kernel in order to speedup PQ codebook generation step of IVF-PQ index building. This PR partially address #1901 (we are also experimenting other ideas, such as subsampling, to optimize the PQ codebook generation step).
The key kernel to the PQ code generation performance is kmeans predict, which is used to find the cluster that is nearest (defined by distance function) from each data point. Depending on the distance type, this could be viewed as some variants of GEMM between
dataset
andcluster_centers
. In this PR, we focus on L2 distance. The current implementation utilizes CUTLASS kernel for this L2 distance calculation, with a fused reduction part that performs anargmin
on the distances and outputs the label of the cluster which is the closest to each input data point.In codebook generation step, we have a small input matrix for
cluster_center
(size[n_cluster, n_dim]
), sincen_dim
depends onpq_len
which is usually 1-24 andn_cluster
depends on2^pq_bit
, which ranges from 16 to 256 (pq_bit
from [4, 8]). Also, thedataset
is a very tall matrix (size[n_row, n_dim]
),n_row
is the number of rows in the training set after subsampling, which is usually on the order of millions.This PR focus on optimizing the GEMM with thin and small inputs. We found a custom non-GEMM kernel is able to achieve up to 33x compared to the original fused GEMM implementation for this problem when
n_rows
is 10M. For largern_clusters
andn_dim
, the fused GEMM is still the best performer, but for codebook generation part, most of the problem sizes are not ideal for GEMM. Please see the following benchmark results are produced on a A100-80GB-PCIe machine.In order to prevent performance hit, we only use the custom kernel when
n_clusters*n_dim
is less or equal than 256 in this PR. Those cases are circled in green in the picture above. Please let me know if you have ideas on better heuristics.Here are E2E results on Deep-100M dataset. More benchmark numbers to be posted. Search performance is not impacted by this PR, Most recall/runtime are within 5% difference; all within 10%. Those variance seems to be run-to-run difference.
TODO