diff --git a/CHANGELOG.md b/CHANGELOG.md index 024a02311..5f6eca41c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Features - Enable sorting and search_after features in Hybrid Search [#827](https://github.com/opensearch-project/neural-search/pull/827) ### Enhancements +- InferenceProcessor inherits from AbstractBatchingProcessor to support sub batching in processor [#820](https://github.com/opensearch-project/neural-search/pull/820) - Adds dynamic knn query parameters efsearch and nprobes [#814](https://github.com/opensearch-project/neural-search/pull/814/) - Enable '.' for nested field in text embedding processor ([#811](https://github.com/opensearch-project/neural-search/pull/811)) ### Bug Fixes diff --git a/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java index d9f9c7048..97f2f1837 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java @@ -29,7 +29,7 @@ import org.opensearch.cluster.service.ClusterService; import org.opensearch.env.Environment; import org.opensearch.index.mapper.IndexFieldMapper; -import org.opensearch.ingest.AbstractProcessor; +import org.opensearch.ingest.AbstractBatchingProcessor; import org.opensearch.ingest.IngestDocument; import org.opensearch.ingest.IngestDocumentWrapper; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; @@ -46,7 +46,7 @@ * and set the target fields according to the field name map. */ @Log4j2 -public abstract class InferenceProcessor extends AbstractProcessor { +public abstract class InferenceProcessor extends AbstractBatchingProcessor { public static final String MODEL_ID_FIELD = "model_id"; public static final String FIELD_MAP_FIELD = "field_map"; @@ -69,6 +69,7 @@ public abstract class InferenceProcessor extends AbstractProcessor { public InferenceProcessor( String tag, String description, + int batchSize, String type, String listTypeNestedMapKey, String modelId, @@ -77,7 +78,7 @@ public InferenceProcessor( Environment environment, ClusterService clusterService ) { - super(tag, description); + super(tag, description, batchSize); this.type = type; if (StringUtils.isBlank(modelId)) throw new IllegalArgumentException("model_id is null or empty, cannot process it"); validateEmbeddingConfiguration(fieldMap); @@ -144,7 +145,7 @@ public void execute(IngestDocument ingestDocument, BiConsumer<IngestDocument, Ex abstract void doBatchExecute(List<String> inferenceList, Consumer<List<?>> handler, Consumer<Exception> onException); @Override - public void batchExecute(List<IngestDocumentWrapper> ingestDocumentWrappers, Consumer<List<IngestDocumentWrapper>> handler) { + public void subBatchExecute(List<IngestDocumentWrapper> ingestDocumentWrappers, Consumer<List<IngestDocumentWrapper>> handler) { if (CollectionUtils.isEmpty(ingestDocumentWrappers)) { handler.accept(Collections.emptyList()); return; diff --git a/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java index e83bd8233..e01840fbb 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java @@ -31,13 +31,14 @@ public final class SparseEncodingProcessor extends InferenceProcessor { public SparseEncodingProcessor( String tag, String description, + int batchSize, String modelId, Map<String, Object> fieldMap, MLCommonsClientAccessor clientAccessor, Environment environment, ClusterService clusterService ) { - super(tag, description, TYPE, LIST_TYPE_NESTED_MAP_KEY, modelId, fieldMap, clientAccessor, environment, clusterService); + super(tag, description, batchSize, TYPE, LIST_TYPE_NESTED_MAP_KEY, modelId, fieldMap, clientAccessor, environment, clusterService); } @Override diff --git a/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java index f5b710530..c8f9f080d 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java @@ -30,13 +30,14 @@ public final class TextEmbeddingProcessor extends InferenceProcessor { public TextEmbeddingProcessor( String tag, String description, + int batchSize, String modelId, Map<String, Object> fieldMap, MLCommonsClientAccessor clientAccessor, Environment environment, ClusterService clusterService ) { - super(tag, description, TYPE, LIST_TYPE_NESTED_MAP_KEY, modelId, fieldMap, clientAccessor, environment, clusterService); + super(tag, description, batchSize, TYPE, LIST_TYPE_NESTED_MAP_KEY, modelId, fieldMap, clientAccessor, environment, clusterService); } @Override diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingProcessorFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingProcessorFactory.java index 8a294458a..46055df16 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingProcessorFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingProcessorFactory.java @@ -14,7 +14,7 @@ import org.opensearch.cluster.service.ClusterService; import org.opensearch.env.Environment; -import org.opensearch.ingest.Processor; +import org.opensearch.ingest.AbstractBatchingProcessor; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; import org.opensearch.neuralsearch.processor.SparseEncodingProcessor; @@ -24,27 +24,23 @@ * Factory for sparse encoding ingest processor for ingestion pipeline. Instantiates processor based on user provided input. */ @Log4j2 -public class SparseEncodingProcessorFactory implements Processor.Factory { +public class SparseEncodingProcessorFactory extends AbstractBatchingProcessor.Factory { private final MLCommonsClientAccessor clientAccessor; private final Environment environment; private final ClusterService clusterService; public SparseEncodingProcessorFactory(MLCommonsClientAccessor clientAccessor, Environment environment, ClusterService clusterService) { + super(TYPE); this.clientAccessor = clientAccessor; this.environment = environment; this.clusterService = clusterService; } @Override - public SparseEncodingProcessor create( - Map<String, Processor.Factory> registry, - String processorTag, - String description, - Map<String, Object> config - ) throws Exception { - String modelId = readStringProperty(TYPE, processorTag, config, MODEL_ID_FIELD); - Map<String, Object> fieldMap = readMap(TYPE, processorTag, config, FIELD_MAP_FIELD); - - return new SparseEncodingProcessor(processorTag, description, modelId, fieldMap, clientAccessor, environment, clusterService); + protected AbstractBatchingProcessor newProcessor(String tag, String description, int batchSize, Map<String, Object> config) { + String modelId = readStringProperty(TYPE, tag, config, MODEL_ID_FIELD); + Map<String, Object> fieldMap = readMap(TYPE, tag, config, FIELD_MAP_FIELD); + + return new SparseEncodingProcessor(tag, description, batchSize, modelId, fieldMap, clientAccessor, environment, clusterService); } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/TextEmbeddingProcessorFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/TextEmbeddingProcessorFactory.java index d38bf21df..6b442b56c 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/factory/TextEmbeddingProcessorFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/TextEmbeddingProcessorFactory.java @@ -14,14 +14,14 @@ import org.opensearch.cluster.service.ClusterService; import org.opensearch.env.Environment; -import org.opensearch.ingest.Processor; +import org.opensearch.ingest.AbstractBatchingProcessor; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor; /** * Factory for text embedding ingest processor for ingestion pipeline. Instantiates processor based on user provided input. */ -public class TextEmbeddingProcessorFactory implements Processor.Factory { +public final class TextEmbeddingProcessorFactory extends AbstractBatchingProcessor.Factory { private final MLCommonsClientAccessor clientAccessor; @@ -34,20 +34,16 @@ public TextEmbeddingProcessorFactory( final Environment environment, final ClusterService clusterService ) { + super(TYPE); this.clientAccessor = clientAccessor; this.environment = environment; this.clusterService = clusterService; } @Override - public TextEmbeddingProcessor create( - final Map<String, Processor.Factory> registry, - final String processorTag, - final String description, - final Map<String, Object> config - ) throws Exception { - String modelId = readStringProperty(TYPE, processorTag, config, MODEL_ID_FIELD); - Map<String, Object> filedMap = readMap(TYPE, processorTag, config, FIELD_MAP_FIELD); - return new TextEmbeddingProcessor(processorTag, description, modelId, filedMap, clientAccessor, environment, clusterService); + protected AbstractBatchingProcessor newProcessor(String tag, String description, int batchSize, Map<String, Object> config) { + String modelId = readStringProperty(TYPE, tag, config, MODEL_ID_FIELD); + Map<String, Object> filedMap = readMap(TYPE, tag, config, FIELD_MAP_FIELD); + return new TextEmbeddingProcessor(tag, description, batchSize, modelId, filedMap, clientAccessor, environment, clusterService); } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/TextImageEmbeddingProcessorFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/TextImageEmbeddingProcessorFactory.java index b98f4fcc0..7250e9365 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/factory/TextImageEmbeddingProcessorFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/TextImageEmbeddingProcessorFactory.java @@ -6,7 +6,6 @@ import static org.opensearch.ingest.ConfigurationUtils.readMap; import static org.opensearch.ingest.ConfigurationUtils.readStringProperty; -import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.Factory; import static org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor.EMBEDDING_FIELD; import static org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor.FIELD_MAP_FIELD; import static org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor.MODEL_ID_FIELD; @@ -16,6 +15,7 @@ import org.opensearch.cluster.service.ClusterService; import org.opensearch.env.Environment; +import org.opensearch.ingest.Processor; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; import org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor; @@ -25,31 +25,18 @@ * Factory for text_image embedding ingest processor for ingestion pipeline. Instantiates processor based on user provided input. */ @AllArgsConstructor -public class TextImageEmbeddingProcessorFactory implements Factory { +public class TextImageEmbeddingProcessorFactory implements Processor.Factory { private final MLCommonsClientAccessor clientAccessor; private final Environment environment; private final ClusterService clusterService; @Override - public TextImageEmbeddingProcessor create( - final Map<String, Factory> registry, - final String processorTag, - final String description, - final Map<String, Object> config - ) throws Exception { - String modelId = readStringProperty(TYPE, processorTag, config, MODEL_ID_FIELD); - String embedding = readStringProperty(TYPE, processorTag, config, EMBEDDING_FIELD); - Map<String, String> filedMap = readMap(TYPE, processorTag, config, FIELD_MAP_FIELD); - return new TextImageEmbeddingProcessor( - processorTag, - description, - modelId, - embedding, - filedMap, - clientAccessor, - environment, - clusterService - ); + public Processor create(Map<String, Processor.Factory> processorFactories, String tag, String description, Map<String, Object> config) + throws Exception { + String modelId = readStringProperty(TYPE, tag, config, MODEL_ID_FIELD); + String embedding = readStringProperty(TYPE, tag, config, EMBEDDING_FIELD); + Map<String, String> filedMap = readMap(TYPE, tag, config, FIELD_MAP_FIELD); + return new TextImageEmbeddingProcessor(tag, description, modelId, embedding, filedMap, clientAccessor, environment, clusterService); } } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/InferenceProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/InferenceProcessorTests.java index d08f6c3f1..cd2d0816a 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/InferenceProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/InferenceProcessorTests.java @@ -4,6 +4,7 @@ */ package org.opensearch.neuralsearch.processor; +import lombok.Getter; import org.junit.Before; import org.mockito.ArgumentCaptor; import org.mockito.MockitoAnnotations; @@ -15,6 +16,7 @@ import org.opensearch.ingest.IngestDocumentWrapper; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; @@ -42,6 +44,7 @@ public class InferenceProcessorTests extends InferenceProcessorTestCase { private static final String DESCRIPTION = "description"; private static final String MAP_KEY = "map_key"; private static final String MODEL_ID = "model_id"; + private static final int BATCH_SIZE = 10; private static final Map<String, Object> FIELD_MAP = Map.of("key1", "embedding_key1", "key2", "embedding_key2"); @Before @@ -54,7 +57,7 @@ public void setup() { } public void test_batchExecute_emptyInput() { - TestInferenceProcessor processor = new TestInferenceProcessor(createMockVectorResult(), null); + TestInferenceProcessor processor = new TestInferenceProcessor(createMockVectorResult(), BATCH_SIZE, null); Consumer resultHandler = mock(Consumer.class); processor.batchExecute(Collections.emptyList(), resultHandler); ArgumentCaptor<List<IngestDocumentWrapper>> captor = ArgumentCaptor.forClass(List.class); @@ -65,7 +68,7 @@ public void test_batchExecute_emptyInput() { public void test_batchExecute_allFailedValidation() { final int docCount = 2; - TestInferenceProcessor processor = new TestInferenceProcessor(createMockVectorResult(), null); + TestInferenceProcessor processor = new TestInferenceProcessor(createMockVectorResult(), BATCH_SIZE, null); List<IngestDocumentWrapper> wrapperList = createIngestDocumentWrappers(docCount); wrapperList.get(0).getIngestDocument().setFieldValue("key1", Arrays.asList("", "value1")); wrapperList.get(1).getIngestDocument().setFieldValue("key1", Arrays.asList("", "value1")); @@ -83,7 +86,7 @@ public void test_batchExecute_allFailedValidation() { public void test_batchExecute_partialFailedValidation() { final int docCount = 2; - TestInferenceProcessor processor = new TestInferenceProcessor(createMockVectorResult(), null); + TestInferenceProcessor processor = new TestInferenceProcessor(createMockVectorResult(), BATCH_SIZE, null); List<IngestDocumentWrapper> wrapperList = createIngestDocumentWrappers(docCount); wrapperList.get(0).getIngestDocument().setFieldValue("key1", Arrays.asList("", "value1")); wrapperList.get(1).getIngestDocument().setFieldValue("key1", Arrays.asList("value3", "value4")); @@ -105,7 +108,7 @@ public void test_batchExecute_partialFailedValidation() { public void test_batchExecute_happyCase() { final int docCount = 2; List<List<Float>> inferenceResults = createMockVectorWithLength(6); - TestInferenceProcessor processor = new TestInferenceProcessor(inferenceResults, null); + TestInferenceProcessor processor = new TestInferenceProcessor(inferenceResults, BATCH_SIZE, null); List<IngestDocumentWrapper> wrapperList = createIngestDocumentWrappers(docCount); wrapperList.get(0).getIngestDocument().setFieldValue("key1", Arrays.asList("value1", "value2")); wrapperList.get(1).getIngestDocument().setFieldValue("key1", Arrays.asList("value3", "value4")); @@ -126,7 +129,7 @@ public void test_batchExecute_happyCase() { public void test_batchExecute_sort() { final int docCount = 2; List<List<Float>> inferenceResults = createMockVectorWithLength(100); - TestInferenceProcessor processor = new TestInferenceProcessor(inferenceResults, null); + TestInferenceProcessor processor = new TestInferenceProcessor(inferenceResults, BATCH_SIZE, null); List<IngestDocumentWrapper> wrapperList = createIngestDocumentWrappers(docCount); wrapperList.get(0).getIngestDocument().setFieldValue("key1", Arrays.asList("aaaaa", "bbb")); wrapperList.get(1).getIngestDocument().setFieldValue("key1", Arrays.asList("cc", "ddd")); @@ -158,7 +161,7 @@ public void test_batchExecute_sort() { public void test_doBatchExecute_exception() { final int docCount = 2; List<List<Float>> inferenceResults = createMockVectorWithLength(6); - TestInferenceProcessor processor = new TestInferenceProcessor(inferenceResults, new RuntimeException()); + TestInferenceProcessor processor = new TestInferenceProcessor(inferenceResults, BATCH_SIZE, new RuntimeException()); List<IngestDocumentWrapper> wrapperList = createIngestDocumentWrappers(docCount); wrapperList.get(0).getIngestDocument().setFieldValue("key1", Arrays.asList("value1", "value2")); wrapperList.get(1).getIngestDocument().setFieldValue("key1", Arrays.asList("value3", "value4")); @@ -174,12 +177,36 @@ public void test_doBatchExecute_exception() { verify(clientAccessor).inferenceSentences(anyString(), anyList(), any()); } + public void test_batchExecute_subBatches() { + final int docCount = 5; + List<List<Float>> inferenceResults = createMockVectorWithLength(6); + TestInferenceProcessor processor = new TestInferenceProcessor(inferenceResults, 2, null); + List<IngestDocumentWrapper> wrapperList = createIngestDocumentWrappers(docCount); + for (int i = 0; i < docCount; ++i) { + wrapperList.get(i).getIngestDocument().setFieldValue("key1", Collections.singletonList("value" + i)); + } + List<IngestDocumentWrapper> allResults = new ArrayList<>(); + processor.batchExecute(wrapperList, allResults::addAll); + for (int i = 0; i < docCount; ++i) { + assertEquals(allResults.get(i).getIngestDocument(), wrapperList.get(i).getIngestDocument()); + assertEquals(allResults.get(i).getSlot(), wrapperList.get(i).getSlot()); + assertEquals(allResults.get(i).getException(), wrapperList.get(i).getException()); + } + assertEquals(3, processor.getAllInferenceInputs().size()); + assertEquals(List.of("value0", "value1"), processor.getAllInferenceInputs().get(0)); + assertEquals(List.of("value2", "value3"), processor.getAllInferenceInputs().get(1)); + assertEquals(List.of("value4"), processor.getAllInferenceInputs().get(2)); + } + private class TestInferenceProcessor extends InferenceProcessor { List<?> vectors; Exception exception; - public TestInferenceProcessor(List<?> vectors, Exception exception) { - super(TAG, DESCRIPTION, TYPE, MAP_KEY, MODEL_ID, FIELD_MAP, clientAccessor, environment, clusterService); + @Getter + List<List<String>> allInferenceInputs = new ArrayList<>(); + + public TestInferenceProcessor(List<?> vectors, int batchSize, Exception exception) { + super(TAG, DESCRIPTION, batchSize, TYPE, MAP_KEY, MODEL_ID, FIELD_MAP, clientAccessor, environment, clusterService); this.vectors = vectors; this.exception = exception; } @@ -196,6 +223,7 @@ public void doExecute( void doBatchExecute(List<String> inferenceList, Consumer<List<?>> handler, Consumer<Exception> onException) { // use to verify if doBatchExecute is called from InferenceProcessor clientAccessor.inferenceSentences(MODEL_ID, inferenceList, ActionListener.wrap(results -> {}, ex -> {})); + allInferenceInputs.add(inferenceList); if (this.exception != null) { onException.accept(this.exception); } else { diff --git a/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessorTests.java index 7460390de..9486ee2ca 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessorTests.java @@ -38,6 +38,7 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.env.Environment; import org.opensearch.index.mapper.IndexFieldMapper; +import org.opensearch.ingest.AbstractBatchingProcessor; import org.opensearch.ingest.IngestDocument; import org.opensearch.ingest.IngestDocumentWrapper; import org.opensearch.ingest.Processor; @@ -76,7 +77,17 @@ private SparseEncodingProcessor createInstance() { Map<String, Object> config = new HashMap<>(); config.put(SparseEncodingProcessor.MODEL_ID_FIELD, "mockModelId"); config.put(SparseEncodingProcessor.FIELD_MAP_FIELD, ImmutableMap.of("key1", "key1Mapped", "key2", "key2Mapped")); - return sparseEncodingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); + return (SparseEncodingProcessor) sparseEncodingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); + } + + @SneakyThrows + private SparseEncodingProcessor createInstance(int batchSize) { + Map<String, Processor.Factory> registry = new HashMap<>(); + Map<String, Object> config = new HashMap<>(); + config.put(SparseEncodingProcessor.MODEL_ID_FIELD, "mockModelId"); + config.put(SparseEncodingProcessor.FIELD_MAP_FIELD, ImmutableMap.of("key1", "key1Mapped", "key2", "key2Mapped")); + config.put(AbstractBatchingProcessor.BATCH_SIZE_FIELD, batchSize); + return (SparseEncodingProcessor) sparseEncodingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); } public void testExecute_successful() { @@ -115,7 +126,12 @@ public void testExecute_whenInferenceTextListEmpty_SuccessWithoutAnyMap() { Map<String, Object> config = new HashMap<>(); config.put(TextEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId"); config.put(TextEmbeddingProcessor.FIELD_MAP_FIELD, ImmutableMap.of("key1", "key1Mapped", "key2", "key2Mapped")); - SparseEncodingProcessor processor = sparseEncodingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); + SparseEncodingProcessor processor = (SparseEncodingProcessor) sparseEncodingProcessorFactory.create( + registry, + PROCESSOR_TAG, + DESCRIPTION, + config + ); doThrow(new RuntimeException()).when(accessor).inferenceSentences(anyString(), anyList(), isA(ActionListener.class)); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); @@ -181,7 +197,12 @@ public void testExecute_withMapTypeInput_successful() { SparseEncodingProcessor.FIELD_MAP_FIELD, ImmutableMap.of("key1", Map.of("test1", "test1_knn"), "key2", Map.of("test4", "test4_knn")) ); - SparseEncodingProcessor processor = sparseEncodingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); + SparseEncodingProcessor processor = (SparseEncodingProcessor) sparseEncodingProcessorFactory.create( + registry, + PROCESSOR_TAG, + DESCRIPTION, + config + ); List<Map<String, ?>> dataAsMapList = createMockMapResult(2); doAnswer(invocation -> { @@ -199,7 +220,7 @@ public void testExecute_withMapTypeInput_successful() { public void test_batchExecute_successful() { final int docCount = 5; List<IngestDocumentWrapper> ingestDocumentWrappers = createIngestDocumentWrappers(docCount); - SparseEncodingProcessor processor = createInstance(); + SparseEncodingProcessor processor = createInstance(docCount); List<Map<String, ?>> dataAsMapList = createMockMapResult(10); doAnswer(invocation -> { ActionListener<List<Map<String, ?>>> listener = invocation.getArgument(2); @@ -221,7 +242,7 @@ public void test_batchExecute_successful() { public void test_batchExecute_exception() { final int docCount = 5; List<IngestDocumentWrapper> ingestDocumentWrappers = createIngestDocumentWrappers(docCount); - SparseEncodingProcessor processor = createInstance(); + SparseEncodingProcessor processor = createInstance(docCount); doAnswer(invocation -> { ActionListener<List<Map<String, ?>>> listener = invocation.getArgument(2); listener.onFailure(new RuntimeException()); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorIT.java b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorIT.java index 26854dd2e..b8415e4d6 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorIT.java @@ -6,11 +6,15 @@ import java.io.IOException; import java.net.URISyntaxException; +import java.net.URL; import java.nio.file.Files; import java.nio.file.Path; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Objects; +import java.util.Set; import org.apache.http.HttpHeaders; import org.apache.http.message.BasicHeader; @@ -74,11 +78,11 @@ public void testTextEmbeddingProcessor_batch() throws Exception { loadModel(modelId); createPipelineProcessor(modelId, PIPELINE_NAME, ProcessorType.TEXT_EMBEDDING); createTextEmbeddingIndex(); - ingestBatchDocumentWithBulk("batch_"); + ingestBatchDocumentWithBulk("batch_", 2, 2, Collections.emptySet(), Collections.emptySet()); assertEquals(2, getDocCount(INDEX_NAME)); - ingestDocument(INGEST_DOC1, "1"); - ingestDocument(INGEST_DOC2, "2"); + ingestDocument(String.format(LOCALE, INGEST_DOC1, "success"), "1"); + ingestDocument(String.format(LOCALE, INGEST_DOC2, "success"), "2"); assertEquals(getDocById(INDEX_NAME, "1").get("_source"), getDocById(INDEX_NAME, "batch_1").get("_source")); assertEquals(getDocById(INDEX_NAME, "2").get("_source"), getDocById(INDEX_NAME, "batch_2").get("_source")); @@ -147,6 +151,70 @@ public void testNestedFieldMapping_whenDocumentsIngested_thenSuccessful() throws } } + public void testTextEmbeddingProcessor_withBatchSizeInProcessor() throws Exception { + String modelId = null; + try { + modelId = uploadTextEmbeddingModel(); + loadModel(modelId); + URL pipelineURLPath = classLoader.getResource("processor/PipelineConfigurationWithBatchSize.json"); + Objects.requireNonNull(pipelineURLPath); + String requestBody = Files.readString(Path.of(pipelineURLPath.toURI())); + createPipelineProcessor(requestBody, PIPELINE_NAME, modelId); + createTextEmbeddingIndex(); + int docCount = 5; + ingestBatchDocumentWithBulk("batch_", docCount, docCount, Collections.emptySet(), Collections.emptySet()); + assertEquals(5, getDocCount(INDEX_NAME)); + + for (int i = 0; i < docCount; ++i) { + String template = List.of(INGEST_DOC1, INGEST_DOC2).get(i % 2); + String payload = String.format(LOCALE, template, "success"); + ingestDocument(payload, String.valueOf(i + 1)); + } + + for (int i = 0; i < docCount; ++i) { + assertEquals( + getDocById(INDEX_NAME, String.valueOf(i + 1)).get("_source"), + getDocById(INDEX_NAME, "batch_" + (i + 1)).get("_source") + ); + + } + } finally { + wipeOfTestResources(INDEX_NAME, PIPELINE_NAME, modelId, null); + } + } + + public void testTextEmbeddingProcessor_withFailureAndSkip() throws Exception { + String modelId = null; + try { + modelId = uploadTextEmbeddingModel(); + loadModel(modelId); + URL pipelineURLPath = classLoader.getResource("processor/PipelineConfigurationWithBatchSize.json"); + Objects.requireNonNull(pipelineURLPath); + String requestBody = Files.readString(Path.of(pipelineURLPath.toURI())); + createPipelineProcessor(requestBody, PIPELINE_NAME, modelId); + createTextEmbeddingIndex(); + int docCount = 5; + ingestBatchDocumentWithBulk("batch_", docCount, docCount, Set.of(0), Set.of(1)); + assertEquals(3, getDocCount(INDEX_NAME)); + + for (int i = 2; i < docCount; ++i) { + String template = List.of(INGEST_DOC1, INGEST_DOC2).get(i % 2); + String payload = String.format(LOCALE, template, "success"); + ingestDocument(payload, String.valueOf(i + 1)); + } + + for (int i = 2; i < docCount; ++i) { + assertEquals( + getDocById(INDEX_NAME, String.valueOf(i + 1)).get("_source"), + getDocById(INDEX_NAME, "batch_" + (i + 1)).get("_source") + ); + + } + } finally { + wipeOfTestResources(INDEX_NAME, PIPELINE_NAME, modelId, null); + } + } + private String uploadTextEmbeddingModel() throws Exception { String requestBody = Files.readString(Path.of(classLoader.getResource("processor/UploadModelRequestBody.json").toURI())); return registerModelGroupAndUploadModel(requestBody); @@ -183,23 +251,27 @@ private void ingestDocument(String doc, String id) throws Exception { assertEquals("created", map.get("result")); } - private void ingestBatchDocumentWithBulk(String idPrefix) throws Exception { - String doc1 = INGEST_DOC1.replace("\n", ""); - String doc2 = INGEST_DOC2.replace("\n", ""); - final String id1 = idPrefix + "1"; - final String id2 = idPrefix + "2"; - String item1 = BULK_ITEM_TEMPLATE.replace("{{index}}", INDEX_NAME) - .replace("{{id}}", id1) - .replace("{{doc}}", doc1) - .replace("{{comma}}", ","); - String item2 = BULK_ITEM_TEMPLATE.replace("{{index}}", INDEX_NAME) - .replace("{{id}}", id2) - .replace("{{doc}}", doc2) - .replace("{{comma}}", "\n"); - final String payload = item1 + item2; + private void ingestBatchDocumentWithBulk(String idPrefix, int docCount, int batchSize, Set<Integer> failedIds, Set<Integer> droppedIds) + throws Exception { + StringBuilder payloadBuilder = new StringBuilder(); + for (int i = 0; i < docCount; ++i) { + String docTemplate = List.of(INGEST_DOC1, INGEST_DOC2).get(i % 2); + if (failedIds.contains(i)) { + docTemplate = String.format(LOCALE, docTemplate, "fail"); + } else if (droppedIds.contains(i)) { + docTemplate = String.format(LOCALE, docTemplate, "drop"); + } else { + docTemplate = String.format(LOCALE, docTemplate, "success"); + } + String doc = docTemplate.replace("\n", ""); + final String id = idPrefix + (i + 1); + String item = BULK_ITEM_TEMPLATE.replace("{{index}}", INDEX_NAME).replace("{{id}}", id).replace("{{doc}}", doc); + payloadBuilder.append(item).append("\n"); + } + final String payload = payloadBuilder.toString(); Map<String, String> params = new HashMap<>(); params.put("refresh", "true"); - params.put("batch_size", "2"); + params.put("batch_size", String.valueOf(batchSize)); Response response = makeRequest( client(), "POST", @@ -213,7 +285,7 @@ private void ingestBatchDocumentWithBulk(String idPrefix) throws Exception { EntityUtils.toString(response.getEntity()), false ); - assertEquals(false, map.get("errors")); - assertEquals(2, ((List) map.get("items")).size()); + assertEquals(!failedIds.isEmpty(), map.get("errors")); + assertEquals(docCount, ((List) map.get("items")).size()); } } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java index 9a5e8aa76..95ae1a2de 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java @@ -40,6 +40,7 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.env.Environment; import org.opensearch.index.mapper.IndexFieldMapper; +import org.opensearch.ingest.AbstractBatchingProcessor; import org.opensearch.ingest.IngestDocument; import org.opensearch.ingest.IngestDocumentWrapper; import org.opensearch.ingest.Processor; @@ -87,7 +88,7 @@ private TextEmbeddingProcessor createInstanceWithLevel2MapConfig() { TextEmbeddingProcessor.FIELD_MAP_FIELD, ImmutableMap.of("key1", ImmutableMap.of("test1", "test1_knn"), "key2", ImmutableMap.of("test3", CHILD_LEVEL_2_KNN_FIELD)) ); - return textEmbeddingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); + return (TextEmbeddingProcessor) textEmbeddingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); } @SneakyThrows @@ -96,7 +97,17 @@ private TextEmbeddingProcessor createInstanceWithLevel1MapConfig() { Map<String, Object> config = new HashMap<>(); config.put(TextEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId"); config.put(TextEmbeddingProcessor.FIELD_MAP_FIELD, ImmutableMap.of("key1", "key1_knn", "key2", "key2_knn")); - return textEmbeddingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); + return (TextEmbeddingProcessor) textEmbeddingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); + } + + @SneakyThrows + private TextEmbeddingProcessor createInstanceWithLevel1MapConfig(int batchSize) { + Map<String, Processor.Factory> registry = new HashMap<>(); + Map<String, Object> config = new HashMap<>(); + config.put(TextEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId"); + config.put(TextEmbeddingProcessor.FIELD_MAP_FIELD, ImmutableMap.of("key1", "key1_knn", "key2", "key2_knn")); + config.put(AbstractBatchingProcessor.BATCH_SIZE_FIELD, batchSize); + return (TextEmbeddingProcessor) textEmbeddingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); } @SneakyThrows @@ -164,7 +175,12 @@ public void testExecute_whenInferenceThrowInterruptedException_throwRuntimeExcep Map<String, Object> config = new HashMap<>(); config.put(TextEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId"); config.put(TextEmbeddingProcessor.FIELD_MAP_FIELD, ImmutableMap.of("key1", "key1Mapped", "key2", "key2Mapped")); - TextEmbeddingProcessor processor = textEmbeddingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); + TextEmbeddingProcessor processor = (TextEmbeddingProcessor) textEmbeddingProcessorFactory.create( + registry, + PROCESSOR_TAG, + DESCRIPTION, + config + ); doThrow(new RuntimeException()).when(accessor).inferenceSentences(anyString(), anyList(), isA(ActionListener.class)); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); @@ -187,7 +203,12 @@ public void testExecute_whenInferenceTextListEmpty_SuccessWithoutEmbedding() { Map<String, Object> config = new HashMap<>(); config.put(TextEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId"); config.put(TextEmbeddingProcessor.FIELD_MAP_FIELD, ImmutableMap.of("key1", "key1Mapped", "key2", "key2Mapped")); - TextEmbeddingProcessor processor = textEmbeddingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); + TextEmbeddingProcessor processor = (TextEmbeddingProcessor) textEmbeddingProcessorFactory.create( + registry, + PROCESSOR_TAG, + DESCRIPTION, + config + ); doThrow(new RuntimeException()).when(accessor).inferenceSentences(anyString(), anyList(), isA(ActionListener.class)); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); @@ -314,7 +335,12 @@ public void testNestedFieldInMapping_withMapTypeInput_successful() { CHILD_LEVEL_2_KNN_FIELD ) ); - TextEmbeddingProcessor processor = textEmbeddingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); + TextEmbeddingProcessor processor = (TextEmbeddingProcessor) textEmbeddingProcessorFactory.create( + registry, + PROCESSOR_TAG, + DESCRIPTION, + config + ); List<List<Float>> modelTensorList = createRandomOneDimensionalMockVector(1, 100, 0.0f, 1.0f); doAnswer(invocation -> { @@ -358,7 +384,12 @@ public void testNestedFieldInMappingMixedSyntax_withMapTypeInput_successful() { Map.of(CHILD_FIELD_LEVEL_2, CHILD_LEVEL_2_KNN_FIELD) ) ); - TextEmbeddingProcessor processor = textEmbeddingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); + TextEmbeddingProcessor processor = (TextEmbeddingProcessor) textEmbeddingProcessorFactory.create( + registry, + PROCESSOR_TAG, + DESCRIPTION, + config + ); List<List<Float>> modelTensorList = createRandomOneDimensionalMockVector(1, 100, 0.0f, 1.0f); doAnswer(invocation -> { @@ -639,7 +670,7 @@ public void test_doublyNestedList_withMapType_successful() { public void test_batchExecute_successful() { final int docCount = 5; List<IngestDocumentWrapper> ingestDocumentWrappers = createIngestDocumentWrappers(docCount); - TextEmbeddingProcessor processor = createInstanceWithLevel1MapConfig(); + TextEmbeddingProcessor processor = createInstanceWithLevel1MapConfig(docCount); List<List<Float>> modelTensorList = createMockVectorWithLength(10); doAnswer(invocation -> { @@ -662,7 +693,7 @@ public void test_batchExecute_successful() { public void test_batchExecute_exception() { final int docCount = 5; List<IngestDocumentWrapper> ingestDocumentWrappers = createIngestDocumentWrappers(docCount); - TextEmbeddingProcessor processor = createInstanceWithLevel1MapConfig(); + TextEmbeddingProcessor processor = createInstanceWithLevel1MapConfig(docCount); doAnswer(invocation -> { ActionListener<List<List<Float>>> listener = invocation.getArgument(2); listener.onFailure(new RuntimeException()); @@ -780,7 +811,7 @@ private TextEmbeddingProcessor createInstanceWithNestedMapConfiguration(Map<Stri Map<String, Object> config = new HashMap<>(); config.put(TextEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId"); config.put(TextEmbeddingProcessor.FIELD_MAP_FIELD, fieldMap); - return textEmbeddingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); + return (TextEmbeddingProcessor) textEmbeddingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); } private Map<String, Object> createPlainStringConfiguration() { diff --git a/src/test/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessorTests.java index 89a42df80..8f0018f52 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessorTests.java @@ -89,7 +89,7 @@ private TextImageEmbeddingProcessor createInstance() { TextImageEmbeddingProcessor.FIELD_MAP_FIELD, ImmutableMap.of(TEXT_FIELD_NAME, "my_text_field", IMAGE_FIELD_NAME, "image_field") ); - return textImageEmbeddingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); + return (TextImageEmbeddingProcessor) textImageEmbeddingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); } @SneakyThrows @@ -223,7 +223,12 @@ public void testExecute_whenInferenceThrowInterruptedException_throwRuntimeExcep TextImageEmbeddingProcessor.FIELD_MAP_FIELD, ImmutableMap.of(TEXT_FIELD_NAME, "my_text_field", IMAGE_FIELD_NAME, "image_field") ); - TextImageEmbeddingProcessor processor = textImageEmbeddingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); + TextImageEmbeddingProcessor processor = (TextImageEmbeddingProcessor) textImageEmbeddingProcessorFactory.create( + registry, + PROCESSOR_TAG, + DESCRIPTION, + config + ); doThrow(new RuntimeException()).when(accessor).inferenceSentences(anyString(), anyMap(), isA(ActionListener.class)); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/factory/TextImageEmbeddingProcessorFactoryTests.java b/src/test/java/org/opensearch/neuralsearch/processor/factory/TextImageEmbeddingProcessorFactoryTests.java index fa91d61a5..cfb0803a6 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/factory/TextImageEmbeddingProcessorFactoryTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/factory/TextImageEmbeddingProcessorFactoryTests.java @@ -42,7 +42,7 @@ public void testNormalizationProcessor_whenAllParamsPassed_thenSuccessful() { config.put(MODEL_ID_FIELD, "1234567678"); config.put(EMBEDDING_FIELD, "embedding_field"); config.put(FIELD_MAP_FIELD, Map.of(TEXT_FIELD_NAME, "my_text_field", IMAGE_FIELD_NAME, "my_image_field")); - TextImageEmbeddingProcessor inferenceProcessor = textImageEmbeddingProcessorFactory.create( + TextImageEmbeddingProcessor inferenceProcessor = (TextImageEmbeddingProcessor) textImageEmbeddingProcessorFactory.create( processorFactories, tag, description, @@ -68,7 +68,7 @@ public void testNormalizationProcessor_whenOnlyOneParamSet_thenSuccessful() { configOnlyTextField.put(MODEL_ID_FIELD, "1234567678"); configOnlyTextField.put(EMBEDDING_FIELD, "embedding_field"); configOnlyTextField.put(FIELD_MAP_FIELD, Map.of(TEXT_FIELD_NAME, "my_text_field")); - TextImageEmbeddingProcessor processor = textImageEmbeddingProcessorFactory.create( + TextImageEmbeddingProcessor processor = (TextImageEmbeddingProcessor) textImageEmbeddingProcessorFactory.create( processorFactories, tag, description, @@ -81,7 +81,12 @@ public void testNormalizationProcessor_whenOnlyOneParamSet_thenSuccessful() { configOnlyImageField.put(MODEL_ID_FIELD, "1234567678"); configOnlyImageField.put(EMBEDDING_FIELD, "embedding_field"); configOnlyImageField.put(FIELD_MAP_FIELD, Map.of(TEXT_FIELD_NAME, "my_text_field")); - processor = textImageEmbeddingProcessorFactory.create(processorFactories, tag, description, configOnlyImageField); + processor = (TextImageEmbeddingProcessor) textImageEmbeddingProcessorFactory.create( + processorFactories, + tag, + description, + configOnlyImageField + ); assertNotNull(processor); assertEquals("text_image_embedding", processor.getType()); } diff --git a/src/test/resources/processor/PipelineConfigurationWithBatchSize.json b/src/test/resources/processor/PipelineConfigurationWithBatchSize.json new file mode 100644 index 000000000..953a419f1 --- /dev/null +++ b/src/test/resources/processor/PipelineConfigurationWithBatchSize.json @@ -0,0 +1,33 @@ +{ + "description": "text embedding pipeline for hybrid", + "processors": [ + { + "drop": { + "if": "ctx.text.contains('drop')" + } + }, + { + "fail": { + "if": "ctx.text.contains('fail')", + "message": "fail" + } + }, + { + "text_embedding": { + "model_id": "%s", + "batch_size": 2, + "field_map": { + "title": "title_knn", + "favor_list": "favor_list_knn", + "favorites": { + "game": "game_knn", + "movie": "movie_knn" + }, + "nested_passages": { + "text": "embedding" + } + } + } + } + ] +} diff --git a/src/test/resources/processor/bulk_item_template.json b/src/test/resources/processor/bulk_item_template.json index 79881b630..33b70523f 100644 --- a/src/test/resources/processor/bulk_item_template.json +++ b/src/test/resources/processor/bulk_item_template.json @@ -1,2 +1,2 @@ { "index": { "_index": "{{index}}", "_id": "{{id}}" } }, -{{doc}}{{comma}} +{{doc}} diff --git a/src/test/resources/processor/ingest_doc1.json b/src/test/resources/processor/ingest_doc1.json index e3302c75a..b1cc5392b 100644 --- a/src/test/resources/processor/ingest_doc1.json +++ b/src/test/resources/processor/ingest_doc1.json @@ -1,5 +1,6 @@ { "title": "This is a good day", + "text": "%s", "description": "daily logging", "favor_list": [ "test", diff --git a/src/test/resources/processor/ingest_doc2.json b/src/test/resources/processor/ingest_doc2.json index 400f9027a..cce93d4a1 100644 --- a/src/test/resources/processor/ingest_doc2.json +++ b/src/test/resources/processor/ingest_doc2.json @@ -1,5 +1,6 @@ { "title": "this is a second doc", + "text": "%s", "description": "the description is not very long", "favor_list": [ "favor"