Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature/Extensions] Profile detector #882

Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -776,7 +776,13 @@ List<String> jacocoExclusions = [
'org.opensearch.ad.ratelimit.ResultWriteRequest',
'org.opensearch.ad.AnomalyDetectorJobRunner.*',
'org.opensearch.ad.util.RestHandlerUtils',
'org.opensearch.ad.transport.SearchAnomalyDetectorInfoTransportAction.*'
'org.opensearch.ad.transport.SearchAnomalyDetectorInfoTransportAction.*',
'org.opensearch.ad.transport.RCFPollingAction',
'org.opensearch.ad.transport.RCFPollingRequest',
'org.opensearch.ad.transport.RCFPollingTransportAction',
'org.opensearch.ad.transport.RCFPollingTransportAction.*',
'org.opensearch.ad.transport.RCFPollingResponse',

]


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@
import org.opensearch.ad.transport.PreviewAnomalyDetectorTransportAction;
import org.opensearch.ad.transport.ProfileAction;
import org.opensearch.ad.transport.ProfileTransportAction;
import org.opensearch.ad.transport.RCFPollingAction;
import org.opensearch.ad.transport.RCFPollingTransportAction;
import org.opensearch.ad.transport.RCFResultAction;
import org.opensearch.ad.transport.RCFResultTransportAction;
import org.opensearch.ad.transport.SearchADTasksAction;
Expand Down Expand Up @@ -796,7 +798,8 @@ public List<ExecutorBuilder<?>> getExecutorBuilders(Settings settings) {
new ActionHandler<>(DeleteModelAction.INSTANCE, DeleteModelTransportAction.class),
new ActionHandler<>(ForwardADTaskAction.INSTANCE, ForwardADTaskTransportAction.class),
new ActionHandler<>(ADBatchAnomalyResultAction.INSTANCE, ADBatchAnomalyResultTransportAction.class),
new ActionHandler<>(ADCancelTaskAction.INSTANCE, ADCancelTaskTransportAction.class)
new ActionHandler<>(ADCancelTaskAction.INSTANCE, ADCancelTaskTransportAction.class),
new ActionHandler<>(RCFPollingAction.INSTANCE, RCFPollingTransportAction.class)
);
}

Expand Down
23 changes: 13 additions & 10 deletions src/main/java/org/opensearch/ad/AnomalyDetectorProfileRunner.java
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,10 @@
import org.opensearch.ad.settings.AnomalyDetectorSettings;
import org.opensearch.ad.settings.NumericSetting;
import org.opensearch.ad.task.ADTaskManager;
import org.opensearch.ad.transport.ProfileAction;
import org.opensearch.ad.transport.ProfileRequest;
import org.opensearch.ad.transport.ProfileResponse;
import org.opensearch.ad.transport.RCFPollingAction;
import org.opensearch.ad.transport.RCFPollingRequest;
import org.opensearch.ad.transport.RCFPollingResponse;
import org.opensearch.ad.util.DiscoveryNodeFilterer;
Expand Down Expand Up @@ -78,23 +80,23 @@

