From 2b11947f5878c7b6ed11f6281374abeaf2cd5066 Mon Sep 17 00:00:00 2001 From: Kevin Chuang Date: Wed, 8 Feb 2023 02:57:34 -0500 Subject: [PATCH] Support boost in Range and Bool query --- .../main/scala/zio/elasticsearch/Boost.scala | 24 +++- .../zio/elasticsearch/ElasticQuery.scala | 51 +++++--- .../zio/elasticsearch/QueryDSLSpec.scala | 117 ++++++++++++++++++ 3 files changed, 170 insertions(+), 22 deletions(-) diff --git a/modules/library/src/main/scala/zio/elasticsearch/Boost.scala b/modules/library/src/main/scala/zio/elasticsearch/Boost.scala index 0f882daba..52a7eb853 100644 --- a/modules/library/src/main/scala/zio/elasticsearch/Boost.scala +++ b/modules/library/src/main/scala/zio/elasticsearch/Boost.scala @@ -16,8 +16,17 @@ package zio.elasticsearch -import zio.elasticsearch.ElasticQuery.{ElasticPrimitive, MatchAllQuery, TermQuery, WildcardQuery} -import zio.elasticsearch.ElasticQueryType.{MatchAll, Term, Wildcard} +import zio.elasticsearch.ElasticQuery.{ + BoolQuery, + ElasticPrimitive, + LowerBound, + MatchAllQuery, + RangeQuery, + TermQuery, + UpperBound, + WildcardQuery +} +import zio.elasticsearch.ElasticQueryType.{Bool, MatchAll, Range, Term, Wildcard} object Boost { @@ -26,11 +35,22 @@ object Boost { } object WithBoost { + implicit val boolWithBoost: WithBoost[Bool] = (query: ElasticQuery[Bool], value: Double) => + query match { + case q: BoolQuery => q.copy(boost = Some(value)) + } + implicit val matchAllWithBoost: WithBoost[MatchAll] = (query: ElasticQuery[MatchAll], value: Double) => query match { case q: MatchAllQuery => q.copy(boost = Some(value)) } + implicit def rangeWithBoost[A, LB <: LowerBound, UB <: UpperBound]: WithBoost[Range[A, LB, UB]] = + (query: ElasticQuery[Range[A, LB, UB]], value: Double) => + query match { + case q: RangeQuery[A, LB, UB] => q.copy(boost = Some(value)) + } + implicit def termWithBoost[A: ElasticPrimitive]: WithBoost[Term[A]] = (query: ElasticQuery[Term[A]], value: Double) => query match { diff --git a/modules/library/src/main/scala/zio/elasticsearch/ElasticQuery.scala b/modules/library/src/main/scala/zio/elasticsearch/ElasticQuery.scala index 417a1ac3e..ac995108e 100644 --- a/modules/library/src/main/scala/zio/elasticsearch/ElasticQuery.scala +++ b/modules/library/src/main/scala/zio/elasticsearch/ElasticQuery.scala @@ -123,7 +123,8 @@ object ElasticQuery { private[elasticsearch] final case class BoolQuery( filter: List[ElasticQuery[_]], must: List[ElasticQuery[_]], - should: List[ElasticQuery[_]] + should: List[ElasticQuery[_]], + boost: Option[Double] ) extends ElasticQuery[Bool] { self => def filter(queries: ElasticQuery[_]*): BoolQuery = @@ -132,20 +133,21 @@ object ElasticQuery { def must(queries: ElasticQuery[_]*): BoolQuery = self.copy(must = must ++ queries) - def paramsToJson: Json = Obj( - "bool" -> Obj( - "filter" -> Arr(filter.map(_.paramsToJson): _*), - "must" -> Arr(must.map(_.paramsToJson): _*), - "should" -> Arr(should.map(_.paramsToJson): _*) - ) - ) + def paramsToJson: Json = { + val boolFields = + Some("filter" -> Arr(filter.map(_.paramsToJson): _*)) ++ + Some("must" -> Arr(must.map(_.paramsToJson): _*)) ++ + Some("should" -> Arr(should.map(_.paramsToJson): _*)) ++ + boost.map("boost" -> Num(_)) + Obj("bool" -> Obj(boolFields.toList: _*)) + } def should(queries: ElasticQuery[_]*): BoolQuery = self.copy(should = should ++ queries) } private[elasticsearch] object BoolQuery { - def empty: BoolQuery = BoolQuery(Nil, Nil, Nil) + def empty: BoolQuery = BoolQuery(filter = Nil, must = Nil, should = Nil, boost = None) } private[elasticsearch] final case class ExistsQuery private (field: String) extends ElasticQuery[Exists] { @@ -192,8 +194,9 @@ object ElasticQuery { private[elasticsearch] final case class RangeQuery[A, LB <: LowerBound, UB <: UpperBound] private ( field: String, lower: LB, - upper: UB - ) extends ElasticQuery[Range] { self => + upper: UB, + boost: Option[Double] + ) extends ElasticQuery[Range[A, LB, UB]] { self => def gt[B <: A: ElasticPrimitive](value: B)(implicit @unused ev: LB =:= Unbounded.type @@ -215,12 +218,20 @@ object ElasticQuery { ): RangeQuery[B, LB, LessThanOrEqualTo[B]] = self.copy(upper = LessThanOrEqualTo(value)) - def paramsToJson: Json = Obj("range" -> Obj(field -> Obj(List(lower.toJson, upper.toJson).flatten: _*))) + def paramsToJson: Json = { + val rangeFields = Some(field -> Obj(List(lower.toJson, upper.toJson).flatten: _*)) ++ boost.map("boost" -> Num(_)) + Obj("range" -> Obj(rangeFields.toList: _*)) + } } private[elasticsearch] object RangeQuery { def empty[A](field: String): RangeQuery[A, Unbounded.type, Unbounded.type] = - RangeQuery[A, Unbounded.type, Unbounded.type](field, Unbounded, Unbounded) + RangeQuery[A, Unbounded.type, Unbounded.type]( + field = field, + lower = Unbounded, + upper = Unbounded, + boost = None + ) } private[elasticsearch] final case class TermQuery[A: ElasticPrimitive]( @@ -256,11 +267,11 @@ object ElasticQuery { sealed trait ElasticQueryType object ElasticQueryType { - sealed trait Bool extends ElasticQueryType - sealed trait Exists extends ElasticQueryType - sealed trait Match extends ElasticQueryType - sealed trait MatchAll extends ElasticQueryType - sealed trait Range extends ElasticQueryType - sealed trait Term[A] extends ElasticQueryType - sealed trait Wildcard extends ElasticQueryType + sealed trait Bool extends ElasticQueryType + sealed trait Exists extends ElasticQueryType + sealed trait Match extends ElasticQueryType + sealed trait MatchAll extends ElasticQueryType + sealed trait Range[A, LB, UB] extends ElasticQueryType + sealed trait Term[A] extends ElasticQueryType + sealed trait Wildcard extends ElasticQueryType } diff --git a/modules/library/src/test/scala/zio/elasticsearch/QueryDSLSpec.scala b/modules/library/src/test/scala/zio/elasticsearch/QueryDSLSpec.scala index 0755b4614..36fcd7ee1 100644 --- a/modules/library/src/test/scala/zio/elasticsearch/QueryDSLSpec.scala +++ b/modules/library/src/test/scala/zio/elasticsearch/QueryDSLSpec.scala @@ -62,6 +62,11 @@ object QueryDSLSpec extends ZIOSpecDefault { assert(query)(equalTo(MatchQuery(field = "name.keyword", value = "Name"))) }, + test("successfully create Bool Query with boost") { + val query = boolQuery().boost(1.0) + + assert(query)(equalTo(BoolQuery(filter = Nil, must = Nil, should = Nil, boost = Some(1.0)))) + }, test("successfully create `Filter` query from two Match queries") { val query = boolQuery .filter( @@ -182,6 +187,24 @@ object QueryDSLSpec extends ZIOSpecDefault { ) ) }, + test("successfully create `Filter/Must/Should` mixed query with boost") { + val query = boolQuery() + .filter(matches(field = "customer_id", value = 1)) + .must(matches(field = "customer_age", value = 23)) + .should(matches(field = "day_of_week", value = "Monday")) + .boost(1.0) + + assert(query)( + equalTo( + BoolQuery( + filter = List(MatchQuery(field = "customer_id", value = 1)), + must = List(MatchQuery(field = "customer_age", value = 23)), + should = List(MatchQuery(field = "day_of_week", value = "Monday")), + boost = Some(1.0) + ) + ) + ) + }, test("successfully create Exists Query") { val query = exists(field = "day_of_week") @@ -453,6 +476,24 @@ object QueryDSLSpec extends ZIOSpecDefault { assert(query.toJson)(equalTo(expected.toJson)) }, + test("properly encode Bool Query with boost") { + val query = boolQuery().boost(1.0) + val expected = + """ + |{ + | "query": { + | "bool": { + | "filter": [], + | "must": [], + | "should": [], + | "boost": 1.0 + | } + | } + |} + |""".stripMargin + + assert(query.toJsonBody)(equalTo(expected.toJson)) + }, test("properly encode Bool Query with Filter containing `Match` leaf query") { val query = boolQuery.filter(matches(field = "day_of_week", value = "Monday")) val expected = @@ -560,6 +601,46 @@ object QueryDSLSpec extends ZIOSpecDefault { assert(query.toJson)(equalTo(expected.toJson)) }, + test("properly encode Bool Query with Filter, Must and Should containing `Match` leaf query and with boost") { + val query = boolQuery() + .filter(matches(field = "customer_age", value = 23)) + .must(matches(field = "customer_id", value = 1)) + .should(matches(field = "day_of_week", value = "Monday")) + .boost(1.0) + val expected = + """ + |{ + | "query": { + | "bool": { + | "filter": [ + | { + | "match": { + | "customer_age": 23 + | } + | } + | ], + | "must": [ + | { + | "match": { + | "customer_id": 1 + | } + | } + | ], + | "should": [ + | { + | "match": { + | "day_of_week": "Monday" + | } + | } + | ], + | "boost": 1.0 + | } + | } + |} + |""".stripMargin + + assert(query.toJsonBody)(equalTo(expected.toJson)) + }, test("properly encode Exists Query") { val query = exists(field = "day_of_week") val expected = @@ -619,6 +700,23 @@ object QueryDSLSpec extends ZIOSpecDefault { assert(query.toJson)(equalTo(expected.toJson)) }, + test("properly encode Unbounded Range Query with boost") { + val query = range(field = "field").boost(1.0) + val expected = + """ + |{ + | "query": { + | "range": { + | "field": { + | }, + | "boost": 1.0 + | } + | } + |} + |""".stripMargin + + assert(query.toJsonBody)(equalTo(expected.toJson)) + }, test("properly encode Range Query with Lower Bound") { val query = range(field = "customer_age").gt(23) val expected = @@ -705,6 +803,25 @@ object QueryDSLSpec extends ZIOSpecDefault { assert(query.toJson)(equalTo(expected.toJson)) }, + test("properly encode Range Query with both Upper and Lower Bound with boost") { + val query = range(field = "customer_age").gte(10).lt(100).boost(1.0) + val expected = + """ + |{ + | "query": { + | "range": { + | "customer_age": { + | "gte": 10, + | "lt": 100 + | }, + | "boost": 1.0 + | } + | } + |} + |""".stripMargin + + assert(query.toJsonBody)(equalTo(expected.toJson)) + }, test("properly encode Term query") { val query = term(field = "day_of_week", value = true) val expected =