From c2aaa5179fee4f379186ef8e6202f420ff97b953 Mon Sep 17 00:00:00 2001 From: Tim Gross Date: Wed, 25 Nov 2020 11:15:57 -0500 Subject: [PATCH] CSI: fix transaction handling in state store (#9438) When making updates to CSI plugins, the state store methods that have open write transactions were querying the state store using the same methods used by the CSI RPC endpoint, but these method creates their own top-level read transactions. During concurrent plugin updates (as happens when a plugin job is stopped), this can cause write skew in the plugin counts. * Refactor the CSIPlugin query methods to have an implementation method that accepts a transaction, which can be called with either a read txn or a write txn. * Refactor the CSIVolume query methods to have an implementation method that accepts a transaction, which can be called with either a read txn or a write txn. * CSI volumes need to be "denormalized" with their plugins and (optionally) allocations. Read-only RPC endpoints should take a snapshot so that we can make multiple state store method calls with a consistent view. --- nomad/csi_endpoint.go | 54 ++++++----- nomad/state/state_store.go | 159 ++++++++++++++++++++------------ nomad/state/state_store_test.go | 78 ++++++++++++++++ 3 files changed, 211 insertions(+), 80 deletions(-) diff --git a/nomad/csi_endpoint.go b/nomad/csi_endpoint.go index a88d24c0a84..dd610a6d2e4 100644 --- a/nomad/csi_endpoint.go +++ b/nomad/csi_endpoint.go @@ -116,16 +116,19 @@ func (v *CSIVolume) List(args *structs.CSIVolumeListRequest, reply *structs.CSIV queryOpts: &args.QueryOptions, queryMeta: &reply.QueryMeta, run: func(ws memdb.WatchSet, state *state.StateStore) error { + snap, err := state.Snapshot() + if err != nil { + return err + } // Query all volumes - var err error var iter memdb.ResultIterator if args.NodeID != "" { - iter, err = state.CSIVolumesByNodeID(ws, args.NodeID) + iter, err = snap.CSIVolumesByNodeID(ws, args.NodeID) } else if args.PluginID != "" { - iter, err = state.CSIVolumesByPluginID(ws, ns, args.PluginID) + iter, err = snap.CSIVolumesByPluginID(ws, ns, args.PluginID) } else { - iter, err = state.CSIVolumesByNamespace(ws, ns) + iter, err = snap.CSIVolumesByNamespace(ws, ns) } if err != nil { @@ -140,23 +143,25 @@ func (v *CSIVolume) List(args *structs.CSIVolumeListRequest, reply *structs.CSIV if raw == nil { break } - vol := raw.(*structs.CSIVolume) - vol, err := state.CSIVolumeDenormalizePlugins(ws, vol.Copy()) - if err != nil { - return err - } - // Remove (possibly again) by PluginID to handle passing both NodeID and PluginID + // Remove (possibly again) by PluginID to handle passing both + // NodeID and PluginID if args.PluginID != "" && args.PluginID != vol.PluginID { continue } - // Remove by Namespace, since CSIVolumesByNodeID hasn't used the Namespace yet + // Remove by Namespace, since CSIVolumesByNodeID hasn't used + // the Namespace yet if vol.Namespace != ns { continue } + vol, err := snap.CSIVolumeDenormalizePlugins(ws, vol.Copy()) + if err != nil { + return err + } + vs = append(vs, vol.Stub()) } reply.Volumes = vs @@ -195,12 +200,17 @@ func (v *CSIVolume) Get(args *structs.CSIVolumeGetRequest, reply *structs.CSIVol queryOpts: &args.QueryOptions, queryMeta: &reply.QueryMeta, run: func(ws memdb.WatchSet, state *state.StateStore) error { - vol, err := state.CSIVolumeByID(ws, ns, args.ID) + snap, err := state.Snapshot() + if err != nil { + return err + } + + vol, err := snap.CSIVolumeByID(ws, ns, args.ID) if err != nil { return err } if vol != nil { - vol, err = state.CSIVolumeDenormalize(ws, vol) + vol, err = snap.CSIVolumeDenormalize(ws, vol) } if err != nil { return err @@ -214,9 +224,8 @@ func (v *CSIVolume) Get(args *structs.CSIVolumeGetRequest, reply *structs.CSIVol func (v *CSIVolume) pluginValidateVolume(req *structs.CSIVolumeRegisterRequest, vol *structs.CSIVolume) (*structs.CSIPlugin, error) { state := v.srv.fsm.State() - ws := memdb.NewWatchSet() - plugin, err := state.CSIPluginByID(ws, vol.PluginID) + plugin, err := state.CSIPluginByID(nil, vol.PluginID) if err != nil { return nil, err } @@ -481,9 +490,7 @@ func (v *CSIVolume) controllerPublishVolume(req *structs.CSIVolumeClaimRequest, func (v *CSIVolume) volAndPluginLookup(namespace, volID string) (*structs.CSIPlugin, *structs.CSIVolume, error) { state := v.srv.fsm.State() - ws := memdb.NewWatchSet() - - vol, err := state.CSIVolumeByID(ws, namespace, volID) + vol, err := state.CSIVolumeByID(nil, namespace, volID) if err != nil { return nil, nil, err } @@ -497,7 +504,7 @@ func (v *CSIVolume) volAndPluginLookup(namespace, volID string) (*structs.CSIPlu // note: we do this same lookup in CSIVolumeByID but then throw // away the pointer to the plugin rather than attaching it to // the volume so we have to do it again here. - plug, err := state.CSIPluginByID(ws, vol.PluginID) + plug, err := state.CSIPluginByID(nil, vol.PluginID) if err != nil { return nil, nil, err } @@ -870,7 +877,12 @@ func (v *CSIPlugin) Get(args *structs.CSIPluginGetRequest, reply *structs.CSIPlu queryOpts: &args.QueryOptions, queryMeta: &reply.QueryMeta, run: func(ws memdb.WatchSet, state *state.StateStore) error { - plug, err := state.CSIPluginByID(ws, args.ID) + snap, err := state.Snapshot() + if err != nil { + return err + } + + plug, err := snap.CSIPluginByID(ws, args.ID) if err != nil { return err } @@ -880,7 +892,7 @@ func (v *CSIPlugin) Get(args *structs.CSIPluginGetRequest, reply *structs.CSIPlu } if withAllocs { - plug, err = state.CSIPluginDenormalize(ws, plug.Copy()) + plug, err = snap.CSIPluginDenormalize(ws, plug.Copy()) if err != nil { return err } diff --git a/nomad/state/state_store.go b/nomad/state/state_store.go index 977eec5bfbc..9e5059407ba 100644 --- a/nomad/state/state_store.go +++ b/nomad/state/state_store.go @@ -72,7 +72,7 @@ type StateStore struct { // abandoned (usually during a restore). This is only ever closed. abandonCh chan struct{} - // TODO: refactor abondonCh to use a context so that both can use the same + // TODO: refactor abandonCh to use a context so that both can use the same // cancel mechanism. stopEventBroker func() } @@ -1272,7 +1272,7 @@ func deleteNodeCSIPlugins(txn *txn, node *structs.Node, index uint64) error { } // updateOrGCPlugin updates a plugin but will delete it if the plugin is empty -func updateOrGCPlugin(index uint64, txn *txn, plug *structs.CSIPlugin) error { +func updateOrGCPlugin(index uint64, txn Txn, plug *structs.CSIPlugin) error { plug.ModifyIndex = index if plug.IsEmpty() { @@ -1291,7 +1291,7 @@ func updateOrGCPlugin(index uint64, txn *txn, plug *structs.CSIPlugin) error { // deleteJobFromPlugins removes the allocations of this job from any plugins the job is // running, possibly deleting the plugin if it's no longer in use. It's called in DeleteJobTxn -func (s *StateStore) deleteJobFromPlugins(index uint64, txn *txn, job *structs.Job) error { +func (s *StateStore) deleteJobFromPlugins(index uint64, txn Txn, job *structs.Job) error { ws := memdb.NewWatchSet() summary, err := s.JobSummaryByID(ws, job.Namespace, job.ID) if err != nil { @@ -1348,7 +1348,7 @@ func (s *StateStore) deleteJobFromPlugins(index uint64, txn *txn, job *structs.J plug, ok := plugins[x.pluginID] if !ok { - plug, err = s.CSIPluginByID(ws, x.pluginID) + plug, err = s.CSIPluginByIDTxn(txn, nil, x.pluginID) if err != nil { return fmt.Errorf("error getting plugin: %s, %v", x.pluginID, err) } @@ -1826,22 +1826,20 @@ func (s *StateStore) JobsByIDPrefix(ws memdb.WatchSet, namespace, id string) (me func (s *StateStore) JobVersionsByID(ws memdb.WatchSet, namespace, id string) ([]*structs.Job, error) { txn := s.db.ReadTxn() - return s.jobVersionByID(txn, &ws, namespace, id) + return s.jobVersionByID(txn, ws, namespace, id) } // jobVersionByID is the underlying implementation for retrieving all tracked // versions of a job and is called under an existing transaction. A watch set // can optionally be passed in to add the job histories to the watch set. -func (s *StateStore) jobVersionByID(txn *txn, ws *memdb.WatchSet, namespace, id string) ([]*structs.Job, error) { +func (s *StateStore) jobVersionByID(txn *txn, ws memdb.WatchSet, namespace, id string) ([]*structs.Job, error) { // Get all the historic jobs for this ID iter, err := txn.Get("job_version", "id_prefix", namespace, id) if err != nil { return nil, err } - if ws != nil { - ws.Add(iter.WatchCh()) - } + ws.Add(iter.WatchCh()) var all []*structs.Job for { @@ -1884,9 +1882,7 @@ func (s *StateStore) jobByIDAndVersionImpl(ws memdb.WatchSet, namespace, id stri return nil, err } - if ws != nil { - ws.Add(watchCh) - } + ws.Add(watchCh) if existing != nil { job := existing.(*structs.Job) @@ -2096,7 +2092,8 @@ func (s *StateStore) CSIVolumeRegister(index uint64, volumes []*structs.CSIVolum return txn.Commit() } -// CSIVolumes returns the unfiltered list of all volumes +// CSIVolumes returns the unfiltered list of all volumes. Caller should +// snapshot if it wants to also denormalize the plugins. func (s *StateStore) CSIVolumes(ws memdb.WatchSet) (memdb.ResultIterator, error) { txn := s.db.ReadTxn() defer txn.Abort() @@ -2111,8 +2108,9 @@ func (s *StateStore) CSIVolumes(ws memdb.WatchSet) (memdb.ResultIterator, error) return iter, nil } -// CSIVolumeByID is used to lookup a single volume. Returns a copy of the volume -// because its plugins are denormalized to provide accurate Health. +// CSIVolumeByID is used to lookup a single volume. Returns a copy of the +// volume because its plugins and allocations are denormalized to provide +// accurate Health. func (s *StateStore) CSIVolumeByID(ws memdb.WatchSet, namespace, id string) (*structs.CSIVolume, error) { txn := s.db.ReadTxn() @@ -2120,17 +2118,21 @@ func (s *StateStore) CSIVolumeByID(ws memdb.WatchSet, namespace, id string) (*st if err != nil { return nil, fmt.Errorf("volume lookup failed: %s %v", id, err) } + ws.Add(watchCh) if obj == nil { return nil, nil } + // we return the volume with the plugins denormalized by default, + // because the scheduler needs them for feasibility checking vol := obj.(*structs.CSIVolume) - return s.CSIVolumeDenormalizePlugins(ws, vol.Copy()) + return s.CSIVolumeDenormalizePluginsTxn(txn, vol.Copy()) } -// CSIVolumes looks up csi_volumes by pluginID +// CSIVolumes looks up csi_volumes by pluginID. Caller should snapshot if it +// wants to also denormalize the plugins. func (s *StateStore) CSIVolumesByPluginID(ws memdb.WatchSet, namespace, pluginID string) (memdb.ResultIterator, error) { txn := s.db.ReadTxn() @@ -2152,7 +2154,8 @@ func (s *StateStore) CSIVolumesByPluginID(ws memdb.WatchSet, namespace, pluginID return wrap, nil } -// CSIVolumesByIDPrefix supports search +// CSIVolumesByIDPrefix supports search. Caller should snapshot if it wants to +// also denormalize the plugins. func (s *StateStore) CSIVolumesByIDPrefix(ws memdb.WatchSet, namespace, volumeID string) (memdb.ResultIterator, error) { txn := s.db.ReadTxn() @@ -2162,10 +2165,12 @@ func (s *StateStore) CSIVolumesByIDPrefix(ws memdb.WatchSet, namespace, volumeID } ws.Add(iter.WatchCh()) + return iter, nil } -// CSIVolumesByNodeID looks up CSIVolumes in use on a node +// CSIVolumesByNodeID looks up CSIVolumes in use on a node. Caller should +// snapshot if it wants to also denormalize the plugins. func (s *StateStore) CSIVolumesByNodeID(ws memdb.WatchSet, nodeID string) (memdb.ResultIterator, error) { allocs, err := s.AllocsByNode(ws, nodeID) if err != nil { @@ -2202,6 +2207,8 @@ func (s *StateStore) CSIVolumesByNodeID(ws memdb.WatchSet, nodeID string) (memdb iter.Add(raw) } + ws.Add(iter.WatchCh()) + return iter, nil } @@ -2213,6 +2220,7 @@ func (s *StateStore) CSIVolumesByNamespace(ws memdb.WatchSet, namespace string) if err != nil { return nil, fmt.Errorf("volume lookup failed: %v", err) } + ws.Add(iter.WatchCh()) return iter, nil @@ -2222,7 +2230,6 @@ func (s *StateStore) CSIVolumesByNamespace(ws memdb.WatchSet, namespace string) func (s *StateStore) CSIVolumeClaim(index uint64, namespace, id string, claim *structs.CSIVolumeClaim) error { txn := s.db.WriteTxn(index) defer txn.Abort() - ws := memdb.NewWatchSet() row, err := txn.First("csi_volumes", "id", namespace, id) if err != nil { @@ -2239,7 +2246,7 @@ func (s *StateStore) CSIVolumeClaim(index uint64, namespace, id string, claim *s var alloc *structs.Allocation if claim.State == structs.CSIVolumeClaimStateTaken { - alloc, err = s.AllocByID(ws, claim.AllocationID) + alloc, err = s.allocByIDImpl(txn, nil, claim.AllocationID) if err != nil { s.logger.Error("AllocByID failed", "error", err) return fmt.Errorf(structs.ErrUnknownAllocationPrefix) @@ -2252,12 +2259,11 @@ func (s *StateStore) CSIVolumeClaim(index uint64, namespace, id string, claim *s } } - volume, err := s.CSIVolumeDenormalizePlugins(ws, orig.Copy()) + volume, err := s.CSIVolumeDenormalizePluginsTxn(txn, orig.Copy()) if err != nil { return err } - - volume, err = s.CSIVolumeDenormalize(ws, volume) + volume, err = s.CSIVolumeDenormalizeTxn(txn, nil, volume) if err != nil { return err } @@ -2321,7 +2327,7 @@ func (s *StateStore) CSIVolumeDeregister(index uint64, namespace string, ids []s // allocations have been stopped but claims can't be freed because // ex. the plugins have all been removed. if vol.InUse() { - if !force || !s.volSafeToForce(vol) { + if !force || !s.volSafeToForce(txn, vol) { return fmt.Errorf("volume in use: %s", id) } } @@ -2340,9 +2346,8 @@ func (s *StateStore) CSIVolumeDeregister(index uint64, namespace string, ids []s // volSafeToForce checks if the any of the remaining allocations // are in a non-terminal state. -func (s *StateStore) volSafeToForce(v *structs.CSIVolume) bool { - ws := memdb.NewWatchSet() - vol, err := s.CSIVolumeDenormalize(ws, v) +func (s *StateStore) volSafeToForce(txn Txn, v *structs.CSIVolume) bool { + vol, err := s.CSIVolumeDenormalizeTxn(txn, nil, v) if err != nil { return false } @@ -2360,19 +2365,30 @@ func (s *StateStore) volSafeToForce(v *structs.CSIVolume) bool { return true } -// CSIVolumeDenormalizePlugins returns a CSIVolume with current health and plugins, but -// without allocations -// Use this for current volume metadata, handling lists of volumes -// Use CSIVolumeDenormalize for volumes containing both health and current allocations +// CSIVolumeDenormalizePlugins returns a CSIVolume with current health and +// plugins, but without allocations. +// Use this for current volume metadata, handling lists of volumes. +// Use CSIVolumeDenormalize for volumes containing both health and current +// allocations. func (s *StateStore) CSIVolumeDenormalizePlugins(ws memdb.WatchSet, vol *structs.CSIVolume) (*structs.CSIVolume, error) { if vol == nil { return nil, nil } - // Lookup CSIPlugin, the health records, and calculate volume health txn := s.db.ReadTxn() defer txn.Abort() + return s.CSIVolumeDenormalizePluginsTxn(txn, vol) +} - plug, err := s.CSIPluginByID(ws, vol.PluginID) +// CSIVolumeDenormalizePluginsTxn returns a CSIVolume with current health and +// plugins, but without allocations. +// Use this for current volume metadata, handling lists of volumes. +// Use CSIVolumeDenormalize for volumes containing both health and current +// allocations. +func (s *StateStore) CSIVolumeDenormalizePluginsTxn(txn Txn, vol *structs.CSIVolume) (*structs.CSIVolume, error) { + if vol == nil { + return nil, nil + } + plug, err := s.CSIPluginByIDTxn(txn, nil, vol.PluginID) if err != nil { return nil, fmt.Errorf("plugin lookup error: %s %v", vol.PluginID, err) } @@ -2403,8 +2419,17 @@ func (s *StateStore) CSIVolumeDenormalizePlugins(ws memdb.WatchSet, vol *structs // CSIVolumeDenormalize returns a CSIVolume with allocations func (s *StateStore) CSIVolumeDenormalize(ws memdb.WatchSet, vol *structs.CSIVolume) (*structs.CSIVolume, error) { + txn := s.db.ReadTxn() + return s.CSIVolumeDenormalizeTxn(txn, ws, vol) +} + +// CSIVolumeDenormalizeTxn populates a CSIVolume with allocations +func (s *StateStore) CSIVolumeDenormalizeTxn(txn Txn, ws memdb.WatchSet, vol *structs.CSIVolume) (*structs.CSIVolume, error) { + if vol == nil { + return nil, nil + } for id := range vol.ReadAllocs { - a, err := s.AllocByID(ws, id) + a, err := s.allocByIDImpl(txn, ws, id) if err != nil { return nil, err } @@ -2425,7 +2450,7 @@ func (s *StateStore) CSIVolumeDenormalize(ws memdb.WatchSet, vol *structs.CSIVol } for id := range vol.WriteAllocs { - a, err := s.AllocByID(ws, id) + a, err := s.allocByIDImpl(txn, ws, id) if err != nil { return nil, err } @@ -2474,27 +2499,40 @@ func (s *StateStore) CSIPluginsByIDPrefix(ws memdb.WatchSet, pluginID string) (m return iter, nil } -// CSIPluginByID returns the one named CSIPlugin +// CSIPluginByID returns a named CSIPlugin. This method creates a new +// transaction so you should not call it from within another transaction. func (s *StateStore) CSIPluginByID(ws memdb.WatchSet, id string) (*structs.CSIPlugin, error) { txn := s.db.ReadTxn() - defer txn.Abort() - - raw, err := txn.First("csi_plugins", "id_prefix", id) + plugin, err := s.CSIPluginByIDTxn(txn, ws, id) if err != nil { - return nil, fmt.Errorf("csi_plugin lookup failed: %s %v", id, err) + return nil, err } + return plugin, nil +} - if raw == nil { - return nil, nil +// CSIPluginByIDTxn returns a named CSIPlugin +func (s *StateStore) CSIPluginByIDTxn(txn Txn, ws memdb.WatchSet, id string) (*structs.CSIPlugin, error) { + + watchCh, obj, err := txn.FirstWatch("csi_plugins", "id_prefix", id) + if err != nil { + return nil, fmt.Errorf("csi_plugin lookup failed: %s %v", id, err) } - plug := raw.(*structs.CSIPlugin) + ws.Add(watchCh) - return plug, nil + if obj != nil { + return obj.(*structs.CSIPlugin), nil + } + return nil, nil } // CSIPluginDenormalize returns a CSIPlugin with allocation details. Always called on a copy of the plugin. func (s *StateStore) CSIPluginDenormalize(ws memdb.WatchSet, plug *structs.CSIPlugin) (*structs.CSIPlugin, error) { + txn := s.db.ReadTxn() + return s.CSIPluginDenormalizeTxn(txn, ws, plug) +} + +func (s *StateStore) CSIPluginDenormalizeTxn(txn Txn, ws memdb.WatchSet, plug *structs.CSIPlugin) (*structs.CSIPlugin, error) { if plug == nil { return nil, nil } @@ -2509,7 +2547,7 @@ func (s *StateStore) CSIPluginDenormalize(ws memdb.WatchSet, plug *structs.CSIPl } for id := range ids { - alloc, err := s.AllocByID(ws, id) + alloc, err := s.allocByIDImpl(txn, ws, id) if err != nil { return nil, err } @@ -2553,9 +2591,8 @@ func (s *StateStore) UpsertCSIPlugin(index uint64, plug *structs.CSIPlugin) erro func (s *StateStore) DeleteCSIPlugin(index uint64, id string) error { txn := s.db.WriteTxn(index) defer txn.Abort() - ws := memdb.NewWatchSet() - plug, err := s.CSIPluginByID(ws, id) + plug, err := s.CSIPluginByIDTxn(txn, nil, id) if err != nil { return err } @@ -2564,7 +2601,7 @@ func (s *StateStore) DeleteCSIPlugin(index uint64, id string) error { return nil } - plug, err = s.CSIPluginDenormalize(ws, plug.Copy()) + plug, err = s.CSIPluginDenormalizeTxn(txn, nil, plug.Copy()) if err != nil { return err } @@ -3307,18 +3344,25 @@ func (s *StateStore) nestedUpdateAllocDesiredTransition( // AllocByID is used to lookup an allocation by its ID func (s *StateStore) AllocByID(ws memdb.WatchSet, id string) (*structs.Allocation, error) { txn := s.db.ReadTxn() + return s.allocByIDImpl(txn, ws, id) +} - watchCh, existing, err := txn.FirstWatch("allocs", "id", id) +// allocByIDImpl retrives an allocation and is called under and existing +// transaction. An optional watch set can be passed to add allocations to the +// watch set +func (s *StateStore) allocByIDImpl(txn Txn, ws memdb.WatchSet, id string) (*structs.Allocation, error) { + watchCh, raw, err := txn.FirstWatch("allocs", "id", id) if err != nil { return nil, fmt.Errorf("alloc lookup failed: %v", err) } ws.Add(watchCh) - if existing != nil { - return existing.(*structs.Allocation), nil + if raw == nil { + return nil, nil } - return nil, nil + alloc := raw.(*structs.Allocation) + return alloc, nil } // AllocsByIDPrefix is used to lookup allocs by prefix @@ -4613,7 +4657,6 @@ func (s *StateStore) updateJobScalingPolicies(index uint64, job *structs.Job, tx // updateJobCSIPlugins runs on job update, and indexes the job in the plugin func (s *StateStore) updateJobCSIPlugins(index uint64, job, prev *structs.Job, txn *txn) error { - ws := memdb.NewWatchSet() plugIns := make(map[string]*structs.CSIPlugin) loop := func(job *structs.Job, delete bool) error { @@ -4625,7 +4668,7 @@ func (s *StateStore) updateJobCSIPlugins(index uint64, job, prev *structs.Job, t plugIn, ok := plugIns[t.CSIPluginConfig.ID] if !ok { - p, err := s.CSIPluginByID(ws, t.CSIPluginConfig.ID) + p, err := s.CSIPluginByIDTxn(txn, nil, t.CSIPluginConfig.ID) if err != nil { return err } @@ -4909,12 +4952,11 @@ func (s *StateStore) updatePluginWithAlloc(index uint64, alloc *structs.Allocati return nil } - ws := memdb.NewWatchSet() tg := alloc.Job.LookupTaskGroup(alloc.TaskGroup) for _, t := range tg.Tasks { if t.CSIPluginConfig != nil { pluginID := t.CSIPluginConfig.ID - plug, err := s.CSIPluginByID(ws, pluginID) + plug, err := s.CSIPluginByIDTxn(txn, nil, pluginID) if err != nil { return err } @@ -4943,7 +4985,6 @@ func (s *StateStore) updatePluginWithAlloc(index uint64, alloc *structs.Allocati func (s *StateStore) updatePluginWithJobSummary(index uint64, summary *structs.JobSummary, alloc *structs.Allocation, txn *txn) error { - ws := memdb.NewWatchSet() tg := alloc.Job.LookupTaskGroup(alloc.TaskGroup) if tg == nil { return nil @@ -4952,7 +4993,7 @@ func (s *StateStore) updatePluginWithJobSummary(index uint64, summary *structs.J for _, t := range tg.Tasks { if t.CSIPluginConfig != nil { pluginID := t.CSIPluginConfig.ID - plug, err := s.CSIPluginByID(ws, pluginID) + plug, err := s.CSIPluginByIDTxn(txn, nil, pluginID) if err != nil { return err } diff --git a/nomad/state/state_store_test.go b/nomad/state/state_store_test.go index 982080b2135..76bd0ebb363 100644 --- a/nomad/state/state_store_test.go +++ b/nomad/state/state_store_test.go @@ -3520,6 +3520,84 @@ func TestStateStore_CSIPluginMultiNodeUpdates(t *testing.T) { } +// TestStateStore_CSIPlugin_ConcurrentStop tests that concurrent allocation +// updates don't cause the count to drift unexpectedly or cause allocation +// update errors. +func TestStateStore_CSIPlugin_ConcurrentStop(t *testing.T) { + t.Parallel() + index := uint64(999) + state := testStateStore(t) + ws := memdb.NewWatchSet() + + var err error + + // Create Nomad client Nodes + ns := []*structs.Node{mock.Node(), mock.Node(), mock.Node()} + for _, n := range ns { + index++ + err = state.UpsertNode(structs.MsgTypeTestSetup, index, n) + require.NoError(t, err) + } + + plugID := "foo" + plugCfg := &structs.TaskCSIPluginConfig{ID: plugID} + + allocs := []*structs.Allocation{} + + // Fingerprint 3 running node plugins and their allocs + for _, n := range ns[:] { + alloc := mock.Alloc() + n, _ := state.NodeByID(ws, n.ID) + n.CSINodePlugins = map[string]*structs.CSIInfo{ + plugID: { + PluginID: plugID, + AllocID: alloc.ID, + Healthy: true, + UpdateTime: time.Now(), + RequiresControllerPlugin: true, + RequiresTopologies: false, + NodeInfo: &structs.CSINodeInfo{}, + }, + } + index++ + err = state.UpsertNode(structs.MsgTypeTestSetup, index, n) + require.NoError(t, err) + + alloc.NodeID = n.ID + alloc.DesiredStatus = "run" + alloc.ClientStatus = "running" + alloc.Job.TaskGroups[0].Tasks[0].CSIPluginConfig = plugCfg + + index++ + err = state.UpsertAllocs(structs.MsgTypeTestSetup, index, []*structs.Allocation{alloc}) + require.NoError(t, err) + + allocs = append(allocs, alloc) + } + + plug, err := state.CSIPluginByID(ws, plugID) + require.NoError(t, err) + require.Equal(t, 3, plug.NodesHealthy, "nodes healthy") + require.Equal(t, 3, len(plug.Nodes), "nodes expected") + + // stop all the allocs + for _, alloc := range allocs { + alloc.DesiredStatus = "stop" + alloc.ClientStatus = "complete" + } + + // this is somewhat artificial b/c we get alloc updates from multiple + // nodes concurrently but not in a single RPC call. But this guarantees + // we'll trigger any nested transaction setup bugs + index++ + err = state.UpsertAllocs(structs.MsgTypeTestSetup, index, allocs) + require.NoError(t, err) + + plug, err = state.CSIPluginByID(ws, plugID) + require.NoError(t, err) + require.Nil(t, plug) +} + func TestStateStore_CSIPluginJobs(t *testing.T) { s := testStateStore(t) index := uint64(1001)