Skip to content

Commit

Permalink
Add ABS_INNER_PRODUCT metric (facebookresearch#3524)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebookresearch#3524

Searches with the metric abs(dot(query, database))
This makes it possible to search vectors that are closest to a hyperplane

* adds support for alternative metrics in faiss.knn in python

* checks that it works with HNSW

* simplifies the extra distances interface by removing the template on

Reviewed By: asadoughi

Differential Revision: D58695971

fbshipit-source-id: 2a0ff49c7f7ac2c005d85f141cc5de148081c9c4
  • Loading branch information
mdouze authored and facebook-github-bot committed Jun 18, 2024
1 parent e188eb3 commit e758973
Show file tree
Hide file tree
Showing 8 changed files with 82 additions and 50 deletions.
18 changes: 11 additions & 7 deletions faiss/IndexFlat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,19 @@ void IndexFlat::search(
} else if (metric_type == METRIC_L2) {
float_maxheap_array_t res = {size_t(n), size_t(k), labels, distances};
knn_L2sqr(x, get_xb(), d, n, ntotal, &res, nullptr, sel);
} else if (is_similarity_metric(metric_type)) {
float_minheap_array_t res = {size_t(n), size_t(k), labels, distances};
knn_extra_metrics(
x, get_xb(), d, n, ntotal, metric_type, metric_arg, &res);
} else {
FAISS_THROW_IF_NOT(!sel);
float_maxheap_array_t res = {size_t(n), size_t(k), labels, distances};
FAISS_THROW_IF_NOT(!sel); // TODO implement with selector
knn_extra_metrics(
x, get_xb(), d, n, ntotal, metric_type, metric_arg, &res);
x,
get_xb(),
d,
n,
ntotal,
metric_type,
metric_arg,
k,
distances,
labels);
}
}

Expand Down
8 changes: 6 additions & 2 deletions faiss/MetricType.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,13 @@ enum MetricType {
METRIC_Canberra = 20,
METRIC_BrayCurtis,
METRIC_JensenShannon,
METRIC_Jaccard, ///< defined as: sum_i(min(a_i, b_i)) / sum_i(max(a_i, b_i))
///< where a_i, b_i > 0

/// sum_i(min(a_i, b_i)) / sum_i(max(a_i, b_i)) where a_i, b_i > 0
METRIC_Jaccard,
/// Squared Eucliden distance, ignoring NaNs
METRIC_NaNEuclidean,
/// abs(x | y): the distance to a hyperplane
METRIC_ABS_INNER_PRODUCT,
};

/// all vector indices are this type
Expand Down
12 changes: 9 additions & 3 deletions faiss/python/extra_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ def lookup(self, keys):
# KNN function
######################################################

def knn(xq, xb, k, metric=METRIC_L2):
def knn(xq, xb, k, metric=METRIC_L2, metric_arg=0.0):
"""
Compute the k nearest neighbors of a vector without constructing an index
Expand Down Expand Up @@ -374,10 +374,16 @@ def knn(xq, xb, k, metric=METRIC_L2):
swig_ptr(xq), swig_ptr(xb),
d, nq, nb, k, swig_ptr(D), swig_ptr(I)
)
else:
raise NotImplementedError("only L2 and INNER_PRODUCT are supported")
else:
knn_extra_metrics(
swig_ptr(xq), swig_ptr(xb),
d, nq, nb, metric, metric_arg, k,
swig_ptr(D), swig_ptr(I)
)

return D, I


def knn_hamming(xq, xb, k, variant="hc"):
"""
Compute the k nearest neighbors of a set of vectors without constructing an index.
Expand Down
12 changes: 12 additions & 0 deletions faiss/utils/extra_distances-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,4 +150,16 @@ inline float VectorDistance<METRIC_NaNEuclidean>::operator()(
}
return float(d) / float(present) * accu;
}

template <>
inline float VectorDistance<METRIC_ABS_INNER_PRODUCT>::operator()(
const float* x,
const float* y) const {
float accu = 0;
for (size_t i = 0; i < d; i++) {
accu += fabs(x[i] * y[i]);
}
return accu;
}

} // namespace faiss
55 changes: 19 additions & 36 deletions faiss/utils/extra_distances.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,18 @@ void pairwise_extra_distances_template(
}
}

