Skip to content

Commit

Permalink
Update Faiss engine to allow PQ and HNSW (opensearch-project#1074)
Browse files Browse the repository at this point in the history
Updates faiss engine to enable hnsw and faiss to work together. For
HNSW, code_size must be equal to 8 (refer to
facebookresearch/faiss#3027). Therefore, the
index description string "HNSW32,PQXxY" does not work. Only "HNSW32,PQX"
ends up working.

Additionally, adds several unit tests and integration tests in order to
validate the functionality.

Signed-off-by: John Mazanec <[email protected]>
(cherry picked from commit ce47b1b)
  • Loading branch information
jmazanec15 committed Aug 31, 2023
1 parent f7b4415 commit f0dfb8d
Show file tree
Hide file tree
Showing 7 changed files with 492 additions and 89 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* Improved the logic to switch to exact search for restrictive filters search for better recall. [#1059](https://github.com/opensearch-project/k-NN/pull/1059)
* Added max distance computation logic to enhance the switch to exact search in filtered Nearest Neighbor Search. [#1066](https://github.com/opensearch-project/k-NN/pull/1066)
### Bug Fixes
* Update Faiss parameter construction to allow HNSW+PQ to work [#1074](https://github.com/opensearch-project/k-NN/pull/1074)
### Infrastructure
### Documentation
### Maintenance
Expand Down
160 changes: 104 additions & 56 deletions src/main/java/org/opensearch/knn/index/util/Faiss.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.function.Function;

import static org.opensearch.knn.common.KNNConstants.BYTES_PER_KILOBYTES;
Expand Down Expand Up @@ -64,9 +65,7 @@ class Faiss extends NativeLibrary {
Collections.emptyMap()
);

// TODO: To think about in future: for PQ, if dimension is not divisible by code count, PQ will fail. Right now,
// we do not have a way to base validation off of dimension. Failure will happen during training in JNI.
private final static Map<String, MethodComponent> encoderComponents = ImmutableMap.of(
private final static Map<String, MethodComponent> COMMON_ENCODERS = ImmutableMap.of(
KNNConstants.ENCODER_FLAT,
MethodComponent.Builder.builder(KNNConstants.ENCODER_FLAT)
.setMapGenerator(
Expand All @@ -76,62 +75,111 @@ class Faiss extends NativeLibrary {
methodComponentContext
).build())
)
.build(),
KNNConstants.ENCODER_PQ,
MethodComponent.Builder.builder(KNNConstants.ENCODER_PQ)
.addParameter(
ENCODER_PARAMETER_PQ_M,
new Parameter.IntegerParameter(
ENCODER_PARAMETER_PQ_M,
ENCODER_PARAMETER_PQ_CODE_COUNT_DEFAULT,
v -> v > 0 && v < ENCODER_PARAMETER_PQ_CODE_COUNT_LIMIT
)
)
.addParameter(
ENCODER_PARAMETER_PQ_CODE_SIZE,
new Parameter.IntegerParameter(
ENCODER_PARAMETER_PQ_CODE_SIZE,
ENCODER_PARAMETER_PQ_CODE_SIZE_DEFAULT,
v -> v > 0 && v < ENCODER_PARAMETER_PQ_CODE_SIZE_LIMIT
)
)
.setRequiresTraining(true)
.setMapGenerator(
((methodComponent, methodComponentContext) -> MethodAsMapBuilder.builder(
FAISS_PQ_DESCRIPTION,
methodComponent,
methodComponentContext
).addParameter(ENCODER_PARAMETER_PQ_M, "", "").addParameter(ENCODER_PARAMETER_PQ_CODE_SIZE, "x", "").build())
.build()
);

// TODO: To think about in future: for PQ, if dimension is not divisible by code count, PQ will fail. Right now,
// we do not have a way to base validation off of dimension. Failure will happen during training in JNI.
// Define methods supported by faiss. See issue here: https://github.com/opensearch-project/k-NN/issues/1075
private final static Map<String, MethodComponent> HNSW_ENCODERS = ImmutableMap.<String, MethodComponent>builder()
.putAll(
ImmutableMap.of(
KNNConstants.ENCODER_PQ,
MethodComponent.Builder.builder(KNNConstants.ENCODER_PQ)
.addParameter(
ENCODER_PARAMETER_PQ_M,
new Parameter.IntegerParameter(
ENCODER_PARAMETER_PQ_M,
ENCODER_PARAMETER_PQ_CODE_COUNT_DEFAULT,
v -> v > 0 && v < ENCODER_PARAMETER_PQ_CODE_COUNT_LIMIT
)
)
.addParameter(
ENCODER_PARAMETER_PQ_CODE_SIZE,
new Parameter.IntegerParameter(
ENCODER_PARAMETER_PQ_CODE_SIZE,
ENCODER_PARAMETER_PQ_CODE_SIZE_DEFAULT,
v -> Objects.equals(v, ENCODER_PARAMETER_PQ_CODE_SIZE_DEFAULT)
)
)
.setRequiresTraining(true)
.setMapGenerator(
((methodComponent, methodComponentContext) -> MethodAsMapBuilder.builder(
FAISS_PQ_DESCRIPTION,
methodComponent,
methodComponentContext
).addParameter(ENCODER_PARAMETER_PQ_M, "", "").build())
)
.setOverheadInKBEstimator((methodComponent, methodComponentContext, dimension) -> {
int codeSize = ENCODER_PARAMETER_PQ_CODE_SIZE_DEFAULT;
return ((4L * (1L << codeSize) * dimension) / BYTES_PER_KILOBYTES) + 1;
})
.build()
)
.setOverheadInKBEstimator((methodComponent, methodComponentContext, dimension) -> {
// Size estimate formula: (4 * d * 2^code_size) / 1024 + 1

// Get value of code size passed in by user
Object codeSizeObject = methodComponentContext.getParameters().get(ENCODER_PARAMETER_PQ_CODE_SIZE);

// If not specified, get default value of code size
if (codeSizeObject == null) {
Parameter<?> codeSizeParameter = methodComponent.getParameters().get(ENCODER_PARAMETER_PQ_CODE_SIZE);
if (codeSizeParameter == null) {
throw new IllegalStateException(
String.format("%s is not a valid parameter. This is a bug.", ENCODER_PARAMETER_PQ_CODE_SIZE)
);
}
)
.putAll(COMMON_ENCODERS)
.build();

codeSizeObject = codeSizeParameter.getDefaultValue();
}
private final static Map<String, MethodComponent> IVF_ENCODERS = ImmutableMap.<String, MethodComponent>builder()
.putAll(
ImmutableMap.of(
KNNConstants.ENCODER_PQ,
MethodComponent.Builder.builder(KNNConstants.ENCODER_PQ)
.addParameter(
ENCODER_PARAMETER_PQ_M,
new Parameter.IntegerParameter(
ENCODER_PARAMETER_PQ_M,
ENCODER_PARAMETER_PQ_CODE_COUNT_DEFAULT,
v -> v > 0 && v < ENCODER_PARAMETER_PQ_CODE_COUNT_LIMIT
)
)
.addParameter(
ENCODER_PARAMETER_PQ_CODE_SIZE,
new Parameter.IntegerParameter(
ENCODER_PARAMETER_PQ_CODE_SIZE,
ENCODER_PARAMETER_PQ_CODE_SIZE_DEFAULT,
v -> v > 0 && v < ENCODER_PARAMETER_PQ_CODE_SIZE_LIMIT
)
)
.setRequiresTraining(true)
.setMapGenerator(
((methodComponent, methodComponentContext) -> MethodAsMapBuilder.builder(
FAISS_PQ_DESCRIPTION,
methodComponent,
methodComponentContext
).addParameter(ENCODER_PARAMETER_PQ_M, "", "").addParameter(ENCODER_PARAMETER_PQ_CODE_SIZE, "x", "").build())
)
.setOverheadInKBEstimator((methodComponent, methodComponentContext, dimension) -> {
// Size estimate formula: (4 * d * 2^code_size) / 1024 + 1

if (!(codeSizeObject instanceof Integer)) {
throw new IllegalStateException(String.format("%s must be an integer.", ENCODER_PARAMETER_PQ_CODE_SIZE));
}
// Get value of code size passed in by user
Object codeSizeObject = methodComponentContext.getParameters().get(ENCODER_PARAMETER_PQ_CODE_SIZE);

int codeSize = (Integer) codeSizeObject;
return ((4L * (1L << codeSize) * dimension) / BYTES_PER_KILOBYTES) + 1;
})
.build()
);
// If not specified, get default value of code size
if (codeSizeObject == null) {
Parameter<?> codeSizeParameter = methodComponent.getParameters().get(ENCODER_PARAMETER_PQ_CODE_SIZE);
if (codeSizeParameter == null) {
throw new IllegalStateException(
String.format("%s is not a valid parameter. This is a bug.", ENCODER_PARAMETER_PQ_CODE_SIZE)
);
}

codeSizeObject = codeSizeParameter.getDefaultValue();
}

if (!(codeSizeObject instanceof Integer)) {
throw new IllegalStateException(String.format("%s must be an integer.", ENCODER_PARAMETER_PQ_CODE_SIZE));
}

int codeSize = (Integer) codeSizeObject;
return ((4L * (1L << codeSize) * dimension) / BYTES_PER_KILOBYTES) + 1;
})
.build()
)
)
.putAll(COMMON_ENCODERS)
.build();

// Define methods supported by faiss
private final static Map<String, KNNMethod> METHODS = ImmutableMap.of(
METHOD_HNSW,
KNNMethod.Builder.builder(
Expand All @@ -158,7 +206,7 @@ class Faiss extends NativeLibrary {
)
.addParameter(
METHOD_ENCODER_PARAMETER,
new Parameter.MethodComponentContextParameter(METHOD_ENCODER_PARAMETER, ENCODER_DEFAULT, encoderComponents)
new Parameter.MethodComponentContextParameter(METHOD_ENCODER_PARAMETER, ENCODER_DEFAULT, HNSW_ENCODERS)
)
.setMapGenerator(
((methodComponent, methodComponentContext) -> MethodAsMapBuilder.builder(
Expand Down Expand Up @@ -190,7 +238,7 @@ class Faiss extends NativeLibrary {
)
.addParameter(
METHOD_ENCODER_PARAMETER,
new Parameter.MethodComponentContextParameter(METHOD_ENCODER_PARAMETER, ENCODER_DEFAULT, encoderComponents)
new Parameter.MethodComponentContextParameter(METHOD_ENCODER_PARAMETER, ENCODER_DEFAULT, IVF_ENCODERS)
)
.setRequiresTraining(true)
.setMapGenerator(
Expand Down
119 changes: 118 additions & 1 deletion src/test/java/org/opensearch/knn/index/FaissIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,12 @@
import java.util.TreeMap;
import java.util.stream.Collectors;

import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_M;
import static org.opensearch.knn.common.KNNConstants.ENCODER_PQ;
import static org.opensearch.knn.common.KNNConstants.FAISS_NAME;
import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE;
import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER;
import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NLIST;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE;
import static org.opensearch.knn.common.KNNConstants.MODEL_ID;
Expand All @@ -62,7 +67,8 @@ public static void setUpClass() throws IOException {
testData = new TestUtils.TestData(testIndexVectors.getPath(), testQueries.getPath());
}

public void testEndToEnd_fromMethod() throws Exception {
@SneakyThrows
public void testEndToEnd_whenMethodIsHNSWFlat_thenSucceed() {
String indexName = "test-index-1";
String fieldName = "test-field-1";

Expand Down Expand Up @@ -150,6 +156,117 @@ public void testEndToEnd_fromMethod() throws Exception {
fail("Graphs are not getting evicted");
}

@SneakyThrows
public void testEndToEnd_whenMethodIsHNSWPQ_thenSucceed() {
String indexName = "test-index";
String fieldName = "test-field";
String trainingIndexName = "training-index";
String trainingFieldName = "training-field";

String modelId = "test-model";
String modelDescription = "test model";

List<Integer> mValues = ImmutableList.of(16, 32, 64, 128);
List<Integer> efConstructionValues = ImmutableList.of(16, 32, 64, 128);
List<Integer> efSearchValues = ImmutableList.of(16, 32, 64, 128);
List<Integer> pqMValues = ImmutableList.of(2, 4, 8);

// training data needs to be at least equal to the number of centroids for PQ
// which is 2^8 = 256. 8 because thats the only valid code_size for HNSWPQ
int trainingDataCount = 256;

SpaceType spaceType = SpaceType.L2;

Integer dimension = testData.indexData.vectors[0].length;

XContentBuilder xContentBuilder = XContentFactory.jsonBuilder()
.startObject()
.field(NAME, METHOD_HNSW)
.field(KNN_ENGINE, FAISS_NAME)
.startObject(PARAMETERS)
.field(KNNConstants.METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size())))
.field(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, efConstructionValues.get(random().nextInt(efConstructionValues.size())))
.field(KNNConstants.METHOD_PARAMETER_EF_SEARCH, efSearchValues.get(random().nextInt(efSearchValues.size())))
.startObject(METHOD_ENCODER_PARAMETER)
.field(NAME, ENCODER_PQ)
.startObject(PARAMETERS)
.field(ENCODER_PARAMETER_PQ_M, pqMValues.get(random().nextInt(pqMValues.size())))
.endObject()
.endObject()
.endObject()
.endObject();
Map<String, Object> in = xContentBuilderToMap(xContentBuilder);

createBasicKnnIndex(trainingIndexName, trainingFieldName, dimension);
ingestDataAndTrainModel(modelId, trainingIndexName, trainingFieldName, dimension, modelDescription, in, trainingDataCount);
assertTrainingSucceeds(modelId, 180, 1000);

// Create an index
XContentBuilder builder = XContentFactory.jsonBuilder()
.startObject()
.startObject("properties")
.startObject(fieldName)
.field("type", "knn_vector")
.field("model_id", modelId)
.endObject()
.endObject()
.endObject();

Map<String, Object> mappingMap = xContentBuilderToMap(builder);
String mapping = builder.toString();

createKnnIndex(indexName, mapping);
assertEquals(new TreeMap<>(mappingMap), new TreeMap<>(getIndexMappingAsMap(indexName)));

// Index the test data
for (int i = 0; i < testData.indexData.docs.length; i++) {
addKnnDoc(
indexName,
Integer.toString(testData.indexData.docs[i]),
fieldName,
Floats.asList(testData.indexData.vectors[i]).toArray()
);
}

// Assert we have the right number of documents in the index
refreshAllNonSystemIndices();
assertEquals(testData.indexData.docs.length, getDocCount(indexName));

int k = 10;
for (int i = 0; i < testData.queries.length; i++) {
Response response = searchKNNIndex(indexName, new KNNQueryBuilder(fieldName, testData.queries[i], k), k);
String responseBody = EntityUtils.toString(response.getEntity());
List<KNNResult> knnResults = parseSearchResponse(responseBody, fieldName);
assertEquals(k, knnResults.size());

List<Float> actualScores = parseSearchResponseScore(responseBody, fieldName);
for (int j = 0; j < k; j++) {
float[] primitiveArray = Floats.toArray(Arrays.stream(knnResults.get(j).getVector()).collect(Collectors.toList()));
assertEquals(
KNNEngine.FAISS.score(KNNScoringUtil.l2Squared(testData.queries[i], primitiveArray), spaceType),
actualScores.get(j),
0.0001
);
}
}

// Delete index
deleteKNNIndex(indexName);
deleteModel(modelId);

// Search every 5 seconds 14 times to confirm graph gets evicted
int intervals = 14;
for (int i = 0; i < intervals; i++) {
if (getTotalGraphsInCache() == 0) {
return;
}

Thread.sleep(5 * 1000);
}

fail("Graphs are not getting evicted");
}

public void testDocUpdate() throws IOException {
String indexName = "test-index-1";
String fieldName = "test-field-1";
Expand Down
Loading

0 comments on commit f0dfb8d

Please sign in to comment.