Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.9] remote inference: add unit test for create connector request/response #1082

Merged
merged 1 commit into from
Jul 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ public class MLCreateConnectorInput implements ToXContentObject, Writeable {
public static final String ADD_ALL_BACKEND_ROLES_FIELD = "add_all_backend_roles";
public static final String OWNER_FIELD = "owner";
public static final String ACCESS_MODE_FIELD = "access_mode";
public static final String DRY_RUN_FIELD = "dry_run";

public static final String DRY_RUN_CONNECTOR_NAME = "dryRunConnector";

Expand All @@ -52,6 +53,7 @@ public class MLCreateConnectorInput implements ToXContentObject, Writeable {
private List<String> backendRoles;
private Boolean addAllBackendRoles;
private AccessMode access;
private boolean dryRun = false;

@Builder(toBuilder = true)
public MLCreateConnectorInput(String name,
Expand All @@ -63,8 +65,20 @@ public MLCreateConnectorInput(String name,
List<ConnectorAction> actions,
List<String> backendRoles,
Boolean addAllBackendRoles,
AccessMode access
AccessMode access,
boolean dryRun
) {
if (!dryRun) {
if (name == null) {
throw new IllegalArgumentException("Connector name is null");
}
if (version == null) {
throw new IllegalArgumentException("Connector version is null");
}
if (protocol == null) {
throw new IllegalArgumentException("Connector protocol is null");
}
}
this.name = name;
this.description = description;
this.version = version;
Expand All @@ -75,6 +89,7 @@ public MLCreateConnectorInput(String name,
this.backendRoles = backendRoles;
this.addAllBackendRoles = addAllBackendRoles;
this.access = access;
this.dryRun = dryRun;
}

public static MLCreateConnectorInput parse(XContentParser parser) throws IOException {
Expand All @@ -88,6 +103,7 @@ public static MLCreateConnectorInput parse(XContentParser parser) throws IOExcep
List<String> backendRoles = null;
Boolean addAllBackendRoles = null;
AccessMode access = null;
boolean dryRun = false;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
Expand Down Expand Up @@ -133,12 +149,15 @@ public static MLCreateConnectorInput parse(XContentParser parser) throws IOExcep
case ACCESS_MODE_FIELD:
access = AccessMode.from(parser.text());
break;
case DRY_RUN_FIELD:
dryRun = parser.booleanValue();
break;
default:
parser.skipChildren();
break;
}
}
return new MLCreateConnectorInput(name, description, version, protocol, parameters, credential, actions, backendRoles, addAllBackendRoles, access);
return new MLCreateConnectorInput(name, description, version, protocol, parameters, credential, actions, backendRoles, addAllBackendRoles, access, dryRun);
}

@Override
Expand Down Expand Up @@ -181,7 +200,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
@Override
public void writeTo(StreamOutput output) throws IOException {
output.writeString(name);
output.writeString(description);
output.writeOptionalString(description);
output.writeString(version);
output.writeString(protocol);
if (parameters != null) {
Expand Down Expand Up @@ -211,20 +230,19 @@ public void writeTo(StreamOutput output) throws IOException {
} else {
output.writeBoolean(false);
}
if (addAllBackendRoles != null) {
output.writeBoolean(addAllBackendRoles);
}
output.writeOptionalBoolean(addAllBackendRoles);
if (access != null) {
output.writeBoolean(true);
output.writeEnum(access);
} else {
output.writeBoolean(false);
}
output.writeBoolean(dryRun);
}

public MLCreateConnectorInput(StreamInput input) throws IOException {
name = input.readString();
description = input.readString();
description = input.readOptionalString();
version = input.readString();
protocol = input.readString();
if (input.readBoolean()) {
Expand All @@ -247,5 +265,6 @@ public MLCreateConnectorInput(StreamInput input) throws IOException {
if (input.readBoolean()) {
this.access = input.readEnum(AccessMode.class);
}
dryRun = input.readBoolean();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.transport.connector;

import org.junit.Assert;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionRequestValidationException;
import org.opensearch.common.io.stream.BytesStreamOutput;

import java.io.IOException;
import java.io.UncheckedIOException;

public class MLCreateConnectorRequestTest {

@Rule
public ExpectedException exceptionRule = ExpectedException.none();

@Test
public void validate_nullInput() {
MLCreateConnectorRequest request = new MLCreateConnectorRequest((MLCreateConnectorInput)null);
ActionRequestValidationException exception = request.validate();
Assert.assertTrue(exception.getMessage().contains("ML Connector input can't be null"));
}

@Test
public void readFromStream() throws IOException {
MLCreateConnectorInput input = MLCreateConnectorInput.builder()
.name("test_connector")
.protocol("http")
.version("1")
.description("test")
.build();
MLCreateConnectorRequest request = new MLCreateConnectorRequest(input);
BytesStreamOutput output = new BytesStreamOutput();
request.writeTo(output);
MLCreateConnectorRequest request2 = new MLCreateConnectorRequest(output.bytes().streamInput());
Assert.assertEquals("test_connector", request2.getMlCreateConnectorInput().getName());
Assert.assertEquals("http", request2.getMlCreateConnectorInput().getProtocol());
Assert.assertEquals("1", request2.getMlCreateConnectorInput().getVersion());
Assert.assertEquals("test", request2.getMlCreateConnectorInput().getDescription());
}

@Test
public void fromActionRequest() {
MLCreateConnectorInput input = MLCreateConnectorInput.builder()
.name("test_connector")
.protocol("http")
.version("1")
.description("test")
.build();
ActionRequest request = new MLCreateConnectorRequest(input);
MLCreateConnectorRequest request2 = MLCreateConnectorRequest.fromActionRequest(request);
Assert.assertEquals("test_connector", request2.getMlCreateConnectorInput().getName());
Assert.assertEquals("http", request2.getMlCreateConnectorInput().getProtocol());
Assert.assertEquals("1", request2.getMlCreateConnectorInput().getVersion());
Assert.assertEquals("test", request2.getMlCreateConnectorInput().getDescription());
}

@Test
public void fromActionRequest_Exception() {
exceptionRule.expect(UncheckedIOException.class);
exceptionRule.expectMessage("Failed to parse ActionRequest into MLCreateConnectorRequest");
ActionRequest request = new MLConnectorGetRequest("test_id", true);
MLCreateConnectorRequest.fromActionRequest(request);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.transport.connector;

import org.junit.Assert;
import org.junit.Test;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.ml.common.TestHelper;

import java.io.IOException;

public class MLCreateConnectorResponseTest {

@Test
public void toXContent() throws IOException {
MLCreateConnectorResponse response = new MLCreateConnectorResponse("test_id");
XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
response.toXContent(builder, ToXContent.EMPTY_PARAMS);
String content = TestHelper.xContentBuilderToString(builder);
Assert.assertEquals("{\"connector_id\":\"test_id\"}", content);
}

@Test
public void readFromStream() throws IOException {
MLCreateConnectorResponse response = new MLCreateConnectorResponse("test_id");
BytesStreamOutput output = new BytesStreamOutput();
response.writeTo(output);

MLCreateConnectorResponse response2 = new MLCreateConnectorResponse(output.bytes().streamInput());
Assert.assertEquals("test_id", response2.getConnectorId());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ public TransportCreateConnectorAction(
protected void doExecute(Task task, ActionRequest request, ActionListener<MLCreateConnectorResponse> listener) {
MLCreateConnectorRequest mlCreateConnectorRequest = MLCreateConnectorRequest.fromActionRequest(request);
MLCreateConnectorInput mlCreateConnectorInput = mlCreateConnectorRequest.getMlCreateConnectorInput();
if (MLCreateConnectorInput.DRY_RUN_CONNECTOR_NAME.equals(mlCreateConnectorInput.getName())) {
if (mlCreateConnectorInput.isDryRun()) {
MLCreateConnectorResponse response = new MLCreateConnectorResponse(MLCreateConnectorInput.DRY_RUN_CONNECTOR_NAME);
listener.onResponse(response);
return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ private void doRegister(MLRegisterModelInput registerModelInput, ActionListener<
log.error(e.getMessage(), e);
listener.onFailure(e);
});
MLCreateConnectorRequest mlCreateConnectorRequest = createConnectorRequest();
MLCreateConnectorRequest mlCreateConnectorRequest = createDryRunConnectorRequest();
client.execute(MLCreateConnectorAction.INSTANCE, mlCreateConnectorRequest, dryRunResultListener);
}
} else {
Expand All @@ -207,8 +207,8 @@ private void createModelGroup(MLRegisterModelInput registerModelInput, ActionLis
}
}

private MLCreateConnectorRequest createConnectorRequest() {
MLCreateConnectorInput createConnectorInput = MLCreateConnectorInput.builder().name("dryRunConnector").build();
private MLCreateConnectorRequest createDryRunConnectorRequest() {
MLCreateConnectorInput createConnectorInput = MLCreateConnectorInput.builder().dryRun(true).build();
return new MLCreateConnectorRequest(createConnectorInput);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@ public void setup() {
Map<String, String> credential = ImmutableMap.of("access_key", "mockKey", "secret_key", "mockSecret");
input = MLCreateConnectorInput
.builder()
.name("test_name")
.version("1")
.actions(actions)
.parameters(parameters)
.protocol(ConnectorProtocols.HTTP)
Expand Down Expand Up @@ -430,6 +432,7 @@ public void test_execute_dryRun_connector_creation() {

MLCreateConnectorInput mlCreateConnectorInput = mock(MLCreateConnectorInput.class);
when(mlCreateConnectorInput.getName()).thenReturn(MLCreateConnectorInput.DRY_RUN_CONNECTOR_NAME);
when(mlCreateConnectorInput.isDryRun()).thenReturn(true);
MLCreateConnectorRequest request = new MLCreateConnectorRequest(mlCreateConnectorInput);
action.doExecute(task, request, actionListener);
verify(actionListener).onResponse(any(MLCreateConnectorResponse.class));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,7 @@ public void test_execute_registerRemoteModel_withInternalConnector_success() {
MLRegisterModelInput input = mock(MLRegisterModelInput.class);
when(request.getRegisterModelInput()).thenReturn(input);
when(input.getModelName()).thenReturn("Test Model");
when(input.getVersion()).thenReturn("1");
when(input.getModelGroupId()).thenReturn("modelGroupID");
when(input.getFunctionName()).thenReturn(FunctionName.REMOTE);
Connector connector = mock(Connector.class);
Expand Down