-
Notifications
You must be signed in to change notification settings - Fork 197
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Create cub-based argmin primitive and replace argmin_along_rows
in ANN kmeans
#912
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for these changes, @Nyrio! Definitely happy to see prims for both argmin and argmax and more consolidation being done on the new ANN algos. Mostly minor things.
…ong-rows # Conflicts: # build.sh
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Louis for the PR, I just have a few small comments.
@lowener @cjnolet The new argmax header was slightly wrong, and the test was incorrect as well (see explanation below). Please have a detailed look at my latest change to see if you agree with the way I solved this. The wrapper takes a row-major matrix view, but the implementation is in column-major semantics. As a result:
I've kept the column-major convention in the deprecated header, fixed the new ones for the row-major convention, and switched the implementation to drop row and columns in favor of N (number of reductions) and D (elements to reduce) and documented that. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch, your changes on argmax make sense for me!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks for fixing the intermediate issues as well!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Louis for addressing the issues, it looks good to me!
@gpucibot merge |
This PR follows up on a suggestion from @cjnolet. The new
argmin
primitive is up to 5x faster thanargmin_along_rows
for dimensions relevant to ANN kmeans, and removes code duplication.The reasons why it is faster are:
argmin_along_rows
often misses on doing a sequential reduction before the tree reduction, especially as it uses large block sizes, as much as 1024.argmin_along_rows
.argmin
prim to using thecub::BLOCK_REDUCE_RAKING_COMMUTATIVE_ONLY
algorithm, we can get up to 30% further speedup! (I believe it's safe to use the commutative algorithm here since the offset is contained in the key-value pair so the reduction operation is commutative).The speedup that I have measured for IVF-Flat build with the
InnerProduct
metric is around 15%.