Skip to content

Commit

Permalink
rpc: use tls wrapped connection for streaming rpc
Browse files Browse the repository at this point in the history
This ensures that server-to-server streaming RPC calls use the tls
wrapped connections.

Prior to this, `streamingRpcImpl` function uses tls for setting header
and invoking the rpc method, but returns unwrapped tls connection.
Thus, streaming writes fail with tls errors.

This tls streaming bug existed since 0.8.0[1], but PR #5654[2]
exacerbated it in 0.9.2.  Prior to PR #5654, nomad client used to
shuffle servers at every heartbeat -- `servers.Manager.setServers`[3]
always shuffled servers and was called by heartbeat code[4].  Shuffling
servers meant that a nomad client would heartbeat and establish a
connection against all nomad servers eventually.  When handling
streaming RPC calls, nomad servers used these local connection to
communicate directly to the client.  The server-to-server forwarding
logic was left mostly unexercised.

PR #5654 means that a nomad client may connect to a single server only
and caused the server-to-server forward streaming RPC code to get
exercised more and unearthed the problem.

[1] https://github.com/hashicorp/nomad/blob/v0.8.0/nomad/rpc.go#L501-L515
[2] #5654
[3] https://github.com/hashicorp/nomad/blob/v0.9.1/client/servers/manager.go#L198-L216
[4] https://github.com/hashicorp/nomad/blob/v0.9.1/client/client.go#L1603
  • Loading branch information
Mahmood Ali committed Jul 12, 2019
1 parent a6604f8 commit b0d98d1
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 14 deletions.
22 changes: 9 additions & 13 deletions nomad/rpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -540,18 +540,14 @@ func (r *rpcHandler) streamingRpc(server *serverParts, method string) (net.Conn,
tcp.SetNoDelay(true)
}

if err := r.streamingRpcImpl(conn, server.Region, method); err != nil {
return nil, err
}

return conn, nil
return r.streamingRpcImpl(conn, server.Region, method)
}

// streamingRpcImpl takes a pre-established connection to a server and conducts
// the handshake to establish a streaming RPC for the given method. If an error
// is returned, the underlying connection has been closed. Otherwise it is
// assumed that the connection has been hijacked by the RPC method.
func (r *rpcHandler) streamingRpcImpl(conn net.Conn, region, method string) error {
func (r *rpcHandler) streamingRpcImpl(conn net.Conn, region, method string) (net.Conn, error) {
// Check if TLS is enabled
r.tlsWrapLock.RLock()
tlsWrap := r.tlsWrap
Expand All @@ -561,22 +557,22 @@ func (r *rpcHandler) streamingRpcImpl(conn net.Conn, region, method string) erro
// Switch the connection into TLS mode
if _, err := conn.Write([]byte{byte(pool.RpcTLS)}); err != nil {
conn.Close()
return err
return nil, err
}

// Wrap the connection in a TLS client
tlsConn, err := tlsWrap(region, conn)
if err != nil {
conn.Close()
return err
return nil, err
}
conn = tlsConn
}

// Write the multiplex byte to set the mode
if _, err := conn.Write([]byte{byte(pool.RpcStreaming)}); err != nil {
conn.Close()
return err
return nil, err
}

// Send the header
Expand All @@ -587,22 +583,22 @@ func (r *rpcHandler) streamingRpcImpl(conn net.Conn, region, method string) erro
}
if err := encoder.Encode(header); err != nil {
conn.Close()
return err
return nil, err
}

// Wait for the acknowledgement
var ack structs.StreamingRpcAck
if err := decoder.Decode(&ack); err != nil {
conn.Close()
return err
return nil, err
}

if ack.Error != "" {
conn.Close()
return errors.New(ack.Error)
return nil, errors.New(ack.Error)
}

return nil
return conn, nil
}

// raftApplyFuture is used to encode a message, run it through raft, and return the Raft future.
Expand Down
2 changes: 1 addition & 1 deletion nomad/rpc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,7 @@ func TestRPC_handleMultiplexV2(t *testing.T) {
require.NotEmpty(l)

// Make a streaming RPC
err = s.streamingRpcImpl(s2, s.Region(), "Bogus")
_, err = s.streamingRpcImpl(s2, s.Region(), "Bogus")
require.NotNil(err)
require.Contains(err.Error(), "Bogus")
require.True(structs.IsErrUnknownMethod(err))
Expand Down

0 comments on commit b0d98d1

Please sign in to comment.