diff --git a/bootstrap.sh b/bootstrap.sh index 34ba0c30e90..70108e7d37c 100755 --- a/bootstrap.sh +++ b/bootstrap.sh @@ -307,7 +307,7 @@ fi case "$MYSQL_FLAVOR" in "MySQL56") myversion="$("$VT_MYSQL_ROOT/bin/mysql" --version)" - [[ "$myversion" =~ Distrib\ 5\.[67] ]] || fail "Couldn't find MySQL 5.6+ in $VT_MYSQL_ROOT. Set VT_MYSQL_ROOT to override search location." + [[ "$myversion" =~ Distrib\ 5\.[67] || "$myversion" =~ Ver\ 8\. ]] || fail "Couldn't find MySQL 5.6+ in $VT_MYSQL_ROOT. Set VT_MYSQL_ROOT to override search location." echo "Found MySQL 5.6+ installation in $VT_MYSQL_ROOT." ;; diff --git a/doc/GettingStarted.md b/doc/GettingStarted.md index 65745febb7b..f65f812a699 100644 --- a/doc/GettingStarted.md +++ b/doc/GettingStarted.md @@ -271,7 +271,7 @@ In addition, Vitess requires the software and libraries listed below. ``` sh # Remaining commands to build Vitess - . ./dev.env + source ./dev.env make build ``` diff --git a/doc/ServerConfiguration.md b/doc/ServerConfiguration.md index 0124e470182..07020310175 100644 --- a/doc/ServerConfiguration.md +++ b/doc/ServerConfiguration.md @@ -552,6 +552,7 @@ Load-balancer in front of vtgate to scale up (not covered by Vitess). Stateless, ### Parameters * **cells_to_watch**: which cell vtgate is in and will monitor tablets from. Cross-cell master access needs multiple cells here. +* **keyspaces_to_watch**: Specifies that a vtgate will only be able to perform queries against or view the topology of these keyspaces * **tablet_types_to_wait**: VTGate waits for at least one serving tablet per tablet type specified here during startup, before listening to the serving port. So VTGate does not serve error. It should match the available tablet types VTGate connects to (master, replica, rdonly). * **discovery_low_replication_lag**: when replication lags of all VTTablet in a particular shard and tablet type are less than or equal the flag (in seconds), VTGate does not filter them by replication lag and uses all to balance traffic. * **degraded_threshold (30s)**: a tablet will publish itself as degraded if replication lag exceeds this threshold. This will cause VTGates to choose more up-to-date servers over this one. If all servers are degraded, VTGate resorts to serving from all of them. diff --git a/doc/TabletRouting.md b/doc/TabletRouting.md index 6d32f5f8ae1..167975e6469 100644 --- a/doc/TabletRouting.md +++ b/doc/TabletRouting.md @@ -93,36 +93,7 @@ There are two implementations of the Gateway interface: discovery section, one per cell) as a source of tablets, a HealthCheck module to watch their health, and a TabletStatsCache to collect all the health information. Based on this data, it can find the best tablet to use. -* l2VTGateGateway: It keeps a map of l2vtgate processes to send queries to. See - next section for more details. -## l2vtgate - -As we started increasing the number of tablets in a cell, it became clear that a -bottleneck of the system was going to be how many tablets a single vtgate is -connecting to. Since vtgate maintains a streaming health check connection per -tablet, the number of these connections can grow to large numbers. It is common -for vtgate to watch tablets in other cells, to be able to find the master -tablet. - -So l2vtgate came to exist, based on very similar concepts and interfaces: - -* l2vtgate is an extra hop between a vtgate pool and tablets. -* A l2vtgate pool connects to a subset of tablets, therefore it can have a - reasonable number of streaming health connections. Externally, it exposes the - QueryService RPC interface (that has the Target for the query, keyspace / - shard / tablet type). Internally, it uses a discoveryGateway, as usual. -* vtgate connects to l2vtgate pools (using the l2VTGateGateway instead of the - discoveryGateway). It has a map of which keyspace / shard / tablet type needs - to go to wich l2vtgate pool. At this point, vtgate doesn't maintain any health - information about the tablets, it lets l2vtgate handle it. - -Note l2vtgate is not an ideal solution as it is now. For instance, if there are -two cells, and the master for a shard can be in either, l2vtgate still has to -watch the tablets in both cells, to know where the master is. Ideally, we'd want -l2vtgate to be collocated with the tablets in a given cell, and not go -cross-cell. - # Extensions, work in progress ## Regions, cross-cell targeting @@ -169,31 +140,6 @@ between vtgate and l2vtgate: This would also be a good time to merge the vtgate code that uses the VSchema with the code that doesn't for SrvKeyspace access. -## Hybrid Gateway - -It would be nice to re-organize the code a bit inside vtgate to allow for an -hybrid gateway, and get rid of l2vtgate alltogether: - -* vtgate would use the discoveryGateway to watch the tablets in the current cell - (and optionally to any other cell we still want to consider local). -* vtgate would use l2vtgateGateway to watch the tablets in a different cell. -* vtgate would expose the RPC APIs currently exposed by the l2vtgate process. - -So vtgate would watch the tablets in the local cell only, but also know what -healthy tablets are in the other cells, and be able to send query to them -through their vtgate. The extra hop to the other cell vtgate should be a small -latency price to pay, compared to going cross-cell already. - -So queries would go one of two routes: - -* client(cell1) -> vtgate(cell1) -> tablet(cell1) -* client(cell1) -> vtgate(cell1) -> vtgate(cell2) -> tablet(cell2) - -If the number of tablets in a given cell is still too high for the local vtgate -pool, two or more pools can still be created, each of them knowing about a -subset of the tablets. And they would just forward queries to each others when -addressing the other tablet set. - ## Config-based routing Another possible extension would be to group all routing options for vtgate in a diff --git a/docker/lite/Dockerfile.alpine b/docker/lite/Dockerfile.alpine index a4621546caa..6bc9ccf0492 100644 --- a/docker/lite/Dockerfile.alpine +++ b/docker/lite/Dockerfile.alpine @@ -4,8 +4,8 @@ FROM alpine:3.8 AS staging RUN mkdir -p /vt/vtdataroot/ && mkdir -p /vt/bin && mkdir -p /vt/src/vitess.io/vitess/web/vtctld2 -COPY --from=builder /vt/src/vitess.io/vitess/web/vtctld /vt/src/vitess.io/web/vtctld -COPY --from=builder /vt/src/vitess.io/vitess/web/vtctld2/app /vt/src/vitess.io/web/vtctld2/app +COPY --from=builder /vt/src/vitess.io/vitess/web/vtctld /vt/src/vitess.io/vitess/web/vtctld +COPY --from=builder /vt/src/vitess.io/vitess/web/vtctld2/app /vt/src/vitess.io/vitess/web/vtctld2/app COPY --from=builder /vt/src/vitess.io/vitess/config /vt/config COPY --from=builder /vt/bin/mysqlctld /vt/bin/ COPY --from=builder /vt/bin/vtctld /vt/bin/ diff --git a/examples/local/vtgate-up.sh b/examples/local/vtgate-up.sh index 77483499e71..45bb080e822 100755 --- a/examples/local/vtgate-up.sh +++ b/examples/local/vtgate-up.sh @@ -55,10 +55,11 @@ then fi optional_auth_args='-mysql_auth_server_impl none' +optional_grpc_auth_args='' if [ "$1" = "--enable-grpc-static-auth" ]; then echo "Enabling Auth with static authentication in grpc" - optional_auth_args='-grpc_auth_static_client_creds ./grpc_static_client_auth.json' + optional_grpc_auth_args='-grpc_auth_static_client_creds ./grpc_static_client_auth.json' fi if [ "$1" = "--enable-mysql-static-auth" ]; @@ -84,6 +85,7 @@ $VTROOT/bin/vtgate \ -service_map 'grpc-vtgateservice' \ -pid_file $VTDATAROOT/tmp/vtgate.pid \ $optional_auth_args \ + $optional_grpc_auth_args \ $optional_tls_args \ > $VTDATAROOT/tmp/vtgate.out 2>&1 & diff --git a/go/cmd/topo2topo/topo2topo.go b/go/cmd/topo2topo/topo2topo.go index 7295877587b..78ab89d984b 100644 --- a/go/cmd/topo2topo/topo2topo.go +++ b/go/cmd/topo2topo/topo2topo.go @@ -18,6 +18,8 @@ package main import ( "flag" + "fmt" + "os" "golang.org/x/net/context" "vitess.io/vitess/go/exit" @@ -36,6 +38,7 @@ var ( toServerAddress = flag.String("to_server", "", "topology server address to copy data to") toRoot = flag.String("to_root", "", "topology server root to copy data to") + compare = flag.Bool("compare", false, "compares data between topologies") doKeyspaces = flag.Bool("do-keyspaces", false, "copies the keyspace information") doShards = flag.Bool("do-shards", false, "copies the shard information") doShardReplications = flag.Bool("do-shard-replications", false, "copies the shard replication information") @@ -64,6 +67,14 @@ func main() { ctx := context.Background() + if *compare { + compareTopos(ctx, fromTS, toTS) + return + } + copyTopos(ctx, fromTS, toTS) +} + +func copyTopos(ctx context.Context, fromTS, toTS *topo.Server) { if *doKeyspaces { helpers.CopyKeyspaces(ctx, fromTS, toTS) } @@ -76,4 +87,37 @@ func main() { if *doTablets { helpers.CopyTablets(ctx, fromTS, toTS) } + +} + +func compareTopos(ctx context.Context, fromTS, toTS *topo.Server) { + var err error + if *doKeyspaces { + err = helpers.CompareKeyspaces(ctx, fromTS, toTS) + if err != nil { + log.Exitf("Compare keyspaces failed: %v", err) + } + } + if *doShards { + err = helpers.CompareShards(ctx, fromTS, toTS) + if err != nil { + log.Exitf("Compare shards failed: %v", err) + } + } + if *doShardReplications { + err = helpers.CompareShardReplications(ctx, fromTS, toTS) + if err != nil { + log.Exitf("Compare shard replications failed: %v", err) + } + } + if *doTablets { + err = helpers.CompareTablets(ctx, fromTS, toTS) + if err != nil { + log.Exitf("Compare tablets failed: %v", err) + } + } + if err == nil { + fmt.Println("Topologies are in sync") + os.Exit(0) + } } diff --git a/go/cmd/vtgate/plugin_grpcqueryservice.go b/go/cmd/vtgate/plugin_grpcqueryservice.go deleted file mode 100644 index 16c163d095c..00000000000 --- a/go/cmd/vtgate/plugin_grpcqueryservice.go +++ /dev/null @@ -1,34 +0,0 @@ -/* -Copyright 2017 Google Inc. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package main - -// Imports and register the gRPC queryservice server - -import ( - "vitess.io/vitess/go/vt/servenv" - "vitess.io/vitess/go/vt/vtgate" - "vitess.io/vitess/go/vt/vttablet/grpcqueryservice" - "vitess.io/vitess/go/vt/vttablet/queryservice" -) - -func init() { - vtgate.RegisterL2VTGates = append(vtgate.RegisterL2VTGates, func(qs queryservice.QueryService) { - if servenv.GRPCCheckServiceMap("queryservice") { - grpcqueryservice.Register(servenv.GRPCServer, qs) - } - }) -} diff --git a/go/mysql/schema.go b/go/mysql/schema.go index 90bbe3c12df..3fc122c29b3 100644 --- a/go/mysql/schema.go +++ b/go/mysql/schema.go @@ -308,8 +308,8 @@ func ShowIndexFromTableRow(table string, unique bool, keyName string, seqInIndex sqltypes.MakeTrusted(sqltypes.VarChar, []byte(columnName)), sqltypes.MakeTrusted(sqltypes.VarChar, []byte("A")), // Collation sqltypes.MakeTrusted(sqltypes.Int64, []byte("0")), // Cardinality - sqltypes.NULL, // Sub_part - sqltypes.NULL, // Packed + sqltypes.NULL, // Sub_part + sqltypes.NULL, // Packed sqltypes.MakeTrusted(sqltypes.VarChar, []byte(nullableStr)), sqltypes.MakeTrusted(sqltypes.VarChar, []byte("BTREE")), // Index_type sqltypes.MakeTrusted(sqltypes.VarChar, []byte("")), // Comment diff --git a/go/mysql/server.go b/go/mysql/server.go index 0dd9d4e427a..f7a6099ab2d 100644 --- a/go/mysql/server.go +++ b/go/mysql/server.go @@ -21,6 +21,7 @@ import ( "fmt" "io" "net" + "strings" "time" "vitess.io/vitess/go/netutil" @@ -273,7 +274,9 @@ func (l *Listener) handle(conn net.Conn, connectionID uint32, acceptTime time.Ti // First build and send the server handshake packet. salt, err := c.writeHandshakeV10(l.ServerVersion, l.authServer, l.TLSConfig != nil) if err != nil { - log.Errorf("Cannot send HandshakeV10 packet to %s: %v", c, err) + if err != io.EOF { + log.Errorf("Cannot send HandshakeV10 packet to %s: %v", c, err) + } return } @@ -547,6 +550,12 @@ func (c *Conn) writeHandshakeV10(serverVersion string, authServer AuthServer, en } if err := c.writeEphemeralPacket(); err != nil { + if strings.HasSuffix(err.Error(), "write: connection reset by peer") { + return nil, io.EOF + } + if strings.HasSuffix(err.Error(), "write: broken pipe") { + return nil, io.EOF + } return nil, err } diff --git a/go/sqltypes/query_response.go b/go/sqltypes/query_response.go index eeca6cd6c55..5ce8e58b765 100644 --- a/go/sqltypes/query_response.go +++ b/go/sqltypes/query_response.go @@ -16,7 +16,9 @@ limitations under the License. package sqltypes -import "reflect" +import ( + "vitess.io/vitess/go/vt/vterrors" +) // QueryResponse represents a query response for ExecuteBatch. type QueryResponse struct { @@ -34,7 +36,7 @@ func QueryResponsesEqual(r1, r2 []QueryResponse) bool { if !r.QueryResult.Equal(r2[i].QueryResult) { return false } - if !reflect.DeepEqual(r.QueryError, r2[i].QueryError) { + if !vterrors.Equals(r.QueryError, r2[i].QueryError) { return false } } diff --git a/go/vt/automation/scheduler.go b/go/vt/automation/scheduler.go index 08d712189a1..e43c959fc77 100644 --- a/go/vt/automation/scheduler.go +++ b/go/vt/automation/scheduler.go @@ -82,11 +82,11 @@ func NewScheduler() (*Scheduler, error) { registeredClusterOperations: defaultClusterOperations, idGenerator: IDGenerator{}, toBeScheduledClusterOperations: make(chan ClusterOperationInstance, 10), - state: stateNotRunning, - taskCreator: defaultTaskCreator, - pendingOpsWg: &sync.WaitGroup{}, - activeClusterOperations: make(map[string]ClusterOperationInstance), - finishedClusterOperations: make(map[string]ClusterOperationInstance), + state: stateNotRunning, + taskCreator: defaultTaskCreator, + pendingOpsWg: &sync.WaitGroup{}, + activeClusterOperations: make(map[string]ClusterOperationInstance), + finishedClusterOperations: make(map[string]ClusterOperationInstance), } return s, nil diff --git a/go/vt/binlog/binlogplayer/binlog_player_test.go b/go/vt/binlog/binlogplayer/binlog_player_test.go index 27ad35c5712..8f4376f8212 100644 --- a/go/vt/binlog/binlogplayer/binlog_player_test.go +++ b/go/vt/binlog/binlogplayer/binlog_player_test.go @@ -37,8 +37,8 @@ var ( InsertID: 0, Rows: [][]sqltypes.Value{ { - sqltypes.NewVarBinary("MariaDB/0-1-1083"), // pos - sqltypes.NULL, // stop_pos + sqltypes.NewVarBinary("MariaDB/0-1-1083"), // pos + sqltypes.NULL, // stop_pos sqltypes.NewVarBinary("9223372036854775807"), // max_tps sqltypes.NewVarBinary("9223372036854775807"), // max_replication_lag }, diff --git a/go/vt/binlog/keyspace_id_resolver.go b/go/vt/binlog/keyspace_id_resolver.go index c433f1aaeeb..5c97794da26 100644 --- a/go/vt/binlog/keyspace_id_resolver.go +++ b/go/vt/binlog/keyspace_id_resolver.go @@ -134,18 +134,10 @@ func newKeyspaceIDResolverFactoryV3(ctx context.Context, ts *topo.Server, keyspa return -1, nil, fmt.Errorf("no vschema definition for table %v", table.Name) } - // The primary vindex is most likely the sharding key, - // and has to be unique. - if len(tableSchema.ColumnVindexes) == 0 { - return -1, nil, fmt.Errorf("no vindex definition for table %v", table.Name) - } - colVindex := tableSchema.ColumnVindexes[0] - if colVindex.Vindex.Cost() > 1 { - return -1, nil, fmt.Errorf("primary vindex cost is too high for table %v", table.Name) - } - if !colVindex.Vindex.IsUnique() { - // This is impossible, but just checking anyway. - return -1, nil, fmt.Errorf("primary vindex is not unique for table %v", table.Name) + // use the lowest cost unique vindex as the sharding key + colVindex, err := vindexes.FindVindexForSharding(table.Name.String(), tableSchema.ColumnVindexes) + if err != nil { + return -1, nil, err } // TODO @rafael - when rewriting the mapping function, this will need to change. diff --git a/go/vt/callinfo/fakecallinfo/fakecallinfo.go b/go/vt/callinfo/fakecallinfo/fakecallinfo.go index d7eedf2fc5c..84a6df9f5d3 100644 --- a/go/vt/callinfo/fakecallinfo/fakecallinfo.go +++ b/go/vt/callinfo/fakecallinfo/fakecallinfo.go @@ -16,13 +16,16 @@ limitations under the License. package fakecallinfo -import "html/template" +import ( + "fmt" + "html/template" +) // FakeCallInfo gives a fake Callinfo usable in callinfo type FakeCallInfo struct { Remote string + Method string User string - Txt string Html string } @@ -38,7 +41,7 @@ func (fci *FakeCallInfo) Username() string { // Text returns the text. func (fci *FakeCallInfo) Text() string { - return fci.Txt + return fmt.Sprintf("%s:%s(fakeRPC)", fci.Remote, fci.Method) } // HTML returns the html. diff --git a/go/vt/dbconfigs/dbconfigs.go b/go/vt/dbconfigs/dbconfigs.go index 5df9aa2d7da..c8a86106b16 100644 --- a/go/vt/dbconfigs/dbconfigs.go +++ b/go/vt/dbconfigs/dbconfigs.go @@ -100,6 +100,8 @@ func registerBaseFlags() { flag.StringVar(&baseConfig.SslCaPath, "db_ssl_ca_path", "", "connection ssl ca path") flag.StringVar(&baseConfig.SslCert, "db_ssl_cert", "", "connection ssl certificate") flag.StringVar(&baseConfig.SslKey, "db_ssl_key", "", "connection ssl key") + flag.StringVar(&baseConfig.ServerName, "db_server_name", "", "server name of the DB we are connecting to.") + } // The flags will change the global singleton @@ -124,6 +126,7 @@ func registerPerUserFlags(dbc *userConfig, userKey string) { flag.StringVar(&dbc.param.SslCaPath, "db-config-"+userKey+"-ssl-ca-path", "", "deprecated: use db_ssl_ca_path") flag.StringVar(&dbc.param.SslCert, "db-config-"+userKey+"-ssl-cert", "", "deprecated: use db_ssl_cert") flag.StringVar(&dbc.param.SslKey, "db-config-"+userKey+"-ssl-key", "", "deprecated: use db_ssl_key") + flag.StringVar(&dbc.param.ServerName, "db-config-"+userKey+"-server_name", "", "deprecated: use db_server_name") flag.StringVar(&dbc.param.DeprecatedDBName, "db-config-"+userKey+"-dbname", "", "deprecated: dbname does not need to be explicitly configured") @@ -246,6 +249,7 @@ func Init(defaultSocketFile string) (*DBConfigs, error) { uc.param.SslCaPath = baseConfig.SslCaPath uc.param.SslCert = baseConfig.SslCert uc.param.SslKey = baseConfig.SslKey + uc.param.ServerName = baseConfig.ServerName } } } else { diff --git a/go/vt/discovery/healthcheck_test.go b/go/vt/discovery/healthcheck_test.go index 616a15dfdff..980eb90cadd 100644 --- a/go/vt/discovery/healthcheck_test.go +++ b/go/vt/discovery/healthcheck_test.go @@ -85,18 +85,18 @@ func TestHealthCheck(t *testing.T) { // one tablet after receiving a StreamHealthResponse shr := &querypb.StreamHealthResponse{ - Target: &querypb.Target{Keyspace: "k", Shard: "s", TabletType: topodatapb.TabletType_MASTER}, - Serving: true, + Target: &querypb.Target{Keyspace: "k", Shard: "s", TabletType: topodatapb.TabletType_MASTER}, + Serving: true, TabletExternallyReparentedTimestamp: 10, RealtimeStats: &querypb.RealtimeStats{SecondsBehindMaster: 1, CpuUsage: 0.2}, } want = &TabletStats{ - Key: "a,vt:1", - Tablet: tablet, - Target: &querypb.Target{Keyspace: "k", Shard: "s", TabletType: topodatapb.TabletType_MASTER}, - Up: true, - Serving: true, - Stats: &querypb.RealtimeStats{SecondsBehindMaster: 1, CpuUsage: 0.2}, + Key: "a,vt:1", + Tablet: tablet, + Target: &querypb.Target{Keyspace: "k", Shard: "s", TabletType: topodatapb.TabletType_MASTER}, + Up: true, + Serving: true, + Stats: &querypb.RealtimeStats{SecondsBehindMaster: 1, CpuUsage: 0.2}, TabletExternallyReparentedTimestamp: 10, } input <- shr @@ -116,12 +116,12 @@ func TestHealthCheck(t *testing.T) { Cell: "cell", Target: &querypb.Target{Keyspace: "k", Shard: "s", TabletType: topodatapb.TabletType_MASTER}, TabletsStats: TabletStatsList{{ - Key: "a,vt:1", - Tablet: tablet, - Target: &querypb.Target{Keyspace: "k", Shard: "s", TabletType: topodatapb.TabletType_MASTER}, - Up: true, - Serving: true, - Stats: &querypb.RealtimeStats{SecondsBehindMaster: 1, CpuUsage: 0.2}, + Key: "a,vt:1", + Tablet: tablet, + Target: &querypb.Target{Keyspace: "k", Shard: "s", TabletType: topodatapb.TabletType_MASTER}, + Up: true, + Serving: true, + Stats: &querypb.RealtimeStats{SecondsBehindMaster: 1, CpuUsage: 0.2}, TabletExternallyReparentedTimestamp: 10, }}, }} @@ -132,20 +132,20 @@ func TestHealthCheck(t *testing.T) { // TabletType changed, should get both old and new event shr = &querypb.StreamHealthResponse{ - Target: &querypb.Target{Keyspace: "k", Shard: "s", TabletType: topodatapb.TabletType_REPLICA}, - Serving: true, + Target: &querypb.Target{Keyspace: "k", Shard: "s", TabletType: topodatapb.TabletType_REPLICA}, + Serving: true, TabletExternallyReparentedTimestamp: 0, RealtimeStats: &querypb.RealtimeStats{SecondsBehindMaster: 1, CpuUsage: 0.5}, } input <- shr t.Logf(`input <- {{Keyspace: "k", Shard: "s", TabletType: REPLICA}, Serving: true, TabletExternallyReparentedTimestamp: 0, {SecondsBehindMaster: 1, CpuUsage: 0.5}}`) want = &TabletStats{ - Key: "a,vt:1", - Tablet: tablet, - Target: &querypb.Target{Keyspace: "k", Shard: "s", TabletType: topodatapb.TabletType_MASTER}, - Up: false, - Serving: true, - Stats: &querypb.RealtimeStats{SecondsBehindMaster: 1, CpuUsage: 0.2}, + Key: "a,vt:1", + Tablet: tablet, + Target: &querypb.Target{Keyspace: "k", Shard: "s", TabletType: topodatapb.TabletType_MASTER}, + Up: false, + Serving: true, + Stats: &querypb.RealtimeStats{SecondsBehindMaster: 1, CpuUsage: 0.2}, TabletExternallyReparentedTimestamp: 10, } res = <-l.output @@ -153,12 +153,12 @@ func TestHealthCheck(t *testing.T) { t.Errorf(`<-l.output: %+v; want %+v`, res, want) } want = &TabletStats{ - Key: "a,vt:1", - Tablet: tablet, - Target: &querypb.Target{Keyspace: "k", Shard: "s", TabletType: topodatapb.TabletType_REPLICA}, - Up: true, - Serving: true, - Stats: &querypb.RealtimeStats{SecondsBehindMaster: 1, CpuUsage: 0.5}, + Key: "a,vt:1", + Tablet: tablet, + Target: &querypb.Target{Keyspace: "k", Shard: "s", TabletType: topodatapb.TabletType_REPLICA}, + Up: true, + Serving: true, + Stats: &querypb.RealtimeStats{SecondsBehindMaster: 1, CpuUsage: 0.5}, TabletExternallyReparentedTimestamp: 0, } res = <-l.output @@ -173,18 +173,18 @@ func TestHealthCheck(t *testing.T) { // Serving & RealtimeStats changed shr = &querypb.StreamHealthResponse{ - Target: &querypb.Target{Keyspace: "k", Shard: "s", TabletType: topodatapb.TabletType_REPLICA}, - Serving: false, + Target: &querypb.Target{Keyspace: "k", Shard: "s", TabletType: topodatapb.TabletType_REPLICA}, + Serving: false, TabletExternallyReparentedTimestamp: 0, RealtimeStats: &querypb.RealtimeStats{SecondsBehindMaster: 1, CpuUsage: 0.3}, } want = &TabletStats{ - Key: "a,vt:1", - Tablet: tablet, - Target: &querypb.Target{Keyspace: "k", Shard: "s", TabletType: topodatapb.TabletType_REPLICA}, - Up: true, - Serving: false, - Stats: &querypb.RealtimeStats{SecondsBehindMaster: 1, CpuUsage: 0.3}, + Key: "a,vt:1", + Tablet: tablet, + Target: &querypb.Target{Keyspace: "k", Shard: "s", TabletType: topodatapb.TabletType_REPLICA}, + Up: true, + Serving: false, + Stats: &querypb.RealtimeStats{SecondsBehindMaster: 1, CpuUsage: 0.3}, TabletExternallyReparentedTimestamp: 0, } input <- shr @@ -197,18 +197,18 @@ func TestHealthCheck(t *testing.T) { // HealthError shr = &querypb.StreamHealthResponse{ - Target: &querypb.Target{Keyspace: "k", Shard: "s", TabletType: topodatapb.TabletType_REPLICA}, - Serving: true, + Target: &querypb.Target{Keyspace: "k", Shard: "s", TabletType: topodatapb.TabletType_REPLICA}, + Serving: true, TabletExternallyReparentedTimestamp: 0, RealtimeStats: &querypb.RealtimeStats{HealthError: "some error", SecondsBehindMaster: 1, CpuUsage: 0.3}, } want = &TabletStats{ - Key: "a,vt:1", - Tablet: tablet, - Target: &querypb.Target{Keyspace: "k", Shard: "s", TabletType: topodatapb.TabletType_REPLICA}, - Up: true, - Serving: false, - Stats: &querypb.RealtimeStats{HealthError: "some error", SecondsBehindMaster: 1, CpuUsage: 0.3}, + Key: "a,vt:1", + Tablet: tablet, + Target: &querypb.Target{Keyspace: "k", Shard: "s", TabletType: topodatapb.TabletType_REPLICA}, + Up: true, + Serving: false, + Stats: &querypb.RealtimeStats{HealthError: "some error", SecondsBehindMaster: 1, CpuUsage: 0.3}, TabletExternallyReparentedTimestamp: 0, LastError: fmt.Errorf("vttablet error: some error"), } @@ -224,12 +224,12 @@ func TestHealthCheck(t *testing.T) { hc.deleteConn(tablet) t.Logf(`hc.RemoveTablet({Host: "a", PortMap: {"vt": 1}})`) want = &TabletStats{ - Key: "a,vt:1", - Tablet: tablet, - Target: &querypb.Target{Keyspace: "k", Shard: "s", TabletType: topodatapb.TabletType_REPLICA}, - Up: false, - Serving: false, - Stats: &querypb.RealtimeStats{HealthError: "some error", SecondsBehindMaster: 1, CpuUsage: 0.3}, + Key: "a,vt:1", + Tablet: tablet, + Target: &querypb.Target{Keyspace: "k", Shard: "s", TabletType: topodatapb.TabletType_REPLICA}, + Up: false, + Serving: false, + Stats: &querypb.RealtimeStats{HealthError: "some error", SecondsBehindMaster: 1, CpuUsage: 0.3}, TabletExternallyReparentedTimestamp: 0, LastError: context.Canceled, } @@ -271,18 +271,18 @@ func TestHealthCheckStreamError(t *testing.T) { // one tablet after receiving a StreamHealthResponse shr := &querypb.StreamHealthResponse{ - Target: &querypb.Target{Keyspace: "k", Shard: "s", TabletType: topodatapb.TabletType_REPLICA}, - Serving: true, + Target: &querypb.Target{Keyspace: "k", Shard: "s", TabletType: topodatapb.TabletType_REPLICA}, + Serving: true, TabletExternallyReparentedTimestamp: 0, RealtimeStats: &querypb.RealtimeStats{SecondsBehindMaster: 1, CpuUsage: 0.2}, } want = &TabletStats{ - Key: "a,vt:1", - Tablet: tablet, - Target: &querypb.Target{Keyspace: "k", Shard: "s", TabletType: topodatapb.TabletType_REPLICA}, - Up: true, - Serving: true, - Stats: &querypb.RealtimeStats{SecondsBehindMaster: 1, CpuUsage: 0.2}, + Key: "a,vt:1", + Tablet: tablet, + Target: &querypb.Target{Keyspace: "k", Shard: "s", TabletType: topodatapb.TabletType_REPLICA}, + Up: true, + Serving: true, + Stats: &querypb.RealtimeStats{SecondsBehindMaster: 1, CpuUsage: 0.2}, TabletExternallyReparentedTimestamp: 0, } input <- shr @@ -295,12 +295,12 @@ func TestHealthCheckStreamError(t *testing.T) { // Stream error fc.errCh <- fmt.Errorf("some stream error") want = &TabletStats{ - Key: "a,vt:1", - Tablet: tablet, - Target: &querypb.Target{Keyspace: "k", Shard: "s", TabletType: topodatapb.TabletType_REPLICA}, - Up: true, - Serving: false, - Stats: &querypb.RealtimeStats{SecondsBehindMaster: 1, CpuUsage: 0.2}, + Key: "a,vt:1", + Tablet: tablet, + Target: &querypb.Target{Keyspace: "k", Shard: "s", TabletType: topodatapb.TabletType_REPLICA}, + Up: true, + Serving: false, + Stats: &querypb.RealtimeStats{SecondsBehindMaster: 1, CpuUsage: 0.2}, TabletExternallyReparentedTimestamp: 0, LastError: fmt.Errorf("some stream error"), } @@ -342,9 +342,9 @@ func TestHealthCheckVerifiesTabletAlias(t *testing.T) { } input <- &querypb.StreamHealthResponse{ - Target: &querypb.Target{Keyspace: "k", Shard: "s", TabletType: topodatapb.TabletType_MASTER}, - TabletAlias: &topodatapb.TabletAlias{Uid: 20, Cell: "cellb"}, - Serving: true, + Target: &querypb.Target{Keyspace: "k", Shard: "s", TabletType: topodatapb.TabletType_MASTER}, + TabletAlias: &topodatapb.TabletAlias{Uid: 20, Cell: "cellb"}, + Serving: true, TabletExternallyReparentedTimestamp: 10, RealtimeStats: &querypb.RealtimeStats{SecondsBehindMaster: 1, CpuUsage: 0.2}, } @@ -360,9 +360,9 @@ func TestHealthCheckVerifiesTabletAlias(t *testing.T) { } input <- &querypb.StreamHealthResponse{ - Target: &querypb.Target{Keyspace: "k", Shard: "s", TabletType: topodatapb.TabletType_MASTER}, - TabletAlias: &topodatapb.TabletAlias{Uid: 1, Cell: "cell"}, - Serving: true, + Target: &querypb.Target{Keyspace: "k", Shard: "s", TabletType: topodatapb.TabletType_MASTER}, + TabletAlias: &topodatapb.TabletAlias{Uid: 1, Cell: "cell"}, + Serving: true, TabletExternallyReparentedTimestamp: 10, RealtimeStats: &querypb.RealtimeStats{SecondsBehindMaster: 1, CpuUsage: 0.2}, } @@ -409,18 +409,18 @@ func TestHealthCheckCloseWaitsForGoRoutines(t *testing.T) { // Verify that the listener works in general. shr := &querypb.StreamHealthResponse{ - Target: &querypb.Target{Keyspace: "k", Shard: "s", TabletType: topodatapb.TabletType_MASTER}, - Serving: true, + Target: &querypb.Target{Keyspace: "k", Shard: "s", TabletType: topodatapb.TabletType_MASTER}, + Serving: true, TabletExternallyReparentedTimestamp: 10, RealtimeStats: &querypb.RealtimeStats{SecondsBehindMaster: 1, CpuUsage: 0.2}, } want = &TabletStats{ - Key: "a,vt:1", - Tablet: tablet, - Target: &querypb.Target{Keyspace: "k", Shard: "s", TabletType: topodatapb.TabletType_MASTER}, - Up: true, - Serving: true, - Stats: &querypb.RealtimeStats{SecondsBehindMaster: 1, CpuUsage: 0.2}, + Key: "a,vt:1", + Tablet: tablet, + Target: &querypb.Target{Keyspace: "k", Shard: "s", TabletType: topodatapb.TabletType_MASTER}, + Up: true, + Serving: true, + Stats: &querypb.RealtimeStats{SecondsBehindMaster: 1, CpuUsage: 0.2}, TabletExternallyReparentedTimestamp: 10, } input <- shr @@ -504,18 +504,18 @@ func TestHealthCheckTimeout(t *testing.T) { // one tablet after receiving a StreamHealthResponse shr := &querypb.StreamHealthResponse{ - Target: &querypb.Target{Keyspace: "k", Shard: "s", TabletType: topodatapb.TabletType_MASTER}, - Serving: true, + Target: &querypb.Target{Keyspace: "k", Shard: "s", TabletType: topodatapb.TabletType_MASTER}, + Serving: true, TabletExternallyReparentedTimestamp: 10, RealtimeStats: &querypb.RealtimeStats{SecondsBehindMaster: 1, CpuUsage: 0.2}, } want = &TabletStats{ - Key: "a,vt:1", - Tablet: tablet, - Target: &querypb.Target{Keyspace: "k", Shard: "s", TabletType: topodatapb.TabletType_MASTER}, - Up: true, - Serving: true, - Stats: &querypb.RealtimeStats{SecondsBehindMaster: 1, CpuUsage: 0.2}, + Key: "a,vt:1", + Tablet: tablet, + Target: &querypb.Target{Keyspace: "k", Shard: "s", TabletType: topodatapb.TabletType_MASTER}, + Up: true, + Serving: true, + Stats: &querypb.RealtimeStats{SecondsBehindMaster: 1, CpuUsage: 0.2}, TabletExternallyReparentedTimestamp: 10, } input <- shr @@ -580,12 +580,12 @@ func TestTemplate(t *testing.T) { tablet := topo.NewTablet(0, "cell", "a") ts := []*TabletStats{ { - Key: "a", - Tablet: tablet, - Target: &querypb.Target{Keyspace: "k", Shard: "s", TabletType: topodatapb.TabletType_REPLICA}, - Up: true, - Serving: false, - Stats: &querypb.RealtimeStats{SecondsBehindMaster: 1, CpuUsage: 0.3}, + Key: "a", + Tablet: tablet, + Target: &querypb.Target{Keyspace: "k", Shard: "s", TabletType: topodatapb.TabletType_REPLICA}, + Up: true, + Serving: false, + Stats: &querypb.RealtimeStats{SecondsBehindMaster: 1, CpuUsage: 0.3}, TabletExternallyReparentedTimestamp: 0, }, } @@ -612,12 +612,12 @@ func TestDebugURLFormatting(t *testing.T) { tablet := topo.NewTablet(0, "cell", "host.dc.domain") ts := []*TabletStats{ { - Key: "a", - Tablet: tablet, - Target: &querypb.Target{Keyspace: "k", Shard: "s", TabletType: topodatapb.TabletType_REPLICA}, - Up: true, - Serving: false, - Stats: &querypb.RealtimeStats{SecondsBehindMaster: 1, CpuUsage: 0.3}, + Key: "a", + Tablet: tablet, + Target: &querypb.Target{Keyspace: "k", Shard: "s", TabletType: topodatapb.TabletType_REPLICA}, + Up: true, + Serving: false, + Stats: &querypb.RealtimeStats{SecondsBehindMaster: 1, CpuUsage: 0.3}, TabletExternallyReparentedTimestamp: 0, }, } diff --git a/go/vt/discovery/tablet_stats_cache.go b/go/vt/discovery/tablet_stats_cache.go index 720d8089dc5..7fa1809e5ce 100644 --- a/go/vt/discovery/tablet_stats_cache.go +++ b/go/vt/discovery/tablet_stats_cache.go @@ -68,10 +68,8 @@ type tabletStatsCacheEntry struct { all map[string]*TabletStats // healthy only has the healthy ones. healthy []*TabletStats - // aggregates has the per-cell aggregates. + // aggregates has the per-region aggregates. aggregates map[string]*querypb.AggregateStats - // aggregatesPerRegion has the per-region aggregates. - aggregatesPerRegion map[string]*querypb.AggregateStats } func (e *tabletStatsCacheEntry) updateHealthyMapForMaster(ts *TabletStats) { @@ -143,7 +141,6 @@ func newTabletStatsCache(hc HealthCheck, ts *topo.Server, cell string, setListen // upon type change. hc.SetListener(tc, true /*sendDownEvents*/) } - go tc.broadcastAggregateStats() return tc } @@ -268,21 +265,18 @@ func (tc *TabletStatsCache) StatsUpdate(ts *TabletStats) { tc.updateAggregateMap(ts.Target.Keyspace, ts.Target.Shard, ts.Target.TabletType, e, allArray) } -// makeAggregateMap takes a list of TabletStats and builds a per-cell +// makeAggregateMap takes a list of TabletStats and builds a per-region // AggregateStats map. -func (tc *TabletStatsCache) makeAggregateMap(stats []*TabletStats, buildForRegion bool) map[string]*querypb.AggregateStats { +func (tc *TabletStatsCache) makeAggregateMap(stats []*TabletStats) map[string]*querypb.AggregateStats { result := make(map[string]*querypb.AggregateStats) for _, ts := range stats { - cellOrRegion := ts.Tablet.Alias.Cell - if buildForRegion { - cellOrRegion = tc.getRegionByCell(cellOrRegion) - } - agg, ok := result[cellOrRegion] + region := tc.getRegionByCell(ts.Tablet.Alias.Cell) + agg, ok := result[region] if !ok { agg = &querypb.AggregateStats{ SecondsBehindMasterMin: math.MaxUint32, } - result[cellOrRegion] = agg + result[region] = agg } if ts.Serving && ts.LastError == nil { @@ -300,102 +294,12 @@ func (tc *TabletStatsCache) makeAggregateMap(stats []*TabletStats, buildForRegio return result } -// makeAggregateMapDiff computes the entries that need to be broadcast -// when the map goes from oldMap to newMap. -func makeAggregateMapDiff(keyspace, shard string, tabletType topodatapb.TabletType, ter int64, oldMap map[string]*querypb.AggregateStats, newMap map[string]*querypb.AggregateStats) []*srvtopo.TargetStatsEntry { - var result []*srvtopo.TargetStatsEntry - for cell, oldValue := range oldMap { - newValue, ok := newMap[cell] - if ok { - // We have both an old and a new value. If equal, - // skip it. - if oldValue.HealthyTabletCount == newValue.HealthyTabletCount && - oldValue.UnhealthyTabletCount == newValue.UnhealthyTabletCount && - oldValue.SecondsBehindMasterMin == newValue.SecondsBehindMasterMin && - oldValue.SecondsBehindMasterMax == newValue.SecondsBehindMasterMax { - continue - } - // The new value is different, send it. - result = append(result, &srvtopo.TargetStatsEntry{ - Target: &querypb.Target{ - Keyspace: keyspace, - Shard: shard, - TabletType: tabletType, - Cell: cell, - }, - Stats: newValue, - TabletExternallyReparentedTimestamp: ter, - }) - } else { - // We only have the old value, send an empty - // record to clear it. - result = append(result, &srvtopo.TargetStatsEntry{ - Target: &querypb.Target{ - Keyspace: keyspace, - Shard: shard, - TabletType: tabletType, - Cell: cell, - }, - }) - } - } - - for cell, newValue := range newMap { - if _, ok := oldMap[cell]; ok { - continue - } - // New value, no old value, just send it. - result = append(result, &srvtopo.TargetStatsEntry{ - Target: &querypb.Target{ - Keyspace: keyspace, - Shard: shard, - TabletType: tabletType, - Cell: cell, - }, - Stats: newValue, - TabletExternallyReparentedTimestamp: ter, - }) - } - return result -} - // updateAggregateMap will update the aggregate map for the // tabletStatsCacheEntry. It may broadcast the changes too if we have listeners. // e.mu needs to be locked. func (tc *TabletStatsCache) updateAggregateMap(keyspace, shard string, tabletType topodatapb.TabletType, e *tabletStatsCacheEntry, stats []*TabletStats) { // Save the new value - oldAgg := e.aggregates - newAgg := tc.makeAggregateMap(stats /* buildForRegion */, false) - e.aggregates = newAgg - e.aggregatesPerRegion = tc.makeAggregateMap(stats /* buildForRegion */, true) - - // And broadcast the change in the background, if we need to. - tc.mu.RLock() - if !tc.tsm.HasSubscribers() { - // Shortcut: no subscriber, we can be done. - tc.mu.RUnlock() - return - } - tc.mu.RUnlock() - - var ter int64 - if len(stats) > 0 { - ter = stats[0].TabletExternallyReparentedTimestamp - } - diffs := makeAggregateMapDiff(keyspace, shard, tabletType, ter, oldAgg, newAgg) - tc.aggregatesChan <- diffs -} - -// broadcastAggregateStats is called in the background to send aggregate stats -// in the right order to our subscribers. -func (tc *TabletStatsCache) broadcastAggregateStats() { - for diffs := range tc.aggregatesChan { - tc.mu.RLock() - for _, d := range diffs { - tc.tsm.Broadcast(d) - } - tc.mu.RUnlock() - } + e.aggregates = tc.makeAggregateMap(stats) } // GetTabletStats returns the full list of available targets. @@ -442,51 +346,6 @@ func (tc *TabletStatsCache) ResetForTesting() { tc.entries = make(map[string]map[string]map[topodatapb.TabletType]*tabletStatsCacheEntry) } -// Subscribe is part of the TargetStatsListener interface. -func (tc *TabletStatsCache) Subscribe() (int, []srvtopo.TargetStatsEntry, <-chan (*srvtopo.TargetStatsEntry), error) { - var allTS []srvtopo.TargetStatsEntry - - // Make sure the map cannot change. Also blocks any update from - // propagating. - tc.mu.Lock() - defer tc.mu.Unlock() - for keyspace, shardMap := range tc.entries { - for shard, typeMap := range shardMap { - for tabletType, e := range typeMap { - e.mu.RLock() - var ter int64 - if len(e.healthy) > 0 { - ter = e.healthy[0].TabletExternallyReparentedTimestamp - } - for cell, agg := range e.aggregates { - allTS = append(allTS, srvtopo.TargetStatsEntry{ - Target: &querypb.Target{ - Keyspace: keyspace, - Shard: shard, - TabletType: tabletType, - Cell: cell, - }, - Stats: agg, - TabletExternallyReparentedTimestamp: ter, - }) - } - e.mu.RUnlock() - } - } - } - - // Now create the listener, add it to our list. - id, c := tc.tsm.Subscribe() - return id, allTS, c, nil -} - -// Unsubscribe is part of the TargetStatsListener interface. -func (tc *TabletStatsCache) Unsubscribe(i int) error { - tc.mu.Lock() - defer tc.mu.Unlock() - return tc.tsm.Unsubscribe(i) -} - // GetAggregateStats is part of the TargetStatsListener interface. func (tc *TabletStatsCache) GetAggregateStats(target *querypb.Target) (*querypb.AggregateStats, error) { e := tc.getEntry(target.Keyspace, target.Shard, target.TabletType) @@ -505,7 +364,7 @@ func (tc *TabletStatsCache) GetAggregateStats(target *querypb.Target) (*querypb. } } targetRegion := tc.getRegionByCell(target.Cell) - agg, ok := e.aggregatesPerRegion[targetRegion] + agg, ok := e.aggregates[targetRegion] if !ok { return nil, topo.NewError(topo.NoNode, topotools.TargetIdent(target)) } @@ -537,4 +396,3 @@ func (tc *TabletStatsCache) GetMasterCell(keyspace, shard string) (cell string, // Compile-time interface check. var _ HealthCheckStatsListener = (*TabletStatsCache)(nil) -var _ srvtopo.TargetStatsListener = (*TabletStatsCache)(nil) diff --git a/go/vt/logutil/console_logger.go b/go/vt/logutil/console_logger.go index ba9f752f234..b66def0c6a4 100644 --- a/go/vt/logutil/console_logger.go +++ b/go/vt/logutil/console_logger.go @@ -50,6 +50,16 @@ func (cl *ConsoleLogger) Errorf(format string, v ...interface{}) { cl.ErrorDepth(1, fmt.Sprintf(format, v...)) } +// Errorf2 is part of the Logger interface +func (cl *ConsoleLogger) Errorf2(err error, format string, v ...interface{}) { + cl.ErrorDepth(1, fmt.Sprintf(format+": %+v", append(v, err))) +} + +// Error is part of the Logger interface +func (cl *ConsoleLogger) Error(err error) { + cl.ErrorDepth(1, fmt.Sprintf("%+v", err)) +} + // Printf is part of the Logger interface func (cl *ConsoleLogger) Printf(format string, v ...interface{}) { fmt.Printf(format, v...) diff --git a/go/vt/logutil/logger.go b/go/vt/logutil/logger.go index d524584cc2d..7acb086d698 100644 --- a/go/vt/logutil/logger.go +++ b/go/vt/logutil/logger.go @@ -38,7 +38,10 @@ type Logger interface { Warningf(format string, v ...interface{}) // Errorf logs at ERROR level. A newline is appended if missing. Errorf(format string, v ...interface{}) + // Errorf2 logs an error with stack traces at ERROR level. A newline is appended if missing. + Errorf2(e error, message string, v ...interface{}) + Error(e error) // Printf will just display information on stdout when possible. // No newline is appended. Printf(format string, v ...interface{}) @@ -181,6 +184,16 @@ func (cl *CallbackLogger) Errorf(format string, v ...interface{}) { cl.ErrorDepth(1, fmt.Sprintf(format, v...)) } +// Errorf2 is part of the Logger interface +func (cl *CallbackLogger) Errorf2(err error, format string, v ...interface{}) { + cl.ErrorDepth(1, fmt.Sprintf(format+": %+v", append(v, err))) +} + +// Error is part of the Logger interface +func (cl *CallbackLogger) Error(err error) { + cl.ErrorDepth(1, fmt.Sprintf("%+v", err)) +} + // Printf is part of the Logger interface. func (cl *CallbackLogger) Printf(format string, v ...interface{}) { file, line := fileAndLine(2) @@ -321,6 +334,16 @@ func (tl *TeeLogger) Errorf(format string, v ...interface{}) { tl.ErrorDepth(1, fmt.Sprintf(format, v...)) } +// Errorf2 is part of the Logger interface +func (tl *TeeLogger) Errorf2(err error, format string, v ...interface{}) { + tl.ErrorDepth(1, fmt.Sprintf(format+": %+v", append(v, err))) +} + +// Error is part of the Logger interface +func (tl *TeeLogger) Error(err error) { + tl.ErrorDepth(1, fmt.Sprintf("%+v", err)) +} + // Printf is part of the Logger interface func (tl *TeeLogger) Printf(format string, v ...interface{}) { tl.One.Printf(format, v...) diff --git a/go/vt/mysqlctl/backup.go b/go/vt/mysqlctl/backup.go index 7120e9f6c7a..ccecaee9431 100644 --- a/go/vt/mysqlctl/backup.go +++ b/go/vt/mysqlctl/backup.go @@ -284,7 +284,7 @@ func Backup(ctx context.Context, cnf *Mycnf, mysqld MysqlDaemon, logger logutil. if usable { finishErr = bh.EndBackup(ctx) } else { - logger.Errorf("backup is not usable, aborting it: %v", err) + logger.Errorf2(err, "backup is not usable, aborting it") finishErr = bh.AbortBackup(ctx) } if err != nil { @@ -292,7 +292,7 @@ func Backup(ctx context.Context, cnf *Mycnf, mysqld MysqlDaemon, logger logutil. // We have a backup error, and we also failed // to finish the backup: just log the backup // finish error, return the backup error. - logger.Errorf("failed to finish backup: %v", finishErr) + logger.Errorf2(finishErr, "failed to finish backup: %v") } return err } @@ -493,7 +493,7 @@ func backupFile(ctx context.Context, cnf *Mycnf, mysqld MysqlDaemon, logger logu if rerr := wc.Close(); rerr != nil { if err != nil { // We already have an error, just log this one. - logger.Errorf("failed to close file %v: %v", name, rerr) + logger.Errorf2(rerr, "failed to close file %v", name) } else { err = rerr } @@ -797,7 +797,7 @@ func Restore( } if !ok { logger.Infof("Auto-restore is enabled, but mysqld already contains data. Assuming vttablet was just restarted.") - if err = populateMetadataTables(mysqld, localMetadata); err == nil { + if err = PopulateMetadataTables(mysqld, localMetadata); err == nil { err = ErrExistingDB } return mysql.Position{}, err @@ -820,7 +820,7 @@ func Restore( if len(bhs) == 0 { // There are no backups (not even broken/incomplete ones). logger.Errorf("No backup to restore on BackupStorage for directory %v. Starting up empty.", dir) - if err = populateMetadataTables(mysqld, localMetadata); err == nil { + if err = PopulateMetadataTables(mysqld, localMetadata); err == nil { err = ErrNoBackup } return mysql.Position{}, err @@ -901,7 +901,7 @@ func Restore( // Populate local_metadata before starting without --skip-networking, // so it's there before we start announcing ourselves. logger.Infof("Restore: populating local_metadata") - err = populateMetadataTables(mysqld, localMetadata) + err = PopulateMetadataTables(mysqld, localMetadata) if err != nil { return mysql.Position{}, err } diff --git a/go/vt/mysqlctl/metadata_tables.go b/go/vt/mysqlctl/metadata_tables.go index ac911ccfdbc..5fd4aee1134 100644 --- a/go/vt/mysqlctl/metadata_tables.go +++ b/go/vt/mysqlctl/metadata_tables.go @@ -36,7 +36,7 @@ const sqlCreateShardMetadataTable = `CREATE TABLE IF NOT EXISTS _vt.shard_metada PRIMARY KEY (name) ) ENGINE=InnoDB` -// populateMetadataTables creates and fills the _vt.local_metadata table and +// PopulateMetadataTables creates and fills the _vt.local_metadata table and // creates _vt.shard_metadata table. _vt.local_metadata table is // a per-tablet table that is never replicated. This allows queries // against local_metadata to return different values on different tablets, @@ -46,7 +46,7 @@ const sqlCreateShardMetadataTable = `CREATE TABLE IF NOT EXISTS _vt.shard_metada // created here to make it easier to create it on databases that were running // old version of Vitess, or databases that are getting converted to run under // Vitess. -func populateMetadataTables(mysqld MysqlDaemon, localMetadata map[string]string) error { +func PopulateMetadataTables(mysqld MysqlDaemon, localMetadata map[string]string) error { log.Infof("Populating _vt.local_metadata table...") // Get a non-pooled DBA connection. diff --git a/go/vt/sqlparser/ast.go b/go/vt/sqlparser/ast.go index 98e7aa64553..df3dfd75dd8 100644 --- a/go/vt/sqlparser/ast.go +++ b/go/vt/sqlparser/ast.go @@ -3691,7 +3691,7 @@ mustEscape: } func compliantName(in string) string { - var buf bytes.Buffer + var buf strings.Builder for i, c := range in { if !isLetter(uint16(c)) { if i == 0 || !isDigit(uint16(c)) { diff --git a/go/vt/srvtopo/keyspace_filtering_server.go b/go/vt/srvtopo/keyspace_filtering_server.go new file mode 100644 index 00000000000..32b264001cc --- /dev/null +++ b/go/vt/srvtopo/keyspace_filtering_server.go @@ -0,0 +1,116 @@ +/* +Copyright 2018 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package srvtopo + +import ( + "fmt" + + "golang.org/x/net/context" + + topodatapb "vitess.io/vitess/go/vt/proto/topodata" + vschemapb "vitess.io/vitess/go/vt/proto/vschema" + "vitess.io/vitess/go/vt/topo" +) + +var ( + // ErrNilUnderlyingServer is returned when attempting to create a new keyspace + // filtering server if a nil underlying server implementation is provided. + ErrNilUnderlyingServer = fmt.Errorf("Unable to construct filtering server without an underlying server") + + // ErrTopoServerNotAvailable is returned if a caller tries to access the + // topo.Server supporting this srvtopo.Server. + ErrTopoServerNotAvailable = fmt.Errorf("Cannot access underlying topology server when keyspace filtering is enabled") +) + +// NewKeyspaceFilteringServer constructs a new server based on the provided +// implementation that prevents the specified keyspaces from being exposed +// to consumers of the new Server. +// +// A filtering server will not allow access to the topo.Server to prevent +// updates that may corrupt the global VSchema keyspace. +func NewKeyspaceFilteringServer(underlying Server, selectedKeyspaces []string) (Server, error) { + if underlying == nil { + return nil, ErrNilUnderlyingServer + } + + keyspaces := map[string]bool{} + for _, ks := range selectedKeyspaces { + keyspaces[ks] = true + } + + return keyspaceFilteringServer{ + server: underlying, + selectKeyspaces: keyspaces, + }, nil +} + +type keyspaceFilteringServer struct { + server Server + selectKeyspaces map[string]bool +} + +// GetTopoServer returns an error; filtering srvtopo.Server consumers may not +// access the underlying topo.Server. +func (ksf keyspaceFilteringServer) GetTopoServer() (*topo.Server, error) { + return nil, ErrTopoServerNotAvailable +} + +func (ksf keyspaceFilteringServer) GetSrvKeyspaceNames( + ctx context.Context, + cell string, +) ([]string, error) { + keyspaces, err := ksf.server.GetSrvKeyspaceNames(ctx, cell) + ret := make([]string, 0, len(keyspaces)) + for _, ks := range keyspaces { + if ksf.selectKeyspaces[ks] { + ret = append(ret, ks) + } + } + return ret, err +} + +func (ksf keyspaceFilteringServer) GetSrvKeyspace( + ctx context.Context, + cell, + keyspace string, +) (*topodatapb.SrvKeyspace, error) { + if !ksf.selectKeyspaces[keyspace] { + return nil, topo.NewError(topo.NoNode, keyspace) + } + + return ksf.server.GetSrvKeyspace(ctx, cell, keyspace) +} + +func (ksf keyspaceFilteringServer) WatchSrvVSchema( + ctx context.Context, + cell string, + callback func(*vschemapb.SrvVSchema, error), +) { + filteringCallback := func(schema *vschemapb.SrvVSchema, err error) { + if schema != nil { + for ks := range schema.Keyspaces { + if !ksf.selectKeyspaces[ks] { + delete(schema.Keyspaces, ks) + } + } + } + + callback(schema, err) + } + + ksf.server.WatchSrvVSchema(ctx, cell, filteringCallback) +} diff --git a/go/vt/srvtopo/keyspace_filtering_server_test.go b/go/vt/srvtopo/keyspace_filtering_server_test.go new file mode 100644 index 00000000000..8c0560dfbe0 --- /dev/null +++ b/go/vt/srvtopo/keyspace_filtering_server_test.go @@ -0,0 +1,229 @@ +/* +Copyright 2018 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package srvtopo + +import ( + "fmt" + "reflect" + "sync" + "testing" + + "golang.org/x/net/context" + + topodatapb "vitess.io/vitess/go/vt/proto/topodata" + vschemapb "vitess.io/vitess/go/vt/proto/vschema" + "vitess.io/vitess/go/vt/srvtopo/srvtopotest" + "vitess.io/vitess/go/vt/topo" + "vitess.io/vitess/go/vt/topo/memorytopo" +) + +var ( + stockCell = "some-cell" + stockCtx = context.Background() + stockFilters = []string{"bar", "baz"} + stockKeyspaces = map[string]*topodatapb.SrvKeyspace{ + "foo": {ShardingColumnName: "foo"}, + "bar": {ShardingColumnName: "bar"}, + "baz": {ShardingColumnName: "baz"}, + } + stockVSchema = &vschemapb.SrvVSchema{ + Keyspaces: map[string]*vschemapb.Keyspace{ + "foo": {Sharded: true}, + "bar": {Sharded: true}, + "baz": {Sharded: false}, + }, + } +) + +func newFiltering(filter []string) (*topo.Server, *srvtopotest.PassthroughSrvTopoServer, Server) { + testServer := srvtopotest.NewPassthroughSrvTopoServer() + + testServer.TopoServer = memorytopo.NewServer(stockCell) + testServer.SrvKeyspaceNames = []string{"foo", "bar", "baz"} + testServer.SrvKeyspace = &topodatapb.SrvKeyspace{ShardingColumnName: "test-column"} + testServer.WatchedSrvVSchema = stockVSchema + + filtering, _ := NewKeyspaceFilteringServer(testServer, filter) + return testServer.TopoServer, testServer, filtering +} + +func TestFilteringServerHandlesNilUnderlying(t *testing.T) { + got, gotErr := NewKeyspaceFilteringServer(nil, []string{}) + if got != nil { + t.Errorf("got: %v, wanted: nil server", got) + } + if gotErr != ErrNilUnderlyingServer { + t.Errorf("Bad error returned: got %v wanted %v", gotErr, ErrNilUnderlyingServer) + } +} + +func TestFilteringServerReturnsUnderlyingServer(t *testing.T) { + _, _, f := newFiltering(nil) + got, gotErr := f.GetTopoServer() + if got != nil { + t.Errorf("Got non-nil topo.Server from FilteringServer") + } + if gotErr != ErrTopoServerNotAvailable { + t.Errorf("Unexpected error from GetTopoServer; wanted %v but got %v", ErrTopoServerNotAvailable, gotErr) + } +} + +func doTestGetSrvKeyspaceNames( + t *testing.T, + f Server, + cell string, + want []string, + wantErr error, +) { + got, gotErr := f.GetSrvKeyspaceNames(stockCtx, cell) + + if got == nil { + t.Errorf("GetSrvKeyspaceNames failed: should not return nil") + } + if !reflect.DeepEqual(got, want) { + t.Errorf("GetSrvKeyspaceNames failed: want %v, got %v", want, got) + } + if wantErr != gotErr { + t.Errorf("GetSrvKeyspaceNames returned incorrect error: want %v, got %v", wantErr, gotErr) + } +} + +func TestFilteringServerGetSrvKeyspameNamesFiltersEverythingOut(t *testing.T) { + _, _, f := newFiltering(nil) + doTestGetSrvKeyspaceNames(t, f, stockCell, []string{}, nil) +} + +func TestFilteringServerGetSrvKeyspaceNamesFiltersKeyspaces(t *testing.T) { + _, _, f := newFiltering(stockFilters) + doTestGetSrvKeyspaceNames(t, f, stockCell, stockFilters, nil) +} + +func TestFilteringServerGetSrvKeyspaceNamesPassesThroughErrors(t *testing.T) { + _, mock, f := newFiltering(stockFilters) + wantErr := fmt.Errorf("some badcell error") + mock.SrvKeyspaceNamesError = wantErr + doTestGetSrvKeyspaceNames(t, f, "badcell", stockFilters, wantErr) +} + +func doTestGetSrvKeyspace( + t *testing.T, + f Server, + cell, + ksName string, + want *topodatapb.SrvKeyspace, + wantErr error, +) { + got, gotErr := f.GetSrvKeyspace(stockCtx, cell, ksName) + + gotColumnName := "" + wantColumnName := "" + if got != nil { + gotColumnName = got.ShardingColumnName + } + if want != nil { + wantColumnName = want.ShardingColumnName + } + + // a different pointer comes back so compare the expected return by proxy + // of a field we know the value of + if gotColumnName != wantColumnName { + t.Errorf("keyspace incorrect: got %v, want %v", got, want) + } + + if wantErr != gotErr { + t.Errorf("returned error incorrect: got %v, want %v", gotErr, wantErr) + } +} + +func TestFilteringServerGetSrvKeyspaceReturnsSelectedKeyspaces(t *testing.T) { + _, mock, f := newFiltering(stockFilters) + mock.SrvKeyspace = stockKeyspaces["bar"] + doTestGetSrvKeyspace(t, f, stockCell, "bar", stockKeyspaces["bar"], nil) +} + +func TestFilteringServerGetSrvKeyspaceErrorPassthrough(t *testing.T) { + wantErr := fmt.Errorf("some error") + _, mock, f := newFiltering(stockFilters) + mock.SrvKeyspace = stockKeyspaces["bar"] + mock.SrvKeyspaceError = wantErr + doTestGetSrvKeyspace(t, f, "badcell", "bar", stockKeyspaces["bar"], wantErr) +} + +func TestFilteringServerGetSrvKeyspaceFilters(t *testing.T) { + wantErr := topo.NewError(topo.NoNode, "foo") + _, mock, f := newFiltering(stockFilters) + mock.SrvKeyspaceError = wantErr + doTestGetSrvKeyspace(t, f, stockCell, "foo", nil, wantErr) +} + +func TestFilteringServerWatchSrvVSchemaFiltersPassthroughSrvVSchema(t *testing.T) { + _, mock, f := newFiltering(stockFilters) + + allowed := map[string]bool{} + for _, ks := range stockFilters { + allowed[ks] = true + } + + // we need to verify that the nested callback actually gets called + wg := sync.WaitGroup{} + wg.Add(1) + + cb := func(gotSchema *vschemapb.SrvVSchema, gotErr error) { + // ensure that only selected keyspaces made it into the callback + for name, ks := range gotSchema.Keyspaces { + if !allowed[name] { + t.Errorf("Unexpected keyspace found in callback: %v", ks) + } + wantKS := mock.WatchedSrvVSchema.Keyspaces[name] + if !reflect.DeepEqual(ks, wantKS) { + t.Errorf( + "Expected keyspace to be passed through unmodified: want %#v got %#v", + wantKS, + ks, + ) + } + } + wg.Done() + } + + f.WatchSrvVSchema(stockCtx, stockCell, cb) + wg.Wait() +} + +func TestFilteringServerWatchSrvVSchemaHandlesNilSchema(t *testing.T) { + wantErr := fmt.Errorf("some err") + _, mock, f := newFiltering(stockFilters) + mock.WatchedSrvVSchema = nil + mock.WatchedSrvVSchemaError = wantErr + + // we need to verify that the nested callback actually gets called + wg := sync.WaitGroup{} + wg.Add(1) + + cb := func(gotSchema *vschemapb.SrvVSchema, gotErr error) { + if gotSchema != nil { + t.Errorf("Expected nil gotSchema: got %#v", gotSchema) + } + if gotErr != wantErr { + t.Errorf("Unexpected error: want %v got %v", wantErr, gotErr) + } + wg.Done() + } + + f.WatchSrvVSchema(stockCtx, "other-cell", cb) + wg.Wait() +} diff --git a/go/vt/srvtopo/resilient_server.go b/go/vt/srvtopo/resilient_server.go index adcd2f0751c..6340bfe6551 100644 --- a/go/vt/srvtopo/resilient_server.go +++ b/go/vt/srvtopo/resilient_server.go @@ -223,8 +223,8 @@ func NewResilientServer(base *topo.Server, counterPrefix string) *ResilientServe } // GetTopoServer returns the topo.Server that backs the resilient server. -func (server *ResilientServer) GetTopoServer() *topo.Server { - return server.topoServer +func (server *ResilientServer) GetTopoServer() (*topo.Server, error) { + return server.topoServer, nil } // GetSrvKeyspaceNames returns all keyspace names for the given cell. diff --git a/go/vt/srvtopo/server.go b/go/vt/srvtopo/server.go index bb7b0cc6be1..b783c78d3f8 100644 --- a/go/vt/srvtopo/server.go +++ b/go/vt/srvtopo/server.go @@ -32,8 +32,8 @@ import ( // the serving graph read-only calls used by clients to resolve // serving addresses, and to get VSchema. type Server interface { - // GetTopoServer returns the full topo.Server instance - GetTopoServer() *topo.Server + // GetTopoServer returns the full topo.Server instance. + GetTopoServer() (*topo.Server, error) // GetSrvKeyspaceNames returns the list of keyspaces served in // the provided cell. diff --git a/go/vt/srvtopo/srvtopotest/passthrough.go b/go/vt/srvtopo/srvtopotest/passthrough.go new file mode 100644 index 00000000000..09b0bae2fa7 --- /dev/null +++ b/go/vt/srvtopo/srvtopotest/passthrough.go @@ -0,0 +1,65 @@ +/* +Copyright 2018 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package srvtopotest + +import ( + "golang.org/x/net/context" + + topodatapb "vitess.io/vitess/go/vt/proto/topodata" + vschemapb "vitess.io/vitess/go/vt/proto/vschema" + "vitess.io/vitess/go/vt/topo" +) + +// PassthroughSrvTopoServer is a bare implementation of srvtopo.Server for use in tests +type PassthroughSrvTopoServer struct { + TopoServer *topo.Server + TopoServerError error + + SrvKeyspaceNames []string + SrvKeyspaceNamesError error + + SrvKeyspace *topodatapb.SrvKeyspace + SrvKeyspaceError error + + WatchedSrvVSchema *vschemapb.SrvVSchema + WatchedSrvVSchemaError error +} + +// NewPassthroughSrvTopoServer returns a new, unconfigured test PassthroughSrvTopoServer +func NewPassthroughSrvTopoServer() *PassthroughSrvTopoServer { + return &PassthroughSrvTopoServer{} +} + +// GetTopoServer implements srvtopo.Server +func (srv *PassthroughSrvTopoServer) GetTopoServer() (*topo.Server, error) { + return srv.TopoServer, srv.TopoServerError +} + +// GetSrvKeyspaceNames implements srvtopo.Server +func (srv *PassthroughSrvTopoServer) GetSrvKeyspaceNames(ctx context.Context, cell string) ([]string, error) { + return srv.SrvKeyspaceNames, srv.SrvKeyspaceNamesError +} + +// GetSrvKeyspace implements srvtopo.Server +func (srv *PassthroughSrvTopoServer) GetSrvKeyspace(ctx context.Context, cell, keyspace string) (*topodatapb.SrvKeyspace, error) { + return srv.SrvKeyspace, srv.SrvKeyspaceError +} + +// WatchSrvVSchema implements srvtopo.Server +func (srv *PassthroughSrvTopoServer) WatchSrvVSchema(ctx context.Context, cell string, callback func(*vschemapb.SrvVSchema, error)) { + callback(srv.WatchedSrvVSchema, srv.WatchedSrvVSchemaError) +} diff --git a/go/vt/srvtopo/target_stats.go b/go/vt/srvtopo/target_stats.go index ed49e41f02f..6af44eedb81 100644 --- a/go/vt/srvtopo/target_stats.go +++ b/go/vt/srvtopo/target_stats.go @@ -27,9 +27,6 @@ import ( // routing of queries. // - discovery.TabletStatsCache will implement the discovery part of the // interface, and discoverygateway will have the QueryService. -// - hybridgateway will also implement this interface: for each l2vtgate pool, -// it will establish a StreamHealth connection, and store the returned -// health stats. type TargetStats interface { // GetAggregateStats returns the aggregate stats for the given Target. // The srvtopo module will use that information to route queries @@ -45,23 +42,6 @@ type TargetStats interface { GetMasterCell(keyspace, shard string) (cell string, qs queryservice.QueryService, err error) } -// TargetStatsListener is an interface used to propagate TargetStats changes. -// - discovery.TabletStatsCache will implement this interface. -// - the StreamHealth method in l2vtgate will use this interface to surface -// the health of its targets. -type TargetStatsListener interface { - // Subscribe will return the current full state of the TargetStats, - // and a channel that will receive subsequent updates. The int returned - // is the channel id, and can be sent to unsubscribe to stop - // notifications. - Subscribe() (int, []TargetStatsEntry, <-chan (*TargetStatsEntry), error) - - // Unsubscribe stops sending updates to the channel returned - // by Subscribe. The channel still needs to be drained to - // avoid deadlocks. - Unsubscribe(int) error -} - // TargetStatsEntry has the updated information for a Target. type TargetStatsEntry struct { // Target is what this entry applies to. diff --git a/go/vt/topo/helpers/compare.go b/go/vt/topo/helpers/compare.go new file mode 100644 index 00000000000..a14f3f83a03 --- /dev/null +++ b/go/vt/topo/helpers/compare.go @@ -0,0 +1,183 @@ +/* +Copyright 2018 The Vitess Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Package helpers contains a few utility classes to handle topo.Server +// objects, and transitions from one topo implementation to another. +package helpers + +import ( + "fmt" + "reflect" + + "golang.org/x/net/context" + "vitess.io/vitess/go/vt/topo" +) + +// CompareKeyspaces will compare the keyspaces in the destination topo. +func CompareKeyspaces(ctx context.Context, fromTS, toTS *topo.Server) error { + keyspaces, err := fromTS.GetKeyspaces(ctx) + if err != nil { + return fmt.Errorf("GetKeyspace(%v): %v", keyspaces, err) + } + + for _, keyspace := range keyspaces { + + fromKs, err := fromTS.GetKeyspace(ctx, keyspace) + if err != nil { + return fmt.Errorf("GetKeyspace(%v): %v", keyspace, err) + } + + toKs, err := toTS.GetKeyspace(ctx, keyspace) + if err != nil { + return fmt.Errorf("GetKeyspace(%v): %v", keyspace, err) + } + + if !reflect.DeepEqual(fromKs.Keyspace, toKs.Keyspace) { + return fmt.Errorf("Keyspace: %v does not match between from and to topology", keyspace) + } + + fromVs, err := fromTS.GetVSchema(ctx, keyspace) + switch { + case err == nil: + // Nothing to do. + case topo.IsErrType(err, topo.NoNode): + // Nothing to do. + default: + return fmt.Errorf("GetVSchema(%v): %v", keyspace, err) + } + + toVs, err := toTS.GetVSchema(ctx, keyspace) + switch { + case err == nil: + // Nothing to do. + case topo.IsErrType(err, topo.NoNode): + // Nothing to do. + default: + return fmt.Errorf("GetVSchema(%v): %v", keyspace, err) + } + + if !reflect.DeepEqual(fromVs, toVs) { + return fmt.Errorf("Vschema for keyspace: %v does not match between from and to topology", keyspace) + } + } + return nil +} + +// CompareShards will compare the shards in the destination topo. +func CompareShards(ctx context.Context, fromTS, toTS *topo.Server) error { + keyspaces, err := fromTS.GetKeyspaces(ctx) + if err != nil { + return fmt.Errorf("fromTS.GetKeyspaces: %v", err) + } + + for _, keyspace := range keyspaces { + shards, err := fromTS.GetShardNames(ctx, keyspace) + if err != nil { + return fmt.Errorf("GetShardNames(%v): %v", keyspace, err) + } + + for _, shard := range shards { + fromSi, err := fromTS.GetShard(ctx, keyspace, shard) + if err != nil { + return fmt.Errorf("GetShard(%v, %v): %v", keyspace, shard, err) + } + toSi, err := toTS.GetShard(ctx, keyspace, shard) + if err != nil { + return fmt.Errorf("GetShard(%v, %v): %v", keyspace, shard, err) + } + + if !reflect.DeepEqual(fromSi.Shard, toSi.Shard) { + return fmt.Errorf("Shard %v for keyspace: %v does not match between from and to topology", shard, keyspace) + } + } + } + return nil +} + +// CompareTablets will compare the tablets in the destination topo. +func CompareTablets(ctx context.Context, fromTS, toTS *topo.Server) error { + cells, err := fromTS.GetKnownCells(ctx) + if err != nil { + return fmt.Errorf("fromTS.GetKnownCells: %v", err) + } + + for _, cell := range cells { + tabletAliases, err := fromTS.GetTabletsByCell(ctx, cell) + if err != nil { + return fmt.Errorf("GetTabletsByCell(%v): %v", cell, err) + } + for _, tabletAlias := range tabletAliases { + + // read the source tablet + fromTi, err := fromTS.GetTablet(ctx, tabletAlias) + if err != nil { + return fmt.Errorf("GetTablet(%v): %v", tabletAlias, err) + } + toTi, err := toTS.GetTablet(ctx, tabletAlias) + if err != nil { + return fmt.Errorf("GetTablet(%v): %v", tabletAlias, err) + } + if !reflect.DeepEqual(fromTi.Tablet, toTi.Tablet) { + return fmt.Errorf("Tablet %v: does not match between from and to topology", tabletAlias) + } + } + } + return nil +} + +// CompareShardReplications will compare the ShardReplication objects in +// the destination topo. +func CompareShardReplications(ctx context.Context, fromTS, toTS *topo.Server) error { + keyspaces, err := fromTS.GetKeyspaces(ctx) + if err != nil { + return fmt.Errorf("fromTS.GetKeyspaces: %v", err) + } + + for _, keyspace := range keyspaces { + shards, err := fromTS.GetShardNames(ctx, keyspace) + if err != nil { + return fmt.Errorf("GetShardNames(%v): %v", keyspace, err) + } + + for _, shard := range shards { + + // read the source shard to get the cells + si, err := fromTS.GetShard(ctx, keyspace, shard) + if err != nil { + return fmt.Errorf("GetShard(%v, %v): %v", keyspace, shard, err) + } + + for _, cell := range si.Shard.Cells { + fromSRi, err := fromTS.GetShardReplication(ctx, cell, keyspace, shard) + if err != nil { + return fmt.Errorf("GetShardReplication(%v, %v, %v): %v", cell, keyspace, shard, err) + } + toSRi, err := toTS.GetShardReplication(ctx, cell, keyspace, shard) + if err != nil { + return fmt.Errorf("GetShardReplication(%v, %v, %v): %v", cell, keyspace, shard, err) + } + if !reflect.DeepEqual(fromSRi.ShardReplication, toSRi.ShardReplication) { + return fmt.Errorf( + "Shard Replication in cell %v, keyspace %v, shard %v: does not match between from and to topology", + cell, + keyspace, + shard) + } + } + } + } + return nil +} diff --git a/go/vt/topo/helpers/compare_test.go b/go/vt/topo/helpers/compare_test.go new file mode 100644 index 00000000000..cda8f325a8d --- /dev/null +++ b/go/vt/topo/helpers/compare_test.go @@ -0,0 +1,80 @@ +/* +Copyright 2017 Google Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package helpers + +import ( + "testing" + + "golang.org/x/net/context" +) + +func TestBasicCompare(t *testing.T) { + ctx := context.Background() + fromTS, toTS := createSetup(ctx, t) + + // check compare keyspace compare + err := CompareKeyspaces(ctx, fromTS, toTS) + if err == nil { + t.Fatalf("Compare keyspaces is not failing when topos are not in sync") + } + + CopyKeyspaces(ctx, fromTS, toTS) + + err = CompareKeyspaces(ctx, fromTS, toTS) + if err != nil { + t.Fatalf("Compare keyspaces failed: %v", err) + } + + // check shard copy + err = CompareShards(ctx, fromTS, toTS) + if err == nil { + t.Fatalf("Compare shards is not failing when topos are not in sync") + } + + CopyShards(ctx, fromTS, toTS) + + err = CompareShards(ctx, fromTS, toTS) + if err != nil { + t.Fatalf("Compare shards failed: %v", err) + } + + // check ShardReplication compare + err = CompareShardReplications(ctx, fromTS, toTS) + if err == nil { + t.Fatalf("Compare shard replications is not failing when topos are not in sync") + } + + CopyShardReplications(ctx, fromTS, toTS) + + err = CompareShardReplications(ctx, fromTS, toTS) + if err != nil { + t.Fatalf("Compare shard replications failed: %v", err) + } + + // check tablet compare + err = CompareTablets(ctx, fromTS, toTS) + if err == nil { + t.Fatalf("Compare tablets is not failing when topos are not in sync") + } + + CopyTablets(ctx, fromTS, toTS) + + err = CompareTablets(ctx, fromTS, toTS) + if err != nil { + t.Fatalf("Compare tablets failed: %v", err) + } +} diff --git a/go/vt/topo/helpers/copy.go b/go/vt/topo/helpers/copy.go index bc314ed34c6..4b19870b9e1 100644 --- a/go/vt/topo/helpers/copy.go +++ b/go/vt/topo/helpers/copy.go @@ -19,11 +19,7 @@ limitations under the License. package helpers import ( - "fmt" - "sync" - "golang.org/x/net/context" - "vitess.io/vitess/go/vt/concurrency" "vitess.io/vitess/go/vt/log" "vitess.io/vitess/go/vt/topo" @@ -37,44 +33,32 @@ func CopyKeyspaces(ctx context.Context, fromTS, toTS *topo.Server) { log.Fatalf("GetKeyspaces: %v", err) } - wg := sync.WaitGroup{} - rec := concurrency.AllErrorRecorder{} for _, keyspace := range keyspaces { - wg.Add(1) - go func(keyspace string) { - defer wg.Done() - ki, err := fromTS.GetKeyspace(ctx, keyspace) - if err != nil { - rec.RecordError(fmt.Errorf("GetKeyspace(%v): %v", keyspace, err)) - return - } + ki, err := fromTS.GetKeyspace(ctx, keyspace) + if err != nil { + log.Fatalf("GetKeyspace(%v): %v", keyspace, err) + } - if err := toTS.CreateKeyspace(ctx, keyspace, ki.Keyspace); err != nil { - if topo.IsErrType(err, topo.NodeExists) { - log.Warningf("keyspace %v already exists", keyspace) - } else { - rec.RecordError(fmt.Errorf("CreateKeyspace(%v): %v", keyspace, err)) - } + if err := toTS.CreateKeyspace(ctx, keyspace, ki.Keyspace); err != nil { + if topo.IsErrType(err, topo.NodeExists) { + log.Warningf("keyspace %v already exists", keyspace) + } else { + log.Errorf("CreateKeyspace(%v): %v", keyspace, err) } + } - vs, err := fromTS.GetVSchema(ctx, keyspace) - switch { - case err == nil: - if err := toTS.SaveVSchema(ctx, keyspace, vs); err != nil { - rec.RecordError(fmt.Errorf("SaveVSchema(%v): %v", keyspace, err)) - } - case topo.IsErrType(err, topo.NoNode): - // Nothing to do. - default: - rec.RecordError(fmt.Errorf("GetVSchema(%v): %v", keyspace, err)) - return + vs, err := fromTS.GetVSchema(ctx, keyspace) + switch { + case err == nil: + if err := toTS.SaveVSchema(ctx, keyspace, vs); err != nil { + log.Errorf("SaveVSchema(%v): %v", keyspace, err) } - }(keyspace) - } - wg.Wait() - if rec.HasErrors() { - log.Fatalf("copyKeyspaces failed: %v", rec.Error()) + case topo.IsErrType(err, topo.NoNode): + // Nothing to do. + default: + log.Errorf("GetVSchema(%v): %v", keyspace, err) + } } } @@ -85,51 +69,34 @@ func CopyShards(ctx context.Context, fromTS, toTS *topo.Server) { log.Fatalf("fromTS.GetKeyspaces: %v", err) } - wg := sync.WaitGroup{} - rec := concurrency.AllErrorRecorder{} for _, keyspace := range keyspaces { - wg.Add(1) - go func(keyspace string) { - defer wg.Done() - shards, err := fromTS.GetShardNames(ctx, keyspace) + shards, err := fromTS.GetShardNames(ctx, keyspace) + if err != nil { + log.Fatalf("GetShardNames(%v): %v", keyspace, err) + return + } + + for _, shard := range shards { + + si, err := fromTS.GetShard(ctx, keyspace, shard) if err != nil { - rec.RecordError(fmt.Errorf("GetShardNames(%v): %v", keyspace, err)) - return + log.Fatalf("GetShard(%v, %v): %v", keyspace, shard, err) } - for _, shard := range shards { - wg.Add(1) - go func(keyspace, shard string) { - defer wg.Done() - - si, err := fromTS.GetShard(ctx, keyspace, shard) - if err != nil { - rec.RecordError(fmt.Errorf("GetShard(%v, %v): %v", keyspace, shard, err)) - return - } - - if err := toTS.CreateShard(ctx, keyspace, shard); err != nil { - if topo.IsErrType(err, topo.NodeExists) { - log.Warningf("shard %v/%v already exists", keyspace, shard) - } else { - rec.RecordError(fmt.Errorf("CreateShard(%v, %v): %v", keyspace, shard, err)) - return - } - } - if _, err := toTS.UpdateShardFields(ctx, keyspace, shard, func(toSI *topo.ShardInfo) error { - *toSI.Shard = *si.Shard - return nil - }); err != nil { - rec.RecordError(fmt.Errorf("UpdateShardFields(%v, %v): %v", keyspace, shard, err)) - return - } - }(keyspace, shard) + if err := toTS.CreateShard(ctx, keyspace, shard); err != nil { + if topo.IsErrType(err, topo.NodeExists) { + log.Warningf("shard %v/%v already exists", keyspace, shard) + } else { + log.Fatalf("CreateShard(%v, %v): %v", keyspace, shard, err) + } } - }(keyspace) - } - wg.Wait() - if rec.HasErrors() { - log.Fatalf("copyShards failed: %v", rec.Error()) + if _, err := toTS.UpdateShardFields(ctx, keyspace, shard, func(toSI *topo.ShardInfo) error { + *toSI.Shard = *si.Shard + return nil + }); err != nil { + log.Fatalf("UpdateShardFields(%v, %v): %v", keyspace, shard, err) + } + } } } @@ -140,50 +107,34 @@ func CopyTablets(ctx context.Context, fromTS, toTS *topo.Server) { log.Fatalf("fromTS.GetKnownCells: %v", err) } - wg := sync.WaitGroup{} - rec := concurrency.AllErrorRecorder{} for _, cell := range cells { - wg.Add(1) - go func(cell string) { - defer wg.Done() - tabletAliases, err := fromTS.GetTabletsByCell(ctx, cell) - if err != nil { - rec.RecordError(fmt.Errorf("GetTabletsByCell(%v): %v", cell, err)) - } else { - for _, tabletAlias := range tabletAliases { - wg.Add(1) - go func(tabletAlias *topodatapb.TabletAlias) { - defer wg.Done() - - // read the source tablet - ti, err := fromTS.GetTablet(ctx, tabletAlias) - if err != nil { - rec.RecordError(fmt.Errorf("GetTablet(%v): %v", tabletAlias, err)) - return - } - - // try to create the destination - err = toTS.CreateTablet(ctx, ti.Tablet) - if topo.IsErrType(err, topo.NodeExists) { - // update the destination tablet - log.Warningf("tablet %v already exists, updating it", tabletAlias) - _, err = toTS.UpdateTabletFields(ctx, tabletAlias, func(t *topodatapb.Tablet) error { - *t = *ti.Tablet - return nil - }) - } - if err != nil { - rec.RecordError(fmt.Errorf("CreateTablet(%v): %v", tabletAlias, err)) - return - } - }(tabletAlias) + tabletAliases, err := fromTS.GetTabletsByCell(ctx, cell) + if err != nil { + log.Fatalf("GetTabletsByCell(%v): %v", cell, err) + } else { + for _, tabletAlias := range tabletAliases { + + // read the source tablet + ti, err := fromTS.GetTablet(ctx, tabletAlias) + if err != nil { + log.Fatalf("GetTablet(%v): %v", tabletAlias, err) + } + + // try to create the destination + err = toTS.CreateTablet(ctx, ti.Tablet) + if topo.IsErrType(err, topo.NodeExists) { + // update the destination tablet + log.Warningf("tablet %v already exists, updating it", tabletAlias) + _, err = toTS.UpdateTabletFields(ctx, tabletAlias, func(t *topodatapb.Tablet) error { + *t = *ti.Tablet + return nil + }) + } + if err != nil { + log.Fatalf("CreateTablet(%v): %v", tabletAlias, err) } } - }(cell) - } - wg.Wait() - if rec.HasErrors() { - log.Fatalf("copyTablets failed: %v", rec.Error()) + } } } @@ -195,50 +146,34 @@ func CopyShardReplications(ctx context.Context, fromTS, toTS *topo.Server) { log.Fatalf("fromTS.GetKeyspaces: %v", err) } - wg := sync.WaitGroup{} - rec := concurrency.AllErrorRecorder{} for _, keyspace := range keyspaces { - wg.Add(1) - go func(keyspace string) { - defer wg.Done() - shards, err := fromTS.GetShardNames(ctx, keyspace) + shards, err := fromTS.GetShardNames(ctx, keyspace) + if err != nil { + log.Fatalf("GetShardNames(%v): %v", keyspace, err) + } + + for _, shard := range shards { + + // read the source shard to get the cells + si, err := fromTS.GetShard(ctx, keyspace, shard) if err != nil { - rec.RecordError(fmt.Errorf("GetShardNames(%v): %v", keyspace, err)) - return + log.Fatalf("GetShard(%v, %v): %v", keyspace, shard, err) } - for _, shard := range shards { - wg.Add(1) - go func(keyspace, shard string) { - defer wg.Done() - - // read the source shard to get the cells - si, err := fromTS.GetShard(ctx, keyspace, shard) - if err != nil { - rec.RecordError(fmt.Errorf("GetShard(%v, %v): %v", keyspace, shard, err)) - return - } - - for _, cell := range si.Shard.Cells { - sri, err := fromTS.GetShardReplication(ctx, cell, keyspace, shard) - if err != nil { - rec.RecordError(fmt.Errorf("GetShardReplication(%v, %v, %v): %v", cell, keyspace, shard, err)) - continue - } - - if err := toTS.UpdateShardReplicationFields(ctx, cell, keyspace, shard, func(oldSR *topodatapb.ShardReplication) error { - *oldSR = *sri.ShardReplication - return nil - }); err != nil { - rec.RecordError(fmt.Errorf("UpdateShardReplicationFields(%v, %v, %v): %v", cell, keyspace, shard, err)) - } - } - }(keyspace, shard) + for _, cell := range si.Shard.Cells { + sri, err := fromTS.GetShardReplication(ctx, cell, keyspace, shard) + if err != nil { + log.Fatalf("GetShardReplication(%v, %v, %v): %v", cell, keyspace, shard, err) + continue + } + + if err := toTS.UpdateShardReplicationFields(ctx, cell, keyspace, shard, func(oldSR *topodatapb.ShardReplication) error { + *oldSR = *sri.ShardReplication + return nil + }); err != nil { + log.Warningf("UpdateShardReplicationFields(%v, %v, %v): %v", cell, keyspace, shard, err) + } } - }(keyspace) - } - wg.Wait() - if rec.HasErrors() { - log.Fatalf("copyShards failed: %v", rec.Error()) + } } } diff --git a/go/vt/topo/topoproto/destination.go b/go/vt/topo/topoproto/destination.go index 07a66947418..f4886fdd6ef 100644 --- a/go/vt/topo/topoproto/destination.go +++ b/go/vt/topo/topoproto/destination.go @@ -17,6 +17,7 @@ limitations under the License. package topoproto import ( + "encoding/hex" "strings" "vitess.io/vitess/go/vt/key" @@ -45,23 +46,32 @@ func ParseDestination(targetString string, defaultTabletType topodatapb.TabletTy dest = key.DestinationShard(targetString[last+1:]) targetString = targetString[:last] } - // Try to parse it as a range + // Try to parse it as a keyspace id or range last = strings.LastIndexAny(targetString, "[") if last != -1 { rangeEnd := strings.LastIndexAny(targetString, "]") if rangeEnd == -1 { return keyspace, tabletType, dest, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "invalid key range provided. Couldn't find range end ']'") - } rangeString := targetString[last+1 : rangeEnd] - keyRange, err := key.ParseShardingSpec(rangeString) - if err != nil { - return keyspace, tabletType, dest, err - } - if len(keyRange) != 1 { - return keyspace, tabletType, dest, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "single keyrange expected in %s", rangeString) + if strings.Contains(rangeString, "-") { + // Parse as range + keyRange, err := key.ParseShardingSpec(rangeString) + if err != nil { + return keyspace, tabletType, dest, err + } + if len(keyRange) != 1 { + return keyspace, tabletType, dest, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "single keyrange expected in %s", rangeString) + } + dest = key.DestinationExactKeyRange{KeyRange: keyRange[0]} + } else { + // Parse as keyspace id + destBytes, err := hex.DecodeString(rangeString) + if err != nil { + return keyspace, tabletType, dest, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "expected valid hex in keyspace id %s", rangeString) + } + dest = key.DestinationKeyspaceID(destBytes) } - dest = key.DestinationExactKeyRange{KeyRange: keyRange[0]} targetString = targetString[:last] } keyspace = targetString diff --git a/go/vt/topo/topoproto/destination_test.go b/go/vt/topo/topoproto/destination_test.go index 121277ea013..3fdddaf0a01 100644 --- a/go/vt/topo/topoproto/destination_test.go +++ b/go/vt/topo/topoproto/destination_test.go @@ -45,6 +45,11 @@ func TestParseDestination(t *testing.T) { keyspace: "ks", tabletType: topodatapb.TabletType_MASTER, dest: key.DestinationExactKeyRange{KeyRange: &topodatapb.KeyRange{}}, + }, { + targetString: "ks[deadbeef]@master", + keyspace: "ks", + tabletType: topodatapb.TabletType_MASTER, + dest: key.DestinationKeyspaceID([]byte("\xde\xad\xbe\xef")), }, { targetString: "ks[10-]@master", keyspace: "ks", @@ -109,4 +114,10 @@ func TestParseDestination(t *testing.T) { if err == nil || err.Error() != want { t.Errorf("executorExec error: %v, want %s", err, want) } + + _, _, _, err = ParseDestination("ks[qrnqorrs]@master", topodatapb.TabletType_MASTER) + want = "expected valid hex in keyspace id qrnqorrs" + if err == nil || err.Error() != want { + t.Errorf("executorExec error: %v, want %s", err, want) + } } diff --git a/go/vt/topo/zk2topo/zk_conn.go b/go/vt/topo/zk2topo/zk_conn.go index 661068c7854..c7539192500 100644 --- a/go/vt/topo/zk2topo/zk_conn.go +++ b/go/vt/topo/zk2topo/zk_conn.go @@ -31,7 +31,6 @@ import ( "github.com/samuel/go-zookeeper/zk" "golang.org/x/net/context" - "vitess.io/vitess/go/netutil" "vitess.io/vitess/go/sync2" "vitess.io/vitess/go/vt/log" ) @@ -307,11 +306,7 @@ func (c *ZkConn) handleSessionEvents(conn *zk.Conn, session <-chan zk.Event) { // dialZk dials the server, and waits until connection. func dialZk(ctx context.Context, addr string) (*zk.Conn, <-chan zk.Event, error) { - servers, err := resolveZkAddr(addr) - if err != nil { - return nil, nil, err - } - + servers := strings.Split(addr, ",") options := zk.WithDialer(net.DialTimeout) // If TLS is enabled use a TLS enabled dialer option if *certPath != "" && *keyPath != "" { @@ -376,26 +371,3 @@ func dialZk(ctx context.Context, addr string) (*zk.Conn, <-chan zk.Event, error) } } } - -// resolveZkAddr takes a comma-separated list of host:port addresses, -// and resolves the host to replace it with the IP address. -// If a resolution fails, the host is skipped. -// If no host can be resolved, an error is returned. -// This is different from the Zookeeper library, that insists on resolving -// *all* hosts successfully before it starts. -func resolveZkAddr(zkAddr string) ([]string, error) { - parts := strings.Split(zkAddr, ",") - resolved := make([]string, 0, len(parts)) - for _, part := range parts { - // The Zookeeper client cannot handle IPv6 addresses before version 3.4.x. - if r, err := netutil.ResolveIPv4Addrs(part); err != nil { - log.Warningf("cannot resolve %v, will not use it: %v", part, err) - } else { - resolved = append(resolved, r...) - } - } - if len(resolved) == 0 { - return nil, fmt.Errorf("no valid address found in %v", zkAddr) - } - return resolved, nil -} diff --git a/go/vt/topotools/rebuild_vschema.go b/go/vt/topotools/rebuild_vschema.go index e601b059789..1b356319361 100644 --- a/go/vt/topotools/rebuild_vschema.go +++ b/go/vt/topotools/rebuild_vschema.go @@ -67,7 +67,7 @@ func RebuildVSchema(ctx context.Context, log logutil.Logger, ts *topo.Server, ce mu.Lock() defer mu.Unlock() if err != nil { - log.Errorf("GetVSchema(%v) failed: %v", keyspace, err) + log.Errorf2(err, "GetVSchema(%v) failed", keyspace) finalErr = err return } @@ -85,7 +85,7 @@ func RebuildVSchema(ctx context.Context, log logutil.Logger, ts *topo.Server, ce go func(cell string) { defer wg.Done() if err := ts.UpdateSrvVSchema(ctx, cell, srvVSchema); err != nil { - log.Errorf("UpdateSrvVSchema(%v) failed: %v", cell, err) + log.Errorf2(err, "UpdateSrvVSchema(%v) failed", cell) mu.Lock() finalErr = err mu.Unlock() diff --git a/go/vt/vtctl/query.go b/go/vt/vtctl/query.go index 5be996302a7..00944a2e97e 100644 --- a/go/vt/vtctl/query.go +++ b/go/vt/vtctl/query.go @@ -617,7 +617,7 @@ func commandVtTabletStreamHealth(ctx context.Context, wr *wrangler.Wrangler, sub err = conn.StreamHealth(ctx, func(shr *querypb.StreamHealthResponse) error { data, err := json.Marshal(shr) if err != nil { - wr.Logger().Errorf("cannot json-marshal structure: %v", err) + wr.Logger().Errorf2(err, "cannot json-marshal structure") } else { wr.Logger().Printf("%v\n", string(data)) } @@ -672,7 +672,7 @@ func commandVtTabletUpdateStream(ctx context.Context, wr *wrangler.Wrangler, sub }, *position, int64(*timestamp), func(se *querypb.StreamEvent) error { data, err := json.Marshal(se) if err != nil { - wr.Logger().Errorf("cannot json-marshal structure: %v", err) + wr.Logger().Errorf2(err, "cannot json-marshal structure") } else { wr.Logger().Printf("%v\n", string(data)) } diff --git a/go/vt/vtctl/vtctl.go b/go/vt/vtctl/vtctl.go index d79902547ed..50b106ca977 100644 --- a/go/vt/vtctl/vtctl.go +++ b/go/vt/vtctl/vtctl.go @@ -2218,7 +2218,7 @@ func commandApplyVSchema(ctx context.Context, wr *wrangler.Wrangler, subFlags *f b, err := json2.MarshalIndentPB(vs, " ") if err != nil { - wr.Logger().Errorf("Failed to marshal VSchema for display: %v", err) + wr.Logger().Errorf2(err, "Failed to marshal VSchema for display") } else { wr.Logger().Printf("New VSchema object:\n%s\nIf this is not what you expected, check the input data (as JSON parsing will skip unexpected fields).\n", b) } diff --git a/go/vt/vterrors/LICENSE b/go/vt/vterrors/LICENSE new file mode 100644 index 00000000000..835ba3e755c --- /dev/null +++ b/go/vt/vterrors/LICENSE @@ -0,0 +1,23 @@ +Copyright (c) 2015, Dave Cheney +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/go/vt/vterrors/aggregate_test.go b/go/vt/vterrors/aggregate_test.go index 4847215a715..6f7841282ac 100644 --- a/go/vt/vterrors/aggregate_test.go +++ b/go/vt/vterrors/aggregate_test.go @@ -19,7 +19,6 @@ package vterrors import ( "errors" "fmt" - "reflect" "testing" vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" @@ -96,7 +95,7 @@ func TestAggregateVtGateErrors(t *testing.T) { } for _, tc := range testcases { out := Aggregate(tc.input) - if !reflect.DeepEqual(out, tc.expected) { + if !Equals(out, tc.expected) { t.Errorf("AggregateVtGateErrors(%+v) = %+v \nwant: %+v", tc.input, out, tc.expected) } diff --git a/go/vt/vterrors/doc.go b/go/vt/vterrors/doc.go deleted file mode 100644 index a704cc6568b..00000000000 --- a/go/vt/vterrors/doc.go +++ /dev/null @@ -1,44 +0,0 @@ -/* -Copyright 2017 Google Inc. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreedto in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -// Package vterrors provides helpers for propagating internal errors -// through the Vitess system (including across RPC boundaries) in a -// structured way. -package vterrors - -/* - -Vitess uses canonical error codes for error reporting. This is based -on years of industry experience with error reporting. This idea is -that errors should be classified into a small set of errors (10 or so) -with very specific meaning. Each error has a code, and a message. When -errors are passed around (even through RPCs), the code is -propagated. To handle errors, only the code should be looked at (and -not string-matching on the error message). - -Vitess defines the error codes in /proto/vtrpc.proto. Along with an -RPCError message that can be used to transmit errors through RPCs, in -the message payloads. These codes match the names and numbers defined -by gRPC. - -Vitess also defines a standardized error implementation that allows -you to build an error with an associated canonical code. - -While sending an error through gRPC, these codes are transmitted -using gRPC's error propagation mechanism and decoded back to -the original code on the other end. - -*/ diff --git a/go/vt/vterrors/errors_test.go b/go/vt/vterrors/errors_test.go new file mode 100644 index 00000000000..138bc6ed387 --- /dev/null +++ b/go/vt/vterrors/errors_test.go @@ -0,0 +1,205 @@ +package vterrors + +import ( + "errors" + "fmt" + "io" + "reflect" + "strings" + "testing" + + vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" +) + +func TestWrapNil(t *testing.T) { + got := Wrap(nil, "no error") + if got != nil { + t.Errorf("Wrap(nil, \"no error\"): got %#v, expected nil", got) + } +} + +func TestWrap(t *testing.T) { + tests := []struct { + err error + message string + wantMessage string + wantCode vtrpcpb.Code + }{ + {io.EOF, "read error", "read error: EOF", vtrpcpb.Code_UNKNOWN}, + {New(vtrpcpb.Code_ALREADY_EXISTS, "oops"), "client error", "client error: oops", vtrpcpb.Code_ALREADY_EXISTS}, + } + + for _, tt := range tests { + got := Wrap(tt.err, tt.message) + if got.Error() != tt.wantMessage { + t.Errorf("Wrap(%v, %q): got: [%v], want [%v]", tt.err, tt.message, got, tt.wantMessage) + } + if Code(got) != tt.wantCode { + t.Errorf("Wrap(%v, %v): got: [%v], want [%v]", tt.err, tt, Code(got), tt.wantCode) + } + } +} + +type nilError struct{} + +func (nilError) Error() string { return "nil error" } + +func TestRootCause(t *testing.T) { + x := New(vtrpcpb.Code_FAILED_PRECONDITION, "error") + tests := []struct { + err error + want error + }{{ + // nil error is nil + err: nil, + want: nil, + }, { + // explicit nil error is nil + err: (error)(nil), + want: nil, + }, { + // typed nil is nil + err: (*nilError)(nil), + want: (*nilError)(nil), + }, { + // uncaused error is unaffected + err: io.EOF, + want: io.EOF, + }, { + // caused error returns cause + err: Wrap(io.EOF, "ignored"), + want: io.EOF, + }, { + err: x, // return from errors.New + want: x, + }} + + for i, tt := range tests { + got := RootCause(tt.err) + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("test %d: got %#v, want %#v", i+1, got, tt.want) + } + } +} + +func TestCause(t *testing.T) { + x := New(vtrpcpb.Code_FAILED_PRECONDITION, "error") + tests := []struct { + err error + want error + }{{ + // nil error is nil + err: nil, + want: nil, + }, { + // uncaused error is nil + err: io.EOF, + want: nil, + }, { + // caused error returns cause + err: Wrap(io.EOF, "ignored"), + want: io.EOF, + }, { + err: x, // return from errors.New + want: nil, + }} + + for i, tt := range tests { + got := Cause(tt.err) + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("test %d: got %#v, want %#v", i+1, got, tt.want) + } + } +} + +func TestWrapfNil(t *testing.T) { + got := Wrapf(nil, "no error") + if got != nil { + t.Errorf("Wrapf(nil, \"no error\"): got %#v, expected nil", got) + } +} + +func TestWrapf(t *testing.T) { + tests := []struct { + err error + message string + want string + }{ + {io.EOF, "read error", "read error: EOF"}, + {Wrapf(io.EOF, "read error without format specifiers"), "client error", "client error: read error without format specifiers: EOF"}, + {Wrapf(io.EOF, "read error with %d format specifier", 1), "client error", "client error: read error with 1 format specifier: EOF"}, + } + + for _, tt := range tests { + got := Wrapf(tt.err, tt.message).Error() + if got != tt.want { + t.Errorf("Wrapf(%v, %q): got: %v, want %v", tt.err, tt.message, got, tt.want) + } + } +} + +func TestErrorf(t *testing.T) { + tests := []struct { + err error + want string + }{ + {Errorf(vtrpcpb.Code_DATA_LOSS, "read error without format specifiers"), "read error without format specifiers"}, + {Errorf(vtrpcpb.Code_DATA_LOSS, "read error with %d format specifier", 1), "read error with 1 format specifier"}, + } + + for _, tt := range tests { + got := tt.err.Error() + if got != tt.want { + t.Errorf("Errorf(%v): got: %q, want %q", tt.err, got, tt.want) + } + } +} + +func innerMost() error { + return Wrap(io.ErrNoProgress, "oh noes") +} + +func middle() error { + return innerMost() +} + +func outer() error { + return middle() +} + +func TestStackFormat(t *testing.T) { + err := outer() + got := fmt.Sprintf("%+v", err) + + assertStringContains(t, got, "innerMost") + assertStringContains(t, got, "middle") + assertStringContains(t, got, "outer") +} + +func assertStringContains(t *testing.T, s, substring string) { + if !strings.Contains(s, substring) { + t.Errorf("string did not contain `%v`: \n %v", substring, s) + } +} + +// errors.New, etc values are not expected to be compared by value +// but the change in errors#27 made them incomparable. Assert that +// various kinds of errors have a functional equality operator, even +// if the result of that equality is always false. +func TestErrorEquality(t *testing.T) { + vals := []error{ + nil, + io.EOF, + errors.New("EOF"), + New(vtrpcpb.Code_ALREADY_EXISTS, "EOF"), + Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "EOF"), + Wrap(io.EOF, "EOF"), + Wrapf(io.EOF, "EOF%d", 2), + } + + for i := range vals { + for j := range vals { + _ = vals[i] == vals[j] // mustn't panic + } + } +} diff --git a/go/vt/vterrors/proto3_test.go b/go/vt/vterrors/proto3_test.go index 34d89649456..c2b8395969e 100644 --- a/go/vt/vterrors/proto3_test.go +++ b/go/vt/vterrors/proto3_test.go @@ -17,7 +17,6 @@ limitations under the License. package vterrors import ( - "reflect" "testing" "github.com/golang/protobuf/proto" @@ -54,8 +53,8 @@ func TestFromVtRPCError(t *testing.T) { }} for _, tcase := range testcases { got := FromVTRPC(tcase.in) - if !reflect.DeepEqual(got, tcase.want) { - t.Errorf("FromVtRPCError(%v): %v, want %v", tcase.in, got, tcase.want) + if !Equals(got, tcase.want) { + t.Errorf("FromVtRPCError(%v): [%v], want [%v]", tcase.in, got, tcase.want) } } } diff --git a/go/vt/vterrors/stack.go b/go/vt/vterrors/stack.go new file mode 100644 index 00000000000..2ba717ad3f0 --- /dev/null +++ b/go/vt/vterrors/stack.go @@ -0,0 +1,191 @@ +package vterrors + +/* This file is copied from https://github.com/pkg/errors/blob/v0.8.0/stack.go */ + +import ( + "fmt" + "io" + "path" + "runtime" + "strings" +) + +// Frame represents a program counter inside a stack frame. +type Frame uintptr + +// pc returns the program counter for this frame; +// multiple frames may have the same PC value. +func (f Frame) pc() uintptr { return uintptr(f) - 1 } + +// file returns the full path to the file that contains the +// function for this Frame's pc. +func (f Frame) file() string { + fn := runtime.FuncForPC(f.pc()) + if fn == nil { + return "unknown" + } + file, _ := fn.FileLine(f.pc()) + return file +} + +// line returns the line number of source code of the +// function for this Frame's pc. +func (f Frame) line() int { + fn := runtime.FuncForPC(f.pc()) + if fn == nil { + return 0 + } + _, line := fn.FileLine(f.pc()) + return line +} + +// Format formats the frame according to the fmt.Formatter interface. +// +// %s source file +// %d source line +// %n function name +// %v equivalent to %s:%d +// +// Format accepts flags that alter the printing of some verbs, as follows: +// +// %+s path of source file relative to the compile time GOPATH +// %+v equivalent to %+s:%d +func (f Frame) Format(s fmt.State, verb rune) { + switch verb { + case 's': + switch { + case s.Flag('+'): + pc := f.pc() + fn := runtime.FuncForPC(pc) + if fn == nil { + io.WriteString(s, "unknown") + } else { + file, _ := fn.FileLine(pc) + fmt.Fprintf(s, "%s\n\t%s", fn.Name(), file) + } + default: + io.WriteString(s, path.Base(f.file())) + } + case 'd': + fmt.Fprintf(s, "%d", f.line()) + case 'n': + name := runtime.FuncForPC(f.pc()).Name() + io.WriteString(s, funcname(name)) + case 'v': + f.Format(s, 's') + io.WriteString(s, ":") + f.Format(s, 'd') + } +} + +// StackTrace is stack of Frames from innermost (newest) to outermost (oldest). +type StackTrace []Frame + +// Format format the stacktrace according to the fmt.Formatter interface. +// +// %s source file +// %d source line +// %n function name +// %v equivalent to %s:%d +// +// Format accepts flags that alter the printing of some verbs, as follows: +// +// %+s path of source file relative to the compile time GOPATH +// %+v equivalent to %+s:%d +func (st StackTrace) Format(s fmt.State, verb rune) { + switch verb { + case 'v': + switch { + case s.Flag('+'): + for _, f := range st { + fmt.Fprintf(s, "\n%+v", f) + } + case s.Flag('#'): + fmt.Fprintf(s, "%#v", []Frame(st)) + default: + fmt.Fprintf(s, "%v", []Frame(st)) + } + case 's': + fmt.Fprintf(s, "%s", []Frame(st)) + } +} + +// stack represents a stack of program counters. +type stack []uintptr + +func (s *stack) Format(st fmt.State, verb rune) { + switch verb { + case 'v': + switch { + case st.Flag('+'): + for _, pc := range *s { + f := Frame(pc) + fmt.Fprintf(st, "\n%+v", f) + } + } + } +} + +func (s *stack) StackTrace() StackTrace { + f := make([]Frame, len(*s)) + for i := 0; i < len(f); i++ { + f[i] = Frame((*s)[i]) + } + return f +} + +func callers() *stack { + const depth = 32 + var pcs [depth]uintptr + n := runtime.Callers(3, pcs[:]) + var st stack = pcs[0:n] + return &st +} + +// funcname removes the path prefix component of a function's name reported by func.Name(). +func funcname(name string) string { + i := strings.LastIndex(name, "/") + name = name[i+1:] + i = strings.Index(name, ".") + return name[i+1:] +} + +func trimGOPATH(name, file string) string { + // Here we want to get the source file path relative to the compile time + // GOPATH. As of Go 1.6.x there is no direct way to know the compiled + // GOPATH at runtime, but we can infer the number of path segments in the + // GOPATH. We note that fn.Name() returns the function name qualified by + // the import path, which does not include the GOPATH. Thus we can trim + // segments from the beginning of the file path until the number of path + // separators remaining is one more than the number of path separators in + // the function name. For example, given: + // + // GOPATH /home/user + // file /home/user/src/pkg/sub/file.go + // fn.Name() pkg/sub.Type.Method + // + // We want to produce: + // + // pkg/sub/file.go + // + // From this we can easily see that fn.Name() has one less path separator + // than our desired output. We count separators from the end of the file + // path until it finds two more than in the function name and then move + // one character forward to preserve the initial path segment without a + // leading separator. + const sep = "/" + goal := strings.Count(name, sep) + 2 + i := len(file) + for n := 0; n < goal; n++ { + i = strings.LastIndex(file[:i], sep) + if i == -1 { + // not enough separators found, set i so that the slice expression + // below leaves file unmodified + i = -len(sep) + break + } + } + // get back to 0 or trim the leading separator + file = file[i+len(sep):] + return file +} diff --git a/go/vt/vterrors/vterrors.go b/go/vt/vterrors/vterrors.go index 2ce24160410..6e85542cb10 100644 --- a/go/vt/vterrors/vterrors.go +++ b/go/vt/vterrors/vterrors.go @@ -1,73 +1,153 @@ -/* -Copyright 2017 Google Inc. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreedto in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - +// Package vterrors provides simple error handling primitives for Vitess +// +// In all Vitess code, errors should be propagated using vterrors.Wrapf() +// and not fmt.Errorf(). +// +// New errors should be created using vterrors.New +// +// Vitess uses canonical error codes for error reporting. This is based +// on years of industry experience with error reporting. This idea is +// that errors should be classified into a small set of errors (10 or so) +// with very specific meaning. Each error has a code, and a message. When +// errors are passed around (even through RPCs), the code is +// propagated. To handle errors, only the code should be looked at (and +// not string-matching on the error message). +// +// Error codes are defined in /proto/vtrpc.proto. Along with an +// RPCError message that can be used to transmit errors through RPCs, in +// the message payloads. These codes match the names and numbers defined +// by gRPC. +// +// A standardized error implementation that allows you to build an error +// with an associated canonical code is also defined. +// While sending an error through gRPC, these codes are transmitted +// using gRPC's error propagation mechanism and decoded back to +// the original code on the other end. +// +// Retrieving the cause of an error +// +// Using vterrors.Wrap constructs a stack of errors, adding context to the +// preceding error. Depending on the nature of the error it may be necessary +// to reverse the operation of errors.Wrap to retrieve the original error +// for inspection. Any error value which implements this interface +// +// type causer interface { +// Cause() error +// } +// +// can be inspected by vterrors.Cause and vterrors.RootCause. +// +// * vterrors.Cause will find the immediate cause if one is available, or nil +// if the error is not a `causer` or if no cause is available. +// * vterrors.RootCause will recursively retrieve +// the topmost error which does not implement causer, which is assumed to be +// the original cause. For example: +// +// switch err := errors.RootCause(err).(type) { +// case *MyError: +// // handle specifically +// default: +// // unknown error +// } +// +// causer interface is not exported by this package, but is considered a part +// of stable public API. +// +// Formatted printing of errors +// +// All error values returned from this package implement fmt.Formatter and can +// be formatted by the fmt package. The following verbs are supported +// +// %s print the error. If the error has a Cause it will be +// printed recursively +// %v see %s +// %+v extended format. Each Frame of the error's StackTrace will +// be printed in detail. +// +// Most but not all of the code in this file was originally copied from +// https://github.com/pkg/errors/blob/v0.8.0/errors.go package vterrors import ( "fmt" - "golang.org/x/net/context" - + "io" vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" ) -type vtError struct { - code vtrpcpb.Code - err string +// New returns an error with the supplied message. +// New also records the stack trace at the point it was called. +func New(code vtrpcpb.Code, message string) error { + return &fundamental{ + msg: message, + code: code, + stack: callers(), + } } -// New creates a new error using the code and input string. -func New(code vtrpcpb.Code, in string) error { - if code == vtrpcpb.Code_OK { - panic("OK is an invalid error code; use INTERNAL instead") - } - return &vtError{ - code: code, - err: in, +// NewWithoutCode returns an error when no applicable error code is available +// It will record the stack trace when creating the error +func NewWithoutCode(message string) error { + return &fundamental{ + msg: message, + code: vtrpcpb.Code_UNKNOWN, + stack: callers(), } } -// Wrap wraps the given error, returning a new error with the given message as a prefix but with the same error code (if err was a vterror) and message of the passed error. -func Wrap(err error, message string) error { - return New(Code(err), fmt.Sprintf("%v: %v", message, err.Error())) +// Errorf formats according to a format specifier and returns the string +// as a value that satisfies error. +// Errorf also records the stack trace at the point it was called. +func Errorf(code vtrpcpb.Code, format string, args ...interface{}) error { + return &fundamental{ + msg: fmt.Sprintf(format, args...), + code: code, + stack: callers(), + } } -// Wrapf wraps the given error, returning a new error with the given format string as a prefix but with the same error code (if err was a vterror) and message of the passed error. -func Wrapf(err error, format string, args ...interface{}) error { - return Wrap(err, fmt.Sprintf(format, args...)) +// fundamental is an error that has a message and a stack, but no caller. +type fundamental struct { + msg string + code vtrpcpb.Code + *stack } -// Errorf returns a new error built using Printf style arguments. -func Errorf(code vtrpcpb.Code, format string, args ...interface{}) error { - return New(code, fmt.Sprintf(format, args...)) -} +func (f *fundamental) Error() string { return f.msg } -func (e *vtError) Error() string { - return e.err +func (f *fundamental) Format(s fmt.State, verb rune) { + switch verb { + case 'v': + if s.Flag('+') { + io.WriteString(s, "Code: "+f.code.String()+"\n") + io.WriteString(s, f.msg+"\n") + f.stack.Format(s, verb) + return + } + fallthrough + case 's': + io.WriteString(s, f.msg) + case 'q': + fmt.Fprintf(s, "%q", f.msg) + } } // Code returns the error code if it's a vtError. -// If err is nil, it returns ok. Otherwise, it returns unknown. +// If err is nil, it returns ok. func Code(err error) vtrpcpb.Code { if err == nil { return vtrpcpb.Code_OK } - if err, ok := err.(*vtError); ok { + if err, ok := err.(*fundamental); ok { return err.code } + + cause := Cause(err) + if cause != err && cause != nil { + // If we did not find an error code at the outer level, let's find the cause and check it's code + return Code(cause) + } + // Handle some special cases. switch err { case context.Canceled: @@ -78,16 +158,111 @@ func Code(err error) vtrpcpb.Code { return vtrpcpb.Code_UNKNOWN } +// Wrap returns an error annotating err with a stack trace +// at the point Wrap is called, and the supplied message. +// If err is nil, Wrap returns nil. +func Wrap(err error, message string) error { + if err == nil { + return nil + } + return &wrapping{ + cause: err, + msg: message, + stack: callers(), + } +} + +// Wrapf returns an error annotating err with a stack trace +// at the point Wrapf is call, and the format specifier. +// If err is nil, Wrapf returns nil. +func Wrapf(err error, format string, args ...interface{}) error { + if err == nil { + return nil + } + return &wrapping{ + cause: err, + msg: fmt.Sprintf(format, args...), + stack: callers(), + } +} + +type wrapping struct { + cause error + msg string + stack *stack +} + +func (w *wrapping) Error() string { return w.msg + ": " + w.cause.Error() } +func (w *wrapping) Cause() error { return w.cause } + +func (w *wrapping) Format(s fmt.State, verb rune) { + switch verb { + case 'v': + if s.Flag('+') { + fmt.Fprintf(s, "%+v\n", w.Cause()) + io.WriteString(s, w.msg) + w.stack.Format(s, verb) + return + } + fallthrough + case 's', 'q': + io.WriteString(s, w.Error()) + } +} + +// RootCause returns the underlying cause of the error, if possible. +// An error value has a cause if it implements the following +// interface: +// +// type causer interface { +// Cause() error +// } +// +// If the error does not implement Cause, the original error will +// be returned. If the error is nil, nil will be returned without further +// investigation. +func RootCause(err error) error { + for { + cause := Cause(err) + if cause == nil { + return err + } + err = cause + } +} + +// +// Cause will return the immediate cause, if possible. +// An error value has a cause if it implements the following +// interface: +// +// type causer interface { +// Cause() error +// } +// If the error does not implement Cause, nil will be returned +func Cause(err error) error { + type causer interface { + Cause() error + } + + causerObj, ok := err.(causer) + if !ok { + return nil + } + + return causerObj.Cause() +} + // Equals returns true iff the error message and the code returned by Code() -// is equal. +// are equal. func Equals(a, b error) bool { if a == nil && b == nil { // Both are nil. return true } - if a == nil && b != nil || a != nil && b == nil { - // One of the two is nil. + if a == nil || b == nil { + // One of the two is nil, since we know both are not nil. return false } @@ -97,5 +272,5 @@ func Equals(a, b error) bool { // Print is meant to print the vtError object in test failures. // For comparing two vterrors, use Equals() instead. func Print(err error) string { - return fmt.Sprintf("%v: %v", Code(err), err.Error()) + return fmt.Sprintf("%v: %v\n", Code(err), err.Error()) } diff --git a/go/vt/vtexplain/vtexplain_topo.go b/go/vt/vtexplain/vtexplain_topo.go index f168636eb80..e4e0f374969 100644 --- a/go/vt/vtexplain/vtexplain_topo.go +++ b/go/vt/vtexplain/vtexplain_topo.go @@ -55,8 +55,8 @@ func (et *ExplainTopo) getSrvVSchema() *vschemapb.SrvVSchema { } // GetTopoServer is part of the srvtopo.Server interface -func (et *ExplainTopo) GetTopoServer() *topo.Server { - return nil +func (et *ExplainTopo) GetTopoServer() (*topo.Server, error) { + return nil, nil } // GetSrvKeyspaceNames is part of the srvtopo.Server interface. diff --git a/go/vt/vtgate/buffer/buffer_test.go b/go/vt/vtgate/buffer/buffer_test.go index b42ed617669..1edd007f7d1 100644 --- a/go/vt/vtgate/buffer/buffer_test.go +++ b/go/vt/vtgate/buffer/buffer_test.go @@ -94,8 +94,8 @@ func TestBuffer(t *testing.T) { // an external failover tool, the timestamp will be increased (even though // the master did not change.) b.StatsUpdate(&discovery.TabletStats{ - Tablet: oldMaster, - Target: &querypb.Target{Keyspace: keyspace, Shard: shard, TabletType: topodatapb.TabletType_MASTER}, + Tablet: oldMaster, + Target: &querypb.Target{Keyspace: keyspace, Shard: shard, TabletType: topodatapb.TabletType_MASTER}, TabletExternallyReparentedTimestamp: now.Unix(), }) @@ -124,8 +124,8 @@ func TestBuffer(t *testing.T) { // Mimic the failover end. now = now.Add(1 * time.Second) b.StatsUpdate(&discovery.TabletStats{ - Tablet: newMaster, - Target: &querypb.Target{Keyspace: keyspace, Shard: shard, TabletType: topodatapb.TabletType_MASTER}, + Tablet: newMaster, + Target: &querypb.Target{Keyspace: keyspace, Shard: shard, TabletType: topodatapb.TabletType_MASTER}, TabletExternallyReparentedTimestamp: now.Unix(), }) @@ -185,8 +185,8 @@ func TestBuffer(t *testing.T) { } // Stop buffering. b.StatsUpdate(&discovery.TabletStats{ - Tablet: oldMaster, - Target: &querypb.Target{Keyspace: keyspace, Shard: shard, TabletType: topodatapb.TabletType_MASTER}, + Tablet: oldMaster, + Target: &querypb.Target{Keyspace: keyspace, Shard: shard, TabletType: topodatapb.TabletType_MASTER}, TabletExternallyReparentedTimestamp: now.Unix(), }) if err := <-stopped4; err != nil { @@ -322,8 +322,8 @@ func TestDryRun(t *testing.T) { // End of failover is tracked as well. b.StatsUpdate(&discovery.TabletStats{ - Tablet: newMaster, - Target: &querypb.Target{Keyspace: keyspace, Shard: shard, TabletType: topodatapb.TabletType_MASTER}, + Tablet: newMaster, + Target: &querypb.Target{Keyspace: keyspace, Shard: shard, TabletType: topodatapb.TabletType_MASTER}, TabletExternallyReparentedTimestamp: 1, // Use any value > 0. }) if err := waitForState(b, stateIdle); err != nil { @@ -374,8 +374,8 @@ func TestLastReparentTooRecent_BufferingSkipped(t *testing.T) { // very recently (time.Now()). // vtgate should see this immediately after the start. b.StatsUpdate(&discovery.TabletStats{ - Tablet: oldMaster, - Target: &querypb.Target{Keyspace: keyspace, Shard: shard, TabletType: topodatapb.TabletType_MASTER}, + Tablet: oldMaster, + Target: &querypb.Target{Keyspace: keyspace, Shard: shard, TabletType: topodatapb.TabletType_MASTER}, TabletExternallyReparentedTimestamp: now.Unix(), }) @@ -383,8 +383,8 @@ func TestLastReparentTooRecent_BufferingSkipped(t *testing.T) { // Do not start buffering. now = now.Add(1 * time.Second) b.StatsUpdate(&discovery.TabletStats{ - Tablet: newMaster, - Target: &querypb.Target{Keyspace: keyspace, Shard: shard, TabletType: topodatapb.TabletType_MASTER}, + Tablet: newMaster, + Target: &querypb.Target{Keyspace: keyspace, Shard: shard, TabletType: topodatapb.TabletType_MASTER}, TabletExternallyReparentedTimestamp: now.Unix(), }) @@ -418,8 +418,8 @@ func TestLastReparentTooRecent_Buffering(t *testing.T) { // very recently (time.Now()). // vtgate should see this immediately after the start. b.StatsUpdate(&discovery.TabletStats{ - Tablet: oldMaster, - Target: &querypb.Target{Keyspace: keyspace, Shard: shard, TabletType: topodatapb.TabletType_MASTER}, + Tablet: oldMaster, + Target: &querypb.Target{Keyspace: keyspace, Shard: shard, TabletType: topodatapb.TabletType_MASTER}, TabletExternallyReparentedTimestamp: now.Unix(), }) @@ -427,8 +427,8 @@ func TestLastReparentTooRecent_Buffering(t *testing.T) { // there was 0 QPS traffic and no buffering was started. now = now.Add(1 * time.Second) b.StatsUpdate(&discovery.TabletStats{ - Tablet: newMaster, - Target: &querypb.Target{Keyspace: keyspace, Shard: shard, TabletType: topodatapb.TabletType_MASTER}, + Tablet: newMaster, + Target: &querypb.Target{Keyspace: keyspace, Shard: shard, TabletType: topodatapb.TabletType_MASTER}, TabletExternallyReparentedTimestamp: now.Unix(), }) @@ -442,8 +442,8 @@ func TestLastReparentTooRecent_Buffering(t *testing.T) { } // And then the failover end. b.StatsUpdate(&discovery.TabletStats{ - Tablet: newMaster, - Target: &querypb.Target{Keyspace: keyspace, Shard: shard, TabletType: topodatapb.TabletType_MASTER}, + Tablet: newMaster, + Target: &querypb.Target{Keyspace: keyspace, Shard: shard, TabletType: topodatapb.TabletType_MASTER}, TabletExternallyReparentedTimestamp: now.Unix(), }) @@ -481,8 +481,8 @@ func TestPassthroughDuringDrain(t *testing.T) { // Stop buffering and trigger drain. b.StatsUpdate(&discovery.TabletStats{ - Tablet: newMaster, - Target: &querypb.Target{Keyspace: keyspace, Shard: shard, TabletType: topodatapb.TabletType_MASTER}, + Tablet: newMaster, + Target: &querypb.Target{Keyspace: keyspace, Shard: shard, TabletType: topodatapb.TabletType_MASTER}, TabletExternallyReparentedTimestamp: 1, // Use any value > 0. }) if got, want := b.getOrCreateBuffer(keyspace, shard).state, stateDraining; got != want { @@ -596,8 +596,8 @@ func testRequestCanceled(t *testing.T, explicitEnd bool) { if explicitEnd { b.StatsUpdate(&discovery.TabletStats{ - Tablet: newMaster, - Target: &querypb.Target{Keyspace: keyspace, Shard: shard, TabletType: topodatapb.TabletType_MASTER}, + Tablet: newMaster, + Target: &querypb.Target{Keyspace: keyspace, Shard: shard, TabletType: topodatapb.TabletType_MASTER}, TabletExternallyReparentedTimestamp: 1, // Use any value > 0. }) } @@ -615,8 +615,8 @@ func testRequestCanceled(t *testing.T, explicitEnd bool) { // shortly after. In that case, the buffer should ignore it. if !explicitEnd { b.StatsUpdate(&discovery.TabletStats{ - Tablet: newMaster, - Target: &querypb.Target{Keyspace: keyspace, Shard: shard, TabletType: topodatapb.TabletType_MASTER}, + Tablet: newMaster, + Target: &querypb.Target{Keyspace: keyspace, Shard: shard, TabletType: topodatapb.TabletType_MASTER}, TabletExternallyReparentedTimestamp: 1, // Use any value > 0. }) } @@ -661,8 +661,8 @@ func TestEviction(t *testing.T) { // End of failover. Stop buffering. b.StatsUpdate(&discovery.TabletStats{ - Tablet: newMaster, - Target: &querypb.Target{Keyspace: keyspace, Shard: shard, TabletType: topodatapb.TabletType_MASTER}, + Tablet: newMaster, + Target: &querypb.Target{Keyspace: keyspace, Shard: shard, TabletType: topodatapb.TabletType_MASTER}, TabletExternallyReparentedTimestamp: 1, // Use any value > 0. }) @@ -744,8 +744,8 @@ func TestEvictionNotPossible(t *testing.T) { // End of failover. Stop buffering. b.StatsUpdate(&discovery.TabletStats{ - Tablet: newMaster, - Target: &querypb.Target{Keyspace: keyspace, Shard: shard, TabletType: topodatapb.TabletType_MASTER}, + Tablet: newMaster, + Target: &querypb.Target{Keyspace: keyspace, Shard: shard, TabletType: topodatapb.TabletType_MASTER}, TabletExternallyReparentedTimestamp: 1, // Use any value > 0. }) if err := <-stoppedFirstFailover; err != nil { diff --git a/go/vt/vtgate/gateway/discoverygateway.go b/go/vt/vtgate/gateway/discoverygateway.go index 6d5d4ed7522..26731f72e69 100644 --- a/go/vt/vtgate/gateway/discoverygateway.go +++ b/go/vt/vtgate/gateway/discoverygateway.go @@ -87,8 +87,13 @@ type discoveryGateway struct { func createDiscoveryGateway(hc discovery.HealthCheck, serv srvtopo.Server, cell string, retryCount int) Gateway { var topoServer *topo.Server if serv != nil { - topoServer = serv.GetTopoServer() + var err error + topoServer, err = serv.GetTopoServer() + if err != nil { + log.Exitf("Unable to create new discoverygateway: %v", err) + } } + dg := &discoveryGateway{ hc: hc, tsc: discovery.NewTabletStatsCacheDoNotSetListener(topoServer, cell), @@ -111,6 +116,10 @@ func createDiscoveryGateway(hc discovery.HealthCheck, serv srvtopo.Server, cell } var tr discovery.TabletRecorder = dg.hc if len(tabletFilters) > 0 { + if len(KeyspacesToWatch) > 0 { + log.Exitf("Only one of -keyspaces_to_watch and -tablet_filters may be specified at a time") + } + fbs, err := discovery.NewFilterByShard(dg.hc, tabletFilters) if err != nil { log.Exitf("Cannot parse tablet_filters parameter: %v", err) @@ -201,12 +210,6 @@ func (dg *discoveryGateway) GetMasterCell(keyspace, shard string) (string, query return cell, dg, err } -// StreamHealth is not forwarded to any other tablet, -// but we handle it directly here. -func (dg *discoveryGateway) StreamHealth(ctx context.Context, callback func(*querypb.StreamHealthResponse) error) error { - return StreamHealthFromTargetStatsListener(ctx, dg.tsc, callback) -} - // Close shuts down underlying connections. // This function hides the inner implementation. func (dg *discoveryGateway) Close(ctx context.Context) error { diff --git a/go/vt/vtgate/gateway/gateway.go b/go/vt/vtgate/gateway/gateway.go index 35893063b4a..d9f5166ea4f 100644 --- a/go/vt/vtgate/gateway/gateway.go +++ b/go/vt/vtgate/gateway/gateway.go @@ -20,17 +20,16 @@ package gateway import ( "flag" - "fmt" "time" "golang.org/x/net/context" + "vitess.io/vitess/go/flagutil" "vitess.io/vitess/go/vt/log" "vitess.io/vitess/go/vt/discovery" "vitess.io/vitess/go/vt/srvtopo" "vitess.io/vitess/go/vt/vttablet/queryservice" - querypb "vitess.io/vitess/go/vt/proto/query" topodatapb "vitess.io/vitess/go/vt/proto/topodata" ) @@ -40,8 +39,17 @@ import ( var ( implementation = flag.String("gateway_implementation", "discoverygateway", "The implementation of gateway") initialTabletTimeout = flag.Duration("gateway_initial_tablet_timeout", 30*time.Second, "At startup, the gateway will wait up to that duration to get one tablet per keyspace/shard/tablettype") + + // KeyspacesToWatch - if provided this specifies which keyspaces should be + // visible to a vtgate. By default the vtgate will allow access to any + // keyspace. + KeyspacesToWatch flagutil.StringListValue ) +func init() { + flag.Var(&KeyspacesToWatch, "keyspaces_to_watch", "Specifies which keyspaces this vtgate should have access to while routing queries or accessing the vschema") +} + // A Gateway is the query processing module for each shard, // which is used by ScatterConn. type Gateway interface { @@ -118,53 +126,3 @@ func WaitForTablets(gw Gateway, tabletTypesToWait []topodatapb.TabletType) error } return err } - -// StreamHealthFromTargetStatsListener responds to a StreamHealth -// streaming RPC using a srvtopo.TargetStatsListener implementation. -func StreamHealthFromTargetStatsListener(ctx context.Context, l srvtopo.TargetStatsListener, callback func(*querypb.StreamHealthResponse) error) error { - // Subscribe to the TargetStatsListener aggregate stats. - id, entries, c, err := l.Subscribe() - if err != nil { - return err - } - defer func() { - // Unsubscribe so we don't receive more updates, and - // drain the channel. - l.Unsubscribe(id) - for range c { - } - }() - - // Send all current entries. - for _, e := range entries { - shr := &querypb.StreamHealthResponse{ - Target: e.Target, - TabletExternallyReparentedTimestamp: e.TabletExternallyReparentedTimestamp, - AggregateStats: e.Stats, - } - if err := callback(shr); err != nil { - return err - } - } - - // Now listen for updates, or the end of the connection. - for { - select { - case <-ctx.Done(): - return ctx.Err() - case e, ok := <-c: - if !ok { - // Channel is closed, should never happen. - return fmt.Errorf("channel closed") - } - shr := &querypb.StreamHealthResponse{ - Target: e.Target, - TabletExternallyReparentedTimestamp: e.TabletExternallyReparentedTimestamp, - AggregateStats: e.Stats, - } - if err := callback(shr); err != nil { - return err - } - } - } -} diff --git a/go/vt/vtgate/gateway/hybridgateway.go b/go/vt/vtgate/gateway/hybridgateway.go deleted file mode 100644 index a2a1b0f3f7c..00000000000 --- a/go/vt/vtgate/gateway/hybridgateway.go +++ /dev/null @@ -1,202 +0,0 @@ -/* -Copyright 2017 Google Inc. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package gateway - -import ( - "fmt" - - "golang.org/x/net/context" - - "vitess.io/vitess/go/stats" - querypb "vitess.io/vitess/go/vt/proto/query" - topodatapb "vitess.io/vitess/go/vt/proto/topodata" - "vitess.io/vitess/go/vt/srvtopo" - "vitess.io/vitess/go/vt/topo" - "vitess.io/vitess/go/vt/vttablet/queryservice" -) - -// HybridGateway implements the gateway.Gateway interface by forwarding -// the queries to the right underlying implementation: -// - it has one gateway that watches for tablets. Usually a DiscoveryGateway. -// Useful for local tablets, or remote tablets that can be accessed. -// - it has a list of remote vtgate connections to talk to l2 vtgate processes. -// Useful for remote tablets that are far away, or if the number of local -// tablets grows too big. -// -// Note the WaitForTablets method for now only waits on the local gateway. -type HybridGateway struct { - queryservice.QueryService - - // gw is the local gateway that has the local connections. - gw Gateway - - // l2vtgates is the list of remote connections to other vtgate pools. - l2vtgates []*L2VTGateConn -} - -// NewHybridGateway returns a new HybridGateway based on the provided -// parameters. gw can be nil, in which case it is assumed there is no -// local tablets. -func NewHybridGateway(gw Gateway, addrs []string, retryCount int) (*HybridGateway, error) { - h := &HybridGateway{ - gw: gw, - } - - for i, addr := range addrs { - conn, err := NewL2VTGateConn(fmt.Sprintf("%v", i), addr, retryCount) - if err != nil { - h.Close(context.Background()) - return nil, fmt.Errorf("dialing %v failed: %v", addr, err) - } - h.l2vtgates = append(h.l2vtgates, conn) - } - - h.QueryService = queryservice.Wrap(nil, h.route) - return h, nil -} - -// Close is part of the queryservice.QueryService interface. -func (h *HybridGateway) Close(ctx context.Context) error { - for _, l := range h.l2vtgates { - l.Close(ctx) - } - return nil -} - -// WaitForTablets is part of the Gateway interface. -// We just forward to the local Gateway, if any. -func (h *HybridGateway) WaitForTablets(ctx context.Context, tabletTypesToWait []topodatapb.TabletType) error { - if h.gw != nil { - return h.gw.WaitForTablets(ctx, tabletTypesToWait) - } - - // No local tablets, we don't wait for anything here. - return nil -} - -// RegisterStats registers the l2vtgate connection counts stats. -func (h *HybridGateway) RegisterStats() { - stats.NewCountersFuncWithMultiLabels( - "L2VtgateConnections", - "number of l2vtgate connection", - []string{"Keyspace", "ShardName", "TabletType"}, - h.servingConnStats) -} - -func (h *HybridGateway) servingConnStats() map[string]int64 { - res := make(map[string]int64) - for _, l := range h.l2vtgates { - l.servingConnStats(res) - } - return res -} - -// CacheStatus is part of the Gateway interface. It just concatenates -// all statuses from all underlying parts. -func (h *HybridGateway) CacheStatus() TabletCacheStatusList { - var result TabletCacheStatusList - - // Start with the local Gateway part. - if h.gw != nil { - result = h.gw.CacheStatus() - } - - // Then add each gateway one at a time. - for _, l := range h.l2vtgates { - partial := l.CacheStatus() - result = append(result, partial...) - } - - return result -} - -// route sends the action to the right underlying implementation. -// This doesn't retry, and doesn't collect stats, as these two are -// done by the underlying gw or l2VTGateConn. -// -// FIXME(alainjobart) now we only use gw, or the one l2vtgates we have. -// Need to deprecate this code in favor of using GetAggregateStats. -func (h *HybridGateway) route(ctx context.Context, target *querypb.Target, conn queryservice.QueryService, name string, inTransaction bool, inner func(context.Context, *querypb.Target, queryservice.QueryService) (error, bool)) error { - if h.gw != nil { - err, _ := inner(ctx, target, h.gw) - return NewShardError(err, target, nil, inTransaction) - } - if len(h.l2vtgates) == 1 { - err, _ := inner(ctx, target, h.l2vtgates[0]) - return NewShardError(err, target, nil, inTransaction) - } - return NewShardError(topo.NewError(topo.NoNode, ""), target, nil, inTransaction) -} - -// GetAggregateStats is part of the srvtopo.TargetStats interface, included -// in the gateway.Gateway interface. -func (h *HybridGateway) GetAggregateStats(target *querypb.Target) (*querypb.AggregateStats, queryservice.QueryService, error) { - // Start with the local Gateway part. - if h.gw != nil { - stats, qs, err := h.gw.GetAggregateStats(target) - if !topo.IsErrType(err, topo.NoNode) { - // The local gateway either worked, or returned an - // error. But it knows about this target. - return stats, qs, err - } - } - - // The local gateway doesn't know about this target, - // try the remote ones. - for _, l := range h.l2vtgates { - stats, err := l.GetAggregateStats(target) - if !topo.IsErrType(err, topo.NoNode) { - // This remote gateway either worked, or returned an - // error. But it knows about this target. - return stats, l, err - } - } - - // We couldn't find a way to resolve this. - return nil, nil, topo.NewError(topo.NoNode, target.String()) -} - -// GetMasterCell is part of the srvtopo.TargetStats interface, included -// in the gateway.Gateway interface. -func (h *HybridGateway) GetMasterCell(keyspace, shard string) (cell string, qs queryservice.QueryService, err error) { - // Start with the local Gateway part. - if h.gw != nil { - cell, qs, err := h.gw.GetMasterCell(keyspace, shard) - if !topo.IsErrType(err, topo.NoNode) { - // The local gateway either worked, or returned an - // error. But it knows about this target. - return cell, qs, err - } - // The local gateway doesn't know about this target, - // try the remote ones. - } - - for _, l := range h.l2vtgates { - cell, err := l.GetMasterCell(keyspace, shard) - if !topo.IsErrType(err, topo.NoNode) { - // This remote gateway either worked, or returned an - // error. But it knows about this target. - return cell, l, err - } - } - - // We couldn't find a way to resolve this. - return "", nil, topo.NewError(topo.NoNode, keyspace+"/"+shard) -} - -var _ Gateway = (*HybridGateway)(nil) -var _ srvtopo.TargetStats = (*HybridGateway)(nil) diff --git a/go/vt/vtgate/gateway/l2vtgateconn.go b/go/vt/vtgate/gateway/l2vtgateconn.go deleted file mode 100644 index d4264eea828..00000000000 --- a/go/vt/vtgate/gateway/l2vtgateconn.go +++ /dev/null @@ -1,271 +0,0 @@ -/* -Copyright 2017 Google Inc. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package gateway - -import ( - "fmt" - "sort" - "sync" - "time" - - "golang.org/x/net/context" - "vitess.io/vitess/go/vt/grpcclient" - "vitess.io/vitess/go/vt/log" - "vitess.io/vitess/go/vt/topo" - "vitess.io/vitess/go/vt/topo/topoproto" - "vitess.io/vitess/go/vt/vttablet/queryservice" - "vitess.io/vitess/go/vt/vttablet/tabletconn" - - querypb "vitess.io/vitess/go/vt/proto/query" - topodatapb "vitess.io/vitess/go/vt/proto/topodata" -) - -// L2VTGateConn keeps a single connection to a vtgate backend. The -// underlying vtgate backend must have been started with the -// '-enable_forwarding' flag. -// -// It will keep a healthcheck connection going to the target, to get -// the list of available Targets. It remembers them, and exposes a -// srvtopo.TargetStats interface to query them. -type L2VTGateConn struct { - queryservice.QueryService - - // addr is the destination address. Immutable. - addr string - - // name is the name to display for stats. Immutable. - name string - - // retryCount is the number of times to retry an action. Immutable. - retryCount int - - // cancel is associated with the life cycle of this L2VTGateConn. - // It is called when Close is called. - cancel context.CancelFunc - - // mu protects the following fields. - mu sync.RWMutex - // stats has all the stats we received from the other side. - stats map[l2VTGateConnKey]*l2VTGateConnValue - // statusAggregators is a map indexed by the key - // name:keyspace/shard/tablet type - statusAggregators map[string]*TabletStatusAggregator -} - -type l2VTGateConnKey struct { - keyspace string - shard string - tabletType topodatapb.TabletType -} - -type l2VTGateConnValue struct { - tabletExternallyReparentedTimestamp int64 - - // aggregates has the per-cell aggregates. - aggregates map[string]*querypb.AggregateStats -} - -// NewL2VTGateConn creates a new L2VTGateConn object. It also starts -// the background go routine to monitor its health. -func NewL2VTGateConn(name, addr string, retryCount int) (*L2VTGateConn, error) { - conn, err := tabletconn.GetDialer()(&topodatapb.Tablet{ - Hostname: addr, - }, grpcclient.FailFast(true)) - if err != nil { - return nil, err - } - - ctx, cancel := context.WithCancel(context.Background()) - c := &L2VTGateConn{ - addr: addr, - name: name, - cancel: cancel, - stats: make(map[l2VTGateConnKey]*l2VTGateConnValue), - statusAggregators: make(map[string]*TabletStatusAggregator), - } - c.QueryService = queryservice.Wrap(conn, c.withRetry) - go c.checkConn(ctx) - return c, nil -} - -// Close is part of the queryservice.QueryService interface. -func (c *L2VTGateConn) Close(ctx context.Context) error { - c.cancel() - return nil -} - -func (c *L2VTGateConn) servingConnStats(res map[string]int64) { - c.mu.Lock() - defer c.mu.Unlock() - for k, s := range c.stats { - key := fmt.Sprintf("%s.%s.%s", k.keyspace, k.shard, topoproto.TabletTypeLString(k.tabletType)) - var htc int32 - for _, stats := range s.aggregates { - htc += stats.HealthyTabletCount - } - res[key] += int64(htc) - } -} - -func (c *L2VTGateConn) checkConn(ctx context.Context) { - for { - err := c.StreamHealth(ctx, c.streamHealthCallback) - log.Warningf("StreamHealth to %v failed, will retry after 30s: %v", c.addr, err) - time.Sleep(30 * time.Second) - } -} - -func (c *L2VTGateConn) streamHealthCallback(shr *querypb.StreamHealthResponse) error { - key := l2VTGateConnKey{ - keyspace: shr.Target.Keyspace, - shard: shr.Target.Shard, - tabletType: shr.Target.TabletType, - } - c.mu.Lock() - defer c.mu.Unlock() - e, ok := c.stats[key] - if !ok { - // No current value for this keyspace/shard/tablet type. - // Check if we received a delete, drop it. - if shr.AggregateStats == nil || (shr.AggregateStats.HealthyTabletCount == 0 && shr.AggregateStats.UnhealthyTabletCount == 0) { - return nil - } - - // It's a record for a keyspace/shard/tablet type we - // don't know yet, just create our new record with one - // entry in the map for the cell. - c.stats[key] = &l2VTGateConnValue{ - tabletExternallyReparentedTimestamp: shr.TabletExternallyReparentedTimestamp, - aggregates: map[string]*querypb.AggregateStats{ - shr.Target.Cell: shr.AggregateStats, - }, - } - return nil - } - - // Save our new value. - e.tabletExternallyReparentedTimestamp = shr.TabletExternallyReparentedTimestamp - e.aggregates[shr.Target.Cell] = shr.AggregateStats - return nil -} - -// GetAggregateStats is the discovery part of srvtopo.TargetStats interface. -func (c *L2VTGateConn) GetAggregateStats(target *querypb.Target) (*querypb.AggregateStats, error) { - key := l2VTGateConnKey{ - keyspace: target.Keyspace, - shard: target.Shard, - tabletType: target.TabletType, - } - c.mu.RLock() - defer c.mu.RUnlock() - e, ok := c.stats[key] - if !ok { - return nil, topo.NewError(topo.NoNode, target.String()) - } - - a, ok := e.aggregates[target.Cell] - if !ok { - return nil, topo.NewError(topo.NoNode, target.String()) - } - return a, nil -} - -// GetMasterCell is the discovery part of the srvtopo.TargetStats interface. -func (c *L2VTGateConn) GetMasterCell(keyspace, shard string) (cell string, err error) { - key := l2VTGateConnKey{ - keyspace: keyspace, - shard: shard, - tabletType: topodatapb.TabletType_MASTER, - } - c.mu.RLock() - defer c.mu.RUnlock() - e, ok := c.stats[key] - if !ok { - return "", topo.NewError(topo.NoNode, keyspace+"/"+shard) - } - - for cell := range e.aggregates { - return cell, nil - } - return "", topo.NewError(topo.NoNode, keyspace+"/"+shard) -} - -// CacheStatus returns a list of TabletCacheStatus per -// name:keyspace/shard/tablet type. -func (c *L2VTGateConn) CacheStatus() TabletCacheStatusList { - c.mu.RLock() - res := make(TabletCacheStatusList, 0, len(c.statusAggregators)) - for _, aggr := range c.statusAggregators { - res = append(res, aggr.GetCacheStatus()) - } - c.mu.RUnlock() - sort.Sort(res) - return res -} - -func (c *L2VTGateConn) updateStats(target *querypb.Target, startTime time.Time, err error) { - elapsed := time.Now().Sub(startTime) - aggr := c.getStatsAggregator(target) - aggr.UpdateQueryInfo("", target.TabletType, elapsed, err != nil) -} - -func (c *L2VTGateConn) getStatsAggregator(target *querypb.Target) *TabletStatusAggregator { - key := fmt.Sprintf("%v:%v/%v/%v", c.name, target.Keyspace, target.Shard, target.TabletType.String()) - - // get existing aggregator - c.mu.RLock() - aggr, ok := c.statusAggregators[key] - c.mu.RUnlock() - if ok { - return aggr - } - - // create a new one, but check again before the creation - c.mu.Lock() - defer c.mu.Unlock() - aggr, ok = c.statusAggregators[key] - if ok { - return aggr - } - aggr = NewTabletStatusAggregator(target.Keyspace, target.Shard, target.TabletType, key) - c.statusAggregators[key] = aggr - return aggr -} - -// withRetry uses the connection to execute the action. If there are -// retryable errors, it retries retryCount times before failing. It -// does not retry if the connection is in the middle of a -// transaction. While returning the error check if it maybe a result -// of a resharding event, and set the re-resolve bit and let the upper -// layers re-resolve and retry. -func (c *L2VTGateConn) withRetry(ctx context.Context, target *querypb.Target, conn queryservice.QueryService, name string, inTransaction bool, inner func(context.Context, *querypb.Target, queryservice.QueryService) (error, bool)) error { - var err error - for i := 0; i < c.retryCount+1; i++ { - startTime := time.Now() - var canRetry bool - err, canRetry = inner(ctx, target, conn) - if target != nil { - // target can be nil for StreamHealth calls. - c.updateStats(target, startTime, err) - } - if canRetry { - continue - } - break - } - return NewShardError(err, target, nil, inTransaction) -} diff --git a/go/vt/vtgate/gatewaytest/grpc_discovery_test.go b/go/vt/vtgate/gatewaytest/grpc_discovery_test.go index b14c5d5af1e..6bcdec6e7cc 100644 --- a/go/vt/vtgate/gatewaytest/grpc_discovery_test.go +++ b/go/vt/vtgate/gatewaytest/grpc_discovery_test.go @@ -28,7 +28,6 @@ import ( "vitess.io/vitess/go/vt/discovery" "vitess.io/vitess/go/vt/srvtopo" - "vitess.io/vitess/go/vt/vtgate" "vitess.io/vitess/go/vt/vtgate/gateway" "vitess.io/vitess/go/vt/vttablet/grpcqueryservice" "vitess.io/vitess/go/vt/vttablet/tabletconntest" @@ -90,77 +89,3 @@ func TestGRPCDiscovery(t *testing.T) { // run the test suite. TestSuite(t, "discovery-grpc", dg, service) } - -// TestL2VTGateDiscovery tests the hybrid gateway with a gRPC -// connection from the gateway to a l2vtgate in-process object. -func TestL2VTGateDiscovery(t *testing.T) { - flag.Set("tablet_protocol", "grpc") - flag.Set("gateway_implementation", "discoverygateway") - flag.Set("enable_forwarding", "true") - - // Fake services for the tablet, topo server. - service, ts, cell := CreateFakeServers(t) - - // Tablet: listen on a random port. - listener, err := net.Listen("tcp", ":0") - if err != nil { - t.Fatalf("Cannot listen: %v", err) - } - host := listener.Addr().(*net.TCPAddr).IP.String() - port := listener.Addr().(*net.TCPAddr).Port - defer listener.Close() - - // Tablet: create a gRPC server and listen on the port. - server := grpc.NewServer() - grpcqueryservice.Register(server, service) - go server.Serve(listener) - defer server.Stop() - - // L2VTGate: Create the discovery healthcheck, and the gateway. - // Wait for the right tablets to be present. - hc := discovery.NewHealthCheck(10*time.Second, 2*time.Minute) - rs := srvtopo.NewResilientServer(ts, "TestL2VTGateDiscovery") - l2vtgate := vtgate.Init(context.Background(), hc, rs, cell, 2, nil) - hc.AddTablet(&topodatapb.Tablet{ - Alias: &topodatapb.TabletAlias{ - Cell: cell, - Uid: 44, - }, - Keyspace: tabletconntest.TestTarget.Keyspace, - Shard: tabletconntest.TestTarget.Shard, - Type: tabletconntest.TestTarget.TabletType, - Hostname: host, - PortMap: map[string]int32{ - "grpc": int32(port), - }, - }, "test_tablet") - ctx := context.Background() - err = l2vtgate.Gateway().WaitForTablets(ctx, []topodatapb.TabletType{tabletconntest.TestTarget.TabletType}) - if err != nil { - t.Fatalf("WaitForTablets failed: %v", err) - } - - // L2VTGate: listen on a random port. - listener, err = net.Listen("tcp", ":0") - if err != nil { - t.Fatalf("Cannot listen: %v", err) - } - defer listener.Close() - - // L2VTGate: create a gRPC server and listen on the port. - server = grpc.NewServer() - grpcqueryservice.Register(server, l2vtgate.L2VTGate()) - go server.Serve(listener) - defer server.Stop() - - // VTGate: create the HybridGateway, with no local gateway, - // and just the remote address in the l2vtgate pool. - hg, err := gateway.NewHybridGateway(nil, []string{listener.Addr().String()}, 2) - if err != nil { - t.Fatalf("gateway.NewHybridGateway() failed: %v", err) - } - defer hg.Close(ctx) - - // and run the test suite. - TestSuite(t, "l2vtgate-grpc", hg, service) -} diff --git a/go/vt/vtgate/gatewaytest/suite.go b/go/vt/vtgate/gatewaytest/suite.go index 4efe6c4e37f..138de99d295 100644 --- a/go/vt/vtgate/gatewaytest/suite.go +++ b/go/vt/vtgate/gatewaytest/suite.go @@ -49,8 +49,8 @@ func CreateFakeServers(t *testing.T) (*tabletconntest.FakeQueryService, *topo.Se f := tabletconntest.CreateFakeServer(t) f.TestingGateway = true f.StreamHealthResponse = &querypb.StreamHealthResponse{ - Target: tabletconntest.TestTarget, - Serving: true, + Target: tabletconntest.TestTarget, + Serving: true, TabletExternallyReparentedTimestamp: 1234589, RealtimeStats: &querypb.RealtimeStats{ SecondsBehindMaster: 1, diff --git a/go/vt/vtgate/l2vtgate.go b/go/vt/vtgate/l2vtgate.go deleted file mode 100644 index b33fd5f8613..00000000000 --- a/go/vt/vtgate/l2vtgate.go +++ /dev/null @@ -1,108 +0,0 @@ -/* -Copyright 2017 Google Inc. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package vtgate - -import ( - "time" - - "golang.org/x/net/context" - - "vitess.io/vitess/go/stats" - "vitess.io/vitess/go/vt/log" - "vitess.io/vitess/go/vt/servenv" - "vitess.io/vitess/go/vt/topo/topoproto" - "vitess.io/vitess/go/vt/vterrors" - "vitess.io/vitess/go/vt/vtgate/gateway" - "vitess.io/vitess/go/vt/vttablet/queryservice" - - querypb "vitess.io/vitess/go/vt/proto/query" - vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" -) - -var ( - l2VTGate *L2VTGate -) - -// L2VTGate implements queryservice.QueryService and forwards queries to -// the underlying gateway. -type L2VTGate struct { - queryservice.QueryService - timings *stats.MultiTimings - errorCounts *stats.CountersWithMultiLabels - gateway gateway.Gateway -} - -// RegisterL2VTGate defines the type of registration mechanism. -type RegisterL2VTGate func(queryservice.QueryService) - -// RegisterL2VTGates stores register funcs for L2VTGate server. -var RegisterL2VTGates []RegisterL2VTGate - -// initL2VTGate creates the single L2VTGate with the provided parameters. -func initL2VTGate(gw gateway.Gateway) *L2VTGate { - if l2VTGate != nil { - log.Fatalf("L2VTGate already initialized") - } - - l2VTGate = &L2VTGate{ - timings: stats.NewMultiTimings( - "QueryServiceCall", - "l2VTGate query service call timings", - []string{"Operation", "Keyspace", "ShardName", "DbType"}), - errorCounts: stats.NewCountersWithMultiLabels( - "QueryServiceCallErrorCount", - "Error count from calls to the query service", - []string{"Operation", "Keyspace", "ShardName", "DbType"}), - gateway: gw, - } - l2VTGate.QueryService = queryservice.Wrap( - gw, - func(ctx context.Context, target *querypb.Target, conn queryservice.QueryService, name string, inTransaction bool, inner func(context.Context, *querypb.Target, queryservice.QueryService) (error, bool)) (err error) { - if target != nil { - startTime, statsKey := l2VTGate.startAction(name, target) - defer l2VTGate.endAction(startTime, statsKey, &err) - } - err, _ = inner(ctx, target, conn) - return err - }, - ) - servenv.OnRun(func() { - for _, f := range RegisterL2VTGates { - f(l2VTGate) - } - }) - return l2VTGate -} - -func (l *L2VTGate) startAction(name string, target *querypb.Target) (time.Time, []string) { - statsKey := []string{name, target.Keyspace, target.Shard, topoproto.TabletTypeLString(target.TabletType)} - startTime := time.Now() - return startTime, statsKey -} - -func (l *L2VTGate) endAction(startTime time.Time, statsKey []string, err *error) { - if *err != nil { - // Don't increment the error counter for duplicate - // keys or bad queries, as those errors are caused by - // client queries and are not VTGate's fault. - ec := vterrors.Code(*err) - if ec != vtrpcpb.Code_ALREADY_EXISTS && ec != vtrpcpb.Code_INVALID_ARGUMENT { - l.errorCounts.Add(statsKey, 1) - } - } - l.timings.Record(statsKey, startTime) -} diff --git a/go/vt/vtgate/planbuilder/delete.go b/go/vt/vtgate/planbuilder/delete.go index 0e969084476..0d20ae4556c 100644 --- a/go/vt/vtgate/planbuilder/delete.go +++ b/go/vt/vtgate/planbuilder/delete.go @@ -56,27 +56,27 @@ func buildDeletePlan(del *sqlparser.Delete, vschema ContextVSchema) (*engine.Del if hasSubquery(del) { return nil, errors.New("unsupported: subqueries in sharded DML") } - var tableName sqlparser.TableName - for t := range pb.st.tables { - tableName = t + var vindexTable *vindexes.Table + for _, tval := range pb.st.tables { + vindexTable = tval.vindexTable } - table, _, destTabletType, destTarget, err := vschema.FindTable(tableName) - if err != nil { - return nil, err + edel.Table = vindexTable + if edel.Table == nil { + return nil, errors.New("internal error: table.vindexTable is mysteriously nil") } - edel.Table = table + var err error directives := sqlparser.ExtractCommentDirectives(del.Comments) if directives.IsSet(sqlparser.DirectiveMultiShardAutocommit) { edel.MultiShardAutocommit = true } - if destTarget != nil { - if destTabletType != topodatapb.TabletType_MASTER { + if rb.ERoute.TargetDestination != nil { + if rb.ERoute.TargetTabletType != topodatapb.TabletType_MASTER { return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "unsupported: DELETE statement with a replica target") } edel.Opcode = engine.DeleteByDestination - edel.TargetDestination = destTarget + edel.TargetDestination = rb.ERoute.TargetDestination return edel, nil } edel.Vindex, edel.Values, err = getDMLRouting(del.Where, edel.Table) diff --git a/go/vt/vtgate/sandbox_test.go b/go/vt/vtgate/sandbox_test.go index 593ca103be1..52384a429a7 100644 --- a/go/vt/vtgate/sandbox_test.go +++ b/go/vt/vtgate/sandbox_test.go @@ -235,8 +235,8 @@ func newSandboxForCells(cells []string) *sandboxTopo { } // GetTopoServer is part of the srvtopo.Server interface -func (sct *sandboxTopo) GetTopoServer() *topo.Server { - return sct.topoServer +func (sct *sandboxTopo) GetTopoServer() (*topo.Server, error) { + return sct.topoServer, nil } // GetSrvKeyspaceNames is part of the srvtopo.Server interface. diff --git a/go/vt/vtgate/vindexes/vschema.go b/go/vt/vtgate/vindexes/vschema.go index 96a2100fee7..b140933cb38 100644 --- a/go/vt/vtgate/vindexes/vschema.go +++ b/go/vt/vtgate/vindexes/vschema.go @@ -485,3 +485,24 @@ func LoadFormalKeyspace(filename string) (*vschemapb.Keyspace, error) { } return formal, nil } + +// FindVindexForSharding searches through the given slice +// to find the lowest cost unique vindex +// primary vindex is always unique +// if two have the same cost, use the one that occurs earlier in the definition +// if the final result is too expensive, return nil +func FindVindexForSharding(tableName string, colVindexes []*ColumnVindex) (*ColumnVindex, error) { + if len(colVindexes) == 0 { + return nil, fmt.Errorf("no vindex definition for table %v", tableName) + } + result := colVindexes[0] + for _, colVindex := range colVindexes { + if colVindex.Vindex.Cost() < result.Vindex.Cost() && colVindex.Vindex.IsUnique() { + result = colVindex + } + } + if result.Vindex.Cost() > 1 || !result.Vindex.IsUnique() { + return nil, fmt.Errorf("could not find a vindex to use for sharding table %v", tableName) + } + return result, nil +} diff --git a/go/vt/vtgate/vindexes/vschema_test.go b/go/vt/vtgate/vindexes/vschema_test.go index d1f04d1a600..724573cb98e 100644 --- a/go/vt/vtgate/vindexes/vschema_test.go +++ b/go/vt/vtgate/vindexes/vschema_test.go @@ -468,6 +468,123 @@ func TestShardedVSchemaOwned(t *testing.T) { wantjson, _ := json.Marshal(want) t.Errorf("BuildVSchema:\n%s, want\n%s", gotjson, wantjson) } + +} + +func TestFindVindexForSharding(t *testing.T) { + ks := &Keyspace{ + Name: "sharded", + Sharded: true, + } + vindex1 := &stFU{ + name: "stfu1", + Params: map[string]string{ + "stfu1": "1", + }, + } + vindex2 := &stLN{name: "stln1"} + t1 := &Table{ + Name: sqlparser.NewTableIdent("t1"), + Keyspace: ks, + ColumnVindexes: []*ColumnVindex{ + { + Columns: []sqlparser.ColIdent{sqlparser.NewColIdent("c1")}, + Type: "stfu", + Name: "stfu1", + Vindex: vindex1, + }, + { + Columns: []sqlparser.ColIdent{sqlparser.NewColIdent("c2")}, + Type: "stln", + Name: "stln1", + Owned: true, + Vindex: vindex2, + }, + }, + } + res, err := FindVindexForSharding(t1.Name.String(), t1.ColumnVindexes) + if err != nil { + t.Error(err) + } + if !reflect.DeepEqual(res, t1.ColumnVindexes[0]) { + t.Errorf("FindVindexForSharding:\n got\n%v, want\n%v", res, t1.ColumnVindexes[0]) + } +} + +func TestFindVindexForShardingError(t *testing.T) { + ks := &Keyspace{ + Name: "sharded", + Sharded: true, + } + vindex1 := &stLU{name: "stlu1"} + vindex2 := &stLN{name: "stln1"} + t1 := &Table{ + Name: sqlparser.NewTableIdent("t1"), + Keyspace: ks, + ColumnVindexes: []*ColumnVindex{ + { + Columns: []sqlparser.ColIdent{sqlparser.NewColIdent("c1")}, + Type: "stlu", + Name: "stlu1", + Vindex: vindex1, + }, + { + Columns: []sqlparser.ColIdent{sqlparser.NewColIdent("c2")}, + Type: "stln", + Name: "stln1", + Owned: true, + Vindex: vindex2, + }, + }, + } + res, err := FindVindexForSharding(t1.Name.String(), t1.ColumnVindexes) + want := `could not find a vindex to use for sharding table t1` + if err == nil || err.Error() != want { + t.Errorf("FindVindexForSharding: %v, want %v", err, want) + } + if res != nil { + t.Errorf("FindVindexForSharding:\n got\n%v, want\n%v", res, nil) + } +} + +func TestFindVindexForSharding2(t *testing.T) { + ks := &Keyspace{ + Name: "sharded", + Sharded: true, + } + vindex1 := &stLU{name: "stlu1"} + vindex2 := &stFU{ + name: "stfu1", + Params: map[string]string{ + "stfu1": "1", + }, + } + t1 := &Table{ + Name: sqlparser.NewTableIdent("t1"), + Keyspace: ks, + ColumnVindexes: []*ColumnVindex{ + { + Columns: []sqlparser.ColIdent{sqlparser.NewColIdent("c1")}, + Type: "stlu", + Name: "stlu1", + Vindex: vindex1, + }, + { + Columns: []sqlparser.ColIdent{sqlparser.NewColIdent("c2")}, + Type: "stfu", + Name: "stfu1", + Owned: true, + Vindex: vindex2, + }, + }, + } + res, err := FindVindexForSharding(t1.Name.String(), t1.ColumnVindexes) + if err != nil { + t.Error(err) + } + if !reflect.DeepEqual(res, t1.ColumnVindexes[1]) { + t.Errorf("FindVindexForSharding:\n got\n%v, want\n%v", res, t1.ColumnVindexes[1]) + } } func TestShardedVSchemaMultiColumnVindex(t *testing.T) { diff --git a/go/vt/vtgate/vschema_manager.go b/go/vt/vtgate/vschema_manager.go index 5823c080c14..d3189f58b2d 100644 --- a/go/vt/vtgate/vschema_manager.go +++ b/go/vt/vtgate/vschema_manager.go @@ -122,22 +122,25 @@ func (vm *VSchemaManager) watchSrvVSchema(ctx context.Context, cell string) { // the given keyspace is updated in the global topo, and the full SrvVSchema // is updated in all known cells. func (vm *VSchemaManager) UpdateVSchema(ctx context.Context, ksName string, vschema *vschemapb.SrvVSchema) error { - topo := vm.e.serv.GetTopoServer() + topoServer, err := vm.e.serv.GetTopoServer() + if err != nil { + return err + } ks := vschema.Keyspaces[ksName] - err := topo.SaveVSchema(ctx, ksName, ks) + err = topoServer.SaveVSchema(ctx, ksName, ks) if err != nil { return err } - cells, err := vm.e.serv.GetTopoServer().GetKnownCells(ctx) + cells, err := topoServer.GetKnownCells(ctx) if err != nil { return err } // even if one cell fails, continue to try the others for _, cell := range cells { - cellErr := vm.e.serv.GetTopoServer().UpdateSrvVSchema(ctx, cell, vschema) + cellErr := topoServer.UpdateSrvVSchema(ctx, cell, vschema) if cellErr != nil { err = cellErr log.Errorf("error updating vschema in cell %s: %v", cell, cellErr) diff --git a/go/vt/vtgate/vtgate.go b/go/vt/vtgate/vtgate.go index 0e85e7af4d7..2b2e7b81d8c 100644 --- a/go/vt/vtgate/vtgate.go +++ b/go/vt/vtgate/vtgate.go @@ -28,7 +28,6 @@ import ( "golang.org/x/net/context" "vitess.io/vitess/go/acl" - "vitess.io/vitess/go/flagutil" "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/stats" "vitess.io/vitess/go/tb" @@ -61,7 +60,6 @@ var ( queryPlanCacheSize = flag.Int64("gate_query_cache_size", 10000, "gate server query cache size, maximum number of queries to be cached. vtgate analyzes every incoming query and generate a query plan, these plans are being cached in a lru cache. This config controls the capacity of the lru cache.") legacyAutocommit = flag.Bool("legacy_autocommit", false, "DEPRECATED: set this flag to true to get the legacy behavior: all transactions will need an explicit begin, and DMLs outside transactions will return an error.") enableForwarding = flag.Bool("enable_forwarding", false, "if specified, this process will also expose a QueryService interface that allows other vtgates to talk through this vtgate to the underlying tablets.") - l2vtgateAddrs flagutil.StringListValue disableLocalGateway = flag.Bool("disable_local_gateway", false, "if specified, this process will not route any queries to local tablets in the local cell") ) @@ -118,7 +116,6 @@ type VTGate struct { resolver *Resolver txConn *TxConn gw gateway.Gateway - l2vtgate *L2VTGate // stats objects. // TODO(sougou): This needs to be cleaned up. There @@ -162,30 +159,12 @@ func Init(ctx context.Context, hc discovery.HealthCheck, serv srvtopo.Server, ce // Start with the gateway. If we can't reach the topology service, // we can't go on much further, so we log.Fatal out. var gw gateway.Gateway - var l2vtgate *L2VTGate if !*disableLocalGateway { gw = gateway.GetCreator()(hc, serv, cell, retryCount) gw.RegisterStats() if err := gateway.WaitForTablets(gw, tabletTypesToWait); err != nil { log.Fatalf("gateway.WaitForTablets failed: %v", err) } - - // l2vtgate gives access to the underlying Gateway - // from an exported QueryService interface. - if *enableForwarding { - l2vtgate = initL2VTGate(gw) - } - } - - // If we have other vtgate pools to connect to, create a - // HybridGateway to perform the routing. - if len(l2vtgateAddrs) > 0 { - hgw, err := gateway.NewHybridGateway(gw, l2vtgateAddrs, retryCount) - if err != nil { - log.Fatalf("gateway.NewHybridGateway failed: %v", err) - } - hgw.RegisterStats() - gw = hgw } // Check we have something to do. @@ -193,6 +172,17 @@ func Init(ctx context.Context, hc discovery.HealthCheck, serv srvtopo.Server, ce log.Fatalf("'-disable_local_gateway' cannot be specified if 'l2vtgate_addrs' is also empty, otherwise this vtgate has no backend") } + // If we want to filter keyspaces replace the srvtopo.Server with a + // filtering server + if len(gateway.KeyspacesToWatch) > 0 { + log.Infof("Keyspace filtering enabled, selecting %v", gateway.KeyspacesToWatch) + var err error + serv, err = srvtopo.NewKeyspaceFilteringServer(serv, gateway.KeyspacesToWatch) + if err != nil { + log.Fatalf("Unable to construct SrvTopo server: %v", err.Error()) + } + } + tc := NewTxConn(gw, getTxMode()) // ScatterConn depends on TxConn to perform forced rollbacks. sc := NewScatterConn("VttabletCall", tc, gw, hc) @@ -204,7 +194,6 @@ func Init(ctx context.Context, hc discovery.HealthCheck, serv srvtopo.Server, ce resolver: resolver, txConn: tc, gw: gw, - l2vtgate: l2vtgate, timings: stats.NewMultiTimings( "VtgateApi", "VtgateApi timings", @@ -284,11 +273,6 @@ func (vtg *VTGate) Gateway() gateway.Gateway { return vtg.gw } -// L2VTGate returns the L2VTGate object. Mostly used for tests. -func (vtg *VTGate) L2VTGate() *L2VTGate { - return vtg.l2vtgate -} - // Execute executes a non-streaming query. This is a V3 function. func (vtg *VTGate) Execute(ctx context.Context, session *vtgatepb.Session, sql string, bindVariables map[string]*querypb.BindVariable) (newSession *vtgatepb.Session, qr *sqltypes.Result, err error) { // In this context, we don't care if we can't fully parse destination @@ -1184,7 +1168,3 @@ func unambiguousKeyspaceBSQ(queries []*vtgatepb.BoundShardQuery) string { return keyspace } } - -func init() { - flag.Var(&l2vtgateAddrs, "l2vtgate_addrs", "Specifies a comma-separated list of other l2 vtgate pools to connect to. These other vtgates must run with the --enable_forwarding flag") -} diff --git a/go/vt/vttablet/tabletconntest/fakequeryservice.go b/go/vt/vttablet/tabletconntest/fakequeryservice.go index 4ac84275bea..813a48efd3a 100644 --- a/go/vt/vttablet/tabletconntest/fakequeryservice.go +++ b/go/vt/vttablet/tabletconntest/fakequeryservice.go @@ -751,14 +751,14 @@ var TestStreamHealthStreamHealthResponse = &querypb.StreamHealthResponse{ Shard: "test_shard", TabletType: topodatapb.TabletType_RDONLY, }, - Serving: true, + Serving: true, TabletExternallyReparentedTimestamp: 1234589, RealtimeStats: &querypb.RealtimeStats{ HealthError: "random error", SecondsBehindMaster: 234, BinlogPlayersCount: 1, SecondsBehindMasterFilteredReplication: 2, - CpuUsage: 1.0, + CpuUsage: 1.0, }, } diff --git a/go/vt/vttablet/tabletmanager/init_tablet.go b/go/vt/vttablet/tabletmanager/init_tablet.go index 9a28e0fac38..804590d2d82 100644 --- a/go/vt/vttablet/tabletmanager/init_tablet.go +++ b/go/vt/vttablet/tabletmanager/init_tablet.go @@ -30,19 +30,20 @@ import ( "vitess.io/vitess/go/flagutil" "vitess.io/vitess/go/netutil" "vitess.io/vitess/go/vt/log" + "vitess.io/vitess/go/vt/mysqlctl" + topodatapb "vitess.io/vitess/go/vt/proto/topodata" "vitess.io/vitess/go/vt/topo" "vitess.io/vitess/go/vt/topo/topoproto" - - topodatapb "vitess.io/vitess/go/vt/proto/topodata" ) var ( - initDbNameOverride = flag.String("init_db_name_override", "", "(init parameter) override the name of the db used by vttablet") - initKeyspace = flag.String("init_keyspace", "", "(init parameter) keyspace to use for this tablet") - initShard = flag.String("init_shard", "", "(init parameter) shard to use for this tablet") - initTags flagutil.StringMapValue - initTabletType = flag.String("init_tablet_type", "", "(init parameter) the tablet type to use for this tablet.") - initTimeout = flag.Duration("init_timeout", 1*time.Minute, "(init parameter) timeout to use for the init phase.") + initDbNameOverride = flag.String("init_db_name_override", "", "(init parameter) override the name of the db used by vttablet") + initKeyspace = flag.String("init_keyspace", "", "(init parameter) keyspace to use for this tablet") + initShard = flag.String("init_shard", "", "(init parameter) shard to use for this tablet") + initTags flagutil.StringMapValue + initTabletType = flag.String("init_tablet_type", "", "(init parameter) the tablet type to use for this tablet.") + initTimeout = flag.Duration("init_timeout", 1*time.Minute, "(init parameter) timeout to use for the init phase.") + initPopulateMetadata = flag.Bool("init_populate_metadata", false, "(init parameter) populate metadata tables") ) func init() { @@ -206,5 +207,15 @@ func (agent *ActionAgent) InitTablet(port, gRPCPort int32) error { return vterrors.Wrap(err, "CreateTablet failed") } + // optionally populate metadata records + if *initPopulateMetadata { + agent.setTablet(tablet) + localMetadata := agent.getLocalMetadataValues(tablet.Type) + err := mysqlctl.PopulateMetadataTables(agent.MysqlDaemon, localMetadata) + if err != nil { + return vterrors.Wrap(err, "failed to -init_populate_metadata") + } + } + return nil } diff --git a/go/vt/vttablet/tabletmanager/init_tablet_test.go b/go/vt/vttablet/tabletmanager/init_tablet_test.go index d9083ab86ef..1171474637e 100644 --- a/go/vt/vttablet/tabletmanager/init_tablet_test.go +++ b/go/vt/vttablet/tabletmanager/init_tablet_test.go @@ -25,6 +25,8 @@ import ( "github.com/golang/protobuf/proto" "vitess.io/vitess/go/history" + "vitess.io/vitess/go/mysql/fakesqldb" + "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/dbconfigs" "vitess.io/vitess/go/vt/mysqlctl/fakemysqldaemon" "vitess.io/vitess/go/vt/topo" @@ -169,11 +171,24 @@ func TestInitTablet(t *testing.T) { Cell: "cell1", Uid: 1, } + db := fakesqldb.New(t) + defer db.Close() + db.AddQueryPattern(`(SET|CREATE|BEGIN|INSERT|COMMIT)\b.*`, &sqltypes.Result{}) + /* + db.AddQuery("SET @@session.sql_log_bin = 0", &sqltypes.Result{}) + db.AddQuery("CREATE DATABASE IF NOT EXISTS _vt", &sqltypes.Result{}) + db.AddQueryPattern(`CREATE TABLE IF NOT EXISTS _vt\.local_metadata.*`, &sqltypes.Result{}) + db.AddQueryPattern(`CREATE TABLE IF NOT EXISTS _vt\.shard_metadata.*`, &sqltypes.Result{}) + db.AddQuery("BEGIN", &sqltypes.Result{}) + db.AddQueryPattern(`INSERT INTO _vt.local_metadata.*`, &sqltypes.Result{}) + db.AddQueryPattern(`INSERT INTO _vt.shard_metadata.*`, &sqltypes.Result{}) + db.AddQuery("COMMIT", &sqltypes.Result{}) + */ // start with a tablet record that doesn't exist port := int32(1234) gRPCPort := int32(3456) - mysqlDaemon := fakemysqldaemon.NewFakeMysqlDaemon(nil) + mysqlDaemon := fakemysqldaemon.NewFakeMysqlDaemon(db) agent := &ActionAgent{ TopoServer: ts, TabletAlias: tabletAlias, @@ -194,6 +209,7 @@ func TestInitTablet(t *testing.T) { *initKeyspace = "test_keyspace" *initShard = "-C0" *initTabletType = "replica" + *initPopulateMetadata = true tabletAlias = &topodatapb.TabletAlias{ Cell: "cell1", Uid: 2, diff --git a/go/vt/vttablet/tabletmanager/vreplication/controller_test.go b/go/vt/vttablet/tabletmanager/vreplication/controller_test.go index bfd84b53ece..52c2a0f0b69 100644 --- a/go/vt/vttablet/tabletmanager/vreplication/controller_test.go +++ b/go/vt/vttablet/tabletmanager/vreplication/controller_test.go @@ -39,8 +39,8 @@ var ( InsertID: 0, Rows: [][]sqltypes.Value{ { - sqltypes.NewVarBinary("MariaDB/0-1-1083"), // pos - sqltypes.NULL, // stop_pos + sqltypes.NewVarBinary("MariaDB/0-1-1083"), // pos + sqltypes.NULL, // stop_pos sqltypes.NewVarBinary("9223372036854775807"), // max_tps sqltypes.NewVarBinary("9223372036854775807"), // max_replication_lag }, diff --git a/go/vt/vttablet/tabletserver/splitquery/splitter_test.go b/go/vt/vttablet/tabletserver/splitquery/splitter_test.go index 908ae71ed72..31670d911ce 100644 --- a/go/vt/vttablet/tabletserver/splitquery/splitter_test.go +++ b/go/vt/vttablet/tabletserver/splitquery/splitter_test.go @@ -459,7 +459,7 @@ func TestSplitWithExistingBindVariables(t *testing.T) { " (id < :_splitquery_end_id or" + " (id = :_splitquery_end_id and user_id < :_splitquery_end_user_id))", BindVariables: map[string]*querypb.BindVariable{ - "foo": sqltypes.Int64BindVariable(100), + "foo": sqltypes.Int64BindVariable(100), "_splitquery_start_id": sqltypes.Int64BindVariable(1), "_splitquery_start_user_id": sqltypes.Int64BindVariable(2), "_splitquery_end_id": sqltypes.Int64BindVariable(1), @@ -476,7 +476,7 @@ func TestSplitWithExistingBindVariables(t *testing.T) { " (id < :_splitquery_end_id or" + " (id = :_splitquery_end_id and user_id < :_splitquery_end_user_id))", BindVariables: map[string]*querypb.BindVariable{ - "foo": sqltypes.Int64BindVariable(100), + "foo": sqltypes.Int64BindVariable(100), "_splitquery_start_id": sqltypes.Int64BindVariable(1), "_splitquery_start_user_id": sqltypes.Int64BindVariable(3), "_splitquery_end_id": sqltypes.Int64BindVariable(5), @@ -490,7 +490,7 @@ func TestSplitWithExistingBindVariables(t *testing.T) { " :_splitquery_start_id < id or" + " (:_splitquery_start_id = id and :_splitquery_start_user_id <= user_id)", BindVariables: map[string]*querypb.BindVariable{ - "foo": sqltypes.Int64BindVariable(100), + "foo": sqltypes.Int64BindVariable(100), "_splitquery_start_user_id": sqltypes.Int64BindVariable(1), "_splitquery_start_id": sqltypes.Int64BindVariable(5), }, diff --git a/go/vt/vttablet/tabletserver/tabletenv/logstats.go b/go/vt/vttablet/tabletserver/tabletenv/logstats.go index d95cf046135..6e8323400c3 100644 --- a/go/vt/vttablet/tabletserver/tabletenv/logstats.go +++ b/go/vt/vttablet/tabletserver/tabletenv/logstats.go @@ -168,13 +168,13 @@ func (stats *LogStats) ErrorStr() string { return "" } -// RemoteAddrUsername returns some parts of CallInfo if set -func (stats *LogStats) RemoteAddrUsername() (string, string) { +// CallInfo returns some parts of CallInfo if set +func (stats *LogStats) CallInfo() (string, string) { ci, ok := callinfo.FromContext(stats.Ctx) if !ok { return "", "" } - return ci.RemoteAddr(), ci.Username() + return ci.Text(), ci.Username() } // Logf formats the log record to the given writer, either as @@ -195,7 +195,7 @@ func (stats *LogStats) Logf(w io.Writer, params url.Values) error { } // TODO: remove username here we fully enforce immediate caller id - remoteAddr, username := stats.RemoteAddrUsername() + callInfo, username := stats.CallInfo() // Valid options for the QueryLogFormat are text or json var fmtString string @@ -203,14 +203,14 @@ func (stats *LogStats) Logf(w io.Writer, params url.Values) error { case streamlog.QueryLogFormatText: fmtString = "%v\t%v\t%v\t'%v'\t'%v'\t%v\t%v\t%.6f\t%v\t%q\t%v\t%v\t%q\t%v\t%.6f\t%.6f\t%v\t%v\t%q\t\n" case streamlog.QueryLogFormatJSON: - fmtString = "{\"Method\": %q, \"RemoteAddr\": %q, \"Username\": %q, \"ImmediateCaller\": %q, \"Effective Caller\": %q, \"Start\": \"%v\", \"End\": \"%v\", \"TotalTime\": %.6f, \"PlanType\": %q, \"OriginalSQL\": %q, \"BindVars\": %v, \"Queries\": %v, \"RewrittenSQL\": %q, \"QuerySources\": %q, \"MysqlTime\": %.6f, \"ConnWaitTime\": %.6f, \"RowsAffected\": %v, \"ResponseSize\": %v, \"Error\": %q}\n" + fmtString = "{\"Method\": %q, \"CallInfo\": %q, \"Username\": %q, \"ImmediateCaller\": %q, \"Effective Caller\": %q, \"Start\": \"%v\", \"End\": \"%v\", \"TotalTime\": %.6f, \"PlanType\": %q, \"OriginalSQL\": %q, \"BindVars\": %v, \"Queries\": %v, \"RewrittenSQL\": %q, \"QuerySources\": %q, \"MysqlTime\": %.6f, \"ConnWaitTime\": %.6f, \"RowsAffected\": %v, \"ResponseSize\": %v, \"Error\": %q}\n" } _, err := fmt.Fprintf( w, fmtString, stats.Method, - remoteAddr, + callInfo, username, stats.ImmediateCaller(), stats.EffectiveCaller(), diff --git a/go/vt/vttablet/tabletserver/tabletenv/logstats_test.go b/go/vt/vttablet/tabletserver/tabletenv/logstats_test.go index 4b33f4fe424..2bd62184a0d 100644 --- a/go/vt/vttablet/tabletserver/tabletenv/logstats_test.go +++ b/go/vt/vttablet/tabletserver/tabletenv/logstats_test.go @@ -97,7 +97,7 @@ func TestLogStatsFormat(t *testing.T) { if err != nil { t.Errorf("logstats format: error marshaling json: %v -- got:\n%v", err, got) } - want = "{\n \"BindVars\": {\n \"intVal\": {\n \"type\": \"INT64\",\n \"value\": 1\n }\n },\n \"ConnWaitTime\": 0,\n \"Effective Caller\": \"\",\n \"End\": \"2017-01-01 01:02:04.000001\",\n \"Error\": \"\",\n \"ImmediateCaller\": \"\",\n \"Method\": \"test\",\n \"MysqlTime\": 0,\n \"OriginalSQL\": \"sql\",\n \"PlanType\": \"\",\n \"Queries\": 1,\n \"QuerySources\": \"mysql\",\n \"RemoteAddr\": \"\",\n \"ResponseSize\": 1,\n \"RewrittenSQL\": \"sql with pii\",\n \"RowsAffected\": 0,\n \"Start\": \"2017-01-01 01:02:03.000000\",\n \"TotalTime\": 1.000001,\n \"Username\": \"\"\n}" + want = "{\n \"BindVars\": {\n \"intVal\": {\n \"type\": \"INT64\",\n \"value\": 1\n }\n },\n \"CallInfo\": \"\",\n \"ConnWaitTime\": 0,\n \"Effective Caller\": \"\",\n \"End\": \"2017-01-01 01:02:04.000001\",\n \"Error\": \"\",\n \"ImmediateCaller\": \"\",\n \"Method\": \"test\",\n \"MysqlTime\": 0,\n \"OriginalSQL\": \"sql\",\n \"PlanType\": \"\",\n \"Queries\": 1,\n \"QuerySources\": \"mysql\",\n \"ResponseSize\": 1,\n \"RewrittenSQL\": \"sql with pii\",\n \"RowsAffected\": 0,\n \"Start\": \"2017-01-01 01:02:03.000000\",\n \"TotalTime\": 1.000001,\n \"Username\": \"\"\n}" if string(formatted) != want { t.Errorf("logstats format: got:\n%q\nwant:\n%v\n", string(formatted), want) } @@ -113,7 +113,7 @@ func TestLogStatsFormat(t *testing.T) { if err != nil { t.Errorf("logstats format: error marshaling json: %v -- got:\n%v", err, got) } - want = "{\n \"BindVars\": \"[REDACTED]\",\n \"ConnWaitTime\": 0,\n \"Effective Caller\": \"\",\n \"End\": \"2017-01-01 01:02:04.000001\",\n \"Error\": \"\",\n \"ImmediateCaller\": \"\",\n \"Method\": \"test\",\n \"MysqlTime\": 0,\n \"OriginalSQL\": \"sql\",\n \"PlanType\": \"\",\n \"Queries\": 1,\n \"QuerySources\": \"mysql\",\n \"RemoteAddr\": \"\",\n \"ResponseSize\": 1,\n \"RewrittenSQL\": \"[REDACTED]\",\n \"RowsAffected\": 0,\n \"Start\": \"2017-01-01 01:02:03.000000\",\n \"TotalTime\": 1.000001,\n \"Username\": \"\"\n}" + want = "{\n \"BindVars\": \"[REDACTED]\",\n \"CallInfo\": \"\",\n \"ConnWaitTime\": 0,\n \"Effective Caller\": \"\",\n \"End\": \"2017-01-01 01:02:04.000001\",\n \"Error\": \"\",\n \"ImmediateCaller\": \"\",\n \"Method\": \"test\",\n \"MysqlTime\": 0,\n \"OriginalSQL\": \"sql\",\n \"PlanType\": \"\",\n \"Queries\": 1,\n \"QuerySources\": \"mysql\",\n \"ResponseSize\": 1,\n \"RewrittenSQL\": \"[REDACTED]\",\n \"RowsAffected\": 0,\n \"Start\": \"2017-01-01 01:02:03.000000\",\n \"TotalTime\": 1.000001,\n \"Username\": \"\"\n}" if string(formatted) != want { t.Errorf("logstats format: got:\n%q\nwant:\n%v\n", string(formatted), want) } @@ -141,7 +141,7 @@ func TestLogStatsFormat(t *testing.T) { if err != nil { t.Errorf("logstats format: error marshaling json: %v -- got:\n%v", err, got) } - want = "{\n \"BindVars\": {\n \"strVal\": {\n \"type\": \"VARCHAR\",\n \"value\": \"abc\"\n }\n },\n \"ConnWaitTime\": 0,\n \"Effective Caller\": \"\",\n \"End\": \"2017-01-01 01:02:04.000001\",\n \"Error\": \"\",\n \"ImmediateCaller\": \"\",\n \"Method\": \"test\",\n \"MysqlTime\": 0,\n \"OriginalSQL\": \"sql\",\n \"PlanType\": \"\",\n \"Queries\": 1,\n \"QuerySources\": \"mysql\",\n \"RemoteAddr\": \"\",\n \"ResponseSize\": 1,\n \"RewrittenSQL\": \"sql with pii\",\n \"RowsAffected\": 0,\n \"Start\": \"2017-01-01 01:02:03.000000\",\n \"TotalTime\": 1.000001,\n \"Username\": \"\"\n}" + want = "{\n \"BindVars\": {\n \"strVal\": {\n \"type\": \"VARCHAR\",\n \"value\": \"abc\"\n }\n },\n \"CallInfo\": \"\",\n \"ConnWaitTime\": 0,\n \"Effective Caller\": \"\",\n \"End\": \"2017-01-01 01:02:04.000001\",\n \"Error\": \"\",\n \"ImmediateCaller\": \"\",\n \"Method\": \"test\",\n \"MysqlTime\": 0,\n \"OriginalSQL\": \"sql\",\n \"PlanType\": \"\",\n \"Queries\": 1,\n \"QuerySources\": \"mysql\",\n \"ResponseSize\": 1,\n \"RewrittenSQL\": \"sql with pii\",\n \"RowsAffected\": 0,\n \"Start\": \"2017-01-01 01:02:03.000000\",\n \"TotalTime\": 1.000001,\n \"Username\": \"\"\n}" if string(formatted) != want { t.Errorf("logstats format: got:\n%q\nwant:\n%v\n", string(formatted), want) } @@ -190,11 +190,11 @@ func TestLogStatsErrorStr(t *testing.T) { } } -func TestLogStatsRemoteAddrUsername(t *testing.T) { +func TestLogStatsCallInfo(t *testing.T) { logStats := NewLogStats(context.Background(), "test") - addr, user := logStats.RemoteAddrUsername() - if addr != "" { - t.Fatalf("remote addr should be empty") + caller, user := logStats.CallInfo() + if caller != "" { + t.Fatalf("caller should be empty") } if user != "" { t.Fatalf("username should be empty") @@ -204,13 +204,15 @@ func TestLogStatsRemoteAddrUsername(t *testing.T) { username := "vt" callInfo := &fakecallinfo.FakeCallInfo{ Remote: remoteAddr, + Method: "FakeExecute", User: username, } ctx := callinfo.NewContext(context.Background(), callInfo) logStats = NewLogStats(ctx, "test") - addr, user = logStats.RemoteAddrUsername() - if addr != remoteAddr { - t.Fatalf("expected to get remote addr: %s, but got: %s", remoteAddr, addr) + caller, user = logStats.CallInfo() + wantCaller := remoteAddr + ":FakeExecute(fakeRPC)" + if caller != wantCaller { + t.Fatalf("expected to get caller: %s, but got: %s", wantCaller, caller) } if user != username { t.Fatalf("expected to get username: %s, but got: %s", username, user) diff --git a/go/vt/vttablet/tabletserver/tabletserver.go b/go/vt/vttablet/tabletserver/tabletserver.go index 07231489dce..e698969816e 100644 --- a/go/vt/vttablet/tabletserver/tabletserver.go +++ b/go/vt/vttablet/tabletserver/tabletserver.go @@ -1770,9 +1770,9 @@ func (tsv *TabletServer) BroadcastHealth(terTimestamp int64, stats *querypb.Real target := tsv.target tsv.mu.Unlock() shr := &querypb.StreamHealthResponse{ - Target: &target, - TabletAlias: &tsv.alias, - Serving: tsv.IsServing(), + Target: &target, + TabletAlias: &tsv.alias, + Serving: tsv.IsServing(), TabletExternallyReparentedTimestamp: terTimestamp, RealtimeStats: stats, } diff --git a/go/vt/vttablet/tabletserver/tabletserver_test.go b/go/vt/vttablet/tabletserver/tabletserver_test.go index fa89040de35..b2ba5a51e55 100644 --- a/go/vt/vttablet/tabletserver/tabletserver_test.go +++ b/go/vt/vttablet/tabletserver/tabletserver_test.go @@ -2497,8 +2497,8 @@ func TestTabletServerSplitQueryEqualSplitsOnStringColumn(t *testing.T) { &querypb.BoundQuery{Sql: sql}, // EQUAL_SPLITS should not work on a string column. []string{"name_string"}, /* splitColumns */ - 10, /* splitCount */ - 0, /* numRowsPerQueryPart */ + 10, /* splitCount */ + 0, /* numRowsPerQueryPart */ querypb.SplitQueryRequest_EQUAL_SPLITS) want := "using the EQUAL_SPLITS algorithm in SplitQuery" + diff --git a/go/vt/vttablet/tabletserver/txserializer/tx_serializer.go b/go/vt/vttablet/tabletserver/txserializer/tx_serializer.go index c4a1768adc3..466da683795 100644 --- a/go/vt/vttablet/tabletserver/txserializer/tx_serializer.go +++ b/go/vt/vttablet/tabletserver/txserializer/tx_serializer.go @@ -112,17 +112,17 @@ type TxSerializer struct { // New returns a TxSerializer object. func New(dryRun bool, maxQueueSize, maxGlobalQueueSize, concurrentTransactions int) *TxSerializer { return &TxSerializer{ - ConsolidatorCache: sync2.NewConsolidatorCache(1000), - dryRun: dryRun, - maxQueueSize: maxQueueSize, - maxGlobalQueueSize: maxGlobalQueueSize, - concurrentTransactions: concurrentTransactions, + ConsolidatorCache: sync2.NewConsolidatorCache(1000), + dryRun: dryRun, + maxQueueSize: maxQueueSize, + maxGlobalQueueSize: maxGlobalQueueSize, + concurrentTransactions: concurrentTransactions, log: logutil.NewThrottledLogger("HotRowProtection", 5*time.Second), logDryRun: logutil.NewThrottledLogger("HotRowProtection DryRun", 5*time.Second), logWaitsDryRun: logutil.NewThrottledLogger("HotRowProtection Waits DryRun", 5*time.Second), logQueueExceededDryRun: logutil.NewThrottledLogger("HotRowProtection QueueExceeded DryRun", 5*time.Second), logGlobalQueueExceededDryRun: logutil.NewThrottledLogger("HotRowProtection GlobalQueueExceeded DryRun", 5*time.Second), - queues: make(map[string]*queue), + queues: make(map[string]*queue), } } diff --git a/go/vt/vttablet/tabletserver/txthrottler/tx_throttler.go b/go/vt/vttablet/tabletserver/txthrottler/tx_throttler.go index 40dd699b087..c3c022ded99 100644 --- a/go/vt/vttablet/tabletserver/txthrottler/tx_throttler.go +++ b/go/vt/vttablet/tabletserver/txthrottler/tx_throttler.go @@ -255,8 +255,8 @@ func newTxThrottlerState(config *txThrottlerConfig, keyspace, shard string, ) (*txThrottlerState, error) { t, err := throttlerFactory( TxThrottlerName, - "TPS", /* unit */ - 1, /* threadCount */ + "TPS", /* unit */ + 1, /* threadCount */ throttler.MaxRateModuleDisabled, /* maxRate */ config.throttlerConfig.MaxReplicationLagSec /* maxReplicationLag */) if err != nil { diff --git a/go/vt/worker/chunk.go b/go/vt/worker/chunk.go index ba0865769ad..2de9dc1c9f8 100644 --- a/go/vt/worker/chunk.go +++ b/go/vt/worker/chunk.go @@ -97,7 +97,7 @@ func generateChunks(ctx context.Context, wr *wrangler.Wrangler, tablet *topodata qr, err := wr.TabletManagerClient().ExecuteFetchAsApp(shortCtx, tablet, true, []byte(query), 1) cancel() if err != nil { - return nil, vterrors.Wrapf(err, "tablet: %v, table: %v: cannot determine MIN and MAX of the first primary key column. ExecuteFetchAsApp: %v", topoproto.TabletAliasString(tablet.Alias), td.Name, err) + return nil, vterrors.Wrapf(err, "tablet: %v, table: %v: cannot determine MIN and MAX of the first primary key column. ExecuteFetchAsApp", topoproto.TabletAliasString(tablet.Alias), td.Name) } if len(qr.Rows) != 1 { return nil, fmt.Errorf("tablet: %v, table: %v: cannot determine MIN and MAX of the first primary key column. Zero rows were returned", topoproto.TabletAliasString(tablet.Alias), td.Name) diff --git a/go/vt/worker/diff_utils.go b/go/vt/worker/diff_utils.go index ae9ba1a6ba9..43b469fd21b 100644 --- a/go/vt/worker/diff_utils.go +++ b/go/vt/worker/diff_utils.go @@ -430,7 +430,7 @@ func CompareRows(fields []*querypb.Field, compareCount int, left, right []sqltyp r := rv.([]byte) return bytes.Compare(l, r), nil default: - return 0, fmt.Errorf("Unsuported type %T returned by mysql.proto.Convert", l) + return 0, fmt.Errorf("Unsupported type %T returned by mysql.proto.Convert", l) } } return 0, nil @@ -440,27 +440,27 @@ func CompareRows(fields []*querypb.Field, compareCount int, left, right []sqltyp // It assumes left and right are sorted by ascending primary key. // it will record errors if extra rows exist on either side. type RowDiffer struct { - left *RowReader - right *RowReader - pkFieldCount int + left *RowReader + right *RowReader + tableDefinition *tabletmanagerdatapb.TableDefinition } // NewRowDiffer returns a new RowDiffer -func NewRowDiffer(left, right *QueryResultReader, tableDefinition *tabletmanagerdatapb.TableDefinition) (*RowDiffer, error) { +func NewRowDiffer(left, right ResultReader, tableDefinition *tabletmanagerdatapb.TableDefinition) (*RowDiffer, error) { leftFields := left.Fields() rightFields := right.Fields() if len(leftFields) != len(rightFields) { - return nil, fmt.Errorf("Cannot diff inputs with different types") + return nil, fmt.Errorf("[table=%v] Cannot diff inputs with different types", tableDefinition.Name) } for i, field := range leftFields { if field.Type != rightFields[i].Type { - return nil, fmt.Errorf("Cannot diff inputs with different types: field %v types are %v and %v", i, field.Type, rightFields[i].Type) + return nil, fmt.Errorf("[table=%v] Cannot diff inputs with different types: field %v types are %v and %v", tableDefinition.Name, i, field.Type, rightFields[i].Type) } } return &RowDiffer{ - left: NewRowReader(left), - right: NewRowReader(right), - pkFieldCount: len(tableDefinition.PrimaryKeyColumns), + left: NewRowReader(left), + right: NewRowReader(right), + tableDefinition: tableDefinition, }, nil } @@ -529,10 +529,10 @@ func (rd *RowDiffer) Go(log logutil.Logger) (dr DiffReport, err error) { continue } - if f >= rd.pkFieldCount { + if f >= len(rd.tableDefinition.PrimaryKeyColumns) { // rows have the same primary key, only content is different if dr.mismatchedRows < 10 { - log.Errorf("Different content %v in same PK: %v != %v", dr.mismatchedRows, left, right) + log.Errorf("[table=%v] Different content %v in same PK: %v != %v", rd.tableDefinition.Name, dr.mismatchedRows, left, right) } dr.mismatchedRows++ advanceLeft = true @@ -541,20 +541,20 @@ func (rd *RowDiffer) Go(log logutil.Logger) (dr DiffReport, err error) { } // have to find the 'smallest' row and advance it - c, err := CompareRows(rd.left.Fields(), rd.pkFieldCount, left, right) + c, err := CompareRows(rd.left.Fields(), len(rd.tableDefinition.PrimaryKeyColumns), left, right) if err != nil { return dr, err } if c < 0 { if dr.extraRowsLeft < 10 { - log.Errorf("Extra row %v on left: %v", dr.extraRowsLeft, left) + log.Errorf("[table=%v] Extra row %v on left: %v", rd.tableDefinition.Name, dr.extraRowsLeft, left) } dr.extraRowsLeft++ advanceLeft = true continue } else if c > 0 { if dr.extraRowsRight < 10 { - log.Errorf("Extra row %v on right: %v", dr.extraRowsRight, right) + log.Errorf("[table=%v] Extra row %v on right: %v", rd.tableDefinition.Name, dr.extraRowsRight, right) } dr.extraRowsRight++ advanceRight = true @@ -565,7 +565,7 @@ func (rd *RowDiffer) Go(log logutil.Logger) (dr DiffReport, err error) { // they're the same. Logging a regular difference // then, and advancing both. if dr.mismatchedRows < 10 { - log.Errorf("Different content %v in same PK: %v != %v", dr.mismatchedRows, left, right) + log.Errorf("[table=%v] Different content %v in same PK: %v != %v", rd.tableDefinition.Name, dr.mismatchedRows, left, right) } dr.mismatchedRows++ advanceLeft = true diff --git a/go/vt/worker/key_resolver.go b/go/vt/worker/key_resolver.go index 193d807dba4..a17cdd27e0c 100644 --- a/go/vt/worker/key_resolver.go +++ b/go/vt/worker/key_resolver.go @@ -107,18 +107,10 @@ func newV3ResolverFromTableDefinition(keyspaceSchema *vindexes.KeyspaceSchema, t if !ok { return nil, fmt.Errorf("no vschema definition for table %v", td.Name) } - // the primary vindex is most likely the sharding key, and has to - // be unique. - if len(tableSchema.ColumnVindexes) == 0 { - return nil, fmt.Errorf("no vindex definition for table %v", td.Name) - } - colVindex := tableSchema.ColumnVindexes[0] - if colVindex.Vindex.Cost() > 1 { - return nil, fmt.Errorf("primary vindex cost is too high for table %v", td.Name) - } - if !colVindex.Vindex.IsUnique() { - // This is impossible, but just checking anyway. - return nil, fmt.Errorf("primary vindex is not unique for table %v", td.Name) + // use the lowest cost unique vindex as the sharding key + colVindex, err := vindexes.FindVindexForSharding(td.Name, tableSchema.ColumnVindexes) + if err != nil { + return nil, err } // Find the sharding key column index. @@ -139,18 +131,10 @@ func newV3ResolverFromColumnList(keyspaceSchema *vindexes.KeyspaceSchema, name s if !ok { return nil, fmt.Errorf("no vschema definition for table %v", name) } - // the primary vindex is most likely the sharding key, and has to - // be unique. - if len(tableSchema.ColumnVindexes) == 0 { - return nil, fmt.Errorf("no vindex definition for table %v", name) - } - colVindex := tableSchema.ColumnVindexes[0] - if colVindex.Vindex.Cost() > 1 { - return nil, fmt.Errorf("primary vindex cost is too high for table %v", name) - } - if !colVindex.Vindex.IsUnique() { - // This is impossible, but just checking anyway. - return nil, fmt.Errorf("primary vindex is not unique for table %v", name) + // use the lowest cost unique vindex as the sharding key + colVindex, err := vindexes.FindVindexForSharding(name, tableSchema.ColumnVindexes) + if err != nil { + return nil, err } // Find the sharding key column index. diff --git a/go/vt/worker/legacy_split_clone.go b/go/vt/worker/legacy_split_clone.go index 08bbb39ebff..3a870245688 100644 --- a/go/vt/worker/legacy_split_clone.go +++ b/go/vt/worker/legacy_split_clone.go @@ -206,7 +206,7 @@ func (scw *LegacySplitCloneWorker) Run(ctx context.Context) error { cerr := scw.cleaner.CleanUp(scw.wr) if cerr != nil { if err != nil { - scw.wr.Logger().Errorf("CleanUp failed in addition to job error: %v", cerr) + scw.wr.Logger().Errorf2(cerr, "CleanUp failed in addition to job error") } else { err = cerr } @@ -222,7 +222,7 @@ func (scw *LegacySplitCloneWorker) Run(ctx context.Context) error { } if scw.healthCheck != nil { if err := scw.healthCheck.Close(); err != nil { - scw.wr.Logger().Errorf("HealthCheck.Close() failed: %v", err) + scw.wr.Logger().Errorf2(err, "HealthCheck.Close() failed") } } diff --git a/go/vt/worker/legacy_split_clone_test.go b/go/vt/worker/legacy_split_clone_test.go index f1bd3cad9f7..8754f0c7d6d 100644 --- a/go/vt/worker/legacy_split_clone_test.go +++ b/go/vt/worker/legacy_split_clone_test.go @@ -180,7 +180,7 @@ func (tc *legacySplitCloneTestCase) setUp(v3 bool) { qs := fakes.NewStreamHealthQueryService(sourceRdonly.Target()) qs.AddDefaultHealthResponse() grpcqueryservice.Register(sourceRdonly.RPCServer, &legacyTestQueryService{ - t: tc.t, + t: tc.t, StreamHealthQueryService: qs, }) } diff --git a/go/vt/worker/multi_split_diff.go b/go/vt/worker/multi_split_diff.go new file mode 100644 index 00000000000..dbf899883fa --- /dev/null +++ b/go/vt/worker/multi_split_diff.go @@ -0,0 +1,747 @@ +/* +Copyright 2017 Google Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package worker + +import ( + "fmt" + "html/template" + "sync" + + "vitess.io/vitess/go/vt/vterrors" + + "golang.org/x/net/context" + + "vitess.io/vitess/go/vt/concurrency" + "vitess.io/vitess/go/vt/mysqlctl/tmutils" + "vitess.io/vitess/go/vt/topo" + "vitess.io/vitess/go/vt/wrangler" + + "sort" + + "time" + + "vitess.io/vitess/go/sqltypes" + "vitess.io/vitess/go/vt/binlog/binlogplayer" + tabletmanagerdatapb "vitess.io/vitess/go/vt/proto/tabletmanagerdata" + topodatapb "vitess.io/vitess/go/vt/proto/topodata" + "vitess.io/vitess/go/vt/vtgate/vindexes" +) + +// MultiSplitDiffWorker executes a diff between a destination shard and its +// source shards in a shard split case. +type MultiSplitDiffWorker struct { + StatusWorker + + wr *wrangler.Wrangler + cell string + keyspace string + shard string + excludeTables []string + minHealthyRdonlyTablets int + parallelDiffsCount int + waitForFixedTimeRatherThanGtidSet bool + cleaner *wrangler.Cleaner + + // populated during WorkerStateInit, read-only after that + keyspaceInfo *topo.KeyspaceInfo + shardInfo *topo.ShardInfo + sourceUID uint32 + destinationShards []*topo.ShardInfo + + // populated during WorkerStateFindTargets, read-only after that + sourceAlias *topodatapb.TabletAlias + destinationAliases []*topodatapb.TabletAlias // matches order of destinationShards +} + +// NewMultiSplitDiffWorker returns a new MultiSplitDiffWorker object. +func NewMultiSplitDiffWorker(wr *wrangler.Wrangler, cell, keyspace, shard string, excludeTables []string, minHealthyRdonlyTablets, parallelDiffsCount int, waitForFixedTimeRatherThanGtidSet bool) Worker { + return &MultiSplitDiffWorker{ + waitForFixedTimeRatherThanGtidSet: waitForFixedTimeRatherThanGtidSet, + StatusWorker: NewStatusWorker(), + wr: wr, + cell: cell, + keyspace: keyspace, + shard: shard, + excludeTables: excludeTables, + minHealthyRdonlyTablets: minHealthyRdonlyTablets, + parallelDiffsCount: parallelDiffsCount, + cleaner: &wrangler.Cleaner{}, + } +} + +// StatusAsHTML is part of the Worker interface +func (msdw *MultiSplitDiffWorker) StatusAsHTML() template.HTML { + state := msdw.State() + + result := "Working on: " + msdw.keyspace + "/" + msdw.shard + "
\n" + result += "State: " + state.String() + "
\n" + switch state { + case WorkerStateDiff: + result += "Running...
\n" + case WorkerStateDiffWillFail: + result += "Running - have already found differences...\n" + case WorkerStateDone: + result += "Success.
\n" + } + + return template.HTML(result) +} + +// StatusAsText is part of the Worker interface +func (msdw *MultiSplitDiffWorker) StatusAsText() string { + state := msdw.State() + + result := "Working on: " + msdw.keyspace + "/" + msdw.shard + "\n" + result += "State: " + state.String() + "\n" + switch state { + case WorkerStateDiff: + result += "Running...\n" + case WorkerStateDiffWillFail: + result += "Running - have already found differences...\n" + case WorkerStateDone: + result += "Success.\n" + } + return result +} + +// Run is mostly a wrapper to run the cleanup at the end. +func (msdw *MultiSplitDiffWorker) Run(ctx context.Context) error { + resetVars() + err := msdw.run(ctx) + + msdw.SetState(WorkerStateCleanUp) + cerr := msdw.cleaner.CleanUp(msdw.wr) + if cerr != nil { + if err != nil { + msdw.wr.Logger().Errorf("CleanUp failed in addition to job error: %v", cerr) + } else { + err = cerr + } + } + if err != nil { + msdw.wr.Logger().Errorf("Run() error: %v", err) + msdw.SetState(WorkerStateError) + return err + } + msdw.SetState(WorkerStateDone) + return nil +} + +func (msdw *MultiSplitDiffWorker) run(ctx context.Context) error { + // first state: read what we need to do + if err := msdw.init(ctx); err != nil { + return fmt.Errorf("init() failed: %v", err) + } + if err := checkDone(ctx); err != nil { + return err + } + + // second state: find targets + if err := msdw.findTargets(ctx); err != nil { + return fmt.Errorf("findTargets() failed: %v", err) + } + if err := checkDone(ctx); err != nil { + return err + } + + // third phase: synchronize replication + if err := msdw.synchronizeReplication(ctx); err != nil { + return fmt.Errorf("synchronizeReplication() failed: %v", err) + } + if err := checkDone(ctx); err != nil { + return err + } + + // fourth phase: diff + if err := msdw.diff(ctx); err != nil { + return fmt.Errorf("diff() failed: %v", err) + } + + return checkDone(ctx) +} + +// init phase: +// - read the shard info, make sure it has sources +func (msdw *MultiSplitDiffWorker) init(ctx context.Context) error { + msdw.SetState(WorkerStateInit) + + var err error + shortCtx, cancel := context.WithTimeout(ctx, *remoteActionsTimeout) + msdw.keyspaceInfo, err = msdw.wr.TopoServer().GetKeyspace(shortCtx, msdw.keyspace) + cancel() + if err != nil { + return fmt.Errorf("cannot read keyspace %v: %v", msdw.keyspace, err) + } + shortCtx, cancel = context.WithTimeout(ctx, *remoteActionsTimeout) + msdw.shardInfo, err = msdw.wr.TopoServer().GetShard(shortCtx, msdw.keyspace, msdw.shard) + cancel() + if err != nil { + return fmt.Errorf("cannot read shard %v/%v: %v", msdw.keyspace, msdw.shard, err) + } + + if !msdw.shardInfo.HasMaster() { + return fmt.Errorf("shard %v/%v has no master", msdw.keyspace, msdw.shard) + } + + destinationShards, err := msdw.findDestinationShards(ctx) + if err != nil { + return fmt.Errorf("findDestinationShards() failed for %v/%v/%v: %v", msdw.cell, msdw.keyspace, msdw.shard, err) + } + msdw.destinationShards = destinationShards + + return nil +} + +// findDestinationShards finds all the shards that have filtered replication from the source shard +func (msdw *MultiSplitDiffWorker) findDestinationShards(ctx context.Context) ([]*topo.ShardInfo, error) { + shortCtx, cancel := context.WithTimeout(ctx, *remoteActionsTimeout) + keyspaces, err := msdw.wr.TopoServer().GetKeyspaces(shortCtx) + cancel() + if err != nil { + return nil, vterrors.Wrap(err, "failed to get list of keyspaces") + } + + var resultArray []*topo.ShardInfo + + for _, keyspace := range keyspaces { + shardInfo, err := msdw.findShardsInKeyspace(ctx, keyspace) + if err != nil { + return nil, err + } + resultArray = append(resultArray, shardInfo...) + } + + if len(resultArray) == 0 { + return nil, fmt.Errorf("there are no destination shards") + } + return resultArray, nil +} + +func (msdw *MultiSplitDiffWorker) findShardsInKeyspace(ctx context.Context, keyspace string) ([]*topo.ShardInfo, error) { + shortCtx, cancel := context.WithTimeout(ctx, *remoteActionsTimeout) + shards, err := msdw.wr.TopoServer().GetShardNames(shortCtx, keyspace) + cancel() + if err != nil { + return nil, vterrors.Wrapf(err, "failed to get list of shards for keyspace '%v'", keyspace) + } + + var resultArray []*topo.ShardInfo + first := true + + for _, shard := range shards { + shardInfo, uid, err := msdw.getShardInfo(ctx, keyspace, shard) + if err != nil { + return nil, err + } + // There might not be any source shards here + if shardInfo != nil { + if first { + msdw.sourceUID = uid + first = false + } else if msdw.sourceUID != uid { + return nil, fmt.Errorf("found a source ID that was different, aborting. %v vs %v", msdw.sourceUID, uid) + } + + resultArray = append(resultArray, shardInfo) + } + } + + return resultArray, nil +} + +func (msdw *MultiSplitDiffWorker) getShardInfo(ctx context.Context, keyspace string, shard string) (*topo.ShardInfo, uint32, error) { + shortCtx, cancel := context.WithTimeout(ctx, *remoteActionsTimeout) + si, err := msdw.wr.TopoServer().GetShard(shortCtx, keyspace, shard) + cancel() + if err != nil { + return nil, 0, vterrors.Wrap(err, "failed to get shard info from toposerver") + } + + for _, sourceShard := range si.SourceShards { + if len(sourceShard.Tables) == 0 && sourceShard.Keyspace == msdw.keyspace && sourceShard.Shard == msdw.shard { + // Prevents the same shard from showing up multiple times + return si, sourceShard.Uid, nil + } + } + + return nil, 0, nil +} + +// findTargets phase: +// - find one rdonly in source shard +// - find one rdonly per destination shard +// - mark them all as 'worker' pointing back to us +func (msdw *MultiSplitDiffWorker) findTargets(ctx context.Context) error { + msdw.SetState(WorkerStateFindTargets) + + var err error + + // find an appropriate tablet in the source shard + msdw.sourceAlias, err = FindWorkerTablet( + ctx, + msdw.wr, + msdw.cleaner, + nil, /* tsc */ + msdw.cell, + msdw.keyspace, + msdw.shard, + 1, /* minHealthyTablets */ + topodatapb.TabletType_RDONLY) + if err != nil { + return fmt.Errorf("FindWorkerTablet() failed for %v/%v/%v: %v", msdw.cell, msdw.keyspace, msdw.shard, err) + } + + // find an appropriate tablet in each destination shard + msdw.destinationAliases = make([]*topodatapb.TabletAlias, len(msdw.destinationShards)) + for i, destinationShard := range msdw.destinationShards { + keyspace := destinationShard.Keyspace() + shard := destinationShard.ShardName() + destinationAlias, err := FindWorkerTablet( + ctx, + msdw.wr, + msdw.cleaner, + nil, /* tsc */ + msdw.cell, + keyspace, + shard, + msdw.minHealthyRdonlyTablets, + topodatapb.TabletType_RDONLY) + if err != nil { + return fmt.Errorf("FindWorkerTablet() failed for %v/%v/%v: %v", msdw.cell, keyspace, shard, err) + } + msdw.destinationAliases[i] = destinationAlias + } + if err != nil { + return fmt.Errorf("FindWorkerTablet() failed for %v/%v/%v: %v", msdw.cell, msdw.keyspace, msdw.shard, err) + } + + return nil +} + +// ask the master of the destination shard to pause filtered replication, +// and return the source binlog positions +// (add a cleanup task to restart filtered replication on master) +func (msdw *MultiSplitDiffWorker) stopReplicationOnAllDestinationMasters(ctx context.Context, masterInfos []*topo.TabletInfo) ([]string, error) { + destVreplicationPos := make([]string, len(msdw.destinationShards)) + + for i, shardInfo := range msdw.destinationShards { + masterInfo := masterInfos[i] + + msdw.wr.Logger().Infof("Stopping master binlog replication on %v", shardInfo.MasterAlias) + shortCtx, cancel := context.WithTimeout(ctx, *remoteActionsTimeout) + _, err := msdw.wr.TabletManagerClient().VReplicationExec(shortCtx, masterInfo.Tablet, binlogplayer.StopVReplication(msdw.sourceUID, "for split diff")) + cancel() + if err != nil { + return nil, fmt.Errorf("VReplicationExec(stop) for %v failed: %v", shardInfo.MasterAlias, err) + } + wrangler.RecordVReplicationAction(msdw.cleaner, masterInfo.Tablet, binlogplayer.StartVReplication(msdw.sourceUID)) + shortCtx, cancel = context.WithTimeout(ctx, *remoteActionsTimeout) + p3qr, err := msdw.wr.TabletManagerClient().VReplicationExec(shortCtx, masterInfo.Tablet, binlogplayer.ReadVReplicationPos(msdw.sourceUID)) + cancel() + if err != nil { + return nil, fmt.Errorf("VReplicationExec(stop) for %v failed: %v", msdw.shardInfo.MasterAlias, err) + } + qr := sqltypes.Proto3ToResult(p3qr) + if len(qr.Rows) != 1 || len(qr.Rows[0]) != 1 { + return nil, fmt.Errorf("unexpected result while reading position: %v", qr) + } + destVreplicationPos[i] = qr.Rows[0][0].ToString() + if err != nil { + return nil, fmt.Errorf("StopBlp for %v failed: %v", msdw.shardInfo.MasterAlias, err) + } + } + return destVreplicationPos, nil +} + +func (msdw *MultiSplitDiffWorker) getTabletInfoForShard(ctx context.Context, shardInfo *topo.ShardInfo) (*topo.TabletInfo, error) { + shortCtx, cancel := context.WithTimeout(ctx, *remoteActionsTimeout) + masterInfo, err := msdw.wr.TopoServer().GetTablet(shortCtx, shardInfo.MasterAlias) + cancel() + if err != nil { + return nil, fmt.Errorf("synchronizeReplication: cannot get Tablet record for master %v: %v", msdw.shardInfo.MasterAlias, err) + } + return masterInfo, nil +} + +// stop the source tablet at a binlog position higher than the +// destination masters. Return the reached position +// (add a cleanup task to restart binlog replication on the source tablet, and +// change the existing ChangeSlaveType cleanup action to 'spare' type) +func (msdw *MultiSplitDiffWorker) stopReplicationOnSourceRdOnlyTabletAt(ctx context.Context, destVreplicationPos []string) (string, error) { + shortCtx, cancel := context.WithTimeout(ctx, *remoteActionsTimeout) + sourceTablet, err := msdw.wr.TopoServer().GetTablet(shortCtx, msdw.sourceAlias) + cancel() + if err != nil { + return "", err + } + + var mysqlPos string // will be the last GTID that we stopped at + for _, vreplicationPos := range destVreplicationPos { + // We need to stop the source RDONLY tablet at a position which includes ALL of the positions of the destination + // shards. We do this by starting replication and then stopping at a minimum of each blp position separately. + // TODO this is not terribly efficient but it's possible to implement without changing the existing RPC, + // if we make StopSlaveMinimum take multiple blp positions then this will be a lot more efficient because you just + // check for each position using WAIT_UNTIL_SQL_THREAD_AFTER_GTIDS and then stop replication. + + msdw.wr.Logger().Infof("Stopping slave %v at a minimum of %v", msdw.sourceAlias, vreplicationPos) + // read the tablet + sourceTablet, err := msdw.wr.TopoServer().GetTablet(shortCtx, msdw.sourceAlias) + if err != nil { + return "", err + } + shortCtx, cancel = context.WithTimeout(ctx, *remoteActionsTimeout) + msdw.wr.TabletManagerClient().StartSlave(shortCtx, sourceTablet.Tablet) + cancel() + if err != nil { + return "", err + } + + shortCtx, cancel = context.WithTimeout(ctx, *remoteActionsTimeout) + mysqlPos, err = msdw.wr.TabletManagerClient().StopSlaveMinimum(shortCtx, sourceTablet.Tablet, vreplicationPos, *remoteActionsTimeout) + cancel() + if err != nil { + return "", fmt.Errorf("cannot stop slave %v at right binlog position %v: %v", msdw.sourceAlias, vreplicationPos, err) + } + } + // change the cleaner actions from ChangeSlaveType(rdonly) + // to StartSlave() + ChangeSlaveType(spare) + wrangler.RecordStartSlaveAction(msdw.cleaner, sourceTablet.Tablet) + + return mysqlPos, nil +} + +// ask the master of the destination shard to resume filtered replication +// up to the new list of positions, and return its binlog position. +func (msdw *MultiSplitDiffWorker) resumeReplicationOnDestinationMasterUntil(ctx context.Context, shardInfo *topo.ShardInfo, mysqlPos string, masterInfo *topo.TabletInfo) (string, error) { + msdw.wr.Logger().Infof("Restarting master %v until it catches up to %v", shardInfo.MasterAlias, mysqlPos) + shortCtx, cancel := context.WithTimeout(ctx, *remoteActionsTimeout) + _, err := msdw.wr.TabletManagerClient().VReplicationExec(shortCtx, masterInfo.Tablet, binlogplayer.StartVReplicationUntil(msdw.sourceUID, mysqlPos)) + cancel() + if err != nil { + return "", fmt.Errorf("VReplication(start until) for %v until %v failed: %v", shardInfo.MasterAlias, mysqlPos, err) + } + shortCtx, cancel = context.WithTimeout(ctx, *remoteActionsTimeout) + if err := msdw.wr.TabletManagerClient().VReplicationWaitForPos(shortCtx, masterInfo.Tablet, int(msdw.sourceUID), mysqlPos); err != nil { + cancel() + return "", fmt.Errorf("VReplicationWaitForPos for %v until %v failed: %v", shardInfo.MasterAlias, mysqlPos, err) + } + cancel() + + shortCtx, cancel = context.WithTimeout(ctx, *remoteActionsTimeout) + masterPos, err := msdw.wr.TabletManagerClient().MasterPosition(shortCtx, masterInfo.Tablet) + cancel() + if err != nil { + return "", fmt.Errorf("MasterPosition for %v failed: %v", msdw.shardInfo.MasterAlias, err) + } + return masterPos, nil +} + +// wait until the destination tablet is equal or passed that master +// binlog position, and stop its replication. +// (add a cleanup task to restart binlog replication on it, and change +// the existing ChangeSlaveType cleanup action to 'spare' type) +func (msdw *MultiSplitDiffWorker) stopReplicationOnDestinationRdOnlys(ctx context.Context, destinationAlias *topodatapb.TabletAlias, masterPos string) error { + if msdw.waitForFixedTimeRatherThanGtidSet { + msdw.wr.Logger().Infof("Workaround for broken GTID set in destination RDONLY. Just waiting for 1 minute for %v and assuming replication has caught up. (should be at %v)", destinationAlias, masterPos) + } else { + msdw.wr.Logger().Infof("Waiting for destination tablet %v to catch up to %v", destinationAlias, masterPos) + } + shortCtx, cancel := context.WithTimeout(ctx, *remoteActionsTimeout) + destinationTablet, err := msdw.wr.TopoServer().GetTablet(shortCtx, destinationAlias) + cancel() + if err != nil { + return err + } + + if msdw.waitForFixedTimeRatherThanGtidSet { + time.Sleep(1 * time.Minute) + } + + shortCtx, cancel = context.WithTimeout(ctx, *remoteActionsTimeout) + if msdw.waitForFixedTimeRatherThanGtidSet { + err = msdw.wr.TabletManagerClient().StopSlave(shortCtx, destinationTablet.Tablet) + } else { + _, err = msdw.wr.TabletManagerClient().StopSlaveMinimum(shortCtx, destinationTablet.Tablet, masterPos, *remoteActionsTimeout) + } + cancel() + if err != nil { + return fmt.Errorf("StopSlaveMinimum for %v at %v failed: %v", destinationAlias, masterPos, err) + } + wrangler.RecordStartSlaveAction(msdw.cleaner, destinationTablet.Tablet) + return nil +} + +// restart filtered replication on the destination master. +// (remove the cleanup task that does the same) +func (msdw *MultiSplitDiffWorker) restartReplicationOn(ctx context.Context, shardInfo *topo.ShardInfo, masterInfo *topo.TabletInfo) error { + msdw.wr.Logger().Infof("Restarting filtered replication on master %v", shardInfo.MasterAlias) + shortCtx, cancel := context.WithTimeout(ctx, *remoteActionsTimeout) + _, err := msdw.wr.TabletManagerClient().VReplicationExec(shortCtx, masterInfo.Tablet, binlogplayer.StartVReplication(msdw.sourceUID)) + if err != nil { + return fmt.Errorf("VReplicationExec(start) failed for %v: %v", shardInfo.MasterAlias, err) + } + cancel() + return nil +} + +// synchronizeReplication phase: +// At this point, the source and the destination tablet are stopped at the same +// point. + +func (msdw *MultiSplitDiffWorker) synchronizeReplication(ctx context.Context) error { + msdw.SetState(WorkerStateSyncReplication) + var err error + + masterInfos := make([]*topo.TabletInfo, len(msdw.destinationAliases)) + for i, shardInfo := range msdw.destinationShards { + masterInfos[i], err = msdw.getTabletInfoForShard(ctx, shardInfo) + if err != nil { + return err + } + } + + destVreplicationPos, err := msdw.stopReplicationOnAllDestinationMasters(ctx, masterInfos) + if err != nil { + return err + } + + mysqlPos, err := msdw.stopReplicationOnSourceRdOnlyTabletAt(ctx, destVreplicationPos) + if err != nil { + return err + } + + for i, shardInfo := range msdw.destinationShards { + masterInfo := masterInfos[i] + destinationAlias := msdw.destinationAliases[i] + + masterPos, err := msdw.resumeReplicationOnDestinationMasterUntil(ctx, shardInfo, mysqlPos, masterInfo) + if err != nil { + return err + } + + err = msdw.stopReplicationOnDestinationRdOnlys(ctx, destinationAlias, masterPos) + if err != nil { + return err + } + + err = msdw.restartReplicationOn(ctx, shardInfo, masterInfo) + if err != nil { + return err + } + } + + return nil +} + +func (msdw *MultiSplitDiffWorker) diffSingleTable(ctx context.Context, wg *sync.WaitGroup, tableDefinition *tabletmanagerdatapb.TableDefinition, keyspaceSchema *vindexes.KeyspaceSchema) error { + msdw.wr.Logger().Infof("Starting the diff on table %v", tableDefinition.Name) + + sourceQueryResultReader, err := TableScan(ctx, msdw.wr.Logger(), msdw.wr.TopoServer(), msdw.sourceAlias, tableDefinition) + if err != nil { + return fmt.Errorf("TableScan(source) failed: %v", err) + } + defer sourceQueryResultReader.Close(ctx) + + destinationQueryResultReaders := make([]ResultReader, len(msdw.destinationAliases)) + for i, destinationAlias := range msdw.destinationAliases { + destinationQueryResultReader, err := TableScan(ctx, msdw.wr.Logger(), msdw.wr.TopoServer(), destinationAlias, tableDefinition) + if err != nil { + return fmt.Errorf("TableScan(destination) failed: %v", err) + } + + // For the first result scanner, let's check the PKs are of types that we can work with + if i == 0 { + err = CheckValidTypesForResultMerger(destinationQueryResultReader.fields, len(tableDefinition.PrimaryKeyColumns)) + if err != nil { + return fmt.Errorf("invalid types for multi split diff. use the regular split diff instead %v", err.Error()) + } + } + + // We are knowingly using defer inside the for loop. + // All these readers need to be active until the diff is done + //noinspection GoDeferInLoop + defer destinationQueryResultReader.Close(ctx) + destinationQueryResultReaders[i] = destinationQueryResultReader + } + mergedResultReader, err := NewResultMerger(destinationQueryResultReaders, len(tableDefinition.PrimaryKeyColumns)) + if err != nil { + return fmt.Errorf("NewResultMerger failed: %v", err) + } + + // Create the row differ. + differ, err := NewRowDiffer(sourceQueryResultReader, mergedResultReader, tableDefinition) + if err != nil { + return fmt.Errorf("NewRowDiffer() failed: %v", err) + } + + // And run the diff. + report, err := differ.Go(msdw.wr.Logger()) + if err != nil { + return fmt.Errorf("Differ.Go failed: %v", err.Error()) + } + + if report.HasDifferences() { + return fmt.Errorf("table %v has differences: %v", tableDefinition.Name, report.String()) + } + + msdw.wr.Logger().Infof("Table %v checks out (%v rows processed, %v qps)", tableDefinition.Name, report.processedRows, report.processingQPS) + + return nil +} + +func (msdw *MultiSplitDiffWorker) tableDiffingConsumer(ctx context.Context, wg *sync.WaitGroup, tableChan chan *tabletmanagerdatapb.TableDefinition, rec *concurrency.AllErrorRecorder, keyspaceSchema *vindexes.KeyspaceSchema) { + defer wg.Done() + + for tableDefinition := range tableChan { + err := msdw.diffSingleTable(ctx, wg, tableDefinition, keyspaceSchema) + if err != nil { + msdw.markAsWillFail(rec, err) + msdw.wr.Logger().Errorf("%v", err) + } + } +} + +func (msdw *MultiSplitDiffWorker) gatherSchemaInfo(ctx context.Context) ([]*tabletmanagerdatapb.SchemaDefinition, *tabletmanagerdatapb.SchemaDefinition, error) { + msdw.wr.Logger().Infof("Gathering schema information...") + wg := sync.WaitGroup{} + rec := &concurrency.AllErrorRecorder{} + + // this array will have concurrent writes to it, but no two goroutines will write to the same slot in the array + destinationSchemaDefinitions := make([]*tabletmanagerdatapb.SchemaDefinition, len(msdw.destinationAliases)) + var sourceSchemaDefinition *tabletmanagerdatapb.SchemaDefinition + for i, destinationAlias := range msdw.destinationAliases { + wg.Add(1) + go func(i int, destinationAlias *topodatapb.TabletAlias) { + var err error + shortCtx, cancel := context.WithTimeout(ctx, *remoteActionsTimeout) + destinationSchemaDefinition, err := msdw.wr.GetSchema( + shortCtx, destinationAlias, nil /* tables */, msdw.excludeTables, false /* includeViews */) + cancel() + msdw.markAsWillFail(rec, err) + destinationSchemaDefinitions[i] = destinationSchemaDefinition + msdw.wr.Logger().Infof("Got schema from destination %v", destinationAlias) + wg.Done() + }(i, destinationAlias) + } + wg.Add(1) + go func() { + var err error + shortCtx, cancel := context.WithTimeout(ctx, *remoteActionsTimeout) + sourceSchemaDefinition, err = msdw.wr.GetSchema( + shortCtx, msdw.sourceAlias, nil /* tables */, msdw.excludeTables, false /* includeViews */) + cancel() + msdw.markAsWillFail(rec, err) + msdw.wr.Logger().Infof("Got schema from source %v", msdw.sourceAlias) + wg.Done() + }() + + wg.Wait() + if rec.HasErrors() { + return nil, nil, rec.Error() + } + + return destinationSchemaDefinitions, sourceSchemaDefinition, nil +} + +func (msdw *MultiSplitDiffWorker) diffSchemaInformation(ctx context.Context, destinationSchemaDefinitions []*tabletmanagerdatapb.SchemaDefinition, sourceSchemaDefinition *tabletmanagerdatapb.SchemaDefinition) { + msdw.wr.Logger().Infof("Diffing the schema...") + rec := &concurrency.AllErrorRecorder{} + sourceShardName := fmt.Sprintf("%v/%v", msdw.shardInfo.Keyspace(), msdw.shardInfo.ShardName()) + for i, destinationSchemaDefinition := range destinationSchemaDefinitions { + destinationShard := msdw.destinationShards[i] + destinationShardName := fmt.Sprintf("%v/%v", destinationShard.Keyspace(), destinationShard.ShardName()) + tmutils.DiffSchema(destinationShardName, destinationSchemaDefinition, sourceShardName, sourceSchemaDefinition, rec) + } + if rec.HasErrors() { + msdw.wr.Logger().Warningf("Different schemas: %v", rec.Error().Error()) + } else { + msdw.wr.Logger().Infof("Schema match, good.") + } +} + +func (msdw *MultiSplitDiffWorker) loadVSchema(ctx context.Context) (*vindexes.KeyspaceSchema, error) { + shortCtx, cancel := context.WithCancel(ctx) + kschema, err := msdw.wr.TopoServer().GetVSchema(shortCtx, msdw.keyspace) + cancel() + if err != nil { + return nil, fmt.Errorf("cannot load VSchema for keyspace %v: %v", msdw.keyspace, err) + } + if kschema == nil { + return nil, fmt.Errorf("no VSchema for keyspace %v", msdw.keyspace) + } + + keyspaceSchema, err := vindexes.BuildKeyspaceSchema(kschema, msdw.keyspace) + if err != nil { + return nil, fmt.Errorf("cannot build vschema for keyspace %v: %v", msdw.keyspace, err) + } + return keyspaceSchema, nil +} + +// diff phase: will log messages regarding the diff. +// - get the schema on all tablets +// - if some table schema mismatches, record them (use existing schema diff tools). +// - for each table in destination, run a diff pipeline. + +func (msdw *MultiSplitDiffWorker) diff(ctx context.Context) error { + msdw.SetState(WorkerStateDiff) + + destinationSchemaDefinitions, sourceSchemaDefinition, err := msdw.gatherSchemaInfo(ctx) + if err != nil { + return err + } + msdw.diffSchemaInformation(ctx, destinationSchemaDefinitions, sourceSchemaDefinition) + + // read the vschema if needed + var keyspaceSchema *vindexes.KeyspaceSchema + if *useV3ReshardingMode { + keyspaceSchema, err = msdw.loadVSchema(ctx) + if err != nil { + return err + } + } + + msdw.wr.Logger().Infof("Running the diffs...") + tableDefinitions := sourceSchemaDefinition.TableDefinitions + rec := &concurrency.AllErrorRecorder{} + + // sort tables by size + // if there are large deltas between table sizes then it's more efficient to start working on the large tables first + sort.Slice(tableDefinitions, func(i, j int) bool { return tableDefinitions[i].DataLength > tableDefinitions[j].DataLength }) + tableChan := make(chan *tabletmanagerdatapb.TableDefinition, len(tableDefinitions)) + for _, tableDefinition := range tableDefinitions { + tableChan <- tableDefinition + } + close(tableChan) + + consumers := sync.WaitGroup{} + // start as many goroutines we want parallel diffs running + for i := 0; i < msdw.parallelDiffsCount; i++ { + consumers.Add(1) + go msdw.tableDiffingConsumer(ctx, &consumers, tableChan, rec, keyspaceSchema) + } + + // wait for all consumers to wrap up their work + consumers.Wait() + + return rec.Error() +} + +// markAsWillFail records the error and changes the state of the worker to reflect this +func (msdw *MultiSplitDiffWorker) markAsWillFail(er concurrency.ErrorRecorder, err error) { + er.RecordError(err) + msdw.SetState(WorkerStateDiffWillFail) +} diff --git a/go/vt/worker/multi_split_diff_cmd.go b/go/vt/worker/multi_split_diff_cmd.go new file mode 100644 index 00000000000..353e47790e7 --- /dev/null +++ b/go/vt/worker/multi_split_diff_cmd.go @@ -0,0 +1,231 @@ +/* +Copyright 2017 Google Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package worker + +import ( + "flag" + "fmt" + "html/template" + "net/http" + "strconv" + "strings" + "sync" + + "golang.org/x/net/context" + "vitess.io/vitess/go/vt/concurrency" + "vitess.io/vitess/go/vt/topo/topoproto" + "vitess.io/vitess/go/vt/wrangler" +) + +const multiSplitDiffHTML = ` + + + Multi Split Diff Action + + +

Multi Split Diff Action

+ + {{if .Error}} + Error: {{.Error}}
+ {{else}} + {{range $i, $si := .Shards}} +
  • {{$si.Keyspace}}/{{$si.Shard}}
  • + {{end}} + {{end}} + +` + +const multiSplitDiffHTML2 = ` + + + Multi Split Diff Action + + +

    Shard involved: {{.Keyspace}}/{{.Shard}}

    +

    Multi Split Diff Action

    +
    + +
    + +
    + +
    + +
    + +
    + + + +
    + +` + +var multiSplitDiffTemplate = mustParseTemplate("multiSplitDiff", multiSplitDiffHTML) +var multiSplitDiffTemplate2 = mustParseTemplate("multiSplitDiff2", multiSplitDiffHTML2) + +func commandMultiSplitDiff(wi *Instance, wr *wrangler.Wrangler, subFlags *flag.FlagSet, args []string) (Worker, error) { + excludeTables := subFlags.String("exclude_tables", "", "comma separated list of tables to exclude") + minHealthyRdonlyTablets := subFlags.Int("min_healthy_rdonly_tablets", defaultMinHealthyRdonlyTablets, "minimum number of healthy RDONLY tablets before taking out one") + parallelDiffsCount := subFlags.Int("parallel_diffs_count", defaultParallelDiffsCount, "number of tables to diff in parallel") + waitForFixedTimeRatherThanGtidSet := subFlags.Bool("wait_for_fixed_time_rather_than_gtid_set", false, "wait for 1m when syncing up the destination RDONLY tablet rather than using the GTID set. Use this when the GTID set on the RDONLY is broken. Make sure the RDONLY is not behind in replication when using this flag.") + if err := subFlags.Parse(args); err != nil { + return nil, err + } + if subFlags.NArg() != 1 { + subFlags.Usage() + return nil, fmt.Errorf("command MultiSplitDiff requires ") + } + keyspace, shard, err := topoproto.ParseKeyspaceShard(subFlags.Arg(0)) + if err != nil { + return nil, err + } + var excludeTableArray []string + if *excludeTables != "" { + excludeTableArray = strings.Split(*excludeTables, ",") + } + return NewMultiSplitDiffWorker(wr, wi.cell, keyspace, shard, excludeTableArray, *minHealthyRdonlyTablets, *parallelDiffsCount, *waitForFixedTimeRatherThanGtidSet), nil +} + +// shardSources returns all the shards that are SourceShards of at least one other shard. +func shardSources(ctx context.Context, wr *wrangler.Wrangler) ([]map[string]string, error) { + shortCtx, cancel := context.WithTimeout(ctx, *remoteActionsTimeout) + keyspaces, err := wr.TopoServer().GetKeyspaces(shortCtx) + cancel() + if err != nil { + return nil, fmt.Errorf("failed to get list of keyspaces: %v", err) + } + + wg := sync.WaitGroup{} + mu := sync.Mutex{} // protects sourceShards + // Use a map to dedupe source shards + sourceShards := make(map[string]map[string]string) + rec := concurrency.AllErrorRecorder{} + for _, keyspace := range keyspaces { + wg.Add(1) + go func(keyspace string) { + defer wg.Done() + shortCtx, cancel := context.WithTimeout(ctx, *remoteActionsTimeout) + shards, err := wr.TopoServer().GetShardNames(shortCtx, keyspace) + cancel() + if err != nil { + rec.RecordError(fmt.Errorf("failed to get list of shards for keyspace '%v': %v", keyspace, err)) + return + } + for _, shard := range shards { + wg.Add(1) + go func(keyspace, shard string) { + defer wg.Done() + shortCtx, cancel := context.WithTimeout(ctx, *remoteActionsTimeout) + si, err := wr.TopoServer().GetShard(shortCtx, keyspace, shard) + cancel() + if err != nil { + rec.RecordError(fmt.Errorf("failed to get details for shard '%v': %v", topoproto.KeyspaceShardString(keyspace, shard), err)) + return + } + + if len(si.SourceShards) > 0 && len(si.SourceShards[0].Tables) == 0 { + mu.Lock() + for _, sourceShard := range si.SourceShards { + sourceShards[fmt.Sprintf("%v/%v", sourceShard.Keyspace, sourceShard.Shard)] = + map[string]string{ + "Keyspace": sourceShard.Keyspace, + "Shard": sourceShard.Shard, + } + } + mu.Unlock() + } + }(keyspace, shard) + } + }(keyspace) + } + wg.Wait() + + if rec.HasErrors() { + return nil, rec.Error() + } + result := make([]map[string]string, 0, len(sourceShards)) + for _, shard := range sourceShards { + result = append(result, shard) + } + if len(result) == 0 { + return nil, fmt.Errorf("there are no shards with SourceShards") + } + return result, nil +} + +func interactiveMultiSplitDiff(ctx context.Context, wi *Instance, wr *wrangler.Wrangler, w http.ResponseWriter, r *http.Request) (Worker, *template.Template, map[string]interface{}, error) { + if err := r.ParseForm(); err != nil { + return nil, nil, nil, fmt.Errorf("cannot parse form: %s", err) + } + keyspace := r.FormValue("keyspace") + shard := r.FormValue("shard") + + if keyspace == "" || shard == "" { + // display the list of possible shards to chose from + result := make(map[string]interface{}) + shards, err := shardSources(ctx, wr) + if err != nil { + result["Error"] = err.Error() + } else { + result["Shards"] = shards + } + return nil, multiSplitDiffTemplate, result, nil + } + + submitButtonValue := r.FormValue("submit") + if submitButtonValue == "" { + // display the input form + result := make(map[string]interface{}) + result["Keyspace"] = keyspace + result["Shard"] = shard + result["DefaultSourceUID"] = "0" + result["DefaultMinHealthyRdonlyTablets"] = fmt.Sprintf("%v", defaultMinHealthyRdonlyTablets) + result["DefaultParallelDiffsCount"] = fmt.Sprintf("%v", defaultParallelDiffsCount) + return nil, multiSplitDiffTemplate2, result, nil + } + + // Process input form. + excludeTables := r.FormValue("excludeTables") + var excludeTableArray []string + if excludeTables != "" { + excludeTableArray = strings.Split(excludeTables, ",") + } + minHealthyRdonlyTabletsStr := r.FormValue("minHealthyRdonlyTablets") + parallelDiffsCountStr := r.FormValue("parallelDiffsCount") + minHealthyRdonlyTablets, err := strconv.ParseInt(minHealthyRdonlyTabletsStr, 0, 64) + parallelDiffsCount, err := strconv.ParseInt(parallelDiffsCountStr, 0, 64) + if err != nil { + return nil, nil, nil, fmt.Errorf("cannot parse minHealthyRdonlyTablets: %s", err) + } + waitForFixedTimeRatherThanGtidSetStr := r.FormValue("waitForFixedTimeRatherThanGtidSet") + waitForFixedTimeRatherThanGtidSet := waitForFixedTimeRatherThanGtidSetStr == "true" + if err != nil { + return nil, nil, nil, fmt.Errorf("cannot parse minHealthyRdonlyTablets: %s", err) + } + + // start the diff job + wrk := NewMultiSplitDiffWorker(wr, wi.cell, keyspace, shard, excludeTableArray, int(minHealthyRdonlyTablets), int(parallelDiffsCount), waitForFixedTimeRatherThanGtidSet) + return wrk, nil, nil, nil +} + +func init() { + AddCommand("Diffs", Command{"MultiSplitDiff", + commandMultiSplitDiff, interactiveMultiSplitDiff, + "[--exclude_tables=''] ", + "Diffs a rdonly destination shard against its SourceShards"}) +} diff --git a/go/vt/worker/multi_split_diff_test.go b/go/vt/worker/multi_split_diff_test.go new file mode 100644 index 00000000000..7811f5b1b54 --- /dev/null +++ b/go/vt/worker/multi_split_diff_test.go @@ -0,0 +1,337 @@ +/* +Copyright 2017 Google Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package worker + +import ( + "fmt" + "strings" + "testing" + "time" + + "golang.org/x/net/context" + + "vitess.io/vitess/go/sqltypes" + "vitess.io/vitess/go/vt/logutil" + "vitess.io/vitess/go/vt/mysqlctl/tmutils" + "vitess.io/vitess/go/vt/topo/memorytopo" + "vitess.io/vitess/go/vt/vttablet/grpcqueryservice" + "vitess.io/vitess/go/vt/vttablet/queryservice/fakes" + "vitess.io/vitess/go/vt/wrangler" + "vitess.io/vitess/go/vt/wrangler/testlib" + + querypb "vitess.io/vitess/go/vt/proto/query" + tabletmanagerdatapb "vitess.io/vitess/go/vt/proto/tabletmanagerdata" + topodatapb "vitess.io/vitess/go/vt/proto/topodata" + vschemapb "vitess.io/vitess/go/vt/proto/vschema" +) + +// msdDestinationTabletServer is a local QueryService implementation to +// support the tests +type msdDestinationTabletServer struct { + t *testing.T + + *fakes.StreamHealthQueryService + excludedTable string + shardIndex int +} + +func (sq *msdDestinationTabletServer) StreamExecute(ctx context.Context, target *querypb.Target, sql string, bindVariables map[string]*querypb.BindVariable, options *querypb.ExecuteOptions, callback func(reply *sqltypes.Result) error) error { + if strings.Contains(sql, sq.excludedTable) { + sq.t.Errorf("Split Diff operation on destination should skip the excluded table: %v query: %v", sq.excludedTable, sql) + } + + if hasKeyspace := strings.Contains(sql, "WHERE `keyspace_id`"); hasKeyspace == true { + sq.t.Errorf("Sql query on destination should not contain a keyspace_id WHERE clause; query received: %v", sql) + } + + sq.t.Logf("msdDestinationTabletServer: got query: %v", sql) + + // Send the headers + if err := callback(&sqltypes.Result{ + Fields: []*querypb.Field{ + { + Name: "id", + Type: sqltypes.Int64, + }, + { + Name: "msg", + Type: sqltypes.VarChar, + }, + { + Name: "keyspace_id", + Type: sqltypes.Int64, + }, + }, + }); err != nil { + return err + } + + // Send the values + ksids := []uint64{0x2000000000000000, 0x6000000000000000} + for i := 0; i < 100; i++ { + // skip the out-of-range values + if i%2 == sq.shardIndex { + continue + } + if err := callback(&sqltypes.Result{ + Rows: [][]sqltypes.Value{ + { + sqltypes.NewVarBinary(fmt.Sprintf("%v", i)), + sqltypes.NewVarBinary(fmt.Sprintf("Text for %v", i)), + sqltypes.NewVarBinary(fmt.Sprintf("%v", ksids[i%2])), + }, + }, + }); err != nil { + return err + } + } + return nil +} + +// msdSourceTabletServer is a local QueryService implementation to support the tests +type msdSourceTabletServer struct { + t *testing.T + + *fakes.StreamHealthQueryService + excludedTable string + v3 bool +} + +func (sq *msdSourceTabletServer) StreamExecute(ctx context.Context, target *querypb.Target, sql string, bindVariables map[string]*querypb.BindVariable, options *querypb.ExecuteOptions, callback func(reply *sqltypes.Result) error) error { + if strings.Contains(sql, sq.excludedTable) { + sq.t.Errorf("Split Diff operation on source should skip the excluded table: %v query: %v", sq.excludedTable, sql) + } + + // we test for a keyspace_id where clause, except for v3 + if !sq.v3 { + if hasKeyspace := strings.Contains(sql, "WHERE `keyspace_id` < 4611686018427387904"); hasKeyspace != true { + sq.t.Errorf("Sql query on source should contain a keyspace_id WHERE clause; query received: %v", sql) + } + } + + sq.t.Logf("msdSourceTabletServer: got query: %v", sql) + + // Send the headers + if err := callback(&sqltypes.Result{ + Fields: []*querypb.Field{ + { + Name: "id", + Type: sqltypes.Int64, + }, + { + Name: "msg", + Type: sqltypes.VarChar, + }, + { + Name: "keyspace_id", + Type: sqltypes.Int64, + }, + }, + }); err != nil { + return err + } + + // Send the values + ksids := []uint64{0x2000000000000000, 0x6000000000000000} + for i := 0; i < 100; i++ { + if !sq.v3 && i%2 == 1 { + // for v2, filtering is done at SQL layer + continue + } + if err := callback(&sqltypes.Result{ + Rows: [][]sqltypes.Value{ + { + sqltypes.NewVarBinary(fmt.Sprintf("%v", i)), + sqltypes.NewVarBinary(fmt.Sprintf("Text for %v", i)), + sqltypes.NewVarBinary(fmt.Sprintf("%v", ksids[i%2])), + }, + }, + }); err != nil { + return err + } + } + return nil +} + +// TODO(aaijazi): Create a test in which source and destination data does not match + +func testMultiSplitDiff(t *testing.T, v3 bool) { + *useV3ReshardingMode = v3 + ts := memorytopo.NewServer("cell1", "cell2") + ctx := context.Background() + wi := NewInstance(ts, "cell1", time.Second) + + if v3 { + if err := ts.CreateKeyspace(ctx, "ks", &topodatapb.Keyspace{}); err != nil { + t.Fatalf("CreateKeyspace v3 failed: %v", err) + } + + vs := &vschemapb.Keyspace{ + Sharded: true, + Vindexes: map[string]*vschemapb.Vindex{ + "table1_index": { + Type: "numeric", + }, + }, + Tables: map[string]*vschemapb.Table{ + "table1": { + ColumnVindexes: []*vschemapb.ColumnVindex{ + { + Column: "keyspace_id", + Name: "table1_index", + }, + }, + }, + }, + } + if err := ts.SaveVSchema(ctx, "ks", vs); err != nil { + t.Fatalf("SaveVSchema v3 failed: %v", err) + } + } else { + if err := ts.CreateKeyspace(ctx, "ks", &topodatapb.Keyspace{ + ShardingColumnName: "keyspace_id", + ShardingColumnType: topodatapb.KeyspaceIdType_UINT64, + }); err != nil { + t.Fatalf("CreateKeyspace failed: %v", err) + } + } + + sourceMaster := testlib.NewFakeTablet(t, wi.wr, "cell1", 0, + topodatapb.TabletType_MASTER, nil, testlib.TabletKeyspaceShard(t, "ks", "-80")) + sourceRdonly1 := testlib.NewFakeTablet(t, wi.wr, "cell1", 1, + topodatapb.TabletType_RDONLY, nil, testlib.TabletKeyspaceShard(t, "ks", "-80")) + sourceRdonly2 := testlib.NewFakeTablet(t, wi.wr, "cell1", 2, + topodatapb.TabletType_RDONLY, nil, testlib.TabletKeyspaceShard(t, "ks", "-80")) + + leftMaster := testlib.NewFakeTablet(t, wi.wr, "cell1", 10, + topodatapb.TabletType_MASTER, nil, testlib.TabletKeyspaceShard(t, "ks", "-40")) + leftRdonly1 := testlib.NewFakeTablet(t, wi.wr, "cell1", 11, + topodatapb.TabletType_RDONLY, nil, testlib.TabletKeyspaceShard(t, "ks", "-40")) + leftRdonly2 := testlib.NewFakeTablet(t, wi.wr, "cell1", 12, + topodatapb.TabletType_RDONLY, nil, testlib.TabletKeyspaceShard(t, "ks", "-40")) + + rightMaster := testlib.NewFakeTablet(t, wi.wr, "cell1", 20, + topodatapb.TabletType_MASTER, nil, testlib.TabletKeyspaceShard(t, "ks", "40-80")) + rightRdonly1 := testlib.NewFakeTablet(t, wi.wr, "cell1", 21, + topodatapb.TabletType_RDONLY, nil, testlib.TabletKeyspaceShard(t, "ks", "40-80")) + rightRdonly2 := testlib.NewFakeTablet(t, wi.wr, "cell1", 22, + topodatapb.TabletType_RDONLY, nil, testlib.TabletKeyspaceShard(t, "ks", "40-80")) + + // add the topo and schema data we'll need + if err := ts.CreateShard(ctx, "ks", "80-"); err != nil { + t.Fatalf("CreateShard(\"-80\") failed: %v", err) + } + wi.wr.SetSourceShards(ctx, "ks", "-40", []*topodatapb.TabletAlias{sourceRdonly1.Tablet.Alias}, nil) + if err := wi.wr.SetKeyspaceShardingInfo(ctx, "ks", "keyspace_id", topodatapb.KeyspaceIdType_UINT64, false); err != nil { + t.Fatalf("SetKeyspaceShardingInfo failed: %v", err) + } + wi.wr.SetSourceShards(ctx, "ks", "40-80", []*topodatapb.TabletAlias{sourceRdonly1.Tablet.Alias}, nil) + if err := wi.wr.SetKeyspaceShardingInfo(ctx, "ks", "keyspace_id", topodatapb.KeyspaceIdType_UINT64, false); err != nil { + t.Fatalf("SetKeyspaceShardingInfo failed: %v", err) + } + if err := wi.wr.RebuildKeyspaceGraph(ctx, "ks", nil); err != nil { + t.Fatalf("RebuildKeyspaceGraph failed: %v", err) + } + + excludedTable := "excludedTable1" + + for _, rdonly := range []*testlib.FakeTablet{sourceRdonly1, sourceRdonly2, leftRdonly1, leftRdonly2, rightRdonly1, rightRdonly2} { + // The destination only has half the data. + // For v2, we do filtering at the SQL level. + // For v3, we do it in the client. + // So in any case, we need real data. + rdonly.FakeMysqlDaemon.Schema = &tabletmanagerdatapb.SchemaDefinition{ + DatabaseSchema: "", + TableDefinitions: []*tabletmanagerdatapb.TableDefinition{ + { + Name: "table1", + Columns: []string{"id", "msg", "keyspace_id"}, + PrimaryKeyColumns: []string{"id"}, + Type: tmutils.TableBaseTable, + }, + { + Name: excludedTable, + Columns: []string{"id", "msg", "keyspace_id"}, + PrimaryKeyColumns: []string{"id"}, + Type: tmutils.TableBaseTable, + }, + }, + } + } + + for _, sourceRdonly := range []*testlib.FakeTablet{sourceRdonly1, sourceRdonly2} { + qs := fakes.NewStreamHealthQueryService(sourceRdonly.Target()) + qs.AddDefaultHealthResponse() + grpcqueryservice.Register(sourceRdonly.RPCServer, &msdSourceTabletServer{ + t: t, + StreamHealthQueryService: qs, + excludedTable: excludedTable, + v3: v3, + }) + } + + for _, destRdonly := range []*testlib.FakeTablet{leftRdonly1, leftRdonly2} { + qs := fakes.NewStreamHealthQueryService(destRdonly.Target()) + qs.AddDefaultHealthResponse() + grpcqueryservice.Register(destRdonly.RPCServer, &msdDestinationTabletServer{ + t: t, + StreamHealthQueryService: qs, + excludedTable: excludedTable, + shardIndex: 0, + }) + } + + for _, destRdonly := range []*testlib.FakeTablet{rightRdonly1, rightRdonly2} { + qs := fakes.NewStreamHealthQueryService(destRdonly.Target()) + qs.AddDefaultHealthResponse() + grpcqueryservice.Register(destRdonly.RPCServer, &msdDestinationTabletServer{ + t: t, + StreamHealthQueryService: qs, + excludedTable: excludedTable, + shardIndex: 1, + }) + } + + // Start action loop after having registered all RPC services. + for _, ft := range []*testlib.FakeTablet{sourceMaster, sourceRdonly1, sourceRdonly2, leftMaster, leftRdonly1, leftRdonly2, rightMaster, rightRdonly1, rightRdonly2} { + ft.StartActionLoop(t, wi.wr) + defer ft.StopActionLoop(t) + } + + // Run the vtworker command. + args := []string{ + "MultiSplitDiff", + "-exclude_tables", excludedTable, + "ks/-80", + } + // We need to use FakeTabletManagerClient because we don't + // have a good way to fake the binlog player yet, which is + // necessary for synchronizing replication. + wr := wrangler.New(logutil.NewConsoleLogger(), ts, newFakeTMCTopo(ts)) + if err := runCommand(t, wi, wr, args); err != nil { + t.Fatal(err) + } +} + +func TestMultiSplitDiffv2(t *testing.T) { + // TODO: Make MultiSplitDiff work with V2 + // testMultiSplitDiff(t, false) +} + +func TestMultiSplitDiffv3(t *testing.T) { + testMultiSplitDiff(t, true) +} diff --git a/go/vt/worker/restartable_result_reader.go b/go/vt/worker/restartable_result_reader.go index d21110d418e..73fadf04145 100644 --- a/go/vt/worker/restartable_result_reader.go +++ b/go/vt/worker/restartable_result_reader.go @@ -207,7 +207,7 @@ func (r *RestartableResultReader) nextWithRetries() (*sqltypes.Result, error) { retryable, err = r.getTablet() if err != nil { if !retryable { - r.logger.Errorf("table=%v chunk=%v: Failed to restart streaming query (attempt %d) and failover to a different tablet (%v) due to a non-retryable error: %v", r.td.Name, r.chunk, attempt, r.tablet, err) + r.logger.Errorf2(err, "table=%v chunk=%v: Failed to restart streaming query (attempt %d) and failover to a different tablet (%v) due to a non-retryable error", r.td.Name, r.chunk, attempt, r.tablet) return nil, err } goto retry @@ -219,7 +219,7 @@ func (r *RestartableResultReader) nextWithRetries() (*sqltypes.Result, error) { retryable, err = r.startStream() if err != nil { if !retryable { - r.logger.Errorf("tablet=%v table=%v chunk=%v: Failed to restart streaming query (attempt %d) with query '%v' and stopped due to a non-retryable error: %v", topoproto.TabletAliasString(r.tablet.Alias), r.td.Name, r.chunk, attempt, r.query, err) + r.logger.Errorf2(err, "tablet=%v table=%v chunk=%v: Failed to restart streaming query (attempt %d) with query '%v' and stopped due to a non-retryable error", topoproto.TabletAliasString(r.tablet.Alias), r.td.Name, r.chunk, attempt, r.query) return nil, err } goto retry diff --git a/go/vt/worker/result_merger.go b/go/vt/worker/result_merger.go index d5a2ca26c94..68250d06ad0 100644 --- a/go/vt/worker/result_merger.go +++ b/go/vt/worker/result_merger.go @@ -68,11 +68,9 @@ func NewResultMerger(inputs []ResultReader, pkFieldCount int) (*ResultMerger, er return nil, err } - for i := 0; i < pkFieldCount; i++ { - typ := fields[i].Type - if !sqltypes.IsIntegral(typ) && !sqltypes.IsFloat(typ) && !sqltypes.IsBinary(typ) { - return nil, fmt.Errorf("unsupported type: %v cannot compare fields with this type. Use the vtworker LegacySplitClone command instead", typ) - } + err := CheckValidTypesForResultMerger(fields, pkFieldCount) + if err != nil { + return nil, fmt.Errorf("invalid PK types for ResultMerger. Use the vtworker LegacySplitClone command instead. %v", err.Error()) } // Initialize the priority queue with all input ResultReader which have at @@ -100,6 +98,17 @@ func NewResultMerger(inputs []ResultReader, pkFieldCount int) (*ResultMerger, er return rm, nil } +// CheckValidTypesForResultMerger returns an error if the provided fields are not compatible with how ResultMerger works +func CheckValidTypesForResultMerger(fields []*querypb.Field, pkFieldCount int) error { + for i := 0; i < pkFieldCount; i++ { + typ := fields[i].Type + if !sqltypes.IsIntegral(typ) && !sqltypes.IsFloat(typ) && !sqltypes.IsBinary(typ) { + return fmt.Errorf("unsupported type: %v cannot compare fields with this type", typ) + } + } + return nil +} + // Fields returns the field information for the columns in the result. // It is part of the ResultReader interface. func (rm *ResultMerger) Fields() []*querypb.Field { diff --git a/go/vt/worker/split_clone.go b/go/vt/worker/split_clone.go index ff18878cb74..3cc7638d4a1 100644 --- a/go/vt/worker/split_clone.go +++ b/go/vt/worker/split_clone.go @@ -397,7 +397,7 @@ func (scw *SplitCloneWorker) Run(ctx context.Context) error { cerr := scw.cleaner.CleanUp(scw.wr) if cerr != nil { if err != nil { - scw.wr.Logger().Errorf("CleanUp failed in addition to job error: %v", cerr) + scw.wr.Logger().Errorf2(cerr, "CleanUp failed in addition to job error: %v") } else { err = cerr } @@ -412,7 +412,7 @@ func (scw *SplitCloneWorker) Run(ctx context.Context) error { // After Close returned, we can be sure that it won't call our listener // implementation (method StatsUpdate) anymore. if err := scw.healthCheck.Close(); err != nil { - scw.wr.Logger().Errorf("HealthCheck.Close() failed: %v", err) + scw.wr.Logger().Errorf2(err, "HealthCheck.Close() failed") } } diff --git a/go/vt/worker/split_clone_test.go b/go/vt/worker/split_clone_test.go index 25bb32b4dea..911a5c8ff82 100644 --- a/go/vt/worker/split_clone_test.go +++ b/go/vt/worker/split_clone_test.go @@ -330,8 +330,8 @@ func newTestQueryService(t *testing.T, target querypb.Target, shqs *fakes.Stream fields = v3Fields } return &testQueryService{ - t: t, - target: target, + t: t, + target: target, StreamHealthQueryService: shqs, shardIndex: shardIndex, shardCount: shardCount, diff --git a/go/vt/worker/split_diff.go b/go/vt/worker/split_diff.go index e6cc6a7ff7a..82c3211f017 100644 --- a/go/vt/worker/split_diff.go +++ b/go/vt/worker/split_diff.go @@ -131,13 +131,13 @@ func (sdw *SplitDiffWorker) Run(ctx context.Context) error { cerr := sdw.cleaner.CleanUp(sdw.wr) if cerr != nil { if err != nil { - sdw.wr.Logger().Errorf("CleanUp failed in addition to job error: %v", cerr) + sdw.wr.Logger().Errorf2(cerr, "CleanUp failed in addition to job error") } else { err = cerr } } if err != nil { - sdw.wr.Logger().Errorf("Run() error: %v", err) + sdw.wr.Logger().Errorf2(err, "Run() error") sdw.SetState(WorkerStateError) return err } @@ -504,7 +504,7 @@ func (sdw *SplitDiffWorker) diff(ctx context.Context) error { if err != nil { newErr := vterrors.Wrap(err, "TableScan(ByKeyRange?)(source) failed") sdw.markAsWillFail(rec, newErr) - sdw.wr.Logger().Errorf("%v", newErr) + sdw.wr.Logger().Error(newErr) return } defer sourceQueryResultReader.Close(ctx) @@ -520,7 +520,7 @@ func (sdw *SplitDiffWorker) diff(ctx context.Context) error { if err != nil { newErr := vterrors.Wrap(err, "TableScan(ByKeyRange?)(destination) failed") sdw.markAsWillFail(rec, newErr) - sdw.wr.Logger().Errorf("%v", newErr) + sdw.wr.Logger().Error(newErr) return } defer destinationQueryResultReader.Close(ctx) @@ -530,7 +530,7 @@ func (sdw *SplitDiffWorker) diff(ctx context.Context) error { if err != nil { newErr := vterrors.Wrap(err, "NewRowDiffer() failed") sdw.markAsWillFail(rec, newErr) - sdw.wr.Logger().Errorf("%v", newErr) + sdw.wr.Logger().Error(newErr) return } @@ -539,7 +539,7 @@ func (sdw *SplitDiffWorker) diff(ctx context.Context) error { if err != nil { newErr := fmt.Errorf("Differ.Go failed: %v", err.Error()) sdw.markAsWillFail(rec, newErr) - sdw.wr.Logger().Errorf("%v", newErr) + sdw.wr.Logger().Error(newErr) } else { if report.HasDifferences() { err := fmt.Errorf("Table %v has differences: %v", tableDefinition.Name, report.String()) diff --git a/go/vt/worker/split_diff_test.go b/go/vt/worker/split_diff_test.go index c724d57964a..efc8f4be612 100644 --- a/go/vt/worker/split_diff_test.go +++ b/go/vt/worker/split_diff_test.go @@ -265,7 +265,7 @@ func testSplitDiff(t *testing.T, v3 bool, destinationTabletType topodatapb.Table qs := fakes.NewStreamHealthQueryService(sourceRdonly.Target()) qs.AddDefaultHealthResponse() grpcqueryservice.Register(sourceRdonly.RPCServer, &sourceTabletServer{ - t: t, + t: t, StreamHealthQueryService: qs, excludedTable: excludedTable, v3: v3, @@ -276,7 +276,7 @@ func testSplitDiff(t *testing.T, v3 bool, destinationTabletType topodatapb.Table qs := fakes.NewStreamHealthQueryService(destRdonly.Target()) qs.AddDefaultHealthResponse() grpcqueryservice.Register(destRdonly.RPCServer, &destinationTabletServer{ - t: t, + t: t, StreamHealthQueryService: qs, excludedTable: excludedTable, }) diff --git a/go/vt/worker/vertical_split_diff.go b/go/vt/worker/vertical_split_diff.go index 2980db2fdbf..eed0d695883 100644 --- a/go/vt/worker/vertical_split_diff.go +++ b/go/vt/worker/vertical_split_diff.go @@ -68,11 +68,11 @@ type VerticalSplitDiffWorker struct { // NewVerticalSplitDiffWorker returns a new VerticalSplitDiffWorker object. func NewVerticalSplitDiffWorker(wr *wrangler.Wrangler, cell, keyspace, shard string, minHealthyRdonlyTablets, parallelDiffsCount int, destintationTabletType topodatapb.TabletType) Worker { return &VerticalSplitDiffWorker{ - StatusWorker: NewStatusWorker(), - wr: wr, - cell: cell, - keyspace: keyspace, - shard: shard, + StatusWorker: NewStatusWorker(), + wr: wr, + cell: cell, + keyspace: keyspace, + shard: shard, minHealthyRdonlyTablets: minHealthyRdonlyTablets, destinationTabletType: destintationTabletType, parallelDiffsCount: parallelDiffsCount, @@ -124,7 +124,7 @@ func (vsdw *VerticalSplitDiffWorker) Run(ctx context.Context) error { cerr := vsdw.cleaner.CleanUp(vsdw.wr) if cerr != nil { if err != nil { - vsdw.wr.Logger().Errorf("CleanUp failed in addition to job error: %v", cerr) + vsdw.wr.Logger().Errorf2(cerr, "CleanUp failed in addition to job error") } else { err = cerr } @@ -418,7 +418,7 @@ func (vsdw *VerticalSplitDiffWorker) diff(ctx context.Context) error { if err != nil { newErr := vterrors.Wrap(err, "TableScan(source) failed") vsdw.markAsWillFail(rec, newErr) - vsdw.wr.Logger().Errorf("%v", newErr) + vsdw.wr.Logger().Error(newErr) return } defer sourceQueryResultReader.Close(ctx) @@ -427,7 +427,7 @@ func (vsdw *VerticalSplitDiffWorker) diff(ctx context.Context) error { if err != nil { newErr := vterrors.Wrap(err, "TableScan(destination) failed") vsdw.markAsWillFail(rec, newErr) - vsdw.wr.Logger().Errorf("%v", newErr) + vsdw.wr.Logger().Error(newErr) return } defer destinationQueryResultReader.Close(ctx) @@ -436,18 +436,18 @@ func (vsdw *VerticalSplitDiffWorker) diff(ctx context.Context) error { if err != nil { newErr := vterrors.Wrap(err, "NewRowDiffer() failed") vsdw.markAsWillFail(rec, newErr) - vsdw.wr.Logger().Errorf("%v", newErr) + vsdw.wr.Logger().Error(newErr) return } report, err := differ.Go(vsdw.wr.Logger()) if err != nil { - vsdw.wr.Logger().Errorf("Differ.Go failed: %v", err) + vsdw.wr.Logger().Errorf2(err, "Differ.Go failed") } else { if report.HasDifferences() { - err := fmt.Errorf("Table %v has differences: %v", tableDefinition.Name, report.String()) + err := fmt.Errorf("table %v has differences: %v", tableDefinition.Name, report.String()) vsdw.markAsWillFail(rec, err) - vsdw.wr.Logger().Errorf("%v", err) + vsdw.wr.Logger().Error(err) } else { vsdw.wr.Logger().Infof("Table %v checks out (%v rows processed, %v qps)", tableDefinition.Name, report.processedRows, report.processingQPS) } diff --git a/go/vt/worker/vertical_split_diff_test.go b/go/vt/worker/vertical_split_diff_test.go index 4723f1d86ce..c2c05a97f2c 100644 --- a/go/vt/worker/vertical_split_diff_test.go +++ b/go/vt/worker/vertical_split_diff_test.go @@ -171,7 +171,7 @@ func TestVerticalSplitDiff(t *testing.T) { qs := fakes.NewStreamHealthQueryService(rdonly.Target()) qs.AddDefaultHealthResponse() grpcqueryservice.Register(rdonly.RPCServer, &verticalDiffTabletServer{ - t: t, + t: t, StreamHealthQueryService: qs, }) } diff --git a/go/vt/wrangler/cleaner.go b/go/vt/wrangler/cleaner.go index d65d7b6eed0..668e369a15f 100644 --- a/go/vt/wrangler/cleaner.go +++ b/go/vt/wrangler/cleaner.go @@ -93,7 +93,7 @@ func (cleaner *Cleaner) CleanUp(wr *Wrangler) error { if err != nil { helper.err = err rec.RecordError(err) - wr.Logger().Errorf("action %v failed on %v: %v", actionReference.name, actionReference.target, err) + wr.Logger().Errorf2(err, "action %v failed on %v", actionReference.name, actionReference.target) } else { wr.Logger().Infof("action %v successful on %v", actionReference.name, actionReference.target) } diff --git a/go/vt/wrangler/keyspace.go b/go/vt/wrangler/keyspace.go index d65e721ba3f..1e5cd10200a 100644 --- a/go/vt/wrangler/keyspace.go +++ b/go/vt/wrangler/keyspace.go @@ -607,11 +607,11 @@ func (wr *Wrangler) masterMigrateServedType(ctx context.Context, keyspace string func (wr *Wrangler) cancelMasterMigrateServedTypes(ctx context.Context, sourceShards []*topo.ShardInfo) { if err := wr.updateShardRecords(ctx, sourceShards, nil, topodatapb.TabletType_MASTER, false); err != nil { - wr.Logger().Errorf("failed to re-enable source masters: %v", err) + wr.Logger().Errorf2(err, "failed to re-enable source masters") return } if err := wr.refreshMasters(ctx, sourceShards); err != nil { - wr.Logger().Errorf("failed to refresh source masters: %v", err) + wr.Logger().Errorf2(err, "failed to refresh source masters") } } diff --git a/go/vt/wrangler/reparent.go b/go/vt/wrangler/reparent.go index b19e9e267fe..4b8a1d09148 100644 --- a/go/vt/wrangler/reparent.go +++ b/go/vt/wrangler/reparent.go @@ -495,7 +495,7 @@ func (wr *Wrangler) plannedReparentShardLocked(ctx context.Context, ev *events.R // Wait for the slaves to complete. wgSlaves.Wait() if err := rec.Error(); err != nil { - wr.Logger().Errorf("Some slaves failed to reparent: %v", err) + wr.Logger().Errorf2(err, "some slaves failed to reparent") return err } @@ -791,7 +791,7 @@ func (wr *Wrangler) emergencyReparentShardLocked(ctx context.Context, ev *events // will rebuild the shard serving graph anyway wgSlaves.Wait() if err := rec.Error(); err != nil { - wr.Logger().Errorf("Some slaves failed to reparent: %v", err) + wr.Logger().Errorf2(err, "some slaves failed to reparent") return err } diff --git a/go/vt/wrangler/validator.go b/go/vt/wrangler/validator.go index 41119d5913d..231b6f65221 100644 --- a/go/vt/wrangler/validator.go +++ b/go/vt/wrangler/validator.go @@ -49,7 +49,7 @@ func (wr *Wrangler) waitForResults(wg *sync.WaitGroup, results chan error) error var finalErr error for err := range results { finalErr = errors.New("some validation errors - see log") - wr.Logger().Errorf("%v", err) + wr.Logger().Error(err) } return finalErr } diff --git a/helm/release.sh b/helm/release.sh new file mode 100755 index 00000000000..7e0f3a1642a --- /dev/null +++ b/helm/release.sh @@ -0,0 +1,51 @@ +#!/bin/bash + +version_tag=1.0.3 + +docker pull vitess/k8s:latest +docker tag vitess/k8s:latest vitess/k8s:helm-$version_tag +docker push vitess/k8s:helm-$version_tag + +docker pull vitess/vtgate:latest +docker tag vitess/vtgate:latest vitess/vtgate:helm-$version_tag +docker push vitess/vtgate:helm-$version_tag + +docker pull vitess/vttablet:latest +docker tag vitess/vttablet:latest vitess/vttablet:helm-$version_tag +docker push vitess/vttablet:helm-$version_tag + +docker pull vitess/mysqlctld:latest +docker tag vitess/mysqlctld:latest vitess/mysqlctld:helm-$version_tag +docker push vitess/mysqlctld:helm-$version_tag + +docker pull vitess/vtctl:latest +docker tag vitess/vtctl:latest vitess/vtctl:helm-$version_tag +docker push vitess/vtctl:helm-$version_tag + +docker pull vitess/vtctlclient:latest +docker tag vitess/vtctlclient:latest vitess/vtctlclient:helm-$version_tag +docker push vitess/vtctlclient:helm-$version_tag + +docker pull vitess/vtctld:latest +docker tag vitess/vtctld:latest vitess/vtctld:helm-$version_tag +docker push vitess/vtctld:helm-$version_tag + +docker pull vitess/vtworker:latest +docker tag vitess/vtworker:latest vitess/vtworker:helm-$version_tag +docker push vitess/vtworker:helm-$version_tag + +docker pull vitess/logrotate:latest +docker tag vitess/logrotate:latest vitess/logrotate:helm-$version_tag +docker push vitess/logrotate:helm-$version_tag + +docker pull vitess/logtail:latest +docker tag vitess/logtail:latest vitess/logtail:helm-$version_tag +docker push vitess/logtail:helm-$version_tag + +docker pull vitess/pmm-client:latest +docker tag vitess/pmm-client:latest vitess/pmm-client:helm-$version_tag +docker push vitess/pmm-client:helm-$version_tag + +docker pull vitess/orchestrator:latest +docker tag vitess/orchestrator:latest vitess/orchestrator:helm-$version_tag +docker push vitess/orchestrator:helm-$version_tag diff --git a/helm/vitess/CHANGELOG.md b/helm/vitess/CHANGELOG.md new file mode 100644 index 00000000000..992437e922f --- /dev/null +++ b/helm/vitess/CHANGELOG.md @@ -0,0 +1,31 @@ +## 1.0.3 - 2018-12-20 + +### Changes +* Start tagging helm images and use them as default +* Added commonly used flags to values.yaml for vtgate & vttablet for discoverability. +Some match the binary flag defaults, and some have been set to more production ready values. +* Extended vttablet terminationGracePeriodSeconds from 600 to 60000000. +This will block on `PlannedReparent` in the `preStopHook` forever to prevent +unsafe `EmergencyReparent` operations when the pod is killed. + +### Bug fixes +* Use `$MYSQL_FLAVOR` to set flavor instead of `$EXTRA_MY_CNF` + +## 1.0.2 - 2018-12-11 + +### Bug fixes +* Renamed ImagePullPolicy to imagePullPolicy +* Added user-secret-volumes to backup CronJob + +## 1.0.1 - 2018-12-07 + +### Changes +* Added support for [MySQL Custom Queries](https://www.percona.com/blog/2018/10/10/percona-monitoring-and-management-pmm-1-15-0-is-now-available/) in PMM +* Added Linux host monitoring for PMM +* Added keyspace and shard labels to jobs +* Remove old mysql.sock file in vttablet InitContainer + +### Bug fixes +* PMM wouldn't bootstrap correctly on a new cluster + +## 1.0.0 - 2018-12-03 Vitess Helm Chart goes GA! \ No newline at end of file diff --git a/helm/vitess/Chart.yaml b/helm/vitess/Chart.yaml index d4ec24f79f7..b65208e7bf0 100644 --- a/helm/vitess/Chart.yaml +++ b/helm/vitess/Chart.yaml @@ -1,6 +1,6 @@ apiVersion: v1 name: vitess -version: 1.0.0 +version: 1.0.3 description: Single-Chart Vitess Cluster keywords: - vitess diff --git a/helm/vitess/templates/_cron-jobs.tpl b/helm/vitess/templates/_cron-jobs.tpl index ebbe576c612..3b4ffd13852 100644 --- a/helm/vitess/templates/_cron-jobs.tpl +++ b/helm/vitess/templates/_cron-jobs.tpl @@ -27,6 +27,7 @@ metadata: cell: {{ $cellClean | quote }} keyspace: {{ $keyspaceClean | quote }} shard: {{ $shardClean | quote }} + backupJob: "true" spec: schedule: {{ $shard.backup.cron.schedule | default $backup.cron.schedule | quote }} @@ -45,6 +46,8 @@ spec: cell: {{ $cellClean | quote }} keyspace: {{ $keyspaceClean | quote }} shard: {{ $shardClean | quote }} + backupJob: "true" + # pod spec spec: restartPolicy: Never @@ -71,6 +74,8 @@ spec: requests: cpu: 10m memory: 20Mi + volumes: +{{ include "user-secret-volumes" $defaultVtctlclient.secrets | indent 12 }} {{ end }} diff --git a/helm/vitess/templates/_helpers.tpl b/helm/vitess/templates/_helpers.tpl index 3463c7916d7..f83683c0e4d 100644 --- a/helm/vitess/templates/_helpers.tpl +++ b/helm/vitess/templates/_helpers.tpl @@ -103,26 +103,27 @@ nodeAffinity: {{- define "mycnf-exec" -}} if [ "$VT_DB_FLAVOR" = "percona" ]; then - FLAVOR_MYCNF=/vt/config/mycnf/master_mysql56.cnf + MYSQL_FLAVOR=Percona elif [ "$VT_DB_FLAVOR" = "mysql" ]; then - FLAVOR_MYCNF=/vt/config/mycnf/master_mysql56.cnf + MYSQL_FLAVOR=MySQL56 elif [ "$VT_DB_FLAVOR" = "mysql56" ]; then - FLAVOR_MYCNF=/vt/config/mycnf/master_mysql56.cnf + MYSQL_FLAVOR=MySQL56 elif [ "$VT_DB_FLAVOR" = "maria" ]; then - FLAVOR_MYCNF=/vt/config/mycnf/master_mariadb.cnf + MYSQL_FLAVOR=MariaDB elif [ "$VT_DB_FLAVOR" = "mariadb" ]; then - FLAVOR_MYCNF=/vt/config/mycnf/master_mariadb.cnf + MYSQL_FLAVOR=MariaDB elif [ "$VT_DB_FLAVOR" = "mariadb103" ]; then - FLAVOR_MYCNF=/vt/config/mycnf/master_mariadb103.cnf + MYSQL_FLAVOR=MariaDB103 fi -export EXTRA_MY_CNF="$FLAVOR_MYCNF:/vtdataroot/tabletdata/report-host.cnf:/vt/config/mycnf/rbr.cnf" +export MYSQL_FLAVOR +export EXTRA_MY_CNF="/vtdataroot/tabletdata/report-host.cnf:/vt/config/mycnf/rbr.cnf" {{ if . }} for filename in /vt/userconfig/*.cnf; do diff --git a/helm/vitess/templates/_jobs.tpl b/helm/vitess/templates/_jobs.tpl index 6567e1cf91b..45d389ff50e 100644 --- a/helm/vitess/templates/_jobs.tpl +++ b/helm/vitess/templates/_jobs.tpl @@ -20,6 +20,12 @@ metadata: spec: backoffLimit: 1 template: + metadata: + labels: + app: vitess + component: vtctlclient + vtctlclientJob: "true" + spec: restartPolicy: OnFailure containers: @@ -62,6 +68,12 @@ metadata: spec: backoffLimit: 1 template: + metadata: + labels: + app: vitess + component: vtworker + vtworkerJob: "true" + spec: {{ include "pod-security" . | indent 6 }} restartPolicy: OnFailure diff --git a/helm/vitess/templates/_orchestrator.tpl b/helm/vitess/templates/_orchestrator.tpl index c17308f330c..d0babc41a59 100644 --- a/helm/vitess/templates/_orchestrator.tpl +++ b/helm/vitess/templates/_orchestrator.tpl @@ -89,7 +89,7 @@ spec: containers: - name: orchestrator image: {{ $orc.image | quote }} - ImagePullPolicy: IfNotPresent + imagePullPolicy: IfNotPresent ports: - containerPort: 3000 name: web @@ -123,8 +123,8 @@ spec: value: "15999" - name: recovery-log - image: vitess/logtail:latest - ImagePullPolicy: IfNotPresent + image: vitess/logtail:helm-1.0.3 + imagePullPolicy: IfNotPresent env: - name: TAIL_FILEPATH value: /tmp/recovery.log @@ -133,8 +133,8 @@ spec: mountPath: /tmp - name: audit-log - image: vitess/logtail:latest - ImagePullPolicy: IfNotPresent + image: vitess/logtail:helm-1.0.3 + imagePullPolicy: IfNotPresent env: - name: TAIL_FILEPATH value: /tmp/orchestrator-audit.log diff --git a/helm/vitess/templates/_pmm.tpl b/helm/vitess/templates/_pmm.tpl index 522ae60d4ac..9ff20aa94a5 100644 --- a/helm/vitess/templates/_pmm.tpl +++ b/helm/vitess/templates/_pmm.tpl @@ -133,14 +133,20 @@ spec: ################################### {{ define "cont-pmm-client" -}} {{- $pmm := index . 0 -}} -{{- $namespace := index . 1 }} +{{- $namespace := index . 1 -}} +{{- $keyspace := index . 2 }} - name: "pmm-client" image: "vitess/pmm-client:{{ $pmm.pmmTag }}" - ImagePullPolicy: IfNotPresent + imagePullPolicy: IfNotPresent volumeMounts: - name: vtdataroot mountPath: "/vtdataroot" +{{ if $keyspace.pmm }}{{if $keyspace.pmm.config }} + - name: config + mountPath: "/vt-pmm-config" +{{ end }}{{ end }} + ports: - containerPort: 42001 name: query-data @@ -173,10 +179,23 @@ spec: ln -s /vtdataroot/pmm/init.d /etc/init.d ln -s /vtdataroot/pmm/pmm-mysql-metrics-42002.log /var/log/pmm-mysql-metrics-42002.log - # workaround for when pod ips change if [ ! -z "$FIRST_RUN" ]; then cp -r /usr/local/percona_tmp/* /vtdataroot/pmm/percona || : cp -r /etc/init.d_tmp/* /vtdataroot/pmm/init.d || : + fi + +{{ if $keyspace.pmm }}{{if $keyspace.pmm.config }} + # link all the configmap files into their expected file locations + for filename in /vt-pmm-config/*; do + DEST_FILE=/vtdataroot/pmm/percona/pmm-client/$(basename "$filename") + rm -f $DEST_FILE + ln -s "$filename" $DEST_FILE + done +{{ end }}{{ end }} + + # if this doesn't return an error, pmm-admin has already been configured + # and we want to stop/remove running services, in case pod ips have changed + if pmm-admin info; then pmm-admin stop --all pmm-admin rm --all fi @@ -191,6 +210,7 @@ spec: done # creates systemd services + pmm-admin add linux:metrics pmm-admin add mysql:metrics --user root --socket /vtdataroot/tabletdata/mysql.sock --force pmm-admin add mysql:queries --user root --socket /vtdataroot/tabletdata/mysql.sock --force --query-source=perfschema @@ -198,8 +218,8 @@ spec: trap : TERM INT; sleep infinity & wait - name: pmm-client-metrics-log - image: vitess/logtail:latest - ImagePullPolicy: IfNotPresent + image: vitess/logtail:helm-1.0.3 + imagePullPolicy: IfNotPresent env: - name: TAIL_FILEPATH value: /vtdataroot/pmm/pmm-mysql-metrics-42002.log diff --git a/helm/vitess/templates/_shard.tpl b/helm/vitess/templates/_shard.tpl index d8c6c3f5846..3d075c24f85 100644 --- a/helm/vitess/templates/_shard.tpl +++ b/helm/vitess/templates/_shard.tpl @@ -29,6 +29,15 @@ metadata: spec: backoffLimit: 1 template: + metadata: + labels: + app: vitess + component: vttablet + cell: {{ $cellClean | quote }} + keyspace: {{ $keyspaceClean | quote }} + shard: {{ $shardClean | quote }} + initShardMasterJob: "true" + spec: restartPolicy: OnFailure containers: @@ -123,6 +132,15 @@ metadata: spec: backoffLimit: 1 template: + metadata: + labels: + app: vitess + component: vttablet + cell: {{ $cellClean | quote }} + keyspace: {{ $keyspaceClean | quote }} + shard: {{ $shardClean | quote }} + copySchemaShardJob: "true" + spec: restartPolicy: OnFailure containers: diff --git a/helm/vitess/templates/_vtctld.tpl b/helm/vitess/templates/_vtctld.tpl index 765898c1cfe..af178035e99 100644 --- a/helm/vitess/templates/_vtctld.tpl +++ b/helm/vitess/templates/_vtctld.tpl @@ -60,7 +60,7 @@ spec: containers: - name: vtctld image: vitess/vtctld:{{$vitessTag}} - ImagePullPolicy: IfNotPresent + imagePullPolicy: IfNotPresent readinessProbe: httpGet: path: /debug/health diff --git a/helm/vitess/templates/_vtgate.tpl b/helm/vitess/templates/_vtgate.tpl index 585a1a26562..1455e9c47a2 100644 --- a/helm/vitess/templates/_vtgate.tpl +++ b/helm/vitess/templates/_vtgate.tpl @@ -75,7 +75,7 @@ spec: containers: - name: vtgate image: vitess/vtgate:{{$vitessTag}} - ImagePullPolicy: IfNotPresent + imagePullPolicy: IfNotPresent readinessProbe: httpGet: path: /debug/health @@ -224,7 +224,7 @@ affinity: - name: init-mysql-creds image: "vitess/vtgate:{{$vitessTag}}" - ImagePullPolicy: IfNotPresent + imagePullPolicy: IfNotPresent volumeMounts: - name: creds mountPath: "/mysqlcreds" diff --git a/helm/vitess/templates/_vttablet.tpl b/helm/vitess/templates/_vttablet.tpl index c2b85187a4d..0c3e4643fc3 100644 --- a/helm/vitess/templates/_vttablet.tpl +++ b/helm/vitess/templates/_vttablet.tpl @@ -96,7 +96,7 @@ spec: shard: {{ $shardClean | quote }} type: {{ $tablet.type | quote }} spec: - terminationGracePeriodSeconds: 600 + terminationGracePeriodSeconds: 60000000 {{ include "pod-security" . | indent 6 }} {{ include "vttablet-affinity" (tuple $cellClean $keyspaceClean $shardClean $cell.region) | indent 6 }} @@ -111,7 +111,7 @@ spec: {{ include "cont-mysql-generallog" . | indent 8 }} {{ include "cont-mysql-errorlog" . | indent 8 }} {{ include "cont-mysql-slowlog" . | indent 8 }} -{{ if $pmm.enabled }}{{ include "cont-pmm-client" (tuple $pmm $namespace) | indent 8 }}{{ end }} +{{ if $pmm.enabled }}{{ include "cont-pmm-client" (tuple $pmm $namespace $keyspace) | indent 8 }}{{ end }} volumes: - name: vt @@ -119,6 +119,11 @@ spec: {{ include "backup-volume" $config.backup | indent 8 }} {{ include "user-config-volume" (.extraMyCnf | default $defaultVttablet.extraMyCnf) | indent 8 }} {{ include "user-secret-volumes" (.secrets | default $defaultVttablet.secrets) | indent 8 }} +{{ if $keyspace.pmm }}{{if $keyspace.pmm.config }} + - name: config + configMap: + name: {{ $keyspace.pmm.config }} +{{ end }}{{ end }} volumeClaimTemplates: - metadata: @@ -162,7 +167,7 @@ spec: - name: "init-mysql" image: "vitess/mysqlctld:{{$vitessTag}}" - ImagePullPolicy: IfNotPresent + imagePullPolicy: IfNotPresent volumeMounts: - name: vtdataroot mountPath: "/vtdataroot" @@ -188,6 +193,9 @@ spec: touch /vtdataroot/tabletdata/slow-query.log touch /vtdataroot/tabletdata/general.log + # remove the old socket file if it is still around + rm -f /vtdataroot/tabletdata/mysql.sock + {{- end -}} ################################### @@ -203,7 +211,7 @@ spec: - name: init-vttablet image: "vitess/vtctl:{{$vitessTag}}" - ImagePullPolicy: IfNotPresent + imagePullPolicy: IfNotPresent volumeMounts: - name: vtdataroot mountPath: "/vtdataroot" @@ -268,7 +276,7 @@ spec: - name: vttablet image: "vitess/vttablet:{{$vitessTag}}" - ImagePullPolicy: IfNotPresent + imagePullPolicy: IfNotPresent readinessProbe: httpGet: path: /debug/health @@ -336,7 +344,7 @@ spec: # - use GTID_SUBTRACT RETRY_COUNT=0 - MAX_RETRY_COUNT=5 + MAX_RETRY_COUNT=100000 # retry reparenting until [ $DONE_REPARENTING ]; do @@ -428,7 +436,7 @@ spec: - name: mysql image: {{.mysqlImage | default $defaultVttablet.mysqlImage | quote}} - ImagePullPolicy: IfNotPresent + imagePullPolicy: IfNotPresent readinessProbe: exec: command: ["mysqladmin", "ping", "-uroot", "--socket=/vtdataroot/tabletdata/mysql.sock"] @@ -505,8 +513,8 @@ spec: {{ define "cont-logrotate" }} - name: logrotate - image: vitess/logrotate:latest - ImagePullPolicy: IfNotPresent + image: vitess/logrotate:helm-1.0.3 + imagePullPolicy: IfNotPresent volumeMounts: - name: vtdataroot mountPath: /vtdataroot @@ -519,8 +527,8 @@ spec: {{ define "cont-mysql-errorlog" }} - name: error-log - image: vitess/logtail:latest - ImagePullPolicy: IfNotPresent + image: vitess/logtail:helm-1.0.3 + imagePullPolicy: IfNotPresent env: - name: TAIL_FILEPATH @@ -537,8 +545,8 @@ spec: {{ define "cont-mysql-slowlog" }} - name: slow-log - image: vitess/logtail:latest - ImagePullPolicy: IfNotPresent + image: vitess/logtail:helm-1.0.3 + imagePullPolicy: IfNotPresent env: - name: TAIL_FILEPATH @@ -555,8 +563,8 @@ spec: {{ define "cont-mysql-generallog" }} - name: general-log - image: vitess/logtail:latest - ImagePullPolicy: IfNotPresent + image: vitess/logtail:helm-1.0.3 + imagePullPolicy: IfNotPresent env: - name: TAIL_FILEPATH diff --git a/helm/vitess/values.yaml b/helm/vitess/values.yaml index 63db7cca30f..03d200e2a71 100644 --- a/helm/vitess/values.yaml +++ b/helm/vitess/values.yaml @@ -129,6 +129,13 @@ topology: # } # } + ## this defines keyspace specific information for PMM + # pmm: + ## PMM supports collecting metrics from custom SQL queries in a file named queries-mysqld.yml + # The specified ConfigMap will be mounted in a directory, so the file name is important. + # https://www.percona.com/blog/2018/10/10/percona-monitoring-and-management-pmm-1-15-0-is-now-available/ + # config: pmm-commerce-config + # enable or disable mysql protocol support, with accompanying auth details mysqlProtocol: enabled: false @@ -170,7 +177,7 @@ etcd: # Default values for vtctld resources defined in 'topology' vtctld: serviceType: ClusterIP - vitessTag: latest + vitessTag: helm-1.0.3 resources: # requests: # cpu: 100m @@ -181,23 +188,32 @@ vtctld: # Default values for vtgate resources defined in 'topology' vtgate: serviceType: ClusterIP - vitessTag: latest + vitessTag: helm-1.0.3 resources: # requests: # cpu: 500m # memory: 512Mi - extraFlags: {} + + # Additional flags that will be appended to the vtgate command. + # The options below are the most commonly adjusted, but any flag can be put here. + # run vtgate --help to see all available flags + extraFlags: + # MySQL server version to advertise. (default "5.5.10-Vitess") + # If running 8.0, you may need to use something like "8.0.13-Vitess" + # to prevent db clients from running deprecated queries on startup + mysql_server_version: "5.5.10-Vitess" + secrets: [] # secrets are mounted under /vt/usersecrets/{secretname} # Default values for vtctlclient resources defined in 'topology' vtctlclient: - vitessTag: latest + vitessTag: helm-1.0.3 extraFlags: {} secrets: [] # secrets are mounted under /vt/usersecrets/{secretname} # Default values for vtworker resources defined in 'jobs' vtworker: - vitessTag: latest + vitessTag: helm-1.0.3 extraFlags: {} resources: # requests: @@ -208,7 +224,7 @@ vtworker: # Default values for vttablet resources defined in 'topology' vttablet: - vitessTag: latest + vitessTag: helm-1.0.3 # valid values are # - mysql56 (for MySQL 8.0) @@ -241,8 +257,38 @@ vttablet: # If the value is "test", then mysql is instanitated with a smaller footprint. mysqlSize: prod - # Additional flags that will be appended to the vttablet command - extraFlags: {} + # Additional flags that will be appended to the vttablet command. + # The options below are the most commonly adjusted, but any flag can be put here. + # run vttablet --help to see all available flags + extraFlags: + # query server max result size, maximum number of rows allowed to return + # from vttablet for non-streaming queries. + queryserver-config-max-result-size: 10000 + + # query server query timeout (in seconds), this is the query timeout in vttablet side. + # If a query takes more than this timeout, it will be killed. + queryserver-config-query-timeout: 30 + + # query server connection pool size, connection pool is used by + # regular queries (non streaming, not in a transaction) + queryserver-config-pool-size: 24 + + # query server stream connection pool size, stream pool is used by stream queries: + # queries that return results to client in a streaming fashion + queryserver-config-stream-pool-size: 100 + + # query server transaction cap is the maximum number of transactions allowed to + # happen at any given point of a time for a single vttablet. + # e.g. by setting transaction cap to 100, there are at most 100 transactions + # will be processed by a vttablet and the 101th transaction will be blocked + # (and fail if it cannot get connection within specified timeout) + queryserver-config-transaction-cap: 300 + + # Size of the connection pool for app connections + app_pool_size: 40 + + # Size of the connection pool for dba connections + dba_pool_size: 20 # User secrets that will be mounted under /vt/usersecrets/{secretname}/ secrets: [] diff --git a/java/jdbc/pom.xml b/java/jdbc/pom.xml index b2585ea79ee..02ada256b32 100644 --- a/java/jdbc/pom.xml +++ b/java/jdbc/pom.xml @@ -81,15 +81,6 @@ - - org.apache.maven.plugins - maven-compiler-plugin - 3.5 - - 7 - 7 - - org.apache.maven.plugins maven-surefire-plugin diff --git a/java/jdbc/src/main/java/io/vitess/jdbc/ConnectionProperties.java b/java/jdbc/src/main/java/io/vitess/jdbc/ConnectionProperties.java index 72850bb666a..f978311831e 100644 --- a/java/jdbc/src/main/java/io/vitess/jdbc/ConnectionProperties.java +++ b/java/jdbc/src/main/java/io/vitess/jdbc/ConnectionProperties.java @@ -16,6 +16,11 @@ package io.vitess.jdbc; +import io.vitess.proto.Query; +import io.vitess.proto.Topodata; +import io.vitess.util.Constants; +import io.vitess.util.StringUtils; + import java.io.UnsupportedEncodingException; import java.lang.reflect.Field; import java.sql.DriverPropertyInfo; @@ -27,11 +32,6 @@ import java.util.Properties; import java.util.concurrent.TimeUnit; -import io.vitess.proto.Query; -import io.vitess.proto.Topodata; -import io.vitess.util.Constants; -import io.vitess.util.StringUtils; - public class ConnectionProperties { private static final ArrayList PROPERTY_LIST = new ArrayList<>(); diff --git a/java/jdbc/src/main/java/io/vitess/jdbc/FieldWithMetadata.java b/java/jdbc/src/main/java/io/vitess/jdbc/FieldWithMetadata.java index e26253e017f..910c83825e3 100644 --- a/java/jdbc/src/main/java/io/vitess/jdbc/FieldWithMetadata.java +++ b/java/jdbc/src/main/java/io/vitess/jdbc/FieldWithMetadata.java @@ -17,9 +17,6 @@ package io.vitess.jdbc; import com.google.common.annotations.VisibleForTesting; -import java.sql.SQLException; -import java.sql.Types; -import java.util.regex.PatternSyntaxException; import io.vitess.proto.Query; import io.vitess.util.Constants; @@ -27,6 +24,10 @@ import io.vitess.util.StringUtils; import io.vitess.util.charset.CharsetMapping; +import java.sql.SQLException; +import java.sql.Types; +import java.util.regex.PatternSyntaxException; + public class FieldWithMetadata { private final ConnectionProperties connectionProperties; diff --git a/java/jdbc/src/main/java/io/vitess/jdbc/VitessConnection.java b/java/jdbc/src/main/java/io/vitess/jdbc/VitessConnection.java index 40a110bbf79..9fcd3d947be 100644 --- a/java/jdbc/src/main/java/io/vitess/jdbc/VitessConnection.java +++ b/java/jdbc/src/main/java/io/vitess/jdbc/VitessConnection.java @@ -16,6 +16,14 @@ package io.vitess.jdbc; +import io.vitess.client.Context; +import io.vitess.client.VTGateConnection; +import io.vitess.client.VTSession; +import io.vitess.proto.Query; +import io.vitess.util.CommonUtils; +import io.vitess.util.Constants; +import io.vitess.util.MysqlDefs; + import java.sql.Array; import java.sql.Blob; import java.sql.CallableStatement; @@ -41,16 +49,6 @@ import java.util.Properties; import java.util.Set; import java.util.concurrent.Executor; -import java.util.logging.Logger; - -import io.vitess.client.Context; -import io.vitess.client.VTGateConnection; -import io.vitess.client.VTSession; -import io.vitess.proto.Query; -import io.vitess.proto.Vtgate; -import io.vitess.util.CommonUtils; -import io.vitess.util.Constants; -import io.vitess.util.MysqlDefs; /** * Created by harshit.gangal on 23/01/16. @@ -58,7 +56,6 @@ public class VitessConnection extends ConnectionProperties implements Connection { /* Get actual class name to be printed on */ - private static Logger logger = Logger.getLogger(VitessConnection.class.getName()); private static DatabaseMetaData databaseMetaData = null; /** @@ -223,9 +220,8 @@ public void close() throws SQLException { * Return Connection state * * @return DatabaseMetadata Object - * @throws SQLException */ - public boolean isClosed() throws SQLException { + public boolean isClosed() { return this.closed; } @@ -486,9 +482,8 @@ public void setClientInfo(String name, String value) throws SQLClientInfoExcepti * * @param name - Property Name * @return Property Value - * @throws SQLException */ - public String getClientInfo(String name) throws SQLException { + public String getClientInfo(String name) { return null; } @@ -496,9 +491,8 @@ public String getClientInfo(String name) throws SQLException { * TODO: For Implementation Possibility * * @return - Property Object - * @throws SQLException */ - public Properties getClientInfo() throws SQLException { + public Properties getClientInfo() { return null; } @@ -818,11 +812,11 @@ private String initializeDBProperties() throws SQLException { if (metadataNullOrClosed()) { String versionValue; - ResultSet resultSet = null; - VitessStatement vitessStatement = new VitessStatement(this); - try { - resultSet = vitessStatement.executeQuery( - "SHOW VARIABLES WHERE VARIABLE_NAME IN (\'tx_isolation\',\'INNODB_VERSION\', \'lower_case_table_names\')"); + + try(VitessStatement vitessStatement = new VitessStatement(this); + ResultSet resultSet = vitessStatement.executeQuery( + "SHOW VARIABLES WHERE VARIABLE_NAME IN (\'tx_isolation\',\'INNODB_VERSION\', \'lower_case_table_names\')") + ) { while (resultSet.next()) { dbVariables.put(resultSet.getString(1), resultSet.getString(2)); } @@ -855,13 +849,7 @@ private String initializeDBProperties() throws SQLException { } this.dbProperties = new DBProperties(productVersion, majorVersion, minorVersion, isolationLevel, lowerCaseTables); - } finally { - if (null != resultSet) { - resultSet.close(); - } - vitessStatement.close(); } - } return dbEngine; } diff --git a/java/jdbc/src/main/java/io/vitess/jdbc/VitessDatabaseMetaData.java b/java/jdbc/src/main/java/io/vitess/jdbc/VitessDatabaseMetaData.java index e0f3c0fbd9d..61942951a15 100644 --- a/java/jdbc/src/main/java/io/vitess/jdbc/VitessDatabaseMetaData.java +++ b/java/jdbc/src/main/java/io/vitess/jdbc/VitessDatabaseMetaData.java @@ -16,14 +16,14 @@ package io.vitess.jdbc; +import io.vitess.util.Constants; + import java.sql.DatabaseMetaData; import java.sql.ResultSet; import java.sql.SQLException; import java.sql.SQLFeatureNotSupportedException; import java.util.logging.Logger; -import io.vitess.util.Constants; - /** * Created by harshit.gangal on 25/01/16. */ diff --git a/java/jdbc/src/main/java/io/vitess/jdbc/VitessDriver.java b/java/jdbc/src/main/java/io/vitess/jdbc/VitessDriver.java index b78579b1abf..a06f7df9904 100644 --- a/java/jdbc/src/main/java/io/vitess/jdbc/VitessDriver.java +++ b/java/jdbc/src/main/java/io/vitess/jdbc/VitessDriver.java @@ -16,6 +16,8 @@ package io.vitess.jdbc; +import io.vitess.util.Constants; + import java.sql.Connection; import java.sql.Driver; import java.sql.DriverManager; @@ -25,8 +27,6 @@ import java.util.Properties; import java.util.logging.Logger; -import io.vitess.util.Constants; - /** * VitessDriver is the official JDBC driver for Vitess. * @@ -38,9 +38,6 @@ */ public class VitessDriver implements Driver { - /* Get actual class name to be printed on */ - private static Logger logger = Logger.getLogger(VitessDriver.class.getName()); - static { try { DriverManager.registerDriver(new VitessDriver()); @@ -73,7 +70,7 @@ public Connection connect(String url, Properties info) throws SQLException { * TODO: Write a better regex */ @Override - public boolean acceptsURL(String url) throws SQLException { + public boolean acceptsURL(String url) { return null != url && url.startsWith(Constants.URL_PREFIX); } diff --git a/java/jdbc/src/main/java/io/vitess/jdbc/VitessJDBCUrl.java b/java/jdbc/src/main/java/io/vitess/jdbc/VitessJDBCUrl.java index ca7f39c22ee..416a76734a6 100644 --- a/java/jdbc/src/main/java/io/vitess/jdbc/VitessJDBCUrl.java +++ b/java/jdbc/src/main/java/io/vitess/jdbc/VitessJDBCUrl.java @@ -16,6 +16,9 @@ package io.vitess.jdbc; +import io.vitess.util.Constants; +import io.vitess.util.StringUtils; + import java.io.UnsupportedEncodingException; import java.net.URLDecoder; import java.sql.SQLException; @@ -26,9 +29,6 @@ import java.util.regex.Matcher; import java.util.regex.Pattern; -import io.vitess.util.Constants; -import io.vitess.util.StringUtils; - /** * VitessJDBCUrl is responsible for parsing a driver URL and Properties object, * returning a new Properties object with configuration from the URL and passed in Properties diff --git a/java/jdbc/src/main/java/io/vitess/jdbc/VitessMariaDBDatabaseMetadata.java b/java/jdbc/src/main/java/io/vitess/jdbc/VitessMariaDBDatabaseMetadata.java index 1dc87d55d7e..5b0f47f7e07 100644 --- a/java/jdbc/src/main/java/io/vitess/jdbc/VitessMariaDBDatabaseMetadata.java +++ b/java/jdbc/src/main/java/io/vitess/jdbc/VitessMariaDBDatabaseMetadata.java @@ -16,6 +16,9 @@ package io.vitess.jdbc; +import io.vitess.proto.Query; +import io.vitess.util.Constants; + import java.sql.Connection; import java.sql.DatabaseMetaData; import java.sql.ResultSet; @@ -24,9 +27,6 @@ import java.sql.SQLFeatureNotSupportedException; import java.util.logging.Logger; -import io.vitess.proto.Query; -import io.vitess.util.Constants; - /** * Created by ashudeep.sharma on 15/02/16. */ diff --git a/java/jdbc/src/main/java/io/vitess/jdbc/VitessMySQLDatabaseMetadata.java b/java/jdbc/src/main/java/io/vitess/jdbc/VitessMySQLDatabaseMetadata.java index 8a72ef686bb..e94541dacde 100644 --- a/java/jdbc/src/main/java/io/vitess/jdbc/VitessMySQLDatabaseMetadata.java +++ b/java/jdbc/src/main/java/io/vitess/jdbc/VitessMySQLDatabaseMetadata.java @@ -17,6 +17,12 @@ package io.vitess.jdbc; import com.google.common.annotations.VisibleForTesting; +import org.apache.commons.lang.StringUtils; + +import io.vitess.proto.Query; +import io.vitess.util.Constants; +import io.vitess.util.MysqlDefs; + import java.sql.Connection; import java.sql.DatabaseMetaData; import java.sql.ResultSet; @@ -35,11 +41,6 @@ import java.util.StringTokenizer; import java.util.TreeMap; import java.util.logging.Logger; -import org.apache.commons.lang.StringUtils; - -import io.vitess.proto.Query; -import io.vitess.util.Constants; -import io.vitess.util.MysqlDefs; /** * Created by ashudeep.sharma on 15/02/16. diff --git a/java/jdbc/src/main/java/io/vitess/jdbc/VitessPreparedStatement.java b/java/jdbc/src/main/java/io/vitess/jdbc/VitessPreparedStatement.java index b554a224fae..d53d7c86ceb 100644 --- a/java/jdbc/src/main/java/io/vitess/jdbc/VitessPreparedStatement.java +++ b/java/jdbc/src/main/java/io/vitess/jdbc/VitessPreparedStatement.java @@ -16,6 +16,14 @@ package io.vitess.jdbc; +import io.vitess.client.Context; +import io.vitess.client.VTGateConnection; +import io.vitess.client.cursor.Cursor; +import io.vitess.client.cursor.CursorWithError; +import io.vitess.mysql.DateTime; +import io.vitess.util.Constants; +import io.vitess.util.StringUtils; + import java.io.InputStream; import java.io.Reader; import java.math.BigDecimal; @@ -48,15 +56,6 @@ import java.util.List; import java.util.Locale; import java.util.Map; -import java.util.logging.Logger; - -import io.vitess.client.Context; -import io.vitess.client.VTGateConnection; -import io.vitess.client.cursor.Cursor; -import io.vitess.client.cursor.CursorWithError; -import io.vitess.mysql.DateTime; -import io.vitess.util.Constants; -import io.vitess.util.StringUtils; /** * Created by harshit.gangal on 25/01/16. @@ -70,7 +69,6 @@ public class VitessPreparedStatement extends VitessStatement implements PreparedStatement { /* Get actual class name to be printed on */ - private static Logger logger = Logger.getLogger(VitessPreparedStatement.class.getName()); private final String sql; private final Map bindVariables; /** @@ -394,7 +392,7 @@ public void addBatch() throws SQLException { vtGateConn = this.vitessConnection.getVtGateConn(); this.retrieveGeneratedKeys = true; // mimicking mysql-connector-j - /** + /* * Current api does not support single query and multiple bindVariables list. * So, List of the query is created to match the bindVariables list. */ @@ -466,9 +464,6 @@ private int calculateParameterCount() throws SQLException { continue; // inline quote escape } - inQuotes = !inQuotes; - currentQuoteChar = 0; - } else if (((c == '\'') || (c == '"')) && c == currentQuoteChar) { inQuotes = !inQuotes; currentQuoteChar = 0; } diff --git a/java/jdbc/src/main/java/io/vitess/jdbc/VitessResultSet.java b/java/jdbc/src/main/java/io/vitess/jdbc/VitessResultSet.java index aa6cdcd36b3..f13a31d24fb 100644 --- a/java/jdbc/src/main/java/io/vitess/jdbc/VitessResultSet.java +++ b/java/jdbc/src/main/java/io/vitess/jdbc/VitessResultSet.java @@ -18,6 +18,15 @@ import com.google.common.annotations.VisibleForTesting; import com.google.protobuf.ByteString; + +import io.vitess.client.cursor.Cursor; +import io.vitess.client.cursor.Row; +import io.vitess.client.cursor.SimpleCursor; +import io.vitess.proto.Query; +import io.vitess.util.Constants; +import io.vitess.util.StringUtils; + +import javax.sql.rowset.serial.SerialClob; import java.io.ByteArrayInputStream; import java.io.InputStream; import java.io.Reader; @@ -46,24 +55,12 @@ import java.util.Calendar; import java.util.List; import java.util.Map; -import java.util.logging.Logger; -import javax.sql.rowset.serial.SerialClob; - -import io.vitess.client.cursor.Cursor; -import io.vitess.client.cursor.Row; -import io.vitess.client.cursor.SimpleCursor; -import io.vitess.proto.Query; -import io.vitess.util.Constants; -import io.vitess.util.StringUtils; /** * Created by harshit.gangal on 23/01/16. */ public class VitessResultSet implements ResultSet { - /* Get actual class name to be printed on */ - private static Logger logger = Logger.getLogger(VitessResultSet.class.getName()); - private Cursor cursor; private List fields; private VitessStatement vitessStatement; @@ -765,7 +762,7 @@ public Timestamp getTimestamp(String columnLabel, Calendar cal) throws SQLExcept return getTimestamp(columnIndex, cal); } - public boolean isClosed() throws SQLException { + public boolean isClosed() { return this.closed; } diff --git a/java/jdbc/src/main/java/io/vitess/jdbc/VitessResultSetMetaData.java b/java/jdbc/src/main/java/io/vitess/jdbc/VitessResultSetMetaData.java index a7f8c7cb4a3..8eddb40752f 100644 --- a/java/jdbc/src/main/java/io/vitess/jdbc/VitessResultSetMetaData.java +++ b/java/jdbc/src/main/java/io/vitess/jdbc/VitessResultSetMetaData.java @@ -17,15 +17,16 @@ package io.vitess.jdbc; import com.google.common.collect.ImmutableList; + +import io.vitess.proto.Query; +import io.vitess.util.Constants; + import java.sql.ResultSetMetaData; import java.sql.SQLException; import java.sql.Types; import java.util.List; import java.util.logging.Logger; -import io.vitess.proto.Query; -import io.vitess.util.Constants; - /** * Created by harshit.gangal on 25/01/16. diff --git a/java/jdbc/src/main/java/io/vitess/jdbc/VitessStatement.java b/java/jdbc/src/main/java/io/vitess/jdbc/VitessStatement.java index f4a0904824f..bde82e6f01e 100644 --- a/java/jdbc/src/main/java/io/vitess/jdbc/VitessStatement.java +++ b/java/jdbc/src/main/java/io/vitess/jdbc/VitessStatement.java @@ -16,16 +16,6 @@ package io.vitess.jdbc; -import java.sql.BatchUpdateException; -import java.sql.ResultSet; -import java.sql.SQLException; -import java.sql.SQLFeatureNotSupportedException; -import java.sql.SQLWarning; -import java.sql.Statement; -import java.util.ArrayList; -import java.util.List; -import java.util.logging.Logger; - import io.vitess.client.Context; import io.vitess.client.Proto; import io.vitess.client.VTGateConnection; @@ -36,6 +26,15 @@ import io.vitess.util.Constants; import io.vitess.util.StringUtils; +import java.sql.BatchUpdateException; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.SQLFeatureNotSupportedException; +import java.sql.SQLWarning; +import java.sql.Statement; +import java.util.ArrayList; +import java.util.List; + /** * Created by harshit.gangal on 19/01/16. *

    @@ -48,8 +47,6 @@ */ public class VitessStatement implements Statement { protected static final String[] ON_DUPLICATE_KEY_UPDATE_CLAUSE = new String[] { "ON", "DUPLICATE", "KEY", "UPDATE" }; - /* Get actual class name to be printed on */ - private static Logger logger = Logger.getLogger(VitessStatement.class.getName()); protected VitessResultSet vitessResultSet; protected VitessConnection vitessConnection; protected boolean closed; @@ -268,9 +265,8 @@ public SQLWarning getWarnings() throws SQLException { /** * Clear the warnings - Not saving Warnings * - * @throws SQLException */ - public void clearWarnings() throws SQLException { + public void clearWarnings() { //no-op } @@ -316,7 +312,7 @@ public VitessConnection getConnection() throws SQLException { return vitessConnection; } - public boolean isClosed() throws SQLException { + public boolean isClosed() { return this.closed; } @@ -376,15 +372,15 @@ public ResultSet getGeneratedKeys() throws SQLException { } } else if (this.batchGeneratedKeys != null) { long totalAffected = 0; - for (int i = 0; i < this.batchGeneratedKeys.length; i++) { - long rowsAffected = this.batchGeneratedKeys[i][1]; + for (long[] batchGeneratedKey : this.batchGeneratedKeys) { + long rowsAffected = batchGeneratedKey[1]; totalAffected += rowsAffected; } data = new String[(int) totalAffected][1]; int idx = 0; - for (int i = 0; i < this.batchGeneratedKeys.length; i++) { - long insertId = this.batchGeneratedKeys[i][0]; - long rowsAffected = this.batchGeneratedKeys[i][1]; + for (long[] batchGeneratedKey : this.batchGeneratedKeys) { + long insertId = batchGeneratedKey[0]; + long rowsAffected = batchGeneratedKey[1]; for (int j = 0; j < rowsAffected; j++) { data[idx++][0] = String.valueOf(insertId + j); } @@ -591,7 +587,7 @@ protected void checkSQLNullOrEmpty(String sql) throws SQLException { protected int[] generateBatchUpdateResult(List cursorWithErrorList, List batchedArgs) throws BatchUpdateException { int[] updateCounts = new int[cursorWithErrorList.size()]; - ArrayList generatedKeys = new ArrayList(); + ArrayList generatedKeys = new ArrayList<>(); Vtrpc.RPCError rpcError = null; String batchCommand = null; @@ -663,7 +659,7 @@ protected void checkAndBeginTransaction() throws SQLException { if (!(this.vitessConnection.getAutoCommit() || this.vitessConnection.isInTransaction())) { Context context = this.vitessConnection.createContext(this.queryTimeoutInMillis); VTGateConnection vtGateConn = this.vitessConnection.getVtGateConn(); - Cursor cursor = vtGateConn.execute(context,"begin",null,this.vitessConnection.getVtSession()).checkedGet(); + vtGateConn.execute(context,"begin",null,this.vitessConnection.getVtSession()).checkedGet(); } } diff --git a/java/jdbc/src/main/java/io/vitess/jdbc/VitessVTGateManager.java b/java/jdbc/src/main/java/io/vitess/jdbc/VitessVTGateManager.java index 895f8ef3f70..b9047d2a8c9 100644 --- a/java/jdbc/src/main/java/io/vitess/jdbc/VitessVTGateManager.java +++ b/java/jdbc/src/main/java/io/vitess/jdbc/VitessVTGateManager.java @@ -16,6 +16,14 @@ package io.vitess.jdbc; +import io.vitess.client.Context; +import io.vitess.client.RefreshableVTGateConnection; +import io.vitess.client.VTGateConnection; +import io.vitess.client.grpc.GrpcClientFactory; +import io.vitess.client.grpc.RetryingInterceptorConfig; +import io.vitess.client.grpc.tls.TlsOptions; +import io.vitess.util.Constants; + import java.io.IOException; import java.sql.SQLException; import java.util.ArrayList; @@ -26,16 +34,8 @@ import java.util.TimerTask; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.TimeUnit; -import java.util.logging.Logger; import java.util.logging.Level; - -import io.vitess.client.Context; -import io.vitess.client.VTGateConnection; -import io.vitess.client.RefreshableVTGateConnection; -import io.vitess.client.grpc.GrpcClientFactory; -import io.vitess.client.grpc.RetryingInterceptorConfig; -import io.vitess.client.grpc.tls.TlsOptions; -import io.vitess.util.Constants; +import java.util.logging.Logger; /** * Created by naveen.nahata on 24/02/16. diff --git a/java/jdbc/src/test/java/io/vitess/jdbc/BaseTest.java b/java/jdbc/src/test/java/io/vitess/jdbc/BaseTest.java index 0d807e2f7ec..5d4f7fbfaff 100644 --- a/java/jdbc/src/test/java/io/vitess/jdbc/BaseTest.java +++ b/java/jdbc/src/test/java/io/vitess/jdbc/BaseTest.java @@ -16,11 +16,12 @@ package io.vitess.jdbc; -import java.sql.SQLException; -import java.util.Properties; import org.junit.Assert; import org.junit.BeforeClass; +import java.sql.SQLException; +import java.util.Properties; + public class BaseTest { String dbURL = "jdbc:vitess://locahost:9000/vt_keyspace/keyspace"; diff --git a/java/jdbc/src/test/java/io/vitess/jdbc/ConnectionPropertiesTest.java b/java/jdbc/src/test/java/io/vitess/jdbc/ConnectionPropertiesTest.java index 6d1e8cc5c4b..195b867fa4d 100644 --- a/java/jdbc/src/test/java/io/vitess/jdbc/ConnectionPropertiesTest.java +++ b/java/jdbc/src/test/java/io/vitess/jdbc/ConnectionPropertiesTest.java @@ -16,11 +16,6 @@ package io.vitess.jdbc; -import java.sql.DriverPropertyInfo; -import java.sql.SQLException; -import java.util.Arrays; -import java.util.Properties; - import org.junit.Assert; import org.junit.Test; import org.mockito.Mockito; @@ -28,6 +23,23 @@ import io.vitess.proto.Query; import io.vitess.proto.Topodata; import io.vitess.util.Constants; +import io.vitess.util.Constants.ZeroDateTimeBehavior; + +import java.sql.DriverPropertyInfo; +import java.sql.SQLException; +import java.util.Arrays; +import java.util.Properties; + +import static io.vitess.util.Constants.DEFAULT_EXECUTE_TYPE; +import static io.vitess.util.Constants.DEFAULT_INCLUDED_FIELDS; +import static io.vitess.util.Constants.DEFAULT_TABLET_TYPE; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; +import static org.mockito.Matchers.anyString; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; public class ConnectionPropertiesTest { @@ -42,8 +54,8 @@ public void testReflection() throws Exception { // Just testing that we are properly picking up all the fields defined in the properties // For each field we call initializeFrom, which should call getProperty and remove - Mockito.verify(info, Mockito.times(NUM_PROPS)).getProperty(Mockito.anyString()); - Mockito.verify(info, Mockito.times(NUM_PROPS)).remove(Mockito.anyString()); + verify(info, times(NUM_PROPS)).getProperty(anyString()); + verify(info, times(NUM_PROPS)).remove(anyString()); } @Test @@ -52,24 +64,24 @@ public void testDefaults() throws SQLException { ConnectionProperties props = new ConnectionProperties(); props.initializeProperties(new Properties()); - Assert.assertEquals("blobsAreStrings", false, props.getBlobsAreStrings()); - Assert.assertEquals("functionsNeverReturnBlobs", false, props.getFunctionsNeverReturnBlobs()); - Assert.assertEquals("tinyInt1isBit", true, props.getTinyInt1isBit()); - Assert.assertEquals("yearIsDateType", true, props.getYearIsDateType()); - Assert.assertEquals("useBlobToStoreUTF8OutsideBMP", false, props.getUseBlobToStoreUTF8OutsideBMP()); - Assert.assertEquals("utf8OutsideBmpIncludedColumnNamePattern", null, props.getUtf8OutsideBmpIncludedColumnNamePattern()); - Assert.assertEquals("utf8OutsideBmpExcludedColumnNamePattern", null, props.getUtf8OutsideBmpExcludedColumnNamePattern()); - Assert.assertEquals("zeroDateTimeBehavior", Constants.ZeroDateTimeBehavior.GARBLE, props.getZeroDateTimeBehavior()); - Assert.assertEquals("characterEncoding", null, props.getEncoding()); - Assert.assertEquals("executeType", Constants.DEFAULT_EXECUTE_TYPE, props.getExecuteType()); - Assert.assertEquals("twopcEnabled", false, props.getTwopcEnabled()); - Assert.assertEquals("includedFields", Constants.DEFAULT_INCLUDED_FIELDS, props.getIncludedFields()); - Assert.assertEquals("includedFieldsCache", true, props.isIncludeAllFields()); - Assert.assertEquals("tabletType", Constants.DEFAULT_TABLET_TYPE, props.getTabletType()); - Assert.assertEquals("useSSL", false, props.getUseSSL()); - Assert.assertEquals("useAffectedRows", true, props.getUseAffectedRows()); - Assert.assertEquals("refreshConnection", false, props.getRefreshConnection()); - Assert.assertEquals("refreshSeconds", 60, props.getRefreshSeconds()); + assertEquals("blobsAreStrings", false, props.getBlobsAreStrings()); + assertEquals("functionsNeverReturnBlobs", false, props.getFunctionsNeverReturnBlobs()); + assertEquals("tinyInt1isBit", true, props.getTinyInt1isBit()); + assertEquals("yearIsDateType", true, props.getYearIsDateType()); + assertEquals("useBlobToStoreUTF8OutsideBMP", false, props.getUseBlobToStoreUTF8OutsideBMP()); + assertEquals("utf8OutsideBmpIncludedColumnNamePattern", null, props.getUtf8OutsideBmpIncludedColumnNamePattern()); + assertEquals("utf8OutsideBmpExcludedColumnNamePattern", null, props.getUtf8OutsideBmpExcludedColumnNamePattern()); + assertEquals("zeroDateTimeBehavior", ZeroDateTimeBehavior.GARBLE, props.getZeroDateTimeBehavior()); + assertEquals("characterEncoding", null, props.getEncoding()); + assertEquals("executeType", DEFAULT_EXECUTE_TYPE, props.getExecuteType()); + assertEquals("twopcEnabled", false, props.getTwopcEnabled()); + assertEquals("includedFields", DEFAULT_INCLUDED_FIELDS, props.getIncludedFields()); + assertEquals("includedFieldsCache", true, props.isIncludeAllFields()); + assertEquals("tabletType", DEFAULT_TABLET_TYPE, props.getTabletType()); + assertEquals("useSSL", false, props.getUseSSL()); + assertEquals("useAffectedRows", true, props.getUseAffectedRows()); + assertEquals("refreshConnection", false, props.getRefreshConnection()); + assertEquals("refreshSeconds", 60, props.getRefreshSeconds()); } @Test @@ -93,24 +105,24 @@ public void testInitializeFromProperties() throws SQLException { props.initializeProperties(info); - Assert.assertEquals("blobsAreStrings", true, props.getBlobsAreStrings()); - Assert.assertEquals("functionsNeverReturnBlobs", true, props.getFunctionsNeverReturnBlobs()); - Assert.assertEquals("tinyInt1isBit", true, props.getTinyInt1isBit()); - Assert.assertEquals("yearIsDateType", true, props.getYearIsDateType()); - Assert.assertEquals("useBlobToStoreUTF8OutsideBMP", true, props.getUseBlobToStoreUTF8OutsideBMP()); - Assert.assertEquals("utf8OutsideBmpIncludedColumnNamePattern", "(foo|bar)?baz", props.getUtf8OutsideBmpIncludedColumnNamePattern()); - Assert.assertEquals("utf8OutsideBmpExcludedColumnNamePattern", "(foo|bar)?baz", props.getUtf8OutsideBmpExcludedColumnNamePattern()); - Assert.assertEquals("zeroDateTimeBehavior", Constants.ZeroDateTimeBehavior.CONVERTTONULL, props.getZeroDateTimeBehavior()); - Assert.assertEquals("characterEncoding", "utf-8", props.getEncoding()); - Assert.assertEquals("executeType", Constants.QueryExecuteType.STREAM, props.getExecuteType()); - Assert.assertEquals("twopcEnabled", true, props.getTwopcEnabled()); - Assert.assertEquals("includedFields", Query.ExecuteOptions.IncludedFields.TYPE_ONLY, props.getIncludedFields()); - Assert.assertEquals("includedFieldsCache", false, props.isIncludeAllFields()); - Assert.assertEquals("tabletType", Topodata.TabletType.BACKUP, props.getTabletType()); + assertEquals("blobsAreStrings", true, props.getBlobsAreStrings()); + assertEquals("functionsNeverReturnBlobs", true, props.getFunctionsNeverReturnBlobs()); + assertEquals("tinyInt1isBit", true, props.getTinyInt1isBit()); + assertEquals("yearIsDateType", true, props.getYearIsDateType()); + assertEquals("useBlobToStoreUTF8OutsideBMP", true, props.getUseBlobToStoreUTF8OutsideBMP()); + assertEquals("utf8OutsideBmpIncludedColumnNamePattern", "(foo|bar)?baz", props.getUtf8OutsideBmpIncludedColumnNamePattern()); + assertEquals("utf8OutsideBmpExcludedColumnNamePattern", "(foo|bar)?baz", props.getUtf8OutsideBmpExcludedColumnNamePattern()); + assertEquals("zeroDateTimeBehavior", ZeroDateTimeBehavior.CONVERTTONULL, props.getZeroDateTimeBehavior()); + assertEquals("characterEncoding", "utf-8", props.getEncoding()); + assertEquals("executeType", Constants.QueryExecuteType.STREAM, props.getExecuteType()); + assertEquals("twopcEnabled", true, props.getTwopcEnabled()); + assertEquals("includedFields", Query.ExecuteOptions.IncludedFields.TYPE_ONLY, props.getIncludedFields()); + assertEquals("includedFieldsCache", false, props.isIncludeAllFields()); + assertEquals("tabletType", Topodata.TabletType.BACKUP, props.getTabletType()); } - @Test(expected = SQLException.class) - public void testEncodingValidation() throws SQLException { + @Test + public void testEncodingValidation() { ConnectionProperties props = new ConnectionProperties(); Properties info = new Properties(); @@ -118,10 +130,9 @@ public void testEncodingValidation() throws SQLException { info.setProperty("characterEncoding", fakeEncoding); try { props.initializeProperties(info); - Assert.fail("should have failed to parse encoding " + fakeEncoding); + fail("should have failed to parse encoding " + fakeEncoding); } catch (SQLException e) { - Assert.assertEquals("Unsupported character encoding: " + fakeEncoding, e.getMessage()); - throw e; + assertEquals("Unsupported character encoding: " + fakeEncoding, e.getMessage()); } } @@ -129,14 +140,14 @@ public void testEncodingValidation() throws SQLException { public void testDriverPropertiesOutput() throws SQLException { Properties info = new Properties(); DriverPropertyInfo[] infos = ConnectionProperties.exposeAsDriverPropertyInfo(info, 0); - Assert.assertEquals(NUM_PROPS, infos.length); + assertEquals(NUM_PROPS, infos.length); // Test the expected fields for just 1 int indexForFullTest = 3; - Assert.assertEquals("executeType", infos[indexForFullTest].name); - Assert.assertEquals("Query execution type: simple or stream", + assertEquals("executeType", infos[indexForFullTest].name); + assertEquals("Query execution type: simple or stream", infos[indexForFullTest].description); - Assert.assertEquals(false, infos[indexForFullTest].required); + assertEquals(false, infos[indexForFullTest].required); Constants.QueryExecuteType[] enumConstants = Constants.QueryExecuteType.values(); String[] allowed = new String[enumConstants.length]; for (int i = 0; i < enumConstants.length; i++) { @@ -145,17 +156,17 @@ public void testDriverPropertiesOutput() throws SQLException { Assert.assertArrayEquals(allowed, infos[indexForFullTest].choices); // Test that name exists for the others, as a sanity check - Assert.assertEquals("dbName", infos[1].name); - Assert.assertEquals("characterEncoding", infos[2].name); - Assert.assertEquals("executeType", infos[3].name); - Assert.assertEquals("functionsNeverReturnBlobs", infos[4].name); - Assert.assertEquals("grpcRetriesEnabled", infos[5].name); - Assert.assertEquals("grpcRetriesBackoffMultiplier", infos[6].name); - Assert.assertEquals("grpcRetriesInitialBackoffMillis", infos[7].name); - Assert.assertEquals("grpcRetriesMaxBackoffMillis", infos[8].name); - Assert.assertEquals(Constants.Property.INCLUDED_FIELDS, infos[9].name); - Assert.assertEquals(Constants.Property.TABLET_TYPE, infos[21].name); - Assert.assertEquals(Constants.Property.TWOPC_ENABLED, infos[29].name); + assertEquals("dbName", infos[1].name); + assertEquals("characterEncoding", infos[2].name); + assertEquals("executeType", infos[3].name); + assertEquals("functionsNeverReturnBlobs", infos[4].name); + assertEquals("grpcRetriesEnabled", infos[5].name); + assertEquals("grpcRetriesBackoffMultiplier", infos[6].name); + assertEquals("grpcRetriesInitialBackoffMillis", infos[7].name); + assertEquals("grpcRetriesMaxBackoffMillis", infos[8].name); + assertEquals(Constants.Property.INCLUDED_FIELDS, infos[9].name); + assertEquals(Constants.Property.TABLET_TYPE, infos[21].name); + assertEquals(Constants.Property.TWOPC_ENABLED, infos[29].name); } @Test @@ -172,14 +183,12 @@ public void testValidBooleanValues() throws SQLException { info.setProperty(Constants.Property.TWOPC_ENABLED, "false-ish"); try { props.initializeProperties(info); - Assert.fail("should have thrown an exception on bad value false-ish"); + fail("should have thrown an exception on bad value false-ish"); } catch (IllegalArgumentException e) { - Assert.assertEquals( - "Property '" + Constants.Property.TWOPC_ENABLED + "' Value 'false-ish' not in the list of allowable values: " - + Arrays.toString(new String[] { Boolean.toString(true), Boolean.toString(false), "yes", "no"}) - , e.getMessage()); + String expected = String.format("Property '%s' Value 'false-ish' not in the list of allowable values: [true, false, yes, no]", + Constants.Property.TWOPC_ENABLED); + assertEquals(expected , e.getMessage()); } - } @Test @@ -190,9 +199,9 @@ public void testValidEnumValues() throws SQLException { info.setProperty("executeType", "foo"); try { props.initializeProperties(info); - Assert.fail("should have thrown an exception on bad value foo"); + fail("should have thrown an exception on bad value foo"); } catch (IllegalArgumentException e) { - Assert.assertEquals( + assertEquals( "Property 'executeType' Value 'foo' not in the list of allowable values: " + Arrays.toString(Constants.QueryExecuteType.values()) , e.getMessage()); @@ -205,16 +214,16 @@ public void testSettersUpdateCaches() throws SQLException { props.initializeProperties(new Properties()); // included fields and all boolean cache - Assert.assertEquals(Constants.DEFAULT_INCLUDED_FIELDS, props.getIncludedFields()); - Assert.assertEquals(true, props.isIncludeAllFields()); + assertEquals(DEFAULT_INCLUDED_FIELDS, props.getIncludedFields()); + assertTrue(props.isIncludeAllFields()); // execute type and simple boolean cache - Assert.assertEquals(Constants.DEFAULT_EXECUTE_TYPE, props.getExecuteType()); - Assert.assertEquals(Constants.DEFAULT_EXECUTE_TYPE == Constants.QueryExecuteType.SIMPLE, props.isSimpleExecute()); + assertEquals(DEFAULT_EXECUTE_TYPE, props.getExecuteType()); + assertEquals(DEFAULT_EXECUTE_TYPE == Constants.QueryExecuteType.SIMPLE, props.isSimpleExecute()); // tablet type and twopc - Assert.assertEquals(Constants.DEFAULT_TABLET_TYPE, props.getTabletType()); - Assert.assertEquals(false, props.getTwopcEnabled()); + assertEquals(DEFAULT_TABLET_TYPE, props.getTabletType()); + assertFalse(props.getTwopcEnabled()); props.setIncludedFields(Query.ExecuteOptions.IncludedFields.TYPE_AND_NAME); props.setExecuteType(Constants.QueryExecuteType.STREAM); @@ -222,16 +231,16 @@ public void testSettersUpdateCaches() throws SQLException { props.setTwopcEnabled(true); // included fields and all boolean cache - Assert.assertEquals(Query.ExecuteOptions.IncludedFields.TYPE_AND_NAME, props.getIncludedFields()); - Assert.assertEquals(false, props.isIncludeAllFields()); + assertEquals(Query.ExecuteOptions.IncludedFields.TYPE_AND_NAME, props.getIncludedFields()); + assertFalse(props.isIncludeAllFields()); // execute type and simple boolean cache - Assert.assertEquals(Constants.QueryExecuteType.STREAM, props.getExecuteType()); - Assert.assertEquals(Constants.DEFAULT_EXECUTE_TYPE != Constants.QueryExecuteType.SIMPLE, props.isSimpleExecute()); + assertEquals(Constants.QueryExecuteType.STREAM, props.getExecuteType()); + assertEquals(DEFAULT_EXECUTE_TYPE != Constants.QueryExecuteType.SIMPLE, props.isSimpleExecute()); // tablet type and twopc - Assert.assertEquals(Topodata.TabletType.BACKUP, props.getTabletType()); - Assert.assertEquals(true, props.getTwopcEnabled()); + assertEquals(Topodata.TabletType.BACKUP, props.getTabletType()); + assertTrue(props.getTwopcEnabled()); } @Test @@ -242,33 +251,33 @@ public void testTarget() throws SQLException { Properties info = new Properties(); info.setProperty(Constants.Property.KEYSPACE, "test_keyspace"); props.initializeProperties(info); - Assert.assertEquals("target", "test_keyspace@master", props.getTarget()); + assertEquals("target", "test_keyspace@master", props.getTarget()); // Setting keyspace and shard info = new Properties(); info.setProperty(Constants.Property.KEYSPACE, "test_keyspace"); info.setProperty(Constants.Property.SHARD, "80-c0"); props.initializeProperties(info); - Assert.assertEquals("target", "test_keyspace:80-c0@master", props.getTarget()); + assertEquals("target", "test_keyspace:80-c0@master", props.getTarget()); // Setting tablet type info = new Properties(); info.setProperty(Constants.Property.TABLET_TYPE, "replica"); props.initializeProperties(info); - Assert.assertEquals("target", "@replica", props.getTarget()); + assertEquals("target", "@replica", props.getTarget()); // Setting shard which will have no impact without keyspace info = new Properties(); info.setProperty(Constants.Property.SHARD, "80-c0"); props.initializeProperties(info); - Assert.assertEquals("target", "@master", props.getTarget()); + assertEquals("target", "@master", props.getTarget()); // Setting shard and tablet type. Shard will have no impact. info = new Properties(); info.setProperty(Constants.Property.SHARD, "80-c0"); info.setProperty(Constants.Property.TABLET_TYPE, "replica"); props.initializeProperties(info); - Assert.assertEquals("target", "@replica", props.getTarget()); + assertEquals("target", "@replica", props.getTarget()); // Setting keyspace, shard and tablet type. info = new Properties(); @@ -276,7 +285,7 @@ public void testTarget() throws SQLException { info.setProperty(Constants.Property.SHARD, "80-c0"); info.setProperty(Constants.Property.TABLET_TYPE, "rdonly"); props.initializeProperties(info); - Assert.assertEquals("target", "test_keyspace:80-c0@rdonly", props.getTarget()); + assertEquals("target", "test_keyspace:80-c0@rdonly", props.getTarget()); // Setting keyspace, shard, tablet type and target. Target supersede others. info = new Properties(); @@ -285,6 +294,6 @@ public void testTarget() throws SQLException { info.setProperty(Constants.Property.TABLET_TYPE, "rdonly"); info.setProperty(Constants.Property.TARGET, "dummy"); props.initializeProperties(info); - Assert.assertEquals("target", "dummy", props.getTarget()); + assertEquals("target", "dummy", props.getTarget()); } } diff --git a/java/jdbc/src/test/java/io/vitess/jdbc/FieldWithMetadataTest.java b/java/jdbc/src/test/java/io/vitess/jdbc/FieldWithMetadataTest.java index c81ce9d5d7f..863399930a3 100644 --- a/java/jdbc/src/test/java/io/vitess/jdbc/FieldWithMetadataTest.java +++ b/java/jdbc/src/test/java/io/vitess/jdbc/FieldWithMetadataTest.java @@ -16,8 +16,6 @@ package io.vitess.jdbc; -import java.sql.SQLException; -import java.sql.Types; import org.junit.Assert; import org.junit.Test; import org.junit.runner.RunWith; @@ -31,6 +29,9 @@ import io.vitess.util.MysqlDefs; import io.vitess.util.charset.CharsetMapping; +import java.sql.SQLException; +import java.sql.Types; + @PrepareForTest(FieldWithMetadata.class) @RunWith(PowerMockRunner.class) public class FieldWithMetadataTest extends BaseTest { diff --git a/java/jdbc/src/test/java/io/vitess/jdbc/VitessConnectionTest.java b/java/jdbc/src/test/java/io/vitess/jdbc/VitessConnectionTest.java index 3121857c70f..64df7a23a58 100644 --- a/java/jdbc/src/test/java/io/vitess/jdbc/VitessConnectionTest.java +++ b/java/jdbc/src/test/java/io/vitess/jdbc/VitessConnectionTest.java @@ -1,12 +1,12 @@ /* * Copyright 2017 Google Inc. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -16,14 +16,6 @@ package io.vitess.jdbc; -import java.lang.reflect.Field; -import java.sql.Connection; -import java.sql.PreparedStatement; -import java.sql.SQLException; -import java.sql.Statement; -import java.util.Properties; - -import org.junit.Assert; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mockito; @@ -32,224 +24,247 @@ import io.vitess.client.VTSession; import io.vitess.proto.Query; +import io.vitess.proto.Query.ExecuteOptions.TransactionIsolation; import io.vitess.proto.Topodata; import io.vitess.util.Constants; +import java.lang.reflect.Field; +import java.sql.Connection; +import java.sql.PreparedStatement; +import java.sql.SQLException; +import java.sql.Statement; +import java.util.Properties; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + /** * Created by harshit.gangal on 19/01/16. */ @RunWith(PowerMockRunner.class) public class VitessConnectionTest extends BaseTest { - @Test public void testVitessConnection() throws SQLException { + @Test + public void testVitessConnection() throws SQLException { VitessConnection vitessConnection = new VitessConnection(dbURL, new Properties()); - Assert.assertEquals(false, vitessConnection.isClosed()); - Assert.assertNull(vitessConnection.getDbProperties()); + assertFalse(vitessConnection.isClosed()); + assertNull(vitessConnection.getDbProperties()); } - @Test public void testCreateStatement() throws SQLException { + @Test + public void testCreateStatement() throws SQLException { VitessConnection vitessConnection = getVitessConnection(); Statement statement = vitessConnection.createStatement(); - Assert.assertEquals(vitessConnection, statement.getConnection()); + assertEquals(vitessConnection, statement.getConnection()); } - - @Test(expected = SQLException.class) public void testCreateStatementForClose() - throws SQLException { + @Test + public void testCreateStatementForClose() + throws SQLException { VitessConnection vitessConnection = getVitessConnection(); - vitessConnection.close(); - try { - Statement statement = vitessConnection.createStatement(); - } catch (SQLException e) { - throw new SQLException(Constants.SQLExceptionMessages.CONN_CLOSED); - } + assertFailsOnClosedConnection(vitessConnection, vitessConnection::createStatement); } - @Test public void testnativeSQL() throws SQLException { + @Test + public void testnativeSQL() throws SQLException { VitessConnection vitessConnection = getVitessConnection(); - Assert.assertEquals("query", vitessConnection.nativeSQL("query")); + assertEquals("query", vitessConnection.nativeSQL("query")); } - @Test public void testCreatePreperedStatement() throws SQLException { + @Test + public void testCreatePreperedStatement() throws SQLException { VitessConnection vitessConnection = getVitessConnection(); PreparedStatement preparedStatementstatement = vitessConnection.prepareStatement("query"); - Assert.assertEquals(vitessConnection, preparedStatementstatement.getConnection()); + assertEquals(vitessConnection, preparedStatementstatement.getConnection()); } - - @Test(expected = SQLException.class) public void testCreatePreperedStatementForClose() - throws SQLException { + @Test + public void testCreatePreparedStatementForClose() throws SQLException { VitessConnection vitessConnection = getVitessConnection(); - vitessConnection.close(); - try { - PreparedStatement preparedStatementstatement = - vitessConnection.prepareStatement("query"); - } catch (SQLException e) { - throw new SQLException(Constants.SQLExceptionMessages.CONN_CLOSED); - } + assertFailsOnClosedConnection(vitessConnection, () -> vitessConnection.prepareStatement("query")); } - @Test public void testDefaultGetAutoCommit() throws SQLException { + @Test + public void testDefaultGetAutoCommit() throws SQLException { VitessConnection vitessConnection = getVitessConnection(); - Assert.assertEquals(true, vitessConnection.getAutoCommit()); + assertTrue(vitessConnection.getAutoCommit()); } - @Test(expected = SQLException.class) public void testDefaultGetAutoCommitForClose() - throws SQLException { + @Test + public void testDefaultGetAutoCommitForClose() throws SQLException { VitessConnection vitessConnection = getVitessConnection(); - vitessConnection.close(); - try { - boolean autoCommit = vitessConnection.getAutoCommit(); - } catch (SQLException e) { - throw new SQLException(Constants.SQLExceptionMessages.CONN_CLOSED); - } + assertFailsOnClosedConnection(vitessConnection, vitessConnection::getAutoCommit); } - @Test public void testDefaultSetAutoCommit() throws SQLException { + @Test + public void testDefaultSetAutoCommit() throws SQLException { VitessConnection vitessConnection = getVitessConnection(); vitessConnection.setAutoCommit(false); - Assert.assertEquals(false, vitessConnection.getAutoCommit()); + assertFalse(vitessConnection.getAutoCommit()); } - @Test(expected = SQLException.class) public void testDefaultSetAutoCommitForClose() - throws SQLException { + @Test + public void testDefaultSetAutoCommitForClose() + throws SQLException { VitessConnection vitessConnection = getVitessConnection(); - vitessConnection.close(); - try { - boolean autoCommit = vitessConnection.getAutoCommit(); - } catch (SQLException e) { - throw new SQLException(Constants.SQLExceptionMessages.CONN_CLOSED); - } + assertFailsOnClosedConnection(vitessConnection, () -> vitessConnection.setAutoCommit(false)); } - @Test public void testCommit() throws SQLException { + @Test + public void testCommit() throws Exception { VTSession mockSession = PowerMockito.mock(VTSession.class); VitessConnection vitessConnection = getVitessConnection(); - try { - Field privateVTSessionField = VitessConnection.class.getDeclaredField("vtSession"); - privateVTSessionField.setAccessible(true); - privateVTSessionField.set(vitessConnection, mockSession); - PowerMockito.when(mockSession.isInTransaction()).thenReturn(false); - PowerMockito.when(mockSession.isAutoCommit()).thenReturn(false); - } catch (NoSuchFieldException | IllegalAccessException e) { - Assert.fail(e.getMessage()); - } + Field privateVTSessionField = VitessConnection.class.getDeclaredField("vtSession"); + privateVTSessionField.setAccessible(true); + privateVTSessionField.set(vitessConnection, mockSession); + PowerMockito.when(mockSession.isInTransaction()).thenReturn(false); + PowerMockito.when(mockSession.isAutoCommit()).thenReturn(false); vitessConnection.commit(); } - @Test(expected = SQLException.class) public void testCommitForException() throws SQLException { + @Test(expected = SQLException.class) + public void testCommitForException() throws SQLException { VitessConnection vitessConnection = getVitessConnection(); vitessConnection.setAutoCommit(true); vitessConnection.commit(); } - @Test public void testRollback() throws SQLException { + @Test + public void testRollback() throws SQLException { VitessConnection vitessConnection = getVitessConnection(); vitessConnection.setAutoCommit(false); vitessConnection.rollback(); } - @Test(expected = SQLException.class) public void testRollbackForException() - throws SQLException { + @Test(expected = SQLException.class) + public void testRollbackForException() throws SQLException { VitessConnection vitessConnection = getVitessConnection(); vitessConnection.setAutoCommit(true); vitessConnection.rollback(); } - @Test public void testClosed() throws SQLException { + @Test + public void testClosed() throws SQLException { VitessConnection vitessConnection = getVitessConnection(); vitessConnection.setAutoCommit(false); vitessConnection.close(); - Assert.assertEquals(true, vitessConnection.isClosed()); + assertTrue(vitessConnection.isClosed()); } - @Test(expected = SQLException.class) public void testClosedForException() throws SQLException { + @Test(expected = SQLException.class) + public void testClosedForException() throws Exception { VTSession mockSession = PowerMockito.mock(VTSession.class); VitessConnection vitessConnection = getVitessConnection(); - try { - Field privateVTSessionField = VitessConnection.class.getDeclaredField("vtSession"); - privateVTSessionField.setAccessible(true); - privateVTSessionField.set(vitessConnection, mockSession); - //vtSession.setSession(mockSession.getSession()); - PowerMockito.when(mockSession.isInTransaction()).thenReturn(true); - PowerMockito.when(mockSession.isAutoCommit()).thenReturn(true); - } catch (NoSuchFieldException | IllegalAccessException e) { - Assert.fail(e.getMessage()); - } + Field privateVTSessionField = VitessConnection.class.getDeclaredField("vtSession"); + privateVTSessionField.setAccessible(true); + privateVTSessionField.set(vitessConnection, mockSession); + PowerMockito.when(mockSession.isInTransaction()).thenReturn(true); + PowerMockito.when(mockSession.isAutoCommit()).thenReturn(true); vitessConnection.close(); } - @Test public void testGetCatalog() throws SQLException { + @Test + public void testGetCatalog() throws SQLException { VitessConnection vitessConnection = getVitessConnection(); - Assert.assertEquals("keyspace", vitessConnection.getCatalog()); + assertEquals("keyspace", vitessConnection.getCatalog()); } - @Test public void testSetCatalog() throws SQLException { + @Test + public void testSetCatalog() throws SQLException { VitessConnection vitessConnection = getVitessConnection(); vitessConnection.setCatalog("myDB"); - Assert.assertEquals("myDB", vitessConnection.getCatalog()); + assertEquals("myDB", vitessConnection.getCatalog()); } - @Test public void testPropertiesFromJdbcUrl() throws SQLException { + @Test + public void testPropertiesFromJdbcUrl() throws SQLException { String url = "jdbc:vitess://locahost:9000/vt_keyspace/keyspace?TABLET_TYPE=replica&includedFields=type_and_name&blobsAreStrings=yes"; VitessConnection conn = new VitessConnection(url, new Properties()); // Properties from the url should be passed into the connection properties, and override whatever defaults we've defined - Assert.assertEquals(Query.ExecuteOptions.IncludedFields.TYPE_AND_NAME, conn.getIncludedFields()); - Assert.assertEquals(false, conn.isIncludeAllFields()); - Assert.assertEquals(Topodata.TabletType.REPLICA, conn.getTabletType()); - Assert.assertEquals(true, conn.getBlobsAreStrings()); + assertEquals(Query.ExecuteOptions.IncludedFields.TYPE_AND_NAME, conn.getIncludedFields()); + assertFalse(conn.isIncludeAllFields()); + assertEquals(Topodata.TabletType.REPLICA, conn.getTabletType()); + assertTrue(conn.getBlobsAreStrings()); } - @Test public void testClientFoundRows() throws SQLException { + @Test + public void testClientFoundRows() throws SQLException { String url = "jdbc:vitess://locahost:9000/vt_keyspace/keyspace?TABLET_TYPE=replica&useAffectedRows=true"; VitessConnection conn = new VitessConnection(url, new Properties()); - Assert.assertEquals(true, conn.getUseAffectedRows()); - Assert.assertEquals(false, conn.getVtSession().getSession().getOptions().getClientFoundRows()); + assertTrue(conn.getUseAffectedRows()); + assertFalse(conn.getVtSession().getSession().getOptions().getClientFoundRows()); + } - url = "jdbc:vitess://locahost:9000/vt_keyspace/keyspace?TABLET_TYPE=replica&useAffectedRows=false"; - conn = new VitessConnection(url, new Properties()); + @Test + public void testClientFoundRows2() throws SQLException { + String url = "jdbc:vitess://locahost:9000/vt_keyspace/keyspace?TABLET_TYPE=replica&useAffectedRows=false"; + VitessConnection conn = new VitessConnection(url, new Properties()); - Assert.assertEquals(false, conn.getUseAffectedRows()); - Assert.assertEquals(true, conn.getVtSession().getSession().getOptions().getClientFoundRows()); + assertFalse(conn.getUseAffectedRows()); + assertTrue(conn.getVtSession().getSession().getOptions().getClientFoundRows()); } - @Test public void testWorkload() throws SQLException { + @Test + public void testWorkload() throws SQLException { for (Query.ExecuteOptions.Workload workload : Query.ExecuteOptions.Workload.values()) { if (workload == Query.ExecuteOptions.Workload.UNRECOGNIZED) { continue; } String url = "jdbc:vitess://locahost:9000/vt_keyspace/keyspace?TABLET_TYPE=replica&workload=" + workload.toString().toLowerCase(); VitessConnection conn = new VitessConnection(url, new Properties()); - - Assert.assertEquals(workload, conn.getWorkload()); - Assert.assertEquals(workload, conn.getVtSession().getSession().getOptions().getWorkload()); + + assertEquals(workload, conn.getWorkload()); + assertEquals(workload, conn.getVtSession().getSession().getOptions().getWorkload()); } } - @Test public void testTransactionIsolation() throws SQLException { + @Test + public void testTransactionIsolation() throws SQLException { VitessConnection conn = Mockito.spy(getVitessConnection()); - Mockito.doReturn(new DBProperties("random", "random", "random", Connection.TRANSACTION_REPEATABLE_READ, "random")) - .when(conn) - .getDbProperties(); - Mockito.doReturn(new VitessMySQLDatabaseMetadata(conn)).when(conn).getMetaData(); + doReturn(new DBProperties("random", "random", "random", Connection.TRANSACTION_REPEATABLE_READ, "random")) + .when(conn) + .getDbProperties(); + doReturn(new VitessMySQLDatabaseMetadata(conn)).when(conn).getMetaData(); - Assert.assertEquals(Query.ExecuteOptions.TransactionIsolation.DEFAULT, conn.getVtSession().getSession().getOptions().getTransactionIsolation()); - Assert.assertEquals(Connection.TRANSACTION_REPEATABLE_READ, conn.getTransactionIsolation()); + assertEquals(TransactionIsolation.DEFAULT, conn.getVtSession().getSession().getOptions().getTransactionIsolation()); + assertEquals(Connection.TRANSACTION_REPEATABLE_READ, conn.getTransactionIsolation()); conn.setTransactionIsolation(Connection.TRANSACTION_READ_COMMITTED); - Assert.assertEquals(Query.ExecuteOptions.TransactionIsolation.READ_COMMITTED, conn.getVtSession().getSession().getOptions().getTransactionIsolation()); - Assert.assertEquals(Connection.TRANSACTION_READ_COMMITTED, conn.getTransactionIsolation()); + assertEquals(TransactionIsolation.READ_COMMITTED, conn.getVtSession().getSession().getOptions().getTransactionIsolation()); + assertEquals(Connection.TRANSACTION_READ_COMMITTED, conn.getTransactionIsolation()); - VitessStatement statement = Mockito.mock(VitessStatement.class); - Mockito.when(conn.createStatement()).thenReturn(statement); - Mockito.when(conn.isInTransaction()).thenReturn(true); + VitessStatement statement = mock(VitessStatement.class); + when(conn.createStatement()).thenReturn(statement); + when(conn.isInTransaction()).thenReturn(true); conn.setTransactionIsolation(Connection.TRANSACTION_READ_UNCOMMITTED); - Mockito.verify(statement).executeUpdate("rollback"); - Assert.assertEquals(Query.ExecuteOptions.TransactionIsolation.READ_UNCOMMITTED, conn.getVtSession().getSession().getOptions().getTransactionIsolation()); - Assert.assertEquals(Connection.TRANSACTION_READ_UNCOMMITTED, conn.getTransactionIsolation()); + verify(statement).executeUpdate("rollback"); + assertEquals(TransactionIsolation.READ_UNCOMMITTED, conn.getVtSession().getSession().getOptions().getTransactionIsolation()); + assertEquals(Connection.TRANSACTION_READ_UNCOMMITTED, conn.getTransactionIsolation()); + } + + interface Runthis { + void run() throws SQLException; } + + private void assertFailsOnClosedConnection(VitessConnection connection, Runthis failingRunnable) throws SQLException { + connection.close(); + try { + failingRunnable.run(); + fail("expected this to fail on a closed connection"); + } catch (SQLException e) { + assertEquals(e.getMessage(), Constants.SQLExceptionMessages.CONN_CLOSED); + } + } + } diff --git a/java/jdbc/src/test/java/io/vitess/jdbc/VitessDatabaseMetadataTest.java b/java/jdbc/src/test/java/io/vitess/jdbc/VitessDatabaseMetadataTest.java index 07d7fd1c2b8..d930e08d3a7 100644 --- a/java/jdbc/src/test/java/io/vitess/jdbc/VitessDatabaseMetadataTest.java +++ b/java/jdbc/src/test/java/io/vitess/jdbc/VitessDatabaseMetadataTest.java @@ -19,6 +19,18 @@ import com.google.common.base.Charsets; import com.google.common.io.CharStreams; import com.google.protobuf.ByteString; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.powermock.api.mockito.PowerMockito; +import org.powermock.core.classloader.annotations.PrepareForTest; +import org.powermock.modules.junit4.PowerMockRunner; + +import io.vitess.client.cursor.Cursor; +import io.vitess.client.cursor.SimpleCursor; +import io.vitess.proto.Query; +import io.vitess.util.Constants; + import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; @@ -31,17 +43,6 @@ import java.util.List; import java.util.Properties; import java.util.Scanner; -import org.junit.Assert; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.powermock.api.mockito.PowerMockito; -import org.powermock.core.classloader.annotations.PrepareForTest; -import org.powermock.modules.junit4.PowerMockRunner; - -import io.vitess.client.cursor.Cursor; -import io.vitess.client.cursor.SimpleCursor; -import io.vitess.proto.Query; -import io.vitess.util.Constants; /** * Created by ashudeep.sharma on 08/03/16. diff --git a/java/jdbc/src/test/java/io/vitess/jdbc/VitessDriverTest.java b/java/jdbc/src/test/java/io/vitess/jdbc/VitessDriverTest.java index ba6f756545e..326403f8298 100644 --- a/java/jdbc/src/test/java/io/vitess/jdbc/VitessDriverTest.java +++ b/java/jdbc/src/test/java/io/vitess/jdbc/VitessDriverTest.java @@ -1,12 +1,12 @@ /* * Copyright 2017 Google Inc. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -16,15 +16,18 @@ package io.vitess.jdbc; +import org.junit.BeforeClass; +import org.junit.Test; + +import io.vitess.util.Constants; + import java.sql.DriverManager; import java.sql.DriverPropertyInfo; import java.sql.SQLException; import java.util.Properties; -import org.junit.Assert; -import org.junit.BeforeClass; -import org.junit.Test; -import io.vitess.util.Constants; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; /** * Created by harshit.gangal on 19/01/16. @@ -34,84 +37,80 @@ public class VitessDriverTest { private static VitessDriver driver = new VitessDriver(); String dbURL = - "jdbc:vitess://localhost:9000/shipment/vt_shipment?tabletType=master&executeType=stream&userName" - + "=user"; + "jdbc:vitess://localhost:9000/shipment/vt_shipment?tabletType=master&executeType=stream&userName" + + "=user"; - @BeforeClass public static void setUp() { + @BeforeClass + public static void setUp() { // load Vitess driver try { Class.forName("io.vitess.jdbc.VitessDriver"); } catch (ClassNotFoundException e) { - Assert.fail("Driver is not in the CLASSPATH -> " + e); + fail("Driver is not in the CLASSPATH -> " + e); } } - @Test public void testConnect() { - try { - VitessConnection connection = + @Test + public void testConnect() throws SQLException { + VitessConnection connection = (VitessConnection) DriverManager.getConnection(dbURL, new Properties()); - Assert.assertEquals(connection.getUrl().getUrl(), dbURL); - } catch (SQLException e) { - Assert.fail("SQLException Not Expected"); - } + assertEquals(connection.getUrl().getUrl(), dbURL); } - @Test public void testAcceptsURL() { - try { - Assert.assertEquals(true, driver.acceptsURL(dbURL)); - } catch (SQLException e) { - Assert.fail("SQLException Not Expected"); - } + @Test + public void testAcceptsURL() { + assertEquals(true, driver.acceptsURL(dbURL)); } - @Test public void testAcceptsMalformedURL() { - try { - String url = + @Test + public void testAcceptsMalformedURL() { + String url = "jdbc:MalfromdedUrl://localhost:9000/shipment/vt_shipment?tabletType=master"; - Assert.assertEquals(false, driver.acceptsURL(url)); - } catch (SQLException e) { - Assert.fail("SQLException Not Expected"); - } + assertEquals(false, driver.acceptsURL(url)); } - @Test public void testGetPropertyInfo() throws SQLException { + @Test + public void testGetPropertyInfo() throws SQLException { // Used to ensure that we're properly adding the below URL-based properties at the beginning // of the full ConnectionProperties configuration DriverPropertyInfo[] underlying = ConnectionProperties.exposeAsDriverPropertyInfo(new Properties(), 0); int additionalProp = 2; DriverPropertyInfo[] driverPropertyInfos = driver.getPropertyInfo(dbURL, null); - Assert.assertEquals(underlying.length + additionalProp, driverPropertyInfos.length); + assertEquals(underlying.length + additionalProp, driverPropertyInfos.length); - Assert.assertEquals(driverPropertyInfos[0].description, Constants.VITESS_HOST); - Assert.assertEquals(driverPropertyInfos[0].required, true); - Assert.assertEquals(driverPropertyInfos[0].name, Constants.Property.HOST); - Assert.assertEquals(driverPropertyInfos[0].value, "localhost"); + assertEquals(driverPropertyInfos[0].description, Constants.VITESS_HOST); + assertEquals(driverPropertyInfos[0].required, true); + assertEquals(driverPropertyInfos[0].name, Constants.Property.HOST); + assertEquals(driverPropertyInfos[0].value, "localhost"); - Assert.assertEquals(driverPropertyInfos[1].description, Constants.VITESS_PORT); - Assert.assertEquals(driverPropertyInfos[1].required, false); - Assert.assertEquals(driverPropertyInfos[1].name, Constants.Property.PORT); - Assert.assertEquals(driverPropertyInfos[1].value, "9000"); + assertEquals(driverPropertyInfos[1].description, Constants.VITESS_PORT); + assertEquals(driverPropertyInfos[1].required, false); + assertEquals(driverPropertyInfos[1].name, Constants.Property.PORT); + assertEquals(driverPropertyInfos[1].value, "9000"); // Validate the remainder of the driver properties match up with the underlying for (int i = additionalProp; i < driverPropertyInfos.length; i++) { - Assert.assertEquals(underlying[i - additionalProp].description, driverPropertyInfos[i].description); - Assert.assertEquals(underlying[i - additionalProp].required, driverPropertyInfos[i].required); - Assert.assertEquals(underlying[i - additionalProp].name, driverPropertyInfos[i].name); - Assert.assertEquals(underlying[i - additionalProp].value, driverPropertyInfos[i].value); + assertEquals(underlying[i - additionalProp].description, driverPropertyInfos[i].description); + assertEquals(underlying[i - additionalProp].required, driverPropertyInfos[i].required); + assertEquals(underlying[i - additionalProp].name, driverPropertyInfos[i].name); + assertEquals(underlying[i - additionalProp].value, driverPropertyInfos[i].value); } } - @Test public void testGetMajorVersion() { - Assert.assertEquals(driver.getMajorVersion(), Constants.DRIVER_MAJOR_VERSION); + @Test + public void testGetMajorVersion() { + assertEquals(driver.getMajorVersion(), Constants.DRIVER_MAJOR_VERSION); } - @Test public void testGetMinorVersion() { - Assert.assertEquals(driver.getMinorVersion(), Constants.DRIVER_MINOR_VERSION); + @Test + public void testGetMinorVersion() { + assertEquals(driver.getMinorVersion(), Constants.DRIVER_MINOR_VERSION); } - @Test public void testJdbcCompliant() { - Assert.assertEquals(false, driver.jdbcCompliant()); + @Test + public void testJdbcCompliant() { + assertEquals(false, driver.jdbcCompliant()); } } diff --git a/java/jdbc/src/test/java/io/vitess/jdbc/VitessJDBCUrlTest.java b/java/jdbc/src/test/java/io/vitess/jdbc/VitessJDBCUrlTest.java index d0d29078383..130d5bedf31 100644 --- a/java/jdbc/src/test/java/io/vitess/jdbc/VitessJDBCUrlTest.java +++ b/java/jdbc/src/test/java/io/vitess/jdbc/VitessJDBCUrlTest.java @@ -16,14 +16,15 @@ package io.vitess.jdbc; -import java.sql.SQLException; -import java.util.Properties; import org.junit.Assert; import org.junit.Test; import io.vitess.proto.Topodata; import io.vitess.util.Constants; +import java.sql.SQLException; +import java.util.Properties; + /** * Created by naveen.nahata on 18/02/16. */ diff --git a/java/jdbc/src/test/java/io/vitess/jdbc/VitessParameterMetaDataTest.java b/java/jdbc/src/test/java/io/vitess/jdbc/VitessParameterMetaDataTest.java index 3c313d11814..57212aa40d5 100644 --- a/java/jdbc/src/test/java/io/vitess/jdbc/VitessParameterMetaDataTest.java +++ b/java/jdbc/src/test/java/io/vitess/jdbc/VitessParameterMetaDataTest.java @@ -16,9 +16,6 @@ package io.vitess.jdbc; -import java.sql.ParameterMetaData; -import java.sql.SQLException; -import java.sql.Types; import org.junit.Assert; import org.junit.Test; import org.junit.runner.RunWith; @@ -27,6 +24,10 @@ import org.powermock.core.classloader.annotations.PrepareForTest; import org.powermock.modules.junit4.PowerMockRunner; +import java.sql.ParameterMetaData; +import java.sql.SQLException; +import java.sql.Types; + @RunWith(PowerMockRunner.class) @PrepareForTest(VitessParameterMetaData.class) public class VitessParameterMetaDataTest { diff --git a/java/jdbc/src/test/java/io/vitess/jdbc/VitessPreparedStatementTest.java b/java/jdbc/src/test/java/io/vitess/jdbc/VitessPreparedStatementTest.java index 31b7356f1f0..95512fd6f8d 100644 --- a/java/jdbc/src/test/java/io/vitess/jdbc/VitessPreparedStatementTest.java +++ b/java/jdbc/src/test/java/io/vitess/jdbc/VitessPreparedStatementTest.java @@ -16,25 +16,7 @@ package io.vitess.jdbc; -import java.lang.reflect.Field; -import java.math.BigDecimal; -import java.math.BigInteger; -import java.sql.BatchUpdateException; -import java.sql.Date; -import java.sql.ResultSet; -import java.sql.SQLException; -import java.sql.Statement; -import java.sql.Time; -import java.sql.Timestamp; -import java.sql.Types; -import java.util.ArrayList; -import java.util.Calendar; -import java.util.List; -import java.util.Map; -import java.util.TimeZone; - -import javax.sql.rowset.serial.SerialClob; - +import com.google.common.collect.ImmutableMap; import org.junit.Assert; import org.junit.Test; import org.junit.runner.RunWith; @@ -44,8 +26,6 @@ import org.powermock.core.classloader.annotations.PrepareForTest; import org.powermock.modules.junit4.PowerMockRunner; -import com.google.common.collect.ImmutableMap; - import io.vitess.client.Context; import io.vitess.client.SQLFuture; import io.vitess.client.VTGateConnection; @@ -57,6 +37,32 @@ import io.vitess.proto.Vtrpc; import io.vitess.util.Constants; +import javax.sql.rowset.serial.SerialClob; +import java.lang.reflect.Field; +import java.math.BigDecimal; +import java.math.BigInteger; +import java.sql.BatchUpdateException; +import java.sql.Date; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Statement; +import java.sql.Time; +import java.sql.Timestamp; +import java.sql.Types; +import java.util.ArrayList; +import java.util.Calendar; +import java.util.List; +import java.util.Map; +import java.util.TimeZone; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.anyMap; +import static org.mockito.Matchers.anyString; +import static org.powermock.api.mockito.PowerMockito.mock; +import static org.powermock.api.mockito.PowerMockito.when; + /** * Created by harshit.gangal on 09/02/16. @@ -69,51 +75,50 @@ private String sqlUpdate = "update test_table set msg = null"; private String sqlInsert = "insert into test_table(msg) values (?)"; - @Test public void testStatementExecute() throws SQLException { - VitessConnection mockConn = PowerMockito.mock(VitessConnection.class); + @Test public void testStatementExecute() { + VitessConnection mockConn = mock(VitessConnection.class); VitessPreparedStatement preparedStatement; try { preparedStatement = new VitessPreparedStatement(mockConn, sqlShow); preparedStatement.executeQuery(sqlSelect); - Assert.fail("Should have thrown exception for calling this method"); + fail("Should have thrown exception for calling this method"); } catch (SQLException ex) { - Assert.assertEquals("This method cannot be called using this class object", + assertEquals("This method cannot be called using this class object", ex.getMessage()); } try { preparedStatement = new VitessPreparedStatement(mockConn, sqlShow); preparedStatement.executeUpdate(sqlUpdate); - Assert.fail("Should have thrown exception for calling this method"); + fail("Should have thrown exception for calling this method"); } catch (SQLException ex) { - Assert.assertEquals("This method cannot be called using this class object", + assertEquals("This method cannot be called using this class object", ex.getMessage()); } try { preparedStatement = new VitessPreparedStatement(mockConn, sqlShow); preparedStatement.execute(sqlShow); - Assert.fail("Should have thrown exception for calling this method"); + fail("Should have thrown exception for calling this method"); } catch (SQLException ex) { - Assert.assertEquals("This method cannot be called using this class object", + assertEquals("This method cannot be called using this class object", ex.getMessage()); } } @Test public void testExecuteQuery() throws SQLException { - VitessConnection mockConn = PowerMockito.mock(VitessConnection.class); - VTGateConnection mockVtGateConn = PowerMockito.mock(VTGateConnection.class); - Cursor mockCursor = PowerMockito.mock(Cursor.class); - SQLFuture mockSqlFutureCursor = PowerMockito.mock(SQLFuture.class); - - PowerMockito.when(mockConn.getVtGateConn()).thenReturn(mockVtGateConn); - PowerMockito.when(mockVtGateConn - .execute(Matchers.any(Context.class), Matchers.anyString(), Matchers.anyMap(), - Matchers.any(VTSession.class))).thenReturn(mockSqlFutureCursor); - PowerMockito.when(mockConn.getExecuteType()) + VitessConnection mockConn = mock(VitessConnection.class); + VTGateConnection mockVtGateConn = mock(VTGateConnection.class); + Cursor mockCursor = mock(Cursor.class); + SQLFuture mockSqlFutureCursor = mock(SQLFuture.class); + + when(mockConn.getVtGateConn()).thenReturn(mockVtGateConn); + when(mockVtGateConn.execute(any(Context.class), anyString(), anyMap(),any(VTSession.class))). + thenReturn(mockSqlFutureCursor); + when(mockConn.getExecuteType()) .thenReturn(Constants.QueryExecuteType.SIMPLE); - PowerMockito.when(mockConn.isSimpleExecute()).thenReturn(true); - PowerMockito.when(mockSqlFutureCursor.checkedGet()).thenReturn(mockCursor); + when(mockConn.isSimpleExecute()).thenReturn(true); + when(mockSqlFutureCursor.checkedGet()).thenReturn(mockCursor); VitessPreparedStatement preparedStatement; try { @@ -121,64 +126,64 @@ //Empty Sql Statement try { new VitessPreparedStatement(mockConn, ""); - Assert.fail("Should have thrown exception for empty sql"); + fail("Should have thrown exception for empty sql"); } catch (SQLException ex) { - Assert.assertEquals("SQL statement is not valid", ex.getMessage()); + assertEquals("SQL statement is not valid", ex.getMessage()); } //show query preparedStatement = new VitessPreparedStatement(mockConn, sqlShow); ResultSet rs = preparedStatement.executeQuery(); - Assert.assertEquals(-1, preparedStatement.getUpdateCount()); + assertEquals(-1, preparedStatement.getUpdateCount()); //select on replica with bind variables preparedStatement = new VitessPreparedStatement(mockConn, sqlSelect, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY); rs = preparedStatement.executeQuery(); - Assert.assertEquals(-1, preparedStatement.getUpdateCount()); + assertEquals(-1, preparedStatement.getUpdateCount()); //select on replica without bind variables preparedStatement = new VitessPreparedStatement(mockConn, sqlSelect, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY); rs = preparedStatement.executeQuery(); - Assert.assertEquals(-1, preparedStatement.getUpdateCount()); + assertEquals(-1, preparedStatement.getUpdateCount()); //select on master rs = preparedStatement.executeQuery(); - Assert.assertEquals(-1, preparedStatement.getUpdateCount()); + assertEquals(-1, preparedStatement.getUpdateCount()); try { //when returned cursor is null - PowerMockito.when(mockSqlFutureCursor.checkedGet()).thenReturn(null); + when(mockSqlFutureCursor.checkedGet()).thenReturn(null); preparedStatement.executeQuery(); - Assert.fail("Should have thrown exception for cursor null"); + fail("Should have thrown exception for cursor null"); } catch (SQLException ex) { - Assert.assertEquals("Failed to execute this method", ex.getMessage()); + assertEquals("Failed to execute this method", ex.getMessage()); } } catch (SQLException e) { - Assert.fail("Test failed " + e.getMessage()); + fail("Test failed " + e.getMessage()); } } @Test public void testExecuteQueryWithStream() throws SQLException { - VitessConnection mockConn = PowerMockito.mock(VitessConnection.class); - VTGateConnection mockVtGateConn = PowerMockito.mock(VTGateConnection.class); - Cursor mockCursor = PowerMockito.mock(Cursor.class); - SQLFuture mockSqlFutureCursor = PowerMockito.mock(SQLFuture.class); + VitessConnection mockConn = mock(VitessConnection.class); + VTGateConnection mockVtGateConn = mock(VTGateConnection.class); + Cursor mockCursor = mock(Cursor.class); + SQLFuture mockSqlFutureCursor = mock(SQLFuture.class); - PowerMockito.when(mockConn.getVtGateConn()).thenReturn(mockVtGateConn); - PowerMockito.when(mockVtGateConn - .streamExecute(Matchers.any(Context.class), Matchers.anyString(), Matchers.anyMap(), - Matchers.any(VTSession.class))).thenReturn(mockCursor); - PowerMockito.when(mockVtGateConn - .execute(Matchers.any(Context.class), Matchers.anyString(), Matchers.anyMap(), - Matchers.any(VTSession.class))).thenReturn(mockSqlFutureCursor); - PowerMockito.when(mockSqlFutureCursor.checkedGet()).thenReturn(mockCursor); - PowerMockito.when(mockConn.getExecuteType()) + when(mockConn.getVtGateConn()).thenReturn(mockVtGateConn); + when(mockVtGateConn + .streamExecute(any(Context.class), anyString(), anyMap(), + any(VTSession.class))).thenReturn(mockCursor); + when(mockVtGateConn + .execute(any(Context.class), anyString(), anyMap(), + any(VTSession.class))).thenReturn(mockSqlFutureCursor); + when(mockSqlFutureCursor.checkedGet()).thenReturn(mockCursor); + when(mockConn.getExecuteType()) .thenReturn(Constants.QueryExecuteType.STREAM); VitessPreparedStatement preparedStatement; @@ -187,65 +192,65 @@ //Empty Sql Statement try { new VitessPreparedStatement(mockConn, ""); - Assert.fail("Should have thrown exception for empty sql"); + fail("Should have thrown exception for empty sql"); } catch (SQLException ex) { - Assert.assertEquals("SQL statement is not valid", ex.getMessage()); + assertEquals("SQL statement is not valid", ex.getMessage()); } //show query preparedStatement = new VitessPreparedStatement(mockConn, sqlShow); ResultSet rs = preparedStatement.executeQuery(); - Assert.assertEquals(-1, preparedStatement.getUpdateCount()); + assertEquals(-1, preparedStatement.getUpdateCount()); //select on replica with bind variables preparedStatement = new VitessPreparedStatement(mockConn, sqlSelect, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY); rs = preparedStatement.executeQuery(); - Assert.assertEquals(-1, preparedStatement.getUpdateCount()); + assertEquals(-1, preparedStatement.getUpdateCount()); //select on replica without bind variables preparedStatement = new VitessPreparedStatement(mockConn, sqlSelect, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY); rs = preparedStatement.executeQuery(); - Assert.assertEquals(-1, preparedStatement.getUpdateCount()); + assertEquals(-1, preparedStatement.getUpdateCount()); //select on master rs = preparedStatement.executeQuery(); - Assert.assertEquals(-1, preparedStatement.getUpdateCount()); + assertEquals(-1, preparedStatement.getUpdateCount()); try { //when returned cursor is null - PowerMockito.when(mockVtGateConn - .streamExecute(Matchers.any(Context.class), Matchers.anyString(), Matchers.anyMap(), - Matchers.any(VTSession.class))).thenReturn(null); + when(mockVtGateConn + .streamExecute(any(Context.class), anyString(), anyMap(), + any(VTSession.class))).thenReturn(null); preparedStatement.executeQuery(); - Assert.fail("Should have thrown exception for cursor null"); + fail("Should have thrown exception for cursor null"); } catch (SQLException ex) { - Assert.assertEquals("Failed to execute this method", ex.getMessage()); + assertEquals("Failed to execute this method", ex.getMessage()); } } catch (SQLException e) { - Assert.fail("Test failed " + e.getMessage()); + fail("Test failed " + e.getMessage()); } } @Test public void testExecuteUpdate() throws SQLException { - VitessConnection mockConn = PowerMockito.mock(VitessConnection.class); - VTGateConnection mockVtGateConn = PowerMockito.mock(VTGateConnection.class); - Cursor mockCursor = PowerMockito.mock(Cursor.class); - SQLFuture mockSqlFutureCursor = PowerMockito.mock(SQLFuture.class); - List fieldList = PowerMockito.mock(ArrayList.class); - - PowerMockito.when(mockConn.getVtGateConn()).thenReturn(mockVtGateConn); - PowerMockito.when(mockVtGateConn - .execute(Matchers.any(Context.class), Matchers.anyString(), Matchers.anyMap(), - Matchers.any(VTSession.class))).thenReturn(mockSqlFutureCursor); - PowerMockito.when(mockSqlFutureCursor.checkedGet()).thenReturn(mockCursor); - PowerMockito.when(mockCursor.getFields()).thenReturn(Query.QueryResult.getDefaultInstance().getFieldsList()); + VitessConnection mockConn = mock(VitessConnection.class); + VTGateConnection mockVtGateConn = mock(VTGateConnection.class); + Cursor mockCursor = mock(Cursor.class); + SQLFuture mockSqlFutureCursor = mock(SQLFuture.class); + List fieldList = mock(ArrayList.class); + + when(mockConn.getVtGateConn()).thenReturn(mockVtGateConn); + when(mockVtGateConn + .execute(any(Context.class), anyString(), anyMap(), + any(VTSession.class))).thenReturn(mockSqlFutureCursor); + when(mockSqlFutureCursor.checkedGet()).thenReturn(mockCursor); + when(mockCursor.getFields()).thenReturn(Query.QueryResult.getDefaultInstance().getFieldsList()); VitessPreparedStatement preparedStatement; try { @@ -255,76 +260,76 @@ new VitessPreparedStatement(mockConn, sqlUpdate, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY); int updateCount = preparedStatement.executeUpdate(); - Assert.assertEquals(0, updateCount); + assertEquals(0, updateCount); //tx is null & autoCommit is true - PowerMockito.when(mockConn.getAutoCommit()).thenReturn(true); + when(mockConn.getAutoCommit()).thenReturn(true); preparedStatement = new VitessPreparedStatement(mockConn, sqlUpdate); updateCount = preparedStatement.executeUpdate(); - Assert.assertEquals(0, updateCount); + assertEquals(0, updateCount); //cursor fields is not null - PowerMockito.when(mockCursor.getFields()).thenReturn(fieldList); - PowerMockito.when(fieldList.isEmpty()).thenReturn(false); + when(mockCursor.getFields()).thenReturn(fieldList); + when(fieldList.isEmpty()).thenReturn(false); try { preparedStatement.executeUpdate(); - Assert.fail("Should have thrown exception for field not null"); + fail("Should have thrown exception for field not null"); } catch (SQLException ex) { - Assert.assertEquals("ResultSet generation is not allowed through this method", + assertEquals("ResultSet generation is not allowed through this method", ex.getMessage()); } //cursor is null - PowerMockito.when(mockSqlFutureCursor.checkedGet()).thenReturn(null); + when(mockSqlFutureCursor.checkedGet()).thenReturn(null); try { preparedStatement.executeUpdate(); - Assert.fail("Should have thrown exception for cursor null"); + fail("Should have thrown exception for cursor null"); } catch (SQLException ex) { - Assert.assertEquals("Failed to execute this method", ex.getMessage()); + assertEquals("Failed to execute this method", ex.getMessage()); } //read only - PowerMockito.when(mockConn.isReadOnly()).thenReturn(true); + when(mockConn.isReadOnly()).thenReturn(true); try { preparedStatement.executeUpdate(); - Assert.fail("Should have thrown exception for read only"); + fail("Should have thrown exception for read only"); } catch (SQLException ex) { - Assert.assertEquals(Constants.SQLExceptionMessages.READ_ONLY, ex.getMessage()); + assertEquals(Constants.SQLExceptionMessages.READ_ONLY, ex.getMessage()); } //read only - PowerMockito.when(mockConn.isReadOnly()).thenReturn(true); + when(mockConn.isReadOnly()).thenReturn(true); try { preparedStatement.executeBatch(); - Assert.fail("Should have thrown exception for read only"); + fail("Should have thrown exception for read only"); } catch (SQLException ex) { - Assert.assertEquals(Constants.SQLExceptionMessages.READ_ONLY, ex.getMessage()); + assertEquals(Constants.SQLExceptionMessages.READ_ONLY, ex.getMessage()); } } catch (SQLException e) { - Assert.fail("Test failed " + e.getMessage()); + fail("Test failed " + e.getMessage()); } } @Test public void testExecute() throws SQLException { - VitessConnection mockConn = PowerMockito.mock(VitessConnection.class); - VTGateConnection mockVtGateConn = PowerMockito.mock(VTGateConnection.class); - Cursor mockCursor = PowerMockito.mock(Cursor.class); - SQLFuture mockSqlFutureCursor = PowerMockito.mock(SQLFuture.class); - List mockFieldList = PowerMockito.spy(new ArrayList()); + VitessConnection mockConn = mock(VitessConnection.class); + VTGateConnection mockVtGateConn = mock(VTGateConnection.class); + Cursor mockCursor = mock(Cursor.class); + SQLFuture mockSqlFutureCursor = mock(SQLFuture.class); + List mockFieldList = PowerMockito.spy(new ArrayList<>()); - PowerMockito.when(mockConn.getVtGateConn()).thenReturn(mockVtGateConn); - PowerMockito.when(mockVtGateConn - .execute(Matchers.any(Context.class), Matchers.anyString(), Matchers.anyMap(), - Matchers.any(VTSession.class))).thenReturn(mockSqlFutureCursor); - PowerMockito.when(mockConn.getExecuteType()) + when(mockConn.getVtGateConn()).thenReturn(mockVtGateConn); + when(mockVtGateConn + .execute(any(Context.class), anyString(), anyMap(), + any(VTSession.class))).thenReturn(mockSqlFutureCursor); + when(mockConn.getExecuteType()) .thenReturn(Constants.QueryExecuteType.SIMPLE); - PowerMockito.when(mockConn.isSimpleExecute()).thenReturn(true); + when(mockConn.isSimpleExecute()).thenReturn(true); - PowerMockito.when(mockConn.getAutoCommit()).thenReturn(true); + when(mockConn.getAutoCommit()).thenReturn(true); - PowerMockito.when(mockSqlFutureCursor.checkedGet()).thenReturn(mockCursor); - PowerMockito.when(mockCursor.getFields()).thenReturn(mockFieldList); + when(mockSqlFutureCursor.checkedGet()).thenReturn(mockCursor); + when(mockCursor.getFields()).thenReturn(mockFieldList); VitessPreparedStatement preparedStatement = new VitessPreparedStatement(mockConn, sqlSelect, ResultSet.TYPE_FORWARD_ONLY, @@ -332,7 +337,7 @@ try { int fieldSize = 5; - PowerMockito.when(mockCursor.getFields()).thenReturn(mockFieldList); + when(mockCursor.getFields()).thenReturn(mockFieldList); PowerMockito.doReturn(fieldSize).when(mockFieldList).size(); PowerMockito.doReturn(false).when(mockFieldList).isEmpty(); boolean hasResultSet = preparedStatement.execute(); @@ -345,26 +350,26 @@ Assert.assertNotNull(preparedStatement.getResultSet()); int mockUpdateCount = 10; - PowerMockito.when(mockCursor.getFields()).thenReturn(Query.QueryResult.getDefaultInstance().getFieldsList()); - PowerMockito.when(mockCursor.getRowsAffected()).thenReturn((long) mockUpdateCount); + when(mockCursor.getFields()).thenReturn(Query.QueryResult.getDefaultInstance().getFieldsList()); + when(mockCursor.getRowsAffected()).thenReturn((long) mockUpdateCount); preparedStatement = new VitessPreparedStatement(mockConn, sqlUpdate); hasResultSet = preparedStatement.execute(); Assert.assertFalse(hasResultSet); Assert.assertNull(preparedStatement.getResultSet()); - Assert.assertEquals(mockUpdateCount, preparedStatement.getUpdateCount()); + assertEquals(mockUpdateCount, preparedStatement.getUpdateCount()); //cursor is null - PowerMockito.when(mockSqlFutureCursor.checkedGet()).thenReturn(null); + when(mockSqlFutureCursor.checkedGet()).thenReturn(null); try { preparedStatement = new VitessPreparedStatement(mockConn, sqlShow); preparedStatement.execute(); - Assert.fail("Should have thrown exception for cursor null"); + fail("Should have thrown exception for cursor null"); } catch (SQLException ex) { - Assert.assertEquals("Failed to execute this method", ex.getMessage()); + assertEquals("Failed to execute this method", ex.getMessage()); } } catch (SQLException e) { - Assert.fail("Test failed " + e.getMessage()); + fail("Test failed " + e.getMessage()); } } @@ -376,82 +381,82 @@ } private void testExecute(int fetchSize, boolean simpleExecute, boolean shouldRunExecute, boolean shouldRunStreamExecute) throws SQLException { - VTGateConnection mockVtGateConn = PowerMockito.mock(VTGateConnection.class); + VTGateConnection mockVtGateConn = mock(VTGateConnection.class); - VitessConnection mockConn = PowerMockito.mock(VitessConnection.class); - PowerMockito.when(mockConn.isSimpleExecute()).thenReturn(simpleExecute); - PowerMockito.when(mockConn.getVtGateConn()).thenReturn(mockVtGateConn); + VitessConnection mockConn = mock(VitessConnection.class); + when(mockConn.isSimpleExecute()).thenReturn(simpleExecute); + when(mockConn.getVtGateConn()).thenReturn(mockVtGateConn); - Cursor mockCursor = PowerMockito.mock(Cursor.class); - SQLFuture mockSqlFutureCursor = PowerMockito.mock(SQLFuture.class); - PowerMockito.when(mockSqlFutureCursor.checkedGet()).thenReturn(mockCursor); + Cursor mockCursor = mock(Cursor.class); + SQLFuture mockSqlFutureCursor = mock(SQLFuture.class); + when(mockSqlFutureCursor.checkedGet()).thenReturn(mockCursor); - PowerMockito.when(mockVtGateConn - .execute(Matchers.any(Context.class), Matchers.anyString(), Matchers.anyMap(), - Matchers.any(VTSession.class))).thenReturn(mockSqlFutureCursor); - PowerMockito.when(mockVtGateConn - .streamExecute(Matchers.any(Context.class), Matchers.anyString(), Matchers.anyMap(), - Matchers.any(VTSession.class))).thenReturn(mockCursor); + when(mockVtGateConn + .execute(any(Context.class), anyString(), anyMap(), + any(VTSession.class))).thenReturn(mockSqlFutureCursor); + when(mockVtGateConn + .streamExecute(any(Context.class), anyString(), anyMap(), + any(VTSession.class))).thenReturn(mockCursor); VitessPreparedStatement statement = new VitessPreparedStatement(mockConn, sqlSelect); statement.setFetchSize(fetchSize); statement.executeQuery(); if (shouldRunExecute) { - Mockito.verify(mockVtGateConn, Mockito.times(2)).execute(Matchers.any(Context.class), Matchers.anyString(), Matchers.anyMap(), - Matchers.any(VTSession.class)); + Mockito.verify(mockVtGateConn, Mockito.times(2)).execute(any(Context.class), anyString(), anyMap(), + any(VTSession.class)); } if (shouldRunStreamExecute) { - Mockito.verify(mockVtGateConn).streamExecute(Matchers.any(Context.class), Matchers.anyString(), Matchers.anyMap(), - Matchers.any(VTSession.class)); + Mockito.verify(mockVtGateConn).streamExecute(any(Context.class), anyString(), anyMap(), + any(VTSession.class)); } } @Test public void testGetUpdateCount() throws SQLException { - VitessConnection mockConn = PowerMockito.mock(VitessConnection.class); - VTGateConnection mockVtGateConn = PowerMockito.mock(VTGateConnection.class); - Cursor mockCursor = PowerMockito.mock(Cursor.class); - SQLFuture mockSqlFuture = PowerMockito.mock(SQLFuture.class); - - PowerMockito.when(mockConn.getVtGateConn()).thenReturn(mockVtGateConn); - PowerMockito.when(mockVtGateConn - .execute(Matchers.any(Context.class), Matchers.anyString(), Matchers.anyMap(), - Matchers.any(VTSession.class))).thenReturn(mockSqlFuture); - PowerMockito.when(mockSqlFuture.checkedGet()).thenReturn(mockCursor); - PowerMockito.when(mockCursor.getFields()).thenReturn(Query.QueryResult.getDefaultInstance().getFieldsList()); + VitessConnection mockConn = mock(VitessConnection.class); + VTGateConnection mockVtGateConn = mock(VTGateConnection.class); + Cursor mockCursor = mock(Cursor.class); + SQLFuture mockSqlFuture = mock(SQLFuture.class); + + when(mockConn.getVtGateConn()).thenReturn(mockVtGateConn); + when(mockVtGateConn + .execute(any(Context.class), anyString(), anyMap(), + any(VTSession.class))).thenReturn(mockSqlFuture); + when(mockSqlFuture.checkedGet()).thenReturn(mockCursor); + when(mockCursor.getFields()).thenReturn(Query.QueryResult.getDefaultInstance().getFieldsList()); VitessPreparedStatement preparedStatement = new VitessPreparedStatement(mockConn, sqlSelect); try { - PowerMockito.when(mockCursor.getRowsAffected()).thenReturn(10L); + when(mockCursor.getRowsAffected()).thenReturn(10L); int updateCount = preparedStatement.executeUpdate(); - Assert.assertEquals(10L, updateCount); - Assert.assertEquals(10L, preparedStatement.getUpdateCount()); + assertEquals(10L, updateCount); + assertEquals(10L, preparedStatement.getUpdateCount()); // Truncated Update Count - PowerMockito.when(mockCursor.getRowsAffected()) + when(mockCursor.getRowsAffected()) .thenReturn((long) Integer.MAX_VALUE + 10); updateCount = preparedStatement.executeUpdate(); - Assert.assertEquals(Integer.MAX_VALUE, updateCount); - Assert.assertEquals(Integer.MAX_VALUE, preparedStatement.getUpdateCount()); + assertEquals(Integer.MAX_VALUE, updateCount); + assertEquals(Integer.MAX_VALUE, preparedStatement.getUpdateCount()); - PowerMockito.when(mockConn.isSimpleExecute()).thenReturn(true); + when(mockConn.isSimpleExecute()).thenReturn(true); preparedStatement.executeQuery(); - Assert.assertEquals(-1, preparedStatement.getUpdateCount()); + assertEquals(-1, preparedStatement.getUpdateCount()); } catch (SQLException e) { - Assert.fail("Test failed " + e.getMessage()); + fail("Test failed " + e.getMessage()); } } @Test public void testSetParameters() throws Exception { - VitessConnection mockConn = PowerMockito.mock(VitessConnection.class); + VitessConnection mockConn = mock(VitessConnection.class); Mockito.when(mockConn.getTreatUtilDateAsTimestamp()).thenReturn(true); VitessPreparedStatement preparedStatement = new VitessPreparedStatement(mockConn, sqlSelect); - Boolean boolValue = Boolean.TRUE; + Boolean boolValue = true; Byte byteValue = Byte.MAX_VALUE; Short shortValue = Short.MAX_VALUE; Integer intValue = Integer.MAX_VALUE; @@ -519,57 +524,57 @@ private void testExecute(int fetchSize, boolean simpleExecute, boolean shouldRun bindVariablesMap.setAccessible(true); Map bindVariables = (Map) bindVariablesMap.get(preparedStatement); - Assert.assertEquals(null, bindVariables.get("v1")); - Assert.assertEquals(boolValue, bindVariables.get("v2")); - Assert.assertEquals(byteValue, bindVariables.get("v3")); - Assert.assertEquals(shortValue, bindVariables.get("v4")); - Assert.assertEquals(intValue, bindVariables.get("v5")); - Assert.assertEquals(longValue, bindVariables.get("v6")); - Assert.assertEquals(floatValue, bindVariables.get("v7")); - Assert.assertEquals(doubleValue, bindVariables.get("v8")); - Assert.assertEquals(bigDecimalValue, bindVariables.get("v9")); - Assert.assertEquals(bigIntegerValue, bindVariables.get("v10")); - Assert.assertEquals(stringValue, bindVariables.get("v11")); - Assert.assertEquals(bytesValue, bindVariables.get("v12")); - Assert.assertEquals(dateValue.toString(), bindVariables.get("v13")); - Assert.assertEquals(timeValue.toString(), bindVariables.get("v14")); - Assert.assertEquals(timestampValue.toString(), bindVariables.get("v15")); - Assert.assertEquals(dateValue.toString(), bindVariables.get("v16")); - Assert.assertEquals(timeValue.toString(), bindVariables.get("v17")); - Assert.assertEquals(timestampValue.toString(), bindVariables.get("v18")); - Assert.assertEquals(boolValue, bindVariables.get("v19")); - Assert.assertEquals(byteValue, bindVariables.get("v20")); - Assert.assertEquals(shortValue, bindVariables.get("v21")); - Assert.assertEquals(intValue, bindVariables.get("v22")); - Assert.assertEquals(longValue, bindVariables.get("v23")); - Assert.assertEquals(floatValue, bindVariables.get("v24")); - Assert.assertEquals(doubleValue, bindVariables.get("v25")); - Assert.assertEquals(bigDecimalValue, bindVariables.get("v26")); - Assert.assertEquals(bigIntegerValue, bindVariables.get("v27")); - Assert.assertEquals(stringValue, bindVariables.get("v28")); - Assert.assertEquals(dateValue.toString(), bindVariables.get("v29")); - Assert.assertEquals(timeValue.toString(), bindVariables.get("v30")); - Assert.assertEquals(timestampValue.toString(), bindVariables.get("v31")); - Assert.assertEquals("a", bindVariables.get("v32")); - Assert.assertEquals(null, bindVariables.get("v33")); - Assert.assertEquals(boolValue, bindVariables.get("v34")); - Assert.assertEquals(shortValue.intValue(), bindVariables.get("v35")); - Assert.assertEquals(longValue, bindVariables.get("v36")); - Assert.assertEquals((double) floatValue, (double) bindVariables.get("v37"), 0.1); - Assert.assertEquals(doubleValue, (double) bindVariables.get("v38"), 0.1); - Assert.assertEquals(expectedDecimalValue, bindVariables.get("v39")); - Assert.assertEquals(stringValue, bindVariables.get("v40")); - Assert.assertEquals(dateValue.toString(), bindVariables.get("v41")); - Assert.assertEquals(timeValue.toString(), bindVariables.get("v42")); - Assert.assertEquals(timestampValue.toString(), bindVariables.get("v43")); - Assert.assertEquals("clob", bindVariables.get("v44")); + assertEquals(null, bindVariables.get("v1")); + assertEquals(boolValue, bindVariables.get("v2")); + assertEquals(byteValue, bindVariables.get("v3")); + assertEquals(shortValue, bindVariables.get("v4")); + assertEquals(intValue, bindVariables.get("v5")); + assertEquals(longValue, bindVariables.get("v6")); + assertEquals(floatValue, bindVariables.get("v7")); + assertEquals(doubleValue, bindVariables.get("v8")); + assertEquals(bigDecimalValue, bindVariables.get("v9")); + assertEquals(bigIntegerValue, bindVariables.get("v10")); + assertEquals(stringValue, bindVariables.get("v11")); + assertEquals(bytesValue, bindVariables.get("v12")); + assertEquals(dateValue.toString(), bindVariables.get("v13")); + assertEquals(timeValue.toString(), bindVariables.get("v14")); + assertEquals(timestampValue.toString(), bindVariables.get("v15")); + assertEquals(dateValue.toString(), bindVariables.get("v16")); + assertEquals(timeValue.toString(), bindVariables.get("v17")); + assertEquals(timestampValue.toString(), bindVariables.get("v18")); + assertEquals(boolValue, bindVariables.get("v19")); + assertEquals(byteValue, bindVariables.get("v20")); + assertEquals(shortValue, bindVariables.get("v21")); + assertEquals(intValue, bindVariables.get("v22")); + assertEquals(longValue, bindVariables.get("v23")); + assertEquals(floatValue, bindVariables.get("v24")); + assertEquals(doubleValue, bindVariables.get("v25")); + assertEquals(bigDecimalValue, bindVariables.get("v26")); + assertEquals(bigIntegerValue, bindVariables.get("v27")); + assertEquals(stringValue, bindVariables.get("v28")); + assertEquals(dateValue.toString(), bindVariables.get("v29")); + assertEquals(timeValue.toString(), bindVariables.get("v30")); + assertEquals(timestampValue.toString(), bindVariables.get("v31")); + assertEquals("a", bindVariables.get("v32")); + assertEquals(null, bindVariables.get("v33")); + assertEquals(true, bindVariables.get("v34")); + assertEquals(shortValue.intValue(), bindVariables.get("v35")); + assertEquals(longValue, bindVariables.get("v36")); + assertEquals((double) floatValue, (double) bindVariables.get("v37"), 0.1); + assertEquals(doubleValue, (double) bindVariables.get("v38"), 0.1); + assertEquals(expectedDecimalValue, bindVariables.get("v39")); + assertEquals(stringValue, bindVariables.get("v40")); + assertEquals(dateValue.toString(), bindVariables.get("v41")); + assertEquals(timeValue.toString(), bindVariables.get("v42")); + assertEquals(timestampValue.toString(), bindVariables.get("v43")); + assertEquals("clob", bindVariables.get("v44")); Assert.assertArrayEquals(bytesValue, (byte[])bindVariables.get("v45")); preparedStatement.clearParameters(); } @Test public void testTreatUtilDateAsTimestamp() throws Exception { - VitessConnection mockConn = PowerMockito.mock(VitessConnection.class); + VitessConnection mockConn = mock(VitessConnection.class); VitessPreparedStatement preparedStatement = new VitessPreparedStatement(mockConn, sqlSelect); @@ -577,7 +582,7 @@ private void testExecute(int fetchSize, boolean simpleExecute, boolean shouldRun Timestamp timestamp = new Timestamp(utilDateValue.getTime()); try { preparedStatement.setObject(1, utilDateValue); - Assert.fail("setObject on java.util.Date should have failed with SQLException"); + fail("setObject on java.util.Date should have failed with SQLException"); } catch (SQLException e) { Assert.assertTrue(e.getMessage().startsWith(Constants.SQLExceptionMessages.SQL_TYPE_INFER)); } @@ -593,57 +598,57 @@ private void testExecute(int fetchSize, boolean simpleExecute, boolean shouldRun Map bindVariables = (Map) bindVariablesMap.get(preparedStatement); - Assert.assertEquals(DateTime.formatTimestamp(timestamp), bindVariables.get("v1")); + assertEquals(DateTime.formatTimestamp(timestamp), bindVariables.get("v1")); } @Test public void testAutoGeneratedKeys() throws SQLException { - VitessConnection mockConn = PowerMockito.mock(VitessConnection.class); - VTGateConnection mockVtGateConn = PowerMockito.mock(VTGateConnection.class); - Cursor mockCursor = PowerMockito.mock(Cursor.class); - SQLFuture mockSqlFutureCursor = PowerMockito.mock(SQLFuture.class); + VitessConnection mockConn = mock(VitessConnection.class); + VTGateConnection mockVtGateConn = mock(VTGateConnection.class); + Cursor mockCursor = mock(Cursor.class); + SQLFuture mockSqlFutureCursor = mock(SQLFuture.class); - PowerMockito.when(mockConn.getVtGateConn()).thenReturn(mockVtGateConn); - PowerMockito.when(mockVtGateConn - .execute(Matchers.any(Context.class), Matchers.anyString(), Matchers.anyMap(), - Matchers.any(VTSession.class))).thenReturn(mockSqlFutureCursor); - PowerMockito.when(mockSqlFutureCursor.checkedGet()).thenReturn(mockCursor); - PowerMockito.when(mockCursor.getFields()).thenReturn(Query.QueryResult.getDefaultInstance().getFieldsList()); + when(mockConn.getVtGateConn()).thenReturn(mockVtGateConn); + when(mockVtGateConn + .execute(any(Context.class), anyString(), anyMap(), + any(VTSession.class))).thenReturn(mockSqlFutureCursor); + when(mockSqlFutureCursor.checkedGet()).thenReturn(mockCursor); + when(mockCursor.getFields()).thenReturn(Query.QueryResult.getDefaultInstance().getFieldsList()); try { long expectedFirstGeneratedId = 121; long[] expectedGeneratedIds = {121, 122}; int expectedAffectedRows = 2; - PowerMockito.when(mockCursor.getInsertId()).thenReturn(expectedFirstGeneratedId); - PowerMockito.when(mockCursor.getRowsAffected()) + when(mockCursor.getInsertId()).thenReturn(expectedFirstGeneratedId); + when(mockCursor.getRowsAffected()) .thenReturn(Long.valueOf(expectedAffectedRows)); //Executing Insert Statement VitessPreparedStatement preparedStatement = new VitessPreparedStatement(mockConn, sqlInsert, Statement.RETURN_GENERATED_KEYS); int updateCount = preparedStatement.executeUpdate(); - Assert.assertEquals(expectedAffectedRows, updateCount); + assertEquals(expectedAffectedRows, updateCount); ResultSet rs = preparedStatement.getGeneratedKeys(); int i = 0; while (rs.next()) { long generatedId = rs.getLong(1); - Assert.assertEquals(expectedGeneratedIds[i++], generatedId); + assertEquals(expectedGeneratedIds[i++], generatedId); } } catch (SQLException e) { - Assert.fail("Test failed " + e.getMessage()); + fail("Test failed " + e.getMessage()); } } @Test public void testAddBatch() throws SQLException { - VitessConnection mockConn = PowerMockito.mock(VitessConnection.class); + VitessConnection mockConn = mock(VitessConnection.class); VitessPreparedStatement statement = new VitessPreparedStatement(mockConn, sqlInsert); try { statement.addBatch(this.sqlInsert); - Assert.fail("Should have thrown Exception"); + fail("Should have thrown Exception"); } catch (SQLException ex) { - Assert.assertEquals(Constants.SQLExceptionMessages.METHOD_NOT_ALLOWED, ex.getMessage()); + assertEquals(Constants.SQLExceptionMessages.METHOD_NOT_ALLOWED, ex.getMessage()); } statement.setString(1, "string1"); statement.addBatch(); @@ -651,17 +656,17 @@ private void testExecute(int fetchSize, boolean simpleExecute, boolean shouldRun Field privateStringField = VitessPreparedStatement.class.getDeclaredField("batchedArgs"); privateStringField.setAccessible(true); - Assert.assertEquals("string1", + assertEquals("string1", (((List>) privateStringField.get(statement)).get(0)).get("v1")); } catch (NoSuchFieldException e) { - Assert.fail("Private Field should exists: batchedArgs"); + fail("Private Field should exists: batchedArgs"); } catch (IllegalAccessException e) { - Assert.fail("Private Field should be accessible: batchedArgs"); + fail("Private Field should be accessible: batchedArgs"); } } @Test public void testClearBatch() throws SQLException { - VitessConnection mockConn = PowerMockito.mock(VitessConnection.class); + VitessConnection mockConn = mock(VitessConnection.class); VitessPreparedStatement statement = new VitessPreparedStatement(mockConn, sqlInsert); statement.setString(1, "string1"); statement.addBatch(); @@ -673,44 +678,44 @@ private void testExecute(int fetchSize, boolean simpleExecute, boolean shouldRun Assert.assertTrue( ((List>) privateStringField.get(statement)).isEmpty()); } catch (NoSuchFieldException e) { - Assert.fail("Private Field should exists: batchedArgs"); + fail("Private Field should exists: batchedArgs"); } catch (IllegalAccessException e) { - Assert.fail("Private Field should be accessible: batchedArgs"); + fail("Private Field should be accessible: batchedArgs"); } } @Test public void testExecuteBatch() throws SQLException { - VitessConnection mockConn = PowerMockito.mock(VitessConnection.class); + VitessConnection mockConn = mock(VitessConnection.class); VitessPreparedStatement statement = new VitessPreparedStatement(mockConn, sqlInsert); int[] updateCounts = statement.executeBatch(); - Assert.assertEquals(0, updateCounts.length); + assertEquals(0, updateCounts.length); - VTGateConnection mockVtGateConn = PowerMockito.mock(VTGateConnection.class); - PowerMockito.when(mockConn.getVtGateConn()).thenReturn(mockVtGateConn); - PowerMockito.when(mockConn.getAutoCommit()).thenReturn(true); + VTGateConnection mockVtGateConn = mock(VTGateConnection.class); + when(mockConn.getVtGateConn()).thenReturn(mockVtGateConn); + when(mockConn.getAutoCommit()).thenReturn(true); - SQLFuture mockSqlFutureCursor = PowerMockito.mock(SQLFuture.class); - PowerMockito.when(mockVtGateConn - .executeBatch(Matchers.any(Context.class), Matchers.anyList(), Matchers.anyList(), - Matchers.any(VTSession.class))).thenReturn(mockSqlFutureCursor); + SQLFuture mockSqlFutureCursor = mock(SQLFuture.class); + when(mockVtGateConn + .executeBatch(any(Context.class), Matchers.anyList(), Matchers.anyList(), + any(VTSession.class))).thenReturn(mockSqlFutureCursor); List mockCursorWithErrorList = new ArrayList<>(); - PowerMockito.when(mockSqlFutureCursor.checkedGet()).thenReturn(mockCursorWithErrorList); + when(mockSqlFutureCursor.checkedGet()).thenReturn(mockCursorWithErrorList); - CursorWithError mockCursorWithError1 = PowerMockito.mock(CursorWithError.class); - PowerMockito.when(mockCursorWithError1.getError()).thenReturn(null); - PowerMockito.when(mockCursorWithError1.getCursor()) - .thenReturn(PowerMockito.mock(Cursor.class)); + CursorWithError mockCursorWithError1 = mock(CursorWithError.class); + when(mockCursorWithError1.getError()).thenReturn(null); + when(mockCursorWithError1.getCursor()) + .thenReturn(mock(Cursor.class)); mockCursorWithErrorList.add(mockCursorWithError1); statement.setString(1, "string1"); statement.addBatch(); updateCounts = statement.executeBatch(); - Assert.assertEquals(1, updateCounts.length); + assertEquals(1, updateCounts.length); - CursorWithError mockCursorWithError2 = PowerMockito.mock(CursorWithError.class); + CursorWithError mockCursorWithError2 = mock(CursorWithError.class); Vtrpc.RPCError rpcError = Vtrpc.RPCError.newBuilder().setMessage("preparedStatement execute batch error").build(); - PowerMockito.when(mockCursorWithError2.getError()) + when(mockCursorWithError2.getError()) .thenReturn(rpcError); mockCursorWithErrorList.add(mockCursorWithError2); statement.setString(1, "string1"); @@ -719,16 +724,16 @@ private void testExecute(int fetchSize, boolean simpleExecute, boolean shouldRun statement.addBatch(); try { statement.executeBatch(); - Assert.fail("Should have thrown Exception"); + fail("Should have thrown Exception"); } catch (BatchUpdateException ex) { - Assert.assertEquals(rpcError.toString(), ex.getMessage()); - Assert.assertEquals(2, ex.getUpdateCounts().length); - Assert.assertEquals(Statement.EXECUTE_FAILED, ex.getUpdateCounts()[1]); + assertEquals(rpcError.toString(), ex.getMessage()); + assertEquals(2, ex.getUpdateCounts().length); + assertEquals(Statement.EXECUTE_FAILED, ex.getUpdateCounts()[1]); } } @Test public void testStatementCount() throws SQLException { - VitessConnection mockConn = PowerMockito.mock(VitessConnection.class); + VitessConnection mockConn = mock(VitessConnection.class); Map testCases = ImmutableMap.builder() .put("select * from foo where a = ?", 1) .put("select * from foo where a = ? and b = ?", 2) @@ -750,7 +755,7 @@ private void testExecute(int fetchSize, boolean simpleExecute, boolean shouldRun for (Map.Entry testCase : testCases.entrySet()) { VitessPreparedStatement statement = new VitessPreparedStatement(mockConn, testCase.getKey()); - Assert.assertEquals(testCase.getKey(), testCase.getValue().longValue(), statement.getParameterMetaData().getParameterCount()); + assertEquals(testCase.getKey(), testCase.getValue().longValue(), statement.getParameterMetaData().getParameterCount()); } } } diff --git a/java/jdbc/src/test/java/io/vitess/jdbc/VitessResultSetMetadataTest.java b/java/jdbc/src/test/java/io/vitess/jdbc/VitessResultSetMetadataTest.java index 3f3140aadbe..3ac7be7b5eb 100644 --- a/java/jdbc/src/test/java/io/vitess/jdbc/VitessResultSetMetadataTest.java +++ b/java/jdbc/src/test/java/io/vitess/jdbc/VitessResultSetMetadataTest.java @@ -16,11 +16,6 @@ package io.vitess.jdbc; -import java.sql.ResultSetMetaData; -import java.sql.SQLException; -import java.sql.Types; -import java.util.ArrayList; -import java.util.List; import org.junit.Assert; import org.junit.Test; @@ -28,6 +23,12 @@ import io.vitess.util.Constants; import io.vitess.util.charset.CharsetMapping; +import java.sql.ResultSetMetaData; +import java.sql.SQLException; +import java.sql.Types; +import java.util.ArrayList; +import java.util.List; + /** * Created by ashudeep.sharma on 08/02/16. */ diff --git a/java/jdbc/src/test/java/io/vitess/jdbc/VitessResultSetTest.java b/java/jdbc/src/test/java/io/vitess/jdbc/VitessResultSetTest.java index 12502faa082..ce2c8b79aa2 100644 --- a/java/jdbc/src/test/java/io/vitess/jdbc/VitessResultSetTest.java +++ b/java/jdbc/src/test/java/io/vitess/jdbc/VitessResultSetTest.java @@ -16,17 +16,7 @@ package io.vitess.jdbc; -import java.io.ByteArrayInputStream; -import java.io.IOException; -import java.io.UnsupportedEncodingException; -import java.math.BigDecimal; -import java.math.BigInteger; -import java.sql.Clob; -import java.sql.SQLException; -import java.sql.Time; -import java.sql.Timestamp; -import java.util.Properties; - +import com.google.protobuf.ByteString; import org.junit.Assert; import org.junit.Test; import org.junit.runner.RunWith; @@ -36,8 +26,6 @@ import org.powermock.core.classloader.annotations.PrepareForTest; import org.powermock.modules.junit4.PowerMockRunner; -import com.google.protobuf.ByteString; - import io.vitess.client.cursor.Cursor; import io.vitess.client.cursor.SimpleCursor; import io.vitess.proto.Query; @@ -45,6 +33,19 @@ import io.vitess.util.StringUtils; import io.vitess.util.charset.CharsetMapping; +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.UnsupportedEncodingException; +import java.math.BigDecimal; +import java.math.BigInteger; +import java.sql.Clob; +import java.sql.SQLException; +import java.sql.Time; +import java.sql.Timestamp; +import java.util.Properties; + +import static org.junit.Assert.assertEquals; + /** * Created by harshit.gangal on 19/01/16. */ @@ -85,34 +86,34 @@ public Cursor getCursorWithRows() { UNRECOGNIZED(-1, -1); */ return new SimpleCursor(Query.QueryResult.newBuilder() - .addFields(Query.Field.newBuilder().setName("col1").setType(Query.Type.INT8).build()) - .addFields(Query.Field.newBuilder().setName("col2").setType(Query.Type.UINT8).build()) - .addFields(Query.Field.newBuilder().setName("col3").setType(Query.Type.INT16).build()) - .addFields(Query.Field.newBuilder().setName("col4").setType(Query.Type.UINT16).build()) - .addFields(Query.Field.newBuilder().setName("col5").setType(Query.Type.INT24).build()) - .addFields(Query.Field.newBuilder().setName("col6").setType(Query.Type.UINT24).build()) - .addFields(Query.Field.newBuilder().setName("col7").setType(Query.Type.INT32).build()) - .addFields(Query.Field.newBuilder().setName("col8").setType(Query.Type.UINT32).build()) - .addFields(Query.Field.newBuilder().setName("col9").setType(Query.Type.INT64).build()) - .addFields(Query.Field.newBuilder().setName("col10").setType(Query.Type.UINT64).build()) - .addFields(Query.Field.newBuilder().setName("col11").setType(Query.Type.FLOAT32).build()) - .addFields(Query.Field.newBuilder().setName("col12").setType(Query.Type.FLOAT64).build()) - .addFields(Query.Field.newBuilder().setName("col13").setType(Query.Type.TIMESTAMP).build()) - .addFields(Query.Field.newBuilder().setName("col14").setType(Query.Type.DATE).build()) - .addFields(Query.Field.newBuilder().setName("col15").setType(Query.Type.TIME).build()) - .addFields(Query.Field.newBuilder().setName("col16").setType(Query.Type.DATETIME).build()) - .addFields(Query.Field.newBuilder().setName("col17").setType(Query.Type.YEAR).build()) - .addFields(Query.Field.newBuilder().setName("col18").setType(Query.Type.DECIMAL).build()) - .addFields(Query.Field.newBuilder().setName("col19").setType(Query.Type.TEXT).build()) - .addFields(Query.Field.newBuilder().setName("col20").setType(Query.Type.BLOB).build()) - .addFields(Query.Field.newBuilder().setName("col21").setType(Query.Type.VARCHAR).build()) - .addFields(Query.Field.newBuilder().setName("col22").setType(Query.Type.VARBINARY).build()) - .addFields(Query.Field.newBuilder().setName("col23").setType(Query.Type.CHAR).build()) - .addFields(Query.Field.newBuilder().setName("col24").setType(Query.Type.BINARY).build()) - .addFields(Query.Field.newBuilder().setName("col25").setType(Query.Type.BIT).build()) - .addFields(Query.Field.newBuilder().setName("col26").setType(Query.Type.ENUM).build()) - .addFields(Query.Field.newBuilder().setName("col27").setType(Query.Type.SET).build()) - .addFields(Query.Field.newBuilder().setName("col28").setType(Query.Type.TIMESTAMP).build()) + .addFields(getField("col1", Query.Type.INT8)) + .addFields(getField("col2", Query.Type.UINT8)) + .addFields(getField("col3", Query.Type.INT16)) + .addFields(getField("col4", Query.Type.UINT16)) + .addFields(getField("col5", Query.Type.INT24)) + .addFields(getField("col6", Query.Type.UINT24)) + .addFields(getField("col7", Query.Type.INT32)) + .addFields(getField("col8", Query.Type.UINT32)) + .addFields(getField("col9", Query.Type.INT64)) + .addFields(getField("col10", Query.Type.UINT64)) + .addFields(getField("col11", Query.Type.FLOAT32)) + .addFields(getField("col12", Query.Type.FLOAT64)) + .addFields(getField("col13", Query.Type.TIMESTAMP)) + .addFields(getField("col14", Query.Type.DATE)) + .addFields(getField("col15", Query.Type.TIME)) + .addFields(getField("col16", Query.Type.DATETIME)) + .addFields(getField("col17", Query.Type.YEAR)) + .addFields(getField("col18", Query.Type.DECIMAL)) + .addFields(getField("col19", Query.Type.TEXT)) + .addFields(getField("col20", Query.Type.BLOB)) + .addFields(getField("col21", Query.Type.VARCHAR)) + .addFields(getField("col22", Query.Type.VARBINARY)) + .addFields(getField("col23", Query.Type.CHAR)) + .addFields(getField("col24", Query.Type.BINARY)) + .addFields(getField("col25", Query.Type.BIT)) + .addFields(getField("col26", Query.Type.ENUM)) + .addFields(getField("col27", Query.Type.SET)) + .addFields(getField("col28", Query.Type.TIMESTAMP)) .addRows(Query.Row.newBuilder().addLengths("-50".length()).addLengths("50".length()) .addLengths("-23000".length()).addLengths("23000".length()) .addLengths("-100".length()).addLengths("100".length()).addLengths("-100".length()) @@ -133,6 +134,14 @@ public Cursor getCursorWithRows() { " TDS TEAMHELLO TDS TEAMNHELLO TDS TEAM1val123val1230000-00-00 00:00:00"))).build()); } + private Query.Field getField(String fieldName, Query.Type typ) { + return Query.Field.newBuilder().setName(fieldName).setType(typ).build(); + } + + private Query.Field getField(String fieldName) { + return Query.Field.newBuilder().setName(fieldName).build(); + } + public Cursor getCursorWithRowsAsNull() { /* INT8(1, 257), -50 @@ -166,33 +175,33 @@ public Cursor getCursorWithRowsAsNull() { UNRECOGNIZED(-1, -1); */ return new SimpleCursor(Query.QueryResult.newBuilder() - .addFields(Query.Field.newBuilder().setName("col1").setType(Query.Type.INT8).build()) - .addFields(Query.Field.newBuilder().setName("col2").setType(Query.Type.UINT8).build()) - .addFields(Query.Field.newBuilder().setName("col3").setType(Query.Type.INT16).build()) - .addFields(Query.Field.newBuilder().setName("col4").setType(Query.Type.UINT16).build()) - .addFields(Query.Field.newBuilder().setName("col5").setType(Query.Type.INT24).build()) - .addFields(Query.Field.newBuilder().setName("col6").setType(Query.Type.UINT24).build()) - .addFields(Query.Field.newBuilder().setName("col7").setType(Query.Type.INT32).build()) - .addFields(Query.Field.newBuilder().setName("col8").setType(Query.Type.UINT32).build()) - .addFields(Query.Field.newBuilder().setName("col9").setType(Query.Type.INT64).build()) - .addFields(Query.Field.newBuilder().setName("col10").setType(Query.Type.UINT64).build()) - .addFields(Query.Field.newBuilder().setName("col11").setType(Query.Type.FLOAT32).build()) - .addFields(Query.Field.newBuilder().setName("col12").setType(Query.Type.FLOAT64).build()) - .addFields(Query.Field.newBuilder().setName("col13").setType(Query.Type.TIMESTAMP).build()) - .addFields(Query.Field.newBuilder().setName("col14").setType(Query.Type.DATE).build()) - .addFields(Query.Field.newBuilder().setName("col15").setType(Query.Type.TIME).build()) - .addFields(Query.Field.newBuilder().setName("col16").setType(Query.Type.DATETIME).build()) - .addFields(Query.Field.newBuilder().setName("col17").setType(Query.Type.YEAR).build()) - .addFields(Query.Field.newBuilder().setName("col18").setType(Query.Type.DECIMAL).build()) - .addFields(Query.Field.newBuilder().setName("col19").setType(Query.Type.TEXT).build()) - .addFields(Query.Field.newBuilder().setName("col20").setType(Query.Type.BLOB).build()) - .addFields(Query.Field.newBuilder().setName("col21").setType(Query.Type.VARCHAR).build()) - .addFields(Query.Field.newBuilder().setName("col22").setType(Query.Type.VARBINARY).build()) - .addFields(Query.Field.newBuilder().setName("col23").setType(Query.Type.CHAR).build()) - .addFields(Query.Field.newBuilder().setName("col24").setType(Query.Type.BINARY).build()) - .addFields(Query.Field.newBuilder().setName("col25").setType(Query.Type.BIT).build()) - .addFields(Query.Field.newBuilder().setName("col26").setType(Query.Type.ENUM).build()) - .addFields(Query.Field.newBuilder().setName("col27").setType(Query.Type.SET).build()) + .addFields(getField("col1", Query.Type.INT8)) + .addFields(getField("col2", Query.Type.UINT8)) + .addFields(getField("col3", Query.Type.INT16)) + .addFields(getField("col4", Query.Type.UINT16)) + .addFields(getField("col5", Query.Type.INT24)) + .addFields(getField("col6", Query.Type.UINT24)) + .addFields(getField("col7", Query.Type.INT32)) + .addFields(getField("col8", Query.Type.UINT32)) + .addFields(getField("col9", Query.Type.INT64)) + .addFields(getField("col10", Query.Type.UINT64)) + .addFields(getField("col11", Query.Type.FLOAT32)) + .addFields(getField("col12", Query.Type.FLOAT64)) + .addFields(getField("col13", Query.Type.TIMESTAMP)) + .addFields(getField("col14", Query.Type.DATE)) + .addFields(getField("col15", Query.Type.TIME)) + .addFields(getField("col16", Query.Type.DATETIME)) + .addFields(getField("col17", Query.Type.YEAR)) + .addFields(getField("col18", Query.Type.DECIMAL)) + .addFields(getField("col19", Query.Type.TEXT)) + .addFields(getField("col20", Query.Type.BLOB)) + .addFields(getField("col21", Query.Type.VARCHAR)) + .addFields(getField("col22", Query.Type.VARBINARY)) + .addFields(getField("col23", Query.Type.CHAR)) + .addFields(getField("col24", Query.Type.BINARY)) + .addFields(getField("col25", Query.Type.BIT)) + .addFields(getField("col26", Query.Type.ENUM)) + .addFields(getField("col27", Query.Type.SET)) .addRows(Query.Row.newBuilder().addLengths("-50".length()).addLengths("50".length()) .addLengths("-23000".length()).addLengths("23000".length()) .addLengths("-100".length()).addLengths("100".length()).addLengths("-100".length()) @@ -212,52 +221,52 @@ public Cursor getCursorWithRowsAsNull() { @Test public void testNextWithZeroRows() throws Exception { Cursor cursor = new SimpleCursor(Query.QueryResult.newBuilder() - .addFields(Query.Field.newBuilder().setName("col0").build()) - .addFields(Query.Field.newBuilder().setName("col1").build()) - .addFields(Query.Field.newBuilder().setName("col2").build()).build()); + .addFields(getField("col0")) + .addFields(getField("col1")) + .addFields(getField("col2")).build()); VitessResultSet vitessResultSet = new VitessResultSet(cursor); - Assert.assertEquals(false, vitessResultSet.next()); + assertEquals(false, vitessResultSet.next()); } @Test public void testNextWithNonZeroRows() throws Exception { Cursor cursor = getCursorWithRows(); VitessResultSet vitessResultSet = new VitessResultSet(cursor); - Assert.assertEquals(true, vitessResultSet.next()); - Assert.assertEquals(false, vitessResultSet.next()); + assertEquals(true, vitessResultSet.next()); + assertEquals(false, vitessResultSet.next()); } @Test public void testgetString() throws SQLException { Cursor cursor = getCursorWithRowsAsNull(); VitessResultSet vitessResultSet = new VitessResultSet(cursor, getVitessStatement()); vitessResultSet.next(); - Assert.assertEquals("-50", vitessResultSet.getString(1)); - Assert.assertEquals("50", vitessResultSet.getString(2)); - Assert.assertEquals("-23000", vitessResultSet.getString(3)); - Assert.assertEquals("23000", vitessResultSet.getString(4)); - Assert.assertEquals("-100", vitessResultSet.getString(5)); - Assert.assertEquals("100", vitessResultSet.getString(6)); - Assert.assertEquals("-100", vitessResultSet.getString(7)); - Assert.assertEquals("100", vitessResultSet.getString(8)); - Assert.assertEquals("-1000", vitessResultSet.getString(9)); - Assert.assertEquals("1000", vitessResultSet.getString(10)); - Assert.assertEquals("24.52", vitessResultSet.getString(11)); - Assert.assertEquals("100.43", vitessResultSet.getString(12)); - Assert.assertEquals("2016-02-06 14:15:16.0", vitessResultSet.getString(13)); - Assert.assertEquals("2016-02-06", vitessResultSet.getString(14)); - Assert.assertEquals("12:34:56", vitessResultSet.getString(15)); - Assert.assertEquals("2016-02-06 14:15:16.0", vitessResultSet.getString(16)); - Assert.assertEquals("2016", vitessResultSet.getString(17)); - Assert.assertEquals("1234.56789", vitessResultSet.getString(18)); - Assert.assertEquals("HELLO TDS TEAM", vitessResultSet.getString(19)); - Assert.assertEquals("HELLO TDS TEAM", vitessResultSet.getString(20)); - Assert.assertEquals("HELLO TDS TEAM", vitessResultSet.getString(21)); - Assert.assertEquals("HELLO TDS TEAM", vitessResultSet.getString(22)); - Assert.assertEquals("N", vitessResultSet.getString(23)); - Assert.assertEquals("HELLO TDS TEAM", vitessResultSet.getString(24)); - Assert.assertEquals("0", vitessResultSet.getString(25)); - Assert.assertEquals("val123", vitessResultSet.getString(26)); - Assert.assertEquals(null, vitessResultSet.getString(27)); + assertEquals("-50", vitessResultSet.getString(1)); + assertEquals("50", vitessResultSet.getString(2)); + assertEquals("-23000", vitessResultSet.getString(3)); + assertEquals("23000", vitessResultSet.getString(4)); + assertEquals("-100", vitessResultSet.getString(5)); + assertEquals("100", vitessResultSet.getString(6)); + assertEquals("-100", vitessResultSet.getString(7)); + assertEquals("100", vitessResultSet.getString(8)); + assertEquals("-1000", vitessResultSet.getString(9)); + assertEquals("1000", vitessResultSet.getString(10)); + assertEquals("24.52", vitessResultSet.getString(11)); + assertEquals("100.43", vitessResultSet.getString(12)); + assertEquals("2016-02-06 14:15:16.0", vitessResultSet.getString(13)); + assertEquals("2016-02-06", vitessResultSet.getString(14)); + assertEquals("12:34:56", vitessResultSet.getString(15)); + assertEquals("2016-02-06 14:15:16.0", vitessResultSet.getString(16)); + assertEquals("2016", vitessResultSet.getString(17)); + assertEquals("1234.56789", vitessResultSet.getString(18)); + assertEquals("HELLO TDS TEAM", vitessResultSet.getString(19)); + assertEquals("HELLO TDS TEAM", vitessResultSet.getString(20)); + assertEquals("HELLO TDS TEAM", vitessResultSet.getString(21)); + assertEquals("HELLO TDS TEAM", vitessResultSet.getString(22)); + assertEquals("N", vitessResultSet.getString(23)); + assertEquals("HELLO TDS TEAM", vitessResultSet.getString(24)); + assertEquals("0", vitessResultSet.getString(25)); + assertEquals("val123", vitessResultSet.getString(26)); + assertEquals(null, vitessResultSet.getString(27)); } @Test public void getObjectUint64AsBigInteger() throws SQLException { @@ -265,7 +274,7 @@ public Cursor getCursorWithRowsAsNull() { VitessResultSet vitessResultSet = new VitessResultSet(cursor, getVitessStatement()); vitessResultSet.next(); - Assert.assertEquals(new BigInteger("1000"), vitessResultSet.getObject(10)); + assertEquals(new BigInteger("1000"), vitessResultSet.getObject(10)); } @Test public void getBigInteger() throws SQLException { @@ -273,7 +282,7 @@ public Cursor getCursorWithRowsAsNull() { VitessResultSet vitessResultSet = new VitessResultSet(cursor, getVitessStatement()); vitessResultSet.next(); - Assert.assertEquals(new BigInteger("1000"), vitessResultSet.getBigInteger(10)); + assertEquals(new BigInteger("1000"), vitessResultSet.getBigInteger(10)); } @Test public void testgetBoolean() throws SQLException { @@ -281,62 +290,62 @@ public Cursor getCursorWithRowsAsNull() { Cursor cursorWithRowsAsNull = getCursorWithRowsAsNull(); VitessResultSet vitessResultSet = new VitessResultSet(cursor, getVitessStatement()); vitessResultSet.next(); - Assert.assertEquals(true, vitessResultSet.getBoolean(25)); - Assert.assertEquals(false, vitessResultSet.getBoolean(1)); + assertEquals(true, vitessResultSet.getBoolean(25)); + assertEquals(false, vitessResultSet.getBoolean(1)); vitessResultSet = new VitessResultSet(cursorWithRowsAsNull, getVitessStatement()); vitessResultSet.next(); - Assert.assertEquals(false, vitessResultSet.getBoolean(25)); - Assert.assertEquals(false, vitessResultSet.getBoolean(1)); + assertEquals(false, vitessResultSet.getBoolean(25)); + assertEquals(false, vitessResultSet.getBoolean(1)); } @Test public void testgetByte() throws SQLException { Cursor cursor = getCursorWithRows(); VitessResultSet vitessResultSet = new VitessResultSet(cursor, getVitessStatement()); vitessResultSet.next(); - Assert.assertEquals(-50, vitessResultSet.getByte(1)); - Assert.assertEquals(1, vitessResultSet.getByte(25)); + assertEquals(-50, vitessResultSet.getByte(1)); + assertEquals(1, vitessResultSet.getByte(25)); } @Test public void testgetShort() throws SQLException { Cursor cursor = getCursorWithRows(); VitessResultSet vitessResultSet = new VitessResultSet(cursor, getVitessStatement()); vitessResultSet.next(); - Assert.assertEquals(-23000, vitessResultSet.getShort(3)); + assertEquals(-23000, vitessResultSet.getShort(3)); } @Test public void testgetInt() throws SQLException { Cursor cursor = getCursorWithRows(); VitessResultSet vitessResultSet = new VitessResultSet(cursor, getVitessStatement()); vitessResultSet.next(); - Assert.assertEquals(-100, vitessResultSet.getInt(7)); + assertEquals(-100, vitessResultSet.getInt(7)); } @Test public void testgetLong() throws SQLException { Cursor cursor = getCursorWithRows(); VitessResultSet vitessResultSet = new VitessResultSet(cursor, getVitessStatement()); vitessResultSet.next(); - Assert.assertEquals(-1000, vitessResultSet.getInt(9)); + assertEquals(-1000, vitessResultSet.getInt(9)); } @Test public void testgetFloat() throws SQLException { Cursor cursor = getCursorWithRows(); VitessResultSet vitessResultSet = new VitessResultSet(cursor, getVitessStatement()); vitessResultSet.next(); - Assert.assertEquals(24.52f, vitessResultSet.getFloat(11), 0.001); + assertEquals(24.52f, vitessResultSet.getFloat(11), 0.001); } @Test public void testgetDouble() throws SQLException { Cursor cursor = getCursorWithRows(); VitessResultSet vitessResultSet = new VitessResultSet(cursor, getVitessStatement()); vitessResultSet.next(); - Assert.assertEquals(100.43, vitessResultSet.getFloat(12), 0.001); + assertEquals(100.43, vitessResultSet.getFloat(12), 0.001); } @Test public void testBigDecimal() throws SQLException { Cursor cursor = getCursorWithRows(); VitessResultSet vitessResultSet = new VitessResultSet(cursor, getVitessStatement()); vitessResultSet.next(); - Assert.assertEquals(new BigDecimal(BigInteger.valueOf(123456789), 5), + assertEquals(new BigDecimal(BigInteger.valueOf(123456789), 5), vitessResultSet.getBigDecimal(18)); } @@ -351,21 +360,21 @@ public Cursor getCursorWithRowsAsNull() { Cursor cursor = getCursorWithRows(); VitessResultSet vitessResultSet = new VitessResultSet(cursor, getVitessStatement()); vitessResultSet.next(); - Assert.assertEquals(new java.sql.Date(116, 1, 6), vitessResultSet.getDate(14)); + assertEquals(new java.sql.Date(116, 1, 6), vitessResultSet.getDate(14)); } @Test public void testgetTime() throws SQLException, UnsupportedEncodingException { Cursor cursor = getCursorWithRows(); VitessResultSet vitessResultSet = new VitessResultSet(cursor, getVitessStatement()); vitessResultSet.next(); - Assert.assertEquals(new Time(12, 34, 56), vitessResultSet.getTime(15)); + assertEquals(new Time(12, 34, 56), vitessResultSet.getTime(15)); } @Test public void testgetTimestamp() throws SQLException, UnsupportedEncodingException { Cursor cursor = getCursorWithRows(); VitessResultSet vitessResultSet = new VitessResultSet(cursor, getVitessStatement()); vitessResultSet.next(); - Assert.assertEquals(new Timestamp(116, 1, 6, 14, 15, 16, 0), + assertEquals(new Timestamp(116, 1, 6, 14, 15, 16, 0), vitessResultSet.getTimestamp(13)); } @@ -375,7 +384,7 @@ public Cursor getCursorWithRowsAsNull() { new VitessStatement(new VitessConnection( "jdbc:vitess://locahost:9000/vt_keyspace/keyspace?zeroDateTimeBehavior=garble", new Properties()))); vitessResultSet.next(); - Assert.assertEquals("0002-11-30 00:00:00.0", + assertEquals("0002-11-30 00:00:00.0", vitessResultSet.getTimestamp(28).toString()); } @@ -406,7 +415,7 @@ public Cursor getCursorWithRowsAsNull() { new VitessStatement(new VitessConnection( "jdbc:vitess://locahost:9000/vt_keyspace/keyspace?zeroDateTimeBehavior=round", new Properties()))); vitessResultSet.next(); - Assert.assertEquals("0001-01-01 00:00:00.0", vitessResultSet.getTimestamp(28).toString()); + assertEquals("0001-01-01 00:00:00.0", vitessResultSet.getTimestamp(28).toString()); } @Test public void testgetZeroDateRound() throws SQLException, UnsupportedEncodingException { @@ -415,84 +424,84 @@ public Cursor getCursorWithRowsAsNull() { new VitessStatement(new VitessConnection( "jdbc:vitess://locahost:9000/vt_keyspace/keyspace?zeroDateTimeBehavior=round", new Properties()))); vitessResultSet.next(); - Assert.assertEquals("0001-01-01", vitessResultSet.getDate(28).toString()); + assertEquals("0001-01-01", vitessResultSet.getDate(28).toString()); } @Test public void testgetStringbyColumnLabel() throws SQLException { Cursor cursor = getCursorWithRows(); VitessResultSet vitessResultSet = new VitessResultSet(cursor, getVitessStatement()); vitessResultSet.next(); - Assert.assertEquals("-50", vitessResultSet.getString("col1")); - Assert.assertEquals("50", vitessResultSet.getString("col2")); - Assert.assertEquals("-23000", vitessResultSet.getString("col3")); - Assert.assertEquals("23000", vitessResultSet.getString("col4")); - Assert.assertEquals("-100", vitessResultSet.getString("col5")); - Assert.assertEquals("100", vitessResultSet.getString("col6")); - Assert.assertEquals("-100", vitessResultSet.getString("col7")); - Assert.assertEquals("100", vitessResultSet.getString("col8")); - Assert.assertEquals("-1000", vitessResultSet.getString("col9")); - Assert.assertEquals("1000", vitessResultSet.getString("col10")); - Assert.assertEquals("24.52", vitessResultSet.getString("col11")); - Assert.assertEquals("100.43", vitessResultSet.getString("col12")); - Assert.assertEquals("2016-02-06 14:15:16.0", vitessResultSet.getString("col13")); - Assert.assertEquals("2016-02-06", vitessResultSet.getString("col14")); - Assert.assertEquals("12:34:56", vitessResultSet.getString("col15")); - Assert.assertEquals("2016-02-06 14:15:16.0", vitessResultSet.getString("col16")); - Assert.assertEquals("2016", vitessResultSet.getString("col17")); - Assert.assertEquals("1234.56789", vitessResultSet.getString("col18")); - Assert.assertEquals("HELLO TDS TEAM", vitessResultSet.getString("col19")); - Assert.assertEquals("HELLO TDS TEAM", vitessResultSet.getString("col20")); - Assert.assertEquals("HELLO TDS TEAM", vitessResultSet.getString("col21")); - Assert.assertEquals("HELLO TDS TEAM", vitessResultSet.getString("col22")); - Assert.assertEquals("N", vitessResultSet.getString("col23")); - Assert.assertEquals("HELLO TDS TEAM", vitessResultSet.getString("col24")); - Assert.assertEquals("1", vitessResultSet.getString("col25")); - Assert.assertEquals("val123", vitessResultSet.getString("col26")); - Assert.assertEquals("val123", vitessResultSet.getString("col27")); + assertEquals("-50", vitessResultSet.getString("col1")); + assertEquals("50", vitessResultSet.getString("col2")); + assertEquals("-23000", vitessResultSet.getString("col3")); + assertEquals("23000", vitessResultSet.getString("col4")); + assertEquals("-100", vitessResultSet.getString("col5")); + assertEquals("100", vitessResultSet.getString("col6")); + assertEquals("-100", vitessResultSet.getString("col7")); + assertEquals("100", vitessResultSet.getString("col8")); + assertEquals("-1000", vitessResultSet.getString("col9")); + assertEquals("1000", vitessResultSet.getString("col10")); + assertEquals("24.52", vitessResultSet.getString("col11")); + assertEquals("100.43", vitessResultSet.getString("col12")); + assertEquals("2016-02-06 14:15:16.0", vitessResultSet.getString("col13")); + assertEquals("2016-02-06", vitessResultSet.getString("col14")); + assertEquals("12:34:56", vitessResultSet.getString("col15")); + assertEquals("2016-02-06 14:15:16.0", vitessResultSet.getString("col16")); + assertEquals("2016", vitessResultSet.getString("col17")); + assertEquals("1234.56789", vitessResultSet.getString("col18")); + assertEquals("HELLO TDS TEAM", vitessResultSet.getString("col19")); + assertEquals("HELLO TDS TEAM", vitessResultSet.getString("col20")); + assertEquals("HELLO TDS TEAM", vitessResultSet.getString("col21")); + assertEquals("HELLO TDS TEAM", vitessResultSet.getString("col22")); + assertEquals("N", vitessResultSet.getString("col23")); + assertEquals("HELLO TDS TEAM", vitessResultSet.getString("col24")); + assertEquals("1", vitessResultSet.getString("col25")); + assertEquals("val123", vitessResultSet.getString("col26")); + assertEquals("val123", vitessResultSet.getString("col27")); } @Test public void testgetBooleanbyColumnLabel() throws SQLException { Cursor cursor = getCursorWithRows(); VitessResultSet vitessResultSet = new VitessResultSet(cursor, getVitessStatement()); vitessResultSet.next(); - Assert.assertEquals(true, vitessResultSet.getBoolean("col25")); - Assert.assertEquals(false, vitessResultSet.getBoolean("col1")); + assertEquals(true, vitessResultSet.getBoolean("col25")); + assertEquals(false, vitessResultSet.getBoolean("col1")); } @Test public void testgetBytebyColumnLabel() throws SQLException { Cursor cursor = getCursorWithRows(); VitessResultSet vitessResultSet = new VitessResultSet(cursor, getVitessStatement()); vitessResultSet.next(); - Assert.assertEquals(-50, vitessResultSet.getByte("col1")); - Assert.assertEquals(1, vitessResultSet.getByte("col25")); + assertEquals(-50, vitessResultSet.getByte("col1")); + assertEquals(1, vitessResultSet.getByte("col25")); } @Test public void testgetShortbyColumnLabel() throws SQLException { Cursor cursor = getCursorWithRows(); VitessResultSet vitessResultSet = new VitessResultSet(cursor, getVitessStatement()); vitessResultSet.next(); - Assert.assertEquals(-23000, vitessResultSet.getShort("col3")); + assertEquals(-23000, vitessResultSet.getShort("col3")); } @Test public void testgetIntbyColumnLabel() throws SQLException { Cursor cursor = getCursorWithRows(); VitessResultSet vitessResultSet = new VitessResultSet(cursor, getVitessStatement()); vitessResultSet.next(); - Assert.assertEquals(-100, vitessResultSet.getInt("col7")); + assertEquals(-100, vitessResultSet.getInt("col7")); } @Test public void testgetLongbyColumnLabel() throws SQLException { Cursor cursor = getCursorWithRows(); VitessResultSet vitessResultSet = new VitessResultSet(cursor, getVitessStatement()); vitessResultSet.next(); - Assert.assertEquals(-1000, vitessResultSet.getInt("col9")); + assertEquals(-1000, vitessResultSet.getInt("col9")); } @Test public void testBigIntegerbyColumnLabel() throws SQLException { Cursor cursor = getCursorWithRows(); VitessResultSet vitessResultSet = new VitessResultSet(cursor, getVitessStatement()); vitessResultSet.next(); - Assert.assertEquals(new BigInteger("1000"), + assertEquals(new BigInteger("1000"), vitessResultSet.getBigInteger("col10")); } @@ -500,21 +509,21 @@ public Cursor getCursorWithRowsAsNull() { Cursor cursor = getCursorWithRows(); VitessResultSet vitessResultSet = new VitessResultSet(cursor, getVitessStatement()); vitessResultSet.next(); - Assert.assertEquals(24.52f, vitessResultSet.getFloat("col11"), 0.001); + assertEquals(24.52f, vitessResultSet.getFloat("col11"), 0.001); } @Test public void testgetDoublebyColumnLabel() throws SQLException { Cursor cursor = getCursorWithRows(); VitessResultSet vitessResultSet = new VitessResultSet(cursor, getVitessStatement()); vitessResultSet.next(); - Assert.assertEquals(100.43, vitessResultSet.getFloat("col12"), 0.001); + assertEquals(100.43, vitessResultSet.getFloat("col12"), 0.001); } @Test public void testBigDecimalbyColumnLabel() throws SQLException { Cursor cursor = getCursorWithRows(); VitessResultSet vitessResultSet = new VitessResultSet(cursor, getVitessStatement()); vitessResultSet.next(); - Assert.assertEquals(new BigDecimal(BigInteger.valueOf(123456789), 5), + assertEquals(new BigDecimal(BigInteger.valueOf(123456789), 5), vitessResultSet.getBigDecimal("col18")); } @@ -531,14 +540,14 @@ public Cursor getCursorWithRowsAsNull() { Cursor cursor = getCursorWithRows(); VitessResultSet vitessResultSet = new VitessResultSet(cursor, getVitessStatement()); vitessResultSet.next(); - Assert.assertEquals(new java.sql.Date(116, 1, 6), vitessResultSet.getDate("col14")); + assertEquals(new java.sql.Date(116, 1, 6), vitessResultSet.getDate("col14")); } @Test public void testgetTimebyColumnLabel() throws SQLException, UnsupportedEncodingException { Cursor cursor = getCursorWithRows(); VitessResultSet vitessResultSet = new VitessResultSet(cursor, getVitessStatement()); vitessResultSet.next(); - Assert.assertEquals(new Time(12, 34, 56), vitessResultSet.getTime("col15")); + assertEquals(new Time(12, 34, 56), vitessResultSet.getTime("col15")); } @Test public void testgetTimestampbyColumnLabel() @@ -546,7 +555,7 @@ public Cursor getCursorWithRowsAsNull() { Cursor cursor = getCursorWithRows(); VitessResultSet vitessResultSet = new VitessResultSet(cursor, getVitessStatement()); vitessResultSet.next(); - Assert.assertEquals(new Timestamp(116, 1, 6, 14, 15, 16, 0), + assertEquals(new Timestamp(116, 1, 6, 14, 15, 16, 0), vitessResultSet.getTimestamp("col13")); } @@ -573,13 +582,13 @@ public Cursor getCursorWithRowsAsNull() { vitessResultSet.getBinaryStream(22).read(ba3, 0, 128); Assert.assertArrayEquals(ba1, ba3); - Assert.assertEquals(null, vitessResultSet.getBinaryStream(27)); + assertEquals(null, vitessResultSet.getBinaryStream(27)); } @Test public void testEnhancedFieldsFromCursor() throws Exception { Cursor cursor = getCursorWithRows(); VitessResultSet vitessResultSet = new VitessResultSet(cursor, getVitessStatement()); - Assert.assertEquals(cursor.getFields().size(), vitessResultSet.getFields().size()); + assertEquals(cursor.getFields().size(), vitessResultSet.getFields().size()); } @Test public void testGetStringUsesEncoding() throws Exception { @@ -626,8 +635,8 @@ public Cursor getCursorWithRowsAsNull() { VitessResultSet vitessResultSet = PowerMockito.spy(new VitessResultSet(new SimpleCursor(result), new VitessStatement(conn))); vitessResultSet.next(); - Assert.assertEquals(true, vitessResultSet.getObject(1)); - Assert.assertEquals(false, vitessResultSet.getObject(2)); + assertEquals(true, vitessResultSet.getObject(1)); + assertEquals(false, vitessResultSet.getObject(2)); Assert.assertArrayEquals(new byte[] {1,2,3,4}, (byte[]) vitessResultSet.getObject(3)); PowerMockito.verifyPrivate(vitessResultSet, VerificationModeFactory.times(3)).invoke("convertBytesIfPossible", Matchers.any(byte[].class), Matchers.any(FieldWithMetadata.class)); @@ -785,12 +794,12 @@ public Cursor getCursorWithRowsAsNull() { VitessResultSet vitessResultSet = PowerMockito.spy(new VitessResultSet(new SimpleCursor(result), new VitessStatement(conn))); vitessResultSet.next(); - Assert.assertEquals(trimmedCharStr, vitessResultSet.getObject(1)); - Assert.assertEquals(varcharStr, vitessResultSet.getObject(2)); + assertEquals(trimmedCharStr, vitessResultSet.getObject(1)); + assertEquals(varcharStr, vitessResultSet.getObject(2)); Assert.assertArrayEquals(opaqueBinary, (byte[]) vitessResultSet.getObject(3)); - Assert.assertEquals(masqueradingBlobStr, vitessResultSet.getObject(4)); - Assert.assertEquals(textStr, vitessResultSet.getObject(5)); - Assert.assertEquals(jsonStr, vitessResultSet.getObject(6)); + assertEquals(masqueradingBlobStr, vitessResultSet.getObject(4)); + assertEquals(textStr, vitessResultSet.getObject(5)); + assertEquals(jsonStr, vitessResultSet.getObject(6)); PowerMockito.verifyPrivate(vitessResultSet, VerificationModeFactory.times(6)).invoke("convertBytesIfPossible", Matchers.any(byte[].class), Matchers.any(FieldWithMetadata.class)); @@ -816,9 +825,9 @@ public Cursor getCursorWithRowsAsNull() { Assert.assertTrue(vitessResultSet.next()); Clob clob = vitessResultSet.getClob(1); - Assert.assertEquals("clobValue", clob.getSubString(1, (int) clob.length())); + assertEquals("clobValue", clob.getSubString(1, (int) clob.length())); clob = vitessResultSet.getClob("clob"); - Assert.assertEquals("clobValue", clob.getSubString(1, (int) clob.length())); + assertEquals("clobValue", clob.getSubString(1, (int) clob.length())); } } diff --git a/java/jdbc/src/test/java/io/vitess/jdbc/VitessStatementTest.java b/java/jdbc/src/test/java/io/vitess/jdbc/VitessStatementTest.java index 39482e9f14c..2880d5ab3fe 100644 --- a/java/jdbc/src/test/java/io/vitess/jdbc/VitessStatementTest.java +++ b/java/jdbc/src/test/java/io/vitess/jdbc/VitessStatementTest.java @@ -1,12 +1,12 @@ /* * Copyright 2017 Google Inc. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -16,17 +16,8 @@ package io.vitess.jdbc; -import java.lang.reflect.Field; -import java.sql.BatchUpdateException; -import java.sql.ResultSet; -import java.sql.SQLException; -import java.sql.Statement; -import java.util.ArrayList; -import java.util.List; -import org.junit.Assert; import org.junit.Test; import org.junit.runner.RunWith; -import org.mockito.Matchers; import org.mockito.Mockito; import org.powermock.api.mockito.PowerMockito; import org.powermock.core.classloader.annotations.PrepareForTest; @@ -42,12 +33,37 @@ import io.vitess.proto.Vtrpc; import io.vitess.util.Constants; +import java.lang.reflect.Field; +import java.sql.BatchUpdateException; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Statement; +import java.util.ArrayList; +import java.util.List; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.anyList; +import static org.mockito.Matchers.anyMap; +import static org.mockito.Matchers.anyString; +import static org.mockito.Mockito.verify; +import static org.powermock.api.mockito.PowerMockito.doReturn; +import static org.powermock.api.mockito.PowerMockito.mock; +import static org.powermock.api.mockito.PowerMockito.when; + /** * Created by harshit.gangal on 19/01/16. */ -@RunWith(PowerMockRunner.class) @PrepareForTest({VTGateConnection.class, - Vtrpc.RPCError.class}) public class VitessStatementTest { +@RunWith(PowerMockRunner.class) +@PrepareForTest({VTGateConnection.class, + Vtrpc.RPCError.class}) +public class VitessStatementTest { private String sqlSelect = "select 1 from test_table"; private String sqlShow = "show tables"; @@ -56,132 +72,117 @@ private String sqlUpsert = "insert into test_table(msg) values ('abc') on duplicate key update msg = 'def'"; - @Test public void testGetConnection() { - VitessConnection mockConn = PowerMockito.mock(VitessConnection.class); + @Test + public void testGetConnection() throws SQLException { + VitessConnection mockConn = mock(VitessConnection.class); VitessStatement statement = new VitessStatement(mockConn); - try { - Assert.assertEquals(mockConn, statement.getConnection()); - } catch (SQLException e) { - Assert.fail("Connection Object is different than expect"); - } + assertEquals(mockConn, statement.getConnection()); } - @Test public void testGetResultSet() { - VitessConnection mockConn = PowerMockito.mock(VitessConnection.class); + @Test + public void testGetResultSet() throws SQLException { + VitessConnection mockConn = mock(VitessConnection.class); VitessStatement statement = new VitessStatement(mockConn); - try { - Assert.assertEquals(null, statement.getResultSet()); - } catch (SQLException e) { - Assert.fail("ResultSet Object is different than expect"); - } + assertEquals(null, statement.getResultSet()); } - @Test public void testExecuteQuery() throws SQLException { - VitessConnection mockConn = PowerMockito.mock(VitessConnection.class); - VTGateConnection mockVtGateConn = PowerMockito.mock(VTGateConnection.class); - Cursor mockCursor = PowerMockito.mock(Cursor.class); - SQLFuture mockSqlFutureCursor = PowerMockito.mock(SQLFuture.class); - - PowerMockito.when(mockConn.getVtGateConn()).thenReturn(mockVtGateConn); - PowerMockito.when(mockVtGateConn - .execute(Matchers.any(Context.class), Matchers.anyString(), Matchers.anyMap(), Matchers.any( - VTSession.class))).thenReturn(mockSqlFutureCursor); - PowerMockito.when(mockConn.isSimpleExecute()).thenReturn(true); - PowerMockito.when(mockSqlFutureCursor.checkedGet()).thenReturn(mockCursor); - PowerMockito.when(mockCursor.getFields()).thenReturn(Query.QueryResult.getDefaultInstance().getFieldsList()); + @Test + public void testExecuteQuery() throws SQLException { + VitessConnection mockConn = mock(VitessConnection.class); + VTGateConnection mockVtGateConn = mock(VTGateConnection.class); + Cursor mockCursor = mock(Cursor.class); + SQLFuture mockSqlFutureCursor = mock(SQLFuture.class); + + when(mockConn.getVtGateConn()).thenReturn(mockVtGateConn); + when(mockVtGateConn + .execute(any(Context.class), anyString(), anyMap(), any( + VTSession.class))).thenReturn(mockSqlFutureCursor); + when(mockConn.isSimpleExecute()).thenReturn(true); + when(mockSqlFutureCursor.checkedGet()).thenReturn(mockCursor); + when(mockCursor.getFields()).thenReturn(Query.QueryResult.getDefaultInstance().getFieldsList()); VitessStatement statement = new VitessStatement(mockConn); + //Empty Sql Statement try { + statement.executeQuery(""); + fail("Should have thrown exception for empty sql"); + } catch (SQLException ex) { + assertEquals("SQL statement is not valid", ex.getMessage()); + } + + ResultSet rs = statement.executeQuery(sqlSelect); + assertEquals(-1, statement.getUpdateCount()); + + //autocommit is false and not in transaction + when(mockConn.getAutoCommit()).thenReturn(false); + when(mockConn.isInTransaction()).thenReturn(false); + rs = statement.executeQuery(sqlSelect); + assertEquals(-1, statement.getUpdateCount()); - //Empty Sql Statement - try { - statement.executeQuery(""); - Assert.fail("Should have thrown exception for empty sql"); - } catch (SQLException ex) { - Assert.assertEquals("SQL statement is not valid", ex.getMessage()); - } - - ResultSet rs = statement.executeQuery(sqlSelect); - Assert.assertEquals(-1, statement.getUpdateCount()); - - //autocommit is false and not in transaction - PowerMockito.when(mockConn.getAutoCommit()).thenReturn(false); - PowerMockito.when(mockConn.isInTransaction()).thenReturn(false); - rs = statement.executeQuery(sqlSelect); - Assert.assertEquals(-1, statement.getUpdateCount()); - - //when returned cursor is null - PowerMockito.when(mockSqlFutureCursor.checkedGet()).thenReturn(null); - try { - statement.executeQuery(sqlSelect); - Assert.fail("Should have thrown exception for cursor null"); - } catch (SQLException ex) { - Assert.assertEquals("Failed to execute this method", ex.getMessage()); - } - - } catch (SQLException e) { - Assert.fail("Test failed " + e.getMessage()); + //when returned cursor is null + when(mockSqlFutureCursor.checkedGet()).thenReturn(null); + try { + statement.executeQuery(sqlSelect); + fail("Should have thrown exception for cursor null"); + } catch (SQLException ex) { + assertEquals("Failed to execute this method", ex.getMessage()); } } - @Test public void testExecuteQueryWithStreamExecuteType() throws SQLException { - VitessConnection mockConn = PowerMockito.mock(VitessConnection.class); - VTGateConnection mockVtGateConn = PowerMockito.mock(VTGateConnection.class); - Cursor mockCursor = PowerMockito.mock(Cursor.class); - SQLFuture mockSqlFutureCursor = PowerMockito.mock(SQLFuture.class); - - PowerMockito.when(mockConn.getVtGateConn()).thenReturn(mockVtGateConn); - PowerMockito.when(mockVtGateConn - .streamExecute(Matchers.any(Context.class), Matchers.anyString(), Matchers.anyMap(), - Matchers.any(VTSession.class))).thenReturn(mockCursor); - PowerMockito.when(mockConn.getExecuteType()) - .thenReturn(Constants.QueryExecuteType.STREAM); - PowerMockito.when(mockSqlFutureCursor.checkedGet()).thenReturn(mockCursor); - PowerMockito.when(mockCursor.getFields()).thenReturn(Query.QueryResult.getDefaultInstance().getFieldsList()); + @Test + public void testExecuteQueryWithStreamExecuteType() throws SQLException { + VitessConnection mockConn = mock(VitessConnection.class); + VTGateConnection mockVtGateConn = mock(VTGateConnection.class); + Cursor mockCursor = mock(Cursor.class); + SQLFuture mockSqlFutureCursor = mock(SQLFuture.class); + + when(mockConn.getVtGateConn()).thenReturn(mockVtGateConn); + when(mockVtGateConn + .streamExecute(any(Context.class), anyString(), anyMap(), + any(VTSession.class))).thenReturn(mockCursor); + when(mockConn.getExecuteType()) + .thenReturn(Constants.QueryExecuteType.STREAM); + when(mockSqlFutureCursor.checkedGet()).thenReturn(mockCursor); + when(mockCursor.getFields()).thenReturn(Query.QueryResult.getDefaultInstance().getFieldsList()); VitessStatement statement = new VitessStatement(mockConn); + //Empty Sql Statement try { + statement.executeQuery(""); + fail("Should have thrown exception for empty sql"); + } catch (SQLException ex) { + assertEquals("SQL statement is not valid", ex.getMessage()); + } + + //select on replica + ResultSet rs = statement.executeQuery(sqlSelect); + assertEquals(-1, statement.getUpdateCount()); + + //show query + rs = statement.executeQuery(sqlShow); + assertEquals(-1, statement.getUpdateCount()); + + //select on master when tx is null and autocommit is false + when(mockConn.getAutoCommit()).thenReturn(false); + when(mockConn.isInTransaction()).thenReturn(false); + rs = statement.executeQuery(sqlSelect); + assertEquals(-1, statement.getUpdateCount()); - //Empty Sql Statement - try { - statement.executeQuery(""); - Assert.fail("Should have thrown exception for empty sql"); - } catch (SQLException ex) { - Assert.assertEquals("SQL statement is not valid", ex.getMessage()); - } - - //select on replica - ResultSet rs = statement.executeQuery(sqlSelect); - Assert.assertEquals(-1, statement.getUpdateCount()); - - //show query - rs = statement.executeQuery(sqlShow); - Assert.assertEquals(-1, statement.getUpdateCount()); - - //select on master when tx is null and autocommit is false - PowerMockito.when(mockConn.getAutoCommit()).thenReturn(false); - PowerMockito.when(mockConn.isInTransaction()).thenReturn(false); - rs = statement.executeQuery(sqlSelect); - Assert.assertEquals(-1, statement.getUpdateCount()); - - //when returned cursor is null - PowerMockito.when(mockVtGateConn - .streamExecute(Matchers.any(Context.class), Matchers.anyString(), Matchers.anyMap(), - Matchers.any(VTSession.class))).thenReturn(null); - try { - statement.executeQuery(sqlSelect); - Assert.fail("Should have thrown exception for cursor null"); - } catch (SQLException ex) { - Assert.assertEquals("Failed to execute this method", ex.getMessage()); - } - - } catch (SQLException e) { - Assert.fail("Test failed " + e.getMessage()); + //when returned cursor is null + when(mockVtGateConn + .streamExecute(any(Context.class), anyString(), anyMap(), + any(VTSession.class))).thenReturn(null); + try { + statement.executeQuery(sqlSelect); + fail("Should have thrown exception for cursor null"); + } catch (SQLException ex) { + assertEquals("Failed to execute this method", ex.getMessage()); } } - @Test public void testExecuteFetchSizeAsStreaming() throws SQLException { + @Test + public void testExecuteFetchSizeAsStreaming() throws SQLException { testExecute(5, true, false, true); testExecute(5, false, false, true); testExecute(0, true, true, false); @@ -189,504 +190,483 @@ } private void testExecute(int fetchSize, boolean simpleExecute, boolean shouldRunExecute, boolean shouldRunStreamExecute) throws SQLException { - VTGateConnection mockVtGateConn = PowerMockito.mock(VTGateConnection.class); - - VitessConnection mockConn = PowerMockito.mock(VitessConnection.class); - PowerMockito.when(mockConn.isSimpleExecute()).thenReturn(simpleExecute); - PowerMockito.when(mockConn.getVtGateConn()).thenReturn(mockVtGateConn); - - Cursor mockCursor = PowerMockito.mock(Cursor.class); - SQLFuture mockSqlFutureCursor = PowerMockito.mock(SQLFuture.class); - PowerMockito.when(mockSqlFutureCursor.checkedGet()).thenReturn(mockCursor); - - PowerMockito.when(mockVtGateConn - .execute(Matchers.any(Context.class), Matchers.anyString(), Matchers.anyMap(), - Matchers.any(VTSession.class))).thenReturn(mockSqlFutureCursor); - PowerMockito.when(mockVtGateConn - .streamExecute(Matchers.any(Context.class), Matchers.anyString(), Matchers.anyMap(), - Matchers.any(VTSession.class))).thenReturn(mockCursor); + VTGateConnection mockVtGateConn = mock(VTGateConnection.class); + + VitessConnection mockConn = mock(VitessConnection.class); + when(mockConn.isSimpleExecute()).thenReturn(simpleExecute); + when(mockConn.getVtGateConn()).thenReturn(mockVtGateConn); + + Cursor mockCursor = mock(Cursor.class); + SQLFuture mockSqlFutureCursor = mock(SQLFuture.class); + when(mockSqlFutureCursor.checkedGet()).thenReturn(mockCursor); + + when(mockVtGateConn + .execute(any(Context.class), anyString(), anyMap(), + any(VTSession.class))).thenReturn(mockSqlFutureCursor); + when(mockVtGateConn + .streamExecute(any(Context.class), anyString(), anyMap(), + any(VTSession.class))).thenReturn(mockCursor); VitessStatement statement = new VitessStatement(mockConn); statement.setFetchSize(fetchSize); statement.executeQuery(sqlSelect); if (shouldRunExecute) { - Mockito.verify(mockVtGateConn, Mockito.times(2)).execute(Matchers.any(Context.class), Matchers.anyString(), Matchers.anyMap(), - Matchers.any(VTSession.class)); + verify(mockVtGateConn, Mockito.times(2)).execute(any(Context.class), anyString(), anyMap(), + any(VTSession.class)); } if (shouldRunStreamExecute) { - Mockito.verify(mockVtGateConn).streamExecute(Matchers.any(Context.class), Matchers.anyString(), Matchers.anyMap(), - Matchers.any(VTSession.class)); + verify(mockVtGateConn).streamExecute(any(Context.class), anyString(), anyMap(), + any(VTSession.class)); } } - @Test public void testExecuteUpdate() throws SQLException { - VitessConnection mockConn = PowerMockito.mock(VitessConnection.class); - VTGateConnection mockVtGateConn = PowerMockito.mock(VTGateConnection.class); - Cursor mockCursor = PowerMockito.mock(Cursor.class); - SQLFuture mockSqlFutureCursor = PowerMockito.mock(SQLFuture.class); - List fieldList = PowerMockito.mock(ArrayList.class); - - PowerMockito.when(mockConn.getVtGateConn()).thenReturn(mockVtGateConn); - PowerMockito.when(mockVtGateConn - .execute(Matchers.any(Context.class), Matchers.anyString(), Matchers.anyMap(), - Matchers.any(VTSession.class))).thenReturn(mockSqlFutureCursor); - PowerMockito.when(mockSqlFutureCursor.checkedGet()).thenReturn(mockCursor); - PowerMockito.when(mockCursor.getFields()).thenReturn(Query.QueryResult.getDefaultInstance().getFieldsList()); + @Test + public void testExecuteUpdate() throws SQLException { + VitessConnection mockConn = mock(VitessConnection.class); + VTGateConnection mockVtGateConn = mock(VTGateConnection.class); + Cursor mockCursor = mock(Cursor.class); + SQLFuture mockSqlFutureCursor = mock(SQLFuture.class); + List fieldList = mock(ArrayList.class); + + when(mockConn.getVtGateConn()).thenReturn(mockVtGateConn); + when(mockVtGateConn + .execute(any(Context.class), anyString(), anyMap(), + any(VTSession.class))).thenReturn(mockSqlFutureCursor); + when(mockSqlFutureCursor.checkedGet()).thenReturn(mockCursor); + when(mockCursor.getFields()).thenReturn(Query.QueryResult.getDefaultInstance().getFieldsList()); VitessStatement statement = new VitessStatement(mockConn); + //executing dml on master + int updateCount = statement.executeUpdate(sqlUpdate); + assertEquals(0, updateCount); + + //tx is null & autoCommit is true + when(mockConn.getAutoCommit()).thenReturn(true); + updateCount = statement.executeUpdate(sqlUpdate); + assertEquals(0, updateCount); + + //cursor fields is not null + when(mockCursor.getFields()).thenReturn(fieldList); + when(fieldList.isEmpty()).thenReturn(false); try { - - //executing dml on master - int updateCount = statement.executeUpdate(sqlUpdate); - Assert.assertEquals(0, updateCount); - - //tx is null & autoCommit is true - PowerMockito.when(mockConn.getAutoCommit()).thenReturn(true); - updateCount = statement.executeUpdate(sqlUpdate); - Assert.assertEquals(0, updateCount); - - //cursor fields is not null - PowerMockito.when(mockCursor.getFields()).thenReturn(fieldList); - PowerMockito.when(fieldList.isEmpty()).thenReturn(false); - try { - statement.executeUpdate(sqlSelect); - Assert.fail("Should have thrown exception for field not null"); - } catch (SQLException ex) { - Assert.assertEquals("ResultSet generation is not allowed through this method", + statement.executeUpdate(sqlSelect); + fail("Should have thrown exception for field not null"); + } catch (SQLException ex) { + assertEquals("ResultSet generation is not allowed through this method", ex.getMessage()); - } - - //cursor is null - PowerMockito.when(mockSqlFutureCursor.checkedGet()).thenReturn(null); - try { - statement.executeUpdate(sqlUpdate); - Assert.fail("Should have thrown exception for cursor null"); - } catch (SQLException ex) { - Assert.assertEquals("Failed to execute this method", ex.getMessage()); - } - - //read only - PowerMockito.when(mockConn.isReadOnly()).thenReturn(true); - try { - statement.execute("UPDATE SET foo = 1 ON mytable WHERE id = 1"); - Assert.fail("Should have thrown exception for read only"); - } catch (SQLException ex) { - Assert.assertEquals(Constants.SQLExceptionMessages.READ_ONLY, ex.getMessage()); - } - - //read only - PowerMockito.when(mockConn.isReadOnly()).thenReturn(true); - try { - statement.executeBatch(); - Assert.fail("Should have thrown exception for read only"); - } catch (SQLException ex) { - Assert.assertEquals(Constants.SQLExceptionMessages.READ_ONLY, ex.getMessage()); - } - - } catch (SQLException e) { - Assert.fail("Test failed " + e.getMessage()); } - } - @Test public void testExecute() throws SQLException { - VitessConnection mockConn = PowerMockito.mock(VitessConnection.class); - VTGateConnection mockVtGateConn = PowerMockito.mock(VTGateConnection.class); - Cursor mockCursor = PowerMockito.mock(Cursor.class); - SQLFuture mockSqlFutureCursor = PowerMockito.mock(SQLFuture.class); - List mockFieldList = PowerMockito.spy(new ArrayList()); - - PowerMockito.when(mockConn.getVtGateConn()).thenReturn(mockVtGateConn); - PowerMockito.when(mockVtGateConn - .execute(Matchers.any(Context.class), Matchers.anyString(), Matchers.anyMap(), - Matchers.any(VTSession.class))).thenReturn(mockSqlFutureCursor); - PowerMockito.when(mockConn.getAutoCommit()).thenReturn(true); - PowerMockito.when(mockConn.getExecuteType()) - .thenReturn(Constants.QueryExecuteType.SIMPLE); - PowerMockito.when(mockConn.isSimpleExecute()).thenReturn(true); - - PowerMockito.when(mockSqlFutureCursor.checkedGet()).thenReturn(mockCursor); - PowerMockito.when(mockCursor.getFields()).thenReturn(mockFieldList); + //cursor is null + when(mockSqlFutureCursor.checkedGet()).thenReturn(null); + try { + statement.executeUpdate(sqlUpdate); + fail("Should have thrown exception for cursor null"); + } catch (SQLException ex) { + assertEquals("Failed to execute this method", ex.getMessage()); + } - VitessStatement statement = new VitessStatement(mockConn); + //read only + when(mockConn.isReadOnly()).thenReturn(true); try { + statement.execute("UPDATE SET foo = 1 ON mytable WHERE id = 1"); + fail("Should have thrown exception for read only"); + } catch (SQLException ex) { + assertEquals(Constants.SQLExceptionMessages.READ_ONLY, ex.getMessage()); + } - int fieldSize = 5; - PowerMockito.when(mockCursor.getFields()).thenReturn(mockFieldList); - PowerMockito.doReturn(fieldSize).when(mockFieldList).size(); - PowerMockito.doReturn(false).when(mockFieldList).isEmpty(); - - boolean hasResultSet = statement.execute(sqlSelect); - Assert.assertTrue(hasResultSet); - Assert.assertNotNull(statement.getResultSet()); - - hasResultSet = statement.execute(sqlShow); - Assert.assertTrue(hasResultSet); - Assert.assertNotNull(statement.getResultSet()); - - int mockUpdateCount = 10; - PowerMockito.when(mockCursor.getFields()).thenReturn(Query.QueryResult.getDefaultInstance().getFieldsList()); - PowerMockito.when(mockCursor.getRowsAffected()).thenReturn((long) mockUpdateCount); - hasResultSet = statement.execute(sqlUpdate); - Assert.assertFalse(hasResultSet); - Assert.assertNull(statement.getResultSet()); - Assert.assertEquals(mockUpdateCount, statement.getUpdateCount()); - - //cursor is null - PowerMockito.when(mockSqlFutureCursor.checkedGet()).thenReturn(null); - try { - statement.execute(sqlUpdate); - Assert.fail("Should have thrown exception for cursor null"); - } catch (SQLException ex) { - Assert.assertEquals("Failed to execute this method", ex.getMessage()); - } - - } catch (SQLException e) { - Assert.fail("Test failed " + e.getMessage()); + //read only + when(mockConn.isReadOnly()).thenReturn(true); + try { + statement.executeBatch(); + fail("Should have thrown exception for read only"); + } catch (SQLException ex) { + assertEquals(Constants.SQLExceptionMessages.READ_ONLY, ex.getMessage()); } } - @Test public void testGetUpdateCount() throws SQLException { - VitessConnection mockConn = PowerMockito.mock(VitessConnection.class); - VTGateConnection mockVtGateConn = PowerMockito.mock(VTGateConnection.class); - Cursor mockCursor = PowerMockito.mock(Cursor.class); - SQLFuture mockSqlFuture = PowerMockito.mock(SQLFuture.class); - - PowerMockito.when(mockConn.getVtGateConn()).thenReturn(mockVtGateConn); - PowerMockito.when(mockVtGateConn - .execute(Matchers.any(Context.class), Matchers.anyString(), Matchers.anyMap(), - Matchers.any(VTSession.class))).thenReturn(mockSqlFuture); - PowerMockito.when(mockSqlFuture.checkedGet()).thenReturn(mockCursor); - PowerMockito.when(mockCursor.getFields()).thenReturn(Query.QueryResult.getDefaultInstance().getFieldsList()); + @Test + public void testExecute() throws SQLException { + VitessConnection mockConn = mock(VitessConnection.class); + VTGateConnection mockVtGateConn = mock(VTGateConnection.class); + Cursor mockCursor = mock(Cursor.class); + SQLFuture mockSqlFutureCursor = mock(SQLFuture.class); + List mockFieldList = PowerMockito.spy(new ArrayList<>()); + + when(mockConn.getVtGateConn()).thenReturn(mockVtGateConn); + when(mockVtGateConn + .execute(any(Context.class), anyString(), anyMap(), + any(VTSession.class))).thenReturn(mockSqlFutureCursor); + when(mockConn.getAutoCommit()).thenReturn(true); + when(mockConn.getExecuteType()) + .thenReturn(Constants.QueryExecuteType.SIMPLE); + when(mockConn.isSimpleExecute()).thenReturn(true); + + when(mockSqlFutureCursor.checkedGet()).thenReturn(mockCursor); + when(mockCursor.getFields()).thenReturn(mockFieldList); VitessStatement statement = new VitessStatement(mockConn); + int fieldSize = 5; + when(mockCursor.getFields()).thenReturn(mockFieldList); + doReturn(fieldSize).when(mockFieldList).size(); + doReturn(false).when(mockFieldList).isEmpty(); + + boolean hasResultSet = statement.execute(sqlSelect); + assertTrue(hasResultSet); + assertNotNull(statement.getResultSet()); + + hasResultSet = statement.execute(sqlShow); + assertTrue(hasResultSet); + assertNotNull(statement.getResultSet()); + + int mockUpdateCount = 10; + when(mockCursor.getFields()).thenReturn(Query.QueryResult.getDefaultInstance().getFieldsList()); + when(mockCursor.getRowsAffected()).thenReturn((long) mockUpdateCount); + hasResultSet = statement.execute(sqlUpdate); + assertFalse(hasResultSet); + assertNull(statement.getResultSet()); + assertEquals(mockUpdateCount, statement.getUpdateCount()); + + //cursor is null + when(mockSqlFutureCursor.checkedGet()).thenReturn(null); try { + statement.execute(sqlUpdate); + fail("Should have thrown exception for cursor null"); + } catch (SQLException ex) { + assertEquals("Failed to execute this method", ex.getMessage()); + } + } - PowerMockito.when(mockCursor.getRowsAffected()).thenReturn(10L); - int updateCount = statement.executeUpdate(sqlUpdate); - Assert.assertEquals(10L, updateCount); - Assert.assertEquals(10L, statement.getUpdateCount()); + @Test + public void testGetUpdateCount() throws SQLException { + VitessConnection mockConn = mock(VitessConnection.class); + VTGateConnection mockVtGateConn = mock(VTGateConnection.class); + Cursor mockCursor = mock(Cursor.class); + SQLFuture mockSqlFuture = mock(SQLFuture.class); - // Truncated Update Count - PowerMockito.when(mockCursor.getRowsAffected()) - .thenReturn((long) Integer.MAX_VALUE + 10); - updateCount = statement.executeUpdate(sqlUpdate); - Assert.assertEquals(Integer.MAX_VALUE, updateCount); - Assert.assertEquals(Integer.MAX_VALUE, statement.getUpdateCount()); + when(mockConn.getVtGateConn()).thenReturn(mockVtGateConn); + when(mockVtGateConn + .execute(any(Context.class), anyString(), anyMap(), + any(VTSession.class))).thenReturn(mockSqlFuture); + when(mockSqlFuture.checkedGet()).thenReturn(mockCursor); + when(mockCursor.getFields()).thenReturn(Query.QueryResult.getDefaultInstance().getFieldsList()); - PowerMockito.when(mockConn.isSimpleExecute()).thenReturn(true); - statement.executeQuery(sqlSelect); - Assert.assertEquals(-1, statement.getUpdateCount()); + VitessStatement statement = new VitessStatement(mockConn); + when(mockCursor.getRowsAffected()).thenReturn(10L); + int updateCount = statement.executeUpdate(sqlUpdate); + assertEquals(10L, updateCount); + assertEquals(10L, statement.getUpdateCount()); - } catch (SQLException e) { - Assert.fail("Test failed " + e.getMessage()); - } + // Truncated Update Count + when(mockCursor.getRowsAffected()) + .thenReturn((long) Integer.MAX_VALUE + 10); + updateCount = statement.executeUpdate(sqlUpdate); + assertEquals(Integer.MAX_VALUE, updateCount); + assertEquals(Integer.MAX_VALUE, statement.getUpdateCount()); + + when(mockConn.isSimpleExecute()).thenReturn(true); + statement.executeQuery(sqlSelect); + assertEquals(-1, statement.getUpdateCount()); } - @Test public void testClose() throws Exception { - VitessConnection mockConn = PowerMockito.mock(VitessConnection.class); - VTGateConnection mockVtGateConn = PowerMockito.mock(VTGateConnection.class); - Cursor mockCursor = PowerMockito.mock(Cursor.class); - SQLFuture mockSqlFutureCursor = PowerMockito.mock(SQLFuture.class); - - PowerMockito.when(mockConn.getVtGateConn()).thenReturn(mockVtGateConn); - PowerMockito.when(mockVtGateConn - .execute(Matchers.any(Context.class), Matchers.anyString(), Matchers.anyMap(), - Matchers.any(VTSession.class))).thenReturn(mockSqlFutureCursor); - PowerMockito.when(mockConn.getExecuteType()) - .thenReturn(Constants.QueryExecuteType.SIMPLE); - PowerMockito.when(mockConn.isSimpleExecute()).thenReturn(true); - PowerMockito.when(mockSqlFutureCursor.checkedGet()).thenReturn(mockCursor); + @Test + public void testClose() throws Exception { + VitessConnection mockConn = mock(VitessConnection.class); + VTGateConnection mockVtGateConn = mock(VTGateConnection.class); + Cursor mockCursor = mock(Cursor.class); + SQLFuture mockSqlFutureCursor = mock(SQLFuture.class); + + when(mockConn.getVtGateConn()).thenReturn(mockVtGateConn); + when(mockVtGateConn + .execute(any(Context.class), anyString(), anyMap(), + any(VTSession.class))).thenReturn(mockSqlFutureCursor); + when(mockConn.getExecuteType()) + .thenReturn(Constants.QueryExecuteType.SIMPLE); + when(mockConn.isSimpleExecute()).thenReturn(true); + when(mockSqlFutureCursor.checkedGet()).thenReturn(mockCursor); VitessStatement statement = new VitessStatement(mockConn); + ResultSet rs = statement.executeQuery(sqlSelect); + statement.close(); try { - ResultSet rs = statement.executeQuery(sqlSelect); - statement.close(); - try { - statement.executeQuery(sqlSelect); - Assert.fail("Should have thrown exception for statement closed"); - } catch (SQLException ex) { - Assert.assertEquals("Statement is closed", ex.getMessage()); - } - } catch (SQLException e) { - Assert.fail("Test failed " + e.getMessage()); + statement.executeQuery(sqlSelect); + fail("Should have thrown exception for statement closed"); + } catch (SQLException ex) { + assertEquals("Statement is closed", ex.getMessage()); } } - @Test public void testGetMaxFieldSize() throws SQLException { - VitessConnection mockConn = PowerMockito.mock(VitessConnection.class); + @Test + public void testGetMaxFieldSize() throws SQLException { + VitessConnection mockConn = mock(VitessConnection.class); VitessStatement statement = new VitessStatement(mockConn); - Assert.assertEquals(65535, statement.getMaxFieldSize()); + assertEquals(65535, statement.getMaxFieldSize()); } - @Test public void testGetMaxRows() throws SQLException { - VitessConnection mockConn = PowerMockito.mock(VitessConnection.class); + @Test + public void testGetMaxRows() throws SQLException { + VitessConnection mockConn = mock(VitessConnection.class); VitessStatement statement = new VitessStatement(mockConn); statement.setMaxRows(10); - Assert.assertEquals(10, statement.getMaxRows()); + assertEquals(10, statement.getMaxRows()); try { statement.setMaxRows(-1); - Assert.fail("Should have thrown exception for wrong value"); + fail("Should have thrown exception for wrong value"); } catch (SQLException ex) { - Assert.assertEquals("Illegal value for max row", ex.getMessage()); + assertEquals("Illegal value for max row", ex.getMessage()); } } - @Test public void testGetQueryTimeout() throws SQLException { - VitessConnection mockConn = PowerMockito.mock(VitessConnection.class); - Mockito.when(mockConn.getTimeout()).thenReturn((long)Constants.DEFAULT_TIMEOUT); + @Test + public void testGetQueryTimeout() throws SQLException { + VitessConnection mockConn = mock(VitessConnection.class); + Mockito.when(mockConn.getTimeout()).thenReturn((long) Constants.DEFAULT_TIMEOUT); VitessStatement statement = new VitessStatement(mockConn); - Assert.assertEquals(30, statement.getQueryTimeout()); + assertEquals(30, statement.getQueryTimeout()); } - @Test public void testGetQueryTimeoutZeroDefault() throws SQLException { - VitessConnection mockConn = PowerMockito.mock(VitessConnection.class); + @Test + public void testGetQueryTimeoutZeroDefault() throws SQLException { + VitessConnection mockConn = mock(VitessConnection.class); Mockito.when(mockConn.getTimeout()).thenReturn(0L); VitessStatement statement = new VitessStatement(mockConn); - Assert.assertEquals(0, statement.getQueryTimeout()); + assertEquals(0, statement.getQueryTimeout()); } - @Test public void testSetQueryTimeout() throws SQLException { - VitessConnection mockConn = PowerMockito.mock(VitessConnection.class); - Mockito.when(mockConn.getTimeout()).thenReturn((long)Constants.DEFAULT_TIMEOUT); + @Test + public void testSetQueryTimeout() throws SQLException { + VitessConnection mockConn = mock(VitessConnection.class); + Mockito.when(mockConn.getTimeout()).thenReturn((long) Constants.DEFAULT_TIMEOUT); VitessStatement statement = new VitessStatement(mockConn); int queryTimeout = 10; statement.setQueryTimeout(queryTimeout); - Assert.assertEquals(queryTimeout, statement.getQueryTimeout()); + assertEquals(queryTimeout, statement.getQueryTimeout()); try { queryTimeout = -1; statement.setQueryTimeout(queryTimeout); - Assert.fail("Should have thrown exception for wrong value"); + fail("Should have thrown exception for wrong value"); } catch (SQLException ex) { - Assert.assertEquals("Illegal value for query timeout", ex.getMessage()); + assertEquals("Illegal value for query timeout", ex.getMessage()); } statement.setQueryTimeout(0); - Assert.assertEquals(30, statement.getQueryTimeout()); + assertEquals(30, statement.getQueryTimeout()); } - @Test public void testGetWarnings() throws SQLException { - VitessConnection mockConn = PowerMockito.mock(VitessConnection.class); + @Test + public void testGetWarnings() throws SQLException { + VitessConnection mockConn = mock(VitessConnection.class); VitessStatement statement = new VitessStatement(mockConn); - Assert.assertNull(statement.getWarnings()); + assertNull(statement.getWarnings()); } - @Test public void testGetFetchDirection() throws SQLException { - VitessConnection mockConn = PowerMockito.mock(VitessConnection.class); + @Test + public void testGetFetchDirection() throws SQLException { + VitessConnection mockConn = mock(VitessConnection.class); VitessStatement statement = new VitessStatement(mockConn); - Assert.assertEquals(ResultSet.FETCH_FORWARD, statement.getFetchDirection()); + assertEquals(ResultSet.FETCH_FORWARD, statement.getFetchDirection()); } - @Test public void testGetFetchSize() throws SQLException { - VitessConnection mockConn = PowerMockito.mock(VitessConnection.class); + @Test + public void testGetFetchSize() throws SQLException { + VitessConnection mockConn = mock(VitessConnection.class); VitessStatement statement = new VitessStatement(mockConn); - Assert.assertEquals(0, statement.getFetchSize()); + assertEquals(0, statement.getFetchSize()); } - @Test public void testGetResultSetConcurrency() throws SQLException { - VitessConnection mockConn = PowerMockito.mock(VitessConnection.class); + @Test + public void testGetResultSetConcurrency() throws SQLException { + VitessConnection mockConn = mock(VitessConnection.class); VitessStatement statement = new VitessStatement(mockConn); - Assert.assertEquals(ResultSet.CONCUR_READ_ONLY, statement.getResultSetConcurrency()); + assertEquals(ResultSet.CONCUR_READ_ONLY, statement.getResultSetConcurrency()); } - @Test public void testGetResultSetType() throws SQLException { - VitessConnection mockConn = PowerMockito.mock(VitessConnection.class); + @Test + public void testGetResultSetType() throws SQLException { + VitessConnection mockConn = mock(VitessConnection.class); VitessStatement statement = new VitessStatement(mockConn); - Assert.assertEquals(ResultSet.TYPE_FORWARD_ONLY, statement.getResultSetType()); + assertEquals(ResultSet.TYPE_FORWARD_ONLY, statement.getResultSetType()); } - @Test public void testIsClosed() throws SQLException { - VitessConnection mockConn = PowerMockito.mock(VitessConnection.class); + @Test + public void testIsClosed() throws SQLException { + VitessConnection mockConn = mock(VitessConnection.class); VitessStatement statement = new VitessStatement(mockConn); - Assert.assertFalse(statement.isClosed()); + assertFalse(statement.isClosed()); statement.close(); - Assert.assertTrue(statement.isClosed()); + assertTrue(statement.isClosed()); } - @Test public void testAutoGeneratedKeys() throws SQLException { - VitessConnection mockConn = PowerMockito.mock(VitessConnection.class); - VTGateConnection mockVtGateConn = PowerMockito.mock(VTGateConnection.class); - Cursor mockCursor = PowerMockito.mock(Cursor.class); - SQLFuture mockSqlFutureCursor = PowerMockito.mock(SQLFuture.class); + @Test + public void testAutoGeneratedKeys() throws SQLException { + VitessConnection mockConn = mock(VitessConnection.class); + VTGateConnection mockVtGateConn = mock(VTGateConnection.class); + Cursor mockCursor = mock(Cursor.class); + SQLFuture mockSqlFutureCursor = mock(SQLFuture.class); - PowerMockito.when(mockConn.getVtGateConn()).thenReturn(mockVtGateConn); - PowerMockito.when(mockVtGateConn - .execute(Matchers.any(Context.class), Matchers.anyString(), Matchers.anyMap(), - Matchers.any(VTSession.class))).thenReturn(mockSqlFutureCursor); - PowerMockito.when(mockSqlFutureCursor.checkedGet()).thenReturn(mockCursor); - PowerMockito.when(mockCursor.getFields()).thenReturn(Query.QueryResult.getDefaultInstance().getFieldsList()); + when(mockConn.getVtGateConn()).thenReturn(mockVtGateConn); + when(mockVtGateConn + .execute(any(Context.class), anyString(), anyMap(), + any(VTSession.class))).thenReturn(mockSqlFutureCursor); + when(mockSqlFutureCursor.checkedGet()).thenReturn(mockCursor); + when(mockCursor.getFields()).thenReturn(Query.QueryResult.getDefaultInstance().getFieldsList()); VitessStatement statement = new VitessStatement(mockConn); - try { - - long expectedFirstGeneratedId = 121; - long[] expectedGeneratedIds = {121, 122, 123, 124, 125}; - int expectedAffectedRows = 5; - PowerMockito.when(mockCursor.getInsertId()).thenReturn(expectedFirstGeneratedId); - PowerMockito.when(mockCursor.getRowsAffected()) + long expectedFirstGeneratedId = 121; + long[] expectedGeneratedIds = {121, 122, 123, 124, 125}; + int expectedAffectedRows = 5; + when(mockCursor.getInsertId()).thenReturn(expectedFirstGeneratedId); + when(mockCursor.getRowsAffected()) .thenReturn(Long.valueOf(expectedAffectedRows)); - //Executing Insert Statement - int updateCount = statement.executeUpdate(sqlInsert, Statement.RETURN_GENERATED_KEYS); - Assert.assertEquals(expectedAffectedRows, updateCount); - - ResultSet rs = statement.getGeneratedKeys(); - int i = 0; - while (rs.next()) { - long generatedId = rs.getLong(1); - Assert.assertEquals(expectedGeneratedIds[i++], generatedId); - } - - //Fetching Generated Keys without notifying the driver - statement.executeUpdate(sqlInsert); - try { - statement.getGeneratedKeys(); - Assert.fail("Should have thrown exception for not setting autoGeneratedKey flag"); - } catch (SQLException ex) { - Assert.assertEquals("Generated keys not requested. You need to specify Statement" - + ".RETURN_GENERATED_KEYS to Statement.executeUpdate() or Connection.prepareStatement()", - ex.getMessage()); - } + //Executing Insert Statement + int updateCount = statement.executeUpdate(sqlInsert, Statement.RETURN_GENERATED_KEYS); + assertEquals(expectedAffectedRows, updateCount); - //Fetching Generated Keys on update query - expectedFirstGeneratedId = 0; - PowerMockito.when(mockCursor.getInsertId()).thenReturn(expectedFirstGeneratedId); - updateCount = statement.executeUpdate(sqlUpdate, Statement.RETURN_GENERATED_KEYS); - Assert.assertEquals(expectedAffectedRows, updateCount); - - rs = statement.getGeneratedKeys(); - Assert.assertFalse(rs.next()); + ResultSet rs = statement.getGeneratedKeys(); + int i = 0; + while (rs.next()) { + long generatedId = rs.getLong(1); + assertEquals(expectedGeneratedIds[i++], generatedId); + } - } catch (SQLException e) { - Assert.fail("Test failed " + e.getMessage()); + //Fetching Generated Keys without notifying the driver + statement.executeUpdate(sqlInsert); + try { + statement.getGeneratedKeys(); + fail("Should have thrown exception for not setting autoGeneratedKey flag"); + } catch (SQLException ex) { + assertEquals("Generated keys not requested. You need to specify Statement" + + ".RETURN_GENERATED_KEYS to Statement.executeUpdate() or Connection.prepareStatement()", + ex.getMessage()); } + + //Fetching Generated Keys on update query + expectedFirstGeneratedId = 0; + when(mockCursor.getInsertId()).thenReturn(expectedFirstGeneratedId); + updateCount = statement.executeUpdate(sqlUpdate, Statement.RETURN_GENERATED_KEYS); + assertEquals(expectedAffectedRows, updateCount); + + rs = statement.getGeneratedKeys(); + assertFalse(rs.next()); } - @Test public void testAddBatch() throws SQLException { - VitessConnection mockConn = PowerMockito.mock(VitessConnection.class); + @Test + public void testAddBatch() throws Exception { + VitessConnection mockConn = mock(VitessConnection.class); VitessStatement statement = new VitessStatement(mockConn); statement.addBatch(sqlInsert); - try { - Field privateStringField = VitessStatement.class.getDeclaredField("batchedArgs"); - privateStringField.setAccessible(true); - Assert - .assertEquals(sqlInsert, ((List) privateStringField.get(statement)).get(0)); - } catch (NoSuchFieldException e) { - Assert.fail("Private Field should exists: batchedArgs"); - } catch (IllegalAccessException e) { - Assert.fail("Private Field should be accessible: batchedArgs"); - } + Field privateStringField = VitessStatement.class.getDeclaredField("batchedArgs"); + privateStringField.setAccessible(true); + assertEquals(sqlInsert, ((List) privateStringField.get(statement)).get(0)); } - @Test public void testClearBatch() throws SQLException { - VitessConnection mockConn = PowerMockito.mock(VitessConnection.class); + @Test + public void testClearBatch() throws Exception { + VitessConnection mockConn = mock(VitessConnection.class); VitessStatement statement = new VitessStatement(mockConn); statement.addBatch(sqlInsert); statement.clearBatch(); - try { - Field privateStringField = VitessStatement.class.getDeclaredField("batchedArgs"); - privateStringField.setAccessible(true); - Assert.assertTrue(((List) privateStringField.get(statement)).isEmpty()); - } catch (NoSuchFieldException e) { - Assert.fail("Private Field should exists: batchedArgs"); - } catch (IllegalAccessException e) { - Assert.fail("Private Field should be accessible: batchedArgs"); - } + Field privateStringField = VitessStatement.class.getDeclaredField("batchedArgs"); + privateStringField.setAccessible(true); + assertTrue(((List) privateStringField.get(statement)).isEmpty()); } - @Test public void testExecuteBatch() throws SQLException { - VitessConnection mockConn = PowerMockito.mock(VitessConnection.class); + @Test + public void testExecuteBatch() throws SQLException { + VitessConnection mockConn = mock(VitessConnection.class); VitessStatement statement = new VitessStatement(mockConn); int[] updateCounts = statement.executeBatch(); - Assert.assertEquals(0, updateCounts.length); + assertEquals(0, updateCounts.length); - VTGateConnection mockVtGateConn = PowerMockito.mock(VTGateConnection.class); - PowerMockito.when(mockConn.getVtGateConn()).thenReturn(mockVtGateConn); - PowerMockito.when(mockConn.getAutoCommit()).thenReturn(true); + VTGateConnection mockVtGateConn = mock(VTGateConnection.class); + when(mockConn.getVtGateConn()).thenReturn(mockVtGateConn); + when(mockConn.getAutoCommit()).thenReturn(true); - SQLFuture mockSqlFutureCursor = PowerMockito.mock(SQLFuture.class); - PowerMockito.when(mockVtGateConn - .executeBatch(Matchers.any(Context.class), Matchers.anyList(), Matchers.anyList(), - Matchers.any(VTSession.class))).thenReturn(mockSqlFutureCursor); + SQLFuture mockSqlFutureCursor = mock(SQLFuture.class); + when(mockVtGateConn + .executeBatch(any(Context.class), anyList(), anyList(), + any(VTSession.class))).thenReturn(mockSqlFutureCursor); List mockCursorWithErrorList = new ArrayList<>(); - PowerMockito.when(mockSqlFutureCursor.checkedGet()).thenReturn(mockCursorWithErrorList); + when(mockSqlFutureCursor.checkedGet()).thenReturn(mockCursorWithErrorList); - CursorWithError mockCursorWithError1 = PowerMockito.mock(CursorWithError.class); - PowerMockito.when(mockCursorWithError1.getError()).thenReturn(null); - PowerMockito.when(mockCursorWithError1.getCursor()) - .thenReturn(PowerMockito.mock(Cursor.class)); + CursorWithError mockCursorWithError1 = mock(CursorWithError.class); + when(mockCursorWithError1.getError()).thenReturn(null); + when(mockCursorWithError1.getCursor()) + .thenReturn(mock(Cursor.class)); mockCursorWithErrorList.add(mockCursorWithError1); statement.addBatch(sqlUpdate); updateCounts = statement.executeBatch(); - Assert.assertEquals(1, updateCounts.length); + assertEquals(1, updateCounts.length); - CursorWithError mockCursorWithError2 = PowerMockito.mock(CursorWithError.class); + CursorWithError mockCursorWithError2 = mock(CursorWithError.class); Vtrpc.RPCError rpcError = Vtrpc.RPCError.newBuilder().setMessage("statement execute batch error").build(); - PowerMockito.when(mockCursorWithError2.getError()) - .thenReturn(rpcError); + when(mockCursorWithError2.getError()) + .thenReturn(rpcError); mockCursorWithErrorList.add(mockCursorWithError2); statement.addBatch(sqlUpdate); statement.addBatch(sqlUpdate); try { statement.executeBatch(); - Assert.fail("Should have thrown Exception"); + fail("Should have thrown Exception"); } catch (BatchUpdateException ex) { - Assert.assertEquals(rpcError.toString(), ex.getMessage()); - Assert.assertEquals(2, ex.getUpdateCounts().length); - Assert.assertEquals(Statement.EXECUTE_FAILED, ex.getUpdateCounts()[1]); + assertEquals(rpcError.toString(), ex.getMessage()); + assertEquals(2, ex.getUpdateCounts().length); + assertEquals(Statement.EXECUTE_FAILED, ex.getUpdateCounts()[1]); } } - @Test public void testBatchGeneratedKeys() throws SQLException { - VitessConnection mockConn = PowerMockito.mock(VitessConnection.class); + @Test + public void testBatchGeneratedKeys() throws SQLException { + VitessConnection mockConn = mock(VitessConnection.class); VitessStatement statement = new VitessStatement(mockConn); - Cursor mockCursor = PowerMockito.mock(Cursor.class); - SQLFuture mockSqlFutureCursor = PowerMockito.mock(SQLFuture.class); - - VTGateConnection mockVtGateConn = PowerMockito.mock(VTGateConnection.class); - PowerMockito.when(mockConn.getVtGateConn()).thenReturn(mockVtGateConn); - PowerMockito.when(mockConn.getAutoCommit()).thenReturn(true); - - PowerMockito.when(mockSqlFutureCursor.checkedGet()).thenReturn(mockCursor); - PowerMockito.when(mockCursor.getFields()).thenReturn(Query.QueryResult.getDefaultInstance().getFieldsList()); - - PowerMockito.when(mockVtGateConn - .executeBatch(Matchers.any(Context.class), - Matchers.anyList(), - Matchers.anyList(), - Matchers.any(VTSession.class))) - .thenReturn(mockSqlFutureCursor); + Cursor mockCursor = mock(Cursor.class); + SQLFuture mockSqlFutureCursor = mock(SQLFuture.class); + + VTGateConnection mockVtGateConn = mock(VTGateConnection.class); + when(mockConn.getVtGateConn()).thenReturn(mockVtGateConn); + when(mockConn.getAutoCommit()).thenReturn(true); + + when(mockSqlFutureCursor.checkedGet()).thenReturn(mockCursor); + when(mockCursor.getFields()).thenReturn(Query.QueryResult.getDefaultInstance().getFieldsList()); + + when(mockVtGateConn + .executeBatch(any(Context.class), + anyList(), + anyList(), + any(VTSession.class))) + .thenReturn(mockSqlFutureCursor); List mockCursorWithErrorList = new ArrayList<>(); - PowerMockito.when(mockSqlFutureCursor.checkedGet()).thenReturn(mockCursorWithErrorList); + when(mockSqlFutureCursor.checkedGet()).thenReturn(mockCursorWithErrorList); - CursorWithError mockCursorWithError = PowerMockito.mock(CursorWithError.class); - PowerMockito.when(mockCursorWithError.getError()).thenReturn(null); - PowerMockito.when(mockCursorWithError.getCursor()).thenReturn(mockCursor); + CursorWithError mockCursorWithError = mock(CursorWithError.class); + when(mockCursorWithError.getError()).thenReturn(null); + when(mockCursorWithError.getCursor()).thenReturn(mockCursor); mockCursorWithErrorList.add(mockCursorWithError); long expectedFirstGeneratedId = 121; long[] expectedGeneratedIds = {121, 122, 123, 124, 125}; - PowerMockito.when(mockCursor.getInsertId()).thenReturn(expectedFirstGeneratedId); - PowerMockito.when(mockCursor.getRowsAffected()).thenReturn(Long.valueOf(expectedGeneratedIds.length)); + when(mockCursor.getInsertId()).thenReturn(expectedFirstGeneratedId); + when(mockCursor.getRowsAffected()).thenReturn(Long.valueOf(expectedGeneratedIds.length)); statement.addBatch(sqlInsert); statement.executeBatch(); @@ -695,41 +675,42 @@ private void testExecute(int fetchSize, boolean simpleExecute, boolean shouldRun int i = 0; while (rs.next()) { long generatedId = rs.getLong(1); - Assert.assertEquals(expectedGeneratedIds[i++], generatedId); + assertEquals(expectedGeneratedIds[i++], generatedId); } } - @Test public void testBatchUpsertGeneratedKeys() throws SQLException { - VitessConnection mockConn = PowerMockito.mock(VitessConnection.class); + @Test + public void testBatchUpsertGeneratedKeys() throws SQLException { + VitessConnection mockConn = mock(VitessConnection.class); VitessStatement statement = new VitessStatement(mockConn); - Cursor mockCursor = PowerMockito.mock(Cursor.class); - SQLFuture mockSqlFutureCursor = PowerMockito.mock(SQLFuture.class); - - VTGateConnection mockVtGateConn = PowerMockito.mock(VTGateConnection.class); - PowerMockito.when(mockConn.getVtGateConn()).thenReturn(mockVtGateConn); - PowerMockito.when(mockConn.getAutoCommit()).thenReturn(true); - - PowerMockito.when(mockSqlFutureCursor.checkedGet()).thenReturn(mockCursor); - PowerMockito.when(mockCursor.getFields()).thenReturn(Query.QueryResult.getDefaultInstance().getFieldsList()); - - PowerMockito.when(mockVtGateConn - .executeBatch(Matchers.any(Context.class), - Matchers.anyList(), - Matchers.anyList(), - Matchers.any(VTSession.class))) - .thenReturn(mockSqlFutureCursor); + Cursor mockCursor = mock(Cursor.class); + SQLFuture mockSqlFutureCursor = mock(SQLFuture.class); + + VTGateConnection mockVtGateConn = mock(VTGateConnection.class); + when(mockConn.getVtGateConn()).thenReturn(mockVtGateConn); + when(mockConn.getAutoCommit()).thenReturn(true); + + when(mockSqlFutureCursor.checkedGet()).thenReturn(mockCursor); + when(mockCursor.getFields()).thenReturn(Query.QueryResult.getDefaultInstance().getFieldsList()); + + when(mockVtGateConn + .executeBatch(any(Context.class), + anyList(), + anyList(), + any(VTSession.class))) + .thenReturn(mockSqlFutureCursor); List mockCursorWithErrorList = new ArrayList<>(); - PowerMockito.when(mockSqlFutureCursor.checkedGet()).thenReturn(mockCursorWithErrorList); + when(mockSqlFutureCursor.checkedGet()).thenReturn(mockCursorWithErrorList); - CursorWithError mockCursorWithError = PowerMockito.mock(CursorWithError.class); - PowerMockito.when(mockCursorWithError.getError()).thenReturn(null); - PowerMockito.when(mockCursorWithError.getCursor()).thenReturn(mockCursor); + CursorWithError mockCursorWithError = mock(CursorWithError.class); + when(mockCursorWithError.getError()).thenReturn(null); + when(mockCursorWithError.getCursor()).thenReturn(mockCursor); mockCursorWithErrorList.add(mockCursorWithError); long expectedFirstGeneratedId = 121; long[] expectedGeneratedIds = {121, 122}; - PowerMockito.when(mockCursor.getInsertId()).thenReturn(expectedFirstGeneratedId); - PowerMockito.when(mockCursor.getRowsAffected()).thenReturn(Long.valueOf(expectedGeneratedIds.length)); + when(mockCursor.getInsertId()).thenReturn(expectedFirstGeneratedId); + when(mockCursor.getRowsAffected()).thenReturn(Long.valueOf(expectedGeneratedIds.length)); statement.addBatch(sqlUpsert); statement.executeBatch(); @@ -738,19 +719,19 @@ private void testExecute(int fetchSize, boolean simpleExecute, boolean shouldRun int i = 0; while (rs.next()) { long generatedId = rs.getLong(1); - Assert.assertEquals(expectedGeneratedIds[i], generatedId); - Assert.assertEquals(i, 0); // we should only have one + assertEquals(expectedGeneratedIds[i], generatedId); + assertEquals(i, 0); // we should only have one i++; } VitessStatement noUpdate = new VitessStatement(mockConn); - PowerMockito.when(mockCursor.getInsertId()).thenReturn(0L); - PowerMockito.when(mockCursor.getRowsAffected()).thenReturn(1L); + when(mockCursor.getInsertId()).thenReturn(0L); + when(mockCursor.getRowsAffected()).thenReturn(1L); noUpdate.addBatch(sqlUpsert); noUpdate.executeBatch(); ResultSet empty = noUpdate.getGeneratedKeys(); - Assert.assertFalse(empty.next()); + assertFalse(empty.next()); } } diff --git a/java/jdbc/src/test/java/io/vitess/jdbc/VitessVTGateManagerTest.java b/java/jdbc/src/test/java/io/vitess/jdbc/VitessVTGateManagerTest.java index b8a6e9253e6..86eb64f43d7 100644 --- a/java/jdbc/src/test/java/io/vitess/jdbc/VitessVTGateManagerTest.java +++ b/java/jdbc/src/test/java/io/vitess/jdbc/VitessVTGateManagerTest.java @@ -16,11 +16,6 @@ package io.vitess.jdbc; -import java.io.IOException; -import java.lang.reflect.Field; -import java.sql.SQLException; -import java.util.Properties; -import java.util.concurrent.ConcurrentHashMap; import org.joda.time.Duration; import org.junit.Assert; import org.junit.Test; @@ -31,6 +26,12 @@ import io.vitess.client.grpc.GrpcClientFactory; import io.vitess.proto.Vtrpc; +import java.io.IOException; +import java.lang.reflect.Field; +import java.sql.SQLException; +import java.util.Properties; +import java.util.concurrent.ConcurrentHashMap; + /** * Created by naveen.nahata on 29/02/16. */ diff --git a/test/base_sharding.py b/test/base_sharding.py index 89b1490804f..3fcc92c1fb1 100644 --- a/test/base_sharding.py +++ b/test/base_sharding.py @@ -28,6 +28,7 @@ keyspace_id_type = keyrange_constants.KIT_UINT64 use_rbr = False +use_multi_split_diff = False pack_keyspace_id = struct.Struct('!Q').pack # fixed_parent_id is used as fixed value for the "parent_id" column in all rows. diff --git a/test/config.json b/test/config.json index 2222f062857..d629184ec8e 100644 --- a/test/config.json +++ b/test/config.json @@ -90,17 +90,6 @@ "worker_test" ] }, - "initial_sharding_l2vtgate": { - "File": "initial_sharding_l2vtgate.py", - "Args": [], - "Command": [], - "Manual": false, - "Shard": 2, - "RetryMax": 0, - "Tags": [ - "worker_test" - ] - }, "legacy_resharding": { "File": "legacy_resharding.py", "Args": [], @@ -426,17 +415,6 @@ "site_test" ] }, - "vtgatev2_l2vtgate": { - "File": "vtgatev2_l2vtgate_test.py", - "Args": [], - "Command": [], - "Manual": false, - "Shard": 1, - "RetryMax": 0, - "Tags": [ - "site_test" - ] - }, "vtgatev3": { "File": "vtgatev3_test.py", "Args": [], diff --git a/test/initial_sharding.py b/test/initial_sharding.py index 0a101192ae0..cb4fb3428b7 100755 --- a/test/initial_sharding.py +++ b/test/initial_sharding.py @@ -28,7 +28,6 @@ import logging import unittest - from vtdb import keyrange_constants import base_sharding @@ -36,16 +35,6 @@ import tablet import utils -# use_l2vtgate is set if we want to use l2vtgate processes. -# We'll set them up to have: -# l2vtgate1: covers the initial shard, and -80 -# l2vtgate2: covers 80- -use_l2vtgate = False - -# the l2vtgate processes, if applicable -l2vtgate1 = None -l2vtgate2 = None - # initial shard, covers everything shard_master = tablet.Tablet() shard_replica = tablet.Tablet() @@ -218,8 +207,6 @@ def _check_lots_not_present(self, count, base=0): should_be_here=False) def test_resharding(self): - global l2vtgate1, l2vtgate2 - # create the keyspace with just one shard shard_master.init_tablet( 'replica', @@ -281,33 +268,12 @@ def test_resharding(self): # We must start vtgate after tablets are up, or else wait until 1min refresh # (that is the tablet_refresh_interval parameter for discovery gateway) # we want cache_ttl at zero so we re-read the topology for every test query. - if use_l2vtgate: - l2vtgate1 = utils.VtGate() - l2vtgate1.start(extra_args=['--enable_forwarding'], tablets= - [shard_master, shard_replica, shard_rdonly1]) - l2vtgate1.wait_for_endpoints('test_keyspace.0.master', 1) - l2vtgate1.wait_for_endpoints('test_keyspace.0.replica', 1) - l2vtgate1.wait_for_endpoints('test_keyspace.0.rdonly', 1) - - _, l2vtgate1_addr = l2vtgate1.rpc_endpoint() - - # Clear utils.vtgate, so it doesn't point to the previous l2vtgate1. - utils.vtgate = None - utils.VtGate().start(cache_ttl='0', l2vtgates=[l2vtgate1_addr,], - extra_args=['-disable_local_gateway']) - utils.vtgate.wait_for_endpoints('test_keyspace.0.master', 1, - var='L2VtgateConnections') - utils.vtgate.wait_for_endpoints('test_keyspace.0.replica', 1, - var='L2VtgateConnections') - utils.vtgate.wait_for_endpoints('test_keyspace.0.rdonly', 1, - var='L2VtgateConnections') - else: - utils.VtGate().start(cache_ttl='0', tablets=[ - shard_master, shard_replica, shard_rdonly1]) - utils.vtgate.wait_for_endpoints('test_keyspace.0.master', 1) - utils.vtgate.wait_for_endpoints('test_keyspace.0.replica', 1) - utils.vtgate.wait_for_endpoints('test_keyspace.0.rdonly', 1) + utils.VtGate().start(cache_ttl='0', tablets=[ + shard_master, shard_replica, shard_rdonly1]) + utils.vtgate.wait_for_endpoints('test_keyspace.0.master', 1) + utils.vtgate.wait_for_endpoints('test_keyspace.0.replica', 1) + utils.vtgate.wait_for_endpoints('test_keyspace.0.rdonly', 1) # check the Map Reduce API works correctly, should use ExecuteShards, # as we're not sharded yet. @@ -392,62 +358,13 @@ def test_resharding(self): # must restart vtgate after tablets are up, or else wait until 1min refresh # we want cache_ttl at zero so we re-read the topology for every test query. utils.vtgate.kill() - if use_l2vtgate: - l2vtgate1.kill() - - l2vtgate1 = utils.VtGate() - l2vtgate1.start(extra_args=['--enable_forwarding', - '-tablet_filters', - 'test_keyspace|0,test_keyspace|-80'], - tablets=[shard_master, shard_replica, shard_rdonly1, - shard_0_master, shard_0_replica, - shard_0_rdonly1]) - l2vtgate1.wait_for_endpoints('test_keyspace.0.master', 1) - l2vtgate1.wait_for_endpoints('test_keyspace.0.replica', 1) - l2vtgate1.wait_for_endpoints('test_keyspace.0.rdonly', 1) - l2vtgate1.wait_for_endpoints('test_keyspace.-80.master', 1) - l2vtgate1.wait_for_endpoints('test_keyspace.-80.replica', 1) - l2vtgate1.wait_for_endpoints('test_keyspace.-80.rdonly', 1) - l2vtgate1.verify_no_endpoint('test_keyspace.80-.master') - l2vtgate1.verify_no_endpoint('test_keyspace.80-.replica') - l2vtgate1.verify_no_endpoint('test_keyspace.80-.rdonly') - - # FIXME(alainjobart) we clear tablet_types_to_wait, as this - # l2vtgate2 doesn't serve the current test_keyspace shard, which - # is test_keyspace.0. This is not ideal, we should re-work - # which keyspace/shard a l2vtgate can wait for, as the ones - # filtered by tablet_filters. - l2vtgate2 = utils.VtGate() - l2vtgate2.start(extra_args=['--enable_forwarding', - '-tablet_filters', - 'test_keyspace|80-'], tablets= - [shard_1_master, shard_1_replica, shard_1_rdonly1], - tablet_types_to_wait='') - l2vtgate2.wait_for_endpoints('test_keyspace.80-.master', 1) - l2vtgate2.wait_for_endpoints('test_keyspace.80-.replica', 1) - l2vtgate2.wait_for_endpoints('test_keyspace.80-.rdonly', 1) - l2vtgate2.verify_no_endpoint('test_keyspace.0.master') - l2vtgate2.verify_no_endpoint('test_keyspace.0.replica') - l2vtgate2.verify_no_endpoint('test_keyspace.0.rdonly') - l2vtgate2.verify_no_endpoint('test_keyspace.-80.master') - l2vtgate2.verify_no_endpoint('test_keyspace.-80.replica') - l2vtgate2.verify_no_endpoint('test_keyspace.-80.rdonly') - - _, l2vtgate1_addr = l2vtgate1.rpc_endpoint() - _, l2vtgate2_addr = l2vtgate2.rpc_endpoint() - utils.vtgate = None - utils.VtGate().start(cache_ttl='0', l2vtgates=[l2vtgate1_addr, - l2vtgate2_addr,], - extra_args=['-disable_local_gateway']) - var = 'L2VtgateConnections' - else: - utils.vtgate = None - utils.VtGate().start(cache_ttl='0', tablets=[ - shard_master, shard_replica, shard_rdonly1, - shard_0_master, shard_0_replica, shard_0_rdonly1, - shard_1_master, shard_1_replica, shard_1_rdonly1]) - var = None + utils.vtgate = None + utils.VtGate().start(cache_ttl='0', tablets=[ + shard_master, shard_replica, shard_rdonly1, + shard_0_master, shard_0_replica, shard_0_rdonly1, + shard_1_master, shard_1_replica, shard_1_rdonly1]) + var = None # Wait for the endpoints, either local or remote. utils.vtgate.wait_for_endpoints('test_keyspace.0.master', 1, var=var) @@ -577,23 +494,32 @@ def test_resharding(self): min_statements=1000, min_transactions=1000) # use vtworker to compare the data - logging.debug('Running vtworker SplitDiff for -80') for t in [shard_0_rdonly1, shard_1_rdonly1]: utils.run_vtctl(['RunHealthCheck', t.tablet_alias]) - utils.run_vtworker(['-cell', 'test_nj', - '--use_v3_resharding_mode=false', - 'SplitDiff', - '--min_healthy_rdonly_tablets', '1', - 'test_keyspace/-80'], - auto_log=True) - - logging.debug('Running vtworker SplitDiff for 80-') - utils.run_vtworker(['-cell', 'test_nj', - '--use_v3_resharding_mode=false', - 'SplitDiff', - '--min_healthy_rdonly_tablets', '1', - 'test_keyspace/80-'], - auto_log=True) + + if base_sharding.use_multi_split_diff: + logging.debug('Running vtworker MultiSplitDiff for 0') + utils.run_vtworker(['-cell', 'test_nj', + '--use_v3_resharding_mode=false', + 'MultiSplitDiff', + '--min_healthy_rdonly_tablets', '1', + 'test_keyspace/0'], + auto_log=True) + else: + logging.debug('Running vtworker SplitDiff for -80') + utils.run_vtworker(['-cell', 'test_nj', + '--use_v3_resharding_mode=false', + 'SplitDiff', + '--min_healthy_rdonly_tablets', '1', + 'test_keyspace/-80'], + auto_log=True) + logging.debug('Running vtworker SplitDiff for 80-') + utils.run_vtworker(['-cell', 'test_nj', + '--use_v3_resharding_mode=false', + 'SplitDiff', + '--min_healthy_rdonly_tablets', '1', + 'test_keyspace/80-'], + auto_log=True) utils.pause('Good time to test vtworker for diffs') @@ -618,12 +544,9 @@ def test_resharding(self): # make sure rdonly tablets are back to serving before hitting vtgate. for t in [shard_0_rdonly1, shard_1_rdonly1]: t.wait_for_vttablet_state('SERVING') - if use_l2vtgate: - l2vtgate1.wait_for_endpoints('test_keyspace.-80.rdonly', 1) - l2vtgate2.wait_for_endpoints('test_keyspace.80-.rdonly', 1) - else: - utils.vtgate.wait_for_endpoints('test_keyspace.-80.rdonly', 1) - utils.vtgate.wait_for_endpoints('test_keyspace.80-.rdonly', 1) + + utils.vtgate.wait_for_endpoints('test_keyspace.-80.rdonly', 1) + utils.vtgate.wait_for_endpoints('test_keyspace.80-.rdonly', 1) # check the Map Reduce API works correctly, should use ExecuteKeyRanges # on both destination shards now. diff --git a/test/vtgatev2_l2vtgate_test.py b/test/initial_sharding_multi_split_diff.py similarity index 66% rename from test/vtgatev2_l2vtgate_test.py rename to test/initial_sharding_multi_split_diff.py index cf869a2701e..d21c41cf162 100755 --- a/test/vtgatev2_l2vtgate_test.py +++ b/test/initial_sharding_multi_split_diff.py @@ -1,26 +1,29 @@ #!/usr/bin/env python # # Copyright 2017 Google Inc. -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Re-runs vtgatev2_test.py with a l2vtgate process.""" +"""Re-runs initial_sharding.py using multiple-split-diff.""" + +from vtdb import keyrange_constants +import base_sharding +import initial_sharding import utils -import vtgatev2_test -# This test is just re-running an entire vtgatev2_test.py with a -# l2vtgate process in the middle. +# this test is just re-running an entire initial_sharding.py with a +# varbinary keyspace_id if __name__ == '__main__': - vtgatev2_test.use_l2vtgate = True - utils.main(vtgatev2_test) + base_sharding.use_multi_split_diff = True + utils.main(initial_sharding) diff --git a/test/resharding.py b/test/resharding.py index 86c55480a7b..a942c6dd466 100755 --- a/test/resharding.py +++ b/test/resharding.py @@ -880,14 +880,25 @@ def test_resharding(self): # use vtworker to compare the data (after health-checking the destination # rdonly tablets so discovery works) utils.run_vtctl(['RunHealthCheck', shard_3_rdonly1.tablet_alias]) - logging.debug('Running vtworker SplitDiff') - utils.run_vtworker(['-cell', 'test_nj', - '--use_v3_resharding_mode=false', - 'SplitDiff', - '--exclude_tables', 'unrelated', - '--min_healthy_rdonly_tablets', '1', - 'test_keyspace/c0-'], - auto_log=True) + + if base_sharding.use_multi_split_diff: + logging.debug('Running vtworker MultiSplitDiff') + utils.run_vtworker(['-cell', 'test_nj', + '--use_v3_resharding_mode=false', + 'MultiSplitDiff', + '--exclude_tables', 'unrelated', + '--min_healthy_rdonly_tablets', '1', + 'test_keyspace/80-'], + auto_log=True) + else: + logging.debug('Running vtworker SplitDiff') + utils.run_vtworker(['-cell', 'test_nj', + '--use_v3_resharding_mode=false', + 'SplitDiff', + '--exclude_tables', 'unrelated', + '--min_healthy_rdonly_tablets', '1', + 'test_keyspace/c0-'], + auto_log=True) utils.run_vtctl(['ChangeSlaveType', shard_1_rdonly1.tablet_alias, 'rdonly'], auto_log=True) utils.run_vtctl(['ChangeSlaveType', shard_3_rdonly1.tablet_alias, 'rdonly'], @@ -1067,14 +1078,25 @@ def test_resharding(self): self._check_lots_timeout(3000, 80, 10, base=2000) # use vtworker to compare the data again - logging.debug('Running vtworker SplitDiff') - utils.run_vtworker(['-cell', 'test_nj', - '--use_v3_resharding_mode=false', - 'SplitDiff', - '--exclude_tables', 'unrelated', - '--min_healthy_rdonly_tablets', '1', - 'test_keyspace/c0-'], - auto_log=True) + if base_sharding.use_multi_split_diff: + logging.debug('Running vtworker MultiSplitDiff') + utils.run_vtworker(['-cell', 'test_nj', + '--use_v3_resharding_mode=false', + 'MultiSplitDiff', + '--exclude_tables', 'unrelated', + '--min_healthy_rdonly_tablets', '1', + 'test_keyspace/80-'], + auto_log=True) + else: + logging.debug('Running vtworker SplitDiff') + utils.run_vtworker(['-cell', 'test_nj', + '--use_v3_resharding_mode=false', + 'SplitDiff', + '--exclude_tables', 'unrelated', + '--min_healthy_rdonly_tablets', '1', + 'test_keyspace/c0-'], + auto_log=True) + utils.run_vtctl(['ChangeSlaveType', shard_1_rdonly1.tablet_alias, 'rdonly'], auto_log=True) utils.run_vtctl(['ChangeSlaveType', shard_3_rdonly1.tablet_alias, 'rdonly'], diff --git a/test/initial_sharding_l2vtgate.py b/test/resharding_multi_split_diff.py similarity index 79% rename from test/initial_sharding_l2vtgate.py rename to test/resharding_multi_split_diff.py index 5a2a3df68a8..64413128816 100755 --- a/test/initial_sharding_l2vtgate.py +++ b/test/resharding_multi_split_diff.py @@ -14,11 +14,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Re-runs initial_sharding.py with a l2vtgate process.""" +"""Re-runs resharding.py with MultiSplitDiff.""" -import initial_sharding +import base_sharding +import resharding import utils if __name__ == '__main__': - initial_sharding.use_l2vtgate = True - utils.main(initial_sharding) + base_sharding.use_multi_split_diff = True + utils.main(resharding) diff --git a/test/vtgatev2_test.py b/test/vtgatev2_test.py index a95d640202f..4d45f44d307 100755 --- a/test/vtgatev2_test.py +++ b/test/vtgatev2_test.py @@ -35,16 +35,6 @@ from vtdb import vtgate_client from vtdb import vtgate_cursor -# use_l2vtgate controls if we're adding a l2vtgate process in between -# vtgate and the tablets. -use_l2vtgate = False - -# l2vtgate is the L2VTGate object, if any -l2vtgate = None - -# l2vtgate_addr is the address of the l2vtgate to send to vtgate -l2vtgate_addr = None - shard_0_master = tablet.Tablet() shard_0_replica1 = tablet.Tablet() shard_0_replica2 = tablet.Tablet() @@ -154,8 +144,6 @@ def tearDownModule(): logging.debug('Tearing down the servers and setup') if utils.vtgate: utils.vtgate.kill() - if l2vtgate: - l2vtgate.kill() tablet.kill_tablets([shard_0_master, shard_0_replica1, shard_0_replica2, shard_1_master, @@ -184,7 +172,6 @@ def tearDownModule(): def setup_tablets(): """Start up a master mysql and vttablet.""" - global l2vtgate, l2vtgate_addr logging.debug('Setting up tablets') utils.run_vtctl(['CreateKeyspace', KEYSPACE_NAME]) @@ -252,46 +239,22 @@ def setup_tablets(): 'Partitions(rdonly): -80 80-\n' 'Partitions(replica): -80 80-\n') - if use_l2vtgate: - l2vtgate = utils.VtGate() - l2vtgate.start(extra_args=['--enable_forwarding'], tablets= - [shard_0_master, shard_0_replica1, shard_0_replica2, - shard_1_master, shard_1_replica1, shard_1_replica2]) - _, l2vtgate_addr = l2vtgate.rpc_endpoint() - - # Clear utils.vtgate, so it doesn't point to the previous l2vtgate. - utils.vtgate = None - - # This vgate doesn't watch any local tablets, so we disable_local_gateway. - utils.VtGate().start(l2vtgates=[l2vtgate_addr,], - extra_args=['-disable_local_gateway']) - else: - utils.VtGate().start(tablets= - [shard_0_master, shard_0_replica1, shard_0_replica2, - shard_1_master, shard_1_replica1, shard_1_replica2]) + utils.VtGate().start(tablets= + [shard_0_master, shard_0_replica1, shard_0_replica2, + shard_1_master, shard_1_replica1, shard_1_replica2]) wait_for_all_tablets() def restart_vtgate(port): - if use_l2vtgate: - utils.VtGate(port=port).start(l2vtgates=[l2vtgate_addr,], - extra_args=['-disable_local_gateway']) - else: - utils.VtGate(port=port).start( - tablets=[shard_0_master, shard_0_replica1, shard_0_replica2, - shard_1_master, shard_1_replica1, shard_1_replica2]) + utils.VtGate(port=port).start( + tablets=[shard_0_master, shard_0_replica1, shard_0_replica2, + shard_1_master, shard_1_replica1, shard_1_replica2]) def wait_for_endpoints(name, count): - if use_l2vtgate: - # Wait for the l2vtgate to have a healthy connection. - l2vtgate.wait_for_endpoints(name, count) - # Also wait for vtgate to have received the remote healthy connection. - utils.vtgate.wait_for_endpoints(name, count, var='L2VtgateConnections') - else: - utils.vtgate.wait_for_endpoints(name, count) + utils.vtgate.wait_for_endpoints(name, count) def wait_for_all_tablets(): @@ -411,17 +374,12 @@ def test_query_routing(self): self.assertIn(kid, SHARD_KID_MAP[SHARD_NAMES[shard_index]]) # Do a cross shard range query and assert all rows are fetched. - # Use this test to also test the vtgate vars (and l2vtgate vars if - # applicable) are correctly updated. + # Use this test to also test the vtgate vars are correctly updated. v = utils.vtgate.get_vars() key0 = 'Execute.' + KEYSPACE_NAME + '.' + SHARD_NAMES[0] + '.master' key1 = 'Execute.' + KEYSPACE_NAME + '.' + SHARD_NAMES[1] + '.master' before0 = v['VttabletCall']['Histograms'][key0]['Count'] before1 = v['VttabletCall']['Histograms'][key1]['Count'] - if use_l2vtgate: - lv = l2vtgate.get_vars() - lbefore0 = lv['QueryServiceCall']['Histograms'][key0]['Count'] - lbefore1 = lv['QueryServiceCall']['Histograms'][key1]['Count'] cursor = vtgate_conn.cursor( tablet_type='master', keyspace=KEYSPACE_NAME, @@ -435,12 +393,6 @@ def test_query_routing(self): after1 = v['VttabletCall']['Histograms'][key1]['Count'] self.assertEqual(after0 - before0, 1) self.assertEqual(after1 - before1, 1) - if use_l2vtgate: - lv = l2vtgate.get_vars() - lafter0 = lv['QueryServiceCall']['Histograms'][key0]['Count'] - lafter1 = lv['QueryServiceCall']['Histograms'][key1]['Count'] - self.assertEqual(lafter0 - lbefore0, 1) - self.assertEqual(lafter1 - lbefore1, 1) def test_rollback(self): vtgate_conn = get_connection() diff --git a/tools/build_version_flags.sh b/tools/build_version_flags.sh index ae5aae32e97..dc0ad5d630b 100755 --- a/tools/build_version_flags.sh +++ b/tools/build_version_flags.sh @@ -17,11 +17,24 @@ DIR=$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd ) source $DIR/shell_functions.inc +# Normal builds run directly against the git repo, but when packaging (for example with rpms) +# a tar ball might be used, which will prevent the git metadata from being available. +# Should this be the case then allow environment variables to be used to source +# this information instead. +_build_git_rev=$(git rev-parse --short HEAD) +if [ -z "$_build_git_rev" ]; then + _build_git_rev="$BUILD_GIT_REV" +fi +_build_git_branch=$(git rev-parse --abbrev-ref HEAD) +if [ -z "$_build_git_branch" ]; then + _build_git_branch="$BUILD_GIT_BRANCH" +fi + echo "\ -X 'vitess.io/vitess/go/vt/servenv.buildHost=$(hostname)' \ -X 'vitess.io/vitess/go/vt/servenv.buildUser=$(whoami)' \ - -X 'vitess.io/vitess/go/vt/servenv.buildGitRev=$(git rev-parse --short HEAD)' \ - -X 'vitess.io/vitess/go/vt/servenv.buildGitBranch=$(git rev-parse --abbrev-ref HEAD)' \ + -X 'vitess.io/vitess/go/vt/servenv.buildGitRev=${_build_git_rev}' \ + -X 'vitess.io/vitess/go/vt/servenv.buildGitBranch=${_build_git_branch}' \ -X 'vitess.io/vitess/go/vt/servenv.buildTime=$(LC_ALL=C date)' \ -X 'vitess.io/vitess/go/vt/servenv.jenkinsBuildNumberStr=${BUILD_NUMBER}' \ "