Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

renaming metrics #1224

Merged
merged 11 commits into from
Aug 22, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@
import org.opensearch.ml.helper.ModelAccessControlHelper;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.stats.MLNodeLevelStat;
import org.opensearch.ml.stats.MLStats;
import org.opensearch.ml.task.MLTaskDispatcher;
import org.opensearch.ml.task.MLTaskManager;
Expand Down Expand Up @@ -148,8 +147,6 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLDepl
if (!allowCustomDeploymentPlan && !deployToAllNodes) {
throw new IllegalArgumentException("Don't allow custom deployment plan");
}
// mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).increment();
mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT).increment();
DiscoveryNode[] allEligibleNodes = nodeFilter.getEligibleNodes(functionName);
Map<String, DiscoveryNode> nodeMapping = new HashMap<>();
for (DiscoveryNode node : allEligibleNodes) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@
import org.opensearch.ml.indices.MLIndicesHandler;
import org.opensearch.ml.model.MLModelGroupManager;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.stats.MLNodeLevelStat;
import org.opensearch.ml.stats.MLStats;
import org.opensearch.ml.task.MLTaskDispatcher;
import org.opensearch.ml.task.MLTaskManager;
Expand Down Expand Up @@ -234,12 +233,6 @@ private void registerModel(MLRegisterModelInput registerModelInput, ActionListen
throw new IllegalArgumentException("URL can't match trusted url regex");
}
}
// mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).increment();
mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT).increment();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why remove this line?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we are already counting this in the registerMLModel function in MLModelManager class

// //TODO: track executing task; track register failures
// mlStats.createCounterStatIfAbsent(FunctionName.TEXT_EMBEDDING,
// ActionName.REGISTER,
// MLActionLevelStat.ML_ACTION_REQUEST_COUNT).increment();
boolean isAsync = registerModelInput.getFunctionName() != FunctionName.REMOTE;
MLTask mlTask = MLTask
.builder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,9 @@ MLStatsNodeResponse createMLStatsNodeResponse(MLStatsNodesRequest mlStatsNodesRe
MLStatsInput mlStatsInput = mlStatsNodesRequest.getMlStatsInput();
// return node level stats
if (mlStatsInput.getTargetStatLevels().contains(MLStatLevel.NODE)) {
if (mlStatsInput.retrieveStat(MLNodeLevelStat.ML_NODE_JVM_HEAP_USAGE)) {
if (mlStatsInput.retrieveStat(MLNodeLevelStat.ML_JVM_HEAP_USAGE)) {
long heapUsedPercent = jvmService.stats().getMem().getHeapUsedPercent();
statValues.put(MLNodeLevelStat.ML_NODE_JVM_HEAP_USAGE, heapUsedPercent);
statValues.put(MLNodeLevelStat.ML_JVM_HEAP_USAGE, heapUsedPercent);
}

for (Enum statName : mlStats.getNodeStats().keySet()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,8 +228,7 @@ protected MLUndeployModelNodeResponse nodeOperation(MLUndeployModelNodeRequest r
}

private MLUndeployModelNodeResponse createUndeployModelNodeResponse(MLUndeployModelNodesRequest MLUndeployModelNodesRequest) {
mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).increment();
mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT).increment();
mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).increment();

String[] modelIds = MLUndeployModelNodesRequest.getModelIds();

Expand All @@ -246,7 +245,7 @@ private MLUndeployModelNodeResponse createUndeployModelNodeResponse(MLUndeployMo
}

Map<String, String> modelUndeployStatus = mlModelManager.undeployModel(modelIds);
mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).decrement();
mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).decrement();
return new MLUndeployModelNodeResponse(clusterService.localNode(), modelUndeployStatus, modelWorkerNodesMap);
}
}
39 changes: 16 additions & 23 deletions plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ public MLModelManager(
public void registerModelMeta(MLRegisterModelMetaInput mlRegisterModelMetaInput, ActionListener<String> listener) {
try {
FunctionName functionName = mlRegisterModelMetaInput.getFunctionName();
mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT).increment();
mlStats.getStat(MLNodeLevelStat.ML_REQUEST_COUNT).increment();
mlStats.createCounterStatIfAbsent(functionName, REGISTER, ML_ACTION_REQUEST_COUNT).increment();
String modelGroupId = mlRegisterModelMetaInput.getModelGroupId();
if (Strings.isBlank(modelGroupId)) {
Expand Down Expand Up @@ -322,9 +322,9 @@ public void registerMLModel(MLRegisterModelInput registerModelInput, MLTask mlTa

checkAndAddRunningTask(mlTask, maxRegisterTasksPerNode);
try {
mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT).increment();
mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).increment();
mlStats.getStat(MLNodeLevelStat.ML_REQUEST_COUNT).increment();
mlStats.createCounterStatIfAbsent(mlTask.getFunctionName(), REGISTER, ML_ACTION_REQUEST_COUNT).increment();
mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).increment();

String modelGroupId = registerModelInput.getModelGroupId();
GetRequest getModelGroupRequest = new GetRequest(ML_MODEL_GROUP_INDEX).id(modelGroupId);
Expand Down Expand Up @@ -384,17 +384,14 @@ public void registerMLModel(MLRegisterModelInput registerModelInput, MLTask mlTa
} catch (Exception e) {
handleException(registerModelInput.getFunctionName(), mlTask.getTaskId(), e);
} finally {
mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).increment();
mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).decrement();
}
}

