Skip to content
Permalink

Comparing changes

This is a direct comparison between two commits made in this repository or its related repositories. View the default comparison for this range or learn more about diff comparisons.

Open a pull request

Create a new pull request by comparing changes across two branches. If you need to, you can also . Learn more about diff comparisons here.
base repository: opensearch-project/k-NN
Failed to load repositories. Confirm that selected base ref is valid, then try again.
Loading
base: dd837799584effb6fdb2b2001934d4ac95070e00
Choose a base ref
..
head repository: opensearch-project/k-NN
Failed to load repositories. Confirm that selected head ref is valid, then try again.
Loading
compare: 5bc6606b60350c645da9fcc3a27e1fe27240dcc0
Choose a head ref
Showing with 34 additions and 28 deletions.
  1. +34 −28 jni/patches/faiss/0001-Custom-patch-to-support-multi-vector.patch
62 changes: 34 additions & 28 deletions jni/patches/faiss/0001-Custom-patch-to-support-multi-vector.patch
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
From 8f89fbf1cf445a8216b9cc4ee52e7b82e24e906e Mon Sep 17 00:00:00 2001
From e678b7827b63195bc78c3384b1cf37ed6c079adc Mon Sep 17 00:00:00 2001
From: Heemin Kim <heemin@amazon.com>
Date: Wed, 6 Dec 2023 16:33:52 -0800
Subject: [PATCH] Custom patch to support multi-vector
@@ -7,12 +7,12 @@ Signed-off-by: Heemin Kim <heemin@amazon.com>
---
faiss/CMakeLists.txt | 2 +
faiss/Index.h | 6 ++-
faiss/IndexIDMap.cpp | 20 ++++++++++
faiss/IndexIDMap.cpp | 21 ++++++++++
faiss/IndexIDMap.h | 1 +
faiss/impl/HNSW.cpp | 25 ++++++++-----
faiss/impl/ResultCollector.h | 58 +++++++++++++++++++++++++++++
faiss/impl/ResultCollectorFactory.h | 29 +++++++++++++++
7 files changed, 129 insertions(+), 12 deletions(-)
faiss/impl/HNSW.cpp | 27 ++++++++-----
faiss/impl/ResultCollector.h | 61 +++++++++++++++++++++++++++++
faiss/impl/ResultCollectorFactory.h | 29 ++++++++++++++
7 files changed, 135 insertions(+), 12 deletions(-)
create mode 100644 faiss/impl/ResultCollector.h
create mode 100644 faiss/impl/ResultCollectorFactory.h

@@ -57,46 +57,47 @@ index 4b4b302b..13eab0c0 100644
virtual ~SearchParameters() {}
};
diff --git a/faiss/IndexIDMap.cpp b/faiss/IndexIDMap.cpp
index 7972bec9..a5c017a9 100644
index 7972bec9..407e0e61 100644
--- a/faiss/IndexIDMap.cpp
+++ b/faiss/IndexIDMap.cpp
@@ -102,6 +102,20 @@ struct ScopedSelChange {
@@ -102,6 +102,22 @@ struct ScopedSelChange {
}
};

