Skip to content

Commit

Permalink
sqlsmith: make order-dependent aggregation functions deterministic
Browse files Browse the repository at this point in the history
Some aggregation functions (e.g. string_agg) have results that depend
on the order of input rows. To make sqlsmith more deterministic, add
ORDER BY clauses to these aggregation functions whenever their argument
is a column reference. (When their argument is a constant, ordering
doesn't matter.)

Fixes: #83024

Release note: None
  • Loading branch information
michae2 committed Jul 14, 2022
1 parent 784f20a commit f118f5d
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 31 deletions.
2 changes: 1 addition & 1 deletion pkg/cmd/roachtest/tests/query_comparison_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ func runOneRoundQueryComparison(

// Initialize a smither that generates only deterministic SELECT statements.
smither, err := sqlsmith.NewSmither(conn, rnd,
sqlsmith.DisableMutations(), sqlsmith.DisableImpureFns(), sqlsmith.DisableLimits(),
sqlsmith.DisableMutations(), sqlsmith.DisableNondeterministicFns(), sqlsmith.DisableLimits(),
sqlsmith.UnlikelyConstantPredicate(), sqlsmith.FavorCommonData(),
sqlsmith.UnlikelyRandomNulls(), sqlsmith.DisableCrossJoins(),
sqlsmith.DisableIndexHints(), sqlsmith.DisableWith(),
Expand Down
2 changes: 1 addition & 1 deletion pkg/cmd/roachtest/tests/tlp.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ func runOneTLP(

// Initialize a smither that will never generate mutations.
tlpSmither, err := sqlsmith.NewSmither(conn, rnd,
sqlsmith.DisableMutations(), sqlsmith.DisableImpureFns())
sqlsmith.DisableMutations(), sqlsmith.DisableNondeterministicFns())
if err != nil {
t.Fatal(err)
}
Expand Down
38 changes: 19 additions & 19 deletions pkg/cmd/smith/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,25 +42,25 @@ var (
num = flags.Int("num", 1, "number of statements / expressions to generate")
url = flags.String("url", "", "database to fetch schema from")
smitherOptMap = map[string]sqlsmith.SmitherOption{
"DisableMutations": sqlsmith.DisableMutations(),
"DisableDDLs": sqlsmith.DisableDDLs(),
"OnlyNoDropDDLs": sqlsmith.OnlyNoDropDDLs(),
"MultiRegionDDLs": sqlsmith.MultiRegionDDLs(),
"DisableWith": sqlsmith.DisableWith(),
"DisableImpureFns": sqlsmith.DisableImpureFns(),
"DisableCRDBFns": sqlsmith.DisableCRDBFns(),
"SimpleDatums": sqlsmith.SimpleDatums(),
"MutationsOnly": sqlsmith.MutationsOnly(),
"InsUpdOnly": sqlsmith.InsUpdOnly(),
"DisableLimits": sqlsmith.DisableLimits(),
"AvoidConsts": sqlsmith.AvoidConsts(),
"DisableWindowFuncs": sqlsmith.DisableWindowFuncs(),
"OutputSort": sqlsmith.OutputSort(),
"UnlikelyConstantPredicate": sqlsmith.UnlikelyConstantPredicate(),
"FavorCommonData": sqlsmith.FavorCommonData(),
"UnlikelyRandomNulls": sqlsmith.UnlikelyRandomNulls(),
"DisableCrossJoins": sqlsmith.DisableCrossJoins(),
"DisableIndexHints": sqlsmith.DisableIndexHints(),
"DisableMutations": sqlsmith.DisableMutations(),
"DisableDDLs": sqlsmith.DisableDDLs(),
"OnlyNoDropDDLs": sqlsmith.OnlyNoDropDDLs(),
"MultiRegionDDLs": sqlsmith.MultiRegionDDLs(),
"DisableWith": sqlsmith.DisableWith(),
"DisableNondeterministicFns": sqlsmith.DisableNondeterministicFns(),
"DisableCRDBFns": sqlsmith.DisableCRDBFns(),
"SimpleDatums": sqlsmith.SimpleDatums(),
"MutationsOnly": sqlsmith.MutationsOnly(),
"InsUpdOnly": sqlsmith.InsUpdOnly(),
"DisableLimits": sqlsmith.DisableLimits(),
"AvoidConsts": sqlsmith.AvoidConsts(),
"DisableWindowFuncs": sqlsmith.DisableWindowFuncs(),
"OutputSort": sqlsmith.OutputSort(),
"UnlikelyConstantPredicate": sqlsmith.UnlikelyConstantPredicate(),
"FavorCommonData": sqlsmith.FavorCommonData(),
"UnlikelyRandomNulls": sqlsmith.UnlikelyRandomNulls(),
"DisableCrossJoins": sqlsmith.DisableCrossJoins(),
"DisableIndexHints": sqlsmith.DisableIndexHints(),
"LowProbabilityWhereClauseWithJoinTables": sqlsmith.LowProbabilityWhereClauseWithJoinTables(),
"DisableInsertSelect": sqlsmith.DisableInsertSelect(),
"CompareMode": sqlsmith.CompareMode(),
Expand Down
40 changes: 35 additions & 5 deletions pkg/internal/sqlsmith/scalar.go
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ func makeFunc(s *Smither, ctx Context, typ *types.T, refs colRefs) (tree.TypedEx
return nil, false
}
fn := fns[s.rnd.Intn(len(fns))]
if s.disableImpureFns && fn.overload.Volatility > volatility.Immutable {
if s.disableNondeterministicFns && fn.overload.Volatility > volatility.Immutable {
return nil, false
}
for _, ignore := range s.ignoreFNs {
Expand All @@ -410,6 +410,9 @@ func makeFunc(s *Smither, ctx Context, typ *types.T, refs colRefs) (tree.TypedEx
}
}

// Some aggregation functions benefit from an order by clause.
var orderExpr tree.Expr

args := make(tree.TypedExprs, 0)
for _, argTyp := range fn.overload.Types.Types() {
// Postgres is picky about having Int4 arguments instead of Int8.
Expand All @@ -421,6 +424,9 @@ func makeFunc(s *Smither, ctx Context, typ *types.T, refs colRefs) (tree.TypedEx
if class == tree.AggregateClass || class == tree.WindowClass {
var ok bool
arg, ok = makeColRef(s, argTyp, refs)
if ok && len(args) == 0 {
orderExpr = arg
}
if !ok {
// If we can't find a col ref for our aggregate function, just use a
// constant.
Expand Down Expand Up @@ -469,9 +475,7 @@ func makeFunc(s *Smither, ctx Context, typ *types.T, refs colRefs) (tree.TypedEx
}
}

// Cast the return and arguments to prevent ambiguity during function
// implementation choosing.
return castType(tree.NewTypedFuncExpr(
funcExpr := tree.NewTypedFuncExpr(
tree.ResolvableFunctionReference{FunctionReference: fn.def},
0, /* aggQualifier */
args,
Expand All @@ -480,7 +484,33 @@ func makeFunc(s *Smither, ctx Context, typ *types.T, refs colRefs) (tree.TypedEx
typ,
&fn.def.FunctionProperties,
fn.overload,
), typ), true
)

// Some aggregation functions need an order by clause to be deterministic.
if s.disableNondeterministicFns {
switch fn.def.Name {
case "array_agg",
"concat_agg",
"json_agg",
"json_object_agg",
"jsonb_agg",
"jsonb_object_agg",
"st_makeline",
"string_agg",
"xmlagg":
if orderExpr != nil {
funcExpr.AggType = tree.GeneralAgg
funcExpr.OrderBy = tree.OrderBy{{
Expr: orderExpr,
Direction: s.randDirection(),
}}
}
}
}

// Cast the return and arguments to prevent ambiguity during function
// implementation choosing.
return castType(funcExpr, typ), true
}

var windowFrameModes = []treewindow.WindowFrameMode{
Expand Down
10 changes: 5 additions & 5 deletions pkg/internal/sqlsmith/sqlsmith.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ type Smither struct {
scalarExprSampler, boolExprSampler *scalarExprSampler

disableWith bool
disableImpureFns bool
disableNondeterministicFns bool
disableLimits bool
disableWindowFuncs bool
simpleDatums bool
Expand Down Expand Up @@ -318,9 +318,9 @@ var DisableWith = simpleOption("disable WITH", func(s *Smither) {
s.disableWith = true
})

// DisableImpureFns causes the Smither to disable impure functions.
var DisableImpureFns = simpleOption("disable impure funcs", func(s *Smither) {
s.disableImpureFns = true
// DisableNondeterministicFns causes the Smither to disable nondeterministic functions.
var DisableNondeterministicFns = simpleOption("disable nondeterministic funcs", func(s *Smither) {
s.disableNondeterministicFns = true
})

// DisableCRDBFns causes the Smither to disable crdb_internal functions.
Expand Down Expand Up @@ -435,7 +435,7 @@ var DisableInsertSelect = simpleOption("disable insert select", func(s *Smither)
var CompareMode = multiOption(
"compare mode",
DisableMutations(),
DisableImpureFns(),
DisableNondeterministicFns(),
DisableCRDBFns(),
IgnoreFNs("^version"),
DisableLimits(),
Expand Down

0 comments on commit f118f5d

Please sign in to comment.