Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Filter out infinities in radix-based select-k #1742

Conversation

achirkin
Copy link
Contributor

@achirkin achirkin commented Aug 16, 2023

As a means of filtering, ANN methods can produce a lot of repeated max_bound<T>/min_bound<T> values.
These are fed to a select_k function, which leads to poor performance if the radix-based implementation is used.
This is due to the nature of the algorithm (lots of values with the same bit representation).

This fix filters out max_bound<T>/min_bound<T> values as a special case. It works as follows:

  • In the zero-th pass (first histogram creation), we check the first k values of the input for being max_bound<T>/min_bound<T> and add them to the end of the output if found.
  • In the other passes, the max_bound<T>/min_bound<T> are explicitly ignored; this breaks the assumption that the inputs always have enough values; the PR makes the code not rely on this assumption by slightly modifying comparisons.
  • The back-fill sequence of k-th values (bits == kth_value_bits) is changed to fill the output from k - needed_num_of_kth in order to not override the max_bound<T>/min_bound<T> values written during the zero-th pass.

Closes: #1725

@achirkin achirkin added improvement Improvement / enhancement to an existing function non-breaking Non-breaking change 2 - In Progress Currenty a work in progress labels Aug 16, 2023
@github-actions github-actions bot added the cpp label Aug 16, 2023
@achirkin achirkin marked this pull request as ready for review August 21, 2023 07:44
@achirkin achirkin requested a review from a team as a code owner August 21, 2023 07:44
@achirkin achirkin requested a review from yong-wang August 21, 2023 07:45
@yong-wang
Copy link
Contributor

In the zero-th pass (first histogram creation), we check the first k values of the input for being max_bound/min_bound and add them to the end of the output if found.

I think there is a corner case. What if the first k values doesn't contain enough max_bound<T>/min_bound<T> values? For example, suppose we need 5 infs in the top-k results and all these 5 infs are not in the first k values, then we get no inf in pass 0. During the last filtering step, we also won't get any inf because they are not saved in out_buf.

@achirkin
Copy link
Contributor Author

achirkin commented Aug 21, 2023

This will work fine. The trick here is that if there are no bound values in the first k, that automatically means there are enough non-bound values there - because any input value should be preferred over the bound values.

@yong-wang
Copy link
Contributor

Got it. Really smart strategy.

@yong-wang
Copy link
Contributor

The back-fill sequence of k-th values (bits == kth_value_bits) is changed to fill the output from k - needed_num_of_kth in order to not override the max_bound<T>/min_bound<T> values written during the zero-th pass.

Shoud the k-th values be always written from the end?

  • If max_bound<T>/min_bound<T> is not the k-th value, they should not appear in the result, and should be overwritten by values that <= k-th value.
  • If max_bound<T>/min_bound<T> is the k-th value, then the last filter won't write any k-th values to the output, so the max_bound<T>/min_bound<T> values written during pass 0 are untouched.

@achirkin
Copy link
Contributor Author

The problem is that with this implementation max_bound<T>/min_bound<T> get the special treatment; they do not appear in the histogram. As a result, select_bucket assumption is broken; I came up with a workaround to select the last bin for the next pass if there not enough (non-bound) values. Hence we can end up with a situation when the (bits == kth_value_bits) does not necessarily mean there are enough non-bound values (last bin is selected even though accumulated count is less than k). Luckily, we have needed_num_of_kth from the previous pass, so we know from where to start writing the "k-th" values - it does not really matter in which order we write them. But if there are not enough "k-th" values, they will not override the bound values written during zero-th pass.

@achirkin achirkin added 3 - Ready for Review and removed 2 - In Progress Currenty a work in progress labels Aug 21, 2023
@yong-wang
Copy link
Contributor

Thanks for the explanation.

The code looks good to me.

I suggest adding unit tests which contain infs.

However, I'm a little concerned about whether we should add such special treatment. I'll add comments in #1725, which has concrete context.

@achirkin
Copy link
Contributor Author

Thanks for reminding me to write the tests! Indeed there was a bug :) forgot to filter out infinities in the last-filter (broken in the case it takes the original input data as in_buf).

@yong-wang
Copy link
Contributor

Found a bug.

The test case is:
len = 32, k = 31, select_min = true
in = {0, 1, 2, 3, inf, inf, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}
in_idx = {31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}

Then the results are:
out = {0, 1, 2, 3, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, inf}
out_idx = {31, 30, 29, 28, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 4}

The last value in out_idx is 4, but should be 27 or 26, the values in in_idx corresponding to inf.

@achirkin
Copy link
Contributor Author

Indeed, apparently the index isn't passed in the zero-th pass in the one-block kernel.

https://github.com/rapidsai/raft/blob/09ab49b22ae2a396794beae10ac16ef8524d3ce3/cpp/include/raft/matrix/detail/select_radix.cuh#L749C12-L749C12

