Skip to content

Commit

Permalink
add more stats: connector count, connector/config index status; fix m…
Browse files Browse the repository at this point in the history
…odel count bug (#1181)

Signed-off-by: Yaliang Wu <[email protected]>
  • Loading branch information
ylwu-amzn authored and zane-neo committed Sep 1, 2023
1 parent 81c79d5 commit 7ef8121
Show file tree
Hide file tree
Showing 10 changed files with 134 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
@Log4j2
public class EncryptorImpl implements Encryptor {

public static final String MASTER_KEY_NOT_READY_ERROR = "The ML encryption master key has not been initialized yet. Please retry after waiting for 10 seconds.";
private ClusterService clusterService;
private Client client;
private volatile String masterKey;
Expand Down Expand Up @@ -114,15 +115,15 @@ private void initMasterKey() {
String masterKey = (String) r.getSourceAsMap().get(MASTER_KEY);
this.masterKey = masterKey;
} else {
exceptionRef.set(new ResourceNotFoundException("ML encryption master key not initialized yet"));
exceptionRef.set(new ResourceNotFoundException(MASTER_KEY_NOT_READY_ERROR));
}
}, e -> {
log.error("Failed to get ML encryption master key", e);
exceptionRef.set(e);
}), latch));
}
} else {
exceptionRef.set(new ResourceNotFoundException("ML encryption master key not initialized yet"));
exceptionRef.set(new ResourceNotFoundException(MASTER_KEY_NOT_READY_ERROR));
latch.countDown();
}

Expand All @@ -141,7 +142,7 @@ private void initMasterKey() {
}
}
if (masterKey == null) {
throw new ResourceNotFoundException("ML encryption master key not initialized yet");
throw new ResourceNotFoundException(MASTER_KEY_NOT_READY_ERROR);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ public static Optional<String> executePreprocessFunction(ScriptService scriptSer
}
return Optional.empty();
}

public static Optional<String> executePostprocessFunction(ScriptService scriptService,
String postProcessFunction,
String resultJson) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import static org.opensearch.ml.common.CommonValue.CREATE_TIME_FIELD;
import static org.opensearch.ml.common.CommonValue.MASTER_KEY;
import static org.opensearch.ml.common.CommonValue.ML_CONFIG_INDEX;
import static org.opensearch.ml.engine.encryptor.EncryptorImpl.MASTER_KEY_NOT_READY_ERROR;

