Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rpc: use tls wrapped connection for streaming rpc #5954

Merged
merged 2 commits into from
Jul 12, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 9 additions & 13 deletions nomad/rpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -540,18 +540,14 @@ func (r *rpcHandler) streamingRpc(server *serverParts, method string) (net.Conn,
tcp.SetNoDelay(true)
}

if err := r.streamingRpcImpl(conn, server.Region, method); err != nil {
return nil, err
}

return conn, nil
return r.streamingRpcImpl(conn, server.Region, method)
}

// streamingRpcImpl takes a pre-established connection to a server and conducts
// the handshake to establish a streaming RPC for the given method. If an error
// is returned, the underlying connection has been closed. Otherwise it is
// assumed that the connection has been hijacked by the RPC method.
func (r *rpcHandler) streamingRpcImpl(conn net.Conn, region, method string) error {
func (r *rpcHandler) streamingRpcImpl(conn net.Conn, region, method string) (net.Conn, error) {
// Check if TLS is enabled
r.tlsWrapLock.RLock()
tlsWrap := r.tlsWrap
Expand All @@ -561,22 +557,22 @@ func (r *rpcHandler) streamingRpcImpl(conn net.Conn, region, method string) erro
// Switch the connection into TLS mode
if _, err := conn.Write([]byte{byte(pool.RpcTLS)}); err != nil {
conn.Close()
return err
return nil, err
}

// Wrap the connection in a TLS client
tlsConn, err := tlsWrap(region, conn)
if err != nil {
conn.Close()
return err
return nil, err
}
conn = tlsConn
}

// Write the multiplex byte to set the mode
if _, err := conn.Write([]byte{byte(pool.RpcStreaming)}); err != nil {
conn.Close()
return err
return nil, err
}

// Send the header
Expand All @@ -587,22 +583,22 @@ func (r *rpcHandler) streamingRpcImpl(conn net.Conn, region, method string) erro
}
if err := encoder.Encode(header); err != nil {
conn.Close()
return err
return nil, err
}

// Wait for the acknowledgement
var ack structs.StreamingRpcAck
if err := decoder.Decode(&ack); err != nil {
conn.Close()
return err
return nil, err
}

if ack.Error != "" {
conn.Close()
return errors.New(ack.Error)
return nil, errors.New(ack.Error)
}

return nil
return conn, nil
}

