diff --git a/src/main/java/org/opensearch/ad/AnomalyDetectorJobRunner.java b/src/main/java/org/opensearch/ad/AnomalyDetectorJobRunner.java index a7af702a7..135799930 100644 --- a/src/main/java/org/opensearch/ad/AnomalyDetectorJobRunner.java +++ b/src/main/java/org/opensearch/ad/AnomalyDetectorJobRunner.java @@ -45,6 +45,7 @@ import org.opensearch.ad.transport.AnomalyResultRequest; import org.opensearch.ad.transport.AnomalyResultResponse; import org.opensearch.ad.transport.AnomalyResultTransportAction; +import org.opensearch.ad.util.SecurityUtil; import org.opensearch.client.Client; import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.LoggingDeprecationHandler; @@ -53,6 +54,7 @@ import org.opensearch.common.xcontent.XContentParser; import org.opensearch.common.xcontent.XContentType; import org.opensearch.commons.InjectSecurity; +import org.opensearch.commons.authuser.User; import org.opensearch.jobscheduler.spi.JobExecutionContext; import org.opensearch.jobscheduler.spi.LockModel; import org.opensearch.jobscheduler.spi.ScheduledJobParameter; @@ -62,7 +64,6 @@ import org.opensearch.threadpool.ThreadPool; import com.google.common.base.Throwables; -import com.google.common.collect.ImmutableList; /** * JobScheduler will call AD job runner to get anomaly result periodically @@ -145,49 +146,57 @@ public void runJob(ScheduledJobParameter scheduledJobParameter, JobExecutionCont final LockService lockService = context.getLockService(); Runnable runnable = () -> { - nodeStateManager.getAnomalyDetector(detectorId, ActionListener.wrap(detectorOptional -> { - if (!detectorOptional.isPresent()) { - log.error(new ParameterizedMessage("fail to get detector [{}]", detectorId)); - return; - } - AnomalyDetector detector = detectorOptional.get(); - - if (jobParameter.getLockDurationSeconds() != null) { - lockService - .acquireLock( - jobParameter, - context, - ActionListener - .wrap( - lock -> runAdJob( - jobParameter, - lockService, - lock, - detectionStartTime, - executionStartTime, - recorder, - detector - ), - exception -> { - indexAnomalyResultException( + try { + nodeStateManager.getAnomalyDetector(detectorId, ActionListener.wrap(detectorOptional -> { + if (!detectorOptional.isPresent()) { + log.error(new ParameterizedMessage("fail to get detector [{}]", detectorId)); + return; + } + AnomalyDetector detector = detectorOptional.get(); + + if (jobParameter.getLockDurationSeconds() != null) { + lockService + .acquireLock( + jobParameter, + context, + ActionListener + .wrap( + lock -> runAdJob( jobParameter, lockService, - null, + lock, detectionStartTime, executionStartTime, - exception, - false, recorder, detector - ); - throw new IllegalStateException("Failed to acquire lock for AD job: " + detectorId); - } - ) - ); - } else { - log.warn("Can't get lock for AD job: " + detectorId); - } - }, e -> log.error(new ParameterizedMessage("fail to get detector [{}]", detectorId), e))); + ), + exception -> { + indexAnomalyResultException( + jobParameter, + lockService, + null, + detectionStartTime, + executionStartTime, + exception, + false, + recorder, + detector + ); + throw new IllegalStateException("Failed to acquire lock for AD job: " + detectorId); + } + ) + ); + } else { + log.warn("Can't get lock for AD job: " + detectorId); + } + + }, e -> log.error(new ParameterizedMessage("fail to get detector [{}]", detectorId), e))); + } catch (Exception e) { + // os log won't show anything if there is an exception happens (maybe due to running on a ExecutorService) + // we at least log the error. + log.error("Can't start AD job: " + detectorId, e); + throw e; + } }; ExecutorService executor = threadPool.executor(AD_THREAD_POOL_NAME); @@ -231,28 +240,11 @@ protected void runAdJob( } anomalyDetectionIndices.update(); - /* - * We need to handle 3 cases: - * 1. Detectors created by older versions and never updated. These detectors wont have User details in the - * detector object. `detector.user` will be null. Insert `all_access, AmazonES_all_access` role. - * 2. Detectors are created when security plugin is disabled, these will have empty User object. - * (`detector.user.name`, `detector.user.roles` are empty ) - * 3. Detectors are created when security plugin is enabled, these will have an User object. - * This will inject user role and check if the user role has permissions to call the execute - * Anomaly Result API. - */ - String user; - List roles; - if (jobParameter.getUser() == null) { - // It's possible that user create domain with security disabled, then enable security - // after upgrading. This is for BWC, for old detectors which created when security - // disabled, the user will be null. - user = ""; - roles = settings.getAsList("", ImmutableList.of("all_access", "AmazonES_all_access")); - } else { - user = jobParameter.getUser().getName(); - roles = jobParameter.getUser().getRoles(); - } + User userInfo = SecurityUtil.getUserFromJob(jobParameter, settings); + + String user = userInfo.getName(); + List roles = userInfo.getRoles(); + String resultIndex = jobParameter.getResultIndex(); if (resultIndex == null) { runAnomalyDetectionJob( @@ -302,7 +294,7 @@ private void runAnomalyDetectionJob( ExecuteADResultResponseRecorder recorder, AnomalyDetector detector ) { - + // using one thread in the write threadpool try (InjectSecurity injectSecurity = new InjectSecurity(detectorId, settings, client.threadPool().getThreadContext())) { // Injecting user role to verify if the user has permissions for our API. injectSecurity.inject(user, roles); diff --git a/src/main/java/org/opensearch/ad/AnomalyDetectorPlugin.java b/src/main/java/org/opensearch/ad/AnomalyDetectorPlugin.java index 1ec27253a..0bc49541d 100644 --- a/src/main/java/org/opensearch/ad/AnomalyDetectorPlugin.java +++ b/src/main/java/org/opensearch/ad/AnomalyDetectorPlugin.java @@ -158,10 +158,7 @@ import org.opensearch.ad.transport.handler.AnomalyIndexHandler; import org.opensearch.ad.transport.handler.AnomalyResultBulkIndexHandler; import org.opensearch.ad.transport.handler.MultiEntityResultHandler; -import org.opensearch.ad.util.ClientUtil; -import org.opensearch.ad.util.DiscoveryNodeFilterer; -import org.opensearch.ad.util.IndexUtils; -import org.opensearch.ad.util.Throttler; +import org.opensearch.ad.util.*; import org.opensearch.client.Client; import org.opensearch.cluster.metadata.IndexNameExpressionResolver; import org.opensearch.cluster.node.DiscoveryNodes; @@ -232,6 +229,7 @@ public class AnomalyDetectorPlugin extends Plugin implements ActionPlugin, Scrip private ThreadPool threadPool; private ADStats adStats; private ClientUtil clientUtil; + private SecurityClientUtil securityClientUtil; private DiscoveryNodeFilterer nodeFilter; private IndexUtils indexUtils; private ADTaskManager adTaskManager; @@ -344,11 +342,21 @@ public Collection createComponents( SingleFeatureLinearUniformInterpolator singleFeatureLinearUniformInterpolator = new IntegerSensitiveSingleFeatureLinearUniformInterpolator(); Interpolator interpolator = new LinearUniformInterpolator(singleFeatureLinearUniformInterpolator); + stateManager = new NodeStateManager( + client, + xContentRegistry, + settings, + clientUtil, + getClock(), + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + clusterService + ); + securityClientUtil = new SecurityClientUtil(stateManager, settings); SearchFeatureDao searchFeatureDao = new SearchFeatureDao( client, xContentRegistry, interpolator, - clientUtil, + securityClientUtil, settings, clusterService, AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE @@ -373,16 +381,6 @@ public Collection createComponents( adCircuitBreakerService ); - stateManager = new NodeStateManager( - client, - xContentRegistry, - settings, - clientUtil, - getClock(), - AnomalyDetectorSettings.HOURLY_MAINTENANCE, - clusterService - ); - FeatureManager featureManager = new FeatureManager( searchFeatureDao, interpolator, @@ -731,6 +729,7 @@ public PooledObject wrap(LinkedBuffer obj) { threadPool, clusterService, client, + securityClientUtil, adCircuitBreakerService, featureManager, adTaskManager, diff --git a/src/main/java/org/opensearch/ad/AnomalyDetectorProfileRunner.java b/src/main/java/org/opensearch/ad/AnomalyDetectorProfileRunner.java index bb555d0c9..5d1405a66 100644 --- a/src/main/java/org/opensearch/ad/AnomalyDetectorProfileRunner.java +++ b/src/main/java/org/opensearch/ad/AnomalyDetectorProfileRunner.java @@ -54,9 +54,7 @@ import org.opensearch.ad.transport.RCFPollingAction; import org.opensearch.ad.transport.RCFPollingRequest; import org.opensearch.ad.transport.RCFPollingResponse; -import org.opensearch.ad.util.DiscoveryNodeFilterer; -import org.opensearch.ad.util.ExceptionUtil; -import org.opensearch.ad.util.MultiResponsesDelegateActionListener; +import org.opensearch.ad.util.*; import org.opensearch.client.Client; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.common.xcontent.LoggingDeprecationHandler; @@ -78,6 +76,7 @@ public class AnomalyDetectorProfileRunner extends AbstractProfileRunner { private final Logger logger = LogManager.getLogger(AnomalyDetectorProfileRunner.class); private Client client; + private SecurityClientUtil clientUtil; private NamedXContentRegistry xContentRegistry; private DiscoveryNodeFilterer nodeFilter; private final TransportService transportService; @@ -86,6 +85,7 @@ public class AnomalyDetectorProfileRunner extends AbstractProfileRunner { public AnomalyDetectorProfileRunner( Client client, + SecurityClientUtil clientUtil, NamedXContentRegistry xContentRegistry, DiscoveryNodeFilterer nodeFilter, long requiredSamples, @@ -94,6 +94,7 @@ public AnomalyDetectorProfileRunner( ) { super(requiredSamples); this.client = client; + this.clientUtil = clientUtil; this.xContentRegistry = xContentRegistry; this.nodeFilter = nodeFilter; if (requiredSamples <= 0) { @@ -293,7 +294,7 @@ private void profileEntityStats(MultiResponsesDelegateActionListener { + final ActionListener searchResponseListener = ActionListener.wrap(searchResponse -> { Map aggMap = searchResponse.getAggregations().asMap(); InternalCardinality totalEntities = (InternalCardinality) aggMap.get(CommonName.TOTAL_ENTITIES); long value = totalEntities.getValue(); @@ -303,7 +304,17 @@ private void profileEntityStats(MultiResponsesDelegateActionListener { logger.warn(CommonErrorMessages.FAIL_TO_GET_TOTAL_ENTITIES + detector.getDetectorId()); listener.onFailure(searchException); - })); + }); + // using the original context in listener as user roles have no permissions for internal operations like fetching a + // checkpoint + clientUtil + .asyncRequestWithInjectedSecurity( + request, + client::search, + detector.getDetectorId(), + client, + searchResponseListener + ); } else { // Run a composite query and count the number of buckets to decide cardinality of multiple category fields AggregationBuilder bucketAggs = AggregationBuilders @@ -316,7 +327,7 @@ private void profileEntityStats(MultiResponsesDelegateActionListener { + final ActionListener searchResponseListener = ActionListener.wrap(searchResponse -> { DetectorProfile.Builder profileBuilder = new DetectorProfile.Builder(); Aggregations aggs = searchResponse.getAggregations(); if (aggs == null) { @@ -342,7 +353,17 @@ private void profileEntityStats(MultiResponsesDelegateActionListener { logger.warn(CommonErrorMessages.FAIL_TO_GET_TOTAL_ENTITIES + detector.getDetectorId()); listener.onFailure(searchException); - })); + }); + // using the original context in listener as user roles have no permissions for internal operations like fetching a + // checkpoint + clientUtil + .asyncRequestWithInjectedSecurity( + searchRequest, + client::search, + detector.getDetectorId(), + client, + searchResponseListener + ); } } diff --git a/src/main/java/org/opensearch/ad/EntityProfileRunner.java b/src/main/java/org/opensearch/ad/EntityProfileRunner.java index 11c4015e4..e69b1b90b 100644 --- a/src/main/java/org/opensearch/ad/EntityProfileRunner.java +++ b/src/main/java/org/opensearch/ad/EntityProfileRunner.java @@ -26,6 +26,7 @@ import org.opensearch.action.ActionListener; import org.opensearch.action.get.GetRequest; import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; import org.opensearch.ad.constant.CommonErrorMessages; import org.opensearch.ad.constant.CommonName; import org.opensearch.ad.model.AnomalyDetector; @@ -43,6 +44,7 @@ import org.opensearch.ad.transport.EntityProfileResponse; import org.opensearch.ad.util.MultiResponsesDelegateActionListener; import org.opensearch.ad.util.ParseUtils; +import org.opensearch.ad.util.SecurityClientUtil; import org.opensearch.client.Client; import org.opensearch.cluster.routing.Preference; import org.opensearch.common.xcontent.LoggingDeprecationHandler; @@ -64,11 +66,13 @@ public class EntityProfileRunner extends AbstractProfileRunner { static final String EMPTY_ENTITY_ATTRIBUTES = "Empty entity attributes"; static final String NO_ENTITY = "Cannot find entity"; private Client client; + private SecurityClientUtil clientUtil; private NamedXContentRegistry xContentRegistry; - public EntityProfileRunner(Client client, NamedXContentRegistry xContentRegistry, long requiredSamples) { + public EntityProfileRunner(Client client, SecurityClientUtil clientUtil, NamedXContentRegistry xContentRegistry, long requiredSamples) { super(requiredSamples); this.client = client; + this.clientUtil = clientUtil; this.xContentRegistry = xContentRegistry; } @@ -165,8 +169,7 @@ private void validateEntity( SearchRequest searchRequest = new SearchRequest(detector.getIndices().toArray(new String[0]), searchSourceBuilder) .preference(Preference.LOCAL.toString()); - - client.search(searchRequest, ActionListener.wrap(searchResponse -> { + final ActionListener searchResponseListener = ActionListener.wrap(searchResponse -> { try { if (searchResponse.getHits().getHits().length == 0) { listener.onFailure(new IllegalArgumentException(NO_ENTITY)); @@ -177,7 +180,17 @@ private void validateEntity( listener.onFailure(new IllegalArgumentException(NO_ENTITY)); return; } - }, e -> listener.onFailure(new IllegalArgumentException(NO_ENTITY)))); + }, e -> listener.onFailure(new IllegalArgumentException(NO_ENTITY))); + // using the original context in listener as user roles have no permissions for internal operations like fetching a + // checkpoint + clientUtil + .asyncRequestWithInjectedSecurity( + searchRequest, + client::search, + detector.getDetectorId(), + client, + searchResponseListener + ); } diff --git a/src/main/java/org/opensearch/ad/feature/CompositeRetriever.java b/src/main/java/org/opensearch/ad/feature/CompositeRetriever.java index 907b4603e..e6bae39df 100644 --- a/src/main/java/org/opensearch/ad/feature/CompositeRetriever.java +++ b/src/main/java/org/opensearch/ad/feature/CompositeRetriever.java @@ -32,6 +32,7 @@ import org.opensearch.ad.model.Entity; import org.opensearch.ad.model.Feature; import org.opensearch.ad.util.ParseUtils; +import org.opensearch.ad.util.SecurityClientUtil; import org.opensearch.client.Client; import org.opensearch.cluster.metadata.IndexNameExpressionResolver; import org.opensearch.cluster.service.ClusterService; @@ -67,6 +68,7 @@ public class CompositeRetriever extends AbstractRetriever { private final AnomalyDetector anomalyDetector; private final NamedXContentRegistry xContent; private final Client client; + private final SecurityClientUtil clientUtil; private int totalResults; // we can process at most maxEntities entities private int maxEntities; @@ -82,6 +84,7 @@ public CompositeRetriever( AnomalyDetector anomalyDetector, NamedXContentRegistry xContent, Client client, + SecurityClientUtil clientUtil, long expirationEpochMs, Clock clock, Settings settings, @@ -95,6 +98,7 @@ public CompositeRetriever( this.anomalyDetector = anomalyDetector; this.xContent = xContent; this.client = client; + this.clientUtil = clientUtil; this.totalResults = 0; this.maxEntities = maxEntitiesPerInterval; this.pageSize = pageSize; @@ -111,6 +115,7 @@ public CompositeRetriever( AnomalyDetector anomalyDetector, NamedXContentRegistry xContent, Client client, + SecurityClientUtil clientUtil, long expirationEpochMs, Settings settings, int maxEntitiesPerInterval, @@ -124,6 +129,7 @@ public CompositeRetriever( anomalyDetector, xContent, client, + clientUtil, expirationEpochMs, Clock.systemUTC(), settings, @@ -191,8 +197,11 @@ public PageIterator(SearchSourceBuilder source) { */ public void next(ActionListener listener) { iterations++; + + // inject user role while searching. + SearchRequest searchRequest = new SearchRequest(anomalyDetector.getIndices().toArray(new String[0]), source); - client.search(searchRequest, new ActionListener() { + final ActionListener searchResponseListener = new ActionListener() { @Override public void onResponse(SearchResponse response) { processResponse(response, () -> client.search(searchRequest, this), listener); @@ -202,7 +211,17 @@ public void onResponse(SearchResponse response) { public void onFailure(Exception e) { listener.onFailure(e); } - }); + }; + // using the original context in listener as user roles have no permissions for internal operations like fetching a + // checkpoint + clientUtil + .asyncRequestWithInjectedSecurity( + searchRequest, + client::search, + anomalyDetector.getDetectorId(), + client, + searchResponseListener + ); } private void processResponse(SearchResponse response, Runnable retry, ActionListener listener) { diff --git a/src/main/java/org/opensearch/ad/feature/FeatureManager.java b/src/main/java/org/opensearch/ad/feature/FeatureManager.java index cec226f55..f04f24926 100644 --- a/src/main/java/org/opensearch/ad/feature/FeatureManager.java +++ b/src/main/java/org/opensearch/ad/feature/FeatureManager.java @@ -278,32 +278,6 @@ private LongStream getFullShingleEndTimes(long endTime, long intervalMilli, int return LongStream.rangeClosed(1, shingleSize).map(i -> endTime - (shingleSize - i) * intervalMilli); } - /** - * Provides data for cold-start training. - * - * @deprecated use getColdStartData with listener instead. - * - * Training data starts with getting samples from (costly) search. - * Samples are increased in size via interpolation and then - * in dimension via shingling. - * - * @param detector contains data info (indices, documents, etc) - * @return data for cold-start training, or empty if unavailable - */ - @Deprecated - public Optional getColdStartData(AnomalyDetector detector) { - int shingleSize = detector.getShingleSize(); - return searchFeatureDao - .getLatestDataTime(detector) - .flatMap(latest -> searchFeatureDao.getFeaturesForSampledPeriods(detector, maxTrainSamples, maxSampleStride, latest)) - .map( - samples -> transpose( - interpolator.interpolate(transpose(samples.getKey()), samples.getValue() * (samples.getKey().length - 1) + 1) - ) - ) - .map(points -> batchShingle(points, shingleSize)); - } - /** * Returns to listener data for cold-start training. * @@ -488,7 +462,7 @@ public void getPreviewFeaturesForEntity( int stride = sampleRangeResults.getValue(); int shingleSize = detector.getShingleSize(); - getSamplesInRangesForEntity(detector, sampleRanges, entity, getFeatureSamplesListener(stride, shingleSize, listener)); + getPreviewSamplesInRangesForEntity(detector, sampleRanges, entity, getFeatureSamplesListener(stride, shingleSize, listener)); } private ActionListener>, double[][]>> getFeatureSamplesListener( @@ -564,7 +538,7 @@ private Entry>, Integer> getSampleRanges(AnomalyDetector * @param listener handle search results map: key is time ranges, value is corresponding search results * @throws IOException if a user gives wrong query input when defining a detector */ - void getSamplesInRangesForEntity( + void getPreviewSamplesInRangesForEntity( AnomalyDetector detector, List> sampleRanges, Entity entity, diff --git a/src/main/java/org/opensearch/ad/feature/SearchFeatureDao.java b/src/main/java/org/opensearch/ad/feature/SearchFeatureDao.java index 397f57f5a..04ce3025a 100644 --- a/src/main/java/org/opensearch/ad/feature/SearchFeatureDao.java +++ b/src/main/java/org/opensearch/ad/feature/SearchFeatureDao.java @@ -46,8 +46,8 @@ import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.model.Entity; import org.opensearch.ad.model.IntervalTimeConfiguration; -import org.opensearch.ad.util.ClientUtil; import org.opensearch.ad.util.ParseUtils; +import org.opensearch.ad.util.SecurityClientUtil; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; @@ -68,7 +68,6 @@ import org.opensearch.search.aggregations.bucket.range.InternalDateRange; import org.opensearch.search.aggregations.bucket.range.InternalDateRange.Bucket; import org.opensearch.search.aggregations.bucket.terms.Terms; -import org.opensearch.search.aggregations.metrics.Max; import org.opensearch.search.aggregations.metrics.Min; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.sort.FieldSortBuilder; @@ -88,7 +87,7 @@ public class SearchFeatureDao extends AbstractRetriever { private final Client client; private final NamedXContentRegistry xContent; private final Interpolator interpolator; - private final ClientUtil clientUtil; + private final SecurityClientUtil clientUtil; private volatile int maxEntitiesForPreview; private volatile int pageSize; private final int minimumDocCountForPreview; @@ -100,7 +99,7 @@ public SearchFeatureDao( Client client, NamedXContentRegistry xContent, Interpolator interpolator, - ClientUtil clientUtil, + SecurityClientUtil clientUtil, Settings settings, ClusterService clusterService, int minimumDocCount, @@ -142,7 +141,7 @@ public SearchFeatureDao( Client client, NamedXContentRegistry xContent, Interpolator interpolator, - ClientUtil clientUtil, + SecurityClientUtil clientUtil, Settings settings, ClusterService clusterService, int minimumDocCount @@ -162,28 +161,6 @@ public SearchFeatureDao( ); } - /** - * Returns epoch time of the latest data under the detector. - * - * @deprecated use getLatestDataTime with listener instead. - * - * @param detector info about the indices and documents - * @return epoch time of the latest data in milliseconds - */ - @Deprecated - public Optional getLatestDataTime(AnomalyDetector detector) { - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder() - .aggregation(AggregationBuilders.max(CommonName.AGG_NAME_MAX_TIME).field(detector.getTimeField())) - .size(0); - SearchRequest searchRequest = new SearchRequest().indices(detector.getIndices().toArray(new String[0])).source(searchSourceBuilder); - return clientUtil - .timedRequest(searchRequest, logger, client::search) - .map(SearchResponse::getAggregations) - .map(aggs -> aggs.asMap()) - .map(map -> (Max) map.get(CommonName.AGG_NAME_MAX_TIME)) - .map(agg -> (long) agg.getValue()); - } - /** * Returns to listener the epoch time of the latset data under the detector. * @@ -195,10 +172,17 @@ public void getLatestDataTime(AnomalyDetector detector, ActionListener searchResponseListener = ActionListener + .wrap(response -> listener.onResponse(ParseUtils.getLatestDataTime(response)), listener::onFailure); + // using the original context in listener as user roles have no permissions for internal operations like fetching a + // checkpoint + clientUtil + .asyncRequestWithInjectedSecurity( searchRequest, - ActionListener.wrap(response -> listener.onResponse(ParseUtils.getLatestDataTime(response)), listener::onFailure) + client::search, + detector.getDetectorId(), + client, + searchResponseListener ); } @@ -360,18 +344,24 @@ public void getHighestCountEntities( .trackTotalHits(false) .size(0); SearchRequest searchRequest = new SearchRequest().indices(detector.getIndices().toArray(new String[0])).source(searchSourceBuilder); - client - .search( + final ActionListener searchResponseListener = new TopEntitiesListener( + listener, + detector, + searchSourceBuilder, + // TODO: tune timeout for historical analysis based on performance test result + clock.millis() + previewTimeoutInMilliseconds, + maxEntitiesSize, + minimumDocCount + ); + // using the original context in listener as user roles have no permissions for internal operations like fetching a + // checkpoint + clientUtil + .asyncRequestWithInjectedSecurity( searchRequest, - new TopEntitiesListener( - listener, - detector, - searchSourceBuilder, - // TODO: tune timeout for historical analysis based on performance test result - clock.millis() + previewTimeoutInMilliseconds, - maxEntitiesSize, - minimumDocCount - ) + client::search, + detector.getDetectorId(), + client, + searchResponseListener ); } @@ -455,9 +445,14 @@ public void onResponse(SearchResponse response) { } } else { updateSourceAfterKey(afterKey, searchSourceBuilder); - client - .search( + // using the original context in listener as user roles have no permissions for internal operations like fetching a + // checkpoint + clientUtil + .asyncRequestWithInjectedSecurity( new SearchRequest().indices(detector.getIndices().toArray(new String[0])).source(searchSourceBuilder), + client::search, + detector.getDetectorId(), + client, this ); } @@ -493,10 +488,16 @@ public void getEntityMinDataTime(AnomalyDetector detector, Entity entity, Action .trackTotalHits(false) .size(0); SearchRequest searchRequest = new SearchRequest().indices(detector.getIndices().toArray(new String[0])).source(searchSourceBuilder); - client - .search( + final ActionListener searchResponseListener = ActionListener + .wrap(response -> { listener.onResponse(parseMinDataTime(response)); }, listener::onFailure); + // inject user role while searching. + clientUtil + .asyncRequestWithInjectedSecurity( searchRequest, - ActionListener.wrap(response -> { listener.onResponse(parseMinDataTime(response)); }, listener::onFailure) + client::search, + detector.getDetectorId(), + client, + searchResponseListener ); } @@ -509,31 +510,6 @@ private Optional parseMinDataTime(SearchResponse searchResponse) { return mapOptional.map(map -> (Min) map.get(AGG_NAME_MIN)).map(agg -> (long) agg.getValue()); } - /** - * Gets features for the given time period. - * This function also adds given detector to negative cache before sending es request. - * Once response/exception is received within timeout, this request will be treated as complete - * and cleared from the negative cache. - * Otherwise this detector entry remain in the negative to reject further request. - * - * @deprecated use getFeaturesForPeriod with listener instead. - * - * @param detector info about indices, documents, feature query - * @param startTime epoch milliseconds at the beginning of the period - * @param endTime epoch milliseconds at the end of the period - * @throws IllegalStateException when unexpected failures happen - * @return features from search results, empty when no data found - */ - @Deprecated - public Optional getFeaturesForPeriod(AnomalyDetector detector, long startTime, long endTime) { - SearchRequest searchRequest = createFeatureSearchRequest(detector, startTime, endTime, Optional.empty()); - - // send throttled request: this request will clear the negative cache if the request finished within timeout - return clientUtil - .throttledTimedRequest(searchRequest, logger, client::search, detector) - .flatMap(resp -> parseResponse(resp, detector.getEnabledFeatureIds())); - } - /** * Returns to listener features for the given time period. * @@ -544,11 +520,17 @@ public Optional getFeaturesForPeriod(AnomalyDetector detector, long st */ public void getFeaturesForPeriod(AnomalyDetector detector, long startTime, long endTime, ActionListener> listener) { SearchRequest searchRequest = createFeatureSearchRequest(detector, startTime, endTime, Optional.empty()); - client - .search( + final ActionListener searchResponseListener = ActionListener + .wrap(response -> listener.onResponse(parseResponse(response, detector.getEnabledFeatureIds())), listener::onFailure); + // using the original context in listener as user roles have no permissions for internal operations like fetching a + // checkpoint + clientUtil + .asyncRequestWithInjectedSecurity( searchRequest, - ActionListener - .wrap(response -> listener.onResponse(parseResponse(response, detector.getEnabledFeatureIds())), listener::onFailure) + client::search, + detector.getDetectorId(), + client, + searchResponseListener ); } @@ -563,14 +545,19 @@ public void getFeaturesForPeriodByBatch( logger.debug("Batch query for detector {}: {} ", detector.getDetectorId(), searchSourceBuilder); SearchRequest searchRequest = new SearchRequest(detector.getIndices().toArray(new String[0])).source(searchSourceBuilder); - client - .search( + final ActionListener searchResponseListener = ActionListener + .wrap( + response -> { listener.onResponse(parseBucketAggregationResponse(response, detector.getEnabledFeatureIds())); }, + listener::onFailure + ); + // inject user role while searching. + clientUtil + .asyncRequestWithInjectedSecurity( searchRequest, - ActionListener - .wrap( - response -> { listener.onResponse(parseBucketAggregationResponse(response, detector.getEnabledFeatureIds())); }, - listener::onFailure - ) + client::search, + detector.getDetectorId(), + client, + searchResponseListener ); } @@ -608,8 +595,7 @@ public void getFeatureSamplesForPeriods( ActionListener>> listener ) throws IOException { SearchRequest request = createPreviewSearchRequest(detector, ranges); - - client.search(request, ActionListener.wrap(response -> { + final ActionListener searchResponseListener = ActionListener.wrap(response -> { Aggregations aggs = response.getAggregations(); if (aggs == null) { listener.onResponse(Collections.emptyList()); @@ -626,105 +612,16 @@ public void getFeatureSamplesForPeriods( .map(bucket -> parseBucket(bucket, detector.getEnabledFeatureIds())) .collect(Collectors.toList()) ); - }, listener::onFailure)); - } - - /** - * Gets features for sampled periods. - * - * @deprecated use getFeaturesForSampledPeriods with listener instead. - * - * Sampling starts with the latest period and goes backwards in time until there are up to {@code maxSamples} samples. - * If the initial stride {@code maxStride} results into a low count of samples, the implementation - * may attempt with (exponentially) reduced strides and interpolate missing points. - * - * @param detector info about indices, documents, feature query - * @param maxSamples the maximum number of samples to return - * @param maxStride the maximum number of periods between samples - * @param endTime the end time of the latest period - * @return sampled features and stride, empty when no data found - */ - @Deprecated - public Optional> getFeaturesForSampledPeriods( - AnomalyDetector detector, - int maxSamples, - int maxStride, - long endTime - ) { - Map cache = new HashMap<>(); - int currentStride = maxStride; - Optional features = Optional.empty(); - logger.info(String.format(Locale.ROOT, "Getting features for detector %s starting %d", detector.getDetectorId(), endTime)); - while (currentStride >= 1) { - boolean isInterpolatable = currentStride < maxStride; - features = getFeaturesForSampledPeriods(detector, maxSamples, currentStride, endTime, cache, isInterpolatable); - - if (!features.isPresent() || features.get().length > maxSamples / 2 || currentStride == 1) { - logger - .info( - String - .format( - Locale.ROOT, - "Get features for detector %s finishes with features present %b, current stride %d", - detector.getDetectorId(), - features.isPresent(), - currentStride - ) - ); - break; - } else { - currentStride = currentStride / 2; - } - } - if (features.isPresent()) { - return Optional.of(new SimpleEntry<>(features.get(), currentStride)); - } else { - return Optional.empty(); - } - } - - private Optional getFeaturesForSampledPeriods( - AnomalyDetector detector, - int maxSamples, - int stride, - long endTime, - Map cache, - boolean isInterpolatable - ) { - ArrayDeque sampledFeatures = new ArrayDeque<>(maxSamples); - for (int i = 0; i < maxSamples; i++) { - long span = ((IntervalTimeConfiguration) detector.getDetectionInterval()).toDuration().toMillis(); - long end = endTime - span * stride * i; - if (cache.containsKey(end)) { - sampledFeatures.addFirst(cache.get(end)); - } else { - Optional features = getFeaturesForPeriod(detector, end - span, end); - if (features.isPresent()) { - cache.put(end, features.get()); - sampledFeatures.addFirst(features.get()); - } else if (isInterpolatable) { - Optional previous = Optional.ofNullable(cache.get(end - span * stride)); - Optional next = Optional.ofNullable(cache.get(end + span * stride)); - if (previous.isPresent() && next.isPresent()) { - double[] interpolants = getInterpolants(previous.get(), next.get()); - cache.put(end, interpolants); - sampledFeatures.addFirst(interpolants); - } else { - break; - } - } else { - break; - } - - } - } - Optional samples; - if (sampledFeatures.isEmpty()) { - samples = Optional.empty(); - } else { - samples = Optional.of(sampledFeatures.toArray(new double[0][0])); - } - return samples; + }, listener::onFailure); + // inject user role while searching + clientUtil + .asyncRequestWithInjectedSecurity( + request, + client::search, + detector.getDetectorId(), + client, + searchResponseListener + ); } /** @@ -966,10 +863,9 @@ public void getColdStartSamplesForPeriods( Entity entity, boolean includesEmptyBucket, ActionListener>> listener - ) throws IOException { + ) { SearchRequest request = createColdStartFeatureSearchRequest(detector, ranges, entity); - - client.search(request, ActionListener.wrap(response -> { + final ActionListener searchResponseListener = ActionListener.wrap(response -> { Aggregations aggs = response.getAggregations(); if (aggs == null) { listener.onResponse(Collections.emptyList()); @@ -1001,7 +897,17 @@ public void getColdStartSamplesForPeriods( .map(bucket -> parseBucket(bucket, detector.getEnabledFeatureIds())) .collect(Collectors.toList()) ); - }, listener::onFailure)); + }, listener::onFailure); + + // inject user role while searching. + clientUtil + .asyncRequestWithInjectedSecurity( + request, + client::search, + detector.getDetectorId(), + client, + searchResponseListener + ); } private SearchRequest createColdStartFeatureSearchRequest(AnomalyDetector detector, List> ranges, Entity entity) { diff --git a/src/main/java/org/opensearch/ad/rest/handler/AbstractAnomalyDetectorActionHandler.java b/src/main/java/org/opensearch/ad/rest/handler/AbstractAnomalyDetectorActionHandler.java index ee59b55d1..251a065de 100644 --- a/src/main/java/org/opensearch/ad/rest/handler/AbstractAnomalyDetectorActionHandler.java +++ b/src/main/java/org/opensearch/ad/rest/handler/AbstractAnomalyDetectorActionHandler.java @@ -70,8 +70,10 @@ import org.opensearch.ad.transport.ValidateAnomalyDetectorResponse; import org.opensearch.ad.util.MultiResponsesDelegateActionListener; import org.opensearch.ad.util.RestHandlerUtils; +import org.opensearch.ad.util.SecurityClientUtil; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.common.xcontent.NamedXContentRegistry; @@ -143,6 +145,7 @@ public abstract class AbstractAnomalyDetectorActionHandler listener; @@ -152,12 +155,14 @@ public abstract class AbstractAnomalyDetectorActionHandler listener, AnomalyDetectionIndices anomalyDetectionIndices, @@ -201,10 +208,12 @@ public AbstractAnomalyDetectorActionHandler( SearchFeatureDao searchFeatureDao, String validationType, boolean isDryRun, - Clock clock + Clock clock, + Settings settings ) { this.clusterService = clusterService; this.client = client; + this.clientUtil = clientUtil; this.transportService = transportService; this.anomalyDetectionIndices = anomalyDetectionIndices; this.listener = listener; @@ -225,6 +234,7 @@ public AbstractAnomalyDetectorActionHandler( this.validationType = validationType; this.isDryRun = isDryRun; this.clock = clock; + this.settings = settings; } /** @@ -384,7 +394,7 @@ protected void validateTimeField(boolean indexingDryRun) { logger.error(message, error); listener.onFailure(new IllegalArgumentException(message)); }); - client.execute(GetFieldMappingsAction.INSTANCE, getMappingsRequest, mappingsListener); + clientUtil.executeWithInjectedSecurity(GetFieldMappingsAction.INSTANCE, getMappingsRequest, user, client, mappingsListener); } /** @@ -664,7 +674,7 @@ protected void validateCategoricalField(String detectorId, boolean indexingDryRu listener.onFailure(new IllegalArgumentException(message)); }); - client.execute(GetFieldMappingsAction.INSTANCE, getMappingsRequest, mappingsListener); + clientUtil.executeWithInjectedSecurity(GetFieldMappingsAction.INSTANCE, getMappingsRequest, user, client, mappingsListener); } protected void searchAdInputIndices(String detectorId, boolean indexingDryRun) { @@ -675,15 +685,13 @@ protected void searchAdInputIndices(String detectorId, boolean indexingDryRun) { SearchRequest searchRequest = new SearchRequest(anomalyDetector.getIndices().toArray(new String[0])).source(searchSourceBuilder); - client - .search( - searchRequest, - ActionListener - .wrap( - searchResponse -> onSearchAdInputIndicesResponse(searchResponse, detectorId, indexingDryRun), - exception -> listener.onFailure(exception) - ) + ActionListener searchResponseListener = ActionListener + .wrap( + searchResponse -> onSearchAdInputIndicesResponse(searchResponse, detectorId, indexingDryRun), + exception -> listener.onFailure(exception) ); + + clientUtil.asyncRequestWithInjectedSecurity(searchRequest, client::search, user, client, searchResponseListener); } protected void onSearchAdInputIndicesResponse(SearchResponse response, String detectorId, boolean indexingDryRun) throws IOException { @@ -710,7 +718,6 @@ protected void checkADNameExists(String detectorId, boolean indexingDryRun) thro } SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(boolQueryBuilder).timeout(requestTimeout); SearchRequest searchRequest = new SearchRequest(ANOMALY_DETECTORS_INDEX).source(searchSourceBuilder); - client .search( searchRequest, @@ -769,13 +776,16 @@ protected void finishDetectorValidationOrContinueToModelValidation() { ModelValidationActionHandler modelValidationActionHandler = new ModelValidationActionHandler( clusterService, client, + clientUtil, (ActionListener) listener, anomalyDetector, requestTimeout, xContentRegistry, searchFeatureDao, validationType, - clock + clock, + settings, + user ); modelValidationActionHandler.checkIfMultiEntityDetector(); } @@ -929,7 +939,7 @@ protected void validateAnomalyDetectorFeatures(String detectorId, boolean indexi ); ssb.aggregation(internalAgg.getAggregatorFactories().iterator().next()); SearchRequest searchRequest = new SearchRequest().indices(anomalyDetector.getIndices().toArray(new String[0])).source(ssb); - client.search(searchRequest, ActionListener.wrap(response -> { + ActionListener searchResponseListener = ActionListener.wrap(response -> { Optional aggFeatureResult = searchFeatureDao.parseResponse(response, Arrays.asList(feature.getId())); if (aggFeatureResult.isPresent()) { multiFeatureQueriesResponseListener @@ -950,7 +960,8 @@ protected void validateAnomalyDetectorFeatures(String detectorId, boolean indexi } logger.error(errorMessage, e); multiFeatureQueriesResponseListener.onFailure(new OpenSearchStatusException(errorMessage, RestStatus.BAD_REQUEST, e)); - })); + }); + clientUtil.asyncRequestWithInjectedSecurity(searchRequest, client::search, user, client, searchResponseListener); } } } diff --git a/src/main/java/org/opensearch/ad/rest/handler/IndexAnomalyDetectorActionHandler.java b/src/main/java/org/opensearch/ad/rest/handler/IndexAnomalyDetectorActionHandler.java index 6291c6472..598c8285b 100644 --- a/src/main/java/org/opensearch/ad/rest/handler/IndexAnomalyDetectorActionHandler.java +++ b/src/main/java/org/opensearch/ad/rest/handler/IndexAnomalyDetectorActionHandler.java @@ -18,8 +18,10 @@ import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.task.ADTaskManager; import org.opensearch.ad.transport.IndexAnomalyDetectorResponse; +import org.opensearch.ad.util.SecurityClientUtil; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.common.xcontent.NamedXContentRegistry; import org.opensearch.commons.authuser.User; @@ -38,6 +40,7 @@ public class IndexAnomalyDetectorActionHandler extends AbstractAnomalyDetectorAc * * @param clusterService ClusterService * @param client ES node client that executes actions on the local node + * @param clientUtil AD client util * @param transportService ES transport service * @param listener ES channel used to construct bytes / builder based outputs, and send responses * @param anomalyDetectionIndices anomaly detector index manager @@ -55,10 +58,12 @@ public class IndexAnomalyDetectorActionHandler extends AbstractAnomalyDetectorAc * @param user User context * @param adTaskManager AD Task manager * @param searchFeatureDao Search feature dao + * @param settings Node settings */ public IndexAnomalyDetectorActionHandler( ClusterService clusterService, Client client, + SecurityClientUtil clientUtil, TransportService transportService, ActionListener listener, AnomalyDetectionIndices anomalyDetectionIndices, @@ -75,11 +80,13 @@ public IndexAnomalyDetectorActionHandler( NamedXContentRegistry xContentRegistry, User user, ADTaskManager adTaskManager, - SearchFeatureDao searchFeatureDao + SearchFeatureDao searchFeatureDao, + Settings settings ) { super( clusterService, client, + clientUtil, transportService, listener, anomalyDetectionIndices, @@ -99,7 +106,8 @@ public IndexAnomalyDetectorActionHandler( searchFeatureDao, null, false, - null + null, + settings ); } diff --git a/src/main/java/org/opensearch/ad/rest/handler/ModelValidationActionHandler.java b/src/main/java/org/opensearch/ad/rest/handler/ModelValidationActionHandler.java index f99b44070..74ffa3fdd 100644 --- a/src/main/java/org/opensearch/ad/rest/handler/ModelValidationActionHandler.java +++ b/src/main/java/org/opensearch/ad/rest/handler/ModelValidationActionHandler.java @@ -49,10 +49,13 @@ import org.opensearch.ad.transport.ValidateAnomalyDetectorResponse; import org.opensearch.ad.util.MultiResponsesDelegateActionListener; import org.opensearch.ad.util.ParseUtils; +import org.opensearch.ad.util.SecurityClientUtil; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.common.xcontent.NamedXContentRegistry; +import org.opensearch.commons.authuser.User; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryBuilders; @@ -92,17 +95,21 @@ public class ModelValidationActionHandler { protected final TimeValue requestTimeout; protected final AnomalyDetectorActionHandler handler = new AnomalyDetectorActionHandler(); protected final Client client; + protected final SecurityClientUtil clientUtil; protected final NamedXContentRegistry xContentRegistry; protected final ActionListener listener; protected final SearchFeatureDao searchFeatureDao; protected final Clock clock; protected final String validationType; + protected final Settings settings; + protected final User user; /** * Constructor function. * * @param clusterService ClusterService * @param client ES node client that executes actions on the local node + * @param clientUtil AD client util * @param listener ES channel used to construct bytes / builder based outputs, and send responses * @param anomalyDetector anomaly detector instance * @param requestTimeout request time out configuration @@ -110,20 +117,26 @@ public class ModelValidationActionHandler { * @param searchFeatureDao Search feature DAO * @param validationType Specified type for validation * @param clock clock object to know when to timeout + * @param settings Node settings + * @param user User info */ public ModelValidationActionHandler( ClusterService clusterService, Client client, + SecurityClientUtil clientUtil, ActionListener listener, AnomalyDetector anomalyDetector, TimeValue requestTimeout, NamedXContentRegistry xContentRegistry, SearchFeatureDao searchFeatureDao, String validationType, - Clock clock + Clock clock, + Settings settings, + User user ) { this.clusterService = clusterService; this.client = client; + this.clientUtil = clientUtil; this.listener = listener; this.anomalyDetector = anomalyDetector; this.requestTimeout = requestTimeout; @@ -131,6 +144,8 @@ public ModelValidationActionHandler( this.searchFeatureDao = searchFeatureDao; this.validationType = validationType; this.clock = clock; + this.settings = settings; + this.user = user; } // Need to first check if multi entity detector or not before doing any sort of validation. @@ -195,7 +210,7 @@ private void getTopEntity(ActionListener> topEntityListener) SearchRequest searchRequest = new SearchRequest() .indices(anomalyDetector.getIndices().toArray(new String[0])) .source(searchSourceBuilder); - client.search(searchRequest, ActionListener.wrap(response -> { + final ActionListener searchResponseListener = ActionListener.wrap(response -> { Aggregations aggs = response.getAggregations(); if (aggs == null) { topEntityListener.onResponse(Collections.emptyMap()); @@ -228,7 +243,17 @@ private void getTopEntity(ActionListener> topEntityListener) } } topEntityListener.onResponse(topKeys); - }, topEntityListener::onFailure)); + }, topEntityListener::onFailure); + // using the original context in listener as user roles have no permissions for internal operations like fetching a + // checkpoint + clientUtil + .asyncRequestWithInjectedSecurity( + searchRequest, + client::search, + user, + client, + searchResponseListener + ); } private void getLatestDateForValidation(Map topEntity) { @@ -303,18 +328,25 @@ private void getBucketAggregates( listener.onFailure(exception); logger.error("Failed to get interval recommendation", exception); }); - client - .search( + final ActionListener searchResponseListener = + new ModelValidationActionHandler.DetectorIntervalRecommendationListener( + intervalListener, + searchRequest.source(), + (IntervalTimeConfiguration) anomalyDetector.getDetectionInterval(), + clock.millis() + TOP_VALIDATE_TIMEOUT_IN_MILLIS, + latestTime, + false, + MAX_TIMES_DECREASING_INTERVAL + ); + // using the original context in listener as user roles have no permissions for internal operations like fetching a + // checkpoint + clientUtil + .asyncRequestWithInjectedSecurity( searchRequest, - new ModelValidationActionHandler.DetectorIntervalRecommendationListener( - intervalListener, - searchRequest.source(), - (IntervalTimeConfiguration) anomalyDetector.getDetectionInterval(), - clock.millis() + TOP_VALIDATE_TIMEOUT_IN_MILLIS, - latestTime, - false, - MAX_TIMES_DECREASING_INTERVAL - ) + client::search, + user, + client, + searchResponseListener ); } @@ -421,11 +453,16 @@ public void onResponse(SearchResponse response) { searchSourceBuilder.query(), getBucketAggregation(this.latestTime, new IntervalTimeConfiguration(newIntervalMinute, ChronoUnit.MINUTES)) ); - client - .search( + // using the original context in listener as user roles have no permissions for internal operations like fetching a + // checkpoint + clientUtil + .asyncRequestWithInjectedSecurity( new SearchRequest() .indices(anomalyDetector.getIndices().toArray(new String[0])) .source(updatedSearchSourceBuilder), + client::search, + user, + client, this ); // In this case decreasingInterval has to be true already, so we will stop @@ -452,9 +489,14 @@ private void searchWithDifferentInterval(long newIntervalMinuteValue) { searchSourceBuilder.query(), getBucketAggregation(this.latestTime, new IntervalTimeConfiguration(newIntervalMinuteValue, ChronoUnit.MINUTES)) ); - client - .search( + // using the original context in listener as user roles have no permissions for internal operations like fetching a + // checkpoint + clientUtil + .asyncRequestWithInjectedSecurity( new SearchRequest().indices(anomalyDetector.getIndices().toArray(new String[0])).source(updatedSearchSourceBuilder), + client::search, + user, + client, this ); } @@ -524,7 +566,18 @@ private void checkRawDataSparsity(long latestTime) { ); SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().aggregation(aggregation).size(0).timeout(requestTimeout); SearchRequest searchRequest = new SearchRequest(anomalyDetector.getIndices().toArray(new String[0])).source(searchSourceBuilder); - client.search(searchRequest, ActionListener.wrap(response -> processRawDataResults(response, latestTime), listener::onFailure)); + final ActionListener searchResponseListener = ActionListener + .wrap(response -> processRawDataResults(response, latestTime), listener::onFailure); + // using the original context in listener as user roles have no permissions for internal operations like fetching a + // checkpoint + clientUtil + .asyncRequestWithInjectedSecurity( + searchRequest, + client::search, + user, + client, + searchResponseListener + ); } private Histogram checkBucketResultErrors(SearchResponse response) { @@ -580,7 +633,18 @@ private void checkDataFilterSparsity(long latestTime) { BoolQueryBuilder query = QueryBuilders.boolQuery().filter(anomalyDetector.getFilterQuery()); SearchSourceBuilder searchSourceBuilder = getSearchSourceBuilder(query, aggregation); SearchRequest searchRequest = new SearchRequest(anomalyDetector.getIndices().toArray(new String[0])).source(searchSourceBuilder); - client.search(searchRequest, ActionListener.wrap(response -> processDataFilterResults(response, latestTime), listener::onFailure)); + final ActionListener searchResponseListener = ActionListener + .wrap(response -> processDataFilterResults(response, latestTime), listener::onFailure); + // using the original context in listener as user roles have no permissions for internal operations like fetching a + // checkpoint + clientUtil + .asyncRequestWithInjectedSecurity( + searchRequest, + client::search, + user, + client, + searchResponseListener + ); } private void processDataFilterResults(SearchResponse response, long latestTime) { @@ -634,7 +698,18 @@ private void checkCategoryFieldSparsity(Map topEntity, long late ); SearchSourceBuilder searchSourceBuilder = getSearchSourceBuilder(query, aggregation); SearchRequest searchRequest = new SearchRequest(anomalyDetector.getIndices().toArray(new String[0])).source(searchSourceBuilder); - client.search(searchRequest, ActionListener.wrap(response -> processTopEntityResults(response, latestTime), listener::onFailure)); + final ActionListener searchResponseListener = ActionListener + .wrap(response -> processTopEntityResults(response, latestTime), listener::onFailure); + // using the original context in listener as user roles have no permissions for internal operations like fetching a + // checkpoint + clientUtil + .asyncRequestWithInjectedSecurity( + searchRequest, + client::search, + user, + client, + searchResponseListener + ); } private void processTopEntityResults(SearchResponse response, long latestTime) { @@ -695,7 +770,7 @@ private void checkFeatureQueryDelegate(long latestTime) throws IOException { SearchSourceBuilder searchSourceBuilder = getSearchSourceBuilder(query, aggregation); SearchRequest searchRequest = new SearchRequest(anomalyDetector.getIndices().toArray(new String[0])) .source(searchSourceBuilder); - client.search(searchRequest, ActionListener.wrap(response -> { + final ActionListener searchResponseListener = ActionListener.wrap(response -> { Histogram aggregate = checkBucketResultErrors(response); if (aggregate == null) { return; @@ -718,7 +793,17 @@ private void checkFeatureQueryDelegate(long latestTime) throws IOException { logger.error(e); multiFeatureQueriesResponseListener .onFailure(new OpenSearchStatusException(CommonErrorMessages.FEATURE_QUERY_TOO_SPARSE, RestStatus.BAD_REQUEST, e)); - })); + }); + // using the original context in listener as user roles have no permissions for internal operations like fetching a + // checkpoint + clientUtil + .asyncRequestWithInjectedSecurity( + searchRequest, + client::search, + user, + client, + searchResponseListener + ); } } diff --git a/src/main/java/org/opensearch/ad/rest/handler/ValidateAnomalyDetectorActionHandler.java b/src/main/java/org/opensearch/ad/rest/handler/ValidateAnomalyDetectorActionHandler.java index e394d0318..1a1a15626 100644 --- a/src/main/java/org/opensearch/ad/rest/handler/ValidateAnomalyDetectorActionHandler.java +++ b/src/main/java/org/opensearch/ad/rest/handler/ValidateAnomalyDetectorActionHandler.java @@ -18,8 +18,10 @@ import org.opensearch.ad.indices.AnomalyDetectionIndices; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.transport.ValidateAnomalyDetectorResponse; +import org.opensearch.ad.util.SecurityClientUtil; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.common.xcontent.NamedXContentRegistry; import org.opensearch.commons.authuser.User; @@ -36,6 +38,7 @@ public class ValidateAnomalyDetectorActionHandler extends AbstractAnomalyDetecto * * @param clusterService ClusterService * @param client ES node client that executes actions on the local node + * @param clientUtil AD client utility * @param listener ES channel used to construct bytes / builder based outputs, and send responses * @param anomalyDetectionIndices anomaly detector index manager * @param anomalyDetector anomaly detector instance @@ -49,10 +52,12 @@ public class ValidateAnomalyDetectorActionHandler extends AbstractAnomalyDetecto * @param searchFeatureDao Search feature DAO * @param validationType Specified type for validation * @param clock Clock object to know when to timeout + * @param settings Node settings */ public ValidateAnomalyDetectorActionHandler( ClusterService clusterService, Client client, + SecurityClientUtil clientUtil, ActionListener listener, AnomalyDetectionIndices anomalyDetectionIndices, AnomalyDetector anomalyDetector, @@ -65,11 +70,13 @@ public ValidateAnomalyDetectorActionHandler( User user, SearchFeatureDao searchFeatureDao, String validationType, - Clock clock + Clock clock, + Settings settings ) { super( clusterService, client, + clientUtil, null, listener, anomalyDetectionIndices, @@ -89,7 +96,8 @@ public ValidateAnomalyDetectorActionHandler( searchFeatureDao, validationType, true, - clock + clock, + settings ); } diff --git a/src/main/java/org/opensearch/ad/task/ADBatchTaskRunner.java b/src/main/java/org/opensearch/ad/task/ADBatchTaskRunner.java index 23bc3848a..0c1516387 100644 --- a/src/main/java/org/opensearch/ad/task/ADBatchTaskRunner.java +++ b/src/main/java/org/opensearch/ad/task/ADBatchTaskRunner.java @@ -50,6 +50,7 @@ import org.opensearch.action.ActionListener; import org.opensearch.action.ActionListenerResponseHandler; import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; import org.opensearch.action.support.ThreadedActionListener; import org.opensearch.ad.breaker.ADCircuitBreakerService; import org.opensearch.ad.caching.PriorityTracker; @@ -89,6 +90,7 @@ import org.opensearch.ad.transport.handler.AnomalyResultBulkIndexHandler; import org.opensearch.ad.util.ExceptionUtil; import org.opensearch.ad.util.ParseUtils; +import org.opensearch.ad.util.SecurityClientUtil; import org.opensearch.client.Client; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.service.ClusterService; @@ -123,6 +125,7 @@ public class ADBatchTaskRunner { private Settings settings; private final ThreadPool threadPool; private final Client client; + private final SecurityClientUtil clientUtil; private final ADStats adStats; private final ClusterService clusterService; private final FeatureManager featureManager; @@ -151,6 +154,7 @@ public ADBatchTaskRunner( ThreadPool threadPool, ClusterService clusterService, Client client, + SecurityClientUtil clientUtil, ADCircuitBreakerService adCircuitBreakerService, FeatureManager featureManager, ADTaskManager adTaskManager, @@ -166,6 +170,7 @@ public ADBatchTaskRunner( this.threadPool = threadPool; this.clusterService = clusterService; this.client = client; + this.clientUtil = clientUtil; this.anomalyResultBulkIndexHandler = anomalyResultBulkIndexHandler; this.adStats = adStats; this.adCircuitBreakerService = adCircuitBreakerService; @@ -438,7 +443,7 @@ private void searchTopEntitiesForSingleCategoryHC( SearchRequest searchRequest = new SearchRequest(); searchRequest.source(sourceBuilder); searchRequest.indices(adTask.getDetector().getIndices().toArray(new String[0])); - client.search(searchRequest, ActionListener.wrap(r -> { + final ActionListener searchResponseListener = ActionListener.wrap(r -> { StringTerms stringTerms = r.getAggregations().get(topEntitiesAgg); List buckets = stringTerms.getBuckets(); List topEntities = new ArrayList<>(); @@ -473,7 +478,18 @@ private void searchTopEntitiesForSingleCategoryHC( }, e -> { logger.error("Failed to get top entities for detector " + adTask.getDetectorId(), e); internalHCListener.onFailure(e); - })); + }); + // using the original context in listener as user roles have no permissions for internal operations like fetching a + // checkpoint + clientUtil + .asyncRequestWithInjectedSecurity( + searchRequest, + client::search, + // user is the one who started historical detector. Read AnomalyDetectorJobTransportAction.doExecute. + adTask.getUser(), + client, + searchResponseListener + ); } /** @@ -949,8 +965,7 @@ private void getDateRangeOfSourceData(ADTask adTask, BiConsumer cons SearchRequest request = new SearchRequest() .indices(adTask.getDetector().getIndices().toArray(new String[0])) .source(searchSourceBuilder); - - client.search(request, ActionListener.wrap(r -> { + final ActionListener searchResponseListener = ActionListener.wrap(r -> { InternalMin minAgg = r.getAggregations().get(AGG_NAME_MIN_TIME); InternalMax maxAgg = r.getAggregations().get(AGG_NAME_MAX_TIME); double minValue = minAgg.getValue(); @@ -989,7 +1004,18 @@ private void getDateRangeOfSourceData(ADTask adTask, BiConsumer cons return; } consumer.accept(dataStartTime, dataEndTime); - }, e -> { internalListener.onFailure(e); })); + }, e -> { internalListener.onFailure(e); }); + + // inject user role while searching. + clientUtil + .asyncRequestWithInjectedSecurity( + request, + client::search, + // user is the one who started historical detector. Read AnomalyDetectorJobTransportAction.doExecute. + adTask.getUser(), + client, + searchResponseListener + ); } private void getFeatureData( diff --git a/src/main/java/org/opensearch/ad/transport/AnomalyResultTransportAction.java b/src/main/java/org/opensearch/ad/transport/AnomalyResultTransportAction.java index 921a5dc55..af859632c 100644 --- a/src/main/java/org/opensearch/ad/transport/AnomalyResultTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/AnomalyResultTransportAction.java @@ -73,6 +73,7 @@ import org.opensearch.ad.task.ADTaskManager; import org.opensearch.ad.util.ExceptionUtil; import org.opensearch.ad.util.ParseUtils; +import org.opensearch.ad.util.SecurityClientUtil; import org.opensearch.client.Client; import org.opensearch.cluster.ClusterState; import org.opensearch.cluster.block.ClusterBlockLevel; @@ -125,6 +126,7 @@ public class AnomalyResultTransportAction extends HandledTransportAction onFeatureResponseForSingleEntityDete ) { return ActionListener.wrap(featureOptional -> { List featureInResponse = null; - if (featureOptional.getUnprocessedFeatures().isPresent()) { featureInResponse = ParseUtils.getFeatureData(featureOptional.getUnprocessedFeatures().get(), detector); } diff --git a/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorTransportAction.java b/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorTransportAction.java index 563153ec1..f42f700af 100644 --- a/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorTransportAction.java @@ -58,6 +58,7 @@ import org.opensearch.ad.task.ADTaskManager; import org.opensearch.ad.util.DiscoveryNodeFilterer; import org.opensearch.ad.util.RestHandlerUtils; +import org.opensearch.ad.util.SecurityClientUtil; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.CheckedConsumer; @@ -80,7 +81,7 @@ public class GetAnomalyDetectorTransportAction extends HandledTransportAction allProfileTypeStrs; private final Set allProfileTypes; private final Set defaultDetectorProfileTypes; @@ -100,6 +101,7 @@ public GetAnomalyDetectorTransportAction( ActionFilters actionFilters, ClusterService clusterService, Client client, + SecurityClientUtil clientUtil, Settings settings, NamedXContentRegistry xContentRegistry, ADTaskManager adTaskManager @@ -107,7 +109,7 @@ public GetAnomalyDetectorTransportAction( super(GetAnomalyDetectorAction.NAME, transportService, actionFilters, GetAnomalyDetectorRequest::new); this.clusterService = clusterService; this.client = client; - + this.clientUtil = clientUtil; List allProfiles = Arrays.asList(DetectorProfileName.values()); this.allProfileTypes = EnumSet.copyOf(allProfiles); this.allProfileTypeStrs = getProfileListStrs(allProfiles); @@ -165,6 +167,7 @@ protected void getExecute(GetAnomalyDetectorRequest request, ActionListener entityProfilesToCollect = getEntityProfilesToCollect(typesStr, all); EntityProfileRunner profileRunner = new EntityProfileRunner( client, + clientUtil, xContentRegistry, AnomalyDetectorSettings.NUM_MIN_SAMPLES ); @@ -203,6 +206,7 @@ protected void getExecute(GetAnomalyDetectorRequest request, ActionListener profilesToCollect = getProfilesToCollect(typesStr, all); AnomalyDetectorProfileRunner profileRunner = new AnomalyDetectorProfileRunner( client, + clientUtil, xContentRegistry, nodeFilter, AnomalyDetectorSettings.NUM_MIN_SAMPLES, diff --git a/src/main/java/org/opensearch/ad/transport/IndexAnomalyDetectorTransportAction.java b/src/main/java/org/opensearch/ad/transport/IndexAnomalyDetectorTransportAction.java index d40bd295e..d41cf46d7 100644 --- a/src/main/java/org/opensearch/ad/transport/IndexAnomalyDetectorTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/IndexAnomalyDetectorTransportAction.java @@ -36,6 +36,7 @@ import org.opensearch.ad.rest.handler.IndexAnomalyDetectorActionHandler; import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.ad.task.ADTaskManager; +import org.opensearch.ad.util.SecurityClientUtil; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; @@ -53,6 +54,7 @@ public class IndexAnomalyDetectorTransportAction extends HandledTransportAction { private static final Logger LOG = LogManager.getLogger(IndexAnomalyDetectorTransportAction.class); private final Client client; + private final SecurityClientUtil clientUtil; private final TransportService transportService; private final AnomalyDetectionIndices anomalyDetectionIndices; private final ClusterService clusterService; @@ -60,12 +62,14 @@ public class IndexAnomalyDetectorTransportAction extends HandledTransportAction< private final ADTaskManager adTaskManager; private volatile Boolean filterByEnabled; private final SearchFeatureDao searchFeatureDao; + private final Settings settings; @Inject public IndexAnomalyDetectorTransportAction( TransportService transportService, ActionFilters actionFilters, Client client, + SecurityClientUtil clientUtil, ClusterService clusterService, Settings settings, AnomalyDetectionIndices anomalyDetectionIndices, @@ -75,6 +79,7 @@ public IndexAnomalyDetectorTransportAction( ) { super(IndexAnomalyDetectorAction.NAME, transportService, actionFilters, IndexAnomalyDetectorRequest::new); this.client = client; + this.clientUtil = clientUtil; this.transportService = transportService; this.clusterService = clusterService; this.anomalyDetectionIndices = anomalyDetectionIndices; @@ -83,6 +88,7 @@ public IndexAnomalyDetectorTransportAction( this.searchFeatureDao = searchFeatureDao; filterByEnabled = AnomalyDetectorSettings.FILTER_BY_BACKEND_ROLES.get(settings); clusterService.getClusterSettings().addSettingsUpdateConsumer(FILTER_BY_BACKEND_ROLES, it -> filterByEnabled = it); + this.settings = settings; } @Override @@ -157,6 +163,7 @@ protected void adExecute( IndexAnomalyDetectorActionHandler indexAnomalyDetectorActionHandler = new IndexAnomalyDetectorActionHandler( clusterService, client, + clientUtil, transportService, listener, anomalyDetectionIndices, @@ -173,7 +180,8 @@ protected void adExecute( xContentRegistry, detectorUser, adTaskManager, - searchFeatureDao + searchFeatureDao, + settings ); indexAnomalyDetectorActionHandler.start(); }, listener); diff --git a/src/main/java/org/opensearch/ad/transport/ValidateAnomalyDetectorTransportAction.java b/src/main/java/org/opensearch/ad/transport/ValidateAnomalyDetectorTransportAction.java index 7a9bd66f3..a6c3d07a5 100644 --- a/src/main/java/org/opensearch/ad/transport/ValidateAnomalyDetectorTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/ValidateAnomalyDetectorTransportAction.java @@ -39,6 +39,7 @@ import org.opensearch.ad.rest.handler.AnomalyDetectorFunction; import org.opensearch.ad.rest.handler.ValidateAnomalyDetectorActionHandler; import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.ad.util.SecurityClientUtil; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; @@ -58,16 +59,19 @@ public class ValidateAnomalyDetectorTransportAction extends private static final Logger logger = LogManager.getLogger(ValidateAnomalyDetectorTransportAction.class); private final Client client; + private final SecurityClientUtil clientUtil; private final ClusterService clusterService; private final NamedXContentRegistry xContentRegistry; private final AnomalyDetectionIndices anomalyDetectionIndices; private final SearchFeatureDao searchFeatureDao; private volatile Boolean filterByEnabled; private Clock clock; + private Settings settings; @Inject public ValidateAnomalyDetectorTransportAction( Client client, + SecurityClientUtil clientUtil, ClusterService clusterService, NamedXContentRegistry xContentRegistry, Settings settings, @@ -78,6 +82,7 @@ public ValidateAnomalyDetectorTransportAction( ) { super(ValidateAnomalyDetectorAction.NAME, transportService, actionFilters, ValidateAnomalyDetectorRequest::new); this.client = client; + this.clientUtil = clientUtil; this.clusterService = clusterService; this.xContentRegistry = xContentRegistry; this.anomalyDetectionIndices = anomalyDetectionIndices; @@ -85,6 +90,7 @@ public ValidateAnomalyDetectorTransportAction( clusterService.getClusterSettings().addSettingsUpdateConsumer(FILTER_BY_BACKEND_ROLES, it -> filterByEnabled = it); this.searchFeatureDao = searchFeatureDao; this.clock = Clock.systemUTC(); + this.settings = settings; } @Override @@ -143,6 +149,7 @@ private void validateExecute( ValidateAnomalyDetectorActionHandler handler = new ValidateAnomalyDetectorActionHandler( clusterService, client, + clientUtil, validateListener, anomalyDetectionIndices, detector, @@ -155,7 +162,8 @@ private void validateExecute( user, searchFeatureDao, request.getValidationType(), - clock + clock, + settings ); try { handler.start(); diff --git a/src/main/java/org/opensearch/ad/util/ADSafeSecurityInjector.java b/src/main/java/org/opensearch/ad/util/ADSafeSecurityInjector.java new file mode 100644 index 000000000..8fcd11ce0 --- /dev/null +++ b/src/main/java/org/opensearch/ad/util/ADSafeSecurityInjector.java @@ -0,0 +1,78 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.util; + +import java.util.Optional; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.ActionListener; +import org.opensearch.ad.NodeStateManager; +import org.opensearch.ad.common.exception.EndRunException; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.common.Strings; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.commons.authuser.User; + +public class ADSafeSecurityInjector extends SafeSecurityInjector { + private static final Logger LOG = LogManager.getLogger(ADSafeSecurityInjector.class); + private NodeStateManager nodeStateManager; + + public ADSafeSecurityInjector(String detectorId, Settings settings, ThreadContext tc, NodeStateManager stateManager) { + super(detectorId, settings, tc); + this.nodeStateManager = stateManager; + } + + public void injectUserRolesFromDetector(ActionListener injectListener) { + // if id is null, we cannot fetch a detector + if (Strings.isEmpty(id)) { + LOG.debug("Empty id"); + injectListener.onResponse(null); + return; + } + + // for example, if a user exists in thread context, we don't need to inject user/roles + if (!shouldInject()) { + LOG.debug("Don't need to inject"); + injectListener.onResponse(null); + return; + } + + ActionListener> getDetectorListener = ActionListener.wrap(detectorOp -> { + if (!detectorOp.isPresent()) { + injectListener.onFailure(new EndRunException(id, "AnomalyDetector is not available.", false)); + return; + } + AnomalyDetector detector = detectorOp.get(); + User userInfo = SecurityUtil.getUserFromDetector(detector, settings); + inject(userInfo.getName(), userInfo.getRoles()); + injectListener.onResponse(null); + }, injectListener::onFailure); + + // Since we are gonna read user from detector, make sure the anomaly detector exists and fetched from disk or cached memory + // We don't accept a passed-in AnomalyDetector because the caller might mistakenly not insert any user info in the + // constructed AnomalyDetector and thus poses risks. In the case, if the user is null, we will give admin role. + nodeStateManager.getAnomalyDetector(id, getDetectorListener); + } + + public void injectUserRoles(User user) { + if (user == null) { + LOG.debug("null user"); + return; + } + + if (shouldInject()) { + inject(user.getName(), user.getRoles()); + } + } +} diff --git a/src/main/java/org/opensearch/ad/util/MultiResponsesDelegateActionListener.java b/src/main/java/org/opensearch/ad/util/MultiResponsesDelegateActionListener.java index 3fc1468b7..8b18bf9c3 100644 --- a/src/main/java/org/opensearch/ad/util/MultiResponsesDelegateActionListener.java +++ b/src/main/java/org/opensearch/ad/util/MultiResponsesDelegateActionListener.java @@ -70,7 +70,7 @@ public void onResponse(T response) { @Override public void onFailure(Exception e) { - LOG.error(e); + LOG.error("Failure in response", e); try { this.exceptions.add(e.getMessage()); } finally { diff --git a/src/main/java/org/opensearch/ad/util/ParseUtils.java b/src/main/java/org/opensearch/ad/util/ParseUtils.java index 3df9aa9d5..4a96245ff 100644 --- a/src/main/java/org/opensearch/ad/util/ParseUtils.java +++ b/src/main/java/org/opensearch/ad/util/ParseUtils.java @@ -456,6 +456,14 @@ public static SearchSourceBuilder addUserBackendRolesFilter(User user, SearchSou return searchSourceBuilder; } + /** + * Generates a user string formed by the username, backend roles, roles and requested tenants separated by '|' + * (e.g., john||own_index,testrole|__user__, no backend role so you see two verticle line after john.). + * This is the user string format used internally in the OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT and may be + * parsed using User.parse(string). + * @param client Client containing user info. A public API request will fill in the user info in the thread context. + * @return parsed user object + */ public static User getUserContext(Client client) { String userStr = client.threadPool().getThreadContext().getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); logger.debug("Filtering result by " + userStr); diff --git a/src/main/java/org/opensearch/ad/util/SafeSecurityInjector.java b/src/main/java/org/opensearch/ad/util/SafeSecurityInjector.java new file mode 100644 index 000000000..612ea4d5c --- /dev/null +++ b/src/main/java/org/opensearch/ad/util/SafeSecurityInjector.java @@ -0,0 +1,87 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.util; + +import java.util.List; +import java.util.Locale; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.commons.ConfigConstants; +import org.opensearch.commons.InjectSecurity; + +public abstract class SafeSecurityInjector implements AutoCloseable { + private static final Logger LOG = LogManager.getLogger(SafeSecurityInjector.class); + // user header used by security plugin. As we cannot take security plugin as + // a compile dependency, we have to duplicate it here. + private static final String OPENDISTRO_SECURITY_USER = "_opendistro_security_user"; + + private InjectSecurity rolesInjectorHelper; + protected String id; + protected Settings settings; + protected ThreadContext tc; + + public SafeSecurityInjector(String id, Settings settings, ThreadContext tc) { + this.id = id; + this.settings = settings; + this.tc = tc; + this.rolesInjectorHelper = null; + } + + protected boolean shouldInject() { + if (id == null || settings == null || tc == null) { + LOG.debug(String.format(Locale.ROOT, "null value: id: %s, settings: %s, threadContext: %s", id, settings, tc)); + return false; + } + // user not null means the request comes from user (e.g., public restful API) + // we don't need to inject roles. + Object userIn = tc.getTransient(OPENDISTRO_SECURITY_USER); + if (userIn != null) { + LOG.debug(new ParameterizedMessage("User not empty in thread context: [{}]", userIn)); + return false; + } + userIn = tc.getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); + if (userIn != null) { + LOG.debug(new ParameterizedMessage("User not empty in thread context: [{}]", userIn)); + return false; + } + Object rolesin = tc.getTransient(ConfigConstants.OPENSEARCH_SECURITY_INJECTED_ROLES); + if (rolesin != null) { + LOG.warn(new ParameterizedMessage("Injected roles not empty in thread context: [{}]", rolesin)); + return false; + } + + return true; + } + + protected void inject(String user, List roles) { + if (roles == null) { + LOG.warn("Cannot inject empty roles in thread context"); + return; + } + if (rolesInjectorHelper == null) { + // lazy init + rolesInjectorHelper = new InjectSecurity(id, settings, tc); + } + rolesInjectorHelper.inject(user, roles); + } + + @Override + public void close() { + if (rolesInjectorHelper != null) { + rolesInjectorHelper.close(); + } + } +} diff --git a/src/main/java/org/opensearch/ad/util/SecurityClientUtil.java b/src/main/java/org/opensearch/ad/util/SecurityClientUtil.java new file mode 100644 index 000000000..8e9b97b57 --- /dev/null +++ b/src/main/java/org/opensearch/ad/util/SecurityClientUtil.java @@ -0,0 +1,130 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.util; + +import java.util.function.BiConsumer; + +import org.opensearch.action.ActionListener; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionResponse; +import org.opensearch.action.ActionType; +import org.opensearch.ad.NodeStateManager; +import org.opensearch.client.Client; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.commons.authuser.User; + +public class SecurityClientUtil { + private static final String INJECTION_ID = "direct"; + private NodeStateManager nodeStateManager; + private Settings settings; + + @Inject + public SecurityClientUtil(NodeStateManager nodeStateManager, Settings settings) { + this.nodeStateManager = nodeStateManager; + this.settings = settings; + } + + /** + * Send an asynchronous request in the context of user role and handle response with the provided listener. The role + * is recorded in a detector config. + * @param ActionRequest + * @param ActionResponse + * @param request request body + * @param consumer request method, functional interface to operate as a client request like client::get + * @param detectorId Detector id + * @param client OpenSearch client + * @param listener needed to handle response + */ + public void asyncRequestWithInjectedSecurity( + Request request, + BiConsumer> consumer, + String detectorId, + Client client, + ActionListener listener + ) { + ThreadContext threadContext = client.threadPool().getThreadContext(); + try (ADSafeSecurityInjector injectSecurity = new ADSafeSecurityInjector(detectorId, settings, threadContext, nodeStateManager)) { + injectSecurity + .injectUserRolesFromDetector( + ActionListener + .wrap( + success -> consumer.accept(request, ActionListener.runBefore(listener, () -> injectSecurity.close())), + listener::onFailure + ) + ); + } + } + + /** + * Send an asynchronous request in the context of user role and handle response with the provided listener. The role + * is provided in the arguments. + * @param ActionRequest + * @param ActionResponse + * @param request request body + * @param consumer request method, functional interface to operate as a client request like client::get + * @param user User info + * @param client OpenSearch client + * @param listener needed to handle response + */ + public void asyncRequestWithInjectedSecurity( + Request request, + BiConsumer> consumer, + User user, + Client client, + ActionListener listener + ) { + ThreadContext threadContext = client.threadPool().getThreadContext(); + // use a hardcoded string as detector id that is only used in logging + // Question: + // Will the try-with-resources statement auto close injectSecurity? + // Here the injectSecurity is closed explicitly. So we don't need to put the injectSecurity inside try ? + // Explanation: + // There might be two threads: one thread covers try, inject, and triggers client.execute/client.search + // (this can be a thread in the write thread pool); another thread actually execute the logic of + // client.execute/client.search and handles the responses (this can be a thread in the search thread pool). + // Auto-close in try will restore the context in one thread; the explicit close injectSecurity will restore + // the context in another thread. So we still need to put the injectSecurity inside try. + try (ADSafeSecurityInjector injectSecurity = new ADSafeSecurityInjector(INJECTION_ID, settings, threadContext, nodeStateManager)) { + injectSecurity.injectUserRoles(user); + consumer.accept(request, ActionListener.runBefore(listener, () -> injectSecurity.close())); + } + } + + /** + * Execute a transport action in the context of user role and handle response with the provided listener. The role + * is provided in the arguments. + * @param ActionRequest + * @param ActionResponse + * @param action transport action + * @param request request body + * @param user User info + * @param client OpenSearch client + * @param listener needed to handle response + */ + public void executeWithInjectedSecurity( + ActionType action, + Request request, + User user, + Client client, + ActionListener listener + ) { + ThreadContext threadContext = client.threadPool().getThreadContext(); + + // use a hardcoded string as detector id that is only used in logging + try (ADSafeSecurityInjector injectSecurity = new ADSafeSecurityInjector(INJECTION_ID, settings, threadContext, nodeStateManager)) { + injectSecurity.injectUserRoles(user); + client.execute(action, request, ActionListener.runBefore(listener, () -> injectSecurity.close())); + } + } +} diff --git a/src/main/java/org/opensearch/ad/util/SecurityUtil.java b/src/main/java/org/opensearch/ad/util/SecurityUtil.java new file mode 100644 index 000000000..d72d345ab --- /dev/null +++ b/src/main/java/org/opensearch/ad/util/SecurityUtil.java @@ -0,0 +1,77 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.util; + +import java.util.Collections; +import java.util.List; + +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.AnomalyDetectorJob; +import org.opensearch.common.settings.Settings; +import org.opensearch.commons.authuser.User; + +import com.google.common.collect.ImmutableList; + +public class SecurityUtil { + /** + * @param userObj the last user who edited the detector config + * @param settings Node settings + * @return converted user for bwc if necessary + */ + private static User getAdjustedUserBWC(User userObj, Settings settings) { + /* + * We need to handle 3 cases: + * 1. Detectors created by older versions and never updated. These detectors wont have User details in the + * detector object. `detector.user` will be null. Insert `all_access, AmazonES_all_access` role. + * 2. Detectors are created when security plugin is disabled, these will have empty User object. + * (`detector.user.name`, `detector.user.roles` are empty ) + * 3. Detectors are created when security plugin is enabled, these will have an User object. + * This will inject user role and check if the user role has permissions to call the execute + * Anomaly Result API. + */ + String user; + List roles; + if (userObj == null) { + // It's possible that user create domain with security disabled, then enable security + // after upgrading. This is for BWC, for old detectors which created when security + // disabled, the user will be null. + // This is a huge promotion in privileges. To prevent a caller code from making a mistake and pass a null object, + // we make the method private and only allow fetching user object from detector or job configuration (see the public + // access methods with the same name). + user = ""; + roles = settings.getAsList("", ImmutableList.of("all_access", "AmazonES_all_access")); + return new User(user, Collections.emptyList(), roles, Collections.emptyList()); + } else { + return userObj; + } + } + + /** + * * + * @param detector Detector config + * @param settings Node settings + * @return user recorded by a detector. Made adjstument for BWC (backward-compatibility) if necessary. + */ + public static User getUserFromDetector(AnomalyDetector detector, Settings settings) { + return getAdjustedUserBWC(detector.getUser(), settings); + } + + /** + * * + * @param detectorJob Detector Job + * @param settings Node settings + * @return user recorded by a detector job + */ + public static User getUserFromJob(AnomalyDetectorJob detectorJob, Settings settings) { + return getAdjustedUserBWC(detectorJob.getUser(), settings); + } +} diff --git a/src/main/java/org/opensearch/ad/util/Throttler.java b/src/main/java/org/opensearch/ad/util/Throttler.java index bd2e3d66e..177b612a2 100644 --- a/src/main/java/org/opensearch/ad/util/Throttler.java +++ b/src/main/java/org/opensearch/ad/util/Throttler.java @@ -29,16 +29,20 @@ public class Throttler { private final ConcurrentHashMap> negativeCache; private final Clock clock; - /** - * Inject annotation required by Guice to instantiate EntityResultTransportAction (transitive dependency) - * (EntityResultTransportAction > ResultHandler > ClientUtil > Throttler) - * @param clock a UTC clock - */ public Throttler(Clock clock) { this.negativeCache = new ConcurrentHashMap<>(); this.clock = clock; } + /** + * This will be used when dependency injection directly/indirectly injects a Throttler object. Without this object, + * node start might fail due to not being able to find a Clock object. We removed Clock object association in + * https://github.com/opendistro-for-elasticsearch/anomaly-detection/pull/305 + */ + public Throttler() { + this(Clock.systemUTC()); + } + /** * Get negative cache value(ActionRequest, Instant) for given detector * @param detectorId AnomalyDetector ID diff --git a/src/test/java/org/opensearch/action/admin/indices/mapping/get/IndexAnomalyDetectorActionHandlerTests.java b/src/test/java/org/opensearch/action/admin/indices/mapping/get/IndexAnomalyDetectorActionHandlerTests.java index 9feaeee45..94571f257 100644 --- a/src/test/java/org/opensearch/action/admin/indices/mapping/get/IndexAnomalyDetectorActionHandlerTests.java +++ b/src/test/java/org/opensearch/action/admin/indices/mapping/get/IndexAnomalyDetectorActionHandlerTests.java @@ -23,6 +23,7 @@ import static org.opensearch.ad.model.AnomalyDetector.ANOMALY_DETECTORS_INDEX; import java.io.IOException; +import java.time.Clock; import java.util.Arrays; import java.util.Locale; import java.util.concurrent.TimeUnit; @@ -45,6 +46,7 @@ import org.opensearch.action.search.SearchResponse; import org.opensearch.action.support.WriteRequest; import org.opensearch.ad.AbstractADTest; +import org.opensearch.ad.NodeStateManager; import org.opensearch.ad.TestHelpers; import org.opensearch.ad.common.exception.ADValidationException; import org.opensearch.ad.constant.CommonName; @@ -54,6 +56,7 @@ import org.opensearch.ad.rest.handler.IndexAnomalyDetectorActionHandler; import org.opensearch.ad.task.ADTaskManager; import org.opensearch.ad.transport.IndexAnomalyDetectorResponse; +import org.opensearch.ad.util.SecurityClientUtil; import org.opensearch.client.node.NodeClient; import org.opensearch.cluster.ClusterName; import org.opensearch.cluster.ClusterState; @@ -81,6 +84,7 @@ public class IndexAnomalyDetectorActionHandlerTests extends AbstractADTest { private IndexAnomalyDetectorActionHandler handler; private ClusterService clusterService; private NodeClient clientMock; + private SecurityClientUtil clientUtil; private TransportService transportService; private ActionListener channel; private AnomalyDetectionIndices anomalyDetectionIndices; @@ -97,6 +101,7 @@ public class IndexAnomalyDetectorActionHandlerTests extends AbstractADTest { private RestRequest.Method method; private ADTaskManager adTaskManager; private SearchFeatureDao searchFeatureDao; + private Clock clock; @BeforeClass public static void beforeClass() { @@ -118,6 +123,9 @@ public void setUp() throws Exception { settings = Settings.EMPTY; clusterService = mock(ClusterService.class); clientMock = spy(new NodeClient(settings, threadPool)); + clock = mock(Clock.class); + NodeStateManager nodeStateManager = mock(NodeStateManager.class); + clientUtil = new SecurityClientUtil(nodeStateManager, settings); transportService = mock(TransportService.class); channel = mock(ActionListener.class); @@ -151,6 +159,7 @@ public void setUp() throws Exception { handler = new IndexAnomalyDetectorActionHandler( clusterService, clientMock, + clientUtil, transportService, channel, anomalyDetectionIndices, @@ -167,7 +176,8 @@ public void setUp() throws Exception { xContentRegistry(), null, adTaskManager, - searchFeatureDao + searchFeatureDao, + Settings.EMPTY ); } @@ -194,10 +204,13 @@ public void testMoreThanTenThousandSingleEntityDetectors() throws IOException { // we can also use spy to overstep the final methods NodeClient client = getCustomNodeClient(detectorResponse, userIndexResponse, detector, threadPool); NodeClient clientSpy = spy(client); + NodeStateManager nodeStateManager = mock(NodeStateManager.class); + clientUtil = new SecurityClientUtil(nodeStateManager, settings); handler = new IndexAnomalyDetectorActionHandler( clusterService, clientSpy, + clientUtil, transportService, channel, anomalyDetectionIndices, @@ -215,7 +228,8 @@ public void testMoreThanTenThousandSingleEntityDetectors() throws IOException { xContentRegistry(), null, adTaskManager, - searchFeatureDao + searchFeatureDao, + Settings.EMPTY ); handler.start(); @@ -268,10 +282,13 @@ public void doE } } }; + NodeStateManager nodeStateManager = mock(NodeStateManager.class); + clientUtil = new SecurityClientUtil(nodeStateManager, Settings.EMPTY); handler = new IndexAnomalyDetectorActionHandler( clusterService, client, + clientUtil, transportService, channel, anomalyDetectionIndices, @@ -288,7 +305,8 @@ public void doE xContentRegistry(), null, adTaskManager, - searchFeatureDao + searchFeatureDao, + Settings.EMPTY ); ArgumentCaptor response = ArgumentCaptor.forClass(Exception.class); @@ -351,10 +369,13 @@ public void doE }; NodeClient clientSpy = spy(client); + NodeStateManager nodeStateManager = mock(NodeStateManager.class); + clientUtil = new SecurityClientUtil(nodeStateManager, Settings.EMPTY); handler = new IndexAnomalyDetectorActionHandler( clusterService, clientSpy, + clientUtil, transportService, channel, anomalyDetectionIndices, @@ -371,7 +392,8 @@ public void doE xContentRegistry(), null, adTaskManager, - searchFeatureDao + searchFeatureDao, + Settings.EMPTY ); ArgumentCaptor response = ArgumentCaptor.forClass(Exception.class); @@ -443,6 +465,8 @@ public void doE }; NodeClient clientSpy = spy(client); + NodeStateManager nodeStateManager = mock(NodeStateManager.class); + clientUtil = new SecurityClientUtil(nodeStateManager, Settings.EMPTY); ClusterName clusterName = new ClusterName("test"); ClusterState clusterState = ClusterState.builder(clusterName).metadata(Metadata.builder().build()).build(); when(clusterService.state()).thenReturn(clusterState); @@ -450,6 +474,7 @@ public void doE handler = new IndexAnomalyDetectorActionHandler( clusterService, clientSpy, + clientUtil, transportService, channel, anomalyDetectionIndices, @@ -466,7 +491,8 @@ public void doE xContentRegistry(), null, adTaskManager, - searchFeatureDao + searchFeatureDao, + Settings.EMPTY ); ArgumentCaptor response = ArgumentCaptor.forClass(Exception.class); @@ -548,10 +574,13 @@ public void testMoreThanTenMultiEntityDetectors() throws IOException { // we can also use spy to overstep the final methods NodeClient client = getCustomNodeClient(detectorResponse, userIndexResponse, detector, threadPool); NodeClient clientSpy = spy(client); + NodeStateManager nodeStateManager = mock(NodeStateManager.class); + clientUtil = new SecurityClientUtil(nodeStateManager, settings); handler = new IndexAnomalyDetectorActionHandler( clusterService, clientSpy, + clientUtil, transportService, channel, anomalyDetectionIndices, @@ -568,7 +597,8 @@ public void testMoreThanTenMultiEntityDetectors() throws IOException { xContentRegistry(), null, adTaskManager, - searchFeatureDao + searchFeatureDao, + Settings.EMPTY ); handler.start(); @@ -638,6 +668,7 @@ public void testTenMultiEntityDetectorsUpdateSingleEntityAdToMulti() throws IOEx handler = new IndexAnomalyDetectorActionHandler( clusterService, clientMock, + clientUtil, transportService, channel, anomalyDetectionIndices, @@ -654,7 +685,8 @@ public void testTenMultiEntityDetectorsUpdateSingleEntityAdToMulti() throws IOEx xContentRegistry(), null, adTaskManager, - searchFeatureDao + searchFeatureDao, + Settings.EMPTY ); handler.start(); @@ -720,6 +752,7 @@ public void testTenMultiEntityDetectorsUpdateExistingMultiEntityAd() throws IOEx handler = new IndexAnomalyDetectorActionHandler( clusterService, clientMock, + clientUtil, transportService, channel, anomalyDetectionIndices, @@ -736,7 +769,8 @@ public void testTenMultiEntityDetectorsUpdateExistingMultiEntityAd() throws IOEx xContentRegistry(), null, adTaskManager, - searchFeatureDao + searchFeatureDao, + Settings.EMPTY ); handler.start(); diff --git a/src/test/java/org/opensearch/action/admin/indices/mapping/get/ValidateAnomalyDetectorActionHandlerTests.java b/src/test/java/org/opensearch/action/admin/indices/mapping/get/ValidateAnomalyDetectorActionHandlerTests.java index dfd03633d..d56ea2b51 100644 --- a/src/test/java/org/opensearch/action/admin/indices/mapping/get/ValidateAnomalyDetectorActionHandlerTests.java +++ b/src/test/java/org/opensearch/action/admin/indices/mapping/get/ValidateAnomalyDetectorActionHandlerTests.java @@ -33,6 +33,7 @@ import org.opensearch.action.search.SearchResponse; import org.opensearch.action.support.WriteRequest; import org.opensearch.ad.AbstractADTest; +import org.opensearch.ad.NodeStateManager; import org.opensearch.ad.TestHelpers; import org.opensearch.ad.common.exception.ADValidationException; import org.opensearch.ad.feature.SearchFeatureDao; @@ -44,6 +45,7 @@ import org.opensearch.ad.rest.handler.ValidateAnomalyDetectorActionHandler; import org.opensearch.ad.task.ADTaskManager; import org.opensearch.ad.transport.ValidateAnomalyDetectorResponse; +import org.opensearch.ad.util.SecurityClientUtil; import org.opensearch.client.Client; import org.opensearch.client.node.NodeClient; import org.opensearch.cluster.service.ClusterService; @@ -141,10 +143,13 @@ public void testValidateMoreThanThousandSingleEntityDetectorLimit() throws IOExc .getCustomNodeClient(detectorResponse, userIndexResponse, singleEntityDetector, threadPool); NodeClient clientSpy = spy(client); + NodeStateManager nodeStateManager = mock(NodeStateManager.class); + SecurityClientUtil clientUtil = new SecurityClientUtil(nodeStateManager, settings); handler = new ValidateAnomalyDetectorActionHandler( clusterService, clientSpy, + clientUtil, channel, anomalyDetectionIndices, singleEntityDetector, @@ -157,7 +162,8 @@ public void testValidateMoreThanThousandSingleEntityDetectorLimit() throws IOExc null, searchFeatureDao, ValidationAspect.DETECTOR.getName(), - clock + clock, + settings ); handler.start(); ArgumentCaptor response = ArgumentCaptor.forClass(Exception.class); @@ -191,10 +197,13 @@ public void testValidateMoreThanTenMultiEntityDetectorsLimit() throws IOExceptio NodeClient client = IndexAnomalyDetectorActionHandlerTests .getCustomNodeClient(detectorResponse, userIndexResponse, detector, threadPool); NodeClient clientSpy = spy(client); + NodeStateManager nodeStateManager = mock(NodeStateManager.class); + SecurityClientUtil clientUtil = new SecurityClientUtil(nodeStateManager, settings); handler = new ValidateAnomalyDetectorActionHandler( clusterService, clientSpy, + clientUtil, channel, anomalyDetectionIndices, detector, @@ -207,7 +216,8 @@ public void testValidateMoreThanTenMultiEntityDetectorsLimit() throws IOExceptio null, searchFeatureDao, "", - clock + clock, + Settings.EMPTY ); handler.start(); ArgumentCaptor response = ArgumentCaptor.forClass(Exception.class); diff --git a/src/test/java/org/opensearch/ad/AbstractProfileRunnerTests.java b/src/test/java/org/opensearch/ad/AbstractProfileRunnerTests.java index 715425c2e..a47ec9968 100644 --- a/src/test/java/org/opensearch/ad/AbstractProfileRunnerTests.java +++ b/src/test/java/org/opensearch/ad/AbstractProfileRunnerTests.java @@ -19,12 +19,14 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; +import java.time.Clock; import java.util.Arrays; import java.util.HashSet; import java.util.Optional; import java.util.Set; import java.util.function.Consumer; +import org.junit.AfterClass; import org.junit.Before; import org.junit.BeforeClass; import org.opensearch.Version; @@ -33,7 +35,9 @@ import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.model.DetectorProfileName; import org.opensearch.ad.task.ADTaskManager; +import org.opensearch.ad.transport.AnomalyResultTests; import org.opensearch.ad.util.DiscoveryNodeFilterer; +import org.opensearch.ad.util.SecurityClientUtil; import org.opensearch.client.Client; import org.opensearch.cluster.ClusterName; import org.opensearch.cluster.ClusterState; @@ -65,6 +69,7 @@ protected enum ErrorResultStatus { protected AnomalyDetectorProfileRunner runner; protected Client client; + protected SecurityClientUtil clientUtil; protected DiscoveryNodeFilterer nodeFilter; protected AnomalyDetector detector; protected ClusterService clusterService; @@ -141,6 +146,12 @@ public static void setUpOnce() { emptySet(), Version.CURRENT ); + setUpThreadPool(AnomalyResultTests.class.getSimpleName()); + } + + @AfterClass + public static void tearDownAfterClass() { + tearDownThreadPool(); } @SuppressWarnings("unchecked") @@ -149,6 +160,9 @@ public static void setUpOnce() { public void setUp() throws Exception { super.setUp(); client = mock(Client.class); + when(client.threadPool()).thenReturn(threadPool); + Clock clock = mock(Clock.class); + nodeFilter = mock(DiscoveryNodeFilterer.class); clusterService = mock(ClusterService.class); adTaskManager = mock(ADTaskManager.class); @@ -163,7 +177,6 @@ public void setUp() throws Exception { function.accept(Optional.of(TestHelpers.randomAdTask())); return null; }).when(adTaskManager).getAndExecuteOnLatestDetectorLevelTask(any(), any(), any(), any(), anyBoolean(), any()); - runner = new AnomalyDetectorProfileRunner(client, xContentRegistry(), nodeFilter, requiredSamples, transportService, adTaskManager); detectorIntervalMin = 3; detectorGetReponse = mock(GetResponse.class); diff --git a/src/test/java/org/opensearch/ad/AnomalyDetectorProfileRunnerTests.java b/src/test/java/org/opensearch/ad/AnomalyDetectorProfileRunnerTests.java index cc703da5d..57a6d11b4 100644 --- a/src/test/java/org/opensearch/ad/AnomalyDetectorProfileRunnerTests.java +++ b/src/test/java/org/opensearch/ad/AnomalyDetectorProfileRunnerTests.java @@ -15,8 +15,8 @@ import static java.util.Collections.emptySet; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyBoolean; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.when; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.*; import static org.opensearch.ad.model.AnomalyDetector.ANOMALY_DETECTORS_INDEX; import static org.opensearch.ad.model.AnomalyDetectorJob.ANOMALY_DETECTOR_JOB_INDEX; @@ -59,9 +59,11 @@ import org.opensearch.ad.transport.ProfileResponse; import org.opensearch.ad.transport.RCFPollingAction; import org.opensearch.ad.transport.RCFPollingResponse; +import org.opensearch.ad.util.SecurityClientUtil; import org.opensearch.cluster.ClusterName; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.common.io.stream.NotSerializableExceptionWrapper; +import org.opensearch.common.settings.Settings; import org.opensearch.common.transport.TransportAddress; import org.opensearch.index.IndexNotFoundException; import org.opensearch.transport.RemoteTransportException; @@ -96,6 +98,22 @@ private void setUpClientGet( ErrorResultStatus errorResultStatus ) throws IOException { detector = TestHelpers.randomAnomalyDetectorWithInterval(new IntervalTimeConfiguration(detectorIntervalMin, ChronoUnit.MINUTES)); + NodeStateManager nodeStateManager = mock(NodeStateManager.class); + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.of(detector)); + return null; + }).when(nodeStateManager).getAnomalyDetector(anyString(), any(ActionListener.class)); + clientUtil = new SecurityClientUtil(nodeStateManager, Settings.EMPTY); + runner = new AnomalyDetectorProfileRunner( + client, + clientUtil, + xContentRegistry(), + nodeFilter, + requiredSamples, + transportService, + adTaskManager + ); doAnswer(invocation -> { Object[] args = invocation.getArguments(); @@ -614,7 +632,7 @@ public void testInitNoIndex() throws IOException, InterruptedException { public void testInvalidRequiredSamples() { expectThrows( IllegalArgumentException.class, - () -> new AnomalyDetectorProfileRunner(client, xContentRegistry(), nodeFilter, 0, transportService, adTaskManager) + () -> new AnomalyDetectorProfileRunner(client, clientUtil, xContentRegistry(), nodeFilter, 0, transportService, adTaskManager) ); } diff --git a/src/test/java/org/opensearch/ad/AnomalyDetectorRestTestCase.java b/src/test/java/org/opensearch/ad/AnomalyDetectorRestTestCase.java index b344cb3c2..9d7f54e8c 100644 --- a/src/test/java/org/opensearch/ad/AnomalyDetectorRestTestCase.java +++ b/src/test/java/org/opensearch/ad/AnomalyDetectorRestTestCase.java @@ -94,13 +94,16 @@ protected AnomalyDetector createRandomAnomalyDetector( if (indexName == null) { detector = TestHelpers.randomAnomalyDetector(uiMetadata, null, featureEnabled); + TestHelpers.createIndexWithTimeField(client(), detector.getIndices().get(0), detector.getTimeField()); TestHelpers .makeRequest( client, "POST", "/" + detector.getIndices().get(0) + "/_doc/" + randomAlphaOfLength(5) + "?refresh=true", ImmutableMap.of(), - TestHelpers.toHttpEntity("{\"name\": \"test\"}"), + // avoid validation error as validation API will check at least 1 document and the timestamp field + // exists in index mapping + TestHelpers.toHttpEntity("{\"name\": \"test\", \"" + detector.getTimeField() + "\" : \"1661386754000\"}"), null, false ); @@ -511,6 +514,54 @@ public Response createSearchRole(String role, String index) throws IOException { ); } + public Response createDlsRole(String role, String index) throws IOException { + return TestHelpers + .makeRequest( + client(), + "PUT", + "/_opendistro/_security/api/roles/" + role, + null, + TestHelpers + .toHttpEntity( + "{\n" + + "\"cluster_permissions\": [\n" + + "unlimited\n" + + "],\n" + + "\"index_permissions\": [\n" + + "{\n" + + "\"index_patterns\": [\n" + + "\"" + + index + + "\"\n" + + "],\n" + + "\"dls\": \"\"\"{ \"bool\": { \"must\": { \"match\": { \"foo\": \"bar\" }}}}\"\"\",\n" + + "\"fls\": [],\n" + + "\"masked_fields\": [],\n" + + "\"allowed_actions\": [\n" + + "\"unlimited\"\n" + + "]\n" + + "},\n" + + "{\n" + + "\"index_patterns\": [\n" + + "\"" + + "*" + + "\"\n" + + "],\n" + + "\"dls\": \"\",\n" + + "\"fls\": [],\n" + + "\"masked_fields\": [],\n" + + "\"allowed_actions\": [\n" + + "\"unlimited\"\n" + + "]\n" + + "}\n" + + "],\n" + + "\"tenant_permissions\": []\n" + + "}" + ), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana")) + ); + } + public Response deleteUser(String user) throws IOException { return TestHelpers .makeRequest( diff --git a/src/test/java/org/opensearch/ad/EntityProfileRunnerTests.java b/src/test/java/org/opensearch/ad/EntityProfileRunnerTests.java index dca8ac43e..42c7d3676 100644 --- a/src/test/java/org/opensearch/ad/EntityProfileRunnerTests.java +++ b/src/test/java/org/opensearch/ad/EntityProfileRunnerTests.java @@ -13,21 +13,20 @@ import static java.util.Collections.emptyMap; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.*; import static org.opensearch.ad.model.AnomalyDetector.ANOMALY_DETECTORS_INDEX; import static org.opensearch.ad.model.AnomalyDetectorJob.ANOMALY_DETECTOR_JOB_INDEX; import java.io.IOException; import java.time.temporal.ChronoUnit; -import java.util.Arrays; -import java.util.Collections; -import java.util.HashSet; -import java.util.Set; +import java.util.*; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import org.apache.lucene.search.TotalHits; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; import org.opensearch.action.ActionListener; import org.opensearch.action.get.GetRequest; import org.opensearch.action.get.GetResponse; @@ -49,7 +48,9 @@ import org.opensearch.ad.model.ModelProfileOnNode; import org.opensearch.ad.transport.EntityProfileAction; import org.opensearch.ad.transport.EntityProfileResponse; +import org.opensearch.ad.util.SecurityClientUtil; import org.opensearch.client.Client; +import org.opensearch.common.settings.Settings; import org.opensearch.index.IndexNotFoundException; import org.opensearch.search.DocValueFormat; import org.opensearch.search.SearchHit; @@ -62,6 +63,7 @@ public class EntityProfileRunnerTests extends AbstractADTest { private AnomalyDetector detector; private int detectorIntervalMin; private Client client; + private SecurityClientUtil clientUtil; private EntityProfileRunner runner; private Set state; private Set initNInfo; @@ -87,8 +89,19 @@ enum InittedEverResultStatus { NOT_INITTED, } + @BeforeClass + public static void setUpBeforeClass() { + setUpThreadPool(AnomalyDetectorJobRunnerTests.class.getSimpleName()); + } + + @AfterClass + public static void tearDownAfterClass() { + tearDownThreadPool(); + } + @SuppressWarnings("unchecked") @Override + @Before public void setUp() throws Exception { super.setUp(); detectorIntervalMin = 3; @@ -106,15 +119,23 @@ public void setUp() throws Exception { detectorId = "A69pa3UBHuCbh-emo9oR"; entityValue = "app-0"; - requiredSamples = 128; - client = mock(Client.class); - - runner = new EntityProfileRunner(client, xContentRegistry(), requiredSamples); - categoryField = "a"; detector = TestHelpers.randomAnomalyDetectorUsingCategoryFields(detectorId, Arrays.asList(categoryField)); job = TestHelpers.randomAnomalyDetectorJob(true); + requiredSamples = 128; + client = mock(Client.class); + when(client.threadPool()).thenReturn(threadPool); + NodeStateManager nodeStateManager = mock(NodeStateManager.class); + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.of(detector)); + return null; + }).when(nodeStateManager).getAnomalyDetector(any(String.class), any(ActionListener.class)); + clientUtil = new SecurityClientUtil(nodeStateManager, Settings.EMPTY); + + runner = new EntityProfileRunner(client, clientUtil, xContentRegistry(), requiredSamples); + doAnswer(invocation -> { Object[] args = invocation.getArguments(); GetRequest request = (GetRequest) args[0]; diff --git a/src/test/java/org/opensearch/ad/MultiEntityProfileRunnerTests.java b/src/test/java/org/opensearch/ad/MultiEntityProfileRunnerTests.java index c989871fb..d0753c481 100644 --- a/src/test/java/org/opensearch/ad/MultiEntityProfileRunnerTests.java +++ b/src/test/java/org/opensearch/ad/MultiEntityProfileRunnerTests.java @@ -20,6 +20,7 @@ import static org.opensearch.ad.model.AnomalyDetector.ANOMALY_DETECTORS_INDEX; import static org.opensearch.ad.model.AnomalyDetectorJob.ANOMALY_DETECTOR_JOB_INDEX; +import java.time.Clock; import java.time.Instant; import java.util.ArrayList; import java.util.Arrays; @@ -34,7 +35,9 @@ import java.util.concurrent.TimeUnit; import java.util.function.Consumer; +import org.junit.AfterClass; import org.junit.Before; +import org.junit.BeforeClass; import org.opensearch.Version; import org.opensearch.action.ActionListener; import org.opensearch.action.FailedNodeException; @@ -52,19 +55,22 @@ import org.opensearch.ad.model.DetectorProfileName; import org.opensearch.ad.model.DetectorState; import org.opensearch.ad.task.ADTaskManager; +import org.opensearch.ad.transport.AnomalyResultTests; import org.opensearch.ad.transport.ProfileAction; import org.opensearch.ad.transport.ProfileNodeResponse; import org.opensearch.ad.transport.ProfileResponse; -import org.opensearch.ad.util.DiscoveryNodeFilterer; +import org.opensearch.ad.util.*; import org.opensearch.client.Client; import org.opensearch.cluster.ClusterName; import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.common.settings.Settings; import org.opensearch.common.transport.TransportAddress; import org.opensearch.transport.TransportService; public class MultiEntityProfileRunnerTests extends AbstractADTest { private AnomalyDetectorProfileRunner runner; private Client client; + private SecurityClientUtil clientUtil; private DiscoveryNodeFilterer nodeFilter; private int requiredSamples; private AnomalyDetector detector; @@ -93,12 +99,25 @@ enum InittedEverResultStatus { NOT_INITTED, } + @BeforeClass + public static void setUpBeforeClass() { + setUpThreadPool(AnomalyResultTests.class.getSimpleName()); + } + + @AfterClass + public static void tearDownAfterClass() { + tearDownThreadPool(); + } + @SuppressWarnings("unchecked") @Before @Override public void setUp() throws Exception { super.setUp(); client = mock(Client.class); + Clock clock = mock(Clock.class); + NodeStateManager nodeStateManager = mock(NodeStateManager.class); + clientUtil = new SecurityClientUtil(nodeStateManager, Settings.EMPTY); nodeFilter = mock(DiscoveryNodeFilterer.class); requiredSamples = 128; @@ -115,7 +134,15 @@ public void setUp() throws Exception { function.accept(Optional.of(TestHelpers.randomAdTask())); return null; }).when(adTaskManager).getAndExecuteOnLatestDetectorLevelTask(any(), any(), any(), any(), anyBoolean(), any()); - runner = new AnomalyDetectorProfileRunner(client, xContentRegistry(), nodeFilter, requiredSamples, transportService, adTaskManager); + runner = new AnomalyDetectorProfileRunner( + client, + clientUtil, + xContentRegistry(), + nodeFilter, + requiredSamples, + transportService, + adTaskManager + ); doAnswer(invocation -> { Object[] args = invocation.getArguments(); diff --git a/src/test/java/org/opensearch/ad/ODFERestTestCase.java b/src/test/java/org/opensearch/ad/ODFERestTestCase.java index 6de204584..517041e12 100644 --- a/src/test/java/org/opensearch/ad/ODFERestTestCase.java +++ b/src/test/java/org/opensearch/ad/ODFERestTestCase.java @@ -18,11 +18,14 @@ import static org.opensearch.commons.ConfigConstants.OPENSEARCH_SECURITY_SSL_HTTP_PEMCERT_FILEPATH; import java.io.IOException; +import java.io.InputStreamReader; import java.net.URI; import java.net.URISyntaxException; +import java.nio.charset.Charset; import java.nio.file.Path; import java.util.Collections; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.Objects; import java.util.Optional; @@ -54,6 +57,10 @@ import org.opensearch.rest.RestStatus; import org.opensearch.test.rest.OpenSearchRestTestCase; +import com.google.gson.JsonArray; +import com.google.gson.JsonObject; +import com.google.gson.JsonParser; + /** * ODFE integration test base class to support both security disabled and enabled ODFE cluster. */ @@ -210,4 +217,46 @@ protected static void configureHttpsClient(RestClientBuilder builder, Settings s protected boolean preserveIndicesUponCompletion() { return true; } + + protected void waitAllSyncheticDataIngested(int expectedSize, String datasetName, RestClient client) throws Exception { + int maxWaitCycles = 3; + do { + Request request = new Request("POST", String.format(Locale.ROOT, "/%s/_search", datasetName)); + request + .setJsonEntity( + String + .format( + Locale.ROOT, + "{\"query\": {" + + " \"match_all\": {}" + + " }," + + " \"size\": 1," + + " \"sort\": [" + + " {" + + " \"timestamp\": {" + + " \"order\": \"desc\"" + + " }" + + " }" + + " ]}" + ) + ); + // Make sure all of the test data has been ingested + // Expected response: + // "_index":"synthetic","_type":"_doc","_id":"10080","_score":null,"_source":{"timestamp":"2019-11-08T00:00:00Z","Feature1":156.30028000000001,"Feature2":100.211205,"host":"host1"},"sort":[1573171200000]} + Response response = client.performRequest(request); + JsonObject json = JsonParser + .parseReader(new InputStreamReader(response.getEntity().getContent(), Charset.defaultCharset())) + .getAsJsonObject(); + JsonArray hits = json.getAsJsonObject("hits").getAsJsonArray("hits"); + if (hits != null + && hits.size() == 1 + && expectedSize - 1 == hits.get(0).getAsJsonObject().getAsJsonPrimitive("_id").getAsLong()) { + break; + } else { + request = new Request("POST", String.format(Locale.ROOT, "/%s/_refresh", datasetName)); + client.performRequest(request); + } + Thread.sleep(1_000); + } while (maxWaitCycles-- >= 0); + } } diff --git a/src/test/java/org/opensearch/ad/TestHelpers.java b/src/test/java/org/opensearch/ad/TestHelpers.java index aedc79b9a..04ae73ff2 100644 --- a/src/test/java/org/opensearch/ad/TestHelpers.java +++ b/src/test/java/org/opensearch/ad/TestHelpers.java @@ -1096,6 +1096,17 @@ public static void createIndexWithHCADFields(RestClient client, String indexName createIndexMapping(client, indexName, TestHelpers.toHttpEntity(indexMappings.toString())); } + public static void createEmptyIndexMapping(RestClient client, String indexName, Map fieldsAndTypes) throws IOException { + StringBuilder indexMappings = new StringBuilder(); + indexMappings.append("{\"properties\":{"); + for (Map.Entry entry : fieldsAndTypes.entrySet()) { + indexMappings.append("\"" + entry.getKey() + "\":{\"type\":\"" + entry.getValue() + "\"},"); + } + indexMappings.append("}}"); + createEmptyIndex(client, indexName); + createIndexMapping(client, indexName, TestHelpers.toHttpEntity(indexMappings.toString())); + } + public static void createEmptyAnomalyResultIndex(RestClient client) throws IOException { createEmptyIndex(client, CommonName.ANOMALY_RESULT_INDEX_ALIAS); createIndexMapping(client, CommonName.ANOMALY_RESULT_INDEX_ALIAS, toHttpEntity(AnomalyDetectionIndices.getAnomalyResultMappings())); diff --git a/src/test/java/org/opensearch/ad/e2e/DetectionResultEvalutationIT.java b/src/test/java/org/opensearch/ad/e2e/DetectionResultEvalutationIT.java index 074b15552..8273032e7 100644 --- a/src/test/java/org/opensearch/ad/e2e/DetectionResultEvalutationIT.java +++ b/src/test/java/org/opensearch/ad/e2e/DetectionResultEvalutationIT.java @@ -183,6 +183,8 @@ private void indexTrainData(String datasetName, List data, int train String requestBody = "{ \"mappings\": { \"properties\": { \"timestamp\": { \"type\": \"date\"}," + " \"Feature1\": { \"type\": \"long\" }, \"Feature2\": { \"type\": \"long\" } } } }"; request.setJsonEntity(requestBody); + // a WarningFailureException on access system indices .opendistro_security will fail the test if this is not false. + setWarningHandler(request, false); client.performRequest(request); Thread.sleep(1_000); data.stream().limit(trainTestSplit).forEach(r -> { diff --git a/src/test/java/org/opensearch/ad/feature/FeatureManagerTests.java b/src/test/java/org/opensearch/ad/feature/FeatureManagerTests.java index d1ef85077..64ca51936 100644 --- a/src/test/java/org/opensearch/ad/feature/FeatureManagerTests.java +++ b/src/test/java/org/opensearch/ad/feature/FeatureManagerTests.java @@ -13,7 +13,6 @@ import static java.util.Arrays.asList; import static java.util.Optional.empty; -import static java.util.Optional.ofNullable; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; @@ -23,7 +22,6 @@ import static org.mockito.Matchers.argThat; import static org.mockito.Matchers.eq; import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; @@ -167,24 +165,6 @@ private Object[] getColdStartDataTestData() { new Object[] { null, null, 1, null }, }; } - @Test - @Parameters(method = "getColdStartDataTestData") - public void getColdStartData_returnExpected(Long latestTime, Entry data, int interpolants, double[][] expected) { - when(searchFeatureDao.getLatestDataTime(detector)).thenReturn(ofNullable(latestTime)); - if (latestTime != null) { - when(searchFeatureDao.getFeaturesForSampledPeriods(detector, maxTrainSamples, maxSampleStride, latestTime)) - .thenReturn(ofNullable(data)); - } - if (data != null) { - when(interpolator.interpolate(argThat(new ArrayEqMatcher<>(data.getKey())), eq(interpolants))).thenReturn(data.getKey()); - doReturn(data.getKey()).when(featureManager).batchShingle(argThat(new ArrayEqMatcher<>(data.getKey())), eq(shingleSize)); - } - - Optional results = featureManager.getColdStartData(detector); - - assertTrue(Arrays.deepEquals(expected, results.orElse(null))); - } - private Object[] getTrainDataTestData() { List> ranges = asList( entry(0L, 900_000L), diff --git a/src/test/java/org/opensearch/ad/feature/NoPowermockSearchFeatureDaoTests.java b/src/test/java/org/opensearch/ad/feature/NoPowermockSearchFeatureDaoTests.java index 0268b9d7b..67e73cdb5 100644 --- a/src/test/java/org/opensearch/ad/feature/NoPowermockSearchFeatureDaoTests.java +++ b/src/test/java/org/opensearch/ad/feature/NoPowermockSearchFeatureDaoTests.java @@ -43,6 +43,8 @@ import org.apache.logging.log4j.Logger; import org.apache.lucene.search.TotalHits; import org.apache.lucene.util.BytesRef; +import org.junit.AfterClass; +import org.junit.BeforeClass; import org.junit.Test; import org.mockito.ArgumentCaptor; import org.mockito.invocation.InvocationOnMock; @@ -54,6 +56,7 @@ import org.opensearch.action.search.SearchResponseSections; import org.opensearch.action.search.ShardSearchFailure; import org.opensearch.ad.AbstractADTest; +import org.opensearch.ad.NodeStateManager; import org.opensearch.ad.TestHelpers; import org.opensearch.ad.dataprocessor.LinearUniformInterpolator; import org.opensearch.ad.dataprocessor.SingleFeatureLinearUniformInterpolator; @@ -62,7 +65,7 @@ import org.opensearch.ad.model.Feature; import org.opensearch.ad.model.IntervalTimeConfiguration; import org.opensearch.ad.settings.AnomalyDetectorSettings; -import org.opensearch.ad.util.ClientUtil; +import org.opensearch.ad.util.SecurityClientUtil; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.lease.Releasables; @@ -113,7 +116,7 @@ public class NoPowermockSearchFeatureDaoTests extends AbstractADTest { private Client client; private SearchFeatureDao searchFeatureDao; private LinearUniformInterpolator interpolator; - private ClientUtil clientUtil; + private SecurityClientUtil clientUtil; private Settings settings; private ClusterService clusterService; private Clock clock; @@ -121,6 +124,16 @@ public class NoPowermockSearchFeatureDaoTests extends AbstractADTest { private String detectorId; private Map attrs1, attrs2; + @BeforeClass + public static void setUpBeforeClass() { + setUpThreadPool(NoPowermockSearchFeatureDaoTests.class.getSimpleName()); + } + + @AfterClass + public static void tearDownAfterClass() { + tearDownThreadPool(); + } + @Override public void setUp() throws Exception { super.setUp(); @@ -139,11 +152,10 @@ public void setUp() throws Exception { when(detector.getFilterQuery()).thenReturn(QueryBuilders.matchAllQuery()); client = mock(Client.class); + when(client.threadPool()).thenReturn(threadPool); interpolator = new LinearUniformInterpolator(new SingleFeatureLinearUniformInterpolator()); - clientUtil = mock(ClientUtil.class); - settings = Settings.EMPTY; ClusterSettings clusterSettings = new ClusterSettings( Settings.EMPTY, @@ -155,6 +167,13 @@ public void setUp() throws Exception { clusterService = mock(ClusterService.class); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); clock = mock(Clock.class); + NodeStateManager nodeStateManager = mock(NodeStateManager.class); + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.of(detector)); + return null; + }).when(nodeStateManager).getAnomalyDetector(any(String.class), any(ActionListener.class)); + clientUtil = new SecurityClientUtil(nodeStateManager, settings); searchFeatureDao = new SearchFeatureDao( client, diff --git a/src/test/java/org/opensearch/ad/feature/SearchFeatureDaoTests.java b/src/test/java/org/opensearch/ad/feature/SearchFeatureDaoTests.java index d95fb0990..2b458d690 100644 --- a/src/test/java/org/opensearch/ad/feature/SearchFeatureDaoTests.java +++ b/src/test/java/org/opensearch/ad/feature/SearchFeatureDaoTests.java @@ -16,14 +16,11 @@ import static org.hamcrest.core.AnyOf.anyOf; import static org.hamcrest.core.IsInstanceOf.instanceOf; import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Matchers.anyLong; -import static org.mockito.Matchers.anyObject; import static org.mockito.Matchers.eq; import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; @@ -44,7 +41,6 @@ import java.util.Map.Entry; import java.util.Optional; import java.util.concurrent.ExecutorService; -import java.util.function.BiConsumer; import junitparams.JUnitParamsRunner; import junitparams.Parameters; @@ -54,7 +50,6 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.ArgumentCaptor; -import org.mockito.Matchers; import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.opensearch.action.ActionFuture; @@ -68,7 +63,6 @@ import org.opensearch.action.search.ShardSearchFailure; import org.opensearch.ad.AnomalyDetectorPlugin; import org.opensearch.ad.NodeStateManager; -import org.opensearch.ad.common.exception.EndRunException; import org.opensearch.ad.constant.CommonName; import org.opensearch.ad.dataprocessor.Interpolator; import org.opensearch.ad.dataprocessor.LinearUniformInterpolator; @@ -77,8 +71,8 @@ import org.opensearch.ad.model.Entity; import org.opensearch.ad.model.IntervalTimeConfiguration; import org.opensearch.ad.settings.AnomalyDetectorSettings; -import org.opensearch.ad.util.ClientUtil; import org.opensearch.ad.util.ParseUtils; +import org.opensearch.ad.util.SecurityClientUtil; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; @@ -106,7 +100,6 @@ import org.opensearch.search.aggregations.metrics.Max; import org.opensearch.search.aggregations.metrics.MaxAggregationBuilder; import org.opensearch.search.aggregations.metrics.MinAggregationBuilder; -import org.opensearch.search.aggregations.metrics.NumericMetricsAggregation; import org.opensearch.search.aggregations.metrics.Percentile; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.threadpool.ThreadPool; @@ -133,8 +126,7 @@ public class SearchFeatureDaoTests { private ScriptService scriptService; @Mock private NamedXContentRegistry xContent; - @Mock - private ClientUtil clientUtil; + private SecurityClientUtil clientUtil; @Mock private Factory factory; @@ -196,6 +188,14 @@ public void setup() throws Exception { settings = Settings.EMPTY; + when(client.threadPool()).thenReturn(threadPool); + NodeStateManager nodeStateManager = mock(NodeStateManager.class); + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.of(detector)); + return null; + }).when(nodeStateManager).getAnomalyDetector(any(String.class), any(ActionListener.class)); + clientUtil = new SecurityClientUtil(nodeStateManager, settings); searchFeatureDao = spy( new SearchFeatureDao(client, xContent, interpolator, clientUtil, settings, null, AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE) ); @@ -222,72 +222,30 @@ public void setup() throws Exception { SearchHits hits = new SearchHits(new SearchHit[0], new TotalHits(1L, TotalHits.Relation.EQUAL_TO), 1f); when(searchResponse.getHits()).thenReturn(hits); - doReturn(Optional.of(searchResponse)) - .when(clientUtil) - .timedRequest(eq(searchRequest), anyObject(), Matchers.>>anyObject()); + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + ActionListener listener = (ActionListener) args[1]; + listener.onResponse(searchResponse); + return null; + }).when(client).search(eq(searchRequest), any()); when(searchResponse.getAggregations()).thenReturn(aggregations); - doReturn(Optional.of(searchResponse)) - .when(clientUtil) - .throttledTimedRequest( - eq(searchRequest), - anyObject(), - Matchers.>>anyObject(), - anyObject() - ); - multiSearchRequest = new MultiSearchRequest(); SearchRequest request = new SearchRequest(detector.getIndices().toArray(new String[0])); multiSearchRequest.add(request); - doReturn(Optional.of(multiSearchResponse)) - .when(clientUtil) - .timedRequest( - eq(multiSearchRequest), - anyObject(), - Matchers.>>anyObject() - ); + + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + ActionListener listener = (ActionListener) args[1]; + listener.onResponse(multiSearchResponse); + return null; + }).when(client).multiSearch(eq(multiSearchRequest), any()); when(multiSearchResponse.getResponses()).thenReturn(new Item[] { multiSearchResponseItem }); when(multiSearchResponseItem.getResponse()).thenReturn(searchResponse); gson = PowerMockito.mock(Gson.class); } - @Test - public void test_getLatestDataTime_returnExpectedTime_givenData() { - // pre-conditions - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder() - .aggregation(AggregationBuilders.max(CommonName.AGG_NAME_MAX_TIME).field(detector.getTimeField())) - .size(0); - searchRequest.source(searchSourceBuilder); - - long epochTime = 100L; - aggsMap.put(CommonName.AGG_NAME_MAX_TIME, max); - when(max.getValue()).thenReturn((double) epochTime); - - // test - Optional result = searchFeatureDao.getLatestDataTime(detector); - - // verify - assertEquals(epochTime, result.get().longValue()); - } - - @Test - public void test_getLatestDataTime_returnEmpty_givenNoData() { - // pre-conditions - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder() - .aggregation(AggregationBuilders.max(CommonName.AGG_NAME_MAX_TIME).field(detector.getTimeField())) - .size(0); - searchRequest.source(searchSourceBuilder); - - when(searchResponse.getAggregations()).thenReturn(null); - - // test - Optional result = searchFeatureDao.getLatestDataTime(detector); - - // verify - assertFalse(result.isPresent()); - } - @Test @SuppressWarnings("unchecked") public void getLatestDataTime_returnExpectedToListener() { @@ -360,26 +318,6 @@ private Object[] getFeaturesForPeriodData() { new Object[] { asList(max, percentiles, missing), asList(maxName, percentileName, missingName), null }, }; } - @Test - @Parameters(method = "getFeaturesForPeriodData") - public void getFeaturesForPeriod_returnExpected_givenData(List aggs, List featureIds, double[] expected) - throws Exception { - - long start = 100L; - long end = 200L; - - // pre-conditions - when(ParseUtils.generateInternalFeatureQuery(eq(detector), eq(start), eq(end), eq(xContent))).thenReturn(searchSourceBuilder); - when(searchResponse.getAggregations()).thenReturn(new Aggregations(aggs)); - when(detector.getEnabledFeatureIds()).thenReturn(featureIds); - - // test - Optional result = searchFeatureDao.getFeaturesForPeriod(detector, start, end); - - // verify - assertTrue(Arrays.equals(expected, result.orElse(null))); - } - @SuppressWarnings("unchecked") private Object[] getFeaturesForPeriodThrowIllegalStateData() { String aggName = "aggName"; @@ -398,16 +336,6 @@ private Object[] getFeaturesForPeriodThrowIllegalStateData() { new Object[] { asList(multiBucket), asList(aggName), null }, }; } - @Test(expected = EndRunException.class) - @Parameters(method = "getFeaturesForPeriodThrowIllegalStateData") - public void getFeaturesForPeriod_throwIllegalState_forUnknownAggregation( - List aggs, - List featureIds, - double[] expected - ) throws Exception { - getFeaturesForPeriod_returnExpected_givenData(aggs, featureIds, expected); - } - @Test @Parameters(method = "getFeaturesForPeriodData") @SuppressWarnings("unchecked") @@ -473,49 +401,6 @@ public void getFeaturesForPeriod_throwToListener_whenResponseParsingFails() thro verify(listener).onFailure(any(Exception.class)); } - @Test - public void test_getFeaturesForPeriod_returnEmpty_givenNoData() throws Exception { - long start = 100L; - long end = 200L; - - // pre-conditions - when(ParseUtils.generateInternalFeatureQuery(eq(detector), eq(start), eq(end), eq(xContent))).thenReturn(searchSourceBuilder); - when(searchResponse.getAggregations()).thenReturn(null); - - // test - Optional result = searchFeatureDao.getFeaturesForPeriod(detector, start, end); - - // verify - assertFalse(result.isPresent()); - } - - @Test - public void getFeaturesForPeriod_returnNonEmpty_givenDefaultValue() throws Exception { - long start = 100L; - long end = 200L; - - // pre-conditions - when(ParseUtils.generateInternalFeatureQuery(eq(detector), eq(start), eq(end), eq(xContent))).thenReturn(searchSourceBuilder); - when(searchResponse.getHits()).thenReturn(new SearchHits(new SearchHit[0], new TotalHits(0L, TotalHits.Relation.EQUAL_TO), 1f)); - - List aggList = new ArrayList<>(1); - - NumericMetricsAggregation.SingleValue agg = mock(NumericMetricsAggregation.SingleValue.class); - when(agg.getName()).thenReturn("deny_max"); - when(agg.value()).thenReturn(0d); - - aggList.add(agg); - - Aggregations aggregations = new Aggregations(aggList); - when(searchResponse.getAggregations()).thenReturn(aggregations); - - // test - Optional result = searchFeatureDao.getFeaturesForPeriod(detector, start, end); - - // verify - assertTrue(result.isPresent()); - } - private Object[] getFeaturesForSampledPeriodsData() { long endTime = 300_000; int maxStride = 4; @@ -617,34 +502,6 @@ private Object[] getFeaturesForSampledPeriodsData() { Optional.of(new SimpleEntry<>(new double[][] { { 1, 2 }, { 3, 4 }, { 5, 6 }, { 7, 8 }, { 9, 10 } }, 1)) }, }; } - @Test - @Parameters(method = "getFeaturesForSampledPeriodsData") - public void getFeaturesForSampledPeriods_returnExpected( - Long[][] queryRanges, - double[][] queryResults, - long endTime, - int maxStride, - int maxSamples, - Optional> expected - ) { - - doReturn(Optional.empty()).when(searchFeatureDao).getFeaturesForPeriod(eq(detector), anyLong(), anyLong()); - for (int i = 0; i < queryRanges.length; i++) { - doReturn(Optional.of(queryResults[i])) - .when(searchFeatureDao) - .getFeaturesForPeriod(detector, queryRanges[i][0], queryRanges[i][1]); - } - - Optional> result = searchFeatureDao - .getFeaturesForSampledPeriods(detector, maxSamples, maxStride, endTime); - - assertEquals(expected.isPresent(), result.isPresent()); - if (expected.isPresent()) { - assertTrue(Arrays.deepEquals(expected.get().getKey(), result.get().getKey())); - assertEquals(expected.get().getValue(), result.get().getValue()); - } - } - @Test @Parameters(method = "getFeaturesForSampledPeriodsData") @SuppressWarnings("unchecked") diff --git a/src/test/java/org/opensearch/ad/rest/SecureADRestIT.java b/src/test/java/org/opensearch/ad/rest/SecureADRestIT.java index 44536c9fd..68098b268 100644 --- a/src/test/java/org/opensearch/ad/rest/SecureADRestIT.java +++ b/src/test/java/org/opensearch/ad/rest/SecureADRestIT.java @@ -221,7 +221,6 @@ public void testUpdateApiFilterByEnabled() throws IOException { // User Fish has AD full access, and has "odfe" backend role which is one of Alice's backend role, so // Fish should be able to update detectors created by Alice. But the detector's backend role should // not be replaced as Fish's backend roles. - TestHelpers.createIndexWithTimeField(client(), newDetector.getIndices().get(0), newDetector.getTimeField()); Response response = updateAnomalyDetector(aliceDetector.getDetectorId(), newDetector, fishClient); Assert.assertEquals(response.getStatusLine().getStatusCode(), 200); AnomalyDetector anomalyDetector = getAnomalyDetector(aliceDetector.getDetectorId(), aliceClient); @@ -387,7 +386,11 @@ public void testPreviewAnomalyDetectorWithNoReadPermissionOfIndex() throws IOExc Exception.class, () -> { previewAnomalyDetector(aliceDetector.getDetectorId(), elkClient, input); } ); - Assert.assertTrue(exception.getMessage().contains("no permissions for [indices:data/read/search]")); + Assert + .assertTrue( + "actual msg: " + exception.getMessage(), + exception.getMessage().contains("no permissions for [indices:data/read/search]") + ); } public void testValidateAnomalyDetectorWithWriteAccess() throws IOException { diff --git a/src/test/java/org/opensearch/ad/transport/AnomalyResultTests.java b/src/test/java/org/opensearch/ad/transport/AnomalyResultTests.java index 21ff96e18..da7eb0d17 100644 --- a/src/test/java/org/opensearch/ad/transport/AnomalyResultTests.java +++ b/src/test/java/org/opensearch/ad/transport/AnomalyResultTests.java @@ -92,6 +92,7 @@ import org.opensearch.ad.stats.StatNames; import org.opensearch.ad.stats.suppliers.CounterSupplier; import org.opensearch.ad.task.ADTaskManager; +import org.opensearch.ad.util.SecurityClientUtil; import org.opensearch.client.Client; import org.opensearch.cluster.ClusterName; import org.opensearch.cluster.ClusterState; @@ -137,6 +138,7 @@ public class AnomalyResultTests extends AbstractADTest { private FeatureManager featureQuery; private ModelManager normalModelManager; private Client client; + private SecurityClientUtil clientUtil; private AnomalyDetector detector; private HashRing hashRing; private IndexNameExpressionResolver indexNameResolver; @@ -278,6 +280,8 @@ public void setUp() throws Exception { return null; }).when(client).index(any(), any()); + NodeStateManager nodeStateManager = mock(NodeStateManager.class); + clientUtil = new SecurityClientUtil(nodeStateManager, settings); indexNameResolver = new IndexNameExpressionResolver(new ThreadContext(Settings.EMPTY)); @@ -355,6 +359,7 @@ public void testNormal() throws IOException { transportService, settings, client, + clientUtil, stateManager, featureQuery, normalModelManager, @@ -479,6 +484,7 @@ public void sendRequest( realTransportService, settings, client, + clientUtil, stateManager, featureQuery, normalModelManager, @@ -531,6 +537,7 @@ public void testInsufficientCapacityExceptionDuringColdStart() { transportService, settings, client, + clientUtil, stateManager, featureQuery, normalModelManager, @@ -575,6 +582,7 @@ public void testInsufficientCapacityExceptionDuringRestoringModel() { transportService, settings, client, + clientUtil, stateManager, featureQuery, normalModelManager, @@ -701,6 +709,7 @@ public void sendRequest( realTransportService, settings, client, + clientUtil, stateManager, featureQuery, normalModelManager, @@ -743,6 +752,7 @@ public void testCircuitBreaker() { transportService, settings, client, + clientUtil, stateManager, featureQuery, normalModelManager, @@ -812,6 +822,7 @@ private void nodeNotConnectedExceptionTemplate(boolean isRCF, boolean temporary, exceptionTransportService, settings, client, + clientUtil, stateManager, featureQuery, normalModelManager, @@ -864,6 +875,7 @@ public void testMute() { transportService, settings, client, + clientUtil, muteStateManager, featureQuery, normalModelManager, @@ -905,6 +917,7 @@ public void alertingRequestTemplate(boolean anomalyResultIndexExists) throws IOE transportService, settings, client, + clientUtil, stateManager, featureQuery, normalModelManager, @@ -1083,6 +1096,7 @@ public void testOnFailureNull() throws IOException { transportService, settings, client, + clientUtil, stateManager, featureQuery, normalModelManager, @@ -1174,6 +1188,7 @@ public void testColdStartNoTrainingData() throws Exception { transportService, settings, client, + clientUtil, stateManager, featureQuery, normalModelManager, @@ -1211,6 +1226,7 @@ public void testConcurrentColdStart() throws Exception { transportService, settings, client, + clientUtil, stateManager, featureQuery, normalModelManager, @@ -1254,6 +1270,7 @@ public void testColdStartTimeoutPutCheckpoint() throws Exception { transportService, settings, client, + clientUtil, stateManager, featureQuery, normalModelManager, @@ -1297,6 +1314,7 @@ public void testColdStartIllegalArgumentException() throws Exception { transportService, settings, client, + clientUtil, stateManager, featureQuery, normalModelManager, @@ -1347,6 +1365,7 @@ public void featureTestTemplate(FeatureTestMode mode) throws IOException { transportService, settings, client, + clientUtil, stateManager, featureQuery, normalModelManager, @@ -1435,6 +1454,7 @@ private void globalBlockTemplate(BlockType type, String errLogMsg, Settings inde transportService, settings, client, + clientUtil, stateManager, featureQuery, normalModelManager, @@ -1483,6 +1503,7 @@ public void testNullRCFResult() { transportService, settings, client, + clientUtil, stateManager, featureQuery, normalModelManager, @@ -1509,6 +1530,7 @@ public void testNormalRCFResult() { transportService, settings, client, + clientUtil, stateManager, featureQuery, normalModelManager, @@ -1542,6 +1564,7 @@ public void testNullPointerRCFResult() { transportService, settings, client, + clientUtil, stateManager, featureQuery, normalModelManager, @@ -1583,6 +1606,7 @@ public void testAllFeaturesDisabled() throws IOException { transportService, settings, client, + clientUtil, stateManager, featureQuery, normalModelManager, @@ -1647,6 +1671,7 @@ public void testEndRunDueToNoTrainingData() { transportService, settings, client, + clientUtil, stateManager, featureQuery, normalModelManager, @@ -1699,6 +1724,7 @@ public void testColdStartEndRunException() { transportService, settings, client, + clientUtil, stateManager, featureQuery, normalModelManager, @@ -1742,6 +1768,7 @@ public void testColdStartEndRunExceptionNow() { transportService, settings, client, + clientUtil, stateManager, featureQuery, normalModelManager, @@ -1782,6 +1809,7 @@ public void testColdStartBecauseFailtoGetCheckpoint() { transportService, settings, client, + clientUtil, stateManager, featureQuery, normalModelManager, @@ -1820,6 +1848,7 @@ public void testNoColdStartDueToUnknownException() { transportService, settings, client, + clientUtil, stateManager, featureQuery, normalModelManager, diff --git a/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorTests.java b/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorTests.java index e7a5d613e..d25979304 100644 --- a/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorTests.java +++ b/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorTests.java @@ -25,6 +25,7 @@ import java.io.IOException; import java.nio.ByteBuffer; +import java.time.Clock; import java.util.Arrays; import java.util.Collections; import java.util.HashSet; @@ -42,6 +43,7 @@ import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.PlainActionFuture; import org.opensearch.ad.AbstractADTest; +import org.opensearch.ad.NodeStateManager; import org.opensearch.ad.constant.CommonErrorMessages; import org.opensearch.ad.model.ADTask; import org.opensearch.ad.model.ADTaskType; @@ -49,6 +51,8 @@ import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.ad.task.ADTaskManager; import org.opensearch.ad.util.DiscoveryNodeFilterer; +import org.opensearch.ad.util.SecurityClientUtil; +import org.opensearch.ad.util.Throttler; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.bytes.BytesReference; @@ -64,6 +68,7 @@ public class GetAnomalyDetectorTests extends AbstractADTest { private DiscoveryNodeFilterer nodeFilter; private ActionFilters actionFilters; private Client client; + private SecurityClientUtil clientUtil; private GetAnomalyDetectorRequest request; private String detectorId = "yecrdnUBqurvo9uKU_d8"; private String entityValue = "app_0"; @@ -111,6 +116,12 @@ public void setUp() throws Exception { client = mock(Client.class); when(client.threadPool()).thenReturn(threadPool); + Clock clock = mock(Clock.class); + Throttler throttler = new Throttler(clock); + + NodeStateManager nodeStateManager = mock(NodeStateManager.class); + clientUtil = new SecurityClientUtil(nodeStateManager, Settings.EMPTY); + adTaskManager = mock(ADTaskManager.class); action = new GetAnomalyDetectorTransportAction( @@ -119,6 +130,7 @@ public void setUp() throws Exception { actionFilters, clusterService, client, + clientUtil, Settings.EMPTY, xContentRegistry(), adTaskManager diff --git a/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorTransportActionTests.java b/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorTransportActionTests.java index 544ad1ce0..b301c1b6c 100644 --- a/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorTransportActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorTransportActionTests.java @@ -20,13 +20,13 @@ import java.util.Collections; import java.util.HashSet; import java.util.Map; +import java.util.concurrent.TimeUnit; -import org.junit.Assert; -import org.junit.Before; -import org.junit.Test; +import org.junit.*; import org.mockito.Mockito; import org.opensearch.action.ActionListener; import org.opensearch.action.support.ActionFilters; +import org.opensearch.ad.NodeStateManager; import org.opensearch.ad.TestHelpers; import org.opensearch.ad.constant.CommonName; import org.opensearch.ad.model.ADTask; @@ -37,8 +37,7 @@ import org.opensearch.ad.model.InitProgressProfile; import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.ad.task.ADTaskManager; -import org.opensearch.ad.util.DiscoveryNodeFilterer; -import org.opensearch.ad.util.RestHandlerUtils; +import org.opensearch.ad.util.*; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.io.stream.NamedWriteableAwareStreamInput; @@ -51,12 +50,14 @@ import org.opensearch.rest.RestStatus; import org.opensearch.tasks.Task; import org.opensearch.test.OpenSearchSingleNodeTestCase; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; import com.google.common.collect.ImmutableMap; public class GetAnomalyDetectorTransportActionTests extends OpenSearchSingleNodeTestCase { - + private static ThreadPool threadPool; private GetAnomalyDetectorTransportAction action; private Task task; private ActionListener response; @@ -65,6 +66,16 @@ public class GetAnomalyDetectorTransportActionTests extends OpenSearchSingleNode private String categoryField; private String categoryValue; + @BeforeClass + public static void beforeCLass() { + threadPool = new TestThreadPool("GetAnomalyDetectorTransportActionTests"); + } + + @AfterClass + public static void afterClass() { + ThreadPool.terminate(threadPool, 30, TimeUnit.SECONDS); + } + @Override @Before public void setUp() throws Exception { @@ -76,12 +87,15 @@ public void setUp() throws Exception { ); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); adTaskManager = mock(ADTaskManager.class); + NodeStateManager nodeStateManager = mock(NodeStateManager.class); + SecurityClientUtil clientUtil = new SecurityClientUtil(nodeStateManager, Settings.EMPTY); action = new GetAnomalyDetectorTransportAction( Mockito.mock(TransportService.class), Mockito.mock(DiscoveryNodeFilterer.class), Mockito.mock(ActionFilters.class), clusterService, client(), + clientUtil, Settings.EMPTY, xContentRegistry(), adTaskManager diff --git a/src/test/java/org/opensearch/ad/transport/IndexAnomalyDetectorTransportActionTests.java b/src/test/java/org/opensearch/ad/transport/IndexAnomalyDetectorTransportActionTests.java index fd76d4bea..b76ba44bb 100644 --- a/src/test/java/org/opensearch/ad/transport/IndexAnomalyDetectorTransportActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/IndexAnomalyDetectorTransportActionTests.java @@ -35,12 +35,14 @@ import org.opensearch.action.search.ShardSearchFailure; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.WriteRequest; +import org.opensearch.ad.NodeStateManager; import org.opensearch.ad.TestHelpers; import org.opensearch.ad.feature.SearchFeatureDao; import org.opensearch.ad.indices.AnomalyDetectionIndices; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.ad.task.ADTaskManager; +import org.opensearch.ad.util.SecurityClientUtil; import org.opensearch.client.Client; import org.opensearch.cluster.ClusterName; import org.opensearch.cluster.ClusterState; @@ -72,6 +74,7 @@ public class IndexAnomalyDetectorTransportActionTests extends OpenSearchIntegTes private ClusterSettings clusterSettings; private ADTaskManager adTaskManager; private Client client = mock(Client.class); + private SecurityClientUtil clientUtil; private SearchFeatureDao searchFeatureDao; @SuppressWarnings("unchecked") @@ -104,10 +107,13 @@ public void setUp() throws Exception { adTaskManager = mock(ADTaskManager.class); searchFeatureDao = mock(SearchFeatureDao.class); + NodeStateManager nodeStateManager = mock(NodeStateManager.class); + clientUtil = new SecurityClientUtil(nodeStateManager, Settings.EMPTY); action = new IndexAnomalyDetectorTransportAction( mock(TransportService.class), mock(ActionFilters.class), client(), + clientUtil, clusterService, indexSettings(), mock(AnomalyDetectionIndices.class), @@ -206,6 +212,7 @@ public void testIndexTransportActionWithUserAndFilterOn() { mock(TransportService.class), mock(ActionFilters.class), client, + clientUtil, clusterService, settings, mock(AnomalyDetectionIndices.class), @@ -231,6 +238,7 @@ public void testIndexTransportActionWithUserAndFilterOff() { mock(TransportService.class), mock(ActionFilters.class), client, + clientUtil, clusterService, settings, mock(AnomalyDetectionIndices.class), diff --git a/src/test/java/org/opensearch/ad/transport/MultiEntityResultTests.java b/src/test/java/org/opensearch/ad/transport/MultiEntityResultTests.java index aa3c182f3..e8a6e71c2 100644 --- a/src/test/java/org/opensearch/ad/transport/MultiEntityResultTests.java +++ b/src/test/java/org/opensearch/ad/transport/MultiEntityResultTests.java @@ -102,6 +102,8 @@ import org.opensearch.ad.task.ADTaskManager; import org.opensearch.ad.util.Bwc; import org.opensearch.ad.util.ClientUtil; +import org.opensearch.ad.util.SecurityClientUtil; +import org.opensearch.ad.util.Throttler; import org.opensearch.client.Client; import org.opensearch.cluster.metadata.IndexNameExpressionResolver; import org.opensearch.cluster.node.DiscoveryNode; @@ -147,6 +149,7 @@ public class MultiEntityResultTests extends AbstractADTest { private static Settings settings; private TransportService transportService; private Client client; + private SecurityClientUtil clientUtil; private FeatureManager featureQuery; private ModelManager normalModelManager; private HashRing hashRing; @@ -218,6 +221,7 @@ public void setUp() throws Exception { setUpADThreadPool(mockThreadPool); when(client.threadPool()).thenReturn(mockThreadPool); when(mockThreadPool.getThreadContext()).thenReturn(threadContext); + clientUtil = new SecurityClientUtil(stateManager, settings); featureQuery = mock(FeatureManager.class); @@ -277,6 +281,7 @@ public void setUp() throws Exception { transportService, settings, client, + clientUtil, stateManager, featureQuery, normalModelManager, @@ -424,34 +429,35 @@ private void setUpEntityResult(int nodeIndex) { @SuppressWarnings("unchecked") public void setUpNormlaStateManager() throws IOException { - ClientUtil clientUtil = mock(ClientUtil.class); - AnomalyDetector detector = TestHelpers.AnomalyDetectorBuilder .newInstance() .setDetectionInterval(new IntervalTimeConfiguration(1, ChronoUnit.MINUTES)) .setCategoryFields(ImmutableList.of(randomAlphaOfLength(5))) .build(); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); + ActionListener listener = invocation.getArgument(1); listener.onResponse(TestHelpers.createGetResponse(detector, detectorId, AnomalyDetector.ANOMALY_DETECTORS_INDEX)); return null; - }).when(clientUtil).asyncRequest(any(GetRequest.class), any(), any(ActionListener.class)); + }).when(client).get(any(GetRequest.class), any(ActionListener.class)); stateManager = new NodeStateManager( client, xContentRegistry(), settings, - clientUtil, + new ClientUtil(settings, client, new Throttler(mock(Clock.class)), threadPool), clock, AnomalyDetectorSettings.HOURLY_MAINTENANCE, clusterService ); + clientUtil = new SecurityClientUtil(stateManager, settings); + action = new AnomalyResultTransportAction( new ActionFilters(Collections.emptySet()), transportService, settings, client, + clientUtil, stateManager, featureQuery, normalModelManager, @@ -525,7 +531,6 @@ public void testIndexNotFound() throws InterruptedException, IOException { }).when(client).search(any(), any()); PlainActionFuture listener = new PlainActionFuture<>(); - action.doExecute(null, request, listener); AnomalyResultResponse response = listener.actionGet(10000L); @@ -687,6 +692,7 @@ public void sendRequest( realTransportService, settings, client, + clientUtil, nodeStateManager, featureQuery, normalModelManager, @@ -1151,6 +1157,7 @@ public void testPageToString() { detector, xContentRegistry(), client, + clientUtil, 100, clock, settings, @@ -1177,6 +1184,7 @@ public void testEmptyPageToString() { detector, xContentRegistry(), client, + clientUtil, 100, clock, settings, diff --git a/src/test/java/org/opensearch/search/aggregations/metrics/CardinalityProfileTests.java b/src/test/java/org/opensearch/search/aggregations/metrics/CardinalityProfileTests.java index af6aeef47..5e86089d3 100644 --- a/src/test/java/org/opensearch/search/aggregations/metrics/CardinalityProfileTests.java +++ b/src/test/java/org/opensearch/search/aggregations/metrics/CardinalityProfileTests.java @@ -11,8 +11,7 @@ package org.opensearch.search.aggregations.metrics; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.*; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; @@ -22,11 +21,7 @@ import java.io.IOException; import java.time.temporal.ChronoUnit; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; +import java.util.*; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; @@ -37,6 +32,8 @@ import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.ad.AbstractProfileRunnerTests; +import org.opensearch.ad.AnomalyDetectorProfileRunner; +import org.opensearch.ad.NodeStateManager; import org.opensearch.ad.TestHelpers; import org.opensearch.ad.constant.CommonName; import org.opensearch.ad.model.AnomalyDetector; @@ -45,7 +42,9 @@ import org.opensearch.ad.transport.ProfileAction; import org.opensearch.ad.transport.ProfileNodeResponse; import org.opensearch.ad.transport.ProfileResponse; +import org.opensearch.ad.util.SecurityClientUtil; import org.opensearch.cluster.ClusterName; +import org.opensearch.common.settings.Settings; import org.opensearch.common.util.BigArrays; import org.opensearch.search.aggregations.InternalAggregation; import org.opensearch.search.aggregations.InternalAggregations; @@ -73,6 +72,23 @@ private void setUpMultiEntityClientGet(DetectorStatus detectorStatus, JobStatus throws IOException { detector = TestHelpers .randomAnomalyDetectorWithInterval(new IntervalTimeConfiguration(detectorIntervalMin, ChronoUnit.MINUTES), true); + NodeStateManager nodeStateManager = mock(NodeStateManager.class); + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.of(detector)); + return null; + }).when(nodeStateManager).getAnomalyDetector(anyString(), any(ActionListener.class)); + clientUtil = new SecurityClientUtil(nodeStateManager, Settings.EMPTY); + runner = new AnomalyDetectorProfileRunner( + client, + clientUtil, + xContentRegistry(), + nodeFilter, + requiredSamples, + transportService, + adTaskManager + ); + doAnswer(invocation -> { Object[] args = invocation.getArguments(); GetRequest request = (GetRequest) args[0];