Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(dsl): Offer type-safe query API #39

Merged
merged 21 commits into from
Jan 30, 2023
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -54,16 +54,38 @@ object ElasticQuery {
def toJson(implicit EP: ElasticPrimitive[A]): Json = EP.toJson(value)
}

def matches[A: ElasticPrimitive](
field: Field[_, A],
multiField: Option[String] = None,
value: A
): ElasticQuery[Match] =
MatchQuery(field.toString ++ multiField.map("." ++ _).getOrElse(""), value)

def matches[A: ElasticPrimitive](field: String, value: A): ElasticQuery[Match] =
MatchQuery(field, value)

def boolQuery(): BoolQuery = BoolQuery.empty

def exists(field: Field[_, _]): ElasticQuery[Exists] = ExistsQuery(field.toString)

def exists(field: String): ElasticQuery[Exists] = ExistsQuery(field)

def matchAll(): ElasticQuery[MatchAll] = MatchAllQuery()

def range(field: String): RangeQuery[Unbounded.type, Unbounded.type] = RangeQuery.empty(field)
def range[A](
field: Field[_, A],
multiField: Option[String] = None
): RangeQuery[A, Unbounded.type, Unbounded.type] =
RangeQuery.empty(field.toString ++ multiField.map("." ++ _).getOrElse(""))

def range(field: String): RangeQuery[Any, Unbounded.type, Unbounded.type] = RangeQuery.empty[Any](field)

def term[A: ElasticPrimitive](
field: Field[_, A],
multiField: Option[String] = None,
value: A
): ElasticQuery[Term[A]] =
TermQuery(field.toString ++ multiField.map("." ++ _).getOrElse(""), value)

def term[A: ElasticPrimitive](field: String, value: A): ElasticQuery[Term[A]] = TermQuery(field, value)

Expand Down Expand Up @@ -125,33 +147,38 @@ object ElasticQuery {
override def toJson: Option[(String, Json)] = None
}

