Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a Multi-Vector Similarity Function #13991

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.index;

import java.util.ArrayList;
import java.util.List;
import org.apache.lucene.util.ArrayUtil;

/**
* Computes similarity between two multi-vectors.
*
* <p>A multi-vector is a collection of multiple vectors that represent a single document or query.
* MultiVectorSimilarityFunction is used to determine nearest neighbors during indexing and search
* on multi-vectors.
*/
public class MultiVectorSimilarityFunction {

/** Aggregation function to combine similarity across multiple vector values */
public enum Aggregation {

/**
* Sum_Max Similarity between two multi-vectors. Computes the sum of maximum similarity found
* for each vector in the first multi-vector against all vectors in the second multi-vector.
*/
SUM_MAX {
@Override
public float aggregate(
float[] outer,
float[] inner,
VectorSimilarityFunction vectorSimilarityFunction,
int dimension) {
if (outer.length % dimension != 0 || inner.length % dimension != 0) {
throw new IllegalArgumentException("Multi vectors do not match provided dimension value");
}

// TODO: can we avoid making vector copies?
List<float[]> outerList = new ArrayList<>();
List<float[]> innerList = new ArrayList<>();
for (int i = 0; i < outer.length; i += dimension) {
outerList.add(ArrayUtil.copyOfSubArray(outer, i, i + dimension));
}
for (int i = 0; i < inner.length; i += dimension) {
innerList.add(ArrayUtil.copyOfSubArray(inner, i, i + dimension));
}

float result = 0f;
for (float[] o : outerList) {
float maxSim = Float.MIN_VALUE;
for (float[] i : innerList) {
maxSim = Float.max(maxSim, vectorSimilarityFunction.compare(o, i));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we add another compare method with start and end indexes for both inner and outer, I guess we won't need to copy the array?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, but it needs to go all the way down to VectorUtilSupport, which I think should be a PR of its own.

}
result += maxSim;
}
return result;
}

@Override
public float aggregate(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's so unfortunate that Java doesn't support generic primitive array and has to have duplicate code.

byte[] outer,
byte[] inner,
VectorSimilarityFunction vectorSimilarityFunction,
int dimension) {
if (outer.length % dimension != 0 || inner.length % dimension != 0) {
throw new IllegalArgumentException("Multi vectors do not match provided dimension value");
}

// TODO: can we avoid making vector copies?
List<byte[]> outerList = new ArrayList<>();
List<byte[]> innerList = new ArrayList<>();
for (int i = 0; i < outer.length; i += dimension) {
outerList.add(ArrayUtil.copyOfSubArray(outer, i, i + dimension));
}
for (int i = 0; i < inner.length; i += dimension) {
innerList.add(ArrayUtil.copyOfSubArray(inner, i, i + dimension));
}

float result = 0f;
for (byte[] o : outerList) {
float maxSim = Float.MIN_VALUE;
for (byte[] i : innerList) {
maxSim = Float.max(maxSim, vectorSimilarityFunction.compare(o, i));
}
result += maxSim;
}
return result;
}
};

/**
* Computes and aggregates similarity over multiple vector values.
*
* <p>Assumes all vector values in both provided multi-vectors have the same dimension. Slices
* inner and outer float[] multi-vectors into dimension sized vector values for comparison.
*
* @param outer first multi-vector
* @param inner second multi-vector
* @param vectorSimilarityFunction distance function for vector proximity
* @param dimension dimension for each vector in the provided multi-vectors
* @return similarity between the two multi-vectors
*/
public abstract float aggregate(
float[] outer,
float[] inner,
VectorSimilarityFunction vectorSimilarityFunction,
int dimension);

/**
* Computes and aggregates similarity over multiple vector values.
*
* <p>Assumes all vector values in both provided multi-vectors have the same dimension. Slices
* inner and outer byte[] multi-vectors into dimension sized vector values for comparison.
*
* @param outer first multi-vector
* @param inner second multi-vector
* @param vectorSimilarityFunction distance function for vector proximity
* @param dimension dimension for each vector in the provided multi-vectors
* @return similarity between the two multi-vectors
*/
public abstract float aggregate(
byte[] outer,
byte[] inner,
VectorSimilarityFunction vectorSimilarityFunction,
int dimension);
}

/** Similarity function used for multi-vector distance calculations */
public final VectorSimilarityFunction similarityFunction;

/** Aggregation function to combine similarity across multiple vector values */
public final Aggregation aggregation;

/**
* Similarity function for computing distance between multi-vector values
*
* @param similarityFunction {@link VectorSimilarityFunction} for computing vector proximity
* @param aggregation {@link Aggregation} to combine similarity across multiple vector values
*/
public MultiVectorSimilarityFunction(
VectorSimilarityFunction similarityFunction, Aggregation aggregation) {
this.similarityFunction = similarityFunction;
this.aggregation = aggregation;
}

/**
* Compute similarity between two float multi-vectors.
*
* <p>Expects all component vector values as a single packed float[] for each multi-vector. Uses
* configured aggregation function and vector similarity.
*/
public float compare(float[] t1, float[] t2, int dimension) {
return aggregation.aggregate(t1, t2, similarityFunction, dimension);
}

/**
* Compute similarity between two byte multi-vectors.
*
* <p>Expects all component vector values as a single packed float[] for each multi-vector. Uses
* configured aggregation function and vector similarity.
*/
public float compare(byte[] t1, byte[] t2, int dimension) {
return aggregation.aggregate(t1, t2, similarityFunction, dimension);
}

@Override
public boolean equals(Object obj) {
if (obj instanceof MultiVectorSimilarityFunction == false) {
return false;
}
MultiVectorSimilarityFunction o = (MultiVectorSimilarityFunction) obj;
return this.similarityFunction == o.similarityFunction && this.aggregation == o.aggregation;
}

@Override
public int hashCode() {
int result = Integer.hashCode(similarityFunction.ordinal());
result = 31 * result + Integer.hashCode(aggregation.ordinal());
return result;
}

@Override
public String toString() {
return "MultiVectorSimilarityFunction(similarity="
+ similarityFunction
+ ", aggregation="
+ aggregation
+ ")";
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.index;

import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.util.VectorUtil;
import org.junit.Test;

public class TestMultiVectorSimilarityFunction extends LuceneTestCase {

@Test
public void testSumMaxWithDotProduct() {
final int dimension = 3;
final VectorSimilarityFunction vectorSim = VectorSimilarityFunction.DOT_PRODUCT;

float[][] a =
new float[][] {
VectorUtil.l2normalize(new float[] {1.f, 4.f, -3.f}),
VectorUtil.l2normalize(new float[] {8.f, 3.f, -7.f})
};
float[][] b =
new float[][] {
VectorUtil.l2normalize(new float[] {-5.f, 2.f, 4.f}),
VectorUtil.l2normalize(new float[] {7.f, 1.f, -3.f}),
VectorUtil.l2normalize(new float[] {-5.f, 8.f, 3.f})
};

float result = 0f;
float[] a0_bDot =
new float[] {
vectorSim.compare(a[0], b[0]),
vectorSim.compare(a[0], b[1]),
vectorSim.compare(a[0], b[2])
};
float max = Float.MIN_VALUE;
for (float k : a0_bDot) {
max = Float.max(max, k);
}
result += max;

float[] a1_bDot =
new float[] {
vectorSim.compare(a[1], b[0]),
vectorSim.compare(a[1], b[1]),
vectorSim.compare(a[1], b[2])
};
max = Float.MIN_VALUE;
for (float k : a1_bDot) {
max = Float.max(max, k);
}
result += max;

float[] a_Packed = new float[a.length * dimension];
int i = 0;
for (float[] v : a) {
System.arraycopy(v, 0, a_Packed, i, dimension);
i += dimension;
}
float[] b_Packed = new float[b.length * dimension];
i = 0;
for (float[] v : b) {
System.arraycopy(v, 0, b_Packed, i, dimension);
i += dimension;
}

MultiVectorSimilarityFunction mvSim =
new MultiVectorSimilarityFunction(
VectorSimilarityFunction.DOT_PRODUCT,
MultiVectorSimilarityFunction.Aggregation.SUM_MAX);
float score = mvSim.compare(a_Packed, b_Packed, dimension);
assertEquals(result, score, 0.0001f);
}

@Test
public void testDimensionCheck() {
float[] a = {1f, 2f, 3f, 4f, 5f, 6f};
float[] b = {1f, 2f, 3f, 4f, 5f, 6f, 7f, 8f, 9f};
MultiVectorSimilarityFunction mvSim =
new MultiVectorSimilarityFunction(
VectorSimilarityFunction.DOT_PRODUCT,
MultiVectorSimilarityFunction.Aggregation.SUM_MAX);
assertThrows(IllegalArgumentException.class, () -> mvSim.compare(a, b, 2));
}
}