diff --git a/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java b/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java index 0b61a9e249..1f47d5b9c1 100644 --- a/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java +++ b/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java @@ -23,7 +23,7 @@ private MLCommonsSettings() {} public static final Setting ML_COMMONS_MAX_ML_TASK_PER_NODE = Setting .intSetting("plugins.ml_commons.max_ml_task_per_node", 10, 0, 10000, Setting.Property.NodeScope, Setting.Property.Dynamic); public static final Setting ML_COMMONS_ONLY_RUN_ON_ML_NODE = Setting - .boolSetting("plugins.ml_commons.only_run_on_ml_node", false, Setting.Property.NodeScope, Setting.Property.Dynamic); + .boolSetting("plugins.ml_commons.only_run_on_ml_node", true, Setting.Property.NodeScope, Setting.Property.Dynamic); public static final Setting ML_COMMONS_SYNC_UP_JOB_INTERVAL_IN_SECONDS = Setting .intSetting( "plugins.ml_commons.sync_up_job_interval_in_seconds", diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLTaskDispatcher.java b/plugin/src/main/java/org/opensearch/ml/task/MLTaskDispatcher.java index ad4db303b5..539baa9de1 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLTaskDispatcher.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLTaskDispatcher.java @@ -26,6 +26,7 @@ import org.opensearch.ml.action.stats.MLStatsNodesAction; import org.opensearch.ml.action.stats.MLStatsNodesRequest; import org.opensearch.ml.cluster.DiscoveryNodeHelper; +import org.opensearch.ml.common.exception.MLResourceNotFoundException; import org.opensearch.ml.stats.MLNodeLevelStat; import com.google.common.collect.ImmutableSet; @@ -165,6 +166,9 @@ private void dispatchTaskWithLeastLoad(ActionListener listener) { private void dispatchTaskWithRoundRobin(ActionListener listener) { DiscoveryNode[] eligibleNodes = nodeHelper.getEligibleNodes(); + if (eligibleNodes == null || eligibleNodes.length == 0) { + throw new MLResourceNotFoundException("no eligible node found, ml node is required to run this request"); + } dispatchTaskWithRoundRobin(eligibleNodes, listener); } diff --git a/plugin/src/test/java/org/opensearch/ml/action/MLCommonsIntegTestCase.java b/plugin/src/test/java/org/opensearch/ml/action/MLCommonsIntegTestCase.java index 45a69388e6..6294aaffe3 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/MLCommonsIntegTestCase.java +++ b/plugin/src/test/java/org/opensearch/ml/action/MLCommonsIntegTestCase.java @@ -6,6 +6,7 @@ package org.opensearch.ml.action; import static org.opensearch.ml.common.input.parameter.regression.LogisticRegressionParams.ObjectiveType.LOGMULTICLASS; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ONLY_RUN_ON_ML_NODE; import static org.opensearch.ml.utils.RestActionUtils.getAllNodes; import static org.opensearch.ml.utils.TestData.TARGET_FIELD; import static org.opensearch.ml.utils.TestData.TIME_FIELD; @@ -24,6 +25,7 @@ import org.opensearch.action.support.WriteRequest; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.node.DiscoveryNodes; +import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.XContentType; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryBuilders; @@ -409,4 +411,9 @@ public MLSyncUpNodesResponse syncUp_Clear() { MLSyncUpNodesResponse syncUpResponse = client().execute(MLSyncUpAction.INSTANCE, syncUpRequest).actionGet(5000); return syncUpResponse; } + + @Override + protected Settings nodeSettings(int ordinal) { + return Settings.builder().put(super.nodeSettings(ordinal)).put(ML_COMMONS_ONLY_RUN_ON_ML_NODE.getKey(), false).build(); + } } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java b/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java index 0847990786..4c1675e420 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java @@ -49,6 +49,7 @@ import org.apache.http.ssl.SSLContextBuilder; import org.apache.http.util.EntityUtils; import org.junit.After; +import org.junit.Before; import org.opensearch.client.Request; import org.opensearch.client.Response; import org.opensearch.client.RestClient; @@ -103,6 +104,20 @@ protected boolean isHttps() { return isHttps; } + @Before + public void setupSettings() throws IOException { + Response response = TestHelper + .makeRequest( + client(), + "PUT", + "_cluster/settings", + null, + "{\"persistent\":{\"plugins.ml_commons.only_run_on_ml_node\":false}}", + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "")) + ); + assertEquals(200, response.getStatusLine().getStatusCode()); + } + @Override protected String getProtocol() { return isHttps() ? "https" : "http";