From 4e441b3abe6029166db607a97064871870168c51 Mon Sep 17 00:00:00 2001 From: Kaituo Li Date: Fri, 13 May 2022 15:54:34 -0700 Subject: [PATCH] Use current time as training data end time The bug happens because we use job enabled time as training data end time. But if the historical data before that time is deleted or does not exist at all, cold start might never finish. This PR uses current time as the training data end time so that cold start has a chance to succeed later. This PR also removes the code that combines cold start data and existing samples in EntityColdStartWorker because we don't add samples until cold start succeeds. Combining cold start data and existing samples is thus unnecessary. Testing done: 1. manually verified the bug is fixed. 2. fixed all related unit tests. Signed-off-by: Kaituo Li --- .../opensearch/ad/ml/EntityColdStarter.java | 88 ++++++++----------- .../ad/ratelimit/CheckpointReadWorker.java | 10 +-- .../opensearch/ad/NodeStateManagerTests.java | 63 +++++++++++++ .../ad/ml/EntityColdStarterTests.java | 72 +++++---------- .../ratelimit/CheckpointReadWorkerTests.java | 7 ++ 5 files changed, 132 insertions(+), 108 deletions(-) diff --git a/src/main/java/org/opensearch/ad/ml/EntityColdStarter.java b/src/main/java/org/opensearch/ad/ml/EntityColdStarter.java index 521f98971..4f10c8dc7 100644 --- a/src/main/java/org/opensearch/ad/ml/EntityColdStarter.java +++ b/src/main/java/org/opensearch/ad/ml/EntityColdStarter.java @@ -21,6 +21,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.Map.Entry; import java.util.Optional; @@ -46,7 +47,6 @@ import org.opensearch.ad.feature.FeatureManager; import org.opensearch.ad.feature.SearchFeatureDao; import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.model.AnomalyDetectorJob; import org.opensearch.ad.model.Entity; import org.opensearch.ad.model.IntervalTimeConfiguration; import org.opensearch.ad.ratelimit.CheckpointWriteWorker; @@ -220,6 +220,22 @@ private void coldStart( ) { logger.debug("Trigger cold start for {}", modelId); + if (modelState == null || entity == null) { + listener + .onFailure( + new IllegalArgumentException( + String + .format( + Locale.ROOT, + "Cannot have empty model state or entity: model state [%b], entity [%b]", + modelState == null, + entity == null + ) + ) + ); + return; + } + if (lastThrottledColdStartTime.plus(Duration.ofMinutes(coolDownMinutes)).isAfter(clock.instant())) { listener.onResponse(null); return; @@ -252,7 +268,7 @@ private void coldStart( try { if (trainingData.isPresent()) { List dataPoints = trainingData.get(); - combineTrainSamples(dataPoints, modelId, modelState); + extractTrainSamples(dataPoints, modelId, modelState); Queue samples = modelState.getModel().getSamples(); // only train models if we have enough samples if (samples.size() >= numMinSamples) { @@ -272,7 +288,6 @@ private void coldStart( } catch (Exception e) { listener.onFailure(e); } - }, exception -> { try { logger.error(new ParameterizedMessage("Error while cold start {}", modelId), exception); @@ -404,42 +419,23 @@ private void getEntityColdStartData(String detectorId, Entity entity, ActionList ActionListener> minTimeListener = ActionListener.wrap(earliest -> { if (earliest.isPresent()) { long startTimeMs = earliest.get().longValue(); - nodeStateManager.getAnomalyDetectorJob(detectorId, ActionListener.wrap(jobOp -> { - if (!jobOp.isPresent()) { - listener.onFailure(new EndRunException(detectorId, "AnomalyDetector job is not available.", false)); - return; - } - AnomalyDetectorJob job = jobOp.get(); - // End time uses milliseconds as start time is assumed to be in milliseconds. - // Opensearch uses a set of preconfigured formats to recognize and parse these strings into a long value - // representing milliseconds-since-the-epoch in UTC. - // More on https://tinyurl.com/wub4fk92 - - // Existing samples either predates or coincide with cold start data. In either case, - // combining them without reordering based on time stamps is not ok. We might introduce - // anomalies in the process. - // An ideal solution would be to record time stamps of data points and combine existing - // samples and cold start samples and do interpolation afterwards. Recording time stamps - // requires changes across the board like bwc in checkpoints. A pragmatic solution is to use - // job enabled time as the end time of cold start period as it is easier to combine - // existing samples with cold start data. We just need to appends existing samples after - // cold start data as existing samples all happen after job enabled time. There might - // be some gaps in between the last cold start sample and the first accumulated sample. - // We will need to accept that precision loss in current solution. - long endTimeMs = job.getEnabledTime().toEpochMilli(); - Pair params = selectRangeParam(detector); - int stride = params.getLeft(); - int numberOfSamples = params.getRight(); - - // we start with round 0 - getFeatures(listener, 0, coldStartData, detector, entity, stride, numberOfSamples, startTimeMs, endTimeMs); - - }, listener::onFailure)); + // End time uses milliseconds as start time is assumed to be in milliseconds. + // Opensearch uses a set of preconfigured formats to recognize and parse these + // strings into a long value + // representing milliseconds-since-the-epoch in UTC. + // More on https://tinyurl.com/wub4fk92 + + long endTimeMs = clock.millis(); + Pair params = selectRangeParam(detector); + int stride = params.getLeft(); + int numberOfSamples = params.getRight(); + + // we start with round 0 + getFeatures(listener, 0, coldStartData, detector, entity, stride, numberOfSamples, startTimeMs, endTimeMs); } else { listener.onResponse(Optional.empty()); } - }, listener::onFailure); searchFeatureDao @@ -694,40 +690,30 @@ public void trainModelFromExistingSamples(ModelState modelState, in } /** - * Precondition: we don't have enough training data. - * Combine training data with existing sample data. - * Existing samples either predates or coincide with cold start data. In either case, - * combining them without reordering based on time stamps is not ok. We might introduce - * anomalies in the process. - * An ideal solution would be to record time stamps of data points and combine existing - * samples and cold start samples and do interpolation afterwards. Recording time stamps - * requires changes across the board like bwc in checkpoints. A pragmatic solution is to use - * job enabled time as the end time of cold start period as it is easier to combine - * existing samples with cold start data. We just need to appends existing samples after - * cold start data as existing samples all happen after job enabled time. There might - * be some gaps in between the last cold start sample and the first accumulated sample. - * We will need to accept that precision loss in current solution. + * Extract training data and put them into ModelState * * @param coldstartDatapoints training data generated from cold start * @param modelId model Id * @param entityState entity State */ - private void combineTrainSamples(List coldstartDatapoints, String modelId, ModelState entityState) { - if (coldstartDatapoints == null || coldstartDatapoints.size() == 0) { + private void extractTrainSamples(List coldstartDatapoints, String modelId, ModelState entityState) { + if (coldstartDatapoints == null || coldstartDatapoints.size() == 0 || entityState == null) { return; } EntityModel model = entityState.getModel(); if (model == null) { model = new EntityModel(null, new ArrayDeque<>(), null); + entityState.setModel(model); } + Queue newSamples = new ArrayDeque<>(); for (double[][] consecutivePoints : coldstartDatapoints) { for (int i = 0; i < consecutivePoints.length; i++) { newSamples.add(consecutivePoints[i]); } } - newSamples.addAll(model.getSamples()); + model.setSamples(newSamples); } diff --git a/src/main/java/org/opensearch/ad/ratelimit/CheckpointReadWorker.java b/src/main/java/org/opensearch/ad/ratelimit/CheckpointReadWorker.java index 2da83ad8e..1d2f14079 100644 --- a/src/main/java/org/opensearch/ad/ratelimit/CheckpointReadWorker.java +++ b/src/main/java/org/opensearch/ad/ratelimit/CheckpointReadWorker.java @@ -354,14 +354,8 @@ private ActionListener> onGetDetector( ModelState modelState = modelManager .processEntityCheckpoint(checkpoint, entity, modelId, detectorId, detector.getShingleSize()); - EntityModel entityModel = modelState.getModel(); - - ThresholdingResult result = null; - if (entityModel.getTrcf().isPresent()) { - result = modelManager.score(origRequest.getCurrentFeature(), modelId, modelState); - } else { - entityModel.addSample(origRequest.getCurrentFeature()); - } + ThresholdingResult result = modelManager + .getAnomalyResultForEntity(origRequest.getCurrentFeature(), modelState, modelId, entity, detector.getShingleSize()); if (result != null && result.getRcfScore() > 0) { AnomalyResult resultToSave = result diff --git a/src/test/java/org/opensearch/ad/NodeStateManagerTests.java b/src/test/java/org/opensearch/ad/NodeStateManagerTests.java index 464105268..2e9f604d3 100644 --- a/src/test/java/org/opensearch/ad/NodeStateManagerTests.java +++ b/src/test/java/org/opensearch/ad/NodeStateManagerTests.java @@ -43,6 +43,7 @@ import org.opensearch.action.get.GetResponse; import org.opensearch.action.search.SearchRequest; import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.AnomalyDetectorJob; import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.ad.transport.AnomalyResultTests; import org.opensearch.ad.util.ClientUtil; @@ -79,6 +80,7 @@ public class NodeStateManagerTests extends AbstractADTest { private GetResponse checkpointResponse; private ClusterService clusterService; private ClusterSettings clusterSettings; + private AnomalyDetectorJob jobToCheck; @Override protected NamedXContentRegistry xContentRegistry() { @@ -129,6 +131,7 @@ public void setUp() throws Exception { stateManager = new NodeStateManager(client, xContentRegistry(), settings, clientUtil, clock, duration, clusterService); checkpointResponse = mock(GetResponse.class); + jobToCheck = TestHelpers.randomAnomalyDetectorJob(true, Instant.ofEpochMilli(1602401500000L), null); } @Override @@ -381,4 +384,64 @@ public void testSettingUpdateBackOffMin() { when(clock.millis()).thenReturn(62000L); assertTrue(!stateManager.isMuted(nodeId, adId)); } + + @SuppressWarnings("unchecked") + private String setupJob() throws IOException { + String detectorId = jobToCheck.getName(); + + doAnswer(invocation -> { + GetRequest request = invocation.getArgument(0); + ActionListener listener = invocation.getArgument(1); + if (request.index().equals(AnomalyDetectorJob.ANOMALY_DETECTOR_JOB_INDEX)) { + listener.onResponse(TestHelpers.createGetResponse(jobToCheck, detectorId, AnomalyDetectorJob.ANOMALY_DETECTOR_JOB_INDEX)); + } + return null; + }).when(client).get(any(), any(ActionListener.class)); + + return detectorId; + } + + public void testGetAnomalyJob() throws IOException, InterruptedException { + String detectorId = setupJob(); + final CountDownLatch inProgressLatch = new CountDownLatch(1); + stateManager.getAnomalyDetectorJob(detectorId, ActionListener.wrap(asDetector -> { + assertEquals(jobToCheck, asDetector.get()); + inProgressLatch.countDown(); + }, exception -> { + assertTrue(false); + inProgressLatch.countDown(); + })); + assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); + } + + /** + * Test that we caches anomaly detector job definition after the first call + * @throws IOException if client throws exception + * @throws InterruptedException if the current thread is interrupted while waiting + */ + @SuppressWarnings("unchecked") + public void testRepeatedGetAnomalyJob() throws IOException, InterruptedException { + String detectorId = setupJob(); + final CountDownLatch inProgressLatch = new CountDownLatch(2); + + stateManager.getAnomalyDetectorJob(detectorId, ActionListener.wrap(asDetector -> { + assertEquals(jobToCheck, asDetector.get()); + inProgressLatch.countDown(); + }, exception -> { + assertTrue(false); + inProgressLatch.countDown(); + })); + + stateManager.getAnomalyDetectorJob(detectorId, ActionListener.wrap(asDetector -> { + assertEquals(jobToCheck, asDetector.get()); + inProgressLatch.countDown(); + }, exception -> { + assertTrue(false); + inProgressLatch.countDown(); + })); + + assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); + + verify(client, times(1)).get(any(), any(ActionListener.class)); + } } diff --git a/src/test/java/org/opensearch/ad/ml/EntityColdStarterTests.java b/src/test/java/org/opensearch/ad/ml/EntityColdStarterTests.java index 1150cf860..40ced3682 100644 --- a/src/test/java/org/opensearch/ad/ml/EntityColdStarterTests.java +++ b/src/test/java/org/opensearch/ad/ml/EntityColdStarterTests.java @@ -59,7 +59,6 @@ import org.opensearch.ad.feature.SearchFeatureDao; import org.opensearch.ad.ml.ModelManager.ModelType; import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.model.AnomalyDetectorJob; import org.opensearch.ad.model.Entity; import org.opensearch.ad.model.IntervalTimeConfiguration; import org.opensearch.ad.ratelimit.CheckpointWriteWorker; @@ -110,7 +109,6 @@ public class EntityColdStarterTests extends AbstractADTest { CheckpointWriteWorker checkpointWriteQueue; Entity entity; AnomalyDetector detector; - AnomalyDetectorJob job; long rcfSeed; ModelManager modelManager; ClientUtil clientUtil; @@ -137,15 +135,13 @@ public void setUp() throws Exception { .setDetectionInterval(new IntervalTimeConfiguration(1, ChronoUnit.MINUTES)) .setCategoryFields(ImmutableList.of(randomAlphaOfLength(5))) .build(); - job = TestHelpers.randomAnomalyDetectorJob(true, Instant.ofEpochMilli(1602401500000L), null); + when(clock.millis()).thenReturn(1602401500000L); doAnswer(invocation -> { GetRequest request = invocation.getArgument(0); ActionListener listener = invocation.getArgument(2); - if (request.index().equals(AnomalyDetectorJob.ANOMALY_DETECTOR_JOB_INDEX)) { - listener.onResponse(TestHelpers.createGetResponse(job, detectorId, AnomalyDetectorJob.ANOMALY_DETECTOR_JOB_INDEX)); - } else { - listener.onResponse(TestHelpers.createGetResponse(detector, detectorId, AnomalyDetector.ANOMALY_DETECTORS_INDEX)); - } + + listener.onResponse(TestHelpers.createGetResponse(detector, detectorId, AnomalyDetector.ANOMALY_DETECTORS_INDEX)); + return null; }).when(clientUtil).asyncRequest(any(GetRequest.class), any(), any(ActionListener.class)); @@ -306,7 +302,7 @@ public void testColdStart() throws InterruptedException, IOException { ThresholdedRandomCutForest ercf = model.getTrcf().get(); // 1 round: stride * (samples - 1) + 1 = 60 * 2 + 1 = 121 // plus 1 existing sample - assertEquals(122, ercf.getForest().getTotalUpdates()); + assertEquals(121, ercf.getForest().getTotalUpdates()); assertTrue("size: " + model.getSamples().size(), model.getSamples().isEmpty()); checkSemaphoreRelease(); @@ -331,8 +327,7 @@ public void testColdStart() throws InterruptedException, IOException { expectedColdStartData.addAll(convertToFeatures(interval1, 60)); double[][] interval2 = interpolator.interpolate(new double[][] { new double[] { sample2[0], sample3[0] } }, 61); expectedColdStartData.addAll(convertToFeatures(interval2, 61)); - expectedColdStartData.add(savedSample); - assertEquals(122, expectedColdStartData.size()); + assertEquals(121, expectedColdStartData.size()); diffTesting(modelState, expectedColdStartData); } @@ -456,8 +451,7 @@ public void testTwoSegmentsWithSingleSample() throws InterruptedException, IOExc // 1 round: stride * (samples - 1) + 1 = 60 * 4 + 1 = 241 // if 241 < shingle size + numMinSamples, then another round is performed - // plus 1 existing sample - assertEquals(242, modelState.getModel().getTrcf().get().getForest().getTotalUpdates()); + assertEquals(241, modelState.getModel().getTrcf().get().getForest().getTotalUpdates()); checkSemaphoreRelease(); List expectedColdStartData = new ArrayList<>(); @@ -471,9 +465,8 @@ public void testTwoSegmentsWithSingleSample() throws InterruptedException, IOExc expectedColdStartData.addAll(convertToFeatures(interval2, 60)); double[][] interval3 = interpolator.interpolate(new double[][] { new double[] { sample3[0], sample5[0] } }, 121); expectedColdStartData.addAll(convertToFeatures(interval3, 121)); - expectedColdStartData.add(savedSample); assertTrue("size: " + model.getSamples().size(), model.getSamples().isEmpty()); - assertEquals(242, expectedColdStartData.size()); + assertEquals(241, expectedColdStartData.size()); diffTesting(modelState, expectedColdStartData); } @@ -514,8 +507,7 @@ public void testTwoSegments() throws InterruptedException, IOException { assertTrue(model.getTrcf().isPresent()); ThresholdedRandomCutForest ercf = model.getTrcf().get(); // 1 rounds: stride * (samples - 1) + 1 = 60 * 5 + 1 = 301 - // plus 1 existing sample - assertEquals(302, ercf.getForest().getTotalUpdates()); + assertEquals(301, ercf.getForest().getTotalUpdates()); checkSemaphoreRelease(); List expectedColdStartData = new ArrayList<>(); @@ -531,8 +523,7 @@ public void testTwoSegments() throws InterruptedException, IOException { expectedColdStartData.addAll(convertToFeatures(interval3, 120)); double[][] interval4 = interpolator.interpolate(new double[][] { new double[] { sample5[0], sample6[0] } }, 61); expectedColdStartData.addAll(convertToFeatures(interval4, 61)); - expectedColdStartData.add(savedSample); - assertEquals(302, expectedColdStartData.size()); + assertEquals(301, expectedColdStartData.size()); assertTrue("size: " + model.getSamples().size(), model.getSamples().isEmpty()); diffTesting(modelState, expectedColdStartData); } @@ -588,11 +579,8 @@ public void testNotEnoughSamples() throws InterruptedException, IOException { doAnswer(invocation -> { GetRequest request = invocation.getArgument(0); ActionListener listener = invocation.getArgument(2); - if (request.index().equals(AnomalyDetectorJob.ANOMALY_DETECTOR_JOB_INDEX)) { - listener.onResponse(TestHelpers.createGetResponse(job, detectorId, AnomalyDetectorJob.ANOMALY_DETECTOR_JOB_INDEX)); - } else { - listener.onResponse(TestHelpers.createGetResponse(detector, detectorId, AnomalyDetector.ANOMALY_DETECTORS_INDEX)); - } + + listener.onResponse(TestHelpers.createGetResponse(detector, detectorId, AnomalyDetector.ANOMALY_DETECTORS_INDEX)); return null; }).when(clientUtil).asyncRequest(any(GetRequest.class), any(), any(ActionListener.class)); @@ -618,16 +606,17 @@ public void testNotEnoughSamples() throws InterruptedException, IOException { // 1st round we add 57 and 1. // 2nd round we add 57 and 1. Queue currentSamples = model.getSamples(); - assertEquals("real sample size is " + currentSamples.size(), 5, currentSamples.size()); + assertEquals("real sample size is " + currentSamples.size(), 4, currentSamples.size()); int j = 0; - while (currentSamples.isEmpty()) { + while (!currentSamples.isEmpty()) { double[] element = currentSamples.poll(); assertEquals(1, element.length); - if (j == 0 || j == 1) { + if (j == 0 || j == 2) { assertEquals(57, element[0], 1e-10); } else { assertEquals(1, element[0], 1e-10); } + j++; } } @@ -638,20 +627,13 @@ public void testEmptyDataRange() throws InterruptedException { modelState = new ModelState<>(model, modelId, detectorId, ModelType.ENTITY.getName(), clock, priority); // the min-max range 894056973000L~894057860000L is too small and thus no data range can be found - job = TestHelpers.randomAnomalyDetectorJob(true, Instant.ofEpochMilli(894057860000L), null); + when(clock.millis()).thenReturn(894057860000L); doAnswer(invocation -> { GetRequest request = invocation.getArgument(0); ActionListener listener = invocation.getArgument(2); - if (request.index().equals(AnomalyDetectorJob.ANOMALY_DETECTOR_JOB_INDEX)) { - listener - .onResponse( - TestHelpers.createGetResponse(job, detector.getDetectorId(), AnomalyDetectorJob.ANOMALY_DETECTOR_JOB_INDEX) - ); - } else { - listener - .onResponse(TestHelpers.createGetResponse(detector, detector.getDetectorId(), AnomalyDetector.ANOMALY_DETECTORS_INDEX)); - } + + listener.onResponse(TestHelpers.createGetResponse(detector, detector.getDetectorId(), AnomalyDetector.ANOMALY_DETECTORS_INDEX)); return null; }).when(clientUtil).asyncRequest(any(GetRequest.class), any(), any(ActionListener.class)); @@ -745,23 +727,15 @@ private void accuracyTemplate(int detectorIntervalMins) throws Exception { .getMultiDimData(dataSize + detector.getShingleSize() - 1, 50, 100, 5, seed, baseDimension, false, trainTestSplit, delta); long[] timestamps = dataWithKeys.timestampsMs; double[][] data = dataWithKeys.data; - job = TestHelpers.randomAnomalyDetectorJob(true, Instant.ofEpochMilli(timestamps[trainTestSplit - 1]), null); + when(clock.millis()).thenReturn(timestamps[trainTestSplit - 1]); // training data ranges from timestamps[0] ~ timestamps[trainTestSplit-1] doAnswer(invocation -> { GetRequest request = invocation.getArgument(0); ActionListener listener = invocation.getArgument(2); - if (request.index().equals(AnomalyDetectorJob.ANOMALY_DETECTOR_JOB_INDEX)) { - listener - .onResponse( - TestHelpers.createGetResponse(job, detector.getDetectorId(), AnomalyDetectorJob.ANOMALY_DETECTOR_JOB_INDEX) - ); - } else { - listener - .onResponse( - TestHelpers.createGetResponse(detector, detector.getDetectorId(), AnomalyDetector.ANOMALY_DETECTORS_INDEX) - ); - } + + listener + .onResponse(TestHelpers.createGetResponse(detector, detector.getDetectorId(), AnomalyDetector.ANOMALY_DETECTORS_INDEX)); return null; }).when(clientUtil).asyncRequest(any(GetRequest.class), any(), any(ActionListener.class)); diff --git a/src/test/java/org/opensearch/ad/ratelimit/CheckpointReadWorkerTests.java b/src/test/java/org/opensearch/ad/ratelimit/CheckpointReadWorkerTests.java index cfc7a06e4..ae51d5d6d 100644 --- a/src/test/java/org/opensearch/ad/ratelimit/CheckpointReadWorkerTests.java +++ b/src/test/java/org/opensearch/ad/ratelimit/CheckpointReadWorkerTests.java @@ -227,6 +227,13 @@ private void regularTestSetUp(RegularSetUpConfig config) { state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(config.fullModel).build()); when(modelManager.processEntityCheckpoint(any(), any(), anyString(), anyString(), anyInt())).thenReturn(state); + if (config.fullModel) { + when(modelManager.getAnomalyResultForEntity(any(), any(), anyString(), any(), anyInt())) + .thenReturn(new ThresholdingResult(0, 1, 1)); + } else { + when(modelManager.getAnomalyResultForEntity(any(), any(), anyString(), any(), anyInt())) + .thenReturn(new ThresholdingResult(0, 0, 0)); + } List requests = new ArrayList<>(); requests.add(request);