diff --git a/common/src/main/java/org/opensearch/ml/common/MLModelGroup.java b/common/src/main/java/org/opensearch/ml/common/MLModelGroup.java index 6fad6433ff..3a060910e0 100644 --- a/common/src/main/java/org/opensearch/ml/common/MLModelGroup.java +++ b/common/src/main/java/org/opensearch/ml/common/MLModelGroup.java @@ -10,6 +10,7 @@ import lombok.Setter; import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.common.util.CollectionUtils; import org.opensearch.commons.authuser.User; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; @@ -57,6 +58,9 @@ public MLModelGroup(String name, String description, int latestVersion, String modelGroupId, Instant createdTime, Instant lastUpdatedTime) { + if (name == null) { + throw new IllegalArgumentException("model group name is null"); + } this.name = name; this.description = description; this.latestVersion = latestVersion; @@ -73,7 +77,9 @@ public MLModelGroup(StreamInput input) throws IOException{ name = input.readString(); description = input.readOptionalString(); latestVersion = input.readInt(); - backendRoles = input.readOptionalStringList(); + if (input.readBoolean()) { + backendRoles = input.readStringList(); + } if (input.readBoolean()) { this.owner = new User(input); } else { @@ -89,8 +95,14 @@ public void writeTo(StreamOutput out) throws IOException { out.writeString(name); out.writeOptionalString(description); out.writeInt(latestVersion); - out.writeStringCollection(backendRoles); + if (!CollectionUtils.isEmpty(backendRoles)) { + out.writeBoolean(true); + out.writeStringCollection(backendRoles); + } else { + out.writeBoolean(false); + } if (owner != null) { + out.writeBoolean(true); owner.writeTo(out); } else { out.writeBoolean(false); @@ -109,7 +121,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (description != null) { builder.field(DESCRIPTION_FIELD, description); } - if (backendRoles != null) { + if (!CollectionUtils.isEmpty(backendRoles)) { builder.field(BACKEND_ROLES_FIELD, backendRoles); } if (owner != null) { @@ -134,8 +146,8 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws public static MLModelGroup parse(XContentParser parser) throws IOException { String name = null; String description = null; - List backendRoles = new ArrayList<>(); - Integer latestVersion = null; + List backendRoles = null; + int latestVersion = 0; User owner = null; String access = null; String modelGroupId = null; @@ -155,6 +167,7 @@ public static MLModelGroup parse(XContentParser parser) throws IOException { description = parser.text(); break; case BACKEND_ROLES_FIELD: + backendRoles = new ArrayList<>(); ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_ARRAY) { backendRoles.add(parser.text()); diff --git a/common/src/test/java/org/opensearch/ml/common/MLModelGroupTest.java b/common/src/test/java/org/opensearch/ml/common/MLModelGroupTest.java new file mode 100644 index 0000000000..da8048b1cc --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/MLModelGroupTest.java @@ -0,0 +1,132 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common; + +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.search.SearchModule; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; + +public class MLModelGroupTest { + + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + @Test + public void toXContent_NullName() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("model group name is null"); + + MLModelGroup.builder().build(); + } + + @Test + public void toXContent_Empty() throws IOException { + MLModelGroup modelGroup = MLModelGroup.builder().name("test").build(); + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + modelGroup.toXContent(builder, ToXContent.EMPTY_PARAMS); + String content = TestHelper.xContentBuilderToString(builder); + Assert.assertEquals("{\"name\":\"test\",\"latest_version\":0}", content); + } + + @Test + public void toXContent() throws IOException { + MLModelGroup modelGroup = MLModelGroup.builder() + .name("test") + .description("this is test group") + .latestVersion(1) + .backendRoles(Arrays.asList("role1", "role2")) + .owner(new User()) + .access(AccessMode.PUBLIC.name()) + .build(); + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + modelGroup.toXContent(builder, ToXContent.EMPTY_PARAMS); + String content = TestHelper.xContentBuilderToString(builder); + Assert.assertEquals("{\"name\":\"test\",\"latest_version\":1,\"description\":\"this is test group\"," + + "\"backend_roles\":[\"role1\",\"role2\"]," + + "\"owner\":{\"name\":\"\",\"backend_roles\":[],\"roles\":[],\"custom_attribute_names\":[],\"user_requested_tenant\":null}," + + "\"access\":\"PUBLIC\"}", content); + } + + @Test + public void parse() throws IOException { + String jsonStr = "{\"name\":\"test\",\"latest_version\":1,\"description\":\"this is test group\"," + + "\"backend_roles\":[\"role1\",\"role2\"]," + + "\"owner\":{\"name\":\"\",\"backend_roles\":[],\"roles\":[],\"custom_attribute_names\":[],\"user_requested_tenant\":null}," + + "\"access\":\"PUBLIC\"}"; + XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, + Collections.emptyList()).getNamedXContents()), null, jsonStr); + parser.nextToken(); + MLModelGroup modelGroup = MLModelGroup.parse(parser); + Assert.assertEquals("test", modelGroup.getName()); + Assert.assertEquals("this is test group", modelGroup.getDescription()); + Assert.assertEquals("PUBLIC", modelGroup.getAccess()); + Assert.assertEquals(2, modelGroup.getBackendRoles().size()); + Assert.assertEquals("role1", modelGroup.getBackendRoles().get(0)); + Assert.assertEquals("role2", modelGroup.getBackendRoles().get(1)); + } + + @Test + public void parse_Empty() throws IOException { + String jsonStr = "{\"name\":\"test\"}"; + XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, + Collections.emptyList()).getNamedXContents()), null, jsonStr); + parser.nextToken(); + MLModelGroup modelGroup = MLModelGroup.parse(parser); + Assert.assertEquals("test", modelGroup.getName()); + Assert.assertNull(modelGroup.getBackendRoles()); + Assert.assertNull(modelGroup.getAccess()); + Assert.assertNull(modelGroup.getOwner()); + } + + @Test + public void writeTo() throws IOException { + MLModelGroup originalModelGroup = MLModelGroup.builder() + .name("test") + .description("this is test group") + .latestVersion(1) + .backendRoles(Arrays.asList("role1", "role2")) + .owner(new User()) + .access(AccessMode.PUBLIC.name()) + .build(); + + BytesStreamOutput output = new BytesStreamOutput(); + originalModelGroup.writeTo(output); + MLModelGroup modelGroup = new MLModelGroup(output.bytes().streamInput()); + Assert.assertEquals("test", modelGroup.getName()); + Assert.assertEquals("this is test group", modelGroup.getDescription()); + Assert.assertEquals("PUBLIC", modelGroup.getAccess()); + Assert.assertEquals(2, modelGroup.getBackendRoles().size()); + Assert.assertEquals("role1", modelGroup.getBackendRoles().get(0)); + Assert.assertEquals("role2", modelGroup.getBackendRoles().get(1)); + } + + @Test + public void writeTo_Empty() throws IOException { + MLModelGroup originalModelGroup = MLModelGroup.builder().name("test").build(); + + BytesStreamOutput output = new BytesStreamOutput(); + originalModelGroup.writeTo(output); + MLModelGroup modelGroup = new MLModelGroup(output.bytes().streamInput()); + Assert.assertEquals("test", modelGroup.getName()); + Assert.assertNull(modelGroup.getBackendRoles()); + Assert.assertNull(modelGroup.getAccess()); + Assert.assertNull(modelGroup.getOwner()); + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelGroupManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelGroupManager.java index eeb00c1fd0..a55ef27ead 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelGroupManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelGroupManager.java @@ -193,7 +193,9 @@ public void validateUniqueModelGroupName(String name, ActionListener li } private void validateSecurityDisabledOrModelAccessControlDisabled(MLRegisterModelGroupInput input) { - if (input.getModelAccessMode() != null || input.getIsAddAllBackendRoles() != null || input.getBackendRoles() != null) { + if (input.getModelAccessMode() != null + || input.getIsAddAllBackendRoles() != null + || !CollectionUtils.isEmpty(input.getBackendRoles())) { throw new IllegalArgumentException( "You cannot specify model access control parameters because the Security plugin or model access control is disabled on your cluster." ); diff --git a/plugin/src/test/java/org/opensearch/ml/helper/ModelAccessControlHelperTests.java b/plugin/src/test/java/org/opensearch/ml/helper/ModelAccessControlHelperTests.java index 17b8725620..bf73addf04 100644 --- a/plugin/src/test/java/org/opensearch/ml/helper/ModelAccessControlHelperTests.java +++ b/plugin/src/test/java/org/opensearch/ml/helper/ModelAccessControlHelperTests.java @@ -199,8 +199,12 @@ public void test_IsOwner() { public void test_IsUserHasBackendRole() { User user = User.parse("owner|IT,HR|all_access"); MLModelGroupBuilder builder = MLModelGroup.builder(); - assertTrue(modelAccessControlHelper.isUserHasBackendRole(null, builder.access(AccessMode.PUBLIC.getValue()).build())); - assertFalse(modelAccessControlHelper.isUserHasBackendRole(null, builder.access(AccessMode.PRIVATE.getValue()).build())); + assertTrue( + modelAccessControlHelper.isUserHasBackendRole(null, builder.name("test_group").access(AccessMode.PUBLIC.getValue()).build()) + ); + assertFalse( + modelAccessControlHelper.isUserHasBackendRole(null, builder.name("test_group").access(AccessMode.PRIVATE.getValue()).build()) + ); assertTrue( modelAccessControlHelper .isUserHasBackendRole( @@ -218,9 +222,13 @@ public void test_IsOwnerStillHasPermission() { User userLostAccess = User.parse("owner|Finance|myTenant"); assertTrue(modelAccessControlHelper.isOwnerStillHasPermission(null, null)); MLModelGroupBuilder builder = MLModelGroup.builder(); - assertTrue(modelAccessControlHelper.isOwnerStillHasPermission(user, builder.access(AccessMode.PUBLIC.getValue()).build())); assertTrue( - modelAccessControlHelper.isOwnerStillHasPermission(user, builder.access(AccessMode.PRIVATE.getValue()).owner(owner).build()) + modelAccessControlHelper + .isOwnerStillHasPermission(user, builder.name("test_group").access(AccessMode.PUBLIC.getValue()).build()) + ); + assertTrue( + modelAccessControlHelper + .isOwnerStillHasPermission(user, builder.name("test_group").access(AccessMode.PRIVATE.getValue()).owner(owner).build()) ); assertFalse( modelAccessControlHelper