diff --git a/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelAction.java b/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelAction.java index 696bbba5da..72a9cfd889 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelAction.java @@ -54,7 +54,6 @@ import org.opensearch.ml.helper.ModelAccessControlHelper; import org.opensearch.ml.model.MLModelManager; import org.opensearch.ml.settings.MLFeatureEnabledSetting; -import org.opensearch.ml.stats.MLNodeLevelStat; import org.opensearch.ml.stats.MLStats; import org.opensearch.ml.task.MLTaskDispatcher; import org.opensearch.ml.task.MLTaskManager; @@ -148,8 +147,6 @@ protected void doExecute(Task task, ActionRequest request, ActionListener nodeMapping = new HashMap<>(); for (DiscoveryNode node : allEligibleNodes) { diff --git a/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java b/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java index bb511108b7..5cdba7f6b6 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java @@ -54,7 +54,6 @@ import org.opensearch.ml.indices.MLIndicesHandler; import org.opensearch.ml.model.MLModelGroupManager; import org.opensearch.ml.model.MLModelManager; -import org.opensearch.ml.stats.MLNodeLevelStat; import org.opensearch.ml.stats.MLStats; import org.opensearch.ml.task.MLTaskDispatcher; import org.opensearch.ml.task.MLTaskManager; @@ -234,12 +233,6 @@ private void registerModel(MLRegisterModelInput registerModelInput, ActionListen throw new IllegalArgumentException("URL can't match trusted url regex"); } } - // mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).increment(); - mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT).increment(); - // //TODO: track executing task; track register failures - // mlStats.createCounterStatIfAbsent(FunctionName.TEXT_EMBEDDING, - // ActionName.REGISTER, - // MLActionLevelStat.ML_ACTION_REQUEST_COUNT).increment(); boolean isAsync = registerModelInput.getFunctionName() != FunctionName.REMOTE; MLTask mlTask = MLTask .builder() diff --git a/plugin/src/main/java/org/opensearch/ml/action/stats/MLStatsNodesTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/stats/MLStatsNodesTransportAction.java index 479ccdcbf1..61889ad254 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/stats/MLStatsNodesTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/stats/MLStatsNodesTransportAction.java @@ -97,9 +97,9 @@ MLStatsNodeResponse createMLStatsNodeResponse(MLStatsNodesRequest mlStatsNodesRe MLStatsInput mlStatsInput = mlStatsNodesRequest.getMlStatsInput(); // return node level stats if (mlStatsInput.getTargetStatLevels().contains(MLStatLevel.NODE)) { - if (mlStatsInput.retrieveStat(MLNodeLevelStat.ML_NODE_JVM_HEAP_USAGE)) { + if (mlStatsInput.retrieveStat(MLNodeLevelStat.ML_JVM_HEAP_USAGE)) { long heapUsedPercent = jvmService.stats().getMem().getHeapUsedPercent(); - statValues.put(MLNodeLevelStat.ML_NODE_JVM_HEAP_USAGE, heapUsedPercent); + statValues.put(MLNodeLevelStat.ML_JVM_HEAP_USAGE, heapUsedPercent); } for (Enum statName : mlStats.getNodeStats().keySet()) { diff --git a/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelAction.java b/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelAction.java index 99728cad1c..1a9b375e55 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelAction.java @@ -228,8 +228,7 @@ protected MLUndeployModelNodeResponse nodeOperation(MLUndeployModelNodeRequest r } private MLUndeployModelNodeResponse createUndeployModelNodeResponse(MLUndeployModelNodesRequest MLUndeployModelNodesRequest) { - mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).increment(); - mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT).increment(); + mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).increment(); String[] modelIds = MLUndeployModelNodesRequest.getModelIds(); @@ -246,7 +245,7 @@ private MLUndeployModelNodeResponse createUndeployModelNodeResponse(MLUndeployMo } Map modelUndeployStatus = mlModelManager.undeployModel(modelIds); - mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).decrement(); + mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).decrement(); return new MLUndeployModelNodeResponse(clusterService.localNode(), modelUndeployStatus, modelWorkerNodesMap); } } diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java index ad9b687c07..bc7edde75b 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java @@ -212,7 +212,7 @@ public MLModelManager( public void registerModelMeta(MLRegisterModelMetaInput mlRegisterModelMetaInput, ActionListener listener) { try { FunctionName functionName = mlRegisterModelMetaInput.getFunctionName(); - mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT).increment(); + mlStats.getStat(MLNodeLevelStat.ML_REQUEST_COUNT).increment(); mlStats.createCounterStatIfAbsent(functionName, REGISTER, ML_ACTION_REQUEST_COUNT).increment(); String modelGroupId = mlRegisterModelMetaInput.getModelGroupId(); if (Strings.isBlank(modelGroupId)) { @@ -322,9 +322,9 @@ public void registerMLModel(MLRegisterModelInput registerModelInput, MLTask mlTa checkAndAddRunningTask(mlTask, maxRegisterTasksPerNode); try { - mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT).increment(); - mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).increment(); + mlStats.getStat(MLNodeLevelStat.ML_REQUEST_COUNT).increment(); mlStats.createCounterStatIfAbsent(mlTask.getFunctionName(), REGISTER, ML_ACTION_REQUEST_COUNT).increment(); + mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).increment(); String modelGroupId = registerModelInput.getModelGroupId(); GetRequest getModelGroupRequest = new GetRequest(ML_MODEL_GROUP_INDEX).id(modelGroupId); @@ -384,7 +384,7 @@ public void registerMLModel(MLRegisterModelInput registerModelInput, MLTask mlTa } catch (Exception e) { handleException(registerModelInput.getFunctionName(), mlTask.getTaskId(), e); } finally { - mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).increment(); + mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).decrement(); } } @@ -392,9 +392,6 @@ private void indexRemoteModel(MLRegisterModelInput registerModelInput, MLTask ml String taskId = mlTask.getTaskId(); FunctionName functionName = mlTask.getFunctionName(); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT).increment(); - mlStats.createCounterStatIfAbsent(functionName, REGISTER, ML_ACTION_REQUEST_COUNT).increment(); - mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).increment(); String modelName = registerModelInput.getModelName(); String version = modelVersion == null ? registerModelInput.getVersion() : modelVersion; @@ -443,8 +440,6 @@ private void indexRemoteModel(MLRegisterModelInput registerModelInput, MLTask ml } catch (Exception e) { logException("Failed to upload model", e, log); handleException(functionName, taskId, e); - } finally { - mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).increment(); } } @@ -462,9 +457,6 @@ private void registerModelFromUrl(MLRegisterModelInput registerModelInput, MLTas String taskId = mlTask.getTaskId(); FunctionName functionName = mlTask.getFunctionName(); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT).increment(); - mlStats.createCounterStatIfAbsent(functionName, REGISTER, ML_ACTION_REQUEST_COUNT).increment(); - mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).increment(); String modelName = registerModelInput.getModelName(); String version = modelVersion == null ? registerModelInput.getVersion() : modelVersion; String modelGroupId = registerModelInput.getModelGroupId(); @@ -509,8 +501,6 @@ private void registerModelFromUrl(MLRegisterModelInput registerModelInput, MLTas } catch (Exception e) { logException("Failed to register model", e, log); handleException(functionName, taskId, e); - } finally { - mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).increment(); } } @@ -693,7 +683,7 @@ private void handleException(FunctionName functionName, String taskId, Exception && !(e instanceof MLResourceNotFoundException) && !(e instanceof IllegalArgumentException)) { mlStats.createCounterStatIfAbsent(functionName, REGISTER, MLActionLevelStat.ML_ACTION_FAILURE_COUNT).increment(); - mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_FAILURE_COUNT).increment(); + mlStats.getStat(MLNodeLevelStat.ML_FAILURE_COUNT).increment(); } Map updated = ImmutableMap.of(ERROR_FIELD, MLExceptionUtils.getRootCauseMessage(e), STATE_FIELD, FAILED); mlTaskManager.updateMLTask(taskId, updated, TIMEOUT_IN_MILLIS, true); @@ -718,7 +708,8 @@ public void deployModel( ActionListener listener ) { mlStats.createCounterStatIfAbsent(functionName, ActionName.DEPLOY, ML_ACTION_REQUEST_COUNT).increment(); - mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT).increment(); + mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).increment(); + mlStats.getStat(MLNodeLevelStat.ML_REQUEST_COUNT).increment(); List workerNodes = mlTask.getWorkerNodes(); if (modelCacheHelper.isModelDeployed(modelId)) { if (workerNodes != null && workerNodes.size() > 0) { @@ -800,7 +791,7 @@ public void deployModel( MLExecutable mlExecutable = mlEngine.deployExecute(mlModel, params); try { modelCacheHelper.setMLExecutor(modelId, mlExecutable); - mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_MODEL_COUNT).increment(); + mlStats.getStat(MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT).increment(); modelCacheHelper.setModelState(modelId, MLModelState.DEPLOYED); listener.onResponse("successful"); } catch (Exception e) { @@ -813,7 +804,7 @@ public void deployModel( Predictable predictable = mlEngine.deploy(mlModel, params); try { modelCacheHelper.setPredictor(modelId, predictable); - mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_MODEL_COUNT).increment(); + mlStats.getStat(MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT).increment(); modelCacheHelper.setModelState(modelId, MLModelState.DEPLOYED); Long modelContentSizeInBytes = mlModel.getModelContentSizeInBytes(); long contentSize = modelContentSizeInBytes == null @@ -837,6 +828,8 @@ public void deployModel( }))); } catch (Exception e) { handleDeployModelException(modelId, functionName, listener, e); + } finally { + mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).decrement(); } } @@ -846,7 +839,7 @@ private void handleDeployModelException(String modelId, FunctionName functionNam && !(e instanceof MLResourceNotFoundException) && !(e instanceof IllegalArgumentException)) { mlStats.createCounterStatIfAbsent(functionName, ActionName.DEPLOY, MLActionLevelStat.ML_ACTION_FAILURE_COUNT).increment(); - mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_FAILURE_COUNT).increment(); + mlStats.getStat(MLNodeLevelStat.ML_FAILURE_COUNT).increment(); } removeModel(modelId); listener.onFailure(e); @@ -855,7 +848,7 @@ private void handleDeployModelException(String modelId, FunctionName functionNam private void setupPredictable(String modelId, MLModel mlModel, Map params) { Predictable predictable = mlEngine.deploy(mlModel, params); modelCacheHelper.setPredictor(modelId, predictable); - mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_MODEL_COUNT).increment(); + mlStats.getStat(MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT).increment(); modelCacheHelper.setModelState(modelId, MLModelState.DEPLOYED); } @@ -1056,8 +1049,8 @@ public synchronized Map undeployModel(String[] modelIds) { for (String modelId : modelIds) { if (modelCacheHelper.isModelDeployed(modelId)) { modelUndeployStatus.put(modelId, UNDEPLOYED); - mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_MODEL_COUNT).decrement(); - mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT).increment(); + mlStats.getStat(MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT).decrement(); + mlStats.getStat(MLNodeLevelStat.ML_REQUEST_COUNT).increment(); mlStats .createCounterStatIfAbsent(getModelFunctionName(modelId), ActionName.UNDEPLOY, ML_ACTION_REQUEST_COUNT) .increment(); @@ -1070,7 +1063,7 @@ public synchronized Map undeployModel(String[] modelIds) { log.debug("undeploy all models {}", Arrays.toString(getLocalDeployedModels())); for (String modelId : getLocalDeployedModels()) { modelUndeployStatus.put(modelId, UNDEPLOYED); - mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_MODEL_COUNT).decrement(); + mlStats.getStat(MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT).decrement(); mlStats.createCounterStatIfAbsent(getModelFunctionName(modelId), ActionName.UNDEPLOY, ML_ACTION_REQUEST_COUNT).increment(); removeModel(modelId); } diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index 0179c4fb6c..94ec866df8 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -305,11 +305,11 @@ public Collection createComponents( stats.put(MLClusterLevelStat.ML_MODEL_COUNT, new MLStat<>(true, new CounterSupplier())); stats.put(MLClusterLevelStat.ML_CONNECTOR_COUNT, new MLStat<>(true, new CounterSupplier())); // node level stats - stats.put(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT, new MLStat<>(false, new CounterSupplier())); - stats.put(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT, new MLStat<>(false, new CounterSupplier())); - stats.put(MLNodeLevelStat.ML_NODE_TOTAL_FAILURE_COUNT, new MLStat<>(false, new CounterSupplier())); - stats.put(MLNodeLevelStat.ML_NODE_TOTAL_MODEL_COUNT, new MLStat<>(false, new CounterSupplier())); - stats.put(MLNodeLevelStat.ML_NODE_TOTAL_CIRCUIT_BREAKER_TRIGGER_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_REQUEST_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_FAILURE_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_CIRCUIT_BREAKER_TRIGGER_COUNT, new MLStat<>(false, new CounterSupplier())); this.mlStats = new MLStats(stats); mlIndicesHandler = new MLIndicesHandler(clusterService, client); diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLStatsAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLStatsAction.java index 829b86ebb8..14d7bb1280 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLStatsAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLStatsAction.java @@ -21,6 +21,7 @@ import java.util.Locale; import java.util.Map; import java.util.Optional; +import java.util.Set; import java.util.stream.Collectors; import org.opensearch.action.ActionListener; @@ -60,6 +61,12 @@ public class RestMLStatsAction extends BaseRestHandler { private static final String QUERY_ALL_MODEL_META_DOC = "{\"query\":{\"bool\":{\"must_not\":{\"exists\":{\"field\":\"chunk_number\"}}}}}"; + private static final Set ML_NODE_STAT_NAMES = EnumSet + .allOf(MLNodeLevelStat.class) + .stream() + .map(stat -> stat.name()) + .collect(Collectors.toSet()); + /** * Constructor * @param mlStats MLStats object @@ -148,6 +155,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli } MLStatsInput createMlStatsInputFromRequestParams(RestRequest request) { + MLStatsInput mlStatsInput = new MLStatsInput(); Optional nodeIds = splitCommaSeparatedParam(request, "nodeId"); if (nodeIds.isPresent()) { @@ -158,7 +166,7 @@ MLStatsInput createMlStatsInputFromRequestParams(RestRequest request) { for (String state : stats.get()) { state = state.toUpperCase(Locale.ROOT); // only support cluster and node level stats for bwc - if (state.startsWith("ML_NODE")) { + if (ML_NODE_STAT_NAMES.contains(state)) { mlStatsInput.getNodeLevelStats().add(MLNodeLevelStat.from(state)); } else { mlStatsInput.getClusterLevelStats().add(MLClusterLevelStat.from(state)); diff --git a/plugin/src/main/java/org/opensearch/ml/stats/MLNodeLevelStat.java b/plugin/src/main/java/org/opensearch/ml/stats/MLNodeLevelStat.java index d002c002bf..857721392a 100644 --- a/plugin/src/main/java/org/opensearch/ml/stats/MLNodeLevelStat.java +++ b/plugin/src/main/java/org/opensearch/ml/stats/MLNodeLevelStat.java @@ -10,12 +10,13 @@ * This enum represents node level stats. */ public enum MLNodeLevelStat { - ML_NODE_JVM_HEAP_USAGE, - ML_NODE_EXECUTING_TASK_COUNT, - ML_NODE_TOTAL_REQUEST_COUNT, - ML_NODE_TOTAL_FAILURE_COUNT, - ML_NODE_TOTAL_MODEL_COUNT, - ML_NODE_TOTAL_CIRCUIT_BREAKER_TRIGGER_COUNT; + ML_JVM_HEAP_USAGE, + ML_EXECUTING_TASK_COUNT, // How many tasks are executing currently. If any task starts, then it will increase by 1, + // if the task finished then it will decrease by 0. + ML_REQUEST_COUNT, + ML_FAILURE_COUNT, + ML_DEPLOYED_MODEL_COUNT, + ML_CIRCUIT_BREAKER_TRIGGER_COUNT; public static MLNodeLevelStat from(String value) { try { diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLExecuteTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLExecuteTaskRunner.java index bc032efc0b..245ce4db43 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLExecuteTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLExecuteTaskRunner.java @@ -88,8 +88,8 @@ protected TransportResponseHandler getResponseHandler(Act protected void executeTask(MLExecuteTaskRequest request, ActionListener listener) { threadPool.executor(EXECUTE_THREAD_POOL).execute(() -> { try { - mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).increment(); - mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT).increment(); + mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).increment(); + mlStats.getStat(MLNodeLevelStat.ML_REQUEST_COUNT).increment(); mlStats .createCounterStatIfAbsent(request.getFunctionName(), ActionName.EXECUTE, MLActionLevelStat.ML_ACTION_REQUEST_COUNT) .increment(); @@ -113,7 +113,7 @@ protected void executeTask(MLExecuteTaskRequest request, ActionListener listener) { ActionListener internalListener = wrappedCleanupListener(listener, mlTask.getTaskId()); // track ML task count and add ML task into cache - mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).increment(); - mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT).increment(); + mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).increment(); + mlStats.getStat(MLNodeLevelStat.ML_REQUEST_COUNT).increment(); mlStats .createCounterStatIfAbsent(mlTask.getFunctionName(), ActionName.PREDICT, MLActionLevelStat.ML_ACTION_REQUEST_COUNT) .increment(); @@ -303,7 +303,7 @@ private void handlePredictFailure(MLTask mlTask, ActionListener mlStats .createCounterStatIfAbsent(mlTask.getFunctionName(), ActionName.PREDICT, MLActionLevelStat.ML_ACTION_FAILURE_COUNT) .increment(); - mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_FAILURE_COUNT).increment(); + mlStats.getStat(MLNodeLevelStat.ML_FAILURE_COUNT).increment(); } handleAsyncMLTaskFailure(mlTask, e); listener.onFailure(e); diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLTaskDispatcher.java b/plugin/src/main/java/org/opensearch/ml/task/MLTaskDispatcher.java index 43263e4601..7151d29bf8 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLTaskDispatcher.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLTaskDispatcher.java @@ -106,15 +106,14 @@ private void dispatchTaskWithLeastLoad(String[] nodeIds, ActionListener listener) { MLStatsNodesRequest MLStatsNodesRequest = new MLStatsNodesRequest(nodes); - MLStatsNodesRequest - .addNodeLevelStats(ImmutableSet.of(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT, MLNodeLevelStat.ML_NODE_JVM_HEAP_USAGE)); + MLStatsNodesRequest.addNodeLevelStats(ImmutableSet.of(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT, MLNodeLevelStat.ML_JVM_HEAP_USAGE)); client.execute(MLStatsNodesAction.INSTANCE, MLStatsNodesRequest, ActionListener.wrap(mlStatsResponse -> { // Check JVM pressure List candidateNodeResponse = mlStatsResponse .getNodes() .stream() - .filter(stat -> (long) stat.getNodeLevelStat(MLNodeLevelStat.ML_NODE_JVM_HEAP_USAGE) < DEFAULT_JVM_HEAP_USAGE_THRESHOLD) + .filter(stat -> (long) stat.getNodeLevelStat(MLNodeLevelStat.ML_JVM_HEAP_USAGE) < DEFAULT_JVM_HEAP_USAGE_THRESHOLD) .collect(Collectors.toList()); if (candidateNodeResponse.size() == 0) { @@ -129,7 +128,7 @@ private void dispatchTaskWithLeastLoad(DiscoveryNode[] nodes, ActionListener (Long) stat.getNodeLevelStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT) < maxMLBatchTaskPerNode) + .filter(stat -> (Long) stat.getNodeLevelStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT) < maxMLBatchTaskPerNode) .collect(Collectors.toList()); if (candidateNodeResponse.size() == 0) { String errorMessage = "All nodes' executing ML task count reach limitation."; @@ -142,13 +141,13 @@ private void dispatchTaskWithLeastLoad(DiscoveryNode[] nodes, ActionListener targetNode = candidateNodeResponse .stream() .sorted((MLStatsNodeResponse r1, MLStatsNodeResponse r2) -> { - int result = ((Long) r1.getNodeLevelStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT)) - .compareTo((Long) r2.getNodeLevelStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT)); + int result = ((Long) r1.getNodeLevelStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT)) + .compareTo((Long) r2.getNodeLevelStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT)); if (result == 0) { // if multiple nodes have same running task count, choose the one with least // JVM heap usage. - return ((Long) r1.getNodeLevelStat(MLNodeLevelStat.ML_NODE_JVM_HEAP_USAGE)) - .compareTo((Long) r2.getNodeLevelStat(MLNodeLevelStat.ML_NODE_JVM_HEAP_USAGE)); + return ((Long) r1.getNodeLevelStat(MLNodeLevelStat.ML_JVM_HEAP_USAGE)) + .compareTo((Long) r2.getNodeLevelStat(MLNodeLevelStat.ML_JVM_HEAP_USAGE)); } return result; }) diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLTaskRunner.java index 1d29560cca..534549c422 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLTaskRunner.java @@ -96,7 +96,7 @@ public void run(FunctionName functionName, Request request, TransportService tra protected ActionListener wrappedCleanupListener(ActionListener listener, String taskId) { ActionListener internalListener = ActionListener.runAfter(listener, () -> { - mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).decrement(); + mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).decrement(); mlTaskManager.remove(taskId); }); return internalListener; diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunner.java index 30983171fa..1e75008ac2 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunner.java @@ -128,8 +128,8 @@ protected void executeTask(MLTrainingTaskRequest request, ActionListener listener) { ActionListener internalListener = wrappedCleanupListener(listener, mlTask.getTaskId()); // track ML task count and add ML task into cache - mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).increment(); - mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT).increment(); + mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).increment(); + mlStats.getStat(MLNodeLevelStat.ML_REQUEST_COUNT).increment(); mlStats .createCounterStatIfAbsent(mlTask.getFunctionName(), ActionName.TRAIN_PREDICT, MLActionLevelStat.ML_ACTION_REQUEST_COUNT) .increment(); @@ -161,7 +161,7 @@ private void handlePredictFailure(MLTask mlTask, ActionListener mlStats .createCounterStatIfAbsent(mlTask.getFunctionName(), ActionName.TRAIN_PREDICT, MLActionLevelStat.ML_ACTION_FAILURE_COUNT) .increment(); - mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_FAILURE_COUNT).increment(); + mlStats.getStat(MLNodeLevelStat.ML_FAILURE_COUNT).increment(); } handleAsyncMLTaskFailure(mlTask, e); listener.onFailure(e); diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLTrainingTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLTrainingTaskRunner.java index f8937f9a0a..ee4edfe2bc 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLTrainingTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLTrainingTaskRunner.java @@ -146,8 +146,8 @@ protected void executeTask(MLTrainingTaskRequest request, ActionListener listener) { ActionListener internalListener = wrappedCleanupListener(listener, mlTask.getTaskId()); // track ML task count and add ML task into cache - mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).increment(); - mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT).increment(); + mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).increment(); + mlStats.getStat(MLNodeLevelStat.ML_REQUEST_COUNT).increment(); mlStats .createCounterStatIfAbsent(mlTask.getFunctionName(), ActionName.TRAIN, MLActionLevelStat.ML_ACTION_REQUEST_COUNT) .increment(); @@ -180,7 +180,7 @@ private void train(MLTask mlTask, MLInput mlInput, ActionListener void parseField(XContentParser parser, Set set, Function { ActionListener listener = invocation.getArgument(1); diff --git a/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodeITTests.java b/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodeITTests.java index 9140619dc1..601af4df0f 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodeITTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodeITTests.java @@ -5,7 +5,7 @@ package org.opensearch.ml.action.stats; -import static org.opensearch.ml.stats.MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT; +import static org.opensearch.ml.stats.MLNodeLevelStat.ML_EXECUTING_TASK_COUNT; import static org.opensearch.ml.utils.IntegTestUtils.TESTING_DATA; import static org.opensearch.ml.utils.IntegTestUtils.generateMLTestingData; import static org.opensearch.ml.utils.IntegTestUtils.verifyGeneratedTestingData; @@ -45,7 +45,7 @@ public void testGeneratedTestingData() throws ExecutionException, InterruptedExc public void testNormalCase() throws ExecutionException, InterruptedException { MLStatsNodesRequest request = new MLStatsNodesRequest(new String[0], new MLStatsInput()); - request.addNodeLevelStats(ImmutableSet.of(ML_NODE_EXECUTING_TASK_COUNT)); + request.addNodeLevelStats(ImmutableSet.of(ML_EXECUTING_TASK_COUNT)); ActionFuture future = client().execute(MLStatsNodesAction.INSTANCE, request); MLStatsNodesResponse response = future.get(); @@ -58,6 +58,6 @@ public void testNormalCase() throws ExecutionException, InterruptedException { MLStatsNodeResponse nodeResponse = responseList.get(0); assertEquals(1, nodeResponse.getNodeLevelStatSize()); - assertEquals(0l, nodeResponse.getNodeLevelStat(ML_NODE_EXECUTING_TASK_COUNT)); + assertEquals(0l, nodeResponse.getNodeLevelStat(ML_EXECUTING_TASK_COUNT)); } } diff --git a/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodeResponseTests.java b/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodeResponseTests.java index de6b599728..b475b033ff 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodeResponseTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodeResponseTests.java @@ -39,14 +39,14 @@ public class MLStatsNodeResponseTests extends OpenSearchTestCase { public void setup() { node = new DiscoveryNode("node0", buildNewFakeTransportAddress(), Version.CURRENT); Map statsToValues = new HashMap<>(); - statsToValues.put(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT, 100); + statsToValues.put(MLNodeLevelStat.ML_REQUEST_COUNT, 100); response = new MLStatsNodeResponse(node, statsToValues); } public void testSerializationDeserialization() throws IOException { DiscoveryNode localNode = new DiscoveryNode("node0", buildNewFakeTransportAddress(), Version.CURRENT); Map statsToValues = new HashMap<>(); - statsToValues.put(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT, 10l); + statsToValues.put(MLNodeLevelStat.ML_REQUEST_COUNT, 10l); MLStatsNodeResponse response = new MLStatsNodeResponse(localNode, statsToValues); BytesStreamOutput output = new BytesStreamOutput(); response.writeTo(output); @@ -60,7 +60,7 @@ public void testToXContent_NodeLevelStats() throws IOException { response.toXContent(builder, ToXContent.EMPTY_PARAMS); builder.endObject(); String taskContent = TestHelper.xContentBuilderToString(builder); - assertEquals("{\"ml_node_total_request_count\":100}", taskContent); + assertEquals("{\"ml_request_count\":100}", taskContent); } public void testToXContent_AlgorithmStats() throws IOException { @@ -100,7 +100,7 @@ public void testToXContent_WithAlgoStats() throws IOException { builder.startObject(); DiscoveryNode node = new DiscoveryNode("node0", buildNewFakeTransportAddress(), Version.CURRENT); Map statsToValues = new HashMap<>(); - statsToValues.put(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT, 100); + statsToValues.put(MLNodeLevelStat.ML_REQUEST_COUNT, 100); Map algoStats = new HashMap<>(); Map algoActionStats = new HashMap<>(); Map algoActionStatMap = new HashMap<>(); @@ -114,8 +114,8 @@ public void testToXContent_WithAlgoStats() throws IOException { String taskContent = TestHelper.xContentBuilderToString(builder); Set validResult = ImmutableSet .of( - "{\"ml_node_total_request_count\":100,\"algorithms\":{\"kmeans\":{\"train\":{\"ml_action_failure_count\":22,\"ml_action_request_count\":111}}}}", - "{\"ml_node_total_request_count\":100,\"algorithms\":{\"kmeans\":{\"train\":{\"ml_action_request_count\":111,\"ml_action_failure_count\":22}}}}" + "{\"ml_request_count\":100,\"algorithms\":{\"kmeans\":{\"train\":{\"ml_action_failure_count\":22,\"ml_action_request_count\":111}}}}", + "{\"ml_request_count\":100,\"algorithms\":{\"kmeans\":{\"train\":{\"ml_action_request_count\":111,\"ml_action_failure_count\":22}}}}" ); assertTrue(validResult.contains(taskContent)); } @@ -124,8 +124,8 @@ public void testReadStats() throws IOException { BytesStreamOutput output = new BytesStreamOutput(); response.writeTo(output); MLStatsNodeResponse mlStatsNodeResponse = MLStatsNodeResponse.readStats(output.bytes().streamInput()); - Integer expectedValue = (Integer) response.getNodeLevelStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT); - assertEquals(expectedValue, mlStatsNodeResponse.getNodeLevelStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT)); + Integer expectedValue = (Integer) response.getNodeLevelStat(MLNodeLevelStat.ML_REQUEST_COUNT); + assertEquals(expectedValue, mlStatsNodeResponse.getNodeLevelStat(MLNodeLevelStat.ML_REQUEST_COUNT)); } public void testIsEmpty_NullNodeStats() { @@ -150,23 +150,23 @@ public void testIsEmpty_EmptyAlgoStats() { public void testIsEmpty_NonEmptyNodeAndAlgoStats() { MLStatsNodeResponse response = createResponseWithDefaultAlgoStats( - ImmutableMap.of(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT, totalRequestCount) + ImmutableMap.of(MLNodeLevelStat.ML_REQUEST_COUNT, totalRequestCount) ); assertFalse(response.isEmpty()); } public void testGetNodeLevelStat_NonExistingStat() { - assertNull(response.getNodeLevelStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT)); + assertNull(response.getNodeLevelStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT)); assertEquals(1, response.getNodeLevelStatSize()); } public void testGetNodeLevelStat_NullOrEmptyNodeStats() { MLStatsNodeResponse response = new MLStatsNodeResponse(node, null); - assertNull(response.getNodeLevelStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT)); + assertNull(response.getNodeLevelStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT)); assertEquals(0, response.getNodeLevelStatSize()); response = new MLStatsNodeResponse(node, ImmutableMap.of()); - assertNull(response.getNodeLevelStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT)); + assertNull(response.getNodeLevelStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT)); assertEquals(0, response.getNodeLevelStatSize()); } diff --git a/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodesRequestTests.java b/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodesRequestTests.java index 4e8746811a..a59a2ca30c 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodesRequestTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodesRequestTests.java @@ -20,7 +20,7 @@ public class MLStatsNodesRequestTests extends OpenSearchTestCase { public void testSerializationDeserialization() throws IOException { MLStatsNodesRequest mlStatsNodesRequest = new MLStatsNodesRequest(new String[] { "testNodeId" }, new MLStatsInput()); - mlStatsNodesRequest.addNodeLevelStats(ImmutableSet.of(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT)); + mlStatsNodesRequest.addNodeLevelStats(ImmutableSet.of(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT)); BytesStreamOutput output = new BytesStreamOutput(); MLStatsNodeRequest request = new MLStatsNodeRequest(mlStatsNodesRequest); request.writeTo(output); diff --git a/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodesResponseTests.java b/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodesResponseTests.java index 832ddbb282..c801da467c 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodesResponseTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodesResponseTests.java @@ -44,12 +44,12 @@ public void testToXContent() throws IOException { DiscoveryNode node1 = new DiscoveryNode("node1", buildNewFakeTransportAddress(), Version.CURRENT); Map nodeLevelStats1 = new HashMap<>(); - nodeLevelStats1.put(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT, 100); + nodeLevelStats1.put(MLNodeLevelStat.ML_REQUEST_COUNT, 100); nodes.add(new MLStatsNodeResponse(node1, nodeLevelStats1)); DiscoveryNode node2 = new DiscoveryNode("node2", buildNewFakeTransportAddress(), Version.CURRENT); Map nodeLevelStats2 = new HashMap<>(); - nodeLevelStats2.put(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT, 200); + nodeLevelStats2.put(MLNodeLevelStat.ML_REQUEST_COUNT, 200); nodes.add(new MLStatsNodeResponse(node2, nodeLevelStats2)); List failures = new ArrayList<>(); @@ -58,9 +58,6 @@ public void testToXContent() throws IOException { response.toXContent(builder, ToXContent.EMPTY_PARAMS); builder.endObject(); String taskContent = TestHelper.xContentBuilderToString(builder); - assertEquals( - "{\"nodes\":{\"node1\":{\"ml_node_total_request_count\":100},\"node2\":{\"ml_node_total_request_count\":200}}}", - taskContent - ); + assertEquals("{\"nodes\":{\"node1\":{\"ml_request_count\":100},\"node2\":{\"ml_request_count\":200}}}", taskContent); } } diff --git a/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodesTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodesTransportActionTests.java index cef0e3a9c2..72b07055e3 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodesTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodesTransportActionTests.java @@ -7,7 +7,7 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; -import static org.opensearch.ml.stats.MLNodeLevelStat.ML_NODE_JVM_HEAP_USAGE; +import static org.opensearch.ml.stats.MLNodeLevelStat.ML_JVM_HEAP_USAGE; import java.io.IOException; import java.util.EnumSet; @@ -56,13 +56,13 @@ public void setUp() throws Exception { super.setUp(); clusterStatName1 = MLClusterLevelStat.ML_MODEL_COUNT; - nodeStatName1 = MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT; + nodeStatName1 = MLNodeLevelStat.ML_EXECUTING_TASK_COUNT; statsMap = new HashMap<>() { { put(nodeStatName1, new MLStat<>(false, new CounterSupplier())); put(clusterStatName1, new MLStat<>(true, new CounterSupplier())); - put(ML_NODE_JVM_HEAP_USAGE, new MLStat<>(true, new SettableSupplier())); + put(ML_JVM_HEAP_USAGE, new MLStat<>(true, new SettableSupplier())); } }; @@ -119,14 +119,14 @@ public void testNodeOperationWithJvmHeapUsage() { String nodeId = clusterService().localNode().getId(); MLStatsNodesRequest mlStatsNodesRequest = new MLStatsNodesRequest(new String[] { nodeId }, new MLStatsInput()); - Set statsToBeRetrieved = ImmutableSet.of(ML_NODE_JVM_HEAP_USAGE); + Set statsToBeRetrieved = ImmutableSet.of(ML_JVM_HEAP_USAGE); mlStatsNodesRequest.addNodeLevelStats(statsToBeRetrieved); MLStatsNodeResponse response = action.nodeOperation(new MLStatsNodeRequest(mlStatsNodesRequest)); Assert.assertEquals(statsToBeRetrieved.size(), response.getNodeLevelStatSize()); - assertNotNull(response.getNodeLevelStat(ML_NODE_JVM_HEAP_USAGE)); + assertNotNull(response.getNodeLevelStat(ML_JVM_HEAP_USAGE)); } public void testNodeOperation_NoNodeLevelStat() { diff --git a/plugin/src/test/java/org/opensearch/ml/bwc/MLCommonsBackwardsCompatibilityRestTestCase.java b/plugin/src/test/java/org/opensearch/ml/bwc/MLCommonsBackwardsCompatibilityRestTestCase.java index b4278e4ae3..36af1cde91 100644 --- a/plugin/src/test/java/org/opensearch/ml/bwc/MLCommonsBackwardsCompatibilityRestTestCase.java +++ b/plugin/src/test/java/org/opensearch/ml/bwc/MLCommonsBackwardsCompatibilityRestTestCase.java @@ -14,8 +14,8 @@ import static org.opensearch.ml.common.MLTask.MODEL_ID_FIELD; import static org.opensearch.ml.common.MLTask.STATE_FIELD; import static org.opensearch.ml.common.MLTask.TASK_ID_FIELD; -import static org.opensearch.ml.stats.MLNodeLevelStat.ML_NODE_TOTAL_FAILURE_COUNT; -import static org.opensearch.ml.stats.MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT; +import static org.opensearch.ml.stats.MLNodeLevelStat.ML_FAILURE_COUNT; +import static org.opensearch.ml.stats.MLNodeLevelStat.ML_REQUEST_COUNT; import static org.opensearch.ml.utils.TestData.SENTENCE_TRANSFORMER_MODEL_URL; import static org.opensearch.ml.utils.TestData.trainModelDataJson; @@ -292,11 +292,11 @@ protected void validateStats( Map allNodeStats = (Map) map.get("nodes"); for (String key : allNodeStats.keySet()) { Map nodeStatsMap = (Map) allNodeStats.get(key); - String statKey = ML_NODE_TOTAL_FAILURE_COUNT.name().toLowerCase(Locale.ROOT); + String statKey = ML_FAILURE_COUNT.name().toLowerCase(Locale.ROOT); if (nodeStatsMap.containsKey(statKey)) { totalFailureCount += (Double) nodeStatsMap.get(statKey); } - statKey = ML_NODE_TOTAL_REQUEST_COUNT.name().toLowerCase(Locale.ROOT); + statKey = ML_REQUEST_COUNT.name().toLowerCase(Locale.ROOT); if (nodeStatsMap.containsKey(statKey)) { totalRequestCount += (Double) nodeStatsMap.get(statKey); } diff --git a/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java b/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java index 434b343ad8..aba789c9b0 100644 --- a/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java @@ -213,11 +213,11 @@ public void setup() throws URISyntaxException { Map> stats = new ConcurrentHashMap<>(); // node level stats - stats.put(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT, new MLStat<>(false, new CounterSupplier())); - stats.put(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT, new MLStat<>(false, new CounterSupplier())); - stats.put(MLNodeLevelStat.ML_NODE_TOTAL_FAILURE_COUNT, new MLStat<>(false, new CounterSupplier())); - stats.put(MLNodeLevelStat.ML_NODE_TOTAL_MODEL_COUNT, new MLStat<>(false, new CounterSupplier())); - stats.put(MLNodeLevelStat.ML_NODE_TOTAL_CIRCUIT_BREAKER_TRIGGER_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_REQUEST_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_FAILURE_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_CIRCUIT_BREAKER_TRIGGER_COUNT, new MLStat<>(false, new CounterSupplier())); this.mlStats = spy(new MLStats(stats)); mlTask = MLTask @@ -734,7 +734,7 @@ private void testDeployModel_FailedToRetrieveModelChunks(boolean lastChunk) { modelManager.deployModel(modelId, modelContentHashValue, functionName, true, mlTask, listener); verify(modelCacheHelper).removeModel(eq(modelId)); verify(mlStats).createCounterStatIfAbsent(eq(functionName), eq(ActionName.DEPLOY), eq(MLActionLevelStat.ML_ACTION_REQUEST_COUNT)); - verify(mlStats).getStat(eq(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT)); + verify(mlStats).getStat(eq(MLNodeLevelStat.ML_REQUEST_COUNT)); } private void mock_client_index_ModelChunkFailure(Client client, String modelId) { diff --git a/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java b/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java index 963fd91f76..b0caba313a 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java @@ -14,8 +14,8 @@ import static org.opensearch.ml.common.MLTask.MODEL_ID_FIELD; import static org.opensearch.ml.common.MLTask.STATE_FIELD; import static org.opensearch.ml.common.MLTask.TASK_ID_FIELD; -import static org.opensearch.ml.stats.MLNodeLevelStat.ML_NODE_TOTAL_FAILURE_COUNT; -import static org.opensearch.ml.stats.MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT; +import static org.opensearch.ml.stats.MLNodeLevelStat.ML_FAILURE_COUNT; +import static org.opensearch.ml.stats.MLNodeLevelStat.ML_REQUEST_COUNT; import static org.opensearch.ml.utils.TestData.SENTENCE_TRANSFORMER_MODEL_URL; import static org.opensearch.ml.utils.TestData.trainModelDataJson; @@ -338,11 +338,11 @@ protected void validateStats( Map allNodeStats = (Map) map.get("nodes"); for (String key : allNodeStats.keySet()) { Map nodeStatsMap = (Map) allNodeStats.get(key); - String statKey = ML_NODE_TOTAL_FAILURE_COUNT.name().toLowerCase(Locale.ROOT); + String statKey = ML_FAILURE_COUNT.name().toLowerCase(Locale.ROOT); if (nodeStatsMap.containsKey(statKey)) { totalFailureCount += (Double) nodeStatsMap.get(statKey); } - statKey = ML_NODE_TOTAL_REQUEST_COUNT.name().toLowerCase(Locale.ROOT); + statKey = ML_REQUEST_COUNT.name().toLowerCase(Locale.ROOT); if (nodeStatsMap.containsKey(statKey)) { totalRequestCount += (Double) nodeStatsMap.get(statKey); } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLStatsActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLStatsActionTests.java index 7c755d0fd4..2bac46ff8e 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLStatsActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLStatsActionTests.java @@ -113,7 +113,7 @@ public void setup() throws IOException { MockitoAnnotations.openMocks(this); Map> statMap = ImmutableMap .>builder() - .put(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT, new MLStat<>(false, new CounterSupplier())) + .put(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT, new MLStat<>(false, new CounterSupplier())) .build(); mlStats = new MLStats(statMap); threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); @@ -209,7 +209,7 @@ public void testPrepareRequest_ClusterAndNodeLevelStates() throws Exception { content .utf8ToString() .contains( - "\"nodes\":{\"node\":{\"ml_node_total_request_count\":100,\"algorithms\":{\"kmeans\":{\"train\":{\"ml_action_request_count\":20}}}}}}" + "\"nodes\":{\"node\":{\"ml_request_count\":100,\"algorithms\":{\"kmeans\":{\"train\":{\"ml_action_request_count\":20}}}}}}" ) ); } @@ -218,7 +218,7 @@ private void prepareResponse() { doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(2); List nodes = new ArrayList<>(); - Map nodeStats = ImmutableMap.of(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT, nodeTotalRequestCount); + Map nodeStats = ImmutableMap.of(MLNodeLevelStat.ML_REQUEST_COUNT, nodeTotalRequestCount); Map algoStats = new HashMap<>(); Map actionStats = ImmutableMap .of( @@ -299,7 +299,7 @@ public void testPrepareRequest_ClusterAndNodeLevelStates_NoRequestContent() thro content .utf8ToString() .contains( - "\"nodes\":{\"node\":{\"ml_node_total_request_count\":100,\"algorithms\":{\"kmeans\":{\"train\":{\"ml_action_request_count\":20}}}}}}" + "\"nodes\":{\"node\":{\"ml_request_count\":100,\"algorithms\":{\"kmeans\":{\"train\":{\"ml_action_request_count\":20}}}}}}" ) ); } @@ -309,7 +309,7 @@ public void testPrepareRequest_ClusterAndNodeLevelStates_RequestParams() throws RestRequest request = getStatsRestRequest( node.getId(), - MLClusterLevelStat.ML_MODEL_COUNT + "," + MLNodeLevelStat.ML_NODE_TOTAL_MODEL_COUNT + MLClusterLevelStat.ML_MODEL_COUNT + "," + MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT ); restAction.handleRequest(request, channel, client); @@ -321,7 +321,7 @@ public void testPrepareRequest_ClusterAndNodeLevelStates_RequestParams() throws assertTrue(input.getTargetStatLevels().contains(MLStatLevel.NODE)); assertEquals(1, input.getClusterLevelStats().size()); assertTrue(input.getClusterLevelStats().contains(MLClusterLevelStat.ML_MODEL_COUNT)); - assertTrue(input.getNodeLevelStats().contains(MLNodeLevelStat.ML_NODE_TOTAL_MODEL_COUNT)); + assertTrue(input.getNodeLevelStats().contains(MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT)); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(BytesRestResponse.class); verify(channel, times(1)).sendResponse(argumentCaptor.capture()); @@ -334,7 +334,7 @@ public void testPrepareRequest_ClusterAndNodeLevelStates_RequestParams() throws content .utf8ToString() .contains( - "\"nodes\":{\"node\":{\"ml_node_total_request_count\":100,\"algorithms\":{\"kmeans\":{\"train\":{\"ml_action_request_count\":20}}}}}}" + "\"nodes\":{\"node\":{\"ml_request_count\":100,\"algorithms\":{\"kmeans\":{\"train\":{\"ml_action_request_count\":20}}}}}}" ) ); } @@ -342,7 +342,7 @@ public void testPrepareRequest_ClusterAndNodeLevelStates_RequestParams() throws public void testPrepareRequest_ClusterAndNodeLevelStates_RequestParams_NodeLevelStat() throws Exception { prepareResponse(); - RestRequest request = getStatsRestRequest(node.getId(), MLNodeLevelStat.ML_NODE_TOTAL_MODEL_COUNT.name()); + RestRequest request = getStatsRestRequest(node.getId(), MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT.name()); restAction.handleRequest(request, channel, client); ArgumentCaptor inputArgumentCaptor = ArgumentCaptor.forClass(MLStatsNodesRequest.class); @@ -352,7 +352,7 @@ public void testPrepareRequest_ClusterAndNodeLevelStates_RequestParams_NodeLevel assertTrue(input.getTargetStatLevels().contains(MLStatLevel.NODE)); assertEquals(0, input.getClusterLevelStats().size()); assertEquals(1, input.getNodeLevelStats().size()); - assertTrue(input.getNodeLevelStats().contains(MLNodeLevelStat.ML_NODE_TOTAL_MODEL_COUNT)); + assertTrue(input.getNodeLevelStats().contains(MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT)); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(BytesRestResponse.class); verify(channel, times(1)).sendResponse(argumentCaptor.capture()); @@ -360,17 +360,17 @@ public void testPrepareRequest_ClusterAndNodeLevelStates_RequestParams_NodeLevel assertEquals(RestStatus.OK, restResponse.status()); BytesReference content = restResponse.content(); assertEquals( - "{\"nodes\":{\"node\":{\"ml_node_total_request_count\":100,\"algorithms\":{\"kmeans\":{\"train\":{\"ml_action_request_count\":20}}}}}}", + "{\"nodes\":{\"node\":{\"ml_request_count\":100,\"algorithms\":{\"kmeans\":{\"train\":{\"ml_action_request_count\":20}}}}}}", content.utf8ToString() ); } public void testCreateMlStatsInputFromRequestParams_NodeStat() { - RestRequest request = getStatsRestRequest(node.getId(), MLNodeLevelStat.ML_NODE_TOTAL_MODEL_COUNT.name().toLowerCase(Locale.ROOT)); + RestRequest request = getStatsRestRequest(node.getId(), MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT.name().toLowerCase(Locale.ROOT)); MLStatsInput input = restAction.createMlStatsInputFromRequestParams(request); assertEquals(1, input.getTargetStatLevels().size()); assertTrue(input.getTargetStatLevels().contains(MLStatLevel.NODE)); - assertTrue(input.getNodeLevelStats().contains(MLNodeLevelStat.ML_NODE_TOTAL_MODEL_COUNT)); + assertTrue(input.getNodeLevelStats().contains(MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT)); assertEquals(0, input.getClusterLevelStats().size()); } diff --git a/plugin/src/test/java/org/opensearch/ml/stats/MLStatsInputTests.java b/plugin/src/test/java/org/opensearch/ml/stats/MLStatsInputTests.java index e8ee91397a..ddad31f58d 100644 --- a/plugin/src/test/java/org/opensearch/ml/stats/MLStatsInputTests.java +++ b/plugin/src/test/java/org/opensearch/ml/stats/MLStatsInputTests.java @@ -84,27 +84,27 @@ public void testRetrieveAll() { public void testShouldRetrieveStat() { assertTrue(mlStatsInput.retrieveStat(MLClusterLevelStat.ML_MODEL_COUNT)); - assertTrue(mlStatsInput.retrieveStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT)); + assertTrue(mlStatsInput.retrieveStat(MLNodeLevelStat.ML_REQUEST_COUNT)); assertTrue(mlStatsInput.retrieveStat(MLActionLevelStat.ML_ACTION_REQUEST_COUNT)); MLStatsInput mlStatsInput = MLStatsInput.builder().build(); assertTrue(mlStatsInput.retrieveStat(MLClusterLevelStat.ML_MODEL_COUNT)); - assertTrue(mlStatsInput.retrieveStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT)); + assertTrue(mlStatsInput.retrieveStat(MLNodeLevelStat.ML_REQUEST_COUNT)); assertTrue(mlStatsInput.retrieveStat(MLActionLevelStat.ML_ACTION_REQUEST_COUNT)); mlStatsInput = new MLStatsInput(); assertTrue(mlStatsInput.retrieveStat(MLClusterLevelStat.ML_MODEL_COUNT)); - assertTrue(mlStatsInput.retrieveStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT)); + assertTrue(mlStatsInput.retrieveStat(MLNodeLevelStat.ML_REQUEST_COUNT)); assertTrue(mlStatsInput.retrieveStat(MLActionLevelStat.ML_ACTION_REQUEST_COUNT)); mlStatsInput = MLStatsInput .builder() .clusterLevelStats(EnumSet.of(MLClusterLevelStat.ML_TASK_INDEX_STATUS)) - .nodeLevelStats(EnumSet.of(MLNodeLevelStat.ML_NODE_TOTAL_FAILURE_COUNT)) + .nodeLevelStats(EnumSet.of(MLNodeLevelStat.ML_FAILURE_COUNT)) .actionLevelStats(EnumSet.of(MLActionLevelStat.ML_ACTION_FAILURE_COUNT)) .build(); assertFalse(mlStatsInput.retrieveStat(MLClusterLevelStat.ML_MODEL_COUNT)); - assertFalse(mlStatsInput.retrieveStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT)); + assertFalse(mlStatsInput.retrieveStat(MLNodeLevelStat.ML_REQUEST_COUNT)); assertFalse(mlStatsInput.retrieveStat(MLActionLevelStat.ML_ACTION_REQUEST_COUNT)); } diff --git a/plugin/src/test/java/org/opensearch/ml/stats/MLStatsTests.java b/plugin/src/test/java/org/opensearch/ml/stats/MLStatsTests.java index 83e8676646..51bde6bd3a 100644 --- a/plugin/src/test/java/org/opensearch/ml/stats/MLStatsTests.java +++ b/plugin/src/test/java/org/opensearch/ml/stats/MLStatsTests.java @@ -33,7 +33,7 @@ public class MLStatsTests extends OpenSearchTestCase { public void setup() { clusterStatName1 = MLClusterLevelStat.ML_MODEL_COUNT; - nodeStatName1 = MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT; + nodeStatName1 = MLNodeLevelStat.ML_EXECUTING_TASK_COUNT; statsMap = new HashMap>() { { @@ -70,14 +70,14 @@ public void testGetStat() { } public void testGetStatNoExisting() { - MLNodeLevelStat wrongStat = MLNodeLevelStat.ML_NODE_JVM_HEAP_USAGE; + MLNodeLevelStat wrongStat = MLNodeLevelStat.ML_JVM_HEAP_USAGE; expectedEx.expect(IllegalArgumentException.class); expectedEx.expectMessage("Stat \"" + wrongStat + "\" does not exist"); mlStats.getStat(wrongStat); } public void testCreateCounterStatIfAbsent() { - MLStat stat = mlStats.createCounterStatIfAbsent(MLNodeLevelStat.ML_NODE_TOTAL_FAILURE_COUNT); + MLStat stat = mlStats.createCounterStatIfAbsent(MLNodeLevelStat.ML_FAILURE_COUNT); stat.increment(); assertEquals(1L, stat.getValue()); } diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLExecuteTaskRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLExecuteTaskRunnerTests.java index 37679e7820..f2fad0f753 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/MLExecuteTaskRunnerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/MLExecuteTaskRunnerTests.java @@ -119,10 +119,10 @@ public void setup() { when(clusterService.getClusterSettings()).thenReturn(clusterSettings); Map> stats = new ConcurrentHashMap<>(); - stats.put(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT, new MLStat<>(false, new CounterSupplier())); - stats.put(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT, new MLStat<>(false, new CounterSupplier())); - stats.put(MLNodeLevelStat.ML_NODE_TOTAL_FAILURE_COUNT, new MLStat<>(false, new CounterSupplier())); - stats.put(MLNodeLevelStat.ML_NODE_TOTAL_MODEL_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_REQUEST_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_FAILURE_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT, new MLStat<>(false, new CounterSupplier())); this.mlStats = new MLStats(stats); mlInputDatasetHandler = spy(new MLInputDatasetHandler(client)); diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java index 452ab49811..26c19394d2 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java @@ -141,10 +141,10 @@ public void setup() throws IOException { }).when(executorService).execute(any(Runnable.class)); Map> stats = new ConcurrentHashMap<>(); - stats.put(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT, new MLStat<>(false, new CounterSupplier())); - stats.put(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT, new MLStat<>(false, new CounterSupplier())); - stats.put(MLNodeLevelStat.ML_NODE_TOTAL_FAILURE_COUNT, new MLStat<>(false, new CounterSupplier())); - stats.put(MLNodeLevelStat.ML_NODE_TOTAL_MODEL_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_REQUEST_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_FAILURE_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT, new MLStat<>(false, new CounterSupplier())); this.mlStats = new MLStats(stats); mlInputDatasetHandler = spy(new MLInputDatasetHandler(client)); taskRunner = spy( diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLTaskDispatcherTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLTaskDispatcherTests.java index 9629ded5d7..0324c33c1f 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/MLTaskDispatcherTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/MLTaskDispatcherTests.java @@ -160,8 +160,8 @@ public void testGetEligibleNodes_MlAndDataNodes() { private MLStatsNodesResponse getMlStatsNodesResponse() { Map nodeStats = new HashMap<>(); - nodeStats.put(MLNodeLevelStat.ML_NODE_JVM_HEAP_USAGE, 50l); - nodeStats.put(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT, 5l); + nodeStats.put(MLNodeLevelStat.ML_JVM_HEAP_USAGE, 50l); + nodeStats.put(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT, 5l); MLStatsNodeResponse mlStatsNodeResponse1 = new MLStatsNodeResponse(dataNode1, nodeStats); MLStatsNodeResponse mlStatsNodeResponse2 = new MLStatsNodeResponse(dataNode1, nodeStats); return new MLStatsNodesResponse( @@ -173,7 +173,7 @@ private MLStatsNodesResponse getMlStatsNodesResponse() { private MLStatsNodesResponse getNodesResponse_NoTaskCounts() { Map nodeStats = new HashMap<>(); - nodeStats.put(MLNodeLevelStat.ML_NODE_JVM_HEAP_USAGE, 50l); + nodeStats.put(MLNodeLevelStat.ML_JVM_HEAP_USAGE, 50l); MLStatsNodeResponse mlStatsNodeResponse1 = new MLStatsNodeResponse(dataNode1, nodeStats); MLStatsNodeResponse mlStatsNodeResponse2 = new MLStatsNodeResponse(dataNode1, nodeStats); return new MLStatsNodesResponse( @@ -185,8 +185,8 @@ private MLStatsNodesResponse getNodesResponse_NoTaskCounts() { private MLStatsNodesResponse getNodesResponse_MemoryExceedLimits() { Map nodeStats = new HashMap<>(); - nodeStats.put(MLNodeLevelStat.ML_NODE_JVM_HEAP_USAGE, 90l); - nodeStats.put(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT, 5l); + nodeStats.put(MLNodeLevelStat.ML_JVM_HEAP_USAGE, 90l); + nodeStats.put(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT, 5l); MLStatsNodeResponse mlStatsNodeResponse1 = new MLStatsNodeResponse(dataNode1, nodeStats); MLStatsNodeResponse mlStatsNodeResponse2 = new MLStatsNodeResponse(dataNode1, nodeStats); return new MLStatsNodesResponse( @@ -198,8 +198,8 @@ private MLStatsNodesResponse getNodesResponse_MemoryExceedLimits() { private MLStatsNodesResponse getNodesResponse_TaskCountExceedLimits() { Map nodeStats = new HashMap<>(); - nodeStats.put(MLNodeLevelStat.ML_NODE_JVM_HEAP_USAGE, 50l); - nodeStats.put(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT, 15l); + nodeStats.put(MLNodeLevelStat.ML_JVM_HEAP_USAGE, 50l); + nodeStats.put(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT, 15l); MLStatsNodeResponse mlStatsNodeResponse1 = new MLStatsNodeResponse(dataNode1, nodeStats); MLStatsNodeResponse mlStatsNodeResponse2 = new MLStatsNodeResponse(dataNode1, nodeStats); return new MLStatsNodesResponse( diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunnerTests.java index ffb94563ff..5811372b61 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunnerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunnerTests.java @@ -117,10 +117,10 @@ public void setup() { }).when(executorService).execute(any(Runnable.class)); Map> stats = new ConcurrentHashMap<>(); - stats.put(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT, new MLStat<>(false, new CounterSupplier())); - stats.put(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT, new MLStat<>(false, new CounterSupplier())); - stats.put(MLNodeLevelStat.ML_NODE_TOTAL_FAILURE_COUNT, new MLStat<>(false, new CounterSupplier())); - stats.put(MLNodeLevelStat.ML_NODE_TOTAL_MODEL_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_REQUEST_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_FAILURE_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT, new MLStat<>(false, new CounterSupplier())); this.mlStats = new MLStats(stats); mlInputDatasetHandler = spy(new MLInputDatasetHandler(client)); diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLTrainingTaskRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLTrainingTaskRunnerTests.java index ffdd33cae4..5ca625e71d 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/MLTrainingTaskRunnerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/MLTrainingTaskRunnerTests.java @@ -125,10 +125,10 @@ public void setup() { }).when(executorService).execute(any(Runnable.class)); Map> stats = new ConcurrentHashMap<>(); - stats.put(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT, new MLStat<>(false, new CounterSupplier())); - stats.put(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT, new MLStat<>(false, new CounterSupplier())); - stats.put(MLNodeLevelStat.ML_NODE_TOTAL_FAILURE_COUNT, new MLStat<>(false, new CounterSupplier())); - stats.put(MLNodeLevelStat.ML_NODE_TOTAL_MODEL_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_REQUEST_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_FAILURE_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT, new MLStat<>(false, new CounterSupplier())); this.mlStats = new MLStats(stats); mlInputDatasetHandler = spy(new MLInputDatasetHandler(client)); diff --git a/plugin/src/test/java/org/opensearch/ml/task/TaskRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/task/TaskRunnerTests.java index bfce831a14..bbebe4060e 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/TaskRunnerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/TaskRunnerTests.java @@ -69,11 +69,11 @@ public class TaskRunnerTests extends OpenSearchTestCase { @Before public void setup() { Map> stats = new ConcurrentHashMap<>(); - stats.put(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT, new MLStat<>(false, new CounterSupplier())); - stats.put(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT, new MLStat<>(false, new CounterSupplier())); - stats.put(MLNodeLevelStat.ML_NODE_TOTAL_FAILURE_COUNT, new MLStat<>(false, new CounterSupplier())); - stats.put(MLNodeLevelStat.ML_NODE_TOTAL_MODEL_COUNT, new MLStat<>(false, new CounterSupplier())); - stats.put(MLNodeLevelStat.ML_NODE_TOTAL_CIRCUIT_BREAKER_TRIGGER_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_REQUEST_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_FAILURE_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_CIRCUIT_BREAKER_TRIGGER_COUNT, new MLStat<>(false, new CounterSupplier())); mlStats = new MLStats(stats); MockitoAnnotations.openMocks(this); @@ -140,7 +140,7 @@ public void testRun_CircuitBreakerOpen() { ActionListener listener = mock(ActionListener.class); MLTaskRequest request = new MLTaskRequest(false); expectThrows(MLLimitExceededException.class, () -> mlTaskRunner.run(FunctionName.REMOTE, request, transportService, listener)); - Long value = (Long) mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_CIRCUIT_BREAKER_TRIGGER_COUNT).getValue(); + Long value = (Long) mlStats.getStat(MLNodeLevelStat.ML_CIRCUIT_BREAKER_TRIGGER_COUNT).getValue(); assertEquals(1L, value.longValue()); } }