diff --git a/pkg/mounter/refcounter_windows.go b/pkg/mounter/refcounter_windows.go new file mode 100644 index 00000000000..257d6e52d73 --- /dev/null +++ b/pkg/mounter/refcounter_windows.go @@ -0,0 +1,122 @@ +//go:build windows +// +build windows + +/* +Copyright 2020 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package mounter + +import ( + "crypto/md5" + "fmt" + "os" + "path/filepath" + "strings" + "sync" +) + +var basePath = "c:\\csi\\smbmounts" +var mutexes sync.Map + +func lock(key string) func() { + value, _ := mutexes.LoadOrStore(key, &sync.Mutex{}) + mtx := value.(*sync.Mutex) + mtx.Lock() + + return func() { mtx.Unlock() } +} + +// getRootMappingPath - returns root of smb share path or empty string if the path is invalid. For example: +// +// \\hostname\share\subpath => \\hostname\share, error is nil +// \\hostname\share => \\hostname\share, error is nil +// \\hostname => '', error is 'remote path (\\hostname) is invalid' +func getRootMappingPath(path string) (string, error) { + items := strings.Split(path, "\\") + parts := []string{} + for _, s := range items { + if len(s) > 0 { + parts = append(parts, s) + if len(parts) == 2 { + break + } + } + } + if len(parts) != 2 { + return "", fmt.Errorf("remote path (%s) is invalid", path) + } + // parts[0] is a smb host name + // parts[1] is a smb share name + return strings.ToLower("\\\\" + parts[0] + "\\" + parts[1]), nil +} + +// incementRemotePathReferencesCount - adds new reference between mappingPath and remotePath if it doesn't exist. +// How it works: +// 1. MappingPath contains two components: hostname, sharename +// 2. We create directory in basePath related to each mappingPath. It will be used as container for references. +// Example: c:\\csi\\smbmounts\\hostname\\sharename +// 3. Each reference is a file with name based on MD5 of remotePath. For debug it also will contains remotePath in body of the file. +// So, in incementRemotePathReferencesCount we create the file. In decrementRemotePathReferencesCount we remove the file. +// Example: c:\\csi\\smbmounts\\hostname\\sharename\\092f1413e6c1d03af8b5da6f44619af8 +func incementRemotePathReferencesCount(mappingPath, remotePath string) error { + remotePath = strings.TrimSuffix(remotePath, "\\") + path := filepath.Join(basePath, strings.TrimPrefix(mappingPath, "\\\\")) + if err := os.MkdirAll(path, os.ModeDir); err != nil { + return err + } + filePath := filepath.Join(path, getMd5(remotePath)) + file, err := os.Create(filePath) + if err != nil { + return err + } + defer func() { + file.Close() + }() + + _, err = file.WriteString(remotePath) + return err +} + +// decrementRemotePathReferencesCount - removes reference between mappingPath and remotePath. +// See incementRemotePathReferencesCount to understand how references work. +func decrementRemotePathReferencesCount(mappingPath, remotePath string) error { + remotePath = strings.TrimSuffix(remotePath, "\\") + path := filepath.Join(basePath, strings.TrimPrefix(mappingPath, "\\\\")) + if err := os.MkdirAll(path, os.ModeDir); err != nil { + return err + } + filePath := filepath.Join(path, getMd5(remotePath)) + return os.Remove(filePath) +} + +// getRemotePathReferencesCount - returns count of references between mappingPath and remotePath. +// See incementRemotePathReferencesCount to understand how references work. +func getRemotePathReferencesCount(mappingPath string) int { + path := filepath.Join(basePath, strings.TrimPrefix(mappingPath, "\\\\")) + if os.MkdirAll(path, os.ModeDir) != nil { + return -1 + } + files, err := os.ReadDir(path) + if err != nil { + return -1 + } + return len(files) +} + +func getMd5(path string) string { + data := []byte(strings.ToLower(path)) + return fmt.Sprintf("%x", md5.Sum(data)) +} diff --git a/pkg/mounter/refcounter_windows_test.go b/pkg/mounter/refcounter_windows_test.go new file mode 100644 index 00000000000..f1031702087 --- /dev/null +++ b/pkg/mounter/refcounter_windows_test.go @@ -0,0 +1,227 @@ +/* +Copyright 2020 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package mounter + +import ( + "os" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestLockUnlock(t *testing.T) { + key := "resource name" + + unlock := lock(key) + defer unlock() + + _, loaded := mutexes.Load(key) + assert.True(t, loaded) +} + +func TestLockLockedResource(t *testing.T) { + locked := true + unlock := lock("a") + go func() { + time.Sleep(500 * time.Microsecond) + locked = false + unlock() + }() + + // try to lock already locked resource + unlock2 := lock("a") + defer unlock2() + if locked { + assert.Fail(t, "access to locked resource") + } +} + +func TestLockDifferentKeys(t *testing.T) { + unlocka := lock("a") + unlockb := lock("b") + unlocka() + unlockb() +} + +func TestGetRootMappingPath(t *testing.T) { + testCases := []struct { + remote string + expectResult string + expectError bool + }{ + { + remote: "", + expectResult: "", + expectError: true, + }, + { + remote: "hostname", + expectResult: "", + expectError: true, + }, + { + remote: "\\\\hostname\\path", + expectResult: "\\\\hostname\\path", + expectError: false, + }, + { + remote: "\\\\hostname\\path\\", + expectResult: "\\\\hostname\\path", + expectError: false, + }, + { + remote: "\\\\hostname\\path\\subpath", + expectResult: "\\\\hostname\\path", + expectError: false, + }, + } + for _, tc := range testCases { + result, err := getRootMappingPath(tc.remote) + if tc.expectError && err == nil { + t.Errorf("Expected error but getRootMappingPath returned a nil error") + } + if !tc.expectError { + if err != nil { + t.Errorf("Expected no errors but getRootMappingPath returned error: %v", err) + } + if tc.expectResult != result { + t.Errorf("Expected (%s) but getRootMappingPath returned (%s)", tc.expectResult, result) + } + } + } +} + +func TestRemotePathReferencesCounter(t *testing.T) { + remotePath1 := "\\\\servername\\share\\subpath\\1" + remotePath2 := "\\\\servername\\share\\subpath\\2" + mappingPath, err := getRootMappingPath(remotePath1) + assert.Nil(t, err) + + basePath = os.Getenv("TEMP") + "\\TestMappingPathCounter" + os.RemoveAll(basePath) + defer func() { + // cleanup temp folder + os.RemoveAll(basePath) + }() + + // by default we have no any files in `mappingPath`. So, `count` should be zero + assert.Zero(t, getRemotePathReferencesCount(mappingPath)) + // add reference to `remotePath1`. So, `count` should be equal `1` + assert.Nil(t, incementRemotePathReferencesCount(mappingPath, remotePath1)) + assert.Equal(t, 1, getRemotePathReferencesCount(mappingPath)) + // add reference to `remotePath2`. So, `count` should be equal `2` + assert.Nil(t, incementRemotePathReferencesCount(mappingPath, remotePath2)) + assert.Equal(t, 2, getRemotePathReferencesCount(mappingPath)) + // remove reference to `remotePath1`. So, `count` should be equal `1` + assert.Nil(t, decrementRemotePathReferencesCount(mappingPath, remotePath1)) + assert.Equal(t, 1, getRemotePathReferencesCount(mappingPath)) + // remove reference to `remotePath2`. So, `count` should be equal `0` + assert.Nil(t, decrementRemotePathReferencesCount(mappingPath, remotePath2)) + assert.Zero(t, getRemotePathReferencesCount(mappingPath)) +} + +func TestIncementRemotePathReferencesCount(t *testing.T) { + remotePath := "\\\\servername\\share\\subpath" + mappingPath, err := getRootMappingPath(remotePath) + assert.Nil(t, err) + + basePath = os.Getenv("TEMP") + "\\TestMappingPathCounter" + os.RemoveAll(basePath) + defer func() { + // cleanup temp folder + os.RemoveAll(basePath) + }() + + assert.Nil(t, incementRemotePathReferencesCount(mappingPath, remotePath)) + + mappingPathContainer := basePath + "\\servername\\share" + if dir, err := os.Stat(mappingPathContainer); os.IsNotExist(err) || !dir.IsDir() { + t.Error("mapping file container does not exist") + } + + reference := mappingPathContainer + "\\" + getMd5(remotePath) + if file, err := os.Stat(reference); os.IsNotExist(err) || file.IsDir() { + t.Error("reference file does not exist") + } +} + +func TestDecrementRemotePathReferencesCount(t *testing.T) { + remotePath := "\\\\servername\\share\\subpath" + mappingPath, err := getRootMappingPath(remotePath) + assert.Nil(t, err) + + basePath = os.Getenv("TEMP") + "\\TestMappingPathCounter" + os.RemoveAll(basePath) + defer func() { + // cleanup temp folder + os.RemoveAll(basePath) + }() + + assert.Nil(t, incementRemotePathReferencesCount(mappingPath, remotePath)) + assert.Nil(t, decrementRemotePathReferencesCount(mappingPath, remotePath)) + + mappingPathContainer := basePath + "\\servername\\share" + if dir, err := os.Stat(mappingPathContainer); os.IsNotExist(err) || !dir.IsDir() { + t.Error("mapping file container does not exist") + } + + reference := mappingPathContainer + "\\" + getMd5(remotePath) + if _, err := os.Stat(reference); os.IsExist(err) { + t.Error("reference file exists") + } +} + +func TestMultiplyCallsOfIncementRemotePathReferencesCount(t *testing.T) { + remotePath := "\\\\servername\\share\\subpath" + mappingPath, err := getRootMappingPath(remotePath) + assert.Nil(t, err) + + basePath = os.Getenv("TEMP") + "\\TestMappingPathCounter" + os.RemoveAll(basePath) + defer func() { + // cleanup temp folder + os.RemoveAll(basePath) + }() + + assert.Zero(t, getRemotePathReferencesCount(mappingPath)) + assert.Nil(t, incementRemotePathReferencesCount(mappingPath, remotePath)) + // next calls of `incementMappingPathCount` with the same arguments should be ignored + assert.Nil(t, incementRemotePathReferencesCount(mappingPath, remotePath)) + assert.Nil(t, incementRemotePathReferencesCount(mappingPath, remotePath)) + assert.Nil(t, incementRemotePathReferencesCount(mappingPath, remotePath)) + assert.Nil(t, incementRemotePathReferencesCount(mappingPath, remotePath)) + assert.Equal(t, 1, getRemotePathReferencesCount(mappingPath)) +} + +func TestMultiplyCallsOfDecrementRemotePathReferencesCount(t *testing.T) { + remotePath := "\\\\servername\\share\\subpath" + mappingPath, err := getRootMappingPath(remotePath) + assert.Nil(t, err) + + basePath = os.Getenv("TEMP") + "\\TestMappingPathCounter" + os.RemoveAll(basePath) + defer func() { + // cleanup temp folder + os.RemoveAll(basePath) + }() + + assert.Zero(t, getRemotePathReferencesCount(mappingPath)) + assert.Nil(t, incementRemotePathReferencesCount(mappingPath, remotePath)) + assert.Nil(t, decrementRemotePathReferencesCount(mappingPath, remotePath)) + assert.NotNil(t, decrementRemotePathReferencesCount(mappingPath, remotePath)) +} diff --git a/pkg/mounter/safe_mounter_windows.go b/pkg/mounter/safe_mounter_windows.go index 87e58855e2c..e67bb0a3dc4 100644 --- a/pkg/mounter/safe_mounter_windows.go +++ b/pkg/mounter/safe_mounter_windows.go @@ -101,6 +101,17 @@ func (mounter *csiProxyMounter) SMBMount(source, target, fsType string, mountOpt } source = strings.Replace(source, "/", "\\", -1) + if strings.HasSuffix(source, "\\") { + source = strings.TrimSuffix(source, "\\") + } + + mappingPath, err := getRootMappingPath(source) + if err != nil { + return fmt.Errorf("getRootMappingPath(%s) failed with error: %v", source, err) + } + unlock := lock(mappingPath) + defer unlock() + normalizedTarget := normalizeWindowsPath(target) smbMountRequest := &smb.NewSmbGlobalMappingRequest{ LocalPath: normalizedTarget, @@ -113,13 +124,53 @@ func (mounter *csiProxyMounter) SMBMount(source, target, fsType string, mountOpt return fmt.Errorf("smb mapping failed with error: %v", err) } klog.V(2).Infof("mount %s on %s successfully", source, normalizedTarget) + + if err = incementRemotePathReferencesCount(mappingPath, source); err != nil { + klog.Warningf("incementMappingPathCount(%s, %s) failed with error: %v", mappingPath, source, err) + } + return nil } func (mounter *csiProxyMounter) SMBUnmount(target string) error { klog.V(4).Infof("SMBUnmount: local path: %s", target) - // TODO: We need to remove the SMB mapping. The change to remove the - // directory brings the CSI code in parity with the in-tree. + + if remotePath, err := os.Readlink(target); err != nil { + klog.Warningf("SMBUnmount: can't get remote path: %v", err) + } else { + if strings.HasSuffix(remotePath, "\\") { + remotePath = strings.TrimSuffix(remotePath, "\\") + } + mappingPath, err := getRootMappingPath(remotePath) + if err != nil { + klog.Warningf("getRootMappingPath(%s) failed with error: %v", remotePath, err) + } else { + klog.V(4).Infof("SMBUnmount: remote path: %s, mapping path: %s", remotePath, mappingPath) + + unlock := lock(mappingPath) + defer unlock() + + if err := decrementRemotePathReferencesCount(mappingPath, remotePath); err != nil { + klog.Warningf("decrementMappingPathCount(%s, %d) failed with error: %v", mappingPath, remotePath, err) + } else { + count := getRemotePathReferencesCount(mappingPath) + if count == 0 { + smbUnmountRequest := &smb.RemoveSmbGlobalMappingRequest{ + RemotePath: remotePath, + } + klog.V(2).Infof("begin to unmount %s on %s", remotePath, target) + if _, err := mounter.SMBClient.RemoveSmbGlobalMapping(context.Background(), smbUnmountRequest); err != nil { + return fmt.Errorf("smb unmapping failed with error: %v", err) + } else { + klog.V(2).Infof("unmount %s on %s successfully", remotePath, target) + } + } else { + klog.Infof("SMBUnmount: found %f links to %s", count, mappingPath) + } + } + } + } + return mounter.Rmdir(target) } diff --git a/pkg/smb/nodeserver.go b/pkg/smb/nodeserver.go index 64466d432fb..8e92b3a2e44 100644 --- a/pkg/smb/nodeserver.go +++ b/pkg/smb/nodeserver.go @@ -98,7 +98,7 @@ func (d *Driver) NodeUnpublishVolume(ctx context.Context, req *csi.NodeUnpublish } klog.V(2).Infof("NodeUnpublishVolume: unmounting volume %s on %s", volumeID, targetPath) - err := CleanupSMBMountPoint(d.mounter, targetPath, true /*extensiveMountPointCheck*/) + err := CleanupMountPoint(d.mounter, targetPath, true /*extensiveMountPointCheck*/) if err != nil { return nil, status.Errorf(codes.Internal, "failed to unmount target %q: %v", targetPath, err) } diff --git a/pkg/smb/nodeserver_test.go b/pkg/smb/nodeserver_test.go index d177c63f104..6dd605923c5 100644 --- a/pkg/smb/nodeserver_test.go +++ b/pkg/smb/nodeserver_test.go @@ -60,11 +60,12 @@ func TestNodeStageVolume(t *testing.T) { smbFile := testutil.GetWorkDirPath("smb.go", t) sourceTest := testutil.GetWorkDirPath("source_test", t) + testSource := "\\\\hostname\\share\\test" volContext := map[string]string{ - sourceField: "test_source", + sourceField: testSource, } volContextWithMetadata := map[string]string{ - sourceField: "test_source", + sourceField: testSource, pvcNameKey: "pvcname", pvcNamespaceKey: "pvcnamespace", pvNameKey: "pvname", @@ -152,14 +153,14 @@ func TestNodeStageVolume(t *testing.T) { VolumeCapability: &stdVolCap, VolumeContext: volContext, Secrets: secrets}, - flakyWindowsErrorMessage: fmt.Sprintf("volume(vol_1##) mount \"test_source\" on %#v failed "+ + flakyWindowsErrorMessage: fmt.Sprintf("volume(vol_1##) mount \"%s\" on %#v failed "+ "with smb mapping failed with error: rpc error: code = Unknown desc = NewSmbGlobalMapping failed.", - errorMountSensSource), + strings.Replace(testSource, "\\", "\\\\", -1), errorMountSensSource), expectedErr: testutil.TestError{ DefaultError: status.Errorf(codes.Internal, - fmt.Sprintf("volume(vol_1##) mount \"test_source\" on \"%s\" failed with fake "+ + fmt.Sprintf("volume(vol_1##) mount \"%s\" on \"%s\" failed with fake "+ "MountSensitive: target error", - errorMountSensSource)), + strings.Replace(testSource, "\\", "\\\\", -1), errorMountSensSource)), }, }, { @@ -168,9 +169,9 @@ func TestNodeStageVolume(t *testing.T) { VolumeCapability: &stdVolCap, VolumeContext: volContext, Secrets: secrets}, - flakyWindowsErrorMessage: fmt.Sprintf("volume(vol_1##) mount \"test_source\" on %#v failed with "+ + flakyWindowsErrorMessage: fmt.Sprintf("volume(vol_1##) mount \"%s\" on %#v failed with "+ "smb mapping failed with error: rpc error: code = Unknown desc = NewSmbGlobalMapping failed.", - sourceTest), + strings.Replace(testSource, "\\", "\\\\", -1), sourceTest), expectedErr: testutil.TestError{}, }, { @@ -179,9 +180,9 @@ func TestNodeStageVolume(t *testing.T) { VolumeCapability: &stdVolCap, VolumeContext: volContextWithMetadata, Secrets: secrets}, - flakyWindowsErrorMessage: fmt.Sprintf("volume(vol_1##) mount \"test_source\" on %#v failed with "+ + flakyWindowsErrorMessage: fmt.Sprintf("volume(vol_1##) mount \"%s\" on %#v failed with "+ "smb mapping failed with error: rpc error: code = Unknown desc = NewSmbGlobalMapping failed.", - sourceTest), + strings.Replace(testSource, "\\", "\\\\", -1), sourceTest), expectedErr: testutil.TestError{}, }, }