-
Notifications
You must be signed in to change notification settings - Fork 138
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
the dev of [FEATURE]Auto reload model when cluster rebooted/node rejoin #711
Changes from 24 commits
77a1ffb
9f50458
ddfb117
0c235a6
35a3703
350faed
5126aba
c22980c
a0f1df5
f62ad71
eaf794d
3f710da
da42086
699e06a
0365674
d779f8c
9fa1025
beef20f
5f3c2cc
facc4a1
7356a82
ddab41c
76fb7f0
c0b575e
78ae922
f169ec1
5436e6d
71e5645
38bf342
467250e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,329 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
package org.opensearch.ml.model; | ||
|
||
import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken; | ||
import static org.opensearch.ml.common.CommonValue.ML_MODEL_RELOAD_INDEX; | ||
import static org.opensearch.ml.common.CommonValue.ML_TASK_INDEX; | ||
import static org.opensearch.ml.common.CommonValue.MODEL_LOAD_RETRY_TIMES_FIELD; | ||
import static org.opensearch.ml.common.CommonValue.NODE_ID_FIELD; | ||
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MODEL_AUTO_RELOAD_ENABLE; | ||
import static org.opensearch.ml.settings.MLCommonsSettings.ML_MODEL_RELOAD_MAX_RETRY_TIMES; | ||
import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry; | ||
|
||
import java.util.ArrayList; | ||
import java.util.HashMap; | ||
import java.util.List; | ||
import java.util.Map; | ||
import java.util.concurrent.ExecutionException; | ||
|
||
import lombok.extern.log4j.Log4j2; | ||
|
||
import org.apache.commons.lang3.exception.ExceptionUtils; | ||
import org.opensearch.action.ActionListener; | ||
import org.opensearch.action.StepListener; | ||
import org.opensearch.action.index.IndexAction; | ||
import org.opensearch.action.index.IndexRequestBuilder; | ||
import org.opensearch.action.index.IndexResponse; | ||
import org.opensearch.action.search.SearchAction; | ||
import org.opensearch.action.search.SearchRequestBuilder; | ||
import org.opensearch.action.search.SearchResponse; | ||
import org.opensearch.action.support.WriteRequest; | ||
import org.opensearch.client.Client; | ||
import org.opensearch.cluster.service.ClusterService; | ||
import org.opensearch.common.settings.Settings; | ||
import org.opensearch.common.util.CollectionUtils; | ||
import org.opensearch.common.xcontent.NamedXContentRegistry; | ||
import org.opensearch.common.xcontent.XContentParser; | ||
import org.opensearch.index.IndexNotFoundException; | ||
import org.opensearch.index.query.QueryBuilder; | ||
import org.opensearch.index.query.QueryBuilders; | ||
import org.opensearch.ml.cluster.DiscoveryNodeHelper; | ||
import org.opensearch.ml.common.MLTask; | ||
import org.opensearch.ml.common.exception.MLException; | ||
import org.opensearch.ml.common.transport.load.MLLoadModelAction; | ||
import org.opensearch.ml.common.transport.load.MLLoadModelRequest; | ||
import org.opensearch.ml.utils.MLNodeUtils; | ||
import org.opensearch.rest.RestStatus; | ||
import org.opensearch.search.SearchHit; | ||
import org.opensearch.search.builder.SearchSourceBuilder; | ||
import org.opensearch.search.sort.FieldSortBuilder; | ||
import org.opensearch.search.sort.SortBuilder; | ||
import org.opensearch.search.sort.SortOrder; | ||
import org.opensearch.threadpool.ThreadPool; | ||
|
||
import com.google.common.annotations.VisibleForTesting; | ||
|
||
/** | ||
* Manager class for ML models and nodes. It contains ML model auto reload operations etc. | ||
*/ | ||
@Log4j2 | ||
public class MLModelAutoReloader { | ||
|
||
private final Client client; | ||
private final ClusterService clusterService; | ||
private final NamedXContentRegistry xContentRegistry; | ||
private final DiscoveryNodeHelper nodeHelper; | ||
private final ThreadPool threadPool; | ||
private volatile Boolean enableAutoReloadModel; | ||
private volatile Integer autoReloadMaxRetryTimes; | ||
|
||
/** | ||
* constructor method, init all the params necessary for model auto reloading | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This |
||
* | ||
* @param clusterService clusterService | ||
* @param threadPool threadPool | ||
* @param client client | ||
* @param xContentRegistry xContentRegistry | ||
* @param nodeHelper nodeHelper | ||
* @param settings settings | ||
*/ | ||
public MLModelAutoReloader( | ||
ClusterService clusterService, | ||
ThreadPool threadPool, | ||
Client client, | ||
NamedXContentRegistry xContentRegistry, | ||
DiscoveryNodeHelper nodeHelper, | ||
Settings settings | ||
) { | ||
this.clusterService = clusterService; | ||
this.client = client; | ||
this.xContentRegistry = xContentRegistry; | ||
this.nodeHelper = nodeHelper; | ||
this.threadPool = threadPool; | ||
|
||
enableAutoReloadModel = ML_COMMONS_MODEL_AUTO_RELOAD_ENABLE.get(settings); | ||
autoReloadMaxRetryTimes = ML_MODEL_RELOAD_MAX_RETRY_TIMES.get(settings); | ||
clusterService | ||
.getClusterSettings() | ||
.addSettingsUpdateConsumer(ML_COMMONS_MODEL_AUTO_RELOAD_ENABLE, it -> enableAutoReloadModel = it); | ||
|
||
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_MODEL_RELOAD_MAX_RETRY_TIMES, it -> autoReloadMaxRetryTimes = it); | ||
} | ||
|
||
/** | ||
* the main method: model auto reloading | ||
*/ | ||
public void autoReloadModel() { | ||
log.info("auto reload model enabled: {} ", enableAutoReloadModel); | ||
|
||
// if we don't need to reload automatically, just return without doing anything | ||
if (!enableAutoReloadModel) { | ||
return; | ||
} | ||
|
||
// At opensearch startup, get local node id, if not ml node,we ignored, just return without doing anything | ||
if (!MLNodeUtils.isMLNode(clusterService.localNode())) { | ||
return; | ||
} | ||
|
||
String localNodeId = clusterService.localNode().getId(); | ||
// auto reload all models of this local ml node | ||
threadPool.generic().submit(() -> { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please change this threadpool to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. en ,yaliang have said it~ I have committed the latest code |
||
try { | ||
autoReloadModelByNodeId(localNodeId); | ||
} catch (ExecutionException | InterruptedException e) { | ||
log | ||
.error( | ||
"the model auto-reloading has exception,and the root cause message is: {}", | ||
ExceptionUtils.getRootCauseMessage(e) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How about print the full exception stack trace here? Just print out the root cause seems not easy to debug. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can I use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think for a while,and changed it to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Print the entire exception stack is useful and convenient to locate and debug issues, you can change to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cool~I will modify it according to what you said |
||
); | ||
throw new MLException(e); | ||
} | ||
}); | ||
} | ||
|
||
/** | ||
* auto reload all the models under the node id<br> the node must be a ml node<br> | ||
* | ||
* @param localNodeId node id | ||
*/ | ||
@VisibleForTesting | ||
void autoReloadModelByNodeId(String localNodeId) throws ExecutionException, InterruptedException { | ||
StepListener<SearchResponse> queryTaskStep = new StepListener<>(); | ||
StepListener<SearchResponse> getRetryTimesStep = new StepListener<>(); | ||
StepListener<IndexResponse> saveLatestRetryTimesStep = new StepListener<>(); | ||
|
||
if (!clusterService.state().metadata().indices().containsKey(ML_TASK_INDEX)) { | ||
// ML_TASK_INDEX did not exist,do nothing | ||
return; | ||
} | ||
|
||
queryTask(localNodeId, ActionListener.wrap(queryTaskStep::onResponse, queryTaskStep::onFailure)); | ||
|
||
getRetryTimes(localNodeId, ActionListener.wrap(getRetryTimesStep::onResponse, getRetryTimesStep::onFailure)); | ||
|
||
queryTaskStep.whenComplete(searchResponse -> { | ||
SearchHit[] hits = searchResponse.getHits().getHits(); | ||
if (CollectionUtils.isEmpty(hits)) { | ||
return; | ||
} | ||
|
||
getRetryTimesStep.whenComplete(getReTryTimesResponse -> { | ||
int retryTimes = 0; | ||
// if getReTryTimesResponse is null,it means we get retryTimes at the first time,and the index | ||
// .plugins-ml-model-reload doesn't exist,so we should let retryTimes be zero(init value) | ||
// we don't do anything | ||
// if getReTryTimesResponse is not null,it means we have saved the value of retryTimes into the index | ||
// .plugins-ml-model-reload,so we get the value of the field MODEL_LOAD_RETRY_TIMES_FIELD | ||
if (getReTryTimesResponse != null) { | ||
Map<String, Object> sourceAsMap = getReTryTimesResponse.getHits().getHits()[0].getSourceAsMap(); | ||
retryTimes = (Integer) sourceAsMap.get(MODEL_LOAD_RETRY_TIMES_FIELD); | ||
} | ||
|
||
// According to the node id to get retry times, if more than the max retry times, don't need to retry | ||
// that the number of unsuccessful reload has reached the maximum number of times, do not need to reload | ||
if (retryTimes > autoReloadMaxRetryTimes) { | ||
log.info("Node: {} has reached to the max retry limit, failed to load models", localNodeId); | ||
return; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Before return, should we check how long the node has been in the max retry status and reset to 0 after a substantial time? It looks to me the node will never reload forever once reached maximum retry times. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. in the first comment of yours, I found ml_task index and ml_model index are both definition in ml_task index, so if I add 2 new fields in ml_task,the ml_model index will have these 2 fields,too. It sounds that give ml_model index redundant attributes. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Cool. Let's keep this logic then. But we should try to define a new type of ML Task for auto reload, and reuse MLTask to store the max_retry field, etc. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thanks~ |
||
} | ||
|
||
try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, hits[0].getSourceRef())) { | ||
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); | ||
MLTask mlTask = MLTask.parse(parser); | ||
|
||
autoReloadModelByNodeAndModelId(localNodeId, mlTask.getModelId()); | ||
|
||
// if reload the model successfully,the number of unsuccessful reload should be reset to zero. | ||
retryTimes = 0; | ||
} catch (MLException e) { | ||
retryTimes++; | ||
log.error("Can't auto reload model in node id {} ,has tried {} times\nThe reason is:{}", localNodeId, retryTimes, e); | ||
} | ||
|
||
// Store the latest value of the retryTimes and node id under the index ".plugins-ml-model-reload" | ||
saveLatestRetryTimes( | ||
localNodeId, | ||
retryTimes, | ||
ActionListener.wrap(saveLatestRetryTimesStep::onResponse, saveLatestRetryTimesStep::onFailure) | ||
); | ||
}, getRetryTimesStep::onFailure); | ||
}, queryTaskStep::onFailure); | ||
|
||
saveLatestRetryTimesStep.whenComplete(response -> log.info("successfully complete all steps"), saveLatestRetryTimesStep::onFailure); | ||
} | ||
|
||
/** | ||
* auto reload 1 model under the node id | ||
* | ||
* @param localNodeId node id | ||
* @param modelId model id | ||
*/ | ||
@VisibleForTesting | ||
void autoReloadModelByNodeAndModelId(String localNodeId, String modelId) throws MLException { | ||
String[] allNodeIds = nodeHelper.getAllNodeIds(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok, I will modify it. Let the collection just have all ids of ml node. |
||
List<String> allNodeIdList = new ArrayList<>(List.of(allNodeIds)); | ||
if (!allNodeIdList.contains(localNodeId)) { | ||
allNodeIdList.add(localNodeId); | ||
} | ||
MLLoadModelRequest mlLoadModelRequest = new MLLoadModelRequest(modelId, allNodeIdList.toArray(new String[] {}), false, false); | ||
|
||
client | ||
.execute( | ||
MLLoadModelAction.INSTANCE, | ||
mlLoadModelRequest, | ||
ActionListener | ||
.wrap(response -> log.info("the model {} is auto reloading under the node {} ", modelId, localNodeId), exception -> { | ||
log.error("fail to reload model " + modelId + " under the node " + localNodeId + "\nthe reason is: " + exception); | ||
throw new MLException( | ||
"fail to reload model " + modelId + " under the node " + localNodeId + "\nthe reason is: " + exception | ||
); | ||
}) | ||
); | ||
} | ||
|
||
/** | ||
* query task index, and get the result of "task_type"="LOAD_MODEL" and "state"="COMPLETED" and | ||
* "worker_node" match nodeId | ||
* | ||
* @param localNodeId one of query condition | ||
*/ | ||
@VisibleForTesting | ||
void queryTask(String localNodeId, ActionListener<SearchResponse> searchResponseActionListener) { | ||
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().from(0).size(1); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This query only return the latest load model task. If user have 3 models, this query only return 1 latest load model task for 1 model, the other 2 models' tasks won't be returned. So we can't reload all 3 models, just reload 1 model. Is that correct? |
||
|
||
QueryBuilder queryBuilder = QueryBuilders | ||
.boolQuery() | ||
.must(QueryBuilders.matchPhraseQuery("task_type", "LOAD_MODEL")) | ||
.must(QueryBuilders.matchPhraseQuery("worker_node", localNodeId)) | ||
.must( | ||
QueryBuilders | ||
.boolQuery() | ||
.should(QueryBuilders.matchPhraseQuery("state", "COMPLETED")) | ||
.should(QueryBuilders.matchPhraseQuery("state", "COMPLETED_WITH_ERROR")) | ||
); | ||
searchSourceBuilder.query(queryBuilder); | ||
|
||
SortBuilder<FieldSortBuilder> sortBuilderOrder = new FieldSortBuilder("create_time").order(SortOrder.DESC); | ||
searchSourceBuilder.sort(sortBuilderOrder); | ||
|
||
SearchRequestBuilder searchRequestBuilder = new SearchRequestBuilder(client, SearchAction.INSTANCE) | ||
.setIndices(ML_TASK_INDEX) | ||
.setSource(searchSourceBuilder); | ||
|
||
searchRequestBuilder.execute(ActionListener.wrap(searchResponseActionListener::onResponse, exception -> { | ||
log.error("index {} not found, the reason is {}", ML_TASK_INDEX, exception); | ||
throw new IndexNotFoundException("index " + ML_TASK_INDEX + " not found"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can't confirm this is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok, I have changed it~ |
||
})); | ||
} | ||
|
||
/** | ||
* get retry times from the index ".plugins-ml-model-reload" by 1 ml node | ||
* | ||
* @param localNodeId the filter condition to query | ||
*/ | ||
@VisibleForTesting | ||
void getRetryTimes(String localNodeId, ActionListener<SearchResponse> searchResponseActionListener) { | ||
if (!clusterService.state().metadata().indices().containsKey(ML_MODEL_RELOAD_INDEX)) { | ||
// ML_MODEL_RELOAD_INDEX did not exist, it means it is our first time to do model auto-reloading operation | ||
searchResponseActionListener.onResponse(null); | ||
} | ||
|
||
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); | ||
searchSourceBuilder.fetchSource(new String[] { MODEL_LOAD_RETRY_TIMES_FIELD }, null); | ||
QueryBuilder queryBuilder = QueryBuilders.idsQuery().addIds(localNodeId); | ||
searchSourceBuilder.query(queryBuilder); | ||
|
||
SearchRequestBuilder searchRequestBuilder = new SearchRequestBuilder(client, SearchAction.INSTANCE) | ||
.setIndices(ML_MODEL_RELOAD_INDEX) | ||
.setSource(searchSourceBuilder); | ||
|
||
searchRequestBuilder.execute(ActionListener.wrap(searchResponse -> { | ||
SearchHit[] hits = searchResponse.getHits().getHits(); | ||
if (CollectionUtils.isEmpty(hits)) { | ||
searchResponseActionListener.onResponse(null); | ||
return; | ||
} | ||
|
||
searchResponseActionListener.onResponse(searchResponse); | ||
}, searchResponseActionListener::onFailure)); | ||
} | ||
|
||
/** | ||
* save retry times | ||
* @param localNodeId node id | ||
* @param retryTimes actual retry times | ||
*/ | ||
@VisibleForTesting | ||
void saveLatestRetryTimes(String localNodeId, int retryTimes, ActionListener<IndexResponse> indexResponseActionListener) { | ||
Map<String, Object> content = new HashMap<>(2); | ||
content.put(NODE_ID_FIELD, localNodeId); | ||
content.put(MODEL_LOAD_RETRY_TIMES_FIELD, retryTimes); | ||
|
||
IndexRequestBuilder indexRequestBuilder = new IndexRequestBuilder(client, IndexAction.INSTANCE, ML_MODEL_RELOAD_INDEX) | ||
.setId(localNodeId) | ||
.setSource(content) | ||
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); | ||
|
||
indexRequestBuilder.execute(ActionListener.wrap(indexResponse -> { | ||
if (indexResponse.status() == RestStatus.CREATED || indexResponse.status() == RestStatus.OK) { | ||
log.info("node id:{} insert retry times successfully", localNodeId); | ||
indexResponseActionListener.onResponse(indexResponse); | ||
return; | ||
} | ||
indexResponseActionListener.onFailure(new MLException("node id:" + localNodeId + " insert retry times unsuccessfully")); | ||
}, indexResponseActionListener::onFailure)); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add logs here when receiving There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done~ |
||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A general question: Is it possible to avoid using a new index to achieve auto reload? Can we just query the Task index and find out all the loaded models in the current node and reload them all after OS started? I may missed some discussion earlier, but it looks like the retry number and search results can be stored locally in the memory.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this ml node has been restarted for some unknown reason, I can still use the persistent
retryTimes
value to know how many times the models on this node have been auto-reloaded before, then decide whether to do auto-reload this time. but if it is placed in cache, I can't get this info and have to auto-reload again. both are compared. The former may have some performance improvementThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, but this is a trade off for performance improvement by using a lot more resources. Is it possible to define this auto_reload as a ml_task and reuse the ml_task index to store the retry_times? Adding 2 new fields in ml_task may be much cheaper than using a new index. Thoughts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After the communication with you and charlie, we will elaborate