Skip to content

Commit

Permalink
Merge pull request #7271 from planetscale/fix-7254-in-9
Browse files Browse the repository at this point in the history
[9.0] don't try to compare varchars in vtgate
  • Loading branch information
systay authored Jan 8, 2021
2 parents 283497f + 3503ef2 commit d86632c
Show file tree
Hide file tree
Showing 7 changed files with 65 additions and 73 deletions.
12 changes: 12 additions & 0 deletions go/test/endtoend/vtgate/lookup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -515,3 +515,15 @@ func TestSelectNullLookup(t *testing.T) {
})
}
}

func TestUnicodeLooseMD5CaseInsensitive(t *testing.T) {
ctx := context.Background()
conn, err := mysql.Connect(ctx, &vtParams)
require.NoError(t, err)
defer conn.Close()

exec(t, conn, "insert into t4(id1, id2) values(1, 'test')")
defer exec(t, conn, "delete from t4")

assertMatches(t, conn, "SELECT id1, id2 from t4 where id2 = 'Test'", `[[INT64(1) VARCHAR("test")]]`)
}
4 changes: 2 additions & 2 deletions go/test/endtoend/vtgate/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,14 @@ create table t4(
id1 bigint,
id2 varchar(10),
primary key(id1)
) Engine=InnoDB;
) ENGINE=InnoDB DEFAULT charset=utf8mb4 COLLATE=utf8mb4_general_ci;
create table t4_id2_idx(
id2 varchar(10),
id1 bigint,
keyspace_id varbinary(50),
primary key(id2, id1)
) Engine=InnoDB;
) Engine=InnoDB DEFAULT charset=utf8mb4 COLLATE=utf8mb4_general_ci;
create table t5_null_vindex(
id bigint not null,
Expand Down
6 changes: 3 additions & 3 deletions go/test/utils/diff.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,17 @@ import (
// In Test*() function:
//
// mustMatch(t, want, got, "something doesn't match")
func MustMatchFn(allowUnexportedTypes []interface{}, ignoredFields []string, extraOpts ...cmp.Option) func(t *testing.T, want, got interface{}, errMsg string) {
func MustMatchFn(allowUnexportedTypes []interface{}, ignoredFields []string, extraOpts ...cmp.Option) func(t *testing.T, want, got interface{}, errMsg ...string) {
diffOpts := append([]cmp.Option{
cmp.AllowUnexported(allowUnexportedTypes...),
cmpIgnoreFields(ignoredFields...),
}, extraOpts...)
// Diffs want/got and fails with errMsg on any failure.
return func(t *testing.T, want, got interface{}, errMsg string) {
return func(t *testing.T, want, got interface{}, errMsg ...string) {
t.Helper()
diff := cmp.Diff(want, got, diffOpts...)
if diff != "" {
t.Fatalf("%s: (-want +got)\n%v", errMsg, diff)
t.Fatalf("%v: (-want +got)\n%v", errMsg, diff)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ select count(*) from user where id = 1 /* point aggregate */
select count(*) from user where name in ('alice','bob') /* scatter aggregate */

1 ks_sharded/40-80: select `name`, user_id from name_user_map where `name` in ('alice') limit 10001 /* scatter aggregate */
1 ks_sharded/c0-: select `name`, user_id from name_user_map where `name` in ('bob') limit 10001 /* scatter aggregate */
2 ks_sharded/-40: select count(*) from user where `name` in ('alice', 'bob') limit 10001 /* scatter aggregate */
2 ks_sharded/c0-: select `name`, user_id from name_user_map where `name` in ('bob') limit 10001 /* scatter aggregate */
3 ks_sharded/-40: select count(*) from user where `name` in ('alice', 'bob') limit 10001 /* scatter aggregate */

----------------------------------------------------------------------
select name, count(*) from user group by name /* scatter aggregate */
Expand Down
25 changes: 9 additions & 16 deletions go/vt/vtgate/executor_framework_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ import (
"strings"
"testing"

"github.com/stretchr/testify/require"

"github.com/stretchr/testify/assert"

"context"

"vitess.io/vitess/go/sqltypes"
Expand Down Expand Up @@ -546,19 +550,14 @@ func testQueryLog(t *testing.T, logChan chan interface{}, method, stmtType, sql
t.Helper()

logStats := getQueryLog(logChan)
if logStats == nil {
t.Errorf("logstats: no querylog in channel, want sql %s", sql)
return nil
}
require.NotNil(t, logStats)

var log bytes.Buffer
streamlog.GetFormatter(QueryLogger)(&log, nil, logStats)
fields := strings.Split(log.String(), "\t")

// fields[0] is the method
if method != fields[0] {
t.Errorf("logstats: method want %q got %q", method, fields[0])
}
assert.Equal(t, method, fields[0], "logstats: method")

// fields[1] - fields[6] are the caller id, start/end times, etc

Expand Down Expand Up @@ -588,22 +587,16 @@ func testQueryLog(t *testing.T, logChan chan interface{}, method, stmtType, sql
}

// fields[11] is the statement type
if stmtType != fields[11] {
t.Errorf("logstats: stmtType want %q got %q", stmtType, fields[11])
}
assert.Equal(t, stmtType, fields[11], "logstats: stmtType")

// fields[12] is the original sql
wantSQL := fmt.Sprintf("%q", sql)
if wantSQL != fields[12] {
t.Errorf("logstats: SQL want %s got %s", wantSQL, fields[12])
}
assert.Equal(t, wantSQL, fields[12], "logstats: SQL")

// fields[13] contains the formatted bind vars

// fields[14] is the count of shard queries
if fmt.Sprintf("%v", shardQueries) != fields[14] {
t.Errorf("logstats: ShardQueries want %v got %v", shardQueries, fields[14])
}
assert.Equal(t, fmt.Sprintf("%v", shardQueries), fields[14], "logstats: ShardQueries")

return logStats
}
Expand Down
39 changes: 13 additions & 26 deletions go/vt/vtgate/executor_select_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -568,10 +568,8 @@ func TestSelectBindvars(t *testing.T) {
lookup.SetResults([]*sqltypes.Result{sqltypes.MakeTestResult(
sqltypes.MakeTestFields("b|a", "varbinary|varbinary"),
"foo1|1",
"foo2|1",
), sqltypes.MakeTestResult(
sqltypes.MakeTestFields("b|a", "varbinary|varbinary"),
"foo1|1",
"foo2|1",
)})

Expand All @@ -584,12 +582,8 @@ func TestSelectBindvars(t *testing.T) {
Sql: "select id from user where id = :id",
BindVariables: map[string]*querypb.BindVariable{"id": sqltypes.Int64BindVariable(1)},
}}
if !reflect.DeepEqual(sbc1.Queries, wantQueries) {
t.Errorf("sbc1.Queries: %+v, want %+v\n", sbc1.Queries, wantQueries)
}
if sbc2.Queries != nil {
t.Errorf("sbc2.Queries: %+v, want nil\n", sbc2.Queries)
}
utils.MustMatch(t, sbc1.Queries, wantQueries)
assert.Empty(t, sbc2.Queries)
sbc1.Queries = nil
testQueryLog(t, logChan, "TestExecute", "SELECT", sql, 1)

Expand All @@ -608,11 +602,10 @@ func TestSelectBindvars(t *testing.T) {
"__vals": sqltypes.TestBindVariable([]interface{}{"foo1", "foo2"}),
},
}}
if !reflect.DeepEqual(sbc1.Queries, wantQueries) {
t.Errorf("sbc1.Queries: %+v, want %+v\n", sbc1.Queries, wantQueries)
}
utils.MustMatch(t, wantQueries, sbc1.Queries)
sbc1.Queries = nil
testQueryLog(t, logChan, "VindexLookup", "SELECT", "select name, user_id from name_user_map where name in ::name", 1)
testQueryLog(t, logChan, "VindexLookup", "SELECT", "select name, user_id from name_user_map where name in ::name", 1)
testQueryLog(t, logChan, "TestExecute", "SELECT", sql, 1)

// Test with BytesBindVariable
Expand All @@ -623,17 +616,14 @@ func TestSelectBindvars(t *testing.T) {
})
require.NoError(t, err)
wantQueries = []*querypb.BoundQuery{{
Sql: "select id from user where `name` in ::__vals",
Sql: "select id from user where 1 != 1",
BindVariables: map[string]*querypb.BindVariable{
"name1": sqltypes.BytesBindVariable([]byte("foo1")),
"name2": sqltypes.BytesBindVariable([]byte("foo2")),
"__vals": sqltypes.TestBindVariable([]interface{}{[]byte("foo1"), []byte("foo2")}),
"name1": sqltypes.BytesBindVariable([]byte("foo1")),
"name2": sqltypes.BytesBindVariable([]byte("foo2")),
},
}}
if !reflect.DeepEqual(sbc1.Queries, wantQueries) {
t.Errorf("sbc1.Queries: %+v, want %+v\n", sbc1.Queries, wantQueries)
}

utils.MustMatch(t, wantQueries, sbc1.Queries)
testQueryLog(t, logChan, "VindexLookup", "SELECT", "select name, user_id from name_user_map where name in ::name", 1)
testQueryLog(t, logChan, "VindexLookup", "SELECT", "select name, user_id from name_user_map where name in ::name", 1)
testQueryLog(t, logChan, "TestExecute", "SELECT", sql, 1)

Expand Down Expand Up @@ -662,9 +652,7 @@ func TestSelectBindvars(t *testing.T) {
"name": sqltypes.StringBindVariable("nonexistent"),
},
}}
if !reflect.DeepEqual(sbc1.Queries, wantQueries) {
t.Errorf("sbc1.Queries: %+v, want %+v\n", sbc1.Queries, wantQueries)
}
utils.MustMatch(t, wantQueries, sbc1.Queries)

vars, err := sqltypes.BuildBindVariable([]interface{}{sqltypes.NewVarBinary("nonexistent")})
require.NoError(t, err)
Expand All @@ -674,9 +662,8 @@ func TestSelectBindvars(t *testing.T) {
"name": vars,
},
}}
if !reflect.DeepEqual(lookup.Queries, wantLookupQueries) {
t.Errorf("lookup.Queries: %+v, want %+v\n", lookup.Queries, wantLookupQueries)
}

utils.MustMatch(t, wantLookupQueries, lookup.Queries)

testQueryLog(t, logChan, "VindexLookup", "SELECT", "select name, user_id from name_user_map where name in ::name", 1)
testQueryLog(t, logChan, "TestExecute", "SELECT", sql, 1)
Expand Down Expand Up @@ -829,7 +816,7 @@ func TestSelectNormalize(t *testing.T) {
},
}}
require.Empty(t, sbc1.Queries)
utils.MustMatch(t, sbc2.Queries, wantQueries, "sbc2.Queries")
utils.MustMatch(t, wantQueries, sbc2.Queries, "sbc2.Queries")
sbc2.Queries = nil
masterSession.TargetString = ""
}
Expand Down
48 changes: 24 additions & 24 deletions go/vt/vtgate/vindexes/lookup_internal.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,30 +80,7 @@ func (lkp *lookupInternal) Lookup(vcursor VCursor, ids []sqltypes.Value, co vtga
if vcursor.InTransactionAndIsDML() {
sel = sel + " for update"
}
if !ids[0].IsIntegral() && !ids[0].IsBinary() {
// for non integral and binary type, fallback to send query per id
for _, id := range ids {
vars, err := sqltypes.BuildBindVariable([]interface{}{id})
if err != nil {
return nil, fmt.Errorf("lookup.Map: %v", err)
}
bindVars := map[string]*querypb.BindVariable{
lkp.FromColumns[0]: vars,
}
var result *sqltypes.Result
result, err = vcursor.Execute("VindexLookup", sel, bindVars, false /* rollbackOnError */, co)
if err != nil {
return nil, fmt.Errorf("lookup.Map: %v", err)
}
rows := make([][]sqltypes.Value, 0, len(result.Rows))
for _, row := range result.Rows {
rows = append(rows, []sqltypes.Value{row[1]})
}
results = append(results, &sqltypes.Result{
Rows: rows,
})
}
} else {
if ids[0].IsIntegral() {
// for integral or binary type, batch query all ids and then map them back to the input order
vars, err := sqltypes.BuildBindVariable(ids)
if err != nil {
Expand All @@ -126,6 +103,29 @@ func (lkp *lookupInternal) Lookup(vcursor VCursor, ids []sqltypes.Value, co vtga
Rows: resultMap[id.ToString()],
})
}
} else {
// for non integral and binary type, fallback to send query per id
for _, id := range ids {
vars, err := sqltypes.BuildBindVariable([]interface{}{id})
if err != nil {
return nil, fmt.Errorf("lookup.Map: %v", err)
}
bindVars := map[string]*querypb.BindVariable{
lkp.FromColumns[0]: vars,
}
var result *sqltypes.Result
result, err = vcursor.Execute("VindexLookup", sel, bindVars, false /* rollbackOnError */, co)
if err != nil {
return nil, fmt.Errorf("lookup.Map: %v", err)
}
rows := make([][]sqltypes.Value, 0, len(result.Rows))
for _, row := range result.Rows {
rows = append(rows, []sqltypes.Value{row[1]})
}
results = append(results, &sqltypes.Result{
Rows: rows,
})
}
}
return results, nil
}
Expand Down

0 comments on commit d86632c

Please sign in to comment.