Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix update connector API #1484

Merged
merged 2 commits into from
Oct 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,7 @@
import org.opensearch.ml.common.AccessMode;
import org.opensearch.ml.common.MLCommonsClassLoader;
import org.opensearch.ml.common.output.model.ModelTensor;

import java.io.IOException;
import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Function;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.utils.StringUtils.gson;
Expand Down Expand Up @@ -80,6 +70,7 @@ public interface Connector extends ToXContentObject, Writeable {

void writeTo(StreamOutput out) throws IOException;

void update(MLCreateConnectorInput updateContent, Function<String, String> function);

<T> void parseResponse(T orElse, List<ModelTensor> modelTensors, boolean b) throws IOException;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.AccessMode;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput;

@Log4j2
@NoArgsConstructor
Expand Down Expand Up @@ -248,6 +249,38 @@ public void writeTo(StreamOutput out) throws IOException {
}
}

@Override
public void update(MLCreateConnectorInput updateContent, Function<String, String> function) {
if (updateContent.getName() != null) {
this.name = updateContent.getName();
}
if (updateContent.getDescription() != null) {
this.description = updateContent.getDescription();
}
if (updateContent.getVersion() != null) {
this.version = updateContent.getVersion();
}
if (updateContent.getProtocol() != null) {
this.protocol = updateContent.getProtocol();
}
if (updateContent.getParameters() != null && updateContent.getParameters().size() > 0) {
this.parameters = updateContent.getParameters();
}
if (updateContent.getCredential() != null && updateContent.getCredential().size() > 0) {
this.credential = updateContent.getCredential();
encrypt(function);
}
if (updateContent.getActions() != null) {
this.actions = updateContent.getActions();
}
if (updateContent.getBackendRoles() != null) {
this.backendRoles = updateContent.getBackendRoles();
}
if (updateContent.getAccess() != null) {
this.access = updateContent.getAccess();
}
}

