From f3995d87dd3e9a83fd7eddc4647529c5aed95f6c Mon Sep 17 00:00:00 2001 From: Yaliang Wu Date: Mon, 14 Aug 2023 23:08:58 -0700 Subject: [PATCH] add more stats: connector count, connector/config index status; fix model count bug Signed-off-by: Yaliang Wu --- .../ml/engine/encryptor/EncryptorImpl.java | 7 +-- .../ml/engine/utils/ScriptUtils.java | 1 + .../engine/encryptor/EncryptorImplTest.java | 5 +- .../ml/model/MLModelGroupManager.java | 4 +- .../ml/plugin/MachineLearningPlugin.java | 8 ++- .../opensearch/ml/rest/RestMLStatsAction.java | 42 +++++++++++---- .../ml/stats/MLClusterLevelStat.java | 5 +- .../org/opensearch/ml/utils/IndexUtils.java | 38 ++++++++++++++ .../ml/model/MLModelGroupManagerTests.java | 4 +- .../ml/rest/RestMLStatsActionTests.java | 51 ++++++++++++++----- 10 files changed, 134 insertions(+), 31 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/encryptor/EncryptorImpl.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/encryptor/EncryptorImpl.java index ad0e266ba1..6a045d9aa1 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/encryptor/EncryptorImpl.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/encryptor/EncryptorImpl.java @@ -34,6 +34,7 @@ @Log4j2 public class EncryptorImpl implements Encryptor { + public static final String MASTER_KEY_NOT_READY_ERROR = "The ML encryption master key has not been initialized yet. Please retry after waiting for 10 seconds."; private ClusterService clusterService; private Client client; private volatile String masterKey; @@ -114,7 +115,7 @@ private void initMasterKey() { String masterKey = (String) r.getSourceAsMap().get(MASTER_KEY); this.masterKey = masterKey; } else { - exceptionRef.set(new ResourceNotFoundException("ML encryption master key not initialized yet")); + exceptionRef.set(new ResourceNotFoundException(MASTER_KEY_NOT_READY_ERROR)); } }, e -> { log.error("Failed to get ML encryption master key", e); @@ -122,7 +123,7 @@ private void initMasterKey() { }), latch)); } } else { - exceptionRef.set(new ResourceNotFoundException("ML encryption master key not initialized yet")); + exceptionRef.set(new ResourceNotFoundException(MASTER_KEY_NOT_READY_ERROR)); latch.countDown(); } @@ -141,7 +142,7 @@ private void initMasterKey() { } } if (masterKey == null) { - throw new ResourceNotFoundException("ML encryption master key not initialized yet"); + throw new ResourceNotFoundException(MASTER_KEY_NOT_READY_ERROR); } } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/utils/ScriptUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/utils/ScriptUtils.java index a1159da08c..7c7e3a528f 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/utils/ScriptUtils.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/utils/ScriptUtils.java @@ -37,6 +37,7 @@ public static Optional executePreprocessFunction(ScriptService scriptSer } return Optional.empty(); } + public static Optional executePostprocessFunction(ScriptService scriptService, String postProcessFunction, String resultJson) { diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/encryptor/EncryptorImplTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/encryptor/EncryptorImplTest.java index e0eaab0c4a..281955eeaa 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/encryptor/EncryptorImplTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/encryptor/EncryptorImplTest.java @@ -31,6 +31,7 @@ import static org.opensearch.ml.common.CommonValue.CREATE_TIME_FIELD; import static org.opensearch.ml.common.CommonValue.MASTER_KEY; import static org.opensearch.ml.common.CommonValue.ML_CONFIG_INDEX; +import static org.opensearch.ml.engine.encryptor.EncryptorImpl.MASTER_KEY_NOT_READY_ERROR; public class EncryptorImplTest { @Rule @@ -121,7 +122,7 @@ public void decrypt() { @Test public void encrypt_NullMasterKey_NullMasterKey_MasterKeyNotExistInIndex() { exceptionRule.expect(ResourceNotFoundException.class); - exceptionRule.expectMessage("ML encryption master key not initialized yet"); + exceptionRule.expectMessage(MASTER_KEY_NOT_READY_ERROR); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); @@ -155,7 +156,7 @@ public void decrypt_NullMasterKey_GetMasterKey_Exception() { @Test public void decrypt_MLConfigIndexNotFound() { exceptionRule.expect(ResourceNotFoundException.class); - exceptionRule.expectMessage("ML encryption master key not initialized yet"); + exceptionRule.expectMessage(MASTER_KEY_NOT_READY_ERROR); Metadata metadata = new Metadata.Builder().indices(ImmutableMap.of()).build(); when(clusterState.metadata()).thenReturn(metadata); diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelGroupManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelGroupManager.java index 3a79a0c659..fac5db25c8 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelGroupManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelGroupManager.java @@ -77,7 +77,9 @@ public void createModelGroup(MLRegisterModelGroupInput input, ActionListener createComponents( Map> stats = new ConcurrentHashMap<>(); // cluster level stats stats.put(MLClusterLevelStat.ML_MODEL_INDEX_STATUS, new MLStat<>(true, new IndexStatusSupplier(indexUtils, ML_MODEL_INDEX))); + stats + .put(MLClusterLevelStat.ML_CONNECTOR_INDEX_STATUS, new MLStat<>(true, new IndexStatusSupplier(indexUtils, ML_CONNECTOR_INDEX))); + stats.put(MLClusterLevelStat.ML_CONFIG_INDEX_STATUS, new MLStat<>(true, new IndexStatusSupplier(indexUtils, ML_CONFIG_INDEX))); stats.put(MLClusterLevelStat.ML_TASK_INDEX_STATUS, new MLStat<>(true, new IndexStatusSupplier(indexUtils, ML_TASK_INDEX))); stats.put(MLClusterLevelStat.ML_MODEL_COUNT, new MLStat<>(true, new CounterSupplier())); + stats.put(MLClusterLevelStat.ML_CONNECTOR_COUNT, new MLStat<>(true, new CounterSupplier())); // node level stats stats.put(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT, new MLStat<>(false, new CounterSupplier())); stats.put(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT, new MLStat<>(false, new CounterSupplier())); @@ -452,7 +458,7 @@ public List getRestHandlers( IndexNameExpressionResolver indexNameExpressionResolver, Supplier nodesInCluster ) { - RestMLStatsAction restMLStatsAction = new RestMLStatsAction(mlStats, clusterService, indexUtils); + RestMLStatsAction restMLStatsAction = new RestMLStatsAction(mlStats, clusterService, indexUtils, xContentRegistry); RestMLTrainingAction restMLTrainingAction = new RestMLTrainingAction(); RestMLTrainAndPredictAction restMLTrainAndPredictAction = new RestMLTrainAndPredictAction(); RestMLPredictionAction restMLPredictionAction = new RestMLPredictionAction(mlModelManager); diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLStatsAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLStatsAction.java index 8ecb66eb49..4a946963a8 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLStatsAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLStatsAction.java @@ -6,6 +6,7 @@ package org.opensearch.ml.rest; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX; import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; import static org.opensearch.ml.utils.RestActionUtils.splitCommaSeparatedParam; @@ -27,6 +28,7 @@ import org.opensearch.cluster.service.ClusterService; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; @@ -54,6 +56,9 @@ public class RestMLStatsAction extends BaseRestHandler { private MLStats mlStats; private ClusterService clusterService; private IndexUtils indexUtils; + private NamedXContentRegistry xContentRegistry; + private static final String QUERY_ALL_MODEL_META_DOC = + "{\"query\":{\"bool\":{\"must_not\":{\"exists\":{\"field\":\"chunk_number\"}}}}}"; /** * Constructor @@ -61,10 +66,16 @@ public class RestMLStatsAction extends BaseRestHandler { * @param clusterService cluster service * @param indexUtils index util */ - public RestMLStatsAction(MLStats mlStats, ClusterService clusterService, IndexUtils indexUtils) { + public RestMLStatsAction( + MLStats mlStats, + ClusterService clusterService, + IndexUtils indexUtils, + NamedXContentRegistry xContentRegistry + ) { this.mlStats = mlStats; this.clusterService = clusterService; this.indexUtils = indexUtils; + this.xContentRegistry = xContentRegistry; } @Override @@ -109,14 +120,27 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli if (finalMlStatsInput.getTargetStatLevels().contains(MLStatLevel.CLUSTER) && (finalMlStatsInput.retrieveAllClusterLevelStats() || finalMlStatsInput.getClusterLevelStats().contains(MLClusterLevelStat.ML_MODEL_COUNT))) { - indexUtils.getNumberOfDocumentsInIndex(ML_MODEL_INDEX, ActionListener.wrap(count -> { - clusterStatsMap.put(MLClusterLevelStat.ML_MODEL_COUNT, count); - getNodeStats(finalMlStatsInput, clusterStatsMap, client, mlStatsNodesRequest, channel); - }, e -> { - String errorMessage = "Failed to get ML model count"; - log.error(errorMessage, e); - onFailure(channel, RestStatus.INTERNAL_SERVER_ERROR, errorMessage, e); - })); + indexUtils + .getNumberOfDocumentsInIndex( + ML_MODEL_INDEX, + QUERY_ALL_MODEL_META_DOC, + xContentRegistry, + ActionListener.wrap(modelCount -> { + clusterStatsMap.put(MLClusterLevelStat.ML_MODEL_COUNT, modelCount); + indexUtils.getNumberOfDocumentsInIndex(ML_CONNECTOR_INDEX, ActionListener.wrap(connectorCount -> { + clusterStatsMap.put(MLClusterLevelStat.ML_CONNECTOR_COUNT, connectorCount); + getNodeStats(finalMlStatsInput, clusterStatsMap, client, mlStatsNodesRequest, channel); + }, e -> { + String errorMessage = "Failed to get ML model count"; + log.error(errorMessage, e); + onFailure(channel, RestStatus.INTERNAL_SERVER_ERROR, errorMessage, e); + })); + }, e -> { + String errorMessage = "Failed to get ML model count"; + log.error(errorMessage, e); + onFailure(channel, RestStatus.INTERNAL_SERVER_ERROR, errorMessage, e); + }) + ); } else { getNodeStats(finalMlStatsInput, clusterStatsMap, client, mlStatsNodesRequest, channel); } diff --git a/plugin/src/main/java/org/opensearch/ml/stats/MLClusterLevelStat.java b/plugin/src/main/java/org/opensearch/ml/stats/MLClusterLevelStat.java index 302995cad3..b918c3cd4c 100644 --- a/plugin/src/main/java/org/opensearch/ml/stats/MLClusterLevelStat.java +++ b/plugin/src/main/java/org/opensearch/ml/stats/MLClusterLevelStat.java @@ -11,8 +11,11 @@ */ public enum MLClusterLevelStat { ML_MODEL_INDEX_STATUS, + ML_CONNECTOR_INDEX_STATUS, + ML_CONFIG_INDEX_STATUS, ML_TASK_INDEX_STATUS, - ML_MODEL_COUNT; + ML_MODEL_COUNT, + ML_CONNECTOR_COUNT; public static MLClusterLevelStat from(String value) { try { diff --git a/plugin/src/main/java/org/opensearch/ml/utils/IndexUtils.java b/plugin/src/main/java/org/opensearch/ml/utils/IndexUtils.java index 3bbfedb370..8e4a593916 100644 --- a/plugin/src/main/java/org/opensearch/ml/utils/IndexUtils.java +++ b/plugin/src/main/java/org/opensearch/ml/utils/IndexUtils.java @@ -8,13 +8,22 @@ import java.util.List; import java.util.Locale; +import org.opensearch.OpenSearchStatusException; import org.opensearch.action.admin.indices.stats.IndicesStatsRequest; +import org.opensearch.action.search.SearchRequest; import org.opensearch.client.Client; import org.opensearch.cluster.health.ClusterIndexHealth; import org.opensearch.cluster.metadata.IndexMetadata; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.search.builder.SearchSourceBuilder; public class IndexUtils { /** @@ -96,4 +105,33 @@ public void getNumberOfDocumentsInIndex(String indexName, ActionListener l } } + // TODO: add connector count stats + public void getNumberOfDocumentsInIndex( + String indexName, + String searchQuery, + NamedXContentRegistry xContentRegistry, + ActionListener listener + ) { + if (clusterService.state().getRoutingTable().hasIndex(indexName)) { + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + SearchRequest searchRequest = new SearchRequest(); + XContentParser parser = XContentType.JSON + .xContent() + .createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, searchQuery); + SearchSourceBuilder builder = SearchSourceBuilder.fromXContent(parser); + builder.fetchSource(false); + searchRequest.source(builder).indices(indexName); + + client.search(searchRequest, ActionListener.runBefore(ActionListener.wrap(r -> { + long count = r.getHits().getTotalHits().value; + listener.onResponse(count); + }, e -> { listener.onFailure(e); }), () -> context.restore())); + } catch (Exception e) { + throw new OpenSearchStatusException("Failed to search index " + indexName, RestStatus.BAD_REQUEST); + } + } else { + listener.onResponse(0L); + } + } + } diff --git a/plugin/src/test/java/org/opensearch/ml/model/MLModelGroupManagerTests.java b/plugin/src/test/java/org/opensearch/ml/model/MLModelGroupManagerTests.java index 8c5f2f999f..9a6bc35237 100644 --- a/plugin/src/test/java/org/opensearch/ml/model/MLModelGroupManagerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/model/MLModelGroupManagerTests.java @@ -129,7 +129,7 @@ public void test_SuccessAddAllBackendRolesTrue() { verify(actionListener).onResponse(argumentCaptor.capture()); } - public void test_ModelGroupNameNotUnique() throws IOException { + public void test_ModelGroupNameNotUnique() throws IOException {// SearchResponse searchResponse = createModelGroupSearchResponse(1); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); @@ -143,7 +143,7 @@ public void test_ModelGroupNameNotUnique() throws IOException { ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals( - "The name you provided is already being used by another model with ID: model_group_ID. Please provide a different name", + "The name you provided is already being used by another model with ID: model_group_ID. Please provide a different name or add \"model_group_id\": \"model_group_ID\" to request body", argumentCaptor.getValue().getMessage() ); diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLStatsActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLStatsActionTests.java index c9173bffbb..603aca698f 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLStatsActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLStatsActionTests.java @@ -49,6 +49,7 @@ import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.common.transport.TransportAddress; import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.ml.action.stats.MLStatsNodeResponse; import org.opensearch.ml.action.stats.MLStatsNodesAction; import org.opensearch.ml.action.stats.MLStatsNodesRequest; @@ -88,6 +89,8 @@ public class RestMLStatsActionTests extends OpenSearchTestCase { ClusterService clusterService; @Mock IndexUtils indexUtils; + @Mock + NamedXContentRegistry xContentRegistry; @Mock RestChannel channel; @@ -101,6 +104,7 @@ public class RestMLStatsActionTests extends OpenSearchTestCase { ClusterState testState; long mlModelCount = 10; + long mlConnectorCount = 2; long nodeTotalRequestCount = 100; long kmeansTrainRequestCount = 20; @@ -114,7 +118,7 @@ public void setup() throws IOException { mlStats = new MLStats(statMap); threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); client = spy(new NodeClient(Settings.EMPTY, threadPool)); - restAction = new RestMLStatsAction(mlStats, clusterService, indexUtils); + restAction = new RestMLStatsAction(mlStats, clusterService, indexUtils, xContentRegistry); Set roleSet = new HashSet<>(); roleSet.add(DiscoveryNodeRole.DATA_ROLE); node = new DiscoveryNode( @@ -128,10 +132,17 @@ public void setup() throws IOException { when(clusterService.state()).thenReturn(testState); clusterName = new ClusterName(clusterNameStr); + doAnswer(invocation -> { - ActionListener actionListener = invocation.getArgument(1); + ActionListener actionListener = invocation.getArgument(3); actionListener.onResponse(mlModelCount); return null; + }).when(indexUtils).getNumberOfDocumentsInIndex(anyString(), anyString(), any(), any()); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onResponse(mlConnectorCount); + return null; }).when(indexUtils).getNumberOfDocumentsInIndex(anyString(), any()); when(channel.newBuilder()).thenReturn(XContentFactory.jsonBuilder()); @@ -168,7 +179,8 @@ public void testPrepareRequest_ClusterLevelStates() throws Exception { BytesRestResponse restResponse = argumentCaptor.getValue(); assertEquals(RestStatus.OK, restResponse.status()); BytesReference content = restResponse.content(); - assertEquals("{\"ml_model_count\":10}", content.utf8ToString()); + assertTrue(content.utf8ToString().contains("\"ml_connector_count\":2")); + assertTrue(content.utf8ToString().contains("\"ml_model_count\":10")); } public void testPrepareRequest_ClusterAndNodeLevelStates() throws Exception { @@ -191,9 +203,14 @@ public void testPrepareRequest_ClusterAndNodeLevelStates() throws Exception { BytesRestResponse restResponse = argumentCaptor.getValue(); assertEquals(RestStatus.OK, restResponse.status()); BytesReference content = restResponse.content(); - assertEquals( - "{\"ml_model_count\":10,\"nodes\":{\"node\":{\"ml_node_total_request_count\":100,\"algorithms\":{\"kmeans\":{\"train\":{\"ml_action_request_count\":20}}}}}}", - content.utf8ToString() + assertTrue(content.utf8ToString().contains("\"ml_connector_count\":2")); + assertTrue(content.utf8ToString().contains("\"ml_model_count\":10")); + assertTrue( + content + .utf8ToString() + .contains( + "\"nodes\":{\"node\":{\"ml_node_total_request_count\":100,\"algorithms\":{\"kmeans\":{\"train\":{\"ml_action_request_count\":20}}}}}}" + ) ); } @@ -276,9 +293,14 @@ public void testPrepareRequest_ClusterAndNodeLevelStates_NoRequestContent() thro BytesRestResponse restResponse = argumentCaptor.getValue(); assertEquals(RestStatus.OK, restResponse.status()); BytesReference content = restResponse.content(); - assertEquals( - "{\"ml_model_count\":10,\"nodes\":{\"node\":{\"ml_node_total_request_count\":100,\"algorithms\":{\"kmeans\":{\"train\":{\"ml_action_request_count\":20}}}}}}", - content.utf8ToString() + assertTrue(content.utf8ToString().contains("\"ml_connector_count\":2")); + assertTrue(content.utf8ToString().contains("\"ml_model_count\":10")); + assertTrue( + content + .utf8ToString() + .contains( + "\"nodes\":{\"node\":{\"ml_node_total_request_count\":100,\"algorithms\":{\"kmeans\":{\"train\":{\"ml_action_request_count\":20}}}}}}" + ) ); } @@ -306,9 +328,14 @@ public void testPrepareRequest_ClusterAndNodeLevelStates_RequestParams() throws BytesRestResponse restResponse = argumentCaptor.getValue(); assertEquals(RestStatus.OK, restResponse.status()); BytesReference content = restResponse.content(); - assertEquals( - "{\"ml_model_count\":10,\"nodes\":{\"node\":{\"ml_node_total_request_count\":100,\"algorithms\":{\"kmeans\":{\"train\":{\"ml_action_request_count\":20}}}}}}", - content.utf8ToString() + assertTrue(content.utf8ToString().contains("\"ml_connector_count\":2")); + assertTrue(content.utf8ToString().contains("\"ml_model_count\":10")); + assertTrue( + content + .utf8ToString() + .contains( + "\"nodes\":{\"node\":{\"ml_node_total_request_count\":100,\"algorithms\":{\"kmeans\":{\"train\":{\"ml_action_request_count\":20}}}}}}" + ) ); }