Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add integration tests for neural query #36

Merged
merged 8 commits into from
Oct 31, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,10 @@ testClusters.integTest {
debugPort += 1
}
}

// Increase heap size from default of 512mb to 1gb. When heap size is 512mb, our integ tests sporadically fail due
// to ml-commons memory circuit breaker exception
jvmArgs("-Xms1g", "-Xmx1g")
}

// Remote Integration Tests
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,22 @@
package org.opensearch.neuralsearch.common;

import static org.apache.http.entity.ContentType.APPLICATION_JSON;
import static org.opensearch.neuralsearch.common.VectorUtil.vectorAsListToArray;

import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import java.util.function.Predicate;
import java.util.stream.Collectors;

import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.SneakyThrows;

import org.apache.commons.lang3.StringUtils;
import org.apache.http.Header;
Expand All @@ -28,14 +35,20 @@
import org.opensearch.client.Response;
import org.opensearch.client.RestClient;
import org.opensearch.client.WarningsHandler;
import org.opensearch.common.Strings;
import org.opensearch.common.xcontent.ToXContent;
import org.opensearch.common.xcontent.XContentBuilder;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.common.xcontent.XContentHelper;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.test.rest.OpenSearchRestTestCase;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.neuralsearch.OpenSearchSecureRestTestCase;
import org.opensearch.rest.RestStatus;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know this is not the change which you have done, but can you change the base class of BaseNeuralSearchIT from OpenSearchRestTestCase to OpenSearchSecureRestTestCase. This will allow the secure clusters to be used for testing

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I can update.


import com.google.common.collect.ImmutableList;

