From 2072be90b28c5e0fd7dff1ac26cec0b448cb7b48 Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Tue, 12 Sep 2023 10:05:29 -0400 Subject: [PATCH] Utilize optimized dot_product where possible when calculating vector magnitude (#99448) Lucene provides an optimized `dot_product` calculation for vectors. We should use that when calculating a vector's magnitude. --- .../mapper/vectors/DenseVectorFieldMapper.java | 11 +++-------- .../index/mapper/vectors/VectorEncoderDecoder.java | 8 ++------ .../script/field/vectors/DenseVector.java | 14 ++++---------- 3 files changed, 9 insertions(+), 24 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index 0e4f871fbb8ca..bd9b9df68aff2 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -26,6 +26,7 @@ import org.apache.lucene.search.KnnFloatVectorQuery; import org.apache.lucene.search.Query; import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.VectorUtil; import org.elasticsearch.common.xcontent.support.XContentMapValues; import org.elasticsearch.index.IndexVersion; import org.elasticsearch.index.fielddata.FieldDataContext; @@ -859,10 +860,7 @@ public Query createKnnQuery(byte[] queryVector, int numCands, Query filter, Floa } if (similarity == VectorSimilarity.DOT_PRODUCT || similarity == VectorSimilarity.COSINE) { - float squaredMagnitude = 0.0f; - for (byte b : queryVector) { - squaredMagnitude += b * b; - } + float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector); elementType.checkVectorMagnitude(similarity, elementType.errorByteElementsAppender(queryVector), squaredMagnitude); } Query knnQuery = new KnnByteVectorQuery(name(), queryVector, numCands, filter); @@ -891,10 +889,7 @@ public Query createKnnQuery(float[] queryVector, int numCands, Query filter, Flo elementType.checkVectorBounds(queryVector); if (similarity == VectorSimilarity.DOT_PRODUCT || similarity == VectorSimilarity.COSINE) { - float squaredMagnitude = 0.0f; - for (float e : queryVector) { - squaredMagnitude += e * e; - } + float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector); elementType.checkVectorMagnitude(similarity, elementType.errorFloatElementsAppender(queryVector), squaredMagnitude); } Query knnQuery = switch (elementType) { diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorEncoderDecoder.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorEncoderDecoder.java index 381c1767edff3..e3285c4dc8644 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorEncoderDecoder.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorEncoderDecoder.java @@ -9,6 +9,7 @@ package org.elasticsearch.index.mapper.vectors; import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.VectorUtil; import org.elasticsearch.index.IndexVersion; import java.nio.ByteBuffer; @@ -46,12 +47,7 @@ public static float decodeMagnitude(IndexVersion indexVersion, BytesRef vectorBR * Calculates vector magnitude */ private static float calculateMagnitude(float[] decodedVector) { - double magnitude = 0.0f; - for (int i = 0; i < decodedVector.length; i++) { - magnitude += decodedVector[i] * decodedVector[i]; - } - magnitude = Math.sqrt(magnitude); - return (float) magnitude; + return (float) Math.sqrt(VectorUtil.dotProduct(decodedVector, decodedVector)); } public static float getMagnitude(IndexVersion indexVersion, BytesRef vectorBR, float[] decodedVector) { diff --git a/server/src/main/java/org/elasticsearch/script/field/vectors/DenseVector.java b/server/src/main/java/org/elasticsearch/script/field/vectors/DenseVector.java index 84649d9954b6a..79a4c3fa1b2ee 100644 --- a/server/src/main/java/org/elasticsearch/script/field/vectors/DenseVector.java +++ b/server/src/main/java/org/elasticsearch/script/field/vectors/DenseVector.java @@ -8,6 +8,8 @@ package org.elasticsearch.script.field.vectors; +import org.apache.lucene.util.VectorUtil; + import java.util.List; /** @@ -151,11 +153,7 @@ default double cosineSimilarity(Object queryVector) { int size(); static float getMagnitude(byte[] vector) { - int mag = 0; - for (int elem : vector) { - mag += elem * elem; - } - return (float) Math.sqrt(mag); + return (float) Math.sqrt(VectorUtil.dotProduct(vector, vector)); } static float getMagnitude(byte[] vector, int dims) { @@ -170,11 +168,7 @@ static float getMagnitude(byte[] vector, int dims) { } static float getMagnitude(float[] vector) { - double mag = 0.0f; - for (float elem : vector) { - mag += elem * elem; - } - return (float) Math.sqrt(mag); + return (float) Math.sqrt(VectorUtil.dotProduct(vector, vector)); } static float getMagnitude(List vector) {