Skip to content

Commit

Permalink
csi: retry controller client RPCs on next controller (#8561)
Browse files Browse the repository at this point in the history
The documentation encourages operators to run multiple controller plugin
instances for HA, but the client RPCs don't take advantage of this by retrying
when the RPC fails in cases when the plugin is unavailable (because the node
has drained or the alloc has failed but we haven't received an updated
fingerprint yet).

This changeset tries all known controllers on ready nodes before giving up,
and adds tests that exercise the client RPC routing and retries.
  • Loading branch information
tgross authored Aug 6, 2020
1 parent c44df58 commit 07ff0b9
Show file tree
Hide file tree
Showing 2 changed files with 268 additions and 174 deletions.
164 changes: 87 additions & 77 deletions nomad/client_csi_endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package nomad
import (
"fmt"
"math/rand"
"strings"
"time"

metrics "github.com/armon/go-metrics"
Expand All @@ -20,74 +21,101 @@ type ClientCSI struct {

func (a *ClientCSI) ControllerAttachVolume(args *cstructs.ClientCSIControllerAttachVolumeRequest, reply *cstructs.ClientCSIControllerAttachVolumeResponse) error {
defer metrics.MeasureSince([]string{"nomad", "client_csi_controller", "attach_volume"}, time.Now())
// Get a Nomad client node for the controller
nodeID, err := a.nodeForController(args.PluginID, args.ControllerNodeID)

clientIDs, err := a.clientIDsForController(args.PluginID)
if err != nil {
return err
return fmt.Errorf("controller attach volume: %v", err)
}
args.ControllerNodeID = nodeID

// Get the connection to the client
state, ok := a.srv.getNodeConn(args.ControllerNodeID)
if !ok {
return findNodeConnAndForward(a.srv, args.ControllerNodeID, "ClientCSI.ControllerAttachVolume", args, reply)
}
for _, clientID := range clientIDs {
args.ControllerNodeID = clientID
state, ok := a.srv.getNodeConn(clientID)
if !ok {
return findNodeConnAndForward(a.srv,
clientID, "ClientCSI.ControllerAttachVolume", args, reply)
}

// Make the RPC
err = NodeRpc(state.Session, "CSI.ControllerAttachVolume", args, reply)
if err != nil {
err = NodeRpc(state.Session, "CSI.ControllerAttachVolume", args, reply)
if err == nil {
return nil
}
if a.isRetryable(err, clientID, args.PluginID) {
a.logger.Debug("failed to reach controller on client %q: %v", clientID, err)
continue
}
return fmt.Errorf("controller attach volume: %v", err)
}
return nil
return fmt.Errorf("controller attach volume: %v", err)
}

func (a *ClientCSI) ControllerValidateVolume(args *cstructs.ClientCSIControllerValidateVolumeRequest, reply *cstructs.ClientCSIControllerValidateVolumeResponse) error {
defer metrics.MeasureSince([]string{"nomad", "client_csi_controller", "validate_volume"}, time.Now())

// Get a Nomad client node for the controller
nodeID, err := a.nodeForController(args.PluginID, args.ControllerNodeID)
clientIDs, err := a.clientIDsForController(args.PluginID)
if err != nil {
return err
return fmt.Errorf("validate volume: %v", err)
}
args.ControllerNodeID = nodeID

// Get the connection to the client
state, ok := a.srv.getNodeConn(args.ControllerNodeID)
if !ok {
return findNodeConnAndForward(a.srv, args.ControllerNodeID, "ClientCSI.ControllerValidateVolume", args, reply)
}
for _, clientID := range clientIDs {
args.ControllerNodeID = clientID
state, ok := a.srv.getNodeConn(clientID)
if !ok {
return findNodeConnAndForward(a.srv,
clientID, "ClientCSI.ControllerValidateVolume", args, reply)
}

// Make the RPC
err = NodeRpc(state.Session, "CSI.ControllerValidateVolume", args, reply)
if err != nil {
err = NodeRpc(state.Session, "CSI.ControllerValidateVolume", args, reply)
if err == nil {
return nil
}
if a.isRetryable(err, clientID, args.PluginID) {
a.logger.Debug("failed to reach controller on client %q: %v", clientID, err)
continue
}
return fmt.Errorf("validate volume: %v", err)
}
return nil
return fmt.Errorf("validate volume: %v", err)
}

func (a *ClientCSI) ControllerDetachVolume(args *cstructs.ClientCSIControllerDetachVolumeRequest, reply *cstructs.ClientCSIControllerDetachVolumeResponse) error {
defer metrics.MeasureSince([]string{"nomad", "client_csi_controller", "detach_volume"}, time.Now())

// Get a Nomad client node for the controller
nodeID, err := a.nodeForController(args.PluginID, args.ControllerNodeID)
clientIDs, err := a.clientIDsForController(args.PluginID)
if err != nil {
return err
return fmt.Errorf("controller detach volume: %v", err)
}
args.ControllerNodeID = nodeID

// Get the connection to the client
state, ok := a.srv.getNodeConn(args.ControllerNodeID)
if !ok {
return findNodeConnAndForward(a.srv, args.ControllerNodeID, "ClientCSI.ControllerDetachVolume", args, reply)
}
for _, clientID := range clientIDs {
args.ControllerNodeID = clientID
state, ok := a.srv.getNodeConn(clientID)
if !ok {
return findNodeConnAndForward(a.srv,
clientID, "ClientCSI.ControllerDetachVolume", args, reply)
}

// Make the RPC
err = NodeRpc(state.Session, "CSI.ControllerDetachVolume", args, reply)
if err != nil {
err = NodeRpc(state.Session, "CSI.ControllerDetachVolume", args, reply)
if err == nil {
return nil
}
if a.isRetryable(err, clientID, args.PluginID) {
a.logger.Debug("failed to reach controller on client %q: %v", clientID, err)
continue
}
return fmt.Errorf("controller detach volume: %v", err)
}
return nil
return fmt.Errorf("controller detach volume: %v", err)
}

// we can retry the same RPC on a different controller in the cases where the
// client has stopped and been GC'd, or where the controller has stopped but
// we don't have the fingerprint update yet
func (a *ClientCSI) isRetryable(err error, clientID, pluginID string) bool {
// TODO(tgross): it would be nicer to use errors.Is here but we
// need to make sure we're using error wrapping to make that work
errMsg := err.Error()
return strings.Contains(errMsg, fmt.Sprintf("Unknown node: %s", clientID)) ||
strings.Contains(errMsg, "no plugins registered for type: csi-controller") ||
strings.Contains(errMsg, fmt.Sprintf("plugin %s for type controller not found", pluginID))
}

func (a *ClientCSI) NodeDetachVolume(args *cstructs.ClientCSINodeDetachVolumeRequest, reply *cstructs.ClientCSINodeDetachVolumeResponse) error {
Expand Down Expand Up @@ -119,29 +147,17 @@ func (a *ClientCSI) NodeDetachVolume(args *cstructs.ClientCSINodeDetachVolumeReq

}

// nodeForController validates that the Nomad client node ID for
// a plugin exists and is new enough to support client RPC. If no node
// ID is passed, select a random node ID for the controller to load-balance
// long blocking RPCs across client nodes.
func (a *ClientCSI) nodeForController(pluginID, nodeID string) (string, error) {
// clientIDsForController returns a shuffled list of client IDs where the
// controller plugin is expected to be running.
func (a *ClientCSI) clientIDsForController(pluginID string) ([]string, error) {

snap, err := a.srv.State().Snapshot()
if err != nil {
return "", err
}

if nodeID != "" {
_, err = getNodeForRpc(snap, nodeID)
if err == nil {
return nodeID, nil
} else {
// we'll fall-through and select a node at random
a.logger.Trace("could not be used for client RPC", "node", nodeID, "error", err)
}
return nil, err
}

if pluginID == "" {
return "", fmt.Errorf("missing plugin ID")
return nil, fmt.Errorf("missing plugin ID")
}

ws := memdb.NewWatchSet()
Expand All @@ -151,43 +167,37 @@ func (a *ClientCSI) nodeForController(pluginID, nodeID string) (string, error) {
// region/DC for the volume.
plugin, err := snap.CSIPluginByID(ws, pluginID)
if err != nil {
return "", fmt.Errorf("error getting plugin: %s, %v", pluginID, err)
return nil, fmt.Errorf("error getting plugin: %s, %v", pluginID, err)
}
if plugin == nil {
return "", fmt.Errorf("plugin missing: %s %v", pluginID, err)
}
count := len(plugin.Controllers)
if count == 0 {
return "", fmt.Errorf("no controllers available for plugin %q", plugin.ID)
return nil, fmt.Errorf("plugin missing: %s", pluginID)
}

// iterating maps is "random" but unspecified and isn't particularly
// random with small maps, so not well-suited for load balancing.
// so we shuffle the keys and iterate over them.
clientIDs := make([]string, 0, count)
for clientID := range plugin.Controllers {
clientIDs = append(clientIDs, clientID)
}
rand.Shuffle(count, func(i, j int) {
clientIDs[i], clientIDs[j] = clientIDs[j], clientIDs[i]
})
clientIDs := []string{}

for _, clientID := range clientIDs {
controller := plugin.Controllers[clientID]
for clientID, controller := range plugin.Controllers {
if !controller.IsController() {
// we don't have separate types for CSIInfo depending on
// whether it's a controller or node. this error shouldn't
// make it to production but is to aid developers during
// development
err = fmt.Errorf("plugin is not a controller")
continue
}
_, err = getNodeForRpc(snap, clientID)
if err != nil {
continue
node, err := getNodeForRpc(snap, clientID)
if err == nil && node != nil && node.Ready() {
clientIDs = append(clientIDs, clientID)
}
return clientID, nil
}
if len(clientIDs) == 0 {
return nil, fmt.Errorf("failed to find clients running controller plugin %q", pluginID)
}

rand.Shuffle(len(clientIDs), func(i, j int) {
clientIDs[i], clientIDs[j] = clientIDs[j], clientIDs[i]
})

return "", err
return clientIDs, nil
}
Loading

0 comments on commit 07ff0b9

Please sign in to comment.