diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/VectorScorerBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/VectorScorerBenchmark.java index b294fe97c7e7c..fdb09594a1cda 100644 --- a/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/VectorScorerBenchmark.java +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/VectorScorerBenchmark.java @@ -83,6 +83,8 @@ public class VectorScorerBenchmark { RandomVectorScorer luceneDotScorerQuery; RandomVectorScorer nativeDotScorerQuery; + RandomVectorScorer luceneSqrScorerQuery; + RandomVectorScorer nativeSqrScorerQuery; @Setup public void setup() throws IOException { @@ -130,6 +132,8 @@ public void setup() throws IOException { } luceneDotScorerQuery = luceneScorer(values, VectorSimilarityFunction.DOT_PRODUCT, queryVec); nativeDotScorerQuery = factory.getInt7SQVectorScorer(VectorSimilarityFunction.DOT_PRODUCT, values, queryVec).get(); + luceneSqrScorerQuery = luceneScorer(values, VectorSimilarityFunction.EUCLIDEAN, queryVec); + nativeSqrScorerQuery = factory.getInt7SQVectorScorer(VectorSimilarityFunction.EUCLIDEAN, values, queryVec).get(); // sanity var f1 = dotProductLucene(); @@ -157,6 +161,12 @@ public void setup() throws IOException { if (q1 != q2) { throw new AssertionError("query: lucene[" + q1 + "] != " + "native[" + q2 + "]"); } + + var sqr1 = squareDistanceLuceneQuery(); + var sqr2 = squareDistanceNativeQuery(); + if (sqr1 != sqr2) { + throw new AssertionError("query: lucene[" + q1 + "] != " + "native[" + q2 + "]"); + } } @TearDown @@ -217,6 +227,16 @@ public float squareDistanceScalar() { return 1 / (1f + adjustedDistance); } + @Benchmark + public float squareDistanceLuceneQuery() throws IOException { + return luceneSqrScorerQuery.score(1); + } + + @Benchmark + public float squareDistanceNativeQuery() throws IOException { + return nativeSqrScorerQuery.score(1); + } + QuantizedByteVectorValues vectorValues(int dims, int size, IndexInput in, VectorSimilarityFunction sim) throws IOException { var sq = new ScalarQuantizer(0.1f, 0.9f, (byte) 7); var slice = in.slice("values", 0, in.length());