Skip to content

Commit

Permalink
Mocobeta/lucene 9004 2 (apache#3)
Browse files Browse the repository at this point in the history
* Added KnnGraphTester; fixed some small bugs and measured recall

* use FloatBuffer for encode/decode of vector values. Share a single float[] when fetching values. 25% latency reduction

* Don't try to precompute nearest neighbors when generating test data in KnnGraphTester

* added visitedCounter for debugging/optimization
restored shrinking to limit fanout of graph
fixed an upside-down priority queue

* don't terminate exploration of graph until we have a full result queue

* fixing up some small bugs I had introduced in earlier push
  • Loading branch information
msokolov authored and jtibshirani committed Feb 6, 2020
1 parent b5c32b2 commit 43dc87b
Show file tree
Hide file tree
Showing 8 changed files with 220 additions and 77 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
package org.apache.lucene.codecs.lucene90;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.FloatBuffer;
import java.util.HashMap;
import java.util.Map;

Expand Down Expand Up @@ -258,26 +260,40 @@ private final static class RandomAccessVectorValuesReader extends VectorValues {

final int maxDoc;
final int numDims;
final int byteSize;
final KnnGraphEntry entry;
final IndexInput dataIn;
final BytesRef binaryValue;
final ByteBuffer byteBuffer;
final FloatBuffer floatBuffer;
final float[] value;

int doc = -1;
BytesRef binaryValue;

RandomAccessVectorValuesReader(int maxDoc, int numDims, KnnGraphEntry entry, IndexInput dataIn) {
this.maxDoc = maxDoc;
this.numDims = numDims;
this.entry = entry;
this.dataIn = dataIn;
this.binaryValue = new BytesRef(new byte[Float.BYTES * numDims]);
// TODO: if we had more direct access to file system (a FileChannel) we could use a
// MappedByteBuffer here. EG if we know that dataIn is a ByteBufferIndexInput, we could get a
// FloatBuffer from that and provide access to it
byteSize = Float.BYTES * numDims;
byteBuffer = ByteBuffer.allocate(byteSize);
floatBuffer = byteBuffer.asFloatBuffer();
value = new float[numDims];
binaryValue = new BytesRef(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize);
}

@Override
public float[] vectorValue() throws IOException {
if (doc == NO_MORE_DOCS) {
return null;
}
return VectorValues.decode(binaryValue, numDims);
// return VectorValues.decode(binaryValue, numDims);
floatBuffer.position(0);
floatBuffer.get(value, 0, numDims);
return value;
}

@Override
Expand All @@ -296,10 +312,11 @@ public boolean seek(int target) throws IOException {
if (ord == null) {
return false;
}
int offset = Float.BYTES * numDims * ord;
int offset = ord * byteSize;
assert offset >= 0;
dataIn.seek(offset);
dataIn.readBytes(binaryValue.bytes, binaryValue.offset, binaryValue.length);
//dataIn.readBytes(binaryValue.bytes, binaryValue.offset, binaryValue.length);
dataIn.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize);
return true;
}

Expand Down
12 changes: 3 additions & 9 deletions lucene/core/src/java/org/apache/lucene/index/VectorValues.java
Original file line number Diff line number Diff line change
Expand Up @@ -73,15 +73,9 @@ public static float distance(float[] v1, float[] v2, DistanceFunction distFunc)
* Encodes float array to byte array.
*/
public static BytesRef encode(float[] value) {
byte[] bytes = new byte[Float.BYTES * value.length];
for (int i = 0; i < value.length; i++) {
int bits = Float.floatToIntBits(value[i]);
bytes[i * Float.BYTES] = (byte)(bits >> 24);
bytes[i * Float.BYTES + 1] = (byte)(bits >> 16);
bytes[i * Float.BYTES + 2] = (byte)(bits >> 8);
bytes[i * Float.BYTES + 3] = (byte)(bits);
}
return new BytesRef(bytes);
ByteBuffer buffer = ByteBuffer.allocate(Float.BYTES * value.length);
buffer.asFloatBuffer().put(value);
return new BytesRef(buffer.array());
}

public static boolean verifyNumDimensions(int numBytes, int numDims) {
Expand Down
25 changes: 19 additions & 6 deletions lucene/core/src/java/org/apache/lucene/search/KnnGraphQuery.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import java.util.Arrays;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.atomic.AtomicLong;

import org.apache.lucene.document.Document;
import org.apache.lucene.index.FieldInfo;
Expand All @@ -44,8 +45,8 @@ public class KnnGraphQuery extends Query implements Accountable {
private final String field;
private final float[] queryVector;
private final int ef;

private long bytesUsed = 0L;
private final long bytesUsed;
private final AtomicLong visitedCounter;

/**
* Creates an nearest neighbor search query with default {@code ef} parameter (={@link #DEFAULT_EF}).
Expand All @@ -54,7 +55,6 @@ public class KnnGraphQuery extends Query implements Accountable {
*/
public KnnGraphQuery(String field, float[] queryVector) {
this(field, queryVector, DEFAULT_EF);
this.bytesUsed = RamUsageEstimator.shallowSizeOfInstance(getClass());
}

/**
Expand All @@ -68,7 +68,8 @@ public KnnGraphQuery(String field, float[] queryVector, int ef) {
this.field = field;
this.queryVector = queryVector;
this.ef = ef;
this.bytesUsed = RamUsageEstimator.shallowSizeOfInstance(getClass());
visitedCounter = new AtomicLong();
bytesUsed = RamUsageEstimator.shallowSizeOfInstance(getClass());
}

/**
Expand All @@ -85,8 +86,11 @@ public KnnGraphQuery(String field, float[] queryVector, int ef, IndexReader read
this.field = field;
this.queryVector = queryVector;
this.ef = ef;
visitedCounter = new AtomicLong();
if (reader != null) {
this.bytesUsed = HNSWGraphReader.loadGraphs(field, reader, forceReload);
bytesUsed = HNSWGraphReader.loadGraphs(field, reader, forceReload);
} else {
bytesUsed = 0;
}
}

Expand All @@ -106,7 +110,9 @@ public static KnnGraphQuery like(String field, int docId, int ef, IndexReader re

@Override
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException {
return new KnnScoreWeight(this, boost, scoreMode, field, queryVector, ef);
KnnScoreWeight weight = new KnnScoreWeight(this, boost, scoreMode, field, queryVector, ef);
weight.setVisitedCounter(visitedCounter);
return weight;
}

@Override
Expand Down Expand Up @@ -143,4 +149,11 @@ public String toString(String field) {
public long ramBytesUsed() {
return bytesUsed;
}

/**
* @return this total the number of documents visited by this query
*/
public long getVisitedCount() {
return visitedCounter.get();
}
}
12 changes: 12 additions & 0 deletions lucene/core/src/java/org/apache/lucene/search/KnnScoreWeight.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.lucene.search;

import java.io.IOException;
import java.util.concurrent.atomic.AtomicLong;

import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.LeafReaderContext;
Expand All @@ -32,6 +33,7 @@ class KnnScoreWeight extends ConstantScoreWeight {
private final ScoreMode scoreMode;
private final float[] queryVector;
private final int ef;
private AtomicLong visitedCounter;

KnnScoreWeight(Query query, float score, ScoreMode scoreMode, String field, float[] queryVector, int ef) {
super(query, score);
Expand Down Expand Up @@ -70,6 +72,7 @@ public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOExcepti
@Override
public Scorer get(long leadCost) throws IOException {
Neighbors neighbors = hnswReader.searchNeighbors(queryVector, ef, vectorValues);
visitedCounter.addAndGet(hnswReader.getVisitedCount());
return new Scorer(weight) {

int doc = -1;
Expand Down Expand Up @@ -110,6 +113,7 @@ public int advance(int target) throws IOException {
switch (fi.getVectorDistFunc()) {
case MANHATTAN:
case EUCLIDEAN:
// is it necessary to normalize these scores?
score = 1.0f / (next.distance() / numDimensions + 0.01f);
break;
case COSINE:
Expand Down Expand Up @@ -158,4 +162,12 @@ public long cost() {
public boolean isCacheable(LeafReaderContext ctx) {
return true;
}

public void setVisitedCounter(AtomicLong counter) {
visitedCounter = counter;
}

public long getVisitedCount() {
return visitedCounter.get();
}
}
27 changes: 10 additions & 17 deletions lucene/core/src/java/org/apache/lucene/util/hnsw/HNSWGraph.java
Original file line number Diff line number Diff line change
Expand Up @@ -55,51 +55,47 @@ public HNSWGraph(VectorValues.DistanceFunction distFunc) {
* @param ef the number of nodes to be searched
* @param level graph level
* @param vectorValues vector values
* @return nearest neighbors
* @return number of candidates visited
*/
void searchLayer(float[] query, FurthestNeighbors results, int ef, int level, VectorValues vectorValues) throws IOException {
int searchLayer(float[] query, FurthestNeighbors results, int ef, int level, VectorValues vectorValues) throws IOException {
if (level >= layers.size()) {
throw new IllegalArgumentException("layer does not exist for the level: " + level);
}

Layer layer = layers.get(level);
TreeSet<Neighbor> candidates = new TreeSet<>();
// set of docids that have been visited by search on this layer, used to avoid backtracking
Set<Integer> visited = new HashSet<>();
for (Neighbor n : results) {
// TODO: candidates should get neighbors of neighbors
candidates.add(n);
visited.add(n.docId());
}
// set of docids that have been visited by search on this layer, used to avoid backtracking
Set<Integer> visited = new HashSet<>();
// We want to efficiently pop the best (nearest, least distance) candidate, so use NearestNeighbors,
// but we don't want to overflow the heap and lose the best candidate!
Neighbor f = results.top();
while (candidates.size() > 0) {
Neighbor c = candidates.pollFirst();
Neighbor f = results.top();
assert c.isDeferred() == false;
assert f.isDeferred() == false;
if (c.distance() > f.distance()) {
if (c.distance() > f.distance() && results.size() >= ef) {
break;
}
for (Neighbor e : layer.getFriends(c.docId())) {
if (visited.contains(e.docId())) {
continue;
}
visited.add(e.docId());
assert f.isDeferred() == false;
float dist = distance(query, e.docId(), vectorValues);
if (dist < f.distance() || results.size() < ef) {
Neighbor n = new ImmutableNeighbor(e.docId(), dist);
candidates.add(n);
Neighbor popped = results.insertWithOverflow(n);
if (popped != null && popped != n) {
f = results.top();
}
results.insertWithOverflow(n);
f = results.top();
}
}
}

//System.out.println("level=" + level + ", visited nodes=" + visited.size());
//return pickNearestNeighbor(results);
return visited.size();
}

private float distance(float[] query, int docId, VectorValues vectorValues) throws IOException {
Expand Down Expand Up @@ -216,13 +212,10 @@ public void connectNodes(int level, int node1, int node2, float dist, int maxCon
throw new IllegalArgumentException("layer does not exist for level: " + level);
}
layer.connectNodes(node1, node2, dist);
/*
// ensure friends size <= maxConnections
//
if (maxConnections > 0) {
layer.shrink(node2, maxConnections);
}
*/
}

/** Connects two nodes; this is supposed to be called when searching */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,16 @@
* This also caches built {@link HNSWGraph}s for repeated use. */
public final class HNSWGraphReader {

// TODO: evict from this cache when a reader's segment is merged. Probably should hold the cache
// somewhere that has a similar lifecycle, like Lucene90KnnGraphReader
private static final Map<GraphKey, HNSWGraph> cache = new ConcurrentHashMap<>();

private final String field;
private final LeafReaderContext context;
private final VectorValues.DistanceFunction distFunc;

private long visitedCount;

public HNSWGraphReader(String field, LeafReaderContext context) {
this.field = field;
this.context = context;
Expand All @@ -58,9 +62,9 @@ public Neighbors searchNeighbors(float[] query, int ef, VectorValues vectorValue
Neighbor ep = new ImmutableNeighbor(enterPoint, VectorValues.distance(query, vectorValues.vectorValue(), distFunc));
FurthestNeighbors neighbors = new FurthestNeighbors(ef, ep);
for (int l = hnsw.topLevel(); l > 0; l--) {
hnsw.searchLayer(query, neighbors, 1, l, vectorValues);
visitedCount += hnsw.searchLayer(query, neighbors, 1, l, vectorValues);
}
hnsw.searchLayer(query, neighbors, ef, 0, vectorValues);
visitedCount += hnsw.searchLayer(query, neighbors, ef, 0, vectorValues);
return neighbors;
}

Expand Down Expand Up @@ -122,6 +126,15 @@ public static HNSWGraph load(String field, VectorValues.DistanceFunction distFun
return hnsw;
}

/**
* @return the number of documents visited by this reader. For each visited document, the reader
* computed the distance to a target vector. This count accumulates over the lifetime of the
* reader.
*/
public long getVisitedCount() {
return visitedCount;
}

private static class GraphKey {
final String field;
final Object readerId;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ public HNSWGraphWriter(int numDimensions, VectorValues.DistanceFunction distFunc
}

/** Full constructor */
public HNSWGraphWriter(int maxConn, int maxConn0, int efConst, long seed, int numDimensions, VectorValues.DistanceFunction distFunc) {
HNSWGraphWriter(int maxConn, int maxConn0, int efConst, long seed, int numDimensions, VectorValues.DistanceFunction distFunc) {
this.maxConn = maxConn;
this.maxConn0 = maxConn0;
this.efConst = efConst;
Expand Down Expand Up @@ -115,7 +115,7 @@ public void insert(int docId, BytesRef binaryValue) throws IOException {
FurthestNeighbors results = new FurthestNeighbors(efConst, ep);
// down to the level from the hnsw's top level
for (int l = hnsw.topLevel(); l > level; l--) {
hnsw.searchLayer(value, results, efConst, l, vectorValues);
hnsw.searchLayer(value, results, 1, l, vectorValues);
}

// down to level 0 with placing the doc to each layer
Expand All @@ -128,12 +128,10 @@ public void insert(int docId, BytesRef binaryValue) throws IOException {

hnsw.searchLayer(value, results, efConst, l, vectorValues);
int maxConnections = l == 0 ? maxConn0 : maxConn;
NearestNeighbors neighbors = new NearestNeighbors(maxConnections, results.top());
for (Neighbor n : results) {
neighbors.insertWithOverflow(n);
while (results.size() > maxConnections) {
results.pop();
}
for (Neighbor n : neighbors) {
// TODO: limit *total* num connections by pruning (shrinking)
for (Neighbor n : results) {
hnsw.connectNodes(l, docId, n.docId(), n.distance(), maxConnections);
}
}
Expand Down
Loading

0 comments on commit 43dc87b

Please sign in to comment.