Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add access to dense_vector values #71313

Merged
merged 13 commits into from
Apr 19, 2021
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions docs/reference/mapping/types/dense-vector.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@ A `dense_vector` field stores dense vectors of float values.
The maximum number of dimensions that can be in a vector should
not exceed 2048. A `dense_vector` field is a single-valued field.

These vectors can be used for <<vector-functions,document scoring>>.
mayya-sharipova marked this conversation as resolved.
Show resolved Hide resolved
For example, a document score can represent a distance between
a given query vector and the indexed document vector.
`dense_vector` fields do not support querying, sorting or aggregating. They can
only be accessed in scripts through the dedicated <<vector-functions,vector functions>>.

You index a dense vector as an array of floats.

Expand Down Expand Up @@ -47,4 +46,4 @@ PUT my-index-000001/_doc/2

--------------------------------------------------

<1> dimsthe number of dimensions in the vector, required parameter.
<1> dimsthe number of dimensions in the vector, required parameter.
58 changes: 58 additions & 0 deletions docs/reference/vectors/vector-functions.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,16 @@ linearly scanned. Thus, expect the query time grow linearly
with the number of matched documents. For this reason, we recommend
to limit the number of matched documents with a `query` parameter.

This is the list of available vector functions and vector access methods:

1. `cosineSimilarity` – calculates cosine similarity
2. `dotProduct` – calculates dot product
3. `l1norm` – calculates L^1^ distance
4. `l2norm` - calculates L^2^ distance
5. `doc[<field>].vectorValue` – returns a vector's value as an array of floats
6. `doc[<field>].magnitude` – returns a vector's magnitude


Let's create an index with a `dense_vector` mapping and index a couple
of documents into it.

Expand Down Expand Up @@ -195,3 +205,51 @@ You can check if a document has a value for the field `my_vector` by
"source": "doc['my_vector'].size() == 0 ? 0 : cosineSimilarity(params.queryVector, 'my_vector')"
--------------------------------------------------
// NOTCONSOLE

The recommended way to access dense vectors is through `cosineSimilarity`,
`dotProduct`, `l1norm` or `l2norm` functions. But for custom use cases,
you can access dense vectors's values directly through the following functions:

- `doc[<field>].vectorValue` – returns a vector's value as an array of floats

- `doc[<field>].magnitude` – returns a vector's magnitude as a float
(for vectors created prior to version 7.5 the magnitude is not stored.
So this function calculates it anew every time it is called).

For example, the script below implements a cosine similarity using these
two functions:

[source,console]
--------------------------------------------------
GET my-index-000001/_search
{
"query": {
"script_score": {
"query" : {
"bool" : {
"filter" : {
"term" : {
"status" : "published"
}
}
}
},
"script": {
"source": """
float[] v = doc['my_dense_vector'].vectorValue;
float vm = doc['my_dense_vector'].magnitude;
float dotProduct = 0;
for (int i = 0; i < v.length; i++) {
dotProduct += v[i] * params.queryVector[i];
}
return dotProduct / (vm * (float) params.queryVectorMag);
""",
"params": {
"queryVector": [4, 3.4, -0.2],
"queryVectorMag": 5.25357
}
}
}
}
}
--------------------------------------------------
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import org.elasticsearch.script.ExplainableScoreScript;
import org.elasticsearch.script.ScoreScript;
import org.elasticsearch.script.Script;
import org.elasticsearch.Version;