template <class VD, class C>
template <class VD>
void knn_extra_metrics_template(
VD vd,
const float* x,
const float* y,
size_t nx,
size_t ny,
HeapArray<C>* res) {
size_t k = res->k;
size_t k,
float* distances,
int64_t* labels) {
size_t d = vd.d;
using C = typename VD::C;
size_t check_period = InterruptCallback::get_period_hint(ny * d);
check_period *= omp_get_max_threads();

Expand All @@ -71,18 +73,15 @@ void knn_extra_metrics_template(
const float* x_i = x + i * d;
const float* y_j = y;
size_t j;
float* simi = res->get_val(i);
int64_t* idxi = res->get_ids(i);
float* simi = distances + k * i;
int64_t* idxi = labels + k * i;

// maxheap_heapify(k, simi, idxi);
heap_heapify<C>(k, simi, idxi);
for (j = 0; j < ny; j++) {
float disij = vd(x_i, y_j);

// if (disij < simi[0]) {
if ((!vd.is_similarity && (disij < simi[0])) ||
(vd.is_similarity && (disij > simi[0]))) {
// maxheap_replace_top(k, simi, idxi, disij, j);
if (C::cmp(simi[0], disij)) {
heap_replace_top<C>(k, simi, idxi, disij, j);
}
y_j += d;
Expand Down Expand Up @@ -165,13 +164,13 @@ void pairwise_extra_distances(
HANDLE_VAR(Lp);
HANDLE_VAR(Jaccard);
HANDLE_VAR(NaNEuclidean);
HANDLE_VAR(ABS_INNER_PRODUCT);
#undef HANDLE_VAR
default:
FAISS_THROW_MSG("metric type not implemented");
}
}

template <class C>
void knn_extra_metrics(
const float* x,
const float* y,
Expand All @@ -180,13 +179,15 @@ void knn_extra_metrics(
size_t ny,
MetricType mt,
float metric_arg,
HeapArray<C>* res) {
size_t k,
float* distances,
int64_t* indexes) {
switch (mt) {
#define HANDLE_VAR(kw) \
case METRIC_##kw: { \
VectorDistance<METRIC_##kw> vd = {(size_t)d, metric_arg}; \
knn_extra_metrics_template(vd, x, y, nx, ny, res); \
break; \
#define HANDLE_VAR(kw) \
case METRIC_##kw: { \
VectorDistance<METRIC_##kw> vd = {(size_t)d, metric_arg}; \
knn_extra_metrics_template(vd, x, y, nx, ny, k, distances, indexes); \
break; \
}
HANDLE_VAR(L2);
HANDLE_VAR(L1);
Expand All @@ -197,32 +198,13 @@ void knn_extra_metrics(
HANDLE_VAR(Lp);
HANDLE_VAR(Jaccard);
HANDLE_VAR(NaNEuclidean);
HANDLE_VAR(ABS_INNER_PRODUCT);
#undef HANDLE_VAR
default:
FAISS_THROW_MSG("metric type not implemented");
}
}

template void knn_extra_metrics<CMax<float, int64_t>>(
const float* x,
const float* y,
size_t d,
size_t nx,
size_t ny,
MetricType mt,
float metric_arg,
HeapArray<CMax<float, int64_t>>* res);

template void knn_extra_metrics<CMin<float, int64_t>>(
const float* x,
const float* y,
size_t d,
size_t nx,
size_t ny,
MetricType mt,
float metric_arg,
HeapArray<CMin<float, int64_t>>* res);

FlatCodesDistanceComputer* get_extra_distance_computer(
size_t d,
MetricType mt,
Expand All @@ -245,6 +227,7 @@ FlatCodesDistanceComputer* get_extra_distance_computer(
HANDLE_VAR(Lp);
HANDLE_VAR(Jaccard);
HANDLE_VAR(NaNEuclidean);
HANDLE_VAR(ABS_INNER_PRODUCT);
#undef HANDLE_VAR
default:
FAISS_THROW_MSG("metric type not implemented");
Expand Down
5 changes: 3 additions & 2 deletions faiss/utils/extra_distances.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ void pairwise_extra_distances(
int64_t ldb = -1,
int64_t ldd = -1);

template <class C>
void knn_extra_metrics(
const float* x,
const float* y,
Expand All @@ -42,7 +41,9 @@ void knn_extra_metrics(
size_t ny,
MetricType mt,
float metric_arg,
HeapArray<C>* res);
size_t k,
float* distances,
int64_t* indexes);

/** get a DistanceComputer that refers to this type of distance and
* indexes a flat array of size nb */
Expand Down
7 changes: 7 additions & 0 deletions tests/test_extra_distances.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,13 @@ def test_nan_euclidean(self):
new_dis = faiss.pairwise_distances(x, q, faiss.METRIC_NaNEuclidean)
self.assertTrue(np.isnan(new_dis[0]))

def test_abs_inner_product(self):
xq, yb = self.make_example()
dis = faiss.pairwise_distances(xq, yb, faiss.METRIC_ABS_INNER_PRODUCT)

gt_dis = np.abs(xq @ yb.T)
np.testing.assert_allclose(dis, gt_dis, atol=1e-5)


class TestKNN(unittest.TestCase):
""" test that the knn search gives the same as distance matrix + argmin """
Expand Down
15 changes: 15 additions & 0 deletions tests/test_graph_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,21 @@ def test_io_no_storage(self):
)
self.assertEquals(index3.storage, None)

def test_abs_inner_product(self):
"""Test HNSW with abs inner product (not a real distance, so dubious that triangular inequality works)"""
d = self.xq.shape[1]
xb = self.xb - self.xb.mean(axis=0) # need to be centered to give interesting directions
xq = self.xq - self.xq.mean(axis=0)
Dref, Iref = faiss.knn(xq, xb, 10, faiss.METRIC_ABS_INNER_PRODUCT)

index = faiss.IndexHNSWFlat(d, 32, faiss.METRIC_ABS_INNER_PRODUCT)
index.add(xb)
Dnew, Inew = index.search(xq, 10)

inter = faiss.eval_intersection(Iref, Inew)
# 4769 vs. 500*10
self.assertGreater(inter, Iref.size * 0.9)


class TestNSG(unittest.TestCase):

Expand Down

0 comments on commit e758973

Please sign in to comment.