private void indexRemoteModel(MLRegisterModelInput registerModelInput, MLTask mlTask, String modelVersion) {
String taskId = mlTask.getTaskId();
FunctionName functionName = mlTask.getFunctionName();
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT).increment();
mlStats.createCounterStatIfAbsent(functionName, REGISTER, ML_ACTION_REQUEST_COUNT).increment();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line is to track how many register requests on function level. By removing this, can we still track that?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, because we are tracking this in the parent function registerMLModel

mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).increment();

String modelName = registerModelInput.getModelName();
String version = modelVersion == null ? registerModelInput.getVersion() : modelVersion;
Expand Down Expand Up @@ -443,8 +440,6 @@ private void indexRemoteModel(MLRegisterModelInput registerModelInput, MLTask ml
} catch (Exception e) {
logException("Failed to upload model", e, log);
handleException(functionName, taskId, e);
} finally {
mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).increment();
}
}

Expand All @@ -462,9 +457,6 @@ private void registerModelFromUrl(MLRegisterModelInput registerModelInput, MLTas
String taskId = mlTask.getTaskId();
FunctionName functionName = mlTask.getFunctionName();
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT).increment();
mlStats.createCounterStatIfAbsent(functionName, REGISTER, ML_ACTION_REQUEST_COUNT).increment();
mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).increment();
String modelName = registerModelInput.getModelName();
String version = modelVersion == null ? registerModelInput.getVersion() : modelVersion;
String modelGroupId = registerModelInput.getModelGroupId();
Expand Down Expand Up @@ -509,8 +501,6 @@ private void registerModelFromUrl(MLRegisterModelInput registerModelInput, MLTas
} catch (Exception e) {
logException("Failed to register model", e, log);
handleException(functionName, taskId, e);
} finally {
mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).increment();
}
}

Expand Down Expand Up @@ -693,7 +683,7 @@ private void handleException(FunctionName functionName, String taskId, Exception
&& !(e instanceof MLResourceNotFoundException)
&& !(e instanceof IllegalArgumentException)) {
mlStats.createCounterStatIfAbsent(functionName, REGISTER, MLActionLevelStat.ML_ACTION_FAILURE_COUNT).increment();
mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_FAILURE_COUNT).increment();
mlStats.getStat(MLNodeLevelStat.ML_FAILURE_COUNT).increment();
}
Map<String, Object> updated = ImmutableMap.of(ERROR_FIELD, MLExceptionUtils.getRootCauseMessage(e), STATE_FIELD, FAILED);
mlTaskManager.updateMLTask(taskId, updated, TIMEOUT_IN_MILLIS, true);
Expand All @@ -718,7 +708,8 @@ public void deployModel(
ActionListener<String> listener
) {
mlStats.createCounterStatIfAbsent(functionName, ActionName.DEPLOY, ML_ACTION_REQUEST_COUNT).increment();
mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT).increment();
mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).increment();
mlStats.getStat(MLNodeLevelStat.ML_REQUEST_COUNT).increment();
List<String> workerNodes = mlTask.getWorkerNodes();
if (modelCacheHelper.isModelDeployed(modelId)) {
if (workerNodes != null && workerNodes.size() > 0) {
Expand Down Expand Up @@ -800,7 +791,7 @@ public void deployModel(
MLExecutable mlExecutable = mlEngine.deployExecute(mlModel, params);
try {
modelCacheHelper.setMLExecutor(modelId, mlExecutable);
mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_MODEL_COUNT).increment();
mlStats.getStat(MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT).increment();
modelCacheHelper.setModelState(modelId, MLModelState.DEPLOYED);
listener.onResponse("successful");
} catch (Exception e) {
Expand All @@ -813,7 +804,7 @@ public void deployModel(
Predictable predictable = mlEngine.deploy(mlModel, params);
try {
modelCacheHelper.setPredictor(modelId, predictable);
mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_MODEL_COUNT).increment();
mlStats.getStat(MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT).increment();
modelCacheHelper.setModelState(modelId, MLModelState.DEPLOYED);
Long modelContentSizeInBytes = mlModel.getModelContentSizeInBytes();
long contentSize = modelContentSizeInBytes == null
Expand All @@ -837,6 +828,8 @@ public void deployModel(
})));
} catch (Exception e) {
handleDeployModelException(modelId, functionName, listener, e);
} finally {
mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).decrement();
}
}

