From 119e84608dac33204dd880c1cbccaa1d3558ca8b Mon Sep 17 00:00:00 2001 From: igalklebanov Date: Wed, 24 Apr 2024 00:20:58 +0300 Subject: [PATCH] savepoints. --- src/dialect/mssql/mssql-driver.ts | 30 ++++++++++ src/dialect/mysql/mysql-driver.ts | 32 ++++++++++ src/dialect/postgres/postgres-driver.ts | 37 ++++++++++++ src/dialect/sqlite/sqlite-driver.ts | 32 ++++++++++ src/driver/driver.ts | 28 +++++++++ src/driver/runtime-driver.ts | 45 +++++++++++++++ src/kysely.ts | 77 +++++++++++++++++++++---- src/parser/savepoint-parser.ts | 12 ++++ 8 files changed, 282 insertions(+), 11 deletions(-) create mode 100644 src/parser/savepoint-parser.ts diff --git a/src/dialect/mssql/mssql-driver.ts b/src/dialect/mssql/mssql-driver.ts index 5d9b9cb12..79a327392 100644 --- a/src/dialect/mssql/mssql-driver.ts +++ b/src/dialect/mssql/mssql-driver.ts @@ -30,6 +30,8 @@ import { CompiledQuery } from '../../query-compiler/compiled-query.js' import { extendStackTrace } from '../../util/stack-trace-utils.js' import { randomString } from '../../util/random-string.js' import { Deferred } from '../../util/deferred.js' +import { parseSavepointCommand } from '../../parser/savepoint-parser.js' +import { QueryCompiler } from '../../query-compiler/query-compiler.js' const PRIVATE_RELEASE_METHOD = Symbol() const PRIVATE_DESTROY_METHOD = Symbol() @@ -87,6 +89,34 @@ export class MssqlDriver implements Driver { await connection.rollbackTransaction() } + async savepoint( + connection: DatabaseConnection, + savepointName: string, + compileQuery: QueryCompiler['compileQuery'], + ): Promise { + await connection.executeQuery( + compileQuery(parseSavepointCommand('save transaction', savepointName)), + ) + } + + async rollbackToSavepoint( + connection: DatabaseConnection, + savepointName: string, + compileQuery: QueryCompiler['compileQuery'], + ): Promise { + await connection.executeQuery( + compileQuery( + parseSavepointCommand('rollback transaction', savepointName), + ), + ) + } + + async releaseSavepoint(): Promise { + throw new Error( + 'MS SQL Server (mssql) does not support releasing savepoints', + ) + } + async releaseConnection(connection: MssqlConnection): Promise { await connection[PRIVATE_RELEASE_METHOD]() this.#pool.release(connection) diff --git a/src/dialect/mysql/mysql-driver.ts b/src/dialect/mysql/mysql-driver.ts index f03067957..08e1e2cb0 100644 --- a/src/dialect/mysql/mysql-driver.ts +++ b/src/dialect/mysql/mysql-driver.ts @@ -3,7 +3,9 @@ import { QueryResult, } from '../../driver/database-connection.js' import { Driver, TransactionSettings } from '../../driver/driver.js' +import { parseSavepointCommand } from '../../parser/savepoint-parser.js' import { CompiledQuery } from '../../query-compiler/compiled-query.js' +import { QueryCompiler } from '../../query-compiler/query-compiler.js' import { isFunction, isObject, freeze } from '../../util/object-utils.js' import { extendStackTrace } from '../../util/stack-trace-utils.js' import { @@ -86,6 +88,36 @@ export class MysqlDriver implements Driver { await connection.executeQuery(CompiledQuery.raw('rollback')) } + async savepoint( + connection: DatabaseConnection, + savepointName: string, + compileQuery: QueryCompiler['compileQuery'], + ): Promise { + await connection.executeQuery( + compileQuery(parseSavepointCommand('savepoint', savepointName)), + ) + } + + async rollbackToSavepoint( + connection: DatabaseConnection, + savepointName: string, + compileQuery: QueryCompiler['compileQuery'], + ): Promise { + await connection.executeQuery( + compileQuery(parseSavepointCommand('rollback to', savepointName)), + ) + } + + async releaseSavepoint( + connection: DatabaseConnection, + savepointName: string, + compileQuery: QueryCompiler['compileQuery'], + ): Promise { + await connection.executeQuery( + compileQuery(parseSavepointCommand('release savepoint', savepointName)), + ) + } + async releaseConnection(connection: MysqlConnection): Promise { connection[PRIVATE_RELEASE_METHOD]() } diff --git a/src/dialect/postgres/postgres-driver.ts b/src/dialect/postgres/postgres-driver.ts index 70d8a4bfb..ea3b1cdf7 100644 --- a/src/dialect/postgres/postgres-driver.ts +++ b/src/dialect/postgres/postgres-driver.ts @@ -3,7 +3,14 @@ import { QueryResult, } from '../../driver/database-connection.js' import { Driver, TransactionSettings } from '../../driver/driver.js' +import { IdentifierNode } from '../../operation-node/identifier-node.js' +import { RawNode } from '../../operation-node/raw-node.js' +import { parseSavepointCommand } from '../../parser/savepoint-parser.js' import { CompiledQuery } from '../../query-compiler/compiled-query.js' +import { + QueryCompiler, + RootOperationNode, +} from '../../query-compiler/query-compiler.js' import { isFunction, freeze } from '../../util/object-utils.js' import { extendStackTrace } from '../../util/stack-trace-utils.js' import { @@ -74,6 +81,36 @@ export class PostgresDriver implements Driver { await connection.executeQuery(CompiledQuery.raw('rollback')) } + async savepoint( + connection: DatabaseConnection, + savepointName: string, + compileQuery: QueryCompiler['compileQuery'], + ): Promise { + await connection.executeQuery( + compileQuery(parseSavepointCommand('savepoint', savepointName)), + ) + } + + async rollbackToSavepoint( + connection: DatabaseConnection, + savepointName: string, + compileQuery: QueryCompiler['compileQuery'], + ): Promise { + await connection.executeQuery( + compileQuery(parseSavepointCommand('rollback to', savepointName)), + ) + } + + async releaseSavepoint( + connection: DatabaseConnection, + savepointName: string, + compileQuery: QueryCompiler['compileQuery'], + ): Promise { + await connection.executeQuery( + compileQuery(parseSavepointCommand('release', savepointName)), + ) + } + async releaseConnection(connection: PostgresConnection): Promise { connection[PRIVATE_RELEASE_METHOD]() } diff --git a/src/dialect/sqlite/sqlite-driver.ts b/src/dialect/sqlite/sqlite-driver.ts index dcfbca2e6..5aefb32ef 100644 --- a/src/dialect/sqlite/sqlite-driver.ts +++ b/src/dialect/sqlite/sqlite-driver.ts @@ -4,7 +4,9 @@ import { } from '../../driver/database-connection.js' import { Driver } from '../../driver/driver.js' import { SelectQueryNode } from '../../operation-node/select-query-node.js' +import { parseSavepointCommand } from '../../parser/savepoint-parser.js' import { CompiledQuery } from '../../query-compiler/compiled-query.js' +import { QueryCompiler } from '../../query-compiler/query-compiler.js' import { freeze, isFunction } from '../../util/object-utils.js' import { SqliteDatabase, SqliteDialectConfig } from './sqlite-dialect-config.js' @@ -50,6 +52,36 @@ export class SqliteDriver implements Driver { await connection.executeQuery(CompiledQuery.raw('rollback')) } + async savepoint( + connection: DatabaseConnection, + savepointName: string, + compileQuery: QueryCompiler['compileQuery'], + ): Promise { + await connection.executeQuery( + compileQuery(parseSavepointCommand('savepoint', savepointName)), + ) + } + + async rollbackToSavepoint( + connection: DatabaseConnection, + savepointName: string, + compileQuery: QueryCompiler['compileQuery'], + ): Promise { + await connection.executeQuery( + compileQuery(parseSavepointCommand('rollback to', savepointName)), + ) + } + + async releaseSavepoint( + connection: DatabaseConnection, + savepointName: string, + compileQuery: QueryCompiler['compileQuery'], + ): Promise { + await connection.executeQuery( + compileQuery(parseSavepointCommand('release', savepointName)), + ) + } + async releaseConnection(): Promise { this.#connectionMutex.unlock() } diff --git a/src/driver/driver.ts b/src/driver/driver.ts index ef9214d17..849cd0129 100644 --- a/src/driver/driver.ts +++ b/src/driver/driver.ts @@ -1,3 +1,4 @@ +import { QueryCompiler } from '../query-compiler/query-compiler.js' import { ArrayItemType } from '../util/type-utils.js' import { DatabaseConnection } from './database-connection.js' @@ -37,6 +38,33 @@ export interface Driver { */ rollbackTransaction(connection: DatabaseConnection): Promise + /** + * Establishses a new savepoint within a transaction. + */ + savepoint?( + connection: DatabaseConnection, + savepointName: string, + compileQuery: QueryCompiler['compileQuery'], + ): Promise + + /** + * Rolls back to a savepoint within a transaction. + */ + rollbackToSavepoint?( + connection: DatabaseConnection, + savepointName: string, + compileQuery: QueryCompiler['compileQuery'], + ): Promise + + /** + * Releases a savepoint within a transaction. + */ + releaseSavepoint?( + connection: DatabaseConnection, + savepointName: string, + compileQuery: QueryCompiler['compileQuery'], + ): Promise + /** * Releases a connection back to the pool. */ diff --git a/src/driver/runtime-driver.ts b/src/driver/runtime-driver.ts index a7ba8d771..cf7563412 100644 --- a/src/driver/runtime-driver.ts +++ b/src/driver/runtime-driver.ts @@ -1,4 +1,5 @@ import { CompiledQuery } from '../query-compiler/compiled-query.js' +import { QueryCompiler } from '../query-compiler/query-compiler.js' import { Log } from '../util/log.js' import { performanceNow } from '../util/performance-now.js' import { DatabaseConnection, QueryResult } from './database-connection.js' @@ -85,6 +86,50 @@ export class RuntimeDriver implements Driver { return this.#driver.rollbackTransaction(connection) } + savepoint( + connection: DatabaseConnection, + savepointName: string, + compileQuery: QueryCompiler['compileQuery'], + ): Promise { + if (this.#driver.savepoint) { + return this.#driver.savepoint(connection, savepointName, compileQuery) + } + + throw new Error('savepoints are not supported by this driver') + } + + rollbackToSavepoint( + connection: DatabaseConnection, + savepointName: string, + compileQuery: QueryCompiler['compileQuery'], + ): Promise { + if (this.#driver.rollbackToSavepoint) { + return this.#driver.rollbackToSavepoint( + connection, + savepointName, + compileQuery, + ) + } + + throw new Error('savepoints are not supported by this driver') + } + + releaseSavepoint( + connection: DatabaseConnection, + savepointName: string, + compileQuery: QueryCompiler['compileQuery'], + ): Promise { + if (this.#driver.releaseSavepoint) { + return this.#driver.releaseSavepoint( + connection, + savepointName, + compileQuery, + ) + } + + throw new Error('savepoints are not supported by this driver') + } + async destroy(): Promise { if (!this.#initPromise) { return diff --git a/src/kysely.ts b/src/kysely.ts index 87ca88de0..26e77c1b6 100644 --- a/src/kysely.ts +++ b/src/kysely.ts @@ -36,6 +36,7 @@ import { parseExpression } from './parser/expression-parser.js' import { Expression } from './expression/expression.js' import { WithSchemaPlugin } from './plugin/with-schema/with-schema-plugin.js' import { DrainOuterGeneric } from './util/type-utils.js' +import { QueryCompiler } from './query-compiler/query-compiler.js' /** * The main Kysely class. @@ -606,9 +607,9 @@ function validateTransactionSettings(settings: TransactionSettings): void { } export class ControlledTransactionBuilder { - readonly #props: TransactionBuilderProps + readonly #props: ControlledTransactionBuilderProps - constructor(props: TransactionBuilderProps) { + constructor(props: ControlledTransactionBuilderProps) { this.#props = freeze(props) } @@ -622,7 +623,7 @@ export class ControlledTransactionBuilder { } async execute(): Promise> { - const { isolationLevel, ...kyselyProps } = this.#props + const { isolationLevel, ...props } = this.#props const settings = { isolationLevel } validateTransactionSettings(settings) @@ -632,12 +633,16 @@ export class ControlledTransactionBuilder { await this.#props.driver.beginTransaction(connection, settings) return new ControlledTransaction({ - ...kyselyProps, + ...props, connection, }) } } +interface ControlledTransactionBuilderProps extends TransactionBuilderProps { + readonly releaseConnection?: boolean +} + preventAwait( ControlledTransactionBuilder, "don't await ControlledTransactionBuilder instances directly. To execute the transaction you need to call the `execute` method", @@ -645,28 +650,75 @@ preventAwait( export class ControlledTransaction extends Transaction { readonly #props: ControlledTransactionProps + readonly #compileQuery: QueryCompiler['compileQuery'] constructor(props: ControlledTransactionProps) { - const { connection, ...transactionProps } = props + const { + connection, + releaseConnectedWhenDone: releaseConnection, + ...transactionProps + } = props super(transactionProps) this.#props = props + + const queryId = createQueryId() + this.#compileQuery = (node) => props.executor.compileQuery(node, queryId) } commit(): Command { - return new Command(() => - this.#props.driver.commitTransaction(this.#props.connection), - ) + return new Command(async () => { + await this.#props.driver.commitTransaction(this.#props.connection) + await this.#releaseConnectionIfNecessary() + }) } rollback(): Command { - return new Command(() => - this.#props.driver.rollbackTransaction(this.#props.connection), - ) + return new Command(async () => { + await this.#props.driver.rollbackTransaction(this.#props.connection) + await this.#releaseConnectionIfNecessary() + }) + } + + savepoint(savepointName: string): Command { + return new Command(async () => { + await this.#props.driver.savepoint?.( + this.#props.connection, + savepointName, + this.#compileQuery, + ) + }) + } + + rollbackToSavepoint(savepointName: string): Command { + return new Command(async () => { + await this.#props.driver.rollbackToSavepoint?.( + this.#props.connection, + savepointName, + this.#compileQuery, + ) + }) + } + + releaseSavepoint(savepointName: string): Command { + return new Command(async () => { + await this.#props.driver.releaseSavepoint?.( + this.#props.connection, + savepointName, + this.#compileQuery, + ) + }) + } + + async #releaseConnectionIfNecessary(): Promise { + if (this.#props.releaseConnectedWhenDone !== false) { + await this.#props.driver.releaseConnection(this.#props.connection) + } } } interface ControlledTransactionProps extends KyselyProps { readonly connection: DatabaseConnection + readonly releaseConnectedWhenDone?: boolean } preventAwait( @@ -681,6 +733,9 @@ export class Command { this.#cb = cb } + /** + * Executes the command. + */ async execute(): Promise { return await this.#cb() } diff --git a/src/parser/savepoint-parser.ts b/src/parser/savepoint-parser.ts new file mode 100644 index 000000000..fbbebd9cd --- /dev/null +++ b/src/parser/savepoint-parser.ts @@ -0,0 +1,12 @@ +import { IdentifierNode } from '../operation-node/identifier-node.js' +import { RawNode } from '../operation-node/raw-node.js' + +export function parseSavepointCommand( + command: string, + savepointName: string, +): RawNode { + return RawNode.createWithChildren([ + RawNode.createWithSql(`${command} `), + IdentifierNode.create(savepointName), // ensures savepointName gets sanitized + ]) +}