diff --git a/graph/load.go b/graph/load.go index a58c5a3cf9..a3e3551252 100644 --- a/graph/load.go +++ b/graph/load.go @@ -106,13 +106,14 @@ func (s *TagStore) recursiveLoad(address, tmpImageDir string) error { } // ensure no two downloads of the same layer happen at the same time - if ps, found := s.poolAdd("pull", "layer:"+img.ID); found { + poolKey := "layer:" + img.ID + broadcaster, found := s.poolAdd("pull", poolKey) + if found { logrus.Debugf("Image (id: %s) load is already running, waiting", img.ID) - ps.Wait() - return nil + return broadcaster.Wait() } - defer s.poolRemove("pull", "layer:"+img.ID) + defer s.poolRemove("pull", poolKey) if img.Parent != "" { if !s.graph.Exists(img.Parent) { diff --git a/graph/pull_v1.go b/graph/pull_v1.go index 13741292fd..79fb1709fd 100644 --- a/graph/pull_v1.go +++ b/graph/pull_v1.go @@ -138,16 +138,14 @@ func (p *v1Puller) pullRepository(askedTag string) error { } // ensure no two downloads of the same image happen at the same time - broadcaster, found := p.poolAdd("pull", "img:"+img.ID) + poolKey := "img:" + img.ID + broadcaster, found := p.poolAdd("pull", poolKey) + broadcaster.Add(out) if found { - broadcaster.Add(out) - broadcaster.Wait() - out.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), "Download complete", nil)) - errors <- nil + errors <- broadcaster.Wait() return } - broadcaster.Add(out) - defer p.poolRemove("pull", "img:"+img.ID) + defer p.poolRemove("pull", poolKey) // we need to retain it until tagging p.graph.Retain(sessionID, img.ID) @@ -188,6 +186,7 @@ func (p *v1Puller) pullRepository(askedTag string) error { err := fmt.Errorf("Error pulling image (%s) from %s, %v", img.Tag, p.repoInfo.CanonicalName, lastErr) broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), err.Error(), nil)) errors <- err + broadcaster.CloseWithError(err) return } broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), "Download complete", nil)) @@ -225,8 +224,9 @@ func (p *v1Puller) pullRepository(askedTag string) error { return nil } -func (p *v1Puller) pullImage(out io.Writer, imgID, endpoint string, token []string) (bool, error) { - history, err := p.session.GetRemoteHistory(imgID, endpoint) +func (p *v1Puller) pullImage(out io.Writer, imgID, endpoint string, token []string) (layersDownloaded bool, err error) { + var history []string + history, err = p.session.GetRemoteHistory(imgID, endpoint) if err != nil { return false, err } @@ -239,20 +239,28 @@ func (p *v1Puller) pullImage(out io.Writer, imgID, endpoint string, token []stri p.graph.Retain(sessionID, history[1:]...) defer p.graph.Release(sessionID, history[1:]...) - layersDownloaded := false + layersDownloaded = false for i := len(history) - 1; i >= 0; i-- { id := history[i] // ensure no two downloads of the same layer happen at the same time - broadcaster, found := p.poolAdd("pull", "layer:"+id) + poolKey := "layer:" + id + broadcaster, found := p.poolAdd("pull", poolKey) + broadcaster.Add(out) if found { logrus.Debugf("Image (id: %s) pull is already running, skipping", id) - broadcaster.Add(out) - broadcaster.Wait() - } else { - broadcaster.Add(out) + err = broadcaster.Wait() + if err != nil { + return layersDownloaded, err + } + continue } - defer p.poolRemove("pull", "layer:"+id) + + // This must use a closure so it captures the value of err when + // the function returns, not when the 'defer' is evaluated. + defer func() { + p.poolRemoveWithError("pull", poolKey, err) + }() if !p.graph.Exists(id) { broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(id), "Pulling metadata", nil)) @@ -328,6 +336,7 @@ func (p *v1Puller) pullImage(out io.Writer, imgID, endpoint string, token []stri } } broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(id), "Download complete", nil)) + broadcaster.Close() } return layersDownloaded, nil } diff --git a/graph/pull_v2.go b/graph/pull_v2.go index 116fee77c5..5691b09e10 100644 --- a/graph/pull_v2.go +++ b/graph/pull_v2.go @@ -74,14 +74,17 @@ func (p *v2Puller) pullV2Repository(tag string) (err error) { } broadcaster, found := p.poolAdd("pull", taggedName) + broadcaster.Add(p.config.OutStream) if found { // Another pull of the same repository is already taking place; just wait for it to finish - broadcaster.Add(p.config.OutStream) - broadcaster.Wait() - return nil + return broadcaster.Wait() } - defer p.poolRemove("pull", taggedName) - broadcaster.Add(p.config.OutStream) + + // This must use a closure so it captures the value of err when the + // function returns, not when the 'defer' is evaluated. + defer func() { + p.poolRemoveWithError("pull", taggedName, err) + }() var layersDownloaded bool for _, tag := range tags { @@ -101,13 +104,15 @@ func (p *v2Puller) pullV2Repository(tag string) (err error) { // downloadInfo is used to pass information from download to extractor type downloadInfo struct { - img *image.Image - tmpFile *os.File - digest digest.Digest - layer distribution.ReadSeekCloser - size int64 - err chan error - out io.Writer // Download progress is written here. + img *image.Image + tmpFile *os.File + digest digest.Digest + layer distribution.ReadSeekCloser + size int64 + err chan error + out io.Writer // Download progress is written here. + poolKey string + broadcaster *progressreader.Broadcaster } type errVerification struct{} @@ -117,19 +122,15 @@ func (errVerification) Error() string { return "verification failed" } func (p *v2Puller) download(di *downloadInfo) { logrus.Debugf("pulling blob %q to %s", di.digest, di.img.ID) - out := di.out - - broadcaster, found := p.poolAdd("pull", "img:"+di.img.ID) + di.poolKey = "layer:" + di.img.ID + broadcaster, found := p.poolAdd("pull", di.poolKey) + broadcaster.Add(di.out) + di.broadcaster = broadcaster if found { - broadcaster.Add(out) - broadcaster.Wait() - out.Write(p.sf.FormatProgress(stringid.TruncateID(di.img.ID), "Download complete", nil)) di.err <- nil return } - broadcaster.Add(out) - defer p.poolRemove("pull", "img:"+di.img.ID) tmpFile, err := ioutil.TempFile("", "GetImageBlob") if err != nil { di.err <- err @@ -279,6 +280,7 @@ func (p *v2Puller) pullV2Tag(out io.Writer, tag, taggedName string) (verified bo // run clean for all downloads to prevent leftovers for _, d := range downloads { defer func(d *downloadInfo) { + p.poolRemoveWithError("pull", d.poolKey, err) if d.tmpFile != nil { d.tmpFile.Close() if err := os.RemoveAll(d.tmpFile.Name()); err != nil { @@ -293,14 +295,21 @@ func (p *v2Puller) pullV2Tag(out io.Writer, tag, taggedName string) (verified bo if err := <-d.err; err != nil { return false, err } + if d.layer == nil { + // Wait for a different pull to download and extract + // this layer. + err = d.broadcaster.Wait() + if err != nil { + return false, err + } continue } - // if tmpFile is empty assume download and extracted elsewhere + d.tmpFile.Seek(0, 0) reader := progressreader.New(progressreader.Config{ In: d.tmpFile, - Out: out, + Out: d.broadcaster, Formatter: p.sf, Size: d.size, NewLines: false, @@ -317,8 +326,8 @@ func (p *v2Puller) pullV2Tag(out io.Writer, tag, taggedName string) (verified bo return false, err } - // FIXME: Pool release here for parallel tag pull (ensures any downloads block until fully extracted) - out.Write(p.sf.FormatProgress(stringid.TruncateID(d.img.ID), "Pull complete", nil)) + d.broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(d.img.ID), "Pull complete", nil)) + d.broadcaster.Close() tagUpdated = true } diff --git a/graph/tags.go b/graph/tags.go index 09a9bc8530..804898c1d3 100644 --- a/graph/tags.go +++ b/graph/tags.go @@ -462,18 +462,18 @@ func (store *TagStore) poolAdd(kind, key string) (*progressreader.Broadcaster, b return broadcaster, false } -func (store *TagStore) poolRemove(kind, key string) error { +func (store *TagStore) poolRemoveWithError(kind, key string, broadcasterResult error) error { store.Lock() defer store.Unlock() switch kind { case "pull": - if ps, exists := store.pullingPool[key]; exists { - ps.Close() + if broadcaster, exists := store.pullingPool[key]; exists { + broadcaster.CloseWithError(broadcasterResult) delete(store.pullingPool, key) } case "push": - if ps, exists := store.pushingPool[key]; exists { - ps.Close() + if broadcaster, exists := store.pushingPool[key]; exists { + broadcaster.CloseWithError(broadcasterResult) delete(store.pushingPool, key) } default: @@ -481,3 +481,7 @@ func (store *TagStore) poolRemove(kind, key string) error { } return nil } + +func (store *TagStore) poolRemove(kind, key string) error { + return store.poolRemoveWithError(kind, key, nil) +} diff --git a/integration-cli/docker_cli_pull_local_test.go b/integration-cli/docker_cli_pull_local_test.go index 350e871cf7..a875f38456 100644 --- a/integration-cli/docker_cli_pull_local_test.go +++ b/integration-cli/docker_cli_pull_local_test.go @@ -2,6 +2,8 @@ package main import ( "fmt" + "os/exec" + "strings" "github.com/go-check/check" ) @@ -37,3 +39,134 @@ func (s *DockerRegistrySuite) TestPullImageWithAliases(c *check.C) { } } } + +// TestConcurrentPullWholeRepo pulls the same repo concurrently. +func (s *DockerRegistrySuite) TestConcurrentPullWholeRepo(c *check.C) { + repoName := fmt.Sprintf("%v/dockercli/busybox", privateRegistryURL) + + repos := []string{} + for _, tag := range []string{"recent", "fresh", "todays"} { + repo := fmt.Sprintf("%v:%v", repoName, tag) + _, err := buildImage(repo, fmt.Sprintf(` + FROM busybox + ENTRYPOINT ["/bin/echo"] + ENV FOO foo + ENV BAR bar + CMD echo %s + `, repo), true) + if err != nil { + c.Fatal(err) + } + dockerCmd(c, "push", repo) + repos = append(repos, repo) + } + + // Clear local images store. + args := append([]string{"rmi"}, repos...) + dockerCmd(c, args...) + + // Run multiple re-pulls concurrently + results := make(chan error) + numPulls := 3 + + for i := 0; i != numPulls; i++ { + go func() { + _, _, err := runCommandWithOutput(exec.Command(dockerBinary, "pull", "-a", repoName)) + results <- err + }() + } + + // These checks are separate from the loop above because the check + // package is not goroutine-safe. + for i := 0; i != numPulls; i++ { + err := <-results + c.Assert(err, check.IsNil, check.Commentf("concurrent pull failed with error: %v", err)) + } + + // Ensure all tags were pulled successfully + for _, repo := range repos { + dockerCmd(c, "inspect", repo) + out, _ := dockerCmd(c, "run", "--rm", repo) + if strings.TrimSpace(out) != "/bin/sh -c echo "+repo { + c.Fatalf("CMD did not contain /bin/sh -c echo %s: %s", repo, out) + } + } +} + +// TestConcurrentFailingPull tries a concurrent pull that doesn't succeed. +func (s *DockerRegistrySuite) TestConcurrentFailingPull(c *check.C) { + repoName := fmt.Sprintf("%v/dockercli/busybox", privateRegistryURL) + + // Run multiple pulls concurrently + results := make(chan error) + numPulls := 3 + + for i := 0; i != numPulls; i++ { + go func() { + _, _, err := runCommandWithOutput(exec.Command(dockerBinary, "pull", repoName+":asdfasdf")) + results <- err + }() + } + + // These checks are separate from the loop above because the check + // package is not goroutine-safe. + for i := 0; i != numPulls; i++ { + err := <-results + if err == nil { + c.Fatal("expected pull to fail") + } + } +} + +// TestConcurrentPullMultipleTags pulls multiple tags from the same repo +// concurrently. +func (s *DockerRegistrySuite) TestConcurrentPullMultipleTags(c *check.C) { + repoName := fmt.Sprintf("%v/dockercli/busybox", privateRegistryURL) + + repos := []string{} + for _, tag := range []string{"recent", "fresh", "todays"} { + repo := fmt.Sprintf("%v:%v", repoName, tag) + _, err := buildImage(repo, fmt.Sprintf(` + FROM busybox + ENTRYPOINT ["/bin/echo"] + ENV FOO foo + ENV BAR bar + CMD echo %s + `, repo), true) + if err != nil { + c.Fatal(err) + } + dockerCmd(c, "push", repo) + repos = append(repos, repo) + } + + // Clear local images store. + args := append([]string{"rmi"}, repos...) + dockerCmd(c, args...) + + // Re-pull individual tags, in parallel + results := make(chan error) + + for _, repo := range repos { + go func(repo string) { + _, _, err := runCommandWithOutput(exec.Command(dockerBinary, "pull", repo)) + results <- err + }(repo) + } + + // These checks are separate from the loop above because the check + // package is not goroutine-safe. + for range repos { + err := <-results + c.Assert(err, check.IsNil, check.Commentf("concurrent pull failed with error: %v", err)) + } + + // Ensure all tags were pulled successfully + for _, repo := range repos { + dockerCmd(c, "inspect", repo) + out, _ := dockerCmd(c, "run", "--rm", repo) + if strings.TrimSpace(out) != "/bin/sh -c echo "+repo { + c.Fatalf("CMD did not contain /bin/sh -c echo %s: %s", repo, out) + } + } +} diff --git a/pkg/progressreader/broadcaster.go b/pkg/progressreader/broadcaster.go index 5118e9e2f8..58604aa44b 100644 --- a/pkg/progressreader/broadcaster.go +++ b/pkg/progressreader/broadcaster.go @@ -27,6 +27,9 @@ type Broadcaster struct { // isClosed is set to true when Close is called to avoid closing c // multiple times. isClosed bool + // result is the argument passed to the first call of Close, and + // returned to callers of Wait + result error } // NewBroadcaster returns a Broadcaster structure @@ -134,23 +137,33 @@ func (broadcaster *Broadcaster) Add(w io.Writer) error { return nil } -// Close signals to all observers that the operation has finished. -func (broadcaster *Broadcaster) Close() { +// CloseWithError signals to all observers that the operation has finished. Its +// argument is a result that should be returned to waiters blocking on Wait. +func (broadcaster *Broadcaster) CloseWithError(result error) { broadcaster.Lock() if broadcaster.isClosed { broadcaster.Unlock() return } broadcaster.isClosed = true + broadcaster.result = result close(broadcaster.c) broadcaster.cond.Broadcast() broadcaster.Unlock() - // Don't return from Close until all writers have caught up. + // Don't return until all writers have caught up. broadcaster.wg.Wait() } -// Wait blocks until the operation is marked as completed by the Done method. -func (broadcaster *Broadcaster) Wait() { - <-broadcaster.c +// Close signals to all observers that the operation has finished. It causes +// all calls to Wait to return nil. +func (broadcaster *Broadcaster) Close() { + broadcaster.CloseWithError(nil) +} + +// Wait blocks until the operation is marked as completed by the Done method. +// It returns the argument that was passed to Close. +func (broadcaster *Broadcaster) Wait() error { + <-broadcaster.c + return broadcaster.result }