Skip to content

Commit

Permalink
fix failed ut
Browse files Browse the repository at this point in the history
Signed-off-by: Yaliang Wu <[email protected]>
  • Loading branch information
ylwu-amzn committed Jul 9, 2023
1 parent 7da9d75 commit c30752a
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ public class ConnectorProtocols {
public static final Set<String> VALID_PROTOCOLS = Set.of(HTTP, AWS_SIGV4);

public static void validateProtocol(String protocol) {
if (protocol == null) {
throw new IllegalArgumentException("Connector protocol is null. Please use one of [" + Arrays.toString(VALID_PROTOCOLS.toArray(new String[0]))+ "]");
}
if (!VALID_PROTOCOLS.contains(protocol)) {
throw new IllegalArgumentException("Unsupported connector protocol. Please use one of [" + Arrays.toString(VALID_PROTOCOLS.toArray(new String[0]))+ "]");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ public void executePredict_RemoteInferenceInput_MissingCredential() {
.url("http://test.com/mock")
.requestBody("{\"input\": \"${parameters.input}\"}")
.build();
AwsConnector.awsConnectorBuilder().name("test connector").version("1").actions(Arrays.asList(predictAction)).build();
AwsConnector.awsConnectorBuilder().name("test connector").protocol("http").version("1").actions(Arrays.asList(predictAction)).build();
}

@Test
Expand All @@ -99,7 +99,7 @@ public void executePredict_RemoteInferenceInput_NullResponse() throws IOExceptio
.build();
Map<String, String> credential = ImmutableMap.of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key"));
Map<String, String> parameters = ImmutableMap.of(REGION_FIELD, "us-west-2", SERVICE_NAME_FIELD, "sagemaker");
Connector connector = AwsConnector.awsConnectorBuilder().name("test connector").version("1").parameters(parameters).credential(credential).actions(Arrays.asList(predictAction)).build();
Connector connector = AwsConnector.awsConnectorBuilder().name("test connector").version("1").protocol("http").parameters(parameters).credential(credential).actions(Arrays.asList(predictAction)).build();
connector.decrypt((c) -> encryptor.decrypt(c));
AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector, httpClient));

Expand All @@ -124,7 +124,7 @@ public void executePredict_RemoteInferenceInput() throws IOException {
.build();
Map<String, String> credential = ImmutableMap.of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key"));
Map<String, String> parameters = ImmutableMap.of(REGION_FIELD, "us-west-2", SERVICE_NAME_FIELD, "sagemaker");
Connector connector = AwsConnector.awsConnectorBuilder().name("test connector").version("1").parameters(parameters).credential(credential).actions(Arrays.asList(predictAction)).build();
Connector connector = AwsConnector.awsConnectorBuilder().name("test connector").version("1").protocol("http").parameters(parameters).credential(credential).actions(Arrays.asList(predictAction)).build();
connector.decrypt((c) -> encryptor.decrypt(c));
AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector, httpClient));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ public void processInput_TextDocsInputDataSet_NoPreprocessFunction() {
.url("http://test.com/mock")
.requestBody("{\"input\": \"${parameters.input}\"}")
.build();
Connector connector = HttpConnector.builder().name("test connector").version("1").actions(Arrays.asList(predictAction)).build();
Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build();
ConnectorUtils.processInput(mlInput, connector, new HashMap<>(), scriptService);
}

