Skip to content

Commit

Permalink
refactor: make nodeForControllerPlugin private to ClientCSI (#7688)
Browse files Browse the repository at this point in the history
The current design of `ClientCSI` RPC requires that callers in the
server know about the free-standing `nodeForControllerPlugin`
function. This makes it difficult to send `ClientCSI` RPC messages
from subpackages of `nomad` and adds a bunch of boilerplate to every
server-side caller of a controller RPC.

This changeset makes it so that the `ClientCSI` RPCs will populate and
validate the controller's client node ID if it hasn't been passed by
the caller, centralizing the logic of picking and validating
controller targets into the `nomad.ClientCSI` struct.
  • Loading branch information
tgross authored Apr 10, 2020
1 parent 47dfa76 commit 09abe0c
Show file tree
Hide file tree
Showing 5 changed files with 154 additions and 183 deletions.
88 changes: 42 additions & 46 deletions nomad/client_csi_endpoint.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package nomad

import (
"errors"
"fmt"
"math/rand"
"time"
Expand All @@ -10,7 +9,6 @@ import (
log "github.com/hashicorp/go-hclog"
memdb "github.com/hashicorp/go-memdb"
cstructs "github.com/hashicorp/nomad/client/structs"
"github.com/hashicorp/nomad/nomad/state"
"github.com/hashicorp/nomad/nomad/structs"
)

Expand All @@ -23,22 +21,12 @@ type ClientCSI struct {

func (a *ClientCSI) ControllerAttachVolume(args *cstructs.ClientCSIControllerAttachVolumeRequest, reply *cstructs.ClientCSIControllerAttachVolumeResponse) error {
defer metrics.MeasureSince([]string{"nomad", "client_csi_controller", "attach_volume"}, time.Now())

// Verify the arguments.
if args.ControllerNodeID == "" {
return errors.New("missing ControllerNodeID")
}

// Make sure Node is valid and new enough to support RPC
snap, err := a.srv.State().Snapshot()
if err != nil {
return err
}

_, err = getNodeForRpc(snap, args.ControllerNodeID)
// Get a Nomad client node for the controller
nodeID, err := a.nodeForController(args.PluginID, args.ControllerNodeID)
if err != nil {
return err
}
args.ControllerNodeID = nodeID

// Get the connection to the client
state, ok := a.srv.getNodeConn(args.ControllerNodeID)
Expand All @@ -57,21 +45,12 @@ func (a *ClientCSI) ControllerAttachVolume(args *cstructs.ClientCSIControllerAtt
func (a *ClientCSI) ControllerValidateVolume(args *cstructs.ClientCSIControllerValidateVolumeRequest, reply *cstructs.ClientCSIControllerValidateVolumeResponse) error {
defer metrics.MeasureSince([]string{"nomad", "client_csi_controller", "validate_volume"}, time.Now())

// Verify the arguments.
if args.ControllerNodeID == "" {
return errors.New("missing ControllerNodeID")
}

// Make sure Node is valid and new enough to support RPC
snap, err := a.srv.State().Snapshot()
if err != nil {
return err
}

_, err = getNodeForRpc(snap, args.ControllerNodeID)
// Get a Nomad client node for the controller
nodeID, err := a.nodeForController(args.PluginID, args.ControllerNodeID)
if err != nil {
return err
}
args.ControllerNodeID = nodeID

// Get the connection to the client
state, ok := a.srv.getNodeConn(args.ControllerNodeID)
Expand All @@ -90,21 +69,12 @@ func (a *ClientCSI) ControllerValidateVolume(args *cstructs.ClientCSIControllerV
func (a *ClientCSI) ControllerDetachVolume(args *cstructs.ClientCSIControllerDetachVolumeRequest, reply *cstructs.ClientCSIControllerDetachVolumeResponse) error {
defer metrics.MeasureSince([]string{"nomad", "client_csi_controller", "detach_volume"}, time.Now())

// Verify the arguments.
if args.ControllerNodeID == "" {
return errors.New("missing ControllerNodeID")
}

// Make sure Node is valid and new enough to support RPC
snap, err := a.srv.State().Snapshot()
if err != nil {
return err
}

_, err = getNodeForRpc(snap, args.ControllerNodeID)
// Get a Nomad client node for the controller
nodeID, err := a.nodeForController(args.PluginID, args.ControllerNodeID)
if err != nil {
return err
}
args.ControllerNodeID = nodeID

// Get the connection to the client
state, ok := a.srv.getNodeConn(args.ControllerNodeID)
Expand Down Expand Up @@ -178,17 +148,43 @@ func (srv *Server) volAndPluginLookup(namespace, volID string) (*structs.CSIPlug
return plug, vol, nil
}

// nodeForControllerPlugin returns the node ID for a random controller
// to load-balance long-blocking RPCs across client nodes.
func nodeForControllerPlugin(state *state.StateStore, plugin *structs.CSIPlugin) (string, error) {
// nodeForController validates that the Nomad client node ID for
// a plugin exists and is new enough to support client RPC. If no node
// ID is passed, select a random node ID for the controller to load-balance
// long blocking RPCs across client nodes.
func (a *ClientCSI) nodeForController(pluginID, nodeID string) (string, error) {

snap, err := a.srv.State().Snapshot()
if err != nil {
return "", err
}

if nodeID != "" {
_, err = getNodeForRpc(snap, nodeID)
if err == nil {
return nodeID, nil
}
}

if pluginID == "" {
return "", fmt.Errorf("missing plugin ID")
}
ws := memdb.NewWatchSet()

// note: plugin IDs are not scoped to region/DC but volumes are.
// so any node we get for a controller is already in the same
// region/DC for the volume.
plugin, err := snap.CSIPluginByID(ws, pluginID)
if err != nil {
return "", fmt.Errorf("error getting plugin: %s, %v", pluginID, err)
}
if plugin == nil {
return "", fmt.Errorf("plugin missing: %s %v", pluginID, err)
}
count := len(plugin.Controllers)
if count == 0 {
return "", fmt.Errorf("no controllers available for plugin %q", plugin.ID)
}
snap, err := state.Snapshot()
if err != nil {
return "", err
}

// iterating maps is "random" but unspecified and isn't particularly
// random with small maps, so not well-suited for load balancing.
Expand Down
45 changes: 45 additions & 0 deletions nomad/client_csi_endpoint_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@ package nomad
import (
"testing"

memdb "github.com/hashicorp/go-memdb"
msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc"
"github.com/hashicorp/nomad/client"
"github.com/hashicorp/nomad/client/config"
cstructs "github.com/hashicorp/nomad/client/structs"
"github.com/hashicorp/nomad/helper/uuid"
"github.com/hashicorp/nomad/nomad/mock"
"github.com/hashicorp/nomad/nomad/structs"
"github.com/hashicorp/nomad/testutil"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -167,3 +170,45 @@ func TestClientCSIController_DetachVolume_Forwarded(t *testing.T) {
// Should recieve an error from the client endpoint
require.Contains(err.Error(), "must specify plugin name to dispense")
}

func TestClientCSI_NodeForControllerPlugin(t *testing.T) {
t.Parallel()
srv, shutdown := TestServer(t, func(c *Config) {})
testutil.WaitForLeader(t, srv.RPC)
defer shutdown()

plugins := map[string]*structs.CSIInfo{
"minnie": {PluginID: "minnie",
Healthy: true,
ControllerInfo: &structs.CSIControllerInfo{},
NodeInfo: &structs.CSINodeInfo{},
RequiresControllerPlugin: true,
},
}
state := srv.fsm.State()

node1 := mock.Node()
node1.Attributes["nomad.version"] = "0.11.0" // client RPCs not supported on early versions
node1.CSIControllerPlugins = plugins
node2 := mock.Node()
node2.CSIControllerPlugins = plugins
node2.ID = uuid.Generate()
node3 := mock.Node()
node3.ID = uuid.Generate()

err := state.UpsertNode(1002, node1)
require.NoError(t, err)
err = state.UpsertNode(1003, node2)
require.NoError(t, err)
err = state.UpsertNode(1004, node3)
require.NoError(t, err)

ws := memdb.NewWatchSet()

plugin, err := state.CSIPluginByID(ws, "minnie")
require.NoError(t, err)
nodeID, err := srv.staticEndpoints.ClientCSI.nodeForController(plugin.ID, "")

// only node1 has both the controller and a recent Nomad version
require.Equal(t, nodeID, node1.ID)
}
5 changes: 0 additions & 5 deletions nomad/core_sched.go
Original file line number Diff line number Diff line change
Expand Up @@ -867,16 +867,11 @@ func volumeClaimReapImpl(srv RPCServer, args *volumeClaimReapArgs) (map[string]i
return args.nodeClaims, fmt.Errorf("Failed to find NodeInfo for node: %s", targetNode.ID)
}

controllerNodeID, err := nodeForControllerPlugin(srv.State(), args.plug)
if err != nil || controllerNodeID == "" {
return args.nodeClaims, err
}
cReq := &cstructs.ClientCSIControllerDetachVolumeRequest{
VolumeID: vol.RemoteID(),
ClientCSINodeID: targetCSIInfo.NodeInfo.ID,
}
cReq.PluginID = args.plug.ID
cReq.ControllerNodeID = controllerNodeID
err = srv.RPC("ClientCSI.ControllerDetachVolume", cReq,
&cstructs.ClientCSIControllerDetachVolumeResponse{})
if err != nil {
Expand Down
Loading

0 comments on commit 09abe0c

Please sign in to comment.