Skip to content

Commit

Permalink
Merge pull request hashicorp#238 from hashicorp/issue_237
Browse files Browse the repository at this point in the history
  • Loading branch information
preetapan authored Aug 30, 2017
2 parents 2356637 + db0c156 commit c837e57
Show file tree
Hide file tree
Showing 11 changed files with 400 additions and 286 deletions.
1 change: 1 addition & 0 deletions api.go
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,7 @@ func NewRaft(conf *Config, fsm FSM, logs LogStore, stable StableStore, snaps Sna
}
r.processConfigurationLogEntry(&entry)
}

r.logger.Printf("[INFO] raft: Initial configuration (index=%d): %+v",
r.configurations.latestIndex, r.configurations.latest.Servers)

Expand Down
2 changes: 1 addition & 1 deletion configuration.go
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ func encodePeers(configuration Configuration, trans Transport) []byte {
var encPeers [][]byte
for _, server := range configuration.Servers {
if server.Suffrage == Voter {
encPeers = append(encPeers, trans.EncodePeer(server.Address))
encPeers = append(encPeers, trans.EncodePeer(server.ID, server.Address))
}
}

Expand Down
10 changes: 5 additions & 5 deletions inmem_transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ func (i *InmemTransport) LocalAddr() ServerAddress {

// AppendEntriesPipeline returns an interface that can be used to pipeline
// AppendEntries requests.
func (i *InmemTransport) AppendEntriesPipeline(target ServerAddress) (AppendPipeline, error) {
func (i *InmemTransport) AppendEntriesPipeline(id ServerID, target ServerAddress) (AppendPipeline, error) {
i.RLock()
peer, ok := i.peers[target]
i.RUnlock()
Expand All @@ -90,7 +90,7 @@ func (i *InmemTransport) AppendEntriesPipeline(target ServerAddress) (AppendPipe
}

// AppendEntries implements the Transport interface.
func (i *InmemTransport) AppendEntries(target ServerAddress, args *AppendEntriesRequest, resp *AppendEntriesResponse) error {
func (i *InmemTransport) AppendEntries(id ServerID, target ServerAddress, args *AppendEntriesRequest, resp *AppendEntriesResponse) error {
rpcResp, err := i.makeRPC(target, args, nil, i.timeout)
if err != nil {
return err
Expand All @@ -103,7 +103,7 @@ func (i *InmemTransport) AppendEntries(target ServerAddress, args *AppendEntries
}

// RequestVote implements the Transport interface.
func (i *InmemTransport) RequestVote(target ServerAddress, args *RequestVoteRequest, resp *RequestVoteResponse) error {
func (i *InmemTransport) RequestVote(id ServerID, target ServerAddress, args *RequestVoteRequest, resp *RequestVoteResponse) error {
rpcResp, err := i.makeRPC(target, args, nil, i.timeout)
if err != nil {
return err
Expand All @@ -116,7 +116,7 @@ func (i *InmemTransport) RequestVote(target ServerAddress, args *RequestVoteRequ
}

// InstallSnapshot implements the Transport interface.
func (i *InmemTransport) InstallSnapshot(target ServerAddress, args *InstallSnapshotRequest, resp *InstallSnapshotResponse, data io.Reader) error {
func (i *InmemTransport) InstallSnapshot(id ServerID, target ServerAddress, args *InstallSnapshotRequest, resp *InstallSnapshotResponse, data io.Reader) error {
rpcResp, err := i.makeRPC(target, args, data, 10*i.timeout)
if err != nil {
return err
Expand Down Expand Up @@ -159,7 +159,7 @@ func (i *InmemTransport) makeRPC(target ServerAddress, args interface{}, r io.Re
}

// EncodePeer implements the Transport interface.
func (i *InmemTransport) EncodePeer(p ServerAddress) []byte {
func (i *InmemTransport) EncodePeer(id ServerID, p ServerAddress) []byte {
return []byte(p)
}

Expand Down
112 changes: 83 additions & 29 deletions net_transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ type NetworkTransport struct {

maxPool int

serverAddressProvider ServerAddressProvider

shutdown bool
shutdownCh chan struct{}
shutdownLock sync.Mutex
Expand All @@ -78,6 +80,28 @@ type NetworkTransport struct {
TimeoutScale int
}

// NetworkTransportConfig encapsulates configuration for the network transport layer.
type NetworkTransportConfig struct {
// ServerAddressProvider is used to override the target address when establishing a connection to invoke an RPC
ServerAddressProvider ServerAddressProvider

Logger *log.Logger

// Dialer
Stream StreamLayer

// MaxPool controls how many connections we will pool
MaxPool int

// Timeout is used to apply I/O deadlines. For InstallSnapshot, we multiply
// the timeout by (SnapshotSize / TimeoutScale).
Timeout time.Duration
}

type ServerAddressProvider interface {
ServerAddr(id ServerID) (ServerAddress, error)
}

// StreamLayer is used with the NetworkTransport to provide
// the low level stream abstraction.
type StreamLayer interface {
Expand Down Expand Up @@ -112,6 +136,28 @@ type netPipeline struct {
shutdownLock sync.Mutex
}

// NewNetworkTransportWithConfig creates a new network transport with the given config struct
func NewNetworkTransportWithConfig(
config *NetworkTransportConfig,
) *NetworkTransport {
if config.Logger == nil {
config.Logger = log.New(os.Stderr, "", log.LstdFlags)
}
trans := &NetworkTransport{
connPool: make(map[ServerAddress][]*netConn),
consumeCh: make(chan RPC),
logger: config.Logger,
maxPool: config.MaxPool,
shutdownCh: make(chan struct{}),
stream: config.Stream,
timeout: config.Timeout,
TimeoutScale: DefaultTimeoutScale,
serverAddressProvider: config.ServerAddressProvider,
}
go trans.listen()
return trans
}

// NewNetworkTransport creates a new network transport with the given dialer
// and listener. The maxPool controls how many connections we will pool. The
// timeout is used to apply I/O deadlines. For InstallSnapshot, we multiply
Expand All @@ -125,10 +171,12 @@ func NewNetworkTransport(
if logOutput == nil {
logOutput = os.Stderr
}
return NewNetworkTransportWithLogger(stream, maxPool, timeout, log.New(logOutput, "", log.LstdFlags))
logger := log.New(logOutput, "", log.LstdFlags)
config := &NetworkTransportConfig{Stream: stream, MaxPool: maxPool, Timeout: timeout, Logger: logger}
return NewNetworkTransportWithConfig(config)
}

// NewNetworkTransportWithLogger creates a new network transport with the given dialer
// NewNetworkTransportWithLogger creates a new network transport with the given logger, dialer
// and listener. The maxPool controls how many connections we will pool. The
// timeout is used to apply I/O deadlines. For InstallSnapshot, we multiply
// the timeout by (SnapshotSize / TimeoutScale).
Expand All @@ -138,21 +186,8 @@ func NewNetworkTransportWithLogger(
timeout time.Duration,
logger *log.Logger,
) *NetworkTransport {
if logger == nil {
logger = log.New(os.Stderr, "", log.LstdFlags)
}
trans := &NetworkTransport{
connPool: make(map[ServerAddress][]*netConn),
consumeCh: make(chan RPC),
logger: logger,
maxPool: maxPool,
shutdownCh: make(chan struct{}),
stream: stream,
timeout: timeout,
TimeoutScale: DefaultTimeoutScale,
}
go trans.listen()
return trans
config := &NetworkTransportConfig{Stream: stream, MaxPool: maxPool, Timeout: timeout, Logger: logger}
return NewNetworkTransportWithConfig(config)
}

// SetHeartbeatHandler is used to setup a heartbeat handler
Expand Down Expand Up @@ -214,6 +249,24 @@ func (n *NetworkTransport) getPooledConn(target ServerAddress) *netConn {
return conn
}

// getConnFromAddressProvider returns a connection from the server address provider if available, or defaults to a connection using the target server address
func (n *NetworkTransport) getConnFromAddressProvider(id ServerID, target ServerAddress) (*netConn, error) {
address := n.getProviderAddressOrFallback(id, target)
return n.getConn(address)
}

func (n *NetworkTransport) getProviderAddressOrFallback(id ServerID, target ServerAddress) ServerAddress {
if n.serverAddressProvider != nil {
serverAddressOverride, err := n.serverAddressProvider.ServerAddr(id)
if err != nil {
n.logger.Printf("[WARN] Unable to get address for server id %v, using fallback address %v: %v", id, target, err)
} else {
return serverAddressOverride
}
}
return target
}

// getConn is used to get a connection from the pool.
func (n *NetworkTransport) getConn(target ServerAddress) (*netConn, error) {
// Check for a pooled conn
Expand Down Expand Up @@ -260,9 +313,9 @@ func (n *NetworkTransport) returnConn(conn *netConn) {

// AppendEntriesPipeline returns an interface that can be used to pipeline
// AppendEntries requests.
func (n *NetworkTransport) AppendEntriesPipeline(target ServerAddress) (AppendPipeline, error) {
func (n *NetworkTransport) AppendEntriesPipeline(id ServerID, target ServerAddress) (AppendPipeline, error) {
// Get a connection
conn, err := n.getConn(target)
conn, err := n.getConnFromAddressProvider(id, target)
if err != nil {
return nil, err
}
Expand All @@ -272,19 +325,19 @@ func (n *NetworkTransport) AppendEntriesPipeline(target ServerAddress) (AppendPi
}

// AppendEntries implements the Transport interface.
func (n *NetworkTransport) AppendEntries(target ServerAddress, args *AppendEntriesRequest, resp *AppendEntriesResponse) error {
return n.genericRPC(target, rpcAppendEntries, args, resp)
func (n *NetworkTransport) AppendEntries(id ServerID, target ServerAddress, args *AppendEntriesRequest, resp *AppendEntriesResponse) error {
return n.genericRPC(id, target, rpcAppendEntries, args, resp)
}

// RequestVote implements the Transport interface.
func (n *NetworkTransport) RequestVote(target ServerAddress, args *RequestVoteRequest, resp *RequestVoteResponse) error {
return n.genericRPC(target, rpcRequestVote, args, resp)
func (n *NetworkTransport) RequestVote(id ServerID, target ServerAddress, args *RequestVoteRequest, resp *RequestVoteResponse) error {
return n.genericRPC(id, target, rpcRequestVote, args, resp)
}

// genericRPC handles a simple request/response RPC.
func (n *NetworkTransport) genericRPC(target ServerAddress, rpcType uint8, args interface{}, resp interface{}) error {
func (n *NetworkTransport) genericRPC(id ServerID, target ServerAddress, rpcType uint8, args interface{}, resp interface{}) error {
// Get a conn
conn, err := n.getConn(target)
conn, err := n.getConnFromAddressProvider(id, target)
if err != nil {
return err
}
Expand All @@ -308,9 +361,9 @@ func (n *NetworkTransport) genericRPC(target ServerAddress, rpcType uint8, args
}

// InstallSnapshot implements the Transport interface.
func (n *NetworkTransport) InstallSnapshot(target ServerAddress, args *InstallSnapshotRequest, resp *InstallSnapshotResponse, data io.Reader) error {
func (n *NetworkTransport) InstallSnapshot(id ServerID, target ServerAddress, args *InstallSnapshotRequest, resp *InstallSnapshotResponse, data io.Reader) error {
// Get a conn, always close for InstallSnapshot
conn, err := n.getConn(target)
conn, err := n.getConnFromAddressProvider(id, target)
if err != nil {
return err
}
Expand Down Expand Up @@ -346,8 +399,9 @@ func (n *NetworkTransport) InstallSnapshot(target ServerAddress, args *InstallSn
}

// EncodePeer implements the Transport interface.
func (n *NetworkTransport) EncodePeer(p ServerAddress) []byte {
return []byte(p)
func (n *NetworkTransport) EncodePeer(id ServerID, p ServerAddress) []byte {
address := n.getProviderAddressOrFallback(id, p)
return []byte(address)
}

// DecodePeer implements the Transport interface.
Expand Down
Loading

0 comments on commit c837e57

Please sign in to comment.