Skip to content

Commit

Permalink
fill null parameters in connector body template (#1192) (#1219)
Browse files Browse the repository at this point in the history
Signed-off-by: Yaliang Wu <[email protected]>
  • Loading branch information
ylwu-amzn authored Aug 18, 2023
1 parent c93005f commit 6db14b1
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
import java.util.Map;
import java.util.Optional;
import java.util.function.Function;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import lombok.Builder;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
Expand Down Expand Up @@ -81,7 +84,7 @@ public HttpConnector(String protocol, XContentParser parser) throws IOException
description = parser.text();
break;
case PROTOCOL_FIELD:
protocol = parser.text();
this.protocol = parser.text();
break;
case PARAMETERS_FIELD:
Map<String, Object> map = parser.map();
Expand Down Expand Up @@ -250,6 +253,7 @@ public <T> T createPredictPayload(Map<String, String> parameters) {
Optional<ConnectorAction> predictAction = findPredictAction();
if (predictAction.isPresent() && predictAction.get().getRequestBody() != null) {
String payload = predictAction.get().getRequestBody();
payload = fillNullParameters(parameters, payload);
StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}");
payload = substitutor.replace(payload);

Expand All @@ -261,6 +265,30 @@ public <T> T createPredictPayload(Map<String, String> parameters) {
return (T) parameters.get("http_body");
}

protected String fillNullParameters(Map<String, String> parameters, String payload) {
List<String> bodyParams = findStringParametersWithNullDefaultValue(payload);
String newPayload = payload;
for (String key : bodyParams) {
if (!parameters.containsKey(key) || parameters.get(key) == null) {
newPayload = newPayload.replace("\"${parameters." + key + ":-null}\"", "null");
}
}
return newPayload;
}

private List<String> findStringParametersWithNullDefaultValue(String input) {
String regex = "\"\\$\\{parameters\\.(\\w+):-null}\"";
Pattern pattern = Pattern.compile(regex);
Matcher matcher = pattern.matcher(input);

List<String> paramList = new ArrayList<>();
while (matcher.find()) {
String parameterValue = matcher.group(1);
paramList.add(parameterValue);
}
return paramList;
}

@Override
public void decrypt(Function<String, String> function) {
Map<String, String> decrypted = new HashMap<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,14 @@ public void parseResponse_NonJsonString() throws IOException {
Assert.assertEquals("test output", modelTensors.get(0).getDataAsMap().get("response"));
}

@Test
public void fillNullParameters() {
HttpConnector connector = createHttpConnector();
Map<String, String> parameters = new HashMap<>();
String output = connector.fillNullParameters(parameters, "{\"input1\": \"${parameters.input1:-null}\"}");
Assert.assertEquals("{\"input1\": null}", output);
}

public static HttpConnector createHttpConnector() {
ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT;
String method = "POST";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,9 @@ public static RemoteInferenceInputDataSet processInput(MLInput mlInput, Connecto
if (inputData.getParameters() != null) {
Map<String, String> newParameters = new HashMap<>();
inputData.getParameters().entrySet().forEach(entry -> {
if (StringUtils.isJson(entry.getValue())) {
if (entry.getValue() == null) {
newParameters.put(entry.getKey(), entry.getValue());
} else if (StringUtils.isJson(entry.getValue())) {
// no need to escape if it's already valid json
newParameters.put(entry.getKey(), entry.getValue());
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,16 @@ public void processInput_RemoteInferenceInputDataSet_NotEscapeJsonString() {
processInput_RemoteInferenceInputDataSet(input, input);
}

@Test
public void processInput_RemoteInferenceInputDataSet_NullParam() {
String input = null;
processInput_RemoteInferenceInputDataSet(input, input);
}

private void processInput_RemoteInferenceInputDataSet(String input, String expectedInput) {
RemoteInferenceInputDataSet dataSet = RemoteInferenceInputDataSet.builder().parameters(ImmutableMap.of("input", input)).build();
Map<String, String> params = new HashMap<>();
params.put("input", input);
RemoteInferenceInputDataSet dataSet = RemoteInferenceInputDataSet.builder().parameters(params).build();
MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(dataSet).build();

ConnectorAction predictAction = ConnectorAction.builder()
Expand Down

0 comments on commit 6db14b1

Please sign in to comment.