diff --git a/cpp/bench/ann/src/raft/raft_cagra_wrapper.h b/cpp/bench/ann/src/raft/raft_cagra_wrapper.h old mode 100644 new mode 100755 index ec71de9cff..fe07d02452 --- a/cpp/bench/ann/src/raft/raft_cagra_wrapper.h +++ b/cpp/bench/ann/src/raft/raft_cagra_wrapper.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -159,7 +159,8 @@ void RaftCagra::build(const T* dataset, size_t nrow, cudaStream_t strea index_params_.nn_descent_params, index_params_.ivf_pq_refine_rate, index_params_.ivf_pq_build_params, - index_params_.ivf_pq_search_params))); + index_params_.ivf_pq_search_params, + false))); handle_.stream_wait(stream); // RAFT stream -> bench stream } diff --git a/cpp/include/raft/neighbors/cagra_types.hpp b/cpp/include/raft/neighbors/cagra_types.hpp index e8a0b8a7bd..00c363b377 100644 --- a/cpp/include/raft/neighbors/cagra_types.hpp +++ b/cpp/include/raft/neighbors/cagra_types.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -145,7 +145,7 @@ struct index : ann::index { /** Total length of the index (number of vectors). */ [[nodiscard]] constexpr inline auto size() const noexcept -> IdxT { - return dataset_view_.extent(0); + return dataset_view_.extent(0) ? dataset_view_.extent(0) : graph_view_.extent(0); } /** Dimensionality of the data. */ diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh old mode 100644 new mode 100755 index 812cca5b3b..e7f7d96a13 --- a/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -296,7 +296,8 @@ index build( std::optional nn_descent_params = std::nullopt, std::optional refine_rate = std::nullopt, std::optional pq_build_params = std::nullopt, - std::optional search_params = std::nullopt) + std::optional search_params = std::nullopt, + bool construct_index_with_dataset = true) { size_t intermediate_degree = params.intermediate_graph_degree; size_t graph_degree = params.graph_degree; @@ -334,12 +335,22 @@ index build( auto cagra_graph = raft::make_host_matrix(dataset.extent(0), graph_degree); + RAFT_LOG_INFO("optimizing graph"); optimize(res, knn_graph->view(), cagra_graph.view()); // free intermediate graph before trying to create the index knn_graph.reset(); + RAFT_LOG_INFO("Graph optimized, creating index"); // Construct an index from dataset and optimized knn graph. - return index(res, params.metric, dataset, raft::make_const_mdspan(cagra_graph.view())); + if (construct_index_with_dataset) { + return index(res, params.metric, dataset, raft::make_const_mdspan(cagra_graph.view())); + } else { + // We just add the graph. User is expected to update dataset separately. This branch is used + // if user needs special control of memory allocations for the dataset. + index idx(res, params.metric); + idx.update_graph(res, raft::make_const_mdspan(cagra_graph.view())); + return idx; + } } } // namespace raft::neighbors::cagra::detail diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh index 42a979f059..84a9ce96cd 100644 --- a/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh @@ -65,8 +65,11 @@ void serialize(raft::resources const& res, serialize_scalar(res, os, index_.metric()); serialize_mdspan(res, os, index_.graph()); + include_dataset &= (index_.dataset().extent(0) > 0); + serialize_scalar(res, os, include_dataset); if (include_dataset) { + RAFT_LOG_INFO("Saving CAGRA index with dataset"); auto dataset = index_.dataset(); // Remove padding before saving the dataset auto host_dataset = make_host_matrix(dataset.extent(0), dataset.extent(1)); @@ -80,6 +83,8 @@ void serialize(raft::resources const& res, resource::get_cuda_stream(res))); resource::sync_stream(res); serialize_mdspan(res, os, host_dataset.view()); + } else { + RAFT_LOG_INFO("Saving CAGRA index WITHOUT dataset"); } }