Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add random subsampling for IVF methods #2077

Merged
merged 11 commits into from
Jan 23, 2024
74 changes: 73 additions & 1 deletion cpp/include/raft/matrix/detail/gather.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
* Copyright (c) 2022-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -17,7 +17,15 @@
#pragma once

#include <functional>
#include <raft/common/nvtx.hpp>
#include <raft/core/device_mdarray.hpp>
#include <raft/core/device_mdspan.hpp>
#include <raft/core/host_mdarray.hpp>
#include <raft/core/host_mdspan.hpp>
#include <raft/core/operators.hpp>
#include <raft/core/pinned_mdarray.hpp>
#include <raft/core/pinned_mdspan.hpp>
#include <raft/util/cuda_dev_essentials.cuh>
#include <raft/util/cudart_utils.hpp>

namespace raft {
Expand Down Expand Up @@ -335,6 +343,70 @@ void gather_if(const InputIteratorT in,
gatherImpl(in, D, N, map, stencil, map_length, out, pred_op, transform_op, stream);
}

template <typename T, typename IdxT = int64_t>
void gather_buff(host_matrix_view<const T, IdxT> dataset,
host_vector_view<const IdxT, IdxT> indices,
IdxT offset,
pinned_matrix_view<T, IdxT> buff)
{
raft::common::nvtx::range<common::nvtx::domain::raft> fun_scope("gather_host_buff");
IdxT batch_size = std::min<IdxT>(buff.extent(0), indices.extent(0) - offset);

#pragma omp for
for (IdxT i = 0; i < batch_size; i++) {
IdxT in_idx = indices(offset + i);
for (IdxT k = 0; k < buff.extent(1); k++) {
buff(i, k) = dataset(in_idx, k);
}
tfeher marked this conversation as resolved.
Show resolved Hide resolved
}
}

template <typename T, typename IdxT>
void gather(raft::resources const& res,
host_matrix_view<const T, IdxT> dataset,
device_vector_view<const IdxT, IdxT> indices,
raft::device_matrix_view<T, IdxT> output)
{
raft::common::nvtx::range<common::nvtx::domain::raft> fun_scope("gather");
IdxT n_dim = output.extent(1);
IdxT n_train = output.extent(0);
auto indices_host = raft::make_host_vector<IdxT, IdxT>(n_train);
raft::copy(
indices_host.data_handle(), indices.data_handle(), n_train, resource::get_cuda_stream(res));
resource::sync_stream(res);

const size_t max_batch_size = 32768;
// Gather the vector on the host in tmp buffers. We use two buffers to overlap H2D sync
// and gathering the data.
raft::common::nvtx::push_range("gather::alloc_buffers");
auto out_tmp1 = raft::make_pinned_matrix<T, IdxT>(res, max_batch_size, n_dim);
auto out_tmp2 = raft::make_pinned_matrix<T, IdxT>(res, max_batch_size, n_dim);
auto view1 = out_tmp1.view();
auto view2 = out_tmp2.view();
raft::common::nvtx::pop_range();

gather_buff(dataset, make_const_mdspan(indices_host.view()), (IdxT)0, view1);
#pragma omp parallel
for (IdxT device_offset = 0; device_offset < n_train; device_offset += max_batch_size) {
IdxT batch_size = std::min<IdxT>(max_batch_size, n_train - device_offset);
#pragma omp master
raft::copy(output.data_handle() + device_offset * n_dim,
view1.data_handle(),
batch_size * n_dim,
resource::get_cuda_stream(res));
// Start gathering the next batch on the host.
IdxT host_offset = device_offset + batch_size;
batch_size = std::min<IdxT>(max_batch_size, n_train - host_offset);
if (batch_size > 0) {
gather_buff(dataset, make_const_mdspan(indices_host.view()), host_offset, view2);
}
#pragma omp master
resource::sync_stream(res);
#pragma omp barrier
std::swap(view1, view2);
}
}

} // namespace detail
} // namespace matrix
} // namespace raft
25 changes: 10 additions & 15 deletions cpp/include/raft/neighbors/detail/ivf_flat_build.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
* Copyright (c) 2022-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -361,28 +361,23 @@ inline auto build(raft::resources const& handle,

// Train the kmeans clustering
{
int random_seed = 137;
auto trainset_ratio = std::max<size_t>(
1, n_rows / std::max<size_t>(params.kmeans_trainset_fraction * n_rows, index.n_lists()));
auto n_rows_train = n_rows / trainset_ratio;
rmm::device_uvector<T> trainset(n_rows_train * index.dim(), stream);
// TODO: a proper sampling
RAFT_CUDA_TRY(cudaMemcpy2DAsync(trainset.data(),
sizeof(T) * index.dim(),
dataset,
sizeof(T) * index.dim() * trainset_ratio,
sizeof(T) * index.dim(),
n_rows_train,
cudaMemcpyDefault,
stream));
auto trainset_const_view =
raft::make_device_matrix_view<const T, IdxT>(trainset.data(), n_rows_train, index.dim());
auto trainset = make_device_matrix<T, IdxT>(handle, n_rows_train, index.dim());
raft::spatial::knn::detail::utils::subsample(
handle, dataset, n_rows, trainset.view(), random_seed);
auto centers_view = raft::make_device_matrix_view<float, IdxT>(
index.centers().data_handle(), index.n_lists(), index.dim());
raft::cluster::kmeans_balanced_params kmeans_params;
kmeans_params.n_iters = params.kmeans_n_iters;
kmeans_params.metric = index.metric();
raft::cluster::kmeans_balanced::fit(
handle, kmeans_params, trainset_const_view, centers_view, utils::mapping<float>{});
raft::cluster::kmeans_balanced::fit(handle,
kmeans_params,
make_const_mdspan(trainset.view()),
centers_view,
utils::mapping<float>{});
}

// add the data if necessary
Expand Down
88 changes: 30 additions & 58 deletions cpp/include/raft/neighbors/detail/ivf_pq_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,11 @@

#include <raft/cluster/kmeans_balanced.cuh>
#include <raft/core/device_mdarray.hpp>
#include <raft/core/device_resources.hpp>
#include <raft/core/logger.hpp>
#include <raft/core/nvtx.hpp>
#include <raft/core/operators.hpp>
#include <raft/core/resource/device_memory_resource.hpp>
#include <raft/core/resources.hpp>
#include <raft/distance/distance_types.hpp>
#include <raft/linalg/add.cuh>
Expand All @@ -46,7 +48,6 @@
#include <raft/util/pow2_utils.cuh>
#include <raft/util/vectorized.cuh>

#include <raft/core/resource/device_memory_resource.hpp>
#include <rmm/cuda_stream_view.hpp>
#include <rmm/device_uvector.hpp>
#include <rmm/mr/device/managed_memory_resource.hpp>
Expand Down Expand Up @@ -1754,76 +1755,47 @@ auto build(raft::resources const& handle,
utils::memzero(index.inds_ptrs().data_handle(), index.inds_ptrs().size(), stream);

{
int random_seed = 137;
auto trainset_ratio = std::max<size_t>(
1,
size_t(n_rows) / std::max<size_t>(params.kmeans_trainset_fraction * n_rows, index.n_lists()));
size_t n_rows_train = n_rows / trainset_ratio;

auto* device_memory = resource::get_workspace_resource(handle);
rmm::mr::managed_memory_resource managed_memory_upstream;
auto* device_mr = resource::get_workspace_resource(handle);
rmm::mr::managed_memory_resource managed_mr;

// Besides just sampling, we transform the input dataset into floats to make it easier
// to use gemm operations from cublas.
rmm::device_uvector<float> trainset(n_rows_train * index.dim(), stream, device_memory);
// TODO: a proper sampling
auto trainset =
make_device_mdarray<float>(handle, device_mr, make_extents<IdxT>(n_rows_train, dim));

if constexpr (std::is_same_v<T, float>) {
RAFT_CUDA_TRY(cudaMemcpy2DAsync(trainset.data(),
sizeof(T) * index.dim(),
dataset,
sizeof(T) * index.dim() * trainset_ratio,
sizeof(T) * index.dim(),
n_rows_train,
cudaMemcpyDefault,
stream));
raft::spatial::knn::detail::utils::subsample(
handle, dataset, n_rows, trainset.view(), random_seed);
} else {
size_t dim = index.dim();
cudaPointerAttributes dataset_attr;
RAFT_CUDA_TRY(cudaPointerGetAttributes(&dataset_attr, dataset));
if (dataset_attr.devicePointer != nullptr) {
// data is available on device: just run the kernel to copy and map the data
auto p = reinterpret_cast<T*>(dataset_attr.devicePointer);
auto trainset_view =
raft::make_device_vector_view<float, IdxT>(trainset.data(), dim * n_rows_train);
linalg::map_offset(handle, trainset_view, [p, trainset_ratio, dim] __device__(size_t i) {
auto col = i % dim;
return utils::mapping<float>{}(p[(i - col) * size_t(trainset_ratio) + col]);
});
} else {
// data is not available: first copy, then map inplace
auto trainset_tmp = reinterpret_cast<T*>(reinterpret_cast<uint8_t*>(trainset.data()) +
(sizeof(float) - sizeof(T)) * index.dim());
// We copy the data in strides, one row at a time, and place the smaller rows of type T
// at the end of float rows.
RAFT_CUDA_TRY(cudaMemcpy2DAsync(trainset_tmp,
sizeof(float) * index.dim(),
dataset,
sizeof(T) * index.dim() * trainset_ratio,
sizeof(T) * index.dim(),
n_rows_train,
cudaMemcpyDefault,
stream));
// Transform the input `{T -> float}`, one row per warp.
// The threads in each warp copy the data synchronously; this and the layout of the data
// (content is aligned to the end of the rows) together allow doing the transform in-place.
copy_warped(trainset.data(),
index.dim(),
trainset_tmp,
index.dim() * sizeof(float) / sizeof(T),
index.dim(),
n_rows_train,
stream);
tfeher marked this conversation as resolved.
Show resolved Hide resolved
}
// TODO(tfeher): Enable codebook generation with any type T, and then remove
// trainset tmp.
auto trainset_tmp =
make_device_mdarray<T>(handle, &managed_mr, make_extents<IdxT>(n_rows_train, dim));
raft::spatial::knn::detail::utils::subsample(
handle, dataset, n_rows, trainset_tmp.view(), random_seed);
cudaDeviceSynchronize();
RAFT_LOG_INFO("Subsampling done, converting to float");
raft::linalg::unaryOp(trainset.data_handle(),
trainset_tmp.data_handle(),
trainset.size(),
utils::mapping<float>{}, // raft::cast_op<float>(),
raft::resource::get_cuda_stream(handle));
}

// NB: here cluster_centers is used as if it is [n_clusters, data_dim] not [n_clusters,
// dim_ext]!
rmm::device_uvector<float> cluster_centers_buf(
index.n_lists() * index.dim(), stream, device_memory);
index.n_lists() * index.dim(), stream, device_mr);
auto cluster_centers = cluster_centers_buf.data();

// Train balanced hierarchical kmeans clustering
auto trainset_const_view =
raft::make_device_matrix_view<const float, IdxT>(trainset.data(), n_rows_train, index.dim());
auto trainset_const_view = raft::make_const_mdspan(trainset.view());
auto centers_view =
raft::make_device_matrix_view<float, IdxT>(cluster_centers, index.n_lists(), index.dim());
raft::cluster::kmeans_balanced_params kmeans_params;
Expand All @@ -1833,7 +1805,7 @@ auto build(raft::resources const& handle,
handle, kmeans_params, trainset_const_view, centers_view, utils::mapping<float>{});

// Trainset labels are needed for training PQ codebooks
rmm::device_uvector<uint32_t> labels(n_rows_train, stream, device_memory);
rmm::device_uvector<uint32_t> labels(n_rows_train, stream, device_mr);
auto centers_const_view = raft::make_device_matrix_view<const float, IdxT>(
cluster_centers, index.n_lists(), index.dim());
auto labels_view = raft::make_device_vector_view<uint32_t, IdxT>(labels.data(), n_rows_train);
Expand All @@ -1859,19 +1831,19 @@ auto build(raft::resources const& handle,
train_per_subset(handle,
index,
n_rows_train,
trainset.data(),
trainset.data_handle(),
labels.data(),
params.kmeans_n_iters,
&managed_memory_upstream);
&managed_mr);
break;
case codebook_gen::PER_CLUSTER:
train_per_cluster(handle,
index,
n_rows_train,
trainset.data(),
trainset.data_handle(),
labels.data(),
params.kmeans_n_iters,
&managed_memory_upstream);
&managed_mr);
break;
default: RAFT_FAIL("Unreachable code");
}
Expand Down
62 changes: 61 additions & 1 deletion cpp/include/raft/spatial/knn/detail/ann_utils.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
* Copyright (c) 2022-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,8 +16,16 @@

