-
Notifications
You must be signed in to change notification settings - Fork 140
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
remote inference: add unit test for StringUtils and remote inference …
…input (#1061) Signed-off-by: Yaliang Wu <[email protected]>
- Loading branch information
Showing
9 changed files
with
234 additions
and
37 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
48 changes: 48 additions & 0 deletions
48
...rc/test/java/org/opensearch/ml/common/dataset/remote/RemoteInferenceInputDataSetTest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
package org.opensearch.ml.common.dataset.remote; | ||
|
||
import org.junit.Assert; | ||
import org.junit.Test; | ||
import org.opensearch.common.io.stream.BytesStreamOutput; | ||
import org.opensearch.common.io.stream.StreamInput; | ||
import org.opensearch.ml.common.dataset.MLInputDataset; | ||
|
||
import java.io.IOException; | ||
import java.util.HashMap; | ||
import java.util.Map; | ||
|
||
import static org.opensearch.ml.common.dataset.MLInputDataType.REMOTE; | ||
|
||
public class RemoteInferenceInputDataSetTest { | ||
|
||
@Test | ||
public void writeTo_NullParameter() throws IOException { | ||
Map<String, String> parameters = null; | ||
RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet.builder().parameters(parameters).build(); | ||
|
||
BytesStreamOutput output = new BytesStreamOutput(); | ||
inputDataSet.writeTo(output); | ||
StreamInput streamInput = output.bytes().streamInput(); | ||
|
||
RemoteInferenceInputDataSet inputDataSet2 = (RemoteInferenceInputDataSet) MLInputDataset.fromStream(streamInput); | ||
Assert.assertEquals(REMOTE, inputDataSet2.getInputDataType()); | ||
Assert.assertNull(inputDataSet2.getParameters()); | ||
} | ||
|
||
@Test | ||
public void writeTo() throws IOException { | ||
Map<String, String> parameters = new HashMap<>(); | ||
parameters.put("key1", "test value1"); | ||
parameters.put("key2", "test value2"); | ||
RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet.builder().parameters(parameters).build(); | ||
|
||
BytesStreamOutput output = new BytesStreamOutput(); | ||
inputDataSet.writeTo(output); | ||
StreamInput streamInput = output.bytes().streamInput(); | ||
|
||
RemoteInferenceInputDataSet inputDataSet2 = (RemoteInferenceInputDataSet) MLInputDataset.fromStream(streamInput); | ||
Assert.assertEquals(REMOTE, inputDataSet2.getInputDataType()); | ||
Assert.assertEquals(2, inputDataSet2.getParameters().size()); | ||
Assert.assertEquals("test value1", inputDataSet2.getParameters().get("key1")); | ||
Assert.assertEquals("test value2", inputDataSet2.getParameters().get("key2")); | ||
} | ||
} |
52 changes: 52 additions & 0 deletions
52
common/src/test/java/org/opensearch/ml/common/input/remote/RemoteInferenceMLInputTest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
package org.opensearch.ml.common.input.remote; | ||
|
||
import org.junit.Assert; | ||
import org.junit.Test; | ||
import org.opensearch.common.io.stream.BytesStreamOutput; | ||
import org.opensearch.common.settings.Settings; | ||
import org.opensearch.common.xcontent.XContentType; | ||
import org.opensearch.core.xcontent.NamedXContentRegistry; | ||
import org.opensearch.core.xcontent.XContentParser; | ||
import org.opensearch.ml.common.FunctionName; | ||
import org.opensearch.ml.common.dataset.MLInputDataType; | ||
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; | ||
import org.opensearch.search.SearchModule; | ||
|
||
import java.io.IOException; | ||
import java.util.Collections; | ||
|
||
public class RemoteInferenceMLInputTest { | ||
|
||
@Test | ||
public void constructor_parser() throws IOException { | ||
RemoteInferenceMLInput input = createRemoteInferenceMLInput(); | ||
Assert.assertNotNull(input.getInputDataset()); | ||
Assert.assertEquals(MLInputDataType.REMOTE, input.getInputDataset().getInputDataType()); | ||
RemoteInferenceInputDataSet inputDataSet = (RemoteInferenceInputDataSet)input.getInputDataset(); | ||
Assert.assertEquals(1, inputDataSet.getParameters().size()); | ||
Assert.assertEquals("hello world", inputDataSet.getParameters().get("prompt")); | ||
} | ||
|
||
@Test | ||
public void constructor_stream() throws IOException { | ||
RemoteInferenceMLInput originalInput = createRemoteInferenceMLInput(); | ||
BytesStreamOutput output = new BytesStreamOutput(); | ||
originalInput.writeTo(output); | ||
|
||
RemoteInferenceMLInput input = new RemoteInferenceMLInput(output.bytes().streamInput()); | ||
Assert.assertNotNull(input.getInputDataset()); | ||
Assert.assertEquals(MLInputDataType.REMOTE, input.getInputDataset().getInputDataType()); | ||
RemoteInferenceInputDataSet inputDataSet = (RemoteInferenceInputDataSet)input.getInputDataset(); | ||
Assert.assertEquals(1, inputDataSet.getParameters().size()); | ||
Assert.assertEquals("hello world", inputDataSet.getParameters().get("prompt")); | ||
} | ||
|
||
private static RemoteInferenceMLInput createRemoteInferenceMLInput() throws IOException { | ||
String jsonStr = "{ \"parameters\": { \"prompt\": \"hello world\" } }"; | ||
XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, | ||
Collections.emptyList()).getNamedXContents()), null, jsonStr); | ||
parser.nextToken(); | ||
RemoteInferenceMLInput input = new RemoteInferenceMLInput(parser, FunctionName.REMOTE); | ||
return input; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
98 changes: 98 additions & 0 deletions
98
common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
package org.opensearch.ml.common.utils; | ||
|
||
import org.junit.Assert; | ||
import org.junit.Test; | ||
|
||
import java.util.HashMap; | ||
import java.util.List; | ||
import java.util.Map; | ||
|
||
public class StringUtilsTest { | ||
|
||
@Test | ||
public void isJson_True() { | ||
Assert.assertTrue(StringUtils.isJson("{}")); | ||
Assert.assertTrue(StringUtils.isJson("[]")); | ||
Assert.assertTrue(StringUtils.isJson("{\"key\": \"value\"}")); | ||
Assert.assertTrue(StringUtils.isJson("{\"key\": 123}")); | ||
Assert.assertTrue(StringUtils.isJson("[1, 2, 3]")); | ||
Assert.assertTrue(StringUtils.isJson("[\"a\", \"b\"]")); | ||
Assert.assertTrue(StringUtils.isJson("[1, \"a\"]")); | ||
} | ||
|
||
@Test | ||
public void isJson_False() { | ||
Assert.assertFalse(StringUtils.isJson("{")); | ||
Assert.assertFalse(StringUtils.isJson("[")); | ||
Assert.assertFalse(StringUtils.isJson("{\"key\": \"value}")); | ||
Assert.assertFalse(StringUtils.isJson("{\"key\": \"value\", \"key\": 123}")); | ||
Assert.assertFalse(StringUtils.isJson("[1, \"a]")); | ||
} | ||
|
||
@Test | ||
public void toUTF8() { | ||
String rawString = "\uD83D\uDE00\uD83D\uDE0D\uD83D\uDE1C"; | ||
String utf8 = StringUtils.toUTF8(rawString); | ||
Assert.assertNotNull(utf8); | ||
} | ||
|
||
@Test | ||
public void fromJson_SimpleMap() { | ||
Map<String, Object> response = StringUtils.fromJson("{\"key\": \"value\"}", "response"); | ||
Assert.assertEquals(1, response.size()); | ||
Assert.assertEquals("value", response.get("key")); | ||
} | ||
|
||
@Test | ||
public void fromJson_NestedMap() { | ||
Map<String, Object> response = StringUtils.fromJson("{\"key\": {\"nested_key\": \"nested_value\", \"nested_array\": [1, \"a\"]}}", "response"); | ||
Assert.assertEquals(1, response.size()); | ||
Assert.assertTrue(response.get("key") instanceof Map); | ||
Map nestedMap = (Map)response.get("key"); | ||
Assert.assertEquals("nested_value", nestedMap.get("nested_key")); | ||
List list = (List)nestedMap.get("nested_array"); | ||
Assert.assertEquals(2, list.size()); | ||
Assert.assertEquals(1.0, list.get(0)); | ||
Assert.assertEquals("a", list.get(1)); | ||
} | ||
|
||
@Test | ||
public void fromJson_SimpleList() { | ||
Map<String, Object> response = StringUtils.fromJson("[1, \"a\"]", "response"); | ||
Assert.assertEquals(1, response.size()); | ||
Assert.assertTrue(response.get("response") instanceof List); | ||
List list = (List)response.get("response"); | ||
Assert.assertEquals(1.0, list.get(0)); | ||
Assert.assertEquals("a", list.get(1)); | ||
} | ||
|
||
@Test | ||
public void fromJson_NestedList() { | ||
Map<String, Object> response = StringUtils.fromJson("[1, \"a\", [2, 3], {\"key\": \"value\"}]", "response"); | ||
Assert.assertEquals(1, response.size()); | ||
Assert.assertTrue(response.get("response") instanceof List); | ||
List list = (List)response.get("response"); | ||
Assert.assertEquals(1.0, list.get(0)); | ||
Assert.assertEquals("a", list.get(1)); | ||
Assert.assertTrue(list.get(2) instanceof List); | ||
Assert.assertTrue(list.get(3) instanceof Map); | ||
} | ||
|
||
@Test | ||
public void getParameterMap() { | ||
Map<String, Object> parameters = new HashMap<>(); | ||
parameters.put("key1", "value1"); | ||
parameters.put("key2", 2); | ||
parameters.put("key3", 2.1); | ||
parameters.put("key4", new int[]{10, 20}); | ||
parameters.put("key5", new Object[]{1.01, "abc"}); | ||
Map<String, String> parameterMap = StringUtils.getParameterMap(parameters); | ||
System.out.println(parameterMap); | ||
Assert.assertEquals(5, parameterMap.size()); | ||
Assert.assertEquals("value1", parameterMap.get("key1")); | ||
Assert.assertEquals("2", parameterMap.get("key2")); | ||
Assert.assertEquals("2.1", parameterMap.get("key3")); | ||
Assert.assertEquals("[10,20]", parameterMap.get("key4")); | ||
Assert.assertEquals("[1.01,\"abc\"]", parameterMap.get("key5")); | ||
} | ||
} |