diff --git a/plugin/build.gradle b/plugin/build.gradle index 341df2d037..7d3299642d 100644 --- a/plugin/build.gradle +++ b/plugin/build.gradle @@ -208,8 +208,7 @@ List jacocoExclusions = [ 'org.opensearch.ml.action.tasks.*', 'org.opensearch.ml.action.handler.*', 'org.opensearch.ml.constant.CommonValue', - 'org.opensearch.ml.plugin.*', - 'org.opensearch.ml.task.MLPredictTaskRunner', + 'org.opensearch.ml.plugin.MachineLearningPlugin*', 'org.opensearch.ml.rest.AbstractMLSearchAction*', 'org.opensearch.ml.rest.RestMLExecuteAction' //0.3 ] diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index 9fd5acca50..f4f8b09449 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -162,6 +162,7 @@ public Collection createComponents( stats.put(StatNames.ML_TOTAL_REQUEST_COUNT, new MLStat<>(false, new CounterSupplier())); stats.put(StatNames.ML_TOTAL_FAILURE_COUNT, new MLStat<>(false, new CounterSupplier())); stats.put(StatNames.ML_TOTAL_MODEL_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(StatNames.ML_TOTAL_CIRCUIT_BREAKER_TRIGGER_COUNT, new MLStat<>(false, new CounterSupplier())); this.mlStats = new MLStats(stats); mlIndicesHandler = new MLIndicesHandler(clusterService, client); diff --git a/plugin/src/main/java/org/opensearch/ml/stats/StatNames.java b/plugin/src/main/java/org/opensearch/ml/stats/StatNames.java index fa6c61992e..0fb7bfc7d1 100644 --- a/plugin/src/main/java/org/opensearch/ml/stats/StatNames.java +++ b/plugin/src/main/java/org/opensearch/ml/stats/StatNames.java @@ -17,6 +17,7 @@ public class StatNames { public static String ML_TOTAL_REQUEST_COUNT = "ml_total_request_count"; public static String ML_TOTAL_FAILURE_COUNT = "ml_total_failure_count"; public static String ML_TOTAL_MODEL_COUNT = "ml_total_model_count"; + public static String ML_TOTAL_CIRCUIT_BREAKER_TRIGGER_COUNT = "ml_total_circuit_breaker_trigger_count"; public static String requestCountStat(FunctionName functionName, ActionName actionName) { return String.format(Locale.ROOT, "ml_%s_%s_request_count", functionName, actionName).toLowerCase(Locale.ROOT); diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLTaskRunner.java index 0eb4ebbab1..0db895a88c 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLTaskRunner.java @@ -6,6 +6,7 @@ package org.opensearch.ml.task; import static org.opensearch.ml.stats.StatNames.ML_EXECUTING_TASK_COUNT; +import static org.opensearch.ml.stats.StatNames.ML_TOTAL_CIRCUIT_BREAKER_TRIGGER_COUNT; import java.util.HashMap; import java.util.Map; @@ -88,6 +89,7 @@ protected void handleAsyncMLTaskComplete(MLTask mlTask) { public void run(Request request, TransportService transportService, ActionListener listener) { if (mlCircuitBreakerService.isOpen()) { + mlStats.getStat(ML_TOTAL_CIRCUIT_BREAKER_TRIGGER_COUNT).increment(); throw new MLLimitExceededException("Circuit breaker is open"); } if (!request.isDispatchTask()) { diff --git a/plugin/src/test/java/org/opensearch/ml/task/TaskRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/task/TaskRunnerTests.java index c71a365533..7159b03e6f 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/TaskRunnerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/TaskRunnerTests.java @@ -16,6 +16,7 @@ import java.time.Instant; import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; import org.junit.Before; import org.junit.Rule; @@ -31,7 +32,10 @@ import org.opensearch.ml.common.parameter.MLTaskState; import org.opensearch.ml.common.parameter.MLTaskType; import org.opensearch.ml.common.transport.MLTaskRequest; +import org.opensearch.ml.stats.MLStat; import org.opensearch.ml.stats.MLStats; +import org.opensearch.ml.stats.StatNames; +import org.opensearch.ml.stats.suppliers.CounterSupplier; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.transport.TransportResponseHandler; import org.opensearch.transport.TransportService; @@ -40,7 +44,6 @@ public class TaskRunnerTests extends OpenSearchTestCase { @Mock MLTaskManager mlTaskManager; - @Mock MLStats mlStats; @Mock MLTaskDispatcher mlTaskDispatcher; @@ -57,6 +60,14 @@ public class TaskRunnerTests extends OpenSearchTestCase { @Before public void setup() { + Map> stats = new ConcurrentHashMap<>(); + stats.put(StatNames.ML_EXECUTING_TASK_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(StatNames.ML_TOTAL_REQUEST_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(StatNames.ML_TOTAL_FAILURE_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(StatNames.ML_TOTAL_MODEL_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(StatNames.ML_TOTAL_CIRCUIT_BREAKER_TRIGGER_COUNT, new MLStat<>(false, new CounterSupplier())); + mlStats = new MLStats(stats); + MockitoAnnotations.openMocks(this); mlTaskRunner = new MLTaskRunner(mlTaskManager, mlStats, mlTaskDispatcher, mlCircuitBreakerService, clusterService) { @Override @@ -113,11 +124,11 @@ public void testHandleAsyncMLTaskComplete_SyncTask() { } public void testRun_CircuitBreakerOpen() { - exceptionRule.expect(MLLimitExceededException.class); - exceptionRule.expectMessage("Circuit breaker is open"); when(mlCircuitBreakerService.isOpen()).thenReturn(true); TransportService transportService = mock(TransportService.class); ActionListener listener = mock(ActionListener.class); - mlTaskRunner.run(null, transportService, listener); + expectThrows(MLLimitExceededException.class, () -> mlTaskRunner.run(null, transportService, listener)); + Long value = (Long) mlStats.getStat(StatNames.ML_TOTAL_CIRCUIT_BREAKER_TRIGGER_COUNT).getValue(); + assertEquals(1L, value.longValue()); } }