diff --git a/go/test/endtoend/vtgate/lookup_test.go b/go/test/endtoend/vtgate/lookup_test.go index 489bb814227..796ed695ac2 100644 --- a/go/test/endtoend/vtgate/lookup_test.go +++ b/go/test/endtoend/vtgate/lookup_test.go @@ -515,3 +515,15 @@ func TestSelectNullLookup(t *testing.T) { }) } } + +func TestUnicodeLooseMD5CaseInsensitive(t *testing.T) { + ctx := context.Background() + conn, err := mysql.Connect(ctx, &vtParams) + require.NoError(t, err) + defer conn.Close() + + exec(t, conn, "insert into t4(id1, id2) values(1, 'test')") + defer exec(t, conn, "delete from t4") + + assertMatches(t, conn, "SELECT id1, id2 from t4 where id2 = 'Test'", `[[INT64(1) VARCHAR("test")]]`) +} diff --git a/go/test/endtoend/vtgate/main_test.go b/go/test/endtoend/vtgate/main_test.go index 5f4422984e9..e3c4817a800 100644 --- a/go/test/endtoend/vtgate/main_test.go +++ b/go/test/endtoend/vtgate/main_test.go @@ -87,14 +87,14 @@ create table t4( id1 bigint, id2 varchar(10), primary key(id1) -) Engine=InnoDB; +) ENGINE=InnoDB DEFAULT charset=utf8mb4 COLLATE=utf8mb4_general_ci; create table t4_id2_idx( id2 varchar(10), id1 bigint, keyspace_id varbinary(50), primary key(id2, id1) -) Engine=InnoDB; +) Engine=InnoDB DEFAULT charset=utf8mb4 COLLATE=utf8mb4_general_ci; create table t5_null_vindex( id bigint not null, diff --git a/go/test/utils/diff.go b/go/test/utils/diff.go index 05a1f9eb66d..293d279e8c4 100644 --- a/go/test/utils/diff.go +++ b/go/test/utils/diff.go @@ -43,17 +43,17 @@ import ( // In Test*() function: // // mustMatch(t, want, got, "something doesn't match") -func MustMatchFn(allowUnexportedTypes []interface{}, ignoredFields []string, extraOpts ...cmp.Option) func(t *testing.T, want, got interface{}, errMsg string) { +func MustMatchFn(allowUnexportedTypes []interface{}, ignoredFields []string, extraOpts ...cmp.Option) func(t *testing.T, want, got interface{}, errMsg ...string) { diffOpts := append([]cmp.Option{ cmp.AllowUnexported(allowUnexportedTypes...), cmpIgnoreFields(ignoredFields...), }, extraOpts...) // Diffs want/got and fails with errMsg on any failure. - return func(t *testing.T, want, got interface{}, errMsg string) { + return func(t *testing.T, want, got interface{}, errMsg ...string) { t.Helper() diff := cmp.Diff(want, got, diffOpts...) if diff != "" { - t.Fatalf("%s: (-want +got)\n%v", errMsg, diff) + t.Fatalf("%v: (-want +got)\n%v", errMsg, diff) } } } diff --git a/go/vt/vtexplain/testdata/multi-output/selectsharded-output.txt b/go/vt/vtexplain/testdata/multi-output/selectsharded-output.txt index 69a10e8198c..ccad1126651 100644 --- a/go/vt/vtexplain/testdata/multi-output/selectsharded-output.txt +++ b/go/vt/vtexplain/testdata/multi-output/selectsharded-output.txt @@ -60,8 +60,8 @@ select count(*) from user where id = 1 /* point aggregate */ select count(*) from user where name in ('alice','bob') /* scatter aggregate */ 1 ks_sharded/40-80: select `name`, user_id from name_user_map where `name` in ('alice') limit 10001 /* scatter aggregate */ -1 ks_sharded/c0-: select `name`, user_id from name_user_map where `name` in ('bob') limit 10001 /* scatter aggregate */ -2 ks_sharded/-40: select count(*) from user where `name` in ('alice', 'bob') limit 10001 /* scatter aggregate */ +2 ks_sharded/c0-: select `name`, user_id from name_user_map where `name` in ('bob') limit 10001 /* scatter aggregate */ +3 ks_sharded/-40: select count(*) from user where `name` in ('alice', 'bob') limit 10001 /* scatter aggregate */ ---------------------------------------------------------------------- select name, count(*) from user group by name /* scatter aggregate */ diff --git a/go/vt/vtgate/executor_framework_test.go b/go/vt/vtgate/executor_framework_test.go index 71e6528673d..fcc24f67b08 100644 --- a/go/vt/vtgate/executor_framework_test.go +++ b/go/vt/vtgate/executor_framework_test.go @@ -24,6 +24,10 @@ import ( "strings" "testing" + "github.com/stretchr/testify/require" + + "github.com/stretchr/testify/assert" + "context" "vitess.io/vitess/go/sqltypes" @@ -546,19 +550,14 @@ func testQueryLog(t *testing.T, logChan chan interface{}, method, stmtType, sql t.Helper() logStats := getQueryLog(logChan) - if logStats == nil { - t.Errorf("logstats: no querylog in channel, want sql %s", sql) - return nil - } + require.NotNil(t, logStats) var log bytes.Buffer streamlog.GetFormatter(QueryLogger)(&log, nil, logStats) fields := strings.Split(log.String(), "\t") // fields[0] is the method - if method != fields[0] { - t.Errorf("logstats: method want %q got %q", method, fields[0]) - } + assert.Equal(t, method, fields[0], "logstats: method") // fields[1] - fields[6] are the caller id, start/end times, etc @@ -588,22 +587,16 @@ func testQueryLog(t *testing.T, logChan chan interface{}, method, stmtType, sql } // fields[11] is the statement type - if stmtType != fields[11] { - t.Errorf("logstats: stmtType want %q got %q", stmtType, fields[11]) - } + assert.Equal(t, stmtType, fields[11], "logstats: stmtType") // fields[12] is the original sql wantSQL := fmt.Sprintf("%q", sql) - if wantSQL != fields[12] { - t.Errorf("logstats: SQL want %s got %s", wantSQL, fields[12]) - } + assert.Equal(t, wantSQL, fields[12], "logstats: SQL") // fields[13] contains the formatted bind vars // fields[14] is the count of shard queries - if fmt.Sprintf("%v", shardQueries) != fields[14] { - t.Errorf("logstats: ShardQueries want %v got %v", shardQueries, fields[14]) - } + assert.Equal(t, fmt.Sprintf("%v", shardQueries), fields[14], "logstats: ShardQueries") return logStats } diff --git a/go/vt/vtgate/executor_select_test.go b/go/vt/vtgate/executor_select_test.go index 74c72baf988..d5783b070ce 100644 --- a/go/vt/vtgate/executor_select_test.go +++ b/go/vt/vtgate/executor_select_test.go @@ -568,10 +568,8 @@ func TestSelectBindvars(t *testing.T) { lookup.SetResults([]*sqltypes.Result{sqltypes.MakeTestResult( sqltypes.MakeTestFields("b|a", "varbinary|varbinary"), "foo1|1", - "foo2|1", ), sqltypes.MakeTestResult( sqltypes.MakeTestFields("b|a", "varbinary|varbinary"), - "foo1|1", "foo2|1", )}) @@ -584,12 +582,8 @@ func TestSelectBindvars(t *testing.T) { Sql: "select id from user where id = :id", BindVariables: map[string]*querypb.BindVariable{"id": sqltypes.Int64BindVariable(1)}, }} - if !reflect.DeepEqual(sbc1.Queries, wantQueries) { - t.Errorf("sbc1.Queries: %+v, want %+v\n", sbc1.Queries, wantQueries) - } - if sbc2.Queries != nil { - t.Errorf("sbc2.Queries: %+v, want nil\n", sbc2.Queries) - } + utils.MustMatch(t, sbc1.Queries, wantQueries) + assert.Empty(t, sbc2.Queries) sbc1.Queries = nil testQueryLog(t, logChan, "TestExecute", "SELECT", sql, 1) @@ -608,11 +602,10 @@ func TestSelectBindvars(t *testing.T) { "__vals": sqltypes.TestBindVariable([]interface{}{"foo1", "foo2"}), }, }} - if !reflect.DeepEqual(sbc1.Queries, wantQueries) { - t.Errorf("sbc1.Queries: %+v, want %+v\n", sbc1.Queries, wantQueries) - } + utils.MustMatch(t, wantQueries, sbc1.Queries) sbc1.Queries = nil testQueryLog(t, logChan, "VindexLookup", "SELECT", "select name, user_id from name_user_map where name in ::name", 1) + testQueryLog(t, logChan, "VindexLookup", "SELECT", "select name, user_id from name_user_map where name in ::name", 1) testQueryLog(t, logChan, "TestExecute", "SELECT", sql, 1) // Test with BytesBindVariable @@ -623,17 +616,14 @@ func TestSelectBindvars(t *testing.T) { }) require.NoError(t, err) wantQueries = []*querypb.BoundQuery{{ - Sql: "select id from user where `name` in ::__vals", + Sql: "select id from user where 1 != 1", BindVariables: map[string]*querypb.BindVariable{ - "name1": sqltypes.BytesBindVariable([]byte("foo1")), - "name2": sqltypes.BytesBindVariable([]byte("foo2")), - "__vals": sqltypes.TestBindVariable([]interface{}{[]byte("foo1"), []byte("foo2")}), + "name1": sqltypes.BytesBindVariable([]byte("foo1")), + "name2": sqltypes.BytesBindVariable([]byte("foo2")), }, }} - if !reflect.DeepEqual(sbc1.Queries, wantQueries) { - t.Errorf("sbc1.Queries: %+v, want %+v\n", sbc1.Queries, wantQueries) - } - + utils.MustMatch(t, wantQueries, sbc1.Queries) + testQueryLog(t, logChan, "VindexLookup", "SELECT", "select name, user_id from name_user_map where name in ::name", 1) testQueryLog(t, logChan, "VindexLookup", "SELECT", "select name, user_id from name_user_map where name in ::name", 1) testQueryLog(t, logChan, "TestExecute", "SELECT", sql, 1) @@ -662,9 +652,7 @@ func TestSelectBindvars(t *testing.T) { "name": sqltypes.StringBindVariable("nonexistent"), }, }} - if !reflect.DeepEqual(sbc1.Queries, wantQueries) { - t.Errorf("sbc1.Queries: %+v, want %+v\n", sbc1.Queries, wantQueries) - } + utils.MustMatch(t, wantQueries, sbc1.Queries) vars, err := sqltypes.BuildBindVariable([]interface{}{sqltypes.NewVarBinary("nonexistent")}) require.NoError(t, err) @@ -674,9 +662,8 @@ func TestSelectBindvars(t *testing.T) { "name": vars, }, }} - if !reflect.DeepEqual(lookup.Queries, wantLookupQueries) { - t.Errorf("lookup.Queries: %+v, want %+v\n", lookup.Queries, wantLookupQueries) - } + + utils.MustMatch(t, wantLookupQueries, lookup.Queries) testQueryLog(t, logChan, "VindexLookup", "SELECT", "select name, user_id from name_user_map where name in ::name", 1) testQueryLog(t, logChan, "TestExecute", "SELECT", sql, 1) @@ -829,7 +816,7 @@ func TestSelectNormalize(t *testing.T) { }, }} require.Empty(t, sbc1.Queries) - utils.MustMatch(t, sbc2.Queries, wantQueries, "sbc2.Queries") + utils.MustMatch(t, wantQueries, sbc2.Queries, "sbc2.Queries") sbc2.Queries = nil masterSession.TargetString = "" } diff --git a/go/vt/vtgate/vindexes/lookup_internal.go b/go/vt/vtgate/vindexes/lookup_internal.go index 346c7219f84..e14162e8efd 100644 --- a/go/vt/vtgate/vindexes/lookup_internal.go +++ b/go/vt/vtgate/vindexes/lookup_internal.go @@ -80,30 +80,7 @@ func (lkp *lookupInternal) Lookup(vcursor VCursor, ids []sqltypes.Value, co vtga if vcursor.InTransactionAndIsDML() { sel = sel + " for update" } - if !ids[0].IsIntegral() && !ids[0].IsBinary() { - // for non integral and binary type, fallback to send query per id - for _, id := range ids { - vars, err := sqltypes.BuildBindVariable([]interface{}{id}) - if err != nil { - return nil, fmt.Errorf("lookup.Map: %v", err) - } - bindVars := map[string]*querypb.BindVariable{ - lkp.FromColumns[0]: vars, - } - var result *sqltypes.Result - result, err = vcursor.Execute("VindexLookup", sel, bindVars, false /* rollbackOnError */, co) - if err != nil { - return nil, fmt.Errorf("lookup.Map: %v", err) - } - rows := make([][]sqltypes.Value, 0, len(result.Rows)) - for _, row := range result.Rows { - rows = append(rows, []sqltypes.Value{row[1]}) - } - results = append(results, &sqltypes.Result{ - Rows: rows, - }) - } - } else { + if ids[0].IsIntegral() { // for integral or binary type, batch query all ids and then map them back to the input order vars, err := sqltypes.BuildBindVariable(ids) if err != nil { @@ -126,6 +103,29 @@ func (lkp *lookupInternal) Lookup(vcursor VCursor, ids []sqltypes.Value, co vtga Rows: resultMap[id.ToString()], }) } + } else { + // for non integral and binary type, fallback to send query per id + for _, id := range ids { + vars, err := sqltypes.BuildBindVariable([]interface{}{id}) + if err != nil { + return nil, fmt.Errorf("lookup.Map: %v", err) + } + bindVars := map[string]*querypb.BindVariable{ + lkp.FromColumns[0]: vars, + } + var result *sqltypes.Result + result, err = vcursor.Execute("VindexLookup", sel, bindVars, false /* rollbackOnError */, co) + if err != nil { + return nil, fmt.Errorf("lookup.Map: %v", err) + } + rows := make([][]sqltypes.Value, 0, len(result.Rows)) + for _, row := range result.Rows { + rows = append(rows, []sqltypes.Value{row[1]}) + } + results = append(results, &sqltypes.Result{ + Rows: rows, + }) + } } return results, nil }