From c9758cabfe145402470879f2e2963ed3428d96f0 Mon Sep 17 00:00:00 2001 From: Mingshi Liu <113382730+mingshl@users.noreply.github.com> Date: Mon, 29 Apr 2024 01:05:22 -0700 Subject: [PATCH] Initiate MLInferencelngestProcessor (#2205) * Initiate MLModelIngestProcessor Signed-off-by: Mingshi Liu add more tests Signed-off-by: Mingshi Liu * add more tests Signed-off-by: Mingshi Liu add yaml tests and nested objects tests Signed-off-by: Mingshi Liu add IT tests Signed-off-by: Mingshi Liu * use GroupListener and add DEFAULT_MAX_PREDICTION_TASKS Signed-off-by: Mingshi Liu * add javadoc Signed-off-by: Mingshi Liu * avoid calling execute(IngestDocument ingestDocument)-s Signed-off-by: Mingshi Liu * not rewriting dotpath to json path Signed-off-by: Mingshi Liu * change mapping order, input_map-model input as key, output_map-document field as key Signed-off-by: Mingshi Liu * use StringUtils.toJson Signed-off-by: Mingshi Liu --------- Signed-off-by: Mingshi Liu --- .../common/output/model/MLResultDataType.java | 9 + plugin/build.gradle | 1 + .../ml/plugin/MachineLearningPlugin.java | 22 +- .../InferenceProcessorAttributes.java | 83 ++ .../processor/MLInferenceIngestProcessor.java | 447 +++++++ .../ml/processor/ModelExecutor.java | 252 ++++ .../InferenceProcessorAttributesTests.java | 58 + ...LInferenceIngestProcessorFactoryTests.java | 143 +++ .../MLInferenceIngestProcessorTests.java | 1105 +++++++++++++++++ .../ml/rest/MLCommonsRestTestCase.java | 5 +- .../ml/rest/RestMLGuardrailsIT.java | 8 - .../RestMLInferenceIngestProcessorIT.java | 263 ++++ .../ml/rest/RestMLRemoteInferenceIT.java | 24 +- .../plugin/PluginClientYamlTestSuiteIT.java | 25 + .../resources/rest-api-spec/test/10_basic.yml | 8 + .../test/20_inference_ingest_processor.yml | 24 + 16 files changed, 2450 insertions(+), 27 deletions(-) create mode 100644 plugin/src/main/java/org/opensearch/ml/processor/InferenceProcessorAttributes.java create mode 100644 plugin/src/main/java/org/opensearch/ml/processor/MLInferenceIngestProcessor.java create mode 100644 plugin/src/main/java/org/opensearch/ml/processor/ModelExecutor.java create mode 100644 plugin/src/test/java/org/opensearch/ml/processor/InferenceProcessorAttributesTests.java create mode 100644 plugin/src/test/java/org/opensearch/ml/processor/MLInferenceIngestProcessorFactoryTests.java create mode 100644 plugin/src/test/java/org/opensearch/ml/processor/MLInferenceIngestProcessorTests.java create mode 100644 plugin/src/test/java/org/opensearch/ml/rest/RestMLInferenceIngestProcessorIT.java create mode 100644 plugin/src/yamlRestTest/java/org/opensearch/ml/plugin/PluginClientYamlTestSuiteIT.java create mode 100644 plugin/src/yamlRestTest/resources/rest-api-spec/test/10_basic.yml create mode 100644 plugin/src/yamlRestTest/resources/rest-api-spec/test/20_inference_ingest_processor.yml diff --git a/common/src/main/java/org/opensearch/ml/common/output/model/MLResultDataType.java b/common/src/main/java/org/opensearch/ml/common/output/model/MLResultDataType.java index 1b187e3561..ee7cc1bde1 100644 --- a/common/src/main/java/org/opensearch/ml/common/output/model/MLResultDataType.java +++ b/common/src/main/java/org/opensearch/ml/common/output/model/MLResultDataType.java @@ -60,4 +60,13 @@ public boolean isInteger() { public boolean isBoolean() { return format == Format.BOOLEAN; } + + /** + * Checks whether it is a String data type. + * + * @return whether it is a String data type + */ + public boolean isString() { + return format == Format.STRING; + } } diff --git a/plugin/build.gradle b/plugin/build.gradle index 90ae574af6..c2ff6931bd 100644 --- a/plugin/build.gradle +++ b/plugin/build.gradle @@ -60,6 +60,7 @@ dependencies { implementation "org.apache.logging.log4j:log4j-slf4j-impl:2.19.0" testImplementation group: 'commons-io', name: 'commons-io', version: '2.15.1' implementation group: 'org.apache.commons', name: 'commons-text', version: '1.10.0' + implementation 'com.jayway.jsonpath:json-path:2.9.0' } publishing { diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index b84001d50e..89b812b613 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -14,6 +14,7 @@ import java.nio.file.Path; import java.util.ArrayList; import java.util.Collection; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -202,6 +203,7 @@ import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler; import org.opensearch.ml.model.MLModelCacheHelper; import org.opensearch.ml.model.MLModelManager; +import org.opensearch.ml.processor.MLInferenceIngestProcessor; import org.opensearch.ml.repackage.com.google.common.collect.ImmutableList; import org.opensearch.ml.rest.RestMLCreateConnectorAction; import org.opensearch.ml.rest.RestMLCreateControllerAction; @@ -275,6 +277,7 @@ import org.opensearch.monitor.os.OsService; import org.opensearch.plugins.ActionPlugin; import org.opensearch.plugins.ExtensiblePlugin; +import org.opensearch.plugins.IngestPlugin; import org.opensearch.plugins.Plugin; import org.opensearch.plugins.SearchPipelinePlugin; import org.opensearch.plugins.SearchPlugin; @@ -298,7 +301,13 @@ import lombok.SneakyThrows; -public class MachineLearningPlugin extends Plugin implements ActionPlugin, SearchPlugin, SearchPipelinePlugin, ExtensiblePlugin { +public class MachineLearningPlugin extends Plugin + implements + ActionPlugin, + SearchPlugin, + SearchPipelinePlugin, + ExtensiblePlugin, + IngestPlugin { public static final String ML_THREAD_POOL_PREFIX = "thread_pool.ml_commons."; public static final String GENERAL_THREAD_POOL = "opensearch_ml_general"; public static final String EXECUTE_THREAD_POOL = "opensearch_ml_execute"; @@ -983,4 +992,15 @@ public void loadExtensions(ExtensionLoader loader) { } } } + + /** + * To get ingest processors + */ + @Override + public Map getProcessors(org.opensearch.ingest.Processor.Parameters parameters) { + Map processors = new HashMap<>(); + processors + .put(MLInferenceIngestProcessor.TYPE, new MLInferenceIngestProcessor.Factory(parameters.scriptService, parameters.client)); + return Collections.unmodifiableMap(processors); + } } diff --git a/plugin/src/main/java/org/opensearch/ml/processor/InferenceProcessorAttributes.java b/plugin/src/main/java/org/opensearch/ml/processor/InferenceProcessorAttributes.java new file mode 100644 index 0000000000..9a72d04577 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/processor/InferenceProcessorAttributes.java @@ -0,0 +1,83 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.ml.processor; + +import java.util.List; +import java.util.Map; + +import lombok.Getter; +import lombok.Setter; + +@Setter +@Getter +public class InferenceProcessorAttributes { + + protected List> inputMaps; + + protected List> outputMaps; + + protected String modelId; + protected int maxPredictionTask; + + protected Map modelConfigMaps; + public static final String MODEL_ID = "model_id"; + /** + * The list of maps that support one or more prediction tasks mapping. + * The mappings also support JSON path for nested objects. + * + * input_map is used to construct model inputs, where the keys represent the model input fields, + * and the values represent the corresponding document fields. + * + * Example input_map: + * + * "input_map": [ + * { + * "input": "book.title" + * }, + * { + * "input": "book.text" + * } + * ] + */ + public static final String INPUT_MAP = "input_map"; + /** + * output_map is used to construct document fields, where the keys represent the document fields, + * and the values represent the corresponding model output fields. + * + * Example output_map: + * + * "output_map": [ + * { + * "book.title_language": "response.language" + * }, + * { + * "book.text_language": "response.language" + * } + * ] + * + */ + public static final String OUTPUT_MAP = "output_map"; + public static final String MODEL_CONFIG = "model_config"; + public static final String MAX_PREDICTION_TASKS = "max_prediction_tasks"; + + /** + * Utility class containing shared parameters for MLModelIngest/SearchProcessor + * */ + + public InferenceProcessorAttributes( + String modelId, + List> inputMaps, + List> outputMaps, + Map modelConfigMaps, + int maxPredictionTask + ) { + this.modelId = modelId; + this.modelConfigMaps = modelConfigMaps; + this.inputMaps = inputMaps; + this.outputMaps = outputMaps; + this.maxPredictionTask = maxPredictionTask; + } + +} diff --git a/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceIngestProcessor.java b/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceIngestProcessor.java new file mode 100644 index 0000000000..0caecf3ab7 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceIngestProcessor.java @@ -0,0 +1,447 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.ml.processor; + +import static org.opensearch.ml.processor.InferenceProcessorAttributes.*; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.BiConsumer; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.support.GroupedActionListener; +import org.opensearch.client.Client; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.Strings; +import org.opensearch.ingest.AbstractProcessor; +import org.opensearch.ingest.ConfigurationUtils; +import org.opensearch.ingest.IngestDocument; +import org.opensearch.ingest.Processor; +import org.opensearch.ingest.ValueSource; +import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.common.transport.MLTaskResponse; +import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; +import org.opensearch.ml.common.utils.StringUtils; +import org.opensearch.script.ScriptService; +import org.opensearch.script.TemplateScript; + +import com.jayway.jsonpath.Configuration; +import com.jayway.jsonpath.JsonPath; +import com.jayway.jsonpath.Option; + +/** + * MLInferenceIngestProcessor requires a modelId string to call model inferences + * maps fields in document for model input, and maps model inference output to new document fields + * this processor also handles dot path notation for nested object( map of array) by rewriting json path accordingly + */ +public class MLInferenceIngestProcessor extends AbstractProcessor implements ModelExecutor { + + public static final String DOT_SYMBOL = "."; + private final InferenceProcessorAttributes inferenceProcessorAttributes; + private final boolean ignoreMissing; + private final boolean ignoreFailure; + private final ScriptService scriptService; + private static Client client; + public static final String TYPE = "ml_inference"; + public static final String DEFAULT_OUTPUT_FIELD_NAME = "inference_results"; + // allow to ignore a field from mapping is not present in the document, and when the outfield is not found in the + // prediction outcomes, return the whole prediction outcome by skipping filtering + public static final String IGNORE_MISSING = "ignore_missing"; + // At default, ml inference processor allows maximum 10 prediction tasks running in parallel + // it can be overwritten using maxPredictionTask when creating processor + public static final int DEFAULT_MAX_PREDICTION_TASKS = 10; + + private Configuration suppressExceptionConfiguration = Configuration + .builder() + .options(Option.SUPPRESS_EXCEPTIONS, Option.DEFAULT_PATH_LEAF_TO_NULL, Option.ALWAYS_RETURN_LIST) + .build(); + + protected MLInferenceIngestProcessor( + String modelId, + List> inputMaps, + List> outputMaps, + Map modelConfigMaps, + int maxPredictionTask, + String tag, + String description, + boolean ignoreMissing, + boolean ignoreFailure, + ScriptService scriptService, + Client client + ) { + super(tag, description); + this.inferenceProcessorAttributes = new InferenceProcessorAttributes( + modelId, + inputMaps, + outputMaps, + modelConfigMaps, + maxPredictionTask + ); + this.ignoreMissing = ignoreMissing; + this.ignoreFailure = ignoreFailure; + this.scriptService = scriptService; + this.client = client; + } + + /** + * This method is used to execute inference asynchronously, + * supporting multiple predictions. + * + * @param ingestDocument The document to be processed. + * @param handler A consumer for handling the processing result or any exception occurred during processing. + */ + @Override + public void execute(IngestDocument ingestDocument, BiConsumer handler) { + + List> processInputMap = inferenceProcessorAttributes.getInputMaps(); + List> processOutputMap = inferenceProcessorAttributes.getOutputMaps(); + int inputMapSize = (processInputMap != null) ? processInputMap.size() : 0; + + GroupedActionListener batchPredictionListener = new GroupedActionListener<>(new ActionListener>() { + @Override + public void onResponse(Collection voids) { + handler.accept(ingestDocument, null); + } + + @Override + public void onFailure(Exception e) { + if (ignoreFailure) { + handler.accept(ingestDocument, null); + } else { + handler.accept(null, e); + } + } + }, Math.max(inputMapSize, 1)); + + for (int inputMapIndex = 0; inputMapIndex < Math.max(inputMapSize, 1); inputMapIndex++) { + try { + processPredictions(ingestDocument, batchPredictionListener, processInputMap, processOutputMap, inputMapIndex, inputMapSize); + } catch (Exception e) { + batchPredictionListener.onFailure(e); + } + } + } + + /** + * This method was called previously within + * execute( IngestDocument ingestDocument, BiConsumer (IngestDocument, Exception) handler) + * in the ml_inference ingest processor, it's not called. + * + * @param ingestDocument + * @throws Exception + */ + @Override + public IngestDocument execute(IngestDocument ingestDocument) throws Exception { + throw new UnsupportedOperationException("this method should not get executed."); + } + + /** + * process predictions for one model for multiple rounds of predictions + * ingest documents after prediction rounds are completed, + * when no input mappings provided, default to add all fields to model input fields, + * when no output mapping provided, default to output as + * "inference_results" field (the same format as predict API) + * + * @param ingestDocument The IngestDocument object containing the data to be processed. + * @param batchPredictionListener The GroupedActionListener for batch prediction. + * @param processInputMap A list of maps containing input field mappings. + * @param processOutputMap A list of maps containing output field mappings. + * @param inputMapIndex The current index of the inputMap. + * @param inputMapSize The size of inputMap. + */ + private void processPredictions( + IngestDocument ingestDocument, + GroupedActionListener batchPredictionListener, + List> processInputMap, + List> processOutputMap, + int inputMapIndex, + int inputMapSize + ) { + Map modelParameters = new HashMap<>(); + if (inferenceProcessorAttributes.getModelConfigMaps() != null) { + modelParameters.putAll(inferenceProcessorAttributes.getModelConfigMaps()); + } + // when no input mapping is provided, default to read all fields from documents as model input + if (inputMapSize == 0) { + Set documentFields = ingestDocument.getSourceAndMetadata().keySet(); + for (String field : documentFields) { + getMappedModelInputFromDocuments(ingestDocument, modelParameters, field, field); + } + + } else { + Map inputMapping = processInputMap.get(inputMapIndex); + for (Map.Entry entry : inputMapping.entrySet()) { + // model field as key, document field as value + String modelInputFieldName = entry.getKey(); + String documentFieldName = entry.getValue(); + getMappedModelInputFromDocuments(ingestDocument, modelParameters, documentFieldName, modelInputFieldName); + } + } + + ActionRequest request = getRemoteModelInferenceRequest(modelParameters, inferenceProcessorAttributes.getModelId()); + + client.execute(MLPredictionTaskAction.INSTANCE, request, new ActionListener<>() { + + @Override + public void onResponse(MLTaskResponse mlTaskResponse) { + ModelTensorOutput modelTensorOutput = (ModelTensorOutput) mlTaskResponse.getOutput(); + if (processOutputMap == null || processOutputMap.isEmpty()) { + appendFieldValue(modelTensorOutput, null, DEFAULT_OUTPUT_FIELD_NAME, ingestDocument); + } else { + // outMapping serves as a filter to modelTensorOutput, the fields that are not specified + // in the outputMapping will not write to document + Map outputMapping = processOutputMap.get(inputMapIndex); + + for (Map.Entry entry : outputMapping.entrySet()) { + // document field as key, model field as value + String newDocumentFieldName = entry.getKey(); + String modelOutputFieldName = entry.getValue(); + if (ingestDocument.hasField(newDocumentFieldName)) { + throw new IllegalArgumentException( + "document already has field name " + + newDocumentFieldName + + ". Not allow to overwrite the same field name, please check output_map." + ); + } + appendFieldValue(modelTensorOutput, modelOutputFieldName, newDocumentFieldName, ingestDocument); + } + } + batchPredictionListener.onResponse(null); + } + + @Override + public void onFailure(Exception e) { + batchPredictionListener.onFailure(e); + } + }); + + } + + /** + * Retrieves the mapped model input from the IngestDocument and updates the model parameters. + * + * @param ingestDocument The IngestDocument object containing the data. + * @param modelParameters The map to store the model parameters. + * @param documentFieldName The name of the field in the IngestDocument. + * @param modelInputFieldName The name of the model input field. + */ + private void getMappedModelInputFromDocuments( + IngestDocument ingestDocument, + Map modelParameters, + String documentFieldName, + String modelInputFieldName + ) { + // if users used standard dot path, try getFieldPath from document + String originalFieldPath = getFieldPath(ingestDocument, documentFieldName); + if (originalFieldPath != null) { + Object documentFieldValue = ingestDocument.getFieldValue(originalFieldPath, Object.class); + String documentFieldValueAsString = toString(documentFieldValue); + updateModelParameters(modelInputFieldName, documentFieldValueAsString, modelParameters); + } + // else when cannot find field path in document, try check for nested array using json path + else { + if (documentFieldName.contains(DOT_SYMBOL)) { + + Map sourceObject = ingestDocument.getSourceAndMetadata(); + ArrayList fieldValueList = JsonPath + .using(suppressExceptionConfiguration) + .parse(sourceObject) + .read(documentFieldName); + if (!fieldValueList.isEmpty()) { + updateModelParameters(modelInputFieldName, toString(fieldValueList), modelParameters); + } else if (!ignoreMissing) { + throw new IllegalArgumentException("cannot find field name defined from input map: " + documentFieldName); + } + } else if (!ignoreMissing) { + throw new IllegalArgumentException("cannot find field name defined from input map: " + documentFieldName); + } + } + } + + /** + * This method supports mapping multiple document fields to the same model input field. + * It checks if the given model input field name already exists in the modelParameters map. + * If it exists, the method appends the originalFieldValueAsString to the existing value, + * which is expected to be a list. If the model input field name does not exist in the map, + * it adds a new entry with the originalFieldValueAsString as the value. + * + * @param modelInputFieldName the name of the model input field + * @param originalFieldValueAsString the value of the document field to be mapped + * @param modelParameters a map containing the model input fields and their values + */ + private void updateModelParameters(String modelInputFieldName, String originalFieldValueAsString, Map modelParameters) { + + if (modelParameters.containsKey(modelInputFieldName)) { + Object existingValue = modelParameters.get(modelInputFieldName); + List updatedList = (List) existingValue; + updatedList.add(originalFieldValueAsString); + modelParameters.put(modelInputFieldName, toString(updatedList)); + } else { + modelParameters.put(modelInputFieldName, originalFieldValueAsString); + } + + } + + /** + * Retrieves the field path from the given IngestDocument for the specified documentFieldName. + * + * @param ingestDocument the IngestDocument containing the field + * @param documentFieldName the name of the field to retrieve the path for + * @return the field path if the field exists, null otherwise + */ + private String getFieldPath(IngestDocument ingestDocument, String documentFieldName) { + if (Strings.isNullOrEmpty(documentFieldName) || !ingestDocument.hasField(documentFieldName, true)) { + return null; + } + return documentFieldName; + } + + /** + * Appends the model output value to the specified field in the IngestDocument without modifying the source. + * + * @param modelTensorOutput the ModelTensorOutput containing the model output + * @param modelOutputFieldName the name of the field in the model output + * @param newDocumentFieldName the name of the field in the IngestDocument to append the value to + * @param ingestDocument the IngestDocument to append the value to + */ + private void appendFieldValue( + ModelTensorOutput modelTensorOutput, + String modelOutputFieldName, + String newDocumentFieldName, + IngestDocument ingestDocument + ) { + Object modelOutputValue = null; + + if (modelTensorOutput.getMlModelOutputs() != null && modelTensorOutput.getMlModelOutputs().size() > 0) { + + modelOutputValue = getModelOutputValue(modelTensorOutput, modelOutputFieldName, ignoreMissing); + + List dotPathsInArray = writeNewDotPathForNestedObject(ingestDocument.getSourceAndMetadata(), newDocumentFieldName); + + if (dotPathsInArray.size() == 1) { + ValueSource ingestValue = ValueSource.wrap(modelOutputValue, scriptService); + TemplateScript.Factory ingestField = ConfigurationUtils + .compileTemplate(TYPE, tag, newDocumentFieldName, newDocumentFieldName, scriptService); + ingestDocument.setFieldValue(ingestField, ingestValue, ignoreMissing); + } else { + if (!(modelOutputValue instanceof List)) { + throw new IllegalArgumentException("Model output is not an array, cannot assign to array in documents."); + } + List modelOutputValueArray = (List) modelOutputValue; + // check length of the prediction array to be the same of the document array + if (dotPathsInArray.size() != modelOutputValueArray.size()) { + throw new RuntimeException( + "the prediction field: " + + modelOutputFieldName + + " is an array in size of " + + modelOutputValueArray.size() + + " but the document field array from field " + + newDocumentFieldName + + " is in size of " + + dotPathsInArray.size() + ); + } + // Iterate over dotPathInArray + for (int i = 0; i < dotPathsInArray.size(); i++) { + String dotPathInArray = dotPathsInArray.get(i); + Object modelOutputValueInArray = modelOutputValueArray.get(i); + ValueSource ingestValue = ValueSource.wrap(modelOutputValueInArray, scriptService); + TemplateScript.Factory ingestField = ConfigurationUtils + .compileTemplate(TYPE, tag, dotPathInArray, dotPathInArray, scriptService); + ingestDocument.setFieldValue(ingestField, ingestValue, ignoreMissing); + } + } + } else { + throw new RuntimeException("model inference output cannot be null"); + } + } + + @Override + public String getType() { + return TYPE; + } + + public static class Factory implements Processor.Factory { + + private final ScriptService scriptService; + private final Client client; + + /** + * Constructs a new instance of the Factory class. + * + * @param scriptService the ScriptService instance to be used by the Factory + * @param client the Client instance to be used by the Factory + */ + public Factory(ScriptService scriptService, Client client) { + this.scriptService = scriptService; + this.client = client; + } + + /** + * Creates a new instance of the MLInferenceIngestProcessor. + * + * @param registry a map of registered processor factories + * @param processorTag a unique tag for the processor + * @param description a description of the processor + * @param config a map of configuration properties for the processor + * @return a new instance of the MLInferenceIngestProcessor + * @throws Exception if there is an error creating the processor + */ + @Override + public MLInferenceIngestProcessor create( + Map registry, + String processorTag, + String description, + Map config + ) throws Exception { + String modelId = ConfigurationUtils.readStringProperty(TYPE, processorTag, config, MODEL_ID); + Map modelConfigInput = ConfigurationUtils.readOptionalMap(TYPE, processorTag, config, MODEL_CONFIG); + List> inputMaps = ConfigurationUtils.readOptionalList(TYPE, processorTag, config, INPUT_MAP); + List> outputMaps = ConfigurationUtils.readOptionalList(TYPE, processorTag, config, OUTPUT_MAP); + int maxPredictionTask = ConfigurationUtils + .readIntProperty(TYPE, processorTag, config, MAX_PREDICTION_TASKS, DEFAULT_MAX_PREDICTION_TASKS); + boolean ignoreMissing = ConfigurationUtils.readBooleanProperty(TYPE, processorTag, config, IGNORE_MISSING, false); + boolean ignoreFailure = ConfigurationUtils + .readBooleanProperty(TYPE, processorTag, config, ConfigurationUtils.IGNORE_FAILURE_KEY, false); + // convert model config user input data structure to Map + Map modelConfigMaps = null; + if (modelConfigInput != null) { + modelConfigMaps = StringUtils.getParameterMap(modelConfigInput); + } + // check if the number of prediction tasks exceeds max prediction tasks + if (inputMaps != null && inputMaps.size() > maxPredictionTask) { + throw new IllegalArgumentException( + "The number of prediction task setting in this process is " + + inputMaps.size() + + ". It exceeds the max_prediction_tasks of " + + maxPredictionTask + + ". Please reduce the size of input_map or increase max_prediction_tasks." + ); + } + if (inputMaps != null && outputMaps != null && outputMaps.size() != inputMaps.size()) { + throw new IllegalArgumentException("The length of output_map and the length of input_map do no match."); + } + + return new MLInferenceIngestProcessor( + modelId, + inputMaps, + outputMaps, + modelConfigMaps, + maxPredictionTask, + processorTag, + description, + ignoreMissing, + ignoreFailure, + scriptService, + client + ); + } + } + +} diff --git a/plugin/src/main/java/org/opensearch/ml/processor/ModelExecutor.java b/plugin/src/main/java/org/opensearch/ml/processor/ModelExecutor.java new file mode 100644 index 0000000000..1abc770d07 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/processor/ModelExecutor.java @@ -0,0 +1,252 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.processor; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import org.opensearch.action.ActionRequest; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.output.model.ModelTensor; +import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; +import org.opensearch.ml.common.utils.StringUtils; + +import com.jayway.jsonpath.Configuration; +import com.jayway.jsonpath.JsonPath; +import com.jayway.jsonpath.Option; + +/** + * General ModelExecutor interface. + */ +public interface ModelExecutor { + + Configuration suppressExceptionConfiguration = Configuration + .builder() + .options(Option.SUPPRESS_EXCEPTIONS, Option.DEFAULT_PATH_LEAF_TO_NULL) + .build(); + + /** + * Creates an ActionRequest for remote model inference based on the provided parameters and model ID. + * + * @param the type parameter for the ActionRequest + * @param parameters a map of input parameters for the model inference + * @param modelId the ID of the model to be used for inference + * @return an ActionRequest instance for remote model inference + * @throws IllegalArgumentException if the input parameters are null + */ + default ActionRequest getRemoteModelInferenceRequest(Map parameters, String modelId) { + if (parameters == null) { + throw new IllegalArgumentException("wrong input. The model input cannot be empty."); + } + RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet.builder().parameters(parameters).build(); + + MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build(); + + ActionRequest request = new MLPredictionTaskRequest(modelId, mlInput, null); + + return request; + + } + + /** + * Retrieves the model output value from the given ModelTensorOutput for the specified modelOutputFieldName. + * It handles cases where the output contains a single tensor or multiple tensors. + * + * @param modelTensorOutput the ModelTensorOutput containing the model output + * @param modelOutputFieldName the name of the field in the model output to retrieve the value for + * @param ignoreMissing a flag indicating whether to ignore missing fields or throw an exception + * @return the model output value as an Object + * @throws RuntimeException if there is an error retrieving the model output value + */ + default Object getModelOutputValue(ModelTensorOutput modelTensorOutput, String modelOutputFieldName, boolean ignoreMissing) { + Object modelOutputValue; + try { + // getMlModelOutputs() returns a list or collection. + // Adding null check for modelTensorOutput + if (modelTensorOutput != null && !modelTensorOutput.getMlModelOutputs().isEmpty()) { + // getMlModelOutputs() returns a list of ModelTensors + // accessing the first element. + // TODO currently remote model only return single tensor, might need to processor multiple tensors later + ModelTensors output = modelTensorOutput.getMlModelOutputs().get(0); + // Adding null check for output + if (output != null && output.getMlModelTensors() != null && !output.getMlModelTensors().isEmpty()) { + // getMlModelTensors() returns a list of ModelTensor + if (output.getMlModelTensors().size() == 1) { + ModelTensor tensor = output.getMlModelTensors().get(0); + // try getDataAsMap first + Map tensorInDataAsMap = tensor.getDataAsMap(); + if (tensorInDataAsMap != null) { + modelOutputValue = getModelOutputField(tensorInDataAsMap, modelOutputFieldName, ignoreMissing); + } + // if dataAsMap is empty try getData + else { + // parse data type + modelOutputValue = parseDataInTensor(tensor); + } + } else { + + // for multiple tensors, initiate an array + ArrayList tensorArray = new ArrayList<>(); + for (int i = 0; i < output.getMlModelTensors().size(); i++) { + ModelTensor tensor = output.getMlModelTensors().get(i); + + // Adding null check for tensor + if (tensor != null) { + // Assuming getData() method may throw an exception + // if the data is not available or is in an invalid state. + try { + // try getDataAsMap first + Map tensorInDataAsMap = tensor.getDataAsMap(); + if (tensorInDataAsMap != null) { + tensorArray.add(getModelOutputField(tensorInDataAsMap, modelOutputFieldName, ignoreMissing)); + } + // if dataAsMap is empty try getData + else { + tensorArray.add(parseDataInTensor(tensor)); + } + } catch (Exception e) { + // Handle the exception accordingly + throw new RuntimeException("Error accessing tensor data: " + e.getMessage()); + } + } + } + modelOutputValue = tensorArray; + } + } else { + throw new RuntimeException("Output tensors are null or empty."); + } + } else { + throw new RuntimeException("Model outputs are null or empty."); + } + } catch (Exception e) { + throw new RuntimeException("An unexpected error occurred: " + e.getMessage()); + } + return modelOutputValue; + } + + /** + * Parses the data from the given ModelTensor and returns it as an Object. + * The method handles different data types (integer, floating-point, string, and boolean) + * and converts the data accordingly. + * + * @param tensor the ModelTensor containing the data to be parsed + * @return the parsed data as an Object (typically a List) + * @throws RuntimeException if the data type is not supported + */ + static Object parseDataInTensor(ModelTensor tensor) { + Object modelOutputValue; + if (tensor.getDataType().isInteger()) { + modelOutputValue = Arrays.stream(tensor.getData()).map(Number::intValue).map(Integer::new).collect(Collectors.toList()); + } else if (tensor.getDataType().isFloating()) { + modelOutputValue = Arrays.stream(tensor.getData()).map(Number::floatValue).map(Float::new).collect(Collectors.toList()); + } else if (tensor.getDataType().isString()) { + modelOutputValue = Arrays.stream(tensor.getData()).map(String::valueOf).map(String::new).collect(Collectors.toList()); + } else if (tensor.getDataType().isBoolean()) { + modelOutputValue = Arrays + .stream(tensor.getData()) + .map(num -> num.intValue() != 0) + .map(Boolean::new) + .collect(Collectors.toList()); + } else { + throw new RuntimeException("unsupported data type in prediction data."); + } + return modelOutputValue; + } + + /** + * Retrieves the value of the specified field from the given model tensor output map. + * If the field name is null, it returns the entire map. + * If the field name is present in the map, it returns the corresponding value. + * If the field name is not present in the map, it attempts to retrieve the value using JsonPath. + * If the field is not found and ignoreMissing is true, it returns the entire map. + * If the field is not found and ignoreMissing is false, it throws an IOException. + * + * @param modelTensorOutputMap the model tensor output map to retrieve the field value from + * @param fieldName the name of the field to retrieve the value for + * @param ignoreMissing a flag indicating whether to ignore missing fields or throw an exception + * @return the value of the specified field, or the entire map if the field name is null + * @throws IOException if the field is not found and ignoreMissing is false + */ + default Object getModelOutputField(Map modelTensorOutputMap, String fieldName, boolean ignoreMissing) throws IOException { + if (fieldName == null || modelTensorOutputMap == null) { + return modelTensorOutputMap; + } + if (modelTensorOutputMap.containsKey(fieldName)) { + return modelTensorOutputMap.get(fieldName); + } + try { + return JsonPath.parse(modelTensorOutputMap).read(fieldName); + } catch (Exception e) { + if (ignoreMissing) { + return modelTensorOutputMap; + } else { + throw new IllegalArgumentException("model inference output cannot find field name: " + fieldName, e); + } + } + } + + /** + * Converts the given Object to its JSON string representation using the Gson library. + * + * @param originalFieldValue the Object to be converted to JSON string + * @return the JSON string representation of the input Object + */ + + default String toString(Object originalFieldValue) { + return StringUtils.toJson(originalFieldValue); + } + + /** + * Writes a new dot path for a nested object within the given JSON object. + * This method is useful when dealing with arrays or nested objects in the JSON structure. + * for example foo.*.bar.*.quk to be [foo.0.bar.0.quk, foo.0.bar.1.quk..] + * @param json the JSON object containing the nested object + * @param dotPath the dot path representing the location of the nested object + * @return a list of dot paths representing the new locations of the nested object + */ + default List writeNewDotPathForNestedObject(Object json, String dotPath) { + int lastDotIndex = dotPath.lastIndexOf('.'); + List dotPaths = new ArrayList<>(); + if (lastDotIndex != -1) { // Check if dot exists + String leadingDotPath = dotPath.substring(0, lastDotIndex); + String lastLeave = dotPath.substring(lastDotIndex + 1, dotPath.length()); + Configuration configuration = Configuration + .builder() + .options(Option.ALWAYS_RETURN_LIST, Option.AS_PATH_LIST, Option.SUPPRESS_EXCEPTIONS, Option.DEFAULT_PATH_LEAF_TO_NULL) + .build(); + + List resultPaths = JsonPath.using(configuration).parse(json).read(leadingDotPath); + for (String path : resultPaths) { + dotPaths.add(convertToDotPath(path) + "." + lastLeave); + } + return dotPaths; + } else { + dotPaths.add(dotPath); + } + return dotPaths; + } + + /** + * Converts a JSONPath format string to a dot path notation format. + * For example, it converts "$['foo'][0]['bar']['quz'][0]" to "foo.0.bar.quiz.0". + * + * @param path the JSONPath format string to be converted + * @return the converted dot path notation string + */ + default String convertToDotPath(String path) { + + return path.replaceAll("\\[(\\d+)\\]", "$1\\.").replaceAll("\\['(.*?)']", "$1\\.").replaceAll("^\\$", "").replaceAll("\\.$", ""); + } + +} diff --git a/plugin/src/test/java/org/opensearch/ml/processor/InferenceProcessorAttributesTests.java b/plugin/src/test/java/org/opensearch/ml/processor/InferenceProcessorAttributesTests.java new file mode 100644 index 0000000000..f1c4636257 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/processor/InferenceProcessorAttributesTests.java @@ -0,0 +1,58 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.ml.processor; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.opensearch.ml.processor.MLInferenceIngestProcessor.DEFAULT_MAX_PREDICTION_TASKS; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.Test; + +public class InferenceProcessorAttributesTests { + + public void testConstructor() { + String modelId = "my_model"; + List> inputMap = new ArrayList<>(); + Map inputField = new HashMap<>(); + inputField.put("model_input", "document_field"); + inputMap.add(inputField); + + List> outputMap = new ArrayList<>(); + Map outputField = new HashMap<>(); + outputField.put("new_document_field", "model_output"); + outputMap.add(outputField); + + Map modelConfig = new HashMap<>(); + modelConfig.put("config_key", "config_value"); + + InferenceProcessorAttributes mlModelUtil = new InferenceProcessorAttributes( + modelId, + inputMap, + outputMap, + modelConfig, + DEFAULT_MAX_PREDICTION_TASKS + ); + + assertEquals(modelId, mlModelUtil.getModelId()); + assertEquals(inputMap, mlModelUtil.getInputMaps()); + assertEquals(outputMap, mlModelUtil.getOutputMaps()); + assertEquals(modelConfig, mlModelUtil.getModelConfigMaps()); + assertEquals(DEFAULT_MAX_PREDICTION_TASKS, mlModelUtil.getMaxPredictionTask()); + } + + @Test + public void testStaticFields() { + assertNotNull(InferenceProcessorAttributes.MODEL_ID); + assertNotNull(InferenceProcessorAttributes.INPUT_MAP); + assertNotNull(InferenceProcessorAttributes.OUTPUT_MAP); + assertNotNull(InferenceProcessorAttributes.MODEL_CONFIG); + assertNotNull(InferenceProcessorAttributes.MAX_PREDICTION_TASKS); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceIngestProcessorFactoryTests.java b/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceIngestProcessorFactoryTests.java new file mode 100644 index 0000000000..577e8b8693 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceIngestProcessorFactoryTests.java @@ -0,0 +1,143 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.ml.processor; + +import static org.opensearch.ml.processor.InferenceProcessorAttributes.*; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.Before; +import org.mockito.Mock; +import org.opensearch.OpenSearchParseException; +import org.opensearch.client.Client; +import org.opensearch.ingest.Processor; +import org.opensearch.script.ScriptService; +import org.opensearch.test.OpenSearchTestCase; + +public class MLInferenceIngestProcessorFactoryTests extends OpenSearchTestCase { + private MLInferenceIngestProcessor.Factory factory; + @Mock + private Client client; + @Mock + private ScriptService scriptService; + + @Before + public void init() { + factory = new MLInferenceIngestProcessor.Factory(scriptService, client); + } + + public void testCreateRequiredFields() throws Exception { + Map registry = new HashMap<>(); + Map config = new HashMap<>(); + config.put(MODEL_ID, "model1"); + String processorTag = randomAlphaOfLength(10); + MLInferenceIngestProcessor mLInferenceIngestProcessor = factory.create(registry, processorTag, null, config); + assertNotNull(mLInferenceIngestProcessor); + assertEquals(mLInferenceIngestProcessor.getTag(), processorTag); + assertEquals(mLInferenceIngestProcessor.getType(), MLInferenceIngestProcessor.TYPE); + } + + public void testCreateNoFieldPresent() throws Exception { + Map config = new HashMap<>(); + try { + factory.create(null, null, null, config); + fail("factory create should have failed"); + } catch (OpenSearchParseException e) { + assertEquals(e.getMessage(), ("[model_id] required property is missing")); + } + } + + public void testExceedMaxPredictionTasks() throws Exception { + Map registry = new HashMap<>(); + Map config = new HashMap<>(); + config.put(MODEL_ID, "model2"); + List> inputMap = new ArrayList<>(); + Map input0 = new HashMap<>(); + input0.put("inputs", "text"); + inputMap.add(input0); + Map input1 = new HashMap<>(); + input1.put("inputs", "hashtag"); + inputMap.add(input1); + Map input2 = new HashMap<>(); + input2.put("inputs", "timestamp"); + inputMap.add(input2); + config.put(INPUT_MAP, inputMap); + config.put(MAX_PREDICTION_TASKS, 2); + String processorTag = randomAlphaOfLength(10); + + try { + factory.create(registry, processorTag, null, config); + } catch (IllegalArgumentException e) { + assertEquals( + e.getMessage(), + ("The number of prediction task setting in this process is 3. It exceeds the max_prediction_tasks of 2. Please reduce the size of input_map or increase max_prediction_tasks.") + ); + } + } + + public void testOutputMapsExceedInputMaps() throws Exception { + Map registry = new HashMap<>(); + Map config = new HashMap<>(); + config.put(MODEL_ID, "model2"); + List> inputMap = new ArrayList<>(); + Map input0 = new HashMap<>(); + input0.put("inputs", "text"); + inputMap.add(input0); + Map input1 = new HashMap<>(); + input1.put("inputs", "hashtag"); + inputMap.add(input1); + config.put(INPUT_MAP, inputMap); + List> outputMap = new ArrayList<>(); + Map output1 = new HashMap<>(); + output1.put("text_embedding", "response"); + outputMap.add(output1); + Map output2 = new HashMap<>(); + output2.put("hashtag_embedding", "response"); + outputMap.add(output2); + Map output3 = new HashMap<>(); + output2.put("hashtvg_embedding", "response"); + outputMap.add(output3); + config.put(OUTPUT_MAP, outputMap); + config.put(MAX_PREDICTION_TASKS, 2); + String processorTag = randomAlphaOfLength(10); + + try { + factory.create(registry, processorTag, null, config); + } catch (IllegalArgumentException e) { + assertEquals(e.getMessage(), ("The length of output_map and the length of input_map do no match.")); + } + } + + public void testCreateOptionalFields() throws Exception { + Map registry = new HashMap<>(); + Map config = new HashMap<>(); + config.put(MODEL_ID, "model2"); + Map model_config = new HashMap<>(); + model_config.put("hidden_size", 768); + model_config.put("gradient_checkpointing", false); + model_config.put("position_embedding_type", "absolute"); + config.put(MODEL_CONFIG, model_config); + List> inputMap = new ArrayList<>(); + Map input = new HashMap<>(); + input.put("inputs", "text"); + inputMap.add(input); + List> outputMap = new ArrayList<>(); + Map output = new HashMap<>(); + output.put("text_embedding", "response"); + outputMap.add(output); + config.put(INPUT_MAP, inputMap); + config.put(OUTPUT_MAP, outputMap); + config.put(MAX_PREDICTION_TASKS, 5); + String processorTag = randomAlphaOfLength(10); + + MLInferenceIngestProcessor mLInferenceIngestProcessor = factory.create(registry, processorTag, null, config); + assertNotNull(mLInferenceIngestProcessor); + assertEquals(mLInferenceIngestProcessor.getTag(), processorTag); + assertEquals(mLInferenceIngestProcessor.getType(), MLInferenceIngestProcessor.TYPE); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceIngestProcessorTests.java b/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceIngestProcessorTests.java new file mode 100644 index 0000000000..54d4ef220b --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceIngestProcessorTests.java @@ -0,0 +1,1105 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.processor; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.*; +import static org.opensearch.ml.processor.MLInferenceIngestProcessor.DEFAULT_OUTPUT_FIELD_NAME; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.BiConsumer; + +import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.MockitoAnnotations; +import org.opensearch.client.Client; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ingest.IngestDocument; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.output.model.MLResultDataType; +import org.opensearch.ml.common.output.model.ModelTensor; +import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.ml.common.transport.MLTaskResponse; +import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; +import org.opensearch.ml.repackage.com.google.common.collect.ImmutableMap; +import org.opensearch.script.ScriptService; +import org.opensearch.test.OpenSearchTestCase; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.jayway.jsonpath.Configuration; +import com.jayway.jsonpath.JsonPath; +import com.jayway.jsonpath.Option; + +public class MLInferenceIngestProcessorTests extends OpenSearchTestCase { + + @Mock + private Client client; + @Mock + private ScriptService scriptService; + @Mock + private BiConsumer handler; + private static final String PROCESSOR_TAG = "inference"; + private static final String DESCRIPTION = "inference_test"; + private IngestDocument ingestDocument; + private IngestDocument nestedObjectIngestDocument; + private ModelExecutor modelExecutor; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put("key1", "value1"); + sourceAndMetadata.put("key2", "value2"); + ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); + + Map nestedObjectSourceAndMetadata = getNestedObjectWithAnotherNestedObjectSource(); + nestedObjectIngestDocument = new IngestDocument(nestedObjectSourceAndMetadata, new HashMap<>()); + modelExecutor = new ModelExecutor() { + }; + + } + + private MLInferenceIngestProcessor createMLInferenceProcessor( + String model_id, + Map model_config, + List> input_map, + List> output_map, + boolean ignoreMissing, + boolean ignoreFailure + ) { + return new MLInferenceIngestProcessor( + model_id, + input_map, + output_map, + model_config, + RANDOM_MULTIPLIER, + PROCESSOR_TAG, + DESCRIPTION, + ignoreMissing, + ignoreFailure, + scriptService, + client + ); + } + + public void testExecute_Exception() throws Exception { + MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, null, null, false, false); + try { + IngestDocument document = processor.execute(ingestDocument); + } catch (UnsupportedOperationException e) { + assertEquals("this method should not get executed.", e.getMessage()); + } + + } + + /** + * test nested object document with array of Map + */ + public void testExecute_nestedObjectStringDocumentSuccess() { + + List> inputMap = getInputMapsForNestedObjectChunks("chunks.chunk"); + + MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, inputMap, null, true, false); + ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", Arrays.asList(1, 2, 3))).build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + + processor.execute(nestedObjectIngestDocument, handler); + // match output documents + Map sourceAndMetadata = getNestedObjectWithAnotherNestedObjectSource(); + sourceAndMetadata.put(DEFAULT_OUTPUT_FIELD_NAME, ImmutableMap.of("response", Arrays.asList(1, 2, 3))); + IngestDocument ingestDocument1 = new IngestDocument(sourceAndMetadata, new HashMap<>()); + verify(handler).accept(eq(ingestDocument1), isNull()); + assertEquals(nestedObjectIngestDocument, ingestDocument1); + } + + /** + * test nested object document with array of Map, + * the value Object is a Map + */ + public void testExecute_nestedObjectMapDocumentSuccess() { + List> inputMap = getInputMapsForNestedObjectChunks("chunks.*.chunk.text"); + + MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, inputMap, null, true, false); + ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", Arrays.asList(1, 2, 3))).build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + + ArrayList childDocuments = new ArrayList<>(); + Map childDocument1Text = new HashMap<>(); + childDocument1Text.put("text", "this is first"); + Map childDocument1 = new HashMap<>(); + childDocument1.put("chunk", childDocument1Text); + + Map childDocument2 = new HashMap<>(); + Map childDocument2Text = new HashMap<>(); + childDocument2Text.put("text", "this is second"); + childDocument2.put("chunk", childDocument2Text); + + childDocuments.add(childDocument1); + childDocuments.add(childDocument2); + + Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put("chunks", childDocuments); + + IngestDocument nestedObjectIngestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); + processor.execute(nestedObjectIngestDocument, handler); + + // match input dataset + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLPredictionTaskRequest.class); + verify(client).execute(any(), argumentCaptor.capture(), any()); + + Map inputParameters = new HashMap<>(); + ArrayList embedding_text = new ArrayList<>(); + embedding_text.add("this is first"); + embedding_text.add("this is second"); + inputParameters.put("inputs", modelExecutor.toString(embedding_text)); + + MLPredictionTaskRequest expectedRequest = (MLPredictionTaskRequest) modelExecutor + .getRemoteModelInferenceRequest(inputParameters, "model1"); + MLPredictionTaskRequest actualRequest = argumentCaptor.getValue(); + + RemoteInferenceInputDataSet expectedRemoteInputDataset = (RemoteInferenceInputDataSet) expectedRequest + .getMlInput() + .getInputDataset(); + RemoteInferenceInputDataSet actualRemoteInputDataset = (RemoteInferenceInputDataSet) actualRequest.getMlInput().getInputDataset(); + + assertEquals(expectedRemoteInputDataset.getParameters().get("inputs"), actualRemoteInputDataset.getParameters().get("inputs")); + + // match document + sourceAndMetadata.put(DEFAULT_OUTPUT_FIELD_NAME, List.of(ImmutableMap.of("response", Arrays.asList(1, 2, 3)))); + IngestDocument ingestDocument1 = new IngestDocument(sourceAndMetadata, new HashMap<>()); + verify(handler).accept(eq(ingestDocument1), isNull()); + assertEquals(nestedObjectIngestDocument, ingestDocument1); + } + + public void testExecute_jsonPathWithMissingLeaves() { + + Map sourceObject = getNestedObjectWithAnotherNestedObjectSource(); + sourceObject.remove("chunks.1.chunk.text.0.context", "this is third"); + + Configuration suppressExceptionConfiguration = Configuration + .builder() + .options(Option.DEFAULT_PATH_LEAF_TO_NULL, Option.SUPPRESS_EXCEPTIONS) + .build(); + Object jsonObject = JsonPath.parse(sourceObject).json(); + JsonPath.parse(jsonObject).delete("$.chunks[1].chunk.text[0].context"); + ArrayList value = JsonPath.using(suppressExceptionConfiguration).parse(jsonObject).read("chunks.*.chunk.text.*.context"); + + assertEquals(value.size(), 4); + assertEquals(value.get(0), "this is first"); + assertEquals(value.get(1), "this is second"); + assertNull(value.get(2)); // confirm the missing leave is null + assertEquals(value.get(3), "this is fourth"); + } + + /** + * test nested object document with array of Map, + * the value Object is a also a nested object, + */ + public void testExecute_nestedObjectAndNestedObjectDocumentOutputInOneFieldSuccess() { + List> inputMap = getInputMapsForNestedObjectChunks("chunks.*.chunk.text.*.context"); + + MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, inputMap, null, true, false); + ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", Arrays.asList(1, 2, 3, 4))).build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + + Map sourceAndMetadata = getNestedObjectWithAnotherNestedObjectSource(); + + IngestDocument nestedObjectIngestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); + processor.execute(nestedObjectIngestDocument, handler); + + // match input dataset + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLPredictionTaskRequest.class); + verify(client).execute(any(), argumentCaptor.capture(), any()); + + Map inputParameters = new HashMap<>(); + ArrayList embedding_text = new ArrayList<>(); + embedding_text.add("this is first"); + embedding_text.add("this is second"); + embedding_text.add("this is third"); + embedding_text.add("this is fourth"); + inputParameters.put("inputs", modelExecutor.toString(embedding_text)); + + MLPredictionTaskRequest expectedRequest = (MLPredictionTaskRequest) modelExecutor + .getRemoteModelInferenceRequest(inputParameters, "model1"); + MLPredictionTaskRequest actualRequest = argumentCaptor.getValue(); + + RemoteInferenceInputDataSet expectedRemoteInputDataset = (RemoteInferenceInputDataSet) expectedRequest + .getMlInput() + .getInputDataset(); + RemoteInferenceInputDataSet actualRemoteInputDataset = (RemoteInferenceInputDataSet) actualRequest.getMlInput().getInputDataset(); + + assertEquals(expectedRemoteInputDataset.getParameters().get("inputs"), actualRemoteInputDataset.getParameters().get("inputs")); + + // match document + sourceAndMetadata.put(DEFAULT_OUTPUT_FIELD_NAME, List.of(ImmutableMap.of("response", Arrays.asList(1, 2, 3, 4)))); + IngestDocument ingestDocument1 = new IngestDocument(sourceAndMetadata, new HashMap<>()); + verify(handler).accept(eq(ingestDocument1), isNull()); + assertEquals(nestedObjectIngestDocument, ingestDocument1); + + } + + public void testExecute_nestedObjectAndNestedObjectDocumentOutputInArraySuccess() { + List> inputMap = getInputMapsForNestedObjectChunks("chunks.*.chunk.text.*.context"); + + List> outputMap = getOutputMapsForNestedObjectChunks(); + MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, inputMap, outputMap, true, false); + ArrayList> modelPredictionOutput = new ArrayList<>(); + modelPredictionOutput.add(Arrays.asList(1)); + modelPredictionOutput.add(Arrays.asList(2)); + modelPredictionOutput.add(Arrays.asList(3)); + modelPredictionOutput.add(Arrays.asList(4)); + ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", modelPredictionOutput)).build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + + Map sourceAndMetadata = getNestedObjectWithAnotherNestedObjectSource(); + + IngestDocument nestedObjectIngestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); + processor.execute(nestedObjectIngestDocument, handler); + + // match output dataset + assertEquals(nestedObjectIngestDocument.getFieldValue("chunks.0.chunk.text.0.embedding", Object.class), Arrays.asList(1)); + assertEquals(nestedObjectIngestDocument.getFieldValue("chunks.0.chunk.text.1.embedding", Object.class), Arrays.asList(2)); + assertEquals(nestedObjectIngestDocument.getFieldValue("chunks.1.chunk.text.0.embedding", Object.class), Arrays.asList(3)); + assertEquals(nestedObjectIngestDocument.getFieldValue("chunks.1.chunk.text.1.embedding", Object.class), Arrays.asList(4)); + } + + public void testExecute_nestedObjectAndNestedObjectDocumentOutputInArrayMissingLeaveSuccess() { + List> inputMap = getInputMapsForNestedObjectChunks("chunks.*.chunk.text.*.context"); + + List> outputMap = getOutputMapsForNestedObjectChunks(); + + MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, inputMap, outputMap, true, false); + ArrayList> modelPredictionOutput = new ArrayList<>(); + modelPredictionOutput.add(Arrays.asList(1)); + modelPredictionOutput.add(Arrays.asList(2)); + modelPredictionOutput.add(null); + modelPredictionOutput.add(Arrays.asList(4)); + ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", modelPredictionOutput)).build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + + Map sourceAndMetadata = getNestedObjectWithAnotherNestedObjectSource(); + + Object jsonObject = JsonPath.parse(sourceAndMetadata).json(); + JsonPath.parse(jsonObject).delete("$.chunks[1].chunk.text[0].context"); + + IngestDocument nestedObjectIngestDocument = new IngestDocument((Map) jsonObject, new HashMap<>()); + processor.execute(nestedObjectIngestDocument, handler); + + // match output dataset + assertEquals(nestedObjectIngestDocument.getFieldValue("chunks.0.chunk.text", ArrayList.class).size(), 2); + assertEquals(nestedObjectIngestDocument.getFieldValue("chunks.1.chunk.text", ArrayList.class).size(), 2); + assertEquals(nestedObjectIngestDocument.getFieldValue("chunks.0.chunk.text.0.embedding", Object.class), Arrays.asList(1)); + assertEquals(nestedObjectIngestDocument.getFieldValue("chunks.0.chunk.text.1.embedding", Object.class), Arrays.asList(2)); + assertNull(((ArrayList) nestedObjectIngestDocument.getFieldValue("chunks.1.chunk.text.0.embedding", Object.class))); // missing + assertEquals(nestedObjectIngestDocument.getFieldValue("chunks.1.chunk.text.1.embedding", Object.class), Arrays.asList(4)); + } + + public void testExecute_InferenceException() { + MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, null, null, false, false); + when(client.execute(any(), any())).thenThrow(new RuntimeException("Executing Model failed with exception")); + try { + processor.execute(ingestDocument, handler); + } catch (RuntimeException e) { + assertEquals("Executing Model failed with exception", e.getMessage()); + } + } + + public void testExecute_InferenceOnFailure() { + MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, null, null, false, false); + RuntimeException inferenceFailure = new RuntimeException("Executing Model failed with exception"); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onFailure(inferenceFailure); + return null; + }).when(client).execute(any(), any(), any()); + processor.execute(ingestDocument, handler); + + verify(handler).accept(eq(null), eq(inferenceFailure)); + + } + + public void testExecute_AppendFieldValueExceptionOnResponse() throws Exception { + List> outputMap = new ArrayList<>(); + Map output = new HashMap<>(); + String originalOutPutFieldName = "response1"; + output.put("text_embedding", originalOutPutFieldName); + outputMap.add(output); + MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, null, outputMap, false, false); + + ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", Arrays.asList(1, 2, 3))).build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + + try { + processor.execute(ingestDocument, handler); + + } catch (IllegalArgumentException e) { + assertEquals("model inference output can not find field name: " + originalOutPutFieldName, e.getMessage()); + } + + } + + public void testExecute_whenInputFieldNotFound_ExceptionWithIgnoreMissingFalse() { + List> inputMap = new ArrayList<>(); + Map input = new HashMap<>(); + String documentFieldPath = "text"; + input.put("inputs", documentFieldPath); + inputMap.add(input); + List> outputMap = new ArrayList<>(); + Map output = new HashMap<>(); + output.put("text_embedding", "response"); + outputMap.add(output); + Map model_config = new HashMap<>(); + model_config.put("position_embedding_type", "absolute"); + + MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", model_config, inputMap, outputMap, false, false); + + try { + processor.execute(ingestDocument, handler); + } catch (IllegalArgumentException e) { + assertEquals("field name in input_map: [" + documentFieldPath + "] doesn't exist", e.getMessage()); + } + + } + + public void testExecute_whenInputFieldNotFound_SuccessWithIgnoreMissingTrue() { + List> inputMap = new ArrayList<>(); + Map input = new HashMap<>(); + String documentFieldPath = "text"; + input.put("inputs", documentFieldPath); + inputMap.add(input); + List> outputMap = new ArrayList<>(); + Map output = new HashMap<>(); + output.put("text_embedding", "response"); + outputMap.add(output); + MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, inputMap, outputMap, true, false); + + processor.execute(ingestDocument, handler); + } + + public void testExecute_whenEmptyInputField_ExceptionWithIgnoreMissingFalse() { + List> inputMap = new ArrayList<>(); + Map input = new HashMap<>(); + String documentFieldPath = ""; // emptyInputField + input.put("inputs", documentFieldPath); + inputMap.add(input); + List> outputMap = new ArrayList<>(); + Map output = new HashMap<>(); + output.put("text_embedding", "response"); + outputMap.add(output); + Map model_config = new HashMap<>(); + model_config.put("position_embedding_type", "absolute"); + + MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", model_config, inputMap, outputMap, false, false); + + try { + processor.execute(ingestDocument, handler); + } catch (IllegalArgumentException e) { + assertEquals("field name in input_map [ " + documentFieldPath + "] cannot be null nor empty", e.getMessage()); + } + } + + public void testExecute_whenEmptyInputField_ExceptionWithIgnoreMissingTrue() { + List> inputMap = new ArrayList<>(); + Map input = new HashMap<>(); + String documentFieldPath = ""; // emptyInputField + input.put("inputs", documentFieldPath); + inputMap.add(input); + List> outputMap = new ArrayList<>(); + Map output = new HashMap<>(); + output.put("text_embedding", "response"); + outputMap.add(output); + Map model_config = new HashMap<>(); + model_config.put("position_embedding_type", "absolute"); + + MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", model_config, inputMap, outputMap, true, false); + + processor.execute(ingestDocument, handler); + + } + + public void testExecute_IOExceptionWithIgnoreMissingFalse() throws JsonProcessingException { + List> inputMap = new ArrayList<>(); + Map input = new HashMap<>(); + String documentFieldPath = "text"; + input.put("inputs", documentFieldPath); + inputMap.add(input); + List> outputMap = new ArrayList<>(); + Map output = new HashMap<>(); + output.put("text_embedding", "response"); + outputMap.add(output); + Map model_config = new HashMap<>(); + model_config.put("position_embedding_type", "absolute"); + + ObjectMapper mapper = mock(ObjectMapper.class); + when(mapper.readValue(Mockito.anyString(), eq(Object.class))).thenThrow(JsonProcessingException.class); + + MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", model_config, inputMap, outputMap, false, false); + + try { + processor.execute(ingestDocument, handler); + } catch (IllegalArgumentException e) { + assertEquals("field name in input_map: [" + documentFieldPath + "] doesn't exist", e.getMessage()); + } + } + + public void testExecute_NoModelInput_Exception() { + MLInferenceIngestProcessor processorIgnoreMissingTrue = createMLInferenceProcessor("model1", null, null, null, true, false); + MLInferenceIngestProcessor processorIgnoreMissingFalse = createMLInferenceProcessor("model1", null, null, null, false, false); + + Map sourceAndMetadata = new HashMap<>(); + IngestDocument emptyIngestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); + try { + processorIgnoreMissingTrue.execute(emptyIngestDocument, handler); + } catch (IllegalArgumentException e) { + assertEquals("wrong input. The model input cannot be empty.", e.getMessage()); + } + try { + processorIgnoreMissingFalse.execute(emptyIngestDocument, handler); + } catch (IllegalArgumentException e) { + assertEquals("wrong input. The model input cannot be empty.", e.getMessage()); + } + + } + + public void testExecute_AppendModelOutputSuccess() { + MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, null, null, true, false); + ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", Arrays.asList(1, 2, 3))).build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + + processor.execute(ingestDocument, handler); + + Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put("key1", "value1"); + sourceAndMetadata.put("key2", "value2"); + sourceAndMetadata.put(DEFAULT_OUTPUT_FIELD_NAME, ImmutableMap.of("response", Arrays.asList(1, 2, 3))); + IngestDocument ingestDocument1 = new IngestDocument(sourceAndMetadata, new HashMap<>()); + verify(handler).accept(eq(ingestDocument1), isNull()); + assertEquals(ingestDocument, ingestDocument1); + } + + public void testExecute_SingleTensorInDataOutputSuccess() { + MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, null, null, true, false); + + Float[] value = new Float[] { 1.0f, 2.0f, 3.0f }; + List outputs = new ArrayList<>(); + ModelTensor tensor = ModelTensor + .builder() + .data(value) + .name("test") + .shape(new long[] { 1, 3 }) + .dataType(MLResultDataType.FLOAT32) + .byteBuffer(ByteBuffer.wrap(new byte[] { 0, 1, 0, 1 })) + .build(); + List mlModelTensors = Arrays.asList(tensor); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(mlModelTensors).build(); + outputs.add(modelTensors); + ModelTensorOutput modelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(outputs).build(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(modelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + + processor.execute(ingestDocument, handler); + + Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put("key1", "value1"); + sourceAndMetadata.put("key2", "value2"); + sourceAndMetadata.put(DEFAULT_OUTPUT_FIELD_NAME, Arrays.asList(1.0f, 2.0f, 3.0f)); + IngestDocument ingestDocument1 = new IngestDocument(sourceAndMetadata, new HashMap<>()); + verify(handler).accept(eq(ingestDocument1), isNull()); + assertEquals(ingestDocument, ingestDocument1); + } + + public void testExecute_MultipleTensorInDataOutputSuccess() { + MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, null, null, true, false); + List outputs = new ArrayList<>(); + + Float[] value = new Float[] { 1.0f }; + ModelTensor tensor = ModelTensor + .builder() + .data(value) + .name("test") + .shape(new long[] { 1, 1 }) + .dataType(MLResultDataType.FLOAT32) + .byteBuffer(ByteBuffer.wrap(new byte[] { 0, 1, 0, 1 })) + .build(); + + Float[] value1 = new Float[] { 2.0f }; + ModelTensor tensor1 = ModelTensor + .builder() + .data(value1) + .name("test") + .shape(new long[] { 1, 1 }) + .dataType(MLResultDataType.FLOAT32) + .byteBuffer(ByteBuffer.wrap(new byte[] { 0, 1, 0, 1 })) + .build(); + + Float[] value2 = new Float[] { 3.0f }; + ModelTensor tensor2 = ModelTensor + .builder() + .data(value2) + .name("test") + .shape(new long[] { 1, 1 }) + .dataType(MLResultDataType.FLOAT32) + .byteBuffer(ByteBuffer.wrap(new byte[] { 0, 1, 0, 1 })) + .build(); + + List mlModelTensors = Arrays.asList(tensor, tensor1, tensor2); + + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(mlModelTensors).build(); + outputs.add(modelTensors); + ModelTensorOutput modelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(outputs).build(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(modelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + + processor.execute(ingestDocument, handler); + + Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put("key1", "value1"); + sourceAndMetadata.put("key2", "value2"); + sourceAndMetadata.put(DEFAULT_OUTPUT_FIELD_NAME, Arrays.asList(List.of(1.0f), List.of(2.0f), List.of(3.0f))); + IngestDocument ingestDocument1 = new IngestDocument(sourceAndMetadata, new HashMap<>()); + verify(handler).accept(eq(ingestDocument1), isNull()); + assertEquals(ingestDocument, ingestDocument1); + } + + public void testExecute_getModelOutputFieldWithFieldNameSuccess() { + List> outputMap = new ArrayList<>(); + Map output = new HashMap<>(); + output.put("classification", "response"); + outputMap.add(output); + + MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, null, outputMap, true, false); + ModelTensor modelTensor = ModelTensor + .builder() + .dataAsMap(ImmutableMap.of("response", ImmutableMap.of("language", "en", "score", "0.9876"))) + .build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + + processor.execute(ingestDocument, handler); + + Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put("key1", "value1"); + sourceAndMetadata.put("key2", "value2"); + sourceAndMetadata.put("classification", ImmutableMap.of("language", "en", "score", "0.9876")); + IngestDocument ingestDocument1 = new IngestDocument(sourceAndMetadata, new HashMap<>()); + verify(handler).accept(eq(ingestDocument1), isNull()); + assertEquals(ingestDocument, ingestDocument1); + } + + public void testExecute_getModelOutputFieldWithDotPathSuccess() { + List> outputMap = new ArrayList<>(); + Map output = new HashMap<>(); + output.put("language_identification", "response.language"); + outputMap.add(output); + + MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, null, outputMap, true, false); + ModelTensor modelTensor = ModelTensor + .builder() + .dataAsMap(ImmutableMap.of("response", ImmutableMap.of("language", List.of("en", "en"), "score", "0.9876"))) + .build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + + processor.execute(ingestDocument, handler); + + Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put("key1", "value1"); + sourceAndMetadata.put("key2", "value2"); + sourceAndMetadata.put("language_identification", List.of("en", "en")); + IngestDocument ingestDocument1 = new IngestDocument(sourceAndMetadata, new HashMap<>()); + verify(handler).accept(eq(ingestDocument1), isNull()); + assertEquals(ingestDocument, ingestDocument1); + } + + public void testExecute_getModelOutputFieldWithInvalidDotPathSuccess() { + + List> outputMap = new ArrayList<>(); + Map output = new HashMap<>(); + output.put("language_identification", "response.lan"); + outputMap.add(output); + + MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, null, outputMap, true, false); + ModelTensor modelTensor = ModelTensor + .builder() + .dataAsMap(ImmutableMap.of("response", ImmutableMap.of("language", "en", "score", "0.9876"))) + .build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + Map sourceAndMetadata1 = new HashMap<>(); + sourceAndMetadata1.put("key1", "value1"); + sourceAndMetadata1.put("key2", "value2"); + IngestDocument ingestDocument1 = new IngestDocument(sourceAndMetadata1, new HashMap<>()); + processor.execute(ingestDocument1, handler); + + verify(handler).accept(eq(ingestDocument1), isNull()); + assertNull(ingestDocument1.getIngestMetadata().get("language_identification")); + } + + public void testExecute_getModelOutputFieldWithInvalidDotPathException() { + + List> outputMap = new ArrayList<>(); + Map output = new HashMap<>(); + output.put("response.lan", "language_identification"); + outputMap.add(output); + + MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, null, outputMap, false, false); + ModelTensor modelTensor = ModelTensor + .builder() + .dataAsMap(ImmutableMap.of("response", ImmutableMap.of("language", "en", "score", "0.9876"))) + .build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + Map sourceAndMetadata1 = new HashMap<>(); + sourceAndMetadata1.put("key1", "value1"); + sourceAndMetadata1.put("key2", "value2"); + IngestDocument ingestDocument1 = new IngestDocument(sourceAndMetadata1, new HashMap<>()); + try { + processor.execute(ingestDocument1, handler); + } catch (IllegalArgumentException e) { + assertEquals("model inference output can not find field name: " + "response.lan", e.getMessage()); + } + ; + + } + + public void testExecute_getModelOutputFieldInNestedWithInvalidDotPathException() { + + List> inputMap = getInputMapsForNestedObjectChunks("chunks.*.chunk.text.*.context"); + + List> outputMap = new ArrayList<>(); + Map output = new HashMap<>(); + output.put("chunks.*.chunk.text.*.context_embedding", "response.language1"); + outputMap.add(output); + + MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, inputMap, outputMap, false, false); + ModelTensor modelTensor = ModelTensor + .builder() + .dataAsMap(ImmutableMap.of("response", ImmutableMap.of("language", "en", "score", "0.9876"))) + .build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + + processor.execute(nestedObjectIngestDocument, handler); + + verify(handler) + .accept( + eq(null), + argThat( + exception -> exception + .getMessage() + .equals("An unexpected error occurred: model inference output cannot find field name: response.language1") + ) + ); + ; + + } + + public void testExecute_getModelOutputFieldWithExistedFieldNameException() { + List> outputMap = new ArrayList<>(); + Map output = new HashMap<>(); + output.put("key1", "response"); + outputMap.add(output); + + MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, null, outputMap, false, false); + ModelTensor modelTensor = ModelTensor + .builder() + .dataAsMap(ImmutableMap.of("response", ImmutableMap.of("language", "en", "score", "0.9876"))) + .build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + + processor.execute(ingestDocument, handler); + verify(handler) + .accept( + eq(null), + argThat( + exception -> exception + .getMessage() + .equals( + "document already has field name key1. Not allow to overwrite the same field name, please check output_map." + ) + ) + ); + } + + public void testExecute_documentNotExistedFieldNameException() { + List> inputMap = new ArrayList<>(); + Map input = new HashMap<>(); + input.put("inputs", "key99"); + inputMap.add(input); + + List> outputMap = new ArrayList<>(); + Map output = new HashMap<>(); + output.put("classification", "response"); + outputMap.add(output); + + MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, inputMap, outputMap, false, false); + + processor.execute(ingestDocument, handler); + verify(handler) + .accept(eq(null), argThat(exception -> exception.getMessage().equals("cannot find field name defined from input map: key99"))); + } + + public void testExecute_nestedDocumentNotExistedFieldNameException() { + List> inputMap = getInputMapsForNestedObjectChunks("chunks.*.chunk.text.*.context1"); + + MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, inputMap, null, false, false); + + processor.execute(ingestDocument, handler); + verify(handler) + .accept( + eq(null), + argThat( + exception -> exception + .getMessage() + .equals("cannot find field name defined from input map: chunks.*.chunk.text.*.context1") + ) + ); + } + + public void testExecute_getModelOutputFieldDifferentLengthException() { + List> inputMap = getInputMapsForNestedObjectChunks("chunks.*.chunk.text.*.context"); + + List> outputMap = getOutputMapsForNestedObjectChunks(); + + MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, inputMap, outputMap, true, false); + ArrayList> modelPredictionOutput = new ArrayList<>(); + modelPredictionOutput.add(Arrays.asList(1)); + modelPredictionOutput.add(Arrays.asList(2)); + modelPredictionOutput.add(Arrays.asList(3)); + + ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", modelPredictionOutput)).build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + + Map sourceAndMetadata = getNestedObjectWithAnotherNestedObjectSource(); + + IngestDocument nestedObjectIngestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); + + try { + processor.execute(nestedObjectIngestDocument, handler); + } catch (RuntimeException e) { + assertEquals( + "the prediction field: response is an array in size of 3 but the document field array from field chunks.*.chunk.text.*.embedding is in size of 4", + e.getMessage() + ); + } + + } + + public void testExecute_getModelOutputFieldDifferentLengthIgnoreFailureSuccess() { + List> inputMap = getInputMapsForNestedObjectChunks("chunks.*.chunk.text.*.context"); + + List> outputMap = getOutputMapsForNestedObjectChunks(); + + MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, inputMap, outputMap, true, true); + ArrayList> modelPredictionOutput = new ArrayList<>(); + modelPredictionOutput.add(Arrays.asList(1)); + modelPredictionOutput.add(Arrays.asList(2)); + modelPredictionOutput.add(Arrays.asList(3)); + + ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", modelPredictionOutput)).build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + + Map sourceAndMetadata = getNestedObjectWithAnotherNestedObjectSource(); + + IngestDocument nestedObjectIngestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); + + processor.execute(nestedObjectIngestDocument, handler); + verify(handler).accept(eq(nestedObjectIngestDocument), isNull()); + assertNull(nestedObjectIngestDocument.getIngestMetadata().get("response")); + } + + public void testExecute_getMlModelTensorsIsNull() { + List> inputMap = getInputMapsForNestedObjectChunks("chunks.*.chunk.text.*.context"); + + List> outputMap = getOutputMapsForNestedObjectChunks(); + + MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, inputMap, outputMap, true, false); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(null).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + Map sourceAndMetadata = getNestedObjectWithAnotherNestedObjectSource(); + + IngestDocument nestedObjectIngestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); + + processor.execute(nestedObjectIngestDocument, handler); + + verify(handler) + .accept( + eq(null), + argThat(exception -> exception.getMessage().equals("An unexpected error occurred: Output tensors are null or empty.")) + ); + + } + + public void testExecute_getMlModelTensorsIsNullIgnoreFailure() { + List> inputMap = getInputMapsForNestedObjectChunks("chunks.*.chunk.text.*.context"); + + List> outputMap = getOutputMapsForNestedObjectChunks(); + + MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, inputMap, outputMap, true, true); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(null).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + Map sourceAndMetadata = getNestedObjectWithAnotherNestedObjectSource(); + + IngestDocument nestedObjectIngestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); + + processor.execute(nestedObjectIngestDocument, handler); + verify(handler).accept(eq(nestedObjectIngestDocument), isNull()); + } + + public void testExecute_modelTensorOutputIsNull() { + List> inputMap = getInputMapsForNestedObjectChunks("chunks.*.chunk.text.*.context"); + + List> outputMap = getOutputMapsForNestedObjectChunks(); + + MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, inputMap, outputMap, true, false); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(null).build(); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + Map sourceAndMetadata = getNestedObjectWithAnotherNestedObjectSource(); + + IngestDocument nestedObjectIngestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); + + processor.execute(nestedObjectIngestDocument, handler); + verify(handler).accept(eq(null), argThat(exception -> exception.getMessage().equals("model inference output cannot be null"))); + + } + + public void testExecute_modelTensorOutputIsNullIgnoreFailureSuccess() { + List> inputMap = getInputMapsForNestedObjectChunks("chunks.*.chunk.text.*.context"); + + List> outputMap = getOutputMapsForNestedObjectChunks(); + + MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, inputMap, outputMap, true, true); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(null).build(); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + Map sourceAndMetadata = getNestedObjectWithAnotherNestedObjectSource(); + + IngestDocument nestedObjectIngestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); + + processor.execute(nestedObjectIngestDocument, handler); + verify(handler).accept(eq(nestedObjectIngestDocument), isNull()); + } + + public void testParseGetDataInTensor_IntegerDataType() { + ModelTensor mockTensor = mock(ModelTensor.class); + when(mockTensor.getDataType()).thenReturn(MLResultDataType.INT8); + when(mockTensor.getData()).thenReturn(new Number[] { 1, 2, 3 }); + Object result = ModelExecutor.parseDataInTensor(mockTensor); + assertEquals(List.of(1, 2, 3), result); + } + + public void testParseGetDataInTensor_FloatDataType() { + ModelTensor mockTensor = mock(ModelTensor.class); + when(mockTensor.getDataType()).thenReturn(MLResultDataType.FLOAT32); + when(mockTensor.getData()).thenReturn(new Number[] { 1.1, 2.2, 3.3 }); + Object result = ModelExecutor.parseDataInTensor(mockTensor); + assertEquals(List.of(1.1f, 2.2f, 3.3f), result); + } + + public void testParseGetDataInTensor_BooleanDataType() { + ModelTensor mockTensor = mock(ModelTensor.class); + when(mockTensor.getDataType()).thenReturn(MLResultDataType.BOOLEAN); + when(mockTensor.getData()).thenReturn(new Number[] { 1, 0, 1 }); + Object result = ModelExecutor.parseDataInTensor(mockTensor); + assertEquals(List.of(true, false, true), result); + } + + private static Map getNestedObjectWithAnotherNestedObjectSource() { + ArrayList childDocuments = new ArrayList<>(); + + Map childDocument1Text = new HashMap<>(); + ArrayList grandChildDocuments1 = new ArrayList<>(); + Map grandChildDocument1Text = new HashMap<>(); + grandChildDocument1Text.put("context", "this is first"); + grandChildDocument1Text.put("chapter", "first chapter"); + Map grandChildDocument2Text = new HashMap<>(); + grandChildDocument2Text.put("context", "this is second"); + grandChildDocument2Text.put("chapter", "first chapter"); + grandChildDocuments1.add(grandChildDocument1Text); + grandChildDocuments1.add(grandChildDocument2Text); + childDocument1Text.put("text", grandChildDocuments1); + + Map childDocument1 = new HashMap<>(); + childDocument1.put("chunk", childDocument1Text); + + Map childDocument2 = new HashMap<>(); + Map childDocument2Text = new HashMap<>(); + ArrayList grandChildDocuments2 = new ArrayList<>(); + + Map grandChildDocument3Text = new HashMap<>(); + grandChildDocument3Text.put("context", "this is third"); + grandChildDocument3Text.put("chapter", "second chapter"); + Map grandChildDocument4Text = new HashMap<>(); + grandChildDocument4Text.put("context", "this is fourth"); + grandChildDocument4Text.put("chapter", "first chapter"); + grandChildDocuments2.add(grandChildDocument3Text); + grandChildDocuments2.add(grandChildDocument4Text); + + childDocument2Text.put("text", grandChildDocuments2); + childDocument2.put("chunk", childDocument2Text); + + childDocuments.add(childDocument1); + childDocuments.add(childDocument2); + + Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put("chunks", childDocuments); + return sourceAndMetadata; + } + + private static List> getOutputMapsForNestedObjectChunks() { + List> outputMap = new ArrayList<>(); + Map output = new HashMap<>(); + String modelOutputPath = "response"; + String documentFieldName = "chunks.*.chunk.text.*.embedding"; + output.put(documentFieldName, modelOutputPath); + outputMap.add(output); + return outputMap; + } + + private static List> getInputMapsForNestedObjectChunks(String documentFieldPath) { + List> inputMap = new ArrayList<>(); + Map input = new HashMap<>(); + input.put("inputs", documentFieldPath); + inputMap.add(input); + return inputMap; + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java b/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java index 609b82f4ae..65767339d2 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java @@ -89,6 +89,7 @@ import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput; import org.opensearch.ml.common.transport.model_group.MLUpdateModelGroupInput; import org.opensearch.ml.common.transport.register.MLRegisterModelInput; +import org.opensearch.ml.common.utils.StringUtils; import org.opensearch.ml.stats.ActionName; import org.opensearch.ml.stats.MLActionLevelStat; import org.opensearch.ml.utils.TestData; @@ -850,11 +851,11 @@ private String parseTaskIdFromResponse(Response response) throws IOException { return taskId; } - Map parseResponseToMap(Response response) throws IOException { + public static Map parseResponseToMap(Response response) throws IOException { HttpEntity entity = response.getEntity(); assertNotNull(response); String entityString = TestHelper.httpEntityToString(entity); - return gson.fromJson(entityString, Map.class); + return StringUtils.gson.fromJson(entityString, Map.class); } public Map getModelProfile(String modelId, Consumer verifyFunction) throws IOException { diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLGuardrailsIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLGuardrailsIT.java index 7934ad67cd..c275d2263e 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLGuardrailsIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLGuardrailsIT.java @@ -9,7 +9,6 @@ import java.util.List; import java.util.Map; -import org.apache.hc.core5.http.HttpEntity; import org.apache.hc.core5.http.HttpHeaders; import org.apache.hc.core5.http.message.BasicHeader; import org.junit.Before; @@ -240,13 +239,6 @@ protected boolean checkThrottlingOpenAI(Map responseMap) { return message.equals("You exceeded your current quota, please check your plan and billing details."); } - protected Map parseResponseToMap(Response response) throws IOException { - HttpEntity entity = response.getEntity(); - assertNotNull(response); - String entityString = TestHelper.httpEntityToString(entity); - return gson.fromJson(entityString, Map.class); - } - protected void disableClusterConnectorAccessControl() throws IOException { Response response = TestHelper .makeRequest( diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLInferenceIngestProcessorIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLInferenceIngestProcessorIT.java new file mode 100644 index 0000000000..3eec4da4e2 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLInferenceIngestProcessorIT.java @@ -0,0 +1,263 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static org.opensearch.ml.utils.TestHelper.makeRequest; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import org.apache.hc.core5.http.HttpHeaders; +import org.apache.hc.core5.http.message.BasicHeader; +import org.junit.Assert; +import org.junit.Before; +import org.opensearch.client.Request; +import org.opensearch.client.Response; +import org.opensearch.ml.common.MLTaskState; +import org.opensearch.ml.utils.TestHelper; + +import com.google.common.collect.ImmutableList; +import com.jayway.jsonpath.JsonPath; + +public class RestMLInferenceIngestProcessorIT extends MLCommonsRestTestCase { + private final String OPENAI_KEY = System.getenv("OPENAI_KEY"); + private String modelId; + private final String completionModelConnectorEntity = "{\n" + + " \"name\": \"OpenAI text embedding model Connector\",\n" + + " \"description\": \"The connector to public OpenAI text embedding model service\",\n" + + " \"version\": 1,\n" + + " \"protocol\": \"http\",\n" + + " \"parameters\": {\n" + + " \"model\": \"text-embedding-ada-002\"\n" + + " },\n" + + " \"credential\": {\n" + + " \"openAI_key\": \"" + + OPENAI_KEY + + "\"\n" + + " },\n" + + " \"actions\": [\n" + + " {\n" + + " \"action_type\": \"predict\",\n" + + " \"method\": \"POST\",\n" + + " \"url\": \"https://api.openai.com/v1/embeddings\",\n" + + " \"headers\": { \n" + + " \"Authorization\": \"Bearer ${credential.openAI_key}\"\n" + + " },\n" + + " \"request_body\": \"{ \\\"input\\\": ${parameters.input}, \\\"model\\\": \\\"${parameters.model}\\\" }\"\n" + + " }\n" + + " ]\n" + + "}"; + + @Before + public void setup() throws IOException, InterruptedException { + RestMLRemoteInferenceIT.disableClusterConnectorAccessControl(); + Thread.sleep(20000); + + // create connectors for OPEN AI and register model + Response response = RestMLRemoteInferenceIT.createConnector(completionModelConnectorEntity); + Map responseMap = parseResponseToMap(response); + String openAIConnectorId = (String) responseMap.get("connector_id"); + response = RestMLRemoteInferenceIT.registerRemoteModel("openAI-GPT-3.5 chat model", openAIConnectorId); + responseMap = parseResponseToMap(response); + String taskId = (String) responseMap.get("task_id"); + waitForTask(taskId, MLTaskState.COMPLETED); + response = RestMLRemoteInferenceIT.getTask(taskId); + responseMap = parseResponseToMap(response); + this.modelId = (String) responseMap.get("model_id"); + response = RestMLRemoteInferenceIT.deployRemoteModel(modelId); + responseMap = parseResponseToMap(response); + taskId = (String) responseMap.get("task_id"); + waitForTask(taskId, MLTaskState.COMPLETED); + } + + public void testMLInferenceProcessorWithObjectFieldType() throws Exception { + + String createPipelineRequestBody = "{\n" + + " \"description\": \"test ml model ingest processor\",\n" + + " \"processors\": [\n" + + " {\n" + + " \"ml_inference\": {\n" + + " \"model_id\": \"" + + this.modelId + + "\",\n" + + " \"input_map\": [\n" + + " {\n" + + " \"input\": \"diary\"\n" + + " }\n" + + " ],\n" + + " \n" + + " \"output_map\": [\n" + + " {\n" + + " \"diary_embedding\": \"data.*.embedding\"\n" + + " }\n" + + " ],\n" + + " \"model_config\": {\"model\":\"text-embedding-ada-002\"}\n" + + " }\n" + + " }\n" + + " ]\n" + + "}"; + String createIndexRequestBody = "{\n" + + " \"settings\": {\n" + + " \"index\": {\n" + + " \"default_pipeline\": \"diary_embedding_pipeline\"\n" + + " }\n" + + " }\n" + + " }"; + String uploadDocumentRequestBody = "{\n" + + " \"id\": 1,\n" + + " \"diary\": [\"happy\",\"first day at school\"],\n" + + " \"weather\": \"rainy\"\n" + + " }"; + String index_name = "daily_index"; + createPipelineProcessor(createPipelineRequestBody, "diary_embedding_pipeline"); + createIndex(index_name, createIndexRequestBody); + uploadDocument(index_name, "1", uploadDocumentRequestBody); + Map document = getDocument(index_name, "1"); + List embeddingList = JsonPath.parse(document).read("_source.diary_embedding"); + Assert.assertEquals(2, embeddingList.size()); + + List embedding1 = JsonPath.parse(document).read("_source.diary_embedding[0]"); + Assert.assertEquals(1536, embedding1.size()); + Assert.assertEquals(-0.0118564125, (Double) embedding1.get(0), 0.00005); + + List embedding2 = JsonPath.parse(document).read("_source.diary_embedding[1]"); + Assert.assertEquals(1536, embedding2.size()); + Assert.assertEquals(-0.005518768, (Double) embedding2.get(0), 0.00005); + } + + public void testMLInferenceProcessorWithNestedFieldType() throws Exception { + + String createPipelineRequestBody = "{\n" + + " \"description\": \"ingest reviews and generate embedding\",\n" + + " \"processors\": [\n" + + " {\n" + + " \"ml_inference\": {\n" + + " \"model_id\": \"" + + this.modelId + + "\",\n" + + " \"input_map\": [\n" + + " {\n" + + " \"input\": \"book.*.chunk.text.*.context\"\n" + + " }\n" + + " ],\n" + + " \n" + + " \"output_map\": [\n" + + " {\n" + + " \"book.*.chunk.text.*.context_embedding\": \"data.*.embedding\"\n" + + " }\n" + + " ],\n" + + " \"model_config\": {\"model\":\"text-embedding-ada-002\"}\n" + + " }\n" + + " }\n" + + " ]\n" + + "}"; + String createIndexRequestBody = "{\n" + + " \"settings\": {\n" + + " \"index\": {\n" + + " \"default_pipeline\": \"embedding_pipeline\"\n" + + " }\n" + + " }\n" + + " }"; + String uploadDocumentRequestBody = "{\n" + + " \"book\": [\n" + + " {\n" + + " \"chunk\": {\n" + + " \"text\": [\n" + + " {\n" + + " \"chapter\": \"first chapter\",\n" + + " \"context\": \"this is the first part\"\n" + + " },\n" + + " {\n" + + " \"chapter\": \"first chapter\",\n" + + " \"context\": \"this is the second part\"\n" + + " }\n" + + " ]\n" + + " }\n" + + " },\n" + + " {\n" + + " \"chunk\": {\n" + + " \"text\": [\n" + + " {\n" + + " \"chapter\": \"second chapter\",\n" + + " \"context\": \"this is the third part\"\n" + + " },\n" + + " {\n" + + " \"chapter\": \"second chapter\",\n" + + " \"context\": \"this is the fourth part\"\n" + + " }\n" + + " ]\n" + + " }\n" + + " }\n" + + " ]\n" + + "}"; + String index_name = "book_index"; + createPipelineProcessor(createPipelineRequestBody, "embedding_pipeline"); + createIndex(index_name, createIndexRequestBody); + uploadDocument(index_name, "1", uploadDocumentRequestBody); + Map document = getDocument(index_name, "1"); + + List embeddingList = JsonPath.parse(document).read("_source.book[*].chunk.text[*].context_embedding"); + Assert.assertEquals(4, embeddingList.size()); + + List embedding1 = JsonPath.parse(document).read("_source.book[0].chunk.text[0].context_embedding"); + Assert.assertEquals(1536, embedding1.size()); + Assert.assertEquals(0.023224998, (Double) embedding1.get(0), 0.00005); + + List embedding2 = JsonPath.parse(document).read("_source.book[0].chunk.text[1].context_embedding"); + Assert.assertEquals(1536, embedding2.size()); + Assert.assertEquals(0.016423305, (Double) embedding2.get(0), 0.00005); + + List embedding3 = JsonPath.parse(document).read("_source.book[1].chunk.text[0].context_embedding"); + Assert.assertEquals(1536, embedding3.size()); + Assert.assertEquals(0.011252925, (Double) embedding3.get(0), 0.00005); + + List embedding4 = JsonPath.parse(document).read("_source.book[1].chunk.text[1].context_embedding"); + Assert.assertEquals(1536, embedding4.size()); + Assert.assertEquals(0.014352738, (Double) embedding4.get(0), 0.00005); + } + + protected void createPipelineProcessor(String requestBody, final String pipelineName) throws Exception { + Response pipelineCreateResponse = TestHelper + .makeRequest( + client(), + "PUT", + "/_ingest/pipeline/" + pipelineName, + null, + requestBody, + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "")) + ); + assertEquals(200, pipelineCreateResponse.getStatusLine().getStatusCode()); + + } + + protected void createIndex(String indexName, String requestBody) throws Exception { + Response response = makeRequest( + client(), + "PUT", + indexName, + null, + requestBody, + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "")) + ); + assertEquals(200, response.getStatusLine().getStatusCode()); + } + + protected void uploadDocument(final String index, final String docId, final String jsonBody) throws IOException { + Request request = new Request("PUT", "/" + index + "/_doc/" + docId + "?refresh=true"); + + request.setJsonEntity(jsonBody); + client().performRequest(request); + } + + protected Map getDocument(final String index, final String docId) throws Exception { + Response docResponse = TestHelper.makeRequest(client(), "GET", "/" + index + "/_doc/" + docId + "?refresh=true", null, "", null); + assertEquals(200, docResponse.getStatusLine().getStatusCode()); + + return parseResponseToMap(docResponse); + } + +} diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java index 9dee5ee088..938ca8b0b1 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java @@ -11,7 +11,6 @@ import java.util.function.Consumer; import org.apache.commons.lang3.exception.ExceptionUtils; -import org.apache.hc.core5.http.HttpEntity; import org.apache.hc.core5.http.HttpHeaders; import org.apache.hc.core5.http.message.BasicHeader; import org.junit.Before; @@ -738,11 +737,11 @@ public void testCohereClassifyModel() throws IOException, InterruptedException { assertFalse(responseList.isEmpty()); } - protected Response createConnector(String input) throws IOException { + public static Response createConnector(String input) throws IOException { return TestHelper.makeRequest(client(), "POST", "/_plugins/_ml/connectors/_create", null, TestHelper.toHttpEntity(input), null); } - protected Response registerRemoteModel(String name, String connectorId) throws IOException { + public static Response registerRemoteModel(String name, String connectorId) throws IOException { String registerModelGroupEntity = "{\n" + " \"name\": \"remote_model_group\",\n" + " \"description\": \"This is an example description\"\n" @@ -778,15 +777,15 @@ protected Response registerRemoteModel(String name, String connectorId) throws I .makeRequest(client(), "POST", "/_plugins/_ml/models/_register", null, TestHelper.toHttpEntity(registerModelEntity), null); } - protected Response deployRemoteModel(String modelId) throws IOException { + public static Response deployRemoteModel(String modelId) throws IOException { return TestHelper.makeRequest(client(), "POST", "/_plugins/_ml/models/" + modelId + "/_deploy", null, "", null); } - protected Response predictRemoteModel(String modelId, String input) throws IOException { + public Response predictRemoteModel(String modelId, String input) throws IOException { return TestHelper.makeRequest(client(), "POST", "/_plugins/_ml/models/" + modelId + "/_predict", null, input, null); } - protected Response undeployRemoteModel(String modelId) throws IOException { + public Response undeployRemoteModel(String modelId) throws IOException { String undeployEntity = "{\n" + " \"SYqCMdsFTumUwoHZcsgiUg\": {\n" + " \"stats\": {\n" @@ -805,14 +804,7 @@ protected boolean checkThrottlingOpenAI(Map responseMap) { return message.equals("You exceeded your current quota, please check your plan and billing details."); } - protected Map parseResponseToMap(Response response) throws IOException { - HttpEntity entity = response.getEntity(); - assertNotNull(response); - String entityString = TestHelper.httpEntityToString(entity); - return gson.fromJson(entityString, Map.class); - } - - protected void disableClusterConnectorAccessControl() throws IOException { + public static void disableClusterConnectorAccessControl() throws IOException { Response response = TestHelper .makeRequest( client(), @@ -825,11 +817,11 @@ protected void disableClusterConnectorAccessControl() throws IOException { assertEquals(200, response.getStatusLine().getStatusCode()); } - protected Response getTask(String taskId) throws IOException { + public static Response getTask(String taskId) throws IOException { return TestHelper.makeRequest(client(), "GET", "/_plugins/_ml/tasks/" + taskId, null, "", null); } - private String registerRemoteModel() throws IOException { + public String registerRemoteModel() throws IOException { Response response = createConnector(completionModelConnectorEntity); Map responseMap = parseResponseToMap(response); String connectorId = (String) responseMap.get("connector_id"); diff --git a/plugin/src/yamlRestTest/java/org/opensearch/ml/plugin/PluginClientYamlTestSuiteIT.java b/plugin/src/yamlRestTest/java/org/opensearch/ml/plugin/PluginClientYamlTestSuiteIT.java new file mode 100644 index 0000000000..a7715dce2e --- /dev/null +++ b/plugin/src/yamlRestTest/java/org/opensearch/ml/plugin/PluginClientYamlTestSuiteIT.java @@ -0,0 +1,25 @@ +/* + * + * * Copyright OpenSearch Contributors + * * SPDX-License-Identifier: Apache-2.0 + * + */ +package org.opensearch.ml.plugin; + +import com.carrotsearch.randomizedtesting.annotations.Name; +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; +import org.opensearch.test.rest.yaml.ClientYamlTestCandidate; +import org.opensearch.test.rest.yaml.OpenSearchClientYamlSuiteTestCase; + + +public class PluginClientYamlTestSuiteIT extends OpenSearchClientYamlSuiteTestCase { + + public ConversationalClientYamlTestSuiteIT(@Name("yaml") ClientYamlTestCandidate testCandidate) { + super(testCandidate); + } + + @ParametersFactory + public static Iterable parameters() throws Exception { + return OpenSearchClientYamlSuiteTestCase.createParameters(); + } +} diff --git a/plugin/src/yamlRestTest/resources/rest-api-spec/test/10_basic.yml b/plugin/src/yamlRestTest/resources/rest-api-spec/test/10_basic.yml new file mode 100644 index 0000000000..fd3c06631f --- /dev/null +++ b/plugin/src/yamlRestTest/resources/rest-api-spec/test/10_basic.yml @@ -0,0 +1,8 @@ +"c": + - do: + cat.plugins: + local: true + h: component + + - match: + $body: /^plugin\n$/ diff --git a/plugin/src/yamlRestTest/resources/rest-api-spec/test/20_inference_ingest_processor.yml b/plugin/src/yamlRestTest/resources/rest-api-spec/test/20_inference_ingest_processor.yml new file mode 100644 index 0000000000..43931b5d0a --- /dev/null +++ b/plugin/src/yamlRestTest/resources/rest-api-spec/test/20_inference_ingest_processor.yml @@ -0,0 +1,24 @@ +--- +teardown: + - do: + ingest.delete_pipeline: + id: "my_pipeline" + ignore: 404 + +--- +"Test ML Inference Processor": + - do: + ingest.put_pipeline: + id: "my_pipeline" + body: > + { + "description" : "pipeline with drop", + "processors" : [ + { + "ml_inference" : { + "model_id": "AGYioI4BK5nJfCdc0w1T" + } + } + ] + } + - match: { acknowledged: true } \ No newline at end of file