Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add a distance epilogue function to the bfknn call #1371

Merged
merged 14 commits into from
Mar 28, 2023
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 16 additions & 12 deletions cpp/include/raft/neighbors/brute_force.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -138,38 +138,41 @@ inline void knn_merge_parts(
* is ignored if the metric_type is not Minkowski.
* @param[in] global_id_offset: optional starting global id mapping for the local partition
* (assumes the index contains contiguous ids in the global id space)
* @param[in] distance_epilogue: optional epilogue function to run after computing distances. This
function takes a triple of the (value, rowid, colid) for each
element in the pairwise distances and returns a transformed value
back.
*/
template <typename idx_t,
typename value_t,
typename value_int,
typename matrix_idx,
typename index_layout,
typename search_layout>
typename search_layout,
typename epilogue_op = raft::identity_op>
cjnolet marked this conversation as resolved.
Show resolved Hide resolved
void knn(raft::device_resources const& handle,
std::vector<raft::device_matrix_view<const value_t, matrix_idx, index_layout>> index,
raft::device_matrix_view<const value_t, matrix_idx, search_layout> search,
raft::device_matrix_view<idx_t, matrix_idx, row_major> indices,
raft::device_matrix_view<value_t, matrix_idx, row_major> distances,
value_int k,
distance::DistanceType metric = distance::DistanceType::L2Unexpanded,
std::optional<float> metric_arg = std::make_optional<float>(2.0f),
std::optional<idx_t> global_id_offset = std::nullopt)
std::optional<idx_t> global_id_offset = std::nullopt,
epilogue_op distance_epilogue = raft::identity_op())
cjnolet marked this conversation as resolved.
Show resolved Hide resolved
{
RAFT_EXPECTS(index[0].extent(1) == search.extent(1),
"Number of dimensions for both index and search matrices must be equal");

RAFT_EXPECTS(indices.extent(0) == distances.extent(0) && distances.extent(0) == search.extent(0),
"Number of rows in output indices and distances matrices must equal number of rows "
"in search matrix.");
RAFT_EXPECTS(
indices.extent(1) == distances.extent(1) && distances.extent(1) == static_cast<matrix_idx>(k),
"Number of columns in output indices and distances matrices must be equal to k");
RAFT_EXPECTS(indices.extent(1) == distances.extent(1) && distances.extent(1),
"Number of columns in output indices and distances matrices must the same");

bool rowMajorIndex = std::is_same_v<index_layout, layout_c_contiguous>;
bool rowMajorQuery = std::is_same_v<search_layout, layout_c_contiguous>;

std::vector<value_t*> inputs;
std::vector<value_int> sizes;
std::vector<matrix_idx> sizes;
for (std::size_t i = 0; i < index.size(); ++i) {
inputs.push_back(const_cast<value_t*>(index[i].data_handle()));
sizes.push_back(index[i].extent(0));
Expand All @@ -183,18 +186,19 @@ void knn(raft::device_resources const& handle,
raft::neighbors::detail::brute_force_knn_impl(handle,
inputs,
sizes,
static_cast<value_int>(index[0].extent(1)),
index[0].extent(1),
// TODO: This is unfortunate. Need to fix.
const_cast<value_t*>(search.data_handle()),
static_cast<value_int>(search.extent(0)),
search.extent(0),
indices.data_handle(),
distances.data_handle(),
k,
indices.extent(1),
rowMajorIndex,
rowMajorQuery,
trans_arg,
metric,
metric_arg.value_or(2.0f));
metric_arg.value_or(2.0f),
distance_epilogue);
}

/**
Expand Down
51 changes: 39 additions & 12 deletions cpp/include/raft/neighbors/detail/knn_brute_force.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@ using namespace raft::spatial::knn;
* Calculates brute force knn, using a fixed memory budget
* by tiling over both the rows and columns of pairwise_distances
*/
template <typename ElementType = float, typename IndexType = int64_t>
template <typename ElementType = float,
typename IndexType = int64_t,
typename DistanceEpilogue = raft::identity_op>
void tiled_brute_force_knn(const raft::device_resources& handle,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned above, id eventually love to find ways we can consolidate (and reuse) this with the gram matrix APIs. Not an immediate priority but the current gramm matrix API is a class that doesn't need to store state and we have a todo to convert it into flattened public API functions like the rest of RAFT.

const ElementType* search, // size (m ,d)
const ElementType* index, // size (n ,d)
Expand All @@ -58,9 +60,10 @@ void tiled_brute_force_knn(const raft::device_resources& handle,
ElementType* distances, // size (m, k)
IndexType* indices, // size (m, k)
raft::distance::DistanceType metric,
float metric_arg = 0.0,
size_t max_row_tile_size = 0,
size_t max_col_tile_size = 0)
float metric_arg = 2.0,
size_t max_row_tile_size = 0,
size_t max_col_tile_size = 0,
DistanceEpilogue distance_epilogue = raft::identity_op())
{
// Figure out the number of rows/cols to tile for
size_t tile_rows = 0;
Expand Down Expand Up @@ -152,25 +155,41 @@ void tiled_brute_force_knn(const raft::device_resources& handle,
metric_arg);
if (metric == raft::distance::DistanceType::L2Expanded ||
metric == raft::distance::DistanceType::L2SqrtExpanded) {
auto row_norms = search_norms.data() + i;
auto col_norms = index_norms.data() + j;
auto row_norms = search_norms.data();
auto col_norms = index_norms.data();
auto dist = temp_distances.data();

raft::linalg::map_offset(
handle,
raft::make_device_vector_view(dist, current_query_size * current_centroid_size),
[=] __device__(IndexType i) {
IndexType row = i / current_centroid_size, col = i % current_centroid_size;
[=] __device__(IndexType idx) {
IndexType row = i + (idx / current_centroid_size);
IndexType col = j + (idx % current_centroid_size);

auto val = row_norms[row] + col_norms[col] - 2.0 * dist[i];
auto val = row_norms[row] + col_norms[col] - 2.0 * dist[idx];

// due to numerical instability (especially around self-distance)
// the distances here could be slightly negative, which will
// cause NaN values in the subsequent sqrt. Clamp to 0
val = val * (val >= 0.0001);
if (metric == raft::distance::DistanceType::L2SqrtExpanded) { val = sqrt(val); }
val = distance_epilogue(val, row, col);
return val;
});
} else {
// if we're not l2 distance, and we have a distance epilogue - run it now
if constexpr (!std::is_same_v<DistanceEpilogue, raft::identity_op>) {
auto distances_ptr = temp_distances.data();
raft::linalg::map_offset(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice!

handle,
raft::make_device_vector_view(temp_distances.data(),
current_query_size * current_centroid_size),
[=] __device__(size_t idx) {
IndexType row = i + (idx / current_centroid_size);
IndexType col = j + (idx % current_centroid_size);
return distance_epilogue(distances_ptr[idx], row, col);
});
}
}

select_k<IndexType, ElementType>(temp_distances.data(),
Expand Down Expand Up @@ -250,7 +269,10 @@ void tiled_brute_force_knn(const raft::device_resources& handle,
* @param[in] metric corresponds to the raft::distance::DistanceType enum (default is L2Expanded)
* @param[in] metricArg metric argument to use. Corresponds to the p arg for lp norm
*/
template <typename IntType = int, typename IdxType = std::int64_t, typename value_t = float>
template <typename IntType = int,
typename IdxType = std::int64_t,
typename value_t = float,
typename DistanceEpilogue = raft::identity_op>
void brute_force_knn_impl(
raft::device_resources const& handle,
std::vector<value_t*>& input,
Expand All @@ -265,7 +287,8 @@ void brute_force_knn_impl(
bool rowMajorQuery = true,
std::vector<IdxType>* translations = nullptr,
raft::distance::DistanceType metric = raft::distance::DistanceType::L2Expanded,
float metricArg = 0)
float metricArg = 0,
benfred marked this conversation as resolved.
Show resolved Hide resolved
DistanceEpilogue distance_epilogue = raft::identity_op())
{
auto userStream = handle.get_stream();

Expand Down Expand Up @@ -355,6 +378,7 @@ void brute_force_knn_impl(
auto stream = handle.get_next_usable_stream(i);

if (k <= 64 && rowMajorQuery == rowMajorIndex && rowMajorQuery == true &&
std::is_same_v<DistanceEpilogue, raft::identity_op> &&
(metric == raft::distance::DistanceType::L2Unexpanded ||
metric == raft::distance::DistanceType::L2SqrtUnexpanded ||
metric == raft::distance::DistanceType::L2Expanded ||
Expand Down Expand Up @@ -424,7 +448,10 @@ void brute_force_knn_impl(
out_d_ptr,
out_i_ptr,
tiled_metric,
metricArg);
metricArg,
0,
0,
distance_epilogue);
break;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ namespace raft::neighbors::detail {
bool rowMajorQuery, \
std::vector<IdxT>* translations, \
raft::distance::DistanceType metric, \
float metricArg);
float metricArg, \
raft::identity_op);
RAFT_INST(long, float, int);
RAFT_INST(long, float, unsigned int);
RAFT_INST(uint32_t, float, int);
Expand Down
1 change: 0 additions & 1 deletion cpp/include/raft/spatial/knn/detail/ball_cover.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,6 @@ void k_closest_landmarks(raft::device_resources const& handle,
make_device_matrix_view(query_pts, n_query_pts, inputs[0].extent(1)),
make_device_matrix_view(R_knn_inds, n_query_pts, k),
make_device_matrix_view(R_knn_dists, n_query_pts, k),
k,
index.get_metric());
}

Expand Down
11 changes: 2 additions & 9 deletions cpp/src/neighbors/brute_force_knn_int64_t_float.cu
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,8 @@ namespace raft::runtime::neighbors::brute_force {
{ \
std::vector<raft::device_matrix_view<const DATA_T, MATRIX_IDX_T, INDEX_LAYOUT>> vec; \
vec.push_back(index); \
raft::neighbors::brute_force::knn(handle, \
vec, \
search, \
indices, \
distances, \
static_cast<int>(distances.extent(1)), \
metric, \
metric_arg, \
global_id_offset); \
raft::neighbors::brute_force::knn( \
handle, vec, search, indices, distances, metric, metric_arg, global_id_offset); \
}

RAFT_INST_BFKNN(int64_t, float, int64_t, raft::row_major, raft::row_major);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ namespace raft::neighbors::detail {
bool rowMajorQuery, \
std::vector<IdxT>* translations, \
raft::distance::DistanceType metric, \
float metricArg);
float metricArg, \
raft::identity_op);
RAFT_INST(long, float, int);
#undef RAFT_INST
} // namespace raft::neighbors::detail
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ namespace raft::neighbors::detail {
bool rowMajorQuery, \
std::vector<IdxT>* translations, \
raft::distance::DistanceType metric, \
float metricArg);
float metricArg, \
raft::identity_op);
RAFT_INST(long, float, unsigned int);
#undef RAFT_INST
} // namespace raft::neighbors::detail
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ namespace raft::neighbors::detail {
bool rowMajorQuery, \
std::vector<IdxT>* translations, \
raft::distance::DistanceType metric, \
float metricArg);
float metricArg, \
raft::identity_op);
RAFT_INST(uint32_t, float, int);
#undef RAFT_INST
} // namespace raft::neighbors::detail
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ namespace raft::neighbors::detail {
bool rowMajorQuery, \
std::vector<IdxT>* translations, \
raft::distance::DistanceType metric, \
float metricArg);
float metricArg, \
raft::identity_op);
RAFT_INST(uint32_t, float, unsigned int);
#undef RAFT_INST
} // namespace raft::neighbors::detail
1 change: 0 additions & 1 deletion cpp/test/neighbors/ball_cover.cu
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,6 @@ void compute_bfknn(const raft::device_resources& handle,
make_device_matrix_view(X2, n_query_rows, d),
make_device_matrix_view(inds, n_query_rows, k),
make_device_matrix_view(dists, n_query_rows, k),
k,
metric);
}

Expand Down
2 changes: 1 addition & 1 deletion cpp/test/neighbors/knn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ class KNNTest : public ::testing::TestWithParam<KNNInputs> {
raft::make_device_matrix_view<T, IdxT, row_major>(distances_.data(), rows_, k_);

auto metric = raft::distance::DistanceType::L2Unexpanded;
knn(handle, index, search, indices, distances, k_, metric, std::make_optional<IdxT>(0));
knn(handle, index, search, indices, distances, metric, std::make_optional<IdxT>(0));

build_actual_output<<<raft::ceildiv(rows_ * k_, 32), 32, 0, stream>>>(
actual_labels_.data(), rows_, k_, search_labels_.data(), indices_.data());
Expand Down