Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add graph creation stats to the KNNStats API #1141

Merged
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
## [Unreleased 2.x](https://github.com/opensearch-project/k-NN/compare/2.10...2.x)
### Features
### Enhancements
- Add graph creation stats to the KNNStats API. [#1141](https://github.com/opensearch-project/k-NN/pull/1141)
### Bug Fixes
### Infrastructure
### Documentation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import lombok.NonNull;
import lombok.extern.log4j.Log4j2;
import org.apache.lucene.store.ChecksumIndexInput;
import org.opensearch.common.StopWatch;
import org.opensearch.common.xcontent.XContentHelper;
import org.opensearch.core.common.bytes.BytesArray;
import org.opensearch.core.xcontent.MediaTypeRegistry;
Expand All @@ -35,6 +36,7 @@
import org.apache.lucene.store.FilterDirectory;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.plugin.stats.KNNGraphValue;

import java.io.Closeable;
import java.io.IOException;
Expand All @@ -53,6 +55,7 @@
import static org.opensearch.knn.common.KNNConstants.MODEL_ID;
import static org.opensearch.knn.common.KNNConstants.PARAMETERS;
import static org.opensearch.knn.index.codec.util.KNNCodecUtil.buildEngineFileName;
import static org.opensearch.knn.index.codec.util.KNNCodecUtil.calculateArraySize;

/**
* This class writes the KNN docvalues to the segments
Expand All @@ -76,7 +79,13 @@ class KNN80DocValuesConsumer extends DocValuesConsumer implements Closeable {
public void addBinaryField(FieldInfo field, DocValuesProducer valuesProducer) throws IOException {
delegatee.addBinaryField(field, valuesProducer);
if (isKNNBinaryFieldRequired(field)) {
addKNNBinaryField(field, valuesProducer);
StopWatch stopWatch = new StopWatch();
navneet1v marked this conversation as resolved.
Show resolved Hide resolved
stopWatch.start();
addKNNBinaryField(field, valuesProducer, false, true);
stopWatch.stop();
long time_in_millis = stopWatch.totalTime().millis();
KNNGraphValue.REFRESH_TOTAL_TIME_IN_MILLIS.set(KNNGraphValue.REFRESH_TOTAL_TIME_IN_MILLIS.getValue() + time_in_millis);
ryanbogan marked this conversation as resolved.
Show resolved Hide resolved
logger.warn("Refresh operation complete in " + time_in_millis + " ms");
}
}

Expand All @@ -97,14 +106,21 @@ private KNNEngine getKNNEngine(@NonNull FieldInfo field) {
return KNNEngine.getEngine(engineName);
}

public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer) throws IOException {
public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer, boolean isMerge, boolean isRefresh)
throws IOException {
// Get values to be indexed
BinaryDocValues values = valuesProducer.getBinary(field);
KNNCodecUtil.Pair pair = KNNCodecUtil.getFloats(values);
if (pair.vectors.length == 0 || pair.docs.length == 0) {
logger.info("Skipping engine index creation as there are no vectors or docs in the documents");
return;
}
long arraySize = calculateArraySize(pair.vectors, pair.serializationMode);
if (isMerge) {
KNNGraphValue.MERGE_CURRENT_OPERATIONS.increment();
KNNGraphValue.MERGE_CURRENT_DOCS.incrementBy(pair.docs.length);
KNNGraphValue.MERGE_CURRENT_SIZE_IN_BYTES.incrementBy(arraySize);
}
// Increment counter for number of graph index requests
KNNCounter.GRAPH_INDEX_REQUESTS.increment();
// Create library index either from model or from scratch
Expand Down Expand Up @@ -135,6 +151,14 @@ public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer)
indexCreator = () -> createKNNIndexFromScratch(field, pair, knnEngine, indexPath);
}

if (isMerge) {
recordMergeStats(pair.docs.length, arraySize);
}

if (isRefresh) {
recordRefreshStats();
}

// This is a bit of a hack. We have to create an output here and then immediately close it to ensure that
// engineFileName is added to the tracked files by Lucene's TrackingDirectoryWrapper. Otherwise, the file will
// not be marked as added to the directory.
Expand All @@ -143,6 +167,19 @@ public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer)
writeFooter(indexPath, engineFileName);
}

private void recordMergeStats(int length, long arraySize) {
KNNGraphValue.MERGE_CURRENT_OPERATIONS.decrement();
KNNGraphValue.MERGE_CURRENT_DOCS.decrementBy(length);
KNNGraphValue.MERGE_CURRENT_SIZE_IN_BYTES.decrementBy(arraySize);
KNNGraphValue.MERGE_TOTAL_OPERATIONS.increment();
KNNGraphValue.MERGE_TOTAL_DOCS.incrementBy(length);
KNNGraphValue.MERGE_TOTAL_SIZE_IN_BYTES.incrementBy(arraySize);
}

private void recordRefreshStats() {
KNNGraphValue.REFRESH_TOTAL_OPERATIONS.increment();
}

private void createKNNIndexFromTemplate(byte[] model, KNNCodecUtil.Pair pair, KNNEngine knnEngine, String indexPath) {
Map<String, Object> parameters = ImmutableMap.of(
KNNConstants.INDEX_THREAD_QTY,
Expand Down Expand Up @@ -210,7 +247,13 @@ public void merge(MergeState mergeState) {
for (FieldInfo fieldInfo : mergeState.mergeFieldInfos) {
DocValuesType type = fieldInfo.getDocValuesType();
if (type == DocValuesType.BINARY && fieldInfo.attributes().containsKey(KNNVectorFieldMapper.KNN_FIELD)) {
addKNNBinaryField(fieldInfo, new KNN80DocValuesReader(mergeState));
StopWatch stopWatch = new StopWatch();
ryanbogan marked this conversation as resolved.
Show resolved Hide resolved
stopWatch.start();
addKNNBinaryField(fieldInfo, new KNN80DocValuesReader(mergeState), true, false);
stopWatch.stop();
long time_in_millis = stopWatch.totalTime().millis();
KNNGraphValue.MERGE_TOTAL_TIME_IN_MILLIS.set(KNNGraphValue.MERGE_TOTAL_TIME_IN_MILLIS.getValue() + time_in_millis);
logger.warn("Merge operation complete in " + time_in_millis + " ms");
}
}
} catch (Exception e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,30 +17,72 @@

public static final String HNSW_EXTENSION = ".hnsw";
public static final String HNSW_COMPOUND_EXTENSION = ".hnswc";
// Floats are 4 bytes in size
public static final int FLOAT_BYTE_SIZE = 4;
// References to objects are 4 bytes in size
public static final int JAVA_REFERENCE_SIZE = 4;
// Each array in Java has a header that is 12 bytes
public static final int JAVA_ARRAY_HEADER_SIZE = 12;
ryanbogan marked this conversation as resolved.
Show resolved Hide resolved
// Java rounds each array size up to multiples of 8 bytes
public static final int JAVA_ROUNDING_NUMBER = 8;

public static final class Pair {
public Pair(int[] docs, float[][] vectors) {
public Pair(int[] docs, float[][] vectors, SerializationMode serializationMode) {
this.docs = docs;
this.vectors = vectors;
this.serializationMode = serializationMode;
}

public int[] docs;
public float[][] vectors;
public SerializationMode serializationMode;
}

public static KNNCodecUtil.Pair getFloats(BinaryDocValues values) throws IOException {
ArrayList<float[]> vectorList = new ArrayList<>();
ArrayList<Integer> docIdList = new ArrayList<>();
SerializationMode serializationMode = SerializationMode.COLLECTION_OF_FLOATS;
for (int doc = values.nextDoc(); doc != DocIdSetIterator.NO_MORE_DOCS; doc = values.nextDoc()) {
BytesRef bytesref = values.binaryValue();
try (ByteArrayInputStream byteStream = new ByteArrayInputStream(bytesref.bytes, bytesref.offset, bytesref.length)) {
serializationMode = KNNVectorSerializerFactory.serializerModeFromStream(byteStream);
final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getSerializerByStreamContent(byteStream);
final float[] vector = vectorSerializer.byteToFloatArray(byteStream);
vectorList.add(vector);
}
docIdList.add(doc);
}
return new KNNCodecUtil.Pair(docIdList.stream().mapToInt(Integer::intValue).toArray(), vectorList.toArray(new float[][] {}));
return new KNNCodecUtil.Pair(
docIdList.stream().mapToInt(Integer::intValue).toArray(),
vectorList.toArray(new float[][] {}),
serializationMode
);
}

public static long calculateArraySize(float[][] vectors, SerializationMode serializationMode) {
int vectorLength = vectors[0].length;
int numVectors = vectors.length;
if (serializationMode == SerializationMode.ARRAY) {
int vectorSize = vectorLength * FLOAT_BYTE_SIZE + JAVA_ARRAY_HEADER_SIZE;

Check warning on line 66 in src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java#L66

Added line #L66 was not covered by tests
if (vectorSize % JAVA_ROUNDING_NUMBER != 0) {
vectorSize += vectorSize % JAVA_ROUNDING_NUMBER;

Check warning on line 68 in src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java#L68

Added line #L68 was not covered by tests
}
int vectorsSize = numVectors * (vectorSize + JAVA_REFERENCE_SIZE) + JAVA_ARRAY_HEADER_SIZE;

Check warning on line 70 in src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java#L70

Added line #L70 was not covered by tests
if (vectorsSize % JAVA_ROUNDING_NUMBER != 0) {
vectorsSize += vectorsSize % JAVA_ROUNDING_NUMBER;

Check warning on line 72 in src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java#L72

Added line #L72 was not covered by tests
}
return vectorsSize;

Check warning on line 74 in src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java#L74

Added line #L74 was not covered by tests
} else {
int vectorSize = vectorLength * FLOAT_BYTE_SIZE;
if (vectorSize % JAVA_ROUNDING_NUMBER != 0) {
vectorSize += vectorSize % JAVA_ROUNDING_NUMBER;
}
int vectorsSize = numVectors * (vectorSize + JAVA_REFERENCE_SIZE);
if (vectorsSize % JAVA_ROUNDING_NUMBER != 0) {
vectorsSize += vectorsSize % JAVA_ROUNDING_NUMBER;
}
return vectorsSize;
}
}

public static String buildEngineFileName(String segmentName, String latestBuildVersion, String fieldName, String extension) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ public static KNNVectorSerializer getSerializerByStreamContent(final ByteArrayIn
return getSerializerBySerializationMode(serializationMode);
}

private static SerializationMode serializerModeFromStream(ByteArrayInputStream byteStream) {
static SerializationMode serializerModeFromStream(ByteArrayInputStream byteStream) {
int numberOfAvailableBytesInStream = byteStream.available();
if (numberOfAvailableBytesInStream < ARRAY_HEADER_OFFSET) {
return getSerializerOrThrowError(numberOfAvailableBytesInStream, COLLECTION_OF_FLOATS);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.plugin.stats;

import java.util.concurrent.atomic.AtomicLong;

/**
* Contains a map to keep track of different graph values
*/
public enum KNNGraphValue {

REFRESH_TOTAL_OPERATIONS("total"),
REFRESH_TOTAL_TIME_IN_MILLIS("total_time_in_millis"),
MERGE_CURRENT_OPERATIONS("current"),
MERGE_CURRENT_DOCS("current_docs"),
MERGE_CURRENT_SIZE_IN_BYTES("current_size_in_bytes"),
MERGE_TOTAL_OPERATIONS("total"),
MERGE_TOTAL_TIME_IN_MILLIS("total_time_in_millis"),
MERGE_TOTAL_DOCS("total_docs"),
MERGE_TOTAL_SIZE_IN_BYTES("total_size_in_bytes");

private String name;
private AtomicLong value;

/**
* Constructor
*
* @param name name of the graph value
*/
KNNGraphValue(String name) {
this.name = name;
this.value = new AtomicLong(0);
}

/**
* Get name of value
*
* @return name
*/
public String getName() {
return name;
}

/**
* Get the graph value
*
* @return value
*/
public Long getValue() {
return value.get();
}

/**
* Increment the graph value
*/
public void increment() {
value.getAndIncrement();
}

/**
* Decrement the graph value
*/
public void decrement() {
value.getAndDecrement();
}

/**
* Increment the graph value by a specified amount
*
* @param delta The amount to increment
*/
public void incrementBy(long delta) {
value.getAndAdd(delta);
}

/**
* Decrement the graph value by a specified amount
*
* @param delta The amount to decrement
*/
public void decrementBy(long delta) {
value.set(value.get() - delta);
}

/**
* @param value graph value
* Set the graph value
*/
public void set(long value) {
this.value.set(value);
}
}
29 changes: 29 additions & 0 deletions src/main/java/org/opensearch/knn/plugin/stats/KNNStats.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import java.time.temporal.ChronoUnit;
import java.util.HashMap;
import java.util.Map;
import java.util.function.Supplier;

/**
* Class represents all stats the plugin keeps track of
Expand Down Expand Up @@ -84,6 +85,7 @@ private Map<String, KNNStat<?>> buildStatsMap() {
addEngineStats(builder);
addScriptStats(builder);
addModelStats(builder);
addGraphStats(builder);
return builder.build();
}

Expand Down Expand Up @@ -169,4 +171,31 @@ private void addModelStats(ImmutableMap.Builder<String, KNNStat<?>> builder) {
new KNNStat<>(false, new NativeMemoryCacheManagerSupplier<>(NativeMemoryCacheManager::getTrainingSizeAsPercentage))
);
}

private void addGraphStats(ImmutableMap.Builder<String, KNNStat<?>> builder) {
builder.put(StatNames.GRAPH_STATS.getName(), new KNNStat<>(false, new Supplier<Map<String, Map<String, Object>>>() {
@Override
public Map<String, Map<String, Object>> get() {
return createGraphStatsMap();
}
}));
}

private Map<String, Map<String, Object>> createGraphStatsMap() {
Map<String, Object> mergeMap = new HashMap<>();
mergeMap.put(KNNGraphValue.MERGE_CURRENT_OPERATIONS.getName(), KNNGraphValue.MERGE_CURRENT_OPERATIONS.getValue());
mergeMap.put(KNNGraphValue.MERGE_CURRENT_DOCS.getName(), KNNGraphValue.MERGE_CURRENT_DOCS.getValue());
mergeMap.put(KNNGraphValue.MERGE_CURRENT_SIZE_IN_BYTES.getName(), KNNGraphValue.MERGE_CURRENT_SIZE_IN_BYTES.getValue());
mergeMap.put(KNNGraphValue.MERGE_TOTAL_OPERATIONS.getName(), KNNGraphValue.MERGE_TOTAL_OPERATIONS.getValue());
mergeMap.put(KNNGraphValue.MERGE_TOTAL_TIME_IN_MILLIS.getName(), KNNGraphValue.MERGE_TOTAL_TIME_IN_MILLIS.getValue());
mergeMap.put(KNNGraphValue.MERGE_TOTAL_DOCS.getName(), KNNGraphValue.MERGE_TOTAL_DOCS.getValue());
mergeMap.put(KNNGraphValue.MERGE_TOTAL_SIZE_IN_BYTES.getName(), KNNGraphValue.MERGE_TOTAL_SIZE_IN_BYTES.getValue());
Map<String, Object> refreshMap = new HashMap<>();
refreshMap.put(KNNGraphValue.REFRESH_TOTAL_OPERATIONS.getName(), KNNGraphValue.REFRESH_TOTAL_OPERATIONS.getValue());
refreshMap.put(KNNGraphValue.REFRESH_TOTAL_TIME_IN_MILLIS.getName(), KNNGraphValue.REFRESH_TOTAL_TIME_IN_MILLIS.getValue());
Map<String, Map<String, Object>> graphStatsMap = new HashMap<>();
graphStatsMap.put(StatNames.MERGE.getName(), mergeMap);
graphStatsMap.put(StatNames.REFRESH.getName(), refreshMap);
return graphStatsMap;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,10 @@ public enum StatNames {
TRAINING_MEMORY_USAGE("training_memory_usage"),
TRAINING_MEMORY_USAGE_PERCENTAGE("training_memory_usage_percentage"),
SCRIPT_QUERY_ERRORS(KNNCounter.SCRIPT_QUERY_ERRORS.getName()),
KNN_QUERY_WITH_FILTER_REQUESTS(KNNCounter.KNN_QUERY_WITH_FILTER_REQUESTS.getName());
KNN_QUERY_WITH_FILTER_REQUESTS(KNNCounter.KNN_QUERY_WITH_FILTER_REQUESTS.getName()),
GRAPH_STATS("graph_stats"),
REFRESH("refresh"),
MERGE("merge");
navneet1v marked this conversation as resolved.
Show resolved Hide resolved

private String name;

Expand Down
Loading
Loading