Skip to content

Commit

Permalink
Fix ef setting in HNSW wrapper (#2367)
Browse files Browse the repository at this point in the history
Closes #2363 

Bugs fixed: 

1. Setting `ef` in search, it was not being set at all before
2. `from_cagra` used a hard coded filename to serialize CAGRA graph and deserialize to HNSW graph. The PR changes the hardcoded filename to a random string so that multiple graphs may be converted concurrently

cc @Presburger thank you for reporting these bugs

Authors:
  - Divye Gala (https://github.com/divyegala)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #2367
  • Loading branch information
divyegala authored Jun 27, 2024
1 parent b863f18 commit 53e7982
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 4 deletions.
1 change: 1 addition & 0 deletions cpp/include/raft/neighbors/detail/hnsw.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ void search(raft::resources const& res,
raft::host_matrix_view<uint64_t, int64_t, row_major> neighbors,
raft::host_matrix_view<float, int64_t, row_major> distances)
{
idx.set_ef(params.ef);
auto const* hnswlib_index =
reinterpret_cast<hnswlib::HierarchicalNSW<typename hnsw_dist_t<T>::type> const*>(
idx.get_index());
Expand Down
5 changes: 5 additions & 0 deletions cpp/include/raft/neighbors/detail/hnsw_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,11 @@ struct index_impl : index<T> {
*/
auto get_index() const -> void const* override { return appr_alg_.get(); }

/**
@brief Set ef for search
*/
void set_ef(int ef) const override { appr_alg_->ef_ = ef; }

private:
std::unique_ptr<hnswlib::HierarchicalNSW<typename hnsw_dist_t<T>::type>> appr_alg_;
std::unique_ptr<hnswlib::SpaceInterface<typename hnsw_dist_t<T>::type>> space_;
Expand Down
2 changes: 1 addition & 1 deletion cpp/include/raft/neighbors/hnsw.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ namespace raft::neighbors::hnsw {

/**
* @brief Construct an hnswlib base-layer-only index from a CAGRA index
* NOTE: 1. This method uses the filesystem to write the CAGRA index in `/tmp/cagra_index.bin`
* NOTE: 1. This method uses the filesystem to write the CAGRA index in `/tmp/<random_number>.bin`
* before reading it as an hnswlib index, then deleting the temporary file.
* 2. This function is only offered as a compiled symbol in `libraft.so`
*
Expand Down
5 changes: 5 additions & 0 deletions cpp/include/raft/neighbors/hnsw_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,11 @@ struct index : ann::index {

auto metric() const -> raft::distance::DistanceType { return metric_; }

/**
@brief Set ef for search
*/
virtual void set_ef(int ef) const;

private:
int dim_;
raft::distance::DistanceType metric_;
Expand Down
8 changes: 7 additions & 1 deletion cpp/src/raft_runtime/neighbors/hnsw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,20 @@
#include <raft_runtime/neighbors/hnsw.hpp>

#include <filesystem>
#include <random>
#include <string>

namespace raft::neighbors::hnsw {
#define RAFT_INST_HNSW(T) \
template <> \
std::unique_ptr<raft::neighbors::hnsw::index<T>> from_cagra( \
raft::resources const& res, raft::neighbors::cagra::index<T, uint32_t> cagra_index) \
{ \
std::string filepath = "/tmp/cagra_index.bin"; \
std::random_device dev; \
std::mt19937 rng(dev()); \
std::uniform_int_distribution<std::mt19937::result_type> dist(0); \
auto uuid = std::to_string(dist(rng)); \
std::string filepath = "/tmp/" + uuid + ".bin"; \
raft::runtime::neighbors::cagra::serialize_to_hnswlib(res, filepath, cagra_index); \
auto hnsw_index = raft::runtime::neighbors::hnsw::deserialize_file<T>( \
res, filepath, cagra_index.dim(), cagra_index.metric()); \
Expand Down
6 changes: 4 additions & 2 deletions python/pylibraft/pylibraft/neighbors/hnsw.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ from pylibraft.common.mdspan cimport (
from pylibraft.neighbors.common cimport _get_metric_string

import os
import uuid

import numpy as np

Expand Down Expand Up @@ -292,7 +293,7 @@ def from_cagra(Index index, handle=None):
Returns an hnswlib base-layer-only index from a CAGRA index.
NOTE: This method uses the filesystem to write the CAGRA index in
`/tmp/cagra_index.bin` before reading it as an hnswlib index,
`/tmp/<random_number>.bin` before reading it as an hnswlib index,
then deleting the temporary file.
Saving / loading the index is experimental. The serialization format is
Expand Down Expand Up @@ -320,7 +321,8 @@ def from_cagra(Index index, handle=None):
>>> # Serialize the CAGRA index to hnswlib base layer only index format
>>> hnsw_index = hnsw.from_cagra(index, handle=handle)
"""
filename = "/tmp/cagra_index.bin"
uuid_num = uuid.uuid4()
filename = f"/tmp/{uuid_num}.bin"
save(filename, index, handle=handle)
hnsw_index = load(filename, index.dim, np.dtype(index.active_index_type),
_get_metric_string(index.metric), handle=handle)
Expand Down

0 comments on commit 53e7982

Please sign in to comment.