Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move allow model setting from rest to transport #1961

Merged
merged 8 commits into from
Feb 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import static org.opensearch.ml.common.MLTask.STATE_FIELD;
import static org.opensearch.ml.common.MLTaskState.FAILED;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ALLOW_MODEL_URL;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_URL_REGEX;
import static org.opensearch.ml.task.MLTaskManager.TASK_SEMAPHORE_TIMEOUT;
Expand Down Expand Up @@ -89,6 +90,7 @@ public class TransportRegisterModelAction extends HandledTransportAction<ActionR
private List<String> trustedConnectorEndpointsRegex;

ModelAccessControlHelper modelAccessControlHelper;
private volatile boolean isModelUrlAllowed;

ConnectorAccessControlHelper connectorAccessControlHelper;
MLModelGroupManager mlModelGroupManager;
Expand Down Expand Up @@ -132,6 +134,9 @@ public TransportRegisterModelAction(
trustedUrlRegex = ML_COMMONS_TRUSTED_URL_REGEX.get(settings);
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_TRUSTED_URL_REGEX, it -> trustedUrlRegex = it);

isModelUrlAllowed = ML_COMMONS_ALLOW_MODEL_URL.get(settings);
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_ALLOW_MODEL_URL, it -> isModelUrlAllowed = it);

