Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

create model group automatically with first model version #1063

Merged
merged 5 commits into from
Jul 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion common/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ jacocoTestCoverageVerification {
}
limit {
counter = 'BRANCH'
minimum = 0.6 //TODO: add more test to meet the coverage bar 0.9
minimum = 0.5 //TODO: add more test to meet the coverage bar 0.9
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ public class MLRegisterModelGroupInput implements ToXContentObject, Writeable{
public static final String NAME_FIELD = "name"; //mandatory
public static final String DESCRIPTION_FIELD = "description"; //optional
public static final String BACKEND_ROLES_FIELD = "backend_roles"; //optional
public static final String MODEL_ACCESS_MODE = "model_access_mode"; //optional
public static final String MODEL_ACCESS_MODE = "access_mode"; //optional
public static final String ADD_ALL_BACKEND_ROLES = "add_all_backend_roles"; //optional

private String name;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ public class MLUpdateModelGroupInput implements ToXContentObject, Writeable {
public static final String NAME_FIELD = "name"; //optional
public static final String DESCRIPTION_FIELD = "description"; //optional
public static final String BACKEND_ROLES_FIELD = "backend_roles"; //optional
public static final String MODEL_ACCESS_MODE = "model_access_mode"; //optional
public static final String MODEL_ACCESS_MODE = "access_mode"; //optional
public static final String ADD_ALL_BACKEND_ROLES_FIELD = "add_all_backend_roles"; //optional


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,6 @@ public MLRegisterModelInput(FunctionName functionName,
if (modelName == null) {
throw new IllegalArgumentException("model name is null");
}
if (modelGroupId == null) {
throw new IllegalArgumentException("model group id is null");
}
if (functionName != FunctionName.REMOTE) {
if (modelFormat == null) {
throw new IllegalArgumentException("model format is null");
Expand Down Expand Up @@ -131,7 +128,7 @@ public MLRegisterModelInput(FunctionName functionName,
public MLRegisterModelInput(StreamInput in) throws IOException {
this.functionName = in.readEnum(FunctionName.class);
this.modelName = in.readString();
this.modelGroupId = in.readString();
this.modelGroupId = in.readOptionalString();
this.version = in.readOptionalString();
this.description = in.readOptionalString();
this.url = in.readOptionalString();
Expand Down Expand Up @@ -161,7 +158,7 @@ public MLRegisterModelInput(StreamInput in) throws IOException {
public void writeTo(StreamOutput out) throws IOException {
out.writeEnum(functionName);
out.writeString(modelName);
out.writeString(modelGroupId);
out.writeOptionalString(modelGroupId);
out.writeOptionalString(version);
out.writeOptionalString(description);
out.writeOptionalString(url);
Expand Down Expand Up @@ -207,8 +204,12 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.startObject();
builder.field(FUNCTION_NAME_FIELD, functionName);
builder.field(NAME_FIELD, modelName);
builder.field(VERSION_FIELD, version);
builder.field(MODEL_GROUP_ID_FIELD, modelGroupId);
if (version != null) {
builder.field(VERSION_FIELD, version);
}
if (modelGroupId != null) {
builder.field(MODEL_GROUP_ID_FIELD, modelGroupId);
}
if (description != null) {
builder.field(DESCRIPTION_FIELD, description);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,16 @@
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.AccessMode;
import org.opensearch.ml.common.model.MLModelConfig;
import org.opensearch.ml.common.model.MLModelFormat;
import org.opensearch.ml.common.model.MLModelState;
import org.opensearch.ml.common.model.TextEmbeddingModelConfig;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;

import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken;

Expand All @@ -29,20 +33,26 @@ public class MLRegisterModelMetaInput implements ToXContentObject, Writeable{

public static final String FUNCTION_NAME_FIELD = "function_name";
public static final String MODEL_NAME_FIELD = "name"; //mandatory
public static final String DESCRIPTION_FIELD = "description";
public static final String DESCRIPTION_FIELD = "description"; //optional

public static final String VERSION_FIELD = "version";
public static final String MODEL_FORMAT_FIELD = "model_format"; //mandatory
public static final String MODEL_STATE_FIELD = "model_state";
public static final String MODEL_CONTENT_SIZE_IN_BYTES_FIELD = "model_content_size_in_bytes";
public static final String MODEL_CONTENT_HASH_VALUE_FIELD = "model_content_hash_value"; //mandatory
public static final String MODEL_CONFIG_FIELD = "model_config"; //mandatory
public static final String TOTAL_CHUNKS_FIELD = "total_chunks"; //mandatory
public static final String MODEL_GROUP_ID_FIELD = "model_group_id"; //mandatory
public static final String MODEL_GROUP_ID_FIELD = "model_group_id"; //optional
public static final String BACKEND_ROLES_FIELD = "backend_roles"; //optional
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see all access control fields are optional now. Per my understanding, if user doesn't provide any of these, the model group will be private, right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes. It will be created as private

public static final String ACCESS_MODE = "access_mode"; //optional
public static final String ADD_ALL_BACKEND_ROLES = "add_all_backend_roles"; //optional

private FunctionName functionName;
private String name;

private String modelGroupId;
private String description;
private String version;

private MLModelFormat modelFormat;

Expand All @@ -52,9 +62,14 @@ public class MLRegisterModelMetaInput implements ToXContentObject, Writeable{
private String modelContentHashValue;
private MLModelConfig modelConfig;
private Integer totalChunks;
private List<String> backendRoles;
private AccessMode accessMode;
private Boolean isAddAllBackendRoles;

@Builder(toBuilder = true)
public MLRegisterModelMetaInput(String name, FunctionName functionName, String modelGroupId, String description, MLModelFormat modelFormat, MLModelState modelState, Long modelContentSizeInBytes, String modelContentHashValue, MLModelConfig modelConfig, Integer totalChunks) {
public MLRegisterModelMetaInput(String name, FunctionName functionName, String modelGroupId, String version, String description, MLModelFormat modelFormat, MLModelState modelState, Long modelContentSizeInBytes, String modelContentHashValue, MLModelConfig modelConfig, Integer totalChunks, List<String> backendRoles,
AccessMode accessMode,
Boolean isAddAllBackendRoles) {
if (name == null) {
throw new IllegalArgumentException("model name is null");
}
Expand All @@ -63,9 +78,6 @@ public MLRegisterModelMetaInput(String name, FunctionName functionName, String m
} else {
this.functionName = functionName;
}
if (modelGroupId == null) {
throw new IllegalArgumentException("model group id is null");
}
if (modelFormat == null) {
throw new IllegalArgumentException("model format is null");
}
Expand All @@ -80,19 +92,24 @@ public MLRegisterModelMetaInput(String name, FunctionName functionName, String m
}
this.name = name;
this.modelGroupId = modelGroupId;
this.version = version;
this.description = description;
this.modelFormat = modelFormat;
this.modelState = modelState;
this.modelContentSizeInBytes = modelContentSizeInBytes;
this.modelContentHashValue = modelContentHashValue;
this.modelConfig = modelConfig;
this.totalChunks = totalChunks;
this.backendRoles = backendRoles;
this.accessMode = accessMode;
this.isAddAllBackendRoles = isAddAllBackendRoles;
}

public MLRegisterModelMetaInput(StreamInput in) throws IOException{
this.name = in.readString();
this.functionName = in.readEnum(FunctionName.class);
this.modelGroupId = in.readString();
this.modelGroupId = in.readOptionalString();
this.version = in.readOptionalString();
this.description = in.readOptionalString();
if (in.readBoolean()) {
modelFormat = in.readEnum(MLModelFormat.class);
Expand All @@ -106,13 +123,19 @@ public MLRegisterModelMetaInput(StreamInput in) throws IOException{
modelConfig = new TextEmbeddingModelConfig(in);
}
this.totalChunks = in.readInt();
this.backendRoles = in.readOptionalStringList();
if (in.readBoolean()) {
accessMode = in.readEnum(AccessMode.class);
}
this.isAddAllBackendRoles = in.readOptionalBoolean();
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(name);
out.writeEnum(functionName);
out.writeString(modelGroupId);
out.writeOptionalString(modelGroupId);
out.writeOptionalString(version);
out.writeOptionalString(description);
if (modelFormat != null) {
out.writeBoolean(true);
Expand All @@ -135,14 +158,32 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeBoolean(false);
}
out.writeInt(totalChunks);
if (backendRoles != null) {
out.writeBoolean(true);
out.writeStringCollection(backendRoles);
} else {
out.writeBoolean(false);
}
if (accessMode != null) {
out.writeBoolean(true);
out.writeEnum(accessMode);
} else {
out.writeBoolean(false);
}
out.writeOptionalBoolean(isAddAllBackendRoles);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
builder.startObject();
builder.field(MODEL_NAME_FIELD, name);
builder.field(FUNCTION_NAME_FIELD, functionName);
builder.field(MODEL_GROUP_ID_FIELD, modelGroupId);
if (modelGroupId != null) {
builder.field(MODEL_GROUP_ID_FIELD, modelGroupId);
}
if (version != null) {
builder.field(VERSION_FIELD, version);
}
if (description != null) {
builder.field(DESCRIPTION_FIELD, description);
}
Expand All @@ -156,13 +197,23 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par
builder.field(MODEL_CONTENT_HASH_VALUE_FIELD, modelContentHashValue);
builder.field(MODEL_CONFIG_FIELD, modelConfig);
builder.field(TOTAL_CHUNKS_FIELD, totalChunks);
if (backendRoles != null && backendRoles.size() > 0) {
builder.field(BACKEND_ROLES_FIELD, backendRoles);
}
if (accessMode != null) {
builder.field(ACCESS_MODE, accessMode);
}
if (isAddAllBackendRoles != null) {
builder.field(ADD_ALL_BACKEND_ROLES, isAddAllBackendRoles);
}
builder.endObject();
return builder;
}

public static MLRegisterModelMetaInput parse(XContentParser parser) throws IOException {
String name = null;
FunctionName functionName = null;
String modelGroupId = null;
String version = null;
String description = null;
MLModelFormat modelFormat = null;
Expand All @@ -171,6 +222,9 @@ public static MLRegisterModelMetaInput parse(XContentParser parser) throws IOExc
String modelContentHashValue = null;
MLModelConfig modelConfig = null;
Integer totalChunks = null;
List<String> backendRoles = null;
AccessMode accessMode = null;
Boolean isAddAllBackendRoles = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
Expand All @@ -184,6 +238,9 @@ public static MLRegisterModelMetaInput parse(XContentParser parser) throws IOExc
functionName = FunctionName.from(parser.text());
break;
case MODEL_GROUP_ID_FIELD:
modelGroupId = parser.text();
break;
case VERSION_FIELD:
version = parser.text();
break;
case DESCRIPTION_FIELD:
Expand All @@ -207,12 +264,25 @@ public static MLRegisterModelMetaInput parse(XContentParser parser) throws IOExc
case TOTAL_CHUNKS_FIELD:
totalChunks = parser.intValue(false);
break;
case BACKEND_ROLES_FIELD:
backendRoles = new ArrayList<>();
ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_ARRAY) {
backendRoles.add(parser.text());
}
break;
case ACCESS_MODE:
accessMode = AccessMode.from(parser.text().toLowerCase(Locale.ROOT));
break;
case ADD_ALL_BACKEND_ROLES:
isAddAllBackendRoles = parser.booleanValue();
break;
default:
parser.skipChildren();
break;
}
}
return new MLRegisterModelMetaInput(name, functionName, version, description, modelFormat, modelState, modelContentSizeInBytes, modelContentHashValue, modelConfig, totalChunks);
return new MLRegisterModelMetaInput(name, functionName, modelGroupId, version, description, modelFormat, modelState, modelContentSizeInBytes, modelContentHashValue, modelConfig, totalChunks, backendRoles, accessMode, isAddAllBackendRoles);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -96,17 +96,6 @@ public void constructor_NullModelName() {
.build();
}

@Test
public void constructor_NullModelGroupId() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("model group id is null");
MLRegisterModelInput.builder()
.functionName(functionName)
.modelName(modelName)
.modelGroupId(null)
.build();
}

@Test
public void constructor_NullModelFormat() {
exceptionRule.expect(IllegalArgumentException.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ public class MLRegisterModelMetaInputTest {
public void setup() {
config = new TextEmbeddingModelConfig("Model Type", 123, FrameworkType.SENTENCE_TRANSFORMERS, "All Config",
TextEmbeddingModelConfig.PoolingMode.MEAN, true, 512);
mLRegisterModelMetaInput = new MLRegisterModelMetaInput("Model Name", FunctionName.BATCH_RCF, "1.0",
"Model Description", MLModelFormat.TORCH_SCRIPT, MLModelState.DEPLOYING, 200L, "123", config, 2);
mLRegisterModelMetaInput = new MLRegisterModelMetaInput("Model Name", FunctionName.BATCH_RCF, "model_group_id", "1.0",
"Model Description", MLModelFormat.TORCH_SCRIPT, MLModelState.DEPLOYING, 200L, "123", config, 2, null, null, null);
}

@Test
Expand Down Expand Up @@ -75,14 +75,14 @@ public void testToXContent() throws IOException {{
XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
mLRegisterModelMetaInput.toXContent(builder, EMPTY_PARAMS);
String mlModelContent = TestHelper.xContentBuilderToString(builder);
final String expected = "{\"name\":\"Model Name\",\"function_name\":\"BATCH_RCF\",\"model_group_id\":\"1.0\",\"description\":\"Model Description\",\"model_format\":\"TORCH_SCRIPT\",\"model_state\":\"DEPLOYING\",\"model_content_size_in_bytes\":200,\"model_content_hash_value\":\"123\",\"model_config\":{\"model_type\":\"Model Type\"," +
"\"embedding_dimension\":123,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"All Config\",\"model_max_length\":512,\"pooling_mode\":\"MEAN\",\"normalize_result\":true},\"total_chunks\":2}";
final String expected = "{\"name\":\"Model Name\",\"function_name\":\"BATCH_RCF\",\"model_group_id\":\"model_group_id\",\"version\":\"1.0\",\"description\":\"Model Description\",\"model_format\":\"TORCH_SCRIPT\",\"model_state\":\"DEPLOYING\",\"model_content_size_in_bytes\":200,\"model_content_hash_value\":\"123\",\"model_config\":{\"model_type\":\"Model Type\"," +
"\"embedding_dimension\":123,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"All Config\",\"model_max_length\":512,\"pooling_mode\":\"MEAN\",\"normalize_result\":true},\"total_chunks\":2}";
assertEquals(expected, mlModelContent);
}
XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
mLRegisterModelMetaInput.toXContent(builder, EMPTY_PARAMS);
String mlModelContent = TestHelper.xContentBuilderToString(builder);
final String expected = "{\"name\":\"Model Name\",\"function_name\":\"BATCH_RCF\",\"model_group_id\":\"1.0\",\"description\":\"Model Description\",\"model_format\":\"TORCH_SCRIPT\",\"model_state\":\"DEPLOYING\",\"model_content_size_in_bytes\":200,\"model_content_hash_value\":\"123\",\"model_config\":{\"model_type\":\"Model Type\"," +
final String expected = "{\"name\":\"Model Name\",\"function_name\":\"BATCH_RCF\",\"model_group_id\":\"model_group_id\",\"version\":\"1.0\",\"description\":\"Model Description\",\"model_format\":\"TORCH_SCRIPT\",\"model_state\":\"DEPLOYING\",\"model_content_size_in_bytes\":200,\"model_content_hash_value\":\"123\",\"model_config\":{\"model_type\":\"Model Type\"," +
"\"embedding_dimension\":123,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"All Config\",\"model_max_length\":512,\"pooling_mode\":\"MEAN\",\"normalize_result\":true},\"total_chunks\":2}";
assertEquals(expected, mlModelContent);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ public class MLRegisterModelMetaRequestTest {
public void setUp() {
config = new TextEmbeddingModelConfig("Model Type", 123, FrameworkType.SENTENCE_TRANSFORMERS, "All Config",
TextEmbeddingModelConfig.PoolingMode.MEAN, true, 512);
mlRegisterModelMetaInput = new MLRegisterModelMetaInput("Model Name", FunctionName.BATCH_RCF, "Model Group Id",
"Model Description", MLModelFormat.TORCH_SCRIPT, MLModelState.DEPLOYING, 200L, "123", config, 2);
mlRegisterModelMetaInput = new MLRegisterModelMetaInput("Model Name", FunctionName.BATCH_RCF, "Model Group Id", "1.0",
"Model Description", MLModelFormat.TORCH_SCRIPT, MLModelState.DEPLOYING, 200L, "123", config, 2, null, null, null);
}

@Test
Expand Down
Loading