diff --git a/data/test/vtexplain/insertsharded-output.json b/data/test/vtexplain/insertsharded-output.json index ca2b7a6b455..036a835d15f 100644 --- a/data/test/vtexplain/insertsharded-output.json +++ b/data/test/vtexplain/insertsharded-output.json @@ -320,14 +320,14 @@ } }, { - "Original": "insert ignore into user(id, name, nickname) values (:vtg1, :vtg2, :vtg2)", + "Original": "insert ignore into user(id, name, nickname) values (:vtg1, :vtg2, :vtg3)", "Instructions": { "Opcode": "InsertShardedIgnore", "Keyspace": { "Name": "ks_sharded", "Sharded": true }, - "Query": "insert ignore into user(id, name, nickname) values (:_id0, :_name0, :vtg2)", + "Query": "insert ignore into user(id, name, nickname) values (:_id0, :_name0, :vtg3)", "Values": [ [ ":vtg1" @@ -339,7 +339,7 @@ "Table": "user", "Prefix": "insert ignore into user(id, name, nickname) values ", "Mid": [ - "(:_id0, :_name0, :vtg2)" + "(:_id0, :_name0, :vtg3)" ] } } @@ -348,12 +348,13 @@ "ks_sharded/-80": { "TabletQueries": [ { - "SQL": "insert ignore into user(id, name, nickname) values (:_id0, :_name0, :vtg2) /* vtgate:: keyspace_id:06e7ea22ce92708f */", + "SQL": "insert ignore into user(id, name, nickname) values (:_id0, :_name0, :vtg3) /* vtgate:: keyspace_id:06e7ea22ce92708f */", "BindVars": { "_id0": "2", "_name0": "'bob'", "vtg1": "2", - "vtg2": "'bob'" + "vtg2": "'bob'", + "vtg3": "'bob'" } } ], @@ -562,14 +563,14 @@ } }, { - "Original": "insert into user(id, name, nickname) values (:vtg1, :vtg2, :vtg3) on duplicate key update nickname = :vtg3", + "Original": "insert into user(id, name, nickname) values (:vtg1, :vtg2, :vtg3) on duplicate key update nickname = :vtg4", "Instructions": { "Opcode": "InsertShardedIgnore", "Keyspace": { "Name": "ks_sharded", "Sharded": true }, - "Query": "insert into user(id, name, nickname) values (:_id0, :_name0, :vtg3) on duplicate key update nickname = :vtg3", + "Query": "insert into user(id, name, nickname) values (:_id0, :_name0, :vtg3) on duplicate key update nickname = :vtg4", "Values": [ [ ":vtg1" @@ -583,7 +584,7 @@ "Mid": [ "(:_id0, :_name0, :vtg3)" ], - "Suffix": " on duplicate key update nickname = :vtg3" + "Suffix": " on duplicate key update nickname = :vtg4" } } ], @@ -591,13 +592,14 @@ "ks_sharded/-80": { "TabletQueries": [ { - "SQL": "insert into user(id, name, nickname) values (:_id0, :_name0, :vtg3) on duplicate key update nickname = :vtg3 /* vtgate:: keyspace_id:06e7ea22ce92708f */", + "SQL": "insert into user(id, name, nickname) values (:_id0, :_name0, :vtg3) on duplicate key update nickname = :vtg4 /* vtgate:: keyspace_id:06e7ea22ce92708f */", "BindVars": { "_id0": "2", "_name0": "'bob'", "vtg1": "2", "vtg2": "'bob'", - "vtg3": "'bobby'" + "vtg3": "'bobby'", + "vtg4": "'bobby'" } } ], diff --git a/go/vt/sqlparser/normalizer.go b/go/vt/sqlparser/normalizer.go index 5ec2bbc1228..7daf8ca35ad 100644 --- a/go/vt/sqlparser/normalizer.go +++ b/go/vt/sqlparser/normalizer.go @@ -28,78 +28,150 @@ import ( // updates the bind vars to those values. The supplied prefix // is used to generate the bind var names. The function ensures // that there are no collisions with existing bind vars. +// Within Select constructs, bind vars are deduped. This allows +// us to identify vindex equality. Otherwise, every value is +// treated as distinct. func Normalize(stmt Statement, bindVars map[string]*querypb.BindVariable, prefix string) { - reserved := GetBindvars(stmt) - // vals allows us to reuse bindvars for - // identical values. - counter := 1 - vals := make(map[string]string) - _ = Walk(func(node SQLNode) (kontinue bool, err error) { - switch node := node.(type) { - case *SQLVal: - // Make the bindvar - bval := sqlToBindvar(node) - if bval == nil { - // If unsuccessful continue. - return true, nil - } - // Check if there's a bindvar for that value already. - var key string - if bval.Type == sqltypes.VarBinary { - // Prefixing strings with "'" ensures that a string - // and number that have the same representation don't - // collide. - key = "'" + string(node.Val) - } else { - key = string(node.Val) - } - bvname, ok := vals[key] - if !ok { - // If there's no such bindvar, make a new one. - bvname, counter = newName(prefix, counter, reserved) - vals[key] = bvname - bindVars[bvname] = bval - } - // Modify the AST node to a bindvar. - node.Type = ValArg - node.Val = append([]byte(":"), bvname...) - case *ComparisonExpr: - switch node.Operator { - case InStr, NotInStr: - default: - return true, nil - } - // It's either IN or NOT IN. - tupleVals, ok := node.Right.(ValTuple) - if !ok { - return true, nil - } - // The RHS is a tuple of values. - // Make a list bindvar. - bvals := &querypb.BindVariable{ - Type: querypb.Type_TUPLE, - } - for _, val := range tupleVals { - bval := sqlToBindvar(val) - if bval == nil { - return true, nil - } - bvals.Values = append(bvals.Values, &querypb.Value{ - Type: bval.Type, - Value: bval.Value, - }) - } - var bvname string - bvname, counter = newName(prefix, counter, reserved) - bindVars[bvname] = bvals - // Modify RHS to be a list bindvar. - node.Right = ListArg(append([]byte("::"), bvname...)) + nz := newNormalizer(stmt, bindVars, prefix) + _ = Walk(nz.WalkStatement, stmt) +} + +type normalizer struct { + stmt Statement + bindVars map[string]*querypb.BindVariable + prefix string + reserved map[string]struct{} + counter int + vals map[string]string +} + +func newNormalizer(stmt Statement, bindVars map[string]*querypb.BindVariable, prefix string) *normalizer { + return &normalizer{ + stmt: stmt, + bindVars: bindVars, + prefix: prefix, + reserved: GetBindvars(stmt), + counter: 1, + vals: make(map[string]string), + } +} + +// WalkStatement is the top level walk function. +// If it encounters a Select, it switches to a mode +// where variables are deduped. +func (nz *normalizer) WalkStatement(node SQLNode) (bool, error) { + switch node := node.(type) { + case *Select: + _ = Walk(nz.WalkSelect, node) + // Don't continue + return false, nil + case *SQLVal: + nz.convertSQLVal(node) + case *ComparisonExpr: + nz.convertComparison(node) + } + return true, nil +} + +// WalkSelect normalizes the AST in Select mode. +func (nz *normalizer) WalkSelect(node SQLNode) (bool, error) { + switch node := node.(type) { + case *SQLVal: + nz.convertSQLValDedup(node) + case *ComparisonExpr: + nz.convertComparison(node) + } + return true, nil +} + +func (nz *normalizer) convertSQLValDedup(node *SQLVal) { + // If value is too long, don't dedup. + // Such values are most likely not for vindexes. + // We save a lot of CPU because we avoid building + // the key for them. + if len(node.Val) > 256 { + nz.convertSQLVal(node) + return + } + + // Make the bindvar + bval := nz.sqlToBindvar(node) + if bval == nil { + return + } + + // Check if there's a bindvar for that value already. + var key string + if bval.Type == sqltypes.VarBinary { + // Prefixing strings with "'" ensures that a string + // and number that have the same representation don't + // collide. + key = "'" + string(node.Val) + } else { + key = string(node.Val) + } + bvname, ok := nz.vals[key] + if !ok { + // If there's no such bindvar, make a new one. + bvname = nz.newName() + nz.vals[key] = bvname + nz.bindVars[bvname] = bval + } + + // Modify the AST node to a bindvar. + node.Type = ValArg + node.Val = append([]byte(":"), bvname...) +} + +// convertSQLVal converts an SQLVal without the dedup. +func (nz *normalizer) convertSQLVal(node *SQLVal) { + bval := nz.sqlToBindvar(node) + if bval == nil { + return + } + + bvname := nz.newName() + nz.bindVars[bvname] = bval + + node.Type = ValArg + node.Val = append([]byte(":"), bvname...) +} + +// convertComparison attempts to convert IN clauses to +// use the list bind var construct. If it fails, it returns +// with no change made. The walk function will then continue +// and iterate on converting each individual value into separate +// bind vars. +func (nz *normalizer) convertComparison(node *ComparisonExpr) { + if node.Operator != InStr && node.Operator != NotInStr { + return + } + tupleVals, ok := node.Right.(ValTuple) + if !ok { + return + } + // The RHS is a tuple of values. + // Make a list bindvar. + bvals := &querypb.BindVariable{ + Type: querypb.Type_TUPLE, + } + for _, val := range tupleVals { + bval := nz.sqlToBindvar(val) + if bval == nil { + return } - return true, nil - }, stmt) + bvals.Values = append(bvals.Values, &querypb.Value{ + Type: bval.Type, + Value: bval.Value, + }) + } + bvname := nz.newName() + nz.bindVars[bvname] = bvals + // Modify RHS to be a list bindvar. + node.Right = ListArg(append([]byte("::"), bvname...)) } -func sqlToBindvar(node SQLNode) *querypb.BindVariable { +func (nz *normalizer) sqlToBindvar(node SQLNode) *querypb.BindVariable { if node, ok := node.(*SQLVal); ok { var v sqltypes.Value var err error @@ -121,14 +193,14 @@ func sqlToBindvar(node SQLNode) *querypb.BindVariable { return nil } -func newName(prefix string, counter int, reserved map[string]struct{}) (string, int) { +func (nz *normalizer) newName() string { for { - newName := fmt.Sprintf("%s%d", prefix, counter) - if _, ok := reserved[newName]; !ok { - reserved[newName] = struct{}{} - return newName, counter + 1 + newName := fmt.Sprintf("%s%d", nz.prefix, nz.counter) + if _, ok := nz.reserved[newName]; !ok { + nz.reserved[newName] = struct{}{} + return newName } - counter++ + nz.counter++ } } diff --git a/go/vt/sqlparser/normalizer_test.go b/go/vt/sqlparser/normalizer_test.go index 07b68168ebd..9f6742c61e0 100644 --- a/go/vt/sqlparser/normalizer_test.go +++ b/go/vt/sqlparser/normalizer_test.go @@ -17,6 +17,7 @@ limitations under the License. package sqlparser import ( + "fmt" "reflect" "testing" @@ -81,6 +82,56 @@ func TestNormalize(t *testing.T) { "bv1": sqltypes.Int64BindVariable(1), "bv2": sqltypes.BytesBindVariable([]byte("1")), }, + }, { + // val should not be reused for non-select statements + in: "insert into a values(1, 1)", + outstmt: "insert into a values (:bv1, :bv2)", + outbv: map[string]*querypb.BindVariable{ + "bv1": sqltypes.Int64BindVariable(1), + "bv2": sqltypes.Int64BindVariable(1), + }, + }, { + // val should be reused only in subqueries of DMLs + in: "update a set v1=(select 5 from t), v2=5, v3=(select 5 from t), v4=5", + outstmt: "update a set v1 = (select :bv1 from t), v2 = :bv2, v3 = (select :bv1 from t), v4 = :bv3", + outbv: map[string]*querypb.BindVariable{ + "bv1": sqltypes.Int64BindVariable(5), + "bv2": sqltypes.Int64BindVariable(5), + "bv3": sqltypes.Int64BindVariable(5), + }, + }, { + // list vars should work for DMLs also + in: "update a set v1=5 where v2 in (1, 4, 5)", + outstmt: "update a set v1 = :bv1 where v2 in ::bv2", + outbv: map[string]*querypb.BindVariable{ + "bv1": sqltypes.Int64BindVariable(5), + "bv2": sqltypes.TestBindVariable([]interface{}{1, 4, 5}), + }, + }, { + // Hex value does not convert + in: "select * from t where v1 = 0x1234", + outstmt: "select * from t where v1 = 0x1234", + outbv: map[string]*querypb.BindVariable{}, + }, { + // Hex value does not convert for DMLs + in: "update a set v1 = 0x1234", + outstmt: "update a set v1 = 0x1234", + outbv: map[string]*querypb.BindVariable{}, + }, { + // Values up to len 256 will reuse. + in: fmt.Sprintf("select * from t where v1 = '%256s' and v2 = '%256s'", "a", "a"), + outstmt: "select * from t where v1 = :bv1 and v2 = :bv1", + outbv: map[string]*querypb.BindVariable{ + "bv1": sqltypes.BytesBindVariable([]byte(fmt.Sprintf("%256s", "a"))), + }, + }, { + // Values greater than len 256 will not reuse. + in: fmt.Sprintf("select * from t where v1 = '%257s' and v2 = '%257s'", "b", "b"), + outstmt: "select * from t where v1 = :bv1 and v2 = :bv2", + outbv: map[string]*querypb.BindVariable{ + "bv1": sqltypes.BytesBindVariable([]byte(fmt.Sprintf("%257s", "b"))), + "bv2": sqltypes.BytesBindVariable([]byte(fmt.Sprintf("%257s", "b"))), + }, }, { // bad int in: "select * from t where v1 = 12345678901234567890",