trustedConnectorEndpointsRegex = ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX.get(settings);
clusterService
.getClusterSettings()
Expand All @@ -142,6 +147,11 @@ public TransportRegisterModelAction(
protected void doExecute(Task task, ActionRequest request, ActionListener<MLRegisterModelResponse> listener) {
MLRegisterModelRequest registerModelRequest = MLRegisterModelRequest.fromActionRequest(request);
MLRegisterModelInput registerModelInput = registerModelRequest.getRegisterModelInput();
if (registerModelInput.getUrl() != null && !isModelUrlAllowed) {
throw new IllegalArgumentException(
"To upload custom model user needs to enable allow_registering_model_via_url settings. Otherwise please use OpenSearch pre-trained models."
);
}
registerModelInput.setIsHidden(RestActionUtils.isSuperAdminUser(clusterService, client));
if (StringUtils.isEmpty(registerModelInput.getModelGroupId())) {
mlModelGroupManager.validateUniqueModelGroupName(registerModelInput.getModelName(), ActionListener.wrap(modelGroups -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ALLOW_MODEL_URL;
import static org.opensearch.ml.utils.MLExceptionUtils.REMOTE_INFERENCE_DISABLED_ERR_MSG;
import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_DEPLOY_MODEL;
import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_ID;
Expand Down Expand Up @@ -35,7 +34,6 @@

public class RestMLRegisterModelAction extends BaseRestHandler {
private static final String ML_REGISTER_MODEL_ACTION = "ml_register_model_action";
private volatile boolean isModelUrlAllowed;
private final MLFeatureEnabledSetting mlFeatureEnabledSetting;

/**
Expand All @@ -51,8 +49,6 @@ public RestMLRegisterModelAction(MLFeatureEnabledSetting mlFeatureEnabledSetting
* @param settings settings
*/
public RestMLRegisterModelAction(ClusterService clusterService, Settings settings, MLFeatureEnabledSetting mlFeatureEnabledSetting) {
isModelUrlAllowed = ML_COMMONS_ALLOW_MODEL_URL.get(settings);
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_ALLOW_MODEL_URL, it -> isModelUrlAllowed = it);
this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
}

Expand Down Expand Up @@ -103,11 +99,6 @@ MLRegisterModelRequest getRequest(RestRequest request) throws IOException {
if (mlInput.getFunctionName() == FunctionName.REMOTE && !mlFeatureEnabledSetting.isRemoteInferenceEnabled()) {
throw new IllegalStateException(REMOTE_INFERENCE_DISABLED_ERR_MSG);
}
if (mlInput.getUrl() != null && !isModelUrlAllowed) {
throw new IllegalArgumentException(
"To upload custom model user needs to enable allow_registering_model_via_url settings. Otherwise please use opensearch pre-trained models."
);
}
return new MLRegisterModelRequest(mlInput);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@

package org.opensearch.ml.action.models;

import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ALLOW_MODEL_URL;

import org.junit.Before;
import org.junit.Rule;
import org.junit.rules.ExpectedException;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.common.settings.Settings;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.ml.action.MLCommonsIntegTestCase;
Expand Down Expand Up @@ -187,4 +190,9 @@ private void test_matchPhrase_search() {
assertEquals(1, response.getHits().getTotalHits().value);
}

@Override
protected Settings nodeSettings(int ordinal) {
return Settings.builder().put(super.nodeSettings(ordinal)).put(ML_COMMONS_ALLOW_MODEL_URL.getKey(), true).build();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ALLOW_MODEL_URL;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_URL_REGEX;
import static org.opensearch.ml.utils.TestHelper.clusterSetting;
Expand Down Expand Up @@ -155,12 +156,14 @@ public void setup() throws IOException {
settings = Settings
.builder()
.put(ML_COMMONS_TRUSTED_URL_REGEX.getKey(), trustedUrlRegex)
.put(ML_COMMONS_ALLOW_MODEL_URL.getKey(), true)
.putList(ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX.getKey(), TRUSTED_CONNECTOR_ENDPOINTS_REGEXES)
.build();
threadContext = new ThreadContext(settings);
ClusterSettings clusterSettings = clusterSetting(
settings,
ML_COMMONS_TRUSTED_URL_REGEX,
ML_COMMONS_ALLOW_MODEL_URL,
ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX
);
when(clusterService.getClusterSettings()).thenReturn(clusterSettings);
Expand Down Expand Up @@ -294,6 +297,50 @@ public void testDoExecute_invalidURL() {
assertEquals("URL can't match trusted url regex", argumentCaptor.getValue().getMessage());
}

public void testRegisterModelUrlNotAllowed() throws Exception {
Settings settings = Settings
.builder()
.put(ML_COMMONS_TRUSTED_URL_REGEX.getKey(), trustedUrlRegex)
.put(ML_COMMONS_ALLOW_MODEL_URL.getKey(), false)
.putList(ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX.getKey(), TRUSTED_CONNECTOR_ENDPOINTS_REGEXES)
.build();
ClusterSettings clusterSettings = clusterSetting(
settings,
ML_COMMONS_TRUSTED_URL_REGEX,
ML_COMMONS_ALLOW_MODEL_URL,
ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX
);
when(clusterService.getClusterSettings()).thenReturn(clusterSettings);
when(clusterService.getSettings()).thenReturn(settings);
transportRegisterModelAction = new TransportRegisterModelAction(
transportService,
actionFilters,
modelHelper,
mlIndicesHandler,
mlModelManager,
mlTaskManager,
clusterService,
settings,
threadPool,
client,
nodeFilter,
mlTaskDispatcher,
mlStats,
modelAccessControlHelper,
connectorAccessControlHelper,
mlModelGroupManager
);

IllegalArgumentException e = assertThrows(
IllegalArgumentException.class,
() -> transportRegisterModelAction.doExecute(task, prepareRequest("test url", "testModelGroupsID"), actionListener)
);
assertEquals(
e.getMessage(),
"To upload custom model user needs to enable allow_registering_model_via_url settings. Otherwise please use OpenSearch pre-trained models."
);
}

public void testDoExecute_successWithLocalNodeNotEqualToClusterNode() {
when(node1.getId()).thenReturn("NodeId1");
when(node2.getId()).thenReturn("NodeId2");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,20 +147,6 @@ public void testRegisterModelRequestRemoteInferenceDisabled() throws Exception {
restMLRegisterModelAction.handleRequest(request, channel, client);
}

public void testRegisterModelUrlNotAllowed() throws Exception {
settings = Settings.builder().put(ML_COMMONS_ALLOW_MODEL_URL.getKey(), false).build();
ClusterSettings clusterSettings = clusterSetting(settings, ML_COMMONS_ALLOW_MODEL_URL);
when(clusterService.getClusterSettings()).thenReturn(clusterSettings);
restMLRegisterModelAction = new RestMLRegisterModelAction(clusterService, settings, mlFeatureEnabledSetting);
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule
.expectMessage(
"To upload custom model user needs to enable allow_registering_model_via_url settings. Otherwise please use opensearch pre-trained models."
);
RestRequest request = getRestRequest();
restMLRegisterModelAction.handleRequest(request, channel, client);
}

public void testRegisterModelRequestWithNullUrlAndUrlNotAllowed() throws Exception {
settings = Settings.builder().put(ML_COMMONS_ALLOW_MODEL_URL.getKey(), false).build();
ClusterSettings clusterSettings = clusterSetting(settings, ML_COMMONS_ALLOW_MODEL_URL);
Expand Down
Loading