Skip to content

Commit

Permalink
Improved aggregation to support AggregationQuery
Browse files Browse the repository at this point in the history
  • Loading branch information
darkfrog26 committed Jun 10, 2024
1 parent 53c4dc8 commit 5bdd5b1
Show file tree
Hide file tree
Showing 7 changed files with 34 additions and 26 deletions.
6 changes: 5 additions & 1 deletion all/src/test/scala/spec/AggregationSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class AggregationSpec extends AsyncWordSpec with AsyncIOSpec with Matchers {
Person.query
.filter(Person.age <=> (5, 16))
.aggregate(max, min, avg, sum, count)
.stream
.compile
.toList
.map { list =>
Expand Down Expand Up @@ -98,7 +99,10 @@ class AggregationSpec extends AsyncWordSpec with AsyncIOSpec with Matchers {
// TODO: HAVING ageCount > 1
// TODO: ORDER BY ageCount DESC
Person.withSearchContext { implicit context =>
Person.query.aggregate(ids, names, age, count).compile.toList.map { list =>
Person.query
.aggregate(ids, names, age, count)
.toList
.map { list =>
// list.map(_(ids)) should be(Nil)
// list.map(_(names)) should be(Nil)
// list.map(_(age)) should be(Nil)
Expand Down
2 changes: 1 addition & 1 deletion all/src/test/scala/spec/SimpleHaloAndSQLiteSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ class SimpleHaloAndSQLiteSpec extends AsyncWordSpec with AsyncIOSpec with Matche
maxAge,
avgAge,
sumAge
).compile.toList.map { list =>
).stream.compile.toList.map { list =>
list.map(m => m(minAge)).toSet should be(Set(19))
list.map(m => m(maxAge)).toSet should be(Set(21))
list.map(m => m(avgAge)).toSet should be(Set(20.0))
Expand Down
15 changes: 15 additions & 0 deletions core/src/main/scala/lightdb/aggregate/AggregateQuery.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package lightdb.aggregate

import cats.effect.IO
import lightdb.Document
import lightdb.index.Materialized
import lightdb.query.{Query, SearchContext}

case class AggregateQuery[D <: Document[D]](query: Query[D, _],
functions: List[AggregateFunction[_, D]]) {
def stream(implicit context: SearchContext[D]): fs2.Stream[IO, Materialized[D]] = query.indexSupport.aggregate(this)

def toList: IO[List[Materialized[D]]] = query.indexSupport.withSearchContext { implicit context =>
stream.compile.toList
}
}
6 changes: 2 additions & 4 deletions core/src/main/scala/lightdb/index/IndexSupport.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import cats.effect.IO
import lightdb.model.{AbstractCollection, Collection, DocumentAction, DocumentListener, DocumentModel}
import lightdb.query.{Filter, PagedResults, Query, SearchContext}
import lightdb.Document
import lightdb.aggregate.AggregateFunction
import lightdb.aggregate.{AggregateFunction, AggregateQuery}
import lightdb.spatial.GeoPoint
import squants.space.Length

Expand Down Expand Up @@ -52,9 +52,7 @@ trait IndexSupport[D <: Document[D]] extends DocumentModel[D] {
limit: Option[Int],
after: Option[PagedResults[D, V]]): IO[PagedResults[D, V]]

def aggregate[V](query: Query[D, V],
functions: List[AggregateFunction[_, D]],
context: SearchContext[D]): fs2.Stream[IO, Materialized[D]]
def aggregate(query: AggregateQuery[D])(implicit context: SearchContext[D]): fs2.Stream[IO, Materialized[D]]

protected def indexDoc(doc: D, fields: List[Index[_, D]]): IO[Unit]
}
9 changes: 2 additions & 7 deletions core/src/main/scala/lightdb/query/Query.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package lightdb.query

import cats.Eq
import cats.effect.IO
import lightdb.aggregate.AggregateFunction
import lightdb.aggregate.{AggregateFunction, AggregateQuery}
import lightdb.index.{Index, IndexSupport, Materialized}
import lightdb.model.AbstractCollection
import lightdb.spatial.GeoPoint
Expand Down Expand Up @@ -105,12 +105,7 @@ case class Query[D <: Document[D], V](indexSupport: IndexSupport[D],
copy(materializedIndexes = indexes.toList).pageStream.flatMap(_.materializedStream)
}

def aggregate(functions: AggregateFunction[_, D]*)
(implicit context: SearchContext[D]): fs2.Stream[IO, Materialized[D]] = indexSupport.aggregate(
query = this,
functions = functions.toList,
context = context
)
def aggregate(functions: AggregateFunction[_, D]*): AggregateQuery[D] = AggregateQuery[D](this, functions.toList)

def stream(implicit context: SearchContext[D]): fs2.Stream[IO, V] = pageStream.flatMap(_.stream)

Expand Down
6 changes: 2 additions & 4 deletions lucene/src/main/scala/lightdb/lucene/LuceneSupport.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package lightdb.lucene
import cats.effect.IO
import fabric.define.DefType
import lightdb._
import lightdb.aggregate.AggregateFunction
import lightdb.aggregate.{AggregateFunction, AggregateQuery}
import lightdb.index.{Index, IndexSupport, Materialized}
import lightdb.model.AbstractCollection
import lightdb.query.{Filter, PageContext, PagedResults, Query, SearchContext, Sort, SortDirection}
Expand Down Expand Up @@ -78,9 +78,7 @@ trait LuceneSupport[D <: Document[D]] extends IndexSupport[D] {
_ = index.addDoc(doc._id, fields)
} yield ()

override def aggregate[V](query: Query[D, V],
functions: List[AggregateFunction[_, D]],
context: SearchContext[D]): fs2.Stream[IO, Materialized[D]] =
override def aggregate(query: AggregateQuery[D])(implicit context: SearchContext[D]): fs2.Stream[IO, Materialized[D]] =
throw new UnsupportedOperationException("Aggregate functions not supported in Lucene currently")

override protected[lightdb] def initModel(collection: AbstractCollection[D]): Unit = {
Expand Down
16 changes: 7 additions & 9 deletions sql/src/main/scala/lightdb/sql/SQLSupport.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package lightdb.sql
import cats.effect.IO
import fabric._
import fabric.io.JsonFormatter
import lightdb.aggregate.{AggregateFunction, AggregateType}
import lightdb.aggregate.{AggregateFunction, AggregateQuery, AggregateType}
import lightdb.{Document, Id}
import lightdb.index.{Index, IndexSupport, Materialized}
import lightdb.model.AbstractCollection
Expand Down Expand Up @@ -206,27 +206,25 @@ trait SQLSupport[D <: Document[D]] extends IndexSupport[D] {
override protected def indexDoc(doc: D, fields: List[Index[_, D]]): IO[Unit] =
backlog.enqueue(doc._id, doc).map(_ => ())

override def aggregate[V](query: Query[D, V],
functions: List[AggregateFunction[_, D]],
context: SearchContext[D]): fs2.Stream[IO, Materialized[D]] = {
override def aggregate(query: AggregateQuery[D])(implicit context: SearchContext[D]): fs2.Stream[IO, Materialized[D]] = {
val io = IO.blocking {
var params = List.empty[Json]
val filters = query.filter match {
val filters = query.query.filter match {
case Some(f) =>
val filter = f.asInstanceOf[SQLPart]
params = params ::: filter.args
s"WHERE\n ${filter.sql}"
case None => ""
}
val sort = query.sort.collect {
val sort = query.query.sort.collect {
case Sort.ByField(field, direction) =>
val dir = if (direction == SortDirection.Descending) "DESC" else "ASC"
s"${field.fieldName} $dir"
} match {
case Nil => ""
case list => list.mkString("ORDER BY ", ", ", "")
}
val fieldNames = functions.map { f =>
val fieldNames = query.functions.map { f =>
val af = f.`type` match {
case AggregateType.Max => Some("MAX")
case AggregateType.Min => Some("MIN")
Expand All @@ -242,7 +240,7 @@ trait SQLSupport[D <: Document[D]] extends IndexSupport[D] {
}
s"$fieldName AS ${f.name}"
}.mkString(", ")
val group = functions.filter(_.`type` == AggregateType.Group).map(_.fieldName).distinct match {
val group = query.functions.filter(_.`type` == AggregateType.Group).map(_.fieldName).distinct match {
case Nil => ""
case list =>
s"""GROUP BY
Expand All @@ -260,7 +258,7 @@ trait SQLSupport[D <: Document[D]] extends IndexSupport[D] {
scribe.info(s"SQL: $sql")
val ps = prepare(sql, params)
val rs = ps.executeQuery()
val iterator = materializedIterator(rs, functions.map(_.name))
val iterator = materializedIterator(rs, query.functions.map(_.name))
fs2.Stream.fromBlockingIterator[IO](iterator, 512)
}
fs2.Stream.force(io)
Expand Down

0 comments on commit 5bdd5b1

Please sign in to comment.