Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport to 2.x] add feature flags for remote inference #1223

Merged
merged 1 commit into from
Aug 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import static org.opensearch.ml.plugin.MachineLearningPlugin.DEPLOY_THREAD_POOL;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ALLOW_CUSTOM_DEPLOYMENT_PLAN;
import static org.opensearch.ml.task.MLTaskManager.TASK_SEMAPHORE_TIMEOUT;
import static org.opensearch.ml.utils.MLExceptionUtils.REMOTE_INFERENCE_DISABLED_ERR_MSG;

import java.time.Instant;
import java.util.ArrayList;
Expand Down Expand Up @@ -52,6 +53,7 @@
import org.opensearch.ml.engine.ModelHelper;
import org.opensearch.ml.helper.ModelAccessControlHelper;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.stats.MLNodeLevelStat;
import org.opensearch.ml.stats.MLStats;
import org.opensearch.ml.task.MLTaskDispatcher;
Expand Down Expand Up @@ -83,6 +85,7 @@ public class TransportDeployModelAction extends HandledTransportAction<ActionReq

private volatile boolean allowCustomDeploymentPlan;
private ModelAccessControlHelper modelAccessControlHelper;
private MLFeatureEnabledSetting mlFeatureEnabledSetting;

@Inject
public TransportDeployModelAction(
Expand All @@ -99,7 +102,8 @@ public TransportDeployModelAction(
MLModelManager mlModelManager,
MLStats mlStats,
Settings settings,
ModelAccessControlHelper modelAccessControlHelper
ModelAccessControlHelper modelAccessControlHelper,
MLFeatureEnabledSetting mlFeatureEnabledSetting
) {
super(MLDeployModelAction.NAME, transportService, actionFilters, MLDeployModelRequest::new);
this.transportService = transportService;
Expand All @@ -114,6 +118,7 @@ public TransportDeployModelAction(
this.mlModelManager = mlModelManager;
this.mlStats = mlStats;
this.modelAccessControlHelper = modelAccessControlHelper;
this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
allowCustomDeploymentPlan = ML_COMMONS_ALLOW_CUSTOM_DEPLOYMENT_PLAN.get(settings);
clusterService
.getClusterSettings()
Expand All @@ -130,6 +135,9 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLDepl
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
mlModelManager.getModel(modelId, null, excludes, ActionListener.wrap(mlModel -> {
FunctionName functionName = mlModel.getAlgorithm();
if (functionName == FunctionName.REMOTE && !mlFeatureEnabledSetting.isRemoteInferenceEnabled()) {
throw new IllegalStateException(REMOTE_INFERENCE_DISABLED_ERR_MSG);
}
modelAccessControlHelper.validateModelGroupAccess(user, mlModel.getModelGroupId(), client, ActionListener.wrap(access -> {
if (!access) {
listener
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@
import org.opensearch.ml.rest.RestMLUpdateModelGroupAction;
import org.opensearch.ml.rest.RestMLUploadModelChunkAction;
import org.opensearch.ml.settings.MLCommonsSettings;
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.stats.MLClusterLevelStat;
import org.opensearch.ml.stats.MLNodeLevelStat;
import org.opensearch.ml.stats.MLStat;
Expand Down Expand Up @@ -221,6 +222,8 @@ public class MachineLearningPlugin extends Plugin implements ActionPlugin {

private ConnectorAccessControlHelper connectorAccessControlHelper;

private MLFeatureEnabledSetting mlFeatureEnabledSetting;

@Override
public List<ActionHandler<? extends ActionRequest, ? extends ActionResponse>> getActions() {
return ImmutableList
Expand Down Expand Up @@ -331,6 +334,8 @@ public Collection<Object> createComponents(
mlInputDatasetHandler = new MLInputDatasetHandler(client);
modelAccessControlHelper = new ModelAccessControlHelper(clusterService, settings);
connectorAccessControlHelper = new ConnectorAccessControlHelper(clusterService, settings);
mlFeatureEnabledSetting = new MLFeatureEnabledSetting(clusterService, settings);

mlModelChunkUploader = new MLModelChunkUploader(mlIndicesHandler, client, xContentRegistry, modelAccessControlHelper);

MLTaskDispatcher mlTaskDispatcher = new MLTaskDispatcher(clusterService, client, settings, nodeHelper);
Expand Down Expand Up @@ -437,6 +442,7 @@ public Collection<Object> createComponents(
mlExecuteTaskRunner,
modelAccessControlHelper,
connectorAccessControlHelper,
mlFeatureEnabledSetting,
mlSearchHandler,
mlTaskDispatcher,
mlModelChunkUploader,
Expand All @@ -461,7 +467,7 @@ public List<RestHandler> getRestHandlers(
RestMLStatsAction restMLStatsAction = new RestMLStatsAction(mlStats, clusterService, indexUtils, xContentRegistry);
RestMLTrainingAction restMLTrainingAction = new RestMLTrainingAction();
RestMLTrainAndPredictAction restMLTrainAndPredictAction = new RestMLTrainAndPredictAction();
RestMLPredictionAction restMLPredictionAction = new RestMLPredictionAction(mlModelManager);
RestMLPredictionAction restMLPredictionAction = new RestMLPredictionAction(mlModelManager, mlFeatureEnabledSetting);
RestMLExecuteAction restMLExecuteAction = new RestMLExecuteAction();
RestMLGetModelAction restMLGetModelAction = new RestMLGetModelAction();
RestMLDeleteModelAction restMLDeleteModelAction = new RestMLDeleteModelAction();
Expand All @@ -470,7 +476,11 @@ public List<RestHandler> getRestHandlers(
RestMLDeleteTaskAction restMLDeleteTaskAction = new RestMLDeleteTaskAction();
RestMLSearchTaskAction restMLSearchTaskAction = new RestMLSearchTaskAction();
RestMLProfileAction restMLProfileAction = new RestMLProfileAction(clusterService);
RestMLRegisterModelAction restMLRegisterModelAction = new RestMLRegisterModelAction(clusterService, settings);
RestMLRegisterModelAction restMLRegisterModelAction = new RestMLRegisterModelAction(
clusterService,
settings,
mlFeatureEnabledSetting
);
RestMLDeployModelAction restMLDeployModelAction = new RestMLDeployModelAction();
RestMLUndeployModelAction restMLUndeployModelAction = new RestMLUndeployModelAction(clusterService, settings);
RestMLRegisterModelMetaAction restMLRegisterModelMetaAction = new RestMLRegisterModelMetaAction(clusterService, settings);
Expand All @@ -479,7 +489,7 @@ public List<RestHandler> getRestHandlers(
RestMLUpdateModelGroupAction restMLUpdateModelGroupAction = new RestMLUpdateModelGroupAction();
RestMLSearchModelGroupAction restMLSearchModelGroupAction = new RestMLSearchModelGroupAction();
RestMLDeleteModelGroupAction restMLDeleteModelGroupAction = new RestMLDeleteModelGroupAction();
RestMLCreateConnectorAction restMLCreateConnectorAction = new RestMLCreateConnectorAction();
RestMLCreateConnectorAction restMLCreateConnectorAction = new RestMLCreateConnectorAction(mlFeatureEnabledSetting);
RestMLGetConnectorAction restMLGetConnectorAction = new RestMLGetConnectorAction();
RestMLDeleteConnectorAction restMLDeleteConnectorAction = new RestMLDeleteConnectorAction();
RestMLSearchConnectorAction restMLSearchConnectorAction = new RestMLSearchConnectorAction();
Expand Down Expand Up @@ -614,7 +624,8 @@ public List<Setting<?>> getSettings() {
MLCommonsSettings.ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED,
MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX,
MLCommonsSettings.ML_COMMONS_REMOTE_MODEL_ELIGIBLE_NODE_ROLES,
MLCommonsSettings.ML_COMMONS_LOCAL_MODEL_ELIGIBLE_NODE_ROLES
MLCommonsSettings.ML_COMMONS_LOCAL_MODEL_ELIGIBLE_NODE_ROLES,
MLCommonsSettings.ML_COMMONS_REMOTE_INFERENCE_ENABLED
);
return settings;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI;
import static org.opensearch.ml.utils.MLExceptionUtils.REMOTE_INFERENCE_DISABLED_ERR_MSG;

import java.io.IOException;
import java.util.List;
Expand All @@ -17,6 +18,7 @@
import org.opensearch.ml.common.transport.connector.MLCreateConnectorAction;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorRequest;
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
import org.opensearch.rest.BaseRestHandler;
import org.opensearch.rest.RestRequest;
import org.opensearch.rest.action.RestToXContentListener;
Expand All @@ -26,11 +28,15 @@

public class RestMLCreateConnectorAction extends BaseRestHandler {
private static final String ML_CREATE_CONNECTOR_ACTION = "ml_create_connector_action";
private final MLFeatureEnabledSetting mlFeatureEnabledSetting;

/**
* Constructor *
* Constructor
* @param mlFeatureEnabledSetting
*/
public RestMLCreateConnectorAction() {}
public RestMLCreateConnectorAction(MLFeatureEnabledSetting mlFeatureEnabledSetting) {
this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
}

@Override
public String getName() {
Expand All @@ -56,6 +62,12 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client
*/
@VisibleForTesting
MLCreateConnectorRequest getRequest(RestRequest request) throws IOException {
if (!mlFeatureEnabledSetting.isRemoteInferenceEnabled()) {
throw new IllegalStateException(REMOTE_INFERENCE_DISABLED_ERR_MSG);
}
if (!request.hasContent()) {
throw new IOException("Create Connector request has empty body");
}
XContentParser parser = request.contentParser();
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
MLCreateConnectorInput mlCreateConnectorInput = MLCreateConnectorInput.parse(parser);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI;
import static org.opensearch.ml.utils.MLExceptionUtils.REMOTE_INFERENCE_DISABLED_ERR_MSG;
import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_ALGORITHM;
import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_ID;
import static org.opensearch.ml.utils.RestActionUtils.getParameterId;
Expand All @@ -29,6 +30,7 @@
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
import org.opensearch.rest.BaseRestHandler;
import org.opensearch.rest.BytesRestResponse;
import org.opensearch.rest.RestRequest;
Expand All @@ -45,11 +47,14 @@ public class RestMLPredictionAction extends BaseRestHandler {

private MLModelManager modelManager;

private MLFeatureEnabledSetting mlFeatureEnabledSetting;

/**
* Constructor
*/
public RestMLPredictionAction(MLModelManager modelManager) {
public RestMLPredictionAction(MLModelManager modelManager, MLFeatureEnabledSetting mlFeatureEnabledSetting) {
this.modelManager = modelManager;
this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
}

@Override
Expand Down Expand Up @@ -117,6 +122,9 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client
*/
@VisibleForTesting
MLPredictionTaskRequest getRequest(String modelId, String algorithm, RestRequest request) throws IOException {
if (FunctionName.REMOTE.name().equals(algorithm) && !mlFeatureEnabledSetting.isRemoteInferenceEnabled()) {
throw new IllegalStateException(REMOTE_INFERENCE_DISABLED_ERR_MSG);
}
XContentParser parser = request.contentParser();
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
MLInput mlInput = MLInput.parse(parser, algorithm);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ALLOW_MODEL_URL;
import static org.opensearch.ml.utils.MLExceptionUtils.REMOTE_INFERENCE_DISABLED_ERR_MSG;
import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_DEPLOY_MODEL;
import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_ID;
import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_VERSION;
Expand All @@ -20,9 +21,11 @@
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.transport.register.MLRegisterModelAction;
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
import org.opensearch.ml.common.transport.register.MLRegisterModelRequest;
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
import org.opensearch.rest.BaseRestHandler;
import org.opensearch.rest.RestRequest;
import org.opensearch.rest.action.RestToXContentListener;
Expand All @@ -33,20 +36,24 @@
public class RestMLRegisterModelAction extends BaseRestHandler {
private static final String ML_REGISTER_MODEL_ACTION = "ml_register_model_action";
private volatile boolean isModelUrlAllowed;
private final MLFeatureEnabledSetting mlFeatureEnabledSetting;

/**
* Constructor
*/
public RestMLRegisterModelAction() {}
public RestMLRegisterModelAction(MLFeatureEnabledSetting mlFeatureEnabledSetting) {
this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
}

/**
* Constructor
* @param clusterService cluster service
* @param settings settings
*/
public RestMLRegisterModelAction(ClusterService clusterService, Settings settings) {
public RestMLRegisterModelAction(ClusterService clusterService, Settings settings, MLFeatureEnabledSetting mlFeatureEnabledSetting) {
isModelUrlAllowed = ML_COMMONS_ALLOW_MODEL_URL.get(settings);
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_ALLOW_MODEL_URL, it -> isModelUrlAllowed = it);
this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
}

@Override
Expand Down Expand Up @@ -93,6 +100,9 @@ MLRegisterModelRequest getRequest(RestRequest request) throws IOException {
XContentParser parser = request.contentParser();
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
MLRegisterModelInput mlInput = MLRegisterModelInput.parse(parser, loadModel);
if (mlInput.getFunctionName() == FunctionName.REMOTE && !mlFeatureEnabledSetting.isRemoteInferenceEnabled()) {
throw new IllegalStateException(REMOTE_INFERENCE_DISABLED_ERR_MSG);
}
if (mlInput.getUrl() != null && !isModelUrlAllowed) {
throw new IllegalArgumentException(
"To upload custom model user needs to enable allow_registering_model_via_url settings. Otherwise please use opensearch pre-trained models."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,10 @@ private MLCommonsSettings() {}
Setting.Property.Dynamic
);

// This setting is to enable/disable Create Connector API and Register/Deploy/Predict Model APIs for remote models
public static final Setting<Boolean> ML_COMMONS_REMOTE_INFERENCE_ENABLED = Setting
.boolSetting("plugins.ml_commons.remote_inference.enabled", true, Setting.Property.NodeScope, Setting.Property.Dynamic);

public static final Setting<Boolean> ML_COMMONS_MODEL_ACCESS_CONTROL_ENABLED = Setting
.boolSetting("plugins.ml_commons.model_access_control_enabled", false, Setting.Property.NodeScope, Setting.Property.Dynamic);

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
*
* * Copyright OpenSearch Contributors
* * SPDX-License-Identifier: Apache-2.0
*
*/

package org.opensearch.ml.settings;

import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_INFERENCE_ENABLED;

import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;

public class MLFeatureEnabledSetting {

private volatile Boolean isRemoteInferenceEnabled;

public MLFeatureEnabledSetting(ClusterService clusterService, Settings settings) {
isRemoteInferenceEnabled = ML_COMMONS_REMOTE_INFERENCE_ENABLED.get(settings);
clusterService
.getClusterSettings()
.addSettingsUpdateConsumer(ML_COMMONS_REMOTE_INFERENCE_ENABLED, it -> isRemoteInferenceEnabled = it);
}

/**
* Whether the remote inference feature is enabled. If disabled, APIs in ml-commons will block remote inference.
* @return whether Remote Inference is enabled.
*/
public boolean isRemoteInferenceEnabled() {
return isRemoteInferenceEnabled;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
public class MLExceptionUtils {

public static final String NOT_SERIALIZABLE_EXCEPTION_WRAPPER = "NotSerializableExceptionWrapper: ";
public static final String REMOTE_INFERENCE_DISABLED_ERR_MSG =
"Remote Inference is currently disabled. To enable it, update the setting \"plugins.ml_commons.remote_inference_enabled\" to true.";

public static String getRootCauseMessage(final Throwable throwable) {
String message = ExceptionUtils.getRootCauseMessage(throwable);
Expand Down
Loading