-
Notifications
You must be signed in to change notification settings - Fork 138
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Calls RemoveSmbGlobalMapping when it necessary
- Loading branch information
1 parent
013d044
commit 5c67dfa
Showing
5 changed files
with
420 additions
and
17 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
//go:build windows | ||
// +build windows | ||
|
||
/* | ||
Copyright 2020 The Kubernetes Authors. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
*/ | ||
|
||
package mounter | ||
|
||
import ( | ||
"crypto/md5" | ||
"fmt" | ||
"os" | ||
"path/filepath" | ||
"strings" | ||
"sync" | ||
) | ||
|
||
var basePath = "c:\\csi\\smbmounts" | ||
var mutexes sync.Map | ||
|
||
func lock(key string) func() { | ||
value, _ := mutexes.LoadOrStore(key, &sync.Mutex{}) | ||
mtx := value.(*sync.Mutex) | ||
mtx.Lock() | ||
|
||
return func() { mtx.Unlock() } | ||
} | ||
|
||
// getRootMappingPath - returns root of smb share path or empty string if the path is invalid. For example: | ||
// | ||
// \\hostname\share\subpath => \\hostname\share, error is nil | ||
// \\hostname\share => \\hostname\share, error is nil | ||
// \\hostname => '', error is 'remote path (\\hostname) is invalid' | ||
func getRootMappingPath(path string) (string, error) { | ||
items := strings.Split(path, "\\") | ||
parts := []string{} | ||
for _, s := range items { | ||
if len(s) > 0 { | ||
parts = append(parts, s) | ||
if len(parts) == 2 { | ||
break | ||
} | ||
} | ||
} | ||
if len(parts) != 2 { | ||
return "", fmt.Errorf("remote path (%s) is invalid", path) | ||
} | ||
// parts[0] is a smb host name | ||
// parts[1] is a smb share name | ||
return strings.ToLower("\\\\" + parts[0] + "\\" + parts[1]), nil | ||
} | ||
|
||
// incementRemotePathReferencesCount - adds new reference between mappingPath and remotePath if it doesn't exist. | ||
// How it works: | ||
// 1. MappingPath contains two components: hostname, sharename | ||
// 2. We create directory in basePath related to each mappingPath. It will be used as container for references. | ||
// Example: c:\\csi\\smbmounts\\hostname\\sharename | ||
// 3. Each reference is a file with name based on MD5 of remotePath. For debug it also will contains remotePath in body of the file. | ||
// So, in incementRemotePathReferencesCount we create the file. In decrementRemotePathReferencesCount we remove the file. | ||
// Example: c:\\csi\\smbmounts\\hostname\\sharename\\092f1413e6c1d03af8b5da6f44619af8 | ||
func incementRemotePathReferencesCount(mappingPath, remotePath string) error { | ||
remotePath = strings.TrimSuffix(remotePath, "\\") | ||
path := filepath.Join(basePath, strings.TrimPrefix(mappingPath, "\\\\")) | ||
if err := os.MkdirAll(path, os.ModeDir); err != nil { | ||
return err | ||
} | ||
filePath := filepath.Join(path, getMd5(remotePath)) | ||
file, err := os.Create(filePath) | ||
if err != nil { | ||
return err | ||
} | ||
defer func() { | ||
file.Close() | ||
}() | ||
|
||
_, err = file.WriteString(remotePath) | ||
return err | ||
} | ||
|
||
// decrementRemotePathReferencesCount - removes reference between mappingPath and remotePath. | ||
// See incementRemotePathReferencesCount to understand how references work. | ||
func decrementRemotePathReferencesCount(mappingPath, remotePath string) error { | ||
remotePath = strings.TrimSuffix(remotePath, "\\") | ||
path := filepath.Join(basePath, strings.TrimPrefix(mappingPath, "\\\\")) | ||
if err := os.MkdirAll(path, os.ModeDir); err != nil { | ||
return err | ||
} | ||
filePath := filepath.Join(path, getMd5(remotePath)) | ||
return os.Remove(filePath) | ||
} | ||
|
||
// getRemotePathReferencesCount - returns count of references between mappingPath and remotePath. | ||
// See incementRemotePathReferencesCount to understand how references work. | ||
func getRemotePathReferencesCount(mappingPath string) int { | ||
path := filepath.Join(basePath, strings.TrimPrefix(mappingPath, "\\\\")) | ||
if os.MkdirAll(path, os.ModeDir) != nil { | ||
return -1 | ||
} | ||
files, err := os.ReadDir(path) | ||
if err != nil { | ||
return -1 | ||
} | ||
return len(files) | ||
} | ||
|
||
func getMd5(path string) string { | ||
data := []byte(strings.ToLower(path)) | ||
return fmt.Sprintf("%x", md5.Sum(data)) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,227 @@ | ||
/* | ||
Copyright 2020 The Kubernetes Authors. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
*/ | ||
|
||
package mounter | ||
|
||
import ( | ||
"os" | ||
"testing" | ||
"time" | ||
|
||
"github.com/stretchr/testify/assert" | ||
) | ||
|
||
func TestLockUnlock(t *testing.T) { | ||
key := "resource name" | ||
|
||
unlock := lock(key) | ||
defer unlock() | ||
|
||
_, loaded := mutexes.Load(key) | ||
assert.True(t, loaded) | ||
} | ||
|
||
func TestLockLockedResource(t *testing.T) { | ||
locked := true | ||
unlock := lock("a") | ||
go func() { | ||
time.Sleep(500 * time.Microsecond) | ||
locked = false | ||
unlock() | ||
}() | ||
|
||
// try to lock already locked resource | ||
unlock2 := lock("a") | ||
defer unlock2() | ||
if locked { | ||
assert.Fail(t, "access to locked resource") | ||
} | ||
} | ||
|
||
func TestLockDifferentKeys(t *testing.T) { | ||
unlocka := lock("a") | ||
unlockb := lock("b") | ||
unlocka() | ||
unlockb() | ||
} | ||
|
||
func TestGetRootMappingPath(t *testing.T) { | ||
testCases := []struct { | ||
remote string | ||
expectResult string | ||
expectError bool | ||
}{ | ||
{ | ||
remote: "", | ||
expectResult: "", | ||
expectError: true, | ||
}, | ||
{ | ||
remote: "hostname", | ||
expectResult: "", | ||
expectError: true, | ||
}, | ||
{ | ||
remote: "\\\\hostname\\path", | ||
expectResult: "\\\\hostname\\path", | ||
expectError: false, | ||
}, | ||
{ | ||
remote: "\\\\hostname\\path\\", | ||
expectResult: "\\\\hostname\\path", | ||
expectError: false, | ||
}, | ||
{ | ||
remote: "\\\\hostname\\path\\subpath", | ||
expectResult: "\\\\hostname\\path", | ||
expectError: false, | ||
}, | ||
} | ||
for _, tc := range testCases { | ||
result, err := getRootMappingPath(tc.remote) | ||
if tc.expectError && err == nil { | ||
t.Errorf("Expected error but getRootMappingPath returned a nil error") | ||
} | ||
if !tc.expectError { | ||
if err != nil { | ||
t.Errorf("Expected no errors but getRootMappingPath returned error: %v", err) | ||
} | ||
if tc.expectResult != result { | ||
t.Errorf("Expected (%s) but getRootMappingPath returned (%s)", tc.expectResult, result) | ||
} | ||
} | ||
} | ||
} | ||
|
||
func TestRemotePathReferencesCounter(t *testing.T) { | ||
remotePath1 := "\\\\servername\\share\\subpath\\1" | ||
remotePath2 := "\\\\servername\\share\\subpath\\2" | ||
mappingPath, err := getRootMappingPath(remotePath1) | ||
assert.Nil(t, err) | ||
|
||
basePath = os.Getenv("TEMP") + "\\TestMappingPathCounter" | ||
os.RemoveAll(basePath) | ||
defer func() { | ||
// cleanup temp folder | ||
os.RemoveAll(basePath) | ||
}() | ||
|
||
// by default we have no any files in `mappingPath`. So, `count` should be zero | ||
assert.Zero(t, getRemotePathReferencesCount(mappingPath)) | ||
// add reference to `remotePath1`. So, `count` should be equal `1` | ||
assert.Nil(t, incementRemotePathReferencesCount(mappingPath, remotePath1)) | ||
assert.Equal(t, 1, getRemotePathReferencesCount(mappingPath)) | ||
// add reference to `remotePath2`. So, `count` should be equal `2` | ||
assert.Nil(t, incementRemotePathReferencesCount(mappingPath, remotePath2)) | ||
assert.Equal(t, 2, getRemotePathReferencesCount(mappingPath)) | ||
// remove reference to `remotePath1`. So, `count` should be equal `1` | ||
assert.Nil(t, decrementRemotePathReferencesCount(mappingPath, remotePath1)) | ||
assert.Equal(t, 1, getRemotePathReferencesCount(mappingPath)) | ||
// remove reference to `remotePath2`. So, `count` should be equal `0` | ||
assert.Nil(t, decrementRemotePathReferencesCount(mappingPath, remotePath2)) | ||
assert.Zero(t, getRemotePathReferencesCount(mappingPath)) | ||
} | ||
|
||
func TestIncementRemotePathReferencesCount(t *testing.T) { | ||
remotePath := "\\\\servername\\share\\subpath" | ||
mappingPath, err := getRootMappingPath(remotePath) | ||
assert.Nil(t, err) | ||
|
||
basePath = os.Getenv("TEMP") + "\\TestMappingPathCounter" | ||
os.RemoveAll(basePath) | ||
defer func() { | ||
// cleanup temp folder | ||
os.RemoveAll(basePath) | ||
}() | ||
|
||
assert.Nil(t, incementRemotePathReferencesCount(mappingPath, remotePath)) | ||
|
||
mappingPathContainer := basePath + "\\servername\\share" | ||
if dir, err := os.Stat(mappingPathContainer); os.IsNotExist(err) || !dir.IsDir() { | ||
t.Error("mapping file container does not exist") | ||
} | ||
|
||
reference := mappingPathContainer + "\\" + getMd5(remotePath) | ||
if file, err := os.Stat(reference); os.IsNotExist(err) || file.IsDir() { | ||
t.Error("reference file does not exist") | ||
} | ||
} | ||
|
||
func TestDecrementRemotePathReferencesCount(t *testing.T) { | ||
remotePath := "\\\\servername\\share\\subpath" | ||
mappingPath, err := getRootMappingPath(remotePath) | ||
assert.Nil(t, err) | ||
|
||
basePath = os.Getenv("TEMP") + "\\TestMappingPathCounter" | ||
os.RemoveAll(basePath) | ||
defer func() { | ||
// cleanup temp folder | ||
os.RemoveAll(basePath) | ||
}() | ||
|
||
assert.Nil(t, incementRemotePathReferencesCount(mappingPath, remotePath)) | ||
assert.Nil(t, decrementRemotePathReferencesCount(mappingPath, remotePath)) | ||
|
||
mappingPathContainer := basePath + "\\servername\\share" | ||
if dir, err := os.Stat(mappingPathContainer); os.IsNotExist(err) || !dir.IsDir() { | ||
t.Error("mapping file container does not exist") | ||
} | ||
|
||
reference := mappingPathContainer + "\\" + getMd5(remotePath) | ||
if _, err := os.Stat(reference); os.IsExist(err) { | ||
t.Error("reference file exists") | ||
} | ||
} | ||
|
||
func TestMultiplyCallsOfIncementRemotePathReferencesCount(t *testing.T) { | ||
remotePath := "\\\\servername\\share\\subpath" | ||
mappingPath, err := getRootMappingPath(remotePath) | ||
assert.Nil(t, err) | ||
|
||
basePath = os.Getenv("TEMP") + "\\TestMappingPathCounter" | ||
os.RemoveAll(basePath) | ||
defer func() { | ||
// cleanup temp folder | ||
os.RemoveAll(basePath) | ||
}() | ||
|
||
assert.Zero(t, getRemotePathReferencesCount(mappingPath)) | ||
assert.Nil(t, incementRemotePathReferencesCount(mappingPath, remotePath)) | ||
// next calls of `incementMappingPathCount` with the same arguments should be ignored | ||
assert.Nil(t, incementRemotePathReferencesCount(mappingPath, remotePath)) | ||
assert.Nil(t, incementRemotePathReferencesCount(mappingPath, remotePath)) | ||
assert.Nil(t, incementRemotePathReferencesCount(mappingPath, remotePath)) | ||
assert.Nil(t, incementRemotePathReferencesCount(mappingPath, remotePath)) | ||
assert.Equal(t, 1, getRemotePathReferencesCount(mappingPath)) | ||
} | ||
|
||
func TestMultiplyCallsOfDecrementRemotePathReferencesCount(t *testing.T) { | ||
remotePath := "\\\\servername\\share\\subpath" | ||
mappingPath, err := getRootMappingPath(remotePath) | ||
assert.Nil(t, err) | ||
|
||
basePath = os.Getenv("TEMP") + "\\TestMappingPathCounter" | ||
os.RemoveAll(basePath) | ||
defer func() { | ||
// cleanup temp folder | ||
os.RemoveAll(basePath) | ||
}() | ||
|
||
assert.Zero(t, getRemotePathReferencesCount(mappingPath)) | ||
assert.Nil(t, incementRemotePathReferencesCount(mappingPath, remotePath)) | ||
assert.Nil(t, decrementRemotePathReferencesCount(mappingPath, remotePath)) | ||
assert.NotNil(t, decrementRemotePathReferencesCount(mappingPath, remotePath)) | ||
} |
Oops, something went wrong.