Expand All @@ -846,7 +839,7 @@ private void handleDeployModelException(String modelId, FunctionName functionNam
&& !(e instanceof MLResourceNotFoundException)
&& !(e instanceof IllegalArgumentException)) {
mlStats.createCounterStatIfAbsent(functionName, ActionName.DEPLOY, MLActionLevelStat.ML_ACTION_FAILURE_COUNT).increment();
mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_FAILURE_COUNT).increment();
mlStats.getStat(MLNodeLevelStat.ML_FAILURE_COUNT).increment();
}
removeModel(modelId);
listener.onFailure(e);
Expand All @@ -855,7 +848,7 @@ private void handleDeployModelException(String modelId, FunctionName functionNam
private void setupPredictable(String modelId, MLModel mlModel, Map<String, Object> params) {
Predictable predictable = mlEngine.deploy(mlModel, params);
modelCacheHelper.setPredictor(modelId, predictable);
mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_MODEL_COUNT).increment();
mlStats.getStat(MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT).increment();
modelCacheHelper.setModelState(modelId, MLModelState.DEPLOYED);
}

Expand Down Expand Up @@ -1056,8 +1049,8 @@ public synchronized Map<String, String> undeployModel(String[] modelIds) {
for (String modelId : modelIds) {
if (modelCacheHelper.isModelDeployed(modelId)) {
modelUndeployStatus.put(modelId, UNDEPLOYED);
mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_MODEL_COUNT).decrement();
mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT).increment();
mlStats.getStat(MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT).decrement();
mlStats.getStat(MLNodeLevelStat.ML_REQUEST_COUNT).increment();
mlStats
.createCounterStatIfAbsent(getModelFunctionName(modelId), ActionName.UNDEPLOY, ML_ACTION_REQUEST_COUNT)
.increment();
Expand All @@ -1070,7 +1063,7 @@ public synchronized Map<String, String> undeployModel(String[] modelIds) {
log.debug("undeploy all models {}", Arrays.toString(getLocalDeployedModels()));
for (String modelId : getLocalDeployedModels()) {
modelUndeployStatus.put(modelId, UNDEPLOYED);
mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_MODEL_COUNT).decrement();
mlStats.getStat(MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT).decrement();
mlStats.createCounterStatIfAbsent(getModelFunctionName(modelId), ActionName.UNDEPLOY, ML_ACTION_REQUEST_COUNT).increment();
removeModel(modelId);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -305,11 +305,11 @@ public Collection<Object> createComponents(
stats.put(MLClusterLevelStat.ML_MODEL_COUNT, new MLStat<>(true, new CounterSupplier()));
stats.put(MLClusterLevelStat.ML_CONNECTOR_COUNT, new MLStat<>(true, new CounterSupplier()));
// node level stats
stats.put(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT, new MLStat<>(false, new CounterSupplier()));
stats.put(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT, new MLStat<>(false, new CounterSupplier()));
stats.put(MLNodeLevelStat.ML_NODE_TOTAL_FAILURE_COUNT, new MLStat<>(false, new CounterSupplier()));
stats.put(MLNodeLevelStat.ML_NODE_TOTAL_MODEL_COUNT, new MLStat<>(false, new CounterSupplier()));
stats.put(MLNodeLevelStat.ML_NODE_TOTAL_CIRCUIT_BREAKER_TRIGGER_COUNT, new MLStat<>(false, new CounterSupplier()));
stats.put(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT, new MLStat<>(false, new CounterSupplier()));
stats.put(MLNodeLevelStat.ML_REQUEST_COUNT, new MLStat<>(false, new CounterSupplier()));
stats.put(MLNodeLevelStat.ML_FAILURE_COUNT, new MLStat<>(false, new CounterSupplier()));
stats.put(MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT, new MLStat<>(false, new CounterSupplier()));
stats.put(MLNodeLevelStat.ML_CIRCUIT_BREAKER_TRIGGER_COUNT, new MLStat<>(false, new CounterSupplier()));
this.mlStats = new MLStats(stats);

mlIndicesHandler = new MLIndicesHandler(clusterService, client);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

import org.opensearch.client.node.NodeClient;
Expand Down Expand Up @@ -60,6 +61,12 @@ public class RestMLStatsAction extends BaseRestHandler {
private static final String QUERY_ALL_MODEL_META_DOC =
"{\"query\":{\"bool\":{\"must_not\":{\"exists\":{\"field\":\"chunk_number\"}}}}}";

private static final Set<String> ML_NODE_STAT_NAMES = EnumSet
.allOf(MLNodeLevelStat.class)
.stream()
.map(stat -> stat.name())
.collect(Collectors.toSet());

/**
* Constructor
* @param mlStats MLStats object
Expand Down Expand Up @@ -148,6 +155,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
}

MLStatsInput createMlStatsInputFromRequestParams(RestRequest request) {

MLStatsInput mlStatsInput = new MLStatsInput();
Optional<String[]> nodeIds = splitCommaSeparatedParam(request, "nodeId");
if (nodeIds.isPresent()) {
Expand All @@ -158,7 +166,7 @@ MLStatsInput createMlStatsInputFromRequestParams(RestRequest request) {
for (String state : stats.get()) {
state = state.toUpperCase(Locale.ROOT);
// only support cluster and node level stats for bwc
if (state.startsWith("ML_NODE")) {
if (ML_NODE_STAT_NAMES.contains(state)) {
mlStatsInput.getNodeLevelStats().add(MLNodeLevelStat.from(state));
} else {
mlStatsInput.getClusterLevelStats().add(MLClusterLevelStat.from(state));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,13 @@
* This enum represents node level stats.
*/
public enum MLNodeLevelStat {
ML_NODE_JVM_HEAP_USAGE,
ML_NODE_EXECUTING_TASK_COUNT,
ML_NODE_TOTAL_REQUEST_COUNT,
ML_NODE_TOTAL_FAILURE_COUNT,
ML_NODE_TOTAL_MODEL_COUNT,
ML_NODE_TOTAL_CIRCUIT_BREAKER_TRIGGER_COUNT;
ML_JVM_HEAP_USAGE,
ML_EXECUTING_TASK_COUNT, // How many tasks are executing currently. If any task starts, then it will be 1, if the task finished then it
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

then it will be 1 -> then it will increase by 1

will get back to 0 -> will decrease by 1?

// will get back to 0.
ML_REQUEST_COUNT,
ML_FAILURE_COUNT,
ML_DEPLOYED_MODEL_COUNT,
ML_CIRCUIT_BREAKER_TRIGGER_COUNT;

public static MLNodeLevelStat from(String value) {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ protected TransportResponseHandler<MLExecuteTaskResponse> getResponseHandler(Act
protected void executeTask(MLExecuteTaskRequest request, ActionListener<MLExecuteTaskResponse> listener) {
threadPool.executor(EXECUTE_THREAD_POOL).execute(() -> {
try {
mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).increment();
mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT).increment();
mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).increment();
mlStats.getStat(MLNodeLevelStat.ML_REQUEST_COUNT).increment();
mlStats
.createCounterStatIfAbsent(request.getFunctionName(), ActionName.EXECUTE, MLActionLevelStat.ML_ACTION_REQUEST_COUNT)
.increment();
Expand All @@ -113,7 +113,7 @@ protected void executeTask(MLExecuteTaskRequest request, ActionListener<MLExecut
.increment();
listener.onFailure(e);
} finally {
mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).decrement();
mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).decrement();
}
});
}
Expand Down
Loading