Skip to content

Commit

Permalink
Disable interpolation in HCAD cold start (#575)
Browse files Browse the repository at this point in the history
* Disable interpolation in HCAD cold start

Previously, we used interpolation in HCAD cold start for the purpose of efficiency. This caused problems for model accuracy. This PR removes interpolation in the cold start step.

Testing done:
1. added unit tests to verify precision boosted.

Signed-off-by: Kaituo Li <[email protected]>
  • Loading branch information
kaituo authored Jun 22, 2022
1 parent 7f3820a commit 03e04d7
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 20 deletions.
37 changes: 24 additions & 13 deletions src/main/java/org/opensearch/ad/ml/EntityColdStarter.java
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
import org.opensearch.ad.ratelimit.CheckpointWriteWorker;
import org.opensearch.ad.ratelimit.RequestPriority;
import org.opensearch.ad.settings.AnomalyDetectorSettings;
import org.opensearch.ad.settings.EnabledSetting;
import org.opensearch.ad.util.ExceptionUtil;
import org.opensearch.common.settings.Settings;
import org.opensearch.threadpool.ThreadPool;
Expand Down Expand Up @@ -575,7 +576,8 @@ private int calculateColdStartDataSize(List<double[][]> coldStartData) {

/**
* Select strideLength and numberOfSamples, where stride is the number of intervals
* between two samples and trainSamples is training samples to fetch.
* between two samples and trainSamples is training samples to fetch. If we disable
* interpolation, strideLength is 1 and numberOfSamples is shingleSize + numMinSamples;
*
* Algorithm:
*
Expand All @@ -584,24 +586,33 @@ private int calculateColdStartDataSize(List<double[][]> coldStartData) {
* 1. Suppose delta ≤ 30 and divides 60. Then set numberOfSamples = ceil ( (shingleSize + 32)/ 24 )*24
* and strideLength = 60/delta. Note that if there is enough data — we may have lot more than shingleSize+32
* points — which is only good. This step tries to match data with hourly pattern.
* 2. Set numberOfSamples = (shingleSize + 32) and strideLength = 1. This should be an uncommon case,
* but if someone wants 23 minutes interval — and the system permits lets give it to them. Note the
* smallest delta that does not divide 60 is 7 which is quite large to wait for one data point.
* 2. otherwise, set numberOfSamples = (shingleSize + 32) and strideLength = 1.
* This should be an uncommon case as we are assuming most users think in terms of multiple of 5 minutes
*(say 10 or 30 minutes). But if someone wants a 23 minutes interval —- and the system permits --
* we give it to them. In this case, we disable interpolation as we want to interpolate based on the hourly pattern.
* That's why we use 60 as a dividend in case 1. The 23 minute case does not fit that pattern.
* Note the smallest delta that does not divide 60 is 7 which is quite large to wait for one data point.
* @return the chosen strideLength and numberOfSamples
*/
private Pair<Integer, Integer> selectRangeParam(AnomalyDetector detector) {
long delta = detector.getDetectorIntervalInMinutes();
int shingleSize = detector.getShingleSize();
int strideLength = defaulStrideLength;
int numberOfSamples = defaultNumberOfSamples;
if (delta <= 30 && 60 % delta == 0) {
strideLength = (int) (60 / delta);
numberOfSamples = (int) Math.ceil((shingleSize + numMinSamples) / 24.0d) * 24;
if (EnabledSetting.isInterpolationInColdStartEnabled()) {
long delta = detector.getDetectorIntervalInMinutes();

int strideLength = defaulStrideLength;
int numberOfSamples = defaultNumberOfSamples;
if (delta <= 30 && 60 % delta == 0) {
strideLength = (int) (60 / delta);
numberOfSamples = (int) Math.ceil((shingleSize + numMinSamples) / 24.0d) * 24;
} else {
strideLength = 1;
numberOfSamples = shingleSize + numMinSamples;
}
return Pair.of(strideLength, numberOfSamples);
} else {
strideLength = 1;
numberOfSamples = shingleSize + numMinSamples;
return Pair.of(1, shingleSize + numMinSamples);
}
return Pair.of(strideLength, numberOfSamples);

}

/**
Expand Down
9 changes: 9 additions & 0 deletions src/main/java/org/opensearch/ad/settings/AbstractSetting.java
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,15 @@ public <T> T getSettingValue(String key) {
return (T) latestSettings.getOrDefault(key, getSetting(key).getDefault(Settings.EMPTY));
}

/**
* Override existing value.
* @param key Key
* @param newVal New value
*/
public void setSettingValue(String key, Object newVal) {
latestSettings.put(key, newVal);
}

private Setting<?> getSetting(String key) {
if (settings.containsKey(key)) {
return settings.get(key);
Expand Down
21 changes: 20 additions & 1 deletion src/main/java/org/opensearch/ad/settings/EnabledSetting.java
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,10 @@ public class EnabledSetting extends AbstractSetting {

public static final String LEGACY_OPENDISTRO_AD_BREAKER_ENABLED = "opendistro.anomaly_detection.breaker.enabled";

private static final Map<String, Setting<?>> settings = unmodifiableMap(new HashMap<String, Setting<?>>() {
public static final String INTERPOLATION_IN_HCAD_COLD_START_ENABLED =
"plugins.anomaly_detection.hcad_cold_start_interpolation.enabled";;

public static final Map<String, Setting<?>> settings = unmodifiableMap(new HashMap<String, Setting<?>>() {
{
Setting LegacyADPluginEnabledSetting = Setting
.boolSetting(LEGACY_OPENDISTRO_AD_PLUGIN_ENABLED, true, NodeScope, Dynamic, Deprecated);
Expand All @@ -64,6 +67,14 @@ public class EnabledSetting extends AbstractSetting {
* AD breaker enable/disable setting
*/
put(AD_BREAKER_ENABLED, Setting.boolSetting(AD_BREAKER_ENABLED, LegacyADBreakerEnabledSetting, NodeScope, Dynamic));

/**
* Whether interpolation in HCAD cold start is enabled or not
*/
put(
INTERPOLATION_IN_HCAD_COLD_START_ENABLED,
Setting.boolSetting(INTERPOLATION_IN_HCAD_COLD_START_ENABLED, false, NodeScope, Dynamic)
);
}
});

Expand Down Expand Up @@ -93,4 +104,12 @@ public static boolean isADPluginEnabled() {
public static boolean isADBreakerEnabled() {
return EnabledSetting.getInstance().getSettingValue(EnabledSetting.AD_BREAKER_ENABLED);
}

/**
* If enabled, we use samples plus interpolation to train models.
* @return wWhether interpolation in HCAD cold start is enabled or not.
*/
public static boolean isInterpolationInColdStartEnabled() {
return EnabledSetting.getInstance().getSettingValue(EnabledSetting.INTERPOLATION_IN_HCAD_COLD_START_ENABLED);
}
}
83 changes: 77 additions & 6 deletions src/test/java/org/opensearch/ad/ml/EntityColdStarterTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,10 @@
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.stream.Collectors;

import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.opensearch.Version;
import org.opensearch.action.ActionListener;
import org.opensearch.action.get.GetRequest;
Expand All @@ -63,6 +66,7 @@
import org.opensearch.ad.model.IntervalTimeConfiguration;
import org.opensearch.ad.ratelimit.CheckpointWriteWorker;
import org.opensearch.ad.settings.AnomalyDetectorSettings;
import org.opensearch.ad.settings.EnabledSetting;
import org.opensearch.ad.util.ClientUtil;
import org.opensearch.client.Client;
import org.opensearch.cluster.node.DiscoveryNode;
Expand Down Expand Up @@ -113,6 +117,23 @@ public class EntityColdStarterTests extends AbstractADTest {
ModelManager modelManager;
ClientUtil clientUtil;

@BeforeClass
public static void initOnce() {
ClusterService clusterService = mock(ClusterService.class);

Set<Setting<?>> settingSet = EnabledSetting.settings.values().stream().collect(Collectors.toSet());

when(clusterService.getClusterSettings()).thenReturn(new ClusterSettings(Settings.EMPTY, settingSet));

EnabledSetting.getInstance().init(clusterService);
}

@AfterClass
public static void clearOnce() {
// restore to default value
EnabledSetting.getInstance().setSettingValue(EnabledSetting.INTERPOLATION_IN_HCAD_COLD_START_ENABLED, false);
}

@SuppressWarnings("unchecked")
@Override
public void setUp() throws Exception {
Expand Down Expand Up @@ -217,6 +238,7 @@ public void setUp() throws Exception {
rcfSeed,
AnomalyDetectorSettings.MAX_COLD_START_ROUNDS
);
EnabledSetting.getInstance().setSettingValue(EnabledSetting.INTERPOLATION_IN_HCAD_COLD_START_ENABLED, Boolean.TRUE);

detectorId = "123";
modelId = "123_entity_abc";
Expand Down Expand Up @@ -250,6 +272,12 @@ public void setUp() throws Exception {
);
}

@Override
public void tearDown() throws Exception {
EnabledSetting.getInstance().setSettingValue(EnabledSetting.INTERPOLATION_IN_HCAD_COLD_START_ENABLED, Boolean.FALSE);
super.tearDown();
}

private void checkSemaphoreRelease() throws InterruptedException {
assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS));
assertTrue(released.get());
Expand Down Expand Up @@ -700,7 +728,7 @@ public void testTrainModelFromExistingSamplesNotEnoughSamples() {
}

@SuppressWarnings("unchecked")
private void accuracyTemplate(int detectorIntervalMins) throws Exception {
private void accuracyTemplate(int detectorIntervalMins, float precisionThreshold, float recallThreshold) throws Exception {
int baseDimension = 2;
int dataSize = 20 * AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE;
int trainTestSplit = 300;
Expand Down Expand Up @@ -818,13 +846,13 @@ public int compare(Entry<Long, Long> p1, Entry<Long, Long> p2) {
}

// there are randomness involved; keep trying for a limited times
if (prec >= 0.5 && recall >= 0.5) {
if (prec >= precisionThreshold && recall >= recallThreshold) {
break;
}
}

assertTrue("precision is " + prec, prec >= 0.5);
assertTrue("recall is " + recall, recall >= 0.5);
assertTrue("precision is " + prec, prec >= precisionThreshold);
assertTrue("recall is " + recall, recall >= recallThreshold);
}

public int searchInsert(long[] timestamps, long target) {
Expand All @@ -842,11 +870,54 @@ public int searchInsert(long[] timestamps, long target) {
}

public void testAccuracyTenMinuteInterval() throws Exception {
accuracyTemplate(10);
accuracyTemplate(10, 0.5f, 0.5f);
}

public void testAccuracyThirteenMinuteInterval() throws Exception {
accuracyTemplate(13);
accuracyTemplate(13, 0.5f, 0.5f);
}

public void testAccuracyOneMinuteIntervalNoInterpolation() throws Exception {
EnabledSetting.getInstance().setSettingValue(EnabledSetting.INTERPOLATION_IN_HCAD_COLD_START_ENABLED, false);
// for one minute interval, we need to disable interpolation to achieve good results
entityColdStarter = new EntityColdStarter(
clock,
threadPool,
stateManager,
AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE,
AnomalyDetectorSettings.NUM_TREES,
AnomalyDetectorSettings.TIME_DECAY,
numMinSamples,
AnomalyDetectorSettings.MAX_SAMPLE_STRIDE,
AnomalyDetectorSettings.MAX_TRAIN_SAMPLE,
interpolator,
searchFeatureDao,
AnomalyDetectorSettings.THRESHOLD_MIN_PVALUE,
featureManager,
settings,
AnomalyDetectorSettings.HOURLY_MAINTENANCE,
checkpointWriteQueue,
rcfSeed,
AnomalyDetectorSettings.MAX_COLD_START_ROUNDS
);

modelManager = new ModelManager(
mock(CheckpointDao.class),
mock(Clock.class),
AnomalyDetectorSettings.NUM_TREES,
AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE,
AnomalyDetectorSettings.TIME_DECAY,
AnomalyDetectorSettings.NUM_MIN_SAMPLES,
AnomalyDetectorSettings.THRESHOLD_MIN_PVALUE,
AnomalyDetectorSettings.MIN_PREVIEW_SIZE,
AnomalyDetectorSettings.HOURLY_MAINTENANCE,
AnomalyDetectorSettings.HOURLY_MAINTENANCE,
entityColdStarter,
mock(FeatureManager.class),
mock(MemoryTracker.class)
);

accuracyTemplate(1, 0.6f, 0.6f);
}

private ModelState<EntityModel> createStateForCacheRelease() {
Expand Down

0 comments on commit 03e04d7

Please sign in to comment.