Skip to content

Commit

Permalink
Merge branch 'branch-24.04' into fusedCosineNN
Browse files Browse the repository at this point in the history
  • Loading branch information
mdoijade committed Jan 25, 2024
2 parents d4ecc12 + ea0d2f6 commit 6d69589
Show file tree
Hide file tree
Showing 11 changed files with 142 additions and 46 deletions.
18 changes: 8 additions & 10 deletions cpp/bench/ann/src/common/benchmark.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -287,11 +287,11 @@ void bench_search(::benchmark::State& state,
std::make_shared<buf<std::size_t>>(current_algo_props->query_memory_type, k * query_set_size);

cuda_timer gpu_timer;
auto start = std::chrono::high_resolution_clock::now();
{
nvtx_case nvtx{state.name()};

auto algo = dynamic_cast<ANN<T>*>(current_algo.get())->copy();
auto algo = dynamic_cast<ANN<T>*>(current_algo.get())->copy();
auto start = std::chrono::high_resolution_clock::now();
for (auto _ : state) {
[[maybe_unused]] auto ntx_lap = nvtx.lap();
[[maybe_unused]] auto gpu_lap = gpu_timer.lap();
Expand All @@ -314,17 +314,15 @@ void bench_search(::benchmark::State& state,

queries_processed += n_queries;
}
auto end = std::chrono::high_resolution_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::duration<double>>(end - start).count();
if (state.thread_index() == 0) { state.counters.insert({{"end_to_end", duration}}); }
state.counters.insert({"Latency", {duration, benchmark::Counter::kAvgIterations}});
}
auto end = std::chrono::high_resolution_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::duration<double>>(end - start).count();
if (state.thread_index() == 0) { state.counters.insert({{"end_to_end", duration}}); }
state.counters.insert(
{"Latency", {duration / double(state.iterations()), benchmark::Counter::kAvgThreads}});

state.SetItemsProcessed(queries_processed);
if (cudart.found()) {
double gpu_time_per_iteration = gpu_timer.total_time() / (double)state.iterations();
state.counters.insert({"GPU", {gpu_time_per_iteration, benchmark::Counter::kAvgThreads}});
state.counters.insert({"GPU", {gpu_timer.total_time(), benchmark::Counter::kAvgIterations}});
}

// This will be the total number of queries across all threads
Expand Down
5 changes: 4 additions & 1 deletion cpp/bench/ann/src/raft/raft_ann_bench_param_parser.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -87,6 +87,9 @@ void parse_build_param(const nlohmann::json& conf,
"', should be either 'cluster' or 'subspace'");
}
}
if (conf.contains("max_train_points_per_pq_code")) {
param.max_train_points_per_pq_code = conf.at("max_train_points_per_pq_code");
}
}

template <typename T, typename IdxT>
Expand Down
32 changes: 31 additions & 1 deletion cpp/include/raft/core/math.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -49,12 +49,42 @@ RAFT_INLINE_FUNCTION auto abs(T x)
template <typename T>
constexpr RAFT_INLINE_FUNCTION auto abs(T x)
-> std::enable_if_t<!std::is_same_v<float, T> && !std::is_same_v<double, T> &&
#if defined(_RAFT_HAS_CUDA)
!std::is_same_v<__half, T> && !std::is_same_v<nv_bfloat16, T> &&
#endif
!std::is_same_v<int, T> && !std::is_same_v<long int, T> &&
!std::is_same_v<long long int, T>,
T>
{
return x < T{0} ? -x : x;
}
#if defined(_RAFT_HAS_CUDA)
template <typename T>
RAFT_DEVICE_INLINE_FUNCTION typename std::enable_if_t<std::is_same_v<T, __half>, __half> abs(T x)
{
#if (__CUDA_ARCH__ >= 530)
return ::__habs(x);
#else
// Fail during template instantiation if the compute capability doesn't support this operation
static_assert(sizeof(T) != sizeof(T), "__half is only supported on __CUDA_ARCH__ >= 530");
return T{};
#endif
}

