Skip to content

Commit

Permalink
Enabled the IVF algorithm to work with Filters of K-NN Query. (opense…
Browse files Browse the repository at this point in the history
…arch-project#1013)

Signed-off-by: Navneet Verma <[email protected]>
  • Loading branch information
navneet1v committed Aug 1, 2023
1 parent 3aedd94 commit 49c0033
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 5 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
## [Unreleased 2.x](https://github.com/opensearch-project/k-NN/compare/2.9...2.x)
### Features
### Enhancements
* Enabled the IVF algorithm to work with Filters of K-NN Query. [#1013](https://github.com/opensearch-project/k-NN/pull/1013)
### Bug Fixes
### Infrastructure
### Documentation
Expand Down
20 changes: 19 additions & 1 deletion src/main/java/org/opensearch/knn/index/query/KNNWeight.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.knn.index.query;

import lombok.extern.log4j.Log4j2;
import org.apache.commons.lang.StringUtils;
import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.index.DocValues;
import org.apache.lucene.search.FilteredDocIdSetIterator;
Expand Down Expand Up @@ -49,6 +50,7 @@
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.concurrent.ExecutionException;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -290,7 +292,7 @@ private Map<Integer, Float> doExactSearch(final LeafReaderContext leafReaderCont
float[] queryVector = this.knnQuery.getQueryVector();
try {
final BinaryDocValues values = DocValues.getBinary(leafReaderContext.reader(), fieldInfo.getName());
final SpaceType spaceType = SpaceType.getSpace(fieldInfo.getAttribute(SPACE_TYPE));
final SpaceType spaceType = getSpaceType(fieldInfo);
// Creating min heap and init with MAX DocID and Score as -INF.
final HitQueue queue = new HitQueue(this.knnQuery.getK(), true);
ScoreDoc topDoc = queue.top();
Expand Down Expand Up @@ -351,4 +353,20 @@ public static float normalizeScore(float score) {
if (score >= 0) return 1 / (1 + score);
return -score + 1;
}

private SpaceType getSpaceType(final FieldInfo fieldInfo) {
final String spaceTypeString = fieldInfo.getAttribute(SPACE_TYPE);
if (StringUtils.isNotEmpty(spaceTypeString)) {
return SpaceType.getSpace(spaceTypeString);
}

final String modelId = fieldInfo.getAttribute(MODEL_ID);
if (StringUtils.isNotEmpty(modelId)) {
ModelMetadata modelMetadata = modelDao.getMetadata(modelId);
return modelMetadata.getSpaceType();
}
throw new IllegalArgumentException(
String.format(Locale.ROOT, "Unable to find the Space Type from Field Info attribute for field %s", fieldInfo.getName())
);
}
}
37 changes: 33 additions & 4 deletions src/test/java/org/opensearch/knn/index/FaissIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,9 @@ public void testDocDeletion() throws IOException {
deleteKnnDoc(INDEX_NAME, "1");
}

public void testEndToEnd_fromModel() throws Exception {
@SneakyThrows
public void testKNNQuery_withModelDifferentCombination_thenSuccess() {

String modelId = "test-model";
int dimension = 128;

Expand Down Expand Up @@ -270,10 +272,9 @@ public void testEndToEnd_fromModel() throws Exception {
// Index some documents
int numDocs = 100;
for (int i = 0; i < numDocs; i++) {
Float[] indexVector = new Float[dimension];
float[] indexVector = new float[dimension];
Arrays.fill(indexVector, (float) i);

addKnnDoc(indexName, Integer.toString(i), fieldName, indexVector);
addKnnDocWithAttributes(indexName, Integer.toString(i), fieldName, indexVector, ImmutableMap.of("rating", String.valueOf(i)));
}

// Run search and ensure that the values returned are expected
Expand All @@ -287,6 +288,34 @@ public void testEndToEnd_fromModel() throws Exception {
for (int i = 0; i < k; i++) {
assertEquals(numDocs - i - 1, Integer.parseInt(results.get(i).getDocId()));
}

// doing exact search with filters
Response exactSearchFilteredResponse = searchKNNIndex(
indexName,
new KNNQueryBuilder(fieldName, queryVector, k, QueryBuilders.rangeQuery("rating").gte("90").lte("99")),
k
);
List<KNNResult> exactSearchFilteredResults = parseSearchResponse(
EntityUtils.toString(exactSearchFilteredResponse.getEntity()),
fieldName
);
for (int i = 0; i < k; i++) {
assertEquals(numDocs - i - 1, Integer.parseInt(exactSearchFilteredResults.get(i).getDocId()));
}

// doing exact search with filters
Response aNNSearchFilteredResponse = searchKNNIndex(
indexName,
new KNNQueryBuilder(fieldName, queryVector, k, QueryBuilders.rangeQuery("rating").gte("80").lte("99")),
k
);
List<KNNResult> aNNSearchFilteredResults = parseSearchResponse(
EntityUtils.toString(aNNSearchFilteredResponse.getEntity()),
fieldName
);
for (int i = 0; i < k; i++) {
assertEquals(numDocs - i - 1, Integer.parseInt(aNNSearchFilteredResults.get(i).getDocId()));
}
}

@SneakyThrows
Expand Down
22 changes: 22 additions & 0 deletions src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java
Original file line number Diff line number Diff line change
Expand Up @@ -1338,4 +1338,26 @@ protected void addKnnDocWithAttributes(String docId, float[] vector, Map<String,
Response response = client().performRequest(request);
assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));
}

protected void addKnnDocWithAttributes(
String indexName,
String docId,
String vectorFieldName,
float[] vector,
Map<String, String> fieldValues
) throws IOException {
Request request = new Request("POST", "/" + indexName + "/_doc/" + docId + "?refresh=true");

final XContentBuilder builder = XContentFactory.jsonBuilder().startObject().field(vectorFieldName, vector);
for (String fieldName : fieldValues.keySet()) {
builder.field(fieldName, fieldValues.get(fieldName));
}
builder.endObject();
request.setJsonEntity(Strings.toString(builder));
client().performRequest(request);

request = new Request("POST", "/" + indexName + "/_refresh");
Response response = client().performRequest(request);
assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));
}
}

0 comments on commit 49c0033

Please sign in to comment.