diff --git a/src/expression/expression-builder.ts b/src/expression/expression-builder.ts index d5cb8b596..5ff308322 100644 --- a/src/expression/expression-builder.ts +++ b/src/expression/expression-builder.ts @@ -18,8 +18,10 @@ import { } from '../query-builder/function-module.js' import { ExtractTypeFromReferenceExpression, + parseReferenceExpression, parseStringReference, ReferenceExpression, + SimpleReferenceExpression, StringReference, } from '../parser/reference-parser.js' import { QueryExecutor } from '../query-executor/query-executor.js' @@ -47,6 +49,9 @@ import { } from '../parser/value-parser.js' import { NOOP_QUERY_EXECUTOR } from '../query-executor/noop-query-executor.js' import { ValueNode } from '../operation-node/value-node.js' +import { CaseBuilder } from '../query-builder/case-builder.js' +import { CaseNode } from '../operation-node/case-node.js' +import { isUndefined } from '../util/object-utils.js' export interface ExpressionBuilder { /** @@ -150,6 +155,65 @@ export interface ExpressionBuilder { from: TE ): SelectQueryBuilder, FromTables, {}> + /** + * Creates a `case` statement/operator. + * + * ### Examples + * + * Kitchen sink example with 2 flavors of `case` operator: + * + * ```ts + * import { sql } from 'kysely' + * + * const { title, name } = await db + * .selectFrom('person') + * .where('id', '=', '123') + * .select((eb) => [ + * eb.fn.coalesce('last_name', 'first_name').as('name'), + * eb + * .case() + * .when('gender', '=', 'male') + * .then('Mr.') + * .when('gender', '=', 'female') + * .then( + * eb + * .case('martialStatus') + * .when('single') + * .then('Ms.') + * .else('Mrs.') + * .end() + * ) + * .end() + * .as('title'), + * ]) + * .executeTakeFirstOrThrow() + * ``` + * + * The generated SQL (PostgreSQL): + * + * ```sql + * select + * coalesce("last_name", "first_name") as "name", + * case + * when "gender" = $1 then $2 + * when "gender" = $3 then + * case "martialStatus" + * when $4 then $5 + * else $6 + * end + * end as "title" + * from "person" + * where "id" = $7 + * ``` + */ + case(): CaseBuilder + + case>( + column: C + ): CaseBuilder> + + case(expression: Expression): CaseBuilder + /** * This can be used to reference columns. * @@ -505,6 +569,18 @@ export function createExpressionBuilder( }) }, + case>( + reference?: RE + ): CaseBuilder> { + return new CaseBuilder({ + node: CaseNode.create( + isUndefined(reference) + ? undefined + : parseReferenceExpression(reference) + ), + }) + }, + ref>( reference: RE ): ExpressionWrapper> { diff --git a/src/index.ts b/src/index.ts index dfefe329b..1fa85560c 100644 --- a/src/index.ts +++ b/src/index.ts @@ -6,6 +6,7 @@ export { ExpressionBuilder, expressionBuilder, } from './expression/expression-builder.js' +export * from './expression/expression-wrapper.js' export * from './query-builder/where-interface.js' export * from './query-builder/returning-interface.js' @@ -22,6 +23,7 @@ export * from './query-builder/delete-result.js' export * from './query-builder/update-result.js' export * from './query-builder/on-conflict-builder.js' export * from './query-builder/aggregate-function-builder.js' +export * from './query-builder/case-builder.js' export * from './raw-builder/raw-builder.js' export * from './raw-builder/sql.js' @@ -102,6 +104,7 @@ export * from './operation-node/alias-node.js' export * from './operation-node/alter-column-node.js' export * from './operation-node/alter-table-node.js' export * from './operation-node/and-node.js' +export * from './operation-node/case-node.js' export * from './operation-node/check-constraint-node.js' export * from './operation-node/column-definition-node.js' export * from './operation-node/column-node.js' @@ -167,6 +170,7 @@ export * from './operation-node/update-query-node.js' export * from './operation-node/value-list-node.js' export * from './operation-node/value-node.js' export * from './operation-node/values-node.js' +export * from './operation-node/when-node.js' export * from './operation-node/where-node.js' export * from './operation-node/with-node.js' export * from './operation-node/explain-node.js' diff --git a/src/kysely.ts b/src/kysely.ts index 6be742d16..33101ccc0 100644 --- a/src/kysely.ts +++ b/src/kysely.ts @@ -7,7 +7,7 @@ import { QueryCreator, QueryCreatorProps } from './query-creator.js' import { KyselyPlugin } from './plugin/kysely-plugin.js' import { DefaultQueryExecutor } from './query-executor/default-query-executor.js' import { DatabaseIntrospector } from './dialect/database-introspector.js' -import { freeze, isObject } from './util/object-utils.js' +import { freeze, isObject, isUndefined } from './util/object-utils.js' import { RuntimeDriver } from './driver/runtime-driver.js' import { SingleConnectionProvider } from './driver/single-connection-provider.js' import { @@ -27,6 +27,10 @@ import { QueryResult } from './driver/database-connection.js' import { CompiledQuery } from './query-compiler/compiled-query.js' import { createQueryId, QueryId } from './util/query-id.js' import { Compilable, isCompilable } from './util/compilable.js' +import { CaseBuilder } from './query-builder/case-builder.js' +import { CaseNode } from './operation-node/case-node.js' +import { parseExpression } from './parser/expression-parser.js' +import { Expression } from './expression/expression.js' import { WithSchemaPlugin } from './plugin/with-schema/with-schema-plugin.js' /** @@ -142,6 +146,23 @@ export class Kysely return this.#props.dialect.createIntrospector(this.withoutPlugins()) } + /** + * Creates a `case` statement/operator. + * + * See {@link ExpressionBuilder.case} for more information. + */ + case(): CaseBuilder + + case(value: Expression): CaseBuilder + + case(value?: Expression): any { + return new CaseBuilder({ + node: CaseNode.create( + isUndefined(value) ? undefined : parseExpression(value) + ), + }) + } + /** * Returns a {@link FunctionModule} that can be used to write type safe function * calls. diff --git a/src/operation-node/case-node.ts b/src/operation-node/case-node.ts new file mode 100644 index 000000000..a56c0400e --- /dev/null +++ b/src/operation-node/case-node.ts @@ -0,0 +1,59 @@ +import { freeze } from '../util/object-utils.js' +import { OperationNode } from './operation-node.js' +import { WhenNode } from './when-node.js' + +export interface CaseNode extends OperationNode { + readonly kind: 'CaseNode' + readonly value?: OperationNode + readonly when?: ReadonlyArray + readonly else?: OperationNode + readonly isStatement?: boolean +} + +/** + * @internal + */ +export const CaseNode = freeze({ + is(node: OperationNode): node is CaseNode { + return node.kind === 'CaseNode' + }, + + create(value?: OperationNode): CaseNode { + return freeze({ + kind: 'CaseNode', + value, + }) + }, + + cloneWithWhen(caseNode: CaseNode, when: WhenNode): CaseNode { + return freeze({ + ...caseNode, + when: freeze(caseNode.when ? [...caseNode.when, when] : [when]), + }) + }, + + cloneWithThen(caseNode: CaseNode, then: OperationNode): CaseNode { + return freeze({ + ...caseNode, + when: caseNode.when + ? freeze([ + ...caseNode.when.slice(0, -1), + WhenNode.cloneWithResult( + caseNode.when[caseNode.when.length - 1], + then + ), + ]) + : undefined, + }) + }, + + cloneWith( + caseNode: CaseNode, + props: Partial> + ): CaseNode { + return freeze({ + ...caseNode, + ...props, + }) + }, +}) diff --git a/src/operation-node/operation-node-transformer.ts b/src/operation-node/operation-node-transformer.ts index 60710756d..9f737f52c 100644 --- a/src/operation-node/operation-node-transformer.ts +++ b/src/operation-node/operation-node-transformer.ts @@ -80,6 +80,8 @@ import { BinaryOperationNode } from './binary-operation-node.js' import { UnaryOperationNode } from './unary-operation-node.js' import { UsingNode } from './using-node.js' import { FunctionNode } from './function-node.js' +import { CaseNode } from './case-node.js' +import { WhenNode } from './when-node.js' /** * Transforms an operation node tree into another one. @@ -194,6 +196,8 @@ export class OperationNodeTransformer { UnaryOperationNode: this.transformUnaryOperation.bind(this), UsingNode: this.transformUsing.bind(this), FunctionNode: this.transformFunction.bind(this), + CaseNode: this.transformCase.bind(this), + WhenNode: this.transformWhen.bind(this), }) transformNode(node: T): T { @@ -898,6 +902,24 @@ export class OperationNodeTransformer { }) } + protected transformCase(node: CaseNode): CaseNode { + return requireAllProps({ + kind: 'CaseNode', + value: this.transformNode(node.value), + when: this.transformNodeList(node.when), + else: this.transformNode(node.else), + isStatement: node.isStatement, + }) + } + + protected transformWhen(node: WhenNode): WhenNode { + return requireAllProps({ + kind: 'WhenNode', + condition: this.transformNode(node.condition), + result: this.transformNode(node.result), + }) + } + protected transformDataType(node: DataTypeNode): DataTypeNode { // An Object.freezed leaf node. No need to clone. return node diff --git a/src/operation-node/operation-node-visitor.ts b/src/operation-node/operation-node-visitor.ts index 4500eefbb..5a9871b4f 100644 --- a/src/operation-node/operation-node-visitor.ts +++ b/src/operation-node/operation-node-visitor.ts @@ -82,6 +82,8 @@ import { BinaryOperationNode } from './binary-operation-node.js' import { UnaryOperationNode } from './unary-operation-node.js' import { UsingNode } from './using-node.js' import { FunctionNode } from './function-node.js' +import { WhenNode } from './when-node.js' +import { CaseNode } from './case-node.js' export abstract class OperationNodeVisitor { protected readonly nodeStack: OperationNode[] = [] @@ -171,6 +173,8 @@ export abstract class OperationNodeVisitor { UnaryOperationNode: this.visitUnaryOperation.bind(this), UsingNode: this.visitUsing.bind(this), FunctionNode: this.visitFunction.bind(this), + CaseNode: this.visitCase.bind(this), + WhenNode: this.visitWhen.bind(this), }) protected readonly visitNode = (node: OperationNode): void => { @@ -268,4 +272,6 @@ export abstract class OperationNodeVisitor { protected abstract visitUnaryOperation(node: UnaryOperationNode): void protected abstract visitUsing(node: UsingNode): void protected abstract visitFunction(node: FunctionNode): void + protected abstract visitCase(node: CaseNode): void + protected abstract visitWhen(node: WhenNode): void } diff --git a/src/operation-node/operation-node.ts b/src/operation-node/operation-node.ts index 06214fcb4..04bb9122e 100644 --- a/src/operation-node/operation-node.ts +++ b/src/operation-node/operation-node.ts @@ -78,6 +78,8 @@ export type OperationNodeKind = | 'UnaryOperationNode' | 'UsingNode' | 'FunctionNode' + | 'CaseNode' + | 'WhenNode' export interface OperationNode { readonly kind: OperationNodeKind diff --git a/src/operation-node/when-node.ts b/src/operation-node/when-node.ts new file mode 100644 index 000000000..4575a2316 --- /dev/null +++ b/src/operation-node/when-node.ts @@ -0,0 +1,31 @@ +import { freeze } from '../util/object-utils.js' +import { OperationNode } from './operation-node.js' + +export interface WhenNode extends OperationNode { + readonly kind: 'WhenNode' + readonly condition: OperationNode + readonly result?: OperationNode +} + +/** + * @internal + */ +export const WhenNode = freeze({ + is(node: OperationNode): node is WhenNode { + return node.kind === 'WhenNode' + }, + + create(condition: OperationNode): WhenNode { + return freeze({ + kind: 'WhenNode', + condition, + }) + }, + + cloneWithResult(whenNode: WhenNode, result: OperationNode): WhenNode { + return freeze({ + ...whenNode, + result, + }) + }, +}) diff --git a/src/parser/binary-operation-parser.ts b/src/parser/binary-operation-parser.ts index fe01de1aa..840c5aaa5 100644 --- a/src/parser/binary-operation-parser.ts +++ b/src/parser/binary-operation-parser.ts @@ -39,6 +39,8 @@ import { Expression } from '../expression/expression.js' import { SelectQueryNode } from '../operation-node/select-query-node.js' import { JoinNode } from '../operation-node/join-node.js' import { expressionBuilder } from '../expression/expression-builder.js' +import { UnaryOperationNode } from '../operation-node/unary-operation-node.js' +import { CaseNode } from '../operation-node/case-node.js' export type OperandValueExpression< DB, @@ -67,7 +69,7 @@ export type ArithmeticOperatorExpression = | ArithmeticOperator | Expression -type FilterExpressionType = 'where' | 'having' | 'on' +type FilterExpressionType = 'where' | 'having' | 'on' | 'when' export function parseValueBinaryOperation( leftOperand: ReferenceExpression, @@ -141,6 +143,10 @@ export function parseOn(args: any[]): OperationNode { return parseFilter('on', args) } +export function parseWhen(args: any[]): OperationNode { + return parseFilter('when', args) +} + function parseFilter(type: FilterExpressionType, args: any[]): OperationNode { if (args.length === 3) { return parseValueComparison(args[0], args[1], args[2]) @@ -192,13 +198,25 @@ function parseOneArgFilterExpression( arg: any ): OperationNode { if (isFunction(arg)) { + if (type === 'when') { + throw new Error(`when method doesn't accept a callback as an argument`) + } + return CALLBACK_PARSERS[type](arg) } else if (isOperationNodeSource(arg)) { const node = arg.toOperationNode() - if (RawNode.is(node)) { + if ( + RawNode.is(node) || + BinaryOperationNode.is(node) || + UnaryOperationNode.is(node) || + ParensNode.is(node) || + CaseNode.is(node) + ) { return node } + } else if (type === 'when') { + return ValueNode.create(arg) } throw createFilterExpressionError(type, arg) diff --git a/src/query-builder/case-builder.ts b/src/query-builder/case-builder.ts new file mode 100644 index 000000000..399bc4d6d --- /dev/null +++ b/src/query-builder/case-builder.ts @@ -0,0 +1,202 @@ +import { Expression } from '../expression/expression.js' +import { ExpressionWrapper } from '../expression/expression-wrapper.js' +import { freeze } from '../util/object-utils.js' +import { ReferenceExpression } from '../parser/reference-parser.js' +import { CaseNode } from '../operation-node/case-node.js' +import { WhenNode } from '../operation-node/when-node.js' +import { + ComparisonOperatorExpression, + OperandValueExpressionOrList, + parseWhen, +} from '../parser/binary-operation-parser.js' +import { parseValueExpression } from '../parser/value-parser.js' +import { KyselyTypeError } from '../util/type-error.js' + +export class CaseBuilder + implements Whenable +{ + readonly #props: CaseBuilderProps + + constructor(props: CaseBuilderProps) { + this.#props = freeze(props) + } + + when>( + lhs: unknown extends W + ? RE + : KyselyTypeError<'when(lhs, op, rhs) is not supported when using case(value)'>, + op: ComparisonOperatorExpression, + rhs: OperandValueExpressionOrList + ): CaseThenBuilder + when(expression: Expression): CaseThenBuilder + when( + value: unknown extends W + ? KyselyTypeError<'when(value) is only supported when using case(value)'> + : W + ): CaseThenBuilder + + when(...args: any[]): any { + return new CaseThenBuilder({ + ...this.#props, + node: CaseNode.cloneWithWhen( + this.#props.node, + WhenNode.create(parseWhen(args)) + ), + }) + } +} + +interface CaseBuilderProps { + readonly node: CaseNode +} + +export class CaseThenBuilder { + readonly #props: CaseBuilderProps + + constructor(props: CaseBuilderProps) { + this.#props = freeze(props) + } + + /** + * Adds a `then` clause to the `case` statement. + * + * A `then` call can be followed by {@link Whenable.when}, {@link CaseWhenBuilder.else}, + * {@link CaseWhenBuilder.end} or {@link CaseWhenBuilder.endCase} call. + */ + then(expression: Expression): CaseWhenBuilder + then(value: V): CaseWhenBuilder + + then(valueExpression: any): any { + return new CaseWhenBuilder({ + ...this.#props, + node: CaseNode.cloneWithThen( + this.#props.node, + parseValueExpression(valueExpression) + ), + }) + } +} + +export class CaseWhenBuilder + implements Whenable, Endable +{ + readonly #props: CaseBuilderProps + + constructor(props: CaseBuilderProps) { + this.#props = freeze(props) + } + + when>( + lhs: unknown extends W + ? RE + : KyselyTypeError<'when(lhs, op, rhs) is not supported when using case(value)'>, + op: ComparisonOperatorExpression, + rhs: OperandValueExpressionOrList + ): CaseThenBuilder + when(expression: Expression): CaseThenBuilder + when( + value: unknown extends W + ? KyselyTypeError<'when(value) is only supported when using case(value)'> + : W + ): CaseThenBuilder + + when(...args: any[]): any { + return new CaseThenBuilder({ + ...this.#props, + node: CaseNode.cloneWithWhen( + this.#props.node, + WhenNode.create(parseWhen(args)) + ), + }) + } + + /** + * Adds an `else` clause to the `case` statement. + * + * An `else` call must be followed by an {@link Endable.end} or {@link Endable.endCase} call. + */ + else(expression: Expression): CaseEndBuilder + else(value: V): CaseEndBuilder + + else(valueExpression: any): any { + return new CaseEndBuilder({ + ...this.#props, + node: CaseNode.cloneWith(this.#props.node, { + else: parseValueExpression(valueExpression), + }), + }) + } + + end(): ExpressionWrapper { + return new ExpressionWrapper( + CaseNode.cloneWith(this.#props.node, { isStatement: false }) + ) + } + + endCase(): ExpressionWrapper { + return new ExpressionWrapper( + CaseNode.cloneWith(this.#props.node, { isStatement: true }) + ) + } +} + +export class CaseEndBuilder implements Endable { + readonly #props: CaseBuilderProps + + constructor(props: CaseBuilderProps) { + this.#props = freeze(props) + } + + end(): ExpressionWrapper { + return new ExpressionWrapper( + CaseNode.cloneWith(this.#props.node, { isStatement: false }) + ) + } + + endCase(): ExpressionWrapper { + return new ExpressionWrapper( + CaseNode.cloneWith(this.#props.node, { isStatement: true }) + ) + } +} + +interface Whenable { + /** + * Adds a `when` clause to the case statement. + * + * A `when` call must be followed by a {@link CaseThenBuilder.then} call. + */ + when>( + lhs: unknown extends W + ? RE + : KyselyTypeError<'when(lhs, op, rhs) is not supported when using case(value)'>, + op: ComparisonOperatorExpression, + rhs: OperandValueExpressionOrList + ): CaseThenBuilder + + when(expression: Expression): CaseThenBuilder + + when( + value: unknown extends W + ? KyselyTypeError<'when(value) is only supported when using case(value)'> + : W + ): CaseThenBuilder +} + +interface Endable { + /** + * Adds an `end` keyword to the case operator. + * + * `case` operators can only be used as part of a query. + * For a `case` statement used as part of a stored program, use {@link endCase} instead. + */ + end(): ExpressionWrapper + + /** + * Adds `end case` keywords to the case statement. + * + * `case` statements can only be used for flow control in stored programs. + * For a `case` operator used as part of a query, use {@link end} instead. + */ + endCase(): ExpressionWrapper +} diff --git a/src/query-compiler/default-query-compiler.ts b/src/query-compiler/default-query-compiler.ts index 69b42b215..16c33de17 100644 --- a/src/query-compiler/default-query-compiler.ts +++ b/src/query-compiler/default-query-compiler.ts @@ -94,6 +94,8 @@ import { BinaryOperationNode } from '../operation-node/binary-operation-node.js' import { UnaryOperationNode } from '../operation-node/unary-operation-node.js' import { UsingNode } from '../operation-node/using-node.js' import { FunctionNode } from '../operation-node/function-node.js' +import { CaseNode } from '../operation-node/case-node.js' +import { WhenNode } from '../operation-node/when-node.js' export class DefaultQueryCompiler extends OperationNodeVisitor @@ -1289,6 +1291,42 @@ export class DefaultQueryCompiler this.append(')') } + protected override visitCase(node: CaseNode): void { + this.append('case') + + if (node.value) { + this.append(' ') + this.visitNode(node.value) + } + + if (node.when) { + this.append(' ') + this.compileList(node.when, ' ') + } + + if (node.else) { + this.append(' else ') + this.visitNode(node.else) + } + + this.append(' end') + + if (node.isStatement) { + this.append(' case') + } + } + + protected override visitWhen(node: WhenNode): void { + this.append('when ') + + this.visitNode(node.condition) + + if (node.result) { + this.append(' then ') + this.visitNode(node.result) + } + } + protected append(str: string): void { this.#sql += str } diff --git a/test/node/src/case.test.ts b/test/node/src/case.test.ts new file mode 100644 index 000000000..1d18ea20e --- /dev/null +++ b/test/node/src/case.test.ts @@ -0,0 +1,278 @@ +import { sql } from '../../..' +import { + DIALECTS, + TestContext, + clearDatabase, + destroyTest, + initTest, + insertDefaultDataSet, + testSql, +} from './test-setup.js' + +for (const dialect of DIALECTS) { + describe(`${dialect}: case`, () => { + let ctx: TestContext + + before(async function () { + ctx = await initTest(this, dialect) + }) + + beforeEach(async () => { + await insertDefaultDataSet(ctx) + }) + + afterEach(async () => { + await clearDatabase(ctx) + }) + + after(async () => { + await destroyTest(ctx) + }) + + it('should execute a query with a case...when...then...end operator', async () => { + const query = ctx.db + .selectFrom('person') + .select((eb) => + eb.case().when('gender', '=', 'male').then('Mr.').end().as('title') + ) + + testSql(query, dialect, { + postgres: { + sql: `select case when "gender" = $1 then $2 end as "title" from "person"`, + parameters: ['male', 'Mr.'], + }, + mysql: { + sql: 'select case when `gender` = ? then ? end as `title` from `person`', + parameters: ['male', 'Mr.'], + }, + sqlite: { + sql: `select case when "gender" = ? then ? end as "title" from "person"`, + parameters: ['male', 'Mr.'], + }, + }) + + await query.execute() + }) + + it('should execute a query with a case...value...when...then...end operator', async () => { + const query = ctx.db + .selectFrom('person') + .select((eb) => + eb.case('gender').when('male').then('Mr.').end().as('title') + ) + + testSql(query, dialect, { + postgres: { + sql: `select case "gender" when $1 then $2 end as "title" from "person"`, + parameters: ['male', 'Mr.'], + }, + mysql: { + sql: 'select case `gender` when ? then ? end as `title` from `person`', + parameters: ['male', 'Mr.'], + }, + sqlite: { + sql: `select case "gender" when ? then ? end as "title" from "person"`, + parameters: ['male', 'Mr.'], + }, + }) + + await query.execute() + }) + + it('should execute a query with a case...when...then...when...then...end operator', async () => { + const query = ctx.db + .selectFrom('person') + .select((eb) => + eb + .case() + .when(eb.cmpr('gender', '=', 'male')) + .then(sql.lit('Mr.')) + .when(eb.cmpr('gender', '=', 'female')) + .then(sql.lit('Mrs.')) + .end() + .as('title') + ) + + testSql(query, dialect, { + postgres: { + sql: [ + `select case when "gender" = $1 then 'Mr.'`, + `when "gender" = $2 then 'Mrs.'`, + `end as "title" from "person"`, + ], + parameters: ['male', 'female'], + }, + mysql: { + sql: [ + "select case when `gender` = ? then 'Mr.'", + "when `gender` = ? then 'Mrs.'", + 'end as `title` from `person`', + ], + parameters: ['male', 'female'], + }, + sqlite: { + sql: [ + `select case when "gender" = ? then 'Mr.'`, + `when "gender" = ? then 'Mrs.'`, + `end as "title" from "person"`, + ], + parameters: ['male', 'female'], + }, + }) + + await query.execute() + }) + + it('should execute a query with a case...value...when...then...when...then...end operator', async () => { + const query = ctx.db + .selectFrom('person') + .select((eb) => + eb + .case('gender') + .when(sql.lit('male')) + .then(sql.lit('Mr.')) + .when(sql.lit('female')) + .then(sql.lit('Mrs.')) + .end() + .as('title') + ) + + testSql(query, dialect, { + postgres: { + sql: [ + `select case "gender" when 'male' then 'Mr.'`, + `when 'female' then 'Mrs.'`, + `end as "title" from "person"`, + ], + parameters: [], + }, + mysql: { + sql: [ + "select case `gender` when 'male' then 'Mr.'", + "when 'female' then 'Mrs.'", + 'end as `title` from `person`', + ], + parameters: [], + }, + sqlite: { + sql: [ + `select case "gender" when 'male' then 'Mr.'`, + `when 'female' then 'Mrs.'`, + `end as "title" from "person"`, + ], + parameters: [], + }, + }) + + await query.execute() + }) + + it('should execute a query with a case...when...then...when...then...else...end operator', async () => { + const query = ctx.db + .selectFrom('person') + .select((eb) => + eb + .case() + .when('gender', '=', 'male') + .then('Mr.') + .when('gender', '=', 'female') + .then('Mrs.') + .else(null) + .end() + .as('title') + ) + + testSql(query, dialect, { + postgres: { + sql: [ + `select case when "gender" = $1 then $2`, + `when "gender" = $3 then $4`, + `else $5 end as "title" from "person"`, + ], + parameters: ['male', 'Mr.', 'female', 'Mrs.', null], + }, + mysql: { + sql: [ + 'select case when `gender` = ? then ?', + 'when `gender` = ? then ?', + 'else ? end as `title` from `person`', + ], + parameters: ['male', 'Mr.', 'female', 'Mrs.', null], + }, + sqlite: { + sql: [ + `select case when "gender" = ? then ?`, + `when "gender" = ? then ?`, + `else ? end as "title" from "person"`, + ], + parameters: ['male', 'Mr.', 'female', 'Mrs.', null], + }, + }) + + await query.execute() + }) + + it('should execute a query with a case...value...when...then...when...then...(case...when...then...else...end)...end operator', async () => { + const query = ctx.db.selectFrom('person').select((eb) => + eb + .case('gender') + .when('male') + .then('Mr.') + .when('female') + .then( + eb + .case() + .when( + eb.or([ + eb.cmpr('marital_status', '=', 'single'), + eb.cmpr('marital_status', 'is', null), + ]) + ) + .then('Ms.') + .else('Mrs.') + .end() + ) + .end() + .as('title') + ) + + testSql(query, dialect, { + postgres: { + sql: [ + 'select case "gender" when $1 then $2', + 'when $3 then', + 'case when ("marital_status" = $4 or', + '"marital_status" is null) then $5', + 'else $6 end', + 'end as "title" from "person"', + ], + parameters: ['male', 'Mr.', 'female', 'single', 'Ms.', 'Mrs.'], + }, + mysql: { + sql: [ + 'select case `gender` when ? then ?', + 'when ? then', + 'case when (`marital_status` = ? or', + '`marital_status` is null) then ?', + 'else ? end', + 'end as `title` from `person`', + ], + parameters: ['male', 'Mr.', 'female', 'single', 'Ms.', 'Mrs.'], + }, + sqlite: { + sql: [ + 'select case "gender" when ? then ?', + 'when ? then', + 'case when ("marital_status" = ? or', + '"marital_status" is null) then ?', + 'else ? end', + 'end as "title" from "person"', + ], + parameters: ['male', 'Mr.', 'female', 'single', 'Ms.', 'Mrs.'], + }, + }) + + await query.execute() + }) + }) +} diff --git a/test/node/src/introspect.test.ts b/test/node/src/introspect.test.ts index 7e9ff5deb..56a8bfe2f 100644 --- a/test/node/src/introspect.test.ts +++ b/test/node/src/introspect.test.ts @@ -93,6 +93,14 @@ for (const dialect of DIALECTS) { isAutoIncrementing: false, hasDefaultValue: false, }, + { + name: 'marital_status', + dataType: 'varchar', + dataTypeSchema: 'pg_catalog', + isNullable: true, + isAutoIncrementing: false, + hasDefaultValue: false, + }, ], }, { @@ -255,6 +263,13 @@ for (const dialect of DIALECTS) { isAutoIncrementing: false, hasDefaultValue: false, }, + { + name: 'marital_status', + dataType: 'varchar', + isNullable: true, + isAutoIncrementing: false, + hasDefaultValue: false, + }, ], }, { @@ -384,6 +399,13 @@ for (const dialect of DIALECTS) { isAutoIncrementing: false, hasDefaultValue: false, }, + { + name: 'marital_status', + dataType: 'varchar(50)', + isNullable: true, + isAutoIncrementing: false, + hasDefaultValue: false, + }, ], }, { diff --git a/test/node/src/test-setup.ts b/test/node/src/test-setup.ts index bf9522cab..6dee0c6c0 100644 --- a/test/node/src/test-setup.ts +++ b/test/node/src/test-setup.ts @@ -39,6 +39,7 @@ export interface Person { middle_name: ColumnType last_name: string | null gender: 'male' | 'female' | 'other' + marital_status: 'single' | 'married' | 'divorced' | 'widowed' | null } export interface Pet { @@ -197,18 +198,21 @@ export const DEFAULT_DATA_SET: PersonInsertParams[] = [ last_name: 'Aniston', gender: 'female', pets: [{ name: 'Catto', species: 'cat' }], + marital_status: 'divorced', }, { first_name: 'Arnold', last_name: 'Schwarzenegger', gender: 'male', pets: [{ name: 'Doggo', species: 'dog' }], + marital_status: 'divorced', }, { first_name: 'Sylvester', last_name: 'Stallone', gender: 'male', pets: [{ name: 'Hammo', species: 'hamster' }], + marital_status: 'married', }, ] @@ -248,6 +252,7 @@ async function createDatabase( .addColumn('middle_name', 'varchar(255)') .addColumn('last_name', 'varchar(255)') .addColumn('gender', 'varchar(50)', (col) => col.notNull()) + .addColumn('marital_status', 'varchar(50)') .execute() await createTableWithId(db.schema, dialect, 'pet') diff --git a/test/typings/shared.d.ts b/test/typings/shared.d.ts index 0675b5b3f..47fcd4c77 100644 --- a/test/typings/shared.d.ts +++ b/test/typings/shared.d.ts @@ -38,6 +38,7 @@ export interface Person { last_name: string | null age: number gender: 'male' | 'female' | 'other' + marital_status: 'single' | 'married' | 'divorced' | 'widowed' | null // A Column that is generated by the DB and which // we never want the user to be able to insert or // update. diff --git a/test/typings/test-d/case.test-d.ts b/test/typings/test-d/case.test-d.ts new file mode 100644 index 000000000..ab7ae5a7b --- /dev/null +++ b/test/typings/test-d/case.test-d.ts @@ -0,0 +1,144 @@ +import { expectError, expectType } from 'tsd' +import { ExpressionBuilder, ExpressionWrapper, sql } from '..' +import { Database } from '../shared' + +async function testCase(eb: ExpressionBuilder) { + // case...when...then...when...then...end + expectType>( + eb + .case() + .when('gender', '=', 'male') + .then('Mr.') + .when('gender', '=', 'female') + .then(12) + .end() + ) + + // case...when...then...when...then...end (as const) + expectType>( + eb + .case() + .when('gender', '=', 'male') + .then('Mr.' as const) + .when('gender', '=', 'female') + .then(12 as const) + .end() + ) + + // case...when...then...when...then...else...end + expectType>( + eb + .case() + .when('gender', '=', 'male') + .then('Mr.') + .when('gender', '=', 'female') + .then(12) + .else(true) + .end() + ) + + // case...when...then...when...then...else...end (as const) + expectType>( + eb + .case() + .when('gender', '=', 'male') + .then('Mr.' as const) + .when('gender', '=', 'female') + .then(12 as const) + .else(true as const) + .end() + ) + + // nested case + expectType>( + eb + .case() + .when('gender', '=', 'male') + .then('Mr.' as const) + .when('gender', '=', 'female') + .then( + eb + .case() + .when('marital_status', '=', 'single') + .then('Ms.' as const) + .else('Mrs.' as const) + .end() + ) + .end() + ) + + // references + expectType>( + eb + .case() + .when('gender', '=', 'male') + .then(eb.ref('first_name')) + .else(eb.ref('age')) + .end() + ) + + // expressions + expectType>( + eb + .case() + .when('gender', '=', 'male') + .then( + eb.fn<`Mr. ${string}`>('concat', [ + eb.val('Mr.'), + sql.lit(' '), + eb.ref('last_name'), + ]) + ) + .end() + ) + + // errors + + expectError(eb.case().when('no_such_column', '=', 'male').then('Mr.').end()) + expectError(eb.case().when('gender', '??', 'male').then('Mr.').end()) + expectError(eb.case().when('gender', '=', 42).then('Mr.').end()) + expectError(eb.case().when('male').then('Mr.').end()) +} + +function testCaseValue(eb: ExpressionBuilder) { + // case...value...when...then...when...then...end + expectType>( + eb.case('gender').when('male').then('Mr.').when('female').then(12).end() + ) + + // case...value...when...then...when...then...else...end + expectType>( + eb + .case('gender') + .when('male') + .then('Mr.') + .when('female') + .then(12) + .else(true) + .end() + ) + + // nested case + expectType>( + eb + .case('gender') + .when('male') + .then('Mr.' as const) + .when('female') + .then( + eb + .case('marital_status') + .when('single') + .then('Ms.' as const) + .else('Mrs.' as const) + .end() + ) + .end() + ) + + // errors + + expectError(eb.case('no_such_column').when('male').then('Mr.').end()) + expectError(eb.case('gender').when('robot').then('Mr.').end()) + expectError(eb.case('gender').when('gender', '=', 'male').then('Mr.').end()) +} diff --git a/test/typings/test-d/delete-query-builder.test-d.ts b/test/typings/test-d/delete-query-builder.test-d.ts index c2e6944f2..bf33c2331 100644 --- a/test/typings/test-d/delete-query-builder.test-d.ts +++ b/test/typings/test-d/delete-query-builder.test-d.ts @@ -113,6 +113,7 @@ async function testDelete(db: Kysely) { age: number gender: 'male' | 'female' | 'other' modified_at: Date + marital_status: 'single' | 'married' | 'divorced' | 'widowed' | null name: string owner_id: number @@ -137,6 +138,7 @@ async function testDelete(db: Kysely) { age: number gender: 'male' | 'female' | 'other' modified_at: Date + marital_status: 'single' | 'married' | 'divorced' | 'widowed' | null name: string owner_id: number @@ -174,6 +176,7 @@ async function testDelete(db: Kysely) { age: number gender: 'male' | 'female' | 'other' modified_at: Date + marital_status: 'single' | 'married' | 'divorced' | 'widowed' | null name: string owner_id: number diff --git a/test/typings/test-d/join.test-d.ts b/test/typings/test-d/join.test-d.ts index d2e5985b4..06eb40a71 100644 --- a/test/typings/test-d/join.test-d.ts +++ b/test/typings/test-d/join.test-d.ts @@ -18,6 +18,7 @@ async function testJoin(db: Kysely) { age: number gender: 'male' | 'female' | 'other' modified_at: Date + marital_status: 'single' | 'married' | 'divorced' | 'widowed' | null name: string species: 'cat' | 'dog' @@ -80,6 +81,7 @@ async function testJoin(db: Kysely) { age: number gender: 'male' | 'female' | 'other' modified_at: Date + marital_status: 'single' | 'married' | 'divorced' | 'widowed' | null // All Pet columns should be nullable because of the left join name: string | null @@ -106,6 +108,7 @@ async function testJoin(db: Kysely) { age: number | null gender: 'male' | 'female' | 'other' | null modified_at: Date | null + marital_status: 'single' | 'married' | 'divorced' | 'widowed' | null // All Pet columns should also be nullable because there's another // right join after the Pet join. @@ -134,6 +137,7 @@ async function testJoin(db: Kysely) { age: number | null gender: 'male' | 'female' | 'other' | null modified_at: Date | null + marital_status: 'single' | 'married' | 'divorced' | 'widowed' | null name: string | null species: 'dog' | 'cat' | null diff --git a/test/typings/test-d/select-from.test-d.ts b/test/typings/test-d/select-from.test-d.ts index ec59f6758..cf6d44b28 100644 --- a/test/typings/test-d/select-from.test-d.ts +++ b/test/typings/test-d/select-from.test-d.ts @@ -13,6 +13,7 @@ async function testFromSingle(db: Kysely) { age: number gender: 'male' | 'female' | 'other' modified_at: Date + marital_status: 'single' | 'married' | 'divorced' | 'widowed' | null }>(r1) // Table with alias