Skip to content

Commit

Permalink
Fix bug in HNSW diversity checks introduced in LUCENE-10577
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Sokolov committed Sep 18, 2022
1 parent 3e62e8c commit 1154045
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ private HnswGraphBuilder(
this.M = M;
this.beamWidth = beamWidth;
// normalization factor for level generation; currently not configurable
this.ml = 1 / Math.log(1.0 * M);
this.ml = M == 1 ? 1 : 1 / Math.log(1.0 * M);
this.random = new SplittableRandom(seed);
int levelOfFirstNode = getRandomGraphLevel(ml, random);
this.hnsw = new OnHeapHnswGraph(M, levelOfFirstNode);
Expand Down Expand Up @@ -316,49 +316,50 @@ private boolean isDiverse(BytesRef candidate, NeighborArray neighbors, float sco
*/
private int findWorstNonDiverse(NeighborArray neighbors) throws IOException {
for (int i = neighbors.size() - 1; i > 0; i--) {
if (isWorstNonDiverse(i, neighbors, neighbors.score[i])) {
if (isWorstNonDiverse(i, neighbors)) {
return i;
}
}
return neighbors.size() - 1;
}

private boolean isWorstNonDiverse(
int candidate, NeighborArray neighbors, float minAcceptedSimilarity) throws IOException {
private boolean isWorstNonDiverse(int candidateIndex, NeighborArray neighbors)
throws IOException {
int candidateNode = neighbors.node[candidateIndex];
return switch (vectorEncoding) {
case BYTE -> isWorstNonDiverse(
candidate, vectors.binaryValue(candidate), neighbors, minAcceptedSimilarity);
candidateIndex, vectors.binaryValue(candidateNode), neighbors);
case FLOAT32 -> isWorstNonDiverse(
candidate, vectors.vectorValue(candidate), neighbors, minAcceptedSimilarity);
candidateIndex, vectors.vectorValue(candidateNode), neighbors);
};
}

private boolean isWorstNonDiverse(
int candidateIndex, float[] candidate, NeighborArray neighbors, float minAcceptedSimilarity)
private boolean isWorstNonDiverse(int candidateIndex, float[] candidateVector, NeighborArray neighbors)
throws IOException {
for (int i = candidateIndex - 1; i > -0; i--) {
float minAcceptedSimilarity = neighbors.score[candidateIndex];
for (int i = candidateIndex - 1; i >= 0; i--) {
float neighborSimilarity =
similarityFunction.compare(candidate, vectorsCopy.vectorValue(neighbors.node[i]));
// node i is too similar to node j given its score relative to the base node
similarityFunction.compare(candidateVector, vectorsCopy.vectorValue(neighbors.node[i]));
// candidate node is too similar to node i given its score relative to the base node
if (neighborSimilarity >= minAcceptedSimilarity) {
return false;
return true;
}
}
return true;
return false;
}

private boolean isWorstNonDiverse(
int candidateIndex, BytesRef candidate, NeighborArray neighbors, float minAcceptedSimilarity)
private boolean isWorstNonDiverse(int candidateIndex, BytesRef candidateVector, NeighborArray neighbors)
throws IOException {
for (int i = candidateIndex - 1; i > -0; i--) {
float minAcceptedSimilarity = neighbors.score[candidateIndex];
for (int i = candidateIndex - 1; i >= 0; i--) {
float neighborSimilarity =
similarityFunction.compare(candidate, vectorsCopy.binaryValue(neighbors.node[i]));
// node i is too similar to node j given its score relative to the base node
similarityFunction.compare(candidateVector, vectorsCopy.binaryValue(neighbors.node[i]));
// candidate node is too similar to node i given its score relative to the base node
if (neighborSimilarity >= minAcceptedSimilarity) {
return false;
return true;
}
}
return true;
return false;
}

private static int getRandomGraphLevel(double ml, SplittableRandom random) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,7 @@ public void testDiversity() throws IOException {
unitVector2d(0.9),
unitVector2d(0.8),
unitVector2d(0.77),
unitVector2d(0.6)
};
if (vectorEncoding == VectorEncoding.BYTE) {
for (float[] v : values) {
Expand Down Expand Up @@ -555,6 +556,77 @@ public void testDiversity() throws IOException {
assertLevel0Neighbors(builder.hnsw, 5, 1, 4);
}

public void testDiversityFallback() throws IOException {
vectorEncoding = randomVectorEncoding();
similarityFunction = VectorSimilarityFunction.EUCLIDEAN;
// Some test cases can't be exercised in two dimensions;
// in particular if a new neighbor displaces an existing neighbor
// by being closer to the target, yet none of the existing neighbors is closer to the new vector
// than to the target -- ie they all remain diverse, so we simply drop the farthest one.
float[][] values = {
{0, 0, 0},
{0, 1, 0},
{0, 0, 2},
{1, 0, 0},
{0, 0.4f, 0}
};
MockVectorValues vectors = new MockVectorValues(values);
// First add nodes until everybody gets a full neighbor list
HnswGraphBuilder<?> builder =
HnswGraphBuilder.create(
vectors, vectorEncoding, similarityFunction, 1, 10, random().nextInt());
// node 0 is added by the builder constructor
// builder.addGraphNode(vectors.vectorValue(0));
RandomAccessVectorValues vectorsCopy = vectors.copy();
builder.addGraphNode(1, vectorsCopy);
builder.addGraphNode(2, vectorsCopy);
assertLevel0Neighbors(builder.hnsw, 0, 1, 2);
// 2 is closer to 0 than 1, so it is excluded as non-diverse
assertLevel0Neighbors(builder.hnsw, 1, 0);
// 1 is closer to 0 than 2, so it is excluded as non-diverse
assertLevel0Neighbors(builder.hnsw, 2, 0);

builder.addGraphNode(3, vectorsCopy);
// this is one case we are testing; 2 has been displaced by 3
assertLevel0Neighbors(builder.hnsw, 0, 1, 3);
assertLevel0Neighbors(builder.hnsw, 1, 0);
assertLevel0Neighbors(builder.hnsw, 2, 0);
assertLevel0Neighbors(builder.hnsw, 3, 0);
}
public void testDiversity3d() throws IOException {
vectorEncoding = randomVectorEncoding();
similarityFunction = VectorSimilarityFunction.EUCLIDEAN;
// test the case when a neighbor *becomes* non-diverse when a newer better neighbor arrives
float[][] values = {
{0, 0, 0},
{0, 10, 0},
{0, 0, 20},
{0, 9, 0}
};
MockVectorValues vectors = new MockVectorValues(values);
// First add nodes until everybody gets a full neighbor list
HnswGraphBuilder<?> builder =
HnswGraphBuilder.create(
vectors, vectorEncoding, similarityFunction, 1, 10, random().nextInt());
// node 0 is added by the builder constructor
// builder.addGraphNode(vectors.vectorValue(0));
RandomAccessVectorValues vectorsCopy = vectors.copy();
builder.addGraphNode(1, vectorsCopy);
builder.addGraphNode(2, vectorsCopy);
assertLevel0Neighbors(builder.hnsw, 0, 1, 2);
// 2 is closer to 0 than 1, so it is excluded as non-diverse
assertLevel0Neighbors(builder.hnsw, 1, 0);
// 1 is closer to 0 than 2, so it is excluded as non-diverse
assertLevel0Neighbors(builder.hnsw, 2, 0);

builder.addGraphNode(3, vectorsCopy);
// this is one case we are testing; 1 has been displaced by 3
assertLevel0Neighbors(builder.hnsw, 0, 2, 3);
assertLevel0Neighbors(builder.hnsw, 1, 0, 3);
assertLevel0Neighbors(builder.hnsw, 2, 0);
assertLevel0Neighbors(builder.hnsw, 3, 0, 1);
}

private void assertLevel0Neighbors(OnHeapHnswGraph graph, int node, int... expected) {
Arrays.sort(expected);
NeighborArray nn = graph.getNeighbors(0, node);
Expand Down

0 comments on commit 1154045

Please sign in to comment.