diff --git a/common/src/main/java/org/opensearch/ml/common/CommonValue.java b/common/src/main/java/org/opensearch/ml/common/CommonValue.java index 9b4a35e8ae..58a6171e57 100644 --- a/common/src/main/java/org/opensearch/ml/common/CommonValue.java +++ b/common/src/main/java/org/opensearch/ml/common/CommonValue.java @@ -32,6 +32,10 @@ public class CommonValue { public static final String ML_TASK_INDEX = ".plugins-ml-task"; public static final Integer ML_MODEL_INDEX_SCHEMA_VERSION = 3; public static final Integer ML_TASK_INDEX_SCHEMA_VERSION = 1; + + public static final String ML_MODEL_RELOAD_INDEX = ".plugins-ml-model-reload"; + public static final String NODE_ID_FIELD = "node_id"; + public static final String MODEL_LOAD_RETRY_TIMES_FIELD = "retry_times"; public static final String USER_FIELD_MAPPING = " \"" + CommonValue.USER + "\": {\n" diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelAutoReloader.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelAutoReloader.java new file mode 100644 index 0000000000..a5078dd88f --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelAutoReloader.java @@ -0,0 +1,331 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.ml.model; + +import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.CommonValue.ML_MODEL_RELOAD_INDEX; +import static org.opensearch.ml.common.CommonValue.ML_TASK_INDEX; +import static org.opensearch.ml.common.CommonValue.MODEL_LOAD_RETRY_TIMES_FIELD; +import static org.opensearch.ml.common.CommonValue.NODE_ID_FIELD; +import static org.opensearch.ml.plugin.MachineLearningPlugin.LOAD_THREAD_POOL; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MODEL_AUTO_RELOAD_ENABLE; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_MODEL_RELOAD_MAX_RETRY_TIMES; +import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ExecutionException; +import java.util.stream.Collectors; + +import lombok.extern.log4j.Log4j2; + +import org.opensearch.action.ActionListener; +import org.opensearch.action.StepListener; +import org.opensearch.action.index.IndexAction; +import org.opensearch.action.index.IndexRequestBuilder; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.search.SearchAction; +import org.opensearch.action.search.SearchRequestBuilder; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.client.Client; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.CollectionUtils; +import org.opensearch.common.xcontent.NamedXContentRegistry; +import org.opensearch.common.xcontent.XContentParser; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.ml.cluster.DiscoveryNodeHelper; +import org.opensearch.ml.common.MLTask; +import org.opensearch.ml.common.exception.MLException; +import org.opensearch.ml.common.transport.load.MLLoadModelAction; +import org.opensearch.ml.common.transport.load.MLLoadModelRequest; +import org.opensearch.ml.utils.MLNodeUtils; +import org.opensearch.rest.RestStatus; +import org.opensearch.search.SearchHit; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.sort.FieldSortBuilder; +import org.opensearch.search.sort.SortBuilder; +import org.opensearch.search.sort.SortOrder; +import org.opensearch.threadpool.ThreadPool; + +import com.google.common.annotations.VisibleForTesting; + +/** + * Manager class for ML models and nodes. It contains ML model auto reload operations etc. + */ +@Log4j2 +public class MLModelAutoReloader { + + private final Client client; + private final ClusterService clusterService; + private final NamedXContentRegistry xContentRegistry; + private final DiscoveryNodeHelper nodeHelper; + private final ThreadPool threadPool; + private volatile Boolean enableAutoReloadModel; + private volatile Integer autoReloadMaxRetryTimes; + + /** + * constructor method, init all the params necessary for model auto reloading + * + * @param clusterService clusterService + * @param threadPool threadPool + * @param client client + * @param xContentRegistry xContentRegistry + * @param nodeHelper nodeHelper + * @param settings settings + */ + public MLModelAutoReloader( + ClusterService clusterService, + ThreadPool threadPool, + Client client, + NamedXContentRegistry xContentRegistry, + DiscoveryNodeHelper nodeHelper, + Settings settings + ) { + this.clusterService = clusterService; + this.client = client; + this.xContentRegistry = xContentRegistry; + this.nodeHelper = nodeHelper; + this.threadPool = threadPool; + + enableAutoReloadModel = ML_COMMONS_MODEL_AUTO_RELOAD_ENABLE.get(settings); + autoReloadMaxRetryTimes = ML_MODEL_RELOAD_MAX_RETRY_TIMES.get(settings); + clusterService + .getClusterSettings() + .addSettingsUpdateConsumer(ML_COMMONS_MODEL_AUTO_RELOAD_ENABLE, it -> enableAutoReloadModel = it); + + clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_MODEL_RELOAD_MAX_RETRY_TIMES, it -> autoReloadMaxRetryTimes = it); + } + + /** + * the main method: model auto reloading + */ + public void autoReloadModel() { + log.info("auto reload model enabled: {} ", enableAutoReloadModel); + + // if we don't need to reload automatically, just return without doing anything + if (!enableAutoReloadModel) { + return; + } + + // At opensearch startup, get local node id, if not ml node,we ignored, just return without doing anything + if (!MLNodeUtils.isMLNode(clusterService.localNode())) { + return; + } + + String localNodeId = clusterService.localNode().getId(); + // auto reload all models of this local ml node + threadPool.executor(LOAD_THREAD_POOL).execute(() -> { + try { + autoReloadModelByNodeId(localNodeId); + } catch (ExecutionException | InterruptedException e) { + log.error("the model auto-reloading has exception,and the root cause message is: {}", e); + throw new MLException(e); + } + }); + } + + /** + * auto reload all the models under the node id
the node must be a ml node
+ * + * @param localNodeId node id + */ + @VisibleForTesting + void autoReloadModelByNodeId(String localNodeId) throws ExecutionException, InterruptedException { + StepListener queryTaskStep = new StepListener<>(); + StepListener getRetryTimesStep = new StepListener<>(); + StepListener saveLatestRetryTimesStep = new StepListener<>(); + + if (!clusterService.state().metadata().indices().containsKey(ML_TASK_INDEX)) { + // ML_TASK_INDEX did not exist,do nothing + return; + } + + queryTask(localNodeId, ActionListener.wrap(queryTaskStep::onResponse, queryTaskStep::onFailure)); + + getRetryTimes(localNodeId, ActionListener.wrap(getRetryTimesStep::onResponse, getRetryTimesStep::onFailure)); + + queryTaskStep.whenComplete(searchResponse -> { + SearchHit[] hits = searchResponse.getHits().getHits(); + if (CollectionUtils.isEmpty(hits)) { + return; + } + + getRetryTimesStep.whenComplete(getReTryTimesResponse -> { + int retryTimes = 0; + // if getReTryTimesResponse is null,it means we get retryTimes at the first time,and the index + // .plugins-ml-model-reload doesn't exist,so we should let retryTimes be zero(init value) + // we don't do anything + // if getReTryTimesResponse is not null,it means we have saved the value of retryTimes into the index + // .plugins-ml-model-reload,so we get the value of the field MODEL_LOAD_RETRY_TIMES_FIELD + if (getReTryTimesResponse != null) { + Map sourceAsMap = getReTryTimesResponse.getHits().getHits()[0].getSourceAsMap(); + retryTimes = (Integer) sourceAsMap.get(MODEL_LOAD_RETRY_TIMES_FIELD); + } + + // According to the node id to get retry times, if more than the max retry times, don't need to retry + // that the number of unsuccessful reload has reached the maximum number of times, do not need to reload + if (retryTimes > autoReloadMaxRetryTimes) { + log.info("Node: {} has reached to the max retry limit, failed to load models", localNodeId); + return; + } + + try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, hits[0].getSourceRef())) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + MLTask mlTask = MLTask.parse(parser); + + autoReloadModelByNodeAndModelId(localNodeId, mlTask.getModelId()); + + // if reload the model successfully,the number of unsuccessful reload should be reset to zero. + retryTimes = 0; + } catch (MLException e) { + retryTimes++; + log.error("Can't auto reload model in node id {} ,has tried {} times\nThe reason is:{}", localNodeId, retryTimes, e); + } + + // Store the latest value of the retryTimes and node id under the index ".plugins-ml-model-reload" + saveLatestRetryTimes( + localNodeId, + retryTimes, + ActionListener.wrap(saveLatestRetryTimesStep::onResponse, saveLatestRetryTimesStep::onFailure) + ); + }, getRetryTimesStep::onFailure); + }, queryTaskStep::onFailure); + + saveLatestRetryTimesStep.whenComplete(response -> log.info("successfully complete all steps"), saveLatestRetryTimesStep::onFailure); + } + + /** + * auto reload 1 model under the node id + * + * @param localNodeId node id + * @param modelId model id + */ + @VisibleForTesting + void autoReloadModelByNodeAndModelId(String localNodeId, String modelId) throws MLException { + List allMLNodeIdList = Arrays + .stream(nodeHelper.getAllNodes()) + .filter(MLNodeUtils::isMLNode) + .map(DiscoveryNode::getId) + .collect(Collectors.toList()); + + if (!allMLNodeIdList.contains(localNodeId)) { + allMLNodeIdList.add(localNodeId); + } + MLLoadModelRequest mlLoadModelRequest = new MLLoadModelRequest(modelId, allMLNodeIdList.toArray(new String[] {}), false, false); + + client + .execute( + MLLoadModelAction.INSTANCE, + mlLoadModelRequest, + ActionListener + .wrap(response -> log.info("the model {} is auto reloading under the node {} ", modelId, localNodeId), exception -> { + log.error("fail to reload model " + modelId + " under the node " + localNodeId + "\nthe reason is: " + exception); + throw new MLException( + "fail to reload model " + modelId + " under the node " + localNodeId + "\nthe reason is: " + exception + ); + }) + ); + } + + /** + * query task index, and get the result of "task_type"="LOAD_MODEL" and "state"="COMPLETED" and + * "worker_node" match nodeId + * + * @param localNodeId one of query condition + */ + @VisibleForTesting + void queryTask(String localNodeId, ActionListener searchResponseActionListener) { + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().from(0).size(1); + + QueryBuilder queryBuilder = QueryBuilders + .boolQuery() + .must(QueryBuilders.matchPhraseQuery("task_type", "LOAD_MODEL")) + .must(QueryBuilders.matchPhraseQuery("worker_node", localNodeId)) + .must( + QueryBuilders + .boolQuery() + .should(QueryBuilders.matchPhraseQuery("state", "COMPLETED")) + .should(QueryBuilders.matchPhraseQuery("state", "COMPLETED_WITH_ERROR")) + ); + searchSourceBuilder.query(queryBuilder); + + SortBuilder sortBuilderOrder = new FieldSortBuilder("create_time").order(SortOrder.DESC); + searchSourceBuilder.sort(sortBuilderOrder); + + SearchRequestBuilder searchRequestBuilder = new SearchRequestBuilder(client, SearchAction.INSTANCE) + .setIndices(ML_TASK_INDEX) + .setSource(searchSourceBuilder); + + searchRequestBuilder.execute(ActionListener.wrap(searchResponseActionListener::onResponse, exception -> { + log.error("index {} not found, the reason is {}", ML_TASK_INDEX, exception); + throw new MLException("index " + ML_TASK_INDEX + " not found"); + })); + } + + /** + * get retry times from the index ".plugins-ml-model-reload" by 1 ml node + * + * @param localNodeId the filter condition to query + */ + @VisibleForTesting + void getRetryTimes(String localNodeId, ActionListener searchResponseActionListener) { + if (!clusterService.state().metadata().indices().containsKey(ML_MODEL_RELOAD_INDEX)) { + // ML_MODEL_RELOAD_INDEX did not exist, it means it is our first time to do model auto-reloading operation + searchResponseActionListener.onResponse(null); + } + + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.fetchSource(new String[] { MODEL_LOAD_RETRY_TIMES_FIELD }, null); + QueryBuilder queryBuilder = QueryBuilders.idsQuery().addIds(localNodeId); + searchSourceBuilder.query(queryBuilder); + + SearchRequestBuilder searchRequestBuilder = new SearchRequestBuilder(client, SearchAction.INSTANCE) + .setIndices(ML_MODEL_RELOAD_INDEX) + .setSource(searchSourceBuilder); + + searchRequestBuilder.execute(ActionListener.wrap(searchResponse -> { + SearchHit[] hits = searchResponse.getHits().getHits(); + if (CollectionUtils.isEmpty(hits)) { + searchResponseActionListener.onResponse(null); + return; + } + + searchResponseActionListener.onResponse(searchResponse); + }, searchResponseActionListener::onFailure)); + } + + /** + * save retry times + * @param localNodeId node id + * @param retryTimes actual retry times + */ + @VisibleForTesting + void saveLatestRetryTimes(String localNodeId, int retryTimes, ActionListener indexResponseActionListener) { + Map content = new HashMap<>(2); + content.put(NODE_ID_FIELD, localNodeId); + content.put(MODEL_LOAD_RETRY_TIMES_FIELD, retryTimes); + + IndexRequestBuilder indexRequestBuilder = new IndexRequestBuilder(client, IndexAction.INSTANCE, ML_MODEL_RELOAD_INDEX) + .setId(localNodeId) + .setSource(content) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + + indexRequestBuilder.execute(ActionListener.wrap(indexResponse -> { + if (indexResponse.status() == RestStatus.CREATED || indexResponse.status() == RestStatus.OK) { + log.info("node id:{} insert retry times successfully", localNodeId); + indexResponseActionListener.onResponse(indexResponse); + } + }, e -> { + log.error("node id:" + localNodeId + " insert retry times unsuccessfully", e); + indexResponseActionListener.onFailure(new MLException(e)); + })); + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index 2c44a7aa0d..97fe0baafb 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -99,6 +99,7 @@ import org.opensearch.ml.engine.algorithms.sample.LocalSampleCalculator; import org.opensearch.ml.indices.MLIndicesHandler; import org.opensearch.ml.indices.MLInputDatasetHandler; +import org.opensearch.ml.model.MLModelAutoReloader; import org.opensearch.ml.model.MLModelCacheHelper; import org.opensearch.ml.model.MLModelManager; import org.opensearch.ml.rest.RestMLCreateModelMetaAction; @@ -175,6 +176,8 @@ public class MachineLearningPlugin extends Plugin implements ActionPlugin { private MLModelChunkUploader mlModelChunkUploader; private MLEngine mlEngine; + private MLModelAutoReloader mlModelAutoReloader; + private Client client; private ClusterService clusterService; private ThreadPool threadPool; @@ -352,6 +355,9 @@ public Collection createComponents( mlIndicesHandler ); + mlModelAutoReloader = new MLModelAutoReloader(clusterService, threadPool, client, xContentRegistry, nodeHelper, settings); + mlModelAutoReloader.autoReloadModel(); + return ImmutableList .of( mlEngine, @@ -373,7 +379,8 @@ public Collection createComponents( modelHelper, mlCommonsClusterEventListener, clusterManagerEventListener, - mlCircuitBreakerService + mlCircuitBreakerService, + mlModelAutoReloader ); } @@ -513,7 +520,9 @@ public List> getSettings() { MLCommonsSettings.ML_COMMONS_MAX_ML_TASK_PER_NODE, MLCommonsSettings.ML_COMMONS_MAX_LOAD_MODEL_TASKS_PER_NODE, MLCommonsSettings.ML_COMMONS_TRUSTED_URL_REGEX, - MLCommonsSettings.ML_COMMONS_NATIVE_MEM_THRESHOLD + MLCommonsSettings.ML_COMMONS_NATIVE_MEM_THRESHOLD, + MLCommonsSettings.ML_COMMONS_MODEL_AUTO_RELOAD_ENABLE, + MLCommonsSettings.ML_MODEL_RELOAD_MAX_RETRY_TIMES ); return settings; } diff --git a/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java b/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java index 3bae6ea92a..bed1384fe3 100644 --- a/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java +++ b/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java @@ -56,4 +56,10 @@ private MLCommonsSettings() {} public static final Setting ML_COMMONS_NATIVE_MEM_THRESHOLD = Setting .intSetting("plugins.ml_commons.native_memory_threshold", 90, 0, 100, Setting.Property.NodeScope, Setting.Property.Dynamic); + + public static final Setting ML_COMMONS_MODEL_AUTO_RELOAD_ENABLE = Setting + .boolSetting("plugins.ml_commons.model.autoreload.enable", false, Setting.Property.NodeScope, Setting.Property.Dynamic); + + public static final Setting ML_MODEL_RELOAD_MAX_RETRY_TIMES = Setting + .intSetting("plugins.ml_commons.model.autoreload.retrytimes", 3, Setting.Property.NodeScope, Setting.Property.Dynamic); } diff --git a/plugin/src/test/java/org/opensearch/ml/model/MLModelAutoReloaderITTests.java b/plugin/src/test/java/org/opensearch/ml/model/MLModelAutoReloaderITTests.java new file mode 100644 index 0000000000..c06809b166 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/model/MLModelAutoReloaderITTests.java @@ -0,0 +1,455 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.model; + +import static org.hamcrest.Matchers.anyOf; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.notNullValue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.when; +import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; +import static org.opensearch.ml.common.CommonValue.ML_MODEL_RELOAD_INDEX; +import static org.opensearch.ml.common.CommonValue.ML_TASK_INDEX; +import static org.opensearch.ml.common.CommonValue.ML_TASK_INDEX_MAPPING; +import static org.opensearch.ml.common.CommonValue.MODEL_LOAD_RETRY_TIMES_FIELD; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MODEL_AUTO_RELOAD_ENABLE; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.Base64; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.atomic.AtomicInteger; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.Version; +import org.opensearch.action.ActionListener; +import org.opensearch.action.StepListener; +import org.opensearch.action.admin.indices.create.CreateIndexAction; +import org.opensearch.action.admin.indices.create.CreateIndexRequest; +import org.opensearch.action.index.IndexAction; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.cluster.node.DiscoveryNodeRole; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.transport.TransportAddress; +import org.opensearch.common.xcontent.ToXContent; +import org.opensearch.common.xcontent.XContentBuilder; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.ml.action.MLCommonsIntegTestCase; +import org.opensearch.ml.cluster.DiscoveryNodeHelper; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.MLModel; +import org.opensearch.ml.common.MLTask; +import org.opensearch.ml.common.MLTaskState; +import org.opensearch.ml.common.MLTaskType; +import org.opensearch.ml.common.exception.MLException; +import org.opensearch.ml.common.model.MLModelConfig; +import org.opensearch.ml.common.model.MLModelFormat; +import org.opensearch.ml.common.model.MLModelState; +import org.opensearch.ml.common.model.TextEmbeddingModelConfig; +import org.opensearch.ml.utils.TestHelper; +import org.opensearch.test.OpenSearchIntegTestCase; + +@OpenSearchIntegTestCase.ClusterScope(scope = OpenSearchIntegTestCase.Scope.SUITE, numDataNodes = 3) +public class MLModelAutoReloaderITTests extends MLCommonsIntegTestCase { + private final Instant time = Instant.now(); + + @Mock + private DiscoveryNodeHelper nodeHelper; + private Settings settings; + private MLModelAutoReloader mlModelAutoReloader; + private String taskId; + private String modelId; + private String localNodeId; + @Mock + private MLModelManager modelManager; + private MLModel modelChunk0; + private MLModel modelChunk1; + + @Before + public void setup() throws Exception { + super.setUp(); + + MockitoAnnotations.openMocks(this); + + taskId = "taskId1"; + modelId = "modelId1"; + + settings = Settings.builder().put(ML_COMMONS_MODEL_AUTO_RELOAD_ENABLE.getKey(), true).build(); + + localNodeId = clusterService().localNode().getId(); + + AtomicInteger portGenerator = new AtomicInteger(); + Set roleSet = new HashSet<>(); + roleSet.add(TestHelper.ML_ROLE); + DiscoveryNode node = new DiscoveryNode( + localNodeId, + new TransportAddress(TransportAddress.META_ADDRESS, portGenerator.incrementAndGet()), + new HashMap<>(), + roleSet, + Version.CURRENT + ); + + nodeHelper = spy(new DiscoveryNodeHelper(clusterService(), settings)); + + mlModelAutoReloader = spy( + new MLModelAutoReloader(clusterService(), client().threadPool(), client(), xContentRegistry(), nodeHelper, settings) + ); + modelManager = mock(MLModelManager.class); + + when(nodeHelper.getEligibleNodes()).thenReturn(new DiscoveryNode[] { node }); + } + + @After + public void tearDown() throws Exception { + super.tearDown(); + + settings = null; + nodeHelper = null; + mlModelAutoReloader = null; + modelId = null; + localNodeId = null; + modelManager = null; + } + + public void testAutoReloadModel() { + mlModelAutoReloader.autoReloadModel(); + } + + public void testAutoReloadModel_setting_false() { + settings = Settings.builder().put(ML_COMMONS_MODEL_AUTO_RELOAD_ENABLE.getKey(), false).build(); + + nodeHelper = spy(new DiscoveryNodeHelper(clusterService(), settings)); + + mlModelAutoReloader = spy( + new MLModelAutoReloader(clusterService(), client().threadPool(), client(), xContentRegistry(), nodeHelper, settings) + ); + + mlModelAutoReloader.autoReloadModel(); + } + + public void testAutoReloadModel_Is_Not_ML_Node() { + mlModelAutoReloader.autoReloadModel(); + } + + public void testAutoReloadModelByNodeId_Retry() throws IOException, ExecutionException, InterruptedException { + createIndex(ML_MODEL_RELOAD_INDEX); + initDataOfMlTask(localNodeId, modelId, MLTaskType.LOAD_MODEL, MLTaskState.COMPLETED); + + mlModelAutoReloader.autoReloadModelByNodeId(localNodeId); + + StepListener getRetryTimesStep = new StepListener<>(); + mlModelAutoReloader.getRetryTimes(localNodeId, ActionListener.wrap(getRetryTimesStep::onResponse, getRetryTimesStep::onFailure)); + } + + public void testAutoReloadModelByNodeId_Max_ReTryTimes() throws IOException, ExecutionException, InterruptedException { + initDataOfMlTask(localNodeId, modelId, MLTaskType.LOAD_MODEL, MLTaskState.COMPLETED); + + final CountDownLatch inProgressLatch = new CountDownLatch(1); + StepListener saveLatestReTryTimesStep = new StepListener<>(); + mlModelAutoReloader + .saveLatestRetryTimes( + localNodeId, + 3, + ActionListener.wrap(saveLatestReTryTimesStep::onResponse, saveLatestReTryTimesStep::onFailure) + ); + + saveLatestReTryTimesStep.whenComplete(response -> { + inProgressLatch.countDown(); + org.hamcrest.MatcherAssert.assertThat(response, notNullValue()); + org.hamcrest.MatcherAssert.assertThat(response.status(), notNullValue()); + org.hamcrest.MatcherAssert.assertThat(response.status().getStatus(), anyOf(is(201), is(200))); + }, exception -> fail(exception.getMessage())); + + inProgressLatch.await(); + + mlModelAutoReloader.autoReloadModelByNodeId(localNodeId); + } + + public void testAutoReloadModelByNodeId_IndexNotFound() throws ExecutionException, InterruptedException { + mlModelAutoReloader.autoReloadModelByNodeId(localNodeId); + } + + public void testAutoReloadModelByNodeId_EmptyHits() throws ExecutionException, InterruptedException { + createIndex(); + + mlModelAutoReloader.autoReloadModelByNodeId(localNodeId); + } + + public void testAutoReloadModelByNodeAndModelId_Exception() { + Throwable exception = Assert + .assertThrows(RuntimeException.class, () -> mlModelAutoReloader.autoReloadModelByNodeAndModelId(localNodeId, modelId)); + + org.hamcrest.MatcherAssert + .assertThat( + exception.getMessage(), + containsString("fail to reload model " + modelId + " under the node " + localNodeId + "\nthe reason is: ") + ); + } + + public void testQueryTask() throws IOException { + StepListener queryTaskStep = queryTask(localNodeId, modelId, MLTaskType.LOAD_MODEL, MLTaskState.COMPLETED); + + queryTaskStep.whenComplete(response -> { + org.hamcrest.MatcherAssert.assertThat(response, notNullValue()); + org.hamcrest.MatcherAssert.assertThat(response.getHits(), notNullValue()); + org.hamcrest.MatcherAssert.assertThat(response.getHits().getHits(), notNullValue()); + org.hamcrest.MatcherAssert.assertThat(response.getHits().getHits().length, is(1)); + }, exception -> fail(exception.getMessage())); + + queryTaskStep = queryTask(localNodeId, modelId, MLTaskType.UPLOAD_MODEL, MLTaskState.COMPLETED); + + queryTaskStep.whenComplete(response -> { + org.hamcrest.MatcherAssert.assertThat(response, notNullValue()); + org.hamcrest.MatcherAssert.assertThat(response.getHits(), notNullValue()); + org.hamcrest.MatcherAssert.assertThat(response.getHits().getHits(), notNullValue()); + org.hamcrest.MatcherAssert.assertThat(response.getHits().getHits().length, is(0)); + }, exception -> fail(exception.getMessage())); + + queryTaskStep = queryTask(localNodeId, modelId, MLTaskType.LOAD_MODEL, MLTaskState.RUNNING); + + queryTaskStep.whenComplete(response -> { + org.hamcrest.MatcherAssert.assertThat(response, notNullValue()); + org.hamcrest.MatcherAssert.assertThat(response.getHits(), notNullValue()); + org.hamcrest.MatcherAssert.assertThat(response.getHits().getHits(), notNullValue()); + org.hamcrest.MatcherAssert.assertThat(response.getHits().getHits().length, is(0)); + }, exception -> fail(exception.getMessage())); + } + + public void testQueryTask_MultiDataInTaskIndex() throws IOException { + StepListener queryTaskStep = queryTask(localNodeId, "modelId1", MLTaskType.LOAD_MODEL, MLTaskState.COMPLETED); + queryTaskStep.whenComplete(response -> {}, exception -> fail(exception.getMessage())); + + queryTaskStep = queryTask(localNodeId, "modelId2", MLTaskType.LOAD_MODEL, MLTaskState.COMPLETED); + queryTaskStep.whenComplete(response -> {}, exception -> fail(exception.getMessage())); + + queryTaskStep = queryTask(localNodeId, "modelId3", MLTaskType.LOAD_MODEL, MLTaskState.COMPLETED); + + queryTaskStep.whenComplete(response -> { + org.hamcrest.MatcherAssert.assertThat(response, notNullValue()); + org.hamcrest.MatcherAssert.assertThat(response.getHits(), notNullValue()); + org.hamcrest.MatcherAssert.assertThat(response.getHits().getHits(), notNullValue()); + org.hamcrest.MatcherAssert.assertThat(response.getHits().getHits().length, is(1)); + + Map source = response.getHits().getHits()[0].getSourceAsMap(); + org.hamcrest.MatcherAssert.assertThat(source.get("model_id"), is("modelId3")); + }, exception -> fail(exception.getMessage())); + } + + public void testQueryTask_MultiDataInTaskIndex_TaskState_COMPLETED_WITH_ERROR() throws IOException { + createIndex(ML_MODEL_RELOAD_INDEX); + initDataOfMlTask(localNodeId, "modelId3", MLTaskType.LOAD_MODEL, MLTaskState.COMPLETED_WITH_ERROR); + + StepListener queryTaskStep = queryTask( + localNodeId, + "modelId3", + MLTaskType.LOAD_MODEL, + MLTaskState.COMPLETED_WITH_ERROR + ); + + queryTaskStep.whenComplete(response -> { + org.hamcrest.MatcherAssert.assertThat(response, notNullValue()); + org.hamcrest.MatcherAssert.assertThat(response.getHits(), notNullValue()); + org.hamcrest.MatcherAssert.assertThat(response.getHits().getHits(), notNullValue()); + org.hamcrest.MatcherAssert.assertThat(response.getHits().getHits().length, is(1)); + + Map source = response.getHits().getHits()[0].getSourceAsMap(); + org.hamcrest.MatcherAssert.assertThat(source.get("model_id"), is("modelId3")); + }, exception -> fail(exception.getMessage())); + } + + public void testQueryTask_IndexNotExisted() { + StepListener queryTaskStep = new StepListener<>(); + + Throwable exception = Assert + .assertThrows( + MLException.class, + () -> mlModelAutoReloader.queryTask(localNodeId, ActionListener.wrap(queryTaskStep::onResponse, queryTaskStep::onFailure)) + ); + org.hamcrest.MatcherAssert.assertThat(exception.getMessage(), containsString("index " + ML_TASK_INDEX + " not found")); + } + + public void testGetReTryTimes() throws InterruptedException { + assertLatestReTryTimes(1); + assertLatestReTryTimes(3); + } + + public void testGetReTryTimes_IndexNotExisted() { + StepListener getReTryTimesStep = new StepListener<>(); + + mlModelAutoReloader.getRetryTimes(localNodeId, ActionListener.wrap(getReTryTimesStep::onResponse, getReTryTimesStep::onFailure)); + } + + public void testGetReTryTimes_EmptyHits() { + createIndex(ML_MODEL_RELOAD_INDEX); + + StepListener getReTryTimesStep = new StepListener<>(); + mlModelAutoReloader.getRetryTimes(localNodeId, ActionListener.wrap(getReTryTimesStep::onResponse, getReTryTimesStep::onFailure)); + + getReTryTimesStep.whenComplete(response -> {}, exception -> { + org.hamcrest.MatcherAssert.assertThat(exception.getClass(), is(RuntimeException.class)); + org.hamcrest.MatcherAssert.assertThat(exception.getMessage(), containsString("can't get retryTimes from node")); + }); + } + + public void testSaveLatestReTryTimes() throws InterruptedException { + assertLatestReTryTimes(0); + assertLatestReTryTimes(1); + assertLatestReTryTimes(3); + } + + private void createIndex() { + CreateIndexRequest createIndexRequest = new CreateIndexRequest(ML_TASK_INDEX); + createIndexRequest.mapping(ML_TASK_INDEX_MAPPING); + + client().execute(CreateIndexAction.INSTANCE, createIndexRequest).actionGet(5000); + } + + private void initDataOfMlTask(String nodeId, String modelId, MLTaskType mlTaskType, MLTaskState mlTaskState) throws IOException { + MLTask mlTask = MLTask + .builder() + .taskId(taskId) + .modelId(modelId) + .taskType(mlTaskType) + .state(mlTaskState) + .workerNodes(List.of(nodeId)) + .progress(0.0f) + .outputIndex("test_index") + .createTime(time.minus(1, ChronoUnit.MINUTES)) + .async(true) + .lastUpdateTime(time) + .build(); + + IndexRequest indexRequest = new IndexRequest(ML_TASK_INDEX); + indexRequest.id(taskId); + indexRequest.version(); + indexRequest.source(mlTask.toXContent(XContentBuilder.builder(XContentType.JSON.xContent()), ToXContent.EMPTY_PARAMS)); + indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + + client().execute(IndexAction.INSTANCE, indexRequest).actionGet(5000); + } + + private void assertLatestReTryTimes(int retryTimes) throws InterruptedException { + final CountDownLatch inProgressLatch = new CountDownLatch(1); + StepListener saveLatestReTryTimesStep = new StepListener<>(); + mlModelAutoReloader + .saveLatestRetryTimes( + localNodeId, + retryTimes, + ActionListener.wrap(saveLatestReTryTimesStep::onResponse, saveLatestReTryTimesStep::onFailure) + ); + + saveLatestReTryTimesStep.whenComplete(response -> { + inProgressLatch.countDown(); + org.hamcrest.MatcherAssert.assertThat(response, notNullValue()); + org.hamcrest.MatcherAssert.assertThat(response.status(), notNullValue()); + org.hamcrest.MatcherAssert.assertThat(response.status().getStatus(), anyOf(is(201), is(200))); + }, exception -> fail(exception.getMessage())); + + inProgressLatch.await(); + + StepListener getReTryTimesStep = new StepListener<>(); + mlModelAutoReloader.getRetryTimes(localNodeId, ActionListener.wrap(getReTryTimesStep::onResponse, getReTryTimesStep::onFailure)); + + getReTryTimesStep.whenComplete(response -> { + org.hamcrest.MatcherAssert.assertThat(response, notNullValue()); + org.hamcrest.MatcherAssert.assertThat(response.getHits(), notNullValue()); + org.hamcrest.MatcherAssert.assertThat(response.getHits().getHits(), notNullValue()); + org.hamcrest.MatcherAssert.assertThat(response.getHits().getHits().length, is(1)); + + Map sourceAsMap = response.getHits().getHits()[0].getSourceAsMap(); + int result = (Integer) sourceAsMap.get(MODEL_LOAD_RETRY_TIMES_FIELD); + org.hamcrest.MatcherAssert.assertThat(result, is(retryTimes)); + }, exception -> fail(exception.getMessage())); + } + + private StepListener queryTask(String localNodeId, String modelId, MLTaskType mlTaskType, MLTaskState mlTaskState) + throws IOException { + initDataOfMlTask(localNodeId, modelId, mlTaskType, mlTaskState); + + StepListener queryTaskStep = new StepListener<>(); + mlModelAutoReloader.queryTask(localNodeId, ActionListener.wrap(queryTaskStep::onResponse, queryTaskStep::onFailure)); + + return queryTaskStep; + } + + private void initDataOfMlModel(String modelId) throws IOException { + MLModelConfig modelConfig = TextEmbeddingModelConfig + .builder() + .modelType("bert") + .frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) + .embeddingDimension(384) + .build(); + + MLModel mlModel = MLModel + .builder() + .modelId(modelId) + .name("model_name") + .algorithm(FunctionName.TEXT_EMBEDDING) + .version("1.0.0") + .modelFormat(MLModelFormat.TORCH_SCRIPT) + .modelState(MLModelState.LOADED) + .modelConfig(modelConfig) + .totalChunks(2) + .modelContentHash("c446f747520bcc6af053813cb1e8d34944a7c4686bbb405aeaa23883b5a806c8") + .modelContentSizeInBytes(1000L) + .createdTime(time.minus(1, ChronoUnit.MINUTES)) + .build(); + modelChunk0 = mlModel + .toBuilder() + .content(Base64.getEncoder().encodeToString("test chunk1".getBytes(StandardCharsets.UTF_8))) + .build(); + modelChunk1 = mlModel + .toBuilder() + .content(Base64.getEncoder().encodeToString("test chunk2".getBytes(StandardCharsets.UTF_8))) + .build(); + + setUpMock_GetModelChunks(mlModel); + + IndexRequest indexRequest = new IndexRequest(ML_MODEL_INDEX); + indexRequest.id(modelId); + indexRequest.version(); + indexRequest.source(mlModel.toXContent(XContentBuilder.builder(XContentType.JSON.xContent()), ToXContent.EMPTY_PARAMS)); + indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + + client().execute(IndexAction.INSTANCE, indexRequest).actionGet(5000); + } + + private void setUpMock_GetModelChunks(MLModel model) { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(model); + return null; + }).doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(modelChunk0); + return null; + }).doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(modelChunk1); + return null; + }).when(modelManager).getModel(any(), any()); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/model/MLModelAutoReloaderTests.java b/plugin/src/test/java/org/opensearch/ml/model/MLModelAutoReloaderTests.java new file mode 100644 index 0000000000..61c13a2c41 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/model/MLModelAutoReloaderTests.java @@ -0,0 +1,693 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.ml.model; + +import static org.hamcrest.Matchers.anyOf; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.notNullValue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.when; +import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; +import static org.opensearch.ml.common.CommonValue.ML_MODEL_RELOAD_INDEX; +import static org.opensearch.ml.common.CommonValue.ML_TASK_INDEX; +import static org.opensearch.ml.common.CommonValue.MODEL_LOAD_RETRY_TIMES_FIELD; +import static org.opensearch.ml.plugin.MachineLearningPlugin.LOAD_THREAD_POOL; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MAX_LOAD_MODEL_TASKS_PER_NODE; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MAX_MODELS_PER_NODE; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MAX_UPLOAD_TASKS_PER_NODE; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MODEL_AUTO_RELOAD_ENABLE; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MONITORING_REQUEST_COUNT; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ONLY_RUN_ON_ML_NODE; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_MODEL_RELOAD_MAX_RETRY_TIMES; +import static org.opensearch.ml.utils.TestHelper.ML_ROLE; +import static org.opensearch.ml.utils.TestHelper.clusterSetting; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.nio.file.Path; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.Base64; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.atomic.AtomicInteger; + +import org.junit.After; +import org.junit.Before; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.Version; +import org.opensearch.action.ActionListener; +import org.opensearch.action.StepListener; +import org.opensearch.action.admin.indices.create.CreateIndexAction; +import org.opensearch.action.admin.indices.create.CreateIndexRequestBuilder; +import org.opensearch.action.admin.indices.create.CreateIndexResponse; +import org.opensearch.action.index.IndexAction; +import org.opensearch.action.index.IndexRequestBuilder; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.client.Client; +import org.opensearch.cluster.ClusterName; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.metadata.IndexMetadata; +import org.opensearch.cluster.metadata.Metadata; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.cluster.node.DiscoveryNodeRole; +import org.opensearch.cluster.node.DiscoveryNodes; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.collect.ImmutableOpenMap; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.transport.TransportAddress; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.NamedXContentRegistry; +import org.opensearch.common.xcontent.ToXContent; +import org.opensearch.common.xcontent.XContentBuilder; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.ml.breaker.MLCircuitBreakerService; +import org.opensearch.ml.cluster.DiscoveryNodeHelper; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.MLModel; +import org.opensearch.ml.common.MLTask; +import org.opensearch.ml.common.MLTaskState; +import org.opensearch.ml.common.MLTaskType; +import org.opensearch.ml.common.exception.MLException; +import org.opensearch.ml.common.model.MLModelConfig; +import org.opensearch.ml.common.model.MLModelFormat; +import org.opensearch.ml.common.model.MLModelState; +import org.opensearch.ml.common.model.TextEmbeddingModelConfig; +import org.opensearch.ml.engine.MLEngine; +import org.opensearch.ml.engine.ModelHelper; +import org.opensearch.ml.indices.MLIndicesHandler; +import org.opensearch.ml.stats.MLNodeLevelStat; +import org.opensearch.ml.stats.MLStat; +import org.opensearch.ml.stats.MLStats; +import org.opensearch.ml.stats.suppliers.CounterSupplier; +import org.opensearch.ml.task.MLTaskManager; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; + +public class MLModelAutoReloaderTests extends OpenSearchTestCase { + + private final Instant time = Instant.now(); + private static final AtomicInteger portGenerator = new AtomicInteger(); + private ClusterState testState; + private DiscoveryNode node; + @Mock + private ClusterService clusterService; + @Mock + private Client client; + @Mock + private ThreadPool threadPool; + private NamedXContentRegistry xContentRegistry; + @Mock + private ModelHelper modelHelper; + private Settings settings; + @Mock + private MLCircuitBreakerService mlCircuitBreakerService; + @Mock + private MLIndicesHandler mlIndicesHandler; + @Mock + private MLTaskManager mlTaskManager; + private MLModelManager modelManager; + @Mock + private ExecutorService taskExecutorService; + private ThreadContext threadContext; + private String taskId; + private String modelId; + private String localNodeId; + private MLModel modelChunk0; + private MLModel modelChunk1; + @Mock + private MLModelCacheHelper modelCacheHelper; + @Mock + private DiscoveryNodeHelper nodeHelper; + private MLModelAutoReloader mlModelAutoReloader; + + @Before + public void setup() throws Exception { + super.setUp(); + + MockitoAnnotations.openMocks(this); + + MLEngine mlEngine = new MLEngine(Path.of("/tmp/test" + randomAlphaOfLength(10))); + settings = Settings + .builder() + .put(ML_COMMONS_MAX_MODELS_PER_NODE.getKey(), 10) + .put(ML_COMMONS_MAX_UPLOAD_TASKS_PER_NODE.getKey(), 10) + .put(ML_COMMONS_MONITORING_REQUEST_COUNT.getKey(), 10) + .put(ML_COMMONS_MODEL_AUTO_RELOAD_ENABLE.getKey(), true) + .put(ML_MODEL_RELOAD_MAX_RETRY_TIMES.getKey(), 3) + .put(ML_COMMONS_ONLY_RUN_ON_ML_NODE.getKey(), true) + .put(ML_COMMONS_MAX_LOAD_MODEL_TASKS_PER_NODE.getKey(), 1) + .build(); + DiscoveryNode localNode = setupTestMLNode(); + when(clusterService.localNode()).thenReturn(localNode); + testState = setupTestClusterState(localNode); + + ClusterSettings clusterSettings = clusterSetting( + settings, + ML_COMMONS_MAX_MODELS_PER_NODE, + ML_COMMONS_MAX_UPLOAD_TASKS_PER_NODE, + ML_COMMONS_MONITORING_REQUEST_COUNT, + ML_COMMONS_MODEL_AUTO_RELOAD_ENABLE, + ML_MODEL_RELOAD_MAX_RETRY_TIMES, + ML_COMMONS_ONLY_RUN_ON_ML_NODE, + ML_COMMONS_MAX_LOAD_MODEL_TASKS_PER_NODE + ); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + when(clusterService.state()).thenReturn(testState); + + xContentRegistry = NamedXContentRegistry.EMPTY; + + taskId = "taskId1"; + String modelName = "model_name1"; + modelId = randomAlphaOfLength(10); + String modelContentHashValue = "c446f747520bcc6af053813cb1e8d34944a7c4686bbb405aeaa23883b5a806c8"; + String version = "1.0.0"; + MLModelConfig modelConfig = TextEmbeddingModelConfig + .builder() + .modelType("bert") + .frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) + .embeddingDimension(384) + .build(); + + Map> stats = new ConcurrentHashMap<>(); + // node level stats + stats.put(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_NODE_TOTAL_FAILURE_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_NODE_TOTAL_MODEL_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_NODE_TOTAL_CIRCUIT_BREAKER_TRIGGER_COUNT, new MLStat<>(false, new CounterSupplier())); + MLStats mlStats = spy(new MLStats(stats)); + + doAnswer(invocation -> { + Runnable runnable = invocation.getArgument(0); + runnable.run(); + return null; + }).when(taskExecutorService).execute(any()); + + threadContext = new ThreadContext(settings); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + when(threadPool.executor(LOAD_THREAD_POOL)).thenReturn(taskExecutorService); + + modelManager = spy( + new MLModelManager( + clusterService, + client, + threadPool, + xContentRegistry, + modelHelper, + settings, + mlStats, + mlCircuitBreakerService, + mlIndicesHandler, + mlTaskManager, + modelCacheHelper, + mlEngine + ) + ); + nodeHelper = spy(new DiscoveryNodeHelper(clusterService, settings)); + + mlModelAutoReloader = spy( + new MLModelAutoReloader(clusterService, client.threadPool(), client, xContentRegistry, nodeHelper, settings) + ); + + Long modelContentSize = 1000L; + MLModel model = MLModel + .builder() + .modelId(modelId) + .modelState(MLModelState.UPLOADED) + .algorithm(FunctionName.TEXT_EMBEDDING) + .name(modelName) + .version(version) + .totalChunks(2) + .modelFormat(MLModelFormat.TORCH_SCRIPT) + .modelConfig(modelConfig) + .modelContentHash(modelContentHashValue) + .modelContentSizeInBytes(modelContentSize) + .build(); + modelChunk0 = model.toBuilder().content(Base64.getEncoder().encodeToString("test chunk1".getBytes(StandardCharsets.UTF_8))).build(); + modelChunk1 = model.toBuilder().content(Base64.getEncoder().encodeToString("test chunk2".getBytes(StandardCharsets.UTF_8))).build(); + + localNodeId = clusterService.localNode().getId(); + } + + @After + public void tearDown() throws Exception { + super.tearDown(); + + settings = null; + nodeHelper = null; + mlModelAutoReloader = null; + modelId = null; + localNodeId = null; + modelManager = null; + } + + public void testAutoReloadModel() { + mlModelAutoReloader.autoReloadModel(); + } + + public void testAutoReloadModel_setting_false() { + settings = Settings + .builder() + .put(ML_COMMONS_MODEL_AUTO_RELOAD_ENABLE.getKey(), false) + .put(ML_MODEL_RELOAD_MAX_RETRY_TIMES.getKey(), 3) + .put(ML_COMMONS_ONLY_RUN_ON_ML_NODE.getKey(), true) + .build(); + + ClusterSettings clusterSettings = clusterSetting( + settings, + ML_COMMONS_MODEL_AUTO_RELOAD_ENABLE, + ML_MODEL_RELOAD_MAX_RETRY_TIMES, + ML_COMMONS_ONLY_RUN_ON_ML_NODE + ); + clusterService = spy(new ClusterService(settings, clusterSettings, null)); + nodeHelper = spy(new DiscoveryNodeHelper(clusterService, settings)); + mlModelAutoReloader = spy(new MLModelAutoReloader(clusterService, threadPool, client, xContentRegistry, nodeHelper, settings)); + + mlModelAutoReloader.autoReloadModel(); + } + + public void testAutoReloadModel_Is_Not_ML_Node() { + settings = Settings + .builder() + .put(ML_COMMONS_MAX_MODELS_PER_NODE.getKey(), 10) + .put(ML_COMMONS_MAX_UPLOAD_TASKS_PER_NODE.getKey(), 10) + .put(ML_COMMONS_MONITORING_REQUEST_COUNT.getKey(), 10) + .put(ML_COMMONS_MODEL_AUTO_RELOAD_ENABLE.getKey(), true) + .put(ML_MODEL_RELOAD_MAX_RETRY_TIMES.getKey(), 3) + .put(ML_COMMONS_ONLY_RUN_ON_ML_NODE.getKey(), true) + .build(); + DiscoveryNode localNode = setupTestDataNode(); + when(clusterService.localNode()).thenReturn(localNode); + testState = setupTestClusterState(localNode); + ClusterSettings clusterSettings = clusterSetting( + settings, + ML_COMMONS_MAX_MODELS_PER_NODE, + ML_COMMONS_MAX_UPLOAD_TASKS_PER_NODE, + ML_COMMONS_MONITORING_REQUEST_COUNT, + ML_COMMONS_MODEL_AUTO_RELOAD_ENABLE, + ML_MODEL_RELOAD_MAX_RETRY_TIMES, + ML_COMMONS_ONLY_RUN_ON_ML_NODE + ); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + when(clusterService.state()).thenReturn(testState); + + xContentRegistry = NamedXContentRegistry.EMPTY; + + doAnswer(invocation -> { + Runnable runnable = invocation.getArgument(0); + runnable.run(); + return null; + }).when(taskExecutorService).execute(any()); + + threadContext = new ThreadContext(settings); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + when(threadPool.generic()).thenReturn(taskExecutorService); + + nodeHelper = spy(new DiscoveryNodeHelper(clusterService, settings)); + + mlModelAutoReloader = spy( + new MLModelAutoReloader(clusterService, client.threadPool(), client, xContentRegistry, nodeHelper, settings) + ); + + mlModelAutoReloader.autoReloadModel(); + } + + public void testAutoReloadModelByNodeId() throws IOException, ExecutionException, InterruptedException { + initDataOfMlTask(localNodeId, modelId, MLTaskType.LOAD_MODEL, MLTaskState.COMPLETED); + initDataOfMlModel(modelId); + + mlModelAutoReloader.autoReloadModelByNodeId(localNodeId); + + StepListener getReTryTimesStep = new StepListener<>(); + mlModelAutoReloader.getRetryTimes(localNodeId, ActionListener.wrap(getReTryTimesStep::onResponse, getReTryTimesStep::onFailure)); + } + + public void testAutoReloadModelByNodeId_ReTry() throws IOException, ExecutionException, InterruptedException { + initDataOfMlTask(localNodeId, modelId, MLTaskType.LOAD_MODEL, MLTaskState.COMPLETED); + + mlModelAutoReloader.autoReloadModelByNodeId(localNodeId); + + StepListener getReTryTimesStep = new StepListener<>(); + mlModelAutoReloader.getRetryTimes(localNodeId, ActionListener.wrap(getReTryTimesStep::onResponse, getReTryTimesStep::onFailure)); + } + + public void testAutoReloadModelByNodeId_Max_ReTryTimes() throws IOException, ExecutionException, InterruptedException { + initDataOfMlTask(localNodeId, modelId, MLTaskType.LOAD_MODEL, MLTaskState.COMPLETED); + + StepListener saveLatestReTryTimesStep = new StepListener<>(); + mlModelAutoReloader + .saveLatestRetryTimes( + localNodeId, + 3, + ActionListener.wrap(saveLatestReTryTimesStep::onResponse, saveLatestReTryTimesStep::onFailure) + ); + + saveLatestReTryTimesStep.whenComplete(response -> { + org.hamcrest.MatcherAssert.assertThat(response, notNullValue()); + org.hamcrest.MatcherAssert.assertThat(response.status(), notNullValue()); + org.hamcrest.MatcherAssert.assertThat(response.status().getStatus(), anyOf(is(201), is(200))); + }, exception -> fail(exception.getMessage())); + + mlModelAutoReloader.autoReloadModelByNodeId(localNodeId); + } + + public void testAutoReloadModelByNodeId_IndexNotFound() throws ExecutionException, InterruptedException { + mlModelAutoReloader.autoReloadModelByNodeId(localNodeId); + } + + public void testAutoReloadModelByNodeId_EmptyHits() throws ExecutionException, InterruptedException { + createIndex(ML_TASK_INDEX); + + mlModelAutoReloader.autoReloadModelByNodeId(localNodeId); + } + + public void testAutoReloadModelByNodeAndModelId() throws IOException { + initDataOfMlTask(localNodeId, modelId, MLTaskType.LOAD_MODEL, MLTaskState.COMPLETED); + initDataOfMlModel(modelId); + mlModelAutoReloader.autoReloadModelByNodeAndModelId(localNodeId, modelId); + } + + public void testAutoReloadModelByNodeAndModelId_Exception() { + mlModelAutoReloader.autoReloadModelByNodeAndModelId(localNodeId, modelId); + } + + public void testQueryTask() throws IOException { + StepListener queryTaskStep = queryTask(localNodeId, modelId, MLTaskType.LOAD_MODEL, MLTaskState.COMPLETED); + + queryTaskStep.whenComplete(response -> { + org.hamcrest.MatcherAssert.assertThat(response, notNullValue()); + org.hamcrest.MatcherAssert.assertThat(response.getHits(), notNullValue()); + org.hamcrest.MatcherAssert.assertThat(response.getHits().getHits(), notNullValue()); + org.hamcrest.MatcherAssert.assertThat(response.getHits().getHits().length, is(1)); + }, exception -> fail(exception.getMessage())); + + queryTaskStep = queryTask(localNodeId, modelId, MLTaskType.UPLOAD_MODEL, MLTaskState.COMPLETED); + + queryTaskStep.whenComplete(response -> { + org.hamcrest.MatcherAssert.assertThat(response, notNullValue()); + org.hamcrest.MatcherAssert.assertThat(response.getHits(), notNullValue()); + org.hamcrest.MatcherAssert.assertThat(response.getHits().getHits(), notNullValue()); + org.hamcrest.MatcherAssert.assertThat(response.getHits().getHits().length, is(0)); + }, exception -> fail(exception.getMessage())); + + queryTaskStep = queryTask(localNodeId, modelId, MLTaskType.LOAD_MODEL, MLTaskState.RUNNING); + + queryTaskStep.whenComplete(response -> { + org.hamcrest.MatcherAssert.assertThat(response, notNullValue()); + org.hamcrest.MatcherAssert.assertThat(response.getHits(), notNullValue()); + org.hamcrest.MatcherAssert.assertThat(response.getHits().getHits(), notNullValue()); + org.hamcrest.MatcherAssert.assertThat(response.getHits().getHits().length, is(0)); + }, exception -> fail(exception.getMessage())); + } + + public void testQueryTask_MultiDataInTaskIndex() throws IOException { + StepListener queryTaskStep = queryTask(localNodeId, "modelId1", MLTaskType.LOAD_MODEL, MLTaskState.COMPLETED); + queryTaskStep.whenComplete(response -> {}, exception -> fail(exception.getMessage())); + + queryTaskStep = queryTask(localNodeId, "modelId2", MLTaskType.LOAD_MODEL, MLTaskState.COMPLETED); + queryTaskStep.whenComplete(response -> {}, exception -> fail(exception.getMessage())); + + queryTaskStep = queryTask(localNodeId, "modelId3", MLTaskType.LOAD_MODEL, MLTaskState.COMPLETED); + + queryTaskStep.whenComplete(response -> { + org.hamcrest.MatcherAssert.assertThat(response, notNullValue()); + org.hamcrest.MatcherAssert.assertThat(response.getHits(), notNullValue()); + org.hamcrest.MatcherAssert.assertThat(response.getHits().getHits(), notNullValue()); + org.hamcrest.MatcherAssert.assertThat(response.getHits().getHits().length, is(1)); + + Map source = response.getHits().getHits()[0].getSourceAsMap(); + org.hamcrest.MatcherAssert.assertThat(source.get("model_id"), is("modelId3")); + }, exception -> fail(exception.getMessage())); + } + + public void testQueryTask_MultiDataInTaskIndex_TaskState_COMPLETED_WITH_ERROR() throws IOException { + StepListener queryTaskStep = queryTask( + localNodeId, + "modelId3", + MLTaskType.LOAD_MODEL, + MLTaskState.COMPLETED_WITH_ERROR + ); + + queryTaskStep.whenComplete(response -> { + org.hamcrest.MatcherAssert.assertThat(response, notNullValue()); + org.hamcrest.MatcherAssert.assertThat(response.getHits(), notNullValue()); + org.hamcrest.MatcherAssert.assertThat(response.getHits().getHits(), notNullValue()); + org.hamcrest.MatcherAssert.assertThat(response.getHits().getHits().length, is(1)); + + Map source = response.getHits().getHits()[0].getSourceAsMap(); + org.hamcrest.MatcherAssert.assertThat(source.get("model_id"), is("modelId3")); + }, exception -> fail(exception.getMessage())); + } + + public void testQueryTask_IndexNotExisted() { + StepListener queryTaskStep = new StepListener<>(); + + mlModelAutoReloader.queryTask(localNodeId, ActionListener.wrap(queryTaskStep::onResponse, queryTaskStep::onFailure)); + + queryTaskStep.whenComplete(response -> {}, exception -> { + org.hamcrest.MatcherAssert.assertThat(exception.getClass(), is(MLException.class)); + org.hamcrest.MatcherAssert.assertThat(exception.getMessage(), containsString("index " + ML_TASK_INDEX + " not found")); + }); + } + + public void testGetReTryTimes() { + assertLatestReTryTimes(1); + assertLatestReTryTimes(3); + } + + public void testGetReTryTimes_IndexNotExisted() { + StepListener getReTryTimesStep = new StepListener<>(); + + mlModelAutoReloader.getRetryTimes(localNodeId, ActionListener.wrap(getReTryTimesStep::onResponse, getReTryTimesStep::onFailure)); + } + + public void testGetReTryTimes_EmptyHits() { + createIndex(ML_MODEL_RELOAD_INDEX); + + StepListener getReTryTimesStep = new StepListener<>(); + mlModelAutoReloader.getRetryTimes(localNodeId, ActionListener.wrap(getReTryTimesStep::onResponse, getReTryTimesStep::onFailure)); + + getReTryTimesStep.whenComplete(response -> {}, exception -> { + org.hamcrest.MatcherAssert.assertThat(exception.getClass(), is(RuntimeException.class)); + org.hamcrest.MatcherAssert.assertThat(exception.getMessage(), containsString("can't get retryTimes from node")); + }); + } + + public void testSaveLatestReTryTimes() { + assertLatestReTryTimes(0); + assertLatestReTryTimes(1); + assertLatestReTryTimes(3); + } + + private void createIndex(String indexName) { + StepListener actionListener = new StepListener<>(); + CreateIndexRequestBuilder requestBuilder = new CreateIndexRequestBuilder(client, CreateIndexAction.INSTANCE, indexName); + + requestBuilder.execute(ActionListener.wrap(actionListener::onResponse, actionListener::onFailure)); + } + + private void setUpMock_GetModelChunks(MLModel model) { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(model); + return null; + }).doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(modelChunk0); + return null; + }).doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(modelChunk1); + return null; + }).when(modelManager).getModel(any(), any()); + } + + private DiscoveryNode setupTestMLNode() { + Set roleSet = new HashSet<>(); + roleSet.add(ML_ROLE); + node = new DiscoveryNode( + "node", + new TransportAddress(TransportAddress.META_ADDRESS, portGenerator.incrementAndGet()), + new HashMap<>(), + roleSet, + Version.CURRENT + ); + + return node; + } + + private DiscoveryNode setupTestDataNode() { + Set roleSet = new HashSet<>(); + roleSet.add(DiscoveryNodeRole.DATA_ROLE); + node = new DiscoveryNode( + "node", + new TransportAddress(TransportAddress.META_ADDRESS, portGenerator.incrementAndGet()), + new HashMap<>(), + roleSet, + Version.CURRENT + ); + + return node; + } + + private ClusterState setupTestClusterState(DiscoveryNode discoveryNode) { + Metadata metadata = new Metadata.Builder() + .indices( + ImmutableOpenMap + .builder() + .fPut( + ML_MODEL_RELOAD_INDEX, + IndexMetadata + .builder("test") + .settings( + Settings + .builder() + .put("index.number_of_shards", 1) + .put("index.number_of_replicas", 1) + .put("index.version.created", Version.CURRENT.id) + ) + .build() + ) + .build() + ) + .build(); + return new ClusterState( + new ClusterName("clusterName"), + 123L, + "111111", + metadata, + null, + DiscoveryNodes.builder().add(discoveryNode).localNodeId(discoveryNode.getId()).build(), + null, + null, + 0, + false + ); + } + + private void assertLatestReTryTimes(int retryTimes) { + StepListener saveLatestReTryTimesStep = new StepListener<>(); + mlModelAutoReloader + .saveLatestRetryTimes( + localNodeId, + retryTimes, + ActionListener.wrap(saveLatestReTryTimesStep::onResponse, saveLatestReTryTimesStep::onFailure) + ); + + saveLatestReTryTimesStep.whenComplete(response -> { + // inProgressLatch.countDown(); + org.hamcrest.MatcherAssert.assertThat(response, notNullValue()); + org.hamcrest.MatcherAssert.assertThat(response.status(), notNullValue()); + org.hamcrest.MatcherAssert.assertThat(response.status().getStatus(), anyOf(is(201), is(200))); + }, exception -> fail(exception.getMessage())); + + StepListener getReTryTimesStep = new StepListener<>(); + mlModelAutoReloader.getRetryTimes(localNodeId, ActionListener.wrap(getReTryTimesStep::onResponse, getReTryTimesStep::onFailure)); + + getReTryTimesStep.whenComplete(response -> { + org.hamcrest.MatcherAssert.assertThat(response, notNullValue()); + org.hamcrest.MatcherAssert.assertThat(response.getHits(), notNullValue()); + org.hamcrest.MatcherAssert.assertThat(response.getHits().getHits(), notNullValue()); + org.hamcrest.MatcherAssert.assertThat(response.getHits().getHits().length, is(1)); + + Map sourceAsMap = response.getHits().getHits()[0].getSourceAsMap(); + int result = (Integer) sourceAsMap.get(MODEL_LOAD_RETRY_TIMES_FIELD); + org.hamcrest.MatcherAssert.assertThat(result, is(retryTimes)); + }, exception -> fail(exception.getMessage())); + } + + private StepListener queryTask(String localNodeId, String modelId, MLTaskType mlTaskType, MLTaskState mlTaskState) + throws IOException { + initDataOfMlTask(localNodeId, modelId, mlTaskType, mlTaskState); + + StepListener queryTaskStep = new StepListener<>(); + mlModelAutoReloader.queryTask(localNodeId, ActionListener.wrap(queryTaskStep::onResponse, queryTaskStep::onFailure)); + + return queryTaskStep; + } + + private void initDataOfMlTask(String nodeId, String modelId, MLTaskType mlTaskType, MLTaskState mlTaskState) throws IOException { + MLTask mlTask = MLTask + .builder() + .taskId(taskId) + .modelId(modelId) + .taskType(mlTaskType) + .state(mlTaskState) + .workerNodes(List.of(nodeId)) + .progress(0.0f) + .outputIndex("test_index") + .createTime(time.minus(1, ChronoUnit.MINUTES)) + .async(true) + .lastUpdateTime(time) + .build(); + + StepListener actionListener = new StepListener<>(); + IndexRequestBuilder requestBuilder = new IndexRequestBuilder(client, IndexAction.INSTANCE, ML_TASK_INDEX); + requestBuilder.setId(taskId); + requestBuilder.setSource(mlTask.toXContent(XContentBuilder.builder(XContentType.JSON.xContent()), ToXContent.EMPTY_PARAMS)); + requestBuilder.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + + requestBuilder.execute(ActionListener.wrap(actionListener::onResponse, actionListener::onFailure)); + } + + private void initDataOfMlModel(String modelId) throws IOException { + MLModelConfig modelConfig = TextEmbeddingModelConfig + .builder() + .modelType("bert") + .frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) + .embeddingDimension(384) + .build(); + + MLModel mlModel = MLModel + .builder() + .modelId(modelId) + .name("model_name") + .algorithm(FunctionName.TEXT_EMBEDDING) + .version("1.0.0") + .modelFormat(MLModelFormat.TORCH_SCRIPT) + .modelState(MLModelState.LOADED) + .modelConfig(modelConfig) + .totalChunks(2) + .modelContentHash("c446f747520bcc6af053813cb1e8d34944a7c4686bbb405aeaa23883b5a806c8") + .modelContentSizeInBytes(1000L) + .createdTime(time.minus(1, ChronoUnit.MINUTES)) + .build(); + modelChunk0 = mlModel + .toBuilder() + .content(Base64.getEncoder().encodeToString("test chunk1".getBytes(StandardCharsets.UTF_8))) + .build(); + modelChunk1 = mlModel + .toBuilder() + .content(Base64.getEncoder().encodeToString("test chunk2".getBytes(StandardCharsets.UTF_8))) + .build(); + + setUpMock_GetModelChunks(mlModel); + + StepListener actionListener = new StepListener<>(); + IndexRequestBuilder requestBuilder = new IndexRequestBuilder(client, IndexAction.INSTANCE, ML_MODEL_INDEX); + requestBuilder.setId(modelId); + requestBuilder.setSource(mlModel.toXContent(XContentBuilder.builder(XContentType.JSON.xContent()), ToXContent.EMPTY_PARAMS)); + requestBuilder.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + + requestBuilder.execute(ActionListener.wrap(actionListener::onResponse, actionListener::onFailure)); + } + +}