Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add a distance epilogue function to the bfknn call #1371

Merged
merged 14 commits into from
Mar 28, 2023
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions cpp/include/raft/neighbors/brute_force.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -138,13 +138,15 @@ inline void knn_merge_parts(
* is ignored if the metric_type is not Minkowski.
* @param[in] global_id_offset: optional starting global id mapping for the local partition
* (assumes the index contains contiguous ids in the global id space)
* @param[in] distance_epilogue: optional epilogue function to run after computing distances
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add the expectation if the epilogue function's argument list here too? Just a small example prototype of the function definition will do.

*/
template <typename idx_t,
typename value_t,
typename value_int,
typename matrix_idx,
typename index_layout,
typename search_layout>
typename search_layout,
typename epilogue_op = raft::identity_op>
cjnolet marked this conversation as resolved.
Show resolved Hide resolved
void knn(raft::device_resources const& handle,
std::vector<raft::device_matrix_view<const value_t, matrix_idx, index_layout>> index,
raft::device_matrix_view<const value_t, matrix_idx, search_layout> search,
Expand All @@ -153,7 +155,8 @@ void knn(raft::device_resources const& handle,
value_int k,
benfred marked this conversation as resolved.
Show resolved Hide resolved
distance::DistanceType metric = distance::DistanceType::L2Unexpanded,
std::optional<float> metric_arg = std::make_optional<float>(2.0f),
std::optional<idx_t> global_id_offset = std::nullopt)
std::optional<idx_t> global_id_offset = std::nullopt,
epilogue_op distance_epilogue = raft::identity_op())
cjnolet marked this conversation as resolved.
Show resolved Hide resolved
{
RAFT_EXPECTS(index[0].extent(1) == search.extent(1),
"Number of dimensions for both index and search matrices must be equal");
Expand Down Expand Up @@ -194,7 +197,8 @@ void knn(raft::device_resources const& handle,
rowMajorQuery,
trans_arg,
metric,
metric_arg.value_or(2.0f));
metric_arg.value_or(2.0f),
distance_epilogue);
}

/**
Expand Down
39 changes: 32 additions & 7 deletions cpp/include/raft/neighbors/detail/knn_brute_force.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@ using namespace raft::spatial::knn;
* Calculates brute force knn, using a fixed memory budget
* by tiling over both the rows and columns of pairwise_distances
*/
template <typename ElementType = float, typename IndexType = int64_t>
template <typename ElementType = float,
typename IndexType = int64_t,
typename DistanceEpilogue = raft::identity_op>
void tiled_brute_force_knn(const raft::device_resources& handle,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned above, id eventually love to find ways we can consolidate (and reuse) this with the gram matrix APIs. Not an immediate priority but the current gramm matrix API is a class that doesn't need to store state and we have a todo to convert it into flattened public API functions like the rest of RAFT.

const ElementType* search, // size (m ,d)
const ElementType* index, // size (n ,d)
Expand All @@ -58,9 +60,10 @@ void tiled_brute_force_knn(const raft::device_resources& handle,
ElementType* distances, // size (m, k)
IndexType* indices, // size (m, k)
raft::distance::DistanceType metric,
float metric_arg = 0.0,
size_t max_row_tile_size = 0,
size_t max_col_tile_size = 0)
float metric_arg = 0.0,
size_t max_row_tile_size = 0,
size_t max_col_tile_size = 0,
DistanceEpilogue distance_epilogue = raft::identity_op())
{
// Figure out the number of rows/cols to tile for
size_t tile_rows = 0;
Expand Down Expand Up @@ -169,8 +172,22 @@ void tiled_brute_force_knn(const raft::device_resources& handle,
// cause NaN values in the subsequent sqrt. Clamp to 0
val = val * (val >= 0.0001);
if (metric == raft::distance::DistanceType::L2SqrtExpanded) { val = sqrt(val); }
val = distance_epilogue(val, row, col);
return val;
});
} else {
// if we're not l2 distance, and we have a distance epilogue - run it now
if constexpr (!std::is_same_v<DistanceEpilogue, raft::identity_op>) {
auto distances_ptr = temp_distances.data();
raft::linalg::map_offset(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice!

handle,
raft::make_device_vector_view(temp_distances.data(),
current_query_size * current_centroid_size),
[=] __device__(size_t i) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it makes sense to do this as an epilogue on the pairwise distances rather than burdening the (already register constrained) k selection with it.

This formulation of the algorithm (tiled knn with epilogue after tiling) makes this portion equivalent to our gramm matrix API for constructing RKHS kernels (see raft::distance::kernel) and it makes me wonder if there's something to be gained by consolidating these eventually (eg brute force knn primitive becomes a composition of [tiled gramm + epilogue] + k-selection. That could also allow us to reuse more from the gram APIs.

return distance_epilogue(
distances_ptr[i], i / current_centroid_size, i % current_centroid_size);
});
}
}