@Override
public <T> T createPredictPayload(Map<String, String> parameters) {
Optional<ConnectorAction> predictAction = findPredictAction();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ public class MLCreateConnectorInput implements ToXContentObject, Writeable {
private Boolean addAllBackendRoles;
private AccessMode access;
private boolean dryRun = false;
private boolean updateConnector = false;

@Builder(toBuilder = true)
public MLCreateConnectorInput(String name,
Expand All @@ -68,9 +69,10 @@ public MLCreateConnectorInput(String name,
List<String> backendRoles,
Boolean addAllBackendRoles,
AccessMode access,
boolean dryRun
boolean dryRun,
boolean updateConnector
) {
if (!dryRun) {
if (!dryRun && !updateConnector) {
if (name == null) {
throw new IllegalArgumentException("Connector name is null");
}
Expand All @@ -92,9 +94,14 @@ public MLCreateConnectorInput(String name,
this.addAllBackendRoles = addAllBackendRoles;
this.access = access;
this.dryRun = dryRun;
this.updateConnector = updateConnector;
}

public static MLCreateConnectorInput parse(XContentParser parser) throws IOException {
return parse(parser, false);
}

public static MLCreateConnectorInput parse(XContentParser parser, boolean updateConnector) throws IOException {
String name = null;
String description = null;
String version = null;
Expand Down Expand Up @@ -159,7 +166,7 @@ public static MLCreateConnectorInput parse(XContentParser parser) throws IOExcep
break;
}
}
return new MLCreateConnectorInput(name, description, version, protocol, parameters, credential, actions, backendRoles, addAllBackendRoles, access, dryRun);
return new MLCreateConnectorInput(name, description, version, protocol, parameters, credential, actions, backendRoles, addAllBackendRoles, access, dryRun, updateConnector);
}

@Override
Expand Down Expand Up @@ -201,10 +208,10 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws

@Override
public void writeTo(StreamOutput output) throws IOException {
output.writeString(name);
output.writeOptionalString(name);
output.writeOptionalString(description);
output.writeString(version);
output.writeString(protocol);
output.writeOptionalString(version);
output.writeOptionalString(protocol);
if (parameters != null) {
output.writeBoolean(true);
output.writeMap(parameters, StreamOutput::writeString, StreamOutput::writeString);
Expand Down Expand Up @@ -240,13 +247,14 @@ public void writeTo(StreamOutput output) throws IOException {
output.writeBoolean(false);
}
output.writeBoolean(dryRun);
output.writeBoolean(updateConnector);
}

public MLCreateConnectorInput(StreamInput input) throws IOException {
name = input.readString();
name = input.readOptionalString();
description = input.readOptionalString();
version = input.readString();
protocol = input.readString();
version = input.readOptionalString();
protocol = input.readOptionalString();
if (input.readBoolean()) {
parameters = input.readMap(s -> s.readString(), s -> s.readString());
}
Expand All @@ -268,5 +276,6 @@ public MLCreateConnectorInput(StreamInput input) throws IOException {
this.access = input.readEnum(AccessMode.class);
}
dryRun = input.readBoolean();
updateConnector = input.readBoolean();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,32 +19,31 @@
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.Map;

import static org.opensearch.action.ValidateActions.addValidationError;

@Getter
public class MLUpdateConnectorRequest extends ActionRequest {
String connectorId;
Map<String, Object> updateContent;
MLCreateConnectorInput updateContent;

@Builder
public MLUpdateConnectorRequest(String connectorId, Map<String, Object> updateContent) {
public MLUpdateConnectorRequest(String connectorId, MLCreateConnectorInput updateContent) {
this.connectorId = connectorId;
this.updateContent = updateContent;
}

public MLUpdateConnectorRequest(StreamInput in) throws IOException {
super(in);
this.connectorId = in.readString();
this.updateContent = in.readMap();
this.updateContent = new MLCreateConnectorInput(in);
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeString(this.connectorId);
out.writeMap(this.getUpdateContent());
this.updateContent.writeTo(out);
}

@Override
Expand All @@ -55,14 +54,17 @@ public ActionRequestValidationException validate() {
exception = addValidationError("ML connector id can't be null", exception);
}

if (updateContent == null) {
exception = addValidationError("Update connector content can't be null", exception);
}

return exception;
}

public static MLUpdateConnectorRequest parse(XContentParser parser, String connectorId) throws IOException {
Map<String, Object> dataAsMap = null;
dataAsMap = parser.map();
MLCreateConnectorInput updateContent = MLCreateConnectorInput.parse(parser, true);

return MLUpdateConnectorRequest.builder().connectorId(connectorId).updateContent(dataAsMap).build();
return MLUpdateConnectorRequest.builder().connectorId(connectorId).updateContent(updateContent).build();
}

public static MLUpdateConnectorRequest fromActionRequest(ActionRequest actionRequest) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,38 +7,37 @@

import org.junit.Before;
import org.junit.Test;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionRequestValidationException;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.rest.RestRequest;
import org.opensearch.search.SearchModule;

import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.Map;
import java.util.Collections;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotSame;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertSame;
import static org.mockito.Mockito.when;
import static org.junit.Assert.assertTrue;

public class MLUpdateConnectorRequestTests {
private String connectorId;
private Map<String, Object> updateContent;
private MLCreateConnectorInput updateContent;
private MLUpdateConnectorRequest mlUpdateConnectorRequest;

@Mock
XContentParser parser;

@Before
public void setUp() {
MockitoAnnotations.openMocks(this);
this.connectorId = "test-connector_id";
this.updateContent = Map.of("description", "new description");
this.updateContent = MLCreateConnectorInput.builder().description("new description").updateConnector(true).build();
mlUpdateConnectorRequest = MLUpdateConnectorRequest.builder()
.connectorId(connectorId)
.updateContent(updateContent)
Expand All @@ -64,18 +63,20 @@ public void validate_Exception_NullConnectorId() {
MLUpdateConnectorRequest updateConnectorRequest = MLUpdateConnectorRequest.builder().build();
Exception exception = updateConnectorRequest.validate();

assertEquals("Validation Failed: 1: ML connector id can't be null;", exception.getMessage());
assertEquals("Validation Failed: 1: ML connector id can't be null;2: Update connector content can't be null;", exception.getMessage());
}

@Test
public void parse_success() throws IOException {
RestRequest.Method method = RestRequest.Method.POST;
final Map<String, Object> updatefields = Map.of("version", "new version", "description", "new description");
when(parser.map()).thenReturn(updatefields);

String jsonStr = "{\"version\":\"new version\",\"description\":\"new description\"}";
XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY,
Collections.emptyList()).getNamedXContents()), null, jsonStr);
parser.nextToken();
MLUpdateConnectorRequest updateConnectorRequest = MLUpdateConnectorRequest.parse(parser, connectorId);
assertEquals(updateConnectorRequest.getConnectorId(), connectorId);
assertEquals(updateConnectorRequest.getUpdateContent(), updatefields);
assertTrue(updateConnectorRequest.getUpdateContent().isUpdateConnector());
assertEquals("new version", updateConnectorRequest.getUpdateContent().getVersion());
assertEquals("new description", updateConnectorRequest.getUpdateContent().getDescription());
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX;
import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

import org.opensearch.action.ActionRequest;
import org.opensearch.action.DocWriteResponse;
import org.opensearch.action.delete.DeleteRequest;
Expand Down Expand Up @@ -77,11 +81,16 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Delete
.error(
searchHits.length + " models are still using this connector, please delete or update the models first!"
);
List<String> modelIds = new ArrayList<>();
for (SearchHit hit : searchHits) {
modelIds.add(hit.getId());
}
actionListener
.onFailure(
new MLValidationException(
searchHits.length
+ " models are still using this connector, please delete or update the models first!"
+ " models are still using this connector, please delete or update the models first: "
+ Arrays.toString(modelIds.toArray(new String[0]))
)
);
}
Expand Down
Loading
Loading