Skip to content

Commit

Permalink
Output non-normalized distances in IVF-PQ and brute-force KNN (#892)
Browse files Browse the repository at this point in the history
Solves point 9 of #711 (the observed errors were due to comparing distances at different scales).

This PR does two things:

- Changes the naive BF KNN to not normalize int when converting to float and simply cast, resulting in distances in the same scale as the output of IVF-Flat (and solving errors observed in tests).
- Changes IVF-PQ to output distances in the original scale and not the normalized scale, consistently with IVF-Flat.

Authors:
  - Louis Sugy (https://github.com/Nyrio)

Approvers:
  - Tamas Bela Feher (https://github.com/tfeher)
  - Artem M. Chirkin (https://github.com/achirkin)

URL: #892
  • Loading branch information
Nyrio authored Oct 6, 2022
1 parent 11e00f7 commit e2b6399
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 9 deletions.
4 changes: 2 additions & 2 deletions cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1074,9 +1074,9 @@ void search_impl(const handle_t& handle,
rmm::device_uvector<float> coarse_distances_dev(n_queries * n_probes, stream, search_mr);
// The topk index of cluster(list) and queries
rmm::device_uvector<uint32_t> coarse_indices_dev(n_queries * n_probes, stream, search_mr);
// The topk distance value of candicate vectors from each cluster(list)
// The topk distance value of candidate vectors from each cluster(list)
rmm::device_uvector<AccT> refined_distances_dev(n_queries * n_probes * k, stream, search_mr);
// The topk index of candicate vectors from each cluster(list)
// The topk index of candidate vectors from each cluster(list)
rmm::device_uvector<IdxT> refined_indices_dev(n_queries * n_probes * k, stream, search_mr);

size_t float_query_size;
Expand Down
29 changes: 25 additions & 4 deletions cpp/include/raft/spatial/knn/detail/ivf_pq_search.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -405,23 +405,40 @@ void postprocess_distances(float* out, // [n_queries, topk]
distance::DistanceType metric,
uint32_t n_queries,
uint32_t topk,
float scaling_factor,
rmm::cuda_stream_view stream)
{
size_t len = size_t(n_queries) * size_t(topk);
switch (metric) {
case distance::DistanceType::L2Unexpanded:
case distance::DistanceType::L2Expanded: {
linalg::unaryOp(
out, in, len, [] __device__(ScoreT x) -> float { return float(x); }, stream);
out,
in,
len,
[scaling_factor] __device__(ScoreT x) -> float {
return scaling_factor * scaling_factor * float(x);
},
stream);
} break;
case distance::DistanceType::L2SqrtUnexpanded:
case distance::DistanceType::L2SqrtExpanded: {
linalg::unaryOp(
out, in, len, [] __device__(ScoreT x) -> float { return sqrtf(float(x)); }, stream);
out,
in,
len,
[scaling_factor] __device__(ScoreT x) -> float { return scaling_factor * sqrtf(float(x)); },
stream);
} break;
case distance::DistanceType::InnerProduct: {
linalg::unaryOp(
out, in, len, [] __device__(ScoreT x) -> float { return -float(x); }, stream);
out,
in,
len,
[scaling_factor] __device__(ScoreT x) -> float {
return -scaling_factor * scaling_factor * float(x);
},
stream);
} break;
default: RAFT_FAIL("Unexpected metric.");
}
Expand Down Expand Up @@ -998,6 +1015,7 @@ void ivfpq_search_worker(const handle_t& handle,
const float* query, // [n_queries, rot_dim]
IdxT* neighbors, // [n_queries, topK]
float* distances, // [n_queries, topK]
float scaling_factor,
rmm::mr::device_memory_resource* mr)
{
auto stream = handle.get_stream();
Expand Down Expand Up @@ -1125,7 +1143,8 @@ void ivfpq_search_worker(const handle_t& handle,
mr);

// Postprocessing
postprocess_distances(distances, topk_dists.data(), index.metric(), n_queries, topK, stream);
postprocess_distances(
distances, topk_dists.data(), index.metric(), n_queries, topK, scaling_factor, stream);
postprocess_neighbors(neighbors,
manage_local_topk,
data_indices,
Expand Down Expand Up @@ -1156,6 +1175,7 @@ struct ivfpq_search {
const float*,
IdxT*,
float*,
float,
rmm::mr::device_memory_resource*);

/**
Expand Down Expand Up @@ -1366,6 +1386,7 @@ inline void search(const handle_t& handle,
rot_queries.data() + uint64_t(index.rot_dim()) * offset_b,
neighbors + uint64_t(k) * (offset_q + offset_b),
distances + uint64_t(k) * (offset_q + offset_b),
utils::config<T>::kDivisor / utils::config<float>::kDivisor,
mr);
}
}
Expand Down
5 changes: 2 additions & 3 deletions cpp/test/spatial/ann_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ __global__ void naive_distance_kernel(EvalT* dist,
IdxT k,
raft::distance::DistanceType type)
{
detail::utils::mapping<EvalT> f{};
IdxT midx = threadIdx.x + blockIdx.x * blockDim.x;
if (midx >= m) return;
for (IdxT nidx = threadIdx.y + blockIdx.y * blockDim.y; nidx < n;
Expand All @@ -118,8 +117,8 @@ __global__ void naive_distance_kernel(EvalT* dist,
for (IdxT i = 0; i < k; ++i) {
IdxT xidx = i + midx * k;
IdxT yidx = i + nidx * k;
EvalT xv = f(x[xidx]);
EvalT yv = f(y[yidx]);
EvalT xv = (EvalT)x[xidx];
EvalT yv = (EvalT)y[yidx];
if (type == raft::distance::DistanceType::InnerProduct) {
acc += xv * yv;
} else {
Expand Down

0 comments on commit e2b6399

Please sign in to comment.