Skip to content

Commit

Permalink
Add L2SqrtExpanded support to ivf_flat ANN indices (#1133)
Browse files Browse the repository at this point in the history
Authors:
  - Ben Frederickson (https://github.com/benfred)
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Artem M. Chirkin (https://github.com/achirkin)
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #1133
  • Loading branch information
benfred authored Jan 13, 2023
1 parent 3f3a59e commit ab4f1fd
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 50 deletions.
19 changes: 13 additions & 6 deletions cpp/include/raft/neighbors/ivf_flat_types.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -252,14 +252,21 @@ struct index : ann::index {
* Replace the content of the index with new uninitialized mdarrays to hold the indicated amount
* of data.
*/
void allocate(const handle_t& handle, IdxT index_size, bool allocate_center_norms)
void allocate(const handle_t& handle, IdxT index_size)
{
data_ = make_device_mdarray<T>(handle, make_extents<IdxT>(index_size, dim()));
indices_ = make_device_mdarray<IdxT>(handle, make_extents<IdxT>(index_size));
center_norms_ =
allocate_center_norms
? std::optional(make_device_mdarray<float>(handle, make_extents<uint32_t>(n_lists())))
: std::nullopt;

switch (metric_) {
case raft::distance::DistanceType::L2Expanded:
case raft::distance::DistanceType::L2SqrtExpanded:
case raft::distance::DistanceType::L2Unexpanded:
case raft::distance::DistanceType::L2SqrtUnexpanded:
center_norms_ = make_device_mdarray<float>(handle, make_extents<uint32_t>(n_lists()));
break;
default: center_norms_ = std::nullopt;
}

check_consistency();
}

Expand Down
18 changes: 14 additions & 4 deletions cpp/include/raft/spatial/knn/detail/ann_quantized.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021-2022, NVIDIA CORPORATION.
* Copyright (c) 2021-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -94,6 +94,18 @@ void approx_knn_ivfsq_build_index(knnIndex* index, const IVFSQParam& params, Int
index->gpu_res.get(), D, params.nlist, faiss_qtype, faiss_metric, params.encodeResidual));
}

inline bool ivf_flat_supported_metric(raft::distance::DistanceType metric)
{
switch (metric) {
case raft::distance::DistanceType::L2Unexpanded:
case raft::distance::DistanceType::L2Expanded:
case raft::distance::DistanceType::L2SqrtExpanded:
case raft::distance::DistanceType::L2SqrtUnexpanded:
case raft::distance::DistanceType::InnerProduct: return true;
default: return false;
}
}

template <typename T = float, typename IntType = int>
void approx_knn_build_index(const handle_t& handle,
knnIndex* index,
Expand All @@ -120,9 +132,7 @@ void approx_knn_build_index(const handle_t& handle,
}
if constexpr (std::is_same_v<T, float>) { index->metric_processor->preprocess(index_array); }

if (ivf_ft_pams && (metric == raft::distance::DistanceType::L2Unexpanded ||
metric == raft::distance::DistanceType::L2Expanded ||
metric == raft::distance::DistanceType::InnerProduct)) {
if (ivf_ft_pams && ivf_flat_supported_metric(metric)) {
auto new_params = from_legacy_index_params(*ivf_ft_pams, metric, metricArg);
index->ivf_flat<T, int64_t>() = std::make_unique<const ivf_flat::index<T, int64_t>>(
ivf_flat::build(handle, new_params, index_array, int64_t(n), D));
Expand Down
10 changes: 4 additions & 6 deletions cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -191,8 +191,7 @@ inline auto extend(const handle_t& handle,
update_host(&index_size, list_offsets_ptr + n_lists, 1, stream);
handle.sync_stream(stream);

ext_index.allocate(
handle, index_size, ext_index.metric() == raft::distance::DistanceType::L2Expanded);
ext_index.allocate(handle, index_size);

// Populate index with the old data
if (orig_index.size() > 0) {
Expand Down Expand Up @@ -359,8 +358,7 @@ inline void fill_refinement_index(const handle_t& handle,
stream);

IdxT index_size = n_roundup * n_lists;
refinement_index->allocate(
handle, index_size, refinement_index->metric() == raft::distance::DistanceType::L2Expanded);
refinement_index->allocate(handle, index_size);

RAFT_CUDA_TRY(cudaMemsetAsync(list_sizes_ptr, 0, n_lists * sizeof(uint32_t), stream));

Expand Down Expand Up @@ -454,7 +452,7 @@ auto load(const handle_t& handle, const std::string& filename) -> index<T, IdxT>
index<T, IdxT> index_ =
raft::spatial::knn::ivf_flat::index<T, IdxT>(handle, metric, n_lists, adaptive_centers, dim);

index_.allocate(handle, n_rows, metric == raft::distance::DistanceType::L2Expanded);
index_.allocate(handle, n_rows);
auto data = index_.data();
read_mdspan(handle, infile, data);
read_mdspan(handle, infile, index_.indices());
Expand Down
78 changes: 50 additions & 28 deletions cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -663,9 +663,11 @@ template <int Capacity,
typename T,
typename AccT,
typename IdxT,
typename Lambda>
typename Lambda,
typename PostLambda>
__global__ void __launch_bounds__(kThreadsPerBlock)
interleaved_scan_kernel(Lambda compute_dist,
PostLambda post_process,
const uint32_t query_smem_elems,
const T* query,
const uint32_t* coarse_index,
Expand Down Expand Up @@ -777,7 +779,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock)

// finalize and store selected neighbours
queue.done();
queue.store(distances, neighbors);
queue.store(distances, neighbors, post_process);
}

/**
Expand Down Expand Up @@ -805,8 +807,10 @@ template <int Capacity,
typename T,
typename AccT,
typename IdxT,
typename Lambda>
typename Lambda,
typename PostLambda>
void launch_kernel(Lambda lambda,
PostLambda post_process,
const ivf_flat::index<T, IdxT>& index,
const T* queries,
const uint32_t* coarse_index,
Expand All @@ -821,7 +825,7 @@ void launch_kernel(Lambda lambda,
RAFT_EXPECTS(Veclen == index.veclen(),
"Configured Veclen does not match the index interleaving pattern.");
constexpr auto kKernel =
interleaved_scan_kernel<Capacity, Veclen, Ascending, T, AccT, IdxT, Lambda>;
interleaved_scan_kernel<Capacity, Veclen, Ascending, T, AccT, IdxT, Lambda, PostLambda>;
const int max_query_smem = 16384;
int query_smem_elems =
std::min<int>(max_query_smem / sizeof(T), Pow2<Veclen * WarpSize>::roundUp(index.dim()));
Expand Down Expand Up @@ -851,6 +855,7 @@ void launch_kernel(Lambda lambda,
n_probes,
smem_size);
kKernel<<<grid_dim, block_dim, smem_size, stream>>>(lambda,
post_process,
query_smem_elems,
queries,
coarse_index,
Expand Down Expand Up @@ -941,15 +946,27 @@ void launch_with_fixed_consts(raft::distance::DistanceType metric, Args&&... arg
T,
AccT,
IdxT,
euclidean_dist<Veclen, T, AccT>>({}, std::forward<Args>(args)...);
euclidean_dist<Veclen, T, AccT>,
raft::identity_op>({}, {}, std::forward<Args>(args)...);
case raft::distance::DistanceType::L2SqrtExpanded:
case raft::distance::DistanceType::L2SqrtUnexpanded:
return launch_kernel<Capacity,
Veclen,
Ascending,
T,
AccT,
IdxT,
euclidean_dist<Veclen, T, AccT>,
raft::sqrt_op>({}, {}, std::forward<Args>(args)...);
case raft::distance::DistanceType::InnerProduct:
return launch_kernel<Capacity,
Veclen,
Ascending,
T,
AccT,
IdxT,
inner_prod_dist<Veclen, T, AccT>>({}, std::forward<Args>(args)...);
inner_prod_dist<Veclen, T, AccT>,
raft::identity_op>({}, {}, std::forward<Args>(args)...);
// NB: update the description of `knn::ivf_flat::build` when adding here a new metric.
default: RAFT_FAIL("The chosen distance metric is not supported (%d)", int(metric));
}
Expand Down Expand Up @@ -1105,28 +1122,33 @@ void search_impl(const handle_t& handle,
float beta = 0.0f;

// todo(lsugy): raft distance? (if performance is similar/better than gemm)
if (index.metric() == raft::distance::DistanceType::L2Expanded) {
alpha = -2.0f;
beta = 1.0f;
raft::linalg::rowNorm(query_norm_dev.data(),
converted_queries_ptr,
static_cast<IdxT>(index.dim()),
static_cast<IdxT>(n_queries),
raft::linalg::L2Norm,
true,
stream,
raft::sqrt_op());
utils::outer_add(query_norm_dev.data(),
(IdxT)n_queries,
index.center_norms()->data_handle(),
(IdxT)index.n_lists(),
distance_buffer_dev.data(),
stream);
RAFT_LOG_TRACE_VEC(index.center_norms()->data_handle(), std::min<uint32_t>(20, index.dim()));
RAFT_LOG_TRACE_VEC(distance_buffer_dev.data(), std::min<uint32_t>(20, index.n_lists()));
} else {
alpha = 1.0f;
beta = 0.0f;
switch (index.metric()) {
case raft::distance::DistanceType::L2Expanded:
case raft::distance::DistanceType::L2SqrtExpanded: {
alpha = -2.0f;
beta = 1.0f;
raft::linalg::rowNorm(query_norm_dev.data(),
converted_queries_ptr,
static_cast<IdxT>(index.dim()),
static_cast<IdxT>(n_queries),
raft::linalg::L2Norm,
true,
stream,
raft::sqrt_op());
utils::outer_add(query_norm_dev.data(),
(IdxT)n_queries,
index.center_norms()->data_handle(),
(IdxT)index.n_lists(),
distance_buffer_dev.data(),
stream);
RAFT_LOG_TRACE_VEC(index.center_norms()->data_handle(), std::min<uint32_t>(20, index.dim()));
RAFT_LOG_TRACE_VEC(distance_buffer_dev.data(), std::min<uint32_t>(20, index.n_lists()));
break;
}
default: {
alpha = 1.0f;
beta = 0.0f;
}
}

linalg::gemm(handle,
Expand Down
12 changes: 7 additions & 5 deletions cpp/include/raft/spatial/knn/detail/topk/warpsort_topk.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -212,12 +212,13 @@ class warp_sort {
* device pointer to a contiguous array, unique per-subwarp of size `kWarpWidth`
* (length: k <= kWarpWidth * kMaxArrLen).
*/
__device__ void store(T* out, IdxT* out_idx) const
template <typename Lambda = raft::identity_op>
__device__ void store(T* out, IdxT* out_idx, Lambda post_process = raft::identity_op()) const
{
int idx = Pow2<kWarpWidth>::mod(laneId());
#pragma unroll kMaxArrLen
for (int i = 0; i < kMaxArrLen && idx < k; i++, idx += kWarpWidth) {
out[idx] = val_arr_[i];
out[idx] = post_process(val_arr_[i]);
out_idx[idx] = idx_arr_[i];
}
}
Expand Down Expand Up @@ -591,9 +592,10 @@ class block_sort {
}

/** Save the content by the pointer location. */
__device__ void store(T* out, IdxT* out_idx) const
template <typename Lambda = raft::identity_op>
__device__ void store(T* out, IdxT* out_idx, Lambda post_process = raft::identity_op()) const
{
if (threadIdx.x < subwarp_align::Value) { queue_.store(out, out_idx); }
if (threadIdx.x < subwarp_align::Value) { queue_.store(out, out_idx, post_process); }
}

private:
Expand Down
4 changes: 3 additions & 1 deletion cpp/test/neighbors/ann_ivf_flat.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -294,6 +294,8 @@ const std::vector<AnnIvfFlatInputs<int64_t>> inputs = {
{1000, 10000, 4, 16, 40, 1024, raft::distance::DistanceType::L2Expanded, false},
{1000, 10000, 5, 16, 40, 1024, raft::distance::DistanceType::InnerProduct, false},
{1000, 10000, 8, 16, 40, 1024, raft::distance::DistanceType::InnerProduct, true},
{1000, 10000, 5, 16, 40, 1024, raft::distance::DistanceType::L2SqrtExpanded, false},
{1000, 10000, 8, 16, 40, 1024, raft::distance::DistanceType::L2SqrtExpanded, true},

// test dims that do not fit into kernel shared memory limits
{1000, 10000, 2048, 16, 40, 1024, raft::distance::DistanceType::L2Expanded, false},
Expand Down

0 comments on commit ab4f1fd

Please sign in to comment.