From c149ff994cb0c475ebd359ab9f796ee29a728224 Mon Sep 17 00:00:00 2001 From: Kaituo Li Date: Tue, 3 Dec 2024 18:05:24 -0800 Subject: [PATCH] Fix exceptions in IntervalCalculation and ResultIndexingHandler (#1379) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Fix race condition in PageListener This PR - Introduced an `AtomicInteger` called `pagesInFlight` to track the number of pages currently being processed.  - Incremented `pagesInFlight` before processing each page and decremented it after processing is complete - Adjusted the condition in `scheduleImputeHCTask` to check both `pagesInFlight.get() == 0` (all pages have been processed) and `sentOutPages.get() == receivedPages.get()` (all responses have been received) before scheduling the `imputeHC` task.  - Removed the previous final check in `onResponse` that decided when to schedule `imputeHC`, relying instead on the updated counters for accurate synchronization. These changes address the race condition where `sentOutPages` might not have been incremented in time before checking whether to schedule the `imputeHC` task. By accurately tracking the number of in-flight pages and sent pages, we ensure that `imputeHC` is executed only after all pages have been fully processed and all responses have been received. Testing done: 1. Reproduced the race condition by starting two detectors with imputation. This causes an out of order illegal argument exception from RCF due to this race condition. Also verified the change fixed the problem. 2. added an IT for the above scenario. Signed-off-by: Kaituo Li * Fix exceptions in IntervalCalculation and ResultIndexingHandler - **IntervalCalculation**: Prevent an `ArrayIndexOutOfBoundsException` by returning early when there are fewer than two timestamps. Previously, the code assumed at least two timestamps, causing an exception when only one was present. - **ResultIndexingHandler**: Handle exceptions from asynchronous calls by logging error messages instead of throwing exceptions. Since the caller does not wait for these asynchronous operations, throwing exceptions had no effect and could lead to unhandled exceptions. Logging provides visibility without disrupting the caller's flow. Testing done: 1. added UT and ITs. Signed-off-by: Kaituo Li --------- Signed-off-by: Kaituo Li --- build.gradle | 3 - ...nomaly-detection.release-notes-2.18.0.0.md | 1 + .../ForecastResultBulkTransportAction.java | 2 +- .../rest/handler/IntervalCalculation.java | 5 +- .../handler/ResultIndexingHandler.java | 131 +- .../handler/AnomalyResultHandlerTests.java | 225 +++- .../AbstractForecastSyntheticDataTest.java | 14 + .../forecast/rest/ForecastRestApiIT.java | 1149 ++++++++++++++++- .../forecast/rest/SecureForecastRestIT.java | 10 - .../timeseries/AbstractSyntheticDataTest.java | 37 + .../opensearch/timeseries/TestHelpers.java | 64 + ...orecastResultBulkTransportActionTests.java | 150 +++ 12 files changed, 1663 insertions(+), 128 deletions(-) create mode 100644 src/test/java/org/opensearch/timeseries/transport/ForecastResultBulkTransportActionTests.java diff --git a/build.gradle b/build.gradle index 3ccb40abb..9dcbb1abf 100644 --- a/build.gradle +++ b/build.gradle @@ -699,9 +699,6 @@ List jacocoExclusions = [ // TODO: add test coverage (kaituo) 'org.opensearch.forecast.*', - 'org.opensearch.timeseries.transport.ResultBulkTransportAction', - 'org.opensearch.timeseries.transport.handler.IndexMemoryPressureAwareResultHandler', - 'org.opensearch.timeseries.transport.handler.ResultIndexingHandler', 'org.opensearch.timeseries.ml.Sample', 'org.opensearch.timeseries.ratelimit.FeatureRequest', 'org.opensearch.ad.transport.ADHCImputeNodeRequest', diff --git a/release-notes/opensearch-anomaly-detection.release-notes-2.18.0.0.md b/release-notes/opensearch-anomaly-detection.release-notes-2.18.0.0.md index fee194023..b62a3ab10 100644 --- a/release-notes/opensearch-anomaly-detection.release-notes-2.18.0.0.md +++ b/release-notes/opensearch-anomaly-detection.release-notes-2.18.0.0.md @@ -7,6 +7,7 @@ Compatible with OpenSearch 2.18.0 ### Bug Fixes * Bump RCF Version and Fix Default Rules Bug in AnomalyDetector ([#1334](https://github.com/opensearch-project/anomaly-detection/pull/1334)) +* Fix race condition in PageListener ([#1351](https://github.com/opensearch-project/anomaly-detection/pull/1351)) ### Infrastructure * forward port flaky test fix and add forecasting security tests ([#1329](https://github.com/opensearch-project/anomaly-detection/pull/1329)) diff --git a/src/main/java/org/opensearch/forecast/transport/ForecastResultBulkTransportAction.java b/src/main/java/org/opensearch/forecast/transport/ForecastResultBulkTransportAction.java index dcb792fba..dcdd0680a 100644 --- a/src/main/java/org/opensearch/forecast/transport/ForecastResultBulkTransportAction.java +++ b/src/main/java/org/opensearch/forecast/transport/ForecastResultBulkTransportAction.java @@ -58,7 +58,7 @@ public ForecastResultBulkTransportAction( } @Override - protected BulkRequest prepareBulkRequest(float indexingPressurePercent, ForecastResultBulkRequest request) { + public BulkRequest prepareBulkRequest(float indexingPressurePercent, ForecastResultBulkRequest request) { BulkRequest bulkRequest = new BulkRequest(); List results = request.getResults(); diff --git a/src/main/java/org/opensearch/timeseries/rest/handler/IntervalCalculation.java b/src/main/java/org/opensearch/timeseries/rest/handler/IntervalCalculation.java index ab17f91cf..53c503626 100644 --- a/src/main/java/org/opensearch/timeseries/rest/handler/IntervalCalculation.java +++ b/src/main/java/org/opensearch/timeseries/rest/handler/IntervalCalculation.java @@ -252,8 +252,9 @@ private void findMinimumInterval(LongBounds timeStampBounds, ActionListener searchResponseListener = ActionListener.wrap(response -> { List timestamps = aggregationPrep.getTimestamps(response); - if (timestamps.isEmpty()) { - logger.warn("empty data, return one minute by default"); + if (timestamps.size() < 2) { + // to calculate the difference we need at least 2 timestamps + logger.warn("not enough data, return one minute by default"); listener.onResponse(new IntervalTimeConfiguration(1, ChronoUnit.MINUTES)); return; } diff --git a/src/main/java/org/opensearch/timeseries/transport/handler/ResultIndexingHandler.java b/src/main/java/org/opensearch/timeseries/transport/handler/ResultIndexingHandler.java index d7beb64a8..1647fa01a 100644 --- a/src/main/java/org/opensearch/timeseries/transport/handler/ResultIndexingHandler.java +++ b/src/main/java/org/opensearch/timeseries/transport/handler/ResultIndexingHandler.java @@ -34,7 +34,6 @@ import org.opensearch.core.concurrency.OpenSearchRejectedExecutionException; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.threadpool.ThreadPool; -import org.opensearch.timeseries.common.exception.EndRunException; import org.opensearch.timeseries.common.exception.TimeSeriesException; import org.opensearch.timeseries.indices.IndexManagement; import org.opensearch.timeseries.indices.TimeSeriesIndex; @@ -109,22 +108,28 @@ public void setFixedDoc(boolean fixedDoc) { } // TODO: check if user has permission to index. - public void index(ResultType toSave, String detectorId, String indexOrAliasName) { - try { - if (indexOrAliasName != null) { - if (indexUtils.checkIndicesBlocked(clusterService.state(), ClusterBlockLevel.WRITE, indexOrAliasName)) { - LOG.warn(String.format(Locale.ROOT, CANNOT_SAVE_ERR_MSG, detectorId)); - return; - } - // We create custom result index when creating a detector. Custom result index can be rolled over and thus we may need to - // create a new one. - if (!timeSeriesIndices.doesIndexExist(indexOrAliasName) && !timeSeriesIndices.doesAliasExist(indexOrAliasName)) { - timeSeriesIndices.initCustomResultIndexDirectly(indexOrAliasName, ActionListener.wrap(response -> { - if (response.isAcknowledged()) { - save(toSave, detectorId, indexOrAliasName); - } else { - throw new TimeSeriesException( - detectorId, + /** + * Run async index operation. Cannot guarantee index is done after finishing executing the function as several calls + * in the method are asynchronous. + * @param toSave Result to save + * @param configId config id + * @param indexOrAliasName custom index or alias name + */ + public void index(ResultType toSave, String configId, String indexOrAliasName) { + if (indexOrAliasName != null) { + if (indexUtils.checkIndicesBlocked(clusterService.state(), ClusterBlockLevel.WRITE, indexOrAliasName)) { + LOG.warn(String.format(Locale.ROOT, CANNOT_SAVE_ERR_MSG, configId)); + return; + } + // We create custom result index when creating a detector. Custom result index can be rolled over and thus we may need to + // create a new one. + if (!timeSeriesIndices.doesIndexExist(indexOrAliasName) && !timeSeriesIndices.doesAliasExist(indexOrAliasName)) { + timeSeriesIndices.initCustomResultIndexDirectly(indexOrAliasName, ActionListener.wrap(response -> { + if (response.isAcknowledged()) { + save(toSave, configId, indexOrAliasName); + } else { + LOG + .error( String .format( Locale.ROOT, @@ -132,65 +137,49 @@ public void index(ResultType toSave, String detectorId, String indexOrAliasName) indexOrAliasName ) ); - } - }, exception -> { - if (ExceptionsHelper.unwrapCause(exception) instanceof ResourceAlreadyExistsException) { - // It is possible the index has been created while we sending the create request - save(toSave, detectorId, indexOrAliasName); - } else { - throw new TimeSeriesException( - detectorId, - String.format(Locale.ROOT, "cannot create result index %s", indexOrAliasName), - exception - ); - } - })); - } else { - timeSeriesIndices.validateResultIndexMapping(indexOrAliasName, ActionListener.wrap(valid -> { - if (!valid) { - throw new EndRunException(detectorId, "wrong index mapping of custom AD result index", true); - } else { - save(toSave, detectorId, indexOrAliasName); - } - }, exception -> { - throw new TimeSeriesException( - detectorId, - String.format(Locale.ROOT, "cannot validate result index %s", indexOrAliasName), - exception - ); - })); - } + } + }, exception -> { + if (ExceptionsHelper.unwrapCause(exception) instanceof ResourceAlreadyExistsException) { + // It is possible the index has been created while we sending the create request + save(toSave, configId, indexOrAliasName); + } else { + LOG.error(String.format(Locale.ROOT, "cannot create result index %s", indexOrAliasName), exception); + } + })); } else { - if (indexUtils.checkIndicesBlocked(clusterService.state(), ClusterBlockLevel.WRITE, this.defaultResultIndexName)) { - LOG.warn(String.format(Locale.ROOT, CANNOT_SAVE_ERR_MSG, detectorId)); - return; - } - if (!timeSeriesIndices.doesDefaultResultIndexExist()) { - timeSeriesIndices - .initDefaultResultIndexDirectly( - ActionListener.wrap(initResponse -> onCreateIndexResponse(initResponse, toSave, detectorId), exception -> { - if (ExceptionsHelper.unwrapCause(exception) instanceof ResourceAlreadyExistsException) { - // It is possible the index has been created while we sending the create request - save(toSave, detectorId); - } else { - throw new TimeSeriesException( - detectorId, + timeSeriesIndices.validateResultIndexMapping(indexOrAliasName, ActionListener.wrap(valid -> { + if (!valid) { + LOG.error("wrong index mapping of custom result index"); + } else { + save(toSave, configId, indexOrAliasName); + } + }, exception -> { LOG.error(String.format(Locale.ROOT, "cannot validate result index %s", indexOrAliasName), exception); }) + ); + } + } else { + if (indexUtils.checkIndicesBlocked(clusterService.state(), ClusterBlockLevel.WRITE, this.defaultResultIndexName)) { + LOG.warn(String.format(Locale.ROOT, CANNOT_SAVE_ERR_MSG, configId)); + return; + } + if (!timeSeriesIndices.doesDefaultResultIndexExist()) { + timeSeriesIndices + .initDefaultResultIndexDirectly( + ActionListener.wrap(initResponse -> onCreateIndexResponse(initResponse, toSave, configId), exception -> { + if (ExceptionsHelper.unwrapCause(exception) instanceof ResourceAlreadyExistsException) { + // It is possible the index has been created while we sending the create request + save(toSave, configId); + } else { + LOG + .error( String.format(Locale.ROOT, "Unexpected error creating index %s", defaultResultIndexName), exception ); - } - }) - ); - } else { - save(toSave, detectorId); - } + } + }) + ); + } else { + save(toSave, configId); } - } catch (Exception e) { - throw new TimeSeriesException( - detectorId, - String.format(Locale.ROOT, "Error in saving %s for detector %s", defaultResultIndexName, detectorId), - e - ); } } diff --git a/src/test/java/org/opensearch/ad/transport/handler/AnomalyResultHandlerTests.java b/src/test/java/org/opensearch/ad/transport/handler/AnomalyResultHandlerTests.java index 73efdf4a1..8db056c0c 100644 --- a/src/test/java/org/opensearch/ad/transport/handler/AnomalyResultHandlerTests.java +++ b/src/test/java/org/opensearch/ad/transport/handler/AnomalyResultHandlerTests.java @@ -12,11 +12,14 @@ package org.opensearch.ad.transport.handler; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; import java.io.IOException; import java.time.Clock; @@ -31,6 +34,8 @@ import org.junit.rules.ExpectedException; import org.mockito.ArgumentMatchers; import org.mockito.Mock; +import org.opensearch.ResourceAlreadyExistsException; +import org.opensearch.action.admin.indices.create.CreateIndexResponse; import org.opensearch.action.index.IndexRequest; import org.opensearch.action.index.IndexResponse; import org.opensearch.ad.constant.ADCommonName; @@ -44,7 +49,6 @@ import org.opensearch.core.concurrency.OpenSearchRejectedExecutionException; import org.opensearch.timeseries.NodeStateManager; import org.opensearch.timeseries.TestHelpers; -import org.opensearch.timeseries.common.exception.TimeSeriesException; import org.opensearch.timeseries.transport.handler.ResultIndexingHandler; public class AnomalyResultHandlerTests extends AbstractIndexHandlerTest { @@ -181,9 +185,6 @@ public void testAdResultIndexExist() throws IOException { @Test public void testAdResultIndexOtherException() throws IOException { - expectedEx.expect(TimeSeriesException.class); - expectedEx.expectMessage("Error in saving .opendistro-anomaly-results for detector " + detectorId); - setUpSavingAnomalyResultIndex(false, IndexCreation.RUNTIME_EXCEPTION); ResultIndexingHandler handler = new ResultIndexingHandler<>( client, @@ -199,6 +200,7 @@ public void testAdResultIndexOtherException() throws IOException { ); handler.index(TestHelpers.randomAnomalyDetectResult(), detectorId, null); verify(client, never()).index(any(), any()); + assertTrue(testAppender.containsMessage(String.format(Locale.ROOT, "Unexpected error creating index .opendistro-anomaly-results"))); } /** @@ -212,7 +214,6 @@ public void testAdResultIndexOtherException() throws IOException { * @throws InterruptedException if thread execution is interrupted * @throws IOException if IO failures */ - @SuppressWarnings("unchecked") private void savingFailureTemplate(boolean throwOpenSearchRejectedExecutionException, int latchCount, boolean adResultIndexExists) throws InterruptedException, IOException { @@ -262,4 +263,218 @@ private void savingFailureTemplate(boolean throwOpenSearchRejectedExecutionExcep backoffLatch.await(1, TimeUnit.MINUTES); } + + @Test + public void testCustomIndexCreate() { + String testIndex = "test"; + setWriteBlockAdResultIndex(false); + when(anomalyDetectionIndices.doesIndexExist(anyString())).thenReturn(false); + when(anomalyDetectionIndices.doesAliasExist(anyString())).thenReturn(false); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(new CreateIndexResponse(true, true, testIndex)); + return null; + }).when(anomalyDetectionIndices).initCustomResultIndexDirectly(eq(testIndex), any()); + + ResultIndexingHandler handler = new ResultIndexingHandler<>( + client, + settings, + threadPool, + ADCommonName.ANOMALY_RESULT_INDEX_ALIAS, + anomalyDetectionIndices, + clientUtil, + indexUtil, + clusterService, + AnomalyDetectorSettings.AD_BACKOFF_INITIAL_DELAY, + AnomalyDetectorSettings.AD_MAX_RETRY_FOR_BACKOFF + ); + handler.index(TestHelpers.randomAnomalyDetectResult(), detectorId, testIndex); + verify(client, times(1)).index(any(), any()); + } + + @Test + public void testCustomIndexCreateNotAcked() { + String testIndex = "test"; + setWriteBlockAdResultIndex(false); + when(anomalyDetectionIndices.doesIndexExist(anyString())).thenReturn(false); + when(anomalyDetectionIndices.doesAliasExist(anyString())).thenReturn(false); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(new CreateIndexResponse(false, false, testIndex)); + return null; + }).when(anomalyDetectionIndices).initCustomResultIndexDirectly(eq(testIndex), any()); + + ResultIndexingHandler handler = new ResultIndexingHandler<>( + client, + settings, + threadPool, + ADCommonName.ANOMALY_RESULT_INDEX_ALIAS, + anomalyDetectionIndices, + clientUtil, + indexUtil, + clusterService, + AnomalyDetectorSettings.AD_BACKOFF_INITIAL_DELAY, + AnomalyDetectorSettings.AD_MAX_RETRY_FOR_BACKOFF + ); + handler.index(TestHelpers.randomAnomalyDetectResult(), detectorId, testIndex); + + assertTrue( + testAppender + .containsMessage( + String.format(Locale.ROOT, "Creating custom result index %s with mappings call not acknowledged", testIndex) + ) + ); + } + + @Test + public void testCustomIndexCreateExists() { + String testIndex = "test"; + setWriteBlockAdResultIndex(false); + when(anomalyDetectionIndices.doesIndexExist(anyString())).thenReturn(false); + when(anomalyDetectionIndices.doesAliasExist(anyString())).thenReturn(false); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new ResourceAlreadyExistsException("index already exists")); + return null; + }).when(anomalyDetectionIndices).initCustomResultIndexDirectly(eq(testIndex), any()); + + ResultIndexingHandler handler = new ResultIndexingHandler<>( + client, + settings, + threadPool, + ADCommonName.ANOMALY_RESULT_INDEX_ALIAS, + anomalyDetectionIndices, + clientUtil, + indexUtil, + clusterService, + AnomalyDetectorSettings.AD_BACKOFF_INITIAL_DELAY, + AnomalyDetectorSettings.AD_MAX_RETRY_FOR_BACKOFF + ); + handler.index(TestHelpers.randomAnomalyDetectResult(), detectorId, testIndex); + verify(client, times(1)).index(any(), any()); + } + + @Test + public void testCustomIndexOtherException() { + String testIndex = "test"; + setWriteBlockAdResultIndex(false); + when(anomalyDetectionIndices.doesIndexExist(anyString())).thenReturn(false); + when(anomalyDetectionIndices.doesAliasExist(anyString())).thenReturn(false); + + Exception testException = new OpenSearchRejectedExecutionException("Test exception"); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(testException); + return null; + }).when(anomalyDetectionIndices).initCustomResultIndexDirectly(eq(testIndex), any()); + + ResultIndexingHandler handler = new ResultIndexingHandler<>( + client, + settings, + threadPool, + ADCommonName.ANOMALY_RESULT_INDEX_ALIAS, + anomalyDetectionIndices, + clientUtil, + indexUtil, + clusterService, + AnomalyDetectorSettings.AD_BACKOFF_INITIAL_DELAY, + AnomalyDetectorSettings.AD_MAX_RETRY_FOR_BACKOFF + ); + handler.index(TestHelpers.randomAnomalyDetectResult(), detectorId, testIndex); + + assertTrue(testAppender.containsMessage(String.format(Locale.ROOT, "cannot create result index %s", testIndex))); + } + + @Test + public void testInvalid() { + String testIndex = "test"; + setWriteBlockAdResultIndex(false); + when(anomalyDetectionIndices.doesIndexExist(anyString())).thenReturn(false); + when(anomalyDetectionIndices.doesAliasExist(anyString())).thenReturn(true); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(false); + return null; + }).when(anomalyDetectionIndices).validateResultIndexMapping(eq(testIndex), any()); + + ResultIndexingHandler handler = new ResultIndexingHandler<>( + client, + settings, + threadPool, + ADCommonName.ANOMALY_RESULT_INDEX_ALIAS, + anomalyDetectionIndices, + clientUtil, + indexUtil, + clusterService, + AnomalyDetectorSettings.AD_BACKOFF_INITIAL_DELAY, + AnomalyDetectorSettings.AD_MAX_RETRY_FOR_BACKOFF + ); + handler.index(TestHelpers.randomAnomalyDetectResult(), detectorId, testIndex); + + assertTrue(testAppender.containsMessage("wrong index mapping of custom result index", false)); + } + + @Test + public void testValid() { + String testIndex = "test"; + setWriteBlockAdResultIndex(false); + when(anomalyDetectionIndices.doesIndexExist(anyString())).thenReturn(false); + when(anomalyDetectionIndices.doesAliasExist(anyString())).thenReturn(true); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(true); + return null; + }).when(anomalyDetectionIndices).validateResultIndexMapping(eq(testIndex), any()); + + ResultIndexingHandler handler = new ResultIndexingHandler<>( + client, + settings, + threadPool, + ADCommonName.ANOMALY_RESULT_INDEX_ALIAS, + anomalyDetectionIndices, + clientUtil, + indexUtil, + clusterService, + AnomalyDetectorSettings.AD_BACKOFF_INITIAL_DELAY, + AnomalyDetectorSettings.AD_MAX_RETRY_FOR_BACKOFF + ); + + handler.index(TestHelpers.randomAnomalyDetectResult(), detectorId, testIndex); + verify(client, times(1)).index(any(), any()); + } + + @Test + public void testValidationException() { + String testIndex = "test"; + setWriteBlockAdResultIndex(false); + when(anomalyDetectionIndices.doesIndexExist(anyString())).thenReturn(false); + when(anomalyDetectionIndices.doesAliasExist(anyString())).thenReturn(true); + + Exception testException = new OpenSearchRejectedExecutionException("Test exception"); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(testException); + return null; + }).when(anomalyDetectionIndices).validateResultIndexMapping(eq(testIndex), any()); + + ResultIndexingHandler handler = new ResultIndexingHandler<>( + client, + settings, + threadPool, + ADCommonName.ANOMALY_RESULT_INDEX_ALIAS, + anomalyDetectionIndices, + clientUtil, + indexUtil, + clusterService, + AnomalyDetectorSettings.AD_BACKOFF_INITIAL_DELAY, + AnomalyDetectorSettings.AD_MAX_RETRY_FOR_BACKOFF + ); + + handler.index(TestHelpers.randomAnomalyDetectResult(), detectorId, testIndex); + assertTrue(testAppender.containsMessage(String.format(Locale.ROOT, "cannot validate result index %s", testIndex), false)); + } } diff --git a/src/test/java/org/opensearch/forecast/AbstractForecastSyntheticDataTest.java b/src/test/java/org/opensearch/forecast/AbstractForecastSyntheticDataTest.java index 89e1a51b2..5b4163047 100644 --- a/src/test/java/org/opensearch/forecast/AbstractForecastSyntheticDataTest.java +++ b/src/test/java/org/opensearch/forecast/AbstractForecastSyntheticDataTest.java @@ -21,18 +21,22 @@ import java.io.IOException; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import java.util.Locale; import java.util.Set; import org.apache.http.ParseException; import org.apache.http.util.EntityUtils; +import org.opensearch.action.search.SearchResponse; import org.opensearch.client.Response; import org.opensearch.client.RestClient; +import org.opensearch.common.xcontent.json.JsonXContent; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.forecast.constant.ForecastCommonName; import org.opensearch.forecast.model.ForecastTaskProfile; import org.opensearch.forecast.model.Forecaster; +import org.opensearch.search.SearchHit; import org.opensearch.timeseries.AbstractSyntheticDataTest; import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; @@ -158,4 +162,14 @@ protected List waitUntilTaskReachState(String forecasterId, Set return results; } + protected List toHits(Response response) throws UnsupportedOperationException, IOException { + SearchResponse searchResponse = SearchResponse + .fromXContent(createParser(JsonXContent.jsonXContent, response.getEntity().getContent())); + long total = searchResponse.getHits().getTotalHits().value; + if (total == 0) { + return new ArrayList<>(); + } + return Arrays.asList(searchResponse.getHits().getHits()); + } + } diff --git a/src/test/java/org/opensearch/forecast/rest/ForecastRestApiIT.java b/src/test/java/org/opensearch/forecast/rest/ForecastRestApiIT.java index 911ee8484..ebc5b9547 100644 --- a/src/test/java/org/opensearch/forecast/rest/ForecastRestApiIT.java +++ b/src/test/java/org/opensearch/forecast/rest/ForecastRestApiIT.java @@ -11,9 +11,12 @@ import java.time.Duration; import java.time.Instant; import java.time.temporal.ChronoUnit; +import java.util.Arrays; +import java.util.HashSet; import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.Set; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -27,6 +30,7 @@ import org.opensearch.forecast.AbstractForecastSyntheticDataTest; import org.opensearch.forecast.model.ForecastTaskProfile; import org.opensearch.forecast.settings.ForecastEnabledSetting; +import org.opensearch.search.SearchHit; import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.constant.CommonMessages; import org.opensearch.timeseries.model.EntityTaskProfile; @@ -39,16 +43,22 @@ /** * Test the following Restful API: - * - Suggest - * - Validate - * - Create - * - run once + * - top forecast * - start * - stop + * - Create + * - run once + * - Validate + * - Suggest * - update */ public class ForecastRestApiIT extends AbstractForecastSyntheticDataTest { public static final int MAX_RETRY_TIMES = 200; + private static final String CITY_NAME = "cityName"; + private static final String CONFIDENCE_INTERVAL_WIDTH = "confidence_interval_width"; + private static final String FORECAST_VALUE = "forecast_value"; + private static final String MIN_CONFIDENCE_INTERVAL = "MIN_CONFIDENCE_INTERVAL_WIDTH"; + private static final String MAX_CONFIDENCE_INTERVAL = "MAX_CONFIDENCE_INTERVAL_WIDTH"; @Override @Before @@ -85,10 +95,10 @@ private static Instant loadSparseCategoryData(int trainTestSplit) throws Excepti JsonObject row = data.get(i); // Get the value of the "cityName" field - String cityName = row.get("cityName").getAsString(); + String cityName = row.get(CITY_NAME).getAsString(); // Replace the field based on the value of "cityName" - row.remove("cityName"); // Remove the original "cityName" field + row.remove(CITY_NAME); // Remove the original "cityName" field if ("Phoenix".equals(cityName)) { if (phonenixIndex % 2 == 0) { @@ -539,7 +549,7 @@ public void testSuggestSparseData() throws Exception { */ public void testFailToSuggest() throws Exception { int trainTestSplit = 100; - String categoricalField = "cityName"; + String categoricalField = CITY_NAME; GenData dataGenerated = genUniformSingleFeatureData( 70, trainTestSplit, @@ -1931,7 +1941,7 @@ public void testCreate() throws Exception { ); MatcherAssert.assertThat(ex.getMessage(), containsString("Can't create more than 1 feature(s)")); - // case 2: create forecaster with custom index + // Case 2: users cannot specify forecaster id when creating a forecaster forecasterDef = "{\n" + " \"name\": \"Second-Test-Forecaster-4\",\n" + " \"description\": \"ok rate\",\n" @@ -1946,28 +1956,11 @@ public void testCreate() throws Exception { + " \"feature_enabled\": true,\n" + " \"importance\": 1,\n" + " \"aggregation_query\": {\n" - + " \"filtered_max_1\": {\n" - + " \"filter\": {\n" - + " \"bool\": {\n" - + " \"must\": [\n" - + " {\n" - + " \"range\": {\n" - + " \"timestamp\": {\n" - + " \"lt\": %d\n" - + " }\n" - + " }\n" - + " }\n" - + " ]\n" - + " }\n" - + " },\n" - + " \"aggregations\": {\n" + " \"max1\": {\n" + " \"max\": {\n" + " \"field\": \"visitCount\"\n" + " }\n" + " }\n" - + " }\n" - + " }\n" + " }\n" + " }\n" + " ],\n" @@ -1989,26 +1982,25 @@ public void testCreate() throws Exception { + " \"interval\": 10,\n" + " \"unit\": \"MINUTES\"\n" + " }\n" - + " },\n" - + " \"result_index\": \"opensearch-forecast-result-b\"\n" + + " }\n" + "}"; // +1 to make sure it is big enough windowDelayMinutes = Duration.between(trainTime, Instant.now()).toMinutes() + 1; - // we have 100 timestamps (2 entities per timestamp). Timestamps are 10 minutes apart. If we subtract 70 * 10 = 700 minutes, we have - // sparse data. - String formattedForecaster2 = String.format(Locale.ROOT, forecasterDef, RULE_DATASET_NAME, filterTimestamp, windowDelayMinutes); + final String formattedForecasterId = String.format(Locale.ROOT, forecasterDef, RULE_DATASET_NAME, windowDelayMinutes); + String blahId = "__blah__"; Response response = TestHelpers .makeRequest( client(), "POST", String.format(Locale.ROOT, CREATE_FORECASTER), - ImmutableMap.of(), - TestHelpers.toHttpEntity(formattedForecaster2), + ImmutableMap.of(RestHandlerUtils.FORECASTER_ID, blahId), + TestHelpers.toHttpEntity(formattedForecasterId), null ); Map responseMap = entityAsMap(response); - assertEquals("opensearch-forecast-result-b", ((Map) responseMap.get("forecaster")).get("result_index")); + String forecasterId = (String) responseMap.get("_id"); + assertNotEquals("response is missing Id", blahId, forecasterId); } public void testRunOnce() throws Exception { @@ -2054,12 +2046,14 @@ public void testRunOnce() throws Exception { + " \"interval\": 10,\n" + " \"unit\": \"MINUTES\"\n" + " }\n" - + " }\n" + + " },\n" + + " \"result_index\": \"opensearch-forecast-result-b\",\n" + + " \"category_field\": [\"%s\"]\n" + "}"; // +1 to make sure it is big enough long windowDelayMinutes = Duration.between(trainTime, Instant.now()).toMinutes() + 1; - final String formattedForecaster = String.format(Locale.ROOT, forecasterDef, RULE_DATASET_NAME, windowDelayMinutes); + final String formattedForecaster = String.format(Locale.ROOT, forecasterDef, RULE_DATASET_NAME, windowDelayMinutes, CITY_NAME); Response response = TestHelpers .makeRequest( client(), @@ -2071,6 +2065,7 @@ public void testRunOnce() throws Exception { ); Map responseMap = entityAsMap(response); String forecasterId = (String) responseMap.get("_id"); + assertEquals("opensearch-forecast-result-b", ((Map) responseMap.get("forecaster")).get("result_index")); // run once response = TestHelpers @@ -2100,6 +2095,30 @@ public void testRunOnce() throws Exception { int total = (int) (((Map) ((Map) responseMap.get("hits")).get("total")).get("value")); assertTrue("actual: " + total, total > 40); + List hits = toHits(response); + long forecastFrom = -1; + for (SearchHit hit : hits) { + Map source = hit.getSourceAsMap(); + if (source.get("forecast_value") != null) { + forecastFrom = (long) (source.get("data_end_time")); + break; + } + } + assertTrue(forecastFrom != -1); + + // top forecast verification + minConfidenceIntervalVerification(forecasterId, forecastFrom); + maxConfidenceIntervalVerification(forecasterId, forecastFrom); + minForecastValueVerification(forecasterId, forecastFrom); + maxForecastValueVerification(forecasterId, forecastFrom); + distanceToThresholdGreaterThan(forecasterId, forecastFrom); + distanceToThresholdGreaterThanEqual(forecasterId, forecastFrom); + distanceToThresholdLessThan(forecasterId, forecastFrom); + distanceToThresholdLessThanEqual(forecasterId, forecastFrom); + customMaxForecastValue(forecasterId, forecastFrom); + customMinForecastValue(forecasterId, forecastFrom); + topForecastSizeVerification(forecasterId, forecastFrom); + // case 2: cannot run once while forecaster is started response = TestHelpers .makeRequest( @@ -2144,6 +2163,442 @@ public void testRunOnce() throws Exception { assertEquals(forecasterId, responseMap.get("_id")); } + private void maxForecastValueVerification(String forecasterId, long forecastFrom) throws IOException { + Response response; + Map responseMap; + String topForcastRequest; + List parsedBuckets; + double previousValue; + topForcastRequest = String + .format( + Locale.ROOT, + "{\n" + + " \"split_by\": \"%s\",\n" + + " \"filter_by\": \"BUILD_IN_QUERY\",\n" + + " \"build_in_query\": \"MAX_VALUE_WITHIN_THE_HORIZON\",\n" + + " \"forecast_from\": %d,\n" + + " \"run_once\": true\n" + + "}", + CITY_NAME, + forecastFrom + ); + + response = TestHelpers + .makeRequest( + client(), + "POST", + String.format(Locale.ROOT, TOP_FORECASTER, forecasterId), + ImmutableMap.of(), + TestHelpers.toHttpEntity(topForcastRequest), + null + ); + responseMap = entityAsMap(response); + parsedBuckets = (List) responseMap.get("buckets"); + assertTrue("actual content: " + parsedBuckets, parsedBuckets.size() == 2); + + previousValue = Double.MAX_VALUE; // Initialize to positive infinity + double largestValue = Double.MIN_VALUE; + + largestValue = isDesc(parsedBuckets, previousValue, largestValue, "MAX_VALUE_WITHIN_THE_HORIZON"); + + String maxValueRequest = String + .format( + Locale.ROOT, + "{\n" + + " \"size\": 1,\n" + + " \"query\": {\n" + + " \"bool\": {\n" + + " \"filter\": [\n" + + " {\n" + + " \"range\": {\n" + + " \"data_end_time\": {\n" + + " \"from\": %d,\n" + + " \"to\": null,\n" + + " \"include_lower\": true\n" + + " }\n" + + " }\n" + + " }\n" + + " ]\n" + + " }\n" + + " },\n" + + " \"sort\": [\n" + + " {\n" + + " \"%s\": {\n" + + " \"order\": \"desc\"\n" + + " }\n" + + " }\n" + + " ]\n" + + "}", + forecastFrom, + FORECAST_VALUE + ); + + Response maxValueResponse = TestHelpers + .makeRequest(client(), "GET", SEARCH_RESULTS, ImmutableMap.of(), TestHelpers.toHttpEntity(maxValueRequest), null); + List maxValueHits = toHits(maxValueResponse); + assertEquals("actual: " + maxValueHits, 1, maxValueHits.size()); + double maxValue = (double) (maxValueHits.get(0).getSourceAsMap().get(FORECAST_VALUE)); + assertEquals(String.format(Locale.ROOT, "actual: %f, expect: %f", maxValue, largestValue), maxValue, largestValue, 0.001); + } + + private void minForecastValueVerification(String forecasterId, long forecastFrom) throws IOException { + Response response; + Map responseMap; + String topForcastRequest; + List parsedBuckets; + Set cities; + topForcastRequest = String + .format( + Locale.ROOT, + "{\n" + + " \"split_by\": \"%s\",\n" + + " \"filter_by\": \"BUILD_IN_QUERY\",\n" + + " \"build_in_query\": \"MIN_VALUE_WITHIN_THE_HORIZON\",\n" + + " \"forecast_from\": %d,\n" + + " \"run_once\": true\n" + + "}", + CITY_NAME, + forecastFrom + ); + + response = TestHelpers + .makeRequest( + client(), + "POST", + String.format(Locale.ROOT, TOP_FORECASTER, forecasterId), + ImmutableMap.of(), + TestHelpers.toHttpEntity(topForcastRequest), + null + ); + responseMap = entityAsMap(response); + parsedBuckets = (List) responseMap.get("buckets"); + assertTrue("actual content: " + parsedBuckets, parsedBuckets.size() == 2); + + double previousValue = -Double.MAX_VALUE; // Initialize to negative infinity + double smallestValue = Double.MAX_VALUE; + cities = new HashSet<>(); + + smallestValue = isAsc(parsedBuckets, cities, previousValue, smallestValue, "MIN_VALUE_WITHIN_THE_HORIZON"); + + String minValueRequest = String + .format( + Locale.ROOT, + "{\n" + + " \"size\": 1,\n" + + " \"query\": {\n" + + " \"bool\": {\n" + + " \"filter\": [\n" + + " {\n" + + " \"range\": {\n" + + " \"data_end_time\": {\n" + + " \"from\": %d,\n" + + " \"to\": null,\n" + + " \"include_lower\": true\n" + + " }\n" + + " }\n" + + " }\n" + + " ]\n" + + " }\n" + + " },\n" + + " \"sort\": [\n" + + " {\n" + + " \"%s\": {\n" + + " \"order\": \"asc\"\n" + + " }\n" + + " }\n" + + " ]\n" + + "}", + forecastFrom, + FORECAST_VALUE + ); + + Response minValueResponse = TestHelpers + .makeRequest(client(), "GET", SEARCH_RESULTS, ImmutableMap.of(), TestHelpers.toHttpEntity(minValueRequest), null); + List minValueHits = toHits(minValueResponse); + assertEquals("actual: " + minValueHits, 1, minValueHits.size()); + double minValue = (double) (minValueHits.get(0).getSourceAsMap().get("forecast_value")); + assertEquals(String.format(Locale.ROOT, "actual: %f, expect: %f", minValue, smallestValue), minValue, smallestValue, 0.001); + } + + private double isAsc(List parsedBuckets, Set cities, double previousValue, double smallestValue, String valueKey) { + for (Object obj : parsedBuckets) { + assertTrue("Each element in the list must be a Map.", obj instanceof Map); + + @SuppressWarnings("unchecked") + Map bucket = (Map) obj; + + // Extract value using keys + Object valueObj = bucket.get(valueKey); + assertTrue("actual: " + valueObj, valueObj instanceof Number); + + double value = ((Number) valueObj).doubleValue(); + if (smallestValue > value) { + smallestValue = value; + } + + // Check ascending order + assertTrue(String.format(Locale.ROOT, "value %f previousValue %f", value, previousValue), value >= previousValue); + + previousValue = value; + + // Extract the key + Object keyObj = bucket.get("key"); + assertTrue("actual: " + keyObj, keyObj instanceof Map); + + @SuppressWarnings("unchecked") + Map keyMap = (Map) keyObj; + String cityName = (String) keyMap.get(CITY_NAME); + + assertTrue("cityName is null", cityName != null); + + // Check that service is either "Phoenix" or "Scottsdale" + assertTrue("cityName is " + cityName, cityName.equals("Phoenix") || cityName.equals("Scottsdale")); + + // Check for unique services + assertTrue("Duplicate city found: " + cityName, cities.add(cityName)); + } + return smallestValue; + } + + private void maxConfidenceIntervalVerification(String forecasterId, long forecastFrom) throws IOException { + Response response; + Map responseMap; + String topForcastRequest; + List parsedBuckets; + double previousWidth; + topForcastRequest = String + .format( + Locale.ROOT, + "{\n" + + " \"split_by\": \"%s\",\n" + + " \"filter_by\": \"BUILD_IN_QUERY\",\n" + + " \"build_in_query\": \"MAX_CONFIDENCE_INTERVAL_WIDTH\",\n" + + " \"forecast_from\": %d,\n" + + " \"run_once\": true\n" + + "}", + CITY_NAME, + forecastFrom + ); + + response = TestHelpers + .makeRequest( + client(), + "POST", + String.format(Locale.ROOT, TOP_FORECASTER, forecasterId), + ImmutableMap.of(), + TestHelpers.toHttpEntity(topForcastRequest), + null + ); + responseMap = entityAsMap(response); + parsedBuckets = (List) responseMap.get("buckets"); + assertTrue("actual content: " + parsedBuckets, parsedBuckets.size() == 2); + + previousWidth = Double.MAX_VALUE; // Initialize to positive infinity + double largestWidth = Double.MIN_VALUE; + + largestWidth = isDesc(parsedBuckets, previousWidth, largestWidth, MAX_CONFIDENCE_INTERVAL); + + String maxConfidenceIntervalRequest = String + .format( + Locale.ROOT, + "{\n" + + " \"size\": 1,\n" + + " \"query\": {\n" + + " \"bool\": {\n" + + " \"filter\": [\n" + + " {\n" + + " \"term\": {\n" + + " \"horizon_index\": 24\n" + + " }\n" + + " },\n" + + " {\n" + + " \"range\": {\n" + + " \"data_end_time\": {\n" + + " \"from\": %d,\n" + + " \"to\": null,\n" + + " \"include_lower\": true\n" + + " }\n" + + " }\n" + + " }\n" + + " ]\n" + + " }\n" + + " },\n" + + " \"sort\": [\n" + + " {\n" + + " \"%s\": {\n" + + " \"order\": \"desc\"\n" + + " }\n" + + " }\n" + + " ]\n" + + "}", + forecastFrom, + CONFIDENCE_INTERVAL_WIDTH + ); + + Response maxConfidenceIntervalResponse = TestHelpers + .makeRequest(client(), "GET", SEARCH_RESULTS, ImmutableMap.of(), TestHelpers.toHttpEntity(maxConfidenceIntervalRequest), null); + List maxConfidenceIntervalHits = toHits(maxConfidenceIntervalResponse); + assertEquals("actual: " + maxConfidenceIntervalHits, 1, maxConfidenceIntervalHits.size()); + double maxWidth = (double) (maxConfidenceIntervalHits.get(0).getSourceAsMap().get(CONFIDENCE_INTERVAL_WIDTH)); + assertEquals(String.format(Locale.ROOT, "actual: %f, expect: %f", maxWidth, largestWidth), maxWidth, largestWidth, 0.001); + } + + private void validateKeyValue( + Map keyMap, + String keyName, + String valueDescription, + Set expectedValues, + Set uniqueValuesSet + ) { + // Extract the value from the keyMap using the keyName + String value = (String) keyMap.get(keyName); + + // Ensure the value is not null + assertTrue(valueDescription + " is null", value != null); + + // Check that the value is one of the expected values + assertTrue(valueDescription + " is " + value, expectedValues.contains(value)); + + // Check for uniqueness in the provided set + assertTrue("Duplicate " + valueDescription + " found: " + value, uniqueValuesSet.add(value)); + } + + private double isDesc( + List parsedBuckets, + double previousWidth, + Set uniqueValuesSet, + double largestWidth, + String valueKey, + String keyName, + String valueDescription, + Set expectedValues + ) { + for (Object obj : parsedBuckets) { + assertTrue("Each element in the list must be a Map.", obj instanceof Map); + + @SuppressWarnings("unchecked") + Map bucket = (Map) obj; + + // Extract valueKey + Object widthObj = bucket.get(valueKey); + assertTrue("actual: " + widthObj, widthObj instanceof Number); + + double width = ((Number) widthObj).doubleValue(); + if (largestWidth < width) { + largestWidth = width; + } + + // Check descending order + assertTrue(String.format(Locale.ROOT, "width %f previousWidth %f", width, previousWidth), width <= previousWidth); + + previousWidth = width; + + // Extract the key + Object keyObj = bucket.get("key"); + assertTrue("actual: " + keyObj, keyObj instanceof Map); + + @SuppressWarnings("unchecked") + Map keyMap = (Map) keyObj; + + // Use the helper method for validation + validateKeyValue(keyMap, keyName, valueDescription, expectedValues, uniqueValuesSet); + } + return largestWidth; + } + + private double isDesc(List parsedBuckets, double previousWidth, double largestWidth, String valueKey) { + Set cities = new HashSet<>(); + Set expectedCities = new HashSet<>(Arrays.asList("Phoenix", "Scottsdale")); + return isDesc(parsedBuckets, previousWidth, cities, largestWidth, valueKey, CITY_NAME, "cityName", expectedCities); + } + + private double isDescTwoCategorical(List parsedBuckets, double previousWidth, double largestWidth, String valueKey) { + Set regions = new HashSet<>(); + Set expectedRegions = new HashSet<>(Arrays.asList("pdx", "iad")); + return isDesc(parsedBuckets, previousWidth, regions, largestWidth, valueKey, "region", "regionName", expectedRegions); + } + + private void minConfidenceIntervalVerification(String forecasterId, long forecastFrom) throws IOException { + Response response; + Map responseMap; + String topForcastRequest = String + .format( + Locale.ROOT, + "{\n" + + " \"split_by\": \"%s\",\n" + + " \"filter_by\": \"BUILD_IN_QUERY\",\n" + + " \"build_in_query\": \"MIN_CONFIDENCE_INTERVAL_WIDTH\",\n" + + " \"forecast_from\": %d,\n" + + " \"run_once\": true\n" + + "}", + CITY_NAME, + forecastFrom + ); + + response = TestHelpers + .makeRequest( + client(), + "POST", + String.format(Locale.ROOT, TOP_FORECASTER, forecasterId), + ImmutableMap.of(), + TestHelpers.toHttpEntity(topForcastRequest), + null + ); + responseMap = entityAsMap(response); + List parsedBuckets = (List) responseMap.get("buckets"); + assertTrue("actual content: " + parsedBuckets, parsedBuckets.size() == 2); + + double previousWidth = -Double.MAX_VALUE; // Initialize to negative infinity + double smallestWidth = Double.MAX_VALUE; + Set cities = new HashSet<>(); + + smallestWidth = isAsc(parsedBuckets, cities, previousWidth, smallestWidth, MIN_CONFIDENCE_INTERVAL); + + String minConfidenceIntervalRequest = String + .format( + Locale.ROOT, + "{\n" + + " \"size\": 1,\n" + + " \"query\": {\n" + + " \"bool\": {\n" + + " \"filter\": [\n" + + " {\n" + + " \"term\": {\n" + + " \"horizon_index\": 24\n" + + " }\n" + + " },\n" + + " {\n" + + " \"range\": {\n" + + " \"data_end_time\": {\n" + + " \"from\": %d,\n" + + " \"to\": null,\n" + + " \"include_lower\": true\n" + + " }\n" + + " }\n" + + " }\n" + + " ]\n" + + " }\n" + + " },\n" + + " \"sort\": [\n" + + " {\n" + + " \"%s\": {\n" + + " \"order\": \"asc\"\n" + + " }\n" + + " }\n" + + " ]\n" + + "}", + forecastFrom, + CONFIDENCE_INTERVAL_WIDTH + ); + + Response minConfidenceIntervalResponse = TestHelpers + .makeRequest(client(), "GET", SEARCH_RESULTS, ImmutableMap.of(), TestHelpers.toHttpEntity(minConfidenceIntervalRequest), null); + List minConfidenceIntervalHits = toHits(minConfidenceIntervalResponse); + assertEquals("actual: " + minConfidenceIntervalHits, 1, minConfidenceIntervalHits.size()); + double minWidth = (double) (minConfidenceIntervalHits.get(0).getSourceAsMap().get(CONFIDENCE_INTERVAL_WIDTH)); + assertEquals(String.format(Locale.ROOT, "actual: %f, expect: %f", minWidth, smallestWidth), minWidth, smallestWidth, 0.001); + } + public Response searchTaskResult(String taskId) throws IOException { Response response = TestHelpers .makeRequest( @@ -2153,7 +2608,9 @@ public Response searchTaskResult(String taskId) throws IOException { ImmutableMap.of(), TestHelpers .toHttpEntity( - "{\"query\":{\"bool\":{\"filter\":[{\"term\":{\"task_id\":\"" + taskId + "\"}}]}},\"track_total_hits\":true}" + "{\"query\":{\"bool\":{\"filter\":[{\"term\":{\"task_id\":\"" + + taskId + + "\"}}]}},\"track_total_hits\":true,\"size\":10000}" ), null ); @@ -2454,4 +2911,624 @@ public void testUpdateDetector() throws Exception { responseMap = entityAsMap(response); assertEquals(responseMap.get("last_update_time"), responseMap.get("last_ui_breaking_change_time")); } + + private void distanceToThresholdGreaterThan(String forecasterId, long forecastFrom) throws IOException { + distanceToThresholdGreaterTemplate(forecasterId, forecastFrom, false); + } + + private void distanceToThresholdGreaterTemplate(String forecasterId, long forecastFrom, boolean equal) throws IOException { + Response response; + Map responseMap; + String topForcastRequest; + List parsedBuckets; + double previousWidth; + int threshold = 4587; + topForcastRequest = String + .format( + Locale.ROOT, + "{\n" + + " \"split_by\": \"%s\",\n" + + " \"filter_by\": \"BUILD_IN_QUERY\",\n" + + " \"build_in_query\": \"DISTANCE_TO_THRESHOLD_VALUE\",\n" + + " \"forecast_from\": %d,\n" + + " \"run_once\": true,\n" + + " \"threshold\": %d,\n" + + " \"relation_to_threshold\": \"%s\"" + + "}", + CITY_NAME, + forecastFrom, + threshold, + equal ? "GREATER_THAN_OR_EQUAL_TO" : "GREATER_THAN" + ); + + response = TestHelpers + .makeRequest( + client(), + "POST", + String.format(Locale.ROOT, TOP_FORECASTER, forecasterId), + ImmutableMap.of(), + TestHelpers.toHttpEntity(topForcastRequest), + null + ); + responseMap = entityAsMap(response); + parsedBuckets = (List) responseMap.get("buckets"); + assertTrue("actual content: " + parsedBuckets, parsedBuckets.size() == 2); + + previousWidth = Double.MAX_VALUE; // Initialize to positive infinity + double largestValue = Double.MIN_VALUE; + + largestValue = isDesc(parsedBuckets, previousWidth, largestValue, "DISTANCE_TO_THRESHOLD_VALUE"); + + String maxDistanceToThresholdRequest = String + .format( + Locale.ROOT, + "{\n" + + " \"size\": 1,\n" + + " \"query\": {\n" + + " \"bool\": {\n" + + " \"filter\": [\n" + + " {\n" + + " \"range\": {\n" + + " \"forecast_value\": {\n" + + " \"%s\": " + + threshold + + "\n" + + " }\n" + + " }\n" + + " },\n" + + " {\n" + + " \"range\": {\n" + + " \"data_end_time\": {\n" + + " \"from\": %d,\n" + + " \"to\": null,\n" + + " \"include_lower\": true\n" + + " }\n" + + " }\n" + + " }\n" + + " ]\n" + + " }\n" + + " },\n" + + " \"sort\": [\n" + + " {\n" + + " \"forecast_value\": {\n" + + " \"order\": \"desc\"\n" + + " }\n" + + " }\n" + + " ]\n" + + "}", + equal ? "gte" : "gt", + forecastFrom + ); + + Response maxDistanceResponse = TestHelpers + .makeRequest(client(), "GET", SEARCH_RESULTS, ImmutableMap.of(), TestHelpers.toHttpEntity(maxDistanceToThresholdRequest), null); + List maxDistanceHits = toHits(maxDistanceResponse); + assertEquals("actual: " + maxDistanceHits, 1, maxDistanceHits.size()); + double maxValue = (double) (maxDistanceHits.get(0).getSourceAsMap().get("forecast_value")); + assertEquals(String.format(Locale.ROOT, "actual: %f, expect: %f", maxValue, largestValue), maxValue, largestValue, 0.001); + } + + private void distanceToThresholdGreaterThanEqual(String forecasterId, long forecastFrom) throws IOException { + distanceToThresholdGreaterTemplate(forecasterId, forecastFrom, true); + } + + private void distanceToThresholdLessThan(String forecasterId, long forecastFrom) throws IOException { + distanceToThresholdLessTemplate(forecasterId, forecastFrom, false); + } + + private void distanceToThresholdLessTemplate(String forecasterId, long forecastFrom, boolean equal) throws IOException { + Response response; + Map responseMap; + String topForcastRequest; + List parsedBuckets; + double previousWidth; + Set cities; + int threshold = 7000; + topForcastRequest = String + .format( + Locale.ROOT, + "{\n" + + " \"split_by\": \"%s\",\n" + + " \"filter_by\": \"BUILD_IN_QUERY\",\n" + + " \"build_in_query\": \"DISTANCE_TO_THRESHOLD_VALUE\",\n" + + " \"forecast_from\": %d,\n" + + " \"run_once\": true,\n" + + " \"threshold\": %d,\n" + + " \"relation_to_threshold\": \"%s\"" + + "}", + CITY_NAME, + forecastFrom, + threshold, + equal ? "LESS_THAN_OR_EQUAL_TO" : "LESS_THAN" + ); + + response = TestHelpers + .makeRequest( + client(), + "POST", + String.format(Locale.ROOT, TOP_FORECASTER, forecasterId), + ImmutableMap.of(), + TestHelpers.toHttpEntity(topForcastRequest), + null + ); + responseMap = entityAsMap(response); + parsedBuckets = (List) responseMap.get("buckets"); + assertTrue("actual content: " + parsedBuckets, parsedBuckets.size() == 2); + + previousWidth = Double.MIN_VALUE; // Initialize to negative infinity + double smallestValue = Double.MAX_VALUE; + cities = new HashSet<>(); + + smallestValue = isAsc(parsedBuckets, cities, previousWidth, smallestValue, "DISTANCE_TO_THRESHOLD_VALUE"); + + String maxDistanceToThresholdRequest = String + .format( + Locale.ROOT, + "{\n" + + " \"size\": 1,\n" + + " \"query\": {\n" + + " \"bool\": {\n" + + " \"filter\": [\n" + + " {\n" + + " \"range\": {\n" + + " \"forecast_value\": {\n" + + " \"%s\": " + + threshold + + "\n" + + " }\n" + + " }\n" + + " },\n" + + " {\n" + + " \"range\": {\n" + + " \"data_end_time\": {\n" + + " \"from\": %d,\n" + + " \"to\": null,\n" + + " \"include_lower\": true\n" + + " }\n" + + " }\n" + + " }\n" + + " ]\n" + + " }\n" + + " },\n" + + " \"sort\": [\n" + + " {\n" + + " \"forecast_value\": {\n" + + " \"order\": \"asc\"\n" + + " }\n" + + " }\n" + + " ]\n" + + "}", + equal ? "lte" : "lt", + forecastFrom + ); + + Response maxDistanceResponse = TestHelpers + .makeRequest(client(), "GET", SEARCH_RESULTS, ImmutableMap.of(), TestHelpers.toHttpEntity(maxDistanceToThresholdRequest), null); + List maxDistanceHits = toHits(maxDistanceResponse); + assertEquals("actual: " + maxDistanceHits, 1, maxDistanceHits.size()); + double maxValue = (double) (maxDistanceHits.get(0).getSourceAsMap().get("forecast_value")); + assertEquals(String.format(Locale.ROOT, "actual: %f, expect: %f", maxValue, smallestValue), maxValue, smallestValue, 0.001); + } + + private void distanceToThresholdLessThanEqual(String forecasterId, long forecastFrom) throws IOException { + distanceToThresholdLessTemplate(forecasterId, forecastFrom, true); + } + + private void customMaxForecastValue(String forecasterId, long forecastFrom) throws IOException { + customForecastValueTemplate(forecasterId, forecastFrom, true); + } + + private void customMinForecastValue(String forecasterId, long forecastFrom) throws IOException { + customForecastValueTemplate(forecasterId, forecastFrom, false); + } + + private void customForecastValueTemplate(String forecasterId, long forecastFrom, boolean max) throws IOException { + Response response; + Map responseMap; + String topForcastRequest; + List parsedBuckets; + double previousValue; + topForcastRequest = String + .format( + Locale.ROOT, + "{\n" + + " \"forecast_from\": %d,\n" + + " \"filter_by\": \"CUSTOM_QUERY\",\n" + + " \"filter_query\": {\n" + + " \"nested\": {\n" + + " \"path\": \"entity\",\n" + + " \"query\": {\n" + + " \"bool\": {\n" + + " \"must\": [\n" + + " {\n" + + " \"term\": {\n" + + " \"entity.name\": \"%s\"\n" + + " }\n" + + " },\n" + + " {\n" + + " \"wildcard\": {\n" + + " \"entity.value\": \"S*\"\n" + + " }\n" + + " }\n" + + " ]\n" + + " }\n" + + " }\n" + + " }\n" + + " },\n" + + " \"subaggregations\": [\n" + + " {\n" + + " \"aggregation_query\": {\n" + + " \"forecast_value_max\": {\n" + + " \"%s\": {\n" + + " \"field\": \"forecast_value\"\n" + + " }\n" + + " }\n" + + " },\n" + + " \"order\": \"DESC\"\n" + + " }\n" + + " ],\n" + + " \"run_once\": true\n" + + "}", + forecastFrom, + CITY_NAME, + max ? "max" : "min" + ); + + response = TestHelpers + .makeRequest( + client(), + "POST", + String.format(Locale.ROOT, TOP_FORECASTER, forecasterId), + ImmutableMap.of(), + TestHelpers.toHttpEntity(topForcastRequest), + null + ); + responseMap = entityAsMap(response); + parsedBuckets = (List) responseMap.get("buckets"); + assertTrue("actual content: " + parsedBuckets, parsedBuckets.size() == 1); + + previousValue = Double.MAX_VALUE; // Initialize to positive infinity + double largestValue = Double.MIN_VALUE; + + largestValue = isDesc(parsedBuckets, previousValue, largestValue, "forecast_value_max"); + + String maxValueRequest = String + .format( + Locale.ROOT, + "{\n" + + " \"size\": 1,\n" + + " \"sort\": [\n" + + " {\n" + + " \"%s\": {\n" + + " \"order\": \"%s\"\n" + + " }\n" + + " }\n" + + " ],\n" + + " \"query\": {\n" + + " \"bool\": {\n" + + " \"must\": [\n" + + " {\n" + + " \"nested\": {\n" + + " \"path\": \"entity\",\n" + + " \"query\": {\n" + + " \"bool\": {\n" + + " \"must\": [\n" + + " {\n" + + " \"term\": {\n" + + " \"entity.name\": \"%s\"\n" + + " }\n" + + " },\n" + + " {\n" + + " \"wildcard\": {\n" + + " \"entity.value\": \"S*\"\n" + + " }\n" + + " }\n" + + " ]\n" + + " }\n" + + " }\n" + + " }\n" + + " },\n" + + " {\n" + + " \"range\": {\n" + + " \"data_end_time\": {\n" + + " \"from\": %d,\n" + + " \"to\": null,\n" + + " \"include_lower\": true\n" + + " }\n" + + " }\n" + + " }\n" + + " ]\n" + + " }\n" + + " }\n" + + "}", + FORECAST_VALUE, // First %s + max ? "desc" : "asc", // Second %s + CITY_NAME, // Third %s + forecastFrom // %d + ); + + Response maxValueResponse = TestHelpers + .makeRequest(client(), "GET", SEARCH_RESULTS, ImmutableMap.of(), TestHelpers.toHttpEntity(maxValueRequest), null); + List maxValueHits = toHits(maxValueResponse); + assertEquals("actual: " + maxValueHits, 1, maxValueHits.size()); + double maxValue = (double) (maxValueHits.get(0).getSourceAsMap().get(FORECAST_VALUE)); + assertEquals(String.format(Locale.ROOT, "actual: %f, expect: %f", maxValue, largestValue), maxValue, largestValue, 0.001); + } + + public void testTopForecast() throws Exception { + Instant trainTime = loadTwoCategoricalFieldData(200); + // case 1: happy case + String forecasterDef = "{\n" + + " \"name\": \"Second-Test-Forecaster-4\",\n" + + " \"description\": \"ok rate\",\n" + + " \"time_field\": \"timestamp\",\n" + + " \"indices\": [\n" + + " \"%s\"\n" + + " ],\n" + + " \"feature_attributes\": [\n" + + " {\n" + + " \"feature_id\": \"max1\",\n" + + " \"feature_name\": \"max1\",\n" + + " \"feature_enabled\": true,\n" + + " \"importance\": 1,\n" + + " \"aggregation_query\": {\n" + + " \"max1\": {\n" + + " \"max\": {\n" + + " \"field\": \"visitCount\"\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " ],\n" + + " \"window_delay\": {\n" + + " \"period\": {\n" + + " \"interval\": %d,\n" + + " \"unit\": \"MINUTES\"\n" + + " }\n" + + " },\n" + + " \"ui_metadata\": {\n" + + " \"aabb\": {\n" + + " \"ab\": \"bb\"\n" + + " }\n" + + " },\n" + + " \"schema_version\": 2,\n" + + " \"horizon\": 24,\n" + + " \"forecast_interval\": {\n" + + " \"period\": {\n" + + " \"interval\": 10,\n" + + " \"unit\": \"MINUTES\"\n" + + " }\n" + + " },\n" + + " \"result_index\": \"opensearch-forecast-result-b\",\n" + + " \"category_field\": [%s]\n" + + "}"; + + // +1 to make sure it is big enough + long windowDelayMinutes = Duration.between(trainTime, Instant.now()).toMinutes() + 1; + final String formattedForecaster = String + .format(Locale.ROOT, forecasterDef, RULE_DATASET_NAME, windowDelayMinutes, "\"account\",\"region\""); + Response response = TestHelpers + .makeRequest( + client(), + "POST", + String.format(Locale.ROOT, CREATE_FORECASTER), + ImmutableMap.of(), + TestHelpers.toHttpEntity(formattedForecaster), + null + ); + Map responseMap = entityAsMap(response); + String forecasterId = (String) responseMap.get("_id"); + assertEquals("opensearch-forecast-result-b", ((Map) responseMap.get("forecaster")).get("result_index")); + + // run once + response = TestHelpers + .makeRequest( + client(), + "POST", + String.format(Locale.ROOT, RUN_ONCE_FORECASTER, forecasterId), + ImmutableMap.of(), + (HttpEntity) null, + null + ); + + ForecastTaskProfile forecastTaskProfile = (ForecastTaskProfile) waitUntilTaskReachState( + forecasterId, + ImmutableSet.of(TaskState.TEST_COMPLETE.name()), + client() + ).get(0); + assertTrue(forecastTaskProfile != null); + assertTrue(forecastTaskProfile.getTask().isLatest()); + + responseMap = entityAsMap(response); + String taskId = (String) responseMap.get(EntityTaskProfile.TASK_ID_FIELD); + assertEquals(taskId, forecastTaskProfile.getTaskId()); + + response = searchTaskResult(taskId); + responseMap = entityAsMap(response); + int total = (int) (((Map) ((Map) responseMap.get("hits")).get("total")).get("value")); + assertTrue("actual: " + total, total > 40); + + List hits = toHits(response); + long forecastFrom = -1; + for (SearchHit hit : hits) { + Map source = hit.getSourceAsMap(); + if (source.get("forecast_value") != null) { + forecastFrom = (long) (source.get("data_end_time")); + break; + } + } + assertTrue(forecastFrom != -1); + + // top forecast verification + customForecastValueDoubleCategories(forecasterId, forecastFrom, true, taskId); + customForecastValueDoubleCategories(forecasterId, forecastFrom, false, taskId); + } + + private void topForecastSizeVerification(String forecasterId, long forecastFrom) throws IOException { + Response response; + Map responseMap; + String topForcastRequest = String + .format( + Locale.ROOT, + "{\n" + + " \"size\": 1,\n" + + " \"split_by\": \"%s\",\n" + + " \"filter_by\": \"BUILD_IN_QUERY\",\n" + + " \"build_in_query\": \"MIN_CONFIDENCE_INTERVAL_WIDTH\",\n" + + " \"forecast_from\": %d,\n" + + " \"run_once\": true\n" + + "}", + CITY_NAME, + forecastFrom + ); + + response = TestHelpers + .makeRequest( + client(), + "POST", + String.format(Locale.ROOT, TOP_FORECASTER, forecasterId), + ImmutableMap.of(), + TestHelpers.toHttpEntity(topForcastRequest), + null + ); + responseMap = entityAsMap(response); + List parsedBuckets = (List) responseMap.get("buckets"); + assertTrue("actual content: " + parsedBuckets, parsedBuckets.size() == 1); + } + + private void customForecastValueDoubleCategories(String forecasterId, long forecastFrom, boolean max, String taskId) + throws IOException { + Response response; + Map responseMap; + String topForcastRequest; + List parsedBuckets; + double previousValue; + topForcastRequest = String + .format( + Locale.ROOT, + "{\n" + + " \"forecast_from\": %d,\n" + + " \"filter_by\": \"CUSTOM_QUERY\",\n" + + " \"filter_query\": {\n" + + " \"nested\": {\n" + + " \"path\": \"entity\",\n" + + " \"query\": {\n" + + " \"bool\": {\n" + + " \"must\": [\n" + + " {\n" + + " \"term\": {\n" + + " \"entity.name\": \"%s\"\n" + + " }\n" + + " },\n" + + " {\n" + + " \"wildcard\": {\n" + + " \"entity.value\": \"i*\"\n" + + " }\n" + + " }\n" + + " ]\n" + + " }\n" + + " }\n" + + " }\n" + + " },\n" + + " \"subaggregations\": [\n" + + " {\n" + + " \"aggregation_query\": {\n" + + " \"forecast_value_max\": {\n" + + " \"%s\": {\n" + + " \"field\": \"forecast_value\"\n" + + " }\n" + + " }\n" + + " },\n" + + " \"order\": \"DESC\"\n" + + " }\n" + + " ],\n" + + " \"run_once\": true,\n" + + " \"task_id\": \"%s\"\n" + + "}", + forecastFrom, + "region", + max ? "max" : "min", + taskId + ); + + response = TestHelpers + .makeRequest( + client(), + "POST", + String.format(Locale.ROOT, TOP_FORECASTER, forecasterId), + ImmutableMap.of(), + TestHelpers.toHttpEntity(topForcastRequest), + null + ); + responseMap = entityAsMap(response); + parsedBuckets = (List) responseMap.get("buckets"); + assertTrue("actual content: " + parsedBuckets, parsedBuckets.size() == 1); + + previousValue = Double.MAX_VALUE; // Initialize to positive infinity + double largestValue = Double.MIN_VALUE; + + largestValue = isDescTwoCategorical(parsedBuckets, previousValue, largestValue, "forecast_value_max"); + + String maxValueRequest = String + .format( + Locale.ROOT, + "{\n" + + " \"size\": 1,\n" + + " \"sort\": [\n" + + " {\n" + + " \"%s\": {\n" + + " \"order\": \"%s\"\n" + + " }\n" + + " }\n" + + " ],\n" + + " \"query\": {\n" + + " \"bool\": {\n" + + " \"must\": [\n" + + " {\n" + + " \"nested\": {\n" + + " \"path\": \"entity\",\n" + + " \"query\": {\n" + + " \"bool\": {\n" + + " \"must\": [\n" + + " {\n" + + " \"term\": {\n" + + " \"entity.name\": \"%s\"\n" + + " }\n" + + " },\n" + + " {\n" + + " \"wildcard\": {\n" + + " \"entity.value\": \"i*\"\n" + + " }\n" + + " }\n" + + " ]\n" + + " }\n" + + " }\n" + + " }\n" + + " },\n" + + " {\n" + + " \"range\": {\n" + + " \"data_end_time\": {\n" + + " \"from\": %d,\n" + + " \"to\": null,\n" + + " \"include_lower\": true\n" + + " }\n" + + " }\n" + + " }\n" + + " ]\n" + + " }\n" + + " }\n" + + "}", + FORECAST_VALUE, // First %s + max ? "desc" : "asc", // Second %s + "region", // Third %s + forecastFrom // %d + ); + + Response maxValueResponse = TestHelpers + .makeRequest(client(), "GET", SEARCH_RESULTS, ImmutableMap.of(), TestHelpers.toHttpEntity(maxValueRequest), null); + List maxValueHits = toHits(maxValueResponse); + assertEquals("actual: " + maxValueHits, 1, maxValueHits.size()); + double maxValue = (double) (maxValueHits.get(0).getSourceAsMap().get(FORECAST_VALUE)); + assertEquals(String.format(Locale.ROOT, "actual: %f, expect: %f", maxValue, largestValue), maxValue, largestValue, 0.001); + } } diff --git a/src/test/java/org/opensearch/forecast/rest/SecureForecastRestIT.java b/src/test/java/org/opensearch/forecast/rest/SecureForecastRestIT.java index ae838080d..be81599ab 100644 --- a/src/test/java/org/opensearch/forecast/rest/SecureForecastRestIT.java +++ b/src/test/java/org/opensearch/forecast/rest/SecureForecastRestIT.java @@ -588,16 +588,6 @@ protected List waitUntilResultAvailable(RestClient client) throws Int return hits; } - private List toHits(Response response) throws UnsupportedOperationException, IOException { - SearchResponse searchResponse = SearchResponse - .fromXContent(createParser(JsonXContent.jsonXContent, response.getEntity().getContent())); - long total = searchResponse.getHits().getTotalHits().value; - if (total == 0) { - return new ArrayList<>(); - } - return Arrays.asList(searchResponse.getHits().getHits()); - } - private Response enableFilterBy() throws IOException { return TestHelpers .makeRequest( diff --git a/src/test/java/org/opensearch/timeseries/AbstractSyntheticDataTest.java b/src/test/java/org/opensearch/timeseries/AbstractSyntheticDataTest.java index a4ba921f6..175a971de 100644 --- a/src/test/java/org/opensearch/timeseries/AbstractSyntheticDataTest.java +++ b/src/test/java/org/opensearch/timeseries/AbstractSyntheticDataTest.java @@ -279,6 +279,43 @@ protected static Instant loadRuleData(int trainTestSplit) throws Exception { return loadData(RULE_DATASET_NAME, trainTestSplit, RULE_DATA_MAPPING); } + // convert 1 categorical field (cityName) rule data with two categorical field (account and region) rule data + protected static Instant loadTwoCategoricalFieldData(int trainTestSplit) throws Exception { + RestClient client = client(); + + String dataFileName = String.format(Locale.ROOT, "org/opensearch/ad/e2e/data/%s.data", RULE_DATASET_NAME); + + List data = readJsonArrayWithLimit(dataFileName, trainTestSplit); + + for (int i = 0; i < trainTestSplit && i < data.size(); i++) { + JsonObject jsonObject = data.get(i); + String city = jsonObject.get("cityName").getAsString(); + if (city.equals("Phoenix")) { + jsonObject.addProperty("account", "1234"); + jsonObject.addProperty("region", "iad"); + } else if (city.equals("Scottsdale")) { + jsonObject.addProperty("account", "5678"); + jsonObject.addProperty("region", "pdx"); + } + } + + String mapping = "{ \"mappings\": { \"properties\": { " + + "\"timestamp\": { \"type\": \"date\" }, " + + "\"visitCount\": { \"type\": \"integer\" }, " + + "\"cityName\": { \"type\": \"keyword\" }, " + + "\"account\": { \"type\": \"keyword\" }, " + + "\"region\": { \"type\": \"keyword\" } " + + "} } }"; + + bulkIndexTrainData(RULE_DATASET_NAME, data, trainTestSplit, client, mapping); + String trainTimeStr = data.get(trainTestSplit - 1).get("timestamp").getAsString(); + if (canBeParsedAsLong(trainTimeStr)) { + return Instant.ofEpochMilli(Long.parseLong(trainTimeStr)); + } else { + return Instant.parse(trainTimeStr); + } + } + public static boolean canBeParsedAsLong(String str) { if (str == null || str.isEmpty()) { return false; // Handle null or empty strings as not parsable diff --git a/src/test/java/org/opensearch/timeseries/TestHelpers.java b/src/test/java/org/opensearch/timeseries/TestHelpers.java index 7e4a0b7d0..cc0036ce5 100644 --- a/src/test/java/org/opensearch/timeseries/TestHelpers.java +++ b/src/test/java/org/opensearch/timeseries/TestHelpers.java @@ -118,6 +118,7 @@ import org.opensearch.forecast.model.ForecastResult; import org.opensearch.forecast.model.ForecastTask; import org.opensearch.forecast.model.Forecaster; +import org.opensearch.forecast.ratelimit.ForecastResultWriteRequest; import org.opensearch.index.get.GetResult; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.MatchAllQueryBuilder; @@ -1585,6 +1586,16 @@ public static Entity randomEntity(Config config) { return entity; } + private static Entity randomEntity() { + String name = randomAlphaOfLength(10); + List values = new ArrayList<>(); + int size = random.nextInt(3) + 1; // At least one value + for (int i = 0; i < size; i++) { + values.add(randomAlphaOfLength(10)); + } + return Entity.createEntityByReordering(ImmutableMap.of(name, values)); + } + public static HttpEntity toHttpEntity(ToXContentObject object) throws IOException { return new StringEntity(toJsonString(object), APPLICATION_JSON); } @@ -2224,4 +2235,57 @@ public Job build() { } } + public static ForecastResultWriteRequest randomForecastResultWriteRequest() { + // Generate random values for required fields + long expirationEpochMs = Instant.now().plusSeconds(random.nextInt(3600)).toEpochMilli(); // Expire within the next hour + String forecasterId = randomAlphaOfLength(10); + RequestPriority priority = RequestPriority.MEDIUM; // Use NORMAL priority for testing + ForecastResult result = randomForecastResult(forecasterId); + String resultIndex = random.nextBoolean() ? randomAlphaOfLength(10) : null; // Randomly decide to set resultIndex or not + + return new ForecastResultWriteRequest(expirationEpochMs, forecasterId, priority, result, resultIndex); + } + + public static ForecastResult randomForecastResult(String forecasterId) { + String taskId = randomAlphaOfLength(10); + Double dataQuality = random.nextDouble(); + List featureData = ImmutableList.of(randomFeatureData()); + Instant dataStartTime = Instant.now().minusSeconds(random.nextInt(3600)); + Instant dataEndTime = Instant.now(); + Instant executionStartTime = Instant.now().minusSeconds(random.nextInt(3600)); + Instant executionEndTime = Instant.now(); + String error = random.nextBoolean() ? randomAlphaOfLength(20) : null; + Optional entity = random.nextBoolean() ? Optional.of(randomEntity()) : Optional.empty(); + User user = random.nextBoolean() ? randomUser() : null; + Integer schemaVersion = random.nextInt(10); + String featureId = randomAlphaOfLength(10); + Float forecastValue = random.nextFloat(); + Float lowerBound = forecastValue - random.nextFloat(); + Float upperBound = forecastValue + random.nextFloat(); + Instant forecastDataStartTime = dataEndTime.plusSeconds(random.nextInt(3600)); + Instant forecastDataEndTime = forecastDataStartTime.plusSeconds(random.nextInt(3600)); + Integer horizonIndex = random.nextInt(100); + + return new ForecastResult( + forecasterId, + taskId, + dataQuality, + featureData, + dataStartTime, + dataEndTime, + executionStartTime, + executionEndTime, + error, + entity, + user, + schemaVersion, + featureId, + forecastValue, + lowerBound, + upperBound, + forecastDataStartTime, + forecastDataEndTime, + horizonIndex + ); + } } diff --git a/src/test/java/org/opensearch/timeseries/transport/ForecastResultBulkTransportActionTests.java b/src/test/java/org/opensearch/timeseries/transport/ForecastResultBulkTransportActionTests.java new file mode 100644 index 000000000..f5cead05e --- /dev/null +++ b/src/test/java/org/opensearch/timeseries/transport/ForecastResultBulkTransportActionTests.java @@ -0,0 +1,150 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.timeseries.transport; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.io.IOException; + +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.opensearch.action.bulk.BulkRequest; +import org.opensearch.action.bulk.BulkResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.PlainActionFuture; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.forecast.model.ForecastResult; +import org.opensearch.forecast.ratelimit.ForecastResultWriteRequest; +import org.opensearch.forecast.settings.ForecastSettings; +import org.opensearch.forecast.transport.ForecastResultBulkRequest; +import org.opensearch.forecast.transport.ForecastResultBulkTransportAction; +import org.opensearch.index.IndexingPressure; +import org.opensearch.timeseries.AbstractTimeSeriesTest; +import org.opensearch.timeseries.TestHelpers; +import org.opensearch.transport.TransportService; + +public class ForecastResultBulkTransportActionTests extends AbstractTimeSeriesTest { + + private ForecastResultBulkTransportAction resultBulk; + private TransportService transportService; + private ClusterService clusterService; + private IndexingPressure indexingPressure; + private Client client; + + @BeforeClass + public static void setUpBeforeClass() { + setUpThreadPool(ForecastResultBulkTransportActionTests.class.getSimpleName()); + } + + @AfterClass + public static void tearDownAfterClass() { + tearDownThreadPool(); + } + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + Settings settings = Settings + .builder() + .put(IndexingPressure.MAX_INDEXING_BYTES.getKey(), "1KB") + .put("forecast.index_pressure.soft_limit", 0.8) + .build(); + + // Setup test nodes and services + setupTestNodes(ForecastSettings.FORECAST_INDEX_PRESSURE_SOFT_LIMIT, ForecastSettings.FORECAST_INDEX_PRESSURE_HARD_LIMIT); + transportService = testNodes[0].transportService; + clusterService = testNodes[0].clusterService; + + ActionFilters actionFilters = mock(ActionFilters.class); + indexingPressure = mock(IndexingPressure.class); + + client = mock(Client.class); + + resultBulk = new ForecastResultBulkTransportAction( + transportService, + actionFilters, + indexingPressure, + settings, + clusterService, + client + ); + } + + @Override + @After + public final void tearDown() throws Exception { + tearDownTestNodes(); + super.tearDown(); + } + + @SuppressWarnings("unchecked") + public void testBulkIndexingFailure() throws IOException { + // Set indexing pressure below soft limit to ensure requests are processed + when(indexingPressure.getCurrentCombinedCoordinatingAndPrimaryBytes()).thenReturn(0L); + when(indexingPressure.getCurrentReplicaBytes()).thenReturn(0L); + + // Create a ForecastResultBulkRequest with some results + ForecastResultBulkRequest originalRequest = new ForecastResultBulkRequest(); + originalRequest.add(TestHelpers.randomForecastResultWriteRequest()); + originalRequest.add(TestHelpers.randomForecastResultWriteRequest()); + + // Mock client.execute to throw an exception + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + ActionListener listener = (ActionListener) args[2]; + listener.onFailure(new RuntimeException("Simulated bulk indexing failure")); + return null; + }).when(client).execute(any(), any(), any()); + + // Execute the action + PlainActionFuture future = PlainActionFuture.newFuture(); + resultBulk.doExecute(null, originalRequest, future); + + // Verify that the exception is propagated to the listener + Exception exception = expectThrows(Exception.class, () -> future.actionGet()); + assertTrue(exception.getMessage().contains("Simulated bulk indexing failure")); + } + + public void testPrepareBulkRequestFailure() throws IOException { + // Set indexing pressure below soft limit to ensure requests are processed + when(indexingPressure.getCurrentCombinedCoordinatingAndPrimaryBytes()).thenReturn(0L); + when(indexingPressure.getCurrentReplicaBytes()).thenReturn(0L); + + // Create a ForecastResultWriteRequest with a result that throws IOException when toXContent is called + ForecastResultWriteRequest faultyWriteRequest = mock(ForecastResultWriteRequest.class); + ForecastResult faultyResult = mock(ForecastResult.class); + + when(faultyWriteRequest.getResult()).thenReturn(faultyResult); + when(faultyWriteRequest.getResultIndex()).thenReturn(null); + + // Mock the toXContent method to throw IOException + doThrow(new IOException("Simulated IOException in toXContent")).when(faultyResult).toXContent(any(XContentBuilder.class), any()); + + // Create a ForecastResultBulkRequest with the faulty write request + ForecastResultBulkRequest originalRequest = new ForecastResultBulkRequest(); + originalRequest.add(faultyWriteRequest); + + // Execute the prepareBulkRequest method directly + BulkRequest bulkRequest = resultBulk.prepareBulkRequest(0.5f, originalRequest); + + // Since the exception is caught inside addResult, bulkRequest should have zero actions + assertEquals(0, bulkRequest.numberOfActions()); + } +}