Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Diversity check bugfix #11781

Merged
merged 2 commits into from
Sep 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ private HnswGraphBuilder(
this.M = M;
this.beamWidth = beamWidth;
// normalization factor for level generation; currently not configurable
this.ml = 1 / Math.log(1.0 * M);
this.ml = M == 1 ? 1 : 1 / Math.log(1.0 * M);
this.random = new SplittableRandom(seed);
int levelOfFirstNode = getRandomGraphLevel(ml, random);
this.hnsw = new OnHeapHnswGraph(M, levelOfFirstNode);
Expand Down Expand Up @@ -316,49 +316,49 @@ private boolean isDiverse(BytesRef candidate, NeighborArray neighbors, float sco
*/
private int findWorstNonDiverse(NeighborArray neighbors) throws IOException {
for (int i = neighbors.size() - 1; i > 0; i--) {
if (isWorstNonDiverse(i, neighbors, neighbors.score[i])) {
if (isWorstNonDiverse(i, neighbors)) {
return i;
}
}
return neighbors.size() - 1;
}

private boolean isWorstNonDiverse(
int candidate, NeighborArray neighbors, float minAcceptedSimilarity) throws IOException {
private boolean isWorstNonDiverse(int candidateIndex, NeighborArray neighbors)
throws IOException {
int candidateNode = neighbors.node[candidateIndex];
return switch (vectorEncoding) {
case BYTE -> isWorstNonDiverse(
candidate, vectors.binaryValue(candidate), neighbors, minAcceptedSimilarity);
case BYTE -> isWorstNonDiverse(candidateIndex, vectors.binaryValue(candidateNode), neighbors);
case FLOAT32 -> isWorstNonDiverse(
candidate, vectors.vectorValue(candidate), neighbors, minAcceptedSimilarity);
candidateIndex, vectors.vectorValue(candidateNode), neighbors);
};
}

private boolean isWorstNonDiverse(
int candidateIndex, float[] candidate, NeighborArray neighbors, float minAcceptedSimilarity)
throws IOException {
for (int i = candidateIndex - 1; i > -0; i--) {
int candidateIndex, float[] candidateVector, NeighborArray neighbors) throws IOException {
float minAcceptedSimilarity = neighbors.score[candidateIndex];
for (int i = candidateIndex - 1; i >= 0; i--) {
float neighborSimilarity =
similarityFunction.compare(candidate, vectorsCopy.vectorValue(neighbors.node[i]));
// node i is too similar to node j given its score relative to the base node
similarityFunction.compare(candidateVector, vectorsCopy.vectorValue(neighbors.node[i]));
// candidate node is too similar to node i given its score relative to the base node
if (neighborSimilarity >= minAcceptedSimilarity) {
return false;
return true;
}
}
return true;
return false;
}

private boolean isWorstNonDiverse(
int candidateIndex, BytesRef candidate, NeighborArray neighbors, float minAcceptedSimilarity)
throws IOException {
for (int i = candidateIndex - 1; i > -0; i--) {
int candidateIndex, BytesRef candidateVector, NeighborArray neighbors) throws IOException {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am surprised that with this big change, we had only a small reduction in recall. I guess the reason could be that in our tests diversity check was really relevant only for small number of nodes; in majority of cases the algorithm just eliminated the most distant node.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know - how did this garbage even work at all! ☹️ It's kind of astonishing how insensitive this whole process is to the diversity checking. Initially we didn't have it at all though (just always pick the closest neighbors), and things still kind of work. Then I had the wonky implementation that did not sort the neighbors while indexing, but did some best effort kind of thing, and still it mostly worked. So we need good tests here to ensure we are doing the right thing! Because bugs here can lead to small degradation.

float minAcceptedSimilarity = neighbors.score[candidateIndex];
for (int i = candidateIndex - 1; i >= 0; i--) {
float neighborSimilarity =
similarityFunction.compare(candidate, vectorsCopy.binaryValue(neighbors.node[i]));
// node i is too similar to node j given its score relative to the base node
similarityFunction.compare(candidateVector, vectorsCopy.binaryValue(neighbors.node[i]));
// candidate node is too similar to node i given its score relative to the base node
if (neighborSimilarity >= minAcceptedSimilarity) {
return false;
return true;
}
}
return true;
return false;
}

private static int getRandomGraphLevel(double ml, SplittableRandom random) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,7 @@ public void testDiversity() throws IOException {
unitVector2d(0.9),
unitVector2d(0.8),
unitVector2d(0.77),
unitVector2d(0.6)
};
if (vectorEncoding == VectorEncoding.BYTE) {
for (float[] v : values) {
Expand Down Expand Up @@ -555,6 +556,78 @@ public void testDiversity() throws IOException {
assertLevel0Neighbors(builder.hnsw, 5, 1, 4);
}

public void testDiversityFallback() throws IOException {
vectorEncoding = randomVectorEncoding();
similarityFunction = VectorSimilarityFunction.EUCLIDEAN;
// Some test cases can't be exercised in two dimensions;
// in particular if a new neighbor displaces an existing neighbor
// by being closer to the target, yet none of the existing neighbors is closer to the new vector
// than to the target -- ie they all remain diverse, so we simply drop the farthest one.
float[][] values = {
{0, 0, 0},
{0, 10, 0},
{0, 0, 20},
{10, 0, 0},
{0, 4, 0}
};
MockVectorValues vectors = new MockVectorValues(values);
// First add nodes until everybody gets a full neighbor list
HnswGraphBuilder<?> builder =
HnswGraphBuilder.create(
vectors, vectorEncoding, similarityFunction, 1, 10, random().nextInt());
// node 0 is added by the builder constructor
// builder.addGraphNode(vectors.vectorValue(0));
RandomAccessVectorValues vectorsCopy = vectors.copy();
builder.addGraphNode(1, vectorsCopy);
builder.addGraphNode(2, vectorsCopy);
assertLevel0Neighbors(builder.hnsw, 0, 1, 2);
// 2 is closer to 0 than 1, so it is excluded as non-diverse
assertLevel0Neighbors(builder.hnsw, 1, 0);
// 1 is closer to 0 than 2, so it is excluded as non-diverse
assertLevel0Neighbors(builder.hnsw, 2, 0);

builder.addGraphNode(3, vectorsCopy);
// this is one case we are testing; 2 has been displaced by 3
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice test!

assertLevel0Neighbors(builder.hnsw, 0, 1, 3);
assertLevel0Neighbors(builder.hnsw, 1, 0);
assertLevel0Neighbors(builder.hnsw, 2, 0);
assertLevel0Neighbors(builder.hnsw, 3, 0);
}

public void testDiversity3d() throws IOException {
vectorEncoding = randomVectorEncoding();
similarityFunction = VectorSimilarityFunction.EUCLIDEAN;
// test the case when a neighbor *becomes* non-diverse when a newer better neighbor arrives
float[][] values = {
{0, 0, 0},
{0, 10, 0},
{0, 0, 20},
{0, 9, 0}
};
MockVectorValues vectors = new MockVectorValues(values);
// First add nodes until everybody gets a full neighbor list
HnswGraphBuilder<?> builder =
HnswGraphBuilder.create(
vectors, vectorEncoding, similarityFunction, 1, 10, random().nextInt());
// node 0 is added by the builder constructor
// builder.addGraphNode(vectors.vectorValue(0));
RandomAccessVectorValues vectorsCopy = vectors.copy();
builder.addGraphNode(1, vectorsCopy);
builder.addGraphNode(2, vectorsCopy);
assertLevel0Neighbors(builder.hnsw, 0, 1, 2);
// 2 is closer to 0 than 1, so it is excluded as non-diverse
assertLevel0Neighbors(builder.hnsw, 1, 0);
// 1 is closer to 0 than 2, so it is excluded as non-diverse
assertLevel0Neighbors(builder.hnsw, 2, 0);

builder.addGraphNode(3, vectorsCopy);
// this is one case we are testing; 1 has been displaced by 3
assertLevel0Neighbors(builder.hnsw, 0, 2, 3);
assertLevel0Neighbors(builder.hnsw, 1, 0, 3);
assertLevel0Neighbors(builder.hnsw, 2, 0);
assertLevel0Neighbors(builder.hnsw, 3, 0, 1);
}

private void assertLevel0Neighbors(OnHeapHnswGraph graph, int node, int... expected) {
Arrays.sort(expected);
NeighborArray nn = graph.getNeighbors(0, node);
Expand Down