public abstract class BaseNeuralSearchIT extends OpenSearchRestTestCase {
public abstract class BaseNeuralSearchIT extends OpenSearchSecureRestTestCase {

private static final Locale LOCALE = Locale.ROOT;

Expand All @@ -45,7 +58,7 @@ public abstract class BaseNeuralSearchIT extends OpenSearchRestTestCase {

protected final ClassLoader classLoader = this.getClass().getClassLoader();

public String uploadModel(String requestBody) throws Exception {
protected String uploadModel(String requestBody) throws Exception {
Response uploadResponse = makeRequest(
client(),
"POST",
Expand Down Expand Up @@ -74,7 +87,7 @@ public String uploadModel(String requestBody) throws Exception {
return modelId;
}

public void loadModel(String modelId) throws IOException, InterruptedException {
protected void loadModel(String modelId) throws IOException, InterruptedException {
Response uploadResponse = makeRequest(
client(),
"POST",
Expand All @@ -100,6 +113,56 @@ public void loadModel(String modelId) throws IOException, InterruptedException {
}
}

/**
* Upload default model and load into the cluster
*
* @return modelID
*/
@SneakyThrows
protected String prepareModel() {
String requestBody = Files.readString(Path.of(classLoader.getResource("processor/UploadModelRequestBody.json").toURI()));
String modelId = uploadModel(requestBody);
loadModel(modelId);
return modelId;
}

/**
* Execute model inference on the provided query text
*
* @param modelId id of model to run inference
* @param queryText text to be transformed to a model
* @return text embedding
*/
@SuppressWarnings("unchecked")
@SneakyThrows
protected float[] runInference(String modelId, String queryText) {
Response inferenceResponse = makeRequest(
client(),
"POST",
String.format(LOCALE, "/_plugins/_ml/_predict/text_embedding/%s", modelId),
null,
toHttpEntity(String.format(LOCALE, "{\"text_docs\": [\"%s\"],\"target_response\": [\"sentence_embedding\"]}", queryText)),
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana"))
);

Map<String, Object> inferenceResJson = XContentHelper.convertToMap(
XContentFactory.xContent(XContentType.JSON),
EntityUtils.toString(inferenceResponse.getEntity()),
false
);

Object inference_results = inferenceResJson.get("inference_results");
assertTrue(inference_results instanceof List);
List<Object> inferenceResultsAsMap = (List<Object>) inference_results;
assertEquals(1, inferenceResultsAsMap.size());
Map<String, Object> result = (Map<String, Object>) inferenceResultsAsMap.get(0);
List<Object> output = (List<Object>) result.get("output");
assertEquals(1, output.size());
Map<String, Object> map = (Map<String, Object>) output.get(0);
List<Float> data = ((List<Double>) map.get("data")).stream().map(Double::floatValue).collect(Collectors.toList());
return vectorAsListToArray(data);
}

protected void createIndexWithConfiguration(String indexName, String indexConfiguration, String pipelineName) throws Exception {
if (StringUtils.isNotBlank(pipelineName)) {
indexConfiguration = String.format(LOCALE, indexConfiguration, pipelineName);
Expand All @@ -121,7 +184,7 @@ protected void createIndexWithConfiguration(String indexName, String indexConfig
assertEquals(indexName, node.get("index").toString());
}

public void createPipelineProcessor(String modelId, String pipelineName) throws Exception {
protected void createPipelineProcessor(String modelId, String pipelineName) throws Exception {
Response pipelineCreateResponse = makeRequest(
client(),
"PUT",
Expand All @@ -144,7 +207,155 @@ public void createPipelineProcessor(String modelId, String pipelineName) throws
assertEquals("true", node.get("acknowledged").toString());
}

public Map<String, Object> getTaskQueryResponse(String taskId) throws IOException {
/**
* Get the number of documents in a particular index
*
* @param indexName name of index
* @return number of documents indexed to that index
*/
@SneakyThrows
protected int getDocCount(String indexName) {
Request request = new Request("GET", "/" + indexName + "/_count");
Response response = client().performRequest(request);
assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));
String responseBody = EntityUtils.toString(response.getEntity());
Map<String, Object> responseMap = createParser(XContentType.JSON.xContent(), responseBody).map();
return (Integer) responseMap.get("count");
}

/**
* Execute a search request initialized from a neural query builder
*
* @param index Index to search against
* @param queryBuilder queryBuilder to produce source of query
* @param resultSize number of results to return in the search
* @return Search results represented as a map
*/
protected Map<String, Object> search(String index, QueryBuilder queryBuilder, int resultSize) {
return search(index, queryBuilder, null, resultSize);
}

/**
* Execute a search request initialized from a neural query builder that can add a rescore query to the request
*
* @param index Index to search against
* @param queryBuilder queryBuilder to produce source of query
* @param rescorer used for rescorer query builder
* @param resultSize number of results to return in the search
* @return Search results represented as a map
*/
@SneakyThrows
protected Map<String, Object> search(String index, QueryBuilder queryBuilder, QueryBuilder rescorer, int resultSize) {
XContentBuilder builder = XContentFactory.jsonBuilder().startObject().field("query");
queryBuilder.toXContent(builder, ToXContent.EMPTY_PARAMS);

if (rescorer != null) {
builder.startObject("rescore").startObject("query").field("query_weight", 0.0f).field("rescore_query");
rescorer.toXContent(builder, ToXContent.EMPTY_PARAMS);
builder.endObject().endObject();
}

builder.endObject();

Request request = new Request("POST", "/" + index + "/_search");
request.addParameter("size", Integer.toString(resultSize));
request.setJsonEntity(Strings.toString(builder));

Response response = client().performRequest(request);
assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));

String responseBody = EntityUtils.toString(response.getEntity());

return XContentHelper.convertToMap(XContentFactory.xContent(XContentType.JSON), responseBody, false);
}

/**
* Add a set of knn vectors
*
* @param index Name of the index
* @param docId ID of document to be added
* @param vectorFieldNames List of vectir fields to be added
* @param vectors List of vectors corresponding to those fields
*/
protected void addKnnDoc(String index, String docId, List<String> vectorFieldNames, List<Object[]> vectors) {
addKnnDoc(index, docId, vectorFieldNames, vectors, Collections.emptyList(), Collections.emptyList());
}

/**
* Add a set of knn vectors and text to an index
*
* @param index Name of the index
* @param docId ID of document to be added
* @param vectorFieldNames List of vectir fields to be added
* @param vectors List of vectors corresponding to those fields
* @param textFieldNames List of text fields to be added
* @param texts List of text corresponding to those fields
*/
@SneakyThrows
protected void addKnnDoc(
String index,
String docId,
List<String> vectorFieldNames,
List<Object[]> vectors,
List<String> textFieldNames,
List<String> texts
) {
Request request = new Request("POST", "/" + index + "/_doc/" + docId + "?refresh=true");
XContentBuilder builder = XContentFactory.jsonBuilder().startObject();
for (int i = 0; i < vectorFieldNames.size(); i++) {
builder.field(vectorFieldNames.get(i), vectors.get(i));
}

for (int i = 0; i < textFieldNames.size(); i++) {
builder.field(textFieldNames.get(i), texts.get(i));
}
builder.endObject();

request.setJsonEntity(Strings.toString(builder));
Response response = client().performRequest(request);
assertEquals(request.getEndpoint() + ": failed", RestStatus.CREATED, RestStatus.fromCode(response.getStatusLine().getStatusCode()));
}

/**
* Parse the first returned hit from a search response as a map
*
* @param searchResponseAsMap Complete search response as a map
* @return Map of first internal hit from the search
*/
@SuppressWarnings("unchecked")
protected Map<String, Object> getFirstInnerHit(Map<String, Object> searchResponseAsMap) {
Map<String, Object> hits1map = (Map<String, Object>) searchResponseAsMap.get("hits");
List<Object> hits2List = (List<Object>) hits1map.get("hits");
assertTrue(hits2List.size() > 0);
return (Map<String, Object>) hits2List.get(0);
}

/**
* Create a k-NN index from a list of KNNFieldConfigs
*
* @param indexName of index to be created
* @param knnFieldConfigs list of configs specifying field
*/
@SneakyThrows
protected void prepareKnnIndex(String indexName, List<KNNFieldConfig> knnFieldConfigs) {
createIndexWithConfiguration(indexName, buildIndexConfiguration(knnFieldConfigs), "");
}

/**
* Computes the expected distance between an indexVector and query text without using the neural query type.
*
* @param modelId ID of model to run inference
* @param indexVector vector to compute score against
* @param spaceType Space to measure distance
* @param queryText Text to produce query vector from
* @return Expected OpenSearch score for this indexVector
*/
protected float computeExpectedScore(String modelId, float[] indexVector, SpaceType spaceType, String queryText) {
float[] queryVector = runInference(modelId, queryText);
return spaceType.getVectorSimilarityFunction().compare(queryVector, indexVector);
}

protected Map<String, Object> getTaskQueryResponse(String taskId) throws IOException {
Response taskQueryResponse = makeRequest(
client(),
"GET",
Expand All @@ -160,12 +371,37 @@ public Map<String, Object> getTaskQueryResponse(String taskId) throws IOExceptio
);
}

public boolean checkComplete(Map<String, Object> node) {
protected boolean checkComplete(Map<String, Object> node) {
Predicate<Map<String, Object>> predicate = x -> node.get("error") != null || "COMPLETED".equals(String.valueOf(node.get("state")));
return predicate.test(node);
}

public static Response makeRequest(
@SneakyThrows
private String buildIndexConfiguration(List<KNNFieldConfig> knnFieldConfigs) {
XContentBuilder xContentBuilder = XContentFactory.jsonBuilder()
.startObject()
.startObject("settings")
.field("number_of_shards", 3)
.field("index.knn", true)
.endObject()
.startObject("mappings")
.startObject("properties");

for (KNNFieldConfig knnFieldConfig : knnFieldConfigs) {
xContentBuilder.startObject(knnFieldConfig.getName())
.field("type", "knn_vector")
.field("dimension", Integer.toString(knnFieldConfig.getDimension()))
.startObject("method")
.field("engine", "lucene")
.field("space_type", knnFieldConfig.getSpaceType().getValue())
.field("name", "hnsw")
.endObject()
.endObject();
}
return Strings.toString(xContentBuilder.endObject().endObject().endObject());
}

protected static Response makeRequest(
RestClient client,
String method,
String endpoint,
Expand All @@ -176,7 +412,7 @@ public static Response makeRequest(
return makeRequest(client, method, endpoint, params, entity, headers, false);
}

public static Response makeRequest(
protected static Response makeRequest(
RestClient client,
String method,
String endpoint,
Expand All @@ -203,8 +439,15 @@ public static Response makeRequest(
return client.performRequest(request);
}

public static HttpEntity toHttpEntity(String jsonString) {
protected static HttpEntity toHttpEntity(String jsonString) {
return new StringEntity(jsonString, APPLICATION_JSON);
}

@AllArgsConstructor
@Getter
protected static class KNNFieldConfig {
private final String name;
private final Integer dimension;
private final SpaceType spaceType;
}
}
Loading