diff --git a/nomad/core_sched.go b/nomad/core_sched.go index 0b1d7e62420..08f65e5bf82 100644 --- a/nomad/core_sched.go +++ b/nomad/core_sched.go @@ -8,9 +8,7 @@ import ( log "github.com/hashicorp/go-hclog" memdb "github.com/hashicorp/go-memdb" - multierror "github.com/hashicorp/go-multierror" version "github.com/hashicorp/go-version" - cstructs "github.com/hashicorp/nomad/client/structs" "github.com/hashicorp/nomad/nomad/state" "github.com/hashicorp/nomad/nomad/structs" "github.com/hashicorp/nomad/scheduler" @@ -711,212 +709,30 @@ func allocGCEligible(a *structs.Allocation, job *structs.Job, gcTime time.Time, return timeDiff > interval.Nanoseconds() } +// TODO: we need a periodic trigger to iterate over all the volumes and split +// them up into separate work items, same as we do for jobs. + // csiVolumeClaimGC is used to garbage collect CSI volume claims func (c *CoreScheduler) csiVolumeClaimGC(eval *structs.Evaluation) error { - c.logger.Trace("garbage collecting unclaimed CSI volume claims") + c.logger.Trace("garbage collecting unclaimed CSI volume claims", "eval.JobID", eval.JobID) // Volume ID smuggled in with the eval's own JobID evalVolID := strings.Split(eval.JobID, ":") - if len(evalVolID) != 3 { + + // COMPAT(1.0): 0.11.0 shipped with 3 fields. tighten this check to len == 2 + if len(evalVolID) < 2 { c.logger.Error("volume gc called without volID") return nil } volID := evalVolID[1] - runningAllocs := evalVolID[2] == "purge" - return volumeClaimReap(c.srv, volID, eval.Namespace, - c.srv.config.Region, eval.LeaderACL, runningAllocs) -} - -func volumeClaimReap(srv RPCServer, volID, namespace, region, leaderACL string, runningAllocs bool) error { - - ws := memdb.NewWatchSet() - - vol, err := srv.State().CSIVolumeByID(ws, namespace, volID) - if err != nil { - return err - } - if vol == nil { - return nil - } - vol, err = srv.State().CSIVolumeDenormalize(ws, vol) - if err != nil { - return err - } - - plug, err := srv.State().CSIPluginByID(ws, vol.PluginID) - if err != nil { - return err - } - - nodeClaims := collectClaimsToGCImpl(vol, runningAllocs) - - var result *multierror.Error - for _, claim := range vol.PastClaims { - nodeClaims, err = volumeClaimReapImpl(srv, - &volumeClaimReapArgs{ - vol: vol, - plug: plug, - claim: claim, - namespace: namespace, - region: region, - leaderACL: leaderACL, - nodeClaims: nodeClaims, - }, - ) - if err != nil { - result = multierror.Append(result, err) - continue - } + req := &structs.CSIVolumeClaimRequest{ + VolumeID: volID, + Claim: structs.CSIVolumeClaimRelease, } - return result.ErrorOrNil() - -} + req.Namespace = eval.Namespace + req.Region = c.srv.config.Region -func collectClaimsToGCImpl(vol *structs.CSIVolume, runningAllocs bool) map[string]int { - nodeClaims := map[string]int{} // node IDs -> count - - collectFunc := func(allocs map[string]*structs.Allocation, - claims map[string]*structs.CSIVolumeClaim) { - - for allocID, alloc := range allocs { - claim, ok := claims[allocID] - if !ok { - // COMPAT(1.0): the CSIVolumeClaim fields were added - // after 0.11.1, so claims made before that may be - // missing this value. note that we'll have non-nil - // allocs here because we called denormalize on the - // value. - claim = &structs.CSIVolumeClaim{ - AllocationID: allocID, - NodeID: alloc.NodeID, - State: structs.CSIVolumeClaimStateTaken, - } - } - nodeClaims[claim.NodeID]++ - if runningAllocs || alloc.Terminated() { - // only overwrite the PastClaim if this is new, - // so that we can track state between subsequent calls - if _, exists := vol.PastClaims[claim.AllocationID]; !exists { - claim.State = structs.CSIVolumeClaimStateTaken - vol.PastClaims[claim.AllocationID] = claim - } - } - } - } - - collectFunc(vol.WriteAllocs, vol.WriteClaims) - collectFunc(vol.ReadAllocs, vol.ReadClaims) - return nodeClaims -} - -type volumeClaimReapArgs struct { - vol *structs.CSIVolume - plug *structs.CSIPlugin - claim *structs.CSIVolumeClaim - region string - namespace string - leaderACL string - nodeClaims map[string]int // node IDs -> count -} - -func volumeClaimReapImpl(srv RPCServer, args *volumeClaimReapArgs) (map[string]int, error) { - vol := args.vol - claim := args.claim - - var err error - var nReq *cstructs.ClientCSINodeDetachVolumeRequest - - checkpoint := func(claimState structs.CSIVolumeClaimState) error { - req := &structs.CSIVolumeClaimRequest{ - VolumeID: vol.ID, - AllocationID: claim.AllocationID, - Claim: structs.CSIVolumeClaimRelease, - WriteRequest: structs.WriteRequest{ - Region: args.region, - Namespace: args.namespace, - AuthToken: args.leaderACL, - }, - } - return srv.RPC("CSIVolume.Claim", req, &structs.CSIVolumeClaimResponse{}) - } - - // previous checkpoints may have set the past claim state already. - // in practice we should never see CSIVolumeClaimStateControllerDetached - // but having an option for the state makes it easy to add a checkpoint - // in a backwards compatible way if we need one later - switch claim.State { - case structs.CSIVolumeClaimStateNodeDetached: - goto NODE_DETACHED - case structs.CSIVolumeClaimStateControllerDetached: - goto RELEASE_CLAIM - case structs.CSIVolumeClaimStateReadyToFree: - goto RELEASE_CLAIM - } - - // (1) NodePublish / NodeUnstage must be completed before controller - // operations or releasing the claim. - nReq = &cstructs.ClientCSINodeDetachVolumeRequest{ - PluginID: args.plug.ID, - VolumeID: vol.ID, - ExternalID: vol.RemoteID(), - AllocID: claim.AllocationID, - NodeID: claim.NodeID, - AttachmentMode: vol.AttachmentMode, - AccessMode: vol.AccessMode, - ReadOnly: claim.Mode == structs.CSIVolumeClaimRead, - } - err = srv.RPC("ClientCSI.NodeDetachVolume", nReq, - &cstructs.ClientCSINodeDetachVolumeResponse{}) - if err != nil { - return args.nodeClaims, err - } - err = checkpoint(structs.CSIVolumeClaimStateNodeDetached) - if err != nil { - return args.nodeClaims, err - } - -NODE_DETACHED: - args.nodeClaims[claim.NodeID]-- - - // (2) we only emit the controller unpublish if no other allocs - // on the node need it, but we also only want to make this - // call at most once per node - if vol.ControllerRequired && args.nodeClaims[claim.NodeID] < 1 { - - // we need to get the CSI Node ID, which is not the same as - // the Nomad Node ID - ws := memdb.NewWatchSet() - targetNode, err := srv.State().NodeByID(ws, claim.NodeID) - if err != nil { - return args.nodeClaims, err - } - if targetNode == nil { - return args.nodeClaims, fmt.Errorf("%s: %s", - structs.ErrUnknownNodePrefix, claim.NodeID) - } - targetCSIInfo, ok := targetNode.CSINodePlugins[args.plug.ID] - if !ok { - return args.nodeClaims, fmt.Errorf("Failed to find NodeInfo for node: %s", targetNode.ID) - } - - cReq := &cstructs.ClientCSIControllerDetachVolumeRequest{ - VolumeID: vol.RemoteID(), - ClientCSINodeID: targetCSIInfo.NodeInfo.ID, - } - cReq.PluginID = args.plug.ID - err = srv.RPC("ClientCSI.ControllerDetachVolume", cReq, - &cstructs.ClientCSIControllerDetachVolumeResponse{}) - if err != nil { - return args.nodeClaims, err - } - } - -RELEASE_CLAIM: - // (3) release the claim from the state store, allowing it to be rescheduled - err = checkpoint(structs.CSIVolumeClaimStateReadyToFree) - if err != nil { - return args.nodeClaims, err - } - return args.nodeClaims, nil + err := c.srv.RPC("CSIVolume.Claim", req, &structs.CSIVolumeClaimResponse{}) + return err } diff --git a/nomad/core_sched_test.go b/nomad/core_sched_test.go index 819e0908ddb..70b500a82bc 100644 --- a/nomad/core_sched_test.go +++ b/nomad/core_sched_test.go @@ -6,10 +6,8 @@ import ( "time" memdb "github.com/hashicorp/go-memdb" - cstructs "github.com/hashicorp/nomad/client/structs" "github.com/hashicorp/nomad/helper/uuid" "github.com/hashicorp/nomad/nomad/mock" - "github.com/hashicorp/nomad/nomad/state" "github.com/hashicorp/nomad/nomad/structs" "github.com/hashicorp/nomad/testutil" "github.com/stretchr/testify/assert" @@ -2195,268 +2193,3 @@ func TestAllocation_GCEligible(t *testing.T) { alloc.ClientStatus = structs.AllocClientStatusComplete require.True(allocGCEligible(alloc, nil, time.Now(), 1000)) } - -func TestCSI_GCVolumeClaims_Collection(t *testing.T) { - t.Parallel() - srv, shutdownSrv := TestServer(t, func(c *Config) { c.NumSchedulers = 0 }) - defer shutdownSrv() - testutil.WaitForLeader(t, srv.RPC) - - state := srv.fsm.State() - ws := memdb.NewWatchSet() - index := uint64(100) - - // Create a client node, plugin, and volume - node := mock.Node() - node.Attributes["nomad.version"] = "0.11.0" // client RPCs not supported on early version - node.CSINodePlugins = map[string]*structs.CSIInfo{ - "csi-plugin-example": { - PluginID: "csi-plugin-example", - Healthy: true, - RequiresControllerPlugin: true, - NodeInfo: &structs.CSINodeInfo{}, - }, - } - node.CSIControllerPlugins = map[string]*structs.CSIInfo{ - "csi-plugin-example": { - PluginID: "csi-plugin-example", - Healthy: true, - RequiresControllerPlugin: true, - ControllerInfo: &structs.CSIControllerInfo{ - SupportsReadOnlyAttach: true, - SupportsAttachDetach: true, - SupportsListVolumes: true, - SupportsListVolumesAttachedNodes: false, - }, - }, - } - err := state.UpsertNode(99, node) - require.NoError(t, err) - volId0 := uuid.Generate() - ns := structs.DefaultNamespace - vols := []*structs.CSIVolume{{ - ID: volId0, - Namespace: ns, - PluginID: "csi-plugin-example", - AccessMode: structs.CSIVolumeAccessModeMultiNodeSingleWriter, - AttachmentMode: structs.CSIVolumeAttachmentModeFilesystem, - }} - - err = state.CSIVolumeRegister(index, vols) - index++ - require.NoError(t, err) - vol, err := state.CSIVolumeByID(ws, ns, volId0) - - require.NoError(t, err) - require.True(t, vol.ControllerRequired) - require.Len(t, vol.ReadAllocs, 0) - require.Len(t, vol.WriteAllocs, 0) - - // Create a job with 2 allocations - job := mock.Job() - job.TaskGroups[0].Volumes = map[string]*structs.VolumeRequest{ - "_": { - Name: "someVolume", - Type: structs.VolumeTypeCSI, - Source: volId0, - ReadOnly: false, - }, - } - err = state.UpsertJob(index, job) - index++ - require.NoError(t, err) - - alloc1 := mock.Alloc() - alloc1.JobID = job.ID - alloc1.NodeID = node.ID - err = state.UpsertJobSummary(index, mock.JobSummary(alloc1.JobID)) - index++ - require.NoError(t, err) - alloc1.TaskGroup = job.TaskGroups[0].Name - - alloc2 := mock.Alloc() - alloc2.JobID = job.ID - alloc2.NodeID = node.ID - err = state.UpsertJobSummary(index, mock.JobSummary(alloc2.JobID)) - index++ - require.NoError(t, err) - alloc2.TaskGroup = job.TaskGroups[0].Name - - err = state.UpsertAllocs(104, []*structs.Allocation{alloc1, alloc2}) - require.NoError(t, err) - - // Claim the volumes and verify the claims were set - err = state.CSIVolumeClaim(index, ns, volId0, &structs.CSIVolumeClaim{ - AllocationID: alloc1.ID, - NodeID: alloc1.NodeID, - Mode: structs.CSIVolumeClaimWrite, - }) - index++ - require.NoError(t, err) - - err = state.CSIVolumeClaim(index, ns, volId0, &structs.CSIVolumeClaim{ - AllocationID: alloc2.ID, - NodeID: alloc2.NodeID, - Mode: structs.CSIVolumeClaimRead, - }) - index++ - require.NoError(t, err) - - vol, err = state.CSIVolumeByID(ws, ns, volId0) - require.NoError(t, err) - require.Len(t, vol.ReadAllocs, 1) - require.Len(t, vol.WriteAllocs, 1) - - // Update both allocs as failed/terminated - alloc1.ClientStatus = structs.AllocClientStatusFailed - alloc2.ClientStatus = structs.AllocClientStatusFailed - err = state.UpdateAllocsFromClient(index, []*structs.Allocation{alloc1, alloc2}) - require.NoError(t, err) - - vol, err = state.CSIVolumeDenormalize(ws, vol) - require.NoError(t, err) - - nodeClaims := collectClaimsToGCImpl(vol, false) - require.Equal(t, nodeClaims[node.ID], 2) - require.Len(t, vol.PastClaims, 2) -} - -func TestCSI_GCVolumeClaims_Reap(t *testing.T) { - t.Parallel() - require := require.New(t) - - s, shutdownSrv := TestServer(t, func(c *Config) { c.NumSchedulers = 0 }) - defer shutdownSrv() - testutil.WaitForLeader(t, s.RPC) - - node := mock.Node() - plugin := mock.CSIPlugin() - vol := mock.CSIVolume(plugin) - alloc := mock.Alloc() - - cases := []struct { - Name string - ClaimsCount map[string]int - ControllerRequired bool - ExpectedErr string - ExpectedCount int - ExpectedClaimsCount int - ExpectedNodeDetachVolumeCount int - ExpectedControllerDetachVolumeCount int - ExpectedVolumeClaimCount int - srv *MockRPCServer - }{ - { - Name: "NodeDetachVolume fails", - ClaimsCount: map[string]int{node.ID: 1}, - ControllerRequired: true, - ExpectedErr: "node plugin missing", - ExpectedClaimsCount: 1, - ExpectedNodeDetachVolumeCount: 1, - srv: &MockRPCServer{ - state: s.State(), - nextCSINodeDetachVolumeError: fmt.Errorf("node plugin missing"), - }, - }, - { - Name: "ControllerDetachVolume no controllers", - ClaimsCount: map[string]int{node.ID: 1}, - ControllerRequired: true, - ExpectedErr: fmt.Sprintf("Unknown node: %s", node.ID), - ExpectedClaimsCount: 0, - ExpectedNodeDetachVolumeCount: 1, - ExpectedControllerDetachVolumeCount: 0, - ExpectedVolumeClaimCount: 1, - srv: &MockRPCServer{ - state: s.State(), - }, - }, - { - Name: "ControllerDetachVolume node-only", - ClaimsCount: map[string]int{node.ID: 1}, - ControllerRequired: false, - ExpectedClaimsCount: 0, - ExpectedNodeDetachVolumeCount: 1, - ExpectedControllerDetachVolumeCount: 0, - ExpectedVolumeClaimCount: 2, - srv: &MockRPCServer{ - state: s.State(), - }, - }, - } - - for _, tc := range cases { - t.Run(tc.Name, func(t *testing.T) { - vol.ControllerRequired = tc.ControllerRequired - claim := &structs.CSIVolumeClaim{ - AllocationID: alloc.ID, - NodeID: node.ID, - State: structs.CSIVolumeClaimStateTaken, - Mode: structs.CSIVolumeClaimRead, - } - nodeClaims, err := volumeClaimReapImpl(tc.srv, &volumeClaimReapArgs{ - vol: vol, - plug: plugin, - claim: claim, - region: "global", - namespace: "default", - leaderACL: "not-in-use", - nodeClaims: tc.ClaimsCount, - }) - if tc.ExpectedErr != "" { - require.EqualError(err, tc.ExpectedErr) - } else { - require.NoError(err) - } - require.Equal(tc.ExpectedClaimsCount, - nodeClaims[claim.NodeID], "expected claims remaining") - require.Equal(tc.ExpectedNodeDetachVolumeCount, - tc.srv.countCSINodeDetachVolume, "node detach RPC count") - require.Equal(tc.ExpectedControllerDetachVolumeCount, - tc.srv.countCSIControllerDetachVolume, "controller detach RPC count") - require.Equal(tc.ExpectedVolumeClaimCount, - tc.srv.countCSIVolumeClaim, "volume claim RPC count") - }) - } -} - -type MockRPCServer struct { - state *state.StateStore - - // mock responses for ClientCSI.NodeDetachVolume - nextCSINodeDetachVolumeResponse *cstructs.ClientCSINodeDetachVolumeResponse - nextCSINodeDetachVolumeError error - countCSINodeDetachVolume int - - // mock responses for ClientCSI.ControllerDetachVolume - nextCSIControllerDetachVolumeResponse *cstructs.ClientCSIControllerDetachVolumeResponse - nextCSIControllerDetachVolumeError error - countCSIControllerDetachVolume int - - // mock responses for CSI.VolumeClaim - nextCSIVolumeClaimResponse *structs.CSIVolumeClaimResponse - nextCSIVolumeClaimError error - countCSIVolumeClaim int -} - -func (srv *MockRPCServer) RPC(method string, args interface{}, reply interface{}) error { - switch method { - case "ClientCSI.NodeDetachVolume": - reply = srv.nextCSINodeDetachVolumeResponse - srv.countCSINodeDetachVolume++ - return srv.nextCSINodeDetachVolumeError - case "ClientCSI.ControllerDetachVolume": - reply = srv.nextCSIControllerDetachVolumeResponse - srv.countCSIControllerDetachVolume++ - return srv.nextCSIControllerDetachVolumeError - case "CSIVolume.Claim": - reply = srv.nextCSIVolumeClaimResponse - srv.countCSIVolumeClaim++ - return srv.nextCSIVolumeClaimError - default: - return fmt.Errorf("unexpected method %q passed to mock", method) - } - -} - -func (srv *MockRPCServer) State() *state.StateStore { return srv.state } diff --git a/nomad/fsm.go b/nomad/fsm.go index 9ec1ef08651..b9f412393dd 100644 --- a/nomad/fsm.go +++ b/nomad/fsm.go @@ -270,6 +270,8 @@ func (n *nomadFSM) Apply(log *raft.Log) interface{} { return n.applyCSIVolumeDeregister(buf[1:], log.Index) case structs.CSIVolumeClaimRequestType: return n.applyCSIVolumeClaim(buf[1:], log.Index) + case structs.CSIVolumeClaimBatchRequestType: + return n.applyCSIVolumeBatchClaim(buf[1:], log.Index) case structs.ScalingEventRegisterRequestType: return n.applyUpsertScalingEvent(buf[1:], log.Index) } @@ -1156,33 +1158,35 @@ func (n *nomadFSM) applyCSIVolumeDeregister(buf []byte, index uint64) interface{ return nil } -func (n *nomadFSM) applyCSIVolumeClaim(buf []byte, index uint64) interface{} { - var req structs.CSIVolumeClaimRequest - if err := structs.Decode(buf, &req); err != nil { +func (n *nomadFSM) applyCSIVolumeBatchClaim(buf []byte, index uint64) interface{} { + var batch *structs.CSIVolumeClaimBatchRequest + if err := structs.Decode(buf, &batch); err != nil { panic(fmt.Errorf("failed to decode request: %v", err)) } - defer metrics.MeasureSince([]string{"nomad", "fsm", "apply_csi_volume_claim"}, time.Now()) + defer metrics.MeasureSince([]string{"nomad", "fsm", "apply_csi_volume_batch_claim"}, time.Now()) - ws := memdb.NewWatchSet() - alloc, err := n.state.AllocByID(ws, req.AllocationID) - if err != nil { - n.logger.Error("AllocByID failed", "error", err) - return err - } - if alloc == nil { - n.logger.Error("AllocByID failed to find alloc", "alloc_id", req.AllocationID) + for _, req := range batch.Claims { + err := n.state.CSIVolumeClaim(index, req.RequestNamespace(), + req.VolumeID, req.ToClaim()) if err != nil { - return err + n.logger.Error("CSIVolumeClaim for batch failed", "error", err) + return err // note: fails the remaining batch } + } + return nil +} - return structs.ErrUnknownAllocationPrefix +func (n *nomadFSM) applyCSIVolumeClaim(buf []byte, index uint64) interface{} { + var req structs.CSIVolumeClaimRequest + if err := structs.Decode(buf, &req); err != nil { + panic(fmt.Errorf("failed to decode request: %v", err)) } + defer metrics.MeasureSince([]string{"nomad", "fsm", "apply_csi_volume_claim"}, time.Now()) if err := n.state.CSIVolumeClaim(index, req.RequestNamespace(), req.VolumeID, req.ToClaim()); err != nil { n.logger.Error("CSIVolumeClaim failed", "error", err) return err } - return nil } diff --git a/nomad/interfaces.go b/nomad/interfaces.go deleted file mode 100644 index 4dc266d8b80..00000000000 --- a/nomad/interfaces.go +++ /dev/null @@ -1,11 +0,0 @@ -package nomad - -import "github.com/hashicorp/nomad/nomad/state" - -// RPCServer is a minimal interface of the Server, intended as -// an aid for testing logic surrounding server-to-server or -// server-to-client RPC calls -type RPCServer interface { - RPC(method string, args interface{}, reply interface{}) error - State() *state.StateStore -} diff --git a/nomad/job_endpoint.go b/nomad/job_endpoint.go index a451817fb58..3cbd3f204db 100644 --- a/nomad/job_endpoint.go +++ b/nomad/job_endpoint.go @@ -737,19 +737,13 @@ func (j *Job) Deregister(args *structs.JobDeregisterRequest, reply *structs.JobD for _, vol := range volumesToGC { // we have to build this eval by hand rather than calling srv.CoreJob // here because we need to use the volume's namespace - - runningAllocs := ":ok" - if args.Purge { - runningAllocs = ":purge" - } - eval := &structs.Evaluation{ ID: uuid.Generate(), Namespace: job.Namespace, Priority: structs.CoreJobPriority, Type: structs.JobTypeCore, TriggeredBy: structs.EvalTriggerAllocStop, - JobID: structs.CoreJobCSIVolumeClaimGC + ":" + vol.Source + runningAllocs, + JobID: structs.CoreJobCSIVolumeClaimGC + ":" + vol.Source, LeaderACL: j.srv.getLeaderAcl(), Status: structs.EvalStatusPending, CreateTime: now, diff --git a/nomad/leader.go b/nomad/leader.go index b43d4abd2f5..29550dc7992 100644 --- a/nomad/leader.go +++ b/nomad/leader.go @@ -241,6 +241,9 @@ func (s *Server) establishLeadership(stopCh chan struct{}) error { // Enable the NodeDrainer s.nodeDrainer.SetEnabled(true, s.State()) + // Enable the volume watcher, since we are now the leader + s.volumeWatcher.SetEnabled(true, s.State()) + // Restore the eval broker state if err := s.restoreEvals(); err != nil { return err @@ -870,6 +873,9 @@ func (s *Server) revokeLeadership() error { // Disable the node drainer s.nodeDrainer.SetEnabled(false, nil) + // Disable the volume watcher + s.volumeWatcher.SetEnabled(false, nil) + // Disable any enterprise systems required. if err := s.revokeEnterpriseLeadership(); err != nil { return err diff --git a/nomad/node_endpoint.go b/nomad/node_endpoint.go index fcfbcfcc232..7308c00aa63 100644 --- a/nomad/node_endpoint.go +++ b/nomad/node_endpoint.go @@ -1149,7 +1149,7 @@ func (n *Node) UpdateAlloc(args *structs.AllocUpdateRequest, reply *structs.Gene Priority: structs.CoreJobPriority, Type: structs.JobTypeCore, TriggeredBy: structs.EvalTriggerAllocStop, - JobID: structs.CoreJobCSIVolumeClaimGC + ":" + volAndNamespace[0] + ":no", + JobID: structs.CoreJobCSIVolumeClaimGC + ":" + volAndNamespace[0], LeaderACL: n.srv.getLeaderAcl(), Status: structs.EvalStatusPending, CreateTime: now.UTC().UnixNano(), diff --git a/nomad/node_endpoint_test.go b/nomad/node_endpoint_test.go index c1d54ebc849..e687614409a 100644 --- a/nomad/node_endpoint_test.go +++ b/nomad/node_endpoint_test.go @@ -2414,7 +2414,7 @@ func TestClientEndpoint_UpdateAlloc_UnclaimVolumes(t *testing.T) { // Verify the eval for the claim GC was emitted // Lookup the evaluations - eval, err := state.EvalsByJob(ws, job.Namespace, structs.CoreJobCSIVolumeClaimGC+":"+volId0+":no") + eval, err := state.EvalsByJob(ws, job.Namespace, structs.CoreJobCSIVolumeClaimGC+":"+volId0) require.NotNil(t, eval) require.Nil(t, err) } diff --git a/nomad/server.go b/nomad/server.go index d691faeda80..8a1353f985d 100644 --- a/nomad/server.go +++ b/nomad/server.go @@ -35,6 +35,7 @@ import ( "github.com/hashicorp/nomad/nomad/state" "github.com/hashicorp/nomad/nomad/structs" "github.com/hashicorp/nomad/nomad/structs/config" + "github.com/hashicorp/nomad/nomad/volumewatcher" "github.com/hashicorp/nomad/scheduler" "github.com/hashicorp/raft" raftboltdb "github.com/hashicorp/raft-boltdb" @@ -186,6 +187,9 @@ type Server struct { // nodeDrainer is used to drain allocations from nodes. nodeDrainer *drainer.NodeDrainer + // volumeWatcher is used to release volume claims + volumeWatcher *volumewatcher.Watcher + // evalBroker is used to manage the in-progress evaluations // that are waiting to be brokered to a sub-scheduler evalBroker *EvalBroker @@ -399,6 +403,12 @@ func NewServer(config *Config, consulCatalog consul.CatalogAPI, consulACLs consu return nil, fmt.Errorf("failed to create deployment watcher: %v", err) } + // Setup the volume watcher + if err := s.setupVolumeWatcher(); err != nil { + s.logger.Error("failed to create volume watcher", "error", err) + return nil, fmt.Errorf("failed to create volume watcher: %v", err) + } + // Setup the node drainer. s.setupNodeDrainer() @@ -993,6 +1003,27 @@ func (s *Server) setupDeploymentWatcher() error { return nil } +// setupVolumeWatcher creates a volume watcher that consumes the RPC +// endpoints for state information and makes transitions via Raft through a +// shim that provides the appropriate methods. +func (s *Server) setupVolumeWatcher() error { + + // Create the raft shim type to restrict the set of raft methods that can be + // made + raftShim := &volumeWatcherRaftShim{ + apply: s.raftApply, + } + + // Create the volume watcher + s.volumeWatcher = volumewatcher.NewVolumesWatcher( + s.logger, raftShim, + s.staticEndpoints.ClientCSI, + volumewatcher.LimitStateQueriesPerSecond, + volumewatcher.CrossVolumeUpdateBatchDuration) + + return nil +} + // setupNodeDrainer creates a node drainer which will be enabled when a server // becomes a leader. func (s *Server) setupNodeDrainer() { diff --git a/nomad/state/state_store.go b/nomad/state/state_store.go index 22b9801ca70..3f76bc68a5d 100644 --- a/nomad/state/state_store.go +++ b/nomad/state/state_store.go @@ -2068,9 +2068,14 @@ func (s *StateStore) CSIVolumeClaim(index uint64, namespace, id string, claim *s return err } - err = volume.Claim(claim, alloc) - if err != nil { - return err + // in the case of a job deregistration, there will be no allocation ID + // for the claim but we still want to write an updated index to the volume + // so that volume reaping is triggered + if claim.AllocationID != "" { + err = volume.Claim(claim, alloc) + if err != nil { + return err + } } volume.ModifyIndex = index diff --git a/nomad/structs/csi.go b/nomad/structs/csi.go index bea3439ea68..5428f89ef69 100644 --- a/nomad/structs/csi.go +++ b/nomad/structs/csi.go @@ -575,6 +575,10 @@ const ( CSIVolumeClaimRelease ) +type CSIVolumeClaimBatchRequest struct { + Claims []CSIVolumeClaimRequest +} + type CSIVolumeClaimRequest struct { VolumeID string AllocationID string diff --git a/nomad/structs/structs.go b/nomad/structs/structs.go index ce36d15fb41..8f3eb060f76 100644 --- a/nomad/structs/structs.go +++ b/nomad/structs/structs.go @@ -90,6 +90,7 @@ const ( CSIVolumeRegisterRequestType CSIVolumeDeregisterRequestType CSIVolumeClaimRequestType + CSIVolumeClaimBatchRequestType ScalingEventRegisterRequestType ) diff --git a/nomad/volumewatcher/batcher.go b/nomad/volumewatcher/batcher.go new file mode 100644 index 00000000000..a67ef1bb8d5 --- /dev/null +++ b/nomad/volumewatcher/batcher.go @@ -0,0 +1,125 @@ +package volumewatcher + +import ( + "context" + "time" + + "github.com/hashicorp/nomad/nomad/structs" +) + +// VolumeUpdateBatcher is used to batch the updates for volume claims +type VolumeUpdateBatcher struct { + // batch is the batching duration + batch time.Duration + + // raft is used to actually commit the updates + raft VolumeRaftEndpoints + + // workCh is used to pass evaluations to the daemon process + workCh chan *updateWrapper + + // ctx is used to exit the daemon batcher + ctx context.Context +} + +// NewVolumeUpdateBatcher returns an VolumeUpdateBatcher that uses the +// passed raft endpoints to create the updates to volume claims, and +// exits the batcher when the passed exit channel is closed. +func NewVolumeUpdateBatcher(batchDuration time.Duration, raft VolumeRaftEndpoints, ctx context.Context) *VolumeUpdateBatcher { + b := &VolumeUpdateBatcher{ + batch: batchDuration, + raft: raft, + ctx: ctx, + workCh: make(chan *updateWrapper, 10), + } + + go b.batcher() + return b +} + +// CreateUpdate batches the volume claim update and returns a future +// that tracks the completion of the request. +func (b *VolumeUpdateBatcher) CreateUpdate(claims []structs.CSIVolumeClaimRequest) *BatchFuture { + wrapper := &updateWrapper{ + claims: claims, + f: make(chan *BatchFuture, 1), + } + + b.workCh <- wrapper + return <-wrapper.f +} + +type updateWrapper struct { + claims []structs.CSIVolumeClaimRequest + f chan *BatchFuture +} + +// batcher is the long lived batcher goroutine +func (b *VolumeUpdateBatcher) batcher() { + var timerCh <-chan time.Time + claims := make(map[string]structs.CSIVolumeClaimRequest) + future := NewBatchFuture() + for { + select { + case <-b.ctx.Done(): + // note: we can't flush here because we're likely no + // longer the leader + return + case w := <-b.workCh: + if timerCh == nil { + timerCh = time.After(b.batch) + } + + // de-dupe and store the claim update, and attach the future + for _, upd := range w.claims { + claims[upd.VolumeID+upd.RequestNamespace()] = upd + } + w.f <- future + case <-timerCh: + // Capture the future and create a new one + f := future + future = NewBatchFuture() + + // Create the batch request + req := structs.CSIVolumeClaimBatchRequest{} + for _, claim := range claims { + req.Claims = append(req.Claims, claim) + } + + // Upsert the claims in a go routine + go f.Set(b.raft.UpsertVolumeClaims(&req)) + + // Reset the claims list and timer + claims = make(map[string]structs.CSIVolumeClaimRequest) + timerCh = nil + } + } +} + +// BatchFuture is a future that can be used to retrieve the index for +// the update or any error in the update process +type BatchFuture struct { + index uint64 + err error + waitCh chan struct{} +} + +// NewBatchFuture returns a new BatchFuture +func NewBatchFuture() *BatchFuture { + return &BatchFuture{ + waitCh: make(chan struct{}), + } +} + +// Set sets the results of the future, unblocking any client. +func (f *BatchFuture) Set(index uint64, err error) { + f.index = index + f.err = err + close(f.waitCh) +} + +// Results returns the creation index and any error. +func (f *BatchFuture) Results() (uint64, error) { + <-f.waitCh + return f.index, f.err +} diff --git a/nomad/volumewatcher/batcher_test.go b/nomad/volumewatcher/batcher_test.go new file mode 100644 index 00000000000..7f9915d193b --- /dev/null +++ b/nomad/volumewatcher/batcher_test.go @@ -0,0 +1,85 @@ +package volumewatcher + +import ( + "context" + "fmt" + "sync" + "testing" + + "github.com/hashicorp/nomad/helper/testlog" + "github.com/hashicorp/nomad/nomad/mock" + "github.com/hashicorp/nomad/nomad/state" + "github.com/hashicorp/nomad/nomad/structs" + "github.com/stretchr/testify/require" +) + +// TestVolumeWatch_Batcher tests the update batching logic +func TestVolumeWatch_Batcher(t *testing.T) { + t.Parallel() + require := require.New(t) + + ctx, exitFn := context.WithCancel(context.Background()) + defer exitFn() + + srv := &MockBatchingRPCServer{} + srv.state = state.TestStateStore(t) + srv.volumeUpdateBatcher = NewVolumeUpdateBatcher(CrossVolumeUpdateBatchDuration, srv, ctx) + + plugin := mock.CSIPlugin() + node := testNode(nil, plugin, srv.State()) + + // because we wait for the results to return from the batch for each + // Watcher.updateClaims, we can't test that we're batching except across + // multiple volume watchers. create 2 volumes and their watchers here. + alloc0 := mock.Alloc() + alloc0.ClientStatus = structs.AllocClientStatusComplete + vol0 := testVolume(nil, plugin, alloc0, node.ID) + w0 := &volumeWatcher{ + v: vol0, + rpc: srv, + state: srv.State(), + updateClaims: srv.UpdateClaims, + logger: testlog.HCLogger(t), + } + + alloc1 := mock.Alloc() + alloc1.ClientStatus = structs.AllocClientStatusComplete + vol1 := testVolume(nil, plugin, alloc1, node.ID) + w1 := &volumeWatcher{ + v: vol1, + rpc: srv, + state: srv.State(), + updateClaims: srv.UpdateClaims, + logger: testlog.HCLogger(t), + } + + srv.nextCSIControllerDetachError = fmt.Errorf("some controller plugin error") + + var wg sync.WaitGroup + wg.Add(2) + + go func() { + w0.volumeReapImpl(vol0) + wg.Done() + }() + go func() { + w1.volumeReapImpl(vol1) + wg.Done() + }() + + wg.Wait() + + require.Equal(structs.CSIVolumeClaimStateNodeDetached, vol0.PastClaims[alloc0.ID].State) + require.Equal(structs.CSIVolumeClaimStateNodeDetached, vol1.PastClaims[alloc1.ID].State) + require.Equal(2, srv.countCSINodeDetachVolume) + require.Equal(2, srv.countCSIControllerDetachVolume) + require.Equal(2, srv.countUpdateClaims) + + // note: it's technically possible that the volumeReapImpl + // goroutines get de-scheduled and we don't write both updates in + // the same batch. but this seems really unlikely, so we're + // testing for both cases here so that if we start seeing a flake + // here in the future we have a clear cause for it. + require.GreaterOrEqual(srv.countUpsertVolumeClaims, 1) + require.Equal(1, srv.countUpsertVolumeClaims) +} diff --git a/nomad/volumewatcher/interfaces.go b/nomad/volumewatcher/interfaces.go new file mode 100644 index 00000000000..55d82c55b7c --- /dev/null +++ b/nomad/volumewatcher/interfaces.go @@ -0,0 +1,28 @@ +package volumewatcher + +import ( + cstructs "github.com/hashicorp/nomad/client/structs" + "github.com/hashicorp/nomad/nomad/structs" +) + +// VolumeRaftEndpoints exposes the volume watcher to a set of functions +// to apply data transforms via Raft. +type VolumeRaftEndpoints interface { + + // UpsertVolumeClaims applys a batch of claims to raft + UpsertVolumeClaims(*structs.CSIVolumeClaimBatchRequest) (uint64, error) +} + +// ClientRPC is a minimal interface of the Server, intended as an aid +// for testing logic surrounding server-to-server or server-to-client +// RPC calls and to avoid circular references between the nomad +// package and the volumewatcher +type ClientRPC interface { + ControllerDetachVolume(args *cstructs.ClientCSIControllerDetachVolumeRequest, reply *cstructs.ClientCSIControllerDetachVolumeResponse) error + NodeDetachVolume(args *cstructs.ClientCSINodeDetachVolumeRequest, reply *cstructs.ClientCSINodeDetachVolumeResponse) error +} + +// claimUpdater is the function used to update claims on behalf of a volume +// (used to wrap batch updates so that we can test +// volumeWatcher methods synchronously without batching) +type updateClaimsFn func(claims []structs.CSIVolumeClaimRequest) (uint64, error) diff --git a/nomad/volumewatcher/interfaces_test.go b/nomad/volumewatcher/interfaces_test.go new file mode 100644 index 00000000000..068a76e52de --- /dev/null +++ b/nomad/volumewatcher/interfaces_test.go @@ -0,0 +1,148 @@ +package volumewatcher + +import ( + cstructs "github.com/hashicorp/nomad/client/structs" + "github.com/hashicorp/nomad/nomad/mock" + "github.com/hashicorp/nomad/nomad/state" + "github.com/hashicorp/nomad/nomad/structs" +) + +// Create a client node with plugin info +func testNode(node *structs.Node, plugin *structs.CSIPlugin, s *state.StateStore) *structs.Node { + if node != nil { + return node + } + node = mock.Node() + node.Attributes["nomad.version"] = "0.11.0" // client RPCs not supported on early version + node.CSINodePlugins = map[string]*structs.CSIInfo{ + plugin.ID: { + PluginID: plugin.ID, + Healthy: true, + RequiresControllerPlugin: plugin.ControllerRequired, + NodeInfo: &structs.CSINodeInfo{}, + }, + } + if plugin.ControllerRequired { + node.CSIControllerPlugins = map[string]*structs.CSIInfo{ + plugin.ID: { + PluginID: plugin.ID, + Healthy: true, + RequiresControllerPlugin: true, + ControllerInfo: &structs.CSIControllerInfo{ + SupportsReadOnlyAttach: true, + SupportsAttachDetach: true, + SupportsListVolumes: true, + SupportsListVolumesAttachedNodes: false, + }, + }, + } + } else { + node.CSIControllerPlugins = map[string]*structs.CSIInfo{} + } + s.UpsertNode(99, node) + return node +} + +// Create a test volume with claim info +func testVolume(vol *structs.CSIVolume, plugin *structs.CSIPlugin, alloc *structs.Allocation, nodeID string) *structs.CSIVolume { + if vol != nil { + return vol + } + vol = mock.CSIVolume(plugin) + vol.ControllerRequired = plugin.ControllerRequired + + vol.ReadAllocs = map[string]*structs.Allocation{alloc.ID: alloc} + vol.ReadClaims = map[string]*structs.CSIVolumeClaim{ + alloc.ID: { + AllocationID: alloc.ID, + NodeID: nodeID, + Mode: structs.CSIVolumeClaimRead, + State: structs.CSIVolumeClaimStateTaken, + }, + } + return vol +} + +// COMPAT(1.0): the claim fields were added after 0.11.1; this +// mock and the associated test cases can be removed for 1.0 +func testOldVolume(vol *structs.CSIVolume, plugin *structs.CSIPlugin, alloc *structs.Allocation, nodeID string) *structs.CSIVolume { + if vol != nil { + return vol + } + vol = mock.CSIVolume(plugin) + vol.ControllerRequired = plugin.ControllerRequired + + vol.ReadAllocs = map[string]*structs.Allocation{alloc.ID: alloc} + return vol +} + +type MockRPCServer struct { + state *state.StateStore + + // mock responses for ClientCSI.NodeDetachVolume + nextCSINodeDetachResponse *cstructs.ClientCSINodeDetachVolumeResponse + nextCSINodeDetachError error + countCSINodeDetachVolume int + + // mock responses for ClientCSI.ControllerDetachVolume + nextCSIControllerDetachVolumeResponse *cstructs.ClientCSIControllerDetachVolumeResponse + nextCSIControllerDetachError error + countCSIControllerDetachVolume int + + countUpdateClaims int + countUpsertVolumeClaims int +} + +func (srv *MockRPCServer) ControllerDetachVolume(args *cstructs.ClientCSIControllerDetachVolumeRequest, reply *cstructs.ClientCSIControllerDetachVolumeResponse) error { + reply = srv.nextCSIControllerDetachVolumeResponse + srv.countCSIControllerDetachVolume++ + return srv.nextCSIControllerDetachError +} + +func (srv *MockRPCServer) NodeDetachVolume(args *cstructs.ClientCSINodeDetachVolumeRequest, reply *cstructs.ClientCSINodeDetachVolumeResponse) error { + reply = srv.nextCSINodeDetachResponse + srv.countCSINodeDetachVolume++ + return srv.nextCSINodeDetachError + +} + +func (srv *MockRPCServer) UpsertVolumeClaims(*structs.CSIVolumeClaimBatchRequest) (uint64, error) { + srv.countUpsertVolumeClaims++ + return 0, nil +} + +func (srv *MockRPCServer) State() *state.StateStore { return srv.state } + +func (srv *MockRPCServer) UpdateClaims(claims []structs.CSIVolumeClaimRequest) (uint64, error) { + srv.countUpdateClaims++ + return 0, nil +} + +type MockBatchingRPCServer struct { + MockRPCServer + volumeUpdateBatcher *VolumeUpdateBatcher +} + +func (srv *MockBatchingRPCServer) UpdateClaims(claims []structs.CSIVolumeClaimRequest) (uint64, error) { + srv.countUpdateClaims++ + return srv.volumeUpdateBatcher.CreateUpdate(claims).Results() +} + +type MockStatefulRPCServer struct { + MockRPCServer + volumeUpdateBatcher *VolumeUpdateBatcher +} + +func (srv *MockStatefulRPCServer) UpsertVolumeClaims(batch *structs.CSIVolumeClaimBatchRequest) (uint64, error) { + srv.countUpsertVolumeClaims++ + index, _ := srv.state.LatestIndex() + for _, req := range batch.Claims { + index++ + err := srv.state.CSIVolumeClaim(index, req.RequestNamespace(), + req.VolumeID, req.ToClaim()) + if err != nil { + return 0, err + } + } + return index, nil +} diff --git a/nomad/volumewatcher/volume_watcher.go b/nomad/volumewatcher/volume_watcher.go new file mode 100644 index 00000000000..6177ae10d20 --- /dev/null +++ b/nomad/volumewatcher/volume_watcher.go @@ -0,0 +1,382 @@ +package volumewatcher + +import ( + "context" + "fmt" + "sync" + + log "github.com/hashicorp/go-hclog" + memdb "github.com/hashicorp/go-memdb" + multierror "github.com/hashicorp/go-multierror" + cstructs "github.com/hashicorp/nomad/client/structs" + "github.com/hashicorp/nomad/nomad/state" + "github.com/hashicorp/nomad/nomad/structs" +) + +// volumeWatcher is used to watch a single volume and trigger the +// scheduler when allocation health transitions. +type volumeWatcher struct { + // v is the volume being watched + v *structs.CSIVolume + + // state is the state that is watched for state changes. + state *state.StateStore + + // updateClaims is the function used to apply claims to raft + updateClaims updateClaimsFn + + // server interface for CSI client RPCs + rpc ClientRPC + + logger log.Logger + shutdownCtx context.Context // parent context + ctx context.Context // own context + exitFn context.CancelFunc + + // updateCh is triggered when there is an updated volume + updateCh chan *structs.CSIVolume + + wLock sync.RWMutex + running bool +} + +// newVolumeWatcher returns a volume watcher that is used to watch +// volumes +func newVolumeWatcher(parent *Watcher, vol *structs.CSIVolume) *volumeWatcher { + + w := &volumeWatcher{ + updateCh: make(chan *structs.CSIVolume, 1), + updateClaims: parent.updateClaims, + v: vol, + state: parent.state, + rpc: parent.rpc, + logger: parent.logger.With("volume_id", vol.ID, "namespace", vol.Namespace), + shutdownCtx: parent.ctx, + } + + // Start the long lived watcher that scans for allocation updates + w.Start() + return w +} + +// Notify signals an update to the tracked volume. +func (vw *volumeWatcher) Notify(v *structs.CSIVolume) { + if !vw.isRunning() { + vw.Start() + } + select { + case vw.updateCh <- v: + case <-vw.shutdownCtx.Done(): // prevent deadlock if we stopped + case <-vw.ctx.Done(): // prevent deadlock if we stopped + } +} + +func (vw *volumeWatcher) Start() { + vw.logger.Trace("starting watcher", "id", vw.v.ID, "namespace", vw.v.Namespace) + vw.wLock.Lock() + defer vw.wLock.Unlock() + vw.running = true + ctx, exitFn := context.WithCancel(vw.shutdownCtx) + vw.ctx = ctx + vw.exitFn = exitFn + go vw.watch() +} + +// Stop stops watching the volume. This should be called whenever a +// volume's claims are fully reaped or the watcher is no longer needed. +func (vw *volumeWatcher) Stop() { + vw.logger.Trace("no more claims", "id", vw.v.ID, "namespace", vw.v.Namespace) + vw.exitFn() +} + +func (vw *volumeWatcher) isRunning() bool { + vw.wLock.RLock() + defer vw.wLock.RUnlock() + select { + case <-vw.shutdownCtx.Done(): + return false + case <-vw.ctx.Done(): + return false + default: + return vw.running + } +} + +// watch is the long-running function that watches for changes to a volume. +// Each pass steps the volume's claims through the various states of reaping +// until the volume has no more claims eligible to be reaped. +func (vw *volumeWatcher) watch() { + for { + select { + // TODO(tgross): currently server->client RPC have no cancellation + // context, so we can't stop the long-runner RPCs gracefully + case <-vw.shutdownCtx.Done(): + return + case <-vw.ctx.Done(): + return + case vol := <-vw.updateCh: + // while we won't make raft writes if we get a stale update, + // we can still fire extra CSI RPC calls if we don't check this + if vol == nil || vw.v == nil || vol.ModifyIndex >= vw.v.ModifyIndex { + vol = vw.getVolume(vol) + if vol == nil { + return + } + vw.volumeReap(vol) + } + } + } +} + +// getVolume returns the tracked volume, fully populated with the current +// state +func (vw *volumeWatcher) getVolume(vol *structs.CSIVolume) *structs.CSIVolume { + vw.wLock.RLock() + defer vw.wLock.RUnlock() + + var err error + ws := memdb.NewWatchSet() + + vol, err = vw.state.CSIVolumeDenormalizePlugins(ws, vol.Copy()) + if err != nil { + vw.logger.Error("could not query plugins for volume", "error", err) + return nil + } + + vol, err = vw.state.CSIVolumeDenormalize(ws, vol) + if err != nil { + vw.logger.Error("could not query allocs for volume", "error", err) + return nil + } + vw.v = vol + return vol +} + +// volumeReap collects errors for logging but doesn't return them +// to the main loop. +func (vw *volumeWatcher) volumeReap(vol *structs.CSIVolume) { + vw.logger.Trace("releasing unused volume claims", "id", vol.ID, "namespace", vol.Namespace) + err := vw.volumeReapImpl(vol) + if err != nil { + vw.logger.Error("error releasing volume claims", "error", err) + } + if vw.isUnclaimed(vol) { + vw.Stop() + } +} + +func (vw *volumeWatcher) isUnclaimed(vol *structs.CSIVolume) bool { + return len(vol.ReadClaims) == 0 && len(vol.WriteClaims) == 0 && len(vol.PastClaims) == 0 +} + +func (vw *volumeWatcher) volumeReapImpl(vol *structs.CSIVolume) error { + var result *multierror.Error + nodeClaims := map[string]int{} // node IDs -> count + jobs := map[string]bool{} // jobID -> stopped + + // if a job is purged, the subsequent alloc updates can't + // trigger a GC job because there's no job for them to query. + // Job.Deregister will send a claim release on all claims + // but the allocs will not yet be terminated. save the status + // for each job so that we don't requery in this pass + checkStopped := func(jobID string) bool { + namespace := vw.v.Namespace + isStopped, ok := jobs[jobID] + if !ok { + ws := memdb.NewWatchSet() + job, err := vw.state.JobByID(ws, namespace, jobID) + if err != nil { + isStopped = true + } + if job == nil || job.Stopped() { + isStopped = true + } + jobs[jobID] = isStopped + } + return isStopped + } + + collect := func(allocs map[string]*structs.Allocation, + claims map[string]*structs.CSIVolumeClaim) { + + for allocID, alloc := range allocs { + + if alloc == nil { + _, exists := vol.PastClaims[allocID] + if !exists { + vol.PastClaims[allocID] = &structs.CSIVolumeClaim{ + AllocationID: allocID, + State: structs.CSIVolumeClaimStateReadyToFree, + } + } + continue + } + + nodeClaims[alloc.NodeID]++ + + if alloc.Terminated() || checkStopped(alloc.JobID) { + // don't overwrite the PastClaim if we've seen it before, + // so that we can track state between subsequent calls + _, exists := vol.PastClaims[allocID] + if !exists { + claim, ok := claims[allocID] + if !ok { + claim = &structs.CSIVolumeClaim{ + AllocationID: allocID, + NodeID: alloc.NodeID, + } + } + claim.State = structs.CSIVolumeClaimStateTaken + vol.PastClaims[allocID] = claim + } + } + } + } + + collect(vol.ReadAllocs, vol.ReadClaims) + collect(vol.WriteAllocs, vol.WriteClaims) + + if len(vol.PastClaims) == 0 { + return nil + } + + for _, claim := range vol.PastClaims { + + var err error + + // previous checkpoints may have set the past claim state already. + // in practice we should never see CSIVolumeClaimStateControllerDetached + // but having an option for the state makes it easy to add a checkpoint + // in a backwards compatible way if we need one later + switch claim.State { + case structs.CSIVolumeClaimStateNodeDetached: + goto NODE_DETACHED + case structs.CSIVolumeClaimStateControllerDetached: + goto RELEASE_CLAIM + case structs.CSIVolumeClaimStateReadyToFree: + goto RELEASE_CLAIM + } + + err = vw.nodeDetach(vol, claim) + if err != nil { + result = multierror.Append(result, err) + break + } + + NODE_DETACHED: + nodeClaims[claim.NodeID]-- + err = vw.controllerDetach(vol, claim, nodeClaims) + if err != nil { + result = multierror.Append(result, err) + break + } + + RELEASE_CLAIM: + err = vw.checkpoint(vol, claim) + if err != nil { + result = multierror.Append(result, err) + break + } + // the checkpoint deletes from the state store, but this operates + // on our local copy which aids in testing + delete(vol.PastClaims, claim.AllocationID) + } + + return result.ErrorOrNil() + +} + +// nodeDetach makes the client NodePublish / NodeUnstage RPCs, which +// must be completed before controller operations or releasing the claim. +func (vw *volumeWatcher) nodeDetach(vol *structs.CSIVolume, claim *structs.CSIVolumeClaim) error { + vw.logger.Trace("detaching node", "id", vol.ID, "namespace", vol.Namespace) + nReq := &cstructs.ClientCSINodeDetachVolumeRequest{ + PluginID: vol.PluginID, + VolumeID: vol.ID, + ExternalID: vol.RemoteID(), + AllocID: claim.AllocationID, + NodeID: claim.NodeID, + AttachmentMode: vol.AttachmentMode, + AccessMode: vol.AccessMode, + ReadOnly: claim.Mode == structs.CSIVolumeClaimRead, + } + + err := vw.rpc.NodeDetachVolume(nReq, + &cstructs.ClientCSINodeDetachVolumeResponse{}) + if err != nil { + return fmt.Errorf("could not detach from node: %v", err) + } + claim.State = structs.CSIVolumeClaimStateNodeDetached + return vw.checkpoint(vol, claim) +} + +// controllerDetach makes the client RPC to the controller to +// unpublish the volume if a controller is required and no other +// allocs on the node need it +func (vw *volumeWatcher) controllerDetach(vol *structs.CSIVolume, claim *structs.CSIVolumeClaim, nodeClaims map[string]int) error { + if !vol.ControllerRequired || nodeClaims[claim.NodeID] > 1 { + claim.State = structs.CSIVolumeClaimStateReadyToFree + return nil + } + vw.logger.Trace("detaching controller", "id", vol.ID, "namespace", vol.Namespace) + // note: we need to get the CSI Node ID, which is not the same as + // the Nomad Node ID + ws := memdb.NewWatchSet() + targetNode, err := vw.state.NodeByID(ws, claim.NodeID) + if err != nil { + return err + } + if targetNode == nil { + return fmt.Errorf("%s: %s", structs.ErrUnknownNodePrefix, claim.NodeID) + } + targetCSIInfo, ok := targetNode.CSINodePlugins[vol.PluginID] + if !ok { + return fmt.Errorf("failed to find NodeInfo for node: %s", targetNode.ID) + } + + plug, err := vw.state.CSIPluginByID(ws, vol.PluginID) + if err != nil { + return fmt.Errorf("plugin lookup error: %s %v", vol.PluginID, err) + } + if plug == nil { + return fmt.Errorf("plugin lookup error: %s missing plugin", vol.PluginID) + } + + cReq := &cstructs.ClientCSIControllerDetachVolumeRequest{ + VolumeID: vol.RemoteID(), + ClientCSINodeID: targetCSIInfo.NodeInfo.ID, + } + cReq.PluginID = plug.ID + err = vw.rpc.ControllerDetachVolume(cReq, + &cstructs.ClientCSIControllerDetachVolumeResponse{}) + if err != nil { + return fmt.Errorf("could not detach from controller: %v", err) + } + claim.State = structs.CSIVolumeClaimStateReadyToFree + return nil +} + +func (vw *volumeWatcher) checkpoint(vol *structs.CSIVolume, claim *structs.CSIVolumeClaim) error { + vw.logger.Trace("checkpointing claim", "id", vol.ID, "namespace", vol.Namespace) + req := structs.CSIVolumeClaimRequest{ + VolumeID: vol.ID, + AllocationID: claim.AllocationID, + NodeID: claim.NodeID, + Claim: structs.CSIVolumeClaimRelease, + State: claim.State, + WriteRequest: structs.WriteRequest{ + Namespace: vol.Namespace, + // Region: vol.Region, // TODO(tgross) should volumes have regions? + }, + } + index, err := vw.updateClaims([]structs.CSIVolumeClaimRequest{req}) + if err == nil && index != 0 { + vw.wLock.Lock() + defer vw.wLock.Unlock() + vw.v.ModifyIndex = index + } + if err != nil { + return fmt.Errorf("could not checkpoint claim release: %v", err) + } + return nil +} diff --git a/nomad/volumewatcher/volume_watcher_test.go b/nomad/volumewatcher/volume_watcher_test.go new file mode 100644 index 00000000000..a2b5ab03350 --- /dev/null +++ b/nomad/volumewatcher/volume_watcher_test.go @@ -0,0 +1,294 @@ +package volumewatcher + +import ( + "context" + "fmt" + "testing" + + "github.com/hashicorp/nomad/helper/testlog" + "github.com/hashicorp/nomad/nomad/mock" + "github.com/hashicorp/nomad/nomad/state" + "github.com/hashicorp/nomad/nomad/structs" + "github.com/stretchr/testify/require" +) + +// TestVolumeWatch_OneReap tests one pass through the reaper +func TestVolumeWatch_OneReap(t *testing.T) { + t.Parallel() + require := require.New(t) + + cases := []struct { + Name string + Volume *structs.CSIVolume + Node *structs.Node + ControllerRequired bool + ExpectedErr string + ExpectedClaimsCount int + ExpectedNodeDetachCount int + ExpectedControllerDetachCount int + ExpectedUpdateClaimsCount int + srv *MockRPCServer + }{ + { + Name: "No terminal allocs", + Volume: mock.CSIVolume(mock.CSIPlugin()), + ControllerRequired: true, + srv: &MockRPCServer{ + state: state.TestStateStore(t), + nextCSINodeDetachError: fmt.Errorf("should never see this"), + }, + }, + { + Name: "NodeDetachVolume fails", + ControllerRequired: true, + ExpectedErr: "some node plugin error", + ExpectedNodeDetachCount: 1, + srv: &MockRPCServer{ + state: state.TestStateStore(t), + nextCSINodeDetachError: fmt.Errorf("some node plugin error"), + }, + }, + { + Name: "NodeDetachVolume node-only happy path", + ControllerRequired: false, + ExpectedNodeDetachCount: 1, + ExpectedUpdateClaimsCount: 2, + srv: &MockRPCServer{ + state: state.TestStateStore(t), + }, + }, + { + Name: "ControllerDetachVolume no controllers available", + Node: mock.Node(), + ControllerRequired: true, + ExpectedErr: "Unknown node", + ExpectedNodeDetachCount: 1, + ExpectedUpdateClaimsCount: 1, + srv: &MockRPCServer{ + state: state.TestStateStore(t), + }, + }, + { + Name: "ControllerDetachVolume controller error", + ControllerRequired: true, + ExpectedErr: "some controller error", + ExpectedNodeDetachCount: 1, + ExpectedControllerDetachCount: 1, + ExpectedUpdateClaimsCount: 1, + srv: &MockRPCServer{ + state: state.TestStateStore(t), + nextCSIControllerDetachError: fmt.Errorf("some controller error"), + }, + }, + { + Name: "ControllerDetachVolume happy path", + ControllerRequired: true, + ExpectedNodeDetachCount: 1, + ExpectedControllerDetachCount: 1, + ExpectedUpdateClaimsCount: 2, + srv: &MockRPCServer{ + state: state.TestStateStore(t), + }, + }, + } + + for _, tc := range cases { + t.Run(tc.Name, func(t *testing.T) { + + plugin := mock.CSIPlugin() + plugin.ControllerRequired = tc.ControllerRequired + node := testNode(tc.Node, plugin, tc.srv.State()) + alloc := mock.Alloc() + alloc.NodeID = node.ID + alloc.ClientStatus = structs.AllocClientStatusComplete + vol := testVolume(tc.Volume, plugin, alloc, node.ID) + ctx, exitFn := context.WithCancel(context.Background()) + w := &volumeWatcher{ + v: vol, + rpc: tc.srv, + state: tc.srv.State(), + updateClaims: tc.srv.UpdateClaims, + ctx: ctx, + exitFn: exitFn, + logger: testlog.HCLogger(t), + } + + err := w.volumeReapImpl(vol) + if tc.ExpectedErr != "" { + require.Error(err, fmt.Sprintf("expected: %q", tc.ExpectedErr)) + require.Contains(err.Error(), tc.ExpectedErr) + } else { + require.NoError(err) + } + require.Equal(tc.ExpectedNodeDetachCount, + tc.srv.countCSINodeDetachVolume, "node detach RPC count") + require.Equal(tc.ExpectedControllerDetachCount, + tc.srv.countCSIControllerDetachVolume, "controller detach RPC count") + require.Equal(tc.ExpectedUpdateClaimsCount, + tc.srv.countUpdateClaims, "update claims count") + }) + } +} + +// TestVolumeWatch_OldVolume_OneReap tests one pass through the reaper +// COMPAT(1.0): the claim fields were added after 0.11.1; this test +// can be removed for 1.0 +func TestVolumeWatch_OldVolume_OneReap(t *testing.T) { + t.Parallel() + require := require.New(t) + + cases := []struct { + Name string + Volume *structs.CSIVolume + Node *structs.Node + ControllerRequired bool + ExpectedErr string + ExpectedClaimsCount int + ExpectedNodeDetachCount int + ExpectedControllerDetachCount int + ExpectedUpdateClaimsCount int + srv *MockRPCServer + }{ + { + Name: "No terminal allocs", + Volume: mock.CSIVolume(mock.CSIPlugin()), + ControllerRequired: true, + srv: &MockRPCServer{ + state: state.TestStateStore(t), + nextCSINodeDetachError: fmt.Errorf("should never see this"), + }, + }, + { + Name: "NodeDetachVolume fails", + ControllerRequired: true, + ExpectedErr: "some node plugin error", + ExpectedNodeDetachCount: 1, + srv: &MockRPCServer{ + state: state.TestStateStore(t), + nextCSINodeDetachError: fmt.Errorf("some node plugin error"), + }, + }, + { + Name: "NodeDetachVolume node-only happy path", + ControllerRequired: false, + ExpectedNodeDetachCount: 1, + ExpectedUpdateClaimsCount: 2, + srv: &MockRPCServer{ + state: state.TestStateStore(t), + }, + }, + { + Name: "ControllerDetachVolume no controllers available", + Node: mock.Node(), + ControllerRequired: true, + ExpectedErr: "Unknown node", + ExpectedNodeDetachCount: 1, + ExpectedUpdateClaimsCount: 1, + srv: &MockRPCServer{ + state: state.TestStateStore(t), + }, + }, + { + Name: "ControllerDetachVolume controller error", + ControllerRequired: true, + ExpectedErr: "some controller error", + ExpectedNodeDetachCount: 1, + ExpectedControllerDetachCount: 1, + ExpectedUpdateClaimsCount: 1, + srv: &MockRPCServer{ + state: state.TestStateStore(t), + nextCSIControllerDetachError: fmt.Errorf("some controller error"), + }, + }, + { + Name: "ControllerDetachVolume happy path", + ControllerRequired: true, + ExpectedNodeDetachCount: 1, + ExpectedControllerDetachCount: 1, + ExpectedUpdateClaimsCount: 2, + srv: &MockRPCServer{ + state: state.TestStateStore(t), + }, + }, + } + + for _, tc := range cases { + t.Run(tc.Name, func(t *testing.T) { + + plugin := mock.CSIPlugin() + plugin.ControllerRequired = tc.ControllerRequired + node := testNode(tc.Node, plugin, tc.srv.State()) + alloc := mock.Alloc() + alloc.ClientStatus = structs.AllocClientStatusComplete + alloc.NodeID = node.ID + vol := testOldVolume(tc.Volume, plugin, alloc, node.ID) + ctx, exitFn := context.WithCancel(context.Background()) + w := &volumeWatcher{ + v: vol, + rpc: tc.srv, + state: tc.srv.State(), + updateClaims: tc.srv.UpdateClaims, + ctx: ctx, + exitFn: exitFn, + logger: testlog.HCLogger(t), + } + + err := w.volumeReapImpl(vol) + if tc.ExpectedErr != "" { + require.Error(err, fmt.Sprintf("expected: %q", tc.ExpectedErr)) + require.Contains(err.Error(), tc.ExpectedErr) + } else { + require.NoError(err) + } + require.Equal(tc.ExpectedNodeDetachCount, + tc.srv.countCSINodeDetachVolume, "node detach RPC count") + require.Equal(tc.ExpectedControllerDetachCount, + tc.srv.countCSIControllerDetachVolume, "controller detach RPC count") + require.Equal(tc.ExpectedUpdateClaimsCount, + tc.srv.countUpdateClaims, "update claims count") + }) + } +} + +// TestVolumeWatch_OneReap tests multiple passes through the reaper, +// updating state after each one +func TestVolumeWatch_ReapStates(t *testing.T) { + t.Parallel() + require := require.New(t) + + srv := &MockRPCServer{state: state.TestStateStore(t)} + plugin := mock.CSIPlugin() + node := testNode(nil, plugin, srv.State()) + alloc := mock.Alloc() + alloc.ClientStatus = structs.AllocClientStatusComplete + vol := testVolume(nil, plugin, alloc, node.ID) + + w := &volumeWatcher{ + v: vol, + rpc: srv, + state: srv.State(), + updateClaims: srv.UpdateClaims, + logger: testlog.HCLogger(t), + } + + srv.nextCSINodeDetachError = fmt.Errorf("some node plugin error") + err := w.volumeReapImpl(vol) + require.Error(err) + require.Equal(structs.CSIVolumeClaimStateTaken, vol.PastClaims[alloc.ID].State) + require.Equal(1, srv.countCSINodeDetachVolume) + require.Equal(0, srv.countCSIControllerDetachVolume) + require.Equal(0, srv.countUpdateClaims) + + srv.nextCSINodeDetachError = nil + srv.nextCSIControllerDetachError = fmt.Errorf("some controller plugin error") + err = w.volumeReapImpl(vol) + require.Error(err) + require.Equal(structs.CSIVolumeClaimStateNodeDetached, vol.PastClaims[alloc.ID].State) + require.Equal(1, srv.countUpdateClaims) + + srv.nextCSIControllerDetachError = nil + err = w.volumeReapImpl(vol) + require.NoError(err) + require.Equal(0, len(vol.PastClaims)) + require.Equal(2, srv.countUpdateClaims) +} diff --git a/nomad/volumewatcher/volumes_watcher.go b/nomad/volumewatcher/volumes_watcher.go new file mode 100644 index 00000000000..63446c461ac --- /dev/null +++ b/nomad/volumewatcher/volumes_watcher.go @@ -0,0 +1,232 @@ +package volumewatcher + +import ( + "context" + "sync" + "time" + + log "github.com/hashicorp/go-hclog" + memdb "github.com/hashicorp/go-memdb" + "github.com/hashicorp/nomad/nomad/state" + "github.com/hashicorp/nomad/nomad/structs" + "golang.org/x/time/rate" +) + +const ( + // LimitStateQueriesPerSecond is the number of state queries allowed per + // second + LimitStateQueriesPerSecond = 100.0 + + // CrossVolumeUpdateBatchDuration is the duration in which volume + // claim updates are batched across all volume watchers before + // being committed to Raft. + CrossVolumeUpdateBatchDuration = 250 * time.Millisecond +) + +// Watcher is used to watch volumes and their allocations created +// by the scheduler and trigger the scheduler when allocation health +// transitions. +type Watcher struct { + enabled bool + logger log.Logger + + // queryLimiter is used to limit the rate of blocking queries + queryLimiter *rate.Limiter + + // updateBatchDuration is the duration in which volume + // claim updates are batched across all volume watchers + // before being committed to Raft. + updateBatchDuration time.Duration + + // raft contains the set of Raft endpoints that can be used by the + // volumes watcher + raft VolumeRaftEndpoints + + // rpc contains the set of Server methods that can be used by + // the volumes watcher for RPC + rpc ClientRPC + + // state is the state that is watched for state changes. + state *state.StateStore + + // watchers is the set of active watchers, one per volume + watchers map[string]*volumeWatcher + + // volumeUpdateBatcher is used to batch volume claim updates + volumeUpdateBatcher *VolumeUpdateBatcher + + // ctx and exitFn are used to cancel the watcher + ctx context.Context + exitFn context.CancelFunc + + wlock sync.RWMutex +} + +// NewVolumesWatcher returns a volumes watcher that is used to watch +// volumes and trigger the scheduler as needed. +func NewVolumesWatcher(logger log.Logger, + raft VolumeRaftEndpoints, rpc ClientRPC, stateQueriesPerSecond float64, + updateBatchDuration time.Duration) *Watcher { + + // the leader step-down calls SetEnabled(false) which is what + // cancels this context, rather than passing in its own shutdown + // context + ctx, exitFn := context.WithCancel(context.Background()) + + return &Watcher{ + raft: raft, + rpc: rpc, + queryLimiter: rate.NewLimiter(rate.Limit(stateQueriesPerSecond), 100), + updateBatchDuration: updateBatchDuration, + logger: logger.Named("volumes_watcher"), + ctx: ctx, + exitFn: exitFn, + } +} + +// SetEnabled is used to control if the watcher is enabled. The +// watcher should only be enabled on the active leader. When being +// enabled the state is passed in as it is no longer valid once a +// leader election has taken place. +func (w *Watcher) SetEnabled(enabled bool, state *state.StateStore) { + w.wlock.Lock() + defer w.wlock.Unlock() + + wasEnabled := w.enabled + w.enabled = enabled + + if state != nil { + w.state = state + } + + // Flush the state to create the necessary objects + w.flush() + + // If we are starting now, launch the watch daemon + if enabled && !wasEnabled { + go w.watchVolumes(w.ctx) + } +} + +// flush is used to clear the state of the watcher +func (w *Watcher) flush() { + // Stop all the watchers and clear it + for _, watcher := range w.watchers { + watcher.Stop() + } + + // Kill everything associated with the watcher + if w.exitFn != nil { + w.exitFn() + } + + w.watchers = make(map[string]*volumeWatcher, 32) + w.ctx, w.exitFn = context.WithCancel(context.Background()) + w.volumeUpdateBatcher = NewVolumeUpdateBatcher(w.updateBatchDuration, w.raft, w.ctx) +} + +// watchVolumes is the long lived go-routine that watches for volumes to +// add and remove watchers on. +func (w *Watcher) watchVolumes(ctx context.Context) { + vIndex := uint64(1) + for { + volumes, idx, err := w.getVolumes(ctx, vIndex) + if err != nil { + if err == context.Canceled { + return + } + w.logger.Error("failed to retrieve volumes", "error", err) + } + + vIndex = idx // last-seen index + for _, v := range volumes { + if err := w.add(v); err != nil { + w.logger.Error("failed to track volume", "volume_id", v.ID, "error", err) + } + + } + } +} + +// getVolumes retrieves all volumes blocking at the given index. +func (w *Watcher) getVolumes(ctx context.Context, minIndex uint64) ([]*structs.CSIVolume, uint64, error) { + resp, index, err := w.state.BlockingQuery(w.getVolumesImpl, minIndex, ctx) + if err != nil { + return nil, 0, err + } + + return resp.([]*structs.CSIVolume), index, nil +} + +// getVolumesImpl retrieves all volumes from the passed state store. +func (w *Watcher) getVolumesImpl(ws memdb.WatchSet, state *state.StateStore) (interface{}, uint64, error) { + + iter, err := state.CSIVolumes(ws) + if err != nil { + return nil, 0, err + } + + var volumes []*structs.CSIVolume + for { + raw := iter.Next() + if raw == nil { + break + } + volume := raw.(*structs.CSIVolume) + volumes = append(volumes, volume) + } + + // Use the last index that affected the volume table + index, err := state.Index("csi_volumes") + if err != nil { + return nil, 0, err + } + + return volumes, index, nil +} + +// add adds a volume to the watch list +func (w *Watcher) add(d *structs.CSIVolume) error { + w.wlock.Lock() + defer w.wlock.Unlock() + _, err := w.addLocked(d) + return err +} + +// addLocked adds a volume to the watch list and should only be called when +// locked. Creating the volumeWatcher starts a go routine to .watch() it +func (w *Watcher) addLocked(v *structs.CSIVolume) (*volumeWatcher, error) { + // Not enabled so no-op + if !w.enabled { + return nil, nil + } + + // Already watched so trigger an update for the volume + if watcher, ok := w.watchers[v.ID+v.Namespace]; ok { + watcher.Notify(v) + return nil, nil + } + + watcher := newVolumeWatcher(w, v) + w.watchers[v.ID+v.Namespace] = watcher + return watcher, nil +} + +// TODO: this is currently dead code; we'll call a public remove +// method on the Watcher once we have a periodic GC job +// remove stops watching a volume and should only be called when locked. +func (w *Watcher) removeLocked(volID, namespace string) { + if !w.enabled { + return + } + if watcher, ok := w.watchers[volID+namespace]; ok { + watcher.Stop() + delete(w.watchers, volID+namespace) + } +} + +// updatesClaims sends the claims to the batch updater and waits for +// the results +func (w *Watcher) updateClaims(claims []structs.CSIVolumeClaimRequest) (uint64, error) { + return w.volumeUpdateBatcher.CreateUpdate(claims).Results() +} diff --git a/nomad/volumewatcher/volumes_watcher_test.go b/nomad/volumewatcher/volumes_watcher_test.go new file mode 100644 index 00000000000..b7ae7aea2c5 --- /dev/null +++ b/nomad/volumewatcher/volumes_watcher_test.go @@ -0,0 +1,310 @@ +package volumewatcher + +import ( + "context" + "testing" + "time" + + memdb "github.com/hashicorp/go-memdb" + "github.com/hashicorp/nomad/helper/testlog" + "github.com/hashicorp/nomad/nomad/mock" + "github.com/hashicorp/nomad/nomad/state" + "github.com/hashicorp/nomad/nomad/structs" + "github.com/stretchr/testify/require" +) + +// TestVolumeWatch_EnableDisable tests the watcher registration logic that needs +// to happen during leader step-up/step-down +func TestVolumeWatch_EnableDisable(t *testing.T) { + t.Parallel() + require := require.New(t) + + srv := &MockRPCServer{} + srv.state = state.TestStateStore(t) + index := uint64(100) + + watcher := NewVolumesWatcher(testlog.HCLogger(t), + srv, srv, + LimitStateQueriesPerSecond, + CrossVolumeUpdateBatchDuration) + + watcher.SetEnabled(true, srv.State()) + + plugin := mock.CSIPlugin() + node := testNode(nil, plugin, srv.State()) + alloc := mock.Alloc() + alloc.ClientStatus = structs.AllocClientStatusComplete + vol := testVolume(nil, plugin, alloc, node.ID) + + index++ + err := srv.State().CSIVolumeRegister(index, []*structs.CSIVolume{vol}) + require.NoError(err) + + claim := &structs.CSIVolumeClaim{Mode: structs.CSIVolumeClaimRelease} + index++ + err = srv.State().CSIVolumeClaim(index, vol.Namespace, vol.ID, claim) + require.NoError(err) + require.Eventually(func() bool { + return 1 == len(watcher.watchers) + }, time.Second, 10*time.Millisecond) + + watcher.SetEnabled(false, srv.State()) + require.Equal(0, len(watcher.watchers)) +} + +// TestVolumeWatch_Checkpoint tests the checkpointing of progress across +// leader leader step-up/step-down +func TestVolumeWatch_Checkpoint(t *testing.T) { + t.Parallel() + require := require.New(t) + + srv := &MockRPCServer{} + srv.state = state.TestStateStore(t) + index := uint64(100) + + watcher := NewVolumesWatcher(testlog.HCLogger(t), + srv, srv, + LimitStateQueriesPerSecond, + CrossVolumeUpdateBatchDuration) + + plugin := mock.CSIPlugin() + node := testNode(nil, plugin, srv.State()) + alloc := mock.Alloc() + alloc.ClientStatus = structs.AllocClientStatusComplete + vol := testVolume(nil, plugin, alloc, node.ID) + + watcher.SetEnabled(true, srv.State()) + + index++ + err := srv.State().CSIVolumeRegister(index, []*structs.CSIVolume{vol}) + require.NoError(err) + + // we should get or start up a watcher when we get an update for + // the volume from the state store + require.Eventually(func() bool { + return 1 == len(watcher.watchers) + }, time.Second, 10*time.Millisecond) + + // step-down (this is sync, but step-up is async) + watcher.SetEnabled(false, srv.State()) + require.Equal(0, len(watcher.watchers)) + + // step-up again + watcher.SetEnabled(true, srv.State()) + require.Eventually(func() bool { + return 1 == len(watcher.watchers) + }, time.Second, 10*time.Millisecond) + + require.True(watcher.watchers[vol.ID+vol.Namespace].isRunning()) +} + +// TestVolumeWatch_StartStop tests the start and stop of the watcher when +// it receives notifcations and has completed its work +func TestVolumeWatch_StartStop(t *testing.T) { + t.Parallel() + require := require.New(t) + + ctx, exitFn := context.WithCancel(context.Background()) + defer exitFn() + + srv := &MockStatefulRPCServer{} + srv.state = state.TestStateStore(t) + index := uint64(100) + srv.volumeUpdateBatcher = NewVolumeUpdateBatcher( + CrossVolumeUpdateBatchDuration, srv, ctx) + + watcher := NewVolumesWatcher(testlog.HCLogger(t), + srv, srv, + LimitStateQueriesPerSecond, + CrossVolumeUpdateBatchDuration) + + watcher.SetEnabled(true, srv.State()) + require.Equal(0, len(watcher.watchers)) + + plugin := mock.CSIPlugin() + node := testNode(nil, plugin, srv.State()) + alloc := mock.Alloc() + alloc.ClientStatus = structs.AllocClientStatusRunning + alloc2 := mock.Alloc() + alloc2.Job = alloc.Job + alloc2.ClientStatus = structs.AllocClientStatusRunning + index++ + err := srv.State().UpsertJob(index, alloc.Job) + require.NoError(err) + index++ + err = srv.State().UpsertAllocs(index, []*structs.Allocation{alloc, alloc2}) + require.NoError(err) + + // register a volume + vol := testVolume(nil, plugin, alloc, node.ID) + index++ + err = srv.State().CSIVolumeRegister(index, []*structs.CSIVolume{vol}) + require.NoError(err) + + // assert we get a running watcher + require.Eventually(func() bool { + return 1 == len(watcher.watchers) + }, time.Second, 10*time.Millisecond) + require.True(watcher.watchers[vol.ID+vol.Namespace].isRunning()) + + // claim the volume for both allocs + claim := &structs.CSIVolumeClaim{ + AllocationID: alloc.ID, + NodeID: node.ID, + Mode: structs.CSIVolumeClaimRead, + } + index++ + err = srv.State().CSIVolumeClaim(index, vol.Namespace, vol.ID, claim) + require.NoError(err) + claim.AllocationID = alloc2.ID + index++ + err = srv.State().CSIVolumeClaim(index, vol.Namespace, vol.ID, claim) + require.NoError(err) + + // reap the volume and assert nothing has happened + claim = &structs.CSIVolumeClaim{ + AllocationID: alloc.ID, + NodeID: node.ID, + Mode: structs.CSIVolumeClaimRelease, + } + index++ + err = srv.State().CSIVolumeClaim(index, vol.Namespace, vol.ID, claim) + require.NoError(err) + require.True(watcher.watchers[vol.ID+vol.Namespace].isRunning()) + + // alloc becomes terminal + alloc.ClientStatus = structs.AllocClientStatusComplete + index++ + err = srv.State().UpsertAllocs(index, []*structs.Allocation{alloc}) + require.NoError(err) + index++ + err = srv.State().CSIVolumeClaim(index, vol.Namespace, vol.ID, claim) + require.NoError(err) + + // 1 claim has been released but watcher is still running + require.Eventually(func() bool { + ws := memdb.NewWatchSet() + vol, _ := srv.State().CSIVolumeByID(ws, vol.Namespace, vol.ID) + return len(vol.ReadAllocs) == 1 && len(vol.PastClaims) == 0 + }, time.Second*2, 10*time.Millisecond) + + require.True(watcher.watchers[vol.ID+vol.Namespace].isRunning()) + + // the watcher will have incremented the index so we need to make sure + // our inserts will trigger new events + index, _ = srv.State().LatestIndex() + + // remaining alloc's job is stopped (alloc is not marked terminal) + alloc2.Job.Stop = true + index++ + err = srv.State().UpsertJob(index, alloc2.Job) + require.NoError(err) + + // job deregistration write a claim with no allocations or nodes + claim = &structs.CSIVolumeClaim{ + Mode: structs.CSIVolumeClaimRelease, + } + index++ + err = srv.State().CSIVolumeClaim(index, vol.Namespace, vol.ID, claim) + require.NoError(err) + + // all claims have been released and watcher is stopped + require.Eventually(func() bool { + ws := memdb.NewWatchSet() + vol, _ := srv.State().CSIVolumeByID(ws, vol.Namespace, vol.ID) + return len(vol.ReadAllocs) == 1 && len(vol.PastClaims) == 0 + }, time.Second*2, 10*time.Millisecond) + + require.Eventually(func() bool { + return !watcher.watchers[vol.ID+vol.Namespace].isRunning() + }, time.Second*1, 10*time.Millisecond) + + // the watcher will have incremented the index so we need to make sure + // our inserts will trigger new events + index, _ = srv.State().LatestIndex() + + // create a new claim + alloc3 := mock.Alloc() + alloc3.ClientStatus = structs.AllocClientStatusRunning + index++ + err = srv.State().UpsertAllocs(index, []*structs.Allocation{alloc3}) + require.NoError(err) + claim3 := &structs.CSIVolumeClaim{ + AllocationID: alloc3.ID, + NodeID: node.ID, + Mode: structs.CSIVolumeClaimRelease, + } + index++ + err = srv.State().CSIVolumeClaim(index, vol.Namespace, vol.ID, claim3) + require.NoError(err) + + // a stopped watcher should restore itself on notification + require.Eventually(func() bool { + return watcher.watchers[vol.ID+vol.Namespace].isRunning() + }, time.Second*1, 10*time.Millisecond) +} + +// TestVolumeWatch_RegisterDeregister tests the start and stop of +// watchers around registration +func TestVolumeWatch_RegisterDeregister(t *testing.T) { + t.Parallel() + require := require.New(t) + + ctx, exitFn := context.WithCancel(context.Background()) + defer exitFn() + + srv := &MockStatefulRPCServer{} + srv.state = state.TestStateStore(t) + srv.volumeUpdateBatcher = NewVolumeUpdateBatcher( + CrossVolumeUpdateBatchDuration, srv, ctx) + + index := uint64(100) + + watcher := NewVolumesWatcher(testlog.HCLogger(t), + srv, srv, + LimitStateQueriesPerSecond, + CrossVolumeUpdateBatchDuration) + + watcher.SetEnabled(true, srv.State()) + require.Equal(0, len(watcher.watchers)) + + plugin := mock.CSIPlugin() + node := testNode(nil, plugin, srv.State()) + alloc := mock.Alloc() + alloc.ClientStatus = structs.AllocClientStatusComplete + + // register a volume + vol := testVolume(nil, plugin, alloc, node.ID) + index++ + err := srv.State().CSIVolumeRegister(index, []*structs.CSIVolume{vol}) + require.NoError(err) + + require.Eventually(func() bool { + return 1 == len(watcher.watchers) + }, time.Second, 10*time.Millisecond) + + // reap the volume and assert we've cleaned up + w := watcher.watchers[vol.ID+vol.Namespace] + w.Notify(vol) + + require.Eventually(func() bool { + ws := memdb.NewWatchSet() + vol, _ := srv.State().CSIVolumeByID(ws, vol.Namespace, vol.ID) + return len(vol.ReadAllocs) == 0 && len(vol.PastClaims) == 0 + }, time.Second*2, 10*time.Millisecond) + + require.Eventually(func() bool { + return !watcher.watchers[vol.ID+vol.Namespace].isRunning() + }, time.Second*1, 10*time.Millisecond) + + require.Equal(1, srv.countCSINodeDetachVolume, "node detach RPC count") + require.Equal(1, srv.countCSIControllerDetachVolume, "controller detach RPC count") + require.Equal(2, srv.countUpsertVolumeClaims, "upsert claims count") + + // deregistering the volume doesn't cause an update that triggers + // a watcher; we'll clean up this watcher in a GC later + err = srv.State().CSIVolumeDeregister(index, vol.Namespace, []string{vol.ID}) + require.NoError(err) + require.Equal(1, len(watcher.watchers)) + require.False(watcher.watchers[vol.ID+vol.Namespace].isRunning()) +} diff --git a/nomad/volumewatcher_shim.go b/nomad/volumewatcher_shim.go new file mode 100644 index 00000000000..5148d7f5ba9 --- /dev/null +++ b/nomad/volumewatcher_shim.go @@ -0,0 +1,31 @@ +package nomad + +import ( + "github.com/hashicorp/nomad/nomad/structs" +) + +// volumeWatcherRaftShim is the shim that provides the state watching +// methods. These should be set by the server and passed to the volume +// watcher. +type volumeWatcherRaftShim struct { + // apply is used to apply a message to Raft + apply raftApplyFn +} + +// convertApplyErrors parses the results of a raftApply and returns the index at +// which it was applied and any error that occurred. Raft Apply returns two +// separate errors, Raft library errors and user returned errors from the FSM. +// This helper, joins the errors by inspecting the applyResponse for an error. +func (shim *volumeWatcherRaftShim) convertApplyErrors(applyResp interface{}, index uint64, err error) (uint64, error) { + if applyResp != nil { + if fsmErr, ok := applyResp.(error); ok && fsmErr != nil { + return index, fsmErr + } + } + return index, err +} + +func (shim *volumeWatcherRaftShim) UpsertVolumeClaims(req *structs.CSIVolumeClaimBatchRequest) (uint64, error) { + fsmErrIntf, index, raftErr := shim.apply(structs.CSIVolumeClaimBatchRequestType, req) + return shim.convertApplyErrors(fsmErrIntf, index, raftErr) +}