diff --git a/go/vt/binlog/keyspace_id_resolver.go b/go/vt/binlog/keyspace_id_resolver.go index c433f1aaeeb..5c97794da26 100644 --- a/go/vt/binlog/keyspace_id_resolver.go +++ b/go/vt/binlog/keyspace_id_resolver.go @@ -134,18 +134,10 @@ func newKeyspaceIDResolverFactoryV3(ctx context.Context, ts *topo.Server, keyspa return -1, nil, fmt.Errorf("no vschema definition for table %v", table.Name) } - // The primary vindex is most likely the sharding key, - // and has to be unique. - if len(tableSchema.ColumnVindexes) == 0 { - return -1, nil, fmt.Errorf("no vindex definition for table %v", table.Name) - } - colVindex := tableSchema.ColumnVindexes[0] - if colVindex.Vindex.Cost() > 1 { - return -1, nil, fmt.Errorf("primary vindex cost is too high for table %v", table.Name) - } - if !colVindex.Vindex.IsUnique() { - // This is impossible, but just checking anyway. - return -1, nil, fmt.Errorf("primary vindex is not unique for table %v", table.Name) + // use the lowest cost unique vindex as the sharding key + colVindex, err := vindexes.FindVindexForSharding(table.Name.String(), tableSchema.ColumnVindexes) + if err != nil { + return -1, nil, err } // TODO @rafael - when rewriting the mapping function, this will need to change. diff --git a/go/vt/vtgate/vindexes/vschema.go b/go/vt/vtgate/vindexes/vschema.go index 96a2100fee7..b140933cb38 100644 --- a/go/vt/vtgate/vindexes/vschema.go +++ b/go/vt/vtgate/vindexes/vschema.go @@ -485,3 +485,24 @@ func LoadFormalKeyspace(filename string) (*vschemapb.Keyspace, error) { } return formal, nil } + +// FindVindexForSharding searches through the given slice +// to find the lowest cost unique vindex +// primary vindex is always unique +// if two have the same cost, use the one that occurs earlier in the definition +// if the final result is too expensive, return nil +func FindVindexForSharding(tableName string, colVindexes []*ColumnVindex) (*ColumnVindex, error) { + if len(colVindexes) == 0 { + return nil, fmt.Errorf("no vindex definition for table %v", tableName) + } + result := colVindexes[0] + for _, colVindex := range colVindexes { + if colVindex.Vindex.Cost() < result.Vindex.Cost() && colVindex.Vindex.IsUnique() { + result = colVindex + } + } + if result.Vindex.Cost() > 1 || !result.Vindex.IsUnique() { + return nil, fmt.Errorf("could not find a vindex to use for sharding table %v", tableName) + } + return result, nil +} diff --git a/go/vt/vtgate/vindexes/vschema_test.go b/go/vt/vtgate/vindexes/vschema_test.go index d1f04d1a600..724573cb98e 100644 --- a/go/vt/vtgate/vindexes/vschema_test.go +++ b/go/vt/vtgate/vindexes/vschema_test.go @@ -468,6 +468,123 @@ func TestShardedVSchemaOwned(t *testing.T) { wantjson, _ := json.Marshal(want) t.Errorf("BuildVSchema:\n%s, want\n%s", gotjson, wantjson) } + +} + +func TestFindVindexForSharding(t *testing.T) { + ks := &Keyspace{ + Name: "sharded", + Sharded: true, + } + vindex1 := &stFU{ + name: "stfu1", + Params: map[string]string{ + "stfu1": "1", + }, + } + vindex2 := &stLN{name: "stln1"} + t1 := &Table{ + Name: sqlparser.NewTableIdent("t1"), + Keyspace: ks, + ColumnVindexes: []*ColumnVindex{ + { + Columns: []sqlparser.ColIdent{sqlparser.NewColIdent("c1")}, + Type: "stfu", + Name: "stfu1", + Vindex: vindex1, + }, + { + Columns: []sqlparser.ColIdent{sqlparser.NewColIdent("c2")}, + Type: "stln", + Name: "stln1", + Owned: true, + Vindex: vindex2, + }, + }, + } + res, err := FindVindexForSharding(t1.Name.String(), t1.ColumnVindexes) + if err != nil { + t.Error(err) + } + if !reflect.DeepEqual(res, t1.ColumnVindexes[0]) { + t.Errorf("FindVindexForSharding:\n got\n%v, want\n%v", res, t1.ColumnVindexes[0]) + } +} + +func TestFindVindexForShardingError(t *testing.T) { + ks := &Keyspace{ + Name: "sharded", + Sharded: true, + } + vindex1 := &stLU{name: "stlu1"} + vindex2 := &stLN{name: "stln1"} + t1 := &Table{ + Name: sqlparser.NewTableIdent("t1"), + Keyspace: ks, + ColumnVindexes: []*ColumnVindex{ + { + Columns: []sqlparser.ColIdent{sqlparser.NewColIdent("c1")}, + Type: "stlu", + Name: "stlu1", + Vindex: vindex1, + }, + { + Columns: []sqlparser.ColIdent{sqlparser.NewColIdent("c2")}, + Type: "stln", + Name: "stln1", + Owned: true, + Vindex: vindex2, + }, + }, + } + res, err := FindVindexForSharding(t1.Name.String(), t1.ColumnVindexes) + want := `could not find a vindex to use for sharding table t1` + if err == nil || err.Error() != want { + t.Errorf("FindVindexForSharding: %v, want %v", err, want) + } + if res != nil { + t.Errorf("FindVindexForSharding:\n got\n%v, want\n%v", res, nil) + } +} + +func TestFindVindexForSharding2(t *testing.T) { + ks := &Keyspace{ + Name: "sharded", + Sharded: true, + } + vindex1 := &stLU{name: "stlu1"} + vindex2 := &stFU{ + name: "stfu1", + Params: map[string]string{ + "stfu1": "1", + }, + } + t1 := &Table{ + Name: sqlparser.NewTableIdent("t1"), + Keyspace: ks, + ColumnVindexes: []*ColumnVindex{ + { + Columns: []sqlparser.ColIdent{sqlparser.NewColIdent("c1")}, + Type: "stlu", + Name: "stlu1", + Vindex: vindex1, + }, + { + Columns: []sqlparser.ColIdent{sqlparser.NewColIdent("c2")}, + Type: "stfu", + Name: "stfu1", + Owned: true, + Vindex: vindex2, + }, + }, + } + res, err := FindVindexForSharding(t1.Name.String(), t1.ColumnVindexes) + if err != nil { + t.Error(err) + } + if !reflect.DeepEqual(res, t1.ColumnVindexes[1]) { + t.Errorf("FindVindexForSharding:\n got\n%v, want\n%v", res, t1.ColumnVindexes[1]) + } } func TestShardedVSchemaMultiColumnVindex(t *testing.T) { diff --git a/go/vt/worker/key_resolver.go b/go/vt/worker/key_resolver.go index 193d807dba4..a17cdd27e0c 100644 --- a/go/vt/worker/key_resolver.go +++ b/go/vt/worker/key_resolver.go @@ -107,18 +107,10 @@ func newV3ResolverFromTableDefinition(keyspaceSchema *vindexes.KeyspaceSchema, t if !ok { return nil, fmt.Errorf("no vschema definition for table %v", td.Name) } - // the primary vindex is most likely the sharding key, and has to - // be unique. - if len(tableSchema.ColumnVindexes) == 0 { - return nil, fmt.Errorf("no vindex definition for table %v", td.Name) - } - colVindex := tableSchema.ColumnVindexes[0] - if colVindex.Vindex.Cost() > 1 { - return nil, fmt.Errorf("primary vindex cost is too high for table %v", td.Name) - } - if !colVindex.Vindex.IsUnique() { - // This is impossible, but just checking anyway. - return nil, fmt.Errorf("primary vindex is not unique for table %v", td.Name) + // use the lowest cost unique vindex as the sharding key + colVindex, err := vindexes.FindVindexForSharding(td.Name, tableSchema.ColumnVindexes) + if err != nil { + return nil, err } // Find the sharding key column index. @@ -139,18 +131,10 @@ func newV3ResolverFromColumnList(keyspaceSchema *vindexes.KeyspaceSchema, name s if !ok { return nil, fmt.Errorf("no vschema definition for table %v", name) } - // the primary vindex is most likely the sharding key, and has to - // be unique. - if len(tableSchema.ColumnVindexes) == 0 { - return nil, fmt.Errorf("no vindex definition for table %v", name) - } - colVindex := tableSchema.ColumnVindexes[0] - if colVindex.Vindex.Cost() > 1 { - return nil, fmt.Errorf("primary vindex cost is too high for table %v", name) - } - if !colVindex.Vindex.IsUnique() { - // This is impossible, but just checking anyway. - return nil, fmt.Errorf("primary vindex is not unique for table %v", name) + // use the lowest cost unique vindex as the sharding key + colVindex, err := vindexes.FindVindexForSharding(name, tableSchema.ColumnVindexes) + if err != nil { + return nil, err } // Find the sharding key column index.