Skip to content

Commit

Permalink
Enabled the IVF algorithm to work with Filters of K-NN Query.
Browse files Browse the repository at this point in the history
Signed-off-by: Navneet Verma <[email protected]>
  • Loading branch information
navneet1v committed Jul 31, 2023
1 parent ae9c9f4 commit fc4381f
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 5 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
## [Unreleased 2.x](https://github.com/opensearch-project/k-NN/compare/2.9...2.x)
### Features
### Enhancements
1. Enabled the IVF algorithm to work with Filters of K-NN Query. [#1013](https://github.com/opensearch-project/k-NN/pull/)
### Bug Fixes
### Infrastructure
### Documentation
Expand Down
18 changes: 17 additions & 1 deletion src/main/java/org/opensearch/knn/index/query/KNNWeight.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.knn.index.query;

import lombok.extern.log4j.Log4j2;
import org.apache.commons.lang.StringUtils;
import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.index.DocValues;
import org.apache.lucene.search.FilteredDocIdSetIterator;
Expand Down Expand Up @@ -285,12 +286,13 @@ private Map<Integer, Float> doANNSearch(final LeafReaderContext context, final i
}

private Map<Integer, Float> doExactSearch(final LeafReaderContext leafReaderContext, final int[] filterIdsArray) {
log.error("Doing Exact search");
final SegmentReader reader = (SegmentReader) FilterLeafReader.unwrap(leafReaderContext.reader());
final FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(knnQuery.getField());
float[] queryVector = this.knnQuery.getQueryVector();
try {
final BinaryDocValues values = DocValues.getBinary(leafReaderContext.reader(), fieldInfo.getName());
final SpaceType spaceType = SpaceType.getSpace(fieldInfo.getAttribute(SPACE_TYPE));
final SpaceType spaceType = getSpaceType(fieldInfo);
// Creating min heap and init with MAX DocID and Score as -INF.
final HitQueue queue = new HitQueue(this.knnQuery.getK(), true);
ScoreDoc topDoc = queue.top();
Expand Down Expand Up @@ -351,4 +353,18 @@ public static float normalizeScore(float score) {
if (score >= 0) return 1 / (1 + score);
return -score + 1;
}

private SpaceType getSpaceType(final FieldInfo fieldInfo) {
final String spaceTypeString = fieldInfo.getAttribute(SPACE_TYPE);
if(StringUtils.isNotEmpty(spaceTypeString)) {
return SpaceType.getSpace(spaceTypeString);
}

final String modelId = fieldInfo.getAttribute(MODEL_ID);
if(StringUtils.isNotEmpty(modelId)) {
ModelMetadata modelMetadata = modelDao.getMetadata(modelId);
return modelMetadata.getSpaceType();
}
throw new IllegalArgumentException("Unable to find the Model Id from Field Info attributes for field " + fieldInfo.getName());
}
}
26 changes: 22 additions & 4 deletions src/test/java/org/opensearch/knn/index/FaissIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ public void testDocDeletion() throws IOException {
deleteKnnDoc(INDEX_NAME, "1");
}

public void testEndToEnd_fromModel() throws IOException, InterruptedException {
public void testKNNQuery_withModelDifferentCombination_thenSuccess() throws IOException, InterruptedException {
String modelId = "test-model";
int dimension = 128;

Expand Down Expand Up @@ -270,10 +270,10 @@ public void testEndToEnd_fromModel() throws IOException, InterruptedException {
// Index some documents
int numDocs = 100;
for (int i = 0; i < numDocs; i++) {
Float[] indexVector = new Float[dimension];
float[] indexVector = new float[dimension];
Arrays.fill(indexVector, (float) i);

addKnnDoc(indexName, Integer.toString(i), fieldName, indexVector);
addKnnDocWithAttributes(indexName, Integer.toString(i), fieldName, indexVector, ImmutableMap.of("rating",
String.valueOf(i)));
}

// Run search and ensure that the values returned are expected
Expand All @@ -287,6 +287,24 @@ public void testEndToEnd_fromModel() throws IOException, InterruptedException {
for (int i = 0; i < k; i++) {
assertEquals(numDocs - i - 1, Integer.parseInt(results.get(i).getDocId()));
}

// doing exact search with filters
Response exactSearchFilteredResponse = searchKNNIndex(indexName, new KNNQueryBuilder(fieldName, queryVector, k,
QueryBuilders.rangeQuery("rating").gte("90").lte("99")), k);
List<KNNResult> exactSearchFilteredResults =
parseSearchResponse(EntityUtils.toString(exactSearchFilteredResponse.getEntity()), fieldName);
for (int i = 0; i < k; i++) {
assertEquals(numDocs - i - 1, Integer.parseInt(exactSearchFilteredResults.get(i).getDocId()));
}

// doing exact search with filters
Response aNNSearchFilteredResponse = searchKNNIndex(indexName, new KNNQueryBuilder(fieldName, queryVector, k,
QueryBuilders.rangeQuery("rating").gte("80").lte("99")), k);
List<KNNResult> aNNSearchFilteredResults =
parseSearchResponse(EntityUtils.toString(aNNSearchFilteredResponse.getEntity()), fieldName);
for (int i = 0; i < k; i++) {
assertEquals(numDocs - i - 1, Integer.parseInt(aNNSearchFilteredResults.get(i).getDocId()));
}
}

@SneakyThrows
Expand Down
17 changes: 17 additions & 0 deletions src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java
Original file line number Diff line number Diff line change
Expand Up @@ -1322,4 +1322,21 @@ protected void addKnnDocWithAttributes(String docId, float[] vector, Map<String,
Response response = client().performRequest(request);
assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));
}

protected void addKnnDocWithAttributes(String indexName, String docId, String vectorFieldName, float[] vector,
Map<String, String> fieldValues) throws IOException {
Request request = new Request("POST", "/" + indexName + "/_doc/" + docId + "?refresh=true");

final XContentBuilder builder = XContentFactory.jsonBuilder().startObject().field(vectorFieldName, vector);
for (String fieldName : fieldValues.keySet()) {
builder.field(fieldName, fieldValues.get(fieldName));
}
builder.endObject();
request.setJsonEntity(Strings.toString(builder));
client().performRequest(request);

request = new Request("POST", "/" + indexName + "/_refresh");
Response response = client().performRequest(request);
assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));
}
}

0 comments on commit fc4381f

Please sign in to comment.