Skip to content

Commit

Permalink
allow arbitrary expressions in count and sum
Browse files Browse the repository at this point in the history
  • Loading branch information
koskimas committed May 28, 2023
1 parent 746cd7b commit bf0cafc
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 24 deletions.
30 changes: 6 additions & 24 deletions src/query-builder/function-module.ts
Original file line number Diff line number Diff line change
Expand Up @@ -180,10 +180,7 @@ export interface FunctionModule<DB, TB extends keyof DB> {
*/
avg<
O extends number | string | null = number | string,
C extends SimpleReferenceExpression<DB, TB> = SimpleReferenceExpression<
DB,
TB
>
C extends ReferenceExpression<DB, TB> = ReferenceExpression<DB, TB>
>(
column: C
): AggregateFunctionBuilder<DB, TB, O>
Expand Down Expand Up @@ -316,10 +313,7 @@ export interface FunctionModule<DB, TB extends keyof DB> {
*/
count<
O extends number | string | bigint,
C extends SimpleReferenceExpression<DB, TB> = SimpleReferenceExpression<
DB,
TB
>
C extends ReferenceExpression<DB, TB> = ReferenceExpression<DB, TB>
>(
column: C
): AggregateFunctionBuilder<DB, TB, O>
Expand Down Expand Up @@ -597,10 +591,7 @@ export interface FunctionModule<DB, TB extends keyof DB> {
*/
sum<
O extends number | string | bigint | null = number | string | bigint,
C extends SimpleReferenceExpression<DB, TB> = SimpleReferenceExpression<
DB,
TB
>
C extends ReferenceExpression<DB, TB> = ReferenceExpression<DB, TB>
>(
column: C
): AggregateFunctionBuilder<DB, TB, O>
Expand Down Expand Up @@ -636,10 +627,7 @@ export function createFunctionModule<DB, TB extends keyof DB>(): FunctionModule<

avg<
O extends number | string | null = number | string,
C extends SimpleReferenceExpression<DB, TB> = SimpleReferenceExpression<
DB,
TB
>
C extends ReferenceExpression<DB, TB> = ReferenceExpression<DB, TB>
>(column: C): AggregateFunctionBuilder<DB, TB, O> {
return agg('avg', [column])
},
Expand All @@ -656,10 +644,7 @@ export function createFunctionModule<DB, TB extends keyof DB>(): FunctionModule<

count<
O extends number | string | bigint,
C extends SimpleReferenceExpression<DB, TB> = SimpleReferenceExpression<
DB,
TB
>
C extends ReferenceExpression<DB, TB> = ReferenceExpression<DB, TB>
>(column: C): AggregateFunctionBuilder<DB, TB, O> {
return agg('count', [column])
},
Expand Down Expand Up @@ -693,10 +678,7 @@ export function createFunctionModule<DB, TB extends keyof DB>(): FunctionModule<

sum<
O extends number | string | bigint | null = number | string | bigint,
C extends SimpleReferenceExpression<DB, TB> = SimpleReferenceExpression<
DB,
TB
>
C extends ReferenceExpression<DB, TB> = ReferenceExpression<DB, TB>
>(column: C): AggregateFunctionBuilder<DB, TB, O> {
return agg('sum', [column])
},
Expand Down
67 changes: 67 additions & 0 deletions test/node/src/select.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,73 @@ for (const dialect of DIALECTS) {
expect(persons).to.eql([{ pet_name: 'Catto' }])
})

it('should select the count of jennifers using count', async () => {
const query = ctx.db
.selectFrom('person')
.select((eb) =>
eb.fn
.count(eb.case().when('first_name', '=', 'Jennifer').then(1).end())
.as('num_jennifers')
)

testSql(query, dialect, {
postgres: {
sql: 'select count(case when "first_name" = $1 then $2 end) as "num_jennifers" from "person"',
parameters: ['Jennifer', 1],
},
mysql: {
sql: 'select count(case when `first_name` = ? then ? end) as `num_jennifers` from `person`',
parameters: ['Jennifer', 1],
},
sqlite: {
sql: 'select count(case when "first_name" = ? then ? end) as "num_jennifers" from "person"',
parameters: ['Jennifer', 1],
},
})

const counts = await query.execute()

expect(counts).to.have.length(1)
if (dialect === 'postgres' || dialect === 'mysql') {
expect(counts[0]).to.eql({ num_jennifers: '1' })
} else {
expect(counts[0]).to.eql({ num_jennifers: 1 })
}
})

if (dialect === 'postgres') {
it('should select the count of jennifers using sum', async () => {
const query = ctx.db
.selectFrom('person')
.select((eb) =>
eb.fn
.sum(
eb
.case()
.when('first_name', '=', 'Jennifer')
.then(sql.lit(1))
.else(sql.lit(0))
.end()
)
.as('num_jennifers')
)

testSql(query, dialect, {
postgres: {
sql: 'select sum(case when "first_name" = $1 then 1 else 0 end) as "num_jennifers" from "person"',
parameters: ['Jennifer'],
},
mysql: NOT_SUPPORTED,
sqlite: NOT_SUPPORTED,
})

const counts = await query.execute()

expect(counts).to.have.length(1)
expect(counts[0]).to.eql({ num_jennifers: '1' })
})
}

// Raw exrpessions are of course supported on all dialects, but we use an
// expression that's only valid on postgres.
if (dialect === 'postgres') {
Expand Down

0 comments on commit bf0cafc

Please sign in to comment.