-
Notifications
You must be signed in to change notification settings - Fork 197
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FEA] support of prefiltered brute force #2294
[FEA] support of prefiltered brute force #2294
Conversation
- This PR is one part of the feature of rapidsai#1969 - Add the API of 'search_with_filtering' for brute force. Authors: - James Rong (https://github.com/rhdong)
…DDMM with faster_dot_on_csr
Hey @cjnolet @benfred, I implemented an initial version of ***WARNING*** CPU scaling is enabled, the benchmark real time measurements may be noisy and will incur extra overhead.
------------------------------------------------------------------------------------------------------
Benchmark cuSparseSDDMM faster_dot Iterations Configuration(No filtered means normal `search`)
------------------------------------------------------------------------------------------------------n_sample#dim#n_query#top_k#remove rate# metric
KNN/float/int64_t/brute_force_filter_knn/0/0/0/manual_time 11.4 ms 11.4 ms 61 1000000#4096#1#255#[No filtered]#InnerProduct#NO_COPY#SEARCH
KNN/float/int64_t/brute_force_filter_knn/1/0/0/manual_time 11.5 ms 11.4 ms 61 1000000#4096#1#255#[No filtered]#L2Expanded#NO_COPY#SEARCH
KNN/float/int64_t/brute_force_filter_knn/2/0/0/manual_time 137 ms 7.02 ms 4 1000000#4096#1#255#0.8#InnerProduct#NO_COPY#SEARCH
KNN/float/int64_t/brute_force_filter_knn/3/0/0/manual_time 137 ms 7.03 ms 5 1000000#4096#1#255#0.8#L2Expanded#NO_COPY#SEARCH
KNN/float/int64_t/brute_force_filter_knn/4/0/0/manual_time 70.3 ms 4.72 ms 10 1000000#4096#1#255#0.9#InnerProduct#NO_COPY#SEARCH
KNN/float/int64_t/brute_force_filter_knn/5/0/0/manual_time 70.2 ms 4.72 ms 10 1000000#4096#1#255#0.9#L2Expanded#NO_COPY#SEARCH
KNN/float/int64_t/brute_force_filter_knn/6/0/0/manual_time 8.94 ms 2.05 ms 77 1000000#4096#1#255#0.99#InnerProduct#NO_COPY#SEARCH
KNN/float/int64_t/brute_force_filter_knn/7/0/0/manual_time 8.97 ms 2.06 ms 77 1000000#4096#1#255#0.99#L2Expanded#NO_COPY#SEARCH
KNN/float/int64_t/brute_force_filter_knn/8/0/0/manual_time 17.1 ms 17.1 ms 41 1000000#4096#10#255#[No filtered]#InnerProduct#NO_COPY#SEARCH
KNN/float/int64_t/brute_force_filter_knn/9/0/0/manual_time 17.2 ms 17.2 ms 41 1000000#4096#10#255#[No filtered]#L2Expanded#NO_COPY#SEARCH
KNN/float/int64_t/brute_force_filter_knn/10/0/0/manual_time 149 ms 36.0 ms 5 1000000#4096#10#255#0.8#InnerProduct#NO_COPY#SEARCH
KNN/float/int64_t/brute_force_filter_knn/11/0/0/manual_time 149 ms 36.1 ms 5 1000000#4096#10#255#0.8#L2Expanded#NO_COPY#SEARCH
KNN/float/int64_t/brute_force_filter_knn/12/0/0/manual_time 76.0 ms 19.2 ms 9 1000000#4096#10#255#0.9#InnerProduct#NO_COPY#SEARCH
KNN/float/int64_t/brute_force_filter_knn/13/0/0/manual_time 76.2 ms 19.2 ms 9 1000000#4096#10#255#0.9#L2Expanded#NO_COPY#SEARCH
KNN/float/int64_t/brute_force_filter_knn/14/0/0/manual_time 9.43 ms 3.39 ms 72 1000000#4096#10#255#0.99#InnerProduct#NO_COPY#SEARCH
KNN/float/int64_t/brute_force_filter_knn/15/0/0/manual_time 9.45 ms 3.40 ms 72 1000000#4096#10#255#0.99#L2Expanded#NO_COPY#SEARCH
KNN/float/int64_t/brute_force_filter_knn/16/0/0/manual_time 488 ms 489 ms 2 1000000#4096#1000#255#[No filtered]#InnerProduct#NO_COPY#SEARCH
KNN/float/int64_t/brute_force_filter_knn/17/0/0/manual_time 495 ms 494 ms 2 1000000#4096#1000#255#[No filtered]#L2Expanded#NO_COPY#SEARCH
KNN/float/int64_t/brute_force_filter_knn/18/0/0/manual_time 1916 ms 3277 ms 1 1000000#4096#1000#255#0.8#InnerProduct#NO_COPY#SEARCH
KNN/float/int64_t/brute_force_filter_knn/19/0/0/manual_time 1915 ms 3280 ms 1 1000000#4096#1000#255#0.8#L2Expanded#NO_COPY#SEARCH
KNN/float/int64_t/brute_force_filter_knn/20/0/0/manual_time 968 ms 1641 ms 1 1000000#4096#1000#255#0.9#InnerProduct#NO_COPY#SEARCH
KNN/float/int64_t/brute_force_filter_knn/21/0/0/manual_time 969 ms 1643 ms 1 1000000#4096#1000#255#0.9#L2Expanded#NO_COPY#SEARCH
KNN/float/int64_t/brute_force_filter_knn/22/0/0/manual_time 108 ms 167 ms 6 1000000#4096#1000#255#0.99#InnerProduct#NO_COPY#SEARCH
KNN/float/int64_t/brute_force_filter_knn/23/0/0/manual_time 108 ms 167 ms 6 1000000#4096#1000#255#0.99#L2Expanded#NO_COPY#SEARCH
KNN/float/int64_t/brute_force_filter_knn/24/0/0/manual_time 2.64 ms 2.64 ms 265 1000000#512#1#255#[No filtered]#InnerProduct#NO_COPY#SEARCH
KNN/float/int64_t/brute_force_filter_knn/25/0/0/manual_time 2.65 ms 2.65 ms 264 1000000#512#1#255#[No filtered]#L2Expanded#NO_COPY#SEARCH
KNN/float/int64_t/brute_force_filter_knn/26/0/0/manual_time 21.4 ms 4.84 ms 32 1000000#512#1#255#0.8#InnerProduct#NO_COPY#SEARCH
KNN/float/int64_t/brute_force_filter_knn/27/0/0/manual_time 21.4 ms 4.84 ms 33 1000000#512#1#255#0.8#L2Expanded#NO_COPY#SEARCH
KNN/float/int64_t/brute_force_filter_knn/28/0/0/manual_time 12.0 ms 3.60 ms 57 1000000#512#1#255#0.9#InnerProduct#NO_COPY#SEARCH
KNN/float/int64_t/brute_force_filter_knn/29/0/0/manual_time 12.0 ms 3.62 ms 57 1000000#512#1#255#0.9#L2Expanded#NO_COPY#SEARCH
KNN/float/int64_t/brute_force_filter_knn/30/0/0/manual_time 2.83 ms 1.92 ms 246 1000000#512#1#255#0.99#InnerProduct#NO_COPY#SEARCH
KNN/float/int64_t/brute_force_filter_knn/31/0/0/manual_time 2.84 ms 1.94 ms 245 1000000#512#1#255#0.99#L2Expanded#NO_COPY#SEARCH
KNN/float/int64_t/brute_force_filter_knn/32/0/0/manual_time 3.57 ms 3.57 ms 196 1000000#512#10#255#[No filtered]#InnerProduct#NO_COPY#SEARCH
KNN/float/int64_t/brute_force_filter_knn/33/0/0/manual_time 3.63 ms 3.63 ms 193 1000000#512#10#255#[No filtered]#L2Expanded#NO_COPY#SEARCH
KNN/float/int64_t/brute_force_filter_knn/34/0/0/manual_time 22.3 ms 14.4 ms 31 1000000#512#10#255#0.8#InnerProduct#NO_COPY#SEARCH
KNN/float/int64_t/brute_force_filter_knn/35/0/0/manual_time 22.3 ms 14.4 ms 31 1000000#512#10#255#0.8#L2Expanded#NO_COPY#SEARCH
KNN/float/int64_t/brute_force_filter_knn/36/0/0/manual_time 12.6 ms 8.33 ms 55 1000000#512#10#255#0.9#InnerProduct#NO_COPY#SEARCH
KNN/float/int64_t/brute_force_filter_knn/37/0/0/manual_time 12.6 ms 8.35 ms 55 1000000#512#10#255#0.9#L2Expanded#NO_COPY#SEARCH
KNN/float/int64_t/brute_force_filter_knn/38/0/0/manual_time 2.80 ms 2.28 ms 249 1000000#512#10#255#0.99#InnerProduct#NO_COPY#SEARCH
KNN/float/int64_t/brute_force_filter_knn/39/0/0/manual_time 2.81 ms 2.29 ms 247 1000000#512#10#255#0.99#L2Expanded#NO_COPY#SEARCH
KNN/float/int64_t/brute_force_filter_knn/40/0/0/manual_time 75.0 ms 74.5 ms 9 1000000#512#1000#255#[No filtered]#InnerProduct#NO_COPY#SEARCH
KNN/float/int64_t/brute_force_filter_knn/41/0/0/manual_time 79.9 ms 79.7 ms 9 1000000#512#1000#255#[No filtered]#L2Expanded#NO_COPY#SEARCH
KNN/float/int64_t/brute_force_filter_knn/42/0/0/manual_time 210 ms 1085 ms 3 1000000#512#1000#255#0.8#InnerProduct#NO_COPY#SEARCH
KNN/float/int64_t/brute_force_filter_knn/43/0/0/manual_time 215 ms 1088 ms 3 1000000#512#1000#255#0.8#L2Expanded#NO_COPY#SEARCH
KNN/float/int64_t/brute_force_filter_knn/44/0/0/manual_time 105 ms 543 ms 7 1000000#512#1000#255#0.9#InnerProduct#NO_COPY#SEARCH
KNN/float/int64_t/brute_force_filter_knn/45/0/0/manual_time 106 ms 544 ms 7 1000000#512#1000#255#0.9#L2Expanded#NO_COPY#SEARCH
KNN/float/int64_t/brute_force_filter_knn/46/0/0/manual_time 12.6 ms 56.6 ms 55 1000000#512#1000#255#0.99#InnerProduct#NO_COPY#SEARCH
KNN/float/int64_t/brute_force_filter_knn/47/0/0/manual_time 12.7 ms 56.8 ms 54 1000000#512#1000#255#0.99#L2Expanded#NO_COPY#SEARCH
KNN/float/int64_t/brute_force_filter_knn/48/0/0/manual_time 1.72 ms 1.71 ms 407 1000000#128#1#255#[No filtered]#InnerProduct#NO_COPY#SEARCH
KNN/float/int64_t/brute_force_filter_knn/49/0/0/manual_time 1.73 ms 1.72 ms 406 1000000#128#1#255#[No filtered]#L2Expanded#NO_COPY#SEARCH
KNN/float/int64_t/brute_force_filter_knn/50/0/0/manual_time 8.86 ms 3.99 ms 79 1000000#128#1#255#0.8#InnerProduct#NO_COPY#SEARCH
KNN/float/int64_t/brute_force_filter_knn/51/0/0/manual_time 8.87 ms 4.00 ms 79 1000000#128#1#255#0.8#L2Expanded#NO_COPY#SEARCH
KNN/float/int64_t/brute_force_filter_knn/52/0/0/manual_time 5.64 ms 3.19 ms 123 1000000#128#1#255#0.9#InnerProduct#NO_COPY#SEARCH
KNN/float/int64_t/brute_force_filter_knn/53/0/0/manual_time 5.62 ms 3.19 ms 120 1000000#128#1#255#0.9#L2Expanded#NO_COPY#SEARCH
KNN/float/int64_t/brute_force_filter_knn/54/0/0/manual_time 2.15 ms 1.87 ms 326 1000000#128#1#255#0.99#InnerProduct#NO_COPY#SEARCH
KNN/float/int64_t/brute_force_filter_knn/55/0/0/manual_time 2.16 ms 1.89 ms 323 1000000#128#1#255#0.99#L2Expanded#NO_COPY#SEARCH
KNN/float/int64_t/brute_force_filter_knn/56/0/0/manual_time 2.13 ms 2.12 ms 328 1000000#128#10#255#[No filtered]#InnerProduct#NO_COPY#SEARCH
KNN/float/int64_t/brute_force_filter_knn/57/0/0/manual_time 2.19 ms 2.18 ms 320 1000000#128#10#255#[No filtered]#L2Expanded#NO_COPY#SEARCH
KNN/float/int64_t/brute_force_filter_knn/58/0/0/manual_time 9.41 ms 6.07 ms 75 1000000#128#10#255#0.8#InnerProduct#NO_COPY#SEARCH
KNN/float/int64_t/brute_force_filter_knn/59/0/0/manual_time 9.44 ms 6.09 ms 74 1000000#128#10#255#0.8#L2Expanded#NO_COPY#SEARCH
KNN/float/int64_t/brute_force_filter_knn/60/0/0/manual_time 5.94 ms 4.16 ms 114 1000000#128#10#255#0.9#InnerProduct#NO_COPY#SEARCH
KNN/float/int64_t/brute_force_filter_knn/61/0/0/manual_time 5.97 ms 4.18 ms 113 1000000#128#10#255#0.9#L2Expanded#NO_COPY#SEARCH
KNN/float/int64_t/brute_force_filter_knn/62/0/0/manual_time 2.10 ms 1.86 ms 332 1000000#128#10#255#0.99#InnerProduct#NO_COPY#SEARCH
KNN/float/int64_t/brute_force_filter_knn/63/0/0/manual_time 2.11 ms 1.87 ms 331 1000000#128#10#255#0.99#L2Expanded#NO_COPY#SEARCH
KNN/float/int64_t/brute_force_filter_knn/64/0/0/manual_time 33.1 ms 33.1 ms 21 1000000#128#1000#255#[No filtered]#InnerProduct#NO_COPY#SEARCH
KNN/float/int64_t/brute_force_filter_knn/65/0/0/manual_time 38.0 ms 38.0 ms 18 1000000#128#1000#255#[No filtered]#L2Expanded#NO_COPY#SEARCH
KNN/float/int64_t/brute_force_filter_knn/66/0/0/manual_time 54.2 ms 244 ms 13 1000000#128#1000#255#0.8#InnerProduct#NO_COPY#SEARCH
KNN/float/int64_t/brute_force_filter_knn/67/0/0/manual_time 57.4 ms 247 ms 12 1000000#128#1000#255#0.8#L2Expanded#NO_COPY#SEARCH
KNN/float/int64_t/brute_force_filter_knn/68/0/0/manual_time 24.6 ms 121 ms 29 1000000#128#1000#255#0.9#InnerProduct#NO_COPY#SEARCH
KNN/float/int64_t/brute_force_filter_knn/69/0/0/manual_time 26.1 ms 123 ms 27 1000000#128#1000#255#0.9#L2Expanded#NO_COPY#SEARCH
KNN/float/int64_t/brute_force_filter_knn/70/0/0/manual_time 4.69 ms 14.4 ms 149 1000000#128#1000#255#0.99#InnerProduct#NO_COPY#SEARCH
KNN/float/int64_t/brute_force_filter_knn/71/0/0/manual_time 4.83 ms 14.5 ms 138 1000000#128#1000#255#0.99#L2Expanded#NO_COPY#SEARCH |
@rhdong these benchmarks look promising so far. Can you run these for the following data shapes and also provide the standard (non-filtered) brute-force knn numbers for reference, please? Num vectors: [100k, 1M, 10M] Can you also see how far we can take the sparsity? It would be helpful to know what the results look like for 60%, 70%, and 80% sparsity. From what I can see in your benchmarks above, it appears as though the performance degrades significantly for your new kernel as k grows, but shouldn't the k-selection be a separate operation altogether? Do you have any ideas or insights into this behavior? |
Hi @cjnolet , this is the benchmark for end-2-end, not only for the new kernel. So, maybe the chart is not so clear, the conclusion should be the performance decreasing with |
Hello @cjnolet , here is the latest benchmark, I improved the new kernel a little bit, it has more stable performance(without that much strange performance gap) and no need to consume extra buffer(that would might cause the OOM on some lower sparsity like cuSparseSDDMM). But its performance in some cases(roughly 30-40%) is not better than cuSparseSDDMM. I intend to merge them by a branch. Before that, I'd like to hear your advice, thank you!
The following cases are still running. |
value_t l_dot_ = 0.0; | ||
#pragma unroll | ||
for (value_idx k = vec_id; k < dim; k += blockDim.x) { | ||
asm("prefetch.global.L2 [%0];" ::"l"(B_col + k + blockDim.x)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @achirkin, I guess this symbol checking it concerns here, what should I do to keep the same function (this is very useful for better performance.)
@@ -25,12 +28,18 @@ | |||
#include <raft/distance/distance.cuh> | |||
#include <raft/distance/distance_types.hpp> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a heads up, this becomes distance/distance.hpp
in cuvs.
…into rhdong/prefiltered-bf
cpp/CMakeLists.txt
Outdated
@@ -334,6 +334,14 @@ if(RAFT_COMPILE_LIBRARY) | |||
src/matrix/detail/select_k_float_int32.cu | |||
src/matrix/detail/select_k_half_int64_t.cu | |||
src/matrix/detail/select_k_half_uint32_t.cu | |||
src/sparse/matrix/detail/select_k_half_uint32_t.cu |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We're removing libraft.so in a future version, so I think we probably want to avoid adding these instantiations.
@@ -16,7 +16,7 @@ | |||
|
|||
#pragma once | |||
|
|||
#include <raft/core/bitset.cuh> | |||
#include <raft/core/bitmap.hpp> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds great! Just make sure to keep all cuda stuff out of the hpp file.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bitmap.cuh
should include bitset.cuh
since it is inheriting from the bitset class. This will avoid having undefined references for n_elements()
for example
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bitmap.cuh
should includebitset.cuh
since it is inheriting from the bitset class. This will avoid having undefined references forn_elements()
for example
Thank you @lowener , very make sense, just fix it!
value_t* s_A = (value_t*)smem; | ||
value_idx cur_row = -1; | ||
|
||
#pragma unroll |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have you verified in the underlying SASS that this pragma unroll is doing anything? This does a static analysis when it's compiled and relies on compile-time constants to pull the runtime loops out into independent compiled loops. When the number of iterations is a runtime value, this will be silently skipped.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Follow-up question- why are we iterating through all of the rows in each block? Shouldn't each block or warp be launched on an independent row? Otherwise, every block is going to duplicate the work, isn't it?
Edit: I should have looked more closely where you are launching the kernel- I see this is actually chunks of rows. To ease future eyes from making this mistake, can you please rename nrows
to be something more specific? Even something like nrows_in_chunk
or nrows_for_block
would be more readable.
Another thought- converting the CSR to a COO would be a more efficient way to do this, especially for CSR matrices of high sparsity- that way you could just have the kernel go through the edge list element-wise and perform the needed dots. Each item in the bitmap would be 1 dot product, thus you could be guaranteeing more uniformly distributed chunks by chunking over the COO instead of rows of the CSR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it! I will remove some unuseful unroll, and I did try a COO version, but it is a little low efficient because of this memory-intensive kernel. The COO will increase the memory reading quantity
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and I did try a COO version, but it is a little low efficient because of this memory-intensive kernel. The COO will increase the memory reading quantity
That's not true, COO is for load balancing. Let's revisit this after 24.08.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
C++ maintainability review
template <typename value_t, typename index_t> | ||
void popc(const raft::resources& res, | ||
device_vector_view<value_t, index_t> values, | ||
index_t max_len, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be a raft::host_scalar_view<index_t>
instead for consistency?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's OK for me, but not so sure, how to judge a primitive type need to be raft::host_scalar_view<index_t>
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that's just the API consistency that we are going for. If it's an array, use vector/matrix_view and if it's a scalar use scalar_view.
cc @cjnolet for his opinion
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think primitive types are okay for values that might never need to originate or be returned on device, as opposed to host.
@@ -68,7 +68,7 @@ RAFT_KERNEL __launch_bounds__(calc_nnz_by_rows_tpb) calc_nnz_by_rows_kernel(cons | |||
|
|||
while (offset < num_cols) { | |||
index_t bitmap_idx = lane_id + (s_bit + offset) / BITS_PER_BITMAP; | |||
bitmap_t l_bitmap = bitmap_t(0); | |||
typename std::remove_const<bitmap_t>::type l_bitmap = 0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typename std::remove_const<bitmap_t>::type l_bitmap = 0; | |
std::remove_const_t<bitmap_t> l_bitmap = 0; |
- plus Improve C++ maintainability
…into rhdong/prefiltered-bf
/merged |
/merge |
- The PR is one part of prefiltered brute force and should work with the PR of raft: rapidsai/raft#2294 Authors: - rhdong (https://github.com/rhdong) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: #146
- The PR is one part of prefiltered brute force and should work with the PR of raft: rapidsai/raft#2294 Authors: - rhdong (https://github.com/rhdong) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: rapidsai/cuvs#146
Authors: