diff --git a/pkg/mounter/refcounter_windows.go b/pkg/mounter/refcounter_windows.go index 257d6e52d73..f36f2e1b4a5 100644 --- a/pkg/mounter/refcounter_windows.go +++ b/pkg/mounter/refcounter_windows.go @@ -63,21 +63,21 @@ func getRootMappingPath(path string) (string, error) { return strings.ToLower("\\\\" + parts[0] + "\\" + parts[1]), nil } -// incementRemotePathReferencesCount - adds new reference between mappingPath and remotePath if it doesn't exist. +// incementVolumeIDReferencesCount - adds new reference between mappingPath and remotePath if it doesn't exist. // How it works: // 1. MappingPath contains two components: hostname, sharename // 2. We create directory in basePath related to each mappingPath. It will be used as container for references. // Example: c:\\csi\\smbmounts\\hostname\\sharename -// 3. Each reference is a file with name based on MD5 of remotePath. For debug it also will contains remotePath in body of the file. -// So, in incementRemotePathReferencesCount we create the file. In decrementRemotePathReferencesCount we remove the file. +// 3. Each reference is a file with name based on MD5 of volumeID. For debug it also will contains remotePath in body of the file. +// So, in incementVolumeIDReferencesCount we create the file. In decrementRemotePathReferencesCount we remove the file. // Example: c:\\csi\\smbmounts\\hostname\\sharename\\092f1413e6c1d03af8b5da6f44619af8 -func incementRemotePathReferencesCount(mappingPath, remotePath string) error { +func incementVolumeIDReferencesCount(mappingPath, remotePath string, volumeID string) error { remotePath = strings.TrimSuffix(remotePath, "\\") path := filepath.Join(basePath, strings.TrimPrefix(mappingPath, "\\\\")) if err := os.MkdirAll(path, os.ModeDir); err != nil { return err } - filePath := filepath.Join(path, getMd5(remotePath)) + filePath := filepath.Join(path, getMd5(volumeID)) file, err := os.Create(filePath) if err != nil { return err @@ -90,21 +90,20 @@ func incementRemotePathReferencesCount(mappingPath, remotePath string) error { return err } -// decrementRemotePathReferencesCount - removes reference between mappingPath and remotePath. -// See incementRemotePathReferencesCount to understand how references work. -func decrementRemotePathReferencesCount(mappingPath, remotePath string) error { - remotePath = strings.TrimSuffix(remotePath, "\\") +// decrementVolumeIDReferencesCount - removes reference between mappingPath and remotePath. +// See incementVolumeIDReferencesCount to understand how references work. +func decrementVolumeIDReferencesCount(mappingPath, volumeID string) error { path := filepath.Join(basePath, strings.TrimPrefix(mappingPath, "\\\\")) if err := os.MkdirAll(path, os.ModeDir); err != nil { return err } - filePath := filepath.Join(path, getMd5(remotePath)) + filePath := filepath.Join(path, getMd5(volumeID)) return os.Remove(filePath) } -// getRemotePathReferencesCount - returns count of references between mappingPath and remotePath. -// See incementRemotePathReferencesCount to understand how references work. -func getRemotePathReferencesCount(mappingPath string) int { +// getVolumeIDReferencesCount - returns count of references between mappingPath and remotePath. +// See incementVolumeIDReferencesCount to understand how references work. +func getVolumeIDReferencesCount(mappingPath string) int { path := filepath.Join(basePath, strings.TrimPrefix(mappingPath, "\\\\")) if os.MkdirAll(path, os.ModeDir) != nil { return -1 diff --git a/pkg/mounter/refcounter_windows_test.go b/pkg/mounter/refcounter_windows_test.go index f1031702087..54fb04bdb4b 100644 --- a/pkg/mounter/refcounter_windows_test.go +++ b/pkg/mounter/refcounter_windows_test.go @@ -106,36 +106,51 @@ func TestGetRootMappingPath(t *testing.T) { } } -func TestRemotePathReferencesCounter(t *testing.T) { - remotePath1 := "\\\\servername\\share\\subpath\\1" - remotePath2 := "\\\\servername\\share\\subpath\\2" - mappingPath, err := getRootMappingPath(remotePath1) - assert.Nil(t, err) +func TestVolumeIDReferencesCounter(t *testing.T) { + testCases := []struct { + path1 string + path2 string + }{ + { + path1: "\\\\servername\\share\\subpath\\1", + path2: "\\\\servername\\share\\subpath\\2", + }, + { + path1: "\\\\servername\\share", + path2: "\\\\servername\\share", + }, + } + for _, tc := range testCases { + remotePath1 := tc.path1 + remotePath2 := tc.path2 + mappingPath, err := getRootMappingPath(remotePath1) + assert.Nil(t, err) - basePath = os.Getenv("TEMP") + "\\TestMappingPathCounter" - os.RemoveAll(basePath) - defer func() { - // cleanup temp folder + basePath = os.Getenv("TEMP") + "\\TestMappingPathCounter" os.RemoveAll(basePath) - }() - - // by default we have no any files in `mappingPath`. So, `count` should be zero - assert.Zero(t, getRemotePathReferencesCount(mappingPath)) - // add reference to `remotePath1`. So, `count` should be equal `1` - assert.Nil(t, incementRemotePathReferencesCount(mappingPath, remotePath1)) - assert.Equal(t, 1, getRemotePathReferencesCount(mappingPath)) - // add reference to `remotePath2`. So, `count` should be equal `2` - assert.Nil(t, incementRemotePathReferencesCount(mappingPath, remotePath2)) - assert.Equal(t, 2, getRemotePathReferencesCount(mappingPath)) - // remove reference to `remotePath1`. So, `count` should be equal `1` - assert.Nil(t, decrementRemotePathReferencesCount(mappingPath, remotePath1)) - assert.Equal(t, 1, getRemotePathReferencesCount(mappingPath)) - // remove reference to `remotePath2`. So, `count` should be equal `0` - assert.Nil(t, decrementRemotePathReferencesCount(mappingPath, remotePath2)) - assert.Zero(t, getRemotePathReferencesCount(mappingPath)) + defer func() { + // cleanup temp folder + os.RemoveAll(basePath) + }() + + // by default we have no any files in `mappingPath`. So, `count` should be zero + assert.Zero(t, getVolumeIDReferencesCount(mappingPath)) + // add reference to `remotePath1`. So, `count` should be equal `1` + assert.Nil(t, incementVolumeIDReferencesCount(mappingPath, remotePath1, "vol1")) + assert.Equal(t, 1, getVolumeIDReferencesCount(mappingPath)) + // add reference to `remotePath2`. So, `count` should be equal `2` + assert.Nil(t, incementVolumeIDReferencesCount(mappingPath, remotePath2, "vol2")) + assert.Equal(t, 2, getVolumeIDReferencesCount(mappingPath)) + // remove reference to `remotePath1`. So, `count` should be equal `1` + assert.Nil(t, decrementVolumeIDReferencesCount(mappingPath, "vol1")) + assert.Equal(t, 1, getVolumeIDReferencesCount(mappingPath)) + // remove reference to `remotePath2`. So, `count` should be equal `0` + assert.Nil(t, decrementVolumeIDReferencesCount(mappingPath, "vol2")) + assert.Zero(t, getVolumeIDReferencesCount(mappingPath)) + } } -func TestIncementRemotePathReferencesCount(t *testing.T) { +func TestIncementVolumeIDReferencesCount(t *testing.T) { remotePath := "\\\\servername\\share\\subpath" mappingPath, err := getRootMappingPath(remotePath) assert.Nil(t, err) @@ -147,20 +162,20 @@ func TestIncementRemotePathReferencesCount(t *testing.T) { os.RemoveAll(basePath) }() - assert.Nil(t, incementRemotePathReferencesCount(mappingPath, remotePath)) + assert.Nil(t, incementVolumeIDReferencesCount(mappingPath, remotePath, "vol1")) mappingPathContainer := basePath + "\\servername\\share" if dir, err := os.Stat(mappingPathContainer); os.IsNotExist(err) || !dir.IsDir() { t.Error("mapping file container does not exist") } - reference := mappingPathContainer + "\\" + getMd5(remotePath) + reference := mappingPathContainer + "\\" + "vol1" if file, err := os.Stat(reference); os.IsNotExist(err) || file.IsDir() { t.Error("reference file does not exist") } } -func TestDecrementRemotePathReferencesCount(t *testing.T) { +func TestDecrementVolumeIDReferencesCount(t *testing.T) { remotePath := "\\\\servername\\share\\subpath" mappingPath, err := getRootMappingPath(remotePath) assert.Nil(t, err) @@ -172,21 +187,21 @@ func TestDecrementRemotePathReferencesCount(t *testing.T) { os.RemoveAll(basePath) }() - assert.Nil(t, incementRemotePathReferencesCount(mappingPath, remotePath)) - assert.Nil(t, decrementRemotePathReferencesCount(mappingPath, remotePath)) + assert.Nil(t, incementVolumeIDReferencesCount(mappingPath, remotePath, "vol1")) + assert.Nil(t, decrementVolumeIDReferencesCount(mappingPath, "vol1")) mappingPathContainer := basePath + "\\servername\\share" if dir, err := os.Stat(mappingPathContainer); os.IsNotExist(err) || !dir.IsDir() { t.Error("mapping file container does not exist") } - reference := mappingPathContainer + "\\" + getMd5(remotePath) + reference := mappingPathContainer + "\\" + "vol1" if _, err := os.Stat(reference); os.IsExist(err) { t.Error("reference file exists") } } -func TestMultiplyCallsOfIncementRemotePathReferencesCount(t *testing.T) { +func TestMultiplyCallsOfIncementVolumeIDReferencesCount(t *testing.T) { remotePath := "\\\\servername\\share\\subpath" mappingPath, err := getRootMappingPath(remotePath) assert.Nil(t, err) @@ -198,17 +213,17 @@ func TestMultiplyCallsOfIncementRemotePathReferencesCount(t *testing.T) { os.RemoveAll(basePath) }() - assert.Zero(t, getRemotePathReferencesCount(mappingPath)) - assert.Nil(t, incementRemotePathReferencesCount(mappingPath, remotePath)) + assert.Zero(t, getVolumeIDReferencesCount(mappingPath)) + assert.Nil(t, incementVolumeIDReferencesCount(mappingPath, remotePath, "vol")) // next calls of `incementMappingPathCount` with the same arguments should be ignored - assert.Nil(t, incementRemotePathReferencesCount(mappingPath, remotePath)) - assert.Nil(t, incementRemotePathReferencesCount(mappingPath, remotePath)) - assert.Nil(t, incementRemotePathReferencesCount(mappingPath, remotePath)) - assert.Nil(t, incementRemotePathReferencesCount(mappingPath, remotePath)) - assert.Equal(t, 1, getRemotePathReferencesCount(mappingPath)) + assert.Nil(t, incementVolumeIDReferencesCount(mappingPath, remotePath, "vol")) + assert.Nil(t, incementVolumeIDReferencesCount(mappingPath, remotePath, "vol")) + assert.Nil(t, incementVolumeIDReferencesCount(mappingPath, remotePath, "vol")) + assert.Nil(t, incementVolumeIDReferencesCount(mappingPath, remotePath, "vol")) + assert.Equal(t, 1, getVolumeIDReferencesCount(mappingPath)) } -func TestMultiplyCallsOfDecrementRemotePathReferencesCount(t *testing.T) { +func TestMultiplyCallsOfDecrementVolumeIDReferencesCount(t *testing.T) { remotePath := "\\\\servername\\share\\subpath" mappingPath, err := getRootMappingPath(remotePath) assert.Nil(t, err) @@ -220,8 +235,8 @@ func TestMultiplyCallsOfDecrementRemotePathReferencesCount(t *testing.T) { os.RemoveAll(basePath) }() - assert.Zero(t, getRemotePathReferencesCount(mappingPath)) - assert.Nil(t, incementRemotePathReferencesCount(mappingPath, remotePath)) - assert.Nil(t, decrementRemotePathReferencesCount(mappingPath, remotePath)) - assert.NotNil(t, decrementRemotePathReferencesCount(mappingPath, remotePath)) + assert.Zero(t, getVolumeIDReferencesCount(mappingPath)) + assert.Nil(t, incementVolumeIDReferencesCount(mappingPath, remotePath, "vol")) + assert.Nil(t, decrementVolumeIDReferencesCount(mappingPath, "vol")) + assert.NotNil(t, decrementVolumeIDReferencesCount(mappingPath, "vol")) } diff --git a/pkg/mounter/safe_mounter_v1beta_windows.go b/pkg/mounter/safe_mounter_v1beta_windows.go index 429ea94081f..54a75ae9839 100644 --- a/pkg/mounter/safe_mounter_v1beta_windows.go +++ b/pkg/mounter/safe_mounter_v1beta_windows.go @@ -44,7 +44,7 @@ type csiProxyMounterV1Beta struct { SMBClient *smbclient.Client } -func (mounter *csiProxyMounterV1Beta) SMBMount(source, target, fsType string, mountOptions, sensitiveMountOptions []string) error { +func (mounter *csiProxyMounterV1Beta) SMBMount(source, target, fsType string, mountOptions, sensitiveMountOptions []string, volumeID string) error { klog.V(4).Infof("SMBMount: remote path: %s. local path: %s", source, target) if len(mountOptions) == 0 || len(sensitiveMountOptions) == 0 { @@ -93,7 +93,7 @@ func (mounter *csiProxyMounterV1Beta) SMBMount(source, target, fsType string, mo return nil } -func (mounter *csiProxyMounterV1Beta) SMBUnmount(target string) error { +func (mounter *csiProxyMounterV1Beta) SMBUnmount(target string, volumeID string) error { klog.V(4).Infof("SMBUnmount: local path: %s", target) // TODO: We need to remove the SMB mapping. The change to remove the // directory brings the CSI code in parity with the in-tree. diff --git a/pkg/mounter/safe_mounter_windows.go b/pkg/mounter/safe_mounter_windows.go index 46dff0081a0..52f0da75562 100644 --- a/pkg/mounter/safe_mounter_windows.go +++ b/pkg/mounter/safe_mounter_windows.go @@ -42,8 +42,8 @@ import ( type CSIProxyMounter interface { mount.Interface - SMBMount(source, target, fsType string, mountOptions, sensitiveMountOptions []string) error - SMBUnmount(target string) error + SMBMount(source, target, fsType string, mountOptions, sensitiveMountOptions []string, volumeID string) error + SMBUnmount(target string, volumeID string) error MakeDir(path string) error Rmdir(path string) error IsMountPointMatch(mp mount.MountPoint, dir string) bool @@ -68,7 +68,7 @@ func normalizeWindowsPath(path string) string { return normalizedPath } -func (mounter *csiProxyMounter) SMBMount(source, target, fsType string, mountOptions, sensitiveMountOptions []string) error { +func (mounter *csiProxyMounter) SMBMount(source, target, fsType string, mountOptions, sensitiveMountOptions []string, volumeID string) error { klog.V(2).Infof("SMBMount: remote path: %s local path: %s", source, target) if len(mountOptions) == 0 || len(sensitiveMountOptions) == 0 { @@ -124,14 +124,14 @@ func (mounter *csiProxyMounter) SMBMount(source, target, fsType string, mountOpt klog.V(2).Infof("NewSmbGlobalMapping %s on %s successfully", source, normalizedTarget) if mounter.RemoveSMBMappingDuringUnmount { - if err := incementRemotePathReferencesCount(mappingPath, source); err != nil { - return fmt.Errorf("incementMappingPathCount(%s, %s) failed with error: %v", mappingPath, source, err) + if err := incementVolumeIDReferencesCount(mappingPath, source, volumeID); err != nil { + return fmt.Errorf("incementRemotePathReferencesCount(%s, %s, %s) failed with error: %v", mappingPath, source, volumeID, err) } } return nil } -func (mounter *csiProxyMounter) SMBUnmount(target string) error { +func (mounter *csiProxyMounter) SMBUnmount(target string, volumeID string) error { klog.V(4).Infof("SMBUnmount: local path: %s", target) if remotePath, err := os.Readlink(target); err != nil { @@ -144,14 +144,14 @@ func (mounter *csiProxyMounter) SMBUnmount(target string) error { } klog.V(4).Infof("SMBUnmount: remote path: %s, mapping path: %s", remotePath, mappingPath) - unlock := lock(mappingPath) - defer unlock() - if mounter.RemoveSMBMappingDuringUnmount { - if err := decrementRemotePathReferencesCount(mappingPath, remotePath); err != nil { - return fmt.Errorf("decrementMappingPathCount(%s, %s) failed with error: %v", mappingPath, remotePath, err) + unlock := lock(mappingPath) + defer unlock() + + if err := decrementVolumeIDReferencesCount(mappingPath, volumeID); err != nil { + return fmt.Errorf("decrementRemotePathReferencesCount(%s, %s) failed with error: %v", mappingPath, volumeID, err) } - count := getRemotePathReferencesCount(mappingPath) + count := getVolumeIDReferencesCount(mappingPath) if count == 0 { smbUnmountRequest := &smb.RemoveSmbGlobalMappingRequest{ RemotePath: remotePath, diff --git a/pkg/smb/nodeserver.go b/pkg/smb/nodeserver.go index 59b63f91a72..9f2d59cb01a 100644 --- a/pkg/smb/nodeserver.go +++ b/pkg/smb/nodeserver.go @@ -225,7 +225,7 @@ func (d *Driver) NodeStageVolume(ctx context.Context, req *csi.NodeStageVolumeRe } mountComplete := false err = wait.PollImmediate(1*time.Second, 2*time.Minute, func() (bool, error) { - err := Mount(d.mounter, source, targetPath, "cifs", mountOptions, sensitiveMountOptions) + err := Mount(d.mounter, source, targetPath, "cifs", mountOptions, sensitiveMountOptions, volumeID) mountComplete = true return true, err }) @@ -258,7 +258,7 @@ func (d *Driver) NodeUnstageVolume(ctx context.Context, req *csi.NodeUnstageVolu defer d.volumeLocks.Release(volumeID) klog.V(2).Infof("NodeUnstageVolume: CleanupMountPoint on %s with volume %s", stagingTargetPath, volumeID) - if err := CleanupSMBMountPoint(d.mounter, stagingTargetPath, true /*extensiveMountPointCheck*/); err != nil { + if err := CleanupSMBMountPoint(d.mounter, stagingTargetPath, true /*extensiveMountPointCheck*/, volumeID); err != nil { return nil, status.Errorf(codes.Internal, "failed to unmount staging target %q: %v", stagingTargetPath, err) } diff --git a/pkg/smb/smb_common_darwin.go b/pkg/smb/smb_common_darwin.go index 646779607fc..f4b4fcc3270 100644 --- a/pkg/smb/smb_common_darwin.go +++ b/pkg/smb/smb_common_darwin.go @@ -25,11 +25,11 @@ import ( mount "k8s.io/mount-utils" ) -func Mount(m *mount.SafeFormatAndMount, source, target, fsType string, options []string, sensitiveMountOptions []string) error { +func Mount(m *mount.SafeFormatAndMount, source, target, fsType string, options []string, sensitiveMountOptions []string, volumeID string) error { return m.MountSensitive(source, target, fsType, options, sensitiveMountOptions) } -func CleanupSMBMountPoint(m *mount.SafeFormatAndMount, target string, extensiveMountCheck bool) error { +func CleanupSMBMountPoint(m *mount.SafeFormatAndMount, target string, extensiveMountCheck bool, volumeID string) error { return mount.CleanupMountPoint(target, m, extensiveMountCheck) } diff --git a/pkg/smb/smb_common_linux.go b/pkg/smb/smb_common_linux.go index 086b71de97d..c163ae6a4d2 100644 --- a/pkg/smb/smb_common_linux.go +++ b/pkg/smb/smb_common_linux.go @@ -25,11 +25,11 @@ import ( mount "k8s.io/mount-utils" ) -func Mount(m *mount.SafeFormatAndMount, source, target, fsType string, options, sensitiveMountOptions []string) error { +func Mount(m *mount.SafeFormatAndMount, source, target, fsType string, options, sensitiveMountOptions []string, volumeID string) error { return m.MountSensitive(source, target, fsType, options, sensitiveMountOptions) } -func CleanupSMBMountPoint(m *mount.SafeFormatAndMount, target string, extensiveMountCheck bool) error { +func CleanupSMBMountPoint(m *mount.SafeFormatAndMount, target string, extensiveMountCheck bool, volumeID string) error { return mount.CleanupMountPoint(target, m, extensiveMountCheck) } diff --git a/pkg/smb/smb_common_windows.go b/pkg/smb/smb_common_windows.go index 7742d28e55c..61a86eeeff0 100644 --- a/pkg/smb/smb_common_windows.go +++ b/pkg/smb/smb_common_windows.go @@ -28,9 +28,9 @@ import ( mount "k8s.io/mount-utils" ) -func Mount(m *mount.SafeFormatAndMount, source, target, fsType string, mountOptions, sensitiveMountOptions []string) error { +func Mount(m *mount.SafeFormatAndMount, source, target, fsType string, mountOptions, sensitiveMountOptions []string, volumeID string) error { if proxy, ok := m.Interface.(mounter.CSIProxyMounter); ok { - return proxy.SMBMount(source, target, fsType, mountOptions, sensitiveMountOptions) + return proxy.SMBMount(source, target, fsType, mountOptions, sensitiveMountOptions, volumeID) } return fmt.Errorf("could not cast to csi proxy class") } @@ -38,9 +38,9 @@ func Mount(m *mount.SafeFormatAndMount, source, target, fsType string, mountOpti // CleanupSMBMountPoint - In windows CSI proxy call to umount is used to unmount the SMB. // The clean up mount point point calls is supposed for fix the corrupted directories as well. // For alpha CSI proxy integration, we only do an unmount. -func CleanupSMBMountPoint(m *mount.SafeFormatAndMount, target string, extensiveMountCheck bool) error { +func CleanupSMBMountPoint(m *mount.SafeFormatAndMount, target string, extensiveMountCheck bool, volumeID string) error { if proxy, ok := m.Interface.(mounter.CSIProxyMounter); ok { - return proxy.SMBUnmount(target) + return proxy.SMBUnmount(target, volumeID) } return fmt.Errorf("could not cast to csi proxy class") }