diff --git a/CHANGELOG.md b/CHANGELOG.md index 1b29bbba4..1b17c8f63 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Features ### Enhancements * Added support for ignore_unmapped in KNN queries. [#1071](https://github.com/opensearch-project/k-NN/pull/1071) +* Add graph creation stats to the KNNStats API. [#1141](https://github.com/opensearch-project/k-NN/pull/1141) ### Bug Fixes ### Infrastructure ### Documentation diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java index d9a30f75b..bcc2bf369 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java @@ -9,6 +9,7 @@ import lombok.NonNull; import lombok.extern.log4j.Log4j2; import org.apache.lucene.store.ChecksumIndexInput; +import org.opensearch.common.StopWatch; import org.opensearch.common.xcontent.XContentHelper; import org.opensearch.core.common.bytes.BytesArray; import org.opensearch.core.xcontent.MediaTypeRegistry; @@ -35,6 +36,7 @@ import org.apache.lucene.store.FilterDirectory; import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.plugin.stats.KNNGraphValue; import java.io.Closeable; import java.io.IOException; @@ -53,6 +55,7 @@ import static org.opensearch.knn.common.KNNConstants.MODEL_ID; import static org.opensearch.knn.common.KNNConstants.PARAMETERS; import static org.opensearch.knn.index.codec.util.KNNCodecUtil.buildEngineFileName; +import static org.opensearch.knn.index.codec.util.KNNCodecUtil.calculateArraySize; /** * This class writes the KNN docvalues to the segments @@ -76,7 +79,13 @@ class KNN80DocValuesConsumer extends DocValuesConsumer implements Closeable { public void addBinaryField(FieldInfo field, DocValuesProducer valuesProducer) throws IOException { delegatee.addBinaryField(field, valuesProducer); if (isKNNBinaryFieldRequired(field)) { - addKNNBinaryField(field, valuesProducer); + StopWatch stopWatch = new StopWatch(); + stopWatch.start(); + addKNNBinaryField(field, valuesProducer, false, true); + stopWatch.stop(); + long time_in_millis = stopWatch.totalTime().millis(); + KNNGraphValue.REFRESH_TOTAL_TIME_IN_MILLIS.set(KNNGraphValue.REFRESH_TOTAL_TIME_IN_MILLIS.getValue() + time_in_millis); + logger.warn("Refresh operation complete in " + time_in_millis + " ms"); } } @@ -97,7 +106,8 @@ private KNNEngine getKNNEngine(@NonNull FieldInfo field) { return KNNEngine.getEngine(engineName); } - public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer) throws IOException { + public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer, boolean isMerge, boolean isRefresh) + throws IOException { // Get values to be indexed BinaryDocValues values = valuesProducer.getBinary(field); KNNCodecUtil.Pair pair = KNNCodecUtil.getFloats(values); @@ -105,6 +115,12 @@ public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer) logger.info("Skipping engine index creation as there are no vectors or docs in the documents"); return; } + long arraySize = calculateArraySize(pair.vectors, pair.serializationMode); + if (isMerge) { + KNNGraphValue.MERGE_CURRENT_OPERATIONS.increment(); + KNNGraphValue.MERGE_CURRENT_DOCS.incrementBy(pair.docs.length); + KNNGraphValue.MERGE_CURRENT_SIZE_IN_BYTES.incrementBy(arraySize); + } // Increment counter for number of graph index requests KNNCounter.GRAPH_INDEX_REQUESTS.increment(); // Create library index either from model or from scratch @@ -135,6 +151,14 @@ public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer) indexCreator = () -> createKNNIndexFromScratch(field, pair, knnEngine, indexPath); } + if (isMerge) { + recordMergeStats(pair.docs.length, arraySize); + } + + if (isRefresh) { + recordRefreshStats(); + } + // This is a bit of a hack. We have to create an output here and then immediately close it to ensure that // engineFileName is added to the tracked files by Lucene's TrackingDirectoryWrapper. Otherwise, the file will // not be marked as added to the directory. @@ -143,6 +167,19 @@ public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer) writeFooter(indexPath, engineFileName); } + private void recordMergeStats(int length, long arraySize) { + KNNGraphValue.MERGE_CURRENT_OPERATIONS.decrement(); + KNNGraphValue.MERGE_CURRENT_DOCS.decrementBy(length); + KNNGraphValue.MERGE_CURRENT_SIZE_IN_BYTES.decrementBy(arraySize); + KNNGraphValue.MERGE_TOTAL_OPERATIONS.increment(); + KNNGraphValue.MERGE_TOTAL_DOCS.incrementBy(length); + KNNGraphValue.MERGE_TOTAL_SIZE_IN_BYTES.incrementBy(arraySize); + } + + private void recordRefreshStats() { + KNNGraphValue.REFRESH_TOTAL_OPERATIONS.increment(); + } + private void createKNNIndexFromTemplate(byte[] model, KNNCodecUtil.Pair pair, KNNEngine knnEngine, String indexPath) { Map parameters = ImmutableMap.of( KNNConstants.INDEX_THREAD_QTY, @@ -210,7 +247,13 @@ public void merge(MergeState mergeState) { for (FieldInfo fieldInfo : mergeState.mergeFieldInfos) { DocValuesType type = fieldInfo.getDocValuesType(); if (type == DocValuesType.BINARY && fieldInfo.attributes().containsKey(KNNVectorFieldMapper.KNN_FIELD)) { - addKNNBinaryField(fieldInfo, new KNN80DocValuesReader(mergeState)); + StopWatch stopWatch = new StopWatch(); + stopWatch.start(); + addKNNBinaryField(fieldInfo, new KNN80DocValuesReader(mergeState), true, false); + stopWatch.stop(); + long time_in_millis = stopWatch.totalTime().millis(); + KNNGraphValue.MERGE_TOTAL_TIME_IN_MILLIS.set(KNNGraphValue.MERGE_TOTAL_TIME_IN_MILLIS.getValue() + time_in_millis); + logger.warn("Merge operation complete in " + time_in_millis + " ms"); } } } catch (Exception e) { diff --git a/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java b/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java index eef9e6863..02ab2d833 100644 --- a/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java +++ b/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java @@ -17,30 +17,72 @@ public class KNNCodecUtil { public static final String HNSW_EXTENSION = ".hnsw"; public static final String HNSW_COMPOUND_EXTENSION = ".hnswc"; + // Floats are 4 bytes in size + public static final int FLOAT_BYTE_SIZE = 4; + // References to objects are 4 bytes in size + public static final int JAVA_REFERENCE_SIZE = 4; + // Each array in Java has a header that is 12 bytes + public static final int JAVA_ARRAY_HEADER_SIZE = 12; + // Java rounds each array size up to multiples of 8 bytes + public static final int JAVA_ROUNDING_NUMBER = 8; public static final class Pair { - public Pair(int[] docs, float[][] vectors) { + public Pair(int[] docs, float[][] vectors, SerializationMode serializationMode) { this.docs = docs; this.vectors = vectors; + this.serializationMode = serializationMode; } public int[] docs; public float[][] vectors; + public SerializationMode serializationMode; } public static KNNCodecUtil.Pair getFloats(BinaryDocValues values) throws IOException { ArrayList vectorList = new ArrayList<>(); ArrayList docIdList = new ArrayList<>(); + SerializationMode serializationMode = SerializationMode.COLLECTION_OF_FLOATS; for (int doc = values.nextDoc(); doc != DocIdSetIterator.NO_MORE_DOCS; doc = values.nextDoc()) { BytesRef bytesref = values.binaryValue(); try (ByteArrayInputStream byteStream = new ByteArrayInputStream(bytesref.bytes, bytesref.offset, bytesref.length)) { + serializationMode = KNNVectorSerializerFactory.serializerModeFromStream(byteStream); final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getSerializerByStreamContent(byteStream); final float[] vector = vectorSerializer.byteToFloatArray(byteStream); vectorList.add(vector); } docIdList.add(doc); } - return new KNNCodecUtil.Pair(docIdList.stream().mapToInt(Integer::intValue).toArray(), vectorList.toArray(new float[][] {})); + return new KNNCodecUtil.Pair( + docIdList.stream().mapToInt(Integer::intValue).toArray(), + vectorList.toArray(new float[][] {}), + serializationMode + ); + } + + public static long calculateArraySize(float[][] vectors, SerializationMode serializationMode) { + int vectorLength = vectors[0].length; + int numVectors = vectors.length; + if (serializationMode == SerializationMode.ARRAY) { + int vectorSize = vectorLength * FLOAT_BYTE_SIZE + JAVA_ARRAY_HEADER_SIZE; + if (vectorSize % JAVA_ROUNDING_NUMBER != 0) { + vectorSize += vectorSize % JAVA_ROUNDING_NUMBER; + } + int vectorsSize = numVectors * (vectorSize + JAVA_REFERENCE_SIZE) + JAVA_ARRAY_HEADER_SIZE; + if (vectorsSize % JAVA_ROUNDING_NUMBER != 0) { + vectorsSize += vectorsSize % JAVA_ROUNDING_NUMBER; + } + return vectorsSize; + } else { + int vectorSize = vectorLength * FLOAT_BYTE_SIZE; + if (vectorSize % JAVA_ROUNDING_NUMBER != 0) { + vectorSize += vectorSize % JAVA_ROUNDING_NUMBER; + } + int vectorsSize = numVectors * (vectorSize + JAVA_REFERENCE_SIZE); + if (vectorsSize % JAVA_ROUNDING_NUMBER != 0) { + vectorsSize += vectorsSize % JAVA_ROUNDING_NUMBER; + } + return vectorsSize; + } } public static String buildEngineFileName(String segmentName, String latestBuildVersion, String fieldName, String extension) { diff --git a/src/main/java/org/opensearch/knn/index/codec/util/KNNVectorSerializerFactory.java b/src/main/java/org/opensearch/knn/index/codec/util/KNNVectorSerializerFactory.java index f02da0949..5c1e4ca9b 100644 --- a/src/main/java/org/opensearch/knn/index/codec/util/KNNVectorSerializerFactory.java +++ b/src/main/java/org/opensearch/knn/index/codec/util/KNNVectorSerializerFactory.java @@ -56,7 +56,7 @@ public static KNNVectorSerializer getSerializerByStreamContent(final ByteArrayIn return getSerializerBySerializationMode(serializationMode); } - private static SerializationMode serializerModeFromStream(ByteArrayInputStream byteStream) { + static SerializationMode serializerModeFromStream(ByteArrayInputStream byteStream) { int numberOfAvailableBytesInStream = byteStream.available(); if (numberOfAvailableBytesInStream < ARRAY_HEADER_OFFSET) { return getSerializerOrThrowError(numberOfAvailableBytesInStream, COLLECTION_OF_FLOATS); diff --git a/src/main/java/org/opensearch/knn/plugin/stats/KNNGraphValue.java b/src/main/java/org/opensearch/knn/plugin/stats/KNNGraphValue.java new file mode 100644 index 000000000..b33b59e36 --- /dev/null +++ b/src/main/java/org/opensearch/knn/plugin/stats/KNNGraphValue.java @@ -0,0 +1,95 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.plugin.stats; + +import java.util.concurrent.atomic.AtomicLong; + +/** + * Contains a map to keep track of different graph values + */ +public enum KNNGraphValue { + + REFRESH_TOTAL_OPERATIONS("total"), + REFRESH_TOTAL_TIME_IN_MILLIS("total_time_in_millis"), + MERGE_CURRENT_OPERATIONS("current"), + MERGE_CURRENT_DOCS("current_docs"), + MERGE_CURRENT_SIZE_IN_BYTES("current_size_in_bytes"), + MERGE_TOTAL_OPERATIONS("total"), + MERGE_TOTAL_TIME_IN_MILLIS("total_time_in_millis"), + MERGE_TOTAL_DOCS("total_docs"), + MERGE_TOTAL_SIZE_IN_BYTES("total_size_in_bytes"); + + private String name; + private AtomicLong value; + + /** + * Constructor + * + * @param name name of the graph value + */ + KNNGraphValue(String name) { + this.name = name; + this.value = new AtomicLong(0); + } + + /** + * Get name of value + * + * @return name + */ + public String getName() { + return name; + } + + /** + * Get the graph value + * + * @return value + */ + public Long getValue() { + return value.get(); + } + + /** + * Increment the graph value + */ + public void increment() { + value.getAndIncrement(); + } + + /** + * Decrement the graph value + */ + public void decrement() { + value.getAndDecrement(); + } + + /** + * Increment the graph value by a specified amount + * + * @param delta The amount to increment + */ + public void incrementBy(long delta) { + value.getAndAdd(delta); + } + + /** + * Decrement the graph value by a specified amount + * + * @param delta The amount to decrement + */ + public void decrementBy(long delta) { + value.set(value.get() - delta); + } + + /** + * @param value graph value + * Set the graph value + */ + public void set(long value) { + this.value.set(value); + } +} diff --git a/src/main/java/org/opensearch/knn/plugin/stats/KNNStats.java b/src/main/java/org/opensearch/knn/plugin/stats/KNNStats.java index 66b3f215b..07d129652 100644 --- a/src/main/java/org/opensearch/knn/plugin/stats/KNNStats.java +++ b/src/main/java/org/opensearch/knn/plugin/stats/KNNStats.java @@ -24,6 +24,7 @@ import java.time.temporal.ChronoUnit; import java.util.HashMap; import java.util.Map; +import java.util.function.Supplier; /** * Class represents all stats the plugin keeps track of @@ -84,6 +85,7 @@ private Map> buildStatsMap() { addEngineStats(builder); addScriptStats(builder); addModelStats(builder); + addGraphStats(builder); return builder.build(); } @@ -169,4 +171,31 @@ private void addModelStats(ImmutableMap.Builder> builder) { new KNNStat<>(false, new NativeMemoryCacheManagerSupplier<>(NativeMemoryCacheManager::getTrainingSizeAsPercentage)) ); } + + private void addGraphStats(ImmutableMap.Builder> builder) { + builder.put(StatNames.GRAPH_STATS.getName(), new KNNStat<>(false, new Supplier>>() { + @Override + public Map> get() { + return createGraphStatsMap(); + } + })); + } + + private Map> createGraphStatsMap() { + Map mergeMap = new HashMap<>(); + mergeMap.put(KNNGraphValue.MERGE_CURRENT_OPERATIONS.getName(), KNNGraphValue.MERGE_CURRENT_OPERATIONS.getValue()); + mergeMap.put(KNNGraphValue.MERGE_CURRENT_DOCS.getName(), KNNGraphValue.MERGE_CURRENT_DOCS.getValue()); + mergeMap.put(KNNGraphValue.MERGE_CURRENT_SIZE_IN_BYTES.getName(), KNNGraphValue.MERGE_CURRENT_SIZE_IN_BYTES.getValue()); + mergeMap.put(KNNGraphValue.MERGE_TOTAL_OPERATIONS.getName(), KNNGraphValue.MERGE_TOTAL_OPERATIONS.getValue()); + mergeMap.put(KNNGraphValue.MERGE_TOTAL_TIME_IN_MILLIS.getName(), KNNGraphValue.MERGE_TOTAL_TIME_IN_MILLIS.getValue()); + mergeMap.put(KNNGraphValue.MERGE_TOTAL_DOCS.getName(), KNNGraphValue.MERGE_TOTAL_DOCS.getValue()); + mergeMap.put(KNNGraphValue.MERGE_TOTAL_SIZE_IN_BYTES.getName(), KNNGraphValue.MERGE_TOTAL_SIZE_IN_BYTES.getValue()); + Map refreshMap = new HashMap<>(); + refreshMap.put(KNNGraphValue.REFRESH_TOTAL_OPERATIONS.getName(), KNNGraphValue.REFRESH_TOTAL_OPERATIONS.getValue()); + refreshMap.put(KNNGraphValue.REFRESH_TOTAL_TIME_IN_MILLIS.getName(), KNNGraphValue.REFRESH_TOTAL_TIME_IN_MILLIS.getValue()); + Map> graphStatsMap = new HashMap<>(); + graphStatsMap.put(StatNames.MERGE.getName(), mergeMap); + graphStatsMap.put(StatNames.REFRESH.getName(), refreshMap); + return graphStatsMap; + } } diff --git a/src/main/java/org/opensearch/knn/plugin/stats/StatNames.java b/src/main/java/org/opensearch/knn/plugin/stats/StatNames.java index a098dd8b5..e9ed2b126 100644 --- a/src/main/java/org/opensearch/knn/plugin/stats/StatNames.java +++ b/src/main/java/org/opensearch/knn/plugin/stats/StatNames.java @@ -41,7 +41,10 @@ public enum StatNames { TRAINING_MEMORY_USAGE("training_memory_usage"), TRAINING_MEMORY_USAGE_PERCENTAGE("training_memory_usage_percentage"), SCRIPT_QUERY_ERRORS(KNNCounter.SCRIPT_QUERY_ERRORS.getName()), - KNN_QUERY_WITH_FILTER_REQUESTS(KNNCounter.KNN_QUERY_WITH_FILTER_REQUESTS.getName()); + KNN_QUERY_WITH_FILTER_REQUESTS(KNNCounter.KNN_QUERY_WITH_FILTER_REQUESTS.getName()), + GRAPH_STATS("graph_stats"), + REFRESH("refresh"), + MERGE("merge"); private String name; diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java index 58f4b6e39..6af83de87 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java @@ -39,12 +39,14 @@ import org.opensearch.knn.indices.ModelState; import org.opensearch.knn.jni.JNIService; import org.opensearch.knn.plugin.stats.KNNCounter; +import org.opensearch.knn.plugin.stats.KNNGraphValue; import java.io.IOException; import java.util.Map; import java.util.concurrent.ExecutionException; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; @@ -94,7 +96,7 @@ public void testAddBinaryField_withKNN() throws IOException { KNN80DocValuesConsumer knn80DocValuesConsumer = new KNN80DocValuesConsumer(delegate, null) { @Override - public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer) { + public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer, boolean isMerge, boolean isRefresh) { called[0] = true; } }; @@ -118,7 +120,7 @@ public void testAddBinaryField_withoutKNN() throws IOException { KNN80DocValuesConsumer knn80DocValuesConsumer = new KNN80DocValuesConsumer(delegate, null) { @Override - public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer) { + public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer, boolean isMerge, boolean isRefresh) { called[0] = true; } }; @@ -133,9 +135,17 @@ public void testAddKNNBinaryField_noVectors() throws IOException { // When there are no new vectors, no more graph index requests should be added RandomVectorDocValuesProducer randomVectorDocValuesProducer = new RandomVectorDocValuesProducer(0, 128); Long initialGraphIndexRequests = KNNCounter.GRAPH_INDEX_REQUESTS.getCount(); + Long initialRefreshOperations = KNNGraphValue.REFRESH_TOTAL_OPERATIONS.getValue(); + Long initialMergeOperations = KNNGraphValue.MERGE_TOTAL_OPERATIONS.getValue(); + Long initialMergeSize = KNNGraphValue.MERGE_TOTAL_SIZE_IN_BYTES.getValue(); + Long initialMergeDocs = KNNGraphValue.MERGE_TOTAL_DOCS.getValue(); KNN80DocValuesConsumer knn80DocValuesConsumer = new KNN80DocValuesConsumer(null, null); - knn80DocValuesConsumer.addKNNBinaryField(null, randomVectorDocValuesProducer); + knn80DocValuesConsumer.addKNNBinaryField(null, randomVectorDocValuesProducer, true, true); assertEquals(initialGraphIndexRequests, KNNCounter.GRAPH_INDEX_REQUESTS.getCount()); + assertEquals(initialRefreshOperations, KNNGraphValue.REFRESH_TOTAL_OPERATIONS.getValue()); + assertEquals(initialMergeOperations, KNNGraphValue.MERGE_TOTAL_OPERATIONS.getValue()); + assertEquals(initialMergeSize, KNNGraphValue.MERGE_TOTAL_SIZE_IN_BYTES.getValue()); + assertEquals(initialMergeDocs, KNNGraphValue.MERGE_TOTAL_DOCS.getValue()); } public void testAddKNNBinaryField_fromScratch_nmslibCurrent() throws IOException { @@ -174,10 +184,13 @@ public void testAddKNNBinaryField_fromScratch_nmslibCurrent() throws IOException FieldInfos fieldInfos = new FieldInfos(fieldInfoArray); SegmentWriteState state = new SegmentWriteState(null, directory, segmentInfo, fieldInfos, null, IOContext.DEFAULT); + long initialRefreshOperations = KNNGraphValue.REFRESH_TOTAL_OPERATIONS.getValue(); + long initialMergeOperations = KNNGraphValue.MERGE_TOTAL_OPERATIONS.getValue(); + // Add documents to the field KNN80DocValuesConsumer knn80DocValuesConsumer = new KNN80DocValuesConsumer(null, state); RandomVectorDocValuesProducer randomVectorDocValuesProducer = new RandomVectorDocValuesProducer(docsInSegment, dimension); - knn80DocValuesConsumer.addKNNBinaryField(fieldInfoArray[0], randomVectorDocValuesProducer); + knn80DocValuesConsumer.addKNNBinaryField(fieldInfoArray[0], randomVectorDocValuesProducer, true, true); // The document should be created in the correct location String expectedFile = KNNCodecUtil.buildEngineFileName(segmentName, knnEngine.getVersion(), fieldName, knnEngine.getExtension()); @@ -188,6 +201,12 @@ public void testAddKNNBinaryField_fromScratch_nmslibCurrent() throws IOException // The document should be readable by nmslib assertLoadableByEngine(state, expectedFile, knnEngine, spaceType, dimension); + + // The graph creation statistics should be updated + assertEquals(1 + initialRefreshOperations, (long) KNNGraphValue.REFRESH_TOTAL_OPERATIONS.getValue()); + assertEquals(1 + initialMergeOperations, (long) KNNGraphValue.MERGE_TOTAL_OPERATIONS.getValue()); + assertNotEquals(0, (long) KNNGraphValue.MERGE_TOTAL_DOCS.getValue()); + assertNotEquals(0, (long) KNNGraphValue.MERGE_TOTAL_SIZE_IN_BYTES.getValue()); } public void testAddKNNBinaryField_fromScratch_nmslibLegacy() throws IOException { @@ -218,10 +237,13 @@ public void testAddKNNBinaryField_fromScratch_nmslibLegacy() throws IOException FieldInfos fieldInfos = new FieldInfos(fieldInfoArray); SegmentWriteState state = new SegmentWriteState(null, directory, segmentInfo, fieldInfos, null, IOContext.DEFAULT); + long initialRefreshOperations = KNNGraphValue.REFRESH_TOTAL_OPERATIONS.getValue(); + long initialMergeOperations = KNNGraphValue.MERGE_TOTAL_OPERATIONS.getValue(); + // Add documents to the field KNN80DocValuesConsumer knn80DocValuesConsumer = new KNN80DocValuesConsumer(null, state); RandomVectorDocValuesProducer randomVectorDocValuesProducer = new RandomVectorDocValuesProducer(docsInSegment, dimension); - knn80DocValuesConsumer.addKNNBinaryField(fieldInfoArray[0], randomVectorDocValuesProducer); + knn80DocValuesConsumer.addKNNBinaryField(fieldInfoArray[0], randomVectorDocValuesProducer, true, true); // The document should be created in the correct location String expectedFile = KNNCodecUtil.buildEngineFileName(segmentName, knnEngine.getVersion(), fieldName, knnEngine.getExtension()); @@ -232,6 +254,12 @@ public void testAddKNNBinaryField_fromScratch_nmslibLegacy() throws IOException // The document should be readable by nmslib assertLoadableByEngine(state, expectedFile, knnEngine, spaceType, dimension); + + // The graph creation statistics should be updated + assertEquals(1 + initialRefreshOperations, (long) KNNGraphValue.REFRESH_TOTAL_OPERATIONS.getValue()); + assertEquals(1 + initialMergeOperations, (long) KNNGraphValue.MERGE_TOTAL_OPERATIONS.getValue()); + assertNotEquals(0, (long) KNNGraphValue.MERGE_TOTAL_DOCS.getValue()); + assertNotEquals(0, (long) KNNGraphValue.MERGE_TOTAL_SIZE_IN_BYTES.getValue()); } public void testAddKNNBinaryField_fromScratch_faissCurrent() throws IOException { @@ -269,10 +297,13 @@ public void testAddKNNBinaryField_fromScratch_faissCurrent() throws IOException FieldInfos fieldInfos = new FieldInfos(fieldInfoArray); SegmentWriteState state = new SegmentWriteState(null, directory, segmentInfo, fieldInfos, null, IOContext.DEFAULT); + long initialRefreshOperations = KNNGraphValue.REFRESH_TOTAL_OPERATIONS.getValue(); + long initialMergeOperations = KNNGraphValue.MERGE_TOTAL_OPERATIONS.getValue(); + // Add documents to the field KNN80DocValuesConsumer knn80DocValuesConsumer = new KNN80DocValuesConsumer(null, state); RandomVectorDocValuesProducer randomVectorDocValuesProducer = new RandomVectorDocValuesProducer(docsInSegment, dimension); - knn80DocValuesConsumer.addKNNBinaryField(fieldInfoArray[0], randomVectorDocValuesProducer); + knn80DocValuesConsumer.addKNNBinaryField(fieldInfoArray[0], randomVectorDocValuesProducer, true, true); // The document should be created in the correct location String expectedFile = KNNCodecUtil.buildEngineFileName(segmentName, knnEngine.getVersion(), fieldName, knnEngine.getExtension()); @@ -283,6 +314,12 @@ public void testAddKNNBinaryField_fromScratch_faissCurrent() throws IOException // The document should be readable by faiss assertLoadableByEngine(state, expectedFile, knnEngine, spaceType, dimension); + + // The graph creation statistics should be updated + assertEquals(1 + initialRefreshOperations, (long) KNNGraphValue.REFRESH_TOTAL_OPERATIONS.getValue()); + assertEquals(1 + initialMergeOperations, (long) KNNGraphValue.MERGE_TOTAL_OPERATIONS.getValue()); + assertNotEquals(0, (long) KNNGraphValue.MERGE_TOTAL_DOCS.getValue()); + assertNotEquals(0, (long) KNNGraphValue.MERGE_TOTAL_SIZE_IN_BYTES.getValue()); } public void testAddKNNBinaryField_fromModel_faiss() throws IOException, ExecutionException, InterruptedException { @@ -345,10 +382,13 @@ public void testAddKNNBinaryField_fromModel_faiss() throws IOException, Executio FieldInfos fieldInfos = new FieldInfos(fieldInfoArray); SegmentWriteState state = new SegmentWriteState(null, directory, segmentInfo, fieldInfos, null, IOContext.DEFAULT); + long initialRefreshOperations = KNNGraphValue.REFRESH_TOTAL_OPERATIONS.getValue(); + long initialMergeOperations = KNNGraphValue.MERGE_TOTAL_OPERATIONS.getValue(); + // Add documents to the field KNN80DocValuesConsumer knn80DocValuesConsumer = new KNN80DocValuesConsumer(null, state); RandomVectorDocValuesProducer randomVectorDocValuesProducer = new RandomVectorDocValuesProducer(docsInSegment, dimension); - knn80DocValuesConsumer.addKNNBinaryField(fieldInfoArray[0], randomVectorDocValuesProducer); + knn80DocValuesConsumer.addKNNBinaryField(fieldInfoArray[0], randomVectorDocValuesProducer, true, true); // The document should be created in the correct location String expectedFile = KNNCodecUtil.buildEngineFileName(segmentName, knnEngine.getVersion(), fieldName, knnEngine.getExtension()); @@ -359,6 +399,13 @@ public void testAddKNNBinaryField_fromModel_faiss() throws IOException, Executio // The document should be readable by faiss assertLoadableByEngine(state, expectedFile, knnEngine, spaceType, dimension); + + // The graph creation statistics should be updated + assertEquals(1 + initialRefreshOperations, (long) KNNGraphValue.REFRESH_TOTAL_OPERATIONS.getValue()); + assertEquals(1 + initialMergeOperations, (long) KNNGraphValue.MERGE_TOTAL_OPERATIONS.getValue()); + assertNotEquals(0, (long) KNNGraphValue.MERGE_TOTAL_DOCS.getValue()); + assertNotEquals(0, (long) KNNGraphValue.MERGE_TOTAL_SIZE_IN_BYTES.getValue()); + } public void testMerge_exception() throws IOException { @@ -426,6 +473,6 @@ public void testAddBinaryField_luceneEngine_noInvocations_addKNNBinary() throws knn80DocValuesConsumer.addBinaryField(fieldInfo, docValuesProducer); verify(delegate, times(1)).addBinaryField(fieldInfo, docValuesProducer); - verify(knn80DocValuesConsumer, never()).addKNNBinaryField(any(), any()); + verify(knn80DocValuesConsumer, never()).addKNNBinaryField(any(), any(), eq(false), eq(true)); } } diff --git a/src/testFixtures/java/org/opensearch/knn/TestUtils.java b/src/testFixtures/java/org/opensearch/knn/TestUtils.java index 12751dff8..f22e4b267 100644 --- a/src/testFixtures/java/org/opensearch/knn/TestUtils.java +++ b/src/testFixtures/java/org/opensearch/knn/TestUtils.java @@ -16,6 +16,7 @@ import java.io.FileReader; import java.io.IOException; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.codec.util.SerializationMode; import org.opensearch.knn.plugin.script.KNNScoringUtil; import java.util.Comparator; import java.util.Random; @@ -283,7 +284,7 @@ private KNNCodecUtil.Pair readIndexData(String path) throws IOException { } } - return new KNNCodecUtil.Pair(idsArray, vectorsArray); + return new KNNCodecUtil.Pair(idsArray, vectorsArray, SerializationMode.COLLECTION_OF_FLOATS); } private float[][] readQueries(String path) throws IOException {