diff --git a/lucene/core/src/java/org/apache/lucene/index/MultiVectorSimilarityFunction.java b/lucene/core/src/java/org/apache/lucene/index/MultiVectorSimilarityFunction.java new file mode 100644 index 000000000000..8ed0e8bc500b --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/index/MultiVectorSimilarityFunction.java @@ -0,0 +1,202 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.index; + +import java.util.ArrayList; +import java.util.List; +import org.apache.lucene.util.ArrayUtil; + +/** + * Computes similarity between two multi-vectors. + * + *

A multi-vector is a collection of multiple vectors that represent a single document or query. + * MultiVectorSimilarityFunction is used to determine nearest neighbors during indexing and search + * on multi-vectors. + */ +public class MultiVectorSimilarityFunction { + + /** Aggregation function to combine similarity across multiple vector values */ + public enum Aggregation { + + /** + * Sum_Max Similarity between two multi-vectors. Computes the sum of maximum similarity found + * for each vector in the first multi-vector against all vectors in the second multi-vector. + */ + SUM_MAX { + @Override + public float aggregate( + float[] outer, + float[] inner, + VectorSimilarityFunction vectorSimilarityFunction, + int dimension) { + if (outer.length % dimension != 0 || inner.length % dimension != 0) { + throw new IllegalArgumentException("Multi vectors do not match provided dimension value"); + } + + // TODO: can we avoid making vector copies? + List outerList = new ArrayList<>(); + List innerList = new ArrayList<>(); + for (int i = 0; i < outer.length; i += dimension) { + outerList.add(ArrayUtil.copyOfSubArray(outer, i, i + dimension)); + } + for (int i = 0; i < inner.length; i += dimension) { + innerList.add(ArrayUtil.copyOfSubArray(inner, i, i + dimension)); + } + + float result = 0f; + for (float[] o : outerList) { + float maxSim = Float.MIN_VALUE; + for (float[] i : innerList) { + maxSim = Float.max(maxSim, vectorSimilarityFunction.compare(o, i)); + } + result += maxSim; + } + return result; + } + + @Override + public float aggregate( + byte[] outer, + byte[] inner, + VectorSimilarityFunction vectorSimilarityFunction, + int dimension) { + if (outer.length % dimension != 0 || inner.length % dimension != 0) { + throw new IllegalArgumentException("Multi vectors do not match provided dimension value"); + } + + // TODO: can we avoid making vector copies? + List outerList = new ArrayList<>(); + List innerList = new ArrayList<>(); + for (int i = 0; i < outer.length; i += dimension) { + outerList.add(ArrayUtil.copyOfSubArray(outer, i, i + dimension)); + } + for (int i = 0; i < inner.length; i += dimension) { + innerList.add(ArrayUtil.copyOfSubArray(inner, i, i + dimension)); + } + + float result = 0f; + for (byte[] o : outerList) { + float maxSim = Float.MIN_VALUE; + for (byte[] i : innerList) { + maxSim = Float.max(maxSim, vectorSimilarityFunction.compare(o, i)); + } + result += maxSim; + } + return result; + } + }; + + /** + * Computes and aggregates similarity over multiple vector values. + * + *

Assumes all vector values in both provided multi-vectors have the same dimension. Slices + * inner and outer float[] multi-vectors into dimension sized vector values for comparison. + * + * @param outer first multi-vector + * @param inner second multi-vector + * @param vectorSimilarityFunction distance function for vector proximity + * @param dimension dimension for each vector in the provided multi-vectors + * @return similarity between the two multi-vectors + */ + public abstract float aggregate( + float[] outer, + float[] inner, + VectorSimilarityFunction vectorSimilarityFunction, + int dimension); + + /** + * Computes and aggregates similarity over multiple vector values. + * + *

Assumes all vector values in both provided multi-vectors have the same dimension. Slices + * inner and outer byte[] multi-vectors into dimension sized vector values for comparison. + * + * @param outer first multi-vector + * @param inner second multi-vector + * @param vectorSimilarityFunction distance function for vector proximity + * @param dimension dimension for each vector in the provided multi-vectors + * @return similarity between the two multi-vectors + */ + public abstract float aggregate( + byte[] outer, + byte[] inner, + VectorSimilarityFunction vectorSimilarityFunction, + int dimension); + } + + /** Similarity function used for multi-vector distance calculations */ + public final VectorSimilarityFunction similarityFunction; + + /** Aggregation function to combine similarity across multiple vector values */ + public final Aggregation aggregation; + + /** + * Similarity function for computing distance between multi-vector values + * + * @param similarityFunction {@link VectorSimilarityFunction} for computing vector proximity + * @param aggregation {@link Aggregation} to combine similarity across multiple vector values + */ + public MultiVectorSimilarityFunction( + VectorSimilarityFunction similarityFunction, Aggregation aggregation) { + this.similarityFunction = similarityFunction; + this.aggregation = aggregation; + } + + /** + * Compute similarity between two float multi-vectors. + * + *

Expects all component vector values as a single packed float[] for each multi-vector. Uses + * configured aggregation function and vector similarity. + */ + public float compare(float[] t1, float[] t2, int dimension) { + return aggregation.aggregate(t1, t2, similarityFunction, dimension); + } + + /** + * Compute similarity between two byte multi-vectors. + * + *

Expects all component vector values as a single packed float[] for each multi-vector. Uses + * configured aggregation function and vector similarity. + */ + public float compare(byte[] t1, byte[] t2, int dimension) { + return aggregation.aggregate(t1, t2, similarityFunction, dimension); + } + + @Override + public boolean equals(Object obj) { + if (obj instanceof MultiVectorSimilarityFunction == false) { + return false; + } + MultiVectorSimilarityFunction o = (MultiVectorSimilarityFunction) obj; + return this.similarityFunction == o.similarityFunction && this.aggregation == o.aggregation; + } + + @Override + public int hashCode() { + int result = Integer.hashCode(similarityFunction.ordinal()); + result = 31 * result + Integer.hashCode(aggregation.ordinal()); + return result; + } + + @Override + public String toString() { + return "MultiVectorSimilarityFunction(similarity=" + + similarityFunction + + ", aggregation=" + + aggregation + + ")"; + } +} diff --git a/lucene/core/src/test/org/apache/lucene/index/TestMultiVectorSimilarityFunction.java b/lucene/core/src/test/org/apache/lucene/index/TestMultiVectorSimilarityFunction.java new file mode 100644 index 000000000000..507c3dca26a4 --- /dev/null +++ b/lucene/core/src/test/org/apache/lucene/index/TestMultiVectorSimilarityFunction.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.index; + +import org.apache.lucene.tests.util.LuceneTestCase; +import org.apache.lucene.util.VectorUtil; +import org.junit.Test; + +public class TestMultiVectorSimilarityFunction extends LuceneTestCase { + + @Test + public void testSumMaxWithDotProduct() { + final int dimension = 3; + final VectorSimilarityFunction vectorSim = VectorSimilarityFunction.DOT_PRODUCT; + + float[][] a = + new float[][] { + VectorUtil.l2normalize(new float[] {1.f, 4.f, -3.f}), + VectorUtil.l2normalize(new float[] {8.f, 3.f, -7.f}) + }; + float[][] b = + new float[][] { + VectorUtil.l2normalize(new float[] {-5.f, 2.f, 4.f}), + VectorUtil.l2normalize(new float[] {7.f, 1.f, -3.f}), + VectorUtil.l2normalize(new float[] {-5.f, 8.f, 3.f}) + }; + + float result = 0f; + float[] a0_bDot = + new float[] { + vectorSim.compare(a[0], b[0]), + vectorSim.compare(a[0], b[1]), + vectorSim.compare(a[0], b[2]) + }; + float max = Float.MIN_VALUE; + for (float k : a0_bDot) { + max = Float.max(max, k); + } + result += max; + + float[] a1_bDot = + new float[] { + vectorSim.compare(a[1], b[0]), + vectorSim.compare(a[1], b[1]), + vectorSim.compare(a[1], b[2]) + }; + max = Float.MIN_VALUE; + for (float k : a1_bDot) { + max = Float.max(max, k); + } + result += max; + + float[] a_Packed = new float[a.length * dimension]; + int i = 0; + for (float[] v : a) { + System.arraycopy(v, 0, a_Packed, i, dimension); + i += dimension; + } + float[] b_Packed = new float[b.length * dimension]; + i = 0; + for (float[] v : b) { + System.arraycopy(v, 0, b_Packed, i, dimension); + i += dimension; + } + + MultiVectorSimilarityFunction mvSim = + new MultiVectorSimilarityFunction( + VectorSimilarityFunction.DOT_PRODUCT, + MultiVectorSimilarityFunction.Aggregation.SUM_MAX); + float score = mvSim.compare(a_Packed, b_Packed, dimension); + assertEquals(result, score, 0.0001f); + } + + @Test + public void testDimensionCheck() { + float[] a = {1f, 2f, 3f, 4f, 5f, 6f}; + float[] b = {1f, 2f, 3f, 4f, 5f, 6f, 7f, 8f, 9f}; + MultiVectorSimilarityFunction mvSim = + new MultiVectorSimilarityFunction( + VectorSimilarityFunction.DOT_PRODUCT, + MultiVectorSimilarityFunction.Aggregation.SUM_MAX); + assertThrows(IllegalArgumentException.class, () -> mvSim.compare(a, b, 2)); + } +}