From 758b2aea7f4fb1d72ae481c90757cbc2ec952146 Mon Sep 17 00:00:00 2001 From: Bhavana Ramaram Date: Mon, 1 May 2023 10:07:42 -0700 Subject: [PATCH] Add a setting to enable/disable local upload while registering model (#873) Signed-off-by: Bhavana Goud Ramaram --- .../ml/plugin/MachineLearningPlugin.java | 7 ++-- .../rest/RestMLRegisterModelMetaAction.java | 20 +++++++++- .../ml/rest/RestMLUploadModelChunkAction.java | 16 +++++++- .../ml/settings/MLCommonsSettings.java | 8 ++++ .../ml/rest/MLCommonsRestTestCase.java | 11 ++++++ .../RestMLRegisterModelMetaActionTests.java | 38 ++++++++++++++++--- .../RestMLUploadModelChunkActionTests.java | 36 +++++++++++++++++- 7 files changed, 123 insertions(+), 13 deletions(-) diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index 80b9fc0b42..87578868fa 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -388,8 +388,8 @@ public List getRestHandlers( RestMLRegisterModelAction restMLRegisterModelAction = new RestMLRegisterModelAction(clusterService, settings); RestMLDeployModelAction restMLDeployModelAction = new RestMLDeployModelAction(); RestMLUndeployModelAction restMLUndeployModelAction = new RestMLUndeployModelAction(clusterService, settings); - RestMLRegisterModelMetaAction restMLRegisterModelMetaAction = new RestMLRegisterModelMetaAction(); - RestMLUploadModelChunkAction restMLUploadModelChunkAction = new RestMLUploadModelChunkAction(); + RestMLRegisterModelMetaAction restMLRegisterModelMetaAction = new RestMLRegisterModelMetaAction(clusterService, settings); + RestMLUploadModelChunkAction restMLUploadModelChunkAction = new RestMLUploadModelChunkAction(clusterService, settings); return ImmutableList .of( @@ -507,7 +507,8 @@ public List> getSettings() { MLCommonsSettings.ML_COMMONS_ENABLE_INHOUSE_PYTHON_MODEL, MLCommonsSettings.ML_COMMONS_MODEL_AUTO_REDEPLOY_ENABLE, MLCommonsSettings.ML_COMMONS_MODEL_AUTO_REDEPLOY_LIFETIME_RETRY_TIMES, - MLCommonsSettings.ML_COMMONS_ALLOW_MODEL_URL + MLCommonsSettings.ML_COMMONS_ALLOW_MODEL_URL, + MLCommonsSettings.ML_COMMONS_ALLOW_LOCAL_FILE_UPLOAD ); return settings; } diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLRegisterModelMetaAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLRegisterModelMetaAction.java index 81e1b24fb9..53212b09c6 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLRegisterModelMetaAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLRegisterModelMetaAction.java @@ -7,12 +7,15 @@ import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ALLOW_LOCAL_FILE_UPLOAD; import java.io.IOException; import java.util.List; import java.util.Locale; import org.opensearch.client.node.NodeClient; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.transport.upload_chunk.MLRegisterModelMetaAction; import org.opensearch.ml.common.transport.upload_chunk.MLRegisterModelMetaInput; @@ -27,10 +30,19 @@ public class RestMLRegisterModelMetaAction extends BaseRestHandler { private static final String ML_REGISTER_MODEL_META_ACTION = "ml_register_model_meta_action"; + private volatile boolean isLocalFileUploadAllowed; + /** * Constructor + * @param clusterService cluster service + * @param settings settings */ - public RestMLRegisterModelMetaAction() {} + public RestMLRegisterModelMetaAction(ClusterService clusterService, Settings settings) { + isLocalFileUploadAllowed = ML_COMMONS_ALLOW_LOCAL_FILE_UPLOAD.get(settings); + clusterService + .getClusterSettings() + .addSettingsUpdateConsumer(ML_COMMONS_ALLOW_LOCAL_FILE_UPLOAD, it -> isLocalFileUploadAllowed = it); + } @Override public String getName() { @@ -66,7 +78,11 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client @VisibleForTesting MLRegisterModelMetaRequest getRequest(RestRequest request) throws IOException { boolean hasContent = request.hasContent(); - if (!hasContent) { + if (!isLocalFileUploadAllowed) { + throw new IllegalArgumentException( + "To upload custom model from local file, user needs to enable allow_registering_model_via_local_file settings. Otherwise please use opensearch pre-trained models" + ); + } else if (!hasContent) { throw new IOException("Model meta request has empty body"); } XContentParser parser = request.contentParser(); diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLUploadModelChunkAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLUploadModelChunkAction.java index 0e3441ada7..82b3754433 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLUploadModelChunkAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLUploadModelChunkAction.java @@ -6,12 +6,15 @@ package org.opensearch.ml.rest; import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ALLOW_LOCAL_FILE_UPLOAD; import java.io.IOException; import java.util.List; import java.util.Locale; import org.opensearch.client.node.NodeClient; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; import org.opensearch.ml.common.transport.upload_chunk.MLUploadModelChunkAction; import org.opensearch.ml.common.transport.upload_chunk.MLUploadModelChunkInput; import org.opensearch.ml.common.transport.upload_chunk.MLUploadModelChunkRequest; @@ -24,11 +27,17 @@ public class RestMLUploadModelChunkAction extends BaseRestHandler { private static final String ML_UPLOAD_MODEL_CHUNK_ACTION = "ml_upload_model_chunk_action"; + private volatile boolean isLocalFileUploadAllowed; /** * Constructor */ - public RestMLUploadModelChunkAction() {} + public RestMLUploadModelChunkAction(ClusterService clusterService, Settings settings) { + isLocalFileUploadAllowed = ML_COMMONS_ALLOW_LOCAL_FILE_UPLOAD.get(settings); + clusterService + .getClusterSettings() + .addSettingsUpdateConsumer(ML_COMMONS_ALLOW_LOCAL_FILE_UPLOAD, it -> isLocalFileUploadAllowed = it); + } @Override public String getName() { @@ -65,6 +74,11 @@ MLUploadModelChunkRequest getRequest(RestRequest request) throws IOException { final String modelId = request.param("model_id"); String chunk_number = request.param("chunk_number"); byte[] content = request.content().streamInput().readAllBytes(); + if (!isLocalFileUploadAllowed) { + throw new IllegalArgumentException( + "To upload custom model from local file, user needs to enable allow_registering_model_via_local_file settings. Otherwise please use opensearch pre-trained models." + ); + } MLUploadModelChunkInput mlInput = new MLUploadModelChunkInput(modelId, Integer.parseInt(chunk_number), content); return new MLUploadModelChunkRequest(mlInput); } diff --git a/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java b/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java index 1e0a354b13..161dc228b4 100644 --- a/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java +++ b/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java @@ -91,4 +91,12 @@ private MLCommonsSettings() {} // This setting is to enable/disable model url in model register API. public static final Setting ML_COMMONS_ALLOW_MODEL_URL = Setting .boolSetting("plugins.ml_commons.allow_registering_model_via_url", false, Setting.Property.NodeScope, Setting.Property.Dynamic); + + public static final Setting ML_COMMONS_ALLOW_LOCAL_FILE_UPLOAD = Setting + .boolSetting( + "plugins.ml_commons.allow_registering_model_via_local_file", + false, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java b/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java index 0309cbcb57..002f93418c 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java @@ -131,6 +131,17 @@ public void setupSettings() throws IOException { ); assertEquals(200, response.getStatusLine().getStatusCode()); + response = TestHelper + .makeRequest( + client(), + "PUT", + "_cluster/settings", + null, + "{\"persistent\":{\"plugins.ml_commons.allow_registering_model_via_local_file\":true}}", + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "")) + ); + assertEquals(200, response.getStatusLine().getStatusCode()); + String jsonEntity = "{\n" + " \"persistent\" : {\n" + " \"plugins.ml_commons.native_memory_threshold\" : 100 \n" diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRegisterModelMetaActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRegisterModelMetaActionTests.java index df8e35d49f..41d9bbb18a 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRegisterModelMetaActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRegisterModelMetaActionTests.java @@ -8,6 +8,8 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.*; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ALLOW_LOCAL_FILE_UPLOAD; +import static org.opensearch.ml.utils.TestHelper.clusterSetting; import java.io.IOException; import java.util.HashMap; @@ -20,7 +22,10 @@ import org.junit.rules.ExpectedException; import org.mockito.ArgumentCaptor; import org.mockito.Mock; +import org.mockito.MockitoAnnotations; import org.opensearch.client.node.NodeClient; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.action.ActionListener; @@ -50,12 +55,21 @@ public class RestMLRegisterModelMetaActionTests extends OpenSearchTestCase { @Mock RestChannel channel; + @Mock + private ClusterService clusterService; + + private Settings settings; + @Rule - public ExpectedException expectedEx = ExpectedException.none(); + public ExpectedException expectedException = ExpectedException.none(); @Before public void setup() { - restMLRegisterModelMetaAction = new RestMLRegisterModelMetaAction(); + MockitoAnnotations.openMocks(this); + settings = Settings.builder().put(ML_COMMONS_ALLOW_LOCAL_FILE_UPLOAD.getKey(), true).build(); + ClusterSettings clusterSettings = clusterSetting(settings, ML_COMMONS_ALLOW_LOCAL_FILE_UPLOAD); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + restMLRegisterModelMetaAction = new RestMLRegisterModelMetaAction(clusterService, settings); threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); client = spy(new NodeClient(Settings.EMPTY, threadPool)); doAnswer(invocation -> { @@ -72,7 +86,7 @@ public void tearDown() throws Exception { } public void testConstructor() { - RestMLRegisterModelMetaAction mlUploadModel = new RestMLRegisterModelMetaAction(); + RestMLRegisterModelMetaAction mlUploadModel = new RestMLRegisterModelMetaAction(clusterService, settings); assertNotNull(mlUploadModel); } @@ -112,12 +126,26 @@ public void testRegisterModelMetaRequest() throws Exception { assertEquals(Integer.valueOf(2), metaModelRequest.getTotalChunks()); } + public void testRegisterModelFileUploadNotAllowed() throws Exception { + settings = Settings.builder().put(ML_COMMONS_ALLOW_LOCAL_FILE_UPLOAD.getKey(), false).build(); + ClusterSettings clusterSettings = clusterSetting(settings, ML_COMMONS_ALLOW_LOCAL_FILE_UPLOAD); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + restMLRegisterModelMetaAction = new RestMLRegisterModelMetaAction(clusterService, settings); + expectedException.expect(IllegalArgumentException.class); + expectedException + .expectMessage( + "To upload custom model from local file, user needs to enable allow_registering_model_via_local_file settings. Otherwise please use opensearch pre-trained models" + ); + RestRequest request = getRestRequest(); + restMLRegisterModelMetaAction.handleRequest(request, channel, client); + } + public void testRegisterModelMeta_NoContent() throws Exception { RestRequest.Method method = RestRequest.Method.POST; Map params = new HashMap<>(); RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withMethod(method).withParams(params).build(); - expectedEx.expect(IOException.class); - expectedEx.expectMessage("Model meta request has empty body"); + expectedException.expect(IOException.class); + expectedException.expectMessage("Model meta request has empty body"); restMLRegisterModelMetaAction.handleRequest(request, channel, client); } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLUploadModelChunkActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLUploadModelChunkActionTests.java index 842e168a65..cadefb50a9 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLUploadModelChunkActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLUploadModelChunkActionTests.java @@ -8,6 +8,8 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.*; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ALLOW_LOCAL_FILE_UPLOAD; +import static org.opensearch.ml.utils.TestHelper.clusterSetting; import java.util.HashMap; import java.util.List; @@ -15,9 +17,14 @@ import org.junit.Before; import org.junit.Ignore; +import org.junit.Rule; +import org.junit.rules.ExpectedException; import org.mockito.ArgumentCaptor; import org.mockito.Mock; +import org.mockito.MockitoAnnotations; import org.opensearch.client.node.NodeClient; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.Strings; @@ -43,10 +50,21 @@ public class RestMLUploadModelChunkActionTests extends OpenSearchTestCase { @Mock RestChannel channel; + @Mock + private ClusterService clusterService; + + private Settings settings; + + @Rule + public ExpectedException expectedException = ExpectedException.none(); @Before public void setup() { - restChunkUploadAction = new RestMLUploadModelChunkAction(); + MockitoAnnotations.openMocks(this); + settings = Settings.builder().put(ML_COMMONS_ALLOW_LOCAL_FILE_UPLOAD.getKey(), true).build(); + ClusterSettings clusterSettings = clusterSetting(settings, ML_COMMONS_ALLOW_LOCAL_FILE_UPLOAD); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + restChunkUploadAction = new RestMLUploadModelChunkAction(clusterService, settings); threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); client = spy(new NodeClient(Settings.EMPTY, threadPool)); doAnswer(invocation -> { @@ -63,7 +81,7 @@ public void tearDown() throws Exception { } public void testConstructor() { - RestMLUploadModelChunkAction mlUploadChunk = new RestMLUploadModelChunkAction(); + RestMLUploadModelChunkAction mlUploadChunk = new RestMLUploadModelChunkAction(clusterService, settings); assertNotNull(mlUploadChunk); } @@ -102,6 +120,20 @@ public void testUploadChunkRequest() throws Exception { assertEquals(Integer.valueOf(0), chunkRequest.getChunkNumber()); } + public void testRegisterModelFileUploadNotAllowed() throws Exception { + settings = Settings.builder().put(ML_COMMONS_ALLOW_LOCAL_FILE_UPLOAD.getKey(), false).build(); + ClusterSettings clusterSettings = clusterSetting(settings, ML_COMMONS_ALLOW_LOCAL_FILE_UPLOAD); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + restChunkUploadAction = new RestMLUploadModelChunkAction(clusterService, settings); + expectedException.expect(IllegalArgumentException.class); + expectedException + .expectMessage( + "To upload custom model from local file, user needs to enable allow_registering_model_via_local_file settings. Otherwise please use opensearch pre-trained models" + ); + RestRequest request = getRestRequest(); + restChunkUploadAction.handleRequest(request, channel, client); + } + private RestRequest getRestRequest() { RestRequest.Method method = RestRequest.Method.POST; BytesArray content = new BytesArray("12345678");