diff --git a/api/allocations.go b/api/allocations.go index cf4400486c0..42b5c8636b3 100644 --- a/api/allocations.go +++ b/api/allocations.go @@ -48,13 +48,9 @@ func (a *Allocations) Info(allocID string, q *QueryOptions) (*Allocation, *Query } func (a *Allocations) Stats(alloc *Allocation, q *QueryOptions) (*AllocResourceUsage, error) { - nodeClient, err := a.client.GetNodeClient(alloc.NodeID, q) - if err != nil { - return nil, err - } - var resp AllocResourceUsage - _, err = nodeClient.query("/v1/client/allocation/"+alloc.ID+"/stats", &resp, nil) + path := fmt.Sprintf("/v1/client/allocation/%s/stats", alloc.ID) + _, err := a.client.query(path, &resp, q) return &resp, err } diff --git a/api/api.go b/api/api.go index 50b97954bae..17766c064cd 100644 --- a/api/api.go +++ b/api/api.go @@ -7,6 +7,7 @@ import ( "encoding/json" "fmt" "io" + "net" "net/http" "net/url" "os" @@ -18,6 +19,13 @@ import ( rootcerts "github.com/hashicorp/go-rootcerts" ) +var ( + // ClientConnTimeout is the timeout applied when attempting to contact a + // client directly before switching to a connection through the Nomad + // server. + ClientConnTimeout = 1 * time.Second +) + // QueryOptions are used to parameterize a query type QueryOptions struct { // Providing a datacenter overwrites the region provided @@ -145,6 +153,8 @@ func (c *Config) ClientConfig(region, address string, tlsEnabled bool) *Config { WaitTime: c.WaitTime, TLSConfig: c.TLSConfig.Copy(), } + + // Update the tls server name for connecting to a client if tlsEnabled && config.TLSConfig != nil { config.TLSConfig.TLSServerName = fmt.Sprintf("client.%s.nomad", region) } @@ -249,6 +259,34 @@ func DefaultConfig() *Config { return config } +// SetTimeout is used to place a timeout for connecting to Nomad. A negative +// duration is ignored, a duration of zero means no timeout, and any other value +// will add a timeout. +func (c *Config) SetTimeout(t time.Duration) error { + if c == nil { + return fmt.Errorf("nil config") + } else if c.httpClient == nil { + return fmt.Errorf("nil HTTP client") + } else if c.httpClient.Transport == nil { + return fmt.Errorf("nil HTTP client transport") + } + + // Apply a timeout. + if t.Nanoseconds() >= 0 { + transport, ok := c.httpClient.Transport.(*http.Transport) + if !ok { + return fmt.Errorf("unexpected HTTP transport: %T", c.httpClient.Transport) + } + + transport.DialContext = (&net.Dialer{ + Timeout: t, + KeepAlive: 30 * time.Second, + }).DialContext + } + + return nil +} + // ConfigureTLS applies a set of TLS configurations to the the HTTP client. func (c *Config) ConfigureTLS() error { if c.TLSConfig == nil { @@ -343,7 +381,15 @@ func (c *Client) SetNamespace(namespace string) { // GetNodeClient returns a new Client that will dial the specified node. If the // QueryOptions is set, its region will be used. func (c *Client) GetNodeClient(nodeID string, q *QueryOptions) (*Client, error) { - return c.getNodeClientImpl(nodeID, q, c.Nodes().Info) + return c.getNodeClientImpl(nodeID, -1, q, c.Nodes().Info) +} + +// GetNodeClientWithTimeout returns a new Client that will dial the specified +// node using the specified timeout. If the QueryOptions is set, its region will +// be used. +func (c *Client) GetNodeClientWithTimeout( + nodeID string, timeout time.Duration, q *QueryOptions) (*Client, error) { + return c.getNodeClientImpl(nodeID, timeout, q, c.Nodes().Info) } // nodeLookup is the definition of a function used to lookup a node. This is @@ -353,7 +399,7 @@ type nodeLookup func(nodeID string, q *QueryOptions) (*Node, *QueryMeta, error) // getNodeClientImpl is the implementation of creating a API client for // contacting a node. It takes a function to lookup the node such that it can be // mocked during tests. -func (c *Client) getNodeClientImpl(nodeID string, q *QueryOptions, lookup nodeLookup) (*Client, error) { +func (c *Client) getNodeClientImpl(nodeID string, timeout time.Duration, q *QueryOptions, lookup nodeLookup) (*Client, error) { node, _, err := lookup(nodeID, q) if err != nil { return nil, err @@ -380,6 +426,10 @@ func (c *Client) getNodeClientImpl(nodeID string, q *QueryOptions, lookup nodeLo // Get an API client for the node conf := c.config.ClientConfig(region, node.HTTPAddr, node.TLSEnabled) + + // Set the timeout + conf.SetTimeout(timeout) + return NewClient(conf) } diff --git a/api/api_test.go b/api/api_test.go index 79c86a2d496..55cf170908b 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -426,7 +426,7 @@ func TestClient_NodeClient(t *testing.T) { name := fmt.Sprintf("%s__%s__%s", c.ExpectedAddr, c.ExpectedRegion, c.ExpectedTLSServerName) t.Run(name, func(t *testing.T) { assert := assert.New(t) - nodeClient, err := c.Client.getNodeClientImpl("testID", c.QueryOptions, c.Node) + nodeClient, err := c.Client.getNodeClientImpl("testID", -1, c.QueryOptions, c.Node) assert.Nil(err) assert.Equal(c.ExpectedRegion, nodeClient.config.Region) assert.Equal(c.ExpectedAddr, nodeClient.config.Address) diff --git a/api/fs.go b/api/fs.go index c412db5416d..a2d17769e9b 100644 --- a/api/fs.go +++ b/api/fs.go @@ -4,6 +4,7 @@ import ( "encoding/json" "fmt" "io" + "net" "strconv" "sync" "time" @@ -51,22 +52,16 @@ func (c *Client) AllocFS() *AllocFS { // List is used to list the files at a given path of an allocation directory func (a *AllocFS) List(alloc *Allocation, path string, q *QueryOptions) ([]*AllocFileInfo, *QueryMeta, error) { - nodeClient, err := a.client.GetNodeClient(alloc.NodeID, q) - if err != nil { - return nil, nil, err - } - if q == nil { q = &QueryOptions{} } if q.Params == nil { q.Params = make(map[string]string) } - q.Params["path"] = path var resp []*AllocFileInfo - qm, err := nodeClient.query(fmt.Sprintf("/v1/client/fs/ls/%s", alloc.ID), &resp, q) + qm, err := a.client.query(fmt.Sprintf("/v1/client/fs/ls/%s", alloc.ID), &resp, q) if err != nil { return nil, nil, err } @@ -76,11 +71,6 @@ func (a *AllocFS) List(alloc *Allocation, path string, q *QueryOptions) ([]*Allo // Stat is used to stat a file at a given path of an allocation directory func (a *AllocFS) Stat(alloc *Allocation, path string, q *QueryOptions) (*AllocFileInfo, *QueryMeta, error) { - nodeClient, err := a.client.GetNodeClient(alloc.NodeID, q) - if err != nil { - return nil, nil, err - } - if q == nil { q = &QueryOptions{} } @@ -91,7 +81,7 @@ func (a *AllocFS) Stat(alloc *Allocation, path string, q *QueryOptions) (*AllocF q.Params["path"] = path var resp AllocFileInfo - qm, err := nodeClient.query(fmt.Sprintf("/v1/client/fs/stat/%s", alloc.ID), &resp, q) + qm, err := a.client.query(fmt.Sprintf("/v1/client/fs/stat/%s", alloc.ID), &resp, q) if err != nil { return nil, nil, err } @@ -101,7 +91,7 @@ func (a *AllocFS) Stat(alloc *Allocation, path string, q *QueryOptions) (*AllocF // ReadAt is used to read bytes at a given offset until limit at the given path // in an allocation directory. If limit is <= 0, there is no limit. func (a *AllocFS) ReadAt(alloc *Allocation, path string, offset int64, limit int64, q *QueryOptions) (io.ReadCloser, error) { - nodeClient, err := a.client.GetNodeClient(alloc.NodeID, q) + nodeClient, err := a.client.GetNodeClientWithTimeout(alloc.NodeID, ClientConnTimeout, q) if err != nil { return nil, err } @@ -117,17 +107,28 @@ func (a *AllocFS) ReadAt(alloc *Allocation, path string, offset int64, limit int q.Params["offset"] = strconv.FormatInt(offset, 10) q.Params["limit"] = strconv.FormatInt(limit, 10) - r, err := nodeClient.rawQuery(fmt.Sprintf("/v1/client/fs/readat/%s", alloc.ID), q) + reqPath := fmt.Sprintf("/v1/client/fs/readat/%s", alloc.ID) + r, err := nodeClient.rawQuery(reqPath, q) if err != nil { - return nil, err + // There was a networking error when talking directly to the client. + if _, ok := err.(net.Error); !ok { + return nil, err + } + + // Try via the server + r, err = a.client.rawQuery(reqPath, q) + if err != nil { + return nil, err + } } + return r, nil } // Cat is used to read contents of a file at the given path in an allocation // directory func (a *AllocFS) Cat(alloc *Allocation, path string, q *QueryOptions) (io.ReadCloser, error) { - nodeClient, err := a.client.GetNodeClient(alloc.NodeID, q) + nodeClient, err := a.client.GetNodeClientWithTimeout(alloc.NodeID, ClientConnTimeout, q) if err != nil { return nil, err } @@ -140,11 +141,21 @@ func (a *AllocFS) Cat(alloc *Allocation, path string, q *QueryOptions) (io.ReadC } q.Params["path"] = path - - r, err := nodeClient.rawQuery(fmt.Sprintf("/v1/client/fs/cat/%s", alloc.ID), q) + reqPath := fmt.Sprintf("/v1/client/fs/cat/%s", alloc.ID) + r, err := nodeClient.rawQuery(reqPath, q) if err != nil { - return nil, err + // There was a networking error when talking directly to the client. + if _, ok := err.(net.Error); !ok { + return nil, err + } + + // Try via the server + r, err = a.client.rawQuery(reqPath, q) + if err != nil { + return nil, err + } } + return r, nil } @@ -160,7 +171,7 @@ func (a *AllocFS) Stream(alloc *Allocation, path, origin string, offset int64, cancel <-chan struct{}, q *QueryOptions) (<-chan *StreamFrame, <-chan error) { errCh := make(chan error, 1) - nodeClient, err := a.client.GetNodeClient(alloc.NodeID, q) + nodeClient, err := a.client.GetNodeClientWithTimeout(alloc.NodeID, ClientConnTimeout, q) if err != nil { errCh <- err return nil, errCh @@ -177,10 +188,21 @@ func (a *AllocFS) Stream(alloc *Allocation, path, origin string, offset int64, q.Params["offset"] = strconv.FormatInt(offset, 10) q.Params["origin"] = origin - r, err := nodeClient.rawQuery(fmt.Sprintf("/v1/client/fs/stream/%s", alloc.ID), q) + reqPath := fmt.Sprintf("/v1/client/fs/stream/%s", alloc.ID) + r, err := nodeClient.rawQuery(reqPath, q) if err != nil { - errCh <- err - return nil, errCh + // There was a networking error when talking directly to the client. + if _, ok := err.(net.Error); !ok { + errCh <- err + return nil, errCh + } + + // Try via the server + r, err = a.client.rawQuery(reqPath, q) + if err != nil { + errCh <- err + return nil, errCh + } } // Create the output channel @@ -236,7 +258,7 @@ func (a *AllocFS) Logs(alloc *Allocation, follow bool, task, logType, origin str offset int64, cancel <-chan struct{}, q *QueryOptions) (<-chan *StreamFrame, <-chan error) { errCh := make(chan error, 1) - nodeClient, err := a.client.GetNodeClient(alloc.NodeID, q) + nodeClient, err := a.client.GetNodeClientWithTimeout(alloc.NodeID, ClientConnTimeout, q) if err != nil { errCh <- err return nil, errCh @@ -255,10 +277,21 @@ func (a *AllocFS) Logs(alloc *Allocation, follow bool, task, logType, origin str q.Params["origin"] = origin q.Params["offset"] = strconv.FormatInt(offset, 10) - r, err := nodeClient.rawQuery(fmt.Sprintf("/v1/client/fs/logs/%s", alloc.ID), q) + reqPath := fmt.Sprintf("/v1/client/fs/logs/%s", alloc.ID) + r, err := nodeClient.rawQuery(reqPath, q) if err != nil { - errCh <- err - return nil, errCh + // There was a networking error when talking directly to the client. + if _, ok := err.(net.Error); !ok { + errCh <- err + return nil, errCh + } + + // Try via the server + r, err = a.client.rawQuery(reqPath, q) + if err != nil { + errCh <- err + return nil, errCh + } } // Create the output channel diff --git a/api/nodes.go b/api/nodes.go index 194affde458..6a61ab90b8a 100644 --- a/api/nodes.go +++ b/api/nodes.go @@ -1,6 +1,7 @@ package api import ( + "fmt" "sort" "strconv" ) @@ -72,25 +73,26 @@ func (n *Nodes) ForceEvaluate(nodeID string, q *WriteOptions) (string, *WriteMet } func (n *Nodes) Stats(nodeID string, q *QueryOptions) (*HostStats, error) { - nodeClient, err := n.client.GetNodeClient(nodeID, q) - if err != nil { - return nil, err - } var resp HostStats - if _, err := nodeClient.query("/v1/client/stats", &resp, nil); err != nil { + path := fmt.Sprintf("/v1/client/stats?node_id=%s", nodeID) + if _, err := n.client.query(path, &resp, q); err != nil { return nil, err } return &resp, nil } func (n *Nodes) GC(nodeID string, q *QueryOptions) error { - nodeClient, err := n.client.GetNodeClient(nodeID, q) - if err != nil { - return err - } + var resp struct{} + path := fmt.Sprintf("/v1/client/gc?node_id=%s", nodeID) + _, err := n.client.query(path, &resp, q) + return err +} +// TODO Add tests +func (n *Nodes) GcAlloc(allocID string, q *QueryOptions) error { var resp struct{} - _, err = nodeClient.query("/v1/client/gc", &resp, nil) + path := fmt.Sprintf("/v1/client/allocation/%s/gc", allocID) + _, err := n.client.query(path, &resp, q) return err } diff --git a/api/nodes_test.go b/api/nodes_test.go index b3cc6c2b141..2a2727788ed 100644 --- a/api/nodes_test.go +++ b/api/nodes_test.go @@ -8,7 +8,10 @@ import ( "testing" "time" + "github.com/hashicorp/nomad/helper/uuid" + "github.com/hashicorp/nomad/nomad/structs" "github.com/hashicorp/nomad/testutil" + "github.com/stretchr/testify/require" ) func TestNodes_List(t *testing.T) { @@ -275,3 +278,27 @@ func TestNodes_Sort(t *testing.T) { t.Fatalf("\n\n%#v\n\n%#v", nodes, expect) } } + +func TestNodes_GC(t *testing.T) { + t.Parallel() + require := require.New(t) + c, s := makeClient(t, nil, nil) + defer s.Stop() + nodes := c.Nodes() + + err := nodes.GC(uuid.Generate(), nil) + require.NotNil(err) + require.True(structs.IsErrUnknownNode(err)) +} + +func TestNodes_GcAlloc(t *testing.T) { + t.Parallel() + require := require.New(t) + c, s := makeClient(t, nil, nil) + defer s.Stop() + nodes := c.Nodes() + + err := nodes.GcAlloc(uuid.Generate(), nil) + require.NotNil(err) + require.True(structs.IsErrUnknownAllocation(err)) +} diff --git a/client/acl_test.go b/client/acl_test.go index c4fc1463fde..bb21fcb8bd1 100644 --- a/client/acl_test.go +++ b/client/acl_test.go @@ -17,7 +17,7 @@ func TestClient_ACL_resolveTokenValue(t *testing.T) { defer s1.Shutdown() testutil.WaitForLeader(t, s1.RPC) - c1 := testClient(t, func(c *config.Config) { + c1 := TestClient(t, func(c *config.Config) { c.RPCHandler = s1 c.ACLEnabled = true }) @@ -66,7 +66,7 @@ func TestClient_ACL_resolvePolicies(t *testing.T) { defer s1.Shutdown() testutil.WaitForLeader(t, s1.RPC) - c1 := testClient(t, func(c *config.Config) { + c1 := TestClient(t, func(c *config.Config) { c.RPCHandler = s1 c.ACLEnabled = true }) @@ -106,7 +106,7 @@ func TestClient_ACL_ResolveToken_Disabled(t *testing.T) { defer s1.Shutdown() testutil.WaitForLeader(t, s1.RPC) - c1 := testClient(t, func(c *config.Config) { + c1 := TestClient(t, func(c *config.Config) { c.RPCHandler = s1 }) defer c1.Shutdown() @@ -122,7 +122,7 @@ func TestClient_ACL_ResolveToken(t *testing.T) { defer s1.Shutdown() testutil.WaitForLeader(t, s1.RPC) - c1 := testClient(t, func(c *config.Config) { + c1 := TestClient(t, func(c *config.Config) { c.RPCHandler = s1 c.ACLEnabled = true }) diff --git a/client/alloc_endpoint.go b/client/alloc_endpoint.go new file mode 100644 index 00000000000..6b1e4eec09a --- /dev/null +++ b/client/alloc_endpoint.go @@ -0,0 +1,75 @@ +package client + +import ( + "time" + + metrics "github.com/armon/go-metrics" + "github.com/hashicorp/nomad/acl" + cstructs "github.com/hashicorp/nomad/client/structs" + nstructs "github.com/hashicorp/nomad/nomad/structs" +) + +// Allocations endpoint is used for interacting with client allocations +type Allocations struct { + c *Client +} + +// GarbageCollectAll is used to garbage collect all allocations on a client. +func (a *Allocations) GarbageCollectAll(args *nstructs.NodeSpecificRequest, reply *nstructs.GenericResponse) error { + defer metrics.MeasureSince([]string{"client", "allocations", "garbage_collect_all"}, time.Now()) + + // Check node write permissions + if aclObj, err := a.c.ResolveToken(args.AuthToken); err != nil { + return err + } else if aclObj != nil && !aclObj.AllowNodeWrite() { + return nstructs.ErrPermissionDenied + } + + a.c.CollectAllAllocs() + return nil +} + +// GarbageCollect is used to garbage collect an allocation on a client. +func (a *Allocations) GarbageCollect(args *nstructs.AllocSpecificRequest, reply *nstructs.GenericResponse) error { + defer metrics.MeasureSince([]string{"client", "allocations", "garbage_collect"}, time.Now()) + + // Check submit job permissions + if aclObj, err := a.c.ResolveToken(args.AuthToken); err != nil { + return err + } else if aclObj != nil && !aclObj.AllowNsOp(args.Namespace, acl.NamespaceCapabilitySubmitJob) { + return nstructs.ErrPermissionDenied + } + + if !a.c.CollectAllocation(args.AllocID) { + // Could not find alloc + return nstructs.NewErrUnknownAllocation(args.AllocID) + } + + return nil +} + +// Stats is used to collect allocation statistics +func (a *Allocations) Stats(args *cstructs.AllocStatsRequest, reply *cstructs.AllocStatsResponse) error { + defer metrics.MeasureSince([]string{"client", "allocations", "stats"}, time.Now()) + + // Check read job permissions + if aclObj, err := a.c.ResolveToken(args.AuthToken); err != nil { + return err + } else if aclObj != nil && !aclObj.AllowNsOp(args.Namespace, acl.NamespaceCapabilityReadJob) { + return nstructs.ErrPermissionDenied + } + + clientStats := a.c.StatsReporter() + aStats, err := clientStats.GetAllocStats(args.AllocID) + if err != nil { + return err + } + + stats, err := aStats.LatestAllocStats(args.Task) + if err != nil { + return err + } + + reply.Stats = stats + return nil +} diff --git a/client/alloc_endpoint_test.go b/client/alloc_endpoint_test.go new file mode 100644 index 00000000000..c111f209b6b --- /dev/null +++ b/client/alloc_endpoint_test.go @@ -0,0 +1,264 @@ +package client + +import ( + "fmt" + "testing" + + "github.com/hashicorp/nomad/acl" + "github.com/hashicorp/nomad/client/config" + cstructs "github.com/hashicorp/nomad/client/structs" + "github.com/hashicorp/nomad/nomad/mock" + nstructs "github.com/hashicorp/nomad/nomad/structs" + "github.com/hashicorp/nomad/testutil" + "github.com/stretchr/testify/require" +) + +func TestAllocations_GarbageCollectAll(t *testing.T) { + t.Parallel() + require := require.New(t) + client := TestClient(t, nil) + + req := &nstructs.NodeSpecificRequest{} + var resp nstructs.GenericResponse + require.Nil(client.ClientRPC("Allocations.GarbageCollectAll", &req, &resp)) +} + +func TestAllocations_GarbageCollectAll_ACL(t *testing.T) { + t.Parallel() + require := require.New(t) + server, addr, root := testACLServer(t, nil) + defer server.Shutdown() + + client := TestClient(t, func(c *config.Config) { + c.Servers = []string{addr} + c.ACLEnabled = true + }) + defer client.Shutdown() + + // Try request without a token and expect failure + { + req := &nstructs.NodeSpecificRequest{} + var resp nstructs.GenericResponse + err := client.ClientRPC("Allocations.GarbageCollectAll", &req, &resp) + require.NotNil(err) + require.EqualError(err, nstructs.ErrPermissionDenied.Error()) + } + + // Try request with an invalid token and expect failure + { + token := mock.CreatePolicyAndToken(t, server.State(), 1005, "invalid", mock.NodePolicy(acl.PolicyDeny)) + req := &nstructs.NodeSpecificRequest{} + req.AuthToken = token.SecretID + + var resp nstructs.GenericResponse + err := client.ClientRPC("Allocations.GarbageCollectAll", &req, &resp) + + require.NotNil(err) + require.EqualError(err, nstructs.ErrPermissionDenied.Error()) + } + + // Try request with a valid token + { + token := mock.CreatePolicyAndToken(t, server.State(), 1007, "valid", mock.NodePolicy(acl.PolicyWrite)) + req := &nstructs.NodeSpecificRequest{} + req.AuthToken = token.SecretID + var resp nstructs.GenericResponse + require.Nil(client.ClientRPC("Allocations.GarbageCollectAll", &req, &resp)) + } + + // Try request with a management token + { + req := &nstructs.NodeSpecificRequest{} + req.AuthToken = root.SecretID + var resp nstructs.GenericResponse + require.Nil(client.ClientRPC("Allocations.GarbageCollectAll", &req, &resp)) + } +} + +func TestAllocations_GarbageCollect(t *testing.T) { + t.Parallel() + require := require.New(t) + client := TestClient(t, func(c *config.Config) { + c.GCDiskUsageThreshold = 100.0 + }) + + a := mock.Alloc() + a.Job.TaskGroups[0].Tasks[0].Driver = "mock_driver" + a.Job.TaskGroups[0].RestartPolicy = &nstructs.RestartPolicy{ + Attempts: 0, + Mode: nstructs.RestartPolicyModeFail, + } + a.Job.TaskGroups[0].Tasks[0].Config = map[string]interface{}{ + "run_for": "10ms", + } + require.Nil(client.addAlloc(a, "")) + + // Try with bad alloc + req := &nstructs.AllocSpecificRequest{} + var resp nstructs.GenericResponse + err := client.ClientRPC("Allocations.GarbageCollect", &req, &resp) + require.NotNil(err) + + // Try with good alloc + req.AllocID = a.ID + testutil.WaitForResult(func() (bool, error) { + // Check if has been removed first + if ar, ok := client.allocs[a.ID]; !ok || ar.IsDestroyed() { + return true, nil + } + + var resp2 nstructs.GenericResponse + err := client.ClientRPC("Allocations.GarbageCollect", &req, &resp2) + return err == nil, err + }, func(err error) { + t.Fatalf("err: %v", err) + }) +} + +func TestAllocations_GarbageCollect_ACL(t *testing.T) { + t.Parallel() + require := require.New(t) + server, addr, root := testACLServer(t, nil) + defer server.Shutdown() + + client := TestClient(t, func(c *config.Config) { + c.Servers = []string{addr} + c.ACLEnabled = true + }) + defer client.Shutdown() + + // Try request without a token and expect failure + { + req := &nstructs.AllocSpecificRequest{} + var resp nstructs.GenericResponse + err := client.ClientRPC("Allocations.GarbageCollect", &req, &resp) + require.NotNil(err) + require.EqualError(err, nstructs.ErrPermissionDenied.Error()) + } + + // Try request with an invalid token and expect failure + { + token := mock.CreatePolicyAndToken(t, server.State(), 1005, "invalid", mock.NodePolicy(acl.PolicyDeny)) + req := &nstructs.AllocSpecificRequest{} + req.AuthToken = token.SecretID + + var resp nstructs.GenericResponse + err := client.ClientRPC("Allocations.GarbageCollect", &req, &resp) + + require.NotNil(err) + require.EqualError(err, nstructs.ErrPermissionDenied.Error()) + } + + // Try request with a valid token + { + token := mock.CreatePolicyAndToken(t, server.State(), 1005, "test-valid", + mock.NamespacePolicy(nstructs.DefaultNamespace, "", []string{acl.NamespaceCapabilitySubmitJob})) + req := &nstructs.AllocSpecificRequest{} + req.AuthToken = token.SecretID + req.Namespace = nstructs.DefaultNamespace + + var resp nstructs.GenericResponse + err := client.ClientRPC("Allocations.GarbageCollect", &req, &resp) + require.True(nstructs.IsErrUnknownAllocation(err)) + } + + // Try request with a management token + { + req := &nstructs.AllocSpecificRequest{} + req.AuthToken = root.SecretID + + var resp nstructs.GenericResponse + err := client.ClientRPC("Allocations.GarbageCollect", &req, &resp) + require.True(nstructs.IsErrUnknownAllocation(err)) + } +} + +func TestAllocations_Stats(t *testing.T) { + t.Parallel() + require := require.New(t) + client := TestClient(t, nil) + + a := mock.Alloc() + require.Nil(client.addAlloc(a, "")) + + // Try with bad alloc + req := &cstructs.AllocStatsRequest{} + var resp cstructs.AllocStatsResponse + err := client.ClientRPC("Allocations.Stats", &req, &resp) + require.NotNil(err) + + // Try with good alloc + req.AllocID = a.ID + testutil.WaitForResult(func() (bool, error) { + var resp2 cstructs.AllocStatsResponse + err := client.ClientRPC("Allocations.Stats", &req, &resp2) + if err != nil { + return false, err + } + if resp2.Stats == nil { + return false, fmt.Errorf("invalid stats object") + } + + return true, nil + }, func(err error) { + t.Fatalf("err: %v", err) + }) +} + +func TestAllocations_Stats_ACL(t *testing.T) { + t.Parallel() + require := require.New(t) + server, addr, root := testACLServer(t, nil) + defer server.Shutdown() + + client := TestClient(t, func(c *config.Config) { + c.Servers = []string{addr} + c.ACLEnabled = true + }) + defer client.Shutdown() + + // Try request without a token and expect failure + { + req := &cstructs.AllocStatsRequest{} + var resp cstructs.AllocStatsResponse + err := client.ClientRPC("Allocations.Stats", &req, &resp) + require.NotNil(err) + require.EqualError(err, nstructs.ErrPermissionDenied.Error()) + } + + // Try request with an invalid token and expect failure + { + token := mock.CreatePolicyAndToken(t, server.State(), 1005, "invalid", mock.NodePolicy(acl.PolicyDeny)) + req := &cstructs.AllocStatsRequest{} + req.AuthToken = token.SecretID + + var resp cstructs.AllocStatsResponse + err := client.ClientRPC("Allocations.Stats", &req, &resp) + + require.NotNil(err) + require.EqualError(err, nstructs.ErrPermissionDenied.Error()) + } + + // Try request with a valid token + { + token := mock.CreatePolicyAndToken(t, server.State(), 1005, "test-valid", + mock.NamespacePolicy(nstructs.DefaultNamespace, "", []string{acl.NamespaceCapabilityReadJob})) + req := &cstructs.AllocStatsRequest{} + req.AuthToken = token.SecretID + req.Namespace = nstructs.DefaultNamespace + + var resp cstructs.AllocStatsResponse + err := client.ClientRPC("Allocations.Stats", &req, &resp) + require.True(nstructs.IsErrUnknownAllocation(err)) + } + + // Try request with a management token + { + req := &cstructs.AllocStatsRequest{} + req.AuthToken = root.SecretID + + var resp cstructs.AllocStatsResponse + err := client.ClientRPC("Allocations.Stats", &req, &resp) + require.True(nstructs.IsErrUnknownAllocation(err)) + } +} diff --git a/client/alloc_runner_test.go b/client/alloc_runner_test.go index 47348e29bd0..28abf0f43de 100644 --- a/client/alloc_runner_test.go +++ b/client/alloc_runner_test.go @@ -15,6 +15,7 @@ import ( "github.com/hashicorp/consul/api" "github.com/hashicorp/go-multierror" "github.com/hashicorp/nomad/command/agent/consul" + "github.com/hashicorp/nomad/helper/testlog" "github.com/hashicorp/nomad/helper/uuid" "github.com/hashicorp/nomad/nomad/mock" "github.com/hashicorp/nomad/nomad/structs" @@ -62,8 +63,7 @@ func allocationBucketExists(tx *bolt.Tx, allocID string) bool { return alloc != nil } -func testAllocRunnerFromAlloc(alloc *structs.Allocation, restarts bool) (*MockAllocStateUpdater, *AllocRunner) { - logger := testLogger() +func testAllocRunnerFromAlloc(t *testing.T, alloc *structs.Allocation, restarts bool) (*MockAllocStateUpdater, *AllocRunner) { conf := config.DefaultConfig() conf.Node = mock.Node() conf.StateDir = os.TempDir() @@ -76,22 +76,22 @@ func testAllocRunnerFromAlloc(alloc *structs.Allocation, restarts bool) (*MockAl alloc.Job.Type = structs.JobTypeBatch } vclient := vaultclient.NewMockVaultClient() - ar := NewAllocRunner(logger, conf, db, upd.Update, alloc, vclient, newMockConsulServiceClient(), noopPrevAlloc{}) + ar := NewAllocRunner(testlog.Logger(t), conf, db, upd.Update, alloc, vclient, newMockConsulServiceClient(t), noopPrevAlloc{}) return upd, ar } -func testAllocRunner(restarts bool) (*MockAllocStateUpdater, *AllocRunner) { +func testAllocRunner(t *testing.T, restarts bool) (*MockAllocStateUpdater, *AllocRunner) { // Use mock driver alloc := mock.Alloc() task := alloc.Job.TaskGroups[0].Tasks[0] task.Driver = "mock_driver" task.Config["run_for"] = "500ms" - return testAllocRunnerFromAlloc(alloc, restarts) + return testAllocRunnerFromAlloc(t, alloc, restarts) } func TestAllocRunner_SimpleRun(t *testing.T) { t.Parallel() - upd, ar := testAllocRunner(false) + upd, ar := testAllocRunner(t, false) go ar.Run() defer ar.Destroy() @@ -115,7 +115,7 @@ func TestAllocRunner_DeploymentHealth_Unhealthy_BadStart(t *testing.T) { assert := assert.New(t) // Ensure the task fails and restarts - upd, ar := testAllocRunner(false) + upd, ar := testAllocRunner(t, false) // Make the task fail task := ar.alloc.Job.TaskGroups[0].Tasks[0] @@ -163,7 +163,7 @@ func TestAllocRunner_DeploymentHealth_Unhealthy_Deadline(t *testing.T) { assert := assert.New(t) // Ensure the task fails and restarts - upd, ar := testAllocRunner(false) + upd, ar := testAllocRunner(t, false) // Make the task block task := ar.alloc.Job.TaskGroups[0].Tasks[0] @@ -211,7 +211,7 @@ func TestAllocRunner_DeploymentHealth_Healthy_NoChecks(t *testing.T) { t.Parallel() // Ensure the task fails and restarts - upd, ar := testAllocRunner(false) + upd, ar := testAllocRunner(t, false) // Make the task run healthy task := ar.alloc.Job.TaskGroups[0].Tasks[0] @@ -259,7 +259,7 @@ func TestAllocRunner_DeploymentHealth_Healthy_Checks(t *testing.T) { t.Parallel() // Ensure the task fails and restarts - upd, ar := testAllocRunner(false) + upd, ar := testAllocRunner(t, false) // Make the task fail task := ar.alloc.Job.TaskGroups[0].Tasks[0] @@ -352,7 +352,7 @@ func TestAllocRunner_DeploymentHealth_Unhealthy_Checks(t *testing.T) { assert := assert.New(t) // Ensure the task fails and restarts - upd, ar := testAllocRunner(false) + upd, ar := testAllocRunner(t, false) // Make the task fail task := ar.alloc.Job.TaskGroups[0].Tasks[0] @@ -421,7 +421,7 @@ func TestAllocRunner_DeploymentHealth_Healthy_UpdatedDeployment(t *testing.T) { t.Parallel() // Ensure the task fails and restarts - upd, ar := testAllocRunner(false) + upd, ar := testAllocRunner(t, false) // Make the task run healthy task := ar.alloc.Job.TaskGroups[0].Tasks[0] @@ -502,7 +502,7 @@ func TestAllocRunner_RetryArtifact(t *testing.T) { } alloc.Job.TaskGroups[0].Tasks = append(alloc.Job.TaskGroups[0].Tasks, badtask) - upd, ar := testAllocRunnerFromAlloc(alloc, true) + upd, ar := testAllocRunnerFromAlloc(t, alloc, true) go ar.Run() defer ar.Destroy() @@ -538,7 +538,7 @@ func TestAllocRunner_RetryArtifact(t *testing.T) { func TestAllocRunner_TerminalUpdate_Destroy(t *testing.T) { t.Parallel() - upd, ar := testAllocRunner(false) + upd, ar := testAllocRunner(t, false) // Ensure task takes some time task := ar.alloc.Job.TaskGroups[0].Tasks[0] @@ -637,7 +637,7 @@ func TestAllocRunner_TerminalUpdate_Destroy(t *testing.T) { func TestAllocRunner_Destroy(t *testing.T) { t.Parallel() - upd, ar := testAllocRunner(false) + upd, ar := testAllocRunner(t, false) // Ensure task takes some time task := ar.alloc.Job.TaskGroups[0].Tasks[0] @@ -693,7 +693,7 @@ func TestAllocRunner_Destroy(t *testing.T) { func TestAllocRunner_Update(t *testing.T) { t.Parallel() - _, ar := testAllocRunner(false) + _, ar := testAllocRunner(t, false) // Deep copy the alloc to avoid races when updating newAlloc := ar.Alloc().Copy() @@ -728,7 +728,7 @@ func TestAllocRunner_SaveRestoreState(t *testing.T) { "run_for": "10s", } - upd, ar := testAllocRunnerFromAlloc(alloc, false) + upd, ar := testAllocRunnerFromAlloc(t, alloc, false) go ar.Run() defer ar.Destroy() @@ -796,7 +796,7 @@ func TestAllocRunner_SaveRestoreState(t *testing.T) { func TestAllocRunner_SaveRestoreState_TerminalAlloc(t *testing.T) { t.Parallel() - upd, ar := testAllocRunner(false) + upd, ar := testAllocRunner(t, false) ar.logger = prefixedTestLogger("ar1: ") // Ensure task takes some time @@ -929,7 +929,7 @@ func TestAllocRunner_SaveRestoreState_Upgrade(t *testing.T) { "run_for": "10s", } - upd, ar := testAllocRunnerFromAlloc(alloc, false) + upd, ar := testAllocRunnerFromAlloc(t, alloc, false) // Hack in old version to cause an upgrade on RestoreState origConfig := ar.config.Copy() ar.config.Version = &version.VersionInfo{Version: "0.5.6"} @@ -1112,7 +1112,7 @@ func TestAllocRunner_RestoreOldState(t *testing.T) { *alloc.Job.LookupTaskGroup(alloc.TaskGroup).RestartPolicy = structs.RestartPolicy{Attempts: 0} alloc.Job.Type = structs.JobTypeBatch vclient := vaultclient.NewMockVaultClient() - cclient := newMockConsulServiceClient() + cclient := newMockConsulServiceClient(t) ar := NewAllocRunner(logger, conf, db, upd.Update, alloc, vclient, cclient, noopPrevAlloc{}) defer ar.Destroy() @@ -1140,7 +1140,7 @@ func TestAllocRunner_RestoreOldState(t *testing.T) { func TestAllocRunner_TaskFailed_KillTG(t *testing.T) { t.Parallel() - upd, ar := testAllocRunner(false) + upd, ar := testAllocRunner(t, false) // Create two tasks in the task group task := ar.alloc.Job.TaskGroups[0].Tasks[0] @@ -1208,7 +1208,7 @@ func TestAllocRunner_TaskFailed_KillTG(t *testing.T) { func TestAllocRunner_TaskLeader_KillTG(t *testing.T) { t.Parallel() - upd, ar := testAllocRunner(false) + upd, ar := testAllocRunner(t, false) // Create two tasks in the task group task := ar.alloc.Job.TaskGroups[0].Tasks[0] @@ -1282,7 +1282,7 @@ func TestAllocRunner_TaskLeader_KillTG(t *testing.T) { // with a leader the leader is stopped before other tasks. func TestAllocRunner_TaskLeader_StopTG(t *testing.T) { t.Parallel() - upd, ar := testAllocRunner(false) + upd, ar := testAllocRunner(t, false) // Create 3 tasks in the task group task := ar.alloc.Job.TaskGroups[0].Tasks[0] @@ -1371,7 +1371,7 @@ func TestAllocRunner_TaskLeader_StopTG(t *testing.T) { // See https://github.com/hashicorp/nomad/issues/3420#issuecomment-341666932 func TestAllocRunner_TaskLeader_StopRestoredTG(t *testing.T) { t.Parallel() - _, ar := testAllocRunner(false) + _, ar := testAllocRunner(t, false) defer ar.Destroy() // Create a leader and follower task in the task group @@ -1468,7 +1468,7 @@ func TestAllocRunner_MoveAllocDir(t *testing.T) { task.Config = map[string]interface{}{ "run_for": "1s", } - upd, ar := testAllocRunnerFromAlloc(alloc, false) + upd, ar := testAllocRunnerFromAlloc(t, alloc, false) go ar.Run() defer ar.Destroy() @@ -1501,7 +1501,7 @@ func TestAllocRunner_MoveAllocDir(t *testing.T) { task.Config = map[string]interface{}{ "run_for": "1s", } - upd2, ar2 := testAllocRunnerFromAlloc(alloc2, false) + upd2, ar2 := testAllocRunnerFromAlloc(t, alloc2, false) // Set prevAlloc like Client does ar2.prevAlloc = newAllocWatcher(alloc2, ar, nil, ar2.config, ar2.logger, "") diff --git a/client/alloc_watcher_test.go b/client/alloc_watcher_test.go index 45972e7593f..43048363b67 100644 --- a/client/alloc_watcher_test.go +++ b/client/alloc_watcher_test.go @@ -23,7 +23,7 @@ import ( // TestPrevAlloc_LocalPrevAlloc asserts that when a previous alloc runner is // set a localPrevAlloc will block on it. func TestPrevAlloc_LocalPrevAlloc(t *testing.T) { - _, prevAR := testAllocRunner(false) + _, prevAR := testAllocRunner(t, false) prevAR.alloc.Job.TaskGroups[0].Tasks[0].Config["run_for"] = "10s" newAlloc := mock.Alloc() @@ -177,7 +177,7 @@ func TestPrevAlloc_StreamAllocDir_Ok(t *testing.T) { } defer os.RemoveAll(dir1) - c1 := testClient(t, func(c *config.Config) { + c1 := TestClient(t, func(c *config.Config) { c.RPCHandler = nil }) defer c1.Shutdown() diff --git a/client/allocdir/alloc_dir.go b/client/allocdir/alloc_dir.go index d34c8b8ea1c..1cccec1d598 100644 --- a/client/allocdir/alloc_dir.go +++ b/client/allocdir/alloc_dir.go @@ -2,6 +2,7 @@ package allocdir import ( "archive/tar" + "context" "fmt" "io" "io/ioutil" @@ -10,11 +11,11 @@ import ( "path/filepath" "time" - "gopkg.in/tomb.v1" - "github.com/hashicorp/go-multierror" + cstructs "github.com/hashicorp/nomad/client/structs" "github.com/hashicorp/nomad/nomad/structs" "github.com/hpcloud/tail/watch" + tomb "gopkg.in/tomb.v1" ) const ( @@ -75,23 +76,14 @@ type AllocDir struct { logger *log.Logger } -// AllocFileInfo holds information about a file inside the AllocDir -type AllocFileInfo struct { - Name string - IsDir bool - Size int64 - FileMode string - ModTime time.Time -} - // AllocDirFS exposes file operations on the alloc dir type AllocDirFS interface { - List(path string) ([]*AllocFileInfo, error) - Stat(path string) (*AllocFileInfo, error) + List(path string) ([]*cstructs.AllocFileInfo, error) + Stat(path string) (*cstructs.AllocFileInfo, error) ReadAt(path string, offset int64) (io.ReadCloser, error) Snapshot(w io.Writer) error - BlockUntilExists(path string, t *tomb.Tomb) (chan error, error) - ChangeEvents(path string, curOffset int64, t *tomb.Tomb) (*watch.FileChanges, error) + BlockUntilExists(ctx context.Context, path string) (chan error, error) + ChangeEvents(ctx context.Context, path string, curOffset int64) (*watch.FileChanges, error) } // NewAllocDir initializes the AllocDir struct with allocDir as base path for @@ -335,7 +327,7 @@ func (d *AllocDir) Build() error { } // List returns the list of files at a path relative to the alloc dir -func (d *AllocDir) List(path string) ([]*AllocFileInfo, error) { +func (d *AllocDir) List(path string) ([]*cstructs.AllocFileInfo, error) { if escapes, err := structs.PathEscapesAllocDir("", path); err != nil { return nil, fmt.Errorf("Failed to check if path escapes alloc directory: %v", err) } else if escapes { @@ -345,11 +337,11 @@ func (d *AllocDir) List(path string) ([]*AllocFileInfo, error) { p := filepath.Join(d.AllocDir, path) finfos, err := ioutil.ReadDir(p) if err != nil { - return []*AllocFileInfo{}, err + return []*cstructs.AllocFileInfo{}, err } - files := make([]*AllocFileInfo, len(finfos)) + files := make([]*cstructs.AllocFileInfo, len(finfos)) for idx, info := range finfos { - files[idx] = &AllocFileInfo{ + files[idx] = &cstructs.AllocFileInfo{ Name: info.Name(), IsDir: info.IsDir(), Size: info.Size(), @@ -361,7 +353,7 @@ func (d *AllocDir) List(path string) ([]*AllocFileInfo, error) { } // Stat returns information about the file at a path relative to the alloc dir -func (d *AllocDir) Stat(path string) (*AllocFileInfo, error) { +func (d *AllocDir) Stat(path string) (*cstructs.AllocFileInfo, error) { if escapes, err := structs.PathEscapesAllocDir("", path); err != nil { return nil, fmt.Errorf("Failed to check if path escapes alloc directory: %v", err) } else if escapes { @@ -374,7 +366,7 @@ func (d *AllocDir) Stat(path string) (*AllocFileInfo, error) { return nil, err } - return &AllocFileInfo{ + return &cstructs.AllocFileInfo{ Size: info.Size(), Name: info.Name(), IsDir: info.IsDir(), @@ -411,8 +403,8 @@ func (d *AllocDir) ReadAt(path string, offset int64) (io.ReadCloser, error) { } // BlockUntilExists blocks until the passed file relative the allocation -// directory exists. The block can be cancelled with the passed tomb. -func (d *AllocDir) BlockUntilExists(path string, t *tomb.Tomb) (chan error, error) { +// directory exists. The block can be cancelled with the passed context. +func (d *AllocDir) BlockUntilExists(ctx context.Context, path string) (chan error, error) { if escapes, err := structs.PathEscapesAllocDir("", path); err != nil { return nil, fmt.Errorf("Failed to check if path escapes alloc directory: %v", err) } else if escapes { @@ -423,6 +415,11 @@ func (d *AllocDir) BlockUntilExists(path string, t *tomb.Tomb) (chan error, erro p := filepath.Join(d.AllocDir, path) watcher := getFileWatcher(p) returnCh := make(chan error, 1) + t := &tomb.Tomb{} + go func() { + <-ctx.Done() + t.Kill(nil) + }() go func() { returnCh <- watcher.BlockUntilExists(t) close(returnCh) @@ -431,15 +428,21 @@ func (d *AllocDir) BlockUntilExists(path string, t *tomb.Tomb) (chan error, erro } // ChangeEvents watches for changes to the passed path relative to the -// allocation directory. The offset should be the last read offset. The tomb is +// allocation directory. The offset should be the last read offset. The context is // used to clean up the watch. -func (d *AllocDir) ChangeEvents(path string, curOffset int64, t *tomb.Tomb) (*watch.FileChanges, error) { +func (d *AllocDir) ChangeEvents(ctx context.Context, path string, curOffset int64) (*watch.FileChanges, error) { if escapes, err := structs.PathEscapesAllocDir("", path); err != nil { return nil, fmt.Errorf("Failed to check if path escapes alloc directory: %v", err) } else if escapes { return nil, fmt.Errorf("Path escapes the alloc directory") } + t := &tomb.Tomb{} + go func() { + <-ctx.Done() + t.Kill(nil) + }() + // Get the path relative to the alloc directory p := filepath.Join(d.AllocDir, path) watcher := getFileWatcher(p) diff --git a/client/allocdir/alloc_dir_test.go b/client/allocdir/alloc_dir_test.go index a89ac39486b..922ce52c633 100644 --- a/client/allocdir/alloc_dir_test.go +++ b/client/allocdir/alloc_dir_test.go @@ -3,6 +3,7 @@ package allocdir import ( "archive/tar" "bytes" + "context" "io" "io/ioutil" "log" @@ -12,8 +13,6 @@ import ( "syscall" "testing" - tomb "gopkg.in/tomb.v1" - cstructs "github.com/hashicorp/nomad/client/structs" "github.com/hashicorp/nomad/client/testutil" "github.com/hashicorp/nomad/nomad/structs" @@ -314,13 +313,12 @@ func TestAllocDir_EscapeChecking(t *testing.T) { } // BlockUntilExists - tomb := tomb.Tomb{} - if _, err := d.BlockUntilExists("../foo", &tomb); err == nil || !strings.Contains(err.Error(), "escapes") { + if _, err := d.BlockUntilExists(context.Background(), "../foo"); err == nil || !strings.Contains(err.Error(), "escapes") { t.Fatalf("BlockUntilExists of escaping path didn't error: %v", err) } // ChangeEvents - if _, err := d.ChangeEvents("../foo", 0, &tomb); err == nil || !strings.Contains(err.Error(), "escapes") { + if _, err := d.ChangeEvents(context.Background(), "../foo", 0); err == nil || !strings.Contains(err.Error(), "escapes") { t.Fatalf("ChangeEvents of escaping path didn't error: %v", err) } } diff --git a/client/client.go b/client/client.go index e27d2d62caf..757f3279900 100644 --- a/client/client.go +++ b/client/client.go @@ -6,8 +6,10 @@ import ( "io/ioutil" "log" "net" + "net/rpc" "os" "path/filepath" + "sort" "strconv" "strings" "sync" @@ -20,14 +22,16 @@ import ( multierror "github.com/hashicorp/go-multierror" "github.com/hashicorp/nomad/client/allocdir" "github.com/hashicorp/nomad/client/config" + "github.com/hashicorp/nomad/client/servers" "github.com/hashicorp/nomad/client/stats" cstructs "github.com/hashicorp/nomad/client/structs" "github.com/hashicorp/nomad/client/vaultclient" "github.com/hashicorp/nomad/command/agent/consul" "github.com/hashicorp/nomad/helper" + "github.com/hashicorp/nomad/helper/pool" + hstats "github.com/hashicorp/nomad/helper/stats" "github.com/hashicorp/nomad/helper/tlsutil" "github.com/hashicorp/nomad/helper/uuid" - "github.com/hashicorp/nomad/nomad" "github.com/hashicorp/nomad/nomad/structs" nconfig "github.com/hashicorp/nomad/nomad/structs/config" vaultapi "github.com/hashicorp/vault/api" @@ -107,10 +111,15 @@ type Client struct { logger *log.Logger - connPool *nomad.ConnPool + connPool *pool.ConnPool - // servers is the (optionally prioritized) list of nomad servers - servers *serverlist + // tlsWrap is used to wrap outbound connections using TLS. It should be + // accessed using the lock. + tlsWrap tlsutil.RegionWrapper + tlsWrapLock sync.RWMutex + + // servers is the list of nomad servers + servers *servers.Manager // heartbeat related times for tracking how often to heartbeat lastHeartbeat time.Time @@ -157,6 +166,11 @@ type Client struct { // clientACLResolver holds the ACL resolution state clientACLResolver + // rpcServer is used to serve RPCs by the local agent. + rpcServer *rpc.Server + endpoints rpcEndpoints + streamingRpcs *structs.StreamingRpcRegistery + // baseLabels are used when emitting tagged metrics. All client metrics will // have these tags, and optionally more. baseLabels []metrics.Label @@ -187,21 +201,28 @@ func NewClient(cfg *config.Config, consulCatalog consul.CatalogAPI, consulServic consulCatalog: consulCatalog, consulService: consulService, start: time.Now(), - connPool: nomad.NewPool(cfg.LogOutput, clientRPCCache, clientMaxStreams, tlsWrap), + connPool: pool.NewPool(cfg.LogOutput, clientRPCCache, clientMaxStreams, tlsWrap), + tlsWrap: tlsWrap, + streamingRpcs: structs.NewStreamingRpcRegistery(), logger: logger, allocs: make(map[string]*AllocRunner), allocUpdates: make(chan *structs.Allocation, 64), shutdownCh: make(chan struct{}), - servers: newServerList(), triggerDiscoveryCh: make(chan struct{}), serversDiscoveredCh: make(chan struct{}), } + // Initialize the server manager + c.servers = servers.New(c.logger, c.shutdownCh, c) + // Initialize the client if err := c.init(); err != nil { return nil, fmt.Errorf("failed to initialize client: %v", err) } + // Setup the clients RPC server + c.setupClientRpc() + // Initialize the ACL state if err := c.clientACLResolver.init(); err != nil { return nil, fmt.Errorf("failed to initialize ACL state: %v", err) @@ -248,7 +269,7 @@ func NewClient(cfg *config.Config, consulCatalog consul.CatalogAPI, consulServic // Set the preconfigured list of static servers c.configLock.RLock() if len(c.configCopy.Servers) > 0 { - if err := c.SetServers(c.configCopy.Servers); err != nil { + if err := c.setServersImpl(c.configCopy.Servers, true); err != nil { logger.Printf("[WARN] client: None of the configured servers are valid: %v", err) } } @@ -257,7 +278,7 @@ func NewClient(cfg *config.Config, consulCatalog consul.CatalogAPI, consulServic // Setup Consul discovery if enabled if c.configCopy.ConsulConfig.ClientAutoJoin != nil && *c.configCopy.ConsulConfig.ClientAutoJoin { go c.consulDiscovery() - if len(c.servers.all()) == 0 { + if c.servers.NumServers() == 0 { // No configured servers; trigger discovery manually c.triggerDiscoveryCh <- struct{}{} } @@ -374,6 +395,11 @@ func (c *Client) reloadTLSConnections(newConfig *nconfig.TLSConfig) error { tlsWrap = tw } + // Store the new tls wrapper. + c.tlsWrapLock.Lock() + c.tlsWrap = tlsWrap + c.tlsWrapLock.Unlock() + // Keep the client configuration up to date as we use configuration values to // decide on what type of connections to accept c.configLock.Lock() @@ -474,35 +500,6 @@ func (c *Client) Shutdown() error { return c.saveState() } -// RPC is used to forward an RPC call to a nomad server, or fail if no servers. -func (c *Client) RPC(method string, args interface{}, reply interface{}) error { - // Invoke the RPCHandler if it exists - if c.config.RPCHandler != nil { - return c.config.RPCHandler.RPC(method, args, reply) - } - - servers := c.servers.all() - if len(servers) == 0 { - return noServersErr - } - - var mErr multierror.Error - for _, s := range servers { - // Make the RPC request - if err := c.connPool.RPC(c.Region(), s.addr, c.RPCMajorVersion(), method, args, reply); err != nil { - errmsg := fmt.Errorf("RPC failed to server %s: %v", s.addr, err) - mErr.Errors = append(mErr.Errors, errmsg) - c.logger.Printf("[DEBUG] client: %v", errmsg) - c.servers.failed(s) - continue - } - c.servers.good(s) - return nil - } - - return mErr.ErrorOrNil() -} - // Stats is used to return statistics for debugging and insight // for various sub-systems func (c *Client) Stats() map[string]map[string]string { @@ -511,12 +508,12 @@ func (c *Client) Stats() map[string]map[string]string { stats := map[string]map[string]string{ "client": { "node_id": c.NodeID(), - "known_servers": c.servers.all().String(), + "known_servers": strings.Join(c.GetServers(), ","), "num_allocations": strconv.Itoa(c.NumAllocs()), "last_heartbeat": fmt.Sprintf("%v", time.Since(c.lastHeartbeat)), "heartbeat_ttl": fmt.Sprintf("%v", c.heartbeatTTL), }, - "runtime": nomad.RuntimeStats(), + "runtime": hstats.RuntimeStats(), } return stats } @@ -551,7 +548,7 @@ func (c *Client) GetAllocStats(allocID string) (AllocStatsReporter, error) { defer c.allocLock.RUnlock() ar, ok := c.allocs[allocID] if !ok { - return nil, fmt.Errorf("unknown allocation ID %q", allocID) + return nil, structs.NewErrUnknownAllocation(allocID) } return ar.StatsReporter(), nil } @@ -569,7 +566,7 @@ func (c *Client) ValidateMigrateToken(allocID, migrateToken string) bool { return true } - return nomad.CompareMigrateToken(allocID, c.secretNodeID(), migrateToken) + return structs.CompareMigrateToken(allocID, c.secretNodeID(), migrateToken) } // GetAllocFS returns the AllocFS interface for the alloc dir of an allocation @@ -579,7 +576,7 @@ func (c *Client) GetAllocFS(allocID string) (allocdir.AllocDirFS, error) { ar, ok := c.allocs[allocID] if !ok { - return nil, fmt.Errorf("unknown allocation ID %q", allocID) + return nil, structs.NewErrUnknownAllocation(allocID) } return ar.GetAllocDir(), nil } @@ -589,39 +586,71 @@ func (c *Client) GetClientAlloc(allocID string) (*structs.Allocation, error) { all := c.allAllocs() alloc, ok := all[allocID] if !ok { - return nil, fmt.Errorf("unknown allocation ID %q", allocID) + return nil, structs.NewErrUnknownAllocation(allocID) } return alloc, nil } // GetServers returns the list of nomad servers this client is aware of. func (c *Client) GetServers() []string { - endpoints := c.servers.all() + endpoints := c.servers.GetServers() res := make([]string, len(endpoints)) for i := range endpoints { - res[i] = endpoints[i].addr.String() + res[i] = endpoints[i].String() } + sort.Strings(res) return res } // SetServers sets a new list of nomad servers to connect to. As long as one // server is resolvable no error is returned. -func (c *Client) SetServers(servers []string) error { - endpoints := make([]*endpoint, 0, len(servers)) +func (c *Client) SetServers(in []string) error { + return c.setServersImpl(in, false) +} + +// setServersImpl sets a new list of nomad servers to connect to. If force is +// set, we add the server to the internal severlist even if the server could not +// be pinged. An error is returned if no endpoints were valid when non-forcing. +// +// Force should be used when setting the servers from the initial configuration +// since the server may be starting up in parallel and initial pings may fail. +func (c *Client) setServersImpl(in []string, force bool) error { + var mu sync.Mutex + var wg sync.WaitGroup var merr multierror.Error - for _, s := range servers { - addr, err := resolveServer(s) - if err != nil { - c.logger.Printf("[DEBUG] client: ignoring server %s due to resolution error: %v", s, err) - merr.Errors = append(merr.Errors, err) - continue - } - // Valid endpoint, append it without a priority as this API - // doesn't support different priorities for different servers - endpoints = append(endpoints, &endpoint{name: s, addr: addr}) + endpoints := make([]*servers.Server, 0, len(in)) + wg.Add(len(in)) + + for _, s := range in { + go func(srv string) { + defer wg.Done() + addr, err := resolveServer(srv) + if err != nil { + c.logger.Printf("[DEBUG] client: ignoring server %s due to resolution error: %v", srv, err) + merr.Errors = append(merr.Errors, err) + return + } + + // Try to ping to check if it is a real server + if err := c.Ping(addr); err != nil { + merr.Errors = append(merr.Errors, fmt.Errorf("Server at address %s failed ping: %v", addr, err)) + + // If we are forcing the setting of the servers, inject it to + // the serverlist even if we can't ping immediately. + if !force { + return + } + } + + mu.Lock() + endpoints = append(endpoints, &servers.Server{Addr: addr}) + mu.Unlock() + }(s) } + wg.Wait() + // Only return errors if no servers are valid if len(endpoints) == 0 { if len(merr.Errors) > 0 { @@ -630,7 +659,7 @@ func (c *Client) SetServers(servers []string) error { return noServersErr } - c.servers.set(endpoints) + c.servers.SetServers(endpoints) return nil } @@ -1185,26 +1214,25 @@ func (c *Client) updateNodeStatus() error { } } - // Convert []*NodeServerInfo to []*endpoints - localdc := c.Datacenter() - servers := make(endpoints, 0, len(resp.Servers)) + // Update the number of nodes in the cluster so we can adjust our server + // rebalance rate. + c.servers.SetNumNodes(resp.NumNodes) + + // Convert []*NodeServerInfo to []*servers.Server + nomadServers := make([]*servers.Server, 0, len(resp.Servers)) for _, s := range resp.Servers { addr, err := resolveServer(s.RPCAdvertiseAddr) if err != nil { c.logger.Printf("[WARN] client: ignoring invalid server %q: %v", s.RPCAdvertiseAddr, err) continue } - e := endpoint{name: s.RPCAdvertiseAddr, addr: addr} - if s.Datacenter != localdc { - // server is non-local; de-prioritize - e.priority = 1 - } - servers = append(servers, &e) + e := &servers.Server{DC: s.Datacenter, Addr: addr} + nomadServers = append(nomadServers, e) } - if len(servers) == 0 { - return fmt.Errorf("server returned no valid servers") + if len(nomadServers) == 0 { + return fmt.Errorf("heartbeat response returned no valid servers") } - c.servers.set(servers) + c.servers.SetServers(nomadServers) // Begin polling Consul if there is no Nomad leader. We could be // heartbeating to a Nomad server that is in the minority of a @@ -1797,7 +1825,7 @@ func (c *Client) consulDiscoveryImpl() error { serviceName := c.configCopy.ConsulConfig.ServerServiceName var mErr multierror.Error - var servers endpoints + var nomadServers servers.Servers c.logger.Printf("[DEBUG] client.consul: bootstrap contacting following Consul DCs: %+q", dcs) DISCOLOOP: for _, dc := range dcs { @@ -1837,22 +1865,23 @@ DISCOLOOP: if err != nil { mErr.Errors = append(mErr.Errors, err) } - servers = append(servers, &endpoint{name: p, addr: addr}) + srv := &servers.Server{Addr: addr} + nomadServers = append(nomadServers, srv) } - if len(servers) > 0 { + if len(nomadServers) > 0 { break DISCOLOOP } } } - if len(servers) == 0 { + if len(nomadServers) == 0 { if len(mErr.Errors) > 0 { return mErr.ErrorOrNil() } return fmt.Errorf("no Nomad Servers advertising service %q in Consul datacenters: %+q", serviceName, dcs) } - c.logger.Printf("[INFO] client.consul: discovered following Servers: %s", servers) - c.servers.set(servers) + c.logger.Printf("[INFO] client.consul: discovered following Servers: %s", nomadServers) + c.servers.SetServers(nomadServers) // Notify waiting rpc calls. If a goroutine just failed an RPC call and // isn't receiving on this chan yet they'll still retry eventually. @@ -2164,19 +2193,3 @@ func (c *Client) allAllocs() map[string]*structs.Allocation { } return allocs } - -// resolveServer given a sever's address as a string, return it's resolved -// net.Addr or an error. -func resolveServer(s string) (net.Addr, error) { - const defaultClientPort = "4647" // default client RPC port - host, port, err := net.SplitHostPort(s) - if err != nil { - if strings.Contains(err.Error(), "missing port") { - host = s - port = defaultClientPort - } else { - return nil, err - } - } - return net.ResolveTCPAddr("tcp", net.JoinHostPort(host, port)) -} diff --git a/client/client_stats_endpoint.go b/client/client_stats_endpoint.go new file mode 100644 index 00000000000..630707ca68c --- /dev/null +++ b/client/client_stats_endpoint.go @@ -0,0 +1,30 @@ +package client + +import ( + "time" + + metrics "github.com/armon/go-metrics" + "github.com/hashicorp/nomad/client/structs" + nstructs "github.com/hashicorp/nomad/nomad/structs" +) + +// ClientStats endpoint is used for retrieving stats about a client +type ClientStats struct { + c *Client +} + +// Stats is used to retrieve the Clients stats. +func (s *ClientStats) Stats(args *nstructs.NodeSpecificRequest, reply *structs.ClientStatsResponse) error { + defer metrics.MeasureSince([]string{"client", "client_stats", "stats"}, time.Now()) + + // Check node read permissions + if aclObj, err := s.c.ResolveToken(args.AuthToken); err != nil { + return err + } else if aclObj != nil && !aclObj.AllowNodeRead() { + return nstructs.ErrPermissionDenied + } + + clientStats := s.c.StatsReporter() + reply.HostStats = clientStats.LatestHostStats() + return nil +} diff --git a/client/client_stats_endpoint_test.go b/client/client_stats_endpoint_test.go new file mode 100644 index 00000000000..c16f91f3dcc --- /dev/null +++ b/client/client_stats_endpoint_test.go @@ -0,0 +1,85 @@ +package client + +import ( + "testing" + + "github.com/hashicorp/nomad/acl" + "github.com/hashicorp/nomad/client/config" + "github.com/hashicorp/nomad/client/structs" + "github.com/hashicorp/nomad/nomad/mock" + nstructs "github.com/hashicorp/nomad/nomad/structs" + "github.com/stretchr/testify/require" +) + +func TestClientStats_Stats(t *testing.T) { + t.Parallel() + require := require.New(t) + client := TestClient(t, nil) + + req := &nstructs.NodeSpecificRequest{} + var resp structs.ClientStatsResponse + require.Nil(client.ClientRPC("ClientStats.Stats", &req, &resp)) + require.NotNil(resp.HostStats) + require.NotNil(resp.HostStats.AllocDirStats) + require.NotZero(resp.HostStats.Uptime) +} + +func TestClientStats_Stats_ACL(t *testing.T) { + t.Parallel() + require := require.New(t) + server, addr, root := testACLServer(t, nil) + defer server.Shutdown() + + client := TestClient(t, func(c *config.Config) { + c.Servers = []string{addr} + c.ACLEnabled = true + }) + defer client.Shutdown() + + // Try request without a token and expect failure + { + req := &nstructs.NodeSpecificRequest{} + var resp structs.ClientStatsResponse + err := client.ClientRPC("ClientStats.Stats", &req, &resp) + require.NotNil(err) + require.EqualError(err, nstructs.ErrPermissionDenied.Error()) + } + + // Try request with an invalid token and expect failure + { + token := mock.CreatePolicyAndToken(t, server.State(), 1005, "invalid", mock.NodePolicy(acl.PolicyDeny)) + req := &nstructs.NodeSpecificRequest{} + req.AuthToken = token.SecretID + + var resp structs.ClientStatsResponse + err := client.ClientRPC("ClientStats.Stats", &req, &resp) + + require.NotNil(err) + require.EqualError(err, nstructs.ErrPermissionDenied.Error()) + } + + // Try request with a valid token + { + token := mock.CreatePolicyAndToken(t, server.State(), 1007, "valid", mock.NodePolicy(acl.PolicyRead)) + req := &nstructs.NodeSpecificRequest{} + req.AuthToken = token.SecretID + + var resp structs.ClientStatsResponse + err := client.ClientRPC("ClientStats.Stats", &req, &resp) + + require.Nil(err) + require.NotNil(resp.HostStats) + } + + // Try request with a management token + { + req := &nstructs.NodeSpecificRequest{} + req.AuthToken = root.SecretID + + var resp structs.ClientStatsResponse + err := client.ClientRPC("ClientStats.Stats", &req, &resp) + + require.Nil(err) + require.NotNil(resp.HostStats) + } +} diff --git a/client/client_test.go b/client/client_test.go index da6a14fb7b4..13d4debfb01 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -4,20 +4,15 @@ import ( "fmt" "io/ioutil" "log" - "math/rand" - "net" "os" "path/filepath" "testing" "time" - "github.com/hashicorp/consul/lib/freeport" memdb "github.com/hashicorp/go-memdb" "github.com/hashicorp/nomad/client/config" "github.com/hashicorp/nomad/client/driver" - "github.com/hashicorp/nomad/client/fingerprint" "github.com/hashicorp/nomad/command/agent/consul" - "github.com/hashicorp/nomad/helper" "github.com/hashicorp/nomad/helper/uuid" "github.com/hashicorp/nomad/nomad" "github.com/hashicorp/nomad/nomad/mock" @@ -32,110 +27,18 @@ import ( ) func testACLServer(t *testing.T, cb func(*nomad.Config)) (*nomad.Server, string, *structs.ACLToken) { - server, addr := testServer(t, func(c *nomad.Config) { - c.ACLEnabled = true - if cb != nil { - cb(c) - } - }) - token := mock.ACLManagementToken() - err := server.State().BootstrapACLTokens(1, 0, token) - if err != nil { - t.Fatalf("failed to bootstrap ACL token: %v", err) - } - return server, addr, token + server, token := nomad.TestACLServer(t, cb) + return server, server.GetConfig().RPCAddr.String(), token } func testServer(t *testing.T, cb func(*nomad.Config)) (*nomad.Server, string) { - // Setup the default settings - config := nomad.DefaultConfig() - config.VaultConfig.Enabled = helper.BoolToPtr(false) - config.Build = "unittest" - config.DevMode = true - - // Tighten the Serf timing - config.SerfConfig.MemberlistConfig.BindAddr = "127.0.0.1" - config.SerfConfig.MemberlistConfig.SuspicionMult = 2 - config.SerfConfig.MemberlistConfig.RetransmitMult = 2 - config.SerfConfig.MemberlistConfig.ProbeTimeout = 50 * time.Millisecond - config.SerfConfig.MemberlistConfig.ProbeInterval = 100 * time.Millisecond - config.SerfConfig.MemberlistConfig.GossipInterval = 100 * time.Millisecond - - // Tighten the Raft timing - config.RaftConfig.LeaderLeaseTimeout = 20 * time.Millisecond - config.RaftConfig.HeartbeatTimeout = 40 * time.Millisecond - config.RaftConfig.ElectionTimeout = 40 * time.Millisecond - config.RaftConfig.StartAsLeader = true - config.RaftTimeout = 500 * time.Millisecond - - logger := log.New(config.LogOutput, "", log.LstdFlags) - catalog := consul.NewMockCatalog(logger) - - // Invoke the callback if any - if cb != nil { - cb(config) - } - - // Enable raft as leader if we have bootstrap on - config.RaftConfig.StartAsLeader = !config.DevDisableBootstrap - - for i := 10; i >= 0; i-- { - ports := freeport.GetT(t, 2) - config.RPCAddr = &net.TCPAddr{ - IP: []byte{127, 0, 0, 1}, - Port: ports[0], - } - config.NodeName = fmt.Sprintf("Node %d", config.RPCAddr.Port) - config.SerfConfig.MemberlistConfig.BindPort = ports[1] - - // Create server - server, err := nomad.NewServer(config, catalog, logger) - if err == nil { - return server, config.RPCAddr.String() - } else if i == 0 { - t.Fatalf("err: %v", err) - } else { - wait := time.Duration(rand.Int31n(2000)) * time.Millisecond - time.Sleep(wait) - } - } - return nil, "" -} - -func testClient(t *testing.T, cb func(c *config.Config)) *Client { - conf := config.DefaultConfig() - conf.VaultConfig.Enabled = helper.BoolToPtr(false) - conf.DevMode = true - conf.Node = &structs.Node{ - Reserved: &structs.Resources{ - DiskMB: 0, - }, - } - - // Tighten the fingerprinter timeouts - if conf.Options == nil { - conf.Options = make(map[string]string) - } - conf.Options[fingerprint.TightenNetworkTimeoutsConfig] = "true" - - if cb != nil { - cb(conf) - } - - logger := log.New(conf.LogOutput, "", log.LstdFlags) - catalog := consul.NewMockCatalog(logger) - mockService := newMockConsulServiceClient() - mockService.logger = logger - client, err := NewClient(conf, catalog, mockService, logger) - if err != nil { - t.Fatalf("err: %v", err) - } - return client + server := nomad.TestServer(t, cb) + return server, server.GetConfig().RPCAddr.String() } func TestClient_StartStop(t *testing.T) { t.Parallel() - client := testClient(t, nil) + client := TestClient(t, nil) if err := client.Shutdown(); err != nil { t.Fatalf("err: %v", err) } @@ -147,7 +50,7 @@ func TestClient_BaseLabels(t *testing.T) { t.Parallel() assert := assert.New(t) - client := testClient(t, nil) + client := TestClient(t, nil) if err := client.Shutdown(); err != nil { t.Fatalf("err: %v", err) } @@ -172,7 +75,7 @@ func TestClient_RPC(t *testing.T) { s1, addr := testServer(t, nil) defer s1.Shutdown() - c1 := testClient(t, func(c *config.Config) { + c1 := TestClient(t, func(c *config.Config) { c.Servers = []string{addr} }) defer c1.Shutdown() @@ -192,7 +95,7 @@ func TestClient_RPC_Passthrough(t *testing.T) { s1, _ := testServer(t, nil) defer s1.Shutdown() - c1 := testClient(t, func(c *config.Config) { + c1 := TestClient(t, func(c *config.Config) { c.RPCHandler = s1 }) defer c1.Shutdown() @@ -213,7 +116,7 @@ func TestClient_Fingerprint(t *testing.T) { driver.CheckForMockDriver(t) - c := testClient(t, nil) + c := TestClient(t, nil) defer c.Shutdown() // Ensure default values are present @@ -225,7 +128,7 @@ func TestClient_Fingerprint(t *testing.T) { func TestClient_HasNodeChanged(t *testing.T) { t.Parallel() - c := testClient(t, nil) + c := TestClient(t, nil) defer c.Shutdown() node := c.config.Node @@ -261,7 +164,7 @@ func TestClient_Fingerprint_Periodic(t *testing.T) { // these constants are only defined when nomad_test is enabled, so these fail // our linter without explicit disabling. - c1 := testClient(t, func(c *config.Config) { + c1 := TestClient(t, func(c *config.Config) { c.Options = map[string]string{ driver.ShutdownPeriodicAfter: "true", // nolint: varcheck driver.ShutdownPeriodicDuration: "3", // nolint: varcheck @@ -317,7 +220,7 @@ func TestClient_MixedTLS(t *testing.T) { defer s1.Shutdown() testutil.WaitForLeader(t, s1.RPC) - c1 := testClient(t, func(c *config.Config) { + c1 := TestClient(t, func(c *config.Config) { c.Servers = []string{addr} }) defer c1.Shutdown() @@ -367,7 +270,7 @@ func TestClient_BadTLS(t *testing.T) { defer s1.Shutdown() testutil.WaitForLeader(t, s1.RPC) - c1 := testClient(t, func(c *config.Config) { + c1 := TestClient(t, func(c *config.Config) { c.Servers = []string{addr} c.TLSConfig = &nconfig.TLSConfig{ EnableHTTP: true, @@ -405,7 +308,7 @@ func TestClient_Register(t *testing.T) { defer s1.Shutdown() testutil.WaitForLeader(t, s1.RPC) - c1 := testClient(t, func(c *config.Config) { + c1 := TestClient(t, func(c *config.Config) { c.RPCHandler = s1 }) defer c1.Shutdown() @@ -439,7 +342,7 @@ func TestClient_Heartbeat(t *testing.T) { defer s1.Shutdown() testutil.WaitForLeader(t, s1.RPC) - c1 := testClient(t, func(c *config.Config) { + c1 := TestClient(t, func(c *config.Config) { c.RPCHandler = s1 }) defer c1.Shutdown() @@ -471,7 +374,7 @@ func TestClient_UpdateAllocStatus(t *testing.T) { defer s1.Shutdown() testutil.WaitForLeader(t, s1.RPC) - c1 := testClient(t, func(c *config.Config) { + c1 := TestClient(t, func(c *config.Config) { c.RPCHandler = s1 }) defer c1.Shutdown() @@ -522,7 +425,7 @@ func TestClient_WatchAllocs(t *testing.T) { defer s1.Shutdown() testutil.WaitForLeader(t, s1.RPC) - c1 := testClient(t, func(c *config.Config) { + c1 := TestClient(t, func(c *config.Config) { c.RPCHandler = s1 }) defer c1.Shutdown() @@ -617,7 +520,7 @@ func TestClient_SaveRestoreState(t *testing.T) { defer s1.Shutdown() testutil.WaitForLeader(t, s1.RPC) - c1 := testClient(t, func(c *config.Config) { + c1 := TestClient(t, func(c *config.Config) { c.DevMode = false c.RPCHandler = s1 }) @@ -671,7 +574,7 @@ func TestClient_SaveRestoreState(t *testing.T) { // Create a new client logger := log.New(c1.config.LogOutput, "", log.LstdFlags) catalog := consul.NewMockCatalog(logger) - mockService := newMockConsulServiceClient() + mockService := newMockConsulServiceClient(t) mockService.logger = logger c2, err := NewClient(c1.config, catalog, mockService, logger) if err != nil { @@ -734,7 +637,7 @@ func TestClient_BlockedAllocations(t *testing.T) { defer s1.Shutdown() testutil.WaitForLeader(t, s1.RPC) - c1 := testClient(t, func(c *config.Config) { + c1 := TestClient(t, func(c *config.Config) { c.RPCHandler = s1 }) defer c1.Shutdown() @@ -845,13 +748,13 @@ func TestClient_ValidateMigrateToken_ValidToken(t *testing.T) { t.Parallel() assert := assert.New(t) - c := testClient(t, func(c *config.Config) { + c := TestClient(t, func(c *config.Config) { c.ACLEnabled = true }) defer c.Shutdown() alloc := mock.Alloc() - validToken, err := nomad.GenerateMigrateToken(alloc.ID, c.secretNodeID()) + validToken, err := structs.GenerateMigrateToken(alloc.ID, c.secretNodeID()) assert.Nil(err) assert.Equal(c.ValidateMigrateToken(alloc.ID, validToken), true) @@ -861,7 +764,7 @@ func TestClient_ValidateMigrateToken_InvalidToken(t *testing.T) { t.Parallel() assert := assert.New(t) - c := testClient(t, func(c *config.Config) { + c := TestClient(t, func(c *config.Config) { c.ACLEnabled = true }) defer c.Shutdown() @@ -877,7 +780,7 @@ func TestClient_ValidateMigrateToken_ACLDisabled(t *testing.T) { t.Parallel() assert := assert.New(t) - c := testClient(t, func(c *config.Config) {}) + c := TestClient(t, func(c *config.Config) {}) defer c.Shutdown() assert.Equal(c.ValidateMigrateToken("", ""), true) @@ -899,7 +802,7 @@ func TestClient_ReloadTLS_UpgradePlaintextToTLS(t *testing.T) { fookey = "../helper/tlsutil/testdata/nomad-foo-key.pem" ) - c1 := testClient(t, func(c *config.Config) { + c1 := TestClient(t, func(c *config.Config) { c.Servers = []string{addr} }) defer c1.Shutdown() @@ -975,7 +878,7 @@ func TestClient_ReloadTLS_DowngradeTLSToPlaintext(t *testing.T) { fookey = "../helper/tlsutil/testdata/nomad-foo-key.pem" ) - c1 := testClient(t, func(c *config.Config) { + c1 := TestClient(t, func(c *config.Config) { c.Servers = []string{addr} c.TLSConfig = &nconfig.TLSConfig{ EnableHTTP: true, @@ -1002,10 +905,9 @@ func TestClient_ReloadTLS_DowngradeTLSToPlaintext(t *testing.T) { return false, fmt.Errorf("client RPC succeeded when it should have failed :\n%+v", err) } return true, nil + }, func(err error) { + t.Fatalf(err.Error()) }, - func(err error) { - t.Fatalf(err.Error()) - }, ) } @@ -1028,10 +930,33 @@ func TestClient_ReloadTLS_DowngradeTLSToPlaintext(t *testing.T) { return false, fmt.Errorf("client RPC failed when it should have succeeded:\n%+v", err) } return true, nil + }, func(err error) { + t.Fatalf(err.Error()) }, - func(err error) { - t.Fatalf(err.Error()) - }, ) } } + +// TestClient_ServerList tests client methods that interact with the internal +// nomad server list. +func TestClient_ServerList(t *testing.T) { + t.Parallel() + client := TestClient(t, func(c *config.Config) {}) + + if s := client.GetServers(); len(s) != 0 { + t.Fatalf("expected server lit to be empty but found: %+q", s) + } + if err := client.SetServers(nil); err != noServersErr { + t.Fatalf("expected setting an empty list to return a 'no servers' error but received %v", err) + } + if err := client.SetServers([]string{"123.456.13123.123.13:80"}); err == nil { + t.Fatalf("expected setting a bad server to return an error") + } + if err := client.SetServers([]string{"123.456.13123.123.13:80", "127.0.0.1:1234", "127.0.0.1"}); err == nil { + t.Fatalf("expected setting at least one good server to succeed but received: %v", err) + } + s := client.GetServers() + if len(s) != 0 { + t.Fatalf("expected 2 servers but received: %+q", s) + } +} diff --git a/client/config/config.go b/client/config/config.go index 8a16eb8bebe..ff30f20c822 100644 --- a/client/config/config.go +++ b/client/config/config.go @@ -195,6 +195,13 @@ type Config struct { // BackwardsCompatibleMetrics determines whether to show methods of // displaying metrics for older verions, or to only show the new format BackwardsCompatibleMetrics bool + + // RPCHoldTimeout is how long an RPC can be "held" before it is errored. + // This is used to paper over a loss of leadership by instead holding RPCs, + // so that the caller experiences a slow response rather than an error. + // This period is meant to be long enough for a leader election to take + // place, and a small jitter is applied to avoid a thundering herd. + RPCHoldTimeout time.Duration } func (c *Config) Copy() *Config { @@ -228,6 +235,7 @@ func DefaultConfig() *Config { NoHostUUID: true, DisableTaggedMetrics: false, BackwardsCompatibleMetrics: false, + RPCHoldTimeout: 5 * time.Second, } } diff --git a/client/consul_test.go b/client/consul_testing.go similarity index 93% rename from client/consul_test.go rename to client/consul_testing.go index 8703cdd215a..4a2d2631bc6 100644 --- a/client/consul_test.go +++ b/client/consul_testing.go @@ -2,16 +2,15 @@ package client import ( "fmt" - "io/ioutil" "log" - "os" "sync" - "testing" "github.com/hashicorp/nomad/client/driver" cstructs "github.com/hashicorp/nomad/client/structs" "github.com/hashicorp/nomad/command/agent/consul" + "github.com/hashicorp/nomad/helper/testlog" "github.com/hashicorp/nomad/nomad/structs" + "github.com/mitchellh/go-testing-interface" ) // mockConsulOp represents the register/deregister operations. @@ -49,13 +48,10 @@ type mockConsulServiceClient struct { allocRegistrationsFn func(allocID string) (*consul.AllocRegistration, error) } -func newMockConsulServiceClient() *mockConsulServiceClient { +func newMockConsulServiceClient(t testing.T) *mockConsulServiceClient { m := mockConsulServiceClient{ ops: make([]mockConsulOp, 0, 20), - logger: log.New(ioutil.Discard, "", 0), - } - if testing.Verbose() { - m.logger = log.New(os.Stderr, "", log.LstdFlags) + logger: testlog.Logger(t), } return &m } diff --git a/client/driver/mock_driver.go b/client/driver/mock_driver.go index 15cc56b5b41..29d6a4a9d70 100644 --- a/client/driver/mock_driver.go +++ b/client/driver/mock_driver.go @@ -1,5 +1,3 @@ -//+build nomad_test - package driver import ( @@ -7,6 +5,7 @@ import ( "encoding/json" "errors" "fmt" + "io" "log" "os" "strconv" @@ -15,6 +14,7 @@ import ( "github.com/mitchellh/mapstructure" + "github.com/hashicorp/nomad/client/driver/logging" dstructs "github.com/hashicorp/nomad/client/driver/structs" cstructs "github.com/hashicorp/nomad/client/structs" "github.com/hashicorp/nomad/nomad/structs" @@ -32,11 +32,6 @@ const ( ShutdownPeriodicDuration = "test.shutdown_periodic_duration" ) -// Add the mock driver to the list of builtin drivers -func init() { - BuiltinDrivers["mock_driver"] = NewMockDriver -} - // MockDriverConfig is the driver configuration for the MockDriver type MockDriverConfig struct { @@ -83,6 +78,15 @@ type MockDriverConfig struct { // DriverPortMap will parse a label:number pair and return it in // DriverNetwork.PortMap from Start(). DriverPortMap string `mapstructure:"driver_port_map"` + + // StdoutString is the string that should be sent to stdout + StdoutString string `mapstructure:"stdout_string"` + + // StdoutRepeat is the number of times the output should be sent. + StdoutRepeat int `mapstructure:"stdout_repeat"` + + // StdoutRepeatDur is the duration between repeated outputs. + StdoutRepeatDur time.Duration `mapstructure:"stdout_repeat_duration"` } // MockDriver is a driver which is used for testing purposes @@ -169,15 +173,20 @@ func (m *MockDriver) Start(ctx *ExecContext, task *structs.Task) (*StartResponse } h := mockDriverHandle{ - taskName: task.Name, - runFor: driverConfig.RunFor, - killAfter: driverConfig.KillAfter, - killTimeout: task.KillTimeout, - exitCode: driverConfig.ExitCode, - exitSignal: driverConfig.ExitSignal, - logger: m.logger, - doneCh: make(chan struct{}), - waitCh: make(chan *dstructs.WaitResult, 1), + ctx: ctx, + task: task, + taskName: task.Name, + runFor: driverConfig.RunFor, + killAfter: driverConfig.KillAfter, + killTimeout: task.KillTimeout, + exitCode: driverConfig.ExitCode, + exitSignal: driverConfig.ExitSignal, + stdoutString: driverConfig.StdoutString, + stdoutRepeat: driverConfig.StdoutRepeat, + stdoutRepeatDur: driverConfig.StdoutRepeatDur, + logger: m.logger, + doneCh: make(chan struct{}), + waitCh: make(chan *dstructs.WaitResult, 1), } if driverConfig.ExitErrMsg != "" { h.exitErr = errors.New(driverConfig.ExitErrMsg) @@ -233,19 +242,29 @@ func (m *MockDriver) Fingerprint(req *cstructs.FingerprintRequest, resp *cstruct return nil } +// When testing, poll for updates +func (m *MockDriver) Periodic() (bool, time.Duration) { + return true, 500 * time.Millisecond +} + // MockDriverHandle is a driver handler which supervises a mock task type mockDriverHandle struct { - taskName string - runFor time.Duration - killAfter time.Duration - killTimeout time.Duration - exitCode int - exitSignal int - exitErr error - signalErr error - logger *log.Logger - waitCh chan *dstructs.WaitResult - doneCh chan struct{} + ctx *ExecContext + task *structs.Task + taskName string + runFor time.Duration + killAfter time.Duration + killTimeout time.Duration + exitCode int + exitSignal int + exitErr error + signalErr error + logger *log.Logger + stdoutString string + stdoutRepeat int + stdoutRepeatDur time.Duration + waitCh chan *dstructs.WaitResult + doneCh chan struct{} } type mockDriverID struct { @@ -355,6 +374,11 @@ func (h *mockDriverHandle) Stats() (*cstructs.TaskResourceUsage, error) { // run waits for the configured amount of time and then indicates the task has // terminated func (h *mockDriverHandle) run() { + // Setup logging output + if h.stdoutString != "" { + go h.handleLogging() + } + timer := time.NewTimer(h.runFor) defer timer.Stop() for { @@ -374,7 +398,43 @@ func (h *mockDriverHandle) run() { } } -// When testing, poll for updates -func (m *MockDriver) Periodic() (bool, time.Duration) { - return true, 500 * time.Millisecond +// handleLogging handles logging stdout messages +func (h *mockDriverHandle) handleLogging() { + if h.stdoutString == "" { + return + } + + // Setup a log rotator + logFileSize := int64(h.task.LogConfig.MaxFileSizeMB * 1024 * 1024) + lro, err := logging.NewFileRotator(h.ctx.TaskDir.LogDir, fmt.Sprintf("%v.stdout", h.taskName), + h.task.LogConfig.MaxFiles, logFileSize, h.logger) + if err != nil { + h.exitErr = err + close(h.doneCh) + h.logger.Printf("[ERR] mock_driver: failed to setup file rotator: %v", err) + return + } + defer lro.Close() + + // Do initial write to stdout. + if _, err := io.WriteString(lro, h.stdoutString); err != nil { + h.exitErr = err + close(h.doneCh) + h.logger.Printf("[ERR] mock_driver: failed to write to stdout: %v", err) + return + } + + for i := 0; i < h.stdoutRepeat; i++ { + select { + case <-h.doneCh: + return + case <-time.After(h.stdoutRepeatDur): + if _, err := io.WriteString(lro, h.stdoutString); err != nil { + h.exitErr = err + close(h.doneCh) + h.logger.Printf("[ERR] mock_driver: failed to write to stdout: %v", err) + return + } + } + } } diff --git a/client/driver/mock_driver_testing.go b/client/driver/mock_driver_testing.go new file mode 100644 index 00000000000..1b1e861a891 --- /dev/null +++ b/client/driver/mock_driver_testing.go @@ -0,0 +1,8 @@ +//+build nomad_test + +package driver + +// Add the mock driver to the list of builtin drivers +func init() { + BuiltinDrivers["mock_driver"] = NewMockDriver +} diff --git a/client/fs_endpoint.go b/client/fs_endpoint.go new file mode 100644 index 00000000000..eaff009c7df --- /dev/null +++ b/client/fs_endpoint.go @@ -0,0 +1,921 @@ +package client + +import ( + "bytes" + "context" + "fmt" + "io" + "math" + "os" + "path/filepath" + "sort" + "strconv" + "strings" + "syscall" + "time" + + metrics "github.com/armon/go-metrics" + "github.com/hashicorp/nomad/acl" + "github.com/hashicorp/nomad/client/allocdir" + sframer "github.com/hashicorp/nomad/client/lib/streamframer" + cstructs "github.com/hashicorp/nomad/client/structs" + "github.com/hashicorp/nomad/helper" + "github.com/hashicorp/nomad/nomad/structs" + "github.com/hpcloud/tail/watch" + "github.com/ugorji/go/codec" +) + +var ( + allocIDNotPresentErr = fmt.Errorf("must provide a valid alloc id") + pathNotPresentErr = fmt.Errorf("must provide a file path") + taskNotPresentErr = fmt.Errorf("must provide task name") + logTypeNotPresentErr = fmt.Errorf("must provide log type (stdout/stderr)") + invalidOrigin = fmt.Errorf("origin must be start or end") +) + +const ( + // streamFramesBuffer is the number of stream frames that will be buffered + // before back pressure is applied on the stream framer. + streamFramesBuffer = 32 + + // streamFrameSize is the maximum number of bytes to send in a single frame + streamFrameSize = 64 * 1024 + + // streamHeartbeatRate is the rate at which a heartbeat will occur to detect + // a closed connection without sending any additional data + streamHeartbeatRate = 1 * time.Second + + // streamBatchWindow is the window in which file content is batched before + // being flushed if the frame size has not been hit. + streamBatchWindow = 200 * time.Millisecond + + // nextLogCheckRate is the rate at which we check for a log entry greater + // than what we are watching for. This is to handle the case in which logs + // rotate faster than we can detect and we have to rely on a normal + // directory listing. + nextLogCheckRate = 100 * time.Millisecond + + // deleteEvent and truncateEvent are the file events that can be sent in a + // StreamFrame + deleteEvent = "file deleted" + truncateEvent = "file truncated" + + // OriginStart and OriginEnd are the available parameters for the origin + // argument when streaming a file. They respectively offset from the start + // and end of a file. + OriginStart = "start" + OriginEnd = "end" +) + +// FileSystem endpoint is used for accessing the logs and filesystem of +// allocations. +type FileSystem struct { + c *Client +} + +func NewFileSystemEndpoint(c *Client) *FileSystem { + f := &FileSystem{c} + f.c.streamingRpcs.Register("FileSystem.Logs", f.logs) + f.c.streamingRpcs.Register("FileSystem.Stream", f.stream) + return f +} + +// handleStreamResultError is a helper for sending an error with a potential +// error code. The transmission of the error is ignored if the error has been +// generated by the closing of the underlying transport. +func (f *FileSystem) handleStreamResultError(err error, code *int64, encoder *codec.Encoder) { + // Nothing to do as the conn is closed + if err == io.EOF || strings.Contains(err.Error(), "closed") { + return + } + + encoder.Encode(&cstructs.StreamErrWrapper{ + Error: cstructs.NewRpcError(err, code), + }) +} + +// List is used to list the contents of an allocation's directory. +func (f *FileSystem) List(args *cstructs.FsListRequest, reply *cstructs.FsListResponse) error { + defer metrics.MeasureSince([]string{"client", "file_system", "list"}, time.Now()) + + // Check read permissions + if aclObj, err := f.c.ResolveToken(args.QueryOptions.AuthToken); err != nil { + return err + } else if aclObj != nil && !aclObj.AllowNsOp(args.Namespace, acl.NamespaceCapabilityReadFS) { + return structs.ErrPermissionDenied + } + + fs, err := f.c.GetAllocFS(args.AllocID) + if err != nil { + return err + } + files, err := fs.List(args.Path) + if err != nil { + return err + } + + reply.Files = files + return nil +} + +// Stat is used to stat a file in the allocation's directory. +func (f *FileSystem) Stat(args *cstructs.FsStatRequest, reply *cstructs.FsStatResponse) error { + defer metrics.MeasureSince([]string{"client", "file_system", "stat"}, time.Now()) + + // Check read permissions + if aclObj, err := f.c.ResolveToken(args.QueryOptions.AuthToken); err != nil { + return err + } else if aclObj != nil && !aclObj.AllowNsOp(args.Namespace, acl.NamespaceCapabilityReadFS) { + return structs.ErrPermissionDenied + } + + fs, err := f.c.GetAllocFS(args.AllocID) + if err != nil { + return err + } + info, err := fs.Stat(args.Path) + if err != nil { + return err + } + + reply.Info = info + return nil +} + +// stream is is used to stream the contents of file in an allocation's +// directory. +func (f *FileSystem) stream(conn io.ReadWriteCloser) { + defer metrics.MeasureSince([]string{"client", "file_system", "stream"}, time.Now()) + defer conn.Close() + + // Decode the arguments + var req cstructs.FsStreamRequest + decoder := codec.NewDecoder(conn, structs.MsgpackHandle) + encoder := codec.NewEncoder(conn, structs.MsgpackHandle) + + if err := decoder.Decode(&req); err != nil { + f.handleStreamResultError(err, helper.Int64ToPtr(500), encoder) + return + } + + // Check read permissions + if aclObj, err := f.c.ResolveToken(req.QueryOptions.AuthToken); err != nil { + f.handleStreamResultError(err, nil, encoder) + return + } else if aclObj != nil && !aclObj.AllowNsOp(req.Namespace, acl.NamespaceCapabilityReadFS) { + f.handleStreamResultError(structs.ErrPermissionDenied, nil, encoder) + return + } + + // Validate the arguments + if req.AllocID == "" { + f.handleStreamResultError(allocIDNotPresentErr, helper.Int64ToPtr(400), encoder) + return + } + if req.Path == "" { + f.handleStreamResultError(pathNotPresentErr, helper.Int64ToPtr(400), encoder) + return + } + switch req.Origin { + case "start", "end": + case "": + req.Origin = "start" + default: + f.handleStreamResultError(invalidOrigin, helper.Int64ToPtr(400), encoder) + return + } + + fs, err := f.c.GetAllocFS(req.AllocID) + if err != nil { + code := helper.Int64ToPtr(500) + if structs.IsErrUnknownAllocation(err) { + code = helper.Int64ToPtr(404) + } + + f.handleStreamResultError(err, code, encoder) + return + } + + // Calculate the offset + fileInfo, err := fs.Stat(req.Path) + if err != nil { + f.handleStreamResultError(err, helper.Int64ToPtr(400), encoder) + return + } + if fileInfo.IsDir { + f.handleStreamResultError( + fmt.Errorf("file %q is a directory", req.Path), + helper.Int64ToPtr(400), encoder) + return + } + + // If offsetting from the end subtract from the size + if req.Origin == "end" { + req.Offset = fileInfo.Size - req.Offset + if req.Offset < 0 { + req.Offset = 0 + } + } + + frames := make(chan *sframer.StreamFrame, streamFramesBuffer) + errCh := make(chan error) + var buf bytes.Buffer + frameCodec := codec.NewEncoder(&buf, structs.JsonHandle) + + // Create the framer + framer := sframer.NewStreamFramer(frames, streamHeartbeatRate, streamBatchWindow, streamFrameSize) + framer.Run() + defer framer.Destroy() + + // If we aren't following end as soon as we hit EOF + var eofCancelCh chan error + if !req.Follow { + eofCancelCh = make(chan error) + close(eofCancelCh) + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Start streaming + go func() { + if err := f.streamFile(ctx, req.Offset, req.Path, req.Limit, fs, framer, eofCancelCh); err != nil { + select { + case errCh <- err: + case <-ctx.Done(): + } + } + + framer.Destroy() + }() + + // Create a goroutine to detect the remote side closing + go func() { + for { + if _, err := conn.Read(nil); err != nil { + if err == io.EOF { + cancel() + return + } + select { + case errCh <- err: + case <-ctx.Done(): + return + } + } + } + }() + + var streamErr error +OUTER: + for { + select { + case streamErr = <-errCh: + break OUTER + case frame, ok := <-frames: + if !ok { + break OUTER + } + + var resp cstructs.StreamErrWrapper + if req.PlainText { + resp.Payload = frame.Data + } else { + if err = frameCodec.Encode(frame); err != nil { + streamErr = err + break OUTER + } + + resp.Payload = buf.Bytes() + buf.Reset() + } + + if err := encoder.Encode(resp); err != nil { + streamErr = err + break OUTER + } + case <-ctx.Done(): + break OUTER + } + } + + if streamErr != nil { + f.handleStreamResultError(streamErr, helper.Int64ToPtr(500), encoder) + return + } +} + +// logs is is used to stream a task's logs. +func (f *FileSystem) logs(conn io.ReadWriteCloser) { + defer metrics.MeasureSince([]string{"client", "file_system", "logs"}, time.Now()) + defer conn.Close() + + // Decode the arguments + var req cstructs.FsLogsRequest + decoder := codec.NewDecoder(conn, structs.MsgpackHandle) + encoder := codec.NewEncoder(conn, structs.MsgpackHandle) + + if err := decoder.Decode(&req); err != nil { + f.handleStreamResultError(err, helper.Int64ToPtr(500), encoder) + return + } + + // Check read permissions + if aclObj, err := f.c.ResolveToken(req.QueryOptions.AuthToken); err != nil { + f.handleStreamResultError(err, nil, encoder) + return + } else if aclObj != nil { + readfs := aclObj.AllowNsOp(req.QueryOptions.Namespace, acl.NamespaceCapabilityReadFS) + logs := aclObj.AllowNsOp(req.QueryOptions.Namespace, acl.NamespaceCapabilityReadLogs) + if !readfs && !logs { + f.handleStreamResultError(structs.ErrPermissionDenied, nil, encoder) + return + } + } + + // Validate the arguments + if req.AllocID == "" { + f.handleStreamResultError(allocIDNotPresentErr, helper.Int64ToPtr(400), encoder) + return + } + if req.Task == "" { + f.handleStreamResultError(taskNotPresentErr, helper.Int64ToPtr(400), encoder) + return + } + switch req.LogType { + case "stdout", "stderr": + default: + f.handleStreamResultError(logTypeNotPresentErr, helper.Int64ToPtr(400), encoder) + return + } + switch req.Origin { + case "start", "end": + case "": + req.Origin = "start" + default: + f.handleStreamResultError(invalidOrigin, helper.Int64ToPtr(400), encoder) + return + } + + fs, err := f.c.GetAllocFS(req.AllocID) + if err != nil { + code := helper.Int64ToPtr(500) + if structs.IsErrUnknownAllocation(err) { + code = helper.Int64ToPtr(404) + } + + f.handleStreamResultError(err, code, encoder) + return + } + + alloc, err := f.c.GetClientAlloc(req.AllocID) + if err != nil { + code := helper.Int64ToPtr(500) + if structs.IsErrUnknownAllocation(err) { + code = helper.Int64ToPtr(404) + } + + f.handleStreamResultError(err, code, encoder) + return + } + + // Check that the task is there + tg := alloc.Job.LookupTaskGroup(alloc.TaskGroup) + if tg == nil { + f.handleStreamResultError(fmt.Errorf("Failed to lookup task group for allocation"), + helper.Int64ToPtr(500), encoder) + return + } else if taskStruct := tg.LookupTask(req.Task); taskStruct == nil { + f.handleStreamResultError( + fmt.Errorf("task group %q does not have task with name %q", alloc.TaskGroup, req.Task), + helper.Int64ToPtr(400), + encoder) + return + } + + state, ok := alloc.TaskStates[req.Task] + if !ok || state.StartedAt.IsZero() { + f.handleStreamResultError(fmt.Errorf("task %q not started yet. No logs available", req.Task), + helper.Int64ToPtr(404), encoder) + return + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + frames := make(chan *sframer.StreamFrame, streamFramesBuffer) + errCh := make(chan error) + var buf bytes.Buffer + frameCodec := codec.NewEncoder(&buf, structs.JsonHandle) + + // Start streaming + go func() { + if err := f.logsImpl(ctx, req.Follow, req.PlainText, + req.Offset, req.Origin, req.Task, req.LogType, fs, frames); err != nil { + select { + case errCh <- err: + case <-ctx.Done(): + } + } + }() + + // Create a goroutine to detect the remote side closing + go func() { + for { + if _, err := conn.Read(nil); err != nil { + if err == io.EOF { + cancel() + return + } + select { + case errCh <- err: + case <-ctx.Done(): + return + } + } + } + }() + + var streamErr error +OUTER: + for { + select { + case streamErr = <-errCh: + break OUTER + case frame, ok := <-frames: + if !ok { + break OUTER + } + + var resp cstructs.StreamErrWrapper + if req.PlainText { + resp.Payload = frame.Data + } else { + if err = frameCodec.Encode(frame); err != nil { + streamErr = err + break OUTER + } + + resp.Payload = buf.Bytes() + buf.Reset() + } + + if err := encoder.Encode(resp); err != nil { + streamErr = err + break OUTER + } + } + } + + if streamErr != nil { + f.handleStreamResultError(streamErr, helper.Int64ToPtr(500), encoder) + return + } +} + +// logsImpl is used to stream the logs of a the given task. Output is sent on +// the passed frames channel and the method will return on EOF if follow is not +// true otherwise when the context is cancelled or on an error. +func (f *FileSystem) logsImpl(ctx context.Context, follow, plain bool, offset int64, + origin, task, logType string, + fs allocdir.AllocDirFS, frames chan<- *sframer.StreamFrame) error { + + // Create the framer + framer := sframer.NewStreamFramer(frames, streamHeartbeatRate, streamBatchWindow, streamFrameSize) + framer.Run() + defer framer.Destroy() + + // Path to the logs + logPath := filepath.Join(allocdir.SharedAllocName, allocdir.LogDirName) + + // nextIdx is the next index to read logs from + var nextIdx int64 + switch origin { + case "start": + nextIdx = 0 + case "end": + nextIdx = math.MaxInt64 + offset *= -1 + default: + return invalidOrigin + } + + for { + // Logic for picking next file is: + // 1) List log files + // 2) Pick log file closest to desired index + // 3) Open log file at correct offset + // 3a) No error, read contents + // 3b) If file doesn't exist, goto 1 as it may have been rotated out + entries, err := fs.List(logPath) + if err != nil { + return fmt.Errorf("failed to list entries: %v", err) + } + + // If we are not following logs, determine the max index for the logs we are + // interested in so we can stop there. + maxIndex := int64(math.MaxInt64) + if !follow { + _, idx, _, err := findClosest(entries, maxIndex, 0, task, logType) + if err != nil { + return err + } + maxIndex = idx + } + + logEntry, idx, openOffset, err := findClosest(entries, nextIdx, offset, task, logType) + if err != nil { + return err + } + + var eofCancelCh chan error + exitAfter := false + if !follow && idx > maxIndex { + // Exceeded what was there initially so return + return nil + } else if !follow && idx == maxIndex { + // At the end + eofCancelCh = make(chan error) + close(eofCancelCh) + exitAfter = true + } else { + eofCancelCh = blockUntilNextLog(ctx, fs, logPath, task, logType, idx+1) + } + + p := filepath.Join(logPath, logEntry.Name) + err = f.streamFile(ctx, openOffset, p, 0, fs, framer, eofCancelCh) + + // Check if the context is cancelled + select { + case <-ctx.Done(): + return nil + default: + } + + if err != nil { + // Check if there was an error where the file does not exist. That means + // it got rotated out from under us. + if os.IsNotExist(err) { + continue + } + + // Check if the connection was closed + if err == syscall.EPIPE { + return nil + } + + return fmt.Errorf("failed to stream %q: %v", p, err) + } + + if exitAfter { + return nil + } + + // defensively check to make sure StreamFramer hasn't stopped + // running to avoid tight loops with goroutine leaks as in + // #3342 + select { + case <-framer.ExitCh(): + err := parseFramerErr(framer.Err()) + if err == syscall.EPIPE { + // EPIPE just means the connection was closed + return nil + } + return err + default: + } + + // Since we successfully streamed, update the overall offset/idx. + offset = int64(0) + nextIdx = idx + 1 + } +} + +// streamFile is the internal method to stream the content of a file. If limit +// is greater than zero, the stream will end once that many bytes have been +// read. eofCancelCh is used to cancel the stream if triggered while at EOF. If +// the connection is broken an EPIPE error is returned +func (f *FileSystem) streamFile(ctx context.Context, offset int64, path string, limit int64, + fs allocdir.AllocDirFS, framer *sframer.StreamFramer, eofCancelCh chan error) error { + + // Get the reader + file, err := fs.ReadAt(path, offset) + if err != nil { + return err + } + defer file.Close() + + var fileReader io.Reader + if limit <= 0 { + fileReader = file + } else { + fileReader = io.LimitReader(file, limit) + } + + // Create a tomb to cancel watch events + waitCtx, cancel := context.WithCancel(ctx) + defer cancel() + + // Create a variable to allow setting the last event + var lastEvent string + + // Only create the file change watcher once. But we need to do it after we + // read and reach EOF. + var changes *watch.FileChanges + + // Start streaming the data + bufSize := int64(streamFrameSize) + if limit > 0 && limit < streamFrameSize { + bufSize = limit + } + data := make([]byte, bufSize) +OUTER: + for { + // Read up to the max frame size + n, readErr := fileReader.Read(data) + + // Update the offset + offset += int64(n) + + // Return non-EOF errors + if readErr != nil && readErr != io.EOF { + return readErr + } + + // Send the frame + if n != 0 || lastEvent != "" { + if err := framer.Send(path, lastEvent, data[:n], offset); err != nil { + return parseFramerErr(err) + } + } + + // Clear the last event + if lastEvent != "" { + lastEvent = "" + } + + // Just keep reading since we aren't at the end of the file so we can + // avoid setting up a file event watcher. + if readErr == nil { + continue + } + + // If EOF is hit, wait for a change to the file + if changes == nil { + changes, err = fs.ChangeEvents(waitCtx, path, offset) + if err != nil { + return err + } + } + + for { + select { + case <-changes.Modified: + continue OUTER + case <-changes.Deleted: + return parseFramerErr(framer.Send(path, deleteEvent, nil, offset)) + case <-changes.Truncated: + // Close the current reader + if err := file.Close(); err != nil { + return err + } + + // Get a new reader at offset zero + offset = 0 + var err error + file, err = fs.ReadAt(path, offset) + if err != nil { + return err + } + defer file.Close() + + if limit <= 0 { + fileReader = file + } else { + // Get the current limit + lr, ok := fileReader.(*io.LimitedReader) + if !ok { + return fmt.Errorf("unable to determine remaining read limit") + } + + fileReader = io.LimitReader(file, lr.N) + } + + // Store the last event + lastEvent = truncateEvent + continue OUTER + case <-framer.ExitCh(): + return parseFramerErr(framer.Err()) + case <-ctx.Done(): + return nil + case err, ok := <-eofCancelCh: + if !ok { + return nil + } + + return err + } + } + } +} + +// blockUntilNextLog returns a channel that will have data sent when the next +// log index or anything greater is created. +func blockUntilNextLog(ctx context.Context, fs allocdir.AllocDirFS, logPath, task, logType string, nextIndex int64) chan error { + nextPath := filepath.Join(logPath, fmt.Sprintf("%s.%s.%d", task, logType, nextIndex)) + next := make(chan error, 1) + + go func() { + eofCancelCh, err := fs.BlockUntilExists(ctx, nextPath) + if err != nil { + next <- err + close(next) + return + } + + ticker := time.NewTicker(nextLogCheckRate) + defer ticker.Stop() + scanCh := ticker.C + for { + select { + case <-ctx.Done(): + next <- nil + close(next) + return + case err := <-eofCancelCh: + next <- err + close(next) + return + case <-scanCh: + entries, err := fs.List(logPath) + if err != nil { + next <- fmt.Errorf("failed to list entries: %v", err) + close(next) + return + } + + indexes, err := logIndexes(entries, task, logType) + if err != nil { + next <- err + close(next) + return + } + + // Scan and see if there are any entries larger than what we are + // waiting for. + for _, entry := range indexes { + if entry.idx >= nextIndex { + next <- nil + close(next) + return + } + } + } + } + }() + + return next +} + +// indexTuple and indexTupleArray are used to find the correct log entry to +// start streaming logs from +type indexTuple struct { + idx int64 + entry *cstructs.AllocFileInfo +} + +type indexTupleArray []indexTuple + +func (a indexTupleArray) Len() int { return len(a) } +func (a indexTupleArray) Less(i, j int) bool { return a[i].idx < a[j].idx } +func (a indexTupleArray) Swap(i, j int) { a[i], a[j] = a[j], a[i] } + +// logIndexes takes a set of entries and returns a indexTupleArray of +// the desired log file entries. If the indexes could not be determined, an +// error is returned. +func logIndexes(entries []*cstructs.AllocFileInfo, task, logType string) (indexTupleArray, error) { + var indexes []indexTuple + prefix := fmt.Sprintf("%s.%s.", task, logType) + for _, entry := range entries { + if entry.IsDir { + continue + } + + // If nothing was trimmed, then it is not a match + idxStr := strings.TrimPrefix(entry.Name, prefix) + if idxStr == entry.Name { + continue + } + + // Convert to an int + idx, err := strconv.Atoi(idxStr) + if err != nil { + return nil, fmt.Errorf("failed to convert %q to a log index: %v", idxStr, err) + } + + indexes = append(indexes, indexTuple{idx: int64(idx), entry: entry}) + } + + return indexTupleArray(indexes), nil +} + +// findClosest takes a list of entries, the desired log index and desired log +// offset (which can be negative, treated as offset from end), task name and log +// type and returns the log entry, the log index, the offset to read from and a +// potential error. +func findClosest(entries []*cstructs.AllocFileInfo, desiredIdx, desiredOffset int64, + task, logType string) (*cstructs.AllocFileInfo, int64, int64, error) { + + // Build the matching indexes + indexes, err := logIndexes(entries, task, logType) + if err != nil { + return nil, 0, 0, err + } + if len(indexes) == 0 { + return nil, 0, 0, fmt.Errorf("log entry for task %q and log type %q not found", task, logType) + } + + // Binary search the indexes to get the desiredIdx + sort.Sort(indexes) + i := sort.Search(len(indexes), func(i int) bool { return indexes[i].idx >= desiredIdx }) + l := len(indexes) + if i == l { + // Use the last index if the number is bigger than all of them. + i = l - 1 + } + + // Get to the correct offset + offset := desiredOffset + idx := int64(i) + for { + s := indexes[idx].entry.Size + + // Base case + if offset == 0 { + break + } else if offset < 0 { + // Going backwards + if newOffset := s + offset; newOffset >= 0 { + // Current file works + offset = newOffset + break + } else if idx == 0 { + // Already at the end + offset = 0 + break + } else { + // Try the file before + offset = newOffset + idx -= 1 + continue + } + } else { + // Going forward + if offset <= s { + // Current file works + break + } else if idx == int64(l-1) { + // Already at the end + offset = s + break + } else { + // Try the next file + offset = offset - s + idx += 1 + continue + } + + } + } + + return indexes[idx].entry, indexes[idx].idx, offset, nil +} + +// parseFramerErr takes an error and returns an error. The error will +// potentially change if it was caused by the connection being closed. +func parseFramerErr(err error) error { + if err == nil { + return nil + } + + errMsg := err.Error() + + if strings.Contains(errMsg, io.ErrClosedPipe.Error()) { + // The pipe check is for tests + return syscall.EPIPE + } + + // The connection was closed by our peer + if strings.Contains(errMsg, syscall.EPIPE.Error()) || strings.Contains(errMsg, syscall.ECONNRESET.Error()) { + return syscall.EPIPE + } + + // Windows version of ECONNRESET + //XXX(schmichael) I could find no existing error or constant to + // compare this against. + if strings.Contains(errMsg, "forcibly closed") { + return syscall.EPIPE + } + + return err +} diff --git a/client/fs_endpoint_test.go b/client/fs_endpoint_test.go new file mode 100644 index 00000000000..4e90daf0620 --- /dev/null +++ b/client/fs_endpoint_test.go @@ -0,0 +1,2036 @@ +package client + +import ( + "context" + "fmt" + "io" + "io/ioutil" + "log" + "math" + "net" + "os" + "path/filepath" + "reflect" + "strings" + "testing" + "time" + + "github.com/hashicorp/nomad/acl" + "github.com/hashicorp/nomad/client/allocdir" + "github.com/hashicorp/nomad/client/config" + sframer "github.com/hashicorp/nomad/client/lib/streamframer" + cstructs "github.com/hashicorp/nomad/client/structs" + "github.com/hashicorp/nomad/helper/uuid" + "github.com/hashicorp/nomad/nomad" + "github.com/hashicorp/nomad/nomad/mock" + "github.com/hashicorp/nomad/nomad/structs" + "github.com/hashicorp/nomad/testutil" + "github.com/stretchr/testify/require" + "github.com/ugorji/go/codec" +) + +// tempAllocDir returns a new alloc dir that is rooted in a temp dir. The caller +// should destroy the temp dir. +func tempAllocDir(t testing.TB) *allocdir.AllocDir { + dir, err := ioutil.TempDir("", "") + if err != nil { + t.Fatalf("TempDir() failed: %v", err) + } + + if err := os.Chmod(dir, 0777); err != nil { + t.Fatalf("failed to chmod dir: %v", err) + } + + return allocdir.NewAllocDir(log.New(os.Stderr, "", log.LstdFlags), dir) +} + +type nopWriteCloser struct { + io.Writer +} + +func (n nopWriteCloser) Close() error { + return nil +} + +func TestFS_Stat_NoAlloc(t *testing.T) { + t.Parallel() + require := require.New(t) + + // Start a client + c := TestClient(t, nil) + defer c.Shutdown() + + // Make the request with bad allocation id + req := &cstructs.FsStatRequest{ + AllocID: uuid.Generate(), + Path: "foo", + QueryOptions: structs.QueryOptions{Region: "global"}, + } + + var resp cstructs.FsStatResponse + err := c.ClientRPC("FileSystem.Stat", req, &resp) + require.NotNil(err) + require.True(structs.IsErrUnknownAllocation(err)) +} + +func TestFS_Stat(t *testing.T) { + t.Parallel() + require := require.New(t) + + // Start a client + c := TestClient(t, nil) + defer c.Shutdown() + + // Create and add an alloc + a := mock.Alloc() + c.addAlloc(a, "") + + // Wait for the client to start it + testutil.WaitForResult(func() (bool, error) { + ar, ok := c.allocs[a.ID] + if !ok { + return false, fmt.Errorf("alloc doesn't exist") + } + + return len(ar.tasks) != 0, fmt.Errorf("tasks not running") + }, func(err error) { + t.Fatal(err) + }) + + // Make the request with bad allocation id + req := &cstructs.FsStatRequest{ + AllocID: a.ID, + Path: "/", + QueryOptions: structs.QueryOptions{Region: "global"}, + } + + var resp cstructs.FsStatResponse + err := c.ClientRPC("FileSystem.Stat", req, &resp) + require.Nil(err) + require.NotNil(resp.Info) + require.True(resp.Info.IsDir) +} + +func TestFS_Stat_ACL(t *testing.T) { + t.Parallel() + require := require.New(t) + + // Start a server + s, root := nomad.TestACLServer(t, nil) + defer s.Shutdown() + testutil.WaitForLeader(t, s.RPC) + + client := TestClient(t, func(c *config.Config) { + c.ACLEnabled = true + c.Servers = []string{s.GetConfig().RPCAddr.String()} + }) + defer client.Shutdown() + + // Create a bad token + policyBad := mock.NamespacePolicy("other", "", []string{acl.NamespaceCapabilityDeny}) + tokenBad := mock.CreatePolicyAndToken(t, s.State(), 1005, "invalid", policyBad) + + policyGood := mock.NamespacePolicy(structs.DefaultNamespace, "", + []string{acl.NamespaceCapabilityReadLogs, acl.NamespaceCapabilityReadFS}) + tokenGood := mock.CreatePolicyAndToken(t, s.State(), 1009, "valid2", policyGood) + + cases := []struct { + Name string + Token string + ExpectedError string + }{ + { + Name: "bad token", + Token: tokenBad.SecretID, + ExpectedError: structs.ErrPermissionDenied.Error(), + }, + { + Name: "good token", + Token: tokenGood.SecretID, + ExpectedError: structs.ErrUnknownAllocationPrefix, + }, + { + Name: "root token", + Token: root.SecretID, + ExpectedError: structs.ErrUnknownAllocationPrefix, + }, + } + + for _, c := range cases { + t.Run(c.Name, func(t *testing.T) { + // Make the request with bad allocation id + req := &cstructs.FsStatRequest{ + AllocID: uuid.Generate(), + Path: "/", + QueryOptions: structs.QueryOptions{ + Region: "global", + AuthToken: c.Token, + Namespace: structs.DefaultNamespace, + }, + } + + var resp cstructs.FsStatResponse + err := client.ClientRPC("FileSystem.Stat", req, &resp) + require.NotNil(err) + require.Contains(err.Error(), c.ExpectedError) + }) + } +} + +func TestFS_List_NoAlloc(t *testing.T) { + t.Parallel() + require := require.New(t) + + // Start a client + c := TestClient(t, nil) + defer c.Shutdown() + + // Make the request with bad allocation id + req := &cstructs.FsListRequest{ + AllocID: uuid.Generate(), + Path: "foo", + QueryOptions: structs.QueryOptions{Region: "global"}, + } + + var resp cstructs.FsListResponse + err := c.ClientRPC("FileSystem.List", req, &resp) + require.NotNil(err) + require.True(structs.IsErrUnknownAllocation(err)) +} + +func TestFS_List(t *testing.T) { + t.Parallel() + require := require.New(t) + + // Start a client + c := TestClient(t, nil) + defer c.Shutdown() + + // Create and add an alloc + a := mock.Alloc() + c.addAlloc(a, "") + + // Wait for the client to start it + testutil.WaitForResult(func() (bool, error) { + ar, ok := c.allocs[a.ID] + if !ok { + return false, fmt.Errorf("alloc doesn't exist") + } + + return len(ar.tasks) != 0, fmt.Errorf("tasks not running") + }, func(err error) { + t.Fatal(err) + }) + + // Make the request + req := &cstructs.FsListRequest{ + AllocID: a.ID, + Path: "/", + QueryOptions: structs.QueryOptions{Region: "global"}, + } + + var resp cstructs.FsListResponse + err := c.ClientRPC("FileSystem.List", req, &resp) + require.Nil(err) + require.NotEmpty(resp.Files) + require.True(resp.Files[0].IsDir) +} + +func TestFS_List_ACL(t *testing.T) { + t.Parallel() + require := require.New(t) + + // Start a server + s, root := nomad.TestACLServer(t, nil) + defer s.Shutdown() + testutil.WaitForLeader(t, s.RPC) + + client := TestClient(t, func(c *config.Config) { + c.ACLEnabled = true + c.Servers = []string{s.GetConfig().RPCAddr.String()} + }) + defer client.Shutdown() + + // Create a bad token + policyBad := mock.NamespacePolicy("other", "", []string{acl.NamespaceCapabilityDeny}) + tokenBad := mock.CreatePolicyAndToken(t, s.State(), 1005, "invalid", policyBad) + + policyGood := mock.NamespacePolicy(structs.DefaultNamespace, "", + []string{acl.NamespaceCapabilityReadLogs, acl.NamespaceCapabilityReadFS}) + tokenGood := mock.CreatePolicyAndToken(t, s.State(), 1009, "valid2", policyGood) + + cases := []struct { + Name string + Token string + ExpectedError string + }{ + { + Name: "bad token", + Token: tokenBad.SecretID, + ExpectedError: structs.ErrPermissionDenied.Error(), + }, + { + Name: "good token", + Token: tokenGood.SecretID, + ExpectedError: structs.ErrUnknownAllocationPrefix, + }, + { + Name: "root token", + Token: root.SecretID, + ExpectedError: structs.ErrUnknownAllocationPrefix, + }, + } + + for _, c := range cases { + t.Run(c.Name, func(t *testing.T) { + // Make the request with bad allocation id + req := &cstructs.FsListRequest{ + AllocID: uuid.Generate(), + Path: "/", + QueryOptions: structs.QueryOptions{ + Region: "global", + AuthToken: c.Token, + Namespace: structs.DefaultNamespace, + }, + } + + var resp cstructs.FsListResponse + err := client.ClientRPC("FileSystem.List", req, &resp) + require.NotNil(err) + require.Contains(err.Error(), c.ExpectedError) + }) + } +} + +func TestFS_Stream_NoAlloc(t *testing.T) { + t.Parallel() + require := require.New(t) + + // Start a client + c := TestClient(t, nil) + defer c.Shutdown() + + // Make the request with bad allocation id + req := &cstructs.FsStreamRequest{ + AllocID: uuid.Generate(), + Path: "foo", + Origin: "start", + QueryOptions: structs.QueryOptions{Region: "global"}, + } + + // Get the handler + handler, err := c.StreamingRpcHandler("FileSystem.Stream") + require.Nil(err) + + // Create a pipe + p1, p2 := net.Pipe() + defer p1.Close() + defer p2.Close() + + errCh := make(chan error) + streamMsg := make(chan *cstructs.StreamErrWrapper) + + // Start the handler + go handler(p2) + + // Start the decoder + go func() { + decoder := codec.NewDecoder(p1, structs.MsgpackHandle) + for { + var msg cstructs.StreamErrWrapper + if err := decoder.Decode(&msg); err != nil { + if err == io.EOF || strings.Contains(err.Error(), "closed") { + return + } + errCh <- fmt.Errorf("error decoding: %v", err) + } + + streamMsg <- &msg + } + }() + + // Send the request + encoder := codec.NewEncoder(p1, structs.MsgpackHandle) + require.Nil(encoder.Encode(req)) + + timeout := time.After(3 * time.Second) + +OUTER: + for { + select { + case <-timeout: + t.Fatal("timeout") + case err := <-errCh: + t.Fatal(err) + case msg := <-streamMsg: + t.Logf("Got msg %+v", msg) + if msg.Error == nil { + continue + } + + if structs.IsErrUnknownAllocation(msg.Error) { + break OUTER + } else { + t.Fatalf("bad error: %v", err) + } + } + } +} + +func TestFS_Stream_ACL(t *testing.T) { + t.Parallel() + require := require.New(t) + + // Start a server + s, root := nomad.TestACLServer(t, nil) + defer s.Shutdown() + testutil.WaitForLeader(t, s.RPC) + + client := TestClient(t, func(c *config.Config) { + c.ACLEnabled = true + c.Servers = []string{s.GetConfig().RPCAddr.String()} + }) + defer client.Shutdown() + + // Create a bad token + policyBad := mock.NamespacePolicy("other", "", []string{acl.NamespaceCapabilityReadFS}) + tokenBad := mock.CreatePolicyAndToken(t, s.State(), 1005, "invalid", policyBad) + + policyGood := mock.NamespacePolicy(structs.DefaultNamespace, "", + []string{acl.NamespaceCapabilityReadLogs, acl.NamespaceCapabilityReadFS}) + tokenGood := mock.CreatePolicyAndToken(t, s.State(), 1009, "valid2", policyGood) + + cases := []struct { + Name string + Token string + ExpectedError string + }{ + { + Name: "bad token", + Token: tokenBad.SecretID, + ExpectedError: structs.ErrPermissionDenied.Error(), + }, + { + Name: "good token", + Token: tokenGood.SecretID, + ExpectedError: structs.ErrUnknownAllocationPrefix, + }, + { + Name: "root token", + Token: root.SecretID, + ExpectedError: structs.ErrUnknownAllocationPrefix, + }, + } + + for _, c := range cases { + t.Run(c.Name, func(t *testing.T) { + // Make the request with bad allocation id + req := &cstructs.FsStreamRequest{ + AllocID: uuid.Generate(), + Path: "foo", + Origin: "start", + QueryOptions: structs.QueryOptions{ + Namespace: structs.DefaultNamespace, + Region: "global", + AuthToken: c.Token, + }, + } + + // Get the handler + handler, err := client.StreamingRpcHandler("FileSystem.Stream") + require.Nil(err) + + // Create a pipe + p1, p2 := net.Pipe() + defer p1.Close() + defer p2.Close() + + errCh := make(chan error) + streamMsg := make(chan *cstructs.StreamErrWrapper) + + // Start the handler + go handler(p2) + + // Start the decoder + go func() { + decoder := codec.NewDecoder(p1, structs.MsgpackHandle) + for { + var msg cstructs.StreamErrWrapper + if err := decoder.Decode(&msg); err != nil { + if err == io.EOF || strings.Contains(err.Error(), "closed") { + return + } + errCh <- fmt.Errorf("error decoding: %v", err) + } + + streamMsg <- &msg + } + }() + + // Send the request + encoder := codec.NewEncoder(p1, structs.MsgpackHandle) + require.Nil(encoder.Encode(req)) + + timeout := time.After(5 * time.Second) + + OUTER: + for { + select { + case <-timeout: + t.Fatal("timeout") + case err := <-errCh: + t.Fatal(err) + case msg := <-streamMsg: + if msg.Error == nil { + continue + } + + if strings.Contains(msg.Error.Error(), c.ExpectedError) { + break OUTER + } else { + t.Fatalf("Bad error: %v", msg.Error) + } + } + } + }) + } +} + +func TestFS_Stream(t *testing.T) { + t.Parallel() + require := require.New(t) + + // Start a server and client + s := nomad.TestServer(t, nil) + defer s.Shutdown() + testutil.WaitForLeader(t, s.RPC) + + c := TestClient(t, func(c *config.Config) { + c.Servers = []string{s.GetConfig().RPCAddr.String()} + }) + defer c.Shutdown() + + // Force an allocation onto the node + expected := "Hello from the other side" + a := mock.Alloc() + a.Job.Type = structs.JobTypeBatch + a.NodeID = c.NodeID() + a.Job.TaskGroups[0].Count = 1 + a.Job.TaskGroups[0].Tasks[0] = &structs.Task{ + Name: "web", + Driver: "mock_driver", + Config: map[string]interface{}{ + "run_for": "2s", + "stdout_string": expected, + }, + LogConfig: structs.DefaultLogConfig(), + Resources: &structs.Resources{ + CPU: 500, + MemoryMB: 256, + }, + } + + // Wait for the client to connect + testutil.WaitForResult(func() (bool, error) { + node, err := s.State().NodeByID(nil, c.NodeID()) + if err != nil { + return false, err + } + if node == nil { + return false, fmt.Errorf("unknown node") + } + + return node.Status == structs.NodeStatusReady, fmt.Errorf("bad node status") + }, func(err error) { + t.Fatal(err) + }) + + // Upsert the allocation + state := s.State() + require.Nil(state.UpsertJob(999, a.Job)) + require.Nil(state.UpsertAllocs(1003, []*structs.Allocation{a})) + + // Wait for the client to run the allocation + testutil.WaitForResult(func() (bool, error) { + alloc, err := state.AllocByID(nil, a.ID) + if err != nil { + return false, err + } + if alloc == nil { + return false, fmt.Errorf("unknown alloc") + } + if alloc.ClientStatus != structs.AllocClientStatusComplete { + return false, fmt.Errorf("alloc client status: %v", alloc.ClientStatus) + } + + return true, nil + }, func(err error) { + t.Fatalf("Alloc on node %q not finished: %v", c.NodeID(), err) + }) + + // Make the request + req := &cstructs.FsStreamRequest{ + AllocID: a.ID, + Path: "alloc/logs/web.stdout.0", + PlainText: true, + QueryOptions: structs.QueryOptions{Region: "global"}, + } + + // Get the handler + handler, err := c.StreamingRpcHandler("FileSystem.Stream") + require.Nil(err) + + // Create a pipe + p1, p2 := net.Pipe() + defer p1.Close() + defer p2.Close() + + // Wrap the pipe so we can check it is closed + pipeChecker := &ReadWriteCloseChecker{ReadWriteCloser: p2} + + errCh := make(chan error) + streamMsg := make(chan *cstructs.StreamErrWrapper) + + // Start the handler + go handler(pipeChecker) + + // Start the decoder + go func() { + decoder := codec.NewDecoder(p1, structs.MsgpackHandle) + for { + var msg cstructs.StreamErrWrapper + if err := decoder.Decode(&msg); err != nil { + if err == io.EOF || strings.Contains(err.Error(), "closed") { + return + } + errCh <- fmt.Errorf("error decoding: %v", err) + } + + streamMsg <- &msg + } + }() + + // Send the request + encoder := codec.NewEncoder(p1, structs.MsgpackHandle) + require.Nil(encoder.Encode(req)) + + timeout := time.After(3 * time.Second) + received := "" +OUTER: + for { + select { + case <-timeout: + t.Fatal("timeout") + case err := <-errCh: + t.Fatal(err) + case msg := <-streamMsg: + if msg.Error != nil { + t.Fatalf("Got error: %v", msg.Error.Error()) + } + + // Add the payload + received += string(msg.Payload) + if received == expected { + break OUTER + } + } + } + + testutil.WaitForResult(func() (bool, error) { + return pipeChecker.Closed, nil + }, func(err error) { + t.Fatal("Pipe not closed") + }) +} + +type ReadWriteCloseChecker struct { + io.ReadWriteCloser + Closed bool +} + +func (r *ReadWriteCloseChecker) Close() error { + r.Closed = true + return r.ReadWriteCloser.Close() +} + +func TestFS_Stream_Follow(t *testing.T) { + t.Parallel() + require := require.New(t) + + // Start a server and client + s := nomad.TestServer(t, nil) + defer s.Shutdown() + testutil.WaitForLeader(t, s.RPC) + + c := TestClient(t, func(c *config.Config) { + c.Servers = []string{s.GetConfig().RPCAddr.String()} + }) + defer c.Shutdown() + + // Force an allocation onto the node + expectedBase := "Hello from the other side" + repeat := 10 + + a := mock.Alloc() + a.Job.Type = structs.JobTypeBatch + a.NodeID = c.NodeID() + a.Job.TaskGroups[0].Count = 1 + a.Job.TaskGroups[0].Tasks[0] = &structs.Task{ + Name: "web", + Driver: "mock_driver", + Config: map[string]interface{}{ + "run_for": "20s", + "stdout_string": expectedBase, + "stdout_repeat": repeat, + "stdout_repeat_duration": 200 * time.Millisecond, + }, + LogConfig: structs.DefaultLogConfig(), + Resources: &structs.Resources{ + CPU: 500, + MemoryMB: 256, + }, + } + + // Wait for the client to connect + testutil.WaitForResult(func() (bool, error) { + node, err := s.State().NodeByID(nil, c.NodeID()) + if err != nil { + return false, err + } + if node == nil { + return false, fmt.Errorf("unknown node") + } + + return node.Status == structs.NodeStatusReady, fmt.Errorf("bad node status") + }, func(err error) { + t.Fatal(err) + }) + + // Upsert the allocation + state := s.State() + require.Nil(state.UpsertJob(999, a.Job)) + require.Nil(state.UpsertAllocs(1003, []*structs.Allocation{a})) + + // Wait for the client to run the allocation + testutil.WaitForResult(func() (bool, error) { + alloc, err := state.AllocByID(nil, a.ID) + if err != nil { + return false, err + } + if alloc == nil { + return false, fmt.Errorf("unknown alloc") + } + if alloc.ClientStatus != structs.AllocClientStatusRunning { + return false, fmt.Errorf("alloc client status: %v", alloc.ClientStatus) + } + + return true, nil + }, func(err error) { + t.Fatalf("Alloc on node %q not running: %v", c.NodeID(), err) + }) + + // Make the request + req := &cstructs.FsStreamRequest{ + AllocID: a.ID, + Path: "alloc/logs/web.stdout.0", + PlainText: true, + Follow: true, + QueryOptions: structs.QueryOptions{Region: "global"}, + } + + // Get the handler + handler, err := c.StreamingRpcHandler("FileSystem.Stream") + require.Nil(err) + + // Create a pipe + p1, p2 := net.Pipe() + defer p1.Close() + defer p2.Close() + + errCh := make(chan error) + streamMsg := make(chan *cstructs.StreamErrWrapper) + + // Start the handler + go handler(p2) + + // Start the decoder + go func() { + decoder := codec.NewDecoder(p1, structs.MsgpackHandle) + for { + var msg cstructs.StreamErrWrapper + if err := decoder.Decode(&msg); err != nil { + if err == io.EOF || strings.Contains(err.Error(), "closed") { + return + } + errCh <- fmt.Errorf("error decoding: %v", err) + } + + streamMsg <- &msg + } + }() + + // Send the request + encoder := codec.NewEncoder(p1, structs.MsgpackHandle) + require.Nil(encoder.Encode(req)) + + timeout := time.After(20 * time.Second) + expected := strings.Repeat(expectedBase, repeat+1) + received := "" +OUTER: + for { + select { + case <-timeout: + t.Fatal("timeout") + case err := <-errCh: + t.Fatal(err) + case msg := <-streamMsg: + if msg.Error != nil { + t.Fatalf("Got error: %v", msg.Error.Error()) + } + + // Add the payload + received += string(msg.Payload) + if received == expected { + break OUTER + } + } + } +} + +func TestFS_Stream_Limit(t *testing.T) { + t.Parallel() + require := require.New(t) + + // Start a server and client + s := nomad.TestServer(t, nil) + defer s.Shutdown() + testutil.WaitForLeader(t, s.RPC) + + c := TestClient(t, func(c *config.Config) { + c.Servers = []string{s.GetConfig().RPCAddr.String()} + }) + defer c.Shutdown() + + // Force an allocation onto the node + var limit int64 = 5 + full := "Hello from the other side" + expected := full[:limit] + a := mock.Alloc() + a.Job.Type = structs.JobTypeBatch + a.NodeID = c.NodeID() + a.Job.TaskGroups[0].Count = 1 + a.Job.TaskGroups[0].Tasks[0] = &structs.Task{ + Name: "web", + Driver: "mock_driver", + Config: map[string]interface{}{ + "run_for": "2s", + "stdout_string": full, + }, + LogConfig: structs.DefaultLogConfig(), + Resources: &structs.Resources{ + CPU: 500, + MemoryMB: 256, + }, + } + + // Wait for the client to connect + testutil.WaitForResult(func() (bool, error) { + node, err := s.State().NodeByID(nil, c.NodeID()) + if err != nil { + return false, err + } + if node == nil { + return false, fmt.Errorf("unknown node") + } + + return node.Status == structs.NodeStatusReady, fmt.Errorf("bad node status") + }, func(err error) { + t.Fatal(err) + }) + + // Upsert the allocation + state := s.State() + require.Nil(state.UpsertJob(999, a.Job)) + require.Nil(state.UpsertAllocs(1003, []*structs.Allocation{a})) + + // Wait for the client to run the allocation + testutil.WaitForResult(func() (bool, error) { + alloc, err := state.AllocByID(nil, a.ID) + if err != nil { + return false, err + } + if alloc == nil { + return false, fmt.Errorf("unknown alloc") + } + if alloc.ClientStatus != structs.AllocClientStatusComplete { + return false, fmt.Errorf("alloc client status: %v", alloc.ClientStatus) + } + + return true, nil + }, func(err error) { + t.Fatalf("Alloc on node %q not finished: %v", c.NodeID(), err) + }) + + // Make the request + req := &cstructs.FsStreamRequest{ + AllocID: a.ID, + Path: "alloc/logs/web.stdout.0", + PlainText: true, + Limit: limit, + QueryOptions: structs.QueryOptions{Region: "global"}, + } + + // Get the handler + handler, err := c.StreamingRpcHandler("FileSystem.Stream") + require.Nil(err) + + // Create a pipe + p1, p2 := net.Pipe() + defer p1.Close() + defer p2.Close() + + errCh := make(chan error) + streamMsg := make(chan *cstructs.StreamErrWrapper) + + // Start the handler + go handler(p2) + + // Start the decoder + go func() { + decoder := codec.NewDecoder(p1, structs.MsgpackHandle) + for { + var msg cstructs.StreamErrWrapper + if err := decoder.Decode(&msg); err != nil { + if err == io.EOF || strings.Contains(err.Error(), "closed") { + return + } + errCh <- fmt.Errorf("error decoding: %v", err) + } + + streamMsg <- &msg + } + }() + + // Send the request + encoder := codec.NewEncoder(p1, structs.MsgpackHandle) + require.Nil(encoder.Encode(req)) + + timeout := time.After(3 * time.Second) + received := "" +OUTER: + for { + select { + case <-timeout: + t.Fatal("timeout") + case err := <-errCh: + t.Fatal(err) + case msg := <-streamMsg: + if msg.Error != nil { + t.Fatalf("Got error: %v", msg.Error.Error()) + } + + // Add the payload + received += string(msg.Payload) + if received == expected { + break OUTER + } + } + } +} + +func TestFS_Logs_NoAlloc(t *testing.T) { + t.Parallel() + require := require.New(t) + + // Start a client + c := TestClient(t, nil) + defer c.Shutdown() + + // Make the request with bad allocation id + req := &cstructs.FsLogsRequest{ + AllocID: uuid.Generate(), + Task: "foo", + LogType: "stdout", + Origin: "start", + QueryOptions: structs.QueryOptions{Region: "global"}, + } + + // Get the handler + handler, err := c.StreamingRpcHandler("FileSystem.Logs") + require.Nil(err) + + // Create a pipe + p1, p2 := net.Pipe() + defer p1.Close() + defer p2.Close() + + errCh := make(chan error) + streamMsg := make(chan *cstructs.StreamErrWrapper) + + // Start the handler + go handler(p2) + + // Start the decoder + go func() { + decoder := codec.NewDecoder(p1, structs.MsgpackHandle) + for { + var msg cstructs.StreamErrWrapper + if err := decoder.Decode(&msg); err != nil { + if err == io.EOF || strings.Contains(err.Error(), "closed") { + return + } + errCh <- fmt.Errorf("error decoding: %v", err) + } + + streamMsg <- &msg + } + }() + + // Send the request + encoder := codec.NewEncoder(p1, structs.MsgpackHandle) + require.Nil(encoder.Encode(req)) + + timeout := time.After(3 * time.Second) + +OUTER: + for { + select { + case <-timeout: + t.Fatal("timeout") + case err := <-errCh: + t.Fatal(err) + case msg := <-streamMsg: + t.Logf("Got msg %+v", msg) + if msg.Error == nil { + continue + } + + if structs.IsErrUnknownAllocation(msg.Error) { + break OUTER + } else { + t.Fatalf("bad error: %v", err) + } + } + } +} + +func TestFS_Logs_ACL(t *testing.T) { + t.Parallel() + require := require.New(t) + + // Start a server + s, root := nomad.TestACLServer(t, nil) + defer s.Shutdown() + testutil.WaitForLeader(t, s.RPC) + + client := TestClient(t, func(c *config.Config) { + c.ACLEnabled = true + c.Servers = []string{s.GetConfig().RPCAddr.String()} + }) + defer client.Shutdown() + + // Create a bad token + policyBad := mock.NamespacePolicy("other", "", []string{acl.NamespaceCapabilityReadFS}) + tokenBad := mock.CreatePolicyAndToken(t, s.State(), 1005, "invalid", policyBad) + + policyGood := mock.NamespacePolicy(structs.DefaultNamespace, "", + []string{acl.NamespaceCapabilityReadLogs, acl.NamespaceCapabilityReadFS}) + tokenGood := mock.CreatePolicyAndToken(t, s.State(), 1009, "valid2", policyGood) + + cases := []struct { + Name string + Token string + ExpectedError string + }{ + { + Name: "bad token", + Token: tokenBad.SecretID, + ExpectedError: structs.ErrPermissionDenied.Error(), + }, + { + Name: "good token", + Token: tokenGood.SecretID, + ExpectedError: structs.ErrUnknownAllocationPrefix, + }, + { + Name: "root token", + Token: root.SecretID, + ExpectedError: structs.ErrUnknownAllocationPrefix, + }, + } + + for _, c := range cases { + t.Run(c.Name, func(t *testing.T) { + // Make the request with bad allocation id + req := &cstructs.FsLogsRequest{ + AllocID: uuid.Generate(), + Task: "foo", + LogType: "stdout", + Origin: "start", + QueryOptions: structs.QueryOptions{ + Namespace: structs.DefaultNamespace, + Region: "global", + AuthToken: c.Token, + }, + } + + // Get the handler + handler, err := client.StreamingRpcHandler("FileSystem.Logs") + require.Nil(err) + + // Create a pipe + p1, p2 := net.Pipe() + defer p1.Close() + defer p2.Close() + + errCh := make(chan error) + streamMsg := make(chan *cstructs.StreamErrWrapper) + + // Start the handler + go handler(p2) + + // Start the decoder + go func() { + decoder := codec.NewDecoder(p1, structs.MsgpackHandle) + for { + var msg cstructs.StreamErrWrapper + if err := decoder.Decode(&msg); err != nil { + if err == io.EOF || strings.Contains(err.Error(), "closed") { + return + } + errCh <- fmt.Errorf("error decoding: %v", err) + } + + streamMsg <- &msg + } + }() + + // Send the request + encoder := codec.NewEncoder(p1, structs.MsgpackHandle) + require.Nil(encoder.Encode(req)) + + timeout := time.After(5 * time.Second) + + OUTER: + for { + select { + case <-timeout: + t.Fatal("timeout") + case err := <-errCh: + t.Fatal(err) + case msg := <-streamMsg: + if msg.Error == nil { + continue + } + + if strings.Contains(msg.Error.Error(), c.ExpectedError) { + break OUTER + } else { + t.Fatalf("Bad error: %v", msg.Error) + } + } + } + }) + } +} + +func TestFS_Logs(t *testing.T) { + t.Parallel() + require := require.New(t) + + // Start a server and client + s := nomad.TestServer(t, nil) + defer s.Shutdown() + testutil.WaitForLeader(t, s.RPC) + + c := TestClient(t, func(c *config.Config) { + c.Servers = []string{s.GetConfig().RPCAddr.String()} + }) + defer c.Shutdown() + + // Force an allocation onto the node + expected := "Hello from the other side" + a := mock.Alloc() + a.Job.Type = structs.JobTypeBatch + a.NodeID = c.NodeID() + a.Job.TaskGroups[0].Count = 1 + a.Job.TaskGroups[0].Tasks[0] = &structs.Task{ + Name: "web", + Driver: "mock_driver", + Config: map[string]interface{}{ + "run_for": "2s", + "stdout_string": expected, + }, + LogConfig: structs.DefaultLogConfig(), + Resources: &structs.Resources{ + CPU: 500, + MemoryMB: 256, + }, + } + + // Wait for the client to connect + testutil.WaitForResult(func() (bool, error) { + node, err := s.State().NodeByID(nil, c.NodeID()) + if err != nil { + return false, err + } + if node == nil { + return false, fmt.Errorf("unknown node") + } + + return node.Status == structs.NodeStatusReady, fmt.Errorf("bad node status") + }, func(err error) { + t.Fatal(err) + }) + + // Upsert the allocation + state := s.State() + require.Nil(state.UpsertJob(999, a.Job)) + require.Nil(state.UpsertAllocs(1003, []*structs.Allocation{a})) + + // Wait for the client to run the allocation + testutil.WaitForResult(func() (bool, error) { + alloc, err := state.AllocByID(nil, a.ID) + if err != nil { + return false, err + } + if alloc == nil { + return false, fmt.Errorf("unknown alloc") + } + if alloc.ClientStatus != structs.AllocClientStatusComplete { + return false, fmt.Errorf("alloc client status: %v", alloc.ClientStatus) + } + + return true, nil + }, func(err error) { + t.Fatalf("Alloc on node %q not finished: %v", c.NodeID(), err) + }) + + // Make the request + req := &cstructs.FsLogsRequest{ + AllocID: a.ID, + Task: a.Job.TaskGroups[0].Tasks[0].Name, + LogType: "stdout", + Origin: "start", + PlainText: true, + QueryOptions: structs.QueryOptions{Region: "global"}, + } + + // Get the handler + handler, err := c.StreamingRpcHandler("FileSystem.Logs") + require.Nil(err) + + // Create a pipe + p1, p2 := net.Pipe() + defer p1.Close() + defer p2.Close() + + errCh := make(chan error) + streamMsg := make(chan *cstructs.StreamErrWrapper) + + // Start the handler + go handler(p2) + + // Start the decoder + go func() { + decoder := codec.NewDecoder(p1, structs.MsgpackHandle) + for { + var msg cstructs.StreamErrWrapper + if err := decoder.Decode(&msg); err != nil { + if err == io.EOF || strings.Contains(err.Error(), "closed") { + return + } + errCh <- fmt.Errorf("error decoding: %v", err) + } + + streamMsg <- &msg + } + }() + + // Send the request + encoder := codec.NewEncoder(p1, structs.MsgpackHandle) + require.Nil(encoder.Encode(req)) + + timeout := time.After(3 * time.Second) + received := "" +OUTER: + for { + select { + case <-timeout: + t.Fatal("timeout") + case err := <-errCh: + t.Fatal(err) + case msg := <-streamMsg: + if msg.Error != nil { + t.Fatalf("Got error: %v", msg.Error.Error()) + } + + // Add the payload + received += string(msg.Payload) + if received == expected { + break OUTER + } + } + } +} + +func TestFS_Logs_Follow(t *testing.T) { + t.Parallel() + require := require.New(t) + + // Start a server and client + s := nomad.TestServer(t, nil) + defer s.Shutdown() + testutil.WaitForLeader(t, s.RPC) + + c := TestClient(t, func(c *config.Config) { + c.Servers = []string{s.GetConfig().RPCAddr.String()} + }) + defer c.Shutdown() + + // Force an allocation onto the node + expectedBase := "Hello from the other side" + repeat := 10 + + a := mock.Alloc() + a.Job.Type = structs.JobTypeBatch + a.NodeID = c.NodeID() + a.Job.TaskGroups[0].Count = 1 + a.Job.TaskGroups[0].Tasks[0] = &structs.Task{ + Name: "web", + Driver: "mock_driver", + Config: map[string]interface{}{ + "run_for": "20s", + "stdout_string": expectedBase, + "stdout_repeat": repeat, + "stdout_repeat_duration": 200 * time.Millisecond, + }, + LogConfig: structs.DefaultLogConfig(), + Resources: &structs.Resources{ + CPU: 500, + MemoryMB: 256, + }, + } + + // Wait for the client to connect + testutil.WaitForResult(func() (bool, error) { + node, err := s.State().NodeByID(nil, c.NodeID()) + if err != nil { + return false, err + } + if node == nil { + return false, fmt.Errorf("unknown node") + } + + return node.Status == structs.NodeStatusReady, fmt.Errorf("bad node status") + }, func(err error) { + t.Fatal(err) + }) + + // Upsert the allocation + state := s.State() + require.Nil(state.UpsertJob(999, a.Job)) + require.Nil(state.UpsertAllocs(1003, []*structs.Allocation{a})) + + // Wait for the client to run the allocation + testutil.WaitForResult(func() (bool, error) { + alloc, err := state.AllocByID(nil, a.ID) + if err != nil { + return false, err + } + if alloc == nil { + return false, fmt.Errorf("unknown alloc") + } + if alloc.ClientStatus != structs.AllocClientStatusRunning { + return false, fmt.Errorf("alloc client status: %v", alloc.ClientStatus) + } + + return true, nil + }, func(err error) { + t.Fatalf("Alloc on node %q not running: %v", c.NodeID(), err) + }) + + // Make the request + req := &cstructs.FsLogsRequest{ + AllocID: a.ID, + Task: a.Job.TaskGroups[0].Tasks[0].Name, + LogType: "stdout", + Origin: "start", + PlainText: true, + Follow: true, + QueryOptions: structs.QueryOptions{Region: "global"}, + } + + // Get the handler + handler, err := c.StreamingRpcHandler("FileSystem.Logs") + require.Nil(err) + + // Create a pipe + p1, p2 := net.Pipe() + defer p1.Close() + defer p2.Close() + + errCh := make(chan error) + streamMsg := make(chan *cstructs.StreamErrWrapper) + + // Start the handler + go handler(p2) + + // Start the decoder + go func() { + decoder := codec.NewDecoder(p1, structs.MsgpackHandle) + for { + var msg cstructs.StreamErrWrapper + if err := decoder.Decode(&msg); err != nil { + if err == io.EOF || strings.Contains(err.Error(), "closed") { + return + } + errCh <- fmt.Errorf("error decoding: %v", err) + } + + streamMsg <- &msg + } + }() + + // Send the request + encoder := codec.NewEncoder(p1, structs.MsgpackHandle) + require.Nil(encoder.Encode(req)) + + timeout := time.After(20 * time.Second) + expected := strings.Repeat(expectedBase, repeat+1) + received := "" +OUTER: + for { + select { + case <-timeout: + t.Fatal("timeout") + case err := <-errCh: + t.Fatal(err) + case msg := <-streamMsg: + if msg.Error != nil { + t.Fatalf("Got error: %v", msg.Error.Error()) + } + + // Add the payload + received += string(msg.Payload) + if received == expected { + break OUTER + } + } + } +} + +func TestFS_findClosest(t *testing.T) { + task := "foo" + entries := []*cstructs.AllocFileInfo{ + { + Name: "foo.stdout.0", + Size: 100, + }, + { + Name: "foo.stdout.1", + Size: 100, + }, + { + Name: "foo.stdout.2", + Size: 100, + }, + { + Name: "foo.stdout.3", + Size: 100, + }, + { + Name: "foo.stderr.0", + Size: 100, + }, + { + Name: "foo.stderr.1", + Size: 100, + }, + { + Name: "foo.stderr.2", + Size: 100, + }, + } + + cases := []struct { + Entries []*cstructs.AllocFileInfo + DesiredIdx int64 + DesiredOffset int64 + Task string + LogType string + ExpectedFile string + ExpectedIdx int64 + ExpectedOffset int64 + Error bool + }{ + // Test error cases + { + Entries: nil, + DesiredIdx: 0, + Task: task, + LogType: "stdout", + Error: true, + }, + { + Entries: entries[0:3], + DesiredIdx: 0, + Task: task, + LogType: "stderr", + Error: true, + }, + + // Test beginning cases + { + Entries: entries, + DesiredIdx: 0, + Task: task, + LogType: "stdout", + ExpectedFile: entries[0].Name, + ExpectedIdx: 0, + }, + { + // Desired offset should be ignored at edges + Entries: entries, + DesiredIdx: 0, + DesiredOffset: -100, + Task: task, + LogType: "stdout", + ExpectedFile: entries[0].Name, + ExpectedIdx: 0, + ExpectedOffset: 0, + }, + { + // Desired offset should be ignored at edges + Entries: entries, + DesiredIdx: 1, + DesiredOffset: -1000, + Task: task, + LogType: "stdout", + ExpectedFile: entries[0].Name, + ExpectedIdx: 0, + ExpectedOffset: 0, + }, + { + Entries: entries, + DesiredIdx: 0, + Task: task, + LogType: "stderr", + ExpectedFile: entries[4].Name, + ExpectedIdx: 0, + }, + { + Entries: entries, + DesiredIdx: 0, + Task: task, + LogType: "stdout", + ExpectedFile: entries[0].Name, + ExpectedIdx: 0, + }, + + // Test middle cases + { + Entries: entries, + DesiredIdx: 1, + Task: task, + LogType: "stdout", + ExpectedFile: entries[1].Name, + ExpectedIdx: 1, + }, + { + Entries: entries, + DesiredIdx: 1, + DesiredOffset: 10, + Task: task, + LogType: "stdout", + ExpectedFile: entries[1].Name, + ExpectedIdx: 1, + ExpectedOffset: 10, + }, + { + Entries: entries, + DesiredIdx: 1, + DesiredOffset: 110, + Task: task, + LogType: "stdout", + ExpectedFile: entries[2].Name, + ExpectedIdx: 2, + ExpectedOffset: 10, + }, + { + Entries: entries, + DesiredIdx: 1, + Task: task, + LogType: "stderr", + ExpectedFile: entries[5].Name, + ExpectedIdx: 1, + }, + // Test end cases + { + Entries: entries, + DesiredIdx: math.MaxInt64, + Task: task, + LogType: "stdout", + ExpectedFile: entries[3].Name, + ExpectedIdx: 3, + }, + { + Entries: entries, + DesiredIdx: math.MaxInt64, + DesiredOffset: math.MaxInt64, + Task: task, + LogType: "stdout", + ExpectedFile: entries[3].Name, + ExpectedIdx: 3, + ExpectedOffset: 100, + }, + { + Entries: entries, + DesiredIdx: math.MaxInt64, + DesiredOffset: -10, + Task: task, + LogType: "stdout", + ExpectedFile: entries[3].Name, + ExpectedIdx: 3, + ExpectedOffset: 90, + }, + { + Entries: entries, + DesiredIdx: math.MaxInt64, + Task: task, + LogType: "stderr", + ExpectedFile: entries[6].Name, + ExpectedIdx: 2, + }, + } + + for i, c := range cases { + entry, idx, offset, err := findClosest(c.Entries, c.DesiredIdx, c.DesiredOffset, c.Task, c.LogType) + if err != nil { + if !c.Error { + t.Fatalf("case %d: Unexpected error: %v", i, err) + } + continue + } + + if entry.Name != c.ExpectedFile { + t.Fatalf("case %d: Got file %q; want %q", i, entry.Name, c.ExpectedFile) + } + if idx != c.ExpectedIdx { + t.Fatalf("case %d: Got index %d; want %d", i, idx, c.ExpectedIdx) + } + if offset != c.ExpectedOffset { + t.Fatalf("case %d: Got offset %d; want %d", i, offset, c.ExpectedOffset) + } + } +} + +func TestFS_streamFile_NoFile(t *testing.T) { + t.Parallel() + require := require.New(t) + c := TestClient(t, nil) + defer c.Shutdown() + + ad := tempAllocDir(t) + defer os.RemoveAll(ad.AllocDir) + + frames := make(chan *sframer.StreamFrame, 32) + framer := sframer.NewStreamFramer(frames, streamHeartbeatRate, streamBatchWindow, streamFrameSize) + framer.Run() + defer framer.Destroy() + + err := c.endpoints.FileSystem.streamFile( + context.Background(), 0, "foo", 0, ad, framer, nil) + require.NotNil(err) + require.Contains(err.Error(), "no such file") +} + +func TestFS_streamFile_Modify(t *testing.T) { + t.Parallel() + + c := TestClient(t, nil) + defer c.Shutdown() + + // Get a temp alloc dir + ad := tempAllocDir(t) + defer os.RemoveAll(ad.AllocDir) + + // Create a file in the temp dir + streamFile := "stream_file" + f, err := os.Create(filepath.Join(ad.AllocDir, streamFile)) + if err != nil { + t.Fatalf("Failed to create file: %v", err) + } + defer f.Close() + + data := []byte("helloworld") + + // Start the reader + resultCh := make(chan struct{}) + frames := make(chan *sframer.StreamFrame, 4) + go func() { + var collected []byte + for { + frame := <-frames + if frame.IsHeartbeat() { + continue + } + + collected = append(collected, frame.Data...) + if reflect.DeepEqual(data, collected) { + resultCh <- struct{}{} + return + } + } + }() + + // Write a few bytes + if _, err := f.Write(data[:3]); err != nil { + t.Fatalf("write failed: %v", err) + } + + framer := sframer.NewStreamFramer(frames, streamHeartbeatRate, streamBatchWindow, streamFrameSize) + framer.Run() + defer framer.Destroy() + + // Start streaming + go func() { + if err := c.endpoints.FileSystem.streamFile( + context.Background(), 0, streamFile, 0, ad, framer, nil); err != nil { + t.Fatalf("stream() failed: %v", err) + } + }() + + // Sleep a little before writing more. This lets us check if the watch + // is working. + time.Sleep(1 * time.Duration(testutil.TestMultiplier()) * time.Second) + if _, err := f.Write(data[3:]); err != nil { + t.Fatalf("write failed: %v", err) + } + + select { + case <-resultCh: + case <-time.After(10 * time.Duration(testutil.TestMultiplier()) * streamBatchWindow): + t.Fatalf("failed to send new data") + } +} + +func TestFS_streamFile_Truncate(t *testing.T) { + t.Parallel() + c := TestClient(t, nil) + defer c.Shutdown() + + // Get a temp alloc dir + ad := tempAllocDir(t) + defer os.RemoveAll(ad.AllocDir) + + // Create a file in the temp dir + data := []byte("helloworld") + streamFile := "stream_file" + streamFilePath := filepath.Join(ad.AllocDir, streamFile) + f, err := os.Create(streamFilePath) + if err != nil { + t.Fatalf("Failed to create file: %v", err) + } + defer f.Close() + + // Start the reader + truncateCh := make(chan struct{}) + dataPostTruncCh := make(chan struct{}) + frames := make(chan *sframer.StreamFrame, 4) + go func() { + var collected []byte + for { + frame := <-frames + if frame.IsHeartbeat() { + continue + } + + if frame.FileEvent == truncateEvent { + close(truncateCh) + } + + collected = append(collected, frame.Data...) + if reflect.DeepEqual(data, collected) { + close(dataPostTruncCh) + return + } + } + }() + + // Write a few bytes + if _, err := f.Write(data[:3]); err != nil { + t.Fatalf("write failed: %v", err) + } + + framer := sframer.NewStreamFramer(frames, streamHeartbeatRate, streamBatchWindow, streamFrameSize) + framer.Run() + defer framer.Destroy() + + // Start streaming + go func() { + if err := c.endpoints.FileSystem.streamFile( + context.Background(), 0, streamFile, 0, ad, framer, nil); err != nil { + t.Fatalf("stream() failed: %v", err) + } + }() + + // Sleep a little before truncating. This lets us check if the watch + // is working. + time.Sleep(1 * time.Duration(testutil.TestMultiplier()) * time.Second) + if err := f.Truncate(0); err != nil { + t.Fatalf("truncate failed: %v", err) + } + if err := f.Sync(); err != nil { + t.Fatalf("sync failed: %v", err) + } + if err := f.Close(); err != nil { + t.Fatalf("failed to close file: %v", err) + } + + f2, err := os.OpenFile(streamFilePath, os.O_RDWR, 0) + if err != nil { + t.Fatalf("failed to reopen file: %v", err) + } + defer f2.Close() + if _, err := f2.Write(data[3:5]); err != nil { + t.Fatalf("write failed: %v", err) + } + + select { + case <-truncateCh: + case <-time.After(10 * time.Duration(testutil.TestMultiplier()) * streamBatchWindow): + t.Fatalf("did not receive truncate") + } + + // Sleep a little before writing more. This lets us check if the watch + // is working. + time.Sleep(1 * time.Duration(testutil.TestMultiplier()) * time.Second) + if _, err := f2.Write(data[5:]); err != nil { + t.Fatalf("write failed: %v", err) + } + + select { + case <-dataPostTruncCh: + case <-time.After(10 * time.Duration(testutil.TestMultiplier()) * streamBatchWindow): + t.Fatalf("did not receive post truncate data") + } +} + +func TestFS_streamImpl_Delete(t *testing.T) { + t.Parallel() + + c := TestClient(t, nil) + defer c.Shutdown() + + // Get a temp alloc dir + ad := tempAllocDir(t) + defer os.RemoveAll(ad.AllocDir) + + // Create a file in the temp dir + data := []byte("helloworld") + streamFile := "stream_file" + streamFilePath := filepath.Join(ad.AllocDir, streamFile) + f, err := os.Create(streamFilePath) + if err != nil { + t.Fatalf("Failed to create file: %v", err) + } + defer f.Close() + + // Start the reader + deleteCh := make(chan struct{}) + frames := make(chan *sframer.StreamFrame, 4) + go func() { + for { + frame := <-frames + if frame.IsHeartbeat() { + continue + } + + if frame.FileEvent == deleteEvent { + close(deleteCh) + return + } + } + }() + + // Write a few bytes + if _, err := f.Write(data[:3]); err != nil { + t.Fatalf("write failed: %v", err) + } + + framer := sframer.NewStreamFramer(frames, streamHeartbeatRate, streamBatchWindow, streamFrameSize) + framer.Run() + defer framer.Destroy() + + // Start streaming + go func() { + if err := c.endpoints.FileSystem.streamFile( + context.Background(), 0, streamFile, 0, ad, framer, nil); err != nil { + t.Fatalf("stream() failed: %v", err) + } + }() + + // Sleep a little before deleting. This lets us check if the watch + // is working. + time.Sleep(1 * time.Duration(testutil.TestMultiplier()) * time.Second) + if err := os.Remove(streamFilePath); err != nil { + t.Fatalf("delete failed: %v", err) + } + + select { + case <-deleteCh: + case <-time.After(10 * time.Duration(testutil.TestMultiplier()) * streamBatchWindow): + t.Fatalf("did not receive delete") + } +} + +func TestFS_logsImpl_NoFollow(t *testing.T) { + t.Parallel() + + c := TestClient(t, nil) + defer c.Shutdown() + + // Get a temp alloc dir and create the log dir + ad := tempAllocDir(t) + defer os.RemoveAll(ad.AllocDir) + + logDir := filepath.Join(ad.SharedDir, allocdir.LogDirName) + if err := os.MkdirAll(logDir, 0777); err != nil { + t.Fatalf("Failed to make log dir: %v", err) + } + + // Create a series of log files in the temp dir + task := "foo" + logType := "stdout" + expected := []byte("012") + for i := 0; i < 3; i++ { + logFile := fmt.Sprintf("%s.%s.%d", task, logType, i) + logFilePath := filepath.Join(logDir, logFile) + err := ioutil.WriteFile(logFilePath, expected[i:i+1], 777) + if err != nil { + t.Fatalf("Failed to create file: %v", err) + } + } + + // Start the reader + resultCh := make(chan struct{}) + frames := make(chan *sframer.StreamFrame, 4) + var received []byte + go func() { + for { + frame, ok := <-frames + if !ok { + return + } + + if frame.IsHeartbeat() { + continue + } + + received = append(received, frame.Data...) + if reflect.DeepEqual(received, expected) { + close(resultCh) + return + } + } + }() + + // Start streaming logs + go func() { + if err := c.endpoints.FileSystem.logsImpl( + context.Background(), false, false, 0, + OriginStart, task, logType, ad, frames); err != nil { + t.Fatalf("logs() failed: %v", err) + } + }() + + select { + case <-resultCh: + case <-time.After(10 * time.Duration(testutil.TestMultiplier()) * streamBatchWindow): + t.Fatalf("did not receive data: got %q", string(received)) + } +} + +func TestFS_logsImpl_Follow(t *testing.T) { + t.Parallel() + + c := TestClient(t, nil) + defer c.Shutdown() + + // Get a temp alloc dir and create the log dir + ad := tempAllocDir(t) + defer os.RemoveAll(ad.AllocDir) + + logDir := filepath.Join(ad.SharedDir, allocdir.LogDirName) + if err := os.MkdirAll(logDir, 0777); err != nil { + t.Fatalf("Failed to make log dir: %v", err) + } + + // Create a series of log files in the temp dir + task := "foo" + logType := "stdout" + expected := []byte("012345") + initialWrites := 3 + + writeToFile := func(index int, data []byte) { + logFile := fmt.Sprintf("%s.%s.%d", task, logType, index) + logFilePath := filepath.Join(logDir, logFile) + err := ioutil.WriteFile(logFilePath, data, 777) + if err != nil { + t.Fatalf("Failed to create file: %v", err) + } + } + for i := 0; i < initialWrites; i++ { + writeToFile(i, expected[i:i+1]) + } + + // Start the reader + firstResultCh := make(chan struct{}) + fullResultCh := make(chan struct{}) + frames := make(chan *sframer.StreamFrame, 4) + var received []byte + go func() { + for { + frame, ok := <-frames + if !ok { + return + } + + if frame.IsHeartbeat() { + continue + } + + received = append(received, frame.Data...) + if reflect.DeepEqual(received, expected[:initialWrites]) { + close(firstResultCh) + } else if reflect.DeepEqual(received, expected) { + close(fullResultCh) + return + } + } + }() + + // Start streaming logs + go c.endpoints.FileSystem.logsImpl( + context.Background(), true, false, 0, + OriginStart, task, logType, ad, frames) + + select { + case <-firstResultCh: + case <-time.After(10 * time.Duration(testutil.TestMultiplier()) * streamBatchWindow): + t.Fatalf("did not receive data: got %q", string(received)) + } + + // We got the first chunk of data, write out the rest to the next file + // at an index much ahead to check that it is following and detecting + // skips + skipTo := initialWrites + 10 + writeToFile(skipTo, expected[initialWrites:]) + + select { + case <-fullResultCh: + case <-time.After(10 * time.Duration(testutil.TestMultiplier()) * streamBatchWindow): + t.Fatalf("did not receive data: got %q", string(received)) + } +} diff --git a/client/gc_test.go b/client/gc_test.go index ac28239a193..f94b8bb7b2c 100644 --- a/client/gc_test.go +++ b/client/gc_test.go @@ -26,10 +26,10 @@ func TestIndexedGCAllocPQ(t *testing.T) { t.Parallel() pq := NewIndexedGCAllocPQ() - _, ar1 := testAllocRunnerFromAlloc(mock.Alloc(), false) - _, ar2 := testAllocRunnerFromAlloc(mock.Alloc(), false) - _, ar3 := testAllocRunnerFromAlloc(mock.Alloc(), false) - _, ar4 := testAllocRunnerFromAlloc(mock.Alloc(), false) + _, ar1 := testAllocRunnerFromAlloc(t, mock.Alloc(), false) + _, ar2 := testAllocRunnerFromAlloc(t, mock.Alloc(), false) + _, ar3 := testAllocRunnerFromAlloc(t, mock.Alloc(), false) + _, ar4 := testAllocRunnerFromAlloc(t, mock.Alloc(), false) pq.Push(ar1) pq.Push(ar2) @@ -108,7 +108,7 @@ func TestAllocGarbageCollector_MarkForCollection(t *testing.T) { logger := testLogger() gc := NewAllocGarbageCollector(logger, &MockStatsCollector{}, &MockAllocCounter{}, gcConfig()) - _, ar1 := testAllocRunnerFromAlloc(mock.Alloc(), false) + _, ar1 := testAllocRunnerFromAlloc(t, mock.Alloc(), false) gc.MarkForCollection(ar1) gcAlloc := gc.allocRunners.Pop() @@ -122,8 +122,8 @@ func TestAllocGarbageCollector_Collect(t *testing.T) { logger := testLogger() gc := NewAllocGarbageCollector(logger, &MockStatsCollector{}, &MockAllocCounter{}, gcConfig()) - _, ar1 := testAllocRunnerFromAlloc(mock.Alloc(), false) - _, ar2 := testAllocRunnerFromAlloc(mock.Alloc(), false) + _, ar1 := testAllocRunnerFromAlloc(t, mock.Alloc(), false) + _, ar2 := testAllocRunnerFromAlloc(t, mock.Alloc(), false) gc.MarkForCollection(ar1) gc.MarkForCollection(ar2) @@ -143,8 +143,8 @@ func TestAllocGarbageCollector_CollectAll(t *testing.T) { logger := testLogger() gc := NewAllocGarbageCollector(logger, &MockStatsCollector{}, &MockAllocCounter{}, gcConfig()) - _, ar1 := testAllocRunnerFromAlloc(mock.Alloc(), false) - _, ar2 := testAllocRunnerFromAlloc(mock.Alloc(), false) + _, ar1 := testAllocRunnerFromAlloc(t, mock.Alloc(), false) + _, ar2 := testAllocRunnerFromAlloc(t, mock.Alloc(), false) gc.MarkForCollection(ar1) gc.MarkForCollection(ar2) @@ -163,9 +163,9 @@ func TestAllocGarbageCollector_MakeRoomForAllocations_EnoughSpace(t *testing.T) conf.ReservedDiskMB = 20 gc := NewAllocGarbageCollector(logger, statsCollector, &MockAllocCounter{}, conf) - _, ar1 := testAllocRunnerFromAlloc(mock.Alloc(), false) + _, ar1 := testAllocRunnerFromAlloc(t, mock.Alloc(), false) close(ar1.waitCh) - _, ar2 := testAllocRunnerFromAlloc(mock.Alloc(), false) + _, ar2 := testAllocRunnerFromAlloc(t, mock.Alloc(), false) close(ar2.waitCh) gc.MarkForCollection(ar1) gc.MarkForCollection(ar2) @@ -198,9 +198,9 @@ func TestAllocGarbageCollector_MakeRoomForAllocations_GC_Partial(t *testing.T) { conf.ReservedDiskMB = 20 gc := NewAllocGarbageCollector(logger, statsCollector, &MockAllocCounter{}, conf) - _, ar1 := testAllocRunnerFromAlloc(mock.Alloc(), false) + _, ar1 := testAllocRunnerFromAlloc(t, mock.Alloc(), false) close(ar1.waitCh) - _, ar2 := testAllocRunnerFromAlloc(mock.Alloc(), false) + _, ar2 := testAllocRunnerFromAlloc(t, mock.Alloc(), false) close(ar2.waitCh) gc.MarkForCollection(ar1) gc.MarkForCollection(ar2) @@ -234,9 +234,9 @@ func TestAllocGarbageCollector_MakeRoomForAllocations_GC_All(t *testing.T) { conf.ReservedDiskMB = 20 gc := NewAllocGarbageCollector(logger, statsCollector, &MockAllocCounter{}, conf) - _, ar1 := testAllocRunnerFromAlloc(mock.Alloc(), false) + _, ar1 := testAllocRunnerFromAlloc(t, mock.Alloc(), false) close(ar1.waitCh) - _, ar2 := testAllocRunnerFromAlloc(mock.Alloc(), false) + _, ar2 := testAllocRunnerFromAlloc(t, mock.Alloc(), false) close(ar2.waitCh) gc.MarkForCollection(ar1) gc.MarkForCollection(ar2) @@ -266,9 +266,9 @@ func TestAllocGarbageCollector_MakeRoomForAllocations_GC_Fallback(t *testing.T) conf.ReservedDiskMB = 20 gc := NewAllocGarbageCollector(logger, statsCollector, &MockAllocCounter{}, conf) - _, ar1 := testAllocRunnerFromAlloc(mock.Alloc(), false) + _, ar1 := testAllocRunnerFromAlloc(t, mock.Alloc(), false) close(ar1.waitCh) - _, ar2 := testAllocRunnerFromAlloc(mock.Alloc(), false) + _, ar2 := testAllocRunnerFromAlloc(t, mock.Alloc(), false) close(ar2.waitCh) gc.MarkForCollection(ar1) gc.MarkForCollection(ar2) @@ -298,7 +298,7 @@ func TestAllocGarbageCollector_MaxAllocs(t *testing.T) { testutil.WaitForLeader(t, server.RPC) const maxAllocs = 6 - client := testClient(t, func(c *config.Config) { + client := TestClient(t, func(c *config.Config) { c.GCMaxAllocs = maxAllocs c.GCDiskUsageThreshold = 100 c.GCInodeUsageThreshold = 100 @@ -425,9 +425,9 @@ func TestAllocGarbageCollector_UsageBelowThreshold(t *testing.T) { conf.ReservedDiskMB = 20 gc := NewAllocGarbageCollector(logger, statsCollector, &MockAllocCounter{}, conf) - _, ar1 := testAllocRunnerFromAlloc(mock.Alloc(), false) + _, ar1 := testAllocRunnerFromAlloc(t, mock.Alloc(), false) close(ar1.waitCh) - _, ar2 := testAllocRunnerFromAlloc(mock.Alloc(), false) + _, ar2 := testAllocRunnerFromAlloc(t, mock.Alloc(), false) close(ar2.waitCh) gc.MarkForCollection(ar1) gc.MarkForCollection(ar2) @@ -457,9 +457,9 @@ func TestAllocGarbageCollector_UsedPercentThreshold(t *testing.T) { conf.ReservedDiskMB = 20 gc := NewAllocGarbageCollector(logger, statsCollector, &MockAllocCounter{}, conf) - _, ar1 := testAllocRunnerFromAlloc(mock.Alloc(), false) + _, ar1 := testAllocRunnerFromAlloc(t, mock.Alloc(), false) close(ar1.waitCh) - _, ar2 := testAllocRunnerFromAlloc(mock.Alloc(), false) + _, ar2 := testAllocRunnerFromAlloc(t, mock.Alloc(), false) close(ar2.waitCh) gc.MarkForCollection(ar1) gc.MarkForCollection(ar2) diff --git a/client/lib/streamframer/framer.go b/client/lib/streamframer/framer.go new file mode 100644 index 00000000000..b0caa4a047b --- /dev/null +++ b/client/lib/streamframer/framer.go @@ -0,0 +1,309 @@ +package framer + +import ( + "bytes" + "fmt" + "sync" + "time" +) + +var ( + // HeartbeatStreamFrame is the StreamFrame to send as a heartbeat, avoiding + // creating many instances of the empty StreamFrame + HeartbeatStreamFrame = &StreamFrame{} +) + +// StreamFrame is used to frame data of a file when streaming +type StreamFrame struct { + // Offset is the offset the data was read from + Offset int64 `json:",omitempty"` + + // Data is the read data + Data []byte `json:",omitempty"` + + // File is the file that the data was read from + File string `json:",omitempty"` + + // FileEvent is the last file event that occurred that could cause the + // streams position to change or end + FileEvent string `json:",omitempty"` +} + +// IsHeartbeat returns if the frame is a heartbeat frame +func (s *StreamFrame) IsHeartbeat() bool { + return s.Offset == 0 && len(s.Data) == 0 && s.File == "" && s.FileEvent == "" +} + +func (s *StreamFrame) Clear() { + s.Offset = 0 + s.Data = nil + s.File = "" + s.FileEvent = "" +} + +func (s *StreamFrame) IsCleared() bool { + if s.Offset != 0 { + return false + } else if s.Data != nil { + return false + } else if s.File != "" { + return false + } else if s.FileEvent != "" { + return false + } else { + return true + } +} + +// StreamFramer is used to buffer and send frames as well as heartbeat. +type StreamFramer struct { + out chan<- *StreamFrame + + frameSize int + + heartbeat *time.Ticker + flusher *time.Ticker + + shutdown bool + shutdownCh chan struct{} + exitCh chan struct{} + + // The mutex protects everything below + l sync.Mutex + + // The current working frame + f StreamFrame + data *bytes.Buffer + + // Captures whether the framer is running and any error that occurred to + // cause it to stop. + running bool + err error +} + +// NewStreamFramer creates a new stream framer that will output StreamFrames to +// the passed output channel. +func NewStreamFramer(out chan<- *StreamFrame, + heartbeatRate, batchWindow time.Duration, frameSize int) *StreamFramer { + + // Create the heartbeat and flush ticker + heartbeat := time.NewTicker(heartbeatRate) + flusher := time.NewTicker(batchWindow) + + return &StreamFramer{ + out: out, + frameSize: frameSize, + heartbeat: heartbeat, + flusher: flusher, + data: bytes.NewBuffer(make([]byte, 0, 2*frameSize)), + shutdownCh: make(chan struct{}), + exitCh: make(chan struct{}), + } +} + +// Destroy is used to cleanup the StreamFramer and flush any pending frames +func (s *StreamFramer) Destroy() { + s.l.Lock() + + wasShutdown := s.shutdown + s.shutdown = true + + if !wasShutdown { + close(s.shutdownCh) + } + + s.heartbeat.Stop() + s.flusher.Stop() + running := s.running + s.l.Unlock() + + // Ensure things were flushed + if running { + <-s.exitCh + } + if !wasShutdown { + close(s.out) + } +} + +// Run starts a long lived goroutine that handles sending data as well as +// heartbeating +func (s *StreamFramer) Run() { + s.l.Lock() + defer s.l.Unlock() + if s.running { + return + } + + s.running = true + go s.run() +} + +// ExitCh returns a channel that will be closed when the run loop terminates. +func (s *StreamFramer) ExitCh() <-chan struct{} { + return s.exitCh +} + +// Err returns the error that caused the StreamFramer to exit +func (s *StreamFramer) Err() error { + s.l.Lock() + defer s.l.Unlock() + return s.err +} + +// run is the internal run method. It exits if Destroy is called or an error +// occurs, in which case the exit channel is closed. +func (s *StreamFramer) run() { + var err error + defer func() { + s.l.Lock() + s.running = false + s.err = err + s.l.Unlock() + close(s.exitCh) + }() + +OUTER: + for { + select { + case <-s.shutdownCh: + break OUTER + case <-s.flusher.C: + // Skip if there is nothing to flush + s.l.Lock() + if s.f.IsCleared() { + s.l.Unlock() + continue + } + + // Read the data for the frame, and send it + s.f.Data = s.readData() + err = s.send(&s.f) + s.f.Clear() + s.l.Unlock() + if err != nil { + return + } + case <-s.heartbeat.C: + // Send a heartbeat frame + if err = s.send(HeartbeatStreamFrame); err != nil { + return + } + } + } + + s.l.Lock() + if !s.f.IsCleared() { + s.f.Data = s.readData() + err = s.send(&s.f) + s.f.Clear() + } + s.l.Unlock() +} + +// send takes a StreamFrame, encodes and sends it +func (s *StreamFramer) send(f *StreamFrame) error { + sending := *f + f.Data = nil + + select { + case s.out <- &sending: + return nil + case <-s.exitCh: + return nil + } +} + +// readData is a helper which reads the buffered data returning up to the frame +// size of data. Must be called with the lock held. The returned value is +// invalid on the next read or write into the StreamFramer buffer +func (s *StreamFramer) readData() []byte { + // Compute the amount to read from the buffer + size := s.data.Len() + if size > s.frameSize { + size = s.frameSize + } + if size == 0 { + return nil + } + d := s.data.Next(size) + return d +} + +// Send creates and sends a StreamFrame based on the passed parameters. An error +// is returned if the run routine hasn't run or encountered an error. Send is +// asynchronous and does not block for the data to be transferred. +func (s *StreamFramer) Send(file, fileEvent string, data []byte, offset int64) error { + s.l.Lock() + defer s.l.Unlock() + + // If we are not running, return the error that caused us to not run or + // indicated that it was never started. + if !s.running { + if s.err != nil { + return s.err + } + + return fmt.Errorf("StreamFramer not running") + } + + // Check if not mergeable + if !s.f.IsCleared() && (s.f.File != file || s.f.FileEvent != fileEvent) { + // Flush the old frame + s.f.Data = s.readData() + select { + case <-s.exitCh: + return nil + default: + } + err := s.send(&s.f) + s.f.Clear() + if err != nil { + return err + } + } + + // Store the new data as the current frame. + if s.f.IsCleared() { + s.f.Offset = offset + s.f.File = file + s.f.FileEvent = fileEvent + } + + // Write the data to the buffer + s.data.Write(data) + + // Handle the delete case in which there is no data + force := s.data.Len() == 0 && s.f.FileEvent != "" + + // Flush till we are under the max frame size + for s.data.Len() >= s.frameSize || force { + // Clear since are flushing the frame and capturing the file event. + // Subsequent data frames will be flushed based on the data size alone + // since they share the same fileevent. + if force { + force = false + } + + // Create a new frame to send it + s.f.Data = s.readData() + select { + case <-s.exitCh: + return nil + default: + } + + if err := s.send(&s.f); err != nil { + return err + } + + // Update the offset + s.f.Offset += int64(len(s.f.Data)) + } + + if s.data.Len() == 0 { + s.f.Clear() + } + + return nil +} diff --git a/client/lib/streamframer/framer_test.go b/client/lib/streamframer/framer_test.go new file mode 100644 index 00000000000..13be32141ab --- /dev/null +++ b/client/lib/streamframer/framer_test.go @@ -0,0 +1,259 @@ +package framer + +import ( + "bytes" + "reflect" + "strconv" + "testing" + "time" + + "github.com/hashicorp/nomad/testutil" +) + +// This test checks, that even if the frame size has not been hit, a flush will +// periodically occur. +func TestStreamFramer_Flush(t *testing.T) { + // Create the stream framer + frames := make(chan *StreamFrame, 10) + hRate, bWindow := 100*time.Millisecond, 100*time.Millisecond + sf := NewStreamFramer(frames, hRate, bWindow, 100) + sf.Run() + + f := "foo" + fe := "bar" + d := []byte{0xa} + o := int64(10) + + // Start the reader + resultCh := make(chan struct{}) + go func() { + for { + frame := <-frames + + if frame.IsHeartbeat() { + continue + } + + if reflect.DeepEqual(frame.Data, d) && frame.Offset == o && frame.File == f && frame.FileEvent == fe { + resultCh <- struct{}{} + return + } + + } + }() + + // Write only 1 byte so we do not hit the frame size + if err := sf.Send(f, fe, d, o); err != nil { + t.Fatalf("Send() failed %v", err) + } + + select { + case <-resultCh: + case <-time.After(10 * time.Duration(testutil.TestMultiplier()) * bWindow): + t.Fatalf("failed to flush") + } + + // Shutdown + sf.Destroy() + + select { + case <-sf.ExitCh(): + case <-time.After(10 * time.Duration(testutil.TestMultiplier()) * hRate): + t.Fatalf("exit channel should close") + } + + if _, ok := <-frames; ok { + t.Fatal("out channel should be closed") + } +} + +// This test checks that frames will be batched till the frame size is hit (in +// the case that is before the flush). +func TestStreamFramer_Batch(t *testing.T) { + // Ensure the batch window doesn't get hit + hRate, bWindow := 100*time.Millisecond, 500*time.Millisecond + + // Create the stream framer + frames := make(chan *StreamFrame, 10) + sf := NewStreamFramer(frames, hRate, bWindow, 3) + sf.Run() + + f := "foo" + fe := "bar" + d := []byte{0xa, 0xb, 0xc} + o := int64(10) + + // Start the reader + resultCh := make(chan struct{}) + go func() { + for { + frame := <-frames + if frame.IsHeartbeat() { + continue + } + + if reflect.DeepEqual(frame.Data, d) && frame.Offset == o && frame.File == f && frame.FileEvent == fe { + resultCh <- struct{}{} + return + } + } + }() + + // Write only 1 byte so we do not hit the frame size + if err := sf.Send(f, fe, d[:1], o); err != nil { + t.Fatalf("Send() failed %v", err) + } + + // Ensure we didn't get any data + select { + case <-resultCh: + t.Fatalf("Got data before frame size reached") + case <-time.After(bWindow / 2): + } + + // Write the rest so we hit the frame size + if err := sf.Send(f, fe, d[1:], o); err != nil { + t.Fatalf("Send() failed %v", err) + } + + // Ensure we get data + select { + case <-resultCh: + case <-time.After(10 * time.Duration(testutil.TestMultiplier()) * bWindow): + t.Fatalf("Did not receive data after batch size reached") + } + + // Shutdown + sf.Destroy() + + select { + case <-sf.ExitCh(): + case <-time.After(10 * time.Duration(testutil.TestMultiplier()) * hRate): + t.Fatalf("exit channel should close") + } + + if _, ok := <-frames; ok { + t.Fatal("out channel should be closed") + } +} + +func TestStreamFramer_Heartbeat(t *testing.T) { + // Create the stream framer + frames := make(chan *StreamFrame, 10) + hRate, bWindow := 100*time.Millisecond, 100*time.Millisecond + sf := NewStreamFramer(frames, hRate, bWindow, 100) + sf.Run() + + // Start the reader + resultCh := make(chan struct{}) + go func() { + for { + frame := <-frames + if frame.IsHeartbeat() { + resultCh <- struct{}{} + return + } + } + }() + + select { + case <-resultCh: + case <-time.After(10 * time.Duration(testutil.TestMultiplier()) * hRate): + t.Fatalf("failed to heartbeat") + } + + // Shutdown + sf.Destroy() + + select { + case <-sf.ExitCh(): + case <-time.After(10 * time.Duration(testutil.TestMultiplier()) * hRate): + t.Fatalf("exit channel should close") + } + + if _, ok := <-frames; ok { + t.Fatal("out channel should be closed") + } +} + +// This test checks that frames are received in order +func TestStreamFramer_Order(t *testing.T) { + // Ensure the batch window doesn't get hit + hRate, bWindow := 100*time.Millisecond, 10*time.Millisecond + // Create the stream framer + frames := make(chan *StreamFrame, 10) + sf := NewStreamFramer(frames, hRate, bWindow, 10) + sf.Run() + + files := []string{"1", "2", "3", "4", "5"} + input := bytes.NewBuffer(make([]byte, 0, 100000)) + for i := 0; i <= 1000; i++ { + str := strconv.Itoa(i) + "," + input.WriteString(str) + } + + expected := bytes.NewBuffer(make([]byte, 0, 100000)) + for range files { + expected.Write(input.Bytes()) + } + receivedBuf := bytes.NewBuffer(make([]byte, 0, 100000)) + + // Start the reader + resultCh := make(chan struct{}) + go func() { + for { + frame := <-frames + if frame.IsHeartbeat() { + continue + } + + receivedBuf.Write(frame.Data) + + if reflect.DeepEqual(expected, receivedBuf) { + resultCh <- struct{}{} + return + } + } + }() + + // Send the data + b := input.Bytes() + shards := 10 + each := len(b) / shards + for _, f := range files { + for i := 0; i < shards; i++ { + l, r := each*i, each*(i+1) + if i == shards-1 { + r = len(b) + } + + if err := sf.Send(f, "", b[l:r], 0); err != nil { + t.Fatalf("Send() failed %v", err) + } + } + } + + // Ensure we get data + select { + case <-resultCh: + case <-time.After(10 * time.Duration(testutil.TestMultiplier()) * bWindow): + if reflect.DeepEqual(expected, receivedBuf) { + got := receivedBuf.String() + want := expected.String() + t.Fatalf("Got %v; want %v", got, want) + } + } + + // Shutdown + sf.Destroy() + + select { + case <-sf.ExitCh(): + case <-time.After(10 * time.Duration(testutil.TestMultiplier()) * hRate): + t.Fatalf("exit channel should close") + } + + if _, ok := <-frames; ok { + t.Fatal("out channel should be closed") + } +} diff --git a/client/rpc.go b/client/rpc.go new file mode 100644 index 00000000000..90a1eec47a0 --- /dev/null +++ b/client/rpc.go @@ -0,0 +1,384 @@ +package client + +import ( + "errors" + "io" + "net" + "net/rpc" + "strings" + "time" + + metrics "github.com/armon/go-metrics" + "github.com/hashicorp/consul/lib" + "github.com/hashicorp/nomad/client/servers" + inmem "github.com/hashicorp/nomad/helper/codec" + "github.com/hashicorp/nomad/helper/pool" + "github.com/hashicorp/nomad/nomad/structs" + "github.com/hashicorp/yamux" + "github.com/ugorji/go/codec" +) + +// rpcEndpoints holds the RPC endpoints +type rpcEndpoints struct { + ClientStats *ClientStats + FileSystem *FileSystem + Allocations *Allocations +} + +// ClientRPC is used to make a local, client only RPC call +func (c *Client) ClientRPC(method string, args interface{}, reply interface{}) error { + codec := &inmem.InmemCodec{ + Method: method, + Args: args, + Reply: reply, + } + if err := c.rpcServer.ServeRequest(codec); err != nil { + return err + } + return codec.Err +} + +// StreamingRpcHandler is used to make a local, client only streaming RPC +// call. +func (c *Client) StreamingRpcHandler(method string) (structs.StreamingRpcHandler, error) { + return c.streamingRpcs.GetHandler(method) +} + +// RPC is used to forward an RPC call to a nomad server, or fail if no servers. +func (c *Client) RPC(method string, args interface{}, reply interface{}) error { + // Invoke the RPCHandler if it exists + if c.config.RPCHandler != nil { + return c.config.RPCHandler.RPC(method, args, reply) + } + + // This is subtle but we start measuring the time on the client side + // right at the time of the first request, vs. on the first retry as + // is done on the server side inside forward(). This is because the + // servers may already be applying the RPCHoldTimeout up there, so by + // starting the timer here we won't potentially double up the delay. + firstCheck := time.Now() + +TRY: + server := c.servers.FindServer() + if server == nil { + return noServersErr + } + + // Make the request. + rpcErr := c.connPool.RPC(c.Region(), server.Addr, c.RPCMajorVersion(), method, args, reply) + if rpcErr == nil { + return nil + } + + // Move off to another server, and see if we can retry. + c.logger.Printf("[ERR] nomad: %q RPC failed to server %s: %v", method, server.Addr, rpcErr) + c.servers.NotifyFailedServer(server) + if retry := canRetry(args, rpcErr); !retry { + return rpcErr + } + + // We can wait a bit and retry! + if time.Since(firstCheck) < c.config.RPCHoldTimeout { + jitter := lib.RandomStagger(c.config.RPCHoldTimeout / structs.JitterFraction) + select { + case <-time.After(jitter): + goto TRY + case <-c.shutdownCh: + } + } + return rpcErr +} + +// canRetry returns true if the given situation is safe for a retry. +func canRetry(args interface{}, err error) bool { + // No leader errors are always safe to retry since no state could have + // been changed. + if structs.IsErrNoLeader(err) { + return true + } + + // Reads are safe to retry for stream errors, such as if a server was + // being shut down. + info, ok := args.(structs.RPCInfo) + if ok && info.IsRead() && lib.IsErrEOF(err) { + return true + } + + return false +} + +// RemoteStreamingRpcHandler is used to make a streaming RPC call to a remote +// server. +func (c *Client) RemoteStreamingRpcHandler(method string) (structs.StreamingRpcHandler, error) { + server := c.servers.FindServer() + if server == nil { + return nil, noServersErr + } + + conn, err := c.streamingRpcConn(server, method) + if err != nil { + // Move off to another server + c.logger.Printf("[ERR] nomad: %q RPC failed to server %s: %v", method, server.Addr, err) + c.servers.NotifyFailedServer(server) + return nil, err + } + + return bridgedStreamingRpcHandler(conn), nil +} + +// bridgedStreamingRpcHandler creates a bridged streaming RPC handler by copying +// data between the two sides. +func bridgedStreamingRpcHandler(sideA io.ReadWriteCloser) structs.StreamingRpcHandler { + return func(sideB io.ReadWriteCloser) { + defer sideA.Close() + defer sideB.Close() + structs.Bridge(sideA, sideB) + } +} + +// streamingRpcConn is used to retrieve a connection to a server to conduct a +// streaming RPC. +func (c *Client) streamingRpcConn(server *servers.Server, method string) (net.Conn, error) { + // Dial the server + conn, err := net.DialTimeout("tcp", server.Addr.String(), 10*time.Second) + if err != nil { + return nil, err + } + + // Cast to TCPConn + if tcp, ok := conn.(*net.TCPConn); ok { + tcp.SetKeepAlive(true) + tcp.SetNoDelay(true) + } + + // Check if TLS is enabled + c.tlsWrapLock.RLock() + tlsWrap := c.tlsWrap + c.tlsWrapLock.RUnlock() + + if tlsWrap != nil { + // Switch the connection into TLS mode + if _, err := conn.Write([]byte{byte(pool.RpcTLS)}); err != nil { + conn.Close() + return nil, err + } + + // Wrap the connection in a TLS client + tlsConn, err := tlsWrap(c.Region(), conn) + if err != nil { + conn.Close() + return nil, err + } + conn = tlsConn + } + + // Write the multiplex byte to set the mode + if _, err := conn.Write([]byte{byte(pool.RpcStreaming)}); err != nil { + conn.Close() + return nil, err + } + + // Send the header + encoder := codec.NewEncoder(conn, structs.MsgpackHandle) + decoder := codec.NewDecoder(conn, structs.MsgpackHandle) + header := structs.StreamingRpcHeader{ + Method: method, + } + if err := encoder.Encode(header); err != nil { + conn.Close() + return nil, err + } + + // Wait for the acknowledgement + var ack structs.StreamingRpcAck + if err := decoder.Decode(&ack); err != nil { + conn.Close() + return nil, err + } + + if ack.Error != "" { + conn.Close() + return nil, errors.New(ack.Error) + } + + return conn, nil +} + +// setupClientRpc is used to setup the Client's RPC endpoints +func (c *Client) setupClientRpc() { + // Initialize the RPC handlers + c.endpoints.ClientStats = &ClientStats{c} + c.endpoints.FileSystem = NewFileSystemEndpoint(c) + c.endpoints.Allocations = &Allocations{c} + + // Create the RPC Server + c.rpcServer = rpc.NewServer() + + // Register the endpoints with the RPC server + c.setupClientRpcServer(c.rpcServer) + + go c.rpcConnListener() +} + +// setupClientRpcServer is used to populate a client RPC server with endpoints. +func (c *Client) setupClientRpcServer(server *rpc.Server) { + // Register the endpoints + server.Register(c.endpoints.ClientStats) + server.Register(c.endpoints.FileSystem) + server.Register(c.endpoints.Allocations) +} + +// rpcConnListener is a long lived function that listens for new connections +// being made on the connection pool and starts an RPC listener for each +// connection. +func (c *Client) rpcConnListener() { + // Make a channel for new connections. + conns := make(chan *yamux.Session, 4) + c.connPool.SetConnListener(conns) + + for { + select { + case <-c.shutdownCh: + return + case session, ok := <-conns: + if !ok { + continue + } + + go c.listenConn(session) + } + } +} + +// listenConn is used to listen for connections being made from the server on +// pre-existing connection. This should be called in a goroutine. +func (c *Client) listenConn(s *yamux.Session) { + for { + conn, err := s.Accept() + if err != nil { + if s.IsClosed() { + return + } + + c.logger.Printf("[ERR] client.rpc: failed to accept RPC conn: %v", err) + continue + } + + go c.handleConn(conn) + metrics.IncrCounter([]string{"client", "rpc", "accept_conn"}, 1) + } +} + +// handleConn is used to determine if this is a RPC or Streaming RPC connection and +// invoke the correct handler +func (c *Client) handleConn(conn net.Conn) { + // Read a single byte + buf := make([]byte, 1) + if _, err := conn.Read(buf); err != nil { + if err != io.EOF { + c.logger.Printf("[ERR] client.rpc: failed to read byte: %v", err) + } + conn.Close() + return + } + + // Switch on the byte + switch pool.RPCType(buf[0]) { + case pool.RpcNomad: + c.handleNomadConn(conn) + + case pool.RpcStreaming: + c.handleStreamingConn(conn) + + default: + c.logger.Printf("[ERR] client.rpc: unrecognized RPC byte: %v", buf[0]) + conn.Close() + return + } +} + +// handleNomadConn is used to handle a single Nomad RPC connection. +func (c *Client) handleNomadConn(conn net.Conn) { + defer conn.Close() + rpcCodec := pool.NewServerCodec(conn) + for { + select { + case <-c.shutdownCh: + return + default: + } + + if err := c.rpcServer.ServeRequest(rpcCodec); err != nil { + if err != io.EOF && !strings.Contains(err.Error(), "closed") { + c.logger.Printf("[ERR] client.rpc: RPC error: %v (%v)", err, conn) + metrics.IncrCounter([]string{"client", "rpc", "request_error"}, 1) + } + return + } + metrics.IncrCounter([]string{"client", "rpc", "request"}, 1) + } +} + +// handleStreamingConn is used to handle a single Streaming Nomad RPC connection. +func (c *Client) handleStreamingConn(conn net.Conn) { + defer conn.Close() + + // Decode the header + var header structs.StreamingRpcHeader + decoder := codec.NewDecoder(conn, structs.MsgpackHandle) + if err := decoder.Decode(&header); err != nil { + if err != io.EOF && !strings.Contains(err.Error(), "closed") { + c.logger.Printf("[ERR] client.rpc: Streaming RPC error: %v (%v)", err, conn) + metrics.IncrCounter([]string{"client", "streaming_rpc", "request_error"}, 1) + } + + return + } + + ack := structs.StreamingRpcAck{} + handler, err := c.streamingRpcs.GetHandler(header.Method) + if err != nil { + c.logger.Printf("[ERR] client.rpc: Streaming RPC error: %v (%v)", err, conn) + metrics.IncrCounter([]string{"client", "streaming_rpc", "request_error"}, 1) + ack.Error = err.Error() + } + + // Send the acknowledgement + encoder := codec.NewEncoder(conn, structs.MsgpackHandle) + if err := encoder.Encode(ack); err != nil { + conn.Close() + return + } + + if ack.Error != "" { + return + } + + // Invoke the handler + metrics.IncrCounter([]string{"client", "streaming_rpc", "request"}, 1) + handler(conn) +} + +// resolveServer given a sever's address as a string, return it's resolved +// net.Addr or an error. +func resolveServer(s string) (net.Addr, error) { + const defaultClientPort = "4647" // default client RPC port + host, port, err := net.SplitHostPort(s) + if err != nil { + if strings.Contains(err.Error(), "missing port") { + host = s + port = defaultClientPort + } else { + return nil, err + } + } + return net.ResolveTCPAddr("tcp", net.JoinHostPort(host, port)) +} + +// Ping is used to ping a particular server and returns whether it is healthy or +// a potential error. +func (c *Client) Ping(srv net.Addr) error { + var reply struct{} + err := c.connPool.RPC(c.Region(), srv, c.RPCMajorVersion(), "Status.Ping", struct{}{}, &reply) + return err +} diff --git a/client/rpc_test.go b/client/rpc_test.go new file mode 100644 index 00000000000..c25033923e8 --- /dev/null +++ b/client/rpc_test.go @@ -0,0 +1,115 @@ +package client + +import ( + "errors" + "testing" + + "github.com/hashicorp/nomad/client/config" + "github.com/hashicorp/nomad/nomad" + "github.com/hashicorp/nomad/nomad/structs" + sconfig "github.com/hashicorp/nomad/nomad/structs/config" + "github.com/hashicorp/nomad/testutil" + "github.com/stretchr/testify/require" +) + +func TestRpc_streamingRpcConn_badEndpoint(t *testing.T) { + t.Parallel() + require := require.New(t) + s1 := nomad.TestServer(t, nil) + defer s1.Shutdown() + testutil.WaitForLeader(t, s1.RPC) + + c := TestClient(t, func(c *config.Config) { + c.Servers = []string{s1.GetConfig().RPCAddr.String()} + }) + defer c.Shutdown() + + // Wait for the client to connect + testutil.WaitForResult(func() (bool, error) { + node, err := s1.State().NodeByID(nil, c.NodeID()) + if err != nil { + return false, err + } + if node == nil { + return false, errors.New("no node") + } + + return node.Status == structs.NodeStatusReady, errors.New("wrong status") + }, func(err error) { + t.Fatalf("should have a clients") + }) + + // Get the server + server := c.servers.FindServer() + require.NotNil(server) + + conn, err := c.streamingRpcConn(server, "Bogus") + require.Nil(conn) + require.NotNil(err) + require.Contains(err.Error(), "Unknown rpc method: \"Bogus\"") +} + +func TestRpc_streamingRpcConn_badEndpoint_TLS(t *testing.T) { + t.Parallel() + require := require.New(t) + + const ( + cafile = "../helper/tlsutil/testdata/ca.pem" + foocert = "../helper/tlsutil/testdata/nomad-foo.pem" + fookey = "../helper/tlsutil/testdata/nomad-foo-key.pem" + ) + + s1 := nomad.TestServer(t, func(c *nomad.Config) { + c.Region = "regionFoo" + c.BootstrapExpect = 1 + c.DevDisableBootstrap = true + c.TLSConfig = &sconfig.TLSConfig{ + EnableHTTP: true, + EnableRPC: true, + VerifyServerHostname: true, + CAFile: cafile, + CertFile: foocert, + KeyFile: fookey, + } + }) + defer s1.Shutdown() + testutil.WaitForLeader(t, s1.RPC) + + c := TestClient(t, func(c *config.Config) { + c.Region = "regionFoo" + c.Servers = []string{s1.GetConfig().RPCAddr.String()} + c.TLSConfig = &sconfig.TLSConfig{ + EnableHTTP: true, + EnableRPC: true, + VerifyServerHostname: true, + CAFile: cafile, + CertFile: foocert, + KeyFile: fookey, + } + }) + defer c.Shutdown() + + // Wait for the client to connect + testutil.WaitForResult(func() (bool, error) { + node, err := s1.State().NodeByID(nil, c.NodeID()) + if err != nil { + return false, err + } + if node == nil { + return false, errors.New("no node") + } + + return node.Status == structs.NodeStatusReady, errors.New("wrong status") + }, func(err error) { + t.Fatalf("should have a clients") + }) + + // Get the server + server := c.servers.FindServer() + require.NotNil(server) + + conn, err := c.streamingRpcConn(server, "Bogus") + require.Nil(conn) + require.NotNil(err) + require.Contains(err.Error(), "Unknown rpc method: \"Bogus\"") +} diff --git a/client/serverlist.go b/client/serverlist.go deleted file mode 100644 index 87aec05e61f..00000000000 --- a/client/serverlist.go +++ /dev/null @@ -1,111 +0,0 @@ -package client - -import ( - "math/rand" - "net" - "sort" - "strings" - "sync" -) - -// serverlist is a prioritized randomized list of nomad servers. Users should -// call all() to retrieve the full list, followed by failed(e) on each endpoint -// that's failed and good(e) when a valid endpoint is found. -type serverlist struct { - e endpoints - mu sync.RWMutex -} - -func newServerList() *serverlist { - return &serverlist{} -} - -// set the server list to a new list. The new list will be shuffled and sorted -// by priority. -func (s *serverlist) set(in endpoints) { - s.mu.Lock() - s.e = in - s.mu.Unlock() -} - -// all returns a copy of the full server list, shuffled and then sorted by -// priority -func (s *serverlist) all() endpoints { - s.mu.RLock() - out := make(endpoints, len(s.e)) - copy(out, s.e) - s.mu.RUnlock() - - // Randomize the order - for i, j := range rand.Perm(len(out)) { - out[i], out[j] = out[j], out[i] - } - - // Sort by priority - sort.Sort(out) - return out -} - -// failed endpoint will be deprioritized if its still in the list. -func (s *serverlist) failed(e *endpoint) { - s.mu.Lock() - defer s.mu.Unlock() - for _, cur := range s.e { - if cur.equal(e) { - cur.priority++ - return - } - } -} - -// good endpoint will get promoted to the highest priority if it's still in the -// list. -func (s *serverlist) good(e *endpoint) { - s.mu.Lock() - defer s.mu.Unlock() - for _, cur := range s.e { - if cur.equal(e) { - cur.priority = 0 - return - } - } -} - -func (e endpoints) Len() int { - return len(e) -} - -func (e endpoints) Less(i int, j int) bool { - // Sort only by priority as endpoints should be shuffled and ordered - // only by priority - return e[i].priority < e[j].priority -} - -func (e endpoints) Swap(i int, j int) { - e[i], e[j] = e[j], e[i] -} - -type endpoints []*endpoint - -func (e endpoints) String() string { - names := make([]string, 0, len(e)) - for _, endpoint := range e { - names = append(names, endpoint.name) - } - return strings.Join(names, ",") -} - -type endpoint struct { - name string - addr net.Addr - - // 0 being the highest priority - priority int -} - -// equal returns true if the name and addr match between two endpoints. -// Priority is ignored because the same endpoint may be added by discovery and -// heartbeating with different priorities. -func (e *endpoint) equal(o *endpoint) bool { - return e.name == o.name && e.addr == o.addr -} diff --git a/client/serverlist_test.go b/client/serverlist_test.go deleted file mode 100644 index e23ef4b7fe5..00000000000 --- a/client/serverlist_test.go +++ /dev/null @@ -1,119 +0,0 @@ -package client - -import ( - "log" - "os" - "strings" - "testing" -) - -func TestServerList(t *testing.T) { - t.Parallel() - s := newServerList() - - // New lists should be empty - if e := s.all(); len(e) != 0 { - t.Fatalf("expected empty list to return an empty list, but received: %+q", e) - } - - mklist := func() endpoints { - return endpoints{ - &endpoint{"b", nil, 1}, - &endpoint{"c", nil, 1}, - &endpoint{"g", nil, 2}, - &endpoint{"d", nil, 1}, - &endpoint{"e", nil, 1}, - &endpoint{"f", nil, 1}, - &endpoint{"h", nil, 2}, - &endpoint{"a", nil, 0}, - } - } - s.set(mklist()) - - orig := mklist() - all := s.all() - if len(all) != len(orig) { - t.Fatalf("expected %d endpoints but only have %d", len(orig), len(all)) - } - - // Assert list is properly randomized+sorted - for i, pri := range []int{0, 1, 1, 1, 1, 1, 2, 2} { - if all[i].priority != pri { - t.Errorf("expected endpoint %d (%+q) to be priority %d", i, all[i], pri) - } - } - - // Subsequent sets should reshuffle (try multiple times as they may - // shuffle in the same order) - tries := 0 - max := 3 - for ; tries < max; tries++ { - if s.all().String() == s.all().String() { - // eek, matched; try again in case we just got unlucky - continue - } - break - } - if tries == max { - t.Fatalf("after %d attempts servers were still not random reshuffled", tries) - } - - // Mark an endpoint as failed enough that it should be at the end of the list - sa := &endpoint{"a", nil, 0} - s.failed(sa) - s.failed(sa) - s.failed(sa) - all2 := s.all() - if len(all2) != len(orig) { - t.Fatalf("marking should not have changed list length") - } - if all2[len(all)-1].name != sa.name { - t.Fatalf("failed endpoint should be at end of list: %+q", all2) - } - - // But if the bad endpoint succeeds even once it should be bumped to the top group - s.good(sa) - found := false - for _, e := range s.all() { - if e.name == sa.name { - if e.priority != 0 { - t.Fatalf("server newly marked good should have highest priority") - } - found = true - } - } - if !found { - t.Fatalf("what happened to endpoint A?!") - } -} - -// TestClient_ServerList tests client methods that interact with the internal -// nomad server list. -func TestClient_ServerList(t *testing.T) { - t.Parallel() - // manually create a mostly empty client to avoid spinning up a ton of - // goroutines that complicate testing - client := Client{servers: newServerList(), logger: log.New(os.Stderr, "", log.Ltime|log.Lshortfile)} - - if s := client.GetServers(); len(s) != 0 { - t.Fatalf("expected server lit to be empty but found: %+q", s) - } - if err := client.SetServers(nil); err != noServersErr { - t.Fatalf("expected setting an empty list to return a 'no servers' error but received %v", err) - } - if err := client.SetServers([]string{"123.456.13123.123.13:80"}); err == nil { - t.Fatalf("expected setting a bad server to return an error") - } - if err := client.SetServers([]string{"123.456.13123.123.13:80", "127.0.0.1:1234", "127.0.0.1"}); err != nil { - t.Fatalf("expected setting at least one good server to succeed but received: %v", err) - } - s := client.GetServers() - if len(s) != 2 { - t.Fatalf("expected 2 servers but received: %+q", s) - } - for _, host := range s { - if !strings.HasPrefix(host, "127.0.0.1:") { - t.Errorf("expected both servers to be localhost and include port but found: %s", host) - } - } -} diff --git a/client/servers/manager.go b/client/servers/manager.go new file mode 100644 index 00000000000..6dac0c7e489 --- /dev/null +++ b/client/servers/manager.go @@ -0,0 +1,296 @@ +// Package servers provides an interface for choosing Servers to communicate +// with from a Nomad Client perspective. The package does not provide any API +// guarantees and should be called only by `hashicorp/nomad`. +package servers + +import ( + "log" + "math/rand" + "net" + "strings" + "sync" + "time" + + "github.com/hashicorp/consul/lib" +) + +const ( + // clientRPCMinReuseDuration controls the minimum amount of time RPC + // queries are sent over an established connection to a single server + clientRPCMinReuseDuration = 5 * time.Minute + + // Limit the number of new connections a server receives per second + // for connection rebalancing. This limit caps the load caused by + // continual rebalancing efforts when a cluster is in equilibrium. A + // lower value comes at the cost of increased recovery time after a + // partition. This parameter begins to take effect when there are + // more than ~48K clients querying 5x servers or at lower server + // values when there is a partition. + // + // For example, in a 100K Nomad cluster with 5x servers, it will + // take ~5min for all servers to rebalance their connections. If + // 99,995 agents are in the minority talking to only one server, it + // will take ~26min for all servers to rebalance. A 10K cluster in + // the same scenario will take ~2.6min to rebalance. + newRebalanceConnsPerSecPerServer = 64 +) + +// Pinger is an interface for pinging a server to see if it is healthy. +type Pinger interface { + Ping(addr net.Addr) error +} + +// Server contains the address of a server and metadata that can be used for +// choosing a server to contact. +type Server struct { + // Addr is the resolved address of the server + Addr net.Addr + addr string + sync.Mutex + + // DC is the datacenter of the server + DC string +} + +func (s *Server) Copy() *Server { + s.Lock() + defer s.Unlock() + + return &Server{ + Addr: s.Addr, + addr: s.addr, + DC: s.DC, + } +} + +func (s *Server) String() string { + s.Lock() + defer s.Unlock() + + if s.addr == "" { + s.addr = s.Addr.String() + } + + return s.addr +} + +type Servers []*Server + +func (s Servers) String() string { + addrs := make([]string, 0, len(s)) + for _, srv := range s { + addrs = append(addrs, srv.String()) + } + return strings.Join(addrs, ",") +} + +// cycle cycles a list of servers in-place +func (s Servers) cycle() { + numServers := len(s) + if numServers < 2 { + return // No action required + } + + start := s[0] + for i := 1; i < numServers; i++ { + s[i-1] = s[i] + } + s[numServers-1] = start +} + +// shuffle shuffles the server list in place +func (s Servers) shuffle() { + for i := len(s) - 1; i > 0; i-- { + j := rand.Int31n(int32(i + 1)) + s[i], s[j] = s[j], s[i] + } +} + +type Manager struct { + // servers is the list of all known Nomad servers. + servers Servers + + // rebalanceTimer controls the duration of the rebalance interval + rebalanceTimer *time.Timer + + // shutdownCh is a copy of the channel in Nomad.Client + shutdownCh chan struct{} + + logger *log.Logger + + // numNodes is used to estimate the approximate number of nodes in + // a cluster and limit the rate at which it rebalances server + // connections. This should be read and set using atomic. + numNodes int32 + + // connPoolPinger is used to test the health of a server in the connection + // pool. Pinger is an interface that wraps client.ConnPool. + connPoolPinger Pinger + + sync.Mutex +} + +// New is the only way to safely create a new Manager struct. +func New(logger *log.Logger, shutdownCh chan struct{}, connPoolPinger Pinger) (m *Manager) { + return &Manager{ + logger: logger, + connPoolPinger: connPoolPinger, + rebalanceTimer: time.NewTimer(clientRPCMinReuseDuration), + shutdownCh: shutdownCh, + } +} + +// Start is used to start and manage the task of automatically shuffling and +// rebalancing the list of Nomad servers in order to distribute load across +// all known and available Nomad servers. +func (m *Manager) Start() { + for { + select { + case <-m.rebalanceTimer.C: + m.RebalanceServers() + m.refreshServerRebalanceTimer() + + case <-m.shutdownCh: + m.logger.Printf("[DEBUG] manager: shutting down") + return + } + } +} + +func (m *Manager) SetServers(servers Servers) { + m.Lock() + defer m.Unlock() + m.servers = servers +} + +// FindServer returns a server to send an RPC too. If there are no servers, nil +// is returned. +func (m *Manager) FindServer() *Server { + m.Lock() + defer m.Unlock() + + if len(m.servers) == 0 { + m.logger.Printf("[WARN] manager: No servers available") + return nil + } + + // Return whatever is at the front of the list because it is + // assumed to be the oldest in the server list (unless - + // hypothetically - the server list was rotated right after a + // server was added). + return m.servers[0] +} + +// NumNodes returns the number of approximate nodes in the cluster. +func (m *Manager) NumNodes() int32 { + m.Lock() + defer m.Unlock() + return m.numNodes +} + +// SetNumNodes stores the number of approximate nodes in the cluster. +func (m *Manager) SetNumNodes(n int32) { + m.Lock() + defer m.Unlock() + m.numNodes = n +} + +// NotifyFailedServer marks the passed in server as "failed" by rotating it +// to the end of the server list. +func (m *Manager) NotifyFailedServer(s *Server) { + m.Lock() + defer m.Unlock() + + // If the server being failed is not the first server on the list, + // this is a noop. If, however, the server is failed and first on + // the list, move the server to the end of the list. + if len(m.servers) > 1 && m.servers[0] == s { + m.servers.cycle() + } +} + +// NumServers returns the total number of known servers whether healthy or not. +func (m *Manager) NumServers() int { + m.Lock() + defer m.Unlock() + return len(m.servers) +} + +// GetServers returns a copy of the current list of servers. +func (m *Manager) GetServers() Servers { + m.Lock() + defer m.Unlock() + + copy := make([]*Server, 0, len(m.servers)) + for _, s := range m.servers { + copy = append(copy, s.Copy()) + } + + return copy +} + +// RebalanceServers shuffles the order in which Servers will be contacted. The +// function will shuffle the set of potential servers to contact and then attempt +// to contact each server. If a server successfully responds it is used, otherwise +// it is rotated such that it will be the last attempted server. +func (m *Manager) RebalanceServers() { + // Shuffle servers so we have a chance of picking a new one. + servers := m.GetServers() + servers.shuffle() + + // Iterate through the shuffled server list to find an assumed + // healthy server. NOTE: Do not iterate on the list directly because + // this loop mutates the server list in-place. + var foundHealthyServer bool + for i := 0; i < len(m.servers); i++ { + // Always test the first server. Failed servers are cycled + // while Serf detects the node has failed. + srv := servers[0] + + err := m.connPoolPinger.Ping(srv.Addr) + if err == nil { + foundHealthyServer = true + break + } + m.logger.Printf(`[DEBUG] manager: pinging server "%s" failed: %s`, srv, err) + + servers.cycle() + } + + if !foundHealthyServer { + m.logger.Printf("[DEBUG] manager: No healthy servers during rebalance") + return + } + + // Save the servers + m.Lock() + m.servers = servers + m.Unlock() +} + +// refreshServerRebalanceTimer is only called once m.rebalanceTimer expires. +func (m *Manager) refreshServerRebalanceTimer() time.Duration { + m.Lock() + defer m.Unlock() + numServers := len(m.servers) + + // Limit this connection's life based on the size (and health) of the + // cluster. Never rebalance a connection more frequently than + // connReuseLowWatermarkDuration, and make sure we never exceed + // clusterWideRebalanceConnsPerSec operations/s across numLANMembers. + clusterWideRebalanceConnsPerSec := float64(numServers * newRebalanceConnsPerSecPerServer) + + connRebalanceTimeout := lib.RateScaledInterval(clusterWideRebalanceConnsPerSec, clientRPCMinReuseDuration, int(m.numNodes)) + connRebalanceTimeout += lib.RandomStagger(connRebalanceTimeout) + + m.rebalanceTimer.Reset(connRebalanceTimeout) + return connRebalanceTimeout +} + +// ResetRebalanceTimer resets the rebalance timer. This method exists for +// testing and should not be used directly. +func (m *Manager) ResetRebalanceTimer() { + m.Lock() + defer m.Unlock() + m.rebalanceTimer.Reset(clientRPCMinReuseDuration) +} diff --git a/client/servers/manager_internal_test.go b/client/servers/manager_internal_test.go new file mode 100644 index 00000000000..e6ad03bb3c6 --- /dev/null +++ b/client/servers/manager_internal_test.go @@ -0,0 +1,158 @@ +package servers + +import ( + "fmt" + "log" + "math/rand" + "net" + "os" + "testing" + "time" +) + +func init() { + // Seed the random number generator + rand.Seed(time.Now().UnixNano()) +} + +type fauxAddr struct { + Addr string +} + +func (fa *fauxAddr) String() string { return fa.Addr } +func (fa *fauxAddr) Network() string { return fa.Addr } + +type fauxConnPool struct { + // failPct between 0.0 and 1.0 == pct of time a Ping should fail + failPct float64 +} + +func (cp *fauxConnPool) Ping(net.Addr) error { + successProb := rand.Float64() + if successProb > cp.failPct { + return nil + } + return fmt.Errorf("bad server") +} + +func testManager(t *testing.T) (m *Manager) { + logger := log.New(os.Stderr, "", 0) + shutdownCh := make(chan struct{}) + m = New(logger, shutdownCh, &fauxConnPool{}) + return m +} + +func testManagerFailProb(failPct float64) (m *Manager) { + logger := log.New(os.Stderr, "", 0) + shutdownCh := make(chan struct{}) + m = New(logger, shutdownCh, &fauxConnPool{failPct: failPct}) + return m +} + +func TestManagerInternal_cycleServer(t *testing.T) { + server0 := &Server{Addr: &fauxAddr{"server1"}} + server1 := &Server{Addr: &fauxAddr{"server2"}} + server2 := &Server{Addr: &fauxAddr{"server3"}} + srvs := Servers([]*Server{server0, server1, server2}) + + srvs.cycle() + if len(srvs) != 3 { + t.Fatalf("server length incorrect: %d/3", len(srvs)) + } + if srvs[0] != server1 && + srvs[1] != server2 && + srvs[2] != server0 { + t.Fatalf("server ordering after one cycle not correct") + } + + srvs.cycle() + if srvs[0] != server2 && + srvs[1] != server0 && + srvs[2] != server1 { + t.Fatalf("server ordering after two cycles not correct") + } + + srvs.cycle() + if srvs[0] != server0 && + srvs[1] != server1 && + srvs[2] != server2 { + t.Fatalf("server ordering after three cycles not correct") + } +} + +func TestManagerInternal_New(t *testing.T) { + m := testManager(t) + if m == nil { + t.Fatalf("Manager nil") + } + + if m.logger == nil { + t.Fatalf("Manager.logger nil") + } + + if m.shutdownCh == nil { + t.Fatalf("Manager.shutdownCh nil") + } +} + +// func (l *serverList) refreshServerRebalanceTimer() { +func TestManagerInternal_refreshServerRebalanceTimer(t *testing.T) { + type clusterSizes struct { + numNodes int32 + numServers int + minRebalance time.Duration + } + clusters := []clusterSizes{ + {1, 0, 5 * time.Minute}, // partitioned cluster + {1, 3, 5 * time.Minute}, + {2, 3, 5 * time.Minute}, + {100, 0, 5 * time.Minute}, // partitioned + {100, 1, 5 * time.Minute}, // partitioned + {100, 3, 5 * time.Minute}, + {1024, 1, 5 * time.Minute}, // partitioned + {1024, 3, 5 * time.Minute}, // partitioned + {1024, 5, 5 * time.Minute}, + {16384, 1, 4 * time.Minute}, // partitioned + {16384, 2, 5 * time.Minute}, // partitioned + {16384, 3, 5 * time.Minute}, // partitioned + {16384, 5, 5 * time.Minute}, + {32768, 0, 5 * time.Minute}, // partitioned + {32768, 1, 8 * time.Minute}, // partitioned + {32768, 2, 3 * time.Minute}, // partitioned + {32768, 3, 5 * time.Minute}, // partitioned + {32768, 5, 3 * time.Minute}, // partitioned + {65535, 7, 5 * time.Minute}, + {65535, 0, 5 * time.Minute}, // partitioned + {65535, 1, 8 * time.Minute}, // partitioned + {65535, 2, 3 * time.Minute}, // partitioned + {65535, 3, 5 * time.Minute}, // partitioned + {65535, 5, 3 * time.Minute}, // partitioned + {65535, 7, 5 * time.Minute}, + {1000000, 1, 4 * time.Hour}, // partitioned + {1000000, 2, 2 * time.Hour}, // partitioned + {1000000, 3, 80 * time.Minute}, // partitioned + {1000000, 5, 50 * time.Minute}, // partitioned + {1000000, 11, 20 * time.Minute}, // partitioned + {1000000, 19, 10 * time.Minute}, + } + + logger := log.New(os.Stderr, "", log.LstdFlags) + shutdownCh := make(chan struct{}) + + for _, s := range clusters { + m := New(logger, shutdownCh, &fauxConnPool{}) + m.SetNumNodes(s.numNodes) + servers := make([]*Server, 0, s.numServers) + for i := 0; i < s.numServers; i++ { + nodeName := fmt.Sprintf("s%02d", i) + servers = append(servers, &Server{Addr: &fauxAddr{nodeName}}) + } + m.SetServers(servers) + + d := m.refreshServerRebalanceTimer() + t.Logf("Nodes: %d; Servers: %d; Refresh: %v; Min: %v", s.numNodes, s.numServers, d, s.minRebalance) + if d < s.minRebalance { + t.Errorf("duration too short for cluster of size %d and %d servers (%s < %s)", s.numNodes, s.numServers, d, s.minRebalance) + } + } +} diff --git a/client/servers/manager_test.go b/client/servers/manager_test.go new file mode 100644 index 00000000000..deea7f48f00 --- /dev/null +++ b/client/servers/manager_test.go @@ -0,0 +1,238 @@ +package servers_test + +import ( + "fmt" + "log" + "math/rand" + "net" + "os" + "strings" + "testing" + + "github.com/hashicorp/nomad/client/servers" +) + +type fauxAddr struct { + Addr string +} + +func (fa *fauxAddr) String() string { return fa.Addr } +func (fa *fauxAddr) Network() string { return fa.Addr } + +type fauxConnPool struct { + // failPct between 0.0 and 1.0 == pct of time a Ping should fail + failPct float64 +} + +func (cp *fauxConnPool) Ping(net.Addr) error { + successProb := rand.Float64() + if successProb > cp.failPct { + return nil + } + return fmt.Errorf("bad server") +} + +func testManager() (m *servers.Manager) { + logger := log.New(os.Stderr, "", log.LstdFlags) + shutdownCh := make(chan struct{}) + m = servers.New(logger, shutdownCh, &fauxConnPool{}) + return m +} + +func testManagerFailProb(failPct float64) (m *servers.Manager) { + logger := log.New(os.Stderr, "", log.LstdFlags) + shutdownCh := make(chan struct{}) + m = servers.New(logger, shutdownCh, &fauxConnPool{failPct: failPct}) + return m +} + +func TestServers_SetServers(t *testing.T) { + m := testManager() + var num int + num = m.NumServers() + if num != 0 { + t.Fatalf("Expected zero servers to start") + } + + s1 := &servers.Server{Addr: &fauxAddr{"server1"}} + s2 := &servers.Server{Addr: &fauxAddr{"server2"}} + m.SetServers([]*servers.Server{s1, s2}) + num = m.NumServers() + if num != 2 { + t.Fatalf("Expected two servers") + } + + all := m.GetServers() + if l := len(all); l != 2 { + t.Fatalf("expected 2 servers got %d", l) + } + + if all[0] == s1 || all[0] == s2 { + t.Fatalf("expected a copy, got actual server") + } +} + +func TestServers_FindServer(t *testing.T) { + m := testManager() + + if m.FindServer() != nil { + t.Fatalf("Expected nil return") + } + + var srvs []*servers.Server + srvs = append(srvs, &servers.Server{Addr: &fauxAddr{"s1"}}) + m.SetServers(srvs) + if m.NumServers() != 1 { + t.Fatalf("Expected one server") + } + + s1 := m.FindServer() + if s1 == nil { + t.Fatalf("Expected non-nil server") + } + if s1.String() != "s1" { + t.Fatalf("Expected s1 server") + } + + s1 = m.FindServer() + if s1 == nil || s1.String() != "s1" { + t.Fatalf("Expected s1 server (still)") + } + + srvs = append(srvs, &servers.Server{Addr: &fauxAddr{"s2"}}) + m.SetServers(srvs) + if m.NumServers() != 2 { + t.Fatalf("Expected two servers") + } + s1 = m.FindServer() + if s1 == nil || s1.String() != "s1" { + t.Fatalf("Expected s1 server (still)") + } + + m.NotifyFailedServer(s1) + s2 := m.FindServer() + if s2 == nil || s2.String() != "s2" { + t.Fatalf("Expected s2 server") + } + + m.NotifyFailedServer(s2) + s1 = m.FindServer() + if s1 == nil || s1.String() != "s1" { + t.Fatalf("Expected s1 server") + } +} + +func TestServers_New(t *testing.T) { + logger := log.New(os.Stderr, "", log.LstdFlags) + shutdownCh := make(chan struct{}) + m := servers.New(logger, shutdownCh, &fauxConnPool{}) + if m == nil { + t.Fatalf("Manager nil") + } +} + +func TestServers_NotifyFailedServer(t *testing.T) { + m := testManager() + + if m.NumServers() != 0 { + t.Fatalf("Expected zero servers to start") + } + + s1 := &servers.Server{Addr: &fauxAddr{"s1"}} + s2 := &servers.Server{Addr: &fauxAddr{"s2"}} + + // Try notifying for a server that is not managed by Manager + m.NotifyFailedServer(s1) + if m.NumServers() != 0 { + t.Fatalf("Expected zero servers to start") + } + m.SetServers([]*servers.Server{s1}) + + // Test again w/ a server not in the list + m.NotifyFailedServer(s2) + if m.NumServers() != 1 { + t.Fatalf("Expected one server") + } + + m.SetServers([]*servers.Server{s1, s2}) + if m.NumServers() != 2 { + t.Fatalf("Expected two servers") + } + + s1 = m.FindServer() + if s1 == nil || s1.String() != "s1" { + t.Fatalf("Expected s1 server") + } + + m.NotifyFailedServer(s2) + s1 = m.FindServer() + if s1 == nil || s1.String() != "s1" { + t.Fatalf("Expected s1 server (still)") + } + + m.NotifyFailedServer(s1) + s2 = m.FindServer() + if s2 == nil || s2.String() != "s2" { + t.Fatalf("Expected s2 server") + } + + m.NotifyFailedServer(s2) + s1 = m.FindServer() + if s1 == nil || s1.String() != "s1" { + t.Fatalf("Expected s1 server") + } +} + +func TestServers_NumServers(t *testing.T) { + m := testManager() + var num int + num = m.NumServers() + if num != 0 { + t.Fatalf("Expected zero servers to start") + } + + s := &servers.Server{Addr: &fauxAddr{"server1"}} + m.SetServers([]*servers.Server{s}) + num = m.NumServers() + if num != 1 { + t.Fatalf("Expected one server after SetServers") + } +} + +func TestServers_RebalanceServers(t *testing.T) { + const failPct = 0.5 + m := testManagerFailProb(failPct) + const maxServers = 100 + const numShuffleTests = 100 + const uniquePassRate = 0.5 + + // Make a huge list of nodes. + var srvs []*servers.Server + for i := 0; i < maxServers; i++ { + nodeName := fmt.Sprintf("s%02d", i) + srvs = append(srvs, &servers.Server{Addr: &fauxAddr{nodeName}}) + } + m.SetServers(srvs) + + // Keep track of how many unique shuffles we get. + uniques := make(map[string]struct{}, maxServers) + for i := 0; i < numShuffleTests; i++ { + m.RebalanceServers() + + var names []string + for j := 0; j < maxServers; j++ { + server := m.FindServer() + m.NotifyFailedServer(server) + names = append(names, server.String()) + } + key := strings.Join(names, "|") + uniques[key] = struct{}{} + } + + // We have to allow for the fact that there won't always be a unique + // shuffle each pass, so we just look for smell here without the test + // being flaky. + if len(uniques) < int(maxServers*uniquePassRate) { + t.Fatalf("unique shuffle ratio too low: %d/%d", len(uniques), maxServers) + } +} diff --git a/client/stats/host.go b/client/stats/host.go index 8f0f92377db..1da2b464180 100644 --- a/client/stats/host.go +++ b/client/stats/host.go @@ -1,6 +1,7 @@ package stats import ( + "fmt" "log" "math" "runtime" @@ -93,7 +94,12 @@ func NewHostStatsCollector(logger *log.Logger, allocDir string) *HostStatsCollec func (h *HostStatsCollector) Collect() error { h.hostStatsLock.Lock() defer h.hostStatsLock.Unlock() + return h.collectLocked() +} +// collectLocked collects stats related to resource usage of the host but should +// be called with the lock held. +func (h *HostStatsCollector) collectLocked() error { hs := &HostStats{Timestamp: time.Now().UTC().UnixNano()} // Determine up-time @@ -128,7 +134,7 @@ func (h *HostStatsCollector) Collect() error { // Getting the disk stats for the allocation directory usage, err := disk.Usage(h.allocDir) if err != nil { - return err + return fmt.Errorf("failed to find disk usage of alloc_dir %q: %v", h.allocDir, err) } hs.AllocDirStats = h.toDiskStats(usage, nil) @@ -185,6 +191,13 @@ func (h *HostStatsCollector) collectDiskStats() ([]*DiskStats, error) { func (h *HostStatsCollector) Stats() *HostStats { h.hostStatsLock.RLock() defer h.hostStatsLock.RUnlock() + + if h.hostStats == nil { + if err := h.collectLocked(); err != nil { + h.logger.Printf("[WARN] client: error fetching host resource usage stats: %v", err) + } + } + return h.hostStats } diff --git a/client/structs/structs.go b/client/structs/structs.go index 97887232de0..c038d1ff084 100644 --- a/client/structs/structs.go +++ b/client/structs/structs.go @@ -1,14 +1,170 @@ package structs +//go:generate codecgen -d 102 -o structs.generated.go structs.go + import ( "crypto/md5" "io" "strconv" + "time" "github.com/hashicorp/nomad/client/config" + "github.com/hashicorp/nomad/client/stats" "github.com/hashicorp/nomad/nomad/structs" ) +// RpcError is used for serializing errors with a potential error code +type RpcError struct { + Message string + Code *int64 +} + +func NewRpcError(err error, code *int64) *RpcError { + return &RpcError{ + Message: err.Error(), + Code: code, + } +} + +func (r *RpcError) Error() string { + return r.Message +} + +// ClientStatsResponse is used to return statistics about a node. +type ClientStatsResponse struct { + HostStats *stats.HostStats + structs.QueryMeta +} + +// AllocFileInfo holds information about a file inside the AllocDir +type AllocFileInfo struct { + Name string + IsDir bool + Size int64 + FileMode string + ModTime time.Time +} + +// FsListRequest is used to list an allocation's directory. +type FsListRequest struct { + // AllocID is the allocation to list from + AllocID string + + // Path is the path to list + Path string + + structs.QueryOptions +} + +// FsListResponse is used to return the listings of an allocation's directory. +type FsListResponse struct { + // Files are the result of listing a directory. + Files []*AllocFileInfo + + structs.QueryMeta +} + +// FsStatRequest is used to stat a file +type FsStatRequest struct { + // AllocID is the allocation to stat the file in + AllocID string + + // Path is the path to list + Path string + + structs.QueryOptions +} + +// FsStatResponse is used to return the stat results of a file +type FsStatResponse struct { + // Info is the result of stating a file + Info *AllocFileInfo + + structs.QueryMeta +} + +// FsStreamRequest is the initial request for streaming the content of a file. +type FsStreamRequest struct { + // AllocID is the allocation to stream logs from + AllocID string + + // Path is the path to the file to stream + Path string + + // Offset is the offset to start streaming data at. + Offset int64 + + // Origin can either be "start" or "end" and determines where the offset is + // applied. + Origin string + + // PlainText disables base64 encoding. + PlainText bool + + // Limit is the number of bytes to read + Limit int64 + + // Follow follows the file. + Follow bool + + structs.QueryOptions +} + +// FsLogsRequest is the initial request for accessing allocation logs. +type FsLogsRequest struct { + // AllocID is the allocation to stream logs from + AllocID string + + // Task is the task to stream logs from + Task string + + // LogType indicates whether "stderr" or "stdout" should be streamed + LogType string + + // Offset is the offset to start streaming data at. + Offset int64 + + // Origin can either be "start" or "end" and determines where the offset is + // applied. + Origin string + + // PlainText disables base64 encoding. + PlainText bool + + // Follow follows logs. + Follow bool + + structs.QueryOptions +} + +// StreamErrWrapper is used to serialize output of a stream of a file or logs. +type StreamErrWrapper struct { + // Error stores any error that may have occurred. + Error *RpcError + + // Payload is the payload + Payload []byte +} + +// AllocStatsRequest is used to request the resource usage of a given +// allocation, potentially filtering by task +type AllocStatsRequest struct { + // AllocID is the allocation to retrieves stats for + AllocID string + + // Task is an optional filter to only request stats for the task. + Task string + + structs.QueryOptions +} + +// AllocStatsResponse is used to return the resource usage of a given +// allocation. +type AllocStatsResponse struct { + Stats *AllocResourceUsage + structs.QueryMeta +} + // MemoryStats holds memory usage related stats type MemoryStats struct { RSS uint64 diff --git a/client/task_runner_test.go b/client/task_runner_test.go index d91f5a96bbc..aa30f64b589 100644 --- a/client/task_runner_test.go +++ b/client/task_runner_test.go @@ -642,7 +642,7 @@ func TestTaskRunner_UnregisterConsul_Retries(t *testing.T) { ctx := testTaskRunnerFromAlloc(t, true, alloc) // Use mockConsulServiceClient - consul := newMockConsulServiceClient() + consul := newMockConsulServiceClient(t) ctx.tr.consul = consul ctx.tr.MarkReceived() diff --git a/client/testing.go b/client/testing.go new file mode 100644 index 00000000000..a86728365ab --- /dev/null +++ b/client/testing.go @@ -0,0 +1,43 @@ +package client + +import ( + "github.com/hashicorp/nomad/client/config" + "github.com/hashicorp/nomad/client/fingerprint" + "github.com/hashicorp/nomad/command/agent/consul" + "github.com/hashicorp/nomad/helper" + "github.com/hashicorp/nomad/helper/testlog" + "github.com/hashicorp/nomad/nomad/structs" + "github.com/mitchellh/go-testing-interface" +) + +// TestClient creates an in-memory client for testing purposes. +func TestClient(t testing.T, cb func(c *config.Config)) *Client { + conf := config.DefaultConfig() + conf.VaultConfig.Enabled = helper.BoolToPtr(false) + conf.DevMode = true + conf.Node = &structs.Node{ + Reserved: &structs.Resources{ + DiskMB: 0, + }, + } + + // Tighten the fingerprinter timeouts + if conf.Options == nil { + conf.Options = make(map[string]string) + } + conf.Options[fingerprint.TightenNetworkTimeoutsConfig] = "true" + + if cb != nil { + cb(conf) + } + + logger := testlog.Logger(t) + catalog := consul.NewMockCatalog(logger) + mockService := newMockConsulServiceClient(t) + mockService.logger = logger + client, err := NewClient(conf, catalog, mockService, logger) + if err != nil { + t.Fatalf("err: %v", err) + } + return client +} diff --git a/command/agent/agent_endpoint_test.go b/command/agent/agent_endpoint_test.go index 46748bb93aa..211cd74cf6f 100644 --- a/command/agent/agent_endpoint_test.go +++ b/command/agent/agent_endpoint_test.go @@ -4,14 +4,20 @@ import ( "bytes" "encoding/json" "fmt" + "net" "net/http" "net/http/httptest" + "net/url" "testing" + "time" + msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc" "github.com/hashicorp/nomad/acl" + "github.com/hashicorp/nomad/helper/pool" "github.com/hashicorp/nomad/nomad/mock" "github.com/hashicorp/nomad/nomad/structs" - "github.com/stretchr/testify/assert" + "github.com/hashicorp/nomad/testutil" + "github.com/stretchr/testify/require" ) func TestHTTP_AgentSelf(t *testing.T) { @@ -44,7 +50,7 @@ func TestHTTP_AgentSelf(t *testing.T) { t.Fatalf("bad: %#v", self) } - // Assign a Vault token and assert it is redacted. + // Assign a Vault token and require it is redacted. s.Config.Vault.Token = "badc0deb-adc0-deba-dc0d-ebadc0debadc" respW = httptest.NewRecorder() obj, err = s.Server.AgentSelfRequest(respW, req) @@ -60,21 +66,21 @@ func TestHTTP_AgentSelf(t *testing.T) { func TestHTTP_AgentSelf_ACL(t *testing.T) { t.Parallel() - assert := assert.New(t) + require := require.New(t) httpACLTest(t, nil, func(s *TestAgent) { state := s.Agent.server.State() // Make the HTTP request req, err := http.NewRequest("GET", "/v1/agent/self", nil) - assert.Nil(err) + require.Nil(err) // Try request without a token and expect failure { respW := httptest.NewRecorder() _, err := s.Server.AgentSelfRequest(respW, req) - assert.NotNil(err) - assert.Equal(err.Error(), structs.ErrPermissionDenied.Error()) + require.NotNil(err) + require.Equal(err.Error(), structs.ErrPermissionDenied.Error()) } // Try request with an invalid token and expect failure @@ -83,8 +89,8 @@ func TestHTTP_AgentSelf_ACL(t *testing.T) { token := mock.CreatePolicyAndToken(t, state, 1005, "invalid", mock.NodePolicy(acl.PolicyWrite)) setToken(req, token) _, err := s.Server.AgentSelfRequest(respW, req) - assert.NotNil(err) - assert.Equal(err.Error(), structs.ErrPermissionDenied.Error()) + require.NotNil(err) + require.Equal(err.Error(), structs.ErrPermissionDenied.Error()) } // Try request with a valid token @@ -93,11 +99,11 @@ func TestHTTP_AgentSelf_ACL(t *testing.T) { token := mock.CreatePolicyAndToken(t, state, 1007, "valid", mock.AgentPolicy(acl.PolicyWrite)) setToken(req, token) obj, err := s.Server.AgentSelfRequest(respW, req) - assert.Nil(err) + require.Nil(err) self := obj.(agentSelf) - assert.NotNil(self.Config) - assert.NotNil(self.Stats) + require.NotNil(self.Config) + require.NotNil(self.Stats) } // Try request with a root token @@ -105,18 +111,17 @@ func TestHTTP_AgentSelf_ACL(t *testing.T) { respW := httptest.NewRecorder() setToken(req, s.RootToken) obj, err := s.Server.AgentSelfRequest(respW, req) - assert.Nil(err) + require.Nil(err) self := obj.(agentSelf) - assert.NotNil(self.Config) - assert.NotNil(self.Stats) + require.NotNil(self.Config) + require.NotNil(self.Stats) } }) } func TestHTTP_AgentJoin(t *testing.T) { - // TODO(alexdadgar) - // t.Parallel() + t.Parallel() httpTest(t, nil, func(s *TestAgent) { // Determine the join address member := s.Agent.Server().LocalMember() @@ -173,21 +178,21 @@ func TestHTTP_AgentMembers(t *testing.T) { func TestHTTP_AgentMembers_ACL(t *testing.T) { t.Parallel() - assert := assert.New(t) + require := require.New(t) httpACLTest(t, nil, func(s *TestAgent) { state := s.Agent.server.State() // Make the HTTP request req, err := http.NewRequest("GET", "/v1/agent/members", nil) - assert.Nil(err) + require.Nil(err) // Try request without a token and expect failure { respW := httptest.NewRecorder() _, err := s.Server.AgentMembersRequest(respW, req) - assert.NotNil(err) - assert.Equal(err.Error(), structs.ErrPermissionDenied.Error()) + require.NotNil(err) + require.Equal(err.Error(), structs.ErrPermissionDenied.Error()) } // Try request with an invalid token and expect failure @@ -196,8 +201,8 @@ func TestHTTP_AgentMembers_ACL(t *testing.T) { token := mock.CreatePolicyAndToken(t, state, 1005, "invalid", mock.AgentPolicy(acl.PolicyWrite)) setToken(req, token) _, err := s.Server.AgentMembersRequest(respW, req) - assert.NotNil(err) - assert.Equal(err.Error(), structs.ErrPermissionDenied.Error()) + require.NotNil(err) + require.Equal(err.Error(), structs.ErrPermissionDenied.Error()) } // Try request with a valid token @@ -206,10 +211,10 @@ func TestHTTP_AgentMembers_ACL(t *testing.T) { token := mock.CreatePolicyAndToken(t, state, 1007, "valid", mock.NodePolicy(acl.PolicyRead)) setToken(req, token) obj, err := s.Server.AgentMembersRequest(respW, req) - assert.Nil(err) + require.Nil(err) members := obj.(structs.ServerMembersResponse) - assert.Len(members.Members, 1) + require.Len(members.Members, 1) } // Try request with a root token @@ -217,10 +222,10 @@ func TestHTTP_AgentMembers_ACL(t *testing.T) { respW := httptest.NewRecorder() setToken(req, s.RootToken) obj, err := s.Server.AgentMembersRequest(respW, req) - assert.Nil(err) + require.Nil(err) members := obj.(structs.ServerMembersResponse) - assert.Len(members.Members, 1) + require.Len(members.Members, 1) } }) } @@ -245,21 +250,21 @@ func TestHTTP_AgentForceLeave(t *testing.T) { func TestHTTP_AgentForceLeave_ACL(t *testing.T) { t.Parallel() - assert := assert.New(t) + require := require.New(t) httpACLTest(t, nil, func(s *TestAgent) { state := s.Agent.server.State() // Make the HTTP request req, err := http.NewRequest("PUT", "/v1/agent/force-leave?node=foo", nil) - assert.Nil(err) + require.Nil(err) // Try request without a token and expect failure { respW := httptest.NewRecorder() _, err := s.Server.AgentForceLeaveRequest(respW, req) - assert.NotNil(err) - assert.Equal(err.Error(), structs.ErrPermissionDenied.Error()) + require.NotNil(err) + require.Equal(err.Error(), structs.ErrPermissionDenied.Error()) } // Try request with an invalid token and expect failure @@ -268,8 +273,8 @@ func TestHTTP_AgentForceLeave_ACL(t *testing.T) { token := mock.CreatePolicyAndToken(t, state, 1005, "invalid", mock.NodePolicy(acl.PolicyRead)) setToken(req, token) _, err := s.Server.AgentForceLeaveRequest(respW, req) - assert.NotNil(err) - assert.Equal(err.Error(), structs.ErrPermissionDenied.Error()) + require.NotNil(err) + require.Equal(err.Error(), structs.ErrPermissionDenied.Error()) } // Try request with a valid token @@ -278,8 +283,8 @@ func TestHTTP_AgentForceLeave_ACL(t *testing.T) { token := mock.CreatePolicyAndToken(t, state, 1007, "valid", mock.AgentPolicy(acl.PolicyWrite)) setToken(req, token) _, err := s.Server.AgentForceLeaveRequest(respW, req) - assert.Nil(err) - assert.Equal(http.StatusOK, respW.Code) + require.Nil(err) + require.Equal(http.StatusOK, respW.Code) } // Try request with a root token @@ -287,71 +292,113 @@ func TestHTTP_AgentForceLeave_ACL(t *testing.T) { respW := httptest.NewRecorder() setToken(req, s.RootToken) _, err := s.Server.AgentForceLeaveRequest(respW, req) - assert.Nil(err) - assert.Equal(http.StatusOK, respW.Code) + require.Nil(err) + require.Equal(http.StatusOK, respW.Code) } }) } func TestHTTP_AgentSetServers(t *testing.T) { t.Parallel() - assert := assert.New(t) + require := require.New(t) httpTest(t, nil, func(s *TestAgent) { + addr := s.Config.AdvertiseAddrs.RPC + testutil.WaitForResult(func() (bool, error) { + conn, err := net.DialTimeout("tcp", addr, 100*time.Millisecond) + if err != nil { + return false, err + } + defer conn.Close() + + // Write the Consul RPC byte to set the mode + if _, err := conn.Write([]byte{byte(pool.RpcNomad)}); err != nil { + return false, err + } + + codec := pool.NewClientCodec(conn) + args := &structs.GenericRequest{} + var leader string + err = msgpackrpc.CallWithCodec(codec, "Status.Leader", args, &leader) + return leader != "", err + }, func(err error) { + t.Fatalf("failed to find leader: %v", err) + }) + // Create the request req, err := http.NewRequest("PUT", "/v1/agent/servers", nil) - assert.Nil(err) + require.Nil(err) // Send the request respW := httptest.NewRecorder() _, err = s.Server.AgentServersRequest(respW, req) - assert.NotNil(err) - assert.Contains(err.Error(), "missing server address") + require.NotNil(err) + require.Contains(err.Error(), "missing server address") // Create a valid request req, err = http.NewRequest("PUT", "/v1/agent/servers?address=127.0.0.1%3A4647&address=127.0.0.2%3A4647&address=127.0.0.3%3A4647", nil) - assert.Nil(err) + require.Nil(err) - // Send the request + // Send the request which should fail respW = httptest.NewRecorder() _, err = s.Server.AgentServersRequest(respW, req) - assert.Nil(err) + require.NotNil(err) // Retrieve the servers again req, err = http.NewRequest("GET", "/v1/agent/servers", nil) - assert.Nil(err) + require.Nil(err) respW = httptest.NewRecorder() // Make the request and check the result expected := []string{ - "127.0.0.1:4647", - "127.0.0.2:4647", - "127.0.0.3:4647", + s.GetConfig().AdvertiseAddrs.RPC, } out, err := s.Server.AgentServersRequest(respW, req) - assert.Nil(err) + require.Nil(err) servers := out.([]string) - assert.Len(servers, len(expected)) - assert.Equal(expected, servers) + require.Len(servers, len(expected)) + require.Equal(expected, servers) }) } func TestHTTP_AgentSetServers_ACL(t *testing.T) { t.Parallel() - assert := assert.New(t) + require := require.New(t) httpACLTest(t, nil, func(s *TestAgent) { state := s.Agent.server.State() + addr := s.Config.AdvertiseAddrs.RPC + testutil.WaitForResult(func() (bool, error) { + conn, err := net.DialTimeout("tcp", addr, 100*time.Millisecond) + if err != nil { + return false, err + } + defer conn.Close() + + // Write the Consul RPC byte to set the mode + if _, err := conn.Write([]byte{byte(pool.RpcNomad)}); err != nil { + return false, err + } + + codec := pool.NewClientCodec(conn) + args := &structs.GenericRequest{} + var leader string + err = msgpackrpc.CallWithCodec(codec, "Status.Leader", args, &leader) + return leader != "", err + }, func(err error) { + t.Fatalf("failed to find leader: %v", err) + }) // Make the HTTP request - req, err := http.NewRequest("PUT", "/v1/agent/servers?address=127.0.0.1%3A4647&address=127.0.0.2%3A4647&address=127.0.0.3%3A4647", nil) - assert.Nil(err) + path := fmt.Sprintf("/v1/agent/servers?address=%s", url.QueryEscape(s.GetConfig().AdvertiseAddrs.RPC)) + req, err := http.NewRequest("PUT", path, nil) + require.Nil(err) // Try request without a token and expect failure { respW := httptest.NewRecorder() _, err := s.Server.AgentServersRequest(respW, req) - assert.NotNil(err) - assert.Equal(err.Error(), structs.ErrPermissionDenied.Error()) + require.NotNil(err) + require.Equal(err.Error(), structs.ErrPermissionDenied.Error()) } // Try request with an invalid token and expect failure @@ -360,8 +407,8 @@ func TestHTTP_AgentSetServers_ACL(t *testing.T) { token := mock.CreatePolicyAndToken(t, state, 1005, "invalid", mock.NodePolicy(acl.PolicyRead)) setToken(req, token) _, err := s.Server.AgentServersRequest(respW, req) - assert.NotNil(err) - assert.Equal(err.Error(), structs.ErrPermissionDenied.Error()) + require.NotNil(err) + require.Equal(err.Error(), structs.ErrPermissionDenied.Error()) } // Try request with a valid token @@ -370,8 +417,8 @@ func TestHTTP_AgentSetServers_ACL(t *testing.T) { token := mock.CreatePolicyAndToken(t, state, 1007, "valid", mock.AgentPolicy(acl.PolicyWrite)) setToken(req, token) _, err := s.Server.AgentServersRequest(respW, req) - assert.Nil(err) - assert.Equal(http.StatusOK, respW.Code) + require.Nil(err) + require.Equal(http.StatusOK, respW.Code) } // Try request with a root token @@ -379,47 +426,33 @@ func TestHTTP_AgentSetServers_ACL(t *testing.T) { respW := httptest.NewRecorder() setToken(req, s.RootToken) _, err := s.Server.AgentServersRequest(respW, req) - assert.Nil(err) - assert.Equal(http.StatusOK, respW.Code) + require.Nil(err) + require.Equal(http.StatusOK, respW.Code) } }) } func TestHTTP_AgentListServers_ACL(t *testing.T) { t.Parallel() - assert := assert.New(t) + require := require.New(t) httpACLTest(t, nil, func(s *TestAgent) { state := s.Agent.server.State() - // Set some servers - { - req, err := http.NewRequest("PUT", "/v1/agent/servers?address=127.0.0.1%3A4647&address=127.0.0.2%3A4647&address=127.0.0.3%3A4647", nil) - assert.Nil(err) - - respW := httptest.NewRecorder() - setToken(req, s.RootToken) - _, err = s.Server.AgentServersRequest(respW, req) - assert.Nil(err) - assert.Equal(http.StatusOK, respW.Code) - } - // Create list request req, err := http.NewRequest("GET", "/v1/agent/servers", nil) - assert.Nil(err) + require.Nil(err) expected := []string{ - "127.0.0.1:4647", - "127.0.0.2:4647", - "127.0.0.3:4647", + s.GetConfig().AdvertiseAddrs.RPC, } // Try request without a token and expect failure { respW := httptest.NewRecorder() _, err := s.Server.AgentServersRequest(respW, req) - assert.NotNil(err) - assert.Equal(err.Error(), structs.ErrPermissionDenied.Error()) + require.NotNil(err) + require.Equal(err.Error(), structs.ErrPermissionDenied.Error()) } // Try request with an invalid token and expect failure @@ -428,20 +461,27 @@ func TestHTTP_AgentListServers_ACL(t *testing.T) { token := mock.CreatePolicyAndToken(t, state, 1005, "invalid", mock.NodePolicy(acl.PolicyRead)) setToken(req, token) _, err := s.Server.AgentServersRequest(respW, req) - assert.NotNil(err) - assert.Equal(err.Error(), structs.ErrPermissionDenied.Error()) + require.NotNil(err) + require.Equal(err.Error(), structs.ErrPermissionDenied.Error()) } + // Wait for client to have a server + testutil.WaitForResult(func() (bool, error) { + return len(s.client.GetServers()) != 0, fmt.Errorf("no servers") + }, func(err error) { + t.Fatal(err) + }) + // Try request with a valid token { respW := httptest.NewRecorder() token := mock.CreatePolicyAndToken(t, state, 1007, "valid", mock.AgentPolicy(acl.PolicyRead)) setToken(req, token) out, err := s.Server.AgentServersRequest(respW, req) - assert.Nil(err) + require.Nil(err) servers := out.([]string) - assert.Len(servers, len(expected)) - assert.Equal(expected, servers) + require.Len(servers, len(expected)) + require.Equal(expected, servers) } // Try request with a root token @@ -449,10 +489,10 @@ func TestHTTP_AgentListServers_ACL(t *testing.T) { respW := httptest.NewRecorder() setToken(req, s.RootToken) out, err := s.Server.AgentServersRequest(respW, req) - assert.Nil(err) + require.Nil(err) servers := out.([]string) - assert.Len(servers, len(expected)) - assert.Equal(expected, servers) + require.Len(servers, len(expected)) + require.Equal(expected, servers) } }) } @@ -472,19 +512,15 @@ func TestHTTP_AgentListKeys(t *testing.T) { respW := httptest.NewRecorder() out, err := s.Server.KeyringOperationRequest(respW, req) - if err != nil { - t.Fatalf("err: %s", err) - } + require.Nil(t, err) kresp := out.(structs.KeyringResponse) - if len(kresp.Keys) != 1 { - t.Fatalf("bad: %v", kresp) - } + require.Len(t, kresp.Keys, 1) }) } func TestHTTP_AgentListKeys_ACL(t *testing.T) { t.Parallel() - assert := assert.New(t) + require := require.New(t) key1 := "HS5lJ+XuTlYKWaeGYyG+/A==" @@ -497,14 +533,14 @@ func TestHTTP_AgentListKeys_ACL(t *testing.T) { // Make the HTTP request req, err := http.NewRequest("GET", "/v1/agent/keyring/list", nil) - assert.Nil(err) + require.Nil(err) // Try request without a token and expect failure { respW := httptest.NewRecorder() _, err := s.Server.KeyringOperationRequest(respW, req) - assert.NotNil(err) - assert.Equal(err.Error(), structs.ErrPermissionDenied.Error()) + require.NotNil(err) + require.Equal(err.Error(), structs.ErrPermissionDenied.Error()) } // Try request with an invalid token and expect failure @@ -513,8 +549,8 @@ func TestHTTP_AgentListKeys_ACL(t *testing.T) { token := mock.CreatePolicyAndToken(t, state, 1005, "invalid", mock.AgentPolicy(acl.PolicyRead)) setToken(req, token) _, err := s.Server.KeyringOperationRequest(respW, req) - assert.NotNil(err) - assert.Equal(err.Error(), structs.ErrPermissionDenied.Error()) + require.NotNil(err) + require.Equal(err.Error(), structs.ErrPermissionDenied.Error()) } // Try request with a valid token @@ -523,10 +559,10 @@ func TestHTTP_AgentListKeys_ACL(t *testing.T) { token := mock.CreatePolicyAndToken(t, state, 1007, "valid", mock.AgentPolicy(acl.PolicyWrite)) setToken(req, token) out, err := s.Server.KeyringOperationRequest(respW, req) - assert.Nil(err) + require.Nil(err) kresp := out.(structs.KeyringResponse) - assert.Len(kresp.Keys, 1) - assert.Contains(kresp.Keys, key1) + require.Len(kresp.Keys, 1) + require.Contains(kresp.Keys, key1) } // Try request with a root token @@ -534,17 +570,16 @@ func TestHTTP_AgentListKeys_ACL(t *testing.T) { respW := httptest.NewRecorder() setToken(req, s.RootToken) out, err := s.Server.KeyringOperationRequest(respW, req) - assert.Nil(err) + require.Nil(err) kresp := out.(structs.KeyringResponse) - assert.Len(kresp.Keys, 1) - assert.Contains(kresp.Keys, key1) + require.Len(kresp.Keys, 1) + require.Contains(kresp.Keys, key1) } }) } func TestHTTP_AgentInstallKey(t *testing.T) { - // TODO(alexdadgar) - // t.Parallel() + t.Parallel() key1 := "HS5lJ+XuTlYKWaeGYyG+/A==" key2 := "wH1Bn9hlJ0emgWB1JttVRA==" @@ -584,8 +619,7 @@ func TestHTTP_AgentInstallKey(t *testing.T) { } func TestHTTP_AgentRemoveKey(t *testing.T) { - // TODO(alexdadgar) - // t.Parallel() + t.Parallel() key1 := "HS5lJ+XuTlYKWaeGYyG+/A==" key2 := "wH1Bn9hlJ0emgWB1JttVRA==" @@ -635,87 +669,87 @@ func TestHTTP_AgentRemoveKey(t *testing.T) { func TestHTTP_AgentHealth_Ok(t *testing.T) { t.Parallel() - assert := assert.New(t) + require := require.New(t) // Enable ACLs to ensure they're not enforced httpACLTest(t, nil, func(s *TestAgent) { // No ?type= { req, err := http.NewRequest("GET", "/v1/agent/health", nil) - assert.Nil(err) + require.Nil(err) respW := httptest.NewRecorder() healthI, err := s.Server.HealthRequest(respW, req) - assert.Nil(err) - assert.Equal(http.StatusOK, respW.Code) - assert.NotNil(healthI) + require.Nil(err) + require.Equal(http.StatusOK, respW.Code) + require.NotNil(healthI) health := healthI.(*healthResponse) - assert.NotNil(health.Client) - assert.True(health.Client.Ok) - assert.Equal("ok", health.Client.Message) - assert.NotNil(health.Server) - assert.True(health.Server.Ok) - assert.Equal("ok", health.Server.Message) + require.NotNil(health.Client) + require.True(health.Client.Ok) + require.Equal("ok", health.Client.Message) + require.NotNil(health.Server) + require.True(health.Server.Ok) + require.Equal("ok", health.Server.Message) } // type=client { req, err := http.NewRequest("GET", "/v1/agent/health?type=client", nil) - assert.Nil(err) + require.Nil(err) respW := httptest.NewRecorder() healthI, err := s.Server.HealthRequest(respW, req) - assert.Nil(err) - assert.Equal(http.StatusOK, respW.Code) - assert.NotNil(healthI) + require.Nil(err) + require.Equal(http.StatusOK, respW.Code) + require.NotNil(healthI) health := healthI.(*healthResponse) - assert.NotNil(health.Client) - assert.True(health.Client.Ok) - assert.Equal("ok", health.Client.Message) - assert.Nil(health.Server) + require.NotNil(health.Client) + require.True(health.Client.Ok) + require.Equal("ok", health.Client.Message) + require.Nil(health.Server) } // type=server { req, err := http.NewRequest("GET", "/v1/agent/health?type=server", nil) - assert.Nil(err) + require.Nil(err) respW := httptest.NewRecorder() healthI, err := s.Server.HealthRequest(respW, req) - assert.Nil(err) - assert.Equal(http.StatusOK, respW.Code) - assert.NotNil(healthI) + require.Nil(err) + require.Equal(http.StatusOK, respW.Code) + require.NotNil(healthI) health := healthI.(*healthResponse) - assert.NotNil(health.Server) - assert.True(health.Server.Ok) - assert.Equal("ok", health.Server.Message) - assert.Nil(health.Client) + require.NotNil(health.Server) + require.True(health.Server.Ok) + require.Equal("ok", health.Server.Message) + require.Nil(health.Client) } // type=client&type=server { req, err := http.NewRequest("GET", "/v1/agent/health?type=client&type=server", nil) - assert.Nil(err) + require.Nil(err) respW := httptest.NewRecorder() healthI, err := s.Server.HealthRequest(respW, req) - assert.Nil(err) - assert.Equal(http.StatusOK, respW.Code) - assert.NotNil(healthI) + require.Nil(err) + require.Equal(http.StatusOK, respW.Code) + require.NotNil(healthI) health := healthI.(*healthResponse) - assert.NotNil(health.Client) - assert.True(health.Client.Ok) - assert.Equal("ok", health.Client.Message) - assert.NotNil(health.Server) - assert.True(health.Server.Ok) - assert.Equal("ok", health.Server.Message) + require.NotNil(health.Client) + require.True(health.Client.Ok) + require.Equal("ok", health.Client.Message) + require.NotNil(health.Server) + require.True(health.Server.Ok) + require.Equal("ok", health.Server.Message) } }) } func TestHTTP_AgentHealth_BadServer(t *testing.T) { t.Parallel() - assert := assert.New(t) + require := require.New(t) // Enable ACLs to ensure they're not enforced httpACLTest(t, nil, func(s *TestAgent) { @@ -726,39 +760,39 @@ func TestHTTP_AgentHealth_BadServer(t *testing.T) { // No ?type= means server is just skipped { req, err := http.NewRequest("GET", "/v1/agent/health", nil) - assert.Nil(err) + require.Nil(err) respW := httptest.NewRecorder() healthI, err := s.Server.HealthRequest(respW, req) - assert.Nil(err) - assert.Equal(http.StatusOK, respW.Code) - assert.NotNil(healthI) + require.Nil(err) + require.Equal(http.StatusOK, respW.Code) + require.NotNil(healthI) health := healthI.(*healthResponse) - assert.NotNil(health.Client) - assert.True(health.Client.Ok) - assert.Equal("ok", health.Client.Message) - assert.Nil(health.Server) + require.NotNil(health.Client) + require.True(health.Client.Ok) + require.Equal("ok", health.Client.Message) + require.Nil(health.Server) } // type=server means server is considered unhealthy { req, err := http.NewRequest("GET", "/v1/agent/health?type=server", nil) - assert.Nil(err) + require.Nil(err) respW := httptest.NewRecorder() _, err = s.Server.HealthRequest(respW, req) - assert.NotNil(err) + require.NotNil(err) httpErr, ok := err.(HTTPCodedError) - assert.True(ok) - assert.Equal(500, httpErr.Code()) - assert.Equal(`{"server":{"ok":false,"message":"server not enabled"}}`, err.Error()) + require.True(ok) + require.Equal(500, httpErr.Code()) + require.Equal(`{"server":{"ok":false,"message":"server not enabled"}}`, err.Error()) } }) } func TestHTTP_AgentHealth_BadClient(t *testing.T) { t.Parallel() - assert := assert.New(t) + require := require.New(t) // Enable ACLs to ensure they're not enforced httpACLTest(t, nil, func(s *TestAgent) { @@ -769,32 +803,32 @@ func TestHTTP_AgentHealth_BadClient(t *testing.T) { // No ?type= means client is just skipped { req, err := http.NewRequest("GET", "/v1/agent/health", nil) - assert.Nil(err) + require.Nil(err) respW := httptest.NewRecorder() healthI, err := s.Server.HealthRequest(respW, req) - assert.Nil(err) - assert.Equal(http.StatusOK, respW.Code) - assert.NotNil(healthI) + require.Nil(err) + require.Equal(http.StatusOK, respW.Code) + require.NotNil(healthI) health := healthI.(*healthResponse) - assert.NotNil(health.Server) - assert.True(health.Server.Ok) - assert.Equal("ok", health.Server.Message) - assert.Nil(health.Client) + require.NotNil(health.Server) + require.True(health.Server.Ok) + require.Equal("ok", health.Server.Message) + require.Nil(health.Client) } // type=client means client is considered unhealthy { req, err := http.NewRequest("GET", "/v1/agent/health?type=client", nil) - assert.Nil(err) + require.Nil(err) respW := httptest.NewRecorder() _, err = s.Server.HealthRequest(respW, req) - assert.NotNil(err) + require.NotNil(err) httpErr, ok := err.(HTTPCodedError) - assert.True(ok) - assert.Equal(500, httpErr.Code()) - assert.Equal(`{"client":{"ok":false,"message":"client not enabled"}}`, err.Error()) + require.True(ok) + require.Equal(500, httpErr.Code()) + require.Equal(`{"client":{"ok":false,"message":"client not enabled"}}`, err.Error()) } }) } diff --git a/command/agent/alloc_endpoint.go b/command/agent/alloc_endpoint.go index 03d8bf8c127..422f8906cb0 100644 --- a/command/agent/alloc_endpoint.go +++ b/command/agent/alloc_endpoint.go @@ -6,7 +6,7 @@ import ( "strings" "github.com/golang/snappy" - "github.com/hashicorp/nomad/acl" + cstructs "github.com/hashicorp/nomad/client/structs" "github.com/hashicorp/nomad/nomad/structs" ) @@ -79,9 +79,6 @@ func (s *HTTPServer) AllocSpecificRequest(resp http.ResponseWriter, req *http.Re } func (s *HTTPServer) ClientAllocRequest(resp http.ResponseWriter, req *http.Request) (interface{}, error) { - if s.agent.client == nil { - return nil, clientNotRunning - } reqSuffix := strings.TrimPrefix(req.URL.Path, "/v1/client/allocation/") @@ -96,6 +93,10 @@ func (s *HTTPServer) ClientAllocRequest(resp http.ResponseWriter, req *http.Requ case "stats": return s.allocStats(allocID, resp, req) case "snapshot": + if s.agent.client == nil { + return nil, clientNotRunning + } + return s.allocSnapshot(allocID, resp, req) case "gc": return s.allocGC(allocID, resp, req) @@ -105,43 +106,70 @@ func (s *HTTPServer) ClientAllocRequest(resp http.ResponseWriter, req *http.Requ } func (s *HTTPServer) ClientGCRequest(resp http.ResponseWriter, req *http.Request) (interface{}, error) { - if s.agent.client == nil { - return nil, clientNotRunning - } - - var secret string - s.parseToken(req, &secret) - - // Check node write permissions - if aclObj, err := s.agent.Client().ResolveToken(secret); err != nil { - return nil, err - } else if aclObj != nil && !aclObj.AllowNodeWrite() { - return nil, structs.ErrPermissionDenied + // Get the requested Node ID + requestedNode := req.URL.Query().Get("node_id") + + // Build the request and parse the ACL token + args := structs.NodeSpecificRequest{ + NodeID: requestedNode, + } + s.parse(resp, req, &args.QueryOptions.Region, &args.QueryOptions) + + // Determine the handler to use + useLocalClient, useClientRPC, useServerRPC := s.rpcHandlerForNode(requestedNode) + + // Make the RPC + var reply structs.GenericResponse + var rpcErr error + if useLocalClient { + rpcErr = s.agent.Client().ClientRPC("Allocations.GarbageCollectAll", &args, &reply) + } else if useClientRPC { + rpcErr = s.agent.Client().RPC("ClientAllocations.GarbageCollectAll", &args, &reply) + } else if useServerRPC { + rpcErr = s.agent.Server().RPC("ClientAllocations.GarbageCollectAll", &args, &reply) + } else { + rpcErr = CodedError(400, "No local Node and node_id not provided") + } + + if rpcErr != nil { + if structs.IsErrNoNodeConn(rpcErr) { + rpcErr = CodedError(404, rpcErr.Error()) + } } - s.agent.Client().CollectAllAllocs() - return nil, nil + return nil, rpcErr } func (s *HTTPServer) allocGC(allocID string, resp http.ResponseWriter, req *http.Request) (interface{}, error) { - var secret string - s.parseToken(req, &secret) + // Build the request and parse the ACL token + args := structs.AllocSpecificRequest{ + AllocID: allocID, + } + s.parse(resp, req, &args.QueryOptions.Region, &args.QueryOptions) - var namespace string - parseNamespace(req, &namespace) + // Determine the handler to use + useLocalClient, useClientRPC, useServerRPC := s.rpcHandlerForAlloc(allocID) - // Check namespace submit-job permissions - if aclObj, err := s.agent.Client().ResolveToken(secret); err != nil { - return nil, err - } else if aclObj != nil && !aclObj.AllowNsOp(namespace, acl.NamespaceCapabilitySubmitJob) { - return nil, structs.ErrPermissionDenied + // Make the RPC + var reply structs.GenericResponse + var rpcErr error + if useLocalClient { + rpcErr = s.agent.Client().ClientRPC("Allocations.GarbageCollect", &args, &reply) + } else if useClientRPC { + rpcErr = s.agent.Client().RPC("ClientAllocations.GarbageCollect", &args, &reply) + } else if useServerRPC { + rpcErr = s.agent.Server().RPC("ClientAllocations.GarbageCollect", &args, &reply) + } else { + rpcErr = CodedError(400, "No local Node and node_id not provided") } - if !s.agent.Client().CollectAllocation(allocID) { - // Could not find alloc - return nil, fmt.Errorf("unable to collect allocation: not present") + if rpcErr != nil { + if structs.IsErrNoNodeConn(rpcErr) || structs.IsErrUnknownAllocation(rpcErr) { + rpcErr = CodedError(404, rpcErr.Error()) + } } - return nil, nil + + return nil, rpcErr } func (s *HTTPServer) allocSnapshot(allocID string, resp http.ResponseWriter, req *http.Request) (interface{}, error) { @@ -162,25 +190,36 @@ func (s *HTTPServer) allocSnapshot(allocID string, resp http.ResponseWriter, req } func (s *HTTPServer) allocStats(allocID string, resp http.ResponseWriter, req *http.Request) (interface{}, error) { - var secret string - s.parseToken(req, &secret) - var namespace string - parseNamespace(req, &namespace) + // Build the request and parse the ACL token + task := req.URL.Query().Get("task") + args := cstructs.AllocStatsRequest{ + AllocID: allocID, + Task: task, + } + s.parse(resp, req, &args.QueryOptions.Region, &args.QueryOptions) - // Check namespace read-job permissions - if aclObj, err := s.agent.Client().ResolveToken(secret); err != nil { - return nil, err - } else if aclObj != nil && !aclObj.AllowNsOp(namespace, acl.NamespaceCapabilityReadJob) { - return nil, structs.ErrPermissionDenied + // Determine the handler to use + useLocalClient, useClientRPC, useServerRPC := s.rpcHandlerForAlloc(allocID) + + // Make the RPC + var reply cstructs.AllocStatsResponse + var rpcErr error + if useLocalClient { + rpcErr = s.agent.Client().ClientRPC("Allocations.Stats", &args, &reply) + } else if useClientRPC { + rpcErr = s.agent.Client().RPC("ClientAllocations.Stats", &args, &reply) + } else if useServerRPC { + rpcErr = s.agent.Server().RPC("ClientAllocations.Stats", &args, &reply) + } else { + rpcErr = CodedError(400, "No local Node and node_id not provided") } - clientStats := s.agent.client.StatsReporter() - aStats, err := clientStats.GetAllocStats(allocID) - if err != nil { - return nil, err + if rpcErr != nil { + if structs.IsErrNoNodeConn(rpcErr) || structs.IsErrUnknownAllocation(rpcErr) { + rpcErr = CodedError(404, rpcErr.Error()) + } } - task := req.URL.Query().Get("task") - return aStats.LatestAllocStats(task) + return reply.Stats, rpcErr } diff --git a/command/agent/alloc_endpoint_test.go b/command/agent/alloc_endpoint_test.go index 957c354d8f0..e005753ba51 100644 --- a/command/agent/alloc_endpoint_test.go +++ b/command/agent/alloc_endpoint_test.go @@ -15,11 +15,11 @@ import ( "github.com/golang/snappy" "github.com/hashicorp/nomad/acl" "github.com/hashicorp/nomad/client/allocdir" - "github.com/hashicorp/nomad/nomad" + "github.com/hashicorp/nomad/helper/uuid" "github.com/hashicorp/nomad/nomad/mock" "github.com/hashicorp/nomad/nomad/structs" "github.com/hashicorp/nomad/testutil" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestHTTP_AllocsList(t *testing.T) { @@ -78,9 +78,9 @@ func TestHTTP_AllocsList(t *testing.T) { } expectedMsg := "Task's sibling failed" displayMsg1 := allocs[0].TaskStates["test"].Events[0].DisplayMessage - assert.Equal(t, expectedMsg, displayMsg1, "DisplayMessage should be set") + require.Equal(t, expectedMsg, displayMsg1, "DisplayMessage should be set") displayMsg2 := allocs[0].TaskStates["test"].Events[0].DisplayMessage - assert.Equal(t, expectedMsg, displayMsg2, "DisplayMessage should be set") + require.Equal(t, expectedMsg, displayMsg2, "DisplayMessage should be set") }) } @@ -151,7 +151,7 @@ func TestHTTP_AllocsPrefixList(t *testing.T) { } expectedMsg := "Task's sibling failed" displayMsg1 := n[0].TaskStates["test"].Events[0].DisplayMessage - assert.Equal(t, expectedMsg, displayMsg1, "DisplayMessage should be set") + require.Equal(t, expectedMsg, displayMsg1, "DisplayMessage should be set") }) } @@ -262,31 +262,77 @@ func TestHTTP_AllocQuery_Payload(t *testing.T) { func TestHTTP_AllocStats(t *testing.T) { t.Parallel() + require := require.New(t) + httpTest(t, nil, func(s *TestAgent) { - // Make the HTTP request - req, err := http.NewRequest("GET", "/v1/client/allocation/123/foo", nil) - if err != nil { - t.Fatalf("err: %v", err) + // Local node, local resp + { + // Make the HTTP request + req, err := http.NewRequest("GET", fmt.Sprintf("/v1/client/allocation/%s/stats", uuid.Generate()), nil) + if err != nil { + t.Fatalf("err: %v", err) + } + respW := httptest.NewRecorder() + + // Make the request + _, err = s.Server.ClientAllocRequest(respW, req) + require.NotNil(err) + require.True(structs.IsErrUnknownAllocation(err)) } - respW := httptest.NewRecorder() - // Make the request - _, err = s.Server.ClientAllocRequest(respW, req) - if !strings.Contains(err.Error(), resourceNotFoundErr) { - t.Fatalf("err: %v", err) + // Local node, server resp + { + srv := s.server + s.server = nil + + req, err := http.NewRequest("GET", fmt.Sprintf("/v1/client/allocation/%s/stats", uuid.Generate()), nil) + require.Nil(err) + + respW := httptest.NewRecorder() + _, err = s.Server.ClientAllocRequest(respW, req) + require.NotNil(err) + require.True(structs.IsErrUnknownAllocation(err)) + + s.server = srv + } + + // no client, server resp + { + c := s.client + s.client = nil + + testutil.WaitForResult(func() (bool, error) { + n, err := s.server.State().NodeByID(nil, c.NodeID()) + if err != nil { + return false, err + } + return n != nil, nil + }, func(err error) { + t.Fatalf("should have client: %v", err) + }) + + req, err := http.NewRequest("GET", fmt.Sprintf("/v1/client/allocation/%s/stats", uuid.Generate()), nil) + require.Nil(err) + + respW := httptest.NewRecorder() + _, err = s.Server.ClientAllocRequest(respW, req) + require.NotNil(err) + require.True(structs.IsErrUnknownAllocation(err)) + + s.client = c } }) } func TestHTTP_AllocStats_ACL(t *testing.T) { t.Parallel() - assert := assert.New(t) + require := require.New(t) httpACLTest(t, nil, func(s *TestAgent) { state := s.Agent.server.State() // Make the HTTP request - req, err := http.NewRequest("GET", "/v1/client/allocation/123/stats", nil) + req, err := http.NewRequest("GET", fmt.Sprintf("/v1/client/allocation/%s/stats", uuid.Generate()), nil) if err != nil { t.Fatalf("err: %v", err) } @@ -295,8 +341,8 @@ func TestHTTP_AllocStats_ACL(t *testing.T) { { respW := httptest.NewRecorder() _, err := s.Server.ClientAllocRequest(respW, req) - assert.NotNil(err) - assert.Equal(err.Error(), structs.ErrPermissionDenied.Error()) + require.NotNil(err) + require.Equal(err.Error(), structs.ErrPermissionDenied.Error()) } // Try request with an invalid token and expect failure @@ -305,8 +351,8 @@ func TestHTTP_AllocStats_ACL(t *testing.T) { token := mock.CreatePolicyAndToken(t, state, 1005, "invalid", mock.NodePolicy(acl.PolicyWrite)) setToken(req, token) _, err := s.Server.ClientAllocRequest(respW, req) - assert.NotNil(err) - assert.Equal(err.Error(), structs.ErrPermissionDenied.Error()) + require.NotNil(err) + require.Equal(err.Error(), structs.ErrPermissionDenied.Error()) } // Try request with a valid token @@ -317,8 +363,8 @@ func TestHTTP_AllocStats_ACL(t *testing.T) { token := mock.CreatePolicyAndToken(t, state, 1007, "valid", policy) setToken(req, token) _, err := s.Server.ClientAllocRequest(respW, req) - assert.NotNil(err) - assert.Contains(err.Error(), "unknown allocation ID") + require.NotNil(err) + require.True(structs.IsErrUnknownAllocation(err)) } // Try request with a management token @@ -327,8 +373,8 @@ func TestHTTP_AllocStats_ACL(t *testing.T) { respW := httptest.NewRecorder() setToken(req, s.RootToken) _, err := s.Server.ClientAllocRequest(respW, req) - assert.NotNil(err) - assert.Contains(err.Error(), "unknown allocation ID") + require.NotNil(err) + require.True(structs.IsErrUnknownAllocation(err)) } }) } @@ -353,35 +399,35 @@ func TestHTTP_AllocSnapshot(t *testing.T) { func TestHTTP_AllocSnapshot_WithMigrateToken(t *testing.T) { t.Parallel() - assert := assert.New(t) + require := require.New(t) httpACLTest(t, nil, func(s *TestAgent) { // Request without a token fails req, err := http.NewRequest("GET", "/v1/client/allocation/123/snapshot", nil) - assert.Nil(err) + require.Nil(err) // Make the unauthorized request respW := httptest.NewRecorder() _, err = s.Server.ClientAllocRequest(respW, req) - assert.NotNil(err) - assert.EqualError(err, structs.ErrPermissionDenied.Error()) + require.NotNil(err) + require.EqualError(err, structs.ErrPermissionDenied.Error()) // Create an allocation alloc := mock.Alloc() - validMigrateToken, err := nomad.GenerateMigrateToken(alloc.ID, s.Agent.Client().Node().SecretID) - assert.Nil(err) + validMigrateToken, err := structs.GenerateMigrateToken(alloc.ID, s.Agent.Client().Node().SecretID) + require.Nil(err) // Request with a token succeeds url := fmt.Sprintf("/v1/client/allocation/%s/snapshot", alloc.ID) req, err = http.NewRequest("GET", url, nil) - assert.Nil(err) + require.Nil(err) req.Header.Set("X-Nomad-Token", validMigrateToken) // Make the unauthorized request respW = httptest.NewRecorder() _, err = s.Server.ClientAllocRequest(respW, req) - assert.NotContains(err.Error(), structs.ErrPermissionDenied.Error()) + require.NotContains(err.Error(), structs.ErrPermissionDenied.Error()) }) } @@ -427,7 +473,7 @@ func TestHTTP_AllocSnapshot_Atomic(t *testing.T) { // Remove the task dir to break Snapshot os.RemoveAll(allocDir.TaskDirs["web"].LocalDir) - // Assert Snapshot fails + // require Snapshot fails if err := allocDir.Snapshot(ioutil.Discard); err != nil { s.logger.Printf("[DEBUG] agent.test: snapshot returned error: %v", err) } else { @@ -493,31 +539,84 @@ func TestHTTP_AllocSnapshot_Atomic(t *testing.T) { func TestHTTP_AllocGC(t *testing.T) { t.Parallel() + require := require.New(t) + path := fmt.Sprintf("/v1/client/allocation/%s/gc", uuid.Generate()) httpTest(t, nil, func(s *TestAgent) { - // Make the HTTP request - req, err := http.NewRequest("GET", "/v1/client/allocation/123/gc", nil) - if err != nil { - t.Fatalf("err: %v", err) + // Local node, local resp + { + req, err := http.NewRequest("GET", path, nil) + if err != nil { + t.Fatalf("err: %v", err) + } + + respW := httptest.NewRecorder() + _, err = s.Server.ClientAllocRequest(respW, req) + if !structs.IsErrUnknownAllocation(err) { + t.Fatalf("unexpected err: %v", err) + } } - respW := httptest.NewRecorder() - // Make the request - _, err = s.Server.ClientAllocRequest(respW, req) - if !strings.Contains(err.Error(), "unable to collect allocation") { - t.Fatalf("err: %v", err) + // Local node, server resp + { + srv := s.server + s.server = nil + + req, err := http.NewRequest("GET", path, nil) + if err != nil { + t.Fatalf("err: %v", err) + } + + respW := httptest.NewRecorder() + _, err = s.Server.ClientAllocRequest(respW, req) + if !structs.IsErrUnknownAllocation(err) { + t.Fatalf("unexpected err: %v", err) + } + + s.server = srv + } + + // no client, server resp + { + c := s.client + s.client = nil + + testutil.WaitForResult(func() (bool, error) { + n, err := s.server.State().NodeByID(nil, c.NodeID()) + if err != nil { + return false, err + } + return n != nil, nil + }, func(err error) { + t.Fatalf("should have client: %v", err) + }) + + req, err := http.NewRequest("GET", path, nil) + if err != nil { + t.Fatalf("err: %v", err) + } + + respW := httptest.NewRecorder() + _, err = s.Server.ClientAllocRequest(respW, req) + require.NotNil(err) + if !structs.IsErrUnknownAllocation(err) { + t.Fatalf("unexpected err: %v", err) + } + + s.client = c } }) } func TestHTTP_AllocGC_ACL(t *testing.T) { t.Parallel() - assert := assert.New(t) + require := require.New(t) + path := fmt.Sprintf("/v1/client/allocation/%s/gc", uuid.Generate()) httpACLTest(t, nil, func(s *TestAgent) { state := s.Agent.server.State() // Make the HTTP request - req, err := http.NewRequest("GET", "/v1/client/allocation/123/gc", nil) + req, err := http.NewRequest("GET", path, nil) if err != nil { t.Fatalf("err: %v", err) } @@ -526,8 +625,8 @@ func TestHTTP_AllocGC_ACL(t *testing.T) { { respW := httptest.NewRecorder() _, err := s.Server.ClientAllocRequest(respW, req) - assert.NotNil(err) - assert.Equal(err.Error(), structs.ErrPermissionDenied.Error()) + require.NotNil(err) + require.Equal(err.Error(), structs.ErrPermissionDenied.Error()) } // Try request with an invalid token and expect failure @@ -536,8 +635,8 @@ func TestHTTP_AllocGC_ACL(t *testing.T) { token := mock.CreatePolicyAndToken(t, state, 1005, "invalid", mock.NodePolicy(acl.PolicyWrite)) setToken(req, token) _, err := s.Server.ClientAllocRequest(respW, req) - assert.NotNil(err) - assert.Equal(err.Error(), structs.ErrPermissionDenied.Error()) + require.NotNil(err) + require.Equal(err.Error(), structs.ErrPermissionDenied.Error()) } // Try request with a valid token @@ -548,8 +647,8 @@ func TestHTTP_AllocGC_ACL(t *testing.T) { token := mock.CreatePolicyAndToken(t, state, 1007, "valid", policy) setToken(req, token) _, err := s.Server.ClientAllocRequest(respW, req) - assert.NotNil(err) - assert.Contains(err.Error(), "not present") + require.NotNil(err) + require.True(structs.IsErrUnknownAllocation(err)) } // Try request with a management token @@ -558,26 +657,72 @@ func TestHTTP_AllocGC_ACL(t *testing.T) { respW := httptest.NewRecorder() setToken(req, s.RootToken) _, err := s.Server.ClientAllocRequest(respW, req) - assert.NotNil(err) - assert.Contains(err.Error(), "not present") + require.NotNil(err) + require.True(structs.IsErrUnknownAllocation(err)) } }) } func TestHTTP_AllocAllGC(t *testing.T) { t.Parallel() + require := require.New(t) httpTest(t, nil, func(s *TestAgent) { - // Make the HTTP request - req, err := http.NewRequest("GET", "/v1/client/gc", nil) - if err != nil { - t.Fatalf("err: %v", err) + // Local node, local resp + { + req, err := http.NewRequest("GET", "/v1/client/gc", nil) + if err != nil { + t.Fatalf("err: %v", err) + } + + respW := httptest.NewRecorder() + _, err = s.Server.ClientGCRequest(respW, req) + if err != nil { + t.Fatalf("unexpected err: %v", err) + } } - respW := httptest.NewRecorder() - // Make the request - _, err = s.Server.ClientGCRequest(respW, req) - if err != nil { - t.Fatalf("err: %v", err) + // Local node, server resp + { + srv := s.server + s.server = nil + + req, err := http.NewRequest("GET", fmt.Sprintf("/v1/client/gc?node_id=%s", uuid.Generate()), nil) + require.Nil(err) + + respW := httptest.NewRecorder() + _, err = s.Server.ClientGCRequest(respW, req) + require.NotNil(err) + require.Contains(err.Error(), "Unknown node") + + s.server = srv + } + + // no client, server resp + { + c := s.client + s.client = nil + + testutil.WaitForResult(func() (bool, error) { + n, err := s.server.State().NodeByID(nil, c.NodeID()) + if err != nil { + return false, err + } + return n != nil, nil + }, func(err error) { + t.Fatalf("should have client: %v", err) + }) + + req, err := http.NewRequest("GET", fmt.Sprintf("/v1/client/gc?node_id=%s", c.NodeID()), nil) + require.Nil(err) + + respW := httptest.NewRecorder() + _, err = s.Server.ClientGCRequest(respW, req) + require.NotNil(err) + + // The dev agent uses in-mem RPC so just assert the no route error + require.Contains(err.Error(), structs.ErrNoNodeConn.Error()) + + s.client = c } }) @@ -585,20 +730,20 @@ func TestHTTP_AllocAllGC(t *testing.T) { func TestHTTP_AllocAllGC_ACL(t *testing.T) { t.Parallel() - assert := assert.New(t) + require := require.New(t) httpACLTest(t, nil, func(s *TestAgent) { state := s.Agent.server.State() // Make the HTTP request req, err := http.NewRequest("GET", "/v1/client/gc", nil) - assert.Nil(err) + require.Nil(err) // Try request without a token and expect failure { respW := httptest.NewRecorder() _, err := s.Server.ClientGCRequest(respW, req) - assert.NotNil(err) - assert.Equal(err.Error(), structs.ErrPermissionDenied.Error()) + require.NotNil(err) + require.Equal(err.Error(), structs.ErrPermissionDenied.Error()) } // Try request with an invalid token and expect failure @@ -607,8 +752,8 @@ func TestHTTP_AllocAllGC_ACL(t *testing.T) { token := mock.CreatePolicyAndToken(t, state, 1005, "invalid", mock.NodePolicy(acl.PolicyRead)) setToken(req, token) _, err := s.Server.ClientGCRequest(respW, req) - assert.NotNil(err) - assert.Equal(err.Error(), structs.ErrPermissionDenied.Error()) + require.NotNil(err) + require.Equal(err.Error(), structs.ErrPermissionDenied.Error()) } // Try request with a valid token @@ -617,8 +762,8 @@ func TestHTTP_AllocAllGC_ACL(t *testing.T) { token := mock.CreatePolicyAndToken(t, state, 1007, "valid", mock.NodePolicy(acl.PolicyWrite)) setToken(req, token) _, err := s.Server.ClientGCRequest(respW, req) - assert.Nil(err) - assert.Equal(http.StatusOK, respW.Code) + require.Nil(err) + require.Equal(http.StatusOK, respW.Code) } // Try request with a management token @@ -626,9 +771,8 @@ func TestHTTP_AllocAllGC_ACL(t *testing.T) { respW := httptest.NewRecorder() setToken(req, s.RootToken) _, err := s.Server.ClientGCRequest(respW, req) - assert.Nil(err) - assert.Equal(http.StatusOK, respW.Code) + require.Nil(err) + require.Equal(http.StatusOK, respW.Code) } }) - } diff --git a/command/agent/fs_endpoint.go b/command/agent/fs_endpoint.go index 1c09103665c..c19684c8983 100644 --- a/command/agent/fs_endpoint.go +++ b/command/agent/fs_endpoint.go @@ -1,29 +1,18 @@ package agent -//go:generate codecgen -d 101 -o fs_endpoint.generated.go fs_endpoint.go - import ( "bytes" + "context" "fmt" "io" - "math" + "net" "net/http" - "os" - "path/filepath" - "sort" "strconv" "strings" - "sync" - "syscall" - "time" - - "gopkg.in/tomb.v1" "github.com/docker/docker/pkg/ioutils" - "github.com/hashicorp/nomad/acl" - "github.com/hashicorp/nomad/client/allocdir" + cstructs "github.com/hashicorp/nomad/client/structs" "github.com/hashicorp/nomad/nomad/structs" - "github.com/hpcloud/tail/watch" "github.com/ugorji/go/codec" ) @@ -36,88 +25,20 @@ var ( invalidOrigin = fmt.Errorf("origin must be start or end") ) -const ( - // streamFrameSize is the maximum number of bytes to send in a single frame - streamFrameSize = 64 * 1024 - - // streamHeartbeatRate is the rate at which a heartbeat will occur to detect - // a closed connection without sending any additional data - streamHeartbeatRate = 1 * time.Second - - // streamBatchWindow is the window in which file content is batched before - // being flushed if the frame size has not been hit. - streamBatchWindow = 200 * time.Millisecond - - // nextLogCheckRate is the rate at which we check for a log entry greater - // than what we are watching for. This is to handle the case in which logs - // rotate faster than we can detect and we have to rely on a normal - // directory listing. - nextLogCheckRate = 100 * time.Millisecond - - // deleteEvent and truncateEvent are the file events that can be sent in a - // StreamFrame - deleteEvent = "file deleted" - truncateEvent = "file truncated" - - // OriginStart and OriginEnd are the available parameters for the origin - // argument when streaming a file. They respectively offset from the start - // and end of a file. - OriginStart = "start" - OriginEnd = "end" -) - func (s *HTTPServer) FsRequest(resp http.ResponseWriter, req *http.Request) (interface{}, error) { - if s.agent.client == nil { - return nil, clientNotRunning - } - - var secret string - s.parseToken(req, &secret) - - var namespace string - parseNamespace(req, &namespace) - - aclObj, err := s.agent.Client().ResolveToken(secret) - if err != nil { - return nil, err - } - path := strings.TrimPrefix(req.URL.Path, "/v1/client/fs/") switch { case strings.HasPrefix(path, "ls/"): - if aclObj != nil && !aclObj.AllowNsOp(namespace, acl.NamespaceCapabilityReadFS) { - return nil, structs.ErrPermissionDenied - } return s.DirectoryListRequest(resp, req) case strings.HasPrefix(path, "stat/"): - if aclObj != nil && !aclObj.AllowNsOp(namespace, acl.NamespaceCapabilityReadFS) { - return nil, structs.ErrPermissionDenied - } return s.FileStatRequest(resp, req) case strings.HasPrefix(path, "readat/"): - if aclObj != nil && !aclObj.AllowNsOp(namespace, acl.NamespaceCapabilityReadFS) { - return nil, structs.ErrPermissionDenied - } return s.FileReadAtRequest(resp, req) case strings.HasPrefix(path, "cat/"): - if aclObj != nil && !aclObj.AllowNsOp(namespace, acl.NamespaceCapabilityReadFS) { - return nil, structs.ErrPermissionDenied - } return s.FileCatRequest(resp, req) case strings.HasPrefix(path, "stream/"): - if aclObj != nil && !aclObj.AllowNsOp(namespace, acl.NamespaceCapabilityReadFS) { - return nil, structs.ErrPermissionDenied - } return s.Stream(resp, req) case strings.HasPrefix(path, "logs/"): - // Logs can be accessed with ReadFS or ReadLogs caps - if aclObj != nil { - readfs := aclObj.AllowNsOp(namespace, acl.NamespaceCapabilityReadFS) - logs := aclObj.AllowNsOp(namespace, acl.NamespaceCapabilityReadLogs) - if !readfs && !logs { - return nil, structs.ErrPermissionDenied - } - } return s.Logs(resp, req) default: return nil, CodedError(404, ErrInvalidMethod) @@ -133,11 +54,36 @@ func (s *HTTPServer) DirectoryListRequest(resp http.ResponseWriter, req *http.Re if path = req.URL.Query().Get("path"); path == "" { path = "/" } - fs, err := s.agent.client.GetAllocFS(allocID) - if err != nil { - return nil, err + + // Create the request + args := &cstructs.FsListRequest{ + AllocID: allocID, + Path: path, } - return fs.List(path) + s.parse(resp, req, &args.QueryOptions.Region, &args.QueryOptions) + + // Make the RPC + localClient, remoteClient, localServer := s.rpcHandlerForAlloc(allocID) + + var reply cstructs.FsListResponse + var rpcErr error + if localClient { + rpcErr = s.agent.Client().ClientRPC("FileSystem.List", &args, &reply) + } else if remoteClient { + rpcErr = s.agent.Client().RPC("FileSystem.List", &args, &reply) + } else if localServer { + rpcErr = s.agent.Server().RPC("FileSystem.List", &args, &reply) + } + + if rpcErr != nil { + if structs.IsErrNoNodeConn(rpcErr) || structs.IsErrUnknownAllocation(rpcErr) { + rpcErr = CodedError(404, rpcErr.Error()) + } + + return nil, rpcErr + } + + return reply.Files, nil } func (s *HTTPServer) FileStatRequest(resp http.ResponseWriter, req *http.Request) (interface{}, error) { @@ -148,11 +94,36 @@ func (s *HTTPServer) FileStatRequest(resp http.ResponseWriter, req *http.Request if path = req.URL.Query().Get("path"); path == "" { return nil, fileNameNotPresentErr } - fs, err := s.agent.client.GetAllocFS(allocID) - if err != nil { - return nil, err + + // Create the request + args := &cstructs.FsStatRequest{ + AllocID: allocID, + Path: path, + } + s.parse(resp, req, &args.QueryOptions.Region, &args.QueryOptions) + + // Make the RPC + localClient, remoteClient, localServer := s.rpcHandlerForAlloc(allocID) + + var reply cstructs.FsStatResponse + var rpcErr error + if localClient { + rpcErr = s.agent.Client().ClientRPC("FileSystem.Stat", &args, &reply) + } else if remoteClient { + rpcErr = s.agent.Client().RPC("FileSystem.Stat", &args, &reply) + } else if localServer { + rpcErr = s.agent.Server().RPC("FileSystem.Stat", &args, &reply) + } + + if rpcErr != nil { + if structs.IsErrNoNodeConn(rpcErr) || structs.IsErrUnknownAllocation(rpcErr) { + rpcErr = CodedError(404, rpcErr.Error()) + } + + return nil, rpcErr } - return fs.Stat(path) + + return reply.Info, nil } func (s *HTTPServer) FileReadAtRequest(resp http.ResponseWriter, req *http.Request) (interface{}, error) { @@ -180,37 +151,23 @@ func (s *HTTPServer) FileReadAtRequest(resp http.ResponseWriter, req *http.Reque } } - fs, err := s.agent.client.GetAllocFS(allocID) - if err != nil { - return nil, err - } - - rc, err := fs.ReadAt(path, offset) - if limit > 0 { - rc = &ReadCloserWrapper{ - Reader: io.LimitReader(rc, limit), - Closer: rc, - } - } - - if err != nil { - return nil, err + // Create the request arguments + fsReq := &cstructs.FsStreamRequest{ + AllocID: allocID, + Path: path, + Offset: offset, + Origin: "start", + Limit: limit, + PlainText: true, } + s.parse(resp, req, &fsReq.QueryOptions.Region, &fsReq.QueryOptions) - io.Copy(resp, rc) - return nil, rc.Close() -} - -// ReadCloserWrapper wraps a LimitReader so that a file is closed once it has been -// read -type ReadCloserWrapper struct { - io.Reader - io.Closer + // Make the request + return s.fsStreamImpl(resp, req, "FileSystem.Stream", fsReq, fsReq.AllocID) } func (s *HTTPServer) FileCatRequest(resp http.ResponseWriter, req *http.Request) (interface{}, error) { var allocID, path string - var err error q := req.URL.Query() @@ -220,326 +177,18 @@ func (s *HTTPServer) FileCatRequest(resp http.ResponseWriter, req *http.Request) if path = q.Get("path"); path == "" { return nil, fileNameNotPresentErr } - fs, err := s.agent.client.GetAllocFS(allocID) - if err != nil { - return nil, err - } - fileInfo, err := fs.Stat(path) - if err != nil { - return nil, err - } - if fileInfo.IsDir { - return nil, fmt.Errorf("file %q is a directory", path) + // Create the request arguments + fsReq := &cstructs.FsStreamRequest{ + AllocID: allocID, + Path: path, + Origin: "start", + PlainText: true, } + s.parse(resp, req, &fsReq.QueryOptions.Region, &fsReq.QueryOptions) - r, err := fs.ReadAt(path, int64(0)) - if err != nil { - return nil, err - } - io.Copy(resp, r) - return nil, r.Close() -} - -var ( - // HeartbeatStreamFrame is the StreamFrame to send as a heartbeat, avoiding - // creating many instances of the empty StreamFrame - HeartbeatStreamFrame = &StreamFrame{} -) - -// StreamFrame is used to frame data of a file when streaming -type StreamFrame struct { - // Offset is the offset the data was read from - Offset int64 `json:",omitempty"` - - // Data is the read data - Data []byte `json:",omitempty"` - - // File is the file that the data was read from - File string `json:",omitempty"` - - // FileEvent is the last file event that occurred that could cause the - // streams position to change or end - FileEvent string `json:",omitempty"` -} - -// IsHeartbeat returns if the frame is a heartbeat frame -func (s *StreamFrame) IsHeartbeat() bool { - return s.Offset == 0 && len(s.Data) == 0 && s.File == "" && s.FileEvent == "" -} - -func (s *StreamFrame) Clear() { - s.Offset = 0 - s.Data = nil - s.File = "" - s.FileEvent = "" -} - -func (s *StreamFrame) IsCleared() bool { - if s.Offset != 0 { - return false - } else if s.Data != nil { - return false - } else if s.File != "" { - return false - } else if s.FileEvent != "" { - return false - } else { - return true - } -} - -// StreamFramer is used to buffer and send frames as well as heartbeat. -type StreamFramer struct { - // plainTxt determines whether we frame or just send plain text data. - plainTxt bool - - out io.WriteCloser - enc *codec.Encoder - encLock sync.Mutex - - frameSize int - - heartbeat *time.Ticker - flusher *time.Ticker - - shutdownCh chan struct{} - exitCh chan struct{} - - // The mutex protects everything below - l sync.Mutex - - // The current working frame - f StreamFrame - data *bytes.Buffer - - // Captures whether the framer is running and any error that occurred to - // cause it to stop. - running bool - err error -} - -// NewStreamFramer creates a new stream framer that will output StreamFrames to -// the passed output. If plainTxt is set we do not frame and just batch plain -// text data. -func NewStreamFramer(out io.WriteCloser, plainTxt bool, - heartbeatRate, batchWindow time.Duration, frameSize int) *StreamFramer { - - // Create a JSON encoder - enc := codec.NewEncoder(out, structs.JsonHandle) - - // Create the heartbeat and flush ticker - heartbeat := time.NewTicker(heartbeatRate) - flusher := time.NewTicker(batchWindow) - - return &StreamFramer{ - plainTxt: plainTxt, - out: out, - enc: enc, - frameSize: frameSize, - heartbeat: heartbeat, - flusher: flusher, - data: bytes.NewBuffer(make([]byte, 0, 2*frameSize)), - shutdownCh: make(chan struct{}), - exitCh: make(chan struct{}), - } -} - -// Destroy is used to cleanup the StreamFramer and flush any pending frames -func (s *StreamFramer) Destroy() { - s.l.Lock() - close(s.shutdownCh) - s.heartbeat.Stop() - s.flusher.Stop() - running := s.running - s.l.Unlock() - - // Ensure things were flushed - if running { - <-s.exitCh - } - s.out.Close() -} - -// Run starts a long lived goroutine that handles sending data as well as -// heartbeating -func (s *StreamFramer) Run() { - s.l.Lock() - defer s.l.Unlock() - if s.running { - return - } - - s.running = true - go s.run() -} - -// ExitCh returns a channel that will be closed when the run loop terminates. -func (s *StreamFramer) ExitCh() <-chan struct{} { - return s.exitCh -} - -// Err returns the error that caused the StreamFramer to exit -func (s *StreamFramer) Err() error { - s.l.Lock() - defer s.l.Unlock() - return s.err -} - -// run is the internal run method. It exits if Destroy is called or an error -// occurs, in which case the exit channel is closed. -func (s *StreamFramer) run() { - var err error - defer func() { - s.l.Lock() - s.running = false - s.err = err - s.l.Unlock() - close(s.exitCh) - }() - -OUTER: - for { - select { - case <-s.shutdownCh: - break OUTER - case <-s.flusher.C: - // Skip if there is nothing to flush - s.l.Lock() - if s.f.IsCleared() { - s.l.Unlock() - continue - } - - // Read the data for the frame, and send it - s.f.Data = s.readData() - err = s.send(&s.f) - s.f.Clear() - s.l.Unlock() - if err != nil { - return - } - case <-s.heartbeat.C: - // Send a heartbeat frame - if err = s.send(HeartbeatStreamFrame); err != nil { - return - } - } - } - - s.l.Lock() - if !s.f.IsCleared() { - s.f.Data = s.readData() - err = s.send(&s.f) - s.f.Clear() - } - s.l.Unlock() -} - -// send takes a StreamFrame, encodes and sends it -func (s *StreamFramer) send(f *StreamFrame) error { - s.encLock.Lock() - defer s.encLock.Unlock() - if s.plainTxt { - _, err := io.Copy(s.out, bytes.NewReader(f.Data)) - return err - } - return s.enc.Encode(f) -} - -// readData is a helper which reads the buffered data returning up to the frame -// size of data. Must be called with the lock held. The returned value is -// invalid on the next read or write into the StreamFramer buffer -func (s *StreamFramer) readData() []byte { - // Compute the amount to read from the buffer - size := s.data.Len() - if size > s.frameSize { - size = s.frameSize - } - if size == 0 { - return nil - } - d := s.data.Next(size) - return d -} - -// Send creates and sends a StreamFrame based on the passed parameters. An error -// is returned if the run routine hasn't run or encountered an error. Send is -// asynchronous and does not block for the data to be transferred. -func (s *StreamFramer) Send(file, fileEvent string, data []byte, offset int64) error { - s.l.Lock() - defer s.l.Unlock() - - // If we are not running, return the error that caused us to not run or - // indicated that it was never started. - if !s.running { - if s.err != nil { - return s.err - } - - return fmt.Errorf("StreamFramer not running") - } - - // Check if not mergeable - if !s.f.IsCleared() && (s.f.File != file || s.f.FileEvent != fileEvent) { - // Flush the old frame - s.f.Data = s.readData() - select { - case <-s.exitCh: - return nil - default: - } - err := s.send(&s.f) - s.f.Clear() - if err != nil { - return err - } - } - - // Store the new data as the current frame. - if s.f.IsCleared() { - s.f.Offset = offset - s.f.File = file - s.f.FileEvent = fileEvent - } - - // Write the data to the buffer - s.data.Write(data) - - // Handle the delete case in which there is no data - force := false - if s.data.Len() == 0 && s.f.FileEvent != "" { - force = true - } - - // Flush till we are under the max frame size - for s.data.Len() >= s.frameSize || force { - // Clear - if force { - force = false - } - - // Create a new frame to send it - s.f.Data = s.readData() - select { - case <-s.exitCh: - return nil - default: - } - - if err := s.send(&s.f); err != nil { - return err - } - - // Update the offset - s.f.Offset += int64(len(s.f.Data)) - } - - if s.data.Len() == 0 { - s.f.Clear() - } - - return nil + // Make the request + return s.fsStreamImpl(resp, req, "FileSystem.Stream", fsReq, fsReq.AllocID) } // Stream streams the content of a file blocking on EOF. @@ -550,7 +199,6 @@ func (s *StreamFramer) Send(file, fileEvent string, data []byte, offset int64) e // applied. Defaults to "start". func (s *HTTPServer) Stream(resp http.ResponseWriter, req *http.Request) (interface{}, error) { var allocID, path string - var err error q := req.URL.Query() @@ -580,173 +228,18 @@ func (s *HTTPServer) Stream(resp http.ResponseWriter, req *http.Request) (interf return nil, invalidOrigin } - fs, err := s.agent.client.GetAllocFS(allocID) - if err != nil { - return nil, err - } - - fileInfo, err := fs.Stat(path) - if err != nil { - return nil, err - } - if fileInfo.IsDir { - return nil, fmt.Errorf("file %q is a directory", path) - } - - // If offsetting from the end subtract from the size - if origin == "end" { - offset = fileInfo.Size - offset - - } - - // Create an output that gets flushed on every write - output := ioutils.NewWriteFlusher(resp) - - // Create the framer - framer := NewStreamFramer(output, false, streamHeartbeatRate, streamBatchWindow, streamFrameSize) - framer.Run() - defer framer.Destroy() - - err = s.stream(offset, path, fs, framer, nil) - if err != nil && err != syscall.EPIPE { - return nil, err + // Create the request arguments + fsReq := &cstructs.FsStreamRequest{ + AllocID: allocID, + Path: path, + Origin: origin, + Offset: offset, + Follow: true, } + s.parse(resp, req, &fsReq.QueryOptions.Region, &fsReq.QueryOptions) - return nil, nil -} - -// parseFramerErr takes an error and returns an error. The error will -// potentially change if it was caused by the connection being closed. -func parseFramerErr(err error) error { - if err == nil { - return nil - } - - errMsg := err.Error() - - if strings.Contains(errMsg, io.ErrClosedPipe.Error()) { - // The pipe check is for tests - return syscall.EPIPE - } - - // The connection was closed by our peer - if strings.Contains(errMsg, syscall.EPIPE.Error()) || strings.Contains(errMsg, syscall.ECONNRESET.Error()) { - return syscall.EPIPE - } - - // Windows version of ECONNRESET - //XXX(schmichael) I could find no existing error or constant to - // compare this against. - if strings.Contains(errMsg, "forcibly closed") { - return syscall.EPIPE - } - - return err -} - -// stream is the internal method to stream the content of a file. eofCancelCh is -// used to cancel the stream if triggered while at EOF. If the connection is -// broken an EPIPE error is returned -func (s *HTTPServer) stream(offset int64, path string, - fs allocdir.AllocDirFS, framer *StreamFramer, - eofCancelCh chan error) error { - - // Get the reader - f, err := fs.ReadAt(path, offset) - if err != nil { - return err - } - defer f.Close() - - // Create a tomb to cancel watch events - t := tomb.Tomb{} - defer func() { - t.Kill(nil) - t.Done() - }() - - // Create a variable to allow setting the last event - var lastEvent string - - // Only create the file change watcher once. But we need to do it after we - // read and reach EOF. - var changes *watch.FileChanges - - // Start streaming the data - data := make([]byte, streamFrameSize) -OUTER: - for { - // Read up to the max frame size - n, readErr := f.Read(data) - - // Update the offset - offset += int64(n) - - // Return non-EOF errors - if readErr != nil && readErr != io.EOF { - return readErr - } - - // Send the frame - if n != 0 || lastEvent != "" { - if err := framer.Send(path, lastEvent, data[:n], offset); err != nil { - return parseFramerErr(err) - } - } - - // Clear the last event - if lastEvent != "" { - lastEvent = "" - } - - // Just keep reading - if readErr == nil { - continue - } - - // If EOF is hit, wait for a change to the file - if changes == nil { - changes, err = fs.ChangeEvents(path, offset, &t) - if err != nil { - return err - } - } - - for { - select { - case <-changes.Modified: - continue OUTER - case <-changes.Deleted: - return parseFramerErr(framer.Send(path, deleteEvent, nil, offset)) - case <-changes.Truncated: - // Close the current reader - if err := f.Close(); err != nil { - return err - } - - // Get a new reader at offset zero - offset = 0 - var err error - f, err = fs.ReadAt(path, offset) - if err != nil { - return err - } - defer f.Close() - - // Store the last event - lastEvent = truncateEvent - continue OUTER - case <-framer.ExitCh(): - return parseFramerErr(framer.Err()) - case err, ok := <-eofCancelCh: - if !ok { - return nil - } - - return err - } - } - } + // Make the request + return s.fsStreamImpl(resp, req, "FileSystem.Stream", fsReq, fsReq.AllocID) } // Logs streams the content of a log blocking on EOF. The parameters are: @@ -762,7 +255,6 @@ func (s *HTTPServer) Logs(resp http.ResponseWriter, req *http.Request) (interfac var err error q := req.URL.Query() - if allocID = strings.TrimPrefix(req.URL.Path, "/v1/client/fs/logs/"); allocID == "" { return nil, allocIDNotPresentErr } @@ -808,318 +300,108 @@ func (s *HTTPServer) Logs(resp http.ResponseWriter, req *http.Request) (interfac return nil, invalidOrigin } - fs, err := s.agent.client.GetAllocFS(allocID) - if err != nil { - return nil, err + // Create the request arguments + fsReq := &cstructs.FsLogsRequest{ + AllocID: allocID, + Task: task, + LogType: logType, + Offset: offset, + Origin: origin, + PlainText: plain, + Follow: follow, } + s.parse(resp, req, &fsReq.QueryOptions.Region, &fsReq.QueryOptions) - alloc, err := s.agent.client.GetClientAlloc(allocID) - if err != nil { - return nil, err - } - - // Check that the task is there - tg := alloc.Job.LookupTaskGroup(alloc.TaskGroup) - if tg == nil { - return nil, fmt.Errorf("Failed to lookup task group for allocation") - } else if taskStruct := tg.LookupTask(task); taskStruct == nil { - return nil, CodedError(404, fmt.Sprintf("task group %q does not have task with name %q", alloc.TaskGroup, task)) - } - - state, ok := alloc.TaskStates[task] - if !ok || state.StartedAt.IsZero() { - return nil, CodedError(404, fmt.Sprintf("task %q not started yet. No logs available", task)) - } - - // Create an output that gets flushed on every write - output := ioutils.NewWriteFlusher(resp) - - return nil, s.logs(follow, plain, offset, origin, task, logType, fs, output) + // Make the request + return s.fsStreamImpl(resp, req, "FileSystem.Logs", fsReq, fsReq.AllocID) } -func (s *HTTPServer) logs(follow, plain bool, offset int64, - origin, task, logType string, - fs allocdir.AllocDirFS, output io.WriteCloser) error { - - // Create the framer - framer := NewStreamFramer(output, plain, streamHeartbeatRate, streamBatchWindow, streamFrameSize) - framer.Run() - defer framer.Destroy() - - // Path to the logs - logPath := filepath.Join(allocdir.SharedAllocName, allocdir.LogDirName) +// fsStreamImpl is used to make a streaming filesystem call that serializes the +// args and then expects a stream of StreamErrWrapper results where the payload +// is copied to the response body. +func (s *HTTPServer) fsStreamImpl(resp http.ResponseWriter, + req *http.Request, method string, args interface{}, allocID string) (interface{}, error) { - // nextIdx is the next index to read logs from - var nextIdx int64 - switch origin { - case "start": - nextIdx = 0 - case "end": - nextIdx = math.MaxInt64 - offset *= -1 - default: - return invalidOrigin + // Get the correct handler + localClient, remoteClient, localServer := s.rpcHandlerForAlloc(allocID) + var handler structs.StreamingRpcHandler + var handlerErr error + if localClient { + handler, handlerErr = s.agent.Client().StreamingRpcHandler(method) + } else if remoteClient { + handler, handlerErr = s.agent.Client().RemoteStreamingRpcHandler(method) + } else if localServer { + handler, handlerErr = s.agent.Server().StreamingRpcHandler(method) } - // Create a tomb to cancel watch events - t := tomb.Tomb{} - defer func() { - t.Kill(nil) - t.Done() - }() - - for { - // Logic for picking next file is: - // 1) List log files - // 2) Pick log file closest to desired index - // 3) Open log file at correct offset - // 3a) No error, read contents - // 3b) If file doesn't exist, goto 1 as it may have been rotated out - entries, err := fs.List(logPath) - if err != nil { - return fmt.Errorf("failed to list entries: %v", err) - } - - // If we are not following logs, determine the max index for the logs we are - // interested in so we can stop there. - maxIndex := int64(math.MaxInt64) - if !follow { - _, idx, _, err := findClosest(entries, maxIndex, 0, task, logType) - if err != nil { - return err - } - maxIndex = idx - } - - logEntry, idx, openOffset, err := findClosest(entries, nextIdx, offset, task, logType) - if err != nil { - return err - } - - var eofCancelCh chan error - exitAfter := false - if !follow && idx > maxIndex { - // Exceeded what was there initially so return - return nil - } else if !follow && idx == maxIndex { - // At the end - eofCancelCh = make(chan error) - close(eofCancelCh) - exitAfter = true - } else { - eofCancelCh = blockUntilNextLog(fs, &t, logPath, task, logType, idx+1) - } - - p := filepath.Join(logPath, logEntry.Name) - err = s.stream(openOffset, p, fs, framer, eofCancelCh) - - if err != nil { - // Check if there was an error where the file does not exist. That means - // it got rotated out from under us. - if os.IsNotExist(err) { - continue - } - - // Check if the connection was closed - if err == syscall.EPIPE { - return nil - } - - return fmt.Errorf("failed to stream %q: %v", p, err) - } - - if exitAfter { - return nil - } + if handlerErr != nil { + return nil, CodedError(500, handlerErr.Error()) + } - // defensively check to make sure StreamFramer hasn't stopped - // running to avoid tight loops with goroutine leaks as in - // #3342 - select { - case <-framer.ExitCh(): - err := parseFramerErr(framer.Err()) - if err == syscall.EPIPE { - // EPIPE just means the connection was closed - return nil - } - return err - default: - } + p1, p2 := net.Pipe() + decoder := codec.NewDecoder(p1, structs.MsgpackHandle) + encoder := codec.NewEncoder(p1, structs.MsgpackHandle) - // Since we successfully streamed, update the overall offset/idx. - offset = int64(0) - nextIdx = idx + 1 - } -} + // Create a goroutine that closes the pipe if the connection closes. + ctx, cancel := context.WithCancel(req.Context()) + go func() { + <-ctx.Done() + p1.Close() + }() -// blockUntilNextLog returns a channel that will have data sent when the next -// log index or anything greater is created. -func blockUntilNextLog(fs allocdir.AllocDirFS, t *tomb.Tomb, logPath, task, logType string, nextIndex int64) chan error { - nextPath := filepath.Join(logPath, fmt.Sprintf("%s.%s.%d", task, logType, nextIndex)) - next := make(chan error, 1) + // Create an output that gets flushed on every write + output := ioutils.NewWriteFlusher(resp) + // Create a channel that decodes the results + errCh := make(chan HTTPCodedError) go func() { - eofCancelCh, err := fs.BlockUntilExists(nextPath, t) - if err != nil { - next <- err - close(next) + // Send the request + if err := encoder.Encode(args); err != nil { + errCh <- CodedError(500, err.Error()) + cancel() return } - ticker := time.NewTicker(nextLogCheckRate) - defer ticker.Stop() - scanCh := ticker.C for { select { - case <-t.Dead(): - next <- fmt.Errorf("shutdown triggered") - close(next) + case <-ctx.Done(): + errCh <- nil + cancel() return - case err := <-eofCancelCh: - next <- err - close(next) + default: + } + + var res cstructs.StreamErrWrapper + if err := decoder.Decode(&res); err != nil { + errCh <- CodedError(500, err.Error()) + cancel() return - case <-scanCh: - entries, err := fs.List(logPath) - if err != nil { - next <- fmt.Errorf("failed to list entries: %v", err) - close(next) - return - } + } - indexes, err := logIndexes(entries, task, logType) - if err != nil { - next <- err - close(next) + if err := res.Error; err != nil { + if err.Code != nil { + errCh <- CodedError(int(*err.Code), err.Error()) + cancel() return } + } - // Scan and see if there are any entries larger than what we are - // waiting for. - for _, entry := range indexes { - if entry.idx >= nextIndex { - next <- nil - close(next) - return - } - } + if _, err := io.Copy(output, bytes.NewBuffer(res.Payload)); err != nil { + errCh <- CodedError(500, err.Error()) + cancel() + return } } }() - return next -} - -// indexTuple and indexTupleArray are used to find the correct log entry to -// start streaming logs from -type indexTuple struct { - idx int64 - entry *allocdir.AllocFileInfo -} - -type indexTupleArray []indexTuple - -func (a indexTupleArray) Len() int { return len(a) } -func (a indexTupleArray) Less(i, j int) bool { return a[i].idx < a[j].idx } -func (a indexTupleArray) Swap(i, j int) { a[i], a[j] = a[j], a[i] } - -// logIndexes takes a set of entries and returns a indexTupleArray of -// the desired log file entries. If the indexes could not be determined, an -// error is returned. -func logIndexes(entries []*allocdir.AllocFileInfo, task, logType string) (indexTupleArray, error) { - var indexes []indexTuple - prefix := fmt.Sprintf("%s.%s.", task, logType) - for _, entry := range entries { - if entry.IsDir { - continue - } - - // If nothing was trimmed, then it is not a match - idxStr := strings.TrimPrefix(entry.Name, prefix) - if idxStr == entry.Name { - continue - } - - // Convert to an int - idx, err := strconv.Atoi(idxStr) - if err != nil { - return nil, fmt.Errorf("failed to convert %q to a log index: %v", idxStr, err) - } - - indexes = append(indexes, indexTuple{idx: int64(idx), entry: entry}) - } - - return indexTupleArray(indexes), nil -} - -// findClosest takes a list of entries, the desired log index and desired log -// offset (which can be negative, treated as offset from end), task name and log -// type and returns the log entry, the log index, the offset to read from and a -// potential error. -func findClosest(entries []*allocdir.AllocFileInfo, desiredIdx, desiredOffset int64, - task, logType string) (*allocdir.AllocFileInfo, int64, int64, error) { - - // Build the matching indexes - indexes, err := logIndexes(entries, task, logType) - if err != nil { - return nil, 0, 0, err - } - if len(indexes) == 0 { - return nil, 0, 0, fmt.Errorf("log entry for task %q and log type %q not found", task, logType) - } - - // Binary search the indexes to get the desiredIdx - sort.Sort(indexes) - i := sort.Search(len(indexes), func(i int) bool { return indexes[i].idx >= desiredIdx }) - l := len(indexes) - if i == l { - // Use the last index if the number is bigger than all of them. - i = l - 1 + handler(p2) + cancel() + codedErr := <-errCh + if codedErr != nil && + (codedErr == io.EOF || + strings.Contains(codedErr.Error(), "closed") || + strings.Contains(codedErr.Error(), "EOF")) { + codedErr = nil } - - // Get to the correct offset - offset := desiredOffset - idx := int64(i) - for { - s := indexes[idx].entry.Size - - // Base case - if offset == 0 { - break - } else if offset < 0 { - // Going backwards - if newOffset := s + offset; newOffset >= 0 { - // Current file works - offset = newOffset - break - } else if idx == 0 { - // Already at the end - offset = 0 - break - } else { - // Try the file before - offset = newOffset - idx -= 1 - continue - } - } else { - // Going forward - if offset <= s { - // Current file works - break - } else if idx == int64(l-1) { - // Already at the end - offset = s - break - } else { - // Try the next file - offset = offset - s - idx += 1 - continue - } - - } - } - - return indexes[idx].entry, indexes[idx].idx, offset, nil + return nil, codedErr } diff --git a/command/agent/fs_endpoint_test.go b/command/agent/fs_endpoint_test.go index ea31b9ee82e..f59bbd953b0 100644 --- a/command/agent/fs_endpoint_test.go +++ b/command/agent/fs_endpoint_test.go @@ -1,1470 +1,468 @@ package agent import ( - "bytes" + "encoding/base64" "fmt" "io" "io/ioutil" - "log" - "math" "net/http" "net/http/httptest" - "os" - "path/filepath" - "reflect" - "runtime" - "strconv" "strings" "testing" "time" - "github.com/hashicorp/nomad/acl" - "github.com/hashicorp/nomad/client/allocdir" + cstructs "github.com/hashicorp/nomad/client/structs" "github.com/hashicorp/nomad/nomad/mock" "github.com/hashicorp/nomad/nomad/structs" "github.com/hashicorp/nomad/testutil" - "github.com/stretchr/testify/assert" - "github.com/ugorji/go/codec" + "github.com/stretchr/testify/require" ) -func TestAllocDirFS_List_MissingParams(t *testing.T) { - t.Parallel() - httpTest(t, nil, func(s *TestAgent) { - req, err := http.NewRequest("GET", "/v1/client/fs/ls/", nil) - if err != nil { - t.Fatalf("err: %v", err) - } - respW := httptest.NewRecorder() - - _, err = s.Server.DirectoryListRequest(respW, req) - if err != allocIDNotPresentErr { - t.Fatalf("expected err: %v, actual: %v", allocIDNotPresentErr, err) - } - }) -} - -func TestAllocDirFS_Stat_MissingParams(t *testing.T) { - t.Parallel() - httpTest(t, nil, func(s *TestAgent) { - req, err := http.NewRequest("GET", "/v1/client/fs/stat/", nil) - if err != nil { - t.Fatalf("err: %v", err) - } - respW := httptest.NewRecorder() - - _, err = s.Server.FileStatRequest(respW, req) - if err != allocIDNotPresentErr { - t.Fatalf("expected err: %v, actual: %v", allocIDNotPresentErr, err) - } - - req, err = http.NewRequest("GET", "/v1/client/fs/stat/foo", nil) - if err != nil { - t.Fatalf("err: %v", err) - } - respW = httptest.NewRecorder() +const ( + defaultLoggerMockDriverStdout = "Hello from the other side" +) - _, err = s.Server.FileStatRequest(respW, req) - if err != fileNameNotPresentErr { - t.Fatalf("expected err: %v, actual: %v", allocIDNotPresentErr, err) - } +var ( + defaultLoggerMockDriver = map[string]interface{}{ + "run_for": "2s", + "stdout_string": defaultLoggerMockDriverStdout, + } +) - }) -} +type clientAllocWaiter int -func TestAllocDirFS_ReadAt_MissingParams(t *testing.T) { - t.Parallel() - httpTest(t, nil, func(s *TestAgent) { - req, err := http.NewRequest("GET", "/v1/client/fs/readat/", nil) - if err != nil { - t.Fatalf("err: %v", err) - } - respW := httptest.NewRecorder() +const ( + noWaitClientAlloc clientAllocWaiter = iota + runningClientAlloc + terminalClientAlloc +) - _, err = s.Server.FileReadAtRequest(respW, req) - if err == nil { - t.Fatal("expected error") - } +func addAllocToClient(agent *TestAgent, alloc *structs.Allocation, wait clientAllocWaiter) { + require := require.New(agent.T) - req, err = http.NewRequest("GET", "/v1/client/fs/readat/foo", nil) + // Wait for the client to connect + testutil.WaitForResult(func() (bool, error) { + node, err := agent.server.State().NodeByID(nil, agent.client.NodeID()) if err != nil { - t.Fatalf("err: %v", err) - } - respW = httptest.NewRecorder() - - _, err = s.Server.FileReadAtRequest(respW, req) - if err == nil { - t.Fatal("expected error") + return false, err } - - req, err = http.NewRequest("GET", "/v1/client/fs/readat/foo?path=/path/to/file", nil) - if err != nil { - t.Fatalf("err: %v", err) + if node == nil { + return false, fmt.Errorf("unknown node") } - respW = httptest.NewRecorder() - _, err = s.Server.FileReadAtRequest(respW, req) - if err == nil { - t.Fatal("expected error") - } + return node.Status == structs.NodeStatusReady, fmt.Errorf("bad node status") + }, func(err error) { + agent.T.Fatal(err) }) -} - -func TestAllocDirFS_ACL(t *testing.T) { - t.Parallel() - assert := assert.New(t) - - for _, endpoint := range []string{"ls", "stat", "readat", "cat", "stream"} { - httpACLTest(t, nil, func(s *TestAgent) { - state := s.Agent.server.State() - - req, err := http.NewRequest("GET", fmt.Sprintf("/v1/client/fs/%s/", endpoint), nil) - assert.Nil(err) - - // Try request without a token and expect failure - { - respW := httptest.NewRecorder() - _, err := s.Server.FsRequest(respW, req) - assert.NotNil(err) - assert.Equal(err.Error(), structs.ErrPermissionDenied.Error()) - } - - // Try request with an invalid token and expect failure - { - respW := httptest.NewRecorder() - policy := mock.NamespacePolicy(structs.DefaultNamespace, "", []string{acl.NamespaceCapabilityReadLogs}) - token := mock.CreatePolicyAndToken(t, state, 1005, "invalid", policy) - setToken(req, token) - _, err := s.Server.FsRequest(respW, req) - assert.NotNil(err) - assert.Equal(err.Error(), structs.ErrPermissionDenied.Error()) - } - // Try request with a valid token - // No alloc id set, so expect an error - just not a permissions error - { - respW := httptest.NewRecorder() - policy := mock.NamespacePolicy(structs.DefaultNamespace, "", []string{acl.NamespaceCapabilityReadFS}) - token := mock.CreatePolicyAndToken(t, state, 1007, "valid", policy) - setToken(req, token) - _, err := s.Server.FsRequest(respW, req) - assert.NotNil(err) - assert.Equal(allocIDNotPresentErr, err) - } + // Upsert the allocation + state := agent.server.State() + require.Nil(state.UpsertJob(999, alloc.Job)) + require.Nil(state.UpsertAllocs(1003, []*structs.Allocation{alloc})) - // Try request with a management token - // No alloc id set, so expect an error - just not a permissions error - { - respW := httptest.NewRecorder() - setToken(req, s.RootToken) - _, err := s.Server.FsRequest(respW, req) - assert.NotNil(err) - assert.Equal(allocIDNotPresentErr, err) - } - }) + if wait == noWaitClientAlloc { + return } -} - -func TestAllocDirFS_Logs_ACL(t *testing.T) { - t.Parallel() - assert := assert.New(t) - httpACLTest(t, nil, func(s *TestAgent) { - state := s.Agent.server.State() - - req, err := http.NewRequest("GET", "/v1/client/fs/logs/", nil) - assert.Nil(err) - - // Try request without a token and expect failure - { - respW := httptest.NewRecorder() - _, err := s.Server.FsRequest(respW, req) - assert.NotNil(err) - assert.Equal(err.Error(), structs.ErrPermissionDenied.Error()) + // Wait for the client to run the allocation + testutil.WaitForResult(func() (bool, error) { + alloc, err := state.AllocByID(nil, alloc.ID) + if err != nil { + return false, err } - - // Try request with an invalid token and expect failure - { - respW := httptest.NewRecorder() - policy := mock.NamespacePolicy("other", "", []string{acl.NamespaceCapabilityReadFS}) - token := mock.CreatePolicyAndToken(t, state, 1005, "invalid", policy) - setToken(req, token) - _, err := s.Server.FsRequest(respW, req) - assert.NotNil(err) - assert.Equal(err.Error(), structs.ErrPermissionDenied.Error()) + if alloc == nil { + return false, fmt.Errorf("unknown alloc") } - // Try request with a valid token (ReadFS) - // No alloc id set, so expect an error - just not a permissions error - { - respW := httptest.NewRecorder() - policy := mock.NamespacePolicy(structs.DefaultNamespace, "", []string{acl.NamespaceCapabilityReadFS}) - token := mock.CreatePolicyAndToken(t, state, 1007, "valid1", policy) - setToken(req, token) - _, err := s.Server.FsRequest(respW, req) - assert.NotNil(err) - assert.Equal(allocIDNotPresentErr, err) + expectation := alloc.ClientStatus == structs.AllocClientStatusComplete || + alloc.ClientStatus == structs.AllocClientStatusFailed + if wait == runningClientAlloc { + expectation = expectation || alloc.ClientStatus == structs.AllocClientStatusRunning } - // Try request with a valid token (ReadLogs) - // No alloc id set, so expect an error - just not a permissions error - { - respW := httptest.NewRecorder() - policy := mock.NamespacePolicy(structs.DefaultNamespace, "", []string{acl.NamespaceCapabilityReadLogs}) - token := mock.CreatePolicyAndToken(t, state, 1009, "valid2", policy) - setToken(req, token) - _, err := s.Server.FsRequest(respW, req) - assert.NotNil(err) - assert.Equal(allocIDNotPresentErr, err) + if !expectation { + return false, fmt.Errorf("alloc client status: %v", alloc.ClientStatus) } - // Try request with a management token - // No alloc id set, so expect an error - just not a permissions error - { - respW := httptest.NewRecorder() - setToken(req, s.RootToken) - _, err := s.Server.FsRequest(respW, req) - assert.NotNil(err) - assert.Equal(allocIDNotPresentErr, err) - } + return true, nil + }, func(err error) { + agent.T.Fatal(err) }) } -type WriteCloseChecker struct { - io.WriteCloser - Closed bool -} - -func (w *WriteCloseChecker) Close() error { - w.Closed = true - return w.WriteCloser.Close() -} - -// This test checks, that even if the frame size has not been hit, a flush will -// periodically occur. -func TestStreamFramer_Flush(t *testing.T) { - // Create the stream framer - r, w := io.Pipe() - wrappedW := &WriteCloseChecker{WriteCloser: w} - hRate, bWindow := 100*time.Millisecond, 100*time.Millisecond - sf := NewStreamFramer(wrappedW, false, hRate, bWindow, 100) - sf.Run() - - // Create a decoder - dec := codec.NewDecoder(r, structs.JsonHandle) - - f := "foo" - fe := "bar" - d := []byte{0xa} - o := int64(10) - - // Start the reader - resultCh := make(chan struct{}) - go func() { - for { - var frame StreamFrame - if err := dec.Decode(&frame); err != nil { - t.Fatalf("failed to decode") - } - - if frame.IsHeartbeat() { - continue - } +// mockFSAlloc returns a suitable mock alloc for testing the fs system. If +// config isn't provided, the defaultLoggerMockDriver config is used. +func mockFSAlloc(nodeID string, config map[string]interface{}) *structs.Allocation { + a := mock.Alloc() + a.NodeID = nodeID + a.Job.Type = structs.JobTypeBatch + a.Job.TaskGroups[0].Count = 1 + a.Job.TaskGroups[0].Tasks[0].Driver = "mock_driver" - if reflect.DeepEqual(frame.Data, d) && frame.Offset == o && frame.File == f && frame.FileEvent == fe { - resultCh <- struct{}{} - return - } - - } - }() - - // Write only 1 byte so we do not hit the frame size - if err := sf.Send(f, fe, d, o); err != nil { - t.Fatalf("Send() failed %v", err) - } - - select { - case <-resultCh: - case <-time.After(10 * time.Duration(testutil.TestMultiplier()) * bWindow): - t.Fatalf("failed to flush") + if config != nil { + a.Job.TaskGroups[0].Tasks[0].Config = config + } else { + a.Job.TaskGroups[0].Tasks[0].Config = defaultLoggerMockDriver } - // Close the reader and wait. This should cause the runner to exit - if err := r.Close(); err != nil { - t.Fatalf("failed to close reader") - } - - select { - case <-sf.ExitCh(): - case <-time.After(10 * time.Duration(testutil.TestMultiplier()) * hRate): - t.Fatalf("exit channel should close") - } - - sf.Destroy() - if !wrappedW.Closed { - t.Fatalf("writer not closed") - } + return a } -// This test checks that frames will be batched till the frame size is hit (in -// the case that is before the flush). -func TestStreamFramer_Batch(t *testing.T) { - // Create the stream framer - r, w := io.Pipe() - wrappedW := &WriteCloseChecker{WriteCloser: w} - // Ensure the batch window doesn't get hit - hRate, bWindow := 100*time.Millisecond, 500*time.Millisecond - sf := NewStreamFramer(wrappedW, false, hRate, bWindow, 3) - sf.Run() - - // Create a decoder - dec := codec.NewDecoder(r, structs.JsonHandle) - - f := "foo" - fe := "bar" - d := []byte{0xa, 0xb, 0xc} - o := int64(10) - - // Start the reader - resultCh := make(chan struct{}) - go func() { - for { - var frame StreamFrame - if err := dec.Decode(&frame); err != nil { - t.Fatalf("failed to decode") - } - - if frame.IsHeartbeat() { - continue - } - - if reflect.DeepEqual(frame.Data, d) && frame.Offset == o && frame.File == f && frame.FileEvent == fe { - resultCh <- struct{}{} - return - } - } - }() - - // Write only 1 byte so we do not hit the frame size - if err := sf.Send(f, fe, d[:1], o); err != nil { - t.Fatalf("Send() failed %v", err) - } - - // Ensure we didn't get any data - select { - case <-resultCh: - t.Fatalf("Got data before frame size reached") - case <-time.After(bWindow / 2): - } - - // Write the rest so we hit the frame size - if err := sf.Send(f, fe, d[1:], o); err != nil { - t.Fatalf("Send() failed %v", err) - } - - // Ensure we get data - select { - case <-resultCh: - case <-time.After(10 * time.Duration(testutil.TestMultiplier()) * bWindow): - t.Fatalf("Did not receive data after batch size reached") - } - - // Close the reader and wait. This should cause the runner to exit - if err := r.Close(); err != nil { - t.Fatalf("failed to close reader") - } - - select { - case <-sf.ExitCh(): - case <-time.After(10 * time.Duration(testutil.TestMultiplier()) * hRate): - t.Fatalf("exit channel should close") - } - - sf.Destroy() - if !wrappedW.Closed { - t.Fatalf("writer not closed") - } +func TestHTTP_FS_List_MissingParams(t *testing.T) { + t.Parallel() + require := require.New(t) + httpTest(t, nil, func(s *TestAgent) { + req, err := http.NewRequest("GET", "/v1/client/fs/ls/", nil) + require.Nil(err) + respW := httptest.NewRecorder() + _, err = s.Server.DirectoryListRequest(respW, req) + require.EqualError(err, allocIDNotPresentErr.Error()) + }) } -func TestStreamFramer_Heartbeat(t *testing.T) { - // Create the stream framer - r, w := io.Pipe() - wrappedW := &WriteCloseChecker{WriteCloser: w} - hRate, bWindow := 100*time.Millisecond, 100*time.Millisecond - sf := NewStreamFramer(wrappedW, false, hRate, bWindow, 100) - sf.Run() - - // Create a decoder - dec := codec.NewDecoder(r, structs.JsonHandle) - - // Start the reader - resultCh := make(chan struct{}) - go func() { - for { - var frame StreamFrame - if err := dec.Decode(&frame); err != nil { - t.Fatalf("failed to decode") - } - - if frame.IsHeartbeat() { - resultCh <- struct{}{} - return - } - } - }() - - select { - case <-resultCh: - case <-time.After(10 * time.Duration(testutil.TestMultiplier()) * hRate): - t.Fatalf("failed to heartbeat") - } +func TestHTTP_FS_Stat_MissingParams(t *testing.T) { + t.Parallel() + require := require.New(t) + httpTest(t, nil, func(s *TestAgent) { + req, err := http.NewRequest("GET", "/v1/client/fs/stat/", nil) + require.Nil(err) + respW := httptest.NewRecorder() - // Close the reader and wait. This should cause the runner to exit - if err := r.Close(); err != nil { - t.Fatalf("failed to close reader") - } + _, err = s.Server.FileStatRequest(respW, req) + require.EqualError(err, allocIDNotPresentErr.Error()) - select { - case <-sf.ExitCh(): - case <-time.After(10 * time.Duration(testutil.TestMultiplier()) * hRate): - t.Fatalf("exit channel should close") - } + req, err = http.NewRequest("GET", "/v1/client/fs/stat/foo", nil) + require.Nil(err) + respW = httptest.NewRecorder() - sf.Destroy() - if !wrappedW.Closed { - t.Fatalf("writer not closed") - } + _, err = s.Server.FileStatRequest(respW, req) + require.EqualError(err, fileNameNotPresentErr.Error()) + }) } -// This test checks that frames are received in order -func TestStreamFramer_Order(t *testing.T) { - // Create the stream framer - r, w := io.Pipe() - wrappedW := &WriteCloseChecker{WriteCloser: w} - // Ensure the batch window doesn't get hit - hRate, bWindow := 100*time.Millisecond, 10*time.Millisecond - sf := NewStreamFramer(wrappedW, false, hRate, bWindow, 10) - sf.Run() - - // Create a decoder - dec := codec.NewDecoder(r, structs.JsonHandle) - - files := []string{"1", "2", "3", "4", "5"} - input := bytes.NewBuffer(make([]byte, 0, 100000)) - for i := 0; i <= 1000; i++ { - str := strconv.Itoa(i) + "," - input.WriteString(str) - } - - expected := bytes.NewBuffer(make([]byte, 0, 100000)) - for range files { - expected.Write(input.Bytes()) - } - receivedBuf := bytes.NewBuffer(make([]byte, 0, 100000)) - - // Start the reader - resultCh := make(chan struct{}) - go func() { - for { - var frame StreamFrame - if err := dec.Decode(&frame); err != nil { - t.Fatalf("failed to decode") - } - - if frame.IsHeartbeat() { - continue - } - - receivedBuf.Write(frame.Data) - - if reflect.DeepEqual(expected, receivedBuf) { - resultCh <- struct{}{} - return - } - } - }() - - // Send the data - b := input.Bytes() - shards := 10 - each := len(b) / shards - for _, f := range files { - for i := 0; i < shards; i++ { - l, r := each*i, each*(i+1) - if i == shards-1 { - r = len(b) - } +func TestHTTP_FS_ReadAt_MissingParams(t *testing.T) { + t.Parallel() + require := require.New(t) + httpTest(t, nil, func(s *TestAgent) { + req, err := http.NewRequest("GET", "/v1/client/fs/readat/", nil) + require.Nil(err) + respW := httptest.NewRecorder() - if err := sf.Send(f, "", b[l:r], 0); err != nil { - t.Fatalf("Send() failed %v", err) - } - } - } + _, err = s.Server.FileReadAtRequest(respW, req) + require.NotNil(err) - // Ensure we get data - select { - case <-resultCh: - case <-time.After(10 * time.Duration(testutil.TestMultiplier()) * bWindow): - if reflect.DeepEqual(expected, receivedBuf) { - got := receivedBuf.String() - want := expected.String() - t.Fatalf("Got %v; want %v", got, want) - } - } + req, err = http.NewRequest("GET", "/v1/client/fs/readat/foo", nil) + require.Nil(err) + respW = httptest.NewRecorder() - // Close the reader and wait. This should cause the runner to exit - if err := r.Close(); err != nil { - t.Fatalf("failed to close reader") - } + _, err = s.Server.FileReadAtRequest(respW, req) + require.NotNil(err) - select { - case <-sf.ExitCh(): - case <-time.After(10 * time.Duration(testutil.TestMultiplier()) * hRate): - t.Fatalf("exit channel should close") - } + req, err = http.NewRequest("GET", "/v1/client/fs/readat/foo?path=/path/to/file", nil) + require.Nil(err) + respW = httptest.NewRecorder() - sf.Destroy() - if !wrappedW.Closed { - t.Fatalf("writer not closed") - } + _, err = s.Server.FileReadAtRequest(respW, req) + require.NotNil(err) + }) } -// This test checks that frames are received in order -func TestStreamFramer_Order_PlainText(t *testing.T) { - // Create the stream framer - r, w := io.Pipe() - wrappedW := &WriteCloseChecker{WriteCloser: w} - // Ensure the batch window doesn't get hit - hRate, bWindow := 100*time.Millisecond, 10*time.Millisecond - sf := NewStreamFramer(wrappedW, true, hRate, bWindow, 10) - sf.Run() - - files := []string{"1", "2", "3", "4", "5"} - input := bytes.NewBuffer(make([]byte, 0, 100000)) - for i := 0; i <= 1000; i++ { - str := strconv.Itoa(i) + "," - input.WriteString(str) - } - - expected := bytes.NewBuffer(make([]byte, 0, 100000)) - for range files { - expected.Write(input.Bytes()) - } - receivedBuf := bytes.NewBuffer(make([]byte, 0, 100000)) - - // Start the reader - resultCh := make(chan struct{}) - go func() { - OUTER: - for { - if _, err := receivedBuf.ReadFrom(r); err != nil { - if strings.Contains(err.Error(), "closed pipe") { - resultCh <- struct{}{} - return - } - t.Fatalf("bad read: %v", err) - } - - if expected.Len() != receivedBuf.Len() { - continue - } - expectedBytes := expected.Bytes() - actualBytes := receivedBuf.Bytes() - for i, e := range expectedBytes { - if a := actualBytes[i]; a != e { - continue OUTER - } - } - resultCh <- struct{}{} - return - - } - }() - - // Send the data - b := input.Bytes() - shards := 10 - each := len(b) / shards - for _, f := range files { - for i := 0; i < shards; i++ { - l, r := each*i, each*(i+1) - if i == shards-1 { - r = len(b) - } - - if err := sf.Send(f, "", b[l:r], 0); err != nil { - t.Fatalf("Send() failed %v", err) - } - } - } +func TestHTTP_FS_Cat_MissingParams(t *testing.T) { + t.Parallel() + require := require.New(t) + httpTest(t, nil, func(s *TestAgent) { + req, err := http.NewRequest("GET", "/v1/client/fs/cat/", nil) + require.Nil(err) + respW := httptest.NewRecorder() - // Ensure we get data - select { - case <-resultCh: - case <-time.After(10 * time.Duration(testutil.TestMultiplier()) * bWindow): - if expected.Len() != receivedBuf.Len() { - t.Fatalf("Got %v; want %v", expected.Len(), receivedBuf.Len()) - } - expectedBytes := expected.Bytes() - actualBytes := receivedBuf.Bytes() - for i, e := range expectedBytes { - if a := actualBytes[i]; a != e { - t.Fatalf("Index %d; Got %q; want %q", i, a, e) - } - } - } + _, err = s.Server.FileCatRequest(respW, req) + require.EqualError(err, allocIDNotPresentErr.Error()) - // Close the reader and wait. This should cause the runner to exit - if err := r.Close(); err != nil { - t.Fatalf("failed to close reader") - } + req, err = http.NewRequest("GET", "/v1/client/fs/stat/foo", nil) + require.Nil(err) + respW = httptest.NewRecorder() - sf.Destroy() - if !wrappedW.Closed { - t.Fatalf("writer not closed") - } + _, err = s.Server.FileCatRequest(respW, req) + require.EqualError(err, fileNameNotPresentErr.Error()) + }) } -func TestHTTP_Stream_MissingParams(t *testing.T) { +func TestHTTP_FS_Stream_MissingParams(t *testing.T) { t.Parallel() + require := require.New(t) httpTest(t, nil, func(s *TestAgent) { req, err := http.NewRequest("GET", "/v1/client/fs/stream/", nil) - if err != nil { - t.Fatalf("err: %v", err) - } + require.Nil(err) respW := httptest.NewRecorder() _, err = s.Server.Stream(respW, req) - if err == nil { - t.Fatal("expected error") - } + require.EqualError(err, allocIDNotPresentErr.Error()) req, err = http.NewRequest("GET", "/v1/client/fs/stream/foo", nil) - if err != nil { - t.Fatalf("err: %v", err) - } + require.Nil(err) respW = httptest.NewRecorder() _, err = s.Server.Stream(respW, req) - if err == nil { - t.Fatal("expected error") - } + require.EqualError(err, fileNameNotPresentErr.Error()) req, err = http.NewRequest("GET", "/v1/client/fs/stream/foo?path=/path/to/file", nil) - if err != nil { - t.Fatalf("err: %v", err) - } + require.Nil(err) respW = httptest.NewRecorder() _, err = s.Server.Stream(respW, req) - if err == nil { - t.Fatal("expected error") - } + require.Nil(err) }) } -// tempAllocDir returns a new alloc dir that is rooted in a temp dir. The caller -// should destroy the temp dir. -func tempAllocDir(t testing.TB) *allocdir.AllocDir { - dir, err := ioutil.TempDir("", "") - if err != nil { - t.Fatalf("TempDir() failed: %v", err) - } +func TestHTTP_FS_Logs_MissingParams(t *testing.T) { + t.Parallel() + require := require.New(t) + httpTest(t, nil, func(s *TestAgent) { + req, err := http.NewRequest("GET", "/v1/client/fs/logs/", nil) + require.Nil(err) + respW := httptest.NewRecorder() - if err := os.Chmod(dir, 0777); err != nil { - t.Fatalf("failed to chmod dir: %v", err) - } + _, err = s.Server.Logs(respW, req) + require.EqualError(err, allocIDNotPresentErr.Error()) - return allocdir.NewAllocDir(log.New(os.Stderr, "", log.LstdFlags), dir) -} + req, err = http.NewRequest("GET", "/v1/client/fs/logs/foo", nil) + require.Nil(err) + respW = httptest.NewRecorder() -type nopWriteCloser struct { - io.Writer -} + _, err = s.Server.Logs(respW, req) + require.EqualError(err, taskNotPresentErr.Error()) -func (n nopWriteCloser) Close() error { - return nil -} + req, err = http.NewRequest("GET", "/v1/client/fs/logs/foo?task=foo", nil) + require.Nil(err) + respW = httptest.NewRecorder() -func TestHTTP_Stream_NoFile(t *testing.T) { - t.Parallel() - httpTest(t, nil, func(s *TestAgent) { - // Get a temp alloc dir - ad := tempAllocDir(t) - defer os.RemoveAll(ad.AllocDir) + _, err = s.Server.Logs(respW, req) + require.EqualError(err, logTypeNotPresentErr.Error()) - framer := NewStreamFramer(nopWriteCloser{ioutil.Discard}, false, streamHeartbeatRate, streamBatchWindow, streamFrameSize) - framer.Run() - defer framer.Destroy() + req, err = http.NewRequest("GET", "/v1/client/fs/logs/foo?task=foo&type=stdout", nil) + require.Nil(err) + respW = httptest.NewRecorder() - if err := s.Server.stream(0, "foo", ad, framer, nil); err == nil { - t.Fatalf("expected an error when streaming unknown file") - } + _, err = s.Server.Logs(respW, req) + require.Nil(err) }) } -func TestHTTP_Stream_Modify(t *testing.T) { +func TestHTTP_FS_List(t *testing.T) { t.Parallel() + require := require.New(t) httpTest(t, nil, func(s *TestAgent) { - // Get a temp alloc dir - ad := tempAllocDir(t) - defer os.RemoveAll(ad.AllocDir) - - // Create a file in the temp dir - streamFile := "stream_file" - f, err := os.Create(filepath.Join(ad.AllocDir, streamFile)) - if err != nil { - t.Fatalf("Failed to create file: %v", err) - } - defer f.Close() - - // Create a decoder - r, w := io.Pipe() - defer r.Close() - defer w.Close() - dec := codec.NewDecoder(r, structs.JsonHandle) + a := mockFSAlloc(s.client.NodeID(), nil) + addAllocToClient(s, a, terminalClientAlloc) - data := []byte("helloworld") - - // Start the reader - resultCh := make(chan struct{}) - go func() { - var collected []byte - for { - var frame StreamFrame - if err := dec.Decode(&frame); err != nil { - t.Fatalf("failed to decode: %v", err) - } - - if frame.IsHeartbeat() { - continue - } - - collected = append(collected, frame.Data...) - if reflect.DeepEqual(data, collected) { - resultCh <- struct{}{} - return - } - } - }() - - // Write a few bytes - if _, err := f.Write(data[:3]); err != nil { - t.Fatalf("write failed: %v", err) - } - - framer := NewStreamFramer(w, false, streamHeartbeatRate, streamBatchWindow, streamFrameSize) - framer.Run() - defer framer.Destroy() - - // Start streaming - go func() { - if err := s.Server.stream(0, streamFile, ad, framer, nil); err != nil { - t.Fatalf("stream() failed: %v", err) - } - }() - - // Sleep a little before writing more. This lets us check if the watch - // is working. - time.Sleep(1 * time.Duration(testutil.TestMultiplier()) * time.Second) - if _, err := f.Write(data[3:]); err != nil { - t.Fatalf("write failed: %v", err) - } + req, err := http.NewRequest("GET", "/v1/client/fs/ls/"+a.ID, nil) + require.Nil(err) + respW := httptest.NewRecorder() + raw, err := s.Server.DirectoryListRequest(respW, req) + require.Nil(err) - select { - case <-resultCh: - case <-time.After(10 * time.Duration(testutil.TestMultiplier()) * streamBatchWindow): - t.Fatalf("failed to send new data") - } + files, ok := raw.([]*cstructs.AllocFileInfo) + require.True(ok) + require.NotEmpty(files) + require.True(files[0].IsDir) }) } -func TestHTTP_Stream_Truncate(t *testing.T) { +func TestHTTP_FS_Stat(t *testing.T) { t.Parallel() + require := require.New(t) httpTest(t, nil, func(s *TestAgent) { - // Get a temp alloc dir - ad := tempAllocDir(t) - defer os.RemoveAll(ad.AllocDir) - - // Create a file in the temp dir - streamFile := "stream_file" - streamFilePath := filepath.Join(ad.AllocDir, streamFile) - f, err := os.Create(streamFilePath) - if err != nil { - t.Fatalf("Failed to create file: %v", err) - } - defer f.Close() - - // Create a decoder - r, w := io.Pipe() - defer r.Close() - defer w.Close() - dec := codec.NewDecoder(r, structs.JsonHandle) + a := mockFSAlloc(s.client.NodeID(), nil) + addAllocToClient(s, a, terminalClientAlloc) - data := []byte("helloworld") + path := fmt.Sprintf("/v1/client/fs/stat/%s?path=alloc/", a.ID) + req, err := http.NewRequest("GET", path, nil) + require.Nil(err) + respW := httptest.NewRecorder() + raw, err := s.Server.FileStatRequest(respW, req) + require.Nil(err) - // Start the reader - truncateCh := make(chan struct{}) - dataPostTruncCh := make(chan struct{}) - go func() { - var collected []byte - for { - var frame StreamFrame - if err := dec.Decode(&frame); err != nil { - t.Fatalf("failed to decode: %v", err) - } - - if frame.IsHeartbeat() { - continue - } - - if frame.FileEvent == truncateEvent { - close(truncateCh) - } - - collected = append(collected, frame.Data...) - if reflect.DeepEqual(data, collected) { - close(dataPostTruncCh) - return - } - } - }() + info, ok := raw.(*cstructs.AllocFileInfo) + require.True(ok) + require.NotNil(info) + require.True(info.IsDir) + }) +} - // Write a few bytes - if _, err := f.Write(data[:3]); err != nil { - t.Fatalf("write failed: %v", err) - } +func TestHTTP_FS_ReadAt(t *testing.T) { + t.Parallel() + require := require.New(t) + httpTest(t, nil, func(s *TestAgent) { + a := mockFSAlloc(s.client.NodeID(), nil) + addAllocToClient(s, a, terminalClientAlloc) - framer := NewStreamFramer(w, false, streamHeartbeatRate, streamBatchWindow, streamFrameSize) - framer.Run() - defer framer.Destroy() + offset := 1 + limit := 3 + expectation := defaultLoggerMockDriverStdout[offset : offset+limit] + path := fmt.Sprintf("/v1/client/fs/readat/%s?path=alloc/logs/web.stdout.0&offset=%d&limit=%d", + a.ID, offset, limit) - // Start streaming - go func() { - if err := s.Server.stream(0, streamFile, ad, framer, nil); err != nil { - t.Fatalf("stream() failed: %v", err) - } - }() + req, err := http.NewRequest("GET", path, nil) + require.Nil(err) + respW := httptest.NewRecorder() + _, err = s.Server.FileReadAtRequest(respW, req) + require.Nil(err) - // Sleep a little before truncating. This lets us check if the watch - // is working. - time.Sleep(1 * time.Duration(testutil.TestMultiplier()) * time.Second) - if err := f.Truncate(0); err != nil { - t.Fatalf("truncate failed: %v", err) - } - if err := f.Sync(); err != nil { - t.Fatalf("sync failed: %v", err) - } - if err := f.Close(); err != nil { - t.Fatalf("failed to close file: %v", err) - } + output, err := ioutil.ReadAll(respW.Result().Body) + require.Nil(err) + require.EqualValues(expectation, output) + }) +} - f2, err := os.OpenFile(streamFilePath, os.O_RDWR, 0) - if err != nil { - t.Fatalf("failed to reopen file: %v", err) - } - defer f2.Close() - if _, err := f2.Write(data[3:5]); err != nil { - t.Fatalf("write failed: %v", err) - } +func TestHTTP_FS_Cat(t *testing.T) { + t.Parallel() + require := require.New(t) + httpTest(t, nil, func(s *TestAgent) { + a := mockFSAlloc(s.client.NodeID(), nil) + addAllocToClient(s, a, terminalClientAlloc) - select { - case <-truncateCh: - case <-time.After(10 * time.Duration(testutil.TestMultiplier()) * streamBatchWindow): - t.Fatalf("did not receive truncate") - } + path := fmt.Sprintf("/v1/client/fs/cat/%s?path=alloc/logs/web.stdout.0", a.ID) - // Sleep a little before writing more. This lets us check if the watch - // is working. - time.Sleep(1 * time.Duration(testutil.TestMultiplier()) * time.Second) - if _, err := f2.Write(data[5:]); err != nil { - t.Fatalf("write failed: %v", err) - } + req, err := http.NewRequest("GET", path, nil) + require.Nil(err) + respW := httptest.NewRecorder() + _, err = s.Server.FileCatRequest(respW, req) + require.Nil(err) - select { - case <-dataPostTruncCh: - case <-time.After(10 * time.Duration(testutil.TestMultiplier()) * streamBatchWindow): - t.Fatalf("did not receive post truncate data") - } + output, err := ioutil.ReadAll(respW.Result().Body) + require.Nil(err) + require.EqualValues(defaultLoggerMockDriverStdout, output) }) } -func TestHTTP_Stream_Delete(t *testing.T) { +func TestHTTP_FS_Stream(t *testing.T) { t.Parallel() + require := require.New(t) httpTest(t, nil, func(s *TestAgent) { - // Get a temp alloc dir - ad := tempAllocDir(t) - defer os.RemoveAll(ad.AllocDir) - - // Create a file in the temp dir - streamFile := "stream_file" - streamFilePath := filepath.Join(ad.AllocDir, streamFile) - f, err := os.Create(streamFilePath) - if err != nil { - t.Fatalf("Failed to create file: %v", err) - } - defer f.Close() + a := mockFSAlloc(s.client.NodeID(), nil) + addAllocToClient(s, a, terminalClientAlloc) - // Create a decoder - r, w := io.Pipe() - wrappedW := &WriteCloseChecker{WriteCloser: w} - defer r.Close() - defer w.Close() - dec := codec.NewDecoder(r, structs.JsonHandle) + offset := 4 + expectation := base64.StdEncoding.EncodeToString( + []byte(defaultLoggerMockDriverStdout[len(defaultLoggerMockDriverStdout)-offset:])) + path := fmt.Sprintf("/v1/client/fs/stream/%s?path=alloc/logs/web.stdout.0&offset=%d&origin=end", + a.ID, offset) - data := []byte("helloworld") + p, _ := io.Pipe() - // Start the reader - deleteCh := make(chan struct{}) + req, err := http.NewRequest("GET", path, p) + require.Nil(err) + respW := httptest.NewRecorder() + doneCh := make(chan struct{}) go func() { - for { - var frame StreamFrame - if err := dec.Decode(&frame); err != nil { - t.Fatalf("failed to decode: %v", err) - } - - if frame.IsHeartbeat() { - continue - } - - if frame.FileEvent == deleteEvent { - close(deleteCh) - return - } - } + _, err = s.Server.Stream(respW, req) + require.Nil(err) + close(doneCh) }() - // Write a few bytes - if _, err := f.Write(data[:3]); err != nil { - t.Fatalf("write failed: %v", err) - } - - framer := NewStreamFramer(wrappedW, false, streamHeartbeatRate, streamBatchWindow, streamFrameSize) - framer.Run() - - // Start streaming - go func() { - if err := s.Server.stream(0, streamFile, ad, framer, nil); err != nil { - t.Fatalf("stream() failed: %v", err) + out := "" + testutil.WaitForResult(func() (bool, error) { + output, err := ioutil.ReadAll(respW.Body) + if err != nil { + return false, err } - }() - // Sleep a little before deleting. This lets us check if the watch - // is working. - time.Sleep(1 * time.Duration(testutil.TestMultiplier()) * time.Second) - if err := os.Remove(streamFilePath); err != nil { - t.Fatalf("delete failed: %v", err) - } + out += string(output) + return strings.Contains(out, expectation), fmt.Errorf("%q doesn't contain %q", out, expectation) + }, func(err error) { + t.Fatal(err) + }) select { - case <-deleteCh: - case <-time.After(10 * time.Duration(testutil.TestMultiplier()) * streamBatchWindow): - t.Fatalf("did not receive delete") + case <-doneCh: + t.Fatal("shouldn't close") + case <-time.After(1 * time.Second): } - framer.Destroy() - testutil.WaitForResult(func() (bool, error) { - return wrappedW.Closed, nil - }, func(err error) { - t.Fatalf("connection not closed") - }) - + p.Close() }) } -func TestHTTP_Logs_NoFollow(t *testing.T) { +func TestHTTP_FS_Logs(t *testing.T) { t.Parallel() + require := require.New(t) httpTest(t, nil, func(s *TestAgent) { - // Get a temp alloc dir and create the log dir - ad := tempAllocDir(t) - defer os.RemoveAll(ad.AllocDir) - - logDir := filepath.Join(ad.SharedDir, allocdir.LogDirName) - if err := os.MkdirAll(logDir, 0777); err != nil { - t.Fatalf("Failed to make log dir: %v", err) - } - - // Create a series of log files in the temp dir - task := "foo" - logType := "stdout" - expected := []byte("012") - for i := 0; i < 3; i++ { - logFile := fmt.Sprintf("%s.%s.%d", task, logType, i) - logFilePath := filepath.Join(logDir, logFile) - err := ioutil.WriteFile(logFilePath, expected[i:i+1], 777) - if err != nil { - t.Fatalf("Failed to create file: %v", err) - } - } + a := mockFSAlloc(s.client.NodeID(), nil) + addAllocToClient(s, a, terminalClientAlloc) - // Create a decoder - r, w := io.Pipe() - wrappedW := &WriteCloseChecker{WriteCloser: w} - defer r.Close() - defer w.Close() - dec := codec.NewDecoder(r, structs.JsonHandle) + offset := 4 + expectation := defaultLoggerMockDriverStdout[len(defaultLoggerMockDriverStdout)-offset:] + path := fmt.Sprintf("/v1/client/fs/logs/%s?type=stdout&task=web&offset=%d&origin=end&plain=true", + a.ID, offset) - var received []byte - - // Start the reader - resultCh := make(chan struct{}) + p, _ := io.Pipe() + req, err := http.NewRequest("GET", path, p) + require.Nil(err) + respW := httptest.NewRecorder() go func() { - for { - var frame StreamFrame - if err := dec.Decode(&frame); err != nil { - if err == io.EOF { - t.Logf("EOF") - return - } - - t.Fatalf("failed to decode: %v", err) - } - - if frame.IsHeartbeat() { - continue - } - - received = append(received, frame.Data...) - if reflect.DeepEqual(received, expected) { - close(resultCh) - return - } - } + _, err = s.Server.Logs(respW, req) + require.Nil(err) }() - // Start streaming logs - go func() { - if err := s.Server.logs(false, false, 0, OriginStart, task, logType, ad, wrappedW); err != nil { - t.Fatalf("logs() failed: %v", err) + out := "" + testutil.WaitForResult(func() (bool, error) { + output, err := ioutil.ReadAll(respW.Body) + if err != nil { + return false, err } - }() - - select { - case <-resultCh: - case <-time.After(10 * time.Duration(testutil.TestMultiplier()) * streamBatchWindow): - t.Fatalf("did not receive data: got %q", string(received)) - } - testutil.WaitForResult(func() (bool, error) { - return wrappedW.Closed, nil + out += string(output) + return out == expectation, fmt.Errorf("%q != %q", out, expectation) }, func(err error) { - t.Fatalf("connection not closed") + t.Fatal(err) }) + p.Close() }) } -func TestHTTP_Logs_Follow(t *testing.T) { +func TestHTTP_FS_Logs_Follow(t *testing.T) { t.Parallel() + require := require.New(t) httpTest(t, nil, func(s *TestAgent) { - // Get a temp alloc dir and create the log dir - ad := tempAllocDir(t) - defer os.RemoveAll(ad.AllocDir) + a := mockFSAlloc(s.client.NodeID(), nil) + addAllocToClient(s, a, terminalClientAlloc) - logDir := filepath.Join(ad.SharedDir, allocdir.LogDirName) - if err := os.MkdirAll(logDir, 0777); err != nil { - t.Fatalf("Failed to make log dir: %v", err) - } + offset := 4 + expectation := defaultLoggerMockDriverStdout[len(defaultLoggerMockDriverStdout)-offset:] + path := fmt.Sprintf("/v1/client/fs/logs/%s?type=stdout&task=web&offset=%d&origin=end&plain=true&follow=true", + a.ID, offset) - // Create a series of log files in the temp dir - task := "foo" - logType := "stdout" - expected := []byte("012345") - initialWrites := 3 - - writeToFile := func(index int, data []byte) { - logFile := fmt.Sprintf("%s.%s.%d", task, logType, index) - logFilePath := filepath.Join(logDir, logFile) - err := ioutil.WriteFile(logFilePath, data, 777) - if err != nil { - t.Fatalf("Failed to create file: %v", err) - } - } - for i := 0; i < initialWrites; i++ { - writeToFile(i, expected[i:i+1]) - } - - // Create a decoder - r, w := io.Pipe() - wrappedW := &WriteCloseChecker{WriteCloser: w} - defer r.Close() - defer w.Close() - dec := codec.NewDecoder(r, structs.JsonHandle) - - var received []byte - - // Start the reader - firstResultCh := make(chan struct{}) - fullResultCh := make(chan struct{}) - go func() { - for { - var frame StreamFrame - if err := dec.Decode(&frame); err != nil { - if err == io.EOF { - t.Logf("EOF") - return - } - - t.Fatalf("failed to decode: %v", err) - } - - if frame.IsHeartbeat() { - continue - } - - received = append(received, frame.Data...) - if reflect.DeepEqual(received, expected[:initialWrites]) { - close(firstResultCh) - } else if reflect.DeepEqual(received, expected) { - close(fullResultCh) - return - } - } - }() - - // Start streaming logs + p, _ := io.Pipe() + req, err := http.NewRequest("GET", path, p) + require.Nil(err) + respW := httptest.NewRecorder() + doneCh := make(chan struct{}) go func() { - if err := s.Server.logs(true, false, 0, OriginStart, task, logType, ad, wrappedW); err != nil { - t.Fatalf("logs() failed: %v", err) - } + _, err = s.Server.Logs(respW, req) + require.Nil(err) + close(doneCh) }() - select { - case <-firstResultCh: - case <-time.After(10 * time.Duration(testutil.TestMultiplier()) * streamBatchWindow): - t.Fatalf("did not receive data: got %q", string(received)) - } - - // We got the first chunk of data, write out the rest to the next file - // at an index much ahead to check that it is following and detecting - // skips - skipTo := initialWrites + 10 - writeToFile(skipTo, expected[initialWrites:]) - - select { - case <-fullResultCh: - case <-time.After(10 * time.Duration(testutil.TestMultiplier()) * streamBatchWindow): - t.Fatalf("did not receive data: got %q", string(received)) - } - - // Close the reader - r.Close() - + out := "" testutil.WaitForResult(func() (bool, error) { - return wrappedW.Closed, nil - }, func(err error) { - t.Fatalf("connection not closed") - }) - }) -} - -func BenchmarkHTTP_Logs_Follow(t *testing.B) { - runtime.MemProfileRate = 1 - - s := makeHTTPServer(t, nil) - defer s.Shutdown() - testutil.WaitForLeader(t, s.Agent.RPC) - - // Get a temp alloc dir and create the log dir - ad := tempAllocDir(t) - s.Agent.logger.Printf("ALEX: LOG DIR: %q", ad.SharedDir) - //defer os.RemoveAll(ad.AllocDir) - - logDir := filepath.Join(ad.SharedDir, allocdir.LogDirName) - if err := os.MkdirAll(logDir, 0777); err != nil { - t.Fatalf("Failed to make log dir: %v", err) - } - - // Create a series of log files in the temp dir - task := "foo" - logType := "stdout" - expected := make([]byte, 1024*1024*100) - initialWrites := 3 - - writeToFile := func(index int, data []byte) { - logFile := fmt.Sprintf("%s.%s.%d", task, logType, index) - logFilePath := filepath.Join(logDir, logFile) - err := ioutil.WriteFile(logFilePath, data, 777) - if err != nil { - t.Fatalf("Failed to create file: %v", err) - } - } - - part := (len(expected) / 3) - 50 - goodEnough := (8 * len(expected)) / 10 - for i := 0; i < initialWrites; i++ { - writeToFile(i, expected[i*part:(i+1)*part]) - } - - t.ResetTimer() - for i := 0; i < t.N; i++ { - s.Agent.logger.Printf("BENCHMARK %d", i) - - // Create a decoder - r, w := io.Pipe() - wrappedW := &WriteCloseChecker{WriteCloser: w} - defer r.Close() - defer w.Close() - dec := codec.NewDecoder(r, structs.JsonHandle) - - var received []byte - - // Start the reader - fullResultCh := make(chan struct{}) - go func() { - for { - var frame StreamFrame - if err := dec.Decode(&frame); err != nil { - if err == io.EOF { - t.Logf("EOF") - return - } - - t.Fatalf("failed to decode: %v", err) - } - - if frame.IsHeartbeat() { - continue - } - - received = append(received, frame.Data...) - if len(received) > goodEnough { - close(fullResultCh) - return - } - } - }() - - // Start streaming logs - go func() { - if err := s.Server.logs(true, false, 0, OriginStart, task, logType, ad, wrappedW); err != nil { - t.Fatalf("logs() failed: %v", err) + output, err := ioutil.ReadAll(respW.Body) + if err != nil { + return false, err } - }() - - select { - case <-fullResultCh: - case <-time.After(time.Duration(60 * time.Second)): - t.Fatalf("did not receive data: %d < %d", len(received), goodEnough) - } - s.Agent.logger.Printf("ALEX: CLOSING") - - // Close the reader - r.Close() - s.Agent.logger.Printf("ALEX: CLOSED") - - s.Agent.logger.Printf("ALEX: WAITING FOR WRITER TO CLOSE") - testutil.WaitForResult(func() (bool, error) { - return wrappedW.Closed, nil + out += string(output) + return out == expectation, fmt.Errorf("%q != %q", out, expectation) }, func(err error) { - t.Fatalf("connection not closed") + t.Fatal(err) }) - s.Agent.logger.Printf("ALEX: WRITER CLOSED") - } -} -func TestLogs_findClosest(t *testing.T) { - task := "foo" - entries := []*allocdir.AllocFileInfo{ - { - Name: "foo.stdout.0", - Size: 100, - }, - { - Name: "foo.stdout.1", - Size: 100, - }, - { - Name: "foo.stdout.2", - Size: 100, - }, - { - Name: "foo.stdout.3", - Size: 100, - }, - { - Name: "foo.stderr.0", - Size: 100, - }, - { - Name: "foo.stderr.1", - Size: 100, - }, - { - Name: "foo.stderr.2", - Size: 100, - }, - } - - cases := []struct { - Entries []*allocdir.AllocFileInfo - DesiredIdx int64 - DesiredOffset int64 - Task string - LogType string - ExpectedFile string - ExpectedIdx int64 - ExpectedOffset int64 - Error bool - }{ - // Test error cases - { - Entries: nil, - DesiredIdx: 0, - Task: task, - LogType: "stdout", - Error: true, - }, - { - Entries: entries[0:3], - DesiredIdx: 0, - Task: task, - LogType: "stderr", - Error: true, - }, - - // Test beginning cases - { - Entries: entries, - DesiredIdx: 0, - Task: task, - LogType: "stdout", - ExpectedFile: entries[0].Name, - ExpectedIdx: 0, - }, - { - // Desired offset should be ignored at edges - Entries: entries, - DesiredIdx: 0, - DesiredOffset: -100, - Task: task, - LogType: "stdout", - ExpectedFile: entries[0].Name, - ExpectedIdx: 0, - ExpectedOffset: 0, - }, - { - // Desired offset should be ignored at edges - Entries: entries, - DesiredIdx: 1, - DesiredOffset: -1000, - Task: task, - LogType: "stdout", - ExpectedFile: entries[0].Name, - ExpectedIdx: 0, - ExpectedOffset: 0, - }, - { - Entries: entries, - DesiredIdx: 0, - Task: task, - LogType: "stderr", - ExpectedFile: entries[4].Name, - ExpectedIdx: 0, - }, - { - Entries: entries, - DesiredIdx: 0, - Task: task, - LogType: "stdout", - ExpectedFile: entries[0].Name, - ExpectedIdx: 0, - }, - - // Test middle cases - { - Entries: entries, - DesiredIdx: 1, - Task: task, - LogType: "stdout", - ExpectedFile: entries[1].Name, - ExpectedIdx: 1, - }, - { - Entries: entries, - DesiredIdx: 1, - DesiredOffset: 10, - Task: task, - LogType: "stdout", - ExpectedFile: entries[1].Name, - ExpectedIdx: 1, - ExpectedOffset: 10, - }, - { - Entries: entries, - DesiredIdx: 1, - DesiredOffset: 110, - Task: task, - LogType: "stdout", - ExpectedFile: entries[2].Name, - ExpectedIdx: 2, - ExpectedOffset: 10, - }, - { - Entries: entries, - DesiredIdx: 1, - Task: task, - LogType: "stderr", - ExpectedFile: entries[5].Name, - ExpectedIdx: 1, - }, - // Test end cases - { - Entries: entries, - DesiredIdx: math.MaxInt64, - Task: task, - LogType: "stdout", - ExpectedFile: entries[3].Name, - ExpectedIdx: 3, - }, - { - Entries: entries, - DesiredIdx: math.MaxInt64, - DesiredOffset: math.MaxInt64, - Task: task, - LogType: "stdout", - ExpectedFile: entries[3].Name, - ExpectedIdx: 3, - ExpectedOffset: 100, - }, - { - Entries: entries, - DesiredIdx: math.MaxInt64, - DesiredOffset: -10, - Task: task, - LogType: "stdout", - ExpectedFile: entries[3].Name, - ExpectedIdx: 3, - ExpectedOffset: 90, - }, - { - Entries: entries, - DesiredIdx: math.MaxInt64, - Task: task, - LogType: "stderr", - ExpectedFile: entries[6].Name, - ExpectedIdx: 2, - }, - } - - for i, c := range cases { - entry, idx, offset, err := findClosest(c.Entries, c.DesiredIdx, c.DesiredOffset, c.Task, c.LogType) - if err != nil { - if !c.Error { - t.Fatalf("case %d: Unexpected error: %v", i, err) - } - continue + select { + case <-doneCh: + t.Fatal("shouldn't close") + case <-time.After(1 * time.Second): } - if entry.Name != c.ExpectedFile { - t.Fatalf("case %d: Got file %q; want %q", i, entry.Name, c.ExpectedFile) - } - if idx != c.ExpectedIdx { - t.Fatalf("case %d: Got index %d; want %d", i, idx, c.ExpectedIdx) - } - if offset != c.ExpectedOffset { - t.Fatalf("case %d: Got offset %d; want %d", i, offset, c.ExpectedOffset) - } - } + p.Close() + }) } diff --git a/command/agent/helpers.go b/command/agent/helpers.go new file mode 100644 index 00000000000..50542416ca9 --- /dev/null +++ b/command/agent/helpers.go @@ -0,0 +1,52 @@ +package agent + +// rpcHandlerForAlloc is a helper that given an allocation ID returns whether to +// use the local clients RPC, the local clients remote RPC or the server on the +// agent. +func (s *HTTPServer) rpcHandlerForAlloc(allocID string) (localClient, remoteClient, server bool) { + c := s.agent.Client() + srv := s.agent.Server() + + // See if the local client can handle the request. + localAlloc := false + if c != nil { + // If there is an error it means that the client doesn't have the + // allocation so we can't use the local client + _, err := c.GetClientAlloc(allocID) + if err == nil { + localAlloc = true + } + } + + // Only use the client RPC to server if we don't have a server and the local + // client can't handle the call. + useClientRPC := c != nil && !localAlloc && srv == nil + + // Use the server as a last case. + useServerRPC := !localAlloc && !useClientRPC && srv != nil + + return localAlloc, useClientRPC, useServerRPC +} + +// rpcHandlerForNode is a helper that given a node ID returns whether to +// use the local clients RPC, the local clients remote RPC or the server on the +// agent. If there is a local node and no node id is given, it is assumed the +// local node is being targed. +func (s *HTTPServer) rpcHandlerForNode(nodeID string) (localClient, remoteClient, server bool) { + c := s.agent.Client() + srv := s.agent.Server() + + // See if the local client can handle the request. + localClient = c != nil && // Must have a client + (nodeID == "" || // If no node ID is given + nodeID == c.NodeID()) // Requested node is the local node. + + // Only use the client RPC to server if we don't have a server and the local + // client can't handle the call. + useClientRPC := c != nil && !localClient && srv == nil + + // Use the server as a last case. + useServerRPC := !localClient && !useClientRPC && srv != nil && nodeID != "" + + return localClient, useClientRPC, useServerRPC +} diff --git a/command/agent/helpers_test.go b/command/agent/helpers_test.go new file mode 100644 index 00000000000..10532310ec1 --- /dev/null +++ b/command/agent/helpers_test.go @@ -0,0 +1,92 @@ +package agent + +import ( + "testing" + + "github.com/hashicorp/nomad/helper/uuid" + "github.com/stretchr/testify/require" +) + +func TestHTTP_rpcHandlerForAlloc(t *testing.T) { + t.Parallel() + require := require.New(t) + agent := NewTestAgent(t, t.Name(), nil) + + a := mockFSAlloc(agent.client.NodeID(), nil) + addAllocToClient(agent, a, terminalClientAlloc) + + // Case 1: Client has allocation + // Outcome: Use local client + lc, rc, s := agent.Server.rpcHandlerForAlloc(a.ID) + require.True(lc) + require.False(rc) + require.False(s) + + // Case 2: Client doesn't have allocation and there is a server + // Outcome: Use server + lc, rc, s = agent.Server.rpcHandlerForAlloc(uuid.Generate()) + require.False(lc) + require.False(rc) + require.True(s) + + // Case 3: Client doesn't have allocation and there is no server + // Outcome: Use client RPC to server + srv := agent.server + agent.server = nil + lc, rc, s = agent.Server.rpcHandlerForAlloc(uuid.Generate()) + require.False(lc) + require.True(rc) + require.False(s) + agent.server = srv + + // Case 4: No client + // Outcome: Use server + client := agent.client + agent.client = nil + lc, rc, s = agent.Server.rpcHandlerForAlloc(uuid.Generate()) + require.False(lc) + require.False(rc) + require.True(s) + agent.client = client +} + +func TestHTTP_rpcHandlerForNode(t *testing.T) { + t.Parallel() + require := require.New(t) + agent := NewTestAgent(t, t.Name(), nil) + cID := agent.client.NodeID() + + // Case 1: Node running, no node ID given + // Outcome: Use local node + lc, rc, s := agent.Server.rpcHandlerForNode("") + require.True(lc) + require.False(rc) + require.False(s) + + // Case 2: Node running, it's ID given + // Outcome: Use local node + lc, rc, s = agent.Server.rpcHandlerForNode(cID) + require.True(lc) + require.False(rc) + require.False(s) + + // Case 3: Local node but wrong ID and there is no server + // Outcome: Use client RPC to server + srv := agent.server + agent.server = nil + lc, rc, s = agent.Server.rpcHandlerForNode(uuid.Generate()) + require.False(lc) + require.True(rc) + require.False(s) + agent.server = srv + + // Case 4: No client + // Outcome: Use server + client := agent.client + agent.client = nil + lc, rc, s = agent.Server.rpcHandlerForNode(uuid.Generate()) + require.False(lc) + require.False(rc) + require.True(s) + agent.client = client +} diff --git a/command/agent/stats_endpoint.go b/command/agent/stats_endpoint.go index ba18d9c28eb..b87a231837e 100644 --- a/command/agent/stats_endpoint.go +++ b/command/agent/stats_endpoint.go @@ -2,25 +2,47 @@ package agent import ( "net/http" + "strings" + cstructs "github.com/hashicorp/nomad/client/structs" "github.com/hashicorp/nomad/nomad/structs" ) func (s *HTTPServer) ClientStatsRequest(resp http.ResponseWriter, req *http.Request) (interface{}, error) { - if s.agent.client == nil { - return nil, clientNotRunning + // Get the requested Node ID + requestedNode := req.URL.Query().Get("node_id") + + // Build the request and parse the ACL token + args := structs.NodeSpecificRequest{ + NodeID: requestedNode, + } + s.parse(resp, req, &args.QueryOptions.Region, &args.QueryOptions) + + // Determine the handler to use + useLocalClient, useClientRPC, useServerRPC := s.rpcHandlerForNode(requestedNode) + + // Make the RPC + var reply cstructs.ClientStatsResponse + var rpcErr error + if useLocalClient { + rpcErr = s.agent.Client().ClientRPC("ClientStats.Stats", &args, &reply) + } else if useClientRPC { + rpcErr = s.agent.Client().RPC("ClientStats.Stats", &args, &reply) + } else if useServerRPC { + rpcErr = s.agent.Server().RPC("ClientStats.Stats", &args, &reply) + } else { + rpcErr = CodedError(400, "No local Node and node_id not provided") } - var secret string - s.parseToken(req, &secret) + if rpcErr != nil { + if structs.IsErrNoNodeConn(rpcErr) { + rpcErr = CodedError(404, rpcErr.Error()) + } else if strings.Contains(rpcErr.Error(), "Unknown node") { + rpcErr = CodedError(404, rpcErr.Error()) + } - // Check node read permissions - if aclObj, err := s.agent.Client().ResolveToken(secret); err != nil { - return nil, err - } else if aclObj != nil && !aclObj.AllowNodeRead() { - return nil, structs.ErrPermissionDenied + return nil, rpcErr } - clientStats := s.agent.client.StatsReporter() - return clientStats.LatestHostStats(), nil + return reply.HostStats, nil } diff --git a/command/agent/stats_endpoint_test.go b/command/agent/stats_endpoint_test.go index 9661878c290..6cc1f178ec6 100644 --- a/command/agent/stats_endpoint_test.go +++ b/command/agent/stats_endpoint_test.go @@ -1,28 +1,81 @@ package agent import ( + "fmt" "net/http" "net/http/httptest" "testing" "github.com/hashicorp/nomad/acl" + "github.com/hashicorp/nomad/helper/uuid" "github.com/hashicorp/nomad/nomad/mock" "github.com/hashicorp/nomad/nomad/structs" + "github.com/hashicorp/nomad/testutil" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestClientStatsRequest(t *testing.T) { t.Parallel() + require := require.New(t) httpTest(t, nil, func(s *TestAgent) { - req, err := http.NewRequest("GET", "/v1/client/stats/?since=foo", nil) - if err != nil { - t.Fatalf("err: %v", err) + + // Local node, local resp + { + req, err := http.NewRequest("GET", "/v1/client/stats/?since=foo", nil) + if err != nil { + t.Fatalf("err: %v", err) + } + + respW := httptest.NewRecorder() + _, err = s.Server.ClientStatsRequest(respW, req) + if err != nil { + t.Fatalf("unexpected err: %v", err) + } + } + + // Local node, server resp + { + srv := s.server + s.server = nil + + req, err := http.NewRequest("GET", fmt.Sprintf("/v1/client/stats?node_id=%s", uuid.Generate()), nil) + require.Nil(err) + + respW := httptest.NewRecorder() + _, err = s.Server.ClientStatsRequest(respW, req) + require.NotNil(err) + require.Contains(err.Error(), "Unknown node") + + s.server = srv } - respW := httptest.NewRecorder() - _, err = s.Server.ClientStatsRequest(respW, req) - if err != nil { - t.Fatalf("unexpected err: %v", err) + // no client, server resp + { + c := s.client + s.client = nil + + testutil.WaitForResult(func() (bool, error) { + n, err := s.server.State().NodeByID(nil, c.NodeID()) + if err != nil { + return false, err + } + return n != nil, nil + }, func(err error) { + t.Fatalf("should have client: %v", err) + }) + + req, err := http.NewRequest("GET", fmt.Sprintf("/v1/client/stats?node_id=%s", c.NodeID()), nil) + require.Nil(err) + + respW := httptest.NewRecorder() + _, err = s.Server.ClientStatsRequest(respW, req) + require.NotNil(err) + + // The dev agent uses in-mem RPC so just assert the no route error + require.Contains(err.Error(), structs.ErrNoNodeConn.Error()) + + s.client = c } }) } diff --git a/command/agent/testagent.go b/command/agent/testagent.go index 539890004b1..6e309b7a5e8 100644 --- a/command/agent/testagent.go +++ b/command/agent/testagent.go @@ -19,6 +19,7 @@ import ( "github.com/hashicorp/consul/lib/freeport" "github.com/hashicorp/nomad/api" "github.com/hashicorp/nomad/client/fingerprint" + "github.com/hashicorp/nomad/helper/testlog" "github.com/hashicorp/nomad/nomad" "github.com/hashicorp/nomad/nomad/mock" "github.com/hashicorp/nomad/nomad/structs" @@ -115,7 +116,10 @@ func (a *TestAgent) Start() *TestAgent { a.Config.NomadConfig.DataDir = d } - for i := 10; i >= 0; i-- { + i := 10 + +RETRY: + for ; i >= 0; i-- { a.pickRandomPorts(a.Config) if a.Config.NodeName == "" { a.Config.NodeName = fmt.Sprintf("Node %d", a.Config.Ports.RPC) @@ -137,14 +141,14 @@ func (a *TestAgent) Start() *TestAgent { a.Agent = agent break } else if i == 0 { - fmt.Println(a.Name, "Error starting agent:", err) + a.T.Logf("%s: Error starting agent: %v", a.Name, err) runtime.Goexit() } else { if agent != nil { agent.Shutdown() } wait := time.Duration(rand.Int31n(2000)) * time.Millisecond - fmt.Println(a.Name, "retrying in", wait) + a.T.Logf("%s: retrying in %v", a.Name, wait) time.Sleep(wait) } @@ -153,12 +157,13 @@ func (a *TestAgent) Start() *TestAgent { // the data dir, such as in the Raft configuration. if a.DataDir != "" { if err := os.RemoveAll(a.DataDir); err != nil { - fmt.Println(a.Name, "Error resetting data dir:", err) + a.T.Logf("%s: Error resetting data dir: %v", a.Name, err) runtime.Goexit() } } } + failed := false if a.Config.NomadConfig.Bootstrap && a.Config.Server.Enabled { testutil.WaitForResult(func() (bool, error) { args := &structs.GenericRequest{} @@ -166,7 +171,8 @@ func (a *TestAgent) Start() *TestAgent { err := a.RPC("Status.Leader", args, &leader) return leader != "", err }, func(err error) { - a.T.Fatalf("failed to find leader: %v", err) + a.T.Logf("failed to find leader: %v", err) + failed = true }) } else { testutil.WaitForResult(func() (bool, error) { @@ -175,9 +181,14 @@ func (a *TestAgent) Start() *TestAgent { _, err := a.Server.AgentSelfRequest(resp, req) return err == nil && resp.Code == 200, err }, func(err error) { - a.T.Fatalf("failed OK response: %v", err) + a.T.Logf("failed to find leader: %v", err) + failed = true }) } + if failed { + a.Agent.Shutdown() + goto RETRY + } // Check if ACLs enabled. Use special value of PolicyTTL 0s // to do a bypass of this step. This is so we can test bootstrap @@ -194,7 +205,7 @@ func (a *TestAgent) Start() *TestAgent { func (a *TestAgent) start() (*Agent, error) { if a.LogOutput == nil { - a.LogOutput = os.Stderr + a.LogOutput = testlog.NewWriter(a.T) } inm := metrics.NewInmemSink(10*time.Second, time.Minute) @@ -264,6 +275,15 @@ func (a *TestAgent) pickRandomPorts(c *Config) { c.Ports.RPC = ports[1] c.Ports.Serf = ports[2] + // Clear out the advertise addresses such that through retries we + // re-normalize the addresses correctly instead of using the values from the + // last port selection that had a port conflict. + if c.AdvertiseAddrs != nil { + c.AdvertiseAddrs.HTTP = "" + c.AdvertiseAddrs.RPC = "" + c.AdvertiseAddrs.Serf = "" + } + if err := c.normalizeAddrs(); err != nil { a.T.Fatalf("error normalizing config: %v", err) } diff --git a/command/client_config_test.go b/command/client_config_test.go index e00bb6e3051..cb9275ca0fb 100644 --- a/command/client_config_test.go +++ b/command/client_config_test.go @@ -33,23 +33,16 @@ func TestClientConfigCommand_UpdateServers(t *testing.T) { } ui.ErrorWriter.Reset() - // Set the servers list + // Set the servers list with bad addresses code = cmd.Run([]string{"-address=" + url, "-update-servers", "127.0.0.42", "198.18.5.5"}) - if code != 0 { - t.Fatalf("expected exit 0, got: %d", code) + if code != 1 { + t.Fatalf("expected exit 1, got: %d", code) } - // Query the servers list - code = cmd.Run([]string{"-address=" + url, "-servers"}) + // Set the servers list with good addresses + code = cmd.Run([]string{"-address=" + url, "-update-servers", srv.Config.AdvertiseAddrs.RPC}) if code != 0 { - t.Fatalf("expect exit 0, got: %d", code) - } - out := ui.OutputWriter.String() - if !strings.Contains(out, "127.0.0.42") { - t.Fatalf("missing 127.0.0.42") - } - if !strings.Contains(out, "198.18.5.5") { - t.Fatalf("missing 198.18.5.5") + t.Fatalf("expected exit 0, got: %d", code) } } diff --git a/helper/codec/inmem.go b/helper/codec/inmem.go new file mode 100644 index 00000000000..cb69e89e1b1 --- /dev/null +++ b/helper/codec/inmem.go @@ -0,0 +1,42 @@ +package codec + +import ( + "errors" + "net/rpc" + "reflect" +) + +// InmemCodec is used to do an RPC call without going over a network +type InmemCodec struct { + Method string + Args interface{} + Reply interface{} + Err error +} + +func (i *InmemCodec) ReadRequestHeader(req *rpc.Request) error { + req.ServiceMethod = i.Method + return nil +} + +func (i *InmemCodec) ReadRequestBody(args interface{}) error { + sourceValue := reflect.Indirect(reflect.Indirect(reflect.ValueOf(i.Args))) + dst := reflect.Indirect(reflect.Indirect(reflect.ValueOf(args))) + dst.Set(sourceValue) + return nil +} + +func (i *InmemCodec) WriteResponse(resp *rpc.Response, reply interface{}) error { + if resp.Error != "" { + i.Err = errors.New(resp.Error) + return nil + } + sourceValue := reflect.Indirect(reflect.Indirect(reflect.ValueOf(reply))) + dst := reflect.Indirect(reflect.Indirect(reflect.ValueOf(i.Reply))) + dst.Set(sourceValue) + return nil +} + +func (i *InmemCodec) Close() error { + return nil +} diff --git a/helper/pool/conn.go b/helper/pool/conn.go new file mode 100644 index 00000000000..d5dcc5703f9 --- /dev/null +++ b/helper/pool/conn.go @@ -0,0 +1,15 @@ +package pool + +type RPCType byte + +const ( + RpcNomad RPCType = 0x01 + RpcRaft = 0x02 + RpcMultiplex = 0x03 + RpcTLS = 0x04 + RpcStreaming = 0x05 + + // RpcMultiplexV2 allows a multiplexed connection to switch modes between + // RpcNomad and RpcStreaming per opened stream. + RpcMultiplexV2 = 0x06 +) diff --git a/nomad/pool.go b/helper/pool/pool.go similarity index 86% rename from nomad/pool.go rename to helper/pool/pool.go index 017621c99e8..b1a57c2df23 100644 --- a/nomad/pool.go +++ b/helper/pool/pool.go @@ -1,4 +1,4 @@ -package nomad +package pool import ( "container/list" @@ -12,9 +12,20 @@ import ( msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc" "github.com/hashicorp/nomad/helper/tlsutil" + "github.com/hashicorp/nomad/nomad/structs" "github.com/hashicorp/yamux" ) +// NewClientCodec returns a new rpc.ClientCodec to be used to make RPC calls. +func NewClientCodec(conn io.ReadWriteCloser) rpc.ClientCodec { + return msgpackrpc.NewCodecFromHandle(true, true, conn, structs.HashiMsgpackHandle) +} + +// NewServerCodec returns a new rpc.ServerCodec to be used to handle RPCs. +func NewServerCodec(conn io.ReadWriteCloser) rpc.ServerCodec { + return msgpackrpc.NewCodecFromHandle(true, true, conn, structs.HashiMsgpackHandle) +} + // streamClient is used to wrap a stream with an RPC client type StreamClient struct { stream net.Conn @@ -81,7 +92,7 @@ func (c *Conn) getClient() (*StreamClient, error) { return sc, nil } -// returnStream is used when done with a stream +// returnClient is used when done with a stream // to allow re-use by a future RPC func (c *Conn) returnClient(client *StreamClient) { didSave := false @@ -134,6 +145,10 @@ type ConnPool struct { // Used to indicate the pool is shutdown shutdown bool shutdownCh chan struct{} + + // connListener is used to notify a potential listener of a new connection + // being made. + connListener chan<- *yamux.Session } // NewPool is used to make a new connection pool @@ -170,6 +185,12 @@ func (p *ConnPool) Shutdown() error { if p.shutdown { return nil } + + if p.connListener != nil { + close(p.connListener) + p.connListener = nil + } + p.shutdown = true close(p.shutdownCh) return nil @@ -188,6 +209,21 @@ func (p *ConnPool) ReloadTLS(tlsWrap tlsutil.RegionWrapper) { p.tlsWrap = tlsWrap } +// SetConnListener is used to listen to new connections being made. The +// channel will be closed when the conn pool is closed or a new listener is set. +func (p *ConnPool) SetConnListener(l chan<- *yamux.Session) { + p.Lock() + defer p.Unlock() + + // Close the old listener + if p.connListener != nil { + close(p.connListener) + } + + // Store the new listener + p.connListener = l +} + // Acquire is used to get a connection that is // pooled or to return a new connection func (p *ConnPool) acquire(region string, addr net.Addr, version int) (*Conn, error) { @@ -227,6 +263,15 @@ func (p *ConnPool) acquire(region string, addr net.Addr, version int) (*Conn, er } p.pool[addr.String()] = c + + // If there is a connection listener, notify them of the new connection. + if p.connListener != nil { + select { + case p.connListener <- c.session: + default: + } + } + p.Unlock() return c, nil } @@ -268,7 +313,7 @@ func (p *ConnPool) getNewConn(region string, addr net.Addr, version int) (*Conn, // Check if TLS is enabled if p.tlsWrap != nil { // Switch the connection into TLS mode - if _, err := conn.Write([]byte{byte(rpcTLS)}); err != nil { + if _, err := conn.Write([]byte{byte(RpcTLS)}); err != nil { conn.Close() return nil, err } @@ -283,7 +328,7 @@ func (p *ConnPool) getNewConn(region string, addr net.Addr, version int) (*Conn, } // Write the multiplex byte to set the mode - if _, err := conn.Write([]byte{byte(rpcMultiplex)}); err != nil { + if _, err := conn.Write([]byte{byte(RpcMultiplex)}); err != nil { conn.Close() return nil, err } @@ -312,7 +357,8 @@ func (p *ConnPool) getNewConn(region string, addr net.Addr, version int) (*Conn, return c, nil } -// clearConn is used to clear any cached connection, potentially in response to an erro +// clearConn is used to clear any cached connection, potentially in response to +// an error func (p *ConnPool) clearConn(conn *Conn) { // Ensure returned streams are closed atomic.StoreInt32(&conn.shouldClose, 1) diff --git a/helper/pool/pool_test.go b/helper/pool/pool_test.go new file mode 100644 index 00000000000..becf7d46814 --- /dev/null +++ b/helper/pool/pool_test.go @@ -0,0 +1,66 @@ +package pool + +import ( + "fmt" + "net" + "testing" + "time" + + "github.com/hashicorp/consul/lib/freeport" + "github.com/hashicorp/nomad/helper/testlog" + "github.com/hashicorp/nomad/nomad/structs" + "github.com/hashicorp/yamux" + "github.com/stretchr/testify/require" +) + +func newTestPool(t *testing.T) *ConnPool { + w := testlog.NewWriter(t) + p := NewPool(w, 1*time.Minute, 10, nil) + return p +} + +func TestConnPool_ConnListener(t *testing.T) { + require := require.New(t) + + ports := freeport.GetT(t, 1) + addrStr := fmt.Sprintf("127.0.0.1:%d", ports[0]) + addr, err := net.ResolveTCPAddr("tcp", addrStr) + require.Nil(err) + + exitCh := make(chan struct{}) + defer close(exitCh) + go func() { + ln, err := net.Listen("tcp", addrStr) + require.Nil(err) + defer ln.Close() + conn, _ := ln.Accept() + defer conn.Close() + + <-exitCh + }() + + time.Sleep(100 * time.Millisecond) + + // Create a test pool + pool := newTestPool(t) + + // Setup a listener + c := make(chan *yamux.Session, 1) + pool.SetConnListener(c) + + // Make an RPC + _, err = pool.acquire("test", addr, structs.ApiMajorVersion) + require.Nil(err) + + // Assert we get a connection. + select { + case <-c: + case <-time.After(100 * time.Millisecond): + t.Fatalf("timeout") + } + + // Test that the channel is closed when the pool shuts down. + require.Nil(pool.Shutdown()) + _, ok := <-c + require.False(ok) +} diff --git a/helper/stats/runtime.go b/helper/stats/runtime.go new file mode 100644 index 00000000000..6aa9ee66ee8 --- /dev/null +++ b/helper/stats/runtime.go @@ -0,0 +1,18 @@ +package stats + +import ( + "runtime" + "strconv" +) + +// RuntimeStats is used to return various runtime information +func RuntimeStats() map[string]string { + return map[string]string{ + "kernel.name": runtime.GOOS, + "arch": runtime.GOARCH, + "version": runtime.Version(), + "max_procs": strconv.FormatInt(int64(runtime.GOMAXPROCS(0)), 10), + "goroutines": strconv.FormatInt(int64(runtime.NumGoroutine()), 10), + "cpu_count": strconv.FormatInt(int64(runtime.NumCPU()), 10), + } +} diff --git a/helper/testlog/testlog.go b/helper/testlog/testlog.go index 7f6c6cb042c..b72fcfb28be 100644 --- a/helper/testlog/testlog.go +++ b/helper/testlog/testlog.go @@ -42,5 +42,5 @@ func WithPrefix(t LogPrinter, prefix string) *log.Logger { // NewLog logger with "TEST" prefix and the Lmicroseconds flag. func Logger(t LogPrinter) *log.Logger { - return WithPrefix(t, "TEST ") + return WithPrefix(t, "") } diff --git a/nomad/acl_endpoint_test.go b/nomad/acl_endpoint_test.go index 944f7c5a8f9..506e78e199e 100644 --- a/nomad/acl_endpoint_test.go +++ b/nomad/acl_endpoint_test.go @@ -19,7 +19,7 @@ import ( func TestACLEndpoint_GetPolicy(t *testing.T) { t.Parallel() - s1, root := testACLServer(t, nil) + s1, root := TestACLServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) @@ -74,7 +74,7 @@ func TestACLEndpoint_GetPolicy(t *testing.T) { func TestACLEndpoint_GetPolicy_Blocking(t *testing.T) { t.Parallel() - s1, root := testACLServer(t, nil) + s1, root := TestACLServer(t, nil) defer s1.Shutdown() state := s1.fsm.State() codec := rpcClient(t, s1) @@ -153,7 +153,7 @@ func TestACLEndpoint_GetPolicy_Blocking(t *testing.T) { func TestACLEndpoint_GetPolicies(t *testing.T) { t.Parallel() - s1, root := testACLServer(t, nil) + s1, root := TestACLServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) @@ -192,7 +192,7 @@ func TestACLEndpoint_GetPolicies(t *testing.T) { func TestACLEndpoint_GetPolicies_TokenSubset(t *testing.T) { t.Parallel() - s1, _ := testACLServer(t, nil) + s1, _ := TestACLServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) @@ -232,7 +232,7 @@ func TestACLEndpoint_GetPolicies_TokenSubset(t *testing.T) { func TestACLEndpoint_GetPolicies_Blocking(t *testing.T) { t.Parallel() - s1, root := testACLServer(t, nil) + s1, root := TestACLServer(t, nil) defer s1.Shutdown() state := s1.fsm.State() codec := rpcClient(t, s1) @@ -312,7 +312,7 @@ func TestACLEndpoint_GetPolicies_Blocking(t *testing.T) { func TestACLEndpoint_ListPolicies(t *testing.T) { assert := assert.New(t) t.Parallel() - s1, root := testACLServer(t, nil) + s1, root := TestACLServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) @@ -378,7 +378,7 @@ func TestACLEndpoint_ListPolicies(t *testing.T) { func TestACLEndpoint_ListPolicies_Blocking(t *testing.T) { t.Parallel() - s1, root := testACLServer(t, nil) + s1, root := TestACLServer(t, nil) defer s1.Shutdown() state := s1.fsm.State() codec := rpcClient(t, s1) @@ -438,7 +438,7 @@ func TestACLEndpoint_ListPolicies_Blocking(t *testing.T) { func TestACLEndpoint_DeletePolicies(t *testing.T) { t.Parallel() - s1, root := testACLServer(t, nil) + s1, root := TestACLServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) @@ -464,7 +464,7 @@ func TestACLEndpoint_DeletePolicies(t *testing.T) { func TestACLEndpoint_UpsertPolicies(t *testing.T) { t.Parallel() - s1, root := testACLServer(t, nil) + s1, root := TestACLServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) @@ -494,7 +494,7 @@ func TestACLEndpoint_UpsertPolicies(t *testing.T) { func TestACLEndpoint_UpsertPolicies_Invalid(t *testing.T) { t.Parallel() - s1, root := testACLServer(t, nil) + s1, root := TestACLServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) @@ -521,7 +521,7 @@ func TestACLEndpoint_UpsertPolicies_Invalid(t *testing.T) { func TestACLEndpoint_GetToken(t *testing.T) { t.Parallel() - s1, root := testACLServer(t, nil) + s1, root := TestACLServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) @@ -566,7 +566,7 @@ func TestACLEndpoint_GetToken(t *testing.T) { func TestACLEndpoint_GetToken_Blocking(t *testing.T) { t.Parallel() - s1, root := testACLServer(t, nil) + s1, root := TestACLServer(t, nil) defer s1.Shutdown() state := s1.fsm.State() codec := rpcClient(t, s1) @@ -645,7 +645,7 @@ func TestACLEndpoint_GetToken_Blocking(t *testing.T) { func TestACLEndpoint_GetTokens(t *testing.T) { t.Parallel() - s1, root := testACLServer(t, nil) + s1, root := TestACLServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) @@ -683,7 +683,7 @@ func TestACLEndpoint_GetTokens(t *testing.T) { func TestACLEndpoint_GetTokens_Blocking(t *testing.T) { t.Parallel() - s1, root := testACLServer(t, nil) + s1, root := TestACLServer(t, nil) defer s1.Shutdown() state := s1.fsm.State() codec := rpcClient(t, s1) @@ -762,7 +762,7 @@ func TestACLEndpoint_GetTokens_Blocking(t *testing.T) { func TestACLEndpoint_ListTokens(t *testing.T) { t.Parallel() - s1, root := testACLServer(t, nil) + s1, root := TestACLServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) @@ -823,7 +823,7 @@ func TestACLEndpoint_ListTokens(t *testing.T) { func TestACLEndpoint_ListTokens_Blocking(t *testing.T) { t.Parallel() - s1, root := testACLServer(t, nil) + s1, root := TestACLServer(t, nil) defer s1.Shutdown() state := s1.fsm.State() codec := rpcClient(t, s1) @@ -883,7 +883,7 @@ func TestACLEndpoint_ListTokens_Blocking(t *testing.T) { func TestACLEndpoint_DeleteTokens(t *testing.T) { t.Parallel() - s1, root := testACLServer(t, nil) + s1, root := TestACLServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) @@ -911,7 +911,7 @@ func TestACLEndpoint_DeleteTokens_WithNonexistantToken(t *testing.T) { t.Parallel() assert := assert.New(t) - s1, root := testACLServer(t, nil) + s1, root := TestACLServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) @@ -936,7 +936,7 @@ func TestACLEndpoint_DeleteTokens_WithNonexistantToken(t *testing.T) { func TestACLEndpoint_Bootstrap(t *testing.T) { t.Parallel() - s1 := testServer(t, func(c *Config) { + s1 := TestServer(t, func(c *Config) { c.ACLEnabled = true }) defer s1.Shutdown() @@ -973,7 +973,7 @@ func TestACLEndpoint_Bootstrap_Reset(t *testing.T) { t.Parallel() dir := tmpDir(t) defer os.RemoveAll(dir) - s1 := testServer(t, func(c *Config) { + s1 := TestServer(t, func(c *Config) { c.ACLEnabled = true c.DataDir = dir c.DevMode = false @@ -1035,7 +1035,7 @@ func TestACLEndpoint_Bootstrap_Reset(t *testing.T) { func TestACLEndpoint_UpsertTokens(t *testing.T) { t.Parallel() - s1, root := testACLServer(t, nil) + s1, root := TestACLServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) @@ -1091,7 +1091,7 @@ func TestACLEndpoint_UpsertTokens(t *testing.T) { func TestACLEndpoint_UpsertTokens_Invalid(t *testing.T) { t.Parallel() - s1, root := testACLServer(t, nil) + s1, root := TestACLServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) @@ -1118,7 +1118,7 @@ func TestACLEndpoint_UpsertTokens_Invalid(t *testing.T) { func TestACLEndpoint_ResolveToken(t *testing.T) { t.Parallel() - s1, _ := testACLServer(t, nil) + s1, _ := TestACLServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) diff --git a/nomad/acl_test.go b/nomad/acl_test.go index 8fce091bb85..df5d5b7241a 100644 --- a/nomad/acl_test.go +++ b/nomad/acl_test.go @@ -95,7 +95,7 @@ func TestResolveACLToken(t *testing.T) { func TestResolveACLToken_LeaderToken(t *testing.T) { t.Parallel() assert := assert.New(t) - s1, _ := testACLServer(t, nil) + s1, _ := TestACLServer(t, nil) defer s1.Shutdown() testutil.WaitForLeader(t, s1.RPC) diff --git a/nomad/alloc_endpoint_test.go b/nomad/alloc_endpoint_test.go index ec5c372ece0..2a72de7d23d 100644 --- a/nomad/alloc_endpoint_test.go +++ b/nomad/alloc_endpoint_test.go @@ -16,7 +16,7 @@ import ( func TestAllocEndpoint_List(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) @@ -82,7 +82,7 @@ func TestAllocEndpoint_List(t *testing.T) { func TestAllocEndpoint_List_ACL(t *testing.T) { t.Parallel() - s1, root := testACLServer(t, nil) + s1, root := TestACLServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) @@ -138,7 +138,7 @@ func TestAllocEndpoint_List_ACL(t *testing.T) { func TestAllocEndpoint_List_Blocking(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() state := s1.fsm.State() codec := rpcClient(t, s1) @@ -213,7 +213,7 @@ func TestAllocEndpoint_List_Blocking(t *testing.T) { func TestAllocEndpoint_GetAlloc(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) @@ -253,7 +253,7 @@ func TestAllocEndpoint_GetAlloc(t *testing.T) { func TestAllocEndpoint_GetAlloc_ACL(t *testing.T) { t.Parallel() - s1, root := testACLServer(t, nil) + s1, root := TestACLServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) @@ -327,7 +327,7 @@ func TestAllocEndpoint_GetAlloc_ACL(t *testing.T) { func TestAllocEndpoint_GetAlloc_Blocking(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() state := s1.fsm.State() codec := rpcClient(t, s1) @@ -382,7 +382,7 @@ func TestAllocEndpoint_GetAlloc_Blocking(t *testing.T) { func TestAllocEndpoint_GetAllocs(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) @@ -429,7 +429,7 @@ func TestAllocEndpoint_GetAllocs(t *testing.T) { func TestAllocEndpoint_GetAllocs_Blocking(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() state := s1.fsm.State() codec := rpcClient(t, s1) diff --git a/nomad/autopilot_test.go b/nomad/autopilot_test.go index 13cee504471..a4e2f5d9032 100644 --- a/nomad/autopilot_test.go +++ b/nomad/autopilot_test.go @@ -77,26 +77,26 @@ func testCleanupDeadServer(t *testing.T, raftVersion int) { c.BootstrapExpect = 3 c.RaftConfig.ProtocolVersion = raft.ProtocolVersion(raftVersion) } - s1 := testServer(t, conf) + s1 := TestServer(t, conf) defer s1.Shutdown() - s2 := testServer(t, conf) + s2 := TestServer(t, conf) defer s2.Shutdown() - s3 := testServer(t, conf) + s3 := TestServer(t, conf) defer s3.Shutdown() servers := []*Server{s1, s2, s3} // Try to join - testJoin(t, s1, s2, s3) + TestJoin(t, s1, s2, s3) for _, s := range servers { retry.Run(t, func(r *retry.R) { r.Check(wantPeers(s, 3)) }) } // Bring up a new server - s4 := testServer(t, conf) + s4 := TestServer(t, conf) defer s4.Shutdown() // Kill a non-leader server @@ -114,7 +114,7 @@ func testCleanupDeadServer(t *testing.T, raftVersion int) { }) // Join the new server - testJoin(t, s1, s4) + TestJoin(t, s1, s4) servers[2] = s4 // Make sure the dead server is removed and we're back to 3 total peers @@ -125,30 +125,30 @@ func testCleanupDeadServer(t *testing.T, raftVersion int) { func TestAutopilot_CleanupDeadServerPeriodic(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() conf := func(c *Config) { c.DevDisableBootstrap = true } - s2 := testServer(t, conf) + s2 := TestServer(t, conf) defer s2.Shutdown() - s3 := testServer(t, conf) + s3 := TestServer(t, conf) defer s3.Shutdown() - s4 := testServer(t, conf) + s4 := TestServer(t, conf) defer s4.Shutdown() - s5 := testServer(t, conf) + s5 := TestServer(t, conf) defer s5.Shutdown() servers := []*Server{s1, s2, s3, s4, s5} // Join the servers to s1, and wait until they are all promoted to // voters. - testJoin(t, s1, servers[1:]...) + TestJoin(t, s1, servers[1:]...) retry.Run(t, func(r *retry.R) { r.Check(wantRaft(servers)) for _, s := range servers { @@ -171,7 +171,7 @@ func TestAutopilot_CleanupDeadServerPeriodic(t *testing.T) { func TestAutopilot_RollingUpdate(t *testing.T) { t.Parallel() - s1 := testServer(t, func(c *Config) { + s1 := TestServer(t, func(c *Config) { c.RaftConfig.ProtocolVersion = 3 }) defer s1.Shutdown() @@ -181,16 +181,16 @@ func TestAutopilot_RollingUpdate(t *testing.T) { c.RaftConfig.ProtocolVersion = 3 } - s2 := testServer(t, conf) + s2 := TestServer(t, conf) defer s2.Shutdown() - s3 := testServer(t, conf) + s3 := TestServer(t, conf) defer s3.Shutdown() // Join the servers to s1, and wait until they are all promoted to // voters. servers := []*Server{s1, s2, s3} - testJoin(t, s1, s2, s3) + TestJoin(t, s1, s2, s3) retry.Run(t, func(r *retry.R) { r.Check(wantRaft(servers)) for _, s := range servers { @@ -199,9 +199,9 @@ func TestAutopilot_RollingUpdate(t *testing.T) { }) // Add one more server like we are doing a rolling update. - s4 := testServer(t, conf) + s4 := TestServer(t, conf) defer s4.Shutdown() - testJoin(t, s1, s4) + TestJoin(t, s1, s4) servers = append(servers, s4) retry.Run(t, func(r *retry.R) { r.Check(wantRaft(servers)) @@ -243,25 +243,25 @@ func TestAutopilot_RollingUpdate(t *testing.T) { func TestAutopilot_CleanupStaleRaftServer(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() conf := func(c *Config) { c.DevDisableBootstrap = true } - s2 := testServer(t, conf) + s2 := TestServer(t, conf) defer s2.Shutdown() - s3 := testServer(t, conf) + s3 := TestServer(t, conf) defer s3.Shutdown() - s4 := testServer(t, conf) + s4 := TestServer(t, conf) defer s4.Shutdown() servers := []*Server{s1, s2, s3} // Join the servers to s1 - testJoin(t, s1, s2, s3) + TestJoin(t, s1, s2, s3) for _, s := range servers { retry.Run(t, func(r *retry.R) { r.Check(wantPeers(s, 3)) }) @@ -293,7 +293,7 @@ func TestAutopilot_CleanupStaleRaftServer(t *testing.T) { func TestAutopilot_PromoteNonVoter(t *testing.T) { t.Parallel() - s1 := testServer(t, func(c *Config) { + s1 := TestServer(t, func(c *Config) { c.RaftConfig.ProtocolVersion = 3 }) defer s1.Shutdown() @@ -301,12 +301,12 @@ func TestAutopilot_PromoteNonVoter(t *testing.T) { defer codec.Close() testutil.WaitForLeader(t, s1.RPC) - s2 := testServer(t, func(c *Config) { + s2 := TestServer(t, func(c *Config) { c.DevDisableBootstrap = true c.RaftConfig.ProtocolVersion = 3 }) defer s2.Shutdown() - testJoin(t, s1, s2) + TestJoin(t, s1, s2) // Make sure we see it as a nonvoter initially. We wait until half // the stabilization period has passed. diff --git a/nomad/client_alloc_endpoint.go b/nomad/client_alloc_endpoint.go new file mode 100644 index 00000000000..bb28d2d39f0 --- /dev/null +++ b/nomad/client_alloc_endpoint.go @@ -0,0 +1,158 @@ +package nomad + +import ( + "errors" + "time" + + metrics "github.com/armon/go-metrics" + "github.com/hashicorp/nomad/acl" + cstructs "github.com/hashicorp/nomad/client/structs" + "github.com/hashicorp/nomad/nomad/structs" +) + +// ClientAllocations is used to forward RPC requests to the targed Nomad client's +// Allocation endpoint. +type ClientAllocations struct { + srv *Server +} + +// GarbageCollectAll is used to garbage collect all allocations on a client. +func (a *ClientAllocations) GarbageCollectAll(args *structs.NodeSpecificRequest, reply *structs.GenericResponse) error { + // We only allow stale reads since the only potentially stale information is + // the Node registration and the cost is fairly high for adding another hope + // in the forwarding chain. + args.QueryOptions.AllowStale = true + + // Potentially forward to a different region. + if done, err := a.srv.forward("ClientAllocations.GarbageCollectAll", args, args, reply); done { + return err + } + defer metrics.MeasureSince([]string{"nomad", "client_allocations", "garbage_collect_all"}, time.Now()) + + // Check node read permissions + if aclObj, err := a.srv.ResolveToken(args.AuthToken); err != nil { + return err + } else if aclObj != nil && !aclObj.AllowNodeWrite() { + return structs.ErrPermissionDenied + } + + // Verify the arguments. + if args.NodeID == "" { + return errors.New("missing NodeID") + } + + // Get the connection to the client + state, ok := a.srv.getNodeConn(args.NodeID) + if !ok { + // Check if the node even exists + snap, err := a.srv.State().Snapshot() + if err != nil { + return err + } + + return findNodeConnAndForward(a.srv, snap, args.NodeID, "ClientAllocations.GarbageCollectAll", args, reply) + } + + // Make the RPC + return NodeRpc(state.Session, "Allocations.GarbageCollectAll", args, reply) +} + +// GarbageCollect is used to garbage collect an allocation on a client. +func (a *ClientAllocations) GarbageCollect(args *structs.AllocSpecificRequest, reply *structs.GenericResponse) error { + // We only allow stale reads since the only potentially stale information is + // the Node registration and the cost is fairly high for adding another hope + // in the forwarding chain. + args.QueryOptions.AllowStale = true + + // Potentially forward to a different region. + if done, err := a.srv.forward("ClientAllocations.GarbageCollect", args, args, reply); done { + return err + } + defer metrics.MeasureSince([]string{"nomad", "client_allocations", "garbage_collect"}, time.Now()) + + // Check node read permissions + if aclObj, err := a.srv.ResolveToken(args.AuthToken); err != nil { + return err + } else if aclObj != nil && !aclObj.AllowNsOp(args.Namespace, acl.NamespaceCapabilitySubmitJob) { + return structs.ErrPermissionDenied + } + + // Verify the arguments. + if args.AllocID == "" { + return errors.New("missing AllocID") + } + + // Find the allocation + snap, err := a.srv.State().Snapshot() + if err != nil { + return err + } + + alloc, err := snap.AllocByID(nil, args.AllocID) + if err != nil { + return err + } + + if alloc == nil { + return structs.NewErrUnknownAllocation(args.AllocID) + } + + // Get the connection to the client + state, ok := a.srv.getNodeConn(alloc.NodeID) + if !ok { + return findNodeConnAndForward(a.srv, snap, alloc.NodeID, "ClientAllocations.GarbageCollect", args, reply) + } + + // Make the RPC + return NodeRpc(state.Session, "Allocations.GarbageCollect", args, reply) +} + +// Stats is used to collect allocation statistics +func (a *ClientAllocations) Stats(args *cstructs.AllocStatsRequest, reply *cstructs.AllocStatsResponse) error { + // We only allow stale reads since the only potentially stale information is + // the Node registration and the cost is fairly high for adding another hope + // in the forwarding chain. + args.QueryOptions.AllowStale = true + + // Potentially forward to a different region. + if done, err := a.srv.forward("ClientAllocations.Stats", args, args, reply); done { + return err + } + defer metrics.MeasureSince([]string{"nomad", "client_allocations", "stats"}, time.Now()) + + // Check node read permissions + if aclObj, err := a.srv.ResolveToken(args.AuthToken); err != nil { + return err + } else if aclObj != nil && !aclObj.AllowNsOp(args.Namespace, acl.NamespaceCapabilityReadJob) { + return structs.ErrPermissionDenied + } + + // Verify the arguments. + if args.AllocID == "" { + return errors.New("missing AllocID") + } + + // Find the allocation + snap, err := a.srv.State().Snapshot() + if err != nil { + return err + } + + alloc, err := snap.AllocByID(nil, args.AllocID) + if err != nil { + return err + } + + if alloc == nil { + return structs.NewErrUnknownAllocation(args.AllocID) + } + + // Get the connection to the client + state, ok := a.srv.getNodeConn(alloc.NodeID) + if !ok { + return findNodeConnAndForward(a.srv, snap, alloc.NodeID, "ClientAllocations.Stats", args, reply) + } + + // Make the RPC + return NodeRpc(state.Session, "Allocations.Stats", args, reply) +} diff --git a/nomad/client_alloc_endpoint_test.go b/nomad/client_alloc_endpoint_test.go new file mode 100644 index 00000000000..c905b1b7924 --- /dev/null +++ b/nomad/client_alloc_endpoint_test.go @@ -0,0 +1,650 @@ +package nomad + +import ( + "fmt" + "testing" + + msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc" + "github.com/hashicorp/nomad/acl" + "github.com/hashicorp/nomad/client" + "github.com/hashicorp/nomad/client/config" + cstructs "github.com/hashicorp/nomad/client/structs" + "github.com/hashicorp/nomad/helper/uuid" + "github.com/hashicorp/nomad/nomad/mock" + "github.com/hashicorp/nomad/nomad/structs" + "github.com/hashicorp/nomad/testutil" + "github.com/stretchr/testify/require" +) + +func TestClientAllocations_GarbageCollectAll_Local(t *testing.T) { + t.Parallel() + require := require.New(t) + + // Start a server and client + s := TestServer(t, nil) + defer s.Shutdown() + codec := rpcClient(t, s) + testutil.WaitForLeader(t, s.RPC) + + c := client.TestClient(t, func(c *config.Config) { + c.Servers = []string{s.config.RPCAddr.String()} + }) + defer c.Shutdown() + + testutil.WaitForResult(func() (bool, error) { + nodes := s.connectedNodes() + return len(nodes) == 1, nil + }, func(err error) { + t.Fatalf("should have a clients") + }) + + // Make the request without having a node-id + req := &structs.NodeSpecificRequest{ + QueryOptions: structs.QueryOptions{Region: "global"}, + } + + // Fetch the response + var resp structs.GenericResponse + err := msgpackrpc.CallWithCodec(codec, "ClientAllocations.GarbageCollectAll", req, &resp) + require.NotNil(err) + require.Contains(err.Error(), "missing") + + // Fetch the response setting the node id + req.NodeID = c.NodeID() + var resp2 structs.GenericResponse + err = msgpackrpc.CallWithCodec(codec, "ClientAllocations.GarbageCollectAll", req, &resp2) + require.Nil(err) +} + +func TestClientAllocations_GarbageCollectAll_Local_ACL(t *testing.T) { + t.Parallel() + require := require.New(t) + + // Start a server + s, root := TestACLServer(t, nil) + defer s.Shutdown() + codec := rpcClient(t, s) + testutil.WaitForLeader(t, s.RPC) + + // Create a bad token + policyBad := mock.NamespacePolicy("other", "", []string{acl.NamespaceCapabilityReadFS}) + tokenBad := mock.CreatePolicyAndToken(t, s.State(), 1005, "invalid", policyBad) + + policyGood := mock.NodePolicy(acl.PolicyWrite) + tokenGood := mock.CreatePolicyAndToken(t, s.State(), 1009, "valid2", policyGood) + + cases := []struct { + Name string + Token string + ExpectedError string + }{ + { + Name: "bad token", + Token: tokenBad.SecretID, + ExpectedError: structs.ErrPermissionDenied.Error(), + }, + { + Name: "good token", + Token: tokenGood.SecretID, + ExpectedError: "Unknown node", + }, + { + Name: "root token", + Token: root.SecretID, + ExpectedError: "Unknown node", + }, + } + + for _, c := range cases { + t.Run(c.Name, func(t *testing.T) { + + // Make the request without having a node-id + req := &structs.NodeSpecificRequest{ + NodeID: uuid.Generate(), + QueryOptions: structs.QueryOptions{ + AuthToken: c.Token, + Region: "global", + }, + } + + // Fetch the response + var resp structs.GenericResponse + err := msgpackrpc.CallWithCodec(codec, "ClientAllocations.GarbageCollectAll", req, &resp) + require.NotNil(err) + require.Contains(err.Error(), c.ExpectedError) + }) + } +} + +func TestClientAllocations_GarbageCollectAll_NoNode(t *testing.T) { + t.Parallel() + require := require.New(t) + + // Start a server and client + s := TestServer(t, nil) + defer s.Shutdown() + codec := rpcClient(t, s) + testutil.WaitForLeader(t, s.RPC) + + // Make the request without having a node-id + req := &structs.NodeSpecificRequest{ + NodeID: uuid.Generate(), + QueryOptions: structs.QueryOptions{Region: "global"}, + } + + // Fetch the response + var resp structs.GenericResponse + err := msgpackrpc.CallWithCodec(codec, "ClientAllocations.GarbageCollectAll", req, &resp) + require.NotNil(err) + require.Contains(err.Error(), "Unknown node") +} + +func TestClientAllocations_GarbageCollectAll_Remote(t *testing.T) { + t.Parallel() + require := require.New(t) + + // Start a server and client + s1 := TestServer(t, nil) + defer s1.Shutdown() + s2 := TestServer(t, func(c *Config) { + c.DevDisableBootstrap = true + }) + defer s2.Shutdown() + TestJoin(t, s1, s2) + testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForLeader(t, s2.RPC) + codec := rpcClient(t, s2) + + c := client.TestClient(t, func(c *config.Config) { + c.Servers = []string{s2.config.RPCAddr.String()} + c.GCDiskUsageThreshold = 100.0 + }) + defer c.Shutdown() + + testutil.WaitForResult(func() (bool, error) { + nodes := s2.connectedNodes() + return len(nodes) == 1, nil + }, func(err error) { + t.Fatalf("should have a clients") + }) + + // Force remove the connection locally in case it exists + s1.nodeConnsLock.Lock() + delete(s1.nodeConns, c.NodeID()) + s1.nodeConnsLock.Unlock() + + // Make the request + req := &structs.NodeSpecificRequest{ + NodeID: c.NodeID(), + QueryOptions: structs.QueryOptions{Region: "global"}, + } + + // Fetch the response + var resp cstructs.ClientStatsResponse + err := msgpackrpc.CallWithCodec(codec, "ClientAllocations.GarbageCollectAll", req, &resp) + require.Nil(err) +} + +func TestClientAllocations_GarbageCollect_Local(t *testing.T) { + t.Parallel() + require := require.New(t) + + // Start a server and client + s := TestServer(t, nil) + defer s.Shutdown() + codec := rpcClient(t, s) + testutil.WaitForLeader(t, s.RPC) + + c := client.TestClient(t, func(c *config.Config) { + c.Servers = []string{s.config.RPCAddr.String()} + c.GCDiskUsageThreshold = 100.0 + }) + defer c.Shutdown() + + // Force an allocation onto the node + a := mock.Alloc() + a.Job.Type = structs.JobTypeBatch + a.NodeID = c.NodeID() + a.Job.TaskGroups[0].Count = 1 + a.Job.TaskGroups[0].Tasks[0] = &structs.Task{ + Name: "web", + Driver: "mock_driver", + Config: map[string]interface{}{ + "run_for": "2s", + }, + LogConfig: structs.DefaultLogConfig(), + Resources: &structs.Resources{ + CPU: 500, + MemoryMB: 256, + }, + } + + testutil.WaitForResult(func() (bool, error) { + nodes := s.connectedNodes() + return len(nodes) == 1, nil + }, func(err error) { + t.Fatalf("should have a clients") + }) + + // Upsert the allocation + state := s.State() + require.Nil(state.UpsertJob(999, a.Job)) + require.Nil(state.UpsertAllocs(1003, []*structs.Allocation{a})) + + // Wait for the client to run the allocation + testutil.WaitForResult(func() (bool, error) { + alloc, err := state.AllocByID(nil, a.ID) + if err != nil { + return false, err + } + if alloc == nil { + return false, fmt.Errorf("unknown alloc") + } + if alloc.ClientStatus != structs.AllocClientStatusComplete { + return false, fmt.Errorf("alloc client status: %v", alloc.ClientStatus) + } + + return true, nil + }, func(err error) { + t.Fatalf("Alloc on node %q not finished: %v", c.NodeID(), err) + }) + + // Make the request without having an alloc id + req := &structs.AllocSpecificRequest{ + QueryOptions: structs.QueryOptions{Region: "global"}, + } + + // Fetch the response + var resp structs.GenericResponse + err := msgpackrpc.CallWithCodec(codec, "ClientAllocations.GarbageCollect", req, &resp) + require.NotNil(err) + require.Contains(err.Error(), "missing") + + // Fetch the response setting the node id + req.AllocID = a.ID + var resp2 structs.GenericResponse + err = msgpackrpc.CallWithCodec(codec, "ClientAllocations.GarbageCollect", req, &resp2) + require.Nil(err) +} + +func TestClientAllocations_GarbageCollect_Local_ACL(t *testing.T) { + t.Parallel() + require := require.New(t) + + // Start a server + s, root := TestACLServer(t, nil) + defer s.Shutdown() + codec := rpcClient(t, s) + testutil.WaitForLeader(t, s.RPC) + + // Create a bad token + policyBad := mock.NamespacePolicy("other", "", []string{acl.NamespaceCapabilityReadFS}) + tokenBad := mock.CreatePolicyAndToken(t, s.State(), 1005, "invalid", policyBad) + + policyGood := mock.NamespacePolicy(structs.DefaultNamespace, "", []string{acl.NamespaceCapabilitySubmitJob}) + tokenGood := mock.CreatePolicyAndToken(t, s.State(), 1009, "valid2", policyGood) + + cases := []struct { + Name string + Token string + ExpectedError string + }{ + { + Name: "bad token", + Token: tokenBad.SecretID, + ExpectedError: structs.ErrPermissionDenied.Error(), + }, + { + Name: "good token", + Token: tokenGood.SecretID, + ExpectedError: structs.ErrUnknownAllocationPrefix, + }, + { + Name: "root token", + Token: root.SecretID, + ExpectedError: structs.ErrUnknownAllocationPrefix, + }, + } + + for _, c := range cases { + t.Run(c.Name, func(t *testing.T) { + + // Make the request without having a node-id + req := &structs.AllocSpecificRequest{ + AllocID: uuid.Generate(), + QueryOptions: structs.QueryOptions{ + AuthToken: c.Token, + Region: "global", + Namespace: structs.DefaultNamespace, + }, + } + + // Fetch the response + var resp structs.GenericResponse + err := msgpackrpc.CallWithCodec(codec, "ClientAllocations.GarbageCollect", req, &resp) + require.NotNil(err) + require.Contains(err.Error(), c.ExpectedError) + }) + } +} + +func TestClientAllocations_GarbageCollect_Remote(t *testing.T) { + t.Parallel() + require := require.New(t) + + // Start a server and client + s1 := TestServer(t, nil) + defer s1.Shutdown() + s2 := TestServer(t, func(c *Config) { + c.DevDisableBootstrap = true + }) + defer s2.Shutdown() + TestJoin(t, s1, s2) + testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForLeader(t, s2.RPC) + codec := rpcClient(t, s2) + + c := client.TestClient(t, func(c *config.Config) { + c.Servers = []string{s2.config.RPCAddr.String()} + c.GCDiskUsageThreshold = 100.0 + }) + defer c.Shutdown() + + // Force an allocation onto the node + a := mock.Alloc() + a.Job.Type = structs.JobTypeBatch + a.NodeID = c.NodeID() + a.Job.TaskGroups[0].Count = 1 + a.Job.TaskGroups[0].Tasks[0] = &structs.Task{ + Name: "web", + Driver: "mock_driver", + Config: map[string]interface{}{ + "run_for": "2s", + }, + LogConfig: structs.DefaultLogConfig(), + Resources: &structs.Resources{ + CPU: 500, + MemoryMB: 256, + }, + } + testutil.WaitForResult(func() (bool, error) { + nodes := s2.connectedNodes() + return len(nodes) == 1, nil + }, func(err error) { + t.Fatalf("should have a clients") + }) + + // Upsert the allocation + state1 := s1.State() + state2 := s2.State() + require.Nil(state1.UpsertJob(999, a.Job)) + require.Nil(state1.UpsertAllocs(1003, []*structs.Allocation{a})) + require.Nil(state2.UpsertJob(999, a.Job)) + require.Nil(state2.UpsertAllocs(1003, []*structs.Allocation{a})) + + // Wait for the client to run the allocation + testutil.WaitForResult(func() (bool, error) { + alloc, err := state2.AllocByID(nil, a.ID) + if err != nil { + return false, err + } + if alloc == nil { + return false, fmt.Errorf("unknown alloc") + } + if alloc.ClientStatus != structs.AllocClientStatusComplete { + return false, fmt.Errorf("alloc client status: %v", alloc.ClientStatus) + } + + return true, nil + }, func(err error) { + t.Fatalf("Alloc on node %q not finished: %v", c.NodeID(), err) + }) + + // Force remove the connection locally in case it exists + s1.nodeConnsLock.Lock() + delete(s1.nodeConns, c.NodeID()) + s1.nodeConnsLock.Unlock() + + // Make the request + req := &structs.AllocSpecificRequest{ + AllocID: a.ID, + QueryOptions: structs.QueryOptions{Region: "global"}, + } + + // Fetch the response + var resp cstructs.ClientStatsResponse + err := msgpackrpc.CallWithCodec(codec, "ClientAllocations.GarbageCollect", req, &resp) + require.Nil(err) +} + +func TestClientAllocations_Stats_Local(t *testing.T) { + t.Parallel() + require := require.New(t) + + // Start a server and client + s := TestServer(t, nil) + defer s.Shutdown() + codec := rpcClient(t, s) + testutil.WaitForLeader(t, s.RPC) + + c := client.TestClient(t, func(c *config.Config) { + c.Servers = []string{s.config.RPCAddr.String()} + }) + defer c.Shutdown() + + // Force an allocation onto the node + a := mock.Alloc() + a.Job.Type = structs.JobTypeBatch + a.NodeID = c.NodeID() + a.Job.TaskGroups[0].Count = 1 + a.Job.TaskGroups[0].Tasks[0] = &structs.Task{ + Name: "web", + Driver: "mock_driver", + Config: map[string]interface{}{ + "run_for": "2s", + }, + LogConfig: structs.DefaultLogConfig(), + Resources: &structs.Resources{ + CPU: 500, + MemoryMB: 256, + }, + } + + testutil.WaitForResult(func() (bool, error) { + nodes := s.connectedNodes() + return len(nodes) == 1, nil + }, func(err error) { + t.Fatalf("should have a clients") + }) + + // Upsert the allocation + state := s.State() + require.Nil(state.UpsertJob(999, a.Job)) + require.Nil(state.UpsertAllocs(1003, []*structs.Allocation{a})) + + // Wait for the client to run the allocation + testutil.WaitForResult(func() (bool, error) { + alloc, err := state.AllocByID(nil, a.ID) + if err != nil { + return false, err + } + if alloc == nil { + return false, fmt.Errorf("unknown alloc") + } + if alloc.ClientStatus != structs.AllocClientStatusComplete { + return false, fmt.Errorf("alloc client status: %v", alloc.ClientStatus) + } + + return true, nil + }, func(err error) { + t.Fatalf("Alloc on node %q not finished: %v", c.NodeID(), err) + }) + + // Make the request without having an alloc id + req := &structs.AllocSpecificRequest{ + QueryOptions: structs.QueryOptions{Region: "global"}, + } + + // Fetch the response + var resp cstructs.AllocStatsResponse + err := msgpackrpc.CallWithCodec(codec, "ClientAllocations.Stats", req, &resp) + require.NotNil(err) + require.Contains(err.Error(), "missing") + + // Fetch the response setting the node id + req.AllocID = a.ID + var resp2 cstructs.AllocStatsResponse + err = msgpackrpc.CallWithCodec(codec, "ClientAllocations.Stats", req, &resp2) + require.Nil(err) + require.NotNil(resp2.Stats) +} + +func TestClientAllocations_Stats_Local_ACL(t *testing.T) { + t.Parallel() + require := require.New(t) + + // Start a server + s, root := TestACLServer(t, nil) + defer s.Shutdown() + codec := rpcClient(t, s) + testutil.WaitForLeader(t, s.RPC) + + // Create a bad token + policyBad := mock.NamespacePolicy("other", "", []string{acl.NamespaceCapabilityReadFS}) + tokenBad := mock.CreatePolicyAndToken(t, s.State(), 1005, "invalid", policyBad) + + policyGood := mock.NamespacePolicy(structs.DefaultNamespace, "", []string{acl.NamespaceCapabilityReadJob}) + tokenGood := mock.CreatePolicyAndToken(t, s.State(), 1009, "valid2", policyGood) + + cases := []struct { + Name string + Token string + ExpectedError string + }{ + { + Name: "bad token", + Token: tokenBad.SecretID, + ExpectedError: structs.ErrPermissionDenied.Error(), + }, + { + Name: "good token", + Token: tokenGood.SecretID, + ExpectedError: structs.ErrUnknownAllocationPrefix, + }, + { + Name: "root token", + Token: root.SecretID, + ExpectedError: structs.ErrUnknownAllocationPrefix, + }, + } + + for _, c := range cases { + t.Run(c.Name, func(t *testing.T) { + + // Make the request without having a node-id + req := &structs.AllocSpecificRequest{ + AllocID: uuid.Generate(), + QueryOptions: structs.QueryOptions{ + AuthToken: c.Token, + Region: "global", + Namespace: structs.DefaultNamespace, + }, + } + + // Fetch the response + var resp cstructs.AllocStatsResponse + err := msgpackrpc.CallWithCodec(codec, "ClientAllocations.Stats", req, &resp) + require.NotNil(err) + require.Contains(err.Error(), c.ExpectedError) + }) + } +} + +func TestClientAllocations_Stats_Remote(t *testing.T) { + t.Parallel() + require := require.New(t) + + // Start a server and client + s1 := TestServer(t, nil) + defer s1.Shutdown() + s2 := TestServer(t, func(c *Config) { + c.DevDisableBootstrap = true + }) + defer s2.Shutdown() + TestJoin(t, s1, s2) + testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForLeader(t, s2.RPC) + codec := rpcClient(t, s2) + + c := client.TestClient(t, func(c *config.Config) { + c.Servers = []string{s2.config.RPCAddr.String()} + }) + defer c.Shutdown() + + // Force an allocation onto the node + a := mock.Alloc() + a.Job.Type = structs.JobTypeBatch + a.NodeID = c.NodeID() + a.Job.TaskGroups[0].Count = 1 + a.Job.TaskGroups[0].Tasks[0] = &structs.Task{ + Name: "web", + Driver: "mock_driver", + Config: map[string]interface{}{ + "run_for": "2s", + }, + LogConfig: structs.DefaultLogConfig(), + Resources: &structs.Resources{ + CPU: 500, + MemoryMB: 256, + }, + } + testutil.WaitForResult(func() (bool, error) { + nodes := s2.connectedNodes() + return len(nodes) == 1, nil + }, func(err error) { + t.Fatalf("should have a clients") + }) + + // Upsert the allocation + state1 := s1.State() + state2 := s2.State() + require.Nil(state1.UpsertJob(999, a.Job)) + require.Nil(state1.UpsertAllocs(1003, []*structs.Allocation{a})) + require.Nil(state2.UpsertJob(999, a.Job)) + require.Nil(state2.UpsertAllocs(1003, []*structs.Allocation{a})) + + // Wait for the client to run the allocation + testutil.WaitForResult(func() (bool, error) { + alloc, err := state2.AllocByID(nil, a.ID) + if err != nil { + return false, err + } + if alloc == nil { + return false, fmt.Errorf("unknown alloc") + } + if alloc.ClientStatus != structs.AllocClientStatusComplete { + return false, fmt.Errorf("alloc client status: %v", alloc.ClientStatus) + } + + return true, nil + }, func(err error) { + t.Fatalf("Alloc on node %q not finished: %v", c.NodeID(), err) + }) + + // Force remove the connection locally in case it exists + s1.nodeConnsLock.Lock() + delete(s1.nodeConns, c.NodeID()) + s1.nodeConnsLock.Unlock() + + // Make the request + req := &structs.AllocSpecificRequest{ + AllocID: a.ID, + QueryOptions: structs.QueryOptions{Region: "global"}, + } + + // Fetch the response + var resp cstructs.AllocStatsResponse + err := msgpackrpc.CallWithCodec(codec, "ClientAllocations.Stats", req, &resp) + require.Nil(err) + require.NotNil(resp.Stats) +} diff --git a/nomad/client_fs_endpoint.go b/nomad/client_fs_endpoint.go new file mode 100644 index 00000000000..23cbd26b5a3 --- /dev/null +++ b/nomad/client_fs_endpoint.go @@ -0,0 +1,394 @@ +package nomad + +import ( + "errors" + "io" + "net" + "strings" + "time" + + metrics "github.com/armon/go-metrics" + "github.com/hashicorp/nomad/acl" + cstructs "github.com/hashicorp/nomad/client/structs" + "github.com/hashicorp/nomad/helper" + "github.com/hashicorp/nomad/nomad/structs" + "github.com/ugorji/go/codec" +) + +// FileSystem endpoint is used for accessing the logs and filesystem of +// allocations from a Node. +type FileSystem struct { + srv *Server +} + +func (f *FileSystem) register() { + f.srv.streamingRpcs.Register("FileSystem.Logs", f.logs) + f.srv.streamingRpcs.Register("FileSystem.Stream", f.stream) +} + +// handleStreamResultError is a helper for sending an error with a potential +// error code. The transmission of the error is ignored if the error has been +// generated by the closing of the underlying transport. +func (f *FileSystem) handleStreamResultError(err error, code *int64, encoder *codec.Encoder) { + // Nothing to do as the conn is closed + if err == io.EOF || strings.Contains(err.Error(), "closed") { + return + } + + // Attempt to send the error + encoder.Encode(&cstructs.StreamErrWrapper{ + Error: cstructs.NewRpcError(err, code), + }) +} + +// forwardRegionStreamingRpc is used to make a streaming RPC to a different +// region. It looks up the allocation in the remote region to determine what +// remote server can route the request. +func (f *FileSystem) forwardRegionStreamingRpc(conn io.ReadWriteCloser, + encoder *codec.Encoder, args interface{}, method, allocID string, qo *structs.QueryOptions) { + // Request the allocation from the target region + allocReq := &structs.AllocSpecificRequest{ + AllocID: allocID, + QueryOptions: *qo, + } + var allocResp structs.SingleAllocResponse + if err := f.srv.forwardRegion(qo.RequestRegion(), "Alloc.GetAlloc", allocReq, &allocResp); err != nil { + f.handleStreamResultError(err, nil, encoder) + return + } + + if allocResp.Alloc == nil { + f.handleStreamResultError(structs.NewErrUnknownAllocation(allocID), helper.Int64ToPtr(404), encoder) + return + } + + // Determine the Server that has a connection to the node. + srv, err := f.srv.serverWithNodeConn(allocResp.Alloc.NodeID, qo.RequestRegion()) + if err != nil { + var code *int64 + if structs.IsErrNoNodeConn(err) { + code = helper.Int64ToPtr(404) + } + f.handleStreamResultError(err, code, encoder) + return + } + + // Get a connection to the server + srvConn, err := f.srv.streamingRpc(srv, method) + if err != nil { + f.handleStreamResultError(err, nil, encoder) + return + } + defer srvConn.Close() + + // Send the request. + outEncoder := codec.NewEncoder(srvConn, structs.MsgpackHandle) + if err := outEncoder.Encode(args); err != nil { + f.handleStreamResultError(err, nil, encoder) + return + } + + structs.Bridge(conn, srvConn) +} + +// List is used to list the contents of an allocation's directory. +func (f *FileSystem) List(args *cstructs.FsListRequest, reply *cstructs.FsListResponse) error { + // We only allow stale reads since the only potentially stale information is + // the Node registration and the cost is fairly high for adding another hope + // in the forwarding chain. + args.QueryOptions.AllowStale = true + + // Potentially forward to a different region. + if done, err := f.srv.forward("FileSystem.List", args, args, reply); done { + return err + } + defer metrics.MeasureSince([]string{"nomad", "file_system", "list"}, time.Now()) + + // Check filesystem read permissions + if aclObj, err := f.srv.ResolveToken(args.AuthToken); err != nil { + return err + } else if aclObj != nil && !aclObj.AllowNsOp(args.Namespace, acl.NamespaceCapabilityReadFS) { + return structs.ErrPermissionDenied + } + + // Verify the arguments. + if args.AllocID == "" { + return errors.New("missing allocation ID") + } + + // Lookup the allocation + snap, err := f.srv.State().Snapshot() + if err != nil { + return err + } + + alloc, err := snap.AllocByID(nil, args.AllocID) + if err != nil { + return err + } + if alloc == nil { + return structs.NewErrUnknownAllocation(args.AllocID) + } + + // Get the connection to the client + state, ok := f.srv.getNodeConn(alloc.NodeID) + if !ok { + return findNodeConnAndForward(f.srv, snap, alloc.NodeID, "FileSystem.List", args, reply) + } + + // Make the RPC + return NodeRpc(state.Session, "FileSystem.List", args, reply) +} + +// Stat is used to stat a file in the allocation's directory. +func (f *FileSystem) Stat(args *cstructs.FsStatRequest, reply *cstructs.FsStatResponse) error { + // We only allow stale reads since the only potentially stale information is + // the Node registration and the cost is fairly high for adding another hope + // in the forwarding chain. + args.QueryOptions.AllowStale = true + + // Potentially forward to a different region. + if done, err := f.srv.forward("FileSystem.Stat", args, args, reply); done { + return err + } + defer metrics.MeasureSince([]string{"nomad", "file_system", "stat"}, time.Now()) + + // Check filesystem read permissions + if aclObj, err := f.srv.ResolveToken(args.AuthToken); err != nil { + return err + } else if aclObj != nil && !aclObj.AllowNsOp(args.Namespace, acl.NamespaceCapabilityReadFS) { + return structs.ErrPermissionDenied + } + + // Verify the arguments. + if args.AllocID == "" { + return errors.New("missing allocation ID") + } + + // Lookup the allocation + snap, err := f.srv.State().Snapshot() + if err != nil { + return err + } + + alloc, err := snap.AllocByID(nil, args.AllocID) + if err != nil { + return err + } + if alloc == nil { + return structs.NewErrUnknownAllocation(args.AllocID) + } + + // Get the connection to the client + state, ok := f.srv.getNodeConn(alloc.NodeID) + if !ok { + return findNodeConnAndForward(f.srv, snap, alloc.NodeID, "FileSystem.Stat", args, reply) + } + + // Make the RPC + return NodeRpc(state.Session, "FileSystem.Stat", args, reply) +} + +// stream is is used to stream the contents of file in an allocation's +// directory. +func (f *FileSystem) stream(conn io.ReadWriteCloser) { + defer conn.Close() + defer metrics.MeasureSince([]string{"nomad", "file_system", "stream"}, time.Now()) + + // Decode the arguments + var args cstructs.FsStreamRequest + decoder := codec.NewDecoder(conn, structs.MsgpackHandle) + encoder := codec.NewEncoder(conn, structs.MsgpackHandle) + + if err := decoder.Decode(&args); err != nil { + f.handleStreamResultError(err, helper.Int64ToPtr(500), encoder) + return + } + + // Check if we need to forward to a different region + if r := args.RequestRegion(); r != f.srv.Region() { + f.forwardRegionStreamingRpc(conn, encoder, &args, "FileSystem.Stream", + args.AllocID, &args.QueryOptions) + return + } + + // Check node read permissions + if aclObj, err := f.srv.ResolveToken(args.AuthToken); err != nil { + f.handleStreamResultError(err, nil, encoder) + return + } else if aclObj != nil && !aclObj.AllowNsOp(args.Namespace, acl.NamespaceCapabilityReadFS) { + f.handleStreamResultError(structs.ErrPermissionDenied, nil, encoder) + return + } + + // Verify the arguments. + if args.AllocID == "" { + f.handleStreamResultError(errors.New("missing AllocID"), helper.Int64ToPtr(400), encoder) + return + } + + // Retrieve the allocation + snap, err := f.srv.State().Snapshot() + if err != nil { + f.handleStreamResultError(err, nil, encoder) + return + } + + alloc, err := snap.AllocByID(nil, args.AllocID) + if err != nil { + f.handleStreamResultError(err, nil, encoder) + return + } + if alloc == nil { + f.handleStreamResultError(structs.NewErrUnknownAllocation(args.AllocID), helper.Int64ToPtr(404), encoder) + return + } + nodeID := alloc.NodeID + + // Get the connection to the client either by forwarding to another server + // or creating a direct stream + var clientConn net.Conn + state, ok := f.srv.getNodeConn(nodeID) + if !ok { + // Determine the Server that has a connection to the node. + srv, err := f.srv.serverWithNodeConn(nodeID, f.srv.Region()) + if err != nil { + var code *int64 + if structs.IsErrNoNodeConn(err) { + code = helper.Int64ToPtr(404) + } + f.handleStreamResultError(err, code, encoder) + } + + // Get a connection to the server + conn, err := f.srv.streamingRpc(srv, "FileSystem.Stream") + if err != nil { + f.handleStreamResultError(err, nil, encoder) + return + } + + clientConn = conn + } else { + stream, err := NodeStreamingRpc(state.Session, "FileSystem.Stream") + if err != nil { + f.handleStreamResultError(err, nil, encoder) + return + } + clientConn = stream + } + defer clientConn.Close() + + // Send the request. + outEncoder := codec.NewEncoder(clientConn, structs.MsgpackHandle) + if err := outEncoder.Encode(args); err != nil { + f.handleStreamResultError(err, nil, encoder) + return + } + + structs.Bridge(conn, clientConn) + return +} + +// logs is used to access an task's logs for a given allocation +func (f *FileSystem) logs(conn io.ReadWriteCloser) { + defer conn.Close() + defer metrics.MeasureSince([]string{"nomad", "file_system", "logs"}, time.Now()) + + // Decode the arguments + var args cstructs.FsLogsRequest + decoder := codec.NewDecoder(conn, structs.MsgpackHandle) + encoder := codec.NewEncoder(conn, structs.MsgpackHandle) + + if err := decoder.Decode(&args); err != nil { + f.handleStreamResultError(err, helper.Int64ToPtr(500), encoder) + return + } + + // Check if we need to forward to a different region + if r := args.RequestRegion(); r != f.srv.Region() { + f.forwardRegionStreamingRpc(conn, encoder, &args, "FileSystem.Logs", + args.AllocID, &args.QueryOptions) + return + } + + // Check node read permissions + if aclObj, err := f.srv.ResolveToken(args.AuthToken); err != nil { + f.handleStreamResultError(err, nil, encoder) + return + } else if aclObj != nil { + readfs := aclObj.AllowNsOp(args.QueryOptions.Namespace, acl.NamespaceCapabilityReadFS) + logs := aclObj.AllowNsOp(args.QueryOptions.Namespace, acl.NamespaceCapabilityReadLogs) + if !readfs && !logs { + f.handleStreamResultError(structs.ErrPermissionDenied, nil, encoder) + return + } + } + + // Verify the arguments. + if args.AllocID == "" { + f.handleStreamResultError(errors.New("missing AllocID"), helper.Int64ToPtr(400), encoder) + return + } + + // Retrieve the allocation + snap, err := f.srv.State().Snapshot() + if err != nil { + f.handleStreamResultError(err, nil, encoder) + return + } + + alloc, err := snap.AllocByID(nil, args.AllocID) + if err != nil { + f.handleStreamResultError(err, nil, encoder) + return + } + if alloc == nil { + f.handleStreamResultError(structs.NewErrUnknownAllocation(args.AllocID), helper.Int64ToPtr(404), encoder) + return + } + nodeID := alloc.NodeID + + // Get the connection to the client either by forwarding to another server + // or creating a direct stream + var clientConn net.Conn + state, ok := f.srv.getNodeConn(nodeID) + if !ok { + // Determine the Server that has a connection to the node. + srv, err := f.srv.serverWithNodeConn(nodeID, f.srv.Region()) + if err != nil { + var code *int64 + if structs.IsErrNoNodeConn(err) { + code = helper.Int64ToPtr(404) + } + f.handleStreamResultError(err, code, encoder) + return + } + + // Get a connection to the server + conn, err := f.srv.streamingRpc(srv, "FileSystem.Logs") + if err != nil { + f.handleStreamResultError(err, nil, encoder) + return + } + + clientConn = conn + } else { + stream, err := NodeStreamingRpc(state.Session, "FileSystem.Logs") + if err != nil { + f.handleStreamResultError(err, nil, encoder) + return + } + clientConn = stream + } + defer clientConn.Close() + + // Send the request. + outEncoder := codec.NewEncoder(clientConn, structs.MsgpackHandle) + if err := outEncoder.Encode(args); err != nil { + f.handleStreamResultError(err, nil, encoder) + return + } + + structs.Bridge(conn, clientConn) + return +} diff --git a/nomad/client_fs_endpoint_test.go b/nomad/client_fs_endpoint_test.go new file mode 100644 index 00000000000..c396f833f5a --- /dev/null +++ b/nomad/client_fs_endpoint_test.go @@ -0,0 +1,1986 @@ +package nomad + +import ( + "fmt" + "io" + "net" + "strings" + "testing" + "time" + + msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc" + "github.com/hashicorp/nomad/acl" + "github.com/hashicorp/nomad/client" + "github.com/hashicorp/nomad/client/config" + cstructs "github.com/hashicorp/nomad/client/structs" + "github.com/hashicorp/nomad/helper/uuid" + "github.com/hashicorp/nomad/nomad/mock" + "github.com/hashicorp/nomad/nomad/structs" + "github.com/hashicorp/nomad/testutil" + "github.com/stretchr/testify/require" + codec "github.com/ugorji/go/codec" +) + +func TestClientFS_List_Local(t *testing.T) { + t.Parallel() + require := require.New(t) + + // Start a server and client + s := TestServer(t, nil) + defer s.Shutdown() + codec := rpcClient(t, s) + testutil.WaitForLeader(t, s.RPC) + + c := client.TestClient(t, func(c *config.Config) { + c.Servers = []string{s.config.RPCAddr.String()} + }) + defer c.Shutdown() + + // Force an allocation onto the node + a := mock.Alloc() + a.Job.Type = structs.JobTypeBatch + a.NodeID = c.NodeID() + a.Job.TaskGroups[0].Count = 1 + a.Job.TaskGroups[0].Tasks[0] = &structs.Task{ + Name: "web", + Driver: "mock_driver", + Config: map[string]interface{}{ + "run_for": "2s", + }, + LogConfig: structs.DefaultLogConfig(), + Resources: &structs.Resources{ + CPU: 500, + MemoryMB: 256, + }, + } + + // Wait for the client to connect + testutil.WaitForResult(func() (bool, error) { + nodes := s.connectedNodes() + return len(nodes) == 1, nil + }, func(err error) { + t.Fatalf("should have a clients") + }) + + // Upsert the allocation + state := s.State() + require.Nil(state.UpsertJob(999, a.Job)) + require.Nil(state.UpsertAllocs(1003, []*structs.Allocation{a})) + + // Wait for the client to run the allocation + testutil.WaitForResult(func() (bool, error) { + alloc, err := state.AllocByID(nil, a.ID) + if err != nil { + return false, err + } + if alloc == nil { + return false, fmt.Errorf("unknown alloc") + } + if alloc.ClientStatus != structs.AllocClientStatusComplete { + return false, fmt.Errorf("alloc client status: %v", alloc.ClientStatus) + } + + return true, nil + }, func(err error) { + t.Fatalf("Alloc on node %q not finished: %v", c.NodeID(), err) + }) + + // Make the request without having a node-id + req := &cstructs.FsListRequest{ + Path: "/", + QueryOptions: structs.QueryOptions{Region: "global"}, + } + + // Fetch the response + var resp cstructs.FsListResponse + err := msgpackrpc.CallWithCodec(codec, "FileSystem.List", req, &resp) + require.NotNil(err) + require.Contains(err.Error(), "missing") + + // Fetch the response setting the alloc id + req.AllocID = a.ID + var resp2 cstructs.FsListResponse + err = msgpackrpc.CallWithCodec(codec, "FileSystem.List", req, &resp2) + require.Nil(err) + require.NotEmpty(resp2.Files) +} + +func TestClientFS_List_ACL(t *testing.T) { + t.Parallel() + require := require.New(t) + + // Start a server + s, root := TestACLServer(t, nil) + defer s.Shutdown() + codec := rpcClient(t, s) + testutil.WaitForLeader(t, s.RPC) + + // Create a bad token + policyBad := mock.NamespacePolicy(structs.DefaultNamespace, "", []string{acl.NamespaceCapabilityDeny}) + tokenBad := mock.CreatePolicyAndToken(t, s.State(), 1005, "invalid", policyBad) + + policyGood := mock.NamespacePolicy(structs.DefaultNamespace, "", []string{acl.NamespaceCapabilityReadFS}) + tokenGood := mock.CreatePolicyAndToken(t, s.State(), 1009, "valid2", policyGood) + + cases := []struct { + Name string + Token string + ExpectedError string + }{ + { + Name: "bad token", + Token: tokenBad.SecretID, + ExpectedError: structs.ErrPermissionDenied.Error(), + }, + { + Name: "good token", + Token: tokenGood.SecretID, + ExpectedError: structs.ErrUnknownAllocationPrefix, + }, + { + Name: "root token", + Token: root.SecretID, + ExpectedError: structs.ErrUnknownAllocationPrefix, + }, + } + + for _, c := range cases { + t.Run(c.Name, func(t *testing.T) { + + // Make the request + req := &cstructs.FsListRequest{ + AllocID: uuid.Generate(), + Path: "/", + QueryOptions: structs.QueryOptions{ + Region: "global", + Namespace: structs.DefaultNamespace, + AuthToken: c.Token, + }, + } + + // Fetch the response + var resp cstructs.FsListResponse + err := msgpackrpc.CallWithCodec(codec, "FileSystem.List", req, &resp) + require.NotNil(err) + require.Contains(err.Error(), c.ExpectedError) + }) + } +} + +func TestClientFS_List_Remote(t *testing.T) { + t.Parallel() + require := require.New(t) + + // Start a server and client + s1 := TestServer(t, nil) + defer s1.Shutdown() + s2 := TestServer(t, func(c *Config) { + c.DevDisableBootstrap = true + }) + defer s2.Shutdown() + TestJoin(t, s1, s2) + testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForLeader(t, s2.RPC) + codec := rpcClient(t, s2) + + c := client.TestClient(t, func(c *config.Config) { + c.Servers = []string{s2.config.RPCAddr.String()} + }) + defer c.Shutdown() + + // Force an allocation onto the node + a := mock.Alloc() + a.Job.Type = structs.JobTypeBatch + a.NodeID = c.NodeID() + a.Job.TaskGroups[0].Count = 1 + a.Job.TaskGroups[0].Tasks[0] = &structs.Task{ + Name: "web", + Driver: "mock_driver", + Config: map[string]interface{}{ + "run_for": "2s", + }, + LogConfig: structs.DefaultLogConfig(), + Resources: &structs.Resources{ + CPU: 500, + MemoryMB: 256, + }, + } + + // Wait for the client to connect + testutil.WaitForResult(func() (bool, error) { + nodes := s2.connectedNodes() + return len(nodes) == 1, nil + }, func(err error) { + t.Fatalf("should have a clients") + }) + + // Upsert the allocation + state1 := s1.State() + state2 := s2.State() + require.Nil(state1.UpsertJob(999, a.Job)) + require.Nil(state1.UpsertAllocs(1003, []*structs.Allocation{a})) + require.Nil(state2.UpsertJob(999, a.Job)) + require.Nil(state2.UpsertAllocs(1003, []*structs.Allocation{a})) + + // Wait for the client to run the allocation + testutil.WaitForResult(func() (bool, error) { + alloc, err := state2.AllocByID(nil, a.ID) + if err != nil { + return false, err + } + if alloc == nil { + return false, fmt.Errorf("unknown alloc") + } + if alloc.ClientStatus != structs.AllocClientStatusComplete { + return false, fmt.Errorf("alloc client status: %v", alloc.ClientStatus) + } + + return true, nil + }, func(err error) { + t.Fatalf("Alloc on node %q not finished: %v", c.NodeID(), err) + }) + + // Force remove the connection locally in case it exists + s1.nodeConnsLock.Lock() + delete(s1.nodeConns, c.NodeID()) + s1.nodeConnsLock.Unlock() + + // Make the request without having a node-id + req := &cstructs.FsListRequest{ + AllocID: a.ID, + Path: "/", + QueryOptions: structs.QueryOptions{Region: "global"}, + } + + // Fetch the response + var resp cstructs.FsListResponse + err := msgpackrpc.CallWithCodec(codec, "FileSystem.List", req, &resp) + require.Nil(err) + require.NotEmpty(resp.Files) +} + +func TestClientFS_Stat_Local(t *testing.T) { + t.Parallel() + require := require.New(t) + + // Start a server and client + s := TestServer(t, nil) + defer s.Shutdown() + codec := rpcClient(t, s) + testutil.WaitForLeader(t, s.RPC) + + c := client.TestClient(t, func(c *config.Config) { + c.Servers = []string{s.config.RPCAddr.String()} + }) + defer c.Shutdown() + + // Force an allocation onto the node + a := mock.Alloc() + a.Job.Type = structs.JobTypeBatch + a.NodeID = c.NodeID() + a.Job.TaskGroups[0].Count = 1 + a.Job.TaskGroups[0].Tasks[0] = &structs.Task{ + Name: "web", + Driver: "mock_driver", + Config: map[string]interface{}{ + "run_for": "2s", + }, + LogConfig: structs.DefaultLogConfig(), + Resources: &structs.Resources{ + CPU: 500, + MemoryMB: 256, + }, + } + + // Wait for the client to connect + testutil.WaitForResult(func() (bool, error) { + nodes := s.connectedNodes() + return len(nodes) == 1, nil + }, func(err error) { + t.Fatalf("should have a clients") + }) + + // Upsert the allocation + state := s.State() + require.Nil(state.UpsertJob(999, a.Job)) + require.Nil(state.UpsertAllocs(1003, []*structs.Allocation{a})) + + // Wait for the client to run the allocation + testutil.WaitForResult(func() (bool, error) { + alloc, err := state.AllocByID(nil, a.ID) + if err != nil { + return false, err + } + if alloc == nil { + return false, fmt.Errorf("unknown alloc") + } + if alloc.ClientStatus != structs.AllocClientStatusComplete { + return false, fmt.Errorf("alloc client status: %v", alloc.ClientStatus) + } + + return true, nil + }, func(err error) { + t.Fatalf("Alloc on node %q not finished: %v", c.NodeID(), err) + }) + + // Make the request without having a node-id + req := &cstructs.FsStatRequest{ + Path: "/", + QueryOptions: structs.QueryOptions{Region: "global"}, + } + + // Fetch the response + var resp cstructs.FsStatResponse + err := msgpackrpc.CallWithCodec(codec, "FileSystem.Stat", req, &resp) + require.NotNil(err) + require.Contains(err.Error(), "missing") + + // Fetch the response setting the alloc id + req.AllocID = a.ID + var resp2 cstructs.FsStatResponse + err = msgpackrpc.CallWithCodec(codec, "FileSystem.Stat", req, &resp2) + require.Nil(err) + require.NotNil(resp2.Info) +} + +func TestClientFS_Stat_ACL(t *testing.T) { + t.Parallel() + require := require.New(t) + + // Start a server + s, root := TestACLServer(t, nil) + defer s.Shutdown() + codec := rpcClient(t, s) + testutil.WaitForLeader(t, s.RPC) + + // Create a bad token + policyBad := mock.NamespacePolicy(structs.DefaultNamespace, "", []string{acl.NamespaceCapabilityDeny}) + tokenBad := mock.CreatePolicyAndToken(t, s.State(), 1005, "invalid", policyBad) + + policyGood := mock.NamespacePolicy(structs.DefaultNamespace, "", []string{acl.NamespaceCapabilityReadFS}) + tokenGood := mock.CreatePolicyAndToken(t, s.State(), 1009, "valid2", policyGood) + + cases := []struct { + Name string + Token string + ExpectedError string + }{ + { + Name: "bad token", + Token: tokenBad.SecretID, + ExpectedError: structs.ErrPermissionDenied.Error(), + }, + { + Name: "good token", + Token: tokenGood.SecretID, + ExpectedError: structs.ErrUnknownAllocationPrefix, + }, + { + Name: "root token", + Token: root.SecretID, + ExpectedError: structs.ErrUnknownAllocationPrefix, + }, + } + + for _, c := range cases { + t.Run(c.Name, func(t *testing.T) { + + // Make the request + req := &cstructs.FsStatRequest{ + AllocID: uuid.Generate(), + Path: "/", + QueryOptions: structs.QueryOptions{ + Region: "global", + Namespace: structs.DefaultNamespace, + AuthToken: c.Token, + }, + } + + // Fetch the response + var resp cstructs.FsStatResponse + err := msgpackrpc.CallWithCodec(codec, "FileSystem.Stat", req, &resp) + require.NotNil(err) + require.Contains(err.Error(), c.ExpectedError) + }) + } +} + +func TestClientFS_Stat_Remote(t *testing.T) { + t.Parallel() + require := require.New(t) + + // Start a server and client + s1 := TestServer(t, nil) + defer s1.Shutdown() + s2 := TestServer(t, func(c *Config) { + c.DevDisableBootstrap = true + }) + defer s2.Shutdown() + TestJoin(t, s1, s2) + testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForLeader(t, s2.RPC) + codec := rpcClient(t, s2) + + c := client.TestClient(t, func(c *config.Config) { + c.Servers = []string{s2.config.RPCAddr.String()} + }) + defer c.Shutdown() + + // Force an allocation onto the node + a := mock.Alloc() + a.Job.Type = structs.JobTypeBatch + a.NodeID = c.NodeID() + a.Job.TaskGroups[0].Count = 1 + a.Job.TaskGroups[0].Tasks[0] = &structs.Task{ + Name: "web", + Driver: "mock_driver", + Config: map[string]interface{}{ + "run_for": "2s", + }, + LogConfig: structs.DefaultLogConfig(), + Resources: &structs.Resources{ + CPU: 500, + MemoryMB: 256, + }, + } + + // Wait for the client to connect + testutil.WaitForResult(func() (bool, error) { + nodes := s2.connectedNodes() + return len(nodes) == 1, nil + }, func(err error) { + t.Fatalf("should have a clients") + }) + + // Upsert the allocation + state1 := s1.State() + state2 := s2.State() + require.Nil(state1.UpsertJob(999, a.Job)) + require.Nil(state1.UpsertAllocs(1003, []*structs.Allocation{a})) + require.Nil(state2.UpsertJob(999, a.Job)) + require.Nil(state2.UpsertAllocs(1003, []*structs.Allocation{a})) + + // Wait for the client to run the allocation + testutil.WaitForResult(func() (bool, error) { + alloc, err := state2.AllocByID(nil, a.ID) + if err != nil { + return false, err + } + if alloc == nil { + return false, fmt.Errorf("unknown alloc") + } + if alloc.ClientStatus != structs.AllocClientStatusComplete { + return false, fmt.Errorf("alloc client status: %v", alloc.ClientStatus) + } + + return true, nil + }, func(err error) { + t.Fatalf("Alloc on node %q not finished: %v", c.NodeID(), err) + }) + + // Force remove the connection locally in case it exists + s1.nodeConnsLock.Lock() + delete(s1.nodeConns, c.NodeID()) + s1.nodeConnsLock.Unlock() + + // Make the request without having a node-id + req := &cstructs.FsStatRequest{ + AllocID: a.ID, + Path: "/", + QueryOptions: structs.QueryOptions{Region: "global"}, + } + + // Fetch the response + var resp cstructs.FsStatResponse + err := msgpackrpc.CallWithCodec(codec, "FileSystem.Stat", req, &resp) + require.Nil(err) + require.NotNil(resp.Info) +} + +func TestClientFS_Streaming_NoAlloc(t *testing.T) { + t.Parallel() + require := require.New(t) + + // Start a server and client + s := TestServer(t, nil) + defer s.Shutdown() + testutil.WaitForLeader(t, s.RPC) + + // Make the request with bad allocation id + req := &cstructs.FsStreamRequest{ + AllocID: uuid.Generate(), + QueryOptions: structs.QueryOptions{Region: "global"}, + } + + // Get the handler + handler, err := s.StreamingRpcHandler("FileSystem.Stream") + require.Nil(err) + + // Create a pipe + p1, p2 := net.Pipe() + defer p1.Close() + defer p2.Close() + + errCh := make(chan error) + streamMsg := make(chan *cstructs.StreamErrWrapper) + + // Start the handler + go handler(p2) + + // Start the decoder + go func() { + decoder := codec.NewDecoder(p1, structs.MsgpackHandle) + for { + var msg cstructs.StreamErrWrapper + if err := decoder.Decode(&msg); err != nil { + if err == io.EOF || strings.Contains(err.Error(), "closed") { + return + } + errCh <- fmt.Errorf("error decoding: %v", err) + } + + streamMsg <- &msg + } + }() + + // Send the request + encoder := codec.NewEncoder(p1, structs.MsgpackHandle) + require.Nil(encoder.Encode(req)) + + timeout := time.After(5 * time.Second) + +OUTER: + for { + select { + case <-timeout: + t.Fatal("timeout") + case err := <-errCh: + t.Fatal(err) + case msg := <-streamMsg: + if msg.Error == nil { + continue + } + + if structs.IsErrUnknownAllocation(msg.Error) { + break OUTER + } + } + } +} + +func TestClientFS_Streaming_ACL(t *testing.T) { + t.Parallel() + require := require.New(t) + + // Start a server + s, root := TestACLServer(t, nil) + defer s.Shutdown() + testutil.WaitForLeader(t, s.RPC) + + // Create a bad token + policyBad := mock.NamespacePolicy("other", "", []string{acl.NamespaceCapabilityReadFS}) + tokenBad := mock.CreatePolicyAndToken(t, s.State(), 1005, "invalid", policyBad) + + policyGood := mock.NamespacePolicy(structs.DefaultNamespace, "", + []string{acl.NamespaceCapabilityReadLogs, acl.NamespaceCapabilityReadFS}) + tokenGood := mock.CreatePolicyAndToken(t, s.State(), 1009, "valid2", policyGood) + + cases := []struct { + Name string + Token string + ExpectedError string + }{ + { + Name: "bad token", + Token: tokenBad.SecretID, + ExpectedError: structs.ErrPermissionDenied.Error(), + }, + { + Name: "good token", + Token: tokenGood.SecretID, + ExpectedError: structs.ErrUnknownAllocationPrefix, + }, + { + Name: "root token", + Token: root.SecretID, + ExpectedError: structs.ErrUnknownAllocationPrefix, + }, + } + + for _, c := range cases { + t.Run(c.Name, func(t *testing.T) { + // Make the request with bad allocation id + req := &cstructs.FsStreamRequest{ + AllocID: uuid.Generate(), + QueryOptions: structs.QueryOptions{ + Namespace: structs.DefaultNamespace, + Region: "global", + AuthToken: c.Token, + }, + } + + // Get the handler + handler, err := s.StreamingRpcHandler("FileSystem.Stream") + require.Nil(err) + + // Create a pipe + p1, p2 := net.Pipe() + defer p1.Close() + defer p2.Close() + + errCh := make(chan error) + streamMsg := make(chan *cstructs.StreamErrWrapper) + + // Start the handler + go handler(p2) + + // Start the decoder + go func() { + decoder := codec.NewDecoder(p1, structs.MsgpackHandle) + for { + var msg cstructs.StreamErrWrapper + if err := decoder.Decode(&msg); err != nil { + if err == io.EOF || strings.Contains(err.Error(), "closed") { + return + } + errCh <- fmt.Errorf("error decoding: %v", err) + } + + streamMsg <- &msg + } + }() + + // Send the request + encoder := codec.NewEncoder(p1, structs.MsgpackHandle) + require.Nil(encoder.Encode(req)) + + timeout := time.After(5 * time.Second) + + OUTER: + for { + select { + case <-timeout: + t.Fatal("timeout") + case err := <-errCh: + t.Fatal(err) + case msg := <-streamMsg: + if msg.Error == nil { + continue + } + + if strings.Contains(msg.Error.Error(), c.ExpectedError) { + break OUTER + } else { + t.Fatalf("Bad error: %v", msg.Error) + } + } + } + }) + } +} + +func TestClientFS_Streaming_Local(t *testing.T) { + t.Parallel() + require := require.New(t) + + // Start a server and client + s := TestServer(t, nil) + defer s.Shutdown() + testutil.WaitForLeader(t, s.RPC) + + c := client.TestClient(t, func(c *config.Config) { + c.Servers = []string{s.config.RPCAddr.String()} + }) + defer c.Shutdown() + + // Force an allocation onto the node + expected := "Hello from the other side" + a := mock.Alloc() + a.Job.Type = structs.JobTypeBatch + a.NodeID = c.NodeID() + a.Job.TaskGroups[0].Count = 1 + a.Job.TaskGroups[0].Tasks[0] = &structs.Task{ + Name: "web", + Driver: "mock_driver", + Config: map[string]interface{}{ + "run_for": "2s", + "stdout_string": expected, + }, + LogConfig: structs.DefaultLogConfig(), + Resources: &structs.Resources{ + CPU: 500, + MemoryMB: 256, + }, + } + + // Wait for the client to connect + testutil.WaitForResult(func() (bool, error) { + nodes := s.connectedNodes() + return len(nodes) == 1, nil + }, func(err error) { + t.Fatalf("should have a clients") + }) + + // Upsert the allocation + state := s.State() + require.Nil(state.UpsertJob(999, a.Job)) + require.Nil(state.UpsertAllocs(1003, []*structs.Allocation{a})) + + // Wait for the client to run the allocation + testutil.WaitForResult(func() (bool, error) { + alloc, err := state.AllocByID(nil, a.ID) + if err != nil { + return false, err + } + if alloc == nil { + return false, fmt.Errorf("unknown alloc") + } + if alloc.ClientStatus != structs.AllocClientStatusComplete { + return false, fmt.Errorf("alloc client status: %v", alloc.ClientStatus) + } + + return true, nil + }, func(err error) { + t.Fatalf("Alloc on node %q not finished: %v", c.NodeID(), err) + }) + + // Make the request + req := &cstructs.FsStreamRequest{ + AllocID: a.ID, + Path: "alloc/logs/web.stdout.0", + Origin: "start", + PlainText: true, + QueryOptions: structs.QueryOptions{Region: "global"}, + } + + // Get the handler + handler, err := s.StreamingRpcHandler("FileSystem.Stream") + require.Nil(err) + + // Create a pipe + p1, p2 := net.Pipe() + defer p1.Close() + defer p2.Close() + + errCh := make(chan error) + streamMsg := make(chan *cstructs.StreamErrWrapper) + + // Start the handler + go handler(p2) + + // Start the decoder + go func() { + decoder := codec.NewDecoder(p1, structs.MsgpackHandle) + for { + var msg cstructs.StreamErrWrapper + if err := decoder.Decode(&msg); err != nil { + if err == io.EOF || strings.Contains(err.Error(), "closed") { + return + } + errCh <- fmt.Errorf("error decoding: %v", err) + } + + streamMsg <- &msg + } + }() + + // Send the request + encoder := codec.NewEncoder(p1, structs.MsgpackHandle) + require.Nil(encoder.Encode(req)) + + timeout := time.After(3 * time.Second) + received := "" +OUTER: + for { + select { + case <-timeout: + t.Fatal("timeout") + case err := <-errCh: + t.Fatal(err) + case msg := <-streamMsg: + if msg.Error != nil { + t.Fatalf("Got error: %v", msg.Error.Error()) + } + + // Add the payload + received += string(msg.Payload) + if received == expected { + break OUTER + } + } + } +} + +func TestClientFS_Streaming_Local_Follow(t *testing.T) { + t.Parallel() + require := require.New(t) + + // Start a server and client + s := TestServer(t, nil) + defer s.Shutdown() + testutil.WaitForLeader(t, s.RPC) + + c := client.TestClient(t, func(c *config.Config) { + c.Servers = []string{s.config.RPCAddr.String()} + }) + defer c.Shutdown() + + // Force an allocation onto the node + expectedBase := "Hello from the other side" + repeat := 10 + + a := mock.Alloc() + a.Job.Type = structs.JobTypeBatch + a.NodeID = c.NodeID() + a.Job.TaskGroups[0].Count = 1 + a.Job.TaskGroups[0].Tasks[0] = &structs.Task{ + Name: "web", + Driver: "mock_driver", + Config: map[string]interface{}{ + "run_for": "20s", + "stdout_string": expectedBase, + "stdout_repeat": repeat, + "stdout_repeat_duration": 200 * time.Millisecond, + }, + LogConfig: structs.DefaultLogConfig(), + Resources: &structs.Resources{ + CPU: 500, + MemoryMB: 256, + }, + } + + // Wait for the client to connect + testutil.WaitForResult(func() (bool, error) { + nodes := s.connectedNodes() + return len(nodes) == 1, nil + }, func(err error) { + t.Fatalf("should have a clients") + }) + + // Upsert the allocation + state := s.State() + require.Nil(state.UpsertJob(999, a.Job)) + require.Nil(state.UpsertAllocs(1003, []*structs.Allocation{a})) + + // Wait for the client to run the allocation + testutil.WaitForResult(func() (bool, error) { + alloc, err := state.AllocByID(nil, a.ID) + if err != nil { + return false, err + } + if alloc == nil { + return false, fmt.Errorf("unknown alloc") + } + if alloc.ClientStatus != structs.AllocClientStatusRunning { + return false, fmt.Errorf("alloc client status: %v", alloc.ClientStatus) + } + + return true, nil + }, func(err error) { + t.Fatalf("Alloc on node %q not running: %v", c.NodeID(), err) + }) + + // Make the request + req := &cstructs.FsStreamRequest{ + AllocID: a.ID, + Path: "alloc/logs/web.stdout.0", + Origin: "start", + PlainText: true, + Follow: true, + QueryOptions: structs.QueryOptions{Region: "global"}, + } + + // Get the handler + handler, err := s.StreamingRpcHandler("FileSystem.Stream") + require.Nil(err) + + // Create a pipe + p1, p2 := net.Pipe() + defer p1.Close() + defer p2.Close() + + errCh := make(chan error) + streamMsg := make(chan *cstructs.StreamErrWrapper) + + // Start the handler + go handler(p2) + + // Start the decoder + go func() { + decoder := codec.NewDecoder(p1, structs.MsgpackHandle) + for { + var msg cstructs.StreamErrWrapper + if err := decoder.Decode(&msg); err != nil { + if err == io.EOF || strings.Contains(err.Error(), "closed") { + return + } + errCh <- fmt.Errorf("error decoding: %v", err) + } + + streamMsg <- &msg + } + }() + + // Send the request + encoder := codec.NewEncoder(p1, structs.MsgpackHandle) + require.Nil(encoder.Encode(req)) + + timeout := time.After(20 * time.Second) + expected := strings.Repeat(expectedBase, repeat+1) + received := "" +OUTER: + for { + select { + case <-timeout: + t.Fatal("timeout") + case err := <-errCh: + t.Fatal(err) + case msg := <-streamMsg: + if msg.Error != nil { + t.Fatalf("Got error: %v", msg.Error.Error()) + } + + // Add the payload + received += string(msg.Payload) + if received == expected { + break OUTER + } + } + } +} + +func TestClientFS_Streaming_Remote_Server(t *testing.T) { + t.Parallel() + require := require.New(t) + + // Start a server and client + s1 := TestServer(t, nil) + defer s1.Shutdown() + s2 := TestServer(t, func(c *Config) { + c.DevDisableBootstrap = true + }) + defer s2.Shutdown() + TestJoin(t, s1, s2) + testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForLeader(t, s2.RPC) + + c := client.TestClient(t, func(c *config.Config) { + c.Servers = []string{s2.config.RPCAddr.String()} + }) + defer c.Shutdown() + + // Force an allocation onto the node + expected := "Hello from the other side" + a := mock.Alloc() + a.Job.Type = structs.JobTypeBatch + a.NodeID = c.NodeID() + a.Job.TaskGroups[0].Count = 1 + a.Job.TaskGroups[0].Tasks[0] = &structs.Task{ + Name: "web", + Driver: "mock_driver", + Config: map[string]interface{}{ + "run_for": "2s", + "stdout_string": expected, + }, + LogConfig: structs.DefaultLogConfig(), + Resources: &structs.Resources{ + CPU: 500, + MemoryMB: 256, + }, + } + + // Wait for the client to connect + testutil.WaitForResult(func() (bool, error) { + nodes := s2.connectedNodes() + return len(nodes) == 1, nil + }, func(err error) { + t.Fatalf("should have a clients") + }) + + // Upsert the allocation + state1 := s1.State() + state2 := s2.State() + require.Nil(state1.UpsertJob(999, a.Job)) + require.Nil(state1.UpsertAllocs(1003, []*structs.Allocation{a})) + require.Nil(state2.UpsertJob(999, a.Job)) + require.Nil(state2.UpsertAllocs(1003, []*structs.Allocation{a})) + + // Wait for the client to run the allocation + testutil.WaitForResult(func() (bool, error) { + alloc, err := state2.AllocByID(nil, a.ID) + if err != nil { + return false, err + } + if alloc == nil { + return false, fmt.Errorf("unknown alloc") + } + if alloc.ClientStatus != structs.AllocClientStatusComplete { + return false, fmt.Errorf("alloc client status: %v", alloc.ClientStatus) + } + + return true, nil + }, func(err error) { + t.Fatalf("Alloc on node %q not finished: %v", c.NodeID(), err) + }) + + // Force remove the connection locally in case it exists + s1.nodeConnsLock.Lock() + delete(s1.nodeConns, c.NodeID()) + s1.nodeConnsLock.Unlock() + + // Make the request + req := &cstructs.FsStreamRequest{ + AllocID: a.ID, + Path: "alloc/logs/web.stdout.0", + Origin: "start", + PlainText: true, + QueryOptions: structs.QueryOptions{Region: "global"}, + } + + // Get the handler + handler, err := s1.StreamingRpcHandler("FileSystem.Stream") + require.Nil(err) + + // Create a pipe + p1, p2 := net.Pipe() + defer p1.Close() + defer p2.Close() + + errCh := make(chan error) + streamMsg := make(chan *cstructs.StreamErrWrapper) + + // Start the handler + go handler(p2) + + // Start the decoder + go func() { + decoder := codec.NewDecoder(p1, structs.MsgpackHandle) + for { + var msg cstructs.StreamErrWrapper + if err := decoder.Decode(&msg); err != nil { + if err == io.EOF || strings.Contains(err.Error(), "closed") { + return + } + errCh <- fmt.Errorf("error decoding: %v", err) + } + + streamMsg <- &msg + } + }() + + // Send the request + encoder := codec.NewEncoder(p1, structs.MsgpackHandle) + require.Nil(encoder.Encode(req)) + + timeout := time.After(3 * time.Second) + received := "" +OUTER: + for { + select { + case <-timeout: + t.Fatal("timeout") + case err := <-errCh: + t.Fatal(err) + case msg := <-streamMsg: + if msg.Error != nil { + t.Fatalf("Got error: %v", msg.Error.Error()) + } + + // Add the payload + received += string(msg.Payload) + if received == expected { + break OUTER + } + } + } +} + +func TestClientFS_Streaming_Remote_Region(t *testing.T) { + t.Parallel() + require := require.New(t) + + // Start a server and client + s1 := TestServer(t, nil) + defer s1.Shutdown() + s2 := TestServer(t, func(c *Config) { + c.Region = "two" + }) + defer s2.Shutdown() + TestJoin(t, s1, s2) + testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForLeader(t, s2.RPC) + + c := client.TestClient(t, func(c *config.Config) { + c.Servers = []string{s2.config.RPCAddr.String()} + c.Region = "two" + }) + defer c.Shutdown() + + // Force an allocation onto the node + expected := "Hello from the other side" + a := mock.Alloc() + a.Job.Type = structs.JobTypeBatch + a.NodeID = c.NodeID() + a.Job.TaskGroups[0].Count = 1 + a.Job.TaskGroups[0].Tasks[0] = &structs.Task{ + Name: "web", + Driver: "mock_driver", + Config: map[string]interface{}{ + "run_for": "2s", + "stdout_string": expected, + }, + LogConfig: structs.DefaultLogConfig(), + Resources: &structs.Resources{ + CPU: 500, + MemoryMB: 256, + }, + } + + // Wait for the client to connect + testutil.WaitForResult(func() (bool, error) { + nodes := s2.connectedNodes() + return len(nodes) == 1, nil + }, func(err error) { + t.Fatalf("should have a client") + }) + + // Upsert the allocation + state2 := s2.State() + require.Nil(state2.UpsertJob(999, a.Job)) + require.Nil(state2.UpsertAllocs(1003, []*structs.Allocation{a})) + + // Wait for the client to run the allocation + testutil.WaitForResult(func() (bool, error) { + alloc, err := state2.AllocByID(nil, a.ID) + if err != nil { + return false, err + } + if alloc == nil { + return false, fmt.Errorf("unknown alloc") + } + if alloc.ClientStatus != structs.AllocClientStatusComplete { + return false, fmt.Errorf("alloc client status: %v", alloc.ClientStatus) + } + + return true, nil + }, func(err error) { + t.Fatalf("Alloc on node %q not finished: %v", c.NodeID(), err) + }) + + // Force remove the connection locally in case it exists + s1.nodeConnsLock.Lock() + delete(s1.nodeConns, c.NodeID()) + s1.nodeConnsLock.Unlock() + + // Make the request + req := &cstructs.FsStreamRequest{ + AllocID: a.ID, + Path: "alloc/logs/web.stdout.0", + Origin: "start", + PlainText: true, + QueryOptions: structs.QueryOptions{Region: "two"}, + } + + // Get the handler + handler, err := s1.StreamingRpcHandler("FileSystem.Stream") + require.Nil(err) + + // Create a pipe + p1, p2 := net.Pipe() + defer p1.Close() + defer p2.Close() + + errCh := make(chan error) + streamMsg := make(chan *cstructs.StreamErrWrapper) + + // Start the handler + go handler(p2) + + // Start the decoder + go func() { + decoder := codec.NewDecoder(p1, structs.MsgpackHandle) + for { + var msg cstructs.StreamErrWrapper + if err := decoder.Decode(&msg); err != nil { + if err == io.EOF || strings.Contains(err.Error(), "closed") { + return + } + errCh <- fmt.Errorf("error decoding: %v", err) + } + + streamMsg <- &msg + } + }() + + // Send the request + encoder := codec.NewEncoder(p1, structs.MsgpackHandle) + require.Nil(encoder.Encode(req)) + + timeout := time.After(3 * time.Second) + received := "" +OUTER: + for { + select { + case <-timeout: + t.Fatal("timeout") + case err := <-errCh: + t.Fatal(err) + case msg := <-streamMsg: + if msg.Error != nil { + t.Fatalf("Got error: %v", msg.Error.Error()) + } + + // Add the payload + received += string(msg.Payload) + if received == expected { + break OUTER + } + } + } +} + +func TestClientFS_Logs_NoAlloc(t *testing.T) { + t.Parallel() + require := require.New(t) + + // Start a server and client + s := TestServer(t, nil) + defer s.Shutdown() + testutil.WaitForLeader(t, s.RPC) + + // Make the request with bad allocation id + req := &cstructs.FsLogsRequest{ + AllocID: uuid.Generate(), + QueryOptions: structs.QueryOptions{Region: "global"}, + } + + // Get the handler + handler, err := s.StreamingRpcHandler("FileSystem.Logs") + require.Nil(err) + + // Create a pipe + p1, p2 := net.Pipe() + defer p1.Close() + defer p2.Close() + + errCh := make(chan error) + streamMsg := make(chan *cstructs.StreamErrWrapper) + + // Start the handler + go handler(p2) + + // Start the decoder + go func() { + decoder := codec.NewDecoder(p1, structs.MsgpackHandle) + for { + var msg cstructs.StreamErrWrapper + if err := decoder.Decode(&msg); err != nil { + if err == io.EOF || strings.Contains(err.Error(), "closed") { + return + } + errCh <- fmt.Errorf("error decoding: %v", err) + } + + streamMsg <- &msg + } + }() + + // Send the request + encoder := codec.NewEncoder(p1, structs.MsgpackHandle) + require.Nil(encoder.Encode(req)) + + timeout := time.After(5 * time.Second) + +OUTER: + for { + select { + case <-timeout: + t.Fatal("timeout") + case err := <-errCh: + t.Fatal(err) + case msg := <-streamMsg: + if msg.Error == nil { + continue + } + + if structs.IsErrUnknownAllocation(msg.Error) { + break OUTER + } + } + } +} + +func TestClientFS_Logs_ACL(t *testing.T) { + t.Parallel() + require := require.New(t) + + // Start a server + s, root := TestACLServer(t, nil) + defer s.Shutdown() + testutil.WaitForLeader(t, s.RPC) + + // Create a bad token + policyBad := mock.NamespacePolicy("other", "", []string{acl.NamespaceCapabilityReadFS}) + tokenBad := mock.CreatePolicyAndToken(t, s.State(), 1005, "invalid", policyBad) + + policyGood := mock.NamespacePolicy(structs.DefaultNamespace, "", + []string{acl.NamespaceCapabilityReadLogs, acl.NamespaceCapabilityReadFS}) + tokenGood := mock.CreatePolicyAndToken(t, s.State(), 1009, "valid2", policyGood) + + cases := []struct { + Name string + Token string + ExpectedError string + }{ + { + Name: "bad token", + Token: tokenBad.SecretID, + ExpectedError: structs.ErrPermissionDenied.Error(), + }, + { + Name: "good token", + Token: tokenGood.SecretID, + ExpectedError: structs.ErrUnknownAllocationPrefix, + }, + { + Name: "root token", + Token: root.SecretID, + ExpectedError: structs.ErrUnknownAllocationPrefix, + }, + } + + for _, c := range cases { + t.Run(c.Name, func(t *testing.T) { + // Make the request with bad allocation id + req := &cstructs.FsLogsRequest{ + AllocID: uuid.Generate(), + QueryOptions: structs.QueryOptions{ + Namespace: structs.DefaultNamespace, + Region: "global", + AuthToken: c.Token, + }, + } + + // Get the handler + handler, err := s.StreamingRpcHandler("FileSystem.Logs") + require.Nil(err) + + // Create a pipe + p1, p2 := net.Pipe() + defer p1.Close() + defer p2.Close() + + errCh := make(chan error) + streamMsg := make(chan *cstructs.StreamErrWrapper) + + // Start the handler + go handler(p2) + + // Start the decoder + go func() { + decoder := codec.NewDecoder(p1, structs.MsgpackHandle) + for { + var msg cstructs.StreamErrWrapper + if err := decoder.Decode(&msg); err != nil { + if err == io.EOF || strings.Contains(err.Error(), "closed") { + return + } + errCh <- fmt.Errorf("error decoding: %v", err) + } + + streamMsg <- &msg + } + }() + + // Send the request + encoder := codec.NewEncoder(p1, structs.MsgpackHandle) + require.Nil(encoder.Encode(req)) + + timeout := time.After(5 * time.Second) + + OUTER: + for { + select { + case <-timeout: + t.Fatal("timeout") + case err := <-errCh: + t.Fatal(err) + case msg := <-streamMsg: + if msg.Error == nil { + continue + } + + if strings.Contains(msg.Error.Error(), c.ExpectedError) { + break OUTER + } else { + t.Fatalf("Bad error: %v", msg.Error) + } + } + } + }) + } +} + +func TestClientFS_Logs_Local(t *testing.T) { + t.Parallel() + require := require.New(t) + + // Start a server and client + s := TestServer(t, nil) + defer s.Shutdown() + testutil.WaitForLeader(t, s.RPC) + + c := client.TestClient(t, func(c *config.Config) { + c.Servers = []string{s.config.RPCAddr.String()} + }) + defer c.Shutdown() + + // Force an allocation onto the node + expected := "Hello from the other side" + a := mock.Alloc() + a.Job.Type = structs.JobTypeBatch + a.NodeID = c.NodeID() + a.Job.TaskGroups[0].Count = 1 + a.Job.TaskGroups[0].Tasks[0] = &structs.Task{ + Name: "web", + Driver: "mock_driver", + Config: map[string]interface{}{ + "run_for": "2s", + "stdout_string": expected, + }, + LogConfig: structs.DefaultLogConfig(), + Resources: &structs.Resources{ + CPU: 500, + MemoryMB: 256, + }, + } + + // Wait for the client to connect + testutil.WaitForResult(func() (bool, error) { + nodes := s.connectedNodes() + return len(nodes) == 1, nil + }, func(err error) { + t.Fatalf("should have a clients") + }) + + // Upsert the allocation + state := s.State() + require.Nil(state.UpsertJob(999, a.Job)) + require.Nil(state.UpsertAllocs(1003, []*structs.Allocation{a})) + + // Wait for the client to run the allocation + testutil.WaitForResult(func() (bool, error) { + alloc, err := state.AllocByID(nil, a.ID) + if err != nil { + return false, err + } + if alloc == nil { + return false, fmt.Errorf("unknown alloc") + } + if alloc.ClientStatus != structs.AllocClientStatusComplete { + return false, fmt.Errorf("alloc client status: %v", alloc.ClientStatus) + } + + return true, nil + }, func(err error) { + t.Fatalf("Alloc on node %q not finished: %v", c.NodeID(), err) + }) + + // Make the request + req := &cstructs.FsLogsRequest{ + AllocID: a.ID, + Task: a.Job.TaskGroups[0].Tasks[0].Name, + LogType: "stdout", + Origin: "start", + PlainText: true, + QueryOptions: structs.QueryOptions{Region: "global"}, + } + + // Get the handler + handler, err := s.StreamingRpcHandler("FileSystem.Logs") + require.Nil(err) + + // Create a pipe + p1, p2 := net.Pipe() + defer p1.Close() + defer p2.Close() + + errCh := make(chan error) + streamMsg := make(chan *cstructs.StreamErrWrapper) + + // Start the handler + go handler(p2) + + // Start the decoder + go func() { + decoder := codec.NewDecoder(p1, structs.MsgpackHandle) + for { + var msg cstructs.StreamErrWrapper + if err := decoder.Decode(&msg); err != nil { + if err == io.EOF || strings.Contains(err.Error(), "closed") { + return + } + errCh <- fmt.Errorf("error decoding: %v", err) + } + + streamMsg <- &msg + } + }() + + // Send the request + encoder := codec.NewEncoder(p1, structs.MsgpackHandle) + require.Nil(encoder.Encode(req)) + + timeout := time.After(3 * time.Second) + received := "" +OUTER: + for { + select { + case <-timeout: + t.Fatal("timeout") + case err := <-errCh: + t.Fatal(err) + case msg := <-streamMsg: + if msg.Error != nil { + t.Fatalf("Got error: %v", msg.Error.Error()) + } + + // Add the payload + received += string(msg.Payload) + if received == expected { + break OUTER + } + } + } +} + +func TestClientFS_Logs_Local_Follow(t *testing.T) { + t.Parallel() + require := require.New(t) + + // Start a server and client + s := TestServer(t, nil) + defer s.Shutdown() + testutil.WaitForLeader(t, s.RPC) + + c := client.TestClient(t, func(c *config.Config) { + c.Servers = []string{s.config.RPCAddr.String()} + }) + defer c.Shutdown() + + // Force an allocation onto the node + expectedBase := "Hello from the other side" + repeat := 10 + + a := mock.Alloc() + a.Job.Type = structs.JobTypeBatch + a.NodeID = c.NodeID() + a.Job.TaskGroups[0].Count = 1 + a.Job.TaskGroups[0].Tasks[0] = &structs.Task{ + Name: "web", + Driver: "mock_driver", + Config: map[string]interface{}{ + "run_for": "20s", + "stdout_string": expectedBase, + "stdout_repeat": repeat, + "stdout_repeat_duration": 200 * time.Millisecond, + }, + LogConfig: structs.DefaultLogConfig(), + Resources: &structs.Resources{ + CPU: 500, + MemoryMB: 256, + }, + } + + // Wait for the client to connect + testutil.WaitForResult(func() (bool, error) { + nodes := s.connectedNodes() + return len(nodes) == 1, nil + }, func(err error) { + t.Fatalf("should have a clients") + }) + + // Upsert the allocation + state := s.State() + require.Nil(state.UpsertJob(999, a.Job)) + require.Nil(state.UpsertAllocs(1003, []*structs.Allocation{a})) + + // Wait for the client to run the allocation + testutil.WaitForResult(func() (bool, error) { + alloc, err := state.AllocByID(nil, a.ID) + if err != nil { + return false, err + } + if alloc == nil { + return false, fmt.Errorf("unknown alloc") + } + if alloc.ClientStatus != structs.AllocClientStatusRunning { + return false, fmt.Errorf("alloc client status: %v", alloc.ClientStatus) + } + + return true, nil + }, func(err error) { + t.Fatalf("Alloc on node %q not running: %v", c.NodeID(), err) + }) + + // Make the request + req := &cstructs.FsLogsRequest{ + AllocID: a.ID, + Task: a.Job.TaskGroups[0].Tasks[0].Name, + LogType: "stdout", + Origin: "start", + PlainText: true, + Follow: true, + QueryOptions: structs.QueryOptions{Region: "global"}, + } + + // Get the handler + handler, err := s.StreamingRpcHandler("FileSystem.Logs") + require.Nil(err) + + // Create a pipe + p1, p2 := net.Pipe() + defer p1.Close() + defer p2.Close() + + errCh := make(chan error) + streamMsg := make(chan *cstructs.StreamErrWrapper) + + // Start the handler + go handler(p2) + + // Start the decoder + go func() { + decoder := codec.NewDecoder(p1, structs.MsgpackHandle) + for { + var msg cstructs.StreamErrWrapper + if err := decoder.Decode(&msg); err != nil { + if err == io.EOF || strings.Contains(err.Error(), "closed") { + return + } + errCh <- fmt.Errorf("error decoding: %v", err) + } + + streamMsg <- &msg + } + }() + + // Send the request + encoder := codec.NewEncoder(p1, structs.MsgpackHandle) + require.Nil(encoder.Encode(req)) + + timeout := time.After(20 * time.Second) + expected := strings.Repeat(expectedBase, repeat+1) + received := "" +OUTER: + for { + select { + case <-timeout: + t.Fatal("timeout") + case err := <-errCh: + t.Fatal(err) + case msg := <-streamMsg: + if msg.Error != nil { + t.Fatalf("Got error: %v", msg.Error.Error()) + } + + // Add the payload + received += string(msg.Payload) + if received == expected { + break OUTER + } + } + } +} + +func TestClientFS_Logs_Remote_Server(t *testing.T) { + t.Parallel() + require := require.New(t) + + // Start a server and client + s1 := TestServer(t, nil) + defer s1.Shutdown() + s2 := TestServer(t, func(c *Config) { + c.DevDisableBootstrap = true + }) + defer s2.Shutdown() + TestJoin(t, s1, s2) + testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForLeader(t, s2.RPC) + + c := client.TestClient(t, func(c *config.Config) { + c.Servers = []string{s2.config.RPCAddr.String()} + }) + defer c.Shutdown() + + // Force an allocation onto the node + expected := "Hello from the other side" + a := mock.Alloc() + a.Job.Type = structs.JobTypeBatch + a.NodeID = c.NodeID() + a.Job.TaskGroups[0].Count = 1 + a.Job.TaskGroups[0].Tasks[0] = &structs.Task{ + Name: "web", + Driver: "mock_driver", + Config: map[string]interface{}{ + "run_for": "2s", + "stdout_string": expected, + }, + LogConfig: structs.DefaultLogConfig(), + Resources: &structs.Resources{ + CPU: 500, + MemoryMB: 256, + }, + } + + // Wait for the client to connect + testutil.WaitForResult(func() (bool, error) { + nodes := s2.connectedNodes() + return len(nodes) == 1, nil + }, func(err error) { + t.Fatalf("should have a clients") + }) + + // Upsert the allocation + state1 := s1.State() + state2 := s2.State() + require.Nil(state1.UpsertJob(999, a.Job)) + require.Nil(state1.UpsertAllocs(1003, []*structs.Allocation{a})) + require.Nil(state2.UpsertJob(999, a.Job)) + require.Nil(state2.UpsertAllocs(1003, []*structs.Allocation{a})) + + // Wait for the client to run the allocation + testutil.WaitForResult(func() (bool, error) { + alloc, err := state2.AllocByID(nil, a.ID) + if err != nil { + return false, err + } + if alloc == nil { + return false, fmt.Errorf("unknown alloc") + } + if alloc.ClientStatus != structs.AllocClientStatusComplete { + return false, fmt.Errorf("alloc client status: %v", alloc.ClientStatus) + } + + return true, nil + }, func(err error) { + t.Fatalf("Alloc on node %q not finished: %v", c.NodeID(), err) + }) + + // Force remove the connection locally in case it exists + s1.nodeConnsLock.Lock() + delete(s1.nodeConns, c.NodeID()) + s1.nodeConnsLock.Unlock() + + // Make the request + req := &cstructs.FsLogsRequest{ + AllocID: a.ID, + Task: a.Job.TaskGroups[0].Tasks[0].Name, + LogType: "stdout", + Origin: "start", + PlainText: true, + QueryOptions: structs.QueryOptions{Region: "global"}, + } + + // Get the handler + handler, err := s1.StreamingRpcHandler("FileSystem.Logs") + require.Nil(err) + + // Create a pipe + p1, p2 := net.Pipe() + defer p1.Close() + defer p2.Close() + + errCh := make(chan error) + streamMsg := make(chan *cstructs.StreamErrWrapper) + + // Start the handler + go handler(p2) + + // Start the decoder + go func() { + decoder := codec.NewDecoder(p1, structs.MsgpackHandle) + for { + var msg cstructs.StreamErrWrapper + if err := decoder.Decode(&msg); err != nil { + if err == io.EOF || strings.Contains(err.Error(), "closed") { + return + } + errCh <- fmt.Errorf("error decoding: %v", err) + } + + streamMsg <- &msg + } + }() + + // Send the request + encoder := codec.NewEncoder(p1, structs.MsgpackHandle) + require.Nil(encoder.Encode(req)) + + timeout := time.After(3 * time.Second) + received := "" +OUTER: + for { + select { + case <-timeout: + t.Fatal("timeout") + case err := <-errCh: + t.Fatal(err) + case msg := <-streamMsg: + if msg.Error != nil { + t.Fatalf("Got error: %v", msg.Error.Error()) + } + + // Add the payload + received += string(msg.Payload) + if received == expected { + break OUTER + } + } + } +} + +func TestClientFS_Logs_Remote_Region(t *testing.T) { + t.Parallel() + require := require.New(t) + + // Start a server and client + s1 := TestServer(t, nil) + defer s1.Shutdown() + s2 := TestServer(t, func(c *Config) { + c.Region = "two" + }) + defer s2.Shutdown() + TestJoin(t, s1, s2) + testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForLeader(t, s2.RPC) + + c := client.TestClient(t, func(c *config.Config) { + c.Servers = []string{s2.config.RPCAddr.String()} + c.Region = "two" + }) + defer c.Shutdown() + + // Force an allocation onto the node + expected := "Hello from the other side" + a := mock.Alloc() + a.Job.Type = structs.JobTypeBatch + a.NodeID = c.NodeID() + a.Job.TaskGroups[0].Count = 1 + a.Job.TaskGroups[0].Tasks[0] = &structs.Task{ + Name: "web", + Driver: "mock_driver", + Config: map[string]interface{}{ + "run_for": "2s", + "stdout_string": expected, + }, + LogConfig: structs.DefaultLogConfig(), + Resources: &structs.Resources{ + CPU: 500, + MemoryMB: 256, + }, + } + + // Wait for the client to connect + testutil.WaitForResult(func() (bool, error) { + nodes := s2.connectedNodes() + return len(nodes) == 1, nil + }, func(err error) { + t.Fatalf("should have a client") + }) + + // Upsert the allocation + state2 := s2.State() + require.Nil(state2.UpsertJob(999, a.Job)) + require.Nil(state2.UpsertAllocs(1003, []*structs.Allocation{a})) + + // Wait for the client to run the allocation + testutil.WaitForResult(func() (bool, error) { + alloc, err := state2.AllocByID(nil, a.ID) + if err != nil { + return false, err + } + if alloc == nil { + return false, fmt.Errorf("unknown alloc") + } + if alloc.ClientStatus != structs.AllocClientStatusComplete { + return false, fmt.Errorf("alloc client status: %v", alloc.ClientStatus) + } + + return true, nil + }, func(err error) { + t.Fatalf("Alloc on node %q not finished: %v", c.NodeID(), err) + }) + + // Force remove the connection locally in case it exists + s1.nodeConnsLock.Lock() + delete(s1.nodeConns, c.NodeID()) + s1.nodeConnsLock.Unlock() + + // Make the request + req := &cstructs.FsLogsRequest{ + AllocID: a.ID, + Task: a.Job.TaskGroups[0].Tasks[0].Name, + LogType: "stdout", + Origin: "start", + PlainText: true, + QueryOptions: structs.QueryOptions{Region: "two"}, + } + + // Get the handler + handler, err := s1.StreamingRpcHandler("FileSystem.Logs") + require.Nil(err) + + // Create a pipe + p1, p2 := net.Pipe() + defer p1.Close() + defer p2.Close() + + errCh := make(chan error) + streamMsg := make(chan *cstructs.StreamErrWrapper) + + // Start the handler + go handler(p2) + + // Start the decoder + go func() { + decoder := codec.NewDecoder(p1, structs.MsgpackHandle) + for { + var msg cstructs.StreamErrWrapper + if err := decoder.Decode(&msg); err != nil { + if err == io.EOF || strings.Contains(err.Error(), "closed") { + return + } + errCh <- fmt.Errorf("error decoding: %v", err) + } + + streamMsg <- &msg + } + }() + + // Send the request + encoder := codec.NewEncoder(p1, structs.MsgpackHandle) + require.Nil(encoder.Encode(req)) + + timeout := time.After(3 * time.Second) + received := "" +OUTER: + for { + select { + case <-timeout: + t.Fatal("timeout") + case err := <-errCh: + t.Fatal(err) + case msg := <-streamMsg: + if msg.Error != nil { + t.Fatalf("Got error: %v", msg.Error.Error()) + } + + // Add the payload + received += string(msg.Payload) + if received == expected { + break OUTER + } + } + } +} diff --git a/nomad/client_rpc.go b/nomad/client_rpc.go new file mode 100644 index 00000000000..2b8776253f0 --- /dev/null +++ b/nomad/client_rpc.go @@ -0,0 +1,263 @@ +package nomad + +import ( + "errors" + "fmt" + "net" + "time" + + multierror "github.com/hashicorp/go-multierror" + msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc" + "github.com/hashicorp/nomad/helper/pool" + "github.com/hashicorp/nomad/nomad/state" + "github.com/hashicorp/nomad/nomad/structs" + "github.com/hashicorp/yamux" + "github.com/ugorji/go/codec" +) + +// nodeConnState is used to track connection information about a Nomad Client. +type nodeConnState struct { + // Session holds the multiplexed yamux Session for dialing back. + Session *yamux.Session + + // Established is when the connection was established. + Established time.Time + + // Ctx is the full RPC context + Ctx *RPCContext +} + +// getNodeConn returns the connection to the given node and whether it exists. +func (s *Server) getNodeConn(nodeID string) (*nodeConnState, bool) { + s.nodeConnsLock.RLock() + defer s.nodeConnsLock.RUnlock() + state, ok := s.nodeConns[nodeID] + return state, ok +} + +// connectedNodes returns the set of nodes we have a connection with. +func (s *Server) connectedNodes() map[string]time.Time { + s.nodeConnsLock.RLock() + defer s.nodeConnsLock.RUnlock() + nodes := make(map[string]time.Time, len(s.nodeConns)) + for nodeID, state := range s.nodeConns { + nodes[nodeID] = state.Established + } + return nodes +} + +// addNodeConn adds the mapping between a node and its session. +func (s *Server) addNodeConn(ctx *RPCContext) { + // Hotpath the no-op + if ctx == nil || ctx.NodeID == "" { + return + } + + s.nodeConnsLock.Lock() + defer s.nodeConnsLock.Unlock() + s.nodeConns[ctx.NodeID] = &nodeConnState{ + Session: ctx.Session, + Established: time.Now(), + Ctx: ctx, + } +} + +// removeNodeConn removes the mapping between a node and its session. +func (s *Server) removeNodeConn(ctx *RPCContext) { + // Hotpath the no-op + if ctx == nil || ctx.NodeID == "" { + return + } + + s.nodeConnsLock.Lock() + defer s.nodeConnsLock.Unlock() + state, ok := s.nodeConns[ctx.NodeID] + if !ok { + return + } + + // It is important that we check that the connection being removed is the + // actual stored connection for the client. It is possible for the client to + // dial various addresses that all route to the same server. The most common + // case for this is the original address the client uses to connect to the + // server differs from the advertised address sent by the heartbeat. + if state.Ctx.Conn.LocalAddr().String() == ctx.Conn.LocalAddr().String() && + state.Ctx.Conn.RemoteAddr().String() == ctx.Conn.RemoteAddr().String() { + delete(s.nodeConns, ctx.NodeID) + } +} + +// serverWithNodeConn is used to determine which remote server has the most +// recent connection to the given node. The local server is not queried. +// ErrNoNodeConn is returned if all local peers could be queried but did not +// have a connection to the node. Otherwise if a connection could not be found +// and there were RPC errors, an error is returned. +func (s *Server) serverWithNodeConn(nodeID, region string) (*serverParts, error) { + // We skip ourselves. + selfAddr := s.LocalMember().Addr.String() + + // Build the request + req := &structs.NodeSpecificRequest{ + NodeID: nodeID, + QueryOptions: structs.QueryOptions{ + Region: s.config.Region, + }, + } + + // Select the list of servers to check based on what region we are querying + s.peerLock.RLock() + + var rawTargets []*serverParts + if region == s.Region() { + rawTargets = make([]*serverParts, 0, len(s.localPeers)) + for _, srv := range s.localPeers { + rawTargets = append(rawTargets, srv) + } + } else { + peers, ok := s.peers[region] + if !ok { + s.peerLock.RUnlock() + return nil, structs.ErrNoRegionPath + } + rawTargets = peers + } + + targets := make([]*serverParts, 0, len(rawTargets)) + for _, target := range rawTargets { + targets = append(targets, target.Copy()) + } + s.peerLock.RUnlock() + + // connections is used to store the servers that have connections to the + // requested node. + var mostRecentServer *serverParts + var mostRecent time.Time + + var rpcErr multierror.Error + for _, server := range targets { + if server.Addr.String() == selfAddr { + continue + } + + // Make the RPC + var resp structs.NodeConnQueryResponse + err := s.connPool.RPC(s.config.Region, server.Addr, server.MajorVersion, + "Status.HasNodeConn", &req, &resp) + if err != nil { + multierror.Append(&rpcErr, fmt.Errorf("failed querying server %q: %v", server.Addr.String(), err)) + continue + } + + if resp.Connected && resp.Established.After(mostRecent) { + mostRecentServer = server + mostRecent = resp.Established + } + } + + // Return an error if there is no route to the node. + if mostRecentServer == nil { + if err := rpcErr.ErrorOrNil(); err != nil { + return nil, err + } + + return nil, structs.ErrNoNodeConn + } + + return mostRecentServer, nil +} + +// NodeRpc is used to make an RPC call to a node. The method takes the +// Yamux session for the node and the method to be called. +func NodeRpc(session *yamux.Session, method string, args, reply interface{}) error { + // Open a new session + stream, err := session.Open() + if err != nil { + return err + } + defer stream.Close() + + // Write the RpcNomad byte to set the mode + if _, err := stream.Write([]byte{byte(pool.RpcNomad)}); err != nil { + stream.Close() + return err + } + + // Make the RPC + err = msgpackrpc.CallWithCodec(pool.NewClientCodec(stream), method, args, reply) + if err != nil { + return err + } + + return nil +} + +// NodeStreamingRpc is used to make a streaming RPC call to a node. The method +// takes the Yamux session for the node and the method to be called. It conducts +// the initial handshake and returns a connection to be used or an error. It is +// the callers responsibility to close the connection if there is no error. +func NodeStreamingRpc(session *yamux.Session, method string) (net.Conn, error) { + // Open a new session + stream, err := session.Open() + if err != nil { + return nil, err + } + + // Write the RpcNomad byte to set the mode + if _, err := stream.Write([]byte{byte(pool.RpcStreaming)}); err != nil { + stream.Close() + return nil, err + } + + // Send the header + encoder := codec.NewEncoder(stream, structs.MsgpackHandle) + decoder := codec.NewDecoder(stream, structs.MsgpackHandle) + header := structs.StreamingRpcHeader{ + Method: method, + } + if err := encoder.Encode(header); err != nil { + stream.Close() + return nil, err + } + + // Wait for the acknowledgement + var ack structs.StreamingRpcAck + if err := decoder.Decode(&ack); err != nil { + stream.Close() + return nil, err + } + + if ack.Error != "" { + stream.Close() + return nil, errors.New(ack.Error) + } + + return stream, nil +} + +// findNodeConnAndForward is a helper for finding the server with a connection +// to the given node and forwarding the RPC to the correct server. This does not +// work for streaming RPCs. +func findNodeConnAndForward(srv *Server, snap *state.StateSnapshot, + nodeID, method string, args, reply interface{}) error { + + node, err := snap.NodeByID(nil, nodeID) + if err != nil { + return err + } + + if node == nil { + return fmt.Errorf("Unknown node %q", nodeID) + } + + // Determine the Server that has a connection to the node. + srvWithConn, err := srv.serverWithNodeConn(nodeID, srv.Region()) + if err != nil { + return err + } + + if srvWithConn == nil { + return structs.ErrNoNodeConn + } + + return srv.forwardServer(srvWithConn, method, args, reply) +} diff --git a/nomad/client_rpc_test.go b/nomad/client_rpc_test.go new file mode 100644 index 00000000000..c64eecec029 --- /dev/null +++ b/nomad/client_rpc_test.go @@ -0,0 +1,283 @@ +package nomad + +import ( + "net" + "testing" + + "github.com/hashicorp/nomad/client" + "github.com/hashicorp/nomad/client/config" + "github.com/hashicorp/nomad/helper/uuid" + "github.com/hashicorp/nomad/nomad/structs" + "github.com/hashicorp/nomad/testutil" + "github.com/stretchr/testify/require" +) + +type namedConnWrapper struct { + net.Conn + name string +} + +type namedAddr string + +func (n namedAddr) String() string { return string(n) } +func (n namedAddr) Network() string { return string(n) } + +func (n namedConnWrapper) LocalAddr() net.Addr { + return namedAddr(n.name) +} + +func TestServer_removeNodeConn_differentAddrs(t *testing.T) { + t.Parallel() + require := require.New(t) + s1 := TestServer(t, nil) + defer s1.Shutdown() + testutil.WaitForLeader(t, s1.RPC) + + p1, p2 := net.Pipe() + w1 := namedConnWrapper{ + Conn: p1, + name: "a", + } + w2 := namedConnWrapper{ + Conn: p2, + name: "b", + } + + // Add the connections + nodeID := uuid.Generate() + ctx1 := &RPCContext{ + Conn: w1, + NodeID: nodeID, + } + ctx2 := &RPCContext{ + Conn: w2, + NodeID: nodeID, + } + + s1.addNodeConn(ctx1) + s1.addNodeConn(ctx2) + require.Len(s1.connectedNodes(), 1) + + // Check that the value is the second conn. + state, ok := s1.getNodeConn(nodeID) + require.True(ok) + require.Equal(state.Ctx.Conn.LocalAddr().String(), w2.name) + + // Delete the first + s1.removeNodeConn(ctx1) + require.Len(s1.connectedNodes(), 1) + + // Check that the value is the second conn. + state, ok = s1.getNodeConn(nodeID) + require.True(ok) + require.Equal(state.Ctx.Conn.LocalAddr().String(), w2.name) + + // Delete the second + s1.removeNodeConn(ctx2) + require.Len(s1.connectedNodes(), 0) +} + +func TestServerWithNodeConn_NoPath(t *testing.T) { + t.Parallel() + require := require.New(t) + s1 := TestServer(t, nil) + defer s1.Shutdown() + s2 := TestServer(t, func(c *Config) { + c.DevDisableBootstrap = true + }) + defer s2.Shutdown() + TestJoin(t, s1, s2) + testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForLeader(t, s2.RPC) + + nodeID := uuid.Generate() + srv, err := s1.serverWithNodeConn(nodeID, s1.Region()) + require.Nil(srv) + require.EqualError(err, structs.ErrNoNodeConn.Error()) +} + +func TestServerWithNodeConn_NoPath_Region(t *testing.T) { + t.Parallel() + require := require.New(t) + s1 := TestServer(t, nil) + defer s1.Shutdown() + testutil.WaitForLeader(t, s1.RPC) + + nodeID := uuid.Generate() + srv, err := s1.serverWithNodeConn(nodeID, "fake-region") + require.Nil(srv) + require.EqualError(err, structs.ErrNoRegionPath.Error()) +} + +func TestServerWithNodeConn_Path(t *testing.T) { + t.Parallel() + require := require.New(t) + s1 := TestServer(t, nil) + defer s1.Shutdown() + s2 := TestServer(t, func(c *Config) { + c.DevDisableBootstrap = true + }) + defer s2.Shutdown() + TestJoin(t, s1, s2) + testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForLeader(t, s2.RPC) + + // Create a fake connection for the node on server 2 + nodeID := uuid.Generate() + s2.addNodeConn(&RPCContext{ + NodeID: nodeID, + }) + + srv, err := s1.serverWithNodeConn(nodeID, s1.Region()) + require.NotNil(srv) + require.Equal(srv.Addr.String(), s2.config.RPCAddr.String()) + require.Nil(err) +} + +func TestServerWithNodeConn_Path_Region(t *testing.T) { + t.Parallel() + require := require.New(t) + s1 := TestServer(t, nil) + defer s1.Shutdown() + s2 := TestServer(t, func(c *Config) { + c.Region = "two" + }) + defer s2.Shutdown() + TestJoin(t, s1, s2) + testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForLeader(t, s2.RPC) + + // Create a fake connection for the node on server 2 + nodeID := uuid.Generate() + s2.addNodeConn(&RPCContext{ + NodeID: nodeID, + }) + + srv, err := s1.serverWithNodeConn(nodeID, s2.Region()) + require.NotNil(srv) + require.Equal(srv.Addr.String(), s2.config.RPCAddr.String()) + require.Nil(err) +} + +func TestServerWithNodeConn_Path_Newest(t *testing.T) { + t.Parallel() + require := require.New(t) + s1 := TestServer(t, nil) + defer s1.Shutdown() + s2 := TestServer(t, func(c *Config) { + c.DevDisableBootstrap = true + }) + defer s2.Shutdown() + s3 := TestServer(t, func(c *Config) { + c.DevDisableBootstrap = true + }) + defer s3.Shutdown() + TestJoin(t, s1, s2, s3) + testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForLeader(t, s2.RPC) + testutil.WaitForLeader(t, s3.RPC) + + // Create a fake connection for the node on server 2 and 3 + nodeID := uuid.Generate() + s2.addNodeConn(&RPCContext{ + NodeID: nodeID, + }) + s3.addNodeConn(&RPCContext{ + NodeID: nodeID, + }) + + srv, err := s1.serverWithNodeConn(nodeID, s1.Region()) + require.NotNil(srv) + require.Equal(srv.Addr.String(), s3.config.RPCAddr.String()) + require.Nil(err) +} + +func TestServerWithNodeConn_PathAndErr(t *testing.T) { + t.Parallel() + require := require.New(t) + s1 := TestServer(t, nil) + defer s1.Shutdown() + s2 := TestServer(t, func(c *Config) { + c.DevDisableBootstrap = true + }) + defer s2.Shutdown() + s3 := TestServer(t, func(c *Config) { + c.DevDisableBootstrap = true + }) + defer s3.Shutdown() + TestJoin(t, s1, s2, s3) + testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForLeader(t, s2.RPC) + testutil.WaitForLeader(t, s3.RPC) + + // Create a fake connection for the node on server 2 + nodeID := uuid.Generate() + s2.addNodeConn(&RPCContext{ + NodeID: nodeID, + }) + + // Shutdown the RPC layer for server 3 + s3.rpcListener.Close() + + srv, err := s1.serverWithNodeConn(nodeID, s1.Region()) + require.NotNil(srv) + require.Equal(srv.Addr.String(), s2.config.RPCAddr.String()) + require.Nil(err) +} + +func TestServerWithNodeConn_NoPathAndErr(t *testing.T) { + t.Parallel() + require := require.New(t) + s1 := TestServer(t, nil) + defer s1.Shutdown() + s2 := TestServer(t, func(c *Config) { + c.DevDisableBootstrap = true + }) + defer s2.Shutdown() + s3 := TestServer(t, func(c *Config) { + c.DevDisableBootstrap = true + }) + defer s3.Shutdown() + TestJoin(t, s1, s2, s3) + testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForLeader(t, s2.RPC) + testutil.WaitForLeader(t, s3.RPC) + + // Shutdown the RPC layer for server 3 + s3.rpcListener.Close() + + srv, err := s1.serverWithNodeConn(uuid.Generate(), s1.Region()) + require.Nil(srv) + require.NotNil(err) + require.Contains(err.Error(), "failed querying") +} + +func TestNodeStreamingRpc_badEndpoint(t *testing.T) { + t.Parallel() + require := require.New(t) + s1 := TestServer(t, nil) + defer s1.Shutdown() + testutil.WaitForLeader(t, s1.RPC) + + c := client.TestClient(t, func(c *config.Config) { + c.Servers = []string{s1.config.RPCAddr.String()} + }) + defer c.Shutdown() + + // Wait for the client to connect + testutil.WaitForResult(func() (bool, error) { + nodes := s1.connectedNodes() + return len(nodes) == 1, nil + }, func(err error) { + t.Fatalf("should have a clients") + }) + + state, ok := s1.getNodeConn(c.NodeID()) + require.True(ok) + + conn, err := NodeStreamingRpc(state.Session, "Bogus") + require.Nil(conn) + require.NotNil(err) + require.Contains(err.Error(), "Bogus") + require.True(structs.IsErrUnknownMethod(err)) +} diff --git a/nomad/client_stats_endpoint.go b/nomad/client_stats_endpoint.go new file mode 100644 index 00000000000..7042bbf2b97 --- /dev/null +++ b/nomad/client_stats_endpoint.go @@ -0,0 +1,76 @@ +package nomad + +import ( + "errors" + "fmt" + "time" + + metrics "github.com/armon/go-metrics" + "github.com/hashicorp/nomad/client/structs" + nstructs "github.com/hashicorp/nomad/nomad/structs" +) + +// ClientStats is used to forward RPC requests to the targed Nomad client's +// ClientStats endpoint. +type ClientStats struct { + srv *Server +} + +func (s *ClientStats) Stats(args *nstructs.NodeSpecificRequest, reply *structs.ClientStatsResponse) error { + // We only allow stale reads since the only potentially stale information is + // the Node registration and the cost is fairly high for adding another hope + // in the forwarding chain. + args.QueryOptions.AllowStale = true + + // Potentially forward to a different region. + if done, err := s.srv.forward("ClientStats.Stats", args, args, reply); done { + return err + } + defer metrics.MeasureSince([]string{"nomad", "client_stats", "stats"}, time.Now()) + + // Check node read permissions + if aclObj, err := s.srv.ResolveToken(args.AuthToken); err != nil { + return err + } else if aclObj != nil && !aclObj.AllowNodeRead() { + return nstructs.ErrPermissionDenied + } + + // Verify the arguments. + if args.NodeID == "" { + return errors.New("missing NodeID") + } + + // Get the connection to the client + state, ok := s.srv.getNodeConn(args.NodeID) + if !ok { + // Check if the node even exists + snap, err := s.srv.State().Snapshot() + if err != nil { + return err + } + + node, err := snap.NodeByID(nil, args.NodeID) + if err != nil { + return err + } + + if node == nil { + return fmt.Errorf("Unknown node %q", args.NodeID) + } + + // Determine the Server that has a connection to the node. + srv, err := s.srv.serverWithNodeConn(args.NodeID, s.srv.Region()) + if err != nil { + return err + } + + if srv == nil { + return nstructs.ErrNoNodeConn + } + + return s.srv.forwardServer(srv, "ClientStats.Stats", args, reply) + } + + // Make the RPC + return NodeRpc(state.Session, "ClientStats.Stats", args, reply) +} diff --git a/nomad/client_stats_endpoint_test.go b/nomad/client_stats_endpoint_test.go new file mode 100644 index 00000000000..232be4cfb80 --- /dev/null +++ b/nomad/client_stats_endpoint_test.go @@ -0,0 +1,188 @@ +package nomad + +import ( + "testing" + + msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc" + "github.com/hashicorp/nomad/acl" + "github.com/hashicorp/nomad/client" + "github.com/hashicorp/nomad/client/config" + cstructs "github.com/hashicorp/nomad/client/structs" + "github.com/hashicorp/nomad/helper/uuid" + "github.com/hashicorp/nomad/nomad/mock" + "github.com/hashicorp/nomad/nomad/structs" + "github.com/hashicorp/nomad/testutil" + "github.com/stretchr/testify/require" +) + +func TestClientStats_Stats_Local(t *testing.T) { + t.Parallel() + require := require.New(t) + + // Start a server and client + s := TestServer(t, nil) + defer s.Shutdown() + codec := rpcClient(t, s) + testutil.WaitForLeader(t, s.RPC) + + c := client.TestClient(t, func(c *config.Config) { + c.Servers = []string{s.config.RPCAddr.String()} + }) + defer c.Shutdown() + + testutil.WaitForResult(func() (bool, error) { + nodes := s.connectedNodes() + return len(nodes) == 1, nil + }, func(err error) { + t.Fatalf("should have a clients") + }) + + // Make the request without having a node-id + req := &structs.NodeSpecificRequest{ + QueryOptions: structs.QueryOptions{Region: "global"}, + } + + // Fetch the response + var resp cstructs.ClientStatsResponse + err := msgpackrpc.CallWithCodec(codec, "ClientStats.Stats", req, &resp) + require.NotNil(err) + require.Contains(err.Error(), "missing") + + // Fetch the response setting the node id + req.NodeID = c.NodeID() + var resp2 cstructs.ClientStatsResponse + err = msgpackrpc.CallWithCodec(codec, "ClientStats.Stats", req, &resp2) + require.Nil(err) + require.NotNil(resp2.HostStats) +} + +func TestClientStats_Stats_Local_ACL(t *testing.T) { + t.Parallel() + require := require.New(t) + + // Start a server + s, root := TestACLServer(t, nil) + defer s.Shutdown() + codec := rpcClient(t, s) + testutil.WaitForLeader(t, s.RPC) + + // Create a bad token + policyBad := mock.NamespacePolicy("other", "", []string{acl.NamespaceCapabilityReadFS}) + tokenBad := mock.CreatePolicyAndToken(t, s.State(), 1005, "invalid", policyBad) + + policyGood := mock.NodePolicy(acl.PolicyRead) + tokenGood := mock.CreatePolicyAndToken(t, s.State(), 1009, "valid2", policyGood) + + cases := []struct { + Name string + Token string + ExpectedError string + }{ + { + Name: "bad token", + Token: tokenBad.SecretID, + ExpectedError: structs.ErrPermissionDenied.Error(), + }, + { + Name: "good token", + Token: tokenGood.SecretID, + ExpectedError: "Unknown node", + }, + { + Name: "root token", + Token: root.SecretID, + ExpectedError: "Unknown node", + }, + } + + for _, c := range cases { + t.Run(c.Name, func(t *testing.T) { + + // Make the request without having a node-id + req := &structs.NodeSpecificRequest{ + NodeID: uuid.Generate(), + QueryOptions: structs.QueryOptions{ + AuthToken: c.Token, + Region: "global", + }, + } + + // Fetch the response + var resp cstructs.ClientStatsResponse + err := msgpackrpc.CallWithCodec(codec, "ClientStats.Stats", req, &resp) + require.NotNil(err) + require.Contains(err.Error(), c.ExpectedError) + }) + } +} + +func TestClientStats_Stats_NoNode(t *testing.T) { + t.Parallel() + require := require.New(t) + + // Start a server and client + s := TestServer(t, nil) + defer s.Shutdown() + codec := rpcClient(t, s) + testutil.WaitForLeader(t, s.RPC) + + // Make the request without having a node-id + req := &structs.NodeSpecificRequest{ + NodeID: uuid.Generate(), + QueryOptions: structs.QueryOptions{Region: "global"}, + } + + // Fetch the response + var resp cstructs.ClientStatsResponse + err := msgpackrpc.CallWithCodec(codec, "ClientStats.Stats", req, &resp) + require.Nil(resp.HostStats) + require.NotNil(err) + require.Contains(err.Error(), "Unknown node") +} + +func TestClientStats_Stats_Remote(t *testing.T) { + t.Parallel() + require := require.New(t) + + // Start a server and client + s1 := TestServer(t, nil) + defer s1.Shutdown() + s2 := TestServer(t, func(c *Config) { + c.DevDisableBootstrap = true + }) + defer s2.Shutdown() + TestJoin(t, s1, s2) + testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForLeader(t, s2.RPC) + codec := rpcClient(t, s2) + + c := client.TestClient(t, func(c *config.Config) { + c.Servers = []string{s2.config.RPCAddr.String()} + }) + defer c.Shutdown() + + testutil.WaitForResult(func() (bool, error) { + nodes := s2.connectedNodes() + return len(nodes) == 1, nil + }, func(err error) { + t.Fatalf("should have a clients") + }) + + // Force remove the connection locally in case it exists + s1.nodeConnsLock.Lock() + delete(s1.nodeConns, c.NodeID()) + s1.nodeConnsLock.Unlock() + + // Make the request without having a node-id + req := &structs.NodeSpecificRequest{ + NodeID: uuid.Generate(), + QueryOptions: structs.QueryOptions{Region: "global"}, + } + + // Fetch the response + req.NodeID = c.NodeID() + var resp cstructs.ClientStatsResponse + err := msgpackrpc.CallWithCodec(codec, "ClientStats.Stats", req, &resp) + require.Nil(err) + require.NotNil(resp.HostStats) +} diff --git a/nomad/core_sched_test.go b/nomad/core_sched_test.go index 36c61c530ca..616d1800e3f 100644 --- a/nomad/core_sched_test.go +++ b/nomad/core_sched_test.go @@ -16,7 +16,7 @@ import ( func TestCoreScheduler_EvalGC(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() testutil.WaitForLeader(t, s1.RPC) require := require.New(t) @@ -109,7 +109,7 @@ func TestCoreScheduler_EvalGC(t *testing.T) { // Tests GC behavior on allocations being rescheduled func TestCoreScheduler_EvalGC_ReshedulingAllocs(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() testutil.WaitForLeader(t, s1.RPC) require := require.New(t) @@ -210,7 +210,7 @@ func TestCoreScheduler_EvalGC_ReshedulingAllocs(t *testing.T) { // Tests GC behavior on stopped job with reschedulable allocs func TestCoreScheduler_EvalGC_StoppedJob_Reschedulable(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() testutil.WaitForLeader(t, s1.RPC) require := require.New(t) @@ -285,7 +285,7 @@ func TestCoreScheduler_EvalGC_StoppedJob_Reschedulable(t *testing.T) { // An EvalGC should never reap a batch job that has not been stopped func TestCoreScheduler_EvalGC_Batch(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() testutil.WaitForLeader(t, s1.RPC) @@ -386,7 +386,7 @@ func TestCoreScheduler_EvalGC_Batch(t *testing.T) { // An EvalGC should reap a batch job that has been stopped func TestCoreScheduler_EvalGC_BatchStopped(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() testutil.WaitForLeader(t, s1.RPC) @@ -482,7 +482,7 @@ func TestCoreScheduler_EvalGC_BatchStopped(t *testing.T) { func TestCoreScheduler_EvalGC_Partial(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() testutil.WaitForLeader(t, s1.RPC) require := require.New(t) @@ -603,9 +603,9 @@ func TestCoreScheduler_EvalGC_Force(t *testing.T) { require := require.New(t) var server *Server if withAcl { - server, _ = testACLServer(t, nil) + server, _ = TestACLServer(t, nil) } else { - server = testServer(t, nil) + server = TestServer(t, nil) } defer server.Shutdown() testutil.WaitForLeader(t, server.RPC) @@ -685,9 +685,9 @@ func TestCoreScheduler_NodeGC(t *testing.T) { t.Run(fmt.Sprintf("with acl %v", withAcl), func(t *testing.T) { var server *Server if withAcl { - server, _ = testACLServer(t, nil) + server, _ = TestACLServer(t, nil) } else { - server = testServer(t, nil) + server = TestServer(t, nil) } defer server.Shutdown() testutil.WaitForLeader(t, server.RPC) @@ -737,7 +737,7 @@ func TestCoreScheduler_NodeGC(t *testing.T) { func TestCoreScheduler_NodeGC_TerminalAllocs(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() testutil.WaitForLeader(t, s1.RPC) @@ -792,7 +792,7 @@ func TestCoreScheduler_NodeGC_TerminalAllocs(t *testing.T) { func TestCoreScheduler_NodeGC_RunningAllocs(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() testutil.WaitForLeader(t, s1.RPC) @@ -849,7 +849,7 @@ func TestCoreScheduler_NodeGC_RunningAllocs(t *testing.T) { func TestCoreScheduler_NodeGC_Force(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() testutil.WaitForLeader(t, s1.RPC) @@ -892,7 +892,7 @@ func TestCoreScheduler_NodeGC_Force(t *testing.T) { func TestCoreScheduler_JobGC_OutstandingEvals(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() testutil.WaitForLeader(t, s1.RPC) @@ -1015,7 +1015,7 @@ func TestCoreScheduler_JobGC_OutstandingEvals(t *testing.T) { func TestCoreScheduler_JobGC_OutstandingAllocs(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() testutil.WaitForLeader(t, s1.RPC) @@ -1160,7 +1160,7 @@ func TestCoreScheduler_JobGC_OutstandingAllocs(t *testing.T) { // allocs/evals and job or nothing func TestCoreScheduler_JobGC_OneShot(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() testutil.WaitForLeader(t, s1.RPC) @@ -1272,7 +1272,7 @@ func TestCoreScheduler_JobGC_OneShot(t *testing.T) { // This test ensures that stopped jobs are GCd func TestCoreScheduler_JobGC_Stopped(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() testutil.WaitForLeader(t, s1.RPC) @@ -1376,9 +1376,9 @@ func TestCoreScheduler_JobGC_Force(t *testing.T) { t.Run(fmt.Sprintf("with acl %v", withAcl), func(t *testing.T) { var server *Server if withAcl { - server, _ = testACLServer(t, nil) + server, _ = TestACLServer(t, nil) } else { - server = testServer(t, nil) + server = TestServer(t, nil) } defer server.Shutdown() testutil.WaitForLeader(t, server.RPC) @@ -1443,7 +1443,7 @@ func TestCoreScheduler_JobGC_Force(t *testing.T) { // This test ensures parameterized jobs only get gc'd when stopped func TestCoreScheduler_JobGC_Parameterized(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() testutil.WaitForLeader(t, s1.RPC) @@ -1523,7 +1523,7 @@ func TestCoreScheduler_JobGC_Parameterized(t *testing.T) { func TestCoreScheduler_JobGC_Periodic(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() testutil.WaitForLeader(t, s1.RPC) @@ -1596,7 +1596,7 @@ func TestCoreScheduler_JobGC_Periodic(t *testing.T) { func TestCoreScheduler_DeploymentGC(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() testutil.WaitForLeader(t, s1.RPC) assert := assert.New(t) @@ -1650,9 +1650,9 @@ func TestCoreScheduler_DeploymentGC_Force(t *testing.T) { t.Run(fmt.Sprintf("with acl %v", withAcl), func(t *testing.T) { var server *Server if withAcl { - server, _ = testACLServer(t, nil) + server, _ = TestACLServer(t, nil) } else { - server = testServer(t, nil) + server = TestServer(t, nil) } defer server.Shutdown() testutil.WaitForLeader(t, server.RPC) @@ -1691,7 +1691,7 @@ func TestCoreScheduler_DeploymentGC_Force(t *testing.T) { func TestCoreScheduler_PartitionEvalReap(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() testutil.WaitForLeader(t, s1.RPC) @@ -1733,7 +1733,7 @@ func TestCoreScheduler_PartitionEvalReap(t *testing.T) { func TestCoreScheduler_PartitionDeploymentReap(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() testutil.WaitForLeader(t, s1.RPC) diff --git a/nomad/deployment_endpoint_test.go b/nomad/deployment_endpoint_test.go index c03bfb129ca..fedda98ebe4 100644 --- a/nomad/deployment_endpoint_test.go +++ b/nomad/deployment_endpoint_test.go @@ -16,7 +16,7 @@ import ( func TestDeploymentEndpoint_GetDeployment(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) @@ -47,7 +47,7 @@ func TestDeploymentEndpoint_GetDeployment(t *testing.T) { func TestDeploymentEndpoint_GetDeployment_ACL(t *testing.T) { t.Parallel() - s1, root := testACLServer(t, nil) + s1, root := TestACLServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) @@ -100,7 +100,7 @@ func TestDeploymentEndpoint_GetDeployment_ACL(t *testing.T) { func TestDeploymentEndpoint_GetDeployment_Blocking(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) @@ -149,7 +149,7 @@ func TestDeploymentEndpoint_GetDeployment_Blocking(t *testing.T) { func TestDeploymentEndpoint_Fail(t *testing.T) { t.Parallel() - s1 := testServer(t, func(c *Config) { + s1 := TestServer(t, func(c *Config) { c.NumSchedulers = 0 // Prevent automatic dequeue }) defer s1.Shutdown() @@ -198,7 +198,7 @@ func TestDeploymentEndpoint_Fail(t *testing.T) { func TestDeploymentEndpoint_Fail_ACL(t *testing.T) { t.Parallel() - s1, _ := testACLServer(t, func(c *Config) { + s1, _ := TestACLServer(t, func(c *Config) { c.NumSchedulers = 0 // Prevent automatic dequeue }) defer s1.Shutdown() @@ -273,7 +273,7 @@ func TestDeploymentEndpoint_Fail_ACL(t *testing.T) { func TestDeploymentEndpoint_Fail_Rollback(t *testing.T) { t.Parallel() - s1 := testServer(t, func(c *Config) { + s1 := TestServer(t, func(c *Config) { c.NumSchedulers = 0 // Prevent automatic dequeue }) defer s1.Shutdown() @@ -350,7 +350,7 @@ func TestDeploymentEndpoint_Fail_Rollback(t *testing.T) { func TestDeploymentEndpoint_Pause(t *testing.T) { t.Parallel() - s1 := testServer(t, func(c *Config) { + s1 := TestServer(t, func(c *Config) { c.NumSchedulers = 0 // Prevent automatic dequeue }) defer s1.Shutdown() @@ -392,7 +392,7 @@ func TestDeploymentEndpoint_Pause(t *testing.T) { func TestDeploymentEndpoint_Pause_ACL(t *testing.T) { t.Parallel() - s1, _ := testACLServer(t, func(c *Config) { + s1, _ := TestACLServer(t, func(c *Config) { c.NumSchedulers = 0 // Prevent automatic dequeue }) defer s1.Shutdown() @@ -460,7 +460,7 @@ func TestDeploymentEndpoint_Pause_ACL(t *testing.T) { func TestDeploymentEndpoint_Promote(t *testing.T) { t.Parallel() - s1 := testServer(t, func(c *Config) { + s1 := TestServer(t, func(c *Config) { c.NumSchedulers = 0 // Prevent automatic dequeue }) defer s1.Shutdown() @@ -524,7 +524,7 @@ func TestDeploymentEndpoint_Promote(t *testing.T) { func TestDeploymentEndpoint_Promote_ACL(t *testing.T) { t.Parallel() - s1, _ := testACLServer(t, func(c *Config) { + s1, _ := TestACLServer(t, func(c *Config) { c.NumSchedulers = 0 // Prevent automatic dequeue }) defer s1.Shutdown() @@ -614,7 +614,7 @@ func TestDeploymentEndpoint_Promote_ACL(t *testing.T) { func TestDeploymentEndpoint_SetAllocHealth(t *testing.T) { t.Parallel() - s1 := testServer(t, func(c *Config) { + s1 := TestServer(t, func(c *Config) { c.NumSchedulers = 0 // Prevent automatic dequeue }) defer s1.Shutdown() @@ -681,7 +681,7 @@ func TestDeploymentEndpoint_SetAllocHealth(t *testing.T) { func TestDeploymentEndpoint_SetAllocHealth_ACL(t *testing.T) { t.Parallel() - s1, _ := testACLServer(t, func(c *Config) { + s1, _ := TestACLServer(t, func(c *Config) { c.NumSchedulers = 0 // Prevent automatic dequeue }) defer s1.Shutdown() @@ -774,7 +774,7 @@ func TestDeploymentEndpoint_SetAllocHealth_ACL(t *testing.T) { func TestDeploymentEndpoint_SetAllocHealth_Rollback(t *testing.T) { t.Parallel() - s1 := testServer(t, func(c *Config) { + s1 := TestServer(t, func(c *Config) { c.NumSchedulers = 0 // Prevent automatic dequeue }) defer s1.Shutdown() @@ -863,7 +863,7 @@ func TestDeploymentEndpoint_SetAllocHealth_Rollback(t *testing.T) { // tests rollback upon alloc health failure to job with identical spec does not succeed func TestDeploymentEndpoint_SetAllocHealth_NoRollback(t *testing.T) { t.Parallel() - s1 := testServer(t, func(c *Config) { + s1 := TestServer(t, func(c *Config) { c.NumSchedulers = 0 // Prevent automatic dequeue }) defer s1.Shutdown() @@ -949,7 +949,7 @@ func TestDeploymentEndpoint_SetAllocHealth_NoRollback(t *testing.T) { func TestDeploymentEndpoint_List(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) @@ -995,7 +995,7 @@ func TestDeploymentEndpoint_List(t *testing.T) { func TestDeploymentEndpoint_List_ACL(t *testing.T) { t.Parallel() - s1, root := testACLServer(t, nil) + s1, root := TestACLServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) @@ -1063,7 +1063,7 @@ func TestDeploymentEndpoint_List_ACL(t *testing.T) { func TestDeploymentEndpoint_List_Blocking(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() state := s1.fsm.State() codec := rpcClient(t, s1) @@ -1120,7 +1120,7 @@ func TestDeploymentEndpoint_List_Blocking(t *testing.T) { func TestDeploymentEndpoint_Allocations(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) @@ -1157,7 +1157,7 @@ func TestDeploymentEndpoint_Allocations(t *testing.T) { func TestDeploymentEndpoint_Allocations_ACL(t *testing.T) { t.Parallel() - s1, root := testACLServer(t, nil) + s1, root := TestACLServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) @@ -1231,7 +1231,7 @@ func TestDeploymentEndpoint_Allocations_ACL(t *testing.T) { func TestDeploymentEndpoint_Allocations_Blocking(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() state := s1.fsm.State() codec := rpcClient(t, s1) @@ -1298,7 +1298,7 @@ func TestDeploymentEndpoint_Allocations_Blocking(t *testing.T) { func TestDeploymentEndpoint_Reap(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) diff --git a/nomad/endpoints_oss.go b/nomad/endpoints_oss.go index 3d59b57ead0..006b05552c5 100644 --- a/nomad/endpoints_oss.go +++ b/nomad/endpoints_oss.go @@ -2,6 +2,8 @@ package nomad +import "net/rpc" + // EnterpriseEndpoints holds the set of enterprise only endpoints to register type EnterpriseEndpoints struct{} @@ -12,4 +14,4 @@ func NewEnterpriseEndpoints(s *Server) *EnterpriseEndpoints { } // Register is a no-op in oss. -func (e *EnterpriseEndpoints) Register(s *Server) {} +func (e *EnterpriseEndpoints) Register(s *rpc.Server) {} diff --git a/nomad/eval_endpoint_test.go b/nomad/eval_endpoint_test.go index ea0c42a019c..68504379db1 100644 --- a/nomad/eval_endpoint_test.go +++ b/nomad/eval_endpoint_test.go @@ -1,7 +1,6 @@ package nomad import ( - "encoding/base64" "fmt" "reflect" "strings" @@ -21,7 +20,7 @@ import ( func TestEvalEndpoint_GetEval(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) @@ -62,7 +61,7 @@ func TestEvalEndpoint_GetEval(t *testing.T) { func TestEvalEndpoint_GetEval_ACL(t *testing.T) { t.Parallel() - s1, root := testACLServer(t, nil) + s1, root := TestACLServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) @@ -122,7 +121,7 @@ func TestEvalEndpoint_GetEval_ACL(t *testing.T) { func TestEvalEndpoint_GetEval_Blocking(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() state := s1.fsm.State() codec := rpcClient(t, s1) @@ -200,7 +199,7 @@ func TestEvalEndpoint_GetEval_Blocking(t *testing.T) { func TestEvalEndpoint_Dequeue(t *testing.T) { t.Parallel() - s1 := testServer(t, func(c *Config) { + s1 := TestServer(t, func(c *Config) { c.NumSchedulers = 0 // Prevent automatic dequeue }) defer s1.Shutdown() @@ -242,7 +241,7 @@ func TestEvalEndpoint_Dequeue(t *testing.T) { func TestEvalEndpoint_Dequeue_WaitIndex(t *testing.T) { t.Parallel() - s1 := testServer(t, func(c *Config) { + s1 := TestServer(t, func(c *Config) { c.NumSchedulers = 0 // Prevent automatic dequeue }) defer s1.Shutdown() @@ -289,7 +288,7 @@ func TestEvalEndpoint_Dequeue_WaitIndex(t *testing.T) { func TestEvalEndpoint_Dequeue_UpdateWaitIndex(t *testing.T) { // test enqueueing an eval, updating a plan result for the same eval and de-queueing the eval t.Parallel() - s1 := testServer(t, func(c *Config) { + s1 := TestServer(t, func(c *Config) { c.NumSchedulers = 0 // Prevent automatic dequeue }) defer s1.Shutdown() @@ -355,7 +354,7 @@ func TestEvalEndpoint_Dequeue_UpdateWaitIndex(t *testing.T) { func TestEvalEndpoint_Dequeue_Version_Mismatch(t *testing.T) { t.Parallel() - s1 := testServer(t, func(c *Config) { + s1 := TestServer(t, func(c *Config) { c.NumSchedulers = 0 // Prevent automatic dequeue }) defer s1.Shutdown() @@ -381,7 +380,7 @@ func TestEvalEndpoint_Dequeue_Version_Mismatch(t *testing.T) { func TestEvalEndpoint_Ack(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) @@ -421,7 +420,7 @@ func TestEvalEndpoint_Ack(t *testing.T) { func TestEvalEndpoint_Nack(t *testing.T) { t.Parallel() - s1 := testServer(t, func(c *Config) { + s1 := TestServer(t, func(c *Config) { // Disable all of the schedulers so we can manually dequeue // evals and check the queue status c.NumSchedulers = 0 @@ -474,7 +473,7 @@ func TestEvalEndpoint_Nack(t *testing.T) { func TestEvalEndpoint_Update(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) @@ -522,7 +521,7 @@ func TestEvalEndpoint_Update(t *testing.T) { func TestEvalEndpoint_Create(t *testing.T) { t.Parallel() - s1 := testServer(t, func(c *Config) { + s1 := TestServer(t, func(c *Config) { c.NumSchedulers = 0 // Prevent automatic dequeue }) defer s1.Shutdown() @@ -574,7 +573,7 @@ func TestEvalEndpoint_Create(t *testing.T) { func TestEvalEndpoint_Reap(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) @@ -609,7 +608,7 @@ func TestEvalEndpoint_Reap(t *testing.T) { func TestEvalEndpoint_List(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) @@ -664,7 +663,7 @@ func TestEvalEndpoint_List(t *testing.T) { func TestEvalEndpoint_List_ACL(t *testing.T) { t.Parallel() - s1, root := testACLServer(t, nil) + s1, root := TestACLServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) @@ -729,7 +728,7 @@ func TestEvalEndpoint_List_ACL(t *testing.T) { func TestEvalEndpoint_List_Blocking(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() state := s1.fsm.State() codec := rpcClient(t, s1) @@ -795,7 +794,7 @@ func TestEvalEndpoint_List_Blocking(t *testing.T) { func TestEvalEndpoint_Allocations(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) @@ -833,7 +832,7 @@ func TestEvalEndpoint_Allocations(t *testing.T) { func TestEvalEndpoint_Allocations_ACL(t *testing.T) { t.Parallel() - s1, root := testACLServer(t, nil) + s1, root := TestACLServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) @@ -897,7 +896,7 @@ func TestEvalEndpoint_Allocations_ACL(t *testing.T) { func TestEvalEndpoint_Allocations_Blocking(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() state := s1.fsm.State() codec := rpcClient(t, s1) @@ -952,7 +951,7 @@ func TestEvalEndpoint_Allocations_Blocking(t *testing.T) { func TestEvalEndpoint_Reblock_NonExistent(t *testing.T) { t.Parallel() - s1 := testServer(t, func(c *Config) { + s1 := TestServer(t, func(c *Config) { c.NumSchedulers = 0 // Prevent automatic dequeue }) defer s1.Shutdown() @@ -988,7 +987,7 @@ func TestEvalEndpoint_Reblock_NonExistent(t *testing.T) { func TestEvalEndpoint_Reblock_NonBlocked(t *testing.T) { t.Parallel() - s1 := testServer(t, func(c *Config) { + s1 := TestServer(t, func(c *Config) { c.NumSchedulers = 0 // Prevent automatic dequeue }) defer s1.Shutdown() @@ -1030,7 +1029,7 @@ func TestEvalEndpoint_Reblock_NonBlocked(t *testing.T) { func TestEvalEndpoint_Reblock(t *testing.T) { t.Parallel() - s1 := testServer(t, func(c *Config) { + s1 := TestServer(t, func(c *Config) { c.NumSchedulers = 0 // Prevent automatic dequeue }) defer s1.Shutdown() @@ -1076,25 +1075,3 @@ func TestEvalEndpoint_Reblock(t *testing.T) { t.Fatalf("ReblockEval didn't insert eval into the blocked eval tracker") } } - -// TestGenerateMigrateToken asserts the migrate token is valid for use in HTTP -// headers and CompareMigrateToken works as expected. -func TestGenerateMigrateToken(t *testing.T) { - assert := assert.New(t) - allocID := uuid.Generate() - nodeSecret := uuid.Generate() - token, err := GenerateMigrateToken(allocID, nodeSecret) - assert.Nil(err) - _, err = base64.URLEncoding.DecodeString(token) - assert.Nil(err) - - assert.True(CompareMigrateToken(allocID, nodeSecret, token)) - assert.False(CompareMigrateToken("x", nodeSecret, token)) - assert.False(CompareMigrateToken(allocID, "x", token)) - assert.False(CompareMigrateToken(allocID, nodeSecret, "x")) - - token2, err := GenerateMigrateToken("x", nodeSecret) - assert.Nil(err) - assert.False(CompareMigrateToken(allocID, nodeSecret, token2)) - assert.True(CompareMigrateToken("x", nodeSecret, token2)) -} diff --git a/nomad/heartbeat.go b/nomad/heartbeat.go index 89bc8601015..54e885337cb 100644 --- a/nomad/heartbeat.go +++ b/nomad/heartbeat.go @@ -100,7 +100,7 @@ func (s *Server) invalidateHeartbeat(id string) { }, } var resp structs.NodeUpdateResponse - if err := s.endpoints.Node.UpdateStatus(&req, &resp); err != nil { + if err := s.staticEndpoints.Node.UpdateStatus(&req, &resp); err != nil { s.logger.Printf("[ERR] nomad.heartbeat: update status failed: %v", err) } } diff --git a/nomad/heartbeat_test.go b/nomad/heartbeat_test.go index 24d8283fdc3..0afbee73582 100644 --- a/nomad/heartbeat_test.go +++ b/nomad/heartbeat_test.go @@ -14,7 +14,7 @@ import ( func TestInitializeHeartbeatTimers(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() testutil.WaitForLeader(t, s1.RPC) @@ -40,7 +40,7 @@ func TestInitializeHeartbeatTimers(t *testing.T) { func TestResetHeartbeatTimer(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() testutil.WaitForLeader(t, s1.RPC) @@ -62,7 +62,7 @@ func TestResetHeartbeatTimer(t *testing.T) { func TestResetHeartbeatTimerLocked(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() testutil.WaitForLeader(t, s1.RPC) @@ -83,7 +83,7 @@ func TestResetHeartbeatTimerLocked(t *testing.T) { func TestResetHeartbeatTimerLocked_Renew(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() testutil.WaitForLeader(t, s1.RPC) @@ -122,7 +122,7 @@ func TestResetHeartbeatTimerLocked_Renew(t *testing.T) { func TestInvalidateHeartbeat(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() testutil.WaitForLeader(t, s1.RPC) @@ -150,7 +150,7 @@ func TestInvalidateHeartbeat(t *testing.T) { func TestClearHeartbeatTimer(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() testutil.WaitForLeader(t, s1.RPC) @@ -170,7 +170,7 @@ func TestClearHeartbeatTimer(t *testing.T) { func TestClearAllHeartbeatTimers(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() testutil.WaitForLeader(t, s1.RPC) @@ -192,20 +192,20 @@ func TestClearAllHeartbeatTimers(t *testing.T) { func TestServer_HeartbeatTTL_Failover(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() - s2 := testServer(t, func(c *Config) { + s2 := TestServer(t, func(c *Config) { c.DevDisableBootstrap = true }) defer s2.Shutdown() - s3 := testServer(t, func(c *Config) { + s3 := TestServer(t, func(c *Config) { c.DevDisableBootstrap = true }) defer s3.Shutdown() servers := []*Server{s1, s2, s3} - testJoin(t, s1, s2, s3) + TestJoin(t, s1, s2, s3) testutil.WaitForResult(func() (bool, error) { peers, _ := s1.numPeers() diff --git a/nomad/job_endpoint_test.go b/nomad/job_endpoint_test.go index d9ff378194e..3f9dbf31010 100644 --- a/nomad/job_endpoint_test.go +++ b/nomad/job_endpoint_test.go @@ -21,7 +21,7 @@ import ( func TestJobEndpoint_Register(t *testing.T) { t.Parallel() - s1 := testServer(t, func(c *Config) { + s1 := TestServer(t, func(c *Config) { c.NumSchedulers = 0 // Prevent automatic dequeue }) defer s1.Shutdown() @@ -100,7 +100,7 @@ func TestJobEndpoint_Register(t *testing.T) { func TestJobEndpoint_Register_ACL(t *testing.T) { t.Parallel() - s1, root := testACLServer(t, func(c *Config) { + s1, root := TestACLServer(t, func(c *Config) { c.NumSchedulers = 0 // Prevent automatic dequeue }) defer s1.Shutdown() @@ -143,7 +143,7 @@ func TestJobEndpoint_Register_ACL(t *testing.T) { func TestJobEndpoint_Register_InvalidNamespace(t *testing.T) { t.Parallel() - s1 := testServer(t, func(c *Config) { + s1 := TestServer(t, func(c *Config) { c.NumSchedulers = 0 // Prevent automatic dequeue }) defer s1.Shutdown() @@ -179,7 +179,7 @@ func TestJobEndpoint_Register_InvalidNamespace(t *testing.T) { func TestJobEndpoint_Register_InvalidDriverConfig(t *testing.T) { t.Parallel() - s1 := testServer(t, func(c *Config) { + s1 := TestServer(t, func(c *Config) { c.NumSchedulers = 0 // Prevent automatic dequeue }) defer s1.Shutdown() @@ -212,7 +212,7 @@ func TestJobEndpoint_Register_InvalidDriverConfig(t *testing.T) { func TestJobEndpoint_Register_Payload(t *testing.T) { t.Parallel() - s1 := testServer(t, func(c *Config) { + s1 := TestServer(t, func(c *Config) { c.NumSchedulers = 0 // Prevent automatic dequeue }) defer s1.Shutdown() @@ -245,7 +245,7 @@ func TestJobEndpoint_Register_Payload(t *testing.T) { func TestJobEndpoint_Register_Existing(t *testing.T) { t.Parallel() - s1 := testServer(t, func(c *Config) { + s1 := TestServer(t, func(c *Config) { c.NumSchedulers = 0 // Prevent automatic dequeue }) defer s1.Shutdown() @@ -361,7 +361,7 @@ func TestJobEndpoint_Register_Existing(t *testing.T) { func TestJobEndpoint_Register_Periodic(t *testing.T) { t.Parallel() - s1 := testServer(t, func(c *Config) { + s1 := TestServer(t, func(c *Config) { c.NumSchedulers = 0 // Prevent automatic dequeue }) defer s1.Shutdown() @@ -413,7 +413,7 @@ func TestJobEndpoint_Register_Periodic(t *testing.T) { func TestJobEndpoint_Register_ParameterizedJob(t *testing.T) { t.Parallel() - s1 := testServer(t, func(c *Config) { + s1 := TestServer(t, func(c *Config) { c.NumSchedulers = 0 // Prevent automatic dequeue }) defer s1.Shutdown() @@ -461,7 +461,7 @@ func TestJobEndpoint_Register_ParameterizedJob(t *testing.T) { func TestJobEndpoint_Register_EnforceIndex(t *testing.T) { t.Parallel() - s1 := testServer(t, func(c *Config) { + s1 := TestServer(t, func(c *Config) { c.NumSchedulers = 0 // Prevent automatic dequeue }) defer s1.Shutdown() @@ -590,7 +590,7 @@ func TestJobEndpoint_Register_EnforceIndex(t *testing.T) { func TestJobEndpoint_Register_Vault_Disabled(t *testing.T) { t.Parallel() - s1 := testServer(t, func(c *Config) { + s1 := TestServer(t, func(c *Config) { c.NumSchedulers = 0 // Prevent automatic dequeue f := false c.VaultConfig.Enabled = &f @@ -623,7 +623,7 @@ func TestJobEndpoint_Register_Vault_Disabled(t *testing.T) { func TestJobEndpoint_Register_Vault_AllowUnauthenticated(t *testing.T) { t.Parallel() - s1 := testServer(t, func(c *Config) { + s1 := TestServer(t, func(c *Config) { c.NumSchedulers = 0 // Prevent automatic dequeue }) defer s1.Shutdown() @@ -676,7 +676,7 @@ func TestJobEndpoint_Register_Vault_AllowUnauthenticated(t *testing.T) { func TestJobEndpoint_Register_Vault_NoToken(t *testing.T) { t.Parallel() - s1 := testServer(t, func(c *Config) { + s1 := TestServer(t, func(c *Config) { c.NumSchedulers = 0 // Prevent automatic dequeue }) defer s1.Shutdown() @@ -716,7 +716,7 @@ func TestJobEndpoint_Register_Vault_NoToken(t *testing.T) { func TestJobEndpoint_Register_Vault_Policies(t *testing.T) { t.Parallel() - s1 := testServer(t, func(c *Config) { + s1 := TestServer(t, func(c *Config) { c.NumSchedulers = 0 // Prevent automatic dequeue }) defer s1.Shutdown() @@ -857,7 +857,7 @@ func TestJobEndpoint_Register_Vault_Policies(t *testing.T) { func TestJobEndpoint_Revert(t *testing.T) { t.Parallel() - s1 := testServer(t, func(c *Config) { + s1 := TestServer(t, func(c *Config) { c.NumSchedulers = 0 // Prevent automatic dequeue }) defer s1.Shutdown() @@ -1028,7 +1028,7 @@ func TestJobEndpoint_Revert_ACL(t *testing.T) { t.Parallel() assert := assert.New(t) - s1, root := testACLServer(t, func(c *Config) { + s1, root := TestACLServer(t, func(c *Config) { c.NumSchedulers = 0 // Prevent automatic dequeue }) @@ -1091,7 +1091,7 @@ func TestJobEndpoint_Revert_ACL(t *testing.T) { func TestJobEndpoint_Stable(t *testing.T) { t.Parallel() - s1 := testServer(t, func(c *Config) { + s1 := TestServer(t, func(c *Config) { c.NumSchedulers = 0 // Prevent automatic dequeue }) defer s1.Shutdown() @@ -1156,7 +1156,7 @@ func TestJobEndpoint_Stable_ACL(t *testing.T) { t.Parallel() assert := assert.New(t) - s1, root := testACLServer(t, func(c *Config) { + s1, root := TestACLServer(t, func(c *Config) { c.NumSchedulers = 0 // Prevent automatic dequeue }) defer s1.Shutdown() @@ -1221,7 +1221,7 @@ func TestJobEndpoint_Stable_ACL(t *testing.T) { func TestJobEndpoint_Evaluate(t *testing.T) { t.Parallel() - s1 := testServer(t, func(c *Config) { + s1 := TestServer(t, func(c *Config) { c.NumSchedulers = 0 // Prevent automatic dequeue }) defer s1.Shutdown() @@ -1302,7 +1302,7 @@ func TestJobEndpoint_Evaluate_ACL(t *testing.T) { t.Parallel() assert := assert.New(t) - s1, root := testACLServer(t, func(c *Config) { + s1, root := TestACLServer(t, func(c *Config) { c.NumSchedulers = 0 // Prevent automatic dequeue }) defer s1.Shutdown() @@ -1372,7 +1372,7 @@ func TestJobEndpoint_Evaluate_ACL(t *testing.T) { func TestJobEndpoint_Evaluate_Periodic(t *testing.T) { t.Parallel() - s1 := testServer(t, func(c *Config) { + s1 := TestServer(t, func(c *Config) { c.NumSchedulers = 0 // Prevent automatic dequeue }) defer s1.Shutdown() @@ -1415,7 +1415,7 @@ func TestJobEndpoint_Evaluate_Periodic(t *testing.T) { func TestJobEndpoint_Evaluate_ParameterizedJob(t *testing.T) { t.Parallel() - s1 := testServer(t, func(c *Config) { + s1 := TestServer(t, func(c *Config) { c.NumSchedulers = 0 // Prevent automatic dequeue }) defer s1.Shutdown() @@ -1460,7 +1460,7 @@ func TestJobEndpoint_Evaluate_ParameterizedJob(t *testing.T) { func TestJobEndpoint_Deregister(t *testing.T) { t.Parallel() - s1 := testServer(t, func(c *Config) { + s1 := TestServer(t, func(c *Config) { c.NumSchedulers = 0 // Prevent automatic dequeue }) defer s1.Shutdown() @@ -1607,7 +1607,7 @@ func TestJobEndpoint_Deregister_ACL(t *testing.T) { t.Parallel() assert := assert.New(t) - s1, root := testACLServer(t, func(c *Config) { + s1, root := TestACLServer(t, func(c *Config) { c.NumSchedulers = 0 // Prevent automatic dequeue }) defer s1.Shutdown() @@ -1686,7 +1686,7 @@ func TestJobEndpoint_Deregister_ACL(t *testing.T) { func TestJobEndpoint_Deregister_NonExistent(t *testing.T) { t.Parallel() - s1 := testServer(t, func(c *Config) { + s1 := TestServer(t, func(c *Config) { c.NumSchedulers = 0 // Prevent automatic dequeue }) defer s1.Shutdown() @@ -1746,7 +1746,7 @@ func TestJobEndpoint_Deregister_NonExistent(t *testing.T) { func TestJobEndpoint_Deregister_Periodic(t *testing.T) { t.Parallel() - s1 := testServer(t, func(c *Config) { + s1 := TestServer(t, func(c *Config) { c.NumSchedulers = 0 // Prevent automatic dequeue }) defer s1.Shutdown() @@ -1804,7 +1804,7 @@ func TestJobEndpoint_Deregister_Periodic(t *testing.T) { func TestJobEndpoint_Deregister_ParameterizedJob(t *testing.T) { t.Parallel() - s1 := testServer(t, func(c *Config) { + s1 := TestServer(t, func(c *Config) { c.NumSchedulers = 0 // Prevent automatic dequeue }) defer s1.Shutdown() @@ -1864,7 +1864,7 @@ func TestJobEndpoint_Deregister_ParameterizedJob(t *testing.T) { func TestJobEndpoint_GetJob(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) @@ -1943,7 +1943,7 @@ func TestJobEndpoint_GetJob_ACL(t *testing.T) { t.Parallel() assert := assert.New(t) - s1, root := testACLServer(t, nil) + s1, root := TestACLServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) @@ -1999,7 +1999,7 @@ func TestJobEndpoint_GetJob_ACL(t *testing.T) { func TestJobEndpoint_GetJob_Blocking(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() state := s1.fsm.State() codec := rpcClient(t, s1) @@ -2075,7 +2075,7 @@ func TestJobEndpoint_GetJob_Blocking(t *testing.T) { func TestJobEndpoint_GetJobVersions(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) @@ -2149,7 +2149,7 @@ func TestJobEndpoint_GetJobVersions_ACL(t *testing.T) { t.Parallel() assert := assert.New(t) - s1, root := testACLServer(t, nil) + s1, root := TestACLServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) @@ -2214,7 +2214,7 @@ func TestJobEndpoint_GetJobVersions_ACL(t *testing.T) { func TestJobEndpoint_GetJobVersions_Diff(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) @@ -2310,7 +2310,7 @@ func TestJobEndpoint_GetJobVersions_Diff(t *testing.T) { func TestJobEndpoint_GetJobVersions_Blocking(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() state := s1.fsm.State() codec := rpcClient(t, s1) @@ -2395,7 +2395,7 @@ func TestJobEndpoint_GetJobVersions_Blocking(t *testing.T) { func TestJobEndpoint_GetJobSummary(t *testing.T) { t.Parallel() - s1 := testServer(t, func(c *Config) { + s1 := TestServer(t, func(c *Config) { c.NumSchedulers = 0 // Prevent automatic dequeue }) @@ -2458,7 +2458,7 @@ func TestJobEndpoint_Summary_ACL(t *testing.T) { assert := assert.New(t) t.Parallel() - srv, root := testACLServer(t, func(c *Config) { + srv, root := TestACLServer(t, func(c *Config) { c.NumSchedulers = 0 // Prevent automatic dequeue }) defer srv.Shutdown() @@ -2543,7 +2543,7 @@ func TestJobEndpoint_Summary_ACL(t *testing.T) { func TestJobEndpoint_GetJobSummary_Blocking(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() state := s1.fsm.State() codec := rpcClient(t, s1) @@ -2636,7 +2636,7 @@ func TestJobEndpoint_GetJobSummary_Blocking(t *testing.T) { func TestJobEndpoint_ListJobs(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) @@ -2699,7 +2699,7 @@ func TestJobEndpoint_ListJobs_WithACL(t *testing.T) { assert := assert.New(t) t.Parallel() - srv, root := testACLServer(t, func(c *Config) { + srv, root := TestACLServer(t, func(c *Config) { c.NumSchedulers = 0 // Prevent automatic dequeue }) defer srv.Shutdown() @@ -2757,7 +2757,7 @@ func TestJobEndpoint_ListJobs_WithACL(t *testing.T) { func TestJobEndpoint_ListJobs_Blocking(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() state := s1.fsm.State() codec := rpcClient(t, s1) @@ -2823,7 +2823,7 @@ func TestJobEndpoint_ListJobs_Blocking(t *testing.T) { func TestJobEndpoint_Allocations(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) @@ -2866,7 +2866,7 @@ func TestJobEndpoint_Allocations_ACL(t *testing.T) { t.Parallel() assert := assert.New(t) - s1, root := testACLServer(t, nil) + s1, root := TestACLServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) @@ -2927,7 +2927,7 @@ func TestJobEndpoint_Allocations_ACL(t *testing.T) { func TestJobEndpoint_Allocations_Blocking(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) @@ -2984,7 +2984,7 @@ func TestJobEndpoint_Allocations_Blocking(t *testing.T) { func TestJobEndpoint_Evaluations(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) @@ -3025,7 +3025,7 @@ func TestJobEndpoint_Evaluations_ACL(t *testing.T) { t.Parallel() assert := assert.New(t) - s1, root := testACLServer(t, nil) + s1, root := TestACLServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) @@ -3084,7 +3084,7 @@ func TestJobEndpoint_Evaluations_ACL(t *testing.T) { func TestJobEndpoint_Evaluations_Blocking(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) @@ -3139,7 +3139,7 @@ func TestJobEndpoint_Evaluations_Blocking(t *testing.T) { func TestJobEndpoint_Deployments(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) @@ -3174,7 +3174,7 @@ func TestJobEndpoint_Deployments_ACL(t *testing.T) { t.Parallel() assert := assert.New(t) - s1, root := testACLServer(t, nil) + s1, root := TestACLServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) @@ -3234,7 +3234,7 @@ func TestJobEndpoint_Deployments_ACL(t *testing.T) { func TestJobEndpoint_Deployments_Blocking(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) @@ -3280,7 +3280,7 @@ func TestJobEndpoint_Deployments_Blocking(t *testing.T) { func TestJobEndpoint_LatestDeployment(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) @@ -3318,7 +3318,7 @@ func TestJobEndpoint_LatestDeployment_ACL(t *testing.T) { t.Parallel() assert := assert.New(t) - s1, root := testACLServer(t, nil) + s1, root := TestACLServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) @@ -3383,7 +3383,7 @@ func TestJobEndpoint_LatestDeployment_ACL(t *testing.T) { func TestJobEndpoint_LatestDeployment_Blocking(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) @@ -3429,7 +3429,7 @@ func TestJobEndpoint_LatestDeployment_Blocking(t *testing.T) { func TestJobEndpoint_Plan_ACL(t *testing.T) { t.Parallel() - s1, root := testACLServer(t, func(c *Config) { + s1, root := TestACLServer(t, func(c *Config) { c.NumSchedulers = 0 // Prevent automatic dequeue }) defer s1.Shutdown() @@ -3462,7 +3462,7 @@ func TestJobEndpoint_Plan_ACL(t *testing.T) { func TestJobEndpoint_Plan_WithDiff(t *testing.T) { t.Parallel() - s1 := testServer(t, func(c *Config) { + s1 := TestServer(t, func(c *Config) { c.NumSchedulers = 0 // Prevent automatic dequeue }) defer s1.Shutdown() @@ -3521,7 +3521,7 @@ func TestJobEndpoint_Plan_WithDiff(t *testing.T) { func TestJobEndpoint_Plan_NoDiff(t *testing.T) { t.Parallel() - s1 := testServer(t, func(c *Config) { + s1 := TestServer(t, func(c *Config) { c.NumSchedulers = 0 // Prevent automatic dequeue }) defer s1.Shutdown() @@ -3580,7 +3580,7 @@ func TestJobEndpoint_Plan_NoDiff(t *testing.T) { func TestJobEndpoint_ImplicitConstraints_Vault(t *testing.T) { t.Parallel() - s1 := testServer(t, func(c *Config) { + s1 := TestServer(t, func(c *Config) { c.NumSchedulers = 0 // Prevent automatic dequeue }) defer s1.Shutdown() @@ -3649,7 +3649,7 @@ func TestJobEndpoint_ImplicitConstraints_Vault(t *testing.T) { func TestJobEndpoint_ImplicitConstraints_Signals(t *testing.T) { t.Parallel() - s1 := testServer(t, func(c *Config) { + s1 := TestServer(t, func(c *Config) { c.NumSchedulers = 0 // Prevent automatic dequeue }) defer s1.Shutdown() @@ -3814,7 +3814,7 @@ func TestJobEndpoint_ValidateJobUpdate_ACL(t *testing.T) { t.Parallel() assert := assert.New(t) - s1, root := testACLServer(t, func(c *Config) { + s1, root := TestACLServer(t, func(c *Config) { c.NumSchedulers = 0 // Prevent automatic dequeue }) defer s1.Shutdown() @@ -3850,7 +3850,7 @@ func TestJobEndpoint_Dispatch_ACL(t *testing.T) { t.Parallel() assert := assert.New(t) - s1, root := testACLServer(t, func(c *Config) { + s1, root := TestACLServer(t, func(c *Config) { c.NumSchedulers = 0 // Prevent automatic dequeue }) @@ -4102,7 +4102,7 @@ func TestJobEndpoint_Dispatch(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { - s1 := testServer(t, func(c *Config) { + s1 := TestServer(t, func(c *Config) { c.NumSchedulers = 0 // Prevent automatic dequeue }) defer s1.Shutdown() diff --git a/nomad/leader_test.go b/nomad/leader_test.go index 4689cbfcbd5..721577e69f7 100644 --- a/nomad/leader_test.go +++ b/nomad/leader_test.go @@ -16,20 +16,20 @@ import ( ) func TestLeader_LeftServer(t *testing.T) { - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() - s2 := testServer(t, func(c *Config) { + s2 := TestServer(t, func(c *Config) { c.DevDisableBootstrap = true }) defer s2.Shutdown() - s3 := testServer(t, func(c *Config) { + s3 := TestServer(t, func(c *Config) { c.DevDisableBootstrap = true }) defer s3.Shutdown() servers := []*Server{s1, s2, s3} - testJoin(t, s1, s2, s3) + TestJoin(t, s1, s2, s3) for _, s := range servers { testutil.WaitForResult(func() (bool, error) { @@ -76,20 +76,20 @@ func TestLeader_LeftServer(t *testing.T) { } func TestLeader_LeftLeader(t *testing.T) { - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() - s2 := testServer(t, func(c *Config) { + s2 := TestServer(t, func(c *Config) { c.DevDisableBootstrap = true }) defer s2.Shutdown() - s3 := testServer(t, func(c *Config) { + s3 := TestServer(t, func(c *Config) { c.DevDisableBootstrap = true }) defer s3.Shutdown() servers := []*Server{s1, s2, s3} - testJoin(t, s1, s2, s3) + TestJoin(t, s1, s2, s3) for _, s := range servers { testutil.WaitForResult(func() (bool, error) { @@ -128,13 +128,13 @@ func TestLeader_LeftLeader(t *testing.T) { } func TestLeader_MultiBootstrap(t *testing.T) { - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() - s2 := testServer(t, nil) + s2 := TestServer(t, nil) defer s2.Shutdown() servers := []*Server{s1, s2} - testJoin(t, s1, s2) + TestJoin(t, s1, s2) for _, s := range servers { testutil.WaitForResult(func() (bool, error) { @@ -155,20 +155,20 @@ func TestLeader_MultiBootstrap(t *testing.T) { } func TestLeader_PlanQueue_Reset(t *testing.T) { - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() - s2 := testServer(t, func(c *Config) { + s2 := TestServer(t, func(c *Config) { c.DevDisableBootstrap = true }) defer s2.Shutdown() - s3 := testServer(t, func(c *Config) { + s3 := TestServer(t, func(c *Config) { c.DevDisableBootstrap = true }) defer s3.Shutdown() servers := []*Server{s1, s2, s3} - testJoin(t, s1, s2, s3) + TestJoin(t, s1, s2, s3) for _, s := range servers { testutil.WaitForResult(func() (bool, error) { @@ -227,24 +227,24 @@ func TestLeader_PlanQueue_Reset(t *testing.T) { } func TestLeader_EvalBroker_Reset(t *testing.T) { - s1 := testServer(t, func(c *Config) { + s1 := TestServer(t, func(c *Config) { c.NumSchedulers = 0 }) defer s1.Shutdown() - s2 := testServer(t, func(c *Config) { + s2 := TestServer(t, func(c *Config) { c.NumSchedulers = 0 c.DevDisableBootstrap = true }) defer s2.Shutdown() - s3 := testServer(t, func(c *Config) { + s3 := TestServer(t, func(c *Config) { c.NumSchedulers = 0 c.DevDisableBootstrap = true }) defer s3.Shutdown() servers := []*Server{s1, s2, s3} - testJoin(t, s1, s2, s3) + TestJoin(t, s1, s2, s3) testutil.WaitForLeader(t, s1.RPC) for _, s := range servers { @@ -304,24 +304,24 @@ func TestLeader_EvalBroker_Reset(t *testing.T) { } func TestLeader_PeriodicDispatcher_Restore_Adds(t *testing.T) { - s1 := testServer(t, func(c *Config) { + s1 := TestServer(t, func(c *Config) { c.NumSchedulers = 0 }) defer s1.Shutdown() - s2 := testServer(t, func(c *Config) { + s2 := TestServer(t, func(c *Config) { c.NumSchedulers = 0 c.DevDisableBootstrap = true }) defer s2.Shutdown() - s3 := testServer(t, func(c *Config) { + s3 := TestServer(t, func(c *Config) { c.NumSchedulers = 0 c.DevDisableBootstrap = true }) defer s3.Shutdown() servers := []*Server{s1, s2, s3} - testJoin(t, s1, s2, s3) + TestJoin(t, s1, s2, s3) testutil.WaitForLeader(t, s1.RPC) for _, s := range servers { @@ -411,7 +411,7 @@ func TestLeader_PeriodicDispatcher_Restore_Adds(t *testing.T) { } func TestLeader_PeriodicDispatcher_Restore_NoEvals(t *testing.T) { - s1 := testServer(t, func(c *Config) { + s1 := TestServer(t, func(c *Config) { c.NumSchedulers = 0 }) defer s1.Shutdown() @@ -467,7 +467,7 @@ func TestLeader_PeriodicDispatcher_Restore_NoEvals(t *testing.T) { } func TestLeader_PeriodicDispatcher_Restore_Evals(t *testing.T) { - s1 := testServer(t, func(c *Config) { + s1 := TestServer(t, func(c *Config) { c.NumSchedulers = 0 }) defer s1.Shutdown() @@ -524,7 +524,7 @@ func TestLeader_PeriodicDispatcher_Restore_Evals(t *testing.T) { } func TestLeader_PeriodicDispatch(t *testing.T) { - s1 := testServer(t, func(c *Config) { + s1 := TestServer(t, func(c *Config) { c.NumSchedulers = 0 c.EvalGCInterval = 5 * time.Millisecond }) @@ -544,7 +544,7 @@ func TestLeader_PeriodicDispatch(t *testing.T) { } func TestLeader_ReapFailedEval(t *testing.T) { - s1 := testServer(t, func(c *Config) { + s1 := TestServer(t, func(c *Config) { c.NumSchedulers = 0 c.EvalDeliveryLimit = 1 }) @@ -615,7 +615,7 @@ func TestLeader_ReapFailedEval(t *testing.T) { } func TestLeader_ReapDuplicateEval(t *testing.T) { - s1 := testServer(t, func(c *Config) { + s1 := TestServer(t, func(c *Config) { c.NumSchedulers = 0 }) defer s1.Shutdown() @@ -643,7 +643,7 @@ func TestLeader_ReapDuplicateEval(t *testing.T) { } func TestLeader_RestoreVaultAccessors(t *testing.T) { - s1 := testServer(t, func(c *Config) { + s1 := TestServer(t, func(c *Config) { c.NumSchedulers = 0 }) defer s1.Shutdown() @@ -672,13 +672,13 @@ func TestLeader_RestoreVaultAccessors(t *testing.T) { func TestLeader_ReplicateACLPolicies(t *testing.T) { t.Parallel() - s1, root := testACLServer(t, func(c *Config) { + s1, root := TestACLServer(t, func(c *Config) { c.Region = "region1" c.AuthoritativeRegion = "region1" c.ACLEnabled = true }) defer s1.Shutdown() - s2, _ := testACLServer(t, func(c *Config) { + s2, _ := TestACLServer(t, func(c *Config) { c.Region = "region2" c.AuthoritativeRegion = "region1" c.ACLEnabled = true @@ -686,7 +686,7 @@ func TestLeader_ReplicateACLPolicies(t *testing.T) { c.ReplicationToken = root.SecretID }) defer s2.Shutdown() - testJoin(t, s1, s2) + TestJoin(t, s1, s2) testutil.WaitForLeader(t, s1.RPC) testutil.WaitForLeader(t, s2.RPC) @@ -740,13 +740,13 @@ func TestLeader_DiffACLPolicies(t *testing.T) { func TestLeader_ReplicateACLTokens(t *testing.T) { t.Parallel() - s1, root := testACLServer(t, func(c *Config) { + s1, root := TestACLServer(t, func(c *Config) { c.Region = "region1" c.AuthoritativeRegion = "region1" c.ACLEnabled = true }) defer s1.Shutdown() - s2, _ := testACLServer(t, func(c *Config) { + s2, _ := TestACLServer(t, func(c *Config) { c.Region = "region2" c.AuthoritativeRegion = "region1" c.ACLEnabled = true @@ -754,7 +754,7 @@ func TestLeader_ReplicateACLTokens(t *testing.T) { c.ReplicationToken = root.SecretID }) defer s2.Shutdown() - testJoin(t, s1, s2) + TestJoin(t, s1, s2) testutil.WaitForLeader(t, s1.RPC) testutil.WaitForLeader(t, s2.RPC) @@ -815,18 +815,19 @@ func TestLeader_DiffACLTokens(t *testing.T) { func TestLeader_UpgradeRaftVersion(t *testing.T) { t.Parallel() - s1 := testServer(t, func(c *Config) { + s1 := TestServer(t, func(c *Config) { + c.Datacenter = "dc1" c.RaftConfig.ProtocolVersion = 2 }) defer s1.Shutdown() - s2 := testServer(t, func(c *Config) { + s2 := TestServer(t, func(c *Config) { c.DevDisableBootstrap = true c.RaftConfig.ProtocolVersion = 1 }) defer s2.Shutdown() - s3 := testServer(t, func(c *Config) { + s3 := TestServer(t, func(c *Config) { c.DevDisableBootstrap = true c.RaftConfig.ProtocolVersion = 2 }) @@ -835,7 +836,7 @@ func TestLeader_UpgradeRaftVersion(t *testing.T) { servers := []*Server{s1, s2, s3} // Try to join - testJoin(t, s1, s2, s3) + TestJoin(t, s1, s2, s3) for _, s := range servers { testutil.WaitForResult(func() (bool, error) { @@ -862,13 +863,13 @@ func TestLeader_UpgradeRaftVersion(t *testing.T) { } // Replace the dead server with one running raft protocol v3 - s4 := testServer(t, func(c *Config) { + s4 := TestServer(t, func(c *Config) { c.DevDisableBootstrap = true c.Datacenter = "dc1" c.RaftConfig.ProtocolVersion = 3 }) defer s4.Shutdown() - testJoin(t, s1, s4) + TestJoin(t, s1, s4) servers[1] = s4 // Make sure we're back to 3 total peers with the new one added via ID @@ -903,18 +904,18 @@ func TestLeader_UpgradeRaftVersion(t *testing.T) { func TestLeader_RollRaftServer(t *testing.T) { t.Parallel() - s1 := testServer(t, func(c *Config) { + s1 := TestServer(t, func(c *Config) { c.RaftConfig.ProtocolVersion = 2 }) defer s1.Shutdown() - s2 := testServer(t, func(c *Config) { + s2 := TestServer(t, func(c *Config) { c.DevDisableBootstrap = true c.RaftConfig.ProtocolVersion = 1 }) defer s2.Shutdown() - s3 := testServer(t, func(c *Config) { + s3 := TestServer(t, func(c *Config) { c.DevDisableBootstrap = true c.RaftConfig.ProtocolVersion = 2 }) @@ -923,7 +924,7 @@ func TestLeader_RollRaftServer(t *testing.T) { servers := []*Server{s1, s2, s3} // Try to join - testJoin(t, s1, s2, s3) + TestJoin(t, s1, s2, s3) for _, s := range servers { retry.Run(t, func(r *retry.R) { r.Check(wantPeers(s, 3)) }) @@ -945,12 +946,12 @@ func TestLeader_RollRaftServer(t *testing.T) { } // Replace the dead server with one running raft protocol v3 - s4 := testServer(t, func(c *Config) { + s4 := TestServer(t, func(c *Config) { c.DevDisableBootstrap = true c.RaftConfig.ProtocolVersion = 3 }) defer s4.Shutdown() - testJoin(t, s4, s1) + TestJoin(t, s4, s1) servers[1] = s4 // Make sure the dead server is removed and we're back to 3 total peers diff --git a/nomad/mock/acl.go b/nomad/mock/acl.go index 1eeb61cbe5a..a7f4a23b518 100644 --- a/nomad/mock/acl.go +++ b/nomad/mock/acl.go @@ -2,6 +2,8 @@ package mock import ( "fmt" + "strconv" + "strings" "github.com/hashicorp/nomad/nomad/structs" "github.com/mitchellh/go-testing-interface" @@ -23,7 +25,13 @@ func NamespacePolicy(namespace string, policy string, capabilities []string) str policyHCL += fmt.Sprintf("\n\tpolicy = %q", policy) } if len(capabilities) != 0 { - policyHCL += fmt.Sprintf("\n\tcapabilities = %q", capabilities) + for i, s := range capabilities { + if !strings.HasPrefix(s, "\"") { + capabilities[i] = strconv.Quote(s) + } + } + + policyHCL += fmt.Sprintf("\n\tcapabilities = [%v]", strings.Join(capabilities, ",")) } policyHCL += "\n}" return policyHCL diff --git a/nomad/node_endpoint.go b/nomad/node_endpoint.go index 2ee7a68e056..75e0cf9ffe0 100644 --- a/nomad/node_endpoint.go +++ b/nomad/node_endpoint.go @@ -2,14 +2,11 @@ package nomad import ( "context" - "crypto/subtle" - "encoding/base64" "fmt" "strings" "sync" "time" - "golang.org/x/crypto/blake2b" "golang.org/x/sync/errgroup" "github.com/armon/go-metrics" @@ -36,6 +33,9 @@ const ( type Node struct { srv *Server + // ctx provides context regarding the underlying connection + ctx *RPCContext + // updates holds pending client status updates for allocations updates []*structs.Allocation @@ -114,6 +114,13 @@ func (n *Node) Register(args *structs.NodeRegisterRequest, reply *structs.NodeUp } } + // We have a valid node connection, so add the mapping to cache the + // connection and allow the server to send RPCs to the client. + if n.ctx != nil && n.ctx.NodeID == "" { + n.ctx.NodeID = args.Node.ID + n.srv.addNodeConn(n.ctx) + } + // Commit this update via Raft _, index, err := n.srv.raftApply(structs.NodeRegisterRequestType, args) if err != nil { @@ -305,6 +312,13 @@ func (n *Node) UpdateStatus(args *structs.NodeUpdateStatusRequest, reply *struct return fmt.Errorf("node not found") } + // We have a valid node connection, so add the mapping to cache the + // connection and allow the server to send RPCs to the client. + if n.ctx != nil && n.ctx.NodeID == "" { + n.ctx.NodeID = args.NodeID + n.srv.addNodeConn(n.ctx) + } + // XXX: Could use the SecretID here but have to update the heartbeat system // to track SecretIDs. @@ -658,33 +672,6 @@ func (n *Node) GetAllocs(args *structs.NodeSpecificRequest, return n.srv.blockingRPC(&opts) } -// GenerateMigrateToken will create a token for a client to access an -// authenticated volume of another client to migrate data for sticky volumes. -func GenerateMigrateToken(allocID, nodeSecretID string) (string, error) { - h, err := blake2b.New512([]byte(nodeSecretID)) - if err != nil { - return "", err - } - h.Write([]byte(allocID)) - return base64.URLEncoding.EncodeToString(h.Sum(nil)), nil -} - -// CompareMigrateToken returns true if two migration tokens can be computed and -// are equal. -func CompareMigrateToken(allocID, nodeSecretID, otherMigrateToken string) bool { - h, err := blake2b.New512([]byte(nodeSecretID)) - if err != nil { - return false - } - h.Write([]byte(allocID)) - - otherBytes, err := base64.URLEncoding.DecodeString(otherMigrateToken) - if err != nil { - return false - } - return subtle.ConstantTimeCompare(h.Sum(nil), otherBytes) == 1 -} - // GetClientAllocs is used to request a lightweight list of alloc modify indexes // per allocation. func (n *Node) GetClientAllocs(args *structs.NodeSpecificRequest, @@ -724,6 +711,13 @@ func (n *Node) GetClientAllocs(args *structs.NodeSpecificRequest, return fmt.Errorf("node secret ID does not match") } + // We have a valid node connection, so add the mapping to cache the + // connection and allow the server to send RPCs to the client. + if n.ctx != nil && n.ctx.NodeID == "" { + n.ctx.NodeID = args.NodeID + n.srv.addNodeConn(n.ctx) + } + var err error allocs, err = state.AllocsByNode(ws, args.NodeID) if err != nil { @@ -767,7 +761,7 @@ func (n *Node) GetClientAllocs(args *structs.NodeSpecificRequest, continue } - token, err := GenerateMigrateToken(prevAllocation.ID, allocNode.SecretID) + token, err := structs.GenerateMigrateToken(prevAllocation.ID, allocNode.SecretID) if err != nil { return err } diff --git a/nomad/node_endpoint_test.go b/nomad/node_endpoint_test.go index cd2553f596e..9bd9db8a7f5 100644 --- a/nomad/node_endpoint_test.go +++ b/nomad/node_endpoint_test.go @@ -21,11 +21,15 @@ import ( func TestClientEndpoint_Register(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + require := require.New(t) + s1 := TestServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) + // Check that we have no client connections + require.Empty(s1.connectedNodes()) + // Create the register request node := mock.Node() req := &structs.NodeRegisterRequest{ @@ -42,6 +46,11 @@ func TestClientEndpoint_Register(t *testing.T) { t.Fatalf("bad index: %d", resp.Index) } + // Check that we have the client connections + nodes := s1.connectedNodes() + require.Len(nodes, 1) + require.Contains(nodes, node.ID) + // Check for the node in the FSM state := s1.fsm.State() ws := memdb.NewWatchSet() @@ -58,11 +67,20 @@ func TestClientEndpoint_Register(t *testing.T) { if out.ComputedClass == "" { t.Fatal("ComputedClass not set") } + + // Close the connection and check that we remove the client connections + require.Nil(codec.Close()) + testutil.WaitForResult(func() (bool, error) { + nodes := s1.connectedNodes() + return len(nodes) == 0, nil + }, func(err error) { + t.Fatalf("should have no clients") + }) } func TestClientEndpoint_Register_SecretMismatch(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) @@ -90,7 +108,7 @@ func TestClientEndpoint_Register_SecretMismatch(t *testing.T) { func TestClientEndpoint_Deregister(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) @@ -135,7 +153,7 @@ func TestClientEndpoint_Deregister(t *testing.T) { func TestClientEndpoint_Deregister_ACL(t *testing.T) { t.Parallel() - s1, root := testACLServer(t, nil) + s1, root := TestACLServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) @@ -200,7 +218,7 @@ func TestClientEndpoint_Deregister_ACL(t *testing.T) { func TestClientEndpoint_Deregister_Vault(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) @@ -261,11 +279,15 @@ func TestClientEndpoint_Deregister_Vault(t *testing.T) { func TestClientEndpoint_UpdateStatus(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + require := require.New(t) + s1 := TestServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) + // Check that we have no client connections + require.Empty(s1.connectedNodes()) + // Create the register request node := mock.Node() reg := &structs.NodeRegisterRequest{ @@ -305,6 +327,11 @@ func TestClientEndpoint_UpdateStatus(t *testing.T) { t.Fatalf("bad: %#v", ttl) } + // Check that we have the client connections + nodes := s1.connectedNodes() + require.Len(nodes, 1) + require.Contains(nodes, node.ID) + // Check for the node in the FSM state := s1.fsm.State() ws := memdb.NewWatchSet() @@ -318,11 +345,20 @@ func TestClientEndpoint_UpdateStatus(t *testing.T) { if out.ModifyIndex != resp2.Index { t.Fatalf("index mis-match") } + + // Close the connection and check that we remove the client connections + require.Nil(codec.Close()) + testutil.WaitForResult(func() (bool, error) { + nodes := s1.connectedNodes() + return len(nodes) == 0, nil + }, func(err error) { + t.Fatalf("should have no clients") + }) } func TestClientEndpoint_UpdateStatus_Vault(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) @@ -380,7 +416,7 @@ func TestClientEndpoint_UpdateStatus_Vault(t *testing.T) { func TestClientEndpoint_Register_GetEvals(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) @@ -474,7 +510,7 @@ func TestClientEndpoint_Register_GetEvals(t *testing.T) { func TestClientEndpoint_UpdateStatus_GetEvals(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) @@ -557,20 +593,20 @@ func TestClientEndpoint_UpdateStatus_GetEvals(t *testing.T) { func TestClientEndpoint_UpdateStatus_HeartbeatOnly(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() - s2 := testServer(t, func(c *Config) { + s2 := TestServer(t, func(c *Config) { c.DevDisableBootstrap = true }) defer s2.Shutdown() - s3 := testServer(t, func(c *Config) { + s3 := TestServer(t, func(c *Config) { c.DevDisableBootstrap = true }) defer s3.Shutdown() servers := []*Server{s1, s2, s3} - testJoin(t, s1, s2, s3) + TestJoin(t, s1, s2, s3) for _, s := range servers { testutil.WaitForResult(func() (bool, error) { @@ -632,7 +668,7 @@ func TestClientEndpoint_UpdateStatus_HeartbeatOnly(t *testing.T) { func TestClientEndpoint_UpdateDrain(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) @@ -678,7 +714,7 @@ func TestClientEndpoint_UpdateDrain(t *testing.T) { func TestClientEndpoint_UpdateDrain_ACL(t *testing.T) { t.Parallel() - s1, root := testACLServer(t, nil) + s1, root := TestACLServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) @@ -735,7 +771,7 @@ func TestClientEndpoint_UpdateDrain_ACL(t *testing.T) { // pending/running state to lost when a node is marked as down. func TestClientEndpoint_Drain_Down(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) @@ -871,7 +907,7 @@ func TestClientEndpoint_Drain_Down(t *testing.T) { func TestClientEndpoint_GetNode(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) @@ -930,7 +966,7 @@ func TestClientEndpoint_GetNode(t *testing.T) { func TestClientEndpoint_GetNode_ACL(t *testing.T) { t.Parallel() - s1, root := testACLServer(t, nil) + s1, root := TestACLServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) @@ -993,7 +1029,7 @@ func TestClientEndpoint_GetNode_ACL(t *testing.T) { func TestClientEndpoint_GetNode_Blocking(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() state := s1.fsm.State() codec := rpcClient(t, s1) @@ -1095,7 +1131,7 @@ func TestClientEndpoint_GetNode_Blocking(t *testing.T) { func TestClientEndpoint_GetAllocs(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) @@ -1157,7 +1193,7 @@ func TestClientEndpoint_GetAllocs(t *testing.T) { func TestClientEndpoint_GetAllocs_ACL_Basic(t *testing.T) { t.Parallel() - s1, root := testACLServer(t, nil) + s1, root := TestACLServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) @@ -1231,30 +1267,23 @@ func TestClientEndpoint_GetAllocs_ACL_Basic(t *testing.T) { func TestClientEndpoint_GetClientAllocs(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + require := require.New(t) + s1 := TestServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) + // Check that we have no client connections + require.Empty(s1.connectedNodes()) + // Create the register request node := mock.Node() - reg := &structs.NodeRegisterRequest{ - Node: node, - WriteRequest: structs.WriteRequest{Region: "global"}, - } - - // Fetch the response - var resp structs.GenericResponse - if err := msgpackrpc.CallWithCodec(codec, "Node.Register", reg, &resp); err != nil { - t.Fatalf("err: %v", err) - } - node.CreateIndex = resp.Index - node.ModifyIndex = resp.Index + state := s1.fsm.State() + require.Nil(state.UpsertNode(98, node)) // Inject fake evaluations alloc := mock.Alloc() alloc.NodeID = node.ID - state := s1.fsm.State() state.UpsertJobSummary(99, mock.JobSummary(alloc.JobID)) err := state.UpsertAllocs(100, []*structs.Allocation{alloc}) if err != nil { @@ -1279,6 +1308,11 @@ func TestClientEndpoint_GetClientAllocs(t *testing.T) { t.Fatalf("bad: %#v", resp2.Allocs) } + // Check that we have the client connections + nodes := s1.connectedNodes() + require.Len(nodes, 1) + require.Contains(nodes, node.ID) + // Lookup node with bad SecretID get.SecretID = "foobarbaz" var resp3 structs.NodeClientAllocsResponse @@ -1299,11 +1333,20 @@ func TestClientEndpoint_GetClientAllocs(t *testing.T) { if len(resp4.Allocs) != 0 { t.Fatalf("unexpected node %#v", resp3.Allocs) } + + // Close the connection and check that we remove the client connections + require.Nil(codec.Close()) + testutil.WaitForResult(func() (bool, error) { + nodes := s1.connectedNodes() + return len(nodes) == 0, nil + }, func(err error) { + t.Fatalf("should have no clients") + }) } func TestClientEndpoint_GetClientAllocs_Blocking(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) @@ -1425,7 +1468,7 @@ func TestClientEndpoint_GetClientAllocs_Blocking(t *testing.T) { func TestClientEndpoint_GetClientAllocs_Blocking_GC(t *testing.T) { t.Parallel() assert := assert.New(t) - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) @@ -1502,7 +1545,7 @@ func TestClientEndpoint_GetClientAllocs_WithoutMigrateTokens(t *testing.T) { t.Parallel() assert := assert.New(t) - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) @@ -1553,7 +1596,7 @@ func TestClientEndpoint_GetClientAllocs_WithoutMigrateTokens(t *testing.T) { func TestClientEndpoint_GetAllocs_Blocking(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) @@ -1645,7 +1688,7 @@ func TestClientEndpoint_GetAllocs_Blocking(t *testing.T) { func TestClientEndpoint_UpdateAlloc(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) @@ -1723,7 +1766,7 @@ func TestClientEndpoint_UpdateAlloc(t *testing.T) { func TestClientEndpoint_BatchUpdate(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) @@ -1758,7 +1801,7 @@ func TestClientEndpoint_BatchUpdate(t *testing.T) { // Call to do the batch update bf := NewBatchFuture() - endpoint := s1.endpoints.Node + endpoint := s1.staticEndpoints.Node endpoint.batchUpdate(bf, []*structs.Allocation{clientAlloc}, nil) if err := bf.Wait(); err != nil { t.Fatalf("err: %v", err) @@ -1780,7 +1823,7 @@ func TestClientEndpoint_BatchUpdate(t *testing.T) { func TestClientEndpoint_UpdateAlloc_Vault(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) @@ -1865,7 +1908,7 @@ func TestClientEndpoint_UpdateAlloc_Vault(t *testing.T) { func TestClientEndpoint_CreateNodeEvals(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() testutil.WaitForLeader(t, s1.RPC) @@ -1884,7 +1927,7 @@ func TestClientEndpoint_CreateNodeEvals(t *testing.T) { } // Create some evaluations - ids, index, err := s1.endpoints.Node.createNodeEvals(alloc.NodeID, 1) + ids, index, err := s1.staticEndpoints.Node.createNodeEvals(alloc.NodeID, 1) if err != nil { t.Fatalf("err: %v", err) } @@ -1953,7 +1996,7 @@ func TestClientEndpoint_CreateNodeEvals(t *testing.T) { func TestClientEndpoint_Evaluate(t *testing.T) { t.Parallel() - s1 := testServer(t, func(c *Config) { + s1 := TestServer(t, func(c *Config) { c.NumSchedulers = 0 // Prevent automatic dequeue }) defer s1.Shutdown() @@ -2034,7 +2077,7 @@ func TestClientEndpoint_Evaluate(t *testing.T) { func TestClientEndpoint_Evaluate_ACL(t *testing.T) { t.Parallel() - s1, root := testACLServer(t, nil) + s1, root := TestACLServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) @@ -2092,7 +2135,7 @@ func TestClientEndpoint_Evaluate_ACL(t *testing.T) { func TestClientEndpoint_ListNodes(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) @@ -2153,7 +2196,7 @@ func TestClientEndpoint_ListNodes(t *testing.T) { func TestClientEndpoint_ListNodes_ACL(t *testing.T) { t.Parallel() - s1, root := testACLServer(t, nil) + s1, root := TestACLServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) @@ -2207,7 +2250,7 @@ func TestClientEndpoint_ListNodes_ACL(t *testing.T) { func TestClientEndpoint_ListNodes_Blocking(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() state := s1.fsm.State() codec := rpcClient(t, s1) @@ -2348,7 +2391,7 @@ func TestBatchFuture(t *testing.T) { func TestClientEndpoint_DeriveVaultToken_Bad(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() state := s1.fsm.State() codec := rpcClient(t, s1) @@ -2429,7 +2472,7 @@ func TestClientEndpoint_DeriveVaultToken_Bad(t *testing.T) { func TestClientEndpoint_DeriveVaultToken(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() state := s1.fsm.State() codec := rpcClient(t, s1) @@ -2521,7 +2564,7 @@ func TestClientEndpoint_DeriveVaultToken(t *testing.T) { func TestClientEndpoint_DeriveVaultToken_VaultError(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() state := s1.fsm.State() codec := rpcClient(t, s1) diff --git a/nomad/operator_endpoint_test.go b/nomad/operator_endpoint_test.go index 64115d01a90..1ef2875d7b8 100644 --- a/nomad/operator_endpoint_test.go +++ b/nomad/operator_endpoint_test.go @@ -18,7 +18,7 @@ import ( func TestOperator_RaftGetConfiguration(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) @@ -61,7 +61,7 @@ func TestOperator_RaftGetConfiguration(t *testing.T) { func TestOperator_RaftGetConfiguration_ACL(t *testing.T) { t.Parallel() - s1, root := testACLServer(t, nil) + s1, root := TestACLServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) @@ -124,7 +124,7 @@ func TestOperator_RaftGetConfiguration_ACL(t *testing.T) { func TestOperator_RaftRemovePeerByAddress(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) @@ -180,7 +180,7 @@ func TestOperator_RaftRemovePeerByAddress(t *testing.T) { func TestOperator_RaftRemovePeerByAddress_ACL(t *testing.T) { t.Parallel() - s1, root := testACLServer(t, nil) + s1, root := TestACLServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) @@ -228,7 +228,7 @@ func TestOperator_RaftRemovePeerByAddress_ACL(t *testing.T) { func TestOperator_RaftRemovePeerByID(t *testing.T) { t.Parallel() - s1 := testServer(t, func(c *Config) { + s1 := TestServer(t, func(c *Config) { c.RaftConfig.ProtocolVersion = 3 }) defer s1.Shutdown() @@ -286,7 +286,7 @@ func TestOperator_RaftRemovePeerByID(t *testing.T) { func TestOperator_RaftRemovePeerByID_ACL(t *testing.T) { t.Parallel() - s1, root := testACLServer(t, func(c *Config) { + s1, root := TestACLServer(t, func(c *Config) { c.RaftConfig.ProtocolVersion = 3 }) defer s1.Shutdown() diff --git a/nomad/periodic_endpoint_test.go b/nomad/periodic_endpoint_test.go index 575b9dd1ab9..1049f960587 100644 --- a/nomad/periodic_endpoint_test.go +++ b/nomad/periodic_endpoint_test.go @@ -14,7 +14,7 @@ import ( func TestPeriodicEndpoint_Force(t *testing.T) { t.Parallel() - s1 := testServer(t, func(c *Config) { + s1 := TestServer(t, func(c *Config) { c.NumSchedulers = 0 // Prevent automatic dequeue }) state := s1.fsm.State() @@ -64,7 +64,7 @@ func TestPeriodicEndpoint_Force(t *testing.T) { func TestPeriodicEndpoint_Force_ACL(t *testing.T) { t.Parallel() - s1, root := testACLServer(t, func(c *Config) { + s1, root := TestACLServer(t, func(c *Config) { c.NumSchedulers = 0 // Prevent automatic dequeue }) defer s1.Shutdown() @@ -142,7 +142,7 @@ func TestPeriodicEndpoint_Force_ACL(t *testing.T) { func TestPeriodicEndpoint_Force_NonPeriodic(t *testing.T) { t.Parallel() - s1 := testServer(t, func(c *Config) { + s1 := TestServer(t, func(c *Config) { c.NumSchedulers = 0 // Prevent automatic dequeue }) state := s1.fsm.State() diff --git a/nomad/periodic_test.go b/nomad/periodic_test.go index 4bc3d20cc82..93554d37bfc 100644 --- a/nomad/periodic_test.go +++ b/nomad/periodic_test.go @@ -656,7 +656,7 @@ func deriveChildJob(parent *structs.Job) *structs.Job { func TestPeriodicDispatch_RunningChildren_NoEvals(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() testutil.WaitForLeader(t, s1.RPC) @@ -679,7 +679,7 @@ func TestPeriodicDispatch_RunningChildren_NoEvals(t *testing.T) { func TestPeriodicDispatch_RunningChildren_ActiveEvals(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() testutil.WaitForLeader(t, s1.RPC) @@ -715,7 +715,7 @@ func TestPeriodicDispatch_RunningChildren_ActiveEvals(t *testing.T) { func TestPeriodicDispatch_RunningChildren_ActiveAllocs(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() testutil.WaitForLeader(t, s1.RPC) diff --git a/nomad/plan_apply_test.go b/nomad/plan_apply_test.go index 93e44e6175c..a6e6529d283 100644 --- a/nomad/plan_apply_test.go +++ b/nomad/plan_apply_test.go @@ -5,6 +5,7 @@ import ( "testing" memdb "github.com/hashicorp/go-memdb" + "github.com/hashicorp/nomad/helper/testlog" "github.com/hashicorp/nomad/helper/uuid" "github.com/hashicorp/nomad/nomad/mock" "github.com/hashicorp/nomad/nomad/structs" @@ -62,7 +63,7 @@ func testRegisterJob(t *testing.T, s *Server, j *structs.Job) { func TestPlanApply_applyPlan(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() testutil.WaitForLeader(t, s1.RPC) @@ -252,7 +253,7 @@ func TestPlanApply_EvalPlan_Simple(t *testing.T) { pool := NewEvaluatePool(workerPoolSize, workerPoolBufferSize) defer pool.Shutdown() - result, err := evaluatePlan(pool, snap, plan, testLogger()) + result, err := evaluatePlan(pool, snap, plan, testlog.Logger(t)) if err != nil { t.Fatalf("err: %v", err) } @@ -299,7 +300,7 @@ func TestPlanApply_EvalPlan_Partial(t *testing.T) { pool := NewEvaluatePool(workerPoolSize, workerPoolBufferSize) defer pool.Shutdown() - result, err := evaluatePlan(pool, snap, plan, testLogger()) + result, err := evaluatePlan(pool, snap, plan, testlog.Logger(t)) if err != nil { t.Fatalf("err: %v", err) } @@ -360,7 +361,7 @@ func TestPlanApply_EvalPlan_Partial_AllAtOnce(t *testing.T) { pool := NewEvaluatePool(workerPoolSize, workerPoolBufferSize) defer pool.Shutdown() - result, err := evaluatePlan(pool, snap, plan, testLogger()) + result, err := evaluatePlan(pool, snap, plan, testlog.Logger(t)) if err != nil { t.Fatalf("err: %v", err) } diff --git a/nomad/plan_endpoint_test.go b/nomad/plan_endpoint_test.go index ca4784ba145..ed71ec4171e 100644 --- a/nomad/plan_endpoint_test.go +++ b/nomad/plan_endpoint_test.go @@ -12,7 +12,7 @@ import ( func TestPlanEndpoint_Submit(t *testing.T) { t.Parallel() - s1 := testServer(t, func(c *Config) { + s1 := TestServer(t, func(c *Config) { c.NumSchedulers = 0 }) defer s1.Shutdown() diff --git a/nomad/raft_rpc.go b/nomad/raft_rpc.go index e7f73357d57..164867a7f93 100644 --- a/nomad/raft_rpc.go +++ b/nomad/raft_rpc.go @@ -7,6 +7,7 @@ import ( "sync" "time" + "github.com/hashicorp/nomad/helper/pool" "github.com/hashicorp/nomad/helper/tlsutil" "github.com/hashicorp/raft" ) @@ -111,7 +112,7 @@ func (l *RaftLayer) Dial(address raft.ServerAddress, timeout time.Duration) (net // Check for tls mode if tlsWrapper != nil { // Switch the connection into TLS mode - if _, err := conn.Write([]byte{byte(rpcTLS)}); err != nil { + if _, err := conn.Write([]byte{byte(pool.RpcTLS)}); err != nil { conn.Close() return nil, err } @@ -124,7 +125,7 @@ func (l *RaftLayer) Dial(address raft.ServerAddress, timeout time.Duration) (net } // Write the Raft byte to set the mode - _, err = conn.Write([]byte{byte(rpcRaft)}) + _, err = conn.Write([]byte{byte(pool.RpcRaft)}) if err != nil { conn.Close() return nil, err diff --git a/nomad/regions_endpoint_test.go b/nomad/regions_endpoint_test.go index 0da399e0ac0..8bd79d0bb5f 100644 --- a/nomad/regions_endpoint_test.go +++ b/nomad/regions_endpoint_test.go @@ -12,13 +12,13 @@ import ( func TestRegionList(t *testing.T) { t.Parallel() // Make the servers - s1 := testServer(t, func(c *Config) { + s1 := TestServer(t, func(c *Config) { c.Region = "region1" }) defer s1.Shutdown() codec := rpcClient(t, s1) - s2 := testServer(t, func(c *Config) { + s2 := TestServer(t, func(c *Config) { c.Region = "region2" }) defer s2.Shutdown() diff --git a/nomad/rpc.go b/nomad/rpc.go index 1b7e8bccab8..537e73d9df1 100644 --- a/nomad/rpc.go +++ b/nomad/rpc.go @@ -3,6 +3,8 @@ package nomad import ( "context" "crypto/tls" + "crypto/x509" + "errors" "fmt" "io" "math/rand" @@ -14,20 +16,12 @@ import ( metrics "github.com/armon/go-metrics" "github.com/hashicorp/consul/lib" memdb "github.com/hashicorp/go-memdb" - msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc" + "github.com/hashicorp/nomad/helper/pool" "github.com/hashicorp/nomad/nomad/state" "github.com/hashicorp/nomad/nomad/structs" "github.com/hashicorp/raft" "github.com/hashicorp/yamux" -) - -type RPCType byte - -const ( - rpcNomad RPCType = 0x01 - rpcRaft = 0x02 - rpcMultiplex = 0x03 - rpcTLS = 0x04 + "github.com/ugorji/go/codec" ) const ( @@ -38,12 +32,6 @@ const ( // if no time is specified. Previously we would wait the maxQueryTime. defaultQueryTime = 300 * time.Second - // jitterFraction is a the limit to the amount of jitter we apply - // to a user specified MaxQueryTime. We divide the specified time by - // the fraction. So 16 == 6.25% limit of jitter. This jitter is also - // applied to RPCHoldTimeout. - jitterFraction = 16 - // Warn if the Raft command is larger than this. // If it's over 1MB something is probably being abusive. raftWarnSize = 1024 * 1024 @@ -55,16 +43,23 @@ const ( enqueueLimit = 30 * time.Second ) -// NewClientCodec returns a new rpc.ClientCodec to be used to make RPC calls to -// the Nomad Server. -func NewClientCodec(conn io.ReadWriteCloser) rpc.ClientCodec { - return msgpackrpc.NewCodecFromHandle(true, true, conn, structs.HashiMsgpackHandle) -} +// RPCContext provides metadata about the RPC connection. +type RPCContext struct { + // Conn exposes the raw connection. + Conn net.Conn + + // Session exposes the multiplexed connection session. + Session *yamux.Session + + // TLS marks whether the RPC is over a TLS based connection + TLS bool -// NewServerCodec returns a new rpc.ServerCodec to be used by the Nomad Server -// to handle rpcs. -func NewServerCodec(conn io.ReadWriteCloser) rpc.ServerCodec { - return msgpackrpc.NewCodecFromHandle(true, true, conn, structs.HashiMsgpackHandle) + // VerifiedChains is is the Verified certificates presented by the incoming + // connection. + VerifiedChains [][]*x509.Certificate + + // NodeID marks the NodeID that initiated the connection. + NodeID string } // listen is used to listen for incoming RPC connections @@ -94,14 +89,14 @@ func (s *Server) listen(ctx context.Context) { continue } - go s.handleConn(ctx, conn, false) + go s.handleConn(ctx, conn, &RPCContext{Conn: conn}) metrics.IncrCounter([]string{"nomad", "rpc", "accept_conn"}, 1) } } // handleConn is used to determine if this is a Raft or // Nomad type RPC connection and invoke the correct handler -func (s *Server) handleConn(ctx context.Context, conn net.Conn, isTLS bool) { +func (s *Server) handleConn(ctx context.Context, conn net.Conn, rpcCtx *RPCContext) { // Read a single byte buf := make([]byte, 1) if _, err := conn.Read(buf); err != nil { @@ -113,7 +108,7 @@ func (s *Server) handleConn(ctx context.Context, conn net.Conn, isTLS bool) { } // Enforce TLS if EnableRPC is set - if s.config.TLSConfig.EnableRPC && !isTLS && RPCType(buf[0]) != rpcTLS { + if s.config.TLSConfig.EnableRPC && !rpcCtx.TLS && pool.RPCType(buf[0]) != pool.RpcTLS { if !s.config.TLSConfig.RPCUpgradeMode { s.logger.Printf("[WARN] nomad.rpc: Non-TLS connection attempted from %s with RequireTLS set", conn.RemoteAddr().String()) conn.Close() @@ -122,25 +117,62 @@ func (s *Server) handleConn(ctx context.Context, conn net.Conn, isTLS bool) { } // Switch on the byte - switch RPCType(buf[0]) { - case rpcNomad: - s.handleNomadConn(ctx, conn) - - case rpcRaft: + switch pool.RPCType(buf[0]) { + case pool.RpcNomad: + // Create an RPC Server and handle the request + server := rpc.NewServer() + s.setupRpcServer(server, rpcCtx) + s.handleNomadConn(ctx, conn, server) + + // Remove any potential mapping between a NodeID to this connection and + // close the underlying connection. + s.removeNodeConn(rpcCtx) + + case pool.RpcRaft: metrics.IncrCounter([]string{"nomad", "rpc", "raft_handoff"}, 1) s.raftLayer.Handoff(ctx, conn) - case rpcMultiplex: - s.handleMultiplex(ctx, conn) + case pool.RpcMultiplex: + s.handleMultiplex(ctx, conn, rpcCtx) - case rpcTLS: + case pool.RpcTLS: if s.rpcTLS == nil { s.logger.Printf("[WARN] nomad.rpc: TLS connection attempted, server not configured for TLS") conn.Close() return } conn = tls.Server(conn, s.rpcTLS) - s.handleConn(ctx, conn, true) + + // Force a handshake so we can get information about the TLS connection + // state. + tlsConn, ok := conn.(*tls.Conn) + if !ok { + s.logger.Printf("[ERR] nomad.rpc: expected TLS connection but got %T", conn) + conn.Close() + return + } + + if err := tlsConn.Handshake(); err != nil { + s.logger.Printf("[WARN] nomad.rpc: failed TLS handshake from connection from %v: %v", tlsConn.RemoteAddr(), err) + conn.Close() + return + } + + // Update the connection context with the fact that the connection is + // using TLS + rpcCtx.TLS = true + + // Store the verified chains so they can be inspected later. + state := tlsConn.ConnectionState() + rpcCtx.VerifiedChains = state.VerifiedChains + + s.handleConn(ctx, conn, rpcCtx) + + case pool.RpcStreaming: + s.handleStreamingConn(conn) + + case pool.RpcMultiplexV2: + s.handleMultiplexV2(ctx, conn, rpcCtx) default: s.logger.Printf("[ERR] nomad.rpc: unrecognized RPC byte: %v", buf[0]) @@ -151,11 +183,29 @@ func (s *Server) handleConn(ctx context.Context, conn net.Conn, isTLS bool) { // handleMultiplex is used to multiplex a single incoming connection // using the Yamux multiplexer -func (s *Server) handleMultiplex(ctx context.Context, conn net.Conn) { - defer conn.Close() +func (s *Server) handleMultiplex(ctx context.Context, conn net.Conn, rpcCtx *RPCContext) { + defer func() { + // Remove any potential mapping between a NodeID to this connection and + // close the underlying connection. + s.removeNodeConn(rpcCtx) + conn.Close() + }() + conf := yamux.DefaultConfig() conf.LogOutput = s.config.LogOutput - server, _ := yamux.Server(conn, conf) + server, err := yamux.Server(conn, conf) + if err != nil { + s.logger.Printf("[ERR] nomad.rpc: multiplex failed to create yamux server: %v", err) + return + } + + // Update the context to store the yamux session + rpcCtx.Session = server + + // Create the RPC server for this connection + rpcServer := rpc.NewServer() + s.setupRpcServer(rpcServer, rpcCtx) + for { sub, err := server.Accept() if err != nil { @@ -164,14 +214,14 @@ func (s *Server) handleMultiplex(ctx context.Context, conn net.Conn) { } return } - go s.handleNomadConn(ctx, sub) + go s.handleNomadConn(ctx, sub, rpcServer) } } // handleNomadConn is used to service a single Nomad RPC connection -func (s *Server) handleNomadConn(ctx context.Context, conn net.Conn) { +func (s *Server) handleNomadConn(ctx context.Context, conn net.Conn, server *rpc.Server) { defer conn.Close() - rpcCodec := NewServerCodec(conn) + rpcCodec := pool.NewServerCodec(conn) for { select { case <-ctx.Done(): @@ -182,7 +232,7 @@ func (s *Server) handleNomadConn(ctx context.Context, conn net.Conn) { default: } - if err := s.rpcServer.ServeRequest(rpcCodec); err != nil { + if err := server.ServeRequest(rpcCodec); err != nil { if err != io.EOF && !strings.Contains(err.Error(), "closed") { s.logger.Printf("[ERR] nomad.rpc: RPC error: %v (%v)", err, conn) metrics.IncrCounter([]string{"nomad", "rpc", "request_error"}, 1) @@ -193,6 +243,106 @@ func (s *Server) handleNomadConn(ctx context.Context, conn net.Conn) { } } +// handleStreamingConn is used to handle a single Streaming Nomad RPC connection. +func (s *Server) handleStreamingConn(conn net.Conn) { + defer conn.Close() + + // Decode the header + var header structs.StreamingRpcHeader + decoder := codec.NewDecoder(conn, structs.MsgpackHandle) + if err := decoder.Decode(&header); err != nil { + if err != io.EOF && !strings.Contains(err.Error(), "closed") { + s.logger.Printf("[ERR] nomad.rpc: Streaming RPC error: %v (%v)", err, conn) + metrics.IncrCounter([]string{"nomad", "streaming_rpc", "request_error"}, 1) + } + + return + } + + ack := structs.StreamingRpcAck{} + handler, err := s.streamingRpcs.GetHandler(header.Method) + if err != nil { + s.logger.Printf("[ERR] nomad.rpc: Streaming RPC error: %v (%v)", err, conn) + metrics.IncrCounter([]string{"nomad", "streaming_rpc", "request_error"}, 1) + ack.Error = err.Error() + } + + // Send the acknowledgement + encoder := codec.NewEncoder(conn, structs.MsgpackHandle) + if err := encoder.Encode(ack); err != nil { + conn.Close() + return + } + + if ack.Error != "" { + return + } + + // Invoke the handler + metrics.IncrCounter([]string{"nomad", "streaming_rpc", "request"}, 1) + handler(conn) +} + +// handleMultiplexV2 is used to multiplex a single incoming connection +// using the Yamux multiplexer. Version 2 handling allows a single connection to +// switch streams between regulars RPCs and Streaming RPCs. +func (s *Server) handleMultiplexV2(ctx context.Context, conn net.Conn, rpcCtx *RPCContext) { + defer func() { + // Remove any potential mapping between a NodeID to this connection and + // close the underlying connection. + s.removeNodeConn(rpcCtx) + conn.Close() + }() + + conf := yamux.DefaultConfig() + conf.LogOutput = s.config.LogOutput + server, err := yamux.Server(conn, conf) + if err != nil { + s.logger.Printf("[ERR] nomad.rpc: multiplex_v2 failed to create yamux server: %v", err) + return + } + + // Update the context to store the yamux session + rpcCtx.Session = server + + // Create the RPC server for this connection + rpcServer := rpc.NewServer() + s.setupRpcServer(rpcServer, rpcCtx) + + for { + // Accept a new stream + sub, err := server.Accept() + if err != nil { + if err != io.EOF { + s.logger.Printf("[ERR] nomad.rpc: multiplex_v2 conn accept failed: %v", err) + } + return + } + + // Read a single byte + buf := make([]byte, 1) + if _, err := sub.Read(buf); err != nil { + if err != io.EOF { + s.logger.Printf("[ERR] nomad.rpc: multiplex_v2 failed to read byte: %v", err) + } + return + } + + // Determine which handler to use + switch pool.RPCType(buf[0]) { + case pool.RpcNomad: + go s.handleNomadConn(ctx, sub, rpcServer) + case pool.RpcStreaming: + go s.handleStreamingConn(sub) + + default: + s.logger.Printf("[ERR] nomad.rpc: multiplex_v2 unrecognized RPC byte: %v", buf[0]) + return + } + } + +} + // forward is used to forward to a remote region or to forward to the local leader // Returns a bool of if forwarding was performed, as well as any error func (s *Server) forward(method string, info structs.RPCInfo, args interface{}, reply interface{}) (bool, error) { @@ -234,7 +384,7 @@ CHECK_LEADER: firstCheck = time.Now() } if time.Now().Sub(firstCheck) < s.config.RPCHoldTimeout { - jitter := lib.RandomStagger(s.config.RPCHoldTimeout / jitterFraction) + jitter := lib.RandomStagger(s.config.RPCHoldTimeout / structs.JitterFraction) select { case <-time.After(jitter): goto CHECK_LEADER @@ -279,6 +429,15 @@ func (s *Server) forwardLeader(server *serverParts, method string, args interfac return s.connPool.RPC(s.config.Region, server.Addr, server.MajorVersion, method, args, reply) } +// forwardServer is used to forward an RPC call to a particular server +func (s *Server) forwardServer(server *serverParts, method string, args interface{}, reply interface{}) error { + // Handle a missing server + if server == nil { + return errors.New("must be given a valid server address") + } + return s.connPool.RPC(s.config.Region, server.Addr, server.MajorVersion, method, args, reply) +} + // forwardRegion is used to forward an RPC call to a remote region, or fail if no servers func (s *Server) forwardRegion(region, method string, args interface{}, reply interface{}) error { // Bail if we can't find any servers @@ -301,6 +460,87 @@ func (s *Server) forwardRegion(region, method string, args interface{}, reply in return s.connPool.RPC(region, server.Addr, server.MajorVersion, method, args, reply) } +// streamingRpc creates a connection to the given server and conducts the +// initial handshake, returning the connection or an error. It is the callers +// responsibility to close the connection if there is no returned error. +func (s *Server) streamingRpc(server *serverParts, method string) (net.Conn, error) { + // Try to dial the server + conn, err := net.DialTimeout("tcp", server.Addr.String(), 10*time.Second) + if err != nil { + return nil, err + } + + // Cast to TCPConn + if tcp, ok := conn.(*net.TCPConn); ok { + tcp.SetKeepAlive(true) + tcp.SetNoDelay(true) + } + + if err := s.streamingRpcImpl(conn, server.Region, method); err != nil { + return nil, err + } + + return conn, nil +} + +// streamingRpcImpl takes a pre-established connection to a server and conducts +// the handshake to establish a streaming RPC for the given method. If an error +// is returned, the underlying connection has been closed. Otherwise it is +// assumed that the connection has been hijacked by the RPC method. +func (s *Server) streamingRpcImpl(conn net.Conn, region, method string) error { + // Check if TLS is enabled + s.tlsWrapLock.RLock() + tlsWrap := s.tlsWrap + s.tlsWrapLock.RUnlock() + + if tlsWrap != nil { + // Switch the connection into TLS mode + if _, err := conn.Write([]byte{byte(pool.RpcTLS)}); err != nil { + conn.Close() + return err + } + + // Wrap the connection in a TLS client + tlsConn, err := tlsWrap(region, conn) + if err != nil { + conn.Close() + return err + } + conn = tlsConn + } + + // Write the multiplex byte to set the mode + if _, err := conn.Write([]byte{byte(pool.RpcStreaming)}); err != nil { + conn.Close() + return err + } + + // Send the header + encoder := codec.NewEncoder(conn, structs.MsgpackHandle) + decoder := codec.NewDecoder(conn, structs.MsgpackHandle) + header := structs.StreamingRpcHeader{ + Method: method, + } + if err := encoder.Encode(header); err != nil { + conn.Close() + return err + } + + // Wait for the acknowledgement + var ack structs.StreamingRpcAck + if err := decoder.Decode(&ack); err != nil { + conn.Close() + return err + } + + if ack.Error != "" { + conn.Close() + return errors.New(ack.Error) + } + + return nil +} + // raftApplyFuture is used to encode a message, run it through raft, and return the Raft future. func (s *Server) raftApplyFuture(t structs.MessageType, msg interface{}) (raft.ApplyFuture, error) { buf, err := structs.Encode(t, msg) @@ -378,7 +618,7 @@ func (s *Server) blockingRPC(opts *blockingOptions) error { } // Apply a small amount of jitter to the request - opts.queryOpts.MaxQueryTime += lib.RandomStagger(opts.queryOpts.MaxQueryTime / jitterFraction) + opts.queryOpts.MaxQueryTime += lib.RandomStagger(opts.queryOpts.MaxQueryTime / structs.JitterFraction) // Setup a query timeout ctx, cancel = context.WithTimeout(context.Background(), opts.queryOpts.MaxQueryTime) diff --git a/nomad/rpc_test.go b/nomad/rpc_test.go index 392bb6870bf..c876c6adb1d 100644 --- a/nomad/rpc_test.go +++ b/nomad/rpc_test.go @@ -1,6 +1,7 @@ package nomad import ( + "context" "net" "net/rpc" "os" @@ -9,11 +10,16 @@ import ( "time" msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc" + "github.com/hashicorp/nomad/helper/pool" + "github.com/hashicorp/nomad/helper/testlog" "github.com/hashicorp/nomad/nomad/mock" "github.com/hashicorp/nomad/nomad/structs" "github.com/hashicorp/nomad/nomad/structs/config" "github.com/hashicorp/nomad/testutil" + "github.com/hashicorp/raft" + "github.com/hashicorp/yamux" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) // rpcClient is a test helper method to return a ClientCodec to use to make rpc @@ -25,19 +31,19 @@ func rpcClient(t *testing.T, s *Server) rpc.ClientCodec { t.Fatalf("err: %v", err) } // Write the Consul RPC byte to set the mode - conn.Write([]byte{byte(rpcNomad)}) - return NewClientCodec(conn) + conn.Write([]byte{byte(pool.RpcNomad)}) + return pool.NewClientCodec(conn) } func TestRPC_forwardLeader(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() - s2 := testServer(t, func(c *Config) { + s2 := TestServer(t, func(c *Config) { c.DevDisableBootstrap = true }) defer s2.Shutdown() - testJoin(t, s1, s2) + TestJoin(t, s1, s2) testutil.WaitForLeader(t, s1.RPC) testutil.WaitForLeader(t, s2.RPC) @@ -70,13 +76,13 @@ func TestRPC_forwardLeader(t *testing.T) { func TestRPC_forwardRegion(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() - s2 := testServer(t, func(c *Config) { + s2 := TestServer(t, func(c *Config) { c.Region = "region2" }) defer s2.Shutdown() - testJoin(t, s1, s2) + TestJoin(t, s1, s2) testutil.WaitForLeader(t, s1.RPC) testutil.WaitForLeader(t, s2.RPC) @@ -104,7 +110,7 @@ func TestRPC_PlaintextRPCSucceedsWhenInUpgradeMode(t *testing.T) { dir := tmpDir(t) defer os.RemoveAll(dir) - s1 := testServer(t, func(c *Config) { + s1 := TestServer(t, func(c *Config) { c.DataDir = path.Join(dir, "node1") c.TLSConfig = &config.TLSConfig{ EnableRPC: true, @@ -147,7 +153,7 @@ func TestRPC_PlaintextRPCFailsWhenNotInUpgradeMode(t *testing.T) { dir := tmpDir(t) defer os.RemoveAll(dir) - s1 := testServer(t, func(c *Config) { + s1 := TestServer(t, func(c *Config) { c.DataDir = path.Join(dir, "node1") c.TLSConfig = &config.TLSConfig{ EnableRPC: true, @@ -171,3 +177,152 @@ func TestRPC_PlaintextRPCFailsWhenNotInUpgradeMode(t *testing.T) { err := msgpackrpc.CallWithCodec(codec, "Node.Register", req, &resp) assert.NotNil(err) } + +func TestRPC_streamingRpcConn_badMethod(t *testing.T) { + t.Parallel() + require := require.New(t) + + s1 := TestServer(t, nil) + defer s1.Shutdown() + s2 := TestServer(t, func(c *Config) { + c.DevDisableBootstrap = true + }) + defer s2.Shutdown() + TestJoin(t, s1, s2) + testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForLeader(t, s2.RPC) + + s1.peerLock.RLock() + ok, parts := isNomadServer(s2.LocalMember()) + require.True(ok) + server := s1.localPeers[raft.ServerAddress(parts.Addr.String())] + require.NotNil(server) + s1.peerLock.RUnlock() + + conn, err := s1.streamingRpc(server, "Bogus") + require.Nil(conn) + require.NotNil(err) + require.Contains(err.Error(), "Bogus") + require.True(structs.IsErrUnknownMethod(err)) +} + +func TestRPC_streamingRpcConn_badMethod_TLS(t *testing.T) { + t.Parallel() + require := require.New(t) + const ( + cafile = "../helper/tlsutil/testdata/ca.pem" + foocert = "../helper/tlsutil/testdata/nomad-foo.pem" + fookey = "../helper/tlsutil/testdata/nomad-foo-key.pem" + ) + dir := tmpDir(t) + defer os.RemoveAll(dir) + s1 := TestServer(t, func(c *Config) { + c.Region = "regionFoo" + c.BootstrapExpect = 2 + c.DevMode = false + c.DevDisableBootstrap = true + c.DataDir = path.Join(dir, "node1") + c.TLSConfig = &config.TLSConfig{ + EnableHTTP: true, + EnableRPC: true, + VerifyServerHostname: true, + CAFile: cafile, + CertFile: foocert, + KeyFile: fookey, + } + }) + defer s1.Shutdown() + + s2 := TestServer(t, func(c *Config) { + c.Region = "regionFoo" + c.BootstrapExpect = 2 + c.DevMode = false + c.DevDisableBootstrap = true + c.DataDir = path.Join(dir, "node2") + c.TLSConfig = &config.TLSConfig{ + EnableHTTP: true, + EnableRPC: true, + VerifyServerHostname: true, + CAFile: cafile, + CertFile: foocert, + KeyFile: fookey, + } + }) + defer s2.Shutdown() + + TestJoin(t, s1, s2) + testutil.WaitForLeader(t, s1.RPC) + + s1.peerLock.RLock() + ok, parts := isNomadServer(s2.LocalMember()) + require.True(ok) + server := s1.localPeers[raft.ServerAddress(parts.Addr.String())] + require.NotNil(server) + s1.peerLock.RUnlock() + + conn, err := s1.streamingRpc(server, "Bogus") + require.Nil(conn) + require.NotNil(err) + require.Contains(err.Error(), "Bogus") + require.True(structs.IsErrUnknownMethod(err)) +} + +// COMPAT: Remove in 0.10 +// This is a very low level test to assert that the V2 handling works. It is +// making manual RPC calls since no helpers exist at this point since we are +// only implementing support for v2 but not using it yet. In the future we can +// switch the conn pool to establishing v2 connections and we can deprecate this +// test. +func TestRPC_handleMultiplexV2(t *testing.T) { + t.Parallel() + require := require.New(t) + s := TestServer(t, nil) + defer s.Shutdown() + testutil.WaitForLeader(t, s.RPC) + + p1, p2 := net.Pipe() + defer p1.Close() + defer p2.Close() + + // Start the handler + doneCh := make(chan struct{}) + go func() { + s.handleConn(context.Background(), p2, &RPCContext{Conn: p2}) + close(doneCh) + }() + + // Establish the MultiplexV2 connection + _, err := p1.Write([]byte{byte(pool.RpcMultiplexV2)}) + require.Nil(err) + + // Make two streams + conf := yamux.DefaultConfig() + conf.LogOutput = testlog.NewWriter(t) + session, err := yamux.Client(p1, conf) + require.Nil(err) + + s1, err := session.Open() + require.Nil(err) + defer s1.Close() + + s2, err := session.Open() + require.Nil(err) + defer s2.Close() + + // Make an RPC + _, err = s1.Write([]byte{byte(pool.RpcNomad)}) + require.Nil(err) + + args := &structs.GenericRequest{} + var l string + err = msgpackrpc.CallWithCodec(pool.NewClientCodec(s1), "Status.Leader", args, &l) + require.Nil(err) + require.NotEmpty(l) + + // Make a streaming RPC + err = s.streamingRpcImpl(s2, s.Region(), "Bogus") + require.NotNil(err) + require.Contains(err.Error(), "Bogus") + require.True(structs.IsErrUnknownMethod(err)) + +} diff --git a/nomad/search_endpoint_test.go b/nomad/search_endpoint_test.go index 2631b695876..de27be60bc3 100644 --- a/nomad/search_endpoint_test.go +++ b/nomad/search_endpoint_test.go @@ -31,7 +31,7 @@ func TestSearch_PrefixSearch_Job(t *testing.T) { prefix := "aaaaaaaa-e8f7-fd38-c855-ab94ceb8970" t.Parallel() - s := testServer(t, func(c *Config) { + s := TestServer(t, func(c *Config) { c.NumSchedulers = 0 }) @@ -65,7 +65,7 @@ func TestSearch_PrefixSearch_ACL(t *testing.T) { jobID := "aaaaaaaa-e8f7-fd38-c855-ab94ceb8970" t.Parallel() - s, root := testACLServer(t, func(c *Config) { + s, root := TestACLServer(t, func(c *Config) { c.NumSchedulers = 0 }) @@ -178,7 +178,7 @@ func TestSearch_PrefixSearch_All_JobWithHyphen(t *testing.T) { prefix := "example-test-------" // Assert that a job with more than 4 hyphens works t.Parallel() - s := testServer(t, func(c *Config) { + s := TestServer(t, func(c *Config) { c.NumSchedulers = 0 }) @@ -225,7 +225,7 @@ func TestSearch_PrefixSearch_All_LongJob(t *testing.T) { prefix := strings.Repeat("a", 100) t.Parallel() - s := testServer(t, func(c *Config) { + s := TestServer(t, func(c *Config) { c.NumSchedulers = 0 }) @@ -272,7 +272,7 @@ func TestSearch_PrefixSearch_Truncate(t *testing.T) { prefix := "aaaaaaaa-e8f7-fd38-c855-ab94ceb8970" t.Parallel() - s := testServer(t, func(c *Config) { + s := TestServer(t, func(c *Config) { c.NumSchedulers = 0 }) @@ -309,7 +309,7 @@ func TestSearch_PrefixSearch_AllWithJob(t *testing.T) { prefix := "aaaaaaaa-e8f7-fd38-c855-ab94ceb8970" t.Parallel() - s := testServer(t, func(c *Config) { + s := TestServer(t, func(c *Config) { c.NumSchedulers = 0 }) @@ -347,7 +347,7 @@ func TestSearch_PrefixSearch_AllWithJob(t *testing.T) { func TestSearch_PrefixSearch_Evals(t *testing.T) { assert := assert.New(t) t.Parallel() - s := testServer(t, func(c *Config) { + s := TestServer(t, func(c *Config) { c.NumSchedulers = 0 }) @@ -384,7 +384,7 @@ func TestSearch_PrefixSearch_Evals(t *testing.T) { func TestSearch_PrefixSearch_Allocation(t *testing.T) { assert := assert.New(t) t.Parallel() - s := testServer(t, func(c *Config) { + s := TestServer(t, func(c *Config) { c.NumSchedulers = 0 }) @@ -429,7 +429,7 @@ func TestSearch_PrefixSearch_Allocation(t *testing.T) { func TestSearch_PrefixSearch_All_UUID(t *testing.T) { assert := assert.New(t) t.Parallel() - s := testServer(t, func(c *Config) { + s := TestServer(t, func(c *Config) { c.NumSchedulers = 0 }) @@ -481,7 +481,7 @@ func TestSearch_PrefixSearch_All_UUID(t *testing.T) { func TestSearch_PrefixSearch_Node(t *testing.T) { assert := assert.New(t) t.Parallel() - s := testServer(t, func(c *Config) { + s := TestServer(t, func(c *Config) { c.NumSchedulers = 0 }) @@ -522,7 +522,7 @@ func TestSearch_PrefixSearch_Node(t *testing.T) { func TestSearch_PrefixSearch_Deployment(t *testing.T) { assert := assert.New(t) t.Parallel() - s := testServer(t, func(c *Config) { + s := TestServer(t, func(c *Config) { c.NumSchedulers = 0 }) @@ -559,7 +559,7 @@ func TestSearch_PrefixSearch_Deployment(t *testing.T) { func TestSearch_PrefixSearch_AllContext(t *testing.T) { assert := assert.New(t) t.Parallel() - s := testServer(t, func(c *Config) { + s := TestServer(t, func(c *Config) { c.NumSchedulers = 0 }) @@ -612,7 +612,7 @@ func TestSearch_PrefixSearch_NoPrefix(t *testing.T) { prefix := "aaaaaaaa-e8f7-fd38-c855-ab94ceb8970" t.Parallel() - s := testServer(t, func(c *Config) { + s := TestServer(t, func(c *Config) { c.NumSchedulers = 0 }) @@ -649,7 +649,7 @@ func TestSearch_PrefixSearch_NoMatches(t *testing.T) { prefix := "aaaaaaaa-e8f7-fd38-c855-ab94ceb8970" t.Parallel() - s := testServer(t, func(c *Config) { + s := TestServer(t, func(c *Config) { c.NumSchedulers = 0 }) @@ -684,7 +684,7 @@ func TestSearch_PrefixSearch_RoundDownToEven(t *testing.T) { prefix := "aaafa" t.Parallel() - s := testServer(t, func(c *Config) { + s := TestServer(t, func(c *Config) { c.NumSchedulers = 0 }) @@ -719,19 +719,19 @@ func TestSearch_PrefixSearch_MultiRegion(t *testing.T) { jobName := "exampleexample" t.Parallel() - s1 := testServer(t, func(c *Config) { + s1 := TestServer(t, func(c *Config) { c.NumSchedulers = 0 c.Region = "foo" }) defer s1.Shutdown() - s2 := testServer(t, func(c *Config) { + s2 := TestServer(t, func(c *Config) { c.NumSchedulers = 0 c.Region = "bar" }) defer s2.Shutdown() - testJoin(t, s1, s2) + TestJoin(t, s1, s2) testutil.WaitForLeader(t, s1.RPC) job := registerAndVerifyJob(s1, t, jobName, 0) diff --git a/nomad/serf_test.go b/nomad/serf_test.go index 46706483347..14d27ad41b8 100644 --- a/nomad/serf_test.go +++ b/nomad/serf_test.go @@ -14,13 +14,13 @@ import ( func TestNomad_JoinPeer(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() - s2 := testServer(t, func(c *Config) { + s2 := TestServer(t, func(c *Config) { c.Region = "region2" }) defer s2.Shutdown() - testJoin(t, s1, s2) + TestJoin(t, s1, s2) testutil.WaitForResult(func() (bool, error) { if members := s1.Members(); len(members) != 2 { @@ -55,13 +55,13 @@ func TestNomad_JoinPeer(t *testing.T) { func TestNomad_RemovePeer(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() - s2 := testServer(t, func(c *Config) { + s2 := TestServer(t, func(c *Config) { c.Region = "region2" }) defer s2.Shutdown() - testJoin(t, s1, s2) + TestJoin(t, s1, s2) testutil.WaitForResult(func() (bool, error) { if members := s1.Members(); len(members) != 2 { @@ -96,7 +96,7 @@ func TestNomad_ReapPeer(t *testing.T) { t.Parallel() dir := tmpDir(t) defer os.RemoveAll(dir) - s1 := testServer(t, func(c *Config) { + s1 := TestServer(t, func(c *Config) { c.NodeName = "node1" c.BootstrapExpect = 3 c.DevMode = false @@ -104,7 +104,7 @@ func TestNomad_ReapPeer(t *testing.T) { c.DataDir = path.Join(dir, "node1") }) defer s1.Shutdown() - s2 := testServer(t, func(c *Config) { + s2 := TestServer(t, func(c *Config) { c.NodeName = "node2" c.BootstrapExpect = 3 c.DevMode = false @@ -112,7 +112,7 @@ func TestNomad_ReapPeer(t *testing.T) { c.DataDir = path.Join(dir, "node2") }) defer s2.Shutdown() - s3 := testServer(t, func(c *Config) { + s3 := TestServer(t, func(c *Config) { c.NodeName = "node3" c.BootstrapExpect = 3 c.DevMode = false @@ -120,11 +120,11 @@ func TestNomad_ReapPeer(t *testing.T) { c.DataDir = path.Join(dir, "node3") }) defer s3.Shutdown() - testJoin(t, s1, s2, s3) + TestJoin(t, s1, s2, s3) testutil.WaitForResult(func() (bool, error) { // Retry the join to decrease flakiness - testJoin(t, s1, s2, s3) + TestJoin(t, s1, s2, s3) if members := s1.Members(); len(members) != 3 { return false, fmt.Errorf("bad s1: %#v", members) } @@ -191,32 +191,32 @@ func TestNomad_BootstrapExpect(t *testing.T) { dir := tmpDir(t) defer os.RemoveAll(dir) - s1 := testServer(t, func(c *Config) { + s1 := TestServer(t, func(c *Config) { c.BootstrapExpect = 3 c.DevMode = false c.DevDisableBootstrap = true c.DataDir = path.Join(dir, "node1") }) defer s1.Shutdown() - s2 := testServer(t, func(c *Config) { + s2 := TestServer(t, func(c *Config) { c.BootstrapExpect = 3 c.DevMode = false c.DevDisableBootstrap = true c.DataDir = path.Join(dir, "node2") }) defer s2.Shutdown() - s3 := testServer(t, func(c *Config) { + s3 := TestServer(t, func(c *Config) { c.BootstrapExpect = 3 c.DevMode = false c.DevDisableBootstrap = true c.DataDir = path.Join(dir, "node3") }) defer s3.Shutdown() - testJoin(t, s1, s2, s3) + TestJoin(t, s1, s2, s3) testutil.WaitForResult(func() (bool, error) { // Retry the join to decrease flakiness - testJoin(t, s1, s2, s3) + TestJoin(t, s1, s2, s3) peers, err := s1.numPeers() if err != nil { return false, err @@ -254,7 +254,7 @@ func TestNomad_BootstrapExpect(t *testing.T) { // Join a fourth server after quorum has already been formed and ensure // there is no election - s4 := testServer(t, func(c *Config) { + s4 := TestServer(t, func(c *Config) { c.BootstrapExpect = 3 c.DevMode = false c.DevDisableBootstrap = true @@ -300,18 +300,18 @@ func TestNomad_BootstrapExpect(t *testing.T) { func TestNomad_BadExpect(t *testing.T) { t.Parallel() - s1 := testServer(t, func(c *Config) { + s1 := TestServer(t, func(c *Config) { c.BootstrapExpect = 2 c.DevDisableBootstrap = true }) defer s1.Shutdown() - s2 := testServer(t, func(c *Config) { + s2 := TestServer(t, func(c *Config) { c.BootstrapExpect = 3 c.DevDisableBootstrap = true }) defer s2.Shutdown() servers := []*Server{s1, s2} - testJoin(t, s1, s2) + TestJoin(t, s1, s2) // Serf members should update testutil.WaitForResult(func() (bool, error) { diff --git a/nomad/server.go b/nomad/server.go index b67ddec97fe..a70c2f8905f 100644 --- a/nomad/server.go +++ b/nomad/server.go @@ -3,7 +3,6 @@ package nomad import ( "context" "crypto/tls" - "errors" "fmt" "io/ioutil" "log" @@ -11,7 +10,6 @@ import ( "net/rpc" "os" "path/filepath" - "reflect" "sort" "strconv" "sync" @@ -24,6 +22,9 @@ import ( multierror "github.com/hashicorp/go-multierror" lru "github.com/hashicorp/golang-lru" "github.com/hashicorp/nomad/command/agent/consul" + "github.com/hashicorp/nomad/helper/codec" + "github.com/hashicorp/nomad/helper/pool" + "github.com/hashicorp/nomad/helper/stats" "github.com/hashicorp/nomad/helper/tlsutil" "github.com/hashicorp/nomad/nomad/deploymentwatcher" "github.com/hashicorp/nomad/nomad/state" @@ -90,10 +91,7 @@ type Server struct { logger *log.Logger // Connection pool to other Nomad servers - connPool *ConnPool - - // Endpoints holds our RPC endpoints - endpoints endpoints + connPool *pool.ConnPool // The raft instance is used among Nomad nodes within the // region to protect operations that require strong consistency @@ -114,13 +112,33 @@ type Server struct { rpcListener net.Listener listenerCh chan struct{} - rpcServer *rpc.Server + // tlsWrap is used to wrap outbound connections using TLS. It should be + // accessed using the lock. + tlsWrap tlsutil.RegionWrapper + tlsWrapLock sync.RWMutex + + // rpcServer is the static RPC server that is used by the local agent. + rpcServer *rpc.Server + + // rpcAdvertise is the advertised address for the RPC listener. rpcAdvertise net.Addr // rpcTLS is the TLS config for incoming TLS requests rpcTLS *tls.Config rpcCancel context.CancelFunc + // staticEndpoints is the set of static endpoints that can be reused across + // all RPC connections + staticEndpoints endpoints + + // streamingRpcs is the registry holding our streaming RPC handlers. + streamingRpcs *structs.StreamingRpcRegistery + + // nodeConns is the set of multiplexed node connections we have keyed by + // NodeID + nodeConns map[string]*nodeConnState + nodeConnsLock sync.RWMutex + // peers is used to track the known Nomad servers. This is // used for region forwarding and clustering. peers map[string][]*serverParts @@ -210,6 +228,11 @@ type endpoints struct { Operator *Operator ACL *ACL Enterprise *EnterpriseEndpoints + + // Client endpoints + ClientStats *ClientStats + FileSystem *FileSystem + ClientAllocations *ClientAllocations } // NewServer is used to construct a new Nomad server from the @@ -256,9 +279,12 @@ func NewServer(config *Config, consulCatalog consul.CatalogAPI, logger *log.Logg s := &Server{ config: config, consulCatalog: consulCatalog, - connPool: NewPool(config.LogOutput, serverRPCCache, serverMaxStreams, tlsWrap), + connPool: pool.NewPool(config.LogOutput, serverRPCCache, serverMaxStreams, tlsWrap), logger: logger, + tlsWrap: tlsWrap, rpcServer: rpc.NewServer(), + streamingRpcs: structs.NewStreamingRpcRegistery(), + nodeConns: make(map[string]*nodeConnState), peers: make(map[string][]*serverParts), localPeers: make(map[raft.ServerAddress]*serverParts), reconcileCh: make(chan serf.Member, 32), @@ -415,6 +441,11 @@ func (s *Server) reloadTLSConnections(newTLSConfig *config.TLSConfig) error { return err } + // Store the new tls wrapper. + s.tlsWrapLock.Lock() + s.tlsWrap = tlsWrap + s.tlsWrapLock.Unlock() + if s.rpcCancel == nil { err = fmt.Errorf("No existing RPC server to reset.") s.logger.Printf("[ERR] nomad: %s", err) @@ -855,37 +886,8 @@ func (s *Server) setupVaultClient() error { // setupRPC is used to setup the RPC listener func (s *Server) setupRPC(tlsWrap tlsutil.RegionWrapper) error { - // Create endpoints - s.endpoints.ACL = &ACL{s} - s.endpoints.Alloc = &Alloc{s} - s.endpoints.Eval = &Eval{s} - s.endpoints.Job = &Job{s} - s.endpoints.Node = &Node{srv: s} - s.endpoints.Deployment = &Deployment{srv: s} - s.endpoints.Operator = &Operator{s} - s.endpoints.Periodic = &Periodic{s} - s.endpoints.Plan = &Plan{s} - s.endpoints.Region = &Region{s} - s.endpoints.Status = &Status{s} - s.endpoints.System = &System{s} - s.endpoints.Search = &Search{s} - s.endpoints.Enterprise = NewEnterpriseEndpoints(s) - - // Register the handlers - s.rpcServer.Register(s.endpoints.ACL) - s.rpcServer.Register(s.endpoints.Alloc) - s.rpcServer.Register(s.endpoints.Eval) - s.rpcServer.Register(s.endpoints.Job) - s.rpcServer.Register(s.endpoints.Node) - s.rpcServer.Register(s.endpoints.Deployment) - s.rpcServer.Register(s.endpoints.Operator) - s.rpcServer.Register(s.endpoints.Periodic) - s.rpcServer.Register(s.endpoints.Plan) - s.rpcServer.Register(s.endpoints.Region) - s.rpcServer.Register(s.endpoints.Status) - s.rpcServer.Register(s.endpoints.System) - s.rpcServer.Register(s.endpoints.Search) - s.endpoints.Enterprise.Register(s) + // Populate the static RPC server + s.setupRpcServer(s.rpcServer, nil) listener, err := s.createRPCListener() if err != nil { @@ -915,6 +917,60 @@ func (s *Server) setupRPC(tlsWrap tlsutil.RegionWrapper) error { return nil } +// setupRpcServer is used to populate an RPC server with endpoints +func (s *Server) setupRpcServer(server *rpc.Server, ctx *RPCContext) { + // Add the static endpoints to the RPC server. + if s.staticEndpoints.Status == nil { + // Initialize the list just once + s.staticEndpoints.ACL = &ACL{s} + s.staticEndpoints.Alloc = &Alloc{s} + s.staticEndpoints.Eval = &Eval{s} + s.staticEndpoints.Job = &Job{s} + s.staticEndpoints.Node = &Node{srv: s} // Add but don't register + s.staticEndpoints.Deployment = &Deployment{srv: s} + s.staticEndpoints.Operator = &Operator{s} + s.staticEndpoints.Periodic = &Periodic{s} + s.staticEndpoints.Plan = &Plan{s} + s.staticEndpoints.Region = &Region{s} + s.staticEndpoints.Status = &Status{s} + s.staticEndpoints.System = &System{s} + s.staticEndpoints.Search = &Search{s} + s.staticEndpoints.Enterprise = NewEnterpriseEndpoints(s) + + // Client endpoints + s.staticEndpoints.ClientStats = &ClientStats{s} + s.staticEndpoints.ClientAllocations = &ClientAllocations{s} + + // Streaming endpoints + s.staticEndpoints.FileSystem = &FileSystem{s} + s.staticEndpoints.FileSystem.register() + } + + // Register the static handlers + server.Register(s.staticEndpoints.ACL) + server.Register(s.staticEndpoints.Alloc) + server.Register(s.staticEndpoints.Eval) + server.Register(s.staticEndpoints.Job) + server.Register(s.staticEndpoints.Deployment) + server.Register(s.staticEndpoints.Operator) + server.Register(s.staticEndpoints.Periodic) + server.Register(s.staticEndpoints.Plan) + server.Register(s.staticEndpoints.Region) + server.Register(s.staticEndpoints.Status) + server.Register(s.staticEndpoints.System) + server.Register(s.staticEndpoints.Search) + s.staticEndpoints.Enterprise.Register(server) + server.Register(s.staticEndpoints.ClientStats) + server.Register(s.staticEndpoints.ClientAllocations) + server.Register(s.staticEndpoints.FileSystem) + + // Create new dynamic endpoints and add them to the RPC server. + node := &Node{srv: s, ctx: ctx} + + // Register the dynamic endpoints + server.Register(node) +} + // setupRaft is used to setup and initialize Raft func (s *Server) setupRaft() error { // If we have an unclean exit then attempt to close the Raft store. @@ -1231,52 +1287,22 @@ func (s *Server) Regions() []string { return regions } -// inmemCodec is used to do an RPC call without going over a network -type inmemCodec struct { - method string - args interface{} - reply interface{} - err error -} - -func (i *inmemCodec) ReadRequestHeader(req *rpc.Request) error { - req.ServiceMethod = i.method - return nil -} - -func (i *inmemCodec) ReadRequestBody(args interface{}) error { - sourceValue := reflect.Indirect(reflect.Indirect(reflect.ValueOf(i.args))) - dst := reflect.Indirect(reflect.Indirect(reflect.ValueOf(args))) - dst.Set(sourceValue) - return nil -} - -func (i *inmemCodec) WriteResponse(resp *rpc.Response, reply interface{}) error { - if resp.Error != "" { - i.err = errors.New(resp.Error) - return nil - } - sourceValue := reflect.Indirect(reflect.Indirect(reflect.ValueOf(reply))) - dst := reflect.Indirect(reflect.Indirect(reflect.ValueOf(i.reply))) - dst.Set(sourceValue) - return nil -} - -func (i *inmemCodec) Close() error { - return nil -} - // RPC is used to make a local RPC call func (s *Server) RPC(method string, args interface{}, reply interface{}) error { - codec := &inmemCodec{ - method: method, - args: args, - reply: reply, + codec := &codec.InmemCodec{ + Method: method, + Args: args, + Reply: reply, } if err := s.rpcServer.ServeRequest(codec); err != nil { return err } - return codec.err + return codec.Err +} + +// StreamingRpcHandler is used to make a streaming RPC call. +func (s *Server) StreamingRpcHandler(method string) (structs.StreamingRpcHandler, error) { + return s.streamingRpcs.GetHandler(method) } // Stats is used to return statistics for debugging and insight @@ -1295,7 +1321,7 @@ func (s *Server) Stats() map[string]map[string]string { }, "raft": s.raft.Stats(), "serf": s.serf.Stats(), - "runtime": RuntimeStats(), + "runtime": stats.RuntimeStats(), } return stats diff --git a/nomad/server_test.go b/nomad/server_test.go index bfe4fbf4ba2..81b06197d99 100644 --- a/nomad/server_test.go +++ b/nomad/server_test.go @@ -3,19 +3,13 @@ package nomad import ( "fmt" "io/ioutil" - "log" - "math/rand" - "net" "os" "path" "strings" - "sync/atomic" "testing" "time" - "github.com/hashicorp/consul/lib/freeport" msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc" - "github.com/hashicorp/nomad/command/agent/consul" "github.com/hashicorp/nomad/helper/uuid" "github.com/hashicorp/nomad/nomad/mock" "github.com/hashicorp/nomad/nomad/structs" @@ -24,15 +18,8 @@ import ( "github.com/stretchr/testify/assert" ) -var ( - nodeNumber uint32 = 0 -) - -func testLogger() *log.Logger { - return log.New(os.Stderr, "", log.LstdFlags) -} - func tmpDir(t *testing.T) string { + t.Helper() dir, err := ioutil.TempDir("", "nomad") if err != nil { t.Fatalf("err: %v", err) @@ -40,110 +27,9 @@ func tmpDir(t *testing.T) string { return dir } -func testACLServer(t *testing.T, cb func(*Config)) (*Server, *structs.ACLToken) { - server := testServer(t, func(c *Config) { - c.ACLEnabled = true - if cb != nil { - cb(c) - } - }) - token := mock.ACLManagementToken() - err := server.State().BootstrapACLTokens(1, 0, token) - if err != nil { - t.Fatalf("failed to bootstrap ACL token: %v", err) - } - return server, token -} - -func testServer(t *testing.T, cb func(*Config)) *Server { - // Setup the default settings - config := DefaultConfig() - config.Build = "0.8.0+unittest" - config.DevMode = true - nodeNum := atomic.AddUint32(&nodeNumber, 1) - config.NodeName = fmt.Sprintf("nomad-%03d", nodeNum) - - // Tighten the Serf timing - config.SerfConfig.MemberlistConfig.BindAddr = "127.0.0.1" - config.SerfConfig.MemberlistConfig.SuspicionMult = 2 - config.SerfConfig.MemberlistConfig.RetransmitMult = 2 - config.SerfConfig.MemberlistConfig.ProbeTimeout = 50 * time.Millisecond - config.SerfConfig.MemberlistConfig.ProbeInterval = 100 * time.Millisecond - config.SerfConfig.MemberlistConfig.GossipInterval = 100 * time.Millisecond - - // Tighten the Raft timing - config.RaftConfig.LeaderLeaseTimeout = 50 * time.Millisecond - config.RaftConfig.HeartbeatTimeout = 50 * time.Millisecond - config.RaftConfig.ElectionTimeout = 50 * time.Millisecond - config.RaftTimeout = 500 * time.Millisecond - - // Tighten the autopilot timing - config.AutopilotConfig.ServerStabilizationTime = 100 * time.Millisecond - config.ServerHealthInterval = 50 * time.Millisecond - config.AutopilotInterval = 100 * time.Millisecond - - // Disable Vault - f := false - config.VaultConfig.Enabled = &f - - // Squelch output when -v isn't specified - if !testing.Verbose() { - config.LogOutput = ioutil.Discard - } - - // Invoke the callback if any - if cb != nil { - cb(config) - } - - // Enable raft as leader if we have bootstrap on - config.RaftConfig.StartAsLeader = !config.DevDisableBootstrap - - logger := log.New(config.LogOutput, fmt.Sprintf("[%s] ", config.NodeName), log.LstdFlags) - catalog := consul.NewMockCatalog(logger) - - for i := 10; i >= 0; i-- { - // Get random ports - ports := freeport.GetT(t, 2) - config.RPCAddr = &net.TCPAddr{ - IP: []byte{127, 0, 0, 1}, - Port: ports[0], - } - config.SerfConfig.MemberlistConfig.BindPort = ports[1] - - // Create server - server, err := NewServer(config, catalog, logger) - if err == nil { - return server - } else if i == 0 { - t.Fatalf("err: %v", err) - } else { - if server != nil { - server.Shutdown() - } - wait := time.Duration(rand.Int31n(2000)) * time.Millisecond - time.Sleep(wait) - } - } - - return nil -} - -func testJoin(t *testing.T, s1 *Server, other ...*Server) { - addr := fmt.Sprintf("127.0.0.1:%d", - s1.config.SerfConfig.MemberlistConfig.BindPort) - for _, s2 := range other { - if num, err := s2.Join([]string{addr}); err != nil { - t.Fatalf("err: %v", err) - } else if num != 1 { - t.Fatalf("bad: %d", num) - } - } -} - func TestServer_RPC(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() var out struct{} @@ -152,7 +38,7 @@ func TestServer_RPC(t *testing.T) { } } -func TestServer_RPC_MixedTLS(t *testing.T) { +func TestServer_RPC_TLS(t *testing.T) { t.Parallel() const ( cafile = "../helper/tlsutil/testdata/ca.pem" @@ -161,7 +47,8 @@ func TestServer_RPC_MixedTLS(t *testing.T) { ) dir := tmpDir(t) defer os.RemoveAll(dir) - s1 := testServer(t, func(c *Config) { + s1 := TestServer(t, func(c *Config) { + c.Region = "regionFoo" c.BootstrapExpect = 3 c.DevMode = false c.DevDisableBootstrap = true @@ -177,66 +64,124 @@ func TestServer_RPC_MixedTLS(t *testing.T) { }) defer s1.Shutdown() - s2 := testServer(t, func(c *Config) { + s2 := TestServer(t, func(c *Config) { + c.Region = "regionFoo" c.BootstrapExpect = 3 c.DevMode = false c.DevDisableBootstrap = true c.DataDir = path.Join(dir, "node2") + c.TLSConfig = &config.TLSConfig{ + EnableHTTP: true, + EnableRPC: true, + VerifyServerHostname: true, + CAFile: cafile, + CertFile: foocert, + KeyFile: fookey, + } }) defer s2.Shutdown() - s3 := testServer(t, func(c *Config) { + s3 := TestServer(t, func(c *Config) { + c.Region = "regionFoo" c.BootstrapExpect = 3 c.DevMode = false c.DevDisableBootstrap = true c.DataDir = path.Join(dir, "node3") + c.TLSConfig = &config.TLSConfig{ + EnableHTTP: true, + EnableRPC: true, + VerifyServerHostname: true, + CAFile: cafile, + CertFile: foocert, + KeyFile: fookey, + } }) defer s3.Shutdown() - testJoin(t, s1, s2, s3) + TestJoin(t, s1, s2, s3) + testutil.WaitForLeader(t, s1.RPC) + + // Part of a server joining is making an RPC request, so just by testing + // that there is a leader we verify that the RPCs are working over TLS. +} + +func TestServer_RPC_MixedTLS(t *testing.T) { + t.Parallel() + const ( + cafile = "../helper/tlsutil/testdata/ca.pem" + foocert = "../helper/tlsutil/testdata/nomad-foo.pem" + fookey = "../helper/tlsutil/testdata/nomad-foo-key.pem" + ) + dir := tmpDir(t) + defer os.RemoveAll(dir) + s1 := TestServer(t, func(c *Config) { + c.Region = "regionFoo" + c.BootstrapExpect = 3 + c.DevMode = false + c.DevDisableBootstrap = true + c.DataDir = path.Join(dir, "node1") + c.TLSConfig = &config.TLSConfig{ + EnableHTTP: true, + EnableRPC: true, + VerifyServerHostname: true, + CAFile: cafile, + CertFile: foocert, + KeyFile: fookey, + } + }) + defer s1.Shutdown() - l1, l2, l3, shutdown := make(chan error, 1), make(chan error, 1), make(chan error, 1), make(chan struct{}, 1) + s2 := TestServer(t, func(c *Config) { + c.Region = "regionFoo" + c.BootstrapExpect = 3 + c.DevMode = false + c.DevDisableBootstrap = true + c.DataDir = path.Join(dir, "node2") + c.TLSConfig = &config.TLSConfig{ + EnableHTTP: true, + EnableRPC: true, + VerifyServerHostname: true, + CAFile: cafile, + CertFile: foocert, + KeyFile: fookey, + } + }) + defer s2.Shutdown() + s3 := TestServer(t, func(c *Config) { + c.Region = "regionFoo" + c.BootstrapExpect = 3 + c.DevMode = false + c.DevDisableBootstrap = true + c.DataDir = path.Join(dir, "node3") + }) + defer s3.Shutdown() - wait := func(done chan error, rpc func(string, interface{}, interface{}) error) { - for { - select { - case <-shutdown: - return - default: - } + TestJoin(t, s1, s2, s3) - args := &structs.GenericRequest{} - var leader string - err := rpc("Status.Leader", args, &leader) - if err != nil || leader != "" { - done <- err - } + // Ensure that we do not form a quorum + start := time.Now() + for { + if time.Now().After(start.Add(2 * time.Second)) { + break } - } - go wait(l1, s1.RPC) - go wait(l2, s2.RPC) - go wait(l3, s3.RPC) - - select { - case <-time.After(5 * time.Second): - case err := <-l1: - t.Fatalf("Server 1 has leader or error: %v", err) - case err := <-l2: - t.Fatalf("Server 2 has leader or error: %v", err) - case err := <-l3: - t.Fatalf("Server 3 has leader or error: %v", err) + args := &structs.GenericRequest{} + var leader string + err := s1.RPC("Status.Leader", args, &leader) + if err == nil || leader != "" { + t.Fatalf("Got leader or no error: %q %v", leader, err) + } } } func TestServer_Regions(t *testing.T) { t.Parallel() // Make the servers - s1 := testServer(t, func(c *Config) { + s1 := TestServer(t, func(c *Config) { c.Region = "region1" }) defer s1.Shutdown() - s2 := testServer(t, func(c *Config) { + s2 := TestServer(t, func(c *Config) { c.Region = "region2" }) defer s2.Shutdown() @@ -262,7 +207,7 @@ func TestServer_Regions(t *testing.T) { func TestServer_Reload_Vault(t *testing.T) { t.Parallel() - s1 := testServer(t, func(c *Config) { + s1 := TestServer(t, func(c *Config) { c.Region = "region1" }) defer s1.Shutdown() @@ -303,7 +248,7 @@ func TestServer_Reload_TLSConnections_PlaintextToTLS(t *testing.T) { dir := tmpDir(t) defer os.RemoveAll(dir) - s1 := testServer(t, func(c *Config) { + s1 := TestServer(t, func(c *Config) { c.DataDir = path.Join(dir, "nodeA") }) defer s1.Shutdown() @@ -353,7 +298,7 @@ func TestServer_Reload_TLSConnections_TLSToPlaintext_RPC(t *testing.T) { dir := tmpDir(t) defer os.RemoveAll(dir) - s1 := testServer(t, func(c *Config) { + s1 := TestServer(t, func(c *Config) { c.DataDir = path.Join(dir, "nodeB") c.TLSConfig = &config.TLSConfig{ EnableHTTP: true, @@ -400,7 +345,7 @@ func TestServer_Reload_TLSConnections_Raft(t *testing.T) { dir := tmpDir(t) defer os.RemoveAll(dir) - s1 := testServer(t, func(c *Config) { + s1 := TestServer(t, func(c *Config) { c.BootstrapExpect = 2 c.DevMode = false c.DevDisableBootstrap = true @@ -410,7 +355,7 @@ func TestServer_Reload_TLSConnections_Raft(t *testing.T) { }) defer s1.Shutdown() - s2 := testServer(t, func(c *Config) { + s2 := TestServer(t, func(c *Config) { c.BootstrapExpect = 2 c.DevMode = false c.DevDisableBootstrap = true @@ -420,7 +365,7 @@ func TestServer_Reload_TLSConnections_Raft(t *testing.T) { }) defer s2.Shutdown() - testJoin(t, s1, s2) + TestJoin(t, s1, s2) servers := []*Server{s1, s2} testutil.WaitForLeader(t, s1.RPC) diff --git a/nomad/stats_fetcher.go b/nomad/stats_fetcher.go index 3d59ad6cbb5..a8c34d18f0f 100644 --- a/nomad/stats_fetcher.go +++ b/nomad/stats_fetcher.go @@ -6,6 +6,7 @@ import ( "sync" "github.com/hashicorp/consul/agent/consul/autopilot" + "github.com/hashicorp/nomad/helper/pool" "github.com/hashicorp/serf/serf" ) @@ -18,14 +19,14 @@ import ( // as we run the health check fairly frequently. type StatsFetcher struct { logger *log.Logger - pool *ConnPool + pool *pool.ConnPool region string inflight map[string]struct{} inflightLock sync.Mutex } // NewStatsFetcher returns a stats fetcher. -func NewStatsFetcher(logger *log.Logger, pool *ConnPool, region string) *StatsFetcher { +func NewStatsFetcher(logger *log.Logger, pool *pool.ConnPool, region string) *StatsFetcher { return &StatsFetcher{ logger: logger, pool: pool, diff --git a/nomad/stats_fetcher_test.go b/nomad/stats_fetcher_test.go index a6b0052d119..d96987b8a69 100644 --- a/nomad/stats_fetcher_test.go +++ b/nomad/stats_fetcher_test.go @@ -17,16 +17,16 @@ func TestStatsFetcher(t *testing.T) { c.BootstrapExpect = 3 } - s1 := testServer(t, conf) + s1 := TestServer(t, conf) defer s1.Shutdown() - s2 := testServer(t, conf) + s2 := TestServer(t, conf) defer s2.Shutdown() - s3 := testServer(t, conf) + s3 := TestServer(t, conf) defer s3.Shutdown() - testJoin(t, s1, s2, s3) + TestJoin(t, s1, s2, s3) testutil.WaitForLeader(t, s1.RPC) members := s1.serf.Members() diff --git a/nomad/status_endpoint.go b/nomad/status_endpoint.go index baa700ff573..a79bbd22050 100644 --- a/nomad/status_endpoint.go +++ b/nomad/status_endpoint.go @@ -4,7 +4,10 @@ import ( "fmt" "strconv" + "errors" + "github.com/hashicorp/consul/agent/consul/autopilot" + "github.com/hashicorp/nomad/nomad/structs" ) @@ -126,3 +129,20 @@ func (s *Status) RaftStats(args struct{}, reply *autopilot.ServerStats) error { return nil } + +// HasNodeConn returns whether the server has a connection to the requested +// Node. +func (s *Status) HasNodeConn(args *structs.NodeSpecificRequest, reply *structs.NodeConnQueryResponse) error { + // Validate the args + if args.NodeID == "" { + return errors.New("Must provide the NodeID") + } + + state, ok := s.srv.getNodeConn(args.NodeID) + if ok { + reply.Connected = true + reply.Established = state.Established + } + + return nil +} diff --git a/nomad/status_endpoint_test.go b/nomad/status_endpoint_test.go index ab48ab7f0ef..eb968b673e9 100644 --- a/nomad/status_endpoint_test.go +++ b/nomad/status_endpoint_test.go @@ -5,15 +5,17 @@ import ( "github.com/hashicorp/net-rpc-msgpackrpc" "github.com/hashicorp/nomad/acl" + "github.com/hashicorp/nomad/helper/uuid" "github.com/hashicorp/nomad/nomad/mock" "github.com/hashicorp/nomad/nomad/structs" "github.com/hashicorp/nomad/testutil" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestStatusVersion(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) @@ -44,7 +46,7 @@ func TestStatusVersion(t *testing.T) { func TestStatusPing(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) @@ -57,7 +59,7 @@ func TestStatusPing(t *testing.T) { func TestStatusLeader(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) @@ -79,7 +81,7 @@ func TestStatusLeader(t *testing.T) { func TestStatusPeers(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) @@ -100,7 +102,7 @@ func TestStatusPeers(t *testing.T) { func TestStatusMembers(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) assert := assert.New(t) @@ -119,7 +121,7 @@ func TestStatusMembers(t *testing.T) { func TestStatusMembers_ACL(t *testing.T) { t.Parallel() - s1, root := testACLServer(t, nil) + s1, root := TestACLServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) assert := assert.New(t) @@ -169,3 +171,37 @@ func TestStatusMembers_ACL(t *testing.T) { assert.Len(out.Members, 1) } } + +func TestStatus_HasClientConn(t *testing.T) { + t.Parallel() + s1 := TestServer(t, nil) + defer s1.Shutdown() + codec := rpcClient(t, s1) + require := require.New(t) + + arg := &structs.NodeSpecificRequest{ + QueryOptions: structs.QueryOptions{ + Region: "global", + AllowStale: true, + }, + } + + // Try without setting a node id + var out structs.NodeConnQueryResponse + require.NotNil(msgpackrpc.CallWithCodec(codec, "Status.HasNodeConn", arg, &out)) + + // Set a bad node id + arg.NodeID = uuid.Generate() + var out2 structs.NodeConnQueryResponse + require.Nil(msgpackrpc.CallWithCodec(codec, "Status.HasNodeConn", arg, &out2)) + require.False(out2.Connected) + + // Create a connection on that node + s1.addNodeConn(&RPCContext{ + NodeID: arg.NodeID, + }) + var out3 structs.NodeConnQueryResponse + require.Nil(msgpackrpc.CallWithCodec(codec, "Status.HasNodeConn", arg, &out3)) + require.True(out3.Connected) + require.NotZero(out3.Established) +} diff --git a/nomad/structs/errors.go b/nomad/structs/errors.go new file mode 100644 index 00000000000..f1139cbd7ed --- /dev/null +++ b/nomad/structs/errors.go @@ -0,0 +1,126 @@ +package structs + +import ( + "errors" + "fmt" + "strings" +) + +const ( + errNoLeader = "No cluster leader" + errNoRegionPath = "No path to region" + errTokenNotFound = "ACL token not found" + errPermissionDenied = "Permission denied" + errNoNodeConn = "No path to node" + errUnknownMethod = "Unknown rpc method" + + // Prefix based errors that are used to check if the error is of a given + // type. These errors should be created with the associated constructor. + ErrUnknownAllocationPrefix = "Unknown allocation" + ErrUnknownNodePrefix = "Unknown node" + ErrUnknownJobPrefix = "Unknown job" + ErrUnknownEvaluationPrefix = "Unknown evaluation" + ErrUnknownDeploymentPrefix = "Unknown deployment" +) + +var ( + ErrNoLeader = errors.New(errNoLeader) + ErrNoRegionPath = errors.New(errNoRegionPath) + ErrTokenNotFound = errors.New(errTokenNotFound) + ErrPermissionDenied = errors.New(errPermissionDenied) + ErrNoNodeConn = errors.New(errNoNodeConn) + ErrUnknownMethod = errors.New(errUnknownMethod) +) + +// IsErrNoLeader returns whether the error is due to there being no leader. +func IsErrNoLeader(err error) bool { + return err != nil && strings.Contains(err.Error(), errNoLeader) +} + +// IsErrNoRegionPath returns whether the error is due to there being no path to +// the given region. +func IsErrNoRegionPath(err error) bool { + return err != nil && strings.Contains(err.Error(), errNoRegionPath) +} + +// IsErrTokenNotFound returns whether the error is due to the passed token not +// being resolvable. +func IsErrTokenNotFound(err error) bool { + return err != nil && strings.Contains(err.Error(), errTokenNotFound) +} + +// IsErrPermissionDenied returns whether the error is due to the operation not +// being allowed due to lack of permissions. +func IsErrPermissionDenied(err error) bool { + return err != nil && strings.Contains(err.Error(), errPermissionDenied) +} + +// IsErrNoNodeConn returns whether the error is due to there being no path to +// the given node. +func IsErrNoNodeConn(err error) bool { + return err != nil && strings.Contains(err.Error(), errNoNodeConn) +} + +// IsErrUnknownMethod returns whether the error is due to the operation not +// being allowed due to lack of permissions. +func IsErrUnknownMethod(err error) bool { + return err != nil && strings.Contains(err.Error(), errUnknownMethod) +} + +// NewErrUnknownAllocation returns a new error caused by the allocation being +// unknown. +func NewErrUnknownAllocation(allocID string) error { + return fmt.Errorf("%s %q", ErrUnknownAllocationPrefix, allocID) +} + +// NewErrUnknownNode returns a new error caused by the node being unknown. +func NewErrUnknownNode(nodeID string) error { + return fmt.Errorf("%s %q", ErrUnknownNodePrefix, nodeID) +} + +// NewErrUnknownJob returns a new error caused by the job being unknown. +func NewErrUnknownJob(jobID string) error { + return fmt.Errorf("%s %q", ErrUnknownJobPrefix, jobID) +} + +// NewErrUnknownEvaluation returns a new error caused by the evaluation being +// unknown. +func NewErrUnknownEvaluation(evaluationID string) error { + return fmt.Errorf("%s %q", ErrUnknownEvaluationPrefix, evaluationID) +} + +// NewErrUnknownDeployment returns a new error caused by the deployment being +// unknown. +func NewErrUnknownDeployment(deploymentID string) error { + return fmt.Errorf("%s %q", ErrUnknownDeploymentPrefix, deploymentID) +} + +// IsErrUnknownAllocation returns whether the error is due to an unknown +// allocation. +func IsErrUnknownAllocation(err error) bool { + return err != nil && strings.Contains(err.Error(), ErrUnknownAllocationPrefix) +} + +// IsErrUnknownNode returns whether the error is due to an unknown +// node. +func IsErrUnknownNode(err error) bool { + return err != nil && strings.Contains(err.Error(), ErrUnknownNodePrefix) +} + +// IsErrUnknownJob returns whether the error is due to an unknown +// job. +func IsErrUnknownJob(err error) bool { + return err != nil && strings.Contains(err.Error(), ErrUnknownJobPrefix) +} + +// IsErrUnknownEvaluation returns whether the error is due to an unknown +// evaluation. +func IsErrUnknownEvaluation(err error) bool { + return err != nil && strings.Contains(err.Error(), ErrUnknownEvaluationPrefix) +} + +// IsErrUnknownDeployment returns whether the error is due to an unknown +// deployment. +func IsErrUnknownDeployment(err error) bool { + return err != nil && strings.Contains(err.Error(), ErrUnknownDeploymentPrefix) +} diff --git a/nomad/structs/funcs.go b/nomad/structs/funcs.go index ccd2eb6ed3b..c4ecd8b0e45 100644 --- a/nomad/structs/funcs.go +++ b/nomad/structs/funcs.go @@ -1,6 +1,8 @@ package structs import ( + "crypto/subtle" + "encoding/base64" "encoding/binary" "fmt" "math" @@ -292,3 +294,30 @@ func CompileACLObject(cache *lru.TwoQueueCache, policies []*ACLPolicy) (*acl.ACL cache.Add(cacheKey, aclObj) return aclObj, nil } + +// GenerateMigrateToken will create a token for a client to access an +// authenticated volume of another client to migrate data for sticky volumes. +func GenerateMigrateToken(allocID, nodeSecretID string) (string, error) { + h, err := blake2b.New512([]byte(nodeSecretID)) + if err != nil { + return "", err + } + h.Write([]byte(allocID)) + return base64.URLEncoding.EncodeToString(h.Sum(nil)), nil +} + +// CompareMigrateToken returns true if two migration tokens can be computed and +// are equal. +func CompareMigrateToken(allocID, nodeSecretID, otherMigrateToken string) bool { + h, err := blake2b.New512([]byte(nodeSecretID)) + if err != nil { + return false + } + h.Write([]byte(allocID)) + + otherBytes, err := base64.URLEncoding.DecodeString(otherMigrateToken) + if err != nil { + return false + } + return subtle.ConstantTimeCompare(h.Sum(nil), otherBytes) == 1 +} diff --git a/nomad/structs/funcs_test.go b/nomad/structs/funcs_test.go index f2cd88d5ed8..7c0ba5dcab1 100644 --- a/nomad/structs/funcs_test.go +++ b/nomad/structs/funcs_test.go @@ -1,6 +1,7 @@ package structs import ( + "encoding/base64" "fmt" "testing" @@ -359,3 +360,25 @@ func TestCompileACLObject(t *testing.T) { t.Fatalf("expected same object") } } + +// TestGenerateMigrateToken asserts the migrate token is valid for use in HTTP +// headers and CompareMigrateToken works as expected. +func TestGenerateMigrateToken(t *testing.T) { + assert := assert.New(t) + allocID := uuid.Generate() + nodeSecret := uuid.Generate() + token, err := GenerateMigrateToken(allocID, nodeSecret) + assert.Nil(err) + _, err = base64.URLEncoding.DecodeString(token) + assert.Nil(err) + + assert.True(CompareMigrateToken(allocID, nodeSecret, token)) + assert.False(CompareMigrateToken("x", nodeSecret, token)) + assert.False(CompareMigrateToken(allocID, "x", token)) + assert.False(CompareMigrateToken(allocID, nodeSecret, "x")) + + token2, err := GenerateMigrateToken("x", nodeSecret) + assert.Nil(err) + assert.False(CompareMigrateToken(allocID, nodeSecret, token2)) + assert.True(CompareMigrateToken("x", nodeSecret, token2)) +} diff --git a/nomad/structs/streaming_rpc.go b/nomad/structs/streaming_rpc.go new file mode 100644 index 00000000000..6172c05e6f1 --- /dev/null +++ b/nomad/structs/streaming_rpc.go @@ -0,0 +1,72 @@ +package structs + +import ( + "fmt" + "io" + "sync" +) + +// StreamingRpcHeader is the first struct serialized after entering the +// streaming RPC mode. The header is used to dispatch to the correct method. +type StreamingRpcHeader struct { + // Method is the name of the method to invoke. + Method string +} + +// StreamingRpcAck is used to acknowledge receiving the StreamingRpcHeader and +// routing to the requirested handler. +type StreamingRpcAck struct { + // Error is used to return whether an error occurred establishing the + // streaming RPC. This error occurs before entering the RPC handler. + Error string +} + +// StreamingRpcHandler defines the handler for a streaming RPC. +type StreamingRpcHandler func(conn io.ReadWriteCloser) + +// StreamingRpcRegistery is used to add and retrieve handlers +type StreamingRpcRegistery struct { + registry map[string]StreamingRpcHandler +} + +// NewStreamingRpcRegistery creates a new registry. All registrations of +// handlers should be done before retrieving handlers. +func NewStreamingRpcRegistery() *StreamingRpcRegistery { + return &StreamingRpcRegistery{ + registry: make(map[string]StreamingRpcHandler), + } +} + +// Register registers a new handler for the given method name +func (s *StreamingRpcRegistery) Register(method string, handler StreamingRpcHandler) { + s.registry[method] = handler +} + +// GetHandler returns a handler for the given method or an error if it doesn't exist. +func (s *StreamingRpcRegistery) GetHandler(method string) (StreamingRpcHandler, error) { + h, ok := s.registry[method] + if !ok { + return nil, fmt.Errorf("%s: %q", ErrUnknownMethod, method) + } + + return h, nil +} + +// Bridge is used to just link two connections together and copy traffic +func Bridge(a, b io.ReadWriteCloser) { + wg := sync.WaitGroup{} + wg.Add(2) + go func() { + defer wg.Done() + io.Copy(a, b) + a.Close() + b.Close() + }() + go func() { + defer wg.Done() + io.Copy(b, a) + a.Close() + b.Close() + }() + wg.Wait() +} diff --git a/nomad/structs/structs.go b/nomad/structs/structs.go index bcc408074a9..52630853383 100644 --- a/nomad/structs/structs.go +++ b/nomad/structs/structs.go @@ -39,11 +39,6 @@ import ( ) var ( - ErrNoLeader = fmt.Errorf("No cluster leader") - ErrNoRegionPath = fmt.Errorf("No path to region") - ErrTokenNotFound = errors.New("ACL token not found") - ErrPermissionDenied = errors.New("Permission denied") - // validPolicyName is used to validate a policy name validPolicyName = regexp.MustCompile("^[a-zA-Z0-9-]{1,128}$") @@ -121,6 +116,12 @@ const ( // DefaultNamespace is the default namespace. DefaultNamespace = "default" DefaultNamespaceDescription = "Default shared namespace" + + // JitterFraction is a the limit to the amount of jitter we apply + // to a user specified MaxQueryTime. We divide the specified time by + // the fraction. So 16 == 6.25% limit of jitter. This jitter is also + // applied to RPCHoldTimeout. + JitterFraction = 16 ) // Context defines the scope in which a search for Nomad object operates, and @@ -1037,6 +1038,18 @@ type DeploymentUpdateResponse struct { WriteMeta } +// NodeConnQueryResponse is used to respond to a query of whether a server has +// a connection to a specific Node +type NodeConnQueryResponse struct { + // Connected indicates whether a connection to the Client exists + Connected bool + + // Established marks the time at which the connection was established + Established time.Time + + QueryMeta +} + const ( NodeStatusInit = "initializing" NodeStatusReady = "ready" @@ -5987,6 +6000,9 @@ var ( } ) +// TODO Figure out if we can remove this. This is our fork that is just way +// behind. I feel like its original purpose was to pin at a stable version but +// now we can accomplish this with vendoring. var HashiMsgpackHandle = func() *hcodec.MsgpackHandle { h := &hcodec.MsgpackHandle{RawToString: true} diff --git a/nomad/system_endpoint_test.go b/nomad/system_endpoint_test.go index 09f4e7dbdd1..ba9353e1e29 100644 --- a/nomad/system_endpoint_test.go +++ b/nomad/system_endpoint_test.go @@ -16,7 +16,7 @@ import ( func TestSystemEndpoint_GarbageCollect(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) @@ -66,7 +66,7 @@ func TestSystemEndpoint_GarbageCollect(t *testing.T) { func TestSystemEndpoint_GarbageCollect_ACL(t *testing.T) { t.Parallel() - s1, root := testACLServer(t, nil) + s1, root := TestACLServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) assert := assert.New(t) @@ -110,7 +110,7 @@ func TestSystemEndpoint_GarbageCollect_ACL(t *testing.T) { func TestSystemEndpoint_ReconcileSummaries(t *testing.T) { t.Parallel() - s1 := testServer(t, nil) + s1 := TestServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) testutil.WaitForLeader(t, s1.RPC) @@ -172,7 +172,7 @@ func TestSystemEndpoint_ReconcileSummaries(t *testing.T) { func TestSystemEndpoint_ReconcileJobSummaries_ACL(t *testing.T) { t.Parallel() - s1, root := testACLServer(t, nil) + s1, root := TestACLServer(t, nil) defer s1.Shutdown() codec := rpcClient(t, s1) assert := assert.New(t) diff --git a/nomad/testing.go b/nomad/testing.go new file mode 100644 index 00000000000..2859dfb63db --- /dev/null +++ b/nomad/testing.go @@ -0,0 +1,120 @@ +package nomad + +import ( + "fmt" + "log" + "math/rand" + "net" + "sync/atomic" + "time" + + "github.com/hashicorp/consul/lib/freeport" + "github.com/hashicorp/nomad/command/agent/consul" + "github.com/hashicorp/nomad/helper/testlog" + "github.com/hashicorp/nomad/nomad/mock" + "github.com/hashicorp/nomad/nomad/structs" + "github.com/mitchellh/go-testing-interface" +) + +var ( + nodeNumber uint32 = 0 +) + +func TestACLServer(t testing.T, cb func(*Config)) (*Server, *structs.ACLToken) { + server := TestServer(t, func(c *Config) { + c.ACLEnabled = true + if cb != nil { + cb(c) + } + }) + token := mock.ACLManagementToken() + err := server.State().BootstrapACLTokens(1, 0, token) + if err != nil { + t.Fatalf("failed to bootstrap ACL token: %v", err) + } + return server, token +} + +func TestServer(t testing.T, cb func(*Config)) *Server { + // Setup the default settings + config := DefaultConfig() + config.Build = "0.8.0+unittest" + config.DevMode = true + nodeNum := atomic.AddUint32(&nodeNumber, 1) + config.NodeName = fmt.Sprintf("nomad-%03d", nodeNum) + + // Tighten the Serf timing + config.SerfConfig.MemberlistConfig.BindAddr = "127.0.0.1" + config.SerfConfig.MemberlistConfig.SuspicionMult = 2 + config.SerfConfig.MemberlistConfig.RetransmitMult = 2 + config.SerfConfig.MemberlistConfig.ProbeTimeout = 50 * time.Millisecond + config.SerfConfig.MemberlistConfig.ProbeInterval = 100 * time.Millisecond + config.SerfConfig.MemberlistConfig.GossipInterval = 100 * time.Millisecond + + // Tighten the Raft timing + config.RaftConfig.LeaderLeaseTimeout = 50 * time.Millisecond + config.RaftConfig.HeartbeatTimeout = 50 * time.Millisecond + config.RaftConfig.ElectionTimeout = 50 * time.Millisecond + config.RaftTimeout = 500 * time.Millisecond + + // Disable Vault + f := false + config.VaultConfig.Enabled = &f + + // Squelch output when -v isn't specified + config.LogOutput = testlog.NewWriter(t) + + // Tighten the autopilot timing + config.AutopilotConfig.ServerStabilizationTime = 100 * time.Millisecond + config.ServerHealthInterval = 50 * time.Millisecond + config.AutopilotInterval = 100 * time.Millisecond + + // Invoke the callback if any + if cb != nil { + cb(config) + } + + // Enable raft as leader if we have bootstrap on + config.RaftConfig.StartAsLeader = !config.DevDisableBootstrap + + logger := log.New(config.LogOutput, fmt.Sprintf("[%s] ", config.NodeName), log.LstdFlags) + catalog := consul.NewMockCatalog(logger) + + for i := 10; i >= 0; i-- { + // Get random ports + ports := freeport.GetT(t, 2) + config.RPCAddr = &net.TCPAddr{ + IP: []byte{127, 0, 0, 1}, + Port: ports[0], + } + config.SerfConfig.MemberlistConfig.BindPort = ports[1] + + // Create server + server, err := NewServer(config, catalog, logger) + if err == nil { + return server + } else if i == 0 { + t.Fatalf("err: %v", err) + } else { + if server != nil { + server.Shutdown() + } + wait := time.Duration(rand.Int31n(2000)) * time.Millisecond + time.Sleep(wait) + } + } + + return nil +} + +func TestJoin(t testing.T, s1 *Server, other ...*Server) { + addr := fmt.Sprintf("127.0.0.1:%d", + s1.config.SerfConfig.MemberlistConfig.BindPort) + for _, s2 := range other { + if num, err := s2.Join([]string{addr}); err != nil { + t.Fatalf("err: %v", err) + } else if num != 1 { + t.Fatalf("bad: %d", num) + } + } +} diff --git a/nomad/util.go b/nomad/util.go index be01dc41873..48c85050f7e 100644 --- a/nomad/util.go +++ b/nomad/util.go @@ -6,7 +6,6 @@ import ( "net" "os" "path/filepath" - "runtime" "strconv" version "github.com/hashicorp/go-version" @@ -21,18 +20,6 @@ func ensurePath(path string, dir bool) error { return os.MkdirAll(path, 0755) } -// RuntimeStats is used to return various runtime information -func RuntimeStats() map[string]string { - return map[string]string{ - "kernel.name": runtime.GOOS, - "arch": runtime.GOARCH, - "version": runtime.Version(), - "max_procs": strconv.FormatInt(int64(runtime.GOMAXPROCS(0)), 10), - "goroutines": strconv.FormatInt(int64(runtime.NumGoroutine()), 10), - "cpu_count": strconv.FormatInt(int64(runtime.NumCPU()), 10), - } -} - // serverParts is used to return the parts of a server role type serverParts struct { Name string @@ -56,6 +43,12 @@ func (s *serverParts) String() string { s.Name, s.Addr, s.Datacenter) } +func (s *serverParts) Copy() *serverParts { + ns := new(serverParts) + *ns = *s + return ns +} + // Returns if a member is a Nomad server. Returns a boolean, // and a struct with the various important components func isNomadServer(m serf.Member) (bool, *serverParts) { diff --git a/nomad/worker_test.go b/nomad/worker_test.go index faa9cc104e6..297102dfd87 100644 --- a/nomad/worker_test.go +++ b/nomad/worker_test.go @@ -46,7 +46,7 @@ func init() { func TestWorker_dequeueEvaluation(t *testing.T) { t.Parallel() - s1 := testServer(t, func(c *Config) { + s1 := TestServer(t, func(c *Config) { c.NumSchedulers = 0 c.EnabledSchedulers = []string{structs.JobTypeService} }) @@ -82,7 +82,7 @@ func TestWorker_dequeueEvaluation(t *testing.T) { // evals for the same job. func TestWorker_dequeueEvaluation_SerialJobs(t *testing.T) { t.Parallel() - s1 := testServer(t, func(c *Config) { + s1 := TestServer(t, func(c *Config) { c.NumSchedulers = 0 c.EnabledSchedulers = []string{structs.JobTypeService} }) @@ -150,7 +150,7 @@ func TestWorker_dequeueEvaluation_SerialJobs(t *testing.T) { func TestWorker_dequeueEvaluation_paused(t *testing.T) { t.Parallel() - s1 := testServer(t, func(c *Config) { + s1 := TestServer(t, func(c *Config) { c.NumSchedulers = 0 c.EnabledSchedulers = []string{structs.JobTypeService} }) @@ -197,7 +197,7 @@ func TestWorker_dequeueEvaluation_paused(t *testing.T) { func TestWorker_dequeueEvaluation_shutdown(t *testing.T) { t.Parallel() - s1 := testServer(t, func(c *Config) { + s1 := TestServer(t, func(c *Config) { c.NumSchedulers = 0 c.EnabledSchedulers = []string{structs.JobTypeService} }) @@ -226,7 +226,7 @@ func TestWorker_dequeueEvaluation_shutdown(t *testing.T) { func TestWorker_sendAck(t *testing.T) { t.Parallel() - s1 := testServer(t, func(c *Config) { + s1 := TestServer(t, func(c *Config) { c.NumSchedulers = 0 c.EnabledSchedulers = []string{structs.JobTypeService} }) @@ -273,7 +273,7 @@ func TestWorker_sendAck(t *testing.T) { func TestWorker_waitForIndex(t *testing.T) { t.Parallel() - s1 := testServer(t, func(c *Config) { + s1 := TestServer(t, func(c *Config) { c.NumSchedulers = 0 c.EnabledSchedulers = []string{structs.JobTypeService} }) @@ -308,7 +308,7 @@ func TestWorker_waitForIndex(t *testing.T) { func TestWorker_invokeScheduler(t *testing.T) { t.Parallel() - s1 := testServer(t, func(c *Config) { + s1 := TestServer(t, func(c *Config) { c.NumSchedulers = 0 c.EnabledSchedulers = []string{structs.JobTypeService} }) @@ -326,7 +326,7 @@ func TestWorker_invokeScheduler(t *testing.T) { func TestWorker_SubmitPlan(t *testing.T) { t.Parallel() - s1 := testServer(t, func(c *Config) { + s1 := TestServer(t, func(c *Config) { c.NumSchedulers = 0 c.EnabledSchedulers = []string{structs.JobTypeService} }) @@ -391,7 +391,7 @@ func TestWorker_SubmitPlan(t *testing.T) { func TestWorker_SubmitPlan_MissingNodeRefresh(t *testing.T) { t.Parallel() - s1 := testServer(t, func(c *Config) { + s1 := TestServer(t, func(c *Config) { c.NumSchedulers = 0 c.EnabledSchedulers = []string{structs.JobTypeService} }) @@ -461,7 +461,7 @@ func TestWorker_SubmitPlan_MissingNodeRefresh(t *testing.T) { func TestWorker_UpdateEval(t *testing.T) { t.Parallel() - s1 := testServer(t, func(c *Config) { + s1 := TestServer(t, func(c *Config) { c.NumSchedulers = 0 c.EnabledSchedulers = []string{structs.JobTypeService} }) @@ -508,7 +508,7 @@ func TestWorker_UpdateEval(t *testing.T) { func TestWorker_CreateEval(t *testing.T) { t.Parallel() - s1 := testServer(t, func(c *Config) { + s1 := TestServer(t, func(c *Config) { c.NumSchedulers = 0 c.EnabledSchedulers = []string{structs.JobTypeService} }) @@ -556,7 +556,7 @@ func TestWorker_CreateEval(t *testing.T) { func TestWorker_ReblockEval(t *testing.T) { t.Parallel() - s1 := testServer(t, func(c *Config) { + s1 := TestServer(t, func(c *Config) { c.NumSchedulers = 0 c.EnabledSchedulers = []string{structs.JobTypeService} }) diff --git a/website/source/api/client.html.md b/website/source/api/client.html.md index e2af474b9a3..dc2d92fa5b8 100644 --- a/website/source/api/client.html.md +++ b/website/source/api/client.html.md @@ -3,15 +3,23 @@ layout: api page_title: Client - HTTP API sidebar_current: api-client description: |- - The /client endpoints interact with the local Nomad agent to interact with - client members. + The /client endpoints are used to access client statistics and inspect + allocations running on a particular client. --- # Client HTTP API -The `/client` endpoints are used to interact with the Nomad clients. The API -endpoints are hosted by the Nomad client and requests have to be made to the -Client where the particular allocation was placed. +The `/client` endpoints are used to interact with the Nomad clients. + +Since Nomad 0.8.0, both a client and server can handle client endpoints. This is +particularly useful for when a direct connection to a client is not possible due +to the network configuration. For high volume access to the client endpoints, +particularly endpoints streaming file contents, direct access to the node should +be preferred as it avoids adding additional load to the servers. + +When accessing the endpoints via the server, if the desired node is ambiguous +based on the URL, an additional `?node_id` query parameter must be provided to +disambiguate. ## Read Stats @@ -31,6 +39,13 @@ The table below shows this endpoint's support for | ---------------- | ------------ | | `NO` | `node:read` | +### Parameters + +- `node_id` `(string: )` - Specifies the node to query. This is + required when the endpoint is being accessed via a server. This is specified as + part of the URL. Note, this must be the _full_ node ID, not the short + 8-character one. This is specified as part of the path. + ### Sample Request ```text @@ -132,12 +147,10 @@ $ curl \ } ``` -## Read Allocation +## Read Allocation Statistics The client `allocation` endpoint is used to query the actual resources consumed -by an allocation. The API endpoint is hosted by the Nomad client and requests -have to be made to the nomad client whose resource usage metrics are of -interest. +by an allocation. | Method | Path | Produces | | ------ | ------------------------------------ | -------------------------- | @@ -563,9 +576,37 @@ $ curl \ ## GC Allocation +This endpoint forces a garbage collection of a particular, stopped allocation +on a node. + +| Method | Path | Produces | +| ------ | --------------------------------- | -------------------------- | +| `GET` | `/client/allocation/:alloc_id/gc` | `application/json` | + +The table below shows this endpoint's support for +[blocking queries](/api/index.html#blocking-queries) and +[required ACLs](/api/index.html#acls). + +| Blocking Queries | ACL Required | +| ---------------- | ---------------------- | +| `NO` | `namespace:submit-job` | + +### Parameters + +- `:alloc_id` `(string: )` - Specifies the allocation ID to query. + This is specified as part of the URL. Note, this must be the _full_ allocation + ID, not the short 8-character one. This is specified as part of the path. + +### Sample Request + +```text +$ curl \ + https://nomad.rocks/v1/client/allocation/5fc98185-17ff-26bc-a802-0c74fa471c99/gc +``` + +## GC All Allocation + This endpoint forces a garbage collection of all stopped allocations on a node. -The API endpoint is hosted by the Nomad client and requests have to be made to -the Nomad client whose allocations should be garbage collected. | Method | Path | Produces | | ------ | ---------------------------- | -------------------------- | @@ -579,6 +620,13 @@ The table below shows this endpoint's support for | ---------------- | ------------ | | `NO` | `node:write` | +### Parameters + +- `node_id` `(string: )` - Specifies the node to target. This is + required when the endpoint is being accessed via a server. This is specified as + part of the URL. Note, this must be the _full_ node ID, not the short + 8-character one. This is specified as part of the path. + ### Sample Request ```text