From 6a6771cefa48533418aba6b46708a2b1efabc4be Mon Sep 17 00:00:00 2001 From: Kaituo Li Date: Thu, 2 Feb 2023 08:31:11 -0800 Subject: [PATCH] fixing dls/fls logic around numeric aggregations Signed-off-by: Kaituo Li --- .../ad/AnomalyDetectorJobRunner.java | 31 +- .../opensearch/ad/AnomalyDetectorPlugin.java | 25 +- .../ad/AnomalyDetectorProfileRunner.java | 32 +- .../opensearch/ad/EntityProfileRunner.java | 21 +- .../ad/feature/CompositeRetriever.java | 30 +- .../opensearch/ad/feature/FeatureManager.java | 30 +- .../ad/feature/SearchFeatureDao.java | 280 +++++-------- .../AbstractAnomalyDetectorActionHandler.java | 27 +- .../IndexAnomalyDetectorActionHandler.java | 4 + .../ValidateAnomalyDetectorActionHandler.java | 4 + .../opensearch/ad/task/ADBatchTaskRunner.java | 36 +- .../AnomalyResultTransportAction.java | 7 +- .../GetAnomalyDetectorTransportAction.java | 8 +- .../IndexAnomalyDetectorTransportAction.java | 7 + ...alidateAnomalyDetectorTransportAction.java | 5 + .../ad/util/ADSafeSecurityInjector.java | 78 ++++ .../MultiResponsesDelegateActionListener.java | 2 +- .../org/opensearch/ad/util/ParseUtils.java | 8 + .../ad/util/SafeSecurityInjector.java | 87 ++++ .../ad/util/SecurityClientUtil.java | 130 ++++++ .../org/opensearch/ad/util/SecurityUtil.java | 77 ++++ .../org/opensearch/ad/util/Throttler.java | 14 +- ...ndexAnomalyDetectorActionHandlerTests.java | 21 + ...dateAnomalyDetectorActionHandlerTests.java | 10 + .../ad/AbstractProfileRunnerTests.java | 15 +- .../ad/AnomalyDetectorProfileRunnerTests.java | 24 +- .../ad/AnomalyDetectorRestTestCase.java | 53 ++- .../ad/EntityProfileRunnerTests.java | 43 +- .../ad/MultiEntityProfileRunnerTests.java | 31 +- .../org/opensearch/ad/ODFERestTestCase.java | 49 +++ .../java/org/opensearch/ad/TestHelpers.java | 23 +- .../ad/e2e/SingleStreamModelPerfIT.java | 1 + .../ad/feature/FeatureManagerTests.java | 20 - .../NoPowermockSearchFeatureDaoTests.java | 221 ++++++++-- .../ad/feature/SearchFeatureDaoTests.java | 394 ++---------------- .../opensearch/ad/rest/SecureADRestIT.java | 6 +- .../ad/transport/AnomalyResultTests.java | 29 ++ .../ad/transport/GetAnomalyDetectorTests.java | 12 + ...etAnomalyDetectorTransportActionTests.java | 26 +- ...exAnomalyDetectorTransportActionTests.java | 8 + .../ad/transport/MultiEntityResultTests.java | 21 +- .../metrics/CardinalityProfileTests.java | 30 +- 42 files changed, 1231 insertions(+), 749 deletions(-) create mode 100644 src/main/java/org/opensearch/ad/util/ADSafeSecurityInjector.java create mode 100644 src/main/java/org/opensearch/ad/util/SafeSecurityInjector.java create mode 100644 src/main/java/org/opensearch/ad/util/SecurityClientUtil.java create mode 100644 src/main/java/org/opensearch/ad/util/SecurityUtil.java diff --git a/src/main/java/org/opensearch/ad/AnomalyDetectorJobRunner.java b/src/main/java/org/opensearch/ad/AnomalyDetectorJobRunner.java index 655267dcf..e0bdcf9c4 100644 --- a/src/main/java/org/opensearch/ad/AnomalyDetectorJobRunner.java +++ b/src/main/java/org/opensearch/ad/AnomalyDetectorJobRunner.java @@ -57,6 +57,7 @@ import org.opensearch.ad.transport.ProfileRequest; import org.opensearch.ad.transport.handler.AnomalyIndexHandler; import org.opensearch.ad.util.DiscoveryNodeFilterer; +import org.opensearch.ad.util.SecurityUtil; import org.opensearch.client.Client; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.common.settings.Settings; @@ -76,7 +77,6 @@ import org.opensearch.threadpool.ThreadPool; import com.google.common.base.Throwables; -import com.google.common.collect.ImmutableList; /** * JobScheduler will call AD job runner to get anomaly result periodically @@ -218,28 +218,11 @@ protected void runAdJob( } anomalyDetectionIndices.update(); - /* - * We need to handle 3 cases: - * 1. Detectors created by older versions and never updated. These detectors wont have User details in the - * detector object. `detector.user` will be null. Insert `all_access, AmazonES_all_access` role. - * 2. Detectors are created when security plugin is disabled, these will have empty User object. - * (`detector.user.name`, `detector.user.roles` are empty ) - * 3. Detectors are created when security plugin is enabled, these will have an User object. - * This will inject user role and check if the user role has permissions to call the execute - * Anomaly Result API. - */ - String user; - List roles; - if (jobParameter.getUser() == null) { - // It's possible that user create domain with security disabled, then enable security - // after upgrading. This is for BWC, for old detectors which created when security - // disabled, the user will be null. - user = ""; - roles = settings.getAsList("", ImmutableList.of("all_access", "AmazonES_all_access")); - } else { - user = jobParameter.getUser().getName(); - roles = jobParameter.getUser().getRoles(); - } + User userInfo = SecurityUtil.getUserFromJob(jobParameter, settings); + + String user = userInfo.getName(); + List roles = userInfo.getRoles(); + String resultIndex = jobParameter.getResultIndex(); if (resultIndex == null) { runAnomalyDetectionJob(jobParameter, lockService, lock, detectionStartTime, executionStartTime, detectorId, user, roles); @@ -265,7 +248,7 @@ private void runAnomalyDetectionJob( String user, List roles ) { - + // using one thread in the write threadpool try (InjectSecurity injectSecurity = new InjectSecurity(detectorId, settings, client.threadPool().getThreadContext())) { // Injecting user role to verify if the user has permissions for our API. injectSecurity.inject(user, roles); diff --git a/src/main/java/org/opensearch/ad/AnomalyDetectorPlugin.java b/src/main/java/org/opensearch/ad/AnomalyDetectorPlugin.java index 92552f27d..624e0ec04 100644 --- a/src/main/java/org/opensearch/ad/AnomalyDetectorPlugin.java +++ b/src/main/java/org/opensearch/ad/AnomalyDetectorPlugin.java @@ -159,6 +159,7 @@ import org.opensearch.ad.util.ClientUtil; import org.opensearch.ad.util.DiscoveryNodeFilterer; import org.opensearch.ad.util.IndexUtils; +import org.opensearch.ad.util.SecurityClientUtil; import org.opensearch.ad.util.Throttler; import org.opensearch.client.Client; import org.opensearch.cluster.metadata.IndexNameExpressionResolver; @@ -230,6 +231,7 @@ public class AnomalyDetectorPlugin extends Plugin implements ActionPlugin, Scrip private ThreadPool threadPool; private ADStats adStats; private ClientUtil clientUtil; + private SecurityClientUtil securityClientUtil; private DiscoveryNodeFilterer nodeFilter; private IndexUtils indexUtils; private ADTaskCacheManager adTaskCacheManager; @@ -352,11 +354,21 @@ public Collection createComponents( SingleFeatureLinearUniformInterpolator singleFeatureLinearUniformInterpolator = new IntegerSensitiveSingleFeatureLinearUniformInterpolator(); Interpolator interpolator = new LinearUniformInterpolator(singleFeatureLinearUniformInterpolator); + NodeStateManager stateManager = new NodeStateManager( + client, + xContentRegistry, + settings, + clientUtil, + getClock(), + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + clusterService + ); + securityClientUtil = new SecurityClientUtil(stateManager, settings); SearchFeatureDao searchFeatureDao = new SearchFeatureDao( client, xContentRegistry, interpolator, - clientUtil, + securityClientUtil, settings, clusterService, AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE @@ -381,16 +393,6 @@ public Collection createComponents( adCircuitBreakerService ); - NodeStateManager stateManager = new NodeStateManager( - client, - xContentRegistry, - settings, - clientUtil, - getClock(), - AnomalyDetectorSettings.HOURLY_MAINTENANCE, - clusterService - ); - FeatureManager featureManager = new FeatureManager( searchFeatureDao, interpolator, @@ -699,6 +701,7 @@ public PooledObject wrap(LinkedBuffer obj) { threadPool, clusterService, client, + securityClientUtil, adCircuitBreakerService, featureManager, adTaskManager, diff --git a/src/main/java/org/opensearch/ad/AnomalyDetectorProfileRunner.java b/src/main/java/org/opensearch/ad/AnomalyDetectorProfileRunner.java index d47494e5a..f386a5338 100644 --- a/src/main/java/org/opensearch/ad/AnomalyDetectorProfileRunner.java +++ b/src/main/java/org/opensearch/ad/AnomalyDetectorProfileRunner.java @@ -58,6 +58,7 @@ import org.opensearch.ad.util.DiscoveryNodeFilterer; import org.opensearch.ad.util.ExceptionUtil; import org.opensearch.ad.util.MultiResponsesDelegateActionListener; +import org.opensearch.ad.util.SecurityClientUtil; import org.opensearch.client.Client; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.common.xcontent.LoggingDeprecationHandler; @@ -81,6 +82,7 @@ public class AnomalyDetectorProfileRunner extends AbstractProfileRunner { private final Logger logger = LogManager.getLogger(AnomalyDetectorProfileRunner.class); private Client client; + private SecurityClientUtil clientUtil; private NamedXContentRegistry xContentRegistry; private DiscoveryNodeFilterer nodeFilter; private final TransportService transportService; @@ -89,6 +91,7 @@ public class AnomalyDetectorProfileRunner extends AbstractProfileRunner { public AnomalyDetectorProfileRunner( Client client, + SecurityClientUtil clientUtil, NamedXContentRegistry xContentRegistry, DiscoveryNodeFilterer nodeFilter, long requiredSamples, @@ -97,6 +100,7 @@ public AnomalyDetectorProfileRunner( ) { super(requiredSamples); this.client = client; + this.clientUtil = clientUtil; this.xContentRegistry = xContentRegistry; this.nodeFilter = nodeFilter; if (requiredSamples <= 0) { @@ -296,7 +300,7 @@ private void profileEntityStats(MultiResponsesDelegateActionListener { + final ActionListener searchResponseListener = ActionListener.wrap(searchResponse -> { Map aggMap = searchResponse.getAggregations().asMap(); InternalCardinality totalEntities = (InternalCardinality) aggMap.get(CommonName.TOTAL_ENTITIES); long value = totalEntities.getValue(); @@ -306,7 +310,17 @@ private void profileEntityStats(MultiResponsesDelegateActionListener { logger.warn(CommonErrorMessages.FAIL_TO_GET_TOTAL_ENTITIES + detector.getDetectorId()); listener.onFailure(searchException); - })); + }); + // using the original context in listener as user roles have no permissions for internal operations like fetching a + // checkpoint + clientUtil + .asyncRequestWithInjectedSecurity( + request, + client::search, + detector.getDetectorId(), + client, + searchResponseListener + ); } else { // Run a composite query and count the number of buckets to decide cardinality of multiple category fields AggregationBuilder bucketAggs = AggregationBuilders @@ -319,7 +333,7 @@ private void profileEntityStats(MultiResponsesDelegateActionListener { + final ActionListener searchResponseListener = ActionListener.wrap(searchResponse -> { DetectorProfile.Builder profileBuilder = new DetectorProfile.Builder(); Aggregations aggs = searchResponse.getAggregations(); if (aggs == null) { @@ -345,7 +359,17 @@ private void profileEntityStats(MultiResponsesDelegateActionListener { logger.warn(CommonErrorMessages.FAIL_TO_GET_TOTAL_ENTITIES + detector.getDetectorId()); listener.onFailure(searchException); - })); + }); + // using the original context in listener as user roles have no permissions for internal operations like fetching a + // checkpoint + clientUtil + .asyncRequestWithInjectedSecurity( + searchRequest, + client::search, + detector.getDetectorId(), + client, + searchResponseListener + ); } } diff --git a/src/main/java/org/opensearch/ad/EntityProfileRunner.java b/src/main/java/org/opensearch/ad/EntityProfileRunner.java index 11c4015e4..e69b1b90b 100644 --- a/src/main/java/org/opensearch/ad/EntityProfileRunner.java +++ b/src/main/java/org/opensearch/ad/EntityProfileRunner.java @@ -26,6 +26,7 @@ import org.opensearch.action.ActionListener; import org.opensearch.action.get.GetRequest; import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; import org.opensearch.ad.constant.CommonErrorMessages; import org.opensearch.ad.constant.CommonName; import org.opensearch.ad.model.AnomalyDetector; @@ -43,6 +44,7 @@ import org.opensearch.ad.transport.EntityProfileResponse; import org.opensearch.ad.util.MultiResponsesDelegateActionListener; import org.opensearch.ad.util.ParseUtils; +import org.opensearch.ad.util.SecurityClientUtil; import org.opensearch.client.Client; import org.opensearch.cluster.routing.Preference; import org.opensearch.common.xcontent.LoggingDeprecationHandler; @@ -64,11 +66,13 @@ public class EntityProfileRunner extends AbstractProfileRunner { static final String EMPTY_ENTITY_ATTRIBUTES = "Empty entity attributes"; static final String NO_ENTITY = "Cannot find entity"; private Client client; + private SecurityClientUtil clientUtil; private NamedXContentRegistry xContentRegistry; - public EntityProfileRunner(Client client, NamedXContentRegistry xContentRegistry, long requiredSamples) { + public EntityProfileRunner(Client client, SecurityClientUtil clientUtil, NamedXContentRegistry xContentRegistry, long requiredSamples) { super(requiredSamples); this.client = client; + this.clientUtil = clientUtil; this.xContentRegistry = xContentRegistry; } @@ -165,8 +169,7 @@ private void validateEntity( SearchRequest searchRequest = new SearchRequest(detector.getIndices().toArray(new String[0]), searchSourceBuilder) .preference(Preference.LOCAL.toString()); - - client.search(searchRequest, ActionListener.wrap(searchResponse -> { + final ActionListener searchResponseListener = ActionListener.wrap(searchResponse -> { try { if (searchResponse.getHits().getHits().length == 0) { listener.onFailure(new IllegalArgumentException(NO_ENTITY)); @@ -177,7 +180,17 @@ private void validateEntity( listener.onFailure(new IllegalArgumentException(NO_ENTITY)); return; } - }, e -> listener.onFailure(new IllegalArgumentException(NO_ENTITY)))); + }, e -> listener.onFailure(new IllegalArgumentException(NO_ENTITY))); + // using the original context in listener as user roles have no permissions for internal operations like fetching a + // checkpoint + clientUtil + .asyncRequestWithInjectedSecurity( + searchRequest, + client::search, + detector.getDetectorId(), + client, + searchResponseListener + ); } diff --git a/src/main/java/org/opensearch/ad/feature/CompositeRetriever.java b/src/main/java/org/opensearch/ad/feature/CompositeRetriever.java index bb9446a29..036c2773c 100644 --- a/src/main/java/org/opensearch/ad/feature/CompositeRetriever.java +++ b/src/main/java/org/opensearch/ad/feature/CompositeRetriever.java @@ -29,6 +29,7 @@ import org.opensearch.ad.model.Entity; import org.opensearch.ad.model.Feature; import org.opensearch.ad.util.ParseUtils; +import org.opensearch.ad.util.SecurityClientUtil; import org.opensearch.client.Client; import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.NamedXContentRegistry; @@ -61,6 +62,7 @@ public class CompositeRetriever extends AbstractRetriever { private final AnomalyDetector anomalyDetector; private final NamedXContentRegistry xContent; private final Client client; + private final SecurityClientUtil clientUtil; private int totalResults; private int maxEntities; private final int pageSize; @@ -73,6 +75,7 @@ public CompositeRetriever( AnomalyDetector anomalyDetector, NamedXContentRegistry xContent, Client client, + SecurityClientUtil clientUtil, long expirationEpochMs, Clock clock, Settings settings, @@ -84,6 +87,7 @@ public CompositeRetriever( this.anomalyDetector = anomalyDetector; this.xContent = xContent; this.client = client; + this.clientUtil = clientUtil; this.totalResults = 0; this.maxEntities = maxEntitiesPerInterval; this.pageSize = pageSize; @@ -98,6 +102,7 @@ public CompositeRetriever( AnomalyDetector anomalyDetector, NamedXContentRegistry xContent, Client client, + SecurityClientUtil clientUtil, long expirationEpochMs, Settings settings, int maxEntitiesPerInterval, @@ -109,6 +114,7 @@ public CompositeRetriever( anomalyDetector, xContent, client, + clientUtil, expirationEpochMs, Clock.systemUTC(), settings, @@ -157,10 +163,13 @@ public class PageIterator { private SearchSourceBuilder source; // a map from categorical field name to values (type: java.lang.Comparable) Map afterKey; + // number of iterations so far + private int iterations; public PageIterator(SearchSourceBuilder source) { this.source = source; this.afterKey = null; + this.iterations = 0; } /** @@ -168,8 +177,11 @@ public PageIterator(SearchSourceBuilder source) { * @param listener Listener to return results */ public void next(ActionListener listener) { + iterations++; + + // inject user role while searching. SearchRequest searchRequest = new SearchRequest(anomalyDetector.getIndices().toArray(new String[0]), source); - client.search(searchRequest, new ActionListener() { + final ActionListener searchResponseListener = new ActionListener() { @Override public void onResponse(SearchResponse response) { processResponse(response, () -> client.search(searchRequest, this), listener); @@ -179,7 +191,17 @@ public void onResponse(SearchResponse response) { public void onFailure(Exception e) { listener.onFailure(e); } - }); + }; + // using the original context in listener as user roles have no permissions for internal operations like fetching a + // checkpoint + clientUtil + .asyncRequestWithInjectedSecurity( + searchRequest, + client::search, + anomalyDetector.getDetectorId(), + client, + searchResponseListener + ); } private void processResponse(SearchResponse response, Runnable retry, ActionListener listener) { @@ -301,12 +323,12 @@ Optional getComposite(SearchResponse response) { /** * Whether next page exists. Conditions are: - * 1) we haven't fetched any page yet (totalResults == 0) or afterKey is not null + * 1)this is the first time we query (iterations == 0) or afterKey is not null * 2) next detection interval has not started * @return true if the iteration has more pages. */ public boolean hasNext() { - return (totalResults == 0 || (totalResults > 0 && afterKey != null)) && expirationEpochMs > clock.millis(); + return (iterations == 0 || (totalResults > 0 && afterKey != null)) && expirationEpochMs > clock.millis(); } @Override diff --git a/src/main/java/org/opensearch/ad/feature/FeatureManager.java b/src/main/java/org/opensearch/ad/feature/FeatureManager.java index cec226f55..f04f24926 100644 --- a/src/main/java/org/opensearch/ad/feature/FeatureManager.java +++ b/src/main/java/org/opensearch/ad/feature/FeatureManager.java @@ -278,32 +278,6 @@ private LongStream getFullShingleEndTimes(long endTime, long intervalMilli, int return LongStream.rangeClosed(1, shingleSize).map(i -> endTime - (shingleSize - i) * intervalMilli); } - /** - * Provides data for cold-start training. - * - * @deprecated use getColdStartData with listener instead. - * - * Training data starts with getting samples from (costly) search. - * Samples are increased in size via interpolation and then - * in dimension via shingling. - * - * @param detector contains data info (indices, documents, etc) - * @return data for cold-start training, or empty if unavailable - */ - @Deprecated - public Optional getColdStartData(AnomalyDetector detector) { - int shingleSize = detector.getShingleSize(); - return searchFeatureDao - .getLatestDataTime(detector) - .flatMap(latest -> searchFeatureDao.getFeaturesForSampledPeriods(detector, maxTrainSamples, maxSampleStride, latest)) - .map( - samples -> transpose( - interpolator.interpolate(transpose(samples.getKey()), samples.getValue() * (samples.getKey().length - 1) + 1) - ) - ) - .map(points -> batchShingle(points, shingleSize)); - } - /** * Returns to listener data for cold-start training. * @@ -488,7 +462,7 @@ public void getPreviewFeaturesForEntity( int stride = sampleRangeResults.getValue(); int shingleSize = detector.getShingleSize(); - getSamplesInRangesForEntity(detector, sampleRanges, entity, getFeatureSamplesListener(stride, shingleSize, listener)); + getPreviewSamplesInRangesForEntity(detector, sampleRanges, entity, getFeatureSamplesListener(stride, shingleSize, listener)); } private ActionListener>, double[][]>> getFeatureSamplesListener( @@ -564,7 +538,7 @@ private Entry>, Integer> getSampleRanges(AnomalyDetector * @param listener handle search results map: key is time ranges, value is corresponding search results * @throws IOException if a user gives wrong query input when defining a detector */ - void getSamplesInRangesForEntity( + void getPreviewSamplesInRangesForEntity( AnomalyDetector detector, List> sampleRanges, Entity entity, diff --git a/src/main/java/org/opensearch/ad/feature/SearchFeatureDao.java b/src/main/java/org/opensearch/ad/feature/SearchFeatureDao.java index 599d6aacd..ae19930de 100644 --- a/src/main/java/org/opensearch/ad/feature/SearchFeatureDao.java +++ b/src/main/java/org/opensearch/ad/feature/SearchFeatureDao.java @@ -46,8 +46,8 @@ import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.model.Entity; import org.opensearch.ad.model.IntervalTimeConfiguration; -import org.opensearch.ad.util.ClientUtil; import org.opensearch.ad.util.ParseUtils; +import org.opensearch.ad.util.SecurityClientUtil; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; @@ -68,7 +68,6 @@ import org.opensearch.search.aggregations.bucket.range.InternalDateRange; import org.opensearch.search.aggregations.bucket.range.InternalDateRange.Bucket; import org.opensearch.search.aggregations.bucket.terms.Terms; -import org.opensearch.search.aggregations.metrics.Max; import org.opensearch.search.aggregations.metrics.Min; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.sort.FieldSortBuilder; @@ -88,7 +87,7 @@ public class SearchFeatureDao extends AbstractRetriever { private final Client client; private final NamedXContentRegistry xContent; private final Interpolator interpolator; - private final ClientUtil clientUtil; + private final SecurityClientUtil clientUtil; private volatile int maxEntitiesForPreview; private volatile int pageSize; private final int minimumDocCountForPreview; @@ -100,7 +99,7 @@ public SearchFeatureDao( Client client, NamedXContentRegistry xContent, Interpolator interpolator, - ClientUtil clientUtil, + SecurityClientUtil clientUtil, Settings settings, ClusterService clusterService, int minimumDocCount, @@ -138,7 +137,7 @@ public SearchFeatureDao( Client client, NamedXContentRegistry xContent, Interpolator interpolator, - ClientUtil clientUtil, + SecurityClientUtil clientUtil, Settings settings, ClusterService clusterService, int minimumDocCount @@ -158,28 +157,6 @@ public SearchFeatureDao( ); } - /** - * Returns epoch time of the latest data under the detector. - * - * @deprecated use getLatestDataTime with listener instead. - * - * @param detector info about the indices and documents - * @return epoch time of the latest data in milliseconds - */ - @Deprecated - public Optional getLatestDataTime(AnomalyDetector detector) { - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder() - .aggregation(AggregationBuilders.max(CommonName.AGG_NAME_MAX_TIME).field(detector.getTimeField())) - .size(0); - SearchRequest searchRequest = new SearchRequest().indices(detector.getIndices().toArray(new String[0])).source(searchSourceBuilder); - return clientUtil - .timedRequest(searchRequest, logger, client::search) - .map(SearchResponse::getAggregations) - .map(aggs -> aggs.asMap()) - .map(map -> (Max) map.get(CommonName.AGG_NAME_MAX_TIME)) - .map(agg -> (long) agg.getValue()); - } - /** * Returns to listener the epoch time of the latset data under the detector. * @@ -191,10 +168,17 @@ public void getLatestDataTime(AnomalyDetector detector, ActionListener searchResponseListener = ActionListener + .wrap(response -> listener.onResponse(ParseUtils.getLatestDataTime(response)), listener::onFailure); + // using the original context in listener as user roles have no permissions for internal operations like fetching a + // checkpoint + clientUtil + .asyncRequestWithInjectedSecurity( searchRequest, - ActionListener.wrap(response -> listener.onResponse(ParseUtils.getLatestDataTime(response)), listener::onFailure) + client::search, + detector.getDetectorId(), + client, + searchResponseListener ); } @@ -356,18 +340,24 @@ public void getHighestCountEntities( .trackTotalHits(false) .size(0); SearchRequest searchRequest = new SearchRequest().indices(detector.getIndices().toArray(new String[0])).source(searchSourceBuilder); - client - .search( + final ActionListener searchResponseListener = new TopEntitiesListener( + listener, + detector, + searchSourceBuilder, + // TODO: tune timeout for historical analysis based on performance test result + clock.millis() + previewTimeoutInMilliseconds, + maxEntitiesSize, + minimumDocCount + ); + // using the original context in listener as user roles have no permissions for internal operations like fetching a + // checkpoint + clientUtil + .asyncRequestWithInjectedSecurity( searchRequest, - new TopEntitiesListener( - listener, - detector, - searchSourceBuilder, - // TODO: tune timeout for historical analysis based on performance test result - clock.millis() + previewTimeoutInMilliseconds, - maxEntitiesSize, - minimumDocCount - ) + client::search, + detector.getDetectorId(), + client, + searchResponseListener ); } @@ -451,9 +441,14 @@ public void onResponse(SearchResponse response) { } } else { updateSourceAfterKey(afterKey, searchSourceBuilder); - client - .search( + // using the original context in listener as user roles have no permissions for internal operations like fetching a + // checkpoint + clientUtil + .asyncRequestWithInjectedSecurity( new SearchRequest().indices(detector.getIndices().toArray(new String[0])).source(searchSourceBuilder), + client::search, + detector.getDetectorId(), + client, this ); } @@ -489,10 +484,16 @@ public void getEntityMinDataTime(AnomalyDetector detector, Entity entity, Action .trackTotalHits(false) .size(0); SearchRequest searchRequest = new SearchRequest().indices(detector.getIndices().toArray(new String[0])).source(searchSourceBuilder); - client - .search( + final ActionListener searchResponseListener = ActionListener + .wrap(response -> { listener.onResponse(parseMinDataTime(response)); }, listener::onFailure); + // inject user role while searching. + clientUtil + .asyncRequestWithInjectedSecurity( searchRequest, - ActionListener.wrap(response -> { listener.onResponse(parseMinDataTime(response)); }, listener::onFailure) + client::search, + detector.getDetectorId(), + client, + searchResponseListener ); } @@ -505,31 +506,6 @@ private Optional parseMinDataTime(SearchResponse searchResponse) { return mapOptional.map(map -> (Min) map.get(AGG_NAME_MIN)).map(agg -> (long) agg.getValue()); } - /** - * Gets features for the given time period. - * This function also adds given detector to negative cache before sending es request. - * Once response/exception is received within timeout, this request will be treated as complete - * and cleared from the negative cache. - * Otherwise this detector entry remain in the negative to reject further request. - * - * @deprecated use getFeaturesForPeriod with listener instead. - * - * @param detector info about indices, documents, feature query - * @param startTime epoch milliseconds at the beginning of the period - * @param endTime epoch milliseconds at the end of the period - * @throws IllegalStateException when unexpected failures happen - * @return features from search results, empty when no data found - */ - @Deprecated - public Optional getFeaturesForPeriod(AnomalyDetector detector, long startTime, long endTime) { - SearchRequest searchRequest = createFeatureSearchRequest(detector, startTime, endTime, Optional.empty()); - - // send throttled request: this request will clear the negative cache if the request finished within timeout - return clientUtil - .throttledTimedRequest(searchRequest, logger, client::search, detector) - .flatMap(resp -> parseResponse(resp, detector.getEnabledFeatureIds())); - } - /** * Returns to listener features for the given time period. * @@ -540,11 +516,17 @@ public Optional getFeaturesForPeriod(AnomalyDetector detector, long st */ public void getFeaturesForPeriod(AnomalyDetector detector, long startTime, long endTime, ActionListener> listener) { SearchRequest searchRequest = createFeatureSearchRequest(detector, startTime, endTime, Optional.empty()); - client - .search( + final ActionListener searchResponseListener = ActionListener + .wrap(response -> listener.onResponse(parseResponse(response, detector.getEnabledFeatureIds())), listener::onFailure); + // using the original context in listener as user roles have no permissions for internal operations like fetching a + // checkpoint + clientUtil + .asyncRequestWithInjectedSecurity( searchRequest, - ActionListener - .wrap(response -> listener.onResponse(parseResponse(response, detector.getEnabledFeatureIds())), listener::onFailure) + client::search, + detector.getDetectorId(), + client, + searchResponseListener ); } @@ -559,14 +541,19 @@ public void getFeaturesForPeriodByBatch( logger.debug("Batch query for detector {}: {} ", detector.getDetectorId(), searchSourceBuilder); SearchRequest searchRequest = new SearchRequest(detector.getIndices().toArray(new String[0])).source(searchSourceBuilder); - client - .search( + final ActionListener searchResponseListener = ActionListener + .wrap( + response -> { listener.onResponse(parseBucketAggregationResponse(response, detector.getEnabledFeatureIds())); }, + listener::onFailure + ); + // inject user role while searching. + clientUtil + .asyncRequestWithInjectedSecurity( searchRequest, - ActionListener - .wrap( - response -> { listener.onResponse(parseBucketAggregationResponse(response, detector.getEnabledFeatureIds())); }, - listener::onFailure - ) + client::search, + detector.getDetectorId(), + client, + searchResponseListener ); } @@ -604,8 +591,7 @@ public void getFeatureSamplesForPeriods( ActionListener>> listener ) throws IOException { SearchRequest request = createPreviewSearchRequest(detector, ranges); - - client.search(request, ActionListener.wrap(response -> { + final ActionListener searchResponseListener = ActionListener.wrap(response -> { Aggregations aggs = response.getAggregations(); if (aggs == null) { listener.onResponse(Collections.emptyList()); @@ -622,105 +608,16 @@ public void getFeatureSamplesForPeriods( .map(bucket -> parseBucket(bucket, detector.getEnabledFeatureIds())) .collect(Collectors.toList()) ); - }, listener::onFailure)); - } - - /** - * Gets features for sampled periods. - * - * @deprecated use getFeaturesForSampledPeriods with listener instead. - * - * Sampling starts with the latest period and goes backwards in time until there are up to {@code maxSamples} samples. - * If the initial stride {@code maxStride} results into a low count of samples, the implementation - * may attempt with (exponentially) reduced strides and interpolate missing points. - * - * @param detector info about indices, documents, feature query - * @param maxSamples the maximum number of samples to return - * @param maxStride the maximum number of periods between samples - * @param endTime the end time of the latest period - * @return sampled features and stride, empty when no data found - */ - @Deprecated - public Optional> getFeaturesForSampledPeriods( - AnomalyDetector detector, - int maxSamples, - int maxStride, - long endTime - ) { - Map cache = new HashMap<>(); - int currentStride = maxStride; - Optional features = Optional.empty(); - logger.info(String.format(Locale.ROOT, "Getting features for detector %s starting %d", detector.getDetectorId(), endTime)); - while (currentStride >= 1) { - boolean isInterpolatable = currentStride < maxStride; - features = getFeaturesForSampledPeriods(detector, maxSamples, currentStride, endTime, cache, isInterpolatable); - - if (!features.isPresent() || features.get().length > maxSamples / 2 || currentStride == 1) { - logger - .info( - String - .format( - Locale.ROOT, - "Get features for detector %s finishes with features present %b, current stride %d", - detector.getDetectorId(), - features.isPresent(), - currentStride - ) - ); - break; - } else { - currentStride = currentStride / 2; - } - } - if (features.isPresent()) { - return Optional.of(new SimpleEntry<>(features.get(), currentStride)); - } else { - return Optional.empty(); - } - } - - private Optional getFeaturesForSampledPeriods( - AnomalyDetector detector, - int maxSamples, - int stride, - long endTime, - Map cache, - boolean isInterpolatable - ) { - ArrayDeque sampledFeatures = new ArrayDeque<>(maxSamples); - for (int i = 0; i < maxSamples; i++) { - long span = ((IntervalTimeConfiguration) detector.getDetectionInterval()).toDuration().toMillis(); - long end = endTime - span * stride * i; - if (cache.containsKey(end)) { - sampledFeatures.addFirst(cache.get(end)); - } else { - Optional features = getFeaturesForPeriod(detector, end - span, end); - if (features.isPresent()) { - cache.put(end, features.get()); - sampledFeatures.addFirst(features.get()); - } else if (isInterpolatable) { - Optional previous = Optional.ofNullable(cache.get(end - span * stride)); - Optional next = Optional.ofNullable(cache.get(end + span * stride)); - if (previous.isPresent() && next.isPresent()) { - double[] interpolants = getInterpolants(previous.get(), next.get()); - cache.put(end, interpolants); - sampledFeatures.addFirst(interpolants); - } else { - break; - } - } else { - break; - } - - } - } - Optional samples; - if (sampledFeatures.isEmpty()) { - samples = Optional.empty(); - } else { - samples = Optional.of(sampledFeatures.toArray(new double[0][0])); - } - return samples; + }, listener::onFailure); + // inject user role while searching + clientUtil + .asyncRequestWithInjectedSecurity( + request, + client::search, + detector.getDetectorId(), + client, + searchResponseListener + ); } /** @@ -962,10 +859,9 @@ public void getColdStartSamplesForPeriods( Entity entity, boolean includesEmptyBucket, ActionListener>> listener - ) throws IOException { + ) { SearchRequest request = createColdStartFeatureSearchRequest(detector, ranges, entity); - - client.search(request, ActionListener.wrap(response -> { + final ActionListener searchResponseListener = ActionListener.wrap(response -> { Aggregations aggs = response.getAggregations(); if (aggs == null) { listener.onResponse(Collections.emptyList()); @@ -997,7 +893,17 @@ public void getColdStartSamplesForPeriods( .map(bucket -> parseBucket(bucket, detector.getEnabledFeatureIds())) .collect(Collectors.toList()) ); - }, listener::onFailure)); + }, listener::onFailure); + + // inject user role while searching. + clientUtil + .asyncRequestWithInjectedSecurity( + request, + client::search, + detector.getDetectorId(), + client, + searchResponseListener + ); } private SearchRequest createColdStartFeatureSearchRequest(AnomalyDetector detector, List> ranges, Entity entity) { diff --git a/src/main/java/org/opensearch/ad/rest/handler/AbstractAnomalyDetectorActionHandler.java b/src/main/java/org/opensearch/ad/rest/handler/AbstractAnomalyDetectorActionHandler.java index 11afab837..ffb2a49f2 100644 --- a/src/main/java/org/opensearch/ad/rest/handler/AbstractAnomalyDetectorActionHandler.java +++ b/src/main/java/org/opensearch/ad/rest/handler/AbstractAnomalyDetectorActionHandler.java @@ -65,6 +65,7 @@ import org.opensearch.ad.transport.IndexAnomalyDetectorResponse; import org.opensearch.ad.util.MultiResponsesDelegateActionListener; import org.opensearch.ad.util.RestHandlerUtils; +import org.opensearch.ad.util.SecurityClientUtil; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.unit.TimeValue; @@ -139,6 +140,7 @@ public abstract class AbstractAnomalyDetectorActionHandler listener; @@ -152,6 +154,7 @@ public abstract class AbstractAnomalyDetectorActionHandler listener, AnomalyDetectionIndices anomalyDetectionIndices, @@ -195,6 +199,7 @@ public AbstractAnomalyDetectorActionHandler( ) { this.clusterService = clusterService; this.client = client; + this.clientUtil = clientUtil; this.transportService = transportService; this.anomalyDetectionIndices = anomalyDetectionIndices; this.listener = listener; @@ -594,7 +599,7 @@ protected void validateCategoricalField(String detectorId, boolean indexingDryRu listener.onFailure(new IllegalArgumentException(message)); }); - client.execute(GetFieldMappingsAction.INSTANCE, getMappingsRequest, mappingsListener); + clientUtil.executeWithInjectedSecurity(GetFieldMappingsAction.INSTANCE, getMappingsRequest, user, client, mappingsListener); } protected void searchAdInputIndices(String detectorId, boolean indexingDryRun) { @@ -605,15 +610,13 @@ protected void searchAdInputIndices(String detectorId, boolean indexingDryRun) { SearchRequest searchRequest = new SearchRequest(anomalyDetector.getIndices().toArray(new String[0])).source(searchSourceBuilder); - client - .search( - searchRequest, - ActionListener - .wrap( - searchResponse -> onSearchAdInputIndicesResponse(searchResponse, detectorId, indexingDryRun), - exception -> listener.onFailure(exception) - ) + ActionListener searchResponseListener = ActionListener + .wrap( + searchResponse -> onSearchAdInputIndicesResponse(searchResponse, detectorId, indexingDryRun), + exception -> listener.onFailure(exception) ); + + clientUtil.asyncRequestWithInjectedSecurity(searchRequest, client::search, user, client, searchResponseListener); } protected void onSearchAdInputIndicesResponse(SearchResponse response, String detectorId, boolean indexingDryRun) throws IOException { @@ -640,7 +643,6 @@ protected void checkADNameExists(String detectorId, boolean indexingDryRun) thro } SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(boolQueryBuilder).timeout(requestTimeout); SearchRequest searchRequest = new SearchRequest(ANOMALY_DETECTORS_INDEX).source(searchSourceBuilder); - client .search( searchRequest, @@ -830,7 +832,7 @@ protected void validateAnomalyDetectorFeatures(String detectorId, boolean indexi ); ssb.aggregation(internalAgg.getAggregatorFactories().iterator().next()); SearchRequest searchRequest = new SearchRequest().indices(anomalyDetector.getIndices().toArray(new String[0])).source(ssb); - client.search(searchRequest, ActionListener.wrap(response -> { + ActionListener searchResponseListener = ActionListener.wrap(response -> { Optional aggFeatureResult = searchFeatureDao.parseResponse(response, Arrays.asList(feature.getId())); if (aggFeatureResult.isPresent()) { multiFeatureQueriesResponseListener @@ -852,7 +854,8 @@ protected void validateAnomalyDetectorFeatures(String detectorId, boolean indexi logger.error(errorMessage, e); multiFeatureQueriesResponseListener.onFailure(new OpenSearchStatusException(errorMessage, RestStatus.BAD_REQUEST, e)); - })); + }); + clientUtil.asyncRequestWithInjectedSecurity(searchRequest, client::search, user, client, searchResponseListener); } } } diff --git a/src/main/java/org/opensearch/ad/rest/handler/IndexAnomalyDetectorActionHandler.java b/src/main/java/org/opensearch/ad/rest/handler/IndexAnomalyDetectorActionHandler.java index e3dc3661c..cdf9ca7eb 100644 --- a/src/main/java/org/opensearch/ad/rest/handler/IndexAnomalyDetectorActionHandler.java +++ b/src/main/java/org/opensearch/ad/rest/handler/IndexAnomalyDetectorActionHandler.java @@ -18,6 +18,7 @@ import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.task.ADTaskManager; import org.opensearch.ad.transport.IndexAnomalyDetectorResponse; +import org.opensearch.ad.util.SecurityClientUtil; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.unit.TimeValue; @@ -38,6 +39,7 @@ public class IndexAnomalyDetectorActionHandler extends AbstractAnomalyDetectorAc * * @param clusterService ClusterService * @param client ES node client that executes actions on the local node + * @param clientUtil AD client util * @param transportService ES transport service * @param listener ES channel used to construct bytes / builder based outputs, and send responses * @param anomalyDetectionIndices anomaly detector index manager @@ -59,6 +61,7 @@ public class IndexAnomalyDetectorActionHandler extends AbstractAnomalyDetectorAc public IndexAnomalyDetectorActionHandler( ClusterService clusterService, Client client, + SecurityClientUtil clientUtil, TransportService transportService, ActionListener listener, AnomalyDetectionIndices anomalyDetectionIndices, @@ -80,6 +83,7 @@ public IndexAnomalyDetectorActionHandler( super( clusterService, client, + clientUtil, transportService, listener, anomalyDetectionIndices, diff --git a/src/main/java/org/opensearch/ad/rest/handler/ValidateAnomalyDetectorActionHandler.java b/src/main/java/org/opensearch/ad/rest/handler/ValidateAnomalyDetectorActionHandler.java index a88c40eff..aa97b4fe0 100644 --- a/src/main/java/org/opensearch/ad/rest/handler/ValidateAnomalyDetectorActionHandler.java +++ b/src/main/java/org/opensearch/ad/rest/handler/ValidateAnomalyDetectorActionHandler.java @@ -23,6 +23,7 @@ import org.opensearch.ad.model.ValidationAspect; import org.opensearch.ad.rest.RestValidateAnomalyDetectorAction; import org.opensearch.ad.transport.ValidateAnomalyDetectorResponse; +import org.opensearch.ad.util.SecurityClientUtil; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.unit.TimeValue; @@ -46,6 +47,7 @@ public class ValidateAnomalyDetectorActionHandler extends AbstractAnomalyDetecto * * @param clusterService ClusterService * @param client ES node client that executes actions on the local node + * @param clientUtil AD client utility * @param listener ES channel used to construct bytes / builder based outputs, and send responses * @param anomalyDetectionIndices anomaly detector index manager * @param anomalyDetector anomaly detector instance @@ -62,6 +64,7 @@ public class ValidateAnomalyDetectorActionHandler extends AbstractAnomalyDetecto public ValidateAnomalyDetectorActionHandler( ClusterService clusterService, Client client, + SecurityClientUtil clientUtil, ActionListener listener, AnomalyDetectionIndices anomalyDetectionIndices, AnomalyDetector anomalyDetector, @@ -78,6 +81,7 @@ public ValidateAnomalyDetectorActionHandler( super( clusterService, client, + clientUtil, null, listener, anomalyDetectionIndices, diff --git a/src/main/java/org/opensearch/ad/task/ADBatchTaskRunner.java b/src/main/java/org/opensearch/ad/task/ADBatchTaskRunner.java index 48a8ab37b..b080b4358 100644 --- a/src/main/java/org/opensearch/ad/task/ADBatchTaskRunner.java +++ b/src/main/java/org/opensearch/ad/task/ADBatchTaskRunner.java @@ -50,6 +50,7 @@ import org.opensearch.action.ActionListener; import org.opensearch.action.ActionListenerResponseHandler; import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; import org.opensearch.action.support.ThreadedActionListener; import org.opensearch.ad.breaker.ADCircuitBreakerService; import org.opensearch.ad.caching.PriorityTracker; @@ -89,6 +90,7 @@ import org.opensearch.ad.transport.handler.AnomalyResultBulkIndexHandler; import org.opensearch.ad.util.ExceptionUtil; import org.opensearch.ad.util.ParseUtils; +import org.opensearch.ad.util.SecurityClientUtil; import org.opensearch.client.Client; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.service.ClusterService; @@ -123,6 +125,7 @@ public class ADBatchTaskRunner { private Settings settings; private final ThreadPool threadPool; private final Client client; + private final SecurityClientUtil clientUtil; private final ADStats adStats; private final ClusterService clusterService; private final FeatureManager featureManager; @@ -151,6 +154,7 @@ public ADBatchTaskRunner( ThreadPool threadPool, ClusterService clusterService, Client client, + SecurityClientUtil clientUtil, ADCircuitBreakerService adCircuitBreakerService, FeatureManager featureManager, ADTaskManager adTaskManager, @@ -166,6 +170,7 @@ public ADBatchTaskRunner( this.threadPool = threadPool; this.clusterService = clusterService; this.client = client; + this.clientUtil = clientUtil; this.anomalyResultBulkIndexHandler = anomalyResultBulkIndexHandler; this.adStats = adStats; this.adCircuitBreakerService = adCircuitBreakerService; @@ -438,7 +443,7 @@ private void searchTopEntitiesForSingleCategoryHC( SearchRequest searchRequest = new SearchRequest(); searchRequest.source(sourceBuilder); searchRequest.indices(adTask.getDetector().getIndices().toArray(new String[0])); - client.search(searchRequest, ActionListener.wrap(r -> { + final ActionListener searchResponseListener = ActionListener.wrap(r -> { StringTerms stringTerms = r.getAggregations().get(topEntitiesAgg); List buckets = stringTerms.getBuckets(); List topEntities = new ArrayList<>(); @@ -473,7 +478,18 @@ private void searchTopEntitiesForSingleCategoryHC( }, e -> { logger.error("Failed to get top entities for detector " + adTask.getDetectorId(), e); internalHCListener.onFailure(e); - })); + }); + // using the original context in listener as user roles have no permissions for internal operations like fetching a + // checkpoint + clientUtil + .asyncRequestWithInjectedSecurity( + searchRequest, + client::search, + // user is the one who started historical detector. Read AnomalyDetectorJobTransportAction.doExecute. + adTask.getUser(), + client, + searchResponseListener + ); } /** @@ -949,8 +965,7 @@ private void getDateRangeOfSourceData(ADTask adTask, BiConsumer cons SearchRequest request = new SearchRequest() .indices(adTask.getDetector().getIndices().toArray(new String[0])) .source(searchSourceBuilder); - - client.search(request, ActionListener.wrap(r -> { + final ActionListener searchResponseListener = ActionListener.wrap(r -> { InternalMin minAgg = r.getAggregations().get(AGG_NAME_MIN_TIME); InternalMax maxAgg = r.getAggregations().get(AGG_NAME_MAX_TIME); double minValue = minAgg.getValue(); @@ -989,7 +1004,18 @@ private void getDateRangeOfSourceData(ADTask adTask, BiConsumer cons return; } consumer.accept(dataStartTime, dataEndTime); - }, e -> { internalListener.onFailure(e); })); + }, e -> { internalListener.onFailure(e); }); + + // inject user role while searching. + clientUtil + .asyncRequestWithInjectedSecurity( + request, + client::search, + // user is the one who started historical detector. Read AnomalyDetectorJobTransportAction.doExecute. + adTask.getUser(), + client, + searchResponseListener + ); } private void getFeatureData( diff --git a/src/main/java/org/opensearch/ad/transport/AnomalyResultTransportAction.java b/src/main/java/org/opensearch/ad/transport/AnomalyResultTransportAction.java index 42142a286..53a6ce979 100644 --- a/src/main/java/org/opensearch/ad/transport/AnomalyResultTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/AnomalyResultTransportAction.java @@ -74,6 +74,7 @@ import org.opensearch.ad.task.ADTaskManager; import org.opensearch.ad.util.ExceptionUtil; import org.opensearch.ad.util.ParseUtils; +import org.opensearch.ad.util.SecurityClientUtil; import org.opensearch.client.Client; import org.opensearch.cluster.ClusterState; import org.opensearch.cluster.block.ClusterBlockLevel; @@ -126,6 +127,7 @@ public class AnomalyResultTransportAction extends HandledTransportAction onFeatureResponseForSingleEntityDete ) { return ActionListener.wrap(featureOptional -> { List featureInResponse = null; - if (featureOptional.getUnprocessedFeatures().isPresent()) { featureInResponse = ParseUtils.getFeatureData(featureOptional.getUnprocessedFeatures().get(), detector); } diff --git a/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorTransportAction.java b/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorTransportAction.java index 563153ec1..f42f700af 100644 --- a/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorTransportAction.java @@ -58,6 +58,7 @@ import org.opensearch.ad.task.ADTaskManager; import org.opensearch.ad.util.DiscoveryNodeFilterer; import org.opensearch.ad.util.RestHandlerUtils; +import org.opensearch.ad.util.SecurityClientUtil; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.CheckedConsumer; @@ -80,7 +81,7 @@ public class GetAnomalyDetectorTransportAction extends HandledTransportAction allProfileTypeStrs; private final Set allProfileTypes; private final Set defaultDetectorProfileTypes; @@ -100,6 +101,7 @@ public GetAnomalyDetectorTransportAction( ActionFilters actionFilters, ClusterService clusterService, Client client, + SecurityClientUtil clientUtil, Settings settings, NamedXContentRegistry xContentRegistry, ADTaskManager adTaskManager @@ -107,7 +109,7 @@ public GetAnomalyDetectorTransportAction( super(GetAnomalyDetectorAction.NAME, transportService, actionFilters, GetAnomalyDetectorRequest::new); this.clusterService = clusterService; this.client = client; - + this.clientUtil = clientUtil; List allProfiles = Arrays.asList(DetectorProfileName.values()); this.allProfileTypes = EnumSet.copyOf(allProfiles); this.allProfileTypeStrs = getProfileListStrs(allProfiles); @@ -165,6 +167,7 @@ protected void getExecute(GetAnomalyDetectorRequest request, ActionListener entityProfilesToCollect = getEntityProfilesToCollect(typesStr, all); EntityProfileRunner profileRunner = new EntityProfileRunner( client, + clientUtil, xContentRegistry, AnomalyDetectorSettings.NUM_MIN_SAMPLES ); @@ -203,6 +206,7 @@ protected void getExecute(GetAnomalyDetectorRequest request, ActionListener profilesToCollect = getProfilesToCollect(typesStr, all); AnomalyDetectorProfileRunner profileRunner = new AnomalyDetectorProfileRunner( client, + clientUtil, xContentRegistry, nodeFilter, AnomalyDetectorSettings.NUM_MIN_SAMPLES, diff --git a/src/main/java/org/opensearch/ad/transport/IndexAnomalyDetectorTransportAction.java b/src/main/java/org/opensearch/ad/transport/IndexAnomalyDetectorTransportAction.java index d40bd295e..a0b9751e2 100644 --- a/src/main/java/org/opensearch/ad/transport/IndexAnomalyDetectorTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/IndexAnomalyDetectorTransportAction.java @@ -36,6 +36,7 @@ import org.opensearch.ad.rest.handler.IndexAnomalyDetectorActionHandler; import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.ad.task.ADTaskManager; +import org.opensearch.ad.util.SecurityClientUtil; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; @@ -53,6 +54,7 @@ public class IndexAnomalyDetectorTransportAction extends HandledTransportAction { private static final Logger LOG = LogManager.getLogger(IndexAnomalyDetectorTransportAction.class); private final Client client; + private final SecurityClientUtil clientUtil; private final TransportService transportService; private final AnomalyDetectionIndices anomalyDetectionIndices; private final ClusterService clusterService; @@ -60,12 +62,14 @@ public class IndexAnomalyDetectorTransportAction extends HandledTransportAction< private final ADTaskManager adTaskManager; private volatile Boolean filterByEnabled; private final SearchFeatureDao searchFeatureDao; + private final Settings settings; @Inject public IndexAnomalyDetectorTransportAction( TransportService transportService, ActionFilters actionFilters, Client client, + SecurityClientUtil clientUtil, ClusterService clusterService, Settings settings, AnomalyDetectionIndices anomalyDetectionIndices, @@ -75,6 +79,7 @@ public IndexAnomalyDetectorTransportAction( ) { super(IndexAnomalyDetectorAction.NAME, transportService, actionFilters, IndexAnomalyDetectorRequest::new); this.client = client; + this.clientUtil = clientUtil; this.transportService = transportService; this.clusterService = clusterService; this.anomalyDetectionIndices = anomalyDetectionIndices; @@ -83,6 +88,7 @@ public IndexAnomalyDetectorTransportAction( this.searchFeatureDao = searchFeatureDao; filterByEnabled = AnomalyDetectorSettings.FILTER_BY_BACKEND_ROLES.get(settings); clusterService.getClusterSettings().addSettingsUpdateConsumer(FILTER_BY_BACKEND_ROLES, it -> filterByEnabled = it); + this.settings = settings; } @Override @@ -157,6 +163,7 @@ protected void adExecute( IndexAnomalyDetectorActionHandler indexAnomalyDetectorActionHandler = new IndexAnomalyDetectorActionHandler( clusterService, client, + clientUtil, transportService, listener, anomalyDetectionIndices, diff --git a/src/main/java/org/opensearch/ad/transport/ValidateAnomalyDetectorTransportAction.java b/src/main/java/org/opensearch/ad/transport/ValidateAnomalyDetectorTransportAction.java index 2194982a1..9f1b1857e 100644 --- a/src/main/java/org/opensearch/ad/transport/ValidateAnomalyDetectorTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/ValidateAnomalyDetectorTransportAction.java @@ -37,6 +37,7 @@ import org.opensearch.ad.rest.handler.AnomalyDetectorFunction; import org.opensearch.ad.rest.handler.ValidateAnomalyDetectorActionHandler; import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.ad.util.SecurityClientUtil; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; @@ -56,6 +57,7 @@ public class ValidateAnomalyDetectorTransportAction extends private static final Logger logger = LogManager.getLogger(ValidateAnomalyDetectorTransportAction.class); private final Client client; + private final SecurityClientUtil clientUtil; private final ClusterService clusterService; private final NamedXContentRegistry xContentRegistry; private final AnomalyDetectionIndices anomalyDetectionIndices; @@ -65,6 +67,7 @@ public class ValidateAnomalyDetectorTransportAction extends @Inject public ValidateAnomalyDetectorTransportAction( Client client, + SecurityClientUtil clientUtil, ClusterService clusterService, NamedXContentRegistry xContentRegistry, Settings settings, @@ -75,6 +78,7 @@ public ValidateAnomalyDetectorTransportAction( ) { super(ValidateAnomalyDetectorAction.NAME, transportService, actionFilters, ValidateAnomalyDetectorRequest::new); this.client = client; + this.clientUtil = clientUtil; this.clusterService = clusterService; this.xContentRegistry = xContentRegistry; this.anomalyDetectionIndices = anomalyDetectionIndices; @@ -139,6 +143,7 @@ private void validateExecute( ValidateAnomalyDetectorActionHandler handler = new ValidateAnomalyDetectorActionHandler( clusterService, client, + clientUtil, validateListener, anomalyDetectionIndices, detector, diff --git a/src/main/java/org/opensearch/ad/util/ADSafeSecurityInjector.java b/src/main/java/org/opensearch/ad/util/ADSafeSecurityInjector.java new file mode 100644 index 000000000..8fcd11ce0 --- /dev/null +++ b/src/main/java/org/opensearch/ad/util/ADSafeSecurityInjector.java @@ -0,0 +1,78 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.util; + +import java.util.Optional; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.ActionListener; +import org.opensearch.ad.NodeStateManager; +import org.opensearch.ad.common.exception.EndRunException; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.common.Strings; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.commons.authuser.User; + +public class ADSafeSecurityInjector extends SafeSecurityInjector { + private static final Logger LOG = LogManager.getLogger(ADSafeSecurityInjector.class); + private NodeStateManager nodeStateManager; + + public ADSafeSecurityInjector(String detectorId, Settings settings, ThreadContext tc, NodeStateManager stateManager) { + super(detectorId, settings, tc); + this.nodeStateManager = stateManager; + } + + public void injectUserRolesFromDetector(ActionListener injectListener) { + // if id is null, we cannot fetch a detector + if (Strings.isEmpty(id)) { + LOG.debug("Empty id"); + injectListener.onResponse(null); + return; + } + + // for example, if a user exists in thread context, we don't need to inject user/roles + if (!shouldInject()) { + LOG.debug("Don't need to inject"); + injectListener.onResponse(null); + return; + } + + ActionListener> getDetectorListener = ActionListener.wrap(detectorOp -> { + if (!detectorOp.isPresent()) { + injectListener.onFailure(new EndRunException(id, "AnomalyDetector is not available.", false)); + return; + } + AnomalyDetector detector = detectorOp.get(); + User userInfo = SecurityUtil.getUserFromDetector(detector, settings); + inject(userInfo.getName(), userInfo.getRoles()); + injectListener.onResponse(null); + }, injectListener::onFailure); + + // Since we are gonna read user from detector, make sure the anomaly detector exists and fetched from disk or cached memory + // We don't accept a passed-in AnomalyDetector because the caller might mistakenly not insert any user info in the + // constructed AnomalyDetector and thus poses risks. In the case, if the user is null, we will give admin role. + nodeStateManager.getAnomalyDetector(id, getDetectorListener); + } + + public void injectUserRoles(User user) { + if (user == null) { + LOG.debug("null user"); + return; + } + + if (shouldInject()) { + inject(user.getName(), user.getRoles()); + } + } +} diff --git a/src/main/java/org/opensearch/ad/util/MultiResponsesDelegateActionListener.java b/src/main/java/org/opensearch/ad/util/MultiResponsesDelegateActionListener.java index 3fc1468b7..8b18bf9c3 100644 --- a/src/main/java/org/opensearch/ad/util/MultiResponsesDelegateActionListener.java +++ b/src/main/java/org/opensearch/ad/util/MultiResponsesDelegateActionListener.java @@ -70,7 +70,7 @@ public void onResponse(T response) { @Override public void onFailure(Exception e) { - LOG.error(e); + LOG.error("Failure in response", e); try { this.exceptions.add(e.getMessage()); } finally { diff --git a/src/main/java/org/opensearch/ad/util/ParseUtils.java b/src/main/java/org/opensearch/ad/util/ParseUtils.java index 12b6e0744..2de60f556 100644 --- a/src/main/java/org/opensearch/ad/util/ParseUtils.java +++ b/src/main/java/org/opensearch/ad/util/ParseUtils.java @@ -456,6 +456,14 @@ public static SearchSourceBuilder addUserBackendRolesFilter(User user, SearchSou return searchSourceBuilder; } + /** + * Generates a user string formed by the username, backend roles, roles and requested tenants separated by '|' + * (e.g., john||own_index,testrole|__user__, no backend role so you see two verticle line after john.). + * This is the user string format used internally in the OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT and may be + * parsed using User.parse(string). + * @param client Client containing user info. A public API request will fill in the user info in the thread context. + * @return parsed user object + */ public static User getUserContext(Client client) { String userStr = client.threadPool().getThreadContext().getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); logger.debug("Filtering result by " + userStr); diff --git a/src/main/java/org/opensearch/ad/util/SafeSecurityInjector.java b/src/main/java/org/opensearch/ad/util/SafeSecurityInjector.java new file mode 100644 index 000000000..612ea4d5c --- /dev/null +++ b/src/main/java/org/opensearch/ad/util/SafeSecurityInjector.java @@ -0,0 +1,87 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.util; + +import java.util.List; +import java.util.Locale; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.commons.ConfigConstants; +import org.opensearch.commons.InjectSecurity; + +public abstract class SafeSecurityInjector implements AutoCloseable { + private static final Logger LOG = LogManager.getLogger(SafeSecurityInjector.class); + // user header used by security plugin. As we cannot take security plugin as + // a compile dependency, we have to duplicate it here. + private static final String OPENDISTRO_SECURITY_USER = "_opendistro_security_user"; + + private InjectSecurity rolesInjectorHelper; + protected String id; + protected Settings settings; + protected ThreadContext tc; + + public SafeSecurityInjector(String id, Settings settings, ThreadContext tc) { + this.id = id; + this.settings = settings; + this.tc = tc; + this.rolesInjectorHelper = null; + } + + protected boolean shouldInject() { + if (id == null || settings == null || tc == null) { + LOG.debug(String.format(Locale.ROOT, "null value: id: %s, settings: %s, threadContext: %s", id, settings, tc)); + return false; + } + // user not null means the request comes from user (e.g., public restful API) + // we don't need to inject roles. + Object userIn = tc.getTransient(OPENDISTRO_SECURITY_USER); + if (userIn != null) { + LOG.debug(new ParameterizedMessage("User not empty in thread context: [{}]", userIn)); + return false; + } + userIn = tc.getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); + if (userIn != null) { + LOG.debug(new ParameterizedMessage("User not empty in thread context: [{}]", userIn)); + return false; + } + Object rolesin = tc.getTransient(ConfigConstants.OPENSEARCH_SECURITY_INJECTED_ROLES); + if (rolesin != null) { + LOG.warn(new ParameterizedMessage("Injected roles not empty in thread context: [{}]", rolesin)); + return false; + } + + return true; + } + + protected void inject(String user, List roles) { + if (roles == null) { + LOG.warn("Cannot inject empty roles in thread context"); + return; + } + if (rolesInjectorHelper == null) { + // lazy init + rolesInjectorHelper = new InjectSecurity(id, settings, tc); + } + rolesInjectorHelper.inject(user, roles); + } + + @Override + public void close() { + if (rolesInjectorHelper != null) { + rolesInjectorHelper.close(); + } + } +} diff --git a/src/main/java/org/opensearch/ad/util/SecurityClientUtil.java b/src/main/java/org/opensearch/ad/util/SecurityClientUtil.java new file mode 100644 index 000000000..8e9b97b57 --- /dev/null +++ b/src/main/java/org/opensearch/ad/util/SecurityClientUtil.java @@ -0,0 +1,130 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.util; + +import java.util.function.BiConsumer; + +import org.opensearch.action.ActionListener; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionResponse; +import org.opensearch.action.ActionType; +import org.opensearch.ad.NodeStateManager; +import org.opensearch.client.Client; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.commons.authuser.User; + +public class SecurityClientUtil { + private static final String INJECTION_ID = "direct"; + private NodeStateManager nodeStateManager; + private Settings settings; + + @Inject + public SecurityClientUtil(NodeStateManager nodeStateManager, Settings settings) { + this.nodeStateManager = nodeStateManager; + this.settings = settings; + } + + /** + * Send an asynchronous request in the context of user role and handle response with the provided listener. The role + * is recorded in a detector config. + * @param ActionRequest + * @param ActionResponse + * @param request request body + * @param consumer request method, functional interface to operate as a client request like client::get + * @param detectorId Detector id + * @param client OpenSearch client + * @param listener needed to handle response + */ + public void asyncRequestWithInjectedSecurity( + Request request, + BiConsumer> consumer, + String detectorId, + Client client, + ActionListener listener + ) { + ThreadContext threadContext = client.threadPool().getThreadContext(); + try (ADSafeSecurityInjector injectSecurity = new ADSafeSecurityInjector(detectorId, settings, threadContext, nodeStateManager)) { + injectSecurity + .injectUserRolesFromDetector( + ActionListener + .wrap( + success -> consumer.accept(request, ActionListener.runBefore(listener, () -> injectSecurity.close())), + listener::onFailure + ) + ); + } + } + + /** + * Send an asynchronous request in the context of user role and handle response with the provided listener. The role + * is provided in the arguments. + * @param ActionRequest + * @param ActionResponse + * @param request request body + * @param consumer request method, functional interface to operate as a client request like client::get + * @param user User info + * @param client OpenSearch client + * @param listener needed to handle response + */ + public void asyncRequestWithInjectedSecurity( + Request request, + BiConsumer> consumer, + User user, + Client client, + ActionListener listener + ) { + ThreadContext threadContext = client.threadPool().getThreadContext(); + // use a hardcoded string as detector id that is only used in logging + // Question: + // Will the try-with-resources statement auto close injectSecurity? + // Here the injectSecurity is closed explicitly. So we don't need to put the injectSecurity inside try ? + // Explanation: + // There might be two threads: one thread covers try, inject, and triggers client.execute/client.search + // (this can be a thread in the write thread pool); another thread actually execute the logic of + // client.execute/client.search and handles the responses (this can be a thread in the search thread pool). + // Auto-close in try will restore the context in one thread; the explicit close injectSecurity will restore + // the context in another thread. So we still need to put the injectSecurity inside try. + try (ADSafeSecurityInjector injectSecurity = new ADSafeSecurityInjector(INJECTION_ID, settings, threadContext, nodeStateManager)) { + injectSecurity.injectUserRoles(user); + consumer.accept(request, ActionListener.runBefore(listener, () -> injectSecurity.close())); + } + } + + /** + * Execute a transport action in the context of user role and handle response with the provided listener. The role + * is provided in the arguments. + * @param ActionRequest + * @param ActionResponse + * @param action transport action + * @param request request body + * @param user User info + * @param client OpenSearch client + * @param listener needed to handle response + */ + public void executeWithInjectedSecurity( + ActionType action, + Request request, + User user, + Client client, + ActionListener listener + ) { + ThreadContext threadContext = client.threadPool().getThreadContext(); + + // use a hardcoded string as detector id that is only used in logging + try (ADSafeSecurityInjector injectSecurity = new ADSafeSecurityInjector(INJECTION_ID, settings, threadContext, nodeStateManager)) { + injectSecurity.injectUserRoles(user); + client.execute(action, request, ActionListener.runBefore(listener, () -> injectSecurity.close())); + } + } +} diff --git a/src/main/java/org/opensearch/ad/util/SecurityUtil.java b/src/main/java/org/opensearch/ad/util/SecurityUtil.java new file mode 100644 index 000000000..d72d345ab --- /dev/null +++ b/src/main/java/org/opensearch/ad/util/SecurityUtil.java @@ -0,0 +1,77 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.util; + +import java.util.Collections; +import java.util.List; + +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.AnomalyDetectorJob; +import org.opensearch.common.settings.Settings; +import org.opensearch.commons.authuser.User; + +import com.google.common.collect.ImmutableList; + +public class SecurityUtil { + /** + * @param userObj the last user who edited the detector config + * @param settings Node settings + * @return converted user for bwc if necessary + */ + private static User getAdjustedUserBWC(User userObj, Settings settings) { + /* + * We need to handle 3 cases: + * 1. Detectors created by older versions and never updated. These detectors wont have User details in the + * detector object. `detector.user` will be null. Insert `all_access, AmazonES_all_access` role. + * 2. Detectors are created when security plugin is disabled, these will have empty User object. + * (`detector.user.name`, `detector.user.roles` are empty ) + * 3. Detectors are created when security plugin is enabled, these will have an User object. + * This will inject user role and check if the user role has permissions to call the execute + * Anomaly Result API. + */ + String user; + List roles; + if (userObj == null) { + // It's possible that user create domain with security disabled, then enable security + // after upgrading. This is for BWC, for old detectors which created when security + // disabled, the user will be null. + // This is a huge promotion in privileges. To prevent a caller code from making a mistake and pass a null object, + // we make the method private and only allow fetching user object from detector or job configuration (see the public + // access methods with the same name). + user = ""; + roles = settings.getAsList("", ImmutableList.of("all_access", "AmazonES_all_access")); + return new User(user, Collections.emptyList(), roles, Collections.emptyList()); + } else { + return userObj; + } + } + + /** + * * + * @param detector Detector config + * @param settings Node settings + * @return user recorded by a detector. Made adjstument for BWC (backward-compatibility) if necessary. + */ + public static User getUserFromDetector(AnomalyDetector detector, Settings settings) { + return getAdjustedUserBWC(detector.getUser(), settings); + } + + /** + * * + * @param detectorJob Detector Job + * @param settings Node settings + * @return user recorded by a detector job + */ + public static User getUserFromJob(AnomalyDetectorJob detectorJob, Settings settings) { + return getAdjustedUserBWC(detectorJob.getUser(), settings); + } +} diff --git a/src/main/java/org/opensearch/ad/util/Throttler.java b/src/main/java/org/opensearch/ad/util/Throttler.java index bd2e3d66e..177b612a2 100644 --- a/src/main/java/org/opensearch/ad/util/Throttler.java +++ b/src/main/java/org/opensearch/ad/util/Throttler.java @@ -29,16 +29,20 @@ public class Throttler { private final ConcurrentHashMap> negativeCache; private final Clock clock; - /** - * Inject annotation required by Guice to instantiate EntityResultTransportAction (transitive dependency) - * (EntityResultTransportAction > ResultHandler > ClientUtil > Throttler) - * @param clock a UTC clock - */ public Throttler(Clock clock) { this.negativeCache = new ConcurrentHashMap<>(); this.clock = clock; } + /** + * This will be used when dependency injection directly/indirectly injects a Throttler object. Without this object, + * node start might fail due to not being able to find a Clock object. We removed Clock object association in + * https://github.com/opendistro-for-elasticsearch/anomaly-detection/pull/305 + */ + public Throttler() { + this(Clock.systemUTC()); + } + /** * Get negative cache value(ActionRequest, Instant) for given detector * @param detectorId AnomalyDetector ID diff --git a/src/test/java/org/opensearch/action/admin/indices/mapping/get/IndexAnomalyDetectorActionHandlerTests.java b/src/test/java/org/opensearch/action/admin/indices/mapping/get/IndexAnomalyDetectorActionHandlerTests.java index a5bdd5aa3..e8038fedf 100644 --- a/src/test/java/org/opensearch/action/admin/indices/mapping/get/IndexAnomalyDetectorActionHandlerTests.java +++ b/src/test/java/org/opensearch/action/admin/indices/mapping/get/IndexAnomalyDetectorActionHandlerTests.java @@ -23,6 +23,7 @@ import static org.opensearch.ad.model.AnomalyDetector.ANOMALY_DETECTORS_INDEX; import java.io.IOException; +import java.time.Clock; import java.util.Arrays; import java.util.Locale; import java.util.concurrent.TimeUnit; @@ -44,6 +45,7 @@ import org.opensearch.action.search.SearchResponse; import org.opensearch.action.support.WriteRequest; import org.opensearch.ad.AbstractADTest; +import org.opensearch.ad.NodeStateManager; import org.opensearch.ad.TestHelpers; import org.opensearch.ad.common.exception.ADValidationException; import org.opensearch.ad.constant.CommonName; @@ -53,6 +55,7 @@ import org.opensearch.ad.rest.handler.IndexAnomalyDetectorActionHandler; import org.opensearch.ad.task.ADTaskManager; import org.opensearch.ad.transport.IndexAnomalyDetectorResponse; +import org.opensearch.ad.util.SecurityClientUtil; import org.opensearch.client.node.NodeClient; import org.opensearch.cluster.ClusterName; import org.opensearch.cluster.ClusterState; @@ -80,6 +83,7 @@ public class IndexAnomalyDetectorActionHandlerTests extends AbstractADTest { private IndexAnomalyDetectorActionHandler handler; private ClusterService clusterService; private NodeClient clientMock; + private SecurityClientUtil clientUtil; private TransportService transportService; private ActionListener channel; private AnomalyDetectionIndices anomalyDetectionIndices; @@ -96,6 +100,7 @@ public class IndexAnomalyDetectorActionHandlerTests extends AbstractADTest { private RestRequest.Method method; private ADTaskManager adTaskManager; private SearchFeatureDao searchFeatureDao; + private Clock clock; /** * Mockito does not allow mock final methods. Make my own delegates and mock them. @@ -137,6 +142,9 @@ public void setUp() throws Exception { settings = Settings.EMPTY; clusterService = mock(ClusterService.class); clientMock = spy(new NodeClient(settings, threadPool)); + clock = mock(Clock.class); + NodeStateManager nodeStateManager = mock(NodeStateManager.class); + clientUtil = new SecurityClientUtil(nodeStateManager, settings); transportService = mock(TransportService.class); channel = mock(ActionListener.class); @@ -170,6 +178,7 @@ public void setUp() throws Exception { handler = new IndexAnomalyDetectorActionHandler( clusterService, clientMock, + clientUtil, transportService, channel, anomalyDetectionIndices, @@ -219,6 +228,7 @@ public void testNoCategoricalField() throws IOException { handler = new IndexAnomalyDetectorActionHandler( clusterService, clientMock, + clientUtil, transportService, channel, anomalyDetectionIndices, @@ -287,10 +297,13 @@ public void doE } } }; + NodeStateManager nodeStateManager = mock(NodeStateManager.class); + clientUtil = new SecurityClientUtil(nodeStateManager, Settings.EMPTY); handler = new IndexAnomalyDetectorActionHandler( clusterService, client, + clientUtil, transportService, channel, anomalyDetectionIndices, @@ -365,10 +378,13 @@ public void doE }; NodeClient clientSpy = spy(client); + NodeStateManager nodeStateManager = mock(NodeStateManager.class); + clientUtil = new SecurityClientUtil(nodeStateManager, Settings.EMPTY); handler = new IndexAnomalyDetectorActionHandler( clusterService, clientSpy, + clientUtil, transportService, channel, anomalyDetectionIndices, @@ -457,6 +473,8 @@ public void doE }; NodeClient clientSpy = spy(client); + NodeStateManager nodeStateManager = mock(NodeStateManager.class); + clientUtil = new SecurityClientUtil(nodeStateManager, Settings.EMPTY); ClusterName clusterName = new ClusterName("test"); ClusterState clusterState = ClusterState.builder(clusterName).metadata(Metadata.builder().build()).build(); when(clusterService.state()).thenReturn(clusterState); @@ -464,6 +482,7 @@ public void doE handler = new IndexAnomalyDetectorActionHandler( clusterService, clientSpy, + clientUtil, transportService, channel, anomalyDetectionIndices, @@ -597,6 +616,7 @@ public void testTenMultiEntityDetectorsUpdateSingleEntityAdToMulti() throws IOEx handler = new IndexAnomalyDetectorActionHandler( clusterService, clientMock, + clientUtil, transportService, channel, anomalyDetectionIndices, @@ -673,6 +693,7 @@ public void testTenMultiEntityDetectorsUpdateExistingMultiEntityAd() throws IOEx handler = new IndexAnomalyDetectorActionHandler( clusterService, clientMock, + clientUtil, transportService, channel, anomalyDetectionIndices, diff --git a/src/test/java/org/opensearch/action/admin/indices/mapping/get/ValidateAnomalyDetectorActionHandlerTests.java b/src/test/java/org/opensearch/action/admin/indices/mapping/get/ValidateAnomalyDetectorActionHandlerTests.java index ebb3f4a00..73c9cd8f5 100644 --- a/src/test/java/org/opensearch/action/admin/indices/mapping/get/ValidateAnomalyDetectorActionHandlerTests.java +++ b/src/test/java/org/opensearch/action/admin/indices/mapping/get/ValidateAnomalyDetectorActionHandlerTests.java @@ -34,6 +34,7 @@ import org.opensearch.action.search.SearchResponse; import org.opensearch.action.support.WriteRequest; import org.opensearch.ad.AbstractADTest; +import org.opensearch.ad.NodeStateManager; import org.opensearch.ad.TestHelpers; import org.opensearch.ad.common.exception.ADValidationException; import org.opensearch.ad.feature.SearchFeatureDao; @@ -45,6 +46,7 @@ import org.opensearch.ad.rest.handler.ValidateAnomalyDetectorActionHandler; import org.opensearch.ad.task.ADTaskManager; import org.opensearch.ad.transport.ValidateAnomalyDetectorResponse; +import org.opensearch.ad.util.SecurityClientUtil; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; @@ -136,9 +138,13 @@ public void testValidateMoreThanThousandSingleEntityDetectorLimit() throws IOExc return null; }).when(clientMock).search(any(SearchRequest.class), any()); + NodeStateManager nodeStateManager = mock(NodeStateManager.class); + SecurityClientUtil clientUtil = new SecurityClientUtil(nodeStateManager, settings); + handler = new ValidateAnomalyDetectorActionHandler( clusterService, clientMock, + clientUtil, channel, anomalyDetectionIndices, TestHelpers.randomAnomalyDetector(TestHelpers.randomUiMetadata(), null, true), @@ -190,9 +196,13 @@ public void testValidateMoreThanTenMultiEntityDetectorsLimit() throws IOExceptio return null; }).when(clientMock).search(any(SearchRequest.class), any()); + NodeStateManager nodeStateManager = mock(NodeStateManager.class); + SecurityClientUtil clientUtil = new SecurityClientUtil(nodeStateManager, settings); + handler = new ValidateAnomalyDetectorActionHandler( clusterService, clientMock, + clientUtil, channel, anomalyDetectionIndices, detector, diff --git a/src/test/java/org/opensearch/ad/AbstractProfileRunnerTests.java b/src/test/java/org/opensearch/ad/AbstractProfileRunnerTests.java index 715425c2e..a47ec9968 100644 --- a/src/test/java/org/opensearch/ad/AbstractProfileRunnerTests.java +++ b/src/test/java/org/opensearch/ad/AbstractProfileRunnerTests.java @@ -19,12 +19,14 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; +import java.time.Clock; import java.util.Arrays; import java.util.HashSet; import java.util.Optional; import java.util.Set; import java.util.function.Consumer; +import org.junit.AfterClass; import org.junit.Before; import org.junit.BeforeClass; import org.opensearch.Version; @@ -33,7 +35,9 @@ import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.model.DetectorProfileName; import org.opensearch.ad.task.ADTaskManager; +import org.opensearch.ad.transport.AnomalyResultTests; import org.opensearch.ad.util.DiscoveryNodeFilterer; +import org.opensearch.ad.util.SecurityClientUtil; import org.opensearch.client.Client; import org.opensearch.cluster.ClusterName; import org.opensearch.cluster.ClusterState; @@ -65,6 +69,7 @@ protected enum ErrorResultStatus { protected AnomalyDetectorProfileRunner runner; protected Client client; + protected SecurityClientUtil clientUtil; protected DiscoveryNodeFilterer nodeFilter; protected AnomalyDetector detector; protected ClusterService clusterService; @@ -141,6 +146,12 @@ public static void setUpOnce() { emptySet(), Version.CURRENT ); + setUpThreadPool(AnomalyResultTests.class.getSimpleName()); + } + + @AfterClass + public static void tearDownAfterClass() { + tearDownThreadPool(); } @SuppressWarnings("unchecked") @@ -149,6 +160,9 @@ public static void setUpOnce() { public void setUp() throws Exception { super.setUp(); client = mock(Client.class); + when(client.threadPool()).thenReturn(threadPool); + Clock clock = mock(Clock.class); + nodeFilter = mock(DiscoveryNodeFilterer.class); clusterService = mock(ClusterService.class); adTaskManager = mock(ADTaskManager.class); @@ -163,7 +177,6 @@ public void setUp() throws Exception { function.accept(Optional.of(TestHelpers.randomAdTask())); return null; }).when(adTaskManager).getAndExecuteOnLatestDetectorLevelTask(any(), any(), any(), any(), anyBoolean(), any()); - runner = new AnomalyDetectorProfileRunner(client, xContentRegistry(), nodeFilter, requiredSamples, transportService, adTaskManager); detectorIntervalMin = 3; detectorGetReponse = mock(GetResponse.class); diff --git a/src/test/java/org/opensearch/ad/AnomalyDetectorProfileRunnerTests.java b/src/test/java/org/opensearch/ad/AnomalyDetectorProfileRunnerTests.java index ed4362336..85855b037 100644 --- a/src/test/java/org/opensearch/ad/AnomalyDetectorProfileRunnerTests.java +++ b/src/test/java/org/opensearch/ad/AnomalyDetectorProfileRunnerTests.java @@ -15,8 +15,8 @@ import static java.util.Collections.emptySet; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyBoolean; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.when; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.*; import static org.opensearch.ad.model.AnomalyDetector.ANOMALY_DETECTORS_INDEX; import static org.opensearch.ad.model.AnomalyDetectorJob.ANOMALY_DETECTOR_JOB_INDEX; @@ -59,9 +59,11 @@ import org.opensearch.ad.transport.ProfileResponse; import org.opensearch.ad.transport.RCFPollingAction; import org.opensearch.ad.transport.RCFPollingResponse; +import org.opensearch.ad.util.SecurityClientUtil; import org.opensearch.cluster.ClusterName; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.common.io.stream.NotSerializableExceptionWrapper; +import org.opensearch.common.settings.Settings; import org.opensearch.common.transport.TransportAddress; import org.opensearch.index.IndexNotFoundException; import org.opensearch.transport.RemoteTransportException; @@ -96,6 +98,22 @@ private void setUpClientGet( ErrorResultStatus errorResultStatus ) throws IOException { detector = TestHelpers.randomAnomalyDetectorWithInterval(new IntervalTimeConfiguration(detectorIntervalMin, ChronoUnit.MINUTES)); + NodeStateManager nodeStateManager = mock(NodeStateManager.class); + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.of(detector)); + return null; + }).when(nodeStateManager).getAnomalyDetector(anyString(), any(ActionListener.class)); + clientUtil = new SecurityClientUtil(nodeStateManager, Settings.EMPTY); + runner = new AnomalyDetectorProfileRunner( + client, + clientUtil, + xContentRegistry(), + nodeFilter, + requiredSamples, + transportService, + adTaskManager + ); doAnswer(invocation -> { Object[] args = invocation.getArguments(); @@ -614,7 +632,7 @@ public void testInitNoIndex() throws IOException, InterruptedException { public void testInvalidRequiredSamples() { expectThrows( IllegalArgumentException.class, - () -> new AnomalyDetectorProfileRunner(client, xContentRegistry(), nodeFilter, 0, transportService, adTaskManager) + () -> new AnomalyDetectorProfileRunner(client, clientUtil, xContentRegistry(), nodeFilter, 0, transportService, adTaskManager) ); } diff --git a/src/test/java/org/opensearch/ad/AnomalyDetectorRestTestCase.java b/src/test/java/org/opensearch/ad/AnomalyDetectorRestTestCase.java index 474005e0d..afe1b615d 100644 --- a/src/test/java/org/opensearch/ad/AnomalyDetectorRestTestCase.java +++ b/src/test/java/org/opensearch/ad/AnomalyDetectorRestTestCase.java @@ -92,13 +92,16 @@ protected AnomalyDetector createRandomAnomalyDetector( if (indexName == null) { detector = TestHelpers.randomAnomalyDetector(uiMetadata, null, featureEnabled); + TestHelpers.createIndexWithTimeField(client(), detector.getIndices().get(0), detector.getTimeField()); TestHelpers .makeRequest( client, "POST", "/" + detector.getIndices().get(0) + "/_doc/" + randomAlphaOfLength(5) + "?refresh=true", ImmutableMap.of(), - TestHelpers.toHttpEntity("{\"name\": \"test\"}"), + // avoid validation error as validation API will check at least 1 document and the timestamp field + // exists in index mapping + TestHelpers.toHttpEntity("{\"name\": \"test\", \"" + detector.getTimeField() + "\" : \"1661386754000\"}"), null, false ); @@ -482,6 +485,54 @@ public Response createSearchRole(String role, String index) throws IOException { ); } + public Response createDlsRole(String role, String index) throws IOException { + return TestHelpers + .makeRequest( + client(), + "PUT", + "/_opendistro/_security/api/roles/" + role, + null, + TestHelpers + .toHttpEntity( + "{\n" + + "\"cluster_permissions\": [\n" + + "unlimited\n" + + "],\n" + + "\"index_permissions\": [\n" + + "{\n" + + "\"index_patterns\": [\n" + + "\"" + + index + + "\"\n" + + "],\n" + + "\"dls\": \"\"\"{ \"bool\": { \"must\": { \"match\": { \"foo\": \"bar\" }}}}\"\"\",\n" + + "\"fls\": [],\n" + + "\"masked_fields\": [],\n" + + "\"allowed_actions\": [\n" + + "\"unlimited\"\n" + + "]\n" + + "},\n" + + "{\n" + + "\"index_patterns\": [\n" + + "\"" + + "*" + + "\"\n" + + "],\n" + + "\"dls\": \"\",\n" + + "\"fls\": [],\n" + + "\"masked_fields\": [],\n" + + "\"allowed_actions\": [\n" + + "\"unlimited\"\n" + + "]\n" + + "}\n" + + "],\n" + + "\"tenant_permissions\": []\n" + + "}" + ), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana")) + ); + } + public Response deleteUser(String user) throws IOException { return TestHelpers .makeRequest( diff --git a/src/test/java/org/opensearch/ad/EntityProfileRunnerTests.java b/src/test/java/org/opensearch/ad/EntityProfileRunnerTests.java index 8e98f2b4f..18cbb2d5d 100644 --- a/src/test/java/org/opensearch/ad/EntityProfileRunnerTests.java +++ b/src/test/java/org/opensearch/ad/EntityProfileRunnerTests.java @@ -13,21 +13,20 @@ import static java.util.Collections.emptyMap; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.*; import static org.opensearch.ad.model.AnomalyDetector.ANOMALY_DETECTORS_INDEX; import static org.opensearch.ad.model.AnomalyDetectorJob.ANOMALY_DETECTOR_JOB_INDEX; import java.io.IOException; import java.time.temporal.ChronoUnit; -import java.util.Arrays; -import java.util.Collections; -import java.util.HashSet; -import java.util.Set; +import java.util.*; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import org.apache.lucene.search.TotalHits; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; import org.opensearch.action.ActionListener; import org.opensearch.action.get.GetRequest; import org.opensearch.action.get.GetResponse; @@ -49,7 +48,9 @@ import org.opensearch.ad.model.ModelProfileOnNode; import org.opensearch.ad.transport.EntityProfileAction; import org.opensearch.ad.transport.EntityProfileResponse; +import org.opensearch.ad.util.SecurityClientUtil; import org.opensearch.client.Client; +import org.opensearch.common.settings.Settings; import org.opensearch.common.text.Text; import org.opensearch.index.IndexNotFoundException; import org.opensearch.search.DocValueFormat; @@ -63,6 +64,7 @@ public class EntityProfileRunnerTests extends AbstractADTest { private AnomalyDetector detector; private int detectorIntervalMin; private Client client; + private SecurityClientUtil clientUtil; private EntityProfileRunner runner; private Set state; private Set initNInfo; @@ -88,8 +90,19 @@ enum InittedEverResultStatus { NOT_INITTED, } + @BeforeClass + public static void setUpBeforeClass() { + setUpThreadPool(AnomalyDetectorJobRunnerTests.class.getSimpleName()); + } + + @AfterClass + public static void tearDownAfterClass() { + tearDownThreadPool(); + } + @SuppressWarnings("unchecked") @Override + @Before public void setUp() throws Exception { super.setUp(); detectorIntervalMin = 3; @@ -107,15 +120,23 @@ public void setUp() throws Exception { detectorId = "A69pa3UBHuCbh-emo9oR"; entityValue = "app-0"; - requiredSamples = 128; - client = mock(Client.class); - - runner = new EntityProfileRunner(client, xContentRegistry(), requiredSamples); - categoryField = "a"; detector = TestHelpers.randomAnomalyDetectorUsingCategoryFields(detectorId, Arrays.asList(categoryField)); job = TestHelpers.randomAnomalyDetectorJob(true); + requiredSamples = 128; + client = mock(Client.class); + when(client.threadPool()).thenReturn(threadPool); + NodeStateManager nodeStateManager = mock(NodeStateManager.class); + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.of(detector)); + return null; + }).when(nodeStateManager).getAnomalyDetector(any(String.class), any(ActionListener.class)); + clientUtil = new SecurityClientUtil(nodeStateManager, Settings.EMPTY); + + runner = new EntityProfileRunner(client, clientUtil, xContentRegistry(), requiredSamples); + doAnswer(invocation -> { Object[] args = invocation.getArguments(); GetRequest request = (GetRequest) args[0]; diff --git a/src/test/java/org/opensearch/ad/MultiEntityProfileRunnerTests.java b/src/test/java/org/opensearch/ad/MultiEntityProfileRunnerTests.java index c989871fb..d0753c481 100644 --- a/src/test/java/org/opensearch/ad/MultiEntityProfileRunnerTests.java +++ b/src/test/java/org/opensearch/ad/MultiEntityProfileRunnerTests.java @@ -20,6 +20,7 @@ import static org.opensearch.ad.model.AnomalyDetector.ANOMALY_DETECTORS_INDEX; import static org.opensearch.ad.model.AnomalyDetectorJob.ANOMALY_DETECTOR_JOB_INDEX; +import java.time.Clock; import java.time.Instant; import java.util.ArrayList; import java.util.Arrays; @@ -34,7 +35,9 @@ import java.util.concurrent.TimeUnit; import java.util.function.Consumer; +import org.junit.AfterClass; import org.junit.Before; +import org.junit.BeforeClass; import org.opensearch.Version; import org.opensearch.action.ActionListener; import org.opensearch.action.FailedNodeException; @@ -52,19 +55,22 @@ import org.opensearch.ad.model.DetectorProfileName; import org.opensearch.ad.model.DetectorState; import org.opensearch.ad.task.ADTaskManager; +import org.opensearch.ad.transport.AnomalyResultTests; import org.opensearch.ad.transport.ProfileAction; import org.opensearch.ad.transport.ProfileNodeResponse; import org.opensearch.ad.transport.ProfileResponse; -import org.opensearch.ad.util.DiscoveryNodeFilterer; +import org.opensearch.ad.util.*; import org.opensearch.client.Client; import org.opensearch.cluster.ClusterName; import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.common.settings.Settings; import org.opensearch.common.transport.TransportAddress; import org.opensearch.transport.TransportService; public class MultiEntityProfileRunnerTests extends AbstractADTest { private AnomalyDetectorProfileRunner runner; private Client client; + private SecurityClientUtil clientUtil; private DiscoveryNodeFilterer nodeFilter; private int requiredSamples; private AnomalyDetector detector; @@ -93,12 +99,25 @@ enum InittedEverResultStatus { NOT_INITTED, } + @BeforeClass + public static void setUpBeforeClass() { + setUpThreadPool(AnomalyResultTests.class.getSimpleName()); + } + + @AfterClass + public static void tearDownAfterClass() { + tearDownThreadPool(); + } + @SuppressWarnings("unchecked") @Before @Override public void setUp() throws Exception { super.setUp(); client = mock(Client.class); + Clock clock = mock(Clock.class); + NodeStateManager nodeStateManager = mock(NodeStateManager.class); + clientUtil = new SecurityClientUtil(nodeStateManager, Settings.EMPTY); nodeFilter = mock(DiscoveryNodeFilterer.class); requiredSamples = 128; @@ -115,7 +134,15 @@ public void setUp() throws Exception { function.accept(Optional.of(TestHelpers.randomAdTask())); return null; }).when(adTaskManager).getAndExecuteOnLatestDetectorLevelTask(any(), any(), any(), any(), anyBoolean(), any()); - runner = new AnomalyDetectorProfileRunner(client, xContentRegistry(), nodeFilter, requiredSamples, transportService, adTaskManager); + runner = new AnomalyDetectorProfileRunner( + client, + clientUtil, + xContentRegistry(), + nodeFilter, + requiredSamples, + transportService, + adTaskManager + ); doAnswer(invocation -> { Object[] args = invocation.getArguments(); diff --git a/src/test/java/org/opensearch/ad/ODFERestTestCase.java b/src/test/java/org/opensearch/ad/ODFERestTestCase.java index 8c48a04cd..46e901e79 100644 --- a/src/test/java/org/opensearch/ad/ODFERestTestCase.java +++ b/src/test/java/org/opensearch/ad/ODFERestTestCase.java @@ -18,11 +18,14 @@ import static org.opensearch.commons.ConfigConstants.OPENSEARCH_SECURITY_SSL_HTTP_PEMCERT_FILEPATH; import java.io.IOException; +import java.io.InputStreamReader; import java.net.URI; import java.net.URISyntaxException; +import java.nio.charset.Charset; import java.nio.file.Path; import java.util.Collections; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.Objects; import java.util.Optional; @@ -53,6 +56,10 @@ import org.opensearch.commons.rest.SecureRestClientBuilder; import org.opensearch.test.rest.OpenSearchRestTestCase; +import com.google.gson.JsonArray; +import com.google.gson.JsonObject; +import com.google.gson.JsonParser; + /** * ODFE integration test base class to support both security disabled and enabled ODFE cluster. */ @@ -199,4 +206,46 @@ protected static void configureHttpsClient(RestClientBuilder builder, Settings s protected boolean preserveIndicesUponCompletion() { return true; } + + protected void waitAllSyncheticDataIngested(int expectedSize, String datasetName, RestClient client) throws Exception { + int maxWaitCycles = 3; + do { + Request request = new Request("POST", String.format(Locale.ROOT, "/%s/_search", datasetName)); + request + .setJsonEntity( + String + .format( + Locale.ROOT, + "{\"query\": {" + + " \"match_all\": {}" + + " }," + + " \"size\": 1," + + " \"sort\": [" + + " {" + + " \"timestamp\": {" + + " \"order\": \"desc\"" + + " }" + + " }" + + " ]}" + ) + ); + // Make sure all of the test data has been ingested + // Expected response: + // "_index":"synthetic","_type":"_doc","_id":"10080","_score":null,"_source":{"timestamp":"2019-11-08T00:00:00Z","Feature1":156.30028000000001,"Feature2":100.211205,"host":"host1"},"sort":[1573171200000]} + Response response = client.performRequest(request); + JsonObject json = JsonParser + .parseReader(new InputStreamReader(response.getEntity().getContent(), Charset.defaultCharset())) + .getAsJsonObject(); + JsonArray hits = json.getAsJsonObject("hits").getAsJsonArray("hits"); + if (hits != null + && hits.size() == 1 + && expectedSize - 1 == hits.get(0).getAsJsonObject().getAsJsonPrimitive("_id").getAsLong()) { + break; + } else { + request = new Request("POST", String.format(Locale.ROOT, "/%s/_refresh", datasetName)); + client.performRequest(request); + } + Thread.sleep(1_000); + } while (maxWaitCycles-- >= 0); + } } diff --git a/src/test/java/org/opensearch/ad/TestHelpers.java b/src/test/java/org/opensearch/ad/TestHelpers.java index abe97011a..0016e8df9 100644 --- a/src/test/java/org/opensearch/ad/TestHelpers.java +++ b/src/test/java/org/opensearch/ad/TestHelpers.java @@ -31,15 +31,7 @@ import java.nio.ByteBuffer; import java.time.Instant; import java.time.temporal.ChronoUnit; -import java.util.AbstractMap; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Random; -import java.util.Set; +import java.util.*; import java.util.concurrent.Callable; import java.util.function.Consumer; import java.util.stream.IntStream; @@ -1001,6 +993,19 @@ public static void createIndex(RestClient client, String indexName, HttpEntity d ); } + public static void createIndexWithTimeField(RestClient client, String indexName, String timeField) throws IOException { + StringBuilder indexMappings = new StringBuilder(); + indexMappings.append("{\"properties\":{"); + indexMappings.append("\"" + timeField + "\":{\"type\":\"date\"}"); + indexMappings.append("}}"); + createIndex(client, indexName.toLowerCase(Locale.ROOT), TestHelpers.toHttpEntity("{\"name\": \"test\"}")); + createIndexMapping(client, indexName.toLowerCase(Locale.ROOT), TestHelpers.toHttpEntity(indexMappings.toString())); + } + + public static void createIndexMapping(RestClient client, String indexName, HttpEntity mappings) throws IOException { + TestHelpers.makeRequest(client, "POST", "/" + indexName + "/_mapping", ImmutableMap.of(), mappings, null); + } + public static GetResponse createGetResponse(ToXContentObject o, String id, String indexName) throws IOException { XContentBuilder content = o.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); return new GetResponse( diff --git a/src/test/java/org/opensearch/ad/e2e/SingleStreamModelPerfIT.java b/src/test/java/org/opensearch/ad/e2e/SingleStreamModelPerfIT.java index f21eb9bbb..15e4d1c6c 100644 --- a/src/test/java/org/opensearch/ad/e2e/SingleStreamModelPerfIT.java +++ b/src/test/java/org/opensearch/ad/e2e/SingleStreamModelPerfIT.java @@ -270,6 +270,7 @@ private void bulkIndexTestData(List data, String datasetName, int tr ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana")) ); Thread.sleep(1_000); + waitAllSyncheticDataIngested(data.size(), datasetName, client); } private void setWarningHandler(Request request, boolean strictDeprecationMode) { diff --git a/src/test/java/org/opensearch/ad/feature/FeatureManagerTests.java b/src/test/java/org/opensearch/ad/feature/FeatureManagerTests.java index d1ef85077..64ca51936 100644 --- a/src/test/java/org/opensearch/ad/feature/FeatureManagerTests.java +++ b/src/test/java/org/opensearch/ad/feature/FeatureManagerTests.java @@ -13,7 +13,6 @@ import static java.util.Arrays.asList; import static java.util.Optional.empty; -import static java.util.Optional.ofNullable; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; @@ -23,7 +22,6 @@ import static org.mockito.Matchers.argThat; import static org.mockito.Matchers.eq; import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; @@ -167,24 +165,6 @@ private Object[] getColdStartDataTestData() { new Object[] { null, null, 1, null }, }; } - @Test - @Parameters(method = "getColdStartDataTestData") - public void getColdStartData_returnExpected(Long latestTime, Entry data, int interpolants, double[][] expected) { - when(searchFeatureDao.getLatestDataTime(detector)).thenReturn(ofNullable(latestTime)); - if (latestTime != null) { - when(searchFeatureDao.getFeaturesForSampledPeriods(detector, maxTrainSamples, maxSampleStride, latestTime)) - .thenReturn(ofNullable(data)); - } - if (data != null) { - when(interpolator.interpolate(argThat(new ArrayEqMatcher<>(data.getKey())), eq(interpolants))).thenReturn(data.getKey()); - doReturn(data.getKey()).when(featureManager).batchShingle(argThat(new ArrayEqMatcher<>(data.getKey())), eq(shingleSize)); - } - - Optional results = featureManager.getColdStartData(detector); - - assertTrue(Arrays.deepEquals(expected, results.orElse(null))); - } - private Object[] getTrainDataTestData() { List> ranges = asList( entry(0L, 900_000L), diff --git a/src/test/java/org/opensearch/ad/feature/NoPowermockSearchFeatureDaoTests.java b/src/test/java/org/opensearch/ad/feature/NoPowermockSearchFeatureDaoTests.java index cd7a08a9a..b92fe19aa 100644 --- a/src/test/java/org/opensearch/ad/feature/NoPowermockSearchFeatureDaoTests.java +++ b/src/test/java/org/opensearch/ad/feature/NoPowermockSearchFeatureDaoTests.java @@ -11,29 +11,22 @@ package org.opensearch.ad.feature; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.core.AnyOf.anyOf; import static org.hamcrest.core.IsInstanceOf.instanceOf; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; +import static org.mockito.ArgumentMatchers.*; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.Mockito.*; import java.io.IOException; import java.nio.charset.StandardCharsets; import java.time.Clock; +import java.time.ZoneId; import java.time.ZoneOffset; import java.time.temporal.ChronoUnit; +import java.util.*; import java.util.AbstractMap.SimpleImmutableEntry; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collection; -import java.util.Collections; -import java.util.HashMap; -import java.util.HashSet; -import java.util.List; -import java.util.Map; import java.util.Map.Entry; -import java.util.Optional; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; @@ -41,6 +34,8 @@ import org.apache.logging.log4j.Logger; import org.apache.lucene.search.TotalHits; import org.apache.lucene.util.BytesRef; +import org.junit.AfterClass; +import org.junit.BeforeClass; import org.junit.Test; import org.mockito.ArgumentCaptor; import org.mockito.invocation.InvocationOnMock; @@ -52,7 +47,9 @@ import org.opensearch.action.search.SearchResponseSections; import org.opensearch.action.search.ShardSearchFailure; import org.opensearch.ad.AbstractADTest; +import org.opensearch.ad.NodeStateManager; import org.opensearch.ad.TestHelpers; +import org.opensearch.ad.constant.CommonName; import org.opensearch.ad.dataprocessor.LinearUniformInterpolator; import org.opensearch.ad.dataprocessor.SingleFeatureLinearUniformInterpolator; import org.opensearch.ad.model.AnomalyDetector; @@ -60,7 +57,7 @@ import org.opensearch.ad.model.Feature; import org.opensearch.ad.model.IntervalTimeConfiguration; import org.opensearch.ad.settings.AnomalyDetectorSettings; -import org.opensearch.ad.util.ClientUtil; +import org.opensearch.ad.util.SecurityClientUtil; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.ClusterSettings; @@ -71,16 +68,12 @@ import org.opensearch.search.DocValueFormat; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; -import org.opensearch.search.aggregations.AggregationBuilder; -import org.opensearch.search.aggregations.Aggregations; -import org.opensearch.search.aggregations.AggregatorFactories; -import org.opensearch.search.aggregations.BucketOrder; -import org.opensearch.search.aggregations.InternalAggregations; -import org.opensearch.search.aggregations.InternalOrder; +import org.opensearch.search.aggregations.*; import org.opensearch.search.aggregations.bucket.composite.CompositeAggregation; import org.opensearch.search.aggregations.bucket.range.InternalDateRange; import org.opensearch.search.aggregations.bucket.terms.StringTerms; import org.opensearch.search.aggregations.bucket.terms.TermsAggregationBuilder; +import org.opensearch.search.aggregations.metrics.*; import org.opensearch.search.aggregations.metrics.InternalMax; import org.opensearch.search.aggregations.metrics.SumAggregationBuilder; import org.opensearch.search.internal.InternalSearchResponse; @@ -88,7 +81,8 @@ import com.google.common.collect.ImmutableList; /** - * SearchFeatureDaoTests uses Powermock and has strange log4j related errors. + * SearchFeatureDaoTests uses Powermock and has strange log4j related errors + * (e.g., TEST_INSTANCES_ARE_REUSED). * Create a new class for new tests related to SearchFeatureDao. * */ @@ -99,7 +93,7 @@ public class NoPowermockSearchFeatureDaoTests extends AbstractADTest { private Client client; private SearchFeatureDao searchFeatureDao; private LinearUniformInterpolator interpolator; - private ClientUtil clientUtil; + private SecurityClientUtil clientUtil; private Settings settings; private ClusterService clusterService; private Clock clock; @@ -107,7 +101,18 @@ public class NoPowermockSearchFeatureDaoTests extends AbstractADTest { private String detectorId; private Map attrs1, attrs2; + @BeforeClass + public static void setUpBeforeClass() { + setUpThreadPool(NoPowermockSearchFeatureDaoTests.class.getSimpleName()); + } + + @AfterClass + public static void tearDownAfterClass() { + tearDownThreadPool(); + } + @Override + @SuppressWarnings("unchecked") public void setUp() throws Exception { super.setUp(); serviceField = "service"; @@ -125,11 +130,10 @@ public void setUp() throws Exception { when(detector.getFilterQuery()).thenReturn(QueryBuilders.matchAllQuery()); client = mock(Client.class); + when(client.threadPool()).thenReturn(threadPool); interpolator = new LinearUniformInterpolator(new SingleFeatureLinearUniformInterpolator()); - clientUtil = mock(ClientUtil.class); - settings = Settings.EMPTY; ClusterSettings clusterSettings = new ClusterSettings( Settings.EMPTY, @@ -141,6 +145,13 @@ public void setUp() throws Exception { clusterService = mock(ClusterService.class); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); clock = mock(Clock.class); + NodeStateManager nodeStateManager = mock(NodeStateManager.class); + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.of(detector)); + return null; + }).when(nodeStateManager).getAnomalyDetector(any(String.class), any(ActionListener.class)); + clientUtil = new SecurityClientUtil(nodeStateManager, settings); searchFeatureDao = new SearchFeatureDao( client, @@ -571,4 +582,164 @@ public void testGetColdStartSamplesForPeriodsDefaultFormat() throws IOException, public void testGetColdStartSamplesForPeriodsRawFormat() throws IOException, InterruptedException { getColdStartSamplesForPeriodsTemplate(DocValueFormat.RAW); } + + @SuppressWarnings("unchecked") + public void testGetFeaturesForPeriod_throwToListener_whenSearchFails() throws Exception { + + long start = 100L; + long end = 200L; + // when(ParseUtils.generateInternalFeatureQuery(eq(detector), eq(start), eq(end), eq(xContent))).thenReturn(searchSourceBuilder); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new RuntimeException()); + return null; + }).when(client).search(any(SearchRequest.class), any(ActionListener.class)); + + ActionListener> listener = mock(ActionListener.class); + searchFeatureDao.getFeaturesForPeriod(detector, start, end, listener); + + verify(listener).onFailure(any(Exception.class)); + } + + @SuppressWarnings("unchecked") + public void testGetEntityMinDataTime() { + // simulate response {"took":11,"timed_out":false,"_shards":{"total":1, + // "successful":1,"skipped":0,"failed":0},"hits":{"max_score":null,"hits":[]}, + // "aggregations":{"min_timefield":{"value":1.602211285E12, + // "value_as_string":"2020-10-09T02:41:25.000Z"}, + // "max_timefield":{"value":1.602348325E12,"value_as_string":"2020-10-10T16:45:25.000Z"}}} + DocValueFormat dateFormat = new DocValueFormat.DateTime( + DateFormatter.forPattern("strict_date_optional_time||epoch_millis"), + ZoneId.of("UTC"), + DateFieldMapper.Resolution.MILLISECONDS + ); + double earliest = 1.602211285E12; + InternalMin minInternal = new InternalMin("min_timefield", earliest, dateFormat, new HashMap<>()); + InternalAggregations internalAggregations = InternalAggregations.from(Arrays.asList(minInternal)); + SearchHits hits = new SearchHits(new SearchHit[] {}, null, Float.NaN); + SearchResponseSections searchSections = new SearchResponseSections(hits, internalAggregations, null, false, false, null, 1); + + SearchResponse searchResponse = new SearchResponse( + searchSections, + null, + 1, + 1, + 0, + 11, + ShardSearchFailure.EMPTY_ARRAY, + SearchResponse.Clusters.EMPTY + ); + + doAnswer(invocation -> { + SearchRequest request = invocation.getArgument(0); + assertEquals(1, request.indices().length); + assertTrue(detector.getIndices().contains(request.indices()[0])); + AggregatorFactories.Builder aggs = request.source().aggregations(); + assertEquals(1, aggs.count()); + Collection factory = aggs.getAggregatorFactories(); + assertTrue(!factory.isEmpty()); + Iterator iterator = factory.iterator(); + while (iterator.hasNext()) { + assertThat(iterator.next(), anyOf(instanceOf(MaxAggregationBuilder.class), instanceOf(MinAggregationBuilder.class))); + } + + ActionListener listener = invocation.getArgument(1); + listener.onResponse(searchResponse); + return null; + }).when(client).search(any(SearchRequest.class), any(ActionListener.class)); + + ActionListener> listener = mock(ActionListener.class); + Entity entity = Entity.createSingleAttributeEntity("field", "app_1"); + searchFeatureDao.getEntityMinDataTime(detector, entity, listener); + + ArgumentCaptor> captor = ArgumentCaptor.forClass(Optional.class); + verify(listener).onResponse(captor.capture()); + Optional result = captor.getValue(); + assertEquals((long) earliest, result.get().longValue()); + } + + @SuppressWarnings("unchecked") + public void testGetFeaturesForPeriod_throwToListener_whenResponseParsingFails() throws Exception { + + long start = 100L; + long end = 200L; + + SearchResponse searchResponse = mock(SearchResponse.class); + SearchHits hits = new SearchHits(new SearchHit[0], new TotalHits(1L, TotalHits.Relation.EQUAL_TO), 1f); + when(searchResponse.getHits()).thenReturn(hits); + + // when(ParseUtils.generateInternalFeatureQuery(eq(detector), eq(start), eq(end), eq(xContent))).thenReturn(searchSourceBuilder); + when(detector.getEnabledFeatureIds()).thenReturn(null); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(searchResponse); + return null; + }).when(client).search(any(SearchRequest.class), any(ActionListener.class)); + + ActionListener> listener = mock(ActionListener.class); + searchFeatureDao.getFeaturesForPeriod(detector, start, end, listener); + + verify(listener).onResponse(eq(Optional.empty())); + } + + @SuppressWarnings("unchecked") + public void testGetFeaturesForSampledPeriods_throwToListener_whenSamplingFail() { + SearchFeatureDao spySearchFeatureDao = spy(searchFeatureDao); + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(3); + listener.onFailure(new RuntimeException()); + return null; + }).when(spySearchFeatureDao).getFeaturesForPeriod(any(), anyLong(), anyLong(), any(ActionListener.class)); + + ActionListener>> listener = mock(ActionListener.class); + spySearchFeatureDao.getFeaturesForSampledPeriods(detector, 1, 1, 0, listener); + + verify(listener).onFailure(any(Exception.class)); + } + + @SuppressWarnings("unchecked") + public void testGetLatestDataTime_returnExpectedToListener() { + long epochTime = 100L; + + // simulate response {"took":11,"timed_out":false,"_shards":{"total":1, + // "successful":1,"skipped":0,"failed":0},"hits":{"max_score":null,"hits":[]}, + // "aggregations":{"min_timefield":{"value":1.602211285E12, + // "value_as_string":"2020-10-09T02:41:25.000Z"}, + // "max_timefield":{"value":1.602348325E12,"value_as_string":"2020-10-10T16:45:25.000Z"}}} + DocValueFormat dateFormat = new DocValueFormat.DateTime( + DateFormatter.forPattern("strict_date_optional_time||epoch_millis"), + ZoneId.of("UTC"), + DateFieldMapper.Resolution.MILLISECONDS + ); + InternalMax minInternal = new InternalMax(CommonName.AGG_NAME_MAX_TIME, epochTime, dateFormat, new HashMap<>()); + InternalAggregations internalAggregations = InternalAggregations.from(Arrays.asList(minInternal)); + SearchHits hits = new SearchHits(new SearchHit[] {}, null, Float.NaN); + SearchResponseSections searchSections = new SearchResponseSections(hits, internalAggregations, null, false, false, null, 1); + + SearchResponse searchResponse = new SearchResponse( + searchSections, + null, + 1, + 1, + 0, + 11, + ShardSearchFailure.EMPTY_ARRAY, + SearchResponse.Clusters.EMPTY + ); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(searchResponse); + return null; + }).when(client).search(any(SearchRequest.class), any(ActionListener.class)); + + // when(ParseUtils.getLatestDataTime(eq(searchResponse))).thenReturn(Optional.of(epochTime)); + ActionListener> listener = mock(ActionListener.class); + searchFeatureDao.getLatestDataTime(detector, listener); + + ArgumentCaptor> captor = ArgumentCaptor.forClass(Optional.class); + verify(listener).onResponse(captor.capture()); + Optional result = captor.getValue(); + assertEquals(epochTime, result.get().longValue()); + } } diff --git a/src/test/java/org/opensearch/ad/feature/SearchFeatureDaoTests.java b/src/test/java/org/opensearch/ad/feature/SearchFeatureDaoTests.java index afa58b581..ecf31e1e8 100644 --- a/src/test/java/org/opensearch/ad/feature/SearchFeatureDaoTests.java +++ b/src/test/java/org/opensearch/ad/feature/SearchFeatureDaoTests.java @@ -12,30 +12,21 @@ package org.opensearch.ad.feature; import static java.util.Arrays.asList; -import static org.hamcrest.MatcherAssert.assertThat; -import static org.hamcrest.core.AnyOf.anyOf; -import static org.hamcrest.core.IsInstanceOf.instanceOf; import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Matchers.anyLong; -import static org.mockito.Matchers.anyObject; import static org.mockito.Matchers.eq; import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -import java.time.Clock; -import java.time.ZoneId; import java.time.temporal.ChronoUnit; import java.util.AbstractMap.SimpleEntry; import java.util.ArrayList; import java.util.Arrays; -import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; @@ -45,7 +36,6 @@ import java.util.Map.Entry; import java.util.Optional; import java.util.concurrent.ExecutorService; -import java.util.function.BiConsumer; import junitparams.JUnitParamsRunner; import junitparams.Parameters; @@ -55,60 +45,39 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.ArgumentCaptor; -import org.mockito.Matchers; import org.mockito.Mock; import org.mockito.MockitoAnnotations; -import org.opensearch.action.ActionFuture; import org.opensearch.action.ActionListener; import org.opensearch.action.search.MultiSearchRequest; import org.opensearch.action.search.MultiSearchResponse; import org.opensearch.action.search.MultiSearchResponse.Item; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; -import org.opensearch.action.search.SearchResponseSections; -import org.opensearch.action.search.ShardSearchFailure; import org.opensearch.ad.AnomalyDetectorPlugin; import org.opensearch.ad.NodeStateManager; -import org.opensearch.ad.common.exception.EndRunException; import org.opensearch.ad.constant.CommonName; import org.opensearch.ad.dataprocessor.Interpolator; import org.opensearch.ad.dataprocessor.LinearUniformInterpolator; import org.opensearch.ad.dataprocessor.SingleFeatureLinearUniformInterpolator; import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.model.Entity; import org.opensearch.ad.model.IntervalTimeConfiguration; import org.opensearch.ad.settings.AnomalyDetectorSettings; -import org.opensearch.ad.util.ClientUtil; import org.opensearch.ad.util.ParseUtils; +import org.opensearch.ad.util.SecurityClientUtil; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; -import org.opensearch.common.time.DateFormatter; import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.common.xcontent.NamedXContentRegistry; import org.opensearch.common.xcontent.XContentType; -import org.opensearch.index.mapper.DateFieldMapper; import org.opensearch.index.query.QueryBuilders; -import org.opensearch.script.ScriptService; -import org.opensearch.script.TemplateScript; -import org.opensearch.script.TemplateScript.Factory; -import org.opensearch.search.DocValueFormat; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; import org.opensearch.search.aggregations.Aggregation; -import org.opensearch.search.aggregations.AggregationBuilder; -import org.opensearch.search.aggregations.AggregationBuilders; import org.opensearch.search.aggregations.Aggregations; -import org.opensearch.search.aggregations.AggregatorFactories; -import org.opensearch.search.aggregations.InternalAggregations; -import org.opensearch.search.aggregations.bucket.MultiBucketsAggregation; -import org.opensearch.search.aggregations.metrics.InternalMin; import org.opensearch.search.aggregations.metrics.InternalTDigestPercentiles; import org.opensearch.search.aggregations.metrics.Max; -import org.opensearch.search.aggregations.metrics.MaxAggregationBuilder; -import org.opensearch.search.aggregations.metrics.MinAggregationBuilder; -import org.opensearch.search.aggregations.metrics.NumericMetricsAggregation; import org.opensearch.search.aggregations.metrics.Percentile; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.threadpool.ThreadPool; @@ -118,12 +87,10 @@ import org.powermock.modules.junit4.PowerMockRunner; import org.powermock.modules.junit4.PowerMockRunnerDelegate; -import com.google.gson.Gson; - @PowerMockIgnore("javax.management.*") @RunWith(PowerMockRunner.class) @PowerMockRunnerDelegate(JUnitParamsRunner.class) -@PrepareForTest({ ParseUtils.class, Gson.class }) +@PrepareForTest({ ParseUtils.class }) public class SearchFeatureDaoTests { // private final Logger LOG = LogManager.getLogger(SearchFeatureDaoTests.class); @@ -131,21 +98,12 @@ public class SearchFeatureDaoTests { @Mock private Client client; - @Mock - private ScriptService scriptService; + @Mock private NamedXContentRegistry xContent; - @Mock - private ClientUtil clientUtil; - @Mock - private Factory factory; - @Mock - private TemplateScript templateScript; - @Mock - private ActionFuture searchResponseFuture; - @Mock - private ActionFuture multiSearchResponseFuture; + private SecurityClientUtil clientUtil; + @Mock private SearchResponse searchResponse; @Mock @@ -153,11 +111,7 @@ public class SearchFeatureDaoTests { @Mock private Item multiSearchResponseItem; @Mock - private Aggregations aggs; - @Mock private Max max; - @Mock - private NodeStateManager stateManager; @Mock private AnomalyDetector detector; @@ -168,20 +122,17 @@ public class SearchFeatureDaoTests { @Mock private ClusterService clusterService; - @Mock - private Clock clock; - private SearchRequest searchRequest; private SearchSourceBuilder searchSourceBuilder; private MultiSearchRequest multiSearchRequest; private Map aggsMap; private IntervalTimeConfiguration detectionInterval; private String detectorId; - private Gson gson; private Interpolator interpolator; private Settings settings; @Before + @SuppressWarnings("unchecked") public void setup() throws Exception { MockitoAnnotations.initMocks(this); PowerMockito.mockStatic(ParseUtils.class); @@ -206,6 +157,14 @@ public void setup() throws Exception { ); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + when(client.threadPool()).thenReturn(threadPool); + NodeStateManager nodeStateManager = mock(NodeStateManager.class); + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.of(detector)); + return null; + }).when(nodeStateManager).getAnomalyDetector(any(String.class), any(ActionListener.class)); + clientUtil = new SecurityClientUtil(nodeStateManager, settings); searchFeatureDao = spy( new SearchFeatureDao( client, @@ -240,96 +199,26 @@ public void setup() throws Exception { SearchHits hits = new SearchHits(new SearchHit[0], new TotalHits(1L, TotalHits.Relation.EQUAL_TO), 1f); when(searchResponse.getHits()).thenReturn(hits); - doReturn(Optional.of(searchResponse)) - .when(clientUtil) - .timedRequest(eq(searchRequest), anyObject(), Matchers.>>anyObject()); + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + ActionListener listener = (ActionListener) args[1]; + listener.onResponse(searchResponse); + return null; + }).when(client).search(eq(searchRequest), any()); when(searchResponse.getAggregations()).thenReturn(aggregations); - doReturn(Optional.of(searchResponse)) - .when(clientUtil) - .throttledTimedRequest( - eq(searchRequest), - anyObject(), - Matchers.>>anyObject(), - anyObject() - ); - multiSearchRequest = new MultiSearchRequest(); SearchRequest request = new SearchRequest(detector.getIndices().toArray(new String[0])); multiSearchRequest.add(request); - doReturn(Optional.of(multiSearchResponse)) - .when(clientUtil) - .timedRequest( - eq(multiSearchRequest), - anyObject(), - Matchers.>>anyObject() - ); - when(multiSearchResponse.getResponses()).thenReturn(new Item[] { multiSearchResponseItem }); - when(multiSearchResponseItem.getResponse()).thenReturn(searchResponse); - gson = PowerMockito.mock(Gson.class); - } - - @Test - public void test_getLatestDataTime_returnExpectedTime_givenData() { - // pre-conditions - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder() - .aggregation(AggregationBuilders.max(CommonName.AGG_NAME_MAX_TIME).field(detector.getTimeField())) - .size(0); - searchRequest.source(searchSourceBuilder); - - long epochTime = 100L; - aggsMap.put(CommonName.AGG_NAME_MAX_TIME, max); - when(max.getValue()).thenReturn((double) epochTime); - - // test - Optional result = searchFeatureDao.getLatestDataTime(detector); - - // verify - assertEquals(epochTime, result.get().longValue()); - } - - @Test - public void test_getLatestDataTime_returnEmpty_givenNoData() { - // pre-conditions - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder() - .aggregation(AggregationBuilders.max(CommonName.AGG_NAME_MAX_TIME).field(detector.getTimeField())) - .size(0); - searchRequest.source(searchSourceBuilder); - - when(searchResponse.getAggregations()).thenReturn(null); - - // test - Optional result = searchFeatureDao.getLatestDataTime(detector); - - // verify - assertFalse(result.isPresent()); - } - - @Test - @SuppressWarnings("unchecked") - public void getLatestDataTime_returnExpectedToListener() { - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder() - .aggregation(AggregationBuilders.max(CommonName.AGG_NAME_MAX_TIME).field(detector.getTimeField())) - .size(0); - searchRequest.source(searchSourceBuilder); - long epochTime = 100L; - aggsMap.put(CommonName.AGG_NAME_MAX_TIME, max); - when(max.getValue()).thenReturn((double) epochTime); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(searchResponse); + Object[] args = invocation.getArguments(); + ActionListener listener = (ActionListener) args[1]; + listener.onResponse(multiSearchResponse); return null; - }).when(client).search(eq(searchRequest), any(ActionListener.class)); - - when(ParseUtils.getLatestDataTime(eq(searchResponse))).thenReturn(Optional.of(epochTime)); - ActionListener> listener = mock(ActionListener.class); - searchFeatureDao.getLatestDataTime(detector, listener); - - ArgumentCaptor> captor = ArgumentCaptor.forClass(Optional.class); - verify(listener).onResponse(captor.capture()); - Optional result = captor.getValue(); - assertEquals(epochTime, result.get().longValue()); + }).when(client).multiSearch(eq(multiSearchRequest), any()); + when(multiSearchResponse.getResponses()).thenReturn(new Item[] { multiSearchResponseItem }); + when(multiSearchResponseItem.getResponse()).thenReturn(searchResponse); } @SuppressWarnings("unchecked") @@ -378,54 +267,6 @@ private Object[] getFeaturesForPeriodData() { new Object[] { asList(max, percentiles, missing), asList(maxName, percentileName, missingName), null }, }; } - @Test - @Parameters(method = "getFeaturesForPeriodData") - public void getFeaturesForPeriod_returnExpected_givenData(List aggs, List featureIds, double[] expected) - throws Exception { - - long start = 100L; - long end = 200L; - - // pre-conditions - when(ParseUtils.generateInternalFeatureQuery(eq(detector), eq(start), eq(end), eq(xContent))).thenReturn(searchSourceBuilder); - when(searchResponse.getAggregations()).thenReturn(new Aggregations(aggs)); - when(detector.getEnabledFeatureIds()).thenReturn(featureIds); - - // test - Optional result = searchFeatureDao.getFeaturesForPeriod(detector, start, end); - - // verify - assertTrue(Arrays.equals(expected, result.orElse(null))); - } - - @SuppressWarnings("unchecked") - private Object[] getFeaturesForPeriodThrowIllegalStateData() { - String aggName = "aggName"; - - InternalTDigestPercentiles empty = mock(InternalTDigestPercentiles.class); - Iterator emptyIterator = mock(Iterator.class); - when(empty.iterator()).thenReturn(emptyIterator); - when(emptyIterator.hasNext()).thenReturn(false); - when(empty.getName()).thenReturn(aggName); - - MultiBucketsAggregation multiBucket = mock(MultiBucketsAggregation.class); - when(multiBucket.getName()).thenReturn(aggName); - - return new Object[] { - new Object[] { asList(empty), asList(aggName), null }, - new Object[] { asList(multiBucket), asList(aggName), null }, }; - } - - @Test(expected = EndRunException.class) - @Parameters(method = "getFeaturesForPeriodThrowIllegalStateData") - public void getFeaturesForPeriod_throwIllegalState_forUnknownAggregation( - List aggs, - List featureIds, - double[] expected - ) throws Exception { - getFeaturesForPeriod_returnExpected_givenData(aggs, featureIds, expected); - } - @Test @Parameters(method = "getFeaturesForPeriodData") @SuppressWarnings("unchecked") @@ -452,88 +293,6 @@ public void getFeaturesForPeriod_returnExpectedToListener(List aggs assertTrue(Arrays.equals(expected, result.orElse(null))); } - @Test - @SuppressWarnings("unchecked") - public void getFeaturesForPeriod_throwToListener_whenSearchFails() throws Exception { - - long start = 100L; - long end = 200L; - when(ParseUtils.generateInternalFeatureQuery(eq(detector), eq(start), eq(end), eq(xContent))).thenReturn(searchSourceBuilder); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onFailure(new RuntimeException()); - return null; - }).when(client).search(eq(searchRequest), any(ActionListener.class)); - - ActionListener> listener = mock(ActionListener.class); - searchFeatureDao.getFeaturesForPeriod(detector, start, end, listener); - - verify(listener).onFailure(any(Exception.class)); - } - - @Test - @SuppressWarnings("unchecked") - public void getFeaturesForPeriod_throwToListener_whenResponseParsingFails() throws Exception { - - long start = 100L; - long end = 200L; - when(ParseUtils.generateInternalFeatureQuery(eq(detector), eq(start), eq(end), eq(xContent))).thenReturn(searchSourceBuilder); - when(detector.getEnabledFeatureIds()).thenReturn(null); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(searchResponse); - return null; - }).when(client).search(eq(searchRequest), any(ActionListener.class)); - - ActionListener> listener = mock(ActionListener.class); - searchFeatureDao.getFeaturesForPeriod(detector, start, end, listener); - - verify(listener).onFailure(any(Exception.class)); - } - - @Test - public void test_getFeaturesForPeriod_returnEmpty_givenNoData() throws Exception { - long start = 100L; - long end = 200L; - - // pre-conditions - when(ParseUtils.generateInternalFeatureQuery(eq(detector), eq(start), eq(end), eq(xContent))).thenReturn(searchSourceBuilder); - when(searchResponse.getAggregations()).thenReturn(null); - - // test - Optional result = searchFeatureDao.getFeaturesForPeriod(detector, start, end); - - // verify - assertFalse(result.isPresent()); - } - - @Test - public void getFeaturesForPeriod_returnNonEmpty_givenDefaultValue() throws Exception { - long start = 100L; - long end = 200L; - - // pre-conditions - when(ParseUtils.generateInternalFeatureQuery(eq(detector), eq(start), eq(end), eq(xContent))).thenReturn(searchSourceBuilder); - when(searchResponse.getHits()).thenReturn(new SearchHits(new SearchHit[0], new TotalHits(0L, TotalHits.Relation.EQUAL_TO), 1f)); - - List aggList = new ArrayList<>(1); - - NumericMetricsAggregation.SingleValue agg = mock(NumericMetricsAggregation.SingleValue.class); - when(agg.getName()).thenReturn("deny_max"); - when(agg.value()).thenReturn(0d); - - aggList.add(agg); - - Aggregations aggregations = new Aggregations(aggList); - when(searchResponse.getAggregations()).thenReturn(aggregations); - - // test - Optional result = searchFeatureDao.getFeaturesForPeriod(detector, start, end); - - // verify - assertTrue(result.isPresent()); - } - private Object[] getFeaturesForSampledPeriodsData() { long endTime = 300_000; int maxStride = 4; @@ -635,34 +394,6 @@ private Object[] getFeaturesForSampledPeriodsData() { Optional.of(new SimpleEntry<>(new double[][] { { 1, 2 }, { 3, 4 }, { 5, 6 }, { 7, 8 }, { 9, 10 } }, 1)) }, }; } - @Test - @Parameters(method = "getFeaturesForSampledPeriodsData") - public void getFeaturesForSampledPeriods_returnExpected( - Long[][] queryRanges, - double[][] queryResults, - long endTime, - int maxStride, - int maxSamples, - Optional> expected - ) { - - doReturn(Optional.empty()).when(searchFeatureDao).getFeaturesForPeriod(eq(detector), anyLong(), anyLong()); - for (int i = 0; i < queryRanges.length; i++) { - doReturn(Optional.of(queryResults[i])) - .when(searchFeatureDao) - .getFeaturesForPeriod(detector, queryRanges[i][0], queryRanges[i][1]); - } - - Optional> result = searchFeatureDao - .getFeaturesForSampledPeriods(detector, maxSamples, maxStride, endTime); - - assertEquals(expected.isPresent(), result.isPresent()); - if (expected.isPresent()) { - assertTrue(Arrays.deepEquals(expected.get().getKey(), result.get().getKey())); - assertEquals(expected.get().getValue(), result.get().getValue()); - } - } - @Test @Parameters(method = "getFeaturesForSampledPeriodsData") @SuppressWarnings("unchecked") @@ -703,80 +434,7 @@ public void getFeaturesForSampledPeriods_returnExpectedToListener( } } - @Test - @SuppressWarnings("unchecked") - public void getFeaturesForSampledPeriods_throwToListener_whenSamplingFail() { - doAnswer(invocation -> { - ActionListener> listener = invocation.getArgument(3); - listener.onFailure(new RuntimeException()); - return null; - }).when(searchFeatureDao).getFeaturesForPeriod(any(), anyLong(), anyLong(), any(ActionListener.class)); - - ActionListener>> listener = mock(ActionListener.class); - searchFeatureDao.getFeaturesForSampledPeriods(detector, 1, 1, 0, listener); - - verify(listener).onFailure(any(Exception.class)); - } - private Entry pair(K key, V value) { return new SimpleEntry<>(key, value); } - - @SuppressWarnings("unchecked") - @Test - public void testGetEntityMinDataTime() { - // simulate response {"took":11,"timed_out":false,"_shards":{"total":1, - // "successful":1,"skipped":0,"failed":0},"hits":{"max_score":null,"hits":[]}, - // "aggregations":{"min_timefield":{"value":1.602211285E12, - // "value_as_string":"2020-10-09T02:41:25.000Z"}, - // "max_timefield":{"value":1.602348325E12,"value_as_string":"2020-10-10T16:45:25.000Z"}}} - DocValueFormat dateFormat = new DocValueFormat.DateTime( - DateFormatter.forPattern("strict_date_optional_time||epoch_millis"), - ZoneId.of("UTC"), - DateFieldMapper.Resolution.MILLISECONDS - ); - double earliest = 1.602211285E12; - InternalMin minInternal = new InternalMin("min_timefield", earliest, dateFormat, new HashMap<>()); - InternalAggregations internalAggregations = InternalAggregations.from(Arrays.asList(minInternal)); - SearchHits hits = new SearchHits(new SearchHit[] {}, null, Float.NaN); - SearchResponseSections searchSections = new SearchResponseSections(hits, internalAggregations, null, false, false, null, 1); - - SearchResponse searchResponse = new SearchResponse( - searchSections, - null, - 1, - 1, - 0, - 11, - ShardSearchFailure.EMPTY_ARRAY, - SearchResponse.Clusters.EMPTY - ); - - doAnswer(invocation -> { - SearchRequest request = invocation.getArgument(0); - assertEquals(1, request.indices().length); - assertTrue(detector.getIndices().contains(request.indices()[0])); - AggregatorFactories.Builder aggs = request.source().aggregations(); - assertEquals(1, aggs.count()); - Collection factory = aggs.getAggregatorFactories(); - assertTrue(!factory.isEmpty()); - Iterator iterator = factory.iterator(); - while (iterator.hasNext()) { - assertThat(iterator.next(), anyOf(instanceOf(MaxAggregationBuilder.class), instanceOf(MinAggregationBuilder.class))); - } - - ActionListener listener = invocation.getArgument(1); - listener.onResponse(searchResponse); - return null; - }).when(client).search(any(SearchRequest.class), any(ActionListener.class)); - - ActionListener> listener = mock(ActionListener.class); - Entity entity = Entity.createSingleAttributeEntity("field", "app_1"); - searchFeatureDao.getEntityMinDataTime(detector, entity, listener); - - ArgumentCaptor> captor = ArgumentCaptor.forClass(Optional.class); - verify(listener).onResponse(captor.capture()); - Optional result = captor.getValue(); - assertEquals((long) earliest, result.get().longValue()); - } } diff --git a/src/test/java/org/opensearch/ad/rest/SecureADRestIT.java b/src/test/java/org/opensearch/ad/rest/SecureADRestIT.java index e45bc4663..d50a07358 100644 --- a/src/test/java/org/opensearch/ad/rest/SecureADRestIT.java +++ b/src/test/java/org/opensearch/ad/rest/SecureADRestIT.java @@ -376,6 +376,10 @@ public void testPreviewAnomalyDetectorWithNoReadPermissionOfIndex() throws IOExc Exception.class, () -> { previewAnomalyDetector(aliceDetector.getDetectorId(), elkClient, input); } ); - Assert.assertTrue(exception.getMessage().contains("no permissions for [indices:data/read/search]")); + Assert + .assertTrue( + "actual msg: " + exception.getMessage(), + exception.getMessage().contains("no permissions for [indices:data/read/search]") + ); } } diff --git a/src/test/java/org/opensearch/ad/transport/AnomalyResultTests.java b/src/test/java/org/opensearch/ad/transport/AnomalyResultTests.java index 83e8db73f..be296cdd3 100644 --- a/src/test/java/org/opensearch/ad/transport/AnomalyResultTests.java +++ b/src/test/java/org/opensearch/ad/transport/AnomalyResultTests.java @@ -96,6 +96,7 @@ import org.opensearch.ad.stats.StatNames; import org.opensearch.ad.stats.suppliers.CounterSupplier; import org.opensearch.ad.task.ADTaskManager; +import org.opensearch.ad.util.SecurityClientUtil; import org.opensearch.client.Client; import org.opensearch.cluster.ClusterName; import org.opensearch.cluster.ClusterState; @@ -141,6 +142,7 @@ public class AnomalyResultTests extends AbstractADTest { private FeatureManager featureQuery; private ModelManager normalModelManager; private Client client; + private SecurityClientUtil clientUtil; private AnomalyDetector detector; private HashRing hashRing; private IndexNameExpressionResolver indexNameResolver; @@ -279,6 +281,8 @@ public void setUp() throws Exception { return null; }).when(client).index(any(), any()); + NodeStateManager nodeStateManager = mock(NodeStateManager.class); + clientUtil = new SecurityClientUtil(nodeStateManager, settings); indexNameResolver = new IndexNameExpressionResolver(new ThreadContext(Settings.EMPTY)); @@ -355,6 +359,7 @@ public void testNormal() throws IOException { transportService, settings, client, + clientUtil, stateManager, featureQuery, normalModelManager, @@ -478,6 +483,7 @@ public void sendRequest( realTransportService, settings, client, + clientUtil, stateManager, featureQuery, normalModelManager, @@ -529,6 +535,7 @@ public void testInsufficientCapacityExceptionDuringColdStart() { transportService, settings, client, + clientUtil, stateManager, featureQuery, normalModelManager, @@ -572,6 +579,7 @@ public void testInsufficientCapacityExceptionDuringRestoringModel() { transportService, settings, client, + clientUtil, stateManager, featureQuery, normalModelManager, @@ -691,6 +699,7 @@ public void sendRequest( realTransportService, settings, client, + clientUtil, stateManager, featureQuery, normalModelManager, @@ -732,6 +741,7 @@ public void testCircuitBreaker() { transportService, settings, client, + clientUtil, stateManager, featureQuery, normalModelManager, @@ -800,6 +810,7 @@ private void nodeNotConnectedExceptionTemplate(boolean isRCF, boolean temporary, exceptionTransportService, settings, client, + clientUtil, stateManager, featureQuery, normalModelManager, @@ -852,6 +863,7 @@ public void testMute() { transportService, settings, client, + clientUtil, muteStateManager, featureQuery, normalModelManager, @@ -892,6 +904,7 @@ public void alertingRequestTemplate(boolean anomalyResultIndexExists) throws IOE transportService, settings, client, + clientUtil, stateManager, featureQuery, normalModelManager, @@ -1070,6 +1083,7 @@ public void testOnFailureNull() throws IOException { transportService, settings, client, + clientUtil, stateManager, featureQuery, normalModelManager, @@ -1161,6 +1175,7 @@ public void testColdStartNoTrainingData() throws Exception { transportService, settings, client, + clientUtil, stateManager, featureQuery, normalModelManager, @@ -1198,6 +1213,7 @@ public void testConcurrentColdStart() throws Exception { transportService, settings, client, + clientUtil, stateManager, featureQuery, normalModelManager, @@ -1241,6 +1257,7 @@ public void testColdStartTimeoutPutCheckpoint() throws Exception { transportService, settings, client, + clientUtil, stateManager, featureQuery, normalModelManager, @@ -1284,6 +1301,7 @@ public void testColdStartIllegalArgumentException() throws Exception { transportService, settings, client, + clientUtil, stateManager, featureQuery, normalModelManager, @@ -1334,6 +1352,7 @@ public void featureTestTemplate(FeatureTestMode mode) throws IOException { transportService, settings, client, + clientUtil, stateManager, featureQuery, normalModelManager, @@ -1421,6 +1440,7 @@ private void globalBlockTemplate(BlockType type, String errLogMsg, Settings inde transportService, settings, client, + clientUtil, stateManager, featureQuery, normalModelManager, @@ -1469,6 +1489,7 @@ public void testNullRCFResult() { transportService, settings, client, + clientUtil, stateManager, featureQuery, normalModelManager, @@ -1495,6 +1516,7 @@ public void testNormalRCFResult() { transportService, settings, client, + clientUtil, stateManager, featureQuery, normalModelManager, @@ -1528,6 +1550,7 @@ public void testNullPointerRCFResult() { transportService, settings, client, + clientUtil, stateManager, featureQuery, normalModelManager, @@ -1569,6 +1592,7 @@ public void testAllFeaturesDisabled() throws IOException { transportService, settings, client, + clientUtil, stateManager, featureQuery, normalModelManager, @@ -1632,6 +1656,7 @@ public void testEndRunDueToNoTrainingData() { transportService, settings, client, + clientUtil, stateManager, featureQuery, normalModelManager, @@ -1684,6 +1709,7 @@ public void testColdStartEndRunException() { transportService, settings, client, + clientUtil, stateManager, featureQuery, normalModelManager, @@ -1727,6 +1753,7 @@ public void testColdStartEndRunExceptionNow() { transportService, settings, client, + clientUtil, stateManager, featureQuery, normalModelManager, @@ -1767,6 +1794,7 @@ public void testColdStartBecauseFailtoGetCheckpoint() { transportService, settings, client, + clientUtil, stateManager, featureQuery, normalModelManager, @@ -1805,6 +1833,7 @@ public void testNoColdStartDueToUnknownException() { transportService, settings, client, + clientUtil, stateManager, featureQuery, normalModelManager, diff --git a/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorTests.java b/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorTests.java index 4e0afda86..cd01b529d 100644 --- a/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorTests.java +++ b/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorTests.java @@ -18,6 +18,7 @@ import static org.opensearch.ad.model.AnomalyDetector.ANOMALY_DETECTORS_INDEX; import java.io.IOException; +import java.time.Clock; import java.util.Arrays; import java.util.Collections; import java.util.HashSet; @@ -31,11 +32,14 @@ import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.PlainActionFuture; import org.opensearch.ad.AbstractADTest; +import org.opensearch.ad.NodeStateManager; import org.opensearch.ad.constant.CommonErrorMessages; import org.opensearch.ad.model.Entity; import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.ad.task.ADTaskManager; import org.opensearch.ad.util.DiscoveryNodeFilterer; +import org.opensearch.ad.util.SecurityClientUtil; +import org.opensearch.ad.util.Throttler; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.ClusterSettings; @@ -49,6 +53,7 @@ public class GetAnomalyDetectorTests extends AbstractADTest { private DiscoveryNodeFilterer nodeFilter; private ActionFilters actionFilters; private Client client; + private SecurityClientUtil clientUtil; private GetAnomalyDetectorRequest request; private String detectorId = "yecrdnUBqurvo9uKU_d8"; private String entityValue = "app_0"; @@ -96,6 +101,12 @@ public void setUp() throws Exception { client = mock(Client.class); when(client.threadPool()).thenReturn(threadPool); + Clock clock = mock(Clock.class); + Throttler throttler = new Throttler(clock); + + NodeStateManager nodeStateManager = mock(NodeStateManager.class); + clientUtil = new SecurityClientUtil(nodeStateManager, Settings.EMPTY); + adTaskManager = mock(ADTaskManager.class); action = new GetAnomalyDetectorTransportAction( @@ -104,6 +115,7 @@ public void setUp() throws Exception { actionFilters, clusterService, client, + clientUtil, Settings.EMPTY, xContentRegistry(), adTaskManager diff --git a/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorTransportActionTests.java b/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorTransportActionTests.java index 793665051..9c98be99b 100644 --- a/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorTransportActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorTransportActionTests.java @@ -20,13 +20,13 @@ import java.util.Collections; import java.util.HashSet; import java.util.Map; +import java.util.concurrent.TimeUnit; -import org.junit.Assert; -import org.junit.Before; -import org.junit.Test; +import org.junit.*; import org.mockito.Mockito; import org.opensearch.action.ActionListener; import org.opensearch.action.support.ActionFilters; +import org.opensearch.ad.NodeStateManager; import org.opensearch.ad.TestHelpers; import org.opensearch.ad.constant.CommonName; import org.opensearch.ad.model.ADTask; @@ -37,8 +37,7 @@ import org.opensearch.ad.model.InitProgressProfile; import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.ad.task.ADTaskManager; -import org.opensearch.ad.util.DiscoveryNodeFilterer; -import org.opensearch.ad.util.RestHandlerUtils; +import org.opensearch.ad.util.*; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.io.stream.NamedWriteableAwareStreamInput; @@ -51,12 +50,14 @@ import org.opensearch.rest.RestStatus; import org.opensearch.tasks.Task; import org.opensearch.test.OpenSearchSingleNodeTestCase; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; import com.google.common.collect.ImmutableMap; public class GetAnomalyDetectorTransportActionTests extends OpenSearchSingleNodeTestCase { - + private static ThreadPool threadPool; private GetAnomalyDetectorTransportAction action; private Task task; private ActionListener response; @@ -65,6 +66,16 @@ public class GetAnomalyDetectorTransportActionTests extends OpenSearchSingleNode private String categoryField; private String categoryValue; + @BeforeClass + public static void beforeCLass() { + threadPool = new TestThreadPool("GetAnomalyDetectorTransportActionTests"); + } + + @AfterClass + public static void afterClass() { + ThreadPool.terminate(threadPool, 30, TimeUnit.SECONDS); + } + @Override @Before public void setUp() throws Exception { @@ -76,12 +87,15 @@ public void setUp() throws Exception { ); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); adTaskManager = mock(ADTaskManager.class); + NodeStateManager nodeStateManager = mock(NodeStateManager.class); + SecurityClientUtil clientUtil = new SecurityClientUtil(nodeStateManager, Settings.EMPTY); action = new GetAnomalyDetectorTransportAction( Mockito.mock(TransportService.class), Mockito.mock(DiscoveryNodeFilterer.class), Mockito.mock(ActionFilters.class), clusterService, client(), + clientUtil, Settings.EMPTY, xContentRegistry(), adTaskManager diff --git a/src/test/java/org/opensearch/ad/transport/IndexAnomalyDetectorTransportActionTests.java b/src/test/java/org/opensearch/ad/transport/IndexAnomalyDetectorTransportActionTests.java index 9447b3dbd..a5fca1ba8 100644 --- a/src/test/java/org/opensearch/ad/transport/IndexAnomalyDetectorTransportActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/IndexAnomalyDetectorTransportActionTests.java @@ -34,12 +34,14 @@ import org.opensearch.action.search.ShardSearchFailure; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.WriteRequest; +import org.opensearch.ad.NodeStateManager; import org.opensearch.ad.TestHelpers; import org.opensearch.ad.feature.SearchFeatureDao; import org.opensearch.ad.indices.AnomalyDetectionIndices; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.ad.task.ADTaskManager; +import org.opensearch.ad.util.SecurityClientUtil; import org.opensearch.client.Client; import org.opensearch.cluster.ClusterName; import org.opensearch.cluster.ClusterState; @@ -71,6 +73,7 @@ public class IndexAnomalyDetectorTransportActionTests extends OpenSearchIntegTes private ClusterSettings clusterSettings; private ADTaskManager adTaskManager; private Client client = mock(Client.class); + private SecurityClientUtil clientUtil; private SearchFeatureDao searchFeatureDao; @SuppressWarnings("unchecked") @@ -103,10 +106,13 @@ public void setUp() throws Exception { adTaskManager = mock(ADTaskManager.class); searchFeatureDao = mock(SearchFeatureDao.class); + NodeStateManager nodeStateManager = mock(NodeStateManager.class); + clientUtil = new SecurityClientUtil(nodeStateManager, Settings.EMPTY); action = new IndexAnomalyDetectorTransportAction( mock(TransportService.class), mock(ActionFilters.class), client(), + clientUtil, clusterService, indexSettings(), mock(AnomalyDetectionIndices.class), @@ -199,6 +205,7 @@ public void testIndexTransportActionWithUserAndFilterOn() { mock(TransportService.class), mock(ActionFilters.class), client, + clientUtil, clusterService, settings, mock(AnomalyDetectionIndices.class), @@ -224,6 +231,7 @@ public void testIndexTransportActionWithUserAndFilterOff() { mock(TransportService.class), mock(ActionFilters.class), client, + clientUtil, clusterService, settings, mock(AnomalyDetectionIndices.class), diff --git a/src/test/java/org/opensearch/ad/transport/MultiEntityResultTests.java b/src/test/java/org/opensearch/ad/transport/MultiEntityResultTests.java index e2c232f23..b08ef6b1f 100644 --- a/src/test/java/org/opensearch/ad/transport/MultiEntityResultTests.java +++ b/src/test/java/org/opensearch/ad/transport/MultiEntityResultTests.java @@ -100,6 +100,8 @@ import org.opensearch.ad.task.ADTaskManager; import org.opensearch.ad.util.Bwc; import org.opensearch.ad.util.ClientUtil; +import org.opensearch.ad.util.SecurityClientUtil; +import org.opensearch.ad.util.Throttler; import org.opensearch.client.Client; import org.opensearch.cluster.metadata.IndexNameExpressionResolver; import org.opensearch.cluster.node.DiscoveryNode; @@ -145,6 +147,7 @@ public class MultiEntityResultTests extends AbstractADTest { private static Settings settings; private TransportService transportService; private Client client; + private SecurityClientUtil securityClientUtil; private FeatureManager featureQuery; private ModelManager normalModelManager; private HashRing hashRing; @@ -215,6 +218,7 @@ public void setUp() throws Exception { setUpADThreadPool(mockThreadPool); when(client.threadPool()).thenReturn(mockThreadPool); when(mockThreadPool.getThreadContext()).thenReturn(threadContext); + securityClientUtil = new SecurityClientUtil(stateManager, settings); featureQuery = mock(FeatureManager.class); @@ -273,6 +277,7 @@ public void setUp() throws Exception { transportService, settings, client, + securityClientUtil, stateManager, featureQuery, normalModelManager, @@ -417,34 +422,35 @@ private void setUpEntityResult(int nodeIndex) { @SuppressWarnings("unchecked") public void setUpNormlaStateManager() throws IOException { - ClientUtil clientUtil = mock(ClientUtil.class); - AnomalyDetector detector = TestHelpers.AnomalyDetectorBuilder .newInstance() .setDetectionInterval(new IntervalTimeConfiguration(1, ChronoUnit.MINUTES)) .setCategoryFields(ImmutableList.of(randomAlphaOfLength(5))) .build(); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); + ActionListener listener = invocation.getArgument(1); listener.onResponse(TestHelpers.createGetResponse(detector, detectorId, AnomalyDetector.ANOMALY_DETECTORS_INDEX)); return null; - }).when(clientUtil).asyncRequest(any(GetRequest.class), any(), any(ActionListener.class)); + }).when(client).get(any(GetRequest.class), any(ActionListener.class)); stateManager = new NodeStateManager( client, xContentRegistry(), settings, - clientUtil, + new ClientUtil(settings, client, new Throttler(mock(Clock.class)), threadPool), clock, AnomalyDetectorSettings.HOURLY_MAINTENANCE, clusterService ); + securityClientUtil = new SecurityClientUtil(stateManager, settings); + action = new AnomalyResultTransportAction( new ActionFilters(Collections.emptySet()), transportService, settings, client, + securityClientUtil, stateManager, featureQuery, normalModelManager, @@ -518,7 +524,6 @@ public void testIndexNotFound() throws InterruptedException, IOException { }).when(client).search(any(), any()); PlainActionFuture listener = new PlainActionFuture<>(); - action.doExecute(null, request, listener); AnomalyResultResponse response = listener.actionGet(10000L); @@ -684,6 +689,7 @@ public void sendRequest( realTransportService, settings, client, + securityClientUtil, stateManager, featureQuery, normalModelManager, @@ -743,6 +749,7 @@ public void testCircuitBreakerOpen() throws InterruptedException, IOException { transportService, settings, client, + securityClientUtil, stateManager, featureQuery, normalModelManager, @@ -1094,6 +1101,7 @@ public void testPageToString() { detector, xContentRegistry(), client, + securityClientUtil, 100, clock, settings, @@ -1118,6 +1126,7 @@ public void testEmptyPageToString() { detector, xContentRegistry(), client, + securityClientUtil, 100, clock, settings, diff --git a/src/test/java/org/opensearch/search/aggregations/metrics/CardinalityProfileTests.java b/src/test/java/org/opensearch/search/aggregations/metrics/CardinalityProfileTests.java index af6aeef47..5e86089d3 100644 --- a/src/test/java/org/opensearch/search/aggregations/metrics/CardinalityProfileTests.java +++ b/src/test/java/org/opensearch/search/aggregations/metrics/CardinalityProfileTests.java @@ -11,8 +11,7 @@ package org.opensearch.search.aggregations.metrics; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.*; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; @@ -22,11 +21,7 @@ import java.io.IOException; import java.time.temporal.ChronoUnit; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; +import java.util.*; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; @@ -37,6 +32,8 @@ import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.ad.AbstractProfileRunnerTests; +import org.opensearch.ad.AnomalyDetectorProfileRunner; +import org.opensearch.ad.NodeStateManager; import org.opensearch.ad.TestHelpers; import org.opensearch.ad.constant.CommonName; import org.opensearch.ad.model.AnomalyDetector; @@ -45,7 +42,9 @@ import org.opensearch.ad.transport.ProfileAction; import org.opensearch.ad.transport.ProfileNodeResponse; import org.opensearch.ad.transport.ProfileResponse; +import org.opensearch.ad.util.SecurityClientUtil; import org.opensearch.cluster.ClusterName; +import org.opensearch.common.settings.Settings; import org.opensearch.common.util.BigArrays; import org.opensearch.search.aggregations.InternalAggregation; import org.opensearch.search.aggregations.InternalAggregations; @@ -73,6 +72,23 @@ private void setUpMultiEntityClientGet(DetectorStatus detectorStatus, JobStatus throws IOException { detector = TestHelpers .randomAnomalyDetectorWithInterval(new IntervalTimeConfiguration(detectorIntervalMin, ChronoUnit.MINUTES), true); + NodeStateManager nodeStateManager = mock(NodeStateManager.class); + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.of(detector)); + return null; + }).when(nodeStateManager).getAnomalyDetector(anyString(), any(ActionListener.class)); + clientUtil = new SecurityClientUtil(nodeStateManager, Settings.EMPTY); + runner = new AnomalyDetectorProfileRunner( + client, + clientUtil, + xContentRegistry(), + nodeFilter, + requiredSamples, + transportService, + adTaskManager + ); + doAnswer(invocation -> { Object[] args = invocation.getArguments(); GetRequest request = (GetRequest) args[0];