From ea3c6f000e2fd620d36deee1d9c8bf86d06091ca Mon Sep 17 00:00:00 2001 From: Kaituo Li Date: Fri, 13 Sep 2024 16:18:53 -0700 Subject: [PATCH] Bump BWC Version to 2.18 and Fix Bugs (#1311) This PR includes the following updates and bug fixes: * Bump BWC Version to 2.18: Updated BWC version to 2.18 since the 2.17 branch has been cut. * Fix Confidence Value Exceeding 1 in RCF: Addressed a bug in RCF where the confidence value could exceed 1. Implemented a check to cap the confidence value at 1, preventing invalid confidence scores. * Correct Parameter Assignment in GetAnomalyDetectorTransportAction: Fixed an issue where parameter assignments within a method did not affect external variables due to Java's pass-by-value nature. * Fixed a bug in ResultProcessor where we were supposed to check whether the number of sent messages equals the number of received messages before starting imputation. However, the sent message count was mistakenly based on the number of pages rather than the actual number of messages. * Fixed a bug where we mistakenly used the total reserved memory bytes as the memory size per entity in PriorityCache. Testing done: * added test cases for the buggy scenarios * manual e2e testing Signed-off-by: Kaituo Li --- build.gradle | 4 +- .../opensearch/ad/model/AnomalyResult.java | 2 +- .../GetAnomalyDetectorTransportAction.java | 5 +- .../forecast/model/ForecastResult.java | 4 +- .../timeseries/caching/PriorityCache.java | 17 +-- .../BaseGetConfigTransportAction.java | 8 +- .../timeseries/transport/ResultProcessor.java | 13 +- .../ad/HistoricalAnalysisRestTestCase.java | 23 ++++ .../ad/caching/PriorityCacheTests.java | 46 +++++++ .../ad/e2e/AbstractRuleTestCase.java | 35 ++++-- .../org/opensearch/ad/e2e/PreviewRuleIT.java | 2 +- .../org/opensearch/ad/e2e/RealTimeRuleIT.java | 13 +- .../ad/ml/EntityColdStarterTests.java | 21 ++++ .../ad/model/AnomalyResultTests.java | 67 ++++++++++ ...etAnomalyDetectorTransportActionTests.java | 95 ++++++++++++++ .../ad/rest/HistoricalAnalysisRestApiIT.java | 11 +- .../ADHCImputeNodesResponseTests.java | 118 ++++++++++++++++++ .../GetForecasterTransportActionTests.java | 95 ++++++++++++++ 18 files changed, 540 insertions(+), 39 deletions(-) create mode 100644 src/test/java/org/opensearch/ad/model/GetAnomalyDetectorTransportActionTests.java create mode 100644 src/test/java/org/opensearch/ad/transport/ADHCImputeNodesResponseTests.java create mode 100644 src/test/java/org/opensearch/forecast/transport/GetForecasterTransportActionTests.java diff --git a/build.gradle b/build.gradle index b4e2e9f4b..cbfdfbe53 100644 --- a/build.gradle +++ b/build.gradle @@ -696,9 +696,8 @@ List jacocoExclusions = [ // TODO: add test coverage (kaituo) 'org.opensearch.forecast.*', + 'org.opensearch.ad.transport.ADHCImputeNodeResponse', 'org.opensearch.ad.transport.GetAnomalyDetectorTransportAction', - 'org.opensearch.ad.ml.ADColdStart', - 'org.opensearch.ad.transport.ADHCImputeNodesResponse', 'org.opensearch.timeseries.transport.BooleanNodeResponse', 'org.opensearch.timeseries.ml.TimeSeriesSingleStreamCheckpointDao', 'org.opensearch.timeseries.transport.JobRequest', @@ -713,7 +712,6 @@ List jacocoExclusions = [ 'org.opensearch.timeseries.transport.ResultBulkTransportAction', 'org.opensearch.timeseries.transport.handler.IndexMemoryPressureAwareResultHandler', 'org.opensearch.timeseries.transport.handler.ResultIndexingHandler', - 'org.opensearch.ad.transport.ADHCImputeNodeResponse', 'org.opensearch.timeseries.ml.Sample', 'org.opensearch.timeseries.ratelimit.FeatureRequest', 'org.opensearch.ad.transport.ADHCImputeNodeRequest', diff --git a/src/main/java/org/opensearch/ad/model/AnomalyResult.java b/src/main/java/org/opensearch/ad/model/AnomalyResult.java index 868317bab..f52fe7439 100644 --- a/src/main/java/org/opensearch/ad/model/AnomalyResult.java +++ b/src/main/java/org/opensearch/ad/model/AnomalyResult.java @@ -446,7 +446,7 @@ public static AnomalyResult fromRawTRCFResult( taskId, rcfScore, Math.max(0, grade), - confidence, + Math.min(1, confidence), featureData, dataStartTime, dataEndTime, diff --git a/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorTransportAction.java b/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorTransportAction.java index 0bae4a5ce..46e246191 100644 --- a/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorTransportAction.java @@ -86,10 +86,11 @@ public GetAnomalyDetectorTransportAction( } @Override - protected void fillInHistoricalTaskforBwc(Map tasks, Optional historicalAdTask) { + protected Optional fillInHistoricalTaskforBwc(Map tasks) { if (tasks.containsKey(ADTaskType.HISTORICAL.name())) { - historicalAdTask = Optional.ofNullable(tasks.get(ADTaskType.HISTORICAL.name())); + return Optional.ofNullable(tasks.get(ADTaskType.HISTORICAL.name())); } + return Optional.empty(); } @Override diff --git a/src/main/java/org/opensearch/forecast/model/ForecastResult.java b/src/main/java/org/opensearch/forecast/model/ForecastResult.java index 69dd4d6ea..49f51e2e9 100644 --- a/src/main/java/org/opensearch/forecast/model/ForecastResult.java +++ b/src/main/java/org/opensearch/forecast/model/ForecastResult.java @@ -188,7 +188,7 @@ public static List fromRawRCFCasterResult( new ForecastResult( forecasterId, taskId, - dataQuality, + Math.min(1, dataQuality), featureData, dataStartTime, dataEndTime, @@ -218,7 +218,7 @@ public static List fromRawRCFCasterResult( new ForecastResult( forecasterId, taskId, - dataQuality, + Math.min(1, dataQuality), null, dataStartTime, dataEndTime, diff --git a/src/main/java/org/opensearch/timeseries/caching/PriorityCache.java b/src/main/java/org/opensearch/timeseries/caching/PriorityCache.java index 043c197cf..1d984be46 100644 --- a/src/main/java/org/opensearch/timeseries/caching/PriorityCache.java +++ b/src/main/java/org/opensearch/timeseries/caching/PriorityCache.java @@ -463,15 +463,16 @@ public Pair, List> selectUpdateCandidate(Collection return Pair.of(hotEntities, coldEntities); } - private CacheBufferType computeBufferIfAbsent(Config config, String configId) { + public CacheBufferType computeBufferIfAbsent(Config config, String configId) { CacheBufferType buffer = activeEnities.get(configId); if (buffer == null) { - long requiredBytes = getRequiredMemory(config, config.isHighCardinality() ? hcDedicatedCacheSize : 1); + long bytesPerEntityModel = getRequiredMemoryPerEntity(config, memoryTracker, numberOfTrees); + long requiredBytes = bytesPerEntityModel * (config.isHighCardinality() ? hcDedicatedCacheSize : 1); if (memoryTracker.canAllocateReserved(requiredBytes)) { memoryTracker.consumeMemory(requiredBytes, true, origin); buffer = createEmptyCacheBuffer( config, - requiredBytes, + bytesPerEntityModel, priorityTrackerMap .getOrDefault( configId, @@ -496,16 +497,6 @@ private CacheBufferType computeBufferIfAbsent(Config config, String configId) { return buffer; } - /** - * - * @param config Detector config accessor - * @param numberOfEntity number of entities - * @return Memory in bytes required for hosting numberOfEntity entities - */ - private long getRequiredMemory(Config config, int numberOfEntity) { - return numberOfEntity * getRequiredMemoryPerEntity(config, memoryTracker, numberOfTrees); - } - /** * Whether the candidate entity can replace any entity in the shared cache. * We can have race conditions when multiple threads try to evaluate this diff --git a/src/main/java/org/opensearch/timeseries/transport/BaseGetConfigTransportAction.java b/src/main/java/org/opensearch/timeseries/transport/BaseGetConfigTransportAction.java index 3b6ad29d9..b803a4851 100644 --- a/src/main/java/org/opensearch/timeseries/transport/BaseGetConfigTransportAction.java +++ b/src/main/java/org/opensearch/timeseries/transport/BaseGetConfigTransportAction.java @@ -180,7 +180,7 @@ public void doExecute(Task task, ActionRequest request, ActionListener tasks, Optional historicalAdTask) {} + protected Optional fillInHistoricalTaskforBwc(Map tasks) { + return Optional.empty(); + } protected void getExecuteProfile( GetConfigRequest request, diff --git a/src/main/java/org/opensearch/timeseries/transport/ResultProcessor.java b/src/main/java/org/opensearch/timeseries/transport/ResultProcessor.java index b244ee4ac..f412ce84e 100644 --- a/src/main/java/org/opensearch/timeseries/transport/ResultProcessor.java +++ b/src/main/java/org/opensearch/timeseries/transport/ResultProcessor.java @@ -229,8 +229,6 @@ public void onResponse(CompositeRetriever.Page entityFeatures) { pageIterator.next(this); } if (entityFeatures != null && false == entityFeatures.isEmpty()) { - sentOutPages.incrementAndGet(); - LOG .info( "Sending an HC request to process data from timestamp {} to {} for config {}", @@ -285,6 +283,7 @@ public void onResponse(CompositeRetriever.Page entityFeatures) { final AtomicReference failure = new AtomicReference<>(); node2Entities.stream().forEach(nodeEntity -> { + sentOutPages.incrementAndGet(); DiscoveryNode node = nodeEntity.getKey(); transportService .sendRequest( @@ -370,7 +369,15 @@ public void run() { cancellable.get().cancel(); } } else if (Instant.now().toEpochMilli() >= timeoutMillis) { - LOG.warn("Scheduled impute HC task is cancelled due to timeout"); + LOG + .warn( + "Scheduled impute HC task is cancelled due to timeout, current epoch {}, timeout epoch {}, dataEndTime {}, sent out {}, receive {}", + Instant.now().toEpochMilli(), + timeoutMillis, + dataEndTime, + sentOutPages.get(), + receivedPages.get() + ); if (cancellable != null) { cancellable.get().cancel(); } diff --git a/src/test/java/org/opensearch/ad/HistoricalAnalysisRestTestCase.java b/src/test/java/org/opensearch/ad/HistoricalAnalysisRestTestCase.java index 770ba254a..c661371c2 100644 --- a/src/test/java/org/opensearch/ad/HistoricalAnalysisRestTestCase.java +++ b/src/test/java/org/opensearch/ad/HistoricalAnalysisRestTestCase.java @@ -266,6 +266,29 @@ protected List waitUntilTaskReachState(String detectorId, Set ta return results; } + protected List waitUntilTaskReachNumberOfEntities(String detectorId, int categoricalValuesCount) throws InterruptedException { + List results = new ArrayList<>(); + int i = 0; + ADTaskProfile adTaskProfile = null; + // Increase retryTimes if some task can't reach done state + while ((adTaskProfile == null + || adTaskProfile.getTotalEntitiesCount() == null + || adTaskProfile.getTotalEntitiesCount().intValue() != categoricalValuesCount) && i < MAX_RETRY_TIMES) { + try { + adTaskProfile = getADTaskProfile(detectorId); + } catch (Exception e) { + logger.error("failed to get ADTaskProfile", e); + } finally { + Thread.sleep(1000); + } + i++; + } + assertNotNull(adTaskProfile); + results.add(adTaskProfile); + results.add(i); + return results; + } + protected List waitUntilEntityCountAvailable(String detectorId) throws InterruptedException { List results = new ArrayList<>(); int i = 0; diff --git a/src/test/java/org/opensearch/ad/caching/PriorityCacheTests.java b/src/test/java/org/opensearch/ad/caching/PriorityCacheTests.java index 7c09709f0..3da08575d 100644 --- a/src/test/java/org/opensearch/ad/caching/PriorityCacheTests.java +++ b/src/test/java/org/opensearch/ad/caching/PriorityCacheTests.java @@ -26,6 +26,7 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import java.io.IOException; import java.time.Duration; import java.time.Instant; import java.util.ArrayDeque; @@ -62,6 +63,7 @@ import org.opensearch.threadpool.Scheduler.ScheduledCancellable; import org.opensearch.threadpool.ThreadPool; import org.opensearch.timeseries.MemoryTracker; +import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.breaker.CircuitBreakerService; import org.opensearch.timeseries.common.exception.LimitExceededException; import org.opensearch.timeseries.common.exception.TimeSeriesException; @@ -788,4 +790,48 @@ public void testGetTotalUpdates_orElseGetBranchWithNullSamples() { // Assert that the result is 0L assertEquals(0L, result); } + + public void testAllocation() throws IOException { + JvmService jvmService = mock(JvmService.class); + JvmInfo info = mock(JvmInfo.class); + + when(jvmService.info()).thenReturn(info); + + Mem mem = mock(Mem.class); + when(mem.getHeapMax()).thenReturn(new ByteSizeValue(800_000_000L)); + when(info.getMem()).thenReturn(mem); + + CircuitBreakerService circuitBreaker = mock(CircuitBreakerService.class); + when(circuitBreaker.isOpen()).thenReturn(false); + MemoryTracker tracker = new MemoryTracker(jvmService, 0.1, clusterService, circuitBreaker); + + dedicatedCacheSize = 10; + ADPriorityCache cache = new ADPriorityCache( + checkpoint, + dedicatedCacheSize, + AnomalyDetectorSettings.AD_CHECKPOINT_TTL, + AnomalyDetectorSettings.MAX_INACTIVE_ENTITIES, + tracker, + TimeSeriesSettings.NUM_TREES, + clock, + clusterService, + TimeSeriesSettings.HOURLY_MAINTENANCE, + threadPool, + TimeSeriesSettings.MAINTENANCE_FREQ_CONSTANT, + Settings.EMPTY, + AnomalyDetectorSettings.AD_CHECKPOINT_SAVING_FREQ, + checkpointWriteQueue, + checkpointMaintainQueue + ); + + List categoryFields = Arrays.asList("category_field_1", "category_field_2"); + AnomalyDetector anomalyDetector = TestHelpers.AnomalyDetectorBuilder + .newInstance(5) + .setShingleSize(8) + .setCategoryFields(categoryFields) + .build(); + ADCacheBuffer buffer = cache.computeBufferIfAbsent(anomalyDetector, anomalyDetector.getId()); + assertEquals(698336, buffer.getMemoryConsumptionPerModel()); + assertEquals(698336 * dedicatedCacheSize, tracker.getTotalMemoryBytes()); + } } diff --git a/src/test/java/org/opensearch/ad/e2e/AbstractRuleTestCase.java b/src/test/java/org/opensearch/ad/e2e/AbstractRuleTestCase.java index 030f98d20..d1e7a100a 100644 --- a/src/test/java/org/opensearch/ad/e2e/AbstractRuleTestCase.java +++ b/src/test/java/org/opensearch/ad/e2e/AbstractRuleTestCase.java @@ -47,7 +47,7 @@ protected TrainResult ingestTrainDataAndCreateDetector( int trainTestSplit, boolean useDateNanos ) throws Exception { - return ingestTrainDataAndCreateDetector(datasetName, intervalMinutes, numberOfEntities, trainTestSplit, useDateNanos, -1); + return ingestTrainDataAndCreateDetector(datasetName, intervalMinutes, numberOfEntities, trainTestSplit, useDateNanos, -1, true); } protected TrainResult ingestTrainDataAndCreateDetector( @@ -56,7 +56,8 @@ protected TrainResult ingestTrainDataAndCreateDetector( int numberOfEntities, int trainTestSplit, boolean useDateNanos, - int ingestDataSize + int ingestDataSize, + boolean relative ) throws Exception { TrainResult trainResult = ingestTrainData( datasetName, @@ -67,7 +68,7 @@ protected TrainResult ingestTrainDataAndCreateDetector( ingestDataSize ); - String detector = genDetector(datasetName, intervalMinutes, trainTestSplit, trainResult); + String detector = genDetector(datasetName, intervalMinutes, trainTestSplit, trainResult, relative); String detectorId = createDetector(client(), detector); LOG.info("Created detector {}", detectorId); trainResult.detectorId = detectorId; @@ -75,7 +76,22 @@ protected TrainResult ingestTrainDataAndCreateDetector( return trainResult; } - protected String genDetector(String datasetName, int intervalMinutes, int trainTestSplit, TrainResult trainResult) { + protected String genDetector(String datasetName, int intervalMinutes, int trainTestSplit, TrainResult trainResult, boolean relative) { + // Determine threshold types and values based on the 'relative' parameter + String thresholdType1; + String thresholdType2; + double value; + if (relative) { + thresholdType1 = "actual_over_expected_ratio"; + thresholdType2 = "expected_over_actual_ratio"; + value = 0.3; + } else { + thresholdType1 = "actual_over_expected_margin"; + thresholdType2 = "expected_over_actual_margin"; + value = 3000.0; + } + + // Generate the detector JSON string with the appropriate threshold types and values String detector = String .format( Locale.ROOT, @@ -87,15 +103,20 @@ protected String genDetector(String datasetName, int intervalMinutes, int trainT + "\"window_delay\": { \"period\": {\"interval\": %d, \"unit\": \"MINUTES\"}}," + "\"history\": %d," + "\"schema_version\": 0," - + "\"rules\": [{\"action\": \"ignore_anomaly\", \"conditions\": [{\"feature_name\": \"feature 1\", \"threshold_type\": \"actual_over_expected_ratio\", \"operator\": \"lte\", \"value\": 0.3}, " - + "{\"feature_name\": \"feature 1\", \"threshold_type\": \"expected_over_actual_ratio\", \"operator\": \"lte\", \"value\": 0.3}" + + "\"rules\": [{\"action\": \"ignore_anomaly\", \"conditions\": [" + + "{ \"feature_name\": \"feature 1\", \"threshold_type\": \"%s\", \"operator\": \"lte\", \"value\": %f }, " + + "{ \"feature_name\": \"feature 1\", \"threshold_type\": \"%s\", \"operator\": \"lte\", \"value\": %f }" + "]}]" + "}", datasetName, intervalMinutes, categoricalField, trainResult.windowDelay.toMinutes(), - trainTestSplit - 1 + trainTestSplit - 1, + thresholdType1, + value, + thresholdType2, + value ); return detector; } diff --git a/src/test/java/org/opensearch/ad/e2e/PreviewRuleIT.java b/src/test/java/org/opensearch/ad/e2e/PreviewRuleIT.java index 8b481e3c9..e5194bb63 100644 --- a/src/test/java/org/opensearch/ad/e2e/PreviewRuleIT.java +++ b/src/test/java/org/opensearch/ad/e2e/PreviewRuleIT.java @@ -32,7 +32,7 @@ public void testRule() throws Exception { (trainTestSplit + 1) * numberOfEntities ); - String detector = genDetector(datasetName, intervalMinutes, trainTestSplit, trainResult); + String detector = genDetector(datasetName, intervalMinutes, trainTestSplit, trainResult, true); Map result = preview(detector, trainResult.firstDataTime, trainResult.finalDataTime, client()); List results = (List) XContentMapValues.extractValue(result, "anomaly_result"); assertTrue(results.size() > 100); diff --git a/src/test/java/org/opensearch/ad/e2e/RealTimeRuleIT.java b/src/test/java/org/opensearch/ad/e2e/RealTimeRuleIT.java index 650fec64d..a4d4a855b 100644 --- a/src/test/java/org/opensearch/ad/e2e/RealTimeRuleIT.java +++ b/src/test/java/org/opensearch/ad/e2e/RealTimeRuleIT.java @@ -15,7 +15,7 @@ import com.google.gson.JsonObject; public class RealTimeRuleIT extends AbstractRuleTestCase { - public void testRuleWithDateNanos() throws Exception { + private void template(boolean reltive) throws Exception { // TODO: this test case will run for a much longer time and timeout with security enabled if (!isHttps()) { disableResourceNotFoundFaultTolerence(); @@ -32,7 +32,8 @@ public void testRuleWithDateNanos() throws Exception { trainTestSplit, true, // ingest just enough for finish the test - (trainTestSplit + 1) * numberOfEntities + (trainTestSplit + 1) * numberOfEntities, + reltive ); startRealTimeDetector(trainResult, numberOfEntities, intervalMinutes, false); @@ -90,4 +91,12 @@ public void testRuleWithDateNanos() throws Exception { } } } + + public void testRelativeRule() throws Exception { + template(true); + } + + public void testAbsoluateRule() throws Exception { + template(false); + } } diff --git a/src/test/java/org/opensearch/ad/ml/EntityColdStarterTests.java b/src/test/java/org/opensearch/ad/ml/EntityColdStarterTests.java index 3feccd298..d89e03128 100644 --- a/src/test/java/org/opensearch/ad/ml/EntityColdStarterTests.java +++ b/src/test/java/org/opensearch/ad/ml/EntityColdStarterTests.java @@ -1275,4 +1275,25 @@ public void testNotEnoughTrainingData() throws IOException, InterruptedException checkSemaphoreRelease(); assertTrue(modelState.getModel().isEmpty()); } + + public void testTrainModelFromInvalidSamplesNotEnoughSamples() { + Deque samples = new ArrayDeque<>(); + // we have at least numMinSamples samples before executing the null check of trainModelFromDataSegments + for (int i = 0; i < numMinSamples; i++) { + samples.add(new Sample()); + } + + modelState = new ModelState( + null, + modelId, + detectorId, + ModelManager.ModelType.TRCF.getName(), + clock, + priority, + Optional.of(entity), + samples + ); + entityColdStarter.trainModelFromExistingSamples(modelState, detector, "123"); + assertTrue(modelState.getModel().isEmpty()); + } } diff --git a/src/test/java/org/opensearch/ad/model/AnomalyResultTests.java b/src/test/java/org/opensearch/ad/model/AnomalyResultTests.java index 28245aa31..cac5e8a4d 100644 --- a/src/test/java/org/opensearch/ad/model/AnomalyResultTests.java +++ b/src/test/java/org/opensearch/ad/model/AnomalyResultTests.java @@ -12,10 +12,15 @@ package org.opensearch.ad.model; import java.io.IOException; +import java.time.Instant; import java.util.Collection; +import java.util.Collections; +import java.util.List; import java.util.Locale; +import java.util.Optional; import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.commons.authuser.User; import org.opensearch.core.common.io.stream.NamedWriteableAwareStreamInput; import org.opensearch.core.common.io.stream.NamedWriteableRegistry; import org.opensearch.core.xcontent.ToXContent; @@ -24,6 +29,8 @@ import org.opensearch.test.OpenSearchSingleNodeTestCase; import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.FeatureData; import com.google.common.base.Objects; @@ -152,4 +159,64 @@ public void testSerializeAnomalyResultWithEntity() throws IOException { AnomalyResult parsedDetectResult = new AnomalyResult(input); assertTrue(parsedDetectResult.equals(detectResult)); } + + public void testFromRawTRCFResultWithHighConfidence() { + // Set up test parameters + String detectorId = "test-detector-id"; + long intervalMillis = 60000; // Example interval + String taskId = "test-task-id"; + Double rcfScore = 0.5; + Double grade = 0.0; // Non-anomalous + Double confidence = 1.03; // Confidence greater than 1 + List featureData = Collections.emptyList(); // Assuming empty for simplicity + Instant dataStartTime = Instant.now(); + Instant dataEndTime = dataStartTime.plusMillis(intervalMillis); + Instant executionStartTime = Instant.now(); + Instant executionEndTime = executionStartTime.plusMillis(500); + String error = null; + Optional entity = Optional.empty(); + User user = null; // Replace with actual user if needed + Integer schemaVersion = 1; + String modelId = "test-model-id"; + double[] relevantAttribution = null; + Integer relativeIndex = null; + double[] pastValues = null; + double[][] expectedValuesList = null; + double[] likelihoodOfValues = null; + Double threshold = null; + double[] currentData = null; + boolean[] featureImputed = null; + + // Invoke the method under test + AnomalyResult result = AnomalyResult + .fromRawTRCFResult( + detectorId, + intervalMillis, + taskId, + rcfScore, + grade, + confidence, + featureData, + dataStartTime, + dataEndTime, + executionStartTime, + executionEndTime, + error, + entity, + user, + schemaVersion, + modelId, + relevantAttribution, + relativeIndex, + pastValues, + expectedValuesList, + likelihoodOfValues, + threshold, + currentData, + featureImputed + ); + + // Assert that the confidence is capped at 1.0 + assertEquals("Confidence should be capped at 1.0", 1.0, result.getConfidence(), 0.00001); + } } diff --git a/src/test/java/org/opensearch/ad/model/GetAnomalyDetectorTransportActionTests.java b/src/test/java/org/opensearch/ad/model/GetAnomalyDetectorTransportActionTests.java new file mode 100644 index 000000000..64295e4e2 --- /dev/null +++ b/src/test/java/org/opensearch/ad/model/GetAnomalyDetectorTransportActionTests.java @@ -0,0 +1,95 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ad.model; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.function.Consumer; + +import org.opensearch.action.support.ActionFilters; +import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.ad.task.ADTaskManager; +import org.opensearch.ad.transport.GetAnomalyDetectorResponse; +import org.opensearch.ad.transport.GetAnomalyDetectorTransportAction; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.action.ActionListener; +import org.opensearch.timeseries.AbstractTimeSeriesTest; +import org.opensearch.timeseries.transport.GetConfigRequest; +import org.opensearch.transport.TransportService; + +public class GetAnomalyDetectorTransportActionTests extends AbstractTimeSeriesTest { + @SuppressWarnings("unchecked") + public void testRealtimeTaskAssignedWithSingleStreamRealTimeTaskName() throws Exception { + // Arrange + String configID = "test-config-id"; + + // Create a task with singleStreamRealTimeTaskName + Map tasks = new HashMap<>(); + ADTask adTask = ADTask.builder().taskType(ADTaskType.HISTORICAL.name()).build(); + tasks.put(ADTaskType.HISTORICAL.name(), adTask); + + // Mock taskManager to return the tasks + ADTaskManager taskManager = mock(ADTaskManager.class); + doAnswer(invocation -> { + List taskList = new ArrayList<>(tasks.values()); + ((Consumer>) invocation.getArguments()[4]).accept(taskList); + return null; + }).when(taskManager).getAndExecuteOnLatestTasks(anyString(), any(), any(), any(), any(), any(), anyBoolean(), anyInt(), any()); + + // Mock listener + ActionListener listener = mock(ActionListener.class); + + ClusterService clusterService = mock(ClusterService.class); + ClusterSettings settings = new ClusterSettings( + Settings.EMPTY, + Collections.unmodifiableSet(new HashSet<>(Arrays.asList(AnomalyDetectorSettings.AD_FILTER_BY_BACKEND_ROLES))) + ); + when(clusterService.getClusterSettings()).thenReturn(settings); + GetAnomalyDetectorTransportAction getForecaster = spy( + new GetAnomalyDetectorTransportAction( + mock(TransportService.class), + null, + mock(ActionFilters.class), + clusterService, + null, + null, + Settings.EMPTY, + null, + taskManager, + null + ) + ); + + // Act + GetConfigRequest request = new GetConfigRequest(configID, 0L, true, true, "", "", true, null); + getForecaster.getExecute(request, listener); + + // Assert + // Verify that realtimeTask is assigned using singleStreamRealTimeTaskName + // This can be checked by verifying interactions or internal state + // For this example, we'll verify that the correct task is passed to getConfigAndJob + verify(getForecaster).getConfigAndJob(eq(configID), anyBoolean(), anyBoolean(), any(), eq(Optional.of(adTask)), eq(listener)); + } +} diff --git a/src/test/java/org/opensearch/ad/rest/HistoricalAnalysisRestApiIT.java b/src/test/java/org/opensearch/ad/rest/HistoricalAnalysisRestApiIT.java index c42ecb2cc..f428aea16 100644 --- a/src/test/java/org/opensearch/ad/rest/HistoricalAnalysisRestApiIT.java +++ b/src/test/java/org/opensearch/ad/rest/HistoricalAnalysisRestApiIT.java @@ -104,14 +104,17 @@ public void testHistoricalAnalysisForMultiCategoryHC() throws Exception { } private void checkIfTaskCanFinishCorrectly(String detectorId, String taskId, Set states) throws InterruptedException { - List results = waitUntilTaskDone(detectorId); + List results = waitUntilTaskReachState(detectorId, states); TaskProfile endTaskProfile = (TaskProfile) results.get(0); Integer retryCount = (Integer) results.get(1); ADTask stoppedAdTask = endTaskProfile.getTask(); assertEquals(taskId, stoppedAdTask.getTaskId()); if (retryCount < MAX_RETRY_TIMES) { // It's possible that historical analysis still running after max retry times - assertTrue(states.contains(stoppedAdTask.getState())); + assertTrue( + "expect: " + stoppedAdTask.getState() + ", but got " + stoppedAdTask.getState(), + states.contains(stoppedAdTask.getState()) + ); } } @@ -133,6 +136,10 @@ private List startHistoricalAnalysis(int categoryFieldSize, String resul if (!TaskState.RUNNING.name().equals(adTaskProfile.getTask().getState())) { adTaskProfile = (ADTaskProfile) waitUntilTaskReachState(detectorId, ImmutableSet.of(TaskState.RUNNING.name())).get(0); } + if (adTaskProfile == null + || (int) Math.pow(categoryFieldDocCount, categoryFieldSize) != adTaskProfile.getTotalEntitiesCount().intValue()) { + adTaskProfile = (ADTaskProfile) waitUntilTaskReachNumberOfEntities(detectorId, categoryFieldDocCount).get(0); + } assertEquals((int) Math.pow(categoryFieldDocCount, categoryFieldSize), adTaskProfile.getTotalEntitiesCount().intValue()); assertTrue(adTaskProfile.getPendingEntitiesCount() > 0); assertTrue(adTaskProfile.getRunningEntitiesCount() > 0); diff --git a/src/test/java/org/opensearch/ad/transport/ADHCImputeNodesResponseTests.java b/src/test/java/org/opensearch/ad/transport/ADHCImputeNodesResponseTests.java new file mode 100644 index 000000000..f2657f21d --- /dev/null +++ b/src/test/java/org/opensearch/ad/transport/ADHCImputeNodesResponseTests.java @@ -0,0 +1,118 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ad.transport; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; + +import org.opensearch.Version; +import org.opensearch.action.FailedNodeException; +import org.opensearch.cluster.ClusterName; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.test.OpenSearchTestCase; + +public class ADHCImputeNodesResponseTests extends OpenSearchTestCase { + + public void testADHCImputeNodesResponseSerialization() throws IOException { + // Arrange + DiscoveryNode node = new DiscoveryNode( + "nodeId", + buildNewFakeTransportAddress(), + Collections.emptyMap(), + Collections.emptySet(), + Version.CURRENT + ); + Exception previousException = new Exception("Test exception message"); + + ADHCImputeNodeResponse nodeResponse = new ADHCImputeNodeResponse(node, previousException); + List nodes = Collections.singletonList(nodeResponse); + List failures = Collections.emptyList(); + ClusterName clusterName = new ClusterName("test-cluster"); + + ADHCImputeNodesResponse response = new ADHCImputeNodesResponse(clusterName, nodes, failures); + + // Act: Serialize the response + BytesStreamOutput output = new BytesStreamOutput(); + response.writeTo(output); + + // Deserialize the response + StreamInput input = output.bytes().streamInput(); + ADHCImputeNodesResponse deserializedResponse = new ADHCImputeNodesResponse(input); + + // Assert + assertEquals(clusterName, deserializedResponse.getClusterName()); + assertEquals(response.getNodes().size(), deserializedResponse.getNodes().size()); + assertEquals(response.failures().size(), deserializedResponse.failures().size()); + + // Check the node response + ADHCImputeNodeResponse deserializedNodeResponse = deserializedResponse.getNodes().get(0); + assertEquals(node, deserializedNodeResponse.getNode()); + assertNotNull(deserializedNodeResponse.getPreviousException()); + assertEquals("exception: " + previousException.getMessage(), deserializedNodeResponse.getPreviousException().getMessage()); + } + + public void testReadNodesFromAndWriteNodesTo() throws IOException { + // Arrange + DiscoveryNode node = new DiscoveryNode( + "nodeId", + buildNewFakeTransportAddress(), + Collections.emptyMap(), + Collections.emptySet(), + Version.CURRENT + ); + Exception previousException = new Exception("Test exception message"); + + ADHCImputeNodeResponse nodeResponse = new ADHCImputeNodeResponse(node, previousException); + List nodes = Collections.singletonList(nodeResponse); + ClusterName clusterName = new ClusterName("test-cluster"); + ADHCImputeNodesResponse response = new ADHCImputeNodesResponse(clusterName, nodes, Collections.emptyList()); + + // Act: Write nodes to output + BytesStreamOutput output = new BytesStreamOutput(); + response.writeNodesTo(output, nodes); + + // Read nodes from input + StreamInput input = output.bytes().streamInput(); + List readNodes = response.readNodesFrom(input); + + // Assert + assertEquals(nodes.size(), readNodes.size()); + ADHCImputeNodeResponse readNodeResponse = readNodes.get(0); + assertEquals(node, readNodeResponse.getNode()); + assertNotNull(readNodeResponse.getPreviousException()); + assertEquals("exception: " + previousException.getMessage(), readNodeResponse.getPreviousException().getMessage()); + } + + public void testADHCImputeNodeResponseSerialization() throws IOException { + // Arrange + DiscoveryNode node = new DiscoveryNode( + "nodeId", + buildNewFakeTransportAddress(), + Collections.emptyMap(), + Collections.emptySet(), + Version.CURRENT + ); + Exception previousException = new Exception("Test exception message"); + + ADHCImputeNodeResponse nodeResponse = new ADHCImputeNodeResponse(node, previousException); + + // Act: Serialize the node response + BytesStreamOutput output = new BytesStreamOutput(); + nodeResponse.writeTo(output); + + // Deserialize the node response + StreamInput input = output.bytes().streamInput(); + ADHCImputeNodeResponse deserializedNodeResponse = new ADHCImputeNodeResponse(input); + + // Assert + assertEquals(node, deserializedNodeResponse.getNode()); + assertNotNull(deserializedNodeResponse.getPreviousException()); + assertEquals("exception: " + previousException.getMessage(), deserializedNodeResponse.getPreviousException().getMessage()); + } +} diff --git a/src/test/java/org/opensearch/forecast/transport/GetForecasterTransportActionTests.java b/src/test/java/org/opensearch/forecast/transport/GetForecasterTransportActionTests.java new file mode 100644 index 000000000..c71470231 --- /dev/null +++ b/src/test/java/org/opensearch/forecast/transport/GetForecasterTransportActionTests.java @@ -0,0 +1,95 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.transport; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.function.Consumer; + +import org.opensearch.action.support.ActionFilters; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.action.ActionListener; +import org.opensearch.forecast.model.ForecastTask; +import org.opensearch.forecast.model.ForecastTaskType; +import org.opensearch.forecast.settings.ForecastSettings; +import org.opensearch.forecast.task.ForecastTaskManager; +import org.opensearch.timeseries.AbstractTimeSeriesTest; +import org.opensearch.timeseries.transport.GetConfigRequest; +import org.opensearch.transport.TransportService; + +public class GetForecasterTransportActionTests extends AbstractTimeSeriesTest { + @SuppressWarnings("unchecked") + public void testRealtimeTaskAssignedWithSingleStreamRealTimeTaskName() throws Exception { + // Arrange + String configID = "test-config-id"; + + // Create a task with singleStreamRealTimeTaskName + Map tasks = new HashMap<>(); + ForecastTask forecastTask = ForecastTask.builder().taskType(ForecastTaskType.REALTIME_FORECAST_SINGLE_STREAM.name()).build(); + tasks.put(ForecastTaskType.REALTIME_FORECAST_SINGLE_STREAM.name(), forecastTask); + + // Mock taskManager to return the tasks + ForecastTaskManager taskManager = mock(ForecastTaskManager.class); + doAnswer(invocation -> { + List taskList = new ArrayList<>(tasks.values()); + ((Consumer>) invocation.getArguments()[4]).accept(taskList); + return null; + }).when(taskManager).getAndExecuteOnLatestTasks(anyString(), any(), any(), any(), any(), any(), anyBoolean(), anyInt(), any()); + + // Mock listener + ActionListener listener = mock(ActionListener.class); + + ClusterService clusterService = mock(ClusterService.class); + ClusterSettings settings = new ClusterSettings( + Settings.EMPTY, + Collections.unmodifiableSet(new HashSet<>(Arrays.asList(ForecastSettings.FORECAST_FILTER_BY_BACKEND_ROLES))) + ); + when(clusterService.getClusterSettings()).thenReturn(settings); + GetForecasterTransportAction getForecaster = spy( + new GetForecasterTransportAction( + mock(TransportService.class), + null, + mock(ActionFilters.class), + clusterService, + null, + null, + Settings.EMPTY, + null, + taskManager, + null + ) + ); + + // Act + GetConfigRequest request = new GetConfigRequest(configID, 0L, true, true, "", "", true, null); + getForecaster.getExecute(request, listener); + + // Assert + // Verify that realtimeTask is assigned using singleStreamRealTimeTaskName + // This can be checked by verifying interactions or internal state + // For this example, we'll verify that the correct task is passed to getConfigAndJob + verify(getForecaster).getConfigAndJob(eq(configID), anyBoolean(), anyBoolean(), eq(Optional.of(forecastTask)), any(), eq(listener)); + } +}