Skip to content

Commit

Permalink
Utilize optimized dot_product where possible when calculating vector …
Browse files Browse the repository at this point in the history
…magnitude (elastic#99448)

Lucene provides an optimized `dot_product` calculation for vectors. We
should use that when calculating a vector's magnitude.
  • Loading branch information
benwtrent authored Sep 12, 2023
1 parent e26dca4 commit 2072be9
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.apache.lucene.search.KnnFloatVectorQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.VectorUtil;
import org.elasticsearch.common.xcontent.support.XContentMapValues;
import org.elasticsearch.index.IndexVersion;
import org.elasticsearch.index.fielddata.FieldDataContext;
Expand Down Expand Up @@ -859,10 +860,7 @@ public Query createKnnQuery(byte[] queryVector, int numCands, Query filter, Floa
}

if (similarity == VectorSimilarity.DOT_PRODUCT || similarity == VectorSimilarity.COSINE) {
float squaredMagnitude = 0.0f;
for (byte b : queryVector) {
squaredMagnitude += b * b;
}
float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector);
elementType.checkVectorMagnitude(similarity, elementType.errorByteElementsAppender(queryVector), squaredMagnitude);
}
Query knnQuery = new KnnByteVectorQuery(name(), queryVector, numCands, filter);
Expand Down Expand Up @@ -891,10 +889,7 @@ public Query createKnnQuery(float[] queryVector, int numCands, Query filter, Flo
elementType.checkVectorBounds(queryVector);

if (similarity == VectorSimilarity.DOT_PRODUCT || similarity == VectorSimilarity.COSINE) {
float squaredMagnitude = 0.0f;
for (float e : queryVector) {
squaredMagnitude += e * e;
}
float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector);
elementType.checkVectorMagnitude(similarity, elementType.errorFloatElementsAppender(queryVector), squaredMagnitude);
}
Query knnQuery = switch (elementType) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
package org.elasticsearch.index.mapper.vectors;

import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.VectorUtil;
import org.elasticsearch.index.IndexVersion;

import java.nio.ByteBuffer;
Expand Down Expand Up @@ -46,12 +47,7 @@ public static float decodeMagnitude(IndexVersion indexVersion, BytesRef vectorBR
* Calculates vector magnitude
*/
private static float calculateMagnitude(float[] decodedVector) {
double magnitude = 0.0f;
for (int i = 0; i < decodedVector.length; i++) {
magnitude += decodedVector[i] * decodedVector[i];
}
magnitude = Math.sqrt(magnitude);
return (float) magnitude;
return (float) Math.sqrt(VectorUtil.dotProduct(decodedVector, decodedVector));
}

public static float getMagnitude(IndexVersion indexVersion, BytesRef vectorBR, float[] decodedVector) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

package org.elasticsearch.script.field.vectors;

import org.apache.lucene.util.VectorUtil;

import java.util.List;

/**
Expand Down Expand Up @@ -151,11 +153,7 @@ default double cosineSimilarity(Object queryVector) {
int size();

static float getMagnitude(byte[] vector) {
int mag = 0;
for (int elem : vector) {
mag += elem * elem;
}
return (float) Math.sqrt(mag);
return (float) Math.sqrt(VectorUtil.dotProduct(vector, vector));
}

static float getMagnitude(byte[] vector, int dims) {
Expand All @@ -170,11 +168,7 @@ static float getMagnitude(byte[] vector, int dims) {
}

static float getMagnitude(float[] vector) {
double mag = 0.0f;
for (float elem : vector) {
mag += elem * elem;
}
return (float) Math.sqrt(mag);
return (float) Math.sqrt(VectorUtil.dotProduct(vector, vector));
}

static float getMagnitude(List<Number> vector) {
Expand Down

0 comments on commit 2072be9

Please sign in to comment.