Skip to content

Commit

Permalink
fill null parameters in connector body template (#1192)
Browse files Browse the repository at this point in the history
Signed-off-by: Yaliang Wu <[email protected]>
  • Loading branch information
ylwu-amzn authored Aug 8, 2023
1 parent fffec84 commit 4ac47f6
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
import java.util.Map;
import java.util.Optional;
import java.util.function.Function;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.connector.ConnectorProtocols.HTTP;
Expand Down Expand Up @@ -82,7 +84,7 @@ public HttpConnector(String protocol, XContentParser parser) throws IOException
description = parser.text();
break;
case PROTOCOL_FIELD:
protocol = parser.text();
this.protocol = parser.text();
break;
case PARAMETERS_FIELD:
Map<String, Object> map = parser.map();
Expand Down Expand Up @@ -251,9 +253,9 @@ public <T> T createPredictPayload(Map<String, String> parameters) {
Optional<ConnectorAction> predictAction = findPredictAction();
if (predictAction.isPresent() && predictAction.get().getRequestBody() != null) {
String payload = predictAction.get().getRequestBody();
payload = fillNullParameters(parameters, payload);
StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}");
payload = substitutor.replace(payload);

if (!isJson(payload)) {
throw new IllegalArgumentException("Invalid JSON in payload");
}
Expand All @@ -262,6 +264,30 @@ public <T> T createPredictPayload(Map<String, String> parameters) {
return (T) parameters.get("http_body");
}

protected String fillNullParameters(Map<String, String> parameters, String payload) {
List<String> bodyParams = findStringParametersWithNullDefaultValue(payload);
String newPayload = payload;
for (String key : bodyParams) {
if (!parameters.containsKey(key) || parameters.get(key) == null) {
newPayload = newPayload.replace("\"${parameters." + key + ":-null}\"", "null");
}
}
return newPayload;
}

private List<String> findStringParametersWithNullDefaultValue(String input) {
String regex = "\"\\$\\{parameters\\.(\\w+):-null}\"";
Pattern pattern = Pattern.compile(regex);
Matcher matcher = pattern.matcher(input);

List<String> paramList = new ArrayList<>();
while (matcher.find()) {
String parameterValue = matcher.group(1);
paramList.add(parameterValue);
}
return paramList;
}

@Override
public void decrypt(Function<String, String> function) {
Map<String, String> decrypted = new HashMap<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,14 @@ public void parseResponse_NonJsonString() throws IOException {
Assert.assertEquals("test output", modelTensors.get(0).getDataAsMap().get("response"));
}

@Test
public void fillNullParameters() {
HttpConnector connector = createHttpConnector();
Map<String, String> parameters = new HashMap<>();
String output = connector.fillNullParameters(parameters, "{\"input1\": \"${parameters.input1:-null}\"}");
Assert.assertEquals("{\"input1\": null}", output);
}

public static HttpConnector createHttpConnector() {
ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT;
String method = "POST";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,9 @@ public static RemoteInferenceInputDataSet processInput(MLInput mlInput, Connecto
if (inputData.getParameters() != null) {
Map<String, String> newParameters = new HashMap<>();
inputData.getParameters().entrySet().forEach(entry -> {
if (StringUtils.isJson(entry.getValue())) {
if (entry.getValue() == null) {
newParameters.put(entry.getKey(), entry.getValue());
} else if (StringUtils.isJson(entry.getValue())) {
// no need to escape if it's already valid json
newParameters.put(entry.getKey(), entry.getValue());
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,16 @@ public void processInput_RemoteInferenceInputDataSet_NotEscapeJsonString() {
processInput_RemoteInferenceInputDataSet(input, input);
}

@Test
public void processInput_RemoteInferenceInputDataSet_NullParam() {
String input = null;
processInput_RemoteInferenceInputDataSet(input, input);
}

private void processInput_RemoteInferenceInputDataSet(String input, String expectedInput) {
RemoteInferenceInputDataSet dataSet = RemoteInferenceInputDataSet.builder().parameters(ImmutableMap.of("input", input)).build();
Map<String, String> params = new HashMap<>();
params.put("input", input);
RemoteInferenceInputDataSet dataSet = RemoteInferenceInputDataSet.builder().parameters(params).build();
MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(dataSet).build();

ConnectorAction predictAction = ConnectorAction.builder()
Expand Down

0 comments on commit 4ac47f6

Please sign in to comment.