diff --git a/common/src/main/java/org/opensearch/ml/common/dataset/MLInputDataset.java b/common/src/main/java/org/opensearch/ml/common/dataset/MLInputDataset.java index cde37a685f..2c3514530f 100644 --- a/common/src/main/java/org/opensearch/ml/common/dataset/MLInputDataset.java +++ b/common/src/main/java/org/opensearch/ml/common/dataset/MLInputDataset.java @@ -7,6 +7,7 @@ import java.io.IOException; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.common.io.stream.Writeable; @@ -14,6 +15,7 @@ import lombok.Getter; import lombok.RequiredArgsConstructor; import lombok.experimental.FieldDefaults; +import org.opensearch.ml.common.MLCommonsClassLoader; @Getter @FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) @@ -25,4 +27,9 @@ public abstract class MLInputDataset implements Writeable { public void writeTo(StreamOutput streamOutput) throws IOException { streamOutput.writeEnum(this.inputDataType); } + + public static MLInputDataset fromStream(StreamInput in) throws IOException { + MLInputDataType inputDataType = in.readEnum(MLInputDataType.class); + return MLCommonsClassLoader.initMLInstance(inputDataType, in, StreamInput.class); + } } diff --git a/common/src/main/java/org/opensearch/ml/common/dataset/remote/RemoteInferenceInputDataSet.java b/common/src/main/java/org/opensearch/ml/common/dataset/remote/RemoteInferenceInputDataSet.java index 035722ca5d..f7ac619e1b 100644 --- a/common/src/main/java/org/opensearch/ml/common/dataset/remote/RemoteInferenceInputDataSet.java +++ b/common/src/main/java/org/opensearch/ml/common/dataset/remote/RemoteInferenceInputDataSet.java @@ -32,13 +32,20 @@ public RemoteInferenceInputDataSet(Map parameters) { public RemoteInferenceInputDataSet(StreamInput streamInput) throws IOException { super(MLInputDataType.REMOTE); - parameters = streamInput.readMap(s -> s.readString(), s-> s.readString()); + if (streamInput.readBoolean()) { + parameters = streamInput.readMap(s -> s.readString(), s-> s.readString()); + } } @Override public void writeTo(StreamOutput streamOutput) throws IOException { super.writeTo(streamOutput); - streamOutput.writeMap(parameters, StreamOutput::writeString, StreamOutput::writeString); + if (parameters != null) { + streamOutput.writeBoolean(true); + streamOutput.writeMap(parameters, StreamOutput::writeString, StreamOutput::writeString); + } else { + streamOutput.writeBoolean(false); + } } } diff --git a/common/src/main/java/org/opensearch/ml/common/input/MLInput.java b/common/src/main/java/org/opensearch/ml/common/input/MLInput.java index ee293374df..8bbd9630b2 100644 --- a/common/src/main/java/org/opensearch/ml/common/input/MLInput.java +++ b/common/src/main/java/org/opensearch/ml/common/input/MLInput.java @@ -17,7 +17,6 @@ import org.opensearch.ml.common.dataframe.DefaultDataFrame; import org.opensearch.ml.common.dataset.DataFrameInputDataset; import org.opensearch.ml.common.output.model.ModelResultFilter; -import org.opensearch.ml.common.dataset.MLInputDataType; import org.opensearch.ml.common.dataset.MLInputDataset; import org.opensearch.ml.common.dataset.SearchQueryInputDataset; import org.opensearch.ml.common.FunctionName; @@ -98,8 +97,7 @@ public MLInput(StreamInput in) throws IOException { this.parameters = MLCommonsClassLoader.initMLInstance(algorithm, in, StreamInput.class); } if (in.readBoolean()) { - MLInputDataType inputDataType = in.readEnum(MLInputDataType.class); - this.inputDataset = MLCommonsClassLoader.initMLInstance(inputDataType, in, StreamInput.class); + this.inputDataset = MLInputDataset.fromStream(in); } this.version = in.readInt(); } diff --git a/common/src/main/java/org/opensearch/ml/common/input/remote/RemoteInferenceMLInput.java b/common/src/main/java/org/opensearch/ml/common/input/remote/RemoteInferenceMLInput.java index 9cbab818f5..5992e77a24 100644 --- a/common/src/main/java/org/opensearch/ml/common/input/remote/RemoteInferenceMLInput.java +++ b/common/src/main/java/org/opensearch/ml/common/input/remote/RemoteInferenceMLInput.java @@ -11,6 +11,7 @@ import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.utils.StringUtils; import java.io.IOException; import java.security.AccessController; @@ -38,8 +39,6 @@ public void writeTo(StreamOutput out) throws IOException { public RemoteInferenceMLInput(XContentParser parser, FunctionName functionName) throws IOException { super(); this.algorithm = functionName; - Map parameterObjs = new HashMap<>(); - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { String fieldName = parser.currentName(); @@ -47,30 +46,14 @@ public RemoteInferenceMLInput(XContentParser parser, FunctionName functionName) switch (fieldName) { case PARAMETERS_FIELD: - parameterObjs = parser.map(); + Map parameters = StringUtils.getParameterMap(parser.map()); + inputDataset = new RemoteInferenceInputDataSet(parameters); break; default: parser.skipChildren(); break; } } - Map parameters = new HashMap<>(); - for (String key : parameterObjs.keySet()) { - Object value = parameterObjs.get(key); - try { - AccessController.doPrivileged((PrivilegedExceptionAction) () -> { - if (value instanceof String) { - parameters.put(key, (String)value); - } else { - parameters.put(key, gson.toJson(value)); - } - return null; - }); - } catch (PrivilegedActionException e) { - throw new RuntimeException(e); - } - } - inputDataset = new RemoteInferenceInputDataSet(parameters); } } diff --git a/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java b/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java index 8ff8fb0961..968cda1575 100644 --- a/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java +++ b/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java @@ -63,15 +63,6 @@ public static Map fromJson(String jsonStr, String defaultKey) { return result; } - public static Map fromJson(String jsonStr) { - JsonElement jsonElement = JsonParser.parseString(jsonStr); - return gson.fromJson(jsonElement, Map.class); - } - - public static String toJson(Map map) { - return new JSONObject(map).toString(); - } - public static Map getParameterMap(Map parameterObjs) { Map parameters = new HashMap<>(); for (String key : parameterObjs.keySet()) { diff --git a/common/src/test/java/org/opensearch/ml/common/dataset/remote/RemoteInferenceInputDataSetTest.java b/common/src/test/java/org/opensearch/ml/common/dataset/remote/RemoteInferenceInputDataSetTest.java new file mode 100644 index 0000000000..618bc336a7 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/dataset/remote/RemoteInferenceInputDataSetTest.java @@ -0,0 +1,48 @@ +package org.opensearch.ml.common.dataset.remote; + +import org.junit.Assert; +import org.junit.Test; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.ml.common.dataset.MLInputDataset; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.opensearch.ml.common.dataset.MLInputDataType.REMOTE; + +public class RemoteInferenceInputDataSetTest { + + @Test + public void writeTo_NullParameter() throws IOException { + Map parameters = null; + RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet.builder().parameters(parameters).build(); + + BytesStreamOutput output = new BytesStreamOutput(); + inputDataSet.writeTo(output); + StreamInput streamInput = output.bytes().streamInput(); + + RemoteInferenceInputDataSet inputDataSet2 = (RemoteInferenceInputDataSet) MLInputDataset.fromStream(streamInput); + Assert.assertEquals(REMOTE, inputDataSet2.getInputDataType()); + Assert.assertNull(inputDataSet2.getParameters()); + } + + @Test + public void writeTo() throws IOException { + Map parameters = new HashMap<>(); + parameters.put("key1", "test value1"); + parameters.put("key2", "test value2"); + RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet.builder().parameters(parameters).build(); + + BytesStreamOutput output = new BytesStreamOutput(); + inputDataSet.writeTo(output); + StreamInput streamInput = output.bytes().streamInput(); + + RemoteInferenceInputDataSet inputDataSet2 = (RemoteInferenceInputDataSet) MLInputDataset.fromStream(streamInput); + Assert.assertEquals(REMOTE, inputDataSet2.getInputDataType()); + Assert.assertEquals(2, inputDataSet2.getParameters().size()); + Assert.assertEquals("test value1", inputDataSet2.getParameters().get("key1")); + Assert.assertEquals("test value2", inputDataSet2.getParameters().get("key2")); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/input/remote/RemoteInferenceMLInputTest.java b/common/src/test/java/org/opensearch/ml/common/input/remote/RemoteInferenceMLInputTest.java new file mode 100644 index 0000000000..a01a955e7a --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/input/remote/RemoteInferenceMLInputTest.java @@ -0,0 +1,52 @@ +package org.opensearch.ml.common.input.remote; + +import org.junit.Assert; +import org.junit.Test; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.dataset.MLInputDataType; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.search.SearchModule; + +import java.io.IOException; +import java.util.Collections; + +public class RemoteInferenceMLInputTest { + + @Test + public void constructor_parser() throws IOException { + RemoteInferenceMLInput input = createRemoteInferenceMLInput(); + Assert.assertNotNull(input.getInputDataset()); + Assert.assertEquals(MLInputDataType.REMOTE, input.getInputDataset().getInputDataType()); + RemoteInferenceInputDataSet inputDataSet = (RemoteInferenceInputDataSet)input.getInputDataset(); + Assert.assertEquals(1, inputDataSet.getParameters().size()); + Assert.assertEquals("hello world", inputDataSet.getParameters().get("prompt")); + } + + @Test + public void constructor_stream() throws IOException { + RemoteInferenceMLInput originalInput = createRemoteInferenceMLInput(); + BytesStreamOutput output = new BytesStreamOutput(); + originalInput.writeTo(output); + + RemoteInferenceMLInput input = new RemoteInferenceMLInput(output.bytes().streamInput()); + Assert.assertNotNull(input.getInputDataset()); + Assert.assertEquals(MLInputDataType.REMOTE, input.getInputDataset().getInputDataType()); + RemoteInferenceInputDataSet inputDataSet = (RemoteInferenceInputDataSet)input.getInputDataset(); + Assert.assertEquals(1, inputDataSet.getParameters().size()); + Assert.assertEquals("hello world", inputDataSet.getParameters().get("prompt")); + } + + private static RemoteInferenceMLInput createRemoteInferenceMLInput() throws IOException { + String jsonStr = "{ \"parameters\": { \"prompt\": \"hello world\" } }"; + XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, + Collections.emptyList()).getNamedXContents()), null, jsonStr); + parser.nextToken(); + RemoteInferenceMLInput input = new RemoteInferenceMLInput(parser, FunctionName.REMOTE); + return input; + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/output/model/ModelTensorTest.java b/common/src/test/java/org/opensearch/ml/common/output/model/ModelTensorTest.java index e4c7741e7e..68904cb390 100644 --- a/common/src/test/java/org/opensearch/ml/common/output/model/ModelTensorTest.java +++ b/common/src/test/java/org/opensearch/ml/common/output/model/ModelTensorTest.java @@ -17,6 +17,8 @@ import java.io.IOException; import java.nio.ByteBuffer; +import java.util.HashMap; +import java.util.Map; import static org.junit.Assert.assertEquals; import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; @@ -30,12 +32,17 @@ public class ModelTensorTest { @Before public void setUp() { + Map dataMap = new HashMap<>(); + dataMap.put("key1", "test value1"); + dataMap.put("key2", "test value2"); modelTensor = ModelTensor.builder() .name("model_tensor") .data(new Number[]{1, 2, 3}) .shape(new long[]{1, 2, 3,}) .dataType(MLResultDataType.INT32) .byteBuffer(ByteBuffer.wrap(new byte[]{0,1,0,1})) + .result("test result") + .dataAsMap(dataMap) .build(); } @@ -46,7 +53,7 @@ public void test_StreamInAndOut() throws IOException { StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); ModelTensor parsedTensor = new ModelTensor(streamInput); -// assertEquals(parsedTensor, modelTensor); + assertEquals(modelTensor, parsedTensor); } @Test @@ -54,7 +61,13 @@ public void test_ModelTensorSuccess() throws IOException { XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); modelTensor.toXContent(builder, EMPTY_PARAMS); String modelTensorContent = TestHelper.xContentBuilderToString(builder); - assertEquals("{\"name\":\"model_tensor\",\"data_type\":\"INT32\",\"shape\":[1,2,3],\"data\":[1,2,3],\"byte_buffer\":{\"array\":\"AAEAAQ==\",\"order\":\"BIG_ENDIAN\"}}", modelTensorContent); + assertEquals("{\"name\":\"model_tensor\"," + + "\"data_type\":\"INT32\"," + + "\"shape\":[1,2,3]," + + "\"data\":[1,2,3]," + + "\"byte_buffer\":{\"array\":\"AAEAAQ==\",\"order\":\"BIG_ENDIAN\"}," + + "\"result\":\"test result\"," + + "\"dataAsMap\":{\"key1\":\"test value1\",\"key2\":\"test value2\"}}", modelTensorContent); } @Test @@ -74,7 +87,7 @@ public void test_StreamInAndOut_NullValue() throws IOException { StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); ModelTensor parsedTensor = new ModelTensor(streamInput); -// assertEquals(parsedTensor, tensor); + assertEquals(tensor, parsedTensor); } @Test diff --git a/common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java b/common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java new file mode 100644 index 0000000000..a4b34d75b5 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java @@ -0,0 +1,98 @@ +package org.opensearch.ml.common.utils; + +import org.junit.Assert; +import org.junit.Test; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class StringUtilsTest { + + @Test + public void isJson_True() { + Assert.assertTrue(StringUtils.isJson("{}")); + Assert.assertTrue(StringUtils.isJson("[]")); + Assert.assertTrue(StringUtils.isJson("{\"key\": \"value\"}")); + Assert.assertTrue(StringUtils.isJson("{\"key\": 123}")); + Assert.assertTrue(StringUtils.isJson("[1, 2, 3]")); + Assert.assertTrue(StringUtils.isJson("[\"a\", \"b\"]")); + Assert.assertTrue(StringUtils.isJson("[1, \"a\"]")); + } + + @Test + public void isJson_False() { + Assert.assertFalse(StringUtils.isJson("{")); + Assert.assertFalse(StringUtils.isJson("[")); + Assert.assertFalse(StringUtils.isJson("{\"key\": \"value}")); + Assert.assertFalse(StringUtils.isJson("{\"key\": \"value\", \"key\": 123}")); + Assert.assertFalse(StringUtils.isJson("[1, \"a]")); + } + + @Test + public void toUTF8() { + String rawString = "\uD83D\uDE00\uD83D\uDE0D\uD83D\uDE1C"; + String utf8 = StringUtils.toUTF8(rawString); + Assert.assertNotNull(utf8); + } + + @Test + public void fromJson_SimpleMap() { + Map response = StringUtils.fromJson("{\"key\": \"value\"}", "response"); + Assert.assertEquals(1, response.size()); + Assert.assertEquals("value", response.get("key")); + } + + @Test + public void fromJson_NestedMap() { + Map response = StringUtils.fromJson("{\"key\": {\"nested_key\": \"nested_value\", \"nested_array\": [1, \"a\"]}}", "response"); + Assert.assertEquals(1, response.size()); + Assert.assertTrue(response.get("key") instanceof Map); + Map nestedMap = (Map)response.get("key"); + Assert.assertEquals("nested_value", nestedMap.get("nested_key")); + List list = (List)nestedMap.get("nested_array"); + Assert.assertEquals(2, list.size()); + Assert.assertEquals(1.0, list.get(0)); + Assert.assertEquals("a", list.get(1)); + } + + @Test + public void fromJson_SimpleList() { + Map response = StringUtils.fromJson("[1, \"a\"]", "response"); + Assert.assertEquals(1, response.size()); + Assert.assertTrue(response.get("response") instanceof List); + List list = (List)response.get("response"); + Assert.assertEquals(1.0, list.get(0)); + Assert.assertEquals("a", list.get(1)); + } + + @Test + public void fromJson_NestedList() { + Map response = StringUtils.fromJson("[1, \"a\", [2, 3], {\"key\": \"value\"}]", "response"); + Assert.assertEquals(1, response.size()); + Assert.assertTrue(response.get("response") instanceof List); + List list = (List)response.get("response"); + Assert.assertEquals(1.0, list.get(0)); + Assert.assertEquals("a", list.get(1)); + Assert.assertTrue(list.get(2) instanceof List); + Assert.assertTrue(list.get(3) instanceof Map); + } + + @Test + public void getParameterMap() { + Map parameters = new HashMap<>(); + parameters.put("key1", "value1"); + parameters.put("key2", 2); + parameters.put("key3", 2.1); + parameters.put("key4", new int[]{10, 20}); + parameters.put("key5", new Object[]{1.01, "abc"}); + Map parameterMap = StringUtils.getParameterMap(parameters); + System.out.println(parameterMap); + Assert.assertEquals(5, parameterMap.size()); + Assert.assertEquals("value1", parameterMap.get("key1")); + Assert.assertEquals("2", parameterMap.get("key2")); + Assert.assertEquals("2.1", parameterMap.get("key3")); + Assert.assertEquals("[10,20]", parameterMap.get("key4")); + Assert.assertEquals("[1.01,\"abc\"]", parameterMap.get("key5")); + } +}