Skip to content

Commit

Permalink
backport #6308 to 6.0 release branch
Browse files Browse the repository at this point in the history
Signed-off-by: Harshit Gangal <[email protected]>
  • Loading branch information
harshit-gangal committed Jun 16, 2020
1 parent f7fa695 commit 6589103
Show file tree
Hide file tree
Showing 6 changed files with 88 additions and 55 deletions.
7 changes: 6 additions & 1 deletion go/test/endtoend/vtgate/setstatement/udv_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import (
"fmt"
"testing"

"github.com/stretchr/testify/assert"

"vitess.io/vitess/go/test/utils"

"github.com/google/go-cmp/cmp"
Expand All @@ -44,6 +46,9 @@ func TestSetUDV(t *testing.T) {
}

queries := []queriesWithExpectations{{
query: "select @foo",
expectedRows: "[[NULL]]", rowsAffected: 1,
}, {
query: "set @foo = 'abc', @bar = 42, @baz = 30.5, @tablet = concat('foo','bar')",
expectedRows: "", rowsAffected: 0,
}, {
Expand Down Expand Up @@ -92,7 +97,7 @@ func TestSetUDV(t *testing.T) {
t.Run(fmt.Sprintf("%d-%s", i, q.query), func(t *testing.T) {
qr, err := exec(t, conn, q.query)
require.NoError(t, err)
require.Equal(t, uint64(q.rowsAffected), qr.RowsAffected, "rows affected wrong for query: %s", q.query)
assert.Equal(t, uint64(q.rowsAffected), qr.RowsAffected, "rows affected wrong for query: %s", q.query)
if q.expectedRows != "" {
result := fmt.Sprintf("%v", qr.Rows)
if diff := cmp.Diff(q.expectedRows, result); diff != "" {
Expand Down
90 changes: 49 additions & 41 deletions go/vt/sqlparser/expression_rewriting.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ limitations under the License.
package sqlparser

import (
"strings"

"vitess.io/vitess/go/vt/log"
querypb "vitess.io/vitess/go/vt/proto/query"
"vitess.io/vitess/go/vt/proto/vtrpc"
Expand All @@ -36,7 +38,7 @@ type BindVarNeeds struct {
NeedLastInsertID bool
NeedDatabase bool
NeedFoundRows bool
NeedUserDefinedVariables bool
NeedUserDefinedVariables []string
}

// RewriteAST rewrites the whole AST, replacing function calls and adding column aliases to queries
Expand All @@ -55,19 +57,18 @@ func RewriteAST(in Statement) (*RewriteASTResult, error) {
r := &RewriteASTResult{
AST: out,
}
if _, ok := er.bindVars[LastInsertIDName]; ok {
r.NeedLastInsertID = true
}
if _, ok := er.bindVars[DBVarName]; ok {
r.NeedDatabase = true
}
if _, ok := er.bindVars[FoundRowsName]; ok {
r.NeedFoundRows = true
}
if _, ok := er.bindVars[UserDefinedVariableName]; ok {
r.NeedUserDefinedVariables = true
for k := range er.bindVars {
switch k {
case LastInsertIDName:
r.NeedLastInsertID = true
case DBVarName:
r.NeedDatabase = true
case FoundRowsName:
r.NeedFoundRows = true
default:
r.NeedUserDefinedVariables = append(r.NeedUserDefinedVariables, k)
}
}

return r, nil
}

Expand Down Expand Up @@ -145,41 +146,48 @@ func (er *expressionRewriter) goingDown(cursor *Cursor) bool {
return false
}
case *FuncExpr:
switch {
// last_insert_id() -> :__lastInsertId
case node.Name.EqualString("last_insert_id"):
if len(node.Exprs) > 0 { //last_insert_id(x)
er.err = vterrors.New(vtrpc.Code_UNIMPLEMENTED, "Argument to LAST_INSERT_ID() not supported")
} else {
cursor.Replace(bindVarExpression(LastInsertIDName))
er.needBindVarFor(LastInsertIDName)
}
case er.shouldRewriteDatabaseFunc &&
(node.Name.EqualString("database") ||
node.Name.EqualString("schema")):
if len(node.Exprs) > 0 {
er.err = vterrors.Errorf(vtrpc.Code_INVALID_ARGUMENT, "Syntax error. %s() takes no arguments", node.Name.String())
} else {
cursor.Replace(bindVarExpression(DBVarName))
er.needBindVarFor(DBVarName)
}
case node.Name.EqualString("found_rows"):
if len(node.Exprs) > 0 {
er.err = vterrors.New(vtrpc.Code_INVALID_ARGUMENT, "Arguments to FOUND_ROWS() not supported")
} else {
cursor.Replace(bindVarExpression(FoundRowsName))
er.needBindVarFor(FoundRowsName)
}
}
er.funcRewrite(cursor, node)
case *ColName:
if node.Name.at == SingleAt {
cursor.Replace(bindVarExpression(UserDefinedVariableName + node.Name.CompliantName()))
er.needBindVarFor(UserDefinedVariableName)
udv := strings.ToLower(node.Name.CompliantName())
cursor.Replace(bindVarExpression(UserDefinedVariableName + udv))
er.needBindVarFor(udv)
}
}
return true
}

func (er *expressionRewriter) funcRewrite(cursor *Cursor, node *FuncExpr) {
switch {
// last_insert_id() -> :__lastInsertId
case node.Name.EqualString("last_insert_id"):
if len(node.Exprs) > 0 { //last_insert_id(x)
er.err = vterrors.New(vtrpc.Code_UNIMPLEMENTED, "Argument to LAST_INSERT_ID() not supported")
} else {
cursor.Replace(bindVarExpression(LastInsertIDName))
er.needBindVarFor(LastInsertIDName)
}
// database() -> :__vtdbname
case er.shouldRewriteDatabaseFunc &&
(node.Name.EqualString("database") ||
node.Name.EqualString("schema")):
if len(node.Exprs) > 0 {
er.err = vterrors.Errorf(vtrpc.Code_INVALID_ARGUMENT, "Syntax error. %s() takes no arguments", node.Name.String())
} else {
cursor.Replace(bindVarExpression(DBVarName))
er.needBindVarFor(DBVarName)
}
// found_rows() -> :__vtfrows
case node.Name.EqualString("found_rows"):
if len(node.Exprs) > 0 {
er.err = vterrors.New(vtrpc.Code_INVALID_ARGUMENT, "Arguments to FOUND_ROWS() not supported")
} else {
cursor.Replace(bindVarExpression(FoundRowsName))
er.needBindVarFor(FoundRowsName)
}
}
}

// instead of creating new objects, we'll reuse this one
var token = struct{}{}

Expand Down
16 changes: 9 additions & 7 deletions go/vt/sqlparser/expression_rewriting_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@ import (
)

type myTestCase struct {
in, expected string
liid, db, foundRows, udv bool
in, expected string
liid, db, foundRows bool
udv int
}

func TestRewrites(in *testing.T) {
Expand Down Expand Up @@ -97,17 +98,17 @@ func TestRewrites(in *testing.T) {
{
in: "select @`x y`",
expected: "select :__vtudvx_y as `@``x y``` from dual",
udv: true,
udv: 1,
},
{
in: "select id from t where id = @x",
expected: "select id from t where id = :__vtudvx",
db: false, udv: true,
in: "select id from t where id = @x and val = @y",
expected: "select id from t where id = :__vtudvx and val = :__vtudvy",
db: false, udv: 2,
},
{
in: "insert into t(id) values(@xyx)",
expected: "insert into t(id) values(:__vtudvxyx)",
db: false, udv: true,
db: false, udv: 1,
},
}

Expand All @@ -127,6 +128,7 @@ func TestRewrites(in *testing.T) {
require.Equal(t, tc.liid, result.NeedLastInsertID, "should need last insert id")
require.Equal(t, tc.db, result.NeedDatabase, "should need database name")
require.Equal(t, tc.foundRows, result.NeedFoundRows, "should need found rows")
require.Equal(t, tc.udv, len(result.NeedUserDefinedVariables), "should need row count")
})
}
}
2 changes: 2 additions & 0 deletions go/vt/vtgate/evalengine/expressions.go
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,8 @@ func evaluateByType(val *querypb.BindVariable) (EvalResult, error) {
return evalResult{typ: sqltypes.Float64, fval: fval}, nil
case sqltypes.VarChar, sqltypes.Text, sqltypes.VarBinary:
return evalResult{typ: sqltypes.VarBinary, bytes: val.Value}, nil
case sqltypes.Null:
return evalResult{typ: sqltypes.Null}, nil
}
return evalResult{}, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "Type is not supported: %s", val.Type.String())
}
13 changes: 9 additions & 4 deletions go/vt/vtgate/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -242,11 +242,16 @@ func (e *Executor) addNeededBindVars(bindVarNeeds sqlparser.BindVarNeeds, bindVa
bindVars[sqlparser.LastInsertIDName] = sqltypes.Uint64BindVariable(session.GetLastInsertId())
}

// todo: do we need to check this map for nil?
if bindVarNeeds.NeedUserDefinedVariables && session.UserDefinedVariables != nil {
for k, v := range session.UserDefinedVariables {
bindVars[sqlparser.UserDefinedVariableName+k] = v
udvMap := session.UserDefinedVariables
if udvMap == nil {
udvMap = map[string]*querypb.BindVariable{}
}
for _, udv := range bindVarNeeds.NeedUserDefinedVariables {
val := udvMap[udv]
if val == nil {
val = sqltypes.NullBindVariable
}
bindVars[sqlparser.UserDefinedVariableName+udv] = val
}

if bindVarNeeds.NeedFoundRows {
Expand Down
15 changes: 13 additions & 2 deletions go/vt/vtgate/executor_select_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -249,18 +249,29 @@ func TestSelectUserDefinedVariable(t *testing.T) {
defer QueryLogger.Unsubscribe(logChan)

sql := "select @foo"
masterSession = &vtgatepb.Session{UserDefinedVariables: createMap([]string{"foo"}, []interface{}{"bar"})}
result, err := executorExec(executor, sql, map[string]*querypb.BindVariable{})
require.NoError(t, err)
wantResult := &sqltypes.Result{
Fields: []*querypb.Field{
{Name: "@foo", Type: sqltypes.Null},
},
Rows: [][]sqltypes.Value{{
sqltypes.NULL,
}},
}
utils.MustMatch(t, result, wantResult, "Mismatch")

masterSession = &vtgatepb.Session{UserDefinedVariables: createMap([]string{"foo"}, []interface{}{"bar"})}
result, err = executorExec(executor, sql, map[string]*querypb.BindVariable{})
require.NoError(t, err)
wantResult = &sqltypes.Result{
Fields: []*querypb.Field{
{Name: "@foo", Type: sqltypes.VarBinary},
},
Rows: [][]sqltypes.Value{{
sqltypes.NewVarBinary("bar"),
}},
}
require.NoError(t, err)
utils.MustMatch(t, result, wantResult, "Mismatch")
}

Expand Down

0 comments on commit 6589103

Please sign in to comment.