Skip to content

Commit

Permalink
CSI: fix transaction handling in state store (#9438)
Browse files Browse the repository at this point in the history
When making updates to CSI plugins, the state store methods that have open
write transactions were querying the state store using the same methods used
by the CSI RPC endpoint, but these method creates their own top-level read
transactions. During concurrent plugin updates (as happens when a plugin job
is stopped), this can cause write skew in the plugin counts.

* Refactor the CSIPlugin query methods to have an implementation method that
accepts a transaction, which can be called with either a read txn or a write
txn.
* Refactor the CSIVolume query methods to have an implementation method that
accepts a transaction, which can be called with either a read txn or a write
txn.
* CSI volumes need to be "denormalized" with their plugins and (optionally)
allocations. Read-only RPC endpoints should take a snapshot so that we can
make multiple state store method calls with a consistent view.
  • Loading branch information
tgross authored Nov 25, 2020
1 parent 0019fac commit c2aaa51
Show file tree
Hide file tree
Showing 3 changed files with 211 additions and 80 deletions.
54 changes: 33 additions & 21 deletions nomad/csi_endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,16 +116,19 @@ func (v *CSIVolume) List(args *structs.CSIVolumeListRequest, reply *structs.CSIV
queryOpts: &args.QueryOptions,
queryMeta: &reply.QueryMeta,
run: func(ws memdb.WatchSet, state *state.StateStore) error {
snap, err := state.Snapshot()
if err != nil {
return err
}
// Query all volumes
var err error
var iter memdb.ResultIterator

if args.NodeID != "" {
iter, err = state.CSIVolumesByNodeID(ws, args.NodeID)
iter, err = snap.CSIVolumesByNodeID(ws, args.NodeID)
} else if args.PluginID != "" {
iter, err = state.CSIVolumesByPluginID(ws, ns, args.PluginID)
iter, err = snap.CSIVolumesByPluginID(ws, ns, args.PluginID)
} else {
iter, err = state.CSIVolumesByNamespace(ws, ns)
iter, err = snap.CSIVolumesByNamespace(ws, ns)
}

if err != nil {
Expand All @@ -140,23 +143,25 @@ func (v *CSIVolume) List(args *structs.CSIVolumeListRequest, reply *structs.CSIV
if raw == nil {
break
}

vol := raw.(*structs.CSIVolume)
vol, err := state.CSIVolumeDenormalizePlugins(ws, vol.Copy())
if err != nil {
return err
}

// Remove (possibly again) by PluginID to handle passing both NodeID and PluginID
// Remove (possibly again) by PluginID to handle passing both
// NodeID and PluginID
if args.PluginID != "" && args.PluginID != vol.PluginID {
continue
}

// Remove by Namespace, since CSIVolumesByNodeID hasn't used the Namespace yet
// Remove by Namespace, since CSIVolumesByNodeID hasn't used
// the Namespace yet
if vol.Namespace != ns {
continue
}

vol, err := snap.CSIVolumeDenormalizePlugins(ws, vol.Copy())
if err != nil {
return err
}

vs = append(vs, vol.Stub())
}
reply.Volumes = vs
Expand Down Expand Up @@ -195,12 +200,17 @@ func (v *CSIVolume) Get(args *structs.CSIVolumeGetRequest, reply *structs.CSIVol
queryOpts: &args.QueryOptions,
queryMeta: &reply.QueryMeta,
run: func(ws memdb.WatchSet, state *state.StateStore) error {
vol, err := state.CSIVolumeByID(ws, ns, args.ID)
snap, err := state.Snapshot()
if err != nil {
return err
}

vol, err := snap.CSIVolumeByID(ws, ns, args.ID)
if err != nil {
return err
}
if vol != nil {
vol, err = state.CSIVolumeDenormalize(ws, vol)
vol, err = snap.CSIVolumeDenormalize(ws, vol)
}
if err != nil {
return err
Expand All @@ -214,9 +224,8 @@ func (v *CSIVolume) Get(args *structs.CSIVolumeGetRequest, reply *structs.CSIVol

func (v *CSIVolume) pluginValidateVolume(req *structs.CSIVolumeRegisterRequest, vol *structs.CSIVolume) (*structs.CSIPlugin, error) {
state := v.srv.fsm.State()
ws := memdb.NewWatchSet()

plugin, err := state.CSIPluginByID(ws, vol.PluginID)
plugin, err := state.CSIPluginByID(nil, vol.PluginID)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -481,9 +490,7 @@ func (v *CSIVolume) controllerPublishVolume(req *structs.CSIVolumeClaimRequest,

func (v *CSIVolume) volAndPluginLookup(namespace, volID string) (*structs.CSIPlugin, *structs.CSIVolume, error) {
state := v.srv.fsm.State()
ws := memdb.NewWatchSet()

vol, err := state.CSIVolumeByID(ws, namespace, volID)
vol, err := state.CSIVolumeByID(nil, namespace, volID)
if err != nil {
return nil, nil, err
}
Expand All @@ -497,7 +504,7 @@ func (v *CSIVolume) volAndPluginLookup(namespace, volID string) (*structs.CSIPlu
// note: we do this same lookup in CSIVolumeByID but then throw
// away the pointer to the plugin rather than attaching it to
// the volume so we have to do it again here.
plug, err := state.CSIPluginByID(ws, vol.PluginID)
plug, err := state.CSIPluginByID(nil, vol.PluginID)
if err != nil {
return nil, nil, err
}
Expand Down Expand Up @@ -870,7 +877,12 @@ func (v *CSIPlugin) Get(args *structs.CSIPluginGetRequest, reply *structs.CSIPlu
queryOpts: &args.QueryOptions,
queryMeta: &reply.QueryMeta,
run: func(ws memdb.WatchSet, state *state.StateStore) error {
plug, err := state.CSIPluginByID(ws, args.ID)
snap, err := state.Snapshot()
if err != nil {
return err
}

plug, err := snap.CSIPluginByID(ws, args.ID)
if err != nil {
return err
}
Expand All @@ -880,7 +892,7 @@ func (v *CSIPlugin) Get(args *structs.CSIPluginGetRequest, reply *structs.CSIPlu
}

if withAllocs {
plug, err = state.CSIPluginDenormalize(ws, plug.Copy())
plug, err = snap.CSIPluginDenormalize(ws, plug.Copy())
if err != nil {
return err
}
Expand Down
Loading

0 comments on commit c2aaa51

Please sign in to comment.