Skip to content

Commit

Permalink
type-safe savepoints.
Browse files Browse the repository at this point in the history
  • Loading branch information
igalklebanov committed Apr 26, 2024
1 parent 6abc7ad commit 20220be
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 4 deletions.
63 changes: 59 additions & 4 deletions src/kysely.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ import { Expression } from './expression/expression.js'
import { WithSchemaPlugin } from './plugin/with-schema/with-schema-plugin.js'
import { DrainOuterGeneric } from './util/type-utils.js'
import { QueryCompiler } from './query-compiler/query-compiler.js'
import {
ReleaseSavepoint,
RollbackToSavepoint,
} from './parser/savepoint-parser.js'

/**
* The main Kysely class.
Expand Down Expand Up @@ -648,9 +652,14 @@ preventAwait(
"don't await ControlledTransactionBuilder instances directly. To execute the transaction you need to call the `execute` method",
)

export class ControlledTransaction<DB> extends Transaction<DB> {
export class ControlledTransaction<
DB,
S extends string[] = [],
> extends Transaction<DB> {
readonly #props: ControlledTransactionProps
readonly #compileQuery: QueryCompiler['compileQuery']
#isCommitted: boolean
#isRolledBack: boolean

constructor(props: ControlledTransactionProps) {
const {
Expand All @@ -663,52 +672,98 @@ export class ControlledTransaction<DB> extends Transaction<DB> {

const queryId = createQueryId()
this.#compileQuery = (node) => props.executor.compileQuery(node, queryId)

this.#isCommitted = false
this.#isRolledBack = false
}

commit(): Command<void> {
this.#assertNotCommittedOrRolledBack()

return new Command(async () => {
await this.#props.driver.commitTransaction(this.#props.connection)
this.#isCommitted = true
await this.#releaseConnectionIfNecessary()
})
}

rollback(): Command<void> {
this.#assertNotCommittedOrRolledBack()

return new Command(async () => {
await this.#props.driver.rollbackTransaction(this.#props.connection)
this.#isRolledBack = true
await this.#releaseConnectionIfNecessary()
})
}

savepoint(savepointName: string): Command<void> {
savepoint<SN extends string>(
savepointName: SN extends S ? never : SN,
): Command<ControlledTransaction<DB, [...S, SN]>> {
this.#assertNotCommittedOrRolledBack()

return new Command(async () => {
await this.#props.driver.savepoint?.(
this.#props.connection,
savepointName,
this.#compileQuery,
)

return new ControlledTransaction({
...this.#props,
connection: this.#props.connection,
})
})
}

rollbackToSavepoint(savepointName: string): Command<void> {
rollbackToSavepoint<SN extends S[number]>(
savepointName: SN,
): Command<ControlledTransaction<DB, RollbackToSavepoint<S, SN>>> {
this.#assertNotCommittedOrRolledBack()

return new Command(async () => {
await this.#props.driver.rollbackToSavepoint?.(
this.#props.connection,
savepointName,
this.#compileQuery,
)

return new ControlledTransaction({
...this.#props,
connection: this.#props.connection,
})
})
}

releaseSavepoint(savepointName: string): Command<void> {
releaseSavepoint<SN extends S[number]>(
savepointName: SN,
): Command<ControlledTransaction<DB, ReleaseSavepoint<S, SN>>> {
this.#assertNotCommittedOrRolledBack()

return new Command(async () => {
await this.#props.driver.releaseSavepoint?.(
this.#props.connection,
savepointName,
this.#compileQuery,
)

return new ControlledTransaction({
...this.#props,
connection: this.#props.connection,
})
})
}

#assertNotCommittedOrRolledBack(): void {
if (this.#isCommitted) {
throw new Error('Transaction is already committed')
}

if (this.#isRolledBack) {
throw new Error('Transaction is already rolled back')
}
}

async #releaseConnectionIfNecessary(): Promise<void> {
if (this.#props.releaseConnectedWhenDone !== false) {
await this.#props.driver.releaseConnection(this.#props.connection)
Expand Down
18 changes: 18 additions & 0 deletions src/parser/savepoint-parser.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,24 @@
import { IdentifierNode } from '../operation-node/identifier-node.js'
import { RawNode } from '../operation-node/raw-node.js'

export type RollbackToSavepoint<
S extends string[],
SN extends S[number],
> = S extends [...infer L extends string[], infer R]
? R extends SN
? S
: RollbackToSavepoint<L, SN>
: never

export type ReleaseSavepoint<
S extends string[],
SN extends S[number],
> = S extends [...infer L extends string[], infer R]
? R extends SN
? L
: ReleaseSavepoint<L, SN>
: never

export function parseSavepointCommand(
command: string,
savepointName: string,
Expand Down

0 comments on commit 20220be

Please sign in to comment.