diff --git a/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelAction.java b/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelAction.java index 3915886360..acdfb5ec09 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelAction.java @@ -54,7 +54,6 @@ import org.opensearch.ml.helper.ModelAccessControlHelper; import org.opensearch.ml.model.MLModelManager; import org.opensearch.ml.settings.MLFeatureEnabledSetting; -import org.opensearch.ml.stats.MLNodeLevelStat; import org.opensearch.ml.stats.MLStats; import org.opensearch.ml.task.MLTaskDispatcher; import org.opensearch.ml.task.MLTaskManager; @@ -148,8 +147,6 @@ protected void doExecute(Task task, ActionRequest request, ActionListener nodeMapping = new HashMap<>(); for (DiscoveryNode node : allEligibleNodes) { diff --git a/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java b/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java index c5d426e348..7c0e2d6b9d 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java @@ -54,7 +54,6 @@ import org.opensearch.ml.indices.MLIndicesHandler; import org.opensearch.ml.model.MLModelGroupManager; import org.opensearch.ml.model.MLModelManager; -import org.opensearch.ml.stats.MLNodeLevelStat; import org.opensearch.ml.stats.MLStats; import org.opensearch.ml.task.MLTaskDispatcher; import org.opensearch.ml.task.MLTaskManager; @@ -234,12 +233,6 @@ private void registerModel(MLRegisterModelInput registerModelInput, ActionListen throw new IllegalArgumentException("URL can't match trusted url regex"); } } - // mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).increment(); - mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT).increment(); - // //TODO: track executing task; track register failures - // mlStats.createCounterStatIfAbsent(FunctionName.TEXT_EMBEDDING, - // ActionName.REGISTER, - // MLActionLevelStat.ML_ACTION_REQUEST_COUNT).increment(); boolean isAsync = registerModelInput.getFunctionName() != FunctionName.REMOTE; MLTask mlTask = MLTask .builder() diff --git a/plugin/src/main/java/org/opensearch/ml/action/stats/MLStatsNodesTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/stats/MLStatsNodesTransportAction.java index e6c59bf771..f1809de9bb 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/stats/MLStatsNodesTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/stats/MLStatsNodesTransportAction.java @@ -97,9 +97,9 @@ MLStatsNodeResponse createMLStatsNodeResponse(MLStatsNodesRequest mlStatsNodesRe MLStatsInput mlStatsInput = mlStatsNodesRequest.getMlStatsInput(); // return node level stats if (mlStatsInput.getTargetStatLevels().contains(MLStatLevel.NODE)) { - if (mlStatsInput.retrieveStat(MLNodeLevelStat.ML_NODE_JVM_HEAP_USAGE)) { + if (mlStatsInput.retrieveStat(MLNodeLevelStat.ML_JVM_HEAP_USAGE)) { long heapUsedPercent = jvmService.stats().getMem().getHeapUsedPercent(); - statValues.put(MLNodeLevelStat.ML_NODE_JVM_HEAP_USAGE, heapUsedPercent); + statValues.put(MLNodeLevelStat.ML_JVM_HEAP_USAGE, heapUsedPercent); } for (Enum statName : mlStats.getNodeStats().keySet()) { diff --git a/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelAction.java b/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelAction.java index bbed0305e1..b8097c72e5 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelAction.java @@ -228,8 +228,7 @@ protected MLUndeployModelNodeResponse nodeOperation(MLUndeployModelNodeRequest r } private MLUndeployModelNodeResponse createUndeployModelNodeResponse(MLUndeployModelNodesRequest MLUndeployModelNodesRequest) { - mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).increment(); - mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT).increment(); + mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).increment(); String[] modelIds = MLUndeployModelNodesRequest.getModelIds(); @@ -246,7 +245,7 @@ private MLUndeployModelNodeResponse createUndeployModelNodeResponse(MLUndeployMo } Map modelUndeployStatus = mlModelManager.undeployModel(modelIds); - mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).decrement(); + mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).decrement(); return new MLUndeployModelNodeResponse(clusterService.localNode(), modelUndeployStatus, modelWorkerNodesMap); } } diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java index 06295ac138..56b0a743f9 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java @@ -212,7 +212,7 @@ public MLModelManager( public void registerModelMeta(MLRegisterModelMetaInput mlRegisterModelMetaInput, ActionListener listener) { try { FunctionName functionName = mlRegisterModelMetaInput.getFunctionName(); - mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT).increment(); + mlStats.getStat(MLNodeLevelStat.ML_REQUEST_COUNT).increment(); mlStats.createCounterStatIfAbsent(functionName, REGISTER, ML_ACTION_REQUEST_COUNT).increment(); String modelGroupId = mlRegisterModelMetaInput.getModelGroupId(); if (Strings.isBlank(modelGroupId)) { @@ -322,9 +322,9 @@ public void registerMLModel(MLRegisterModelInput registerModelInput, MLTask mlTa checkAndAddRunningTask(mlTask, maxRegisterTasksPerNode); try { - mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT).increment(); - mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).increment(); + mlStats.getStat(MLNodeLevelStat.ML_REQUEST_COUNT).increment(); mlStats.createCounterStatIfAbsent(mlTask.getFunctionName(), REGISTER, ML_ACTION_REQUEST_COUNT).increment(); + mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).increment(); String modelGroupId = registerModelInput.getModelGroupId(); GetRequest getModelGroupRequest = new GetRequest(ML_MODEL_GROUP_INDEX).id(modelGroupId); @@ -384,7 +384,7 @@ public void registerMLModel(MLRegisterModelInput registerModelInput, MLTask mlTa } catch (Exception e) { handleException(registerModelInput.getFunctionName(), mlTask.getTaskId(), e); } finally { - mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).increment(); + mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).decrement(); } } @@ -392,9 +392,6 @@ private void indexRemoteModel(MLRegisterModelInput registerModelInput, MLTask ml String taskId = mlTask.getTaskId(); FunctionName functionName = mlTask.getFunctionName(); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT).increment(); - mlStats.createCounterStatIfAbsent(functionName, REGISTER, ML_ACTION_REQUEST_COUNT).increment(); - mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).increment(); String modelName = registerModelInput.getModelName(); String version = modelVersion == null ? registerModelInput.getVersion() : modelVersion; @@ -443,8 +440,6 @@ private void indexRemoteModel(MLRegisterModelInput registerModelInput, MLTask ml } catch (Exception e) { logException("Failed to upload model", e, log); handleException(functionName, taskId, e); - } finally { - mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).increment(); } } @@ -462,9 +457,6 @@ private void registerModelFromUrl(MLRegisterModelInput registerModelInput, MLTas String taskId = mlTask.getTaskId(); FunctionName functionName = mlTask.getFunctionName(); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT).increment(); - mlStats.createCounterStatIfAbsent(functionName, REGISTER, ML_ACTION_REQUEST_COUNT).increment(); - mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).increment(); String modelName = registerModelInput.getModelName(); String version = modelVersion == null ? registerModelInput.getVersion() : modelVersion; String modelGroupId = registerModelInput.getModelGroupId(); @@ -509,8 +501,6 @@ private void registerModelFromUrl(MLRegisterModelInput registerModelInput, MLTas } catch (Exception e) { logException("Failed to register model", e, log); handleException(functionName, taskId, e); - } finally { - mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).increment(); } } @@ -693,7 +683,7 @@ private void handleException(FunctionName functionName, String taskId, Exception && !(e instanceof MLResourceNotFoundException) && !(e instanceof IllegalArgumentException)) { mlStats.createCounterStatIfAbsent(functionName, REGISTER, MLActionLevelStat.ML_ACTION_FAILURE_COUNT).increment(); - mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_FAILURE_COUNT).increment(); + mlStats.getStat(MLNodeLevelStat.ML_FAILURE_COUNT).increment(); } Map updated = ImmutableMap.of(ERROR_FIELD, MLExceptionUtils.getRootCauseMessage(e), STATE_FIELD, FAILED); mlTaskManager.updateMLTask(taskId, updated, TIMEOUT_IN_MILLIS, true); @@ -718,7 +708,8 @@ public void deployModel( ActionListener listener ) { mlStats.createCounterStatIfAbsent(functionName, ActionName.DEPLOY, ML_ACTION_REQUEST_COUNT).increment(); - mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT).increment(); + mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).increment(); + mlStats.getStat(MLNodeLevelStat.ML_REQUEST_COUNT).increment(); List workerNodes = mlTask.getWorkerNodes(); if (modelCacheHelper.isModelDeployed(modelId)) { if (workerNodes != null && workerNodes.size() > 0) { @@ -800,7 +791,7 @@ public void deployModel( MLExecutable mlExecutable = mlEngine.deployExecute(mlModel, params); try { modelCacheHelper.setMLExecutor(modelId, mlExecutable); - mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_MODEL_COUNT).increment(); + mlStats.getStat(MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT).increment(); modelCacheHelper.setModelState(modelId, MLModelState.DEPLOYED); listener.onResponse("successful"); } catch (Exception e) { @@ -813,7 +804,7 @@ public void deployModel( Predictable predictable = mlEngine.deploy(mlModel, params); try { modelCacheHelper.setPredictor(modelId, predictable); - mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_MODEL_COUNT).increment(); + mlStats.getStat(MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT).increment(); modelCacheHelper.setModelState(modelId, MLModelState.DEPLOYED); Long modelContentSizeInBytes = mlModel.getModelContentSizeInBytes(); long contentSize = modelContentSizeInBytes == null @@ -837,6 +828,8 @@ public void deployModel( }))); } catch (Exception e) { handleDeployModelException(modelId, functionName, listener, e); + } finally { + mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).decrement(); } } @@ -846,7 +839,7 @@ private void handleDeployModelException(String modelId, FunctionName functionNam && !(e instanceof MLResourceNotFoundException) && !(e instanceof IllegalArgumentException)) { mlStats.createCounterStatIfAbsent(functionName, ActionName.DEPLOY, MLActionLevelStat.ML_ACTION_FAILURE_COUNT).increment(); - mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_FAILURE_COUNT).increment(); + mlStats.getStat(MLNodeLevelStat.ML_FAILURE_COUNT).increment(); } removeModel(modelId); listener.onFailure(e); @@ -855,7 +848,7 @@ private void handleDeployModelException(String modelId, FunctionName functionNam private void setupPredictable(String modelId, MLModel mlModel, Map params) { Predictable predictable = mlEngine.deploy(mlModel, params); modelCacheHelper.setPredictor(modelId, predictable); - mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_MODEL_COUNT).increment(); + mlStats.getStat(MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT).increment(); modelCacheHelper.setModelState(modelId, MLModelState.DEPLOYED); } @@ -1056,8 +1049,8 @@ public synchronized Map undeployModel(String[] modelIds) { for (String modelId : modelIds) { if (modelCacheHelper.isModelDeployed(modelId)) { modelUndeployStatus.put(modelId, UNDEPLOYED); - mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_MODEL_COUNT).decrement(); - mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT).increment(); + mlStats.getStat(MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT).decrement(); + mlStats.getStat(MLNodeLevelStat.ML_REQUEST_COUNT).increment(); mlStats .createCounterStatIfAbsent(getModelFunctionName(modelId), ActionName.UNDEPLOY, ML_ACTION_REQUEST_COUNT) .increment(); @@ -1070,7 +1063,7 @@ public synchronized Map undeployModel(String[] modelIds) { log.debug("undeploy all models {}", Arrays.toString(getLocalDeployedModels())); for (String modelId : getLocalDeployedModels()) { modelUndeployStatus.put(modelId, UNDEPLOYED); - mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_MODEL_COUNT).decrement(); + mlStats.getStat(MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT).decrement(); mlStats.createCounterStatIfAbsent(getModelFunctionName(modelId), ActionName.UNDEPLOY, ML_ACTION_REQUEST_COUNT).increment(); removeModel(modelId); } diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index b432f0ec73..3a64f1263f 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -304,11 +304,11 @@ public Collection createComponents( stats.put(MLClusterLevelStat.ML_MODEL_COUNT, new MLStat<>(true, new CounterSupplier())); stats.put(MLClusterLevelStat.ML_CONNECTOR_COUNT, new MLStat<>(true, new CounterSupplier())); // node level stats - stats.put(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT, new MLStat<>(false, new CounterSupplier())); - stats.put(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT, new MLStat<>(false, new CounterSupplier())); - stats.put(MLNodeLevelStat.ML_NODE_TOTAL_FAILURE_COUNT, new MLStat<>(false, new CounterSupplier())); - stats.put(MLNodeLevelStat.ML_NODE_TOTAL_MODEL_COUNT, new MLStat<>(false, new CounterSupplier())); - stats.put(MLNodeLevelStat.ML_NODE_TOTAL_CIRCUIT_BREAKER_TRIGGER_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_REQUEST_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_FAILURE_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_CIRCUIT_BREAKER_TRIGGER_COUNT, new MLStat<>(false, new CounterSupplier())); this.mlStats = new MLStats(stats); mlIndicesHandler = new MLIndicesHandler(clusterService, client); diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLStatsAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLStatsAction.java index 4a946963a8..034b7b2054 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLStatsAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLStatsAction.java @@ -21,6 +21,7 @@ import java.util.Locale; import java.util.Map; import java.util.Optional; +import java.util.Set; import java.util.stream.Collectors; import org.opensearch.client.node.NodeClient; @@ -60,6 +61,12 @@ public class RestMLStatsAction extends BaseRestHandler { private static final String QUERY_ALL_MODEL_META_DOC = "{\"query\":{\"bool\":{\"must_not\":{\"exists\":{\"field\":\"chunk_number\"}}}}}"; + private static final Set ML_NODE_STAT_NAMES = EnumSet + .allOf(MLNodeLevelStat.class) + .stream() + .map(stat -> stat.name()) + .collect(Collectors.toSet()); + /** * Constructor * @param mlStats MLStats object @@ -148,6 +155,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli } MLStatsInput createMlStatsInputFromRequestParams(RestRequest request) { + MLStatsInput mlStatsInput = new MLStatsInput(); Optional nodeIds = splitCommaSeparatedParam(request, "nodeId"); if (nodeIds.isPresent()) { @@ -158,7 +166,7 @@ MLStatsInput createMlStatsInputFromRequestParams(RestRequest request) { for (String state : stats.get()) { state = state.toUpperCase(Locale.ROOT); // only support cluster and node level stats for bwc - if (state.startsWith("ML_NODE")) { + if (ML_NODE_STAT_NAMES.contains(state)) { mlStatsInput.getNodeLevelStats().add(MLNodeLevelStat.from(state)); } else { mlStatsInput.getClusterLevelStats().add(MLClusterLevelStat.from(state)); diff --git a/plugin/src/main/java/org/opensearch/ml/stats/MLNodeLevelStat.java b/plugin/src/main/java/org/opensearch/ml/stats/MLNodeLevelStat.java index d002c002bf..857721392a 100644 --- a/plugin/src/main/java/org/opensearch/ml/stats/MLNodeLevelStat.java +++ b/plugin/src/main/java/org/opensearch/ml/stats/MLNodeLevelStat.java @@ -10,12 +10,13 @@ * This enum represents node level stats. */ public enum MLNodeLevelStat { - ML_NODE_JVM_HEAP_USAGE, - ML_NODE_EXECUTING_TASK_COUNT, - ML_NODE_TOTAL_REQUEST_COUNT, - ML_NODE_TOTAL_FAILURE_COUNT, - ML_NODE_TOTAL_MODEL_COUNT, - ML_NODE_TOTAL_CIRCUIT_BREAKER_TRIGGER_COUNT; + ML_JVM_HEAP_USAGE, + ML_EXECUTING_TASK_COUNT, // How many tasks are executing currently. If any task starts, then it will increase by 1, + // if the task finished then it will decrease by 0. + ML_REQUEST_COUNT, + ML_FAILURE_COUNT, + ML_DEPLOYED_MODEL_COUNT, + ML_CIRCUIT_BREAKER_TRIGGER_COUNT; public static MLNodeLevelStat from(String value) { try { diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLExecuteTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLExecuteTaskRunner.java index f5cb8a3846..1498b39d07 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLExecuteTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLExecuteTaskRunner.java @@ -88,8 +88,8 @@ protected TransportResponseHandler getResponseHandler(Act protected void executeTask(MLExecuteTaskRequest request, ActionListener listener) { threadPool.executor(EXECUTE_THREAD_POOL).execute(() -> { try { - mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).increment(); - mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT).increment(); + mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).increment(); + mlStats.getStat(MLNodeLevelStat.ML_REQUEST_COUNT).increment(); mlStats .createCounterStatIfAbsent(request.getFunctionName(), ActionName.EXECUTE, MLActionLevelStat.ML_ACTION_REQUEST_COUNT) .increment(); @@ -113,7 +113,7 @@ protected void executeTask(MLExecuteTaskRequest request, ActionListener listener) { ActionListener internalListener = wrappedCleanupListener(listener, mlTask.getTaskId()); // track ML task count and add ML task into cache - mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).increment(); - mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT).increment(); + mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).increment(); + mlStats.getStat(MLNodeLevelStat.ML_REQUEST_COUNT).increment(); mlStats .createCounterStatIfAbsent(mlTask.getFunctionName(), ActionName.PREDICT, MLActionLevelStat.ML_ACTION_REQUEST_COUNT) .increment(); @@ -303,7 +303,7 @@ private void handlePredictFailure(MLTask mlTask, ActionListener mlStats .createCounterStatIfAbsent(mlTask.getFunctionName(), ActionName.PREDICT, MLActionLevelStat.ML_ACTION_FAILURE_COUNT) .increment(); - mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_FAILURE_COUNT).increment(); + mlStats.getStat(MLNodeLevelStat.ML_FAILURE_COUNT).increment(); } handleAsyncMLTaskFailure(mlTask, e); listener.onFailure(e); diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLTaskDispatcher.java b/plugin/src/main/java/org/opensearch/ml/task/MLTaskDispatcher.java index 84860005fc..89c8bfe2e2 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLTaskDispatcher.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLTaskDispatcher.java @@ -106,15 +106,14 @@ private void dispatchTaskWithLeastLoad(String[] nodeIds, ActionListener listener) { MLStatsNodesRequest MLStatsNodesRequest = new MLStatsNodesRequest(nodes); - MLStatsNodesRequest - .addNodeLevelStats(ImmutableSet.of(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT, MLNodeLevelStat.ML_NODE_JVM_HEAP_USAGE)); + MLStatsNodesRequest.addNodeLevelStats(ImmutableSet.of(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT, MLNodeLevelStat.ML_JVM_HEAP_USAGE)); client.execute(MLStatsNodesAction.INSTANCE, MLStatsNodesRequest, ActionListener.wrap(mlStatsResponse -> { // Check JVM pressure List candidateNodeResponse = mlStatsResponse .getNodes() .stream() - .filter(stat -> (long) stat.getNodeLevelStat(MLNodeLevelStat.ML_NODE_JVM_HEAP_USAGE) < DEFAULT_JVM_HEAP_USAGE_THRESHOLD) + .filter(stat -> (long) stat.getNodeLevelStat(MLNodeLevelStat.ML_JVM_HEAP_USAGE) < DEFAULT_JVM_HEAP_USAGE_THRESHOLD) .collect(Collectors.toList()); if (candidateNodeResponse.size() == 0) { @@ -129,7 +128,7 @@ private void dispatchTaskWithLeastLoad(DiscoveryNode[] nodes, ActionListener (Long) stat.getNodeLevelStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT) < maxMLBatchTaskPerNode) + .filter(stat -> (Long) stat.getNodeLevelStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT) < maxMLBatchTaskPerNode) .collect(Collectors.toList()); if (candidateNodeResponse.size() == 0) { String errorMessage = "All nodes' executing ML task count reach limitation."; @@ -142,13 +141,13 @@ private void dispatchTaskWithLeastLoad(DiscoveryNode[] nodes, ActionListener targetNode = candidateNodeResponse .stream() .sorted((MLStatsNodeResponse r1, MLStatsNodeResponse r2) -> { - int result = ((Long) r1.getNodeLevelStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT)) - .compareTo((Long) r2.getNodeLevelStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT)); + int result = ((Long) r1.getNodeLevelStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT)) + .compareTo((Long) r2.getNodeLevelStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT)); if (result == 0) { // if multiple nodes have same running task count, choose the one with least // JVM heap usage. - return ((Long) r1.getNodeLevelStat(MLNodeLevelStat.ML_NODE_JVM_HEAP_USAGE)) - .compareTo((Long) r2.getNodeLevelStat(MLNodeLevelStat.ML_NODE_JVM_HEAP_USAGE)); + return ((Long) r1.getNodeLevelStat(MLNodeLevelStat.ML_JVM_HEAP_USAGE)) + .compareTo((Long) r2.getNodeLevelStat(MLNodeLevelStat.ML_JVM_HEAP_USAGE)); } return result; }) diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLTaskRunner.java index 05c6c5f1bb..b2c71d6ed8 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLTaskRunner.java @@ -96,7 +96,7 @@ public void run(FunctionName functionName, Request request, TransportService tra protected ActionListener wrappedCleanupListener(ActionListener listener, String taskId) { ActionListener internalListener = ActionListener.runAfter(listener, () -> { - mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).decrement(); + mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).decrement(); mlTaskManager.remove(taskId); }); return internalListener; diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunner.java index 09603d291e..9461c3adaf 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunner.java @@ -128,8 +128,8 @@ protected void executeTask(MLTrainingTaskRequest request, ActionListener listener) { ActionListener internalListener = wrappedCleanupListener(listener, mlTask.getTaskId()); // track ML task count and add ML task into cache - mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).increment(); - mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT).increment(); + mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).increment(); + mlStats.getStat(MLNodeLevelStat.ML_REQUEST_COUNT).increment(); mlStats .createCounterStatIfAbsent(mlTask.getFunctionName(), ActionName.TRAIN_PREDICT, MLActionLevelStat.ML_ACTION_REQUEST_COUNT) .increment(); @@ -161,7 +161,7 @@ private void handlePredictFailure(MLTask mlTask, ActionListener mlStats .createCounterStatIfAbsent(mlTask.getFunctionName(), ActionName.TRAIN_PREDICT, MLActionLevelStat.ML_ACTION_FAILURE_COUNT) .increment(); - mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_FAILURE_COUNT).increment(); + mlStats.getStat(MLNodeLevelStat.ML_FAILURE_COUNT).increment(); } handleAsyncMLTaskFailure(mlTask, e); listener.onFailure(e); diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLTrainingTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLTrainingTaskRunner.java index 090afb445a..ca436e72f7 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLTrainingTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLTrainingTaskRunner.java @@ -146,8 +146,8 @@ protected void executeTask(MLTrainingTaskRequest request, ActionListener listener) { ActionListener internalListener = wrappedCleanupListener(listener, mlTask.getTaskId()); // track ML task count and add ML task into cache - mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).increment(); - mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT).increment(); + mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).increment(); + mlStats.getStat(MLNodeLevelStat.ML_REQUEST_COUNT).increment(); mlStats .createCounterStatIfAbsent(mlTask.getFunctionName(), ActionName.TRAIN, MLActionLevelStat.ML_ACTION_REQUEST_COUNT) .increment(); @@ -180,7 +180,7 @@ private void train(MLTask mlTask, MLInput mlInput, ActionListener void parseField(XContentParser parser, Set set, Function { ActionListener listener = invocation.getArgument(1); diff --git a/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodeITTests.java b/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodeITTests.java index f68bec725a..0fa4be0b7d 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodeITTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodeITTests.java @@ -5,7 +5,7 @@ package org.opensearch.ml.action.stats; -import static org.opensearch.ml.stats.MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT; +import static org.opensearch.ml.stats.MLNodeLevelStat.ML_EXECUTING_TASK_COUNT; import static org.opensearch.ml.utils.IntegTestUtils.TESTING_DATA; import static org.opensearch.ml.utils.IntegTestUtils.generateMLTestingData; import static org.opensearch.ml.utils.IntegTestUtils.verifyGeneratedTestingData; @@ -45,7 +45,7 @@ public void testGeneratedTestingData() throws ExecutionException, InterruptedExc public void testNormalCase() throws ExecutionException, InterruptedException { MLStatsNodesRequest request = new MLStatsNodesRequest(new String[0], new MLStatsInput()); - request.addNodeLevelStats(ImmutableSet.of(ML_NODE_EXECUTING_TASK_COUNT)); + request.addNodeLevelStats(ImmutableSet.of(ML_EXECUTING_TASK_COUNT)); ActionFuture future = client().execute(MLStatsNodesAction.INSTANCE, request); MLStatsNodesResponse response = future.get(); @@ -58,6 +58,6 @@ public void testNormalCase() throws ExecutionException, InterruptedException { MLStatsNodeResponse nodeResponse = responseList.get(0); assertEquals(1, nodeResponse.getNodeLevelStatSize()); - assertEquals(0l, nodeResponse.getNodeLevelStat(ML_NODE_EXECUTING_TASK_COUNT)); + assertEquals(0l, nodeResponse.getNodeLevelStat(ML_EXECUTING_TASK_COUNT)); } } diff --git a/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodeResponseTests.java b/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodeResponseTests.java index de6b599728..b475b033ff 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodeResponseTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodeResponseTests.java @@ -39,14 +39,14 @@ public class MLStatsNodeResponseTests extends OpenSearchTestCase { public void setup() { node = new DiscoveryNode("node0", buildNewFakeTransportAddress(), Version.CURRENT); Map statsToValues = new HashMap<>(); - statsToValues.put(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT, 100); + statsToValues.put(MLNodeLevelStat.ML_REQUEST_COUNT, 100); response = new MLStatsNodeResponse(node, statsToValues); } public void testSerializationDeserialization() throws IOException { DiscoveryNode localNode = new DiscoveryNode("node0", buildNewFakeTransportAddress(), Version.CURRENT); Map statsToValues = new HashMap<>(); - statsToValues.put(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT, 10l); + statsToValues.put(MLNodeLevelStat.ML_REQUEST_COUNT, 10l); MLStatsNodeResponse response = new MLStatsNodeResponse(localNode, statsToValues); BytesStreamOutput output = new BytesStreamOutput(); response.writeTo(output); @@ -60,7 +60,7 @@ public void testToXContent_NodeLevelStats() throws IOException { response.toXContent(builder, ToXContent.EMPTY_PARAMS); builder.endObject(); String taskContent = TestHelper.xContentBuilderToString(builder); - assertEquals("{\"ml_node_total_request_count\":100}", taskContent); + assertEquals("{\"ml_request_count\":100}", taskContent); } public void testToXContent_AlgorithmStats() throws IOException { @@ -100,7 +100,7 @@ public void testToXContent_WithAlgoStats() throws IOException { builder.startObject(); DiscoveryNode node = new DiscoveryNode("node0", buildNewFakeTransportAddress(), Version.CURRENT); Map statsToValues = new HashMap<>(); - statsToValues.put(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT, 100); + statsToValues.put(MLNodeLevelStat.ML_REQUEST_COUNT, 100); Map algoStats = new HashMap<>(); Map algoActionStats = new HashMap<>(); Map algoActionStatMap = new HashMap<>(); @@ -114,8 +114,8 @@ public void testToXContent_WithAlgoStats() throws IOException { String taskContent = TestHelper.xContentBuilderToString(builder); Set validResult = ImmutableSet .of( - "{\"ml_node_total_request_count\":100,\"algorithms\":{\"kmeans\":{\"train\":{\"ml_action_failure_count\":22,\"ml_action_request_count\":111}}}}", - "{\"ml_node_total_request_count\":100,\"algorithms\":{\"kmeans\":{\"train\":{\"ml_action_request_count\":111,\"ml_action_failure_count\":22}}}}" + "{\"ml_request_count\":100,\"algorithms\":{\"kmeans\":{\"train\":{\"ml_action_failure_count\":22,\"ml_action_request_count\":111}}}}", + "{\"ml_request_count\":100,\"algorithms\":{\"kmeans\":{\"train\":{\"ml_action_request_count\":111,\"ml_action_failure_count\":22}}}}" ); assertTrue(validResult.contains(taskContent)); } @@ -124,8 +124,8 @@ public void testReadStats() throws IOException { BytesStreamOutput output = new BytesStreamOutput(); response.writeTo(output); MLStatsNodeResponse mlStatsNodeResponse = MLStatsNodeResponse.readStats(output.bytes().streamInput()); - Integer expectedValue = (Integer) response.getNodeLevelStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT); - assertEquals(expectedValue, mlStatsNodeResponse.getNodeLevelStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT)); + Integer expectedValue = (Integer) response.getNodeLevelStat(MLNodeLevelStat.ML_REQUEST_COUNT); + assertEquals(expectedValue, mlStatsNodeResponse.getNodeLevelStat(MLNodeLevelStat.ML_REQUEST_COUNT)); } public void testIsEmpty_NullNodeStats() { @@ -150,23 +150,23 @@ public void testIsEmpty_EmptyAlgoStats() { public void testIsEmpty_NonEmptyNodeAndAlgoStats() { MLStatsNodeResponse response = createResponseWithDefaultAlgoStats( - ImmutableMap.of(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT, totalRequestCount) + ImmutableMap.of(MLNodeLevelStat.ML_REQUEST_COUNT, totalRequestCount) ); assertFalse(response.isEmpty()); } public void testGetNodeLevelStat_NonExistingStat() { - assertNull(response.getNodeLevelStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT)); + assertNull(response.getNodeLevelStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT)); assertEquals(1, response.getNodeLevelStatSize()); } public void testGetNodeLevelStat_NullOrEmptyNodeStats() { MLStatsNodeResponse response = new MLStatsNodeResponse(node, null); - assertNull(response.getNodeLevelStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT)); + assertNull(response.getNodeLevelStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT)); assertEquals(0, response.getNodeLevelStatSize()); response = new MLStatsNodeResponse(node, ImmutableMap.of()); - assertNull(response.getNodeLevelStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT)); + assertNull(response.getNodeLevelStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT)); assertEquals(0, response.getNodeLevelStatSize()); } diff --git a/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodesRequestTests.java b/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodesRequestTests.java index 4e8746811a..a59a2ca30c 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodesRequestTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodesRequestTests.java @@ -20,7 +20,7 @@ public class MLStatsNodesRequestTests extends OpenSearchTestCase { public void testSerializationDeserialization() throws IOException { MLStatsNodesRequest mlStatsNodesRequest = new MLStatsNodesRequest(new String[] { "testNodeId" }, new MLStatsInput()); - mlStatsNodesRequest.addNodeLevelStats(ImmutableSet.of(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT)); + mlStatsNodesRequest.addNodeLevelStats(ImmutableSet.of(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT)); BytesStreamOutput output = new BytesStreamOutput(); MLStatsNodeRequest request = new MLStatsNodeRequest(mlStatsNodesRequest); request.writeTo(output); diff --git a/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodesResponseTests.java b/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodesResponseTests.java index 832ddbb282..c801da467c 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodesResponseTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodesResponseTests.java @@ -44,12 +44,12 @@ public void testToXContent() throws IOException { DiscoveryNode node1 = new DiscoveryNode("node1", buildNewFakeTransportAddress(), Version.CURRENT); Map nodeLevelStats1 = new HashMap<>(); - nodeLevelStats1.put(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT, 100); + nodeLevelStats1.put(MLNodeLevelStat.ML_REQUEST_COUNT, 100); nodes.add(new MLStatsNodeResponse(node1, nodeLevelStats1)); DiscoveryNode node2 = new DiscoveryNode("node2", buildNewFakeTransportAddress(), Version.CURRENT); Map nodeLevelStats2 = new HashMap<>(); - nodeLevelStats2.put(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT, 200); + nodeLevelStats2.put(MLNodeLevelStat.ML_REQUEST_COUNT, 200); nodes.add(new MLStatsNodeResponse(node2, nodeLevelStats2)); List failures = new ArrayList<>(); @@ -58,9 +58,6 @@ public void testToXContent() throws IOException { response.toXContent(builder, ToXContent.EMPTY_PARAMS); builder.endObject(); String taskContent = TestHelper.xContentBuilderToString(builder); - assertEquals( - "{\"nodes\":{\"node1\":{\"ml_node_total_request_count\":100},\"node2\":{\"ml_node_total_request_count\":200}}}", - taskContent - ); + assertEquals("{\"nodes\":{\"node1\":{\"ml_request_count\":100},\"node2\":{\"ml_request_count\":200}}}", taskContent); } } diff --git a/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodesTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodesTransportActionTests.java index fe0e962fd8..b54c9ca83a 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodesTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodesTransportActionTests.java @@ -7,7 +7,7 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; -import static org.opensearch.ml.stats.MLNodeLevelStat.ML_NODE_JVM_HEAP_USAGE; +import static org.opensearch.ml.stats.MLNodeLevelStat.ML_JVM_HEAP_USAGE; import java.io.IOException; import java.util.EnumSet; @@ -56,13 +56,13 @@ public void setUp() throws Exception { super.setUp(); clusterStatName1 = MLClusterLevelStat.ML_MODEL_COUNT; - nodeStatName1 = MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT; + nodeStatName1 = MLNodeLevelStat.ML_EXECUTING_TASK_COUNT; statsMap = new HashMap<>() { { put(nodeStatName1, new MLStat<>(false, new CounterSupplier())); put(clusterStatName1, new MLStat<>(true, new CounterSupplier())); - put(ML_NODE_JVM_HEAP_USAGE, new MLStat<>(true, new SettableSupplier())); + put(ML_JVM_HEAP_USAGE, new MLStat<>(true, new SettableSupplier())); } }; @@ -119,14 +119,14 @@ public void testNodeOperationWithJvmHeapUsage() { String nodeId = clusterService().localNode().getId(); MLStatsNodesRequest mlStatsNodesRequest = new MLStatsNodesRequest(new String[] { nodeId }, new MLStatsInput()); - Set statsToBeRetrieved = ImmutableSet.of(ML_NODE_JVM_HEAP_USAGE); + Set statsToBeRetrieved = ImmutableSet.of(ML_JVM_HEAP_USAGE); mlStatsNodesRequest.addNodeLevelStats(statsToBeRetrieved); MLStatsNodeResponse response = action.nodeOperation(new MLStatsNodeRequest(mlStatsNodesRequest)); Assert.assertEquals(statsToBeRetrieved.size(), response.getNodeLevelStatSize()); - assertNotNull(response.getNodeLevelStat(ML_NODE_JVM_HEAP_USAGE)); + assertNotNull(response.getNodeLevelStat(ML_JVM_HEAP_USAGE)); } public void testNodeOperation_NoNodeLevelStat() { diff --git a/plugin/src/test/java/org/opensearch/ml/bwc/MLCommonsBackwardsCompatibilityRestTestCase.java b/plugin/src/test/java/org/opensearch/ml/bwc/MLCommonsBackwardsCompatibilityRestTestCase.java new file mode 100644 index 0000000000..38467caa10 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/bwc/MLCommonsBackwardsCompatibilityRestTestCase.java @@ -0,0 +1,756 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.bwc; + +import static org.opensearch.commons.ConfigConstants.OPENSEARCH_SECURITY_SSL_HTTP_ENABLED; +import static org.opensearch.commons.ConfigConstants.OPENSEARCH_SECURITY_SSL_HTTP_KEYSTORE_FILEPATH; +import static org.opensearch.commons.ConfigConstants.OPENSEARCH_SECURITY_SSL_HTTP_KEYSTORE_KEYPASSWORD; +import static org.opensearch.commons.ConfigConstants.OPENSEARCH_SECURITY_SSL_HTTP_KEYSTORE_PASSWORD; +import static org.opensearch.commons.ConfigConstants.OPENSEARCH_SECURITY_SSL_HTTP_PEMCERT_FILEPATH; +import static org.opensearch.ml.common.MLTask.FUNCTION_NAME_FIELD; +import static org.opensearch.ml.common.MLTask.MODEL_ID_FIELD; +import static org.opensearch.ml.common.MLTask.STATE_FIELD; +import static org.opensearch.ml.common.MLTask.TASK_ID_FIELD; +import static org.opensearch.ml.stats.MLNodeLevelStat.ML_FAILURE_COUNT; +import static org.opensearch.ml.stats.MLNodeLevelStat.ML_REQUEST_COUNT; +import static org.opensearch.ml.utils.TestData.SENTENCE_TRANSFORMER_MODEL_URL; +import static org.opensearch.ml.utils.TestData.trainModelDataJson; + +import java.io.IOException; +import java.net.URI; +import java.net.URISyntaxException; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Consumer; +import java.util.stream.Collectors; + +import org.apache.http.Header; +import org.apache.http.HttpEntity; +import org.apache.http.HttpHeaders; +import org.apache.http.HttpHost; +import org.apache.http.auth.AuthScope; +import org.apache.http.auth.UsernamePasswordCredentials; +import org.apache.http.client.CredentialsProvider; +import org.apache.http.conn.ssl.NoopHostnameVerifier; +import org.apache.http.impl.client.BasicCredentialsProvider; +import org.apache.http.message.BasicHeader; +import org.apache.http.ssl.SSLContextBuilder; +import org.apache.http.util.EntityUtils; +import org.junit.After; +import org.opensearch.client.Request; +import org.opensearch.client.Response; +import org.opensearch.client.RestClient; +import org.opensearch.client.RestClientBuilder; +import org.opensearch.common.io.PathUtils; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.commons.rest.SecureRestClientBuilder; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.DeprecationHandler; +import org.opensearch.core.xcontent.MediaType; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.MLTaskState; +import org.opensearch.ml.common.dataset.MLInputDataset; +import org.opensearch.ml.common.dataset.SearchQueryInputDataset; +import org.opensearch.ml.common.dataset.TextDocsInputDataSet; +import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.input.parameter.MLAlgoParams; +import org.opensearch.ml.common.model.MLModelConfig; +import org.opensearch.ml.common.model.MLModelFormat; +import org.opensearch.ml.common.model.MLModelState; +import org.opensearch.ml.common.model.TextEmbeddingModelConfig; +import org.opensearch.ml.common.transport.register.MLRegisterModelInput; +import org.opensearch.ml.stats.ActionName; +import org.opensearch.ml.stats.MLActionLevelStat; +import org.opensearch.ml.utils.TestData; +import org.opensearch.ml.utils.TestHelper; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.test.rest.OpenSearchRestTestCase; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.gson.Gson; +import com.google.gson.JsonArray; + +// TODO: Need to refactor this code in the future because the whole part of it is a copy of MLCommonsRestTestCase.java + +public class MLCommonsBackwardsCompatibilityRestTestCase extends OpenSearchRestTestCase { + protected Gson gson = new Gson(); + public static long CUSTOM_MODEL_TIMEOUT = 20_000; // 20 seconds + + protected boolean isHttps() { + boolean isHttps = Optional.ofNullable(System.getProperty("https")).map("true"::equalsIgnoreCase).orElse(false); + if (isHttps) { + // currently only external cluster is supported for security enabled testing + if (!Optional.ofNullable(System.getProperty("tests.rest.cluster")).isPresent()) { + throw new RuntimeException("cluster url should be provided for security enabled testing"); + } + } + + return isHttps; + } + + @Override + protected String getProtocol() { + return isHttps() ? "https" : "http"; + } + + @Override + protected Settings restAdminSettings() { + return Settings + .builder() + // disable the warning exception for admin client since it's only used for cleanup. + .put("strictDeprecationMode", false) + .put("http.port", 9200) + .put(OPENSEARCH_SECURITY_SSL_HTTP_ENABLED, isHttps()) + .put(OPENSEARCH_SECURITY_SSL_HTTP_PEMCERT_FILEPATH, "sample.pem") + .put(OPENSEARCH_SECURITY_SSL_HTTP_KEYSTORE_FILEPATH, "test-kirk.jks") + .put(OPENSEARCH_SECURITY_SSL_HTTP_KEYSTORE_PASSWORD, "changeit") + .put(OPENSEARCH_SECURITY_SSL_HTTP_KEYSTORE_KEYPASSWORD, "changeit") + .build(); + } + + // Utility fn for deleting indices. Should only be used when not allowed in a regular context + // (e.g., deleting system indices) + protected static void deleteIndexWithAdminClient(String name) throws IOException { + Request request = new Request("DELETE", "/" + name); + adminClient().performRequest(request); + } + + // Utility fn for checking if an index exists. Should only be used when not allowed in a regular context + // (e.g., checking existence of system indices) + protected static boolean indexExistsWithAdminClient(String indexName) throws IOException { + Request request = new Request("HEAD", "/" + indexName); + Response response = adminClient().performRequest(request); + return RestStatus.OK.getStatus() == response.getStatusLine().getStatusCode(); + } + + @Override + protected RestClient buildClient(Settings settings, HttpHost[] hosts) throws IOException { + boolean strictDeprecationMode = settings.getAsBoolean("strictDeprecationMode", true); + RestClientBuilder builder = RestClient.builder(hosts); + if (isHttps()) { + String keystore = settings.get(OPENSEARCH_SECURITY_SSL_HTTP_KEYSTORE_FILEPATH); + if (Objects.nonNull(keystore)) { + URI uri = null; + try { + uri = this.getClass().getClassLoader().getResource("security/sample.pem").toURI(); + } catch (URISyntaxException e) { + throw new RuntimeException(e); + } + Path configPath = PathUtils.get(uri).getParent().toAbsolutePath(); + return new SecureRestClientBuilder(settings, configPath).build(); + } else { + configureHttpsClient(builder, settings); + builder.setStrictDeprecationMode(strictDeprecationMode); + return builder.build(); + } + + } else { + configureClient(builder, settings); + builder.setStrictDeprecationMode(strictDeprecationMode); + return builder.build(); + } + + } + + @SuppressWarnings("unchecked") + @After + protected void wipeAllODFEIndices() throws IOException { + Response response = adminClient().performRequest(new Request("GET", "/_cat/indices?format=json&expand_wildcards=all")); + MediaType mediaType = MediaType.fromMediaType(response.getEntity().getContentType().getValue()); + try ( + XContentParser parser = mediaType + .xContent() + .createParser( + NamedXContentRegistry.EMPTY, + DeprecationHandler.THROW_UNSUPPORTED_OPERATION, + response.getEntity().getContent() + ) + ) { + XContentParser.Token token = parser.nextToken(); + List> parserList = null; + if (token == XContentParser.Token.START_ARRAY) { + parserList = parser.listOrderedMap().stream().map(obj -> (Map) obj).collect(Collectors.toList()); + } else { + parserList = Collections.singletonList(parser.mapOrdered()); + } + + for (Map index : parserList) { + String indexName = (String) index.get("index"); + if (indexName != null && !".opendistro_security".equals(indexName)) { + adminClient().performRequest(new Request("DELETE", "/" + indexName)); + } + } + } + } + + protected static void configureHttpsClient(RestClientBuilder builder, Settings settings) throws IOException { + Map headers = ThreadContext.buildDefaultHeaders(settings); + Header[] defaultHeaders = new Header[headers.size()]; + int i = 0; + for (Map.Entry entry : headers.entrySet()) { + defaultHeaders[i++] = new BasicHeader(entry.getKey(), entry.getValue()); + } + builder.setDefaultHeaders(defaultHeaders); + builder.setHttpClientConfigCallback(httpClientBuilder -> { + String userName = Optional + .ofNullable(System.getProperty("user")) + .orElseThrow(() -> new RuntimeException("user name is missing")); + String password = Optional + .ofNullable(System.getProperty("password")) + .orElseThrow(() -> new RuntimeException("password is missing")); + CredentialsProvider credentialsProvider = new BasicCredentialsProvider(); + credentialsProvider.setCredentials(AuthScope.ANY, new UsernamePasswordCredentials(userName, password)); + try { + return httpClientBuilder + .setDefaultCredentialsProvider(credentialsProvider) + // disable the certificate since our testing cluster just uses the default security configuration + .setSSLHostnameVerifier(NoopHostnameVerifier.INSTANCE) + .setSSLContext(SSLContextBuilder.create().loadTrustMaterial(null, (chains, authType) -> true).build()); + } catch (Exception e) { + throw new RuntimeException(e); + } + }); + + final String socketTimeoutString = settings.get(CLIENT_SOCKET_TIMEOUT); + final TimeValue socketTimeout = TimeValue + .parseTimeValue(socketTimeoutString == null ? "60s" : socketTimeoutString, CLIENT_SOCKET_TIMEOUT); + builder.setRequestConfigCallback(conf -> conf.setSocketTimeout(Math.toIntExact(socketTimeout.getMillis()))); + if (settings.hasValue(CLIENT_PATH_PREFIX)) { + builder.setPathPrefix(settings.get(CLIENT_PATH_PREFIX)); + } + } + + /** + * wipeAllIndices won't work since it cannot delete security index. Use wipeAllODFEIndices instead. + */ + @Override + protected boolean preserveIndicesUponCompletion() { + return true; + } + + protected Response ingestIrisData(String indexName) throws IOException { + String irisDataIndexMapping = ""; + TestHelper + .makeRequest( + client(), + "PUT", + indexName, + null, + TestHelper.toHttpEntity(irisDataIndexMapping), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana")) + ); + + Response statsResponse = TestHelper.makeRequest(client(), "GET", indexName, ImmutableMap.of(), "", null); + assertEquals(RestStatus.OK, TestHelper.restStatus(statsResponse)); + String result = EntityUtils.toString(statsResponse.getEntity()); + assertTrue(result.contains(indexName)); + + Response bulkResponse = TestHelper + .makeRequest( + client(), + "POST", + "_bulk?refresh=true", + null, + TestHelper.toHttpEntity(TestData.IRIS_DATA.replaceAll("iris_data", indexName)), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "")) + ); + assertEquals(RestStatus.OK, TestHelper.restStatus(statsResponse)); + return bulkResponse; + } + + protected void validateStats( + FunctionName functionName, + ActionName actionName, + int expectedMinimumTotalFailureCount, + int expectedMinimumTotalAlgoFailureCount, + int expectedMinimumTotalRequestCount, + int expectedMinimumTotalAlgoRequestCount + ) throws IOException { + Response statsResponse = TestHelper.makeRequest(client(), "GET", "_plugins/_ml/stats", null, "", null); + Map map = parseResponseToMap(statsResponse); + int totalFailureCount = 0; + int totalAlgoFailureCount = 0; + int totalRequestCount = 0; + int totalAlgoRequestCount = 0; + Map allNodeStats = (Map) map.get("nodes"); + for (String key : allNodeStats.keySet()) { + Map nodeStatsMap = (Map) allNodeStats.get(key); + String statKey = ML_FAILURE_COUNT.name().toLowerCase(Locale.ROOT); + if (nodeStatsMap.containsKey(statKey)) { + totalFailureCount += (Double) nodeStatsMap.get(statKey); + } + statKey = ML_REQUEST_COUNT.name().toLowerCase(Locale.ROOT); + if (nodeStatsMap.containsKey(statKey)) { + totalRequestCount += (Double) nodeStatsMap.get(statKey); + } + Map allAlgoStats = (Map) nodeStatsMap.get("algorithms"); + statKey = functionName.name().toLowerCase(Locale.ROOT); + if (allAlgoStats.containsKey(statKey)) { + Map allActionStats = (Map) allAlgoStats.get(statKey); + String actionKey = actionName.name().toLowerCase(Locale.ROOT); + Map actionStats = (Map) allActionStats.get(actionKey); + + String actionStatKey = MLActionLevelStat.ML_ACTION_FAILURE_COUNT.name().toLowerCase(Locale.ROOT); + if (actionStats.containsKey(actionStatKey)) { + totalAlgoFailureCount += (Double) actionStats.get(actionStatKey); + } + actionStatKey = MLActionLevelStat.ML_ACTION_REQUEST_COUNT.name().toLowerCase(Locale.ROOT); + if (actionStats.containsKey(actionStatKey)) { + totalAlgoRequestCount += (Double) actionStats.get(actionStatKey); + } + } + } + assertTrue(totalFailureCount >= expectedMinimumTotalFailureCount); + assertTrue(totalAlgoFailureCount >= expectedMinimumTotalAlgoFailureCount); + assertTrue(totalRequestCount >= expectedMinimumTotalRequestCount); + assertTrue(totalAlgoRequestCount >= expectedMinimumTotalAlgoRequestCount); + } + + protected Response ingestModelData() throws IOException { + Response trainModelResponse = TestHelper + .makeRequest(client(), "POST", "_plugins/_ml/_train/sample_algo", null, TestHelper.toHttpEntity(trainModelDataJson()), null); + HttpEntity entity = trainModelResponse.getEntity(); + assertNotNull(trainModelResponse); + return trainModelResponse; + } + + public void trainAsyncWithSample(Consumer> consumer, boolean async) throws IOException, InterruptedException { + String endpoint = "/_plugins/_ml/_train/sample_algo"; + if (async) { + endpoint += "?async=true"; + } + Response response = TestHelper + .makeRequest(client(), "POST", endpoint, ImmutableMap.of(), TestHelper.toHttpEntity(trainModelDataJson()), null); + TimeUnit.SECONDS.sleep(5); + verifyResponse(consumer, response); + } + + public Response createIndexRole(String role, String index) throws IOException { + return TestHelper + .makeRequest( + client(), + "PUT", + "/_opendistro/_security/api/roles/" + role, + null, + TestHelper + .toHttpEntity( + "{\n" + + "\"cluster_permissions\": [\n" + + "],\n" + + "\"index_permissions\": [\n" + + "{\n" + + "\"index_patterns\": [\n" + + "\"" + + index + + "\"\n" + + "],\n" + + "\"dls\": \"\",\n" + + "\"fls\": [],\n" + + "\"masked_fields\": [],\n" + + "\"allowed_actions\": [\n" + + "\"crud\",\n" + + "\"indices:admin/create\"\n" + + "]\n" + + "}\n" + + "],\n" + + "\"tenant_permissions\": []\n" + + "}" + ), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana")) + ); + } + + public Response createSearchRole(String role, String index) throws IOException { + return TestHelper + .makeRequest( + client(), + "PUT", + "/_opendistro/_security/api/roles/" + role, + null, + TestHelper + .toHttpEntity( + "{\n" + + "\"cluster_permissions\": [\n" + + "],\n" + + "\"index_permissions\": [\n" + + "{\n" + + "\"index_patterns\": [\n" + + "\"" + + index + + "\"\n" + + "],\n" + + "\"dls\": \"\",\n" + + "\"fls\": [],\n" + + "\"masked_fields\": [],\n" + + "\"allowed_actions\": [\n" + + "\"indices:data/read/search\"\n" + + "]\n" + + "}\n" + + "],\n" + + "\"tenant_permissions\": []\n" + + "}" + ), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana")) + ); + } + + public Response createUser(String name, String password, ArrayList backendRoles) throws IOException { + JsonArray backendRolesString = new JsonArray(); + for (int i = 0; i < backendRoles.size(); i++) { + backendRolesString.add(backendRoles.get(i)); + } + return TestHelper + .makeRequest( + client(), + "PUT", + "/_opendistro/_security/api/internalusers/" + name, + null, + TestHelper + .toHttpEntity( + " {\n" + + "\"password\": \"" + + password + + "\",\n" + + "\"backend_roles\": " + + backendRolesString + + ",\n" + + "\"attributes\": {\n" + + "}} " + ), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana")) + ); + } + + public Response deleteUser(String user) throws IOException { + return TestHelper + .makeRequest( + client(), + "DELETE", + "/_opendistro/_security/api/internalusers/" + user, + null, + "", + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana")) + ); + } + + public Response createRoleMapping(String role, ArrayList users) throws IOException { + JsonArray usersString = new JsonArray(); + for (int i = 0; i < users.size(); i++) { + usersString.add(users.get(i)); + } + return TestHelper + .makeRequest( + client(), + "PUT", + "/_opendistro/_security/api/rolesmapping/" + role, + null, + TestHelper + .toHttpEntity( + "{\n" + " \"backend_roles\" : [ ],\n" + " \"hosts\" : [ ],\n" + " \"users\" : " + usersString + "\n" + "}" + ), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana")) + ); + } + + public void trainAndPredict( + RestClient client, + FunctionName functionName, + String indexName, + MLAlgoParams params, + SearchSourceBuilder searchSourceBuilder, + Consumer> function + ) throws IOException { + MLInputDataset inputData = SearchQueryInputDataset + .builder() + .indices(ImmutableList.of(indexName)) + .searchSourceBuilder(searchSourceBuilder) + .build(); + MLInput kmeansInput = MLInput.builder().algorithm(functionName).parameters(params).inputDataset(inputData).build(); + Response response = TestHelper + .makeRequest( + client, + "POST", + "/_plugins/_ml/_train_predict/" + functionName.name().toLowerCase(Locale.ROOT), + ImmutableMap.of(), + TestHelper.toHttpEntity(kmeansInput), + null + ); + Map map = parseResponseToMap(response); + Map predictionResult = (Map) map.get("prediction_result"); + if (function != null) { + function.accept(predictionResult); + } + } + + public void train( + RestClient client, + FunctionName functionName, + String indexName, + MLAlgoParams params, + SearchSourceBuilder searchSourceBuilder, + Consumer> function, + boolean async + ) throws IOException { + MLInputDataset inputData = SearchQueryInputDataset + .builder() + .indices(ImmutableList.of(indexName)) + .searchSourceBuilder(searchSourceBuilder) + .build(); + MLInput kmeansInput = MLInput.builder().algorithm(functionName).parameters(params).inputDataset(inputData).build(); + String endpoint = "/_plugins/_ml/_train/" + functionName.name().toLowerCase(Locale.ROOT); + if (async) { + endpoint += "?async=true"; + } + Response response = TestHelper.makeRequest(client, "POST", endpoint, ImmutableMap.of(), TestHelper.toHttpEntity(kmeansInput), null); + verifyResponse(function, response); + } + + public void predict( + RestClient client, + FunctionName functionName, + String modelId, + String indexName, + MLAlgoParams params, + SearchSourceBuilder searchSourceBuilder, + Consumer> function + ) throws IOException { + MLInputDataset inputData = SearchQueryInputDataset + .builder() + .indices(ImmutableList.of(indexName)) + .searchSourceBuilder(searchSourceBuilder) + .build(); + MLInput kmeansInput = MLInput.builder().algorithm(functionName).parameters(params).inputDataset(inputData).build(); + String endpoint = "/_plugins/_ml/_predict/" + functionName.name().toLowerCase(Locale.ROOT) + "/" + modelId; + Response response = TestHelper.makeRequest(client, "POST", endpoint, ImmutableMap.of(), TestHelper.toHttpEntity(kmeansInput), null); + verifyResponse(function, response); + } + + public void getModel(RestClient client, String modelId, Consumer> function) throws IOException { + Response response = TestHelper.makeRequest(client, "GET", "/_plugins/_ml/models/" + modelId, null, "", null); + verifyResponse(function, response); + } + + public void getTask(RestClient client, String taskId, Consumer> function) throws IOException { + Response response = TestHelper.makeRequest(client, "GET", "/_plugins/_ml/tasks/" + taskId, null, "", null); + verifyResponse(function, response); + } + + public void deleteModel(RestClient client, String modelId, Consumer> function) throws IOException { + Response response = TestHelper.makeRequest(client, "DELETE", "/_plugins/_ml/models/" + modelId, null, "", null); + verifyResponse(function, response); + } + + public void deleteTask(RestClient client, String taskId, Consumer> function) throws IOException { + Response response = TestHelper.makeRequest(client, "DELETE", "/_plugins/_ml/tasks/" + taskId, null, "", null); + verifyResponse(function, response); + } + + public void searchModelsWithAlgoName(RestClient client, String algoName, Consumer> function) throws IOException { + String query = String.format(Locale.ROOT, "{\"query\":{\"bool\":{\"filter\":[{\"term\":{\"algorithm\":\"%s\"}}]}}}", algoName); + searchModels(client, query, function); + } + + public void searchModels(RestClient client, String query, Consumer> function) throws IOException { + Response response = TestHelper.makeRequest(client, "GET", "/_plugins/_ml/models/_search", null, query, null); + verifyResponse(function, response); + } + + public void searchTasksWithAlgoName(RestClient client, String algoName, Consumer> function) throws IOException { + String query = String.format(Locale.ROOT, "{\"query\":{\"bool\":{\"filter\":[{\"term\":{\"function_name\":\"%s\"}}]}}}", algoName); + searchTasks(client, query, function); + } + + public void searchTasks(RestClient client, String query, Consumer> function) throws IOException { + Response response = TestHelper.makeRequest(client, "GET", "/_plugins/_ml/tasks/_search", null, query, null); + verifyResponse(function, response); + } + + private void verifyResponse(Consumer> verificationConsumer, Response response) throws IOException { + Map map = parseResponseToMap(response); + if (verificationConsumer != null) { + verificationConsumer.accept(map); + } + } + + public MLRegisterModelInput createRegisterModelInput() { + MLModelConfig modelConfig = TextEmbeddingModelConfig + .builder() + .modelType("bert") + .frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) + .embeddingDimension(768) + .build(); + return MLRegisterModelInput + .builder() + .modelName("test_model_name") + .version("1.0.0") + .functionName(FunctionName.TEXT_EMBEDDING) + .modelFormat(MLModelFormat.TORCH_SCRIPT) + .modelConfig(modelConfig) + .url(SENTENCE_TRANSFORMER_MODEL_URL) + .deployModel(false) + .build(); + } + + public void registerModel(RestClient client, String input, Consumer> function) throws IOException { + Response response = TestHelper.makeRequest(client, "POST", "/_plugins/_ml/models/_register", null, input, null); + verifyResponse(function, response); + } + + public String registerModel(String input) throws IOException { + Response response = TestHelper.makeRequest(client(), "POST", "/_plugins/_ml/models/_register", null, input, null); + return parseTaskIdFromResponse(response); + } + + public void deployModel(RestClient client, MLRegisterModelInput registerModelInput, Consumer> function) + throws IOException, + InterruptedException { + String taskId = registerModel(TestHelper.toJsonString(registerModelInput)); + waitForTask(taskId, MLTaskState.COMPLETED); + getTask(client(), taskId, response -> { + String algorithm = (String) response.get(FUNCTION_NAME_FIELD); + assertEquals(registerModelInput.getFunctionName().name(), algorithm); + assertNotNull(response.get(MODEL_ID_FIELD)); + assertEquals(MLTaskState.COMPLETED.name(), response.get(STATE_FIELD)); + String modelId = (String) response.get(MODEL_ID_FIELD); + try { + // deploy model + deployModel(client, modelId, function); + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + } + + public void deployModel(RestClient client, String modelId, Consumer> function) throws IOException { + Response response = TestHelper + .makeRequest(client, "POST", "/_plugins/_ml/models/" + modelId + "/_deploy", null, (String) null, null); + verifyResponse(function, response); + } + + public String deployModel(String modelId) throws IOException { + Response response = TestHelper + .makeRequest(client(), "POST", "/_plugins/_ml/models/" + modelId + "/_deploy", null, (String) null, null); + return parseTaskIdFromResponse(response); + } + + private String parseTaskIdFromResponse(Response response) throws IOException { + Map map = parseResponseToMap(response); + String taskId = (String) map.get(TASK_ID_FIELD); + return taskId; + } + + private Map parseResponseToMap(Response response) throws IOException { + HttpEntity entity = response.getEntity(); + assertNotNull(response); + String entityString = TestHelper.httpEntityToString(entity); + return gson.fromJson(entityString, Map.class); + } + + public Map getModelProfile(String modelId, Consumer verifyFunction) throws IOException { + Response response = TestHelper.makeRequest(client(), "GET", "/_plugins/_ml/profile/models/" + modelId, null, (String) null, null); + Map profile = parseResponseToMap(response); + Map nodeProfiles = (Map) profile.get("nodes"); + for (Map.Entry entry : nodeProfiles.entrySet()) { + Map modelProfiles = (Map) entry.getValue(); + assertNotNull(modelProfiles); + for (Map.Entry modelProfileEntry : modelProfiles.entrySet()) { + Map modelProfile = (Map) ((Map) modelProfileEntry.getValue()).get(modelId); + if (verifyFunction != null) { + verifyFunction.accept(modelProfile); + } + } + } + return profile; + } + + public MLInput createPredictTextEmbeddingInput() { + TextDocsInputDataSet textDocsInputDataSet = TextDocsInputDataSet + .builder() + .docs(Arrays.asList("today is sunny", "this is a happy dog")) + .build(); + return MLInput.builder().inputDataset(textDocsInputDataSet).algorithm(FunctionName.TEXT_EMBEDDING).build(); + } + + public Map predictTextEmbedding(String modelId) throws IOException { + MLInput input = createPredictTextEmbeddingInput(); + Response response = TestHelper + .makeRequest(client(), "POST", "/_plugins/_ml/models/" + modelId + "/_predict", null, TestHelper.toJsonString(input), null); + Map result = parseResponseToMap(response); + List embeddings = (List) result.get("inference_results"); + assertEquals(2, embeddings.size()); + for (Object embedding : embeddings) { + Map embeddingMap = (Map) embedding; + List tensors = (List) embeddingMap.get("output"); + assertEquals(1, tensors.size()); + Map tensorMap = (Map) tensors.get(0); + assertEquals(4, tensorMap.size()); + assertEquals("sentence_embedding", tensorMap.get("name")); + assertEquals("FLOAT32", tensorMap.get("data_type")); + List shape = (List) tensorMap.get("shape"); + assertEquals(1, shape.size()); + assertEquals(768, ((Double) shape.get(0)).longValue()); + List data = (List) tensorMap.get("data"); + assertEquals(768, data.size()); + } + return result; + } + + public Consumer> verifyTextEmbeddingModelDeployed() { + return (modelProfile) -> { + if (modelProfile.containsKey("model_state")) { + assertEquals(MLModelState.DEPLOYED.name(), modelProfile.get("model_state")); + assertTrue( + ((String) modelProfile.get("predictor")) + .startsWith("org.opensearch.ml.engine.algorithms.text_embedding.TextEmbeddingModel@") + ); + } + List workNodes = (List) modelProfile.get("worker_nodes"); + assertTrue(workNodes.size() > 0); + }; + } + + public Map undeployModel(String modelId) throws IOException { + Response response = TestHelper + .makeRequest(client(), "POST", "/_plugins/_ml/models/" + modelId + "/_undeploy", null, (String) null, null); + return parseResponseToMap(response); + } + + public String getTaskState(String taskId) throws IOException { + Response response = TestHelper.makeRequest(client(), "GET", "/_plugins/_ml/tasks/" + taskId, null, "", null); + Map task = parseResponseToMap(response); + return (String) task.get("state"); + } + + public void waitForTask(String taskId, MLTaskState targetState) throws InterruptedException { + AtomicBoolean taskDone = new AtomicBoolean(false); + waitUntil(() -> { + try { + String state = getTaskState(taskId); + if (targetState.name().equals(state)) { + taskDone.set(true); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + return taskDone.get(); + }, CUSTOM_MODEL_TIMEOUT, TimeUnit.SECONDS); + assertTrue(taskDone.get()); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java b/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java index 5b59b8a629..7dc36f41ac 100644 --- a/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java @@ -213,11 +213,11 @@ public void setup() throws URISyntaxException { Map> stats = new ConcurrentHashMap<>(); // node level stats - stats.put(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT, new MLStat<>(false, new CounterSupplier())); - stats.put(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT, new MLStat<>(false, new CounterSupplier())); - stats.put(MLNodeLevelStat.ML_NODE_TOTAL_FAILURE_COUNT, new MLStat<>(false, new CounterSupplier())); - stats.put(MLNodeLevelStat.ML_NODE_TOTAL_MODEL_COUNT, new MLStat<>(false, new CounterSupplier())); - stats.put(MLNodeLevelStat.ML_NODE_TOTAL_CIRCUIT_BREAKER_TRIGGER_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_REQUEST_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_FAILURE_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_CIRCUIT_BREAKER_TRIGGER_COUNT, new MLStat<>(false, new CounterSupplier())); this.mlStats = spy(new MLStats(stats)); mlTask = MLTask @@ -734,7 +734,7 @@ private void testDeployModel_FailedToRetrieveModelChunks(boolean lastChunk) { modelManager.deployModel(modelId, modelContentHashValue, functionName, true, mlTask, listener); verify(modelCacheHelper).removeModel(eq(modelId)); verify(mlStats).createCounterStatIfAbsent(eq(functionName), eq(ActionName.DEPLOY), eq(MLActionLevelStat.ML_ACTION_REQUEST_COUNT)); - verify(mlStats).getStat(eq(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT)); + verify(mlStats).getStat(eq(MLNodeLevelStat.ML_REQUEST_COUNT)); } private void mock_client_index_ModelChunkFailure(Client client, String modelId) { diff --git a/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java b/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java index 31640a260f..d5994bd3c1 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java @@ -16,8 +16,8 @@ import static org.opensearch.ml.common.MLTask.MODEL_ID_FIELD; import static org.opensearch.ml.common.MLTask.STATE_FIELD; import static org.opensearch.ml.common.MLTask.TASK_ID_FIELD; -import static org.opensearch.ml.stats.MLNodeLevelStat.ML_NODE_TOTAL_FAILURE_COUNT; -import static org.opensearch.ml.stats.MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT; +import static org.opensearch.ml.stats.MLNodeLevelStat.ML_FAILURE_COUNT; +import static org.opensearch.ml.stats.MLNodeLevelStat.ML_REQUEST_COUNT; import static org.opensearch.ml.utils.TestData.SENTENCE_TRANSFORMER_MODEL_URL; import static org.opensearch.ml.utils.TestData.trainModelDataJson; @@ -349,11 +349,11 @@ protected void validateStats( Map allNodeStats = (Map) map.get("nodes"); for (String key : allNodeStats.keySet()) { Map nodeStatsMap = (Map) allNodeStats.get(key); - String statKey = ML_NODE_TOTAL_FAILURE_COUNT.name().toLowerCase(Locale.ROOT); + String statKey = ML_FAILURE_COUNT.name().toLowerCase(Locale.ROOT); if (nodeStatsMap.containsKey(statKey)) { totalFailureCount += (Double) nodeStatsMap.get(statKey); } - statKey = ML_NODE_TOTAL_REQUEST_COUNT.name().toLowerCase(Locale.ROOT); + statKey = ML_REQUEST_COUNT.name().toLowerCase(Locale.ROOT); if (nodeStatsMap.containsKey(statKey)) { totalRequestCount += (Double) nodeStatsMap.get(statKey); } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLStatsActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLStatsActionTests.java index 603aca698f..d0fae76be9 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLStatsActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLStatsActionTests.java @@ -113,7 +113,7 @@ public void setup() throws IOException { MockitoAnnotations.openMocks(this); Map> statMap = ImmutableMap .>builder() - .put(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT, new MLStat<>(false, new CounterSupplier())) + .put(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT, new MLStat<>(false, new CounterSupplier())) .build(); mlStats = new MLStats(statMap); threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); @@ -209,7 +209,7 @@ public void testPrepareRequest_ClusterAndNodeLevelStates() throws Exception { content .utf8ToString() .contains( - "\"nodes\":{\"node\":{\"ml_node_total_request_count\":100,\"algorithms\":{\"kmeans\":{\"train\":{\"ml_action_request_count\":20}}}}}}" + "\"nodes\":{\"node\":{\"ml_request_count\":100,\"algorithms\":{\"kmeans\":{\"train\":{\"ml_action_request_count\":20}}}}}}" ) ); } @@ -218,7 +218,7 @@ private void prepareResponse() { doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(2); List nodes = new ArrayList<>(); - Map nodeStats = ImmutableMap.of(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT, nodeTotalRequestCount); + Map nodeStats = ImmutableMap.of(MLNodeLevelStat.ML_REQUEST_COUNT, nodeTotalRequestCount); Map algoStats = new HashMap<>(); Map actionStats = ImmutableMap .of( @@ -299,7 +299,7 @@ public void testPrepareRequest_ClusterAndNodeLevelStates_NoRequestContent() thro content .utf8ToString() .contains( - "\"nodes\":{\"node\":{\"ml_node_total_request_count\":100,\"algorithms\":{\"kmeans\":{\"train\":{\"ml_action_request_count\":20}}}}}}" + "\"nodes\":{\"node\":{\"ml_request_count\":100,\"algorithms\":{\"kmeans\":{\"train\":{\"ml_action_request_count\":20}}}}}}" ) ); } @@ -309,7 +309,7 @@ public void testPrepareRequest_ClusterAndNodeLevelStates_RequestParams() throws RestRequest request = getStatsRestRequest( node.getId(), - MLClusterLevelStat.ML_MODEL_COUNT + "," + MLNodeLevelStat.ML_NODE_TOTAL_MODEL_COUNT + MLClusterLevelStat.ML_MODEL_COUNT + "," + MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT ); restAction.handleRequest(request, channel, client); @@ -321,7 +321,7 @@ public void testPrepareRequest_ClusterAndNodeLevelStates_RequestParams() throws assertTrue(input.getTargetStatLevels().contains(MLStatLevel.NODE)); assertEquals(1, input.getClusterLevelStats().size()); assertTrue(input.getClusterLevelStats().contains(MLClusterLevelStat.ML_MODEL_COUNT)); - assertTrue(input.getNodeLevelStats().contains(MLNodeLevelStat.ML_NODE_TOTAL_MODEL_COUNT)); + assertTrue(input.getNodeLevelStats().contains(MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT)); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(BytesRestResponse.class); verify(channel, times(1)).sendResponse(argumentCaptor.capture()); @@ -334,7 +334,7 @@ public void testPrepareRequest_ClusterAndNodeLevelStates_RequestParams() throws content .utf8ToString() .contains( - "\"nodes\":{\"node\":{\"ml_node_total_request_count\":100,\"algorithms\":{\"kmeans\":{\"train\":{\"ml_action_request_count\":20}}}}}}" + "\"nodes\":{\"node\":{\"ml_request_count\":100,\"algorithms\":{\"kmeans\":{\"train\":{\"ml_action_request_count\":20}}}}}}" ) ); } @@ -342,7 +342,7 @@ public void testPrepareRequest_ClusterAndNodeLevelStates_RequestParams() throws public void testPrepareRequest_ClusterAndNodeLevelStates_RequestParams_NodeLevelStat() throws Exception { prepareResponse(); - RestRequest request = getStatsRestRequest(node.getId(), MLNodeLevelStat.ML_NODE_TOTAL_MODEL_COUNT.name()); + RestRequest request = getStatsRestRequest(node.getId(), MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT.name()); restAction.handleRequest(request, channel, client); ArgumentCaptor inputArgumentCaptor = ArgumentCaptor.forClass(MLStatsNodesRequest.class); @@ -352,7 +352,7 @@ public void testPrepareRequest_ClusterAndNodeLevelStates_RequestParams_NodeLevel assertTrue(input.getTargetStatLevels().contains(MLStatLevel.NODE)); assertEquals(0, input.getClusterLevelStats().size()); assertEquals(1, input.getNodeLevelStats().size()); - assertTrue(input.getNodeLevelStats().contains(MLNodeLevelStat.ML_NODE_TOTAL_MODEL_COUNT)); + assertTrue(input.getNodeLevelStats().contains(MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT)); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(BytesRestResponse.class); verify(channel, times(1)).sendResponse(argumentCaptor.capture()); @@ -360,17 +360,17 @@ public void testPrepareRequest_ClusterAndNodeLevelStates_RequestParams_NodeLevel assertEquals(RestStatus.OK, restResponse.status()); BytesReference content = restResponse.content(); assertEquals( - "{\"nodes\":{\"node\":{\"ml_node_total_request_count\":100,\"algorithms\":{\"kmeans\":{\"train\":{\"ml_action_request_count\":20}}}}}}", + "{\"nodes\":{\"node\":{\"ml_request_count\":100,\"algorithms\":{\"kmeans\":{\"train\":{\"ml_action_request_count\":20}}}}}}", content.utf8ToString() ); } public void testCreateMlStatsInputFromRequestParams_NodeStat() { - RestRequest request = getStatsRestRequest(node.getId(), MLNodeLevelStat.ML_NODE_TOTAL_MODEL_COUNT.name().toLowerCase(Locale.ROOT)); + RestRequest request = getStatsRestRequest(node.getId(), MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT.name().toLowerCase(Locale.ROOT)); MLStatsInput input = restAction.createMlStatsInputFromRequestParams(request); assertEquals(1, input.getTargetStatLevels().size()); assertTrue(input.getTargetStatLevels().contains(MLStatLevel.NODE)); - assertTrue(input.getNodeLevelStats().contains(MLNodeLevelStat.ML_NODE_TOTAL_MODEL_COUNT)); + assertTrue(input.getNodeLevelStats().contains(MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT)); assertEquals(0, input.getClusterLevelStats().size()); } diff --git a/plugin/src/test/java/org/opensearch/ml/stats/MLStatsInputTests.java b/plugin/src/test/java/org/opensearch/ml/stats/MLStatsInputTests.java index e8ee91397a..ddad31f58d 100644 --- a/plugin/src/test/java/org/opensearch/ml/stats/MLStatsInputTests.java +++ b/plugin/src/test/java/org/opensearch/ml/stats/MLStatsInputTests.java @@ -84,27 +84,27 @@ public void testRetrieveAll() { public void testShouldRetrieveStat() { assertTrue(mlStatsInput.retrieveStat(MLClusterLevelStat.ML_MODEL_COUNT)); - assertTrue(mlStatsInput.retrieveStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT)); + assertTrue(mlStatsInput.retrieveStat(MLNodeLevelStat.ML_REQUEST_COUNT)); assertTrue(mlStatsInput.retrieveStat(MLActionLevelStat.ML_ACTION_REQUEST_COUNT)); MLStatsInput mlStatsInput = MLStatsInput.builder().build(); assertTrue(mlStatsInput.retrieveStat(MLClusterLevelStat.ML_MODEL_COUNT)); - assertTrue(mlStatsInput.retrieveStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT)); + assertTrue(mlStatsInput.retrieveStat(MLNodeLevelStat.ML_REQUEST_COUNT)); assertTrue(mlStatsInput.retrieveStat(MLActionLevelStat.ML_ACTION_REQUEST_COUNT)); mlStatsInput = new MLStatsInput(); assertTrue(mlStatsInput.retrieveStat(MLClusterLevelStat.ML_MODEL_COUNT)); - assertTrue(mlStatsInput.retrieveStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT)); + assertTrue(mlStatsInput.retrieveStat(MLNodeLevelStat.ML_REQUEST_COUNT)); assertTrue(mlStatsInput.retrieveStat(MLActionLevelStat.ML_ACTION_REQUEST_COUNT)); mlStatsInput = MLStatsInput .builder() .clusterLevelStats(EnumSet.of(MLClusterLevelStat.ML_TASK_INDEX_STATUS)) - .nodeLevelStats(EnumSet.of(MLNodeLevelStat.ML_NODE_TOTAL_FAILURE_COUNT)) + .nodeLevelStats(EnumSet.of(MLNodeLevelStat.ML_FAILURE_COUNT)) .actionLevelStats(EnumSet.of(MLActionLevelStat.ML_ACTION_FAILURE_COUNT)) .build(); assertFalse(mlStatsInput.retrieveStat(MLClusterLevelStat.ML_MODEL_COUNT)); - assertFalse(mlStatsInput.retrieveStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT)); + assertFalse(mlStatsInput.retrieveStat(MLNodeLevelStat.ML_REQUEST_COUNT)); assertFalse(mlStatsInput.retrieveStat(MLActionLevelStat.ML_ACTION_REQUEST_COUNT)); } diff --git a/plugin/src/test/java/org/opensearch/ml/stats/MLStatsTests.java b/plugin/src/test/java/org/opensearch/ml/stats/MLStatsTests.java index 83e8676646..51bde6bd3a 100644 --- a/plugin/src/test/java/org/opensearch/ml/stats/MLStatsTests.java +++ b/plugin/src/test/java/org/opensearch/ml/stats/MLStatsTests.java @@ -33,7 +33,7 @@ public class MLStatsTests extends OpenSearchTestCase { public void setup() { clusterStatName1 = MLClusterLevelStat.ML_MODEL_COUNT; - nodeStatName1 = MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT; + nodeStatName1 = MLNodeLevelStat.ML_EXECUTING_TASK_COUNT; statsMap = new HashMap>() { { @@ -70,14 +70,14 @@ public void testGetStat() { } public void testGetStatNoExisting() { - MLNodeLevelStat wrongStat = MLNodeLevelStat.ML_NODE_JVM_HEAP_USAGE; + MLNodeLevelStat wrongStat = MLNodeLevelStat.ML_JVM_HEAP_USAGE; expectedEx.expect(IllegalArgumentException.class); expectedEx.expectMessage("Stat \"" + wrongStat + "\" does not exist"); mlStats.getStat(wrongStat); } public void testCreateCounterStatIfAbsent() { - MLStat stat = mlStats.createCounterStatIfAbsent(MLNodeLevelStat.ML_NODE_TOTAL_FAILURE_COUNT); + MLStat stat = mlStats.createCounterStatIfAbsent(MLNodeLevelStat.ML_FAILURE_COUNT); stat.increment(); assertEquals(1L, stat.getValue()); } diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLExecuteTaskRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLExecuteTaskRunnerTests.java index 0a4decb21a..11e9bc3441 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/MLExecuteTaskRunnerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/MLExecuteTaskRunnerTests.java @@ -119,10 +119,10 @@ public void setup() { when(clusterService.getClusterSettings()).thenReturn(clusterSettings); Map> stats = new ConcurrentHashMap<>(); - stats.put(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT, new MLStat<>(false, new CounterSupplier())); - stats.put(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT, new MLStat<>(false, new CounterSupplier())); - stats.put(MLNodeLevelStat.ML_NODE_TOTAL_FAILURE_COUNT, new MLStat<>(false, new CounterSupplier())); - stats.put(MLNodeLevelStat.ML_NODE_TOTAL_MODEL_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_REQUEST_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_FAILURE_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT, new MLStat<>(false, new CounterSupplier())); this.mlStats = new MLStats(stats); mlInputDatasetHandler = spy(new MLInputDatasetHandler(client)); diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java index 7bd5d889ed..5f18d974a7 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java @@ -141,10 +141,10 @@ public void setup() throws IOException { }).when(executorService).execute(any(Runnable.class)); Map> stats = new ConcurrentHashMap<>(); - stats.put(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT, new MLStat<>(false, new CounterSupplier())); - stats.put(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT, new MLStat<>(false, new CounterSupplier())); - stats.put(MLNodeLevelStat.ML_NODE_TOTAL_FAILURE_COUNT, new MLStat<>(false, new CounterSupplier())); - stats.put(MLNodeLevelStat.ML_NODE_TOTAL_MODEL_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_REQUEST_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_FAILURE_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT, new MLStat<>(false, new CounterSupplier())); this.mlStats = new MLStats(stats); mlInputDatasetHandler = spy(new MLInputDatasetHandler(client)); taskRunner = spy( diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLTaskDispatcherTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLTaskDispatcherTests.java index 523cde8169..70373995ad 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/MLTaskDispatcherTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/MLTaskDispatcherTests.java @@ -160,8 +160,8 @@ public void testGetEligibleNodes_MlAndDataNodes() { private MLStatsNodesResponse getMlStatsNodesResponse() { Map nodeStats = new HashMap<>(); - nodeStats.put(MLNodeLevelStat.ML_NODE_JVM_HEAP_USAGE, 50l); - nodeStats.put(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT, 5l); + nodeStats.put(MLNodeLevelStat.ML_JVM_HEAP_USAGE, 50l); + nodeStats.put(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT, 5l); MLStatsNodeResponse mlStatsNodeResponse1 = new MLStatsNodeResponse(dataNode1, nodeStats); MLStatsNodeResponse mlStatsNodeResponse2 = new MLStatsNodeResponse(dataNode1, nodeStats); return new MLStatsNodesResponse( @@ -173,7 +173,7 @@ private MLStatsNodesResponse getMlStatsNodesResponse() { private MLStatsNodesResponse getNodesResponse_NoTaskCounts() { Map nodeStats = new HashMap<>(); - nodeStats.put(MLNodeLevelStat.ML_NODE_JVM_HEAP_USAGE, 50l); + nodeStats.put(MLNodeLevelStat.ML_JVM_HEAP_USAGE, 50l); MLStatsNodeResponse mlStatsNodeResponse1 = new MLStatsNodeResponse(dataNode1, nodeStats); MLStatsNodeResponse mlStatsNodeResponse2 = new MLStatsNodeResponse(dataNode1, nodeStats); return new MLStatsNodesResponse( @@ -185,8 +185,8 @@ private MLStatsNodesResponse getNodesResponse_NoTaskCounts() { private MLStatsNodesResponse getNodesResponse_MemoryExceedLimits() { Map nodeStats = new HashMap<>(); - nodeStats.put(MLNodeLevelStat.ML_NODE_JVM_HEAP_USAGE, 90l); - nodeStats.put(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT, 5l); + nodeStats.put(MLNodeLevelStat.ML_JVM_HEAP_USAGE, 90l); + nodeStats.put(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT, 5l); MLStatsNodeResponse mlStatsNodeResponse1 = new MLStatsNodeResponse(dataNode1, nodeStats); MLStatsNodeResponse mlStatsNodeResponse2 = new MLStatsNodeResponse(dataNode1, nodeStats); return new MLStatsNodesResponse( @@ -198,8 +198,8 @@ private MLStatsNodesResponse getNodesResponse_MemoryExceedLimits() { private MLStatsNodesResponse getNodesResponse_TaskCountExceedLimits() { Map nodeStats = new HashMap<>(); - nodeStats.put(MLNodeLevelStat.ML_NODE_JVM_HEAP_USAGE, 50l); - nodeStats.put(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT, 15l); + nodeStats.put(MLNodeLevelStat.ML_JVM_HEAP_USAGE, 50l); + nodeStats.put(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT, 15l); MLStatsNodeResponse mlStatsNodeResponse1 = new MLStatsNodeResponse(dataNode1, nodeStats); MLStatsNodeResponse mlStatsNodeResponse2 = new MLStatsNodeResponse(dataNode1, nodeStats); return new MLStatsNodesResponse( diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunnerTests.java index 60f6af9d3e..a40c5c87cf 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunnerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunnerTests.java @@ -117,10 +117,10 @@ public void setup() { }).when(executorService).execute(any(Runnable.class)); Map> stats = new ConcurrentHashMap<>(); - stats.put(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT, new MLStat<>(false, new CounterSupplier())); - stats.put(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT, new MLStat<>(false, new CounterSupplier())); - stats.put(MLNodeLevelStat.ML_NODE_TOTAL_FAILURE_COUNT, new MLStat<>(false, new CounterSupplier())); - stats.put(MLNodeLevelStat.ML_NODE_TOTAL_MODEL_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_REQUEST_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_FAILURE_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT, new MLStat<>(false, new CounterSupplier())); this.mlStats = new MLStats(stats); mlInputDatasetHandler = spy(new MLInputDatasetHandler(client)); diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLTrainingTaskRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLTrainingTaskRunnerTests.java index e5e8f3281a..ae397067bc 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/MLTrainingTaskRunnerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/MLTrainingTaskRunnerTests.java @@ -125,10 +125,10 @@ public void setup() { }).when(executorService).execute(any(Runnable.class)); Map> stats = new ConcurrentHashMap<>(); - stats.put(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT, new MLStat<>(false, new CounterSupplier())); - stats.put(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT, new MLStat<>(false, new CounterSupplier())); - stats.put(MLNodeLevelStat.ML_NODE_TOTAL_FAILURE_COUNT, new MLStat<>(false, new CounterSupplier())); - stats.put(MLNodeLevelStat.ML_NODE_TOTAL_MODEL_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_REQUEST_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_FAILURE_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT, new MLStat<>(false, new CounterSupplier())); this.mlStats = new MLStats(stats); mlInputDatasetHandler = spy(new MLInputDatasetHandler(client)); diff --git a/plugin/src/test/java/org/opensearch/ml/task/TaskRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/task/TaskRunnerTests.java index 7f81eebea6..9e2abccebb 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/TaskRunnerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/TaskRunnerTests.java @@ -69,11 +69,11 @@ public class TaskRunnerTests extends OpenSearchTestCase { @Before public void setup() { Map> stats = new ConcurrentHashMap<>(); - stats.put(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT, new MLStat<>(false, new CounterSupplier())); - stats.put(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT, new MLStat<>(false, new CounterSupplier())); - stats.put(MLNodeLevelStat.ML_NODE_TOTAL_FAILURE_COUNT, new MLStat<>(false, new CounterSupplier())); - stats.put(MLNodeLevelStat.ML_NODE_TOTAL_MODEL_COUNT, new MLStat<>(false, new CounterSupplier())); - stats.put(MLNodeLevelStat.ML_NODE_TOTAL_CIRCUIT_BREAKER_TRIGGER_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_REQUEST_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_FAILURE_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(MLNodeLevelStat.ML_CIRCUIT_BREAKER_TRIGGER_COUNT, new MLStat<>(false, new CounterSupplier())); mlStats = new MLStats(stats); MockitoAnnotations.openMocks(this); @@ -140,7 +140,7 @@ public void testRun_CircuitBreakerOpen() { ActionListener listener = mock(ActionListener.class); MLTaskRequest request = new MLTaskRequest(false); expectThrows(MLLimitExceededException.class, () -> mlTaskRunner.run(FunctionName.REMOTE, request, transportService, listener)); - Long value = (Long) mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_CIRCUIT_BREAKER_TRIGGER_COUNT).getValue(); + Long value = (Long) mlStats.getStat(MLNodeLevelStat.ML_CIRCUIT_BREAKER_TRIGGER_COUNT).getValue(); assertEquals(1L, value.longValue()); } }