Skip to content
This repository has been archived by the owner on Aug 2, 2022. It is now read-only.

Add stats for custom scoring feature #233

Merged
18 changes: 17 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,18 @@ The total time in nanoseconds it has taken to load items into cache (cumulative)
#### indices_in_cache
For each index that has graphs in the cache, this stat provides the number of graphs that index has and the total graph_memory_usage that index is using in Kilobytes.

#### script_compilations
The number of times the knn script is compiled. This value should only be 0 or 1 most of the time. However, if the cache containing the compiled scripts is filled, it may cause the script to be recompiled.

#### script_compilation_errors
The number of errors during script compilation.

#### script_query_requests
The number of query requests that use the k-NN score script. One query request corresponds to one query for a given shard.

#### script_query_errors
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these errors are counted at shard level. We could update the documentation accordingly. Also you might want to check if all the above metrics at shard level or index level

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, they are at shard level. I updated README

The number of errors that have occurred during use of the k-NN score script. One error corresponds to one error for a given a shard.

#### Examples
```

Expand Down Expand Up @@ -386,7 +398,11 @@ GET /_opendistro/_knn/stats?pretty
"load_exception_count" : 0,
"hit_count" : 0,
"load_success_count" : 1,
"total_load_time" : 2878745
"total_load_time" : 2878745,
"script_compilations" : 1,
"script_compilation_errors" : 0,
"script_query_requests" : 534,
"script_query_errors" : 0
}
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.amazon.opendistroforelasticsearch.knn.plugin.script;

import com.amazon.opendistroforelasticsearch.knn.plugin.stats.KNNCounter;
import org.elasticsearch.script.ScoreScript;
import org.elasticsearch.script.ScriptContext;
import org.elasticsearch.script.ScriptEngine;
Expand All @@ -22,11 +23,14 @@ public String getType() {

@Override
public <FactoryType> FactoryType compile(String name, String code, ScriptContext<FactoryType> context, Map<String, String> params) {
KNNCounter.SCRIPT_COMPILATIONS.increment();
if (!ScoreScript.CONTEXT.equals(context)) {
KNNCounter.SCRIPT_COMPILATION_ERRORS.increment();
throw new IllegalArgumentException(getType() + " KNN Vector scoring scripts cannot be used for context [" + context.name + "]");
}
// we use the script "source" as the script identifier
if (!SCRIPT_SOURCE.equals(code)) {
KNNCounter.SCRIPT_COMPILATION_ERRORS.increment();
throw new IllegalArgumentException("Unknown script name " + code);
}
ScoreScript.Factory factory = KNNVectorScoreScript.VectorScoreScriptFactory::new;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.amazon.opendistroforelasticsearch.knn.plugin.script;

import com.amazon.opendistroforelasticsearch.knn.index.util.KNNConstants;
import com.amazon.opendistroforelasticsearch.knn.plugin.stats.KNNCounter;
import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.util.BytesRef;
Expand Down Expand Up @@ -53,10 +54,12 @@ public double execute(ScoreScript.ExplanationHolder explanationHolder) {
ObjectInputStream objectStream = new ObjectInputStream(byteStream)) {
doc_vector = (float[]) objectStream.readObject();
} catch (ClassNotFoundException e) {
KNNCounter.SCRIPT_QUERY_ERRORS.increment();
throw new RuntimeException(e);
}

if(doc_vector.length != queryVector.length) {
KNNCounter.SCRIPT_QUERY_ERRORS.increment();
throw new IllegalStateException("[KNN] query vector and field vector dimensions mismatch. " +
"query vector: " + queryVector.length + ", stored vector: " + doc_vector.length);
}
Expand All @@ -69,6 +72,7 @@ public double execute(ScoreScript.ExplanationHolder explanationHolder) {
score = 1 + KNNScoringUtil.cosinesimilOptimized(this.queryVector, doc_vector, this.queryVectorSquaredMagnitude);
}
} catch (IOException e) {
KNNCounter.SCRIPT_QUERY_ERRORS.increment();
throw new UncheckedIOException(e);
}
return score;
Expand All @@ -94,6 +98,7 @@ public KNNVectorScoreScript(Map<String, Object> params, String field, float[] qu
this.queryVectorSquaredMagnitude = queryVectorSquaredMagnitude;
this.binaryDocValuesReader = leafContext.reader().getBinaryDocValues(field);
if(this.binaryDocValuesReader == null) {
KNNCounter.SCRIPT_QUERY_ERRORS.increment();
throw new IllegalStateException("Binary Doc values not enabled for the field " + field
+ " Please ensure the field type is knn_vector in mappings for this field");
}
Expand All @@ -108,6 +113,8 @@ public static class VectorScoreScriptFactory implements ScoreScript.LeafFactory
private float qVectorSquaredMagnitude; // Used for cosine optimization

public VectorScoreScriptFactory(Map<String, Object> params, SearchLookup lookup) {
KNNCounter.SCRIPT_QUERY_REQUESTS.increment();

this.params = params;
this.lookup = lookup;
validateAndInitParams(params);
Expand All @@ -124,23 +131,29 @@ public VectorScoreScriptFactory(Map<String, Object> params, SearchLookup lookup)
private void validateAndInitParams(Map<String, Object> params) {
// query vector field
final Object field = params.get("field");
if (field == null)
if (field == null) {
KNNCounter.SCRIPT_QUERY_ERRORS.increment();
throw new IllegalArgumentException("Missing parameter [field]");
}

this.field = field.toString();

// query vector
final Object qVector = params.get("vector");
if (qVector == null) {
KNNCounter.SCRIPT_QUERY_ERRORS.increment();
throw new IllegalArgumentException("Missing query vector parameter [vector]");
}

// validate space
final Object space = params.get("space_type");
if (space == null) {
KNNCounter.SCRIPT_QUERY_ERRORS.increment();
throw new IllegalArgumentException("Missing parameter [space_type]");
}
this.similaritySpace = (String)space;
if (!KNNConstants.COSINESIMIL.equalsIgnoreCase(similaritySpace) && !KNNConstants.L2.equalsIgnoreCase(similaritySpace)) {
KNNCounter.SCRIPT_QUERY_ERRORS.increment();
throw new IllegalArgumentException("Invalid space type. Please refer to the available space types.");
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@ public enum KNNCounter {
GRAPH_QUERY_REQUESTS("graph_query_requests"),
GRAPH_INDEX_ERRORS("graph_index_errors"),
GRAPH_INDEX_REQUESTS("graph_index_requests"),
KNN_QUERY_REQUESTS("knn_query_requests");
KNN_QUERY_REQUESTS("knn_query_requests"),
SCRIPT_COMPILATIONS("script_compilations"),
SCRIPT_COMPILATION_ERRORS("script_compilation_errors"),
SCRIPT_QUERY_REQUESTS("script_query_requests"),
SCRIPT_QUERY_ERRORS("script_query_errors");

private String name;
private AtomicLong count;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,5 +58,14 @@ public class KNNStatsConfig {
.put(StatNames.KNN_QUERY_REQUESTS.getName(), new KNNStat<>(false,
new KNNCounterSupplier(KNNCounter.KNN_QUERY_REQUESTS)))
.put(StatNames.INDICES_IN_CACHE.getName(), new KNNStat<>(false,
new KNNCacheSupplier<>(KNNIndexCache::getIndicesCacheStats))).build();
new KNNCacheSupplier<>(KNNIndexCache::getIndicesCacheStats)))
.put(StatNames.SCRIPT_COMPILATIONS.getName(), new KNNStat<>(false,
new KNNCounterSupplier(KNNCounter.SCRIPT_COMPILATIONS)))
.put(StatNames.SCRIPT_COMPILATION_ERRORS.getName(), new KNNStat<>(false,
new KNNCounterSupplier(KNNCounter.SCRIPT_COMPILATION_ERRORS)))
.put(StatNames.SCRIPT_QUERY_REQUESTS.getName(), new KNNStat<>(false,
new KNNCounterSupplier(KNNCounter.SCRIPT_QUERY_REQUESTS)))
.put(StatNames.SCRIPT_QUERY_ERRORS.getName(), new KNNStat<>(false,
new KNNCounterSupplier(KNNCounter.SCRIPT_QUERY_ERRORS)))
.build();
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,11 @@ public enum StatNames {
GRAPH_QUERY_REQUESTS(KNNCounter.GRAPH_QUERY_REQUESTS.getName()),
GRAPH_INDEX_ERRORS(KNNCounter.GRAPH_INDEX_ERRORS.getName()),
GRAPH_INDEX_REQUESTS(KNNCounter.GRAPH_INDEX_REQUESTS.getName()),
KNN_QUERY_REQUESTS(KNNCounter.KNN_QUERY_REQUESTS.getName());
KNN_QUERY_REQUESTS(KNNCounter.KNN_QUERY_REQUESTS.getName()),
SCRIPT_COMPILATIONS(KNNCounter.SCRIPT_COMPILATIONS.getName()),
SCRIPT_COMPILATION_ERRORS(KNNCounter.SCRIPT_COMPILATION_ERRORS.getName()),
SCRIPT_QUERY_REQUESTS(KNNCounter.SCRIPT_QUERY_REQUESTS.getName()),
SCRIPT_QUERY_ERRORS(KNNCounter.SCRIPT_QUERY_ERRORS.getName());

private String name;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,27 @@

import com.amazon.opendistroforelasticsearch.knn.KNNRestTestCase;
import com.amazon.opendistroforelasticsearch.knn.index.KNNQueryBuilder;
import com.amazon.opendistroforelasticsearch.knn.index.util.KNNConstants;
import com.amazon.opendistroforelasticsearch.knn.plugin.stats.KNNStats;

import com.amazon.opendistroforelasticsearch.knn.plugin.stats.StatNames;
import org.apache.http.util.EntityUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.client.Request;
import org.elasticsearch.client.Response;
import org.elasticsearch.client.ResponseException;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.index.query.MatchAllQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.rest.RestStatus;
import org.junit.rules.DisableOnDebug;

import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

Expand Down Expand Up @@ -164,6 +170,127 @@ public void testInvalidNodeIdStats() throws Exception {
assertEquals(0, nodeStats.size());
}

/**
* Test checks that script stats are properly updated for single shard
*/
public void testScriptStats_singleShard() throws Exception {
// Get initial stats
Response response = getKnnStats(Collections.emptyList(), Arrays.asList(
StatNames.SCRIPT_QUERY_REQUESTS.getName(),
StatNames.SCRIPT_QUERY_ERRORS.getName())
);
List<Map<String, Object>> nodeStats = parseNodeStatsResponse(EntityUtils.toString(response.getEntity()));
int initialScriptQueryRequests = (int)(nodeStats.get(0).get(StatNames.SCRIPT_QUERY_REQUESTS.getName()));
int initialScriptQueryErrors = (int)(nodeStats.get(0).get(StatNames.SCRIPT_QUERY_ERRORS.getName()));

// Create an index with a single vector
createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2));
Float[] vector = {6.0f, 6.0f};
addKnnDoc(INDEX_NAME, "1", FIELD_NAME, vector);

// Check l2 query and script compilation stats
QueryBuilder qb = new MatchAllQueryBuilder();
Map<String, Object> params = new HashMap<>();
float[] queryVector = {1.0f, 1.0f};
params.put("field", FIELD_NAME);
params.put("vector", queryVector);
params.put("space_type", KNNConstants.L2);
Request request = constructKNNScriptQueryRequest(INDEX_NAME, qb, params, queryVector);
response = client().performRequest(request);
assertEquals(request.getEndpoint() + ": failed", RestStatus.OK,
RestStatus.fromCode(response.getStatusLine().getStatusCode()));

response = getKnnStats(Collections.emptyList(), Arrays.asList(
StatNames.SCRIPT_COMPILATIONS.getName(),
StatNames.SCRIPT_QUERY_REQUESTS.getName())
);
nodeStats = parseNodeStatsResponse(EntityUtils.toString(response.getEntity()));
assertEquals(1, (int)(nodeStats.get(0).get(StatNames.SCRIPT_COMPILATIONS.getName())));
assertEquals(initialScriptQueryRequests + 1,
(int)(nodeStats.get(0).get(StatNames.SCRIPT_QUERY_REQUESTS.getName())));

// Check query error stats
params = new HashMap<>();
params.put("field", FIELD_NAME);
params.put("vector", queryVector);
params.put("space_type", "invalid_space");
request = constructKNNScriptQueryRequest(INDEX_NAME, qb, params, queryVector);
Request finalRequest = request;
expectThrows(ResponseException.class, () -> client().performRequest(finalRequest));

response = getKnnStats(Collections.emptyList(), Collections.singletonList(
StatNames.SCRIPT_QUERY_ERRORS.getName())
);
nodeStats = parseNodeStatsResponse(EntityUtils.toString(response.getEntity()));
assertEquals(initialScriptQueryErrors + 1,
(int)(nodeStats.get(0).get(StatNames.SCRIPT_QUERY_ERRORS.getName())));
}

/**
* Test checks that script stats are properly updated for multiple shards
*/
public void testScriptStats_multipleShards() throws Exception {
// Get initial stats
Response response = getKnnStats(Collections.emptyList(), Arrays.asList(
StatNames.SCRIPT_QUERY_REQUESTS.getName(),
StatNames.SCRIPT_QUERY_ERRORS.getName())
);
List<Map<String, Object>> nodeStats = parseNodeStatsResponse(EntityUtils.toString(response.getEntity()));
int initialScriptQueryRequests = (int)(nodeStats.get(0).get(StatNames.SCRIPT_QUERY_REQUESTS.getName()));
int initialScriptQueryErrors = (int)(nodeStats.get(0).get(StatNames.SCRIPT_QUERY_ERRORS.getName()));

// Create an index with a single vector
createKnnIndex(INDEX_NAME, Settings.builder()
.put("number_of_shards", 2)
.put("number_of_replicas", 0)
.put("index.knn", true)
.build(),
createKnnIndexMapping(FIELD_NAME, 2));

Float[] vector = {6.0f, 6.0f};
addKnnDoc(INDEX_NAME, "1", FIELD_NAME, vector);
addKnnDoc(INDEX_NAME, "2", FIELD_NAME, vector);
addKnnDoc(INDEX_NAME, "3", FIELD_NAME, vector);
addKnnDoc(INDEX_NAME, "4", FIELD_NAME, vector);

// Check l2 query and script compilation stats
QueryBuilder qb = new MatchAllQueryBuilder();
Map<String, Object> params = new HashMap<>();
float[] queryVector = {1.0f, 1.0f};
params.put("field", FIELD_NAME);
params.put("vector", queryVector);
params.put("space_type", KNNConstants.L2);
Request request = constructKNNScriptQueryRequest(INDEX_NAME, qb, params, queryVector);
response = client().performRequest(request);
assertEquals(request.getEndpoint() + ": failed", RestStatus.OK,
RestStatus.fromCode(response.getStatusLine().getStatusCode()));

response = getKnnStats(Collections.emptyList(), Arrays.asList(
StatNames.SCRIPT_COMPILATIONS.getName(),
StatNames.SCRIPT_QUERY_REQUESTS.getName())
);
nodeStats = parseNodeStatsResponse(EntityUtils.toString(response.getEntity()));
assertEquals(1, (int)(nodeStats.get(0).get(StatNames.SCRIPT_COMPILATIONS.getName())));
assertEquals(initialScriptQueryRequests + 2,
(int)(nodeStats.get(0).get(StatNames.SCRIPT_QUERY_REQUESTS.getName())));

// Check query error stats
params = new HashMap<>();
params.put("field", FIELD_NAME);
params.put("vector", queryVector);
params.put("space_type", "invalid_space");
request = constructKNNScriptQueryRequest(INDEX_NAME, qb, params, queryVector);
Request finalRequest = request;
expectThrows(ResponseException.class, () -> client().performRequest(finalRequest));

response = getKnnStats(Collections.emptyList(), Collections.singletonList(
StatNames.SCRIPT_QUERY_ERRORS.getName())
);
nodeStats = parseNodeStatsResponse(EntityUtils.toString(response.getEntity()));
assertEquals(initialScriptQueryErrors + 2,
(int)(nodeStats.get(0).get(StatNames.SCRIPT_QUERY_ERRORS.getName())));
}

// Useful settings when debugging to prevent timeouts
@Override
protected Settings restClientSettings() {
Expand Down