diff --git a/src/kysely.ts b/src/kysely.ts index 26e77c1b6..f3af1db13 100644 --- a/src/kysely.ts +++ b/src/kysely.ts @@ -37,6 +37,10 @@ import { Expression } from './expression/expression.js' import { WithSchemaPlugin } from './plugin/with-schema/with-schema-plugin.js' import { DrainOuterGeneric } from './util/type-utils.js' import { QueryCompiler } from './query-compiler/query-compiler.js' +import { + ReleaseSavepoint, + RollbackToSavepoint, +} from './parser/savepoint-parser.js' /** * The main Kysely class. @@ -648,9 +652,14 @@ preventAwait( "don't await ControlledTransactionBuilder instances directly. To execute the transaction you need to call the `execute` method", ) -export class ControlledTransaction extends Transaction { +export class ControlledTransaction< + DB, + S extends string[] = [], +> extends Transaction { readonly #props: ControlledTransactionProps readonly #compileQuery: QueryCompiler['compileQuery'] + #isCommitted: boolean + #isRolledBack: boolean constructor(props: ControlledTransactionProps) { const { @@ -663,52 +672,98 @@ export class ControlledTransaction extends Transaction { const queryId = createQueryId() this.#compileQuery = (node) => props.executor.compileQuery(node, queryId) + + this.#isCommitted = false + this.#isRolledBack = false } commit(): Command { + this.#assertNotCommittedOrRolledBack() + return new Command(async () => { await this.#props.driver.commitTransaction(this.#props.connection) + this.#isCommitted = true await this.#releaseConnectionIfNecessary() }) } rollback(): Command { + this.#assertNotCommittedOrRolledBack() + return new Command(async () => { await this.#props.driver.rollbackTransaction(this.#props.connection) + this.#isRolledBack = true await this.#releaseConnectionIfNecessary() }) } - savepoint(savepointName: string): Command { + savepoint( + savepointName: SN extends S ? never : SN, + ): Command> { + this.#assertNotCommittedOrRolledBack() + return new Command(async () => { await this.#props.driver.savepoint?.( this.#props.connection, savepointName, this.#compileQuery, ) + + return new ControlledTransaction({ + ...this.#props, + connection: this.#props.connection, + }) }) } - rollbackToSavepoint(savepointName: string): Command { + rollbackToSavepoint( + savepointName: SN, + ): Command>> { + this.#assertNotCommittedOrRolledBack() + return new Command(async () => { await this.#props.driver.rollbackToSavepoint?.( this.#props.connection, savepointName, this.#compileQuery, ) + + return new ControlledTransaction({ + ...this.#props, + connection: this.#props.connection, + }) }) } - releaseSavepoint(savepointName: string): Command { + releaseSavepoint( + savepointName: SN, + ): Command>> { + this.#assertNotCommittedOrRolledBack() + return new Command(async () => { await this.#props.driver.releaseSavepoint?.( this.#props.connection, savepointName, this.#compileQuery, ) + + return new ControlledTransaction({ + ...this.#props, + connection: this.#props.connection, + }) }) } + #assertNotCommittedOrRolledBack(): void { + if (this.#isCommitted) { + throw new Error('Transaction is already committed') + } + + if (this.#isRolledBack) { + throw new Error('Transaction is already rolled back') + } + } + async #releaseConnectionIfNecessary(): Promise { if (this.#props.releaseConnectedWhenDone !== false) { await this.#props.driver.releaseConnection(this.#props.connection) diff --git a/src/parser/savepoint-parser.ts b/src/parser/savepoint-parser.ts index fbbebd9cd..8163bb935 100644 --- a/src/parser/savepoint-parser.ts +++ b/src/parser/savepoint-parser.ts @@ -1,6 +1,24 @@ import { IdentifierNode } from '../operation-node/identifier-node.js' import { RawNode } from '../operation-node/raw-node.js' +export type RollbackToSavepoint< + S extends string[], + SN extends S[number], +> = S extends [...infer L extends string[], infer R] + ? R extends SN + ? S + : RollbackToSavepoint + : never + +export type ReleaseSavepoint< + S extends string[], + SN extends S[number], +> = S extends [...infer L extends string[], infer R] + ? R extends SN + ? L + : ReleaseSavepoint + : never + export function parseSavepointCommand( command: string, savepointName: string,