diff --git a/app/Startup.scala b/app/Startup.scala index 72fa3d67ad5..d7f69403478 100755 --- a/app/Startup.scala +++ b/app/Startup.scala @@ -14,7 +14,8 @@ import oxalis.mail.{Mailer, MailerConfig} import oxalis.security.WkSilhouetteEnvironment import oxalis.telemetry.SlackNotificationService import play.api.inject.ApplicationLifecycle -import utils.{SQLClient, WkConf} +import utils.WkConf +import utils.sql.SQLClient import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.Future diff --git a/app/WebKnossosModule.scala b/app/WebKnossosModule.scala index 105bee3fa1e..084e0fc0b4c 100644 --- a/app/WebKnossosModule.scala +++ b/app/WebKnossosModule.scala @@ -11,7 +11,7 @@ import models.voxelytics.ElasticsearchClient import oxalis.files.TempFileService import oxalis.mail.MailchimpTicker import oxalis.telemetry.SlackNotificationService -import utils.SQLClient +import utils.sql.SQLClient class WebKnossosModule extends AbstractModule { override def configure(): Unit = { diff --git a/app/controllers/Application.scala b/app/controllers/Application.scala index 01d9e9a079f..5684f6b5035 100755 --- a/app/controllers/Application.scala +++ b/app/controllers/Application.scala @@ -12,8 +12,9 @@ import oxalis.security.WkEnv import play.api.libs.json.{JsObject, Json} import play.api.mvc.{Action, AnyContent, PlayBodyParsers} import slick.jdbc.PostgresProfile.api._ -import utils.{SQLClient, SimpleSQLDAO, StoreModules, WkConf} +import utils.{StoreModules, WkConf} import oxalis.mail.{DefaultMails, Send} +import utils.sql.{SQLClient, SimpleSQLDAO} import javax.inject.Inject import scala.concurrent.ExecutionContext diff --git a/app/controllers/MaintenanceController.scala b/app/controllers/MaintenanceController.scala index 5b8d32c8480..834591d6247 100644 --- a/app/controllers/MaintenanceController.scala +++ b/app/controllers/MaintenanceController.scala @@ -10,7 +10,7 @@ import oxalis.security.WkEnv import play.api.libs.json.Json import play.api.mvc.{Action, AnyContent} import slick.jdbc.PostgresProfile.api._ -import utils.{SQLClient, SimpleSQLDAO} +import utils.sql.{SQLClient, SimpleSQLDAO} import scala.concurrent.duration._ import scala.concurrent.ExecutionContext diff --git a/app/controllers/ReportController.scala b/app/controllers/ReportController.scala index 7b0bbd58ed1..425aa0f0db8 100644 --- a/app/controllers/ReportController.scala +++ b/app/controllers/ReportController.scala @@ -10,7 +10,8 @@ import oxalis.security.WkEnv import play.api.libs.json.{Json, OFormat} import play.api.mvc.{Action, AnyContent} import slick.jdbc.PostgresProfile.api._ -import utils.{ObjectId, SQLClient, SimpleSQLDAO} +import utils.sql.{SQLClient, SimpleSQLDAO} +import utils.ObjectId import scala.concurrent.ExecutionContext diff --git a/app/models/annotation/Annotation.scala b/app/models/annotation/Annotation.scala index 2b6d1182159..f57bee548c2 100755 --- a/app/models/annotation/Annotation.scala +++ b/app/models/annotation/Annotation.scala @@ -16,7 +16,8 @@ import slick.jdbc.PostgresProfile.api._ import slick.jdbc.TransactionIsolation.Serializable import slick.lifted.Rep import slick.sql.SqlAction -import utils.{ObjectId, SQLClient, SQLDAO, SimpleSQLDAO} +import utils.sql.{SQLClient, SQLDAO, SimpleSQLDAO} +import utils.ObjectId import scala.concurrent.ExecutionContext import scala.concurrent.duration.FiniteDuration diff --git a/app/models/annotation/AnnotationPrivateLink.scala b/app/models/annotation/AnnotationPrivateLink.scala index ec533296380..ba6f24653cc 100644 --- a/app/models/annotation/AnnotationPrivateLink.scala +++ b/app/models/annotation/AnnotationPrivateLink.scala @@ -3,14 +3,15 @@ package models.annotation import com.scalableminds.util.accesscontext.DBAccessContext import com.scalableminds.util.time.Instant import com.scalableminds.util.tools.Fox -import com.scalableminds.webknossos.schema.Tables.{AnnotationPrivatelinks, _} +import com.scalableminds.webknossos.schema.Tables._ import oxalis.security.RandomIDGenerator import javax.inject.Inject import play.api.libs.json.{JsValue, Json, OFormat} import slick.jdbc.PostgresProfile.api._ import slick.lifted.Rep -import utils.{ObjectId, SQLClient, SQLDAO} +import utils.sql.{SQLClient, SQLDAO} +import utils.ObjectId import scala.concurrent.ExecutionContext diff --git a/app/models/annotation/TracingStore.scala b/app/models/annotation/TracingStore.scala index 44776df99ae..1029d1532b8 100644 --- a/app/models/annotation/TracingStore.scala +++ b/app/models/annotation/TracingStore.scala @@ -12,7 +12,7 @@ import play.api.libs.json.{JsObject, Json} import play.api.mvc.{Result, Results} import slick.jdbc.PostgresProfile.api._ import slick.lifted.Rep -import utils.{SQLClient, SQLDAO} +import utils.sql.{SQLClient, SQLDAO} import scala.concurrent.{ExecutionContext, Future} diff --git a/app/models/binary/DataSet.scala b/app/models/binary/DataSet.scala index 03868cd5882..38830b8f793 100755 --- a/app/models/binary/DataSet.scala +++ b/app/models/binary/DataSet.scala @@ -24,7 +24,8 @@ import slick.jdbc.PostgresProfile.api._ import slick.jdbc.TransactionIsolation.Serializable import slick.lifted.Rep import slick.sql.SqlAction -import utils.{ObjectId, SQLClient, SQLDAO, SimpleSQLDAO} +import utils.sql.{SQLClient, SQLDAO, SimpleSQLDAO} +import utils.ObjectId import scala.concurrent.ExecutionContext diff --git a/app/models/binary/DataStore.scala b/app/models/binary/DataStore.scala index 21dbc00a33d..fdf10e6a1b6 100644 --- a/app/models/binary/DataStore.scala +++ b/app/models/binary/DataStore.scala @@ -9,7 +9,8 @@ import play.api.libs.json.{Format, JsObject, Json} import play.api.mvc.{Result, Results} import slick.jdbc.PostgresProfile.api._ import slick.lifted.Rep -import utils.{ObjectId, SQLClient, SQLDAO} +import utils.sql.{SQLClient, SQLDAO} +import utils.ObjectId import scala.concurrent.{ExecutionContext, Future} diff --git a/app/models/binary/Publication.scala b/app/models/binary/Publication.scala index 10d124a3018..62521ec77d9 100644 --- a/app/models/binary/Publication.scala +++ b/app/models/binary/Publication.scala @@ -12,7 +12,8 @@ import play.api.libs.json.Format.GenericFormat import play.api.libs.json.{JsObject, Json} import slick.jdbc.PostgresProfile.api._ import slick.lifted.Rep -import utils.{ObjectId, SQLClient, SQLDAO} +import utils.sql.{SQLClient, SQLDAO} +import utils.ObjectId import scala.concurrent.ExecutionContext diff --git a/app/models/folder/Folder.scala b/app/models/folder/Folder.scala index 656e75552cd..120a715643e 100644 --- a/app/models/folder/Folder.scala +++ b/app/models/folder/Folder.scala @@ -2,7 +2,7 @@ package models.folder import com.scalableminds.util.accesscontext.DBAccessContext import com.scalableminds.util.tools.Fox -import com.scalableminds.webknossos.schema.Tables.{Folders, _} +import com.scalableminds.webknossos.schema.Tables._ import models.organization.{Organization, OrganizationDAO} import models.team.{TeamDAO, TeamService} import models.user.User @@ -10,7 +10,8 @@ import play.api.libs.json.{JsObject, Json, OFormat} import slick.jdbc.PostgresProfile.api._ import slick.lifted.Rep import slick.sql.SqlAction -import utils.{ObjectId, SQLClient, SQLDAO} +import utils.sql.{SQLClient, SQLDAO} +import utils.ObjectId import javax.inject.Inject import scala.concurrent.ExecutionContext diff --git a/app/models/job/Job.scala b/app/models/job/Job.scala index 8bd16732573..e0f76b9b11b 100644 --- a/app/models/job/Job.scala +++ b/app/models/job/Job.scala @@ -21,7 +21,8 @@ import play.api.libs.json.{JsObject, Json} import slick.jdbc.PostgresProfile.api._ import slick.jdbc.TransactionIsolation.Serializable import slick.lifted.Rep -import utils.{ObjectId, SQLClient, SQLDAO, WkConf} +import utils.sql.{SQLClient, SQLDAO} +import utils.{ObjectId, WkConf} import scala.concurrent.ExecutionContext import scala.concurrent.duration.FiniteDuration diff --git a/app/models/job/Worker.scala b/app/models/job/Worker.scala index 0fbe7ee49da..988a69b40de 100644 --- a/app/models/job/Worker.scala +++ b/app/models/job/Worker.scala @@ -14,7 +14,8 @@ import play.api.inject.ApplicationLifecycle import play.api.libs.json.{JsObject, Json} import slick.jdbc.PostgresProfile.api._ import slick.lifted.Rep -import utils.{ObjectId, SQLClient, SQLDAO, WkConf} +import utils.sql.{SQLClient, SQLDAO} +import utils.{ObjectId, WkConf} import javax.inject.Inject import scala.concurrent.ExecutionContext diff --git a/app/models/mesh/Mesh.scala b/app/models/mesh/Mesh.scala index 3f20ab9be3e..a68bdce67e9 100644 --- a/app/models/mesh/Mesh.scala +++ b/app/models/mesh/Mesh.scala @@ -12,7 +12,8 @@ import play.api.libs.json.Json._ import play.api.libs.json._ import slick.jdbc.PostgresProfile.api._ import slick.lifted.Rep -import utils.{ObjectId, SQLClient, SQLDAO} +import utils.sql.{SQLClient, SQLDAO} +import utils.ObjectId import scala.concurrent.ExecutionContext diff --git a/app/models/organization/Organization.scala b/app/models/organization/Organization.scala index c663e892986..66da497d7f6 100755 --- a/app/models/organization/Organization.scala +++ b/app/models/organization/Organization.scala @@ -10,7 +10,8 @@ import models.team.PricingPlan import models.team.PricingPlan.PricingPlan import slick.jdbc.PostgresProfile.api._ import slick.lifted.Rep -import utils.{ObjectId, SQLClient, SQLDAO} +import utils.sql.{SQLClient, SQLDAO} +import utils.ObjectId import scala.concurrent.ExecutionContext diff --git a/app/models/project/Project.scala b/app/models/project/Project.scala index b0a825cd57e..f8436fa6efe 100755 --- a/app/models/project/Project.scala +++ b/app/models/project/Project.scala @@ -13,10 +13,11 @@ import models.team.TeamDAO import models.user.{User, UserService} import net.liftweb.common.Full import play.api.libs.functional.syntax._ -import play.api.libs.json.{Json, _} +import play.api.libs.json._ import slick.jdbc.PostgresProfile.api._ import slick.lifted.Rep -import utils.{ObjectId, SQLClient, SQLDAO} +import utils.sql.{SQLClient, SQLDAO} +import utils.ObjectId import scala.concurrent.{ExecutionContext, Future} diff --git a/app/models/shortlinks/ShortLink.scala b/app/models/shortlinks/ShortLink.scala index 03b7d5dbd76..9ce6f1d4ff1 100644 --- a/app/models/shortlinks/ShortLink.scala +++ b/app/models/shortlinks/ShortLink.scala @@ -6,7 +6,8 @@ import com.scalableminds.webknossos.schema.Tables.{Shortlinks, ShortlinksRow} import play.api.libs.json.{Json, OFormat} import slick.jdbc.PostgresProfile.api._ import slick.lifted.Rep -import utils.{ObjectId, SQLClient, SQLDAO} +import utils.sql.{SQLClient, SQLDAO} +import utils.ObjectId import javax.inject.Inject import scala.concurrent.ExecutionContext diff --git a/app/models/task/Script.scala b/app/models/task/Script.scala index 0e540f7ed41..53498749b66 100644 --- a/app/models/task/Script.scala +++ b/app/models/task/Script.scala @@ -8,7 +8,8 @@ import models.user.{UserDAO, UserService} import play.api.libs.json._ import slick.jdbc.PostgresProfile.api._ import slick.lifted.Rep -import utils.{ObjectId, SQLClient, SQLDAO} +import utils.sql.{SQLClient, SQLDAO} +import utils.ObjectId import javax.inject.Inject import scala.concurrent.ExecutionContext diff --git a/app/models/task/Task.scala b/app/models/task/Task.scala index 99aad97d2c1..a3f0283047f 100755 --- a/app/models/task/Task.scala +++ b/app/models/task/Task.scala @@ -4,7 +4,7 @@ import com.scalableminds.util.accesscontext.DBAccessContext import com.scalableminds.util.geometry.{BoundingBox, Vec3Double, Vec3Int} import com.scalableminds.util.time.Instant import com.scalableminds.util.tools.Fox -import com.scalableminds.webknossos.schema.Tables.{profile, _} +import com.scalableminds.webknossos.schema.Tables._ import javax.inject.Inject import models.annotation._ @@ -12,7 +12,8 @@ import models.project.ProjectDAO import models.user.Experience import slick.jdbc.PostgresProfile.api._ import slick.jdbc.TransactionIsolation.Serializable -import utils.{ObjectId, SQLClient, SQLDAO} +import utils.sql.{SQLClient, SQLDAO} +import utils.ObjectId import scala.concurrent.ExecutionContext import scala.concurrent.duration.FiniteDuration diff --git a/app/models/task/TaskType.scala b/app/models/task/TaskType.scala index bda72f3a50a..f470c7d1807 100755 --- a/app/models/task/TaskType.scala +++ b/app/models/task/TaskType.scala @@ -12,7 +12,8 @@ import models.team.TeamDAO import play.api.libs.json._ import slick.jdbc.PostgresProfile.api._ import slick.lifted.Rep -import utils.{ObjectId, SQLClient, SQLDAO} +import utils.sql.{SQLClient, SQLDAO} +import utils.ObjectId import javax.inject.Inject import scala.concurrent.ExecutionContext diff --git a/app/models/team/Team.scala b/app/models/team/Team.scala index 0ddd1825d51..a877f2024c9 100755 --- a/app/models/team/Team.scala +++ b/app/models/team/Team.scala @@ -17,7 +17,8 @@ import play.api.libs.json._ import slick.jdbc.PostgresProfile.api._ import slick.jdbc.TransactionIsolation.Serializable import slick.lifted.Rep -import utils.{ObjectId, SQLClient, SQLDAO} +import utils.sql.{SQLClient, SQLDAO} +import utils.ObjectId import scala.concurrent.ExecutionContext diff --git a/app/models/user/Invite.scala b/app/models/user/Invite.scala index 0a61c35aabf..ddf93da399d 100644 --- a/app/models/user/Invite.scala +++ b/app/models/user/Invite.scala @@ -13,7 +13,8 @@ import oxalis.mail.{DefaultMails, Send} import oxalis.security.RandomIDGenerator import slick.jdbc.PostgresProfile.api._ import slick.lifted.Rep -import utils.{ObjectId, SQLClient, SQLDAO, WkConf} +import utils.sql.{SQLClient, SQLDAO} +import utils.{ObjectId, WkConf} import scala.concurrent.{ExecutionContext, Future} diff --git a/app/models/user/MultiUser.scala b/app/models/user/MultiUser.scala index f04c2c726ff..313e13bc220 100644 --- a/app/models/user/MultiUser.scala +++ b/app/models/user/MultiUser.scala @@ -12,7 +12,8 @@ import models.user.Theme.Theme import play.api.libs.json.Format.GenericFormat import play.api.libs.json.{JsObject, Json} import slick.lifted.Rep -import utils.{ObjectId, SQLClient, SQLDAO} +import utils.sql.{SQLClient, SQLDAO} +import utils.ObjectId import scala.concurrent.ExecutionContext diff --git a/app/models/user/User.scala b/app/models/user/User.scala index f91a416c67d..384a8eea6bd 100755 --- a/app/models/user/User.scala +++ b/app/models/user/User.scala @@ -15,7 +15,8 @@ import play.api.libs.json._ import slick.jdbc.PostgresProfile.api._ import slick.jdbc.TransactionIsolation.Serializable import slick.lifted.Rep -import utils.{ObjectId, SQLClient, SQLDAO, SimpleSQLDAO} +import utils.sql.{SQLClient, SQLDAO, SimpleSQLDAO} +import utils.ObjectId import scala.concurrent.ExecutionContext diff --git a/app/models/user/time/TimeSpan.scala b/app/models/user/time/TimeSpan.scala index 563ad876aaa..4d384946154 100755 --- a/app/models/user/time/TimeSpan.scala +++ b/app/models/user/time/TimeSpan.scala @@ -6,7 +6,8 @@ import com.scalableminds.webknossos.schema.Tables._ import play.api.libs.json.{JsValue, Json, OFormat} import slick.jdbc.PostgresProfile.api._ import slick.lifted.Rep -import utils.{ObjectId, SQLClient, SQLDAO} +import utils.sql.{SQLClient, SQLDAO} +import utils.ObjectId import javax.inject.Inject import scala.concurrent.ExecutionContext diff --git a/app/models/voxelytics/VoxelyticsDAO.scala b/app/models/voxelytics/VoxelyticsDAO.scala index 5cff4c23ca8..00027f6bfdb 100644 --- a/app/models/voxelytics/VoxelyticsDAO.scala +++ b/app/models/voxelytics/VoxelyticsDAO.scala @@ -4,8 +4,8 @@ import com.scalableminds.util.time.Instant import com.scalableminds.util.tools.Fox import models.user.User import play.api.libs.json._ -import slick.jdbc.PostgresProfile.api._ -import utils.{ObjectId, SQLClient, SimpleSQLDAO} +import utils.sql.{SQLClient, SimpleSQLDAO, SqlToken} +import utils.ObjectId import javax.inject.Inject import scala.concurrent.ExecutionContext @@ -15,7 +15,7 @@ class VoxelyticsDAO @Inject()(sqlClient: SQLClient)(implicit ec: ExecutionContex def findArtifacts(taskIds: List[ObjectId]): Fox[List[ArtifactEntry]] = for { - r <- run(sql""" + r <- run(q""" SELECT a._id, a._task, @@ -28,7 +28,7 @@ class VoxelyticsDAO @Inject()(sqlClient: SQLClient)(implicit ec: ExecutionContex t.name AS taskName FROM webknossos.voxelytics_artifacts a JOIN webknossos.voxelytics_tasks t ON t._id = a._task - WHERE t."_id" IN #${writeEscapedTuple(taskIds.map(_.id))} + WHERE t."_id" IN ${SqlToken.tuple(taskIds)} """.as[(String, String, String, String, Long, Long, String, String, String)]) } yield r.toList.map( @@ -45,7 +45,7 @@ class VoxelyticsDAO @Inject()(sqlClient: SQLClient)(implicit ec: ExecutionContex def findTasks(combinedTaskRuns: List[TaskRunEntry]): Fox[List[TaskEntry]] = for { - r <- run(sql""" + r <- run(q""" SELECT t._id, t._run, @@ -54,9 +54,7 @@ class VoxelyticsDAO @Inject()(sqlClient: SQLClient)(implicit ec: ExecutionContex t.config FROM webknossos.voxelytics_tasks t WHERE - ("_run", "name") IN (#${combinedTaskRuns - .map(t => s"(${escapeLiteral(t.runId.id)}, ${escapeLiteral(t.taskName)})") - .mkString(", ")}) + ("_run", "name") IN (${SqlToken.tupleList(combinedTaskRuns.map(t => List(t.runId, t.taskName)))}) """.as[(String, String, String, String, String)]) } yield r.toList.map(row => @@ -65,16 +63,16 @@ class VoxelyticsDAO @Inject()(sqlClient: SQLClient)(implicit ec: ExecutionContex def findWorkflowsByHashAndOrganization(organizationId: ObjectId, workflowHashes: Set[String]): Fox[List[WorkflowEntry]] = for { - r <- run(sql""" + r <- run(q""" SELECT name, hash FROM webknossos.voxelytics_workflows - WHERE hash IN #${writeEscapedTuple(workflowHashes.toList)} AND _organization = $organizationId + WHERE hash IN ${SqlToken.tuple(workflowHashes.toList)} AND _organization = $organizationId """.as[(String, String)]) } yield r.toList.map(row => WorkflowEntry(row._1, row._2, organizationId)) def findWorkflowByHashAndOrganization(organizationId: ObjectId, workflowHash: String): Fox[WorkflowEntry] = for { - r <- run(sql""" + r <- run(q""" SELECT name, hash FROM webknossos.voxelytics_workflows WHERE hash = $workflowHash AND _organization = $organizationId @@ -84,7 +82,7 @@ class VoxelyticsDAO @Inject()(sqlClient: SQLClient)(implicit ec: ExecutionContex def findWorkflowByHash(workflowHash: String): Fox[WorkflowEntry] = for { - r <- run(sql""" + r <- run(q""" SELECT name, hash, _organization FROM webknossos.voxelytics_workflows WHERE hash = $workflowHash @@ -94,7 +92,7 @@ class VoxelyticsDAO @Inject()(sqlClient: SQLClient)(implicit ec: ExecutionContex def findTaskRuns(organizationId: ObjectId, runIds: List[ObjectId], staleTimeout: Duration): Fox[List[TaskRunEntry]] = for { - r <- run(sql""" + r <- run(q""" WITH latest_chunk_states AS ( SELECT DISTINCT ON (_chunk) _chunk, timestamp, state FROM webknossos.voxelytics_chunkStateChangeEvents @@ -106,11 +104,11 @@ class VoxelyticsDAO @Inject()(sqlClient: SQLClient)(implicit ec: ExecutionContex t._id AS taskId, t.name AS taskName, CASE - WHEN task_state.state = 'RUNNING' AND run_heartbeat.timestamp IS NOT NULL AND run_heartbeat.timestamp < NOW() - INTERVAL '#${staleTimeout.toSeconds} SECONDS' + WHEN task_state.state = 'RUNNING' AND run_heartbeat.timestamp IS NOT NULL AND run_heartbeat.timestamp < NOW() - $staleTimeout THEN 'STALE' ELSE task_state.state END AS state, task_begin.timestamp AS beginTime, CASE - WHEN task_state.state = 'RUNNING' AND run_heartbeat.timestamp IS NOT NULL AND run_heartbeat.timestamp < NOW() - INTERVAL '#${staleTimeout.toSeconds} SECONDS' + WHEN task_state.state = 'RUNNING' AND run_heartbeat.timestamp IS NOT NULL AND run_heartbeat.timestamp < NOW() - $staleTimeout THEN run_heartbeat.timestamp ELSE task_end.timestamp END AS endTime, exec.executionId AS currentExecutionId, COALESCE(chunks.total, 0) AS chunksTotal, @@ -169,7 +167,7 @@ class VoxelyticsDAO @Inject()(sqlClient: SQLClient)(implicit ec: ExecutionContex ) chunks ON chunks._task = t._id WHERE r._organization = $organizationId AND - r._id IN #${writeEscapedTuple(runIds.map(_.id))} + r._id IN ${SqlToken.tuple(runIds.map(_.id))} """.as[(String, String, String, String, String, Option[Instant], Option[Instant], Option[String], Long, Long)]) results <- Fox.combined( r.toList.map( @@ -198,12 +196,13 @@ class VoxelyticsDAO @Inject()(sqlClient: SQLClient)(implicit ec: ExecutionContex allowUnlisted: Boolean): Fox[List[RunEntry]] = { val organizationId = currentUser._organization val readAccessQ = - if (currentUser.isAdmin || allowUnlisted) "" else { s" AND (r._user = ${escapeLiteral(currentUser._id.id)})" } - val runIdsQ = runIds.map(runIds => s" AND r._id IN ${writeEscapedTuple(runIds.map(_.id))}").getOrElse("") + if (currentUser.isAdmin || allowUnlisted) SqlToken.empty + else { q" AND (r._user = ${currentUser._id})" } + val runIdsQ = runIds.map(runIds => q" AND r._id IN ${SqlToken.tuple(runIds)}").getOrElse(SqlToken.empty) val workflowHashQ = - workflowHash.map(workflowHash => s" AND r.workflow_hash = ${escapeLiteral(workflowHash)}").getOrElse("") + workflowHash.map(workflowHash => q" AND r.workflow_hash = $workflowHash").getOrElse(SqlToken.empty) for { - r <- run(sql""" + r <- run(q""" SELECT r._id, r.name, @@ -214,11 +213,11 @@ class VoxelyticsDAO @Inject()(sqlClient: SQLClient)(implicit ec: ExecutionContex r.workflow_yamlContent, r.workflow_config, CASE - WHEN run_state.state = 'RUNNING' AND run_heartbeat.timestamp IS NOT NULL AND run_heartbeat.timestamp < NOW() - INTERVAL '#${staleTimeout.toSeconds} SECONDS' + WHEN run_state.state = 'RUNNING' AND run_heartbeat.timestamp IS NOT NULL AND run_heartbeat.timestamp < NOW() - $staleTimeout THEN 'STALE' ELSE run_state.state END AS state, run_begin.timestamp AS beginTime, CASE - WHEN run_state.state = 'RUNNING' AND run_heartbeat.timestamp IS NOT NULL AND run_heartbeat.timestamp < NOW() - INTERVAL '#${staleTimeout.toSeconds} SECONDS' + WHEN run_state.state = 'RUNNING' AND run_heartbeat.timestamp IS NOT NULL AND run_heartbeat.timestamp < NOW() - $staleTimeout THEN run_heartbeat.timestamp ELSE run_end.timestamp END AS endTime FROM webknossos.voxelytics_runs r JOIN ( @@ -247,9 +246,9 @@ class VoxelyticsDAO @Inject()(sqlClient: SQLClient)(implicit ec: ExecutionContex ) run_heartbeat ON r._id = run_heartbeat._run WHERE r._organization = $organizationId - #$runIdsQ - #$workflowHashQ - #$readAccessQ + $runIdsQ + $workflowHashQ + $readAccessQ """.as[(String, String, String, String, String, String, String, String, String, Instant, Option[Instant])]) results <- Fox.combined( r.toList.map( @@ -277,7 +276,7 @@ class VoxelyticsDAO @Inject()(sqlClient: SQLClient)(implicit ec: ExecutionContex def upsertArtifactChecksumEvent(artifactId: ObjectId, ev: ArtifactFileChecksumEvent): Fox[Unit] = for { _ <- run( - sqlu"""INSERT INTO webknossos.voxelytics_artifactFileChecksumEvents (_artifact, path, resolvedPath, checksumMethod, checksum, fileSize, lastModified, timestamp) + q"""INSERT INTO webknossos.voxelytics_artifactFileChecksumEvents (_artifact, path, resolvedPath, checksumMethod, checksum, fileSize, lastModified, timestamp) VALUES ($artifactId, ${ev.path}, ${ev.resolvedPath}, ${ev.checksumMethod}, ${ev.checksum}, ${ev.fileSize}, ${ev.lastModified}, ${ev.timestamp}) ON CONFLICT (_artifact, path, timestamp) DO UPDATE SET @@ -286,13 +285,13 @@ class VoxelyticsDAO @Inject()(sqlClient: SQLClient)(implicit ec: ExecutionContex checksum = EXCLUDED.checksum, fileSize = EXCLUDED.fileSize, lastModified = EXCLUDED.lastModified - """) + """.asUpdate) } yield () def upsertChunkProfilingEvent(chunkId: ObjectId, ev: ChunkProfilingEvent): Fox[Unit] = for { _ <- run( - sqlu"""INSERT INTO webknossos.voxelytics_chunkProfilingEvents (_chunk, hostname, pid, memory, cpuUser, cpuSystem, timestamp) + q"""INSERT INTO webknossos.voxelytics_chunkProfilingEvents (_chunk, hostname, pid, memory, cpuUser, cpuSystem, timestamp) VALUES ($chunkId, ${ev.hostname}, ${ev.pid}, ${ev.memory}, ${ev.cpuUser}, ${ev.cpuSystem}, ${ev.timestamp}) ON CONFLICT (_chunk, timestamp) DO UPDATE SET @@ -301,52 +300,52 @@ class VoxelyticsDAO @Inject()(sqlClient: SQLClient)(implicit ec: ExecutionContex memory = EXCLUDED.memory, cpuUser = EXCLUDED.cpuUser, cpuSystem = EXCLUDED.cpuSystem - """) + """.asUpdate) } yield () def upsertRunHeartbeatEvent(runId: ObjectId, ev: RunHeartbeatEvent): Fox[Unit] = for { - _ <- run(sqlu"""INSERT INTO webknossos.voxelytics_runHeartbeatEvents (_run, timestamp) + _ <- run(q"""INSERT INTO webknossos.voxelytics_runHeartbeatEvents (_run, timestamp) VALUES ($runId, ${ev.timestamp}) ON CONFLICT (_run) DO UPDATE SET timestamp = EXCLUDED.timestamp - """) + """.asUpdate) } yield () def upsertChunkStateChangeEvent(chunkId: ObjectId, ev: ChunkStateChangeEvent): Fox[Unit] = for { - _ <- run(sqlu"""INSERT INTO webknossos.voxelytics_chunkStateChangeEvents (_chunk, timestamp, state) + _ <- run(q"""INSERT INTO webknossos.voxelytics_chunkStateChangeEvents (_chunk, timestamp, state) VALUES ($chunkId, ${ev.timestamp}, ${ev.state.toString}::webknossos.VOXELYTICS_RUN_STATE) ON CONFLICT (_chunk, timestamp) DO UPDATE SET state = EXCLUDED.state - """) + """.asUpdate) } yield () def upsertTaskStateChangeEvent(taskId: ObjectId, ev: TaskStateChangeEvent): Fox[Unit] = for { - _ <- run(sqlu"""INSERT INTO webknossos.voxelytics_taskStateChangeEvents (_task, timestamp, state) + _ <- run(q"""INSERT INTO webknossos.voxelytics_taskStateChangeEvents (_task, timestamp, state) VALUES ($taskId, ${ev.timestamp}, ${ev.state.toString}::webknossos.VOXELYTICS_RUN_STATE) ON CONFLICT (_task, timestamp) DO UPDATE SET state = EXCLUDED.state - """) + """.asUpdate) } yield () def upsertRunStateChangeEvent(runId: ObjectId, ev: RunStateChangeEvent): Fox[Unit] = for { - _ <- run(sqlu"""INSERT INTO webknossos.voxelytics_runStateChangeEvents (_run, timestamp, state) + _ <- run(q"""INSERT INTO webknossos.voxelytics_runStateChangeEvents (_run, timestamp, state) VALUES ($runId, ${ev.timestamp}, ${ev.state.toString}::webknossos.VOXELYTICS_RUN_STATE) ON CONFLICT (_run, timestamp) DO UPDATE SET state = EXCLUDED.state - """) + """.asUpdate) } yield () def upsertWorkflow(hash: String, name: String, organizationId: ObjectId): Fox[Unit] = for { - _ <- run(sqlu"""INSERT INTO webknossos.voxelytics_workflows (hash, name, _organization) + _ <- run(q"""INSERT INTO webknossos.voxelytics_workflows (hash, name, _organization) VALUES ($hash, $name, $organizationId) ON CONFLICT (_organization, hash) DO UPDATE SET name = EXCLUDED.name - """) + """.asUpdate) } yield () def upsertRun(organizationId: ObjectId, @@ -360,9 +359,8 @@ class VoxelyticsDAO @Inject()(sqlClient: SQLClient)(implicit ec: ExecutionContex workflow_config: JsValue): Fox[ObjectId] = for { _ <- run( - sqlu"""INSERT INTO webknossos.voxelytics_runs (_id, _organization, _user, name, username, hostname, voxelyticsVersion, workflow_hash, workflow_yamlContent, workflow_config) - VALUES (${ObjectId.generate}, $organizationId, $userId, $name, $username, $hostname, $voxelyticsVersion, $workflow_hash, $workflow_yamlContent, ${Json - .stringify(workflow_config)}::JSONB) + q"""INSERT INTO webknossos.voxelytics_runs (_id, _organization, _user, name, username, hostname, voxelyticsVersion, workflow_hash, workflow_yamlContent, workflow_config) + VALUES (${ObjectId.generate}, $organizationId, $userId, $name, $username, $hostname, $voxelyticsVersion, $workflow_hash, $workflow_yamlContent, $workflow_config) ON CONFLICT (_organization, name) DO UPDATE SET _user = EXCLUDED._user, @@ -372,8 +370,8 @@ class VoxelyticsDAO @Inject()(sqlClient: SQLClient)(implicit ec: ExecutionContex workflow_hash = EXCLUDED.workflow_hash, workflow_yamlContent = EXCLUDED.workflow_yamlContent, workflow_config = EXCLUDED.workflow_config - """) - objectIdList <- run(sql"""SELECT _id + """.asUpdate) + objectIdList <- run(q"""SELECT _id FROM webknossos.voxelytics_runs WHERE _organization = $organizationId AND name = $name """.as[String]) @@ -382,14 +380,14 @@ class VoxelyticsDAO @Inject()(sqlClient: SQLClient)(implicit ec: ExecutionContex def upsertTask(runId: ObjectId, name: String, task: String, config: JsValue): Fox[ObjectId] = for { - _ <- run(sqlu"""INSERT INTO webknossos.voxelytics_tasks (_id, _run, name, task, config) - VALUES (${ObjectId.generate}, $runId, $name, $task, ${Json.stringify(config)}::JSONB) + _ <- run(q"""INSERT INTO webknossos.voxelytics_tasks (_id, _run, name, task, config) + VALUES (${ObjectId.generate}, $runId, $name, $task, $config) ON CONFLICT (_run, name) DO UPDATE SET task = EXCLUDED.task, config = EXCLUDED.config - """) - objectIdList <- run(sql"""SELECT _id + """.asUpdate) + objectIdList <- run(q"""SELECT _id FROM webknossos.voxelytics_tasks WHERE _run = $runId AND name = $name """.as[String]) @@ -398,11 +396,11 @@ class VoxelyticsDAO @Inject()(sqlClient: SQLClient)(implicit ec: ExecutionContex def upsertChunk(taskId: ObjectId, executionId: String, chunkName: String): Fox[ObjectId] = for { - _ <- run(sqlu"""INSERT INTO webknossos.voxelytics_chunks (_id, _task, executionId, chunkName) + _ <- run(q"""INSERT INTO webknossos.voxelytics_chunks (_id, _task, executionId, chunkName) VALUES (${ObjectId.generate}, $taskId, $executionId, $chunkName) ON CONFLICT (_task, executionId, chunkName) DO NOTHING - """) - objectIdList <- run(sql"""SELECT _id + """.asUpdate) + objectIdList <- run(q"""SELECT _id FROM webknossos.voxelytics_chunks WHERE _task = $taskId AND executionId = $executionId AND chunkName = $chunkName """.as[String]) @@ -418,9 +416,8 @@ class VoxelyticsDAO @Inject()(sqlClient: SQLClient)(implicit ec: ExecutionContex metadata: JsValue): Fox[ObjectId] = for { _ <- run( - sqlu"""INSERT INTO webknossos.voxelytics_artifacts (_id, _task, name, path, fileSize, inodeCount, version, metadata) - VALUES (${ObjectId.generate}, $taskId, $name, $path, $fileSize, $inodeCount, $version, ${Json.stringify( - metadata)}::JSONB) + q"""INSERT INTO webknossos.voxelytics_artifacts (_id, _task, name, path, fileSize, inodeCount, version, metadata) + VALUES (${ObjectId.generate}, $taskId, $name, $path, $fileSize, $inodeCount, $version, $metadata) ON CONFLICT (_task, name) DO UPDATE SET path = EXCLUDED.path, @@ -428,8 +425,8 @@ class VoxelyticsDAO @Inject()(sqlClient: SQLClient)(implicit ec: ExecutionContex inodeCount = EXCLUDED.inodeCount, version = EXCLUDED.version, metadata = EXCLUDED.metadata - """) - objectIdList <- run(sql"""SELECT _id + """.asUpdate) + objectIdList <- run(q"""SELECT _id FROM webknossos.voxelytics_artifacts WHERE _task = $taskId AND name = $name """.as[String]) @@ -438,7 +435,7 @@ class VoxelyticsDAO @Inject()(sqlClient: SQLClient)(implicit ec: ExecutionContex def getRunIdByName(runName: String, organizationId: ObjectId): Fox[ObjectId] = for { - objectIdList <- run(sql""" + objectIdList <- run(q""" SELECT _id FROM webknossos.voxelytics_runs WHERE name = $runName AND _organization = $organizationId @@ -448,7 +445,7 @@ class VoxelyticsDAO @Inject()(sqlClient: SQLClient)(implicit ec: ExecutionContex def getRunNameById(runId: ObjectId, organizationId: ObjectId): Fox[String] = for { - nameList <- run(sql"""SELECT name + nameList <- run(q"""SELECT name FROM webknossos.voxelytics_runs WHERE _id = $runId AND _organization = $organizationId """.as[String]) @@ -457,7 +454,7 @@ class VoxelyticsDAO @Inject()(sqlClient: SQLClient)(implicit ec: ExecutionContex def getUserIdForRun(runId: ObjectId): Fox[ObjectId] = for { - userIdList <- run(sql""" + userIdList <- run(q""" SELECT _user FROM webknossos.voxelytics_runs WHERE _id = $runId @@ -467,7 +464,7 @@ class VoxelyticsDAO @Inject()(sqlClient: SQLClient)(implicit ec: ExecutionContex def getUserIdForRunOpt(runName: String, organizationId: ObjectId): Fox[Option[ObjectId]] = for { - userId <- run(sql""" + userId <- run(q""" SELECT _user FROM webknossos.voxelytics_runs WHERE name = $runName AND _organization = $organizationId @@ -476,7 +473,7 @@ class VoxelyticsDAO @Inject()(sqlClient: SQLClient)(implicit ec: ExecutionContex def getTaskIdByName(taskName: String, runId: ObjectId): Fox[ObjectId] = for { - objectIdList <- run(sql"""SELECT _id + objectIdList <- run(q"""SELECT _id FROM webknossos.voxelytics_tasks WHERE _run = $runId AND name = $taskName """.as[String]) @@ -485,7 +482,7 @@ class VoxelyticsDAO @Inject()(sqlClient: SQLClient)(implicit ec: ExecutionContex def getChunkIdByName(taskId: ObjectId, executionId: String, chunkName: String): Fox[ObjectId] = for { - objectIdList <- run(sql"""SELECT _id + objectIdList <- run(q"""SELECT _id FROM webknossos.voxelytics_chunks WHERE _task = $taskId AND executionId = $executionId AND chunkName = $chunkName """.as[String]) @@ -494,7 +491,7 @@ class VoxelyticsDAO @Inject()(sqlClient: SQLClient)(implicit ec: ExecutionContex def getArtifactIdByName(taskId: ObjectId, artifactName: String): Fox[ObjectId] = for { - objectIdList <- run(sql"""SELECT _id + objectIdList <- run(q"""SELECT _id FROM webknossos.voxelytics_artifacts WHERE _task = $taskId AND name = $artifactName """.as[String]) @@ -504,7 +501,7 @@ class VoxelyticsDAO @Inject()(sqlClient: SQLClient)(implicit ec: ExecutionContex def getChunkStatistics(taskId: ObjectId): Fox[List[ChunkStatisticsEntry]] = { for { r <- run( - sql""" + q""" WITH latest_chunk_states AS ( SELECT DISTINCT ON (_chunk) _chunk, timestamp, state FROM webknossos.voxelytics_chunkStateChangeEvents @@ -643,7 +640,7 @@ class VoxelyticsDAO @Inject()(sqlClient: SQLClient)(implicit ec: ExecutionContex def getArtifactChecksums(taskId: ObjectId, artifactName: Option[String]): Fox[List[ArtifactChecksumEntry]] = for { - r <- run(sql""" + r <- run(q""" SELECT t.name AS taskName, a.name AS artifactName, @@ -663,7 +660,7 @@ class VoxelyticsDAO @Inject()(sqlClient: SQLClient)(implicit ec: ExecutionContex JOIN webknossos.voxelytics_artifacts a ON a._id = af._artifact JOIN webknossos.voxelytics_tasks t ON t._id = a._task WHERE - a._task = $taskId #${artifactName.map(a => s"AND a.name = ${escapeLiteral(a)}").getOrElse("")} + a._task = $taskId ${artifactName.map(a => q"AND a.name = $a").getOrElse(SqlToken.empty)} ORDER BY af.path """.as[(String, String, String, String, Instant, String, String, Long, Instant)]) } yield diff --git a/app/oxalis/security/Token.scala b/app/oxalis/security/Token.scala index de9c44a53a1..518ea24d0f1 100644 --- a/app/oxalis/security/Token.scala +++ b/app/oxalis/security/Token.scala @@ -5,11 +5,12 @@ import com.mohiva.play.silhouette.impl.authenticators.BearerTokenAuthenticator import com.scalableminds.util.accesscontext.DBAccessContext import com.scalableminds.util.time.Instant import com.scalableminds.util.tools.Fox -import com.scalableminds.webknossos.schema.Tables.{Tokens, _} +import com.scalableminds.webknossos.schema.Tables._ import oxalis.security.TokenType.TokenType import slick.jdbc.PostgresProfile.api._ import slick.lifted.Rep -import utils.{ObjectId, SQLClient, SQLDAO} +import utils.sql.{SQLClient, SQLDAO} +import utils.ObjectId import javax.inject.Inject import scala.concurrent.ExecutionContext diff --git a/app/utils/SQLHelpers.scala b/app/utils/sql/SQLHelpers.scala similarity index 97% rename from app/utils/SQLHelpers.scala rename to app/utils/sql/SQLHelpers.scala index 5c1f9cb8e70..5a543a69db4 100644 --- a/app/utils/SQLHelpers.scala +++ b/app/utils/sql/SQLHelpers.scala @@ -1,4 +1,4 @@ -package utils +package utils.sql import com.scalableminds.util.accesscontext.DBAccessContext import com.scalableminds.util.time.Instant @@ -11,8 +11,10 @@ import oxalis.telemetry.SlackNotificationService import play.api.Configuration import slick.dbio.DBIOAction import slick.jdbc.PostgresProfile.api._ -import slick.jdbc.{GetResult, PositionedParameters, PositionedResult, PostgresProfile, SetParameter} +import slick.jdbc._ import slick.lifted.{AbstractTable, Rep, TableQuery} +import utils.ObjectId +import utils.sql.SqlInterpolation.sqlInterpolation import javax.inject.Inject import scala.annotation.nowarn @@ -54,47 +56,49 @@ trait SQLTypeImplicits { } } -class SimpleSQLDAO @Inject()(sqlClient: SQLClient)(implicit ec: ExecutionContext) - extends FoxImplicits - with LazyLogging - with SQLTypeImplicits { - - protected lazy val transactionSerializationError = "could not serialize access" +trait Escaping { + protected def escapeLiteral(aString: String): String = { + // Ported from PostgreSQL 9.2.4 source code in src/interfaces/libpq/fe-exec.c + var hasBackslash = false + val escaped = new StringBuffer("'") - protected def run[R](query: DBIOAction[R, NoStream, Nothing], - retryCount: Int = 0, - retryIfErrorContains: List[String] = List()): Fox[R] = { - val foxFuture = sqlClient.db.run(query.asTry).map { result: Try[R] => - result match { - case Success(res) => - Fox.successful(res) - case Failure(e: Throwable) => - val msg = e.getMessage - if (retryIfErrorContains.exists(msg.contains(_)) && retryCount > 0) { - logger.debug(s"Retrying SQL Query ($retryCount remaining) due to $msg") - Thread.sleep(20) - run(query, retryCount - 1, retryIfErrorContains) - } else { - logError(e, query) - reportErrorToSlack(e, query) - Fox.failure("SQL Failure: " + e.getMessage) - } + aString.foreach { c => + if (c == '\'') { + escaped.append(c).append(c) + } else if (c == '\\') { + escaped.append(c).append(c) + hasBackslash = true + } else { + escaped.append(c) } } - foxFuture.toFox.flatten + escaped.append('\'') + + if (hasBackslash) { + "E" + escaped.toString + } else { + escaped.toString + } } - private def logError[R](ex: Throwable, query: DBIOAction[R, NoStream, Nothing]): Unit = { - logger.error("SQL Error: " + ex) - logger.debug("Caused by query:\n" + query.getDumpInfo.mainInfo) + protected def writeEscapedTuple(seq: List[String]): String = + "(" + seq.map(escapeLiteral).mkString(", ") + ")" + + protected def sanitize(aString: String): String = aString.replaceAll("'", "") + + // escape ' by doubling it, escape " with backslash, drop commas + protected def sanitizeInArrayTuple(aString: String): String = + aString.replaceAll("'", """''""").replaceAll(""""""", """\\"""").replaceAll(""",""", "") + + protected def desanitizeFromArrayTuple(aString: String): String = + aString.replaceAll("""\\"""", """"""").replaceAll("""\\,""", ",") + + protected def optionLiteral(aStringOpt: Option[String]): String = aStringOpt match { + case Some(aString) => "'" + aString + "'" + case None => "null" } - private def reportErrorToSlack[R](ex: Throwable, query: DBIOAction[R, NoStream, Nothing]): Unit = - sqlClient.getSlackNotificationService.warnWithException( - "SQL Error", - ex, - s"Causing query: ${query.getDumpInfo.mainInfo}" - ) + protected def optionLiteralSanitized(aStringOpt: Option[String]): String = optionLiteral(aStringOpt.map(sanitize)) protected def writeArrayTuple(elements: List[String]): String = { val commaSeparated = elements.map(sanitizeInArrayTuple).map(e => s""""$e"""").mkString(",") @@ -124,50 +128,52 @@ class SimpleSQLDAO @Inject()(sqlClient: SQLClient)(implicit ec: ExecutionContext } } } +} - protected def escapeLiteral(aString: String): String = { - // Ported from PostgreSQL 9.2.4 source code in src/interfaces/libpq/fe-exec.c - var hasBackslash = false - val escaped = new StringBuffer("'") +class SimpleSQLDAO @Inject()(sqlClient: SQLClient)(implicit ec: ExecutionContext) + extends FoxImplicits + with LazyLogging + with SQLTypeImplicits + with Escaping { - aString.foreach { c => - if (c == '\'') { - escaped.append(c).append(c) - } else if (c == '\\') { - escaped.append(c).append(c) - hasBackslash = true - } else { - escaped.append(c) - } - } - escaped.append('\'') + implicit protected def sqlInterpolationWrapper(s: StringContext): SqlInterpolator = sqlInterpolation(s) - if (hasBackslash) { - "E" + escaped.toString - } else { - escaped.toString + protected lazy val transactionSerializationError = "could not serialize access" + + protected def run[R](query: DBIOAction[R, NoStream, Nothing], + retryCount: Int = 0, + retryIfErrorContains: List[String] = List()): Fox[R] = { + val foxFuture = sqlClient.db.run(query.asTry).map { result: Try[R] => + result match { + case Success(res) => + Fox.successful(res) + case Failure(e: Throwable) => + val msg = e.getMessage + if (retryIfErrorContains.exists(msg.contains(_)) && retryCount > 0) { + logger.debug(s"Retrying SQL Query ($retryCount remaining) due to $msg") + Thread.sleep(20) + run(query, retryCount - 1, retryIfErrorContains) + } else { + logError(e, query) + reportErrorToSlack(e, query) + Fox.failure("SQL Failure: " + e.getMessage) + } + } } + foxFuture.toFox.flatten } - protected def writeEscapedTuple(seq: List[String]): String = - "(" + seq.map(escapeLiteral).mkString(", ") + ")" - - protected def sanitize(aString: String): String = aString.replaceAll("'", "") - - // escape ' by doubling it, escape " with backslash, drop commas - protected def sanitizeInArrayTuple(aString: String): String = - aString.replaceAll("'", """''""").replaceAll(""""""", """\\"""").replaceAll(""",""", "") - - protected def desanitizeFromArrayTuple(aString: String): String = - aString.replaceAll("""\\"""", """"""").replaceAll("""\\,""", ",") - - protected def optionLiteral(aStringOpt: Option[String]): String = aStringOpt match { - case Some(aString) => "'" + aString + "'" - case None => "null" + private def logError[R](ex: Throwable, query: DBIOAction[R, NoStream, Nothing]): Unit = { + logger.error("SQL Error: " + ex) + logger.debug("Caused by query:\n" + query.getDumpInfo.mainInfo) } - protected def optionLiteralSanitized(aStringOpt: Option[String]): String = optionLiteral(aStringOpt.map(sanitize)) - + private def reportErrorToSlack[R](ex: Throwable, query: DBIOAction[R, NoStream, Nothing]): Unit = + sqlClient.getSlackNotificationService.warnWithException( + "SQL Error", + ex, + s"Causing query: ${query.getDumpInfo.mainInfo}" + ) } abstract class SecuredSQLDAO @Inject()(sqlClient: SQLClient)(implicit ec: ExecutionContext) diff --git a/app/utils/sql/SqlInterpolation.scala b/app/utils/sql/SqlInterpolation.scala new file mode 100644 index 00000000000..0fa11f76490 --- /dev/null +++ b/app/utils/sql/SqlInterpolation.scala @@ -0,0 +1,248 @@ +package utils.sql + +import com.scalableminds.util.time.Instant +import play.api.libs.json.{JsValue, Json} +import slick.dbio.{Effect, NoStream} +import slick.jdbc._ +import slick.sql.{SqlAction, SqlStreamingAction} +import slick.util.DumpInfo +import utils.ObjectId + +import java.sql.{PreparedStatement, Types} +import scala.annotation.tailrec +import scala.collection.mutable +import scala.collection.mutable.ListBuffer +import scala.concurrent.duration +import scala.concurrent.duration.FiniteDuration + +class SqlInterpolator(val s: StringContext) extends AnyVal { + def q(param: Any*): SqlToken = { + val parts = s.parts.toList + val values = param.toList + + val outputSql = mutable.StringBuilder.newBuilder + val outputValues = ListBuffer[SqlValue]() + + assert(parts.length == values.length + 1) + for (i <- parts.indices) { + outputSql ++= parts(i) + + if (i < values.length) { + val value = values(i) + value match { + case x: SqlToken => + outputSql ++= x.sql + outputValues ++= x.values + case x => + val sqlValue = SqlValue.makeSqlValue(x) + outputSql ++= sqlValue.placeholder + outputValues += sqlValue + } + } + } + + SqlToken(sql = outputSql.toString, values = outputValues.toList) + } +} + +object SqlInterpolation { + implicit def sqlInterpolation(s: StringContext): SqlInterpolator = new SqlInterpolator(s) +} + +case class SqlToken(sql: String, values: List[SqlValue] = List()) { + def debugInfo: String = { + // The debugInfo should be pastable in an SQL client + val parts = sql.split("\\?", -1) + assert(parts.tail.length == values.length) + parts.tail.zip(values).foldLeft(parts.head)((acc, x) => acc + x._2.debugInfo + x._1) + } + + def as[R](implicit resultConverter: GetResult[R]): SqlStreamingAction[Vector[R], R, Effect] = + new StreamingInvokerAction[Vector[R], R, Effect] { + def statements: List[String] = List(sql) + + protected[this] def createInvoker(statements: Iterable[String]): StatementInvoker[R] = new StatementInvoker[R] { + val getStatement: String = statements.head + + protected def setParam(st: PreparedStatement): Unit = { + val pp = new PositionedParameters(st) + values.foreach(_.setParameter(pp)) + } + + protected def extractValue(rs: PositionedResult): R = resultConverter(rs) + } + + override def getDumpInfo = DumpInfo(DumpInfo.simpleNameFor(getClass), mainInfo = s"[$debugInfo]") + + protected[this] def createBuilder: mutable.Builder[R, Vector[R]] = Vector.newBuilder[R] + } + + def asUpdate: SqlAction[Int, NoStream, Effect] = as[Int](GetUpdateValue).head +} + +object SqlToken { + def join(values: List[Either[SqlValue, SqlToken]], sep: String): SqlToken = { + val outputSql = mutable.StringBuilder.newBuilder + val outputValues = ListBuffer[SqlValue]() + for (i <- values.indices) { + val value = values(i) + value match { + case Left(x) => + outputSql ++= x.placeholder + outputValues += x + case Right(x) => + outputSql ++= x.sql + outputValues ++= x.values + } + if (i < values.length - 1) { + outputSql ++= sep + } + } + SqlToken(sql = outputSql.toString, values = outputValues.toList) + } + + def tuple(values: Seq[Any]): SqlToken = { + val sqlValues = values.map(SqlValue.makeSqlValue) + SqlToken(sql = s"(${sqlValues.map(_.placeholder).mkString(", ")})", values = sqlValues.toList) + } + + def tupleList(values: Seq[Seq[Any]]): SqlToken = { + val sqlValueLists = values.map(list => list.map(SqlValue.makeSqlValue)) + SqlToken(sql = sqlValueLists.map(list => s"(${list.map(_.placeholder).mkString(", ")})").mkString(", "), + values = sqlValueLists.flatten.toList) + } + + def raw(s: String): SqlToken = SqlToken(s) + + def empty: SqlToken = raw("") + + def identifier(id: String): SqlToken = raw('"' + id + '"') +} + +trait SqlValue { + def setParameter(pp: PositionedParameters): Unit + + def placeholder: String = "?" + + def debugInfo: String +} + +object SqlValue { + + @tailrec + def makeSqlValue(p: Any): SqlValue = + p match { + case x: SqlValue => x + case x: String => StringValue(x) + case x: Option[_] => + x match { + case Some(y) => makeSqlValue(y) + case None => NoneValue() + } + case x: Short => ShortValue(x) + case x: Int => IntValue(x) + case x: Long => LongValue(x) + case x: Float => FloatValue(x) + case x: Double => DoubleValue(x) + case x: Boolean => BooleanValue(x) + case x: Instant => InstantValue(x) + case x: FiniteDuration => DurationValue(x) + case x: ObjectId => ObjectIdValue(x) + case x: JsValue => JsonValue(x) + } +} + +case class StringValue(v: String) extends SqlValue with Escaping { + override def setParameter(pp: PositionedParameters): Unit = pp.setString(v) + + override def debugInfo: String = escapeLiteral(v) +} + +case class ShortValue(v: Short) extends SqlValue { + override def setParameter(pp: PositionedParameters): Unit = pp.setShort(v) + + override def debugInfo: String = s"$v" +} + +case class IntValue(v: Int) extends SqlValue { + override def setParameter(pp: PositionedParameters): Unit = pp.setInt(v) + + override def debugInfo: String = s"$v" +} + +case class LongValue(v: Long) extends SqlValue { + override def setParameter(pp: PositionedParameters): Unit = pp.setLong(v) + + override def debugInfo: String = s"$v" +} + +case class FloatValue(v: Float) extends SqlValue { + override def setParameter(pp: PositionedParameters): Unit = pp.setFloat(v) + + override def debugInfo: String = s"$v" +} + +case class DoubleValue(v: Double) extends SqlValue { + override def setParameter(pp: PositionedParameters): Unit = pp.setDouble(v) + + override def debugInfo: String = s"$v" +} + +case class BooleanValue(v: Boolean) extends SqlValue { + override def setParameter(pp: PositionedParameters): Unit = pp.setBoolean(v) + + override def debugInfo: String = s"$v" +} + +case class InstantValue(v: Instant) extends SqlValue with Escaping { + override def setParameter(pp: PositionedParameters): Unit = pp.setTimestamp(v.toSql) + + override def placeholder: String = "?::TIMESTAMPTZ" + + override def debugInfo: String = escapeLiteral(v.toString) +} + +case class DurationValue(v: FiniteDuration) extends SqlValue with Escaping { + + private def stringifyDuration = v.unit match { + case duration.NANOSECONDS => s"${v.length.toDouble / 1000.0} MICROSECONDS" + case duration.MICROSECONDS => s"${v.length} MICROSECONDS" + case duration.MILLISECONDS => s"${v.length} MILLISECONDS" + case duration.SECONDS => s"${v.length} SECONDS" + case duration.MINUTES => s"${v.length} MINUTES" + case duration.HOURS => s"${v.length} HOURS" + case duration.DAYS => s"${v.length} DAYS" + } + + override def setParameter(pp: PositionedParameters): Unit = + pp.setString(stringifyDuration) + + override def placeholder: String = "?::INTERVAL" + + override def debugInfo: String = escapeLiteral(stringifyDuration) +} + +case class ObjectIdValue(v: ObjectId) extends SqlValue with Escaping { + override def setParameter(pp: PositionedParameters): Unit = pp.setString(v.id) + + override def debugInfo: String = escapeLiteral(v.id) +} + +case class JsonValue(v: JsValue) extends SqlValue with Escaping { + override def setParameter(pp: PositionedParameters): Unit = pp.setString(Json.stringify(v)) + + override def placeholder: String = "?::JSONB" + + override def debugInfo: String = escapeLiteral(Json.stringify(v)) +} + +case class NoneValue() extends SqlValue { + override def setParameter(pp: PositionedParameters): Unit = pp.setNull(Types.BOOLEAN) + + override def debugInfo: String = "NULL" +} + +private object GetUpdateValue extends GetResult[Int] { + def apply(pr: PositionedResult) = + throw new Exception("Update statements should not return a ResultSet") +} diff --git a/test/backend/SqlInterpolationTestSuite.scala b/test/backend/SqlInterpolationTestSuite.scala new file mode 100644 index 00000000000..876816ffd2a --- /dev/null +++ b/test/backend/SqlInterpolationTestSuite.scala @@ -0,0 +1,145 @@ +package backend + +import com.scalableminds.util.time.Instant +import org.scalatestplus.play.PlaySpec +import play.api.libs.json.Json +import utils.ObjectId +import utils.sql.SqlInterpolation.sqlInterpolation +import utils.sql._ + +import scala.concurrent.duration.DurationInt + +class SqlInterpolationTestSuite extends PlaySpec { + "SQL query creation" should { + "construct an SQLToken with null value" in { + val sql = q"""SELECT $None""" + assert(sql == SqlToken("SELECT ?", List(NoneValue()))) + assert(sql.debugInfo == "SELECT NULL") + } + "construct an SQLToken with boolean" in { + val sql = q"""SELECT ${Some(true)}""" + assert(sql == SqlToken("SELECT ?", List(BooleanValue(true)))) + } + "construct an SQLToken with string" in { + val sql = q"""SELECT * FROM test WHERE name = ${"Amy"}""" + assert(sql == SqlToken("SELECT * FROM test WHERE name = ?", List(StringValue("Amy")))) + } + "construct an SQLToken with escaped string" in { + val sql = q"""SELECT * FROM test WHERE name = ${"'; DROP TABLE test; --"}""" + assert(sql == SqlToken("SELECT * FROM test WHERE name = ?", List(StringValue("'; DROP TABLE test; --")))) + assert(sql.debugInfo == "SELECT * FROM test WHERE name = '''; DROP TABLE test; --'") + } + "construct an SQLToken with numbers" in { + val sql0 = q"""SELECT * FROM test WHERE age = ${3.shortValue}""" + assert(sql0 == SqlToken("SELECT * FROM test WHERE age = ?", List(ShortValue(3)))) + val sql1 = q"""SELECT * FROM test WHERE age = ${3}""" + assert(sql1 == SqlToken("SELECT * FROM test WHERE age = ?", List(IntValue(3)))) + val sql2 = q"""SELECT * FROM test WHERE age = ${3L}""" + assert(sql2 == SqlToken("SELECT * FROM test WHERE age = ?", List(LongValue(3L)))) + val sql3 = q"""SELECT * FROM test WHERE age = ${3.0f}""" + assert(sql3 == SqlToken("SELECT * FROM test WHERE age = ?", List(FloatValue(3.0f)))) + val sql4 = q"""SELECT * FROM test WHERE age = ${3.0}""" + assert(sql4 == SqlToken("SELECT * FROM test WHERE age = ?", List(DoubleValue(3.0)))) + } + "construct an SQLToken with json" in { + val json = Json.obj("street" -> "Market St") + val sql = q"""SELECT * FROM test WHERE address = $json""" + assert(sql == SqlToken("SELECT * FROM test WHERE address = ?::JSONB", List(JsonValue(json)))) + assert(sql.debugInfo == """SELECT * FROM test WHERE address = '{"street":"Market St"}'::JSONB""") + } + "construct an SQLToken with object id" in { + val id = ObjectId.generate + val sql = q"""SELECT * FROM test WHERE _id = $id""" + assert(sql == SqlToken("SELECT * FROM test WHERE _id = ?", List(ObjectIdValue(id)))) + } + "construct an SQLToken with date" in { + val time = Instant(1671885060000L) + val sql = q"""SELECT * FROM test WHERE created < $time""" + assert(sql == SqlToken("SELECT * FROM test WHERE created < ?::TIMESTAMPTZ", List(InstantValue(time)))) + assert(sql.debugInfo == "SELECT * FROM test WHERE created < '2022-12-24T12:31:00Z'::TIMESTAMPTZ") + } + "construct an SQLToken with duration" in { + val duration0 = 12 nanos + val sql0 = q"""SELECT $duration0""" + assert(sql0 == SqlToken("SELECT ?::INTERVAL", List(DurationValue(duration0)))) + assert(sql0.debugInfo == "SELECT '0.012 MICROSECONDS'::INTERVAL") + + val duration1 = 12 micros + val sql1 = q"""SELECT $duration1""" + assert(sql1 == SqlToken("SELECT ?::INTERVAL", List(DurationValue(duration1)))) + assert(sql1.debugInfo == "SELECT '12 MICROSECONDS'::INTERVAL") + + val duration2 = 12 millis + val sql2 = q"""SELECT $duration2""" + assert(sql2 == SqlToken("SELECT ?::INTERVAL", List(DurationValue(duration2)))) + assert(sql2.debugInfo == "SELECT '12 MILLISECONDS'::INTERVAL") + + val duration3 = 12 seconds + val sql3 = q"""SELECT $duration3""" + assert(sql3 == SqlToken("SELECT ?::INTERVAL", List(DurationValue(duration3)))) + assert(sql3.debugInfo == "SELECT '12 SECONDS'::INTERVAL") + + val duration4 = 12 minutes + val sql4 = q"""SELECT $duration4""" + assert(sql4 == SqlToken("SELECT ?::INTERVAL", List(DurationValue(duration4)))) + assert(sql4.debugInfo == "SELECT '12 MINUTES'::INTERVAL") + + val duration5 = 12 hours + val sql5 = q"""SELECT $duration5""" + assert(sql5 == SqlToken("SELECT ?::INTERVAL", List(DurationValue(duration5)))) + assert(sql5.debugInfo == "SELECT '12 HOURS'::INTERVAL") + + val duration6 = 12 days + val sql6 = q"""SELECT $duration6""" + assert(sql6 == SqlToken("SELECT ?::INTERVAL", List(DurationValue(duration6)))) + assert(sql6.debugInfo == "SELECT '12 DAYS'::INTERVAL") + } + "construct an SQLToken with multiple values" in { + val sql = q"""SELECT * FROM test WHERE age = ${3} AND name = ${"Amy"}""" + assert(sql == SqlToken("SELECT * FROM test WHERE age = ? AND name = ?", List(IntValue(3), StringValue("Amy")))) + } + "construct an SQLToken with identifiers" in { + val sql = q"""SELECT * FROM ${SqlToken.identifier("test")}""" + assert(sql == SqlToken("SELECT * FROM \"test\"")) + } + "construct an SQLToken with raw SQL" in { + val sql = q"""SELECT * FROM test WHERE ${SqlToken.raw("TRUE")}""" + assert(sql == SqlToken("SELECT * FROM test WHERE TRUE")) + } + "construct an SQLToken with empty SQL" in { + val sql = q"""SELECT * FROM test ${SqlToken.empty}""" + assert(sql == SqlToken("SELECT * FROM test ")) + } + "construct an SQLToken with nested SQL" in { + val accessQ = q"""isAdmin = ${true}""" + val sql = q"""SELECT * FROM test WHERE $accessQ""" + assert(sql == SqlToken("SELECT * FROM test WHERE isAdmin = ?", List(BooleanValue(true)))) + } + "construct an SQLToken with tuple" in { + val list = List(3, 5) + val sql = q"""SELECT * FROM test WHERE age IN ${SqlToken.tuple(list)}""" + assert(sql == SqlToken("SELECT * FROM test WHERE age IN (?, ?)", List(IntValue(3), IntValue(5)))) + } + "construct an SQLToken with tuple lists" in { + val list = List(List("Bob", 5), List("Amy", 3)) + val sql = q"""INSERT INTO test(name, age) VALUES ${SqlToken.tupleList(list)}""" + assert( + sql == SqlToken("INSERT INTO test(name, age) VALUES (?, ?), (?, ?)", + List(StringValue("Bob"), IntValue(5), StringValue("Amy"), IntValue(3)))) + } + "construct an SQLToken with nested-joined SQL" in { + val fields = List("name", "age") + val values = List("Bob", 5) + val sql = + q"""INSERT INTO test(${SqlToken.join(fields.map(x => Right(SqlToken.identifier(x))), ", ")}) VALUES ${SqlToken + .tupleList(List(values))}""" + + assert( + sql == SqlToken("""INSERT INTO test("name", "age") VALUES (?, ?)""", List(StringValue("Bob"), IntValue(5)))) + } + "create debugInfo from SQLToken" in { + val sql = q"""SELECT * FROM test WHERE age = ${3} AND name = ${"Amy"}""" + assert(sql.debugInfo == "SELECT * FROM test WHERE age = 3 AND name = 'Amy'") + } + } +}