Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(dsl): Support Filter aggregation #349

Merged
merged 12 commits into from
Nov 17, 2023
36 changes: 36 additions & 0 deletions docs/overview/aggregations/elastic_aggregation_filter.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
---
id: elastic_aggregation_filter
title: "Filter Aggregation"
---

The `Filter` aggregation is a single bucket aggregation that narrows down the entire set of documents to a specific set that matches a [query](https://lambdaworks.github.io/zio-elasticsearch/overview/elastic_query).

In order to use the `Filter` aggregation import the following:
```scala
import zio.elasticsearch.aggregation.FilterAggregation
import zio.elasticsearch.ElasticAggregation.filterAggregation
```

You can create a `Filter` aggregation using the `filterAggregation` method in the following manner:
```scala
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add the import for term query.

import zio.elasticsearch.ElasticQuery.term

val aggregation: FilterAggregation = filterAggregation(name = "filterAggregation", query = term(field = Document.stringField, value = "test"))
```

If you want to add aggregation (on the same level), you can use `withAgg` method:
```scala
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.

import zio.elasticsearch.ElasticQuery.term

val multipleAggregations: MultipleAggregations = filterAggregation(name = "filterAggregation", query = term(field = Document.stringField, value = "test")).withAgg(maxAggregation(name = "maxAggregation", field = Document.doubleField))
```

If you want to add another sub-aggregation, you can use `withSubAgg` method:
```scala
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here add also for maxAggregation too.

import zio.elasticsearch.ElasticQuery.term
import zio.elasticsearch.ElasticAggregation.maxAggregation

val aggregationWithSubAgg: FilterAggregation = filterAggregation(name = "filterAggregation", query = term(field = Document.stringField, value = "test")).withSubAgg(maxAggregation(name = "maxAggregation", field = Document.intField))
```

You can find more information about `Filter` aggregation [here](https://www.elastic.co/guide/en/elasticsearch/reference/current/search-aggregations-bucket-filter-aggregation.html).
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import zio.elasticsearch.query.sort.SortOrder._
import zio.elasticsearch.query.sort.SourceType.NumberType
import zio.elasticsearch.query.{Distance, FunctionScoreBoostMode, FunctionScoreFunction, InnerHits}
import zio.elasticsearch.request.{CreationOutcome, DeletionOutcome}
import zio.elasticsearch.result.{Item, MaxAggregationResult, UpdateByQueryResult}
import zio.elasticsearch.result.{FilterAggregationResult, Item, MaxAggregationResult, UpdateByQueryResult}
import zio.elasticsearch.script.{Painless, Script}
import zio.json.ast.Json.{Arr, Str}
import zio.schema.codec.JsonCodec
Expand Down Expand Up @@ -146,6 +146,65 @@ object HttpExecutorSpec extends IntegrationSpec {
Executor.execute(ElasticRequest.createIndex(firstSearchIndex)),
Executor.execute(ElasticRequest.deleteIndex(firstSearchIndex)).orDie
),
test("aggregate using filter aggregation with max aggregation as a sub aggregation") {
val expectedResult = (
"aggregation",
FilterAggregationResult(
docCount = 2,
subAggregations = Map(
"subAggregation" -> MaxAggregationResult(value = 5.0)
)
)
)
checkOnce(genDocumentId, genTestDocument, genDocumentId, genTestDocument, genDocumentId, genTestDocument) {
(firstDocumentId, firstDocument, secondDocumentId, secondDocument, thirdDocumentId, thirdDocument) =>
for {
_ <- Executor.execute(ElasticRequest.deleteByQuery(firstSearchIndex, matchAll))
firstDocumentUpdated = firstDocument.copy(stringField = "test", intField = 7)
secondDocumentUpdated =
secondDocument.copy(stringField = "filterAggregation", intField = 3)
thirdDocumentUpdated =
thirdDocument.copy(stringField = "filterAggregation", intField = 5)
_ <- Executor.execute(
ElasticRequest.upsert[TestDocument](
firstSearchIndex,
firstDocumentId,
firstDocumentUpdated
)
)
_ <- Executor.execute(
ElasticRequest
.upsert[TestDocument](
firstSearchIndex,
secondDocumentId,
secondDocumentUpdated
)
)
_ <- Executor.execute(
ElasticRequest
.upsert[TestDocument](
firstSearchIndex,
thirdDocumentId,
thirdDocumentUpdated
)
.refreshTrue
)
query = term(field = TestDocument.stringField, value = secondDocumentUpdated.stringField.toLowerCase)
aggregation =
filterAggregation(name = "aggregation", query = query).withSubAgg(
maxAggregation("subAggregation", TestDocument.intField)
)
aggsRes <-
Executor
.execute(ElasticRequest.aggregate(selectors = firstSearchIndex, aggregation = aggregation))
.aggregations

} yield assert(aggsRes.head)(equalTo(expectedResult))
}
} @@ around(
Executor.execute(ElasticRequest.createIndex(firstSearchIndex)),
Executor.execute(ElasticRequest.deleteIndex(firstSearchIndex)).orDie
),
test("aggregate using max aggregation") {
val expectedResponse = ("aggregationInt", MaxAggregationResult(value = 20.0))
checkOnce(genDocumentId, genTestDocument, genDocumentId, genTestDocument) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package zio.elasticsearch

import zio.Chunk
import zio.elasticsearch.aggregation._
import zio.elasticsearch.query.ElasticQuery
import zio.elasticsearch.script.Script

object ElasticAggregation {
Expand Down Expand Up @@ -113,6 +114,20 @@ object ElasticAggregation {
final def cardinalityAggregation(name: String, field: String): CardinalityAggregation =
Cardinality(name = name, field = field, missing = None)

/**
* Constructs an instance of [[zio.elasticsearch.aggregation.FilterAggregation]] using the specified parameters.
*
* @param name
* aggregation name
* @param query
* a query which the documents must match
* @return
* an instance of [[zio.elasticsearch.aggregation.FilterAggregation]] that represents filter aggregation to be
* performed.
*/
final def filterAggregation(name: String, query: ElasticQuery[_]): FilterAggregation =
Filter(name = name, query = query, subAggregations = Chunk.empty)

/**
* Constructs a type-safe instance of [[zio.elasticsearch.aggregation.ExtendedStatsAggregation]] using the specified
* parameters.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import zio.Chunk
import zio.elasticsearch.ElasticAggregation.multipleAggregations
import zio.elasticsearch.ElasticPrimitive.ElasticPrimitiveOps
import zio.elasticsearch.aggregation.options._
import zio.elasticsearch.query.ElasticQuery
import zio.elasticsearch.query.sort.Sort
import zio.elasticsearch.script.Script
import zio.json.ast.Json
Expand Down Expand Up @@ -186,6 +187,31 @@ private[elasticsearch] final case class ExtendedStats(
}
}

sealed trait FilterAggregation extends SingleElasticAggregation with WithAgg with WithSubAgg[FilterAggregation]

private[elasticsearch] final case class Filter(
name: String,
query: ElasticQuery[_],
subAggregations: Chunk[SingleElasticAggregation]
) extends FilterAggregation { self =>

def withAgg(agg: SingleElasticAggregation): MultipleAggregations =
multipleAggregations.aggregations(self, agg)

def withSubAgg(aggregation: SingleElasticAggregation): FilterAggregation =
self.copy(subAggregations = aggregation +: subAggregations)

private[elasticsearch] def toJson: Json = {
val subAggsJson: Obj =
if (self.subAggregations.nonEmpty)
Obj("aggs" -> self.subAggregations.map(_.toJson).reduce(_ merge _))
else
Obj()

Obj(name -> (Obj("filter" -> query.toJson(fieldPath = None)) merge subAggsJson))
}
}

sealed trait MaxAggregation extends SingleElasticAggregation with HasMissing[MaxAggregation] with WithAgg

private[elasticsearch] final case class Max(name: String, field: String, missing: Option[Double])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ import zio.json.ast.Json
import zio.json.ast.Json.Obj
import zio.json.{DeriveJsonDecoder, JsonDecoder, jsonField}

private[elasticsearch] sealed trait AggregationBucket

sealed trait AggregationResponse

object AggregationResponse {
Expand Down Expand Up @@ -68,6 +70,13 @@ object AggregationResponse {
lowerSampling = stdDeviationBoundsResponse.lowerSampling
)
)
case FilterAggregationResponse(docCount, subAggregations) =>
FilterAggregationResult(
docCount = docCount,
subAggregations = subAggregations.fold(Map[String, AggregationResult]())(_.map { case (key, response) =>
(key, toResult(response))
})
)
case MaxAggregationResponse(value) =>
MaxAggregationResult(value)
case MinAggregationResponse(value) =>
Expand Down Expand Up @@ -142,6 +151,123 @@ private[elasticsearch] object ExtendedStatsAggregationResponse {
DeriveJsonDecoder.gen[ExtendedStatsAggregationResponse]
}

private[elasticsearch] final case class FilterAggregationResponse(
@jsonField("doc_count")
docCount: Int,
subAggregations: Option[Map[String, AggregationResponse]] = None
) extends AggregationResponse

private[elasticsearch] object FilterAggregationResponse extends JsonDecoderOps {
implicit val decoder: JsonDecoder[FilterAggregationResponse] = Obj.decoder.mapOrFail { case Obj(fields) =>
val allFields = fields.flatMap { case (field, data) =>
field match {
case "doc_count" =>
Some(field -> data.unsafeAs[Int])
case _ =>
val objFields = data.unsafeAs[Obj].fields.toMap

(field: @unchecked) match {
case str if str.contains("weighted_avg#") =>
Some(field -> WeightedAvgAggregationResponse(value = objFields("value").unsafeAs[Double]))
case str if str.contains("avg#") =>
Some(field -> AvgAggregationResponse(value = objFields("value").unsafeAs[Double]))
case str if str.contains("cardinality#") =>
Some(field -> CardinalityAggregationResponse(value = objFields("value").unsafeAs[Int]))
case str if str.contains("extended_stats#") =>
Some(
field -> ExtendedStatsAggregationResponse(
count = objFields("count").unsafeAs[Int],
min = objFields("min").unsafeAs[Double],
max = objFields("max").unsafeAs[Double],
avg = objFields("avg").unsafeAs[Double],
sum = objFields("sum").unsafeAs[Double],
sumOfSquares = objFields("sum_of_squares").unsafeAs[Double],
variance = objFields("variance").unsafeAs[Double],
variancePopulation = objFields("variance_population").unsafeAs[Double],
varianceSampling = objFields("variance_sampling").unsafeAs[Double],
stdDeviation = objFields("std_deviation").unsafeAs[Double],
stdDeviationPopulation = objFields("std_deviation_population").unsafeAs[Double],
stdDeviationSampling = objFields("std_deviation_sampling").unsafeAs[Double],
stdDeviationBoundsResponse = objFields("std_deviation_sampling").unsafeAs[StdDeviationBoundsResponse](
StdDeviationBoundsResponse.decoder
)
)
)
case str if str.contains("filter#") =>
Some(field -> data.unsafeAs[FilterAggregationResponse](FilterAggregationResponse.decoder))
case str if str.contains("max#") =>
Some(field -> MaxAggregationResponse(value = objFields("value").unsafeAs[Double]))
case str if str.contains("min#") =>
Some(field -> MinAggregationResponse(value = objFields("value").unsafeAs[Double]))
case str if str.contains("missing#") =>
Some(field -> MissingAggregationResponse(docCount = objFields("doc_count").unsafeAs[Int]))
case str if str.contains("percentiles#") =>
Some(field -> PercentilesAggregationResponse(values = objFields("values").unsafeAs[Map[String, Double]]))
case str if str.contains("stats#") =>
Some(
field -> StatsAggregationResponse(
count = objFields("count").unsafeAs[Int],
min = objFields("min").unsafeAs[Double],
max = objFields("max").unsafeAs[Double],
avg = objFields("avg").unsafeAs[Double],
sum = objFields("sum").unsafeAs[Double]
)
)
case str if str.contains("sum#") =>
Some(field -> SumAggregationResponse(value = objFields("value").unsafeAs[Double]))
case str if str.contains("terms#") =>
Some(field -> data.unsafeAs[TermsAggregationResponse](TermsAggregationResponse.decoder))
case str if str.contains("value_count#") =>
Some(field -> ValueCountAggregationResponse(value = objFields("value").unsafeAs[Int]))
}
}
}.toMap

val docCount = allFields("doc_count").asInstanceOf[Int]
val subAggs = allFields.collect {
case (field, data) if field != "doc_count" =>
(field: @unchecked) match {
case str if str.contains("weighted_avg#") =>
(field.split("#")(1), data.asInstanceOf[WeightedAvgAggregationResponse])
case str if str.contains("avg#") =>
(field.split("#")(1), data.asInstanceOf[AvgAggregationResponse])
case str if str.contains("cardinality#") =>
(field.split("#")(1), data.asInstanceOf[CardinalityAggregationResponse])
case str if str.contains("extended_stats#") =>
(field.split("#")(1), data.asInstanceOf[ExtendedStatsAggregationResponse])
case str if str.contains("filter#") =>
(field.split("#")(1), data.asInstanceOf[FilterAggregationResponse])
case str if str.contains("max#") =>
(field.split("#")(1), data.asInstanceOf[MaxAggregationResponse])
case str if str.contains("min#") =>
(field.split("#")(1), data.asInstanceOf[MinAggregationResponse])
case str if str.contains("missing#") =>
(field.split("#")(1), data.asInstanceOf[MissingAggregationResponse])
case str if str.contains("percentiles#") =>
(field.split("#")(1), data.asInstanceOf[PercentilesAggregationResponse])
case str if str.contains("stats#") =>
(field.split("#")(1), data.asInstanceOf[StatsAggregationResponse])
case str if str.contains("sum#") =>
(field.split("#")(1), data.asInstanceOf[SumAggregationResponse])
case str if str.contains("terms#") =>
(field.split("#")(1), data.asInstanceOf[TermsAggregationResponse])
case str if str.contains("value_count#") =>
(field.split("#")(1), data.asInstanceOf[ValueCountAggregationResponse])
}
}
Right(FilterAggregationResponse.apply(docCount, Option(subAggs).filter(_.nonEmpty)))
}
}

private[elasticsearch] sealed trait JsonDecoderOps {
implicit class JsonDecoderOps(json: Json) {
def unsafeAs[A](implicit decoder: JsonDecoder[A]): A =
(json.as[A]: @unchecked) match {
case Right(decoded) => decoded
}
}
}
dbulaja98 marked this conversation as resolved.
Show resolved Hide resolved

private[elasticsearch] final case class MaxAggregationResponse(value: Double) extends AggregationResponse

private[elasticsearch] object MaxAggregationResponse {
Expand Down Expand Up @@ -217,16 +343,14 @@ private[elasticsearch] object TermsAggregationResponse {
implicit val decoder: JsonDecoder[TermsAggregationResponse] = DeriveJsonDecoder.gen[TermsAggregationResponse]
}

private[elasticsearch] sealed trait AggregationBucket

private[elasticsearch] final case class TermsAggregationBucket(
key: String,
@jsonField("doc_count")
docCount: Int,
subAggregations: Option[Map[String, AggregationResponse]] = None
) extends AggregationBucket

private[elasticsearch] object TermsAggregationBucket {
private[elasticsearch] object TermsAggregationBucket extends JsonDecoderOps {
implicit val decoder: JsonDecoder[TermsAggregationBucket] = Obj.decoder.mapOrFail { case Obj(fields) =>
val allFields = fields.flatMap { case (field, data) =>
field match {
Expand Down Expand Up @@ -264,6 +388,8 @@ private[elasticsearch] object TermsAggregationBucket {
)
)
)
case str if str.contains("filter#") =>
Some(field -> data.unsafeAs[FilterAggregationResponse](FilterAggregationResponse.decoder))
case str if str.contains("max#") =>
Some(field -> MaxAggregationResponse(value = objFields("value").unsafeAs[Double]))
case str if str.contains("min#") =>
Expand All @@ -285,15 +411,7 @@ private[elasticsearch] object TermsAggregationBucket {
case str if str.contains("sum#") =>
Some(field -> SumAggregationResponse(value = objFields("value").unsafeAs[Double]))
case str if str.contains("terms#") =>
Some(
field -> TermsAggregationResponse(
docErrorCount = objFields("doc_count_error_upper_bound").unsafeAs[Int],
sumOtherDocCount = objFields("sum_other_doc_count").unsafeAs[Int],
buckets = objFields("buckets")
.unsafeAs[Chunk[Json]]
.map(_.unsafeAs[TermsAggregationBucket](TermsAggregationBucket.decoder))
)
)
Some(field -> data.unsafeAs[TermsAggregationResponse](TermsAggregationResponse.decoder))
case str if str.contains("value_count#") =>
Some(field -> ValueCountAggregationResponse(value = objFields("value").unsafeAs[Int]))
}
Expand All @@ -313,6 +431,8 @@ private[elasticsearch] object TermsAggregationBucket {
(field.split("#")(1), data.asInstanceOf[CardinalityAggregationResponse])
case str if str.contains("extended_stats#") =>
(field.split("#")(1), data.asInstanceOf[ExtendedStatsAggregationResponse])
case str if str.contains("filter#") =>
(field.split("#")(1), data.asInstanceOf[FilterAggregationResponse])
case str if str.contains("max#") =>
(field.split("#")(1), data.asInstanceOf[MaxAggregationResponse])
case str if str.contains("min#") =>
Expand All @@ -331,16 +451,8 @@ private[elasticsearch] object TermsAggregationBucket {
(field.split("#")(1), data.asInstanceOf[ValueCountAggregationResponse])
}
}

Right(TermsAggregationBucket.apply(key, docCount, Option(subAggs).filter(_.nonEmpty)))
}

final implicit class JsonDecoderOps(json: Json) {
def unsafeAs[A](implicit decoder: JsonDecoder[A]): A =
(json.as[A]: @unchecked) match {
case Right(decoded) => decoded
}
}
}

private[elasticsearch] final case class ValueCountAggregationResponse(value: Int) extends AggregationResponse
Expand Down
Loading