Skip to content

Commit

Permalink
remote inference: add unit test for StringUtils and remote inference …
Browse files Browse the repository at this point in the history
…input (#1061)

Signed-off-by: Yaliang Wu <[email protected]>
  • Loading branch information
ylwu-amzn authored and zane-neo committed Sep 1, 2023
1 parent eaa626d commit 9597fbf
Show file tree
Hide file tree
Showing 9 changed files with 234 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@

import java.io.IOException;

import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.common.io.stream.Writeable;

import lombok.AccessLevel;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.experimental.FieldDefaults;
import org.opensearch.ml.common.MLCommonsClassLoader;

@Getter
@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)
Expand All @@ -25,4 +27,9 @@ public abstract class MLInputDataset implements Writeable {
public void writeTo(StreamOutput streamOutput) throws IOException {
streamOutput.writeEnum(this.inputDataType);
}

public static MLInputDataset fromStream(StreamInput in) throws IOException {
MLInputDataType inputDataType = in.readEnum(MLInputDataType.class);
return MLCommonsClassLoader.initMLInstance(inputDataType, in, StreamInput.class);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,20 @@ public RemoteInferenceInputDataSet(Map<String, String> parameters) {

public RemoteInferenceInputDataSet(StreamInput streamInput) throws IOException {
super(MLInputDataType.REMOTE);
parameters = streamInput.readMap(s -> s.readString(), s-> s.readString());
if (streamInput.readBoolean()) {
parameters = streamInput.readMap(s -> s.readString(), s-> s.readString());
}
}

@Override
public void writeTo(StreamOutput streamOutput) throws IOException {
super.writeTo(streamOutput);
streamOutput.writeMap(parameters, StreamOutput::writeString, StreamOutput::writeString);
if (parameters != null) {
streamOutput.writeBoolean(true);
streamOutput.writeMap(parameters, StreamOutput::writeString, StreamOutput::writeString);
} else {
streamOutput.writeBoolean(false);
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import org.opensearch.ml.common.dataframe.DefaultDataFrame;
import org.opensearch.ml.common.dataset.DataFrameInputDataset;
import org.opensearch.ml.common.output.model.ModelResultFilter;
import org.opensearch.ml.common.dataset.MLInputDataType;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.dataset.SearchQueryInputDataset;
import org.opensearch.ml.common.FunctionName;
Expand Down Expand Up @@ -98,8 +97,7 @@ public MLInput(StreamInput in) throws IOException {
this.parameters = MLCommonsClassLoader.initMLInstance(algorithm, in, StreamInput.class);
}
if (in.readBoolean()) {
MLInputDataType inputDataType = in.readEnum(MLInputDataType.class);
this.inputDataset = MLCommonsClassLoader.initMLInstance(inputDataType, in, StreamInput.class);
this.inputDataset = MLInputDataset.fromStream(in);
}
this.version = in.readInt();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.utils.StringUtils;

import java.io.IOException;
import java.security.AccessController;
Expand Down Expand Up @@ -38,39 +39,21 @@ public void writeTo(StreamOutput out) throws IOException {
public RemoteInferenceMLInput(XContentParser parser, FunctionName functionName) throws IOException {
super();
this.algorithm = functionName;
Map<String, ?> parameterObjs = new HashMap<>();

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
String fieldName = parser.currentName();
parser.nextToken();

switch (fieldName) {
case PARAMETERS_FIELD:
parameterObjs = parser.map();
Map<String, String> parameters = StringUtils.getParameterMap(parser.map());
inputDataset = new RemoteInferenceInputDataSet(parameters);
break;
default:
parser.skipChildren();
break;
}
}
Map<String, String> parameters = new HashMap<>();
for (String key : parameterObjs.keySet()) {
Object value = parameterObjs.get(key);
try {
AccessController.doPrivileged((PrivilegedExceptionAction<Void>) () -> {
if (value instanceof String) {
parameters.put(key, (String)value);
} else {
parameters.put(key, gson.toJson(value));
}
return null;
});
} catch (PrivilegedActionException e) {
throw new RuntimeException(e);
}
}
inputDataset = new RemoteInferenceInputDataSet(parameters);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,6 @@ public static Map<String, Object> fromJson(String jsonStr, String defaultKey) {
return result;
}

public static Map<String, String> fromJson(String jsonStr) {
JsonElement jsonElement = JsonParser.parseString(jsonStr);
return gson.fromJson(jsonElement, Map.class);
}

public static String toJson(Map<String, String> map) {
return new JSONObject(map).toString();
}

public static Map<String, String> getParameterMap(Map<String, ?> parameterObjs) {
Map<String, String> parameters = new HashMap<>();
for (String key : parameterObjs.keySet()) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package org.opensearch.ml.common.dataset.remote;

import org.junit.Assert;
import org.junit.Test;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.ml.common.dataset.MLInputDataset;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;

import static org.opensearch.ml.common.dataset.MLInputDataType.REMOTE;

public class RemoteInferenceInputDataSetTest {

@Test
public void writeTo_NullParameter() throws IOException {
Map<String, String> parameters = null;
RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet.builder().parameters(parameters).build();

BytesStreamOutput output = new BytesStreamOutput();
inputDataSet.writeTo(output);
StreamInput streamInput = output.bytes().streamInput();

RemoteInferenceInputDataSet inputDataSet2 = (RemoteInferenceInputDataSet) MLInputDataset.fromStream(streamInput);
Assert.assertEquals(REMOTE, inputDataSet2.getInputDataType());
Assert.assertNull(inputDataSet2.getParameters());
}

@Test
public void writeTo() throws IOException {
Map<String, String> parameters = new HashMap<>();
parameters.put("key1", "test value1");
parameters.put("key2", "test value2");
RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet.builder().parameters(parameters).build();

BytesStreamOutput output = new BytesStreamOutput();
inputDataSet.writeTo(output);
StreamInput streamInput = output.bytes().streamInput();

RemoteInferenceInputDataSet inputDataSet2 = (RemoteInferenceInputDataSet) MLInputDataset.fromStream(streamInput);
Assert.assertEquals(REMOTE, inputDataSet2.getInputDataType());
Assert.assertEquals(2, inputDataSet2.getParameters().size());
Assert.assertEquals("test value1", inputDataSet2.getParameters().get("key1"));
Assert.assertEquals("test value2", inputDataSet2.getParameters().get("key2"));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package org.opensearch.ml.common.input.remote;

import org.junit.Assert;
import org.junit.Test;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.MLInputDataType;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.search.SearchModule;

import java.io.IOException;
import java.util.Collections;

public class RemoteInferenceMLInputTest {

@Test
public void constructor_parser() throws IOException {
RemoteInferenceMLInput input = createRemoteInferenceMLInput();
Assert.assertNotNull(input.getInputDataset());
Assert.assertEquals(MLInputDataType.REMOTE, input.getInputDataset().getInputDataType());
RemoteInferenceInputDataSet inputDataSet = (RemoteInferenceInputDataSet)input.getInputDataset();
Assert.assertEquals(1, inputDataSet.getParameters().size());
Assert.assertEquals("hello world", inputDataSet.getParameters().get("prompt"));
}

@Test
public void constructor_stream() throws IOException {
RemoteInferenceMLInput originalInput = createRemoteInferenceMLInput();
BytesStreamOutput output = new BytesStreamOutput();
originalInput.writeTo(output);

RemoteInferenceMLInput input = new RemoteInferenceMLInput(output.bytes().streamInput());
Assert.assertNotNull(input.getInputDataset());
Assert.assertEquals(MLInputDataType.REMOTE, input.getInputDataset().getInputDataType());
RemoteInferenceInputDataSet inputDataSet = (RemoteInferenceInputDataSet)input.getInputDataset();
Assert.assertEquals(1, inputDataSet.getParameters().size());
Assert.assertEquals("hello world", inputDataSet.getParameters().get("prompt"));
}

private static RemoteInferenceMLInput createRemoteInferenceMLInput() throws IOException {
String jsonStr = "{ \"parameters\": { \"prompt\": \"hello world\" } }";
XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY,
Collections.emptyList()).getNamedXContents()), null, jsonStr);
parser.nextToken();
RemoteInferenceMLInput input = new RemoteInferenceMLInput(parser, FunctionName.REMOTE);
return input;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.HashMap;
import java.util.Map;

import static org.junit.Assert.assertEquals;
import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS;
Expand All @@ -30,12 +32,17 @@ public class ModelTensorTest {

@Before
public void setUp() {
Map<String, String> dataMap = new HashMap<>();
dataMap.put("key1", "test value1");
dataMap.put("key2", "test value2");
modelTensor = ModelTensor.builder()
.name("model_tensor")
.data(new Number[]{1, 2, 3})
.shape(new long[]{1, 2, 3,})
.dataType(MLResultDataType.INT32)
.byteBuffer(ByteBuffer.wrap(new byte[]{0,1,0,1}))
.result("test result")
.dataAsMap(dataMap)
.build();
}

Expand All @@ -46,15 +53,21 @@ public void test_StreamInAndOut() throws IOException {

StreamInput streamInput = bytesStreamOutput.bytes().streamInput();
ModelTensor parsedTensor = new ModelTensor(streamInput);
// assertEquals(parsedTensor, modelTensor);
assertEquals(modelTensor, parsedTensor);
}

@Test
public void test_ModelTensorSuccess() throws IOException {
XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
modelTensor.toXContent(builder, EMPTY_PARAMS);
String modelTensorContent = TestHelper.xContentBuilderToString(builder);
assertEquals("{\"name\":\"model_tensor\",\"data_type\":\"INT32\",\"shape\":[1,2,3],\"data\":[1,2,3],\"byte_buffer\":{\"array\":\"AAEAAQ==\",\"order\":\"BIG_ENDIAN\"}}", modelTensorContent);
assertEquals("{\"name\":\"model_tensor\"," +
"\"data_type\":\"INT32\"," +
"\"shape\":[1,2,3]," +
"\"data\":[1,2,3]," +
"\"byte_buffer\":{\"array\":\"AAEAAQ==\",\"order\":\"BIG_ENDIAN\"}," +
"\"result\":\"test result\"," +
"\"dataAsMap\":{\"key1\":\"test value1\",\"key2\":\"test value2\"}}", modelTensorContent);
}

@Test
Expand All @@ -74,7 +87,7 @@ public void test_StreamInAndOut_NullValue() throws IOException {

StreamInput streamInput = bytesStreamOutput.bytes().streamInput();
ModelTensor parsedTensor = new ModelTensor(streamInput);
// assertEquals(parsedTensor, tensor);
assertEquals(tensor, parsedTensor);
}

@Test
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
package org.opensearch.ml.common.utils;

import org.junit.Assert;
import org.junit.Test;

import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class StringUtilsTest {

@Test
public void isJson_True() {
Assert.assertTrue(StringUtils.isJson("{}"));
Assert.assertTrue(StringUtils.isJson("[]"));
Assert.assertTrue(StringUtils.isJson("{\"key\": \"value\"}"));
Assert.assertTrue(StringUtils.isJson("{\"key\": 123}"));
Assert.assertTrue(StringUtils.isJson("[1, 2, 3]"));
Assert.assertTrue(StringUtils.isJson("[\"a\", \"b\"]"));
Assert.assertTrue(StringUtils.isJson("[1, \"a\"]"));
}

@Test
public void isJson_False() {
Assert.assertFalse(StringUtils.isJson("{"));
Assert.assertFalse(StringUtils.isJson("["));
Assert.assertFalse(StringUtils.isJson("{\"key\": \"value}"));
Assert.assertFalse(StringUtils.isJson("{\"key\": \"value\", \"key\": 123}"));
Assert.assertFalse(StringUtils.isJson("[1, \"a]"));
}

@Test
public void toUTF8() {
String rawString = "\uD83D\uDE00\uD83D\uDE0D\uD83D\uDE1C";
String utf8 = StringUtils.toUTF8(rawString);
Assert.assertNotNull(utf8);
}

@Test
public void fromJson_SimpleMap() {
Map<String, Object> response = StringUtils.fromJson("{\"key\": \"value\"}", "response");
Assert.assertEquals(1, response.size());
Assert.assertEquals("value", response.get("key"));
}

@Test
public void fromJson_NestedMap() {
Map<String, Object> response = StringUtils.fromJson("{\"key\": {\"nested_key\": \"nested_value\", \"nested_array\": [1, \"a\"]}}", "response");
Assert.assertEquals(1, response.size());
Assert.assertTrue(response.get("key") instanceof Map);
Map nestedMap = (Map)response.get("key");
Assert.assertEquals("nested_value", nestedMap.get("nested_key"));
List list = (List)nestedMap.get("nested_array");
Assert.assertEquals(2, list.size());
Assert.assertEquals(1.0, list.get(0));
Assert.assertEquals("a", list.get(1));
}

@Test
public void fromJson_SimpleList() {
Map<String, Object> response = StringUtils.fromJson("[1, \"a\"]", "response");
Assert.assertEquals(1, response.size());
Assert.assertTrue(response.get("response") instanceof List);
List list = (List)response.get("response");
Assert.assertEquals(1.0, list.get(0));
Assert.assertEquals("a", list.get(1));
}

@Test
public void fromJson_NestedList() {
Map<String, Object> response = StringUtils.fromJson("[1, \"a\", [2, 3], {\"key\": \"value\"}]", "response");
Assert.assertEquals(1, response.size());
Assert.assertTrue(response.get("response") instanceof List);
List list = (List)response.get("response");
Assert.assertEquals(1.0, list.get(0));
Assert.assertEquals("a", list.get(1));
Assert.assertTrue(list.get(2) instanceof List);
Assert.assertTrue(list.get(3) instanceof Map);
}

@Test
public void getParameterMap() {
Map<String, Object> parameters = new HashMap<>();
parameters.put("key1", "value1");
parameters.put("key2", 2);
parameters.put("key3", 2.1);
parameters.put("key4", new int[]{10, 20});
parameters.put("key5", new Object[]{1.01, "abc"});
Map<String, String> parameterMap = StringUtils.getParameterMap(parameters);
System.out.println(parameterMap);
Assert.assertEquals(5, parameterMap.size());
Assert.assertEquals("value1", parameterMap.get("key1"));
Assert.assertEquals("2", parameterMap.get("key2"));
Assert.assertEquals("2.1", parameterMap.get("key3"));
Assert.assertEquals("[10,20]", parameterMap.get("key4"));
Assert.assertEquals("[1.01,\"abc\"]", parameterMap.get("key5"));
}
}

0 comments on commit 9597fbf

Please sign in to comment.