Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug in paginating multiple bucket paths for Bucket-Level Monitor #163

Merged
merged 2 commits into from
Sep 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import org.opensearch.alerting.elasticapi.convertToMap
import org.opensearch.alerting.elasticapi.suspendUntil
import org.opensearch.alerting.model.InputRunResults
import org.opensearch.alerting.model.Monitor
import org.opensearch.alerting.model.TriggerAfterKey
import org.opensearch.alerting.util.AggregationQueryRewriter
import org.opensearch.alerting.util.addUserBackendRolesFilter
import org.opensearch.client.Client
Expand Down Expand Up @@ -49,8 +50,10 @@ class InputService(
): InputRunResults {
return try {
val results = mutableListOf<Map<String, Any>>()
val aggTriggerAfterKeys: MutableMap<String, Map<String, Any>?> = mutableMapOf()
val aggTriggerAfterKey: MutableMap<String, TriggerAfterKey> = mutableMapOf()

// TODO: If/when multiple input queries are supported for Bucket-Level Monitor execution, aggTriggerAfterKeys will
// need to be updated to account for it
monitor.inputs.forEach { input ->
when (input) {
is SearchInput -> {
Expand All @@ -75,15 +78,19 @@ class InputService(
searchRequest.source(SearchSourceBuilder.fromXContent(it))
}
val searchResponse: SearchResponse = client.suspendUntil { client.search(searchRequest, it) }
aggTriggerAfterKeys += AggregationQueryRewriter.getAfterKeysFromSearchResponse(searchResponse, monitor.triggers)
aggTriggerAfterKey += AggregationQueryRewriter.getAfterKeysFromSearchResponse(
searchResponse,
monitor.triggers,
prevResult?.aggTriggersAfterKey
)
results += searchResponse.convertToMap()
}
else -> {
throw IllegalArgumentException("Unsupported input type: ${input.name()}.")
}
}
}
InputRunResults(results.toList(), aggTriggersAfterKey = aggTriggerAfterKeys)
InputRunResults(results.toList(), aggTriggersAfterKey = aggTriggerAfterKey)
} catch (e: Exception) {
logger.info("Error collecting inputs for monitor: ${monitor.id}", e)
InputRunResults(emptyList(), e)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ data class MonitorRunResult<TriggerResult : TriggerRunResult>(
data class InputRunResults(
val results: List<Map<String, Any>> = listOf(),
val error: Exception? = null,
val aggTriggersAfterKey: MutableMap<String, Map<String, Any>?>? = null
val aggTriggersAfterKey: MutableMap<String, TriggerAfterKey>? = null
) : Writeable, ToXContent {

override fun toXContent(builder: XContentBuilder, params: ToXContent.Params): XContentBuilder {
Expand Down Expand Up @@ -152,14 +152,16 @@ data class InputRunResults(

fun afterKeysPresent(): Boolean {
aggTriggersAfterKey?.forEach {
if (it.value != null) {
if (it.value.afterKey != null && !it.value.lastPage) {
return true
}
}
return false
}
}

data class TriggerAfterKey(val afterKey: Map<String, Any>?, val lastPage: Boolean)

data class ActionRunResult(
val actionId: String,
val actionName: String,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import org.opensearch.action.search.SearchResponse
import org.opensearch.alerting.model.BucketLevelTrigger
import org.opensearch.alerting.model.InputRunResults
import org.opensearch.alerting.model.Trigger
import org.opensearch.alerting.model.TriggerAfterKey
import org.opensearch.search.aggregations.AggregationBuilder
import org.opensearch.search.aggregations.AggregatorFactories
import org.opensearch.search.aggregations.bucket.SingleBucketAggregation
Expand Down Expand Up @@ -56,7 +57,7 @@ class AggregationQueryRewriter {
if (factory is CompositeAggregationBuilder) {
// if the afterKey from previous result is null, what does it signify?
// A) result set exhausted OR B) first page ?
val afterKey = prevResult.aggTriggersAfterKey[trigger.id]
val afterKey = prevResult.aggTriggersAfterKey[trigger.id]!!.afterKey
factory.aggregateAfter(afterKey)
} else {
throw IllegalStateException("AfterKeys are not expected to be present in non CompositeAggregationBuilder")
Expand All @@ -69,8 +70,12 @@ class AggregationQueryRewriter {
/**
* For each trigger, returns the after keys if present in query result.
*/
fun getAfterKeysFromSearchResponse(searchResponse: SearchResponse, triggers: List<Trigger>): Map<String, Map<String, Any>?> {
val aggTriggerAfterKeys = mutableMapOf<String, Map<String, Any>?>()
fun getAfterKeysFromSearchResponse(
searchResponse: SearchResponse,
triggers: List<Trigger>,
prevBucketLevelTriggerAfterKeys: Map<String, TriggerAfterKey>?
): Map<String, TriggerAfterKey> {
val bucketLevelTriggerAfterKeys = mutableMapOf<String, TriggerAfterKey>()
triggers.forEach { trigger ->
if (trigger is BucketLevelTrigger) {
val parentBucketPath = AggregationPath.parse(trigger.bucketSelector.parentBucketPath)
Expand All @@ -82,11 +87,32 @@ class AggregationQueryRewriter {
val lastAgg = aggs.asMap[parentBucketPath.pathElements.last().name]
// if leaf is CompositeAggregation, then fetch afterKey if present
if (lastAgg is CompositeAggregation) {
aggTriggerAfterKeys[trigger.id] = lastAgg.afterKey()
/*
* Bucket-Level Triggers can have different parent bucket paths that they are tracking for condition evaluation.
* These different bucket paths could have different page sizes, meaning one could be exhausted while another
* bucket path still has pages to iterate in the query responses.
*
* To ensure that these can be exhausted and tracked independently, the after key that led to the last page (which
* should be an empty result for the bucket path) will be saved when the last page is hit and will be continued
* to be passed on for that bucket path if there are still other bucket paths being paginated.
*/
val afterKey = lastAgg.afterKey()
val prevTriggerAfterKey = prevBucketLevelTriggerAfterKeys?.get(trigger.id)
bucketLevelTriggerAfterKeys[trigger.id] = when {
// If the previous TriggerAfterKey was null, this should be the first page
prevTriggerAfterKey == null -> TriggerAfterKey(afterKey, afterKey == null)
// If the previous TriggerAfterKey already hit the last page, pass along the after key it used to get there
prevTriggerAfterKey.lastPage -> prevTriggerAfterKey
// If the previous TriggerAfterKey had not reached the last page and the after key for the current result
// is null, then the last page has been reached so the after key that was used to get there is stored
afterKey == null -> TriggerAfterKey(prevTriggerAfterKey.afterKey, true)
// Otherwise, update the after key to the current one
else -> TriggerAfterKey(afterKey, false)
}
}
}
}
return aggTriggerAfterKeys
return bucketLevelTriggerAfterKeys
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ import org.junit.Assert
import org.opensearch.action.search.SearchResponse
import org.opensearch.alerting.model.InputRunResults
import org.opensearch.alerting.model.Trigger
import org.opensearch.alerting.model.TriggerAfterKey
import org.opensearch.alerting.randomBucketLevelTrigger
import org.opensearch.alerting.randomBucketSelectorExtAggregationBuilder
import org.opensearch.alerting.randomQueryLevelTrigger
import org.opensearch.cluster.ClusterModule
import org.opensearch.common.CheckedFunction
Expand Down Expand Up @@ -59,9 +61,9 @@ class AggregationQueryRewriterTests : OpenSearchTestCase() {
listOf(TermsValuesSourceBuilder("k1"), TermsValuesSourceBuilder("k2"))
)
queryBuilder.aggregation(termAgg)
val aggTriggersAfterKey = mutableMapOf<String, Map<String, Any>?>()
val aggTriggersAfterKey = mutableMapOf<String, TriggerAfterKey>()
for (trigger in triggers) {
aggTriggersAfterKey[trigger.id] = hashMapOf(Pair("k1", "v1"), Pair("k2", "v2"))
aggTriggersAfterKey[trigger.id] = TriggerAfterKey(hashMapOf(Pair("k1", "v1"), Pair("k2", "v2")), false)
}
val prevResult = InputRunResults(emptyList(), null, aggTriggersAfterKey)
AggregationQueryRewriter.rewriteQuery(queryBuilder, prevResult, triggers)
Expand Down Expand Up @@ -132,13 +134,170 @@ class AggregationQueryRewriterTests : OpenSearchTestCase() {
val tradTriggers: MutableList<Trigger> = mutableListOf(randomQueryLevelTrigger())

val searchResponse = SearchResponse.fromXContent(createParser(JsonXContent.jsonXContent, responseContent))
val afterKeys = AggregationQueryRewriter.getAfterKeysFromSearchResponse(searchResponse, aggTriggers)
Assert.assertEquals(afterKeys[aggTriggers[0].id], hashMapOf(Pair("sport", "Basketball")))
val afterKeys = AggregationQueryRewriter.getAfterKeysFromSearchResponse(searchResponse, aggTriggers, null)
Assert.assertEquals(afterKeys[aggTriggers[0].id]?.afterKey, hashMapOf(Pair("sport", "Basketball")))

val afterKeys2 = AggregationQueryRewriter.getAfterKeysFromSearchResponse(searchResponse, tradTriggers)
val afterKeys2 = AggregationQueryRewriter.getAfterKeysFromSearchResponse(searchResponse, tradTriggers, null)
Assert.assertEquals(afterKeys2.size, 0)
}

fun `test after keys from search responses for multiple bucket paths and different page counts`() {
val firstResponseContent = """
{
"took" : 0,
"timed_out" : false,
"_shards" : {
"total" : 1,
"successful" : 1,
"skipped" : 0,
"failed" : 0
},
"hits" : {
"total" : {
"value" : 4675,
"relation" : "eq"
},
"max_score" : null,
"hits" : [ ]
},
"aggregations" : {
"composite2#smallerResults" : {
"after_key" : {
"category" : "Women's Shoes"
},
"buckets" : [
{
"key" : {
"category" : "Women's Shoes"
},
"doc_count" : 1136
}
]
},
"composite3#largerResults" : {
"after_key" : {
"user" : "abigail"
},
"buckets" : [
{
"key" : {
"user" : "abd"
},
"doc_count" : 188
},
{
"key" : {
"user" : "abigail"
},
"doc_count" : 128
}
]
}
}
}
""".trimIndent()

val secondResponseContent = """
{
"took" : 0,
"timed_out" : false,
"_shards" : {
"total" : 1,
"successful" : 1,
"skipped" : 0,
"failed" : 0
},
"hits" : {
"total" : {
"value" : 4675,
"relation" : "eq"
},
"max_score" : null,
"hits" : [ ]
},
"aggregations" : {
"composite2#smallerResults" : {
"buckets" : [ ]
},
"composite3#largerResults" : {
"after_key" : {
"user" : "boris"
},
"buckets" : [
{
"key" : {
"user" : "betty"
},
"doc_count" : 148
},
{
"key" : {
"user" : "boris"
},
"doc_count" : 74
}
]
}
}
}
""".trimIndent()

val thirdResponseContent = """
{
"took" : 0,
"timed_out" : false,
"_shards" : {
"total" : 1,
"successful" : 1,
"skipped" : 0,
"failed" : 0
},
"hits" : {
"total" : {
"value" : 4675,
"relation" : "eq"
},
"max_score" : null,
"hits" : [ ]
},
"aggregations" : {
"composite2#smallerResults" : {
"buckets" : [ ]
},
"composite3#largerResults" : {
"buckets" : [ ]
}
}
}
""".trimIndent()

val bucketLevelTriggers: MutableList<Trigger> = mutableListOf(
randomBucketLevelTrigger(bucketSelector = randomBucketSelectorExtAggregationBuilder(parentBucketPath = "smallerResults")),
randomBucketLevelTrigger(bucketSelector = randomBucketSelectorExtAggregationBuilder(parentBucketPath = "largerResults"))
)

var searchResponse = SearchResponse.fromXContent(createParser(JsonXContent.jsonXContent, firstResponseContent))
val afterKeys = AggregationQueryRewriter.getAfterKeysFromSearchResponse(searchResponse, bucketLevelTriggers, null)
assertEquals(hashMapOf(Pair("category", "Women's Shoes")), afterKeys[bucketLevelTriggers[0].id]?.afterKey)
assertEquals(false, afterKeys[bucketLevelTriggers[0].id]?.lastPage)
assertEquals(hashMapOf(Pair("user", "abigail")), afterKeys[bucketLevelTriggers[1].id]?.afterKey)
assertEquals(false, afterKeys[bucketLevelTriggers[1].id]?.lastPage)

searchResponse = SearchResponse.fromXContent(createParser(JsonXContent.jsonXContent, secondResponseContent))
val afterKeys2 = AggregationQueryRewriter.getAfterKeysFromSearchResponse(searchResponse, bucketLevelTriggers, afterKeys)
assertEquals(hashMapOf(Pair("category", "Women's Shoes")), afterKeys2[bucketLevelTriggers[0].id]?.afterKey)
assertEquals(true, afterKeys2[bucketLevelTriggers[0].id]?.lastPage)
assertEquals(hashMapOf(Pair("user", "boris")), afterKeys2[bucketLevelTriggers[1].id]?.afterKey)
assertEquals(false, afterKeys2[bucketLevelTriggers[1].id]?.lastPage)

searchResponse = SearchResponse.fromXContent(createParser(JsonXContent.jsonXContent, thirdResponseContent))
val afterKeys3 = AggregationQueryRewriter.getAfterKeysFromSearchResponse(searchResponse, bucketLevelTriggers, afterKeys2)
assertEquals(hashMapOf(Pair("category", "Women's Shoes")), afterKeys3[bucketLevelTriggers[0].id]?.afterKey)
assertEquals(true, afterKeys3[bucketLevelTriggers[0].id]?.lastPage)
assertEquals(hashMapOf(Pair("user", "boris")), afterKeys3[bucketLevelTriggers[1].id]?.afterKey)
assertEquals(true, afterKeys3[bucketLevelTriggers[1].id]?.lastPage)
}

override fun xContentRegistry(): NamedXContentRegistry {
val entries = ClusterModule.getNamedXWriteables()
entries.add(
Expand All @@ -151,6 +310,26 @@ class AggregationQueryRewriterTests : OpenSearchTestCase() {
}
)
)
entries.add(
NamedXContentRegistry.Entry(
Aggregation::class.java, ParseField(CompositeAggregationBuilder.NAME + "2"),
CheckedFunction<XContentParser, ParsedComposite, IOException> { parser: XContentParser? ->
ParsedComposite.fromXContent(
parser, "smallerResults"
)
}
)
)
entries.add(
NamedXContentRegistry.Entry(
Aggregation::class.java, ParseField(CompositeAggregationBuilder.NAME + "3"),
CheckedFunction<XContentParser, ParsedComposite, IOException> { parser: XContentParser? ->
ParsedComposite.fromXContent(
parser, "largerResults"
)
}
)
)
return NamedXContentRegistry(entries)
}
}