From 2d3af5e2d2039e884d56ec45f3be143977d31ace Mon Sep 17 00:00:00 2001 From: divyegala Date: Wed, 26 Jun 2024 12:54:43 -0700 Subject: [PATCH] fix bugs --- cpp/include/raft/neighbors/detail/hnsw.hpp | 1 + cpp/include/raft/neighbors/detail/hnsw_types.hpp | 5 +++++ cpp/include/raft/neighbors/hnsw.hpp | 2 +- cpp/include/raft/neighbors/hnsw_types.hpp | 5 +++++ cpp/src/raft_runtime/neighbors/hnsw.cpp | 8 +++++++- python/pylibraft/pylibraft/neighbors/hnsw.pyx | 6 ++++-- 6 files changed, 23 insertions(+), 4 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/hnsw.hpp b/cpp/include/raft/neighbors/detail/hnsw.hpp index 5deee3c2ba..bd4e6608de 100644 --- a/cpp/include/raft/neighbors/detail/hnsw.hpp +++ b/cpp/include/raft/neighbors/detail/hnsw.hpp @@ -53,6 +53,7 @@ void search(raft::resources const& res, raft::host_matrix_view neighbors, raft::host_matrix_view distances) { + idx.set_ef(params.ef); auto const* hnswlib_index = reinterpret_cast::type> const*>( idx.get_index()); diff --git a/cpp/include/raft/neighbors/detail/hnsw_types.hpp b/cpp/include/raft/neighbors/detail/hnsw_types.hpp index 9d35effd1a..8d601f59ae 100644 --- a/cpp/include/raft/neighbors/detail/hnsw_types.hpp +++ b/cpp/include/raft/neighbors/detail/hnsw_types.hpp @@ -93,6 +93,11 @@ struct index_impl : index { */ auto get_index() const -> void const* override { return appr_alg_.get(); } + /** + @brief Set ef for search + */ + void set_ef(int ef) const override { appr_alg_->ef_ = ef; } + private: std::unique_ptr::type>> appr_alg_; std::unique_ptr::type>> space_; diff --git a/cpp/include/raft/neighbors/hnsw.hpp b/cpp/include/raft/neighbors/hnsw.hpp index 964c3ffacd..ee3f61e550 100644 --- a/cpp/include/raft/neighbors/hnsw.hpp +++ b/cpp/include/raft/neighbors/hnsw.hpp @@ -35,7 +35,7 @@ namespace raft::neighbors::hnsw { /** * @brief Construct an hnswlib base-layer-only index from a CAGRA index - * NOTE: 1. This method uses the filesystem to write the CAGRA index in `/tmp/cagra_index.bin` + * NOTE: 1. This method uses the filesystem to write the CAGRA index in `/tmp/.bin` * before reading it as an hnswlib index, then deleting the temporary file. * 2. This function is only offered as a compiled symbol in `libraft.so` * diff --git a/cpp/include/raft/neighbors/hnsw_types.hpp b/cpp/include/raft/neighbors/hnsw_types.hpp index 645a0903b7..f90de6f01b 100644 --- a/cpp/include/raft/neighbors/hnsw_types.hpp +++ b/cpp/include/raft/neighbors/hnsw_types.hpp @@ -62,6 +62,11 @@ struct index : ann::index { auto metric() const -> raft::distance::DistanceType { return metric_; } + /** + @brief Set ef for search + */ + virtual void set_ef(int ef) const; + private: int dim_; raft::distance::DistanceType metric_; diff --git a/cpp/src/raft_runtime/neighbors/hnsw.cpp b/cpp/src/raft_runtime/neighbors/hnsw.cpp index 6eb770abd6..5356e708d2 100644 --- a/cpp/src/raft_runtime/neighbors/hnsw.cpp +++ b/cpp/src/raft_runtime/neighbors/hnsw.cpp @@ -21,6 +21,8 @@ #include #include +#include +#include namespace raft::neighbors::hnsw { #define RAFT_INST_HNSW(T) \ @@ -28,7 +30,11 @@ namespace raft::neighbors::hnsw { std::unique_ptr> from_cagra( \ raft::resources const& res, raft::neighbors::cagra::index cagra_index) \ { \ - std::string filepath = "/tmp/cagra_index.bin"; \ + std::random_device dev; \ + std::mt19937 rng(dev()); \ + std::uniform_int_distribution dist(0); \ + auto uuid = std::to_string(dist(rng)); \ + std::string filepath = "/tmp/" + uuid + ".bin"; \ raft::runtime::neighbors::cagra::serialize_to_hnswlib(res, filepath, cagra_index); \ auto hnsw_index = raft::runtime::neighbors::hnsw::deserialize_file( \ res, filepath, cagra_index.dim(), cagra_index.metric()); \ diff --git a/python/pylibraft/pylibraft/neighbors/hnsw.pyx b/python/pylibraft/pylibraft/neighbors/hnsw.pyx index aa589ffb65..e6f2d69eb8 100644 --- a/python/pylibraft/pylibraft/neighbors/hnsw.pyx +++ b/python/pylibraft/pylibraft/neighbors/hnsw.pyx @@ -52,6 +52,7 @@ from pylibraft.common.mdspan cimport ( from pylibraft.neighbors.common cimport _get_metric_string import os +import uuid import numpy as np @@ -292,7 +293,7 @@ def from_cagra(Index index, handle=None): Returns an hnswlib base-layer-only index from a CAGRA index. NOTE: This method uses the filesystem to write the CAGRA index in - `/tmp/cagra_index.bin` before reading it as an hnswlib index, + `/tmp/.bin` before reading it as an hnswlib index, then deleting the temporary file. Saving / loading the index is experimental. The serialization format is @@ -320,7 +321,8 @@ def from_cagra(Index index, handle=None): >>> # Serialize the CAGRA index to hnswlib base layer only index format >>> hnsw_index = hnsw.from_cagra(index, handle=handle) """ - filename = "/tmp/cagra_index.bin" + uuid_num = uuid.uuid4() + filename = f"/tmp/{uuid_num}.bin" save(filename, index, handle=handle) hnsw_index = load(filename, index.dim, np.dtype(index.active_index_type), _get_metric_string(index.metric), handle=handle)