diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java index 07720d45f8..c59aa6e8e1 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java @@ -94,6 +94,7 @@ import org.opensearch.ml.common.MLTask; import org.opensearch.ml.common.connector.Connector; import org.opensearch.ml.common.exception.MLException; +import org.opensearch.ml.common.exception.MLLimitExceededException; import org.opensearch.ml.common.exception.MLResourceNotFoundException; import org.opensearch.ml.common.exception.MLValidationException; import org.opensearch.ml.common.model.MLModelState; @@ -318,9 +319,10 @@ private void uploadMLModelMeta(MLRegisterModelMetaInput mlRegisterModelMetaInput * @param mlTask ML task */ public void registerMLModel(MLRegisterModelInput registerModelInput, MLTask mlTask) { - mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT).increment(); + checkAndAddRunningTask(mlTask, maxRegisterTasksPerNode); try { + mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT).increment(); mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).increment(); mlStats.createCounterStatIfAbsent(mlTask.getFunctionName(), REGISTER, ML_ACTION_REQUEST_COUNT).increment(); @@ -380,7 +382,6 @@ public void registerMLModel(MLRegisterModelInput registerModelInput, MLTask mlTa handleException(registerModelInput.getFunctionName(), mlTask.getTaskId(), e); } } catch (Exception e) { - mlStats.createCounterStatIfAbsent(mlTask.getFunctionName(), REGISTER, MLActionLevelStat.ML_ACTION_FAILURE_COUNT).increment(); handleException(registerModelInput.getFunctionName(), mlTask.getTaskId(), e); } finally { mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).increment(); @@ -392,9 +393,9 @@ private void indexRemoteModel(MLRegisterModelInput registerModelInput, MLTask ml FunctionName functionName = mlTask.getFunctionName(); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT).increment(); - mlStats.createCounterStatIfAbsent(functionName, REGISTER, ML_ACTION_REQUEST_COUNT).increment(); mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).increment(); + String modelName = registerModelInput.getModelName(); String version = modelVersion == null ? registerModelInput.getVersion() : modelVersion; Instant now = Instant.now(); @@ -462,7 +463,6 @@ private void registerModelFromUrl(MLRegisterModelInput registerModelInput, MLTas FunctionName functionName = mlTask.getFunctionName(); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT).increment(); - mlStats.createCounterStatIfAbsent(functionName, REGISTER, ML_ACTION_REQUEST_COUNT).increment(); mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).increment(); String modelName = registerModelInput.getModelName(); @@ -689,7 +689,12 @@ private void deleteModel(String modelId) { } private void handleException(FunctionName functionName, String taskId, Exception e) { - mlStats.createCounterStatIfAbsent(functionName, REGISTER, MLActionLevelStat.ML_ACTION_FAILURE_COUNT).increment(); + if (!(e instanceof MLLimitExceededException) + && !(e instanceof MLResourceNotFoundException) + && !(e instanceof IllegalArgumentException)) { + mlStats.createCounterStatIfAbsent(functionName, REGISTER, MLActionLevelStat.ML_ACTION_FAILURE_COUNT).increment(); + mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_FAILURE_COUNT).increment(); + } Map updated = ImmutableMap.of(ERROR_FIELD, MLExceptionUtils.getRootCauseMessage(e), STATE_FIELD, FAILED); mlTaskManager.updateMLTask(taskId, updated, TIMEOUT_IN_MILLIS, true); } @@ -713,6 +718,7 @@ public void deployModel( ActionListener listener ) { mlStats.createCounterStatIfAbsent(functionName, ActionName.DEPLOY, ML_ACTION_REQUEST_COUNT).increment(); + mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT).increment(); List workerNodes = mlTask.getWorkerNodes(); if (modelCacheHelper.isModelDeployed(modelId)) { if (workerNodes != null && workerNodes.size() > 0) { @@ -835,7 +841,13 @@ public void deployModel( } private void handleDeployModelException(String modelId, FunctionName functionName, ActionListener listener, Exception e) { - mlStats.createCounterStatIfAbsent(functionName, ActionName.DEPLOY, MLActionLevelStat.ML_ACTION_FAILURE_COUNT).increment(); + + if (!(e instanceof MLLimitExceededException) + && !(e instanceof MLResourceNotFoundException) + && !(e instanceof IllegalArgumentException)) { + mlStats.createCounterStatIfAbsent(functionName, ActionName.DEPLOY, MLActionLevelStat.ML_ACTION_FAILURE_COUNT).increment(); + mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_FAILURE_COUNT).increment(); + } removeModel(modelId); listener.onFailure(e); } @@ -858,7 +870,7 @@ public void getModel(String modelId, ActionListener listener) { } /** - * Get model from model index with includes/exludes filter. + * Get model from model index with includes/excludes filter. * * @param modelId model id * @param includes fields included @@ -1045,6 +1057,7 @@ public synchronized Map undeployModel(String[] modelIds) { if (modelCacheHelper.isModelDeployed(modelId)) { modelUndeployStatus.put(modelId, UNDEPLOYED); mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_MODEL_COUNT).decrement(); + mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT).increment(); mlStats .createCounterStatIfAbsent(getModelFunctionName(modelId), ActionName.UNDEPLOY, ML_ACTION_REQUEST_COUNT) .increment(); diff --git a/plugin/src/main/java/org/opensearch/ml/utils/MLExceptionUtils.java b/plugin/src/main/java/org/opensearch/ml/utils/MLExceptionUtils.java index de2c713d5e..e831388b86 100644 --- a/plugin/src/main/java/org/opensearch/ml/utils/MLExceptionUtils.java +++ b/plugin/src/main/java/org/opensearch/ml/utils/MLExceptionUtils.java @@ -44,9 +44,11 @@ public static String toJsonString(Map nodeErrors) throws IOExcep public static void logException(String errorMessage, Exception e, Logger log) { Throwable rootCause = ExceptionUtils.getRootCause(e); - if (e instanceof MLLimitExceededException || e instanceof MLResourceNotFoundException) { + if (e instanceof MLLimitExceededException || e instanceof MLResourceNotFoundException || e instanceof IllegalArgumentException) { log.warn(e.getMessage()); - } else if (rootCause instanceof MLLimitExceededException || rootCause instanceof MLResourceNotFoundException) { + } else if (rootCause instanceof MLLimitExceededException + || rootCause instanceof MLResourceNotFoundException + || rootCause instanceof IllegalArgumentException) { log.warn(rootCause.getMessage()); } else { log.error(errorMessage, e); diff --git a/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java b/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java index 9f5cb8c441..cc0b66a3fe 100644 --- a/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java @@ -733,7 +733,8 @@ private void testDeployModel_FailedToRetrieveModelChunks(boolean lastChunk) { modelManager.deployModel(modelId, modelContentHashValue, functionName, true, mlTask, listener); verify(modelCacheHelper).removeModel(eq(modelId)); - verify(mlStats).createCounterStatIfAbsent(eq(functionName), eq(ActionName.DEPLOY), eq(MLActionLevelStat.ML_ACTION_FAILURE_COUNT)); + verify(mlStats).createCounterStatIfAbsent(eq(functionName), eq(ActionName.DEPLOY), eq(MLActionLevelStat.ML_ACTION_REQUEST_COUNT)); + verify(mlStats).getStat(eq(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT)); } private void mock_client_index_ModelChunkFailure(Client client, String modelId) {