private[elasticsearch] final case class RangeQuery[LB <: LowerBound, UB <: UpperBound] private (
private[elasticsearch] final case class RangeQuery[A, LB <: LowerBound, UB <: UpperBound] private (
field: String,
lower: LB,
upper: UB
) extends ElasticQuery[Range] { self =>

def gt[A: ElasticPrimitive](value: A)(implicit @unused ev: LB =:= Unbounded.type): RangeQuery[GreaterThan[A], UB] =
def gt[B <: A: ElasticPrimitive](value: B)(implicit
@unused ev: LB =:= Unbounded.type
): RangeQuery[B, GreaterThan[B], UB] =
self.copy(lower = GreaterThan(value))

def gte[A: ElasticPrimitive](value: A)(implicit
def gte[B <: A: ElasticPrimitive](value: B)(implicit
@unused ev: LB =:= Unbounded.type
): RangeQuery[GreaterThanOrEqualTo[A], UB] =
): RangeQuery[B, GreaterThanOrEqualTo[B], UB] =
self.copy(lower = GreaterThanOrEqualTo(value))

def lt[A: ElasticPrimitive](value: A)(implicit @unused ev: UB =:= Unbounded.type): RangeQuery[LB, LessThan[A]] =
def lt[B <: A: ElasticPrimitive](value: B)(implicit
@unused ev: UB =:= Unbounded.type
): RangeQuery[B, LB, LessThan[B]] =
self.copy(upper = LessThan(value))

def lte[A: ElasticPrimitive](value: A)(implicit
def lte[B <: A: ElasticPrimitive](value: B)(implicit
@unused ev: UB =:= Unbounded.type
): RangeQuery[LB, LessThanOrEqualTo[A]] =
): RangeQuery[B, LB, LessThanOrEqualTo[B]] =
self.copy(upper = LessThanOrEqualTo(value))

override def toJson: Json = Obj("range" -> Obj(field -> Obj(List(lower.toJson, upper.toJson).flatten: _*)))
}

private[elasticsearch] object RangeQuery {
def empty(field: String): RangeQuery[Unbounded.type, Unbounded.type] = RangeQuery(field, Unbounded, Unbounded)
def empty[A](field: String): RangeQuery[A, Unbounded.type, Unbounded.type] =
RangeQuery[A, Unbounded.type, Unbounded.type](field, Unbounded, Unbounded)
}

private[elasticsearch] final case class TermQuery[A: ElasticPrimitive](
Expand Down
44 changes: 44 additions & 0 deletions modules/library/src/main/scala/zio/elasticsearch/Field.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package zio.elasticsearch

import zio.Chunk
import zio.schema.{AccessorBuilder, Schema}

import scala.annotation.tailrec

object Annotation {
final case class name(value: String) extends scala.annotation.Annotation

def maybeName(annotations: Chunk[Any]): Option[String] =
annotations.collect { case name(value) => value }.headOption
}

private[elasticsearch] final case class Field[From, To](parent: Option[Field[From, _]], name: String) { self =>

def /[To2](that: Field[To, To2]): Field[From, To2] =
Field(that.parent.map(self / _).orElse(Some(self)), that.name)

override def toString: String = {
@tailrec
def loop(field: Field[_, _], acc: List[String]): List[String] = field match {
case Field(None, name) => s"$name" +: acc
case Field(Some(parent), name) => loop(parent, s".$name" +: acc)
}

loop(self, List.empty).mkString("")
mvelimir marked this conversation as resolved.
Show resolved Hide resolved
}
}

object ElasticQueryAccessorBuilder extends AccessorBuilder {
override type Lens[_, From, To] = Field[From, To]
override type Prism[_, From, To] = Unit
override type Traversal[From, To] = Unit

override def makeLens[F, S, A](product: Schema.Record[S], term: Schema.Field[S, A]): Lens[F, S, A] = {
val label = Annotation.maybeName(term.annotations).getOrElse(term.name)
Field[S, A](None, label)
}

override def makePrism[F, S, A](sum: Schema.Enum[S], term: Schema.Case[S, A]): Prism[F, S, A] = ()

override def makeTraversal[S, A](collection: Schema.Collection[S, A], element: Schema[A]): Traversal[S, A] = ()
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package zio.elasticsearch

import zio.schema.{DeriveSchema, Schema}
import zio.test.{Spec, TestEnvironment, ZIOSpecDefault, assertTrue}

object FieldDSLSpec extends ZIOSpecDefault {

final case class Address(street: String, number: Int)

object Address {

implicit val schema: Schema.CaseClass2[String, Int, Address] = DeriveSchema.gen[Address]

val (street, number) = schema.makeAccessors(ElasticQueryAccessorBuilder)
}

final case class Student(name: String, address: Address)

object Student {

implicit val schema: Schema.CaseClass2[String, Address, Student] = DeriveSchema.gen[Student]

val (name, address) = schema.makeAccessors(ElasticQueryAccessorBuilder)
}

override def spec: Spec[TestEnvironment, Any] =
suite("Field DSL")(
test("properly encode single field path")(
assertTrue(Field(None, "name").toString == "name")
),
test("properly encode single field path using accessor")(
assertTrue(Student.name.toString == "name")
),
test("properly encode nested field path")(
assertTrue(Field[Nothing, Nothing](Some(Field(None, "address")), "number").toString == "address.number")
),
test("properly encode nested field path using accessors")(
assertTrue((Student.address / Address.number).toString == "address.number")
)
)
}
141 changes: 135 additions & 6 deletions modules/library/src/test/scala/zio/elasticsearch/QueryDSLSpec.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package zio.elasticsearch

import zio.Scope
import zio.elasticsearch.Annotation.name
import zio.elasticsearch.ElasticQuery._
import zio.elasticsearch.ElasticRequest.BulkRequest
import zio.elasticsearch.utils._
Expand All @@ -12,6 +13,19 @@ import zio.test._

object QueryDSLSpec extends ZIOSpecDefault {

final case class Student(
name: String,
age: Int,
@name("is_employed")
isEmployed: Boolean
)

object Student {

implicit val schema: Schema.CaseClass3[String, Int, Boolean, Student] = DeriveSchema.gen[Student]

val (name, age, isEmployed) = schema.makeAccessors(ElasticQueryAccessorBuilder)
}
final case class UserDocument(id: String, name: String, address: String, balance: Double, age: Int)

object UserDocument {
Expand All @@ -30,6 +44,20 @@ object QueryDSLSpec extends ZIOSpecDefault {
assert(queryBool)(equalTo(MatchQuery(field = "day_of_week", value = true))) &&
assert(queryLong)(equalTo(MatchQuery(field = "day_of_week", value = 1)))
},
test("successfully create type-safe Match query using `matches` method") {
val queryString = matches(field = Student.name, value = "James")
val queryInt = matches(field = Student.age, value = 20)
val queryBool = matches(field = Student.isEmployed, value = true)

assert(queryString)(equalTo(MatchQuery(field = "name", value = "Monday")))
mvelimir marked this conversation as resolved.
Show resolved Hide resolved
assert(queryInt)(equalTo(MatchQuery(field = "age", value = 20)))
assert(queryBool)(equalTo(MatchQuery(field = "is_employed", value = true)))
},
test("successfully create type-safe Match query with multi-field using `matches` method") {
val query = matches(field = Student.name, multiField = Some("keyword"), value = "James")

assert(query)(equalTo(MatchQuery(field = "name.keyword", value = "James")))
},
test("successfully create `Must` query from two Match queries") {
val query = boolQuery()
.must(matches(field = "day_of_week", value = "Monday"), matches(field = "customer_gender", value = "MALE"))
Expand Down Expand Up @@ -107,6 +135,11 @@ object QueryDSLSpec extends ZIOSpecDefault {

assert(query)(equalTo(ExistsQuery(field = "day_of_week")))
},
test("successfully create Exists Query with accessor") {
val query = exists(field = Student.name)

assert(query)(equalTo(ExistsQuery(field = "name")))
},
test("successfully create MatchAll Query") {
val query = matchAll()

Expand All @@ -120,35 +153,117 @@ object QueryDSLSpec extends ZIOSpecDefault {
test("successfully create empty Range Query") {
val query = range(field = "customer_age")

assert(query)(equalTo(RangeQuery(field = "customer_age", lower = Unbounded, upper = Unbounded)))
assert(query)(
equalTo(
RangeQuery[Any, Unbounded.type, Unbounded.type](
field = "customer_age",
lower = Unbounded,
upper = Unbounded
)
)
)
},
test("successfully create empty type-safe Range Query") {
val queryString = range(field = Student.name)
val queryInt = range(field = Student.age)
val queryBool = range(field = Student.isEmployed)

assert(queryString)(
equalTo(
RangeQuery[String, Unbounded.type, Unbounded.type](field = "name", lower = Unbounded, upper = Unbounded)
)
)
mvelimir marked this conversation as resolved.
Show resolved Hide resolved
assert(queryInt)(
equalTo(
RangeQuery[Int, Unbounded.type, Unbounded.type](field = "age", lower = Unbounded, upper = Unbounded)
)
)
assert(queryBool)(
equalTo(
RangeQuery[Boolean, Unbounded.type, Unbounded.type](
field = "is_employed",
lower = Unbounded,
upper = Unbounded
)
)
)
},
test("successfully create empty type-safe Range Query with multi-field") {
val query = range(field = Student.name, multiField = Some("keyword"))

assert(query)(
equalTo(
RangeQuery[String, Unbounded.type, Unbounded.type](
field = "name.keyword",
lower = Unbounded,
upper = Unbounded
)
)
)
},
test("successfully create Range Query with upper bound") {
val query = range(field = "customer_age").lt(23)

assert(query)(equalTo(RangeQuery(field = "customer_age", lower = Unbounded, upper = LessThan(23))))
assert(query)(
equalTo(
RangeQuery[Int, Unbounded.type, LessThan[Int]](
field = "customer_age",
lower = Unbounded,
upper = LessThan(23)
)
)
)
},
test("successfully create Range Query with lower bound") {
val query = range(field = "customer_age").gt(23)

assert(query)(equalTo(RangeQuery(field = "customer_age", lower = GreaterThan(23), upper = Unbounded)))
assert(query)(
equalTo(
RangeQuery[Int, GreaterThan[Int], Unbounded.type](
field = "customer_age",
lower = GreaterThan(23),
upper = Unbounded
)
)
)
},
test("successfully create Range Query with inclusive upper bound") {
val query = range(field = "customer_age").lte(23)

assert(query)(equalTo(RangeQuery(field = "customer_age", lower = Unbounded, upper = LessThanOrEqualTo(23))))
assert(query)(
equalTo(
RangeQuery[Int, Unbounded.type, LessThanOrEqualTo[Int]](
field = "customer_age",
lower = Unbounded,
upper = LessThanOrEqualTo(23)
)
)
)
},
test("successfully create Range Query with inclusive lower bound") {
val query = range(field = "customer_age").gte(23)

assert(query)(
equalTo(RangeQuery(field = "customer_age", lower = GreaterThanOrEqualTo(23), upper = Unbounded))
equalTo(
RangeQuery[Int, GreaterThanOrEqualTo[Int], Unbounded.type](
field = "customer_age",
lower = GreaterThanOrEqualTo(23),
upper = Unbounded
)
)
)
},
test("successfully create Range Query with both upper and lower bound") {
val query = range(field = "customer_age").gte(23).lt(50)

assert(query)(
equalTo(RangeQuery(field = "customer_age", lower = GreaterThanOrEqualTo(23), upper = LessThan(50)))
equalTo(
RangeQuery[Int, GreaterThanOrEqualTo[Int], LessThan[Int]](
field = "customer_age",
lower = GreaterThanOrEqualTo(23),
upper = LessThan(50)
)
)
)
},
test("successfully create Term Query") {
Expand All @@ -162,6 +277,20 @@ object QueryDSLSpec extends ZIOSpecDefault {
assert(queryBool)(equalTo(TermQuery(field = "day_of_week", value = true))) &&
assert(queryLong)(equalTo(TermQuery(field = "day_of_week", value = 1L)))
},
test("successfully create type-safe Term Query") {
val queryString = term(field = Student.name, value = "James")
val queryInt = term(field = Student.age, value = 20)
val queryBool = term(field = Student.isEmployed, value = true)

assert(queryString)(equalTo(TermQuery(field = "name", value = "Monday")))
mvelimir marked this conversation as resolved.
Show resolved Hide resolved
assert(queryInt)(equalTo(TermQuery(field = "age", value = 20)))
assert(queryBool)(equalTo(TermQuery(field = "is_employed", value = true)))
},
test("successfully create type-safe Term Query with multi-field") {
val query = term(field = Student.name, multiField = Some("keyword"), value = "James")

assert(query)(equalTo(TermQuery(field = "name.keyword", value = "James")))
},
test("successfully create Term Query with boost") {
val queryInt = term(field = "day_of_week", value = 1).boost(1.0)
val queryString = term(field = "day_of_week", value = "Monday").boost(1.0)
Expand Down