-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add a Multi-Vector Similarity Function #13991
Open
vigyasharma
wants to merge
9
commits into
apache:main
Choose a base branch
from
vigyasharma:mv_similarity
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
f1c7352
initial files
vigyasharma 2554e05
remove the mv sim interface
vigyasharma be3f2e9
remove NONE aggregate value
vigyasharma cf97155
first mvsf test
vigyasharma 9533336
add test for mvSim
vigyasharma 5fb88b9
tidy
vigyasharma 4f6070b
lint fail for @override
vigyasharma 51396e8
add license
vigyasharma 0d6f6f1
javadocs
vigyasharma File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
202 changes: 202 additions & 0 deletions
202
lucene/core/src/java/org/apache/lucene/index/MultiVectorSimilarityFunction.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,202 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package org.apache.lucene.index; | ||
|
||
import java.util.ArrayList; | ||
import java.util.List; | ||
import org.apache.lucene.util.ArrayUtil; | ||
|
||
/** | ||
* Computes similarity between two multi-vectors. | ||
* | ||
* <p>A multi-vector is a collection of multiple vectors that represent a single document or query. | ||
* MultiVectorSimilarityFunction is used to determine nearest neighbors during indexing and search | ||
* on multi-vectors. | ||
*/ | ||
public class MultiVectorSimilarityFunction { | ||
|
||
/** Aggregation function to combine similarity across multiple vector values */ | ||
public enum Aggregation { | ||
|
||
/** | ||
* Sum_Max Similarity between two multi-vectors. Computes the sum of maximum similarity found | ||
* for each vector in the first multi-vector against all vectors in the second multi-vector. | ||
*/ | ||
SUM_MAX { | ||
@Override | ||
public float aggregate( | ||
float[] outer, | ||
float[] inner, | ||
VectorSimilarityFunction vectorSimilarityFunction, | ||
int dimension) { | ||
if (outer.length % dimension != 0 || inner.length % dimension != 0) { | ||
throw new IllegalArgumentException("Multi vectors do not match provided dimension value"); | ||
} | ||
|
||
// TODO: can we avoid making vector copies? | ||
List<float[]> outerList = new ArrayList<>(); | ||
List<float[]> innerList = new ArrayList<>(); | ||
for (int i = 0; i < outer.length; i += dimension) { | ||
outerList.add(ArrayUtil.copyOfSubArray(outer, i, i + dimension)); | ||
} | ||
for (int i = 0; i < inner.length; i += dimension) { | ||
innerList.add(ArrayUtil.copyOfSubArray(inner, i, i + dimension)); | ||
} | ||
|
||
float result = 0f; | ||
for (float[] o : outerList) { | ||
float maxSim = Float.MIN_VALUE; | ||
for (float[] i : innerList) { | ||
maxSim = Float.max(maxSim, vectorSimilarityFunction.compare(o, i)); | ||
} | ||
result += maxSim; | ||
} | ||
return result; | ||
} | ||
|
||
@Override | ||
public float aggregate( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's so unfortunate that Java doesn't support generic primitive array and has to have duplicate code. |
||
byte[] outer, | ||
byte[] inner, | ||
VectorSimilarityFunction vectorSimilarityFunction, | ||
int dimension) { | ||
if (outer.length % dimension != 0 || inner.length % dimension != 0) { | ||
throw new IllegalArgumentException("Multi vectors do not match provided dimension value"); | ||
} | ||
|
||
// TODO: can we avoid making vector copies? | ||
List<byte[]> outerList = new ArrayList<>(); | ||
List<byte[]> innerList = new ArrayList<>(); | ||
for (int i = 0; i < outer.length; i += dimension) { | ||
outerList.add(ArrayUtil.copyOfSubArray(outer, i, i + dimension)); | ||
} | ||
for (int i = 0; i < inner.length; i += dimension) { | ||
innerList.add(ArrayUtil.copyOfSubArray(inner, i, i + dimension)); | ||
} | ||
|
||
float result = 0f; | ||
for (byte[] o : outerList) { | ||
float maxSim = Float.MIN_VALUE; | ||
for (byte[] i : innerList) { | ||
maxSim = Float.max(maxSim, vectorSimilarityFunction.compare(o, i)); | ||
} | ||
result += maxSim; | ||
} | ||
return result; | ||
} | ||
}; | ||
|
||
/** | ||
* Computes and aggregates similarity over multiple vector values. | ||
* | ||
* <p>Assumes all vector values in both provided multi-vectors have the same dimension. Slices | ||
* inner and outer float[] multi-vectors into dimension sized vector values for comparison. | ||
* | ||
* @param outer first multi-vector | ||
* @param inner second multi-vector | ||
* @param vectorSimilarityFunction distance function for vector proximity | ||
* @param dimension dimension for each vector in the provided multi-vectors | ||
* @return similarity between the two multi-vectors | ||
*/ | ||
public abstract float aggregate( | ||
float[] outer, | ||
float[] inner, | ||
VectorSimilarityFunction vectorSimilarityFunction, | ||
int dimension); | ||
|
||
/** | ||
* Computes and aggregates similarity over multiple vector values. | ||
* | ||
* <p>Assumes all vector values in both provided multi-vectors have the same dimension. Slices | ||
* inner and outer byte[] multi-vectors into dimension sized vector values for comparison. | ||
* | ||
* @param outer first multi-vector | ||
* @param inner second multi-vector | ||
* @param vectorSimilarityFunction distance function for vector proximity | ||
* @param dimension dimension for each vector in the provided multi-vectors | ||
* @return similarity between the two multi-vectors | ||
*/ | ||
public abstract float aggregate( | ||
byte[] outer, | ||
byte[] inner, | ||
VectorSimilarityFunction vectorSimilarityFunction, | ||
int dimension); | ||
} | ||
|
||
/** Similarity function used for multi-vector distance calculations */ | ||
public final VectorSimilarityFunction similarityFunction; | ||
|
||
/** Aggregation function to combine similarity across multiple vector values */ | ||
public final Aggregation aggregation; | ||
|
||
/** | ||
* Similarity function for computing distance between multi-vector values | ||
* | ||
* @param similarityFunction {@link VectorSimilarityFunction} for computing vector proximity | ||
* @param aggregation {@link Aggregation} to combine similarity across multiple vector values | ||
*/ | ||
public MultiVectorSimilarityFunction( | ||
VectorSimilarityFunction similarityFunction, Aggregation aggregation) { | ||
this.similarityFunction = similarityFunction; | ||
this.aggregation = aggregation; | ||
} | ||
|
||
/** | ||
* Compute similarity between two float multi-vectors. | ||
* | ||
* <p>Expects all component vector values as a single packed float[] for each multi-vector. Uses | ||
* configured aggregation function and vector similarity. | ||
*/ | ||
public float compare(float[] t1, float[] t2, int dimension) { | ||
return aggregation.aggregate(t1, t2, similarityFunction, dimension); | ||
} | ||
|
||
/** | ||
* Compute similarity between two byte multi-vectors. | ||
* | ||
* <p>Expects all component vector values as a single packed float[] for each multi-vector. Uses | ||
* configured aggregation function and vector similarity. | ||
*/ | ||
public float compare(byte[] t1, byte[] t2, int dimension) { | ||
return aggregation.aggregate(t1, t2, similarityFunction, dimension); | ||
} | ||
|
||
@Override | ||
public boolean equals(Object obj) { | ||
if (obj instanceof MultiVectorSimilarityFunction == false) { | ||
return false; | ||
} | ||
MultiVectorSimilarityFunction o = (MultiVectorSimilarityFunction) obj; | ||
return this.similarityFunction == o.similarityFunction && this.aggregation == o.aggregation; | ||
} | ||
|
||
@Override | ||
public int hashCode() { | ||
int result = Integer.hashCode(similarityFunction.ordinal()); | ||
result = 31 * result + Integer.hashCode(aggregation.ordinal()); | ||
return result; | ||
} | ||
|
||
@Override | ||
public String toString() { | ||
return "MultiVectorSimilarityFunction(similarity=" | ||
+ similarityFunction | ||
+ ", aggregation=" | ||
+ aggregation | ||
+ ")"; | ||
} | ||
} |
98 changes: 98 additions & 0 deletions
98
lucene/core/src/test/org/apache/lucene/index/TestMultiVectorSimilarityFunction.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package org.apache.lucene.index; | ||
|
||
import org.apache.lucene.tests.util.LuceneTestCase; | ||
import org.apache.lucene.util.VectorUtil; | ||
import org.junit.Test; | ||
|
||
public class TestMultiVectorSimilarityFunction extends LuceneTestCase { | ||
|
||
@Test | ||
public void testSumMaxWithDotProduct() { | ||
final int dimension = 3; | ||
final VectorSimilarityFunction vectorSim = VectorSimilarityFunction.DOT_PRODUCT; | ||
|
||
float[][] a = | ||
new float[][] { | ||
VectorUtil.l2normalize(new float[] {1.f, 4.f, -3.f}), | ||
VectorUtil.l2normalize(new float[] {8.f, 3.f, -7.f}) | ||
}; | ||
float[][] b = | ||
new float[][] { | ||
VectorUtil.l2normalize(new float[] {-5.f, 2.f, 4.f}), | ||
VectorUtil.l2normalize(new float[] {7.f, 1.f, -3.f}), | ||
VectorUtil.l2normalize(new float[] {-5.f, 8.f, 3.f}) | ||
}; | ||
|
||
float result = 0f; | ||
float[] a0_bDot = | ||
new float[] { | ||
vectorSim.compare(a[0], b[0]), | ||
vectorSim.compare(a[0], b[1]), | ||
vectorSim.compare(a[0], b[2]) | ||
}; | ||
float max = Float.MIN_VALUE; | ||
for (float k : a0_bDot) { | ||
max = Float.max(max, k); | ||
} | ||
result += max; | ||
|
||
float[] a1_bDot = | ||
new float[] { | ||
vectorSim.compare(a[1], b[0]), | ||
vectorSim.compare(a[1], b[1]), | ||
vectorSim.compare(a[1], b[2]) | ||
}; | ||
max = Float.MIN_VALUE; | ||
for (float k : a1_bDot) { | ||
max = Float.max(max, k); | ||
} | ||
result += max; | ||
|
||
float[] a_Packed = new float[a.length * dimension]; | ||
int i = 0; | ||
for (float[] v : a) { | ||
System.arraycopy(v, 0, a_Packed, i, dimension); | ||
i += dimension; | ||
} | ||
float[] b_Packed = new float[b.length * dimension]; | ||
i = 0; | ||
for (float[] v : b) { | ||
System.arraycopy(v, 0, b_Packed, i, dimension); | ||
i += dimension; | ||
} | ||
|
||
MultiVectorSimilarityFunction mvSim = | ||
new MultiVectorSimilarityFunction( | ||
VectorSimilarityFunction.DOT_PRODUCT, | ||
MultiVectorSimilarityFunction.Aggregation.SUM_MAX); | ||
float score = mvSim.compare(a_Packed, b_Packed, dimension); | ||
assertEquals(result, score, 0.0001f); | ||
} | ||
|
||
@Test | ||
public void testDimensionCheck() { | ||
float[] a = {1f, 2f, 3f, 4f, 5f, 6f}; | ||
float[] b = {1f, 2f, 3f, 4f, 5f, 6f, 7f, 8f, 9f}; | ||
MultiVectorSimilarityFunction mvSim = | ||
new MultiVectorSimilarityFunction( | ||
VectorSimilarityFunction.DOT_PRODUCT, | ||
MultiVectorSimilarityFunction.Aggregation.SUM_MAX); | ||
assertThrows(IllegalArgumentException.class, () -> mvSim.compare(a, b, 2)); | ||
} | ||
} |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we add another
compare
method with start and end indexes for both inner and outer, I guess we won't need to copy the array?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, but it needs to go all the way down to
VectorUtilSupport
, which I think should be a PR of its own.