Skip to content

Commit

Permalink
more renaming
Browse files Browse the repository at this point in the history
  • Loading branch information
benwtrent committed Dec 9, 2024
1 parent e1d2466 commit c11966d
Show file tree
Hide file tree
Showing 19 changed files with 172 additions and 240 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,8 @@ class org.elasticsearch.script.field.SeqNoDocValuesField @dynamic_type {
class org.elasticsearch.script.field.VersionDocValuesField @dynamic_type {
}

class org.elasticsearch.script.field.vectors.MultiDenseVector {
MultiDenseVector EMPTY
class org.elasticsearch.script.field.vectors.RankVectors {
RankVectors EMPTY
float[] getMagnitudes()

Iterator getVectors()
Expand All @@ -142,9 +142,9 @@ class org.elasticsearch.script.field.vectors.MultiDenseVector {
int size()
}

class org.elasticsearch.script.field.vectors.MultiDenseVectorDocValuesField {
MultiDenseVector get()
MultiDenseVector get(MultiDenseVector)
class org.elasticsearch.script.field.vectors.RankVectorsDocValuesField {
RankVectors get()
RankVectors get(RankVectors)
}

class org.elasticsearch.script.field.vectors.DenseVector {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ static_import {
double cosineSimilarity(org.elasticsearch.script.ScoreScript, Object, String) bound_to org.elasticsearch.script.VectorScoreScriptUtils$CosineSimilarity
double dotProduct(org.elasticsearch.script.ScoreScript, Object, String) bound_to org.elasticsearch.script.VectorScoreScriptUtils$DotProduct
double hamming(org.elasticsearch.script.ScoreScript, Object, String) bound_to org.elasticsearch.script.VectorScoreScriptUtils$Hamming
double maxSimDotProduct(org.elasticsearch.script.ScoreScript, Object, String) bound_to org.elasticsearch.script.MultiVectorScoreScriptUtils$MaxSimDotProduct
double maxSimInvHamming(org.elasticsearch.script.ScoreScript, Object, String) bound_to org.elasticsearch.script.MultiVectorScoreScriptUtils$MaxSimInvHamming
double maxSimDotProduct(org.elasticsearch.script.ScoreScript, Object, String) bound_to org.elasticsearch.script.RankVectorsScoreScriptUtils$MaxSimDotProduct
double maxSimInvHamming(org.elasticsearch.script.ScoreScript, Object, String) bound_to org.elasticsearch.script.RankVectorsScoreScriptUtils$MaxSimInvHamming
}

Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ class org.elasticsearch.index.mapper.vectors.DenseVectorScriptDocValues {
float getMagnitude()
}

class org.elasticsearch.index.mapper.vectors.MultiDenseVectorScriptDocValues {
class org.elasticsearch.index.mapper.vectors.RankVectorsScriptDocValues {
Iterator getVectorValues()
float[] getMagnitudes()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
import org.elasticsearch.index.fielddata.LeafFieldData;
import org.elasticsearch.index.fielddata.SortedBinaryDocValues;
import org.elasticsearch.script.field.DocValuesScriptFieldFactory;
import org.elasticsearch.script.field.vectors.BitMultiDenseVectorDocValuesField;
import org.elasticsearch.script.field.vectors.ByteMultiDenseVectorDocValuesField;
import org.elasticsearch.script.field.vectors.FloatMultiDenseVectorDocValuesField;
import org.elasticsearch.script.field.vectors.BitRankVectorsDocValuesField;
import org.elasticsearch.script.field.vectors.ByteRankVectorsDocValuesField;
import org.elasticsearch.script.field.vectors.FloatRankVectorsDocValuesField;

import java.io.IOException;

Expand All @@ -40,9 +40,9 @@ public DocValuesScriptFieldFactory getScriptFieldFactory(String name) {
BinaryDocValues values = DocValues.getBinary(reader, field);
BinaryDocValues magnitudeValues = DocValues.getBinary(reader, field + RankVectorsFieldMapper.VECTOR_MAGNITUDES_SUFFIX);
return switch (elementType) {
case BYTE -> new ByteMultiDenseVectorDocValuesField(values, magnitudeValues, name, elementType, dims);
case FLOAT -> new FloatMultiDenseVectorDocValuesField(values, magnitudeValues, name, elementType, dims);
case BIT -> new BitMultiDenseVectorDocValuesField(values, magnitudeValues, name, elementType, dims);
case BYTE -> new ByteRankVectorsDocValuesField(values, magnitudeValues, name, elementType, dims);
case FLOAT -> new FloatRankVectorsDocValuesField(values, magnitudeValues, name, elementType, dims);
case BIT -> new BitRankVectorsDocValuesField(values, magnitudeValues, name, elementType, dims);
};
} catch (IOException e) {
throw new IllegalStateException("Cannot load doc values for multi-vector field!", e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,18 @@

import org.apache.lucene.util.BytesRef;
import org.elasticsearch.index.fielddata.ScriptDocValues;
import org.elasticsearch.script.field.vectors.MultiDenseVector;
import org.elasticsearch.script.field.vectors.RankVectors;

import java.util.Iterator;

public class MultiDenseVectorScriptDocValues extends ScriptDocValues<BytesRef> {
public class RankVectorsScriptDocValues extends ScriptDocValues<BytesRef> {

public static final String MISSING_VECTOR_FIELD_MESSAGE = "A document doesn't have a value for a vector field!";
public static final String MISSING_VECTOR_FIELD_MESSAGE = "A document doesn't have a value for a rank-vectors field!";

private final int dims;
protected final MultiDenseVectorSupplier dvSupplier;
protected final RankVectorsSupplier dvSupplier;

public MultiDenseVectorScriptDocValues(MultiDenseVectorSupplier supplier, int dims) {
public RankVectorsScriptDocValues(RankVectorsSupplier supplier, int dims) {
super(supplier);
this.dvSupplier = supplier;
this.dims = dims;
Expand All @@ -32,8 +32,8 @@ public int dims() {
return dims;
}

private MultiDenseVector getCheckedVector() {
MultiDenseVector vector = dvSupplier.getInternal();
private RankVectors getCheckedVector() {
RankVectors vector = dvSupplier.getInternal();
if (vector == null) {
throw new IllegalArgumentException(MISSING_VECTOR_FIELD_MESSAGE);
}
Expand All @@ -57,25 +57,25 @@ public float[] getMagnitudes() {
@Override
public BytesRef get(int index) {
throw new UnsupportedOperationException(
"accessing a vector field's value through 'get' or 'value' is not supported, use 'vectorValues' or 'magnitudes' instead."
"accessing a rank-vectors field's value through 'get' or 'value' is not supported, use 'vectorValues' or 'magnitudes' instead."
);
}

@Override
public int size() {
MultiDenseVector mdv = dvSupplier.getInternal();
RankVectors mdv = dvSupplier.getInternal();
if (mdv != null) {
return mdv.size();
}
return 0;
}

public interface MultiDenseVectorSupplier extends Supplier<BytesRef> {
public interface RankVectorsSupplier extends Supplier<BytesRef> {
@Override
default BytesRef getInternal(int index) {
throw new UnsupportedOperationException();
}

MultiDenseVector getInternal();
RankVectors getInternal();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -94,14 +94,4 @@ public static float[] getMultiMagnitudes(BytesRef magnitudes) {
return multiMagnitudes;
}

public static void decodeMultiDenseVector(BytesRef vectorBR, int numVectors, float[][] multiVectorValue) {
if (vectorBR == null) {
throw new IllegalArgumentException(MultiDenseVectorScriptDocValues.MISSING_VECTOR_FIELD_MESSAGE);
}
FloatBuffer fb = ByteBuffer.wrap(vectorBR.bytes, vectorBR.offset, vectorBR.length).order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer();
for (int i = 0; i < numVectors; i++) {
fb.get(multiVectorValue[i]);
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,19 @@
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
import org.elasticsearch.script.field.vectors.DenseVector;
import org.elasticsearch.script.field.vectors.MultiDenseVectorDocValuesField;
import org.elasticsearch.script.field.vectors.RankVectorsDocValuesField;

import java.io.IOException;
import java.util.HexFormat;
import java.util.List;

public class MultiVectorScoreScriptUtils {
public class RankVectorsScoreScriptUtils {

public static class MultiDenseVectorFunction {
public static class RankVectorsFunction {
protected final ScoreScript scoreScript;
protected final MultiDenseVectorDocValuesField field;
protected final RankVectorsDocValuesField field;

public MultiDenseVectorFunction(ScoreScript scoreScript, MultiDenseVectorDocValuesField field) {
public RankVectorsFunction(ScoreScript scoreScript, RankVectorsDocValuesField field) {
this.scoreScript = scoreScript;
this.field = field;
}
Expand All @@ -41,7 +41,7 @@ void setNextVector() {
}
}

public static class ByteMultiDenseVectorFunction extends MultiDenseVectorFunction {
public static class ByteRankVectorsFunction extends RankVectorsFunction {
protected final byte[][] queryVector;

/**
Expand All @@ -51,7 +51,7 @@ public static class ByteMultiDenseVectorFunction extends MultiDenseVectorFunctio
* @param field The vector field.
* @param queryVector The query vector.
*/
public ByteMultiDenseVectorFunction(ScoreScript scoreScript, MultiDenseVectorDocValuesField field, List<List<Number>> queryVector) {
public ByteRankVectorsFunction(ScoreScript scoreScript, RankVectorsDocValuesField field, List<List<Number>> queryVector) {
super(scoreScript, field);
if (queryVector.isEmpty()) {
throw new IllegalArgumentException("The query vector is empty.");
Expand Down Expand Up @@ -84,13 +84,13 @@ public ByteMultiDenseVectorFunction(ScoreScript scoreScript, MultiDenseVectorDoc
* @param field The vector field.
* @param queryVector The query vector.
*/
public ByteMultiDenseVectorFunction(ScoreScript scoreScript, MultiDenseVectorDocValuesField field, byte[][] queryVector) {
public ByteRankVectorsFunction(ScoreScript scoreScript, RankVectorsDocValuesField field, byte[][] queryVector) {
super(scoreScript, field);
this.queryVector = queryVector;
}
}

public static class FloatMultiDenseVectorFunction extends MultiDenseVectorFunction {
public static class FloatRankVectorsFunction extends RankVectorsFunction {
protected final float[][] queryVector;

/**
Expand All @@ -100,11 +100,7 @@ public static class FloatMultiDenseVectorFunction extends MultiDenseVectorFuncti
* @param field The vector field.
* @param queryVector The query vector.
*/
public FloatMultiDenseVectorFunction(
ScoreScript scoreScript,
MultiDenseVectorDocValuesField field,
List<List<Number>> queryVector
) {
public FloatRankVectorsFunction(ScoreScript scoreScript, RankVectorsDocValuesField field, List<List<Number>> queryVector) {
super(scoreScript, field);
if (queryVector.isEmpty()) {
throw new IllegalArgumentException("The query vector is empty.");
Expand Down Expand Up @@ -133,13 +129,13 @@ public interface MaxSimInvHammingDistanceInterface {
float maxSimInvHamming();
}

public static class ByteMaxSimInvHammingDistance extends ByteMultiDenseVectorFunction implements MaxSimInvHammingDistanceInterface {
public static class ByteMaxSimInvHammingDistance extends ByteRankVectorsFunction implements MaxSimInvHammingDistanceInterface {

public ByteMaxSimInvHammingDistance(ScoreScript scoreScript, MultiDenseVectorDocValuesField field, List<List<Number>> queryVector) {
public ByteMaxSimInvHammingDistance(ScoreScript scoreScript, RankVectorsDocValuesField field, List<List<Number>> queryVector) {
super(scoreScript, field, queryVector);
}

public ByteMaxSimInvHammingDistance(ScoreScript scoreScript, MultiDenseVectorDocValuesField field, byte[][] queryVector) {
public ByteMaxSimInvHammingDistance(ScoreScript scoreScript, RankVectorsDocValuesField field, byte[][] queryVector) {
super(scoreScript, field, queryVector);
}

Expand Down Expand Up @@ -183,7 +179,7 @@ public static final class MaxSimInvHamming {
private final MaxSimInvHammingDistanceInterface function;

public MaxSimInvHamming(ScoreScript scoreScript, Object queryVector, String fieldName) {
MultiDenseVectorDocValuesField field = (MultiDenseVectorDocValuesField) scoreScript.field(fieldName);
RankVectorsDocValuesField field = (RankVectorsDocValuesField) scoreScript.field(fieldName);
if (field.getElementType() == DenseVectorFieldMapper.ElementType.FLOAT) {
throw new IllegalArgumentException("hamming distance is only supported for byte or bit vectors");
}
Expand All @@ -205,11 +201,11 @@ public interface MaxSimDotProductInterface {
double maxSimDotProduct();
}

public static class MaxSimBitDotProduct extends MultiDenseVectorFunction implements MaxSimDotProductInterface {
public static class MaxSimBitDotProduct extends RankVectorsFunction implements MaxSimDotProductInterface {
private final byte[][] byteQueryVector;
private final float[][] floatQueryVector;

public MaxSimBitDotProduct(ScoreScript scoreScript, MultiDenseVectorDocValuesField field, byte[][] queryVector) {
public MaxSimBitDotProduct(ScoreScript scoreScript, RankVectorsDocValuesField field, byte[][] queryVector) {
super(scoreScript, field);
if (field.getElementType() != DenseVectorFieldMapper.ElementType.BIT) {
throw new IllegalArgumentException("Cannot calculate bit dot product for non-bit vectors");
Expand All @@ -230,7 +226,7 @@ public MaxSimBitDotProduct(ScoreScript scoreScript, MultiDenseVectorDocValuesFie
this.floatQueryVector = null;
}

public MaxSimBitDotProduct(ScoreScript scoreScript, MultiDenseVectorDocValuesField field, List<List<Number>> queryVector) {
public MaxSimBitDotProduct(ScoreScript scoreScript, RankVectorsDocValuesField field, List<List<Number>> queryVector) {
super(scoreScript, field);
if (queryVector.isEmpty()) {
throw new IllegalArgumentException("The query vector is empty.");
Expand Down Expand Up @@ -304,13 +300,13 @@ public double maxSimDotProduct() {
}
}

public static class MaxSimByteDotProduct extends ByteMultiDenseVectorFunction implements MaxSimDotProductInterface {
public static class MaxSimByteDotProduct extends ByteRankVectorsFunction implements MaxSimDotProductInterface {

public MaxSimByteDotProduct(ScoreScript scoreScript, MultiDenseVectorDocValuesField field, List<List<Number>> queryVector) {
public MaxSimByteDotProduct(ScoreScript scoreScript, RankVectorsDocValuesField field, List<List<Number>> queryVector) {
super(scoreScript, field, queryVector);
}

public MaxSimByteDotProduct(ScoreScript scoreScript, MultiDenseVectorDocValuesField field, byte[][] queryVector) {
public MaxSimByteDotProduct(ScoreScript scoreScript, RankVectorsDocValuesField field, byte[][] queryVector) {
super(scoreScript, field, queryVector);
}

Expand All @@ -320,9 +316,9 @@ public double maxSimDotProduct() {
}
}

public static class MaxSimFloatDotProduct extends FloatMultiDenseVectorFunction implements MaxSimDotProductInterface {
public static class MaxSimFloatDotProduct extends FloatRankVectorsFunction implements MaxSimDotProductInterface {

public MaxSimFloatDotProduct(ScoreScript scoreScript, MultiDenseVectorDocValuesField field, List<List<Number>> queryVector) {
public MaxSimFloatDotProduct(ScoreScript scoreScript, RankVectorsDocValuesField field, List<List<Number>> queryVector) {
super(scoreScript, field, queryVector);
}

Expand All @@ -338,7 +334,7 @@ public static final class MaxSimDotProduct {

@SuppressWarnings("unchecked")
public MaxSimDotProduct(ScoreScript scoreScript, Object queryVector, String fieldName) {
MultiDenseVectorDocValuesField field = (MultiDenseVectorDocValuesField) scoreScript.field(fieldName);
RankVectorsDocValuesField field = (RankVectorsDocValuesField) scoreScript.field(fieldName);
function = switch (field.getElementType()) {
case BIT -> {
BytesOrList bytesOrList = parseBytes(queryVector);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@

import java.util.Arrays;

public class BitMultiDenseVector extends ByteMultiDenseVector {
public BitMultiDenseVector(VectorIterator<byte[]> vectorValues, BytesRef magnitudesBytes, int numVecs, int dims) {
public class BitRankVectors extends ByteRankVectors {
public BitRankVectors(VectorIterator<byte[]> vectorValues, BytesRef magnitudesBytes, int numVecs, int dims) {
super(vectorValues, magnitudesBytes, numVecs, dims);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,14 @@
import org.apache.lucene.index.BinaryDocValues;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType;

public class BitMultiDenseVectorDocValuesField extends ByteMultiDenseVectorDocValuesField {
public class BitRankVectorsDocValuesField extends ByteRankVectorsDocValuesField {

public BitMultiDenseVectorDocValuesField(
BinaryDocValues input,
BinaryDocValues magnitudes,
String name,
ElementType elementType,
int dims
) {
public BitRankVectorsDocValuesField(BinaryDocValues input, BinaryDocValues magnitudes, String name, ElementType elementType, int dims) {
super(input, magnitudes, name, elementType, dims / 8);
}

@Override
protected MultiDenseVector getVector() {
return new BitMultiDenseVector(vectorValue, magnitudesValue, numVecs, dims);
protected RankVectors getVector() {
return new BitRankVectors(vectorValue, magnitudesValue, numVecs, dims);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import java.util.Arrays;
import java.util.Iterator;

public class ByteMultiDenseVector implements MultiDenseVector {
public class ByteRankVectors implements RankVectors {

protected final VectorIterator<byte[]> vectorValues;
protected final int numVecs;
Expand All @@ -25,7 +25,7 @@ public class ByteMultiDenseVector implements MultiDenseVector {
private float[] magnitudes;
private final BytesRef magnitudesBytes;

public ByteMultiDenseVector(VectorIterator<byte[]> vectorValues, BytesRef magnitudesBytes, int numVecs, int dims) {
public ByteRankVectors(VectorIterator<byte[]> vectorValues, BytesRef magnitudesBytes, int numVecs, int dims) {
assert magnitudesBytes.length == numVecs * Float.BYTES;
this.vectorValues = vectorValues;
this.numVecs = numVecs;
Expand Down
Loading

0 comments on commit c11966d

Please sign in to comment.