diff --git a/docs/reference/mapping/types/dense-vector.asciidoc b/docs/reference/mapping/types/dense-vector.asciidoc index 39e33bae94e12..95987477e4993 100644 --- a/docs/reference/mapping/types/dense-vector.asciidoc +++ b/docs/reference/mapping/types/dense-vector.asciidoc @@ -10,9 +10,8 @@ A `dense_vector` field stores dense vectors of float values. The maximum number of dimensions that can be in a vector should not exceed 2048. A `dense_vector` field is a single-valued field. -These vectors can be used for <>. -For example, a document score can represent a distance between -a given query vector and the indexed document vector. +`dense_vector` fields do not support querying, sorting or aggregating. They can +only be accessed in scripts through the dedicated <>. You index a dense vector as an array of floats. @@ -47,4 +46,4 @@ PUT my-index-000001/_doc/2 -------------------------------------------------- -<1> dims—the number of dimensions in the vector, required parameter. +<1> dims – the number of dimensions in the vector, required parameter. diff --git a/docs/reference/vectors/vector-functions.asciidoc b/docs/reference/vectors/vector-functions.asciidoc index 4fb197f803099..d3e72625c4a6a 100644 --- a/docs/reference/vectors/vector-functions.asciidoc +++ b/docs/reference/vectors/vector-functions.asciidoc @@ -10,6 +10,15 @@ to limit the number of matched documents with a `query` parameter. ====== `dense_vector` functions +This is the list of available vector functions and vector access methods: + +1. `cosineSimilarity` – calculates cosine similarity +2. `dotProduct` – calculates dot product +3. `l1norm` – calculates L^1^ distance +4. `l2norm` - calculates L^2^ distance +5. `doc[].vectorValue` – returns a vector's value as an array of floats +6. `doc[].magnitude` – returns a vector's magnitude + Let's create an index with a `dense_vector` mapping and index a couple of documents into it. @@ -198,6 +207,54 @@ You can check if a document has a value for the field `my_vector` by -------------------------------------------------- // NOTCONSOLE +The recommended way to access dense vectors is through `cosineSimilarity`, +`dotProduct`, `l1norm` or `l2norm` functions. But for custom use cases, +you can access dense vectors's values directly through the following functions: + +- `doc[].vectorValue` – returns a vector's value as an array of floats + +- `doc[].magnitude` – returns a vector's magnitude as a float +(for vectors created prior to version 7.5 the magnitude is not stored. +So this function calculates it anew every time it is called). + +For example, the script below implements a cosine similarity using these +two functions: + +[source,console] +-------------------------------------------------- +GET my-index-000001/_search +{ + "query": { + "script_score": { + "query" : { + "bool" : { + "filter" : { + "term" : { + "status" : "published" + } + } + } + }, + "script": { + "source": """ + float[] v = doc['my_dense_vector'].vectorValue; + float vm = doc['my_dense_vector'].magnitude; + float dotProduct = 0; + for (int i = 0; i < v.length; i++) { + dotProduct += v[i] * params.queryVector[i]; + } + return dotProduct / (vm * (float) params.queryVectorMag); + """, + "params": { + "queryVector": [4, 3.4, -0.2], + "queryVectorMag": 5.25357 + } + } + } + } +} +-------------------------------------------------- + ====== `sparse_vector` functions deprecated[7.6, The `sparse_vector` type is deprecated and will be removed in 8.0.] diff --git a/server/src/main/java/org/elasticsearch/common/lucene/search/function/ScriptScoreFunction.java b/server/src/main/java/org/elasticsearch/common/lucene/search/function/ScriptScoreFunction.java index e3a9b44fd3197..1573f11b88eca 100644 --- a/server/src/main/java/org/elasticsearch/common/lucene/search/function/ScriptScoreFunction.java +++ b/server/src/main/java/org/elasticsearch/common/lucene/search/function/ScriptScoreFunction.java @@ -14,7 +14,6 @@ import org.elasticsearch.script.ExplainableScoreScript; import org.elasticsearch.script.ScoreScript; import org.elasticsearch.script.Script; -import org.elasticsearch.Version; import java.io.IOException; import java.util.Objects; @@ -42,15 +41,13 @@ public float score() { private final int shardId; private final String indexName; - private final Version indexVersion; - public ScriptScoreFunction(Script sScript, ScoreScript.LeafFactory script, String indexName, int shardId, Version indexVersion) { + public ScriptScoreFunction(Script sScript, ScoreScript.LeafFactory script, String indexName, int shardId) { super(CombineFunction.REPLACE); this.sScript = sScript; this.script = script; this.indexName = indexName; this.shardId = shardId; - this.indexVersion = indexVersion; } @Override @@ -60,7 +57,6 @@ public LeafScoreFunction getLeafScoreFunction(LeafReaderContext ctx) throws IOEx leafScript.setScorer(scorer); leafScript._setIndexName(indexName); leafScript._setShard(shardId); - leafScript._setIndexVersion(indexVersion); return new LeafScoreFunction() { @Override public double score(int docId, float subQueryScore) throws IOException { diff --git a/server/src/main/java/org/elasticsearch/common/lucene/search/function/ScriptScoreQuery.java b/server/src/main/java/org/elasticsearch/common/lucene/search/function/ScriptScoreQuery.java index fae31341458ee..c33f588ac3670 100644 --- a/server/src/main/java/org/elasticsearch/common/lucene/search/function/ScriptScoreQuery.java +++ b/server/src/main/java/org/elasticsearch/common/lucene/search/function/ScriptScoreQuery.java @@ -146,7 +146,6 @@ private ScoreScript makeScoreScript(LeafReaderContext context) throws IOExceptio final ScoreScript scoreScript = scriptBuilder.newInstance(context); scoreScript._setIndexName(indexName); scoreScript._setShard(shardId); - scoreScript._setIndexVersion(indexVersion); return scoreScript; } diff --git a/server/src/main/java/org/elasticsearch/index/query/functionscore/ScriptScoreFunctionBuilder.java b/server/src/main/java/org/elasticsearch/index/query/functionscore/ScriptScoreFunctionBuilder.java index 16ccc4be3a22f..4833d0c7ef727 100644 --- a/server/src/main/java/org/elasticsearch/index/query/functionscore/ScriptScoreFunctionBuilder.java +++ b/server/src/main/java/org/elasticsearch/index/query/functionscore/ScriptScoreFunctionBuilder.java @@ -83,8 +83,7 @@ protected ScoreFunction doToFunction(SearchExecutionContext context) { try { ScoreScript.Factory factory = context.compile(script, ScoreScript.CONTEXT); ScoreScript.LeafFactory searchScript = factory.newFactory(script.getParams(), context.lookup()); - return new ScriptScoreFunction(script, searchScript, - context.index().getName(), context.getShardId(), context.indexVersionCreated()); + return new ScriptScoreFunction(script, searchScript, context.index().getName(), context.getShardId()); } catch (Exception e) { throw new QueryShardException(context, "script_score: the script could not be loaded", e); } diff --git a/server/src/main/java/org/elasticsearch/script/ScoreScript.java b/server/src/main/java/org/elasticsearch/script/ScoreScript.java index 4b09eff66dd0c..f702e44588584 100644 --- a/server/src/main/java/org/elasticsearch/script/ScoreScript.java +++ b/server/src/main/java/org/elasticsearch/script/ScoreScript.java @@ -10,7 +10,6 @@ import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.Explanation; import org.apache.lucene.search.Scorable; -import org.elasticsearch.Version; import org.elasticsearch.common.logging.DeprecationCategory; import org.elasticsearch.common.logging.DeprecationLogger; import org.elasticsearch.index.fielddata.ScriptDocValues; @@ -85,7 +84,6 @@ public Explanation get(double score, Explanation subQueryExplanation) { private int docId; private int shardId = -1; private String indexName = null; - private Version indexVersion = null; public ScoreScript(Map params, SearchLookup lookup, LeafReaderContext leafContext) { // null check needed b/c of expression engine subclass @@ -185,19 +183,6 @@ public String _getIndex() { } } - /** - * Starting a name with underscore, so that the user cannot access this function directly through a script - * It is only used within predefined painless functions. - * @return index version or throws an exception if the index version is not set up for this script instance - */ - public Version _getIndexVersion() { - if (indexVersion != null) { - return indexVersion; - } else { - throw new IllegalArgumentException("index version can not be looked up!"); - } - } - /** * Starting a name with underscore, so that the user cannot access this function directly through a script */ @@ -212,13 +197,6 @@ public void _setIndexName(String indexName) { this.indexName = indexName; } - /** - * Starting a name with underscore, so that the user cannot access this function directly through a script - */ - public void _setIndexVersion(Version indexVersion) { - this.indexVersion = indexVersion; - } - /** A factory to construct {@link ScoreScript} instances. */ public interface LeafFactory { diff --git a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/vectors/30_dense_vector_script_access.yml b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/vectors/30_dense_vector_script_access.yml new file mode 100644 index 0000000000000..ef670b004507f --- /dev/null +++ b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/vectors/30_dense_vector_script_access.yml @@ -0,0 +1,65 @@ +--- +"Access to values of dense_vector in script": + - skip: + version: " - 7.12.99" + reason: "Access to values of dense_vector in script was added in 7.13" + - do: + indices.create: + index: test-index + body: + mappings: + properties: + v: + type: dense_vector + dims: 3 + + - do: + bulk: + index: test-index + refresh: true + body: + - '{"index": {"_id": "1"}}' + - '{"v": [1, 1, 1]}' + - '{"index": {"_id": "2"}}' + - '{"v": [1, 1, 2]}' + - '{"index": {"_id": "3"}}' + - '{"v": [1, 1, 3]}' + - '{"index": {"_id": "missing_vector"}}' + - '{}' + + # vector functions in loop – return the index of the closest parameter vector based on cosine similarity + - do: + search: + body: + query: + script_score: + query: { "exists": { "field": "v" } } + script: + source: | + float[] v = doc['v'].vectorValue; + float vm = doc['v'].magnitude; + + int closestPv = 0; + float maxCosSim = -1; + for (int i = 0; i < params.pvs.length; i++) { + float dotProduct = 0; + for (int j = 0; j < v.length; j++) { + dotProduct += v[j] * params.pvs[i][j]; + } + float cosSim = dotProduct / (vm * (float) params.pvs_magnts[i]); + if (maxCosSim < cosSim) { + maxCosSim = cosSim; + closestPv = i; + } + } + closestPv; + params: + pvs: [ [ 1, 1, 1 ], [ 1, 1, 2 ], [ 1, 1, 3 ] ] + pvs_magnts: [1.7320, 2.4495, 3.3166] + + - match: { hits.hits.0._id: "3" } + - match: { hits.hits.0._score: 2 } + - match: { hits.hits.1._id: "2" } + - match: { hits.hits.1._score: 1 } + - match: { hits.hits.2._id: "1" } + - match: { hits.hits.2._score: 0 } diff --git a/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/mapper/DenseVectorFieldMapper.java b/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/mapper/DenseVectorFieldMapper.java index 2237dbbbbbd90..df374adedd119 100644 --- a/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/mapper/DenseVectorFieldMapper.java +++ b/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/mapper/DenseVectorFieldMapper.java @@ -83,7 +83,7 @@ protected List> getParameters() { public DenseVectorFieldMapper build(ContentPath contentPath) { return new DenseVectorFieldMapper( name, - new DenseVectorFieldType(buildFullName(contentPath), dims.getValue(), meta.getValue()), + new DenseVectorFieldType(buildFullName(contentPath), indexVersionCreated, dims.getValue(), meta.getValue()), dims.getValue(), indexVersionCreated, multiFieldsBuilder.build(this, contentPath), @@ -95,10 +95,12 @@ public DenseVectorFieldMapper build(ContentPath contentPath) { public static final class DenseVectorFieldType extends MappedFieldType { private final int dims; + private final Version indexVersionCreated; - public DenseVectorFieldType(String name, int dims, Map meta) { + public DenseVectorFieldType(String name, Version indexVersionCreated, int dims, Map meta) { super(name, false, false, true, TextSearchInfo.NONE, meta); this.dims = dims; + this.indexVersionCreated = indexVersionCreated; } int dims() { @@ -125,7 +127,7 @@ protected Object parseSourceValue(Object value) { @Override public DocValueFormat docValueFormat(String format, ZoneId timeZone) { - throw new UnsupportedOperationException( + throw new IllegalArgumentException( "Field [" + name() + "] of type [" + typeName() + "] doesn't support docvalue_fields or aggregations"); } @@ -136,7 +138,7 @@ public boolean isAggregatable() { @Override public IndexFieldData.Builder fielddataBuilder(String fullyQualifiedIndexName, Supplier searchLookup) { - return new VectorIndexFieldData.Builder(name(), true, CoreValuesSourceType.KEYWORD); + return new VectorIndexFieldData.Builder(name(), true, CoreValuesSourceType.KEYWORD, indexVersionCreated, dims); } @Override diff --git a/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/mapper/SparseVectorFieldMapper.java b/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/mapper/SparseVectorFieldMapper.java index 34ca7cbd1c38c..756e5aa9c8362 100644 --- a/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/mapper/SparseVectorFieldMapper.java +++ b/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/mapper/SparseVectorFieldMapper.java @@ -71,7 +71,7 @@ protected List> getParameters() { @Override public SparseVectorFieldMapper build(ContentPath contentPath) { return new SparseVectorFieldMapper( - name, new SparseVectorFieldType(buildFullName(contentPath), meta.getValue()), + name, new SparseVectorFieldType(buildFullName(contentPath), indexCreatedVersion, meta.getValue()), multiFieldsBuilder.build(this, contentPath), copyTo.build(), indexCreatedVersion); } } @@ -83,8 +83,10 @@ name, new SparseVectorFieldType(buildFullName(contentPath), meta.getValue()), public static final class SparseVectorFieldType extends MappedFieldType { - public SparseVectorFieldType(String name, Map meta) { + private final Version indexVersionCreated; + public SparseVectorFieldType(String name, Version indexVersionCreated, Map meta) { super(name, false, false, true, TextSearchInfo.NONE, meta); + this.indexVersionCreated = indexVersionCreated; } @Override @@ -94,7 +96,7 @@ public String typeName() { @Override public DocValueFormat docValueFormat(String format, ZoneId timeZone) { - throw new UnsupportedOperationException( + throw new IllegalArgumentException( "Field [" + name() + "] of type [" + typeName() + "] doesn't support docvalue_fields or aggregations"); } @@ -118,7 +120,7 @@ public Query existsQuery(SearchExecutionContext context) { @Override public IndexFieldData.Builder fielddataBuilder(String fullyQualifiedIndexName, Supplier searchLookup) { - return new VectorIndexFieldData.Builder(name(), false, CoreValuesSourceType.KEYWORD); + return new VectorIndexFieldData.Builder(name(), false, CoreValuesSourceType.KEYWORD, indexVersionCreated, -1); } @Override diff --git a/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/mapper/VectorEncoderDecoder.java b/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/mapper/VectorEncoderDecoder.java index 619c96d4e5030..5bb9538f87429 100644 --- a/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/mapper/VectorEncoderDecoder.java +++ b/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/mapper/VectorEncoderDecoder.java @@ -14,7 +14,6 @@ import java.nio.ByteBuffer; -// static utility functions for encoding and decoding dense_vector and sparse_vector fields public final class VectorEncoderDecoder { static final byte INT_BYTES = 4; static final byte SHORT_BYTES = 2; @@ -168,9 +167,51 @@ public static int denseVectorLength(Version indexVersion, BytesRef vectorBR) { * NOTE: this function can only be called on vectors from an index version greater than or * equal to 7.5.0, since vectors created prior to that do not store the magnitude. */ - public static float decodeVectorMagnitude(Version indexVersion, BytesRef vectorBR) { + public static float decodeMagnitude(Version indexVersion, BytesRef vectorBR) { assert indexVersion.onOrAfter(Version.V_7_5_0); ByteBuffer byteBuffer = ByteBuffer.wrap(vectorBR.bytes, vectorBR.offset, vectorBR.length); - return byteBuffer.getFloat(vectorBR.offset + vectorBR.length - 4); + return byteBuffer.getFloat(vectorBR.offset + vectorBR.length - INT_BYTES); } + + /** + * Calculates vector magnitude + */ + private static float calculateMagnitude(Version indexVersion, BytesRef vectorBR) { + final int length = denseVectorLength(indexVersion, vectorBR); + ByteBuffer byteBuffer = ByteBuffer.wrap(vectorBR.bytes, vectorBR.offset, vectorBR.length); + double magnitude = 0.0f; + for (int i = 0; i < length; i++) { + float value = byteBuffer.getFloat(); + magnitude += value * value; + } + magnitude = Math.sqrt(magnitude); + return (float) magnitude; + } + + public static float getMagnitude(Version indexVersion, BytesRef vectorBR) { + if (vectorBR == null) { + throw new IllegalArgumentException("A document doesn't have a value for a vector field!"); + } + if (indexVersion.onOrAfter(Version.V_7_5_0)) { + return decodeMagnitude(indexVersion, vectorBR); + } else { + return calculateMagnitude(indexVersion, vectorBR); + } + } + + /** + * Decodes a BytesRef into the provided array of floats + * @param vectorBR - dense vector encoded in BytesRef + * @param vector - array of floats where the decoded vector should be stored + */ + public static void decodeDenseVector(BytesRef vectorBR, float[] vector) { + if (vectorBR == null) { + throw new IllegalArgumentException("A document doesn't have a value for a vector field!"); + } + ByteBuffer byteBuffer = ByteBuffer.wrap(vectorBR.bytes, vectorBR.offset, vectorBR.length); + for (int dim = 0; dim < vector.length; dim++) { + vector[dim] = byteBuffer.getFloat(); + } + } + } diff --git a/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/ScoreScriptUtils.java b/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/ScoreScriptUtils.java index 9424320419b13..714522776a03d 100644 --- a/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/ScoreScriptUtils.java +++ b/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/ScoreScriptUtils.java @@ -60,6 +60,21 @@ public DenseVectorFunction(ScoreScript scoreScript, Object field, boolean normalizeQuery) { this.scoreScript = scoreScript; + if (field instanceof String) { + String fieldName = (String) field; + docValues = (DenseVectorScriptDocValues) scoreScript.getDoc().get(fieldName); + } else if (field instanceof DenseVectorScriptDocValues) { + docValues = (DenseVectorScriptDocValues) field; + deprecationLogger.deprecate(DeprecationCategory.SCRIPTING, "vector_function_signature", DEPRECATION_MESSAGE); + } else { + throw new IllegalArgumentException("For vector functions, the 'field' argument must be of type String or " + + "VectorScriptDocValues"); + } + + if (docValues.dims() != queryVector.size()){ + throw new IllegalArgumentException("The query vector has a different number of dimensions [" + + queryVector.size() + "] than the document vectors [" + docValues.dims() + "]."); + } this.queryVector = new float[queryVector.size()]; double queryMagnitude = 0.0; @@ -75,17 +90,6 @@ public DenseVectorFunction(ScoreScript scoreScript, this.queryVector[dim] /= queryMagnitude; } } - - if (field instanceof String) { - String fieldName = (String) field; - docValues = (DenseVectorScriptDocValues) scoreScript.getDoc().get(fieldName); - } else if (field instanceof DenseVectorScriptDocValues) { - docValues = (DenseVectorScriptDocValues) field; - deprecationLogger.deprecate(DeprecationCategory.SCRIPTING, "vector_function_signature", DEPRECATION_MESSAGE); - } else { - throw new IllegalArgumentException("For vector functions, the 'field' argument must be of type String or " + - "VectorScriptDocValues"); - } } BytesRef getEncodedVector() { @@ -94,18 +98,10 @@ BytesRef getEncodedVector() { } catch (IOException e) { throw ExceptionsHelper.convertToElastic(e); } - - // Validate the encoded vector's length. BytesRef vector = docValues.getEncodedValue(); if (vector == null) { throw new IllegalArgumentException("A document doesn't have a value for a vector field!"); } - - int vectorLength = VectorEncoderDecoder.denseVectorLength(scoreScript._getIndexVersion(), vector); - if (queryVector.length != vectorLength) { - throw new IllegalArgumentException("The query vector has a different number of dimensions [" + - queryVector.length + "] than the document vectors [" + vectorLength + "]."); - } return vector; } } @@ -179,23 +175,11 @@ public CosineSimilarity(ScoreScript scoreScript, List queryVector, Objec public double cosineSimilarity() { BytesRef vector = getEncodedVector(); ByteBuffer byteBuffer = ByteBuffer.wrap(vector.bytes, vector.offset, vector.length); - double dotProduct = 0.0; - double vectorMagnitude = 0.0f; - if (scoreScript._getIndexVersion().onOrAfter(Version.V_7_5_0)) { - for (float queryValue : queryVector) { - dotProduct += queryValue * byteBuffer.getFloat(); - } - vectorMagnitude = VectorEncoderDecoder.decodeVectorMagnitude(scoreScript._getIndexVersion(), vector); - } else { - for (float queryValue : queryVector) { - float docValue = byteBuffer.getFloat(); - dotProduct += queryValue * docValue; - vectorMagnitude += docValue * docValue; - } - vectorMagnitude = (float) Math.sqrt(vectorMagnitude); + for (float queryValue : queryVector) { + dotProduct += queryValue * byteBuffer.getFloat(); } - return dotProduct / vectorMagnitude; + return dotProduct / docValues.getMagnitude(); } } @@ -272,19 +256,19 @@ public L1NormSparse(ScoreScript scoreScript,Map queryVector, Obj public double l1normSparse() { BytesRef vector = getEncodedVector(); - int[] docDims = VectorEncoderDecoder.decodeSparseVectorDims(scoreScript._getIndexVersion(), vector); - float[] docValues = VectorEncoderDecoder.decodeSparseVector(scoreScript._getIndexVersion(), vector); + int[] docDims = VectorEncoderDecoder.decodeSparseVectorDims(docValues.indexVersion(), vector); + float[] values = VectorEncoderDecoder.decodeSparseVector(docValues.indexVersion(), vector); int queryIndex = 0; int docIndex = 0; double l1norm = 0; while (queryIndex < queryDims.length && docIndex < docDims.length) { if (queryDims[queryIndex] == docDims[docIndex]) { - l1norm += Math.abs(queryValues[queryIndex] - docValues[docIndex]); + l1norm += Math.abs(queryValues[queryIndex] - values[docIndex]); queryIndex++; docIndex++; } else if (queryDims[queryIndex] > docDims[docIndex]) { - l1norm += Math.abs(docValues[docIndex]); // 0 for missing query dim + l1norm += Math.abs(values[docIndex]); // 0 for missing query dim docIndex++; } else { l1norm += Math.abs(queryValues[queryIndex]); // 0 for missing doc dim @@ -296,7 +280,7 @@ public double l1normSparse() { queryIndex++; } while (docIndex < docDims.length) { - l1norm += Math.abs(docValues[docIndex]); // 0 for missing query dim + l1norm += Math.abs(values[docIndex]); // 0 for missing query dim docIndex++; } return l1norm; @@ -311,20 +295,20 @@ public L2NormSparse(ScoreScript scoreScript, Map queryVector, Ob public double l2normSparse() { BytesRef vector = getEncodedVector(); - int[] docDims = VectorEncoderDecoder.decodeSparseVectorDims(scoreScript._getIndexVersion(), vector); - float[] docValues = VectorEncoderDecoder.decodeSparseVector(scoreScript._getIndexVersion(), vector); + int[] docDims = VectorEncoderDecoder.decodeSparseVectorDims(docValues.indexVersion(), vector); + float[] values = VectorEncoderDecoder.decodeSparseVector(docValues.indexVersion(), vector); int queryIndex = 0; int docIndex = 0; double l2norm = 0; while (queryIndex < queryDims.length && docIndex < docDims.length) { if (queryDims[queryIndex] == docDims[docIndex]) { - double diff = queryValues[queryIndex] - docValues[docIndex]; + double diff = queryValues[queryIndex] - values[docIndex]; l2norm += diff * diff; queryIndex++; docIndex++; } else if (queryDims[queryIndex] > docDims[docIndex]) { - double diff = docValues[docIndex]; // 0 for missing query dim + double diff = values[docIndex]; // 0 for missing query dim l2norm += diff * diff; docIndex++; } else { @@ -338,7 +322,7 @@ public double l2normSparse() { queryIndex++; } while (docIndex < docDims.length) { - l2norm += docValues[docIndex]* docValues[docIndex]; // 0 for missing query dims + l2norm += values[docIndex]* values[docIndex]; // 0 for missing query dims docIndex++; } return Math.sqrt(l2norm); @@ -353,10 +337,10 @@ public DotProductSparse(ScoreScript scoreScript, Map queryVector public double dotProductSparse() { BytesRef vector = getEncodedVector(); - int[] docDims = VectorEncoderDecoder.decodeSparseVectorDims(scoreScript._getIndexVersion(), vector); - float[] docValues = VectorEncoderDecoder.decodeSparseVector(scoreScript._getIndexVersion(), vector); + int[] docDims = VectorEncoderDecoder.decodeSparseVectorDims(docValues.indexVersion(), vector); + float[] values = VectorEncoderDecoder.decodeSparseVector(docValues.indexVersion(), vector); - return intDotProductSparse(queryValues, queryDims, docValues, docDims); + return intDotProductSparse(queryValues, queryDims, values, docDims); } } @@ -375,15 +359,15 @@ public CosineSimilaritySparse(ScoreScript scoreScript, Map query public double cosineSimilaritySparse() { BytesRef vector = getEncodedVector(); - int[] docDims = VectorEncoderDecoder.decodeSparseVectorDims(scoreScript._getIndexVersion(), vector); - float[] docValues = VectorEncoderDecoder.decodeSparseVector(scoreScript._getIndexVersion(), vector); + int[] docDims = VectorEncoderDecoder.decodeSparseVectorDims(docValues.indexVersion(), vector); + float[] values = VectorEncoderDecoder.decodeSparseVector(docValues.indexVersion(), vector); - double docQueryDotProduct = intDotProductSparse(queryValues, queryDims, docValues, docDims); + double docQueryDotProduct = intDotProductSparse(queryValues, queryDims, values, docDims); double docVectorMagnitude = 0.0f; - if (scoreScript._getIndexVersion().onOrAfter(Version.V_7_5_0)) { - docVectorMagnitude = VectorEncoderDecoder.decodeVectorMagnitude(scoreScript._getIndexVersion(), vector); + if (docValues.indexVersion().onOrAfter(Version.V_7_5_0)) { + docVectorMagnitude = VectorEncoderDecoder.decodeMagnitude(docValues.indexVersion(), vector); } else { - for (float docValue : docValues) { + for (float docValue : values) { docVectorMagnitude += docValue * docValue; } docVectorMagnitude = (float) Math.sqrt(docVectorMagnitude); diff --git a/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/VectorDVLeafFieldData.java b/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/VectorDVLeafFieldData.java index d31c269c55919..620c66624206d 100644 --- a/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/VectorDVLeafFieldData.java +++ b/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/VectorDVLeafFieldData.java @@ -13,6 +13,7 @@ import org.apache.lucene.index.LeafReader; import org.apache.lucene.util.Accountable; import org.apache.lucene.util.BytesRef; +import org.elasticsearch.Version; import org.elasticsearch.index.fielddata.LeafFieldData; import org.elasticsearch.index.fielddata.ScriptDocValues; import org.elasticsearch.index.fielddata.SortedBinaryDocValues; @@ -26,11 +27,15 @@ final class VectorDVLeafFieldData implements LeafFieldData { private final LeafReader reader; private final String field; private final boolean isDense; + private final Version indexVersion; + private final int dims; - VectorDVLeafFieldData(LeafReader reader, String field, boolean isDense) { + VectorDVLeafFieldData(LeafReader reader, String field, boolean isDense, Version indexVersion, int dims) { this.reader = reader; this.field = field; this.isDense = isDense; + this.indexVersion = indexVersion; + this.dims = dims; } @Override @@ -53,9 +58,9 @@ public ScriptDocValues getScriptValues() { try { final BinaryDocValues values = DocValues.getBinary(reader, field); if (isDense) { - return new VectorScriptDocValues.DenseVectorScriptDocValues(values); + return new VectorScriptDocValues.DenseVectorScriptDocValues(values, indexVersion, dims); } else { - return new VectorScriptDocValues.SparseVectorScriptDocValues(values); + return new VectorScriptDocValues.SparseVectorScriptDocValues(values, indexVersion); } } catch (IOException e) { throw new IllegalStateException("Cannot load doc values for vector field!", e); diff --git a/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/VectorIndexFieldData.java b/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/VectorIndexFieldData.java index 84917a60d57ec..4e456b301eb02 100644 --- a/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/VectorIndexFieldData.java +++ b/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/VectorIndexFieldData.java @@ -10,6 +10,7 @@ import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.SortField; +import org.elasticsearch.Version; import org.elasticsearch.common.Nullable; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.index.fielddata.IndexFieldData; @@ -21,6 +22,7 @@ import org.elasticsearch.search.aggregations.support.ValuesSourceType; import org.elasticsearch.search.sort.BucketedSort; import org.elasticsearch.search.sort.SortOrder; +import org.elasticsearch.xpack.vectors.mapper.DenseVectorFieldMapper; public class VectorIndexFieldData implements IndexFieldData { @@ -28,11 +30,15 @@ public class VectorIndexFieldData implements IndexFieldData build(IndexFieldDataCache cache, CircuitBreakerService breakerService) { - return new VectorIndexFieldData(name, isDense, valuesSourceType); + return new VectorIndexFieldData(name, isDense, valuesSourceType, indexVersion, dims); } } } diff --git a/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/VectorScriptDocValues.java b/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/VectorScriptDocValues.java index 657a2150342b5..6d14b4407d547 100644 --- a/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/VectorScriptDocValues.java +++ b/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/VectorScriptDocValues.java @@ -10,7 +10,9 @@ import org.apache.lucene.index.BinaryDocValues; import org.apache.lucene.util.BytesRef; +import org.elasticsearch.Version; import org.elasticsearch.index.fielddata.ScriptDocValues; +import org.elasticsearch.xpack.vectors.mapper.VectorEncoderDecoder; import java.io.IOException; @@ -20,10 +22,12 @@ public abstract class VectorScriptDocValues extends ScriptDocValues { private final BinaryDocValues in; - private BytesRef value; + final Version indexVersion; + BytesRef value; - VectorScriptDocValues(BinaryDocValues in) { + VectorScriptDocValues(BinaryDocValues in, Version indexVersion) { this.in = in; + this.indexVersion = indexVersion; } @Override @@ -56,15 +60,51 @@ public int size() { // not final, as it needs to be extended by Mockito for tests public static class DenseVectorScriptDocValues extends VectorScriptDocValues { - public DenseVectorScriptDocValues(BinaryDocValues in) { - super(in); + private final int dims; + private final float[] vector; + + public DenseVectorScriptDocValues(BinaryDocValues in, Version indexVersion, int dims) { + super(in, indexVersion); + this.dims = dims; + this.vector = new float[dims]; + } + + @Override + public BytesRef get(int index) { + throw new UnsupportedOperationException("accessing a vector field's value through 'get' or 'value' is not supported!" + + "Use 'vectorValue' or 'magnitude' instead!'"); + } + + // package private access only for {@link ScoreScriptUtils} + int dims() { + return dims; + } + + /** + * Get dense vector's value as an array of floats + */ + public float[] getVectorValue() { + VectorEncoderDecoder.decodeDenseVector(value, vector); + return vector; + } + + /** + * Get dense vector's magnitude + */ + public float getMagnitude() { + return VectorEncoderDecoder.getMagnitude(indexVersion, value); } } // not final, as it needs to be extended by Mockito for tests public static class SparseVectorScriptDocValues extends VectorScriptDocValues { - public SparseVectorScriptDocValues(BinaryDocValues in) { - super(in); + public SparseVectorScriptDocValues(BinaryDocValues in, Version indexVersion) { + super(in, indexVersion); + } + + // package private access only for {@link ScoreScriptUtils} + Version indexVersion() { + return indexVersion; } } diff --git a/x-pack/plugin/vectors/src/main/resources/org/elasticsearch/xpack/vectors/query/whitelist.txt b/x-pack/plugin/vectors/src/main/resources/org/elasticsearch/xpack/vectors/query/whitelist.txt index 43cc8bbb9e9fc..869fa7ab367c3 100644 --- a/x-pack/plugin/vectors/src/main/resources/org/elasticsearch/xpack/vectors/query/whitelist.txt +++ b/x-pack/plugin/vectors/src/main/resources/org/elasticsearch/xpack/vectors/query/whitelist.txt @@ -7,6 +7,8 @@ class org.elasticsearch.xpack.vectors.query.VectorScriptDocValues { } class org.elasticsearch.xpack.vectors.query.VectorScriptDocValues$DenseVectorScriptDocValues { + float[] getVectorValue() + float getMagnitude() } class org.elasticsearch.xpack.vectors.query.VectorScriptDocValues$SparseVectorScriptDocValues { } diff --git a/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/mapper/DenseVectorFieldMapperTests.java b/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/mapper/DenseVectorFieldMapperTests.java index 890ff8757d38f..d7f43c1cee924 100644 --- a/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/mapper/DenseVectorFieldMapperTests.java +++ b/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/mapper/DenseVectorFieldMapperTests.java @@ -101,7 +101,7 @@ public void testDefaults() throws Exception { // assert that after decoding the indexed value is equal to expected BytesRef vectorBR = fields[0].binaryValue(); float[] decodedValues = decodeDenseVector(Version.CURRENT, vectorBR); - float decodedMagnitude = VectorEncoderDecoder.decodeVectorMagnitude(Version.CURRENT, vectorBR); + float decodedMagnitude = VectorEncoderDecoder.decodeMagnitude(Version.CURRENT, vectorBR); assertEquals(expectedMagnitude, decodedMagnitude, 0.001f); assertArrayEquals( "Decoded dense vector values is not equal to the indexed one.", diff --git a/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/mapper/DenseVectorFieldTypeTests.java b/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/mapper/DenseVectorFieldTypeTests.java index ad77daea45cbe..275c814347259 100644 --- a/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/mapper/DenseVectorFieldTypeTests.java +++ b/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/mapper/DenseVectorFieldTypeTests.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.vectors.mapper; +import org.elasticsearch.Version; import org.elasticsearch.index.mapper.FieldTypeTestCase; import java.io.IOException; @@ -16,29 +17,34 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase { public void testHasDocValues() { - DenseVectorFieldMapper.DenseVectorFieldType ft = new DenseVectorFieldMapper.DenseVectorFieldType("f", 1, Collections.emptyMap()); + DenseVectorFieldMapper.DenseVectorFieldType ft = new DenseVectorFieldMapper.DenseVectorFieldType( + "f", Version.CURRENT, 1, Collections.emptyMap()); assertTrue(ft.hasDocValues()); } public void testIsAggregatable() { - DenseVectorFieldMapper.DenseVectorFieldType ft = new DenseVectorFieldMapper.DenseVectorFieldType("f", 1, Collections.emptyMap()); + DenseVectorFieldMapper.DenseVectorFieldType ft = new DenseVectorFieldMapper.DenseVectorFieldType( + "f", Version.CURRENT,1, Collections.emptyMap()); assertFalse(ft.isAggregatable()); } public void testFielddataBuilder() { - DenseVectorFieldMapper.DenseVectorFieldType ft = new DenseVectorFieldMapper.DenseVectorFieldType("f", 1, Collections.emptyMap()); + DenseVectorFieldMapper.DenseVectorFieldType ft = new DenseVectorFieldMapper.DenseVectorFieldType( + "f", Version.CURRENT,1, Collections.emptyMap()); assertNotNull(ft.fielddataBuilder("index", () -> { throw new UnsupportedOperationException(); })); } public void testDocValueFormat() { - DenseVectorFieldMapper.DenseVectorFieldType ft = new DenseVectorFieldMapper.DenseVectorFieldType("f", 1, Collections.emptyMap()); - expectThrows(UnsupportedOperationException.class, () -> ft.docValueFormat(null, null)); + DenseVectorFieldMapper.DenseVectorFieldType ft = new DenseVectorFieldMapper.DenseVectorFieldType( + "f", Version.CURRENT,1, Collections.emptyMap()); + expectThrows(IllegalArgumentException.class, () -> ft.docValueFormat(null, null)); } public void testFetchSourceValue() throws IOException { - DenseVectorFieldMapper.DenseVectorFieldType ft = new DenseVectorFieldMapper.DenseVectorFieldType("f", 5, Collections.emptyMap()); + DenseVectorFieldMapper.DenseVectorFieldType ft = new DenseVectorFieldMapper.DenseVectorFieldType( + "f", Version.CURRENT, 5, Collections.emptyMap()); List vector = org.elasticsearch.common.collect.List.of(0.0, 1.0, 2.0, 3.0, 4.0); assertEquals(vector, fetchSourceValue(ft, vector)); } diff --git a/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/mapper/SparseVectorFieldMapperTests.java b/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/mapper/SparseVectorFieldMapperTests.java index 987ddcf75dfd1..8a3bc48a81108 100644 --- a/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/mapper/SparseVectorFieldMapperTests.java +++ b/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/mapper/SparseVectorFieldMapperTests.java @@ -111,7 +111,7 @@ public void testDefaults() throws Exception { decodedValues, 0.001f ); - float decodedMagnitude = VectorEncoderDecoder.decodeVectorMagnitude(indexVersion, vectorBR); + float decodedMagnitude = VectorEncoderDecoder.decodeMagnitude(indexVersion, vectorBR); assertEquals(expectedMagnitude, decodedMagnitude, 0.001f); assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE); diff --git a/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/mapper/SparseVectorFieldTypeTests.java b/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/mapper/SparseVectorFieldTypeTests.java index e73bcd5746f43..4aaa18f70342e 100644 --- a/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/mapper/SparseVectorFieldTypeTests.java +++ b/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/mapper/SparseVectorFieldTypeTests.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.vectors.mapper; +import org.elasticsearch.Version; import org.elasticsearch.index.mapper.FieldTypeTestCase; import org.elasticsearch.index.mapper.MappedFieldType; @@ -16,30 +17,30 @@ public class SparseVectorFieldTypeTests extends FieldTypeTestCase { public void testHasDocValues() { - MappedFieldType fieldType = new SparseVectorFieldMapper.SparseVectorFieldType("field", Collections.emptyMap()); + MappedFieldType fieldType = new SparseVectorFieldMapper.SparseVectorFieldType("field", Version.CURRENT, Collections.emptyMap()); assertTrue(fieldType.hasDocValues()); } public void testFielddataBuilder() { - MappedFieldType fieldType = new SparseVectorFieldMapper.SparseVectorFieldType("field", Collections.emptyMap()); + MappedFieldType fieldType = new SparseVectorFieldMapper.SparseVectorFieldType("field", Version.CURRENT, Collections.emptyMap()); assertNotNull(fieldType.fielddataBuilder("index", () -> { throw new UnsupportedOperationException(); })); } public void testIsNotAggregatable() { - MappedFieldType fieldType = new SparseVectorFieldMapper.SparseVectorFieldType("field", Collections.emptyMap()); + MappedFieldType fieldType = new SparseVectorFieldMapper.SparseVectorFieldType("field", Version.CURRENT, Collections.emptyMap()); assertFalse(fieldType.isAggregatable()); } public void testDocValueFormatIsNotSupported() { - MappedFieldType fieldType = new SparseVectorFieldMapper.SparseVectorFieldType("field", Collections.emptyMap()); - UnsupportedOperationException exc = expectThrows(UnsupportedOperationException.class, () -> fieldType.docValueFormat(null, null)); + MappedFieldType fieldType = new SparseVectorFieldMapper.SparseVectorFieldType("field", Version.CURRENT, Collections.emptyMap()); + IllegalArgumentException exc = expectThrows(IllegalArgumentException.class, () -> fieldType.docValueFormat(null, null)); assertEquals("Field [field] of type [sparse_vector] doesn't support docvalue_fields or aggregations", exc.getMessage()); } public void testTermQueryIsNotSupported() { - MappedFieldType fieldType = new SparseVectorFieldMapper.SparseVectorFieldType("field", Collections.emptyMap()); + MappedFieldType fieldType = new SparseVectorFieldMapper.SparseVectorFieldType("field", Version.CURRENT, Collections.emptyMap()); IllegalArgumentException exc = expectThrows(IllegalArgumentException.class, () -> fieldType.termQuery(null, null)); assertEquals("Field [field] of type [sparse_vector] doesn't support queries", exc.getMessage()); } diff --git a/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/mapper/VectorEncoderDecoderTests.java b/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/mapper/VectorEncoderDecoderTests.java index 9d888be6e9dc9..aa73f4d3b5e25 100644 --- a/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/mapper/VectorEncoderDecoderTests.java +++ b/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/mapper/VectorEncoderDecoderTests.java @@ -44,7 +44,7 @@ public void testSparseVectorEncodingDecoding() { BytesRef encodedSparseVector = VectorEncoderDecoder.encodeSparseVector(indexVersion, expectedDims, expectedValues, dimCount); int[] decodedDims = VectorEncoderDecoder.decodeSparseVectorDims(indexVersion, encodedSparseVector); float[] decodedValues = VectorEncoderDecoder.decodeSparseVector(indexVersion, encodedSparseVector); - float decodedMagnitude = VectorEncoderDecoder.decodeVectorMagnitude(indexVersion, encodedSparseVector); + float decodedMagnitude = VectorEncoderDecoder.decodeMagnitude(indexVersion, encodedSparseVector); assertEquals(expectedMagnitude, decodedMagnitude, 0.0f); assertArrayEquals( "Decoded sparse vector dims are not equal to their original!", diff --git a/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/query/DenseVectorFunctionTests.java b/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/query/DenseVectorFunctionTests.java index 4802e6dba140d..15e9e9f5e420b 100644 --- a/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/query/DenseVectorFunctionTests.java +++ b/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/query/DenseVectorFunctionTests.java @@ -11,6 +11,7 @@ import org.elasticsearch.Version; import org.elasticsearch.script.ScoreScript; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.vectors.mapper.VectorEncoderDecoder; import org.elasticsearch.xpack.vectors.query.ScoreScriptUtils.CosineSimilarity; import org.elasticsearch.xpack.vectors.query.ScoreScriptUtils.DotProduct; import org.elasticsearch.xpack.vectors.query.ScoreScriptUtils.L1Norm; @@ -44,11 +45,14 @@ public void setUpVectors() { public void testDenseVectorFunctions() { for (Version indexVersion : Arrays.asList(Version.V_7_4_0, Version.CURRENT)) { BytesRef encodedDocVector = mockEncodeDenseVector(docVector, indexVersion); + float magnitude = VectorEncoderDecoder.getMagnitude(indexVersion, encodedDocVector); + DenseVectorScriptDocValues docValues = mock(DenseVectorScriptDocValues.class); when(docValues.getEncodedValue()).thenReturn(encodedDocVector); + when(docValues.getMagnitude()).thenReturn(magnitude); + when(docValues.dims()).thenReturn(docVector.length); ScoreScript scoreScript = mock(ScoreScript.class); - when(scoreScript._getIndexVersion()).thenReturn(indexVersion); when(scoreScript.getDoc()).thenReturn(Collections.singletonMap(field, docValues)); testDotProduct(docValues, scoreScript); @@ -68,8 +72,8 @@ private void testDotProduct(DenseVectorScriptDocValues docValues, ScoreScript sc assertEquals("dotProduct result is not equal to the expected value!", 65425.624, deprecatedResult, 0.001); assertWarnings(ScoreScriptUtils.DEPRECATION_MESSAGE); - DotProduct invalidFunction = new DotProduct(scoreScript, invalidQueryVector, field); - IllegalArgumentException e = expectThrows(IllegalArgumentException.class, invalidFunction::dotProduct); + IllegalArgumentException e = + expectThrows(IllegalArgumentException.class, () -> new DotProduct(scoreScript, invalidQueryVector, field)); assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]")); } @@ -83,8 +87,8 @@ private void testCosineSimilarity(DenseVectorScriptDocValues docValues, ScoreScr assertEquals("cosineSimilarity result is not equal to the expected value!", 0.790, deprecatedResult, 0.001); assertWarnings(ScoreScriptUtils.DEPRECATION_MESSAGE); - CosineSimilarity invalidFunction = new CosineSimilarity(scoreScript, invalidQueryVector, field); - IllegalArgumentException e = expectThrows(IllegalArgumentException.class, invalidFunction::cosineSimilarity); + IllegalArgumentException e = + expectThrows(IllegalArgumentException.class, () -> new CosineSimilarity(scoreScript, invalidQueryVector, field)); assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]")); } @@ -98,8 +102,8 @@ private void testL1Norm(DenseVectorScriptDocValues docValues, ScoreScript scoreS assertEquals("l1norm result is not equal to the expected value!", 485.184, deprecatedResult, 0.001); assertWarnings(ScoreScriptUtils.DEPRECATION_MESSAGE); - L1Norm invalidFunction = new L1Norm(scoreScript, invalidQueryVector, field); - IllegalArgumentException e = expectThrows(IllegalArgumentException.class, invalidFunction::l1norm); + IllegalArgumentException e = + expectThrows(IllegalArgumentException.class, () -> new L1Norm(scoreScript, invalidQueryVector, field)); assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]")); } @@ -113,8 +117,7 @@ private void testL2Norm(DenseVectorScriptDocValues docValues, ScoreScript scoreS assertEquals("l2norm result is not equal to the expected value!", 301.361, deprecatedResult, 0.001); assertWarnings(ScoreScriptUtils.DEPRECATION_MESSAGE); - L2Norm invalidFunction = new L2Norm(scoreScript, invalidQueryVector, field); - IllegalArgumentException e = expectThrows(IllegalArgumentException.class, invalidFunction::l2norm); + IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> new L2Norm(scoreScript, invalidQueryVector, field)); assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]")); } } diff --git a/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/query/DenseVectorScriptDocValuesTests.java b/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/query/DenseVectorScriptDocValuesTests.java new file mode 100644 index 0000000000000..5bb5777c1fd2e --- /dev/null +++ b/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/query/DenseVectorScriptDocValuesTests.java @@ -0,0 +1,107 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.vectors.query; + +import org.apache.lucene.index.BinaryDocValues; +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.Version; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.vectors.query.VectorScriptDocValues.DenseVectorScriptDocValues; + +import java.io.IOException; +import java.util.Arrays; + +import static org.elasticsearch.xpack.vectors.mapper.VectorEncoderDecoderTests.mockEncodeDenseVector; +import static org.hamcrest.Matchers.containsString; + +public class DenseVectorScriptDocValuesTests extends ESTestCase { + + private static BinaryDocValues wrap(float[][] vectors, Version indexVersion) { + return new BinaryDocValues() { + int idx = -1; + int maxIdx = vectors.length; + @Override + public BytesRef binaryValue() { + if (idx >= maxIdx) { + throw new IllegalStateException("max index exceeded"); + } + return mockEncodeDenseVector(vectors[idx], indexVersion); + } + + @Override + public boolean advanceExact(int target) { + idx = target; + if (target < maxIdx) { + return true; + } + return false; + } + + @Override + public int docID() { + return idx; + } + + @Override + public int nextDoc() { + return idx++; + } + + @Override + public int advance(int target) { + throw new IllegalArgumentException("not defined!"); + } + + @Override + public long cost() { + throw new IllegalArgumentException("not defined!"); + } + }; + } + + public void testGetVectorValueAndGetMagnitude() throws IOException { + final int dims = 3; + float[][] vectors = {{ 1, 1, 1 }, { 1, 1, 2 }, { 1, 1, 3 } }; + float[] expectedMagnitudes = { 1.7320f, 2.4495f, 3.3166f }; + + for (Version indexVersion : Arrays.asList(Version.V_7_4_0, Version.CURRENT)) { + BinaryDocValues docValues = wrap(vectors, indexVersion); + final DenseVectorScriptDocValues scriptDocValues = new DenseVectorScriptDocValues(docValues, indexVersion, dims); + for (int i = 0; i < vectors.length; i++) { + scriptDocValues.setNextDocId(i); + assertArrayEquals(vectors[i], scriptDocValues.getVectorValue(), 0.0001f); + assertEquals(expectedMagnitudes[i], scriptDocValues.getMagnitude(), 0.0001f); + } + } + } + + public void testMissingValues() throws IOException { + final int dims = 3; + float[][] vectors = {{ 1, 1, 1 }, { 1, 1, 2 }, { 1, 1, 3 } }; + BinaryDocValues docValues = wrap(vectors, Version.CURRENT); + final DenseVectorScriptDocValues scriptDocValues = new DenseVectorScriptDocValues(docValues, Version.CURRENT, dims); + + scriptDocValues.setNextDocId(3); + Exception e = expectThrows(IllegalArgumentException.class, () -> scriptDocValues.getVectorValue()); + assertEquals("A document doesn't have a value for a vector field!", e.getMessage()); + + e = expectThrows(IllegalArgumentException.class, () -> scriptDocValues.getMagnitude()); + assertEquals("A document doesn't have a value for a vector field!", e.getMessage()); + } + + public void testGetFunctionIsNotAccessible() throws IOException { + final int dims = 3; + float[][] vectors = {{ 1, 1, 1 }, { 1, 1, 2 }, { 1, 1, 3 } }; + BinaryDocValues docValues = wrap(vectors, Version.CURRENT); + final DenseVectorScriptDocValues scriptDocValues = new DenseVectorScriptDocValues(docValues, Version.CURRENT, dims); + + scriptDocValues.setNextDocId(0); + Exception e = expectThrows(UnsupportedOperationException.class, () -> scriptDocValues.get(0)); + assertThat(e.getMessage(), containsString("accessing a vector field's value through 'get' or 'value' is not supported!")); + } +} diff --git a/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/query/SparseVectorFunctionTests.java b/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/query/SparseVectorFunctionTests.java index 5715595d0b0c6..278a4c9d65524 100644 --- a/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/query/SparseVectorFunctionTests.java +++ b/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/query/SparseVectorFunctionTests.java @@ -54,9 +54,9 @@ public void testSparseVectorFunctions() { docVectorDims, docVectorValues, docVectorDims.length); SparseVectorScriptDocValues docValues = mock(SparseVectorScriptDocValues.class); when(docValues.getEncodedValue()).thenReturn(encodedDocVector); + when(docValues.indexVersion()).thenReturn(indexVersion); ScoreScript scoreScript = mock(ScoreScript.class); - when(scoreScript._getIndexVersion()).thenReturn(indexVersion); when(scoreScript.getDoc()).thenReturn(Collections.singletonMap(field, docValues)); testDotProduct(docValues, scoreScript); @@ -120,9 +120,9 @@ public void testSparseVectorMissingDimensions1() { Version.CURRENT, docVectorDims, docVectorValues, docVectorDims.length); VectorScriptDocValues.SparseVectorScriptDocValues dvs = mock(VectorScriptDocValues.SparseVectorScriptDocValues.class); when(dvs.getEncodedValue()).thenReturn(encodedDocVector); + when(dvs.indexVersion()).thenReturn(Version.CURRENT); ScoreScript scoreScript = mock(ScoreScript.class); - when(scoreScript._getIndexVersion()).thenReturn(Version.CURRENT); when(scoreScript.getDoc()).thenReturn(Collections.singletonMap(field, dvs)); Map queryVector = new HashMap() {{ @@ -169,9 +169,9 @@ public void testSparseVectorMissingDimensions2() { Version.CURRENT, docVectorDims, docVectorValues, docVectorDims.length); VectorScriptDocValues.SparseVectorScriptDocValues dvs = mock(VectorScriptDocValues.SparseVectorScriptDocValues.class); when(dvs.getEncodedValue()).thenReturn(encodedDocVector); + when(dvs.indexVersion()).thenReturn(Version.CURRENT); ScoreScript scoreScript = mock(ScoreScript.class); - when(scoreScript._getIndexVersion()).thenReturn(Version.CURRENT); when(scoreScript.getDoc()).thenReturn(Collections.singletonMap(field, dvs)); Map queryVector = new HashMap() {{