Skip to content

Commit

Permalink
Use volumeID for refs counting
Browse files Browse the repository at this point in the history
  • Loading branch information
vitaliy-leschenko committed Oct 18, 2023
1 parent 5d78fe3 commit d754461
Show file tree
Hide file tree
Showing 8 changed files with 96 additions and 82 deletions.
25 changes: 12 additions & 13 deletions pkg/mounter/refcounter_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,21 +63,21 @@ func getRootMappingPath(path string) (string, error) {
return strings.ToLower("\\\\" + parts[0] + "\\" + parts[1]), nil
}

// incementRemotePathReferencesCount - adds new reference between mappingPath and remotePath if it doesn't exist.
// incementVolumeIDReferencesCount - adds new reference between mappingPath and remotePath if it doesn't exist.
// How it works:
// 1. MappingPath contains two components: hostname, sharename
// 2. We create directory in basePath related to each mappingPath. It will be used as container for references.
// Example: c:\\csi\\smbmounts\\hostname\\sharename
// 3. Each reference is a file with name based on MD5 of remotePath. For debug it also will contains remotePath in body of the file.
// So, in incementRemotePathReferencesCount we create the file. In decrementRemotePathReferencesCount we remove the file.
// 3. Each reference is a file with name based on MD5 of volumeID. For debug it also will contains remotePath in body of the file.
// So, in incementVolumeIDReferencesCount we create the file. In decrementRemotePathReferencesCount we remove the file.
// Example: c:\\csi\\smbmounts\\hostname\\sharename\\092f1413e6c1d03af8b5da6f44619af8
func incementRemotePathReferencesCount(mappingPath, remotePath string) error {
func incementVolumeIDReferencesCount(mappingPath, remotePath string, volumeID string) error {
remotePath = strings.TrimSuffix(remotePath, "\\")
path := filepath.Join(basePath, strings.TrimPrefix(mappingPath, "\\\\"))
if err := os.MkdirAll(path, os.ModeDir); err != nil {
return err
}
filePath := filepath.Join(path, getMd5(remotePath))
filePath := filepath.Join(path, getMd5(volumeID))
file, err := os.Create(filePath)
if err != nil {
return err
Expand All @@ -90,21 +90,20 @@ func incementRemotePathReferencesCount(mappingPath, remotePath string) error {
return err
}

// decrementRemotePathReferencesCount - removes reference between mappingPath and remotePath.
// See incementRemotePathReferencesCount to understand how references work.
func decrementRemotePathReferencesCount(mappingPath, remotePath string) error {
remotePath = strings.TrimSuffix(remotePath, "\\")
// decrementVolumeIDReferencesCount - removes reference between mappingPath and remotePath.
// See incementVolumeIDReferencesCount to understand how references work.
func decrementVolumeIDReferencesCount(mappingPath, volumeID string) error {
path := filepath.Join(basePath, strings.TrimPrefix(mappingPath, "\\\\"))
if err := os.MkdirAll(path, os.ModeDir); err != nil {
return err
}
filePath := filepath.Join(path, getMd5(remotePath))
filePath := filepath.Join(path, getMd5(volumeID))
return os.Remove(filePath)
}

// getRemotePathReferencesCount - returns count of references between mappingPath and remotePath.
// See incementRemotePathReferencesCount to understand how references work.
func getRemotePathReferencesCount(mappingPath string) int {
// getVolumeIDReferencesCount - returns count of references between mappingPath and remotePath.
// See incementVolumeIDReferencesCount to understand how references work.
func getVolumeIDReferencesCount(mappingPath string) int {
path := filepath.Join(basePath, strings.TrimPrefix(mappingPath, "\\\\"))
if os.MkdirAll(path, os.ModeDir) != nil {
return -1
Expand Down
105 changes: 60 additions & 45 deletions pkg/mounter/refcounter_windows_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,36 +106,51 @@ func TestGetRootMappingPath(t *testing.T) {
}
}

func TestRemotePathReferencesCounter(t *testing.T) {
remotePath1 := "\\\\servername\\share\\subpath\\1"
remotePath2 := "\\\\servername\\share\\subpath\\2"
mappingPath, err := getRootMappingPath(remotePath1)
assert.Nil(t, err)
func TestVolumeIDReferencesCounter(t *testing.T) {
testCases := []struct {
path1 string
path2 string
}{
{
path1: "\\\\servername\\share\\subpath\\1",
path2: "\\\\servername\\share\\subpath\\2",
},
{
path1: "\\\\servername\\share",
path2: "\\\\servername\\share",
},
}
for _, tc := range testCases {
remotePath1 := tc.path1
remotePath2 := tc.path2
mappingPath, err := getRootMappingPath(remotePath1)
assert.Nil(t, err)

basePath = os.Getenv("TEMP") + "\\TestMappingPathCounter"
os.RemoveAll(basePath)
defer func() {
// cleanup temp folder
basePath = os.Getenv("TEMP") + "\\TestMappingPathCounter"
os.RemoveAll(basePath)
}()

// by default we have no any files in `mappingPath`. So, `count` should be zero
assert.Zero(t, getRemotePathReferencesCount(mappingPath))
// add reference to `remotePath1`. So, `count` should be equal `1`
assert.Nil(t, incementRemotePathReferencesCount(mappingPath, remotePath1))
assert.Equal(t, 1, getRemotePathReferencesCount(mappingPath))
// add reference to `remotePath2`. So, `count` should be equal `2`
assert.Nil(t, incementRemotePathReferencesCount(mappingPath, remotePath2))
assert.Equal(t, 2, getRemotePathReferencesCount(mappingPath))
// remove reference to `remotePath1`. So, `count` should be equal `1`
assert.Nil(t, decrementRemotePathReferencesCount(mappingPath, remotePath1))
assert.Equal(t, 1, getRemotePathReferencesCount(mappingPath))
// remove reference to `remotePath2`. So, `count` should be equal `0`
assert.Nil(t, decrementRemotePathReferencesCount(mappingPath, remotePath2))
assert.Zero(t, getRemotePathReferencesCount(mappingPath))
defer func() {
// cleanup temp folder
os.RemoveAll(basePath)
}()

// by default we have no any files in `mappingPath`. So, `count` should be zero
assert.Zero(t, getVolumeIDReferencesCount(mappingPath))
// add reference to `remotePath1`. So, `count` should be equal `1`
assert.Nil(t, incementVolumeIDReferencesCount(mappingPath, remotePath1, "vol1"))
assert.Equal(t, 1, getVolumeIDReferencesCount(mappingPath))
// add reference to `remotePath2`. So, `count` should be equal `2`
assert.Nil(t, incementVolumeIDReferencesCount(mappingPath, remotePath2, "vol2"))
assert.Equal(t, 2, getVolumeIDReferencesCount(mappingPath))
// remove reference to `remotePath1`. So, `count` should be equal `1`
assert.Nil(t, decrementVolumeIDReferencesCount(mappingPath, "vol1"))
assert.Equal(t, 1, getVolumeIDReferencesCount(mappingPath))
// remove reference to `remotePath2`. So, `count` should be equal `0`
assert.Nil(t, decrementVolumeIDReferencesCount(mappingPath, "vol2"))
assert.Zero(t, getVolumeIDReferencesCount(mappingPath))
}
}

func TestIncementRemotePathReferencesCount(t *testing.T) {
func TestIncementVolumeIDReferencesCount(t *testing.T) {
remotePath := "\\\\servername\\share\\subpath"
mappingPath, err := getRootMappingPath(remotePath)
assert.Nil(t, err)
Expand All @@ -147,20 +162,20 @@ func TestIncementRemotePathReferencesCount(t *testing.T) {
os.RemoveAll(basePath)
}()

assert.Nil(t, incementRemotePathReferencesCount(mappingPath, remotePath))
assert.Nil(t, incementVolumeIDReferencesCount(mappingPath, remotePath, "vol1"))

mappingPathContainer := basePath + "\\servername\\share"
if dir, err := os.Stat(mappingPathContainer); os.IsNotExist(err) || !dir.IsDir() {
t.Error("mapping file container does not exist")
}

reference := mappingPathContainer + "\\" + getMd5(remotePath)
reference := mappingPathContainer + "\\" + "vol1"
if file, err := os.Stat(reference); os.IsNotExist(err) || file.IsDir() {
t.Error("reference file does not exist")
}
}

func TestDecrementRemotePathReferencesCount(t *testing.T) {
func TestDecrementVolumeIDReferencesCount(t *testing.T) {
remotePath := "\\\\servername\\share\\subpath"
mappingPath, err := getRootMappingPath(remotePath)
assert.Nil(t, err)
Expand All @@ -172,21 +187,21 @@ func TestDecrementRemotePathReferencesCount(t *testing.T) {
os.RemoveAll(basePath)
}()

assert.Nil(t, incementRemotePathReferencesCount(mappingPath, remotePath))
assert.Nil(t, decrementRemotePathReferencesCount(mappingPath, remotePath))
assert.Nil(t, incementVolumeIDReferencesCount(mappingPath, remotePath, "vol1"))
assert.Nil(t, decrementVolumeIDReferencesCount(mappingPath, "vol1"))

mappingPathContainer := basePath + "\\servername\\share"
if dir, err := os.Stat(mappingPathContainer); os.IsNotExist(err) || !dir.IsDir() {
t.Error("mapping file container does not exist")
}

reference := mappingPathContainer + "\\" + getMd5(remotePath)
reference := mappingPathContainer + "\\" + "vol1"
if _, err := os.Stat(reference); os.IsExist(err) {
t.Error("reference file exists")
}
}

func TestMultiplyCallsOfIncementRemotePathReferencesCount(t *testing.T) {
func TestMultiplyCallsOfIncementVolumeIDReferencesCount(t *testing.T) {
remotePath := "\\\\servername\\share\\subpath"
mappingPath, err := getRootMappingPath(remotePath)
assert.Nil(t, err)
Expand All @@ -198,17 +213,17 @@ func TestMultiplyCallsOfIncementRemotePathReferencesCount(t *testing.T) {
os.RemoveAll(basePath)
}()

assert.Zero(t, getRemotePathReferencesCount(mappingPath))
assert.Nil(t, incementRemotePathReferencesCount(mappingPath, remotePath))
assert.Zero(t, getVolumeIDReferencesCount(mappingPath))
assert.Nil(t, incementVolumeIDReferencesCount(mappingPath, remotePath, "vol"))
// next calls of `incementMappingPathCount` with the same arguments should be ignored
assert.Nil(t, incementRemotePathReferencesCount(mappingPath, remotePath))
assert.Nil(t, incementRemotePathReferencesCount(mappingPath, remotePath))
assert.Nil(t, incementRemotePathReferencesCount(mappingPath, remotePath))
assert.Nil(t, incementRemotePathReferencesCount(mappingPath, remotePath))
assert.Equal(t, 1, getRemotePathReferencesCount(mappingPath))
assert.Nil(t, incementVolumeIDReferencesCount(mappingPath, remotePath, "vol"))
assert.Nil(t, incementVolumeIDReferencesCount(mappingPath, remotePath, "vol"))
assert.Nil(t, incementVolumeIDReferencesCount(mappingPath, remotePath, "vol"))
assert.Nil(t, incementVolumeIDReferencesCount(mappingPath, remotePath, "vol"))
assert.Equal(t, 1, getVolumeIDReferencesCount(mappingPath))
}

func TestMultiplyCallsOfDecrementRemotePathReferencesCount(t *testing.T) {
func TestMultiplyCallsOfDecrementVolumeIDReferencesCount(t *testing.T) {
remotePath := "\\\\servername\\share\\subpath"
mappingPath, err := getRootMappingPath(remotePath)
assert.Nil(t, err)
Expand All @@ -220,8 +235,8 @@ func TestMultiplyCallsOfDecrementRemotePathReferencesCount(t *testing.T) {
os.RemoveAll(basePath)
}()

assert.Zero(t, getRemotePathReferencesCount(mappingPath))
assert.Nil(t, incementRemotePathReferencesCount(mappingPath, remotePath))
assert.Nil(t, decrementRemotePathReferencesCount(mappingPath, remotePath))
assert.NotNil(t, decrementRemotePathReferencesCount(mappingPath, remotePath))
assert.Zero(t, getVolumeIDReferencesCount(mappingPath))
assert.Nil(t, incementVolumeIDReferencesCount(mappingPath, remotePath, "vol"))
assert.Nil(t, decrementVolumeIDReferencesCount(mappingPath, "vol"))
assert.NotNil(t, decrementVolumeIDReferencesCount(mappingPath, "vol"))
}
4 changes: 2 additions & 2 deletions pkg/mounter/safe_mounter_v1beta_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ type csiProxyMounterV1Beta struct {
SMBClient *smbclient.Client
}

func (mounter *csiProxyMounterV1Beta) SMBMount(source, target, fsType string, mountOptions, sensitiveMountOptions []string) error {
func (mounter *csiProxyMounterV1Beta) SMBMount(source, target, fsType string, mountOptions, sensitiveMountOptions []string, volumeID string) error {
klog.V(4).Infof("SMBMount: remote path: %s. local path: %s", source, target)

if len(mountOptions) == 0 || len(sensitiveMountOptions) == 0 {
Expand Down Expand Up @@ -93,7 +93,7 @@ func (mounter *csiProxyMounterV1Beta) SMBMount(source, target, fsType string, mo
return nil
}

func (mounter *csiProxyMounterV1Beta) SMBUnmount(target string) error {
func (mounter *csiProxyMounterV1Beta) SMBUnmount(target string, volumeID string) error {
klog.V(4).Infof("SMBUnmount: local path: %s", target)
// TODO: We need to remove the SMB mapping. The change to remove the
// directory brings the CSI code in parity with the in-tree.
Expand Down
24 changes: 12 additions & 12 deletions pkg/mounter/safe_mounter_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ import (
type CSIProxyMounter interface {
mount.Interface

SMBMount(source, target, fsType string, mountOptions, sensitiveMountOptions []string) error
SMBUnmount(target string) error
SMBMount(source, target, fsType string, mountOptions, sensitiveMountOptions []string, volumeID string) error
SMBUnmount(target string, volumeID string) error
MakeDir(path string) error
Rmdir(path string) error
IsMountPointMatch(mp mount.MountPoint, dir string) bool
Expand All @@ -68,7 +68,7 @@ func normalizeWindowsPath(path string) string {
return normalizedPath
}

func (mounter *csiProxyMounter) SMBMount(source, target, fsType string, mountOptions, sensitiveMountOptions []string) error {
func (mounter *csiProxyMounter) SMBMount(source, target, fsType string, mountOptions, sensitiveMountOptions []string, volumeID string) error {
klog.V(2).Infof("SMBMount: remote path: %s local path: %s", source, target)

if len(mountOptions) == 0 || len(sensitiveMountOptions) == 0 {
Expand Down Expand Up @@ -124,14 +124,14 @@ func (mounter *csiProxyMounter) SMBMount(source, target, fsType string, mountOpt
klog.V(2).Infof("NewSmbGlobalMapping %s on %s successfully", source, normalizedTarget)

if mounter.RemoveSMBMappingDuringUnmount {
if err := incementRemotePathReferencesCount(mappingPath, source); err != nil {
return fmt.Errorf("incementMappingPathCount(%s, %s) failed with error: %v", mappingPath, source, err)
if err := incementVolumeIDReferencesCount(mappingPath, source, volumeID); err != nil {
return fmt.Errorf("incementRemotePathReferencesCount(%s, %s, %s) failed with error: %v", mappingPath, source, volumeID, err)
}
}
return nil
}

func (mounter *csiProxyMounter) SMBUnmount(target string) error {
func (mounter *csiProxyMounter) SMBUnmount(target string, volumeID string) error {
klog.V(4).Infof("SMBUnmount: local path: %s", target)

if remotePath, err := os.Readlink(target); err != nil {
Expand All @@ -144,14 +144,14 @@ func (mounter *csiProxyMounter) SMBUnmount(target string) error {
}
klog.V(4).Infof("SMBUnmount: remote path: %s, mapping path: %s", remotePath, mappingPath)

unlock := lock(mappingPath)
defer unlock()

if mounter.RemoveSMBMappingDuringUnmount {
if err := decrementRemotePathReferencesCount(mappingPath, remotePath); err != nil {
return fmt.Errorf("decrementMappingPathCount(%s, %s) failed with error: %v", mappingPath, remotePath, err)
unlock := lock(mappingPath)
defer unlock()

if err := decrementVolumeIDReferencesCount(mappingPath, volumeID); err != nil {
return fmt.Errorf("decrementRemotePathReferencesCount(%s, %s) failed with error: %v", mappingPath, volumeID, err)
}
count := getRemotePathReferencesCount(mappingPath)
count := getVolumeIDReferencesCount(mappingPath)
if count == 0 {
smbUnmountRequest := &smb.RemoveSmbGlobalMappingRequest{
RemotePath: remotePath,
Expand Down
4 changes: 2 additions & 2 deletions pkg/smb/nodeserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ func (d *Driver) NodeStageVolume(ctx context.Context, req *csi.NodeStageVolumeRe
}
mountComplete := false
err = wait.PollImmediate(1*time.Second, 2*time.Minute, func() (bool, error) {
err := Mount(d.mounter, source, targetPath, "cifs", mountOptions, sensitiveMountOptions)
err := Mount(d.mounter, source, targetPath, "cifs", mountOptions, sensitiveMountOptions, volumeID)
mountComplete = true
return true, err
})
Expand Down Expand Up @@ -258,7 +258,7 @@ func (d *Driver) NodeUnstageVolume(ctx context.Context, req *csi.NodeUnstageVolu
defer d.volumeLocks.Release(volumeID)

klog.V(2).Infof("NodeUnstageVolume: CleanupMountPoint on %s with volume %s", stagingTargetPath, volumeID)
if err := CleanupSMBMountPoint(d.mounter, stagingTargetPath, true /*extensiveMountPointCheck*/); err != nil {
if err := CleanupSMBMountPoint(d.mounter, stagingTargetPath, true /*extensiveMountPointCheck*/, volumeID); err != nil {
return nil, status.Errorf(codes.Internal, "failed to unmount staging target %q: %v", stagingTargetPath, err)
}

Expand Down
4 changes: 2 additions & 2 deletions pkg/smb/smb_common_darwin.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ import (
mount "k8s.io/mount-utils"
)

func Mount(m *mount.SafeFormatAndMount, source, target, fsType string, options []string, sensitiveMountOptions []string) error {
func Mount(m *mount.SafeFormatAndMount, source, target, fsType string, options []string, sensitiveMountOptions []string, volumeID string) error {
return m.MountSensitive(source, target, fsType, options, sensitiveMountOptions)
}

func CleanupSMBMountPoint(m *mount.SafeFormatAndMount, target string, extensiveMountCheck bool) error {
func CleanupSMBMountPoint(m *mount.SafeFormatAndMount, target string, extensiveMountCheck bool, volumeID string) error {
return mount.CleanupMountPoint(target, m, extensiveMountCheck)
}

Expand Down
4 changes: 2 additions & 2 deletions pkg/smb/smb_common_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ import (
mount "k8s.io/mount-utils"
)

func Mount(m *mount.SafeFormatAndMount, source, target, fsType string, options, sensitiveMountOptions []string) error {
func Mount(m *mount.SafeFormatAndMount, source, target, fsType string, options, sensitiveMountOptions []string, volumeID string) error {
return m.MountSensitive(source, target, fsType, options, sensitiveMountOptions)
}

func CleanupSMBMountPoint(m *mount.SafeFormatAndMount, target string, extensiveMountCheck bool) error {
func CleanupSMBMountPoint(m *mount.SafeFormatAndMount, target string, extensiveMountCheck bool, volumeID string) error {
return mount.CleanupMountPoint(target, m, extensiveMountCheck)
}

Expand Down
8 changes: 4 additions & 4 deletions pkg/smb/smb_common_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,19 @@ import (
mount "k8s.io/mount-utils"
)

func Mount(m *mount.SafeFormatAndMount, source, target, fsType string, mountOptions, sensitiveMountOptions []string) error {
func Mount(m *mount.SafeFormatAndMount, source, target, fsType string, mountOptions, sensitiveMountOptions []string, volumeID string) error {
if proxy, ok := m.Interface.(mounter.CSIProxyMounter); ok {
return proxy.SMBMount(source, target, fsType, mountOptions, sensitiveMountOptions)
return proxy.SMBMount(source, target, fsType, mountOptions, sensitiveMountOptions, volumeID)
}
return fmt.Errorf("could not cast to csi proxy class")
}

// CleanupSMBMountPoint - In windows CSI proxy call to umount is used to unmount the SMB.
// The clean up mount point point calls is supposed for fix the corrupted directories as well.
// For alpha CSI proxy integration, we only do an unmount.
func CleanupSMBMountPoint(m *mount.SafeFormatAndMount, target string, extensiveMountCheck bool) error {
func CleanupSMBMountPoint(m *mount.SafeFormatAndMount, target string, extensiveMountCheck bool, volumeID string) error {
if proxy, ok := m.Interface.(mounter.CSIProxyMounter); ok {
return proxy.SMBUnmount(target)
return proxy.SMBUnmount(target, volumeID)
}
return fmt.Errorf("could not cast to csi proxy class")
}
Expand Down

0 comments on commit d754461

Please sign in to comment.