Skip to content

Commit

Permalink
Refactore AggregationResponse
Browse files Browse the repository at this point in the history
  • Loading branch information
vanjaftn committed Nov 20, 2023
1 parent 0f3f151 commit 91f0ac2
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 151 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,9 @@ object HttpExecutorSpec extends IntegrationSpec {
.refreshTrue
)
aggregation =
termsAggregation(name = "aggregationString", field = TestDocument.stringField.keyword)
termsAggregation(name = "aggregationString", field = TestDocument.stringField.keyword).withSubAgg(
maxAggregation("subAggregation", TestDocument.intField)
)
aggsRes <-
Executor
.execute(ElasticRequest.aggregate(selectors = firstSearchIndex, aggregation = aggregation))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,51 +116,12 @@ private[elasticsearch] object AvgAggregationResponse {
implicit val decoder: JsonDecoder[AvgAggregationResponse] = DeriveJsonDecoder.gen[AvgAggregationResponse]
}

private[elasticsearch] final case class CardinalityAggregationResponse(value: Int) extends AggregationResponse

private[elasticsearch] object CardinalityAggregationResponse {
implicit val decoder: JsonDecoder[CardinalityAggregationResponse] =
DeriveJsonDecoder.gen[CardinalityAggregationResponse]
}

private[elasticsearch] final case class ExtendedStatsAggregationResponse(
count: Int,
min: Double,
max: Double,
avg: Double,
sum: Double,
@jsonField("sum_of_squares")
sumOfSquares: Double,
variance: Double,
@jsonField("variance_population")
variancePopulation: Double,
@jsonField("variance_sampling")
varianceSampling: Double,
@jsonField("std_deviation")
stdDeviation: Double,
@jsonField("std_deviation_population")
stdDeviationPopulation: Double,
@jsonField("std_deviation_sampling")
stdDeviationSampling: Double,
@jsonField("std_deviation_bounds")
stdDeviationBoundsResponse: StdDeviationBoundsResponse
) extends AggregationResponse

private[elasticsearch] object ExtendedStatsAggregationResponse {
implicit val decoder: JsonDecoder[ExtendedStatsAggregationResponse] =
DeriveJsonDecoder.gen[ExtendedStatsAggregationResponse]
}

private[elasticsearch] final case class FilterAggregationResponse(
@jsonField("doc_count")
docCount: Int,
subAggregations: Option[Map[String, AggregationResponse]] = None
) extends AggregationResponse

