Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remote inference: add unit test for StringUtils and remote inference input #1061

Merged
merged 1 commit into from
Jul 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@

import java.io.IOException;

import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.io.stream.StreamOutput;
import org.opensearch.common.io.stream.Writeable;

import lombok.AccessLevel;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.experimental.FieldDefaults;
import org.opensearch.ml.common.MLCommonsClassLoader;

@Getter
@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)
Expand All @@ -25,4 +27,9 @@ public abstract class MLInputDataset implements Writeable {
public void writeTo(StreamOutput streamOutput) throws IOException {
streamOutput.writeEnum(this.inputDataType);
}

public static MLInputDataset fromStream(StreamInput in) throws IOException {
MLInputDataType inputDataType = in.readEnum(MLInputDataType.class);
return MLCommonsClassLoader.initMLInstance(inputDataType, in, StreamInput.class);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,20 @@ public RemoteInferenceInputDataSet(Map<String, String> parameters) {

public RemoteInferenceInputDataSet(StreamInput streamInput) throws IOException {
super(MLInputDataType.REMOTE);
parameters = streamInput.readMap(s -> s.readString(), s-> s.readString());
if (streamInput.readBoolean()) {
parameters = streamInput.readMap(s -> s.readString(), s-> s.readString());
}
}

@Override
public void writeTo(StreamOutput streamOutput) throws IOException {
super.writeTo(streamOutput);
streamOutput.writeMap(parameters, StreamOutput::writeString, StreamOutput::writeString);
if (parameters != null) {
streamOutput.writeBoolean(true);
streamOutput.writeMap(parameters, StreamOutput::writeString, StreamOutput::writeString);
} else {
streamOutput.writeBoolean(false);
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import org.opensearch.ml.common.dataframe.DefaultDataFrame;
import org.opensearch.ml.common.dataset.DataFrameInputDataset;
import org.opensearch.ml.common.output.model.ModelResultFilter;
import org.opensearch.ml.common.dataset.MLInputDataType;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.dataset.SearchQueryInputDataset;
import org.opensearch.ml.common.FunctionName;
Expand Down Expand Up @@ -98,8 +97,7 @@ public MLInput(StreamInput in) throws IOException {
this.parameters = MLCommonsClassLoader.initMLInstance(algorithm, in, StreamInput.class);
}
if (in.readBoolean()) {
MLInputDataType inputDataType = in.readEnum(MLInputDataType.class);
this.inputDataset = MLCommonsClassLoader.initMLInstance(inputDataType, in, StreamInput.class);
this.inputDataset = MLInputDataset.fromStream(in);
}
this.version = in.readInt();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.utils.StringUtils;

import java.io.IOException;
import java.security.AccessController;
Expand Down Expand Up @@ -38,39 +39,21 @@ public void writeTo(StreamOutput out) throws IOException {
public RemoteInferenceMLInput(XContentParser parser, FunctionName functionName) throws IOException {
super();
this.algorithm = functionName;
Map<String, ?> parameterObjs = new HashMap<>();

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
String fieldName = parser.currentName();
parser.nextToken();

switch (fieldName) {
case PARAMETERS_FIELD:
parameterObjs = parser.map();
Map<String, String> parameters = StringUtils.getParameterMap(parser.map());
inputDataset = new RemoteInferenceInputDataSet(parameters);
break;
default:
parser.skipChildren();
break;
}
}
Map<String, String> parameters = new HashMap<>();
for (String key : parameterObjs.keySet()) {
Object value = parameterObjs.get(key);
try {
AccessController.doPrivileged((PrivilegedExceptionAction<Void>) () -> {
if (value instanceof String) {
parameters.put(key, (String)value);
} else {
parameters.put(key, gson.toJson(value));
}
return null;
});
} catch (PrivilegedActionException e) {
throw new RuntimeException(e);
}
}
inputDataset = new RemoteInferenceInputDataSet(parameters);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,6 @@ public static Map<String, Object> fromJson(String jsonStr, String defaultKey) {
return result;
}

public static Map<String, String> fromJson(String jsonStr) {
JsonElement jsonElement = JsonParser.parseString(jsonStr);
return gson.fromJson(jsonElement, Map.class);
}

public static String toJson(Map<String, String> map) {
return new JSONObject(map).toString();
}

public static Map<String, String> getParameterMap(Map<String, ?> parameterObjs) {
Map<String, String> parameters = new HashMap<>();
for (String key : parameterObjs.keySet()) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package org.opensearch.ml.common.dataset.remote;

import org.junit.Assert;
import org.junit.Test;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.ml.common.dataset.MLInputDataset;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;

import static org.opensearch.ml.common.dataset.MLInputDataType.REMOTE;

public class RemoteInferenceInputDataSetTest {

@Test
public void writeTo_NullParameter() throws IOException {
Map<String, String> parameters = null;
RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet.builder().parameters(parameters).build();

BytesStreamOutput output = new BytesStreamOutput();
inputDataSet.writeTo(output);
StreamInput streamInput = output.bytes().streamInput();

RemoteInferenceInputDataSet inputDataSet2 = (RemoteInferenceInputDataSet) MLInputDataset.fromStream(streamInput);
Assert.assertEquals(REMOTE, inputDataSet2.getInputDataType());
Assert.assertNull(inputDataSet2.getParameters());
}

@Test
public void writeTo() throws IOException {
Map<String, String> parameters = new HashMap<>();
parameters.put("key1", "test value1");
parameters.put("key2", "test value2");
RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet.builder().parameters(parameters).build();

BytesStreamOutput output = new BytesStreamOutput();
inputDataSet.writeTo(output);
StreamInput streamInput = output.bytes().streamInput();

RemoteInferenceInputDataSet inputDataSet2 = (RemoteInferenceInputDataSet) MLInputDataset.fromStream(streamInput);
Assert.assertEquals(REMOTE, inputDataSet2.getInputDataType());
Assert.assertEquals(2, inputDataSet2.getParameters().size());
Assert.assertEquals("test value1", inputDataSet2.getParameters().get("key1"));
Assert.assertEquals("test value2", inputDataSet2.getParameters().get("key2"));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package org.opensearch.ml.common.input.remote;

import org.junit.Assert;
import org.junit.Test;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.MLInputDataType;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.search.SearchModule;

import java.io.IOException;
import java.util.Collections;

public class RemoteInferenceMLInputTest {

@Test
public void constructor_parser() throws IOException {
RemoteInferenceMLInput input = createRemoteInferenceMLInput();
Assert.assertNotNull(input.getInputDataset());
Assert.assertEquals(MLInputDataType.REMOTE, input.getInputDataset().getInputDataType());
RemoteInferenceInputDataSet inputDataSet = (RemoteInferenceInputDataSet)input.getInputDataset();
Assert.assertEquals(1, inputDataSet.getParameters().size());
Assert.assertEquals("hello world", inputDataSet.getParameters().get("prompt"));
}

@Test
public void constructor_stream() throws IOException {
RemoteInferenceMLInput originalInput = createRemoteInferenceMLInput();
BytesStreamOutput output = new BytesStreamOutput();
originalInput.writeTo(output);

RemoteInferenceMLInput input = new RemoteInferenceMLInput(output.bytes().streamInput());
Assert.assertNotNull(input.getInputDataset());
Assert.assertEquals(MLInputDataType.REMOTE, input.getInputDataset().getInputDataType());
RemoteInferenceInputDataSet inputDataSet = (RemoteInferenceInputDataSet)input.getInputDataset();
Assert.assertEquals(1, inputDataSet.getParameters().size());
Assert.assertEquals("hello world", inputDataSet.getParameters().get("prompt"));
}

private static RemoteInferenceMLInput createRemoteInferenceMLInput() throws IOException {
String jsonStr = "{ \"parameters\": { \"prompt\": \"hello world\" } }";
XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY,
Collections.emptyList()).getNamedXContents()), null, jsonStr);
parser.nextToken();
RemoteInferenceMLInput input = new RemoteInferenceMLInput(parser, FunctionName.REMOTE);
return input;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.HashMap;
import java.util.Map;

import static org.junit.Assert.assertEquals;
import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS;
Expand All @@ -30,12 +32,17 @@ public class ModelTensorTest {

@Before
public void setUp() {
Map<String, String> dataMap = new HashMap<>();
dataMap.put("key1", "test value1");
dataMap.put("key2", "test value2");
modelTensor = ModelTensor.builder()
.name("model_tensor")
.data(new Number[]{1, 2, 3})
.shape(new long[]{1, 2, 3,})
.dataType(MLResultDataType.INT32)
.byteBuffer(ByteBuffer.wrap(new byte[]{0,1,0,1}))
.result("test result")
.dataAsMap(dataMap)
.build();
}

Expand All @@ -46,15 +53,21 @@ public void test_StreamInAndOut() throws IOException {

StreamInput streamInput = bytesStreamOutput.bytes().streamInput();
ModelTensor parsedTensor = new ModelTensor(streamInput);
// assertEquals(parsedTensor, modelTensor);
assertEquals(modelTensor, parsedTensor);
}

@Test
public void test_ModelTensorSuccess() throws IOException {
XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
modelTensor.toXContent(builder, EMPTY_PARAMS);
String modelTensorContent = TestHelper.xContentBuilderToString(builder);
assertEquals("{\"name\":\"model_tensor\",\"data_type\":\"INT32\",\"shape\":[1,2,3],\"data\":[1,2,3],\"byte_buffer\":{\"array\":\"AAEAAQ==\",\"order\":\"BIG_ENDIAN\"}}", modelTensorContent);
assertEquals("{\"name\":\"model_tensor\"," +
"\"data_type\":\"INT32\"," +
"\"shape\":[1,2,3]," +
"\"data\":[1,2,3]," +
"\"byte_buffer\":{\"array\":\"AAEAAQ==\",\"order\":\"BIG_ENDIAN\"}," +
"\"result\":\"test result\"," +
"\"dataAsMap\":{\"key1\":\"test value1\",\"key2\":\"test value2\"}}", modelTensorContent);
}

@Test
Expand All @@ -74,7 +87,7 @@ public void test_StreamInAndOut_NullValue() throws IOException {

StreamInput streamInput = bytesStreamOutput.bytes().streamInput();
ModelTensor parsedTensor = new ModelTensor(streamInput);
// assertEquals(parsedTensor, tensor);
assertEquals(tensor, parsedTensor);
}

@Test
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
package org.opensearch.ml.common.utils;

import org.junit.Assert;
import org.junit.Test;

import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class StringUtilsTest {

@Test
public void isJson_True() {
Assert.assertTrue(StringUtils.isJson("{}"));
Assert.assertTrue(StringUtils.isJson("[]"));
Assert.assertTrue(StringUtils.isJson("{\"key\": \"value\"}"));
Assert.assertTrue(StringUtils.isJson("{\"key\": 123}"));
Assert.assertTrue(StringUtils.isJson("[1, 2, 3]"));
Assert.assertTrue(StringUtils.isJson("[\"a\", \"b\"]"));
Assert.assertTrue(StringUtils.isJson("[1, \"a\"]"));
}

@Test
public void isJson_False() {
Assert.assertFalse(StringUtils.isJson("{"));
Assert.assertFalse(StringUtils.isJson("["));
Assert.assertFalse(StringUtils.isJson("{\"key\": \"value}"));
Assert.assertFalse(StringUtils.isJson("{\"key\": \"value\", \"key\": 123}"));
Assert.assertFalse(StringUtils.isJson("[1, \"a]"));
}

@Test
public void toUTF8() {
String rawString = "\uD83D\uDE00\uD83D\uDE0D\uD83D\uDE1C";
String utf8 = StringUtils.toUTF8(rawString);
Assert.assertNotNull(utf8);
}

@Test
public void fromJson_SimpleMap() {
Map<String, Object> response = StringUtils.fromJson("{\"key\": \"value\"}", "response");
Assert.assertEquals(1, response.size());
Assert.assertEquals("value", response.get("key"));
}

@Test
public void fromJson_NestedMap() {
Map<String, Object> response = StringUtils.fromJson("{\"key\": {\"nested_key\": \"nested_value\", \"nested_array\": [1, \"a\"]}}", "response");
Assert.assertEquals(1, response.size());
Assert.assertTrue(response.get("key") instanceof Map);
Map nestedMap = (Map)response.get("key");
Assert.assertEquals("nested_value", nestedMap.get("nested_key"));
List list = (List)nestedMap.get("nested_array");
Assert.assertEquals(2, list.size());
Assert.assertEquals(1.0, list.get(0));
Assert.assertEquals("a", list.get(1));
}

@Test
public void fromJson_SimpleList() {
Map<String, Object> response = StringUtils.fromJson("[1, \"a\"]", "response");
Assert.assertEquals(1, response.size());
Assert.assertTrue(response.get("response") instanceof List);
List list = (List)response.get("response");
Assert.assertEquals(1.0, list.get(0));
Assert.assertEquals("a", list.get(1));
}

@Test
public void fromJson_NestedList() {
Map<String, Object> response = StringUtils.fromJson("[1, \"a\", [2, 3], {\"key\": \"value\"}]", "response");
Assert.assertEquals(1, response.size());
Assert.assertTrue(response.get("response") instanceof List);
List list = (List)response.get("response");
Assert.assertEquals(1.0, list.get(0));
Assert.assertEquals("a", list.get(1));
Assert.assertTrue(list.get(2) instanceof List);
Assert.assertTrue(list.get(3) instanceof Map);
}

@Test
public void getParameterMap() {
Map<String, Object> parameters = new HashMap<>();
parameters.put("key1", "value1");
parameters.put("key2", 2);
parameters.put("key3", 2.1);
parameters.put("key4", new int[]{10, 20});
parameters.put("key5", new Object[]{1.01, "abc"});
Map<String, String> parameterMap = StringUtils.getParameterMap(parameters);
System.out.println(parameterMap);
Assert.assertEquals(5, parameterMap.size());
Assert.assertEquals("value1", parameterMap.get("key1"));
Assert.assertEquals("2", parameterMap.get("key2"));
Assert.assertEquals("2.1", parameterMap.get("key3"));
Assert.assertEquals("[10,20]", parameterMap.get("key4"));
Assert.assertEquals("[1.01,\"abc\"]", parameterMap.get("key5"));
}
}