Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update KNN80BinaryDocValues reader count live docs and use live docs as initial capacity to initialize vector address #1595

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion src/main/java/org/opensearch/knn/index/KNNSettings.java
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ public class KNNSettings {
public static final String KNN_ALGO_PARAM_INDEX_THREAD_QTY = "knn.algo_param.index_thread_qty";
public static final String KNN_MEMORY_CIRCUIT_BREAKER_ENABLED = "knn.memory.circuit_breaker.enabled";
public static final String KNN_MEMORY_CIRCUIT_BREAKER_LIMIT = "knn.memory.circuit_breaker.limit";
public static final String KNN_VECTOR_STREAMING_MEMORY_LIMIT_IN_MB = "knn.vector_streaming_memory.limit";
public static final String KNN_CIRCUIT_BREAKER_TRIGGERED = "knn.circuit_breaker.triggered";
public static final String KNN_CACHE_ITEM_EXPIRY_ENABLED = "knn.cache.item.expiry.enabled";
public static final String KNN_CACHE_ITEM_EXPIRY_TIME_MINUTES = "knn.cache.item.expiry.minutes";
Expand All @@ -93,13 +94,23 @@ public class KNNSettings {
public static final Integer KNN_DEFAULT_MODEL_CACHE_SIZE_LIMIT_PERCENTAGE = 10; // By default, set aside 10% of the JVM for the limit
public static final Integer KNN_MAX_MODEL_CACHE_SIZE_LIMIT_PERCENTAGE = 25; // Model cache limit cannot exceed 25% of the JVM heap
public static final String KNN_DEFAULT_MEMORY_CIRCUIT_BREAKER_LIMIT = "50%";
public static final String KNN_DEFAULT_VECTOR_STREAMING_MEMORY_LIMIT_PCT = "1%";

public static final Integer ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD_DEFAULT_VALUE = -1;

/**
* Settings Definition
*/

// This setting controls how much memory should be used to transfer vectors from Java to JNI Layer. The default
// 1% of the JVM heap
public static final Setting<ByteSizeValue> KNN_VECTOR_STREAMING_MEMORY_LIMIT_PCT_SETTING = Setting.memorySizeSetting(
KNN_VECTOR_STREAMING_MEMORY_LIMIT_IN_MB,
KNN_DEFAULT_VECTOR_STREAMING_MEMORY_LIMIT_PCT,
Setting.Property.Dynamic,
Setting.Property.NodeScope
);

public static final Setting<String> INDEX_KNN_SPACE_TYPE = Setting.simpleString(
KNN_SPACE_TYPE,
INDEX_KNN_DEFAULT_SPACE_TYPE,
Expand Down Expand Up @@ -354,6 +365,10 @@ private Setting<?> getSetting(String key) {
return KNN_FAISS_AVX2_DISABLED_SETTING;
}

if (KNN_VECTOR_STREAMING_MEMORY_LIMIT_IN_MB.equals(key)) {
return KNN_VECTOR_STREAMING_MEMORY_LIMIT_PCT_SETTING;
}

throw new IllegalArgumentException("Cannot find setting by key [" + key + "]");
}

Expand All @@ -371,7 +386,8 @@ public List<Setting<?>> getSettings() {
MODEL_INDEX_NUMBER_OF_REPLICAS_SETTING,
MODEL_CACHE_SIZE_LIMIT_SETTING,
ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD_SETTING,
KNN_FAISS_AVX2_DISABLED_SETTING
KNN_FAISS_AVX2_DISABLED_SETTING,
KNN_VECTOR_STREAMING_MEMORY_LIMIT_PCT_SETTING
);
return Stream.concat(settings.stream(), dynamicCacheSettings.values().stream()).collect(Collectors.toList());
}
Expand Down Expand Up @@ -475,6 +491,10 @@ public void onFailure(Exception e) {
});
}

public static ByteSizeValue getVectorStreamingMemoryLimit() {
return KNNSettings.state().getSettingValue(KNN_VECTOR_STREAMING_MEMORY_LIMIT_IN_MB);
}

/**
*
* @param index Name of the index
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.knn.index.codec.KNN80Codec;

import lombok.Getter;
import org.opensearch.knn.index.codec.util.BinaryDocValuesSub;
import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.index.DocIDMerger;
Expand All @@ -15,10 +16,13 @@
/**
* A per-document kNN numeric value.
*/
class KNN80BinaryDocValues extends BinaryDocValues {
public class KNN80BinaryDocValues extends BinaryDocValues {

private DocIDMerger<BinaryDocValuesSub> docIDMerger;

@Getter
private long totalLiveDocs;

KNN80BinaryDocValues(DocIDMerger<BinaryDocValuesSub> docIdMerger) {
this.docIDMerger = docIdMerger;
}
Expand Down Expand Up @@ -61,4 +65,14 @@ public long cost() {
public BytesRef binaryValue() throws IOException {
return current.getValues().binaryValue();
}
};

/**
* Builder pattern like setter for setting totalLiveDocs. We can use setter also. But this way the code is clean.
* @param totalLiveDocs int
* @return {@link KNN80BinaryDocValues}
*/
public KNN80BinaryDocValues setTotalLiveDocs(long totalLiveDocs) {
navneet1v marked this conversation as resolved.
Show resolved Hide resolved
this.totalLiveDocs = totalLiveDocs;
return this;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@

package org.opensearch.knn.index.codec.KNN80Codec;

import lombok.extern.log4j.Log4j2;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.util.Bits;
import org.opensearch.common.StopWatch;
import org.opensearch.knn.index.codec.util.BinaryDocValuesSub;
import org.apache.lucene.codecs.DocValuesProducer;
import org.apache.lucene.index.BinaryDocValues;
Expand All @@ -14,12 +18,14 @@
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.MergeState;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

/**
* Reader for KNNDocValues from the segments
*/
@Log4j2
class KNN80DocValuesReader extends EmptyDocValuesProducer {

private final MergeState mergeState;
Expand All @@ -30,6 +36,7 @@ class KNN80DocValuesReader extends EmptyDocValuesProducer {

@Override
public BinaryDocValues getBinary(FieldInfo field) {
long totalLiveDocs = 0;
try {
List<BinaryDocValuesSub> subs = new ArrayList<>(this.mergeState.docValuesProducers.length);
for (int i = 0; i < this.mergeState.docValuesProducers.length; i++) {
Expand All @@ -41,13 +48,49 @@ public BinaryDocValues getBinary(FieldInfo field) {
values = docValuesProducer.getBinary(readerFieldInfo);
}
if (values != null) {
totalLiveDocs = totalLiveDocs + getLiveDocsCount(values, this.mergeState.liveDocs[i]);
// docValues will be consumed when liveDocs are not null, hence resetting the docsValues
// pointer.
values = this.mergeState.liveDocs[i] != null ? docValuesProducer.getBinary(readerFieldInfo) : values;
jmazanec15 marked this conversation as resolved.
Show resolved Hide resolved

subs.add(new BinaryDocValuesSub(mergeState.docMaps[i], values));
}
}
}
return new KNN80BinaryDocValues(DocIDMerger.of(subs, mergeState.needsIndexSort));
return new KNN80BinaryDocValues(DocIDMerger.of(subs, mergeState.needsIndexSort)).setTotalLiveDocs(totalLiveDocs);
} catch (Exception e) {
throw new RuntimeException(e);
}
}

/**
* This function return the liveDocs count present in the BinaryDocValues. If the liveDocsBits is null, then we
* can use {@link BinaryDocValues#cost()} function to get max docIds. But if LiveDocsBits is not null, then we
* iterate over the BinaryDocValues and validate if the docId is present in the live docs bits or not.
*
* @param binaryDocValues {@link BinaryDocValues}
* @param liveDocsBits {@link Bits}
* @return total number of liveDocs.
* @throws IOException
*/
private long getLiveDocsCount(final BinaryDocValues binaryDocValues, final Bits liveDocsBits) throws IOException {
long liveDocs = 0;
if (liveDocsBits != null) {
int docId;
// This is not the right way to log the time. I create a github issue for adding an annotation to track
// the time. https://github.com/opensearch-project/k-NN/issues/1594
StopWatch stopWatch = new StopWatch();
navneet1v marked this conversation as resolved.
Show resolved Hide resolved
stopWatch.start();
for (docId = binaryDocValues.nextDoc(); docId != DocIdSetIterator.NO_MORE_DOCS; docId = binaryDocValues.nextDoc()) {
if (liveDocsBits.get(docId)) {
liveDocs++;
}
}
stopWatch.stop();
log.debug("Time taken to iterate over binary doc values: {} ms", stopWatch.totalTime().millis());
} else {
liveDocs = binaryDocValues.cost();
}
return liveDocs;
}
}
46 changes: 36 additions & 10 deletions src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,16 @@
import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.util.BytesRef;
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.index.codec.KNN80Codec.KNN80BinaryDocValues;
import org.opensearch.knn.jni.JNICommons;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

public class KNNCodecUtil {

public static final String HNSW_EXTENSION = ".hnsw";
public static final String HNSW_COMPOUND_EXTENSION = ".hnswc";
// Floats are 4 bytes in size
public static final int FLOAT_BYTE_SIZE = 4;
// References to objects are 4 bytes in size
Expand All @@ -44,28 +44,44 @@ public static final class Pair {
}

public static KNNCodecUtil.Pair getFloats(BinaryDocValues values) throws IOException {
ArrayList<float[]> vectorList = new ArrayList<>();
ArrayList<Integer> docIdList = new ArrayList<>();
List<float[]> vectorList = new ArrayList<>();
List<Integer> docIdList = new ArrayList<>();
long vectorAddress = 0;
int dimension = 0;
SerializationMode serializationMode = SerializationMode.COLLECTION_OF_FLOATS;

long totalLiveDocs = getTotalLiveDocsCount(values);
long vectorsStreamingMemoryLimit = KNNSettings.getVectorStreamingMemoryLimit().getBytes();
long vectorsPerTransfer = Integer.MIN_VALUE;

for (int doc = values.nextDoc(); doc != DocIdSetIterator.NO_MORE_DOCS; doc = values.nextDoc()) {
BytesRef bytesref = values.binaryValue();
try (ByteArrayInputStream byteStream = new ByteArrayInputStream(bytesref.bytes, bytesref.offset, bytesref.length)) {
serializationMode = KNNVectorSerializerFactory.serializerModeFromStream(byteStream);
final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getSerializerByStreamContent(byteStream);
final float[] vector = vectorSerializer.byteToFloatArray(byteStream);
dimension = vector.length;

if (vectorsPerTransfer == Integer.MIN_VALUE) {
vectorsPerTransfer = (dimension * Float.BYTES * totalLiveDocs) / vectorsStreamingMemoryLimit;
}
if (vectorList.size() == vectorsPerTransfer) {
vectorAddress = JNICommons.storeVectorData(
vectorAddress,
vectorList.toArray(new float[][] {}),
totalLiveDocs * dimension
);
// We should probably come up with a better way to reuse the vectorList memory which we have
// created. Problem here is doing like this can lead to a lot of list memory which is of no use and
// will be garbage collected later on, but it creates pressure on JVM. We should revisit this.
vectorList = new ArrayList<>();
}
vectorList.add(vector);
}
docIdList.add(doc);
}
if (vectorList.isEmpty() == false) {
vectorAddress = JNICommons.storeVectorData(
vectorAddress,
vectorList.toArray(new float[][] {}),
(long) vectorList.size() * dimension
);
vectorAddress = JNICommons.storeVectorData(vectorAddress, vectorList.toArray(new float[][] {}), totalLiveDocs * dimension);
}
return new KNNCodecUtil.Pair(docIdList.stream().mapToInt(Integer::intValue).toArray(), vectorAddress, dimension, serializationMode);
}
Expand Down Expand Up @@ -105,4 +121,14 @@ public static String buildEngineFilePrefix(String segmentName) {
public static String buildEngineFileSuffix(String fieldName, String extension) {
return String.format("_%s%s", fieldName, extension);
}

private static long getTotalLiveDocsCount(final BinaryDocValues binaryDocValues) {
long totalLiveDocs;
if (binaryDocValues instanceof KNN80BinaryDocValues) {
totalLiveDocs = ((KNN80BinaryDocValues) binaryDocValues).getTotalLiveDocs();
} else {
totalLiveDocs = binaryDocValues.cost();
}
return totalLiveDocs;
}
}
104 changes: 104 additions & 0 deletions src/test/java/org/opensearch/knn/index/FaissIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,12 @@
import java.io.IOException;
import java.net.URL;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.TreeMap;
import java.util.stream.Collectors;

Expand Down Expand Up @@ -173,6 +175,108 @@ public void testEndToEnd_whenMethodIsHNSWFlat_thenSucceed() {
fail("Graphs are not getting evicted");
}

@SneakyThrows
public void testEndToEnd_whenMethodIsHNSWFlatAndHasDeletedDocs_thenSucceed() {
String indexName = "test-index-1";
String fieldName = "test-field-1";

KNNMethod hnswMethod = KNNEngine.FAISS.getMethod(KNNConstants.METHOD_HNSW);
SpaceType spaceType = SpaceType.L2;

List<Integer> mValues = ImmutableList.of(16, 32, 64, 128);
List<Integer> efConstructionValues = ImmutableList.of(16, 32, 64, 128);
List<Integer> efSearchValues = ImmutableList.of(16, 32, 64, 128);

Integer dimension = testData.indexData.vectors[0].length;

// Create an index
XContentBuilder builder = XContentFactory.jsonBuilder()
.startObject()
.startObject("properties")
.startObject(fieldName)
.field("type", "knn_vector")
.field("dimension", dimension)
.startObject(KNNConstants.KNN_METHOD)
.field(KNNConstants.NAME, hnswMethod.getMethodComponent().getName())
.field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue())
.field(KNNConstants.KNN_ENGINE, KNNEngine.FAISS.getName())
.startObject(KNNConstants.PARAMETERS)
.field(KNNConstants.METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size())))
.field(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, efConstructionValues.get(random().nextInt(efConstructionValues.size())))
.field(KNNConstants.METHOD_PARAMETER_EF_SEARCH, efSearchValues.get(random().nextInt(efSearchValues.size())))
.endObject()
.endObject()
.endObject()
.endObject()
.endObject();

Map<String, Object> mappingMap = xContentBuilderToMap(builder);
String mapping = builder.toString();

createKnnIndex(indexName, mapping);
assertEquals(new TreeMap<>(mappingMap), new TreeMap<>(getIndexMappingAsMap(indexName)));

// Index the test data
for (int i = 0; i < testData.indexData.docs.length; i++) {
addKnnDoc(
indexName,
Integer.toString(testData.indexData.docs[i]),
fieldName,
Floats.asList(testData.indexData.vectors[i]).toArray()
);
}

// Assert we have the right number of documents in the index
refreshAllNonSystemIndices();
assertEquals(testData.indexData.docs.length, getDocCount(indexName));

final Set<Integer> docIdsToBeDeleted = new HashSet<>();
while (docIdsToBeDeleted.size() < 10) {
docIdsToBeDeleted.add(randomInt(testData.indexData.docs.length));
}

for (Integer id : docIdsToBeDeleted) {
deleteKnnDoc(indexName, Integer.toString(testData.indexData.docs[id]));
}
refreshAllNonSystemIndices();
forceMergeKnnIndex(indexName, 3);

assertEquals(testData.indexData.docs.length - 10, getDocCount(indexName));

int k = 10;
for (int i = 0; i < testData.queries.length; i++) {
Response response = searchKNNIndex(indexName, new KNNQueryBuilder(fieldName, testData.queries[i], k), k);
String responseBody = EntityUtils.toString(response.getEntity());
List<KNNResult> knnResults = parseSearchResponse(responseBody, fieldName);
assertEquals(k, knnResults.size());

List<Float> actualScores = parseSearchResponseScore(responseBody, fieldName);
for (int j = 0; j < k; j++) {
float[] primitiveArray = knnResults.get(j).getVector();
assertEquals(
KNNEngine.FAISS.score(KNNScoringUtil.l2Squared(testData.queries[i], primitiveArray), spaceType),
actualScores.get(j),
0.0001
);
}
}

// Delete index
deleteKNNIndex(indexName);

// Search every 5 seconds 14 times to confirm graph gets evicted
int intervals = 14;
for (int i = 0; i < intervals; i++) {
if (getTotalGraphsInCache() == 0) {
return;
}

Thread.sleep(5 * 1000);
}

fail("Graphs are not getting evicted");
}

@SneakyThrows
public void testEndToEnd_whenMethodIsHNSWPQ_thenSucceed() {
String indexName = "test-index";
Expand Down
Loading
Loading