Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add more stats: connector count, connector/config index status; fix model count bug #1180

Merged
merged 2 commits into from
Aug 3, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
@Log4j2
public class EncryptorImpl implements Encryptor {

public static final String MASTER_KEY_NOT_READY_ERROR = "ML encryption master key not initialized yet. Please retry after 10 seconds.";
ylwu-amzn marked this conversation as resolved.
Show resolved Hide resolved
private ClusterService clusterService;
private Client client;
private volatile String masterKey;
Expand Down Expand Up @@ -114,15 +115,15 @@ private void initMasterKey() {
String masterKey = (String) r.getSourceAsMap().get(MASTER_KEY);
this.masterKey = masterKey;
} else {
exceptionRef.set(new ResourceNotFoundException("ML encryption master key not initialized yet"));
exceptionRef.set(new ResourceNotFoundException(MASTER_KEY_NOT_READY_ERROR));
}
}, e -> {
log.error("Failed to get ML encryption master key", e);
exceptionRef.set(e);
}), latch));
}
} else {
exceptionRef.set(new ResourceNotFoundException("ML encryption master key not initialized yet"));
exceptionRef.set(new ResourceNotFoundException(MASTER_KEY_NOT_READY_ERROR));
latch.countDown();
}

Expand All @@ -141,7 +142,7 @@ private void initMasterKey() {
}
}
if (masterKey == null) {
throw new ResourceNotFoundException("ML encryption master key not initialized yet");
throw new ResourceNotFoundException(MASTER_KEY_NOT_READY_ERROR);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.ml.engine.utils;

import com.google.gson.Gson;
import lombok.extern.log4j.Log4j2;
import org.opensearch.ml.common.connector.MLPostProcessFunction;
import org.opensearch.ml.common.connector.MLPreProcessFunction;
import org.opensearch.ml.common.utils.StringUtils;
Expand All @@ -18,6 +19,7 @@
import java.util.Map;
import java.util.Optional;

@Log4j2
public class ScriptUtils {

public static final Gson gson;
Expand All @@ -37,6 +39,7 @@ public static Optional<String> executePreprocessFunction(ScriptService scriptSer
}
return Optional.empty();
}

public static Optional<String> executePostprocessFunction(ScriptService scriptService,
String postProcessFunction,
String resultJson) {
Expand All @@ -51,8 +54,12 @@ public static Optional<String> executePostprocessFunction(ScriptService scriptSe
}

public static String executeScript(ScriptService scriptService, String painlessScript, Map<String, Object> params) {
long start = System.nanoTime();
Script script = new Script(ScriptType.INLINE, "painless", painlessScript, Collections.emptyMap());
TemplateScript templateScript = scriptService.compile(script, TemplateScript.CONTEXT).newInstance(params);
long end = System.nanoTime();
double durationInMs = (end - start) / 1e6;
log.info("----------------- painless script execution time: {} ms", durationInMs);
return templateScript.execute();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import static org.opensearch.ml.common.CommonValue.CREATE_TIME_FIELD;
import static org.opensearch.ml.common.CommonValue.MASTER_KEY;
import static org.opensearch.ml.common.CommonValue.ML_CONFIG_INDEX;
import static org.opensearch.ml.engine.encryptor.EncryptorImpl.MASTER_KEY_NOT_READY_ERROR;

public class EncryptorImplTest {
@Rule
Expand Down Expand Up @@ -121,7 +122,7 @@ public void decrypt() {
@Test
public void encrypt_NullMasterKey_NullMasterKey_MasterKeyNotExistInIndex() {
exceptionRule.expect(ResourceNotFoundException.class);
exceptionRule.expectMessage("ML encryption master key not initialized yet");
exceptionRule.expectMessage(MASTER_KEY_NOT_READY_ERROR);

doAnswer(invocation -> {
ActionListener<GetResponse> listener = invocation.getArgument(1);
Expand Down Expand Up @@ -155,7 +156,7 @@ public void decrypt_NullMasterKey_GetMasterKey_Exception() {
@Test
public void decrypt_MLConfigIndexNotFound() {
exceptionRule.expect(ResourceNotFoundException.class);
exceptionRule.expectMessage("ML encryption master key not initialized yet");
exceptionRule.expectMessage(MASTER_KEY_NOT_READY_ERROR);

Metadata metadata = new Metadata.Builder().indices(ImmutableMap.of()).build();
when(clusterState.metadata()).thenReturn(metadata);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ public void createModelGroup(MLRegisterModelGroupInput input, ActionListener<Str
new IllegalArgumentException(
"The name you provided is already being used by another model with ID: "
+ id
+ ". Please provide a different name"
+ ". Please provide a different name or add \"model_group_id\": \"lMPmr4kB4eSCtCCDmCDm\" to request body"
ylwu-amzn marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Collaborator

@dhrubo-os dhrubo-os Aug 2, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we update to Please provide a different model group name. We can see people started having confusion: #1179

about: \"model_group_id\": \"lMPmr4kB4eSCtCCDmCDm\" : Is this any static group id we are adding?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed hard coded model id .
For Please provide a different model group name , that looks also confusing as user just provide model name in the request body, not model group name,

POST _plugins/_ml/models/_upload
{
  "name": "huggingface/sentence-transformers/all-MiniLM-L12-v2",
  "version": "1.0.1",
  "model_format": "TORCH_SCRIPT"
}

)
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

package org.opensearch.ml.plugin;

import static org.opensearch.ml.common.CommonValue.ML_CONFIG_INDEX;
import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX;
import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX;
import static org.opensearch.ml.common.CommonValue.ML_TASK_INDEX;

Expand Down Expand Up @@ -293,8 +295,12 @@ public Collection<Object> createComponents(
Map<Enum, MLStat<?>> stats = new ConcurrentHashMap<>();
// cluster level stats
stats.put(MLClusterLevelStat.ML_MODEL_INDEX_STATUS, new MLStat<>(true, new IndexStatusSupplier(indexUtils, ML_MODEL_INDEX)));
stats
.put(MLClusterLevelStat.ML_CONNECTOR_INDEX_STATUS, new MLStat<>(true, new IndexStatusSupplier(indexUtils, ML_CONNECTOR_INDEX)));
stats.put(MLClusterLevelStat.ML_CONFIG_INDEX_STATUS, new MLStat<>(true, new IndexStatusSupplier(indexUtils, ML_CONFIG_INDEX)));
stats.put(MLClusterLevelStat.ML_TASK_INDEX_STATUS, new MLStat<>(true, new IndexStatusSupplier(indexUtils, ML_TASK_INDEX)));
stats.put(MLClusterLevelStat.ML_MODEL_COUNT, new MLStat<>(true, new CounterSupplier()));
stats.put(MLClusterLevelStat.ML_CONNECTOR_COUNT, new MLStat<>(true, new CounterSupplier()));
// node level stats
stats.put(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT, new MLStat<>(false, new CounterSupplier()));
stats.put(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT, new MLStat<>(false, new CounterSupplier()));
Expand Down Expand Up @@ -452,7 +458,7 @@ public List<RestHandler> getRestHandlers(
IndexNameExpressionResolver indexNameExpressionResolver,
Supplier<DiscoveryNodes> nodesInCluster
) {
RestMLStatsAction restMLStatsAction = new RestMLStatsAction(mlStats, clusterService, indexUtils);
RestMLStatsAction restMLStatsAction = new RestMLStatsAction(mlStats, clusterService, indexUtils, xContentRegistry);
RestMLTrainingAction restMLTrainingAction = new RestMLTrainingAction();
RestMLTrainAndPredictAction restMLTrainAndPredictAction = new RestMLTrainAndPredictAction();
RestMLPredictionAction restMLPredictionAction = new RestMLPredictionAction(mlModelManager);
Expand Down
42 changes: 33 additions & 9 deletions plugin/src/main/java/org/opensearch/ml/rest/RestMLStatsAction.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.ml.rest;

import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX;
import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX;
import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI;
import static org.opensearch.ml.utils.RestActionUtils.splitCommaSeparatedParam;
Expand All @@ -26,6 +27,7 @@
import org.opensearch.client.node.NodeClient;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
Expand Down Expand Up @@ -54,17 +56,26 @@ public class RestMLStatsAction extends BaseRestHandler {
private MLStats mlStats;
private ClusterService clusterService;
private IndexUtils indexUtils;
private NamedXContentRegistry xContentRegistry;
private static final String QUERY_ALL_MODEL_META_DOC =
"{\"query\":{\"bool\":{\"must_not\":{\"exists\":{\"field\":\"chunk_number\"}}}}}";

/**
* Constructor
* @param mlStats MLStats object
* @param clusterService cluster service
* @param indexUtils index util
*/
public RestMLStatsAction(MLStats mlStats, ClusterService clusterService, IndexUtils indexUtils) {
public RestMLStatsAction(
MLStats mlStats,
ClusterService clusterService,
IndexUtils indexUtils,
NamedXContentRegistry xContentRegistry
) {
this.mlStats = mlStats;
this.clusterService = clusterService;
this.indexUtils = indexUtils;
this.xContentRegistry = xContentRegistry;
}

@Override
Expand Down Expand Up @@ -109,14 +120,27 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
if (finalMlStatsInput.getTargetStatLevels().contains(MLStatLevel.CLUSTER)
&& (finalMlStatsInput.retrieveAllClusterLevelStats()
|| finalMlStatsInput.getClusterLevelStats().contains(MLClusterLevelStat.ML_MODEL_COUNT))) {
indexUtils.getNumberOfDocumentsInIndex(ML_MODEL_INDEX, ActionListener.wrap(count -> {
clusterStatsMap.put(MLClusterLevelStat.ML_MODEL_COUNT, count);
getNodeStats(finalMlStatsInput, clusterStatsMap, client, mlStatsNodesRequest, channel);
}, e -> {
String errorMessage = "Failed to get ML model count";
log.error(errorMessage, e);
onFailure(channel, RestStatus.INTERNAL_SERVER_ERROR, errorMessage, e);
}));
indexUtils
.getNumberOfDocumentsInIndex(
ML_MODEL_INDEX,
QUERY_ALL_MODEL_META_DOC,
xContentRegistry,
ActionListener.wrap(modelCount -> {
clusterStatsMap.put(MLClusterLevelStat.ML_MODEL_COUNT, modelCount);
indexUtils.getNumberOfDocumentsInIndex(ML_CONNECTOR_INDEX, ActionListener.wrap(connectorCount -> {
clusterStatsMap.put(MLClusterLevelStat.ML_CONNECTOR_COUNT, connectorCount);
getNodeStats(finalMlStatsInput, clusterStatsMap, client, mlStatsNodesRequest, channel);
}, e -> {
String errorMessage = "Failed to get ML model count";
log.error(errorMessage, e);
onFailure(channel, RestStatus.INTERNAL_SERVER_ERROR, errorMessage, e);
}));
}, e -> {
String errorMessage = "Failed to get ML model count";
log.error(errorMessage, e);
onFailure(channel, RestStatus.INTERNAL_SERVER_ERROR, errorMessage, e);
})
);
} else {
getNodeStats(finalMlStatsInput, clusterStatsMap, client, mlStatsNodesRequest, channel);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,11 @@
*/
public enum MLClusterLevelStat {
ML_MODEL_INDEX_STATUS,
ML_CONNECTOR_INDEX_STATUS,
ML_CONFIG_INDEX_STATUS,
ML_TASK_INDEX_STATUS,
ML_MODEL_COUNT;
ML_MODEL_COUNT,
ML_CONNECTOR_COUNT;

public static MLClusterLevelStat from(String value) {
try {
Expand Down
38 changes: 38 additions & 0 deletions plugin/src/main/java/org/opensearch/ml/utils/IndexUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,22 @@
import java.util.List;
import java.util.Locale;

import org.opensearch.OpenSearchStatusException;
import org.opensearch.action.ActionListener;
import org.opensearch.action.admin.indices.stats.IndicesStatsRequest;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.client.Client;
import org.opensearch.cluster.health.ClusterIndexHealth;
import org.opensearch.cluster.metadata.IndexMetadata;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.xcontent.LoggingDeprecationHandler;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.rest.RestStatus;
import org.opensearch.search.builder.SearchSourceBuilder;

public class IndexUtils {
/**
Expand Down Expand Up @@ -96,4 +105,33 @@ public void getNumberOfDocumentsInIndex(String indexName, ActionListener<Long> l
}
}

// TODO: add connector count stats
public void getNumberOfDocumentsInIndex(
String indexName,
String searchQuery,
NamedXContentRegistry xContentRegistry,
ActionListener<Long> listener
) {
if (clusterService.state().getRoutingTable().hasIndex(indexName)) {
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
SearchRequest searchRequest = new SearchRequest();
XContentParser parser = XContentType.JSON
.xContent()
.createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, searchQuery);
SearchSourceBuilder builder = SearchSourceBuilder.fromXContent(parser);
builder.fetchSource(false);
searchRequest.source(builder).indices(indexName);

client.search(searchRequest, ActionListener.runBefore(ActionListener.wrap(r -> {
long count = r.getHits().getTotalHits().value;
listener.onResponse(count);
}, e -> { listener.onFailure(e); }), () -> context.restore()));
} catch (Exception e) {
throw new OpenSearchStatusException("Failed to search index " + indexName, RestStatus.BAD_REQUEST);
}
} else {
listener.onResponse(0L);
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ public void test_SuccessAddAllBackendRolesTrue() {
verify(actionListener).onResponse(argumentCaptor.capture());
}

public void test_ModelGroupNameNotUnique() throws IOException {
public void test_ModelGroupNameNotUnique() throws IOException {//
SearchResponse searchResponse = createModelGroupSearchResponse(1);
doAnswer(invocation -> {
ActionListener<SearchResponse> listener = invocation.getArgument(1);
Expand All @@ -143,7 +143,7 @@ public void test_ModelGroupNameNotUnique() throws IOException {
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
verify(actionListener).onFailure(argumentCaptor.capture());
assertEquals(
"The name you provided is already being used by another model with ID: model_group_ID. Please provide a different name",
"The name you provided is already being used by another model with ID: model_group_ID. Please provide a different name or add \"model_group_id\": \"lMPmr4kB4eSCtCCDmCDm\" to request body",
argumentCaptor.getValue().getMessage()
);

Expand Down
Loading