Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.11] if model version fails to register, update model group accordingly #1464

Merged
merged 1 commit into from
Oct 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ public class MLRegisterModelInput implements ToXContentObject, Writeable {
public static final String ACCESS_MODE_FIELD = "access_mode";
public static final String BACKEND_ROLES_FIELD = "backend_roles";
public static final String ADD_ALL_BACKEND_ROLES_FIELD = "add_all_backend_roles";

public static final String DOES_VERSION_CREATE_MODEL_GROUP = "does_version_create_model_group";
private FunctionName functionName;
private String modelName;
private String modelGroupId;
Expand All @@ -73,6 +73,7 @@ public class MLRegisterModelInput implements ToXContentObject, Writeable {
private List<String> backendRoles;
private Boolean addAllBackendRoles;
private AccessMode accessMode;
private Boolean doesVersionCreateModelGroup;

@Builder(toBuilder = true)
public MLRegisterModelInput(FunctionName functionName,
Expand All @@ -90,7 +91,8 @@ public MLRegisterModelInput(FunctionName functionName,
String connectorId,
List<String> backendRoles,
Boolean addAllBackendRoles,
AccessMode accessMode
AccessMode accessMode,
Boolean doesVersionCreateModelGroup
) {
if (functionName == null) {
this.functionName = FunctionName.TEXT_EMBEDDING;
Expand Down Expand Up @@ -123,6 +125,7 @@ public MLRegisterModelInput(FunctionName functionName,
this.backendRoles = backendRoles;
this.addAllBackendRoles = addAllBackendRoles;
this.accessMode = accessMode;
this.doesVersionCreateModelGroup = doesVersionCreateModelGroup;
}


Expand Down Expand Up @@ -157,6 +160,7 @@ public MLRegisterModelInput(StreamInput in) throws IOException {
if (in.readBoolean()) {
this.accessMode = in.readEnum(AccessMode.class);
}
this.doesVersionCreateModelGroup = in.readOptionalBoolean();
}

@Override
Expand Down Expand Up @@ -202,6 +206,7 @@ public void writeTo(StreamOutput out) throws IOException {
} else {
out.writeBoolean(false);
}
out.writeOptionalBoolean(doesVersionCreateModelGroup);
}

@Override
Expand Down Expand Up @@ -249,6 +254,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
if (accessMode != null) {
builder.field(ACCESS_MODE_FIELD, accessMode);
}
if (doesVersionCreateModelGroup != null) {
builder.field(DOES_VERSION_CREATE_MODEL_GROUP, doesVersionCreateModelGroup);
}
builder.endObject();
return builder;
}
Expand All @@ -267,6 +275,7 @@ public static MLRegisterModelInput parse(XContentParser parser, String modelName
List<String> backendRoles = new ArrayList<>();
Boolean addAllBackendRoles = null;
AccessMode accessMode = null;
Boolean doesVersionCreateModelGroup = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
Expand Down Expand Up @@ -318,12 +327,15 @@ public static MLRegisterModelInput parse(XContentParser parser, String modelName
case ACCESS_MODE_FIELD:
accessMode = AccessMode.from(parser.text());
break;
case DOES_VERSION_CREATE_MODEL_GROUP:
doesVersionCreateModelGroup = parser.booleanValue();
break;
default:
parser.skipChildren();
break;
}
}
return new MLRegisterModelInput(functionName, modelName, modelGroupId, version, description, url, hashValue, modelFormat, modelConfig, deployModel, modelNodeIds.toArray(new String[0]), connector, connectorId, backendRoles, addAllBackendRoles, accessMode);
return new MLRegisterModelInput(functionName, modelName, modelGroupId, version, description, url, hashValue, modelFormat, modelConfig, deployModel, modelNodeIds.toArray(new String[0]), connector, connectorId, backendRoles, addAllBackendRoles, accessMode, doesVersionCreateModelGroup);
}

public static MLRegisterModelInput parse(XContentParser parser, boolean deployModel) throws IOException {
Expand All @@ -342,6 +354,7 @@ public static MLRegisterModelInput parse(XContentParser parser, boolean deployMo
List<String> backendRoles = new ArrayList<>();
AccessMode accessMode = null;
Boolean addAllBackendRoles = null;
Boolean doesVersionCreateModelGroup = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
Expand Down Expand Up @@ -400,11 +413,14 @@ public static MLRegisterModelInput parse(XContentParser parser, boolean deployMo
case ACCESS_MODE_FIELD:
accessMode = AccessMode.from(parser.text());
break;
case DOES_VERSION_CREATE_MODEL_GROUP:
doesVersionCreateModelGroup = parser.booleanValue();
break;
default:
parser.skipChildren();
break;
}
}
return new MLRegisterModelInput(functionName, name, modelGroupId, version, description, url, hashValue, modelFormat, modelConfig, deployModel, modelNodeIds.toArray(new String[0]), connector, connectorId, backendRoles, addAllBackendRoles, accessMode);
return new MLRegisterModelInput(functionName, name, modelGroupId, version, description, url, hashValue, modelFormat, modelConfig, deployModel, modelNodeIds.toArray(new String[0]), connector, connectorId, backendRoles, addAllBackendRoles, accessMode, doesVersionCreateModelGroup);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ public class MLRegisterModelMetaInput implements ToXContentObject, Writeable{
public static final String BACKEND_ROLES_FIELD = "backend_roles"; //optional
public static final String ACCESS_MODE = "access_mode"; //optional
public static final String ADD_ALL_BACKEND_ROLES = "add_all_backend_roles"; //optional
public static final String DOES_VERSION_CREATE_MODEL_GROUP = "does_version_create_model_group";


private FunctionName functionName;
private String name;
Expand All @@ -65,11 +67,13 @@ public class MLRegisterModelMetaInput implements ToXContentObject, Writeable{
private List<String> backendRoles;
private AccessMode accessMode;
private Boolean isAddAllBackendRoles;
private Boolean doesVersionCreateModelGroup;

@Builder(toBuilder = true)
public MLRegisterModelMetaInput(String name, FunctionName functionName, String modelGroupId, String version, String description, MLModelFormat modelFormat, MLModelState modelState, Long modelContentSizeInBytes, String modelContentHashValue, MLModelConfig modelConfig, Integer totalChunks, List<String> backendRoles,
AccessMode accessMode,
Boolean isAddAllBackendRoles) {
Boolean isAddAllBackendRoles,
Boolean doesVersionCreateModelGroup) {
if (name == null) {
throw new IllegalArgumentException("model name is null");
}
Expand Down Expand Up @@ -103,6 +107,7 @@ public MLRegisterModelMetaInput(String name, FunctionName functionName, String m
this.backendRoles = backendRoles;
this.accessMode = accessMode;
this.isAddAllBackendRoles = isAddAllBackendRoles;
this.doesVersionCreateModelGroup = doesVersionCreateModelGroup;
}

public MLRegisterModelMetaInput(StreamInput in) throws IOException{
Expand All @@ -128,6 +133,7 @@ public MLRegisterModelMetaInput(StreamInput in) throws IOException{
accessMode = in.readEnum(AccessMode.class);
}
this.isAddAllBackendRoles = in.readOptionalBoolean();
this.doesVersionCreateModelGroup = in.readOptionalBoolean();
}

@Override
Expand Down Expand Up @@ -171,6 +177,7 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeBoolean(false);
}
out.writeOptionalBoolean(isAddAllBackendRoles);
out.writeOptionalBoolean(doesVersionCreateModelGroup);
}

@Override
Expand Down Expand Up @@ -206,6 +213,9 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par
if (isAddAllBackendRoles != null) {
builder.field(ADD_ALL_BACKEND_ROLES, isAddAllBackendRoles);
}
if (doesVersionCreateModelGroup != null) {
builder.field(DOES_VERSION_CREATE_MODEL_GROUP, doesVersionCreateModelGroup);
}
builder.endObject();
return builder;
}
Expand All @@ -225,6 +235,7 @@ public static MLRegisterModelMetaInput parse(XContentParser parser) throws IOExc
List<String> backendRoles = null;
AccessMode accessMode = null;
Boolean isAddAllBackendRoles = null;
Boolean doesVersionCreateModelGroup = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
Expand Down Expand Up @@ -277,12 +288,15 @@ public static MLRegisterModelMetaInput parse(XContentParser parser) throws IOExc
case ADD_ALL_BACKEND_ROLES:
isAddAllBackendRoles = parser.booleanValue();
break;
case DOES_VERSION_CREATE_MODEL_GROUP:
doesVersionCreateModelGroup = parser.booleanValue();
break;
default:
parser.skipChildren();
break;
}
}
return new MLRegisterModelMetaInput(name, functionName, modelGroupId, version, description, modelFormat, modelState, modelContentSizeInBytes, modelContentHashValue, modelConfig, totalChunks, backendRoles, accessMode, isAddAllBackendRoles);
return new MLRegisterModelMetaInput(name, functionName, modelGroupId, version, description, modelFormat, modelState, modelContentSizeInBytes, modelContentHashValue, modelConfig, totalChunks, backendRoles, accessMode, isAddAllBackendRoles, doesVersionCreateModelGroup);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ public void setup() {
config = new TextEmbeddingModelConfig("Model Type", 123, FrameworkType.SENTENCE_TRANSFORMERS, "All Config",
TextEmbeddingModelConfig.PoolingMode.MEAN, true, 512);
mLRegisterModelMetaInput = new MLRegisterModelMetaInput("Model Name", FunctionName.BATCH_RCF, "model_group_id", "1.0",
"Model Description", MLModelFormat.TORCH_SCRIPT, MLModelState.DEPLOYING, 200L, "123", config, 2, null, null, null);
"Model Description", MLModelFormat.TORCH_SCRIPT, MLModelState.DEPLOYING, 200L, "123", config, 2, null, null, null, null);
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ public void setUp() {
config = new TextEmbeddingModelConfig("Model Type", 123, FrameworkType.SENTENCE_TRANSFORMERS, "All Config",
TextEmbeddingModelConfig.PoolingMode.MEAN, true, 512);
mlRegisterModelMetaInput = new MLRegisterModelMetaInput("Model Name", FunctionName.BATCH_RCF, "Model Group Id", "1.0",
"Model Description", MLModelFormat.TORCH_SCRIPT, MLModelState.DEPLOYING, 200L, "123", config, 2, null, null, null);
"Model Description", MLModelFormat.TORCH_SCRIPT, MLModelState.DEPLOYING, 200L, "123", config, 2, null, null, null, null);
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -256,12 +256,14 @@ private void createModelGroup(MLRegisterModelInput registerModelInput, ActionLis
MLRegisterModelGroupInput mlRegisterModelGroupInput = createRegisterModelGroupRequest(registerModelInput);
mlModelGroupManager.createModelGroup(mlRegisterModelGroupInput, ActionListener.wrap(modelGroupId -> {
registerModelInput.setModelGroupId(modelGroupId);
registerModelInput.setDoesVersionCreateModelGroup(true);
registerModel(registerModelInput, listener);
}, e -> {
logException("Failed to create Model Group", e, log);
listener.onFailure(e);
}));
} else {
registerModelInput.setDoesVersionCreateModelGroup(false);
registerModel(registerModelInput, listener);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,12 +121,14 @@ private void createModelGroup(MLRegisterModelMetaInput mlUploadInput, ActionList
MLRegisterModelGroupInput mlRegisterModelGroupInput = createRegisterModelGroupRequest(mlUploadInput);
mlModelGroupManager.createModelGroup(mlRegisterModelGroupInput, ActionListener.wrap(modelGroupId -> {
mlUploadInput.setModelGroupId(modelGroupId);
mlUploadInput.setDoesVersionCreateModelGroup(true);
registerModelMeta(mlUploadInput, listener);
}, e -> {
logException("Failed to create Model Group", e, log);
listener.onFailure(e);
}));
} else {
mlUploadInput.setDoesVersionCreateModelGroup(false);
registerModelMeta(mlUploadInput, listener);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,15 +146,13 @@ private void validateRequestForAccessControl(MLRegisterModelGroupInput input, Us
AccessMode modelAccessMode = input.getModelAccessMode();
Boolean isAddAllBackendRoles = input.getIsAddAllBackendRoles();
if (modelAccessMode == null) {
if (modelAccessMode == null) {
if (!CollectionUtils.isEmpty(input.getBackendRoles()) && Boolean.TRUE.equals(isAddAllBackendRoles)) {
throw new IllegalArgumentException("You cannot specify backend roles and add all backend roles at the same time.");
} else if (Boolean.TRUE.equals(isAddAllBackendRoles) || !CollectionUtils.isEmpty(input.getBackendRoles())) {
input.setModelAccessMode(AccessMode.RESTRICTED);
modelAccessMode = AccessMode.RESTRICTED;
} else {
input.setModelAccessMode(AccessMode.PRIVATE);
}
if (!CollectionUtils.isEmpty(input.getBackendRoles()) && Boolean.TRUE.equals(isAddAllBackendRoles)) {
throw new IllegalArgumentException("You cannot specify backend roles and add all backend roles at the same time.");
} else if (Boolean.TRUE.equals(isAddAllBackendRoles) || !CollectionUtils.isEmpty(input.getBackendRoles())) {
input.setModelAccessMode(AccessMode.RESTRICTED);
modelAccessMode = AccessMode.RESTRICTED;
} else {
input.setModelAccessMode(AccessMode.PRIVATE);
}
}
if ((AccessMode.PUBLIC == modelAccessMode || AccessMode.PRIVATE == modelAccessMode)
Expand Down Expand Up @@ -184,20 +182,29 @@ private void validateRequestForAccessControl(MLRegisterModelGroupInput input, Us
}

public void validateUniqueModelGroupName(String name, ActionListener<SearchResponse> listener) throws IllegalArgumentException {
BoolQueryBuilder query = new BoolQueryBuilder();
query.filter(new TermQueryBuilder(MLRegisterModelGroupInput.NAME_FIELD + ".keyword", name));

SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(query);
SearchRequest searchRequest = new SearchRequest(ML_MODEL_GROUP_INDEX).source(searchSourceBuilder);

client.search(searchRequest, ActionListener.wrap(modelGroups -> { listener.onResponse(modelGroups); }, e -> {
if (e instanceof IndexNotFoundException) {
listener.onResponse(null);
} else {
log.error("Failed to search model group index", e);
listener.onFailure(e);
}
}));
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
BoolQueryBuilder query = new BoolQueryBuilder();
query.filter(new TermQueryBuilder(MLRegisterModelGroupInput.NAME_FIELD + ".keyword", name));

SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(query);
SearchRequest searchRequest = new SearchRequest(ML_MODEL_GROUP_INDEX).source(searchSourceBuilder);

client
.search(
searchRequest,
ActionListener.runBefore(ActionListener.wrap(modelGroups -> { listener.onResponse(modelGroups); }, e -> {
if (e instanceof IndexNotFoundException) {
listener.onResponse(null);
} else {
log.error("Failed to search model group index", e);
listener.onFailure(e);
}
}), () -> context.restore())
);
} catch (Exception e) {
log.error("Failed to search model group index", e);
listener.onFailure(e);
}
}

private void validateSecurityDisabledOrModelAccessControlDisabled(MLRegisterModelGroupInput input) {
Expand Down
Loading
Loading