Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fill null parameters in connector body template #1192

Merged
merged 1 commit into from
Aug 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
import java.util.Map;
import java.util.Optional;
import java.util.function.Function;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.connector.ConnectorProtocols.HTTP;
Expand Down Expand Up @@ -82,7 +84,7 @@ public HttpConnector(String protocol, XContentParser parser) throws IOException
description = parser.text();
break;
case PROTOCOL_FIELD:
protocol = parser.text();
this.protocol = parser.text();
break;
case PARAMETERS_FIELD:
Map<String, Object> map = parser.map();
Expand Down Expand Up @@ -251,9 +253,9 @@ public <T> T createPredictPayload(Map<String, String> parameters) {
Optional<ConnectorAction> predictAction = findPredictAction();
if (predictAction.isPresent() && predictAction.get().getRequestBody() != null) {
String payload = predictAction.get().getRequestBody();
payload = fillNullParameters(parameters, payload);
StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}");
payload = substitutor.replace(payload);

if (!isJson(payload)) {
throw new IllegalArgumentException("Invalid JSON in payload");
}
Expand All @@ -262,6 +264,30 @@ public <T> T createPredictPayload(Map<String, String> parameters) {
return (T) parameters.get("http_body");
}

protected String fillNullParameters(Map<String, String> parameters, String payload) {
List<String> bodyParams = findStringParametersWithNullDefaultValue(payload);
String newPayload = payload;
for (String key : bodyParams) {
if (!parameters.containsKey(key) || parameters.get(key) == null) {
newPayload = newPayload.replace("\"${parameters." + key + ":-null}\"", "null");
}
}
return newPayload;
}

private List<String> findStringParametersWithNullDefaultValue(String input) {
String regex = "\"\\$\\{parameters\\.(\\w+):-null}\"";
Pattern pattern = Pattern.compile(regex);
Matcher matcher = pattern.matcher(input);

List<String> paramList = new ArrayList<>();
while (matcher.find()) {
String parameterValue = matcher.group(1);
paramList.add(parameterValue);
}
return paramList;
}

@Override
public void decrypt(Function<String, String> function) {
Map<String, String> decrypted = new HashMap<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,14 @@ public void parseResponse_NonJsonString() throws IOException {
Assert.assertEquals("test output", modelTensors.get(0).getDataAsMap().get("response"));
}

@Test
public void fillNullParameters() {
HttpConnector connector = createHttpConnector();
Map<String, String> parameters = new HashMap<>();
String output = connector.fillNullParameters(parameters, "{\"input1\": \"${parameters.input1:-null}\"}");
Assert.assertEquals("{\"input1\": null}", output);
}

public static HttpConnector createHttpConnector() {
ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT;
String method = "POST";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,9 @@ public static RemoteInferenceInputDataSet processInput(MLInput mlInput, Connecto
if (inputData.getParameters() != null) {
Map<String, String> newParameters = new HashMap<>();
inputData.getParameters().entrySet().forEach(entry -> {
if (StringUtils.isJson(entry.getValue())) {
if (entry.getValue() == null) {
newParameters.put(entry.getKey(), entry.getValue());
} else if (StringUtils.isJson(entry.getValue())) {
// no need to escape if it's already valid json
newParameters.put(entry.getKey(), entry.getValue());
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,16 @@ public void processInput_RemoteInferenceInputDataSet_NotEscapeJsonString() {
processInput_RemoteInferenceInputDataSet(input, input);
}

@Test
public void processInput_RemoteInferenceInputDataSet_NullParam() {
String input = null;
processInput_RemoteInferenceInputDataSet(input, input);
}

private void processInput_RemoteInferenceInputDataSet(String input, String expectedInput) {
RemoteInferenceInputDataSet dataSet = RemoteInferenceInputDataSet.builder().parameters(ImmutableMap.of("input", input)).build();
Map<String, String> params = new HashMap<>();
params.put("input", input);
RemoteInferenceInputDataSet dataSet = RemoteInferenceInputDataSet.builder().parameters(params).build();
MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(dataSet).build();

ConnectorAction predictAction = ConnectorAction.builder()
Expand Down