-
Notifications
You must be signed in to change notification settings - Fork 197
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add a distance epilogue function to the bfknn call #1371
Changes from 7 commits
fbe6f29
319ffea
c341c97
acc674f
f103faf
87c22ff
d5e407b
d0e007f
962f495
cd335a2
78f7b1f
5e4919f
18e19a9
eea9baa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -47,7 +47,9 @@ using namespace raft::spatial::knn; | |
* Calculates brute force knn, using a fixed memory budget | ||
* by tiling over both the rows and columns of pairwise_distances | ||
*/ | ||
template <typename ElementType = float, typename IndexType = int64_t> | ||
template <typename ElementType = float, | ||
typename IndexType = int64_t, | ||
typename DistanceEpilogue = raft::identity_op> | ||
void tiled_brute_force_knn(const raft::device_resources& handle, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As mentioned above, id eventually love to find ways we can consolidate (and reuse) this with the gram matrix APIs. Not an immediate priority but the current gramm matrix API is a class that doesn't need to store state and we have a todo to convert it into flattened public API functions like the rest of RAFT. |
||
const ElementType* search, // size (m ,d) | ||
const ElementType* index, // size (n ,d) | ||
|
@@ -58,9 +60,10 @@ void tiled_brute_force_knn(const raft::device_resources& handle, | |
ElementType* distances, // size (m, k) | ||
IndexType* indices, // size (m, k) | ||
raft::distance::DistanceType metric, | ||
float metric_arg = 0.0, | ||
size_t max_row_tile_size = 0, | ||
size_t max_col_tile_size = 0) | ||
float metric_arg = 0.0, | ||
size_t max_row_tile_size = 0, | ||
size_t max_col_tile_size = 0, | ||
DistanceEpilogue distance_epilogue = raft::identity_op()) | ||
{ | ||
// Figure out the number of rows/cols to tile for | ||
size_t tile_rows = 0; | ||
|
@@ -169,8 +172,22 @@ void tiled_brute_force_knn(const raft::device_resources& handle, | |
// cause NaN values in the subsequent sqrt. Clamp to 0 | ||
val = val * (val >= 0.0001); | ||
if (metric == raft::distance::DistanceType::L2SqrtExpanded) { val = sqrt(val); } | ||
val = distance_epilogue(val, row, col); | ||
return val; | ||
}); | ||
} else { | ||
// if we're not l2 distance, and we have a distance epilogue - run it now | ||
if constexpr (!std::is_same_v<DistanceEpilogue, raft::identity_op>) { | ||
auto distances_ptr = temp_distances.data(); | ||
raft::linalg::map_offset( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Very nice! |
||
handle, | ||
raft::make_device_vector_view(temp_distances.data(), | ||
current_query_size * current_centroid_size), | ||
[=] __device__(size_t i) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it makes sense to do this as an epilogue on the pairwise distances rather than burdening the (already register constrained) k selection with it. This formulation of the algorithm (tiled knn with epilogue after tiling) makes this portion equivalent to our gramm matrix API for constructing RKHS kernels (see raft::distance::kernel) and it makes me wonder if there's something to be gained by consolidating these eventually (eg brute force knn primitive becomes a composition of [tiled gramm + epilogue] + k-selection. That could also allow us to reuse more from the gram APIs. |
||
return distance_epilogue( | ||
distances_ptr[i], i / current_centroid_size, i % current_centroid_size); | ||
}); | ||
} | ||
} | ||
|
||
select_k<IndexType, ElementType>(temp_distances.data(), | ||
|
@@ -250,7 +267,10 @@ void tiled_brute_force_knn(const raft::device_resources& handle, | |
* @param[in] metric corresponds to the raft::distance::DistanceType enum (default is L2Expanded) | ||
* @param[in] metricArg metric argument to use. Corresponds to the p arg for lp norm | ||
*/ | ||
template <typename IntType = int, typename IdxType = std::int64_t, typename value_t = float> | ||
template <typename IntType = int, | ||
typename IdxType = std::int64_t, | ||
typename value_t = float, | ||
typename DistanceEpilogue = raft::identity_op> | ||
void brute_force_knn_impl( | ||
raft::device_resources const& handle, | ||
std::vector<value_t*>& input, | ||
|
@@ -265,7 +285,8 @@ void brute_force_knn_impl( | |
bool rowMajorQuery = true, | ||
std::vector<IdxType>* translations = nullptr, | ||
raft::distance::DistanceType metric = raft::distance::DistanceType::L2Expanded, | ||
float metricArg = 0) | ||
float metricArg = 0, | ||
benfred marked this conversation as resolved.
Show resolved
Hide resolved
|
||
DistanceEpilogue distance_epilogue = raft::identity_op()) | ||
{ | ||
auto userStream = handle.get_stream(); | ||
|
||
|
@@ -355,6 +376,7 @@ void brute_force_knn_impl( | |
auto stream = handle.get_next_usable_stream(i); | ||
|
||
if (k <= 64 && rowMajorQuery == rowMajorIndex && rowMajorQuery == true && | ||
std::is_same_v<DistanceEpilogue, raft::identity_op> && | ||
(metric == raft::distance::DistanceType::L2Unexpanded || | ||
metric == raft::distance::DistanceType::L2SqrtUnexpanded || | ||
metric == raft::distance::DistanceType::L2Expanded || | ||
|
@@ -424,7 +446,10 @@ void brute_force_knn_impl( | |
out_d_ptr, | ||
out_i_ptr, | ||
tiled_metric, | ||
metricArg); | ||
metricArg, | ||
0, | ||
0, | ||
distance_epilogue); | ||
break; | ||
} | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add the expectation if the epilogue function's argument list here too? Just a small example prototype of the function definition will do.