Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(dsl): Offer type-safe query API #39

Merged
merged 21 commits into from
Jan 30, 2023
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -54,17 +54,38 @@ object ElasticQuery {
def toJson(implicit EP: ElasticPrimitive[A]): Json = EP.toJson(value)
}

def matches[S, A: ElasticPrimitive](
field: Selection[S, A],
multiField: Option[String] = None,
value: A
): ElasticQuery[Match] =
MatchQuery(field.toString ++ multiField.map("." ++ _).getOrElse(""), value)

def matches[A: ElasticPrimitive](field: String, value: A): ElasticQuery[Match] =
MatchQuery(field, value)

def boolQuery(): BoolQuery = BoolQuery.empty

def exists(field: Selection[_, _], multiField: Option[String] = None): ElasticQuery[Exists] = ExistsQuery(
field.toString ++ multiField.map("." ++ _).getOrElse("")
)

def exists(field: String): ElasticQuery[Exists] = ExistsQuery(field)

def matchAll(): ElasticQuery[MatchAll] = MatchAllQuery()

def range(field: Selection[_, _], multiField: Option[String] = None): RangeQuery[Unbounded.type, Unbounded.type] =
RangeQuery.empty(field.toString ++ multiField.map("." ++ _).getOrElse(""))

def range(field: String): RangeQuery[Unbounded.type, Unbounded.type] = RangeQuery.empty(field)

def term[S, A: ElasticPrimitive](
field: Selection[S, A],
multiField: Option[String] = None,
value: A
): ElasticQuery[Term[A]] =
TermQuery(field.toString ++ multiField.map("." ++ _).getOrElse(""), value)

def term[A: ElasticPrimitive](field: String, value: A): ElasticQuery[Term[A]] = TermQuery(field, value)

private[elasticsearch] final case class BoolQuery(must: List[ElasticQuery[_]], should: List[ElasticQuery[_]])
Expand Down
49 changes: 49 additions & 0 deletions modules/library/src/main/scala/zio/elasticsearch/Selection.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package zio.elasticsearch

import zio.Chunk
import zio.schema.{AccessorBuilder, Schema}

import scala.annotation.tailrec

object Annotation {
final case class name(value: String) extends scala.annotation.Annotation

def maybeName(annotations: Chunk[Any]): Option[String] =
annotations.collect { case name(value) => value }.headOption
}

sealed trait Selection[-From, +To] { self =>

def /[To2](that: Selection[To, To2]): Selection[From, To2] = that match {
case Field(parent, key) => Field(parent.map(self / _).orElse(Some(self)), key)
mvelimir marked this conversation as resolved.
Show resolved Hide resolved
}

override def toString: String = {
@tailrec
def loop(selection: Selection[_, _], acc: List[String]): List[String] =
selection match {
case Field(None, name) => acc :+ s"$name"
case Field(Some(parent), name) => loop(parent, acc :+ s".$name")
}

loop(self, List.empty).reverse.mkString("")
mvelimir marked this conversation as resolved.
Show resolved Hide resolved
}
}

private[elasticsearch] final case class Field[From, To](parent: Option[Selection[From, _]], name: String)
extends Selection[From, To]

object ElasticQueryAccessorBuilder extends AccessorBuilder {
override type Lens[_, From, To] = Selection[From, To]
override type Prism[_, From, To] = Unit
override type Traversal[From, To] = Unit

override def makeLens[F, S, A](product: Schema.Record[S], term: Schema.Field[S, A]): Lens[F, S, A] = {
val label = Annotation.maybeName(term.annotations).getOrElse(term.name)
Field[S, A](None, label)
}

override def makePrism[F, S, A](sum: Schema.Enum[S], term: Schema.Case[S, A]): Prism[F, S, A] = ()

override def makeTraversal[S, A](collection: Schema.Collection[S, A], element: Schema[A]): Traversal[S, A] = ()
}
Original file line number Diff line number Diff line change
@@ -1,12 +1,29 @@
package zio.elasticsearch

import zio.Scope
import zio.elasticsearch.Annotation.name
import zio.elasticsearch.ElasticQuery._
import zio.elasticsearch.utils._
import zio.schema.{DeriveSchema, Schema}
import zio.test.Assertion.equalTo
import zio.test._

