Skip to content

Commit

Permalink
Move allow model setting from rest to transport (#1961)
Browse files Browse the repository at this point in the history
* Backport multiple PRs to main from 2.x (#1652)

* fix parameter name in preprocess function; fix remote model function … (#1362)

* fix parameter name in preprocess function; fix remote model function name

Signed-off-by: Yaliang Wu <[email protected]>

* fix failed unit test

Signed-off-by: Yaliang Wu <[email protected]>

---------

Signed-off-by: Yaliang Wu <[email protected]>

* throw exception when model group not found during update request (#1447)

Signed-off-by: Bhavana Ramaram <[email protected]>

* add status code to model tensor (#1443) (#1453)

Signed-off-by: Yaliang Wu <[email protected]>

* register new versions to a model group based on the name provided (#1452)

Signed-off-by: Bhavana Ramaram <[email protected]>

* fixing metrics correlation algorithm (#1448)

* fixing metrics correlation algorithm

Signed-off-by: Dhrubo Saha <[email protected]>

* if model version fails to register, update model group accordingly (#1463)

* if model version fails to register, update model group accordingly

Signed-off-by: Bhavana Ramaram <[email protected]>

* Update Model API (#1350)

* Update Model API POC

Signed-off-by: Sicheng Song <[email protected]>

* Using GetRequest to get model

Signed-off-by: Sicheng Song <[email protected]>

* Finalize model update API

Signed-off-by: Sicheng Song <[email protected]>

* Fix compile

Signed-off-by: Sicheng Song <[email protected]>

* Fix compileTest

Signed-off-by: Sicheng Song <[email protected]>

* Add Unit Test Cases for Update Model API

Signed-off-by: Sicheng Song <[email protected]>

* Tune back test coverage thereshold

Signed-off-by: Sicheng Song <[email protected]>

* Add more unit tests on Update model API

Signed-off-by: Sicheng Song <[email protected]>

* Add unit test for TransportUpdateModelAction class

Signed-off-by: Sicheng Song <[email protected]>

* Fix a test error

Signed-off-by: Sicheng Song <[email protected]>

* Change exception thrown to failure response

Signed-off-by: Sicheng Song <[email protected]>

* Move the function judgement to the outter block

Signed-off-by: Sicheng Song <[email protected]>

* Check if model is undeployed before update model

Signed-off-by: Sicheng Song <[email protected]>

* Add more unit test for update model API

Signed-off-by: Sicheng Song <[email protected]>

* Fix unit test due to blocking java 11 CI workflow

Signed-off-by: Sicheng Song <[email protected]>

* Enabling auto bumping model version during registering to a new model group and address reviewers' other concern

Signed-off-by: Sicheng Song <[email protected]>

* Autobump new model groups' latest version when register to a new model

Signed-off-by: Sicheng Song <[email protected]>

* Change the REST API method from POST to PUT

Signed-off-by: Sicheng Song <[email protected]>

* Change the update REST API endpoint

Signed-off-by: Sicheng Song <[email protected]>

---------

Signed-off-by: Sicheng Song <[email protected]>

* Add a setting to control the update connector API (#1465)

* Add a setting to control the update connector API

Signed-off-by: Sicheng Song <[email protected]>

* Enabling the update connnector setting in unit test

Signed-off-by: Sicheng Song <[email protected]>

* Enabling the update connnector setting in corresponding unit test

Signed-off-by: Sicheng Song <[email protected]>

---------

Signed-off-by: Sicheng Song <[email protected]>

* fix update connector API (#1484)

* fix update connector API

Signed-off-by: Yaliang Wu <[email protected]>

* Performance enhacement for predict action by caching model info (#1472) (#1508)

* Performance enhacement for predict action by caching model info

Signed-off-by: zane-neo <[email protected]>

* Add context.restore() to avoid missing info

Signed-off-by: zane-neo <[email protected]>

---------

Signed-off-by: zane-neo <[email protected]>
(cherry picked from commit a985f6e)

Co-authored-by: zane-neo <[email protected]>

* fix failed ut from PR 1472 (#1479) (#1510)

* fix failed ut from PR 1472

Signed-off-by: Yaliang Wu <[email protected]>

* exclude class for low coverage

Signed-off-by: Yaliang Wu <[email protected]>

---------

Signed-off-by: Yaliang Wu <[email protected]>
(cherry picked from commit da5d829)

Co-authored-by: Yaliang Wu <[email protected]>

* [Backport to 2.11] throw exception if remote model doesn't return 2xx status code; fix p… (#1477) (#1509)

* throw exception if remote model doesn't return 2xx status code; fix p… (#1473)

* throw exception if remote model doesn't return 2xx status code; fix predict runner

Signed-off-by: Yaliang Wu <[email protected]>

* fix kmeans model deploy bug

Signed-off-by: Yaliang Wu <[email protected]>

* support multiple docs for remote embedding model

Signed-off-by: Yaliang Wu <[email protected]>

* fix ut

Signed-off-by: Yaliang Wu <[email protected]>

---------

Signed-off-by: Yaliang Wu <[email protected]>

* fix wrong class

Signed-off-by: Yaliang Wu <[email protected]>

---------

Signed-off-by: Yaliang Wu <[email protected]>
(cherry picked from commit 201c8a8)

Co-authored-by: Yaliang Wu <[email protected]>

* fix no worker node exception for remote embedding model (#1482) (#1511)

* fix no worker node exception for remote embedding model

Signed-off-by: Yaliang Wu <[email protected]>

* only add model info to cache if model cache exist

Signed-off-by: Yaliang Wu <[email protected]>

---------

Signed-off-by: Yaliang Wu <[email protected]>
(cherry picked from commit 6f83b9f)

Co-authored-by: Yaliang Wu <[email protected]>

* fix for delete model group API throwing incorrect error when model index not created (#1485) (#1486) (#1512)

* fix for delete model group API throwing incorrect error when model index not created

Signed-off-by: Bhavana Ramaram <[email protected]>
(cherry picked from commit 60ef0fd)

Co-authored-by: Bhavana Ramaram <[email protected]>
(cherry picked from commit 5544681)

Co-authored-by: opensearch-trigger-bot[bot] <98922864+opensearch-trigger-bot[bot]@users.noreply.github.com>

* fix no worker node error on multi-node cluster (#1487) (#1513)

Signed-off-by: Yaliang Wu <[email protected]>
(cherry picked from commit cea1cd6)

Co-authored-by: Yaliang Wu <[email protected]>

* add prefix to show the error is from remote service (#1499) (#1515)

Signed-off-by: Yaliang Wu <[email protected]>
(cherry picked from commit 3897ad1)

Co-authored-by: Yaliang Wu <[email protected]>

* fix multiple docs support (#1516)

Signed-off-by: Yaliang Wu <[email protected]>

* adding another fix issue to the release note (#1498) (#1514)

Signed-off-by: Dhrubo Saha <[email protected]>
(cherry picked from commit 440155c)

Co-authored-by: Dhrubo Saha <[email protected]>

* add bedrockURL to trusted connector regex list (#1461)

Signed-off-by: Bhavana Ramaram <[email protected]>

* return parsing exception 400 for parsing errors

Signed-off-by: Xun Zhang <[email protected]>

* add more ut in restupdateconnector

Signed-off-by: Xun Zhang <[email protected]>

* fix format violations

Signed-off-by: Bhavana Ramaram <[email protected]>

* Fix model/connector update API to address security concern (#1595)

* Fix model/connector update API to address appsec concern

Signed-off-by: Sicheng Song <[email protected]>

* Fix compile and build failure

Signed-off-by: Sicheng Song <[email protected]>

* Improve unit test coverage

Signed-off-by: Sicheng Song <[email protected]>

* Fix spotless

Signed-off-by: Sicheng Song <[email protected]>

* Merge update connector feature flag to remote inference feature flag

Signed-off-by: Sicheng Song <[email protected]>

* Fix compile

Signed-off-by: Sicheng Song <[email protected]>

* Fix exception status

Signed-off-by: Sicheng Song <[email protected]>

* Keep fixing exception status

Signed-off-by: Sicheng Song <[email protected]>

* Spotless fix

Signed-off-by: Sicheng Song <[email protected]>

* Add UT on parsing exception

Signed-off-by: Sicheng Song <[email protected]>

---------

Signed-off-by: Sicheng Song <[email protected]>

* change XContentFactory to MediaTypeRegistry builder in MLRegisterModelInputTest class

Signed-off-by: Bhavana Ramaram <[email protected]>

---------

Signed-off-by: Yaliang Wu <[email protected]>
Signed-off-by: Bhavana Ramaram <[email protected]>
Signed-off-by: Dhrubo Saha <[email protected]>
Signed-off-by: Sicheng Song <[email protected]>
Signed-off-by: Xun Zhang <[email protected]>
Co-authored-by: Yaliang Wu <[email protected]>
Co-authored-by: Dhrubo Saha <[email protected]>
Co-authored-by: Sicheng Song <[email protected]>
Co-authored-by: opensearch-trigger-bot[bot] <98922864+opensearch-trigger-bot[bot]@users.noreply.github.com>
Co-authored-by: zane-neo <[email protected]>
Co-authored-by: Xun Zhang <[email protected]>

* Allow model setting to transport from rest

Signed-off-by: Owais Kazi <[email protected]>

* Added test

Signed-off-by: Owais Kazi <[email protected]>

* Fixed integ test

Signed-off-by: Owais Kazi <[email protected]>

---------

Signed-off-by: Yaliang Wu <[email protected]>
Signed-off-by: Bhavana Ramaram <[email protected]>
Signed-off-by: Dhrubo Saha <[email protected]>
Signed-off-by: Sicheng Song <[email protected]>
Signed-off-by: Xun Zhang <[email protected]>
Signed-off-by: Owais Kazi <[email protected]>
Co-authored-by: Bhavana Ramaram <[email protected]>
Co-authored-by: Yaliang Wu <[email protected]>
Co-authored-by: Dhrubo Saha <[email protected]>
Co-authored-by: Sicheng Song <[email protected]>
Co-authored-by: opensearch-trigger-bot[bot] <98922864+opensearch-trigger-bot[bot]@users.noreply.github.com>
Co-authored-by: zane-neo <[email protected]>
Co-authored-by: Xun Zhang <[email protected]>
(cherry picked from commit 6bff20d)
  • Loading branch information
owaiskazi19 authored and github-actions[bot] committed Feb 1, 2024
1 parent 737630d commit b147fe3
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import static org.opensearch.ml.common.MLTask.STATE_FIELD;
import static org.opensearch.ml.common.MLTaskState.FAILED;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ALLOW_MODEL_URL;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_URL_REGEX;
import static org.opensearch.ml.task.MLTaskManager.TASK_SEMAPHORE_TIMEOUT;
Expand Down Expand Up @@ -89,6 +90,7 @@ public class TransportRegisterModelAction extends HandledTransportAction<ActionR
private List<String> trustedConnectorEndpointsRegex;

ModelAccessControlHelper modelAccessControlHelper;
private volatile boolean isModelUrlAllowed;

ConnectorAccessControlHelper connectorAccessControlHelper;
MLModelGroupManager mlModelGroupManager;
Expand Down Expand Up @@ -132,6 +134,9 @@ public TransportRegisterModelAction(
trustedUrlRegex = ML_COMMONS_TRUSTED_URL_REGEX.get(settings);
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_TRUSTED_URL_REGEX, it -> trustedUrlRegex = it);

isModelUrlAllowed = ML_COMMONS_ALLOW_MODEL_URL.get(settings);
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_ALLOW_MODEL_URL, it -> isModelUrlAllowed = it);

trustedConnectorEndpointsRegex = ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX.get(settings);
clusterService
.getClusterSettings()
Expand All @@ -142,6 +147,11 @@ public TransportRegisterModelAction(
protected void doExecute(Task task, ActionRequest request, ActionListener<MLRegisterModelResponse> listener) {
MLRegisterModelRequest registerModelRequest = MLRegisterModelRequest.fromActionRequest(request);
MLRegisterModelInput registerModelInput = registerModelRequest.getRegisterModelInput();
if (registerModelInput.getUrl() != null && !isModelUrlAllowed) {
throw new IllegalArgumentException(
"To upload custom model user needs to enable allow_registering_model_via_url settings. Otherwise please use OpenSearch pre-trained models."
);
}
registerModelInput.setIsHidden(RestActionUtils.isSuperAdminUser(clusterService, client));
if (StringUtils.isEmpty(registerModelInput.getModelGroupId())) {
mlModelGroupManager.validateUniqueModelGroupName(registerModelInput.getModelName(), ActionListener.wrap(modelGroups -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ALLOW_MODEL_URL;
import static org.opensearch.ml.utils.MLExceptionUtils.REMOTE_INFERENCE_DISABLED_ERR_MSG;
import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_DEPLOY_MODEL;
import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_ID;
Expand Down Expand Up @@ -35,7 +34,6 @@

public class RestMLRegisterModelAction extends BaseRestHandler {
private static final String ML_REGISTER_MODEL_ACTION = "ml_register_model_action";
private volatile boolean isModelUrlAllowed;
private final MLFeatureEnabledSetting mlFeatureEnabledSetting;

/**
Expand All @@ -51,8 +49,6 @@ public RestMLRegisterModelAction(MLFeatureEnabledSetting mlFeatureEnabledSetting
* @param settings settings
*/
public RestMLRegisterModelAction(ClusterService clusterService, Settings settings, MLFeatureEnabledSetting mlFeatureEnabledSetting) {
isModelUrlAllowed = ML_COMMONS_ALLOW_MODEL_URL.get(settings);
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_ALLOW_MODEL_URL, it -> isModelUrlAllowed = it);
this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
}

Expand Down Expand Up @@ -103,11 +99,6 @@ MLRegisterModelRequest getRequest(RestRequest request) throws IOException {
if (mlInput.getFunctionName() == FunctionName.REMOTE && !mlFeatureEnabledSetting.isRemoteInferenceEnabled()) {
throw new IllegalStateException(REMOTE_INFERENCE_DISABLED_ERR_MSG);
}
if (mlInput.getUrl() != null && !isModelUrlAllowed) {
throw new IllegalArgumentException(
"To upload custom model user needs to enable allow_registering_model_via_url settings. Otherwise please use opensearch pre-trained models."
);
}
return new MLRegisterModelRequest(mlInput);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@

package org.opensearch.ml.action.models;

import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ALLOW_MODEL_URL;

import org.junit.Before;
import org.junit.Rule;
import org.junit.rules.ExpectedException;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.common.settings.Settings;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.ml.action.MLCommonsIntegTestCase;
Expand Down Expand Up @@ -187,4 +190,9 @@ private void test_matchPhrase_search() {
assertEquals(1, response.getHits().getTotalHits().value);
}

@Override
protected Settings nodeSettings(int ordinal) {
return Settings.builder().put(super.nodeSettings(ordinal)).put(ML_COMMONS_ALLOW_MODEL_URL.getKey(), true).build();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ALLOW_MODEL_URL;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_URL_REGEX;
import static org.opensearch.ml.utils.TestHelper.clusterSetting;
Expand Down Expand Up @@ -155,12 +156,14 @@ public void setup() throws IOException {
settings = Settings
.builder()
.put(ML_COMMONS_TRUSTED_URL_REGEX.getKey(), trustedUrlRegex)
.put(ML_COMMONS_ALLOW_MODEL_URL.getKey(), true)
.putList(ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX.getKey(), TRUSTED_CONNECTOR_ENDPOINTS_REGEXES)
.build();
threadContext = new ThreadContext(settings);
ClusterSettings clusterSettings = clusterSetting(
settings,
ML_COMMONS_TRUSTED_URL_REGEX,
ML_COMMONS_ALLOW_MODEL_URL,
ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX
);
when(clusterService.getClusterSettings()).thenReturn(clusterSettings);
Expand Down Expand Up @@ -294,6 +297,50 @@ public void testDoExecute_invalidURL() {
assertEquals("URL can't match trusted url regex", argumentCaptor.getValue().getMessage());
}

public void testRegisterModelUrlNotAllowed() throws Exception {
Settings settings = Settings
.builder()
.put(ML_COMMONS_TRUSTED_URL_REGEX.getKey(), trustedUrlRegex)
.put(ML_COMMONS_ALLOW_MODEL_URL.getKey(), false)
.putList(ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX.getKey(), TRUSTED_CONNECTOR_ENDPOINTS_REGEXES)
.build();
ClusterSettings clusterSettings = clusterSetting(
settings,
ML_COMMONS_TRUSTED_URL_REGEX,
ML_COMMONS_ALLOW_MODEL_URL,
ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX
);
when(clusterService.getClusterSettings()).thenReturn(clusterSettings);
when(clusterService.getSettings()).thenReturn(settings);
transportRegisterModelAction = new TransportRegisterModelAction(
transportService,
actionFilters,
modelHelper,
mlIndicesHandler,
mlModelManager,
mlTaskManager,
clusterService,
settings,
threadPool,
client,
nodeFilter,
mlTaskDispatcher,
mlStats,
modelAccessControlHelper,
connectorAccessControlHelper,
mlModelGroupManager
);

IllegalArgumentException e = assertThrows(
IllegalArgumentException.class,
() -> transportRegisterModelAction.doExecute(task, prepareRequest("test url", "testModelGroupsID"), actionListener)
);
assertEquals(
e.getMessage(),
"To upload custom model user needs to enable allow_registering_model_via_url settings. Otherwise please use OpenSearch pre-trained models."
);
}

public void testDoExecute_successWithLocalNodeNotEqualToClusterNode() {
when(node1.getId()).thenReturn("NodeId1");
when(node2.getId()).thenReturn("NodeId2");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,20 +147,6 @@ public void testRegisterModelRequestRemoteInferenceDisabled() throws Exception {
restMLRegisterModelAction.handleRequest(request, channel, client);
}

public void testRegisterModelUrlNotAllowed() throws Exception {
settings = Settings.builder().put(ML_COMMONS_ALLOW_MODEL_URL.getKey(), false).build();
ClusterSettings clusterSettings = clusterSetting(settings, ML_COMMONS_ALLOW_MODEL_URL);
when(clusterService.getClusterSettings()).thenReturn(clusterSettings);
restMLRegisterModelAction = new RestMLRegisterModelAction(clusterService, settings, mlFeatureEnabledSetting);
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule
.expectMessage(
"To upload custom model user needs to enable allow_registering_model_via_url settings. Otherwise please use opensearch pre-trained models."
);
RestRequest request = getRestRequest();
restMLRegisterModelAction.handleRequest(request, channel, client);
}

public void testRegisterModelRequestWithNullUrlAndUrlNotAllowed() throws Exception {
settings = Settings.builder().put(ML_COMMONS_ALLOW_MODEL_URL.getKey(), false).build();
ClusterSettings clusterSettings = clusterSetting(settings, ML_COMMONS_ALLOW_MODEL_URL);
Expand Down

0 comments on commit b147fe3

Please sign in to comment.