// raftApplyFuture is used to encode a message, run it through raft, and return the Raft future.
Expand Down
134 changes: 133 additions & 1 deletion nomad/rpc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@ import (
"time"

msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc"
cstructs "github.com/hashicorp/nomad/client/structs"
"github.com/hashicorp/nomad/helper/pool"
"github.com/hashicorp/nomad/helper/testlog"
"github.com/hashicorp/nomad/helper/uuid"
"github.com/hashicorp/nomad/nomad/mock"
"github.com/hashicorp/nomad/nomad/structs"
"github.com/hashicorp/nomad/nomad/structs/config"
Expand All @@ -20,6 +22,7 @@ import (
"github.com/hashicorp/yamux"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/ugorji/go/codec"
)

// rpcClient is a test helper method to return a ClientCodec to use to make rpc
Expand Down Expand Up @@ -267,6 +270,135 @@ func TestRPC_streamingRpcConn_badMethod_TLS(t *testing.T) {
require.True(structs.IsErrUnknownMethod(err))
}

func TestRPC_streamingRpcConn_goodMethod_Plaintext(t *testing.T) {
t.Parallel()
require := require.New(t)
dir := tmpDir(t)
defer os.RemoveAll(dir)
s1 := TestServer(t, func(c *Config) {
c.Region = "regionFoo"
c.BootstrapExpect = 2
c.DevMode = false
c.DevDisableBootstrap = true
c.DataDir = path.Join(dir, "node1")
})
defer s1.Shutdown()

s2 := TestServer(t, func(c *Config) {
c.Region = "regionFoo"
c.BootstrapExpect = 2
c.DevMode = false
c.DevDisableBootstrap = true
c.DataDir = path.Join(dir, "node2")
})
defer s2.Shutdown()

TestJoin(t, s1, s2)
testutil.WaitForLeader(t, s1.RPC)

s1.peerLock.RLock()
ok, parts := isNomadServer(s2.LocalMember())
require.True(ok)
server := s1.localPeers[raft.ServerAddress(parts.Addr.String())]
require.NotNil(server)
s1.peerLock.RUnlock()

conn, err := s1.streamingRpc(server, "FileSystem.Logs")
require.NotNil(conn)
require.NoError(err)

decoder := codec.NewDecoder(conn, structs.MsgpackHandle)
encoder := codec.NewEncoder(conn, structs.MsgpackHandle)

allocID := uuid.Generate()
require.NoError(encoder.Encode(cstructs.FsStreamRequest{
AllocID: allocID,
QueryOptions: structs.QueryOptions{
Region: "regionFoo",
},
}))

var result cstructs.StreamErrWrapper
require.NoError(decoder.Decode(&result))
require.Empty(result.Payload)
require.True(structs.IsErrUnknownAllocation(result.Error))
}

func TestRPC_streamingRpcConn_goodMethod_TLS(t *testing.T) {
t.Parallel()
require := require.New(t)
const (
cafile = "../helper/tlsutil/testdata/ca.pem"
foocert = "../helper/tlsutil/testdata/nomad-foo.pem"
fookey = "../helper/tlsutil/testdata/nomad-foo-key.pem"
)
dir := tmpDir(t)
defer os.RemoveAll(dir)
s1 := TestServer(t, func(c *Config) {
c.Region = "regionFoo"
c.BootstrapExpect = 2
c.DevMode = false
c.DevDisableBootstrap = true
c.DataDir = path.Join(dir, "node1")
c.TLSConfig = &config.TLSConfig{
EnableHTTP: true,
EnableRPC: true,
VerifyServerHostname: true,
CAFile: cafile,
CertFile: foocert,
KeyFile: fookey,
}
})
defer s1.Shutdown()

s2 := TestServer(t, func(c *Config) {
c.Region = "regionFoo"
c.BootstrapExpect = 2
c.DevMode = false
c.DevDisableBootstrap = true
c.DataDir = path.Join(dir, "node2")
c.TLSConfig = &config.TLSConfig{
EnableHTTP: true,
EnableRPC: true,
VerifyServerHostname: true,
CAFile: cafile,
CertFile: foocert,
KeyFile: fookey,
}
})
defer s2.Shutdown()

TestJoin(t, s1, s2)
testutil.WaitForLeader(t, s1.RPC)

s1.peerLock.RLock()
ok, parts := isNomadServer(s2.LocalMember())
require.True(ok)
server := s1.localPeers[raft.ServerAddress(parts.Addr.String())]
require.NotNil(server)
s1.peerLock.RUnlock()

conn, err := s1.streamingRpc(server, "FileSystem.Logs")
require.NotNil(conn)
require.NoError(err)

decoder := codec.NewDecoder(conn, structs.MsgpackHandle)
encoder := codec.NewEncoder(conn, structs.MsgpackHandle)

allocID := uuid.Generate()
require.NoError(encoder.Encode(cstructs.FsStreamRequest{
AllocID: allocID,
QueryOptions: structs.QueryOptions{
Region: "regionFoo",
},
}))

var result cstructs.StreamErrWrapper
require.NoError(decoder.Decode(&result))
require.Empty(result.Payload)
require.True(structs.IsErrUnknownAllocation(result.Error))
}

// COMPAT: Remove in 0.10
// This is a very low level test to assert that the V2 handling works. It is
// making manual RPC calls since no helpers exist at this point since we are
Expand Down Expand Up @@ -321,7 +453,7 @@ func TestRPC_handleMultiplexV2(t *testing.T) {
require.NotEmpty(l)

// Make a streaming RPC
err = s.streamingRpcImpl(s2, s.Region(), "Bogus")
_, err = s.streamingRpcImpl(s2, s.Region(), "Bogus")
require.NotNil(err)
require.Contains(err.Error(), "Bogus")
require.True(structs.IsErrUnknownMethod(err))
Expand Down