object QueryDSLSpec extends ZIOSpecDefault {

final case class Student(
name: String,
age: Int,
@name("is_employed")
isEmployed: Boolean
)

object Student {

implicit val schema: Schema.CaseClass3[String, Int, Boolean, Student] = DeriveSchema.gen[Student]

val (name, age, isEmployed) = schema.makeAccessors(ElasticQueryAccessorBuilder)
}

override def spec: Spec[Environment with TestEnvironment with Scope, Any] =
suite("Query DSL")(
suite("creating ElasticQuery")(
Expand All @@ -19,6 +36,20 @@ object QueryDSLSpec extends ZIOSpecDefault {
assert(queryBool)(equalTo(MatchQuery(field = "day_of_week", value = true)))
assert(queryLong)(equalTo(MatchQuery(field = "day_of_week", value = 1)))
},
test("successfully create type-safe Match query using `matches` method") {
val queryString = matches(field = Student.name, value = "James")
val queryInt = matches(field = Student.age, value = 20)
val queryBool = matches(field = Student.isEmployed, value = true)

assert(queryString)(equalTo(MatchQuery(field = "name", value = "Monday")))
mvelimir marked this conversation as resolved.
Show resolved Hide resolved
assert(queryInt)(equalTo(MatchQuery(field = "age", value = 20)))
assert(queryBool)(equalTo(MatchQuery(field = "is_employed", value = true)))
},
test("successfully create type-safe Match query with multi-field using `matches` method") {
val query = matches(field = Student.name, multiField = Some("keyword"), value = "James")

assert(query)(equalTo(MatchQuery(field = "name.keyword", value = "James")))
},
test("successfully create `Must` query from two Match queries") {
val query = boolQuery()
.must(matches(field = "day_of_week", value = "Monday"), matches(field = "customer_gender", value = "MALE"))
Expand Down Expand Up @@ -96,6 +127,16 @@ object QueryDSLSpec extends ZIOSpecDefault {

assert(query)(equalTo(ExistsQuery(field = "day_of_week")))
},
test("successfully create Exists Query with accessor") {
val query = exists(field = Student.name)

assert(query)(equalTo(ExistsQuery(field = "name")))
},
test("successfully create Exists Query with accessor and multi-field") {
val query = exists(field = Student.name, multiField = Some("keyword"))

assert(query)(equalTo(ExistsQuery(field = "name.keyword")))
},
test("successfully create MatchAll Query") {
val query = matchAll()

Expand All @@ -111,6 +152,20 @@ object QueryDSLSpec extends ZIOSpecDefault {

assert(query)(equalTo(RangeQuery(field = "customer_age", lower = Unbounded, upper = Unbounded)))
},
test("successfully create empty type-safe Range Query") {
val queryString = range(field = Student.name)
val queryInt = range(field = Student.age)
val queryBool = range(field = Student.isEmployed)

assert(queryString)(equalTo(RangeQuery(field = "name", lower = Unbounded, upper = Unbounded)))
assert(queryInt)(equalTo(RangeQuery(field = "age", lower = Unbounded, upper = Unbounded)))
assert(queryBool)(equalTo(RangeQuery(field = "is_employed", lower = Unbounded, upper = Unbounded)))
},
test("successfully create empty type-safe Range Query with multi-field") {
val query = range(field = Student.name, multiField = Some("keyword"))

assert(query)(equalTo(RangeQuery(field = "name.keyword", lower = Unbounded, upper = Unbounded)))
},
test("successfully create Range Query with upper bound") {
val query = range(field = "customer_age").lt(23)

Expand Down Expand Up @@ -151,6 +206,20 @@ object QueryDSLSpec extends ZIOSpecDefault {
assert(queryBool)(equalTo(TermQuery(field = "day_of_week", value = true)))
assert(queryLong)(equalTo(TermQuery(field = "day_of_week", value = 1L)))
},
test("successfully create type-safe Term Query") {
val queryString = term(field = Student.name, value = "James")
val queryInt = term(field = Student.age, value = 20)
val queryBool = term(field = Student.isEmployed, value = true)

assert(queryString)(equalTo(TermQuery(field = "name", value = "Monday")))
mvelimir marked this conversation as resolved.
Show resolved Hide resolved
assert(queryInt)(equalTo(TermQuery(field = "age", value = 20)))
assert(queryBool)(equalTo(TermQuery(field = "is_employed", value = true)))
},
test("successfully create type-safe Term Query with multi-field") {
val query = term(field = Student.name, multiField = Some("keyword"), value = "James")

assert(query)(equalTo(TermQuery(field = "name.keyword", value = "James")))
},
test("successfully create Term Query with boost") {
val queryInt = term(field = "day_of_week", value = 1).boost(1.0)
val queryString = term(field = "day_of_week", value = "Monday").boost(1.0)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package zio.elasticsearch

import zio.schema.{DeriveSchema, Schema}
import zio.test.{Spec, TestEnvironment, ZIOSpecDefault, assertTrue}

object SelectionDSLSpec extends ZIOSpecDefault {

final case class Address(street: String, number: Int)

object Address {

implicit val schema: Schema.CaseClass2[String, Int, Address] = DeriveSchema.gen[Address]

val (street, number) = schema.makeAccessors(ElasticQueryAccessorBuilder)
}

final case class Student(name: String, address: Address)

object Student {

implicit val schema: Schema.CaseClass2[String, Address, Student] = DeriveSchema.gen[Student]

val (name, address) = schema.makeAccessors(ElasticQueryAccessorBuilder)
}

override def spec: Spec[TestEnvironment, Any] =
suite("Selection DSL")(
test("properly encode single field path")(
assertTrue(Field(None, "name").toString == "name")
),
test("properly encode single field path using accessor")(
assertTrue(Student.name.toString == "name")
),
test("properly encode nested field path")(
assertTrue(Field(Some(Field(None, "address")), "number").toString == "address.number")
),
test("properly encode nested field path using accessors")(
assertTrue((Student.address / Address.number).toString == "address.number")
)
)
}