From 23e68679f080fee7ceb25cf791832f523a3a024a Mon Sep 17 00:00:00 2001 From: Aaron Lehmann Date: Tue, 25 Aug 2015 14:23:52 -0700 Subject: [PATCH] Fix race condition when waiting for a concurrent layer pull Before, this only waited for the download to complete. There was no guarantee that the layer had been registered in the graph and was ready use. This is especially problematic with v2 pulls, which wait for all downloads before extracting layers. Change Broadcaster to allow an error value to be propagated from Close to the waiters. Make the wait stop when the extraction is finished, rather than just the download. This also fixes v2 layer downloads to prefix the pool key with "layer:" instead of "img:". "img:" is the wrong prefix, because this is what v1 uses for entire images. A v1 pull waiting for one of these operations to finish would only wait for that particular layer, not all its dependencies. Signed-off-by: Aaron Lehmann --- graph/load.go | 9 +- graph/pull_v1.go | 41 +++--- graph/pull_v2.go | 57 ++++---- graph/tags.go | 14 +- integration-cli/docker_cli_pull_local_test.go | 133 ++++++++++++++++++ pkg/progressreader/broadcaster.go | 25 +++- 6 files changed, 224 insertions(+), 55 deletions(-) diff --git a/graph/load.go b/graph/load.go index a58c5a3cf9..a3e3551252 100644 --- a/graph/load.go +++ b/graph/load.go @@ -106,13 +106,14 @@ func (s *TagStore) recursiveLoad(address, tmpImageDir string) error { } // ensure no two downloads of the same layer happen at the same time - if ps, found := s.poolAdd("pull", "layer:"+img.ID); found { + poolKey := "layer:" + img.ID + broadcaster, found := s.poolAdd("pull", poolKey) + if found { logrus.Debugf("Image (id: %s) load is already running, waiting", img.ID) - ps.Wait() - return nil + return broadcaster.Wait() } - defer s.poolRemove("pull", "layer:"+img.ID) + defer s.poolRemove("pull", poolKey) if img.Parent != "" { if !s.graph.Exists(img.Parent) { diff --git a/graph/pull_v1.go b/graph/pull_v1.go index 13741292fd..79fb1709fd 100644 --- a/graph/pull_v1.go +++ b/graph/pull_v1.go @@ -138,16 +138,14 @@ func (p *v1Puller) pullRepository(askedTag string) error { } // ensure no two downloads of the same image happen at the same time - broadcaster, found := p.poolAdd("pull", "img:"+img.ID) + poolKey := "img:" + img.ID + broadcaster, found := p.poolAdd("pull", poolKey) + broadcaster.Add(out) if found { - broadcaster.Add(out) - broadcaster.Wait() - out.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), "Download complete", nil)) - errors <- nil + errors <- broadcaster.Wait() return } - broadcaster.Add(out) - defer p.poolRemove("pull", "img:"+img.ID) + defer p.poolRemove("pull", poolKey) // we need to retain it until tagging p.graph.Retain(sessionID, img.ID) @@ -188,6 +186,7 @@ func (p *v1Puller) pullRepository(askedTag string) error { err := fmt.Errorf("Error pulling image (%s) from %s, %v", img.Tag, p.repoInfo.CanonicalName, lastErr) broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), err.Error(), nil)) errors <- err + broadcaster.CloseWithError(err) return } broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), "Download complete", nil)) @@ -225,8 +224,9 @@ func (p *v1Puller) pullRepository(askedTag string) error { return nil } -func (p *v1Puller) pullImage(out io.Writer, imgID, endpoint string, token []string) (bool, error) { - history, err := p.session.GetRemoteHistory(imgID, endpoint) +func (p *v1Puller) pullImage(out io.Writer, imgID, endpoint string, token []string) (layersDownloaded bool, err error) { + var history []string + history, err = p.session.GetRemoteHistory(imgID, endpoint) if err != nil { return false, err } @@ -239,20 +239,28 @@ func (p *v1Puller) pullImage(out io.Writer, imgID, endpoint string, token []stri p.graph.Retain(sessionID, history[1:]...) defer p.graph.Release(sessionID, history[1:]...) - layersDownloaded := false + layersDownloaded = false for i := len(history) - 1; i >= 0; i-- { id := history[i] // ensure no two downloads of the same layer happen at the same time - broadcaster, found := p.poolAdd("pull", "layer:"+id) + poolKey := "layer:" + id + broadcaster, found := p.poolAdd("pull", poolKey) + broadcaster.Add(out) if found { logrus.Debugf("Image (id: %s) pull is already running, skipping", id) - broadcaster.Add(out) - broadcaster.Wait() - } else { - broadcaster.Add(out) + err = broadcaster.Wait() + if err != nil { + return layersDownloaded, err + } + continue } - defer p.poolRemove("pull", "layer:"+id) + + // This must use a closure so it captures the value of err when + // the function returns, not when the 'defer' is evaluated. + defer func() { + p.poolRemoveWithError("pull", poolKey, err) + }() if !p.graph.Exists(id) { broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(id), "Pulling metadata", nil)) @@ -328,6 +336,7 @@ func (p *v1Puller) pullImage(out io.Writer, imgID, endpoint string, token []stri } } broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(id), "Download complete", nil)) + broadcaster.Close() } return layersDownloaded, nil } diff --git a/graph/pull_v2.go b/graph/pull_v2.go index 116fee77c5..5691b09e10 100644 --- a/graph/pull_v2.go +++ b/graph/pull_v2.go @@ -74,14 +74,17 @@ func (p *v2Puller) pullV2Repository(tag string) (err error) { } broadcaster, found := p.poolAdd("pull", taggedName) + broadcaster.Add(p.config.OutStream) if found { // Another pull of the same repository is already taking place; just wait for it to finish - broadcaster.Add(p.config.OutStream) - broadcaster.Wait() - return nil + return broadcaster.Wait() } - defer p.poolRemove("pull", taggedName) - broadcaster.Add(p.config.OutStream) + + // This must use a closure so it captures the value of err when the + // function returns, not when the 'defer' is evaluated. + defer func() { + p.poolRemoveWithError("pull", taggedName, err) + }() var layersDownloaded bool for _, tag := range tags { @@ -101,13 +104,15 @@ func (p *v2Puller) pullV2Repository(tag string) (err error) { // downloadInfo is used to pass information from download to extractor type downloadInfo struct { - img *image.Image - tmpFile *os.File - digest digest.Digest - layer distribution.ReadSeekCloser - size int64 - err chan error - out io.Writer // Download progress is written here. + img *image.Image + tmpFile *os.File + digest digest.Digest + layer distribution.ReadSeekCloser + size int64 + err chan error + out io.Writer // Download progress is written here. + poolKey string + broadcaster *progressreader.Broadcaster } type errVerification struct{} @@ -117,19 +122,15 @@ func (errVerification) Error() string { return "verification failed" } func (p *v2Puller) download(di *downloadInfo) { logrus.Debugf("pulling blob %q to %s", di.digest, di.img.ID) - out := di.out - - broadcaster, found := p.poolAdd("pull", "img:"+di.img.ID) + di.poolKey = "layer:" + di.img.ID + broadcaster, found := p.poolAdd("pull", di.poolKey) + broadcaster.Add(di.out) + di.broadcaster = broadcaster if found { - broadcaster.Add(out) - broadcaster.Wait() - out.Write(p.sf.FormatProgress(stringid.TruncateID(di.img.ID), "Download complete", nil)) di.err <- nil return } - broadcaster.Add(out) - defer p.poolRemove("pull", "img:"+di.img.ID) tmpFile, err := ioutil.TempFile("", "GetImageBlob") if err != nil { di.err <- err @@ -279,6 +280,7 @@ func (p *v2Puller) pullV2Tag(out io.Writer, tag, taggedName string) (verified bo // run clean for all downloads to prevent leftovers for _, d := range downloads { defer func(d *downloadInfo) { + p.poolRemoveWithError("pull", d.poolKey, err) if d.tmpFile != nil { d.tmpFile.Close() if err := os.RemoveAll(d.tmpFile.Name()); err != nil { @@ -293,14 +295,21 @@ func (p *v2Puller) pullV2Tag(out io.Writer, tag, taggedName string) (verified bo if err := <-d.err; err != nil { return false, err } + if d.layer == nil { + // Wait for a different pull to download and extract + // this layer. + err = d.broadcaster.Wait() + if err != nil { + return false, err + } continue } - // if tmpFile is empty assume download and extracted elsewhere + d.tmpFile.Seek(0, 0) reader := progressreader.New(progressreader.Config{ In: d.tmpFile, - Out: out, + Out: d.broadcaster, Formatter: p.sf, Size: d.size, NewLines: false, @@ -317,8 +326,8 @@ func (p *v2Puller) pullV2Tag(out io.Writer, tag, taggedName string) (verified bo return false, err } - // FIXME: Pool release here for parallel tag pull (ensures any downloads block until fully extracted) - out.Write(p.sf.FormatProgress(stringid.TruncateID(d.img.ID), "Pull complete", nil)) + d.broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(d.img.ID), "Pull complete", nil)) + d.broadcaster.Close() tagUpdated = true } diff --git a/graph/tags.go b/graph/tags.go index 09a9bc8530..804898c1d3 100644 --- a/graph/tags.go +++ b/graph/tags.go @@ -462,18 +462,18 @@ func (store *TagStore) poolAdd(kind, key string) (*progressreader.Broadcaster, b return broadcaster, false } -func (store *TagStore) poolRemove(kind, key string) error { +func (store *TagStore) poolRemoveWithError(kind, key string, broadcasterResult error) error { store.Lock() defer store.Unlock() switch kind { case "pull": - if ps, exists := store.pullingPool[key]; exists { - ps.Close() + if broadcaster, exists := store.pullingPool[key]; exists { + broadcaster.CloseWithError(broadcasterResult) delete(store.pullingPool, key) } case "push": - if ps, exists := store.pushingPool[key]; exists { - ps.Close() + if broadcaster, exists := store.pushingPool[key]; exists { + broadcaster.CloseWithError(broadcasterResult) delete(store.pushingPool, key) } default: @@ -481,3 +481,7 @@ func (store *TagStore) poolRemove(kind, key string) error { } return nil } + +func (store *TagStore) poolRemove(kind, key string) error { + return store.poolRemoveWithError(kind, key, nil) +} diff --git a/integration-cli/docker_cli_pull_local_test.go b/integration-cli/docker_cli_pull_local_test.go index 350e871cf7..a875f38456 100644 --- a/integration-cli/docker_cli_pull_local_test.go +++ b/integration-cli/docker_cli_pull_local_test.go @@ -2,6 +2,8 @@ package main import ( "fmt" + "os/exec" + "strings" "github.com/go-check/check" ) @@ -37,3 +39,134 @@ func (s *DockerRegistrySuite) TestPullImageWithAliases(c *check.C) { } } } + +// TestConcurrentPullWholeRepo pulls the same repo concurrently. +func (s *DockerRegistrySuite) TestConcurrentPullWholeRepo(c *check.C) { + repoName := fmt.Sprintf("%v/dockercli/busybox", privateRegistryURL) + + repos := []string{} + for _, tag := range []string{"recent", "fresh", "todays"} { + repo := fmt.Sprintf("%v:%v", repoName, tag) + _, err := buildImage(repo, fmt.Sprintf(` + FROM busybox + ENTRYPOINT ["/bin/echo"] + ENV FOO foo + ENV BAR bar + CMD echo %s + `, repo), true) + if err != nil { + c.Fatal(err) + } + dockerCmd(c, "push", repo) + repos = append(repos, repo) + } + + // Clear local images store. + args := append([]string{"rmi"}, repos...) + dockerCmd(c, args...) + + // Run multiple re-pulls concurrently + results := make(chan error) + numPulls := 3 + + for i := 0; i != numPulls; i++ { + go func() { + _, _, err := runCommandWithOutput(exec.Command(dockerBinary, "pull", "-a", repoName)) + results <- err + }() + } + + // These checks are separate from the loop above because the check + // package is not goroutine-safe. + for i := 0; i != numPulls; i++ { + err := <-results + c.Assert(err, check.IsNil, check.Commentf("concurrent pull failed with error: %v", err)) + } + + // Ensure all tags were pulled successfully + for _, repo := range repos { + dockerCmd(c, "inspect", repo) + out, _ := dockerCmd(c, "run", "--rm", repo) + if strings.TrimSpace(out) != "/bin/sh -c echo "+repo { + c.Fatalf("CMD did not contain /bin/sh -c echo %s: %s", repo, out) + } + } +} + +// TestConcurrentFailingPull tries a concurrent pull that doesn't succeed. +func (s *DockerRegistrySuite) TestConcurrentFailingPull(c *check.C) { + repoName := fmt.Sprintf("%v/dockercli/busybox", privateRegistryURL) + + // Run multiple pulls concurrently + results := make(chan error) + numPulls := 3 + + for i := 0; i != numPulls; i++ { + go func() { + _, _, err := runCommandWithOutput(exec.Command(dockerBinary, "pull", repoName+":asdfasdf")) + results <- err + }() + } + + // These checks are separate from the loop above because the check + // package is not goroutine-safe. + for i := 0; i != numPulls; i++ { + err := <-results + if err == nil { + c.Fatal("expected pull to fail") + } + } +} + +// TestConcurrentPullMultipleTags pulls multiple tags from the same repo +// concurrently. +func (s *DockerRegistrySuite) TestConcurrentPullMultipleTags(c *check.C) { + repoName := fmt.Sprintf("%v/dockercli/busybox", privateRegistryURL) + + repos := []string{} + for _, tag := range []string{"recent", "fresh", "todays"} { + repo := fmt.Sprintf("%v:%v", repoName, tag) + _, err := buildImage(repo, fmt.Sprintf(` + FROM busybox + ENTRYPOINT ["/bin/echo"] + ENV FOO foo + ENV BAR bar + CMD echo %s + `, repo), true) + if err != nil { + c.Fatal(err) + } + dockerCmd(c, "push", repo) + repos = append(repos, repo) + } + + // Clear local images store. + args := append([]string{"rmi"}, repos...) + dockerCmd(c, args...) + + // Re-pull individual tags, in parallel + results := make(chan error) + + for _, repo := range repos { + go func(repo string) { + _, _, err := runCommandWithOutput(exec.Command(dockerBinary, "pull", repo)) + results <- err + }(repo) + } + + // These checks are separate from the loop above because the check + // package is not goroutine-safe. + for range repos { + err := <-results + c.Assert(err, check.IsNil, check.Commentf("concurrent pull failed with error: %v", err)) + } + + // Ensure all tags were pulled successfully + for _, repo := range repos { + dockerCmd(c, "inspect", repo) + out, _ := dockerCmd(c, "run", "--rm", repo) + if strings.TrimSpace(out) != "/bin/sh -c echo "+repo { + c.Fatalf("CMD did not contain /bin/sh -c echo %s: %s", repo, out) + } + } +} diff --git a/pkg/progressreader/broadcaster.go b/pkg/progressreader/broadcaster.go index 5118e9e2f8..58604aa44b 100644 --- a/pkg/progressreader/broadcaster.go +++ b/pkg/progressreader/broadcaster.go @@ -27,6 +27,9 @@ type Broadcaster struct { // isClosed is set to true when Close is called to avoid closing c // multiple times. isClosed bool + // result is the argument passed to the first call of Close, and + // returned to callers of Wait + result error } // NewBroadcaster returns a Broadcaster structure @@ -134,23 +137,33 @@ func (broadcaster *Broadcaster) Add(w io.Writer) error { return nil } -// Close signals to all observers that the operation has finished. -func (broadcaster *Broadcaster) Close() { +// CloseWithError signals to all observers that the operation has finished. Its +// argument is a result that should be returned to waiters blocking on Wait. +func (broadcaster *Broadcaster) CloseWithError(result error) { broadcaster.Lock() if broadcaster.isClosed { broadcaster.Unlock() return } broadcaster.isClosed = true + broadcaster.result = result close(broadcaster.c) broadcaster.cond.Broadcast() broadcaster.Unlock() - // Don't return from Close until all writers have caught up. + // Don't return until all writers have caught up. broadcaster.wg.Wait() } -// Wait blocks until the operation is marked as completed by the Done method. -func (broadcaster *Broadcaster) Wait() { - <-broadcaster.c +// Close signals to all observers that the operation has finished. It causes +// all calls to Wait to return nil. +func (broadcaster *Broadcaster) Close() { + broadcaster.CloseWithError(nil) +} + +// Wait blocks until the operation is marked as completed by the Done method. +// It returns the argument that was passed to Close. +func (broadcaster *Broadcaster) Wait() error { + <-broadcaster.c + return broadcaster.result }