From 6de5096d708bc1b9e7fa556ec8229b04073a19e8 Mon Sep 17 00:00:00 2001 From: Kevin Date: Wed, 8 Feb 2023 15:18:25 -0500 Subject: [PATCH] Support boost in Range and Bool query (#88) Closes #66 --- .../main/scala/zio/elasticsearch/Boost.scala | 15 +- .../zio/elasticsearch/ElasticQuery.scala | 51 +++--- .../zio/elasticsearch/QueryDSLSpec.scala | 170 ++++++++++++++++-- 3 files changed, 199 insertions(+), 37 deletions(-) diff --git a/modules/library/src/main/scala/zio/elasticsearch/Boost.scala b/modules/library/src/main/scala/zio/elasticsearch/Boost.scala index 0f882daba..132e1ec27 100644 --- a/modules/library/src/main/scala/zio/elasticsearch/Boost.scala +++ b/modules/library/src/main/scala/zio/elasticsearch/Boost.scala @@ -16,8 +16,8 @@ package zio.elasticsearch -import zio.elasticsearch.ElasticQuery.{ElasticPrimitive, MatchAllQuery, TermQuery, WildcardQuery} -import zio.elasticsearch.ElasticQueryType.{MatchAll, Term, Wildcard} +import zio.elasticsearch.ElasticQuery._ +import zio.elasticsearch.ElasticQueryType.{Bool, MatchAll, Range, Term, Wildcard} object Boost { @@ -26,11 +26,22 @@ object Boost { } object WithBoost { + implicit val boolWithBoost: WithBoost[Bool] = (query: ElasticQuery[Bool], value: Double) => + query match { + case q: BoolQuery => q.copy(boost = Some(value)) + } + implicit val matchAllWithBoost: WithBoost[MatchAll] = (query: ElasticQuery[MatchAll], value: Double) => query match { case q: MatchAllQuery => q.copy(boost = Some(value)) } + implicit def rangeWithBoost[A, LB <: LowerBound, UB <: UpperBound]: WithBoost[Range[A, LB, UB]] = + (query: ElasticQuery[Range[A, LB, UB]], value: Double) => + query match { + case q: RangeQuery[A, LB, UB] => q.copy(boost = Some(value)) + } + implicit def termWithBoost[A: ElasticPrimitive]: WithBoost[Term[A]] = (query: ElasticQuery[Term[A]], value: Double) => query match { diff --git a/modules/library/src/main/scala/zio/elasticsearch/ElasticQuery.scala b/modules/library/src/main/scala/zio/elasticsearch/ElasticQuery.scala index 417a1ac3e..ac995108e 100644 --- a/modules/library/src/main/scala/zio/elasticsearch/ElasticQuery.scala +++ b/modules/library/src/main/scala/zio/elasticsearch/ElasticQuery.scala @@ -123,7 +123,8 @@ object ElasticQuery { private[elasticsearch] final case class BoolQuery( filter: List[ElasticQuery[_]], must: List[ElasticQuery[_]], - should: List[ElasticQuery[_]] + should: List[ElasticQuery[_]], + boost: Option[Double] ) extends ElasticQuery[Bool] { self => def filter(queries: ElasticQuery[_]*): BoolQuery = @@ -132,20 +133,21 @@ object ElasticQuery { def must(queries: ElasticQuery[_]*): BoolQuery = self.copy(must = must ++ queries) - def paramsToJson: Json = Obj( - "bool" -> Obj( - "filter" -> Arr(filter.map(_.paramsToJson): _*), - "must" -> Arr(must.map(_.paramsToJson): _*), - "should" -> Arr(should.map(_.paramsToJson): _*) - ) - ) + def paramsToJson: Json = { + val boolFields = + Some("filter" -> Arr(filter.map(_.paramsToJson): _*)) ++ + Some("must" -> Arr(must.map(_.paramsToJson): _*)) ++ + Some("should" -> Arr(should.map(_.paramsToJson): _*)) ++ + boost.map("boost" -> Num(_)) + Obj("bool" -> Obj(boolFields.toList: _*)) + } def should(queries: ElasticQuery[_]*): BoolQuery = self.copy(should = should ++ queries) } private[elasticsearch] object BoolQuery { - def empty: BoolQuery = BoolQuery(Nil, Nil, Nil) + def empty: BoolQuery = BoolQuery(filter = Nil, must = Nil, should = Nil, boost = None) } private[elasticsearch] final case class ExistsQuery private (field: String) extends ElasticQuery[Exists] { @@ -192,8 +194,9 @@ object ElasticQuery { private[elasticsearch] final case class RangeQuery[A, LB <: LowerBound, UB <: UpperBound] private ( field: String, lower: LB, - upper: UB - ) extends ElasticQuery[Range] { self => + upper: UB, + boost: Option[Double] + ) extends ElasticQuery[Range[A, LB, UB]] { self => def gt[B <: A: ElasticPrimitive](value: B)(implicit @unused ev: LB =:= Unbounded.type @@ -215,12 +218,20 @@ object ElasticQuery { ): RangeQuery[B, LB, LessThanOrEqualTo[B]] = self.copy(upper = LessThanOrEqualTo(value)) - def paramsToJson: Json = Obj("range" -> Obj(field -> Obj(List(lower.toJson, upper.toJson).flatten: _*))) + def paramsToJson: Json = { + val rangeFields = Some(field -> Obj(List(lower.toJson, upper.toJson).flatten: _*)) ++ boost.map("boost" -> Num(_)) + Obj("range" -> Obj(rangeFields.toList: _*)) + } } private[elasticsearch] object RangeQuery { def empty[A](field: String): RangeQuery[A, Unbounded.type, Unbounded.type] = - RangeQuery[A, Unbounded.type, Unbounded.type](field, Unbounded, Unbounded) + RangeQuery[A, Unbounded.type, Unbounded.type]( + field = field, + lower = Unbounded, + upper = Unbounded, + boost = None + ) } private[elasticsearch] final case class TermQuery[A: ElasticPrimitive]( @@ -256,11 +267,11 @@ object ElasticQuery { sealed trait ElasticQueryType object ElasticQueryType { - sealed trait Bool extends ElasticQueryType - sealed trait Exists extends ElasticQueryType - sealed trait Match extends ElasticQueryType - sealed trait MatchAll extends ElasticQueryType - sealed trait Range extends ElasticQueryType - sealed trait Term[A] extends ElasticQueryType - sealed trait Wildcard extends ElasticQueryType + sealed trait Bool extends ElasticQueryType + sealed trait Exists extends ElasticQueryType + sealed trait Match extends ElasticQueryType + sealed trait MatchAll extends ElasticQueryType + sealed trait Range[A, LB, UB] extends ElasticQueryType + sealed trait Term[A] extends ElasticQueryType + sealed trait Wildcard extends ElasticQueryType } diff --git a/modules/library/src/test/scala/zio/elasticsearch/QueryDSLSpec.scala b/modules/library/src/test/scala/zio/elasticsearch/QueryDSLSpec.scala index 0755b4614..7f177dc1d 100644 --- a/modules/library/src/test/scala/zio/elasticsearch/QueryDSLSpec.scala +++ b/modules/library/src/test/scala/zio/elasticsearch/QueryDSLSpec.scala @@ -62,6 +62,11 @@ object QueryDSLSpec extends ZIOSpecDefault { assert(query)(equalTo(MatchQuery(field = "name.keyword", value = "Name"))) }, + test("successfully create Bool Query with boost") { + val query = boolQuery.boost(1.0) + + assert(query)(equalTo(BoolQuery(filter = Nil, must = Nil, should = Nil, boost = Some(1.0)))) + }, test("successfully create `Filter` query from two Match queries") { val query = boolQuery .filter( @@ -77,7 +82,8 @@ object QueryDSLSpec extends ZIOSpecDefault { MatchQuery(field = "customer_gender", value = "MALE") ), must = Nil, - should = Nil + should = Nil, + boost = None ) ) ) @@ -94,7 +100,8 @@ object QueryDSLSpec extends ZIOSpecDefault { MatchQuery(field = "day_of_week", value = "Monday"), MatchQuery(field = "customer_gender", value = "MALE") ), - should = Nil + should = Nil, + boost = None ) ) ) @@ -114,7 +121,8 @@ object QueryDSLSpec extends ZIOSpecDefault { should = List( MatchQuery(field = "day_of_week", value = "Monday"), MatchQuery(field = "customer_gender", value = "MALE") - ) + ), + boost = None ) ) ) @@ -136,7 +144,8 @@ object QueryDSLSpec extends ZIOSpecDefault { MatchQuery(field = "customer_gender", value = "MALE") ), must = List(MatchQuery(field = "customer_age", value = 23)), - should = List(MatchQuery(field = "customer_id", value = 1)) + should = List(MatchQuery(field = "customer_id", value = 1)), + boost = None ) ) ) @@ -155,7 +164,8 @@ object QueryDSLSpec extends ZIOSpecDefault { MatchQuery(field = "day_of_week", value = "Monday"), MatchQuery(field = "customer_gender", value = "MALE") ), - should = List(MatchQuery(field = "customer_age", value = 23)) + should = List(MatchQuery(field = "customer_age", value = 23)), + boost = None ) ) ) @@ -177,7 +187,26 @@ object QueryDSLSpec extends ZIOSpecDefault { should = List( MatchQuery(field = "day_of_week", value = "Monday"), MatchQuery(field = "customer_gender", value = "MALE") - ) + ), + boost = None + ) + ) + ) + }, + test("successfully create `Filter/Must/Should` mixed query with boost") { + val query = boolQuery + .filter(matches(field = "customer_id", value = 1)) + .must(matches(field = "customer_age", value = 23)) + .should(matches(field = "day_of_week", value = "Monday")) + .boost(1.0) + + assert(query)( + equalTo( + BoolQuery( + filter = List(MatchQuery(field = "customer_id", value = 1)), + must = List(MatchQuery(field = "customer_age", value = 23)), + should = List(MatchQuery(field = "day_of_week", value = "Monday")), + boost = Some(1.0) ) ) ) @@ -210,7 +239,8 @@ object QueryDSLSpec extends ZIOSpecDefault { RangeQuery[Any, Unbounded.type, Unbounded.type]( field = "customer_age", lower = Unbounded, - upper = Unbounded + upper = Unbounded, + boost = None ) ) ) @@ -221,12 +251,22 @@ object QueryDSLSpec extends ZIOSpecDefault { assert(queryString)( equalTo( - RangeQuery[String, Unbounded.type, Unbounded.type](field = "name", lower = Unbounded, upper = Unbounded) + RangeQuery[String, Unbounded.type, Unbounded.type]( + field = "name", + lower = Unbounded, + upper = Unbounded, + boost = None + ) ) ) && assert(queryInt)( equalTo( - RangeQuery[Int, Unbounded.type, Unbounded.type](field = "age", lower = Unbounded, upper = Unbounded) + RangeQuery[Int, Unbounded.type, Unbounded.type]( + field = "age", + lower = Unbounded, + upper = Unbounded, + boost = None + ) ) ) }, @@ -238,7 +278,8 @@ object QueryDSLSpec extends ZIOSpecDefault { RangeQuery[String, Unbounded.type, Unbounded.type]( field = "name.keyword", lower = Unbounded, - upper = Unbounded + upper = Unbounded, + boost = None ) ) ) @@ -251,7 +292,8 @@ object QueryDSLSpec extends ZIOSpecDefault { RangeQuery[Int, Unbounded.type, LessThan[Int]]( field = "customer_age", lower = Unbounded, - upper = LessThan(23) + upper = LessThan(23), + boost = None ) ) ) @@ -264,7 +306,8 @@ object QueryDSLSpec extends ZIOSpecDefault { RangeQuery[Int, GreaterThan[Int], Unbounded.type]( field = "customer_age", lower = GreaterThan(23), - upper = Unbounded + upper = Unbounded, + boost = None ) ) ) @@ -277,7 +320,8 @@ object QueryDSLSpec extends ZIOSpecDefault { RangeQuery[Int, Unbounded.type, LessThanOrEqualTo[Int]]( field = "customer_age", lower = Unbounded, - upper = LessThanOrEqualTo(23) + upper = LessThanOrEqualTo(23), + boost = None ) ) ) @@ -290,7 +334,8 @@ object QueryDSLSpec extends ZIOSpecDefault { RangeQuery[Int, GreaterThanOrEqualTo[Int], Unbounded.type]( field = "customer_age", lower = GreaterThanOrEqualTo(23), - upper = Unbounded + upper = Unbounded, + boost = None ) ) ) @@ -303,7 +348,8 @@ object QueryDSLSpec extends ZIOSpecDefault { RangeQuery[Int, GreaterThanOrEqualTo[Int], LessThan[Int]]( field = "customer_age", lower = GreaterThanOrEqualTo(23), - upper = LessThan(50) + upper = LessThan(50), + boost = None ) ) ) @@ -453,6 +499,24 @@ object QueryDSLSpec extends ZIOSpecDefault { assert(query.toJson)(equalTo(expected.toJson)) }, + test("properly encode Bool Query with boost") { + val query = boolQuery.boost(1.0) + val expected = + """ + |{ + | "query": { + | "bool": { + | "filter": [], + | "must": [], + | "should": [], + | "boost": 1.0 + | } + | } + |} + |""".stripMargin + + assert(query.toJson)(equalTo(expected.toJson)) + }, test("properly encode Bool Query with Filter containing `Match` leaf query") { val query = boolQuery.filter(matches(field = "day_of_week", value = "Monday")) val expected = @@ -560,6 +624,46 @@ object QueryDSLSpec extends ZIOSpecDefault { assert(query.toJson)(equalTo(expected.toJson)) }, + test("properly encode Bool Query with Filter, Must and Should containing `Match` leaf query and with boost") { + val query = boolQuery + .filter(matches(field = "customer_age", value = 23)) + .must(matches(field = "customer_id", value = 1)) + .should(matches(field = "day_of_week", value = "Monday")) + .boost(1.0) + val expected = + """ + |{ + | "query": { + | "bool": { + | "filter": [ + | { + | "match": { + | "customer_age": 23 + | } + | } + | ], + | "must": [ + | { + | "match": { + | "customer_id": 1 + | } + | } + | ], + | "should": [ + | { + | "match": { + | "day_of_week": "Monday" + | } + | } + | ], + | "boost": 1.0 + | } + | } + |} + |""".stripMargin + + assert(query.toJson)(equalTo(expected.toJson)) + }, test("properly encode Exists Query") { val query = exists(field = "day_of_week") val expected = @@ -619,6 +723,23 @@ object QueryDSLSpec extends ZIOSpecDefault { assert(query.toJson)(equalTo(expected.toJson)) }, + test("properly encode Unbounded Range Query with boost") { + val query = range(field = "field").boost(1.0) + val expected = + """ + |{ + | "query": { + | "range": { + | "field": { + | }, + | "boost": 1.0 + | } + | } + |} + |""".stripMargin + + assert(query.toJson)(equalTo(expected.toJson)) + }, test("properly encode Range Query with Lower Bound") { val query = range(field = "customer_age").gt(23) val expected = @@ -705,6 +826,25 @@ object QueryDSLSpec extends ZIOSpecDefault { assert(query.toJson)(equalTo(expected.toJson)) }, + test("properly encode Range Query with both Upper and Lower Bound with boost") { + val query = range(field = "customer_age").gte(10).lt(100).boost(1.0) + val expected = + """ + |{ + | "query": { + | "range": { + | "customer_age": { + | "gte": 10, + | "lt": 100 + | }, + | "boost": 1.0 + | } + | } + |} + |""".stripMargin + + assert(query.toJson)(equalTo(expected.toJson)) + }, test("properly encode Term query") { val query = term(field = "day_of_week", value = true) val expected =