diff --git a/pkg/mount/mounter_linux.go b/pkg/mount/mounter_linux.go index dd4280c777..3ef2ce6f0d 100644 --- a/pkg/mount/mounter_linux.go +++ b/pkg/mount/mounter_linux.go @@ -4,15 +4,50 @@ import ( "syscall" ) -func mount(device, target, mType string, flag uintptr, data string) error { - if err := syscall.Mount(device, target, mType, flag, data); err != nil { - return err +const ( + // ptypes is the set propagation types. + ptypes = syscall.MS_SHARED | syscall.MS_PRIVATE | syscall.MS_SLAVE | syscall.MS_UNBINDABLE + + // pflags is the full set valid flags for a change propagation call. + pflags = ptypes | syscall.MS_REC | syscall.MS_SILENT + + // broflags is the combination of bind and read only + broflags = syscall.MS_BIND | syscall.MS_RDONLY +) + +// isremount returns true if either device name or flags identify a remount request, false otherwise. +func isremount(device string, flags uintptr) bool { + switch { + // We treat device "" and "none" as a remount request to provide compatibility with + // requests that don't explicitly set MS_REMOUNT such as those manipulating bind mounts. + case flags&syscall.MS_REMOUNT != 0, device == "", device == "none": + return true + default: + return false + } +} + +func mount(device, target, mType string, flags uintptr, data string) error { + oflags := flags &^ ptypes + if !isremount(device, flags) { + // Initial call applying all non-propagation flags. + if err := syscall.Mount(device, target, mType, oflags, data); err != nil { + return err + } } - // If we have a bind mount or remount, remount... - if flag&syscall.MS_BIND == syscall.MS_BIND && flag&syscall.MS_RDONLY == syscall.MS_RDONLY { - return syscall.Mount(device, target, mType, flag|syscall.MS_REMOUNT, data) + if flags&ptypes != 0 { + // Change the propagation type. + if err := syscall.Mount("", target, "", flags&pflags, ""); err != nil { + return err + } } + + if oflags&broflags == broflags { + // Remount the bind to apply read only. + return syscall.Mount("", target, "", oflags|syscall.MS_REMOUNT, "") + } + return nil } diff --git a/pkg/mount/mounter_linux_test.go b/pkg/mount/mounter_linux_test.go new file mode 100644 index 0000000000..9c74b1f51d --- /dev/null +++ b/pkg/mount/mounter_linux_test.go @@ -0,0 +1,195 @@ +// +build linux + +package mount + +import ( + "fmt" + "io/ioutil" + "os" + "strings" + "testing" +) + +func TestMount(t *testing.T) { + if os.Getuid() != 0 { + t.Skip("not root tests would fail") + } + + source, err := ioutil.TempDir("", "mount-test-source-") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(source) + + // Ensure we have a known start point by mounting tmpfs with given options + if err := Mount("tmpfs", source, "tmpfs", "private"); err != nil { + t.Fatal(err) + } + defer ensureUnmount(t, source) + validateMount(t, source, "", "") + if t.Failed() { + t.FailNow() + } + + target, err := ioutil.TempDir("", "mount-test-target-") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(target) + + tests := []struct { + source string + ftype string + options string + expectedOpts string + expectedOptional string + }{ + // No options + {"tmpfs", "tmpfs", "", "", ""}, + // Default rw / ro test + {source, "", "bind", "", ""}, + {source, "", "bind,private", "", ""}, + {source, "", "bind,shared", "", "shared"}, + {source, "", "bind,slave", "", "master"}, + {source, "", "bind,unbindable", "", "unbindable"}, + // Read Write tests + {source, "", "bind,rw", "rw", ""}, + {source, "", "bind,rw,private", "rw", ""}, + {source, "", "bind,rw,shared", "rw", "shared"}, + {source, "", "bind,rw,slave", "rw", "master"}, + {source, "", "bind,rw,unbindable", "rw", "unbindable"}, + // Read Only tests + {source, "", "bind,ro", "ro", ""}, + {source, "", "bind,ro,private", "ro", ""}, + {source, "", "bind,ro,shared", "ro", "shared"}, + {source, "", "bind,ro,slave", "ro", "master"}, + {source, "", "bind,ro,unbindable", "ro", "unbindable"}, + } + + for _, tc := range tests { + ftype, options := tc.ftype, tc.options + if tc.ftype == "" { + ftype = "none" + } + if tc.options == "" { + options = "none" + } + + t.Run(fmt.Sprintf("%v-%v", ftype, options), func(t *testing.T) { + if strings.Contains(tc.options, "slave") { + // Slave requires a shared source + if err := MakeShared(source); err != nil { + t.Fatal(err) + } + defer func() { + if err := MakePrivate(source); err != nil { + t.Fatal(err) + } + }() + } + if err := Mount(tc.source, target, tc.ftype, tc.options); err != nil { + t.Fatal(err) + } + defer ensureUnmount(t, target) + validateMount(t, target, tc.expectedOpts, tc.expectedOptional) + }) + } +} + +// ensureUnmount umounts mnt checking for errors +func ensureUnmount(t *testing.T, mnt string) { + if err := Unmount(mnt); err != nil { + t.Error(err) + } +} + +// validateMount checks that mnt has the given options +func validateMount(t *testing.T, mnt string, opts, optional string) { + info, err := GetMounts() + if err != nil { + t.Fatal(err) + } + + wantedOpts := make(map[string]struct{}) + if opts != "" { + for _, opt := range strings.Split(opts, ",") { + wantedOpts[opt] = struct{}{} + } + } + + wantedOptional := make(map[string]struct{}) + if optional != "" { + for _, opt := range strings.Split(optional, ",") { + wantedOptional[opt] = struct{}{} + } + } + + mnts := make(map[int]*Info, len(info)) + for _, mi := range info { + mnts[mi.ID] = mi + } + + for _, mi := range info { + if mi.Mountpoint != mnt { + continue + } + + // Use parent info as the defaults + p := mnts[mi.Parent] + pOpts := make(map[string]struct{}) + if p.Opts != "" { + for _, opt := range strings.Split(p.Opts, ",") { + pOpts[clean(opt)] = struct{}{} + } + } + pOptional := make(map[string]struct{}) + if p.Optional != "" { + for _, field := range strings.Split(p.Optional, ",") { + pOptional[clean(field)] = struct{}{} + } + } + + // Validate Opts + if mi.Opts != "" { + for _, opt := range strings.Split(mi.Opts, ",") { + opt = clean(opt) + if !has(wantedOpts, opt) && !has(pOpts, opt) { + t.Errorf("unexpected mount option %q expected %q", opt, opts) + } + delete(wantedOpts, opt) + } + } + for opt := range wantedOpts { + t.Errorf("missing mount option %q found %q", opt, mi.Opts) + } + + // Validate Optional + if mi.Optional != "" { + for _, field := range strings.Split(mi.Optional, ",") { + field = clean(field) + if !has(wantedOptional, field) && !has(pOptional, field) { + t.Errorf("unexpected optional failed %q expected %q", field, optional) + } + delete(wantedOptional, field) + } + } + for field := range wantedOptional { + t.Errorf("missing optional field %q found %q", field, mi.Optional) + } + + return + } + + t.Errorf("failed to find mount %q", mnt) +} + +// clean strips off any value param after the colon +func clean(v string) string { + return strings.SplitN(v, ":", 2)[0] +} + +// has returns true if key is a member of m +func has(m map[string]struct{}, key string) bool { + _, ok := m[key] + return ok +}