From 8fd9633a6bb136af6666626a76056b1592fef0c7 Mon Sep 17 00:00:00 2001 From: "Guillaume J. Charmes" Date: Wed, 11 Dec 2013 17:12:53 -0800 Subject: [PATCH] Improve FollowLink to handle recursive link and be more strick --- container.go | 2 +- utils/fs.go | 50 ++++++++++++++++++++++++++----------------- utils/fs_test.go | 39 ++++++++++++++++----------------- utils/testdata/fs/b/h | 1 + utils/testdata/fs/g | 1 + 5 files changed, 52 insertions(+), 41 deletions(-) create mode 120000 utils/testdata/fs/b/h create mode 120000 utils/testdata/fs/g diff --git a/container.go b/container.go index 5959ec7600..e207c3223d 100644 --- a/container.go +++ b/container.go @@ -827,7 +827,7 @@ func (container *Container) createVolumes() error { // Create the mountpoint volPath = path.Join(container.RootfsPath(), volPath) - rootVolPath, err := utils.FollowSymlink(volPath, container.RootfsPath()) + rootVolPath, err := utils.FollowSymlinkInScope(volPath, container.RootfsPath()) if err != nil { panic(err) } diff --git a/utils/fs.go b/utils/fs.go index e4897506d9..e710926210 100644 --- a/utils/fs.go +++ b/utils/fs.go @@ -1,6 +1,7 @@ package utils import ( + "fmt" "os" "path/filepath" "strings" @@ -38,43 +39,52 @@ func TreeSize(dir string) (size int64, err error) { // FollowSymlink will follow an existing link and scope it to the root // path provided. func FollowSymlinkInScope(link, root string) (string, error) { - prev := "." + prev := "/" root, err := filepath.Abs(root) if err != nil { return "", err } - root = filepath.Clean(root) - link, err := filepath.Abs(link) + + link, err = filepath.Abs(link) if err != nil { return "", err } - link = filepath.Clean(link) + + if !strings.HasPrefix(filepath.Dir(link), root) { + return "", fmt.Errorf("%s is not within %s", link, root) + } for _, p := range strings.Split(link, "/") { prev = filepath.Join(prev, p) prev = filepath.Clean(prev) - stat, err := os.Lstat(prev) - if err != nil { - if os.IsNotExist(err) { - continue - } - return "", err - } - if stat.Mode()&os.ModeSymlink == os.ModeSymlink { - dest, err := os.Readlink(prev) + for { + stat, err := os.Lstat(prev) if err != nil { + if os.IsNotExist(err) { + break + } return "", err } - - switch dest[0] { - case '/': - prev = filepath.Join(root, dest) - case '.': - if prev = filepath.Clean(filepath.Join(filepath.Dir(prev), dest)); len(prev) < len(root) { - prev = filepath.Join(root, filepath.Base(dest)) + if stat.Mode()&os.ModeSymlink == os.ModeSymlink { + dest, err := os.Readlink(prev) + if err != nil { + return "", err } + + switch dest[0] { + case '/': + prev = filepath.Join(root, dest) + case '.': + prev, _ = filepath.Abs(prev) + + if prev = filepath.Clean(filepath.Join(filepath.Dir(prev), dest)); len(prev) < len(root) { + prev = filepath.Join(root, filepath.Base(dest)) + } + } + } else { + break } } } diff --git a/utils/fs_test.go b/utils/fs_test.go index 5f99ea771d..dd5d97be40 100644 --- a/utils/fs_test.go +++ b/utils/fs_test.go @@ -1,15 +1,14 @@ package utils import ( - "os" "path/filepath" "testing" ) -func abs(p string) string { +func abs(t *testing.T, p string) string { o, err := filepath.Abs(p) if err != nil { - panic(err) + t.Fatal(err) } return o } @@ -17,36 +16,31 @@ func abs(p string) string { func TestFollowSymLinkNormal(t *testing.T) { link := "testdata/fs/a/d/c/data" - rewrite, err := FollowSymlink(link, "test") + rewrite, err := FollowSymlinkInScope(link, "testdata") if err != nil { t.Fatal(err) } - if expected := abs("test/b/c/data"); expected != rewrite { + if expected := abs(t, "testdata/b/c/data"); expected != rewrite { t.Fatalf("Expected %s got %s", expected, rewrite) } } func TestFollowSymLinkRandomString(t *testing.T) { - rewrite, err := FollowSymlink("toto", "test") - if err != nil { - t.Fatal(err) - } - - if rewrite != "toto" { - t.Fatalf("Expected toto got %s", rewrite) + if _, err := FollowSymlinkInScope("toto", "testdata"); err == nil { + t.Fatal("Random string should fail but didn't") } } func TestFollowSymLinkLastLink(t *testing.T) { link := "testdata/fs/a/d" - rewrite, err := FollowSymlink(link, "test") + rewrite, err := FollowSymlinkInScope(link, "testdata") if err != nil { t.Fatal(err) } - if expected := abs("test/b"); expected != rewrite { + if expected := abs(t, "testdata/b"); expected != rewrite { t.Fatalf("Expected %s got %s", expected, rewrite) } } @@ -54,31 +48,36 @@ func TestFollowSymLinkLastLink(t *testing.T) { func TestFollowSymLinkRelativeLink(t *testing.T) { link := "testdata/fs/a/e/c/data" - rewrite, err := FollowSymlink(link, "test") + rewrite, err := FollowSymlinkInScope(link, "testdata") if err != nil { t.Fatal(err) } - if expected := abs("testdata/fs/a/e/c/data"); expected != rewrite { + if expected := abs(t, "testdata/fs/b/c/data"); expected != rewrite { t.Fatalf("Expected %s got %s", expected, rewrite) } } func TestFollowSymLinkRelativeLinkScope(t *testing.T) { link := "testdata/fs/a/f" - pwd, err := os.Getwd() + + rewrite, err := FollowSymlinkInScope(link, "testdata") if err != nil { t.Fatal(err) } - root := filepath.Join(pwd, "testdata") + if expected := abs(t, "testdata/test"); expected != rewrite { + t.Fatalf("Expected %s got %s", expected, rewrite) + } - rewrite, err := FollowSymlink(link, root) + link = "testdata/fs/b/h" + + rewrite, err = FollowSymlinkInScope(link, "testdata") if err != nil { t.Fatal(err) } - if expected := abs("testdata/test"); expected != rewrite { + if expected := abs(t, "testdata/root"); expected != rewrite { t.Fatalf("Expected %s got %s", expected, rewrite) } } diff --git a/utils/testdata/fs/b/h b/utils/testdata/fs/b/h new file mode 120000 index 0000000000..24387a68fb --- /dev/null +++ b/utils/testdata/fs/b/h @@ -0,0 +1 @@ +../g \ No newline at end of file diff --git a/utils/testdata/fs/g b/utils/testdata/fs/g new file mode 120000 index 0000000000..0ce5de0647 --- /dev/null +++ b/utils/testdata/fs/g @@ -0,0 +1 @@ +../../../../../../../../../../../../root \ No newline at end of file