Skip to content

Commit

Permalink
Improve HCAD cold start (#272)
Browse files Browse the repository at this point in the history
* Improve HCAD cold start

ΔT is the length of the interval in minutes. A few tenets:

*If ΔT is small then we can wait for real data to fill up to 32 + shingleSize
*if ΔT is large then either user has a lot of data (and we can probe it slowly, or have large clusters) or the only option is to wait ...
*if ΔT divides 1440 then there is a meaning to a “day” — this is a likely common case.
*Current solution in prod uses linear interpolation — it will take significant work in changing that. Given that, it makes sense to embrace linearity all the way for missing data as well.
*If we have 32 + shingleSize (hopefully recent) values, RCF can get up and running. It will be noisy — there is a reason that default size is 256 (+ shingle size), but it may be more useful for people to start seeing some results.

We have two parameters : numberOfSamples and strideLength —

*We probe numberOfSamples + 1 values at strideLength * ΔT gap
*For each consecutive value present we interpolate into strideLength number of pieces. 0, as returned by the engine should constitute a valid answer, “null” is a missing answer — it may be that 0 is meaningless in some cases, but 0 is also meaningful in some cases. It may be that the query defining the metric is ill-formed, but that cannot be solved by the cold-start strategy of the AD plugin — if we attempt to do that, we will have issues with legitimate interpretations of 0.
*For the missing entries we use linear interpolation as well. Denote the Samples S0, S1, ... as samples in reverse order of time. Each [Si,Si−1]corresponds to strideLength * ΔTgap. If we get samples for S0, S1, S4 (both S2 and S3 are missing) then we interpolate the [S4,S1] into 3*strideLength pieces.
*If the above provides (32+shingleSize) points (note that if S0 is missing or all Sif or some i > N is missing then we would miss a lot of points — but the points we will get are contiguous based on the suggestion) then we have a model. Otherwise we issue another round of query — if there is any sample in the second round then we would have 32 + shingleSize points. If there is no sample in the second round then we should wait for real data.
*If there is no data — there is ultimately nothing that can be done.

How to set numberOfSamples and strideLength?
*Suppose ΔT≤30 and divides 60. Then set numberOfSamples = ceil ( (shingleSize + 32)/ 24 )*24 and strideLength = 60/ΔT. Note that if there is enough data — we may have a lot more than shingleSize+32 points — which is only good.
*Set numberOfSamples = (shingleSize + 32) and strideLength = 1. This should be an uncommon case, but if someone wants a 23 minutes interval — and the system permits -- let's give it to them. Note the smallest ΔT that does not divide 60 is 7 which is quite large to wait for one data point.

Testing done:
1. added precision tests and various unit tests to cover changes.
2. Manually verified HCAD cold start does not break.
  • Loading branch information
kaituo authored Nov 1, 2021
1 parent 7514cdb commit 2ce24a0
Show file tree
Hide file tree
Showing 19 changed files with 1,199 additions and 303 deletions.
3 changes: 2 additions & 1 deletion src/main/java/org/opensearch/ad/AnomalyDetectorPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,8 @@ public PooledObject<LinkedBuffer> wrap(LinkedBuffer obj) {
featureManager,
settings,
AnomalyDetectorSettings.HOURLY_MAINTENANCE,
checkpointWriteQueue
checkpointWriteQueue,
AnomalyDetectorSettings.MAX_COLD_START_ROUNDS
);

EntityColdStartWorker coldstartQueue = new EntityColdStartWorker(
Expand Down
22 changes: 22 additions & 0 deletions src/main/java/org/opensearch/ad/NodeState.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import java.util.Optional;

import org.opensearch.ad.model.AnomalyDetector;
import org.opensearch.ad.model.AnomalyDetectorJob;

/**
* Storing intermediate state during the execution of transport action
Expand All @@ -41,6 +42,8 @@ public class NodeState implements ExpiringState {
private final Clock clock;
// cold start running flag to prevent concurrent cold start
private boolean coldStartRunning;
// detector job
private AnomalyDetectorJob detectorJob;

public NodeState(String detectorId, Clock clock) {
this.detectorId = detectorId;
Expand All @@ -52,6 +55,7 @@ public NodeState(String detectorId, Clock clock) {
this.checkPointExists = false;
this.clock = clock;
this.coldStartRunning = false;
this.detectorJob = null;
}

public String getDetectorId() {
Expand Down Expand Up @@ -166,6 +170,24 @@ public void setColdStartRunning(boolean coldStartRunning) {
refreshLastUpdateTime();
}

/**
*
* @return Detector configuration object
*/
public AnomalyDetectorJob getDetectorJob() {
refreshLastUpdateTime();
return detectorJob;
}

/**
*
* @param detectorJob Detector job
*/
public void setDetectorJob(AnomalyDetectorJob detectorJob) {
this.detectorJob = detectorJob;
refreshLastUpdateTime();
}

/**
* refresh last access time.
*/
Expand Down
40 changes: 40 additions & 0 deletions src/main/java/org/opensearch/ad/NodeStateManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.apache.logging.log4j.util.Strings;
import org.opensearch.action.ActionListener;
import org.opensearch.action.get.GetRequest;
Expand All @@ -34,6 +35,7 @@
import org.opensearch.ad.constant.CommonName;
import org.opensearch.ad.ml.SingleStreamModelIdMapper;
import org.opensearch.ad.model.AnomalyDetector;
import org.opensearch.ad.model.AnomalyDetectorJob;
import org.opensearch.ad.transport.BackPressureRouting;
import org.opensearch.ad.util.ClientUtil;
import org.opensearch.ad.util.ExceptionUtil;
Expand Down Expand Up @@ -367,4 +369,42 @@ public Releasable markColdStartRunning(String adID) {
}
};
}

