From 854f484ca79d802612db914c19e648d453dbe911 Mon Sep 17 00:00:00 2001 From: Mark Sirek Date: Wed, 10 Aug 2022 16:30:24 -0700 Subject: [PATCH] opttester: support session settings in opt tests This commit adds the `set` opttest flag which can be used to set session flags via "set=flagname=value". Release note: none --- pkg/sql/opt/testutils/opttester/opt_tester.go | 17 +++++++++++ pkg/sql/opt/xform/testdata/rules/join | 26 ++++++++++++++++- pkg/sql/vars.go | 29 ++++++++++++++++++- 3 files changed, 70 insertions(+), 2 deletions(-) diff --git a/pkg/sql/opt/testutils/opttester/opt_tester.go b/pkg/sql/opt/testutils/opttester/opt_tester.go index e569b032b602..7911d4f9a83e 100644 --- a/pkg/sql/opt/testutils/opttester/opt_tester.go +++ b/pkg/sql/opt/testutils/opttester/opt_tester.go @@ -39,6 +39,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/roachpb" "github.com/cockroachdb/cockroach/pkg/security/username" "github.com/cockroachdb/cockroach/pkg/settings/cluster" + "github.com/cockroachdb/cockroach/pkg/sql" "github.com/cockroachdb/cockroach/pkg/sql/catalog/descpb" "github.com/cockroachdb/cockroach/pkg/sql/catalog/schemaexpr" "github.com/cockroachdb/cockroach/pkg/sql/opt" @@ -265,6 +266,9 @@ type Flags struct { // SkipRace indicates that a test should be skipped if the race detector is // enabled. SkipRace bool + + // ot is a reference to the OptTester owning Flags. + ot *OptTester } // New constructs a new instance of the OptTester for the given SQL statement. @@ -279,6 +283,7 @@ func New(catalog cat.Catalog, sql string) *OptTester { semaCtx: tree.MakeSemaContext(), evalCtx: eval.MakeTestingEvalContext(cluster.MakeTestingClusterSettings()), } + ot.Flags.ot = ot ot.semaCtx.SearchPath = tree.EmptySearchPath ot.semaCtx.FunctionResolver = ot.catalog // To allow opttester tests to use now(), we hardcode a preset transaction @@ -897,6 +902,18 @@ func ruleNamesToRuleSet(args []string) (RuleSet, error) { // See OptTester.RunCommand for supported flags. func (f *Flags) Set(arg datadriven.CmdArg) error { switch arg.Key { + case "set": + for _, val := range arg.Vals { + s := strings.Split(val, "=") + if len(s) != 2 { + return errors.Errorf("Expected both session variable name and value for set command") + } + err := sql.SetSessionVariable(f.ot.ctx, f.ot.evalCtx, s[0], s[1]) + if err != nil { + return err + } + } + case "format": if len(arg.Vals) == 0 { return fmt.Errorf("format flag requires value(s)") diff --git a/pkg/sql/opt/xform/testdata/rules/join b/pkg/sql/opt/xform/testdata/rules/join index dcec492190ea..780562503f9d 100644 --- a/pkg/sql/opt/xform/testdata/rules/join +++ b/pkg/sql/opt/xform/testdata/rules/join @@ -10230,8 +10230,32 @@ inner-join (merge) │ └── columns: m:1 n:2 └── filters (true) +# The rule should be allowed to fire when the projection is from a join if the +# session flag disable_hoist_projection_in_join_limitation is true. +opt expect=HoistProjectFromInnerJoin set=disable_hoist_projection_in_join_limitation=true +SELECT * FROM (SELECT a, a+b FROM (SELECT tab1.* from abcd tab1, abcd tab2)) JOIN small ON a=m; +---- +project + ├── columns: a:1!null "?column?":13 m:14!null n:15 + ├── immutable + ├── fd: (1)==(14), (14)==(1) + ├── inner-join (cross) + │ ├── columns: tab1.a:1!null tab1.b:2 m:14!null n:15 + │ ├── fd: (1)==(14), (14)==(1) + │ ├── scan abcd@abcd_a_b_idx [as=tab2] + │ ├── inner-join (lookup abcd@abcd_a_b_idx [as=tab1]) + │ │ ├── columns: tab1.a:1!null tab1.b:2 m:14!null n:15 + │ │ ├── key columns: [14] = [1] + │ │ ├── fd: (1)==(14), (14)==(1) + │ │ ├── scan small + │ │ │ └── columns: m:14 n:15 + │ │ └── filters (true) + │ └── filters (true) + └── projections + └── tab1.a:1 + tab1.b:2 [as="?column?":13, outer=(1,2), immutable] + # The rule should not fire when the projection is from a join. -opt expect-not=HoistProjectFromInnerJoin +opt expect-not=HoistProjectFromInnerJoin set=disable_hoist_projection_in_join_limitation=false SELECT * FROM (SELECT a, a+b FROM (SELECT tab1.* from abcd tab1, abcd tab2)) JOIN small ON a=m ---- inner-join (hash) diff --git a/pkg/sql/vars.go b/pkg/sql/vars.go index f52e281cead5..0da75b1179bd 100644 --- a/pkg/sql/vars.go +++ b/pkg/sql/vars.go @@ -2119,6 +2119,7 @@ var varGen = map[string]sessionVar{ return formatFloatAsPostgresSetting(0) }, }, + // CockroachDB extension. `disable_hoist_projection_in_join_limitation`: { GetStringVal: makePostgresBoolGetStringValFn(`disable_hoist_projection_in_join_limitation`), @@ -2130,7 +2131,7 @@ var varGen = map[string]sessionVar{ m.SetDisableHoistProjectionInJoinLimitation(b) return nil }, - Get: func(evalCtx *extendedEvalContext) (string, error) { + Get: func(evalCtx *extendedEvalContext, _ *kv.Txn) (string, error) { return formatBoolAsPostgresSetting(evalCtx.SessionData().DisableHoistProjectionInJoinLimitation), nil }, GlobalDefault: globalFalse, @@ -2184,6 +2185,32 @@ func init() { }() } +// SetSessionVariable sets a new value for session setting `varName` is the +// session settings owned by `evalCtx`, returning an error if not successful. +func SetSessionVariable( + ctx context.Context, evalCtx eval.Context, varName, varValue string, +) (err error) { + err = CheckSessionVariableValueValid(ctx, evalCtx.Settings, varName, varValue) + if err != nil { + return err + } + sessionDataMutatorBase := sessionDataMutatorBase{ + defaults: make(map[string]string), + settings: evalCtx.Settings, + } + sessionDataMutator := sessionDataMutator{ + data: evalCtx.SessionData(), + sessionDataMutatorBase: sessionDataMutatorBase, + sessionDataMutatorCallbacks: sessionDataMutatorCallbacks{}, + } + _, sVar, err := getSessionVar(varName, false) + if err != nil { + return err + } + + return sVar.Set(ctx, sessionDataMutator, varValue) +} + // makePostgresBoolGetStringValFn returns a function that evaluates and returns // a string representation of the first argument value. func makePostgresBoolGetStringValFn(varName string) getStringValFn {