forked from opensearch-project/ml-commons
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
the dev of [FEATURE]Auto reload model when cluster rebooted/node rejo…
…in (opensearch-project#711) * [wjunshen] #N/A feat: fix after the latest rebase Signed-off-by: wujunshen <[email protected]> * [wjunshen] #N/A feat: fix after rebase Signed-off-by: wujunshen <[email protected]> * [wjunshen] #N/A feat: fix after rebase Signed-off-by: wujunshen <[email protected]> * [wjunshen] #N/A feat: fix after rebase Signed-off-by: wujunshen <[email protected]> * [wjunshen] #N/A feat: fix after the latest rebase Signed-off-by: wujunshen <[email protected]> * Increment version to 2.6.0-SNAPSHOT (opensearch-project#671) Signed-off-by: opensearch-ci-bot <[email protected]> Signed-off-by: opensearch-ci-bot <[email protected]> Co-authored-by: opensearch-ci-bot <[email protected]> Signed-off-by: wujunshen <[email protected]> * fix profile API in example doc (opensearch-project#712) Signed-off-by: Yaliang Wu <[email protected]> Signed-off-by: Yaliang Wu <[email protected]> Signed-off-by: wujunshen <[email protected]> * change model url to public repo in text embedding model example doc (opensearch-project#713) Signed-off-by: Yaliang Wu <[email protected]> Signed-off-by: Yaliang Wu <[email protected]> Signed-off-by: wujunshen <[email protected]> * Enhance profile API to add model centric result controlled by view parameter (opensearch-project#714) * Enhance profile API to add model centric result controled by view paramter Signed-off-by: Zan Niu <[email protected]> * Enhance profile API to add model centric result controled by view parameter Signed-off-by: Zan Niu <[email protected]> * Enhance profile API to add model centric result controled by view parameter Signed-off-by: Zan Niu <[email protected]> --------- Signed-off-by: Zan Niu <[email protected]> Signed-off-by: wujunshen <[email protected]> * add planning work nodes to model (opensearch-project#715) * add planning work nodes to model Signed-off-by: Yaliang Wu <[email protected]> * add test Signed-off-by: Yaliang Wu <[email protected]> --------- Signed-off-by: Yaliang Wu <[email protected]> Signed-off-by: wujunshen <[email protected]> * skip running syncup job if no model index (opensearch-project#717) Signed-off-by: Yaliang Wu <[email protected]> Signed-off-by: wujunshen <[email protected]> * refactor: add DL model class (opensearch-project#722) * refactor: add DL model class Signed-off-by: Yaliang Wu <[email protected]> * fix model url in example doc Signed-off-by: Yaliang Wu <[email protected]> * address comments Signed-off-by: Yaliang Wu <[email protected]> * fix failed ut Signed-off-by: Yaliang Wu <[email protected]> --------- Signed-off-by: Yaliang Wu <[email protected]> Signed-off-by: wujunshen <[email protected]> * tune model config: change pooling mode to optional (opensearch-project#724) Signed-off-by: Yaliang Wu <[email protected]> Signed-off-by: wujunshen <[email protected]> * [wjunshen] #N/A feat: make the log readable Signed-off-by: wujunshen <[email protected]> * [wjunshen] #N/A feat: add error log Signed-off-by: wujunshen <[email protected]> * [wjunshen] #N/A feat: Refer to PR opensearch-project#717,just checking if index exists in cluster metadata Signed-off-by: wujunshen <[email protected]> * [wjunshen] #N/A feat: change RunTimeException to MLException Signed-off-by: wujunshen <[email protected]> * [wjunshen] #N/A feat: also consider COMPLETED_WITH_ERROR Signed-off-by: wujunshen <[email protected]> * [wjunshen] #N/A feat: remove ML_MODEL_RELOAD_MAX_RETRY_TIMES in CommonValue.java Signed-off-by: wujunshen <[email protected]> * [wjunshen] #N/A feat: remove Result class Signed-off-by: wujunshen <[email protected]> * [wjunshen] #N/A feat: change "reload" and "retry" to a full word Signed-off-by: wujunshen <[email protected]> * [wjunshen] #N/A feat: change log info sentence Signed-off-by: wujunshen <[email protected]> * [wjunshen] #N/A feat: code format Signed-off-by: wujunshen <[email protected]> * [Signed-off-by: wjunshen<[email protected]>] #N/A feat: 1. Let the collection just have all ids of ml node 2. print out full exception stack trace Signed-off-by: wujunshen <[email protected]> * [Signed-off-by: wjunshen<[email protected]>] #N/A feat: use LOAD_THREAD_POOL to replace generic Signed-off-by: wujunshen <[email protected]> * [Signed-off-by: wjunshen<[email protected]>] #N/A feat: print the whole exception stack Signed-off-by: wujunshen <[email protected]> * [Signed-off-by: wjunshen<[email protected]>] #N/A feat: change the IndexNotFoundException to MLException Signed-off-by: wujunshen <[email protected]> * [Signed-off-by: wjunshen<[email protected]>] #N/A feat: add logs when receiving indexRequestBuilder.execute exception Signed-off-by: wujunshen <[email protected]> * [Signed-off-by: wjunshen<[email protected]>] #N/A feat: change the test code after code review Signed-off-by: wujunshen <[email protected]> --------- Signed-off-by: wujunshen <[email protected]> Signed-off-by: opensearch-ci-bot <[email protected]> Signed-off-by: Yaliang Wu <[email protected]> Signed-off-by: Zan Niu <[email protected]> Co-authored-by: opensearch-trigger-bot[bot] <98922864+opensearch-trigger-bot[bot]@users.noreply.github.com> Co-authored-by: opensearch-ci-bot <[email protected]> Co-authored-by: Yaliang Wu <[email protected]> Co-authored-by: zane-neo <[email protected]>
- Loading branch information
1 parent
2ce521c
commit 7a51dcc
Showing
6 changed files
with
1,500 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
331 changes: 331 additions & 0 deletions
331
plugin/src/main/java/org/opensearch/ml/model/MLModelAutoReloader.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,331 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
package org.opensearch.ml.model; | ||
|
||
import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken; | ||
import static org.opensearch.ml.common.CommonValue.ML_MODEL_RELOAD_INDEX; | ||
import static org.opensearch.ml.common.CommonValue.ML_TASK_INDEX; | ||
import static org.opensearch.ml.common.CommonValue.MODEL_LOAD_RETRY_TIMES_FIELD; | ||
import static org.opensearch.ml.common.CommonValue.NODE_ID_FIELD; | ||
import static org.opensearch.ml.plugin.MachineLearningPlugin.LOAD_THREAD_POOL; | ||
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MODEL_AUTO_RELOAD_ENABLE; | ||
import static org.opensearch.ml.settings.MLCommonsSettings.ML_MODEL_RELOAD_MAX_RETRY_TIMES; | ||
import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry; | ||
|
||
import java.util.Arrays; | ||
import java.util.HashMap; | ||
import java.util.List; | ||
import java.util.Map; | ||
import java.util.concurrent.ExecutionException; | ||
import java.util.stream.Collectors; | ||
|
||
import lombok.extern.log4j.Log4j2; | ||
|
||
import org.opensearch.action.ActionListener; | ||
import org.opensearch.action.StepListener; | ||
import org.opensearch.action.index.IndexAction; | ||
import org.opensearch.action.index.IndexRequestBuilder; | ||
import org.opensearch.action.index.IndexResponse; | ||
import org.opensearch.action.search.SearchAction; | ||
import org.opensearch.action.search.SearchRequestBuilder; | ||
import org.opensearch.action.search.SearchResponse; | ||
import org.opensearch.action.support.WriteRequest; | ||
import org.opensearch.client.Client; | ||
import org.opensearch.cluster.node.DiscoveryNode; | ||
import org.opensearch.cluster.service.ClusterService; | ||
import org.opensearch.common.settings.Settings; | ||
import org.opensearch.common.util.CollectionUtils; | ||
import org.opensearch.common.xcontent.NamedXContentRegistry; | ||
import org.opensearch.common.xcontent.XContentParser; | ||
import org.opensearch.index.query.QueryBuilder; | ||
import org.opensearch.index.query.QueryBuilders; | ||
import org.opensearch.ml.cluster.DiscoveryNodeHelper; | ||
import org.opensearch.ml.common.MLTask; | ||
import org.opensearch.ml.common.exception.MLException; | ||
import org.opensearch.ml.common.transport.load.MLLoadModelAction; | ||
import org.opensearch.ml.common.transport.load.MLLoadModelRequest; | ||
import org.opensearch.ml.utils.MLNodeUtils; | ||
import org.opensearch.rest.RestStatus; | ||
import org.opensearch.search.SearchHit; | ||
import org.opensearch.search.builder.SearchSourceBuilder; | ||
import org.opensearch.search.sort.FieldSortBuilder; | ||
import org.opensearch.search.sort.SortBuilder; | ||
import org.opensearch.search.sort.SortOrder; | ||
import org.opensearch.threadpool.ThreadPool; | ||
|
||
import com.google.common.annotations.VisibleForTesting; | ||
|
||
/** | ||
* Manager class for ML models and nodes. It contains ML model auto reload operations etc. | ||
*/ | ||
@Log4j2 | ||
public class MLModelAutoReloader { | ||
|
||
private final Client client; | ||
private final ClusterService clusterService; | ||
private final NamedXContentRegistry xContentRegistry; | ||
private final DiscoveryNodeHelper nodeHelper; | ||
private final ThreadPool threadPool; | ||
private volatile Boolean enableAutoReloadModel; | ||
private volatile Integer autoReloadMaxRetryTimes; | ||
|
||
/** | ||
* constructor method, init all the params necessary for model auto reloading | ||
* | ||
* @param clusterService clusterService | ||
* @param threadPool threadPool | ||
* @param client client | ||
* @param xContentRegistry xContentRegistry | ||
* @param nodeHelper nodeHelper | ||
* @param settings settings | ||
*/ | ||
public MLModelAutoReloader( | ||
ClusterService clusterService, | ||
ThreadPool threadPool, | ||
Client client, | ||
NamedXContentRegistry xContentRegistry, | ||
DiscoveryNodeHelper nodeHelper, | ||
Settings settings | ||
) { | ||
this.clusterService = clusterService; | ||
this.client = client; | ||
this.xContentRegistry = xContentRegistry; | ||
this.nodeHelper = nodeHelper; | ||
this.threadPool = threadPool; | ||
|
||
enableAutoReloadModel = ML_COMMONS_MODEL_AUTO_RELOAD_ENABLE.get(settings); | ||
autoReloadMaxRetryTimes = ML_MODEL_RELOAD_MAX_RETRY_TIMES.get(settings); | ||
clusterService | ||
.getClusterSettings() | ||
.addSettingsUpdateConsumer(ML_COMMONS_MODEL_AUTO_RELOAD_ENABLE, it -> enableAutoReloadModel = it); | ||
|
||
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_MODEL_RELOAD_MAX_RETRY_TIMES, it -> autoReloadMaxRetryTimes = it); | ||
} | ||
|
||
/** | ||
* the main method: model auto reloading | ||
*/ | ||
public void autoReloadModel() { | ||
log.info("auto reload model enabled: {} ", enableAutoReloadModel); | ||
|
||
// if we don't need to reload automatically, just return without doing anything | ||
if (!enableAutoReloadModel) { | ||
return; | ||
} | ||
|
||
// At opensearch startup, get local node id, if not ml node,we ignored, just return without doing anything | ||
if (!MLNodeUtils.isMLNode(clusterService.localNode())) { | ||
return; | ||
} | ||
|
||
String localNodeId = clusterService.localNode().getId(); | ||
// auto reload all models of this local ml node | ||
threadPool.executor(LOAD_THREAD_POOL).execute(() -> { | ||
try { | ||
autoReloadModelByNodeId(localNodeId); | ||
} catch (ExecutionException | InterruptedException e) { | ||
log.error("the model auto-reloading has exception,and the root cause message is: {}", e); | ||
throw new MLException(e); | ||
} | ||
}); | ||
} | ||
|
||
/** | ||
* auto reload all the models under the node id<br> the node must be a ml node<br> | ||
* | ||
* @param localNodeId node id | ||
*/ | ||
@VisibleForTesting | ||
void autoReloadModelByNodeId(String localNodeId) throws ExecutionException, InterruptedException { | ||
StepListener<SearchResponse> queryTaskStep = new StepListener<>(); | ||
StepListener<SearchResponse> getRetryTimesStep = new StepListener<>(); | ||
StepListener<IndexResponse> saveLatestRetryTimesStep = new StepListener<>(); | ||
|
||
if (!clusterService.state().metadata().indices().containsKey(ML_TASK_INDEX)) { | ||
// ML_TASK_INDEX did not exist,do nothing | ||
return; | ||
} | ||
|
||
queryTask(localNodeId, ActionListener.wrap(queryTaskStep::onResponse, queryTaskStep::onFailure)); | ||
|
||
getRetryTimes(localNodeId, ActionListener.wrap(getRetryTimesStep::onResponse, getRetryTimesStep::onFailure)); | ||
|
||
queryTaskStep.whenComplete(searchResponse -> { | ||
SearchHit[] hits = searchResponse.getHits().getHits(); | ||
if (CollectionUtils.isEmpty(hits)) { | ||
return; | ||
} | ||
|
||
getRetryTimesStep.whenComplete(getReTryTimesResponse -> { | ||
int retryTimes = 0; | ||
// if getReTryTimesResponse is null,it means we get retryTimes at the first time,and the index | ||
// .plugins-ml-model-reload doesn't exist,so we should let retryTimes be zero(init value) | ||
// we don't do anything | ||
// if getReTryTimesResponse is not null,it means we have saved the value of retryTimes into the index | ||
// .plugins-ml-model-reload,so we get the value of the field MODEL_LOAD_RETRY_TIMES_FIELD | ||
if (getReTryTimesResponse != null) { | ||
Map<String, Object> sourceAsMap = getReTryTimesResponse.getHits().getHits()[0].getSourceAsMap(); | ||
retryTimes = (Integer) sourceAsMap.get(MODEL_LOAD_RETRY_TIMES_FIELD); | ||
} | ||
|
||
// According to the node id to get retry times, if more than the max retry times, don't need to retry | ||
// that the number of unsuccessful reload has reached the maximum number of times, do not need to reload | ||
if (retryTimes > autoReloadMaxRetryTimes) { | ||
log.info("Node: {} has reached to the max retry limit, failed to load models", localNodeId); | ||
return; | ||
} | ||
|
||
try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, hits[0].getSourceRef())) { | ||
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); | ||
MLTask mlTask = MLTask.parse(parser); | ||
|
||
autoReloadModelByNodeAndModelId(localNodeId, mlTask.getModelId()); | ||
|
||
// if reload the model successfully,the number of unsuccessful reload should be reset to zero. | ||
retryTimes = 0; | ||
} catch (MLException e) { | ||
retryTimes++; | ||
log.error("Can't auto reload model in node id {} ,has tried {} times\nThe reason is:{}", localNodeId, retryTimes, e); | ||
} | ||
|
||
// Store the latest value of the retryTimes and node id under the index ".plugins-ml-model-reload" | ||
saveLatestRetryTimes( | ||
localNodeId, | ||
retryTimes, | ||
ActionListener.wrap(saveLatestRetryTimesStep::onResponse, saveLatestRetryTimesStep::onFailure) | ||
); | ||
}, getRetryTimesStep::onFailure); | ||
}, queryTaskStep::onFailure); | ||
|
||
saveLatestRetryTimesStep.whenComplete(response -> log.info("successfully complete all steps"), saveLatestRetryTimesStep::onFailure); | ||
} | ||
|
||
/** | ||
* auto reload 1 model under the node id | ||
* | ||
* @param localNodeId node id | ||
* @param modelId model id | ||
*/ | ||
@VisibleForTesting | ||
void autoReloadModelByNodeAndModelId(String localNodeId, String modelId) throws MLException { | ||
List<String> allMLNodeIdList = Arrays | ||
.stream(nodeHelper.getAllNodes()) | ||
.filter(MLNodeUtils::isMLNode) | ||
.map(DiscoveryNode::getId) | ||
.collect(Collectors.toList()); | ||
|
||
if (!allMLNodeIdList.contains(localNodeId)) { | ||
allMLNodeIdList.add(localNodeId); | ||
} | ||
MLLoadModelRequest mlLoadModelRequest = new MLLoadModelRequest(modelId, allMLNodeIdList.toArray(new String[] {}), false, false); | ||
|
||
client | ||
.execute( | ||
MLLoadModelAction.INSTANCE, | ||
mlLoadModelRequest, | ||
ActionListener | ||
.wrap(response -> log.info("the model {} is auto reloading under the node {} ", modelId, localNodeId), exception -> { | ||
log.error("fail to reload model " + modelId + " under the node " + localNodeId + "\nthe reason is: " + exception); | ||
throw new MLException( | ||
"fail to reload model " + modelId + " under the node " + localNodeId + "\nthe reason is: " + exception | ||
); | ||
}) | ||
); | ||
} | ||
|
||
/** | ||
* query task index, and get the result of "task_type"="LOAD_MODEL" and "state"="COMPLETED" and | ||
* "worker_node" match nodeId | ||
* | ||
* @param localNodeId one of query condition | ||
*/ | ||
@VisibleForTesting | ||
void queryTask(String localNodeId, ActionListener<SearchResponse> searchResponseActionListener) { | ||
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().from(0).size(1); | ||
|
||
QueryBuilder queryBuilder = QueryBuilders | ||
.boolQuery() | ||
.must(QueryBuilders.matchPhraseQuery("task_type", "LOAD_MODEL")) | ||
.must(QueryBuilders.matchPhraseQuery("worker_node", localNodeId)) | ||
.must( | ||
QueryBuilders | ||
.boolQuery() | ||
.should(QueryBuilders.matchPhraseQuery("state", "COMPLETED")) | ||
.should(QueryBuilders.matchPhraseQuery("state", "COMPLETED_WITH_ERROR")) | ||
); | ||
searchSourceBuilder.query(queryBuilder); | ||
|
||
SortBuilder<FieldSortBuilder> sortBuilderOrder = new FieldSortBuilder("create_time").order(SortOrder.DESC); | ||
searchSourceBuilder.sort(sortBuilderOrder); | ||
|
||
SearchRequestBuilder searchRequestBuilder = new SearchRequestBuilder(client, SearchAction.INSTANCE) | ||
.setIndices(ML_TASK_INDEX) | ||
.setSource(searchSourceBuilder); | ||
|
||
searchRequestBuilder.execute(ActionListener.wrap(searchResponseActionListener::onResponse, exception -> { | ||
log.error("index {} not found, the reason is {}", ML_TASK_INDEX, exception); | ||
throw new MLException("index " + ML_TASK_INDEX + " not found"); | ||
})); | ||
} | ||
|
||
/** | ||
* get retry times from the index ".plugins-ml-model-reload" by 1 ml node | ||
* | ||
* @param localNodeId the filter condition to query | ||
*/ | ||
@VisibleForTesting | ||
void getRetryTimes(String localNodeId, ActionListener<SearchResponse> searchResponseActionListener) { | ||
if (!clusterService.state().metadata().indices().containsKey(ML_MODEL_RELOAD_INDEX)) { | ||
// ML_MODEL_RELOAD_INDEX did not exist, it means it is our first time to do model auto-reloading operation | ||
searchResponseActionListener.onResponse(null); | ||
} | ||
|
||
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); | ||
searchSourceBuilder.fetchSource(new String[] { MODEL_LOAD_RETRY_TIMES_FIELD }, null); | ||
QueryBuilder queryBuilder = QueryBuilders.idsQuery().addIds(localNodeId); | ||
searchSourceBuilder.query(queryBuilder); | ||
|
||
SearchRequestBuilder searchRequestBuilder = new SearchRequestBuilder(client, SearchAction.INSTANCE) | ||
.setIndices(ML_MODEL_RELOAD_INDEX) | ||
.setSource(searchSourceBuilder); | ||
|
||
searchRequestBuilder.execute(ActionListener.wrap(searchResponse -> { | ||
SearchHit[] hits = searchResponse.getHits().getHits(); | ||
if (CollectionUtils.isEmpty(hits)) { | ||
searchResponseActionListener.onResponse(null); | ||
return; | ||
} | ||
|
||
searchResponseActionListener.onResponse(searchResponse); | ||
}, searchResponseActionListener::onFailure)); | ||
} | ||
|
||
/** | ||
* save retry times | ||
* @param localNodeId node id | ||
* @param retryTimes actual retry times | ||
*/ | ||
@VisibleForTesting | ||
void saveLatestRetryTimes(String localNodeId, int retryTimes, ActionListener<IndexResponse> indexResponseActionListener) { | ||
Map<String, Object> content = new HashMap<>(2); | ||
content.put(NODE_ID_FIELD, localNodeId); | ||
content.put(MODEL_LOAD_RETRY_TIMES_FIELD, retryTimes); | ||
|
||
IndexRequestBuilder indexRequestBuilder = new IndexRequestBuilder(client, IndexAction.INSTANCE, ML_MODEL_RELOAD_INDEX) | ||
.setId(localNodeId) | ||
.setSource(content) | ||
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); | ||
|
||
indexRequestBuilder.execute(ActionListener.wrap(indexResponse -> { | ||
if (indexResponse.status() == RestStatus.CREATED || indexResponse.status() == RestStatus.OK) { | ||
log.info("node id:{} insert retry times successfully", localNodeId); | ||
indexResponseActionListener.onResponse(indexResponse); | ||
} | ||
}, e -> { | ||
log.error("node id:" + localNodeId + " insert retry times unsuccessfully", e); | ||
indexResponseActionListener.onFailure(new MLException(e)); | ||
})); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.