Skip to content

Commit

Permalink
Support boost in Range and Bool query
Browse files Browse the repository at this point in the history
  • Loading branch information
kevchuang committed Feb 8, 2023
1 parent 30bcae8 commit 2b11947
Show file tree
Hide file tree
Showing 3 changed files with 170 additions and 22 deletions.
24 changes: 22 additions & 2 deletions modules/library/src/main/scala/zio/elasticsearch/Boost.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,17 @@

package zio.elasticsearch

import zio.elasticsearch.ElasticQuery.{ElasticPrimitive, MatchAllQuery, TermQuery, WildcardQuery}
import zio.elasticsearch.ElasticQueryType.{MatchAll, Term, Wildcard}
import zio.elasticsearch.ElasticQuery.{
BoolQuery,
ElasticPrimitive,
LowerBound,
MatchAllQuery,
RangeQuery,
TermQuery,
UpperBound,
WildcardQuery
}
import zio.elasticsearch.ElasticQueryType.{Bool, MatchAll, Range, Term, Wildcard}

object Boost {

Expand All @@ -26,11 +35,22 @@ object Boost {
}

object WithBoost {
implicit val boolWithBoost: WithBoost[Bool] = (query: ElasticQuery[Bool], value: Double) =>
query match {
case q: BoolQuery => q.copy(boost = Some(value))
}

implicit val matchAllWithBoost: WithBoost[MatchAll] = (query: ElasticQuery[MatchAll], value: Double) =>
query match {
case q: MatchAllQuery => q.copy(boost = Some(value))
}

implicit def rangeWithBoost[A, LB <: LowerBound, UB <: UpperBound]: WithBoost[Range[A, LB, UB]] =
(query: ElasticQuery[Range[A, LB, UB]], value: Double) =>
query match {
case q: RangeQuery[A, LB, UB] => q.copy(boost = Some(value))
}

implicit def termWithBoost[A: ElasticPrimitive]: WithBoost[Term[A]] =
(query: ElasticQuery[Term[A]], value: Double) =>
query match {
Expand Down
51 changes: 31 additions & 20 deletions modules/library/src/main/scala/zio/elasticsearch/ElasticQuery.scala
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,8 @@ object ElasticQuery {
private[elasticsearch] final case class BoolQuery(
filter: List[ElasticQuery[_]],
must: List[ElasticQuery[_]],
should: List[ElasticQuery[_]]
should: List[ElasticQuery[_]],
boost: Option[Double]
) extends ElasticQuery[Bool] { self =>

def filter(queries: ElasticQuery[_]*): BoolQuery =
Expand All @@ -132,20 +133,21 @@ object ElasticQuery {
def must(queries: ElasticQuery[_]*): BoolQuery =
self.copy(must = must ++ queries)

def paramsToJson: Json = Obj(
"bool" -> Obj(
"filter" -> Arr(filter.map(_.paramsToJson): _*),
"must" -> Arr(must.map(_.paramsToJson): _*),
"should" -> Arr(should.map(_.paramsToJson): _*)
)
)
def paramsToJson: Json = {
val boolFields =
Some("filter" -> Arr(filter.map(_.paramsToJson): _*)) ++
Some("must" -> Arr(must.map(_.paramsToJson): _*)) ++
Some("should" -> Arr(should.map(_.paramsToJson): _*)) ++
boost.map("boost" -> Num(_))
Obj("bool" -> Obj(boolFields.toList: _*))
}

def should(queries: ElasticQuery[_]*): BoolQuery =
self.copy(should = should ++ queries)
}

private[elasticsearch] object BoolQuery {
def empty: BoolQuery = BoolQuery(Nil, Nil, Nil)
def empty: BoolQuery = BoolQuery(filter = Nil, must = Nil, should = Nil, boost = None)
}

private[elasticsearch] final case class ExistsQuery private (field: String) extends ElasticQuery[Exists] {
Expand Down Expand Up @@ -192,8 +194,9 @@ object ElasticQuery {
private[elasticsearch] final case class RangeQuery[A, LB <: LowerBound, UB <: UpperBound] private (
field: String,
lower: LB,
upper: UB
) extends ElasticQuery[Range] { self =>
upper: UB,
boost: Option[Double]
) extends ElasticQuery[Range[A, LB, UB]] { self =>

def gt[B <: A: ElasticPrimitive](value: B)(implicit
@unused ev: LB =:= Unbounded.type
Expand All @@ -215,12 +218,20 @@ object ElasticQuery {
): RangeQuery[B, LB, LessThanOrEqualTo[B]] =
self.copy(upper = LessThanOrEqualTo(value))

def paramsToJson: Json = Obj("range" -> Obj(field -> Obj(List(lower.toJson, upper.toJson).flatten: _*)))
def paramsToJson: Json = {
val rangeFields = Some(field -> Obj(List(lower.toJson, upper.toJson).flatten: _*)) ++ boost.map("boost" -> Num(_))
Obj("range" -> Obj(rangeFields.toList: _*))
}
}

private[elasticsearch] object RangeQuery {
def empty[A](field: String): RangeQuery[A, Unbounded.type, Unbounded.type] =
RangeQuery[A, Unbounded.type, Unbounded.type](field, Unbounded, Unbounded)
RangeQuery[A, Unbounded.type, Unbounded.type](
field = field,
lower = Unbounded,
upper = Unbounded,
boost = None
)
}

private[elasticsearch] final case class TermQuery[A: ElasticPrimitive](
Expand Down Expand Up @@ -256,11 +267,11 @@ object ElasticQuery {
sealed trait ElasticQueryType

object ElasticQueryType {
sealed trait Bool extends ElasticQueryType
sealed trait Exists extends ElasticQueryType
sealed trait Match extends ElasticQueryType
sealed trait MatchAll extends ElasticQueryType
sealed trait Range extends ElasticQueryType
sealed trait Term[A] extends ElasticQueryType
sealed trait Wildcard extends ElasticQueryType
sealed trait Bool extends ElasticQueryType
sealed trait Exists extends ElasticQueryType
sealed trait Match extends ElasticQueryType
sealed trait MatchAll extends ElasticQueryType
sealed trait Range[A, LB, UB] extends ElasticQueryType
sealed trait Term[A] extends ElasticQueryType
sealed trait Wildcard extends ElasticQueryType
}
117 changes: 117 additions & 0 deletions modules/library/src/test/scala/zio/elasticsearch/QueryDSLSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,11 @@ object QueryDSLSpec extends ZIOSpecDefault {

assert(query)(equalTo(MatchQuery(field = "name.keyword", value = "Name")))
},
test("successfully create Bool Query with boost") {
val query = boolQuery().boost(1.0)

assert(query)(equalTo(BoolQuery(filter = Nil, must = Nil, should = Nil, boost = Some(1.0))))
},
test("successfully create `Filter` query from two Match queries") {
val query = boolQuery
.filter(
Expand Down Expand Up @@ -182,6 +187,24 @@ object QueryDSLSpec extends ZIOSpecDefault {
)
)
},
test("successfully create `Filter/Must/Should` mixed query with boost") {
val query = boolQuery()
.filter(matches(field = "customer_id", value = 1))
.must(matches(field = "customer_age", value = 23))
.should(matches(field = "day_of_week", value = "Monday"))
.boost(1.0)

assert(query)(
equalTo(
BoolQuery(
filter = List(MatchQuery(field = "customer_id", value = 1)),
must = List(MatchQuery(field = "customer_age", value = 23)),
should = List(MatchQuery(field = "day_of_week", value = "Monday")),
boost = Some(1.0)
)
)
)
},
test("successfully create Exists Query") {
val query = exists(field = "day_of_week")

Expand Down Expand Up @@ -453,6 +476,24 @@ object QueryDSLSpec extends ZIOSpecDefault {

assert(query.toJson)(equalTo(expected.toJson))
},
test("properly encode Bool Query with boost") {
val query = boolQuery().boost(1.0)
val expected =
"""
|{
| "query": {
| "bool": {
| "filter": [],
| "must": [],
| "should": [],
| "boost": 1.0
| }
| }
|}
|""".stripMargin

assert(query.toJsonBody)(equalTo(expected.toJson))
},
test("properly encode Bool Query with Filter containing `Match` leaf query") {
val query = boolQuery.filter(matches(field = "day_of_week", value = "Monday"))
val expected =
Expand Down Expand Up @@ -560,6 +601,46 @@ object QueryDSLSpec extends ZIOSpecDefault {

assert(query.toJson)(equalTo(expected.toJson))
},
test("properly encode Bool Query with Filter, Must and Should containing `Match` leaf query and with boost") {
val query = boolQuery()
.filter(matches(field = "customer_age", value = 23))
.must(matches(field = "customer_id", value = 1))
.should(matches(field = "day_of_week", value = "Monday"))
.boost(1.0)
val expected =
"""
|{
| "query": {
| "bool": {
| "filter": [
| {
| "match": {
| "customer_age": 23
| }
| }
| ],
| "must": [
| {
| "match": {
| "customer_id": 1
| }
| }
| ],
| "should": [
| {
| "match": {
| "day_of_week": "Monday"
| }
| }
| ],
| "boost": 1.0
| }
| }
|}
|""".stripMargin

assert(query.toJsonBody)(equalTo(expected.toJson))
},
test("properly encode Exists Query") {
val query = exists(field = "day_of_week")
val expected =
Expand Down Expand Up @@ -619,6 +700,23 @@ object QueryDSLSpec extends ZIOSpecDefault {

assert(query.toJson)(equalTo(expected.toJson))
},
test("properly encode Unbounded Range Query with boost") {
val query = range(field = "field").boost(1.0)
val expected =
"""
|{
| "query": {
| "range": {
| "field": {
| },
| "boost": 1.0
| }
| }
|}
|""".stripMargin

assert(query.toJsonBody)(equalTo(expected.toJson))
},
test("properly encode Range Query with Lower Bound") {
val query = range(field = "customer_age").gt(23)
val expected =
Expand Down Expand Up @@ -705,6 +803,25 @@ object QueryDSLSpec extends ZIOSpecDefault {

assert(query.toJson)(equalTo(expected.toJson))
},
test("properly encode Range Query with both Upper and Lower Bound with boost") {
val query = range(field = "customer_age").gte(10).lt(100).boost(1.0)
val expected =
"""
|{
| "query": {
| "range": {
| "customer_age": {
| "gte": 10,
| "lt": 100
| },
| "boost": 1.0
| }
| }
|}
|""".stripMargin

assert(query.toJsonBody)(equalTo(expected.toJson))
},
test("properly encode Term query") {
val query = term(field = "day_of_week", value = true)
val expected =
Expand Down

0 comments on commit 2b11947

Please sign in to comment.