public class EncryptorImplTest {
@Rule
Expand Down Expand Up @@ -121,7 +122,7 @@ public void decrypt() {
@Test
public void encrypt_NullMasterKey_NullMasterKey_MasterKeyNotExistInIndex() {
exceptionRule.expect(ResourceNotFoundException.class);
exceptionRule.expectMessage("ML encryption master key not initialized yet");
exceptionRule.expectMessage(MASTER_KEY_NOT_READY_ERROR);

doAnswer(invocation -> {
ActionListener<GetResponse> listener = invocation.getArgument(1);
Expand Down Expand Up @@ -155,7 +156,7 @@ public void decrypt_NullMasterKey_GetMasterKey_Exception() {
@Test
public void decrypt_MLConfigIndexNotFound() {
exceptionRule.expect(ResourceNotFoundException.class);
exceptionRule.expectMessage("ML encryption master key not initialized yet");
exceptionRule.expectMessage(MASTER_KEY_NOT_READY_ERROR);

Metadata metadata = new Metadata.Builder().indices(ImmutableMap.of()).build();
when(clusterState.metadata()).thenReturn(metadata);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,9 @@ public void createModelGroup(MLRegisterModelGroupInput input, ActionListener<Str
new IllegalArgumentException(
"The name you provided is already being used by another model with ID: "
+ id
+ ". Please provide a different name"
+ ". Please provide a different name or add \"model_group_id\": \""
+ id
+ "\" to request body"
)
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

package org.opensearch.ml.plugin;

import static org.opensearch.ml.common.CommonValue.ML_CONFIG_INDEX;
import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX;
import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX;
import static org.opensearch.ml.common.CommonValue.ML_TASK_INDEX;

Expand Down Expand Up @@ -293,8 +295,12 @@ public Collection<Object> createComponents(
Map<Enum, MLStat<?>> stats = new ConcurrentHashMap<>();
// cluster level stats
stats.put(MLClusterLevelStat.ML_MODEL_INDEX_STATUS, new MLStat<>(true, new IndexStatusSupplier(indexUtils, ML_MODEL_INDEX)));
stats
.put(MLClusterLevelStat.ML_CONNECTOR_INDEX_STATUS, new MLStat<>(true, new IndexStatusSupplier(indexUtils, ML_CONNECTOR_INDEX)));
stats.put(MLClusterLevelStat.ML_CONFIG_INDEX_STATUS, new MLStat<>(true, new IndexStatusSupplier(indexUtils, ML_CONFIG_INDEX)));
stats.put(MLClusterLevelStat.ML_TASK_INDEX_STATUS, new MLStat<>(true, new IndexStatusSupplier(indexUtils, ML_TASK_INDEX)));
stats.put(MLClusterLevelStat.ML_MODEL_COUNT, new MLStat<>(true, new CounterSupplier()));
stats.put(MLClusterLevelStat.ML_CONNECTOR_COUNT, new MLStat<>(true, new CounterSupplier()));
// node level stats
stats.put(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT, new MLStat<>(false, new CounterSupplier()));
stats.put(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT, new MLStat<>(false, new CounterSupplier()));
Expand Down Expand Up @@ -452,7 +458,7 @@ public List<RestHandler> getRestHandlers(
IndexNameExpressionResolver indexNameExpressionResolver,
Supplier<DiscoveryNodes> nodesInCluster
) {
RestMLStatsAction restMLStatsAction = new RestMLStatsAction(mlStats, clusterService, indexUtils);
RestMLStatsAction restMLStatsAction = new RestMLStatsAction(mlStats, clusterService, indexUtils, xContentRegistry);
RestMLTrainingAction restMLTrainingAction = new RestMLTrainingAction();
RestMLTrainAndPredictAction restMLTrainAndPredictAction = new RestMLTrainAndPredictAction();
RestMLPredictionAction restMLPredictionAction = new RestMLPredictionAction(mlModelManager);
Expand Down
42 changes: 33 additions & 9 deletions plugin/src/main/java/org/opensearch/ml/rest/RestMLStatsAction.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.ml.rest;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX;
import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX;
import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI;
import static org.opensearch.ml.utils.RestActionUtils.splitCommaSeparatedParam;
Expand All @@ -27,6 +28,7 @@
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
Expand Down Expand Up @@ -54,17 +56,26 @@ public class RestMLStatsAction extends BaseRestHandler {
private MLStats mlStats;
private ClusterService clusterService;
private IndexUtils indexUtils;
private NamedXContentRegistry xContentRegistry;
private static final String QUERY_ALL_MODEL_META_DOC =
"{\"query\":{\"bool\":{\"must_not\":{\"exists\":{\"field\":\"chunk_number\"}}}}}";

/**
* Constructor
* @param mlStats MLStats object
* @param clusterService cluster service
* @param indexUtils index util
*/
public RestMLStatsAction(MLStats mlStats, ClusterService clusterService, IndexUtils indexUtils) {
public RestMLStatsAction(
MLStats mlStats,
ClusterService clusterService,
IndexUtils indexUtils,
NamedXContentRegistry xContentRegistry
) {
this.mlStats = mlStats;
this.clusterService = clusterService;
this.indexUtils = indexUtils;
this.xContentRegistry = xContentRegistry;
}

@Override
Expand Down Expand Up @@ -109,14 +120,27 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
if (finalMlStatsInput.getTargetStatLevels().contains(MLStatLevel.CLUSTER)
&& (finalMlStatsInput.retrieveAllClusterLevelStats()
|| finalMlStatsInput.getClusterLevelStats().contains(MLClusterLevelStat.ML_MODEL_COUNT))) {
indexUtils.getNumberOfDocumentsInIndex(ML_MODEL_INDEX, ActionListener.wrap(count -> {
clusterStatsMap.put(MLClusterLevelStat.ML_MODEL_COUNT, count);
getNodeStats(finalMlStatsInput, clusterStatsMap, client, mlStatsNodesRequest, channel);
}, e -> {
String errorMessage = "Failed to get ML model count";
log.error(errorMessage, e);
onFailure(channel, RestStatus.INTERNAL_SERVER_ERROR, errorMessage, e);
}));
indexUtils
.getNumberOfDocumentsInIndex(
ML_MODEL_INDEX,
QUERY_ALL_MODEL_META_DOC,
xContentRegistry,
ActionListener.wrap(modelCount -> {
clusterStatsMap.put(MLClusterLevelStat.ML_MODEL_COUNT, modelCount);
indexUtils.getNumberOfDocumentsInIndex(ML_CONNECTOR_INDEX, ActionListener.wrap(connectorCount -> {
clusterStatsMap.put(MLClusterLevelStat.ML_CONNECTOR_COUNT, connectorCount);
getNodeStats(finalMlStatsInput, clusterStatsMap, client, mlStatsNodesRequest, channel);
}, e -> {
String errorMessage = "Failed to get ML model count";
log.error(errorMessage, e);
onFailure(channel, RestStatus.INTERNAL_SERVER_ERROR, errorMessage, e);
}));
}, e -> {
String errorMessage = "Failed to get ML model count";
log.error(errorMessage, e);
onFailure(channel, RestStatus.INTERNAL_SERVER_ERROR, errorMessage, e);
})
);
} else {
getNodeStats(finalMlStatsInput, clusterStatsMap, client, mlStatsNodesRequest, channel);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,11 @@
*/
public enum MLClusterLevelStat {
ML_MODEL_INDEX_STATUS,
ML_CONNECTOR_INDEX_STATUS,
ML_CONFIG_INDEX_STATUS,
ML_TASK_INDEX_STATUS,
ML_MODEL_COUNT;
ML_MODEL_COUNT,
ML_CONNECTOR_COUNT;

public static MLClusterLevelStat from(String value) {
try {
Expand Down
38 changes: 38 additions & 0 deletions plugin/src/main/java/org/opensearch/ml/utils/IndexUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,22 @@
import java.util.List;
import java.util.Locale;

import org.opensearch.OpenSearchStatusException;
import org.opensearch.action.admin.indices.stats.IndicesStatsRequest;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.client.Client;
import org.opensearch.cluster.health.ClusterIndexHealth;
import org.opensearch.cluster.metadata.IndexMetadata;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.xcontent.LoggingDeprecationHandler;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.search.builder.SearchSourceBuilder;

public class IndexUtils {
/**
Expand Down Expand Up @@ -96,4 +105,33 @@ public void getNumberOfDocumentsInIndex(String indexName, ActionListener<Long> l
}
}

// TODO: add connector count stats
public void getNumberOfDocumentsInIndex(
String indexName,
String searchQuery,
NamedXContentRegistry xContentRegistry,
ActionListener<Long> listener
) {
if (clusterService.state().getRoutingTable().hasIndex(indexName)) {
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
SearchRequest searchRequest = new SearchRequest();
XContentParser parser = XContentType.JSON
.xContent()
.createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, searchQuery);
SearchSourceBuilder builder = SearchSourceBuilder.fromXContent(parser);
builder.fetchSource(false);
searchRequest.source(builder).indices(indexName);

client.search(searchRequest, ActionListener.runBefore(ActionListener.wrap(r -> {
long count = r.getHits().getTotalHits().value;
listener.onResponse(count);
}, e -> { listener.onFailure(e); }), () -> context.restore()));
} catch (Exception e) {
throw new OpenSearchStatusException("Failed to search index " + indexName, RestStatus.BAD_REQUEST);
}
} else {
listener.onResponse(0L);
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ public void test_SuccessAddAllBackendRolesTrue() {
verify(actionListener).onResponse(argumentCaptor.capture());
}

public void test_ModelGroupNameNotUnique() throws IOException {
public void test_ModelGroupNameNotUnique() throws IOException {//
SearchResponse searchResponse = createModelGroupSearchResponse(1);
doAnswer(invocation -> {
ActionListener<SearchResponse> listener = invocation.getArgument(1);
Expand All @@ -143,7 +143,7 @@ public void test_ModelGroupNameNotUnique() throws IOException {
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
verify(actionListener).onFailure(argumentCaptor.capture());
assertEquals(
"The name you provided is already being used by another model with ID: model_group_ID. Please provide a different name",
"The name you provided is already being used by another model with ID: model_group_ID. Please provide a different name or add \"model_group_id\": \"model_group_ID\" to request body",
argumentCaptor.getValue().getMessage()
);

Expand Down
Loading

0 comments on commit 7ef8121

Please sign in to comment.