Skip to content

Commit

Permalink
fixing metrics (#1194) (#1220)
Browse files Browse the repository at this point in the history
* fixing metrics



* addressing comments



* addressing comments



* updating test



* added IllegalArgumentException in the if statement



* addressing comments



* fixing spotless



---------

Signed-off-by: Dhrubo Saha <[email protected]>
Co-authored-by: Dhrubo Saha <[email protected]>
  • Loading branch information
ylwu-amzn and dhrubo-os authored Aug 18, 2023
1 parent 6db14b1 commit c6c2ad8
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 10 deletions.
27 changes: 20 additions & 7 deletions plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.connector.Connector;
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.ml.common.exception.MLLimitExceededException;
import org.opensearch.ml.common.exception.MLResourceNotFoundException;
import org.opensearch.ml.common.exception.MLValidationException;
import org.opensearch.ml.common.model.MLModelState;
Expand Down Expand Up @@ -318,9 +319,10 @@ private void uploadMLModelMeta(MLRegisterModelMetaInput mlRegisterModelMetaInput
* @param mlTask ML task
*/
public void registerMLModel(MLRegisterModelInput registerModelInput, MLTask mlTask) {
mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT).increment();

checkAndAddRunningTask(mlTask, maxRegisterTasksPerNode);
try {
mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT).increment();
mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).increment();
mlStats.createCounterStatIfAbsent(mlTask.getFunctionName(), REGISTER, ML_ACTION_REQUEST_COUNT).increment();

Expand Down Expand Up @@ -380,7 +382,6 @@ public void registerMLModel(MLRegisterModelInput registerModelInput, MLTask mlTa
handleException(registerModelInput.getFunctionName(), mlTask.getTaskId(), e);
}
} catch (Exception e) {
mlStats.createCounterStatIfAbsent(mlTask.getFunctionName(), REGISTER, MLActionLevelStat.ML_ACTION_FAILURE_COUNT).increment();
handleException(registerModelInput.getFunctionName(), mlTask.getTaskId(), e);
} finally {
mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).increment();
Expand All @@ -392,9 +393,9 @@ private void indexRemoteModel(MLRegisterModelInput registerModelInput, MLTask ml
FunctionName functionName = mlTask.getFunctionName();
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT).increment();

mlStats.createCounterStatIfAbsent(functionName, REGISTER, ML_ACTION_REQUEST_COUNT).increment();
mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).increment();

String modelName = registerModelInput.getModelName();
String version = modelVersion == null ? registerModelInput.getVersion() : modelVersion;
Instant now = Instant.now();
Expand Down Expand Up @@ -462,7 +463,6 @@ private void registerModelFromUrl(MLRegisterModelInput registerModelInput, MLTas
FunctionName functionName = mlTask.getFunctionName();
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT).increment();

mlStats.createCounterStatIfAbsent(functionName, REGISTER, ML_ACTION_REQUEST_COUNT).increment();
mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).increment();
String modelName = registerModelInput.getModelName();
Expand Down Expand Up @@ -689,7 +689,12 @@ private void deleteModel(String modelId) {
}

private void handleException(FunctionName functionName, String taskId, Exception e) {
mlStats.createCounterStatIfAbsent(functionName, REGISTER, MLActionLevelStat.ML_ACTION_FAILURE_COUNT).increment();
if (!(e instanceof MLLimitExceededException)
&& !(e instanceof MLResourceNotFoundException)
&& !(e instanceof IllegalArgumentException)) {
mlStats.createCounterStatIfAbsent(functionName, REGISTER, MLActionLevelStat.ML_ACTION_FAILURE_COUNT).increment();
mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_FAILURE_COUNT).increment();
}
Map<String, Object> updated = ImmutableMap.of(ERROR_FIELD, MLExceptionUtils.getRootCauseMessage(e), STATE_FIELD, FAILED);
mlTaskManager.updateMLTask(taskId, updated, TIMEOUT_IN_MILLIS, true);
}
Expand All @@ -713,6 +718,7 @@ public void deployModel(
ActionListener<String> listener
) {
mlStats.createCounterStatIfAbsent(functionName, ActionName.DEPLOY, ML_ACTION_REQUEST_COUNT).increment();
mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT).increment();
List<String> workerNodes = mlTask.getWorkerNodes();
if (modelCacheHelper.isModelDeployed(modelId)) {
if (workerNodes != null && workerNodes.size() > 0) {
Expand Down Expand Up @@ -835,7 +841,13 @@ public void deployModel(
}

private void handleDeployModelException(String modelId, FunctionName functionName, ActionListener<String> listener, Exception e) {
mlStats.createCounterStatIfAbsent(functionName, ActionName.DEPLOY, MLActionLevelStat.ML_ACTION_FAILURE_COUNT).increment();

if (!(e instanceof MLLimitExceededException)
&& !(e instanceof MLResourceNotFoundException)
&& !(e instanceof IllegalArgumentException)) {
mlStats.createCounterStatIfAbsent(functionName, ActionName.DEPLOY, MLActionLevelStat.ML_ACTION_FAILURE_COUNT).increment();
mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_FAILURE_COUNT).increment();
}
removeModel(modelId);
listener.onFailure(e);
}
Expand All @@ -858,7 +870,7 @@ public void getModel(String modelId, ActionListener<MLModel> listener) {
}

/**
* Get model from model index with includes/exludes filter.
* Get model from model index with includes/excludes filter.
*
* @param modelId model id
* @param includes fields included
Expand Down Expand Up @@ -1045,6 +1057,7 @@ public synchronized Map<String, String> undeployModel(String[] modelIds) {
if (modelCacheHelper.isModelDeployed(modelId)) {
modelUndeployStatus.put(modelId, UNDEPLOYED);
mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_MODEL_COUNT).decrement();
mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT).increment();
mlStats
.createCounterStatIfAbsent(getModelFunctionName(modelId), ActionName.UNDEPLOY, ML_ACTION_REQUEST_COUNT)
.increment();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,11 @@ public static String toJsonString(Map<String, String> nodeErrors) throws IOExcep

public static void logException(String errorMessage, Exception e, Logger log) {
Throwable rootCause = ExceptionUtils.getRootCause(e);
if (e instanceof MLLimitExceededException || e instanceof MLResourceNotFoundException) {
if (e instanceof MLLimitExceededException || e instanceof MLResourceNotFoundException || e instanceof IllegalArgumentException) {
log.warn(e.getMessage());
} else if (rootCause instanceof MLLimitExceededException || rootCause instanceof MLResourceNotFoundException) {
} else if (rootCause instanceof MLLimitExceededException
|| rootCause instanceof MLResourceNotFoundException
|| rootCause instanceof IllegalArgumentException) {
log.warn(rootCause.getMessage());
} else {
log.error(errorMessage, e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -733,7 +733,8 @@ private void testDeployModel_FailedToRetrieveModelChunks(boolean lastChunk) {

modelManager.deployModel(modelId, modelContentHashValue, functionName, true, mlTask, listener);
verify(modelCacheHelper).removeModel(eq(modelId));
verify(mlStats).createCounterStatIfAbsent(eq(functionName), eq(ActionName.DEPLOY), eq(MLActionLevelStat.ML_ACTION_FAILURE_COUNT));
verify(mlStats).createCounterStatIfAbsent(eq(functionName), eq(ActionName.DEPLOY), eq(MLActionLevelStat.ML_ACTION_REQUEST_COUNT));
verify(mlStats).getStat(eq(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT));
}

private void mock_client_index_ModelChunkFailure(Client client, String modelId) {
Expand Down

0 comments on commit c6c2ad8

Please sign in to comment.