Skip to content

Commit

Permalink
mysqlctl: Correctly encode database and table names (#13312)
Browse files Browse the repository at this point in the history
These need to be escaped as SQL strings for queries against the
`information_schema` or otherwise any names that need quotes will break.

Signed-off-by: Dirkjan Bussink <[email protected]>
  • Loading branch information
dbussink authored Jun 14, 2023
1 parent 5ded6ec commit 3f195a4
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 14 deletions.
26 changes: 14 additions & 12 deletions go/vt/mysqlctl/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ func (mysqld *Mysqld) executeSchemaCommands(sql string) error {
return mysqld.executeMysqlScript(params, strings.NewReader(sql))
}

func encodeTableName(tableName string) string {
func encodeEntityName(name string) string {
var buf strings.Builder
sqltypes.NewVarChar(tableName).EncodeSQL(&buf)
sqltypes.NewVarChar(name).EncodeSQL(&buf)
return buf.String()
}

Expand All @@ -66,7 +66,7 @@ func tableListSQL(tables []string) (string, error) {

encodedTables := make([]string, len(tables))
for i, tableName := range tables {
encodedTables[i] = encodeTableName(tableName)
encodedTables[i] = encodeEntityName(tableName)
}

return "(" + strings.Join(encodedTables, ", ") + ")", nil
Expand Down Expand Up @@ -280,7 +280,7 @@ func ResolveTables(ctx context.Context, mysqld MysqlDaemon, dbName string, table
const (
GetColumnNamesQuery = `SELECT COLUMN_NAME as column_name
FROM INFORMATION_SCHEMA.COLUMNS
WHERE TABLE_SCHEMA = %s AND TABLE_NAME = '%s'
WHERE TABLE_SCHEMA = %s AND TABLE_NAME = %s
ORDER BY ORDINAL_POSITION`
GetFieldsQuery = "SELECT %s FROM %s WHERE 1 != 1"
)
Expand All @@ -290,9 +290,9 @@ func GetColumnsList(dbName, tableName string, exec func(string, int, bool) (*sql
if dbName == "" {
dbName2 = "database()"
} else {
dbName2 = fmt.Sprintf("'%s'", dbName)
dbName2 = encodeEntityName(dbName)
}
query := fmt.Sprintf(GetColumnNamesQuery, dbName2, sqlescape.UnescapeID(tableName))
query := fmt.Sprintf(GetColumnNamesQuery, dbName2, encodeEntityName(sqlescape.UnescapeID(tableName)))
qr, err := exec(query, -1, true)
if err != nil {
return "", err
Expand Down Expand Up @@ -377,9 +377,9 @@ func (mysqld *Mysqld) getPrimaryKeyColumns(ctx context.Context, dbName string, t
sql := `
SELECT TABLE_NAME as table_name, COLUMN_NAME as column_name
FROM information_schema.STATISTICS
WHERE TABLE_SCHEMA = '%s' AND TABLE_NAME IN %s AND LOWER(INDEX_NAME) = 'primary'
WHERE TABLE_SCHEMA = %s AND TABLE_NAME IN %s AND LOWER(INDEX_NAME) = 'primary'
ORDER BY table_name, SEQ_IN_INDEX`
sql = fmt.Sprintf(sql, dbName, tableList)
sql = fmt.Sprintf(sql, encodeEntityName(dbName), tableList)
qr, err := conn.ExecuteFetch(sql, len(tables)*100, true)
if err != nil {
return nil, err
Expand Down Expand Up @@ -599,16 +599,18 @@ func (mysqld *Mysqld) GetPrimaryKeyEquivalentColumns(ctx context.Context, dbName
END
) AS type_cost, COUNT(stats.COLUMN_NAME) AS col_count FROM information_schema.STATISTICS AS stats INNER JOIN
information_schema.COLUMNS AS cols ON stats.TABLE_SCHEMA = cols.TABLE_SCHEMA AND stats.TABLE_NAME = cols.TABLE_NAME AND stats.COLUMN_NAME = cols.COLUMN_NAME
WHERE stats.TABLE_SCHEMA = '%s' AND stats.TABLE_NAME = '%s' AND stats.INDEX_NAME NOT IN
WHERE stats.TABLE_SCHEMA = %s AND stats.TABLE_NAME = %s AND stats.INDEX_NAME NOT IN
(
SELECT DISTINCT INDEX_NAME FROM information_schema.STATISTICS
WHERE TABLE_SCHEMA = '%s' AND TABLE_NAME = '%s' AND (NON_UNIQUE = 1 OR NULLABLE = 'YES')
WHERE TABLE_SCHEMA = %s AND TABLE_NAME = %s AND (NON_UNIQUE = 1 OR NULLABLE = 'YES')
)
GROUP BY INDEX_NAME ORDER BY type_cost ASC, col_count ASC LIMIT 1
) AS pke ON index_cols.INDEX_NAME = pke.INDEX_NAME
WHERE index_cols.TABLE_SCHEMA = '%s' AND index_cols.TABLE_NAME = '%s' AND NON_UNIQUE = 0 AND NULLABLE != 'YES'
WHERE index_cols.TABLE_SCHEMA = %s AND index_cols.TABLE_NAME = %s AND NON_UNIQUE = 0 AND NULLABLE != 'YES'
ORDER BY SEQ_IN_INDEX ASC`
sql = fmt.Sprintf(sql, dbName, table, dbName, table, dbName, table)
encodedDbName := encodeEntityName(dbName)
encodedTable := encodeEntityName(table)
sql = fmt.Sprintf(sql, encodedDbName, encodedTable, encodedDbName, encodedTable, encodedDbName, encodedTable)
qr, err := conn.ExecuteFetch(sql, 1000, true)
if err != nil {
return nil, err
Expand Down
32 changes: 30 additions & 2 deletions go/vt/mysqlctl/schema_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ var queryMap map[string]*sqltypes.Result

func mockExec(query string, maxRows int, wantFields bool) (*sqltypes.Result, error) {
queryMap = make(map[string]*sqltypes.Result)
getColsQuery := fmt.Sprintf(GetColumnNamesQuery, "'test'", "t1")
getColsQuery := fmt.Sprintf(GetColumnNamesQuery, "'test'", "'t1'")
queryMap[getColsQuery] = &sqltypes.Result{
Fields: []*querypb.Field{{
Name: "column_name",
Expand All @@ -40,7 +40,7 @@ func mockExec(query string, maxRows int, wantFields bool) (*sqltypes.Result, err
Type: sqltypes.VarBinary,
}},
}
getColsQuery = fmt.Sprintf(GetColumnNamesQuery, "database()", "t2")
getColsQuery = fmt.Sprintf(GetColumnNamesQuery, "database()", "'t2'")
queryMap[getColsQuery] = &sqltypes.Result{
Fields: []*querypb.Field{{
Name: "column_name",
Expand All @@ -61,6 +61,29 @@ func mockExec(query string, maxRows int, wantFields bool) (*sqltypes.Result, err
if ok {
return result, nil
}

getColsQuery = fmt.Sprintf(GetColumnNamesQuery, "database()", "'with \\' quote'")
queryMap[getColsQuery] = &sqltypes.Result{
Fields: []*querypb.Field{{
Name: "column_name",
Type: sqltypes.VarChar,
}},
Rows: [][]sqltypes.Value{
{sqltypes.NewVarChar("col1")},
},
}

queryMap["SELECT `col1` FROM `with ' quote` WHERE 1 != 1"] = &sqltypes.Result{
Fields: []*querypb.Field{{
Name: "col1",
Type: sqltypes.VarChar,
}},
}
result, ok = queryMap[query]
if ok {
return result, nil
}

return nil, fmt.Errorf("query %s not found in mock setup", query)
}

Expand All @@ -74,4 +97,9 @@ func TestColumnList(t *testing.T) {
fields, _, err = GetColumns("", "t2", mockExec)
require.NoError(t, err)
require.Equal(t, `[name:"col1" type:VARCHAR]`, fmt.Sprintf("%+v", fields))

fields, _, err = GetColumns("", "with ' quote", mockExec)
require.NoError(t, err)
require.Equal(t, `[name:"col1" type:VARCHAR]`, fmt.Sprintf("%+v", fields))

}

0 comments on commit 3f195a4

Please sign in to comment.