Skip to content

Commit

Permalink
Merge pull request #7044 from hashicorp/f-use-multiplexv2
Browse files Browse the repository at this point in the history
rpc: Use MultiplexV2 for connections
  • Loading branch information
Mahmood Ali authored Feb 13, 2020
2 parents 8d222b6 + cac99e1 commit f6cf206
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 42 deletions.
36 changes: 31 additions & 5 deletions helper/pool/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func (c *Conn) Close() error {
}

// getClient is used to get a cached or new client
func (c *Conn) getClient() (*StreamClient, error) {
func (c *Conn) getRPCClient() (*StreamClient, error) {
// Check for cached client
c.clientLock.Lock()
front := c.clients.Front()
Expand All @@ -85,6 +85,11 @@ func (c *Conn) getClient() (*StreamClient, error) {
return nil, err
}

if _, err := stream.Write([]byte{byte(RpcNomad)}); err != nil {
stream.Close()
return nil, err
}

// Create a client codec
codec := NewClientCodec(stream)

Expand Down Expand Up @@ -332,7 +337,7 @@ func (p *ConnPool) getNewConn(region string, addr net.Addr, version int) (*Conn,
}

// Write the multiplex byte to set the mode
if _, err := conn.Write([]byte{byte(RpcMultiplex)}); err != nil {
if _, err := conn.Write([]byte{byte(RpcMultiplexV2)}); err != nil {
conn.Close()
return nil, err
}
Expand Down Expand Up @@ -390,7 +395,7 @@ func (p *ConnPool) releaseConn(conn *Conn) {
}

// getClient is used to get a usable client for an address and protocol version
func (p *ConnPool) getClient(region string, addr net.Addr, version int) (*Conn, *StreamClient, error) {
func (p *ConnPool) getRPCClient(region string, addr net.Addr, version int) (*Conn, *StreamClient, error) {
retries := 0
START:
// Try to get a conn first
Expand All @@ -400,7 +405,7 @@ START:
}

// Get a client
client, err := conn.getClient()
client, err := conn.getRPCClient()
if err != nil {
p.clearConn(conn)
p.releaseConn(conn)
Expand All @@ -415,10 +420,31 @@ START:
return conn, client, nil
}

// StreamingRPC is used to make an streaming RPC call. Callers must
// close the connection when done.
func (p *ConnPool) StreamingRPC(region string, addr net.Addr, version int) (net.Conn, error) {
conn, err := p.acquire(region, addr, version)
if err != nil {
return nil, fmt.Errorf("failed to get conn: %v", err)
}

s, err := conn.session.Open()
if err != nil {
return nil, fmt.Errorf("failed to open a streaming connection: %v", err)
}

if _, err := s.Write([]byte{byte(RpcStreaming)}); err != nil {
conn.Close()
return nil, err
}

return s, nil
}

// RPC is used to make an RPC call to a remote host
func (p *ConnPool) RPC(region string, addr net.Addr, version int, method string, args interface{}, reply interface{}) error {
// Get a usable client
conn, sc, err := p.getClient(region, addr, version)
conn, sc, err := p.getRPCClient(region, addr, version)
if err != nil {
return fmt.Errorf("rpc error: %v", err)
}
Expand Down
39 changes: 3 additions & 36 deletions nomad/rpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -653,52 +653,19 @@ func (r *rpcHandler) getServer(region, serverID string) (*serverParts, error) {
// initial handshake, returning the connection or an error. It is the callers
// responsibility to close the connection if there is no returned error.
func (r *rpcHandler) streamingRpc(server *serverParts, method string) (net.Conn, error) {
// Try to dial the server
conn, err := net.DialTimeout("tcp", server.Addr.String(), 10*time.Second)
c, err := r.connPool.StreamingRPC(r.config.Region, server.Addr, server.MajorVersion)
if err != nil {
return nil, err
}

// Cast to TCPConn
if tcp, ok := conn.(*net.TCPConn); ok {
tcp.SetKeepAlive(true)
tcp.SetNoDelay(true)
}

return r.streamingRpcImpl(conn, server.Region, method)
return r.streamingRpcImpl(c, method)
}

// streamingRpcImpl takes a pre-established connection to a server and conducts
// the handshake to establish a streaming RPC for the given method. If an error
// is returned, the underlying connection has been closed. Otherwise it is
// assumed that the connection has been hijacked by the RPC method.
func (r *rpcHandler) streamingRpcImpl(conn net.Conn, region, method string) (net.Conn, error) {
// Check if TLS is enabled
r.tlsWrapLock.RLock()
tlsWrap := r.tlsWrap
r.tlsWrapLock.RUnlock()

if tlsWrap != nil {
// Switch the connection into TLS mode
if _, err := conn.Write([]byte{byte(pool.RpcTLS)}); err != nil {
conn.Close()
return nil, err
}

// Wrap the connection in a TLS client
tlsConn, err := tlsWrap(region, conn)
if err != nil {
conn.Close()
return nil, err
}
conn = tlsConn
}

// Write the multiplex byte to set the mode
if _, err := conn.Write([]byte{byte(pool.RpcStreaming)}); err != nil {
conn.Close()
return nil, err
}
func (r *rpcHandler) streamingRpcImpl(conn net.Conn, method string) (net.Conn, error) {

// Send the header
encoder := codec.NewEncoder(conn, structs.MsgpackHandle)
Expand Down
5 changes: 4 additions & 1 deletion nomad/rpc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,10 @@ func TestRPC_handleMultiplexV2(t *testing.T) {
require.NotEmpty(l)

// Make a streaming RPC
_, err = s.streamingRpcImpl(s2, s.Region(), "Bogus")
_, err = s2.Write([]byte{byte(pool.RpcStreaming)})
require.Nil(err)

_, err = s.streamingRpcImpl(s2, "Bogus")
require.NotNil(err)
require.Contains(err.Error(), "Bogus")
require.True(structs.IsErrUnknownMethod(err))
Expand Down

0 comments on commit f6cf206

Please sign in to comment.