-
Notifications
You must be signed in to change notification settings - Fork 197
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fused L2 1-NN based on cutlass 3xTF32 / DMMA #1118
Fused L2 1-NN based on cutlass 3xTF32 / DMMA #1118
Conversation
…ck/warp shape size, this now touches the perf of fusedL2NN simt kernel
… from per row to multi-rows. now this kernel is 1.3x to 1.8x faster as k value increases perf gets better than fusedL2NN simt kernel
Pull requests from external contributors require approval from a |
… to predicated tile iterator
… than per warp multi row lock, cleanup and doc update
/okay to test |
/ok to test |
8169088
to
cc3e669
Compare
…election, add comments on register spills tile shape, add test case for veclen=2
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @mdoijade for the updates! The PR has changed significantly since the previous review round to enable an impressive speedup. Here are the first batch of comments of my second review.
This PR adds customized cutlass headers to implement the fused L2 NN operation.
While the changeset seem to be large, it is actually much smaller, if we compare the new cutlass headers to their original version (based on the reference added to files, and after applying raft formatting to the originals). Comparing that way makes it easier to follow how the cutlass code was adapted to our needs, and reveals a clean implementation. Great work, I have only smaller comments (so far)!
cpp/include/raft/distance/detail/fused_distance_nn/custom_epilogue_with_broadcast.h
Outdated
Show resolved
Hide resolved
cpp/include/raft/distance/detail/fused_distance_nn/custom_epilogue_with_broadcast.h
Show resolved
Hide resolved
cpp/include/raft/distance/detail/fused_distance_nn/epilogue_elementwise.cuh
Outdated
Show resolved
Hide resolved
cpp/include/raft/distance/detail/fused_distance_nn/epilogue.cuh
Outdated
Show resolved
Hide resolved
cpp/include/raft/distance/detail/fused_distance_nn/epilogue_elementwise.cuh
Show resolved
Hide resolved
…fix cutlass_utils.h comments
…ing 2 copies of the same code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Mahesh for the updates so far. Here is my second batch of comments.
cpp/include/raft/distance/detail/fused_distance_nn/persistent_gemm.h
Outdated
Show resolved
Hide resolved
cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_normvec_smem.h
Outdated
Show resolved
Hide resolved
I've lost track of where we are with this PR. Do you guys think this will make it into 23.06? (Burndown is in 3 days). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Mahesh for the updates!
@cjnolet, the PR is in a good shape, most of the issues have been addressed, it shall make it to 23.06.
cpp/include/raft/distance/detail/fused_distance_nn/epilogue_elementwise.cuh
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Mahesh for resolving the issues! The PR looks good to me.
/merge |
-- 3xTF32 & DMMA cutlass based persistent FusedL2NN kernel version loosely based on grouped gemm but customized for single problem size.
-- as the value of
k
increases the performance benefit of this implementation gets better.for k==64 upto 1.3x, for k ==128 upto 1.53x, k == 256, up to 1.67x.
-- for all the sizes of
k
this kernel out performs previous implementation.-- attaching the results of FusedL2NN Benchmark of previous implementation with this cutlass version.