Skip to content

Commit

Permalink
Add Tests
Browse files Browse the repository at this point in the history
Signed-off-by: Naveen Tatikonda <[email protected]>
  • Loading branch information
naveentatikonda committed Jul 7, 2023
1 parent f9bf4e3 commit ba8a641
Show file tree
Hide file tree
Showing 2 changed files with 366 additions and 0 deletions.
257 changes: 257 additions & 0 deletions src/test/java/org/opensearch/knn/index/VectorDataTypeIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,29 @@
package org.opensearch.knn.index;

import lombok.SneakyThrows;
import org.apache.hc.core5.http.io.entity.EntityUtils;
import org.junit.After;
import org.opensearch.client.Request;
import org.opensearch.client.Response;
import org.opensearch.client.ResponseException;
import org.opensearch.common.Strings;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.index.query.MatchAllQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.knn.KNNRestTestCase;
import org.opensearch.knn.KNNResult;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.util.KNNEngine;
import org.opensearch.rest.RestStatus;
import org.opensearch.script.Script;

import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;

import static org.opensearch.knn.common.KNNConstants.DIMENSION;
import static org.opensearch.knn.common.KNNConstants.LUCENE_NAME;
Expand All @@ -32,6 +45,7 @@ public class VectorDataTypeIT extends KNNRestTestCase {
private static final String KNN_VECTOR_TYPE = "knn_vector";
private static final int EF_CONSTRUCTION = 128;
private static final int M = 16;
private static final QueryBuilder MATCH_ALL_QUERY_BUILDER = new MatchAllQueryBuilder();

@After
@SneakyThrows
Expand Down Expand Up @@ -176,6 +190,202 @@ public void testByteVectorDataTypeWithNmslibEngine() {
);
}

@SneakyThrows
public void testByteVectorDataTypeWithLegacyFieldMapperKnnIndexSetting() {
// Create an index with byte vector data_type and index.knn as true without setting KnnMethodContext,
// which should throw an exception because the LegacyFieldMapper will use NMSLIB engine and byte data_type
// is not supported for NMSLIB engine.
XContentBuilder builder = XContentFactory.jsonBuilder()
.startObject()
.startObject(PROPERTIES_FIELD)
.startObject(FIELD_NAME)
.field(TYPE_FIELD_NAME, KNN_VECTOR_TYPE)
.field(DIMENSION, 2)
.field(VECTOR_DATA_TYPE_FIELD, VectorDataType.BYTE.getValue())
.endObject()
.endObject()
.endObject();

String mapping = Strings.toString(builder);

ResponseException ex = expectThrows(ResponseException.class, () -> createKnnIndex(INDEX_NAME, mapping));
assertTrue(
ex.getMessage()
.contains(
String.format(
Locale.ROOT,
"[%s] field with value [%s] is only supported for [%s] engine",
VECTOR_DATA_TYPE_FIELD,
VectorDataType.BYTE.getValue(),
LUCENE_NAME
)
)
);

}

public void testDocValuesWithByteVectorDataTypeLuceneEngine() throws Exception {
createKnnIndexMappingWithLuceneEngine(2, SpaceType.L2, VectorDataType.BYTE.getValue());
ingestL2ByteTestData();

Byte[] queryVector = { 1, 1 };
Request request = createScriptQueryRequest(queryVector, SpaceType.L2.getValue(), MATCH_ALL_QUERY_BUILDER);
Response response = client().performRequest(request);
assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));

validateL2SearchResults(response);
}

public void testDocValuesWithFloatVectorDataTypeLuceneEngine() throws Exception {
createKnnIndexMappingWithLuceneEngine(2, SpaceType.L2, VectorDataType.FLOAT.getValue());
ingestL2FloatTestData();

Byte[] queryVector = { 1, 1 };
Request request = createScriptQueryRequest(queryVector, SpaceType.L2.getValue(), MATCH_ALL_QUERY_BUILDER);
Response response = client().performRequest(request);
assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));

validateL2SearchResults(response);
}

public void testL2ScriptScoreWithByteVectorDataType() throws Exception {
createKnnIndexMappingForScripting(2, VectorDataType.BYTE.getValue());
ingestL2ByteTestData();

Byte[] queryVector = { 1, 1 };
Request request = createScriptQueryRequest(queryVector, SpaceType.L2.getValue(), MATCH_ALL_QUERY_BUILDER);
Response response = client().performRequest(request);
assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));

validateL2SearchResults(response);
}

public void testL2ScriptScoreWithFloatVectorDataType() throws Exception {
createKnnIndexMappingForScripting(2, VectorDataType.FLOAT.getValue());
ingestL2FloatTestData();

Float[] queryVector = { 1.0f, 1.0f };
Request request = createScriptQueryRequest(queryVector, SpaceType.L2.getValue(), MATCH_ALL_QUERY_BUILDER);
Response response = client().performRequest(request);
assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));

validateL2SearchResults(response);
}

public void testL2PainlessScriptingWithByteVectorDataType() throws Exception {
createKnnIndexMappingForScripting(2, VectorDataType.BYTE.getValue());
ingestL2ByteTestData();

String source = String.format("1/(1 + l2Squared([1, 1], doc['%s']))", FIELD_NAME);
Request request = constructScriptScoreContextSearchRequest(
INDEX_NAME,
MATCH_ALL_QUERY_BUILDER,
Collections.emptyMap(),
Script.DEFAULT_SCRIPT_LANG,
source,
4
);

Response response = client().performRequest(request);
assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));

validateL2SearchResults(response);
}

