Skip to content

Commit

Permalink
Add tracking of latest SQL string and ExecutionInfo in ZIO
Browse files Browse the repository at this point in the history
  • Loading branch information
deusaquilus committed Oct 30, 2024
1 parent 73feec6 commit f3c136f
Show file tree
Hide file tree
Showing 7 changed files with 96 additions and 17 deletions.
11 changes: 11 additions & 0 deletions quill-jdbc-zio/src/main/scala/io/getquill/Diagnostic.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package io.getquill

import zio.ZIO
import zio.UIO
import io.getquill.context.{ExecutionInfo, ZioQuillLog}

def getLastExecutedQuery(): UIO[Option[String]] =
ZioQuillLog.latestSqlQuery.get

def getLastExecutionInfo(): UIO[Option[ExecutionInfo]] =
ZioQuillLog.latestExecutionInfo.get
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@ import javax.sql.DataSource
import zio.Scope

object ZioJdbc {
val SqlAnnotationKey = "quill.sql.latest"

type QIO[T] = ZIO[DataSource, SQLException, T]
type QStream[T] = ZStream[DataSource, SQLException, T]

Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,34 @@
package io.getquill.context

import zio.FiberRef
import zio.*

object ZioQuillLog {
val latestExecutionInfo: FiberRef[Option[ExecutionInfo]] =
zio.Unsafe.unsafe {
FiberRef.unsafe.make[Option[ExecutionInfo]](None)
}

val currentExecutionInfo: FiberRef[Option[ExecutionInfo]] =
FiberRef.unsafe.make[Option[ExecutionInfo]](None)(Unsafe.unsafe)
final class ExecutionInfoAware(val executionInfo: () => ExecutionInfo) { self =>
def apply[R, E, A](zio: ZIO[R, E, A])(implicit trace: Trace): ZIO[R, E, A] =
latestExecutionInfo.set(Some(executionInfo())) *> zio
}

def withExecutionInfo(info: => ExecutionInfo): ExecutionInfoAware =
new ExecutionInfoAware(() => info)

final class ExecutionInfoInformed(val executionInfo: () => ExecutionInfo) { self =>

val latestSqlQuery: FiberRef[Option[String]] =
zio.Unsafe.unsafe {
FiberRef.unsafe.make[Option[String]](None)
}

final class SqlQueryAware(val sqlQuery: () => String) {
self =>
def apply[R, E, A](zio: ZIO[R, E, A])(implicit trace: Trace): ZIO[R, E, A] =
currentExecutionInfo.locallyWith(_ => Some(executionInfo()))(zio)
latestSqlQuery.set(Some(sqlQuery())) *> zio
}

def withExecutionInfo(info: => ExecutionInfo): ExecutionInfoInformed =
new ExecutionInfoInformed(() => info)
def withSqlQuery(sqlQuery: => String): SqlQueryAware =
new SqlQueryAware(() => sqlQuery)
}

Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ abstract class ZioJdbcUnderlyingContext[+Dialect <: SqlIdiom, +Naming <: NamingS
@targetName("runBatchActionReturningDefault")
inline def run[I, T, A <: Action[I] & QAC[I, T]](inline quoted: Quoted[BatchAction[A]]): ZIO[Connection, SQLException, List[T]] = InternalApi.runBatchActionReturning(quoted, 1)

protected def annotate[R, E, A](zio: ZIO[R, E, A], sql: String, info: ExecutionInfo): ZIO[R, E, A] =
ZioQuillLog.withExecutionInfo(info)(ZIO.logAnnotate(ZioJdbc.SqlAnnotationKey, sql)(zio))
protected def annotate[R, E, A](zio: ZIO[R, E, A], sql: => String, info: => ExecutionInfo): ZIO[R, E, A] =
ZioQuillLog.withExecutionInfo(info)(ZioQuillLog.withSqlQuery(sql)(zio))

// Need explicit return-type annotations due to scala/bug#8356. Otherwise macro system will not understand Result[Long]=Task[Long] etc...
override def executeAction(sql: String, prepare: Prepare = identityPrepare)(info: ExecutionInfo, dc: Runner): QCIO[Long] =
Expand All @@ -74,19 +74,22 @@ abstract class ZioJdbcUnderlyingContext[+Dialect <: SqlIdiom, +Naming <: NamingS
override def executeActionReturningMany[O](sql: String, prepare: Prepare = identityPrepare, extractor: Extractor[O], returningBehavior: ReturnAction)(info: ExecutionInfo, dc: Runner): QCIO[List[O]] =
annotate(super.executeActionReturningMany(sql, prepare, extractor, returningBehavior)(info, dc), sql, info)
override def executeBatchAction(groups: List[BatchGroup])(info: ExecutionInfo, dc: Runner): QCIO[List[Long]] =
annotate(super.executeBatchAction(groups)(info, dc), sql, info)
annotate(super.executeBatchAction(groups)(info, dc), concatQueries(groups), info)
override def executeBatchActionReturning[T](groups: List[BatchGroupReturning], extractor: Extractor[T])(info: ExecutionInfo, dc: Runner): QCIO[List[T]] =
annotate(super.executeBatchActionReturning(groups, extractor)(info, dc), sql, info)
annotate(super.executeBatchActionReturning(groups, extractor)(info, dc), concatQueriesRet(groups), info)
override def prepareQuery(sql: String, prepare: Prepare)(info: ExecutionInfo, dc: Runner): QCIO[PreparedStatement] =
annotate(super.prepareQuery(sql, prepare)(info, dc), sql, info)
override def prepareAction(sql: String, prepare: Prepare)(info: ExecutionInfo, dc: Runner): QCIO[PreparedStatement] =
annotate(super.prepareAction(sql, prepare)(info, dc), sql, info)
override def prepareBatchAction(groups: List[BatchGroup])(info: ExecutionInfo, dc: Runner): QCIO[List[PreparedStatement]] =
annotate(super.prepareBatchAction(groups)(info, dc), sql, info)
annotate(super.prepareBatchAction(groups)(info, dc), concatQueries(groups), info)
override def translateQueryEndpoint[T](statement: String, prepare: Prepare = identityPrepare, extractor: Extractor[T] = identityExtractor, prettyPrint: Boolean = false)(info: ExecutionInfo, dc: Runner): QCIO[String] =
annotate(super.translateQueryEndpoint(statement, prepare, extractor, prettyPrint)(info, dc), sql, info)
annotate(super.translateQueryEndpoint(statement, prepare, extractor, prettyPrint)(info, dc), statement, info)
override def translateBatchQueryEndpoint(groups: List[BatchGroup], prettyPrint: Boolean = false)(info: ExecutionInfo, dc: Runner): QCIO[List[String]] =
annotate(super.translateBatchQueryEndpoint(groups, prettyPrint)(info, dc), sql, info)
annotate(super.translateBatchQueryEndpoint(groups, prettyPrint)(info, dc), concatQueries(groups), info)

protected def concatQueries(groups: List[BatchGroup]): String = groups.map(_.string).distinct.mkString(",")
protected def concatQueriesRet(groups: List[BatchGroupReturning]): String = groups.map(_.string).distinct.mkString(",")

/** ZIO Contexts do not managed DB connections so this is a no-op */
override def close(): Unit = ()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package io.getquill.postgres

import io.getquill.*
import io.getquill.ast.{Filter, ReturningGenerated}
import io.getquill.context.AstSplicing
import io.getquill.context.sql.ProductSpec

class DiagnosticDataSpec extends ProductSpec with ZioSpec {

val context = testContextSplicing
import testContextSplicing.*

override def beforeAll() = {
super.beforeAll()
testContextSplicing.run(quote(query[Product].delete)).runSyncUnsafe()
()
}

"Diagnostic data" - {
"Should expose last executed query" in {
val (insertQuery, selectQuery, insertInfo, selectInfo) = (for {
_ <- testContextSplicing.run {
liftQuery(productEntries).foreach(e => productInsert(e))
}
insertQuery <- getLastExecutedQuery()
insertInfo <- getLastExecutionInfo()

_ <- testContextSplicing.run(query[Product].filter(p => p.description == "Notebook"))
selectQuery <- getLastExecutedQuery()
selectInfo <- getLastExecutionInfo()

} yield (insertQuery, selectQuery, insertInfo, selectInfo)).runSyncUnsafe()

insertQuery.get mustBe "INSERT INTO Product (description,sku) VALUES (?, ?) RETURNING id"
selectQuery.get mustBe "SELECT p.id, p.description, p.sku FROM Product p WHERE p.description = 'Notebook'"

insertInfo.get.ast mustBe a[ReturningGenerated]
selectInfo.get.ast mustBe a[Filter]
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@ package io.getquill

import io.getquill.context.qzio.ImplicitSyntax.Implicit
import io.getquill.ZioSpec.runLayerUnsafe
import io.getquill.context.AstSplicing
import io.getquill.jdbczio.Quill

package object postgres {
val pool = runLayerUnsafe(Quill.DataSource.fromPrefix("testPostgresDB"))
object testContext extends Quill.Postgres(Literal, pool) with TestEntities
object testContextSplicing extends Quill.Postgres(Literal, pool) with TestEntities with AstSplicing
}
Original file line number Diff line number Diff line change
Expand Up @@ -406,8 +406,15 @@ object QueryExecutionBatch {
)($traceConfigExpr)
}

val spliceAsts = TypeRepr.of[Ctx] <:< TypeRepr.of[AstSplicing]
val executionInfo =
if (spliceAsts)
'{ ExecutionInfo(ExecutionType.Static, ${ Lifter(state.ast) }, ${ Lifter.quat(topLevelQuat) }) }
else
'{ ExecutionInfo(ExecutionType.Unknown, io.getquill.ast.NullValue, Quat.Unknown) }

'{
$batchContextOperation.execute(ContextOperation.BatchArgument($batchGroups, $extractor, ExecutionInfo(ExecutionType.Static, ${ Lifter(state.ast) }, ${ Lifter.quat(topLevelQuat) }), None))
$batchContextOperation.execute(ContextOperation.BatchArgument($batchGroups, $extractor, $executionInfo, None))
}

case None =>
Expand Down

0 comments on commit f3c136f

Please sign in to comment.