Thanks! I'll fix and add your test case tomorrow. This shouldn't affect performance in any way, so I'll skip re-running all benchmarks.

@tfeher
Copy link
Contributor

tfeher commented Sep 11, 2023

Thank you @achirkin for implementing this workaround and for the detailed benchmarks presented here! Also thanks for @yong-wang for the constructive discussion and for further benchmarks. The conversaition here and in issue #1725 was really detailed and illuminating.

After going through the discussion, I have the following picture (please correct me if I am wrong):

  • The PR implements a workaround that improves top-k search for the case when the input has very high number of infinities.
  • We see a clear advantage of this when there are less than k non-inf values in the input.
  • Such case can occur in practice, but it probably also means that the problem parameters are set up incorrectly (normalization and precision used or too strong filtering).
  • The changes affect the register pressure, and the code complexity. There is a +/- 10% perf diff in the benchmarks. The average perf change (over the benchmark cases with only modest amount of infs) is close to 0.

I tend to agree with Yong, that it would be preferred not to complicate further the radix-k selection code, if this only treats a corner case. So the question of whether we should integrate this PR into RAFT, depends on whether the corner case needs to be addressed or not.

As Yong has pointed out, having so many infs during ANN search means

some other serious problem has already occurred, and the recall will be low. [...]
The same reasoning applies to ANN pre-filtering. If so many values are deleted that ANN could not return k items with valid distances, it means too many values have been deleted,

According to Artem,

the proposed fix adds the value in that it fixes the x10 slowdown in some edge cases with little to no cost to any of the other cases. Aside from the zero-th pass it doesn't really complicate the logic that much.

These are all important points. While it is true that on average, practically there is no perf change for the non corner cases, we average the results of arbitrarily defined gbench benchmarks. I am a bit concerned about the +/- 10% affect on these benchmarks: are we sure that we average a relevant subset? In Yong's benchmark plot we see a small but noticeable perf degradation. Instead of averaging the gbench benchmarks, I believe we shall take a set of relevant ANN benchmarks, and see how this PR affects the perf there (alternatively define gbench tests that corresponds these).

Because of these concerns, for the regular ANN search case, I would be happier to get a warning message like "k-th value is inf, please check your precision/normalization/filtering", instead of modifying the k-selection kernels.

I am not so sure about the pre-filtering. This could still motivate the solution presented in this PR. @cjnolet do we expect to filter so many values, that less than k non-inf values remain in the end? Do we expect this to occur so often in practice, that we should add a special case for radix topk? If yes, then I would be in favor merging this.

rapids-bot bot pushed a commit that referenced this pull request Sep 18, 2023
Add a few extra test and benchmark cases; in particular:
  1. Allow specifying non-trivial input indices
  2. Allow filling the input data with infinities to see how algorithms perform in edge cases

These tests are borrowed from the controversial workaround #1742

Authors:
  - Artem M. Chirkin (https://github.com/achirkin)

Approvers:
  - Tamas Bela Feher (https://github.com/tfeher)

URL: #1821
@cjnolet
Copy link
Member

cjnolet commented Sep 20, 2023

@cjnolet do we expect to filter so many values, that less than k non-inf values remain in the end? Do we expect this to occur so often in practice, that we should add a special case for radix topk?

For deletion, we haven't gotten a whole lot of consensus on patterns encountered in practice, but we have been told that it's possible for the actual valid k values in a query to be less than k for some data points. We really need to be able to support the generalized cases very efficiently- so assume that not everyone's going to be returning a list with <5 materialized values, but folks that need that capability would prefer not to take a perf hit- especially since there's already going to be a perf hit in the filtering functions themselves.

Aside from delete, consider other (up-coming) use-cases, like filtering recommendations for items that a user has already purchased, or multi-valued keys where a document might only be returned once even if multiple tokens for the same document end up in the list of nearest neighbors. We want these cases to be fast, but we probably don't want to do it at the expense of the non-filtered case, since I still believe that's going to be the most widely used.

@achirkin achirkin changed the base branch from branch-23.10 to branch-24.02 December 15, 2023 09:16
@achirkin achirkin marked this pull request as draft December 15, 2023 09:17
@achirkin achirkin added 0 - Stale / Orphaned PR is too outdated and needs significant rework, or author is no longer responsible. and removed 3 - Ready for Review labels Dec 15, 2023
@cjnolet cjnolet closed this Jan 10, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
0 - Stale / Orphaned PR is too outdated and needs significant rework, or author is no longer responsible. cpp improvement Improvement / enhancement to an existing function non-breaking Non-breaking change
Projects
Development

Successfully merging this pull request may close these issues.

[BUG] radix::select_k<half> is slow
4 participants