template <typename T>
RAFT_DEVICE_INLINE_FUNCTION typename std::enable_if_t<std::is_same_v<T, nv_bfloat16>, nv_bfloat16>
abs(T x)
{
#if (__CUDA_ARCH__ >= 800)
return ::__habs(x);
#else
// Fail during template instantiation if the compute capability doesn't support this operation
static_assert(sizeof(T) != sizeof(T), "nv_bfloat16 is only supported on __CUDA_ARCH__ >= 800");
return T{};
#endif
}
#endif
/** @} */

/** Inverse cosine */
template <typename T>
Expand Down
29 changes: 20 additions & 9 deletions cpp/include/raft/neighbors/detail/ivf_pq_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -353,14 +353,19 @@ void train_per_subset(raft::resources const& handle,
const float* trainset, // [n_rows, dim]
const uint32_t* labels, // [n_rows]
uint32_t kmeans_n_iters,
uint32_t max_train_points_per_pq_code,
rmm::mr::device_memory_resource* managed_memory)
{
auto stream = resource::get_cuda_stream(handle);
auto device_memory = resource::get_workspace_resource(handle);

rmm::device_uvector<float> pq_centers_tmp(index.pq_centers().size(), stream, device_memory);
rmm::device_uvector<float> sub_trainset(n_rows * size_t(index.pq_len()), stream, device_memory);
rmm::device_uvector<uint32_t> sub_labels(n_rows, stream, device_memory);
// Subsampling the train set for codebook generation based on max_train_points_per_pq_code.
size_t big_enough = max_train_points_per_pq_code * size_t(index.pq_book_size());
auto pq_n_rows = uint32_t(std::min(big_enough, n_rows));
rmm::device_uvector<float> sub_trainset(
pq_n_rows * size_t(index.pq_len()), stream, device_memory);
rmm::device_uvector<uint32_t> sub_labels(pq_n_rows, stream, device_memory);

rmm::device_uvector<uint32_t> pq_cluster_sizes(index.pq_book_size(), stream, device_memory);

Expand All @@ -371,7 +376,7 @@ void train_per_subset(raft::resources const& handle,
// Get the rotated cluster centers for each training vector.
// This will be subtracted from the input vectors afterwards.
utils::copy_selected<float, float, size_t, uint32_t>(
n_rows,
pq_n_rows,
index.pq_len(),
index.centers_rot().data_handle() + index.pq_len() * j,
labels,
Expand All @@ -387,7 +392,7 @@ void train_per_subset(raft::resources const& handle,
true,
false,
index.pq_len(),
n_rows,
pq_n_rows,
index.dim(),
&alpha,
index.rotation_matrix().data_handle() + index.dim() * index.pq_len() * j,
Expand All @@ -400,13 +405,14 @@ void train_per_subset(raft::resources const& handle,
stream);

// train PQ codebook for this subspace
auto sub_trainset_view =
raft::make_device_matrix_view<const float, IdxT>(sub_trainset.data(), n_rows, index.pq_len());
auto sub_trainset_view = raft::make_device_matrix_view<const float, IdxT>(
sub_trainset.data(), pq_n_rows, index.pq_len());
auto centers_tmp_view = raft::make_device_matrix_view<float, IdxT>(
pq_centers_tmp.data() + index.pq_book_size() * index.pq_len() * j,
index.pq_book_size(),
index.pq_len());
auto sub_labels_view = raft::make_device_vector_view<uint32_t, IdxT>(sub_labels.data(), n_rows);
auto sub_labels_view =
raft::make_device_vector_view<uint32_t, IdxT>(sub_labels.data(), pq_n_rows);
auto cluster_sizes_view =
raft::make_device_vector_view<uint32_t, IdxT>(pq_cluster_sizes.data(), index.pq_book_size());
raft::cluster::kmeans_balanced_params kmeans_params;
Expand All @@ -430,6 +436,7 @@ void train_per_cluster(raft::resources const& handle,
const float* trainset, // [n_rows, dim]
const uint32_t* labels, // [n_rows]
uint32_t kmeans_n_iters,
uint32_t max_train_points_per_pq_code,
rmm::mr::device_memory_resource* managed_memory)
{
auto stream = resource::get_cuda_stream(handle);
Expand Down Expand Up @@ -477,9 +484,11 @@ void train_per_cluster(raft::resources const& handle,
indices + cluster_offsets[l],
device_memory);

// limit the cluster size to bound the training time.
// limit the cluster size to bound the training time based on max_train_points_per_pq_code
// If pq_book_size is less than pq_dim, use max_train_points_per_pq_code per pq_dim instead
// [sic] we interpret the data as pq_len-dimensional
size_t big_enough = 256ul * std::max<size_t>(index.pq_book_size(), index.pq_dim());
size_t big_enough =
max_train_points_per_pq_code * std::max<size_t>(index.pq_book_size(), index.pq_dim());
size_t available_rows = size_t(cluster_size) * size_t(index.pq_dim());
auto pq_n_rows = uint32_t(std::min(big_enough, available_rows));
// train PQ codebook for this cluster
Expand Down Expand Up @@ -1788,6 +1797,7 @@ auto build(raft::resources const& handle,
trainset.data_handle(),
labels.data(),
params.kmeans_n_iters,
params.max_train_points_per_pq_code,
&managed_mr);
break;
case codebook_gen::PER_CLUSTER:
Expand All @@ -1797,6 +1807,7 @@ auto build(raft::resources const& handle,
trainset.data_handle(),
labels.data(),
params.kmeans_n_iters,
params.max_train_points_per_pq_code,
&managed_mr);
break;
default: RAFT_FAIL("Unreachable code");
Expand Down
10 changes: 9 additions & 1 deletion cpp/include/raft/neighbors/ivf_pq_types.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
* Copyright (c) 2022-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -105,6 +105,14 @@ struct index_params : ann::index_params {
* flag to `true` if you prefer to use as little GPU memory for the database as possible.
*/
bool conservative_memory_allocation = false;
/**
* The max number of data points to use per PQ code during PQ codebook training. Using more data
* points per PQ code may increase the quality of PQ codebook but may also increase the build
* time. The parameter is applied to both PQ codebook generation methods, i.e., PER_SUBSPACE and
* PER_CLUSTER. In both cases, we will use `pq_book_size * max_train_points_per_pq_code` training
* points to train each codebook.
*/
uint32_t max_train_points_per_pq_code = 256;
};

struct search_params : ann::search_params {
Expand Down
8 changes: 2 additions & 6 deletions cpp/include/raft/sparse/linalg/spmm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ namespace linalg {
* @param[in] x input raft::device_csr_matrix_view
* @param[in] y input raft::device_matrix_view
* @param[in] beta scalar
* @param[out] z output raft::device_matrix_view
* @param[inout] z input-output raft::device_matrix_view
*/
template <typename ValueType,
typename IndexType,
Expand All @@ -60,19 +60,15 @@ void spmm(raft::resources const& handle,
{
bool is_row_major = detail::is_row_major(y, z);

auto z_tmp = raft::make_device_matrix<ValueType, IndexType>(handle, z.extent(0), z.extent(1));
auto z_tmp_view = raft::make_device_strided_matrix_view<ValueType, IndexType, LayoutPolicyZ>(
z_tmp.data_handle(), z.extent(0), z.extent(1), is_row_major ? z.stride(0) : z.stride(1));
z.data_handle(), z.extent(0), z.extent(1), is_row_major ? z.stride(0) : z.stride(1));

auto descr_x = detail::create_descriptor(x);
auto descr_y = detail::create_descriptor(y);
auto descr_z = detail::create_descriptor(z_tmp_view);

detail::spmm(handle, trans_x, trans_y, is_row_major, alpha, descr_x, descr_y, beta, descr_z);

raft::copy(
z.data_handle(), z_tmp.data_handle(), z_tmp.size(), raft::resource::get_cuda_stream(handle));

RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroySpMat(descr_x));
RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroyDnMat(descr_y));
RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroyDnMat(descr_z));
Expand Down
55 changes: 44 additions & 11 deletions cpp/include/raft/util/cuda_utils.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2018-2023, NVIDIA CORPORATION.
* Copyright (c) 2018-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,15 +16,12 @@

#pragma once

#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <math_constants.h>
#include <stdint.h>
#include <type_traits>

#if defined(_RAFT_HAS_CUDA)
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#endif

#include <raft/core/cudart_utils.hpp>
#include <raft/core/math.hpp>
#include <raft/core/operators.hpp>
Expand Down Expand Up @@ -278,17 +275,53 @@ template <>
* @{
*/
template <typename T>
inline __device__ T myInf();
template <>
inline __device__ float myInf<float>()
RAFT_DEVICE_INLINE_FUNCTION typename std::enable_if_t<std::is_same_v<T, float>, float> myInf()
{
return CUDART_INF_F;
}
template <>
inline __device__ double myInf<double>()
template <typename T>
RAFT_DEVICE_INLINE_FUNCTION typename std::enable_if_t<std::is_same_v<T, double>, double> myInf()
{
return CUDART_INF;
}
// Half/Bfloat constants only defined after CUDA 12.2
#if __CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 2)
template <typename T>
RAFT_DEVICE_INLINE_FUNCTION typename std::enable_if_t<std::is_same_v<T, __half>, __half> myInf()
{
#if (__CUDA_ARCH__ >= 530)
return __ushort_as_half((unsigned short)0x7C00U);
#else
// Fail during template instantiation if the compute capability doesn't support this operation
static_assert(sizeof(T) != sizeof(T), "__half is only supported on __CUDA_ARCH__ >= 530");
return T{};
#endif
}
template <typename T>
RAFT_DEVICE_INLINE_FUNCTION typename std::enable_if_t<std::is_same_v<T, nv_bfloat16>, nv_bfloat16>
myInf()
{
#if (__CUDA_ARCH__ >= 800)
return __ushort_as_bfloat16((unsigned short)0x7F80U);
#else
// Fail during template instantiation if the compute capability doesn't support this operation
static_assert(sizeof(T) != sizeof(T), "nv_bfloat16 is only supported on __CUDA_ARCH__ >= 800");
return T{};
#endif
}
#else
template <typename T>
RAFT_DEVICE_INLINE_FUNCTION typename std::enable_if_t<std::is_same_v<T, __half>, __half> myInf()
{
return CUDART_INF_FP16;
}
template <typename T>
RAFT_DEVICE_INLINE_FUNCTION typename std::enable_if_t<std::is_same_v<T, nv_bfloat16>, nv_bfloat16>
myInf()
{
return CUDART_INF_BF16;
}
#endif
/** @} */

/**
Expand Down
1 change: 1 addition & 0 deletions docs/source/ann_benchmarks_param_tuning.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ IVF-pq is an inverted-file index, which partitions the vectors into a series of
| `pq_bits` | `build` | N | Positive Integer. [4-8] | 8 | Bit length of the vector element after quantization. |
| `codebook_kind` | `build` | N | ["cluster", "subspace"] | "subspace" | Type of codebook. See the [API docs](https://docs.rapids.ai/api/raft/nightly/cpp_api/neighbors_ivf_pq/#_CPPv412codebook_gen) for more detail |
| `dataset_memory_type` | `build` | N | ["device", "host", "mmap"] | "host" | What memory type should the dataset reside? |
| `max_train_points_per_pq_code` | `build` | N | Positive Number >=1 | 256 | Max number of data points per PQ code used for PQ code book creation. Depending on input dataset size, the data points could be less than what user specifies. |
| `query_memory_type` | `search` | N | ["device", "host", "mmap"] | "device | What memory type should the queries reside? |
| `nprobe` | `search` | Y | Positive Integer >0 | | The closest number of clusters to search for each query vector. Larger values will improve recall but will search more points in the index. |
| `internalDistanceDtype` | `search` | N | [`float`, `half`] | `half` | The precision to use for the distance computations. Lower precision can increase performance at the cost of accuracy. |
Expand Down
9 changes: 5 additions & 4 deletions docs/source/wiki_all_dataset.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,11 @@

The `wiki-all` dataset was created to stress vector search algorithms at scale with both a large number of vectors and dimensions. The entire dataset contains 88M vectors with 768 dimensions and is meant for testing the types of vectors one would typically encounter in retrieval augmented generation (RAG) workloads. The full dataset is ~251GB in size, which is intentionally larger than the typical memory of GPUs. The massive scale is intended to promote the use of compression and efficient out-of-core methods for both indexing and search.

The dataset is composed of all the available languages of in the [Cohere Wikipedia dataset](https://huggingface.co/datasets/Cohere/wikipedia-22-12). An [English version]( https://www.kaggle.com/datasets/jjinho/wikipedia-20230701) is also available.


The dataset is composed of English wiki texts from [Kaggle](https://www.kaggle.com/datasets/jjinho/wikipedia-20230701) and multi-lingual wiki texts from [Cohere Wikipedia](https://huggingface.co/datasets/Cohere/wikipedia-22-12).

Cohere's English Texts are older (2022) and smaller than the Kaggle English Wiki texts (2023) so the English texts have been removed from Cohere completely. The final Wiki texts include English Wiki from Kaggle and the other languages from Cohere. The English texts constitute 50% of the total text size.

To form the final dataset, the Wiki texts were chunked into 85 million 128-token pieces. For reference, Cohere chunks Wiki texts into 104-token pieces. Finally, the embeddings of each chunk were computed using the [paraphrase-multilingual-mpnet-base-v2](https://huggingface.co/sentence-transformers/paraphrase-multilingual-mpnet-base-v2) embedding model. The resulting dataset is an embedding matrix of size 88 million by 768. Also included with the dataset is a query file containing 10k query vectors and a groundtruth file to evaluate nearest neighbors algorithms.
To form the final dataset, the Wiki texts were chunked into 85 million 128-token pieces. For reference, Cohere chunks Wiki texts into 104-token pieces. Finally, the embeddings of each chunk were computed using the [paraphrase-multilingual-mpnet-base-v2](https://huggingface.co/sentence-transformers/paraphrase-multilingual-mpnet-base-v2) embedding model. The resulting dataset is an embedding matrix of size 88 million by 768. Also included with the dataset is a query file containing 10k query vectors and a groundtruth file to evaluate nearest neighbors algorithms.

## Getting the dataset

Expand Down Expand Up @@ -44,3 +41,7 @@ curl -s https://data.rapids.ai/raft/datasets/wiki_all_10M/wiki_all_10M.tar
## Using the dataset

After the dataset is downloaded and extracted to the `wiki_all_88M` directory (or `wiki_all_1M`/`wiki_all_10M` depending on whether the subsets are used), the files can be used in the benchmarking tool. The dataset name is `wiki_all` (or `wiki_all_1M`/`wiki_all_10M`), and the benchmarking tool can be used by specifying the appropriate name `--dataset wiki_all_88M` in the scripts.

## License info

The English wiki texts available on Kaggle come with the [CC BY-NCSA 4.0](https://creativecommons.org/licenses/by-nc-sa/4.0/) license and the Cohere wikipedia data set comes with the [Apache 2.0](https://choosealicense.com/licenses/apache-2.0/) license.
3 changes: 2 additions & 1 deletion python/pylibraft/pylibraft/neighbors/ivf_pq/cpp/c_ivf_pq.pxd
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2022-2023, NVIDIA CORPORATION.
# Copyright (c) 2022-2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -78,6 +78,7 @@ cdef extern from "raft/neighbors/ivf_pq_types.hpp" \
codebook_gen codebook_kind
bool force_random_rotation
bool conservative_memory_allocation
uint32_t max_train_points_per_pq_code

cdef cppclass index[IdxT](ann_index):
index(const device_resources& handle,
Expand Down
Loading

0 comments on commit 6d69589

Please sign in to comment.