public class AnomalyDetectorProfileRunner extends AbstractProfileRunner {
private final Logger logger = LogManager.getLogger(AnomalyDetectorProfileRunner.class);
private SDKRestClient client;
private SDKRestClient sdkRestClient;
private SDKNamedXContentRegistry xContentRegistry;
private DiscoveryNodeFilterer nodeFilter;
private final TransportService transportService;
private final ADTaskManager adTaskManager;
private final int maxTotalEntitiesToTrack;

public AnomalyDetectorProfileRunner(
SDKRestClient client,
SDKRestClient sdkRestClient,
SDKNamedXContentRegistry xContentRegistry,
DiscoveryNodeFilterer nodeFilter,
long requiredSamples,
TransportService transportService,
ADTaskManager adTaskManager
) {
super(requiredSamples);
this.client = client;
this.sdkRestClient = sdkRestClient;
this.xContentRegistry = xContentRegistry;
this.nodeFilter = nodeFilter;
if (requiredSamples <= 0) {
Expand All @@ -119,7 +121,7 @@ private void calculateTotalResponsesToWait(
ActionListener<DetectorProfile> listener
) {
GetRequest getDetectorRequest = new GetRequest(ANOMALY_DETECTORS_INDEX, detectorId);
client.get(getDetectorRequest, ActionListener.wrap(getDetectorResponse -> {
sdkRestClient.get(getDetectorRequest, ActionListener.wrap(getDetectorResponse -> {
if (getDetectorResponse != null && getDetectorResponse.isExists()) {
try (
XContentParser xContentParser = XContentType.JSON
Expand Down Expand Up @@ -153,7 +155,7 @@ private void prepareProfile(
) {
String detectorId = detector.getDetectorId();
GetRequest getRequest = new GetRequest(ANOMALY_DETECTOR_JOB_INDEX, detectorId);
client.get(getRequest, ActionListener.wrap(getResponse -> {
sdkRestClient.get(getRequest, ActionListener.wrap(getResponse -> {
if (getResponse != null && getResponse.isExists()) {
try (
XContentParser parser = XContentType.JSON
Expand Down Expand Up @@ -298,7 +300,7 @@ private void profileEntityStats(MultiResponsesDelegateActionListener<DetectorPro
searchSourceBuilder.aggregation(aggBuilder);

SearchRequest request = new SearchRequest(detector.getIndices().toArray(new String[0]), searchSourceBuilder);
client.search(request, ActionListener.wrap(searchResponse -> {
sdkRestClient.search(request, ActionListener.wrap(searchResponse -> {
Map<String, Aggregation> aggMap = searchResponse.getAggregations().asMap();
InternalCardinality totalEntities = (InternalCardinality) aggMap.get(CommonName.TOTAL_ENTITIES);
long value = totalEntities.getValue();
Expand All @@ -321,7 +323,7 @@ private void profileEntityStats(MultiResponsesDelegateActionListener<DetectorPro
SearchRequest searchRequest = new SearchRequest()
.indices(detector.getIndices().toArray(new String[0]))
.source(searchSourceBuilder);
client.search(searchRequest, ActionListener.wrap(searchResponse -> {
sdkRestClient.search(searchRequest, ActionListener.wrap(searchResponse -> {
DetectorProfile.Builder profileBuilder = new DetectorProfile.Builder();
Aggregations aggs = searchResponse.getAggregations();
if (aggs == null) {
Expand Down Expand Up @@ -383,7 +385,7 @@ private void profileStateRelated(
) {
if (enabled) {
RCFPollingRequest request = new RCFPollingRequest(detector.getDetectorId());
// client.execute(RCFPollingAction.INSTANCE, request, onPollRCFUpdates(detector, profilesToCollect, listener));
sdkRestClient.execute(RCFPollingAction.INSTANCE, request, onPollRCFUpdates(detector, profilesToCollect, listener));
} else {
DetectorProfile.Builder builder = new DetectorProfile.Builder();
if (profilesToCollect.contains(DetectorProfileName.STATE)) {
Expand All @@ -402,7 +404,8 @@ private void profileModels(
) {
DiscoveryNode[] dataNodes = nodeFilter.getEligibleDataNodes();
ProfileRequest profileRequest = new ProfileRequest(detector.getDetectorId(), profiles, forMultiEntityDetector, dataNodes);
// client.execute(ProfileAction.INSTANCE, profileRequest, onModelResponse(detector, profiles, job, listener));// get init progress
sdkRestClient.execute(ProfileAction.INSTANCE, profileRequest, onModelResponse(detector, profiles, job, listener));// get init
// progress
}

private ActionListener<ProfileResponse> onModelResponse(
Expand Down Expand Up @@ -482,7 +485,7 @@ private void confirmMultiEntityDetectorInitStatus(
MultiResponsesDelegateActionListener<DetectorProfile> listener
) {
SearchRequest searchLatestResult = createInittedEverRequest(detector.getDetectorId(), enabledTime, detector.getResultIndex());
client.search(searchLatestResult, onInittedEver(enabledTime, profile, profilesToCollect, detector, totalUpdates, listener));
sdkRestClient.search(searchLatestResult, onInittedEver(enabledTime, profile, profilesToCollect, detector, totalUpdates, listener));
}

private ActionListener<SearchResponse> onInittedEver(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
package org.opensearch.ad.rest;

import static org.opensearch.ad.util.RestHandlerUtils.DETECTOR_ID;
import static org.opensearch.ad.util.RestHandlerUtils.PROFILE;
import static org.opensearch.ad.util.RestHandlerUtils.TYPE;

import java.io.IOException;
Expand Down Expand Up @@ -112,7 +113,34 @@ protected ExtensionRestResponse prepareRequest(RestRequest request) throws IOExc
public List<ReplacedRouteHandler> replacedRouteHandlers() {
String path = String.format(Locale.ROOT, "%s/{%s}", AnomalyDetectorExtension.LEGACY_OPENDISTRO_AD_BASE_URI, DETECTOR_ID);
String newPath = String.format(Locale.ROOT, "%s/{%s}", AnomalyDetectorExtension.AD_BASE_DETECTORS_URI, DETECTOR_ID);
return ImmutableList.of(new ReplacedRouteHandler(RestRequest.Method.GET, newPath, RestRequest.Method.GET, path, handleRequest));
return ImmutableList
.of(
new ReplacedRouteHandler(RestRequest.Method.GET, newPath, RestRequest.Method.GET, path, handleRequest),
new ReplacedRouteHandler(RestRequest.Method.HEAD, newPath, RestRequest.Method.HEAD, path, handleRequest),
dbwiddis marked this conversation as resolved.
Show resolved Hide resolved
new ReplacedRouteHandler(
RestRequest.Method.GET,
String.format(Locale.ROOT, "%s/{%s}/%s", AnomalyDetectorExtension.AD_BASE_DETECTORS_URI, DETECTOR_ID, PROFILE),
RestRequest.Method.GET,
String.format(Locale.ROOT, "%s/{%s}/%s", AnomalyDetectorExtension.LEGACY_OPENDISTRO_AD_BASE_URI, DETECTOR_ID, PROFILE),
handleRequest
),
new ReplacedRouteHandler(
RestRequest.Method.GET,
String
.format(Locale.ROOT, "%s/{%s}/%s/{%s}", AnomalyDetectorExtension.AD_BASE_DETECTORS_URI, DETECTOR_ID, PROFILE, TYPE),
RestRequest.Method.GET,
String
.format(
Locale.ROOT,
"%s/{%s}/%s/{%s}",
AnomalyDetectorExtension.LEGACY_OPENDISTRO_AD_BASE_URI,
DETECTOR_ID,
PROFILE,
TYPE
),
handleRequest
)
);
}

private Function<RestRequest, ExtensionRestResponse> handleRequest = (request) -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,28 +45,23 @@ public class ProfileTransportAction extends TransportAction<ProfileRequest, Prof
private ModelManager modelManager;
private FeatureManager featureManager;
private CacheProvider cacheProvider;
private SDKClusterService clusterService;
private SDKClusterService sdkClusterService;
// the number of models to return. Defaults to 10.
private volatile int numModelsToReturn;

/**
* Constructor
*
* @param threadPool ThreadPool to use
* @param clusterService ClusterService
* @param transportService TransportService
* @param actionFilters Action Filters
* @param modelManager model manager object
* @param featureManager feature manager object
* @param cacheProvider cache provider
* @param settings Node settings accessor
*/
@Inject
public ProfileTransportAction(
ExtensionsRunner extensionsRunner,
ActionFilters actionFilters,
TaskManager taskManager,
SDKClusterService clusterService,
joshpalis marked this conversation as resolved.
Show resolved Hide resolved
ModelManager modelManager,
FeatureManager featureManager,
CacheProvider cacheProvider
Expand All @@ -75,14 +70,14 @@ public ProfileTransportAction(
this.modelManager = modelManager;
this.featureManager = featureManager;
this.cacheProvider = cacheProvider;
this.clusterService = clusterService;
this.sdkClusterService = extensionsRunner.getSdkClusterService();
Settings settings = extensionsRunner.getEnvironmentSettings();
this.numModelsToReturn = MAX_MODEL_SIZE_PER_NODE.get(settings);
clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_MODEL_SIZE_PER_NODE, it -> this.numModelsToReturn = it);
this.sdkClusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_MODEL_SIZE_PER_NODE, it -> this.numModelsToReturn = it);
}

private ProfileResponse newResponse(ProfileRequest request, List<ProfileNodeResponse> responses, List<FailedNodeException> failures) {
return new ProfileResponse(clusterService.state().getClusterName(), responses, failures);
return new ProfileResponse(sdkClusterService.state().getClusterName(), responses, failures);
}

@Override
Expand Down Expand Up @@ -133,7 +128,7 @@ protected void doExecute(Task task, ProfileRequest request, ActionListener<Profi
List
.of(
new ProfileNodeResponse(
clusterService.localNode(),
sdkClusterService.localNode(),
modelSize,
shingleSize,
activeEntity,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,59 +19,63 @@
import org.apache.logging.log4j.Logger;
import org.opensearch.action.ActionListener;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.ad.cluster.HashRing;
import org.opensearch.action.support.TransportAction;
import org.opensearch.ad.common.exception.AnomalyDetectionException;
import org.opensearch.ad.ml.ModelManager;
import org.opensearch.ad.ml.SingleStreamModelIdMapper;
import org.opensearch.ad.settings.AnomalyDetectorSettings;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.settings.Settings;
import org.opensearch.sdk.ExtensionsRunner;
import org.opensearch.sdk.SDKClusterService;
import org.opensearch.tasks.Task;
import org.opensearch.tasks.TaskManager;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportException;
import org.opensearch.transport.TransportRequestOptions;
import org.opensearch.transport.TransportResponseHandler;
import org.opensearch.transport.TransportService;

import com.google.inject.Inject;

/**
* Transport action to get total rcf updates from hosted models or checkpoint
*
*/
public class RCFPollingTransportAction extends HandledTransportAction<RCFPollingRequest, RCFPollingResponse> {
public class RCFPollingTransportAction extends TransportAction<RCFPollingRequest, RCFPollingResponse> {

private static final Logger LOG = LogManager.getLogger(RCFPollingTransportAction.class);
static final String NO_NODE_FOUND_MSG = "Cannot find model hosting node";
static final String FAIL_TO_GET_RCF_UPDATE_MSG = "Cannot find hosted model or related checkpoint";

private final TransportService transportService;
private final ModelManager modelManager;
private final HashRing hashRing;
// private final HashRing hashRing;
private final TransportRequestOptions option;
private final ClusterService clusterService;
private final SDKClusterService sdkClusterService;
// private final DiscoveryNode discoveryNode;

@Inject
public RCFPollingTransportAction(
ActionFilters actionFilters,
TransportService transportService,
Settings settings,
// Settings settings,
ModelManager modelManager,
HashRing hashRing,
ClusterService clusterService
joshpalis marked this conversation as resolved.
Show resolved Hide resolved
// HashRing hashRing,
TaskManager taskManager,
ExtensionsRunner extensionsRunner
) {
super(RCFPollingAction.NAME, transportService, actionFilters, RCFPollingRequest::new);
this.transportService = transportService;
super(RCFPollingAction.NAME, actionFilters, taskManager);
this.modelManager = modelManager;
this.hashRing = hashRing;
// this.hashRing = hashRing;
this.option = TransportRequestOptions
.builder()
.withType(TransportRequestOptions.Type.REG)
.withTimeout(AnomalyDetectorSettings.REQUEST_TIMEOUT.get(settings))
.withTimeout(AnomalyDetectorSettings.REQUEST_TIMEOUT.get(extensionsRunner.getEnvironmentSettings()))
.build();
this.clusterService = clusterService;
this.sdkClusterService = extensionsRunner.getSdkClusterService();
this.transportService = extensionsRunner.getSdkTransportService().getTransportService();
// Adding the below piece of code as we are not supporting multinode. We need one node which we can fetch from clusterService.
// this.discoveryNode=extensionsRunner.getSdkClusterService().localNode();
}

@Override
Expand All @@ -81,15 +85,18 @@ protected void doExecute(Task task, RCFPollingRequest request, ActionListener<RC

String rcfModelID = SingleStreamModelIdMapper.getRcfModelId(adID, 0);

Optional<DiscoveryNode> rcfNode = hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(rcfModelID);
/* Commenting the below piece of code as we do not have support for multinode */
// Optional<DiscoveryNode> rcfNode = hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(rcfModelID);
joshpalis marked this conversation as resolved.
Show resolved Hide resolved

Optional<DiscoveryNode> rcfNode = Optional.ofNullable(sdkClusterService.localNode());
if (!rcfNode.isPresent()) {
listener.onFailure(new AnomalyDetectionException(adID, NO_NODE_FOUND_MSG));
return;
}

String rcfNodeId = rcfNode.get().getId();

DiscoveryNode localNode = clusterService.localNode();
DiscoveryNode localNode = sdkClusterService.localNode();

if (localNode.getId().equals(rcfNodeId)) {
modelManager
Expand Down
Loading