diff --git a/pkg/driver/node.go b/pkg/driver/node.go index b1efbfd895..f1036a91ed 100644 --- a/pkg/driver/node.go +++ b/pkg/driver/node.go @@ -149,8 +149,29 @@ func (d *Driver) NodeUnstageVolume(ctx context.Context, req *csi.NodeUnstageVolu return nil, status.Error(codes.InvalidArgument, "Staging target not provided") } + // Check if target directory is a mount point. GetDeviceNameFromMount + // given a mnt point, finds the device from /proc/mounts + // returns the device name, reference count, and error code + dev, refCount, err := mount.GetDeviceNameFromMount(d.mounter, target) + if err != nil { + msg := fmt.Sprintf("failed to check if volume is mounted: %v", err) + return nil, status.Error(codes.Internal, msg) + } + + // From the spec: If the volume corresponding to the volume_id + // is not staged to the staging_target_path, the Plugin MUST + // reply 0 OK. + if refCount == 0 { + klog.V(5).Infof("NodeUnstageVolume: %s target not mounted", target) + return &csi.NodeUnstageVolumeResponse{}, nil + } + + if refCount > 1 { + klog.Warningf("NodeUnstageVolume: found %d references to device %s mounted at target path %s", refCount, dev, target) + } + klog.V(5).Infof("NodeUnstageVolume: unmounting %s", target) - err := d.mounter.Interface.Unmount(target) + err = d.mounter.Interface.Unmount(target) if err != nil { return nil, status.Errorf(codes.Internal, "Could not unmount target %q: %v", target, err) } diff --git a/pkg/driver/node_test.go b/pkg/driver/node_test.go index 8c48245a94..716a65fb29 100644 --- a/pkg/driver/node_test.go +++ b/pkg/driver/node_test.go @@ -202,11 +202,12 @@ func TestNodeStageVolume(t *testing.T) { t.Fatalf("Expected error %v, got no error", tc.expErrCode) } - if len(tc.expActions) > 0 && !reflect.DeepEqual(fakeMounter.Log, tc.expActions) { + // if fake mounter did anything we should + // check if it was expected + if len(fakeMounter.Log) > 0 && !reflect.DeepEqual(fakeMounter.Log, tc.expActions) { t.Fatalf("Expected actions {%+v}, got {%+v}", tc.expActions, fakeMounter.Log) } - - if len(tc.expMountPoints) > 0 && !reflect.DeepEqual(fakeMounter.MountPoints, tc.expMountPoints) { + if len(fakeMounter.MountPoints) > 0 && !reflect.DeepEqual(fakeMounter.MountPoints, tc.expMountPoints) { t.Fatalf("Expected mount points {%+v}, got {%+v}", tc.expMountPoints, fakeMounter.MountPoints) } }) @@ -215,11 +216,14 @@ func TestNodeStageVolume(t *testing.T) { func TestNodeUnstageVolume(t *testing.T) { testCases := []struct { - name string - req *csi.NodeUnstageVolumeRequest - expErrCode codes.Code - fakeMountPoint *mount.MountPoint - expActions []mount.FakeAction + name string + req *csi.NodeUnstageVolumeRequest + expErrCode codes.Code + fakeMountPoints []mount.MountPoint + // expected fake mount actions the test will make + expActions []mount.FakeAction + // expected mount points when test finishes + expMountPoints []mount.MountPoint }{ { name: "success normal", @@ -227,15 +231,38 @@ func TestNodeUnstageVolume(t *testing.T) { StagingTargetPath: "/test/path", VolumeId: "vol-test", }, - fakeMountPoint: &mount.MountPoint{ - Device: "/dev/fake", - Path: "/test/path", + fakeMountPoints: []mount.MountPoint{ + {Device: "/dev/fake", Path: "/test/path"}, }, expActions: []mount.FakeAction{ - { - Action: "unmount", - Target: "/test/path", - }, + {Action: "unmount", Target: "/test/path"}, + }, + }, + { + name: "success no device mounted at target", + req: &csi.NodeUnstageVolumeRequest{ + StagingTargetPath: "/test/path", + VolumeId: "vol-test", + }, + expActions: []mount.FakeAction{}, + }, + { + name: "success device mounted at multiple targets", + req: &csi.NodeUnstageVolumeRequest{ + StagingTargetPath: "/test/path", + VolumeId: "vol-test", + }, + // mount a fake device in two locations + fakeMountPoints: []mount.MountPoint{ + {Device: "/dev/fake", Path: "/test/path"}, + {Device: "/dev/fake", Path: "/foo/bar"}, + }, + // it should unmount from the original + expActions: []mount.FakeAction{ + {Action: "unmount", Target: "/test/path"}, + }, + expMountPoints: []mount.MountPoint{ + {Device: "/dev/fake", Path: "/foo/bar"}, }, }, { @@ -257,8 +284,8 @@ func TestNodeUnstageVolume(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { fakeMounter := NewFakeMounter() - if tc.fakeMountPoint != nil { - fakeMounter.MountPoints = append(fakeMounter.MountPoints, *tc.fakeMountPoint) + if len(tc.fakeMountPoints) > 0 { + fakeMounter.MountPoints = tc.fakeMountPoints } awsDriver := NewFakeDriver("", fakeMounter) @@ -274,10 +301,14 @@ func TestNodeUnstageVolume(t *testing.T) { } else if tc.expErrCode != codes.OK { t.Fatalf("Expected error %v, got no error", tc.expErrCode) } - - if len(tc.expActions) > 0 && !reflect.DeepEqual(fakeMounter.Log, tc.expActions) { + // if fake mounter did anything we should + // check if it was expected + if len(fakeMounter.Log) > 0 && !reflect.DeepEqual(fakeMounter.Log, tc.expActions) { t.Fatalf("Expected actions {%+v}, got {%+v}", tc.expActions, fakeMounter.Log) } + if len(fakeMounter.MountPoints) > 0 && !reflect.DeepEqual(fakeMounter.MountPoints, tc.expMountPoints) { + t.Fatalf("Expected mount points {%+v}, got {%+v}", tc.expMountPoints, fakeMounter.MountPoints) + } }) } } @@ -468,11 +499,12 @@ func TestNodePublishVolume(t *testing.T) { t.Fatalf("Expected error %v and got no error", tc.expErrCode) } - if len(tc.expActions) > 0 && !reflect.DeepEqual(fakeMounter.Log, tc.expActions) { + // if fake mounter did anything we should + // check if it was expected + if len(fakeMounter.Log) > 0 && !reflect.DeepEqual(fakeMounter.Log, tc.expActions) { t.Fatalf("Expected actions {%+v}, got {%+v}", tc.expActions, fakeMounter.Log) } - - if len(tc.expMountPoints) > 0 && !reflect.DeepEqual(fakeMounter.MountPoints, tc.expMountPoints) { + if len(fakeMounter.MountPoints) > 0 && !reflect.DeepEqual(fakeMounter.MountPoints, tc.expMountPoints) { t.Fatalf("Expected mount points {%+v}, got {%+v}", tc.expMountPoints, fakeMounter.MountPoints) } }) @@ -544,7 +576,9 @@ func TestNodeUnpublishVolume(t *testing.T) { t.Fatalf("Expected error %v, got no error", tc.expErrCode) } - if len(tc.expActions) > 0 && !reflect.DeepEqual(fakeMounter.Log, tc.expActions) { + // if fake mounter did anything we should + // check if it was expected + if len(fakeMounter.Log) > 0 && !reflect.DeepEqual(fakeMounter.Log, tc.expActions) { t.Fatalf("Expected actions {%+v}, got {%+v}", tc.expActions, fakeMounter.Log) } })