Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes tests and plumbs more ServerAddress vs. string parameters. #127

Merged
merged 9 commits into from
Jul 1, 2016
Merged
48 changes: 23 additions & 25 deletions inmem_transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@ import (

// NewInmemAddr returns a new in-memory addr with
// a randomly generate UUID as the ID.
func NewInmemAddr() string {
return generateUUID()
func NewInmemAddr() ServerAddress {
return ServerAddress(generateUUID())
}

// inmemPipeline is used to pipeline requests for the in-mem transport.
type inmemPipeline struct {
trans *InmemTransport
peer *InmemTransport
peerAddr string
peerAddr ServerAddress

doneCh chan AppendFuture
inprogressCh chan *inmemPipelineInflight
Expand All @@ -37,22 +37,22 @@ type inmemPipelineInflight struct {
type InmemTransport struct {
sync.RWMutex
consumerCh chan RPC
localAddr string
peers map[string]*InmemTransport
localAddr ServerAddress
peers map[ServerAddress]*InmemTransport
pipelines []*inmemPipeline
timeout time.Duration
}

// NewInmemTransport is used to initialize a new transport
// and generates a random local address if none is specified
func NewInmemTransport(addr string) (string, *InmemTransport) {
if addr == "" {
func NewInmemTransport(addr ServerAddress) (ServerAddress, *InmemTransport) {
if string(addr) == "" {
addr = NewInmemAddr()
}
trans := &InmemTransport{
consumerCh: make(chan RPC, 16),
localAddr: addr,
peers: make(map[string]*InmemTransport),
peers: make(map[ServerAddress]*InmemTransport),
timeout: 50 * time.Millisecond,
}
return addr, trans
Expand All @@ -69,13 +69,13 @@ func (i *InmemTransport) Consumer() <-chan RPC {
}

// LocalAddr implements the Transport interface.
func (i *InmemTransport) LocalAddr() string {
func (i *InmemTransport) LocalAddr() ServerAddress {
return i.localAddr
}

// AppendEntriesPipeline returns an interface that can be used to pipeline
// AppendEntries requests.
func (i *InmemTransport) AppendEntriesPipeline(target string) (AppendPipeline, error) {
func (i *InmemTransport) AppendEntriesPipeline(target ServerAddress) (AppendPipeline, error) {
i.RLock()
peer, ok := i.peers[target]
i.RUnlock()
Expand All @@ -90,7 +90,7 @@ func (i *InmemTransport) AppendEntriesPipeline(target string) (AppendPipeline, e
}

// AppendEntries implements the Transport interface.
func (i *InmemTransport) AppendEntries(target string, args *AppendEntriesRequest, resp *AppendEntriesResponse) error {
func (i *InmemTransport) AppendEntries(target ServerAddress, args *AppendEntriesRequest, resp *AppendEntriesResponse) error {
rpcResp, err := i.makeRPC(target, args, nil, i.timeout)
if err != nil {
return err
Expand All @@ -103,7 +103,7 @@ func (i *InmemTransport) AppendEntries(target string, args *AppendEntriesRequest
}

// RequestVote implements the Transport interface.
func (i *InmemTransport) RequestVote(target string, args *RequestVoteRequest, resp *RequestVoteResponse) error {
func (i *InmemTransport) RequestVote(target ServerAddress, args *RequestVoteRequest, resp *RequestVoteResponse) error {
rpcResp, err := i.makeRPC(target, args, nil, i.timeout)
if err != nil {
return err
Expand All @@ -116,7 +116,7 @@ func (i *InmemTransport) RequestVote(target string, args *RequestVoteRequest, re
}

// InstallSnapshot implements the Transport interface.
func (i *InmemTransport) InstallSnapshot(target string, args *InstallSnapshotRequest, resp *InstallSnapshotResponse, data io.Reader) error {
func (i *InmemTransport) InstallSnapshot(target ServerAddress, args *InstallSnapshotRequest, resp *InstallSnapshotResponse, data io.Reader) error {
rpcResp, err := i.makeRPC(target, args, data, 10*i.timeout)
if err != nil {
return err
Expand All @@ -128,7 +128,7 @@ func (i *InmemTransport) InstallSnapshot(target string, args *InstallSnapshotReq
return nil
}

func (i *InmemTransport) makeRPC(target string, args interface{}, r io.Reader, timeout time.Duration) (rpcResp RPCResponse, err error) {
func (i *InmemTransport) makeRPC(target ServerAddress, args interface{}, r io.Reader, timeout time.Duration) (rpcResp RPCResponse, err error) {
i.RLock()
peer, ok := i.peers[target]
i.RUnlock()
Expand Down Expand Up @@ -158,29 +158,27 @@ func (i *InmemTransport) makeRPC(target string, args interface{}, r io.Reader, t
return
}

// EncodePeer implements the Transport interface. It uses the UUID as the
// address directly.
func (i *InmemTransport) EncodePeer(p string) []byte {
// EncodePeer implements the Transport interface.
func (i *InmemTransport) EncodePeer(p ServerAddress) []byte {
return []byte(p)
}

// DecodePeer implements the Transport interface. It wraps the UUID in an
// InmemAddr.
func (i *InmemTransport) DecodePeer(buf []byte) string {
return string(buf)
// DecodePeer implements the Transport interface.
func (i *InmemTransport) DecodePeer(buf []byte) ServerAddress {
return ServerAddress(buf)
}

// Connect is used to connect this transport to another transport for
// a given peer name. This allows for local routing.
func (i *InmemTransport) Connect(peer string, t Transport) {
func (i *InmemTransport) Connect(peer ServerAddress, t Transport) {
trans := t.(*InmemTransport)
i.Lock()
defer i.Unlock()
i.peers[peer] = trans
}

// Disconnect is used to remove the ability to route to a given peer.
func (i *InmemTransport) Disconnect(peer string) {
func (i *InmemTransport) Disconnect(peer ServerAddress) {
i.Lock()
defer i.Unlock()
delete(i.peers, peer)
Expand All @@ -202,7 +200,7 @@ func (i *InmemTransport) Disconnect(peer string) {
func (i *InmemTransport) DisconnectAll() {
i.Lock()
defer i.Unlock()
i.peers = make(map[string]*InmemTransport)
i.peers = make(map[ServerAddress]*InmemTransport)

// Handle pipelines
for _, pipeline := range i.pipelines {
Expand All @@ -217,7 +215,7 @@ func (i *InmemTransport) Close() error {
return nil
}

func newInmemPipeline(trans *InmemTransport, peer *InmemTransport, addr string) *inmemPipeline {
func newInmemPipeline(trans *InmemTransport, peer *InmemTransport, addr ServerAddress) *inmemPipeline {
i := &inmemPipeline{
trans: trans,
peer: peer,
Expand Down
15 changes: 13 additions & 2 deletions integ_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,17 @@ func MakeRaft(t *testing.T, conf *Config) *RaftEnv {
}
env.trans = trans

var configuration Configuration
configuration.Servers = append(configuration.Servers, Server{
Suffrage: Voter,
ID: ServerID(trans.LocalAddr()),
Address: trans.LocalAddr(),
})
err = BootstrapCluster(conf, stable, stable, snap, configuration)
if err != nil {
t.Fatalf("err: %v", err)
}

log.Printf("[INFO] Starting node at %v", trans.LocalAddr())
raft, err := NewRaft(conf, env.fsm, stable, stable, snap, trans)
if err != nil {
Expand Down Expand Up @@ -238,8 +249,8 @@ func TestRaft_Integ(t *testing.T) {
}

// Remove the old nodes
NoErr(WaitFuture(leader.raft.RemovePeer(string(rm1.raft.localAddr)), t), t)
NoErr(WaitFuture(leader.raft.RemovePeer(string(rm2.raft.localAddr)), t), t)
NoErr(WaitFuture(leader.raft.RemovePeer(rm1.raft.localAddr), t), t)
NoErr(WaitFuture(leader.raft.RemovePeer(rm2.raft.localAddr), t), t)

// Shoot the leader
env1.Release()
Expand Down
32 changes: 16 additions & 16 deletions net_transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ is not known if there is an error.

*/
type NetworkTransport struct {
connPool map[string][]*netConn
connPool map[ServerAddress][]*netConn
connPoolLock sync.Mutex

consumeCh chan RPC
Expand Down Expand Up @@ -84,11 +84,11 @@ type StreamLayer interface {
net.Listener

// Dial is used to create a new outgoing connection
Dial(address string, timeout time.Duration) (net.Conn, error)
Dial(address ServerAddress, timeout time.Duration) (net.Conn, error)
}

type netConn struct {
target string
target ServerAddress
conn net.Conn
r *bufio.Reader
w *bufio.Writer
Expand Down Expand Up @@ -142,7 +142,7 @@ func NewNetworkTransportWithLogger(
logger = log.New(os.Stderr, "", log.LstdFlags)
}
trans := &NetworkTransport{
connPool: make(map[string][]*netConn),
connPool: make(map[ServerAddress][]*netConn),
consumeCh: make(chan RPC),
logger: logger,
maxPool: maxPool,
Expand Down Expand Up @@ -183,8 +183,8 @@ func (n *NetworkTransport) Consumer() <-chan RPC {
}

// LocalAddr implements the Transport interface.
func (n *NetworkTransport) LocalAddr() string {
return n.stream.Addr().String()
func (n *NetworkTransport) LocalAddr() ServerAddress {
return ServerAddress(n.stream.Addr().String())
}

// IsShutdown is used to check if the transport is shutdown.
Expand All @@ -198,7 +198,7 @@ func (n *NetworkTransport) IsShutdown() bool {
}

// getExistingConn is used to grab a pooled connection.
func (n *NetworkTransport) getPooledConn(target string) *netConn {
func (n *NetworkTransport) getPooledConn(target ServerAddress) *netConn {
n.connPoolLock.Lock()
defer n.connPoolLock.Unlock()

Expand All @@ -215,7 +215,7 @@ func (n *NetworkTransport) getPooledConn(target string) *netConn {
}

// getConn is used to get a connection from the pool.
func (n *NetworkTransport) getConn(target string) (*netConn, error) {
func (n *NetworkTransport) getConn(target ServerAddress) (*netConn, error) {
// Check for a pooled conn
if conn := n.getPooledConn(target); conn != nil {
return conn, nil
Expand Down Expand Up @@ -260,7 +260,7 @@ func (n *NetworkTransport) returnConn(conn *netConn) {

// AppendEntriesPipeline returns an interface that can be used to pipeline
// AppendEntries requests.
func (n *NetworkTransport) AppendEntriesPipeline(target string) (AppendPipeline, error) {
func (n *NetworkTransport) AppendEntriesPipeline(target ServerAddress) (AppendPipeline, error) {
// Get a connection
conn, err := n.getConn(target)
if err != nil {
Expand All @@ -272,17 +272,17 @@ func (n *NetworkTransport) AppendEntriesPipeline(target string) (AppendPipeline,
}

// AppendEntries implements the Transport interface.
func (n *NetworkTransport) AppendEntries(target string, args *AppendEntriesRequest, resp *AppendEntriesResponse) error {
func (n *NetworkTransport) AppendEntries(target ServerAddress, args *AppendEntriesRequest, resp *AppendEntriesResponse) error {
return n.genericRPC(target, rpcAppendEntries, args, resp)
}

// RequestVote implements the Transport interface.
func (n *NetworkTransport) RequestVote(target string, args *RequestVoteRequest, resp *RequestVoteResponse) error {
func (n *NetworkTransport) RequestVote(target ServerAddress, args *RequestVoteRequest, resp *RequestVoteResponse) error {
return n.genericRPC(target, rpcRequestVote, args, resp)
}

// genericRPC handles a simple request/response RPC.
func (n *NetworkTransport) genericRPC(target string, rpcType uint8, args interface{}, resp interface{}) error {
func (n *NetworkTransport) genericRPC(target ServerAddress, rpcType uint8, args interface{}, resp interface{}) error {
// Get a conn
conn, err := n.getConn(target)
if err != nil {
Expand All @@ -308,7 +308,7 @@ func (n *NetworkTransport) genericRPC(target string, rpcType uint8, args interfa
}

// InstallSnapshot implements the Transport interface.
func (n *NetworkTransport) InstallSnapshot(target string, args *InstallSnapshotRequest, resp *InstallSnapshotResponse, data io.Reader) error {
func (n *NetworkTransport) InstallSnapshot(target ServerAddress, args *InstallSnapshotRequest, resp *InstallSnapshotResponse, data io.Reader) error {
// Get a conn, always close for InstallSnapshot
conn, err := n.getConn(target)
if err != nil {
Expand Down Expand Up @@ -346,13 +346,13 @@ func (n *NetworkTransport) InstallSnapshot(target string, args *InstallSnapshotR
}

// EncodePeer implements the Transport interface.
func (n *NetworkTransport) EncodePeer(p string) []byte {
func (n *NetworkTransport) EncodePeer(p ServerAddress) []byte {
return []byte(p)
}

// DecodePeer implements the Transport interface.
func (n *NetworkTransport) DecodePeer(buf []byte) string {
return string(buf)
func (n *NetworkTransport) DecodePeer(buf []byte) ServerAddress {
return ServerAddress(buf)
}

// listen is used to handling incoming connections.
Expand Down
5 changes: 0 additions & 5 deletions observer.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,6 @@ type Observation struct {
Data interface{}
}

// LeaderObservation is used for the data when leadership changes.
type LeaderObservation struct {
leader string
}

// nextObserverId is used to provide a unique ID for each observer to aid in
// deregistration.
var nextObserverId uint64
Expand Down
22 changes: 8 additions & 14 deletions raft.go
Original file line number Diff line number Diff line change
Expand Up @@ -270,9 +270,8 @@ func NewRaft(conf *Config, fsm FSM, logs LogStore, stable StableStore, snaps Sna
localAddr := ServerAddress(trans.LocalAddr())
localID := conf.LocalID
if localID == "" {
logger.Printf("[WARN] raft: No server ID given, using network address: %v",
logger.Printf("[WARN] raft: No server ID given, using network address: %v. This default will be removed in the future. Set server ID explicitly in config.",
localAddr)
logger.Printf("[WARN] raft: This default will be removed in the future. Set server ID explicitly in Config")
localID = ServerID(localAddr)
}

Expand Down Expand Up @@ -353,23 +352,18 @@ func NewRaft(conf *Config, fsm FSM, logs LogStore, stable StableStore, snaps Sna
// Leader is used to return the current leader of the cluster.
// It may return empty string if there is no current leader
// or the leader is unknown.
func (r *Raft) Leader() string {
// TODO: change return type to ServerAddress?
func (r *Raft) Leader() ServerAddress {
r.leaderLock.RLock()
leader := string(r.leader)
leader := r.leader
r.leaderLock.RUnlock()
return leader
}

// setLeader is used to modify the current leader of the cluster
func (r *Raft) setLeader(leader ServerAddress) {
r.leaderLock.Lock()
oldLeader := r.leader
r.leader = leader
r.leaderLock.Unlock()
if oldLeader != leader {
r.observe(LeaderObservation{leader: string(leader)})
}
}

// Apply is used to apply a command to the FSM in a highly consistent
Expand Down Expand Up @@ -452,15 +446,15 @@ func (r *Raft) VerifyLeader() Future {

// AddPeer (deprecated) is used to add a new peer into the cluster. This must be
// run on the leader or it will fail. Use AddVoter/AddNonvoter instead.
func (r *Raft) AddPeer(peer string) Future {
return r.AddVoter(ServerID(peer), ServerAddress(peer), 0, 0)
func (r *Raft) AddPeer(peer ServerAddress) Future {
return r.AddVoter(ServerID(peer), peer, 0, 0)
}

// RemovePeer (deprecated) is used to remove a peer from the cluster. If the
// current leader is being removed, it will cause a new election
// to occur. This must be run on the leader or it will fail.
// Use RemoveServer instead.
func (r *Raft) RemovePeer(peer string) Future {
func (r *Raft) RemovePeer(peer ServerAddress) Future {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I originally wanted to keep this as string but this one was confusing (it's an address that we assume is the same as the ID). The trivial work to fix this is probably worth the clarity that the type provides.

return r.RemoveServer(ServerID(peer), 0, 0)
}

Expand Down Expand Up @@ -1866,7 +1860,7 @@ func (r *Raft) electSelf() <-chan *voteResult {
lastIdx, lastTerm := r.getLastEntry()
req := &RequestVoteRequest{
Term: r.getCurrentTerm(),
Candidate: r.trans.EncodePeer(string(r.localAddr)),
Candidate: r.trans.EncodePeer(r.localAddr),
LastLogIndex: lastIdx,
LastLogTerm: lastTerm,
}
Expand All @@ -1876,7 +1870,7 @@ func (r *Raft) electSelf() <-chan *voteResult {
r.goFunc(func() {
defer metrics.MeasureSince([]string{"raft", "candidate", "electSelf"}, time.Now())
resp := &voteResult{voterID: peer.ID}
err := r.trans.RequestVote(string(peer.Address), req, &resp.RequestVoteResponse)
err := r.trans.RequestVote(peer.Address, req, &resp.RequestVoteResponse)
if err != nil {
r.logger.Printf("[ERR] raft: Failed to make RequestVote RPC to %v: %v", peer, err)
resp.Term = req.Term
Expand Down
Loading