Skip to content

Commit

Permalink
Initiate MLInferencelngestProcessor (#2205)
Browse files Browse the repository at this point in the history
* Initiate MLModelIngestProcessor

Signed-off-by: Mingshi Liu <[email protected]>

add more tests

Signed-off-by: Mingshi Liu <[email protected]>

* add more tests

Signed-off-by: Mingshi Liu <[email protected]>

add yaml tests and nested objects tests

Signed-off-by: Mingshi Liu <[email protected]>

 add IT tests

Signed-off-by: Mingshi Liu <[email protected]>

* use GroupListener and add DEFAULT_MAX_PREDICTION_TASKS

Signed-off-by: Mingshi Liu <[email protected]>

* add javadoc

Signed-off-by: Mingshi Liu <[email protected]>

* avoid calling execute(IngestDocument ingestDocument)-s

Signed-off-by: Mingshi Liu <[email protected]>

* not rewriting dotpath to json path

Signed-off-by: Mingshi Liu <[email protected]>

* change mapping order, input_map-model input as key, output_map-document field as key

Signed-off-by: Mingshi Liu <[email protected]>

* use StringUtils.toJson

Signed-off-by: Mingshi Liu <[email protected]>

---------

Signed-off-by: Mingshi Liu <[email protected]>
  • Loading branch information
mingshl authored Apr 29, 2024
1 parent caf1d65 commit c9758ca
Show file tree
Hide file tree
Showing 16 changed files with 2,450 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,13 @@ public boolean isInteger() {
public boolean isBoolean() {
return format == Format.BOOLEAN;
}

/**
* Checks whether it is a String data type.
*
* @return whether it is a String data type
*/
public boolean isString() {
return format == Format.STRING;
}
}
1 change: 1 addition & 0 deletions plugin/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ dependencies {
implementation "org.apache.logging.log4j:log4j-slf4j-impl:2.19.0"
testImplementation group: 'commons-io', name: 'commons-io', version: '2.15.1'
implementation group: 'org.apache.commons', name: 'commons-text', version: '1.10.0'
implementation 'com.jayway.jsonpath:json-path:2.9.0'
}

publishing {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -202,6 +203,7 @@
import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler;
import org.opensearch.ml.model.MLModelCacheHelper;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.processor.MLInferenceIngestProcessor;
import org.opensearch.ml.repackage.com.google.common.collect.ImmutableList;
import org.opensearch.ml.rest.RestMLCreateConnectorAction;
import org.opensearch.ml.rest.RestMLCreateControllerAction;
Expand Down Expand Up @@ -275,6 +277,7 @@
import org.opensearch.monitor.os.OsService;
import org.opensearch.plugins.ActionPlugin;
import org.opensearch.plugins.ExtensiblePlugin;
import org.opensearch.plugins.IngestPlugin;
import org.opensearch.plugins.Plugin;
import org.opensearch.plugins.SearchPipelinePlugin;
import org.opensearch.plugins.SearchPlugin;
Expand All @@ -298,7 +301,13 @@

import lombok.SneakyThrows;

public class MachineLearningPlugin extends Plugin implements ActionPlugin, SearchPlugin, SearchPipelinePlugin, ExtensiblePlugin {
public class MachineLearningPlugin extends Plugin
implements
ActionPlugin,
SearchPlugin,
SearchPipelinePlugin,
ExtensiblePlugin,
IngestPlugin {
public static final String ML_THREAD_POOL_PREFIX = "thread_pool.ml_commons.";
public static final String GENERAL_THREAD_POOL = "opensearch_ml_general";
public static final String EXECUTE_THREAD_POOL = "opensearch_ml_execute";
Expand Down Expand Up @@ -983,4 +992,15 @@ public void loadExtensions(ExtensionLoader loader) {
}
}
}

/**
* To get ingest processors
*/
@Override
public Map<String, org.opensearch.ingest.Processor.Factory> getProcessors(org.opensearch.ingest.Processor.Parameters parameters) {
Map<String, org.opensearch.ingest.Processor.Factory> processors = new HashMap<>();
processors
.put(MLInferenceIngestProcessor.TYPE, new MLInferenceIngestProcessor.Factory(parameters.scriptService, parameters.client));
return Collections.unmodifiableMap(processors);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.ml.processor;

import java.util.List;
import java.util.Map;

import lombok.Getter;
import lombok.Setter;

@Setter
@Getter
public class InferenceProcessorAttributes {

protected List<Map<String, String>> inputMaps;

protected List<Map<String, String>> outputMaps;

protected String modelId;
protected int maxPredictionTask;

protected Map<String, String> modelConfigMaps;
public static final String MODEL_ID = "model_id";
/**
* The list of maps that support one or more prediction tasks mapping.
* The mappings also support JSON path for nested objects.
*
* input_map is used to construct model inputs, where the keys represent the model input fields,
* and the values represent the corresponding document fields.
*
* Example input_map:
*
* "input_map": [
* {
* "input": "book.title"
* },
* {
* "input": "book.text"
* }
* ]
*/
public static final String INPUT_MAP = "input_map";
/**
* output_map is used to construct document fields, where the keys represent the document fields,
* and the values represent the corresponding model output fields.
*
* Example output_map:
*
* "output_map": [
* {
* "book.title_language": "response.language"
* },
* {
* "book.text_language": "response.language"
* }
* ]
*
*/
public static final String OUTPUT_MAP = "output_map";
public static final String MODEL_CONFIG = "model_config";
public static final String MAX_PREDICTION_TASKS = "max_prediction_tasks";

/**
* Utility class containing shared parameters for MLModelIngest/SearchProcessor
* */

public InferenceProcessorAttributes(
String modelId,
List<Map<String, String>> inputMaps,
List<Map<String, String>> outputMaps,
Map<String, String> modelConfigMaps,
int maxPredictionTask
) {
this.modelId = modelId;
this.modelConfigMaps = modelConfigMaps;
this.inputMaps = inputMaps;
this.outputMaps = outputMaps;
this.maxPredictionTask = maxPredictionTask;
}

}
Loading

0 comments on commit c9758ca

Please sign in to comment.