From 6589103ed58013d57653fe6c13cdc0362182ae77 Mon Sep 17 00:00:00 2001 From: Harshit Gangal Date: Tue, 16 Jun 2020 11:54:08 +0530 Subject: [PATCH] backport #6308 to 6.0 release branch Signed-off-by: Harshit Gangal --- .../endtoend/vtgate/setstatement/udv_test.go | 7 +- go/vt/sqlparser/expression_rewriting.go | 90 ++++++++++--------- go/vt/sqlparser/expression_rewriting_test.go | 16 ++-- go/vt/vtgate/evalengine/expressions.go | 2 + go/vt/vtgate/executor.go | 13 ++- go/vt/vtgate/executor_select_test.go | 15 +++- 6 files changed, 88 insertions(+), 55 deletions(-) diff --git a/go/test/endtoend/vtgate/setstatement/udv_test.go b/go/test/endtoend/vtgate/setstatement/udv_test.go index 19ae64c7bb0..46dbf2e4e3f 100644 --- a/go/test/endtoend/vtgate/setstatement/udv_test.go +++ b/go/test/endtoend/vtgate/setstatement/udv_test.go @@ -21,6 +21,8 @@ import ( "fmt" "testing" + "github.com/stretchr/testify/assert" + "vitess.io/vitess/go/test/utils" "github.com/google/go-cmp/cmp" @@ -44,6 +46,9 @@ func TestSetUDV(t *testing.T) { } queries := []queriesWithExpectations{{ + query: "select @foo", + expectedRows: "[[NULL]]", rowsAffected: 1, + }, { query: "set @foo = 'abc', @bar = 42, @baz = 30.5, @tablet = concat('foo','bar')", expectedRows: "", rowsAffected: 0, }, { @@ -92,7 +97,7 @@ func TestSetUDV(t *testing.T) { t.Run(fmt.Sprintf("%d-%s", i, q.query), func(t *testing.T) { qr, err := exec(t, conn, q.query) require.NoError(t, err) - require.Equal(t, uint64(q.rowsAffected), qr.RowsAffected, "rows affected wrong for query: %s", q.query) + assert.Equal(t, uint64(q.rowsAffected), qr.RowsAffected, "rows affected wrong for query: %s", q.query) if q.expectedRows != "" { result := fmt.Sprintf("%v", qr.Rows) if diff := cmp.Diff(q.expectedRows, result); diff != "" { diff --git a/go/vt/sqlparser/expression_rewriting.go b/go/vt/sqlparser/expression_rewriting.go index 514cafb4636..46e2bdcecd3 100644 --- a/go/vt/sqlparser/expression_rewriting.go +++ b/go/vt/sqlparser/expression_rewriting.go @@ -17,6 +17,8 @@ limitations under the License. package sqlparser import ( + "strings" + "vitess.io/vitess/go/vt/log" querypb "vitess.io/vitess/go/vt/proto/query" "vitess.io/vitess/go/vt/proto/vtrpc" @@ -36,7 +38,7 @@ type BindVarNeeds struct { NeedLastInsertID bool NeedDatabase bool NeedFoundRows bool - NeedUserDefinedVariables bool + NeedUserDefinedVariables []string } // RewriteAST rewrites the whole AST, replacing function calls and adding column aliases to queries @@ -55,19 +57,18 @@ func RewriteAST(in Statement) (*RewriteASTResult, error) { r := &RewriteASTResult{ AST: out, } - if _, ok := er.bindVars[LastInsertIDName]; ok { - r.NeedLastInsertID = true - } - if _, ok := er.bindVars[DBVarName]; ok { - r.NeedDatabase = true - } - if _, ok := er.bindVars[FoundRowsName]; ok { - r.NeedFoundRows = true - } - if _, ok := er.bindVars[UserDefinedVariableName]; ok { - r.NeedUserDefinedVariables = true + for k := range er.bindVars { + switch k { + case LastInsertIDName: + r.NeedLastInsertID = true + case DBVarName: + r.NeedDatabase = true + case FoundRowsName: + r.NeedFoundRows = true + default: + r.NeedUserDefinedVariables = append(r.NeedUserDefinedVariables, k) + } } - return r, nil } @@ -145,41 +146,48 @@ func (er *expressionRewriter) goingDown(cursor *Cursor) bool { return false } case *FuncExpr: - switch { - // last_insert_id() -> :__lastInsertId - case node.Name.EqualString("last_insert_id"): - if len(node.Exprs) > 0 { //last_insert_id(x) - er.err = vterrors.New(vtrpc.Code_UNIMPLEMENTED, "Argument to LAST_INSERT_ID() not supported") - } else { - cursor.Replace(bindVarExpression(LastInsertIDName)) - er.needBindVarFor(LastInsertIDName) - } - case er.shouldRewriteDatabaseFunc && - (node.Name.EqualString("database") || - node.Name.EqualString("schema")): - if len(node.Exprs) > 0 { - er.err = vterrors.Errorf(vtrpc.Code_INVALID_ARGUMENT, "Syntax error. %s() takes no arguments", node.Name.String()) - } else { - cursor.Replace(bindVarExpression(DBVarName)) - er.needBindVarFor(DBVarName) - } - case node.Name.EqualString("found_rows"): - if len(node.Exprs) > 0 { - er.err = vterrors.New(vtrpc.Code_INVALID_ARGUMENT, "Arguments to FOUND_ROWS() not supported") - } else { - cursor.Replace(bindVarExpression(FoundRowsName)) - er.needBindVarFor(FoundRowsName) - } - } + er.funcRewrite(cursor, node) case *ColName: if node.Name.at == SingleAt { - cursor.Replace(bindVarExpression(UserDefinedVariableName + node.Name.CompliantName())) - er.needBindVarFor(UserDefinedVariableName) + udv := strings.ToLower(node.Name.CompliantName()) + cursor.Replace(bindVarExpression(UserDefinedVariableName + udv)) + er.needBindVarFor(udv) } } return true } +func (er *expressionRewriter) funcRewrite(cursor *Cursor, node *FuncExpr) { + switch { + // last_insert_id() -> :__lastInsertId + case node.Name.EqualString("last_insert_id"): + if len(node.Exprs) > 0 { //last_insert_id(x) + er.err = vterrors.New(vtrpc.Code_UNIMPLEMENTED, "Argument to LAST_INSERT_ID() not supported") + } else { + cursor.Replace(bindVarExpression(LastInsertIDName)) + er.needBindVarFor(LastInsertIDName) + } + // database() -> :__vtdbname + case er.shouldRewriteDatabaseFunc && + (node.Name.EqualString("database") || + node.Name.EqualString("schema")): + if len(node.Exprs) > 0 { + er.err = vterrors.Errorf(vtrpc.Code_INVALID_ARGUMENT, "Syntax error. %s() takes no arguments", node.Name.String()) + } else { + cursor.Replace(bindVarExpression(DBVarName)) + er.needBindVarFor(DBVarName) + } + // found_rows() -> :__vtfrows + case node.Name.EqualString("found_rows"): + if len(node.Exprs) > 0 { + er.err = vterrors.New(vtrpc.Code_INVALID_ARGUMENT, "Arguments to FOUND_ROWS() not supported") + } else { + cursor.Replace(bindVarExpression(FoundRowsName)) + er.needBindVarFor(FoundRowsName) + } + } +} + // instead of creating new objects, we'll reuse this one var token = struct{}{} diff --git a/go/vt/sqlparser/expression_rewriting_test.go b/go/vt/sqlparser/expression_rewriting_test.go index 4989675f926..2539354d46c 100644 --- a/go/vt/sqlparser/expression_rewriting_test.go +++ b/go/vt/sqlparser/expression_rewriting_test.go @@ -23,8 +23,9 @@ import ( ) type myTestCase struct { - in, expected string - liid, db, foundRows, udv bool + in, expected string + liid, db, foundRows bool + udv int } func TestRewrites(in *testing.T) { @@ -97,17 +98,17 @@ func TestRewrites(in *testing.T) { { in: "select @`x y`", expected: "select :__vtudvx_y as `@``x y``` from dual", - udv: true, + udv: 1, }, { - in: "select id from t where id = @x", - expected: "select id from t where id = :__vtudvx", - db: false, udv: true, + in: "select id from t where id = @x and val = @y", + expected: "select id from t where id = :__vtudvx and val = :__vtudvy", + db: false, udv: 2, }, { in: "insert into t(id) values(@xyx)", expected: "insert into t(id) values(:__vtudvxyx)", - db: false, udv: true, + db: false, udv: 1, }, } @@ -127,6 +128,7 @@ func TestRewrites(in *testing.T) { require.Equal(t, tc.liid, result.NeedLastInsertID, "should need last insert id") require.Equal(t, tc.db, result.NeedDatabase, "should need database name") require.Equal(t, tc.foundRows, result.NeedFoundRows, "should need found rows") + require.Equal(t, tc.udv, len(result.NeedUserDefinedVariables), "should need row count") }) } } diff --git a/go/vt/vtgate/evalengine/expressions.go b/go/vt/vtgate/evalengine/expressions.go index 9581dd9067e..9679a240312 100644 --- a/go/vt/vtgate/evalengine/expressions.go +++ b/go/vt/vtgate/evalengine/expressions.go @@ -287,6 +287,8 @@ func evaluateByType(val *querypb.BindVariable) (EvalResult, error) { return evalResult{typ: sqltypes.Float64, fval: fval}, nil case sqltypes.VarChar, sqltypes.Text, sqltypes.VarBinary: return evalResult{typ: sqltypes.VarBinary, bytes: val.Value}, nil + case sqltypes.Null: + return evalResult{typ: sqltypes.Null}, nil } return evalResult{}, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "Type is not supported: %s", val.Type.String()) } diff --git a/go/vt/vtgate/executor.go b/go/vt/vtgate/executor.go index 6b74c39d77b..51d9ca0b6cc 100644 --- a/go/vt/vtgate/executor.go +++ b/go/vt/vtgate/executor.go @@ -242,11 +242,16 @@ func (e *Executor) addNeededBindVars(bindVarNeeds sqlparser.BindVarNeeds, bindVa bindVars[sqlparser.LastInsertIDName] = sqltypes.Uint64BindVariable(session.GetLastInsertId()) } - // todo: do we need to check this map for nil? - if bindVarNeeds.NeedUserDefinedVariables && session.UserDefinedVariables != nil { - for k, v := range session.UserDefinedVariables { - bindVars[sqlparser.UserDefinedVariableName+k] = v + udvMap := session.UserDefinedVariables + if udvMap == nil { + udvMap = map[string]*querypb.BindVariable{} + } + for _, udv := range bindVarNeeds.NeedUserDefinedVariables { + val := udvMap[udv] + if val == nil { + val = sqltypes.NullBindVariable } + bindVars[sqlparser.UserDefinedVariableName+udv] = val } if bindVarNeeds.NeedFoundRows { diff --git a/go/vt/vtgate/executor_select_test.go b/go/vt/vtgate/executor_select_test.go index c8b90bd2ce3..28f7af2fe5e 100644 --- a/go/vt/vtgate/executor_select_test.go +++ b/go/vt/vtgate/executor_select_test.go @@ -249,10 +249,22 @@ func TestSelectUserDefinedVariable(t *testing.T) { defer QueryLogger.Unsubscribe(logChan) sql := "select @foo" - masterSession = &vtgatepb.Session{UserDefinedVariables: createMap([]string{"foo"}, []interface{}{"bar"})} result, err := executorExec(executor, sql, map[string]*querypb.BindVariable{}) require.NoError(t, err) wantResult := &sqltypes.Result{ + Fields: []*querypb.Field{ + {Name: "@foo", Type: sqltypes.Null}, + }, + Rows: [][]sqltypes.Value{{ + sqltypes.NULL, + }}, + } + utils.MustMatch(t, result, wantResult, "Mismatch") + + masterSession = &vtgatepb.Session{UserDefinedVariables: createMap([]string{"foo"}, []interface{}{"bar"})} + result, err = executorExec(executor, sql, map[string]*querypb.BindVariable{}) + require.NoError(t, err) + wantResult = &sqltypes.Result{ Fields: []*querypb.Field{ {Name: "@foo", Type: sqltypes.VarBinary}, }, @@ -260,7 +272,6 @@ func TestSelectUserDefinedVariable(t *testing.T) { sqltypes.NewVarBinary("bar"), }}, } - require.NoError(t, err) utils.MustMatch(t, result, wantResult, "Mismatch") }