Skip to content

Commit

Permalink
(dsl): Offer type-safe query API (#39)
Browse files Browse the repository at this point in the history
  • Loading branch information
mvelimir authored Jan 30, 2023
1 parent 7623b29 commit 6cf7012
Show file tree
Hide file tree
Showing 4 changed files with 223 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -54,16 +54,38 @@ object ElasticQuery {
def toJson(implicit EP: ElasticPrimitive[A]): Json = EP.toJson(value)
}

def matches[A: ElasticPrimitive](
field: Field[_, A],
multiField: Option[String] = None,
value: A
): ElasticQuery[Match] =
MatchQuery(field.toString ++ multiField.map("." ++ _).getOrElse(""), value)

def matches[A: ElasticPrimitive](field: String, value: A): ElasticQuery[Match] =
MatchQuery(field, value)

def boolQuery(): BoolQuery = BoolQuery.empty

def exists(field: Field[_, _]): ElasticQuery[Exists] = ExistsQuery(field.toString)

def exists(field: String): ElasticQuery[Exists] = ExistsQuery(field)

def matchAll(): ElasticQuery[MatchAll] = MatchAllQuery()

def range(field: String): RangeQuery[Unbounded.type, Unbounded.type] = RangeQuery.empty(field)
def range[A](
field: Field[_, A],
multiField: Option[String] = None
): RangeQuery[A, Unbounded.type, Unbounded.type] =
RangeQuery.empty(field.toString ++ multiField.map("." ++ _).getOrElse(""))

def range(field: String): RangeQuery[Any, Unbounded.type, Unbounded.type] = RangeQuery.empty[Any](field)

def term[A: ElasticPrimitive](
field: Field[_, A],
multiField: Option[String] = None,
value: A
): ElasticQuery[Term[A]] =
TermQuery(field.toString ++ multiField.map("." ++ _).getOrElse(""), value)

def term[A: ElasticPrimitive](field: String, value: A): ElasticQuery[Term[A]] = TermQuery(field, value)

Expand Down Expand Up @@ -125,33 +147,38 @@ object ElasticQuery {
override def toJson: Option[(String, Json)] = None
}

private[elasticsearch] final case class RangeQuery[LB <: LowerBound, UB <: UpperBound] private (
private[elasticsearch] final case class RangeQuery[A, LB <: LowerBound, UB <: UpperBound] private (
field: String,
lower: LB,
upper: UB
) extends ElasticQuery[Range] { self =>

def gt[A: ElasticPrimitive](value: A)(implicit @unused ev: LB =:= Unbounded.type): RangeQuery[GreaterThan[A], UB] =
def gt[B <: A: ElasticPrimitive](value: B)(implicit
@unused ev: LB =:= Unbounded.type
): RangeQuery[B, GreaterThan[B], UB] =
self.copy(lower = GreaterThan(value))

def gte[A: ElasticPrimitive](value: A)(implicit
def gte[B <: A: ElasticPrimitive](value: B)(implicit
@unused ev: LB =:= Unbounded.type
): RangeQuery[GreaterThanOrEqualTo[A], UB] =
): RangeQuery[B, GreaterThanOrEqualTo[B], UB] =
self.copy(lower = GreaterThanOrEqualTo(value))

def lt[A: ElasticPrimitive](value: A)(implicit @unused ev: UB =:= Unbounded.type): RangeQuery[LB, LessThan[A]] =
def lt[B <: A: ElasticPrimitive](value: B)(implicit
@unused ev: UB =:= Unbounded.type
): RangeQuery[B, LB, LessThan[B]] =
self.copy(upper = LessThan(value))

def lte[A: ElasticPrimitive](value: A)(implicit
def lte[B <: A: ElasticPrimitive](value: B)(implicit
@unused ev: UB =:= Unbounded.type
): RangeQuery[LB, LessThanOrEqualTo[A]] =
): RangeQuery[B, LB, LessThanOrEqualTo[B]] =
self.copy(upper = LessThanOrEqualTo(value))

override def toJson: Json = Obj("range" -> Obj(field -> Obj(List(lower.toJson, upper.toJson).flatten: _*)))
}

private[elasticsearch] object RangeQuery {
def empty(field: String): RangeQuery[Unbounded.type, Unbounded.type] = RangeQuery(field, Unbounded, Unbounded)
def empty[A](field: String): RangeQuery[A, Unbounded.type, Unbounded.type] =
RangeQuery[A, Unbounded.type, Unbounded.type](field, Unbounded, Unbounded)
}

private[elasticsearch] final case class TermQuery[A: ElasticPrimitive](
Expand Down
34 changes: 34 additions & 0 deletions modules/library/src/main/scala/zio/elasticsearch/Field.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package zio.elasticsearch

import zio.schema.{AccessorBuilder, Schema}

import scala.annotation.tailrec

private[elasticsearch] final case class Field[S, A](parent: Option[Field[S, _]], name: String) { self =>

def /[B](that: Field[A, B]): Field[S, B] =
Field(that.parent.map(self / _).orElse(Some(self)), that.name)

override def toString: String = {
@tailrec
def loop(field: Field[_, _], acc: List[String]): List[String] = field match {
case Field(None, name) => s"$name" +: acc
case Field(Some(parent), name) => loop(parent, s".$name" +: acc)
}

loop(self, Nil).mkString
}
}

object ElasticQueryAccessorBuilder extends AccessorBuilder {
override type Lens[_, S, A] = Field[S, A]
override type Prism[_, S, A] = Unit
override type Traversal[S, A] = Unit

override def makeLens[F, S, A](product: Schema.Record[S], term: Schema.Field[S, A]): Lens[_, S, A] =
Field[S, A](None, term.name)

override def makePrism[F, S, A](sum: Schema.Enum[S], term: Schema.Case[S, A]): Prism[_, S, A] = ()

override def makeTraversal[S, A](collection: Schema.Collection[S, A], element: Schema[A]): Traversal[S, A] = ()
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package zio.elasticsearch

import zio.schema.{DeriveSchema, Schema}
import zio.test.{Spec, TestEnvironment, ZIOSpecDefault, assertTrue}

object FieldDSLSpec extends ZIOSpecDefault {

final case class Address(street: String, number: Int)

object Address {

implicit val schema: Schema.CaseClass2[String, Int, Address] = DeriveSchema.gen[Address]

val (street, number) = schema.makeAccessors(ElasticQueryAccessorBuilder)
}

final case class Student(name: String, address: Address)

object Student {

implicit val schema: Schema.CaseClass2[String, Address, Student] = DeriveSchema.gen[Student]

val (name, address) = schema.makeAccessors(ElasticQueryAccessorBuilder)
}

override def spec: Spec[TestEnvironment, Any] =
suite("Field DSL")(
test("properly encode single field path")(
assertTrue(Field(None, "name").toString == "name")
),
test("properly encode single field path using accessor")(
assertTrue(Student.name.toString == "name")
),
test("properly encode nested field path")(
assertTrue(Field[Nothing, Nothing](Some(Field(None, "address")), "number").toString == "address.number")
),
test("properly encode nested field path using accessors")(
assertTrue((Student.address / Address.number).toString == "address.number")
)
)
}
119 changes: 112 additions & 7 deletions modules/library/src/test/scala/zio/elasticsearch/QueryDSLSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@ object QueryDSLSpec extends ZIOSpecDefault {
final case class UserDocument(id: String, name: String, address: String, balance: Double, age: Int)

object UserDocument {
implicit val schema: Schema[UserDocument] = DeriveSchema.gen[UserDocument]

implicit val schema: Schema.CaseClass5[String, String, String, Double, Int, UserDocument] =
DeriveSchema.gen[UserDocument]

val (id, name, address, balance, age) = schema.makeAccessors(ElasticQueryAccessorBuilder)
}

override def spec: Spec[Environment with TestEnvironment with Scope, Any] =
Expand All @@ -30,6 +34,18 @@ object QueryDSLSpec extends ZIOSpecDefault {
assert(queryBool)(equalTo(MatchQuery(field = "day_of_week", value = true))) &&
assert(queryLong)(equalTo(MatchQuery(field = "day_of_week", value = 1)))
},
test("successfully create type-safe Match query using `matches` method") {
val queryString = matches(field = UserDocument.name, value = "Name")
val queryInt = matches(field = UserDocument.age, value = 39)

assert(queryString)(equalTo(MatchQuery(field = "name", value = "Name"))) &&
assert(queryInt)(equalTo(MatchQuery(field = "age", value = 39)))
},
test("successfully create type-safe Match query with multi-field using `matches` method") {
val query = matches(field = UserDocument.name, multiField = Some("keyword"), value = "Name")

assert(query)(equalTo(MatchQuery(field = "name.keyword", value = "Name")))
},
test("successfully create `Must` query from two Match queries") {
val query = boolQuery()
.must(matches(field = "day_of_week", value = "Monday"), matches(field = "customer_gender", value = "MALE"))
Expand Down Expand Up @@ -107,6 +123,11 @@ object QueryDSLSpec extends ZIOSpecDefault {

assert(query)(equalTo(ExistsQuery(field = "day_of_week")))
},
test("successfully create Exists Query with accessor") {
val query = exists(field = UserDocument.name)

assert(query)(equalTo(ExistsQuery(field = "name")))
},
test("successfully create MatchAll Query") {
val query = matchAll()

Expand All @@ -120,35 +141,107 @@ object QueryDSLSpec extends ZIOSpecDefault {
test("successfully create empty Range Query") {
val query = range(field = "customer_age")

assert(query)(equalTo(RangeQuery(field = "customer_age", lower = Unbounded, upper = Unbounded)))
assert(query)(
equalTo(
RangeQuery[Any, Unbounded.type, Unbounded.type](
field = "customer_age",
lower = Unbounded,
upper = Unbounded
)
)
)
},
test("successfully create empty type-safe Range Query") {
val queryString = range(field = UserDocument.name)
val queryInt = range(field = UserDocument.age)

assert(queryString)(
equalTo(
RangeQuery[String, Unbounded.type, Unbounded.type](field = "name", lower = Unbounded, upper = Unbounded)
)
) &&
assert(queryInt)(
equalTo(
RangeQuery[Int, Unbounded.type, Unbounded.type](field = "age", lower = Unbounded, upper = Unbounded)
)
)
},
test("successfully create empty type-safe Range Query with multi-field") {
val query = range(field = UserDocument.name, multiField = Some("keyword"))

assert(query)(
equalTo(
RangeQuery[String, Unbounded.type, Unbounded.type](
field = "name.keyword",
lower = Unbounded,
upper = Unbounded
)
)
)
},
test("successfully create Range Query with upper bound") {
val query = range(field = "customer_age").lt(23)

assert(query)(equalTo(RangeQuery(field = "customer_age", lower = Unbounded, upper = LessThan(23))))
assert(query)(
equalTo(
RangeQuery[Int, Unbounded.type, LessThan[Int]](
field = "customer_age",
lower = Unbounded,
upper = LessThan(23)
)
)
)
},
test("successfully create Range Query with lower bound") {
val query = range(field = "customer_age").gt(23)

assert(query)(equalTo(RangeQuery(field = "customer_age", lower = GreaterThan(23), upper = Unbounded)))
assert(query)(
equalTo(
RangeQuery[Int, GreaterThan[Int], Unbounded.type](
field = "customer_age",
lower = GreaterThan(23),
upper = Unbounded
)
)
)
},
test("successfully create Range Query with inclusive upper bound") {
val query = range(field = "customer_age").lte(23)

assert(query)(equalTo(RangeQuery(field = "customer_age", lower = Unbounded, upper = LessThanOrEqualTo(23))))
assert(query)(
equalTo(
RangeQuery[Int, Unbounded.type, LessThanOrEqualTo[Int]](
field = "customer_age",
lower = Unbounded,
upper = LessThanOrEqualTo(23)
)
)
)
},
test("successfully create Range Query with inclusive lower bound") {
val query = range(field = "customer_age").gte(23)

assert(query)(
equalTo(RangeQuery(field = "customer_age", lower = GreaterThanOrEqualTo(23), upper = Unbounded))
equalTo(
RangeQuery[Int, GreaterThanOrEqualTo[Int], Unbounded.type](
field = "customer_age",
lower = GreaterThanOrEqualTo(23),
upper = Unbounded
)
)
)
},
test("successfully create Range Query with both upper and lower bound") {
val query = range(field = "customer_age").gte(23).lt(50)

assert(query)(
equalTo(RangeQuery(field = "customer_age", lower = GreaterThanOrEqualTo(23), upper = LessThan(50)))
equalTo(
RangeQuery[Int, GreaterThanOrEqualTo[Int], LessThan[Int]](
field = "customer_age",
lower = GreaterThanOrEqualTo(23),
upper = LessThan(50)
)
)
)
},
test("successfully create Term Query") {
Expand All @@ -162,6 +255,18 @@ object QueryDSLSpec extends ZIOSpecDefault {
assert(queryBool)(equalTo(TermQuery(field = "day_of_week", value = true))) &&
assert(queryLong)(equalTo(TermQuery(field = "day_of_week", value = 1L)))
},
test("successfully create type-safe Term Query") {
val queryString = term(field = UserDocument.name, value = "Name")
val queryInt = term(field = UserDocument.age, value = 39)

assert(queryString)(equalTo(TermQuery(field = "name", value = "Name"))) &&
assert(queryInt)(equalTo(TermQuery(field = "age", value = 39)))
},
test("successfully create type-safe Term Query with multi-field") {
val query = term(field = UserDocument.name, multiField = Some("keyword"), value = "Name")

assert(query)(equalTo(TermQuery(field = "name.keyword", value = "Name")))
},
test("successfully create Term Query with boost") {
val queryInt = term(field = "day_of_week", value = 1).boost(1.0)
val queryString = term(field = "day_of_week", value = "Monday").boost(1.0)
Expand Down

0 comments on commit 6cf7012

Please sign in to comment.