Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CSI: improve controller RPC reliability #17996

Merged
merged 2 commits into from
Jul 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions .changelog/17996.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
```release-note:bug
csi: Fixed a bug in sending concurrent requests to CSI controller plugins by serializing them per plugin
```

```release-note:bug
csi: Fixed a bug where CSI controller requests could be sent to unhealthy plugins
```

```release-note:bug
csi: Fixed a bug where CSI controller requests could not be sent to controllers on nodes ineligible for scheduling
```
59 changes: 42 additions & 17 deletions nomad/client_csi_endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
package nomad

import (
"errors"
"fmt"
"math/rand"
"sort"
"strings"
"time"

Expand Down Expand Up @@ -262,9 +263,9 @@ func (a *ClientCSI) clientIDsForController(pluginID string) ([]string, error) {

ws := memdb.NewWatchSet()

// note: plugin IDs are not scoped to region/DC but volumes are.
// so any node we get for a controller is already in the same
// region/DC for the volume.
// note: plugin IDs are not scoped to region but volumes are. so any Nomad
// client we get for a controller is already in the same region for the
// volume.
plugin, err := snap.CSIPluginByID(ws, pluginID)
if err != nil {
return nil, fmt.Errorf("error getting plugin: %s, %v", pluginID, err)
Expand All @@ -273,31 +274,55 @@ func (a *ClientCSI) clientIDsForController(pluginID string) ([]string, error) {
return nil, fmt.Errorf("plugin missing: %s", pluginID)
}

// iterating maps is "random" but unspecified and isn't particularly
// random with small maps, so not well-suited for load balancing.
// so we shuffle the keys and iterate over them.
clientIDs := []string{}

if len(plugin.Controllers) == 0 {
return nil, fmt.Errorf("failed to find instances of controller plugin %q", pluginID)
}

var merr error
for clientID, controller := range plugin.Controllers {
if !controller.IsController() {
// we don't have separate types for CSIInfo depending on
// whether it's a controller or node. this error shouldn't
// make it to production but is to aid developers during
// development
// we don't have separate types for CSIInfo depending on whether
// it's a controller or node. this error should never make it to
// production
merr = errors.Join(merr, fmt.Errorf(
"plugin instance %q is not a controller but was registered as one - this is always a bug", controller.AllocID))
continue
}

if !controller.Healthy {
merr = errors.Join(merr, fmt.Errorf(
"plugin instance %q is not healthy", controller.AllocID))
continue
}

node, err := getNodeForRpc(snap, clientID)
if err == nil && node != nil && node.Ready() {
clientIDs = append(clientIDs, clientID)
if err != nil || node == nil {
merr = errors.Join(merr, fmt.Errorf(
"cannot find node %q for plugin instance %q", clientID, controller.AllocID))
continue
}

if node.Status != structs.NodeStatusReady {
merr = errors.Join(merr, fmt.Errorf(
"node %q for plugin instance %q is not ready", clientID, controller.AllocID))
continue
}

clientIDs = append(clientIDs, clientID)
}

if len(clientIDs) == 0 {
return nil, fmt.Errorf("failed to find clients running controller plugin %q", pluginID)
return nil, fmt.Errorf("failed to find clients running controller plugin %q: %v",
pluginID, merr)
}

rand.Shuffle(len(clientIDs), func(i, j int) {
clientIDs[i], clientIDs[j] = clientIDs[j], clientIDs[i]
})
// Many plugins don't handle concurrent requests as described in the spec,
// and have undocumented expectations of using k8s-specific sidecars to
// leader elect. Sort the client IDs so that we prefer sending requests to
// the same controller to hack around this.
clientIDs = sort.StringSlice(clientIDs)

return clientIDs, nil
}
75 changes: 69 additions & 6 deletions nomad/csi_endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package nomad

import (
"context"
"fmt"
"net/http"
"strings"
Expand Down Expand Up @@ -549,7 +550,9 @@ func (v *CSIVolume) controllerPublishVolume(req *structs.CSIVolumeClaimRequest,
cReq.PluginID = plug.ID
cResp := &cstructs.ClientCSIControllerAttachVolumeResponse{}

err = v.srv.RPC(method, cReq, cResp)
err = v.serializedControllerRPC(plug.ID, func() error {
return v.srv.RPC(method, cReq, cResp)
})
if err != nil {
if strings.Contains(err.Error(), "FailedPrecondition") {
return fmt.Errorf("%v: %v", structs.ErrCSIClientRPCRetryable, err)
Expand Down Expand Up @@ -586,6 +589,57 @@ func (v *CSIVolume) volAndPluginLookup(namespace, volID string) (*structs.CSIPlu
return plug, vol, nil
}

// serializedControllerRPC ensures we're only sending a single controller RPC to
// a given plugin if the RPC can cause conflicting state changes.
Comment on lines +592 to +593
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we're good, but sharing here some of my thought process: For this to work, all such RPCs need to be run on a single server to block in memory. Looking at what I think are the relevant RPCs, it seems like they are all(?) v.srv.forward()ed (to the leader). I suppose this is what you meant by "with an allowance for losing state (ex. leadership transitions)" - i.e. this serialization could be broken by switching leaders? and we can't reasonably avoid that caveat.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, that's exactly it.

//
// The CSI specification says that we SHOULD send no more than one in-flight
// request per *volume* at a time, with an allowance for losing state
// (ex. leadership transitions) which the plugins SHOULD handle gracefully.
//
// In practice many CSI plugins rely on k8s-specific sidecars for serializing
// storage provider API calls globally (ex. concurrently attaching EBS volumes
// to an EC2 instance results in a race for device names). So we have to be much
// more conservative about concurrency in Nomad than the spec allows.
func (v *CSIVolume) serializedControllerRPC(pluginID string, fn func() error) error {

for {
v.srv.volumeControllerLock.Lock()
future := v.srv.volumeControllerFutures[pluginID]
if future == nil {
future, futureDone := context.WithCancel(v.srv.shutdownCtx)
v.srv.volumeControllerFutures[pluginID] = future
v.srv.volumeControllerLock.Unlock()

err := fn()

// close the future while holding the lock and not in a defer so
// that we can ensure we've cleared it from the map before allowing
// anyone else to take the lock and write a new one
v.srv.volumeControllerLock.Lock()
futureDone()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find myself wanting simpler map thread-safety, and a sync.Map seems just about right instead of a manual mutex, but its auto-locking-for-you behavior might not be compatible with your concern here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah we're locking around both access to the map and the construction/cancellation of contexts. With a sync.Map that would end up looking like this:

// spin up a context goroutine whether we need it or not!
future, futureDone := context.WithCancel(v.srv.shutdownCtx)
defer futureDone()

for {
	ok := v.srv.volumeControllerFutures.CompareAndSwap(pluginID, nil, future)
	if ok {
		// success path
		err := fn()
		futureDone()
		v.srv.volumeControllerFutures.CompareAndDelete(pluginID, future)
		return err
	} else {
		// wait path

		// racy! what if it's been deleted?
		waitOn, ok := v.srv.volumeControllerFutures.Load(pluginID) 
		if !ok {
			continue // lost the race with the delete, start again
		}
		select {
			case <-waitOn.Done():
				continue // other goroutine is done
			case <-v.srv.shutdownCh:
				return structs.ErrNoLeader
		}
	}
}

There are a few reasons why this isn't the best solution:

  • A sync.Map doesn't let us compare-and-swap unless we've already constructed the object we're intending to write. So we end up awkwardly constructing the context outside the loop.
  • In the "success path" we need to cancel the context before deleting it from the map, otherwise someone else can start their future before we've closed our context (and any downstream contexts we add in the future). But doing so will unblock the "wait path" before the CompareAndSwap will succeed, potentially leading to that goroutine trying to grab the CompareAndSwap multiple times before it can succeed, which is wasteful.
  • CompareAndSwap doesn't return the existing value, so we need to Load it the value and it may have been deleted in the meanwhile. This adds more contention around that lock.
  • The CompareAndSwap operation returns false when it hasn't been previously set! We could probably hack around that by creating a dummy context that's always treated as "unused" but then we've spun up a goroutine per plugin just to hold that spot.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, thanks for the mock-up and detailed explanation!

delete(v.srv.volumeControllerFutures, pluginID)
v.srv.volumeControllerLock.Unlock()

return err
} else {
v.srv.volumeControllerLock.Unlock()

select {
case <-future.Done():
continue
case <-v.srv.shutdownCh:
// The csi_hook publish workflow on the client will retry if it
// gets this error. On unpublish, we don't want to block client
// shutdown so we give up on error. The new leader's
// volumewatcher will iterate all the claims at startup to
// detect this and mop up any claims in the NodeDetached state
// (volume GC will run periodically as well)
return structs.ErrNoLeader
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Worth wrapping this with any more info about where it came from, or does that happen upstream of here?

Copy link
Member Author

@tgross tgross Jul 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The caller's logger will wrap it sufficiently. "No leader" is fairly foundational and there's nothing we can meaningfully add.

}
}
}
}

// allowCSIMount is called on Job register to check mount permission
func allowCSIMount(aclObj *acl.ACL, namespace string) bool {
return aclObj.AllowPluginRead() &&
Expand Down Expand Up @@ -863,8 +917,11 @@ func (v *CSIVolume) controllerUnpublishVolume(vol *structs.CSIVolume, claim *str
Secrets: vol.Secrets,
}
req.PluginID = vol.PluginID
err = v.srv.RPC("ClientCSI.ControllerDetachVolume", req,
&cstructs.ClientCSIControllerDetachVolumeResponse{})

err = v.serializedControllerRPC(vol.PluginID, func() error {
return v.srv.RPC("ClientCSI.ControllerDetachVolume", req,
&cstructs.ClientCSIControllerDetachVolumeResponse{})
})
if err != nil {
return fmt.Errorf("could not detach from controller: %v", err)
}
Expand Down Expand Up @@ -1139,7 +1196,9 @@ func (v *CSIVolume) deleteVolume(vol *structs.CSIVolume, plugin *structs.CSIPlug
cReq.PluginID = plugin.ID
cResp := &cstructs.ClientCSIControllerDeleteVolumeResponse{}

return v.srv.RPC(method, cReq, cResp)
return v.serializedControllerRPC(plugin.ID, func() error {
return v.srv.RPC(method, cReq, cResp)
})
}

func (v *CSIVolume) ListExternal(args *structs.CSIVolumeExternalListRequest, reply *structs.CSIVolumeExternalListResponse) error {
Expand Down Expand Up @@ -1286,7 +1345,9 @@ func (v *CSIVolume) CreateSnapshot(args *structs.CSISnapshotCreateRequest, reply
}
cReq.PluginID = pluginID
cResp := &cstructs.ClientCSIControllerCreateSnapshotResponse{}
err = v.srv.RPC(method, cReq, cResp)
err = v.serializedControllerRPC(pluginID, func() error {
return v.srv.RPC(method, cReq, cResp)
})
if err != nil {
multierror.Append(&mErr, fmt.Errorf("could not create snapshot: %v", err))
continue
Expand Down Expand Up @@ -1360,7 +1421,9 @@ func (v *CSIVolume) DeleteSnapshot(args *structs.CSISnapshotDeleteRequest, reply
cReq := &cstructs.ClientCSIControllerDeleteSnapshotRequest{ID: snap.ID}
cReq.PluginID = plugin.ID
cResp := &cstructs.ClientCSIControllerDeleteSnapshotResponse{}
err = v.srv.RPC(method, cReq, cResp)
err = v.serializedControllerRPC(plugin.ID, func() error {
return v.srv.RPC(method, cReq, cResp)
})
if err != nil {
multierror.Append(&mErr, fmt.Errorf("could not delete %q: %v", snap.ID, err))
}
Expand Down
48 changes: 48 additions & 0 deletions nomad/csi_endpoint_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package nomad
import (
"fmt"
"strings"
"sync"
"testing"
"time"

Expand All @@ -21,6 +22,7 @@ import (
cconfig "github.com/hashicorp/nomad/client/config"
cstructs "github.com/hashicorp/nomad/client/structs"
"github.com/hashicorp/nomad/helper/uuid"
"github.com/hashicorp/nomad/lib/lang"
"github.com/hashicorp/nomad/nomad/mock"
"github.com/hashicorp/nomad/nomad/state"
"github.com/hashicorp/nomad/nomad/structs"
Expand Down Expand Up @@ -1971,3 +1973,49 @@ func TestCSI_RPCVolumeAndPluginLookup(t *testing.T) {
require.Nil(t, vol)
require.EqualError(t, err, fmt.Sprintf("volume not found: %s", id2))
}

func TestCSI_SerializedControllerRPC(t *testing.T) {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've done some end-to-end testing of this new function with reschedules but I want to follow-up with some additional testing of drains, broken plugin configurations, etc., to make sure that everything is recoverable even in the corner cases.

ci.Parallel(t)

srv, shutdown := TestServer(t, func(c *Config) { c.NumSchedulers = 0 })
defer shutdown()
testutil.WaitForLeader(t, srv.RPC)

var wg sync.WaitGroup
wg.Add(3)

timeCh := make(chan lang.Pair[string, time.Duration])

testFn := func(pluginID string, dur time.Duration) {
defer wg.Done()
c := NewCSIVolumeEndpoint(srv, nil)
now := time.Now()
err := c.serializedControllerRPC(pluginID, func() error {
time.Sleep(dur)
return nil
})
elapsed := time.Since(now)
timeCh <- lang.Pair[string, time.Duration]{pluginID, elapsed}
must.NoError(t, err)
}

go testFn("plugin1", 50*time.Millisecond)
go testFn("plugin2", 50*time.Millisecond)
go testFn("plugin1", 50*time.Millisecond)

totals := map[string]time.Duration{}
for i := 0; i < 3; i++ {
pair := <-timeCh
totals[pair.First] += pair.Second
}

wg.Wait()

// plugin1 RPCs should block each other
must.GreaterEq(t, 150*time.Millisecond, totals["plugin1"])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see the test passes, but how does 2 * 50 milliseconds add up to > 150 ? "plugin1" should only be blocked behind itself, so "plugin2"'s 50ms shouldn't be included, right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For plugin1: the goroutine on line 2002 runs for 50ms. The goroutine on line 2004 waits at least 50ms for the other to finish and then runs for 50ms, for a total of 100ms. We're summing up the elapsed time of all the goroutines for a given plugin, because that way we're asserting the wait time of the method under test from the perspective of each goroutine and not just the elapsed wall clock time. So it's a minimum of 50ms + 100ms.

The maximum value of 200ms is assuming that any overhead is (far) less than 50ms, so that asserts that we're not accidentally waiting behind plugin2 as well.

must.Less(t, 200*time.Millisecond, totals["plugin1"])

// plugin1 RPCs should not block plugin2 RPCs
must.GreaterEq(t, 50*time.Millisecond, totals["plugin2"])
must.Less(t, 100*time.Millisecond, totals["plugin2"])
}
8 changes: 8 additions & 0 deletions nomad/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,13 @@ type Server struct {
// volumeWatcher is used to release volume claims
volumeWatcher *volumewatcher.Watcher

// volumeControllerFutures is a map of plugin IDs to pending controller RPCs. If
// no RPC is pending for a given plugin, this may be nil.
volumeControllerFutures map[string]context.Context

// volumeControllerLock synchronizes access controllerFutures map
volumeControllerLock sync.Mutex

// keyringReplicator is used to replicate root encryption keys from the
// leader
keyringReplicator *KeyringReplicator
Expand Down Expand Up @@ -445,6 +452,7 @@ func NewServer(config *Config, consulCatalog consul.CatalogAPI, consulConfigEntr
s.logger.Error("failed to create volume watcher", "error", err)
return nil, fmt.Errorf("failed to create volume watcher: %v", err)
}
s.volumeControllerFutures = map[string]context.Context{}

// Start the eval broker notification system so any subscribers can get
// updates when the processes SetEnabled is triggered.
Expand Down