diff --git a/go/vt/vttablet/tabletmanager/vreplication/controller.go b/go/vt/vttablet/tabletmanager/vreplication/controller.go index 8776a37a131..b6cec1b4714 100644 --- a/go/vt/vttablet/tabletmanager/vreplication/controller.go +++ b/go/vt/vttablet/tabletmanager/vreplication/controller.go @@ -199,7 +199,7 @@ func (ct *controller) runBlp(ctx context.Context) (err error) { return vterrors.Wrap(err, "can't connect to database") } for _, query := range withDDLInitialQueries { - if _, err := withDDL.Exec(ctx, query, dbClient.ExecuteFetch); err != nil { + if _, err := withDDL.Exec(ctx, query, dbClient.ExecuteFetch, dbClient.ExecuteFetch); err != nil { log.Errorf("cannot apply withDDL init query '%s': %v", query, err) } } diff --git a/go/vt/vttablet/tabletmanager/vreplication/engine.go b/go/vt/vttablet/tabletmanager/vreplication/engine.go index becca0a083e..93ca3c6a765 100644 --- a/go/vt/vttablet/tabletmanager/vreplication/engine.go +++ b/go/vt/vttablet/tabletmanager/vreplication/engine.go @@ -359,13 +359,13 @@ func (vre *Engine) exec(query string, runAsAdmin bool) (*sqltypes.Result, error) // Change the database to ensure that these events don't get // replicated by another vreplication. This can happen when // we reverse replication. - if _, err := withDDL.Exec(vre.ctx, "use _vt", dbClient.ExecuteFetch); err != nil { + if _, err := withDDL.Exec(vre.ctx, "use _vt", dbClient.ExecuteFetch, dbClient.ExecuteFetch); err != nil { return nil, err } switch plan.opcode { case insertQuery: - qr, err := withDDL.Exec(vre.ctx, plan.query, dbClient.ExecuteFetch) + qr, err := withDDL.Exec(vre.ctx, plan.query, dbClient.ExecuteFetch, dbClient.ExecuteFetch) if err != nil { return nil, err } @@ -413,7 +413,7 @@ func (vre *Engine) exec(query string, runAsAdmin bool) (*sqltypes.Result, error) if err != nil { return nil, err } - qr, err := withDDL.Exec(vre.ctx, query, dbClient.ExecuteFetch) + qr, err := withDDL.Exec(vre.ctx, query, dbClient.ExecuteFetch, dbClient.ExecuteFetch) if err != nil { return nil, err } @@ -461,7 +461,7 @@ func (vre *Engine) exec(query string, runAsAdmin bool) (*sqltypes.Result, error) if err != nil { return nil, err } - qr, err := withDDL.Exec(vre.ctx, query, dbClient.ExecuteFetch) + qr, err := withDDL.Exec(vre.ctx, query, dbClient.ExecuteFetch, dbClient.ExecuteFetch) if err != nil { return nil, err } @@ -479,7 +479,7 @@ func (vre *Engine) exec(query string, runAsAdmin bool) (*sqltypes.Result, error) return qr, nil case selectQuery, reshardingJournalQuery: // select and resharding journal queries are passed through. - return withDDL.Exec(vre.ctx, plan.query, dbClient.ExecuteFetch) + return withDDL.Exec(vre.ctx, plan.query, dbClient.ExecuteFetch, dbClient.ExecuteFetch) } panic("unreachable") } @@ -638,7 +638,7 @@ func (vre *Engine) transitionJournal(je *journalEvent) { bls.Keyspace, bls.Shard = sgtid.Keyspace, sgtid.Shard ig := NewInsertGenerator(binlogplayer.BlpRunning, vre.dbName) ig.AddRow(params["workflow"], bls, sgtid.Gtid, params["cell"], params["tablet_types"]) - qr, err := withDDL.Exec(vre.ctx, ig.String(), dbClient.ExecuteFetch) + qr, err := withDDL.Exec(vre.ctx, ig.String(), dbClient.ExecuteFetch, dbClient.ExecuteFetch) if err != nil { log.Errorf("transitionJournal: %v", err) return @@ -648,7 +648,7 @@ func (vre *Engine) transitionJournal(je *journalEvent) { } for _, ks := range participants { id := je.participants[ks] - _, err := withDDL.Exec(vre.ctx, binlogplayer.DeleteVReplication(uint32(id)), dbClient.ExecuteFetch) + _, err := withDDL.Exec(vre.ctx, binlogplayer.DeleteVReplication(uint32(id)), dbClient.ExecuteFetch, dbClient.ExecuteFetch) if err != nil { log.Errorf("transitionJournal: %v", err) return diff --git a/go/vt/vttablet/tabletmanager/vreplication/utils.go b/go/vt/vttablet/tabletmanager/vreplication/utils.go index db4ff93cedc..b8e1e029c26 100644 --- a/go/vt/vttablet/tabletmanager/vreplication/utils.go +++ b/go/vt/vttablet/tabletmanager/vreplication/utils.go @@ -74,7 +74,7 @@ const ( func getLastLog(dbClient *vdbClient, vreplID uint32) (id int64, typ, state, message string, err error) { var qr *sqltypes.Result query := fmt.Sprintf("select id, type, state, message from _vt.vreplication_log where vrepl_id = %d order by id desc limit 1", vreplID) - if qr, err = withDDL.Exec(context.Background(), query, dbClient.ExecuteFetch); err != nil { + if qr, err = withDDL.Exec(context.Background(), query, dbClient.ExecuteFetch, dbClient.ExecuteFetch); err != nil { return 0, "", "", "", err } if len(qr.Rows) != 1 { @@ -108,7 +108,7 @@ func insertLog(dbClient *vdbClient, typ string, vreplID uint32, state, message s strconv.Itoa(int(vreplID)), encodeString(typ), encodeString(state), encodeString(message)) query = buf.ParsedQuery().Query } - if _, err = withDDL.Exec(context.Background(), query, dbClient.ExecuteFetch); err != nil { + if _, err = withDDL.Exec(context.Background(), query, dbClient.ExecuteFetch, dbClient.ExecuteFetch); err != nil { return fmt.Errorf("could not insert into log table: %v: %v", query, err) } return nil diff --git a/go/vt/vttablet/tabletmanager/vreplication/vreplicator.go b/go/vt/vttablet/tabletmanager/vreplication/vreplicator.go index a993db1fd9d..1060b18d29e 100644 --- a/go/vt/vttablet/tabletmanager/vreplication/vreplicator.go +++ b/go/vt/vttablet/tabletmanager/vreplication/vreplicator.go @@ -321,7 +321,7 @@ func (vr *vreplicator) readSettings(ctx context.Context) (settings binlogplayer. } query := fmt.Sprintf("select count(*) from _vt.copy_state where vrepl_id=%d", vr.id) - qr, err := withDDL.Exec(ctx, query, vr.dbClient.ExecuteFetch) + qr, err := withDDL.Exec(ctx, query, vr.dbClient.ExecuteFetch, vr.dbClient.ExecuteFetch) if err != nil { return settings, numTablesToCopy, err } diff --git a/go/vt/vttablet/tabletserver/repltracker/writer.go b/go/vt/vttablet/tabletserver/repltracker/writer.go index ade3884155b..a95535801e0 100644 --- a/go/vt/vttablet/tabletserver/repltracker/writer.go +++ b/go/vt/vttablet/tabletserver/repltracker/writer.go @@ -67,10 +67,11 @@ type heartbeatWriter struct { now func() time.Time errorLog *logutil.ThrottledLogger - mu sync.Mutex - isOpen bool - pool *dbconnpool.ConnectionPool - ticks *timer.Timer + mu sync.Mutex + isOpen bool + appPool *dbconnpool.ConnectionPool + allPrivsPool *dbconnpool.ConnectionPool + ticks *timer.Timer } // newHeartbeatWriter creates a new heartbeatWriter. @@ -92,7 +93,8 @@ func newHeartbeatWriter(env tabletenv.Env, alias *topodatapb.TabletAlias) *heart errorLog: logutil.NewThrottledLogger("HeartbeatWriter", 60*time.Second), // We make this pool size 2; to prevent pool exhausted // stats from incrementing continually, and causing concern - pool: dbconnpool.NewConnectionPool("HeartbeatWritePool", 2, *mysqlctl.DbaIdleTimeout, *mysqlctl.PoolDynamicHostnameResolution), + appPool: dbconnpool.NewConnectionPool("HeartbeatWriteAppPool", 2, *mysqlctl.DbaIdleTimeout, *mysqlctl.PoolDynamicHostnameResolution), + allPrivsPool: dbconnpool.NewConnectionPool("HeartbeatWriteAllPrivsPool", 2, *mysqlctl.DbaIdleTimeout, *mysqlctl.PoolDynamicHostnameResolution), } } @@ -120,7 +122,8 @@ func (w *heartbeatWriter) Open() { // block this thread, and we could end up in a deadlock. // Instead, we try creating the database and table in each tick which runs in a go routine // keeping us safe from hanging the main thread. - w.pool.Open(w.env.Config().DB.AppWithDB()) + w.appPool.Open(w.env.Config().DB.AppWithDB()) + w.allPrivsPool.Open(w.env.Config().DB.AllPrivsWithDB()) w.enableWrites(true) w.isOpen = true } @@ -137,7 +140,8 @@ func (w *heartbeatWriter) Close() { } w.enableWrites(false) - w.pool.Close() + w.appPool.Close() + w.allPrivsPool.Close() w.isOpen = false log.Info("Hearbeat Writer: closed") } @@ -172,16 +176,22 @@ func (w *heartbeatWriter) write() error { defer w.env.LogError() ctx, cancel := context.WithDeadline(context.Background(), w.now().Add(w.interval)) defer cancel() + allPrivsConn, err := w.allPrivsPool.Get(ctx) + if err != nil { + return err + } + defer allPrivsConn.Recycle() + upsert, err := w.bindHeartbeatVars(sqlUpsertHeartbeat) if err != nil { return err } - conn, err := w.pool.Get(ctx) + appConn, err := w.appPool.Get(ctx) if err != nil { return err } - defer conn.Recycle() - _, err = withDDL.Exec(ctx, upsert, conn.ExecuteFetch) + defer appConn.Recycle() + _, err = withDDL.Exec(ctx, upsert, appConn.ExecuteFetch, allPrivsConn.ExecuteFetch) if err != nil { return err } diff --git a/go/vt/vttablet/tabletserver/repltracker/writer_test.go b/go/vt/vttablet/tabletserver/repltracker/writer_test.go index ff75710656c..f678381ec2b 100644 --- a/go/vt/vttablet/tabletserver/repltracker/writer_test.go +++ b/go/vt/vttablet/tabletserver/repltracker/writer_test.go @@ -105,7 +105,8 @@ func newTestWriter(db *fakesqldb.DB, nowFunc func() time.Time) *heartbeatWriter tw := newHeartbeatWriter(tabletenv.NewEnv(config, "WriterTest"), &topodatapb.TabletAlias{Cell: "test", Uid: 1111}) tw.keyspaceShard = "test:0" tw.now = nowFunc - tw.pool.Open(dbc.AppWithDB()) + tw.appPool.Open(dbc.AppWithDB()) + tw.allPrivsPool.Open(dbc.AllPrivsWithDB()) return tw } diff --git a/go/vt/vttablet/tabletserver/schema/tracker.go b/go/vt/vttablet/tabletserver/schema/tracker.go index 161e6be0300..51204e4db3f 100644 --- a/go/vt/vttablet/tabletserver/schema/tracker.go +++ b/go/vt/vttablet/tabletserver/schema/tracker.go @@ -192,7 +192,7 @@ func (tr *Tracker) isSchemaVersionTableEmpty(ctx context.Context) (bool, error) return false, err } defer conn.Recycle() - result, err := withDDL.Exec(ctx, "select id from _vt.schema_version limit 1", conn.Exec) + result, err := withDDL.Exec(ctx, "select id from _vt.schema_version limit 1", conn.Exec, conn.Exec) if err != nil { return false, err } @@ -258,7 +258,7 @@ func (tr *Tracker) saveCurrentSchemaToDb(ctx context.Context, gtid, ddl string, query := fmt.Sprintf("insert into _vt.schema_version "+ "(pos, ddl, schemax, time_updated) "+ "values (%v, %v, %v, %d)", encodeString(gtid), encodeString(ddl), encodeString(string(blob)), timestamp) - _, err = withDDL.Exec(ctx, query, conn.Exec) + _, err = withDDL.Exec(ctx, query, conn.Exec, conn.Exec) if err != nil { return err } diff --git a/go/vt/withddl/withddl.go b/go/vt/withddl/withddl.go index fe5fcfd9566..fb85143b52e 100644 --- a/go/vt/withddl/withddl.go +++ b/go/vt/withddl/withddl.go @@ -52,17 +52,24 @@ func (wd *WithDDL) DDLs() []string { // Exec executes the query using the supplied function. // If there are any schema errors, it applies the DDLs and retries. +// It takes 2 functions, one to run the query and the other to run the +// DDL commands. This is generally needed so that different users can be used +// to run the commands. i.e. AllPrivs user for DDLs and App user for query commands. // Funcs can be any of these types: // func(query string) (*sqltypes.Result, error) // func(query string, maxrows int) (*sqltypes.Result, error) // func(query string, maxrows int, wantfields bool) (*sqltypes.Result, error) // func(ctx context.Context, query string, maxrows int, wantfields bool) (*sqltypes.Result, error) -func (wd *WithDDL) Exec(ctx context.Context, query string, f interface{}) (*sqltypes.Result, error) { - exec, err := wd.unify(ctx, f) +func (wd *WithDDL) Exec(ctx context.Context, query string, fQuery interface{}, fDDL interface{}) (*sqltypes.Result, error) { + execQuery, err := wd.unify(ctx, fQuery) if err != nil { return nil, err } - qr, err := exec(query) + execDDL, err := wd.unify(ctx, fDDL) + if err != nil { + return nil, err + } + qr, err := execQuery(query) if err == nil { return qr, nil } @@ -72,7 +79,7 @@ func (wd *WithDDL) Exec(ctx context.Context, query string, f interface{}) (*sqlt log.Infof("Updating schema for %v and retrying: %v", sqlparser.TruncateForUI(err.Error()), err) for _, applyQuery := range wd.ddls { - _, merr := exec(applyQuery) + _, merr := execDDL(applyQuery) if merr == nil { continue } @@ -83,7 +90,7 @@ func (wd *WithDDL) Exec(ctx context.Context, query string, f interface{}) (*sqlt // Return the original error. return nil, err } - return exec(query) + return execQuery(query) } // ExecIgnore executes the query using the supplied function. diff --git a/go/vt/withddl/withddl_test.go b/go/vt/withddl/withddl_test.go index 5d624721760..2d7fa221361 100644 --- a/go/vt/withddl/withddl_test.go +++ b/go/vt/withddl/withddl_test.go @@ -185,7 +185,7 @@ func TestExec(t *testing.T) { } wd := New(test.ddls) - qr, err := wd.Exec(ctx, test.query, fun.f) + qr, err := wd.Exec(ctx, test.query, fun.f, fun.f) if test.qr != nil { test.qr.StatusFlags = sqltypes.ServerStatusAutocommit } @@ -232,6 +232,34 @@ func TestExecIgnore(t *testing.T) { assert.Equal(t, 1, len(qr.Rows)) } +func TestDifferentExecFunctions(t *testing.T) { + ctx := context.Background() + conn, err := mysql.Connect(ctx, &connParams) + require.NoError(t, err) + defer conn.Close() + defer conn.ExecuteFetch("drop database t", 10000, true) // nolint:errcheck + + execconn, err := mysql.Connect(ctx, &connParams) + require.NoError(t, err) + defer execconn.Close() + + wd := New([]string{"create database t"}) + _, err = wd.Exec(ctx, "select * from a", func(query string) (*sqltypes.Result, error) { + return nil, mysql.NewSQLError(mysql.ERNoSuchTable, mysql.SSUnknownSQLState, "error in execution") + }, execconn.ExecuteFetch) + require.EqualError(t, err, "error in execution (errno 1146) (sqlstate HY000)") + + res, err := execconn.ExecuteFetch("show databases", 10000, false) + require.NoError(t, err) + foundDatabase := false + for _, row := range res.Rows { + if row[0].ToString() == "t" { + foundDatabase = true + } + } + require.True(t, foundDatabase, "database should be created since DDL should have executed") +} + func checkResult(t *testing.T, wantqr *sqltypes.Result, wanterr string, qr *sqltypes.Result, err error) { t.Helper()