Skip to content

Commit

Permalink
add circuit breaker trigger count stat (#274) (#322)
Browse files Browse the repository at this point in the history
Signed-off-by: Yaliang Wu <[email protected]>
  • Loading branch information
ylwu-amzn authored May 13, 2022
1 parent 5e64948 commit 0dc2386
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 6 deletions.
3 changes: 1 addition & 2 deletions plugin/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -208,8 +208,7 @@ List<String> jacocoExclusions = [
'org.opensearch.ml.action.tasks.*',
'org.opensearch.ml.action.handler.*',
'org.opensearch.ml.constant.CommonValue',
'org.opensearch.ml.plugin.*',
'org.opensearch.ml.task.MLPredictTaskRunner',
'org.opensearch.ml.plugin.MachineLearningPlugin*',
'org.opensearch.ml.rest.AbstractMLSearchAction*',
'org.opensearch.ml.rest.RestMLExecuteAction' //0.3
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ public Collection<Object> createComponents(
stats.put(StatNames.ML_TOTAL_REQUEST_COUNT, new MLStat<>(false, new CounterSupplier()));
stats.put(StatNames.ML_TOTAL_FAILURE_COUNT, new MLStat<>(false, new CounterSupplier()));
stats.put(StatNames.ML_TOTAL_MODEL_COUNT, new MLStat<>(false, new CounterSupplier()));
stats.put(StatNames.ML_TOTAL_CIRCUIT_BREAKER_TRIGGER_COUNT, new MLStat<>(false, new CounterSupplier()));
this.mlStats = new MLStats(stats);

mlIndicesHandler = new MLIndicesHandler(clusterService, client);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ public class StatNames {
public static String ML_TOTAL_REQUEST_COUNT = "ml_total_request_count";
public static String ML_TOTAL_FAILURE_COUNT = "ml_total_failure_count";
public static String ML_TOTAL_MODEL_COUNT = "ml_total_model_count";
public static String ML_TOTAL_CIRCUIT_BREAKER_TRIGGER_COUNT = "ml_total_circuit_breaker_trigger_count";

public static String requestCountStat(FunctionName functionName, ActionName actionName) {
return String.format(Locale.ROOT, "ml_%s_%s_request_count", functionName, actionName).toLowerCase(Locale.ROOT);
Expand Down
2 changes: 2 additions & 0 deletions plugin/src/main/java/org/opensearch/ml/task/MLTaskRunner.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.ml.task;

import static org.opensearch.ml.stats.StatNames.ML_EXECUTING_TASK_COUNT;
import static org.opensearch.ml.stats.StatNames.ML_TOTAL_CIRCUIT_BREAKER_TRIGGER_COUNT;

import java.util.HashMap;
import java.util.Map;
Expand Down Expand Up @@ -88,6 +89,7 @@ protected void handleAsyncMLTaskComplete(MLTask mlTask) {

public void run(Request request, TransportService transportService, ActionListener<Response> listener) {
if (mlCircuitBreakerService.isOpen()) {
mlStats.getStat(ML_TOTAL_CIRCUIT_BREAKER_TRIGGER_COUNT).increment();
throw new MLLimitExceededException("Circuit breaker is open");
}
if (!request.isDispatchTask()) {
Expand Down
19 changes: 15 additions & 4 deletions plugin/src/test/java/org/opensearch/ml/task/TaskRunnerTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import java.time.Instant;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

import org.junit.Before;
import org.junit.Rule;
Expand All @@ -31,7 +32,10 @@
import org.opensearch.ml.common.parameter.MLTaskState;
import org.opensearch.ml.common.parameter.MLTaskType;
import org.opensearch.ml.common.transport.MLTaskRequest;
import org.opensearch.ml.stats.MLStat;
import org.opensearch.ml.stats.MLStats;
import org.opensearch.ml.stats.StatNames;
import org.opensearch.ml.stats.suppliers.CounterSupplier;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.transport.TransportResponseHandler;
import org.opensearch.transport.TransportService;
Expand All @@ -40,7 +44,6 @@ public class TaskRunnerTests extends OpenSearchTestCase {

@Mock
MLTaskManager mlTaskManager;
@Mock
MLStats mlStats;
@Mock
MLTaskDispatcher mlTaskDispatcher;
Expand All @@ -57,6 +60,14 @@ public class TaskRunnerTests extends OpenSearchTestCase {

@Before
public void setup() {
Map<String, MLStat<?>> stats = new ConcurrentHashMap<>();
stats.put(StatNames.ML_EXECUTING_TASK_COUNT, new MLStat<>(false, new CounterSupplier()));
stats.put(StatNames.ML_TOTAL_REQUEST_COUNT, new MLStat<>(false, new CounterSupplier()));
stats.put(StatNames.ML_TOTAL_FAILURE_COUNT, new MLStat<>(false, new CounterSupplier()));
stats.put(StatNames.ML_TOTAL_MODEL_COUNT, new MLStat<>(false, new CounterSupplier()));
stats.put(StatNames.ML_TOTAL_CIRCUIT_BREAKER_TRIGGER_COUNT, new MLStat<>(false, new CounterSupplier()));
mlStats = new MLStats(stats);

MockitoAnnotations.openMocks(this);
mlTaskRunner = new MLTaskRunner(mlTaskManager, mlStats, mlTaskDispatcher, mlCircuitBreakerService, clusterService) {
@Override
Expand Down Expand Up @@ -113,11 +124,11 @@ public void testHandleAsyncMLTaskComplete_SyncTask() {
}

public void testRun_CircuitBreakerOpen() {
exceptionRule.expect(MLLimitExceededException.class);
exceptionRule.expectMessage("Circuit breaker is open");
when(mlCircuitBreakerService.isOpen()).thenReturn(true);
TransportService transportService = mock(TransportService.class);
ActionListener listener = mock(ActionListener.class);
mlTaskRunner.run(null, transportService, listener);
expectThrows(MLLimitExceededException.class, () -> mlTaskRunner.run(null, transportService, listener));
Long value = (Long) mlStats.getStat(StatNames.ML_TOTAL_CIRCUIT_BREAKER_TRIGGER_COUNT).getValue();
assertEquals(1L, value.longValue());
}
}

0 comments on commit 0dc2386

Please sign in to comment.