Merge pull request #505 from vitaliy-leschenko/smb-unmount

Calls RemoveSmbGlobalMapping when it necessary
This commit is contained in:
Kubernetes Prow Robot 2022-08-13 06:08:12 -07:00 committed by GitHub
commit eb9ddbcf81
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 414 additions and 13 deletions

View File

@ -0,0 +1,122 @@
//go:build windows
// +build windows
/*
Copyright 2020 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package mounter
import (
"crypto/md5"
"fmt"
"os"
"path/filepath"
"strings"
"sync"
)
var basePath = "c:\\csi\\smbmounts"
var mutexes sync.Map
func lock(key string) func() {
value, _ := mutexes.LoadOrStore(key, &sync.Mutex{})
mtx := value.(*sync.Mutex)
mtx.Lock()
return func() { mtx.Unlock() }
}
// getRootMappingPath - returns root of smb share path or empty string if the path is invalid. For example:
//
// \\hostname\share\subpath => \\hostname\share, error is nil
// \\hostname\share => \\hostname\share, error is nil
// \\hostname => '', error is 'remote path (\\hostname) is invalid'
func getRootMappingPath(path string) (string, error) {
items := strings.Split(path, "\\")
parts := []string{}
for _, s := range items {
if len(s) > 0 {
parts = append(parts, s)
if len(parts) == 2 {
break
}
}
}
if len(parts) != 2 {
return "", fmt.Errorf("remote path (%s) is invalid", path)
}
// parts[0] is a smb host name
// parts[1] is a smb share name
return strings.ToLower("\\\\" + parts[0] + "\\" + parts[1]), nil
}
// incementRemotePathReferencesCount - adds new reference between mappingPath and remotePath if it doesn't exist.
// How it works:
// 1. MappingPath contains two components: hostname, sharename
// 2. We create directory in basePath related to each mappingPath. It will be used as container for references.
// Example: c:\\csi\\smbmounts\\hostname\\sharename
// 3. Each reference is a file with name based on MD5 of remotePath. For debug it also will contains remotePath in body of the file.
// So, in incementRemotePathReferencesCount we create the file. In decrementRemotePathReferencesCount we remove the file.
// Example: c:\\csi\\smbmounts\\hostname\\sharename\\092f1413e6c1d03af8b5da6f44619af8
func incementRemotePathReferencesCount(mappingPath, remotePath string) error {
remotePath = strings.TrimSuffix(remotePath, "\\")
path := filepath.Join(basePath, strings.TrimPrefix(mappingPath, "\\\\"))
if err := os.MkdirAll(path, os.ModeDir); err != nil {
return err
}
filePath := filepath.Join(path, getMd5(remotePath))
file, err := os.Create(filePath)
if err != nil {
return err
}
defer func() {
file.Close()
}()
_, err = file.WriteString(remotePath)
return err
}
// decrementRemotePathReferencesCount - removes reference between mappingPath and remotePath.
// See incementRemotePathReferencesCount to understand how references work.
func decrementRemotePathReferencesCount(mappingPath, remotePath string) error {
remotePath = strings.TrimSuffix(remotePath, "\\")
path := filepath.Join(basePath, strings.TrimPrefix(mappingPath, "\\\\"))
if err := os.MkdirAll(path, os.ModeDir); err != nil {
return err
}
filePath := filepath.Join(path, getMd5(remotePath))
return os.Remove(filePath)
}
// getRemotePathReferencesCount - returns count of references between mappingPath and remotePath.
// See incementRemotePathReferencesCount to understand how references work.
func getRemotePathReferencesCount(mappingPath string) int {
path := filepath.Join(basePath, strings.TrimPrefix(mappingPath, "\\\\"))
if os.MkdirAll(path, os.ModeDir) != nil {
return -1
}
files, err := os.ReadDir(path)
if err != nil {
return -1
}
return len(files)
}
func getMd5(path string) string {
data := []byte(strings.ToLower(path))
return fmt.Sprintf("%x", md5.Sum(data))
}

View File

@ -0,0 +1,227 @@
/*
Copyright 2020 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package mounter
import (
"os"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestLockUnlock(t *testing.T) {
key := "resource name"
unlock := lock(key)
defer unlock()
_, loaded := mutexes.Load(key)
assert.True(t, loaded)
}
func TestLockLockedResource(t *testing.T) {
locked := true
unlock := lock("a")
go func() {
time.Sleep(500 * time.Microsecond)
locked = false
unlock()
}()
// try to lock already locked resource
unlock2 := lock("a")
defer unlock2()
if locked {
assert.Fail(t, "access to locked resource")
}
}
func TestLockDifferentKeys(t *testing.T) {
unlocka := lock("a")
unlockb := lock("b")
unlocka()
unlockb()
}
func TestGetRootMappingPath(t *testing.T) {
testCases := []struct {
remote string
expectResult string
expectError bool
}{
{
remote: "",
expectResult: "",
expectError: true,
},
{
remote: "hostname",
expectResult: "",
expectError: true,
},
{
remote: "\\\\hostname\\path",
expectResult: "\\\\hostname\\path",
expectError: false,
},
{
remote: "\\\\hostname\\path\\",
expectResult: "\\\\hostname\\path",
expectError: false,
},
{
remote: "\\\\hostname\\path\\subpath",
expectResult: "\\\\hostname\\path",
expectError: false,
},
}
for _, tc := range testCases {
result, err := getRootMappingPath(tc.remote)
if tc.expectError && err == nil {
t.Errorf("Expected error but getRootMappingPath returned a nil error")
}
if !tc.expectError {
if err != nil {
t.Errorf("Expected no errors but getRootMappingPath returned error: %v", err)
}
if tc.expectResult != result {
t.Errorf("Expected (%s) but getRootMappingPath returned (%s)", tc.expectResult, result)
}
}
}
}
func TestRemotePathReferencesCounter(t *testing.T) {
remotePath1 := "\\\\servername\\share\\subpath\\1"
remotePath2 := "\\\\servername\\share\\subpath\\2"
mappingPath, err := getRootMappingPath(remotePath1)
assert.Nil(t, err)
basePath = os.Getenv("TEMP") + "\\TestMappingPathCounter"
os.RemoveAll(basePath)
defer func() {
// cleanup temp folder
os.RemoveAll(basePath)
}()
// by default we have no any files in `mappingPath`. So, `count` should be zero
assert.Zero(t, getRemotePathReferencesCount(mappingPath))
// add reference to `remotePath1`. So, `count` should be equal `1`
assert.Nil(t, incementRemotePathReferencesCount(mappingPath, remotePath1))
assert.Equal(t, 1, getRemotePathReferencesCount(mappingPath))
// add reference to `remotePath2`. So, `count` should be equal `2`
assert.Nil(t, incementRemotePathReferencesCount(mappingPath, remotePath2))
assert.Equal(t, 2, getRemotePathReferencesCount(mappingPath))
// remove reference to `remotePath1`. So, `count` should be equal `1`
assert.Nil(t, decrementRemotePathReferencesCount(mappingPath, remotePath1))
assert.Equal(t, 1, getRemotePathReferencesCount(mappingPath))
// remove reference to `remotePath2`. So, `count` should be equal `0`
assert.Nil(t, decrementRemotePathReferencesCount(mappingPath, remotePath2))
assert.Zero(t, getRemotePathReferencesCount(mappingPath))
}
func TestIncementRemotePathReferencesCount(t *testing.T) {
remotePath := "\\\\servername\\share\\subpath"
mappingPath, err := getRootMappingPath(remotePath)
assert.Nil(t, err)
basePath = os.Getenv("TEMP") + "\\TestMappingPathCounter"
os.RemoveAll(basePath)
defer func() {
// cleanup temp folder
os.RemoveAll(basePath)
}()
assert.Nil(t, incementRemotePathReferencesCount(mappingPath, remotePath))
mappingPathContainer := basePath + "\\servername\\share"
if dir, err := os.Stat(mappingPathContainer); os.IsNotExist(err) || !dir.IsDir() {
t.Error("mapping file container does not exist")
}
reference := mappingPathContainer + "\\" + getMd5(remotePath)
if file, err := os.Stat(reference); os.IsNotExist(err) || file.IsDir() {
t.Error("reference file does not exist")
}
}
func TestDecrementRemotePathReferencesCount(t *testing.T) {
remotePath := "\\\\servername\\share\\subpath"
mappingPath, err := getRootMappingPath(remotePath)
assert.Nil(t, err)
basePath = os.Getenv("TEMP") + "\\TestMappingPathCounter"
os.RemoveAll(basePath)
defer func() {
// cleanup temp folder
os.RemoveAll(basePath)
}()
assert.Nil(t, incementRemotePathReferencesCount(mappingPath, remotePath))
assert.Nil(t, decrementRemotePathReferencesCount(mappingPath, remotePath))
mappingPathContainer := basePath + "\\servername\\share"
if dir, err := os.Stat(mappingPathContainer); os.IsNotExist(err) || !dir.IsDir() {
t.Error("mapping file container does not exist")
}
reference := mappingPathContainer + "\\" + getMd5(remotePath)
if _, err := os.Stat(reference); os.IsExist(err) {
t.Error("reference file exists")
}
}
func TestMultiplyCallsOfIncementRemotePathReferencesCount(t *testing.T) {
remotePath := "\\\\servername\\share\\subpath"
mappingPath, err := getRootMappingPath(remotePath)
assert.Nil(t, err)
basePath = os.Getenv("TEMP") + "\\TestMappingPathCounter"
os.RemoveAll(basePath)
defer func() {
// cleanup temp folder
os.RemoveAll(basePath)
}()
assert.Zero(t, getRemotePathReferencesCount(mappingPath))
assert.Nil(t, incementRemotePathReferencesCount(mappingPath, remotePath))
// next calls of `incementMappingPathCount` with the same arguments should be ignored
assert.Nil(t, incementRemotePathReferencesCount(mappingPath, remotePath))
assert.Nil(t, incementRemotePathReferencesCount(mappingPath, remotePath))
assert.Nil(t, incementRemotePathReferencesCount(mappingPath, remotePath))
assert.Nil(t, incementRemotePathReferencesCount(mappingPath, remotePath))
assert.Equal(t, 1, getRemotePathReferencesCount(mappingPath))
}
func TestMultiplyCallsOfDecrementRemotePathReferencesCount(t *testing.T) {
remotePath := "\\\\servername\\share\\subpath"
mappingPath, err := getRootMappingPath(remotePath)
assert.Nil(t, err)
basePath = os.Getenv("TEMP") + "\\TestMappingPathCounter"
os.RemoveAll(basePath)
defer func() {
// cleanup temp folder
os.RemoveAll(basePath)
}()
assert.Zero(t, getRemotePathReferencesCount(mappingPath))
assert.Nil(t, incementRemotePathReferencesCount(mappingPath, remotePath))
assert.Nil(t, decrementRemotePathReferencesCount(mappingPath, remotePath))
assert.NotNil(t, decrementRemotePathReferencesCount(mappingPath, remotePath))
}

View File

@ -101,6 +101,17 @@ func (mounter *csiProxyMounter) SMBMount(source, target, fsType string, mountOpt
}
source = strings.Replace(source, "/", "\\", -1)
if strings.HasSuffix(source, "\\") {
source = strings.TrimSuffix(source, "\\")
}
mappingPath, err := getRootMappingPath(source)
if err != nil {
return fmt.Errorf("getRootMappingPath(%s) failed with error: %v", source, err)
}
unlock := lock(mappingPath)
defer unlock()
normalizedTarget := normalizeWindowsPath(target)
smbMountRequest := &smb.NewSmbGlobalMappingRequest{
LocalPath: normalizedTarget,
@ -113,13 +124,53 @@ func (mounter *csiProxyMounter) SMBMount(source, target, fsType string, mountOpt
return fmt.Errorf("smb mapping failed with error: %v", err)
}
klog.V(2).Infof("mount %s on %s successfully", source, normalizedTarget)
if err = incementRemotePathReferencesCount(mappingPath, source); err != nil {
klog.Warningf("incementMappingPathCount(%s, %s) failed with error: %v", mappingPath, source, err)
}
return nil
}
func (mounter *csiProxyMounter) SMBUnmount(target string) error {
klog.V(4).Infof("SMBUnmount: local path: %s", target)
// TODO: We need to remove the SMB mapping. The change to remove the
// directory brings the CSI code in parity with the in-tree.
if remotePath, err := os.Readlink(target); err != nil {
klog.Warningf("SMBUnmount: can't get remote path: %v", err)
} else {
if strings.HasSuffix(remotePath, "\\") {
remotePath = strings.TrimSuffix(remotePath, "\\")
}
mappingPath, err := getRootMappingPath(remotePath)
if err != nil {
klog.Warningf("getRootMappingPath(%s) failed with error: %v", remotePath, err)
} else {
klog.V(4).Infof("SMBUnmount: remote path: %s, mapping path: %s", remotePath, mappingPath)
unlock := lock(mappingPath)
defer unlock()
if err := decrementRemotePathReferencesCount(mappingPath, remotePath); err != nil {
klog.Warningf("decrementMappingPathCount(%s, %d) failed with error: %v", mappingPath, remotePath, err)
} else {
count := getRemotePathReferencesCount(mappingPath)
if count == 0 {
smbUnmountRequest := &smb.RemoveSmbGlobalMappingRequest{
RemotePath: remotePath,
}
klog.V(2).Infof("begin to unmount %s on %s", remotePath, target)
if _, err := mounter.SMBClient.RemoveSmbGlobalMapping(context.Background(), smbUnmountRequest); err != nil {
return fmt.Errorf("smb unmapping failed with error: %v", err)
} else {
klog.V(2).Infof("unmount %s on %s successfully", remotePath, target)
}
} else {
klog.Infof("SMBUnmount: found %f links to %s", count, mappingPath)
}
}
}
}
return mounter.Rmdir(target)
}

View File

@ -98,7 +98,7 @@ func (d *Driver) NodeUnpublishVolume(ctx context.Context, req *csi.NodeUnpublish
}
klog.V(2).Infof("NodeUnpublishVolume: unmounting volume %s on %s", volumeID, targetPath)
err := CleanupSMBMountPoint(d.mounter, targetPath, true /*extensiveMountPointCheck*/)
err := CleanupMountPoint(d.mounter, targetPath, true /*extensiveMountPointCheck*/)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to unmount target %q: %v", targetPath, err)
}

