From 253933220965574422aa6679255359d8bd15d435 Mon Sep 17 00:00:00 2001 From: Akihiro Suda Date: Wed, 6 Jul 2016 02:40:48 +0000 Subject: [PATCH] Validate arguments for `ps` in `docker top` Fix #24357 Signed-off-by: Akihiro Suda --- daemon/top_unix.go | 113 +++++++++++++++++++++++++++------------- daemon/top_unix_test.go | 76 +++++++++++++++++++++++++++ 2 files changed, 153 insertions(+), 36 deletions(-) create mode 100644 daemon/top_unix_test.go diff --git a/daemon/top_unix.go b/daemon/top_unix.go index d4a9528c98..935f38f29e 100644 --- a/daemon/top_unix.go +++ b/daemon/top_unix.go @@ -5,12 +5,82 @@ package daemon import ( "fmt" "os/exec" + "regexp" "strconv" "strings" "github.com/docker/engine-api/types" ) +func validatePSArgs(psArgs string) error { + // NOTE: \\s does not detect unicode whitespaces. + // So we use fieldsASCII instead of strings.Fields in parsePSOutput. + // See https://github.com/docker/docker/pull/24358 + re := regexp.MustCompile("\\s+([^\\s]*)=\\s*(PID[^\\s]*)") + for _, group := range re.FindAllStringSubmatch(psArgs, -1) { + if len(group) >= 3 { + k := group[1] + v := group[2] + if k != "pid" { + return fmt.Errorf("specifying \"%s=%s\" is not allowed", k, v) + } + } + } + return nil +} + +// fieldsASCII is similar to strings.Fields but only allows ASCII whitespaces +func fieldsASCII(s string) []string { + fn := func(r rune) bool { + switch r { + case '\t', '\n', '\f', '\r', ' ': + return true + } + return false + } + return strings.FieldsFunc(s, fn) +} + +func parsePSOutput(output []byte, pids []int) (*types.ContainerProcessList, error) { + procList := &types.ContainerProcessList{} + + lines := strings.Split(string(output), "\n") + procList.Titles = fieldsASCII(lines[0]) + + pidIndex := -1 + for i, name := range procList.Titles { + if name == "PID" { + pidIndex = i + } + } + if pidIndex == -1 { + return nil, fmt.Errorf("Couldn't find PID field in ps output") + } + + // loop through the output and extract the PID from each line + for _, line := range lines[1:] { + if len(line) == 0 { + continue + } + fields := fieldsASCII(line) + p, err := strconv.Atoi(fields[pidIndex]) + if err != nil { + return nil, fmt.Errorf("Unexpected pid '%s': %s", fields[pidIndex], err) + } + + for _, pid := range pids { + if pid == p { + // Make sure number of fields equals number of header titles + // merging "overhanging" fields + process := fields[:len(procList.Titles)-1] + process = append(process, strings.Join(fields[len(procList.Titles)-1:], " ")) + procList.Processes = append(procList.Processes, process) + } + } + } + return procList, nil +} + // ContainerTop lists the processes running inside of the given // container by calling ps with the given args, or with the flags // "-ef" if no args are given. An error is returned if the container @@ -21,6 +91,10 @@ func (daemon *Daemon) ContainerTop(name string, psArgs string) (*types.Container psArgs = "-ef" } + if err := validatePSArgs(psArgs); err != nil { + return nil, err + } + container, err := daemon.GetContainer(name) if err != nil { return nil, err @@ -43,42 +117,9 @@ func (daemon *Daemon) ContainerTop(name string, psArgs string) (*types.Container if err != nil { return nil, fmt.Errorf("Error running ps: %v", err) } - - procList := &types.ContainerProcessList{} - - lines := strings.Split(string(output), "\n") - procList.Titles = strings.Fields(lines[0]) - - pidIndex := -1 - for i, name := range procList.Titles { - if name == "PID" { - pidIndex = i - } - } - if pidIndex == -1 { - return nil, fmt.Errorf("Couldn't find PID field in ps output") - } - - // loop through the output and extract the PID from each line - for _, line := range lines[1:] { - if len(line) == 0 { - continue - } - fields := strings.Fields(line) - p, err := strconv.Atoi(fields[pidIndex]) - if err != nil { - return nil, fmt.Errorf("Unexpected pid '%s': %s", fields[pidIndex], err) - } - - for _, pid := range pids { - if pid == p { - // Make sure number of fields equals number of header titles - // merging "overhanging" fields - process := fields[:len(procList.Titles)-1] - process = append(process, strings.Join(fields[len(procList.Titles)-1:], " ")) - procList.Processes = append(procList.Processes, process) - } - } + procList, err := parsePSOutput(output, pids) + if err != nil { + return nil, err } daemon.LogContainerEvent(container, "top") return procList, nil diff --git a/daemon/top_unix_test.go b/daemon/top_unix_test.go new file mode 100644 index 0000000000..269ab6e947 --- /dev/null +++ b/daemon/top_unix_test.go @@ -0,0 +1,76 @@ +//+build !windows + +package daemon + +import ( + "testing" +) + +func TestContainerTopValidatePSArgs(t *testing.T) { + tests := map[string]bool{ + "ae -o uid=PID": true, + "ae -o \"uid= PID\"": true, // ascii space (0x20) + "ae -o \"uid= PID\"": false, // unicode space (U+2003, 0xe2 0x80 0x83) + "ae o uid=PID": true, + "aeo uid=PID": true, + "ae -O uid=PID": true, + "ae -o pid=PID2 -o uid=PID": true, + "ae -o pid=PID": false, + "ae -o pid=PID -o uid=PIDX": true, // FIXME: we do not need to prohibit this + "aeo pid=PID": false, + "ae": false, + "": false, + } + for psArgs, errExpected := range tests { + err := validatePSArgs(psArgs) + t.Logf("tested %q, got err=%v", psArgs, err) + if errExpected && err == nil { + t.Fatalf("expected error, got %v (%q)", err, psArgs) + } + if !errExpected && err != nil { + t.Fatalf("expected nil, got %v (%q)", err, psArgs) + } + } +} + +func TestContainerTopParsePSOutput(t *testing.T) { + tests := []struct { + output []byte + pids []int + errExpected bool + }{ + {[]byte(` PID COMMAND + 42 foo + 43 bar + 100 baz +`), []int{42, 43}, false}, + {[]byte(` UID COMMAND + 42 foo + 43 bar + 100 baz +`), []int{42, 43}, true}, + // unicode space (U+2003, 0xe2 0x80 0x83) + {[]byte(` PID COMMAND + 42 foo + 43 bar + 100 baz +`), []int{42, 43}, true}, + // the first space is U+2003, the second one is ascii. + {[]byte(` PID COMMAND + 42 foo + 43 bar + 100 baz +`), []int{42, 43}, true}, + } + + for _, f := range tests { + _, err := parsePSOutput(f.output, f.pids) + t.Logf("tested %q, got err=%v", string(f.output), err) + if f.errExpected && err == nil { + t.Fatalf("expected error, got %v (%q)", err, string(f.output)) + } + if !f.errExpected && err != nil { + t.Fatalf("expected nil, got %v (%q)", err, string(f.output)) + } + } +}