From 3f195a4af31ddad07a2d7c021fb4a95bad50f30a Mon Sep 17 00:00:00 2001 From: Dirkjan Bussink Date: Wed, 14 Jun 2023 12:07:51 +0200 Subject: [PATCH] mysqlctl: Correctly encode database and table names (#13312) These need to be escaped as SQL strings for queries against the `information_schema` or otherwise any names that need quotes will break. Signed-off-by: Dirkjan Bussink --- go/vt/mysqlctl/schema.go | 26 ++++++++++++++------------ go/vt/mysqlctl/schema_test.go | 32 ++++++++++++++++++++++++++++++-- 2 files changed, 44 insertions(+), 14 deletions(-) diff --git a/go/vt/mysqlctl/schema.go b/go/vt/mysqlctl/schema.go index af81e2d088a..04fb3483bc7 100644 --- a/go/vt/mysqlctl/schema.go +++ b/go/vt/mysqlctl/schema.go @@ -52,9 +52,9 @@ func (mysqld *Mysqld) executeSchemaCommands(sql string) error { return mysqld.executeMysqlScript(params, strings.NewReader(sql)) } -func encodeTableName(tableName string) string { +func encodeEntityName(name string) string { var buf strings.Builder - sqltypes.NewVarChar(tableName).EncodeSQL(&buf) + sqltypes.NewVarChar(name).EncodeSQL(&buf) return buf.String() } @@ -66,7 +66,7 @@ func tableListSQL(tables []string) (string, error) { encodedTables := make([]string, len(tables)) for i, tableName := range tables { - encodedTables[i] = encodeTableName(tableName) + encodedTables[i] = encodeEntityName(tableName) } return "(" + strings.Join(encodedTables, ", ") + ")", nil @@ -280,7 +280,7 @@ func ResolveTables(ctx context.Context, mysqld MysqlDaemon, dbName string, table const ( GetColumnNamesQuery = `SELECT COLUMN_NAME as column_name FROM INFORMATION_SCHEMA.COLUMNS - WHERE TABLE_SCHEMA = %s AND TABLE_NAME = '%s' + WHERE TABLE_SCHEMA = %s AND TABLE_NAME = %s ORDER BY ORDINAL_POSITION` GetFieldsQuery = "SELECT %s FROM %s WHERE 1 != 1" ) @@ -290,9 +290,9 @@ func GetColumnsList(dbName, tableName string, exec func(string, int, bool) (*sql if dbName == "" { dbName2 = "database()" } else { - dbName2 = fmt.Sprintf("'%s'", dbName) + dbName2 = encodeEntityName(dbName) } - query := fmt.Sprintf(GetColumnNamesQuery, dbName2, sqlescape.UnescapeID(tableName)) + query := fmt.Sprintf(GetColumnNamesQuery, dbName2, encodeEntityName(sqlescape.UnescapeID(tableName))) qr, err := exec(query, -1, true) if err != nil { return "", err @@ -377,9 +377,9 @@ func (mysqld *Mysqld) getPrimaryKeyColumns(ctx context.Context, dbName string, t sql := ` SELECT TABLE_NAME as table_name, COLUMN_NAME as column_name FROM information_schema.STATISTICS - WHERE TABLE_SCHEMA = '%s' AND TABLE_NAME IN %s AND LOWER(INDEX_NAME) = 'primary' + WHERE TABLE_SCHEMA = %s AND TABLE_NAME IN %s AND LOWER(INDEX_NAME) = 'primary' ORDER BY table_name, SEQ_IN_INDEX` - sql = fmt.Sprintf(sql, dbName, tableList) + sql = fmt.Sprintf(sql, encodeEntityName(dbName), tableList) qr, err := conn.ExecuteFetch(sql, len(tables)*100, true) if err != nil { return nil, err @@ -599,16 +599,18 @@ func (mysqld *Mysqld) GetPrimaryKeyEquivalentColumns(ctx context.Context, dbName END ) AS type_cost, COUNT(stats.COLUMN_NAME) AS col_count FROM information_schema.STATISTICS AS stats INNER JOIN information_schema.COLUMNS AS cols ON stats.TABLE_SCHEMA = cols.TABLE_SCHEMA AND stats.TABLE_NAME = cols.TABLE_NAME AND stats.COLUMN_NAME = cols.COLUMN_NAME - WHERE stats.TABLE_SCHEMA = '%s' AND stats.TABLE_NAME = '%s' AND stats.INDEX_NAME NOT IN + WHERE stats.TABLE_SCHEMA = %s AND stats.TABLE_NAME = %s AND stats.INDEX_NAME NOT IN ( SELECT DISTINCT INDEX_NAME FROM information_schema.STATISTICS - WHERE TABLE_SCHEMA = '%s' AND TABLE_NAME = '%s' AND (NON_UNIQUE = 1 OR NULLABLE = 'YES') + WHERE TABLE_SCHEMA = %s AND TABLE_NAME = %s AND (NON_UNIQUE = 1 OR NULLABLE = 'YES') ) GROUP BY INDEX_NAME ORDER BY type_cost ASC, col_count ASC LIMIT 1 ) AS pke ON index_cols.INDEX_NAME = pke.INDEX_NAME - WHERE index_cols.TABLE_SCHEMA = '%s' AND index_cols.TABLE_NAME = '%s' AND NON_UNIQUE = 0 AND NULLABLE != 'YES' + WHERE index_cols.TABLE_SCHEMA = %s AND index_cols.TABLE_NAME = %s AND NON_UNIQUE = 0 AND NULLABLE != 'YES' ORDER BY SEQ_IN_INDEX ASC` - sql = fmt.Sprintf(sql, dbName, table, dbName, table, dbName, table) + encodedDbName := encodeEntityName(dbName) + encodedTable := encodeEntityName(table) + sql = fmt.Sprintf(sql, encodedDbName, encodedTable, encodedDbName, encodedTable, encodedDbName, encodedTable) qr, err := conn.ExecuteFetch(sql, 1000, true) if err != nil { return nil, err diff --git a/go/vt/mysqlctl/schema_test.go b/go/vt/mysqlctl/schema_test.go index 5ec02be9960..fb64f8ca8ee 100644 --- a/go/vt/mysqlctl/schema_test.go +++ b/go/vt/mysqlctl/schema_test.go @@ -15,7 +15,7 @@ var queryMap map[string]*sqltypes.Result func mockExec(query string, maxRows int, wantFields bool) (*sqltypes.Result, error) { queryMap = make(map[string]*sqltypes.Result) - getColsQuery := fmt.Sprintf(GetColumnNamesQuery, "'test'", "t1") + getColsQuery := fmt.Sprintf(GetColumnNamesQuery, "'test'", "'t1'") queryMap[getColsQuery] = &sqltypes.Result{ Fields: []*querypb.Field{{ Name: "column_name", @@ -40,7 +40,7 @@ func mockExec(query string, maxRows int, wantFields bool) (*sqltypes.Result, err Type: sqltypes.VarBinary, }}, } - getColsQuery = fmt.Sprintf(GetColumnNamesQuery, "database()", "t2") + getColsQuery = fmt.Sprintf(GetColumnNamesQuery, "database()", "'t2'") queryMap[getColsQuery] = &sqltypes.Result{ Fields: []*querypb.Field{{ Name: "column_name", @@ -61,6 +61,29 @@ func mockExec(query string, maxRows int, wantFields bool) (*sqltypes.Result, err if ok { return result, nil } + + getColsQuery = fmt.Sprintf(GetColumnNamesQuery, "database()", "'with \\' quote'") + queryMap[getColsQuery] = &sqltypes.Result{ + Fields: []*querypb.Field{{ + Name: "column_name", + Type: sqltypes.VarChar, + }}, + Rows: [][]sqltypes.Value{ + {sqltypes.NewVarChar("col1")}, + }, + } + + queryMap["SELECT `col1` FROM `with ' quote` WHERE 1 != 1"] = &sqltypes.Result{ + Fields: []*querypb.Field{{ + Name: "col1", + Type: sqltypes.VarChar, + }}, + } + result, ok = queryMap[query] + if ok { + return result, nil + } + return nil, fmt.Errorf("query %s not found in mock setup", query) } @@ -74,4 +97,9 @@ func TestColumnList(t *testing.T) { fields, _, err = GetColumns("", "t2", mockExec) require.NoError(t, err) require.Equal(t, `[name:"col1" type:VARCHAR]`, fmt.Sprintf("%+v", fields)) + + fields, _, err = GetColumns("", "with ' quote", mockExec) + require.NoError(t, err) + require.Equal(t, `[name:"col1" type:VARCHAR]`, fmt.Sprintf("%+v", fields)) + }