diff --git a/future/future.go b/future/future.go index a3dd9f7efa..6a1b4e2311 100644 --- a/future/future.go +++ b/future/future.go @@ -107,9 +107,8 @@ func Curl(url string, stderr io.Writer) (io.Reader, error) { } // Request a given URL and return an io.Reader -func Download(url string, stderr io.Writer) (io.Reader, error) { +func Download(url string, stderr io.Writer) (*http.Response, error) { var resp *http.Response - var archive io.ReadCloser = nil var err error = nil fmt.Fprintf(stderr, "Download start\n") // FIXME: Replace with progress bar @@ -119,7 +118,30 @@ func Download(url string, stderr io.Writer) (io.Reader, error) { if resp.StatusCode >= 400 { return nil, errors.New("Got HTTP status code >= 400: " + resp.Status) } - archive = resp.Body fmt.Fprintf(stderr, "Download end\n") // FIXME: Replace with progress bar - return archive, nil + return resp, nil +} + +// Reader with progress bar +type progressReader struct { + reader io.ReadCloser // Stream to read from + output io.Writer // Where to send progress bar to + read_total int // Expected stream length (bytes) + read_progress int // How much has been read so far (bytes) +} +func (r *progressReader) Read(p []byte) (n int, err error) { + read, err := io.ReadCloser(r.reader).Read(p) + // FIXME: Don't print progress bar on every read + r.read_progress += read + fmt.Fprintf(r.output, "%d/%d (%.2f%%)\n", + r.read_progress, + r.read_total, + float64(r.read_progress) / float64(r.read_total) * 100) + return read, err +} +func (r *progressReader) Close() error { + return io.ReadCloser(r.reader).Close() +} +func ProgressReader(r io.ReadCloser, size int, output io.Writer) *progressReader { + return &progressReader{r, output, size, 0} } diff --git a/server/server.go b/server/server.go index d70b413161..9c61ffaee8 100644 --- a/server/server.go +++ b/server/server.go @@ -12,7 +12,6 @@ import ( "github.com/dotcloud/docker/rcli" "io" "io/ioutil" - "net/http" "net/url" "os" "path" @@ -427,19 +426,11 @@ func (srv *Server) CmdPull(stdin io.ReadCloser, stdout io.Writer, args ...string u.Path = path.Join("/docker.io/images", u.Path) } fmt.Fprintf(stdout, "Downloading from %s\n", u.String()) - // Download with curl (pretty progress bar) - // If curl is not available or receives a HTTP error, fallback - // to http.Get() - archive, err := future.Curl(u.String(), stdout) + resp, err := future.Download(u.String(), stdout) + // FIXME: Validate ContentLength + archive := future.ProgressReader(resp.Body, int(resp.ContentLength), stdout) if err != nil { - if resp, err := http.Get(u.String()); err != nil { - return err - } else { - if resp.StatusCode >= 400 { - return errors.New("Got HTTP status code >= 400: " + resp.Status) - } - archive = resp.Body - } + return err } fmt.Fprintf(stdout, "Unpacking to %s\n", name) img, err := srv.images.Import(name, archive, nil) @@ -815,7 +806,10 @@ func (srv *Server) CmdRun(stdin io.ReadCloser, stdout io.Writer, args ...string) return nil } name := cmd.Arg(0) + var img_name string + //var img_version string // Only here for reference var cmdline []string + if len(cmd.Args()) >= 2 { cmdline = cmd.Args()[1:] } @@ -823,6 +817,13 @@ func (srv *Server) CmdRun(stdin io.ReadCloser, stdout io.Writer, args ...string) if name == "" { name = "base" } + + // Separate the name:version tag + if strings.Contains(name, ":") { + parts := strings.SplitN(name, ":", 2) + img_name = parts[0] + //img_version = parts[1] // Only here for reference + } // Choose a default command if needed if len(cmdline) == 0 { *fl_stdin = true @@ -835,7 +836,7 @@ func (srv *Server) CmdRun(stdin io.ReadCloser, stdout io.Writer, args ...string) img = srv.images.Find(name) if img == nil { stdin_noclose := ioutil.NopCloser(stdin) - if err := srv.CmdPull(stdin_noclose, stdout, name); err != nil { + if err := srv.CmdPull(stdin_noclose, stdout, img_name); err != nil { return err } img = srv.images.Find(name)