diff --git a/lucene/core/src/java/org/apache/lucene/index/MultiVectorSimilarityFunction.java b/lucene/core/src/java/org/apache/lucene/index/MultiVectorSimilarityFunction.java new file mode 100644 index 000000000000..8ed0e8bc500b --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/index/MultiVectorSimilarityFunction.java @@ -0,0 +1,202 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.index; + +import java.util.ArrayList; +import java.util.List; +import org.apache.lucene.util.ArrayUtil; + +/** + * Computes similarity between two multi-vectors. + * + *
A multi-vector is a collection of multiple vectors that represent a single document or query.
+ * MultiVectorSimilarityFunction is used to determine nearest neighbors during indexing and search
+ * on multi-vectors.
+ */
+public class MultiVectorSimilarityFunction {
+
+ /** Aggregation function to combine similarity across multiple vector values */
+ public enum Aggregation {
+
+ /**
+ * Sum_Max Similarity between two multi-vectors. Computes the sum of maximum similarity found
+ * for each vector in the first multi-vector against all vectors in the second multi-vector.
+ */
+ SUM_MAX {
+ @Override
+ public float aggregate(
+ float[] outer,
+ float[] inner,
+ VectorSimilarityFunction vectorSimilarityFunction,
+ int dimension) {
+ if (outer.length % dimension != 0 || inner.length % dimension != 0) {
+ throw new IllegalArgumentException("Multi vectors do not match provided dimension value");
+ }
+
+ // TODO: can we avoid making vector copies?
+ List Assumes all vector values in both provided multi-vectors have the same dimension. Slices
+ * inner and outer float[] multi-vectors into dimension sized vector values for comparison.
+ *
+ * @param outer first multi-vector
+ * @param inner second multi-vector
+ * @param vectorSimilarityFunction distance function for vector proximity
+ * @param dimension dimension for each vector in the provided multi-vectors
+ * @return similarity between the two multi-vectors
+ */
+ public abstract float aggregate(
+ float[] outer,
+ float[] inner,
+ VectorSimilarityFunction vectorSimilarityFunction,
+ int dimension);
+
+ /**
+ * Computes and aggregates similarity over multiple vector values.
+ *
+ * Assumes all vector values in both provided multi-vectors have the same dimension. Slices
+ * inner and outer byte[] multi-vectors into dimension sized vector values for comparison.
+ *
+ * @param outer first multi-vector
+ * @param inner second multi-vector
+ * @param vectorSimilarityFunction distance function for vector proximity
+ * @param dimension dimension for each vector in the provided multi-vectors
+ * @return similarity between the two multi-vectors
+ */
+ public abstract float aggregate(
+ byte[] outer,
+ byte[] inner,
+ VectorSimilarityFunction vectorSimilarityFunction,
+ int dimension);
+ }
+
+ /** Similarity function used for multi-vector distance calculations */
+ public final VectorSimilarityFunction similarityFunction;
+
+ /** Aggregation function to combine similarity across multiple vector values */
+ public final Aggregation aggregation;
+
+ /**
+ * Similarity function for computing distance between multi-vector values
+ *
+ * @param similarityFunction {@link VectorSimilarityFunction} for computing vector proximity
+ * @param aggregation {@link Aggregation} to combine similarity across multiple vector values
+ */
+ public MultiVectorSimilarityFunction(
+ VectorSimilarityFunction similarityFunction, Aggregation aggregation) {
+ this.similarityFunction = similarityFunction;
+ this.aggregation = aggregation;
+ }
+
+ /**
+ * Compute similarity between two float multi-vectors.
+ *
+ * Expects all component vector values as a single packed float[] for each multi-vector. Uses
+ * configured aggregation function and vector similarity.
+ */
+ public float compare(float[] t1, float[] t2, int dimension) {
+ return aggregation.aggregate(t1, t2, similarityFunction, dimension);
+ }
+
+ /**
+ * Compute similarity between two byte multi-vectors.
+ *
+ * Expects all component vector values as a single packed float[] for each multi-vector. Uses
+ * configured aggregation function and vector similarity.
+ */
+ public float compare(byte[] t1, byte[] t2, int dimension) {
+ return aggregation.aggregate(t1, t2, similarityFunction, dimension);
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (obj instanceof MultiVectorSimilarityFunction == false) {
+ return false;
+ }
+ MultiVectorSimilarityFunction o = (MultiVectorSimilarityFunction) obj;
+ return this.similarityFunction == o.similarityFunction && this.aggregation == o.aggregation;
+ }
+
+ @Override
+ public int hashCode() {
+ int result = Integer.hashCode(similarityFunction.ordinal());
+ result = 31 * result + Integer.hashCode(aggregation.ordinal());
+ return result;
+ }
+
+ @Override
+ public String toString() {
+ return "MultiVectorSimilarityFunction(similarity="
+ + similarityFunction
+ + ", aggregation="
+ + aggregation
+ + ")";
+ }
+}
diff --git a/lucene/core/src/test/org/apache/lucene/index/TestMultiVectorSimilarityFunction.java b/lucene/core/src/test/org/apache/lucene/index/TestMultiVectorSimilarityFunction.java
new file mode 100644
index 000000000000..507c3dca26a4
--- /dev/null
+++ b/lucene/core/src/test/org/apache/lucene/index/TestMultiVectorSimilarityFunction.java
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.index;
+
+import org.apache.lucene.tests.util.LuceneTestCase;
+import org.apache.lucene.util.VectorUtil;
+import org.junit.Test;
+
+public class TestMultiVectorSimilarityFunction extends LuceneTestCase {
+
+ @Test
+ public void testSumMaxWithDotProduct() {
+ final int dimension = 3;
+ final VectorSimilarityFunction vectorSim = VectorSimilarityFunction.DOT_PRODUCT;
+
+ float[][] a =
+ new float[][] {
+ VectorUtil.l2normalize(new float[] {1.f, 4.f, -3.f}),
+ VectorUtil.l2normalize(new float[] {8.f, 3.f, -7.f})
+ };
+ float[][] b =
+ new float[][] {
+ VectorUtil.l2normalize(new float[] {-5.f, 2.f, 4.f}),
+ VectorUtil.l2normalize(new float[] {7.f, 1.f, -3.f}),
+ VectorUtil.l2normalize(new float[] {-5.f, 8.f, 3.f})
+ };
+
+ float result = 0f;
+ float[] a0_bDot =
+ new float[] {
+ vectorSim.compare(a[0], b[0]),
+ vectorSim.compare(a[0], b[1]),
+ vectorSim.compare(a[0], b[2])
+ };
+ float max = Float.MIN_VALUE;
+ for (float k : a0_bDot) {
+ max = Float.max(max, k);
+ }
+ result += max;
+
+ float[] a1_bDot =
+ new float[] {
+ vectorSim.compare(a[1], b[0]),
+ vectorSim.compare(a[1], b[1]),
+ vectorSim.compare(a[1], b[2])
+ };
+ max = Float.MIN_VALUE;
+ for (float k : a1_bDot) {
+ max = Float.max(max, k);
+ }
+ result += max;
+
+ float[] a_Packed = new float[a.length * dimension];
+ int i = 0;
+ for (float[] v : a) {
+ System.arraycopy(v, 0, a_Packed, i, dimension);
+ i += dimension;
+ }
+ float[] b_Packed = new float[b.length * dimension];
+ i = 0;
+ for (float[] v : b) {
+ System.arraycopy(v, 0, b_Packed, i, dimension);
+ i += dimension;
+ }
+
+ MultiVectorSimilarityFunction mvSim =
+ new MultiVectorSimilarityFunction(
+ VectorSimilarityFunction.DOT_PRODUCT,
+ MultiVectorSimilarityFunction.Aggregation.SUM_MAX);
+ float score = mvSim.compare(a_Packed, b_Packed, dimension);
+ assertEquals(result, score, 0.0001f);
+ }
+
+ @Test
+ public void testDimensionCheck() {
+ float[] a = {1f, 2f, 3f, 4f, 5f, 6f};
+ float[] b = {1f, 2f, 3f, 4f, 5f, 6f, 7f, 8f, 9f};
+ MultiVectorSimilarityFunction mvSim =
+ new MultiVectorSimilarityFunction(
+ VectorSimilarityFunction.DOT_PRODUCT,
+ MultiVectorSimilarityFunction.Aggregation.SUM_MAX);
+ assertThrows(IllegalArgumentException.class, () -> mvSim.compare(a, b, 2));
+ }
+}