Skip to content

Commit

Permalink
change only run on ml node setting default value to true
Browse files Browse the repository at this point in the history
Signed-off-by: Yaliang Wu <[email protected]>
  • Loading branch information
ylwu-amzn committed Jan 10, 2023
1 parent bbe6ef8 commit 69ba830
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ private MLCommonsSettings() {}
public static final Setting<Integer> ML_COMMONS_MAX_ML_TASK_PER_NODE = Setting
.intSetting("plugins.ml_commons.max_ml_task_per_node", 10, 0, 10000, Setting.Property.NodeScope, Setting.Property.Dynamic);
public static final Setting<Boolean> ML_COMMONS_ONLY_RUN_ON_ML_NODE = Setting
.boolSetting("plugins.ml_commons.only_run_on_ml_node", false, Setting.Property.NodeScope, Setting.Property.Dynamic);
.boolSetting("plugins.ml_commons.only_run_on_ml_node", true, Setting.Property.NodeScope, Setting.Property.Dynamic);
public static final Setting<Integer> ML_COMMONS_SYNC_UP_JOB_INTERVAL_IN_SECONDS = Setting
.intSetting(
"plugins.ml_commons.sync_up_job_interval_in_seconds",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.opensearch.ml.action.stats.MLStatsNodesAction;
import org.opensearch.ml.action.stats.MLStatsNodesRequest;
import org.opensearch.ml.cluster.DiscoveryNodeHelper;
import org.opensearch.ml.common.exception.MLResourceNotFoundException;
import org.opensearch.ml.stats.MLNodeLevelStat;

import com.google.common.collect.ImmutableSet;
Expand Down Expand Up @@ -165,6 +166,9 @@ private void dispatchTaskWithLeastLoad(ActionListener<DiscoveryNode> listener) {

private void dispatchTaskWithRoundRobin(ActionListener<DiscoveryNode> listener) {
DiscoveryNode[] eligibleNodes = nodeHelper.getEligibleNodes();
if (eligibleNodes == null || eligibleNodes.length == 0) {
throw new MLResourceNotFoundException("no eligible node found, ml node is required to run this request");
}
dispatchTaskWithRoundRobin(eligibleNodes, listener);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.ml.action;

import static org.opensearch.ml.common.input.parameter.regression.LogisticRegressionParams.ObjectiveType.LOGMULTICLASS;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ONLY_RUN_ON_ML_NODE;
import static org.opensearch.ml.utils.RestActionUtils.getAllNodes;
import static org.opensearch.ml.utils.TestData.TARGET_FIELD;
import static org.opensearch.ml.utils.TestData.TIME_FIELD;
Expand All @@ -24,6 +25,7 @@
import org.opensearch.action.support.WriteRequest;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.cluster.node.DiscoveryNodes;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryBuilders;
Expand Down Expand Up @@ -409,4 +411,13 @@ public MLSyncUpNodesResponse syncUp_Clear() {
MLSyncUpNodesResponse syncUpResponse = client().execute(MLSyncUpAction.INSTANCE, syncUpRequest).actionGet(5000);
return syncUpResponse;
}

@Override
protected Settings nodeSettings(int ordinal) {
return Settings
.builder()
.put(super.nodeSettings(ordinal))
.put(ML_COMMONS_ONLY_RUN_ON_ML_NODE.getKey(), false)
.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import org.apache.http.ssl.SSLContextBuilder;
import org.apache.http.util.EntityUtils;
import org.junit.After;
import org.junit.Before;
import org.opensearch.client.Request;
import org.opensearch.client.Response;
import org.opensearch.client.RestClient;
Expand Down Expand Up @@ -103,6 +104,20 @@ protected boolean isHttps() {
return isHttps;
}

@Before
public void setupSettings() throws IOException {
Response response = TestHelper
.makeRequest(
client(),
"PUT",
"_cluster/settings",
null,
"{\"persistent\":{\"plugins.ml_commons.only_run_on_ml_node\":false}}",
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, ""))
);
assertEquals(200, response.getStatusLine().getStatusCode());
}

@Override
protected String getProtocol() {
return isHttps() ? "https" : "http";
Expand Down

0 comments on commit 69ba830

Please sign in to comment.