Skip to content

Commit

Permalink
escape input data (#1970)
Browse files Browse the repository at this point in the history
* escape input data

Signed-off-by: Yaliang Wu <[email protected]>

* add unit test

Signed-off-by: Yaliang Wu <[email protected]>

---------

Signed-off-by: Yaliang Wu <[email protected]>
  • Loading branch information
ylwu-amzn authored Jan 31, 2024
1 parent 714d315 commit 72d494c
Show file tree
Hide file tree
Showing 9 changed files with 81 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ public void validate(MLInput mlInput) {
@Override
public RemoteInferenceInputDataSet process(MLInput mlInput) {
TextDocsInputDataSet inputData = (TextDocsInputDataSet) mlInput.getInputDataset();
Map<String, Object> processedResult = Map.of("parameters", Map.of("inputText", processTextDocs(inputData).get(0)));
Map<String, Object> processedResult = Map.of("parameters", Map.of("inputText", inputData.getDocs().get(0)));
return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(processedResult)).build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ public void validate(MLInput mlInput) {
@Override
public RemoteInferenceInputDataSet process(MLInput mlInput) {
TextDocsInputDataSet inputData = (TextDocsInputDataSet) mlInput.getInputDataset();
Map<String, Object> processedResult = Map.of("parameters", Map.of("texts", processTextDocs(inputData)));
Map<String, Object> processedResult = Map.of("parameters", Map.of("texts", inputData.getDocs()));
return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(processedResult)).build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,8 @@
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;

import java.util.ArrayList;
import java.util.List;
import java.util.function.Function;

import static org.opensearch.ml.common.utils.StringUtils.gson;

@Log4j2
public abstract class ConnectorPreProcessFunction implements Function<MLInput, RemoteInferenceInputDataSet> {

Expand All @@ -38,21 +34,6 @@ public RemoteInferenceInputDataSet apply(MLInput mlInput) {

public abstract RemoteInferenceInputDataSet process(MLInput mlInput);

List<String> processTextDocs(TextDocsInputDataSet inputDataSet) {
List<String> docs = new ArrayList<>();
for (String doc : inputDataSet.getDocs()) {
if (doc != null) {
String gsonString = gson.toJson(doc);
// in 2.9, user will add " before and after string
// gson.toString(string) will add extra " before after string, so need to remove
docs.add(gsonString.substring(1, gsonString.length() - 1));
} else {
docs.add(null);
}
}
return docs;
}

public void validateTextDocsInput(MLInput mlInput) {
if (!(mlInput.getInputDataset() instanceof TextDocsInputDataSet)) {
throw new IllegalArgumentException("This pre_process_function can only support TextDocsInputDataSet");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ public void validate(MLInput mlInput) {
@Override
public RemoteInferenceInputDataSet process(MLInput mlInput) {
TextDocsInputDataSet inputData = (TextDocsInputDataSet) mlInput.getInputDataset();
Map<String, Object> processedResult = Map.of("parameters", Map.of("input", processTextDocs(inputData)));
Map<String, Object> processedResult = Map.of("parameters", Map.of("input", inputData.getDocs()));
return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(processedResult)).build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -120,4 +121,23 @@ public static Map<String, String> convertScriptStringToJsonString(Map<String, Ob
}
return parameterStringMap;
}

public static List<String> processTextDocs(List<String> inputDocs) {
List<String> docs = new ArrayList<>();
for (String doc : inputDocs) {
docs.add(processTextDoc(doc));
}
return docs;
}

public static String processTextDoc(String doc) {
if (doc != null) {
String gsonString = gson.toJson(doc);
// in 2.9, user will add " before and after string
// gson.toString(string) will add extra " before after string, so need to remove
return gsonString.substring(1, gsonString.length() - 1);
} else {
return null;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import org.opensearch.script.ScriptService;

import java.util.Arrays;
import java.util.Map;

import static org.junit.Assert.assertEquals;
import static org.mockito.ArgumentMatchers.any;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import org.junit.Assert;
import org.junit.Test;

import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -99,4 +100,13 @@ public void getParameterMap() {
Assert.assertEquals("[10,20]", parameterMap.get("key4"));
Assert.assertEquals("[1.01,\"abc\"]", parameterMap.get("key5"));
}

@Test
public void processTextDocs() {
List<String> processedDocs = StringUtils.processTextDocs(Arrays.asList("abc \n\n123\"4", null, "[1.01,\"abc\"]"));
Assert.assertEquals(3, processedDocs.size());
Assert.assertEquals("abc \\n\\n123\\\"4", processedDocs.get(0));
Assert.assertNull(processedDocs.get(1));
Assert.assertEquals("[1.01,\\\"abc\\\"]", processedDocs.get(2));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import static org.opensearch.ml.common.connector.MLPreProcessFunction.CONVERT_INPUT_TO_JSON_STRING;
import static org.opensearch.ml.common.connector.MLPreProcessFunction.PROCESS_REMOTE_INFERENCE_INPUT;
import static org.opensearch.ml.common.utils.StringUtils.gson;
import static org.opensearch.ml.common.utils.StringUtils.processTextDoc;
import static org.opensearch.ml.common.utils.StringUtils.processTextDocs;
import static org.opensearch.ml.engine.utils.ScriptUtils.executePostProcessFunction;

import java.io.IOException;
Expand All @@ -29,6 +31,7 @@
import org.opensearch.ml.common.connector.functions.preprocess.DefaultPreProcessFunction;
import org.opensearch.ml.common.connector.functions.preprocess.RemoteInferencePreProcessFunction;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.model.ModelTensor;
Expand Down Expand Up @@ -68,20 +71,7 @@ public static RemoteInferenceInputDataSet processInput(
throw new IllegalArgumentException("no predict action found");
}
RemoteInferenceInputDataSet inputData = processMLInput(mlInput, connector, parameters, scriptService);
if (inputData.getParameters() != null) {
Map<String, String> newParameters = new HashMap<>();
inputData.getParameters().forEach((key, value) -> {
if (value == null) {
newParameters.put(key, null);
} else if (org.opensearch.ml.common.utils.StringUtils.isJson(value)) {
// no need to escape if it's already valid json
newParameters.put(key, value);
} else {
newParameters.put(key, escapeJson(value));
}
});
inputData.setParameters(newParameters);
}
escapeRemoteInferenceInputData(inputData);
return inputData;
}

Expand Down Expand Up @@ -112,6 +102,7 @@ private static RemoteInferenceInputDataSet processMLInput(
return (RemoteInferenceInputDataSet) mlInput.getInputDataset();
}
} else {
MLInput newInput = escapeMLInput(mlInput);
boolean convertInputToJsonString = parameters.containsKey(CONVERT_INPUT_TO_JSON_STRING)
&& Boolean.parseBoolean(parameters.get(CONVERT_INPUT_TO_JSON_STRING));
DefaultPreProcessFunction function = DefaultPreProcessFunction
Expand All @@ -120,11 +111,51 @@ private static RemoteInferenceInputDataSet processMLInput(
.preProcessFunction(preProcessFunction)
.convertInputToJsonString(convertInputToJsonString)
.build();
return function.apply(mlInput);
return function.apply(newInput);
}
}
}

private static MLInput escapeMLInput(MLInput mlInput) {
if (mlInput.getInputDataset() instanceof TextDocsInputDataSet) {
List<String> docs = ((TextDocsInputDataSet) mlInput.getInputDataset()).getDocs();
List<String> newDocs = processTextDocs(docs);
TextDocsInputDataSet newInputData = ((TextDocsInputDataSet) mlInput.getInputDataset()).toBuilder().docs(newDocs).build();
return mlInput.toBuilder().inputDataset(newInputData).build();
}

if (mlInput.getInputDataset() instanceof TextSimilarityInputDataSet) {
String query = ((TextSimilarityInputDataSet) mlInput.getInputDataset()).getQueryText();
String newQuery = processTextDoc(query);
List<String> docs = ((TextSimilarityInputDataSet) mlInput.getInputDataset()).getTextDocs();
List<String> newDocs = processTextDocs(docs);
TextSimilarityInputDataSet newInputData = ((TextSimilarityInputDataSet) mlInput.getInputDataset())
.toBuilder()
.queryText(newQuery)
.textDocs(newDocs)
.build();
return mlInput.toBuilder().inputDataset(newInputData).build();
}
return mlInput;
}

public static void escapeRemoteInferenceInputData(RemoteInferenceInputDataSet inputData) {
Map<String, String> newParameters = new HashMap<>();
if (inputData.getParameters() != null) {
inputData.getParameters().forEach((key, value) -> {
if (value == null) {
newParameters.put(key, null);
} else if (org.opensearch.ml.common.utils.StringUtils.isJson(value)) {
// no need to escape if it's already valid json
newParameters.put(key, value);
} else {
newParameters.put(key, escapeJson(value));
}
});
inputData.setParameters(newParameters);
}
}

private static String getPreprocessFunction(MLInput mlInput, Connector connector) {
Optional<ConnectorAction> predictAction = connector.findPredictAction();
String preProcessFunction = predictAction.get().getPreProcessFunction();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.ml.engine.algorithms.remote;

import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.escapeRemoteInferenceInputData;
import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.processInput;

import java.util.ArrayList;
Expand Down Expand Up @@ -108,6 +109,7 @@ default void preparePayloadAndInvokeRemoteModel(MLInput mlInput, List<ModelTenso
MLInputDataset inputDataset = mlInput.getInputDataset();
Map<String, String> inputParameters = new HashMap<>();
if (inputDataset instanceof RemoteInferenceInputDataSet && ((RemoteInferenceInputDataSet) inputDataset).getParameters() != null) {
escapeRemoteInferenceInputData((RemoteInferenceInputDataSet) inputDataset);
inputParameters.putAll(((RemoteInferenceInputDataSet) inputDataset).getParameters());
}
parameters.putAll(inputParameters);
Expand Down

0 comments on commit 72d494c

Please sign in to comment.