private[elasticsearch] object FilterAggregationResponse extends JsonDecoderOps {
implicit val decoder: JsonDecoder[FilterAggregationResponse] = Obj.decoder.mapOrFail { case Obj(fields) =>
val allFields = fields.flatMap { case (field, data) =>
private[elasticsearch] sealed trait BucketDecoder extends JsonDecoderOps {
implicit class BucketDecoder(fields: Chunk[(String, Json)]) {
val allFields: Map[String, Any] = fields.flatMap { case (field, data) =>
field match {
case "key" =>
Some(field -> data.toString.replaceAll("\"", ""))
case "doc_count" =>
Some(field -> data.unsafeAs[Int])
case _ =>
Expand Down Expand Up @@ -193,8 +154,6 @@ private[elasticsearch] object FilterAggregationResponse extends JsonDecoderOps {
)
)
)
case str if str.contains("filter#") =>
Some(field -> data.unsafeAs[FilterAggregationResponse](FilterAggregationResponse.decoder))
case str if str.contains("max#") =>
Some(field -> MaxAggregationResponse(value = objFields("value").unsafeAs[Double]))
case str if str.contains("min#") =>
Expand Down Expand Up @@ -223,9 +182,8 @@ private[elasticsearch] object FilterAggregationResponse extends JsonDecoderOps {
}
}.toMap

val docCount = allFields("doc_count").asInstanceOf[Int]
val subAggs = allFields.collect {
case (field, data) if field != "doc_count" =>
val subAggs: Map[String, AggregationResponse] = allFields.collect {
case (field, data) if field != "doc_count" && field != "key" =>
(field: @unchecked) match {
case str if str.contains("weighted_avg#") =>
(field.split("#")(1), data.asInstanceOf[WeightedAvgAggregationResponse])
Expand All @@ -235,8 +193,6 @@ private[elasticsearch] object FilterAggregationResponse extends JsonDecoderOps {
(field.split("#")(1), data.asInstanceOf[CardinalityAggregationResponse])
case str if str.contains("extended_stats#") =>
(field.split("#")(1), data.asInstanceOf[ExtendedStatsAggregationResponse])
case str if str.contains("filter#") =>
(field.split("#")(1), data.asInstanceOf[FilterAggregationResponse])
case str if str.contains("max#") =>
(field.split("#")(1), data.asInstanceOf[MaxAggregationResponse])
case str if str.contains("min#") =>
Expand All @@ -255,6 +211,56 @@ private[elasticsearch] object FilterAggregationResponse extends JsonDecoderOps {
(field.split("#")(1), data.asInstanceOf[ValueCountAggregationResponse])
}
}
}
}

private[elasticsearch] final case class CardinalityAggregationResponse(value: Int) extends AggregationResponse

private[elasticsearch] object CardinalityAggregationResponse {
implicit val decoder: JsonDecoder[CardinalityAggregationResponse] =
DeriveJsonDecoder.gen[CardinalityAggregationResponse]
}

private[elasticsearch] final case class ExtendedStatsAggregationResponse(
count: Int,
min: Double,
max: Double,
avg: Double,
sum: Double,
@jsonField("sum_of_squares")
sumOfSquares: Double,
variance: Double,
@jsonField("variance_population")
variancePopulation: Double,
@jsonField("variance_sampling")
varianceSampling: Double,
@jsonField("std_deviation")
stdDeviation: Double,
@jsonField("std_deviation_population")
stdDeviationPopulation: Double,
@jsonField("std_deviation_sampling")
stdDeviationSampling: Double,
@jsonField("std_deviation_bounds")
stdDeviationBoundsResponse: StdDeviationBoundsResponse
) extends AggregationResponse

private[elasticsearch] object ExtendedStatsAggregationResponse {
implicit val decoder: JsonDecoder[ExtendedStatsAggregationResponse] =
DeriveJsonDecoder.gen[ExtendedStatsAggregationResponse]
}

private[elasticsearch] final case class FilterAggregationResponse(
@jsonField("doc_count")
docCount: Int,
subAggregations: Option[Map[String, AggregationResponse]] = None
) extends AggregationResponse

private[elasticsearch] object FilterAggregationResponse extends BucketDecoder {
implicit val decoder: JsonDecoder[FilterAggregationResponse] = Obj.decoder.mapOrFail { case Obj(fields) =>
val allFields = BucketDecoder(fields).allFields
val docCount = allFields("doc_count").asInstanceOf[Int]
val subAggs = BucketDecoder(fields).subAggs

Right(FilterAggregationResponse.apply(docCount, Option(subAggs).filter(_.nonEmpty)))
}
}
Expand Down Expand Up @@ -350,107 +356,13 @@ private[elasticsearch] final case class TermsAggregationBucket(
subAggregations: Option[Map[String, AggregationResponse]] = None
) extends AggregationBucket

private[elasticsearch] object TermsAggregationBucket extends JsonDecoderOps {
private[elasticsearch] object TermsAggregationBucket extends BucketDecoder {
implicit val decoder: JsonDecoder[TermsAggregationBucket] = Obj.decoder.mapOrFail { case Obj(fields) =>
val allFields = fields.flatMap { case (field, data) =>
field match {
case "key" =>
Some(field -> data.toString.replaceAll("\"", ""))
case "doc_count" =>
Some(field -> data.unsafeAs[Int])
case _ =>
val objFields = data.unsafeAs[Obj].fields.toMap
val allFields = BucketDecoder(fields).allFields
val key = allFields("key").asInstanceOf[String]
val docCount = allFields("doc_count").asInstanceOf[Int]
val subAggs = BucketDecoder(fields).subAggs

(field: @unchecked) match {
case str if str.contains("weighted_avg#") =>
Some(field -> WeightedAvgAggregationResponse(value = objFields("value").unsafeAs[Double]))
case str if str.contains("avg#") =>
Some(field -> AvgAggregationResponse(value = objFields("value").unsafeAs[Double]))
case str if str.contains("cardinality#") =>
Some(field -> CardinalityAggregationResponse(value = objFields("value").unsafeAs[Int]))
case str if str.contains("extended_stats#") =>
Some(
field -> ExtendedStatsAggregationResponse(
count = objFields("count").unsafeAs[Int],
min = objFields("min").unsafeAs[Double],
max = objFields("max").unsafeAs[Double],
avg = objFields("avg").unsafeAs[Double],
sum = objFields("sum").unsafeAs[Double],
sumOfSquares = objFields("sum_of_squares").unsafeAs[Double],
variance = objFields("variance").unsafeAs[Double],
variancePopulation = objFields("variance_population").unsafeAs[Double],
varianceSampling = objFields("variance_sampling").unsafeAs[Double],
stdDeviation = objFields("std_deviation").unsafeAs[Double],
stdDeviationPopulation = objFields("std_deviation_population").unsafeAs[Double],
stdDeviationSampling = objFields("std_deviation_sampling").unsafeAs[Double],
stdDeviationBoundsResponse = objFields("std_deviation_sampling").unsafeAs[StdDeviationBoundsResponse](
StdDeviationBoundsResponse.decoder
)
)
)
case str if str.contains("filter#") =>
Some(field -> data.unsafeAs[FilterAggregationResponse](FilterAggregationResponse.decoder))
case str if str.contains("max#") =>
Some(field -> MaxAggregationResponse(value = objFields("value").unsafeAs[Double]))
case str if str.contains("min#") =>
Some(field -> MinAggregationResponse(value = objFields("value").unsafeAs[Double]))
case str if str.contains("missing#") =>
Some(field -> MissingAggregationResponse(docCount = objFields("doc_count").unsafeAs[Int]))
case str if str.contains("percentiles#") =>
Some(field -> PercentilesAggregationResponse(values = objFields("values").unsafeAs[Map[String, Double]]))
case str if str.contains("stats#") =>
Some(
field -> StatsAggregationResponse(
count = objFields("count").unsafeAs[Int],
min = objFields("min").unsafeAs[Double],
max = objFields("max").unsafeAs[Double],
avg = objFields("avg").unsafeAs[Double],
sum = objFields("sum").unsafeAs[Double]
)
)
case str if str.contains("sum#") =>
Some(field -> SumAggregationResponse(value = objFields("value").unsafeAs[Double]))
case str if str.contains("terms#") =>
Some(field -> data.unsafeAs[TermsAggregationResponse](TermsAggregationResponse.decoder))
case str if str.contains("value_count#") =>
Some(field -> ValueCountAggregationResponse(value = objFields("value").unsafeAs[Int]))
}
}
}.toMap

val key = allFields("key").asInstanceOf[String]
val docCount = allFields("doc_count").asInstanceOf[Int]
val subAggs = allFields.collect {
case (field, data) if field != "key" && field != "doc_count" =>
(field: @unchecked) match {
case str if str.contains("weighted_avg#") =>
(field.split("#")(1), data.asInstanceOf[WeightedAvgAggregationResponse])
case str if str.contains("avg#") =>
(field.split("#")(1), data.asInstanceOf[AvgAggregationResponse])
case str if str.contains("cardinality#") =>
(field.split("#")(1), data.asInstanceOf[CardinalityAggregationResponse])
case str if str.contains("extended_stats#") =>
(field.split("#")(1), data.asInstanceOf[ExtendedStatsAggregationResponse])
case str if str.contains("filter#") =>
(field.split("#")(1), data.asInstanceOf[FilterAggregationResponse])
case str if str.contains("max#") =>
(field.split("#")(1), data.asInstanceOf[MaxAggregationResponse])
case str if str.contains("min#") =>
(field.split("#")(1), data.asInstanceOf[MinAggregationResponse])
case str if str.contains("missing#") =>
(field.split("#")(1), data.asInstanceOf[MissingAggregationResponse])
case str if str.contains("percentiles#") =>
(field.split("#")(1), data.asInstanceOf[PercentilesAggregationResponse])
case str if str.contains("stats#") =>
(field.split("#")(1), data.asInstanceOf[StatsAggregationResponse])
case str if str.contains("sum#") =>
(field.split("#")(1), data.asInstanceOf[SumAggregationResponse])
case str if str.contains("terms#") =>
(field.split("#")(1), data.asInstanceOf[TermsAggregationResponse])
case str if str.contains("value_count#") =>
(field.split("#")(1), data.asInstanceOf[ValueCountAggregationResponse])
}
}
Right(TermsAggregationBucket.apply(key, docCount, Option(subAggs).filter(_.nonEmpty)))
}
}
Expand Down

0 comments on commit 91f0ac2

Please sign in to comment.