Skip to content

Commit

Permalink
Merge pull request #5954 from hashicorp/b-fix-streaming-rpc-tls
Browse files Browse the repository at this point in the history
rpc: use tls wrapped connection for streaming rpc
  • Loading branch information
Mahmood Ali authored Jul 12, 2019
2 parents 0a58242 + b0d98d1 commit e129c41
Show file tree
Hide file tree
Showing 2 changed files with 142 additions and 14 deletions.
22 changes: 9 additions & 13 deletions nomad/rpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -540,18 +540,14 @@ func (r *rpcHandler) streamingRpc(server *serverParts, method string) (net.Conn,
tcp.SetNoDelay(true)
}

if err := r.streamingRpcImpl(conn, server.Region, method); err != nil {
return nil, err
}

return conn, nil
return r.streamingRpcImpl(conn, server.Region, method)
}

// streamingRpcImpl takes a pre-established connection to a server and conducts
// the handshake to establish a streaming RPC for the given method. If an error
// is returned, the underlying connection has been closed. Otherwise it is
// assumed that the connection has been hijacked by the RPC method.
func (r *rpcHandler) streamingRpcImpl(conn net.Conn, region, method string) error {
func (r *rpcHandler) streamingRpcImpl(conn net.Conn, region, method string) (net.Conn, error) {
// Check if TLS is enabled
r.tlsWrapLock.RLock()
tlsWrap := r.tlsWrap
Expand All @@ -561,22 +557,22 @@ func (r *rpcHandler) streamingRpcImpl(conn net.Conn, region, method string) erro
// Switch the connection into TLS mode
if _, err := conn.Write([]byte{byte(pool.RpcTLS)}); err != nil {
conn.Close()
return err
return nil, err
}

// Wrap the connection in a TLS client
tlsConn, err := tlsWrap(region, conn)
if err != nil {
conn.Close()
return err
return nil, err
}
conn = tlsConn
}

// Write the multiplex byte to set the mode
if _, err := conn.Write([]byte{byte(pool.RpcStreaming)}); err != nil {
conn.Close()
return err
return nil, err
}

// Send the header
Expand All @@ -587,22 +583,22 @@ func (r *rpcHandler) streamingRpcImpl(conn net.Conn, region, method string) erro
}
if err := encoder.Encode(header); err != nil {
conn.Close()
return err
return nil, err
}

// Wait for the acknowledgement
var ack structs.StreamingRpcAck
if err := decoder.Decode(&ack); err != nil {
conn.Close()
return err
return nil, err
}

if ack.Error != "" {
conn.Close()
return errors.New(ack.Error)
return nil, errors.New(ack.Error)
}

return nil
return conn, nil
}

