From 03e04d70556a2785782bbfd1603f3c19fbd8c895 Mon Sep 17 00:00:00 2001 From: Kaituo Li Date: Wed, 22 Jun 2022 11:55:00 -0700 Subject: [PATCH] Disable interpolation in HCAD cold start (#575) * Disable interpolation in HCAD cold start Previously, we used interpolation in HCAD cold start for the purpose of efficiency. This caused problems for model accuracy. This PR removes interpolation in the cold start step. Testing done: 1. added unit tests to verify precision boosted. Signed-off-by: Kaituo Li --- .../opensearch/ad/ml/EntityColdStarter.java | 37 ++++++--- .../ad/settings/AbstractSetting.java | 9 ++ .../ad/settings/EnabledSetting.java | 21 ++++- .../ad/ml/EntityColdStarterTests.java | 83 +++++++++++++++++-- 4 files changed, 130 insertions(+), 20 deletions(-) diff --git a/src/main/java/org/opensearch/ad/ml/EntityColdStarter.java b/src/main/java/org/opensearch/ad/ml/EntityColdStarter.java index 9c55c17bf..2c9095eb8 100644 --- a/src/main/java/org/opensearch/ad/ml/EntityColdStarter.java +++ b/src/main/java/org/opensearch/ad/ml/EntityColdStarter.java @@ -53,6 +53,7 @@ import org.opensearch.ad.ratelimit.CheckpointWriteWorker; import org.opensearch.ad.ratelimit.RequestPriority; import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.ad.settings.EnabledSetting; import org.opensearch.ad.util.ExceptionUtil; import org.opensearch.common.settings.Settings; import org.opensearch.threadpool.ThreadPool; @@ -575,7 +576,8 @@ private int calculateColdStartDataSize(List coldStartData) { /** * Select strideLength and numberOfSamples, where stride is the number of intervals - * between two samples and trainSamples is training samples to fetch. + * between two samples and trainSamples is training samples to fetch. If we disable + * interpolation, strideLength is 1 and numberOfSamples is shingleSize + numMinSamples; * * Algorithm: * @@ -584,24 +586,33 @@ private int calculateColdStartDataSize(List coldStartData) { * 1. Suppose delta ≤ 30 and divides 60. Then set numberOfSamples = ceil ( (shingleSize + 32)/ 24 )*24 * and strideLength = 60/delta. Note that if there is enough data — we may have lot more than shingleSize+32 * points — which is only good. This step tries to match data with hourly pattern. - * 2. Set numberOfSamples = (shingleSize + 32) and strideLength = 1. This should be an uncommon case, - * but if someone wants 23 minutes interval — and the system permits lets give it to them. Note the - * smallest delta that does not divide 60 is 7 which is quite large to wait for one data point. + * 2. otherwise, set numberOfSamples = (shingleSize + 32) and strideLength = 1. + * This should be an uncommon case as we are assuming most users think in terms of multiple of 5 minutes + *(say 10 or 30 minutes). But if someone wants a 23 minutes interval —- and the system permits -- + * we give it to them. In this case, we disable interpolation as we want to interpolate based on the hourly pattern. + * That's why we use 60 as a dividend in case 1. The 23 minute case does not fit that pattern. + * Note the smallest delta that does not divide 60 is 7 which is quite large to wait for one data point. * @return the chosen strideLength and numberOfSamples */ private Pair selectRangeParam(AnomalyDetector detector) { - long delta = detector.getDetectorIntervalInMinutes(); int shingleSize = detector.getShingleSize(); - int strideLength = defaulStrideLength; - int numberOfSamples = defaultNumberOfSamples; - if (delta <= 30 && 60 % delta == 0) { - strideLength = (int) (60 / delta); - numberOfSamples = (int) Math.ceil((shingleSize + numMinSamples) / 24.0d) * 24; + if (EnabledSetting.isInterpolationInColdStartEnabled()) { + long delta = detector.getDetectorIntervalInMinutes(); + + int strideLength = defaulStrideLength; + int numberOfSamples = defaultNumberOfSamples; + if (delta <= 30 && 60 % delta == 0) { + strideLength = (int) (60 / delta); + numberOfSamples = (int) Math.ceil((shingleSize + numMinSamples) / 24.0d) * 24; + } else { + strideLength = 1; + numberOfSamples = shingleSize + numMinSamples; + } + return Pair.of(strideLength, numberOfSamples); } else { - strideLength = 1; - numberOfSamples = shingleSize + numMinSamples; + return Pair.of(1, shingleSize + numMinSamples); } - return Pair.of(strideLength, numberOfSamples); + } /** diff --git a/src/main/java/org/opensearch/ad/settings/AbstractSetting.java b/src/main/java/org/opensearch/ad/settings/AbstractSetting.java index e80fcbde9..f3cf7a9b5 100644 --- a/src/main/java/org/opensearch/ad/settings/AbstractSetting.java +++ b/src/main/java/org/opensearch/ad/settings/AbstractSetting.java @@ -69,6 +69,15 @@ public T getSettingValue(String key) { return (T) latestSettings.getOrDefault(key, getSetting(key).getDefault(Settings.EMPTY)); } + /** + * Override existing value. + * @param key Key + * @param newVal New value + */ + public void setSettingValue(String key, Object newVal) { + latestSettings.put(key, newVal); + } + private Setting getSetting(String key) { if (settings.containsKey(key)) { return settings.get(key); diff --git a/src/main/java/org/opensearch/ad/settings/EnabledSetting.java b/src/main/java/org/opensearch/ad/settings/EnabledSetting.java index 49e0f0d09..6a22b769b 100644 --- a/src/main/java/org/opensearch/ad/settings/EnabledSetting.java +++ b/src/main/java/org/opensearch/ad/settings/EnabledSetting.java @@ -39,7 +39,10 @@ public class EnabledSetting extends AbstractSetting { public static final String LEGACY_OPENDISTRO_AD_BREAKER_ENABLED = "opendistro.anomaly_detection.breaker.enabled"; - private static final Map> settings = unmodifiableMap(new HashMap>() { + public static final String INTERPOLATION_IN_HCAD_COLD_START_ENABLED = + "plugins.anomaly_detection.hcad_cold_start_interpolation.enabled";; + + public static final Map> settings = unmodifiableMap(new HashMap>() { { Setting LegacyADPluginEnabledSetting = Setting .boolSetting(LEGACY_OPENDISTRO_AD_PLUGIN_ENABLED, true, NodeScope, Dynamic, Deprecated); @@ -64,6 +67,14 @@ public class EnabledSetting extends AbstractSetting { * AD breaker enable/disable setting */ put(AD_BREAKER_ENABLED, Setting.boolSetting(AD_BREAKER_ENABLED, LegacyADBreakerEnabledSetting, NodeScope, Dynamic)); + + /** + * Whether interpolation in HCAD cold start is enabled or not + */ + put( + INTERPOLATION_IN_HCAD_COLD_START_ENABLED, + Setting.boolSetting(INTERPOLATION_IN_HCAD_COLD_START_ENABLED, false, NodeScope, Dynamic) + ); } }); @@ -93,4 +104,12 @@ public static boolean isADPluginEnabled() { public static boolean isADBreakerEnabled() { return EnabledSetting.getInstance().getSettingValue(EnabledSetting.AD_BREAKER_ENABLED); } + + /** + * If enabled, we use samples plus interpolation to train models. + * @return wWhether interpolation in HCAD cold start is enabled or not. + */ + public static boolean isInterpolationInColdStartEnabled() { + return EnabledSetting.getInstance().getSettingValue(EnabledSetting.INTERPOLATION_IN_HCAD_COLD_START_ENABLED); + } } diff --git a/src/test/java/org/opensearch/ad/ml/EntityColdStarterTests.java b/src/test/java/org/opensearch/ad/ml/EntityColdStarterTests.java index 64669b222..215c15304 100644 --- a/src/test/java/org/opensearch/ad/ml/EntityColdStarterTests.java +++ b/src/test/java/org/opensearch/ad/ml/EntityColdStarterTests.java @@ -40,7 +40,10 @@ import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.stream.Collectors; +import org.junit.AfterClass; +import org.junit.BeforeClass; import org.opensearch.Version; import org.opensearch.action.ActionListener; import org.opensearch.action.get.GetRequest; @@ -63,6 +66,7 @@ import org.opensearch.ad.model.IntervalTimeConfiguration; import org.opensearch.ad.ratelimit.CheckpointWriteWorker; import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.ad.settings.EnabledSetting; import org.opensearch.ad.util.ClientUtil; import org.opensearch.client.Client; import org.opensearch.cluster.node.DiscoveryNode; @@ -113,6 +117,23 @@ public class EntityColdStarterTests extends AbstractADTest { ModelManager modelManager; ClientUtil clientUtil; + @BeforeClass + public static void initOnce() { + ClusterService clusterService = mock(ClusterService.class); + + Set> settingSet = EnabledSetting.settings.values().stream().collect(Collectors.toSet()); + + when(clusterService.getClusterSettings()).thenReturn(new ClusterSettings(Settings.EMPTY, settingSet)); + + EnabledSetting.getInstance().init(clusterService); + } + + @AfterClass + public static void clearOnce() { + // restore to default value + EnabledSetting.getInstance().setSettingValue(EnabledSetting.INTERPOLATION_IN_HCAD_COLD_START_ENABLED, false); + } + @SuppressWarnings("unchecked") @Override public void setUp() throws Exception { @@ -217,6 +238,7 @@ public void setUp() throws Exception { rcfSeed, AnomalyDetectorSettings.MAX_COLD_START_ROUNDS ); + EnabledSetting.getInstance().setSettingValue(EnabledSetting.INTERPOLATION_IN_HCAD_COLD_START_ENABLED, Boolean.TRUE); detectorId = "123"; modelId = "123_entity_abc"; @@ -250,6 +272,12 @@ public void setUp() throws Exception { ); } + @Override + public void tearDown() throws Exception { + EnabledSetting.getInstance().setSettingValue(EnabledSetting.INTERPOLATION_IN_HCAD_COLD_START_ENABLED, Boolean.FALSE); + super.tearDown(); + } + private void checkSemaphoreRelease() throws InterruptedException { assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); assertTrue(released.get()); @@ -700,7 +728,7 @@ public void testTrainModelFromExistingSamplesNotEnoughSamples() { } @SuppressWarnings("unchecked") - private void accuracyTemplate(int detectorIntervalMins) throws Exception { + private void accuracyTemplate(int detectorIntervalMins, float precisionThreshold, float recallThreshold) throws Exception { int baseDimension = 2; int dataSize = 20 * AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE; int trainTestSplit = 300; @@ -818,13 +846,13 @@ public int compare(Entry p1, Entry p2) { } // there are randomness involved; keep trying for a limited times - if (prec >= 0.5 && recall >= 0.5) { + if (prec >= precisionThreshold && recall >= recallThreshold) { break; } } - assertTrue("precision is " + prec, prec >= 0.5); - assertTrue("recall is " + recall, recall >= 0.5); + assertTrue("precision is " + prec, prec >= precisionThreshold); + assertTrue("recall is " + recall, recall >= recallThreshold); } public int searchInsert(long[] timestamps, long target) { @@ -842,11 +870,54 @@ public int searchInsert(long[] timestamps, long target) { } public void testAccuracyTenMinuteInterval() throws Exception { - accuracyTemplate(10); + accuracyTemplate(10, 0.5f, 0.5f); } public void testAccuracyThirteenMinuteInterval() throws Exception { - accuracyTemplate(13); + accuracyTemplate(13, 0.5f, 0.5f); + } + + public void testAccuracyOneMinuteIntervalNoInterpolation() throws Exception { + EnabledSetting.getInstance().setSettingValue(EnabledSetting.INTERPOLATION_IN_HCAD_COLD_START_ENABLED, false); + // for one minute interval, we need to disable interpolation to achieve good results + entityColdStarter = new EntityColdStarter( + clock, + threadPool, + stateManager, + AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE, + AnomalyDetectorSettings.NUM_TREES, + AnomalyDetectorSettings.TIME_DECAY, + numMinSamples, + AnomalyDetectorSettings.MAX_SAMPLE_STRIDE, + AnomalyDetectorSettings.MAX_TRAIN_SAMPLE, + interpolator, + searchFeatureDao, + AnomalyDetectorSettings.THRESHOLD_MIN_PVALUE, + featureManager, + settings, + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + checkpointWriteQueue, + rcfSeed, + AnomalyDetectorSettings.MAX_COLD_START_ROUNDS + ); + + modelManager = new ModelManager( + mock(CheckpointDao.class), + mock(Clock.class), + AnomalyDetectorSettings.NUM_TREES, + AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE, + AnomalyDetectorSettings.TIME_DECAY, + AnomalyDetectorSettings.NUM_MIN_SAMPLES, + AnomalyDetectorSettings.THRESHOLD_MIN_PVALUE, + AnomalyDetectorSettings.MIN_PREVIEW_SIZE, + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + entityColdStarter, + mock(FeatureManager.class), + mock(MemoryTracker.class) + ); + + accuracyTemplate(1, 0.6f, 0.6f); } private ModelState createStateForCacheRelease() {