Skip to content

Commit

Permalink
Revert "Subsampling for IVF-PQ codebook generation (rapidsai#2052)"
Browse files Browse the repository at this point in the history
This reverts commit e272176.
  • Loading branch information
cjnolet committed Jan 31, 2024
1 parent c70d17a commit 6224649
Show file tree
Hide file tree
Showing 6 changed files with 14 additions and 52 deletions.
5 changes: 1 addition & 4 deletions cpp/bench/ann/src/raft/raft_ann_bench_param_parser.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -87,9 +87,6 @@ void parse_build_param(const nlohmann::json& conf,
"', should be either 'cluster' or 'subspace'");
}
}
if (conf.contains("max_train_points_per_pq_code")) {
param.max_train_points_per_pq_code = conf.at("max_train_points_per_pq_code");
}
}

template <typename T, typename IdxT>
Expand Down
29 changes: 9 additions & 20 deletions cpp/include/raft/neighbors/detail/ivf_pq_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -353,19 +353,14 @@ void train_per_subset(raft::resources const& handle,
const float* trainset, // [n_rows, dim]
const uint32_t* labels, // [n_rows]
uint32_t kmeans_n_iters,
uint32_t max_train_points_per_pq_code,
rmm::mr::device_memory_resource* managed_memory)
{
auto stream = resource::get_cuda_stream(handle);
auto device_memory = resource::get_workspace_resource(handle);

rmm::device_uvector<float> pq_centers_tmp(index.pq_centers().size(), stream, device_memory);
// Subsampling the train set for codebook generation based on max_train_points_per_pq_code.
size_t big_enough = max_train_points_per_pq_code * size_t(index.pq_book_size());
auto pq_n_rows = uint32_t(std::min(big_enough, n_rows));
rmm::device_uvector<float> sub_trainset(
pq_n_rows * size_t(index.pq_len()), stream, device_memory);
rmm::device_uvector<uint32_t> sub_labels(pq_n_rows, stream, device_memory);
rmm::device_uvector<float> sub_trainset(n_rows * size_t(index.pq_len()), stream, device_memory);
rmm::device_uvector<uint32_t> sub_labels(n_rows, stream, device_memory);

rmm::device_uvector<uint32_t> pq_cluster_sizes(index.pq_book_size(), stream, device_memory);

Expand All @@ -376,7 +371,7 @@ void train_per_subset(raft::resources const& handle,
// Get the rotated cluster centers for each training vector.
// This will be subtracted from the input vectors afterwards.
utils::copy_selected<float, float, size_t, uint32_t>(
pq_n_rows,
n_rows,
index.pq_len(),
index.centers_rot().data_handle() + index.pq_len() * j,
labels,
Expand All @@ -392,7 +387,7 @@ void train_per_subset(raft::resources const& handle,
true,
false,
index.pq_len(),
pq_n_rows,
n_rows,
index.dim(),
&alpha,
index.rotation_matrix().data_handle() + index.dim() * index.pq_len() * j,
Expand All @@ -405,14 +400,13 @@ void train_per_subset(raft::resources const& handle,
stream);

// train PQ codebook for this subspace
auto sub_trainset_view = raft::make_device_matrix_view<const float, IdxT>(
sub_trainset.data(), pq_n_rows, index.pq_len());
auto sub_trainset_view =
raft::make_device_matrix_view<const float, IdxT>(sub_trainset.data(), n_rows, index.pq_len());
auto centers_tmp_view = raft::make_device_matrix_view<float, IdxT>(
pq_centers_tmp.data() + index.pq_book_size() * index.pq_len() * j,
index.pq_book_size(),
index.pq_len());
auto sub_labels_view =
raft::make_device_vector_view<uint32_t, IdxT>(sub_labels.data(), pq_n_rows);
auto sub_labels_view = raft::make_device_vector_view<uint32_t, IdxT>(sub_labels.data(), n_rows);
auto cluster_sizes_view =
raft::make_device_vector_view<uint32_t, IdxT>(pq_cluster_sizes.data(), index.pq_book_size());
raft::cluster::kmeans_balanced_params kmeans_params;
Expand All @@ -436,7 +430,6 @@ void train_per_cluster(raft::resources const& handle,
const float* trainset, // [n_rows, dim]
const uint32_t* labels, // [n_rows]
uint32_t kmeans_n_iters,
uint32_t max_train_points_per_pq_code,
rmm::mr::device_memory_resource* managed_memory)
{
auto stream = resource::get_cuda_stream(handle);
Expand Down Expand Up @@ -484,11 +477,9 @@ void train_per_cluster(raft::resources const& handle,
indices + cluster_offsets[l],
device_memory);

// limit the cluster size to bound the training time based on max_train_points_per_pq_code
// If pq_book_size is less than pq_dim, use max_train_points_per_pq_code per pq_dim instead
// limit the cluster size to bound the training time.
// [sic] we interpret the data as pq_len-dimensional
size_t big_enough =
max_train_points_per_pq_code * std::max<size_t>(index.pq_book_size(), index.pq_dim());
size_t big_enough = 256ul * std::max<size_t>(index.pq_book_size(), index.pq_dim());
size_t available_rows = size_t(cluster_size) * size_t(index.pq_dim());
auto pq_n_rows = uint32_t(std::min(big_enough, available_rows));
// train PQ codebook for this cluster
Expand Down Expand Up @@ -1797,7 +1788,6 @@ auto build(raft::resources const& handle,
trainset.data_handle(),
labels.data(),
params.kmeans_n_iters,
params.max_train_points_per_pq_code,
&managed_mr);
break;
case codebook_gen::PER_CLUSTER:
Expand All @@ -1807,7 +1797,6 @@ auto build(raft::resources const& handle,
trainset.data_handle(),
labels.data(),
params.kmeans_n_iters,
params.max_train_points_per_pq_code,
&managed_mr);
break;
default: RAFT_FAIL("Unreachable code");
Expand Down
10 changes: 1 addition & 9 deletions cpp/include/raft/neighbors/ivf_pq_types.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION.
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -105,14 +105,6 @@ struct index_params : ann::index_params {
* flag to `true` if you prefer to use as little GPU memory for the database as possible.
*/
bool conservative_memory_allocation = false;
/**
* The max number of data points to use per PQ code during PQ codebook training. Using more data
* points per PQ code may increase the quality of PQ codebook but may also increase the build
* time. The parameter is applied to both PQ codebook generation methods, i.e., PER_SUBSPACE and
* PER_CLUSTER. In both cases, we will use `pq_book_size * max_train_points_per_pq_code` training
* points to train each codebook.
*/
uint32_t max_train_points_per_pq_code = 256;
};

struct search_params : ann::search_params {
Expand Down
1 change: 0 additions & 1 deletion docs/source/ann_benchmarks_param_tuning.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ IVF-pq is an inverted-file index, which partitions the vectors into a series of
| `pq_bits` | `build` | N | Positive Integer. [4-8] | 8 | Bit length of the vector element after quantization. |
| `codebook_kind` | `build` | N | ["cluster", "subspace"] | "subspace" | Type of codebook. See the [API docs](https://docs.rapids.ai/api/raft/nightly/cpp_api/neighbors_ivf_pq/#_CPPv412codebook_gen) for more detail |
| `dataset_memory_type` | `build` | N | ["device", "host", "mmap"] | "host" | What memory type should the dataset reside? |
| `max_train_points_per_pq_code` | `build` | N | Positive Number >=1 | 256 | Max number of data points per PQ code used for PQ code book creation. Depending on input dataset size, the data points could be less than what user specifies. |
| `query_memory_type` | `search` | N | ["device", "host", "mmap"] | "device | What memory type should the queries reside? |
| `nprobe` | `search` | Y | Positive Integer >0 | | The closest number of clusters to search for each query vector. Larger values will improve recall but will search more points in the index. |
| `internalDistanceDtype` | `search` | N | [`float`, `half`] | `half` | The precision to use for the distance computations. Lower precision can increase performance at the cost of accuracy. |
Expand Down
3 changes: 1 addition & 2 deletions python/pylibraft/pylibraft/neighbors/ivf_pq/cpp/c_ivf_pq.pxd
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2022-2024, NVIDIA CORPORATION.
# Copyright (c) 2022-2023, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -78,7 +78,6 @@ cdef extern from "raft/neighbors/ivf_pq_types.hpp" \
codebook_gen codebook_kind
bool force_random_rotation
bool conservative_memory_allocation
uint32_t max_train_points_per_pq_code

cdef cppclass index[IdxT](ann_index):
index(const device_resources& handle,
Expand Down
18 changes: 2 additions & 16 deletions python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pyx
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2022-2024, NVIDIA CORPORATION.
# Copyright (c) 2022-2023, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -156,14 +156,6 @@ cdef class IndexParams:
repeated calls to `extend` (extending the database).
To disable this behavior and use as little GPU memory for the
database as possible, set this flat to `True`.
max_train_points_per_pq_code : int, default = 256
The max number of data points to use per PQ code during PQ codebook
training. Using more data points per PQ code may increase the
quality of PQ codebook but may also increase the build time. The
parameter is applied to both PQ codebook generation methods, i.e.,
PER_SUBSPACE and PER_CLUSTER. In both cases, we will use
pq_book_size * max_train_points_per_pq_code training points to
train each codebook.
"""
def __init__(self, *,
n_lists=1024,
Expand All @@ -175,8 +167,7 @@ cdef class IndexParams:
codebook_kind="subspace",
force_random_rotation=False,
add_data_on_build=True,
conservative_memory_allocation=False,
max_train_points_per_pq_code=256):
conservative_memory_allocation=False):
self.params.n_lists = n_lists
self.params.metric = _get_metric(metric)
self.params.metric_arg = 0
Expand All @@ -194,8 +185,6 @@ cdef class IndexParams:
self.params.add_data_on_build = add_data_on_build
self.params.conservative_memory_allocation = \
conservative_memory_allocation
self.params.max_train_points_per_pq_code = \
max_train_points_per_pq_code

@property
def n_lists(self):
Expand Down Expand Up @@ -237,9 +226,6 @@ cdef class IndexParams:
def conservative_memory_allocation(self):
return self.params.conservative_memory_allocation

@property
def max_train_points_per_pq_code(self):
return self.params.max_train_points_per_pq_code

cdef class Index:
# We store a pointer to the index because it dose not have a trivial
Expand Down

0 comments on commit 6224649

Please sign in to comment.