From 29d6398ec76e9f4575988b5238d1993631ea3c13 Mon Sep 17 00:00:00 2001 From: Dhrubo Saha Date: Fri, 18 Aug 2023 18:51:38 -0700 Subject: [PATCH 01/11] renaming metrics Signed-off-by: Dhrubo Saha --- .../deploy/TransportDeployModelAction.java | 2 +- .../TransportRegisterModelAction.java | 4 +- .../stats/MLStatsNodesTransportAction.java | 4 +- .../TransportUndeployModelAction.java | 6 +-- .../opensearch/ml/model/MLModelManager.java | 38 +++++++++---------- .../ml/plugin/MachineLearningPlugin.java | 10 ++--- .../opensearch/ml/stats/MLNodeLevelStat.java | 12 +++--- .../ml/task/MLExecuteTaskRunner.java | 6 +-- .../ml/task/MLPredictTaskRunner.java | 6 +-- .../opensearch/ml/task/MLTaskDispatcher.java | 14 +++---- .../org/opensearch/ml/task/MLTaskRunner.java | 2 +- .../ml/task/MLTrainAndPredictTaskRunner.java | 6 +-- .../ml/task/MLTrainingTaskRunner.java | 6 +-- .../org/opensearch/ml/utils/MLNodeUtils.java | 2 +- .../TransportDeployModelActionTests.java | 2 +- .../TransportRegisterModelActionTests.java | 2 +- .../ml/action/stats/MLStatsNodeITTests.java | 6 +-- .../stats/MLStatsNodeResponseTests.java | 24 ++++++------ .../stats/MLStatsNodesRequestTests.java | 2 +- .../stats/MLStatsNodesResponseTests.java | 9 ++--- .../MLStatsNodesTransportActionTests.java | 10 ++--- ...onsBackwardsCompatibilityRestTestCase.java | 8 ++-- .../ml/model/MLModelManagerTests.java | 12 +++--- .../ml/rest/MLCommonsRestTestCase.java | 8 ++-- .../ml/rest/RestMLStatsActionTests.java | 24 ++++++------ .../ml/stats/MLStatsInputTests.java | 10 ++--- .../org/opensearch/ml/stats/MLStatsTests.java | 6 +-- .../ml/task/MLExecuteTaskRunnerTests.java | 8 ++-- .../ml/task/MLPredictTaskRunnerTests.java | 8 ++-- .../ml/task/MLTaskDispatcherTests.java | 14 +++---- .../MLTrainAndPredictTaskRunnerTests.java | 8 ++-- .../ml/task/MLTrainingTaskRunnerTests.java | 8 ++-- .../opensearch/ml/task/TaskRunnerTests.java | 12 +++--- 33 files changed, 148 insertions(+), 151 deletions(-) diff --git a/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelAction.java b/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelAction.java index 3915886360..d931ab80ef 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelAction.java @@ -149,7 +149,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener nodeMapping = new HashMap<>(); for (DiscoveryNode node : allEligibleNodes) { diff --git a/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java b/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java index c5d426e348..963d7450e2 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java @@ -234,8 +234,8 @@ private void registerModel(MLRegisterModelInput registerModelInput, ActionListen throw new IllegalArgumentException("URL can't match trusted url regex"); } } - // mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).increment(); - mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT).increment(); + // mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).increment(); + mlStats.getStat(MLNodeLevelStat.ML_REQUEST_COUNT).increment(); // //TODO: track executing task; track register failures // mlStats.createCounterStatIfAbsent(FunctionName.TEXT_EMBEDDING, // ActionName.REGISTER, diff --git a/plugin/src/main/java/org/opensearch/ml/action/stats/MLStatsNodesTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/stats/MLStatsNodesTransportAction.java index e6c59bf771..f1809de9bb 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/stats/MLStatsNodesTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/stats/MLStatsNodesTransportAction.java @@ -97,9 +97,9 @@ MLStatsNodeResponse createMLStatsNodeResponse(MLStatsNodesRequest mlStatsNodesRe MLStatsInput mlStatsInput = mlStatsNodesRequest.getMlStatsInput(); // return node level stats if (mlStatsInput.getTargetStatLevels().contains(MLStatLevel.NODE)) { - if (mlStatsInput.retrieveStat(MLNodeLevelStat.ML_NODE_JVM_HEAP_USAGE)) { + if (mlStatsInput.retrieveStat(MLNodeLevelStat.ML_JVM_HEAP_USAGE)) { long heapUsedPercent = jvmService.stats().getMem().getHeapUsedPercent(); - statValues.put(MLNodeLevelStat.ML_NODE_JVM_HEAP_USAGE, heapUsedPercent); + statValues.put(MLNodeLevelStat.ML_JVM_HEAP_USAGE, heapUsedPercent); } for (Enum statName : mlStats.getNodeStats().keySet()) { diff --git a/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelAction.java b/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelAction.java index bbed0305e1..4ca04aa609 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelAction.java @@ -228,8 +228,8 @@ protected MLUndeployModelNodeResponse nodeOperation(MLUndeployModelNodeRequest r } private MLUndeployModelNodeResponse createUndeployModelNodeResponse(MLUndeployModelNodesRequest MLUndeployModelNodesRequest) { - mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).increment(); - mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT).increment(); + mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).increment(); + mlStats.getStat(MLNodeLevelStat.ML_REQUEST_COUNT).increment(); String[] modelIds = MLUndeployModelNodesRequest.getModelIds(); @@ -246,7 +246,7 @@ private MLUndeployModelNodeResponse createUndeployModelNodeResponse(MLUndeployMo } Map modelUndeployStatus = mlModelManager.undeployModel(modelIds); - mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).decrement(); + mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).decrement(); return new MLUndeployModelNodeResponse(clusterService.localNode(), modelUndeployStatus, modelWorkerNodesMap); } } diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java index 06295ac138..07658dae1a 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java @@ -212,7 +212,7 @@ public MLModelManager( public void registerModelMeta(MLRegisterModelMetaInput mlRegisterModelMetaInput, ActionListener listener) { try { FunctionName functionName = mlRegisterModelMetaInput.getFunctionName(); - mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT).increment(); + mlStats.getStat(MLNodeLevelStat.ML_REQUEST_COUNT).increment(); mlStats.createCounterStatIfAbsent(functionName, REGISTER, ML_ACTION_REQUEST_COUNT).increment(); String modelGroupId = mlRegisterModelMetaInput.getModelGroupId(); if (Strings.isBlank(modelGroupId)) { @@ -322,8 +322,8 @@ public void registerMLModel(MLRegisterModelInput registerModelInput, MLTask mlTa checkAndAddRunningTask(mlTask, maxRegisterTasksPerNode); try { - mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT).increment(); - mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).increment(); + mlStats.getStat(MLNodeLevelStat.ML_REQUEST_COUNT).increment(); + mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).increment(); mlStats.createCounterStatIfAbsent(mlTask.getFunctionName(), REGISTER, ML_ACTION_REQUEST_COUNT).increment(); String modelGroupId = registerModelInput.getModelGroupId(); @@ -384,7 +384,7 @@ public void registerMLModel(MLRegisterModelInput registerModelInput, MLTask mlTa } catch (Exception e) { handleException(registerModelInput.getFunctionName(), mlTask.getTaskId(), e); } finally { - mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).increment(); + mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).increment(); } } @@ -392,9 +392,9 @@ private void indexRemoteModel(MLRegisterModelInput registerModelInput, MLTask ml String taskId = mlTask.getTaskId(); FunctionName functionName = mlTask.getFunctionName(); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT).increment(); + mlStats.getStat(MLNodeLevelStat.ML_REQUEST_COUNT).increment(); mlStats.createCounterStatIfAbsent(functionName, REGISTER, ML_ACTION_REQUEST_COUNT).increment(); - mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).increment(); + mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).increment(); String modelName = registerModelInput.getModelName(); String version = modelVersion == null ? registerModelInput.getVersion() : modelVersion; @@ -444,7 +444,7 @@ private void indexRemoteModel(MLRegisterModelInput registerModelInput, MLTask ml logException("Failed to upload model", e, log); handleException(functionName, taskId, e); } finally { - mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).increment(); + mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).increment(); } } @@ -462,9 +462,9 @@ private void registerModelFromUrl(MLRegisterModelInput registerModelInput, MLTas String taskId = mlTask.getTaskId(); FunctionName functionName = mlTask.getFunctionName(); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT).increment(); + mlStats.getStat(MLNodeLevelStat.ML_REQUEST_COUNT).increment(); mlStats.createCounterStatIfAbsent(functionName, REGISTER, ML_ACTION_REQUEST_COUNT).increment(); - mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).increment(); + mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).increment(); String modelName = registerModelInput.getModelName(); String version = modelVersion == null ? registerModelInput.getVersion() : modelVersion; String modelGroupId = registerModelInput.getModelGroupId(); @@ -510,7 +510,7 @@ private void registerModelFromUrl(MLRegisterModelInput registerModelInput, MLTas logException("Failed to register model", e, log); handleException(functionName, taskId, e); } finally { - mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).increment(); + mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).increment(); } } @@ -693,7 +693,7 @@ private void handleException(FunctionName functionName, String taskId, Exception && !(e instanceof MLResourceNotFoundException) && !(e instanceof IllegalArgumentException)) { mlStats.createCounterStatIfAbsent(functionName, REGISTER, MLActionLevelStat.ML_ACTION_FAILURE_COUNT).increment(); - mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_FAILURE_COUNT).increment(); + mlStats.getStat(MLNodeLevelStat.ML_FAILURE_COUNT).increment(); } Map updated = ImmutableMap.of(ERROR_FIELD, MLExceptionUtils.getRootCauseMessage(e), STATE_FIELD, FAILED); mlTaskManager.updateMLTask(taskId, updated, TIMEOUT_IN_MILLIS, true); @@ -718,7 +718,7 @@ public void deployModel( ActionListener listener ) { mlStats.createCounterStatIfAbsent(functionName, ActionName.DEPLOY, ML_ACTION_REQUEST_COUNT).increment(); - mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT).increment(); + mlStats.getStat(MLNodeLevelStat.ML_REQUEST_COUNT).increment(); List workerNodes = mlTask.getWorkerNodes(); if (modelCacheHelper.isModelDeployed(modelId)) { if (workerNodes != null && workerNodes.size() > 0) { @@ -800,7 +800,7 @@ public void deployModel( MLExecutable mlExecutable = mlEngine.deployExecute(mlModel, params); try { modelCacheHelper.setMLExecutor(modelId, mlExecutable); - mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_MODEL_COUNT).increment(); + mlStats.getStat(MLNodeLevelStat.ML_TOTAL_MODEL_COUNT).increment(); modelCacheHelper.setModelState(modelId, MLModelState.DEPLOYED); listener.onResponse("successful"); } catch (Exception e) { @@ -813,7 +813,7 @@ public void deployModel( Predictable predictable = mlEngine.deploy(mlModel, params); try { modelCacheHelper.setPredictor(modelId, predictable); - mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_MODEL_COUNT).increment(); + mlStats.getStat(MLNodeLevelStat.ML_TOTAL_MODEL_COUNT).increment(); modelCacheHelper.setModelState(modelId, MLModelState.DEPLOYED); Long modelContentSizeInBytes = mlModel.getModelContentSizeInBytes(); long contentSize = modelContentSizeInBytes == null @@ -846,7 +846,7 @@ private void handleDeployModelException(String modelId, FunctionName functionNam && !(e instanceof MLResourceNotFoundException) && !(e instanceof IllegalArgumentException)) { mlStats.createCounterStatIfAbsent(functionName, ActionName.DEPLOY, MLActionLevelStat.ML_ACTION_FAILURE_COUNT).increment(); - mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_FAILURE_COUNT).increment(); + mlStats.getStat(MLNodeLevelStat.ML_FAILURE_COUNT).increment(); } removeModel(modelId); listener.onFailure(e); @@ -855,7 +855,7 @@ private void handleDeployModelException(String modelId, FunctionName functionNam private void setupPredictable(String modelId, MLModel mlModel, Map params) { Predictable predictable = mlEngine.deploy(mlModel, params); modelCacheHelper.setPredictor(modelId, predictable); - mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_MODEL_COUNT).increment(); + mlStats.getStat(MLNodeLevelStat.ML_TOTAL_MODEL_COUNT).increment(); modelCacheHelper.setModelState(modelId, MLModelState.DEPLOYED); } @@ -1056,8 +1056,8 @@ public synchronized Map undeployModel(String[] modelIds) { for (String modelId : modelIds) { if (modelCacheHelper.isModelDeployed(modelId)) { modelUndeployStatus.put(modelId, UNDEPLOYED); - mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_MODEL_COUNT).decrement(); - mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT).increment(); + mlStats.getStat(MLNodeLevelStat.ML_TOTAL_MODEL_COUNT).decrement(); + mlStats.getStat(MLNodeLevelStat.ML_REQUEST_COUNT).increment(); mlStats .createCounterStatIfAbsent(getModelFunctionName(modelId), ActionName.UNDEPLOY, ML_ACTION_REQUEST_COUNT) .increment(); @@ -1070,7 +1070,7 @@ public synchronized Map undeployModel(String[] modelIds) { log.debug("undeploy all models {}", Arrays.toString(getLocalDeployedModels())); for (String modelId : getLocalDeployedModels()) { modelUndeployStatus.put(modelId, UNDEPLOYED); - mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_MODEL_COUNT).decrement(); + mlStats.getStat(MLNodeLevelStat.ML_TOTAL_MODEL_COUNT).decrement(); mlStats.createCounterStatIfAbsent(getModelFunctionName(modelId), ActionName.UNDEPLOY, ML_ACTION_REQUEST_COUNT).increment(); removeModel(modelId); } diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index 2e68ef68b4..a30ebf2b97 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -305,11 +305,11 @@ public Collection createComponents( stats.put(MLClusterLevelStat.ML_MODEL_COUNT, new MLStat<>(true, new CounterSupplier())); stats.put(MLClusterLevelStat.ML_CONNECTOR_COUNT, new MLStat<>(true, new CounterSupplier())); // node level stats - stats.put(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT, new MLStat<>(false, new CounterSupplier())); - stats.put(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT, new MLStat<>(false, new CounterSupplier())); - stats.put(MLNodeLevelStat.ML_NODE_TOTAL_FAILURE_COUNT, new MLStat<>(false, new CounterSupplier())); - stats.put(MLNodeLevelStat.ML_NODE_TOTAL_MODEL_COUNT, new MLStat<>(false, new CounterSupplier())); - stats.put(MLNodeLevelStat.ML_NODE_TOTAL_CIRCUIT_BREAKER_TRIGGER_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_REQUEST_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_FAILURE_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_TOTAL_MODEL_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_CIRCUIT_BREAKER_TRIGGER_COUNT, new MLStat<>(false, new CounterSupplier())); this.mlStats = new MLStats(stats); mlIndicesHandler = new MLIndicesHandler(clusterService, client); diff --git a/plugin/src/main/java/org/opensearch/ml/stats/MLNodeLevelStat.java b/plugin/src/main/java/org/opensearch/ml/stats/MLNodeLevelStat.java index d002c002bf..f230e9c49e 100644 --- a/plugin/src/main/java/org/opensearch/ml/stats/MLNodeLevelStat.java +++ b/plugin/src/main/java/org/opensearch/ml/stats/MLNodeLevelStat.java @@ -10,12 +10,12 @@ * This enum represents node level stats. */ public enum MLNodeLevelStat { - ML_NODE_JVM_HEAP_USAGE, - ML_NODE_EXECUTING_TASK_COUNT, - ML_NODE_TOTAL_REQUEST_COUNT, - ML_NODE_TOTAL_FAILURE_COUNT, - ML_NODE_TOTAL_MODEL_COUNT, - ML_NODE_TOTAL_CIRCUIT_BREAKER_TRIGGER_COUNT; + ML_JVM_HEAP_USAGE, + ML_EXECUTING_TASK_COUNT, + ML_REQUEST_COUNT, + ML_FAILURE_COUNT, + ML_TOTAL_MODEL_COUNT, + ML_CIRCUIT_BREAKER_TRIGGER_COUNT; public static MLNodeLevelStat from(String value) { try { diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLExecuteTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLExecuteTaskRunner.java index f5cb8a3846..1498b39d07 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLExecuteTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLExecuteTaskRunner.java @@ -88,8 +88,8 @@ protected TransportResponseHandler getResponseHandler(Act protected void executeTask(MLExecuteTaskRequest request, ActionListener listener) { threadPool.executor(EXECUTE_THREAD_POOL).execute(() -> { try { - mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).increment(); - mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT).increment(); + mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).increment(); + mlStats.getStat(MLNodeLevelStat.ML_REQUEST_COUNT).increment(); mlStats .createCounterStatIfAbsent(request.getFunctionName(), ActionName.EXECUTE, MLActionLevelStat.ML_ACTION_REQUEST_COUNT) .increment(); @@ -113,7 +113,7 @@ protected void executeTask(MLExecuteTaskRequest request, ActionListener listener) { ActionListener internalListener = wrappedCleanupListener(listener, mlTask.getTaskId()); // track ML task count and add ML task into cache - mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).increment(); - mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT).increment(); + mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).increment(); + mlStats.getStat(MLNodeLevelStat.ML_REQUEST_COUNT).increment(); mlStats .createCounterStatIfAbsent(mlTask.getFunctionName(), ActionName.PREDICT, MLActionLevelStat.ML_ACTION_REQUEST_COUNT) .increment(); @@ -303,7 +303,7 @@ private void handlePredictFailure(MLTask mlTask, ActionListener mlStats .createCounterStatIfAbsent(mlTask.getFunctionName(), ActionName.PREDICT, MLActionLevelStat.ML_ACTION_FAILURE_COUNT) .increment(); - mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_FAILURE_COUNT).increment(); + mlStats.getStat(MLNodeLevelStat.ML_FAILURE_COUNT).increment(); } handleAsyncMLTaskFailure(mlTask, e); listener.onFailure(e); diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLTaskDispatcher.java b/plugin/src/main/java/org/opensearch/ml/task/MLTaskDispatcher.java index 84860005fc..f25285d5b3 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLTaskDispatcher.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLTaskDispatcher.java @@ -107,14 +107,14 @@ private void dispatchTaskWithLeastLoad(String[] nodeIds, ActionListener listener) { MLStatsNodesRequest MLStatsNodesRequest = new MLStatsNodesRequest(nodes); MLStatsNodesRequest - .addNodeLevelStats(ImmutableSet.of(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT, MLNodeLevelStat.ML_NODE_JVM_HEAP_USAGE)); + .addNodeLevelStats(ImmutableSet.of(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT, MLNodeLevelStat.ML_JVM_HEAP_USAGE)); client.execute(MLStatsNodesAction.INSTANCE, MLStatsNodesRequest, ActionListener.wrap(mlStatsResponse -> { // Check JVM pressure List candidateNodeResponse = mlStatsResponse .getNodes() .stream() - .filter(stat -> (long) stat.getNodeLevelStat(MLNodeLevelStat.ML_NODE_JVM_HEAP_USAGE) < DEFAULT_JVM_HEAP_USAGE_THRESHOLD) + .filter(stat -> (long) stat.getNodeLevelStat(MLNodeLevelStat.ML_JVM_HEAP_USAGE) < DEFAULT_JVM_HEAP_USAGE_THRESHOLD) .collect(Collectors.toList()); if (candidateNodeResponse.size() == 0) { @@ -129,7 +129,7 @@ private void dispatchTaskWithLeastLoad(DiscoveryNode[] nodes, ActionListener (Long) stat.getNodeLevelStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT) < maxMLBatchTaskPerNode) + .filter(stat -> (Long) stat.getNodeLevelStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT) < maxMLBatchTaskPerNode) .collect(Collectors.toList()); if (candidateNodeResponse.size() == 0) { String errorMessage = "All nodes' executing ML task count reach limitation."; @@ -142,13 +142,13 @@ private void dispatchTaskWithLeastLoad(DiscoveryNode[] nodes, ActionListener targetNode = candidateNodeResponse .stream() .sorted((MLStatsNodeResponse r1, MLStatsNodeResponse r2) -> { - int result = ((Long) r1.getNodeLevelStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT)) - .compareTo((Long) r2.getNodeLevelStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT)); + int result = ((Long) r1.getNodeLevelStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT)) + .compareTo((Long) r2.getNodeLevelStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT)); if (result == 0) { // if multiple nodes have same running task count, choose the one with least // JVM heap usage. - return ((Long) r1.getNodeLevelStat(MLNodeLevelStat.ML_NODE_JVM_HEAP_USAGE)) - .compareTo((Long) r2.getNodeLevelStat(MLNodeLevelStat.ML_NODE_JVM_HEAP_USAGE)); + return ((Long) r1.getNodeLevelStat(MLNodeLevelStat.ML_JVM_HEAP_USAGE)) + .compareTo((Long) r2.getNodeLevelStat(MLNodeLevelStat.ML_JVM_HEAP_USAGE)); } return result; }) diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLTaskRunner.java index 05c6c5f1bb..b2c71d6ed8 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLTaskRunner.java @@ -96,7 +96,7 @@ public void run(FunctionName functionName, Request request, TransportService tra protected ActionListener wrappedCleanupListener(ActionListener listener, String taskId) { ActionListener internalListener = ActionListener.runAfter(listener, () -> { - mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).decrement(); + mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).decrement(); mlTaskManager.remove(taskId); }); return internalListener; diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunner.java index 09603d291e..9461c3adaf 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunner.java @@ -128,8 +128,8 @@ protected void executeTask(MLTrainingTaskRequest request, ActionListener listener) { ActionListener internalListener = wrappedCleanupListener(listener, mlTask.getTaskId()); // track ML task count and add ML task into cache - mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).increment(); - mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT).increment(); + mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).increment(); + mlStats.getStat(MLNodeLevelStat.ML_REQUEST_COUNT).increment(); mlStats .createCounterStatIfAbsent(mlTask.getFunctionName(), ActionName.TRAIN_PREDICT, MLActionLevelStat.ML_ACTION_REQUEST_COUNT) .increment(); @@ -161,7 +161,7 @@ private void handlePredictFailure(MLTask mlTask, ActionListener mlStats .createCounterStatIfAbsent(mlTask.getFunctionName(), ActionName.TRAIN_PREDICT, MLActionLevelStat.ML_ACTION_FAILURE_COUNT) .increment(); - mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_FAILURE_COUNT).increment(); + mlStats.getStat(MLNodeLevelStat.ML_FAILURE_COUNT).increment(); } handleAsyncMLTaskFailure(mlTask, e); listener.onFailure(e); diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLTrainingTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLTrainingTaskRunner.java index 090afb445a..ca436e72f7 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLTrainingTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLTrainingTaskRunner.java @@ -146,8 +146,8 @@ protected void executeTask(MLTrainingTaskRequest request, ActionListener listener) { ActionListener internalListener = wrappedCleanupListener(listener, mlTask.getTaskId()); // track ML task count and add ML task into cache - mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).increment(); - mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT).increment(); + mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).increment(); + mlStats.getStat(MLNodeLevelStat.ML_REQUEST_COUNT).increment(); mlStats .createCounterStatIfAbsent(mlTask.getFunctionName(), ActionName.TRAIN, MLActionLevelStat.ML_ACTION_REQUEST_COUNT) .increment(); @@ -180,7 +180,7 @@ private void train(MLTask mlTask, MLInput mlInput, ActionListener void parseField(XContentParser parser, Set set, Function { ActionListener listener = invocation.getArgument(1); diff --git a/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodeITTests.java b/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodeITTests.java index f68bec725a..0fa4be0b7d 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodeITTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodeITTests.java @@ -5,7 +5,7 @@ package org.opensearch.ml.action.stats; -import static org.opensearch.ml.stats.MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT; +import static org.opensearch.ml.stats.MLNodeLevelStat.ML_EXECUTING_TASK_COUNT; import static org.opensearch.ml.utils.IntegTestUtils.TESTING_DATA; import static org.opensearch.ml.utils.IntegTestUtils.generateMLTestingData; import static org.opensearch.ml.utils.IntegTestUtils.verifyGeneratedTestingData; @@ -45,7 +45,7 @@ public void testGeneratedTestingData() throws ExecutionException, InterruptedExc public void testNormalCase() throws ExecutionException, InterruptedException { MLStatsNodesRequest request = new MLStatsNodesRequest(new String[0], new MLStatsInput()); - request.addNodeLevelStats(ImmutableSet.of(ML_NODE_EXECUTING_TASK_COUNT)); + request.addNodeLevelStats(ImmutableSet.of(ML_EXECUTING_TASK_COUNT)); ActionFuture future = client().execute(MLStatsNodesAction.INSTANCE, request); MLStatsNodesResponse response = future.get(); @@ -58,6 +58,6 @@ public void testNormalCase() throws ExecutionException, InterruptedException { MLStatsNodeResponse nodeResponse = responseList.get(0); assertEquals(1, nodeResponse.getNodeLevelStatSize()); - assertEquals(0l, nodeResponse.getNodeLevelStat(ML_NODE_EXECUTING_TASK_COUNT)); + assertEquals(0l, nodeResponse.getNodeLevelStat(ML_EXECUTING_TASK_COUNT)); } } diff --git a/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodeResponseTests.java b/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodeResponseTests.java index de6b599728..a9c5dabcf1 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodeResponseTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodeResponseTests.java @@ -39,14 +39,14 @@ public class MLStatsNodeResponseTests extends OpenSearchTestCase { public void setup() { node = new DiscoveryNode("node0", buildNewFakeTransportAddress(), Version.CURRENT); Map statsToValues = new HashMap<>(); - statsToValues.put(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT, 100); + statsToValues.put(MLNodeLevelStat.ML_REQUEST_COUNT, 100); response = new MLStatsNodeResponse(node, statsToValues); } public void testSerializationDeserialization() throws IOException { DiscoveryNode localNode = new DiscoveryNode("node0", buildNewFakeTransportAddress(), Version.CURRENT); Map statsToValues = new HashMap<>(); - statsToValues.put(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT, 10l); + statsToValues.put(MLNodeLevelStat.ML_REQUEST_COUNT, 10l); MLStatsNodeResponse response = new MLStatsNodeResponse(localNode, statsToValues); BytesStreamOutput output = new BytesStreamOutput(); response.writeTo(output); @@ -60,7 +60,7 @@ public void testToXContent_NodeLevelStats() throws IOException { response.toXContent(builder, ToXContent.EMPTY_PARAMS); builder.endObject(); String taskContent = TestHelper.xContentBuilderToString(builder); - assertEquals("{\"ml_node_total_request_count\":100}", taskContent); + assertEquals("{\"ML_REQUEST_COUNT\":100}", taskContent); } public void testToXContent_AlgorithmStats() throws IOException { @@ -100,7 +100,7 @@ public void testToXContent_WithAlgoStats() throws IOException { builder.startObject(); DiscoveryNode node = new DiscoveryNode("node0", buildNewFakeTransportAddress(), Version.CURRENT); Map statsToValues = new HashMap<>(); - statsToValues.put(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT, 100); + statsToValues.put(MLNodeLevelStat.ML_REQUEST_COUNT, 100); Map algoStats = new HashMap<>(); Map algoActionStats = new HashMap<>(); Map algoActionStatMap = new HashMap<>(); @@ -114,8 +114,8 @@ public void testToXContent_WithAlgoStats() throws IOException { String taskContent = TestHelper.xContentBuilderToString(builder); Set validResult = ImmutableSet .of( - "{\"ml_node_total_request_count\":100,\"algorithms\":{\"kmeans\":{\"train\":{\"ml_action_failure_count\":22,\"ml_action_request_count\":111}}}}", - "{\"ml_node_total_request_count\":100,\"algorithms\":{\"kmeans\":{\"train\":{\"ml_action_request_count\":111,\"ml_action_failure_count\":22}}}}" + "{\"ML_REQUEST_COUNT\":100,\"algorithms\":{\"kmeans\":{\"train\":{\"ml_action_failure_count\":22,\"ml_action_request_count\":111}}}}", + "{\"ML_REQUEST_COUNT\":100,\"algorithms\":{\"kmeans\":{\"train\":{\"ml_action_request_count\":111,\"ml_action_failure_count\":22}}}}" ); assertTrue(validResult.contains(taskContent)); } @@ -124,8 +124,8 @@ public void testReadStats() throws IOException { BytesStreamOutput output = new BytesStreamOutput(); response.writeTo(output); MLStatsNodeResponse mlStatsNodeResponse = MLStatsNodeResponse.readStats(output.bytes().streamInput()); - Integer expectedValue = (Integer) response.getNodeLevelStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT); - assertEquals(expectedValue, mlStatsNodeResponse.getNodeLevelStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT)); + Integer expectedValue = (Integer) response.getNodeLevelStat(MLNodeLevelStat.ML_REQUEST_COUNT); + assertEquals(expectedValue, mlStatsNodeResponse.getNodeLevelStat(MLNodeLevelStat.ML_REQUEST_COUNT)); } public void testIsEmpty_NullNodeStats() { @@ -150,23 +150,23 @@ public void testIsEmpty_EmptyAlgoStats() { public void testIsEmpty_NonEmptyNodeAndAlgoStats() { MLStatsNodeResponse response = createResponseWithDefaultAlgoStats( - ImmutableMap.of(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT, totalRequestCount) + ImmutableMap.of(MLNodeLevelStat.ML_REQUEST_COUNT, totalRequestCount) ); assertFalse(response.isEmpty()); } public void testGetNodeLevelStat_NonExistingStat() { - assertNull(response.getNodeLevelStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT)); + assertNull(response.getNodeLevelStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT)); assertEquals(1, response.getNodeLevelStatSize()); } public void testGetNodeLevelStat_NullOrEmptyNodeStats() { MLStatsNodeResponse response = new MLStatsNodeResponse(node, null); - assertNull(response.getNodeLevelStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT)); + assertNull(response.getNodeLevelStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT)); assertEquals(0, response.getNodeLevelStatSize()); response = new MLStatsNodeResponse(node, ImmutableMap.of()); - assertNull(response.getNodeLevelStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT)); + assertNull(response.getNodeLevelStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT)); assertEquals(0, response.getNodeLevelStatSize()); } diff --git a/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodesRequestTests.java b/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodesRequestTests.java index 4e8746811a..a59a2ca30c 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodesRequestTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodesRequestTests.java @@ -20,7 +20,7 @@ public class MLStatsNodesRequestTests extends OpenSearchTestCase { public void testSerializationDeserialization() throws IOException { MLStatsNodesRequest mlStatsNodesRequest = new MLStatsNodesRequest(new String[] { "testNodeId" }, new MLStatsInput()); - mlStatsNodesRequest.addNodeLevelStats(ImmutableSet.of(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT)); + mlStatsNodesRequest.addNodeLevelStats(ImmutableSet.of(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT)); BytesStreamOutput output = new BytesStreamOutput(); MLStatsNodeRequest request = new MLStatsNodeRequest(mlStatsNodesRequest); request.writeTo(output); diff --git a/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodesResponseTests.java b/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodesResponseTests.java index 832ddbb282..865d207e96 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodesResponseTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodesResponseTests.java @@ -44,12 +44,12 @@ public void testToXContent() throws IOException { DiscoveryNode node1 = new DiscoveryNode("node1", buildNewFakeTransportAddress(), Version.CURRENT); Map nodeLevelStats1 = new HashMap<>(); - nodeLevelStats1.put(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT, 100); + nodeLevelStats1.put(MLNodeLevelStat.ML_REQUEST_COUNT, 100); nodes.add(new MLStatsNodeResponse(node1, nodeLevelStats1)); DiscoveryNode node2 = new DiscoveryNode("node2", buildNewFakeTransportAddress(), Version.CURRENT); Map nodeLevelStats2 = new HashMap<>(); - nodeLevelStats2.put(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT, 200); + nodeLevelStats2.put(MLNodeLevelStat.ML_REQUEST_COUNT, 200); nodes.add(new MLStatsNodeResponse(node2, nodeLevelStats2)); List failures = new ArrayList<>(); @@ -58,9 +58,6 @@ public void testToXContent() throws IOException { response.toXContent(builder, ToXContent.EMPTY_PARAMS); builder.endObject(); String taskContent = TestHelper.xContentBuilderToString(builder); - assertEquals( - "{\"nodes\":{\"node1\":{\"ml_node_total_request_count\":100},\"node2\":{\"ml_node_total_request_count\":200}}}", - taskContent - ); + assertEquals("{\"nodes\":{\"node1\":{\"ML_REQUEST_COUNT\":100},\"node2\":{\"ML_REQUEST_COUNT\":200}}}", taskContent); } } diff --git a/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodesTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodesTransportActionTests.java index fe0e962fd8..b54c9ca83a 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodesTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodesTransportActionTests.java @@ -7,7 +7,7 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; -import static org.opensearch.ml.stats.MLNodeLevelStat.ML_NODE_JVM_HEAP_USAGE; +import static org.opensearch.ml.stats.MLNodeLevelStat.ML_JVM_HEAP_USAGE; import java.io.IOException; import java.util.EnumSet; @@ -56,13 +56,13 @@ public void setUp() throws Exception { super.setUp(); clusterStatName1 = MLClusterLevelStat.ML_MODEL_COUNT; - nodeStatName1 = MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT; + nodeStatName1 = MLNodeLevelStat.ML_EXECUTING_TASK_COUNT; statsMap = new HashMap<>() { { put(nodeStatName1, new MLStat<>(false, new CounterSupplier())); put(clusterStatName1, new MLStat<>(true, new CounterSupplier())); - put(ML_NODE_JVM_HEAP_USAGE, new MLStat<>(true, new SettableSupplier())); + put(ML_JVM_HEAP_USAGE, new MLStat<>(true, new SettableSupplier())); } }; @@ -119,14 +119,14 @@ public void testNodeOperationWithJvmHeapUsage() { String nodeId = clusterService().localNode().getId(); MLStatsNodesRequest mlStatsNodesRequest = new MLStatsNodesRequest(new String[] { nodeId }, new MLStatsInput()); - Set statsToBeRetrieved = ImmutableSet.of(ML_NODE_JVM_HEAP_USAGE); + Set statsToBeRetrieved = ImmutableSet.of(ML_JVM_HEAP_USAGE); mlStatsNodesRequest.addNodeLevelStats(statsToBeRetrieved); MLStatsNodeResponse response = action.nodeOperation(new MLStatsNodeRequest(mlStatsNodesRequest)); Assert.assertEquals(statsToBeRetrieved.size(), response.getNodeLevelStatSize()); - assertNotNull(response.getNodeLevelStat(ML_NODE_JVM_HEAP_USAGE)); + assertNotNull(response.getNodeLevelStat(ML_JVM_HEAP_USAGE)); } public void testNodeOperation_NoNodeLevelStat() { diff --git a/plugin/src/test/java/org/opensearch/ml/bwc/MLCommonsBackwardsCompatibilityRestTestCase.java b/plugin/src/test/java/org/opensearch/ml/bwc/MLCommonsBackwardsCompatibilityRestTestCase.java index fcc746b9e0..38467caa10 100644 --- a/plugin/src/test/java/org/opensearch/ml/bwc/MLCommonsBackwardsCompatibilityRestTestCase.java +++ b/plugin/src/test/java/org/opensearch/ml/bwc/MLCommonsBackwardsCompatibilityRestTestCase.java @@ -14,8 +14,8 @@ import static org.opensearch.ml.common.MLTask.MODEL_ID_FIELD; import static org.opensearch.ml.common.MLTask.STATE_FIELD; import static org.opensearch.ml.common.MLTask.TASK_ID_FIELD; -import static org.opensearch.ml.stats.MLNodeLevelStat.ML_NODE_TOTAL_FAILURE_COUNT; -import static org.opensearch.ml.stats.MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT; +import static org.opensearch.ml.stats.MLNodeLevelStat.ML_FAILURE_COUNT; +import static org.opensearch.ml.stats.MLNodeLevelStat.ML_REQUEST_COUNT; import static org.opensearch.ml.utils.TestData.SENTENCE_TRANSFORMER_MODEL_URL; import static org.opensearch.ml.utils.TestData.trainModelDataJson; @@ -292,11 +292,11 @@ protected void validateStats( Map allNodeStats = (Map) map.get("nodes"); for (String key : allNodeStats.keySet()) { Map nodeStatsMap = (Map) allNodeStats.get(key); - String statKey = ML_NODE_TOTAL_FAILURE_COUNT.name().toLowerCase(Locale.ROOT); + String statKey = ML_FAILURE_COUNT.name().toLowerCase(Locale.ROOT); if (nodeStatsMap.containsKey(statKey)) { totalFailureCount += (Double) nodeStatsMap.get(statKey); } - statKey = ML_NODE_TOTAL_REQUEST_COUNT.name().toLowerCase(Locale.ROOT); + statKey = ML_REQUEST_COUNT.name().toLowerCase(Locale.ROOT); if (nodeStatsMap.containsKey(statKey)) { totalRequestCount += (Double) nodeStatsMap.get(statKey); } diff --git a/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java b/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java index 5b59b8a629..248b3e3aa3 100644 --- a/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java @@ -213,11 +213,11 @@ public void setup() throws URISyntaxException { Map> stats = new ConcurrentHashMap<>(); // node level stats - stats.put(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT, new MLStat<>(false, new CounterSupplier())); - stats.put(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT, new MLStat<>(false, new CounterSupplier())); - stats.put(MLNodeLevelStat.ML_NODE_TOTAL_FAILURE_COUNT, new MLStat<>(false, new CounterSupplier())); - stats.put(MLNodeLevelStat.ML_NODE_TOTAL_MODEL_COUNT, new MLStat<>(false, new CounterSupplier())); - stats.put(MLNodeLevelStat.ML_NODE_TOTAL_CIRCUIT_BREAKER_TRIGGER_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_REQUEST_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_FAILURE_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_TOTAL_MODEL_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_CIRCUIT_BREAKER_TRIGGER_COUNT, new MLStat<>(false, new CounterSupplier())); this.mlStats = spy(new MLStats(stats)); mlTask = MLTask @@ -734,7 +734,7 @@ private void testDeployModel_FailedToRetrieveModelChunks(boolean lastChunk) { modelManager.deployModel(modelId, modelContentHashValue, functionName, true, mlTask, listener); verify(modelCacheHelper).removeModel(eq(modelId)); verify(mlStats).createCounterStatIfAbsent(eq(functionName), eq(ActionName.DEPLOY), eq(MLActionLevelStat.ML_ACTION_REQUEST_COUNT)); - verify(mlStats).getStat(eq(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT)); + verify(mlStats).getStat(eq(MLNodeLevelStat.ML_REQUEST_COUNT)); } private void mock_client_index_ModelChunkFailure(Client client, String modelId) { diff --git a/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java b/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java index 559fba3797..e943b8e42f 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java @@ -14,8 +14,8 @@ import static org.opensearch.ml.common.MLTask.MODEL_ID_FIELD; import static org.opensearch.ml.common.MLTask.STATE_FIELD; import static org.opensearch.ml.common.MLTask.TASK_ID_FIELD; -import static org.opensearch.ml.stats.MLNodeLevelStat.ML_NODE_TOTAL_FAILURE_COUNT; -import static org.opensearch.ml.stats.MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT; +import static org.opensearch.ml.stats.MLNodeLevelStat.ML_FAILURE_COUNT; +import static org.opensearch.ml.stats.MLNodeLevelStat.ML_REQUEST_COUNT; import static org.opensearch.ml.utils.TestData.SENTENCE_TRANSFORMER_MODEL_URL; import static org.opensearch.ml.utils.TestData.trainModelDataJson; @@ -338,11 +338,11 @@ protected void validateStats( Map allNodeStats = (Map) map.get("nodes"); for (String key : allNodeStats.keySet()) { Map nodeStatsMap = (Map) allNodeStats.get(key); - String statKey = ML_NODE_TOTAL_FAILURE_COUNT.name().toLowerCase(Locale.ROOT); + String statKey = ML_FAILURE_COUNT.name().toLowerCase(Locale.ROOT); if (nodeStatsMap.containsKey(statKey)) { totalFailureCount += (Double) nodeStatsMap.get(statKey); } - statKey = ML_NODE_TOTAL_REQUEST_COUNT.name().toLowerCase(Locale.ROOT); + statKey = ML_REQUEST_COUNT.name().toLowerCase(Locale.ROOT); if (nodeStatsMap.containsKey(statKey)) { totalRequestCount += (Double) nodeStatsMap.get(statKey); } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLStatsActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLStatsActionTests.java index 603aca698f..b3bb3e8d1a 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLStatsActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLStatsActionTests.java @@ -113,7 +113,7 @@ public void setup() throws IOException { MockitoAnnotations.openMocks(this); Map> statMap = ImmutableMap .>builder() - .put(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT, new MLStat<>(false, new CounterSupplier())) + .put(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT, new MLStat<>(false, new CounterSupplier())) .build(); mlStats = new MLStats(statMap); threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); @@ -209,7 +209,7 @@ public void testPrepareRequest_ClusterAndNodeLevelStates() throws Exception { content .utf8ToString() .contains( - "\"nodes\":{\"node\":{\"ml_node_total_request_count\":100,\"algorithms\":{\"kmeans\":{\"train\":{\"ml_action_request_count\":20}}}}}}" + "\"nodes\":{\"node\":{\"ML_REQUEST_COUNT\":100,\"algorithms\":{\"kmeans\":{\"train\":{\"ml_action_request_count\":20}}}}}}" ) ); } @@ -218,7 +218,7 @@ private void prepareResponse() { doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(2); List nodes = new ArrayList<>(); - Map nodeStats = ImmutableMap.of(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT, nodeTotalRequestCount); + Map nodeStats = ImmutableMap.of(MLNodeLevelStat.ML_REQUEST_COUNT, nodeTotalRequestCount); Map algoStats = new HashMap<>(); Map actionStats = ImmutableMap .of( @@ -299,7 +299,7 @@ public void testPrepareRequest_ClusterAndNodeLevelStates_NoRequestContent() thro content .utf8ToString() .contains( - "\"nodes\":{\"node\":{\"ml_node_total_request_count\":100,\"algorithms\":{\"kmeans\":{\"train\":{\"ml_action_request_count\":20}}}}}}" + "\"nodes\":{\"node\":{\"ML_REQUEST_COUNT\":100,\"algorithms\":{\"kmeans\":{\"train\":{\"ml_action_request_count\":20}}}}}}" ) ); } @@ -309,7 +309,7 @@ public void testPrepareRequest_ClusterAndNodeLevelStates_RequestParams() throws RestRequest request = getStatsRestRequest( node.getId(), - MLClusterLevelStat.ML_MODEL_COUNT + "," + MLNodeLevelStat.ML_NODE_TOTAL_MODEL_COUNT + MLClusterLevelStat.ML_MODEL_COUNT + "," + MLNodeLevelStat.ML_TOTAL_MODEL_COUNT ); restAction.handleRequest(request, channel, client); @@ -321,7 +321,7 @@ public void testPrepareRequest_ClusterAndNodeLevelStates_RequestParams() throws assertTrue(input.getTargetStatLevels().contains(MLStatLevel.NODE)); assertEquals(1, input.getClusterLevelStats().size()); assertTrue(input.getClusterLevelStats().contains(MLClusterLevelStat.ML_MODEL_COUNT)); - assertTrue(input.getNodeLevelStats().contains(MLNodeLevelStat.ML_NODE_TOTAL_MODEL_COUNT)); + assertTrue(input.getNodeLevelStats().contains(MLNodeLevelStat.ML_TOTAL_MODEL_COUNT)); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(BytesRestResponse.class); verify(channel, times(1)).sendResponse(argumentCaptor.capture()); @@ -334,7 +334,7 @@ public void testPrepareRequest_ClusterAndNodeLevelStates_RequestParams() throws content .utf8ToString() .contains( - "\"nodes\":{\"node\":{\"ml_node_total_request_count\":100,\"algorithms\":{\"kmeans\":{\"train\":{\"ml_action_request_count\":20}}}}}}" + "\"nodes\":{\"node\":{\"ML_REQUEST_COUNT\":100,\"algorithms\":{\"kmeans\":{\"train\":{\"ml_action_request_count\":20}}}}}}" ) ); } @@ -342,7 +342,7 @@ public void testPrepareRequest_ClusterAndNodeLevelStates_RequestParams() throws public void testPrepareRequest_ClusterAndNodeLevelStates_RequestParams_NodeLevelStat() throws Exception { prepareResponse(); - RestRequest request = getStatsRestRequest(node.getId(), MLNodeLevelStat.ML_NODE_TOTAL_MODEL_COUNT.name()); + RestRequest request = getStatsRestRequest(node.getId(), MLNodeLevelStat.ML_TOTAL_MODEL_COUNT.name()); restAction.handleRequest(request, channel, client); ArgumentCaptor inputArgumentCaptor = ArgumentCaptor.forClass(MLStatsNodesRequest.class); @@ -352,7 +352,7 @@ public void testPrepareRequest_ClusterAndNodeLevelStates_RequestParams_NodeLevel assertTrue(input.getTargetStatLevels().contains(MLStatLevel.NODE)); assertEquals(0, input.getClusterLevelStats().size()); assertEquals(1, input.getNodeLevelStats().size()); - assertTrue(input.getNodeLevelStats().contains(MLNodeLevelStat.ML_NODE_TOTAL_MODEL_COUNT)); + assertTrue(input.getNodeLevelStats().contains(MLNodeLevelStat.ML_TOTAL_MODEL_COUNT)); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(BytesRestResponse.class); verify(channel, times(1)).sendResponse(argumentCaptor.capture()); @@ -360,17 +360,17 @@ public void testPrepareRequest_ClusterAndNodeLevelStates_RequestParams_NodeLevel assertEquals(RestStatus.OK, restResponse.status()); BytesReference content = restResponse.content(); assertEquals( - "{\"nodes\":{\"node\":{\"ml_node_total_request_count\":100,\"algorithms\":{\"kmeans\":{\"train\":{\"ml_action_request_count\":20}}}}}}", + "{\"nodes\":{\"node\":{\"ML_REQUEST_COUNT\":100,\"algorithms\":{\"kmeans\":{\"train\":{\"ml_action_request_count\":20}}}}}}", content.utf8ToString() ); } public void testCreateMlStatsInputFromRequestParams_NodeStat() { - RestRequest request = getStatsRestRequest(node.getId(), MLNodeLevelStat.ML_NODE_TOTAL_MODEL_COUNT.name().toLowerCase(Locale.ROOT)); + RestRequest request = getStatsRestRequest(node.getId(), MLNodeLevelStat.ML_TOTAL_MODEL_COUNT.name().toLowerCase(Locale.ROOT)); MLStatsInput input = restAction.createMlStatsInputFromRequestParams(request); assertEquals(1, input.getTargetStatLevels().size()); assertTrue(input.getTargetStatLevels().contains(MLStatLevel.NODE)); - assertTrue(input.getNodeLevelStats().contains(MLNodeLevelStat.ML_NODE_TOTAL_MODEL_COUNT)); + assertTrue(input.getNodeLevelStats().contains(MLNodeLevelStat.ML_TOTAL_MODEL_COUNT)); assertEquals(0, input.getClusterLevelStats().size()); } diff --git a/plugin/src/test/java/org/opensearch/ml/stats/MLStatsInputTests.java b/plugin/src/test/java/org/opensearch/ml/stats/MLStatsInputTests.java index e8ee91397a..ddad31f58d 100644 --- a/plugin/src/test/java/org/opensearch/ml/stats/MLStatsInputTests.java +++ b/plugin/src/test/java/org/opensearch/ml/stats/MLStatsInputTests.java @@ -84,27 +84,27 @@ public void testRetrieveAll() { public void testShouldRetrieveStat() { assertTrue(mlStatsInput.retrieveStat(MLClusterLevelStat.ML_MODEL_COUNT)); - assertTrue(mlStatsInput.retrieveStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT)); + assertTrue(mlStatsInput.retrieveStat(MLNodeLevelStat.ML_REQUEST_COUNT)); assertTrue(mlStatsInput.retrieveStat(MLActionLevelStat.ML_ACTION_REQUEST_COUNT)); MLStatsInput mlStatsInput = MLStatsInput.builder().build(); assertTrue(mlStatsInput.retrieveStat(MLClusterLevelStat.ML_MODEL_COUNT)); - assertTrue(mlStatsInput.retrieveStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT)); + assertTrue(mlStatsInput.retrieveStat(MLNodeLevelStat.ML_REQUEST_COUNT)); assertTrue(mlStatsInput.retrieveStat(MLActionLevelStat.ML_ACTION_REQUEST_COUNT)); mlStatsInput = new MLStatsInput(); assertTrue(mlStatsInput.retrieveStat(MLClusterLevelStat.ML_MODEL_COUNT)); - assertTrue(mlStatsInput.retrieveStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT)); + assertTrue(mlStatsInput.retrieveStat(MLNodeLevelStat.ML_REQUEST_COUNT)); assertTrue(mlStatsInput.retrieveStat(MLActionLevelStat.ML_ACTION_REQUEST_COUNT)); mlStatsInput = MLStatsInput .builder() .clusterLevelStats(EnumSet.of(MLClusterLevelStat.ML_TASK_INDEX_STATUS)) - .nodeLevelStats(EnumSet.of(MLNodeLevelStat.ML_NODE_TOTAL_FAILURE_COUNT)) + .nodeLevelStats(EnumSet.of(MLNodeLevelStat.ML_FAILURE_COUNT)) .actionLevelStats(EnumSet.of(MLActionLevelStat.ML_ACTION_FAILURE_COUNT)) .build(); assertFalse(mlStatsInput.retrieveStat(MLClusterLevelStat.ML_MODEL_COUNT)); - assertFalse(mlStatsInput.retrieveStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT)); + assertFalse(mlStatsInput.retrieveStat(MLNodeLevelStat.ML_REQUEST_COUNT)); assertFalse(mlStatsInput.retrieveStat(MLActionLevelStat.ML_ACTION_REQUEST_COUNT)); } diff --git a/plugin/src/test/java/org/opensearch/ml/stats/MLStatsTests.java b/plugin/src/test/java/org/opensearch/ml/stats/MLStatsTests.java index 83e8676646..51bde6bd3a 100644 --- a/plugin/src/test/java/org/opensearch/ml/stats/MLStatsTests.java +++ b/plugin/src/test/java/org/opensearch/ml/stats/MLStatsTests.java @@ -33,7 +33,7 @@ public class MLStatsTests extends OpenSearchTestCase { public void setup() { clusterStatName1 = MLClusterLevelStat.ML_MODEL_COUNT; - nodeStatName1 = MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT; + nodeStatName1 = MLNodeLevelStat.ML_EXECUTING_TASK_COUNT; statsMap = new HashMap>() { { @@ -70,14 +70,14 @@ public void testGetStat() { } public void testGetStatNoExisting() { - MLNodeLevelStat wrongStat = MLNodeLevelStat.ML_NODE_JVM_HEAP_USAGE; + MLNodeLevelStat wrongStat = MLNodeLevelStat.ML_JVM_HEAP_USAGE; expectedEx.expect(IllegalArgumentException.class); expectedEx.expectMessage("Stat \"" + wrongStat + "\" does not exist"); mlStats.getStat(wrongStat); } public void testCreateCounterStatIfAbsent() { - MLStat stat = mlStats.createCounterStatIfAbsent(MLNodeLevelStat.ML_NODE_TOTAL_FAILURE_COUNT); + MLStat stat = mlStats.createCounterStatIfAbsent(MLNodeLevelStat.ML_FAILURE_COUNT); stat.increment(); assertEquals(1L, stat.getValue()); } diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLExecuteTaskRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLExecuteTaskRunnerTests.java index 0a4decb21a..47d5edda5c 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/MLExecuteTaskRunnerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/MLExecuteTaskRunnerTests.java @@ -119,10 +119,10 @@ public void setup() { when(clusterService.getClusterSettings()).thenReturn(clusterSettings); Map> stats = new ConcurrentHashMap<>(); - stats.put(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT, new MLStat<>(false, new CounterSupplier())); - stats.put(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT, new MLStat<>(false, new CounterSupplier())); - stats.put(MLNodeLevelStat.ML_NODE_TOTAL_FAILURE_COUNT, new MLStat<>(false, new CounterSupplier())); - stats.put(MLNodeLevelStat.ML_NODE_TOTAL_MODEL_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_REQUEST_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_FAILURE_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_TOTAL_MODEL_COUNT, new MLStat<>(false, new CounterSupplier())); this.mlStats = new MLStats(stats); mlInputDatasetHandler = spy(new MLInputDatasetHandler(client)); diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java index 7bd5d889ed..4310ebad66 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java @@ -141,10 +141,10 @@ public void setup() throws IOException { }).when(executorService).execute(any(Runnable.class)); Map> stats = new ConcurrentHashMap<>(); - stats.put(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT, new MLStat<>(false, new CounterSupplier())); - stats.put(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT, new MLStat<>(false, new CounterSupplier())); - stats.put(MLNodeLevelStat.ML_NODE_TOTAL_FAILURE_COUNT, new MLStat<>(false, new CounterSupplier())); - stats.put(MLNodeLevelStat.ML_NODE_TOTAL_MODEL_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_REQUEST_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_FAILURE_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_TOTAL_MODEL_COUNT, new MLStat<>(false, new CounterSupplier())); this.mlStats = new MLStats(stats); mlInputDatasetHandler = spy(new MLInputDatasetHandler(client)); taskRunner = spy( diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLTaskDispatcherTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLTaskDispatcherTests.java index 523cde8169..70373995ad 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/MLTaskDispatcherTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/MLTaskDispatcherTests.java @@ -160,8 +160,8 @@ public void testGetEligibleNodes_MlAndDataNodes() { private MLStatsNodesResponse getMlStatsNodesResponse() { Map nodeStats = new HashMap<>(); - nodeStats.put(MLNodeLevelStat.ML_NODE_JVM_HEAP_USAGE, 50l); - nodeStats.put(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT, 5l); + nodeStats.put(MLNodeLevelStat.ML_JVM_HEAP_USAGE, 50l); + nodeStats.put(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT, 5l); MLStatsNodeResponse mlStatsNodeResponse1 = new MLStatsNodeResponse(dataNode1, nodeStats); MLStatsNodeResponse mlStatsNodeResponse2 = new MLStatsNodeResponse(dataNode1, nodeStats); return new MLStatsNodesResponse( @@ -173,7 +173,7 @@ private MLStatsNodesResponse getMlStatsNodesResponse() { private MLStatsNodesResponse getNodesResponse_NoTaskCounts() { Map nodeStats = new HashMap<>(); - nodeStats.put(MLNodeLevelStat.ML_NODE_JVM_HEAP_USAGE, 50l); + nodeStats.put(MLNodeLevelStat.ML_JVM_HEAP_USAGE, 50l); MLStatsNodeResponse mlStatsNodeResponse1 = new MLStatsNodeResponse(dataNode1, nodeStats); MLStatsNodeResponse mlStatsNodeResponse2 = new MLStatsNodeResponse(dataNode1, nodeStats); return new MLStatsNodesResponse( @@ -185,8 +185,8 @@ private MLStatsNodesResponse getNodesResponse_NoTaskCounts() { private MLStatsNodesResponse getNodesResponse_MemoryExceedLimits() { Map nodeStats = new HashMap<>(); - nodeStats.put(MLNodeLevelStat.ML_NODE_JVM_HEAP_USAGE, 90l); - nodeStats.put(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT, 5l); + nodeStats.put(MLNodeLevelStat.ML_JVM_HEAP_USAGE, 90l); + nodeStats.put(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT, 5l); MLStatsNodeResponse mlStatsNodeResponse1 = new MLStatsNodeResponse(dataNode1, nodeStats); MLStatsNodeResponse mlStatsNodeResponse2 = new MLStatsNodeResponse(dataNode1, nodeStats); return new MLStatsNodesResponse( @@ -198,8 +198,8 @@ private MLStatsNodesResponse getNodesResponse_MemoryExceedLimits() { private MLStatsNodesResponse getNodesResponse_TaskCountExceedLimits() { Map nodeStats = new HashMap<>(); - nodeStats.put(MLNodeLevelStat.ML_NODE_JVM_HEAP_USAGE, 50l); - nodeStats.put(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT, 15l); + nodeStats.put(MLNodeLevelStat.ML_JVM_HEAP_USAGE, 50l); + nodeStats.put(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT, 15l); MLStatsNodeResponse mlStatsNodeResponse1 = new MLStatsNodeResponse(dataNode1, nodeStats); MLStatsNodeResponse mlStatsNodeResponse2 = new MLStatsNodeResponse(dataNode1, nodeStats); return new MLStatsNodesResponse( diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunnerTests.java index 60f6af9d3e..0a18058a51 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunnerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunnerTests.java @@ -117,10 +117,10 @@ public void setup() { }).when(executorService).execute(any(Runnable.class)); Map> stats = new ConcurrentHashMap<>(); - stats.put(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT, new MLStat<>(false, new CounterSupplier())); - stats.put(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT, new MLStat<>(false, new CounterSupplier())); - stats.put(MLNodeLevelStat.ML_NODE_TOTAL_FAILURE_COUNT, new MLStat<>(false, new CounterSupplier())); - stats.put(MLNodeLevelStat.ML_NODE_TOTAL_MODEL_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_REQUEST_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_FAILURE_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_TOTAL_MODEL_COUNT, new MLStat<>(false, new CounterSupplier())); this.mlStats = new MLStats(stats); mlInputDatasetHandler = spy(new MLInputDatasetHandler(client)); diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLTrainingTaskRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLTrainingTaskRunnerTests.java index e5e8f3281a..724586568a 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/MLTrainingTaskRunnerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/MLTrainingTaskRunnerTests.java @@ -125,10 +125,10 @@ public void setup() { }).when(executorService).execute(any(Runnable.class)); Map> stats = new ConcurrentHashMap<>(); - stats.put(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT, new MLStat<>(false, new CounterSupplier())); - stats.put(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT, new MLStat<>(false, new CounterSupplier())); - stats.put(MLNodeLevelStat.ML_NODE_TOTAL_FAILURE_COUNT, new MLStat<>(false, new CounterSupplier())); - stats.put(MLNodeLevelStat.ML_NODE_TOTAL_MODEL_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_REQUEST_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_FAILURE_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_TOTAL_MODEL_COUNT, new MLStat<>(false, new CounterSupplier())); this.mlStats = new MLStats(stats); mlInputDatasetHandler = spy(new MLInputDatasetHandler(client)); diff --git a/plugin/src/test/java/org/opensearch/ml/task/TaskRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/task/TaskRunnerTests.java index 7f81eebea6..04c25b867e 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/TaskRunnerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/TaskRunnerTests.java @@ -69,11 +69,11 @@ public class TaskRunnerTests extends OpenSearchTestCase { @Before public void setup() { Map> stats = new ConcurrentHashMap<>(); - stats.put(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT, new MLStat<>(false, new CounterSupplier())); - stats.put(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT, new MLStat<>(false, new CounterSupplier())); - stats.put(MLNodeLevelStat.ML_NODE_TOTAL_FAILURE_COUNT, new MLStat<>(false, new CounterSupplier())); - stats.put(MLNodeLevelStat.ML_NODE_TOTAL_MODEL_COUNT, new MLStat<>(false, new CounterSupplier())); - stats.put(MLNodeLevelStat.ML_NODE_TOTAL_CIRCUIT_BREAKER_TRIGGER_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_REQUEST_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_FAILURE_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_TOTAL_MODEL_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_CIRCUIT_BREAKER_TRIGGER_COUNT, new MLStat<>(false, new CounterSupplier())); mlStats = new MLStats(stats); MockitoAnnotations.openMocks(this); @@ -140,7 +140,7 @@ public void testRun_CircuitBreakerOpen() { ActionListener listener = mock(ActionListener.class); MLTaskRequest request = new MLTaskRequest(false); expectThrows(MLLimitExceededException.class, () -> mlTaskRunner.run(FunctionName.REMOTE, request, transportService, listener)); - Long value = (Long) mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_CIRCUIT_BREAKER_TRIGGER_COUNT).getValue(); + Long value = (Long) mlStats.getStat(MLNodeLevelStat.ML_CIRCUIT_BREAKER_TRIGGER_COUNT).getValue(); assertEquals(1L, value.longValue()); } } From 1b703f9a3e9865e847872cd9e860fa0da6a84b83 Mon Sep 17 00:00:00 2001 From: Dhrubo Saha Date: Mon, 21 Aug 2023 10:30:18 -0700 Subject: [PATCH 02/11] updating tests Signed-off-by: Dhrubo Saha --- .../ml/action/stats/MLStatsNodeResponseTests.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodeResponseTests.java b/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodeResponseTests.java index a9c5dabcf1..b475b033ff 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodeResponseTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodeResponseTests.java @@ -60,7 +60,7 @@ public void testToXContent_NodeLevelStats() throws IOException { response.toXContent(builder, ToXContent.EMPTY_PARAMS); builder.endObject(); String taskContent = TestHelper.xContentBuilderToString(builder); - assertEquals("{\"ML_REQUEST_COUNT\":100}", taskContent); + assertEquals("{\"ml_request_count\":100}", taskContent); } public void testToXContent_AlgorithmStats() throws IOException { @@ -114,8 +114,8 @@ public void testToXContent_WithAlgoStats() throws IOException { String taskContent = TestHelper.xContentBuilderToString(builder); Set validResult = ImmutableSet .of( - "{\"ML_REQUEST_COUNT\":100,\"algorithms\":{\"kmeans\":{\"train\":{\"ml_action_failure_count\":22,\"ml_action_request_count\":111}}}}", - "{\"ML_REQUEST_COUNT\":100,\"algorithms\":{\"kmeans\":{\"train\":{\"ml_action_request_count\":111,\"ml_action_failure_count\":22}}}}" + "{\"ml_request_count\":100,\"algorithms\":{\"kmeans\":{\"train\":{\"ml_action_failure_count\":22,\"ml_action_request_count\":111}}}}", + "{\"ml_request_count\":100,\"algorithms\":{\"kmeans\":{\"train\":{\"ml_action_request_count\":111,\"ml_action_failure_count\":22}}}}" ); assertTrue(validResult.contains(taskContent)); } From 8439f45b3acaf344f121c5acb7b5b512921f63f4 Mon Sep 17 00:00:00 2001 From: Dhrubo Saha Date: Mon, 21 Aug 2023 10:55:35 -0700 Subject: [PATCH 03/11] updating test cases Signed-off-by: Dhrubo Saha --- .../ml/action/stats/MLStatsNodesResponseTests.java | 2 +- .../org/opensearch/ml/rest/RestMLStatsActionTests.java | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodesResponseTests.java b/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodesResponseTests.java index 865d207e96..c801da467c 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodesResponseTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodesResponseTests.java @@ -58,6 +58,6 @@ public void testToXContent() throws IOException { response.toXContent(builder, ToXContent.EMPTY_PARAMS); builder.endObject(); String taskContent = TestHelper.xContentBuilderToString(builder); - assertEquals("{\"nodes\":{\"node1\":{\"ML_REQUEST_COUNT\":100},\"node2\":{\"ML_REQUEST_COUNT\":200}}}", taskContent); + assertEquals("{\"nodes\":{\"node1\":{\"ml_request_count\":100},\"node2\":{\"ml_request_count\":200}}}", taskContent); } } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLStatsActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLStatsActionTests.java index b3bb3e8d1a..26b5e0ffd9 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLStatsActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLStatsActionTests.java @@ -209,7 +209,7 @@ public void testPrepareRequest_ClusterAndNodeLevelStates() throws Exception { content .utf8ToString() .contains( - "\"nodes\":{\"node\":{\"ML_REQUEST_COUNT\":100,\"algorithms\":{\"kmeans\":{\"train\":{\"ml_action_request_count\":20}}}}}}" + "\"nodes\":{\"node\":{\"ml_request_count\":100,\"algorithms\":{\"kmeans\":{\"train\":{\"ml_action_request_count\":20}}}}}}" ) ); } @@ -299,7 +299,7 @@ public void testPrepareRequest_ClusterAndNodeLevelStates_NoRequestContent() thro content .utf8ToString() .contains( - "\"nodes\":{\"node\":{\"ML_REQUEST_COUNT\":100,\"algorithms\":{\"kmeans\":{\"train\":{\"ml_action_request_count\":20}}}}}}" + "\"nodes\":{\"node\":{\"ml_request_count\":100,\"algorithms\":{\"kmeans\":{\"train\":{\"ml_action_request_count\":20}}}}}}" ) ); } @@ -334,7 +334,7 @@ public void testPrepareRequest_ClusterAndNodeLevelStates_RequestParams() throws content .utf8ToString() .contains( - "\"nodes\":{\"node\":{\"ML_REQUEST_COUNT\":100,\"algorithms\":{\"kmeans\":{\"train\":{\"ml_action_request_count\":20}}}}}}" + "\"nodes\":{\"node\":{\"ml_request_count\":100,\"algorithms\":{\"kmeans\":{\"train\":{\"ml_action_request_count\":20}}}}}}" ) ); } @@ -360,7 +360,7 @@ public void testPrepareRequest_ClusterAndNodeLevelStates_RequestParams_NodeLevel assertEquals(RestStatus.OK, restResponse.status()); BytesReference content = restResponse.content(); assertEquals( - "{\"nodes\":{\"node\":{\"ML_REQUEST_COUNT\":100,\"algorithms\":{\"kmeans\":{\"train\":{\"ml_action_request_count\":20}}}}}}", + "{\"nodes\":{\"node\":{\"ml_request_count\":100,\"algorithms\":{\"kmeans\":{\"train\":{\"ml_action_request_count\":20}}}}}}", content.utf8ToString() ); } From 30b1761ceccec1b59b9ea8eccdb3a07dbbfcda3b Mon Sep 17 00:00:00 2001 From: Dhrubo Saha Date: Mon, 21 Aug 2023 12:12:24 -0700 Subject: [PATCH 04/11] removing the ML_NODE checking for node level stats Signed-off-by: Dhrubo Saha --- .../java/org/opensearch/ml/rest/RestMLStatsAction.java | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLStatsAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLStatsAction.java index 4a946963a8..b6b006dfe6 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLStatsAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLStatsAction.java @@ -21,6 +21,7 @@ import java.util.Locale; import java.util.Map; import java.util.Optional; +import java.util.Set; import java.util.stream.Collectors; import org.opensearch.client.node.NodeClient; @@ -148,6 +149,10 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli } MLStatsInput createMlStatsInputFromRequestParams(RestRequest request) { + + Set mlNodeStatNames = EnumSet.allOf(MLNodeLevelStat.class).stream() + .map(stat -> stat.name()) + .collect(Collectors.toSet()); MLStatsInput mlStatsInput = new MLStatsInput(); Optional nodeIds = splitCommaSeparatedParam(request, "nodeId"); if (nodeIds.isPresent()) { @@ -158,7 +163,7 @@ MLStatsInput createMlStatsInputFromRequestParams(RestRequest request) { for (String state : stats.get()) { state = state.toUpperCase(Locale.ROOT); // only support cluster and node level stats for bwc - if (state.startsWith("ML_NODE")) { + if (mlNodeStatNames.contains(state)) { mlStatsInput.getNodeLevelStats().add(MLNodeLevelStat.from(state)); } else { mlStatsInput.getClusterLevelStats().add(MLClusterLevelStat.from(state)); From 1a057eb61f42ee7a5d6ac411396228f59f7254ca Mon Sep 17 00:00:00 2001 From: Dhrubo Saha Date: Mon, 21 Aug 2023 13:00:54 -0700 Subject: [PATCH 05/11] updating constructing new set Signed-off-by: Dhrubo Saha --- .../java/org/opensearch/ml/rest/RestMLStatsAction.java | 8 ++++---- .../java/org/opensearch/ml/task/MLTaskDispatcher.java | 3 +-- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLStatsAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLStatsAction.java index b6b006dfe6..47ede08fad 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLStatsAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLStatsAction.java @@ -61,6 +61,9 @@ public class RestMLStatsAction extends BaseRestHandler { private static final String QUERY_ALL_MODEL_META_DOC = "{\"query\":{\"bool\":{\"must_not\":{\"exists\":{\"field\":\"chunk_number\"}}}}}"; + private static final Set ML_NODE_STAT_NAMES = EnumSet.allOf(MLNodeLevelStat.class).stream(). + map(stat -> stat.name()).collect(Collectors.toSet()); + /** * Constructor * @param mlStats MLStats object @@ -150,9 +153,6 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli MLStatsInput createMlStatsInputFromRequestParams(RestRequest request) { - Set mlNodeStatNames = EnumSet.allOf(MLNodeLevelStat.class).stream() - .map(stat -> stat.name()) - .collect(Collectors.toSet()); MLStatsInput mlStatsInput = new MLStatsInput(); Optional nodeIds = splitCommaSeparatedParam(request, "nodeId"); if (nodeIds.isPresent()) { @@ -163,7 +163,7 @@ MLStatsInput createMlStatsInputFromRequestParams(RestRequest request) { for (String state : stats.get()) { state = state.toUpperCase(Locale.ROOT); // only support cluster and node level stats for bwc - if (mlNodeStatNames.contains(state)) { + if (ML_NODE_STAT_NAMES.contains(state)) { mlStatsInput.getNodeLevelStats().add(MLNodeLevelStat.from(state)); } else { mlStatsInput.getClusterLevelStats().add(MLClusterLevelStat.from(state)); diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLTaskDispatcher.java b/plugin/src/main/java/org/opensearch/ml/task/MLTaskDispatcher.java index f25285d5b3..89c8bfe2e2 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLTaskDispatcher.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLTaskDispatcher.java @@ -106,8 +106,7 @@ private void dispatchTaskWithLeastLoad(String[] nodeIds, ActionListener listener) { MLStatsNodesRequest MLStatsNodesRequest = new MLStatsNodesRequest(nodes); - MLStatsNodesRequest - .addNodeLevelStats(ImmutableSet.of(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT, MLNodeLevelStat.ML_JVM_HEAP_USAGE)); + MLStatsNodesRequest.addNodeLevelStats(ImmutableSet.of(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT, MLNodeLevelStat.ML_JVM_HEAP_USAGE)); client.execute(MLStatsNodesAction.INSTANCE, MLStatsNodesRequest, ActionListener.wrap(mlStatsResponse -> { // Check JVM pressure From 7e3b4de906b97b08541894ffabcfe0c43031fd52 Mon Sep 17 00:00:00 2001 From: Dhrubo Saha Date: Mon, 21 Aug 2023 13:12:59 -0700 Subject: [PATCH 06/11] spotless Apply Signed-off-by: Dhrubo Saha --- .../java/org/opensearch/ml/rest/RestMLStatsAction.java | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLStatsAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLStatsAction.java index 47ede08fad..034b7b2054 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLStatsAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLStatsAction.java @@ -61,8 +61,11 @@ public class RestMLStatsAction extends BaseRestHandler { private static final String QUERY_ALL_MODEL_META_DOC = "{\"query\":{\"bool\":{\"must_not\":{\"exists\":{\"field\":\"chunk_number\"}}}}}"; - private static final Set ML_NODE_STAT_NAMES = EnumSet.allOf(MLNodeLevelStat.class).stream(). - map(stat -> stat.name()).collect(Collectors.toSet()); + private static final Set ML_NODE_STAT_NAMES = EnumSet + .allOf(MLNodeLevelStat.class) + .stream() + .map(stat -> stat.name()) + .collect(Collectors.toSet()); /** * Constructor From 42f2c8c719138a05ab57c4379110fc9894989083 Mon Sep 17 00:00:00 2001 From: Dhrubo Saha Date: Mon, 21 Aug 2023 14:24:31 -0700 Subject: [PATCH 07/11] updating ML_NODE_TOTAL_MODEL_COUNT to ML_DEPLOYED_MODEL_COUNT Signed-off-by: Dhrubo Saha --- .../java/org/opensearch/ml/model/MLModelManager.java | 10 +++++----- .../opensearch/ml/plugin/MachineLearningPlugin.java | 2 +- .../org/opensearch/ml/stats/MLNodeLevelStat.java | 2 +- .../org/opensearch/ml/model/MLModelManagerTests.java | 2 +- .../opensearch/ml/rest/RestMLStatsActionTests.java | 12 ++++++------ .../opensearch/ml/task/MLExecuteTaskRunnerTests.java | 2 +- .../opensearch/ml/task/MLPredictTaskRunnerTests.java | 2 +- .../ml/task/MLTrainAndPredictTaskRunnerTests.java | 2 +- .../ml/task/MLTrainingTaskRunnerTests.java | 2 +- .../java/org/opensearch/ml/task/TaskRunnerTests.java | 2 +- 10 files changed, 19 insertions(+), 19 deletions(-) diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java index 07658dae1a..2356405155 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java @@ -800,7 +800,7 @@ public void deployModel( MLExecutable mlExecutable = mlEngine.deployExecute(mlModel, params); try { modelCacheHelper.setMLExecutor(modelId, mlExecutable); - mlStats.getStat(MLNodeLevelStat.ML_TOTAL_MODEL_COUNT).increment(); + mlStats.getStat(MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT).increment(); modelCacheHelper.setModelState(modelId, MLModelState.DEPLOYED); listener.onResponse("successful"); } catch (Exception e) { @@ -813,7 +813,7 @@ public void deployModel( Predictable predictable = mlEngine.deploy(mlModel, params); try { modelCacheHelper.setPredictor(modelId, predictable); - mlStats.getStat(MLNodeLevelStat.ML_TOTAL_MODEL_COUNT).increment(); + mlStats.getStat(MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT).increment(); modelCacheHelper.setModelState(modelId, MLModelState.DEPLOYED); Long modelContentSizeInBytes = mlModel.getModelContentSizeInBytes(); long contentSize = modelContentSizeInBytes == null @@ -855,7 +855,7 @@ private void handleDeployModelException(String modelId, FunctionName functionNam private void setupPredictable(String modelId, MLModel mlModel, Map params) { Predictable predictable = mlEngine.deploy(mlModel, params); modelCacheHelper.setPredictor(modelId, predictable); - mlStats.getStat(MLNodeLevelStat.ML_TOTAL_MODEL_COUNT).increment(); + mlStats.getStat(MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT).increment(); modelCacheHelper.setModelState(modelId, MLModelState.DEPLOYED); } @@ -1056,7 +1056,7 @@ public synchronized Map undeployModel(String[] modelIds) { for (String modelId : modelIds) { if (modelCacheHelper.isModelDeployed(modelId)) { modelUndeployStatus.put(modelId, UNDEPLOYED); - mlStats.getStat(MLNodeLevelStat.ML_TOTAL_MODEL_COUNT).decrement(); + mlStats.getStat(MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT).decrement(); mlStats.getStat(MLNodeLevelStat.ML_REQUEST_COUNT).increment(); mlStats .createCounterStatIfAbsent(getModelFunctionName(modelId), ActionName.UNDEPLOY, ML_ACTION_REQUEST_COUNT) @@ -1070,7 +1070,7 @@ public synchronized Map undeployModel(String[] modelIds) { log.debug("undeploy all models {}", Arrays.toString(getLocalDeployedModels())); for (String modelId : getLocalDeployedModels()) { modelUndeployStatus.put(modelId, UNDEPLOYED); - mlStats.getStat(MLNodeLevelStat.ML_TOTAL_MODEL_COUNT).decrement(); + mlStats.getStat(MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT).decrement(); mlStats.createCounterStatIfAbsent(getModelFunctionName(modelId), ActionName.UNDEPLOY, ML_ACTION_REQUEST_COUNT).increment(); removeModel(modelId); } diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index a30ebf2b97..6ee5c42bd4 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -308,7 +308,7 @@ public Collection createComponents( stats.put(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT, new MLStat<>(false, new CounterSupplier())); stats.put(MLNodeLevelStat.ML_REQUEST_COUNT, new MLStat<>(false, new CounterSupplier())); stats.put(MLNodeLevelStat.ML_FAILURE_COUNT, new MLStat<>(false, new CounterSupplier())); - stats.put(MLNodeLevelStat.ML_TOTAL_MODEL_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT, new MLStat<>(false, new CounterSupplier())); stats.put(MLNodeLevelStat.ML_CIRCUIT_BREAKER_TRIGGER_COUNT, new MLStat<>(false, new CounterSupplier())); this.mlStats = new MLStats(stats); diff --git a/plugin/src/main/java/org/opensearch/ml/stats/MLNodeLevelStat.java b/plugin/src/main/java/org/opensearch/ml/stats/MLNodeLevelStat.java index f230e9c49e..7c4acdf91f 100644 --- a/plugin/src/main/java/org/opensearch/ml/stats/MLNodeLevelStat.java +++ b/plugin/src/main/java/org/opensearch/ml/stats/MLNodeLevelStat.java @@ -14,7 +14,7 @@ public enum MLNodeLevelStat { ML_EXECUTING_TASK_COUNT, ML_REQUEST_COUNT, ML_FAILURE_COUNT, - ML_TOTAL_MODEL_COUNT, + ML_DEPLOYED_MODEL_COUNT, ML_CIRCUIT_BREAKER_TRIGGER_COUNT; public static MLNodeLevelStat from(String value) { diff --git a/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java b/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java index 248b3e3aa3..7dc36f41ac 100644 --- a/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java @@ -216,7 +216,7 @@ public void setup() throws URISyntaxException { stats.put(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT, new MLStat<>(false, new CounterSupplier())); stats.put(MLNodeLevelStat.ML_REQUEST_COUNT, new MLStat<>(false, new CounterSupplier())); stats.put(MLNodeLevelStat.ML_FAILURE_COUNT, new MLStat<>(false, new CounterSupplier())); - stats.put(MLNodeLevelStat.ML_TOTAL_MODEL_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT, new MLStat<>(false, new CounterSupplier())); stats.put(MLNodeLevelStat.ML_CIRCUIT_BREAKER_TRIGGER_COUNT, new MLStat<>(false, new CounterSupplier())); this.mlStats = spy(new MLStats(stats)); diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLStatsActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLStatsActionTests.java index 26b5e0ffd9..d0fae76be9 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLStatsActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLStatsActionTests.java @@ -309,7 +309,7 @@ public void testPrepareRequest_ClusterAndNodeLevelStates_RequestParams() throws RestRequest request = getStatsRestRequest( node.getId(), - MLClusterLevelStat.ML_MODEL_COUNT + "," + MLNodeLevelStat.ML_TOTAL_MODEL_COUNT + MLClusterLevelStat.ML_MODEL_COUNT + "," + MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT ); restAction.handleRequest(request, channel, client); @@ -321,7 +321,7 @@ public void testPrepareRequest_ClusterAndNodeLevelStates_RequestParams() throws assertTrue(input.getTargetStatLevels().contains(MLStatLevel.NODE)); assertEquals(1, input.getClusterLevelStats().size()); assertTrue(input.getClusterLevelStats().contains(MLClusterLevelStat.ML_MODEL_COUNT)); - assertTrue(input.getNodeLevelStats().contains(MLNodeLevelStat.ML_TOTAL_MODEL_COUNT)); + assertTrue(input.getNodeLevelStats().contains(MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT)); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(BytesRestResponse.class); verify(channel, times(1)).sendResponse(argumentCaptor.capture()); @@ -342,7 +342,7 @@ public void testPrepareRequest_ClusterAndNodeLevelStates_RequestParams() throws public void testPrepareRequest_ClusterAndNodeLevelStates_RequestParams_NodeLevelStat() throws Exception { prepareResponse(); - RestRequest request = getStatsRestRequest(node.getId(), MLNodeLevelStat.ML_TOTAL_MODEL_COUNT.name()); + RestRequest request = getStatsRestRequest(node.getId(), MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT.name()); restAction.handleRequest(request, channel, client); ArgumentCaptor inputArgumentCaptor = ArgumentCaptor.forClass(MLStatsNodesRequest.class); @@ -352,7 +352,7 @@ public void testPrepareRequest_ClusterAndNodeLevelStates_RequestParams_NodeLevel assertTrue(input.getTargetStatLevels().contains(MLStatLevel.NODE)); assertEquals(0, input.getClusterLevelStats().size()); assertEquals(1, input.getNodeLevelStats().size()); - assertTrue(input.getNodeLevelStats().contains(MLNodeLevelStat.ML_TOTAL_MODEL_COUNT)); + assertTrue(input.getNodeLevelStats().contains(MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT)); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(BytesRestResponse.class); verify(channel, times(1)).sendResponse(argumentCaptor.capture()); @@ -366,11 +366,11 @@ public void testPrepareRequest_ClusterAndNodeLevelStates_RequestParams_NodeLevel } public void testCreateMlStatsInputFromRequestParams_NodeStat() { - RestRequest request = getStatsRestRequest(node.getId(), MLNodeLevelStat.ML_TOTAL_MODEL_COUNT.name().toLowerCase(Locale.ROOT)); + RestRequest request = getStatsRestRequest(node.getId(), MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT.name().toLowerCase(Locale.ROOT)); MLStatsInput input = restAction.createMlStatsInputFromRequestParams(request); assertEquals(1, input.getTargetStatLevels().size()); assertTrue(input.getTargetStatLevels().contains(MLStatLevel.NODE)); - assertTrue(input.getNodeLevelStats().contains(MLNodeLevelStat.ML_TOTAL_MODEL_COUNT)); + assertTrue(input.getNodeLevelStats().contains(MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT)); assertEquals(0, input.getClusterLevelStats().size()); } diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLExecuteTaskRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLExecuteTaskRunnerTests.java index 47d5edda5c..11e9bc3441 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/MLExecuteTaskRunnerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/MLExecuteTaskRunnerTests.java @@ -122,7 +122,7 @@ public void setup() { stats.put(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT, new MLStat<>(false, new CounterSupplier())); stats.put(MLNodeLevelStat.ML_REQUEST_COUNT, new MLStat<>(false, new CounterSupplier())); stats.put(MLNodeLevelStat.ML_FAILURE_COUNT, new MLStat<>(false, new CounterSupplier())); - stats.put(MLNodeLevelStat.ML_TOTAL_MODEL_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT, new MLStat<>(false, new CounterSupplier())); this.mlStats = new MLStats(stats); mlInputDatasetHandler = spy(new MLInputDatasetHandler(client)); diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java index 4310ebad66..5f18d974a7 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java @@ -144,7 +144,7 @@ public void setup() throws IOException { stats.put(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT, new MLStat<>(false, new CounterSupplier())); stats.put(MLNodeLevelStat.ML_REQUEST_COUNT, new MLStat<>(false, new CounterSupplier())); stats.put(MLNodeLevelStat.ML_FAILURE_COUNT, new MLStat<>(false, new CounterSupplier())); - stats.put(MLNodeLevelStat.ML_TOTAL_MODEL_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT, new MLStat<>(false, new CounterSupplier())); this.mlStats = new MLStats(stats); mlInputDatasetHandler = spy(new MLInputDatasetHandler(client)); taskRunner = spy( diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunnerTests.java index 0a18058a51..a40c5c87cf 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunnerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunnerTests.java @@ -120,7 +120,7 @@ public void setup() { stats.put(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT, new MLStat<>(false, new CounterSupplier())); stats.put(MLNodeLevelStat.ML_REQUEST_COUNT, new MLStat<>(false, new CounterSupplier())); stats.put(MLNodeLevelStat.ML_FAILURE_COUNT, new MLStat<>(false, new CounterSupplier())); - stats.put(MLNodeLevelStat.ML_TOTAL_MODEL_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT, new MLStat<>(false, new CounterSupplier())); this.mlStats = new MLStats(stats); mlInputDatasetHandler = spy(new MLInputDatasetHandler(client)); diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLTrainingTaskRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLTrainingTaskRunnerTests.java index 724586568a..ae397067bc 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/MLTrainingTaskRunnerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/MLTrainingTaskRunnerTests.java @@ -128,7 +128,7 @@ public void setup() { stats.put(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT, new MLStat<>(false, new CounterSupplier())); stats.put(MLNodeLevelStat.ML_REQUEST_COUNT, new MLStat<>(false, new CounterSupplier())); stats.put(MLNodeLevelStat.ML_FAILURE_COUNT, new MLStat<>(false, new CounterSupplier())); - stats.put(MLNodeLevelStat.ML_TOTAL_MODEL_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT, new MLStat<>(false, new CounterSupplier())); this.mlStats = new MLStats(stats); mlInputDatasetHandler = spy(new MLInputDatasetHandler(client)); diff --git a/plugin/src/test/java/org/opensearch/ml/task/TaskRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/task/TaskRunnerTests.java index 04c25b867e..9e2abccebb 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/TaskRunnerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/TaskRunnerTests.java @@ -72,7 +72,7 @@ public void setup() { stats.put(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT, new MLStat<>(false, new CounterSupplier())); stats.put(MLNodeLevelStat.ML_REQUEST_COUNT, new MLStat<>(false, new CounterSupplier())); stats.put(MLNodeLevelStat.ML_FAILURE_COUNT, new MLStat<>(false, new CounterSupplier())); - stats.put(MLNodeLevelStat.ML_TOTAL_MODEL_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT, new MLStat<>(false, new CounterSupplier())); stats.put(MLNodeLevelStat.ML_CIRCUIT_BREAKER_TRIGGER_COUNT, new MLStat<>(false, new CounterSupplier())); mlStats = new MLStats(stats); From 04e2106cb73f316085064078f3698c67d7977b70 Mon Sep 17 00:00:00 2001 From: Dhrubo Saha Date: Mon, 21 Aug 2023 17:12:18 -0700 Subject: [PATCH 08/11] fixing metrics count Signed-off-by: Dhrubo Saha --- .../ml/action/deploy/TransportDeployModelAction.java | 2 -- .../register/TransportRegisterModelAction.java | 6 ------ .../undeploy/TransportUndeployModelAction.java | 2 -- .../java/org/opensearch/ml/model/MLModelManager.java | 12 +----------- 4 files changed, 1 insertion(+), 21 deletions(-) diff --git a/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelAction.java b/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelAction.java index d931ab80ef..2873771c9f 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelAction.java @@ -148,8 +148,6 @@ protected void doExecute(Task task, ActionRequest request, ActionListener nodeMapping = new HashMap<>(); for (DiscoveryNode node : allEligibleNodes) { diff --git a/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java b/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java index 963d7450e2..af288ae45f 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java @@ -234,12 +234,6 @@ private void registerModel(MLRegisterModelInput registerModelInput, ActionListen throw new IllegalArgumentException("URL can't match trusted url regex"); } } - // mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).increment(); - mlStats.getStat(MLNodeLevelStat.ML_REQUEST_COUNT).increment(); - // //TODO: track executing task; track register failures - // mlStats.createCounterStatIfAbsent(FunctionName.TEXT_EMBEDDING, - // ActionName.REGISTER, - // MLActionLevelStat.ML_ACTION_REQUEST_COUNT).increment(); boolean isAsync = registerModelInput.getFunctionName() != FunctionName.REMOTE; MLTask mlTask = MLTask .builder() diff --git a/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelAction.java b/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelAction.java index 4ca04aa609..f0bfe71258 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelAction.java @@ -229,7 +229,6 @@ protected MLUndeployModelNodeResponse nodeOperation(MLUndeployModelNodeRequest r private MLUndeployModelNodeResponse createUndeployModelNodeResponse(MLUndeployModelNodesRequest MLUndeployModelNodesRequest) { mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).increment(); - mlStats.getStat(MLNodeLevelStat.ML_REQUEST_COUNT).increment(); String[] modelIds = MLUndeployModelNodesRequest.getModelIds(); @@ -246,7 +245,6 @@ private MLUndeployModelNodeResponse createUndeployModelNodeResponse(MLUndeployMo } Map modelUndeployStatus = mlModelManager.undeployModel(modelIds); - mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).decrement(); return new MLUndeployModelNodeResponse(clusterService.localNode(), modelUndeployStatus, modelWorkerNodesMap); } } diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java index 2356405155..094632f8cb 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java @@ -323,7 +323,6 @@ public void registerMLModel(MLRegisterModelInput registerModelInput, MLTask mlTa checkAndAddRunningTask(mlTask, maxRegisterTasksPerNode); try { mlStats.getStat(MLNodeLevelStat.ML_REQUEST_COUNT).increment(); - mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).increment(); mlStats.createCounterStatIfAbsent(mlTask.getFunctionName(), REGISTER, ML_ACTION_REQUEST_COUNT).increment(); String modelGroupId = registerModelInput.getModelGroupId(); @@ -392,9 +391,6 @@ private void indexRemoteModel(MLRegisterModelInput registerModelInput, MLTask ml String taskId = mlTask.getTaskId(); FunctionName functionName = mlTask.getFunctionName(); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - mlStats.getStat(MLNodeLevelStat.ML_REQUEST_COUNT).increment(); - mlStats.createCounterStatIfAbsent(functionName, REGISTER, ML_ACTION_REQUEST_COUNT).increment(); - mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).increment(); String modelName = registerModelInput.getModelName(); String version = modelVersion == null ? registerModelInput.getVersion() : modelVersion; @@ -443,8 +439,6 @@ private void indexRemoteModel(MLRegisterModelInput registerModelInput, MLTask ml } catch (Exception e) { logException("Failed to upload model", e, log); handleException(functionName, taskId, e); - } finally { - mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).increment(); } } @@ -462,9 +456,6 @@ private void registerModelFromUrl(MLRegisterModelInput registerModelInput, MLTas String taskId = mlTask.getTaskId(); FunctionName functionName = mlTask.getFunctionName(); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - mlStats.getStat(MLNodeLevelStat.ML_REQUEST_COUNT).increment(); - mlStats.createCounterStatIfAbsent(functionName, REGISTER, ML_ACTION_REQUEST_COUNT).increment(); - mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).increment(); String modelName = registerModelInput.getModelName(); String version = modelVersion == null ? registerModelInput.getVersion() : modelVersion; String modelGroupId = registerModelInput.getModelGroupId(); @@ -509,8 +500,6 @@ private void registerModelFromUrl(MLRegisterModelInput registerModelInput, MLTas } catch (Exception e) { logException("Failed to register model", e, log); handleException(functionName, taskId, e); - } finally { - mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).increment(); } } @@ -718,6 +707,7 @@ public void deployModel( ActionListener listener ) { mlStats.createCounterStatIfAbsent(functionName, ActionName.DEPLOY, ML_ACTION_REQUEST_COUNT).increment(); + mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).increment(); mlStats.getStat(MLNodeLevelStat.ML_REQUEST_COUNT).increment(); List workerNodes = mlTask.getWorkerNodes(); if (modelCacheHelper.isModelDeployed(modelId)) { From 569d8fc93ba3531996ab3e26902fa0647774c39d Mon Sep 17 00:00:00 2001 From: Dhrubo Saha Date: Mon, 21 Aug 2023 17:19:44 -0700 Subject: [PATCH 09/11] spotless Signed-off-by: Dhrubo Saha --- .../opensearch/ml/action/deploy/TransportDeployModelAction.java | 1 - .../ml/action/register/TransportRegisterModelAction.java | 1 - 2 files changed, 2 deletions(-) diff --git a/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelAction.java b/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelAction.java index 2873771c9f..acdfb5ec09 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelAction.java @@ -54,7 +54,6 @@ import org.opensearch.ml.helper.ModelAccessControlHelper; import org.opensearch.ml.model.MLModelManager; import org.opensearch.ml.settings.MLFeatureEnabledSetting; -import org.opensearch.ml.stats.MLNodeLevelStat; import org.opensearch.ml.stats.MLStats; import org.opensearch.ml.task.MLTaskDispatcher; import org.opensearch.ml.task.MLTaskManager; diff --git a/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java b/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java index af288ae45f..7c0e2d6b9d 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java @@ -54,7 +54,6 @@ import org.opensearch.ml.indices.MLIndicesHandler; import org.opensearch.ml.model.MLModelGroupManager; import org.opensearch.ml.model.MLModelManager; -import org.opensearch.ml.stats.MLNodeLevelStat; import org.opensearch.ml.stats.MLStats; import org.opensearch.ml.task.MLTaskDispatcher; import org.opensearch.ml.task.MLTaskManager; From 62b0e7bd98984b49274a833a3afbc21a013a8fe2 Mon Sep 17 00:00:00 2001 From: Dhrubo Saha Date: Mon, 21 Aug 2023 18:02:09 -0700 Subject: [PATCH 10/11] fixing executing task Signed-off-by: Dhrubo Saha --- .../ml/action/undeploy/TransportUndeployModelAction.java | 1 + .../main/java/org/opensearch/ml/model/MLModelManager.java | 5 ++++- .../main/java/org/opensearch/ml/stats/MLNodeLevelStat.java | 3 ++- 3 files changed, 7 insertions(+), 2 deletions(-) diff --git a/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelAction.java b/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelAction.java index f0bfe71258..b8097c72e5 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelAction.java @@ -245,6 +245,7 @@ private MLUndeployModelNodeResponse createUndeployModelNodeResponse(MLUndeployMo } Map modelUndeployStatus = mlModelManager.undeployModel(modelIds); + mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).decrement(); return new MLUndeployModelNodeResponse(clusterService.localNode(), modelUndeployStatus, modelWorkerNodesMap); } } diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java index 094632f8cb..56b0a743f9 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java @@ -324,6 +324,7 @@ public void registerMLModel(MLRegisterModelInput registerModelInput, MLTask mlTa try { mlStats.getStat(MLNodeLevelStat.ML_REQUEST_COUNT).increment(); mlStats.createCounterStatIfAbsent(mlTask.getFunctionName(), REGISTER, ML_ACTION_REQUEST_COUNT).increment(); + mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).increment(); String modelGroupId = registerModelInput.getModelGroupId(); GetRequest getModelGroupRequest = new GetRequest(ML_MODEL_GROUP_INDEX).id(modelGroupId); @@ -383,7 +384,7 @@ public void registerMLModel(MLRegisterModelInput registerModelInput, MLTask mlTa } catch (Exception e) { handleException(registerModelInput.getFunctionName(), mlTask.getTaskId(), e); } finally { - mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).increment(); + mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).decrement(); } } @@ -827,6 +828,8 @@ public void deployModel( }))); } catch (Exception e) { handleDeployModelException(modelId, functionName, listener, e); + } finally { + mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).decrement(); } } diff --git a/plugin/src/main/java/org/opensearch/ml/stats/MLNodeLevelStat.java b/plugin/src/main/java/org/opensearch/ml/stats/MLNodeLevelStat.java index 7c4acdf91f..f78be1d4af 100644 --- a/plugin/src/main/java/org/opensearch/ml/stats/MLNodeLevelStat.java +++ b/plugin/src/main/java/org/opensearch/ml/stats/MLNodeLevelStat.java @@ -11,7 +11,8 @@ */ public enum MLNodeLevelStat { ML_JVM_HEAP_USAGE, - ML_EXECUTING_TASK_COUNT, + ML_EXECUTING_TASK_COUNT, // How many tasks are executing currently. If any task starts, then it will be 1, if the task finished then it + // will get back to 0. ML_REQUEST_COUNT, ML_FAILURE_COUNT, ML_DEPLOYED_MODEL_COUNT, From 0b61450fa51a1e9ee67d1816cacb6a1534f0601f Mon Sep 17 00:00:00 2001 From: Dhrubo Saha Date: Mon, 21 Aug 2023 18:11:38 -0700 Subject: [PATCH 11/11] updating comment Signed-off-by: Dhrubo Saha --- .../main/java/org/opensearch/ml/stats/MLNodeLevelStat.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/plugin/src/main/java/org/opensearch/ml/stats/MLNodeLevelStat.java b/plugin/src/main/java/org/opensearch/ml/stats/MLNodeLevelStat.java index f78be1d4af..857721392a 100644 --- a/plugin/src/main/java/org/opensearch/ml/stats/MLNodeLevelStat.java +++ b/plugin/src/main/java/org/opensearch/ml/stats/MLNodeLevelStat.java @@ -11,8 +11,8 @@ */ public enum MLNodeLevelStat { ML_JVM_HEAP_USAGE, - ML_EXECUTING_TASK_COUNT, // How many tasks are executing currently. If any task starts, then it will be 1, if the task finished then it - // will get back to 0. + ML_EXECUTING_TASK_COUNT, // How many tasks are executing currently. If any task starts, then it will increase by 1, + // if the task finished then it will decrease by 0. ML_REQUEST_COUNT, ML_FAILURE_COUNT, ML_DEPLOYED_MODEL_COUNT,