diff --git a/modules/lang-painless/src/main/resources/org/elasticsearch/painless/org.elasticsearch.script.fields.txt b/modules/lang-painless/src/main/resources/org/elasticsearch/painless/org.elasticsearch.script.fields.txt index 875b9a1dac3e8..85dba97a392b4 100644 --- a/modules/lang-painless/src/main/resources/org/elasticsearch/painless/org.elasticsearch.script.fields.txt +++ b/modules/lang-painless/src/main/resources/org/elasticsearch/painless/org.elasticsearch.script.fields.txt @@ -132,8 +132,8 @@ class org.elasticsearch.script.field.SeqNoDocValuesField @dynamic_type { class org.elasticsearch.script.field.VersionDocValuesField @dynamic_type { } -class org.elasticsearch.script.field.vectors.MultiDenseVector { - MultiDenseVector EMPTY +class org.elasticsearch.script.field.vectors.RankVectors { + RankVectors EMPTY float[] getMagnitudes() Iterator getVectors() @@ -142,9 +142,9 @@ class org.elasticsearch.script.field.vectors.MultiDenseVector { int size() } -class org.elasticsearch.script.field.vectors.MultiDenseVectorDocValuesField { - MultiDenseVector get() - MultiDenseVector get(MultiDenseVector) +class org.elasticsearch.script.field.vectors.RankVectorsDocValuesField { + RankVectors get() + RankVectors get(RankVectors) } class org.elasticsearch.script.field.vectors.DenseVector { diff --git a/modules/lang-painless/src/main/resources/org/elasticsearch/painless/org.elasticsearch.script.score.txt b/modules/lang-painless/src/main/resources/org/elasticsearch/painless/org.elasticsearch.script.score.txt index 5a1d8c002aa17..a5118db4876cb 100644 --- a/modules/lang-painless/src/main/resources/org/elasticsearch/painless/org.elasticsearch.script.score.txt +++ b/modules/lang-painless/src/main/resources/org/elasticsearch/painless/org.elasticsearch.script.score.txt @@ -50,7 +50,7 @@ static_import { double cosineSimilarity(org.elasticsearch.script.ScoreScript, Object, String) bound_to org.elasticsearch.script.VectorScoreScriptUtils$CosineSimilarity double dotProduct(org.elasticsearch.script.ScoreScript, Object, String) bound_to org.elasticsearch.script.VectorScoreScriptUtils$DotProduct double hamming(org.elasticsearch.script.ScoreScript, Object, String) bound_to org.elasticsearch.script.VectorScoreScriptUtils$Hamming - double maxSimDotProduct(org.elasticsearch.script.ScoreScript, Object, String) bound_to org.elasticsearch.script.MultiVectorScoreScriptUtils$MaxSimDotProduct - double maxSimInvHamming(org.elasticsearch.script.ScoreScript, Object, String) bound_to org.elasticsearch.script.MultiVectorScoreScriptUtils$MaxSimInvHamming + double maxSimDotProduct(org.elasticsearch.script.ScoreScript, Object, String) bound_to org.elasticsearch.script.RankVectorsScoreScriptUtils$MaxSimDotProduct + double maxSimInvHamming(org.elasticsearch.script.ScoreScript, Object, String) bound_to org.elasticsearch.script.RankVectorsScoreScriptUtils$MaxSimInvHamming } diff --git a/modules/lang-painless/src/main/resources/org/elasticsearch/painless/org.elasticsearch.txt b/modules/lang-painless/src/main/resources/org/elasticsearch/painless/org.elasticsearch.txt index b2db0d1006d40..4815b9c10e733 100644 --- a/modules/lang-painless/src/main/resources/org/elasticsearch/painless/org.elasticsearch.txt +++ b/modules/lang-painless/src/main/resources/org/elasticsearch/painless/org.elasticsearch.txt @@ -123,7 +123,7 @@ class org.elasticsearch.index.mapper.vectors.DenseVectorScriptDocValues { float getMagnitude() } -class org.elasticsearch.index.mapper.vectors.MultiDenseVectorScriptDocValues { +class org.elasticsearch.index.mapper.vectors.RankVectorsScriptDocValues { Iterator getVectorValues() float[] getMagnitudes() } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/RankVectorsDVLeafFieldData.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/RankVectorsDVLeafFieldData.java index ffa4852c44a9f..0125d0249ec2b 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/RankVectorsDVLeafFieldData.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/RankVectorsDVLeafFieldData.java @@ -15,9 +15,9 @@ import org.elasticsearch.index.fielddata.LeafFieldData; import org.elasticsearch.index.fielddata.SortedBinaryDocValues; import org.elasticsearch.script.field.DocValuesScriptFieldFactory; -import org.elasticsearch.script.field.vectors.BitMultiDenseVectorDocValuesField; -import org.elasticsearch.script.field.vectors.ByteMultiDenseVectorDocValuesField; -import org.elasticsearch.script.field.vectors.FloatMultiDenseVectorDocValuesField; +import org.elasticsearch.script.field.vectors.BitRankVectorsDocValuesField; +import org.elasticsearch.script.field.vectors.ByteRankVectorsDocValuesField; +import org.elasticsearch.script.field.vectors.FloatRankVectorsDocValuesField; import java.io.IOException; @@ -40,9 +40,9 @@ public DocValuesScriptFieldFactory getScriptFieldFactory(String name) { BinaryDocValues values = DocValues.getBinary(reader, field); BinaryDocValues magnitudeValues = DocValues.getBinary(reader, field + RankVectorsFieldMapper.VECTOR_MAGNITUDES_SUFFIX); return switch (elementType) { - case BYTE -> new ByteMultiDenseVectorDocValuesField(values, magnitudeValues, name, elementType, dims); - case FLOAT -> new FloatMultiDenseVectorDocValuesField(values, magnitudeValues, name, elementType, dims); - case BIT -> new BitMultiDenseVectorDocValuesField(values, magnitudeValues, name, elementType, dims); + case BYTE -> new ByteRankVectorsDocValuesField(values, magnitudeValues, name, elementType, dims); + case FLOAT -> new FloatRankVectorsDocValuesField(values, magnitudeValues, name, elementType, dims); + case BIT -> new BitRankVectorsDocValuesField(values, magnitudeValues, name, elementType, dims); }; } catch (IOException e) { throw new IllegalStateException("Cannot load doc values for multi-vector field!", e); diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/MultiDenseVectorScriptDocValues.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/RankVectorsScriptDocValues.java similarity index 69% rename from server/src/main/java/org/elasticsearch/index/mapper/vectors/MultiDenseVectorScriptDocValues.java rename to server/src/main/java/org/elasticsearch/index/mapper/vectors/RankVectorsScriptDocValues.java index 23b8f3f2bb4fb..e663df86c67ca 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/MultiDenseVectorScriptDocValues.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/RankVectorsScriptDocValues.java @@ -11,18 +11,18 @@ import org.apache.lucene.util.BytesRef; import org.elasticsearch.index.fielddata.ScriptDocValues; -import org.elasticsearch.script.field.vectors.MultiDenseVector; +import org.elasticsearch.script.field.vectors.RankVectors; import java.util.Iterator; -public class MultiDenseVectorScriptDocValues extends ScriptDocValues { +public class RankVectorsScriptDocValues extends ScriptDocValues { - public static final String MISSING_VECTOR_FIELD_MESSAGE = "A document doesn't have a value for a vector field!"; + public static final String MISSING_VECTOR_FIELD_MESSAGE = "A document doesn't have a value for a rank-vectors field!"; private final int dims; - protected final MultiDenseVectorSupplier dvSupplier; + protected final RankVectorsSupplier dvSupplier; - public MultiDenseVectorScriptDocValues(MultiDenseVectorSupplier supplier, int dims) { + public RankVectorsScriptDocValues(RankVectorsSupplier supplier, int dims) { super(supplier); this.dvSupplier = supplier; this.dims = dims; @@ -32,8 +32,8 @@ public int dims() { return dims; } - private MultiDenseVector getCheckedVector() { - MultiDenseVector vector = dvSupplier.getInternal(); + private RankVectors getCheckedVector() { + RankVectors vector = dvSupplier.getInternal(); if (vector == null) { throw new IllegalArgumentException(MISSING_VECTOR_FIELD_MESSAGE); } @@ -57,25 +57,25 @@ public float[] getMagnitudes() { @Override public BytesRef get(int index) { throw new UnsupportedOperationException( - "accessing a vector field's value through 'get' or 'value' is not supported, use 'vectorValues' or 'magnitudes' instead." + "accessing a rank-vectors field's value through 'get' or 'value' is not supported, use 'vectorValues' or 'magnitudes' instead." ); } @Override public int size() { - MultiDenseVector mdv = dvSupplier.getInternal(); + RankVectors mdv = dvSupplier.getInternal(); if (mdv != null) { return mdv.size(); } return 0; } - public interface MultiDenseVectorSupplier extends Supplier { + public interface RankVectorsSupplier extends Supplier { @Override default BytesRef getInternal(int index) { throw new UnsupportedOperationException(); } - MultiDenseVector getInternal(); + RankVectors getInternal(); } } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorEncoderDecoder.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorEncoderDecoder.java index 3db2d164846bd..54b369ab1f377 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorEncoderDecoder.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorEncoderDecoder.java @@ -94,14 +94,4 @@ public static float[] getMultiMagnitudes(BytesRef magnitudes) { return multiMagnitudes; } - public static void decodeMultiDenseVector(BytesRef vectorBR, int numVectors, float[][] multiVectorValue) { - if (vectorBR == null) { - throw new IllegalArgumentException(MultiDenseVectorScriptDocValues.MISSING_VECTOR_FIELD_MESSAGE); - } - FloatBuffer fb = ByteBuffer.wrap(vectorBR.bytes, vectorBR.offset, vectorBR.length).order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer(); - for (int i = 0; i < numVectors; i++) { - fb.get(multiVectorValue[i]); - } - } - } diff --git a/server/src/main/java/org/elasticsearch/script/MultiVectorScoreScriptUtils.java b/server/src/main/java/org/elasticsearch/script/RankVectorsScoreScriptUtils.java similarity index 85% rename from server/src/main/java/org/elasticsearch/script/MultiVectorScoreScriptUtils.java rename to server/src/main/java/org/elasticsearch/script/RankVectorsScoreScriptUtils.java index 136c5e7b57d4b..2d11641cb5aa7 100644 --- a/server/src/main/java/org/elasticsearch/script/MultiVectorScoreScriptUtils.java +++ b/server/src/main/java/org/elasticsearch/script/RankVectorsScoreScriptUtils.java @@ -12,19 +12,19 @@ import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.script.field.vectors.DenseVector; -import org.elasticsearch.script.field.vectors.MultiDenseVectorDocValuesField; +import org.elasticsearch.script.field.vectors.RankVectorsDocValuesField; import java.io.IOException; import java.util.HexFormat; import java.util.List; -public class MultiVectorScoreScriptUtils { +public class RankVectorsScoreScriptUtils { - public static class MultiDenseVectorFunction { + public static class RankVectorsFunction { protected final ScoreScript scoreScript; - protected final MultiDenseVectorDocValuesField field; + protected final RankVectorsDocValuesField field; - public MultiDenseVectorFunction(ScoreScript scoreScript, MultiDenseVectorDocValuesField field) { + public RankVectorsFunction(ScoreScript scoreScript, RankVectorsDocValuesField field) { this.scoreScript = scoreScript; this.field = field; } @@ -41,7 +41,7 @@ void setNextVector() { } } - public static class ByteMultiDenseVectorFunction extends MultiDenseVectorFunction { + public static class ByteRankVectorsFunction extends RankVectorsFunction { protected final byte[][] queryVector; /** @@ -51,7 +51,7 @@ public static class ByteMultiDenseVectorFunction extends MultiDenseVectorFunctio * @param field The vector field. * @param queryVector The query vector. */ - public ByteMultiDenseVectorFunction(ScoreScript scoreScript, MultiDenseVectorDocValuesField field, List> queryVector) { + public ByteRankVectorsFunction(ScoreScript scoreScript, RankVectorsDocValuesField field, List> queryVector) { super(scoreScript, field); if (queryVector.isEmpty()) { throw new IllegalArgumentException("The query vector is empty."); @@ -84,13 +84,13 @@ public ByteMultiDenseVectorFunction(ScoreScript scoreScript, MultiDenseVectorDoc * @param field The vector field. * @param queryVector The query vector. */ - public ByteMultiDenseVectorFunction(ScoreScript scoreScript, MultiDenseVectorDocValuesField field, byte[][] queryVector) { + public ByteRankVectorsFunction(ScoreScript scoreScript, RankVectorsDocValuesField field, byte[][] queryVector) { super(scoreScript, field); this.queryVector = queryVector; } } - public static class FloatMultiDenseVectorFunction extends MultiDenseVectorFunction { + public static class FloatRankVectorsFunction extends RankVectorsFunction { protected final float[][] queryVector; /** @@ -100,11 +100,7 @@ public static class FloatMultiDenseVectorFunction extends MultiDenseVectorFuncti * @param field The vector field. * @param queryVector The query vector. */ - public FloatMultiDenseVectorFunction( - ScoreScript scoreScript, - MultiDenseVectorDocValuesField field, - List> queryVector - ) { + public FloatRankVectorsFunction(ScoreScript scoreScript, RankVectorsDocValuesField field, List> queryVector) { super(scoreScript, field); if (queryVector.isEmpty()) { throw new IllegalArgumentException("The query vector is empty."); @@ -133,13 +129,13 @@ public interface MaxSimInvHammingDistanceInterface { float maxSimInvHamming(); } - public static class ByteMaxSimInvHammingDistance extends ByteMultiDenseVectorFunction implements MaxSimInvHammingDistanceInterface { + public static class ByteMaxSimInvHammingDistance extends ByteRankVectorsFunction implements MaxSimInvHammingDistanceInterface { - public ByteMaxSimInvHammingDistance(ScoreScript scoreScript, MultiDenseVectorDocValuesField field, List> queryVector) { + public ByteMaxSimInvHammingDistance(ScoreScript scoreScript, RankVectorsDocValuesField field, List> queryVector) { super(scoreScript, field, queryVector); } - public ByteMaxSimInvHammingDistance(ScoreScript scoreScript, MultiDenseVectorDocValuesField field, byte[][] queryVector) { + public ByteMaxSimInvHammingDistance(ScoreScript scoreScript, RankVectorsDocValuesField field, byte[][] queryVector) { super(scoreScript, field, queryVector); } @@ -183,7 +179,7 @@ public static final class MaxSimInvHamming { private final MaxSimInvHammingDistanceInterface function; public MaxSimInvHamming(ScoreScript scoreScript, Object queryVector, String fieldName) { - MultiDenseVectorDocValuesField field = (MultiDenseVectorDocValuesField) scoreScript.field(fieldName); + RankVectorsDocValuesField field = (RankVectorsDocValuesField) scoreScript.field(fieldName); if (field.getElementType() == DenseVectorFieldMapper.ElementType.FLOAT) { throw new IllegalArgumentException("hamming distance is only supported for byte or bit vectors"); } @@ -205,11 +201,11 @@ public interface MaxSimDotProductInterface { double maxSimDotProduct(); } - public static class MaxSimBitDotProduct extends MultiDenseVectorFunction implements MaxSimDotProductInterface { + public static class MaxSimBitDotProduct extends RankVectorsFunction implements MaxSimDotProductInterface { private final byte[][] byteQueryVector; private final float[][] floatQueryVector; - public MaxSimBitDotProduct(ScoreScript scoreScript, MultiDenseVectorDocValuesField field, byte[][] queryVector) { + public MaxSimBitDotProduct(ScoreScript scoreScript, RankVectorsDocValuesField field, byte[][] queryVector) { super(scoreScript, field); if (field.getElementType() != DenseVectorFieldMapper.ElementType.BIT) { throw new IllegalArgumentException("Cannot calculate bit dot product for non-bit vectors"); @@ -230,7 +226,7 @@ public MaxSimBitDotProduct(ScoreScript scoreScript, MultiDenseVectorDocValuesFie this.floatQueryVector = null; } - public MaxSimBitDotProduct(ScoreScript scoreScript, MultiDenseVectorDocValuesField field, List> queryVector) { + public MaxSimBitDotProduct(ScoreScript scoreScript, RankVectorsDocValuesField field, List> queryVector) { super(scoreScript, field); if (queryVector.isEmpty()) { throw new IllegalArgumentException("The query vector is empty."); @@ -304,13 +300,13 @@ public double maxSimDotProduct() { } } - public static class MaxSimByteDotProduct extends ByteMultiDenseVectorFunction implements MaxSimDotProductInterface { + public static class MaxSimByteDotProduct extends ByteRankVectorsFunction implements MaxSimDotProductInterface { - public MaxSimByteDotProduct(ScoreScript scoreScript, MultiDenseVectorDocValuesField field, List> queryVector) { + public MaxSimByteDotProduct(ScoreScript scoreScript, RankVectorsDocValuesField field, List> queryVector) { super(scoreScript, field, queryVector); } - public MaxSimByteDotProduct(ScoreScript scoreScript, MultiDenseVectorDocValuesField field, byte[][] queryVector) { + public MaxSimByteDotProduct(ScoreScript scoreScript, RankVectorsDocValuesField field, byte[][] queryVector) { super(scoreScript, field, queryVector); } @@ -320,9 +316,9 @@ public double maxSimDotProduct() { } } - public static class MaxSimFloatDotProduct extends FloatMultiDenseVectorFunction implements MaxSimDotProductInterface { + public static class MaxSimFloatDotProduct extends FloatRankVectorsFunction implements MaxSimDotProductInterface { - public MaxSimFloatDotProduct(ScoreScript scoreScript, MultiDenseVectorDocValuesField field, List> queryVector) { + public MaxSimFloatDotProduct(ScoreScript scoreScript, RankVectorsDocValuesField field, List> queryVector) { super(scoreScript, field, queryVector); } @@ -338,7 +334,7 @@ public static final class MaxSimDotProduct { @SuppressWarnings("unchecked") public MaxSimDotProduct(ScoreScript scoreScript, Object queryVector, String fieldName) { - MultiDenseVectorDocValuesField field = (MultiDenseVectorDocValuesField) scoreScript.field(fieldName); + RankVectorsDocValuesField field = (RankVectorsDocValuesField) scoreScript.field(fieldName); function = switch (field.getElementType()) { case BIT -> { BytesOrList bytesOrList = parseBytes(queryVector); diff --git a/server/src/main/java/org/elasticsearch/script/field/vectors/BitMultiDenseVector.java b/server/src/main/java/org/elasticsearch/script/field/vectors/BitRankVectors.java similarity index 94% rename from server/src/main/java/org/elasticsearch/script/field/vectors/BitMultiDenseVector.java rename to server/src/main/java/org/elasticsearch/script/field/vectors/BitRankVectors.java index 7805816090d51..0e2984c2a7dff 100644 --- a/server/src/main/java/org/elasticsearch/script/field/vectors/BitMultiDenseVector.java +++ b/server/src/main/java/org/elasticsearch/script/field/vectors/BitRankVectors.java @@ -15,8 +15,8 @@ import java.util.Arrays; -public class BitMultiDenseVector extends ByteMultiDenseVector { - public BitMultiDenseVector(VectorIterator vectorValues, BytesRef magnitudesBytes, int numVecs, int dims) { +public class BitRankVectors extends ByteRankVectors { + public BitRankVectors(VectorIterator vectorValues, BytesRef magnitudesBytes, int numVecs, int dims) { super(vectorValues, magnitudesBytes, numVecs, dims); } diff --git a/server/src/main/java/org/elasticsearch/script/field/vectors/BitMultiDenseVectorDocValuesField.java b/server/src/main/java/org/elasticsearch/script/field/vectors/BitRankVectorsDocValuesField.java similarity index 64% rename from server/src/main/java/org/elasticsearch/script/field/vectors/BitMultiDenseVectorDocValuesField.java rename to server/src/main/java/org/elasticsearch/script/field/vectors/BitRankVectorsDocValuesField.java index 35a43eabb8f0c..6d38621440fbf 100644 --- a/server/src/main/java/org/elasticsearch/script/field/vectors/BitMultiDenseVectorDocValuesField.java +++ b/server/src/main/java/org/elasticsearch/script/field/vectors/BitRankVectorsDocValuesField.java @@ -12,20 +12,14 @@ import org.apache.lucene.index.BinaryDocValues; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType; -public class BitMultiDenseVectorDocValuesField extends ByteMultiDenseVectorDocValuesField { +public class BitRankVectorsDocValuesField extends ByteRankVectorsDocValuesField { - public BitMultiDenseVectorDocValuesField( - BinaryDocValues input, - BinaryDocValues magnitudes, - String name, - ElementType elementType, - int dims - ) { + public BitRankVectorsDocValuesField(BinaryDocValues input, BinaryDocValues magnitudes, String name, ElementType elementType, int dims) { super(input, magnitudes, name, elementType, dims / 8); } @Override - protected MultiDenseVector getVector() { - return new BitMultiDenseVector(vectorValue, magnitudesValue, numVecs, dims); + protected RankVectors getVector() { + return new BitRankVectors(vectorValue, magnitudesValue, numVecs, dims); } } diff --git a/server/src/main/java/org/elasticsearch/script/field/vectors/ByteMultiDenseVector.java b/server/src/main/java/org/elasticsearch/script/field/vectors/ByteRankVectors.java similarity index 95% rename from server/src/main/java/org/elasticsearch/script/field/vectors/ByteMultiDenseVector.java rename to server/src/main/java/org/elasticsearch/script/field/vectors/ByteRankVectors.java index 5e9d3e05746c8..f8e82046037c4 100644 --- a/server/src/main/java/org/elasticsearch/script/field/vectors/ByteMultiDenseVector.java +++ b/server/src/main/java/org/elasticsearch/script/field/vectors/ByteRankVectors.java @@ -16,7 +16,7 @@ import java.util.Arrays; import java.util.Iterator; -public class ByteMultiDenseVector implements MultiDenseVector { +public class ByteRankVectors implements RankVectors { protected final VectorIterator vectorValues; protected final int numVecs; @@ -25,7 +25,7 @@ public class ByteMultiDenseVector implements MultiDenseVector { private float[] magnitudes; private final BytesRef magnitudesBytes; - public ByteMultiDenseVector(VectorIterator vectorValues, BytesRef magnitudesBytes, int numVecs, int dims) { + public ByteRankVectors(VectorIterator vectorValues, BytesRef magnitudesBytes, int numVecs, int dims) { assert magnitudesBytes.length == numVecs * Float.BYTES; this.vectorValues = vectorValues; this.numVecs = numVecs; diff --git a/server/src/main/java/org/elasticsearch/script/field/vectors/ByteMultiDenseVectorDocValuesField.java b/server/src/main/java/org/elasticsearch/script/field/vectors/ByteRankVectorsDocValuesField.java similarity index 85% rename from server/src/main/java/org/elasticsearch/script/field/vectors/ByteMultiDenseVectorDocValuesField.java rename to server/src/main/java/org/elasticsearch/script/field/vectors/ByteRankVectorsDocValuesField.java index d45c5b85137f5..db81bb6ebe1cb 100644 --- a/server/src/main/java/org/elasticsearch/script/field/vectors/ByteMultiDenseVectorDocValuesField.java +++ b/server/src/main/java/org/elasticsearch/script/field/vectors/ByteRankVectorsDocValuesField.java @@ -12,12 +12,12 @@ import org.apache.lucene.index.BinaryDocValues; import org.apache.lucene.util.BytesRef; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType; -import org.elasticsearch.index.mapper.vectors.MultiDenseVectorScriptDocValues; +import org.elasticsearch.index.mapper.vectors.RankVectorsScriptDocValues; import java.io.IOException; import java.util.Iterator; -public class ByteMultiDenseVectorDocValuesField extends MultiDenseVectorDocValuesField { +public class ByteRankVectorsDocValuesField extends RankVectorsDocValuesField { protected final BinaryDocValues input; private final BinaryDocValues magnitudes; @@ -29,7 +29,7 @@ public class ByteMultiDenseVectorDocValuesField extends MultiDenseVectorDocValue protected BytesRef magnitudesValue; private byte[] buffer; - public ByteMultiDenseVectorDocValuesField( + public ByteRankVectorsDocValuesField( BinaryDocValues input, BinaryDocValues magnitudes, String name, @@ -63,25 +63,25 @@ public void setNextDocId(int docId) throws IOException { } @Override - public MultiDenseVectorScriptDocValues toScriptDocValues() { - return new MultiDenseVectorScriptDocValues(this, dims); + public RankVectorsScriptDocValues toScriptDocValues() { + return new RankVectorsScriptDocValues(this, dims); } - protected MultiDenseVector getVector() { - return new ByteMultiDenseVector(vectorValue, magnitudesValue, numVecs, dims); + protected RankVectors getVector() { + return new ByteRankVectors(vectorValue, magnitudesValue, numVecs, dims); } @Override - public MultiDenseVector get() { + public RankVectors get() { if (isEmpty()) { - return MultiDenseVector.EMPTY; + return RankVectors.EMPTY; } decodeVectorIfNecessary(); return getVector(); } @Override - public MultiDenseVector get(MultiDenseVector defaultValue) { + public RankVectors get(RankVectors defaultValue) { if (isEmpty()) { return defaultValue; } @@ -90,7 +90,7 @@ public MultiDenseVector get(MultiDenseVector defaultValue) { } @Override - public MultiDenseVector getInternal() { + public RankVectors getInternal() { return get(null); } diff --git a/server/src/main/java/org/elasticsearch/script/field/vectors/FloatMultiDenseVector.java b/server/src/main/java/org/elasticsearch/script/field/vectors/FloatRankVectors.java similarity index 93% rename from server/src/main/java/org/elasticsearch/script/field/vectors/FloatMultiDenseVector.java rename to server/src/main/java/org/elasticsearch/script/field/vectors/FloatRankVectors.java index 9c2f7eb6a86d4..3ad5e53c047ae 100644 --- a/server/src/main/java/org/elasticsearch/script/field/vectors/FloatMultiDenseVector.java +++ b/server/src/main/java/org/elasticsearch/script/field/vectors/FloatRankVectors.java @@ -17,7 +17,7 @@ import static org.elasticsearch.index.mapper.vectors.VectorEncoderDecoder.getMultiMagnitudes; -public class FloatMultiDenseVector implements MultiDenseVector { +public class FloatRankVectors implements RankVectors { private final BytesRef magnitudes; private float[] magnitudesArray = null; @@ -25,7 +25,7 @@ public class FloatMultiDenseVector implements MultiDenseVector { private final int numVectors; private final VectorIterator vectorValues; - public FloatMultiDenseVector(VectorIterator decodedDocVector, BytesRef magnitudes, int numVectors, int dims) { + public FloatRankVectors(VectorIterator decodedDocVector, BytesRef magnitudes, int numVectors, int dims) { assert magnitudes.length == numVectors * Float.BYTES; this.vectorValues = decodedDocVector; this.magnitudes = magnitudes; diff --git a/server/src/main/java/org/elasticsearch/script/field/vectors/FloatMultiDenseVectorDocValuesField.java b/server/src/main/java/org/elasticsearch/script/field/vectors/FloatRankVectorsDocValuesField.java similarity index 85% rename from server/src/main/java/org/elasticsearch/script/field/vectors/FloatMultiDenseVectorDocValuesField.java rename to server/src/main/java/org/elasticsearch/script/field/vectors/FloatRankVectorsDocValuesField.java index c7ac7842afd96..39bc1e621113b 100644 --- a/server/src/main/java/org/elasticsearch/script/field/vectors/FloatMultiDenseVectorDocValuesField.java +++ b/server/src/main/java/org/elasticsearch/script/field/vectors/FloatRankVectorsDocValuesField.java @@ -12,7 +12,7 @@ import org.apache.lucene.index.BinaryDocValues; import org.apache.lucene.util.BytesRef; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType; -import org.elasticsearch.index.mapper.vectors.MultiDenseVectorScriptDocValues; +import org.elasticsearch.index.mapper.vectors.RankVectorsScriptDocValues; import java.io.IOException; import java.nio.ByteBuffer; @@ -20,7 +20,7 @@ import java.nio.FloatBuffer; import java.util.Iterator; -public class FloatMultiDenseVectorDocValuesField extends MultiDenseVectorDocValuesField { +public class FloatRankVectorsDocValuesField extends RankVectorsDocValuesField { private final BinaryDocValues input; private final BinaryDocValues magnitudes; @@ -32,7 +32,7 @@ public class FloatMultiDenseVectorDocValuesField extends MultiDenseVectorDocValu private int numVectors; private float[] buffer; - public FloatMultiDenseVectorDocValuesField( + public FloatRankVectorsDocValuesField( BinaryDocValues input, BinaryDocValues magnitudes, String name, @@ -66,8 +66,8 @@ public void setNextDocId(int docId) throws IOException { } @Override - public MultiDenseVectorScriptDocValues toScriptDocValues() { - return new MultiDenseVectorScriptDocValues(this, dims); + public RankVectorsScriptDocValues toScriptDocValues() { + return new RankVectorsScriptDocValues(this, dims); } @Override @@ -76,25 +76,25 @@ public boolean isEmpty() { } @Override - public MultiDenseVector get() { + public RankVectors get() { if (isEmpty()) { - return MultiDenseVector.EMPTY; + return RankVectors.EMPTY; } decodeVectorIfNecessary(); - return new FloatMultiDenseVector(vectorValues, magnitudesValue, numVectors, dims); + return new FloatRankVectors(vectorValues, magnitudesValue, numVectors, dims); } @Override - public MultiDenseVector get(MultiDenseVector defaultValue) { + public RankVectors get(RankVectors defaultValue) { if (isEmpty()) { return defaultValue; } decodeVectorIfNecessary(); - return new FloatMultiDenseVector(vectorValues, magnitudesValue, numVectors, dims); + return new FloatRankVectors(vectorValues, magnitudesValue, numVectors, dims); } @Override - public MultiDenseVector getInternal() { + public RankVectors getInternal() { return get(null); } diff --git a/server/src/main/java/org/elasticsearch/script/field/vectors/MultiDenseVector.java b/server/src/main/java/org/elasticsearch/script/field/vectors/RankVectors.java similarity index 92% rename from server/src/main/java/org/elasticsearch/script/field/vectors/MultiDenseVector.java rename to server/src/main/java/org/elasticsearch/script/field/vectors/RankVectors.java index daf2e3529869f..ec0157c2708c8 100644 --- a/server/src/main/java/org/elasticsearch/script/field/vectors/MultiDenseVector.java +++ b/server/src/main/java/org/elasticsearch/script/field/vectors/RankVectors.java @@ -11,7 +11,7 @@ import java.util.Iterator; -public interface MultiDenseVector { +public interface RankVectors { default void checkDimensions(int qvDims) { checkDimensions(getDims(), qvDims); @@ -45,9 +45,9 @@ private static String badQueryVectorType(Object queryVector) { return "Cannot use vector [" + queryVector + "] with class [" + queryVector.getClass().getName() + "] as query vector"; } - MultiDenseVector EMPTY = new MultiDenseVector() { - public static final String MISSING_VECTOR_FIELD_MESSAGE = "vector value missing for a field," - + " use isEmpty() to check for a missing vector value"; + RankVectors EMPTY = new RankVectors() { + public static final String MISSING_VECTOR_FIELD_MESSAGE = "rank-vectors value missing for a field," + + " use isEmpty() to check for a missing value"; @Override public Iterator getVectors() { diff --git a/server/src/main/java/org/elasticsearch/script/field/vectors/MultiDenseVectorDocValuesField.java b/server/src/main/java/org/elasticsearch/script/field/vectors/RankVectorsDocValuesField.java similarity index 71% rename from server/src/main/java/org/elasticsearch/script/field/vectors/MultiDenseVectorDocValuesField.java rename to server/src/main/java/org/elasticsearch/script/field/vectors/RankVectorsDocValuesField.java index cf838f3b93f96..2362561ea88c5 100644 --- a/server/src/main/java/org/elasticsearch/script/field/vectors/MultiDenseVectorDocValuesField.java +++ b/server/src/main/java/org/elasticsearch/script/field/vectors/RankVectorsDocValuesField.java @@ -9,7 +9,7 @@ package org.elasticsearch.script.field.vectors; -import org.elasticsearch.index.mapper.vectors.MultiDenseVectorScriptDocValues; +import org.elasticsearch.index.mapper.vectors.RankVectorsScriptDocValues; import org.elasticsearch.script.field.AbstractScriptFieldFactory; import org.elasticsearch.script.field.DocValuesScriptFieldFactory; import org.elasticsearch.script.field.Field; @@ -18,15 +18,15 @@ import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType; -public abstract class MultiDenseVectorDocValuesField extends AbstractScriptFieldFactory +public abstract class RankVectorsDocValuesField extends AbstractScriptFieldFactory implements - Field, + Field, DocValuesScriptFieldFactory, - MultiDenseVectorScriptDocValues.MultiDenseVectorSupplier { + RankVectorsScriptDocValues.RankVectorsSupplier { protected final String name; protected final ElementType elementType; - public MultiDenseVectorDocValuesField(String name, ElementType elementType) { + public RankVectorsDocValuesField(String name, ElementType elementType) { this.name = name; this.elementType = elementType; } @@ -43,15 +43,15 @@ public ElementType getElementType() { /** * Get the DenseVector for a document if one exists, DenseVector.EMPTY otherwise */ - public abstract MultiDenseVector get(); + public abstract RankVectors get(); - public abstract MultiDenseVector get(MultiDenseVector defaultValue); + public abstract RankVectors get(RankVectors defaultValue); - public abstract MultiDenseVectorScriptDocValues toScriptDocValues(); + public abstract RankVectorsScriptDocValues toScriptDocValues(); // DenseVector fields are single valued, so Iterable does not make sense. @Override - public Iterator iterator() { + public Iterator iterator() { throw new UnsupportedOperationException("Cannot iterate over single valued rank_vectors field, use get() instead"); } } diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/MultiDenseVectorFieldTypeTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/RankVectorsFieldTypeTests.java similarity index 98% rename from server/src/test/java/org/elasticsearch/index/mapper/vectors/MultiDenseVectorFieldTypeTests.java rename to server/src/test/java/org/elasticsearch/index/mapper/vectors/RankVectorsFieldTypeTests.java index af58c7e734ebf..b4cbbc4730d7c 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/MultiDenseVectorFieldTypeTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/RankVectorsFieldTypeTests.java @@ -23,7 +23,7 @@ import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.BBQ_MIN_DIMS; -public class MultiDenseVectorFieldTypeTests extends FieldTypeTestCase { +public class RankVectorsFieldTypeTests extends FieldTypeTestCase { @BeforeClass public static void setup() { diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/MultiDenseVectorScriptDocValuesTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/RankVectorsScriptDocValuesTests.java similarity index 77% rename from server/src/test/java/org/elasticsearch/index/mapper/vectors/MultiDenseVectorScriptDocValuesTests.java rename to server/src/test/java/org/elasticsearch/index/mapper/vectors/RankVectorsScriptDocValuesTests.java index 42276c81fd161..c38ed0f60f0ae 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/MultiDenseVectorScriptDocValuesTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/RankVectorsScriptDocValuesTests.java @@ -13,10 +13,10 @@ import org.apache.lucene.util.BytesRef; import org.elasticsearch.index.IndexVersion; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType; -import org.elasticsearch.script.field.vectors.ByteMultiDenseVectorDocValuesField; -import org.elasticsearch.script.field.vectors.FloatMultiDenseVectorDocValuesField; -import org.elasticsearch.script.field.vectors.MultiDenseVector; -import org.elasticsearch.script.field.vectors.MultiDenseVectorDocValuesField; +import org.elasticsearch.script.field.vectors.ByteRankVectorsDocValuesField; +import org.elasticsearch.script.field.vectors.FloatRankVectorsDocValuesField; +import org.elasticsearch.script.field.vectors.RankVectors; +import org.elasticsearch.script.field.vectors.RankVectorsDocValuesField; import org.elasticsearch.test.ESTestCase; import org.junit.BeforeClass; @@ -27,7 +27,7 @@ import static org.hamcrest.Matchers.containsString; -public class MultiDenseVectorScriptDocValuesTests extends ESTestCase { +public class RankVectorsScriptDocValuesTests extends ESTestCase { @BeforeClass public static void setup() { @@ -41,14 +41,8 @@ public void testFloatGetVectorValueAndGetMagnitude() throws IOException { BinaryDocValues docValues = wrap(vectors, ElementType.FLOAT); BinaryDocValues magnitudeValues = wrap(expectedMagnitudes); - MultiDenseVectorDocValuesField field = new FloatMultiDenseVectorDocValuesField( - docValues, - magnitudeValues, - "test", - ElementType.FLOAT, - dims - ); - MultiDenseVectorScriptDocValues scriptDocValues = field.toScriptDocValues(); + RankVectorsDocValuesField field = new FloatRankVectorsDocValuesField(docValues, magnitudeValues, "test", ElementType.FLOAT, dims); + RankVectorsScriptDocValues scriptDocValues = field.toScriptDocValues(); for (int i = 0; i < vectors.length; i++) { field.setNextDocId(i); assertEquals(vectors[i].length, field.size()); @@ -71,14 +65,8 @@ public void testByteGetVectorValueAndGetMagnitude() throws IOException { BinaryDocValues docValues = wrap(vectors, ElementType.BYTE); BinaryDocValues magnitudeValues = wrap(expectedMagnitudes); - MultiDenseVectorDocValuesField field = new ByteMultiDenseVectorDocValuesField( - docValues, - magnitudeValues, - "test", - ElementType.BYTE, - dims - ); - MultiDenseVectorScriptDocValues scriptDocValues = field.toScriptDocValues(); + RankVectorsDocValuesField field = new ByteRankVectorsDocValuesField(docValues, magnitudeValues, "test", ElementType.BYTE, dims); + RankVectorsScriptDocValues scriptDocValues = field.toScriptDocValues(); for (int i = 0; i < vectors.length; i++) { field.setNextDocId(i); assertEquals(vectors[i].length, field.size()); @@ -101,16 +89,10 @@ public void testFloatMetadataAndIterator() throws IOException { BinaryDocValues docValues = wrap(vectors, ElementType.FLOAT); BinaryDocValues magnitudeValues = wrap(magnitudes); - MultiDenseVectorDocValuesField field = new FloatMultiDenseVectorDocValuesField( - docValues, - magnitudeValues, - "test", - ElementType.FLOAT, - dims - ); + RankVectorsDocValuesField field = new FloatRankVectorsDocValuesField(docValues, magnitudeValues, "test", ElementType.FLOAT, dims); for (int i = 0; i < vectors.length; i++) { field.setNextDocId(i); - MultiDenseVector dv = field.get(); + RankVectors dv = field.get(); assertEquals(vectors[i].length, dv.size()); assertFalse(dv.isEmpty()); assertEquals(dims, dv.getDims()); @@ -118,8 +100,8 @@ public void testFloatMetadataAndIterator() throws IOException { assertEquals("Cannot iterate over single valued rank_vectors field, use get() instead", e.getMessage()); } field.setNextDocId(vectors.length); - MultiDenseVector dv = field.get(); - assertEquals(dv, MultiDenseVector.EMPTY); + RankVectors dv = field.get(); + assertEquals(dv, RankVectors.EMPTY); } public void testByteMetadataAndIterator() throws IOException { @@ -128,16 +110,10 @@ public void testByteMetadataAndIterator() throws IOException { float[][] magnitudes = new float[][] { new float[3], new float[2] }; BinaryDocValues docValues = wrap(vectors, ElementType.BYTE); BinaryDocValues magnitudeValues = wrap(magnitudes); - MultiDenseVectorDocValuesField field = new ByteMultiDenseVectorDocValuesField( - docValues, - magnitudeValues, - "test", - ElementType.BYTE, - dims - ); + RankVectorsDocValuesField field = new ByteRankVectorsDocValuesField(docValues, magnitudeValues, "test", ElementType.BYTE, dims); for (int i = 0; i < vectors.length; i++) { field.setNextDocId(i); - MultiDenseVector dv = field.get(); + RankVectors dv = field.get(); assertEquals(vectors[i].length, dv.size()); assertFalse(dv.isEmpty()); assertEquals(dims, dv.getDims()); @@ -145,8 +121,8 @@ public void testByteMetadataAndIterator() throws IOException { assertEquals("Cannot iterate over single valued rank_vectors field, use get() instead", e.getMessage()); } field.setNextDocId(vectors.length); - MultiDenseVector dv = field.get(); - assertEquals(dv, MultiDenseVector.EMPTY); + RankVectors dv = field.get(); + assertEquals(dv, RankVectors.EMPTY); } protected float[][] fill(float[][] vectors, ElementType elementType) { @@ -164,22 +140,16 @@ public void testFloatMissingValues() throws IOException { float[][] magnitudes = { { 1.7320f, 2.4495f, 3.3166f }, { 2.2361f } }; BinaryDocValues docValues = wrap(vectors, ElementType.FLOAT); BinaryDocValues magnitudeValues = wrap(magnitudes); - MultiDenseVectorDocValuesField field = new FloatMultiDenseVectorDocValuesField( - docValues, - magnitudeValues, - "test", - ElementType.FLOAT, - dims - ); - MultiDenseVectorScriptDocValues scriptDocValues = field.toScriptDocValues(); + RankVectorsDocValuesField field = new FloatRankVectorsDocValuesField(docValues, magnitudeValues, "test", ElementType.FLOAT, dims); + RankVectorsScriptDocValues scriptDocValues = field.toScriptDocValues(); field.setNextDocId(3); assertEquals(0, field.size()); Exception e = expectThrows(IllegalArgumentException.class, scriptDocValues::getVectorValues); - assertEquals("A document doesn't have a value for a multi-vector field!", e.getMessage()); + assertEquals("A document doesn't have a value for a rank-vectors field!", e.getMessage()); e = expectThrows(IllegalArgumentException.class, scriptDocValues::getMagnitudes); - assertEquals("A document doesn't have a value for a multi-vector field!", e.getMessage()); + assertEquals("A document doesn't have a value for a rank-vectors field!", e.getMessage()); } public void testByteMissingValues() throws IOException { @@ -188,22 +158,16 @@ public void testByteMissingValues() throws IOException { float[][] magnitudes = { { 1.7320f, 2.4495f, 3.3166f }, { 2.2361f } }; BinaryDocValues docValues = wrap(vectors, ElementType.BYTE); BinaryDocValues magnitudeValues = wrap(magnitudes); - MultiDenseVectorDocValuesField field = new ByteMultiDenseVectorDocValuesField( - docValues, - magnitudeValues, - "test", - ElementType.BYTE, - dims - ); - MultiDenseVectorScriptDocValues scriptDocValues = field.toScriptDocValues(); + RankVectorsDocValuesField field = new ByteRankVectorsDocValuesField(docValues, magnitudeValues, "test", ElementType.BYTE, dims); + RankVectorsScriptDocValues scriptDocValues = field.toScriptDocValues(); field.setNextDocId(3); assertEquals(0, field.size()); Exception e = expectThrows(IllegalArgumentException.class, scriptDocValues::getVectorValues); - assertEquals("A document doesn't have a value for a multi-vector field!", e.getMessage()); + assertEquals("A document doesn't have a value for a rank-vectors field!", e.getMessage()); e = expectThrows(IllegalArgumentException.class, scriptDocValues::getMagnitudes); - assertEquals("A document doesn't have a value for a multi-vector field!", e.getMessage()); + assertEquals("A document doesn't have a value for a rank-vectors field!", e.getMessage()); } public void testFloatGetFunctionIsNotAccessible() throws IOException { @@ -212,21 +176,15 @@ public void testFloatGetFunctionIsNotAccessible() throws IOException { float[][] magnitudes = { { 1.7320f, 2.4495f, 3.3166f }, { 2.2361f } }; BinaryDocValues docValues = wrap(vectors, ElementType.FLOAT); BinaryDocValues magnitudeValues = wrap(magnitudes); - MultiDenseVectorDocValuesField field = new FloatMultiDenseVectorDocValuesField( - docValues, - magnitudeValues, - "test", - ElementType.FLOAT, - dims - ); - MultiDenseVectorScriptDocValues scriptDocValues = field.toScriptDocValues(); + RankVectorsDocValuesField field = new FloatRankVectorsDocValuesField(docValues, magnitudeValues, "test", ElementType.FLOAT, dims); + RankVectorsScriptDocValues scriptDocValues = field.toScriptDocValues(); field.setNextDocId(0); Exception e = expectThrows(UnsupportedOperationException.class, () -> scriptDocValues.get(0)); assertThat( e.getMessage(), containsString( - "accessing a multi-vector field's value through 'get' or 'value' is not supported," + "accessing a rank-vectors field's value through 'get' or 'value' is not supported," + " use 'vectorValues' or 'magnitudes' instead." ) ); @@ -238,21 +196,15 @@ public void testByteGetFunctionIsNotAccessible() throws IOException { float[][] magnitudes = { { 1.7320f, 2.4495f, 3.3166f }, { 2.2361f } }; BinaryDocValues docValues = wrap(vectors, ElementType.BYTE); BinaryDocValues magnitudeValues = wrap(magnitudes); - MultiDenseVectorDocValuesField field = new ByteMultiDenseVectorDocValuesField( - docValues, - magnitudeValues, - "test", - ElementType.BYTE, - dims - ); - MultiDenseVectorScriptDocValues scriptDocValues = field.toScriptDocValues(); + RankVectorsDocValuesField field = new ByteRankVectorsDocValuesField(docValues, magnitudeValues, "test", ElementType.BYTE, dims); + RankVectorsScriptDocValues scriptDocValues = field.toScriptDocValues(); field.setNextDocId(0); Exception e = expectThrows(UnsupportedOperationException.class, () -> scriptDocValues.get(0)); assertThat( e.getMessage(), containsString( - "accessing a multi-vector field's value through 'get' or 'value' is not supported," + "accessing a rank-vectors field's value through 'get' or 'value' is not supported," + " use 'vectorValues' or 'magnitudes' instead." ) ); diff --git a/server/src/test/java/org/elasticsearch/script/MultiVectorScoreScriptUtilsTests.java b/server/src/test/java/org/elasticsearch/script/RankVectorsScoreScriptUtilsTests.java similarity index 83% rename from server/src/test/java/org/elasticsearch/script/MultiVectorScoreScriptUtilsTests.java rename to server/src/test/java/org/elasticsearch/script/RankVectorsScoreScriptUtilsTests.java index fecd7924e4c6d..917cc2069a293 100644 --- a/server/src/test/java/org/elasticsearch/script/MultiVectorScoreScriptUtilsTests.java +++ b/server/src/test/java/org/elasticsearch/script/RankVectorsScoreScriptUtilsTests.java @@ -11,14 +11,14 @@ import org.apache.lucene.util.VectorUtil; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType; -import org.elasticsearch.index.mapper.vectors.MultiDenseVectorScriptDocValuesTests; import org.elasticsearch.index.mapper.vectors.RankVectorsFieldMapper; -import org.elasticsearch.script.MultiVectorScoreScriptUtils.MaxSimDotProduct; -import org.elasticsearch.script.MultiVectorScoreScriptUtils.MaxSimInvHamming; -import org.elasticsearch.script.field.vectors.BitMultiDenseVectorDocValuesField; -import org.elasticsearch.script.field.vectors.ByteMultiDenseVectorDocValuesField; -import org.elasticsearch.script.field.vectors.FloatMultiDenseVectorDocValuesField; -import org.elasticsearch.script.field.vectors.MultiDenseVectorDocValuesField; +import org.elasticsearch.index.mapper.vectors.RankVectorsScriptDocValuesTests; +import org.elasticsearch.script.RankVectorsScoreScriptUtils.MaxSimDotProduct; +import org.elasticsearch.script.RankVectorsScoreScriptUtils.MaxSimInvHamming; +import org.elasticsearch.script.field.vectors.BitRankVectorsDocValuesField; +import org.elasticsearch.script.field.vectors.ByteRankVectorsDocValuesField; +import org.elasticsearch.script.field.vectors.FloatRankVectorsDocValuesField; +import org.elasticsearch.script.field.vectors.RankVectorsDocValuesField; import org.elasticsearch.test.ESTestCase; import org.junit.BeforeClass; @@ -31,7 +31,7 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; -public class MultiVectorScoreScriptUtilsTests extends ESTestCase { +public class RankVectorsScoreScriptUtilsTests extends ESTestCase { @BeforeClass public static void setup() { @@ -53,23 +53,23 @@ public void testFloatMultiVectorClassBindings() throws IOException { List> queryVector = List.of(Arrays.asList(0.5f, 111.3f, -13.0f, 14.8f, -156.0f)); List> invalidQueryVector = List.of(Arrays.asList(0.5, 111.3)); - List fields = List.of( - new FloatMultiDenseVectorDocValuesField( - MultiDenseVectorScriptDocValuesTests.wrap(docVectors, ElementType.FLOAT), - MultiDenseVectorScriptDocValuesTests.wrap(docMagnitudes), + List fields = List.of( + new FloatRankVectorsDocValuesField( + RankVectorsScriptDocValuesTests.wrap(docVectors, ElementType.FLOAT), + RankVectorsScriptDocValuesTests.wrap(docMagnitudes), "test", ElementType.FLOAT, dims ), - new FloatMultiDenseVectorDocValuesField( - MultiDenseVectorScriptDocValuesTests.wrap(docVectors, ElementType.FLOAT), - MultiDenseVectorScriptDocValuesTests.wrap(docMagnitudes), + new FloatRankVectorsDocValuesField( + RankVectorsScriptDocValuesTests.wrap(docVectors, ElementType.FLOAT), + RankVectorsScriptDocValuesTests.wrap(docMagnitudes), "test", ElementType.FLOAT, dims ) ); - for (MultiDenseVectorDocValuesField field : fields) { + for (RankVectorsDocValuesField field : fields) { field.setNextDocId(0); ScoreScript scoreScript = mock(ScoreScript.class); @@ -88,7 +88,7 @@ public void testFloatMultiVectorClassBindings() throws IOException { // Check each function rejects query vectors with the wrong dimension IllegalArgumentException e = expectThrows( IllegalArgumentException.class, - () -> new MultiVectorScoreScriptUtils.MaxSimDotProduct(scoreScript, invalidQueryVector, fieldName) + () -> new RankVectorsScoreScriptUtils.MaxSimDotProduct(scoreScript, invalidQueryVector, fieldName) ); assertThat( e.getMessage(), @@ -120,16 +120,16 @@ public void testByteMultiVectorClassBindings() throws IOException { List> invalidQueryVector = List.of(Arrays.asList((byte) 1, (byte) 1)); List hexidecimalString = List.of(HexFormat.of().formatHex(new byte[] { 1, 125, -12, 2, 4 })); - List fields = List.of( - new ByteMultiDenseVectorDocValuesField( - MultiDenseVectorScriptDocValuesTests.wrap(new float[][][] { docVector }, ElementType.BYTE), - MultiDenseVectorScriptDocValuesTests.wrap(magnitudes), + List fields = List.of( + new ByteRankVectorsDocValuesField( + RankVectorsScriptDocValuesTests.wrap(new float[][][] { docVector }, ElementType.BYTE), + RankVectorsScriptDocValuesTests.wrap(magnitudes), "test", ElementType.BYTE, dims ) ); - for (MultiDenseVectorDocValuesField field : fields) { + for (RankVectorsDocValuesField field : fields) { field.setNextDocId(0); ScoreScript scoreScript = mock(ScoreScript.class); @@ -174,16 +174,16 @@ public void testBitMultiVectorClassBindingsDotProduct() throws IOException { List> invalidQueryVector = List.of(Arrays.asList((byte) 1, (byte) 1)); List hexidecimalString = List.of(HexFormat.of().formatHex(new byte[] { 124 })); - List fields = List.of( - new BitMultiDenseVectorDocValuesField( - MultiDenseVectorScriptDocValuesTests.wrap(new float[][][] { docVector }, ElementType.BIT), - MultiDenseVectorScriptDocValuesTests.wrap(new float[][] { { 5 } }), + List fields = List.of( + new BitRankVectorsDocValuesField( + RankVectorsScriptDocValuesTests.wrap(new float[][][] { docVector }, ElementType.BIT), + RankVectorsScriptDocValuesTests.wrap(new float[][] { { 5 } }), "test", ElementType.BIT, dims ) ); - for (MultiDenseVectorDocValuesField field : fields) { + for (RankVectorsDocValuesField field : fields) { field.setNextDocId(0); ScoreScript scoreScript = mock(ScoreScript.class); @@ -240,23 +240,23 @@ public void testByteVsFloatSimilarity() throws IOException { float[][] floatVector = new float[][] { { 1f, 125f, -12f, 2f, 4f } }; byte[][] byteVector = new byte[][] { { (byte) 1, (byte) 125, (byte) -12, (byte) 2, (byte) 4 } }; - List fields = List.of( - new FloatMultiDenseVectorDocValuesField( - MultiDenseVectorScriptDocValuesTests.wrap(new float[][][] { docVector }, ElementType.FLOAT), - MultiDenseVectorScriptDocValuesTests.wrap(magnitudes), + List fields = List.of( + new FloatRankVectorsDocValuesField( + RankVectorsScriptDocValuesTests.wrap(new float[][][] { docVector }, ElementType.FLOAT), + RankVectorsScriptDocValuesTests.wrap(magnitudes), "field1", ElementType.FLOAT, dims ), - new ByteMultiDenseVectorDocValuesField( - MultiDenseVectorScriptDocValuesTests.wrap(new float[][][] { docVector }, ElementType.BYTE), - MultiDenseVectorScriptDocValuesTests.wrap(magnitudes), + new ByteRankVectorsDocValuesField( + RankVectorsScriptDocValuesTests.wrap(new float[][][] { docVector }, ElementType.BYTE), + RankVectorsScriptDocValuesTests.wrap(magnitudes), "field3", ElementType.BYTE, dims ) ); - for (MultiDenseVectorDocValuesField field : fields) { + for (RankVectorsDocValuesField field : fields) { field.setNextDocId(0); ScoreScript scoreScript = mock(ScoreScript.class); @@ -296,17 +296,17 @@ public void testByteBoundaries() throws IOException { List> lessThanVector = List.of(List.of(-129)); List> decimalVector = List.of(List.of(0.5)); - List fields = List.of( - new ByteMultiDenseVectorDocValuesField( - MultiDenseVectorScriptDocValuesTests.wrap(new float[][][] { { docVector } }, ElementType.BYTE), - MultiDenseVectorScriptDocValuesTests.wrap(new float[][] { { 1 } }), + List fields = List.of( + new ByteRankVectorsDocValuesField( + RankVectorsScriptDocValuesTests.wrap(new float[][][] { { docVector } }, ElementType.BYTE), + RankVectorsScriptDocValuesTests.wrap(new float[][] { { 1 } }), "test", ElementType.BYTE, dims ) ); - for (MultiDenseVectorDocValuesField field : fields) { + for (RankVectorsDocValuesField field : fields) { field.setNextDocId(0); ScoreScript scoreScript = mock(ScoreScript.class); diff --git a/server/src/test/java/org/elasticsearch/script/field/vectors/MultiDenseVectorTests.java b/server/src/test/java/org/elasticsearch/script/field/vectors/RankVectorsTests.java similarity index 85% rename from server/src/test/java/org/elasticsearch/script/field/vectors/MultiDenseVectorTests.java rename to server/src/test/java/org/elasticsearch/script/field/vectors/RankVectorsTests.java index 6fc24a26135b4..ca7608b10aed9 100644 --- a/server/src/test/java/org/elasticsearch/script/field/vectors/MultiDenseVectorTests.java +++ b/server/src/test/java/org/elasticsearch/script/field/vectors/RankVectorsTests.java @@ -19,7 +19,7 @@ import java.nio.ByteOrder; import java.util.function.IntFunction; -public class MultiDenseVectorTests extends ESTestCase { +public class RankVectorsTests extends ESTestCase { @BeforeClass public static void setup() { @@ -38,7 +38,7 @@ public void testByteUnsupported() { } } - MultiDenseVector knn = newByteVector(docVector); + RankVectors knn = newByteVector(docVector); UnsupportedOperationException e; e = expectThrows(UnsupportedOperationException.class, () -> knn.maxSimDotProduct(queryVector)); @@ -57,20 +57,20 @@ public void testFloatUnsupported() { } } - MultiDenseVector knn = newFloatVector(docVector); + RankVectors knn = newFloatVector(docVector); UnsupportedOperationException e = expectThrows(UnsupportedOperationException.class, () -> knn.maxSimDotProduct(queryVector)); assertEquals(e.getMessage(), "use [float maxSimDotProduct(float[][] queryVector)] instead"); } - static MultiDenseVector newFloatVector(float[][] vector) { + static RankVectors newFloatVector(float[][] vector) { BytesRef magnitudes = magnitudes(vector.length, i -> (float) Math.sqrt(VectorUtil.dotProduct(vector[i], vector[i]))); - return new FloatMultiDenseVector(VectorIterator.from(vector), magnitudes, vector.length, vector[0].length); + return new FloatRankVectors(VectorIterator.from(vector), magnitudes, vector.length, vector[0].length); } - static MultiDenseVector newByteVector(byte[][] vector) { + static RankVectors newByteVector(byte[][] vector) { BytesRef magnitudes = magnitudes(vector.length, i -> (float) Math.sqrt(VectorUtil.dotProduct(vector[i], vector[i]))); - return new ByteMultiDenseVector(VectorIterator.from(vector), magnitudes, vector.length, vector[0].length); + return new ByteRankVectors(VectorIterator.from(vector), magnitudes, vector.length, vector[0].length); } static BytesRef magnitudes(int count, IntFunction magnitude) {