components-contrib/bindings/localstorage/localstorage_test.go

124 lines
5.1 KiB
Go

/*
Copyright 2021 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package localstorage
import (
"os"
"path/filepath"
"runtime"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/dapr/components-contrib/bindings"
"github.com/dapr/kit/logger"
)
func TestParseMetadata(t *testing.T) {
m := bindings.Metadata{}
path := "/files"
if runtime.GOOS == "windows" {
path = "C:\\files"
}
m.Properties = map[string]string{"rootPath": path}
localStorage := NewLocalStorage(logger.NewLogger("test")).(*LocalStorage)
meta, err := localStorage.parseMetadata(m)
require.NoError(t, err)
assert.Equal(t, path, meta.RootPath)
}
func TestValidateRootPath(t *testing.T) {
// Get the current working directory
cwd, err := os.Getwd()
require.NoError(t, err)
// Set up some things in the FS
tmpDir := t.TempDir()
require.NoError(t, os.MkdirAll(filepath.Join(tmpDir, "aaa/bbb"), 0o755))
require.NoError(t, os.MkdirAll(filepath.Join(tmpDir, "aaa/ccc"), 0o755))
require.NoError(t, os.MkdirAll(filepath.Join(tmpDir, "zzz/aaa"), 0o755))
require.NoError(t, os.MkdirAll(filepath.Join(tmpDir, "notgood"), 0o755))
require.NoError(t, os.Symlink(filepath.Join(tmpDir, "aaa/bbb"), filepath.Join(tmpDir, "zzz/link")))
require.NoError(t, os.Symlink(filepath.Join(tmpDir, "notgood"), filepath.Join(tmpDir, "aaa/notgood")))
f, err := os.Create(filepath.Join(tmpDir, "aaa/file"))
f.Close()
require.NoError(t, err)
rootDir := "/"
if runtime.GOOS == "windows" {
rootDir = "C:\\"
}
// Set the list of disallowed paths to some locations that don't exist
// This is because the list contains folders that otherwise are resolved as symlinks in some OS's (like macOS)
oldDisallowedRootPaths := disallowedRootPaths
disallowedRootPaths = []string{
// Explicitly set both the Linux and Windows formats
"/notgood",
"C:\\notgood",
filepath.Join(joinWithMustEvalSymlinks(tmpDir), "notgood"),
}
defer func() {
disallowedRootPaths = oldDisallowedRootPaths
}()
tests := []struct {
name string
rootPath string
wantRes string
// String that will be matched in the error
// If no error is expected, set this to an empty string
wantErr string
}{
{name: "empty", rootPath: "", wantErr: "must not be empty"},
{name: "relative path 1", rootPath: "path", wantRes: filepath.Join(cwd, "path")},
{name: "relative path 2", rootPath: filepath.Clean("../path"), wantRes: filepath.Join(cwd, "..", "path")},
{name: "existing path 1", rootPath: filepath.Join(tmpDir, "aaa/bbb"), wantRes: joinWithMustEvalSymlinks(tmpDir, "aaa/bbb")},
{name: "existing path 2", rootPath: filepath.Join(tmpDir, "zzz/aaa"), wantRes: joinWithMustEvalSymlinks(tmpDir, "zzz/aaa")},
{name: "path does not exist 1", rootPath: filepath.Join(tmpDir, "zzz/foo"), wantRes: filepath.Join(joinWithMustEvalSymlinks(tmpDir, "zzz"), "foo")},
{name: "path does not exist 1", rootPath: filepath.Join(tmpDir, "zzz/aaa/deep/deep"), wantRes: filepath.Join(joinWithMustEvalSymlinks(tmpDir, "zzz/aaa"), "deep/deep")},
{name: "resolve symlinks", rootPath: filepath.Join(tmpDir, "zzz/link"), wantRes: joinWithMustEvalSymlinks(tmpDir, "aaa/bbb")},
{name: "resolve symlinks subfolder", rootPath: filepath.Join(tmpDir, "zzz/link/sub"), wantRes: filepath.Join(joinWithMustEvalSymlinks(tmpDir, "aaa/bbb"), "sub")},
{name: "file", rootPath: filepath.Join(tmpDir, "aaa/file"), wantErr: "not a directory"},
{name: "file in higher level", rootPath: filepath.Join(tmpDir, "aaa/file/2"), wantErr: "not a directory"},
{name: "disallowed path 1", rootPath: filepath.Join(rootDir, "notgood"), wantErr: "disallowed location"},
{name: "disallowed path 1 subfolder", rootPath: filepath.Join(rootDir, "notgood", "foo"), wantErr: "disallowed location"},
{name: "disallowed path 2", rootPath: filepath.Join(tmpDir, "notgood"), wantErr: "disallowed location"},
{name: "disallowed path 2 subfolder", rootPath: filepath.Join(tmpDir, "notgood", "foo"), wantErr: "disallowed location"},
{name: "symlink to disallowed path", rootPath: filepath.Join(tmpDir, "aaa/notgood"), wantErr: "disallowed location"},
{name: "symlink to disallowed path subfolder", rootPath: filepath.Join(tmpDir, "aaa/notgood/foo"), wantErr: "disallowed location"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
res, err := validateRootPath(tt.rootPath)
if tt.wantErr != "" {
require.Error(t, err)
require.ErrorContains(t, err, tt.wantErr)
} else {
require.NoError(t, err)
}
require.Equal(t, tt.wantRes, res)
})
}
}
func joinWithMustEvalSymlinks(v ...string) string {
r, err := filepath.EvalSymlinks(filepath.Join(v...))
if err != nil {
panic(err)
}
return r
}