Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNMG ANN #1993

Closed
wants to merge 25 commits into from
Closed
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
3b74685
SNMG ANN
viclafargue Nov 14, 2023
c655d91
Complete main parts and add tests
viclafargue Nov 15, 2023
9aeb456
Debugging
viclafargue Nov 20, 2023
224a59f
Implement search on shards
viclafargue Nov 20, 2023
ebca042
Debugging
viclafargue Nov 23, 2023
8141f72
ANN benchmark integration + offset fix + translations fix
viclafargue Jan 10, 2024
6d29eb1
Adding CAGRA capability
viclafargue Jan 26, 2024
34bb6c4
Testing serialization + use of pre-computed methods
viclafargue Feb 1, 2024
2c3fbc2
Add distribution feature
viclafargue Feb 2, 2024
b5680b2
Merge remote-tracking branch 'origin/branch-24.04' into snmg-ann
viclafargue Mar 26, 2024
4581fbd
SNMG ANN bench update
viclafargue Apr 23, 2024
e3b03b8
OpenMP
viclafargue May 3, 2024
a6707c3
Answering reviews
viclafargue May 6, 2024
4a91a7f
NCCL clique helper
viclafargue May 8, 2024
0a37d63
SNMG ANN IVF-Flat & IVF-PQ bench + fixes
viclafargue May 24, 2024
8417684
Fixes & improvements
viclafargue May 28, 2024
34d4fd3
Setting NCCL init apart for bench
viclafargue May 28, 2024
3f15c43
Mempool + NCCL fix
viclafargue Jun 12, 2024
b77f938
SNMG cagra bench
viclafargue Jun 12, 2024
fc748ae
SNMG CAGRA bench
viclafargue Jun 13, 2024
410562b
Merge branch 'branch-24.08' into snmg-ann
viclafargue Jun 17, 2024
666d47f
Increase search batch size + fix build
viclafargue Jun 17, 2024
1a559a6
mdspan feature for build and extend
viclafargue Jul 8, 2024
11d30da
style fix
viclafargue Jul 8, 2024
9af470e
Merge branch 'branch-24.08' into snmg-ann
viclafargue Jul 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,12 @@ _text
# clang tooling
compile_commands.json
.clangd/

ann_mg_ivf_flat_index
viclafargue marked this conversation as resolved.
Show resolved Hide resolved
ann_mg_ivf_pq_index
datasets/
index/
ivf_flat_index
local_cagra_index
local_ivf_flat_index
local_ivf_pq_index
1 change: 1 addition & 0 deletions build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,7 @@ if hasArg tests || (( ${NUMARGS} == 0 )); then
$CMAKE_TARGET == *"NEIGHBORS_ANN_BRUTE_FORCE_TEST"* || \
$CMAKE_TARGET == *"NEIGHBORS_ANN_CAGRA_TEST"* || \
$CMAKE_TARGET == *"NEIGHBORS_ANN_IVF_TEST"* || \
$CMAKE_TARGET == *"NEIGHBORS_ANN_MG_TEST"* || \
$CMAKE_TARGET == *"NEIGHBORS_ANN_NN_DESCENT_TEST"* || \
$CMAKE_TARGET == *"NEIGHBORS_TEST"* || \
$CMAKE_TARGET == *"SPARSE_DIST_TEST" || \
Expand Down
2 changes: 2 additions & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -599,6 +599,8 @@ if(RAFT_COMPILE_LIBRARY)
${RAFT_CTK_MATH_DEPENDENCIES} # TODO: Once `raft::resources` is used everywhere, this
# will just be cublas
$<TARGET_NAME_IF_EXISTS:OpenMP::OpenMP_CXX>
nccl
ucp
viclafargue marked this conversation as resolved.
Show resolved Hide resolved
)

