Skip to content

Commit

Permalink
backport of commit 299d897
Browse files Browse the repository at this point in the history
  • Loading branch information
tgross committed Jul 19, 2023
1 parent 2cb8786 commit 8170212
Show file tree
Hide file tree
Showing 5 changed files with 138 additions and 17 deletions.
3 changes: 3 additions & 0 deletions .changelog/17996.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
```release-note:bug
csi: Fixed a bug in sending concurrent requests to CSI controller plugins by serializing them per plugin
```
21 changes: 10 additions & 11 deletions nomad/client_csi_endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package nomad

import (
"fmt"
"math/rand"
"sort"
"strings"
"time"

Expand Down Expand Up @@ -220,9 +220,9 @@ func (a *ClientCSI) clientIDsForController(pluginID string) ([]string, error) {

ws := memdb.NewWatchSet()

// note: plugin IDs are not scoped to region/DC but volumes are.
// so any node we get for a controller is already in the same
// region/DC for the volume.
// note: plugin IDs are not scoped to region but volumes are. so any Nomad
// client we get for a controller is already in the same region for the
// volume.
plugin, err := snap.CSIPluginByID(ws, pluginID)
if err != nil {
return nil, fmt.Errorf("error getting plugin: %s, %v", pluginID, err)
Expand All @@ -231,13 +231,10 @@ func (a *ClientCSI) clientIDsForController(pluginID string) ([]string, error) {
return nil, fmt.Errorf("plugin missing: %s", pluginID)
}

// iterating maps is "random" but unspecified and isn't particularly
// random with small maps, so not well-suited for load balancing.
// so we shuffle the keys and iterate over them.
clientIDs := []string{}

for clientID, controller := range plugin.Controllers {
if !controller.IsController() {
if !controller.IsController() || !controller.Healthy {
// we don't have separate types for CSIInfo depending on
// whether it's a controller or node. this error shouldn't
// make it to production but is to aid developers during
Expand All @@ -253,9 +250,11 @@ func (a *ClientCSI) clientIDsForController(pluginID string) ([]string, error) {
return nil, fmt.Errorf("failed to find clients running controller plugin %q", pluginID)
}

rand.Shuffle(len(clientIDs), func(i, j int) {
clientIDs[i], clientIDs[j] = clientIDs[j], clientIDs[i]
})
// Many plugins don't handle concurrent requests as described in the spec,
// and have undocumented expectations of using k8s-specific sidecars to
// leader elect. Sort the client IDs so that we prefer sending requests to
// the same controller to hack around this.
clientIDs = sort.StringSlice(clientIDs)

return clientIDs, nil
}
75 changes: 69 additions & 6 deletions nomad/csi_endpoint.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package nomad

import (
"context"
"fmt"
"net/http"
"strings"
Expand Down Expand Up @@ -558,7 +559,9 @@ func (v *CSIVolume) controllerPublishVolume(req *structs.CSIVolumeClaimRequest,
cReq.PluginID = plug.ID
cResp := &cstructs.ClientCSIControllerAttachVolumeResponse{}

err = v.srv.RPC(method, cReq, cResp)
err = v.serializedControllerRPC(plug.ID, func() error {
return v.srv.RPC(method, cReq, cResp)
})
if err != nil {
if strings.Contains(err.Error(), "FailedPrecondition") {
return fmt.Errorf("%v: %v", structs.ErrCSIClientRPCRetryable, err)
Expand Down Expand Up @@ -595,6 +598,57 @@ func (v *CSIVolume) volAndPluginLookup(namespace, volID string) (*structs.CSIPlu
return plug, vol, nil
}

// serializedControllerRPC ensures we're only sending a single controller RPC to
// a given plugin if the RPC can cause conflicting state changes.
//
// The CSI specification says that we SHOULD send no more than one in-flight
// request per *volume* at a time, with an allowance for losing state
// (ex. leadership transitions) which the plugins SHOULD handle gracefully.
//
// In practice many CSI plugins rely on k8s-specific sidecars for serializing
// storage provider API calls globally (ex. concurrently attaching EBS volumes
// to an EC2 instance results in a race for device names). So we have to be much
// more conservative about concurrency in Nomad than the spec allows.
func (v *CSIVolume) serializedControllerRPC(pluginID string, fn func() error) error {

for {
v.srv.volumeControllerLock.Lock()
future := v.srv.volumeControllerFutures[pluginID]
if future == nil {
future, futureDone := context.WithCancel(v.srv.shutdownCtx)
v.srv.volumeControllerFutures[pluginID] = future
v.srv.volumeControllerLock.Unlock()

err := fn()

// close the future while holding the lock and not in a defer so
// that we can ensure we've cleared it from the map before allowing
// anyone else to take the lock and write a new one
v.srv.volumeControllerLock.Lock()
futureDone()
delete(v.srv.volumeControllerFutures, pluginID)
v.srv.volumeControllerLock.Unlock()

return err
} else {
v.srv.volumeControllerLock.Unlock()

select {
case <-future.Done():
continue
case <-v.srv.shutdownCh:
// The csi_hook publish workflow on the client will retry if it
// gets this error. On unpublish, we don't want to block client
// shutdown so we give up on error. The new leader's
// volumewatcher will iterate all the claims at startup to
// detect this and mop up any claims in the NodeDetached state
// (volume GC will run periodically as well)
return structs.ErrNoLeader
}
}
}
}

// allowCSIMount is called on Job register to check mount permission
func allowCSIMount(aclObj *acl.ACL, namespace string) bool {
return aclObj.AllowPluginRead() &&
Expand Down Expand Up @@ -866,8 +920,11 @@ func (v *CSIVolume) controllerUnpublishVolume(vol *structs.CSIVolume, claim *str
Secrets: vol.Secrets,
}
req.PluginID = vol.PluginID
err = v.srv.RPC("ClientCSI.ControllerDetachVolume", req,
&cstructs.ClientCSIControllerDetachVolumeResponse{})

err = v.serializedControllerRPC(vol.PluginID, func() error {
return v.srv.RPC("ClientCSI.ControllerDetachVolume", req,
&cstructs.ClientCSIControllerDetachVolumeResponse{})
})
if err != nil {
return fmt.Errorf("could not detach from controller: %v", err)
}
Expand Down Expand Up @@ -1133,7 +1190,9 @@ func (v *CSIVolume) deleteVolume(vol *structs.CSIVolume, plugin *structs.CSIPlug
cReq.PluginID = plugin.ID
cResp := &cstructs.ClientCSIControllerDeleteVolumeResponse{}

return v.srv.RPC(method, cReq, cResp)
return v.serializedControllerRPC(plugin.ID, func() error {
return v.srv.RPC(method, cReq, cResp)
})
}

func (v *CSIVolume) ListExternal(args *structs.CSIVolumeExternalListRequest, reply *structs.CSIVolumeExternalListResponse) error {
Expand Down Expand Up @@ -1270,7 +1329,9 @@ func (v *CSIVolume) CreateSnapshot(args *structs.CSISnapshotCreateRequest, reply
}
cReq.PluginID = pluginID
cResp := &cstructs.ClientCSIControllerCreateSnapshotResponse{}
err = v.srv.RPC(method, cReq, cResp)
err = v.serializedControllerRPC(pluginID, func() error {
return v.srv.RPC(method, cReq, cResp)
})
if err != nil {
multierror.Append(&mErr, fmt.Errorf("could not create snapshot: %v", err))
continue
Expand Down Expand Up @@ -1339,7 +1400,9 @@ func (v *CSIVolume) DeleteSnapshot(args *structs.CSISnapshotDeleteRequest, reply
cReq := &cstructs.ClientCSIControllerDeleteSnapshotRequest{ID: snap.ID}
cReq.PluginID = plugin.ID
cResp := &cstructs.ClientCSIControllerDeleteSnapshotResponse{}
err = v.srv.RPC(method, cReq, cResp)
err = v.serializedControllerRPC(plugin.ID, func() error {
return v.srv.RPC(method, cReq, cResp)
})
if err != nil {
multierror.Append(&mErr, fmt.Errorf("could not delete %q: %v", snap.ID, err))
}
Expand Down
48 changes: 48 additions & 0 deletions nomad/csi_endpoint_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package nomad
import (
"fmt"
"strings"
"sync"
"testing"
"time"

Expand All @@ -18,6 +19,7 @@ import (
cconfig "github.com/hashicorp/nomad/client/config"
cstructs "github.com/hashicorp/nomad/client/structs"
"github.com/hashicorp/nomad/helper/uuid"
"github.com/hashicorp/nomad/lib/lang"
"github.com/hashicorp/nomad/nomad/mock"
"github.com/hashicorp/nomad/nomad/state"
"github.com/hashicorp/nomad/nomad/structs"
Expand Down Expand Up @@ -1968,3 +1970,49 @@ func TestCSI_RPCVolumeAndPluginLookup(t *testing.T) {
require.Nil(t, vol)
require.EqualError(t, err, fmt.Sprintf("volume not found: %s", id2))
}

func TestCSI_SerializedControllerRPC(t *testing.T) {
ci.Parallel(t)

srv, shutdown := TestServer(t, func(c *Config) { c.NumSchedulers = 0 })
defer shutdown()
testutil.WaitForLeader(t, srv.RPC)

var wg sync.WaitGroup
wg.Add(3)

timeCh := make(chan lang.Pair[string, time.Duration])

testFn := func(pluginID string, dur time.Duration) {
defer wg.Done()
c := NewCSIVolumeEndpoint(srv, nil)
now := time.Now()
err := c.serializedControllerRPC(pluginID, func() error {
time.Sleep(dur)
return nil
})
elapsed := time.Since(now)
timeCh <- lang.Pair[string, time.Duration]{pluginID, elapsed}
must.NoError(t, err)
}

go testFn("plugin1", 50*time.Millisecond)
go testFn("plugin2", 50*time.Millisecond)
go testFn("plugin1", 50*time.Millisecond)

totals := map[string]time.Duration{}
for i := 0; i < 3; i++ {
pair := <-timeCh
totals[pair.First] += pair.Second
}

wg.Wait()

// plugin1 RPCs should block each other
must.GreaterEq(t, 150*time.Millisecond, totals["plugin1"])
must.Less(t, 200*time.Millisecond, totals["plugin1"])

// plugin1 RPCs should not block plugin2 RPCs
must.GreaterEq(t, 50*time.Millisecond, totals["plugin2"])
must.Less(t, 100*time.Millisecond, totals["plugin2"])
}
8 changes: 8 additions & 0 deletions nomad/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,13 @@ type Server struct {
// volumeWatcher is used to release volume claims
volumeWatcher *volumewatcher.Watcher

// volumeControllerFutures is a map of plugin IDs to pending controller RPCs. If
// no RPC is pending for a given plugin, this may be nil.
volumeControllerFutures map[string]context.Context

// volumeControllerLock synchronizes access controllerFutures map
volumeControllerLock sync.Mutex

// keyringReplicator is used to replicate root encryption keys from the
// leader
keyringReplicator *KeyringReplicator
Expand Down Expand Up @@ -463,6 +470,7 @@ func NewServer(config *Config, consulCatalog consul.CatalogAPI, consulConfigEntr
s.logger.Error("failed to create volume watcher", "error", err)
return nil, fmt.Errorf("failed to create volume watcher: %v", err)
}
s.volumeControllerFutures = map[string]context.Context{}

// Start the eval broker notification system so any subscribers can get
// updates when the processes SetEnabled is triggered.
Expand Down

0 comments on commit 8170212

Please sign in to comment.