diff --git a/build.gradle b/build.gradle index 853aa85e7..335157549 100644 --- a/build.gradle +++ b/build.gradle @@ -151,6 +151,7 @@ dependencies { runtimeOnly group: 'org.reflections', name: 'reflections', version: '0.9.12' runtimeOnly group: 'org.javassist', name: 'javassist', version: '3.29.2-GA' runtimeOnly group: 'org.opensearch', name: 'common-utils', version: "${opensearch_build}" + runtimeOnly group: 'org.apache.commons', name: 'commons-text', version: '1.10.0' runtimeOnly group: 'com.google.code.gson', name: 'gson', version: '2.10.1' runtimeOnly group: 'org.json', name: 'json', version: '20230227' } diff --git a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java index 6f8b790bb..1c09f5996 100644 --- a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java +++ b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java @@ -5,6 +5,9 @@ package org.opensearch.neuralsearch.ml; +import static org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor.INPUT_IMAGE; +import static org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor.INPUT_TEXT; + import java.util.ArrayList; import java.util.Arrays; import java.util.List; @@ -113,13 +116,30 @@ public void inferenceSentencesWithMapResult( retryableInferenceSentencesWithMapResult(modelId, inputText, 0, listener); } + /** + * Abstraction to call predict function of api of MLClient with provided targetResponse filters. It uses the + * custom model provided as modelId and run the {@link FunctionName#TEXT_EMBEDDING}. The return will be sent + * using the actionListener which will have a list of floats in the order of inputText. + * + * @param modelId {@link String} + * @param inputObjects {@link Map} of {@link String}, {@link String} on which inference needs to happen + * @param listener {@link ActionListener} which will be called when prediction is completed or errored out. + */ + public void inferenceSentences( + @NonNull final String modelId, + @NonNull final Map inputObjects, + @NonNull final ActionListener> listener + ) { + retryableInferenceSentencesWithSingleVectorResult(TARGET_RESPONSE_FILTERS, modelId, inputObjects, 0, listener); + } + private void retryableInferenceSentencesWithMapResult( final String modelId, final List inputText, final int retryTime, final ActionListener>> listener ) { - MLInput mlInput = createMLInput(null, inputText); + MLInput mlInput = createMLTextInput(null, inputText); mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> { final List> result = buildMapResultFromResponse(mlOutput); listener.onResponse(result); @@ -140,7 +160,7 @@ private void retryableInferenceSentencesWithVectorResult( final int retryTime, final ActionListener>> listener ) { - MLInput mlInput = createMLInput(targetResponseFilters, inputText); + MLInput mlInput = createMLTextInput(targetResponseFilters, inputText); mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> { final List> vector = buildVectorFromResponse(mlOutput); listener.onResponse(vector); @@ -154,7 +174,7 @@ private void retryableInferenceSentencesWithVectorResult( })); } - private MLInput createMLInput(final List targetResponseFilters, List inputText) { + private MLInput createMLTextInput(final List targetResponseFilters, List inputText) { final ModelResultFilter modelResultFilter = new ModelResultFilter(false, true, targetResponseFilters, null); final MLInputDataset inputDataset = new TextDocsInputDataSet(inputText, modelResultFilter); return new MLInput(FunctionName.TEXT_EMBEDDING, null, inputDataset); @@ -191,4 +211,41 @@ private List> buildVectorFromResponse(MLOutput mlOutput) { return resultMaps; } + private List buildSingleVectorFromResponse(final MLOutput mlOutput) { + final List> vector = buildVectorFromResponse(mlOutput); + return vector.isEmpty() ? new ArrayList<>() : vector.get(0); + } + + private void retryableInferenceSentencesWithSingleVectorResult( + final List targetResponseFilters, + final String modelId, + final Map inputObjects, + final int retryTime, + final ActionListener> listener + ) { + MLInput mlInput = createMLMultimodalInput(targetResponseFilters, inputObjects); + mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> { + final List vector = buildSingleVectorFromResponse(mlOutput); + log.debug("Inference Response for input sentence is : {} ", vector); + listener.onResponse(vector); + }, e -> { + if (RetryUtil.shouldRetry(e, retryTime)) { + final int retryTimeAdd = retryTime + 1; + retryableInferenceSentencesWithSingleVectorResult(targetResponseFilters, modelId, inputObjects, retryTimeAdd, listener); + } else { + listener.onFailure(e); + } + })); + } + + private MLInput createMLMultimodalInput(final List targetResponseFilters, final Map input) { + List inputText = new ArrayList<>(); + inputText.add(input.get(INPUT_TEXT)); + if (input.containsKey(INPUT_IMAGE)) { + inputText.add(input.get(INPUT_IMAGE)); + } + final ModelResultFilter modelResultFilter = new ModelResultFilter(false, true, targetResponseFilters, null); + final MLInputDataset inputDataset = new TextDocsInputDataSet(inputText, modelResultFilter); + return new MLInput(FunctionName.TEXT_EMBEDDING, null, inputDataset); + } } diff --git a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java index dd7dfee49..7bff137fd 100644 --- a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java +++ b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java @@ -33,11 +33,13 @@ import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow; import org.opensearch.neuralsearch.processor.SparseEncodingProcessor; import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor; +import org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor; import org.opensearch.neuralsearch.processor.combination.ScoreCombinationFactory; import org.opensearch.neuralsearch.processor.combination.ScoreCombiner; import org.opensearch.neuralsearch.processor.factory.NormalizationProcessorFactory; import org.opensearch.neuralsearch.processor.factory.SparseEncodingProcessorFactory; import org.opensearch.neuralsearch.processor.factory.TextEmbeddingProcessorFactory; +import org.opensearch.neuralsearch.processor.factory.TextImageEmbeddingProcessorFactory; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationFactory; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer; import org.opensearch.neuralsearch.query.HybridQueryBuilder; @@ -106,7 +108,9 @@ public Map getProcessors(Processor.Parameters paramet TextEmbeddingProcessor.TYPE, new TextEmbeddingProcessorFactory(clientAccessor, parameters.env), SparseEncodingProcessor.TYPE, - new SparseEncodingProcessorFactory(clientAccessor, parameters.env) + new SparseEncodingProcessorFactory(clientAccessor, parameters.env), + TextImageEmbeddingProcessor.TYPE, + new TextImageEmbeddingProcessorFactory(clientAccessor, parameters.env, parameters.ingestService.getClusterService()) ); } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NLPProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/NLPProcessor.java index 4ac63d419..958b8a8be 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NLPProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NLPProcessor.java @@ -249,7 +249,7 @@ protected void setVectorFieldsToDocument(IngestDocument ingestDocument, Map buildNLPResult(Map processorMap, List results, Map sourceAndMetadataMap) { - NLPProcessor.IndexWrapper indexWrapper = new NLPProcessor.IndexWrapper(0); + IndexWrapper indexWrapper = new IndexWrapper(0); Map result = new LinkedHashMap<>(); for (Map.Entry knnMapEntry : processorMap.entrySet()) { String knnKey = knnMapEntry.getKey(); @@ -270,7 +270,7 @@ private void putNLPResultToSourceMapForMapType( String processorKey, Object sourceValue, List results, - NLPProcessor.IndexWrapper indexWrapper, + IndexWrapper indexWrapper, Map sourceAndMetadataMap ) { if (processorKey == null || sourceAndMetadataMap == null || sourceValue == null) return; @@ -291,11 +291,7 @@ private void putNLPResultToSourceMapForMapType( } } - private List> buildNLPResultForListType( - List sourceValue, - List results, - NLPProcessor.IndexWrapper indexWrapper - ) { + private List> buildNLPResultForListType(List sourceValue, List results, IndexWrapper indexWrapper) { List> keyToResult = new ArrayList<>(); IntStream.range(0, sourceValue.size()) .forEachOrdered(x -> keyToResult.add(ImmutableMap.of(listTypeNestedMapKey, results.get(indexWrapper.index++)))); diff --git a/src/main/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessor.java new file mode 100644 index 000000000..a0d9606e9 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessor.java @@ -0,0 +1,238 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor; + +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.function.BiConsumer; +import java.util.function.Supplier; + +import lombok.extern.log4j.Log4j2; + +import org.apache.commons.lang3.StringUtils; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.action.ActionListener; +import org.opensearch.env.Environment; +import org.opensearch.index.mapper.IndexFieldMapper; +import org.opensearch.index.mapper.MapperService; +import org.opensearch.ingest.AbstractProcessor; +import org.opensearch.ingest.IngestDocument; +import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; + +import com.google.common.annotations.VisibleForTesting; + +/** + * This processor is used for user input data text and image embedding processing, model_id can be used to indicate which model user use, + * and field_map can be used to indicate which fields needs embedding and the corresponding keys for the embedding results. + */ +@Log4j2 +public class TextImageEmbeddingProcessor extends AbstractProcessor { + + public static final String TYPE = "text_image_embedding"; + public static final String MODEL_ID_FIELD = "model_id"; + public static final String EMBEDDING_FIELD = "embedding"; + public static final String FIELD_MAP_FIELD = "field_map"; + public static final String TEXT_FIELD_NAME = "text"; + public static final String IMAGE_FIELD_NAME = "image"; + public static final String INPUT_TEXT = "inputText"; + public static final String INPUT_IMAGE = "inputImage"; + private static final Set VALID_FIELD_NAMES = Set.of(TEXT_FIELD_NAME, IMAGE_FIELD_NAME); + + private final String modelId; + private final String embedding; + private final Map fieldMap; + + private final MLCommonsClientAccessor mlCommonsClientAccessor; + private final Environment environment; + private final ClusterService clusterService; + + public TextImageEmbeddingProcessor( + final String tag, + final String description, + final String modelId, + final String embedding, + final Map fieldMap, + final MLCommonsClientAccessor clientAccessor, + final Environment environment, + final ClusterService clusterService + ) { + super(tag, description); + if (StringUtils.isBlank(modelId)) throw new IllegalArgumentException("model_id is null or empty, can not process it"); + validateEmbeddingConfiguration(fieldMap); + + this.modelId = modelId; + this.embedding = embedding; + this.fieldMap = fieldMap; + this.mlCommonsClientAccessor = clientAccessor; + this.environment = environment; + this.clusterService = clusterService; + } + + private void validateEmbeddingConfiguration(final Map fieldMap) { + if (fieldMap == null + || fieldMap.isEmpty() + || fieldMap.entrySet().stream().anyMatch(x -> StringUtils.isBlank(x.getKey()) || Objects.isNull(x.getValue()))) { + throw new IllegalArgumentException("Unable to create the TextImageEmbedding processor as field_map has invalid key or value"); + } + + if (fieldMap.entrySet().stream().anyMatch(entry -> !VALID_FIELD_NAMES.contains(entry.getKey()))) { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "Unable to create the TextImageEmbedding processor with provided field name(s). Following names are supported [%s]", + String.join(",", VALID_FIELD_NAMES) + ) + ); + } + } + + @Override + public IngestDocument execute(IngestDocument ingestDocument) { + return ingestDocument; + } + + /** + * This method will be invoked by PipelineService to make async inference and then delegate the handler to + * process the inference response or failure. + * @param ingestDocument {@link IngestDocument} which is the document passed to processor. + * @param handler {@link BiConsumer} which is the handler which can be used after the inference task is done. + */ + @Override + public void execute(final IngestDocument ingestDocument, final BiConsumer handler) { + try { + validateEmbeddingFieldsValue(ingestDocument); + Map knnMap = buildMapWithKnnKeyAndOriginalValue(ingestDocument); + Map inferenceMap = createInferences(knnMap); + if (inferenceMap.isEmpty()) { + handler.accept(ingestDocument, null); + } else { + mlCommonsClientAccessor.inferenceSentences(this.modelId, inferenceMap, ActionListener.wrap(vectors -> { + setVectorFieldsToDocument(ingestDocument, vectors); + handler.accept(ingestDocument, null); + }, e -> { handler.accept(null, e); })); + } + } catch (Exception e) { + handler.accept(null, e); + } + + } + + private void setVectorFieldsToDocument(final IngestDocument ingestDocument, final List vectors) { + Objects.requireNonNull(vectors, "embedding failed, inference returns null result!"); + log.debug("Text embedding result fetched, starting build vector output!"); + Map textEmbeddingResult = buildTextEmbeddingResult(this.embedding, vectors); + textEmbeddingResult.forEach(ingestDocument::setFieldValue); + } + + @SuppressWarnings({ "unchecked" }) + private Map createInferences(final Map knnKeyMap) { + Map texts = new HashMap<>(); + if (fieldMap.containsKey(TEXT_FIELD_NAME) && knnKeyMap.containsKey(fieldMap.get(TEXT_FIELD_NAME))) { + texts.put(INPUT_TEXT, knnKeyMap.get(fieldMap.get(TEXT_FIELD_NAME))); + } + if (fieldMap.containsKey(IMAGE_FIELD_NAME) && knnKeyMap.containsKey(fieldMap.get(IMAGE_FIELD_NAME))) { + texts.put(INPUT_IMAGE, knnKeyMap.get(fieldMap.get(IMAGE_FIELD_NAME))); + } + return texts; + } + + @VisibleForTesting + Map buildMapWithKnnKeyAndOriginalValue(final IngestDocument ingestDocument) { + Map sourceAndMetadataMap = ingestDocument.getSourceAndMetadata(); + Map mapWithKnnKeys = new LinkedHashMap<>(); + for (Map.Entry fieldMapEntry : fieldMap.entrySet()) { + String originalKey = fieldMapEntry.getValue(); // field from ingest document that we need to sent as model input, part of + // processor definition + + if (!sourceAndMetadataMap.containsKey(originalKey)) { + continue; + } + if (!(sourceAndMetadataMap.get(originalKey) instanceof String)) { + throw new IllegalArgumentException("Unsupported format of the field in the document, value must be a string"); + } + mapWithKnnKeys.put(originalKey, (String) sourceAndMetadataMap.get(originalKey)); + } + return mapWithKnnKeys; + } + + @SuppressWarnings({ "unchecked" }) + @VisibleForTesting + Map buildTextEmbeddingResult(final String knnKey, List modelTensorList) { + Map result = new LinkedHashMap<>(); + result.put(knnKey, modelTensorList); + return result; + } + + private void validateEmbeddingFieldsValue(final IngestDocument ingestDocument) { + Map sourceAndMetadataMap = ingestDocument.getSourceAndMetadata(); + for (Map.Entry embeddingFieldsEntry : fieldMap.entrySet()) { + String mappedSourceKey = embeddingFieldsEntry.getValue(); + Object sourceValue = sourceAndMetadataMap.get(mappedSourceKey); + if (Objects.isNull(sourceValue)) { + continue; + } + Class sourceValueClass = sourceValue.getClass(); + if (List.class.isAssignableFrom(sourceValueClass) || Map.class.isAssignableFrom(sourceValueClass)) { + String indexName = sourceAndMetadataMap.get(IndexFieldMapper.NAME).toString(); + validateNestedTypeValue(mappedSourceKey, sourceValue, () -> 1, indexName); + } else if (!String.class.isAssignableFrom(sourceValueClass)) { + throw new IllegalArgumentException("field [" + mappedSourceKey + "] is neither string nor nested type, can not process it"); + } else if (StringUtils.isBlank(sourceValue.toString())) { + throw new IllegalArgumentException("field [" + mappedSourceKey + "] has empty string value, can not process it"); + } + + } + } + + @SuppressWarnings({ "rawtypes", "unchecked" }) + private void validateNestedTypeValue( + final String sourceKey, + final Object sourceValue, + final Supplier maxDepthSupplier, + final String indexName + ) { + int maxDepth = maxDepthSupplier.get(); + Settings indexSettings = clusterService.state().metadata().index(indexName).getSettings(); + if (maxDepth > MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(indexSettings)) { + throw new IllegalArgumentException("map type field [" + sourceKey + "] reached max depth limit, can not process it"); + } else if ((List.class.isAssignableFrom(sourceValue.getClass()))) { + validateListTypeValue(sourceKey, (List) sourceValue); + } else if (Map.class.isAssignableFrom(sourceValue.getClass())) { + ((Map) sourceValue).values() + .stream() + .filter(Objects::nonNull) + .forEach(x -> validateNestedTypeValue(sourceKey, x, () -> maxDepth + 1, indexName)); + } else if (!String.class.isAssignableFrom(sourceValue.getClass())) { + throw new IllegalArgumentException("map type field [" + sourceKey + "] has non-string type, can not process it"); + } else if (StringUtils.isBlank(sourceValue.toString())) { + throw new IllegalArgumentException("map type field [" + sourceKey + "] has empty string, can not process it"); + } + } + + @SuppressWarnings({ "rawtypes" }) + private static void validateListTypeValue(final String sourceKey, final List sourceValue) { + for (Object value : sourceValue) { + if (value == null) { + throw new IllegalArgumentException("list type field [" + sourceKey + "] has null, can not process it"); + } else if (!(value instanceof String)) { + throw new IllegalArgumentException("list type field [" + sourceKey + "] has non string value, can not process it"); + } else if (StringUtils.isBlank(value.toString())) { + throw new IllegalArgumentException("list type field [" + sourceKey + "] has empty string, can not process it"); + } + } + } + + @Override + public String getType() { + return TYPE; + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/TextEmbeddingProcessorFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/TextEmbeddingProcessorFactory.java index 0c9a6fa2c..adf6f6d21 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/factory/TextEmbeddingProcessorFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/TextEmbeddingProcessorFactory.java @@ -25,17 +25,17 @@ public class TextEmbeddingProcessorFactory implements Processor.Factory { private final Environment environment; - public TextEmbeddingProcessorFactory(MLCommonsClientAccessor clientAccessor, Environment environment) { + public TextEmbeddingProcessorFactory(final MLCommonsClientAccessor clientAccessor, final Environment environment) { this.clientAccessor = clientAccessor; this.environment = environment; } @Override public TextEmbeddingProcessor create( - Map registry, - String processorTag, - String description, - Map config + final Map registry, + final String processorTag, + final String description, + final Map config ) throws Exception { String modelId = readStringProperty(TYPE, processorTag, config, MODEL_ID_FIELD); Map filedMap = readMap(TYPE, processorTag, config, FIELD_MAP_FIELD); diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/TextImageEmbeddingProcessorFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/TextImageEmbeddingProcessorFactory.java new file mode 100644 index 000000000..c18ec6fb3 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/TextImageEmbeddingProcessorFactory.java @@ -0,0 +1,56 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor.factory; + +import static org.opensearch.ingest.ConfigurationUtils.readMap; +import static org.opensearch.ingest.ConfigurationUtils.readStringProperty; +import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.Factory; +import static org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor.EMBEDDING_FIELD; +import static org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor.FIELD_MAP_FIELD; +import static org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor.MODEL_ID_FIELD; +import static org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor.TYPE; + +import java.util.Map; + +import lombok.AllArgsConstructor; + +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.env.Environment; +import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; +import org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor; + +/** + * Factory for text_image embedding ingest processor for ingestion pipeline. Instantiates processor based on user provided input. + */ +@AllArgsConstructor +public class TextImageEmbeddingProcessorFactory implements Factory { + + private final MLCommonsClientAccessor clientAccessor; + private final Environment environment; + private final ClusterService clusterService; + + @Override + public TextImageEmbeddingProcessor create( + final Map registry, + final String processorTag, + final String description, + final Map config + ) throws Exception { + String modelId = readStringProperty(TYPE, processorTag, config, MODEL_ID_FIELD); + String embedding = readStringProperty(TYPE, processorTag, config, EMBEDDING_FIELD); + Map filedMap = readMap(TYPE, processorTag, config, FIELD_MAP_FIELD); + return new TextImageEmbeddingProcessor( + processorTag, + description, + modelId, + embedding, + filedMap, + clientAccessor, + environment, + clusterService + ); + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java index 7b78be269..edb9aace0 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java @@ -7,8 +7,12 @@ import static org.opensearch.knn.index.query.KNNQueryBuilder.FILTER_FIELD; import static org.opensearch.neuralsearch.common.VectorUtil.vectorAsListToArray; +import static org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor.INPUT_IMAGE; +import static org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor.INPUT_TEXT; import java.io.IOException; +import java.util.HashMap; +import java.util.Map; import java.util.function.Supplier; import lombok.AccessLevel; @@ -19,6 +23,7 @@ import lombok.experimental.Accessors; import lombok.extern.log4j.Log4j2; +import org.apache.commons.lang.StringUtils; import org.apache.commons.lang.builder.EqualsBuilder; import org.apache.commons.lang.builder.HashCodeBuilder; import org.apache.lucene.search.Query; @@ -61,6 +66,9 @@ public class NeuralQueryBuilder extends AbstractQueryBuilder @VisibleForTesting static final ParseField QUERY_TEXT_FIELD = new ParseField("query_text"); + @VisibleForTesting + static final ParseField QUERY_IMAGE_FIELD = new ParseField("query_image"); + @VisibleForTesting static final ParseField MODEL_ID_FIELD = new ParseField("model_id"); @@ -77,6 +85,7 @@ public static void initialize(MLCommonsClientAccessor mlClient) { private String fieldName; private String queryText; + private String queryImage; private String modelId; private int k = DEFAULT_K; @VisibleForTesting @@ -177,7 +186,9 @@ public static NeuralQueryBuilder fromXContent(XContentParser parser) throws IOEx + "]" ); } - requireValue(neuralQueryBuilder.queryText(), "Query text must be provided for neural query"); + if (StringUtils.isBlank(neuralQueryBuilder.queryText()) && StringUtils.isBlank(neuralQueryBuilder.queryImage())) { + throw new IllegalArgumentException("Either query text or image text must be provided for neural query"); + } requireValue(neuralQueryBuilder.fieldName(), "Field name must be provided for neural query"); if (!isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport()) { requireValue(neuralQueryBuilder.modelId(), "Model ID must be provided for neural query"); @@ -194,6 +205,8 @@ private static void parseQueryParams(XContentParser parser, NeuralQueryBuilder n } else if (token.isValue()) { if (QUERY_TEXT_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { neuralQueryBuilder.queryText(parser.text()); + } else if (QUERY_IMAGE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { + neuralQueryBuilder.queryImage(parser.text()); } else if (MODEL_ID_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { neuralQueryBuilder.modelId(parser.text()); } else if (K_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { @@ -237,13 +250,20 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) { } SetOnce vectorSetOnce = new SetOnce<>(); + Map inferenceInput = new HashMap<>(); + if (StringUtils.isNotBlank(queryText())) { + inferenceInput.put(INPUT_TEXT, queryText()); + } + if (StringUtils.isNotBlank(queryImage())) { + inferenceInput.put(INPUT_IMAGE, queryImage()); + } queryRewriteContext.registerAsyncAction( - ((client, actionListener) -> ML_CLIENT.inferenceSentence(modelId(), queryText(), ActionListener.wrap(floatList -> { + ((client, actionListener) -> ML_CLIENT.inferenceSentences(modelId(), inferenceInput, ActionListener.wrap(floatList -> { vectorSetOnce.set(vectorAsListToArray(floatList)); actionListener.onResponse(null); }, actionListener::onFailure))) ); - return new NeuralQueryBuilder(fieldName(), queryText(), modelId(), k(), vectorSetOnce::get, filter()); + return new NeuralQueryBuilder(fieldName(), queryText(), queryImage(), modelId(), k(), vectorSetOnce::get, filter()); } @Override diff --git a/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java b/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java index 9c24e81fd..33cdff9a0 100644 --- a/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java +++ b/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java @@ -72,17 +72,20 @@ public abstract class BaseNeuralSearchIT extends OpenSearchSecureRestTestCase { protected static final String DEFAULT_COMBINATION_METHOD = "arithmetic_mean"; protected static final String PARAM_NAME_WEIGHTS = "weights"; - protected String PIPELINE_CONFIGURATION_NAME = "processor/PipelineConfiguration.json"; + protected static final Map PIPELINE_CONFIGS_BY_TYPE = Map.of( + ProcessorType.TEXT_EMBEDDING, + "processor/PipelineConfiguration.json", + ProcessorType.SPARSE_ENCODING, + "processor/SparseEncodingPipelineConfiguration.json", + ProcessorType.TEXT_IMAGE_EMBEDDING, + "processor/PipelineForTextImageEmbeddingProcessorConfiguration.json" + ); protected final ClassLoader classLoader = this.getClass().getClassLoader(); protected ThreadPool threadPool; protected ClusterService clusterService; - protected void setPipelineConfigurationName(String pipelineConfigurationName) { - this.PIPELINE_CONFIGURATION_NAME = pipelineConfigurationName; - } - @Before public void setupSettings() { threadPool = setUpThreadPool(); @@ -263,13 +266,21 @@ protected void createIndexWithConfiguration(String indexName, String indexConfig } protected void createPipelineProcessor(String modelId, String pipelineName) throws Exception { + createPipelineProcessor(modelId, pipelineName, ProcessorType.TEXT_EMBEDDING); + } + + protected void createPipelineProcessor(String modelId, String pipelineName, ProcessorType processorType) throws Exception { Response pipelineCreateResponse = makeRequest( client(), "PUT", "/_ingest/pipeline/" + pipelineName, null, toHttpEntity( - String.format(LOCALE, Files.readString(Path.of(classLoader.getResource(PIPELINE_CONFIGURATION_NAME).toURI())), modelId) + String.format( + LOCALE, + Files.readString(Path.of(classLoader.getResource(PIPELINE_CONFIGS_BY_TYPE.get(processorType)).toURI())), + modelId + ) ), ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT)) ); @@ -760,4 +771,14 @@ private String registerModelGroup() { assertNotNull(modelGroupId); return modelGroupId; } + + /** + * Enumeration for types of pipeline processors, used to lookup resources like create + * processor request as those are type specific + */ + protected enum ProcessorType { + TEXT_EMBEDDING, + TEXT_IMAGE_EMBEDDING, + SPARSE_ENCODING + } } diff --git a/src/test/java/org/opensearch/neuralsearch/constants/TestCommonConstants.java b/src/test/java/org/opensearch/neuralsearch/constants/TestCommonConstants.java index 2776a53e6..185934b07 100644 --- a/src/test/java/org/opensearch/neuralsearch/constants/TestCommonConstants.java +++ b/src/test/java/org/opensearch/neuralsearch/constants/TestCommonConstants.java @@ -6,6 +6,7 @@ package org.opensearch.neuralsearch.constants; import java.util.List; +import java.util.Map; import lombok.AccessLevel; import lombok.NoArgsConstructor; @@ -16,4 +17,5 @@ public class TestCommonConstants { public static final List TARGET_RESPONSE_FILTERS = List.of("sentence_embedding"); public static final Float[] PREDICT_VECTOR_ARRAY = new Float[] { 2.0f, 3.0f }; public static final List SENTENCES_LIST = List.of("TEXT"); + public static final Map SENTENCES_MAP = Map.of("inputText", "Text query", "inputImage", "base641234567890"); } diff --git a/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java b/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java index 295daa948..68d5f79eb 100644 --- a/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java @@ -278,6 +278,54 @@ public void testInferenceSentencesWithMapResult_whenNotRetryableException_thenFa Mockito.verify(resultListener).onFailure(illegalStateException); } + public void testInferenceMultimodal_whenValidInput_thenSuccess() { + final List vector = new ArrayList<>(List.of(TestCommonConstants.PREDICT_VECTOR_ARRAY)); + Mockito.doAnswer(invocation -> { + final ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(createModelTensorOutput(TestCommonConstants.PREDICT_VECTOR_ARRAY)); + return null; + }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + + accessor.inferenceSentences(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_MAP, singleSentenceResultListener); + + Mockito.verify(client) + .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + Mockito.verify(singleSentenceResultListener).onResponse(vector); + Mockito.verifyNoMoreInteractions(singleSentenceResultListener); + } + + public void testInferenceMultimodal_whenExceptionFromMLClient_thenFailure() { + final RuntimeException exception = new RuntimeException(); + Mockito.doAnswer(invocation -> { + final ActionListener actionListener = invocation.getArgument(2); + actionListener.onFailure(exception); + return null; + }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + accessor.inferenceSentences(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_MAP, singleSentenceResultListener); + + Mockito.verify(client) + .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + Mockito.verify(singleSentenceResultListener).onFailure(exception); + Mockito.verifyNoMoreInteractions(singleSentenceResultListener); + } + + public void testInferenceSentencesMultimodal_whenNodeNotConnectedException_thenRetryThreeTimes() { + final NodeNotConnectedException nodeNodeConnectedException = new NodeNotConnectedException( + mock(DiscoveryNode.class), + "Node not connected" + ); + Mockito.doAnswer(invocation -> { + final ActionListener actionListener = invocation.getArgument(2); + actionListener.onFailure(nodeNodeConnectedException); + return null; + }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + accessor.inferenceSentences(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_MAP, singleSentenceResultListener); + + Mockito.verify(client, times(4)) + .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + Mockito.verify(singleSentenceResultListener).onFailure(nodeNodeConnectedException); + } + private ModelTensorOutput createModelTensorOutput(final Float[] output) { final List tensorsList = new ArrayList<>(); final List mlModelTensorList = new ArrayList<>(); diff --git a/src/test/java/org/opensearch/neuralsearch/plugin/NeuralSearchTests.java b/src/test/java/org/opensearch/neuralsearch/plugin/NeuralSearchTests.java index 8cae15678..69791681e 100644 --- a/src/test/java/org/opensearch/neuralsearch/plugin/NeuralSearchTests.java +++ b/src/test/java/org/opensearch/neuralsearch/plugin/NeuralSearchTests.java @@ -11,6 +11,7 @@ import java.util.Map; import java.util.Optional; +import org.opensearch.ingest.IngestService; import org.opensearch.ingest.Processor; import org.opensearch.neuralsearch.processor.NeuralQueryEnricherProcessor; import org.opensearch.neuralsearch.processor.NormalizationProcessor; @@ -56,7 +57,17 @@ public void testQueryPhaseSearcher() { public void testProcessors() { NeuralSearch plugin = new NeuralSearch(); - Processor.Parameters processorParams = mock(Processor.Parameters.class); + Processor.Parameters processorParams = new Processor.Parameters( + null, + null, + null, + null, + null, + null, + mock(IngestService.class), + null, + null + ); Map processors = plugin.getProcessors(processorParams); assertNotNull(processors); assertNotNull(processors.get(TextEmbeddingProcessor.TYPE)); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorIT.java b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorIT.java index 3cd71e5a1..86e75f736 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorIT.java @@ -94,7 +94,15 @@ public void testResultProcessor_whenOneShardAndQueryMatches_thenSuccessful() { createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE); String modelId = getDeployedModelId(); - NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId, 5, null, null); + NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder( + TEST_KNN_VECTOR_FIELD_NAME_1, + TEST_DOC_TEXT1, + "", + modelId, + 5, + null, + null + ); TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); @@ -129,7 +137,15 @@ public void testResultProcessor_whenDefaultProcessorConfigAndQueryMatches_thenSu createSearchPipelineWithDefaultResultsPostProcessor(SEARCH_PIPELINE); String modelId = getDeployedModelId(); - NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId, 5, null, null); + NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder( + TEST_KNN_VECTOR_FIELD_NAME_1, + TEST_DOC_TEXT1, + "", + modelId, + 5, + null, + null + ); TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); @@ -153,7 +169,15 @@ public void testResultProcessor_whenMultipleShardsAndQueryMatches_thenSuccessful String modelId = getDeployedModelId(); int totalExpectedDocQty = 6; - NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId, 6, null, null); + NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder( + TEST_KNN_VECTOR_FIELD_NAME_1, + TEST_DOC_TEXT1, + "", + modelId, + 6, + null, + null + ); TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationIT.java b/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationIT.java index 03b77549a..4993df7fb 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationIT.java @@ -213,7 +213,7 @@ public void testHarmonicMeanCombination_whenOneShardAndQueryMatches_thenSuccessf String modelId = getDeployedModelId(); HybridQueryBuilder hybridQueryBuilderDefaultNorm = new HybridQueryBuilder(); - hybridQueryBuilderDefaultNorm.add(new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId, 5, null, null)); + hybridQueryBuilderDefaultNorm.add(new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null)); hybridQueryBuilderDefaultNorm.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); Map searchResponseAsMapDefaultNorm = search( @@ -236,7 +236,7 @@ public void testHarmonicMeanCombination_whenOneShardAndQueryMatches_thenSuccessf ); HybridQueryBuilder hybridQueryBuilderL2Norm = new HybridQueryBuilder(); - hybridQueryBuilderL2Norm.add(new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId, 5, null, null)); + hybridQueryBuilderL2Norm.add(new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null)); hybridQueryBuilderL2Norm.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); Map searchResponseAsMapL2Norm = search( @@ -279,7 +279,7 @@ public void testGeometricMeanCombination_whenOneShardAndQueryMatches_thenSuccess String modelId = getDeployedModelId(); HybridQueryBuilder hybridQueryBuilderDefaultNorm = new HybridQueryBuilder(); - hybridQueryBuilderDefaultNorm.add(new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId, 5, null, null)); + hybridQueryBuilderDefaultNorm.add(new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null)); hybridQueryBuilderDefaultNorm.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); Map searchResponseAsMapDefaultNorm = search( @@ -302,7 +302,7 @@ public void testGeometricMeanCombination_whenOneShardAndQueryMatches_thenSuccess ); HybridQueryBuilder hybridQueryBuilderL2Norm = new HybridQueryBuilder(); - hybridQueryBuilderL2Norm.add(new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId, 5, null, null)); + hybridQueryBuilderL2Norm.add(new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null)); hybridQueryBuilderL2Norm.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); Map searchResponseAsMapL2Norm = search( diff --git a/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationIT.java b/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationIT.java index 6f98e8d5e..aa133c44d 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationIT.java @@ -96,7 +96,9 @@ public void testL2Norm_whenOneShardAndQueryMatches_thenSuccessful() { String modelId = getDeployedModelId(); HybridQueryBuilder hybridQueryBuilderArithmeticMean = new HybridQueryBuilder(); - hybridQueryBuilderArithmeticMean.add(new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId, 5, null, null)); + hybridQueryBuilderArithmeticMean.add( + new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null) + ); hybridQueryBuilderArithmeticMean.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); Map searchResponseAsMapArithmeticMean = search( @@ -119,7 +121,9 @@ public void testL2Norm_whenOneShardAndQueryMatches_thenSuccessful() { ); HybridQueryBuilder hybridQueryBuilderHarmonicMean = new HybridQueryBuilder(); - hybridQueryBuilderHarmonicMean.add(new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId, 5, null, null)); + hybridQueryBuilderHarmonicMean.add( + new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null) + ); hybridQueryBuilderHarmonicMean.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); Map searchResponseAsMapHarmonicMean = search( @@ -142,7 +146,9 @@ public void testL2Norm_whenOneShardAndQueryMatches_thenSuccessful() { ); HybridQueryBuilder hybridQueryBuilderGeometricMean = new HybridQueryBuilder(); - hybridQueryBuilderGeometricMean.add(new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId, 5, null, null)); + hybridQueryBuilderGeometricMean.add( + new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null) + ); hybridQueryBuilderGeometricMean.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); Map searchResponseAsMapGeometricMean = search( @@ -185,7 +191,9 @@ public void testMinMaxNorm_whenOneShardAndQueryMatches_thenSuccessful() { String modelId = getDeployedModelId(); HybridQueryBuilder hybridQueryBuilderArithmeticMean = new HybridQueryBuilder(); - hybridQueryBuilderArithmeticMean.add(new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId, 5, null, null)); + hybridQueryBuilderArithmeticMean.add( + new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null) + ); hybridQueryBuilderArithmeticMean.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); Map searchResponseAsMapArithmeticMean = search( @@ -208,7 +216,9 @@ public void testMinMaxNorm_whenOneShardAndQueryMatches_thenSuccessful() { ); HybridQueryBuilder hybridQueryBuilderHarmonicMean = new HybridQueryBuilder(); - hybridQueryBuilderHarmonicMean.add(new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId, 5, null, null)); + hybridQueryBuilderHarmonicMean.add( + new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null) + ); hybridQueryBuilderHarmonicMean.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); Map searchResponseAsMapHarmonicMean = search( @@ -231,7 +241,9 @@ public void testMinMaxNorm_whenOneShardAndQueryMatches_thenSuccessful() { ); HybridQueryBuilder hybridQueryBuilderGeometricMean = new HybridQueryBuilder(); - hybridQueryBuilderGeometricMean.add(new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId, 5, null, null)); + hybridQueryBuilderGeometricMean.add( + new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null) + ); hybridQueryBuilderGeometricMean.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); Map searchResponseAsMapGeometricMean = search( diff --git a/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessIT.java b/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessIT.java index 0312eaef7..893b47ede 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessIT.java @@ -15,7 +15,6 @@ import org.apache.hc.core5.http.io.entity.EntityUtils; import org.apache.hc.core5.http.message.BasicHeader; import org.junit.After; -import org.junit.Before; import org.opensearch.client.Response; import org.opensearch.common.xcontent.XContentHelper; import org.opensearch.common.xcontent.XContentType; @@ -40,14 +39,9 @@ public void tearDown() { findDeployedModels().forEach(this::deleteModel); } - @Before - public void setPipelineName() { - this.setPipelineConfigurationName("processor/SparseEncodingPipelineConfiguration.json"); - } - public void testSparseEncodingProcessor() throws Exception { String modelId = prepareModel(); - createPipelineProcessor(modelId, PIPELINE_NAME); + createPipelineProcessor(modelId, PIPELINE_NAME, ProcessorType.SPARSE_ENCODING); createSparseEncodingIndex(); ingestDocument(); assertEquals(1, getDocCount(INDEX_NAME)); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessorIT.java b/src/test/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessorIT.java new file mode 100644 index 000000000..055344352 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessorIT.java @@ -0,0 +1,110 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor; + +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Map; + +import lombok.SneakyThrows; + +import org.apache.hc.core5.http.HttpHeaders; +import org.apache.hc.core5.http.io.entity.EntityUtils; +import org.apache.hc.core5.http.message.BasicHeader; +import org.junit.After; +import org.opensearch.client.Response; +import org.opensearch.common.xcontent.XContentHelper; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.neuralsearch.common.BaseNeuralSearchIT; + +import com.google.common.collect.ImmutableList; + +/** + * Testing text_and_image_embedding ingest processor. We can only test text in integ tests, none of pre-built models + * supports both text and image. + */ +public class TextImageEmbeddingProcessorIT extends BaseNeuralSearchIT { + + private static final String INDEX_NAME = "text_image_embedding_index"; + private static final String PIPELINE_NAME = "ingest-pipeline"; + + @After + @SneakyThrows + public void tearDown() { + super.tearDown(); + findDeployedModels().forEach(this::deleteModel); + } + + public void testEmbeddingProcessor_whenIngestingDocumentWithSourceMatchingTextMapping_thenSuccessful() throws Exception { + String modelId = uploadModel(); + loadModel(modelId); + createPipelineProcessor(modelId, PIPELINE_NAME, ProcessorType.TEXT_IMAGE_EMBEDDING); + createTextImageEmbeddingIndex(); + ingestDocumentWithTextMappedToEmbeddingField(); + assertEquals(1, getDocCount(INDEX_NAME)); + } + + public void testEmbeddingProcessor_whenIngestingDocumentWithSourceWithoutMatchingInMapping_thenSuccessful() throws Exception { + String modelId = uploadModel(); + loadModel(modelId); + createPipelineProcessor(modelId, PIPELINE_NAME, ProcessorType.TEXT_IMAGE_EMBEDDING); + createTextImageEmbeddingIndex(); + ingestDocumentWithoutMappedFields(); + assertEquals(1, getDocCount(INDEX_NAME)); + } + + private String uploadModel() throws Exception { + String requestBody = Files.readString(Path.of(classLoader.getResource("processor/UploadModelRequestBody.json").toURI())); + return uploadModel(requestBody); + } + + private void createTextImageEmbeddingIndex() throws Exception { + createIndexWithConfiguration( + INDEX_NAME, + Files.readString(Path.of(classLoader.getResource("processor/IndexMappings.json").toURI())), + PIPELINE_NAME + ); + } + + private void ingestDocumentWithTextMappedToEmbeddingField() throws Exception { + String ingestDocumentBody = "{\n" + + " \"title\": \"This is a good day\",\n" + + " \"description\": \"daily logging\",\n" + + " \"passage_text\": \"A very nice day today\",\n" + + " \"favorites\": {\n" + + " \"game\": \"overwatch\",\n" + + " \"movie\": null\n" + + " }\n" + + "}\n"; + ingestDocument(ingestDocumentBody); + } + + private void ingestDocumentWithoutMappedFields() throws Exception { + String ingestDocumentBody = "{\n" + + " \"title\": \"This is a good day\",\n" + + " \"description\": \"daily logging\",\n" + + " \"some_random_field\": \"Today is a sunny weather\"\n" + + "}\n"; + ingestDocument(ingestDocumentBody); + } + + private void ingestDocument(final String ingestDocument) throws Exception { + Response response = makeRequest( + client(), + "POST", + INDEX_NAME + "/_doc?refresh", + null, + toHttpEntity(ingestDocument), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana")) + ); + Map map = XContentHelper.convertToMap( + XContentType.JSON.xContent(), + EntityUtils.toString(response.getEntity()), + false + ); + assertEquals("created", map.get("result")); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessorTests.java new file mode 100644 index 000000000..bae336d4a --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessorTests.java @@ -0,0 +1,383 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor; + +import static org.mockito.ArgumentMatchers.anyMap; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.isA; +import static org.mockito.Mockito.isNull; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor.IMAGE_FIELD_NAME; +import static org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor.TEXT_FIELD_NAME; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.BiConsumer; +import java.util.function.Supplier; + +import lombok.SneakyThrows; + +import org.junit.Before; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.OpenSearchParseException; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.metadata.IndexMetadata; +import org.opensearch.cluster.metadata.Metadata; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.action.ActionListener; +import org.opensearch.env.Environment; +import org.opensearch.index.mapper.IndexFieldMapper; +import org.opensearch.ingest.IngestDocument; +import org.opensearch.ingest.Processor; +import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; +import org.opensearch.neuralsearch.processor.factory.TextImageEmbeddingProcessorFactory; +import org.opensearch.test.OpenSearchTestCase; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; + +public class TextImageEmbeddingProcessorTests extends OpenSearchTestCase { + + @Mock + private MLCommonsClientAccessor mlCommonsClientAccessor; + @Mock + private Environment env; + @Mock + private ClusterService clusterService; + @Mock + private ClusterState clusterState; + @Mock + private Metadata metadata; + @Mock + private IndexMetadata indexMetadata; + + @InjectMocks + private TextImageEmbeddingProcessorFactory textImageEmbeddingProcessorFactory; + private static final String PROCESSOR_TAG = "mockTag"; + private static final String DESCRIPTION = "mockDescription"; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + Settings settings = Settings.builder().put("index.mapping.depth.limit", 20).build(); + when(env.settings()).thenReturn(settings); + when(clusterService.state()).thenReturn(clusterState); + when(clusterState.metadata()).thenReturn(metadata); + when(metadata.index(anyString())).thenReturn(indexMetadata); + when(indexMetadata.getSettings()).thenReturn(settings); + } + + @SneakyThrows + private TextImageEmbeddingProcessor createInstance() { + Map registry = new HashMap<>(); + Map config = new HashMap<>(); + config.put(TextImageEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId"); + config.put(TextImageEmbeddingProcessor.EMBEDDING_FIELD, "my_embedding_field"); + config.put( + TextImageEmbeddingProcessor.FIELD_MAP_FIELD, + ImmutableMap.of(TEXT_FIELD_NAME, "my_text_field", IMAGE_FIELD_NAME, "image_field") + ); + return textImageEmbeddingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); + } + + @SneakyThrows + public void testTextEmbeddingProcessConstructor_whenConfigMapEmpty_throwIllegalArgumentException() { + Map registry = new HashMap<>(); + Map config = new HashMap<>(); + config.put(TextImageEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId"); + try { + textImageEmbeddingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); + } catch (OpenSearchParseException e) { + assertEquals("[embedding] required property is missing", e.getMessage()); + } + } + + @SneakyThrows + public void testTextEmbeddingProcessConstructor_whenTypeMappingIsNullOrInvalid_throwIllegalArgumentException() { + boolean ignoreFailure = false; + String modelId = "mockModelId"; + String embeddingField = "my_embedding_field"; + + // create with null type mapping + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> new TextImageEmbeddingProcessor( + PROCESSOR_TAG, + DESCRIPTION, + modelId, + embeddingField, + null, + mlCommonsClientAccessor, + env, + clusterService + ) + ); + assertEquals("Unable to create the TextImageEmbedding processor as field_map has invalid key or value", exception.getMessage()); + + // type mapping has empty key + exception = expectThrows( + IllegalArgumentException.class, + () -> new TextImageEmbeddingProcessor( + PROCESSOR_TAG, + DESCRIPTION, + modelId, + embeddingField, + Map.of("", "my_field"), + mlCommonsClientAccessor, + env, + clusterService + ) + ); + assertEquals("Unable to create the TextImageEmbedding processor as field_map has invalid key or value", exception.getMessage()); + + // type mapping has empty value + // use vanila java syntax because it allows null values + Map typeMapping = new HashMap<>(); + typeMapping.put("my_field", null); + + exception = expectThrows( + IllegalArgumentException.class, + () -> new TextImageEmbeddingProcessor( + PROCESSOR_TAG, + DESCRIPTION, + modelId, + embeddingField, + typeMapping, + mlCommonsClientAccessor, + env, + clusterService + ) + ); + assertEquals("Unable to create the TextImageEmbedding processor as field_map has invalid key or value", exception.getMessage()); + } + + @SneakyThrows + public void testTextEmbeddingProcessConstructor_whenEmptyModelId_throwIllegalArgumentException() { + Map registry = new HashMap<>(); + Map config = new HashMap<>(); + config.put(TextImageEmbeddingProcessor.MODEL_ID_FIELD, ""); + config.put(TextImageEmbeddingProcessor.EMBEDDING_FIELD, "my_embedding_field"); + config.put( + TextImageEmbeddingProcessor.FIELD_MAP_FIELD, + ImmutableMap.of(TEXT_FIELD_NAME, "my_text_field", IMAGE_FIELD_NAME, "image_field") + ); + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> textImageEmbeddingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config) + ); + assertEquals("model_id is null or empty, can not process it", exception.getMessage()); + } + + public void testExecute_successful() { + Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put("key1", "value1"); + sourceAndMetadata.put("my_text_field", "value2"); + sourceAndMetadata.put("key3", "value3"); + sourceAndMetadata.put("image_field", "base64_of_image_1234567890"); + IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); + TextImageEmbeddingProcessor processor = createInstance(); + + List> modelTensorList = createMockVectorResult(); + doAnswer(invocation -> { + ActionListener>> listener = invocation.getArgument(2); + listener.onResponse(modelTensorList); + return null; + }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyMap(), isA(ActionListener.class)); + + BiConsumer handler = mock(BiConsumer.class); + processor.execute(ingestDocument, handler); + verify(handler).accept(any(IngestDocument.class), isNull()); + } + + @SneakyThrows + public void testExecute_whenInferenceThrowInterruptedException_throwRuntimeException() { + Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put("my_text_field", "value1"); + sourceAndMetadata.put("another_text_field", "value2"); + IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); + Map registry = new HashMap<>(); + MLCommonsClientAccessor accessor = mock(MLCommonsClientAccessor.class); + TextImageEmbeddingProcessorFactory textImageEmbeddingProcessorFactory = new TextImageEmbeddingProcessorFactory( + accessor, + env, + clusterService + ); + + Map config = new HashMap<>(); + config.put(TextImageEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId"); + config.put(TextImageEmbeddingProcessor.EMBEDDING_FIELD, "my_embedding_field"); + config.put( + TextImageEmbeddingProcessor.FIELD_MAP_FIELD, + ImmutableMap.of(TEXT_FIELD_NAME, "my_text_field", IMAGE_FIELD_NAME, "image_field") + ); + TextImageEmbeddingProcessor processor = textImageEmbeddingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); + doThrow(new RuntimeException()).when(accessor).inferenceSentences(anyString(), anyMap(), isA(ActionListener.class)); + BiConsumer handler = mock(BiConsumer.class); + processor.execute(ingestDocument, handler); + verify(handler).accept(isNull(), any(RuntimeException.class)); + } + + public void testExecute_withListTypeInput_successful() { + Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put("my_text_field", "value1"); + sourceAndMetadata.put("another_text_field", "value2"); + IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); + TextImageEmbeddingProcessor processor = createInstance(); + + List> modelTensorList = createMockVectorResult(); + doAnswer(invocation -> { + ActionListener>> listener = invocation.getArgument(2); + listener.onResponse(modelTensorList); + return null; + }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyMap(), isA(ActionListener.class)); + + BiConsumer handler = mock(BiConsumer.class); + processor.execute(ingestDocument, handler); + verify(handler).accept(any(IngestDocument.class), isNull()); + } + + public void testExecute_mapDepthReachLimit_throwIllegalArgumentException() { + Map ret = createMaxDepthLimitExceedMap(() -> 1); + Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put("key1", "hello world"); + sourceAndMetadata.put("my_text_field", ret); + sourceAndMetadata.put(IndexFieldMapper.NAME, "my_index"); + IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); + TextImageEmbeddingProcessor processor = createInstance(); + BiConsumer handler = mock(BiConsumer.class); + processor.execute(ingestDocument, handler); + verify(handler).accept(isNull(), any(IllegalArgumentException.class)); + } + + public void testExecute_MLClientAccessorThrowFail_handlerFailure() { + Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put("my_text_field", "value1"); + sourceAndMetadata.put("key2", "value2"); + IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); + TextImageEmbeddingProcessor processor = createInstance(); + + doAnswer(invocation -> { + ActionListener>> listener = invocation.getArgument(2); + listener.onFailure(new IllegalArgumentException("illegal argument")); + return null; + }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyMap(), isA(ActionListener.class)); + + BiConsumer handler = mock(BiConsumer.class); + processor.execute(ingestDocument, handler); + verify(handler).accept(isNull(), any(IllegalArgumentException.class)); + } + + public void testExecute_mapHasNonStringValue_throwIllegalArgumentException() { + Map map1 = ImmutableMap.of("test1", "test2"); + Map map2 = ImmutableMap.of("test3", 209.3D); + Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put("key1", map1); + sourceAndMetadata.put("my_text_field", map2); + sourceAndMetadata.put(IndexFieldMapper.NAME, "my_index"); + IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); + TextImageEmbeddingProcessor processor = createInstance(); + BiConsumer handler = mock(BiConsumer.class); + processor.execute(ingestDocument, handler); + verify(handler).accept(isNull(), any(IllegalArgumentException.class)); + } + + public void testExecute_mapHasEmptyStringValue_throwIllegalArgumentException() { + Map map1 = ImmutableMap.of("test1", "test2"); + Map map2 = ImmutableMap.of("test3", " "); + Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put("key1", map1); + sourceAndMetadata.put("my_text_field", map2); + sourceAndMetadata.put(IndexFieldMapper.NAME, "my_index"); + IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); + TextImageEmbeddingProcessor processor = createInstance(); + BiConsumer handler = mock(BiConsumer.class); + processor.execute(ingestDocument, handler); + verify(handler).accept(isNull(), any(IllegalArgumentException.class)); + } + + public void testExecute_hybridTypeInput_successful() throws Exception { + List list1 = ImmutableList.of("test1", "test2"); + Map> map1 = ImmutableMap.of("test3", list1); + Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put("key2", map1); + IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); + TextImageEmbeddingProcessor processor = createInstance(); + IngestDocument document = processor.execute(ingestDocument); + assert document.getSourceAndMetadata().containsKey("key2"); + } + + public void testExecute_whenInferencesAreEmpty_thenSuccessful() { + Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put("my_field", "value1"); + sourceAndMetadata.put("another_text_field", "value2"); + IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); + TextImageEmbeddingProcessor processor = createInstance(); + + List> modelTensorList = createMockVectorResult(); + doAnswer(invocation -> { + ActionListener>> listener = invocation.getArgument(2); + listener.onResponse(modelTensorList); + return null; + }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyMap(), isA(ActionListener.class)); + + BiConsumer handler = mock(BiConsumer.class); + processor.execute(ingestDocument, handler); + verify(handler).accept(any(IngestDocument.class), isNull()); + } + + private List> createMockVectorResult() { + List> modelTensorList = new ArrayList<>(); + List number1 = ImmutableList.of(1.234f, 2.354f); + List number2 = ImmutableList.of(3.234f, 4.354f); + List number3 = ImmutableList.of(5.234f, 6.354f); + List number4 = ImmutableList.of(7.234f, 8.354f); + List number5 = ImmutableList.of(9.234f, 10.354f); + List number6 = ImmutableList.of(11.234f, 12.354f); + List number7 = ImmutableList.of(13.234f, 14.354f); + modelTensorList.add(number1); + modelTensorList.add(number2); + modelTensorList.add(number3); + modelTensorList.add(number4); + modelTensorList.add(number5); + modelTensorList.add(number6); + modelTensorList.add(number7); + return modelTensorList; + } + + private List> createMockVectorWithLength(int size) { + float suffix = .234f; + List> result = new ArrayList<>(); + for (int i = 0; i < size * 2;) { + List number = new ArrayList<>(); + number.add(i++ + suffix); + number.add(i++ + suffix); + result.add(number); + } + return result; + } + + private Map createMaxDepthLimitExceedMap(Supplier maxDepthSupplier) { + int maxDepth = maxDepthSupplier.get(); + if (maxDepth > 21) { + return null; + } + Map innerMap = new HashMap<>(); + Map ret = createMaxDepthLimitExceedMap(() -> maxDepth + 1); + if (ret == null) return innerMap; + innerMap.put("hello", ret); + return innerMap; + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/factory/TextImageEmbeddingProcessorFactoryTests.java b/src/test/java/org/opensearch/neuralsearch/processor/factory/TextImageEmbeddingProcessorFactoryTests.java new file mode 100644 index 000000000..cbf53b8fc --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/factory/TextImageEmbeddingProcessorFactoryTests.java @@ -0,0 +1,130 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor.factory; + +import static org.hamcrest.Matchers.allOf; +import static org.hamcrest.Matchers.containsString; +import static org.mockito.Mockito.mock; +import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.FIELD_MAP_FIELD; +import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.MODEL_ID_FIELD; +import static org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor.EMBEDDING_FIELD; +import static org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor.IMAGE_FIELD_NAME; +import static org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor.TEXT_FIELD_NAME; + +import java.util.HashMap; +import java.util.Map; + +import lombok.SneakyThrows; + +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.env.Environment; +import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; +import org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor; +import org.opensearch.test.OpenSearchTestCase; + +public class TextImageEmbeddingProcessorFactoryTests extends OpenSearchTestCase { + + @SneakyThrows + public void testNormalizationProcessor_whenAllParamsPassed_thenSuccessful() { + TextImageEmbeddingProcessorFactory textImageEmbeddingProcessorFactory = new TextImageEmbeddingProcessorFactory( + mock(MLCommonsClientAccessor.class), + mock(Environment.class), + mock(ClusterService.class) + ); + + final Map processorFactories = new HashMap<>(); + String tag = "tag"; + String description = "description"; + boolean ignoreFailure = false; + Map config = new HashMap<>(); + config.put(MODEL_ID_FIELD, "1234567678"); + config.put(EMBEDDING_FIELD, "embedding_field"); + config.put(FIELD_MAP_FIELD, Map.of(TEXT_FIELD_NAME, "my_text_field", IMAGE_FIELD_NAME, "my_image_field")); + TextImageEmbeddingProcessor inferenceProcessor = textImageEmbeddingProcessorFactory.create( + processorFactories, + tag, + description, + config + ); + assertNotNull(inferenceProcessor); + assertEquals("text_image_embedding", inferenceProcessor.getType()); + } + + @SneakyThrows + public void testNormalizationProcessor_whenOnlyOneParamSet_thenSuccessful() { + TextImageEmbeddingProcessorFactory textImageEmbeddingProcessorFactory = new TextImageEmbeddingProcessorFactory( + mock(MLCommonsClientAccessor.class), + mock(Environment.class), + mock(ClusterService.class) + ); + + final Map processorFactories = new HashMap<>(); + String tag = "tag"; + String description = "description"; + boolean ignoreFailure = false; + Map configOnlyTextField = new HashMap<>(); + configOnlyTextField.put(MODEL_ID_FIELD, "1234567678"); + configOnlyTextField.put(EMBEDDING_FIELD, "embedding_field"); + configOnlyTextField.put(FIELD_MAP_FIELD, Map.of(TEXT_FIELD_NAME, "my_text_field")); + TextImageEmbeddingProcessor processor = textImageEmbeddingProcessorFactory.create( + processorFactories, + tag, + description, + configOnlyTextField + ); + assertNotNull(processor); + assertEquals("text_image_embedding", processor.getType()); + + Map configOnlyImageField = new HashMap<>(); + configOnlyImageField.put(MODEL_ID_FIELD, "1234567678"); + configOnlyImageField.put(EMBEDDING_FIELD, "embedding_field"); + configOnlyImageField.put(FIELD_MAP_FIELD, Map.of(TEXT_FIELD_NAME, "my_text_field")); + processor = textImageEmbeddingProcessorFactory.create(processorFactories, tag, description, configOnlyImageField); + assertNotNull(processor); + assertEquals("text_image_embedding", processor.getType()); + } + + @SneakyThrows + public void testNormalizationProcessor_whenMixOfParamsOrEmptyParams_thenFail() { + TextImageEmbeddingProcessorFactory textImageEmbeddingProcessorFactory = new TextImageEmbeddingProcessorFactory( + mock(MLCommonsClientAccessor.class), + mock(Environment.class), + mock(ClusterService.class) + ); + + final Map processorFactories = new HashMap<>(); + String tag = "tag"; + String description = "description"; + boolean ignoreFailure = false; + Map configMixOfFields = new HashMap<>(); + configMixOfFields.put(MODEL_ID_FIELD, "1234567678"); + configMixOfFields.put(EMBEDDING_FIELD, "embedding_field"); + configMixOfFields.put(FIELD_MAP_FIELD, Map.of(TEXT_FIELD_NAME, "my_text_field", "random_field_name", "random_field")); + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> textImageEmbeddingProcessorFactory.create(processorFactories, tag, description, configMixOfFields) + ); + org.hamcrest.MatcherAssert.assertThat( + exception.getMessage(), + allOf( + containsString( + "Unable to create the TextImageEmbedding processor with provided field name(s). Following names are supported [" + ), + containsString("image"), + containsString("text") + ) + ); + Map configNoFields = new HashMap<>(); + configNoFields.put(MODEL_ID_FIELD, "1234567678"); + configNoFields.put(EMBEDDING_FIELD, "embedding_field"); + configNoFields.put(FIELD_MAP_FIELD, Map.of()); + exception = expectThrows( + IllegalArgumentException.class, + () -> textImageEmbeddingProcessorFactory.create(processorFactories, tag, description, configNoFields) + ); + assertEquals(exception.getMessage(), "Unable to create the TextImageEmbedding processor as field_map has invalid key or value"); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java index 681c1247d..94665d879 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java @@ -6,6 +6,7 @@ package org.opensearch.neuralsearch.query; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyMap; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; @@ -16,6 +17,7 @@ import static org.opensearch.neuralsearch.query.NeuralQueryBuilder.K_FIELD; import static org.opensearch.neuralsearch.query.NeuralQueryBuilder.MODEL_ID_FIELD; import static org.opensearch.neuralsearch.query.NeuralQueryBuilder.NAME; +import static org.opensearch.neuralsearch.query.NeuralQueryBuilder.QUERY_IMAGE_FIELD; import static org.opensearch.neuralsearch.query.NeuralQueryBuilder.QUERY_TEXT_FIELD; import java.io.IOException; @@ -49,6 +51,7 @@ import org.opensearch.index.query.MatchNoneQueryBuilder; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryRewriteContext; +import org.opensearch.index.query.QueryShardContext; import org.opensearch.knn.index.query.KNNQueryBuilder; import org.opensearch.neuralsearch.common.VectorUtil; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; @@ -60,6 +63,7 @@ public class NeuralQueryBuilderTests extends OpenSearchTestCase { private static final String FIELD_NAME = "testField"; private static final String QUERY_TEXT = "Hello world!"; + private static final String IMAGE_TEXT = "base641234567890"; private static final String MODEL_ID = "mfgfgdsfgfdgsde"; private static final int K = 10; private static final float BOOST = 1.8f; @@ -74,6 +78,7 @@ public void testFromXContent_whenBuiltWithDefaults_thenBuildSuccessfully() { { "VECTOR_FIELD": { "query_text": "string", + "query_image": "string", "model_id": "string", "k": int } @@ -117,6 +122,7 @@ public void testFromXContent_whenBuiltWithOptionals_thenBuildSuccessfully() { .startObject() .startObject(FIELD_NAME) .field(QUERY_TEXT_FIELD.getPreferredName(), QUERY_TEXT) + .field(QUERY_IMAGE_FIELD.getPreferredName(), IMAGE_TEXT) .field(MODEL_ID_FIELD.getPreferredName(), MODEL_ID) .field(K_FIELD.getPreferredName(), K) .field(BOOST_FIELD.getPreferredName(), BOOST) @@ -130,6 +136,7 @@ public void testFromXContent_whenBuiltWithOptionals_thenBuildSuccessfully() { assertEquals(FIELD_NAME, neuralQueryBuilder.fieldName()); assertEquals(QUERY_TEXT, neuralQueryBuilder.queryText()); + assertEquals(IMAGE_TEXT, neuralQueryBuilder.queryImage()); assertEquals(MODEL_ID, neuralQueryBuilder.modelId()); assertEquals(K, neuralQueryBuilder.k()); assertEquals(BOOST, neuralQueryBuilder.boost(), 0.0); @@ -269,6 +276,33 @@ public void testFromXContent_whenBuildWithDuplicateParameters_thenFail() { expectThrows(IOException.class, () -> NeuralQueryBuilder.fromXContent(contentParser)); } + @SneakyThrows + public void testFromXContent_whenNoQueryField_thenFail() { + /* + { + "VECTOR_FIELD": { + "model_id": "string", + "model_id": "string", + "k": int, + "k": int + } + } + */ + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .startObject(FIELD_NAME) + .field(MODEL_ID_FIELD.getPreferredName(), MODEL_ID) + .field(K_FIELD.getPreferredName(), K) + .field(MODEL_ID_FIELD.getPreferredName(), MODEL_ID) + .field(K_FIELD.getPreferredName(), K) + .endObject() + .endObject(); + + XContentParser contentParser = createParser(xContentBuilder); + contentParser.nextToken(); + expectThrows(IOException.class, () -> NeuralQueryBuilder.fromXContent(contentParser)); + } + @SneakyThrows public void testFromXContent_whenBuiltWithInvalidFilter_thenFail() { /* @@ -352,6 +386,7 @@ private void testStreams() { NeuralQueryBuilder original = new NeuralQueryBuilder(); original.fieldName(FIELD_NAME); original.queryText(QUERY_TEXT); + original.queryImage(IMAGE_TEXT); original.modelId(MODEL_ID); original.k(K); original.boost(BOOST); @@ -377,6 +412,7 @@ public void testHashAndEquals() { String fieldName2 = "field 2"; String queryText1 = "query text 1"; String queryText2 = "query text 2"; + String imageText1 = "query image 1"; String modelId1 = "model-1"; String modelId2 = "model-2"; float boost1 = 1.8f; @@ -391,6 +427,7 @@ public void testHashAndEquals() { NeuralQueryBuilder neuralQueryBuilder_baseline = new NeuralQueryBuilder().fieldName(fieldName1) .queryText(queryText1) + .queryImage(imageText1) .modelId(modelId1) .k(k1) .boost(boost1) @@ -527,7 +564,43 @@ public void testRewrite_whenVectorSupplierNull_thenSetVectorSupplier() { ActionListener> listener = invocation.getArgument(2); listener.onResponse(expectedVector); return null; - }).when(mlCommonsClientAccessor).inferenceSentence(any(), any(), any()); + }).when(mlCommonsClientAccessor).inferenceSentences(any(), anyMap(), any()); + NeuralQueryBuilder.initialize(mlCommonsClientAccessor); + + final CountDownLatch inProgressLatch = new CountDownLatch(1); + QueryRewriteContext queryRewriteContext = mock(QueryRewriteContext.class); + doAnswer(invocation -> { + BiConsumer> biConsumer = invocation.getArgument(0); + biConsumer.accept( + null, + ActionListener.wrap( + response -> inProgressLatch.countDown(), + err -> fail("Failed to set vector supplier: " + err.getMessage()) + ) + ); + return null; + }).when(queryRewriteContext).registerAsyncAction(any()); + + NeuralQueryBuilder queryBuilder = (NeuralQueryBuilder) neuralQueryBuilder.doRewrite(queryRewriteContext); + assertNotNull(queryBuilder.vectorSupplier()); + assertTrue(inProgressLatch.await(5, TimeUnit.SECONDS)); + assertArrayEquals(VectorUtil.vectorAsListToArray(expectedVector), queryBuilder.vectorSupplier().get(), 0.0f); + } + + @SneakyThrows + public void testRewrite_whenVectorSupplierNullAndQueryTextAndImageTextSet_thenSetVectorSupplier() { + NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder().fieldName(FIELD_NAME) + .queryText(QUERY_TEXT) + .queryImage(IMAGE_TEXT) + .modelId(MODEL_ID) + .k(K); + List expectedVector = Arrays.asList(1.0f, 2.0f, 3.0f, 4.0f, 5.0f); + MLCommonsClientAccessor mlCommonsClientAccessor = mock(MLCommonsClientAccessor.class); + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(2); + listener.onResponse(expectedVector); + return null; + }).when(mlCommonsClientAccessor).inferenceSentences(any(), anyMap(), any()); NeuralQueryBuilder.initialize(mlCommonsClientAccessor); final CountDownLatch inProgressLatch = new CountDownLatch(1); @@ -554,6 +627,7 @@ public void testRewrite_whenVectorNull_thenReturnCopy() { Supplier nullSupplier = () -> null; NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder().fieldName(FIELD_NAME) .queryText(QUERY_TEXT) + .queryImage(IMAGE_TEXT) .modelId(MODEL_ID) .k(K) .vectorSupplier(nullSupplier); @@ -564,6 +638,7 @@ public void testRewrite_whenVectorNull_thenReturnCopy() { public void testRewrite_whenVectorSupplierAndVectorSet_thenReturnKNNQueryBuilder() { NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder().fieldName(FIELD_NAME) .queryText(QUERY_TEXT) + .queryImage(IMAGE_TEXT) .modelId(MODEL_ID) .k(K) .vectorSupplier(TEST_VECTOR_SUPPLIER); @@ -588,6 +663,21 @@ public void testRewrite_whenFilterSet_thenKNNQueryBuilderFilterSet() { assertEquals(neuralQueryBuilder.filter(), knnQueryBuilder.getFilter()); } + public void testQueryCreation_whenCreateQueryWithDoToQuery_thenFail() { + NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder().fieldName(FIELD_NAME) + .queryText(QUERY_TEXT) + .modelId(MODEL_ID) + .k(K) + .vectorSupplier(TEST_VECTOR_SUPPLIER) + .filter(TEST_FILTER); + QueryShardContext queryShardContext = mock(QueryShardContext.class); + UnsupportedOperationException exception = expectThrows( + UnsupportedOperationException.class, + () -> neuralQueryBuilder.doToQuery(queryShardContext) + ); + assertEquals("Query cannot be created by NeuralQueryBuilder directly", exception.getMessage()); + } + private void setUpClusterService(Version version) { ClusterService clusterService = NeuralSearchClusterTestUtils.mockClusterService(version); NeuralSearchClusterUtil.instance().initialize(clusterService); diff --git a/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryIT.java index afe3226c8..6e089500b 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryIT.java @@ -31,6 +31,7 @@ public class NeuralQueryIT extends BaseNeuralSearchIT { private static final String TEST_NESTED_INDEX_NAME = "test-neural-nested-index"; private static final String TEST_MULTI_DOC_INDEX_NAME = "test-neural-multi-doc-index"; private static final String TEST_QUERY_TEXT = "Hello world"; + private static final String TEST_IMAGE_TEXT = "/9j/4AAQSkZJRgABAQAASABIAAD"; private static final String TEST_KNN_VECTOR_FIELD_NAME_1 = "test-knn-vector-1"; private static final String TEST_KNN_VECTOR_FIELD_NAME_2 = "test-knn-vector-2"; private static final String TEST_TEXT_FIELD_NAME_1 = "test-text-field"; @@ -79,6 +80,7 @@ public void testBasicQuery() { NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder( TEST_KNN_VECTOR_FIELD_NAME_1, TEST_QUERY_TEXT, + "", modelId, 1, null, @@ -114,6 +116,7 @@ public void testBoostQuery() { NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder( TEST_KNN_VECTOR_FIELD_NAME_1, TEST_QUERY_TEXT, + "", modelId, 1, null, @@ -158,6 +161,7 @@ public void testRescoreQuery() { NeuralQueryBuilder rescoreNeuralQueryBuilder = new NeuralQueryBuilder( TEST_KNN_VECTOR_FIELD_NAME_1, TEST_QUERY_TEXT, + "", modelId, 1, null, @@ -206,6 +210,7 @@ public void testBooleanQuery_withMultipleNeuralQueries() { NeuralQueryBuilder neuralQueryBuilder1 = new NeuralQueryBuilder( TEST_KNN_VECTOR_FIELD_NAME_1, TEST_QUERY_TEXT, + "", modelId, 1, null, @@ -214,6 +219,7 @@ public void testBooleanQuery_withMultipleNeuralQueries() { NeuralQueryBuilder neuralQueryBuilder2 = new NeuralQueryBuilder( TEST_KNN_VECTOR_FIELD_NAME_2, TEST_QUERY_TEXT, + "", modelId, 1, null, @@ -262,6 +268,7 @@ public void testBooleanQuery_withNeuralAndBM25Queries() { NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder( TEST_KNN_VECTOR_FIELD_NAME_1, TEST_QUERY_TEXT, + "", modelId, 1, null, @@ -306,6 +313,7 @@ public void testNestedQuery() { NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder( TEST_KNN_VECTOR_FIELD_NAME_NESTED, TEST_QUERY_TEXT, + "", modelId, 1, null, @@ -348,6 +356,7 @@ public void testFilterQuery() { NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder( TEST_KNN_VECTOR_FIELD_NAME_1, TEST_QUERY_TEXT, + "", modelId, 1, null, @@ -361,6 +370,42 @@ public void testFilterQuery() { assertEquals(expectedScore, objectToFloat(firstInnerHit.get("_score")), 0.0); } + /** + * Tests basic query for multimodal: + * { + * "query": { + * "neural": { + * "text_knn": { + * "query_text": "Hello world", + * "query_image": "base64_1234567890", + * "model_id": "dcsdcasd", + * "k": 1 + * } + * } + * } + * } + */ + @SneakyThrows + public void testMultimodalQuery() { + initializeIndexIfNotExist(TEST_BASIC_INDEX_NAME); + String modelId = getDeployedModelId(); + NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder( + TEST_KNN_VECTOR_FIELD_NAME_1, + TEST_QUERY_TEXT, + TEST_IMAGE_TEXT, + modelId, + 1, + null, + null + ); + Map searchResponseAsMap = search(TEST_BASIC_INDEX_NAME, neuralQueryBuilder, 1); + Map firstInnerHit = getFirstInnerHit(searchResponseAsMap); + + assertEquals("1", firstInnerHit.get("_id")); + float expectedScore = computeExpectedScore(modelId, testVector, TEST_SPACE_TYPE, TEST_QUERY_TEXT); + assertEquals(expectedScore, objectToFloat(firstInnerHit.get("_score")), 0.0); + } + @SneakyThrows private void initializeIndexIfNotExist(String indexName) { if (TEST_BASIC_INDEX_NAME.equals(indexName) && !indexExists(TEST_BASIC_INDEX_NAME)) { diff --git a/src/test/resources/processor/PipelineForTextImageEmbeddingProcessorConfiguration.json b/src/test/resources/processor/PipelineForTextImageEmbeddingProcessorConfiguration.json new file mode 100644 index 000000000..60d5dc051 --- /dev/null +++ b/src/test/resources/processor/PipelineForTextImageEmbeddingProcessorConfiguration.json @@ -0,0 +1,15 @@ +{ + "description": "text image embedding pipeline", + "processors": [ + { + "text_image_embedding": { + "model_id": "%s", + "embedding": "passage_embedding", + "field_map": { + "text": "passage_text", + "image": "passage_image" + } + } + } + ] +}