Skip to content

Commit

Permalink
Improvements to filtering support with !==
Browse files Browse the repository at this point in the history
  • Loading branch information
darkfrog26 committed Aug 23, 2024
1 parent a212773 commit 01e2ea9
Show file tree
Hide file tree
Showing 10 changed files with 45 additions and 2 deletions.
2 changes: 1 addition & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ val developerURL: String = "https://matthicks.com"

name := projectName
ThisBuild / organization := org
ThisBuild / version := "0.12.0"
ThisBuild / version := "0.12.1-SNAPSHOT"
ThisBuild / scalaVersion := scala213
ThisBuild / crossScalaVersions := allScalaVersions
ThisBuild / scalacOptions ++= Seq("-unchecked", "-deprecation")
Expand Down
2 changes: 2 additions & 0 deletions core/src/main/scala/lightdb/Field.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ sealed class Field[Doc, V](val name: String,

override def is(value: V): Filter[Doc] = Filter.Equals(this, value)

override def !==(value: V): Filter[Doc] = Filter.NotEquals(this, value)

override protected def rangeLong(from: Option[Long], to: Option[Long]): Filter[Doc] =
Filter.RangeLong(this.asInstanceOf[Field[Doc, Long]], from, to)

Expand Down
4 changes: 4 additions & 0 deletions core/src/main/scala/lightdb/aggregate/AggregateFilter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ object AggregateFilter {
def getJson: Json = field.rw.read(value)
}

case class NotEquals[Doc, F](name: String, field: Field[Doc, F], value: F) extends AggregateFilter[Doc] {
def getJson: Json = field.rw.read(value)
}

case class In[Doc, F](name: String, field: Field[Doc, F], values: Seq[F]) extends AggregateFilter[Doc] {
def getJson: List[Json] = values.toList.map(field.rw.read)
}
Expand Down
2 changes: 2 additions & 0 deletions core/src/main/scala/lightdb/aggregate/AggregateFunction.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ case class AggregateFunction[T, V, Doc](name: String, field: Field[Doc, V], `typ

override def is(value: V): AggregateFilter[Doc] = AggregateFilter.Equals(name, field, value)

override def !==(value: V): AggregateFilter[Doc] = AggregateFilter.NotEquals(name, field, value)

override protected def rangeLong(from: Option[Long], to: Option[Long]): AggregateFilter[Doc] =
AggregateFilter.RangeLong(name, field.asInstanceOf[Field[Doc, Long]], from, to)

Expand Down
6 changes: 6 additions & 0 deletions core/src/main/scala/lightdb/filter/Filter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ object Filter {
override lazy val fields: List[Field[Doc, _]] = List(field)
}

case class NotEquals[Doc, F](field: Field[Doc, F], value: F) extends Filter[Doc] {
def getJson: Json = field.rw.read(value)

override lazy val fields: List[Field[Doc, _]] = List(field)
}

case class In[Doc, F](field: Field[Doc, F], values: Seq[F]) extends Filter[Doc] {
def getJson: List[Json] = values.toList.map(field.rw.read)

Expand Down
2 changes: 2 additions & 0 deletions core/src/main/scala/lightdb/filter/FilterSupport.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ trait FilterSupport[F, Doc, Filter] {
def ===(value: F): Filter = is(value)
def is(value: F): Filter

def !==(value: F): Filter

def >(value: F)(implicit num: Numeric[F]): Filter = range(Some(value), None, includeFrom = false)
def >=(value: F)(implicit num: Numeric[F]): Filter = range(Some(value), None)
def <(value: F)(implicit num: Numeric[F]): Filter = range(None, Some(value), includeTo = false)
Expand Down
6 changes: 6 additions & 0 deletions core/src/test/scala/spec/AbstractBasicSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,12 @@ abstract class AbstractBasicSpec extends AnyWordSpec with Matchers { spec =>
ids.toSet should be(Set(adam._id, nancy._id, oscar._id, uba._id))
}
}
"search excluding age 30" in {
db.people.transaction { implicit transaction =>
val names = db.people.query.filter(_.age !== 30).toList.map(_.name).toSet
names should be(Set("Linda", "Ruth", "Nancy", "Jenna", "Hanna", "Diana", "Ian", "Zoey", "Quintin", "Uba", "Oscar", "Kevin", "Penny", "Charlie", "Evan", "Sam", "Mike", "Brenda", "Adam", "Xena", "Fiona", "Greg", "Veronica"))
}
}
"sort by age" in {
db.people.transaction { implicit transaction =>
val people = db.people.query.sort(Sort.ByField(Person.age).descending).search.docs.list
Expand Down
5 changes: 5 additions & 0 deletions lucene/src/main/scala/lightdb/lucene/LuceneStore.scala
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,11 @@ class LuceneStore[Doc <: Document[Doc], Model <: DocumentModel[Doc]](directory:
private def filter2Lucene(filter: Option[Filter[Doc]]): LuceneQuery = filter match {
case Some(f) => f match {
case f: Filter.Equals[Doc, _] => exactQuery(f.field, f.getJson)
case f: Filter.NotEquals[Doc, _] =>
val b = new BooleanQuery.Builder
b.add(new MatchAllDocsQuery, BooleanClause.Occur.MUST)
b.add(exactQuery(f.field, f.getJson), BooleanClause.Occur.MUST_NOT)
b.build()
case f: Filter.In[Doc, _] =>
val queries = f.getJson.map(json => exactQuery(f.field, json))
val b = new BooleanQuery.Builder
Expand Down
3 changes: 3 additions & 0 deletions sql/src/main/scala/lightdb/sql/SQLArg.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ object SQLArg {
case l: Long => ps.setLong(index, l)
case f: Float => ps.setFloat(index, f)
case d: Double => ps.setDouble(index, d)
case bd: BigDecimal => ps.setDouble(index, bd.toDouble)
case json: Json => ps.setString(index, JsonFormatter.Compact(json))
case point: GeoPoint => ps.setString(index, s"POINT(${point.longitude} ${point.latitude})")
case _ =>
Expand All @@ -44,6 +45,8 @@ object SQLArg {
}

override def set(ps: PreparedStatement, index: Int): Unit = setInternal(ps, index, value)

override def toString: String = s"FieldArg(field = ${field.name}, value = $value (${value.getClass.getName}))"
}

object FieldArg {
Expand Down
15 changes: 14 additions & 1 deletion sql/src/main/scala/lightdb/sql/SQLStore.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import fabric._
import fabric.define.DefType
import fabric.io.{JsonFormatter, JsonParser}
import fabric.rw._
import lightdb.aggregate.{AggregateFilter, AggregateQuery, AggregateType}
import lightdb.aggregate.{AggregateFilter, AggregateFunction, AggregateQuery, AggregateType}
import lightdb.collection.Collection
import lightdb.distance.Distance
import lightdb.doc.{Document, DocumentModel, JsonConversion}
Expand Down Expand Up @@ -398,6 +398,9 @@ abstract class SQLStore[Doc <: Document[Doc], Model <: DocumentModel[Doc]] exten
case Sort.ByField(field, direction) =>
val dir = if (direction == SortDirection.Descending) "DESC" else "ASC"
SQLPart(s"${field.name} $dir", Nil)
case (AggregateFunction(name, _, _), direction: SortDirection) =>
val dir = if (direction == SortDirection.Descending) "DESC" else "ASC"
SQLPart(s"$name $dir", Nil)
}
SQLQueryBuilder(
collection = collection,
Expand Down Expand Up @@ -477,6 +480,15 @@ abstract class SQLStore[Doc <: Document[Doc], Model <: DocumentModel[Doc]] exten
SQLPart.merge(parts: _*)
case f: Filter.Equals[Doc, _] if f.value == null | f.value == None => SQLPart(s"${f.field.name} IS NULL")
case f: Filter.Equals[Doc, _] => SQLPart(s"${f.field.name} = ?", List(SQLArg.FieldArg(f.field, f.value)))
case f: Filter.NotEquals[Doc, _] if f.field.isArr =>
val values = f.getJson.asVector
val parts = values.map { json =>
val jsonString = JsonFormatter.Compact(json)
SQLPart(s"${f.field.name} NOT LIKE ?", List(SQLArg.StringArg(s"%$jsonString%")))
}
SQLPart.merge(parts: _*)
case f: Filter.NotEquals[Doc, _] if f.value == null | f.value == None => SQLPart(s"${f.field.name} IS NOT NULL")
case f: Filter.NotEquals[Doc, _] => SQLPart(s"${f.field.name} != ?", List(SQLArg.FieldArg(f.field, f.value)))
case f: Filter.In[Doc, _] => SQLPart(s"${f.field.name} IN (${f.values.map(_ => "?").mkString(", ")})", f.values.toList.map(v => SQLArg.FieldArg(f.field, v)))
case f: Filter.RangeLong[Doc] => (f.from, f.to) match {
case (Some(from), Some(to)) => SQLPart(s"${f.field.name} BETWEEN ? AND ?", List(SQLArg.LongArg(from), SQLArg.LongArg(to)))
Expand Down Expand Up @@ -524,6 +536,7 @@ abstract class SQLStore[Doc <: Document[Doc], Model <: DocumentModel[Doc]] exten

private def af2Part(f: AggregateFilter[Doc]): SQLPart = f match {
case f: AggregateFilter.Equals[Doc, _] => SQLPart(s"${f.name} = ?", List(SQLArg.FieldArg(f.field, f.value)))
case f: AggregateFilter.NotEquals[Doc, _] => SQLPart(s"${f.name} != ?", List(SQLArg.FieldArg(f.field, f.value)))
case f: AggregateFilter.In[Doc, _] => SQLPart(s"${f.name} IN (${f.values.map(_ => "?").mkString(", ")})", f.values.toList.map(v => SQLArg.FieldArg(f.field, v)))
case f: AggregateFilter.Combined[Doc] =>
val parts = f.filters.map(f => af2Part(f))
Expand Down

0 comments on commit 01e2ea9

Please sign in to comment.