Skip to content

Commit

Permalink
CSI: allocation should not block its own controller unpublish
Browse files Browse the repository at this point in the history
Nomad user reported problems with CSI volumes associated with failed
allocations, where the Nomad server did not send a controller unpublish RPC.

The controller unpublish is skipped if other non-terminal allocations on the
same node claim the volume. The check has a bug where the allocation belonging
to the claim being freed was included in the check incorrectly. During a normal
allocation stop for job stop or a new version of the job, the allocation is
terminal. But allocations that fail are not yet marked terminal at the point in
time when the client sends the unpublish RPC to the server.

For CSI plugins that support controller attach/detach, this means that the
controller will not be able to detach the volume from the allocation's host and
the replacement claim will fail until a GC is run. This changeset fixes the
conditional so that the claim's own allocation is not included, and makes the
logic easier to read. Include a test case covering this path.
  • Loading branch information
tgross committed Sep 8, 2022
1 parent 25e2302 commit 2dc870c
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 20 deletions.
30 changes: 23 additions & 7 deletions nomad/csi_endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -787,6 +787,7 @@ func (v *CSIVolume) nodeUnpublishVolumeImpl(vol *structs.CSIVolume, claim *struc
// be called on a copy of the volume.
func (v *CSIVolume) controllerUnpublishVolume(vol *structs.CSIVolume, claim *structs.CSIVolumeClaim) error {
v.logger.Trace("controller unpublish", "vol", vol.ID)

if !vol.ControllerRequired {
claim.State = structs.CSIVolumeClaimStateReadyToFree
return nil
Expand All @@ -801,26 +802,39 @@ func (v *CSIVolume) controllerUnpublishVolume(vol *structs.CSIVolume, claim *str
} else if plugin == nil {
return fmt.Errorf("no such plugin: %q", vol.PluginID)
}

if !plugin.HasControllerCapability(structs.CSIControllerSupportsAttachDetach) {
claim.State = structs.CSIVolumeClaimStateReadyToFree
return nil
}

// we only send a controller detach if a Nomad client no longer has
// any claim to the volume, so we need to check the status of claimed
// allocations
vol, err = state.CSIVolumeDenormalize(ws, vol)
if err != nil {
return err
}
for _, alloc := range vol.ReadAllocs {
if alloc != nil && alloc.NodeID == claim.NodeID && !alloc.TerminalStatus() {

// we only send a controller detach if a Nomad client no longer has any
// claim to the volume, so we need to check the status of any other claimed
// allocations
shouldCancel := func(alloc *structs.Allocation) bool {
if alloc != nil && alloc.ID != claim.AllocationID &&
alloc.NodeID == claim.NodeID && !alloc.TerminalStatus() {
claim.State = structs.CSIVolumeClaimStateReadyToFree
v.logger.Debug(
"controller unpublish canceled: another non-terminal alloc is on this node",
"vol", vol.ID, "alloc", alloc.ID)
return true
}
return false
}

for _, alloc := range vol.ReadAllocs {
if shouldCancel(alloc) {
return nil
}
}
for _, alloc := range vol.WriteAllocs {
if alloc != nil && alloc.NodeID == claim.NodeID && !alloc.TerminalStatus() {
claim.State = structs.CSIVolumeClaimStateReadyToFree
if shouldCancel(alloc) {
return nil
}
}
Expand All @@ -846,6 +860,8 @@ func (v *CSIVolume) controllerUnpublishVolume(vol *structs.CSIVolume, claim *str
if err != nil {
return fmt.Errorf("could not detach from controller: %v", err)
}

v.logger.Trace("controller detach complete", "vol", vol.ID)
claim.State = structs.CSIVolumeClaimStateReadyToFree
return v.checkpointClaim(vol, claim)
}
Expand Down
69 changes: 56 additions & 13 deletions nomad/csi_endpoint_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ import (
"time"

msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc"
"github.com/shoenig/test"
"github.com/shoenig/test/must"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/hashicorp/nomad/acl"
"github.com/hashicorp/nomad/ci"
"github.com/hashicorp/nomad/client"
Expand All @@ -17,7 +22,6 @@ import (
"github.com/hashicorp/nomad/nomad/state"
"github.com/hashicorp/nomad/nomad/structs"
"github.com/hashicorp/nomad/testutil"
"github.com/stretchr/testify/require"
)

func TestCSIVolumeEndpoint_Get(t *testing.T) {
Expand Down Expand Up @@ -499,37 +503,52 @@ func TestCSIVolumeEndpoint_Unpublish(t *testing.T) {
},
}
index++
require.NoError(t, state.UpsertNode(structs.MsgTypeTestSetup, index, node))
must.NoError(t, state.UpsertNode(structs.MsgTypeTestSetup, index, node))

type tc struct {
name string
startingState structs.CSIVolumeClaimState
endState structs.CSIVolumeClaimState
nodeID string
otherNodeID string
expectedErrMsg string
}
testCases := []tc{
{
name: "success",
startingState: structs.CSIVolumeClaimStateControllerDetached,
nodeID: node.ID,
otherNodeID: uuid.Generate(),
},
{
name: "non-terminal allocation on same node",
startingState: structs.CSIVolumeClaimStateNodeDetached,
nodeID: node.ID,
otherNodeID: node.ID,
},
{
name: "unpublish previously detached node",
startingState: structs.CSIVolumeClaimStateNodeDetached,
endState: structs.CSIVolumeClaimStateNodeDetached,
expectedErrMsg: "could not detach from controller: controller detach volume: No path to node",
nodeID: node.ID,
otherNodeID: uuid.Generate(),
},
{
name: "unpublish claim on garbage collected node",
startingState: structs.CSIVolumeClaimStateTaken,
endState: structs.CSIVolumeClaimStateNodeDetached,
expectedErrMsg: "could not detach from controller: controller detach volume: No path to node",
nodeID: uuid.Generate(),
otherNodeID: uuid.Generate(),
},
{
name: "first unpublish",
startingState: structs.CSIVolumeClaimStateTaken,
endState: structs.CSIVolumeClaimStateNodeDetached,
expectedErrMsg: "could not detach from controller: controller detach volume: No path to node",
nodeID: node.ID,
otherNodeID: uuid.Generate(),
},
}

Expand All @@ -551,15 +570,20 @@ func TestCSIVolumeEndpoint_Unpublish(t *testing.T) {

index++
err = state.UpsertCSIVolume(index, []*structs.CSIVolume{vol})
require.NoError(t, err)
must.NoError(t, err)

// setup: create an alloc that will claim our volume
alloc := mock.BatchAlloc()
alloc.NodeID = tc.nodeID
alloc.ClientStatus = structs.AllocClientStatusFailed

otherAlloc := mock.BatchAlloc()
otherAlloc.NodeID = tc.otherNodeID
otherAlloc.ClientStatus = structs.AllocClientStatusRunning

index++
require.NoError(t, state.UpsertAllocs(structs.MsgTypeTestSetup, index, []*structs.Allocation{alloc}))
must.NoError(t, state.UpsertAllocs(structs.MsgTypeTestSetup, index,
[]*structs.Allocation{alloc, otherAlloc}))

// setup: claim the volume for our alloc
claim := &structs.CSIVolumeClaim{
Expand All @@ -572,7 +596,20 @@ func TestCSIVolumeEndpoint_Unpublish(t *testing.T) {
index++
claim.State = structs.CSIVolumeClaimStateTaken
err = state.CSIVolumeClaim(index, ns, volID, claim)
require.NoError(t, err)
must.NoError(t, err)

// setup: claim the volume for our other alloc
otherClaim := &structs.CSIVolumeClaim{
AllocationID: otherAlloc.ID,
NodeID: tc.otherNodeID,
ExternalNodeID: "i-example",
Mode: structs.CSIVolumeClaimRead,
}

index++
otherClaim.State = structs.CSIVolumeClaimStateTaken
err = state.CSIVolumeClaim(index, ns, volID, otherClaim)
must.NoError(t, err)

// test: unpublish and check the results
claim.State = tc.startingState
Expand All @@ -589,17 +626,23 @@ func TestCSIVolumeEndpoint_Unpublish(t *testing.T) {
err = msgpackrpc.CallWithCodec(codec, "CSIVolume.Unpublish", req,
&structs.CSIVolumeUnpublishResponse{})

vol, volErr := state.CSIVolumeByID(nil, ns, volID)
must.NoError(t, volErr)
must.NotNil(t, vol)

if tc.expectedErrMsg == "" {
require.NoError(t, err)
vol, err = state.CSIVolumeByID(nil, ns, volID)
require.NoError(t, err)
require.NotNil(t, vol)
require.Len(t, vol.ReadAllocs, 0)
must.NoError(t, err)
assert.Len(t, vol.ReadAllocs, 1)
} else {
require.Error(t, err)
require.True(t, strings.Contains(err.Error(), tc.expectedErrMsg),
"error message %q did not contain %q", err.Error(), tc.expectedErrMsg)
must.Error(t, err)
assert.Len(t, vol.ReadAllocs, 2)
test.True(t, strings.Contains(err.Error(), tc.expectedErrMsg),
test.Sprintf("error %v did not contain %q", err, tc.expectedErrMsg))
claim = vol.PastClaims[alloc.ID]
must.NotNil(t, claim)
test.Eq(t, tc.endState, claim.State)
}

})
}

Expand Down

0 comments on commit 2dc870c

Please sign in to comment.