Skip to content

Commit

Permalink
Update InputService for Bucket-Level Alerting (#152)
Browse files Browse the repository at this point in the history
Signed-off-by: Mohammad Qureshi <[email protected]>

Co-authored-by: Rishabh Maurya <[email protected]>
  • Loading branch information
qreshi and rishabhmaurya authored Aug 26, 2021
1 parent e8c474f commit d31f0a1
Show file tree
Hide file tree
Showing 3 changed files with 260 additions and 2 deletions.
14 changes: 12 additions & 2 deletions alerting/src/main/kotlin/org/opensearch/alerting/InputService.kt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import org.opensearch.alerting.elasticapi.convertToMap
import org.opensearch.alerting.elasticapi.suspendUntil
import org.opensearch.alerting.model.InputRunResults
import org.opensearch.alerting.model.Monitor
import org.opensearch.alerting.util.AggregationQueryRewriter
import org.opensearch.alerting.util.addUserBackendRolesFilter
import org.opensearch.client.Client
import org.opensearch.common.xcontent.LoggingDeprecationHandler
Expand All @@ -40,15 +41,23 @@ class InputService(

private val logger = LogManager.getLogger(InputService::class.java)

suspend fun collectInputResults(monitor: Monitor, periodStart: Instant, periodEnd: Instant): InputRunResults {
suspend fun collectInputResults(
monitor: Monitor,
periodStart: Instant,
periodEnd: Instant,
prevResult: InputRunResults? = null
): InputRunResults {
return try {
val results = mutableListOf<Map<String, Any>>()
val aggTriggerAfterKeys: MutableMap<String, Map<String, Any>?> = mutableMapOf()

monitor.inputs.forEach { input ->
when (input) {
is SearchInput -> {
// TODO: Figure out a way to use SearchTemplateRequest without bringing in the entire TransportClient
val searchParams = mapOf("period_start" to periodStart.toEpochMilli(),
"period_end" to periodEnd.toEpochMilli())
AggregationQueryRewriter.rewriteQuery(input.query, prevResult, monitor.triggers)
val searchSource = scriptService.compile(Script(ScriptType.INLINE, Script.DEFAULT_TEMPLATE_LANG,
input.query.toString(), searchParams), TemplateScript.CONTEXT)
.newInstance(searchParams)
Expand All @@ -59,14 +68,15 @@ class InputService(
searchRequest.source(SearchSourceBuilder.fromXContent(it))
}
val searchResponse: SearchResponse = client.suspendUntil { client.search(searchRequest, it) }
aggTriggerAfterKeys += AggregationQueryRewriter.getAfterKeysFromSearchResponse(searchResponse, monitor.triggers)
results += searchResponse.convertToMap()
}
else -> {
throw IllegalArgumentException("Unsupported input type: ${input.name()}.")
}
}
}
InputRunResults(results.toList())
InputRunResults(results.toList(), aggTriggersAfterKey = aggTriggerAfterKeys)
} catch (e: Exception) {
logger.info("Error collecting inputs for monitor: ${monitor.id}", e)
InputRunResults(emptyList(), e)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.alerting.util

import org.opensearch.alerting.model.BucketLevelTrigger
import org.opensearch.alerting.model.InputRunResults
import org.opensearch.alerting.model.Trigger
import org.opensearch.action.search.SearchResponse
import org.opensearch.search.aggregations.AggregationBuilder
import org.opensearch.search.aggregations.AggregatorFactories
import org.opensearch.search.aggregations.bucket.SingleBucketAggregation
import org.opensearch.search.aggregations.bucket.composite.CompositeAggregation
import org.opensearch.search.aggregations.bucket.composite.CompositeAggregationBuilder
import org.opensearch.search.aggregations.support.AggregationPath
import org.opensearch.search.builder.SearchSourceBuilder

class AggregationQueryRewriter {

companion object {
/**
* Add the bucket selector conditions for each trigger in input query. It also adds afterKeys from previous result
* for each trigger.
*/
fun rewriteQuery(query: SearchSourceBuilder, prevResult: InputRunResults?, triggers: List<Trigger>) {
triggers.forEach { trigger ->
if (trigger is BucketLevelTrigger) {
// add bucket selector pipeline aggregation for each trigger in query
query.aggregation(trigger.bucketSelector)
// if this request is processing the subsequent pages of input query result, then add after key
if (prevResult?.aggTriggersAfterKey?.get(trigger.id) != null) {
val parentBucketPath = AggregationPath.parse(trigger.bucketSelector.parentBucketPath)
var aggBuilders = (query.aggregations() as AggregatorFactories.Builder).aggregatorFactories
var factory: AggregationBuilder? = null
for (i in 0 until parentBucketPath.pathElements.size) {
factory = null
for (aggFactory in aggBuilders) {
if (aggFactory.name.equals(parentBucketPath.pathElements[i].name)) {
aggBuilders = aggFactory.subAggregations
factory = aggFactory
break
}
}
if (factory == null) {
throw IllegalArgumentException("ParentBucketPath: $parentBucketPath not found in input query results")
}
}
if (factory is CompositeAggregationBuilder) {
// if the afterKey from previous result is null, what does it signify?
// A) result set exhausted OR B) first page ?
val afterKey = prevResult.aggTriggersAfterKey[trigger.id]
factory.aggregateAfter(afterKey)
} else {
throw IllegalStateException("AfterKeys are not expected to be present in non CompositeAggregationBuilder")
}
}
}
}
}

/**
* For each trigger, returns the after keys if present in query result.
*/
fun getAfterKeysFromSearchResponse(searchResponse: SearchResponse, triggers: List<Trigger>): Map<String, Map<String, Any>?> {
val aggTriggerAfterKeys = mutableMapOf<String, Map<String, Any>?>()
triggers.forEach { trigger ->
if (trigger is BucketLevelTrigger) {
val parentBucketPath = AggregationPath.parse(trigger.bucketSelector.parentBucketPath)
var aggs = searchResponse.aggregations
// assuming all intermediate aggregations as SingleBucketAggregation
for (i in 0 until parentBucketPath.pathElements.size - 1) {
aggs = (aggs.asMap()[parentBucketPath.pathElements[i].name] as SingleBucketAggregation).aggregations
}
val lastAgg = aggs.asMap[parentBucketPath.pathElements.last().name]
// if leaf is CompositeAggregation, then fetch afterKey if present
if (lastAgg is CompositeAggregation) {
aggTriggerAfterKeys[trigger.id] = lastAgg.afterKey()
}
}
}
return aggTriggerAfterKeys
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.alerting.util

import org.opensearch.alerting.model.InputRunResults
import org.opensearch.alerting.model.Trigger
import org.opensearch.alerting.randomBucketLevelTrigger
import org.opensearch.alerting.randomQueryLevelTrigger
import org.opensearch.action.search.SearchResponse
import org.opensearch.cluster.ClusterModule
import org.opensearch.common.CheckedFunction
import org.opensearch.common.ParseField
import org.opensearch.common.xcontent.NamedXContentRegistry
import org.opensearch.common.xcontent.XContentParser
import org.opensearch.common.xcontent.json.JsonXContent
import org.opensearch.search.aggregations.Aggregation
import org.opensearch.search.aggregations.AggregationBuilder
import org.opensearch.search.aggregations.bucket.composite.CompositeAggregationBuilder
import org.opensearch.search.aggregations.bucket.composite.ParsedComposite
import org.opensearch.search.aggregations.bucket.composite.TermsValuesSourceBuilder
import org.opensearch.search.aggregations.bucket.terms.TermsAggregationBuilder
import org.opensearch.search.builder.SearchSourceBuilder
import org.opensearch.test.OpenSearchTestCase
import org.junit.Assert
import java.io.IOException

class AggregationQueryRewriterTests : OpenSearchTestCase() {

fun `test RewriteQuery empty previous result`() {
val triggers: MutableList<Trigger> = mutableListOf()
for (i in 0 until 10) {
triggers.add(randomBucketLevelTrigger())
}
val queryBuilder = SearchSourceBuilder()
val termAgg: AggregationBuilder = TermsAggregationBuilder("testPath").field("sports")
queryBuilder.aggregation(termAgg)
val prevResult = null
AggregationQueryRewriter.rewriteQuery(queryBuilder, prevResult, triggers)
Assert.assertEquals(queryBuilder.aggregations().pipelineAggregatorFactories.size, 10)
}

fun `skip test RewriteQuery with non-empty previous result`() {
val triggers: MutableList<Trigger> = mutableListOf()
for (i in 0 until 10) {
triggers.add(randomBucketLevelTrigger())
}
val queryBuilder = SearchSourceBuilder()
val termAgg: AggregationBuilder = CompositeAggregationBuilder(
"testPath",
listOf(TermsValuesSourceBuilder("k1"), TermsValuesSourceBuilder("k2"))
)
queryBuilder.aggregation(termAgg)
val aggTriggersAfterKey = mutableMapOf<String, Map<String, Any>?>()
for (trigger in triggers) {
aggTriggersAfterKey[trigger.id] = hashMapOf(Pair("k1", "v1"), Pair("k2", "v2"))
}
val prevResult = InputRunResults(emptyList(), null, aggTriggersAfterKey)
AggregationQueryRewriter.rewriteQuery(queryBuilder, prevResult, triggers)
Assert.assertEquals(queryBuilder.aggregations().pipelineAggregatorFactories.size, 10)
queryBuilder.aggregations().aggregatorFactories.forEach {
if (it.name.equals("testPath")) {
// val compAgg = it as CompositeAggregationBuilder
// TODO: This is calling forbidden API and causing build failures, need to find an alternative
// instead of trying to access private member variables
// val afterField = CompositeAggregationBuilder::class.java.getDeclaredField("after")
// afterField.isAccessible = true
// Assert.assertEquals(afterField.get(compAgg), hashMapOf(Pair("k1", "v1"), Pair("k2", "v2")))
}
}
}

fun `test RewriteQuery with non aggregation trigger`() {
val triggers: MutableList<Trigger> = mutableListOf()
for (i in 0 until 10) {
triggers.add(randomQueryLevelTrigger())
}
val queryBuilder = SearchSourceBuilder()
val termAgg: AggregationBuilder = TermsAggregationBuilder("testPath").field("sports")
queryBuilder.aggregation(termAgg)
val prevResult = null
AggregationQueryRewriter.rewriteQuery(queryBuilder, prevResult, triggers)
Assert.assertEquals(queryBuilder.aggregations().pipelineAggregatorFactories.size, 0)
}

fun `test after keys from search response`() {
val responseContent = """
{
"took" : 97,
"timed_out" : false,
"_shards" : {
"total" : 3,
"successful" : 3,
"skipped" : 0,
"failed" : 0
},
"hits" : {
"total" : {
"value" : 20,
"relation" : "eq"
},
"max_score" : null,
"hits" : [ ]
},
"aggregations" : {
"composite#testPath" : {
"after_key" : {
"sport" : "Basketball"
},
"buckets" : [
{
"key" : {
"sport" : "Basketball"
},
"doc_count" : 5
}
]
}
}
}
""".trimIndent()

val aggTriggers: MutableList<Trigger> = mutableListOf(randomBucketLevelTrigger())
val tradTriggers: MutableList<Trigger> = mutableListOf(randomQueryLevelTrigger())

val searchResponse = SearchResponse.fromXContent(createParser(JsonXContent.jsonXContent, responseContent))
val afterKeys = AggregationQueryRewriter.getAfterKeysFromSearchResponse(searchResponse, aggTriggers)
Assert.assertEquals(afterKeys[aggTriggers[0].id], hashMapOf(Pair("sport", "Basketball")))

val afterKeys2 = AggregationQueryRewriter.getAfterKeysFromSearchResponse(searchResponse, tradTriggers)
Assert.assertEquals(afterKeys2.size, 0)
}

override fun xContentRegistry(): NamedXContentRegistry {
val entries = ClusterModule.getNamedXWriteables()
entries.add(
NamedXContentRegistry.Entry(
Aggregation::class.java, ParseField(CompositeAggregationBuilder.NAME),
CheckedFunction<XContentParser, ParsedComposite, IOException> { parser: XContentParser? ->
ParsedComposite.fromXContent(
parser, "testPath"
)
}
)
)
return NamedXContentRegistry(entries)
}
}

0 comments on commit d31f0a1

Please sign in to comment.