Skip to content

Commit

Permalink
Implement filter aggregation
Browse files Browse the repository at this point in the history
  • Loading branch information
vanjaftn committed Nov 1, 2023
1 parent 3c7beb5 commit 7ebeec7
Show file tree
Hide file tree
Showing 8 changed files with 328 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,34 @@ object HttpExecutorSpec extends IntegrationSpec {
Executor.execute(ElasticRequest.createIndex(firstSearchIndex)),
Executor.execute(ElasticRequest.deleteIndex(firstSearchIndex)).orDie
),
test("aggregate using filter aggregation") {
val expectedResponse = ("aggregation", MaxAggregationResult(value = 5.0))
checkOnce(genDocumentId, genTestDocument, genDocumentId, genTestDocument) {
(firstDocumentId, firstDocument, secondDocumentId, secondDocument) =>
for {
_ <- Executor.execute(
ElasticRequest.upsert[TestDocument](firstSearchIndex, firstDocumentId, firstDocument.copy(stringField = "test", intField = 5))
)
_ <- Executor.execute(
ElasticRequest
.upsert[TestDocument](firstSearchIndex, secondDocumentId, secondDocument.copy(stringField = "test1", intField = 7))
)
aggregation =
filterAggregation(name = "aggregation", field = "test").withSubAgg(
maxAggregation("subAggregation", TestDocument.intField)
)

aggsRes <-
Executor
.execute(ElasticRequest.aggregate(selectors = firstSearchIndex, aggregation = aggregation))
.aggregations

} yield assert(aggsRes.head)(equalTo(expectedResponse))
}
} @@ around(
Executor.execute(ElasticRequest.createIndex(firstSearchIndex)),
Executor.execute(ElasticRequest.deleteIndex(firstSearchIndex)).orDie
),
test("aggregate using max aggregation") {
val expectedResponse = ("aggregationInt", MaxAggregationResult(value = 20.0))
checkOnce(genDocumentId, genTestDocument, genDocumentId, genTestDocument) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,36 @@ object ElasticAggregation {
final def cardinalityAggregation(name: String, field: String): CardinalityAggregation =
Cardinality(name = name, field = field, missing = None)

//Scala doc
/**
* Constructs a type-safe instance of [[zio.elasticsearch.aggregation.MaxAggregation]] using the specified parameters.
*
* @param name
* aggregation name
* @param field
* the type-safe field for which max aggregation will be executed
* @tparam A
* expected number type
* @return
* an instance of [[zio.elasticsearch.aggregation.MaxAggregation]] that represents max aggregation to be performed.
*/
final def filterAggregation(name: String, field: Field[_, String]): FilterAggregation =
Filter(name = name, field = field.toString, subAggregations = Chunk.empty)

//Scala doc
/**
* Constructs an instance of [[zio.elasticsearch.aggregation.MaxAggregation]] using the specified parameters.
*
* @param name
* aggregation name
* @param field
* the field for which max aggregation will be executed
* @return
* an instance of [[zio.elasticsearch.aggregation.MaxAggregation]] that represents max aggregation to be performed.
*/
final def filterAggregation(name: String, field: String): FilterAggregation =
Filter(name = name, field = field, subAggregations = Chunk.empty)

/**
* Constructs a type-safe instance of [[zio.elasticsearch.aggregation.MaxAggregation]] using the specified parameters.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,30 @@ private[elasticsearch] final case class Cardinality(name: String, field: String,
}
}

sealed trait FilterAggregation
extends SingleElasticAggregation
with WithSubAgg[FilterAggregation]
with WithAgg

private[elasticsearch] final case class Filter(name: String, field: String, subAggregations: Chunk[SingleElasticAggregation])
extends FilterAggregation { self =>

def withAgg(agg: SingleElasticAggregation): MultipleAggregations =
multipleAggregations.aggregations(self, agg)

def withSubAgg(aggregation: SingleElasticAggregation): FilterAggregation =
self.copy(subAggregations = aggregation +: subAggregations)

val subAggsJson =
if (self.subAggregations.nonEmpty)
Obj("aggs" -> self.subAggregations.map(_.toJson).reduce(_ merge _))
else
Obj()
private[elasticsearch] def toJson: Json = {
Obj(name -> (Obj("filter" -> Obj("term" -> Obj("type" -> self.field.toJson))) merge subAggsJson))
}
}

sealed trait MaxAggregation extends SingleElasticAggregation with HasMissing[MaxAggregation] with WithAgg

private[elasticsearch] final case class Max(name: String, field: String, missing: Option[Double])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,12 +122,15 @@ private[elasticsearch] final class HttpExecutor private (esConfig: ElasticConfig
).flatMap { response =>
response.code match {
case HttpOk =>
println(response.body)
response.body.fold(
e => ZIO.fail(new ElasticException(s"Exception occurred: ${e.getMessage}")),
value =>
value => {
ZIO.succeed(new AggregateResult(value.aggs.map { case (key, response) =>
(key, toResult(response))
}))
}

)
case _ =>
ZIO.fail(handleFailuresFromCustomResponse(response))
Expand Down Expand Up @@ -601,7 +604,7 @@ private[elasticsearch] final class HttpExecutor private (esConfig: ElasticConfig
UnauthorizedException
case _ =>
new ElasticException(
s"Unexpected response from Elasticsearch. Response body: ${response.body.fold(body => body, _ => "")}"
s"Unexpected response from Elasticsearch. Response body: ${response.body.fold(body => body, _ => "")}"
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,18 @@ object AggregationResponse {
AvgAggregationResult(value)
case CardinalityAggregationResponse(value) =>
CardinalityAggregationResult(value)
case FilterAggregationResponse(docCount, buckets) =>
FilterAggregationResult(
docCount = docCount,
buckets = buckets.map(b =>
FilterAggregationBucketResult(
docCount = b.docCount,
subAggregations = b.subAggregations.fold(Map[String, AggregationResult]())(_.map { case (key, response) =>
(key, toResult(response))
})
)
)
)
case MaxAggregationResponse(value) =>
MaxAggregationResult(value)
case MinAggregationResponse(value) =>
Expand Down Expand Up @@ -105,6 +117,109 @@ private[elasticsearch] object SumAggregationResponse {

}

private[elasticsearch] sealed trait AggregationBucket

private[elasticsearch] final case class FilterAggregationResponse(
@jsonField("doc_count")
docCount: Int,
buckets: Chunk[FilterAggregationBucket]
) extends AggregationResponse

private[elasticsearch] object FilterAggregationResponse {
implicit val decoder: JsonDecoder[FilterAggregationResponse] =
DeriveJsonDecoder.gen[FilterAggregationResponse]
}

private[elasticsearch] final case class FilterAggregationBucket(
@jsonField("doc_count")
docCount: Int,
subAggregations: Option[Map[String, AggregationResponse]] = None
) extends AggregationBucket

private[elasticsearch] object FilterAggregationBucket {
implicit val decoder: JsonDecoder[FilterAggregationBucket] = Obj.decoder.mapOrFail { case Obj(fields) =>
val allFields = fields.flatMap { case (field, data) =>
field match {
case "doc_count" =>
Some(field -> data.unsafeAs[Int])
case _ =>
val objFields = data.unsafeAs[Obj].fields.toMap

(field: @unchecked) match {
case str if str.contains("avg#") =>
Some(field -> AvgAggregationResponse(value = objFields("value").unsafeAs[Double]))
case str if str.contains("cardinality#") =>
Some(field -> CardinalityAggregationResponse(value = objFields("value").unsafeAs[Int]))
case str if str.contains("filter#") =>
Some(
field -> FilterAggregationResponse(
docCount = objFields("doc_count").unsafeAs[Int],
buckets = objFields("buckets")
.unsafeAs[Chunk[Json]]
.map(_.unsafeAs[FilterAggregationBucket](FilterAggregationBucket.decoder))
)
)
case str if str.contains("max#") =>
Some(field -> MaxAggregationResponse(value = objFields("value").unsafeAs[Double]))
case str if str.contains("min#") =>
Some(field -> MinAggregationResponse(value = objFields("value").unsafeAs[Double]))
case str if str.contains("missing#") =>
Some(field -> MissingAggregationResponse(docCount = objFields("doc_count").unsafeAs[Int]))
case str if str.contains("percentiles#") =>
Some(field -> PercentilesAggregationResponse(values = objFields("values").unsafeAs[Map[String, Double]]))
case str if str.contains("sum#") =>
Some(field -> SumAggregationResponse(value = objFields("value").unsafeAs[Double]))
case str if str.contains("terms#") =>
Some(
field -> TermsAggregationResponse(
docErrorCount = objFields("doc_count_error_upper_bound").unsafeAs[Int],
sumOtherDocCount = objFields("sum_other_doc_count").unsafeAs[Int],
buckets = objFields("buckets")
.unsafeAs[Chunk[Json]]
.map(_.unsafeAs[TermsAggregationBucket](TermsAggregationBucket.decoder))
)
)
}
}
}.toMap

val docCount = allFields("doc_count").asInstanceOf[Int]
val subAggs = allFields.collect {
case (field, data) if field != "doc_count" =>
(field: @unchecked) match {
case str if str.contains("avg#") =>
(field.split("#")(1), data.asInstanceOf[AvgAggregationResponse])
case str if str.contains("cardinality#") =>
(field.split("#")(1), data.asInstanceOf[CardinalityAggregationResponse])
case str if str.contains("filter#") =>
(field.split("#")(1), data.asInstanceOf[PercentilesAggregationResponse])
case str if str.contains("max#") =>
(field.split("#")(1), data.asInstanceOf[MaxAggregationResponse])
case str if str.contains("min#") =>
(field.split("#")(1), data.asInstanceOf[MinAggregationResponse])
case str if str.contains("missing#") =>
(field.split("#")(1), data.asInstanceOf[MissingAggregationResponse])
case str if str.contains("percentiles#") =>
(field.split("#")(1), data.asInstanceOf[PercentilesAggregationResponse])
case str if str.contains("sum#") =>
(field.split("#")(1), data.asInstanceOf[SumAggregationResponse])
case str if str.contains("terms#") =>
(field.split("#")(1), data.asInstanceOf[TermsAggregationResponse])
}
}

Right(FilterAggregationBucket.apply(docCount, Option(subAggs).filter(_.nonEmpty)))
}

final implicit class JsonDecoderOps(json: Json) {
def unsafeAs[A](implicit decoder: JsonDecoder[A]): A =
(json.as[A]: @unchecked) match {
case Right(decoded) => decoded
}
}

}

private[elasticsearch] final case class TermsAggregationResponse(
@jsonField("doc_count_error_upper_bound")
docErrorCount: Int,
Expand All @@ -117,8 +232,6 @@ private[elasticsearch] object TermsAggregationResponse {
implicit val decoder: JsonDecoder[TermsAggregationResponse] = DeriveJsonDecoder.gen[TermsAggregationResponse]
}

private[elasticsearch] sealed trait AggregationBucket

private[elasticsearch] final case class TermsAggregationBucket(
key: String,
@jsonField("doc_count")
Expand All @@ -142,6 +255,15 @@ private[elasticsearch] object TermsAggregationBucket {
Some(field -> AvgAggregationResponse(value = objFields("value").unsafeAs[Double]))
case str if str.contains("cardinality#") =>
Some(field -> CardinalityAggregationResponse(value = objFields("value").unsafeAs[Int]))
case str if str.contains("filter#") =>
Some(
field -> FilterAggregationResponse(
docCount = objFields("doc_count_error_upper_bound").unsafeAs[Int],
buckets = objFields("buckets")
.unsafeAs[Chunk[Json]]
.map(_.unsafeAs[FilterAggregationBucket](FilterAggregationBucket.decoder))
)
)
case str if str.contains("max#") =>
Some(field -> MaxAggregationResponse(value = objFields("value").unsafeAs[Double]))
case str if str.contains("min#") =>
Expand Down Expand Up @@ -175,6 +297,8 @@ private[elasticsearch] object TermsAggregationBucket {
(field.split("#")(1), data.asInstanceOf[AvgAggregationResponse])
case str if str.contains("cardinality#") =>
(field.split("#")(1), data.asInstanceOf[CardinalityAggregationResponse])
case str if str.contains("filter#") =>
(field.split("#")(1), data.asInstanceOf[PercentilesAggregationResponse])
case str if str.contains("max#") =>
(field.split("#")(1), data.asInstanceOf[MaxAggregationResponse])
case str if str.contains("min#") =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ private[elasticsearch] final case class SearchWithAggregationsResponse(
(field: @unchecked) match {
case str if str.contains("avg#") =>
AvgAggregationResponse.decoder.decodeJson(data.toString).map(field.split("#")(1) -> _)
case str if str.contains("filter#") =>
FilterAggregationResponse.decoder.decodeJson(data.toString).map(field.split("#")(1) -> _)
case str if str.contains("max#") =>
MaxAggregationResponse.decoder.decodeJson(data.toString).map(field.split("#")(1) -> _)
case str if str.contains("min#") =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,26 @@ final case class TermsAggregationBucketResult private[elasticsearch] (
Right(None)
}
}

final case class FilterAggregationResult private[elasticsearch] (
docCount: Int,
buckets: Chunk[FilterAggregationBucketResult]
) extends AggregationResult

final case class FilterAggregationBucketResult private[elasticsearch] (
docCount: Int,
subAggregations: Map[String, AggregationResult]
) extends AggregationResult {

def subAggregationAs[A <: AggregationResult](aggName: String): Either[DecodingException, Option[A]] =
subAggregations.get(aggName) match {
case Some(aggRes) =>
Try(aggRes.asInstanceOf[A]) match {
case Failure(_) => Left(DecodingException(s"Aggregation with name $aggName was not of type you provided."))
case Success(agg) => Right(Some(agg))
}
case None =>
Right(None)
}

}
Loading

0 comments on commit 7ebeec7

Please sign in to comment.