Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Output non-normalized distances in IVF-PQ and brute-force KNN #892

Merged
merged 6 commits into from
Oct 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1074,9 +1074,9 @@ void search_impl(const handle_t& handle,
rmm::device_uvector<float> coarse_distances_dev(n_queries * n_probes, stream, search_mr);
// The topk index of cluster(list) and queries
rmm::device_uvector<uint32_t> coarse_indices_dev(n_queries * n_probes, stream, search_mr);
// The topk distance value of candicate vectors from each cluster(list)
// The topk distance value of candidate vectors from each cluster(list)
rmm::device_uvector<AccT> refined_distances_dev(n_queries * n_probes * k, stream, search_mr);
// The topk index of candicate vectors from each cluster(list)
// The topk index of candidate vectors from each cluster(list)
rmm::device_uvector<IdxT> refined_indices_dev(n_queries * n_probes * k, stream, search_mr);

size_t float_query_size;
Expand Down
29 changes: 25 additions & 4 deletions cpp/include/raft/spatial/knn/detail/ivf_pq_search.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -405,23 +405,40 @@ void postprocess_distances(float* out, // [n_queries, topk]
distance::DistanceType metric,
uint32_t n_queries,
uint32_t topk,
float scaling_factor,
rmm::cuda_stream_view stream)
{
size_t len = size_t(n_queries) * size_t(topk);
switch (metric) {
case distance::DistanceType::L2Unexpanded:
case distance::DistanceType::L2Expanded: {
linalg::unaryOp(
out, in, len, [] __device__(ScoreT x) -> float { return float(x); }, stream);
out,
in,
len,
[scaling_factor] __device__(ScoreT x) -> float {
return scaling_factor * scaling_factor * float(x);
},
stream);
} break;
case distance::DistanceType::L2SqrtUnexpanded:
case distance::DistanceType::L2SqrtExpanded: {
linalg::unaryOp(
out, in, len, [] __device__(ScoreT x) -> float { return sqrtf(float(x)); }, stream);
out,
in,
len,
[scaling_factor] __device__(ScoreT x) -> float { return scaling_factor * sqrtf(float(x)); },
stream);
} break;
case distance::DistanceType::InnerProduct: {
linalg::unaryOp(
out, in, len, [] __device__(ScoreT x) -> float { return -float(x); }, stream);
out,
in,
len,
[scaling_factor] __device__(ScoreT x) -> float {
return -scaling_factor * scaling_factor * float(x);
},
stream);
} break;
default: RAFT_FAIL("Unexpected metric.");
}
Expand Down Expand Up @@ -998,6 +1015,7 @@ void ivfpq_search_worker(const handle_t& handle,
const float* query, // [n_queries, rot_dim]
IdxT* neighbors, // [n_queries, topK]
float* distances, // [n_queries, topK]
float scaling_factor,
rmm::mr::device_memory_resource* mr)
{
auto stream = handle.get_stream();
Expand Down Expand Up @@ -1125,7 +1143,8 @@ void ivfpq_search_worker(const handle_t& handle,
mr);

// Postprocessing
postprocess_distances(distances, topk_dists.data(), index.metric(), n_queries, topK, stream);
postprocess_distances(
distances, topk_dists.data(), index.metric(), n_queries, topK, scaling_factor, stream);
postprocess_neighbors(neighbors,
manage_local_topk,
data_indices,
Expand Down Expand Up @@ -1156,6 +1175,7 @@ struct ivfpq_search {
const float*,
IdxT*,
float*,
float,
rmm::mr::device_memory_resource*);

/**
Expand Down Expand Up @@ -1366,6 +1386,7 @@ inline void search(const handle_t& handle,
rot_queries.data() + uint64_t(index.rot_dim()) * offset_b,
neighbors + uint64_t(k) * (offset_q + offset_b),
distances + uint64_t(k) * (offset_q + offset_b),
utils::config<T>::kDivisor / utils::config<float>::kDivisor,
mr);
}
}
Expand Down
5 changes: 2 additions & 3 deletions cpp/test/spatial/ann_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ __global__ void naive_distance_kernel(EvalT* dist,
IdxT k,
raft::distance::DistanceType type)
{
detail::utils::mapping<EvalT> f{};
IdxT midx = threadIdx.x + blockIdx.x * blockDim.x;
if (midx >= m) return;
for (IdxT nidx = threadIdx.y + blockIdx.y * blockDim.y; nidx < n;
Expand All @@ -118,8 +117,8 @@ __global__ void naive_distance_kernel(EvalT* dist,
for (IdxT i = 0; i < k; ++i) {
IdxT xidx = i + midx * k;
IdxT yidx = i + nidx * k;
EvalT xv = f(x[xidx]);
EvalT yv = f(y[yidx]);
EvalT xv = (EvalT)x[xidx];
EvalT yv = (EvalT)y[yidx];
if (type == raft::distance::DistanceType::InnerProduct) {
acc += xv * yv;
} else {
Expand Down