+/// RAII object to reset the ResultCollectorFactory in the params object
+// RAII object to reset the id_map parameter in ResultCollectorFactory object
+// This object make sure to reset the id_map parameter in ResultCollectorFactory once
+// the program exist current method scope.
+struct ScopedColChange {
+ SearchParameters* params = nullptr;
+ void set(SearchParameters* params, const std::vector<int64_t>* id_map) {
+ this->params = params;
+ params->col->id_map = id_map;
+ ResultCollectorFactory* collector_factory = nullptr;
+ void set(ResultCollectorFactory* collector_factory, const std::vector<int64_t>* id_map) {
+ this->collector_factory = collector_factory;
+ collector_factory->id_map = id_map;
+ }
+ ~ScopedColChange() {
+ if (params) {
+ params->col->id_map = nullptr;
+ if (collector_factory) {
+ collector_factory->id_map = nullptr;
+ }
+ }
+};
+
} // namespace

template <typename IndexT>
@@ -114,6 +128,7 @@ void IndexIDMapTemplate<IndexT>::search(
@@ -114,6 +130,7 @@ void IndexIDMapTemplate<IndexT>::search(
const SearchParameters* params) const {
IDSelectorTranslated this_idtrans(this->id_map, nullptr);
ScopedSelChange sel_change;
+ ScopedColChange col_change;

if (params && params->sel) {
auto idtrans = dynamic_cast<const IDSelectorTranslated*>(params->sel);
@@ -131,6 +146,11 @@ void IndexIDMapTemplate<IndexT>::search(
@@ -131,6 +148,10 @@ void IndexIDMapTemplate<IndexT>::search(
sel_change.set(params_non_const, &this_idtrans);
}
}
+
+ if (params && params->col) {
+ auto params_non_const = const_cast<SearchParameters*>(params);
+ col_change.set(params_non_const, &this->id_map);
+ if (params && params->col && !params->col->id_map) {
+ col_change.set(params->col, &this->id_map);
+ }
index->search(n, x, k, distances, labels, params);
idx_t* li = labels;
@@ -114,7 +115,7 @@ index 2d164123..c6a1be73 100644
#include <unordered_map>
#include <vector>
diff --git a/faiss/impl/HNSW.cpp b/faiss/impl/HNSW.cpp
index 9fc201ea..540210a6 100644
index 9fc201ea..f7f89b7a 100644
--- a/faiss/impl/HNSW.cpp
+++ b/faiss/impl/HNSW.cpp
@@ -14,6 +14,7 @@
@@ -167,22 +168,24 @@ index 9fc201ea..540210a6 100644
}
candidates.push(idx, dis);
};
@@ -660,6 +662,9 @@ int search_from_candidates(
@@ -660,6 +662,11 @@ int search_from_candidates(
}
}

+ // Completed collection of result. Run post processor.
+ collector->finalize(nres, I);
+ // Collector completed its task. Release all resource of the collector.
+ collectorFactory->deleteCollector(collector);
+
if (level == 0) {
stats.n1++;
if (candidates.size() == 0) {
diff --git a/faiss/impl/ResultCollector.h b/faiss/impl/ResultCollector.h
new file mode 100644
index 00000000..3e4dac34
index 00000000..40ca20c8
--- /dev/null
+++ b/faiss/impl/ResultCollector.h
@@ -0,0 +1,58 @@
@@ -0,0 +1,61 @@
+/**
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ *
@@ -233,6 +236,9 @@ index 00000000..3e4dac34
+ }
+ }
+
+ // This method is called once all result is collected so that final post processing can be done
+ // For example, if the result is collected using group id, the group id can be converted back to
+ // its original id inside this method
+ void finalize(idx_t nres, idx_t* bh_ids) override {
+ // Do nothing
+ }
@@ -243,7 +249,7 @@ index 00000000..3e4dac34
+} // namespace faiss
diff --git a/faiss/impl/ResultCollectorFactory.h b/faiss/impl/ResultCollectorFactory.h
new file mode 100644
index 00000000..4d903f8d
index 00000000..077c2ce7
--- /dev/null
+++ b/faiss/impl/ResultCollectorFactory.h
@@ -0,0 +1,29 @@
@@ -258,20 +264,20 @@ index 00000000..4d903f8d
+#include <faiss/impl/ResultCollector.h>
+namespace faiss {
+
+/** ResultCollector is intended to define how to collect search result */
+/** ResultCollectorFactory to create a ResultCollector object */
+struct ResultCollectorFactory {
+ DefaultCollector default_collector;
+ const std::vector<int64_t>* id_map;
+
+ // For each result, collect method is called to store result
+ // Create a new ResultCollector object
+ virtual ResultCollector* newCollector() {
+ return &default_collector;
+ }
+
+ virtual void deleteCollector(ResultCollector* collector) {
+ // Do nothing
+ }
+ // This method is called after all result is collected
+
+ virtual ~ResultCollectorFactory() {}
+};
+