import java.io.IOException;
import java.util.Objects;
Expand Down Expand Up @@ -42,15 +41,13 @@ public float score() {

private final int shardId;
private final String indexName;
private final Version indexVersion;

public ScriptScoreFunction(Script sScript, ScoreScript.LeafFactory script, String indexName, int shardId, Version indexVersion) {
public ScriptScoreFunction(Script sScript, ScoreScript.LeafFactory script, String indexName, int shardId) {
super(CombineFunction.REPLACE);
this.sScript = sScript;
this.script = script;
this.indexName = indexName;
this.shardId = shardId;
this.indexVersion = indexVersion;
}

@Override
Expand All @@ -60,7 +57,6 @@ public LeafScoreFunction getLeafScoreFunction(LeafReaderContext ctx) throws IOEx
leafScript.setScorer(scorer);
leafScript._setIndexName(indexName);
leafScript._setShard(shardId);
leafScript._setIndexVersion(indexVersion);
return new LeafScoreFunction() {
@Override
public double score(int docId, float subQueryScore) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,6 @@ private ScoreScript makeScoreScript(LeafReaderContext context) throws IOExceptio
final ScoreScript scoreScript = scriptBuilder.newInstance(context);
scoreScript._setIndexName(indexName);
scoreScript._setShard(shardId);
scoreScript._setIndexVersion(indexVersion);
return scoreScript;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,7 @@ protected ScoreFunction doToFunction(SearchExecutionContext context) {
try {
ScoreScript.Factory factory = context.compile(script, ScoreScript.CONTEXT);
ScoreScript.LeafFactory searchScript = factory.newFactory(script.getParams(), context.lookup());
return new ScriptScoreFunction(script, searchScript,
context.index().getName(), context.getShardId(), context.indexVersionCreated());
return new ScriptScoreFunction(script, searchScript, context.index().getName(), context.getShardId());
} catch (Exception e) {
throw new QueryShardException(context, "script_score: the script could not be loaded", e);
}
Expand Down
22 changes: 0 additions & 22 deletions server/src/main/java/org/elasticsearch/script/ScoreScript.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.Scorable;
import org.elasticsearch.Version;
import org.elasticsearch.common.logging.DeprecationCategory;
import org.elasticsearch.common.logging.DeprecationLogger;
import org.elasticsearch.index.fielddata.ScriptDocValues;
Expand Down Expand Up @@ -85,7 +84,6 @@ public Explanation get(double score, Explanation subQueryExplanation) {
private int docId;
private int shardId = -1;
private String indexName = null;
private Version indexVersion = null;

public ScoreScript(Map<String, Object> params, SearchLookup lookup, LeafReaderContext leafContext) {
// null check needed b/c of expression engine subclass
Expand Down Expand Up @@ -185,19 +183,6 @@ public String _getIndex() {
}
}

/**
* Starting a name with underscore, so that the user cannot access this function directly through a script
* It is only used within predefined painless functions.
* @return index version or throws an exception if the index version is not set up for this script instance
*/
public Version _getIndexVersion() {
if (indexVersion != null) {
return indexVersion;
} else {
throw new IllegalArgumentException("index version can not be looked up!");
}
}

/**
* Starting a name with underscore, so that the user cannot access this function directly through a script
*/
Expand All @@ -212,13 +197,6 @@ public void _setIndexName(String indexName) {
this.indexName = indexName;
}

/**
* Starting a name with underscore, so that the user cannot access this function directly through a script
*/
public void _setIndexVersion(Version indexVersion) {
this.indexVersion = indexVersion;
}


/** A factory to construct {@link ScoreScript} instances. */
public interface LeafFactory {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
---
"Access to values of dense_vector in script":
- skip:
version: " - 7.12.99"
reason: "Access to values of dense_vector in script was added in 7.13"
- do:
indices.create:
index: test-index
body:
mappings:
properties:
v:
type: dense_vector
dims: 3

- do:
bulk:
index: test-index
refresh: true
body:
- '{"index": {"_id": "1"}}'
- '{"v": [1, 1, 1]}'
- '{"index": {"_id": "2"}}'
- '{"v": [1, 1, 2]}'
- '{"index": {"_id": "3"}}'
- '{"v": [1, 1, 3]}'
- '{"index": {"_id": "missing_vector"}}'
- '{}'

# vector functions in loop – return the index of the closest parameter vector based on cosine similarity
- do:
search:
body:
query:
script_score:
query: { "exists": { "field": "v" } }
script:
source: |
float[] v = doc['v'].vectorValue;
float vm = doc['v'].magnitude;

int closestPv = 0;
float maxCosSim = -1;
for (int i = 0; i < params.pvs.length; i++) {
float dotProduct = 0;
for (int j = 0; j < v.length; j++) {
dotProduct += v[j] * params.pvs[i][j];
}
float cosSim = dotProduct / (vm * (float) params.pvs_magnts[i]);
if (maxCosSim < cosSim) {
maxCosSim = cosSim;
closestPv = i;
}
}
closestPv;
params:
pvs: [ [ 1, 1, 1 ], [ 1, 1, 2 ], [ 1, 1, 3 ] ]
pvs_magnts: [1.7320, 2.4495, 3.3166]

- match: { hits.hits.0._id: "3" }
- match: { hits.hits.0._score: 2 }
- match: { hits.hits.1._id: "2" }
- match: { hits.hits.1._score: 1 }
- match: { hits.hits.2._id: "1" }
- match: { hits.hits.2._score: 0 }
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ protected List<Parameter<?>> getParameters() {
public DenseVectorFieldMapper build(ContentPath contentPath) {
return new DenseVectorFieldMapper(
name,
new DenseVectorFieldType(buildFullName(contentPath), dims.getValue(), meta.getValue()),
new DenseVectorFieldType(buildFullName(contentPath), indexVersionCreated, dims.getValue(), meta.getValue()),
dims.getValue(),
indexVersionCreated,
multiFieldsBuilder.build(this, contentPath),
Expand All @@ -94,10 +94,12 @@ public DenseVectorFieldMapper build(ContentPath contentPath) {

public static final class DenseVectorFieldType extends MappedFieldType {
private final int dims;
private final Version indexVersionCreated;

public DenseVectorFieldType(String name, int dims, Map<String, String> meta) {
public DenseVectorFieldType(String name, Version indexVersionCreated, int dims, Map<String, String> meta) {
super(name, false, false, true, TextSearchInfo.NONE, meta);
this.dims = dims;
this.indexVersionCreated = indexVersionCreated;
}

int dims() {
Expand All @@ -124,7 +126,7 @@ protected Object parseSourceValue(Object value) {

@Override
public DocValueFormat docValueFormat(String format, ZoneId timeZone) {
throw new UnsupportedOperationException(
throw new IllegalArgumentException(
"Field [" + name() + "] of type [" + typeName() + "] doesn't support docvalue_fields or aggregations");
}

Expand All @@ -135,7 +137,7 @@ public boolean isAggregatable() {

@Override
public IndexFieldData.Builder fielddataBuilder(String fullyQualifiedIndexName, Supplier<SearchLookup> searchLookup) {
return new VectorIndexFieldData.Builder(name(), CoreValuesSourceType.KEYWORD);
return new VectorIndexFieldData.Builder(name(), CoreValuesSourceType.KEYWORD, indexVersionCreated, dims);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import java.nio.ByteBuffer;


public final class VectorEncoderDecoder {
public static final byte INT_BYTES = 4;

Expand All @@ -29,9 +30,51 @@ public static int denseVectorLength(Version indexVersion, BytesRef vectorBR) {
* NOTE: this function can only be called on vectors from an index version greater than or
* equal to 7.5.0, since vectors created prior to that do not store the magnitude.
*/
public static float decodeVectorMagnitude(Version indexVersion, BytesRef vectorBR) {
public static float decodeMagnitude(Version indexVersion, BytesRef vectorBR) {
assert indexVersion.onOrAfter(Version.V_7_5_0);
ByteBuffer byteBuffer = ByteBuffer.wrap(vectorBR.bytes, vectorBR.offset, vectorBR.length);
return byteBuffer.getFloat(vectorBR.offset + vectorBR.length - 4);
return byteBuffer.getFloat(vectorBR.offset + vectorBR.length - INT_BYTES);
}

/**
* Calculates vector magnitude
*/
private static float calculateMagnitude(Version indexVersion, BytesRef vectorBR) {
final int length = denseVectorLength(indexVersion, vectorBR);
ByteBuffer byteBuffer = ByteBuffer.wrap(vectorBR.bytes, vectorBR.offset, vectorBR.length);
double magnitude = 0.0f;
for (int i = 0; i < length; i++) {
float value = byteBuffer.getFloat();
magnitude += value * value;
}
magnitude = Math.sqrt(magnitude);
return (float) magnitude;
}
jtibshirani marked this conversation as resolved.
Show resolved Hide resolved

public static float getMagnitude(Version indexVersion, BytesRef vectorBR) {
if (vectorBR == null) {
throw new IllegalArgumentException("A document doesn't have a value for a vector field!");
}
if (indexVersion.onOrAfter(Version.V_7_5_0)) {
return decodeMagnitude(indexVersion, vectorBR);
} else {
return calculateMagnitude(indexVersion, vectorBR);
}
}

/**
* Decodes a BytesRef into the provided array of floats
* @param vectorBR - dense vector encoded in BytesRef
* @param vector - array of floats where the decoded vector should be stored
*/
public static void decodeDenseVector(BytesRef vectorBR, float[] vector) {
if (vectorBR == null) {
throw new IllegalArgumentException("A document doesn't have a value for a vector field!");
}
ByteBuffer byteBuffer = ByteBuffer.wrap(vectorBR.bytes, vectorBR.offset, vectorBR.length);
for (int dim = 0; dim < vector.length; dim++) {
vector[dim] = byteBuffer.getFloat();
}
}

}
Loading