public void testL2PainlessScriptingWithFloatVectorDataType() throws Exception {
createKnnIndexMappingForScripting(2, VectorDataType.FLOAT.getValue());
ingestL2FloatTestData();

String source = String.format("1/(1 + l2Squared([1.0f, 1.0f], doc['%s']))", FIELD_NAME);
Request request = constructScriptScoreContextSearchRequest(
INDEX_NAME,
MATCH_ALL_QUERY_BUILDER,
Collections.emptyMap(),
Script.DEFAULT_SCRIPT_LANG,
source,
4
);

Response response = client().performRequest(request);
assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));

validateL2SearchResults(response);
}

public void testKNNScriptScoreWithInvalidVectorDataType() {
// Set an invalid value for data_type field while creating the index for script scoring which should throw an exception
ResponseException ex = expectThrows(ResponseException.class, () -> createKnnIndexMappingForScripting(2, "invalid_data_type"));
assertTrue(
ex.getMessage()
.contains(
String.format(
Locale.ROOT,
"Invalid value provided for [%s] field. Supported values are [%s]",
VECTOR_DATA_TYPE_FIELD,
SUPPORTED_VECTOR_DATA_TYPES
)
)
);
}

public void testKNNScriptScoreWithInvalidByteQueryVector() throws Exception {
// Create an index with byte vector data_type, add docs and run a scoring script query with decimal values
// which should throw exception
createKnnIndexMappingForScripting(2, VectorDataType.BYTE.getValue());

Byte[] f1 = { 6, 6 };
addKnnDoc(INDEX_NAME, "1", FIELD_NAME, f1);

Byte[] f2 = { 2, 2 };
addKnnDoc(INDEX_NAME, "2", FIELD_NAME, f2);

// Construct Search Request with query vector having decimal values
Float[] queryVector = { 10.67f, 19.78f };
Request request = createScriptQueryRequest(queryVector, SpaceType.L2.getValue(), MATCH_ALL_QUERY_BUILDER);
ResponseException ex = expectThrows(ResponseException.class, () -> client().performRequest(request));
assertTrue(
ex.getMessage()
.contains(
String.format(
Locale.ROOT,
"[%s] field was set as [%s] in index mapping. But, KNN vector values are floats instead of byte integers",
VECTOR_DATA_TYPE_FIELD,
VectorDataType.BYTE.getValue()
)
)
);
}

@SneakyThrows
private void ingestL2ByteTestData() {
Byte[] b1 = { 6, 6 };
addKnnDoc(INDEX_NAME, "1", FIELD_NAME, b1);

Byte[] b2 = { 2, 2 };
addKnnDoc(INDEX_NAME, "2", FIELD_NAME, b2);

Byte[] b3 = { 4, 4 };
addKnnDoc(INDEX_NAME, "3", FIELD_NAME, b3);

Byte[] b4 = { 3, 3 };
addKnnDoc(INDEX_NAME, "4", FIELD_NAME, b4);
}

@SneakyThrows
private void ingestL2FloatTestData() {
Float[] f1 = { 6.0f, 6.0f };
addKnnDoc(INDEX_NAME, "1", FIELD_NAME, f1);

Float[] f2 = { 2.0f, 2.0f };
addKnnDoc(INDEX_NAME, "2", FIELD_NAME, f2);

Float[] f3 = { 4.0f, 4.0f };
addKnnDoc(INDEX_NAME, "3", FIELD_NAME, f3);

Float[] f4 = { 3.0f, 3.0f };
addKnnDoc(INDEX_NAME, "4", FIELD_NAME, f4);
}

private void createKnnIndexMappingWithNmslibEngine(int dimension, SpaceType spaceType, String vectorDataType) throws Exception {
createKnnIndexMappingWithCustomEngine(dimension, spaceType, vectorDataType, KNNEngine.NMSLIB.getName());
}
Expand Down Expand Up @@ -209,4 +419,51 @@ private void createKnnIndexMappingWithCustomEngine(int dimension, SpaceType spac
String mapping = Strings.toString(builder);
createKnnIndex(INDEX_NAME, mapping);
}

private void createKnnIndexMappingForScripting(int dimension, String vectorDataType) throws Exception {
XContentBuilder builder = XContentFactory.jsonBuilder()
.startObject()
.startObject(PROPERTIES_FIELD)
.startObject(FIELD_NAME)
.field(TYPE_FIELD_NAME, KNN_VECTOR_TYPE)
.field(DIMENSION, dimension)
.field(VECTOR_DATA_TYPE_FIELD, vectorDataType)
.endObject()
.endObject()
.endObject();

String mapping = Strings.toString(builder);
createKnnIndex(INDEX_NAME, Settings.EMPTY, mapping);
}

@SneakyThrows
private Request createScriptQueryRequest(Byte[] queryVector, String spaceType, QueryBuilder qb) {
Map<String, Object> params = new HashMap<>();
params.put("field", FIELD_NAME);
params.put("query_value", queryVector);
params.put("space_type", spaceType);
return constructKNNScriptQueryRequest(INDEX_NAME, qb, params);
}

@SneakyThrows
private Request createScriptQueryRequest(Float[] queryVector, String spaceType, QueryBuilder qb) {
Map<String, Object> params = new HashMap<>();
params.put("field", FIELD_NAME);
params.put("query_value", queryVector);
params.put("space_type", spaceType);
return constructKNNScriptQueryRequest(INDEX_NAME, qb, params);
}

@SneakyThrows
private void validateL2SearchResults(Response response) {

List<KNNResult> results = parseSearchResponse(EntityUtils.toString(response.getEntity()), FIELD_NAME);

assertEquals(4, results.size());

String[] expectedDocIDs = { "2", "4", "3", "1" };
for (int i = 0; i < results.size(); i++) {
assertEquals(expectedDocIDs[i], results.get(i).getDocId());
}
}
}
Loading

0 comments on commit ba8a641

Please sign in to comment.