Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove MetricProcessor code from brute_force::knn #1426

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 49 additions & 42 deletions cpp/include/raft/neighbors/detail/knn_brute_force.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ void tiled_brute_force_knn(const raft::device_resources& handle,
size_t m,
size_t n,
size_t d,
int k,
size_t k,
ElementType* distances, // size (m, k)
IndexType* indices, // size (m, k)
raft::distance::DistanceType metric,
Expand All @@ -79,7 +79,7 @@ void tiled_brute_force_knn(const raft::device_resources& handle,
if (max_col_tile_size && (tile_cols > max_col_tile_size)) { tile_cols = max_col_tile_size; }

// tile_cols must be at least k items
tile_cols = std::max(tile_cols, static_cast<size_t>(k));
tile_cols = std::max(tile_cols, k);

// stores pairwise distances for the current tile
rmm::device_uvector<ElementType> temp_distances(tile_rows * tile_cols, stream);
Expand All @@ -90,13 +90,34 @@ void tiled_brute_force_knn(const raft::device_resources& handle,
rmm::device_uvector<ElementType> search_norms(0, stream);
rmm::device_uvector<ElementType> index_norms(0, stream);
if (metric == raft::distance::DistanceType::L2Expanded ||
metric == raft::distance::DistanceType::L2SqrtExpanded) {
metric == raft::distance::DistanceType::L2SqrtExpanded ||
metric == raft::distance::DistanceType::CosineExpanded) {
search_norms.resize(m, stream);
index_norms.resize(n, stream);
raft::linalg::rowNorm(
search_norms.data(), search, d, m, raft::linalg::NormType::L2Norm, true, stream);
raft::linalg::rowNorm(
index_norms.data(), index, d, n, raft::linalg::NormType::L2Norm, true, stream);
// cosine needs the l2norm, where as l2 distances needs the squared norm
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you also need to do this for correlation since it's a normalized cosine or does that distance "just work" in the pw dists?

Copy link
Member Author

@benfred benfred Apr 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be a good idea to do this for correlation distance too for performance reasons - though the PW distance api does 'just work', and this change already improves the times quite a bit.

times on 23.06 branch (on a github dataset)

In [10]: %timeit brute_force.knn(repo_embeddings, repo_embeddings[repoids], k=10, metric="cosine")
56.6 ms ± 6.48 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

In [11]: %timeit brute_force.knn(repo_embeddings, repo_embeddings[repoids], k=10, metric="correlation")
82.8 ms ± 10.4 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

times on this branch:

In [3]: %timeit brute_force.knn(repo_embeddings, repo_embeddings[repoids], k=10, metric="cosine")
   ...: 
26.9 ms ± 32.4 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

In [4]: %timeit brute_force.knn(repo_embeddings, repo_embeddings[repoids], k=10, metric="correlation")
49.5 ms ± 51 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

cosine/correlation times are both substantially improved with this change - though it might be worth expanding the correlation distance in the same way as we're doing with cosine/l2 etc in a future PR.


Note that in the demo wednesdays presentation - I was seeing 6.5ms for the l2 metric

In [12]: %timeit brute_force.knn(repo_embeddings, repo_embeddings[repoids], k=10, metric="l2")
6.51 ms ± 14.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

However, this is because its hitting the fused l2 code - which is substantially faster in this case. The good news is that most of this perf difference is in the select_k call, which I'm trying to get sped up now w/ changes like #1430. I believe we can get the cosine etc code up to around ~9ms on this call total - and the perf here is suffering since we're using the faiss select which does poorly on a single row.

if (metric == raft::distance::DistanceType::CosineExpanded) {
raft::linalg::rowNorm(search_norms.data(),
search,
d,
m,
raft::linalg::NormType::L2Norm,
true,
stream,
raft::sqrt_op{});
raft::linalg::rowNorm(index_norms.data(),
index,
d,
n,
raft::linalg::NormType::L2Norm,
true,
stream,
raft::sqrt_op{});
} else {
raft::linalg::rowNorm(
search_norms.data(), search, d, m, raft::linalg::NormType::L2Norm, true, stream);
raft::linalg::rowNorm(
index_norms.data(), index, d, n, raft::linalg::NormType::L2Norm, true, stream);
}
pairwise_metric = raft::distance::DistanceType::InnerProduct;
}

Expand All @@ -109,20 +130,17 @@ void tiled_brute_force_knn(const raft::device_resources& handle,
// in which case the number of columns here is too high in the temp output.
// adjust if necessary
auto last_col_tile_size = n % tile_cols;
if (last_col_tile_size && (last_col_tile_size < static_cast<size_t>(k))) {
temp_out_cols -= k - last_col_tile_size;
}
if (last_col_tile_size && (last_col_tile_size < k)) { temp_out_cols -= k - last_col_tile_size; }

// if we have less than k items in the index, we should fill out the result
// to indicate that we are missing items (and match behaviour in faiss)
if (n < static_cast<size_t>(k)) {
if (n < k) {
raft::matrix::fill(handle,
raft::make_device_matrix_view(distances, m, static_cast<size_t>(k)),
raft::make_device_matrix_view(distances, m, k),
std::numeric_limits<ElementType>::lowest());

if constexpr (std::is_signed_v<IndexType>) {
raft::matrix::fill(
handle, raft::make_device_matrix_view(indices, m, static_cast<size_t>(k)), IndexType{-1});
raft::matrix::fill(handle, raft::make_device_matrix_view(indices, m, k), IndexType{-1});
}
}

Expand All @@ -136,7 +154,7 @@ void tiled_brute_force_knn(const raft::device_resources& handle,

for (size_t j = 0; j < n; j += tile_cols) {
size_t current_centroid_size = std::min(tile_cols, n - j);
size_t current_k = std::min(current_centroid_size, static_cast<size_t>(k));
size_t current_k = std::min(current_centroid_size, k);

// calculate the top-k elements for the current tile, by calculating the
// full pairwise distance for the tile - and then selecting the top-k from that
Expand Down Expand Up @@ -176,6 +194,21 @@ void tiled_brute_force_knn(const raft::device_resources& handle,
val = distance_epilogue(val, row, col);
return val;
});
} else if (metric == raft::distance::DistanceType::CosineExpanded) {
auto row_norms = search_norms.data();
auto col_norms = index_norms.data();
auto dist = temp_distances.data();

raft::linalg::map_offset(
handle,
raft::make_device_vector_view(dist, current_query_size * current_centroid_size),
[=] __device__(IndexType idx) {
IndexType row = i + (idx / current_centroid_size);
IndexType col = j + (idx % current_centroid_size);
auto val = 1.0 - dist[idx] / (row_norms[row] * col_norms[col]);
val = distance_epilogue(val, row, col);
return val;
});
} else {
// if we're not l2 distance, and we have a distance epilogue - run it now
if constexpr (!std::is_same_v<DistanceEpilogue, raft::identity_op>) {
Expand Down Expand Up @@ -310,18 +343,6 @@ void brute_force_knn_impl(
id_ranges = translations;
}

// perform preprocessing
std::unique_ptr<MetricProcessor<value_t>> query_metric_processor =
create_processor<value_t>(metric, n, D, k, rowMajorQuery, userStream);
query_metric_processor->preprocess(search_items);

std::vector<std::unique_ptr<MetricProcessor<value_t>>> metric_processors(input.size());
for (size_t i = 0; i < input.size(); i++) {
metric_processors[i] =
create_processor<value_t>(metric, sizes[i], D, k, rowMajorQuery, userStream);
metric_processors[i]->preprocess(input[i]);
}

int device;
RAFT_CUDA_TRY(cudaGetDevice(&device));

Expand Down Expand Up @@ -430,14 +451,6 @@ void brute_force_knn_impl(
raft::linalg::transpose(handle, input[i], index, sizes[i], D, stream);
}

// cosine/correlation are handled by metric processor, use IP distance
// for brute force knn call.
auto tiled_metric = metric;
if (metric == raft::distance::DistanceType::CosineExpanded ||
metric == raft::distance::DistanceType::CorrelationExpanded) {
tiled_metric = raft::distance::DistanceType::InnerProduct;
}

tiled_brute_force_knn<value_t, IdxType>(stream_pool_handle,
search,
index,
Expand All @@ -447,7 +460,7 @@ void brute_force_knn_impl(
k,
out_d_ptr,
out_i_ptr,
tiled_metric,
metric,
metricArg,
0,
0,
Expand All @@ -470,12 +483,6 @@ void brute_force_knn_impl(
knn_merge_parts(out_D, out_I, res_D, res_I, n, input.size(), k, userStream, trans.data());
}

query_metric_processor->revert(search_items);
query_metric_processor->postprocess(out_D);
for (size_t i = 0; i < input.size(); i++) {
metric_processors[i]->revert(input[i]);
}

if (translations == nullptr) delete id_ranges;
};

Expand Down
3 changes: 0 additions & 3 deletions python/pylibraft/pylibraft/test/test_brute_force.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,6 @@ def test_knn(
expected_indices = argsort[i]
gpu_dists = actual_distances[i]

if metric == "correlation" or metric == "cosine":
gpu_dists = gpu_dists[::-1]

cpu_ordered = pw_dists[i, expected_indices]
np.testing.assert_allclose(
cpu_ordered[:k], gpu_dists, atol=1e-4, rtol=1e-4
Expand Down