#pragma once

#include <raft/common/nvtx.hpp>
#include <raft/core/device_mdarray.hpp>
#include <raft/core/device_mdspan.hpp>
#include <raft/core/device_resources.hpp>
#include <raft/core/host_mdarray.hpp>

#include <raft/core/logger.hpp>
#include <raft/distance/distance_types.hpp>
#include <raft/matrix/gather.cuh>
#include <raft/random/sample_without_replacement.cuh>
#include <raft/util/cuda_utils.cuh>
#include <raft/util/cudart_utils.hpp>
#include <raft/util/integer_utils.hpp>
Expand All @@ -30,6 +38,10 @@
#include <optional>

#include <cuda_fp16.hpp>
#include <thrust/copy.h>
#include <thrust/device_ptr.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/sort.h>

namespace raft::spatial::knn::detail::utils {

Expand Down Expand Up @@ -573,4 +585,52 @@ struct batch_load_iterator {
size_type cur_pos_;
};

template <typename IdxT>
auto get_subsample_indices(raft::resources const& res, IdxT n_samples, IdxT n_subsamples, int seed)
-> raft::device_vector<IdxT, IdxT>
{
RAFT_EXPECTS(n_subsamples <= n_samples, "Cannot have more training samples than dataset vectors");

auto data_indices = raft::make_device_vector<IdxT, IdxT>(res, n_samples);
thrust::counting_iterator<IdxT> first(0);
thrust::device_ptr<IdxT> ptr(data_indices.data_handle());
thrust::copy(raft::resource::get_thrust_policy(res), first, first + n_samples, ptr);
tfeher marked this conversation as resolved.
Show resolved Hide resolved
raft::random::RngState rng(seed);
auto train_indices = raft::make_device_vector<IdxT, IdxT>(res, n_subsamples);
raft::random::sample_without_replacement(res,
rng,
raft::make_const_mdspan(data_indices.view()),
std::nullopt,
train_indices.view(),
std::nullopt);
return train_indices;
}

/** Subsample the dataset to create a training set*/
template <typename T, typename IdxT = int64_t>
void subsample(raft::resources const& res,
const T* input,
IdxT n_samples,
raft::device_matrix_view<T, IdxT> output,
int seed)
{
IdxT n_dim = output.extent(1);
IdxT n_train = output.extent(0);

raft::device_vector<IdxT, IdxT> train_indices =
get_subsample_indices<IdxT>(res, n_samples, n_train, seed);

cudaPointerAttributes attr;
RAFT_CUDA_TRY(cudaPointerGetAttributes(&attr, input));
T* ptr = reinterpret_cast<T*>(attr.devicePointer);
if (ptr != nullptr) {
tfeher marked this conversation as resolved.
Show resolved Hide resolved
raft::matrix::gather(res,
raft::make_device_matrix_view<const T, IdxT>(ptr, n_samples, n_dim),
raft::make_const_mdspan(train_indices.view()),
output);
} else {
auto dataset = raft::make_host_matrix_view<const T, IdxT>(input, n_samples, n_dim);
raft::matrix::detail::gather(res, dataset, make_const_mdspan(train_indices.view()), output);
}
}
} // namespace raft::spatial::knn::detail::utils
4 changes: 2 additions & 2 deletions cpp/test/neighbors/ann_ivf_pq.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
* Copyright (c) 2022-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -68,7 +68,7 @@ struct ivf_pq_inputs {
ivf_pq_inputs()
{
index_params.n_lists = max(32u, min(1024u, num_db_vecs / 128u));
index_params.kmeans_trainset_fraction = 1.0;
index_params.kmeans_trainset_fraction = 0.95;
}
};

Expand Down