Skip to content

Commit

Permalink
Remove MetricProcessor code from brute_force::knn (rapidsai#1426)
Browse files Browse the repository at this point in the history
Stop using the MetricProcessor code to preprocess the inputs to the bfknn calls. Since the pairwise distance API supports both cosine and correlation distance, this wasn't required anymore - and it introduced NaN values to the input when passed a dataset with one of the rows being all zero.

Authors:
  - Ben Frederickson (https://github.com/benfred)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: rapidsai#1426
  • Loading branch information
benfred authored and ahendriksen committed Apr 27, 2023
1 parent 91c152a commit 2e9c61c
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 45 deletions.
91 changes: 49 additions & 42 deletions cpp/include/raft/neighbors/detail/knn_brute_force.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ void tiled_brute_force_knn(const raft::device_resources& handle,
size_t m,
size_t n,
size_t d,
int k,
size_t k,
ElementType* distances, // size (m, k)
IndexType* indices, // size (m, k)
raft::distance::DistanceType metric,
Expand All @@ -80,7 +80,7 @@ void tiled_brute_force_knn(const raft::device_resources& handle,
if (max_col_tile_size && (tile_cols > max_col_tile_size)) { tile_cols = max_col_tile_size; }

// tile_cols must be at least k items
tile_cols = std::max(tile_cols, static_cast<size_t>(k));
tile_cols = std::max(tile_cols, k);

// stores pairwise distances for the current tile
rmm::device_uvector<ElementType> temp_distances(tile_rows * tile_cols, stream);
Expand All @@ -91,13 +91,34 @@ void tiled_brute_force_knn(const raft::device_resources& handle,
rmm::device_uvector<ElementType> search_norms(0, stream);
rmm::device_uvector<ElementType> index_norms(0, stream);
if (metric == raft::distance::DistanceType::L2Expanded ||
metric == raft::distance::DistanceType::L2SqrtExpanded) {
metric == raft::distance::DistanceType::L2SqrtExpanded ||
metric == raft::distance::DistanceType::CosineExpanded) {
search_norms.resize(m, stream);
index_norms.resize(n, stream);
raft::linalg::rowNorm(
search_norms.data(), search, d, m, raft::linalg::NormType::L2Norm, true, stream);
raft::linalg::rowNorm(
index_norms.data(), index, d, n, raft::linalg::NormType::L2Norm, true, stream);
// cosine needs the l2norm, where as l2 distances needs the squared norm
if (metric == raft::distance::DistanceType::CosineExpanded) {
raft::linalg::rowNorm(search_norms.data(),
search,
d,
m,
raft::linalg::NormType::L2Norm,
true,
stream,
raft::sqrt_op{});
raft::linalg::rowNorm(index_norms.data(),
index,
d,
n,
raft::linalg::NormType::L2Norm,
true,
stream,
raft::sqrt_op{});
} else {
raft::linalg::rowNorm(
search_norms.data(), search, d, m, raft::linalg::NormType::L2Norm, true, stream);
raft::linalg::rowNorm(
index_norms.data(), index, d, n, raft::linalg::NormType::L2Norm, true, stream);
}
pairwise_metric = raft::distance::DistanceType::InnerProduct;
}

Expand All @@ -110,20 +131,17 @@ void tiled_brute_force_knn(const raft::device_resources& handle,
// in which case the number of columns here is too high in the temp output.
// adjust if necessary
auto last_col_tile_size = n % tile_cols;
if (last_col_tile_size && (last_col_tile_size < static_cast<size_t>(k))) {
temp_out_cols -= k - last_col_tile_size;
}
if (last_col_tile_size && (last_col_tile_size < k)) { temp_out_cols -= k - last_col_tile_size; }

// if we have less than k items in the index, we should fill out the result
// to indicate that we are missing items (and match behaviour in faiss)
if (n < static_cast<size_t>(k)) {
if (n < k) {
raft::matrix::fill(handle,
raft::make_device_matrix_view(distances, m, static_cast<size_t>(k)),
raft::make_device_matrix_view(distances, m, k),
std::numeric_limits<ElementType>::lowest());

if constexpr (std::is_signed_v<IndexType>) {
raft::matrix::fill(
handle, raft::make_device_matrix_view(indices, m, static_cast<size_t>(k)), IndexType{-1});
raft::matrix::fill(handle, raft::make_device_matrix_view(indices, m, k), IndexType{-1});
}
}

Expand All @@ -137,7 +155,7 @@ void tiled_brute_force_knn(const raft::device_resources& handle,

for (size_t j = 0; j < n; j += tile_cols) {
size_t current_centroid_size = std::min(tile_cols, n - j);
size_t current_k = std::min(current_centroid_size, static_cast<size_t>(k));
size_t current_k = std::min(current_centroid_size, k);

// calculate the top-k elements for the current tile, by calculating the
// full pairwise distance for the tile - and then selecting the top-k from that
Expand Down Expand Up @@ -177,6 +195,21 @@ void tiled_brute_force_knn(const raft::device_resources& handle,
val = distance_epilogue(val, row, col);
return val;
});
} else if (metric == raft::distance::DistanceType::CosineExpanded) {
auto row_norms = search_norms.data();
auto col_norms = index_norms.data();
auto dist = temp_distances.data();

raft::linalg::map_offset(
handle,
raft::make_device_vector_view(dist, current_query_size * current_centroid_size),
[=] __device__(IndexType idx) {
IndexType row = i + (idx / current_centroid_size);
IndexType col = j + (idx % current_centroid_size);
auto val = 1.0 - dist[idx] / (row_norms[row] * col_norms[col]);
val = distance_epilogue(val, row, col);
return val;
});
} else {
// if we're not l2 distance, and we have a distance epilogue - run it now
if constexpr (!std::is_same_v<DistanceEpilogue, raft::identity_op>) {
Expand Down Expand Up @@ -311,18 +344,6 @@ void brute_force_knn_impl(
id_ranges = translations;
}

// perform preprocessing
std::unique_ptr<MetricProcessor<value_t>> query_metric_processor =
create_processor<value_t>(metric, n, D, k, rowMajorQuery, userStream);
query_metric_processor->preprocess(search_items);

std::vector<std::unique_ptr<MetricProcessor<value_t>>> metric_processors(input.size());
for (size_t i = 0; i < input.size(); i++) {
metric_processors[i] =
create_processor<value_t>(metric, sizes[i], D, k, rowMajorQuery, userStream);
metric_processors[i]->preprocess(input[i]);
}

int device;
RAFT_CUDA_TRY(cudaGetDevice(&device));

Expand Down Expand Up @@ -431,14 +452,6 @@ void brute_force_knn_impl(
raft::linalg::transpose(handle, input[i], index, sizes[i], D, stream);
}

// cosine/correlation are handled by metric processor, use IP distance
// for brute force knn call.
auto tiled_metric = metric;
if (metric == raft::distance::DistanceType::CosineExpanded ||
metric == raft::distance::DistanceType::CorrelationExpanded) {
tiled_metric = raft::distance::DistanceType::InnerProduct;
}

tiled_brute_force_knn<value_t, IdxType>(stream_pool_handle,
search,
index,
Expand All @@ -448,7 +461,7 @@ void brute_force_knn_impl(
k,
out_d_ptr,
out_i_ptr,
tiled_metric,
metric,
metricArg,
0,
0,
Expand All @@ -471,12 +484,6 @@ void brute_force_knn_impl(
knn_merge_parts(out_D, out_I, res_D, res_I, n, input.size(), k, userStream, trans.data());
}

query_metric_processor->revert(search_items);
query_metric_processor->postprocess(out_D);
for (size_t i = 0; i < input.size(); i++) {
metric_processors[i]->revert(input[i]);
}

if (translations == nullptr) delete id_ranges;
};

Expand Down
3 changes: 0 additions & 3 deletions python/pylibraft/pylibraft/test/test_brute_force.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,6 @@ def test_knn(
expected_indices = argsort[i]
gpu_dists = actual_distances[i]

if metric == "correlation" or metric == "cosine":
gpu_dists = gpu_dists[::-1]

cpu_ordered = pw_dists[i, expected_indices]
np.testing.assert_allclose(
cpu_ordered[:k], gpu_dists, atol=1e-4, rtol=1e-4
Expand Down

0 comments on commit 2e9c61c

Please sign in to comment.