diff --git a/go/vt/sqlparser/normalizer.go b/go/vt/sqlparser/normalizer.go index a9d70d5e190..3e482c64aab 100644 --- a/go/vt/sqlparser/normalizer.go +++ b/go/vt/sqlparser/normalizer.go @@ -48,7 +48,7 @@ func Normalize(stmt Statement, reserved *ReservedVars, bindVars map[string]*quer type normalizer struct { bindVars map[string]*querypb.BindVariable reserved *ReservedVars - vals map[string]string + vals map[Literal]string err error inDerived bool } @@ -57,7 +57,7 @@ func newNormalizer(reserved *ReservedVars, bindVars map[string]*querypb.BindVari return &normalizer{ bindVars: bindVars, reserved: reserved, - vals: make(map[string]string), + vals: make(map[Literal]string), } } @@ -200,12 +200,11 @@ func (nz *normalizer) convertLiteralDedup(node *Literal, cursor *Cursor) { } // Check if there's a bindvar for that value already. - key := keyFor(bval, node) - bvname, ok := nz.vals[key] + bvname, ok := nz.vals[*node] if !ok { // If there's no such bindvar, make a new one. bvname = nz.reserved.nextUnusedVar() - nz.vals[key] = bvname + nz.vals[*node] = bvname nz.bindVars[bvname] = bval } @@ -213,17 +212,6 @@ func (nz *normalizer) convertLiteralDedup(node *Literal, cursor *Cursor) { cursor.Replace(NewTypedArgument(bvname, node.SQLType())) } -func keyFor(bval *querypb.BindVariable, lit *Literal) string { - if bval.Type != sqltypes.VarBinary && bval.Type != sqltypes.VarChar { - return lit.Val - } - - // Prefixing strings with "'" ensures that a string - // and number that have the same representation don't - // collide. - return "'" + lit.Val -} - // convertLiteral converts an Literal without the dedup. func (nz *normalizer) convertLiteral(node *Literal, cursor *Cursor) { err := validateLiteral(node) @@ -281,15 +269,14 @@ func (nz *normalizer) parameterize(left, right Expr) Expr { if bval == nil { return nil } - key := keyFor(bval, lit) - bvname := nz.decideBindVarName(key, lit, col, bval) + bvname := nz.decideBindVarName(lit, col, bval) return NewTypedArgument(bvname, lit.SQLType()) } -func (nz *normalizer) decideBindVarName(key string, lit *Literal, col *ColName, bval *querypb.BindVariable) string { +func (nz *normalizer) decideBindVarName(lit *Literal, col *ColName, bval *querypb.BindVariable) string { if len(lit.Val) <= 256 { // first we check if we already have a bindvar for this value. if we do, we re-use that bindvar name - bvname, ok := nz.vals[key] + bvname, ok := nz.vals[*lit] if ok { return bvname } @@ -299,7 +286,7 @@ func (nz *normalizer) decideBindVarName(key string, lit *Literal, col *ColName, // Big values are most likely not for vindexes. // We save a lot of CPU because we avoid building bvname := nz.reserved.ReserveColName(col) - nz.vals[key] = bvname + nz.vals[*lit] = bvname nz.bindVars[bvname] = bval return bvname diff --git a/go/vt/sqlparser/normalizer_test.go b/go/vt/sqlparser/normalizer_test.go index 2b0a4b52122..58580a78701 100644 --- a/go/vt/sqlparser/normalizer_test.go +++ b/go/vt/sqlparser/normalizer_test.go @@ -371,6 +371,14 @@ func TestNormalize(t *testing.T) { in: `select * from (select 12) as t`, outstmt: `select * from (select 12 from dual) as t`, outbv: map[string]*querypb.BindVariable{}, + }, { + // HexVal and Int should not share a bindvar just because they have the same value + in: `select * from t where v1 = x'31' and v2 = 31`, + outstmt: `select * from t where v1 = :v1 /* HEXVAL */ and v2 = :v2 /* INT64 */`, + outbv: map[string]*querypb.BindVariable{ + "v1": sqltypes.HexValBindVariable([]byte("x'31'")), + "v2": sqltypes.Int64BindVariable(31), + }, }} for _, tc := range testcases { t.Run(tc.in, func(t *testing.T) {