Skip to content

Commit

Permalink
Refactore AggregationResponse
Browse files Browse the repository at this point in the history
  • Loading branch information
vanjaftn committed Nov 17, 2023
1 parent fa1f343 commit 16e1780
Showing 1 changed file with 114 additions and 113 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ import zio.json.ast.Json
import zio.json.ast.Json.Obj
import zio.json.{DeriveJsonDecoder, JsonDecoder, jsonField}

private[elasticsearch] sealed trait AggregationBucket

sealed trait AggregationResponse

object AggregationResponse {
Expand Down Expand Up @@ -107,6 +109,104 @@ private[elasticsearch] object AvgAggregationResponse {
implicit val decoder: JsonDecoder[AvgAggregationResponse] = DeriveJsonDecoder.gen[AvgAggregationResponse]
}

private[elasticsearch] sealed trait BucketDecoder extends JsonDecoderOps {
implicit class BucketDecoder(fields: Chunk[(String, Json)]) {
val allFields: Map[String, Any] = fields.flatMap { case (field, data) =>
field match {
case "key" =>
Some(field -> data.toString.replaceAll("\"", ""))
case "doc_count" =>
Some(field -> data.unsafeAs[Int])
case _ =>
val objFields = data.unsafeAs[Obj].fields.toMap

(field: @unchecked) match {
case str if str.contains("weighted_avg#") =>
Some(field -> WeightedAvgAggregationResponse(value = objFields("value").unsafeAs[Double]))
case str if str.contains("avg#") =>
Some(field -> AvgAggregationResponse(value = objFields("value").unsafeAs[Double]))
case str if str.contains("cardinality#") =>
Some(field -> CardinalityAggregationResponse(value = objFields("value").unsafeAs[Int]))
case str if str.contains("extended_stats#") =>
Some(
field -> ExtendedStatsAggregationResponse(
count = objFields("count").unsafeAs[Int],
min = objFields("min").unsafeAs[Double],
max = objFields("max").unsafeAs[Double],
avg = objFields("avg").unsafeAs[Double],
sum = objFields("sum").unsafeAs[Double],
sumOfSquares = objFields("sum_of_squares").unsafeAs[Double],
variance = objFields("variance").unsafeAs[Double],
variancePopulation = objFields("variance_population").unsafeAs[Double],
varianceSampling = objFields("variance_sampling").unsafeAs[Double],
stdDeviation = objFields("std_deviation").unsafeAs[Double],
stdDeviationPopulation = objFields("std_deviation_population").unsafeAs[Double],
stdDeviationSampling = objFields("std_deviation_sampling").unsafeAs[Double],
stdDeviationBoundsResponse = objFields("std_deviation_sampling").unsafeAs[StdDeviationBoundsResponse](
StdDeviationBoundsResponse.decoder
)
)
)
case str if str.contains("max#") =>
Some(field -> MaxAggregationResponse(value = objFields("value").unsafeAs[Double]))
case str if str.contains("min#") =>
Some(field -> MinAggregationResponse(value = objFields("value").unsafeAs[Double]))
case str if str.contains("missing#") =>
Some(field -> MissingAggregationResponse(docCount = objFields("doc_count").unsafeAs[Int]))
case str if str.contains("percentiles#") =>
Some(field -> PercentilesAggregationResponse(values = objFields("values").unsafeAs[Map[String, Double]]))
case str if str.contains("stats#") =>
Some(
field -> StatsAggregationResponse(
count = objFields("count").unsafeAs[Int],
min = objFields("min").unsafeAs[Double],
max = objFields("max").unsafeAs[Double],
avg = objFields("avg").unsafeAs[Double],
sum = objFields("sum").unsafeAs[Double]
)
)
case str if str.contains("sum#") =>
Some(field -> SumAggregationResponse(value = objFields("value").unsafeAs[Double]))
case str if str.contains("terms#") =>
Some(field -> data.unsafeAs[TermsAggregationResponse](TermsAggregationResponse.decoder))
case str if str.contains("value_count#") =>
Some(field -> ValueCountAggregationResponse(value = objFields("value").unsafeAs[Int]))
}
}
}.toMap

val subAggs: Map[String, AggregationResponse] = allFields.collect {
case (field, data) if field != "doc_count" && field != "key" =>
(field: @unchecked) match {
case str if str.contains("weighted_avg#") =>
(field.split("#")(1), data.asInstanceOf[WeightedAvgAggregationResponse])
case str if str.contains("avg#") =>
(field.split("#")(1), data.asInstanceOf[AvgAggregationResponse])
case str if str.contains("cardinality#") =>
(field.split("#")(1), data.asInstanceOf[CardinalityAggregationResponse])
case str if str.contains("extended_stats#") =>
(field.split("#")(1), data.asInstanceOf[ExtendedStatsAggregationResponse])
case str if str.contains("max#") =>
(field.split("#")(1), data.asInstanceOf[MaxAggregationResponse])
case str if str.contains("min#") =>
(field.split("#")(1), data.asInstanceOf[MinAggregationResponse])
case str if str.contains("missing#") =>
(field.split("#")(1), data.asInstanceOf[MissingAggregationResponse])
case str if str.contains("percentiles#") =>
(field.split("#")(1), data.asInstanceOf[PercentilesAggregationResponse])
case str if str.contains("stats#") =>
(field.split("#")(1), data.asInstanceOf[StatsAggregationResponse])
case str if str.contains("sum#") =>
(field.split("#")(1), data.asInstanceOf[SumAggregationResponse])
case str if str.contains("terms#") =>
(field.split("#")(1), data.asInstanceOf[TermsAggregationResponse])
case str if str.contains("value_count#") =>
(field.split("#")(1), data.asInstanceOf[ValueCountAggregationResponse])
}
}
}
}

private[elasticsearch] final case class CardinalityAggregationResponse(value: Int) extends AggregationResponse

private[elasticsearch] object CardinalityAggregationResponse {
Expand Down Expand Up @@ -142,6 +242,15 @@ private[elasticsearch] object ExtendedStatsAggregationResponse {
DeriveJsonDecoder.gen[ExtendedStatsAggregationResponse]
}

private[elasticsearch] sealed trait JsonDecoderOps {
implicit class JsonDecoderOps(json: Json) {
def unsafeAs[A](implicit decoder: JsonDecoder[A]): A =
(json.as[A]: @unchecked) match {
case Right(decoded) => decoded
}
}
}

private[elasticsearch] final case class MaxAggregationResponse(value: Double) extends AggregationResponse

private[elasticsearch] object MaxAggregationResponse {
Expand Down Expand Up @@ -217,130 +326,22 @@ private[elasticsearch] object TermsAggregationResponse {
implicit val decoder: JsonDecoder[TermsAggregationResponse] = DeriveJsonDecoder.gen[TermsAggregationResponse]
}

private[elasticsearch] sealed trait AggregationBucket

private[elasticsearch] final case class TermsAggregationBucket(
key: String,
@jsonField("doc_count")
docCount: Int,
subAggregations: Option[Map[String, AggregationResponse]] = None
) extends AggregationBucket

private[elasticsearch] object TermsAggregationBucket {
private[elasticsearch] object TermsAggregationBucket extends BucketDecoder {
implicit val decoder: JsonDecoder[TermsAggregationBucket] = Obj.decoder.mapOrFail { case Obj(fields) =>
val allFields = fields.flatMap { case (field, data) =>
field match {
case "key" =>
Some(field -> data.toString.replaceAll("\"", ""))
case "doc_count" =>
Some(field -> data.unsafeAs[Int])
case _ =>
val objFields = data.unsafeAs[Obj].fields.toMap

(field: @unchecked) match {
case str if str.contains("weighted_avg#") =>
Some(field -> WeightedAvgAggregationResponse(value = objFields("value").unsafeAs[Double]))
case str if str.contains("avg#") =>
Some(field -> AvgAggregationResponse(value = objFields("value").unsafeAs[Double]))
case str if str.contains("cardinality#") =>
Some(field -> CardinalityAggregationResponse(value = objFields("value").unsafeAs[Int]))
case str if str.contains("extended_stats#") =>
Some(
field -> ExtendedStatsAggregationResponse(
count = objFields("count").unsafeAs[Int],
min = objFields("min").unsafeAs[Double],
max = objFields("max").unsafeAs[Double],
avg = objFields("avg").unsafeAs[Double],
sum = objFields("sum").unsafeAs[Double],
sumOfSquares = objFields("sum_of_squares").unsafeAs[Double],
variance = objFields("variance").unsafeAs[Double],
variancePopulation = objFields("variance_population").unsafeAs[Double],
varianceSampling = objFields("variance_sampling").unsafeAs[Double],
stdDeviation = objFields("std_deviation").unsafeAs[Double],
stdDeviationPopulation = objFields("std_deviation_population").unsafeAs[Double],
stdDeviationSampling = objFields("std_deviation_sampling").unsafeAs[Double],
stdDeviationBoundsResponse = objFields("std_deviation_sampling").unsafeAs[StdDeviationBoundsResponse](
StdDeviationBoundsResponse.decoder
)
)
)
case str if str.contains("max#") =>
Some(field -> MaxAggregationResponse(value = objFields("value").unsafeAs[Double]))
case str if str.contains("min#") =>
Some(field -> MinAggregationResponse(value = objFields("value").unsafeAs[Double]))
case str if str.contains("missing#") =>
Some(field -> MissingAggregationResponse(docCount = objFields("doc_count").unsafeAs[Int]))
case str if str.contains("percentiles#") =>
Some(field -> PercentilesAggregationResponse(values = objFields("values").unsafeAs[Map[String, Double]]))
case str if str.contains("stats#") =>
Some(
field -> StatsAggregationResponse(
count = objFields("count").unsafeAs[Int],
min = objFields("min").unsafeAs[Double],
max = objFields("max").unsafeAs[Double],
avg = objFields("avg").unsafeAs[Double],
sum = objFields("sum").unsafeAs[Double]
)
)
case str if str.contains("sum#") =>
Some(field -> SumAggregationResponse(value = objFields("value").unsafeAs[Double]))
case str if str.contains("terms#") =>
Some(
field -> TermsAggregationResponse(
docErrorCount = objFields("doc_count_error_upper_bound").unsafeAs[Int],
sumOtherDocCount = objFields("sum_other_doc_count").unsafeAs[Int],
buckets = objFields("buckets")
.unsafeAs[Chunk[Json]]
.map(_.unsafeAs[TermsAggregationBucket](TermsAggregationBucket.decoder))
)
)
case str if str.contains("value_count#") =>
Some(field -> ValueCountAggregationResponse(value = objFields("value").unsafeAs[Int]))
}
}
}.toMap

val key = allFields("key").asInstanceOf[String]
val docCount = allFields("doc_count").asInstanceOf[Int]
val subAggs = allFields.collect {
case (field, data) if field != "key" && field != "doc_count" =>
(field: @unchecked) match {
case str if str.contains("weighted_avg#") =>
(field.split("#")(1), data.asInstanceOf[WeightedAvgAggregationResponse])
case str if str.contains("avg#") =>
(field.split("#")(1), data.asInstanceOf[AvgAggregationResponse])
case str if str.contains("cardinality#") =>
(field.split("#")(1), data.asInstanceOf[CardinalityAggregationResponse])
case str if str.contains("extended_stats#") =>
(field.split("#")(1), data.asInstanceOf[ExtendedStatsAggregationResponse])
case str if str.contains("max#") =>
(field.split("#")(1), data.asInstanceOf[MaxAggregationResponse])
case str if str.contains("min#") =>
(field.split("#")(1), data.asInstanceOf[MinAggregationResponse])
case str if str.contains("missing#") =>
(field.split("#")(1), data.asInstanceOf[MissingAggregationResponse])
case str if str.contains("percentiles#") =>
(field.split("#")(1), data.asInstanceOf[PercentilesAggregationResponse])
case str if str.contains("stats#") =>
(field.split("#")(1), data.asInstanceOf[StatsAggregationResponse])
case str if str.contains("sum#") =>
(field.split("#")(1), data.asInstanceOf[SumAggregationResponse])
case str if str.contains("terms#") =>
(field.split("#")(1), data.asInstanceOf[TermsAggregationResponse])
case str if str.contains("value_count#") =>
(field.split("#")(1), data.asInstanceOf[ValueCountAggregationResponse])
}
}
val allFields = BucketDecoder(fields).allFields
val key = allFields("key").asInstanceOf[String]
val docCount = allFields("doc_count").asInstanceOf[Int]
val subAggs = BucketDecoder(fields).subAggs

Right(TermsAggregationBucket.apply(key, docCount, Option(subAggs).filter(_.nonEmpty)))
}

final implicit class JsonDecoderOps(json: Json) {
def unsafeAs[A](implicit decoder: JsonDecoder[A]): A =
(json.as[A]: @unchecked) match {
case Right(decoded) => decoded
}
}
}

private[elasticsearch] final case class ValueCountAggregationResponse(value: Int) extends AggregationResponse
Expand Down

0 comments on commit 16e1780

Please sign in to comment.