diff --git a/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java b/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java index 82d19bdb01..96f6c018ec 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java @@ -26,6 +26,8 @@ import java.util.Map; import java.util.Optional; import java.util.function.Function; +import java.util.regex.Matcher; +import java.util.regex.Pattern; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.common.connector.ConnectorProtocols.HTTP; @@ -82,7 +84,7 @@ public HttpConnector(String protocol, XContentParser parser) throws IOException description = parser.text(); break; case PROTOCOL_FIELD: - protocol = parser.text(); + this.protocol = parser.text(); break; case PARAMETERS_FIELD: Map map = parser.map(); @@ -251,6 +253,7 @@ public T createPredictPayload(Map parameters) { Optional predictAction = findPredictAction(); if (predictAction.isPresent() && predictAction.get().getRequestBody() != null) { String payload = predictAction.get().getRequestBody(); + payload = fillNullParameters(parameters, payload); StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}"); payload = substitutor.replace(payload); @@ -262,6 +265,30 @@ public T createPredictPayload(Map parameters) { return (T) parameters.get("http_body"); } + protected String fillNullParameters(Map parameters, String payload) { + List bodyParams = findStringParametersWithNullDefaultValue(payload); + String newPayload = payload; + for (String key : bodyParams) { + if (!parameters.containsKey(key) || parameters.get(key) == null) { + newPayload = newPayload.replace("\"${parameters." + key + ":-null}\"", "null"); + } + } + return newPayload; + } + + private List findStringParametersWithNullDefaultValue(String input) { + String regex = "\"\\$\\{parameters\\.(\\w+):-null}\""; + Pattern pattern = Pattern.compile(regex); + Matcher matcher = pattern.matcher(input); + + List paramList = new ArrayList<>(); + while (matcher.find()) { + String parameterValue = matcher.group(1); + paramList.add(parameterValue); + } + return paramList; + } + @Override public void decrypt(Function function) { Map decrypted = new HashMap<>(); diff --git a/common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java b/common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java index 1874791b8b..8e51a06c38 100644 --- a/common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java +++ b/common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java @@ -259,6 +259,14 @@ public void parseResponse_NonJsonString() throws IOException { Assert.assertEquals("test output", modelTensors.get(0).getDataAsMap().get("response")); } + @Test + public void fillNullParameters() { + HttpConnector connector = createHttpConnector(); + Map parameters = new HashMap<>(); + String output = connector.fillNullParameters(parameters, "{\"input1\": \"${parameters.input1:-null}\"}"); + Assert.assertEquals("{\"input1\": null}", output); + } + public static HttpConnector createHttpConnector() { ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT; String method = "POST"; diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java index 588da7ccae..cd3038f49c 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java @@ -99,7 +99,9 @@ public static RemoteInferenceInputDataSet processInput(MLInput mlInput, Connecto if (inputData.getParameters() != null) { Map newParameters = new HashMap<>(); inputData.getParameters().entrySet().forEach(entry -> { - if (StringUtils.isJson(entry.getValue())) { + if (entry.getValue() == null) { + newParameters.put(entry.getKey(), entry.getValue()); + } else if (StringUtils.isJson(entry.getValue())) { // no need to escape if it's already valid json newParameters.put(entry.getKey(), entry.getValue()); } else { diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java index bfe5023b9a..2a84e2fee1 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java @@ -96,8 +96,16 @@ public void processInput_RemoteInferenceInputDataSet_NotEscapeJsonString() { processInput_RemoteInferenceInputDataSet(input, input); } + @Test + public void processInput_RemoteInferenceInputDataSet_NullParam() { + String input = null; + processInput_RemoteInferenceInputDataSet(input, input); + } + private void processInput_RemoteInferenceInputDataSet(String input, String expectedInput) { - RemoteInferenceInputDataSet dataSet = RemoteInferenceInputDataSet.builder().parameters(ImmutableMap.of("input", input)).build(); + Map params = new HashMap<>(); + params.put("input", input); + RemoteInferenceInputDataSet dataSet = RemoteInferenceInputDataSet.builder().parameters(params).build(); MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(dataSet).build(); ConnectorAction predictAction = ConnectorAction.builder()