select_k<IndexType, ElementType>(temp_distances.data(),
Expand Down Expand Up @@ -250,7 +267,10 @@ void tiled_brute_force_knn(const raft::device_resources& handle,
* @param[in] metric corresponds to the raft::distance::DistanceType enum (default is L2Expanded)
* @param[in] metricArg metric argument to use. Corresponds to the p arg for lp norm
*/
template <typename IntType = int, typename IdxType = std::int64_t, typename value_t = float>
template <typename IntType = int,
typename IdxType = std::int64_t,
typename value_t = float,
typename DistanceEpilogue = raft::identity_op>
void brute_force_knn_impl(
raft::device_resources const& handle,
std::vector<value_t*>& input,
Expand All @@ -265,7 +285,8 @@ void brute_force_knn_impl(
bool rowMajorQuery = true,
std::vector<IdxType>* translations = nullptr,
raft::distance::DistanceType metric = raft::distance::DistanceType::L2Expanded,
float metricArg = 0)
float metricArg = 0,
benfred marked this conversation as resolved.
Show resolved Hide resolved
DistanceEpilogue distance_epilogue = raft::identity_op())
{
auto userStream = handle.get_stream();

Expand Down Expand Up @@ -355,6 +376,7 @@ void brute_force_knn_impl(
auto stream = handle.get_next_usable_stream(i);

if (k <= 64 && rowMajorQuery == rowMajorIndex && rowMajorQuery == true &&
std::is_same_v<DistanceEpilogue, raft::identity_op> &&
(metric == raft::distance::DistanceType::L2Unexpanded ||
metric == raft::distance::DistanceType::L2SqrtUnexpanded ||
metric == raft::distance::DistanceType::L2Expanded ||
Expand Down Expand Up @@ -424,7 +446,10 @@ void brute_force_knn_impl(
out_d_ptr,
out_i_ptr,
tiled_metric,
metricArg);
metricArg,
0,
0,
distance_epilogue);
break;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ namespace raft::neighbors::detail {
bool rowMajorQuery, \
std::vector<IdxT>* translations, \
raft::distance::DistanceType metric, \
float metricArg);
float metricArg, \
raft::identity_op);
RAFT_INST(long, float, int);
RAFT_INST(long, float, unsigned int);
RAFT_INST(uint32_t, float, int);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ namespace raft::neighbors::detail {
bool rowMajorQuery, \
std::vector<IdxT>* translations, \
raft::distance::DistanceType metric, \
float metricArg);
float metricArg, \
raft::identity_op);
RAFT_INST(long, float, int);
#undef RAFT_INST
} // namespace raft::neighbors::detail
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ namespace raft::neighbors::detail {
bool rowMajorQuery, \
std::vector<IdxT>* translations, \
raft::distance::DistanceType metric, \
float metricArg);
float metricArg, \
raft::identity_op);
RAFT_INST(long, float, unsigned int);
#undef RAFT_INST
} // namespace raft::neighbors::detail
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ namespace raft::neighbors::detail {
bool rowMajorQuery, \
std::vector<IdxT>* translations, \
raft::distance::DistanceType metric, \
float metricArg);
float metricArg, \
raft::identity_op);
RAFT_INST(uint32_t, float, int);
#undef RAFT_INST
} // namespace raft::neighbors::detail
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ namespace raft::neighbors::detail {
bool rowMajorQuery, \
std::vector<IdxT>* translations, \
raft::distance::DistanceType metric, \
float metricArg);
float metricArg, \
raft::identity_op);
RAFT_INST(uint32_t, float, unsigned int);
#undef RAFT_INST
} // namespace raft::neighbors::detail