Skip to content

Commit

Permalink
add override
Browse files Browse the repository at this point in the history
Signed-off-by: Yaliang Wu <[email protected]>
  • Loading branch information
ylwu-amzn committed Jun 1, 2024
1 parent fd54211 commit e936adb
Showing 1 changed file with 53 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ public class MLInferenceIngestProcessor extends AbstractProcessor implements Mod
private final String functionName;
private final boolean fullResponsePath;
private final boolean ignoreFailure;
private final boolean override;
private final String modelInput;
private final ScriptService scriptService;
private static Client client;
Expand All @@ -61,6 +62,7 @@ public class MLInferenceIngestProcessor extends AbstractProcessor implements Mod
// allow to ignore a field from mapping is not present in the document, and when the outfield is not found in the
// prediction outcomes, return the whole prediction outcome by skipping filtering
public static final String IGNORE_MISSING = "ignore_missing";
public static final String OVERRIDE = "override";
public static final String FUNCTION_NAME = "function_name";
public static final String FULL_RESPONSE_PATH = "full_response_path";
public static final String MODEL_INPUT = "model_input";
Expand All @@ -86,6 +88,7 @@ protected MLInferenceIngestProcessor(
String functionName,
boolean fullResponsePath,
boolean ignoreFailure,
boolean override,
String modelInput,
ScriptService scriptService,
Client client,
Expand All @@ -103,6 +106,7 @@ protected MLInferenceIngestProcessor(
this.functionName = functionName;
this.fullResponsePath = fullResponsePath;
this.ignoreFailure = ignoreFailure;
this.override = override;
this.modelInput = modelInput;
this.scriptService = scriptService;
this.client = client;
Expand Down Expand Up @@ -190,6 +194,37 @@ private void processPredictions(
modelParameters.putAll(inferenceProcessorAttributes.getModelConfigMaps());
modelConfigs.putAll(inferenceProcessorAttributes.getModelConfigMaps());
}
Map<String, String> outputMapping = processOutputMap.get(inputMapIndex);

Map<String, Object> ingestDocumentSourceAndMetaData = new HashMap<>();
ingestDocumentSourceAndMetaData.putAll(ingestDocument.getSourceAndMetadata());
ingestDocumentSourceAndMetaData.put(IngestDocument.INGEST_KEY, ingestDocument.getIngestMetadata());

Map<String, List<String>> newOutputMapping = new HashMap<>();
for (Map.Entry<String, String> entry : outputMapping.entrySet()) {
String newDocumentFieldName = entry.getKey();
List<String> dotPathsInArray = writeNewDotPathForNestedObject(ingestDocumentSourceAndMetaData, newDocumentFieldName);
newOutputMapping.put(newDocumentFieldName, dotPathsInArray);
}

for (Map.Entry<String, String> entry : outputMapping.entrySet()) {
String newDocumentFieldName = entry.getKey();
List<String> dotPaths = newOutputMapping.get(newDocumentFieldName);

int existingFields = 0;
for (String path : dotPaths) {
if (ingestDocument.hasField(path)) {
existingFields++;
}
}
if (!override && existingFields == dotPaths.size()) {
newOutputMapping.remove(newDocumentFieldName);
}
}
if (newOutputMapping.size() == 0) {
batchPredictionListener.onResponse(null);
return;
}
// when no input mapping is provided, default to read all fields from documents as model input
if (inputMapSize == 0) {
Set<String> documentFields = ingestDocument.getSourceAndMetadata().keySet();
Expand Down Expand Up @@ -240,12 +275,8 @@ public void onResponse(MLTaskResponse mlTaskResponse) {
// document field as key, model field as value
String newDocumentFieldName = entry.getKey();
String modelOutputFieldName = entry.getValue();
if (ingestDocument.hasField(newDocumentFieldName)) {
throw new IllegalArgumentException(
"document already has field name "
+ newDocumentFieldName
+ ". Not allow to overwrite the same field name, please check output_map."
);
if (!newOutputMapping.containsKey(newDocumentFieldName)) {
continue;
}
appendFieldValue(mlOutput, modelOutputFieldName, newDocumentFieldName, ingestDocument);
}
Expand Down Expand Up @@ -363,10 +394,13 @@ private void appendFieldValue(
List<String> dotPathsInArray = writeNewDotPathForNestedObject(ingestDocument.getSourceAndMetadata(), newDocumentFieldName);

if (dotPathsInArray.size() == 1) {
ValueSource ingestValue = ValueSource.wrap(modelOutputValue, scriptService);
TemplateScript.Factory ingestField = ConfigurationUtils
.compileTemplate(TYPE, tag, newDocumentFieldName, newDocumentFieldName, scriptService);
ingestDocument.setFieldValue(ingestField, ingestValue, ignoreMissing);
if (!ingestDocument.hasField(dotPathsInArray.get(0)) || override) {
ValueSource ingestValue = ValueSource.wrap(modelOutputValue, scriptService);
TemplateScript.Factory ingestField = ConfigurationUtils
.compileTemplate(TYPE, tag, dotPathsInArray.get(0), dotPathsInArray.get(0), scriptService);

ingestDocument.setFieldValue(ingestField, ingestValue, ignoreMissing);
}
} else {
if (!(modelOutputValue instanceof List)) {
throw new IllegalArgumentException("Model output is not an array, cannot assign to array in documents.");
Expand All @@ -388,11 +422,13 @@ private void appendFieldValue(
// Iterate over dotPathInArray
for (int i = 0; i < dotPathsInArray.size(); i++) {
String dotPathInArray = dotPathsInArray.get(i);
Object modelOutputValueInArray = modelOutputValueArray.get(i);
ValueSource ingestValue = ValueSource.wrap(modelOutputValueInArray, scriptService);
TemplateScript.Factory ingestField = ConfigurationUtils
.compileTemplate(TYPE, tag, dotPathInArray, dotPathInArray, scriptService);
ingestDocument.setFieldValue(ingestField, ingestValue, ignoreMissing);
if (!ingestDocument.hasField(dotPathInArray) || override) {
Object modelOutputValueInArray = modelOutputValueArray.get(i);
ValueSource ingestValue = ValueSource.wrap(modelOutputValueInArray, scriptService);
TemplateScript.Factory ingestField = ConfigurationUtils
.compileTemplate(TYPE, tag, dotPathInArray, dotPathInArray, scriptService);
ingestDocument.setFieldValue(ingestField, ingestValue, ignoreMissing);
}
}
}
} else {
Expand Down Expand Up @@ -500,6 +536,7 @@ public MLInferenceIngestProcessor create(
int maxPredictionTask = ConfigurationUtils
.readIntProperty(TYPE, processorTag, config, MAX_PREDICTION_TASKS, DEFAULT_MAX_PREDICTION_TASKS);
boolean ignoreMissing = ConfigurationUtils.readBooleanProperty(TYPE, processorTag, config, IGNORE_MISSING, false);
boolean override = ConfigurationUtils.readBooleanProperty(TYPE, processorTag, config, OVERRIDE, false);
String functionName = ConfigurationUtils
.readStringProperty(TYPE, processorTag, config, FUNCTION_NAME, FunctionName.REMOTE.name());
String modelInput = ConfigurationUtils
Expand Down Expand Up @@ -538,6 +575,7 @@ public MLInferenceIngestProcessor create(
functionName,
fullResponsePath,
ignoreFailure,
override,
modelInput,
scriptService,
client,
Expand Down

0 comments on commit e936adb

Please sign in to comment.