From 3b0ccd0a02765b1fc8184fd346891bd86b5e2a3b Mon Sep 17 00:00:00 2001 From: "vitess-bot[bot]" <108069721+vitess-bot[bot]@users.noreply.github.com> Date: Mon, 20 Mar 2023 10:35:10 +0200 Subject: [PATCH] Fix bug in vtexplain around JOINs (#12383) Signed-off-by: Andres Taylor Co-authored-by: Andres Taylor --- .../multi-output/unsharded-output.txt | 5 + go/vt/vtexplain/testdata/test-schema.sql | 6 + go/vt/vtexplain/testdata/test-vschema.json | 1 + .../vtexplain/testdata/unsharded-queries.sql | 1 + go/vt/vtexplain/vtexplain_vttablet.go | 353 ++++++++++-------- 5 files changed, 205 insertions(+), 161 deletions(-) diff --git a/go/vt/vtexplain/testdata/multi-output/unsharded-output.txt b/go/vt/vtexplain/testdata/multi-output/unsharded-output.txt index 8bb3ecef970..0adc5661077 100644 --- a/go/vt/vtexplain/testdata/multi-output/unsharded-output.txt +++ b/go/vt/vtexplain/testdata/multi-output/unsharded-output.txt @@ -45,3 +45,8 @@ select ID from t1 1 ks_unsharded/-: select ID from t1 limit 10001 ---------------------------------------------------------------------- +select t1.id, t2.c2 from t1 join t2 on t1.id = t2.t1_id where t2.c2 in (1) + +1 ks_unsharded/-: select t1.id, t2.c2 from t1 join t2 on t1.id = t2.t1_id where t2.c2 in (1) limit 10001 + +---------------------------------------------------------------------- diff --git a/go/vt/vtexplain/testdata/test-schema.sql b/go/vt/vtexplain/testdata/test-schema.sql index 5b65334796f..d5f9fbad56f 100644 --- a/go/vt/vtexplain/testdata/test-schema.sql +++ b/go/vt/vtexplain/testdata/test-schema.sql @@ -4,6 +4,12 @@ create table t1 ( floatval float not null default 0, primary key (id) ); +create table t2 ( + id bigint(20) unsigned not null, + t1_id bigint(20) unsigned not null default 0, + c2 bigint(20) null, + primary key (id) +); create table user ( id bigint, diff --git a/go/vt/vtexplain/testdata/test-vschema.json b/go/vt/vtexplain/testdata/test-vschema.json index a50e11e92ae..f4350efa56d 100644 --- a/go/vt/vtexplain/testdata/test-vschema.json +++ b/go/vt/vtexplain/testdata/test-vschema.json @@ -3,6 +3,7 @@ "sharded": false, "tables": { "t1": {}, + "t2": {}, "table_not_in_schema": {} } }, diff --git a/go/vt/vtexplain/testdata/unsharded-queries.sql b/go/vt/vtexplain/testdata/unsharded-queries.sql index f0147ac5d6e..712245f3338 100644 --- a/go/vt/vtexplain/testdata/unsharded-queries.sql +++ b/go/vt/vtexplain/testdata/unsharded-queries.sql @@ -6,3 +6,4 @@ update t1 set floatval = 9.99; delete from t1 where id = 100; insert into t1 (id,intval,floatval) values (1,2,3.14) on duplicate key update intval=3, floatval=3.14; select ID from t1; +select t1.id, t2.c2 from t1 join t2 on t1.id = t2.t1_id where t2.c2 in (1); \ No newline at end of file diff --git a/go/vt/vtexplain/vtexplain_vttablet.go b/go/vt/vtexplain/vtexplain_vttablet.go index ee94946e5c1..05932f785a0 100644 --- a/go/vt/vtexplain/vtexplain_vttablet.go +++ b/go/vt/vtexplain/vtexplain_vttablet.go @@ -504,199 +504,230 @@ func (t *explainTablet) HandleQuery(c *mysql.Conn, query string, callback func(* } switch sqlparser.Preview(query) { case sqlparser.StmtSelect: - // Parse the select statement to figure out the table and columns - // that were referenced so that the synthetic response has the - // expected field names and types. - stmt, err := sqlparser.Parse(query) + var err error + result, err = t.handleSelect(query) if err != nil { return err } - - var selStmt *sqlparser.Select - switch stmt := stmt.(type) { - case *sqlparser.Select: - selStmt = stmt - case *sqlparser.Union: - selStmt = sqlparser.GetFirstSelect(stmt) - default: - return fmt.Errorf("vtexplain: unsupported statement type +%v", reflect.TypeOf(stmt)) + case sqlparser.StmtBegin, sqlparser.StmtCommit, sqlparser.StmtSet, + sqlparser.StmtSavepoint, sqlparser.StmtSRollback, sqlparser.StmtRelease: + result = &sqltypes.Result{} + case sqlparser.StmtShow: + result = &sqltypes.Result{Fields: sqltypes.MakeTestFields("", "")} + case sqlparser.StmtInsert, sqlparser.StmtReplace, sqlparser.StmtUpdate, sqlparser.StmtDelete: + result = &sqltypes.Result{ + RowsAffected: 1, } + default: + return fmt.Errorf("unsupported query %s", query) + } - // Gen4 supports more complex queries so we now need to - // handle multiple FROM clauses - tables := make([]*sqlparser.AliasedTableExpr, len(selStmt.From)) - for _, from := range selStmt.From { - tables = append(tables, getTables(from)...) - } + return callback(result) +} - tableColumnMap := map[sqlparser.IdentifierCS]map[string]querypb.Type{} - for _, table := range tables { - if table == nil { - continue - } +func (t *explainTablet) handleSelect(query string) (*sqltypes.Result, error) { + // Parse the select statement to figure out the table and columns + // that were referenced so that the synthetic response has the + // expected field names and types. + stmt, err := sqlparser.Parse(query) + if err != nil { + return nil, err + } - tableName := sqlparser.String(sqlparser.GetTableName(table.Expr)) - columns, exists := t.vte.getGlobalTabletEnv().tableColumns[tableName] - if !exists && tableName != "" && tableName != "dual" { - return fmt.Errorf("unable to resolve table name %s", tableName) - } + var selStmt *sqlparser.Select + switch stmt := stmt.(type) { + case *sqlparser.Select: + selStmt = stmt + case *sqlparser.Union: + selStmt = sqlparser.GetFirstSelect(stmt) + default: + return nil, fmt.Errorf("vtexplain: unsupported statement type +%v", reflect.TypeOf(stmt)) + } - colTypeMap := map[string]querypb.Type{} + // Gen4 supports more complex queries so we now need to + // handle multiple FROM clauses + tables := make([]*sqlparser.AliasedTableExpr, len(selStmt.From)) + for _, from := range selStmt.From { + tables = append(tables, getTables(from)...) + } - if table.As.IsEmpty() { - tableColumnMap[sqlparser.GetTableName(table.Expr)] = colTypeMap - } else { - tableColumnMap[table.As] = colTypeMap - } + tableColumnMap := map[sqlparser.IdentifierCS]map[string]querypb.Type{} + for _, table := range tables { + if table == nil { + continue + } - for k, v := range columns { - if colType, exists := colTypeMap[k]; exists { - if colType != v { - return fmt.Errorf("column type mismatch for column : %s, types: %d vs %d", k, colType, v) - } - continue - } - colTypeMap[k] = v - } + tableName := sqlparser.String(sqlparser.GetTableName(table.Expr)) + columns, exists := t.vte.getGlobalTabletEnv().tableColumns[tableName] + if !exists && tableName != "" && tableName != "dual" { + return nil, fmt.Errorf("unable to resolve table name %s", tableName) + } + + colTypeMap := map[string]querypb.Type{} + if table.As.IsEmpty() { + tableColumnMap[sqlparser.GetTableName(table.Expr)] = colTypeMap + } else { + tableColumnMap[table.As] = colTypeMap } - colNames := make([]string, 0, 4) - colTypes := make([]querypb.Type, 0, 4) - for _, node := range selStmt.SelectExprs { - switch node := node.(type) { - case *sqlparser.AliasedExpr: - colNames, colTypes = inferColTypeFromExpr(node.Expr, tableColumnMap, colNames, colTypes) - case *sqlparser.StarExpr: - if node.TableName.Name.IsEmpty() { - // SELECT * - for _, colTypeMap := range tableColumnMap { - for col, colType := range colTypeMap { - colNames = append(colNames, col) - colTypes = append(colTypes, colType) - } - } - } else { - // SELECT tableName.* - colTypeMap := tableColumnMap[node.TableName.Name] - for col, colType := range colTypeMap { - colNames = append(colNames, col) - colTypes = append(colTypes, colType) - } + for k, v := range columns { + if colType, exists := colTypeMap[k]; exists { + if colType != v { + return nil, fmt.Errorf("column type mismatch for column : %s, types: %d vs %d", k, colType, v) } + continue } + colTypeMap[k] = v } - // the query against lookup table is in-query, handle it specifically - var inColName string - inVal := make([]sqltypes.Value, 0, 10) - - rowCount := 1 - if selStmt.Where != nil { - switch v := selStmt.Where.Expr.(type) { - case *sqlparser.ComparisonExpr: - if v.Operator == sqlparser.InOp { - switch c := v.Left.(type) { - case *sqlparser.ColName: - colName := strings.ToLower(c.Name.String()) - colType := tableColumnMap[sqlparser.GetTableName(selStmt.From[0].(*sqlparser.AliasedTableExpr).Expr)][colName] - - switch values := v.Right.(type) { - case sqlparser.ValTuple: - for _, val := range values { - switch v := val.(type) { - case *sqlparser.Literal: - value, err := evalengine.LiteralToValue(v) - if err != nil { - return err - } - - // Cast the value in the tuple to the expected value of the column - castedValue, err := evalengine.Cast(value, colType) - if err != nil { - return err - } - - // Check if we have a duplicate value - isNewValue := true - for _, v := range inVal { - result, err := evalengine.NullsafeCompare(v, value, collations.Default()) - if err != nil { - return err - } - - if result == 0 { - isNewValue = false - break - } - } - - if isNewValue { - inVal = append(inVal, castedValue) - } - } - } - rowCount = len(inVal) - } - inColName = strings.ToLower(c.Name.String()) - } - } - } + } + + colNames, colTypes := t.analyzeExpressions(selStmt, tableColumnMap) + + inColName, inVal, rowCount, s, err := t.analyzeWhere(selStmt, tableColumnMap) + if err != nil { + return s, err + } + + fields := make([]*querypb.Field, len(colNames)) + rows := make([][]sqltypes.Value, 0, rowCount) + for i, col := range colNames { + colType := colTypes[i] + fields[i] = &querypb.Field{ + Name: col, + Type: colType, } + } - fields := make([]*querypb.Field, len(colNames)) - rows := make([][]sqltypes.Value, 0, rowCount) + for j := 0; j < rowCount; j++ { + values := make([]sqltypes.Value, len(colNames)) for i, col := range colNames { + // Generate a fake value for the given column. For the column in the IN clause, + // use the provided values in the query, For numeric types, + // use the column index. For all other types, just shortcut to using + // a string type that encodes the column name + index. colType := colTypes[i] - fields[i] = &querypb.Field{ - Name: col, - Type: colType, + if len(inVal) > j && col == inColName { + values[i], _ = sqltypes.NewValue(querypb.Type_VARBINARY, inVal[j].Raw()) + } else if sqltypes.IsIntegral(colType) { + values[i] = sqltypes.NewInt32(int32(i + 1)) + } else if sqltypes.IsFloat(colType) { + values[i] = sqltypes.NewFloat64(1.0 + float64(i)) + } else { + values[i] = sqltypes.NewVarChar(fmt.Sprintf("%s_val_%d", col, i+1)) } } + rows = append(rows, values) + } + result := &sqltypes.Result{ + Fields: fields, + InsertID: 0, + Rows: rows, + } - for j := 0; j < rowCount; j++ { - values := make([]sqltypes.Value, len(colNames)) - for i, col := range colNames { - // Generate a fake value for the given column. For the column in the IN clause, - // use the provided values in the query, For numeric types, - // use the column index. For all other types, just shortcut to using - // a string type that encodes the column name + index. - colType := colTypes[i] - if len(inVal) > j && col == inColName { - values[i], _ = sqltypes.NewValue(querypb.Type_VARBINARY, inVal[j].Raw()) - } else if sqltypes.IsIntegral(colType) { - values[i] = sqltypes.NewInt32(int32(i + 1)) - } else if sqltypes.IsFloat(colType) { - values[i] = sqltypes.NewFloat64(1.0 + float64(i)) - } else { - values[i] = sqltypes.NewVarChar(fmt.Sprintf("%s_val_%d", col, i+1)) - } + resultJSON, _ := json.MarshalIndent(result, "", " ") + log.V(100).Infof("query %s result %s\n", query, string(resultJSON)) + return result, nil +} + +func (t *explainTablet) analyzeWhere(selStmt *sqlparser.Select, tableColumnMap map[sqlparser.IdentifierCS]map[string]querypb.Type) (inColName string, inVal []sqltypes.Value, rowCount int, result *sqltypes.Result, err error) { + // the query against lookup table is in-query, handle it specifically + rowCount = 1 + if selStmt.Where == nil { + return + } + v, ok := selStmt.Where.Expr.(*sqlparser.ComparisonExpr) + if !ok || v.Operator != sqlparser.InOp { + return + } + c, ok := v.Left.(*sqlparser.ColName) + if !ok { + return + } + colName := strings.ToLower(c.Name.String()) + colType := querypb.Type_VARCHAR + tableExpr := selStmt.From[0] + expr, ok := tableExpr.(*sqlparser.AliasedTableExpr) + if ok { + m := tableColumnMap[sqlparser.GetTableName(expr.Expr)] + if m != nil { + t, found := m[colName] + if found { + colType = t } - rows = append(rows, values) } - result = &sqltypes.Result{ - Fields: fields, - InsertID: 0, - Rows: rows, + } + + values, ok := v.Right.(sqlparser.ValTuple) + if !ok { + return + } + for _, val := range values { + lit, ok := val.(*sqlparser.Literal) + if !ok { + continue + } + value, err := evalengine.LiteralToValue(lit) + if err != nil { + return "", nil, 0, nil, err } - resultJSON, _ := json.MarshalIndent(result, "", " ") - log.V(100).Infof("query %s result %s\n", query, string(resultJSON)) + // Cast the value in the tuple to the expected value of the column + castedValue, err := evalengine.Cast(value, colType) + if err != nil { + return "", nil, 0, nil, err + } - case sqlparser.StmtBegin, sqlparser.StmtCommit, sqlparser.StmtSet, - sqlparser.StmtSavepoint, sqlparser.StmtSRollback, sqlparser.StmtRelease: - result = &sqltypes.Result{} - case sqlparser.StmtShow: - result = &sqltypes.Result{Fields: sqltypes.MakeTestFields("", "")} - case sqlparser.StmtInsert, sqlparser.StmtReplace, sqlparser.StmtUpdate, sqlparser.StmtDelete: - result = &sqltypes.Result{ - RowsAffected: 1, + // Check if we have a duplicate value + isNewValue := true + for _, v := range inVal { + result, err := evalengine.NullsafeCompare(v, value, collations.Default()) + if err != nil { + return "", nil, 0, nil, err + } + + if result == 0 { + isNewValue = false + break + } + } + + if isNewValue { + inVal = append(inVal, castedValue) } - default: - return fmt.Errorf("unsupported query %s", query) } + inColName = strings.ToLower(c.Name.String()) + return inColName, inVal, rowCount, nil, nil +} - return callback(result) +func (t *explainTablet) analyzeExpressions(selStmt *sqlparser.Select, tableColumnMap map[sqlparser.IdentifierCS]map[string]querypb.Type) ([]string, []querypb.Type) { + colNames := make([]string, 0, 4) + colTypes := make([]querypb.Type, 0, 4) + for _, node := range selStmt.SelectExprs { + switch node := node.(type) { + case *sqlparser.AliasedExpr: + colNames, colTypes = inferColTypeFromExpr(node.Expr, tableColumnMap, colNames, colTypes) + case *sqlparser.StarExpr: + if node.TableName.Name.IsEmpty() { + // SELECT * + for _, colTypeMap := range tableColumnMap { + for col, colType := range colTypeMap { + colNames = append(colNames, col) + colTypes = append(colTypes, colType) + } + } + } else { + // SELECT tableName.* + colTypeMap := tableColumnMap[node.TableName.Name] + for col, colType := range colTypeMap { + colNames = append(colNames, col) + colTypes = append(colTypes, colType) + } + } + } + } + return colNames, colTypes } func getTables(node sqlparser.SQLNode) []*sqlparser.AliasedTableExpr {