Merge pull request #505 from vitaliy-leschenko/smb-unmount
Calls RemoveSmbGlobalMapping when it necessary
This commit is contained in:
commit
eb9ddbcf81
|
|
@ -0,0 +1,122 @@
|
|||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
/*
|
||||
Copyright 2020 The Kubernetes Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package mounter
|
||||
|
||||
import (
|
||||
"crypto/md5"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var basePath = "c:\\csi\\smbmounts"
|
||||
var mutexes sync.Map
|
||||
|
||||
func lock(key string) func() {
|
||||
value, _ := mutexes.LoadOrStore(key, &sync.Mutex{})
|
||||
mtx := value.(*sync.Mutex)
|
||||
mtx.Lock()
|
||||
|
||||
return func() { mtx.Unlock() }
|
||||
}
|
||||
|
||||
// getRootMappingPath - returns root of smb share path or empty string if the path is invalid. For example:
|
||||
//
|
||||
// \\hostname\share\subpath => \\hostname\share, error is nil
|
||||
// \\hostname\share => \\hostname\share, error is nil
|
||||
// \\hostname => '', error is 'remote path (\\hostname) is invalid'
|
||||
func getRootMappingPath(path string) (string, error) {
|
||||
items := strings.Split(path, "\\")
|
||||
parts := []string{}
|
||||
for _, s := range items {
|
||||
if len(s) > 0 {
|
||||
parts = append(parts, s)
|
||||
if len(parts) == 2 {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(parts) != 2 {
|
||||
return "", fmt.Errorf("remote path (%s) is invalid", path)
|
||||
}
|
||||
// parts[0] is a smb host name
|
||||
// parts[1] is a smb share name
|
||||
return strings.ToLower("\\\\" + parts[0] + "\\" + parts[1]), nil
|
||||
}
|
||||
|
||||
// incementRemotePathReferencesCount - adds new reference between mappingPath and remotePath if it doesn't exist.
|
||||
// How it works:
|
||||
// 1. MappingPath contains two components: hostname, sharename
|
||||
// 2. We create directory in basePath related to each mappingPath. It will be used as container for references.
|
||||
// Example: c:\\csi\\smbmounts\\hostname\\sharename
|
||||
// 3. Each reference is a file with name based on MD5 of remotePath. For debug it also will contains remotePath in body of the file.
|
||||
// So, in incementRemotePathReferencesCount we create the file. In decrementRemotePathReferencesCount we remove the file.
|
||||
// Example: c:\\csi\\smbmounts\\hostname\\sharename\\092f1413e6c1d03af8b5da6f44619af8
|
||||
func incementRemotePathReferencesCount(mappingPath, remotePath string) error {
|
||||
remotePath = strings.TrimSuffix(remotePath, "\\")
|
||||
path := filepath.Join(basePath, strings.TrimPrefix(mappingPath, "\\\\"))
|
||||
if err := os.MkdirAll(path, os.ModeDir); err != nil {
|
||||
return err
|
||||
}
|
||||
filePath := filepath.Join(path, getMd5(remotePath))
|
||||
file, err := os.Create(filePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
file.Close()
|
||||
}()
|
||||
|
||||
_, err = file.WriteString(remotePath)
|
||||
return err
|
||||
}
|
||||
|
||||
// decrementRemotePathReferencesCount - removes reference between mappingPath and remotePath.
|
||||
// See incementRemotePathReferencesCount to understand how references work.
|
||||
func decrementRemotePathReferencesCount(mappingPath, remotePath string) error {
|
||||
remotePath = strings.TrimSuffix(remotePath, "\\")
|
||||
path := filepath.Join(basePath, strings.TrimPrefix(mappingPath, "\\\\"))
|
||||
if err := os.MkdirAll(path, os.ModeDir); err != nil {
|
||||
return err
|
||||
}
|
||||
filePath := filepath.Join(path, getMd5(remotePath))
|
||||
return os.Remove(filePath)
|
||||
}
|
||||
|
||||
// getRemotePathReferencesCount - returns count of references between mappingPath and remotePath.
|
||||
// See incementRemotePathReferencesCount to understand how references work.
|
||||
func getRemotePathReferencesCount(mappingPath string) int {
|
||||
path := filepath.Join(basePath, strings.TrimPrefix(mappingPath, "\\\\"))
|
||||
if os.MkdirAll(path, os.ModeDir) != nil {
|
||||
return -1
|
||||
}
|
||||
files, err := os.ReadDir(path)
|
||||
if err != nil {
|
||||
return -1
|
||||
}
|
||||
return len(files)
|
||||
}
|
||||
|
||||
func getMd5(path string) string {
|
||||
data := []byte(strings.ToLower(path))
|
||||
return fmt.Sprintf("%x", md5.Sum(data))
|
||||
}
|
||||
|
|
@ -0,0 +1,227 @@
|
|||
/*
|
||||
Copyright 2020 The Kubernetes Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package mounter
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestLockUnlock(t *testing.T) {
|
||||
key := "resource name"
|
||||
|
||||
unlock := lock(key)
|
||||
defer unlock()
|
||||
|
||||
_, loaded := mutexes.Load(key)
|
||||
assert.True(t, loaded)
|
||||
}
|
||||
|
||||
func TestLockLockedResource(t *testing.T) {
|
||||
locked := true
|
||||
unlock := lock("a")
|
||||
go func() {
|
||||
time.Sleep(500 * time.Microsecond)
|
||||
locked = false
|
||||
unlock()
|
||||
}()
|
||||
|
||||
// try to lock already locked resource
|
||||
unlock2 := lock("a")
|
||||
defer unlock2()
|
||||
if locked {
|
||||
assert.Fail(t, "access to locked resource")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLockDifferentKeys(t *testing.T) {
|
||||
unlocka := lock("a")
|
||||
unlockb := lock("b")
|
||||
unlocka()
|
||||
unlockb()
|
||||
}
|
||||
|
||||
func TestGetRootMappingPath(t *testing.T) {
|
||||
testCases := []struct {
|
||||
remote string
|
||||
expectResult string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
remote: "",
|
||||
expectResult: "",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
remote: "hostname",
|
||||
expectResult: "",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
remote: "\\\\hostname\\path",
|
||||
expectResult: "\\\\hostname\\path",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
remote: "\\\\hostname\\path\\",
|
||||
expectResult: "\\\\hostname\\path",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
remote: "\\\\hostname\\path\\subpath",
|
||||
expectResult: "\\\\hostname\\path",
|
||||
expectError: false,
|
||||
},
|
||||
}
|
||||
for _, tc := range testCases {
|
||||
result, err := getRootMappingPath(tc.remote)
|
||||
if tc.expectError && err == nil {
|
||||
t.Errorf("Expected error but getRootMappingPath returned a nil error")
|
||||
}
|
||||
if !tc.expectError {
|
||||
if err != nil {
|
||||
t.Errorf("Expected no errors but getRootMappingPath returned error: %v", err)
|
||||
}
|
||||
if tc.expectResult != result {
|
||||
t.Errorf("Expected (%s) but getRootMappingPath returned (%s)", tc.expectResult, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemotePathReferencesCounter(t *testing.T) {
|
||||
remotePath1 := "\\\\servername\\share\\subpath\\1"
|
||||
remotePath2 := "\\\\servername\\share\\subpath\\2"
|
||||
mappingPath, err := getRootMappingPath(remotePath1)
|
||||
assert.Nil(t, err)
|
||||
|
||||
basePath = os.Getenv("TEMP") + "\\TestMappingPathCounter"
|
||||
os.RemoveAll(basePath)
|
||||
defer func() {
|
||||
// cleanup temp folder
|
||||
os.RemoveAll(basePath)
|
||||
}()
|
||||
|
||||
// by default we have no any files in `mappingPath`. So, `count` should be zero
|
||||
assert.Zero(t, getRemotePathReferencesCount(mappingPath))
|
||||
// add reference to `remotePath1`. So, `count` should be equal `1`
|
||||
assert.Nil(t, incementRemotePathReferencesCount(mappingPath, remotePath1))
|
||||
assert.Equal(t, 1, getRemotePathReferencesCount(mappingPath))
|
||||
// add reference to `remotePath2`. So, `count` should be equal `2`
|
||||
assert.Nil(t, incementRemotePathReferencesCount(mappingPath, remotePath2))
|
||||
assert.Equal(t, 2, getRemotePathReferencesCount(mappingPath))
|
||||
// remove reference to `remotePath1`. So, `count` should be equal `1`
|
||||
assert.Nil(t, decrementRemotePathReferencesCount(mappingPath, remotePath1))
|
||||
assert.Equal(t, 1, getRemotePathReferencesCount(mappingPath))
|
||||
// remove reference to `remotePath2`. So, `count` should be equal `0`
|
||||
assert.Nil(t, decrementRemotePathReferencesCount(mappingPath, remotePath2))
|
||||
assert.Zero(t, getRemotePathReferencesCount(mappingPath))
|
||||
}
|
||||
|
||||
func TestIncementRemotePathReferencesCount(t *testing.T) {
|
||||
remotePath := "\\\\servername\\share\\subpath"
|
||||
mappingPath, err := getRootMappingPath(remotePath)
|
||||
assert.Nil(t, err)
|
||||
|
||||
basePath = os.Getenv("TEMP") + "\\TestMappingPathCounter"
|
||||
os.RemoveAll(basePath)
|
||||
defer func() {
|
||||
// cleanup temp folder
|
||||
os.RemoveAll(basePath)
|
||||
}()
|
||||
|
||||
assert.Nil(t, incementRemotePathReferencesCount(mappingPath, remotePath))
|
||||
|
||||
mappingPathContainer := basePath + "\\servername\\share"
|
||||
if dir, err := os.Stat(mappingPathContainer); os.IsNotExist(err) || !dir.IsDir() {
|
||||
t.Error("mapping file container does not exist")
|
||||
}
|
||||
|
||||
reference := mappingPathContainer + "\\" + getMd5(remotePath)
|
||||
if file, err := os.Stat(reference); os.IsNotExist(err) || file.IsDir() {
|
||||
t.Error("reference file does not exist")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecrementRemotePathReferencesCount(t *testing.T) {
|
||||
remotePath := "\\\\servername\\share\\subpath"
|
||||
mappingPath, err := getRootMappingPath(remotePath)
|
||||
assert.Nil(t, err)
|
||||
|
||||
basePath = os.Getenv("TEMP") + "\\TestMappingPathCounter"
|
||||
os.RemoveAll(basePath)
|
||||
defer func() {
|
||||
// cleanup temp folder
|
||||
os.RemoveAll(basePath)
|
||||
}()
|
||||
|
||||
assert.Nil(t, incementRemotePathReferencesCount(mappingPath, remotePath))
|
||||
assert.Nil(t, decrementRemotePathReferencesCount(mappingPath, remotePath))
|
||||
|
||||
mappingPathContainer := basePath + "\\servername\\share"
|
||||
if dir, err := os.Stat(mappingPathContainer); os.IsNotExist(err) || !dir.IsDir() {
|
||||
t.Error("mapping file container does not exist")
|
||||
}
|
||||
|
||||
reference := mappingPathContainer + "\\" + getMd5(remotePath)
|
||||
if _, err := os.Stat(reference); os.IsExist(err) {
|
||||
t.Error("reference file exists")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMultiplyCallsOfIncementRemotePathReferencesCount(t *testing.T) {
|
||||
remotePath := "\\\\servername\\share\\subpath"
|
||||
mappingPath, err := getRootMappingPath(remotePath)
|
||||
assert.Nil(t, err)
|
||||
|
||||
basePath = os.Getenv("TEMP") + "\\TestMappingPathCounter"
|
||||
os.RemoveAll(basePath)
|
||||
defer func() {
|
||||
// cleanup temp folder
|
||||
os.RemoveAll(basePath)
|
||||
}()
|
||||
|
||||
assert.Zero(t, getRemotePathReferencesCount(mappingPath))
|
||||
assert.Nil(t, incementRemotePathReferencesCount(mappingPath, remotePath))
|
||||
// next calls of `incementMappingPathCount` with the same arguments should be ignored
|
||||
assert.Nil(t, incementRemotePathReferencesCount(mappingPath, remotePath))
|
||||
assert.Nil(t, incementRemotePathReferencesCount(mappingPath, remotePath))
|
||||
assert.Nil(t, incementRemotePathReferencesCount(mappingPath, remotePath))
|
||||
assert.Nil(t, incementRemotePathReferencesCount(mappingPath, remotePath))
|
||||
assert.Equal(t, 1, getRemotePathReferencesCount(mappingPath))
|
||||
}
|
||||
|
||||
func TestMultiplyCallsOfDecrementRemotePathReferencesCount(t *testing.T) {
|
||||
remotePath := "\\\\servername\\share\\subpath"
|
||||
mappingPath, err := getRootMappingPath(remotePath)
|
||||
assert.Nil(t, err)
|
||||
|
||||
basePath = os.Getenv("TEMP") + "\\TestMappingPathCounter"
|
||||
os.RemoveAll(basePath)
|
||||
defer func() {
|
||||
// cleanup temp folder
|
||||
os.RemoveAll(basePath)
|
||||
}()
|
||||
|
||||
assert.Zero(t, getRemotePathReferencesCount(mappingPath))
|
||||
assert.Nil(t, incementRemotePathReferencesCount(mappingPath, remotePath))
|
||||
assert.Nil(t, decrementRemotePathReferencesCount(mappingPath, remotePath))
|
||||
assert.NotNil(t, decrementRemotePathReferencesCount(mappingPath, remotePath))
|
||||
}
|
||||
|
|
@ -101,6 +101,17 @@ func (mounter *csiProxyMounter) SMBMount(source, target, fsType string, mountOpt
|
|||
}
|
||||
|
||||
source = strings.Replace(source, "/", "\\", -1)
|
||||
if strings.HasSuffix(source, "\\") {
|
||||
source = strings.TrimSuffix(source, "\\")
|
||||
}
|
||||
|
||||
mappingPath, err := getRootMappingPath(source)
|
||||
if err != nil {
|
||||
return fmt.Errorf("getRootMappingPath(%s) failed with error: %v", source, err)
|
||||
}
|
||||
unlock := lock(mappingPath)
|
||||
defer unlock()
|
||||
|
||||
normalizedTarget := normalizeWindowsPath(target)
|
||||
smbMountRequest := &smb.NewSmbGlobalMappingRequest{
|
||||
LocalPath: normalizedTarget,
|
||||
|
|
@ -113,13 +124,53 @@ func (mounter *csiProxyMounter) SMBMount(source, target, fsType string, mountOpt
|
|||
return fmt.Errorf("smb mapping failed with error: %v", err)
|
||||
}
|
||||
klog.V(2).Infof("mount %s on %s successfully", source, normalizedTarget)
|
||||
|
||||
if err = incementRemotePathReferencesCount(mappingPath, source); err != nil {
|
||||
klog.Warningf("incementMappingPathCount(%s, %s) failed with error: %v", mappingPath, source, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mounter *csiProxyMounter) SMBUnmount(target string) error {
|
||||
klog.V(4).Infof("SMBUnmount: local path: %s", target)
|
||||
// TODO: We need to remove the SMB mapping. The change to remove the
|
||||
// directory brings the CSI code in parity with the in-tree.
|
||||
|
||||
if remotePath, err := os.Readlink(target); err != nil {
|
||||
klog.Warningf("SMBUnmount: can't get remote path: %v", err)
|
||||
} else {
|
||||
if strings.HasSuffix(remotePath, "\\") {
|
||||
remotePath = strings.TrimSuffix(remotePath, "\\")
|
||||
}
|
||||
mappingPath, err := getRootMappingPath(remotePath)
|
||||
if err != nil {
|
||||
klog.Warningf("getRootMappingPath(%s) failed with error: %v", remotePath, err)
|
||||
} else {
|
||||
klog.V(4).Infof("SMBUnmount: remote path: %s, mapping path: %s", remotePath, mappingPath)
|
||||
|
||||
unlock := lock(mappingPath)
|
||||
defer unlock()
|
||||
|
||||
if err := decrementRemotePathReferencesCount(mappingPath, remotePath); err != nil {
|
||||
klog.Warningf("decrementMappingPathCount(%s, %d) failed with error: %v", mappingPath, remotePath, err)
|
||||
} else {
|
||||
count := getRemotePathReferencesCount(mappingPath)
|
||||
if count == 0 {
|
||||
smbUnmountRequest := &smb.RemoveSmbGlobalMappingRequest{
|
||||
RemotePath: remotePath,
|
||||
}
|
||||
klog.V(2).Infof("begin to unmount %s on %s", remotePath, target)
|
||||
if _, err := mounter.SMBClient.RemoveSmbGlobalMapping(context.Background(), smbUnmountRequest); err != nil {
|
||||
return fmt.Errorf("smb unmapping failed with error: %v", err)
|
||||
} else {
|
||||
klog.V(2).Infof("unmount %s on %s successfully", remotePath, target)
|
||||
}
|
||||
} else {
|
||||
klog.Infof("SMBUnmount: found %f links to %s", count, mappingPath)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return mounter.Rmdir(target)
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -98,7 +98,7 @@ func (d *Driver) NodeUnpublishVolume(ctx context.Context, req *csi.NodeUnpublish
|
|||
}
|
||||
|
||||
klog.V(2).Infof("NodeUnpublishVolume: unmounting volume %s on %s", volumeID, targetPath)
|
||||
err := CleanupSMBMountPoint(d.mounter, targetPath, true /*extensiveMountPointCheck*/)
|
||||
err := CleanupMountPoint(d.mounter, targetPath, true /*extensiveMountPointCheck*/)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to unmount target %q: %v", targetPath, err)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -60,11 +60,12 @@ func TestNodeStageVolume(t *testing.T) {
|
|||
smbFile := testutil.GetWorkDirPath("smb.go", t)
|
||||
sourceTest := testutil.GetWorkDirPath("source_test", t)
|
||||
|
||||
testSource := "\\\\hostname\\share\\test"
|
||||
volContext := map[string]string{
|
||||
sourceField: "test_source",
|
||||
sourceField: testSource,
|
||||
}
|
||||
volContextWithMetadata := map[string]string{
|
||||
sourceField: "test_source",
|
||||
sourceField: testSource,
|
||||
pvcNameKey: "pvcname",
|
||||
pvcNamespaceKey: "pvcnamespace",
|
||||
pvNameKey: "pvname",
|
||||
|
|
@ -152,14 +153,14 @@ func TestNodeStageVolume(t *testing.T) {
|
|||
VolumeCapability: &stdVolCap,
|
||||
VolumeContext: volContext,
|
||||
Secrets: secrets},
|
||||
flakyWindowsErrorMessage: fmt.Sprintf("volume(vol_1##) mount \"test_source\" on %#v failed "+
|
||||
flakyWindowsErrorMessage: fmt.Sprintf("volume(vol_1##) mount \"%s\" on %#v failed "+
|
||||
"with smb mapping failed with error: rpc error: code = Unknown desc = NewSmbGlobalMapping failed.",
|
||||
errorMountSensSource),
|
||||
strings.Replace(testSource, "\\", "\\\\", -1), errorMountSensSource),
|
||||
expectedErr: testutil.TestError{
|
||||
DefaultError: status.Errorf(codes.Internal,
|
||||
fmt.Sprintf("volume(vol_1##) mount \"test_source\" on \"%s\" failed with fake "+
|
||||
fmt.Sprintf("volume(vol_1##) mount \"%s\" on \"%s\" failed with fake "+
|
||||
"MountSensitive: target error",
|
||||
errorMountSensSource)),
|
||||
strings.Replace(testSource, "\\", "\\\\", -1), errorMountSensSource)),
|
||||
},
|
||||
},
|
||||
{
|
||||
|
|
@ -168,9 +169,9 @@ func TestNodeStageVolume(t *testing.T) {
|
|||
VolumeCapability: &stdVolCap,
|
||||
VolumeContext: volContext,
|
||||
Secrets: secrets},
|
||||
flakyWindowsErrorMessage: fmt.Sprintf("volume(vol_1##) mount \"test_source\" on %#v failed with "+
|
||||
flakyWindowsErrorMessage: fmt.Sprintf("volume(vol_1##) mount \"%s\" on %#v failed with "+
|
||||
"smb mapping failed with error: rpc error: code = Unknown desc = NewSmbGlobalMapping failed.",
|
||||
sourceTest),
|
||||
strings.Replace(testSource, "\\", "\\\\", -1), sourceTest),
|
||||
expectedErr: testutil.TestError{},
|
||||
},
|
||||
{
|
||||
|
|
@ -179,9 +180,9 @@ func TestNodeStageVolume(t *testing.T) {
|
|||
VolumeCapability: &stdVolCap,
|
||||
VolumeContext: volContextWithMetadata,
|
||||
Secrets: secrets},
|
||||
flakyWindowsErrorMessage: fmt.Sprintf("volume(vol_1##) mount \"test_source\" on %#v failed with "+
|
||||
flakyWindowsErrorMessage: fmt.Sprintf("volume(vol_1##) mount \"%s\" on %#v failed with "+
|
||||
"smb mapping failed with error: rpc error: code = Unknown desc = NewSmbGlobalMapping failed.",
|
||||
sourceTest),
|
||||
strings.Replace(testSource, "\\", "\\\\", -1), sourceTest),
|
||||
expectedErr: testutil.TestError{},
|
||||
},
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue