Skip to content

Commit

Permalink
Update Faiss engine to allow PQ and HNSW
Browse files Browse the repository at this point in the history
Updates faiss engine to enable hnsw and faiss to work together. For
HNSW, code_size must be equal to 8 (refer to
facebookresearch/faiss#3027). Therefore, the
index description string "HNSW32,PQXxY" does not work. Only "HNSW32,PQX"
ends up working.

Additionally, adds several unit tests and integration tests in order to
validate the functionality.

Signed-off-by: John Mazanec <[email protected]>
  • Loading branch information
jmazanec15 committed Aug 30, 2023
1 parent 8994de6 commit f71fb24
Show file tree
Hide file tree
Showing 7 changed files with 491 additions and 88 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* Improved the logic to switch to exact search for restrictive filters search for better recall. [#1059](https://github.com/opensearch-project/k-NN/pull/1059)
* Added max distance computation logic to enhance the switch to exact search in filtered Nearest Neighbor Search. [#1066](https://github.com/opensearch-project/k-NN/pull/1066)
### Bug Fixes
* Update Faiss parameter construction to allow HNSW+PQ to work [#1074](https://github.com/opensearch-project/k-NN/pull/1074)
### Infrastructure
### Documentation
### Maintenance
Expand Down
160 changes: 104 additions & 56 deletions src/main/java/org/opensearch/knn/index/util/Faiss.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.function.Function;

import static org.opensearch.knn.common.KNNConstants.BYTES_PER_KILOBYTES;
Expand Down Expand Up @@ -64,9 +65,7 @@ class Faiss extends NativeLibrary {
Collections.emptyMap()
);

// TODO: To think about in future: for PQ, if dimension is not divisible by code count, PQ will fail. Right now,
// we do not have a way to base validation off of dimension. Failure will happen during training in JNI.
private final static Map<String, MethodComponent> encoderComponents = ImmutableMap.of(
private final static Map<String, MethodComponent> COMMON_ENCODERS = ImmutableMap.of(
KNNConstants.ENCODER_FLAT,
MethodComponent.Builder.builder(KNNConstants.ENCODER_FLAT)
.setMapGenerator(
Expand All @@ -76,62 +75,111 @@ class Faiss extends NativeLibrary {
methodComponentContext
).build())
)
.build(),
KNNConstants.ENCODER_PQ,
MethodComponent.Builder.builder(KNNConstants.ENCODER_PQ)
.addParameter(
ENCODER_PARAMETER_PQ_M,
new Parameter.IntegerParameter(
ENCODER_PARAMETER_PQ_M,
ENCODER_PARAMETER_PQ_CODE_COUNT_DEFAULT,
v -> v > 0 && v < ENCODER_PARAMETER_PQ_CODE_COUNT_LIMIT
)
)
.addParameter(
ENCODER_PARAMETER_PQ_CODE_SIZE,
new Parameter.IntegerParameter(
ENCODER_PARAMETER_PQ_CODE_SIZE,
ENCODER_PARAMETER_PQ_CODE_SIZE_DEFAULT,
v -> v > 0 && v < ENCODER_PARAMETER_PQ_CODE_SIZE_LIMIT
)
)
.setRequiresTraining(true)
.setMapGenerator(
((methodComponent, methodComponentContext) -> MethodAsMapBuilder.builder(
FAISS_PQ_DESCRIPTION,
methodComponent,
methodComponentContext
).addParameter(ENCODER_PARAMETER_PQ_M, "", "").addParameter(ENCODER_PARAMETER_PQ_CODE_SIZE, "x", "").build())
.build()
);

// TODO: To think about in future: for PQ, if dimension is not divisible by code count, PQ will fail. Right now,
// we do not have a way to base validation off of dimension. Failure will happen during training in JNI.
// Define methods supported by faiss
private final static Map<String, MethodComponent> HNSW_ENCODERS = ImmutableMap.<String, MethodComponent>builder()
.putAll(
ImmutableMap.of(
KNNConstants.ENCODER_PQ,
MethodComponent.Builder.builder(KNNConstants.ENCODER_PQ)
.addParameter(
ENCODER_PARAMETER_PQ_M,
new Parameter.IntegerParameter(
ENCODER_PARAMETER_PQ_M,
ENCODER_PARAMETER_PQ_CODE_COUNT_DEFAULT,
v -> v > 0 && v < ENCODER_PARAMETER_PQ_CODE_COUNT_LIMIT
)
)
.addParameter(
ENCODER_PARAMETER_PQ_CODE_SIZE,
new Parameter.IntegerParameter(
ENCODER_PARAMETER_PQ_CODE_SIZE,
ENCODER_PARAMETER_PQ_CODE_SIZE_DEFAULT,
v -> Objects.equals(v, ENCODER_PARAMETER_PQ_CODE_SIZE_DEFAULT)
)
)
.setRequiresTraining(true)
.setMapGenerator(
((methodComponent, methodComponentContext) -> MethodAsMapBuilder.builder(
FAISS_PQ_DESCRIPTION,
methodComponent,
methodComponentContext
).addParameter(ENCODER_PARAMETER_PQ_M, "", "").build())
)
.setOverheadInKBEstimator((methodComponent, methodComponentContext, dimension) -> {
int codeSize = ENCODER_PARAMETER_PQ_CODE_SIZE_DEFAULT;
return ((4L * (1L << codeSize) * dimension) / BYTES_PER_KILOBYTES) + 1;
})
.build()
)
.setOverheadInKBEstimator((methodComponent, methodComponentContext, dimension) -> {
// Size estimate formula: (4 * d * 2^code_size) / 1024 + 1

// Get value of code size passed in by user
Object codeSizeObject = methodComponentContext.getParameters().get(ENCODER_PARAMETER_PQ_CODE_SIZE);

// If not specified, get default value of code size
if (codeSizeObject == null) {
Parameter<?> codeSizeParameter = methodComponent.getParameters().get(ENCODER_PARAMETER_PQ_CODE_SIZE);
if (codeSizeParameter == null) {
throw new IllegalStateException(
String.format("%s is not a valid parameter. This is a bug.", ENCODER_PARAMETER_PQ_CODE_SIZE)
);
}
)
.putAll(COMMON_ENCODERS)
.build();

codeSizeObject = codeSizeParameter.getDefaultValue();
}
private final static Map<String, MethodComponent> IVF_ENCODERS = ImmutableMap.<String, MethodComponent>builder()
.putAll(
ImmutableMap.of(
KNNConstants.ENCODER_PQ,
MethodComponent.Builder.builder(KNNConstants.ENCODER_PQ)
.addParameter(
ENCODER_PARAMETER_PQ_M,
new Parameter.IntegerParameter(
ENCODER_PARAMETER_PQ_M,
ENCODER_PARAMETER_PQ_CODE_COUNT_DEFAULT,
v -> v > 0 && v < ENCODER_PARAMETER_PQ_CODE_COUNT_LIMIT
)
)
.addParameter(
ENCODER_PARAMETER_PQ_CODE_SIZE,
new Parameter.IntegerParameter(
ENCODER_PARAMETER_PQ_CODE_SIZE,
ENCODER_PARAMETER_PQ_CODE_SIZE_DEFAULT,
v -> v > 0 && v < ENCODER_PARAMETER_PQ_CODE_SIZE_LIMIT
)
)
.setRequiresTraining(true)
.setMapGenerator(
((methodComponent, methodComponentContext) -> MethodAsMapBuilder.builder(
FAISS_PQ_DESCRIPTION,
methodComponent,
methodComponentContext
).addParameter(ENCODER_PARAMETER_PQ_M, "", "").addParameter(ENCODER_PARAMETER_PQ_CODE_SIZE, "x", "").build())
)
.setOverheadInKBEstimator((methodComponent, methodComponentContext, dimension) -> {
// Size estimate formula: (4 * d * 2^code_size) / 1024 + 1

if (!(codeSizeObject instanceof Integer)) {
throw new IllegalStateException(String.format("%s must be an integer.", ENCODER_PARAMETER_PQ_CODE_SIZE));
}
// Get value of code size passed in by user
Object codeSizeObject = methodComponentContext.getParameters().get(ENCODER_PARAMETER_PQ_CODE_SIZE);

int codeSize = (Integer) codeSizeObject;
return ((4L * (1L << codeSize) * dimension) / BYTES_PER_KILOBYTES) + 1;
})
.build()
);
// If not specified, get default value of code size
if (codeSizeObject == null) {
Parameter<?> codeSizeParameter = methodComponent.getParameters().get(ENCODER_PARAMETER_PQ_CODE_SIZE);
if (codeSizeParameter == null) {
throw new IllegalStateException(
String.format("%s is not a valid parameter. This is a bug.", ENCODER_PARAMETER_PQ_CODE_SIZE)
);
}

codeSizeObject = codeSizeParameter.getDefaultValue();
}

if (!(codeSizeObject instanceof Integer)) {
throw new IllegalStateException(String.format("%s must be an integer.", ENCODER_PARAMETER_PQ_CODE_SIZE));
}

int codeSize = (Integer) codeSizeObject;
return ((4L * (1L << codeSize) * dimension) / BYTES_PER_KILOBYTES) + 1;
})
.build()
)
)
.putAll(COMMON_ENCODERS)
.build();

// Define methods supported by faiss
private final static Map<String, KNNMethod> METHODS = ImmutableMap.of(
METHOD_HNSW,
KNNMethod.Builder.builder(
Expand All @@ -158,7 +206,7 @@ class Faiss extends NativeLibrary {
)
.addParameter(
METHOD_ENCODER_PARAMETER,
new Parameter.MethodComponentContextParameter(METHOD_ENCODER_PARAMETER, ENCODER_DEFAULT, encoderComponents)
new Parameter.MethodComponentContextParameter(METHOD_ENCODER_PARAMETER, ENCODER_DEFAULT, HNSW_ENCODERS)
)
.setMapGenerator(
((methodComponent, methodComponentContext) -> MethodAsMapBuilder.builder(
Expand Down Expand Up @@ -190,7 +238,7 @@ class Faiss extends NativeLibrary {
)
.addParameter(
METHOD_ENCODER_PARAMETER,
new Parameter.MethodComponentContextParameter(METHOD_ENCODER_PARAMETER, ENCODER_DEFAULT, encoderComponents)
new Parameter.MethodComponentContextParameter(METHOD_ENCODER_PARAMETER, ENCODER_DEFAULT, IVF_ENCODERS)
)
.setRequiresTraining(true)
.setMapGenerator(
Expand Down
119 changes: 118 additions & 1 deletion src/test/java/org/opensearch/knn/index/FaissIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,12 @@
import java.util.TreeMap;
import java.util.stream.Collectors;

import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_M;
import static org.opensearch.knn.common.KNNConstants.ENCODER_PQ;
import static org.opensearch.knn.common.KNNConstants.FAISS_NAME;
import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE;
import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER;
import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NLIST;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE;
import static org.opensearch.knn.common.KNNConstants.MODEL_ID;
Expand All @@ -62,7 +67,8 @@ public static void setUpClass() throws IOException {
testData = new TestUtils.TestData(testIndexVectors.getPath(), testQueries.getPath());
}

public void testEndToEnd_fromMethod() throws Exception {
@SneakyThrows
public void testEndToEnd_whenMethodIsHNSWFlat_thenSucceed() {
String indexName = "test-index-1";
String fieldName = "test-field-1";

Expand Down Expand Up @@ -150,6 +156,117 @@ public void testEndToEnd_fromMethod() throws Exception {
fail("Graphs are not getting evicted");
}

@SneakyThrows
public void testEndToEnd_whenMethodIsHNSWPQ_thenSucceed() {
String indexName = "test-index";
String fieldName = "test-field";
String trainingIndexName = "training-index";
String trainingFieldName = "training-field";

String modelId = "test-model";
String modelDescription = "test model";

List<Integer> mValues = ImmutableList.of(16, 32, 64, 128);
List<Integer> efConstructionValues = ImmutableList.of(16, 32, 64, 128);
List<Integer> efSearchValues = ImmutableList.of(16, 32, 64, 128);
List<Integer> pqMValues = ImmutableList.of(2, 4, 8);

// training data needs to be at least equal to the number of centroids for PQ
// which is 2^8 = 256. 8 because thats the only valid code_size for HNSWPQ
int trainingDataCount = 256;

SpaceType spaceType = SpaceType.L2;

Integer dimension = testData.indexData.vectors[0].length;

XContentBuilder xContentBuilder = XContentFactory.jsonBuilder()
.startObject()
.field(NAME, METHOD_HNSW)
.field(KNN_ENGINE, FAISS_NAME)
.startObject(PARAMETERS)
.field(KNNConstants.METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size())))
.field(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, efConstructionValues.get(random().nextInt(efConstructionValues.size())))
.field(KNNConstants.METHOD_PARAMETER_EF_SEARCH, efSearchValues.get(random().nextInt(efSearchValues.size())))
.startObject(METHOD_ENCODER_PARAMETER)
.field(NAME, ENCODER_PQ)
.startObject(PARAMETERS)
.field(ENCODER_PARAMETER_PQ_M, pqMValues.get(random().nextInt(pqMValues.size())))
.endObject()
.endObject()
.endObject()
.endObject();
Map<String, Object> in = xContentBuilderToMap(xContentBuilder);

createBasicKnnIndex(trainingIndexName, trainingFieldName, dimension);
ingestDataAndTrainModel(modelId, trainingIndexName, trainingFieldName, dimension, modelDescription, in, trainingDataCount);
assertTrainingSucceeds(modelId, 180, 1000);

// Create an index
XContentBuilder builder = XContentFactory.jsonBuilder()
.startObject()
.startObject("properties")
.startObject(fieldName)
.field("type", "knn_vector")
.field("model_id", modelId)
.endObject()
.endObject()
.endObject();

Map<String, Object> mappingMap = xContentBuilderToMap(builder);
String mapping = builder.toString();

createKnnIndex(indexName, mapping);
assertEquals(new TreeMap<>(mappingMap), new TreeMap<>(getIndexMappingAsMap(indexName)));

// Index the test data
for (int i = 0; i < testData.indexData.docs.length; i++) {
addKnnDoc(
indexName,
Integer.toString(testData.indexData.docs[i]),
fieldName,
Floats.asList(testData.indexData.vectors[i]).toArray()
);
}

// Assert we have the right number of documents in the index
refreshAllNonSystemIndices();
assertEquals(testData.indexData.docs.length, getDocCount(indexName));

int k = 10;
for (int i = 0; i < testData.queries.length; i++) {
Response response = searchKNNIndex(indexName, new KNNQueryBuilder(fieldName, testData.queries[i], k), k);
String responseBody = EntityUtils.toString(response.getEntity());
List<KNNResult> knnResults = parseSearchResponse(responseBody, fieldName);
assertEquals(k, knnResults.size());

List<Float> actualScores = parseSearchResponseScore(responseBody, fieldName);
for (int j = 0; j < k; j++) {
float[] primitiveArray = Floats.toArray(Arrays.stream(knnResults.get(j).getVector()).collect(Collectors.toList()));
assertEquals(
KNNEngine.FAISS.score(KNNScoringUtil.l2Squared(testData.queries[i], primitiveArray), spaceType),
actualScores.get(j),
0.0001
);
}
}

// Delete index
deleteKNNIndex(indexName);
deleteModel(modelId);

// Search every 5 seconds 14 times to confirm graph gets evicted
int intervals = 14;
for (int i = 0; i < intervals; i++) {
if (getTotalGraphsInCache() == 0) {
return;
}

Thread.sleep(5 * 1000);
}

fail("Graphs are not getting evicted");
}

public void testDocUpdate() throws IOException {
String indexName = "test-index-1";
String fieldName = "test-field-1";
Expand Down
Loading

0 comments on commit f71fb24

Please sign in to comment.