Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport to main]Add a setting to enable/disable local upload while registering model … #1236

Merged
merged 1 commit into from
Aug 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -388,8 +388,8 @@ public List<RestHandler> getRestHandlers(
RestMLRegisterModelAction restMLRegisterModelAction = new RestMLRegisterModelAction(clusterService, settings);
RestMLDeployModelAction restMLDeployModelAction = new RestMLDeployModelAction();
RestMLUndeployModelAction restMLUndeployModelAction = new RestMLUndeployModelAction(clusterService, settings);
RestMLRegisterModelMetaAction restMLRegisterModelMetaAction = new RestMLRegisterModelMetaAction();
RestMLUploadModelChunkAction restMLUploadModelChunkAction = new RestMLUploadModelChunkAction();
RestMLRegisterModelMetaAction restMLRegisterModelMetaAction = new RestMLRegisterModelMetaAction(clusterService, settings);
RestMLUploadModelChunkAction restMLUploadModelChunkAction = new RestMLUploadModelChunkAction(clusterService, settings);

return ImmutableList
.of(
Expand Down Expand Up @@ -507,7 +507,8 @@ public List<Setting<?>> getSettings() {
MLCommonsSettings.ML_COMMONS_ENABLE_INHOUSE_PYTHON_MODEL,
MLCommonsSettings.ML_COMMONS_MODEL_AUTO_REDEPLOY_ENABLE,
MLCommonsSettings.ML_COMMONS_MODEL_AUTO_REDEPLOY_LIFETIME_RETRY_TIMES,
MLCommonsSettings.ML_COMMONS_ALLOW_MODEL_URL
MLCommonsSettings.ML_COMMONS_ALLOW_MODEL_URL,
MLCommonsSettings.ML_COMMONS_ALLOW_LOCAL_FILE_UPLOAD
);
return settings;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,15 @@

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ALLOW_LOCAL_FILE_UPLOAD;

import java.io.IOException;
import java.util.List;
import java.util.Locale;

import org.opensearch.client.node.NodeClient;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.transport.upload_chunk.MLRegisterModelMetaAction;
import org.opensearch.ml.common.transport.upload_chunk.MLRegisterModelMetaInput;
Expand All @@ -27,10 +30,19 @@
public class RestMLRegisterModelMetaAction extends BaseRestHandler {
private static final String ML_REGISTER_MODEL_META_ACTION = "ml_register_model_meta_action";

private volatile boolean isLocalFileUploadAllowed;

/**
* Constructor
* @param clusterService cluster service
* @param settings settings
*/
public RestMLRegisterModelMetaAction() {}
public RestMLRegisterModelMetaAction(ClusterService clusterService, Settings settings) {
isLocalFileUploadAllowed = ML_COMMONS_ALLOW_LOCAL_FILE_UPLOAD.get(settings);
clusterService
.getClusterSettings()
.addSettingsUpdateConsumer(ML_COMMONS_ALLOW_LOCAL_FILE_UPLOAD, it -> isLocalFileUploadAllowed = it);
}

@Override
public String getName() {
Expand Down Expand Up @@ -66,7 +78,11 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client
@VisibleForTesting
MLRegisterModelMetaRequest getRequest(RestRequest request) throws IOException {
boolean hasContent = request.hasContent();
if (!hasContent) {
if (!isLocalFileUploadAllowed) {
throw new IllegalArgumentException(
"To upload custom model from local file, user needs to enable allow_registering_model_via_local_file settings. Otherwise please use opensearch pre-trained models"
);
} else if (!hasContent) {
throw new IOException("Model meta request has empty body");
}
XContentParser parser = request.contentParser();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,15 @@
package org.opensearch.ml.rest;

import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ALLOW_LOCAL_FILE_UPLOAD;

import java.io.IOException;
import java.util.List;
import java.util.Locale;

import org.opensearch.client.node.NodeClient;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;
import org.opensearch.ml.common.transport.upload_chunk.MLUploadModelChunkAction;
import org.opensearch.ml.common.transport.upload_chunk.MLUploadModelChunkInput;
import org.opensearch.ml.common.transport.upload_chunk.MLUploadModelChunkRequest;
Expand All @@ -24,11 +27,17 @@

public class RestMLUploadModelChunkAction extends BaseRestHandler {
private static final String ML_UPLOAD_MODEL_CHUNK_ACTION = "ml_upload_model_chunk_action";
private volatile boolean isLocalFileUploadAllowed;

/**
* Constructor
*/
public RestMLUploadModelChunkAction() {}
public RestMLUploadModelChunkAction(ClusterService clusterService, Settings settings) {
isLocalFileUploadAllowed = ML_COMMONS_ALLOW_LOCAL_FILE_UPLOAD.get(settings);
clusterService
.getClusterSettings()
.addSettingsUpdateConsumer(ML_COMMONS_ALLOW_LOCAL_FILE_UPLOAD, it -> isLocalFileUploadAllowed = it);
}

@Override
public String getName() {
Expand Down Expand Up @@ -65,6 +74,11 @@ MLUploadModelChunkRequest getRequest(RestRequest request) throws IOException {
final String modelId = request.param("model_id");
String chunk_number = request.param("chunk_number");
byte[] content = request.content().streamInput().readAllBytes();
if (!isLocalFileUploadAllowed) {
throw new IllegalArgumentException(
"To upload custom model from local file, user needs to enable allow_registering_model_via_local_file settings. Otherwise please use opensearch pre-trained models."
);
}
MLUploadModelChunkInput mlInput = new MLUploadModelChunkInput(modelId, Integer.parseInt(chunk_number), content);
return new MLUploadModelChunkRequest(mlInput);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,4 +91,12 @@ private MLCommonsSettings() {}
// This setting is to enable/disable model url in model register API.
public static final Setting<Boolean> ML_COMMONS_ALLOW_MODEL_URL = Setting
.boolSetting("plugins.ml_commons.allow_registering_model_via_url", false, Setting.Property.NodeScope, Setting.Property.Dynamic);

public static final Setting<Boolean> ML_COMMONS_ALLOW_LOCAL_FILE_UPLOAD = Setting
.boolSetting(
"plugins.ml_commons.allow_registering_model_via_local_file",
false,
Setting.Property.NodeScope,
Setting.Property.Dynamic
);
}
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,17 @@ public void setupSettings() throws IOException {
);
assertEquals(200, response.getStatusLine().getStatusCode());

response = TestHelper
.makeRequest(
client(),
"PUT",
"_cluster/settings",
null,
"{\"persistent\":{\"plugins.ml_commons.allow_registering_model_via_local_file\":true}}",
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, ""))
);
assertEquals(200, response.getStatusLine().getStatusCode());

String jsonEntity = "{\n"
+ " \"persistent\" : {\n"
+ " \"plugins.ml_commons.native_memory_threshold\" : 100 \n"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.*;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ALLOW_LOCAL_FILE_UPLOAD;
import static org.opensearch.ml.utils.TestHelper.clusterSetting;

import java.io.IOException;
import java.util.HashMap;
Expand All @@ -20,7 +22,10 @@
import org.junit.rules.ExpectedException;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.opensearch.client.node.NodeClient;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.ClusterSettings;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.core.action.ActionListener;
Expand Down Expand Up @@ -50,12 +55,21 @@ public class RestMLRegisterModelMetaActionTests extends OpenSearchTestCase {
@Mock
RestChannel channel;

@Mock
private ClusterService clusterService;

private Settings settings;

@Rule
public ExpectedException expectedEx = ExpectedException.none();
public ExpectedException expectedException = ExpectedException.none();

@Before
public void setup() {
restMLRegisterModelMetaAction = new RestMLRegisterModelMetaAction();
MockitoAnnotations.openMocks(this);
settings = Settings.builder().put(ML_COMMONS_ALLOW_LOCAL_FILE_UPLOAD.getKey(), true).build();
ClusterSettings clusterSettings = clusterSetting(settings, ML_COMMONS_ALLOW_LOCAL_FILE_UPLOAD);
when(clusterService.getClusterSettings()).thenReturn(clusterSettings);
restMLRegisterModelMetaAction = new RestMLRegisterModelMetaAction(clusterService, settings);
threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool");
client = spy(new NodeClient(Settings.EMPTY, threadPool));
doAnswer(invocation -> {
Expand All @@ -72,7 +86,7 @@ public void tearDown() throws Exception {
}

public void testConstructor() {
RestMLRegisterModelMetaAction mlUploadModel = new RestMLRegisterModelMetaAction();
RestMLRegisterModelMetaAction mlUploadModel = new RestMLRegisterModelMetaAction(clusterService, settings);
assertNotNull(mlUploadModel);
}

Expand Down Expand Up @@ -112,12 +126,26 @@ public void testRegisterModelMetaRequest() throws Exception {
assertEquals(Integer.valueOf(2), metaModelRequest.getTotalChunks());
}

public void testRegisterModelFileUploadNotAllowed() throws Exception {
settings = Settings.builder().put(ML_COMMONS_ALLOW_LOCAL_FILE_UPLOAD.getKey(), false).build();
ClusterSettings clusterSettings = clusterSetting(settings, ML_COMMONS_ALLOW_LOCAL_FILE_UPLOAD);
when(clusterService.getClusterSettings()).thenReturn(clusterSettings);
restMLRegisterModelMetaAction = new RestMLRegisterModelMetaAction(clusterService, settings);
expectedException.expect(IllegalArgumentException.class);
expectedException
.expectMessage(
"To upload custom model from local file, user needs to enable allow_registering_model_via_local_file settings. Otherwise please use opensearch pre-trained models"
);
RestRequest request = getRestRequest();
restMLRegisterModelMetaAction.handleRequest(request, channel, client);
}

public void testRegisterModelMeta_NoContent() throws Exception {
RestRequest.Method method = RestRequest.Method.POST;
Map<String, String> params = new HashMap<>();
RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withMethod(method).withParams(params).build();
expectedEx.expect(IOException.class);
expectedEx.expectMessage("Model meta request has empty body");
expectedException.expect(IOException.class);
expectedException.expectMessage("Model meta request has empty body");
restMLRegisterModelMetaAction.handleRequest(request, channel, client);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,23 @@
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.*;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ALLOW_LOCAL_FILE_UPLOAD;
import static org.opensearch.ml.utils.TestHelper.clusterSetting;

import java.util.HashMap;
import java.util.List;
import java.util.Map;

import org.junit.Before;
import org.junit.Ignore;
import org.junit.Rule;
import org.junit.rules.ExpectedException;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.opensearch.client.node.NodeClient;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.ClusterSettings;
import org.opensearch.common.settings.Settings;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.Strings;
Expand All @@ -43,10 +50,21 @@ public class RestMLUploadModelChunkActionTests extends OpenSearchTestCase {

@Mock
RestChannel channel;
@Mock
private ClusterService clusterService;

private Settings settings;

@Rule
public ExpectedException expectedException = ExpectedException.none();

@Before
public void setup() {
restChunkUploadAction = new RestMLUploadModelChunkAction();
MockitoAnnotations.openMocks(this);
settings = Settings.builder().put(ML_COMMONS_ALLOW_LOCAL_FILE_UPLOAD.getKey(), true).build();
ClusterSettings clusterSettings = clusterSetting(settings, ML_COMMONS_ALLOW_LOCAL_FILE_UPLOAD);
when(clusterService.getClusterSettings()).thenReturn(clusterSettings);
restChunkUploadAction = new RestMLUploadModelChunkAction(clusterService, settings);
threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool");
client = spy(new NodeClient(Settings.EMPTY, threadPool));
doAnswer(invocation -> {
Expand All @@ -63,7 +81,7 @@ public void tearDown() throws Exception {
}

public void testConstructor() {
RestMLUploadModelChunkAction mlUploadChunk = new RestMLUploadModelChunkAction();
RestMLUploadModelChunkAction mlUploadChunk = new RestMLUploadModelChunkAction(clusterService, settings);
assertNotNull(mlUploadChunk);
}

Expand Down Expand Up @@ -102,6 +120,20 @@ public void testUploadChunkRequest() throws Exception {
assertEquals(Integer.valueOf(0), chunkRequest.getChunkNumber());
}

public void testRegisterModelFileUploadNotAllowed() throws Exception {
settings = Settings.builder().put(ML_COMMONS_ALLOW_LOCAL_FILE_UPLOAD.getKey(), false).build();
ClusterSettings clusterSettings = clusterSetting(settings, ML_COMMONS_ALLOW_LOCAL_FILE_UPLOAD);
when(clusterService.getClusterSettings()).thenReturn(clusterSettings);
restChunkUploadAction = new RestMLUploadModelChunkAction(clusterService, settings);
expectedException.expect(IllegalArgumentException.class);
expectedException
.expectMessage(
"To upload custom model from local file, user needs to enable allow_registering_model_via_local_file settings. Otherwise please use opensearch pre-trained models"
);
RestRequest request = getRestRequest();
restChunkUploadAction.handleRequest(request, channel, client);
}

private RestRequest getRestRequest() {
RestRequest.Method method = RestRequest.Method.POST;
BytesArray content = new BytesArray("12345678");
Expand Down