Skip to content

Commit

Permalink
(dsl): Support Filter aggregation (#349)
Browse files Browse the repository at this point in the history
  • Loading branch information
vanjaftn authored Nov 17, 2023
1 parent c50fc01 commit 8c55381
Show file tree
Hide file tree
Showing 9 changed files with 414 additions and 26 deletions.
36 changes: 36 additions & 0 deletions docs/overview/aggregations/elastic_aggregation_filter.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
---
id: elastic_aggregation_filter
title: "Filter Aggregation"
---

The `Filter` aggregation is a single bucket aggregation that narrows down the entire set of documents to a specific set that matches a [query](https://lambdaworks.github.io/zio-elasticsearch/overview/elastic_query).

In order to use the `Filter` aggregation import the following:
```scala
import zio.elasticsearch.aggregation.FilterAggregation
import zio.elasticsearch.ElasticAggregation.filterAggregation
```

You can create a `Filter` aggregation using the `filterAggregation` method in the following manner:
```scala
import zio.elasticsearch.ElasticQuery.term

val aggregation: FilterAggregation = filterAggregation(name = "filterAggregation", query = term(field = Document.stringField, value = "test"))
```

If you want to add aggregation (on the same level), you can use `withAgg` method:
```scala
import zio.elasticsearch.ElasticQuery.term

val multipleAggregations: MultipleAggregations = filterAggregation(name = "filterAggregation", query = term(field = Document.stringField, value = "test")).withAgg(maxAggregation(name = "maxAggregation", field = Document.doubleField))
```

If you want to add another sub-aggregation, you can use `withSubAgg` method:
```scala
import zio.elasticsearch.ElasticQuery.term
import zio.elasticsearch.ElasticAggregation.maxAggregation

val aggregationWithSubAgg: FilterAggregation = filterAggregation(name = "filterAggregation", query = term(field = Document.stringField, value = "test")).withSubAgg(maxAggregation(name = "maxAggregation", field = Document.intField))
```

You can find more information about `Filter` aggregation [here](https://www.elastic.co/guide/en/elasticsearch/reference/current/search-aggregations-bucket-filter-aggregation.html).
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import zio.elasticsearch.query.sort.SortOrder._
import zio.elasticsearch.query.sort.SourceType.NumberType
import zio.elasticsearch.query.{Distance, FunctionScoreBoostMode, FunctionScoreFunction, InnerHits}
import zio.elasticsearch.request.{CreationOutcome, DeletionOutcome}
import zio.elasticsearch.result.{Item, MaxAggregationResult, UpdateByQueryResult}
import zio.elasticsearch.result.{FilterAggregationResult, Item, MaxAggregationResult, UpdateByQueryResult}
import zio.elasticsearch.script.{Painless, Script}
import zio.json.ast.Json.{Arr, Str}
import zio.schema.codec.JsonCodec
Expand Down Expand Up @@ -146,6 +146,65 @@ object HttpExecutorSpec extends IntegrationSpec {
Executor.execute(ElasticRequest.createIndex(firstSearchIndex)),
Executor.execute(ElasticRequest.deleteIndex(firstSearchIndex)).orDie
),
test("aggregate using filter aggregation with max aggregation as a sub aggregation") {
val expectedResult = (
"aggregation",
FilterAggregationResult(
docCount = 2,
subAggregations = Map(
"subAggregation" -> MaxAggregationResult(value = 5.0)
)
)
)
checkOnce(genDocumentId, genTestDocument, genDocumentId, genTestDocument, genDocumentId, genTestDocument) {
(firstDocumentId, firstDocument, secondDocumentId, secondDocument, thirdDocumentId, thirdDocument) =>
for {
_ <- Executor.execute(ElasticRequest.deleteByQuery(firstSearchIndex, matchAll))
firstDocumentUpdated = firstDocument.copy(stringField = "test", intField = 7)
secondDocumentUpdated =
secondDocument.copy(stringField = "filterAggregation", intField = 3)
thirdDocumentUpdated =
thirdDocument.copy(stringField = "filterAggregation", intField = 5)
_ <- Executor.execute(
ElasticRequest.upsert[TestDocument](
firstSearchIndex,
firstDocumentId,
firstDocumentUpdated
)
)
_ <- Executor.execute(
ElasticRequest
.upsert[TestDocument](
firstSearchIndex,
secondDocumentId,
secondDocumentUpdated
)
)
_ <- Executor.execute(
ElasticRequest
.upsert[TestDocument](
firstSearchIndex,
thirdDocumentId,
thirdDocumentUpdated
)
.refreshTrue
)
query = term(field = TestDocument.stringField, value = secondDocumentUpdated.stringField.toLowerCase)
aggregation =
filterAggregation(name = "aggregation", query = query).withSubAgg(
maxAggregation("subAggregation", TestDocument.intField)
)
aggsRes <-
Executor
.execute(ElasticRequest.aggregate(selectors = firstSearchIndex, aggregation = aggregation))
.aggregations

} yield assert(aggsRes.head)(equalTo(expectedResult))
}
} @@ around(
Executor.execute(ElasticRequest.createIndex(firstSearchIndex)),
Executor.execute(ElasticRequest.deleteIndex(firstSearchIndex)).orDie
),
test("aggregate using max aggregation") {
val expectedResponse = ("aggregationInt", MaxAggregationResult(value = 20.0))
checkOnce(genDocumentId, genTestDocument, genDocumentId, genTestDocument) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package zio.elasticsearch

import zio.Chunk
import zio.elasticsearch.aggregation._
import zio.elasticsearch.query.ElasticQuery
import zio.elasticsearch.script.Script

object ElasticAggregation {
Expand Down Expand Up @@ -113,6 +114,20 @@ object ElasticAggregation {
final def cardinalityAggregation(name: String, field: String): CardinalityAggregation =
Cardinality(name = name, field = field, missing = None)

/**
* Constructs an instance of [[zio.elasticsearch.aggregation.FilterAggregation]] using the specified parameters.
*
* @param name
* aggregation name
* @param query
* a query which the documents must match
* @return
* an instance of [[zio.elasticsearch.aggregation.FilterAggregation]] that represents filter aggregation to be
* performed.
*/
final def filterAggregation(name: String, query: ElasticQuery[_]): FilterAggregation =
Filter(name = name, query = query, subAggregations = Chunk.empty)

/**
* Constructs a type-safe instance of [[zio.elasticsearch.aggregation.ExtendedStatsAggregation]] using the specified
* parameters.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import zio.Chunk
import zio.elasticsearch.ElasticAggregation.multipleAggregations
import zio.elasticsearch.ElasticPrimitive.ElasticPrimitiveOps
import zio.elasticsearch.aggregation.options._
import zio.elasticsearch.query.ElasticQuery
import zio.elasticsearch.query.sort.Sort
import zio.elasticsearch.script.Script
import zio.json.ast.Json
Expand Down Expand Up @@ -186,6 +187,31 @@ private[elasticsearch] final case class ExtendedStats(
}
}

sealed trait FilterAggregation extends SingleElasticAggregation with WithAgg with WithSubAgg[FilterAggregation]

private[elasticsearch] final case class Filter(
name: String,
query: ElasticQuery[_],
subAggregations: Chunk[SingleElasticAggregation]
) extends FilterAggregation { self =>

def withAgg(agg: SingleElasticAggregation): MultipleAggregations =
multipleAggregations.aggregations(self, agg)

def withSubAgg(aggregation: SingleElasticAggregation): FilterAggregation =
self.copy(subAggregations = aggregation +: subAggregations)

private[elasticsearch] def toJson: Json = {
val subAggsJson: Obj =
if (self.subAggregations.nonEmpty)
Obj("aggs" -> self.subAggregations.map(_.toJson).reduce(_ merge _))
else
Obj()

Obj(name -> (Obj("filter" -> query.toJson(fieldPath = None)) merge subAggsJson))
}
}

sealed trait MaxAggregation extends SingleElasticAggregation with HasMissing[MaxAggregation] with WithAgg

private[elasticsearch] final case class Max(name: String, field: String, missing: Option[Double])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ import zio.json.ast.Json
import zio.json.ast.Json.Obj
import zio.json.{DeriveJsonDecoder, JsonDecoder, jsonField}

private[elasticsearch] sealed trait AggregationBucket

sealed trait AggregationResponse

object AggregationResponse {
Expand Down Expand Up @@ -68,6 +70,13 @@ object AggregationResponse {
lowerSampling = stdDeviationBoundsResponse.lowerSampling
)
)
case FilterAggregationResponse(docCount, subAggregations) =>
FilterAggregationResult(
docCount = docCount,
subAggregations = subAggregations.fold(Map[String, AggregationResult]())(_.map { case (key, response) =>
(key, toResult(response))
})
)
case MaxAggregationResponse(value) =>
MaxAggregationResult(value)
case MinAggregationResponse(value) =>
Expand Down Expand Up @@ -142,6 +151,123 @@ private[elasticsearch] object ExtendedStatsAggregationResponse {
DeriveJsonDecoder.gen[ExtendedStatsAggregationResponse]
}

private[elasticsearch] final case class FilterAggregationResponse(
@jsonField("doc_count")
docCount: Int,
subAggregations: Option[Map[String, AggregationResponse]] = None
) extends AggregationResponse

private[elasticsearch] object FilterAggregationResponse extends JsonDecoderOps {
implicit val decoder: JsonDecoder[FilterAggregationResponse] = Obj.decoder.mapOrFail { case Obj(fields) =>
val allFields = fields.flatMap { case (field, data) =>
field match {
case "doc_count" =>
Some(field -> data.unsafeAs[Int])
case _ =>
val objFields = data.unsafeAs[Obj].fields.toMap

(field: @unchecked) match {
case str if str.contains("weighted_avg#") =>
Some(field -> WeightedAvgAggregationResponse(value = objFields("value").unsafeAs[Double]))
case str if str.contains("avg#") =>
Some(field -> AvgAggregationResponse(value = objFields("value").unsafeAs[Double]))
case str if str.contains("cardinality#") =>
Some(field -> CardinalityAggregationResponse(value = objFields("value").unsafeAs[Int]))
case str if str.contains("extended_stats#") =>
Some(
field -> ExtendedStatsAggregationResponse(
count = objFields("count").unsafeAs[Int],
min = objFields("min").unsafeAs[Double],
max = objFields("max").unsafeAs[Double],
avg = objFields("avg").unsafeAs[Double],
sum = objFields("sum").unsafeAs[Double],
sumOfSquares = objFields("sum_of_squares").unsafeAs[Double],
variance = objFields("variance").unsafeAs[Double],
variancePopulation = objFields("variance_population").unsafeAs[Double],
varianceSampling = objFields("variance_sampling").unsafeAs[Double],
stdDeviation = objFields("std_deviation").unsafeAs[Double],
stdDeviationPopulation = objFields("std_deviation_population").unsafeAs[Double],
stdDeviationSampling = objFields("std_deviation_sampling").unsafeAs[Double],
stdDeviationBoundsResponse = objFields("std_deviation_sampling").unsafeAs[StdDeviationBoundsResponse](
StdDeviationBoundsResponse.decoder
)
)
)
case str if str.contains("filter#") =>
Some(field -> data.unsafeAs[FilterAggregationResponse](FilterAggregationResponse.decoder))
case str if str.contains("max#") =>
Some(field -> MaxAggregationResponse(value = objFields("value").unsafeAs[Double]))
case str if str.contains("min#") =>
Some(field -> MinAggregationResponse(value = objFields("value").unsafeAs[Double]))
case str if str.contains("missing#") =>
Some(field -> MissingAggregationResponse(docCount = objFields("doc_count").unsafeAs[Int]))
case str if str.contains("percentiles#") =>
Some(field -> PercentilesAggregationResponse(values = objFields("values").unsafeAs[Map[String, Double]]))
case str if str.contains("stats#") =>
Some(
field -> StatsAggregationResponse(
count = objFields("count").unsafeAs[Int],
min = objFields("min").unsafeAs[Double],
max = objFields("max").unsafeAs[Double],
avg = objFields("avg").unsafeAs[Double],
sum = objFields("sum").unsafeAs[Double]
)
)
case str if str.contains("sum#") =>
Some(field -> SumAggregationResponse(value = objFields("value").unsafeAs[Double]))
case str if str.contains("terms#") =>
Some(field -> data.unsafeAs[TermsAggregationResponse](TermsAggregationResponse.decoder))
case str if str.contains("value_count#") =>
Some(field -> ValueCountAggregationResponse(value = objFields("value").unsafeAs[Int]))
}
}
}.toMap

val docCount = allFields("doc_count").asInstanceOf[Int]
val subAggs = allFields.collect {
case (field, data) if field != "doc_count" =>
(field: @unchecked) match {
case str if str.contains("weighted_avg#") =>
(field.split("#")(1), data.asInstanceOf[WeightedAvgAggregationResponse])
case str if str.contains("avg#") =>
(field.split("#")(1), data.asInstanceOf[AvgAggregationResponse])
case str if str.contains("cardinality#") =>
(field.split("#")(1), data.asInstanceOf[CardinalityAggregationResponse])
case str if str.contains("extended_stats#") =>
(field.split("#")(1), data.asInstanceOf[ExtendedStatsAggregationResponse])
case str if str.contains("filter#") =>
(field.split("#")(1), data.asInstanceOf[FilterAggregationResponse])
case str if str.contains("max#") =>
(field.split("#")(1), data.asInstanceOf[MaxAggregationResponse])
case str if str.contains("min#") =>
(field.split("#")(1), data.asInstanceOf[MinAggregationResponse])
case str if str.contains("missing#") =>
(field.split("#")(1), data.asInstanceOf[MissingAggregationResponse])
case str if str.contains("percentiles#") =>
(field.split("#")(1), data.asInstanceOf[PercentilesAggregationResponse])
case str if str.contains("stats#") =>
(field.split("#")(1), data.asInstanceOf[StatsAggregationResponse])
case str if str.contains("sum#") =>
(field.split("#")(1), data.asInstanceOf[SumAggregationResponse])
case str if str.contains("terms#") =>
(field.split("#")(1), data.asInstanceOf[TermsAggregationResponse])
case str if str.contains("value_count#") =>
(field.split("#")(1), data.asInstanceOf[ValueCountAggregationResponse])
}
}
Right(FilterAggregationResponse.apply(docCount, Option(subAggs).filter(_.nonEmpty)))
}
}

private[elasticsearch] sealed trait JsonDecoderOps {
implicit class JsonDecoderOps(json: Json) {
def unsafeAs[A](implicit decoder: JsonDecoder[A]): A =
(json.as[A]: @unchecked) match {
case Right(decoded) => decoded
}
}
}

private[elasticsearch] final case class MaxAggregationResponse(value: Double) extends AggregationResponse

private[elasticsearch] object MaxAggregationResponse {
Expand Down Expand Up @@ -217,16 +343,14 @@ private[elasticsearch] object TermsAggregationResponse {
implicit val decoder: JsonDecoder[TermsAggregationResponse] = DeriveJsonDecoder.gen[TermsAggregationResponse]
}

private[elasticsearch] sealed trait AggregationBucket

private[elasticsearch] final case class TermsAggregationBucket(
key: String,
@jsonField("doc_count")
docCount: Int,
subAggregations: Option[Map[String, AggregationResponse]] = None
) extends AggregationBucket

private[elasticsearch] object TermsAggregationBucket {
private[elasticsearch] object TermsAggregationBucket extends JsonDecoderOps {
implicit val decoder: JsonDecoder[TermsAggregationBucket] = Obj.decoder.mapOrFail { case Obj(fields) =>
val allFields = fields.flatMap { case (field, data) =>
field match {
Expand Down Expand Up @@ -264,6 +388,8 @@ private[elasticsearch] object TermsAggregationBucket {
)
)
)
case str if str.contains("filter#") =>
Some(field -> data.unsafeAs[FilterAggregationResponse](FilterAggregationResponse.decoder))
case str if str.contains("max#") =>
Some(field -> MaxAggregationResponse(value = objFields("value").unsafeAs[Double]))
case str if str.contains("min#") =>
Expand All @@ -285,15 +411,7 @@ private[elasticsearch] object TermsAggregationBucket {
case str if str.contains("sum#") =>
Some(field -> SumAggregationResponse(value = objFields("value").unsafeAs[Double]))
case str if str.contains("terms#") =>
Some(
field -> TermsAggregationResponse(
docErrorCount = objFields("doc_count_error_upper_bound").unsafeAs[Int],
sumOtherDocCount = objFields("sum_other_doc_count").unsafeAs[Int],
buckets = objFields("buckets")
.unsafeAs[Chunk[Json]]
.map(_.unsafeAs[TermsAggregationBucket](TermsAggregationBucket.decoder))
)
)
Some(field -> data.unsafeAs[TermsAggregationResponse](TermsAggregationResponse.decoder))
case str if str.contains("value_count#") =>
Some(field -> ValueCountAggregationResponse(value = objFields("value").unsafeAs[Int]))
}
Expand All @@ -313,6 +431,8 @@ private[elasticsearch] object TermsAggregationBucket {
(field.split("#")(1), data.asInstanceOf[CardinalityAggregationResponse])
case str if str.contains("extended_stats#") =>
(field.split("#")(1), data.asInstanceOf[ExtendedStatsAggregationResponse])
case str if str.contains("filter#") =>
(field.split("#")(1), data.asInstanceOf[FilterAggregationResponse])
case str if str.contains("max#") =>
(field.split("#")(1), data.asInstanceOf[MaxAggregationResponse])
case str if str.contains("min#") =>
Expand All @@ -331,16 +451,8 @@ private[elasticsearch] object TermsAggregationBucket {
(field.split("#")(1), data.asInstanceOf[ValueCountAggregationResponse])
}
}

Right(TermsAggregationBucket.apply(key, docCount, Option(subAggs).filter(_.nonEmpty)))
}

final implicit class JsonDecoderOps(json: Json) {
def unsafeAs[A](implicit decoder: JsonDecoder[A]): A =
(json.as[A]: @unchecked) match {
case Right(decoded) => decoded
}
}
}

private[elasticsearch] final case class ValueCountAggregationResponse(value: Int) extends AggregationResponse
Expand Down
Loading

0 comments on commit 8c55381

Please sign in to comment.