Skip to content

Commit

Permalink
add a distance epilogue function to the bfknn call (#1371)
Browse files Browse the repository at this point in the history
Add the ability for a user to specify an epilogue function to run after the distance in the brute_force::knn call.

This lets us remove faiss from cuml, by updating the hdbscan reachability code (rapidsai/cuml#5293)

Authors:
  - Ben Frederickson (https://github.com/benfred)
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #1371
  • Loading branch information
benfred authored Mar 28, 2023
1 parent 76c828d commit 0d3bd3d
Show file tree
Hide file tree
Showing 11 changed files with 69 additions and 44 deletions.
32 changes: 17 additions & 15 deletions cpp/include/raft/neighbors/brute_force.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -122,54 +122,55 @@ inline void knn_merge_parts(
*
* raft::raft::device_resources handle;
* ...
* int k = 10;
* auto metric = raft::distance::DistanceType::L2SqrtExpanded;
* brute_force::knn(handle, index, search, indices, distances, k, metric);
* brute_force::knn(handle, index, search, indices, distances, metric);
* @endcode
*
* @param[in] handle: the cuml handle to use
* @param[in] index: vector of device matrices (each size m_i*d) to be used as the knn index
* @param[in] search: matrix (size n*d) to be used for searching the index
* @param[out] indices: matrix (size n*k) to store output knn indices
* @param[out] distances: matrix (size n*k) to store the output knn distance
* @param[in] k: the number of nearest neighbors to return
* @param[in] metric: distance metric to use. Euclidean (L2) is used by default
* @param[in] metric_arg: the value of `p` for Minkowski (l-p) distances. This
* is ignored if the metric_type is not Minkowski.
* @param[in] global_id_offset: optional starting global id mapping for the local partition
* (assumes the index contains contiguous ids in the global id space)
* @param[in] distance_epilogue: optional epilogue function to run after computing distances. This
function takes a triple of the (value, rowid, colid) for each
element in the pairwise distances and returns a transformed value
back.
*/
template <typename idx_t,
typename value_t,
typename value_int,
typename matrix_idx,
typename index_layout,
typename search_layout>
typename search_layout,
typename epilogue_op = raft::identity_op>
void knn(raft::device_resources const& handle,
std::vector<raft::device_matrix_view<const value_t, matrix_idx, index_layout>> index,
raft::device_matrix_view<const value_t, matrix_idx, search_layout> search,
raft::device_matrix_view<idx_t, matrix_idx, row_major> indices,
raft::device_matrix_view<value_t, matrix_idx, row_major> distances,
value_int k,
distance::DistanceType metric = distance::DistanceType::L2Unexpanded,
std::optional<float> metric_arg = std::make_optional<float>(2.0f),
std::optional<idx_t> global_id_offset = std::nullopt)
std::optional<idx_t> global_id_offset = std::nullopt,
epilogue_op distance_epilogue = raft::identity_op())
{
RAFT_EXPECTS(index[0].extent(1) == search.extent(1),
"Number of dimensions for both index and search matrices must be equal");

RAFT_EXPECTS(indices.extent(0) == distances.extent(0) && distances.extent(0) == search.extent(0),
"Number of rows in output indices and distances matrices must equal number of rows "
"in search matrix.");
RAFT_EXPECTS(
indices.extent(1) == distances.extent(1) && distances.extent(1) == static_cast<matrix_idx>(k),
"Number of columns in output indices and distances matrices must be equal to k");
RAFT_EXPECTS(indices.extent(1) == distances.extent(1) && distances.extent(1),
"Number of columns in output indices and distances matrices must the same");

bool rowMajorIndex = std::is_same_v<index_layout, layout_c_contiguous>;
bool rowMajorQuery = std::is_same_v<search_layout, layout_c_contiguous>;

std::vector<value_t*> inputs;
std::vector<value_int> sizes;
std::vector<matrix_idx> sizes;
for (std::size_t i = 0; i < index.size(); ++i) {
inputs.push_back(const_cast<value_t*>(index[i].data_handle()));
sizes.push_back(index[i].extent(0));
Expand All @@ -183,18 +184,19 @@ void knn(raft::device_resources const& handle,
raft::neighbors::detail::brute_force_knn_impl(handle,
inputs,
sizes,
static_cast<value_int>(index[0].extent(1)),
index[0].extent(1),
// TODO: This is unfortunate. Need to fix.
const_cast<value_t*>(search.data_handle()),
static_cast<value_int>(search.extent(0)),
search.extent(0),
indices.data_handle(),
distances.data_handle(),
k,
indices.extent(1),
rowMajorIndex,
rowMajorQuery,
trans_arg,
metric,
metric_arg.value_or(2.0f));
metric_arg.value_or(2.0f),
distance_epilogue);
}

/**
Expand Down
51 changes: 39 additions & 12 deletions cpp/include/raft/neighbors/detail/knn_brute_force.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@ using namespace raft::spatial::knn;
* Calculates brute force knn, using a fixed memory budget
* by tiling over both the rows and columns of pairwise_distances
*/
template <typename ElementType = float, typename IndexType = int64_t>
template <typename ElementType = float,
typename IndexType = int64_t,
typename DistanceEpilogue = raft::identity_op>
void tiled_brute_force_knn(const raft::device_resources& handle,
const ElementType* search, // size (m ,d)
const ElementType* index, // size (n ,d)
Expand All @@ -58,9 +60,10 @@ void tiled_brute_force_knn(const raft::device_resources& handle,
ElementType* distances, // size (m, k)
IndexType* indices, // size (m, k)
raft::distance::DistanceType metric,
float metric_arg = 0.0,
size_t max_row_tile_size = 0,
size_t max_col_tile_size = 0)
float metric_arg = 2.0,
size_t max_row_tile_size = 0,
size_t max_col_tile_size = 0,
DistanceEpilogue distance_epilogue = raft::identity_op())
{
// Figure out the number of rows/cols to tile for
size_t tile_rows = 0;
Expand Down Expand Up @@ -152,25 +155,41 @@ void tiled_brute_force_knn(const raft::device_resources& handle,
metric_arg);
if (metric == raft::distance::DistanceType::L2Expanded ||
metric == raft::distance::DistanceType::L2SqrtExpanded) {
auto row_norms = search_norms.data() + i;
auto col_norms = index_norms.data() + j;
auto row_norms = search_norms.data();
auto col_norms = index_norms.data();
auto dist = temp_distances.data();

raft::linalg::map_offset(
handle,
raft::make_device_vector_view(dist, current_query_size * current_centroid_size),
[=] __device__(IndexType i) {
IndexType row = i / current_centroid_size, col = i % current_centroid_size;
[=] __device__(IndexType idx) {
IndexType row = i + (idx / current_centroid_size);
IndexType col = j + (idx % current_centroid_size);

auto val = row_norms[row] + col_norms[col] - 2.0 * dist[i];
auto val = row_norms[row] + col_norms[col] - 2.0 * dist[idx];

// due to numerical instability (especially around self-distance)
// the distances here could be slightly negative, which will
// cause NaN values in the subsequent sqrt. Clamp to 0
val = val * (val >= 0.0001);
if (metric == raft::distance::DistanceType::L2SqrtExpanded) { val = sqrt(val); }
val = distance_epilogue(val, row, col);
return val;
});
} else {
// if we're not l2 distance, and we have a distance epilogue - run it now
if constexpr (!std::is_same_v<DistanceEpilogue, raft::identity_op>) {
auto distances_ptr = temp_distances.data();
raft::linalg::map_offset(
handle,
raft::make_device_vector_view(temp_distances.data(),
current_query_size * current_centroid_size),
[=] __device__(size_t idx) {
IndexType row = i + (idx / current_centroid_size);
IndexType col = j + (idx % current_centroid_size);
return distance_epilogue(distances_ptr[idx], row, col);
});
}
}

select_k<IndexType, ElementType>(temp_distances.data(),
Expand Down Expand Up @@ -250,7 +269,10 @@ void tiled_brute_force_knn(const raft::device_resources& handle,
* @param[in] metric corresponds to the raft::distance::DistanceType enum (default is L2Expanded)
* @param[in] metricArg metric argument to use. Corresponds to the p arg for lp norm
*/
template <typename IntType = int, typename IdxType = std::int64_t, typename value_t = float>
template <typename IntType = int,
typename IdxType = std::int64_t,
typename value_t = float,
typename DistanceEpilogue = raft::identity_op>
void brute_force_knn_impl(
raft::device_resources const& handle,
std::vector<value_t*>& input,
Expand All @@ -265,7 +287,8 @@ void brute_force_knn_impl(
bool rowMajorQuery = true,
std::vector<IdxType>* translations = nullptr,
raft::distance::DistanceType metric = raft::distance::DistanceType::L2Expanded,
float metricArg = 0)
float metricArg = 0,
DistanceEpilogue distance_epilogue = raft::identity_op())
{
auto userStream = handle.get_stream();

Expand Down Expand Up @@ -355,6 +378,7 @@ void brute_force_knn_impl(
auto stream = handle.get_next_usable_stream(i);

if (k <= 64 && rowMajorQuery == rowMajorIndex && rowMajorQuery == true &&
std::is_same_v<DistanceEpilogue, raft::identity_op> &&
(metric == raft::distance::DistanceType::L2Unexpanded ||
metric == raft::distance::DistanceType::L2SqrtUnexpanded ||
metric == raft::distance::DistanceType::L2Expanded ||
Expand Down Expand Up @@ -424,7 +448,10 @@ void brute_force_knn_impl(
out_d_ptr,
out_i_ptr,
tiled_metric,
metricArg);
metricArg,
0,
0,
distance_epilogue);
break;
}
}
Expand Down
3 changes: 2 additions & 1 deletion cpp/include/raft/neighbors/specializations/brute_force.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ namespace raft::neighbors::detail {
bool rowMajorQuery, \
std::vector<IdxT>* translations, \
raft::distance::DistanceType metric, \
float metricArg);
float metricArg, \
raft::identity_op);
RAFT_INST(long, float, int);
RAFT_INST(long, float, unsigned int);
RAFT_INST(uint32_t, float, int);
Expand Down
1 change: 0 additions & 1 deletion cpp/include/raft/spatial/knn/detail/ball_cover.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,6 @@ void k_closest_landmarks(raft::device_resources const& handle,
make_device_matrix_view(query_pts, n_query_pts, inputs[0].extent(1)),
make_device_matrix_view(R_knn_inds, n_query_pts, k),
make_device_matrix_view(R_knn_dists, n_query_pts, k),
k,
index.get_metric());
}

Expand Down
11 changes: 2 additions & 9 deletions cpp/src/neighbors/brute_force_knn_int64_t_float.cu
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,8 @@ namespace raft::runtime::neighbors::brute_force {
{ \
std::vector<raft::device_matrix_view<const DATA_T, MATRIX_IDX_T, INDEX_LAYOUT>> vec; \
vec.push_back(index); \
raft::neighbors::brute_force::knn(handle, \
vec, \
search, \
indices, \
distances, \
static_cast<int>(distances.extent(1)), \
metric, \
metric_arg, \
global_id_offset); \
raft::neighbors::brute_force::knn( \
handle, vec, search, indices, distances, metric, metric_arg, global_id_offset); \
}

RAFT_INST_BFKNN(int64_t, float, int64_t, raft::row_major, raft::row_major);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ namespace raft::neighbors::detail {
bool rowMajorQuery, \
std::vector<IdxT>* translations, \
raft::distance::DistanceType metric, \
float metricArg);
float metricArg, \
raft::identity_op);
RAFT_INST(long, float, int);
#undef RAFT_INST
} // namespace raft::neighbors::detail
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ namespace raft::neighbors::detail {
bool rowMajorQuery, \
std::vector<IdxT>* translations, \
raft::distance::DistanceType metric, \
float metricArg);
float metricArg, \
raft::identity_op);
RAFT_INST(long, float, unsigned int);
#undef RAFT_INST
} // namespace raft::neighbors::detail
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ namespace raft::neighbors::detail {
bool rowMajorQuery, \
std::vector<IdxT>* translations, \
raft::distance::DistanceType metric, \
float metricArg);
float metricArg, \
raft::identity_op);
RAFT_INST(uint32_t, float, int);
#undef RAFT_INST
} // namespace raft::neighbors::detail
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ namespace raft::neighbors::detail {
bool rowMajorQuery, \
std::vector<IdxT>* translations, \
raft::distance::DistanceType metric, \
float metricArg);
float metricArg, \
raft::identity_op);
RAFT_INST(uint32_t, float, unsigned int);
#undef RAFT_INST
} // namespace raft::neighbors::detail
1 change: 0 additions & 1 deletion cpp/test/neighbors/ball_cover.cu
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,6 @@ void compute_bfknn(const raft::device_resources& handle,
make_device_matrix_view(X2, n_query_rows, d),
make_device_matrix_view(inds, n_query_rows, k),
make_device_matrix_view(dists, n_query_rows, k),
k,
metric);
}

Expand Down
2 changes: 1 addition & 1 deletion cpp/test/neighbors/knn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ class KNNTest : public ::testing::TestWithParam<KNNInputs> {
raft::make_device_matrix_view<T, IdxT, row_major>(distances_.data(), rows_, k_);

auto metric = raft::distance::DistanceType::L2Unexpanded;
knn(handle, index, search, indices, distances, k_, metric, std::make_optional<IdxT>(0));
knn(handle, index, search, indices, distances, metric, std::make_optional<IdxT>(0));

build_actual_output<<<raft::ceildiv(rows_ * k_, 32), 32, 0, stream>>>(
actual_labels_.data(), rows_, k_, search_labels_.data(), indices_.data());
Expand Down

0 comments on commit 0d3bd3d

Please sign in to comment.