Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a HNSW collector that exits early when nearest neighbor queue saturates #14094

Draft
wants to merge 21 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.internal.hppc.IntObjectHashMap;
import org.apache.lucene.search.HnswQueueSaturationCollector;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.store.ChecksumIndexInput;
import org.apache.lucene.store.DataInput;
Expand Down Expand Up @@ -314,10 +315,16 @@ private void search(
return;
}
final RandomVectorScorer scorer = scorerSupplier.get();
final KnnCollector collector =
new OrdinalTranslatedKnnCollector(knnCollector, scorer::ordToDoc);
final Bits acceptedOrds = scorer.getAcceptOrds(acceptDocs);
if (knnCollector.k() < scorer.maxOrd()) {
final KnnCollector collector;
OrdinalTranslatedKnnCollector ordinalTranslatedKnnCollector =
new OrdinalTranslatedKnnCollector(knnCollector, scorer::ordToDoc);
if (scorer.maxOrd() > 1000) {
collector = new HnswQueueSaturationCollector(ordinalTranslatedKnnCollector);
} else {
collector = ordinalTranslatedKnnCollector;
}
HnswGraphSearcher.search(scorer, collector, getGraph(fieldEntry), acceptedOrds);
} else {
// if k is larger than the number of vectors, we can just iterate over all vectors
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search;

/** {@link KnnCollector} that exposes methods to hook into specific parts of the HNSW algorithm. */
public interface HnswKnnCollector extends KnnCollector {

/** Indicates exploration of the next HNSW candidate graph node. */
void nextCandidate();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.lucene.search;

/**
* A {@link HnswKnnCollector} that early exits when nearest neighbor queue keeps saturating beyond a
* 'patience' parameter. This records the rate of collection of new nearest neighbors in the {@code
* delegate} KnnCollector queue, at each HNSW node candidate visit. Once it saturates for a number
* of consecutive node visits (e.g., the patience parameter), this early terminates.
*/
public class HnswQueueSaturationCollector implements HnswKnnCollector {
tteofili marked this conversation as resolved.
Show resolved Hide resolved

private static final double DEFAULT_SATURATION_THRESHOLD = 0.995d;

private final KnnCollector delegate;
private double saturationThreshold;
private int patience;
private boolean patienceFinished;
private int countSaturated;
private int previousQueueSize;
private int currentQueueSize;

public HnswQueueSaturationCollector(
KnnCollector delegate, double saturationThreshold, int patience) {
this.delegate = delegate;
this.previousQueueSize = 0;
this.currentQueueSize = 0;
this.countSaturated = 0;
this.patienceFinished = false;
this.saturationThreshold = saturationThreshold;
this.patience = patience;
}

public HnswQueueSaturationCollector(KnnCollector delegate) {
this.delegate = delegate;
this.previousQueueSize = 0;
this.currentQueueSize = 0;
this.countSaturated = 0;
this.patienceFinished = false;
this.saturationThreshold = DEFAULT_SATURATION_THRESHOLD;
this.patience = defaultPatience();
}

private int defaultPatience() {
return Math.max(7, (int) (k() * 0.3));
}

@Override
public boolean earlyTerminated() {
return delegate.earlyTerminated() || patienceFinished;
tteofili marked this conversation as resolved.
Show resolved Hide resolved
}

@Override
public void incVisitedCount(int count) {
delegate.incVisitedCount(count);
}

@Override
public long visitedCount() {
return delegate.visitedCount();
}

@Override
public long visitLimit() {
return delegate.visitLimit();
}

@Override
public int k() {
return delegate.k();
}

@Override
public boolean collect(int docId, float similarity) {
boolean collect = delegate.collect(docId, similarity);
if (collect) {
currentQueueSize++;
}
return collect;
}

@Override
public float minCompetitiveSimilarity() {
return delegate.minCompetitiveSimilarity();
}

@Override
public TopDocs topDocs() {
TopDocs topDocs;
if (patienceFinished && delegate.earlyTerminated() == false) {
TopDocs delegateDocs = delegate.topDocs();
TotalHits totalHits =
new TotalHits(delegateDocs.totalHits.value(), TotalHits.Relation.EQUAL_TO);
topDocs = new TopDocs(totalHits, delegateDocs.scoreDocs);
} else {
topDocs = delegate.topDocs();
}
return topDocs;
}

@Override
public void nextCandidate() {
double queueSaturation =
(double) Math.min(currentQueueSize, previousQueueSize) / currentQueueSize;
mayya-sharipova marked this conversation as resolved.
Show resolved Hide resolved
previousQueueSize = currentQueueSize;
if (queueSaturation >= saturationThreshold) {
countSaturated++;
} else {
countSaturated = 0;
}
if (countSaturated > patience) {
patienceFinished = true;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import java.io.IOException;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.HnswKnnCollector;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.TopKnnCollector;
import org.apache.lucene.search.knn.EntryPointProvider;
Expand Down Expand Up @@ -272,6 +273,9 @@ void searchLevel(
}
}
}
if (results instanceof HnswKnnCollector hnswKnnCollector) {
hnswKnnCollector.nextCandidate();
}
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search;

import java.util.Random;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.junit.Test;

/** Tests for {@link HnswQueueSaturationCollector} */
public class HnswQueueSaturationCollectorTest extends LuceneTestCase {

@Test
public void testDelegate() {
Random random = random();
int numDocs = 100;
int k = random.nextInt(10);
KnnCollector delegate = new TopKnnCollector(k, numDocs);
HnswQueueSaturationCollector queueSaturationCollector =
new HnswQueueSaturationCollector(delegate);
for (int i = 0; i < random.nextInt(numDocs); i++) {
queueSaturationCollector.collect(random.nextInt(numDocs), random.nextFloat(1.0f));
}
assertEquals(delegate.k(), queueSaturationCollector.k());
assertEquals(delegate.visitedCount(), queueSaturationCollector.visitedCount());
assertEquals(delegate.visitLimit(), queueSaturationCollector.visitLimit());
assertEquals(
delegate.minCompetitiveSimilarity(),
queueSaturationCollector.minCompetitiveSimilarity(),
1e-3);
}

@Test
public void testEarlyExit() {
Random random = random();
int numDocs = 10000;
int k = random.nextInt(100);
KnnCollector delegate = new TopKnnCollector(k, numDocs);
HnswQueueSaturationCollector queueSaturationCollector =
new HnswQueueSaturationCollector(delegate);
for (int i = 0; i < random.nextInt(numDocs); i++) {
queueSaturationCollector.collect(random.nextInt(numDocs), random.nextFloat(1.0f));
boolean earlyTerminatedSaturation = queueSaturationCollector.earlyTerminated();
boolean earlyTerminatedDelegate = delegate.earlyTerminated();
assertTrue(earlyTerminatedSaturation || !earlyTerminatedDelegate);
if (earlyTerminatedDelegate) {
assertTrue(earlyTerminatedSaturation);
}
if (!earlyTerminatedSaturation) {
assertFalse(earlyTerminatedSaturation);
tteofili marked this conversation as resolved.
Show resolved Hide resolved
}
}
}
}
Loading