Skip to content

Commit

Permalink
Neural Sparse Search BWC tests
Browse files Browse the repository at this point in the history
Signed-off-by: Varun Jain <[email protected]>
  • Loading branch information
vibrantvarun committed Jan 8, 2024
1 parent 8ec2ba8 commit 98a9805
Show file tree
Hide file tree
Showing 5 changed files with 218 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.bwc;

import com.carrotsearch.randomizedtesting.RandomizedTest;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.Map;
import org.opensearch.neuralsearch.TestUtils;
import static org.opensearch.neuralsearch.TestUtils.NODES_BWC_CLUSTER;
import static org.opensearch.neuralsearch.TestUtils.SPARSE_ENCODING_PROCESSOR;
import org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder;

public class NeuralSparseSearchIT extends AbstractRestartUpgradeRestTestCase {
private static final String PIPELINE_NAME = "nlp-ingest-pipeline-sparse";
private static final String TEST_FIELD = "passage_text";
private static final String TEXT = "Hello world";
private static final String TEXT_1 = "Hi planet";
private static final String query = "Hi world";

public void testNeuralSparseSearch_E2EFlow() throws Exception {
waitForClusterHealthGreen(NODES_BWC_CLUSTER);
if (isRunningAgainstOldCluster()) {
String modelId = uploadTextEmbeddingModel();
loadModel(modelId);
createPipelineProcessor(modelId, PIPELINE_NAME);
createIndexWithConfiguration(
getIndexNameForTest(),
Files.readString(Path.of(classLoader.getResource("processor/SparseIndexMappings.json").toURI())),
PIPELINE_NAME
);
addDocument(getIndexNameForTest(), "0", TEST_FIELD, TEXT, null, null);
} else {
Map<String, Object> pipeline = getIngestionPipeline(PIPELINE_NAME);
assertNotNull(pipeline);
String modelId = TestUtils.getModelId(pipeline, SPARSE_ENCODING_PROCESSOR);
loadModel(modelId);
addDocument(getIndexNameForTest(), "1", TEST_FIELD, TEXT_1, null, null);
validateTestIndex(modelId);
deletePipeline(PIPELINE_NAME);
deleteModel(modelId);
deleteIndex(getIndexNameForTest());
}

}

private void validateTestIndex(String modelId) throws Exception {
int docCount = getDocCount(getIndexNameForTest());
assertEquals(2, docCount);
loadModel(modelId);
NeuralSparseQueryBuilder neuralSparseQueryBuilder = new NeuralSparseQueryBuilder();
neuralSparseQueryBuilder.fieldName("passage_embedding");
neuralSparseQueryBuilder.queryText(query);
neuralSparseQueryBuilder.modelId(modelId);
Map<String, Object> response = search(getIndexNameForTest(), neuralSparseQueryBuilder, 1);
assertNotNull(response);
int hits = getHitCount(response);
assertEquals(2, hits);
}

private String uploadTextEmbeddingModel() throws Exception {
String requestBody = Files.readString(Path.of(classLoader.getResource("processor/UploadModelRequestBody.json").toURI()));
return registerModelGroupAndGetModelId(requestBody);
}

private String registerModelGroupAndGetModelId(String requestBody) throws Exception {
String modelGroupRegisterRequestBody = Files.readString(
Path.of(classLoader.getResource("processor/CreateModelGroupRequestBody.json").toURI())
);
String modelGroupId = registerModelGroup(
String.format(LOCALE, modelGroupRegisterRequestBody, "public_model_" + RandomizedTest.randomAsciiAlphanumOfLength(8))
);
return uploadModel(String.format(LOCALE, requestBody, modelGroupId));
}

private void createPipelineProcessor(String modelId, String pipelineName) throws Exception {
String requestBody = Files.readString(
Path.of(classLoader.getResource("processor/PipelineForSparseEncodingProcessorConfiguration.json").toURI())
);
createPipelineProcessor(requestBody, pipelineName, modelId);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
{
"description": "An sparse encoding ingest pipeline",
"processors": [
{
"sparse_encoding": {
"model_id": "%s",
"field_map": {
"passage_text": "passage_embedding"
}
}
}
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
{
"settings": {
"default_pipeline": "%s",
"number_of_shards": 3,
"number_of_replicas": 1
},
"mappings": {
"properties": {
"passage_embedding": {
"type": "knn_vector"
},
"passage_text": {
"type": "text"
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.bwc;

import com.carrotsearch.randomizedtesting.RandomizedTest;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.Map;
import org.opensearch.neuralsearch.TestUtils;
import static org.opensearch.neuralsearch.TestUtils.NODES_BWC_CLUSTER;
import static org.opensearch.neuralsearch.TestUtils.SPARSE_ENCODING_PROCESSOR;
import org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder;

public class NeuralSparseSearchIT extends AbstractRollingUpgradeTestCase {
private static final String PIPELINE_NAME = "nlp-ingest-pipeline-sparse";
private static final String TEST_FIELD = "passage_text";
private static final String TEXT = "Hello world";
private static final String TEXT_MIXED = "Hi planet";
private static final String TEXT_UPGRADED = "Hi earth";
private static final int NUM_DOCS_PER_ROUND = 1;
private static final String query = "Hi world";

public void testNeuralSparseSearch_E2EFlow() throws Exception {
waitForClusterHealthGreen(NODES_BWC_CLUSTER);
switch (getClusterType()) {
case OLD:
String modelId = uploadTextEmbeddingModel();
loadModel(modelId);
createPipelineProcessor(modelId, PIPELINE_NAME);
createIndexWithConfiguration(
getIndexNameForTest(),
Files.readString(Path.of(classLoader.getResource("processor/SparseIndexMappings.json").toURI())),
PIPELINE_NAME
);
addDocument(getIndexNameForTest(), "0", TEST_FIELD, TEXT, null, null);
break;
case MIXED:
modelId = getModelId(PIPELINE_NAME);
int totalDocsCountMixed;
if (isFirstMixedRound()) {
totalDocsCountMixed = NUM_DOCS_PER_ROUND;
validateTestIndexOnUpgrade(totalDocsCountMixed, modelId);
addDocument(getIndexNameForTest(), "1", TEST_FIELD, TEXT_MIXED, null, null);
} else {
totalDocsCountMixed = 2 * NUM_DOCS_PER_ROUND;
validateTestIndexOnUpgrade(totalDocsCountMixed, modelId);
}
break;
case UPGRADED:
modelId = getModelId(PIPELINE_NAME);
int totalDocsCountUpgraded = 3 * NUM_DOCS_PER_ROUND;
loadModel(modelId);
addDocument(getIndexNameForTest(), "2", TEST_FIELD, TEXT_UPGRADED, null, null);
validateTestIndexOnUpgrade(totalDocsCountUpgraded, modelId);
deletePipeline(PIPELINE_NAME);
deleteModel(modelId);
deleteIndex(getIndexNameForTest());
break;
}
}

private void validateTestIndexOnUpgrade(int numberOfDocs, String modelId) throws Exception {
int docCount = getDocCount(getIndexNameForTest());
assertEquals(numberOfDocs, docCount);
loadModel(modelId);
NeuralSparseQueryBuilder neuralSparseQueryBuilder = new NeuralSparseQueryBuilder();
neuralSparseQueryBuilder.fieldName("passage_embedding");
neuralSparseQueryBuilder.queryText(query);
neuralSparseQueryBuilder.modelId(modelId);
Map<String, Object> response = search(getIndexNameForTest(), neuralSparseQueryBuilder, 1);
assertNotNull(response);
}

private String uploadTextEmbeddingModel() throws Exception {
String requestBody = Files.readString(Path.of(classLoader.getResource("processor/UploadModelRequestBody.json").toURI()));
return registerModelGroupAndGetModelId(requestBody);
}

private String registerModelGroupAndGetModelId(String requestBody) throws Exception {
String modelGroupRegisterRequestBody = Files.readString(
Path.of(classLoader.getResource("processor/CreateModelGroupRequestBody.json").toURI())
);
String modelGroupId = registerModelGroup(
String.format(LOCALE, modelGroupRegisterRequestBody, "public_model_" + RandomizedTest.randomAsciiAlphanumOfLength(8))
);
return uploadModel(String.format(LOCALE, requestBody, modelGroupId));
}

private void createPipelineProcessor(String modelId, String pipelineName) throws Exception {
String requestBody = Files.readString(
Path.of(classLoader.getResource("processor/PipelineForSparseEncodingProcessorConfiguration.json").toURI())
);
createPipelineProcessor(requestBody, pipelineName, modelId);
}

private String getModelId(String pipelineName) {
Map<String, Object> pipeline = getIngestionPipeline(pipelineName);
assertNotNull(pipeline);
return TestUtils.getModelId(pipeline, SPARSE_ENCODING_PROCESSOR);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ public class TestUtils {
public static final String DEFAULT_NORMALIZATION_METHOD = "min_max";
public static final String DEFAULT_COMBINATION_METHOD = "arithmetic_mean";
public static final String PARAM_NAME_WEIGHTS = "weights";
public static final String SPARSE_ENCODING_PROCESSOR = "sparse_encoding";

/**
* Convert an xContentBuilder to a map
Expand Down

0 comments on commit 98a9805

Please sign in to comment.