Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewrite Stream queries that query by schema_name #11090

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions go/mysql/fakesqldb/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,14 @@ func New(t testing.TB) *DB {
return db
}

// Name returns the name of the DB.
func (db *DB) Name() string {
db.mu.Lock()
defer db.mu.Unlock()

return db.name
}

// SetName sets the name of the DB. to differentiate them in tests if needed.
func (db *DB) SetName(name string) *DB {
db.mu.Lock()
Expand Down
10 changes: 10 additions & 0 deletions go/test/endtoend/vtgate/gen4/system_schema_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,21 @@ func TestDbNameOverride(t *testing.T) {
conn, err := mysql.Connect(ctx, &vtParams)
require.Nil(t, err)
defer conn.Close()

// Test query in OLTP workload (default).
qr, err := conn.ExecuteFetch("SELECT distinct database() FROM information_schema.tables WHERE table_schema = database()", 1000, true)

require.Nil(t, err)
assert.Equal(t, 1, len(qr.Rows), "did not get enough rows back")
assert.Equal(t, "vt_ks", qr.Rows[0][0].ToString())

// Test again in OLAP workload (default).
utils.Exec(t, conn, "SET workload=OLAP")
qr, err = conn.ExecuteFetch("SELECT distinct database() FROM information_schema.tables WHERE table_schema = database()", 1000, true)

require.Nil(t, err)
assert.Equal(t, 1, len(qr.Rows), "did not get enough rows back")
assert.Equal(t, "vt_ks", qr.Rows[0][0].ToString())
}

func TestInformationSchemaQuery(t *testing.T) {
Expand Down
7 changes: 7 additions & 0 deletions go/vt/vttablet/tabletserver/query_executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,13 @@ func (qre *QueryExecutor) Stream(callback StreamCallback) error {
return err
}

switch qre.plan.PlanID {
case p.PlanSelectStream:
if qre.bindVars[sqltypes.BvReplaceSchemaName] != nil {
qre.bindVars[sqltypes.BvSchemaName] = sqltypes.StringBindVariable(qre.tsv.config.DB.DBName)
}
}

sql, sqlWithoutComments, err := qre.generateFinalSQL(qre.plan.FullQuery, qre.bindVars)
if err != nil {
return err
Expand Down
75 changes: 75 additions & 0 deletions go/vt/vttablet/tabletserver/query_executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1203,6 +1203,64 @@ func TestQueryExecutorDenyListQRRetry(t *testing.T) {
}
}

func TestReplaceSchemaName(t *testing.T) {
db := setUpQueryExecutorTest(t)
defer db.Close()

queryFmt := "select * from information_schema.schema_name where schema_name = %s"
inQuery := fmt.Sprintf(queryFmt, ":"+sqltypes.BvSchemaName)
wantQuery := fmt.Sprintf(queryFmt, fmt.Sprintf(
"'%s' limit %d",
db.Name(),
10001,
))
wantQueryStream := fmt.Sprintf(queryFmt, fmt.Sprintf(
"'%s'",
db.Name(),
))

ctx := context.Background()
tsv := newTestTabletServer(ctx, noFlags, db)
defer tsv.StopService()

db.AddQuery(wantQuery, &sqltypes.Result{
Fields: getTestTableFields(),
})

db.AddQuery(wantQueryStream, &sqltypes.Result{
Fields: getTestTableFields(),
})

// Test non streaming execute.
{
qre := newTestQueryExecutor(ctx, tsv, inQuery, 0)
assert.Equal(t, planbuilder.PlanSelect, qre.plan.PlanID)
// Any value other than nil should cause QueryExecutor to replace the
// schema name.
qre.bindVars[sqltypes.BvReplaceSchemaName] = sqltypes.NullBindVariable
_, err := qre.Execute()
require.NoError(t, err)
_, ok := qre.bindVars[sqltypes.BvSchemaName]
require.True(t, ok)
}

// Test streaming execute.
{
qre := newTestQueryExecutorStreaming(ctx, tsv, inQuery, 0)
// Stream only replaces schema name when plan is PlanSelectStream.
assert.Equal(t, planbuilder.PlanSelectStream, qre.plan.PlanID)
// Any value other than nil should cause QueryExecutor to replace the
// schema name.
qre.bindVars[sqltypes.BvReplaceSchemaName] = sqltypes.NullBindVariable
err := qre.Stream(func(_ *sqltypes.Result) error {
_, ok := qre.bindVars[sqltypes.BvSchemaName]
require.True(t, ok)
return nil
})
require.NoError(t, err)
}
}

type executorFlags int64

const (
Expand Down Expand Up @@ -1288,6 +1346,23 @@ func newTestQueryExecutor(ctx context.Context, tsv *TabletServer, sql string, tx
}
}

func newTestQueryExecutorStreaming(ctx context.Context, tsv *TabletServer, sql string, txID int64) *QueryExecutor {
logStats := tabletenv.NewLogStats(ctx, "TestQueryExecutorStreaming")
plan, err := tsv.qe.GetStreamPlan(sql)
if err != nil {
panic(err)
}
return &QueryExecutor{
ctx: ctx,
query: sql,
bindVars: make(map[string]*querypb.BindVariable),
connID: txID,
plan: plan,
logStats: logStats,
tsv: tsv,
}
}

func setUpQueryExecutorTest(t *testing.T) *fakesqldb.DB {
db := fakesqldb.New(t)
initQueryExecutorTestDB(db)
Expand Down