View File

@ -60,11 +60,12 @@ func TestNodeStageVolume(t *testing.T) {
smbFile := testutil.GetWorkDirPath("smb.go", t)
sourceTest := testutil.GetWorkDirPath("source_test", t)
testSource := "\\\\hostname\\share\\test"
volContext := map[string]string{
sourceField: "test_source",
sourceField: testSource,
}
volContextWithMetadata := map[string]string{
sourceField: "test_source",
sourceField: testSource,
pvcNameKey: "pvcname",
pvcNamespaceKey: "pvcnamespace",
pvNameKey: "pvname",
@ -152,14 +153,14 @@ func TestNodeStageVolume(t *testing.T) {
VolumeCapability: &stdVolCap,
VolumeContext: volContext,
Secrets: secrets},
flakyWindowsErrorMessage: fmt.Sprintf("volume(vol_1##) mount \"test_source\" on %#v failed "+
flakyWindowsErrorMessage: fmt.Sprintf("volume(vol_1##) mount \"%s\" on %#v failed "+
"with smb mapping failed with error: rpc error: code = Unknown desc = NewSmbGlobalMapping failed.",
errorMountSensSource),
strings.Replace(testSource, "\\", "\\\\", -1), errorMountSensSource),
expectedErr: testutil.TestError{
DefaultError: status.Errorf(codes.Internal,
fmt.Sprintf("volume(vol_1##) mount \"test_source\" on \"%s\" failed with fake "+
fmt.Sprintf("volume(vol_1##) mount \"%s\" on \"%s\" failed with fake "+
"MountSensitive: target error",
errorMountSensSource)),
strings.Replace(testSource, "\\", "\\\\", -1), errorMountSensSource)),
},
},
{
@ -168,9 +169,9 @@ func TestNodeStageVolume(t *testing.T) {
VolumeCapability: &stdVolCap,
VolumeContext: volContext,
Secrets: secrets},
flakyWindowsErrorMessage: fmt.Sprintf("volume(vol_1##) mount \"test_source\" on %#v failed with "+
flakyWindowsErrorMessage: fmt.Sprintf("volume(vol_1##) mount \"%s\" on %#v failed with "+
"smb mapping failed with error: rpc error: code = Unknown desc = NewSmbGlobalMapping failed.",
sourceTest),
strings.Replace(testSource, "\\", "\\\\", -1), sourceTest),
expectedErr: testutil.TestError{},
},
{
@ -179,9 +180,9 @@ func TestNodeStageVolume(t *testing.T) {
VolumeCapability: &stdVolCap,
VolumeContext: volContextWithMetadata,
Secrets: secrets},
flakyWindowsErrorMessage: fmt.Sprintf("volume(vol_1##) mount \"test_source\" on %#v failed with "+
flakyWindowsErrorMessage: fmt.Sprintf("volume(vol_1##) mount \"%s\" on %#v failed with "+
"smb mapping failed with error: rpc error: code = Unknown desc = NewSmbGlobalMapping failed.",
sourceTest),
strings.Replace(testSource, "\\", "\\\\", -1), sourceTest),
expectedErr: testutil.TestError{},
},
}