Skip to content

Commit

Permalink
Use AbstractBatchingProcessor for InferenceProcessor (#820) (#832)
Browse files Browse the repository at this point in the history
* Use AbstractBatchingProcessor for InferenceProcessor

Signed-off-by: Liyun Xiu <[email protected]>

* Add chagnelog

Signed-off-by: Liyun Xiu <[email protected]>

---------

Signed-off-by: Liyun Xiu <[email protected]>
(cherry picked from commit bf2fd5a)

Co-authored-by: Liyun Xiu <[email protected]>
Co-authored-by: zhichao-aws <[email protected]>
  • Loading branch information
3 people authored Jul 19, 2024
1 parent a581311 commit 342aef5
Show file tree
Hide file tree
Showing 17 changed files with 278 additions and 98 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Features
- Enable sorting and search_after features in Hybrid Search [#827](https://github.com/opensearch-project/neural-search/pull/827)
### Enhancements
- InferenceProcessor inherits from AbstractBatchingProcessor to support sub batching in processor [#820](https://github.com/opensearch-project/neural-search/pull/820)
- Adds dynamic knn query parameters efsearch and nprobes [#814](https://github.com/opensearch-project/neural-search/pull/814/)
- Enable '.' for nested field in text embedding processor ([#811](https://github.com/opensearch-project/neural-search/pull/811))
### Bug Fixes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.env.Environment;
import org.opensearch.index.mapper.IndexFieldMapper;
import org.opensearch.ingest.AbstractProcessor;
import org.opensearch.ingest.AbstractBatchingProcessor;
import org.opensearch.ingest.IngestDocument;
import org.opensearch.ingest.IngestDocumentWrapper;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
Expand All @@ -46,7 +46,7 @@
* and set the target fields according to the field name map.
*/
@Log4j2
public abstract class InferenceProcessor extends AbstractProcessor {
public abstract class InferenceProcessor extends AbstractBatchingProcessor {

public static final String MODEL_ID_FIELD = "model_id";
public static final String FIELD_MAP_FIELD = "field_map";
Expand All @@ -69,6 +69,7 @@ public abstract class InferenceProcessor extends AbstractProcessor {
public InferenceProcessor(
String tag,
String description,
int batchSize,
String type,
String listTypeNestedMapKey,
String modelId,
Expand All @@ -77,7 +78,7 @@ public InferenceProcessor(
Environment environment,
ClusterService clusterService
) {
super(tag, description);
super(tag, description, batchSize);
this.type = type;
if (StringUtils.isBlank(modelId)) throw new IllegalArgumentException("model_id is null or empty, cannot process it");
validateEmbeddingConfiguration(fieldMap);
Expand Down Expand Up @@ -144,7 +145,7 @@ public void execute(IngestDocument ingestDocument, BiConsumer<IngestDocument, Ex
abstract void doBatchExecute(List<String> inferenceList, Consumer<List<?>> handler, Consumer<Exception> onException);

@Override
public void batchExecute(List<IngestDocumentWrapper> ingestDocumentWrappers, Consumer<List<IngestDocumentWrapper>> handler) {
public void subBatchExecute(List<IngestDocumentWrapper> ingestDocumentWrappers, Consumer<List<IngestDocumentWrapper>> handler) {
if (CollectionUtils.isEmpty(ingestDocumentWrappers)) {
handler.accept(Collections.emptyList());
return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,14 @@ public final class SparseEncodingProcessor extends InferenceProcessor {
public SparseEncodingProcessor(
String tag,
String description,
int batchSize,
String modelId,
Map<String, Object> fieldMap,
MLCommonsClientAccessor clientAccessor,
Environment environment,
ClusterService clusterService
) {
super(tag, description, TYPE, LIST_TYPE_NESTED_MAP_KEY, modelId, fieldMap, clientAccessor, environment, clusterService);
super(tag, description, batchSize, TYPE, LIST_TYPE_NESTED_MAP_KEY, modelId, fieldMap, clientAccessor, environment, clusterService);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,14 @@ public final class TextEmbeddingProcessor extends InferenceProcessor {
public TextEmbeddingProcessor(
String tag,
String description,
int batchSize,
String modelId,
Map<String, Object> fieldMap,
MLCommonsClientAccessor clientAccessor,
Environment environment,
ClusterService clusterService
) {
super(tag, description, TYPE, LIST_TYPE_NESTED_MAP_KEY, modelId, fieldMap, clientAccessor, environment, clusterService);
super(tag, description, batchSize, TYPE, LIST_TYPE_NESTED_MAP_KEY, modelId, fieldMap, clientAccessor, environment, clusterService);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import org.opensearch.cluster.service.ClusterService;
import org.opensearch.env.Environment;
import org.opensearch.ingest.Processor;
import org.opensearch.ingest.AbstractBatchingProcessor;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.processor.SparseEncodingProcessor;

Expand All @@ -24,27 +24,23 @@
* Factory for sparse encoding ingest processor for ingestion pipeline. Instantiates processor based on user provided input.
*/
@Log4j2
public class SparseEncodingProcessorFactory implements Processor.Factory {
public class SparseEncodingProcessorFactory extends AbstractBatchingProcessor.Factory {
private final MLCommonsClientAccessor clientAccessor;
private final Environment environment;
private final ClusterService clusterService;

public SparseEncodingProcessorFactory(MLCommonsClientAccessor clientAccessor, Environment environment, ClusterService clusterService) {
super(TYPE);
this.clientAccessor = clientAccessor;
this.environment = environment;
this.clusterService = clusterService;
}

@Override
public SparseEncodingProcessor create(
Map<String, Processor.Factory> registry,
String processorTag,
String description,
Map<String, Object> config
) throws Exception {
String modelId = readStringProperty(TYPE, processorTag, config, MODEL_ID_FIELD);
Map<String, Object> fieldMap = readMap(TYPE, processorTag, config, FIELD_MAP_FIELD);

return new SparseEncodingProcessor(processorTag, description, modelId, fieldMap, clientAccessor, environment, clusterService);
protected AbstractBatchingProcessor newProcessor(String tag, String description, int batchSize, Map<String, Object> config) {
String modelId = readStringProperty(TYPE, tag, config, MODEL_ID_FIELD);
Map<String, Object> fieldMap = readMap(TYPE, tag, config, FIELD_MAP_FIELD);

return new SparseEncodingProcessor(tag, description, batchSize, modelId, fieldMap, clientAccessor, environment, clusterService);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@

import org.opensearch.cluster.service.ClusterService;
import org.opensearch.env.Environment;
import org.opensearch.ingest.Processor;
import org.opensearch.ingest.AbstractBatchingProcessor;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor;

/**
* Factory for text embedding ingest processor for ingestion pipeline. Instantiates processor based on user provided input.
*/
public class TextEmbeddingProcessorFactory implements Processor.Factory {
public final class TextEmbeddingProcessorFactory extends AbstractBatchingProcessor.Factory {

private final MLCommonsClientAccessor clientAccessor;

Expand All @@ -34,20 +34,16 @@ public TextEmbeddingProcessorFactory(
final Environment environment,
final ClusterService clusterService
) {
super(TYPE);
this.clientAccessor = clientAccessor;
this.environment = environment;
this.clusterService = clusterService;
}

@Override
public TextEmbeddingProcessor create(
final Map<String, Processor.Factory> registry,
final String processorTag,
final String description,
final Map<String, Object> config
) throws Exception {
String modelId = readStringProperty(TYPE, processorTag, config, MODEL_ID_FIELD);
Map<String, Object> filedMap = readMap(TYPE, processorTag, config, FIELD_MAP_FIELD);
return new TextEmbeddingProcessor(processorTag, description, modelId, filedMap, clientAccessor, environment, clusterService);
protected AbstractBatchingProcessor newProcessor(String tag, String description, int batchSize, Map<String, Object> config) {
String modelId = readStringProperty(TYPE, tag, config, MODEL_ID_FIELD);
Map<String, Object> filedMap = readMap(TYPE, tag, config, FIELD_MAP_FIELD);
return new TextEmbeddingProcessor(tag, description, batchSize, modelId, filedMap, clientAccessor, environment, clusterService);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import static org.opensearch.ingest.ConfigurationUtils.readMap;
import static org.opensearch.ingest.ConfigurationUtils.readStringProperty;
import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.Factory;
import static org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor.EMBEDDING_FIELD;
import static org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor.FIELD_MAP_FIELD;
import static org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor.MODEL_ID_FIELD;
Expand All @@ -16,6 +15,7 @@

import org.opensearch.cluster.service.ClusterService;
import org.opensearch.env.Environment;
import org.opensearch.ingest.Processor;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor;

Expand All @@ -25,31 +25,18 @@
* Factory for text_image embedding ingest processor for ingestion pipeline. Instantiates processor based on user provided input.
*/
@AllArgsConstructor
public class TextImageEmbeddingProcessorFactory implements Factory {
public class TextImageEmbeddingProcessorFactory implements Processor.Factory {

private final MLCommonsClientAccessor clientAccessor;
private final Environment environment;
private final ClusterService clusterService;

@Override
public TextImageEmbeddingProcessor create(
final Map<String, Factory> registry,
final String processorTag,
final String description,
final Map<String, Object> config
) throws Exception {
String modelId = readStringProperty(TYPE, processorTag, config, MODEL_ID_FIELD);
String embedding = readStringProperty(TYPE, processorTag, config, EMBEDDING_FIELD);
Map<String, String> filedMap = readMap(TYPE, processorTag, config, FIELD_MAP_FIELD);
return new TextImageEmbeddingProcessor(
processorTag,
description,
modelId,
embedding,
filedMap,
clientAccessor,
environment,
clusterService
);
public Processor create(Map<String, Processor.Factory> processorFactories, String tag, String description, Map<String, Object> config)
throws Exception {
String modelId = readStringProperty(TYPE, tag, config, MODEL_ID_FIELD);
String embedding = readStringProperty(TYPE, tag, config, EMBEDDING_FIELD);
Map<String, String> filedMap = readMap(TYPE, tag, config, FIELD_MAP_FIELD);
return new TextImageEmbeddingProcessor(tag, description, modelId, embedding, filedMap, clientAccessor, environment, clusterService);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
*/
package org.opensearch.neuralsearch.processor;

import lombok.Getter;
import org.junit.Before;
import org.mockito.ArgumentCaptor;
import org.mockito.MockitoAnnotations;
Expand All @@ -15,6 +16,7 @@
import org.opensearch.ingest.IngestDocumentWrapper;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
Expand Down Expand Up @@ -42,6 +44,7 @@ public class InferenceProcessorTests extends InferenceProcessorTestCase {
private static final String DESCRIPTION = "description";
private static final String MAP_KEY = "map_key";
private static final String MODEL_ID = "model_id";
private static final int BATCH_SIZE = 10;
private static final Map<String, Object> FIELD_MAP = Map.of("key1", "embedding_key1", "key2", "embedding_key2");

@Before
Expand All @@ -54,7 +57,7 @@ public void setup() {
}

public void test_batchExecute_emptyInput() {
TestInferenceProcessor processor = new TestInferenceProcessor(createMockVectorResult(), null);
TestInferenceProcessor processor = new TestInferenceProcessor(createMockVectorResult(), BATCH_SIZE, null);
Consumer resultHandler = mock(Consumer.class);
processor.batchExecute(Collections.emptyList(), resultHandler);
ArgumentCaptor<List<IngestDocumentWrapper>> captor = ArgumentCaptor.forClass(List.class);
Expand All @@ -65,7 +68,7 @@ public void test_batchExecute_emptyInput() {

public void test_batchExecute_allFailedValidation() {
final int docCount = 2;
TestInferenceProcessor processor = new TestInferenceProcessor(createMockVectorResult(), null);
TestInferenceProcessor processor = new TestInferenceProcessor(createMockVectorResult(), BATCH_SIZE, null);
List<IngestDocumentWrapper> wrapperList = createIngestDocumentWrappers(docCount);
wrapperList.get(0).getIngestDocument().setFieldValue("key1", Arrays.asList("", "value1"));
wrapperList.get(1).getIngestDocument().setFieldValue("key1", Arrays.asList("", "value1"));
Expand All @@ -83,7 +86,7 @@ public void test_batchExecute_allFailedValidation() {

public void test_batchExecute_partialFailedValidation() {
final int docCount = 2;
TestInferenceProcessor processor = new TestInferenceProcessor(createMockVectorResult(), null);
TestInferenceProcessor processor = new TestInferenceProcessor(createMockVectorResult(), BATCH_SIZE, null);
List<IngestDocumentWrapper> wrapperList = createIngestDocumentWrappers(docCount);
wrapperList.get(0).getIngestDocument().setFieldValue("key1", Arrays.asList("", "value1"));
wrapperList.get(1).getIngestDocument().setFieldValue("key1", Arrays.asList("value3", "value4"));
Expand All @@ -105,7 +108,7 @@ public void test_batchExecute_partialFailedValidation() {
public void test_batchExecute_happyCase() {
final int docCount = 2;
List<List<Float>> inferenceResults = createMockVectorWithLength(6);
TestInferenceProcessor processor = new TestInferenceProcessor(inferenceResults, null);
TestInferenceProcessor processor = new TestInferenceProcessor(inferenceResults, BATCH_SIZE, null);
List<IngestDocumentWrapper> wrapperList = createIngestDocumentWrappers(docCount);
wrapperList.get(0).getIngestDocument().setFieldValue("key1", Arrays.asList("value1", "value2"));
wrapperList.get(1).getIngestDocument().setFieldValue("key1", Arrays.asList("value3", "value4"));
Expand All @@ -126,7 +129,7 @@ public void test_batchExecute_happyCase() {
public void test_batchExecute_sort() {
final int docCount = 2;
List<List<Float>> inferenceResults = createMockVectorWithLength(100);
TestInferenceProcessor processor = new TestInferenceProcessor(inferenceResults, null);
TestInferenceProcessor processor = new TestInferenceProcessor(inferenceResults, BATCH_SIZE, null);
List<IngestDocumentWrapper> wrapperList = createIngestDocumentWrappers(docCount);
wrapperList.get(0).getIngestDocument().setFieldValue("key1", Arrays.asList("aaaaa", "bbb"));
wrapperList.get(1).getIngestDocument().setFieldValue("key1", Arrays.asList("cc", "ddd"));
Expand Down Expand Up @@ -158,7 +161,7 @@ public void test_batchExecute_sort() {
public void test_doBatchExecute_exception() {
final int docCount = 2;
List<List<Float>> inferenceResults = createMockVectorWithLength(6);
TestInferenceProcessor processor = new TestInferenceProcessor(inferenceResults, new RuntimeException());
TestInferenceProcessor processor = new TestInferenceProcessor(inferenceResults, BATCH_SIZE, new RuntimeException());
List<IngestDocumentWrapper> wrapperList = createIngestDocumentWrappers(docCount);
wrapperList.get(0).getIngestDocument().setFieldValue("key1", Arrays.asList("value1", "value2"));
wrapperList.get(1).getIngestDocument().setFieldValue("key1", Arrays.asList("value3", "value4"));
Expand All @@ -174,12 +177,36 @@ public void test_doBatchExecute_exception() {
verify(clientAccessor).inferenceSentences(anyString(), anyList(), any());
}

public void test_batchExecute_subBatches() {
final int docCount = 5;
List<List<Float>> inferenceResults = createMockVectorWithLength(6);
TestInferenceProcessor processor = new TestInferenceProcessor(inferenceResults, 2, null);
List<IngestDocumentWrapper> wrapperList = createIngestDocumentWrappers(docCount);
for (int i = 0; i < docCount; ++i) {
wrapperList.get(i).getIngestDocument().setFieldValue("key1", Collections.singletonList("value" + i));
}
List<IngestDocumentWrapper> allResults = new ArrayList<>();
processor.batchExecute(wrapperList, allResults::addAll);
for (int i = 0; i < docCount; ++i) {
assertEquals(allResults.get(i).getIngestDocument(), wrapperList.get(i).getIngestDocument());
assertEquals(allResults.get(i).getSlot(), wrapperList.get(i).getSlot());
assertEquals(allResults.get(i).getException(), wrapperList.get(i).getException());
}
assertEquals(3, processor.getAllInferenceInputs().size());
assertEquals(List.of("value0", "value1"), processor.getAllInferenceInputs().get(0));
assertEquals(List.of("value2", "value3"), processor.getAllInferenceInputs().get(1));
assertEquals(List.of("value4"), processor.getAllInferenceInputs().get(2));
}

private class TestInferenceProcessor extends InferenceProcessor {
List<?> vectors;
Exception exception;

public TestInferenceProcessor(List<?> vectors, Exception exception) {
super(TAG, DESCRIPTION, TYPE, MAP_KEY, MODEL_ID, FIELD_MAP, clientAccessor, environment, clusterService);
@Getter
List<List<String>> allInferenceInputs = new ArrayList<>();

public TestInferenceProcessor(List<?> vectors, int batchSize, Exception exception) {
super(TAG, DESCRIPTION, batchSize, TYPE, MAP_KEY, MODEL_ID, FIELD_MAP, clientAccessor, environment, clusterService);
this.vectors = vectors;
this.exception = exception;
}
Expand All @@ -196,6 +223,7 @@ public void doExecute(
void doBatchExecute(List<String> inferenceList, Consumer<List<?>> handler, Consumer<Exception> onException) {
// use to verify if doBatchExecute is called from InferenceProcessor
clientAccessor.inferenceSentences(MODEL_ID, inferenceList, ActionListener.wrap(results -> {}, ex -> {}));
allInferenceInputs.add(inferenceList);
if (this.exception != null) {
onException.accept(this.exception);
} else {
Expand Down
Loading

0 comments on commit 342aef5

Please sign in to comment.