Skip to content

Commit

Permalink
Add Querying Support to Lucene Byte Sized Vector (opensearch-project#956
Browse files Browse the repository at this point in the history
)

* Add Querying Support to Lucene Byte Sized Vector

Signed-off-by: Naveen Tatikonda <[email protected]>

* Add CHANGELOG

Signed-off-by: Naveen Tatikonda <[email protected]>

* Address Review Comments

Signed-off-by: Naveen Tatikonda <[email protected]>

---------

Signed-off-by: Naveen Tatikonda <[email protected]>
  • Loading branch information
naveentatikonda committed Jul 12, 2023
1 parent be6f699 commit 77db3ab
Show file tree
Hide file tree
Showing 7 changed files with 315 additions and 51 deletions.
21 changes: 18 additions & 3 deletions src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import org.opensearch.index.mapper.NumberFieldMapper;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.knn.index.KNNMethodContext;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
import org.opensearch.knn.index.util.KNNEngine;
import org.opensearch.knn.indices.ModelDao;
Expand All @@ -30,6 +31,8 @@
import java.util.List;
import java.util.Objects;

import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateByteVectorValue;

/**
* Helper class to build the KNN query
*/
Expand Down Expand Up @@ -235,6 +238,7 @@ protected Query doToQuery(QueryShardContext context) {
int fieldDimension = knnVectorFieldType.getDimension();
KNNMethodContext knnMethodContext = knnVectorFieldType.getKnnMethodContext();
KNNEngine knnEngine = KNNEngine.DEFAULT;
VectorDataType vectorDataType = knnVectorFieldType.getVectorDataType();

if (fieldDimension == -1) {
// If dimension is not set, the field uses a model and the information needs to be retrieved from there
Expand All @@ -252,9 +256,18 @@ protected Query doToQuery(QueryShardContext context) {
);
}

byte[] byteVector = new byte[0];
if (VectorDataType.BYTE == vectorDataType) {
byteVector = new byte[vector.length];
for (int i = 0; i < vector.length; i++) {
validateByteVectorValue(vector[i]);
byteVector[i] = (byte) vector[i];
}
}

if (KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(knnEngine)
&& filter != null
&& !KNNEngine.getEnginesThatSupportsFilters().contains(knnEngine)) {
&& filter != null
&& !KNNEngine.getEnginesThatSupportsFilters().contains(knnEngine)) {
throw new IllegalArgumentException(String.format("Engine [%s] does not support filters", knnEngine));
}

Expand All @@ -263,7 +276,9 @@ protected Query doToQuery(QueryShardContext context) {
.knnEngine(knnEngine)
.indexName(indexName)
.fieldName(this.fieldName)
.vector(this.vector)
.vector(VectorDataType.FLOAT == vectorDataType ? this.vector : null)
.byteVector(VectorDataType.BYTE == vectorDataType ? byteVector : null)
.vectorDataType(vectorDataType)
.k(this.k)
.filter(this.filter)
.context(context)
Expand Down
68 changes: 64 additions & 4 deletions src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,21 @@
import lombok.NonNull;
import lombok.Setter;
import lombok.extern.log4j.Log4j2;
import org.apache.lucene.search.KnnByteVectorQuery;
import org.apache.lucene.search.KnnFloatVectorQuery;
import org.apache.lucene.search.Query;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.util.KNNEngine;

import java.io.IOException;
import java.util.Locale;
import java.util.Optional;

import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD;
import static org.opensearch.knn.index.VectorDataType.SUPPORTED_VECTOR_DATA_TYPES;

/**
* Creates the Lucene k-NN queries
*/
Expand All @@ -36,12 +42,20 @@ public class KNNQueryFactory {
* @param k the number of nearest neighbors to return
* @return Lucene Query
*/
public static Query create(KNNEngine knnEngine, String indexName, String fieldName, float[] vector, int k) {
public static Query create(
KNNEngine knnEngine,
String indexName,
String fieldName,
float[] vector,
int k,
VectorDataType vectorDataType
) {
final CreateQueryRequest createQueryRequest = CreateQueryRequest.builder()
.knnEngine(knnEngine)
.indexName(indexName)
.fieldName(fieldName)
.vector(vector)
.vectorDataType(vectorDataType)
.k(k)
.build();
return create(createQueryRequest);
Expand All @@ -59,6 +73,8 @@ public static Query create(CreateQueryRequest createQueryRequest) {
final String fieldName = createQueryRequest.getFieldName();
final int k = createQueryRequest.getK();
final float[] vector = createQueryRequest.getVector();
final byte[] byteVector = createQueryRequest.getByteVector();
final VectorDataType vectorDataType = createQueryRequest.getVectorDataType();
final Query filterQuery = getFilterQuery(createQueryRequest);

if (KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(createQueryRequest.getKnnEngine())) {
Expand All @@ -77,14 +93,54 @@ public static Query create(CreateQueryRequest createQueryRequest) {
return new KNNQuery(fieldName, vector, k, indexName);
}

if (VectorDataType.BYTE == vectorDataType) {
return getKnnByteVectorQuery(indexName, fieldName, byteVector, k, filterQuery);
} else if (VectorDataType.FLOAT == vectorDataType) {
return getKnnFloatVectorQuery(indexName, fieldName, vector, k, filterQuery);
} else {
throw new IllegalArgumentException(
String.format(
Locale.ROOT,
"Invalid value provided for [%s] field. Supported values are [%s]",
VECTOR_DATA_TYPE_FIELD,
SUPPORTED_VECTOR_DATA_TYPES
)
);
}
}

private static Query getKnnByteVectorQuery(String indexName, String fieldName, byte[] byteVector, int k, Query filterQuery) {
if (filterQuery != null) {
log.debug(
String.format(
Locale.ROOT,
"Creating Lucene k-NN query with filters for index: %s \"\", field: %s \"\", k: %d",
indexName,
fieldName,
k
)
);
return new KnnByteVectorQuery(fieldName, byteVector, k, filterQuery);
}
log.debug(String.format("Creating Lucene k-NN query for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k));
return new KnnByteVectorQuery(fieldName, byteVector, k);
}

private static Query getKnnFloatVectorQuery(String indexName, String fieldName, float[] floatVector, int k, Query filterQuery) {
if (filterQuery != null) {
log.debug(
String.format("Creating Lucene k-NN query with filters for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k)
String.format(
Locale.ROOT,
"Creating Lucene k-NN query with filters for index: %s \"\", field: %s \"\", k: %d",
indexName,
fieldName,
k
)
);
return new KnnFloatVectorQuery(fieldName, vector, k, filterQuery);
return new KnnFloatVectorQuery(fieldName, floatVector, k, filterQuery);
}
log.debug(String.format("Creating Lucene k-NN query for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k));
return new KnnFloatVectorQuery(fieldName, vector, k);
return new KnnFloatVectorQuery(fieldName, floatVector, k);
}

private static Query getFilterQuery(CreateQueryRequest createQueryRequest) {
Expand Down Expand Up @@ -126,6 +182,10 @@ static class CreateQueryRequest {
@Getter
private float[] vector;
@Getter
private byte[] byteVector;
@Getter
private VectorDataType vectorDataType;
@Getter
private int k;
// can be null in cases filter not passed with the knn query
private QueryBuilder filter;
Expand Down
111 changes: 71 additions & 40 deletions src/test/java/org/opensearch/knn/index/LuceneEngineIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.primitives.Floats;
import lombok.SneakyThrows;
import org.apache.commons.lang.math.RandomUtils;
import org.apache.hc.core5.http.io.entity.EntityUtils;
import org.apache.lucene.index.VectorSimilarityFunction;
Expand All @@ -34,8 +35,10 @@
import java.util.function.Function;
import java.util.stream.Collectors;

import static org.opensearch.knn.common.KNNConstants.DIMENSION;
import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW;
import static org.opensearch.knn.common.KNNConstants.NMSLIB_NAME;
import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD;

public class LuceneEngineIT extends KNNRestTestCase {

Expand Down Expand Up @@ -110,7 +113,7 @@ public void testQuery_innerProduct_notSupported() throws Exception {

public void testQuery_invalidVectorDimensionInQuery() throws Exception {

createKnnIndexMappingWithLuceneEngine(DIMENSION, SpaceType.L2);
createKnnIndexMappingWithLuceneEngine(DIMENSION, SpaceType.L2, VectorDataType.FLOAT);
for (int j = 0; j < TEST_INDEX_VECTORS.length; j++) {
addKnnDoc(INDEX_NAME, Integer.toString(j + 1), FIELD_NAME, TEST_INDEX_VECTORS[j]);
}
Expand All @@ -127,7 +130,7 @@ public void testQuery_documentsMissingField() throws Exception {

SpaceType spaceType = SpaceType.L2;

createKnnIndexMappingWithLuceneEngine(DIMENSION, spaceType);
createKnnIndexMappingWithLuceneEngine(DIMENSION, spaceType, VectorDataType.FLOAT);
for (int j = 0; j < TEST_INDEX_VECTORS.length; j++) {
addKnnDoc(INDEX_NAME, Integer.toString(j + 1), FIELD_NAME, TEST_INDEX_VECTORS[j]);
}
Expand Down Expand Up @@ -224,35 +227,35 @@ public void testAddDoc() throws Exception {
Float[] vector = new Float[] { 2.0f, 4.5f, 6.5f };
addKnnDoc(INDEX_NAME, DOC_ID, FIELD_NAME, vector);

refreshAllIndices();
refreshIndex(INDEX_NAME);
assertEquals(1, getDocCount(INDEX_NAME));
}

public void testUpdateDoc() throws Exception {
createKnnIndexMappingWithLuceneEngine(2, SpaceType.L2);
createKnnIndexMappingWithLuceneEngine(2, SpaceType.L2, VectorDataType.FLOAT);
Float[] vector = { 6.0f, 6.0f };
addKnnDoc(INDEX_NAME, DOC_ID, FIELD_NAME, vector);

Float[] updatedVector = { 8.0f, 8.0f };
updateKnnDoc(INDEX_NAME, DOC_ID, FIELD_NAME, updatedVector);

refreshAllIndices();
refreshIndex(INDEX_NAME);
assertEquals(1, getDocCount(INDEX_NAME));
}

public void testDeleteDoc() throws Exception {
createKnnIndexMappingWithLuceneEngine(2, SpaceType.L2);
createKnnIndexMappingWithLuceneEngine(2, SpaceType.L2, VectorDataType.FLOAT);
Float[] vector = { 6.0f, 6.0f };
addKnnDoc(INDEX_NAME, DOC_ID, FIELD_NAME, vector);

deleteKnnDoc(INDEX_NAME, DOC_ID);

refreshAllIndices();
refreshIndex(INDEX_NAME);
assertEquals(0, getDocCount(INDEX_NAME));
}

public void testQueryWithFilter() throws Exception {
createKnnIndexMappingWithLuceneEngine(DIMENSION, SpaceType.L2);
public void testQueryWithFilterUsingFloatVectorDataType() throws Exception {
createKnnIndexMappingWithLuceneEngine(DIMENSION, SpaceType.L2, VectorDataType.FLOAT);

addKnnDocWithAttributes(
DOC_ID,
Expand All @@ -262,39 +265,28 @@ public void testQueryWithFilter() throws Exception {
addKnnDocWithAttributes(DOC_ID_2, new float[] { 3.2f, 2.1f, 4.8f }, ImmutableMap.of(COLOR_FIELD_NAME, "green"));
addKnnDocWithAttributes(DOC_ID_3, new float[] { 4.1f, 5.0f, 7.1f }, ImmutableMap.of(COLOR_FIELD_NAME, "red"));

refreshAllIndices();
refreshIndex(INDEX_NAME);

final float[] searchVector = { 6.0f, 6.0f, 4.1f };
int kGreaterThanFilterResult = 5;
List<String> expectedDocIds = Arrays.asList(DOC_ID, DOC_ID_3);
final Response response = searchKNNIndex(
INDEX_NAME,
new KNNQueryBuilder(FIELD_NAME, searchVector, kGreaterThanFilterResult, QueryBuilders.termQuery(COLOR_FIELD_NAME, "red")),
kGreaterThanFilterResult
);
final String responseBody = EntityUtils.toString(response.getEntity());
final List<KNNResult> knnResults = parseSearchResponse(responseBody, FIELD_NAME);
List<String> expectedDocIdsKGreaterThanFilterResult = Arrays.asList(DOC_ID, DOC_ID_3);
List<String> expectedDocIdsKLimitsFilterResult = Arrays.asList(DOC_ID);
validateQueryResultsWithFilters(searchVector, 5, 1, expectedDocIdsKGreaterThanFilterResult, expectedDocIdsKLimitsFilterResult);
}

assertEquals(expectedDocIds.size(), knnResults.size());
assertTrue(knnResults.stream().map(KNNResult::getDocId).collect(Collectors.toList()).containsAll(expectedDocIds));
@SneakyThrows
public void testQueryWithFilterUsingByteVectorDataType() {
createKnnIndexMappingWithLuceneEngine(3, SpaceType.L2, VectorDataType.BYTE);

int kLimitsFilterResult = 1;
List<String> expectedDocIdsKLimitsFilterResult = Arrays.asList(DOC_ID);
final Response responseKLimitsFilterResult = searchKNNIndex(
INDEX_NAME,
new KNNQueryBuilder(FIELD_NAME, searchVector, kLimitsFilterResult, QueryBuilders.termQuery(COLOR_FIELD_NAME, "red")),
kLimitsFilterResult
);
final String responseBodyKLimitsFilterResult = EntityUtils.toString(responseKLimitsFilterResult.getEntity());
final List<KNNResult> knnResultsKLimitsFilterResult = parseSearchResponse(responseBodyKLimitsFilterResult, FIELD_NAME);
addKnnDocWithAttributes(DOC_ID, new float[] { 6.0f, 7.0f, 3.0f }, ImmutableMap.of(COLOR_FIELD_NAME, "red"));
addKnnDocWithAttributes(DOC_ID_2, new float[] { 3.0f, 2.0f, 4.0f }, ImmutableMap.of(COLOR_FIELD_NAME, "green"));
addKnnDocWithAttributes(DOC_ID_3, new float[] { 4.0f, 5.0f, 7.0f }, ImmutableMap.of(COLOR_FIELD_NAME, "red"));

assertEquals(expectedDocIdsKLimitsFilterResult.size(), knnResultsKLimitsFilterResult.size());
assertTrue(
knnResultsKLimitsFilterResult.stream()
.map(KNNResult::getDocId)
.collect(Collectors.toList())
.containsAll(expectedDocIdsKLimitsFilterResult)
);
refreshIndex(INDEX_NAME);

final float[] searchVector = { 6.0f, 6.0f, 4.0f };
List<String> expectedDocIdsKGreaterThanFilterResult = Arrays.asList(DOC_ID, DOC_ID_3);
List<String> expectedDocIdsKLimitsFilterResult = Arrays.asList(DOC_ID);
validateQueryResultsWithFilters(searchVector, 5, 1, expectedDocIdsKGreaterThanFilterResult, expectedDocIdsKLimitsFilterResult);
}

public void testQuery_filterWithNonLuceneEngine() throws Exception {
Expand Down Expand Up @@ -337,7 +329,7 @@ public void testQuery_filterWithNonLuceneEngine() throws Exception {
}

public void testIndexReopening() throws Exception {
createKnnIndexMappingWithLuceneEngine(DIMENSION, SpaceType.L2);
createKnnIndexMappingWithLuceneEngine(DIMENSION, SpaceType.L2, VectorDataType.FLOAT);

for (int j = 0; j < TEST_INDEX_VECTORS.length; j++) {
addKnnDoc(INDEX_NAME, Integer.toString(j + 1), FIELD_NAME, TEST_INDEX_VECTORS[j]);
Expand All @@ -358,13 +350,14 @@ public void testIndexReopening() throws Exception {
assertArrayEquals(knnResultsBeforeIndexClosure.toArray(), knnResultsAfterIndexClosure.toArray());
}

private void createKnnIndexMappingWithLuceneEngine(int dimension, SpaceType spaceType) throws Exception {
private void createKnnIndexMappingWithLuceneEngine(int dimension, SpaceType spaceType, VectorDataType vectorDataType) throws Exception {
XContentBuilder builder = XContentFactory.jsonBuilder()
.startObject()
.startObject(PROPERTIES_FIELD_NAME)
.startObject(FIELD_NAME)
.field(TYPE_FIELD_NAME, KNN_VECTOR_TYPE)
.field(DIMENSION_FIELD_NAME, dimension)
.field(VECTOR_DATA_TYPE_FIELD, vectorDataType)
.startObject(KNNConstants.KNN_METHOD)
.field(KNNConstants.NAME, KNNEngine.LUCENE.getMethod(METHOD_HNSW).getMethodComponent().getName())
.field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue())
Expand All @@ -384,7 +377,7 @@ private void createKnnIndexMappingWithLuceneEngine(int dimension, SpaceType spac

private void baseQueryTest(SpaceType spaceType) throws Exception {

createKnnIndexMappingWithLuceneEngine(DIMENSION, spaceType);
createKnnIndexMappingWithLuceneEngine(DIMENSION, spaceType, VectorDataType.FLOAT);
for (int j = 0; j < TEST_INDEX_VECTORS.length; j++) {
addKnnDoc(INDEX_NAME, Integer.toString(j + 1), FIELD_NAME, TEST_INDEX_VECTORS[j]);
}
Expand Down Expand Up @@ -419,4 +412,42 @@ private List<Float[]> queryResults(final float[] searchVector, final int k) thro
assertNotNull(knnResults);
return knnResults.stream().map(KNNResult::getVector).collect(Collectors.toUnmodifiableList());
}

@SneakyThrows
private void validateQueryResultsWithFilters(
float[] searchVector,
int kGreaterThanFilterResult,
int kLimitsFilterResult,
List<String> expectedDocIdsKGreaterThanFilterResult,
List<String> expectedDocIdsKLimitsFilterResult
) {
final Response response = searchKNNIndex(
INDEX_NAME,
new KNNQueryBuilder(FIELD_NAME, searchVector, kGreaterThanFilterResult, QueryBuilders.termQuery(COLOR_FIELD_NAME, "red")),
kGreaterThanFilterResult
);
final String responseBody = EntityUtils.toString(response.getEntity());
final List<KNNResult> knnResults = parseSearchResponse(responseBody, FIELD_NAME);

assertEquals(expectedDocIdsKGreaterThanFilterResult.size(), knnResults.size());
assertTrue(
knnResults.stream().map(KNNResult::getDocId).collect(Collectors.toList()).containsAll(expectedDocIdsKGreaterThanFilterResult)
);

final Response responseKLimitsFilterResult = searchKNNIndex(
INDEX_NAME,
new KNNQueryBuilder(FIELD_NAME, searchVector, kLimitsFilterResult, QueryBuilders.termQuery(COLOR_FIELD_NAME, "red")),
kLimitsFilterResult
);
final String responseBodyKLimitsFilterResult = EntityUtils.toString(responseKLimitsFilterResult.getEntity());
final List<KNNResult> knnResultsKLimitsFilterResult = parseSearchResponse(responseBodyKLimitsFilterResult, FIELD_NAME);

assertEquals(expectedDocIdsKLimitsFilterResult.size(), knnResultsKLimitsFilterResult.size());
assertTrue(
knnResultsKLimitsFilterResult.stream()
.map(KNNResult::getDocId)
.collect(Collectors.toList())
.containsAll(expectedDocIdsKLimitsFilterResult)
);
}
}
Loading

0 comments on commit 77db3ab

Please sign in to comment.