# So consumers know when using libraft.so/libraft.a
Expand Down
14 changes: 14 additions & 0 deletions cpp/bench/ann/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ option(RAFT_ANN_BENCH_USE_RAFT_IVF_PQ "Include raft's ivf pq algorithm in benchm
option(RAFT_ANN_BENCH_USE_RAFT_CAGRA "Include raft's CAGRA in benchmark" ON)
option(RAFT_ANN_BENCH_USE_RAFT_BRUTE_FORCE "Include raft's brute force knn in benchmark" ON)
option(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB "Include raft's CAGRA in benchmark" ON)
option(RAFT_ANN_BENCH_USE_RAFT_ANN_MG "Include raft's MG ANN in benchmark" ON)
option(RAFT_ANN_BENCH_USE_HNSWLIB "Include hnsw algorithm in benchmark" ON)
option(RAFT_ANN_BENCH_USE_GGNN "Include ggnn algorithm in benchmark" ON)
option(RAFT_ANN_BENCH_SINGLE_EXE
Expand All @@ -55,6 +56,7 @@ if(BUILD_CPU_ONLY)
set(RAFT_ANN_BENCH_USE_RAFT_CAGRA OFF)
set(RAFT_ANN_BENCH_USE_RAFT_BRUTE_FORCE OFF)
set(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB OFF)
set(RAFT_ANN_BENCH_USE_RAFT_ANN_MG OFF)
set(RAFT_ANN_BENCH_USE_GGNN OFF)
else()
# Disable faiss benchmarks on CUDA 12 since faiss is not yet CUDA 12-enabled.
Expand Down Expand Up @@ -90,6 +92,7 @@ if(RAFT_ANN_BENCH_USE_RAFT_IVF_PQ
OR RAFT_ANN_BENCH_USE_RAFT_IVF_FLAT
OR RAFT_ANN_BENCH_USE_RAFT_CAGRA
OR RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB
OR RAFT_ANN_BENCH_USE_RAFT_ANN_MG
)
set(RAFT_ANN_BENCH_USE_RAFT ON)
endif()
Expand Down Expand Up @@ -272,6 +275,17 @@ if(RAFT_ANN_BENCH_USE_RAFT_CAGRA)
)
endif()

if(RAFT_ANN_BENCH_USE_RAFT_ANN_MG)
ConfigureAnnBench(
NAME
RAFT_ANN_MG
PATH
bench/ann/src/raft/raft_benchmark.cu
LINKS
raft::compiled
)
endif()

if(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB)
ConfigureAnnBench(
NAME RAFT_CAGRA_HNSWLIB PATH bench/ann/src/raft/raft_cagra_hnswlib.cu LINKS raft::compiled
Expand Down
1 change: 1 addition & 0 deletions cpp/bench/ann/src/common/benchmark.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include <numeric>
#include <sstream>
#include <string>
#include <thread>
#include <vector>

namespace raft::bench::ann {
Expand Down
121 changes: 121 additions & 0 deletions cpp/bench/ann/src/raft/raft_ann_mg_wrapper.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
/*
viclafargue marked this conversation as resolved.
Show resolved Hide resolved
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include "../common/ann_types.hpp"
#include "raft_ann_bench_utils.h"
#include <raft/neighbors/ann_mg.cuh>

namespace raft::bench::ann {

template <typename T, typename IdxT>
class RaftAnnMG : public ANN<T> {
public:
using typename ANN<T>::AnnSearchParam;

struct SearchParam : public AnnSearchParam {
raft::neighbors::ivf_flat::search_params ivf_flat_params;
};

using BuildParam = raft::neighbors::ivf_flat::index_params;

RaftAnnMG(Metric metric, int dim, const BuildParam& param)
: ANN<T>(metric, dim), index_params_(param), dimension_(dim)
{
index_params_.metric = parse_metric_type(metric);
index_params_.conservative_memory_allocation = true;
RAFT_CUDA_TRY(cudaGetDevice(&device_));
}

~RaftAnnMG() noexcept {}

void build(const T* dataset, size_t nrow, cudaStream_t stream) final;

void set_search_param(const AnnSearchParam& param) override;

// TODO: if the number of results is less than k, the remaining elements of 'neighbors'
// will be filled with (size_t)-1
void search(const T* queries,
int batch_size,
int k,
size_t* neighbors,
float* distances,
cudaStream_t stream = 0) const override;

// to enable dataset access from GPU memory
AlgoProperty get_preference() const override
{
AlgoProperty property;
property.dataset_memory_type = MemoryType::Host;
property.query_memory_type = MemoryType::Host;
return property;
}
void save(const std::string& file) const override;
void load(const std::string&) override;

private:
raft::device_resources handle_;
BuildParam index_params_;
raft::neighbors::ivf_flat::search_params search_params_;
std::optional<raft::neighbors::mg::detail::ann_mg_index<raft::neighbors::ivf_flat::index<T, IdxT>, T, IdxT>> index_;
int device_;
int dimension_;
};

template <typename T, typename IdxT>
void RaftAnnMG<T, IdxT>::build(const T* dataset, size_t nrow, cudaStream_t)
{
std::vector<int> device_ids{0, 1};
raft::neighbors::mg::dist_mode d_mode = raft::neighbors::mg::dist_mode::INDEX_DUPLICATION;
auto dataset_matrix = raft::make_host_matrix_view<const T, IdxT, row_major>(dataset, IdxT(nrow), IdxT(dimension_));
index_ = neighbors::mg::build<T, IdxT>(device_ids, d_mode, index_params_, dataset_matrix);
return;
}

template <typename T, typename IdxT>
void RaftAnnMG<T, IdxT>::set_search_param(const AnnSearchParam& param)
{
auto search_param = dynamic_cast<const SearchParam&>(param);
search_params_ = search_param.ivf_flat_params;
assert(search_params_.n_probes <= index_params_.n_lists);
}

template <typename T, typename IdxT>
void RaftAnnMG<T, IdxT>::save(const std::string& file) const
{
raft::neighbors::mg::serialize<T, IdxT>(handle_, index_.value(), file);
return;
}

template <typename T, typename IdxT>
void RaftAnnMG<T, IdxT>::load(const std::string& file)
{
index_.emplace(raft::neighbors::mg::deserialize_flat<T, IdxT>(handle_, file));
}

template <typename T, typename IdxT>
void RaftAnnMG<T, IdxT>::search(
const T* queries, int batch_size, int k, size_t* neighbors, float* distances, cudaStream_t) const
{
static_assert(sizeof(size_t) == sizeof(IdxT), "IdxT is incompatible with size_t");
auto query_matrix = raft::make_host_matrix_view<const T, IdxT, row_major>(queries, IdxT(batch_size), IdxT(dimension_));
auto neighbors_matrix = raft::make_host_matrix_view<IdxT, IdxT, row_major>((IdxT*)neighbors, IdxT(batch_size), IdxT(k));
auto distances_matrix = raft::make_host_matrix_view<float, IdxT, row_major>(distances, IdxT(batch_size), IdxT(k));
raft::neighbors::mg::search<T, IdxT>(index_.value(), search_params_, query_matrix, neighbors_matrix, distances_matrix);
resource::sync_stream(handle_);
return;
}
} // namespace raft::bench::ann
168 changes: 168 additions & 0 deletions cpp/include/raft/neighbors/ann_mg.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <raft/neighbors/detail/ann_mg.cuh>

namespace raft::neighbors::mg {

template <typename T, typename IdxT>
viclafargue marked this conversation as resolved.
Show resolved Hide resolved
auto build(const std::vector<int> device_ids,
raft::neighbors::mg::dist_mode mode,
const ivf_flat::index_params& index_params,
viclafargue marked this conversation as resolved.
Show resolved Hide resolved
raft::host_matrix_view<const T, IdxT, row_major> index_dataset)
-> detail::ann_mg_index<ivf_flat::index<T, IdxT>, T, IdxT>
{
return mg::detail::build<T, IdxT>(device_ids, mode, index_params, index_dataset);
}

template <typename T, typename IdxT>
auto build(const std::vector<int> device_ids,
raft::neighbors::mg::dist_mode mode,
const ivf_pq::index_params& index_params,
raft::host_matrix_view<const T, IdxT, row_major> index_dataset)
-> detail::ann_mg_index<ivf_pq::index<IdxT>, T, IdxT>
{
return mg::detail::build<T>(device_ids, mode, index_params, index_dataset);
}

template <typename T, typename IdxT>
auto build(const std::vector<int> device_ids,
raft::neighbors::mg::dist_mode mode,
const cagra::index_params& index_params,
raft::host_matrix_view<const T, IdxT, row_major> index_dataset)
-> detail::ann_mg_index<cagra::index<T, IdxT>, T, IdxT>
{
return mg::detail::build<T, IdxT>(device_ids, mode, index_params, index_dataset);
}

template <typename T, typename IdxT>
void extend(detail::ann_mg_index<ivf_flat::index<T, IdxT>, T, IdxT>& index,
raft::host_matrix_view<const T, IdxT, row_major> new_vectors,
std::optional<raft::host_vector_view<const IdxT, IdxT>> new_indices)
{
mg::detail::extend<T, IdxT>(index, new_vectors, new_indices);
}

template <typename T, typename IdxT>
void extend(detail::ann_mg_index<ivf_pq::index<IdxT>, T, IdxT>& index,
raft::host_matrix_view<const T, IdxT, row_major> new_vectors,
std::optional<raft::host_vector_view<const IdxT, IdxT>> new_indices)
{
mg::detail::extend<T>(index, new_vectors, new_indices);
}

template <typename T, typename IdxT>
void search(const detail::ann_mg_index<ivf_flat::index<T, IdxT>, T, IdxT>& index,
const ivf_flat::search_params& search_params,
raft::host_matrix_view<const T, IdxT, row_major> query_dataset,
raft::host_matrix_view<IdxT, IdxT, row_major> neighbors,
raft::host_matrix_view<float, IdxT, row_major> distances)
{
mg::detail::search<T, IdxT>(index, search_params, query_dataset, neighbors, distances);
}

template <typename T, typename IdxT>
void search(const detail::ann_mg_index<ivf_pq::index<IdxT>, T, IdxT>& index,
const ivf_pq::search_params& search_params,
raft::host_matrix_view<const T, IdxT, row_major> query_dataset,
raft::host_matrix_view<IdxT, IdxT, row_major> neighbors,
raft::host_matrix_view<float, IdxT, row_major> distances)
{
mg::detail::search<T>(index, search_params, query_dataset, neighbors, distances);
}

template <typename T, typename IdxT>
void search(const detail::ann_mg_index<cagra::index<T, IdxT>, T, IdxT>& index,
const cagra::search_params& search_params,
raft::host_matrix_view<const T, IdxT, row_major> query_dataset,
raft::host_matrix_view<IdxT, IdxT, row_major> neighbors,
raft::host_matrix_view<float, IdxT, row_major> distances)
{
mg::detail::search<T, IdxT>(index, search_params, query_dataset, neighbors, distances);
}

template <typename T, typename IdxT>
void serialize(const raft::resources& handle,
const detail::ann_mg_index<ivf_flat::index<T, IdxT>, T, IdxT>& index,
const std::string& filename)
{
mg::detail::serialize(handle, index, filename);
}

template <typename T, typename IdxT>
void serialize(const raft::resources& handle,
viclafargue marked this conversation as resolved.
Show resolved Hide resolved
const detail::ann_mg_index<ivf_pq::index<IdxT>, T, IdxT>& index,
const std::string& filename)
{
mg::detail::serialize(handle, index, filename);
}

template <typename T, typename IdxT>
void serialize(const raft::resources& handle,
const detail::ann_mg_index<cagra::index<T, IdxT>, T, IdxT>& index,
const std::string& filename)
{
mg::detail::serialize(handle, index, filename);
}

template <typename T, typename IdxT>
detail::ann_mg_index<ivf_flat::index<T, IdxT>, T, IdxT> deserialize_flat(const raft::resources& handle,
const std::string& filename)
{
return mg::detail::deserialize_flat<T, IdxT>(handle, filename);
}

template <typename T, typename IdxT>
detail::ann_mg_index<ivf_pq::index<IdxT>, T, IdxT> deserialize_pq(const raft::resources& handle,
const std::string& filename)
{
return mg::detail::deserialize_pq<T, IdxT>(handle, filename);
}

template <typename T, typename IdxT>
detail::ann_mg_index<cagra::index<T, IdxT>, T, IdxT> deserialize_cagra(const raft::resources& handle,
const std::string& filename)
{
return mg::detail::deserialize_cagra<T, IdxT>(handle, filename);
}

template <typename T, typename IdxT>
detail::ann_mg_index<ivf_flat::index<T, IdxT>, T, IdxT> distribute_flat(const raft::resources& handle,
const std::vector<int>& dev_list,
const std::string& filename)
{
return mg::detail::distribute_flat<T, IdxT>(handle, dev_list, filename);
}

template <typename T, typename IdxT>
detail::ann_mg_index<ivf_pq::index<IdxT>, T, IdxT> distribute_pq(const raft::resources& handle,
const std::vector<int>& dev_list,
const std::string& filename)
{
return mg::detail::distribute_pq<T, IdxT>(handle, dev_list, filename);
}

template <typename T, typename IdxT>
detail::ann_mg_index<cagra::index<T, IdxT>, T, IdxT> distribute_cagra(const raft::resources& handle,
const std::vector<int>& dev_list,
const std::string& filename)
{
return mg::detail::distribute_cagra<T, IdxT>(handle, dev_list, filename);
}

} // namespace raft::neighbors::mg
2 changes: 1 addition & 1 deletion cpp/include/raft/neighbors/brute_force-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ inline void knn_merge_parts(
RAFT_EXPECTS(in_keys.extent(1) == in_values.extent(1) && in_keys.extent(0) == in_values.extent(0),
"in_keys and in_values must have the same shape.");
RAFT_EXPECTS(
out_keys.extent(0) == out_values.extent(0) && out_keys.extent(0) == n_samples,
out_keys.extent(0) == out_values.extent(0) && out_keys.extent(0) == idx_t(n_samples),
"Number of rows in output keys and val matrices must equal number of rows in search matrix.");
RAFT_EXPECTS(
out_keys.extent(1) == out_values.extent(1) && out_keys.extent(1) == in_keys.extent(1),
Expand Down
Loading
Loading