public void getAnomalyDetectorJob(String adID, ActionListener<Optional<AnomalyDetectorJob>> listener) {
NodeState state = states.get(adID);
if (state != null && state.getDetectorJob() != null) {
listener.onResponse(Optional.of(state.getDetectorJob()));
} else {
GetRequest request = new GetRequest(AnomalyDetectorJob.ANOMALY_DETECTOR_JOB_INDEX, adID);
clientUtil.<GetRequest, GetResponse>asyncRequest(request, client::get, onGetDetectorJobResponse(adID, listener));
}
}

private ActionListener<GetResponse> onGetDetectorJobResponse(String adID, ActionListener<Optional<AnomalyDetectorJob>> listener) {
return ActionListener.wrap(response -> {
if (response == null || !response.isExists()) {
listener.onResponse(Optional.empty());
return;
}

String xc = response.getSourceAsString();
LOG.debug("Fetched anomaly detector: {}", xc);

try (
XContentParser parser = XContentType.JSON
.xContent()
.createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, response.getSourceAsString())
) {
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
AnomalyDetectorJob job = AnomalyDetectorJob.parse(parser);
NodeState state = states.computeIfAbsent(adID, id -> new NodeState(id, clock));
state.setDetectorJob(job);

listener.onResponse(Optional.of(job));
} catch (Exception t) {
LOG.error(new ParameterizedMessage("Fail to parse job {}", adID), t);
listener.onResponse(Optional.empty());
}
}, listener::onFailure);
}
}
25 changes: 8 additions & 17 deletions src/main/java/org/opensearch/ad/feature/SearchFeatureDao.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@

import java.io.IOException;
import java.time.Clock;
import java.time.ZonedDateTime;
import java.util.AbstractMap.SimpleEntry;
import java.util.AbstractMap.SimpleImmutableEntry;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Arrays;
Expand Down Expand Up @@ -471,16 +471,12 @@ public void onFailure(Exception e) {
}

/**
* Get the entity's earliest and latest timestamps
* Get the entity's earliest timestamps
* @param detector detector config
* @param entity the entity's information
* @param listener listener to return back the requested timestamps
*/
public void getEntityMinMaxDataTime(
AnomalyDetector detector,
Entity entity,
ActionListener<Entry<Optional<Long>, Optional<Long>>> listener
) {
public void getEntityMinDataTime(AnomalyDetector detector, Entity entity, ActionListener<Optional<Long>> listener) {
BoolQueryBuilder internalFilterQuery = QueryBuilders.boolQuery();

for (TermQueryBuilder term : entity.getTermQueryBuilders()) {
Expand All @@ -489,29 +485,24 @@ public void getEntityMinMaxDataTime(

SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder()
.query(internalFilterQuery)
.aggregation(AggregationBuilders.max(CommonName.AGG_NAME_MAX_TIME).field(detector.getTimeField()))
.aggregation(AggregationBuilders.min(AGG_NAME_MIN).field(detector.getTimeField()))
.trackTotalHits(false)
.size(0);
SearchRequest searchRequest = new SearchRequest().indices(detector.getIndices().toArray(new String[0])).source(searchSourceBuilder);
client
.search(
searchRequest,
ActionListener.wrap(response -> { listener.onResponse(parseMinMaxDataTime(response)); }, listener::onFailure)
ActionListener.wrap(response -> { listener.onResponse(parseMinDataTime(response)); }, listener::onFailure)
);
}

private Entry<Optional<Long>, Optional<Long>> parseMinMaxDataTime(SearchResponse searchResponse) {
private Optional<Long> parseMinDataTime(SearchResponse searchResponse) {
Optional<Map<String, Aggregation>> mapOptional = Optional
.ofNullable(searchResponse)
.map(SearchResponse::getAggregations)
.map(aggs -> aggs.asMap());

Optional<Long> latest = mapOptional.map(map -> (Max) map.get(CommonName.AGG_NAME_MAX_TIME)).map(agg -> (long) agg.getValue());

Optional<Long> earliest = mapOptional.map(map -> (Min) map.get(AGG_NAME_MIN)).map(agg -> (long) agg.getValue());

return new SimpleImmutableEntry<>(earliest, latest);
return mapOptional.map(map -> (Min) map.get(AGG_NAME_MIN)).map(agg -> (long) agg.getValue());
}

/**
Expand Down Expand Up @@ -1000,9 +991,9 @@ public void getColdStartSamplesForPeriods(
.stream()
.filter(InternalDateRange.class::isInstance)
.flatMap(agg -> ((InternalDateRange) agg).getBuckets().stream())
.filter(bucket -> bucket.getFrom() != null)
.filter(bucket -> bucket.getFrom() != null && bucket.getFrom() instanceof ZonedDateTime)
.filter(bucket -> bucket.getDocCount() > docCountThreshold)
.sorted(Comparator.comparing((Bucket bucket) -> Long.valueOf(bucket.getFromAsString())))
.sorted(Comparator.comparing((Bucket bucket) -> (ZonedDateTime) bucket.getFrom()))
.map(bucket -> parseBucket(bucket, detector.getEnabledFeatureIds()))
.collect(Collectors.toList())
);
Expand Down
Loading

0 comments on commit 2ce24a0

Please sign in to comment.