diff --git a/go/vt/sqlparser/comments.go b/go/vt/sqlparser/comments.go index f64c4fd5e7e..24955b5150b 100644 --- a/go/vt/sqlparser/comments.go +++ b/go/vt/sqlparser/comments.go @@ -36,6 +36,8 @@ const ( DirectiveIgnoreMaxPayloadSize = "IGNORE_MAX_PAYLOAD_SIZE" // DirectiveIgnoreMaxMemoryRows skips memory row validation when set. DirectiveIgnoreMaxMemoryRows = "IGNORE_MAX_MEMORY_ROWS" + // DirectiveAllowScatter lets scatter plans pass through even when they are turned off by `no-scatter`. + DirectiveAllowScatter = "ALLOW_SCATTER" ) func isNonSpace(r rune) bool { @@ -347,3 +349,21 @@ func IgnoreMaxMaxMemoryRowsDirective(stmt Statement) bool { return false } } + +// AllowScatterDirective returns true if the allow scatter override is set to true +func AllowScatterDirective(stmt Statement) bool { + var directives CommentDirectives + switch stmt := stmt.(type) { + case *Select: + directives = ExtractCommentDirectives(stmt.Comments) + case *Insert: + directives = ExtractCommentDirectives(stmt.Comments) + case *Update: + directives = ExtractCommentDirectives(stmt.Comments) + case *Delete: + directives = ExtractCommentDirectives(stmt.Comments) + default: + return false + } + return directives.IsSet(DirectiveAllowScatter) +} diff --git a/go/vt/vtexplain/vtexplain_vtgate.go b/go/vt/vtexplain/vtexplain_vtgate.go index 625b3bceb5c..3e6cf48a535 100644 --- a/go/vt/vtexplain/vtexplain_vtgate.go +++ b/go/vt/vtexplain/vtexplain_vtgate.go @@ -70,7 +70,7 @@ func initVtgateExecutor(vSchemaStr, ksShardMapStr string, opts *Options) error { streamSize := 10 var schemaTracker vtgate.SchemaInfo // no schema tracker for these tests - vtgateExecutor = vtgate.NewExecutor(context.Background(), explainTopo, vtexplainCell, resolver, opts.Normalize, false /*do not warn for sharded only*/, streamSize, cache.DefaultConfig, schemaTracker) + vtgateExecutor = vtgate.NewExecutor(context.Background(), explainTopo, vtexplainCell, resolver, opts.Normalize, false /*do not warn for sharded only*/, streamSize, cache.DefaultConfig, schemaTracker, false /*no-scatter*/) return nil } diff --git a/go/vt/vtgate/executor.go b/go/vt/vtgate/executor.go index 3f483081175..401e5ab94f7 100644 --- a/go/vt/vtgate/executor.go +++ b/go/vt/vtgate/executor.go @@ -100,6 +100,9 @@ type Executor struct { vm *VSchemaManager schemaTracker SchemaInfo + + // allowScatter will fail planning if set to false and a plan contains any scatter queries + allowScatter bool } var executorOnce sync.Once @@ -118,6 +121,7 @@ func NewExecutor( streamSize int, cacheCfg *cache.Config, schemaTracker SchemaInfo, + noScatter bool, ) *Executor { e := &Executor{ serv: serv, @@ -130,6 +134,7 @@ func NewExecutor( warnShardedOnly: warnOnShardedOnly, streamSize: streamSize, schemaTracker: schemaTracker, + allowScatter: !noScatter, } vschemaacl.Init() @@ -208,7 +213,7 @@ func (e *Executor) execute(ctx context.Context, safeSession *SafeSession, sql st } func (e *Executor) legacyExecute(ctx context.Context, safeSession *SafeSession, sql string, bindVars map[string]*querypb.BindVariable, logStats *LogStats) (sqlparser.StatementType, *sqltypes.Result, error) { - //Start an implicit transaction if necessary. + // Start an implicit transaction if necessary. if !safeSession.Autocommit && !safeSession.InTransaction() { if err := e.txConn.Begin(ctx, safeSession); err != nil { return 0, nil, err @@ -398,7 +403,7 @@ func (e *Executor) handleCommit(ctx context.Context, safeSession *SafeSession, l return &sqltypes.Result{}, err } -//Commit commits the existing transactions +// Commit commits the existing transactions func (e *Executor) Commit(ctx context.Context, safeSession *SafeSession) error { return e.txConn.Commit(ctx, safeSession) } @@ -552,7 +557,7 @@ func getValueFor(expr *sqlparser.SetExpr) (interface{}, error) { } func (e *Executor) handleSetVitessMetadata(ctx context.Context, name, value string) (*sqltypes.Result, error) { - //TODO(kalfonso): move to its own acl check and consolidate into an acl component that can handle multiple operations (vschema, metadata) + // TODO(kalfonso): move to its own acl check and consolidate into an acl component that can handle multiple operations (vschema, metadata) user := callerid.ImmediateCallerIDFromContext(ctx) allowed := vschemaacl.Authorized(user) if !allowed { @@ -1240,7 +1245,8 @@ func (e *Executor) getPlan(vcursor *vcursorImpl, sql string, comments sqlparser. if !skipQueryPlanCache && !sqlparser.SkipQueryPlanCacheDirective(statement) && sqlparser.CachePlan(statement) { e.plans.Set(planKey, plan) } - return plan, nil + + return e.checkThatPlanIsValid(stmt, plan) } // skipQueryPlanCache extracts SkipQueryPlanCache from session @@ -1457,7 +1463,7 @@ func (e *Executor) handlePrepare(ctx context.Context, safeSession *SafeSession, var errCount uint64 if err != nil { logStats.Error = err - errCount = 1 //nolint + errCount = 1 // nolint return nil, err } logStats.RowsAffected = qr.RowsAffected @@ -1515,3 +1521,23 @@ func (e *Executor) startVStream(ctx context.Context, rss []*srvtopo.ResolvedShar vs.stream(ctx) return nil } + +func (e *Executor) checkThatPlanIsValid(stmt sqlparser.Statement, plan *engine.Plan) (*engine.Plan, error) { + if e.allowScatter || sqlparser.AllowScatterDirective(stmt) { + return plan, nil + } + // we go over all the primitives in the plan, searching for a route that is of SelectScatter opcode + badPrimitive := engine.Find(func(node engine.Primitive) bool { + router, ok := node.(*engine.Route) + if !ok { + return false + } + return router.Opcode == engine.SelectScatter + }, plan.Instructions) + + if badPrimitive == nil { + return plan, nil + } + + return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "plan includes scatter, which is disallowed using the `no_scatter` command line argument") +} diff --git a/go/vt/vtgate/executor_framework_test.go b/go/vt/vtgate/executor_framework_test.go index a3c4100d998..7e58c94dbab 100644 --- a/go/vt/vtgate/executor_framework_test.go +++ b/go/vt/vtgate/executor_framework_test.go @@ -24,6 +24,8 @@ import ( "strings" "testing" + vtgatepb "vitess.io/vitess/go/vt/proto/vtgate" + "github.com/stretchr/testify/require" "github.com/stretchr/testify/assert" @@ -398,7 +400,7 @@ func createLegacyExecutorEnv() (executor *Executor, sbc1, sbc2, sbclookup *sandb bad.VSchema = badVSchema getSandbox(KsTestUnsharded).VSchema = unshardedVSchema - executor = NewExecutor(context.Background(), serv, cell, resolver, false, false, testBufferSize, cache.DefaultConfig, nil) + executor = NewExecutor(context.Background(), serv, cell, resolver, false, false, testBufferSize, cache.DefaultConfig, nil, false) key.AnyShardPicker = DestinationAnyShardPickerFirstShard{} return executor, sbc1, sbc2, sbclookup @@ -433,7 +435,7 @@ func createExecutorEnv() (executor *Executor, sbc1, sbc2, sbclookup *sandboxconn bad.VSchema = badVSchema getSandbox(KsTestUnsharded).VSchema = unshardedVSchema - executor = NewExecutor(context.Background(), serv, cell, resolver, false, false, testBufferSize, cache.DefaultConfig, nil) + executor = NewExecutor(context.Background(), serv, cell, resolver, false, false, testBufferSize, cache.DefaultConfig, nil, false) key.AnyShardPicker = DestinationAnyShardPickerFirstShard{} return executor, sbc1, sbc2, sbclookup @@ -453,19 +455,23 @@ func createCustomExecutor(vschema string) (executor *Executor, sbc1, sbc2, sbclo sbclookup = hc.AddTestTablet(cell, "0", 1, KsTestUnsharded, "0", topodatapb.TabletType_MASTER, true, 1, nil) getSandbox(KsTestUnsharded).VSchema = unshardedVSchema - executor = NewExecutor(context.Background(), serv, cell, resolver, false, false, testBufferSize, cache.DefaultConfig, nil) + executor = NewExecutor(context.Background(), serv, cell, resolver, false, false, testBufferSize, cache.DefaultConfig, nil, false) return executor, sbc1, sbc2, sbclookup } -func executorExec(executor *Executor, sql string, bv map[string]*querypb.BindVariable) (*sqltypes.Result, error) { +func executorExecSession(executor *Executor, sql string, bv map[string]*querypb.BindVariable, session *vtgatepb.Session) (*sqltypes.Result, error) { return executor.Execute( context.Background(), "TestExecute", - NewSafeSession(masterSession), + NewSafeSession(session), sql, bv) } +func executorExec(executor *Executor, sql string, bv map[string]*querypb.BindVariable) (*sqltypes.Result, error) { + return executorExecSession(executor, sql, bv, masterSession) +} + func executorPrepare(executor *Executor, sql string, bv map[string]*querypb.BindVariable) ([]*querypb.Field, error) { return executor.Prepare( context.Background(), diff --git a/go/vt/vtgate/executor_select_test.go b/go/vt/vtgate/executor_select_test.go index 371df7c2acb..05c1ba5af7f 100644 --- a/go/vt/vtgate/executor_select_test.go +++ b/go/vt/vtgate/executor_select_test.go @@ -1049,7 +1049,7 @@ func TestStreamSelectIN(t *testing.T) { } func createExecutor(serv *sandboxTopo, cell string, resolver *Resolver) *Executor { - return NewExecutor(context.Background(), serv, cell, resolver, false, false, testBufferSize, cache.DefaultConfig, nil) + return NewExecutor(context.Background(), serv, cell, resolver, false, false, testBufferSize, cache.DefaultConfig, nil, false) } func TestSelectScatter(t *testing.T) { @@ -2540,7 +2540,7 @@ func TestStreamOrderByLimitWithMultipleResults(t *testing.T) { count++ } - executor := NewExecutor(context.Background(), serv, cell, resolver, true, false, testBufferSize, cache.DefaultConfig, nil) + executor := NewExecutor(context.Background(), serv, cell, resolver, true, false, testBufferSize, cache.DefaultConfig, nil, false) before := runtime.NumGoroutine() query := "select id, col from user order by id limit 2" @@ -2553,3 +2553,44 @@ func TestStreamOrderByLimitWithMultipleResults(t *testing.T) { time.Sleep(100 * time.Millisecond) assert.GreaterOrEqual(t, before, runtime.NumGoroutine(), "left open goroutines lingering") } + +func TestSelectScatterFails(t *testing.T) { + sess := &vtgatepb.Session{} + cell := "aa" + hc := discovery.NewFakeHealthCheck() + s := createSandbox("TestExecutor") + s.VSchema = executorVSchema + getSandbox(KsTestUnsharded).VSchema = unshardedVSchema + serv := new(sandboxTopo) + resolver := newTestResolver(hc, serv, cell) + + shards := []string{"-20", "20-40", "40-60", "60-80", "80-a0", "a0-c0", "c0-e0", "e0-"} + for i, shard := range shards { + sbc := hc.AddTestTablet(cell, shard, 1, "TestExecutor", shard, topodatapb.TabletType_MASTER, true, 1, nil) + sbc.SetResults([]*sqltypes.Result{{ + Fields: []*querypb.Field{ + {Name: "col1", Type: sqltypes.Int32}, + {Name: "col2", Type: sqltypes.Int32}, + {Name: "weight_string(col2)", Type: sqltypes.VarBinary}, + }, + InsertID: 0, + Rows: [][]sqltypes.Value{{ + sqltypes.NewInt32(1), + sqltypes.NewInt32(int32(i % 4)), + sqltypes.NULL, + }}, + }}) + } + + executor := createExecutor(serv, cell, resolver) + executor.allowScatter = false + logChan := QueryLogger.Subscribe("Test") + defer QueryLogger.Unsubscribe(logChan) + + _, err := executorExecSession(executor, "select id from user", nil, sess) + require.Error(t, err) + assert.Contains(t, err.Error(), "scatter") + + _, err = executorExecSession(executor, "select /*vt+ ALLOW_SCATTER */ id from user", nil, sess) + require.NoError(t, err) +} diff --git a/go/vt/vtgate/executor_stream_test.go b/go/vt/vtgate/executor_stream_test.go index 8ca2a9ab1c9..bff5a744539 100644 --- a/go/vt/vtgate/executor_stream_test.go +++ b/go/vt/vtgate/executor_stream_test.go @@ -60,7 +60,7 @@ func TestStreamSQLSharded(t *testing.T) { for _, shard := range shards { _ = hc.AddTestTablet(cell, shard, 1, "TestExecutor", shard, topodatapb.TabletType_MASTER, true, 1, nil) } - executor := NewExecutor(context.Background(), serv, cell, resolver, false, false, testBufferSize, cache.DefaultConfig, nil) + executor := NewExecutor(context.Background(), serv, cell, resolver, false, false, testBufferSize, cache.DefaultConfig, nil, false) sql := "stream * from sharded_user_msgs" result, err := executorStreamMessages(executor, sql) diff --git a/go/vt/vtgate/vtgate.go b/go/vt/vtgate/vtgate.go index e3aaf49fd18..b5d139797f2 100644 --- a/go/vt/vtgate/vtgate.go +++ b/go/vt/vtgate/vtgate.go @@ -68,6 +68,7 @@ var ( warnMemoryRows = flag.Int("warn_memory_rows", 30000, "Warning threshold for in-memory results. A row count higher than this amount will cause the VtGateWarnings.ResultsExceeded counter to be incremented.") defaultDDLStrategy = flag.String("ddl_strategy", string(schema.DDLStrategyDirect), "Set default strategy for DDL statements. Override with @@ddl_strategy session variable") dbDDLPlugin = flag.String("dbddl_plugin", "fail", "controls how to handle CREATE/DROP DATABASE. use it if you are using your own database provisioning service") + noScatter = flag.Bool("no_scatter", false, "when set to true, the planner will fail instead of producing a plan that includes scatter queries") // TODO(deepthi): change these two vars to unexported and move to healthcheck.go when LegacyHealthcheck is removed @@ -214,7 +215,7 @@ func Init(ctx context.Context, serv srvtopo.Server, cell string, tabletTypesToWa LFU: *queryPlanCacheLFU, } - executor := NewExecutor(ctx, serv, cell, resolver, *normalizeQueries, *warnShardedOnly, *streamBufferSize, cacheCfg, si) + executor := NewExecutor(ctx, serv, cell, resolver, *normalizeQueries, *warnShardedOnly, *streamBufferSize, cacheCfg, si, *noScatter) // connect the schema tracker with the vschema manager if *enableSchemaChangeSignal { @@ -618,7 +619,7 @@ func LegacyInit(ctx context.Context, hc discovery.LegacyHealthCheck, serv srvtop } rpcVTGate = &VTGate{ - executor: NewExecutor(ctx, serv, cell, resolver, *normalizeQueries, *warnShardedOnly, *streamBufferSize, cacheCfg, nil), + executor: NewExecutor(ctx, serv, cell, resolver, *normalizeQueries, *warnShardedOnly, *streamBufferSize, cacheCfg, nil, *noScatter), resolver: resolver, vsm: vsm, txConn: tc,