From a455b08b4a70b93da0c9b548e269e8c82b81f696 Mon Sep 17 00:00:00 2001 From: Tamas Bela Feher Date: Wed, 3 Jan 2024 11:23:42 +0100 Subject: [PATCH 01/11] Add random subsampling for IVF methods --- .../src/raft/raft_ann_bench_param_parser.h | 2 + .../raft/neighbors/detail/ivf_flat_build.cuh | 22 ++--- .../raft/neighbors/detail/ivf_pq_build.cuh | 87 ++++++------------ cpp/include/raft/neighbors/ivf_flat_types.hpp | 6 ++ cpp/include/raft/neighbors/ivf_pq_types.hpp | 7 ++ .../raft/spatial/knn/detail/ann_utils.cuh | 91 +++++++++++++++++++ cpp/test/neighbors/ann_ivf_flat.cuh | 8 +- cpp/test/neighbors/ann_ivf_pq.cuh | 2 +- .../neighbors/ivf_flat/cpp/c_ivf_flat.pxd | 1 + .../pylibraft/neighbors/ivf_flat/ivf_flat.pyx | 13 ++- .../neighbors/ivf_pq/cpp/c_ivf_pq.pxd | 1 + .../pylibraft/neighbors/ivf_pq/ivf_pq.pyx | 12 ++- 12 files changed, 176 insertions(+), 76 deletions(-) diff --git a/cpp/bench/ann/src/raft/raft_ann_bench_param_parser.h b/cpp/bench/ann/src/raft/raft_ann_bench_param_parser.h index 2a021a8a12..b61ade91d2 100644 --- a/cpp/bench/ann/src/raft/raft_ann_bench_param_parser.h +++ b/cpp/bench/ann/src/raft/raft_ann_bench_param_parser.h @@ -55,6 +55,7 @@ void parse_build_param(const nlohmann::json& conf, param.n_lists = conf.at("nlist"); if (conf.contains("niter")) { param.kmeans_n_iters = conf.at("niter"); } if (conf.contains("ratio")) { param.kmeans_trainset_fraction = 1.0 / (double)conf.at("ratio"); } + if (conf.contains("random_seed")) { param.random_seed = conf.at("random_seed"); } } template @@ -87,6 +88,7 @@ void parse_build_param(const nlohmann::json& conf, "', should be either 'cluster' or 'subspace'"); } } + if (conf.contains("random_seed")) { param.random_seed = conf.at("random_seed"); } } template diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_build.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_build.cuh index a35cb9e1f1..2469dd014e 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_build.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_build.cuh @@ -364,25 +364,19 @@ inline auto build(raft::resources const& handle, auto trainset_ratio = std::max( 1, n_rows / std::max(params.kmeans_trainset_fraction * n_rows, index.n_lists())); auto n_rows_train = n_rows / trainset_ratio; - rmm::device_uvector trainset(n_rows_train * index.dim(), stream); - // TODO: a proper sampling - RAFT_CUDA_TRY(cudaMemcpy2DAsync(trainset.data(), - sizeof(T) * index.dim(), - dataset, - sizeof(T) * index.dim() * trainset_ratio, - sizeof(T) * index.dim(), - n_rows_train, - cudaMemcpyDefault, - stream)); - auto trainset_const_view = - raft::make_device_matrix_view(trainset.data(), n_rows_train, index.dim()); + auto trainset = make_device_matrix(handle, n_rows_train, index.dim()); + raft::spatial::knn::detail::utils::subsample( + handle, dataset, n_rows, trainset.view(), params.random_seed); auto centers_view = raft::make_device_matrix_view( index.centers().data_handle(), index.n_lists(), index.dim()); raft::cluster::kmeans_balanced_params kmeans_params; kmeans_params.n_iters = params.kmeans_n_iters; kmeans_params.metric = index.metric(); - raft::cluster::kmeans_balanced::fit( - handle, kmeans_params, trainset_const_view, centers_view, utils::mapping{}); + raft::cluster::kmeans_balanced::fit(handle, + kmeans_params, + make_const_mdspan(trainset.view()), + centers_view, + utils::mapping{}); } // add the data if necessary diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_build.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_build.cuh index 2dfb261f32..2b71db4066 100644 --- a/cpp/include/raft/neighbors/detail/ivf_pq_build.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_pq_build.cuh @@ -25,9 +25,11 @@ #include #include +#include #include #include #include +#include #include #include #include @@ -46,7 +48,6 @@ #include #include -#include #include #include #include @@ -1759,71 +1760,41 @@ auto build(raft::resources const& handle, size_t(n_rows) / std::max(params.kmeans_trainset_fraction * n_rows, index.n_lists())); size_t n_rows_train = n_rows / trainset_ratio; - auto* device_memory = resource::get_workspace_resource(handle); - rmm::mr::managed_memory_resource managed_memory_upstream; + auto* device_mr = resource::get_workspace_resource(handle); + rmm::mr::managed_memory_resource managed_mr; // Besides just sampling, we transform the input dataset into floats to make it easier // to use gemm operations from cublas. - rmm::device_uvector trainset(n_rows_train * index.dim(), stream, device_memory); - // TODO: a proper sampling + auto trainset = + make_device_mdarray(handle, device_mr, make_extents(n_rows_train, dim)); + if constexpr (std::is_same_v) { - RAFT_CUDA_TRY(cudaMemcpy2DAsync(trainset.data(), - sizeof(T) * index.dim(), - dataset, - sizeof(T) * index.dim() * trainset_ratio, - sizeof(T) * index.dim(), - n_rows_train, - cudaMemcpyDefault, - stream)); + raft::spatial::knn::detail::utils::subsample( + handle, dataset, n_rows, trainset.view(), params.random_seed); } else { - size_t dim = index.dim(); - cudaPointerAttributes dataset_attr; - RAFT_CUDA_TRY(cudaPointerGetAttributes(&dataset_attr, dataset)); - if (dataset_attr.devicePointer != nullptr) { - // data is available on device: just run the kernel to copy and map the data - auto p = reinterpret_cast(dataset_attr.devicePointer); - auto trainset_view = - raft::make_device_vector_view(trainset.data(), dim * n_rows_train); - linalg::map_offset(handle, trainset_view, [p, trainset_ratio, dim] __device__(size_t i) { - auto col = i % dim; - return utils::mapping{}(p[(i - col) * size_t(trainset_ratio) + col]); - }); - } else { - // data is not available: first copy, then map inplace - auto trainset_tmp = reinterpret_cast(reinterpret_cast(trainset.data()) + - (sizeof(float) - sizeof(T)) * index.dim()); - // We copy the data in strides, one row at a time, and place the smaller rows of type T - // at the end of float rows. - RAFT_CUDA_TRY(cudaMemcpy2DAsync(trainset_tmp, - sizeof(float) * index.dim(), - dataset, - sizeof(T) * index.dim() * trainset_ratio, - sizeof(T) * index.dim(), - n_rows_train, - cudaMemcpyDefault, - stream)); - // Transform the input `{T -> float}`, one row per warp. - // The threads in each warp copy the data synchronously; this and the layout of the data - // (content is aligned to the end of the rows) together allow doing the transform in-place. - copy_warped(trainset.data(), - index.dim(), - trainset_tmp, - index.dim() * sizeof(float) / sizeof(T), - index.dim(), - n_rows_train, - stream); - } + // TODO(tfeher): Enable codebook generation with any type T, and then remove + // trainset tmp. + auto trainset_tmp = + make_device_mdarray(handle, device_mr, make_extents(n_rows_train, dim)); + raft::spatial::knn::detail::utils::subsample( + handle, dataset, n_rows, trainset_tmp.view(), params.random_seed); + cudaDeviceSynchronize(); + RAFT_LOG_INFO("Subsampling done, converting to float"); + raft::linalg::unaryOp(trainset.data_handle(), + trainset_tmp.data_handle(), + trainset.size(), + utils::mapping{}, // raft::cast_op(), + raft::resource::get_cuda_stream(handle)); } // NB: here cluster_centers is used as if it is [n_clusters, data_dim] not [n_clusters, // dim_ext]! rmm::device_uvector cluster_centers_buf( - index.n_lists() * index.dim(), stream, device_memory); + index.n_lists() * index.dim(), stream, device_mr); auto cluster_centers = cluster_centers_buf.data(); // Train balanced hierarchical kmeans clustering - auto trainset_const_view = - raft::make_device_matrix_view(trainset.data(), n_rows_train, index.dim()); + auto trainset_const_view = raft::make_const_mdspan(trainset.view()); auto centers_view = raft::make_device_matrix_view(cluster_centers, index.n_lists(), index.dim()); raft::cluster::kmeans_balanced_params kmeans_params; @@ -1833,7 +1804,7 @@ auto build(raft::resources const& handle, handle, kmeans_params, trainset_const_view, centers_view, utils::mapping{}); // Trainset labels are needed for training PQ codebooks - rmm::device_uvector labels(n_rows_train, stream, device_memory); + rmm::device_uvector labels(n_rows_train, stream, device_mr); auto centers_const_view = raft::make_device_matrix_view( cluster_centers, index.n_lists(), index.dim()); auto labels_view = raft::make_device_vector_view(labels.data(), n_rows_train); @@ -1859,19 +1830,19 @@ auto build(raft::resources const& handle, train_per_subset(handle, index, n_rows_train, - trainset.data(), + trainset.data_handle(), labels.data(), params.kmeans_n_iters, - &managed_memory_upstream); + &managed_mr); break; case codebook_gen::PER_CLUSTER: train_per_cluster(handle, index, n_rows_train, - trainset.data(), + trainset.data_handle(), labels.data(), params.kmeans_n_iters, - &managed_memory_upstream); + &managed_mr); break; default: RAFT_FAIL("Unreachable code"); } diff --git a/cpp/include/raft/neighbors/ivf_flat_types.hpp b/cpp/include/raft/neighbors/ivf_flat_types.hpp index 180fe2e21b..12b73d1aa2 100644 --- a/cpp/include/raft/neighbors/ivf_flat_types.hpp +++ b/cpp/include/raft/neighbors/ivf_flat_types.hpp @@ -76,6 +76,12 @@ struct index_params : ann::index_params { * flag to `true` if you prefer to use as little GPU memory for the database as possible. */ bool conservative_memory_allocation = false; + /** + * Seed used for random sampling if kmeans_trainset_fraction < 1. + * + * Value -1 disables random sampling, and results in sampling with a fixed stride. + */ + int random_seed = 0; }; struct search_params : ann::search_params { diff --git a/cpp/include/raft/neighbors/ivf_pq_types.hpp b/cpp/include/raft/neighbors/ivf_pq_types.hpp index 45ab18c84f..47d534414b 100644 --- a/cpp/include/raft/neighbors/ivf_pq_types.hpp +++ b/cpp/include/raft/neighbors/ivf_pq_types.hpp @@ -105,6 +105,13 @@ struct index_params : ann::index_params { * flag to `true` if you prefer to use as little GPU memory for the database as possible. */ bool conservative_memory_allocation = false; + + /** + * Seed used for random sampling if kmeans_trainset_fraction < 1. + * + * Value -1 disables random sampling, and results in sampling with a fixed stride. + */ + int random_seed = 0; }; struct search_params : ann::search_params { diff --git a/cpp/include/raft/spatial/knn/detail/ann_utils.cuh b/cpp/include/raft/spatial/knn/detail/ann_utils.cuh index c7823c2d38..9d107f870f 100644 --- a/cpp/include/raft/spatial/knn/detail/ann_utils.cuh +++ b/cpp/include/raft/spatial/knn/detail/ann_utils.cuh @@ -16,8 +16,14 @@ #pragma once +#include +#include +#include +#include #include #include +#include +#include #include #include #include @@ -30,6 +36,10 @@ #include #include +#include +#include +#include +#include namespace raft::spatial::knn::detail::utils { @@ -573,4 +583,85 @@ struct batch_load_iterator { size_type cur_pos_; }; +template +auto get_subsample_indices(raft::resources const& res, IdxT n_samples, IdxT n_subsamples, int seed) + -> raft::device_vector +{ + RAFT_EXPECTS(n_subsamples <= n_samples, "Cannot have more training samples than dataset vectors"); + + auto data_indices = raft::make_device_vector(res, n_samples); + thrust::counting_iterator first(0); + thrust::device_ptr ptr(data_indices.data_handle()); + thrust::copy(raft::resource::get_thrust_policy(res), first, first + n_samples, ptr); + raft::random::RngState rng(seed); + auto train_indices = raft::make_device_vector(res, n_subsamples); + raft::random::sample_without_replacement(res, + rng, + raft::make_const_mdspan(data_indices.view()), + std::nullopt, + train_indices.view(), + std::nullopt); + + thrust::sort(resource::get_thrust_policy(res), + train_indices.data_handle(), + train_indices.data_handle() + n_subsamples); + return train_indices; +} + +/** Subsample the dataset to create a training set*/ +template +void subsample(raft::resources const& res, + const T* input, + IdxT n_samples, + raft::device_matrix_view output, + int seed) +{ + int64_t n_dim = output.extent(1); + int64_t n_train = output.extent(0); + if (seed == -1 || n_train == n_samples) { + IdxT trainset_ratio = n_samples / n_train; + RAFT_LOG_INFO("Fixed stride subsampling"); + RAFT_CUDA_TRY(cudaMemcpy2DAsync(output.data_handle(), + sizeof(T) * n_dim, + input, + sizeof(T) * n_dim * trainset_ratio, + sizeof(T) * n_dim, + n_train, + cudaMemcpyDefault, + resource::get_cuda_stream(res))); + return; + } + RAFT_LOG_DEBUG("Random subsampling"); + raft::device_vector train_indices = + get_subsample_indices(res, n_samples, n_train, seed); + + cudaPointerAttributes attr; + RAFT_CUDA_TRY(cudaPointerGetAttributes(&attr, input)); + T* ptr = reinterpret_cast(attr.devicePointer); + if (ptr != nullptr) { + raft::matrix::copy_rows(res, + raft::make_device_matrix_view(ptr, n_samples, n_dim), + output, + raft::make_const_mdspan(train_indices.view())); + } else { + auto dataset = raft::make_host_matrix_view(input, n_samples, n_dim); + auto train_indices_host = raft::make_host_vector(n_train); + raft::copy(train_indices_host.data_handle(), + train_indices.data_handle(), + n_train, + resource::get_cuda_stream(res)); + resource::sync_stream(res); + auto out_tmp = raft::make_host_matrix(n_train, n_dim); +#pragma omp parallel for + for (IdxT i = 0; i < n_train; i++) { + IdxT in_idx = train_indices_host(i); + for (IdxT k = 0; k < n_dim; k++) { + out_tmp(i, k) = dataset(in_idx, k); + } + } + raft::copy( + output.data_handle(), out_tmp.data_handle(), output.size(), resource::get_cuda_stream(res)); + resource::sync_stream(res); + } +} } // namespace raft::spatial::knn::detail::utils diff --git a/cpp/test/neighbors/ann_ivf_flat.cuh b/cpp/test/neighbors/ann_ivf_flat.cuh index 39439d392d..86ff5964a3 100644 --- a/cpp/test/neighbors/ann_ivf_flat.cuh +++ b/cpp/test/neighbors/ann_ivf_flat.cuh @@ -73,6 +73,7 @@ struct AnnIvfFlatInputs { raft::distance::DistanceType metric; bool adaptive_centers; bool host_dataset; + int seed; }; template @@ -80,7 +81,7 @@ template { os << "{ " << p.num_queries << ", " << p.num_db_vecs << ", " << p.dim << ", " << p.k << ", " << p.nprobe << ", " << p.nlist << ", " << static_cast(p.metric) << ", " - << p.adaptive_centers << ", " << p.host_dataset << '}' << std::endl; + << p.adaptive_centers << ", " << p.host_dataset << "," << p.seed << '}' << std::endl; return os; } @@ -178,6 +179,7 @@ class AnnIVFFlatTest : public ::testing::TestWithParam> { index_params.add_data_on_build = false; index_params.kmeans_trainset_fraction = 0.5; index_params.metric_arg = 0; + index_params.random_seed = ps.seed; ivf_flat::index idx(handle_, index_params, ps.dim); ivf_flat::index index_2(handle_, index_params, ps.dim); @@ -327,6 +329,7 @@ class AnnIVFFlatTest : public ::testing::TestWithParam> { index_params.add_data_on_build = false; index_params.kmeans_trainset_fraction = 1.0; index_params.metric_arg = 0; + index_params.random_seed = ps.seed; auto database_view = raft::make_device_matrix_view( (const DataT*)database.data(), ps.num_db_vecs, ps.dim); @@ -497,6 +500,7 @@ class AnnIVFFlatTest : public ::testing::TestWithParam> { index_params.add_data_on_build = true; index_params.kmeans_trainset_fraction = 0.5; index_params.metric_arg = 0; + index_params.random_seed = ps.seed; // Create IVF Flat index auto database_view = raft::make_device_matrix_view( @@ -607,6 +611,8 @@ const std::vector> inputs = { {20, 100000, 16, 10, 20, 1024, raft::distance::DistanceType::L2Expanded, true}, {1000, 100000, 16, 10, 20, 1024, raft::distance::DistanceType::L2Expanded, true}, {10000, 131072, 8, 10, 20, 1024, raft::distance::DistanceType::L2Expanded, false}, + {10000, 1000000, 96, 10, 20, 1024, raft::distance::DistanceType::L2Expanded, false, true, -1}, + {10000, 1000000, 96, 10, 20, 1024, raft::distance::DistanceType::L2Expanded, false, false, -1}, // host input data {1000, 10000, 16, 10, 40, 1024, raft::distance::DistanceType::L2Expanded, false, true}, diff --git a/cpp/test/neighbors/ann_ivf_pq.cuh b/cpp/test/neighbors/ann_ivf_pq.cuh index eb30b60eca..c6c9096109 100644 --- a/cpp/test/neighbors/ann_ivf_pq.cuh +++ b/cpp/test/neighbors/ann_ivf_pq.cuh @@ -68,7 +68,7 @@ struct ivf_pq_inputs { ivf_pq_inputs() { index_params.n_lists = max(32u, min(1024u, num_db_vecs / 128u)); - index_params.kmeans_trainset_fraction = 1.0; + index_params.kmeans_trainset_fraction = 0.95; } }; diff --git a/python/pylibraft/pylibraft/neighbors/ivf_flat/cpp/c_ivf_flat.pxd b/python/pylibraft/pylibraft/neighbors/ivf_flat/cpp/c_ivf_flat.pxd index a281d33310..9ddab5eba2 100644 --- a/python/pylibraft/pylibraft/neighbors/ivf_flat/cpp/c_ivf_flat.pxd +++ b/python/pylibraft/pylibraft/neighbors/ivf_flat/cpp/c_ivf_flat.pxd @@ -56,6 +56,7 @@ cdef extern from "raft/neighbors/ivf_flat_types.hpp" \ double kmeans_trainset_fraction bool adaptive_centers bool conservative_memory_allocation + int random_seed cdef cppclass index[T, IdxT](ann_index): index(const device_resources& handle, diff --git a/python/pylibraft/pylibraft/neighbors/ivf_flat/ivf_flat.pyx b/python/pylibraft/pylibraft/neighbors/ivf_flat/ivf_flat.pyx index d8fbdc74da..b9fe7060cf 100644 --- a/python/pylibraft/pylibraft/neighbors/ivf_flat/ivf_flat.pyx +++ b/python/pylibraft/pylibraft/neighbors/ivf_flat/ivf_flat.pyx @@ -113,6 +113,11 @@ cdef class IndexParams: adding new data (through the classification of the added data); that is, `index.centers()` "drift" together with the changing distribution of the newly added data. + random_seed : int, default = 0 + Seed used for random sampling if kmeans_trainset_fraction < 1. + Value -1 disables random sampling, and results in sampling with a + fixed stride. + """ cdef c_ivf_flat.index_params params @@ -122,7 +127,8 @@ cdef class IndexParams: kmeans_n_iters=20, kmeans_trainset_fraction=0.5, add_data_on_build=True, - bool adaptive_centers=False): + bool adaptive_centers=False, + random_seed=0): self.params.n_lists = n_lists self.params.metric = _get_metric(metric) self.params.metric_arg = 0 @@ -130,6 +136,7 @@ cdef class IndexParams: self.params.kmeans_trainset_fraction = kmeans_trainset_fraction self.params.add_data_on_build = add_data_on_build self.params.adaptive_centers = adaptive_centers + self.params.random_seed = random_seed @property def n_lists(self): @@ -155,6 +162,10 @@ cdef class IndexParams: def adaptive_centers(self): return self.params.adaptive_centers + @property + def random_seed(self): + return self.params.random_seed + cdef class Index: cdef readonly bool trained diff --git a/python/pylibraft/pylibraft/neighbors/ivf_pq/cpp/c_ivf_pq.pxd b/python/pylibraft/pylibraft/neighbors/ivf_pq/cpp/c_ivf_pq.pxd index 531c2428e9..418b2445d0 100644 --- a/python/pylibraft/pylibraft/neighbors/ivf_pq/cpp/c_ivf_pq.pxd +++ b/python/pylibraft/pylibraft/neighbors/ivf_pq/cpp/c_ivf_pq.pxd @@ -78,6 +78,7 @@ cdef extern from "raft/neighbors/ivf_pq_types.hpp" \ codebook_gen codebook_kind bool force_random_rotation bool conservative_memory_allocation + int random_seed cdef cppclass index[IdxT](ann_index): index(const device_resources& handle, diff --git a/python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pyx b/python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pyx index 0c1bbf6b9c..b1d95caf66 100644 --- a/python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pyx +++ b/python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pyx @@ -156,6 +156,10 @@ cdef class IndexParams: repeated calls to `extend` (extending the database). To disable this behavior and use as little GPU memory for the database as possible, set this flat to `True`. + random_seed : int, default = 0 + Seed used for random sampling if kmeans_trainset_fraction < 1. + Value -1 disables random sampling, and results in sampling with a + fixed stride. """ def __init__(self, *, n_lists=1024, @@ -167,7 +171,8 @@ cdef class IndexParams: codebook_kind="subspace", force_random_rotation=False, add_data_on_build=True, - conservative_memory_allocation=False): + conservative_memory_allocation=False, + random_seed=0): self.params.n_lists = n_lists self.params.metric = _get_metric(metric) self.params.metric_arg = 0 @@ -185,6 +190,7 @@ cdef class IndexParams: self.params.add_data_on_build = add_data_on_build self.params.conservative_memory_allocation = \ conservative_memory_allocation + self.params.random_seed = random_seed @property def n_lists(self): @@ -225,6 +231,10 @@ cdef class IndexParams: @property def conservative_memory_allocation(self): return self.params.conservative_memory_allocation + + @property + def random_seed(self): + return self.params.random_seed cdef class Index: From ee39ddc589be42b6ac5b4677a1e3549a788e2a33 Mon Sep 17 00:00:00 2001 From: Tamas Bela Feher Date: Wed, 3 Jan 2024 21:13:46 +0100 Subject: [PATCH 02/11] fix style --- cpp/bench/ann/src/raft/raft_ann_bench_param_parser.h | 2 +- cpp/include/raft/neighbors/detail/ivf_flat_build.cuh | 2 +- cpp/include/raft/neighbors/ivf_flat_types.hpp | 2 +- cpp/include/raft/neighbors/ivf_pq_types.hpp | 2 +- cpp/include/raft/spatial/knn/detail/ann_utils.cuh | 2 +- cpp/test/neighbors/ann_ivf_flat.cuh | 2 +- cpp/test/neighbors/ann_ivf_pq.cuh | 2 +- .../pylibraft/pylibraft/neighbors/ivf_flat/cpp/c_ivf_flat.pxd | 2 +- python/pylibraft/pylibraft/neighbors/ivf_flat/ivf_flat.pyx | 4 ++-- python/pylibraft/pylibraft/neighbors/ivf_pq/cpp/c_ivf_pq.pxd | 2 +- python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pyx | 4 ++-- 11 files changed, 13 insertions(+), 13 deletions(-) diff --git a/cpp/bench/ann/src/raft/raft_ann_bench_param_parser.h b/cpp/bench/ann/src/raft/raft_ann_bench_param_parser.h index b61ade91d2..38e473f9ae 100644 --- a/cpp/bench/ann/src/raft/raft_ann_bench_param_parser.h +++ b/cpp/bench/ann/src/raft/raft_ann_bench_param_parser.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_build.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_build.cuh index 2469dd014e..0f710d5b81 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_build.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_build.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/include/raft/neighbors/ivf_flat_types.hpp b/cpp/include/raft/neighbors/ivf_flat_types.hpp index 12b73d1aa2..317e10cf92 100644 --- a/cpp/include/raft/neighbors/ivf_flat_types.hpp +++ b/cpp/include/raft/neighbors/ivf_flat_types.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/include/raft/neighbors/ivf_pq_types.hpp b/cpp/include/raft/neighbors/ivf_pq_types.hpp index 47d534414b..51536583f8 100644 --- a/cpp/include/raft/neighbors/ivf_pq_types.hpp +++ b/cpp/include/raft/neighbors/ivf_pq_types.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/include/raft/spatial/knn/detail/ann_utils.cuh b/cpp/include/raft/spatial/knn/detail/ann_utils.cuh index 9d107f870f..ab11a2178a 100644 --- a/cpp/include/raft/spatial/knn/detail/ann_utils.cuh +++ b/cpp/include/raft/spatial/knn/detail/ann_utils.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/test/neighbors/ann_ivf_flat.cuh b/cpp/test/neighbors/ann_ivf_flat.cuh index 86ff5964a3..014048a068 100644 --- a/cpp/test/neighbors/ann_ivf_flat.cuh +++ b/cpp/test/neighbors/ann_ivf_flat.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/test/neighbors/ann_ivf_pq.cuh b/cpp/test/neighbors/ann_ivf_pq.cuh index c6c9096109..f7e85db1c2 100644 --- a/cpp/test/neighbors/ann_ivf_pq.cuh +++ b/cpp/test/neighbors/ann_ivf_pq.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/python/pylibraft/pylibraft/neighbors/ivf_flat/cpp/c_ivf_flat.pxd b/python/pylibraft/pylibraft/neighbors/ivf_flat/cpp/c_ivf_flat.pxd index 9ddab5eba2..035d2814fc 100644 --- a/python/pylibraft/pylibraft/neighbors/ivf_flat/cpp/c_ivf_flat.pxd +++ b/python/pylibraft/pylibraft/neighbors/ivf_flat/cpp/c_ivf_flat.pxd @@ -1,5 +1,5 @@ # -# Copyright (c) 2023, NVIDIA CORPORATION. +# Copyright (c) 2023-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/python/pylibraft/pylibraft/neighbors/ivf_flat/ivf_flat.pyx b/python/pylibraft/pylibraft/neighbors/ivf_flat/ivf_flat.pyx index b9fe7060cf..64ac1a9ce9 100644 --- a/python/pylibraft/pylibraft/neighbors/ivf_flat/ivf_flat.pyx +++ b/python/pylibraft/pylibraft/neighbors/ivf_flat/ivf_flat.pyx @@ -1,5 +1,5 @@ # -# Copyright (c) 2023, NVIDIA CORPORATION. +# Copyright (c) 2023-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -117,7 +117,7 @@ cdef class IndexParams: Seed used for random sampling if kmeans_trainset_fraction < 1. Value -1 disables random sampling, and results in sampling with a fixed stride. - + """ cdef c_ivf_flat.index_params params diff --git a/python/pylibraft/pylibraft/neighbors/ivf_pq/cpp/c_ivf_pq.pxd b/python/pylibraft/pylibraft/neighbors/ivf_pq/cpp/c_ivf_pq.pxd index 418b2445d0..f40e5465b7 100644 --- a/python/pylibraft/pylibraft/neighbors/ivf_pq/cpp/c_ivf_pq.pxd +++ b/python/pylibraft/pylibraft/neighbors/ivf_pq/cpp/c_ivf_pq.pxd @@ -1,5 +1,5 @@ # -# Copyright (c) 2022-2023, NVIDIA CORPORATION. +# Copyright (c) 2022-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pyx b/python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pyx index b1d95caf66..b8a3cf4887 100644 --- a/python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pyx +++ b/python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pyx @@ -1,5 +1,5 @@ # -# Copyright (c) 2022-2023, NVIDIA CORPORATION. +# Copyright (c) 2022-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -231,7 +231,7 @@ cdef class IndexParams: @property def conservative_memory_allocation(self): return self.params.conservative_memory_allocation - + @property def random_seed(self): return self.params.random_seed From ca3eec72ab6737c56a36ab348a8be9aa879fc155 Mon Sep 17 00:00:00 2001 From: Tamas Bela Feher Date: Thu, 18 Jan 2024 01:04:32 +0100 Subject: [PATCH 03/11] Fix IdxT --- cpp/include/raft/spatial/knn/detail/ann_utils.cuh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/include/raft/spatial/knn/detail/ann_utils.cuh b/cpp/include/raft/spatial/knn/detail/ann_utils.cuh index ab11a2178a..0f585c1572 100644 --- a/cpp/include/raft/spatial/knn/detail/ann_utils.cuh +++ b/cpp/include/raft/spatial/knn/detail/ann_utils.cuh @@ -616,8 +616,8 @@ void subsample(raft::resources const& res, raft::device_matrix_view output, int seed) { - int64_t n_dim = output.extent(1); - int64_t n_train = output.extent(0); + IdxT n_dim = output.extent(1); + IdxT n_train = output.extent(0); if (seed == -1 || n_train == n_samples) { IdxT trainset_ratio = n_samples / n_train; RAFT_LOG_INFO("Fixed stride subsampling"); From 2ee6aff7987b9411d1ccfddc314bccba2d3e2245 Mon Sep 17 00:00:00 2001 From: Tamas Bela Feher Date: Thu, 18 Jan 2024 01:16:00 +0100 Subject: [PATCH 04/11] remove random_seed parameter --- .../ann/src/raft/raft_ann_bench_param_parser.h | 4 +--- .../raft/neighbors/detail/ivf_flat_build.cuh | 3 ++- .../raft/neighbors/detail/ivf_pq_build.cuh | 5 +++-- cpp/include/raft/neighbors/ivf_flat_types.hpp | 8 +------- cpp/include/raft/neighbors/ivf_pq_types.hpp | 9 +-------- cpp/test/neighbors/ann_ivf_flat.cuh | 10 ++-------- .../neighbors/ivf_flat/cpp/c_ivf_flat.pxd | 3 +-- .../pylibraft/neighbors/ivf_flat/ivf_flat.pyx | 15 ++------------- .../pylibraft/neighbors/ivf_pq/cpp/c_ivf_pq.pxd | 3 +-- .../pylibraft/neighbors/ivf_pq/ivf_pq.pyx | 14 ++------------ 10 files changed, 16 insertions(+), 58 deletions(-) diff --git a/cpp/bench/ann/src/raft/raft_ann_bench_param_parser.h b/cpp/bench/ann/src/raft/raft_ann_bench_param_parser.h index 38e473f9ae..2a021a8a12 100644 --- a/cpp/bench/ann/src/raft/raft_ann_bench_param_parser.h +++ b/cpp/bench/ann/src/raft/raft_ann_bench_param_parser.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * Copyright (c) 2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -55,7 +55,6 @@ void parse_build_param(const nlohmann::json& conf, param.n_lists = conf.at("nlist"); if (conf.contains("niter")) { param.kmeans_n_iters = conf.at("niter"); } if (conf.contains("ratio")) { param.kmeans_trainset_fraction = 1.0 / (double)conf.at("ratio"); } - if (conf.contains("random_seed")) { param.random_seed = conf.at("random_seed"); } } template @@ -88,7 +87,6 @@ void parse_build_param(const nlohmann::json& conf, "', should be either 'cluster' or 'subspace'"); } } - if (conf.contains("random_seed")) { param.random_seed = conf.at("random_seed"); } } template diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_build.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_build.cuh index 0f710d5b81..ab30b4009d 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_build.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_build.cuh @@ -361,12 +361,13 @@ inline auto build(raft::resources const& handle, // Train the kmeans clustering { + int random_seed = 137; auto trainset_ratio = std::max( 1, n_rows / std::max(params.kmeans_trainset_fraction * n_rows, index.n_lists())); auto n_rows_train = n_rows / trainset_ratio; auto trainset = make_device_matrix(handle, n_rows_train, index.dim()); raft::spatial::knn::detail::utils::subsample( - handle, dataset, n_rows, trainset.view(), params.random_seed); + handle, dataset, n_rows, trainset.view(), random_seed); auto centers_view = raft::make_device_matrix_view( index.centers().data_handle(), index.n_lists(), index.dim()); raft::cluster::kmeans_balanced_params kmeans_params; diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_build.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_build.cuh index 2b71db4066..11d520d989 100644 --- a/cpp/include/raft/neighbors/detail/ivf_pq_build.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_pq_build.cuh @@ -1755,6 +1755,7 @@ auto build(raft::resources const& handle, utils::memzero(index.inds_ptrs().data_handle(), index.inds_ptrs().size(), stream); { + int random_seed = 137; auto trainset_ratio = std::max( 1, size_t(n_rows) / std::max(params.kmeans_trainset_fraction * n_rows, index.n_lists())); @@ -1770,14 +1771,14 @@ auto build(raft::resources const& handle, if constexpr (std::is_same_v) { raft::spatial::knn::detail::utils::subsample( - handle, dataset, n_rows, trainset.view(), params.random_seed); + handle, dataset, n_rows, trainset.view(), random_seed); } else { // TODO(tfeher): Enable codebook generation with any type T, and then remove // trainset tmp. auto trainset_tmp = make_device_mdarray(handle, device_mr, make_extents(n_rows_train, dim)); raft::spatial::knn::detail::utils::subsample( - handle, dataset, n_rows, trainset_tmp.view(), params.random_seed); + handle, dataset, n_rows, trainset_tmp.view(), random_seed); cudaDeviceSynchronize(); RAFT_LOG_INFO("Subsampling done, converting to float"); raft::linalg::unaryOp(trainset.data_handle(), diff --git a/cpp/include/raft/neighbors/ivf_flat_types.hpp b/cpp/include/raft/neighbors/ivf_flat_types.hpp index 317e10cf92..180fe2e21b 100644 --- a/cpp/include/raft/neighbors/ivf_flat_types.hpp +++ b/cpp/include/raft/neighbors/ivf_flat_types.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -76,12 +76,6 @@ struct index_params : ann::index_params { * flag to `true` if you prefer to use as little GPU memory for the database as possible. */ bool conservative_memory_allocation = false; - /** - * Seed used for random sampling if kmeans_trainset_fraction < 1. - * - * Value -1 disables random sampling, and results in sampling with a fixed stride. - */ - int random_seed = 0; }; struct search_params : ann::search_params { diff --git a/cpp/include/raft/neighbors/ivf_pq_types.hpp b/cpp/include/raft/neighbors/ivf_pq_types.hpp index 51536583f8..45ab18c84f 100644 --- a/cpp/include/raft/neighbors/ivf_pq_types.hpp +++ b/cpp/include/raft/neighbors/ivf_pq_types.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -105,13 +105,6 @@ struct index_params : ann::index_params { * flag to `true` if you prefer to use as little GPU memory for the database as possible. */ bool conservative_memory_allocation = false; - - /** - * Seed used for random sampling if kmeans_trainset_fraction < 1. - * - * Value -1 disables random sampling, and results in sampling with a fixed stride. - */ - int random_seed = 0; }; struct search_params : ann::search_params { diff --git a/cpp/test/neighbors/ann_ivf_flat.cuh b/cpp/test/neighbors/ann_ivf_flat.cuh index 014048a068..39439d392d 100644 --- a/cpp/test/neighbors/ann_ivf_flat.cuh +++ b/cpp/test/neighbors/ann_ivf_flat.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -73,7 +73,6 @@ struct AnnIvfFlatInputs { raft::distance::DistanceType metric; bool adaptive_centers; bool host_dataset; - int seed; }; template @@ -81,7 +80,7 @@ template { os << "{ " << p.num_queries << ", " << p.num_db_vecs << ", " << p.dim << ", " << p.k << ", " << p.nprobe << ", " << p.nlist << ", " << static_cast(p.metric) << ", " - << p.adaptive_centers << ", " << p.host_dataset << "," << p.seed << '}' << std::endl; + << p.adaptive_centers << ", " << p.host_dataset << '}' << std::endl; return os; } @@ -179,7 +178,6 @@ class AnnIVFFlatTest : public ::testing::TestWithParam> { index_params.add_data_on_build = false; index_params.kmeans_trainset_fraction = 0.5; index_params.metric_arg = 0; - index_params.random_seed = ps.seed; ivf_flat::index idx(handle_, index_params, ps.dim); ivf_flat::index index_2(handle_, index_params, ps.dim); @@ -329,7 +327,6 @@ class AnnIVFFlatTest : public ::testing::TestWithParam> { index_params.add_data_on_build = false; index_params.kmeans_trainset_fraction = 1.0; index_params.metric_arg = 0; - index_params.random_seed = ps.seed; auto database_view = raft::make_device_matrix_view( (const DataT*)database.data(), ps.num_db_vecs, ps.dim); @@ -500,7 +497,6 @@ class AnnIVFFlatTest : public ::testing::TestWithParam> { index_params.add_data_on_build = true; index_params.kmeans_trainset_fraction = 0.5; index_params.metric_arg = 0; - index_params.random_seed = ps.seed; // Create IVF Flat index auto database_view = raft::make_device_matrix_view( @@ -611,8 +607,6 @@ const std::vector> inputs = { {20, 100000, 16, 10, 20, 1024, raft::distance::DistanceType::L2Expanded, true}, {1000, 100000, 16, 10, 20, 1024, raft::distance::DistanceType::L2Expanded, true}, {10000, 131072, 8, 10, 20, 1024, raft::distance::DistanceType::L2Expanded, false}, - {10000, 1000000, 96, 10, 20, 1024, raft::distance::DistanceType::L2Expanded, false, true, -1}, - {10000, 1000000, 96, 10, 20, 1024, raft::distance::DistanceType::L2Expanded, false, false, -1}, // host input data {1000, 10000, 16, 10, 40, 1024, raft::distance::DistanceType::L2Expanded, false, true}, diff --git a/python/pylibraft/pylibraft/neighbors/ivf_flat/cpp/c_ivf_flat.pxd b/python/pylibraft/pylibraft/neighbors/ivf_flat/cpp/c_ivf_flat.pxd index 035d2814fc..a281d33310 100644 --- a/python/pylibraft/pylibraft/neighbors/ivf_flat/cpp/c_ivf_flat.pxd +++ b/python/pylibraft/pylibraft/neighbors/ivf_flat/cpp/c_ivf_flat.pxd @@ -1,5 +1,5 @@ # -# Copyright (c) 2023-2024, NVIDIA CORPORATION. +# Copyright (c) 2023, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -56,7 +56,6 @@ cdef extern from "raft/neighbors/ivf_flat_types.hpp" \ double kmeans_trainset_fraction bool adaptive_centers bool conservative_memory_allocation - int random_seed cdef cppclass index[T, IdxT](ann_index): index(const device_resources& handle, diff --git a/python/pylibraft/pylibraft/neighbors/ivf_flat/ivf_flat.pyx b/python/pylibraft/pylibraft/neighbors/ivf_flat/ivf_flat.pyx index 64ac1a9ce9..d8fbdc74da 100644 --- a/python/pylibraft/pylibraft/neighbors/ivf_flat/ivf_flat.pyx +++ b/python/pylibraft/pylibraft/neighbors/ivf_flat/ivf_flat.pyx @@ -1,5 +1,5 @@ # -# Copyright (c) 2023-2024, NVIDIA CORPORATION. +# Copyright (c) 2023, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -113,11 +113,6 @@ cdef class IndexParams: adding new data (through the classification of the added data); that is, `index.centers()` "drift" together with the changing distribution of the newly added data. - random_seed : int, default = 0 - Seed used for random sampling if kmeans_trainset_fraction < 1. - Value -1 disables random sampling, and results in sampling with a - fixed stride. - """ cdef c_ivf_flat.index_params params @@ -127,8 +122,7 @@ cdef class IndexParams: kmeans_n_iters=20, kmeans_trainset_fraction=0.5, add_data_on_build=True, - bool adaptive_centers=False, - random_seed=0): + bool adaptive_centers=False): self.params.n_lists = n_lists self.params.metric = _get_metric(metric) self.params.metric_arg = 0 @@ -136,7 +130,6 @@ cdef class IndexParams: self.params.kmeans_trainset_fraction = kmeans_trainset_fraction self.params.add_data_on_build = add_data_on_build self.params.adaptive_centers = adaptive_centers - self.params.random_seed = random_seed @property def n_lists(self): @@ -162,10 +155,6 @@ cdef class IndexParams: def adaptive_centers(self): return self.params.adaptive_centers - @property - def random_seed(self): - return self.params.random_seed - cdef class Index: cdef readonly bool trained diff --git a/python/pylibraft/pylibraft/neighbors/ivf_pq/cpp/c_ivf_pq.pxd b/python/pylibraft/pylibraft/neighbors/ivf_pq/cpp/c_ivf_pq.pxd index f40e5465b7..531c2428e9 100644 --- a/python/pylibraft/pylibraft/neighbors/ivf_pq/cpp/c_ivf_pq.pxd +++ b/python/pylibraft/pylibraft/neighbors/ivf_pq/cpp/c_ivf_pq.pxd @@ -1,5 +1,5 @@ # -# Copyright (c) 2022-2024, NVIDIA CORPORATION. +# Copyright (c) 2022-2023, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -78,7 +78,6 @@ cdef extern from "raft/neighbors/ivf_pq_types.hpp" \ codebook_gen codebook_kind bool force_random_rotation bool conservative_memory_allocation - int random_seed cdef cppclass index[IdxT](ann_index): index(const device_resources& handle, diff --git a/python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pyx b/python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pyx index b8a3cf4887..0c1bbf6b9c 100644 --- a/python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pyx +++ b/python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pyx @@ -1,5 +1,5 @@ # -# Copyright (c) 2022-2024, NVIDIA CORPORATION. +# Copyright (c) 2022-2023, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -156,10 +156,6 @@ cdef class IndexParams: repeated calls to `extend` (extending the database). To disable this behavior and use as little GPU memory for the database as possible, set this flat to `True`. - random_seed : int, default = 0 - Seed used for random sampling if kmeans_trainset_fraction < 1. - Value -1 disables random sampling, and results in sampling with a - fixed stride. """ def __init__(self, *, n_lists=1024, @@ -171,8 +167,7 @@ cdef class IndexParams: codebook_kind="subspace", force_random_rotation=False, add_data_on_build=True, - conservative_memory_allocation=False, - random_seed=0): + conservative_memory_allocation=False): self.params.n_lists = n_lists self.params.metric = _get_metric(metric) self.params.metric_arg = 0 @@ -190,7 +185,6 @@ cdef class IndexParams: self.params.add_data_on_build = add_data_on_build self.params.conservative_memory_allocation = \ conservative_memory_allocation - self.params.random_seed = random_seed @property def n_lists(self): @@ -232,10 +226,6 @@ cdef class IndexParams: def conservative_memory_allocation(self): return self.params.conservative_memory_allocation - @property - def random_seed(self): - return self.params.random_seed - cdef class Index: # We store a pointer to the index because it dose not have a trivial From abf68a627f76d4b3f2601d50f465455ad9cc29f4 Mon Sep 17 00:00:00 2001 From: Tamas Bela Feher Date: Thu, 18 Jan 2024 01:22:39 +0100 Subject: [PATCH 05/11] Remove strided subsampling --- cpp/include/raft/spatial/knn/detail/ann_utils.cuh | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/cpp/include/raft/spatial/knn/detail/ann_utils.cuh b/cpp/include/raft/spatial/knn/detail/ann_utils.cuh index 0f585c1572..bbe4c081e2 100644 --- a/cpp/include/raft/spatial/knn/detail/ann_utils.cuh +++ b/cpp/include/raft/spatial/knn/detail/ann_utils.cuh @@ -618,17 +618,9 @@ void subsample(raft::resources const& res, { IdxT n_dim = output.extent(1); IdxT n_train = output.extent(0); - if (seed == -1 || n_train == n_samples) { - IdxT trainset_ratio = n_samples / n_train; - RAFT_LOG_INFO("Fixed stride subsampling"); - RAFT_CUDA_TRY(cudaMemcpy2DAsync(output.data_handle(), - sizeof(T) * n_dim, - input, - sizeof(T) * n_dim * trainset_ratio, - sizeof(T) * n_dim, - n_train, - cudaMemcpyDefault, - resource::get_cuda_stream(res))); + if (n_train == n_samples) { + RAFT_LOG_INFO("No subsampling"); + raft::copy(output.data_handle(), input, n_dim * n_samples, resource::get_cuda_stream(res)); return; } RAFT_LOG_DEBUG("Random subsampling"); From 7cfd9b5688787527341b44267a0923ea6295b9ef Mon Sep 17 00:00:00 2001 From: Tamas Bela Feher Date: Thu, 18 Jan 2024 13:35:16 +0100 Subject: [PATCH 06/11] Use managed memory for the new temprorary buffer --- cpp/include/raft/neighbors/detail/ivf_pq_build.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_build.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_build.cuh index 11d520d989..b942ae2f88 100644 --- a/cpp/include/raft/neighbors/detail/ivf_pq_build.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_pq_build.cuh @@ -1776,7 +1776,7 @@ auto build(raft::resources const& handle, // TODO(tfeher): Enable codebook generation with any type T, and then remove // trainset tmp. auto trainset_tmp = - make_device_mdarray(handle, device_mr, make_extents(n_rows_train, dim)); + make_device_mdarray(handle, &managed_mr, make_extents(n_rows_train, dim)); raft::spatial::knn::detail::utils::subsample( handle, dataset, n_rows, trainset_tmp.view(), random_seed); cudaDeviceSynchronize(); From 548555766b9acb485000c24e32816d6d874f58b5 Mon Sep 17 00:00:00 2001 From: Tamas Bela Feher Date: Sat, 20 Jan 2024 01:31:58 +0100 Subject: [PATCH 07/11] Batched index gather overlapped with H2D copies --- cpp/include/raft/matrix/detail/gather.cuh | 76 ++++++++++++++++++- .../raft/spatial/knn/detail/ann_utils.cuh | 43 +++-------- 2 files changed, 85 insertions(+), 34 deletions(-) diff --git a/cpp/include/raft/matrix/detail/gather.cuh b/cpp/include/raft/matrix/detail/gather.cuh index 73072ec841..767b8721a9 100644 --- a/cpp/include/raft/matrix/detail/gather.cuh +++ b/cpp/include/raft/matrix/detail/gather.cuh @@ -17,9 +17,15 @@ #pragma once #include +#include +#include +#include +#include #include +#include +#include +#include #include - namespace raft { namespace matrix { namespace detail { @@ -335,6 +341,74 @@ void gather_if(const InputIteratorT in, gatherImpl(in, D, N, map, stencil, map_length, out, pred_op, transform_op, stream); } +template +void gather_buff(host_matrix_view dataset, + host_vector_view indices, + IdxT offset, + pinned_matrix_view buff) +{ + raft::common::nvtx::range fun_scope("Gather vectors"); + + IdxT batch_size = std::min(buff.extent(0), indices.extent(0) - offset); + +#pragma omp for + for (IdxT i = 0; i < batch_size; i++) { + IdxT in_idx = indices(offset + i); + for (IdxT k = 0; k < buff.extent(1); k++) { + buff(i, k) = dataset(in_idx, k); + } + } +} + +template +void gather(raft::resources const& res, + host_matrix_view dataset, + device_vector_view indices, + raft::device_matrix_view output) +{ + IdxT n_dim = output.extent(1); + IdxT n_train = output.extent(0); + auto indices_host = raft::make_host_vector(n_train); + raft::copy( + indices_host.data_handle(), indices.data_handle(), n_train, resource::get_cuda_stream(res)); + resource::sync_stream(res); + + const size_t max_batch_size = 32768; + // Gather the vector on the host in tmp buffers. We use two buffers to overlap H2D sync + // and gathering the data. + raft::common::nvtx::push_range("subsample::alloc_buffers"); + // rmm::mr::pinned_memory_resource mr_pinned; + // auto out_tmp1 = make_host_mdarray(res, mr_pinned, make_extents(max_batch_size, + // n_dim)); auto out_tmp2 = make_host_mdarray(res, mr_pinned, + // make_extents(max_batch_size, n_dim)); + auto out_tmp1 = raft::make_pinned_matrix(res, max_batch_size, n_dim); + auto out_tmp2 = raft::make_pinned_matrix(res, max_batch_size, n_dim); + auto view1 = out_tmp1.view(); + auto view2 = out_tmp2.view(); + raft::common::nvtx::pop_range(); + + gather_buff(dataset, make_const_mdspan(indices_host.view()), (IdxT)0, view1); +#pragma omp parallel + for (IdxT device_offset = 0; device_offset < n_train; device_offset += max_batch_size) { + IdxT batch_size = std::min(max_batch_size, n_train - device_offset); +#pragma omp master + raft::copy(output.data_handle() + device_offset * n_dim, + view1.data_handle(), + batch_size * n_dim, + resource::get_cuda_stream(res)); + // Start gathering the next batch on the host. + IdxT host_offset = device_offset + batch_size; + batch_size = std::min(max_batch_size, n_train - host_offset); + if (batch_size > 0) { + gather_buff(dataset, make_const_mdspan(indices_host.view()), host_offset, view2); + } +#pragma omp master + resource::sync_stream(res); +#pragma omp barrier + std::swap(view1, view2); + } +} + } // namespace detail } // namespace matrix } // namespace raft diff --git a/cpp/include/raft/spatial/knn/detail/ann_utils.cuh b/cpp/include/raft/spatial/knn/detail/ann_utils.cuh index bbe4c081e2..bd25506d44 100644 --- a/cpp/include/raft/spatial/knn/detail/ann_utils.cuh +++ b/cpp/include/raft/spatial/knn/detail/ann_utils.cuh @@ -16,13 +16,15 @@ #pragma once +#include #include #include #include #include + #include #include -#include +#include #include #include #include @@ -601,10 +603,6 @@ auto get_subsample_indices(raft::resources const& res, IdxT n_samples, IdxT n_su std::nullopt, train_indices.view(), std::nullopt); - - thrust::sort(resource::get_thrust_policy(res), - train_indices.data_handle(), - train_indices.data_handle() + n_subsamples); return train_indices; } @@ -618,12 +616,7 @@ void subsample(raft::resources const& res, { IdxT n_dim = output.extent(1); IdxT n_train = output.extent(0); - if (n_train == n_samples) { - RAFT_LOG_INFO("No subsampling"); - raft::copy(output.data_handle(), input, n_dim * n_samples, resource::get_cuda_stream(res)); - return; - } - RAFT_LOG_DEBUG("Random subsampling"); + raft::device_vector train_indices = get_subsample_indices(res, n_samples, n_train, seed); @@ -631,29 +624,13 @@ void subsample(raft::resources const& res, RAFT_CUDA_TRY(cudaPointerGetAttributes(&attr, input)); T* ptr = reinterpret_cast(attr.devicePointer); if (ptr != nullptr) { - raft::matrix::copy_rows(res, - raft::make_device_matrix_view(ptr, n_samples, n_dim), - output, - raft::make_const_mdspan(train_indices.view())); + raft::matrix::gather(res, + raft::make_device_matrix_view(ptr, n_samples, n_dim), + raft::make_const_mdspan(train_indices.view()), + output); } else { - auto dataset = raft::make_host_matrix_view(input, n_samples, n_dim); - auto train_indices_host = raft::make_host_vector(n_train); - raft::copy(train_indices_host.data_handle(), - train_indices.data_handle(), - n_train, - resource::get_cuda_stream(res)); - resource::sync_stream(res); - auto out_tmp = raft::make_host_matrix(n_train, n_dim); -#pragma omp parallel for - for (IdxT i = 0; i < n_train; i++) { - IdxT in_idx = train_indices_host(i); - for (IdxT k = 0; k < n_dim; k++) { - out_tmp(i, k) = dataset(in_idx, k); - } - } - raft::copy( - output.data_handle(), out_tmp.data_handle(), output.size(), resource::get_cuda_stream(res)); - resource::sync_stream(res); + auto dataset = raft::make_host_matrix_view(input, n_samples, n_dim); + raft::matrix::detail::gather(res, dataset, make_const_mdspan(train_indices.view()), output); } } } // namespace raft::spatial::knn::detail::utils From 9177356b13067ef6ebec791e041cd2b4b9a601bc Mon Sep 17 00:00:00 2001 From: Tamas Bela Feher Date: Sun, 21 Jan 2024 13:10:47 +0100 Subject: [PATCH 08/11] Fix copyright year --- cpp/include/raft/matrix/detail/gather.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/include/raft/matrix/detail/gather.cuh b/cpp/include/raft/matrix/detail/gather.cuh index 767b8721a9..f34a5f1e39 100644 --- a/cpp/include/raft/matrix/detail/gather.cuh +++ b/cpp/include/raft/matrix/detail/gather.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. From 790c0e6e267bc0501639cf9e436fd0c94dcd5582 Mon Sep 17 00:00:00 2001 From: Tamas Bela Feher Date: Sun, 21 Jan 2024 13:35:53 +0100 Subject: [PATCH 09/11] Fix nvtx markers --- cpp/include/raft/matrix/detail/gather.cuh | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/cpp/include/raft/matrix/detail/gather.cuh b/cpp/include/raft/matrix/detail/gather.cuh index f34a5f1e39..b41176b533 100644 --- a/cpp/include/raft/matrix/detail/gather.cuh +++ b/cpp/include/raft/matrix/detail/gather.cuh @@ -17,6 +17,7 @@ #pragma once #include +#include #include #include #include @@ -26,6 +27,7 @@ #include #include #include + namespace raft { namespace matrix { namespace detail { @@ -347,8 +349,7 @@ void gather_buff(host_matrix_view dataset, IdxT offset, pinned_matrix_view buff) { - raft::common::nvtx::range fun_scope("Gather vectors"); - + raft::common::nvtx::range fun_scope("gather_host_buff"); IdxT batch_size = std::min(buff.extent(0), indices.extent(0) - offset); #pragma omp for @@ -366,6 +367,7 @@ void gather(raft::resources const& res, device_vector_view indices, raft::device_matrix_view output) { + raft::common::nvtx::range fun_scope("gather"); IdxT n_dim = output.extent(1); IdxT n_train = output.extent(0); auto indices_host = raft::make_host_vector(n_train); @@ -376,11 +378,7 @@ void gather(raft::resources const& res, const size_t max_batch_size = 32768; // Gather the vector on the host in tmp buffers. We use two buffers to overlap H2D sync // and gathering the data. - raft::common::nvtx::push_range("subsample::alloc_buffers"); - // rmm::mr::pinned_memory_resource mr_pinned; - // auto out_tmp1 = make_host_mdarray(res, mr_pinned, make_extents(max_batch_size, - // n_dim)); auto out_tmp2 = make_host_mdarray(res, mr_pinned, - // make_extents(max_batch_size, n_dim)); + raft::common::nvtx::push_range("gather::alloc_buffers"); auto out_tmp1 = raft::make_pinned_matrix(res, max_batch_size, n_dim); auto out_tmp2 = raft::make_pinned_matrix(res, max_batch_size, n_dim); auto view1 = out_tmp1.view(); From 8223259fe2fd7545acd216dd04514ae363fba07c Mon Sep 17 00:00:00 2001 From: Tamas Bela Feher Date: Tue, 23 Jan 2024 08:21:22 +0100 Subject: [PATCH 10/11] replace thrust call with map_offset --- cpp/include/raft/spatial/knn/detail/ann_utils.cuh | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/cpp/include/raft/spatial/knn/detail/ann_utils.cuh b/cpp/include/raft/spatial/knn/detail/ann_utils.cuh index bd25506d44..e55dc82f5d 100644 --- a/cpp/include/raft/spatial/knn/detail/ann_utils.cuh +++ b/cpp/include/raft/spatial/knn/detail/ann_utils.cuh @@ -23,7 +23,9 @@ #include #include +#include #include +#include #include #include #include @@ -38,10 +40,6 @@ #include #include -#include -#include -#include -#include namespace raft::spatial::knn::detail::utils { @@ -592,9 +590,7 @@ auto get_subsample_indices(raft::resources const& res, IdxT n_samples, IdxT n_su RAFT_EXPECTS(n_subsamples <= n_samples, "Cannot have more training samples than dataset vectors"); auto data_indices = raft::make_device_vector(res, n_samples); - thrust::counting_iterator first(0); - thrust::device_ptr ptr(data_indices.data_handle()); - thrust::copy(raft::resource::get_thrust_policy(res), first, first + n_samples, ptr); + raft::linalg::map_offset(res, data_indices.view(), identity_op()); raft::random::RngState rng(seed); auto train_indices = raft::make_device_vector(res, n_subsamples); raft::random::sample_without_replacement(res, From f431ae792dcd8491ee09e0fb688d0cbfe28ac601 Mon Sep 17 00:00:00 2001 From: Tamas Bela Feher Date: Tue, 23 Jan 2024 08:21:52 +0100 Subject: [PATCH 11/11] Remove unused copy_warped --- .../raft/neighbors/detail/ivf_pq_build.cuh | 46 ------------------- 1 file changed, 46 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_build.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_build.cuh index b942ae2f88..cc94511fe7 100644 --- a/cpp/include/raft/neighbors/detail/ivf_pq_build.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_pq_build.cuh @@ -64,51 +64,6 @@ namespace raft::neighbors::ivf_pq::detail { using namespace raft::spatial::knn::detail; // NOLINT -template -__launch_bounds__(BlockDim) RAFT_KERNEL copy_warped_kernel( - T* out, uint32_t ld_out, const S* in, uint32_t ld_in, uint32_t n_cols, size_t n_rows) -{ - using warp = Pow2; - size_t row_ix = warp::div(size_t(threadIdx.x) + size_t(BlockDim) * size_t(blockIdx.x)); - uint32_t i = warp::mod(threadIdx.x); - if (row_ix >= n_rows) return; - out += row_ix * ld_out; - in += row_ix * ld_in; - auto f = utils::mapping{}; - for (uint32_t col_ix = i; col_ix < n_cols; col_ix += warp::Value) { - auto x = f(in[col_ix]); - __syncwarp(); - out[col_ix] = x; - } -} - -/** - * Copy the data one warp-per-row: - * - * 1. load the data per-warp - * 2. apply the `utils::mapping{}` - * 3. sync within warp - * 4. store the data. - * - * Assuming sizeof(T) >= sizeof(S) and the data is properly aligned (see the usage in `build`), this - * allows to re-structure the data within rows in-place. - */ -template -void copy_warped(T* out, - uint32_t ld_out, - const S* in, - uint32_t ld_in, - uint32_t n_cols, - size_t n_rows, - rmm::cuda_stream_view stream) -{ - constexpr uint32_t kBlockDim = 128; - dim3 threads(kBlockDim, 1, 1); - dim3 blocks(div_rounding_up_safe(n_rows, kBlockDim / WarpSize), 1, 1); - copy_warped_kernel - <<>>(out, ld_out, in, ld_in, n_cols, n_rows); -} - /** * @brief Fill-in a random orthogonal transformation matrix. * @@ -1780,7 +1735,6 @@ auto build(raft::resources const& handle, raft::spatial::knn::detail::utils::subsample( handle, dataset, n_rows, trainset_tmp.view(), random_seed); cudaDeviceSynchronize(); - RAFT_LOG_INFO("Subsampling done, converting to float"); raft::linalg::unaryOp(trainset.data_handle(), trainset_tmp.data_handle(), trainset.size(),