Skip to content

Commit

Permalink
Update Model API (#1350)
Browse files Browse the repository at this point in the history
* Update Model API POC

Signed-off-by: Sicheng Song <[email protected]>

* Using GetRequest to get model

Signed-off-by: Sicheng Song <[email protected]>

* Finalize model update API

Signed-off-by: Sicheng Song <[email protected]>

* Fix compile

Signed-off-by: Sicheng Song <[email protected]>

* Fix compileTest

Signed-off-by: Sicheng Song <[email protected]>

* Add Unit Test Cases for Update Model API

Signed-off-by: Sicheng Song <[email protected]>

* Tune back test coverage thereshold

Signed-off-by: Sicheng Song <[email protected]>

* Add more unit tests on Update model API

Signed-off-by: Sicheng Song <[email protected]>

* Add unit test for TransportUpdateModelAction class

Signed-off-by: Sicheng Song <[email protected]>

* Fix a test error

Signed-off-by: Sicheng Song <[email protected]>

* Change exception thrown to failure response

Signed-off-by: Sicheng Song <[email protected]>

* Move the function judgement to the outter block

Signed-off-by: Sicheng Song <[email protected]>

* Check if model is undeployed before update model

Signed-off-by: Sicheng Song <[email protected]>

* Add more unit test for update model API

Signed-off-by: Sicheng Song <[email protected]>

* Fix unit test due to blocking java 11 CI workflow

Signed-off-by: Sicheng Song <[email protected]>

* Enabling auto bumping model version during registering to a new model group and address reviewers' other concern

Signed-off-by: Sicheng Song <[email protected]>

* Autobump new model groups' latest version when register to a new model

Signed-off-by: Sicheng Song <[email protected]>

* Change the REST API method from POST to PUT

Signed-off-by: Sicheng Song <[email protected]>

* Change the update REST API endpoint

Signed-off-by: Sicheng Song <[email protected]>

---------

Signed-off-by: Sicheng Song <[email protected]>
  • Loading branch information
b4sjoo authored and ylwu-amzn committed Nov 20, 2023
1 parent d7ae636 commit 62d3e74
Show file tree
Hide file tree
Showing 14 changed files with 2,007 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ public class MLModelGroup implements ToXContentObject {
@Setter
private String name;
private String description;
@Setter
private int latestVersion;
private List<String> backendRoles;
private User owner;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.transport.model;

import org.opensearch.action.ActionType;
import org.opensearch.action.update.UpdateResponse;

public class MLUpdateModelAction extends ActionType<UpdateResponse> {
public static MLUpdateModelAction INSTANCE = new MLUpdateModelAction();
public static final String NAME = "cluster:admin/opensearch/ml/models/update";

private MLUpdateModelAction() {
super(NAME, UpdateResponse::new);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.transport.model;

import lombok.Data;
import lombok.Builder;
import lombok.Getter;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.common.io.stream.Writeable;
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.connector.Connector;
import org.opensearch.ml.common.model.MLModelConfig;
import org.opensearch.ml.common.model.TextEmbeddingModelConfig;

import java.io.IOException;
import java.util.Map;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.connector.Connector.createConnector;

@Data
public class MLUpdateModelInput implements ToXContentObject, Writeable {

public static final String MODEL_ID_FIELD = "model_id"; // mandatory
public static final String DESCRIPTION_FIELD = "description"; // optional
public static final String MODEL_VERSION_FIELD = "model_version"; // optional
public static final String MODEL_NAME_FIELD = "name"; // optional
public static final String MODEL_GROUP_ID_FIELD = "model_group_id"; // optional
public static final String MODEL_CONFIG_FIELD = "model_config"; // optional
public static final String CONNECTOR_ID_FIELD = "connector_id"; // optional

@Getter
private String modelId;
private String description;
private String version;
private String name;
private String modelGroupId;
private MLModelConfig modelConfig;
private String connectorId;

@Builder(toBuilder = true)
public MLUpdateModelInput(String modelId, String description, String version, String name, String modelGroupId, MLModelConfig modelConfig, String connectorId) {
this.modelId = modelId;
this.description = description;
this.version = version;
this.name = name;
this.modelGroupId = modelGroupId;
this.modelConfig = modelConfig;
this.connectorId = connectorId;
}

public MLUpdateModelInput(StreamInput in) throws IOException {
this.modelId = in.readString();
this.description = in.readOptionalString();
this.version = in.readOptionalString();
this.name = in.readOptionalString();
this.modelGroupId = in.readOptionalString();
if (in.readBoolean()) {
modelConfig = new TextEmbeddingModelConfig(in);
}
this.connectorId = in.readOptionalString();
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(MODEL_ID_FIELD, modelId);
if (name != null) {
builder.field(MODEL_NAME_FIELD, name);
}
if (description != null) {
builder.field(DESCRIPTION_FIELD, description);
}
if (version != null) {
builder.field(MODEL_VERSION_FIELD, version);
}
if (modelGroupId != null) {
builder.field(MODEL_GROUP_ID_FIELD, modelGroupId);
}
if (modelConfig != null) {
builder.field(MODEL_CONFIG_FIELD, modelConfig);
}
if (connectorId != null) {
builder.field(CONNECTOR_ID_FIELD, connectorId);
}
builder.endObject();
return builder;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(modelId);
out.writeOptionalString(description);
out.writeOptionalString(version);
out.writeOptionalString(name);
out.writeOptionalString(modelGroupId);
if (modelConfig != null) {
out.writeBoolean(true);
modelConfig.writeTo(out);
} else {
out.writeBoolean(false);
}
out.writeOptionalString(connectorId);
}

public static MLUpdateModelInput parse(XContentParser parser) throws IOException {
String modelId = null;
String description = null;
String version = null;
String name = null;
String modelGroupId = null;
MLModelConfig modelConfig = null;
String connectorId = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
String fieldName = parser.currentName();
parser.nextToken();
switch (fieldName) {
case MODEL_ID_FIELD:
modelId = parser.text();
break;
case DESCRIPTION_FIELD:
description = parser.text();
break;
case MODEL_NAME_FIELD:
name = parser.text();
break;
case MODEL_VERSION_FIELD:
version = parser.text();
break;
case MODEL_GROUP_ID_FIELD:
modelGroupId = parser.text();
break;
case MODEL_CONFIG_FIELD:
modelConfig = TextEmbeddingModelConfig.parse(parser);
break;
case CONNECTOR_ID_FIELD:
connectorId = parser.text();
break;
default:
parser.skipChildren();
break;
}
}
// Model ID can only be set through RestRequest. Model version can only be set automatically.
return new MLUpdateModelInput(modelId, description, version, name, modelGroupId, modelConfig, connectorId);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.transport.model;

import lombok.AccessLevel;
import lombok.Builder;
import lombok.Getter;
import lombok.ToString;
import lombok.experimental.FieldDefaults;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionRequestValidationException;
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.UncheckedIOException;

import static org.opensearch.action.ValidateActions.addValidationError;

@Getter
@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)
@ToString
public class MLUpdateModelRequest extends ActionRequest {

MLUpdateModelInput updateModelInput;

@Builder
public MLUpdateModelRequest(MLUpdateModelInput updateModelInput) {
this.updateModelInput = updateModelInput;
}

public MLUpdateModelRequest(StreamInput in) throws IOException {
super(in);
updateModelInput = new MLUpdateModelInput(in);
}

@Override
public ActionRequestValidationException validate() {
ActionRequestValidationException exception = null;
if (updateModelInput == null) {
exception = addValidationError("Update Model Input can't be null", exception);
}

return exception;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
this.updateModelInput.writeTo(out);
}

public static MLUpdateModelRequest fromActionRequest(ActionRequest actionRequest){
if (actionRequest instanceof MLUpdateModelRequest) {
return (MLUpdateModelRequest) actionRequest;
}

try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) {
actionRequest.writeTo(osso);
try (StreamInput in = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) {
return new MLUpdateModelRequest(in);
}
} catch (IOException e) {
throw new UncheckedIOException("Failed to parse ActionRequest into MLUpdateModelRequest", e);
}
}
}
Loading

0 comments on commit 62d3e74

Please sign in to comment.