diff --git a/common/pkg/chown/chown.go b/common/pkg/chown/chown.go new file mode 100644 index 0000000000..fe794304ed --- /dev/null +++ b/common/pkg/chown/chown.go @@ -0,0 +1,122 @@ +package chown + +import ( + "os" + "os/user" + "path/filepath" + "syscall" + + "github.com/containers/storage/pkg/homedir" + "github.com/pkg/errors" +) + +// DangerousHostPath validates if a host path is dangerous and should not be modified +func DangerousHostPath(path string) (bool, error) { + excludePaths := map[string]bool{ + "/": true, + "/bin": true, + "/boot": true, + "/dev": true, + "/etc": true, + "/etc/passwd": true, + "/etc/pki": true, + "/etc/shadow": true, + "/home": true, + "/lib": true, + "/lib64": true, + "/media": true, + "/opt": true, + "/proc": true, + "/root": true, + "/run": true, + "/sbin": true, + "/srv": true, + "/sys": true, + "/tmp": true, + "/usr": true, + "/var": true, + "/var/lib": true, + "/var/log": true, + } + + if home := homedir.Get(); home != "" { + excludePaths[home] = true + } + + if sudoUser := os.Getenv("SUDO_USER"); sudoUser != "" { + if usr, err := user.Lookup(sudoUser); err == nil { + excludePaths[usr.HomeDir] = true + } + } + + absPath, err := filepath.Abs(path) + if err != nil { + return true, err + } + + realPath, err := filepath.EvalSymlinks(absPath) + if err != nil { + return true, err + } + + if excludePaths[realPath] { + return true, nil + } + + return false, nil +} + +// ChangeHostPathOwnership changes the uid and gid ownership of a directory or file within the host. +// This is used by the volume U flag to change source volumes ownership +func ChangeHostPathOwnership(path string, recursive bool, uid, gid int) error { + // Validate if host path can be chowned + isDangerous, err := DangerousHostPath(path) + if err != nil { + return errors.Wrapf(err, "failed to validate if host path is dangerous") + } + + if isDangerous { + return errors.Errorf("chowning host path %q is not allowed. You can manually `chown -R %d:%d %s`", path, uid, gid, path) + } + + // Chown host path + if recursive { + err := filepath.Walk(path, func(filePath string, f os.FileInfo, err error) error { + if err != nil { + return err + } + + // Get current ownership + currentUID := int(f.Sys().(*syscall.Stat_t).Uid) + currentGID := int(f.Sys().(*syscall.Stat_t).Gid) + + if uid != currentUID || gid != currentGID { + return os.Lchown(filePath, uid, gid) + } + + return nil + }) + + if err != nil { + return errors.Wrapf(err, "failed to chown recursively host path") + } + } else { + // Get host path info + f, err := os.Lstat(path) + if err != nil { + return errors.Wrapf(err, "failed to get host path information") + } + + // Get current ownership + currentUID := int(f.Sys().(*syscall.Stat_t).Uid) + currentGID := int(f.Sys().(*syscall.Stat_t).Gid) + + if uid != currentUID || gid != currentGID { + if err := os.Lchown(path, uid, gid); err != nil { + return errors.Wrapf(err, "failed to chown host path") + } + } + } + + return nil +} diff --git a/common/pkg/chown/chown_test.go b/common/pkg/chown/chown_test.go new file mode 100644 index 0000000000..b92040e33c --- /dev/null +++ b/common/pkg/chown/chown_test.go @@ -0,0 +1,136 @@ +package chown + +import ( + "io/ioutil" + "os" + "runtime" + "syscall" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestDangerousHostPath(t *testing.T) { + if runtime.GOOS != "linux" { + t.Skip("Current paths are supported only by Linux") + } + + // Create a temp dir that is not dangerous + td, err := ioutil.TempDir("/tmp", "validDir") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(td) + + tests := []struct { + Path string + Expected bool + ExpectError bool + ExpectedErrorMsg string + }{ + { + "/tmp", + true, + false, + "", + }, + { + td, + false, + false, + "", + }, + { + "/doesnotexist", + false, + true, + "no such file or directory", + }, + } + + for _, test := range tests { + result, err := DangerousHostPath(test.Path) + if test.ExpectError { + assert.Error(t, err) + assert.Contains(t, err.Error(), test.ExpectedErrorMsg) + } else { + assert.NoError(t, err) + assert.Equal(t, test.Expected, result) + } + } +} + +func TestChangeHostPathOwnership(t *testing.T) { + if runtime.GOOS != "linux" { + t.Skip("Current paths are supported only by Linux") + } + + // Create a temp dir that is not dangerous + td, err := ioutil.TempDir("/tmp", "validDir") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(td) + + // Get host path info + f, err := os.Lstat(td) + if err != nil { + t.Fatal(err) + } + + // Get current ownership + currentUID := int(f.Sys().(*syscall.Stat_t).Uid) + currentGID := int(f.Sys().(*syscall.Stat_t).Gid) + + tests := []struct { + Path string + Recursive bool + UID int + GID int + ExpectError bool + ExpectedErrorMsg string + }{ + { + "/doesnotexist", + false, + 0, + 0, + true, + "no such file or directory", + }, + { + "/tmp", + false, + 0, + 0, + true, + "is not allowed", + }, + { + td, + false, + currentUID, + currentGID, + false, + "", + }, + { + td, + true, + currentUID, + currentGID, + false, + "", + }, + } + + for _, test := range tests { + err := ChangeHostPathOwnership(test.Path, test.Recursive, test.UID, test.GID) + if test.ExpectError { + assert.Error(t, err) + assert.Contains(t, err.Error(), test.ExpectedErrorMsg) + } else { + assert.NoError(t, err) + } + } +} diff --git a/common/pkg/parse/parse.go b/common/pkg/parse/parse.go index 611b2e84bf..882953309f 100644 --- a/common/pkg/parse/parse.go +++ b/common/pkg/parse/parse.go @@ -13,7 +13,7 @@ import ( // ValidateVolumeOpts validates a volume's options func ValidateVolumeOpts(options []string) ([]string, error) { - var foundRootPropagation, foundRWRO, foundLabelChange, bindType, foundExec, foundDev, foundSuid int + var foundRootPropagation, foundRWRO, foundLabelChange, bindType, foundExec, foundDev, foundSuid, foundChown int finalOpts := make([]string, 0, len(options)) for _, opt := range options { switch opt { @@ -42,6 +42,11 @@ func ValidateVolumeOpts(options []string) ([]string, error) { if foundLabelChange > 1 { return nil, errors.Errorf("invalid options %q, can only specify 1 'z', 'Z', or 'O' option", strings.Join(options, ", ")) } + case "U": + foundChown++ + if foundChown > 1 { + return nil, errors.Errorf("invalid options %q, can only specify 1 'U' option", strings.Join(options, ", ")) + } case "private", "rprivate", "shared", "rshared", "slave", "rslave", "unbindable", "runbindable": foundRootPropagation++ if foundRootPropagation > 1 {