Skip to content

Commit

Permalink
Calls RemoveSmbGlobalMapping when it necessary
Browse files Browse the repository at this point in the history
  • Loading branch information
vitaliy-leschenko committed Aug 9, 2022
1 parent 013d044 commit 5c67dfa
Show file tree
Hide file tree
Showing 5 changed files with 420 additions and 17 deletions.
122 changes: 122 additions & 0 deletions pkg/mounter/refcounter_windows.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
//go:build windows
// +build windows

/*
Copyright 2020 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package mounter

import (
"crypto/md5"
"fmt"
"os"
"path/filepath"
"strings"
"sync"
)

var basePath = "c:\\csi\\smbmounts"
var mutexes sync.Map

func lock(key string) func() {
value, _ := mutexes.LoadOrStore(key, &sync.Mutex{})
mtx := value.(*sync.Mutex)
mtx.Lock()

return func() { mtx.Unlock() }
}

// getRootMappingPath - returns root of smb share path or empty string if the path is invalid. For example:
//
// \\hostname\share\subpath => \\hostname\share, error is nil
// \\hostname\share => \\hostname\share, error is nil
// \\hostname => '', error is 'remote path (\\hostname) is invalid'
func getRootMappingPath(path string) (string, error) {
items := strings.Split(path, "\\")
parts := []string{}
for _, s := range items {
if len(s) > 0 {
parts = append(parts, s)
if len(parts) == 2 {
break
}
}
}
if len(parts) != 2 {
return "", fmt.Errorf("remote path (%s) is invalid", path)
}
// parts[0] is a smb host name
// parts[1] is a smb share name
return strings.ToLower("\\\\" + parts[0] + "\\" + parts[1]), nil
}

// incementRemotePathReferencesCount - adds new reference between mappingPath and remotePath if it doesn't exist.
// How it works:
// 1. MappingPath contains two components: hostname, sharename
// 2. We create directory in basePath related to each mappingPath. It will be used as container for references.
// Example: c:\\csi\\smbmounts\\hostname\\sharename
// 3. Each reference is a file with name based on MD5 of remotePath. For debug it also will contains remotePath in body of the file.
// So, in incementRemotePathReferencesCount we create the file. In decrementRemotePathReferencesCount we remove the file.
// Example: c:\\csi\\smbmounts\\hostname\\sharename\\092f1413e6c1d03af8b5da6f44619af8
func incementRemotePathReferencesCount(mappingPath, remotePath string) error {
remotePath = strings.TrimSuffix(remotePath, "\\")
path := filepath.Join(basePath, strings.TrimPrefix(mappingPath, "\\\\"))
if err := os.MkdirAll(path, os.ModeDir); err != nil {
return err
}
filePath := filepath.Join(path, getMd5(remotePath))
file, err := os.Create(filePath)
if err != nil {
return err
}
defer func() {
file.Close()
}()

_, err = file.WriteString(remotePath)
return err
}

// decrementRemotePathReferencesCount - removes reference between mappingPath and remotePath.
// See incementRemotePathReferencesCount to understand how references work.
func decrementRemotePathReferencesCount(mappingPath, remotePath string) error {
remotePath = strings.TrimSuffix(remotePath, "\\")
path := filepath.Join(basePath, strings.TrimPrefix(mappingPath, "\\\\"))
if err := os.MkdirAll(path, os.ModeDir); err != nil {
return err
}
filePath := filepath.Join(path, getMd5(remotePath))
return os.Remove(filePath)
}

// getRemotePathReferencesCount - returns count of references between mappingPath and remotePath.
// See incementRemotePathReferencesCount to understand how references work.
func getRemotePathReferencesCount(mappingPath string) int {
path := filepath.Join(basePath, strings.TrimPrefix(mappingPath, "\\\\"))
if os.MkdirAll(path, os.ModeDir) != nil {
return -1
}
files, err := os.ReadDir(path)
if err != nil {
return -1
}
return len(files)
}

func getMd5(path string) string {
data := []byte(strings.ToLower(path))
return fmt.Sprintf("%x", md5.Sum(data))
}
227 changes: 227 additions & 0 deletions pkg/mounter/refcounter_windows_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
/*
Copyright 2020 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package mounter

import (
"os"
"testing"
"time"

"github.com/stretchr/testify/assert"
)

func TestLockUnlock(t *testing.T) {
key := "resource name"

unlock := lock(key)
defer unlock()

_, loaded := mutexes.Load(key)
assert.True(t, loaded)
}

func TestLockLockedResource(t *testing.T) {
locked := true
unlock := lock("a")
go func() {
time.Sleep(500 * time.Microsecond)
locked = false
unlock()
}()

// try to lock already locked resource
unlock2 := lock("a")
defer unlock2()
if locked {
assert.Fail(t, "access to locked resource")
}
}

func TestLockDifferentKeys(t *testing.T) {
unlocka := lock("a")
unlockb := lock("b")
unlocka()
unlockb()
}

func TestGetRootMappingPath(t *testing.T) {
testCases := []struct {
remote string
expectResult string
expectError bool
}{
{
remote: "",
expectResult: "",
expectError: true,
},
{
remote: "hostname",
expectResult: "",
expectError: true,
},
{
remote: "\\\\hostname\\path",
expectResult: "\\\\hostname\\path",
expectError: false,
},
{
remote: "\\\\hostname\\path\\",
expectResult: "\\\\hostname\\path",
expectError: false,
},
{
remote: "\\\\hostname\\path\\subpath",
expectResult: "\\\\hostname\\path",
expectError: false,
},
}
for _, tc := range testCases {
result, err := getRootMappingPath(tc.remote)
if tc.expectError && err == nil {
t.Errorf("Expected error but getRootMappingPath returned a nil error")
}
if !tc.expectError {
if err != nil {
t.Errorf("Expected no errors but getRootMappingPath returned error: %v", err)
}
if tc.expectResult != result {
t.Errorf("Expected (%s) but getRootMappingPath returned (%s)", tc.expectResult, result)
}
}
}
}

func TestRemotePathReferencesCounter(t *testing.T) {
remotePath1 := "\\\\servername\\share\\subpath\\1"
remotePath2 := "\\\\servername\\share\\subpath\\2"
mappingPath, err := getRootMappingPath(remotePath1)
assert.Nil(t, err)

basePath = os.Getenv("TEMP") + "\\TestMappingPathCounter"
os.RemoveAll(basePath)
defer func() {
// cleanup temp folder
os.RemoveAll(basePath)
}()

// by default we have no any files in `mappingPath`. So, `count` should be zero
assert.Zero(t, getRemotePathReferencesCount(mappingPath))
// add reference to `remotePath1`. So, `count` should be equal `1`
assert.Nil(t, incementRemotePathReferencesCount(mappingPath, remotePath1))
assert.Equal(t, 1, getRemotePathReferencesCount(mappingPath))
// add reference to `remotePath2`. So, `count` should be equal `2`
assert.Nil(t, incementRemotePathReferencesCount(mappingPath, remotePath2))
assert.Equal(t, 2, getRemotePathReferencesCount(mappingPath))
// remove reference to `remotePath1`. So, `count` should be equal `1`
assert.Nil(t, decrementRemotePathReferencesCount(mappingPath, remotePath1))
assert.Equal(t, 1, getRemotePathReferencesCount(mappingPath))
// remove reference to `remotePath2`. So, `count` should be equal `0`
assert.Nil(t, decrementRemotePathReferencesCount(mappingPath, remotePath2))
assert.Zero(t, getRemotePathReferencesCount(mappingPath))
}

func TestIncementRemotePathReferencesCount(t *testing.T) {
remotePath := "\\\\servername\\share\\subpath"
mappingPath, err := getRootMappingPath(remotePath)
assert.Nil(t, err)

basePath = os.Getenv("TEMP") + "\\TestMappingPathCounter"
os.RemoveAll(basePath)
defer func() {
// cleanup temp folder
os.RemoveAll(basePath)
}()

assert.Nil(t, incementRemotePathReferencesCount(mappingPath, remotePath))

mappingPathContainer := basePath + "\\servername\\share"
if dir, err := os.Stat(mappingPathContainer); os.IsNotExist(err) || !dir.IsDir() {
t.Error("mapping file container does not exist")
}

reference := mappingPathContainer + "\\" + getMd5(remotePath)
if file, err := os.Stat(reference); os.IsNotExist(err) || file.IsDir() {
t.Error("reference file does not exist")
}
}

func TestDecrementRemotePathReferencesCount(t *testing.T) {
remotePath := "\\\\servername\\share\\subpath"
mappingPath, err := getRootMappingPath(remotePath)
assert.Nil(t, err)

basePath = os.Getenv("TEMP") + "\\TestMappingPathCounter"
os.RemoveAll(basePath)
defer func() {
// cleanup temp folder
os.RemoveAll(basePath)
}()

assert.Nil(t, incementRemotePathReferencesCount(mappingPath, remotePath))
assert.Nil(t, decrementRemotePathReferencesCount(mappingPath, remotePath))

mappingPathContainer := basePath + "\\servername\\share"
if dir, err := os.Stat(mappingPathContainer); os.IsNotExist(err) || !dir.IsDir() {
t.Error("mapping file container does not exist")
}

reference := mappingPathContainer + "\\" + getMd5(remotePath)
if _, err := os.Stat(reference); os.IsExist(err) {
t.Error("reference file exists")
}
}

func TestMultiplyCallsOfIncementRemotePathReferencesCount(t *testing.T) {
remotePath := "\\\\servername\\share\\subpath"
mappingPath, err := getRootMappingPath(remotePath)
assert.Nil(t, err)

basePath = os.Getenv("TEMP") + "\\TestMappingPathCounter"
os.RemoveAll(basePath)
defer func() {
// cleanup temp folder
os.RemoveAll(basePath)
}()

assert.Zero(t, getRemotePathReferencesCount(mappingPath))
assert.Nil(t, incementRemotePathReferencesCount(mappingPath, remotePath))
// next calls of `incementMappingPathCount` with the same arguments should be ignored
assert.Nil(t, incementRemotePathReferencesCount(mappingPath, remotePath))
assert.Nil(t, incementRemotePathReferencesCount(mappingPath, remotePath))
assert.Nil(t, incementRemotePathReferencesCount(mappingPath, remotePath))
assert.Nil(t, incementRemotePathReferencesCount(mappingPath, remotePath))
assert.Equal(t, 1, getRemotePathReferencesCount(mappingPath))
}

func TestMultiplyCallsOfDecrementRemotePathReferencesCount(t *testing.T) {
remotePath := "\\\\servername\\share\\subpath"
mappingPath, err := getRootMappingPath(remotePath)
assert.Nil(t, err)

basePath = os.Getenv("TEMP") + "\\TestMappingPathCounter"
os.RemoveAll(basePath)
defer func() {
// cleanup temp folder
os.RemoveAll(basePath)
}()

assert.Zero(t, getRemotePathReferencesCount(mappingPath))
assert.Nil(t, incementRemotePathReferencesCount(mappingPath, remotePath))
assert.Nil(t, decrementRemotePathReferencesCount(mappingPath, remotePath))
assert.NotNil(t, decrementRemotePathReferencesCount(mappingPath, remotePath))
}
Loading

0 comments on commit 5c67dfa

Please sign in to comment.