diff --git a/modules/integration/src/test/scala/zio/elasticsearch/HttpExecutorSpec.scala b/modules/integration/src/test/scala/zio/elasticsearch/HttpExecutorSpec.scala index f5c52093b..5275d9c04 100644 --- a/modules/integration/src/test/scala/zio/elasticsearch/HttpExecutorSpec.scala +++ b/modules/integration/src/test/scala/zio/elasticsearch/HttpExecutorSpec.scala @@ -104,6 +104,34 @@ object HttpExecutorSpec extends IntegrationSpec { Executor.execute(ElasticRequest.createIndex(firstSearchIndex)), Executor.execute(ElasticRequest.deleteIndex(firstSearchIndex)).orDie ), + test("aggregate using filter aggregation") { + val expectedResponse = ("aggregation", MaxAggregationResult(value = 5.0)) + checkOnce(genDocumentId, genTestDocument, genDocumentId, genTestDocument) { + (firstDocumentId, firstDocument, secondDocumentId, secondDocument) => + for { + _ <- Executor.execute( + ElasticRequest.upsert[TestDocument](firstSearchIndex, firstDocumentId, firstDocument.copy(stringField = "test", intField = 5)) + ) + _ <- Executor.execute( + ElasticRequest + .upsert[TestDocument](firstSearchIndex, secondDocumentId, secondDocument.copy(stringField = "test1", intField = 7)) + ) + aggregation = + filterAggregation(name = "aggregation", field = "test").withSubAgg( + maxAggregation("subAggregation", TestDocument.intField) + ) + + aggsRes <- + Executor + .execute(ElasticRequest.aggregate(selectors = firstSearchIndex, aggregation = aggregation)) + .aggregations + + } yield assert(aggsRes.head)(equalTo(expectedResponse)) + } + } @@ around( + Executor.execute(ElasticRequest.createIndex(firstSearchIndex)), + Executor.execute(ElasticRequest.deleteIndex(firstSearchIndex)).orDie + ), test("aggregate using max aggregation") { val expectedResponse = ("aggregationInt", MaxAggregationResult(value = 20.0)) checkOnce(genDocumentId, genTestDocument, genDocumentId, genTestDocument) { diff --git a/modules/library/src/main/scala/zio/elasticsearch/ElasticAggregation.scala b/modules/library/src/main/scala/zio/elasticsearch/ElasticAggregation.scala index 012bd18d7..e3c01e8fc 100644 --- a/modules/library/src/main/scala/zio/elasticsearch/ElasticAggregation.scala +++ b/modules/library/src/main/scala/zio/elasticsearch/ElasticAggregation.scala @@ -113,6 +113,36 @@ object ElasticAggregation { final def cardinalityAggregation(name: String, field: String): CardinalityAggregation = Cardinality(name = name, field = field, missing = None) + //Scala doc + /** + * Constructs a type-safe instance of [[zio.elasticsearch.aggregation.MaxAggregation]] using the specified parameters. + * + * @param name + * aggregation name + * @param field + * the type-safe field for which max aggregation will be executed + * @tparam A + * expected number type + * @return + * an instance of [[zio.elasticsearch.aggregation.MaxAggregation]] that represents max aggregation to be performed. + */ + final def filterAggregation(name: String, field: Field[_, String]): FilterAggregation = + Filter(name = name, field = field.toString, subAggregations = Chunk.empty) + + //Scala doc + /** + * Constructs an instance of [[zio.elasticsearch.aggregation.MaxAggregation]] using the specified parameters. + * + * @param name + * aggregation name + * @param field + * the field for which max aggregation will be executed + * @return + * an instance of [[zio.elasticsearch.aggregation.MaxAggregation]] that represents max aggregation to be performed. + */ + final def filterAggregation(name: String, field: String): FilterAggregation = + Filter(name = name, field = field, subAggregations = Chunk.empty) + /** * Constructs a type-safe instance of [[zio.elasticsearch.aggregation.MaxAggregation]] using the specified parameters. * diff --git a/modules/library/src/main/scala/zio/elasticsearch/aggregation/Aggregations.scala b/modules/library/src/main/scala/zio/elasticsearch/aggregation/Aggregations.scala index 68058df13..45453d78c 100644 --- a/modules/library/src/main/scala/zio/elasticsearch/aggregation/Aggregations.scala +++ b/modules/library/src/main/scala/zio/elasticsearch/aggregation/Aggregations.scala @@ -143,6 +143,30 @@ private[elasticsearch] final case class Cardinality(name: String, field: String, } } +sealed trait FilterAggregation + extends SingleElasticAggregation + with WithSubAgg[FilterAggregation] + with WithAgg + +private[elasticsearch] final case class Filter(name: String, field: String, subAggregations: Chunk[SingleElasticAggregation]) + extends FilterAggregation { self => + + def withAgg(agg: SingleElasticAggregation): MultipleAggregations = + multipleAggregations.aggregations(self, agg) + + def withSubAgg(aggregation: SingleElasticAggregation): FilterAggregation = + self.copy(subAggregations = aggregation +: subAggregations) + + val subAggsJson = + if (self.subAggregations.nonEmpty) + Obj("aggs" -> self.subAggregations.map(_.toJson).reduce(_ merge _)) + else + Obj() + private[elasticsearch] def toJson: Json = { + Obj(name -> (Obj("filter" -> Obj("term" -> Obj("type" -> self.field.toJson))) merge subAggsJson)) + } +} + sealed trait MaxAggregation extends SingleElasticAggregation with HasMissing[MaxAggregation] with WithAgg private[elasticsearch] final case class Max(name: String, field: String, missing: Option[Double]) diff --git a/modules/library/src/main/scala/zio/elasticsearch/executor/HttpExecutor.scala b/modules/library/src/main/scala/zio/elasticsearch/executor/HttpExecutor.scala index cd0246e07..370f2f83f 100644 --- a/modules/library/src/main/scala/zio/elasticsearch/executor/HttpExecutor.scala +++ b/modules/library/src/main/scala/zio/elasticsearch/executor/HttpExecutor.scala @@ -122,12 +122,15 @@ private[elasticsearch] final class HttpExecutor private (esConfig: ElasticConfig ).flatMap { response => response.code match { case HttpOk => + println(response.body) response.body.fold( e => ZIO.fail(new ElasticException(s"Exception occurred: ${e.getMessage}")), - value => + value => { ZIO.succeed(new AggregateResult(value.aggs.map { case (key, response) => (key, toResult(response)) })) + } + ) case _ => ZIO.fail(handleFailuresFromCustomResponse(response)) @@ -601,7 +604,7 @@ private[elasticsearch] final class HttpExecutor private (esConfig: ElasticConfig UnauthorizedException case _ => new ElasticException( - s"Unexpected response from Elasticsearch. Response body: ${response.body.fold(body => body, _ => "")}" + s"Unexpected response from Elasticsearch. Response body: ${response.body.fold(body => body, _ => "")}" ) } diff --git a/modules/library/src/main/scala/zio/elasticsearch/executor/response/AggregationResponse.scala b/modules/library/src/main/scala/zio/elasticsearch/executor/response/AggregationResponse.scala index 976183621..1eef7b9fe 100644 --- a/modules/library/src/main/scala/zio/elasticsearch/executor/response/AggregationResponse.scala +++ b/modules/library/src/main/scala/zio/elasticsearch/executor/response/AggregationResponse.scala @@ -31,6 +31,18 @@ object AggregationResponse { AvgAggregationResult(value) case CardinalityAggregationResponse(value) => CardinalityAggregationResult(value) + case FilterAggregationResponse(docCount, buckets) => + FilterAggregationResult( + docCount = docCount, + buckets = buckets.map(b => + FilterAggregationBucketResult( + docCount = b.docCount, + subAggregations = b.subAggregations.fold(Map[String, AggregationResult]())(_.map { case (key, response) => + (key, toResult(response)) + }) + ) + ) + ) case MaxAggregationResponse(value) => MaxAggregationResult(value) case MinAggregationResponse(value) => @@ -105,6 +117,109 @@ private[elasticsearch] object SumAggregationResponse { } +private[elasticsearch] sealed trait AggregationBucket + +private[elasticsearch] final case class FilterAggregationResponse( + @jsonField("doc_count") + docCount: Int, + buckets: Chunk[FilterAggregationBucket] +) extends AggregationResponse + +private[elasticsearch] object FilterAggregationResponse { + implicit val decoder: JsonDecoder[FilterAggregationResponse] = + DeriveJsonDecoder.gen[FilterAggregationResponse] +} + +private[elasticsearch] final case class FilterAggregationBucket( + @jsonField("doc_count") + docCount: Int, + subAggregations: Option[Map[String, AggregationResponse]] = None +) extends AggregationBucket + +private[elasticsearch] object FilterAggregationBucket { + implicit val decoder: JsonDecoder[FilterAggregationBucket] = Obj.decoder.mapOrFail { case Obj(fields) => + val allFields = fields.flatMap { case (field, data) => + field match { + case "doc_count" => + Some(field -> data.unsafeAs[Int]) + case _ => + val objFields = data.unsafeAs[Obj].fields.toMap + + (field: @unchecked) match { + case str if str.contains("avg#") => + Some(field -> AvgAggregationResponse(value = objFields("value").unsafeAs[Double])) + case str if str.contains("cardinality#") => + Some(field -> CardinalityAggregationResponse(value = objFields("value").unsafeAs[Int])) + case str if str.contains("filter#") => + Some( + field -> FilterAggregationResponse( + docCount = objFields("doc_count").unsafeAs[Int], + buckets = objFields("buckets") + .unsafeAs[Chunk[Json]] + .map(_.unsafeAs[FilterAggregationBucket](FilterAggregationBucket.decoder)) + ) + ) + case str if str.contains("max#") => + Some(field -> MaxAggregationResponse(value = objFields("value").unsafeAs[Double])) + case str if str.contains("min#") => + Some(field -> MinAggregationResponse(value = objFields("value").unsafeAs[Double])) + case str if str.contains("missing#") => + Some(field -> MissingAggregationResponse(docCount = objFields("doc_count").unsafeAs[Int])) + case str if str.contains("percentiles#") => + Some(field -> PercentilesAggregationResponse(values = objFields("values").unsafeAs[Map[String, Double]])) + case str if str.contains("sum#") => + Some(field -> SumAggregationResponse(value = objFields("value").unsafeAs[Double])) + case str if str.contains("terms#") => + Some( + field -> TermsAggregationResponse( + docErrorCount = objFields("doc_count_error_upper_bound").unsafeAs[Int], + sumOtherDocCount = objFields("sum_other_doc_count").unsafeAs[Int], + buckets = objFields("buckets") + .unsafeAs[Chunk[Json]] + .map(_.unsafeAs[TermsAggregationBucket](TermsAggregationBucket.decoder)) + ) + ) + } + } + }.toMap + + val docCount = allFields("doc_count").asInstanceOf[Int] + val subAggs = allFields.collect { + case (field, data) if field != "doc_count" => + (field: @unchecked) match { + case str if str.contains("avg#") => + (field.split("#")(1), data.asInstanceOf[AvgAggregationResponse]) + case str if str.contains("cardinality#") => + (field.split("#")(1), data.asInstanceOf[CardinalityAggregationResponse]) + case str if str.contains("filter#") => + (field.split("#")(1), data.asInstanceOf[PercentilesAggregationResponse]) + case str if str.contains("max#") => + (field.split("#")(1), data.asInstanceOf[MaxAggregationResponse]) + case str if str.contains("min#") => + (field.split("#")(1), data.asInstanceOf[MinAggregationResponse]) + case str if str.contains("missing#") => + (field.split("#")(1), data.asInstanceOf[MissingAggregationResponse]) + case str if str.contains("percentiles#") => + (field.split("#")(1), data.asInstanceOf[PercentilesAggregationResponse]) + case str if str.contains("sum#") => + (field.split("#")(1), data.asInstanceOf[SumAggregationResponse]) + case str if str.contains("terms#") => + (field.split("#")(1), data.asInstanceOf[TermsAggregationResponse]) + } + } + + Right(FilterAggregationBucket.apply(docCount, Option(subAggs).filter(_.nonEmpty))) + } + + final implicit class JsonDecoderOps(json: Json) { + def unsafeAs[A](implicit decoder: JsonDecoder[A]): A = + (json.as[A]: @unchecked) match { + case Right(decoded) => decoded + } + } + +} + private[elasticsearch] final case class TermsAggregationResponse( @jsonField("doc_count_error_upper_bound") docErrorCount: Int, @@ -117,8 +232,6 @@ private[elasticsearch] object TermsAggregationResponse { implicit val decoder: JsonDecoder[TermsAggregationResponse] = DeriveJsonDecoder.gen[TermsAggregationResponse] } -private[elasticsearch] sealed trait AggregationBucket - private[elasticsearch] final case class TermsAggregationBucket( key: String, @jsonField("doc_count") @@ -142,6 +255,15 @@ private[elasticsearch] object TermsAggregationBucket { Some(field -> AvgAggregationResponse(value = objFields("value").unsafeAs[Double])) case str if str.contains("cardinality#") => Some(field -> CardinalityAggregationResponse(value = objFields("value").unsafeAs[Int])) + case str if str.contains("filter#") => + Some( + field -> FilterAggregationResponse( + docCount = objFields("doc_count_error_upper_bound").unsafeAs[Int], + buckets = objFields("buckets") + .unsafeAs[Chunk[Json]] + .map(_.unsafeAs[FilterAggregationBucket](FilterAggregationBucket.decoder)) + ) + ) case str if str.contains("max#") => Some(field -> MaxAggregationResponse(value = objFields("value").unsafeAs[Double])) case str if str.contains("min#") => @@ -175,6 +297,8 @@ private[elasticsearch] object TermsAggregationBucket { (field.split("#")(1), data.asInstanceOf[AvgAggregationResponse]) case str if str.contains("cardinality#") => (field.split("#")(1), data.asInstanceOf[CardinalityAggregationResponse]) + case str if str.contains("filter#") => + (field.split("#")(1), data.asInstanceOf[PercentilesAggregationResponse]) case str if str.contains("max#") => (field.split("#")(1), data.asInstanceOf[MaxAggregationResponse]) case str if str.contains("min#") => diff --git a/modules/library/src/main/scala/zio/elasticsearch/executor/response/SearchWithAggregationsResponse.scala b/modules/library/src/main/scala/zio/elasticsearch/executor/response/SearchWithAggregationsResponse.scala index 105f4b76b..1e663f76f 100644 --- a/modules/library/src/main/scala/zio/elasticsearch/executor/response/SearchWithAggregationsResponse.scala +++ b/modules/library/src/main/scala/zio/elasticsearch/executor/response/SearchWithAggregationsResponse.scala @@ -74,6 +74,8 @@ private[elasticsearch] final case class SearchWithAggregationsResponse( (field: @unchecked) match { case str if str.contains("avg#") => AvgAggregationResponse.decoder.decodeJson(data.toString).map(field.split("#")(1) -> _) + case str if str.contains("filter#") => + FilterAggregationResponse.decoder.decodeJson(data.toString).map(field.split("#")(1) -> _) case str if str.contains("max#") => MaxAggregationResponse.decoder.decodeJson(data.toString).map(field.split("#")(1) -> _) case str if str.contains("min#") => diff --git a/modules/library/src/main/scala/zio/elasticsearch/result/AggregationResult.scala b/modules/library/src/main/scala/zio/elasticsearch/result/AggregationResult.scala index 98cb0c7c4..36d074051 100644 --- a/modules/library/src/main/scala/zio/elasticsearch/result/AggregationResult.scala +++ b/modules/library/src/main/scala/zio/elasticsearch/result/AggregationResult.scala @@ -60,3 +60,26 @@ final case class TermsAggregationBucketResult private[elasticsearch] ( Right(None) } } + +final case class FilterAggregationResult private[elasticsearch] ( + docCount: Int, + buckets: Chunk[FilterAggregationBucketResult] +) extends AggregationResult + +final case class FilterAggregationBucketResult private[elasticsearch] ( + docCount: Int, + subAggregations: Map[String, AggregationResult] +) extends AggregationResult { + + def subAggregationAs[A <: AggregationResult](aggName: String): Either[DecodingException, Option[A]] = + subAggregations.get(aggName) match { + case Some(aggRes) => + Try(aggRes.asInstanceOf[A]) match { + case Failure(_) => Left(DecodingException(s"Aggregation with name $aggName was not of type you provided.")) + case Success(agg) => Right(Some(agg)) + } + case None => + Right(None) + } + +} diff --git a/modules/library/src/test/scala/zio/elasticsearch/ElasticAggregationSpec.scala b/modules/library/src/test/scala/zio/elasticsearch/ElasticAggregationSpec.scala index 86442107d..1f632f6f0 100644 --- a/modules/library/src/test/scala/zio/elasticsearch/ElasticAggregationSpec.scala +++ b/modules/library/src/test/scala/zio/elasticsearch/ElasticAggregationSpec.scala @@ -128,6 +128,39 @@ object ElasticAggregationSpec extends ZIOSpecDefault { equalTo(Cardinality(name = "aggregation", field = "intField", missing = Some(20))) ) }, + test("filter") { + val aggregation = filterAggregation("aggregation", "testField") + val aggregationTs = filterAggregation("aggregation", TestSubDocument.stringField) + val aggregationTsRaw = filterAggregation("aggregation", TestSubDocument.stringField.raw) + + assert(aggregation)( + equalTo( + Filter( + name = "aggregation", + field = "testField", + subAggregations = Chunk.empty + ) + ) + ) && + assert(aggregationTs)( + equalTo( + Filter( + name = "aggregation", + field = "stringField", + subAggregations = Chunk.empty, + ) + ) + ) && + assert(aggregationTsRaw)( + equalTo( + Filter( + name = "aggregation", + field = "stringField.raw", + subAggregations = Chunk.empty, + ) + ) + ) + }, test("max") { val aggregation = maxAggregation("aggregation", "testField") val aggregationTs = maxAggregation("aggregation", TestSubDocument.intField) @@ -625,6 +658,63 @@ object ElasticAggregationSpec extends ZIOSpecDefault { assert(aggregationTs.toJson)(equalTo(expectedTs.toJson)) && assert(aggregationWithMissing.toJson)(equalTo(expectedWithMissing.toJson)) }, + test("filter") { + val aggregation = filterAggregation("aggregation", "testField") + val aggregationTs = filterAggregation("aggregation", TestDocument.stringField) + val aggregationWithSubAggregation = filterAggregation("aggregation", TestDocument.stringField).withSubAgg( + minAggregation("subAggregation", TestDocument.intField) + ) + + val expected = + """ + |{ + | "aggregation": { + | "filter": { + | "term": { + | "type" : "testField" + | } + | } + | } + |} + |""".stripMargin + + val expectedTs = + """ + |{ + | "aggregation": { + | "filter": { + | "term": { + | "type" : "stringField" + | } + | } + | } + |} + |""".stripMargin + + val expectedWithSubAggregation = + """ + |{ + | "aggregation": { + | "filter": { + | "term": { + | "type" : "stringField" + | } + | }, + | "aggs": { + | "subAggregation": { + | "min": { + | "field": "intField" + | } + | } + | } + | } + |} + |""".stripMargin + + assert(aggregation.toJson)(equalTo(expected.toJson)) && + assert(aggregationTs.toJson)(equalTo(expectedTs.toJson)) && + assert(aggregationWithSubAggregation.toJson)(equalTo(expectedWithSubAggregation.toJson)) + }, test("max") { val aggregation = maxAggregation("aggregation", "testField") val aggregationTs = maxAggregation("aggregation", TestDocument.intField)