Skip to content

Commit

Permalink
fix cannot specify model access control parameters error (#1068)
Browse files Browse the repository at this point in the history
* fix cannot specify model access control parameters error

Signed-off-by: Yaliang Wu <[email protected]>

* add unit test for model group class and fix some bug

Signed-off-by: Yaliang Wu <[email protected]>

* fix failed ut

Signed-off-by: Yaliang Wu <[email protected]>

---------

Signed-off-by: Yaliang Wu <[email protected]>
  • Loading branch information
ylwu-amzn authored and zane-neo committed Sep 1, 2023
1 parent 7ba6c3f commit 67a8497
Show file tree
Hide file tree
Showing 4 changed files with 165 additions and 10 deletions.
23 changes: 18 additions & 5 deletions common/src/main/java/org/opensearch/ml/common/MLModelGroup.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.common.util.CollectionUtils;
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
Expand Down Expand Up @@ -57,6 +58,9 @@ public MLModelGroup(String name, String description, int latestVersion,
String modelGroupId,
Instant createdTime,
Instant lastUpdatedTime) {
if (name == null) {
throw new IllegalArgumentException("model group name is null");
}
this.name = name;
this.description = description;
this.latestVersion = latestVersion;
Expand All @@ -73,7 +77,9 @@ public MLModelGroup(StreamInput input) throws IOException{
name = input.readString();
description = input.readOptionalString();
latestVersion = input.readInt();
backendRoles = input.readOptionalStringList();
if (input.readBoolean()) {
backendRoles = input.readStringList();
}
if (input.readBoolean()) {
this.owner = new User(input);
} else {
Expand All @@ -89,8 +95,14 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeString(name);
out.writeOptionalString(description);
out.writeInt(latestVersion);
out.writeStringCollection(backendRoles);
if (!CollectionUtils.isEmpty(backendRoles)) {
out.writeBoolean(true);
out.writeStringCollection(backendRoles);
} else {
out.writeBoolean(false);
}
if (owner != null) {
out.writeBoolean(true);
owner.writeTo(out);
} else {
out.writeBoolean(false);
Expand All @@ -109,7 +121,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
if (description != null) {
builder.field(DESCRIPTION_FIELD, description);
}
if (backendRoles != null) {
if (!CollectionUtils.isEmpty(backendRoles)) {
builder.field(BACKEND_ROLES_FIELD, backendRoles);
}
if (owner != null) {
Expand All @@ -134,8 +146,8 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
public static MLModelGroup parse(XContentParser parser) throws IOException {
String name = null;
String description = null;
List<String> backendRoles = new ArrayList<>();
Integer latestVersion = null;
List<String> backendRoles = null;
int latestVersion = 0;
User owner = null;
String access = null;
String modelGroupId = null;
Expand All @@ -155,6 +167,7 @@ public static MLModelGroup parse(XContentParser parser) throws IOException {
description = parser.text();
break;
case BACKEND_ROLES_FIELD:
backendRoles = new ArrayList<>();
ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_ARRAY) {
backendRoles.add(parser.text());
Expand Down
132 changes: 132 additions & 0 deletions common/src/test/java/org/opensearch/ml/common/MLModelGroupTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common;

import org.junit.Assert;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.search.SearchModule;

import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;

public class MLModelGroupTest {

@Rule
public ExpectedException exceptionRule = ExpectedException.none();

@Test
public void toXContent_NullName() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("model group name is null");

MLModelGroup.builder().build();
}

@Test
public void toXContent_Empty() throws IOException {
MLModelGroup modelGroup = MLModelGroup.builder().name("test").build();
XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
modelGroup.toXContent(builder, ToXContent.EMPTY_PARAMS);
String content = TestHelper.xContentBuilderToString(builder);
Assert.assertEquals("{\"name\":\"test\",\"latest_version\":0}", content);
}

@Test
public void toXContent() throws IOException {
MLModelGroup modelGroup = MLModelGroup.builder()
.name("test")
.description("this is test group")
.latestVersion(1)
.backendRoles(Arrays.asList("role1", "role2"))
.owner(new User())
.access(AccessMode.PUBLIC.name())
.build();
XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
modelGroup.toXContent(builder, ToXContent.EMPTY_PARAMS);
String content = TestHelper.xContentBuilderToString(builder);
Assert.assertEquals("{\"name\":\"test\",\"latest_version\":1,\"description\":\"this is test group\"," +
"\"backend_roles\":[\"role1\",\"role2\"]," +
"\"owner\":{\"name\":\"\",\"backend_roles\":[],\"roles\":[],\"custom_attribute_names\":[],\"user_requested_tenant\":null}," +
"\"access\":\"PUBLIC\"}", content);
}

@Test
public void parse() throws IOException {
String jsonStr = "{\"name\":\"test\",\"latest_version\":1,\"description\":\"this is test group\"," +
"\"backend_roles\":[\"role1\",\"role2\"]," +
"\"owner\":{\"name\":\"\",\"backend_roles\":[],\"roles\":[],\"custom_attribute_names\":[],\"user_requested_tenant\":null}," +
"\"access\":\"PUBLIC\"}";
XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY,
Collections.emptyList()).getNamedXContents()), null, jsonStr);
parser.nextToken();
MLModelGroup modelGroup = MLModelGroup.parse(parser);
Assert.assertEquals("test", modelGroup.getName());
Assert.assertEquals("this is test group", modelGroup.getDescription());
Assert.assertEquals("PUBLIC", modelGroup.getAccess());
Assert.assertEquals(2, modelGroup.getBackendRoles().size());
Assert.assertEquals("role1", modelGroup.getBackendRoles().get(0));
Assert.assertEquals("role2", modelGroup.getBackendRoles().get(1));
}

@Test
public void parse_Empty() throws IOException {
String jsonStr = "{\"name\":\"test\"}";
XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY,
Collections.emptyList()).getNamedXContents()), null, jsonStr);
parser.nextToken();
MLModelGroup modelGroup = MLModelGroup.parse(parser);
Assert.assertEquals("test", modelGroup.getName());
Assert.assertNull(modelGroup.getBackendRoles());
Assert.assertNull(modelGroup.getAccess());
Assert.assertNull(modelGroup.getOwner());
}

@Test
public void writeTo() throws IOException {
MLModelGroup originalModelGroup = MLModelGroup.builder()
.name("test")
.description("this is test group")
.latestVersion(1)
.backendRoles(Arrays.asList("role1", "role2"))
.owner(new User())
.access(AccessMode.PUBLIC.name())
.build();

BytesStreamOutput output = new BytesStreamOutput();
originalModelGroup.writeTo(output);
MLModelGroup modelGroup = new MLModelGroup(output.bytes().streamInput());
Assert.assertEquals("test", modelGroup.getName());
Assert.assertEquals("this is test group", modelGroup.getDescription());
Assert.assertEquals("PUBLIC", modelGroup.getAccess());
Assert.assertEquals(2, modelGroup.getBackendRoles().size());
Assert.assertEquals("role1", modelGroup.getBackendRoles().get(0));
Assert.assertEquals("role2", modelGroup.getBackendRoles().get(1));
}

@Test
public void writeTo_Empty() throws IOException {
MLModelGroup originalModelGroup = MLModelGroup.builder().name("test").build();

BytesStreamOutput output = new BytesStreamOutput();
originalModelGroup.writeTo(output);
MLModelGroup modelGroup = new MLModelGroup(output.bytes().streamInput());
Assert.assertEquals("test", modelGroup.getName());
Assert.assertNull(modelGroup.getBackendRoles());
Assert.assertNull(modelGroup.getAccess());
Assert.assertNull(modelGroup.getOwner());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,9 @@ public void validateUniqueModelGroupName(String name, ActionListener<Boolean> li
}

private void validateSecurityDisabledOrModelAccessControlDisabled(MLRegisterModelGroupInput input) {
if (input.getModelAccessMode() != null || input.getIsAddAllBackendRoles() != null || input.getBackendRoles() != null) {
if (input.getModelAccessMode() != null
|| input.getIsAddAllBackendRoles() != null
|| !CollectionUtils.isEmpty(input.getBackendRoles())) {
throw new IllegalArgumentException(
"You cannot specify model access control parameters because the Security plugin or model access control is disabled on your cluster."
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,12 @@ public void test_IsOwner() {
public void test_IsUserHasBackendRole() {
User user = User.parse("owner|IT,HR|all_access");
MLModelGroupBuilder builder = MLModelGroup.builder();
assertTrue(modelAccessControlHelper.isUserHasBackendRole(null, builder.access(AccessMode.PUBLIC.getValue()).build()));
assertFalse(modelAccessControlHelper.isUserHasBackendRole(null, builder.access(AccessMode.PRIVATE.getValue()).build()));
assertTrue(
modelAccessControlHelper.isUserHasBackendRole(null, builder.name("test_group").access(AccessMode.PUBLIC.getValue()).build())
);
assertFalse(
modelAccessControlHelper.isUserHasBackendRole(null, builder.name("test_group").access(AccessMode.PRIVATE.getValue()).build())
);
assertTrue(
modelAccessControlHelper
.isUserHasBackendRole(
Expand All @@ -218,9 +222,13 @@ public void test_IsOwnerStillHasPermission() {
User userLostAccess = User.parse("owner|Finance|myTenant");
assertTrue(modelAccessControlHelper.isOwnerStillHasPermission(null, null));
MLModelGroupBuilder builder = MLModelGroup.builder();
assertTrue(modelAccessControlHelper.isOwnerStillHasPermission(user, builder.access(AccessMode.PUBLIC.getValue()).build()));
assertTrue(
modelAccessControlHelper.isOwnerStillHasPermission(user, builder.access(AccessMode.PRIVATE.getValue()).owner(owner).build())
modelAccessControlHelper
.isOwnerStillHasPermission(user, builder.name("test_group").access(AccessMode.PUBLIC.getValue()).build())
);
assertTrue(
modelAccessControlHelper
.isOwnerStillHasPermission(user, builder.name("test_group").access(AccessMode.PRIVATE.getValue()).owner(owner).build())
);
assertFalse(
modelAccessControlHelper
Expand Down

0 comments on commit 67a8497

Please sign in to comment.