Expand Down Expand Up @@ -104,7 +104,7 @@ public void processOutput_NoPostprocessFunction() throws IOException {
.build();
Map<String, String> parameters = new HashMap<>();
parameters.put("key1", "value1");
Connector connector = HttpConnector.builder().name("test connector").version("1").parameters(parameters).actions(Arrays.asList(predictAction)).build();
Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").parameters(parameters).actions(Arrays.asList(predictAction)).build();
ModelTensors tensors = ConnectorUtils.processOutput("{\"response\": \"test response\"}", connector, scriptService, ImmutableMap.of());
Assert.assertEquals(1, tensors.getMlModelTensors().size());
Assert.assertEquals("response", tensors.getMlModelTensors().get(0).getName());
Expand All @@ -126,7 +126,7 @@ public void processOutput_PostprocessFunction() throws IOException {
.build();
Map<String, String> parameters = new HashMap<>();
parameters.put("key1", "value1");
Connector connector = HttpConnector.builder().name("test connector").version("1").parameters(parameters).actions(Arrays.asList(predictAction)).build();
Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").parameters(parameters).actions(Arrays.asList(predictAction)).build();
String modelResponse = "{\"object\":\"list\",\"data\":[{\"object\":\"embedding\",\"index\":0,\"embedding\":[-0.014555434,-0.0002135904,0.0035105038]}],\"model\":\"text-embedding-ada-002-v2\",\"usage\":{\"prompt_tokens\":5,\"total_tokens\":5}}";
ModelTensors tensors = ConnectorUtils.processOutput(modelResponse, connector, scriptService, ImmutableMap.of());
Assert.assertEquals(1, tensors.getMlModelTensors().size());
Expand All @@ -153,7 +153,7 @@ private void processInput_TextDocsInputDataSet_PreprocessFunction(String request
.build();
Map<String, String> parameters = new HashMap<>();
parameters.put("key1", "value1");
Connector connector = HttpConnector.builder().name("test connector").version("1").parameters(parameters).actions(Arrays.asList(predictAction)).build();
Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").parameters(parameters).actions(Arrays.asList(predictAction)).build();
RemoteInferenceInputDataSet remoteInferenceInputDataSet = ConnectorUtils.processInput(mlInput, connector, new HashMap<>(), scriptService);
Assert.assertNotNull(remoteInferenceInputDataSet.getParameters());
Assert.assertEquals(1, remoteInferenceInputDataSet.getParameters().size());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ public void invokeRemoteModel_WrongHttpMethod() {
.url("http://test.com/mock")
.requestBody("{\"input\": \"${parameters.input}\"}")
.build();
Connector connector = HttpConnector.builder().name("test connector").version("1").actions(Arrays.asList(predictAction)).build();
Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build();
HttpJsonConnectorExecutor executor = new HttpJsonConnectorExecutor(connector);
executor.invokeRemoteModel(null, null, null, null);
}
Expand All @@ -79,7 +79,7 @@ public void executePredict_RemoteInferenceInput() throws IOException {
.url("http://test.com/mock")
.requestBody("{\"input\": \"${parameters.input}\"}")
.build();
Connector connector = HttpConnector.builder().name("test connector").version("1").actions(Arrays.asList(predictAction)).build();
Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build();
HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector));
when(httpClient.execute(any())).thenReturn(response);
HttpEntity entity = new StringEntity("{\"response\": \"test result\"}");
Expand All @@ -103,7 +103,7 @@ public void executePredict_TextDocsInput_NoPreprocessFunction() {
.url("http://test.com/mock")
.requestBody("{\"input\": \"${parameters.input}\"}")
.build();
Connector connector = HttpConnector.builder().name("test connector").version("1").actions(Arrays.asList(predictAction)).build();
Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build();
HttpJsonConnectorExecutor executor = new HttpJsonConnectorExecutor(connector);
MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build();
executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build());
Expand All @@ -129,7 +129,7 @@ public void executePredict_TextDocsInput() throws IOException {
.postProcessFunction(MLPostProcessFunction.OPENAI_EMBEDDING)
.requestBody("{\"input\": \"${parameters.input}\"}")
.build();
Connector connector = HttpConnector.builder().name("test connector").version("1").actions(Arrays.asList(predictAction)).build();
Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build();
HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector));
executor.setScriptService(scriptService);
when(httpClient.execute(any())).thenReturn(response);
Expand Down

0 comments on commit c30752a

Please sign in to comment.