Skip to content

Commit

Permalink
[HOTFIX] 24.02 Revert Random Sampling (rapidsai#2144)
Browse files Browse the repository at this point in the history
Closes rapidsai#2141.

Authors:
   - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
   - Divye Gala (https://github.com/divyegala)
   - Tamas Bela Feher (https://github.com/tfeher)
  • Loading branch information
cjnolet authored Feb 1, 2024
1 parent c70d17a commit 7edd372
Show file tree
Hide file tree
Showing 10 changed files with 128 additions and 215 deletions.
3 changes: 0 additions & 3 deletions cpp/bench/ann/src/raft/raft_ann_bench_param_parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,6 @@ void parse_build_param(const nlohmann::json& conf,
"', should be either 'cluster' or 'subspace'");
}
}
if (conf.contains("max_train_points_per_pq_code")) {
param.max_train_points_per_pq_code = conf.at("max_train_points_per_pq_code");
}
}

template <typename T, typename IdxT>
Expand Down
72 changes: 0 additions & 72 deletions cpp/include/raft/matrix/detail/gather.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,7 @@
#pragma once

#include <functional>
#include <raft/common/nvtx.hpp>
#include <raft/core/device_mdarray.hpp>
#include <raft/core/device_mdspan.hpp>
#include <raft/core/host_mdarray.hpp>
#include <raft/core/host_mdspan.hpp>
#include <raft/core/operators.hpp>
#include <raft/core/pinned_mdarray.hpp>
#include <raft/core/pinned_mdspan.hpp>
#include <raft/util/cuda_dev_essentials.cuh>
#include <raft/util/cudart_utils.hpp>

namespace raft {
Expand Down Expand Up @@ -343,70 +335,6 @@ void gather_if(const InputIteratorT in,
gatherImpl(in, D, N, map, stencil, map_length, out, pred_op, transform_op, stream);
}

template <typename T, typename IdxT = int64_t>
void gather_buff(host_matrix_view<const T, IdxT> dataset,
host_vector_view<const IdxT, IdxT> indices,
IdxT offset,
pinned_matrix_view<T, IdxT> buff)
{
raft::common::nvtx::range<common::nvtx::domain::raft> fun_scope("gather_host_buff");
IdxT batch_size = std::min<IdxT>(buff.extent(0), indices.extent(0) - offset);

#pragma omp for
for (IdxT i = 0; i < batch_size; i++) {
IdxT in_idx = indices(offset + i);
for (IdxT k = 0; k < buff.extent(1); k++) {
buff(i, k) = dataset(in_idx, k);
}
}
}

template <typename T, typename IdxT>
void gather(raft::resources const& res,
host_matrix_view<const T, IdxT> dataset,
device_vector_view<const IdxT, IdxT> indices,
raft::device_matrix_view<T, IdxT> output)
{
raft::common::nvtx::range<common::nvtx::domain::raft> fun_scope("gather");
IdxT n_dim = output.extent(1);
IdxT n_train = output.extent(0);
auto indices_host = raft::make_host_vector<IdxT, IdxT>(n_train);
raft::copy(
indices_host.data_handle(), indices.data_handle(), n_train, resource::get_cuda_stream(res));
resource::sync_stream(res);

const size_t max_batch_size = 32768;
// Gather the vector on the host in tmp buffers. We use two buffers to overlap H2D sync
// and gathering the data.
raft::common::nvtx::push_range("gather::alloc_buffers");
auto out_tmp1 = raft::make_pinned_matrix<T, IdxT>(res, max_batch_size, n_dim);
auto out_tmp2 = raft::make_pinned_matrix<T, IdxT>(res, max_batch_size, n_dim);
auto view1 = out_tmp1.view();
auto view2 = out_tmp2.view();
raft::common::nvtx::pop_range();

gather_buff(dataset, make_const_mdspan(indices_host.view()), (IdxT)0, view1);
#pragma omp parallel
for (IdxT device_offset = 0; device_offset < n_train; device_offset += max_batch_size) {
IdxT batch_size = std::min<IdxT>(max_batch_size, n_train - device_offset);
#pragma omp master
raft::copy(output.data_handle() + device_offset * n_dim,
view1.data_handle(),
batch_size * n_dim,
resource::get_cuda_stream(res));
// Start gathering the next batch on the host.
IdxT host_offset = device_offset + batch_size;
batch_size = std::min<IdxT>(max_batch_size, n_train - host_offset);
if (batch_size > 0) {
gather_buff(dataset, make_const_mdspan(indices_host.view()), host_offset, view2);
}
#pragma omp master
resource::sync_stream(res);
#pragma omp barrier
std::swap(view1, view2);
}
}

} // namespace detail
} // namespace matrix
} // namespace raft
23 changes: 14 additions & 9 deletions cpp/include/raft/neighbors/detail/ivf_flat_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -361,23 +361,28 @@ inline auto build(raft::resources const& handle,

// Train the kmeans clustering
{
int random_seed = 137;
auto trainset_ratio = std::max<size_t>(
1, n_rows / std::max<size_t>(params.kmeans_trainset_fraction * n_rows, index.n_lists()));
auto n_rows_train = n_rows / trainset_ratio;
auto trainset = make_device_matrix<T, IdxT>(handle, n_rows_train, index.dim());
raft::spatial::knn::detail::utils::subsample(
handle, dataset, n_rows, trainset.view(), random_seed);
rmm::device_uvector<T> trainset(n_rows_train * index.dim(), stream);
// TODO: a proper sampling
RAFT_CUDA_TRY(cudaMemcpy2DAsync(trainset.data(),
sizeof(T) * index.dim(),
dataset,
sizeof(T) * index.dim() * trainset_ratio,
sizeof(T) * index.dim(),
n_rows_train,
cudaMemcpyDefault,
stream));
auto trainset_const_view =
raft::make_device_matrix_view<const T, IdxT>(trainset.data(), n_rows_train, index.dim());
auto centers_view = raft::make_device_matrix_view<float, IdxT>(
index.centers().data_handle(), index.n_lists(), index.dim());
raft::cluster::kmeans_balanced_params kmeans_params;
kmeans_params.n_iters = params.kmeans_n_iters;
kmeans_params.metric = index.metric();
raft::cluster::kmeans_balanced::fit(handle,
kmeans_params,
make_const_mdspan(trainset.view()),
centers_view,
utils::mapping<float>{});
raft::cluster::kmeans_balanced::fit(
handle, kmeans_params, trainset_const_view, centers_view, utils::mapping<float>{});
}

// add the data if necessary
Expand Down
Loading

0 comments on commit 7edd372

Please sign in to comment.