// raftApplyFuture is used to encode a message, run it through raft, and return the Raft future.
Expand Down
134 changes: 133 additions & 1 deletion nomad/rpc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@ import (
"time"

msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc"
cstructs "github.com/hashicorp/nomad/client/structs"
"github.com/hashicorp/nomad/helper/pool"
"github.com/hashicorp/nomad/helper/testlog"
"github.com/hashicorp/nomad/helper/uuid"
"github.com/hashicorp/nomad/nomad/mock"
"github.com/hashicorp/nomad/nomad/structs"
"github.com/hashicorp/nomad/nomad/structs/config"
Expand All @@ -20,6 +22,7 @@ import (
"github.com/hashicorp/yamux"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/ugorji/go/codec"
)

// rpcClient is a test helper method to return a ClientCodec to use to make rpc
Expand Down Expand Up @@ -267,6 +270,135 @@ func TestRPC_streamingRpcConn_badMethod_TLS(t *testing.T) {
require.True(structs.IsErrUnknownMethod(err))
}

func TestRPC_streamingRpcConn_goodMethod_Plaintext(t *testing.T) {
t.Parallel()
require := require.New(t)
dir := tmpDir(t)
defer os.RemoveAll(dir)
s1 := TestServer(t, func(c *Config) {
c.Region = "regionFoo"
c.BootstrapExpect = 2
c.DevMode = false
c.DevDisableBootstrap = true
c.DataDir = path.Join(dir, "node1")
})
defer s1.Shutdown()

s2 := TestServer(t, func(c *Config) {
c.Region = "regionFoo"
c.BootstrapExpect = 2
c.DevMode = false
c.DevDisableBootstrap = true
c.DataDir = path.Join(dir, "node2")
})
defer s2.Shutdown()

TestJoin(t, s1, s2)
testutil.WaitForLeader(t, s1.RPC)

s1.peerLock.RLock()
ok, parts := isNomadServer(s2.LocalMember())
require.True(ok)
server := s1.localPeers[raft.ServerAddress(parts.Addr.String())]
require.NotNil(server)
s1.peerLock.RUnlock()

conn, err := s1.streamingRpc(server, "FileSystem.Logs")
require.NotNil(conn)
require.NoError(err)

decoder := codec.NewDecoder(conn, structs.MsgpackHandle)
encoder := codec.NewEncoder(conn, structs.MsgpackHandle)

allocID := uuid.Generate()
require.NoError(encoder.Encode(cstructs.FsStreamRequest{
AllocID: allocID,
QueryOptions: structs.QueryOptions{
Region: "regionFoo",
},
}))

var result cstructs.StreamErrWrapper
require.NoError(decoder.Decode(&result))
require.Empty(result.Payload)
require.True(structs.IsErrUnknownAllocation(result.Error))
}

func TestRPC_streamingRpcConn_goodMethod_TLS(t *testing.T) {
t.Parallel()
require := require.New(t)
const (
cafile = "../helper/tlsutil/testdata/ca.pem"
foocert = "../helper/tlsutil/testdata/nomad-foo.pem"
fookey = "../helper/tlsutil/testdata/nomad-foo-key.pem"
)
dir := tmpDir(t)
defer os.RemoveAll(dir)
s1 := TestServer(t, func(c *Config) {
c.Region = "regionFoo"
c.BootstrapExpect = 2
c.DevMode = false
c.DevDisableBootstrap = true
c.DataDir = path.Join(dir, "node1")
c.TLSConfig = &config.TLSConfig{
EnableHTTP: true,
EnableRPC: true,
VerifyServerHostname: true,
CAFile: cafile,
CertFile: foocert,
KeyFile: fookey,
}
})
defer s1.Shutdown()

s2 := TestServer(t, func(c *Config) {
c.Region = "regionFoo"
c.BootstrapExpect = 2
c.DevMode = false
c.DevDisableBootstrap = true
c.DataDir = path.Join(dir, "node2")
c.TLSConfig = &config.TLSConfig{
EnableHTTP: true,
EnableRPC: true,
VerifyServerHostname: true,
CAFile: cafile,
CertFile: foocert,
KeyFile: fookey,
}
})
defer s2.Shutdown()

TestJoin(t, s1, s2)
testutil.WaitForLeader(t, s1.RPC)

s1.peerLock.RLock()
ok, parts := isNomadServer(s2.LocalMember())
require.True(ok)
server := s1.localPeers[raft.ServerAddress(parts.Addr.String())]
require.NotNil(server)
s1.peerLock.RUnlock()

conn, err := s1.streamingRpc(server, "FileSystem.Logs")
require.NotNil(conn)
require.NoError(err)

decoder := codec.NewDecoder(conn, structs.MsgpackHandle)
encoder := codec.NewEncoder(conn, structs.MsgpackHandle)

allocID := uuid.Generate()
require.NoError(encoder.Encode(cstructs.FsStreamRequest{
AllocID: allocID,
QueryOptions: structs.QueryOptions{
Region: "regionFoo",
},
}))

var result cstructs.StreamErrWrapper
require.NoError(decoder.Decode(&result))
require.Empty(result.Payload)
require.True(structs.IsErrUnknownAllocation(result.Error))
}

// COMPAT: Remove in 0.10
// This is a very low level test to assert that the V2 handling works. It is
// making manual RPC calls since no helpers exist at this point since we are
Expand Down Expand Up @@ -321,7 +453,7 @@ func TestRPC_handleMultiplexV2(t *testing.T) {
require.NotEmpty(l)

// Make a streaming RPC
err = s.streamingRpcImpl(s2, s.Region(), "Bogus")
_, err = s.streamingRpcImpl(s2, s.Region(), "Bogus")
require.NotNil(err)
require.Contains(err.Error(), "Bogus")
require.True(structs.IsErrUnknownMethod(err))
Expand Down

0 comments on commit e129c41

Please sign in to comment.