Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable host dataset for IVF-Flat #1635

Merged
merged 4 commits into from
Nov 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cpp/bench/ann/src/raft/raft_ivf_flat_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class RaftIvfFlatGpu : public ANN<T> {
AlgoProperty get_preference() const override
{
AlgoProperty property;
property.dataset_memory_type = MemoryType::Device;
property.dataset_memory_type = MemoryType::HostMmap;
property.query_memory_type = MemoryType::Device;
return property;
}
Expand Down
106 changes: 74 additions & 32 deletions cpp/include/raft/neighbors/detail/ivf_flat_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,8 @@ RAFT_KERNEL build_index_kernel(const LabelT* labels,
uint32_t* list_sizes_ptr,
IdxT n_rows,
uint32_t dim,
uint32_t veclen)
uint32_t veclen,
IdxT batch_offset = 0)
{
const IdxT i = IdxT(blockDim.x) * IdxT(blockIdx.x) + threadIdx.x;
if (i >= n_rows) { return; }
Expand All @@ -131,7 +132,7 @@ RAFT_KERNEL build_index_kernel(const LabelT* labels,
auto* list_data = list_data_ptrs[list_id];

// Record the source vector id in the index
list_index[inlist_id] = source_ixs == nullptr ? i : source_ixs[i];
list_index[inlist_id] = source_ixs == nullptr ? i + batch_offset : source_ixs[i];

// The data is written in interleaved groups of `index::kGroupSize` vectors
using interleaved_group = Pow2<kIndexGroupSize>;
Expand Down Expand Up @@ -180,16 +181,33 @@ void extend(raft::resources const& handle,

auto new_labels = raft::make_device_vector<LabelT, IdxT>(handle, n_rows);
raft::cluster::kmeans_balanced_params kmeans_params;
kmeans_params.metric = index->metric();
auto new_vectors_view = raft::make_device_matrix_view<const T, IdxT>(new_vectors, n_rows, dim);
kmeans_params.metric = index->metric();
auto orig_centroids_view =
raft::make_device_matrix_view<const float, IdxT>(index->centers().data_handle(), n_lists, dim);
raft::cluster::kmeans_balanced::predict(handle,
kmeans_params,
new_vectors_view,
orig_centroids_view,
new_labels.view(),
utils::mapping<float>{});
// Calculate the batch size for the input data if it's not accessible directly from the device
constexpr size_t kReasonableMaxBatchSize = 65536;
size_t max_batch_size = std::min<size_t>(n_rows, kReasonableMaxBatchSize);

// Predict the cluster labels for the new data, in batches if necessary
utils::batch_load_iterator<T> vec_batches(new_vectors,
n_rows,
index->dim(),
max_batch_size,
stream,
resource::get_workspace_resource(handle));

for (const auto& batch : vec_batches) {
auto batch_data_view =
raft::make_device_matrix_view<const T, IdxT>(batch.data(), batch.size(), index->dim());
auto batch_labels_view = raft::make_device_vector_view<LabelT, IdxT>(
new_labels.data_handle() + batch.offset(), batch.size());
raft::cluster::kmeans_balanced::predict(handle,
kmeans_params,
batch_data_view,
orig_centroids_view,
batch_labels_view,
utils::mapping<float>{});
}

auto* list_sizes_ptr = index->list_sizes().data_handle();
auto old_list_sizes_dev = raft::make_device_vector<uint32_t, IdxT>(handle, n_lists);
Expand All @@ -202,14 +220,19 @@ void extend(raft::resources const& handle,
auto list_sizes_view =
raft::make_device_vector_view<std::remove_pointer_t<decltype(list_sizes_ptr)>, IdxT>(
list_sizes_ptr, n_lists);
auto const_labels_view = make_const_mdspan(new_labels.view());
raft::cluster::kmeans_balanced::helpers::calc_centers_and_sizes(handle,
new_vectors_view,
const_labels_view,
centroids_view,
list_sizes_view,
false,
utils::mapping<float>{});
for (const auto& batch : vec_batches) {
auto batch_data_view =
raft::make_device_matrix_view<const T, IdxT>(batch.data(), batch.size(), index->dim());
auto batch_labels_view = raft::make_device_vector_view<const LabelT, IdxT>(
new_labels.data_handle() + batch.offset(), batch.size());
raft::cluster::kmeans_balanced::helpers::calc_centers_and_sizes(handle,
batch_data_view,
batch_labels_view,
centroids_view,
list_sizes_view,
false,
utils::mapping<float>{});
}
} else {
raft::stats::histogram<uint32_t, IdxT>(raft::stats::HistTypeAuto,
reinterpret_cast<int32_t*>(list_sizes_ptr),
Expand Down Expand Up @@ -244,20 +267,39 @@ void extend(raft::resources const& handle,
// we'll rebuild the `list_sizes_ptr` in the following kernel, using it as an atomic counter.
raft::copy(list_sizes_ptr, old_list_sizes_dev.data_handle(), n_lists, stream);

// Kernel to insert the new vectors
const dim3 block_dim(256);
const dim3 grid_dim(raft::ceildiv<IdxT>(n_rows, block_dim.x));
build_index_kernel<<<grid_dim, block_dim, 0, stream>>>(new_labels.data_handle(),
new_vectors,
new_indices,
index->data_ptrs().data_handle(),
index->inds_ptrs().data_handle(),
list_sizes_ptr,
n_rows,
dim,
index->veclen());
RAFT_CUDA_TRY(cudaPeekAtLastError());

utils::batch_load_iterator<IdxT> vec_indices(
new_indices, n_rows, 1, max_batch_size, stream, resource::get_workspace_resource(handle));
utils::batch_load_iterator<IdxT> idx_batch = vec_indices.begin();
size_t next_report_offset = 0;
size_t d_report_offset = n_rows * 5 / 100;
for (const auto& batch : vec_batches) {
auto batch_data_view =
raft::make_device_matrix_view<const T, IdxT>(batch.data(), batch.size(), index->dim());
// Kernel to insert the new vectors
const dim3 block_dim(256);
const dim3 grid_dim(raft::ceildiv<IdxT>(batch.size(), block_dim.x));
build_index_kernel<T, IdxT, LabelT>
<<<grid_dim, block_dim, 0, stream>>>(new_labels.data_handle() + batch.offset(),
batch_data_view.data_handle(),
idx_batch->data(),
index->data_ptrs().data_handle(),
index->inds_ptrs().data_handle(),
list_sizes_ptr,
batch.size(),
dim,
index->veclen(),
batch.offset());
RAFT_CUDA_TRY(cudaPeekAtLastError());

if (batch.offset() > next_report_offset) {
float progress = batch.offset() * 100.0f / n_rows;
RAFT_LOG_DEBUG("ivf_flat::extend added vectors %zu, %6.1f%% complete",
static_cast<size_t>(batch.offset()),
progress);
next_report_offset += d_report_offset;
}
++idx_batch;
}
// Precompute the centers vector norms for L2Expanded distance
if (!index->center_norms().has_value()) {
index->allocate_center_norms(handle);
Expand Down
52 changes: 51 additions & 1 deletion cpp/include/raft/neighbors/ivf_flat-ext.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,18 @@ void build(raft::resources const& handle,
raft::device_matrix_view<const T, IdxT, row_major> dataset,
raft::neighbors::ivf_flat::index<T, IdxT>& idx) RAFT_EXPLICIT;

template <typename T, typename IdxT>
auto build(raft::resources const& handle,
const index_params& params,
raft::host_matrix_view<const T, IdxT, row_major> dataset)
-> index<T, IdxT> RAFT_EXPLICIT;

template <typename T, typename IdxT>
void build(raft::resources const& handle,
const index_params& params,
raft::host_matrix_view<const T, IdxT, row_major> dataset,
raft::neighbors::ivf_flat::index<T, IdxT>& idx) RAFT_EXPLICIT;

template <typename T, typename IdxT>
auto extend(raft::resources const& handle,
const index<T, IdxT>& orig_index,
Expand All @@ -74,6 +86,19 @@ void extend(raft::resources const& handle,
std::optional<raft::device_vector_view<const IdxT, IdxT>> new_indices,
index<T, IdxT>* index) RAFT_EXPLICIT;

template <typename T, typename IdxT>
auto extend(raft::resources const& handle,
raft::host_matrix_view<const T, IdxT, row_major> new_vectors,
std::optional<raft::host_vector_view<const IdxT, IdxT>> new_indices,
const raft::neighbors::ivf_flat::index<T, IdxT>& orig_index)
-> raft::neighbors::ivf_flat::index<T, IdxT> RAFT_EXPLICIT;

template <typename T, typename IdxT>
void extend(raft::resources const& handle,
raft::host_matrix_view<const T, IdxT, row_major> new_vectors,
std::optional<raft::host_vector_view<const IdxT, IdxT>> new_indices,
index<T, IdxT>* index) RAFT_EXPLICIT;

template <typename T, typename IdxT, typename IvfSampleFilterT>
void search_with_filtering(raft::resources const& handle,
const search_params& params,
Expand Down Expand Up @@ -137,6 +162,18 @@ void search(raft::resources const& handle,
raft::resources const& handle, \
const raft::neighbors::ivf_flat::index_params& params, \
raft::device_matrix_view<const T, IdxT, row_major> dataset, \
raft::neighbors::ivf_flat::index<T, IdxT>& idx); \
\
extern template auto raft::neighbors::ivf_flat::build<T, IdxT>( \
raft::resources const& handle, \
const raft::neighbors::ivf_flat::index_params& params, \
raft::host_matrix_view<const T, IdxT, row_major> dataset) \
->raft::neighbors::ivf_flat::index<T, IdxT>; \
\
extern template void raft::neighbors::ivf_flat::build<T, IdxT>( \
raft::resources const& handle, \
const raft::neighbors::ivf_flat::index_params& params, \
raft::host_matrix_view<const T, IdxT, row_major> dataset, \
raft::neighbors::ivf_flat::index<T, IdxT>& idx);

instantiate_raft_neighbors_ivf_flat_build(float, int64_t);
Expand Down Expand Up @@ -171,7 +208,20 @@ instantiate_raft_neighbors_ivf_flat_build(uint8_t, int64_t);
raft::resources const& handle, \
raft::device_matrix_view<const T, IdxT, row_major> new_vectors, \
std::optional<raft::device_vector_view<const IdxT, IdxT>> new_indices, \
raft::neighbors::ivf_flat::index<T, IdxT>* index);
raft::neighbors::ivf_flat::index<T, IdxT>* index); \
\
extern template void raft::neighbors::ivf_flat::extend<T, IdxT>( \
raft::resources const& handle, \
raft::host_matrix_view<const T, IdxT, row_major> new_vectors, \
std::optional<raft::host_vector_view<const IdxT, IdxT>> new_indices, \
raft::neighbors::ivf_flat::index<T, IdxT>* index); \
\
extern template auto raft::neighbors::ivf_flat::extend<T, IdxT>( \
const raft::resources& handle, \
raft::host_matrix_view<const T, IdxT, row_major> new_vectors, \
std::optional<raft::host_vector_view<const IdxT, IdxT>> new_indices, \
const raft::neighbors::ivf_flat::index<T, IdxT>& idx) \
->raft::neighbors::ivf_flat::index<T, IdxT>;

instantiate_raft_neighbors_ivf_flat_extend(float, int64_t);
instantiate_raft_neighbors_ivf_flat_extend(int8_t, int64_t);
Expand Down
75 changes: 69 additions & 6 deletions cpp/include/raft/neighbors/ivf_flat-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ namespace raft::neighbors::ivf_flat {
*
* @param[in] handle
* @param[in] params configure the index building
* @param[in] dataset a device pointer to a row-major matrix [n_rows, dim]
* @param[in] dataset a host or device pointer to a row-major matrix [n_rows, dim]
* @param[in] n_rows the number of samples
* @param[in] dim the dimensionality of the data
*
Expand Down Expand Up @@ -102,7 +102,7 @@ auto build(raft::resources const& handle,
*
* @param[in] handle
* @param[in] params configure the index building
* @param[in] dataset a device pointer to a row-major matrix [n_rows, dim]
* @param[in] dataset a device matrix [n_rows, dim]
*
* @return the constructed ivf-flat index
*/
Expand All @@ -118,6 +118,20 @@ auto build(raft::resources const& handle,
static_cast<IdxT>(dataset.extent(1)));
}

/**
* @brief Build the index from a dataset in host memory.
*/
template <typename T, typename IdxT>
auto build(raft::resources const& handle,
const index_params& params,
raft::host_matrix_view<const T, IdxT, row_major> dataset) -> index<T, IdxT>
{
return raft::neighbors::ivf_flat::detail::build(handle,
params,
dataset.data_handle(),
static_cast<IdxT>(dataset.extent(0)),
static_cast<IdxT>(dataset.extent(1)));
}
/**
* @brief Build the index from the dataset for efficient search.
*
Expand Down Expand Up @@ -162,6 +176,21 @@ void build(raft::resources const& handle,
static_cast<IdxT>(dataset.extent(1)));
}

/**
* @brief Build the index from a dataset in host memory.
*/
template <typename T, typename IdxT>
void build(raft::resources const& handle,
const index_params& params,
raft::host_matrix_view<const T, IdxT, row_major> dataset,
raft::neighbors::ivf_flat::index<T, IdxT>& idx)
{
idx = raft::neighbors::ivf_flat::detail::build(handle,
params,
dataset.data_handle(),
static_cast<IdxT>(dataset.extent(0)),
static_cast<IdxT>(dataset.extent(1)));
}
/** @} */

/**
Expand All @@ -188,8 +217,8 @@ void build(raft::resources const& handle,
*
* @param[in] handle
* @param[in] orig_index original index
* @param[in] new_vectors a device pointer to a row-major matrix [n_rows, index.dim()]
* @param[in] new_indices a device pointer to a vector of indices [n_rows].
* @param[in] new_vectors a device/host pointer to a row-major matrix [n_rows, index.dim()]
* @param[in] new_indices a device/host pointer to a vector of indices [n_rows].
* If the original index is empty (`orig_index.size() == 0`), you can pass `nullptr`
* here to imply a continuous range `[0...n_rows)`.
* @param[in] n_rows number of rows in `new_vectors`
Expand Down Expand Up @@ -257,6 +286,23 @@ auto extend(raft::resources const& handle,
new_vectors.extent(0));
}

/**
* @brief Extend the index with additional vectors.
*
* This overloads takes input data in host memory.
*/
template <typename T, typename IdxT>
auto extend(raft::resources const& handle,
raft::host_matrix_view<const T, IdxT, row_major> new_vectors,
std::optional<raft::host_vector_view<const IdxT, IdxT>> new_indices,
const index<T, IdxT>& orig_index) -> index<T, IdxT>
{
return extend<T, IdxT>(handle,
orig_index,
new_vectors.data_handle(),
new_indices.has_value() ? new_indices.value().data_handle() : nullptr,
new_vectors.extent(0));
}
/** @} */

/**
Expand All @@ -279,8 +325,8 @@ auto extend(raft::resources const& handle,
*
* @param handle
* @param[inout] index
* @param[in] new_vectors a device pointer to a row-major matrix [n_rows, index.dim()]
* @param[in] new_indices a device pointer to a vector of indices [n_rows].
* @param[in] new_vectors a device/host pointer to a row-major matrix [n_rows, index.dim()]
* @param[in] new_indices a device/host pointer to a vector of indices [n_rows].
* If the original index is empty (`orig_index.size() == 0`), you can pass `nullptr`
* here to imply a continuous range `[0...n_rows)`.
* @param[in] n_rows the number of samples
Expand Down Expand Up @@ -339,6 +385,23 @@ void extend(raft::resources const& handle,
static_cast<IdxT>(new_vectors.extent(0)));
}

/**
* @brief Extend the index with additional vectors.
*
* This overloads takes input data in host memory.
*/
template <typename T, typename IdxT>
void extend(raft::resources const& handle,
raft::host_matrix_view<const T, IdxT, row_major> new_vectors,
std::optional<raft::host_vector_view<const IdxT, IdxT>> new_indices,
index<T, IdxT>* index)
{
extend(handle,
index,
new_vectors.data_handle(),
new_indices.has_value() ? new_indices.value().data_handle() : nullptr,
static_cast<IdxT>(new_vectors.extent(0)));
}
/** @} */

/**
Expand Down
Loading