From 572ce802306a4e919802e5b77cbeca94acda7c0a Mon Sep 17 00:00:00 2001 From: Aaron Lehmann Date: Fri, 13 Nov 2015 16:59:01 -0800 Subject: [PATCH] Improved push and pull with upload manager and download manager This commit adds a transfer manager which deduplicates and schedules transfers, and also an upload manager and download manager that build on top of the transfer manager to provide high-level interfaces for uploads and downloads. The push and pull code is modified to use these building blocks. Some benefits of the changes: - Simplification of push/pull code - Pushes can upload layers concurrently - Failed downloads and uploads are retried after backoff delays - Cancellation is supported, but individual transfers will only be cancelled if all pushes or pulls using them are cancelled. - The distribution code is decoupled from Docker Engine packages and API conventions (i.e. streamformatter), which will make it easier to split out. This commit also includes unit tests for the new distribution/xfer package. The tests cover 87.8% of the statements in the package. Signed-off-by: Aaron Lehmann --- api/client/build.go | 26 +- api/server/router/local/image.go | 18 +- builder/dockerfile/internals.go | 16 +- daemon/daemon.go | 79 +++- daemon/daemonbuilder/builder.go | 6 +- daemon/import.go | 14 +- distribution/metadata/v1_id_service.go | 10 +- distribution/metadata/v1_id_service_test.go | 8 +- distribution/pool.go | 51 --- distribution/pool_test.go | 28 -- distribution/pull.go | 57 ++- distribution/pull_v1.go | 353 ++++++---------- distribution/pull_v2.go | 297 +++++--------- distribution/push.go | 36 +- distribution/push_v1.go | 52 +-- distribution/push_v2.go | 223 ++++++----- distribution/push_v2_test.go | 8 +- distribution/registry.go | 24 +- distribution/registry_unit_test.go | 7 +- distribution/xfer/download.go | 420 ++++++++++++++++++++ distribution/xfer/download_test.go | 332 ++++++++++++++++ distribution/xfer/transfer.go | 343 ++++++++++++++++ distribution/xfer/transfer_test.go | 385 ++++++++++++++++++ distribution/xfer/upload.go | 159 ++++++++ distribution/xfer/upload_test.go | 153 +++++++ integration-cli/docker_cli_pull_test.go | 14 +- pkg/broadcaster/buffered.go | 167 -------- pkg/ioutils/readers.go | 71 ++++ pkg/ioutils/readers_test.go | 27 ++ pkg/progress/progress.go | 63 +++ pkg/progress/progressreader.go | 59 +++ pkg/progress/progressreader_test.go | 75 ++++ pkg/progressreader/progressreader.go | 68 ---- pkg/progressreader/progressreader_test.go | 94 ----- pkg/streamformatter/streamformatter.go | 39 ++ registry/session.go | 20 +- 36 files changed, 2675 insertions(+), 1127 deletions(-) delete mode 100644 distribution/pool.go delete mode 100644 distribution/pool_test.go create mode 100644 distribution/xfer/download.go create mode 100644 distribution/xfer/download_test.go create mode 100644 distribution/xfer/transfer.go create mode 100644 distribution/xfer/transfer_test.go create mode 100644 distribution/xfer/upload.go create mode 100644 distribution/xfer/upload_test.go delete mode 100644 pkg/broadcaster/buffered.go create mode 100644 pkg/progress/progress.go create mode 100644 pkg/progress/progressreader.go create mode 100644 pkg/progress/progressreader_test.go delete mode 100644 pkg/progressreader/progressreader.go delete mode 100644 pkg/progressreader/progressreader_test.go diff --git a/api/client/build.go b/api/client/build.go index 8af15b2465..0516f5385f 100644 --- a/api/client/build.go +++ b/api/client/build.go @@ -23,7 +23,7 @@ import ( "github.com/docker/docker/pkg/httputils" "github.com/docker/docker/pkg/jsonmessage" flag "github.com/docker/docker/pkg/mflag" - "github.com/docker/docker/pkg/progressreader" + "github.com/docker/docker/pkg/progress" "github.com/docker/docker/pkg/streamformatter" "github.com/docker/docker/pkg/ulimit" "github.com/docker/docker/pkg/units" @@ -169,16 +169,9 @@ func (cli *DockerCli) CmdBuild(args ...string) error { context = replaceDockerfileTarWrapper(context, newDockerfile, relDockerfile) // Setup an upload progress bar - // FIXME: ProgressReader shouldn't be this annoying to use - sf := streamformatter.NewStreamFormatter() - var body io.Reader = progressreader.New(progressreader.Config{ - In: context, - Out: cli.out, - Formatter: sf, - NewLines: true, - ID: "", - Action: "Sending build context to Docker daemon", - }) + progressOutput := streamformatter.NewStreamFormatter().NewProgressOutput(cli.out, true) + + var body io.Reader = progress.NewProgressReader(context, progressOutput, 0, "", "Sending build context to Docker daemon") var memory int64 if *flMemoryString != "" { @@ -447,17 +440,10 @@ func getContextFromURL(out io.Writer, remoteURL, dockerfileName string) (absCont return "", "", fmt.Errorf("unable to download remote context %s: %v", remoteURL, err) } defer response.Body.Close() + progressOutput := streamformatter.NewStreamFormatter().NewProgressOutput(out, true) // Pass the response body through a progress reader. - progReader := &progressreader.Config{ - In: response.Body, - Out: out, - Formatter: streamformatter.NewStreamFormatter(), - Size: response.ContentLength, - NewLines: true, - ID: "", - Action: fmt.Sprintf("Downloading build context from remote url: %s", remoteURL), - } + progReader := progress.NewProgressReader(response.Body, progressOutput, response.ContentLength, "", fmt.Sprintf("Downloading build context from remote url: %s", remoteURL)) return getContextFromReader(progReader, dockerfileName) } diff --git a/api/server/router/local/image.go b/api/server/router/local/image.go index 29d64dff15..2ddc28610d 100644 --- a/api/server/router/local/image.go +++ b/api/server/router/local/image.go @@ -23,7 +23,7 @@ import ( "github.com/docker/docker/pkg/archive" "github.com/docker/docker/pkg/chrootarchive" "github.com/docker/docker/pkg/ioutils" - "github.com/docker/docker/pkg/progressreader" + "github.com/docker/docker/pkg/progress" "github.com/docker/docker/pkg/streamformatter" "github.com/docker/docker/pkg/ulimit" "github.com/docker/docker/runconfig" @@ -325,7 +325,7 @@ func (s *router) postBuild(ctx context.Context, w http.ResponseWriter, r *http.R sf := streamformatter.NewJSONStreamFormatter() errf := func(err error) error { // Do not write the error in the http output if it's still empty. - // This prevents from writing a 200(OK) when there is an interal error. + // This prevents from writing a 200(OK) when there is an internal error. if !output.Flushed() { return err } @@ -401,23 +401,17 @@ func (s *router) postBuild(ctx context.Context, w http.ResponseWriter, r *http.R remoteURL := r.FormValue("remote") // Currently, only used if context is from a remote url. - // The field `In` is set by DetectContextFromRemoteURL. // Look at code in DetectContextFromRemoteURL for more information. - pReader := &progressreader.Config{ - // TODO: make progressreader streamformatter-agnostic - Out: output, - Formatter: sf, - Size: r.ContentLength, - NewLines: true, - ID: "Downloading context", - Action: remoteURL, + createProgressReader := func(in io.ReadCloser) io.ReadCloser { + progressOutput := sf.NewProgressOutput(output, true) + return progress.NewProgressReader(in, progressOutput, r.ContentLength, "Downloading context", remoteURL) } var ( context builder.ModifiableContext dockerfileName string ) - context, dockerfileName, err = daemonbuilder.DetectContextFromRemoteURL(r.Body, remoteURL, pReader) + context, dockerfileName, err = daemonbuilder.DetectContextFromRemoteURL(r.Body, remoteURL, createProgressReader) if err != nil { return errf(err) } diff --git a/builder/dockerfile/internals.go b/builder/dockerfile/internals.go index e4d8540c3f..bf407f278e 100644 --- a/builder/dockerfile/internals.go +++ b/builder/dockerfile/internals.go @@ -29,7 +29,7 @@ import ( "github.com/docker/docker/pkg/httputils" "github.com/docker/docker/pkg/ioutils" "github.com/docker/docker/pkg/jsonmessage" - "github.com/docker/docker/pkg/progressreader" + "github.com/docker/docker/pkg/progress" "github.com/docker/docker/pkg/streamformatter" "github.com/docker/docker/pkg/stringid" "github.com/docker/docker/pkg/stringutils" @@ -264,17 +264,11 @@ func (b *Builder) download(srcURL string) (fi builder.FileInfo, err error) { return } + stdoutFormatter := b.Stdout.(*streamformatter.StdoutFormatter) + progressOutput := stdoutFormatter.StreamFormatter.NewProgressOutput(stdoutFormatter.Writer, true) + progressReader := progress.NewProgressReader(resp.Body, progressOutput, resp.ContentLength, "", "Downloading") // Download and dump result to tmp file - if _, err = io.Copy(tmpFile, progressreader.New(progressreader.Config{ - In: resp.Body, - // TODO: make progressreader streamformatter agnostic - Out: b.Stdout.(*streamformatter.StdoutFormatter).Writer, - Formatter: b.Stdout.(*streamformatter.StdoutFormatter).StreamFormatter, - Size: resp.ContentLength, - NewLines: true, - ID: "", - Action: "Downloading", - })); err != nil { + if _, err = io.Copy(tmpFile, progressReader); err != nil { tmpFile.Close() return } diff --git a/daemon/daemon.go b/daemon/daemon.go index 441e20c166..9d8a8e6eef 100644 --- a/daemon/daemon.go +++ b/daemon/daemon.go @@ -34,6 +34,7 @@ import ( "github.com/docker/docker/daemon/network" "github.com/docker/docker/distribution" dmetadata "github.com/docker/docker/distribution/metadata" + "github.com/docker/docker/distribution/xfer" derr "github.com/docker/docker/errors" "github.com/docker/docker/image" "github.com/docker/docker/image/tarexport" @@ -49,7 +50,9 @@ import ( "github.com/docker/docker/pkg/namesgenerator" "github.com/docker/docker/pkg/nat" "github.com/docker/docker/pkg/parsers/filters" + "github.com/docker/docker/pkg/progress" "github.com/docker/docker/pkg/signal" + "github.com/docker/docker/pkg/streamformatter" "github.com/docker/docker/pkg/stringid" "github.com/docker/docker/pkg/stringutils" "github.com/docker/docker/pkg/sysinfo" @@ -66,6 +69,16 @@ import ( lntypes "github.com/docker/libnetwork/types" "github.com/docker/libtrust" "github.com/opencontainers/runc/libcontainer" + "golang.org/x/net/context" +) + +const ( + // maxDownloadConcurrency is the maximum number of downloads that + // may take place at a time for each pull. + maxDownloadConcurrency = 3 + // maxUploadConcurrency is the maximum number of uploads that + // may take place at a time for each push. + maxUploadConcurrency = 5 ) var ( @@ -126,7 +139,8 @@ type Daemon struct { containers *contStore execCommands *exec.Store tagStore tag.Store - distributionPool *distribution.Pool + downloadManager *xfer.LayerDownloadManager + uploadManager *xfer.LayerUploadManager distributionMetadataStore dmetadata.Store trustKey libtrust.PrivateKey idIndex *truncindex.TruncIndex @@ -738,7 +752,8 @@ func NewDaemon(config *Config, registryService *registry.Service) (daemon *Daemo return nil, err } - distributionPool := distribution.NewPool() + d.downloadManager = xfer.NewLayerDownloadManager(d.layerStore, maxDownloadConcurrency) + d.uploadManager = xfer.NewLayerUploadManager(maxUploadConcurrency) ifs, err := image.NewFSStoreBackend(filepath.Join(imageRoot, "imagedb")) if err != nil { @@ -834,7 +849,6 @@ func NewDaemon(config *Config, registryService *registry.Service) (daemon *Daemo d.containers = &contStore{s: make(map[string]*container.Container)} d.execCommands = exec.NewStore() d.tagStore = tagStore - d.distributionPool = distributionPool d.distributionMetadataStore = distributionMetadataStore d.trustKey = trustKey d.idIndex = truncindex.NewTruncIndex([]string{}) @@ -1038,23 +1052,53 @@ func (daemon *Daemon) TagImage(newTag reference.Named, imageName string) error { return nil } +func writeDistributionProgress(cancelFunc func(), outStream io.Writer, progressChan <-chan progress.Progress) { + progressOutput := streamformatter.NewJSONStreamFormatter().NewProgressOutput(outStream, false) + operationCancelled := false + + for prog := range progressChan { + if err := progressOutput.WriteProgress(prog); err != nil && !operationCancelled { + logrus.Errorf("error writing progress to client: %v", err) + cancelFunc() + operationCancelled = true + // Don't return, because we need to continue draining + // progressChan until it's closed to avoid a deadlock. + } + } +} + // PullImage initiates a pull operation. image is the repository name to pull, and // tag may be either empty, or indicate a specific tag to pull. func (daemon *Daemon) PullImage(ref reference.Named, metaHeaders map[string][]string, authConfig *cliconfig.AuthConfig, outStream io.Writer) error { + // Include a buffer so that slow client connections don't affect + // transfer performance. + progressChan := make(chan progress.Progress, 100) + + writesDone := make(chan struct{}) + + ctx, cancelFunc := context.WithCancel(context.Background()) + + go func() { + writeDistributionProgress(cancelFunc, outStream, progressChan) + close(writesDone) + }() + imagePullConfig := &distribution.ImagePullConfig{ MetaHeaders: metaHeaders, AuthConfig: authConfig, - OutStream: outStream, + ProgressOutput: progress.ChanOutput(progressChan), RegistryService: daemon.RegistryService, EventsService: daemon.EventsService, MetadataStore: daemon.distributionMetadataStore, - LayerStore: daemon.layerStore, ImageStore: daemon.imageStore, TagStore: daemon.tagStore, - Pool: daemon.distributionPool, + DownloadManager: daemon.downloadManager, } - return distribution.Pull(ref, imagePullConfig) + err := distribution.Pull(ctx, ref, imagePullConfig) + close(progressChan) + <-writesDone + return err } // ExportImage exports a list of images to the given output stream. The @@ -1069,10 +1113,23 @@ func (daemon *Daemon) ExportImage(names []string, outStream io.Writer) error { // PushImage initiates a push operation on the repository named localName. func (daemon *Daemon) PushImage(ref reference.Named, metaHeaders map[string][]string, authConfig *cliconfig.AuthConfig, outStream io.Writer) error { + // Include a buffer so that slow client connections don't affect + // transfer performance. + progressChan := make(chan progress.Progress, 100) + + writesDone := make(chan struct{}) + + ctx, cancelFunc := context.WithCancel(context.Background()) + + go func() { + writeDistributionProgress(cancelFunc, outStream, progressChan) + close(writesDone) + }() + imagePushConfig := &distribution.ImagePushConfig{ MetaHeaders: metaHeaders, AuthConfig: authConfig, - OutStream: outStream, + ProgressOutput: progress.ChanOutput(progressChan), RegistryService: daemon.RegistryService, EventsService: daemon.EventsService, MetadataStore: daemon.distributionMetadataStore, @@ -1080,9 +1137,13 @@ func (daemon *Daemon) PushImage(ref reference.Named, metaHeaders map[string][]st ImageStore: daemon.imageStore, TagStore: daemon.tagStore, TrustKey: daemon.trustKey, + UploadManager: daemon.uploadManager, } - return distribution.Push(ref, imagePushConfig) + err := distribution.Push(ctx, ref, imagePushConfig) + close(progressChan) + <-writesDone + return err } // LookupImage looks up an image by name and returns it as an ImageInspect diff --git a/daemon/daemonbuilder/builder.go b/daemon/daemonbuilder/builder.go index 632c484eba..4cd28c1bff 100644 --- a/daemon/daemonbuilder/builder.go +++ b/daemon/daemonbuilder/builder.go @@ -21,7 +21,6 @@ import ( "github.com/docker/docker/pkg/httputils" "github.com/docker/docker/pkg/idtools" "github.com/docker/docker/pkg/ioutils" - "github.com/docker/docker/pkg/progressreader" "github.com/docker/docker/pkg/urlutil" "github.com/docker/docker/registry" "github.com/docker/docker/runconfig" @@ -239,7 +238,7 @@ func (d Docker) Start(c *container.Container) error { // DetectContextFromRemoteURL returns a context and in certain cases the name of the dockerfile to be used // irrespective of user input. // progressReader is only used if remoteURL is actually a URL (not empty, and not a Git endpoint). -func DetectContextFromRemoteURL(r io.ReadCloser, remoteURL string, progressReader *progressreader.Config) (context builder.ModifiableContext, dockerfileName string, err error) { +func DetectContextFromRemoteURL(r io.ReadCloser, remoteURL string, createProgressReader func(in io.ReadCloser) io.ReadCloser) (context builder.ModifiableContext, dockerfileName string, err error) { switch { case remoteURL == "": context, err = builder.MakeTarSumContext(r) @@ -262,8 +261,7 @@ func DetectContextFromRemoteURL(r io.ReadCloser, remoteURL string, progressReade }, // fallback handler (tar context) "": func(rc io.ReadCloser) (io.ReadCloser, error) { - progressReader.In = rc - return progressReader, nil + return createProgressReader(rc), nil }, }) default: diff --git a/daemon/import.go b/daemon/import.go index 749295a738..75010e346d 100644 --- a/daemon/import.go +++ b/daemon/import.go @@ -13,7 +13,7 @@ import ( "github.com/docker/docker/image" "github.com/docker/docker/layer" "github.com/docker/docker/pkg/httputils" - "github.com/docker/docker/pkg/progressreader" + "github.com/docker/docker/pkg/progress" "github.com/docker/docker/pkg/streamformatter" "github.com/docker/docker/runconfig" ) @@ -47,16 +47,8 @@ func (daemon *Daemon) ImportImage(src string, newRef reference.Named, msg string if err != nil { return err } - progressReader := progressreader.New(progressreader.Config{ - In: resp.Body, - Out: outStream, - Formatter: sf, - Size: resp.ContentLength, - NewLines: true, - ID: "", - Action: "Importing", - }) - archive = progressReader + progressOutput := sf.NewProgressOutput(outStream, true) + archive = progress.NewProgressReader(resp.Body, progressOutput, resp.ContentLength, "", "Importing") } defer archive.Close() diff --git a/distribution/metadata/v1_id_service.go b/distribution/metadata/v1_id_service.go index 4098f8db83..f6e4589248 100644 --- a/distribution/metadata/v1_id_service.go +++ b/distribution/metadata/v1_id_service.go @@ -23,20 +23,20 @@ func (idserv *V1IDService) namespace() string { } // Get finds a layer by its V1 ID. -func (idserv *V1IDService) Get(v1ID, registry string) (layer.ChainID, error) { +func (idserv *V1IDService) Get(v1ID, registry string) (layer.DiffID, error) { if err := v1.ValidateID(v1ID); err != nil { - return layer.ChainID(""), err + return layer.DiffID(""), err } idBytes, err := idserv.store.Get(idserv.namespace(), registry+","+v1ID) if err != nil { - return layer.ChainID(""), err + return layer.DiffID(""), err } - return layer.ChainID(idBytes), nil + return layer.DiffID(idBytes), nil } // Set associates an image with a V1 ID. -func (idserv *V1IDService) Set(v1ID, registry string, id layer.ChainID) error { +func (idserv *V1IDService) Set(v1ID, registry string, id layer.DiffID) error { if err := v1.ValidateID(v1ID); err != nil { return err } diff --git a/distribution/metadata/v1_id_service_test.go b/distribution/metadata/v1_id_service_test.go index bf0f23a6dc..556886581e 100644 --- a/distribution/metadata/v1_id_service_test.go +++ b/distribution/metadata/v1_id_service_test.go @@ -24,22 +24,22 @@ func TestV1IDService(t *testing.T) { testVectors := []struct { registry string v1ID string - layerID layer.ChainID + layerID layer.DiffID }{ { registry: "registry1", v1ID: "f0cd5ca10b07f35512fc2f1cbf9a6cefbdb5cba70ac6b0c9e5988f4497f71937", - layerID: layer.ChainID("sha256:a3ed95caeb02ffe68cdd9fd84406680ae93d633cb16422d00e8a7c22955b46d4"), + layerID: layer.DiffID("sha256:a3ed95caeb02ffe68cdd9fd84406680ae93d633cb16422d00e8a7c22955b46d4"), }, { registry: "registry2", v1ID: "9e3447ca24cb96d86ebd5960cb34d1299b07e0a0e03801d90b9969a2c187dd6e", - layerID: layer.ChainID("sha256:86e0e091d0da6bde2456dbb48306f3956bbeb2eae1b5b9a43045843f69fe4aaa"), + layerID: layer.DiffID("sha256:86e0e091d0da6bde2456dbb48306f3956bbeb2eae1b5b9a43045843f69fe4aaa"), }, { registry: "registry1", v1ID: "9e3447ca24cb96d86ebd5960cb34d1299b07e0a0e03801d90b9969a2c187dd6e", - layerID: layer.ChainID("sha256:03f4658f8b782e12230c1783426bd3bacce651ce582a4ffb6fbbfa2079428ecb"), + layerID: layer.DiffID("sha256:03f4658f8b782e12230c1783426bd3bacce651ce582a4ffb6fbbfa2079428ecb"), }, } diff --git a/distribution/pool.go b/distribution/pool.go deleted file mode 100644 index 8c648f6e8b..0000000000 --- a/distribution/pool.go +++ /dev/null @@ -1,51 +0,0 @@ -package distribution - -import ( - "sync" - - "github.com/docker/docker/pkg/broadcaster" -) - -// A Pool manages concurrent pulls. It deduplicates in-progress downloads. -type Pool struct { - sync.Mutex - pullingPool map[string]*broadcaster.Buffered -} - -// NewPool creates a new Pool. -func NewPool() *Pool { - return &Pool{ - pullingPool: make(map[string]*broadcaster.Buffered), - } -} - -// add checks if a pull is already running, and returns (broadcaster, true) -// if a running operation is found. Otherwise, it creates a new one and returns -// (broadcaster, false). -func (pool *Pool) add(key string) (*broadcaster.Buffered, bool) { - pool.Lock() - defer pool.Unlock() - - if p, exists := pool.pullingPool[key]; exists { - return p, true - } - - broadcaster := broadcaster.NewBuffered() - pool.pullingPool[key] = broadcaster - - return broadcaster, false -} - -func (pool *Pool) removeWithError(key string, broadcasterResult error) error { - pool.Lock() - defer pool.Unlock() - if broadcaster, exists := pool.pullingPool[key]; exists { - broadcaster.CloseWithError(broadcasterResult) - delete(pool.pullingPool, key) - } - return nil -} - -func (pool *Pool) remove(key string) error { - return pool.removeWithError(key, nil) -} diff --git a/distribution/pool_test.go b/distribution/pool_test.go deleted file mode 100644 index 80511e8342..0000000000 --- a/distribution/pool_test.go +++ /dev/null @@ -1,28 +0,0 @@ -package distribution - -import ( - "testing" -) - -func TestPools(t *testing.T) { - p := NewPool() - - if _, found := p.add("test1"); found { - t.Fatal("Expected pull test1 not to be in progress") - } - if _, found := p.add("test2"); found { - t.Fatal("Expected pull test2 not to be in progress") - } - if _, found := p.add("test1"); !found { - t.Fatalf("Expected pull test1 to be in progress`") - } - if err := p.remove("test2"); err != nil { - t.Fatal(err) - } - if err := p.remove("test2"); err != nil { - t.Fatal(err) - } - if err := p.remove("test1"); err != nil { - t.Fatal(err) - } -} diff --git a/distribution/pull.go b/distribution/pull.go index 4232ce3ca1..dec47e2112 100644 --- a/distribution/pull.go +++ b/distribution/pull.go @@ -2,7 +2,7 @@ package distribution import ( "fmt" - "io" + "os" "strings" "github.com/Sirupsen/logrus" @@ -10,11 +10,12 @@ import ( "github.com/docker/docker/cliconfig" "github.com/docker/docker/daemon/events" "github.com/docker/docker/distribution/metadata" + "github.com/docker/docker/distribution/xfer" "github.com/docker/docker/image" - "github.com/docker/docker/layer" - "github.com/docker/docker/pkg/streamformatter" + "github.com/docker/docker/pkg/progress" "github.com/docker/docker/registry" "github.com/docker/docker/tag" + "golang.org/x/net/context" ) // ImagePullConfig stores pull configuration. @@ -25,9 +26,9 @@ type ImagePullConfig struct { // AuthConfig holds authentication credentials for authenticating with // the registry. AuthConfig *cliconfig.AuthConfig - // OutStream is the output writer for showing the status of the pull + // ProgressOutput is the interface for showing the status of the pull // operation. - OutStream io.Writer + ProgressOutput progress.Output // RegistryService is the registry service to use for TLS configuration // and endpoint lookup. RegistryService *registry.Service @@ -36,14 +37,12 @@ type ImagePullConfig struct { // MetadataStore is the storage backend for distribution-specific // metadata. MetadataStore metadata.Store - // LayerStore manages layers. - LayerStore layer.Store // ImageStore manages images. ImageStore image.Store // TagStore manages tags. TagStore tag.Store - // Pool manages concurrent pulls. - Pool *Pool + // DownloadManager manages concurrent pulls. + DownloadManager *xfer.LayerDownloadManager } // Puller is an interface that abstracts pulling for different API versions. @@ -51,7 +50,7 @@ type Puller interface { // Pull tries to pull the image referenced by `tag` // Pull returns an error if any, as well as a boolean that determines whether to retry Pull on the next configured endpoint. // - Pull(ref reference.Named) (fallback bool, err error) + Pull(ctx context.Context, ref reference.Named) (fallback bool, err error) } // newPuller returns a Puller interface that will pull from either a v1 or v2 @@ -59,14 +58,13 @@ type Puller interface { // whether a v1 or v2 puller will be created. The other parameters are passed // through to the underlying puller implementation for use during the actual // pull operation. -func newPuller(endpoint registry.APIEndpoint, repoInfo *registry.RepositoryInfo, imagePullConfig *ImagePullConfig, sf *streamformatter.StreamFormatter) (Puller, error) { +func newPuller(endpoint registry.APIEndpoint, repoInfo *registry.RepositoryInfo, imagePullConfig *ImagePullConfig) (Puller, error) { switch endpoint.Version { case registry.APIVersion2: return &v2Puller{ blobSumService: metadata.NewBlobSumService(imagePullConfig.MetadataStore), endpoint: endpoint, config: imagePullConfig, - sf: sf, repoInfo: repoInfo, }, nil case registry.APIVersion1: @@ -74,7 +72,6 @@ func newPuller(endpoint registry.APIEndpoint, repoInfo *registry.RepositoryInfo, v1IDService: metadata.NewV1IDService(imagePullConfig.MetadataStore), endpoint: endpoint, config: imagePullConfig, - sf: sf, repoInfo: repoInfo, }, nil } @@ -83,9 +80,7 @@ func newPuller(endpoint registry.APIEndpoint, repoInfo *registry.RepositoryInfo, // Pull initiates a pull operation. image is the repository name to pull, and // tag may be either empty, or indicate a specific tag to pull. -func Pull(ref reference.Named, imagePullConfig *ImagePullConfig) error { - var sf = streamformatter.NewJSONStreamFormatter() - +func Pull(ctx context.Context, ref reference.Named, imagePullConfig *ImagePullConfig) error { // Resolve the Repository name from fqn to RepositoryInfo repoInfo, err := imagePullConfig.RegistryService.ResolveRepository(ref) if err != nil { @@ -120,12 +115,19 @@ func Pull(ref reference.Named, imagePullConfig *ImagePullConfig) error { for _, endpoint := range endpoints { logrus.Debugf("Trying to pull %s from %s %s", repoInfo.LocalName, endpoint.URL, endpoint.Version) - puller, err := newPuller(endpoint, repoInfo, imagePullConfig, sf) + puller, err := newPuller(endpoint, repoInfo, imagePullConfig) if err != nil { errors = append(errors, err.Error()) continue } - if fallback, err := puller.Pull(ref); err != nil { + if fallback, err := puller.Pull(ctx, ref); err != nil { + // Was this pull cancelled? If so, don't try to fall + // back. + select { + case <-ctx.Done(): + fallback = false + default: + } if fallback { if _, ok := err.(registry.ErrNoSupport); !ok { // Because we found an error that's not ErrNoSupport, discard all subsequent ErrNoSupport errors. @@ -165,11 +167,11 @@ func Pull(ref reference.Named, imagePullConfig *ImagePullConfig) error { // status message indicates that a newer image was downloaded. Otherwise, it // indicates that the image is up to date. requestedTag is the tag the message // will refer to. -func writeStatus(requestedTag string, out io.Writer, sf *streamformatter.StreamFormatter, layersDownloaded bool) { +func writeStatus(requestedTag string, out progress.Output, layersDownloaded bool) { if layersDownloaded { - out.Write(sf.FormatStatus("", "Status: Downloaded newer image for %s", requestedTag)) + progress.Message(out, "", "Status: Downloaded newer image for "+requestedTag) } else { - out.Write(sf.FormatStatus("", "Status: Image is up to date for %s", requestedTag)) + progress.Message(out, "", "Status: Image is up to date for "+requestedTag) } } @@ -183,3 +185,16 @@ func validateRepoName(name string) error { } return nil } + +// tmpFileClose creates a closer function for a temporary file that closes the file +// and also deletes it. +func tmpFileCloser(tmpFile *os.File) func() error { + return func() error { + tmpFile.Close() + if err := os.RemoveAll(tmpFile.Name()); err != nil { + logrus.Errorf("Failed to remove temp file: %s", tmpFile.Name()) + } + + return nil + } +} diff --git a/distribution/pull_v1.go b/distribution/pull_v1.go index 5b6a9101b0..9bc229a83a 100644 --- a/distribution/pull_v1.go +++ b/distribution/pull_v1.go @@ -1,43 +1,42 @@ package distribution import ( - "encoding/json" "errors" "fmt" "io" + "io/ioutil" "net" "net/url" "strings" - "sync" "time" "github.com/Sirupsen/logrus" "github.com/docker/distribution/reference" "github.com/docker/distribution/registry/client/transport" "github.com/docker/docker/distribution/metadata" + "github.com/docker/docker/distribution/xfer" "github.com/docker/docker/image" "github.com/docker/docker/image/v1" "github.com/docker/docker/layer" - "github.com/docker/docker/pkg/archive" - "github.com/docker/docker/pkg/progressreader" - "github.com/docker/docker/pkg/streamformatter" + "github.com/docker/docker/pkg/ioutils" + "github.com/docker/docker/pkg/progress" "github.com/docker/docker/pkg/stringid" "github.com/docker/docker/registry" + "golang.org/x/net/context" ) type v1Puller struct { v1IDService *metadata.V1IDService endpoint registry.APIEndpoint config *ImagePullConfig - sf *streamformatter.StreamFormatter repoInfo *registry.RepositoryInfo session *registry.Session } -func (p *v1Puller) Pull(ref reference.Named) (fallback bool, err error) { +func (p *v1Puller) Pull(ctx context.Context, ref reference.Named) (fallback bool, err error) { if _, isDigested := ref.(reference.Digested); isDigested { // Allowing fallback, because HTTPS v1 is before HTTP v2 - return true, registry.ErrNoSupport{errors.New("Cannot pull by digest with v1 registry")} + return true, registry.ErrNoSupport{Err: errors.New("Cannot pull by digest with v1 registry")} } tlsConfig, err := p.config.RegistryService.TLSConfig(p.repoInfo.Index.Name) @@ -62,19 +61,17 @@ func (p *v1Puller) Pull(ref reference.Named) (fallback bool, err error) { logrus.Debugf("Fallback from error: %s", err) return true, err } - if err := p.pullRepository(ref); err != nil { + if err := p.pullRepository(ctx, ref); err != nil { // TODO(dmcgowan): Check if should fallback return false, err } - out := p.config.OutStream - out.Write(p.sf.FormatStatus("", "%s: this image was pulled from a legacy registry. Important: This registry version will not be supported in future versions of docker.", p.repoInfo.CanonicalName.Name())) + progress.Message(p.config.ProgressOutput, "", p.repoInfo.CanonicalName.Name()+": this image was pulled from a legacy registry. Important: This registry version will not be supported in future versions of docker.") return false, nil } -func (p *v1Puller) pullRepository(ref reference.Named) error { - out := p.config.OutStream - out.Write(p.sf.FormatStatus("", "Pulling repository %s", p.repoInfo.CanonicalName.Name())) +func (p *v1Puller) pullRepository(ctx context.Context, ref reference.Named) error { + progress.Message(p.config.ProgressOutput, "", "Pulling repository "+p.repoInfo.CanonicalName.Name()) repoData, err := p.session.GetRepositoryData(p.repoInfo.RemoteName) if err != nil { @@ -112,46 +109,18 @@ func (p *v1Puller) pullRepository(ref reference.Named) error { } } - errors := make(chan error) - layerDownloaded := make(chan struct{}) - layersDownloaded := false - var wg sync.WaitGroup for _, imgData := range repoData.ImgList { if isTagged && imgData.Tag != tagged.Tag() { continue } - wg.Add(1) - go func(img *registry.ImgData) { - p.downloadImage(out, repoData, img, layerDownloaded, errors) - wg.Done() - }(imgData) - } - - go func() { - wg.Wait() - close(errors) - }() - - var lastError error -selectLoop: - for { - select { - case err, ok := <-errors: - if !ok { - break selectLoop - } - lastError = err - case <-layerDownloaded: - layersDownloaded = true + err := p.downloadImage(ctx, repoData, imgData, &layersDownloaded) + if err != nil { + return err } } - if lastError != nil { - return lastError - } - localNameRef := p.repoInfo.LocalName if isTagged { localNameRef, err = reference.WithTag(localNameRef, tagged.Tag()) @@ -159,194 +128,143 @@ selectLoop: localNameRef = p.repoInfo.LocalName } } - writeStatus(localNameRef.String(), out, p.sf, layersDownloaded) + writeStatus(localNameRef.String(), p.config.ProgressOutput, layersDownloaded) return nil } -func (p *v1Puller) downloadImage(out io.Writer, repoData *registry.RepositoryData, img *registry.ImgData, layerDownloaded chan struct{}, errors chan error) { +func (p *v1Puller) downloadImage(ctx context.Context, repoData *registry.RepositoryData, img *registry.ImgData, layersDownloaded *bool) error { if img.Tag == "" { logrus.Debugf("Image (id: %s) present in this repository but untagged, skipping", img.ID) - return + return nil } localNameRef, err := reference.WithTag(p.repoInfo.LocalName, img.Tag) if err != nil { retErr := fmt.Errorf("Image (id: %s) has invalid tag: %s", img.ID, img.Tag) logrus.Debug(retErr.Error()) - errors <- retErr + return retErr } if err := v1.ValidateID(img.ID); err != nil { - errors <- err - return + return err } - out.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), fmt.Sprintf("Pulling image (%s) from %s", img.Tag, p.repoInfo.CanonicalName.Name()), nil)) + progress.Updatef(p.config.ProgressOutput, stringid.TruncateID(img.ID), "Pulling image (%s) from %s", img.Tag, p.repoInfo.CanonicalName.Name()) success := false var lastErr error - var isDownloaded bool for _, ep := range p.repoInfo.Index.Mirrors { ep += "v1/" - out.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), fmt.Sprintf("Pulling image (%s) from %s, mirror: %s", img.Tag, p.repoInfo.CanonicalName.Name(), ep), nil)) - if isDownloaded, err = p.pullImage(out, img.ID, ep, localNameRef); err != nil { + progress.Updatef(p.config.ProgressOutput, stringid.TruncateID(img.ID), fmt.Sprintf("Pulling image (%s) from %s, mirror: %s", img.Tag, p.repoInfo.CanonicalName.Name(), ep)) + if err = p.pullImage(ctx, img.ID, ep, localNameRef, layersDownloaded); err != nil { // Don't report errors when pulling from mirrors. logrus.Debugf("Error pulling image (%s) from %s, mirror: %s, %s", img.Tag, p.repoInfo.CanonicalName.Name(), ep, err) continue } - if isDownloaded { - layerDownloaded <- struct{}{} - } success = true break } if !success { for _, ep := range repoData.Endpoints { - out.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), fmt.Sprintf("Pulling image (%s) from %s, endpoint: %s", img.Tag, p.repoInfo.CanonicalName.Name(), ep), nil)) - if isDownloaded, err = p.pullImage(out, img.ID, ep, localNameRef); err != nil { + progress.Updatef(p.config.ProgressOutput, stringid.TruncateID(img.ID), "Pulling image (%s) from %s, endpoint: %s", img.Tag, p.repoInfo.CanonicalName.Name(), ep) + if err = p.pullImage(ctx, img.ID, ep, localNameRef, layersDownloaded); err != nil { // It's not ideal that only the last error is returned, it would be better to concatenate the errors. // As the error is also given to the output stream the user will see the error. lastErr = err - out.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), fmt.Sprintf("Error pulling image (%s) from %s, endpoint: %s, %s", img.Tag, p.repoInfo.CanonicalName.Name(), ep, err), nil)) + progress.Updatef(p.config.ProgressOutput, stringid.TruncateID(img.ID), "Error pulling image (%s) from %s, endpoint: %s, %s", img.Tag, p.repoInfo.CanonicalName.Name(), ep, err) continue } - if isDownloaded { - layerDownloaded <- struct{}{} - } success = true break } } if !success { err := fmt.Errorf("Error pulling image (%s) from %s, %v", img.Tag, p.repoInfo.CanonicalName.Name(), lastErr) - out.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), err.Error(), nil)) - errors <- err - return + progress.Update(p.config.ProgressOutput, stringid.TruncateID(img.ID), err.Error()) + return err } - out.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), "Download complete", nil)) + progress.Update(p.config.ProgressOutput, stringid.TruncateID(img.ID), "Download complete") + return nil } -func (p *v1Puller) pullImage(out io.Writer, v1ID, endpoint string, localNameRef reference.Named) (layersDownloaded bool, err error) { +func (p *v1Puller) pullImage(ctx context.Context, v1ID, endpoint string, localNameRef reference.Named, layersDownloaded *bool) (err error) { var history []string history, err = p.session.GetRemoteHistory(v1ID, endpoint) if err != nil { - return false, err + return err } if len(history) < 1 { - return false, fmt.Errorf("empty history for image %s", v1ID) + return fmt.Errorf("empty history for image %s", v1ID) } - out.Write(p.sf.FormatProgress(stringid.TruncateID(v1ID), "Pulling dependent layers", nil)) - // FIXME: Try to stream the images? - // FIXME: Launch the getRemoteImage() in goroutines + progress.Update(p.config.ProgressOutput, stringid.TruncateID(v1ID), "Pulling dependent layers") var ( - referencedLayers []layer.Layer - parentID layer.ChainID - newHistory []image.History - img *image.V1Image - imgJSON []byte - imgSize int64 + descriptors []xfer.DownloadDescriptor + newHistory []image.History + imgJSON []byte + imgSize int64 ) - defer func() { - for _, l := range referencedLayers { - layer.ReleaseAndLog(p.config.LayerStore, l) - } - }() - - layersDownloaded = false - - // Iterate over layers from top-most to bottom-most, checking if any - // already exist on disk. - var i int - for i = 0; i != len(history); i++ { - v1LayerID := history[i] - // Do we have a mapping for this particular v1 ID on this - // registry? - if layerID, err := p.v1IDService.Get(v1LayerID, p.repoInfo.Index.Name); err == nil { - // Does the layer actually exist - if l, err := p.config.LayerStore.Get(layerID); err == nil { - for j := i; j >= 0; j-- { - logrus.Debugf("Layer already exists: %s", history[j]) - out.Write(p.sf.FormatProgress(stringid.TruncateID(history[j]), "Already exists", nil)) - } - referencedLayers = append(referencedLayers, l) - parentID = layerID - break - } - } - } - - needsDownload := i - // Iterate over layers, in order from bottom-most to top-most. Download - // config for all layers, and download actual layer data if needed. - for i = len(history) - 1; i >= 0; i-- { + // config for all layers and create descriptors. + for i := len(history) - 1; i >= 0; i-- { v1LayerID := history[i] - imgJSON, imgSize, err = p.downloadLayerConfig(out, v1LayerID, endpoint) + imgJSON, imgSize, err = p.downloadLayerConfig(v1LayerID, endpoint) if err != nil { - return layersDownloaded, err - } - - img = &image.V1Image{} - if err := json.Unmarshal(imgJSON, img); err != nil { - return layersDownloaded, err - } - - if i < needsDownload { - l, err := p.downloadLayer(out, v1LayerID, endpoint, parentID, imgSize, &layersDownloaded) - - // Note: This needs to be done even in the error case to avoid - // stale references to the layer. - if l != nil { - referencedLayers = append(referencedLayers, l) - } - if err != nil { - return layersDownloaded, err - } - - parentID = l.ChainID() + return err } // Create a new-style config from the legacy configs h, err := v1.HistoryFromConfig(imgJSON, false) if err != nil { - return layersDownloaded, err + return err } newHistory = append(newHistory, h) + + layerDescriptor := &v1LayerDescriptor{ + v1LayerID: v1LayerID, + indexName: p.repoInfo.Index.Name, + endpoint: endpoint, + v1IDService: p.v1IDService, + layersDownloaded: layersDownloaded, + layerSize: imgSize, + session: p.session, + } + + descriptors = append(descriptors, layerDescriptor) } rootFS := image.NewRootFS() - l := referencedLayers[len(referencedLayers)-1] - for l != nil { - rootFS.DiffIDs = append([]layer.DiffID{l.DiffID()}, rootFS.DiffIDs...) - l = l.Parent() - } - - config, err := v1.MakeConfigFromV1Config(imgJSON, rootFS, newHistory) + resultRootFS, release, err := p.config.DownloadManager.Download(ctx, *rootFS, descriptors, p.config.ProgressOutput) if err != nil { - return layersDownloaded, err + return err + } + defer release() + + config, err := v1.MakeConfigFromV1Config(imgJSON, &resultRootFS, newHistory) + if err != nil { + return err } imageID, err := p.config.ImageStore.Create(config) if err != nil { - return layersDownloaded, err + return err } if err := p.config.TagStore.AddTag(localNameRef, imageID, true); err != nil { - return layersDownloaded, err + return err } - return layersDownloaded, nil + return nil } -func (p *v1Puller) downloadLayerConfig(out io.Writer, v1LayerID, endpoint string) (imgJSON []byte, imgSize int64, err error) { - out.Write(p.sf.FormatProgress(stringid.TruncateID(v1LayerID), "Pulling metadata", nil)) +func (p *v1Puller) downloadLayerConfig(v1LayerID, endpoint string) (imgJSON []byte, imgSize int64, err error) { + progress.Update(p.config.ProgressOutput, stringid.TruncateID(v1LayerID), "Pulling metadata") retries := 5 for j := 1; j <= retries; j++ { imgJSON, imgSize, err := p.session.GetRemoteImageJSON(v1LayerID, endpoint) if err != nil && j == retries { - out.Write(p.sf.FormatProgress(stringid.TruncateID(v1LayerID), "Error pulling layer metadata", nil)) + progress.Update(p.config.ProgressOutput, stringid.TruncateID(v1LayerID), "Error pulling layer metadata") return nil, 0, err } else if err != nil { time.Sleep(time.Duration(j) * 500 * time.Millisecond) @@ -360,95 +278,66 @@ func (p *v1Puller) downloadLayerConfig(out io.Writer, v1LayerID, endpoint string return nil, 0, nil } -func (p *v1Puller) downloadLayer(out io.Writer, v1LayerID, endpoint string, parentID layer.ChainID, layerSize int64, layersDownloaded *bool) (l layer.Layer, err error) { - // ensure no two downloads of the same layer happen at the same time - poolKey := "layer:" + v1LayerID - broadcaster, found := p.config.Pool.add(poolKey) - broadcaster.Add(out) - if found { - logrus.Debugf("Image (id: %s) pull is already running, skipping", v1LayerID) - if err = broadcaster.Wait(); err != nil { - return nil, err - } - layerID, err := p.v1IDService.Get(v1LayerID, p.repoInfo.Index.Name) - if err != nil { - return nil, err - } - // Does the layer actually exist - l, err := p.config.LayerStore.Get(layerID) - if err != nil { - return nil, err - } - return l, nil - } +type v1LayerDescriptor struct { + v1LayerID string + indexName string + endpoint string + v1IDService *metadata.V1IDService + layersDownloaded *bool + layerSize int64 + session *registry.Session +} - // This must use a closure so it captures the value of err when - // the function returns, not when the 'defer' is evaluated. - defer func() { - p.config.Pool.removeWithError(poolKey, err) - }() +func (ld *v1LayerDescriptor) Key() string { + return "v1:" + ld.v1LayerID +} - retries := 5 - for j := 1; j <= retries; j++ { - // Get the layer - status := "Pulling fs layer" - if j > 1 { - status = fmt.Sprintf("Pulling fs layer [retries: %d]", j) - } - broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(v1LayerID), status, nil)) - layerReader, err := p.session.GetRemoteImageLayer(v1LayerID, endpoint, layerSize) +func (ld *v1LayerDescriptor) ID() string { + return stringid.TruncateID(ld.v1LayerID) +} + +func (ld *v1LayerDescriptor) DiffID() (layer.DiffID, error) { + return ld.v1IDService.Get(ld.v1LayerID, ld.indexName) +} + +func (ld *v1LayerDescriptor) Download(ctx context.Context, progressOutput progress.Output) (io.ReadCloser, int64, error) { + progress.Update(progressOutput, ld.ID(), "Pulling fs layer") + layerReader, err := ld.session.GetRemoteImageLayer(ld.v1LayerID, ld.endpoint, ld.layerSize) + if err != nil { + progress.Update(progressOutput, ld.ID(), "Error pulling dependent layers") if uerr, ok := err.(*url.Error); ok { err = uerr.Err } - if terr, ok := err.(net.Error); ok && terr.Timeout() && j < retries { - time.Sleep(time.Duration(j) * 500 * time.Millisecond) - continue - } else if err != nil { - broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(v1LayerID), "Error pulling dependent layers", nil)) - return nil, err + if terr, ok := err.(net.Error); ok && terr.Timeout() { + return nil, 0, err } - *layersDownloaded = true - defer layerReader.Close() + return nil, 0, xfer.DoNotRetry{Err: err} + } + *ld.layersDownloaded = true - reader := progressreader.New(progressreader.Config{ - In: layerReader, - Out: broadcaster, - Formatter: p.sf, - Size: layerSize, - NewLines: false, - ID: stringid.TruncateID(v1LayerID), - Action: "Downloading", - }) - - inflatedLayerData, err := archive.DecompressStream(reader) - if err != nil { - return nil, fmt.Errorf("could not get decompression stream: %v", err) - } - - l, err := p.config.LayerStore.Register(inflatedLayerData, parentID) - if err != nil { - return nil, fmt.Errorf("failed to register layer: %v", err) - } - logrus.Debugf("layer %s registered successfully", l.DiffID()) - - if terr, ok := err.(net.Error); ok && terr.Timeout() && j < retries { - time.Sleep(time.Duration(j) * 500 * time.Millisecond) - continue - } else if err != nil { - broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(v1LayerID), "Error downloading dependent layers", nil)) - return nil, err - } - - // Cache mapping from this v1 ID to content-addressable layer ID - if err := p.v1IDService.Set(v1LayerID, p.repoInfo.Index.Name, l.ChainID()); err != nil { - return nil, err - } - - broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(v1LayerID), "Download complete", nil)) - broadcaster.Close() - return l, nil + tmpFile, err := ioutil.TempFile("", "GetImageBlob") + if err != nil { + layerReader.Close() + return nil, 0, err } - // not reached - return nil, nil + reader := progress.NewProgressReader(ioutils.NewCancelReadCloser(ctx, layerReader), progressOutput, ld.layerSize, ld.ID(), "Downloading") + defer reader.Close() + + _, err = io.Copy(tmpFile, reader) + if err != nil { + return nil, 0, err + } + + progress.Update(progressOutput, ld.ID(), "Download complete") + + logrus.Debugf("Downloaded %s to tempfile %s", ld.ID(), tmpFile.Name()) + + tmpFile.Seek(0, 0) + return ioutils.NewReadCloserWrapper(tmpFile, tmpFileCloser(tmpFile)), ld.layerSize, nil +} + +func (ld *v1LayerDescriptor) Registered(diffID layer.DiffID) { + // Cache mapping from this layer's DiffID to the blobsum + ld.v1IDService.Set(ld.v1LayerID, ld.indexName, diffID) } diff --git a/distribution/pull_v2.go b/distribution/pull_v2.go index 0c65d8d744..50244608f1 100644 --- a/distribution/pull_v2.go +++ b/distribution/pull_v2.go @@ -15,13 +15,12 @@ import ( "github.com/docker/distribution/manifest/schema1" "github.com/docker/distribution/reference" "github.com/docker/docker/distribution/metadata" + "github.com/docker/docker/distribution/xfer" "github.com/docker/docker/image" "github.com/docker/docker/image/v1" "github.com/docker/docker/layer" - "github.com/docker/docker/pkg/archive" - "github.com/docker/docker/pkg/broadcaster" - "github.com/docker/docker/pkg/progressreader" - "github.com/docker/docker/pkg/streamformatter" + "github.com/docker/docker/pkg/ioutils" + "github.com/docker/docker/pkg/progress" "github.com/docker/docker/pkg/stringid" "github.com/docker/docker/registry" "golang.org/x/net/context" @@ -31,23 +30,19 @@ type v2Puller struct { blobSumService *metadata.BlobSumService endpoint registry.APIEndpoint config *ImagePullConfig - sf *streamformatter.StreamFormatter repoInfo *registry.RepositoryInfo repo distribution.Repository - sessionID string } -func (p *v2Puller) Pull(ref reference.Named) (fallback bool, err error) { +func (p *v2Puller) Pull(ctx context.Context, ref reference.Named) (fallback bool, err error) { // TODO(tiborvass): was ReceiveTimeout p.repo, err = NewV2Repository(p.repoInfo, p.endpoint, p.config.MetaHeaders, p.config.AuthConfig, "pull") if err != nil { - logrus.Debugf("Error getting v2 registry: %v", err) + logrus.Warnf("Error getting v2 registry: %v", err) return true, err } - p.sessionID = stringid.GenerateRandomID() - - if err := p.pullV2Repository(ref); err != nil { + if err := p.pullV2Repository(ctx, ref); err != nil { if registry.ContinueOnError(err) { logrus.Debugf("Error trying v2 registry: %v", err) return true, err @@ -57,7 +52,7 @@ func (p *v2Puller) Pull(ref reference.Named) (fallback bool, err error) { return false, nil } -func (p *v2Puller) pullV2Repository(ref reference.Named) (err error) { +func (p *v2Puller) pullV2Repository(ctx context.Context, ref reference.Named) (err error) { var refs []reference.Named taggedName := p.repoInfo.LocalName if tagged, isTagged := ref.(reference.Tagged); isTagged { @@ -73,7 +68,7 @@ func (p *v2Puller) pullV2Repository(ref reference.Named) (err error) { } refs = []reference.Named{taggedName} } else { - manSvc, err := p.repo.Manifests(context.Background()) + manSvc, err := p.repo.Manifests(ctx) if err != nil { return err } @@ -98,98 +93,109 @@ func (p *v2Puller) pullV2Repository(ref reference.Named) (err error) { for _, pullRef := range refs { // pulledNew is true if either new layers were downloaded OR if existing images were newly tagged // TODO(tiborvass): should we change the name of `layersDownload`? What about message in WriteStatus? - pulledNew, err := p.pullV2Tag(p.config.OutStream, pullRef) + pulledNew, err := p.pullV2Tag(ctx, pullRef) if err != nil { return err } layersDownloaded = layersDownloaded || pulledNew } - writeStatus(taggedName.String(), p.config.OutStream, p.sf, layersDownloaded) + writeStatus(taggedName.String(), p.config.ProgressOutput, layersDownloaded) return nil } -// downloadInfo is used to pass information from download to extractor -type downloadInfo struct { - tmpFile *os.File - digest digest.Digest - layer distribution.ReadSeekCloser - size int64 - err chan error - poolKey string - broadcaster *broadcaster.Buffered +type v2LayerDescriptor struct { + digest digest.Digest + repo distribution.Repository + blobSumService *metadata.BlobSumService } -type errVerification struct{} +func (ld *v2LayerDescriptor) Key() string { + return "v2:" + ld.digest.String() +} -func (errVerification) Error() string { return "verification failed" } +func (ld *v2LayerDescriptor) ID() string { + return stringid.TruncateID(ld.digest.String()) +} -func (p *v2Puller) download(di *downloadInfo) { - logrus.Debugf("pulling blob %q", di.digest) +func (ld *v2LayerDescriptor) DiffID() (layer.DiffID, error) { + return ld.blobSumService.GetDiffID(ld.digest) +} - blobs := p.repo.Blobs(context.Background()) +func (ld *v2LayerDescriptor) Download(ctx context.Context, progressOutput progress.Output) (io.ReadCloser, int64, error) { + logrus.Debugf("pulling blob %q", ld.digest) - layerDownload, err := blobs.Open(context.Background(), di.digest) + blobs := ld.repo.Blobs(ctx) + + layerDownload, err := blobs.Open(ctx, ld.digest) if err != nil { - logrus.Debugf("Error fetching layer: %v", err) - di.err <- err - return + logrus.Debugf("Error statting layer: %v", err) + if err == distribution.ErrBlobUnknown { + return nil, 0, xfer.DoNotRetry{Err: err} + } + return nil, 0, retryOnError(err) } - defer layerDownload.Close() - di.size, err = layerDownload.Seek(0, os.SEEK_END) + size, err := layerDownload.Seek(0, os.SEEK_END) if err != nil { // Seek failed, perhaps because there was no Content-Length // header. This shouldn't fail the download, because we can // still continue without a progress bar. - di.size = 0 + size = 0 } else { // Restore the seek offset at the beginning of the stream. _, err = layerDownload.Seek(0, os.SEEK_SET) if err != nil { - di.err <- err - return + return nil, 0, err } } - verifier, err := digest.NewDigestVerifier(di.digest) + reader := progress.NewProgressReader(ioutils.NewCancelReadCloser(ctx, layerDownload), progressOutput, size, ld.ID(), "Downloading") + defer reader.Close() + + verifier, err := digest.NewDigestVerifier(ld.digest) if err != nil { - di.err <- err - return + return nil, 0, xfer.DoNotRetry{Err: err} } - digestStr := di.digest.String() + tmpFile, err := ioutil.TempFile("", "GetImageBlob") + if err != nil { + return nil, 0, xfer.DoNotRetry{Err: err} + } - reader := progressreader.New(progressreader.Config{ - In: ioutil.NopCloser(io.TeeReader(layerDownload, verifier)), - Out: di.broadcaster, - Formatter: p.sf, - Size: di.size, - NewLines: false, - ID: stringid.TruncateID(digestStr), - Action: "Downloading", - }) - io.Copy(di.tmpFile, reader) + _, err = io.Copy(tmpFile, io.TeeReader(reader, verifier)) + if err != nil { + return nil, 0, retryOnError(err) + } - di.broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(digestStr), "Verifying Checksum", nil)) + progress.Update(progressOutput, ld.ID(), "Verifying Checksum") if !verifier.Verified() { - err = fmt.Errorf("filesystem layer verification failed for digest %s", di.digest) + err = fmt.Errorf("filesystem layer verification failed for digest %s", ld.digest) logrus.Error(err) - di.err <- err - return + tmpFile.Close() + if err := os.RemoveAll(tmpFile.Name()); err != nil { + logrus.Errorf("Failed to remove temp file: %s", tmpFile.Name()) + } + + return nil, 0, xfer.DoNotRetry{Err: err} } - di.broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(digestStr), "Download complete", nil)) + progress.Update(progressOutput, ld.ID(), "Download complete") - logrus.Debugf("Downloaded %s to tempfile %s", digestStr, di.tmpFile.Name()) - di.layer = layerDownload + logrus.Debugf("Downloaded %s to tempfile %s", ld.ID(), tmpFile.Name()) - di.err <- nil + tmpFile.Seek(0, 0) + return ioutils.NewReadCloserWrapper(tmpFile, tmpFileCloser(tmpFile)), size, nil } -func (p *v2Puller) pullV2Tag(out io.Writer, ref reference.Named) (tagUpdated bool, err error) { +func (ld *v2LayerDescriptor) Registered(diffID layer.DiffID) { + // Cache mapping from this layer's DiffID to the blobsum + ld.blobSumService.Add(diffID, ld.digest) +} + +func (p *v2Puller) pullV2Tag(ctx context.Context, ref reference.Named) (tagUpdated bool, err error) { tagOrDigest := "" if tagged, isTagged := ref.(reference.Tagged); isTagged { tagOrDigest = tagged.Tag() @@ -201,7 +207,7 @@ func (p *v2Puller) pullV2Tag(out io.Writer, ref reference.Named) (tagUpdated boo logrus.Debugf("Pulling ref from V2 registry: %q", tagOrDigest) - manSvc, err := p.repo.Manifests(context.Background()) + manSvc, err := p.repo.Manifests(ctx) if err != nil { return false, err } @@ -231,33 +237,17 @@ func (p *v2Puller) pullV2Tag(out io.Writer, ref reference.Named) (tagUpdated boo return false, err } - out.Write(p.sf.FormatStatus(tagOrDigest, "Pulling from %s", p.repo.Name())) + progress.Message(p.config.ProgressOutput, tagOrDigest, "Pulling from "+p.repo.Name()) - var downloads []*downloadInfo - - defer func() { - for _, d := range downloads { - p.config.Pool.removeWithError(d.poolKey, err) - if d.tmpFile != nil { - d.tmpFile.Close() - if err := os.RemoveAll(d.tmpFile.Name()); err != nil { - logrus.Errorf("Failed to remove temp file: %s", d.tmpFile.Name()) - } - } - } - }() + var descriptors []xfer.DownloadDescriptor // Image history converted to the new format var history []image.History - poolKey := "v2layer:" - notFoundLocally := false - // Note that the order of this loop is in the direction of bottom-most // to top-most, so that the downloads slice gets ordered correctly. for i := len(verifiedManifest.FSLayers) - 1; i >= 0; i-- { blobSum := verifiedManifest.FSLayers[i].BlobSum - poolKey += blobSum.String() var throwAway struct { ThrowAway bool `json:"throwaway,omitempty"` @@ -276,119 +266,22 @@ func (p *v2Puller) pullV2Tag(out io.Writer, ref reference.Named) (tagUpdated boo continue } - // Do we have a layer on disk corresponding to the set of - // blobsums up to this point? - if !notFoundLocally { - notFoundLocally = true - diffID, err := p.blobSumService.GetDiffID(blobSum) - if err == nil { - rootFS.Append(diffID) - if l, err := p.config.LayerStore.Get(rootFS.ChainID()); err == nil { - notFoundLocally = false - logrus.Debugf("Layer already exists: %s", blobSum.String()) - out.Write(p.sf.FormatProgress(stringid.TruncateID(blobSum.String()), "Already exists", nil)) - defer layer.ReleaseAndLog(p.config.LayerStore, l) - continue - } else { - rootFS.DiffIDs = rootFS.DiffIDs[:len(rootFS.DiffIDs)-1] - } - } + layerDescriptor := &v2LayerDescriptor{ + digest: blobSum, + repo: p.repo, + blobSumService: p.blobSumService, } - out.Write(p.sf.FormatProgress(stringid.TruncateID(blobSum.String()), "Pulling fs layer", nil)) - - tmpFile, err := ioutil.TempFile("", "GetImageBlob") - if err != nil { - return false, err - } - - d := &downloadInfo{ - poolKey: poolKey, - digest: blobSum, - tmpFile: tmpFile, - // TODO: seems like this chan buffer solved hanging problem in go1.5, - // this can indicate some deeper problem that somehow we never take - // error from channel in loop below - err: make(chan error, 1), - } - - downloads = append(downloads, d) - - broadcaster, found := p.config.Pool.add(d.poolKey) - broadcaster.Add(out) - d.broadcaster = broadcaster - if found { - d.err <- nil - } else { - go p.download(d) - } + descriptors = append(descriptors, layerDescriptor) } - for _, d := range downloads { - if err := <-d.err; err != nil { - return false, err - } - - if d.layer == nil { - // Wait for a different pull to download and extract - // this layer. - err = d.broadcaster.Wait() - if err != nil { - return false, err - } - - diffID, err := p.blobSumService.GetDiffID(d.digest) - if err != nil { - return false, err - } - rootFS.Append(diffID) - - l, err := p.config.LayerStore.Get(rootFS.ChainID()) - if err != nil { - return false, err - } - - defer layer.ReleaseAndLog(p.config.LayerStore, l) - - continue - } - - d.tmpFile.Seek(0, 0) - reader := progressreader.New(progressreader.Config{ - In: d.tmpFile, - Out: d.broadcaster, - Formatter: p.sf, - Size: d.size, - NewLines: false, - ID: stringid.TruncateID(d.digest.String()), - Action: "Extracting", - }) - - inflatedLayerData, err := archive.DecompressStream(reader) - if err != nil { - return false, fmt.Errorf("could not get decompression stream: %v", err) - } - - l, err := p.config.LayerStore.Register(inflatedLayerData, rootFS.ChainID()) - if err != nil { - return false, fmt.Errorf("failed to register layer: %v", err) - } - logrus.Debugf("layer %s registered successfully", l.DiffID()) - rootFS.Append(l.DiffID()) - - // Cache mapping from this layer's DiffID to the blobsum - if err := p.blobSumService.Add(l.DiffID(), d.digest); err != nil { - return false, err - } - - defer layer.ReleaseAndLog(p.config.LayerStore, l) - - d.broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(d.digest.String()), "Pull complete", nil)) - d.broadcaster.Close() - tagUpdated = true + resultRootFS, release, err := p.config.DownloadManager.Download(ctx, *rootFS, descriptors, p.config.ProgressOutput) + if err != nil { + return false, err } + defer release() - config, err := v1.MakeConfigFromV1Config([]byte(verifiedManifest.History[0].V1Compatibility), rootFS, history) + config, err := v1.MakeConfigFromV1Config([]byte(verifiedManifest.History[0].V1Compatibility), &resultRootFS, history) if err != nil { return false, err } @@ -403,30 +296,24 @@ func (p *v2Puller) pullV2Tag(out io.Writer, ref reference.Named) (tagUpdated boo return false, err } - // Check for new tag if no layers downloaded - var oldTagImageID image.ID - if !tagUpdated { - oldTagImageID, err = p.config.TagStore.Get(ref) - if err != nil || oldTagImageID != imageID { - tagUpdated = true - } + if manifestDigest != "" { + progress.Message(p.config.ProgressOutput, "", "Digest: "+manifestDigest.String()) } - if tagUpdated { - if canonical, ok := ref.(reference.Canonical); ok { - if err = p.config.TagStore.AddDigest(canonical, imageID, true); err != nil { - return false, err - } - } else if err = p.config.TagStore.AddTag(ref, imageID, true); err != nil { + oldTagImageID, err := p.config.TagStore.Get(ref) + if err == nil && oldTagImageID == imageID { + return false, nil + } + + if canonical, ok := ref.(reference.Canonical); ok { + if err = p.config.TagStore.AddDigest(canonical, imageID, true); err != nil { return false, err } + } else if err = p.config.TagStore.AddTag(ref, imageID, true); err != nil { + return false, err } - if manifestDigest != "" { - out.Write(p.sf.FormatStatus("", "Digest: %s", manifestDigest)) - } - - return tagUpdated, nil + return true, nil } func verifyManifest(signedManifest *schema1.SignedManifest, ref reference.Reference) (m *schema1.Manifest, err error) { diff --git a/distribution/push.go b/distribution/push.go index ee41c2e1e3..7e6cd6c4d6 100644 --- a/distribution/push.go +++ b/distribution/push.go @@ -12,12 +12,14 @@ import ( "github.com/docker/docker/cliconfig" "github.com/docker/docker/daemon/events" "github.com/docker/docker/distribution/metadata" + "github.com/docker/docker/distribution/xfer" "github.com/docker/docker/image" "github.com/docker/docker/layer" - "github.com/docker/docker/pkg/streamformatter" + "github.com/docker/docker/pkg/progress" "github.com/docker/docker/registry" "github.com/docker/docker/tag" "github.com/docker/libtrust" + "golang.org/x/net/context" ) // ImagePushConfig stores push configuration. @@ -28,9 +30,9 @@ type ImagePushConfig struct { // AuthConfig holds authentication credentials for authenticating with // the registry. AuthConfig *cliconfig.AuthConfig - // OutStream is the output writer for showing the status of the push + // ProgressOutput is the interface for showing the status of the push // operation. - OutStream io.Writer + ProgressOutput progress.Output // RegistryService is the registry service to use for TLS configuration // and endpoint lookup. RegistryService *registry.Service @@ -48,6 +50,8 @@ type ImagePushConfig struct { // TrustKey is the private key for legacy signatures. This is typically // an ephemeral key, since these signatures are no longer verified. TrustKey libtrust.PrivateKey + // UploadManager dispatches uploads. + UploadManager *xfer.LayerUploadManager } // Pusher is an interface that abstracts pushing for different API versions. @@ -56,7 +60,7 @@ type Pusher interface { // Push returns an error if any, as well as a boolean that determines whether to retry Push on the next configured endpoint. // // TODO(tiborvass): have Push() take a reference to repository + tag, so that the pusher itself is repository-agnostic. - Push() (fallback bool, err error) + Push(ctx context.Context) (fallback bool, err error) } const compressionBufSize = 32768 @@ -66,7 +70,7 @@ const compressionBufSize = 32768 // whether a v1 or v2 pusher will be created. The other parameters are passed // through to the underlying pusher implementation for use during the actual // push operation. -func NewPusher(ref reference.Named, endpoint registry.APIEndpoint, repoInfo *registry.RepositoryInfo, imagePushConfig *ImagePushConfig, sf *streamformatter.StreamFormatter) (Pusher, error) { +func NewPusher(ref reference.Named, endpoint registry.APIEndpoint, repoInfo *registry.RepositoryInfo, imagePushConfig *ImagePushConfig) (Pusher, error) { switch endpoint.Version { case registry.APIVersion2: return &v2Pusher{ @@ -75,8 +79,7 @@ func NewPusher(ref reference.Named, endpoint registry.APIEndpoint, repoInfo *reg endpoint: endpoint, repoInfo: repoInfo, config: imagePushConfig, - sf: sf, - layersPushed: make(map[digest.Digest]bool), + layersPushed: pushMap{layersPushed: make(map[digest.Digest]bool)}, }, nil case registry.APIVersion1: return &v1Pusher{ @@ -85,7 +88,6 @@ func NewPusher(ref reference.Named, endpoint registry.APIEndpoint, repoInfo *reg endpoint: endpoint, repoInfo: repoInfo, config: imagePushConfig, - sf: sf, }, nil } return nil, fmt.Errorf("unknown version %d for registry %s", endpoint.Version, endpoint.URL) @@ -94,11 +96,9 @@ func NewPusher(ref reference.Named, endpoint registry.APIEndpoint, repoInfo *reg // Push initiates a push operation on the repository named localName. // ref is the specific variant of the image to be pushed. // If no tag is provided, all tags will be pushed. -func Push(ref reference.Named, imagePushConfig *ImagePushConfig) error { +func Push(ctx context.Context, ref reference.Named, imagePushConfig *ImagePushConfig) error { // FIXME: Allow to interrupt current push when new push of same image is done. - var sf = streamformatter.NewJSONStreamFormatter() - // Resolve the Repository name from fqn to RepositoryInfo repoInfo, err := imagePushConfig.RegistryService.ResolveRepository(ref) if err != nil { @@ -110,7 +110,7 @@ func Push(ref reference.Named, imagePushConfig *ImagePushConfig) error { return err } - imagePushConfig.OutStream.Write(sf.FormatStatus("", "The push refers to a repository [%s]", repoInfo.CanonicalName)) + progress.Messagef(imagePushConfig.ProgressOutput, "", "The push refers to a repository [%s]", repoInfo.CanonicalName.String()) associations := imagePushConfig.TagStore.ReferencesByName(repoInfo.LocalName) if len(associations) == 0 { @@ -121,12 +121,20 @@ func Push(ref reference.Named, imagePushConfig *ImagePushConfig) error { for _, endpoint := range endpoints { logrus.Debugf("Trying to push %s to %s %s", repoInfo.CanonicalName, endpoint.URL, endpoint.Version) - pusher, err := NewPusher(ref, endpoint, repoInfo, imagePushConfig, sf) + pusher, err := NewPusher(ref, endpoint, repoInfo, imagePushConfig) if err != nil { lastErr = err continue } - if fallback, err := pusher.Push(); err != nil { + if fallback, err := pusher.Push(ctx); err != nil { + // Was this push cancelled? If so, don't try to fall + // back. + select { + case <-ctx.Done(): + fallback = false + default: + } + if fallback { lastErr = err continue diff --git a/distribution/push_v1.go b/distribution/push_v1.go index a4a0de0811..155bbc8647 100644 --- a/distribution/push_v1.go +++ b/distribution/push_v1.go @@ -2,8 +2,6 @@ package distribution import ( "fmt" - "io" - "io/ioutil" "sync" "github.com/Sirupsen/logrus" @@ -15,25 +13,23 @@ import ( "github.com/docker/docker/image/v1" "github.com/docker/docker/layer" "github.com/docker/docker/pkg/ioutils" - "github.com/docker/docker/pkg/progressreader" - "github.com/docker/docker/pkg/streamformatter" + "github.com/docker/docker/pkg/progress" "github.com/docker/docker/pkg/stringid" "github.com/docker/docker/registry" + "golang.org/x/net/context" ) type v1Pusher struct { + ctx context.Context v1IDService *metadata.V1IDService endpoint registry.APIEndpoint ref reference.Named repoInfo *registry.RepositoryInfo config *ImagePushConfig - sf *streamformatter.StreamFormatter session *registry.Session - - out io.Writer } -func (p *v1Pusher) Push() (fallback bool, err error) { +func (p *v1Pusher) Push(ctx context.Context) (fallback bool, err error) { tlsConfig, err := p.config.RegistryService.TLSConfig(p.repoInfo.Index.Name) if err != nil { return false, err @@ -55,7 +51,7 @@ func (p *v1Pusher) Push() (fallback bool, err error) { // TODO(dmcgowan): Check if should fallback return true, err } - if err := p.pushRepository(); err != nil { + if err := p.pushRepository(ctx); err != nil { // TODO(dmcgowan): Check if should fallback return false, err } @@ -306,12 +302,12 @@ func (p *v1Pusher) lookupImageOnEndpoint(wg *sync.WaitGroup, endpoint string, im logrus.Errorf("Error in LookupRemoteImage: %s", err) imagesToPush <- v1ID } else { - p.out.Write(p.sf.FormatStatus("", "Image %s already pushed, skipping", stringid.TruncateID(v1ID))) + progress.Messagef(p.config.ProgressOutput, "", "Image %s already pushed, skipping", stringid.TruncateID(v1ID)) } } } -func (p *v1Pusher) pushImageToEndpoint(endpoint string, imageList []v1Image, tags map[image.ID][]string, repo *registry.RepositoryData) error { +func (p *v1Pusher) pushImageToEndpoint(ctx context.Context, endpoint string, imageList []v1Image, tags map[image.ID][]string, repo *registry.RepositoryData) error { workerCount := len(imageList) // start a maximum of 5 workers to check if images exist on the specified endpoint. if workerCount > 5 { @@ -349,14 +345,14 @@ func (p *v1Pusher) pushImageToEndpoint(endpoint string, imageList []v1Image, tag for _, img := range imageList { v1ID := img.V1ID() if _, push := shouldPush[v1ID]; push { - if _, err := p.pushImage(img, endpoint); err != nil { + if _, err := p.pushImage(ctx, img, endpoint); err != nil { // FIXME: Continue on error? return err } } if topImage, isTopImage := img.(*v1TopImage); isTopImage { for _, tag := range tags[topImage.imageID] { - p.out.Write(p.sf.FormatStatus("", "Pushing tag for rev [%s] on {%s}", stringid.TruncateID(v1ID), endpoint+"repositories/"+p.repoInfo.RemoteName.Name()+"/tags/"+tag)) + progress.Messagef(p.config.ProgressOutput, "", "Pushing tag for rev [%s] on {%s}", stringid.TruncateID(v1ID), endpoint+"repositories/"+p.repoInfo.RemoteName.Name()+"/tags/"+tag) if err := p.session.PushRegistryTag(p.repoInfo.RemoteName, v1ID, tag, endpoint); err != nil { return err } @@ -367,8 +363,7 @@ func (p *v1Pusher) pushImageToEndpoint(endpoint string, imageList []v1Image, tag } // pushRepository pushes layers that do not already exist on the registry. -func (p *v1Pusher) pushRepository() error { - p.out = ioutils.NewWriteFlusher(p.config.OutStream) +func (p *v1Pusher) pushRepository(ctx context.Context) error { imgList, tags, referencedLayers, err := p.getImageList() defer func() { for _, l := range referencedLayers { @@ -378,7 +373,7 @@ func (p *v1Pusher) pushRepository() error { if err != nil { return err } - p.out.Write(p.sf.FormatStatus("", "Sending image list")) + progress.Message(p.config.ProgressOutput, "", "Sending image list") imageIndex := createImageIndex(imgList, tags) for _, data := range imageIndex { @@ -391,10 +386,10 @@ func (p *v1Pusher) pushRepository() error { if err != nil { return err } - p.out.Write(p.sf.FormatStatus("", "Pushing repository %s", p.repoInfo.CanonicalName)) + progress.Message(p.config.ProgressOutput, "", "Pushing repository "+p.repoInfo.CanonicalName.String()) // push the repository to each of the endpoints only if it does not exist. for _, endpoint := range repoData.Endpoints { - if err := p.pushImageToEndpoint(endpoint, imgList, tags, repoData); err != nil { + if err := p.pushImageToEndpoint(ctx, endpoint, imgList, tags, repoData); err != nil { return err } } @@ -402,11 +397,11 @@ func (p *v1Pusher) pushRepository() error { return err } -func (p *v1Pusher) pushImage(v1Image v1Image, ep string) (checksum string, err error) { +func (p *v1Pusher) pushImage(ctx context.Context, v1Image v1Image, ep string) (checksum string, err error) { v1ID := v1Image.V1ID() jsonRaw := v1Image.Config() - p.out.Write(p.sf.FormatProgress(stringid.TruncateID(v1ID), "Pushing", nil)) + progress.Update(p.config.ProgressOutput, stringid.TruncateID(v1ID), "Pushing") // General rule is to use ID for graph accesses and compatibilityID for // calls to session.registry() @@ -417,7 +412,7 @@ func (p *v1Pusher) pushImage(v1Image v1Image, ep string) (checksum string, err e // Send the json if err := p.session.PushImageJSONRegistry(imgData, jsonRaw, ep); err != nil { if err == registry.ErrAlreadyExists { - p.out.Write(p.sf.FormatProgress(stringid.TruncateID(v1ID), "Image already pushed, skipping", nil)) + progress.Update(p.config.ProgressOutput, stringid.TruncateID(v1ID), "Image already pushed, skipping") return "", nil } return "", err @@ -437,15 +432,8 @@ func (p *v1Pusher) pushImage(v1Image v1Image, ep string) (checksum string, err e // Send the layer logrus.Debugf("rendered layer for %s of [%d] size", v1ID, size) - reader := progressreader.New(progressreader.Config{ - In: ioutil.NopCloser(arch), - Out: p.out, - Formatter: p.sf, - Size: size, - NewLines: false, - ID: stringid.TruncateID(v1ID), - Action: "Pushing", - }) + reader := progress.NewProgressReader(ioutils.NewCancelReadCloser(ctx, arch), p.config.ProgressOutput, size, stringid.TruncateID(v1ID), "Pushing") + defer reader.Close() checksum, checksumPayload, err := p.session.PushImageLayerRegistry(v1ID, reader, ep, jsonRaw) if err != nil { @@ -458,10 +446,10 @@ func (p *v1Pusher) pushImage(v1Image v1Image, ep string) (checksum string, err e return "", err } - if err := p.v1IDService.Set(v1ID, p.repoInfo.Index.Name, l.ChainID()); err != nil { + if err := p.v1IDService.Set(v1ID, p.repoInfo.Index.Name, l.DiffID()); err != nil { logrus.Warnf("Could not set v1 ID mapping: %v", err) } - p.out.Write(p.sf.FormatProgress(stringid.TruncateID(v1ID), "Image successfully pushed", nil)) + progress.Update(p.config.ProgressOutput, stringid.TruncateID(v1ID), "Image successfully pushed") return imgData.Checksum, nil } diff --git a/distribution/push_v2.go b/distribution/push_v2.go index f2c23ea0ea..4f725eb60a 100644 --- a/distribution/push_v2.go +++ b/distribution/push_v2.go @@ -5,7 +5,7 @@ import ( "errors" "fmt" "io" - "io/ioutil" + "sync" "time" "github.com/Sirupsen/logrus" @@ -15,11 +15,12 @@ import ( "github.com/docker/distribution/manifest/schema1" "github.com/docker/distribution/reference" "github.com/docker/docker/distribution/metadata" + "github.com/docker/docker/distribution/xfer" "github.com/docker/docker/image" "github.com/docker/docker/image/v1" "github.com/docker/docker/layer" - "github.com/docker/docker/pkg/progressreader" - "github.com/docker/docker/pkg/streamformatter" + "github.com/docker/docker/pkg/ioutils" + "github.com/docker/docker/pkg/progress" "github.com/docker/docker/pkg/stringid" "github.com/docker/docker/registry" "github.com/docker/docker/tag" @@ -32,16 +33,20 @@ type v2Pusher struct { endpoint registry.APIEndpoint repoInfo *registry.RepositoryInfo config *ImagePushConfig - sf *streamformatter.StreamFormatter repo distribution.Repository // layersPushed is the set of layers known to exist on the remote side. // This avoids redundant queries when pushing multiple tags that // involve the same layers. + layersPushed pushMap +} + +type pushMap struct { + sync.Mutex layersPushed map[digest.Digest]bool } -func (p *v2Pusher) Push() (fallback bool, err error) { +func (p *v2Pusher) Push(ctx context.Context) (fallback bool, err error) { p.repo, err = NewV2Repository(p.repoInfo, p.endpoint, p.config.MetaHeaders, p.config.AuthConfig, "push", "pull") if err != nil { logrus.Debugf("Error getting v2 registry: %v", err) @@ -75,7 +80,7 @@ func (p *v2Pusher) Push() (fallback bool, err error) { } for _, association := range associations { - if err := p.pushV2Tag(association); err != nil { + if err := p.pushV2Tag(ctx, association); err != nil { return false, err } } @@ -83,7 +88,7 @@ func (p *v2Pusher) Push() (fallback bool, err error) { return false, nil } -func (p *v2Pusher) pushV2Tag(association tag.Association) error { +func (p *v2Pusher) pushV2Tag(ctx context.Context, association tag.Association) error { ref := association.Ref logrus.Debugf("Pushing repository: %s", ref.String()) @@ -92,8 +97,6 @@ func (p *v2Pusher) pushV2Tag(association tag.Association) error { return fmt.Errorf("could not find image from tag %s: %v", ref.String(), err) } - out := p.config.OutStream - var l layer.Layer topLayerID := img.RootFS.ChainID() @@ -107,33 +110,41 @@ func (p *v2Pusher) pushV2Tag(association tag.Association) error { defer layer.ReleaseAndLog(p.config.LayerStore, l) } - fsLayers := make(map[layer.DiffID]schema1.FSLayer) + var descriptors []xfer.UploadDescriptor // Push empty layer if necessary for _, h := range img.History { if h.EmptyLayer { - dgst, err := p.pushLayerIfNecessary(out, layer.EmptyLayer) - if err != nil { - return err + descriptors = []xfer.UploadDescriptor{ + &v2PushDescriptor{ + layer: layer.EmptyLayer, + blobSumService: p.blobSumService, + repo: p.repo, + layersPushed: &p.layersPushed, + }, } - p.layersPushed[dgst] = true - fsLayers[layer.EmptyLayer.DiffID()] = schema1.FSLayer{BlobSum: dgst} break } } + // Loop bounds condition is to avoid pushing the base layer on Windows. for i := 0; i < len(img.RootFS.DiffIDs); i++ { - dgst, err := p.pushLayerIfNecessary(out, l) - if err != nil { - return err + descriptor := &v2PushDescriptor{ + layer: l, + blobSumService: p.blobSumService, + repo: p.repo, + layersPushed: &p.layersPushed, } - - p.layersPushed[dgst] = true - fsLayers[l.DiffID()] = schema1.FSLayer{BlobSum: dgst} + descriptors = append(descriptors, descriptor) l = l.Parent() } + fsLayers, err := p.config.UploadManager.Upload(ctx, descriptors, p.config.ProgressOutput) + if err != nil { + return err + } + var tag string if tagged, isTagged := ref.(reference.Tagged); isTagged { tag = tagged.Tag() @@ -157,59 +168,124 @@ func (p *v2Pusher) pushV2Tag(association tag.Association) error { if tagged, isTagged := ref.(reference.Tagged); isTagged { // NOTE: do not change this format without first changing the trust client // code. This information is used to determine what was pushed and should be signed. - out.Write(p.sf.FormatStatus("", "%s: digest: %s size: %d", tagged.Tag(), manifestDigest, manifestSize)) + progress.Messagef(p.config.ProgressOutput, "", "%s: digest: %s size: %d", tagged.Tag(), manifestDigest, manifestSize) } } - manSvc, err := p.repo.Manifests(context.Background()) + manSvc, err := p.repo.Manifests(ctx) if err != nil { return err } return manSvc.Put(signed) } -func (p *v2Pusher) pushLayerIfNecessary(out io.Writer, l layer.Layer) (digest.Digest, error) { - logrus.Debugf("Pushing layer: %s", l.DiffID()) +type v2PushDescriptor struct { + layer layer.Layer + blobSumService *metadata.BlobSumService + repo distribution.Repository + layersPushed *pushMap +} + +func (pd *v2PushDescriptor) Key() string { + return "v2push:" + pd.repo.Name() + " " + pd.layer.DiffID().String() +} + +func (pd *v2PushDescriptor) ID() string { + return stringid.TruncateID(pd.layer.DiffID().String()) +} + +func (pd *v2PushDescriptor) DiffID() layer.DiffID { + return pd.layer.DiffID() +} + +func (pd *v2PushDescriptor) Upload(ctx context.Context, progressOutput progress.Output) (digest.Digest, error) { + diffID := pd.DiffID() + + logrus.Debugf("Pushing layer: %s", diffID) // Do we have any blobsums associated with this layer's DiffID? - possibleBlobsums, err := p.blobSumService.GetBlobSums(l.DiffID()) + possibleBlobsums, err := pd.blobSumService.GetBlobSums(diffID) if err == nil { - dgst, exists, err := p.blobSumAlreadyExists(possibleBlobsums) + dgst, exists, err := blobSumAlreadyExists(ctx, possibleBlobsums, pd.repo, pd.layersPushed) if err != nil { - out.Write(p.sf.FormatProgress(stringid.TruncateID(string(l.DiffID())), "Image push failed", nil)) - return "", err + progress.Update(progressOutput, pd.ID(), "Image push failed") + return "", retryOnError(err) } if exists { - out.Write(p.sf.FormatProgress(stringid.TruncateID(string(l.DiffID())), "Layer already exists", nil)) + progress.Update(progressOutput, pd.ID(), "Layer already exists") return dgst, nil } } // if digest was empty or not saved, or if blob does not exist on the remote repository, // then push the blob. - pushDigest, err := p.pushV2Layer(p.repo.Blobs(context.Background()), l) + bs := pd.repo.Blobs(ctx) + + // Send the layer + layerUpload, err := bs.Create(ctx) if err != nil { - return "", err + return "", retryOnError(err) } + defer layerUpload.Close() + + arch, err := pd.layer.TarStream() + if err != nil { + return "", xfer.DoNotRetry{Err: err} + } + + // don't care if this fails; best effort + size, _ := pd.layer.DiffSize() + + reader := progress.NewProgressReader(ioutils.NewCancelReadCloser(ctx, arch), progressOutput, size, pd.ID(), "Pushing") + defer reader.Close() + compressedReader := compress(reader) + + digester := digest.Canonical.New() + tee := io.TeeReader(compressedReader, digester.Hash()) + + nn, err := layerUpload.ReadFrom(tee) + compressedReader.Close() + if err != nil { + return "", retryOnError(err) + } + + pushDigest := digester.Digest() + if _, err := layerUpload.Commit(ctx, distribution.Descriptor{Digest: pushDigest}); err != nil { + return "", retryOnError(err) + } + + logrus.Debugf("uploaded layer %s (%s), %d bytes", diffID, pushDigest, nn) + progress.Update(progressOutput, pd.ID(), "Pushed") + // Cache mapping from this layer's DiffID to the blobsum - if err := p.blobSumService.Add(l.DiffID(), pushDigest); err != nil { - return "", err + if err := pd.blobSumService.Add(diffID, pushDigest); err != nil { + return "", xfer.DoNotRetry{Err: err} } + pd.layersPushed.Lock() + pd.layersPushed.layersPushed[pushDigest] = true + pd.layersPushed.Unlock() + return pushDigest, nil } // blobSumAlreadyExists checks if the registry already know about any of the // blobsums passed in the "blobsums" slice. If it finds one that the registry // knows about, it returns the known digest and "true". -func (p *v2Pusher) blobSumAlreadyExists(blobsums []digest.Digest) (digest.Digest, bool, error) { +func blobSumAlreadyExists(ctx context.Context, blobsums []digest.Digest, repo distribution.Repository, layersPushed *pushMap) (digest.Digest, bool, error) { + layersPushed.Lock() for _, dgst := range blobsums { - if p.layersPushed[dgst] { + if layersPushed.layersPushed[dgst] { // it is already known that the push is not needed and // therefore doing a stat is unnecessary + layersPushed.Unlock() return dgst, true, nil } - _, err := p.repo.Blobs(context.Background()).Stat(context.Background(), dgst) + } + layersPushed.Unlock() + + for _, dgst := range blobsums { + _, err := repo.Blobs(ctx).Stat(ctx, dgst) switch err { case nil: return dgst, true, nil @@ -226,7 +302,7 @@ func (p *v2Pusher) blobSumAlreadyExists(blobsums []digest.Digest) (digest.Digest // FSLayer digests. // FIXME: This should be moved to the distribution repo, since it will also // be useful for converting new manifests to the old format. -func CreateV2Manifest(name, tag string, img *image.Image, fsLayers map[layer.DiffID]schema1.FSLayer) (*schema1.Manifest, error) { +func CreateV2Manifest(name, tag string, img *image.Image, fsLayers map[layer.DiffID]digest.Digest) (*schema1.Manifest, error) { if len(img.History) == 0 { return nil, errors.New("empty history when trying to create V2 manifest") } @@ -271,7 +347,7 @@ func CreateV2Manifest(name, tag string, img *image.Image, fsLayers map[layer.Dif if !present { return nil, fmt.Errorf("missing layer in CreateV2Manifest: %s", diffID.String()) } - dgst, err := digest.FromBytes([]byte(fsLayer.BlobSum.Hex() + " " + parent)) + dgst, err := digest.FromBytes([]byte(fsLayer.Hex() + " " + parent)) if err != nil { return nil, err } @@ -294,7 +370,7 @@ func CreateV2Manifest(name, tag string, img *image.Image, fsLayers map[layer.Dif reversedIndex := len(img.History) - i - 1 history[reversedIndex].V1Compatibility = string(jsonBytes) - fsLayerList[reversedIndex] = fsLayer + fsLayerList[reversedIndex] = schema1.FSLayer{BlobSum: fsLayer} parent = v1ID } @@ -315,11 +391,11 @@ func CreateV2Manifest(name, tag string, img *image.Image, fsLayers map[layer.Dif return nil, fmt.Errorf("missing layer in CreateV2Manifest: %s", diffID.String()) } - dgst, err := digest.FromBytes([]byte(fsLayer.BlobSum.Hex() + " " + parent + " " + string(img.RawJSON()))) + dgst, err := digest.FromBytes([]byte(fsLayer.Hex() + " " + parent + " " + string(img.RawJSON()))) if err != nil { return nil, err } - fsLayerList[0] = fsLayer + fsLayerList[0] = schema1.FSLayer{BlobSum: fsLayer} // Top-level v1compatibility string should be a modified version of the // image config. @@ -346,66 +422,3 @@ func CreateV2Manifest(name, tag string, img *image.Image, fsLayers map[layer.Dif History: history, }, nil } - -func rawJSON(value interface{}) *json.RawMessage { - jsonval, err := json.Marshal(value) - if err != nil { - return nil - } - return (*json.RawMessage)(&jsonval) -} - -func (p *v2Pusher) pushV2Layer(bs distribution.BlobService, l layer.Layer) (digest.Digest, error) { - out := p.config.OutStream - displayID := stringid.TruncateID(string(l.DiffID())) - - out.Write(p.sf.FormatProgress(displayID, "Preparing", nil)) - - arch, err := l.TarStream() - if err != nil { - return "", err - } - defer arch.Close() - - // Send the layer - layerUpload, err := bs.Create(context.Background()) - if err != nil { - return "", err - } - defer layerUpload.Close() - - // don't care if this fails; best effort - size, _ := l.DiffSize() - - reader := progressreader.New(progressreader.Config{ - In: ioutil.NopCloser(arch), // we'll take care of close here. - Out: out, - Formatter: p.sf, - Size: size, - NewLines: false, - ID: displayID, - Action: "Pushing", - }) - - compressedReader := compress(reader) - - digester := digest.Canonical.New() - tee := io.TeeReader(compressedReader, digester.Hash()) - - out.Write(p.sf.FormatProgress(displayID, "Pushing", nil)) - nn, err := layerUpload.ReadFrom(tee) - compressedReader.Close() - if err != nil { - return "", err - } - - dgst := digester.Digest() - if _, err := layerUpload.Commit(context.Background(), distribution.Descriptor{Digest: dgst}); err != nil { - return "", err - } - - logrus.Debugf("uploaded layer %s (%s), %d bytes", l.DiffID(), dgst, nn) - out.Write(p.sf.FormatProgress(displayID, "Pushed", nil)) - - return dgst, nil -} diff --git a/distribution/push_v2_test.go b/distribution/push_v2_test.go index ab9e6612c6..96a3939607 100644 --- a/distribution/push_v2_test.go +++ b/distribution/push_v2_test.go @@ -116,10 +116,10 @@ func TestCreateV2Manifest(t *testing.T) { t.Fatalf("json decoding failed: %v", err) } - fsLayers := map[layer.DiffID]schema1.FSLayer{ - layer.DiffID("sha256:c6f988f4874bb0add23a778f753c65efe992244e148a1d2ec2a8b664fb66bbd1"): {BlobSum: digest.Digest("sha256:a3ed95caeb02ffe68cdd9fd84406680ae93d633cb16422d00e8a7c22955b46d4")}, - layer.DiffID("sha256:5f70bf18a086007016e948b04aed3b82103a36bea41755b6cddfaf10ace3c6ef"): {BlobSum: digest.Digest("sha256:86e0e091d0da6bde2456dbb48306f3956bbeb2eae1b5b9a43045843f69fe4aaa")}, - layer.DiffID("sha256:13f53e08df5a220ab6d13c58b2bf83a59cbdc2e04d0a3f041ddf4b0ba4112d49"): {BlobSum: digest.Digest("sha256:b4ed95caeb02ffe68cdd9fd84406680ae93d633cb16422d00e8a7c22955b46d4")}, + fsLayers := map[layer.DiffID]digest.Digest{ + layer.DiffID("sha256:c6f988f4874bb0add23a778f753c65efe992244e148a1d2ec2a8b664fb66bbd1"): digest.Digest("sha256:a3ed95caeb02ffe68cdd9fd84406680ae93d633cb16422d00e8a7c22955b46d4"), + layer.DiffID("sha256:5f70bf18a086007016e948b04aed3b82103a36bea41755b6cddfaf10ace3c6ef"): digest.Digest("sha256:86e0e091d0da6bde2456dbb48306f3956bbeb2eae1b5b9a43045843f69fe4aaa"), + layer.DiffID("sha256:13f53e08df5a220ab6d13c58b2bf83a59cbdc2e04d0a3f041ddf4b0ba4112d49"): digest.Digest("sha256:b4ed95caeb02ffe68cdd9fd84406680ae93d633cb16422d00e8a7c22955b46d4"), } manifest, err := CreateV2Manifest("testrepo", "testtag", img, fsLayers) diff --git a/distribution/registry.go b/distribution/registry.go index ed8b8c20e5..bb5b58a3af 100644 --- a/distribution/registry.go +++ b/distribution/registry.go @@ -13,10 +13,12 @@ import ( "github.com/docker/distribution" "github.com/docker/distribution/digest" "github.com/docker/distribution/manifest/schema1" + "github.com/docker/distribution/registry/api/errcode" "github.com/docker/distribution/registry/client" "github.com/docker/distribution/registry/client/auth" "github.com/docker/distribution/registry/client/transport" "github.com/docker/docker/cliconfig" + "github.com/docker/docker/distribution/xfer" "github.com/docker/docker/registry" "golang.org/x/net/context" ) @@ -59,7 +61,7 @@ func NewV2Repository(repoInfo *registry.RepositoryInfo, endpoint registry.APIEnd authTransport := transport.NewTransport(base, modifiers...) pingClient := &http.Client{ Transport: authTransport, - Timeout: 5 * time.Second, + Timeout: 15 * time.Second, } endpointStr := strings.TrimRight(endpoint.URL, "/") + "/v2/" req, err := http.NewRequest("GET", endpointStr, nil) @@ -132,3 +134,23 @@ func (th *existingTokenHandler) AuthorizeRequest(req *http.Request, params map[s req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", th.token)) return nil } + +// retryOnError wraps the error in xfer.DoNotRetry if we should not retry the +// operation after this error. +func retryOnError(err error) error { + switch v := err.(type) { + case errcode.Errors: + return retryOnError(v[0]) + case errcode.Error: + switch v.Code { + case errcode.ErrorCodeUnauthorized, errcode.ErrorCodeUnsupported, errcode.ErrorCodeDenied: + return xfer.DoNotRetry{Err: err} + } + + } + // let's be nice and fallback if the error is a completely + // unexpected one. + // If new errors have to be handled in some way, please + // add them to the switch above. + return err +} diff --git a/distribution/registry_unit_test.go b/distribution/registry_unit_test.go index bf13934622..77d810e25b 100644 --- a/distribution/registry_unit_test.go +++ b/distribution/registry_unit_test.go @@ -11,9 +11,9 @@ import ( "github.com/docker/distribution/reference" "github.com/docker/distribution/registry/client/auth" "github.com/docker/docker/cliconfig" - "github.com/docker/docker/pkg/streamformatter" "github.com/docker/docker/registry" "github.com/docker/docker/utils" + "golang.org/x/net/context" ) func TestTokenPassThru(t *testing.T) { @@ -72,8 +72,7 @@ func TestTokenPassThru(t *testing.T) { MetaHeaders: http.Header{}, AuthConfig: authConfig, } - sf := streamformatter.NewJSONStreamFormatter() - puller, err := newPuller(endpoint, repoInfo, imagePullConfig, sf) + puller, err := newPuller(endpoint, repoInfo, imagePullConfig) if err != nil { t.Fatal(err) } @@ -86,7 +85,7 @@ func TestTokenPassThru(t *testing.T) { logrus.Debug("About to pull") // We expect it to fail, since we haven't mock'd the full registry exchange in our handler above tag, _ := reference.WithTag(n, "tag_goes_here") - _ = p.pullV2Repository(tag) + _ = p.pullV2Repository(context.Background(), tag) if !gotToken { t.Fatal("Failed to receive registry token") diff --git a/distribution/xfer/download.go b/distribution/xfer/download.go new file mode 100644 index 0000000000..69c8bad031 --- /dev/null +++ b/distribution/xfer/download.go @@ -0,0 +1,420 @@ +package xfer + +import ( + "errors" + "fmt" + "io" + "time" + + "github.com/Sirupsen/logrus" + "github.com/docker/docker/image" + "github.com/docker/docker/layer" + "github.com/docker/docker/pkg/archive" + "github.com/docker/docker/pkg/ioutils" + "github.com/docker/docker/pkg/progress" + "golang.org/x/net/context" +) + +const maxDownloadAttempts = 5 + +// LayerDownloadManager figures out which layers need to be downloaded, then +// registers and downloads those, taking into account dependencies between +// layers. +type LayerDownloadManager struct { + layerStore layer.Store + tm TransferManager +} + +// NewLayerDownloadManager returns a new LayerDownloadManager. +func NewLayerDownloadManager(layerStore layer.Store, concurrencyLimit int) *LayerDownloadManager { + return &LayerDownloadManager{ + layerStore: layerStore, + tm: NewTransferManager(concurrencyLimit), + } +} + +type downloadTransfer struct { + Transfer + + layerStore layer.Store + layer layer.Layer + err error +} + +// result returns the layer resulting from the download, if the download +// and registration were successful. +func (d *downloadTransfer) result() (layer.Layer, error) { + return d.layer, d.err +} + +// A DownloadDescriptor references a layer that may need to be downloaded. +type DownloadDescriptor interface { + // Key returns the key used to deduplicate downloads. + Key() string + // ID returns the ID for display purposes. + ID() string + // DiffID should return the DiffID for this layer, or an error + // if it is unknown (for example, if it has not been downloaded + // before). + DiffID() (layer.DiffID, error) + // Download is called to perform the download. + Download(ctx context.Context, progressOutput progress.Output) (io.ReadCloser, int64, error) +} + +// DownloadDescriptorWithRegistered is a DownloadDescriptor that has an +// additional Registered method which gets called after a downloaded layer is +// registered. This allows the user of the download manager to know the DiffID +// of each registered layer. This method is called if a cast to +// DownloadDescriptorWithRegistered is successful. +type DownloadDescriptorWithRegistered interface { + DownloadDescriptor + Registered(diffID layer.DiffID) +} + +// Download is a blocking function which ensures the requested layers are +// present in the layer store. It uses the string returned by the Key method to +// deduplicate downloads. If a given layer is not already known to present in +// the layer store, and the key is not used by an in-progress download, the +// Download method is called to get the layer tar data. Layers are then +// registered in the appropriate order. The caller must call the returned +// release function once it is is done with the returned RootFS object. +func (ldm *LayerDownloadManager) Download(ctx context.Context, initialRootFS image.RootFS, layers []DownloadDescriptor, progressOutput progress.Output) (image.RootFS, func(), error) { + var ( + topLayer layer.Layer + topDownload *downloadTransfer + watcher *Watcher + missingLayer bool + transferKey = "" + downloadsByKey = make(map[string]*downloadTransfer) + ) + + rootFS := initialRootFS + for _, descriptor := range layers { + key := descriptor.Key() + transferKey += key + + if !missingLayer { + missingLayer = true + diffID, err := descriptor.DiffID() + if err == nil { + getRootFS := rootFS + getRootFS.Append(diffID) + l, err := ldm.layerStore.Get(getRootFS.ChainID()) + if err == nil { + // Layer already exists. + logrus.Debugf("Layer already exists: %s", descriptor.ID()) + progress.Update(progressOutput, descriptor.ID(), "Already exists") + if topLayer != nil { + layer.ReleaseAndLog(ldm.layerStore, topLayer) + } + topLayer = l + missingLayer = false + rootFS.Append(diffID) + continue + } + } + } + + // Does this layer have the same data as a previous layer in + // the stack? If so, avoid downloading it more than once. + var topDownloadUncasted Transfer + if existingDownload, ok := downloadsByKey[key]; ok { + xferFunc := ldm.makeDownloadFuncFromDownload(descriptor, existingDownload, topDownload) + defer topDownload.Transfer.Release(watcher) + topDownloadUncasted, watcher = ldm.tm.Transfer(transferKey, xferFunc, progressOutput) + topDownload = topDownloadUncasted.(*downloadTransfer) + continue + } + + // Layer is not known to exist - download and register it. + progress.Update(progressOutput, descriptor.ID(), "Pulling fs layer") + + var xferFunc DoFunc + if topDownload != nil { + xferFunc = ldm.makeDownloadFunc(descriptor, "", topDownload) + defer topDownload.Transfer.Release(watcher) + } else { + xferFunc = ldm.makeDownloadFunc(descriptor, rootFS.ChainID(), nil) + } + topDownloadUncasted, watcher = ldm.tm.Transfer(transferKey, xferFunc, progressOutput) + topDownload = topDownloadUncasted.(*downloadTransfer) + downloadsByKey[key] = topDownload + } + + if topDownload == nil { + return rootFS, func() { layer.ReleaseAndLog(ldm.layerStore, topLayer) }, nil + } + + // Won't be using the list built up so far - will generate it + // from downloaded layers instead. + rootFS.DiffIDs = []layer.DiffID{} + + defer func() { + if topLayer != nil { + layer.ReleaseAndLog(ldm.layerStore, topLayer) + } + }() + + select { + case <-ctx.Done(): + topDownload.Transfer.Release(watcher) + return rootFS, func() {}, ctx.Err() + case <-topDownload.Done(): + break + } + + l, err := topDownload.result() + if err != nil { + topDownload.Transfer.Release(watcher) + return rootFS, func() {}, err + } + + // Must do this exactly len(layers) times, so we don't include the + // base layer on Windows. + for range layers { + if l == nil { + topDownload.Transfer.Release(watcher) + return rootFS, func() {}, errors.New("internal error: too few parent layers") + } + rootFS.DiffIDs = append([]layer.DiffID{l.DiffID()}, rootFS.DiffIDs...) + l = l.Parent() + } + return rootFS, func() { topDownload.Transfer.Release(watcher) }, err +} + +// makeDownloadFunc returns a function that performs the layer download and +// registration. If parentDownload is non-nil, it waits for that download to +// complete before the registration step, and registers the downloaded data +// on top of parentDownload's resulting layer. Otherwise, it registers the +// layer on top of the ChainID given by parentLayer. +func (ldm *LayerDownloadManager) makeDownloadFunc(descriptor DownloadDescriptor, parentLayer layer.ChainID, parentDownload *downloadTransfer) DoFunc { + return func(progressChan chan<- progress.Progress, start <-chan struct{}, inactive chan<- struct{}) Transfer { + d := &downloadTransfer{ + Transfer: NewTransfer(), + layerStore: ldm.layerStore, + } + + go func() { + defer func() { + close(progressChan) + }() + + progressOutput := progress.ChanOutput(progressChan) + + select { + case <-start: + default: + progress.Update(progressOutput, descriptor.ID(), "Waiting") + <-start + } + + if parentDownload != nil { + // Did the parent download already fail or get + // cancelled? + select { + case <-parentDownload.Done(): + _, err := parentDownload.result() + if err != nil { + d.err = err + return + } + default: + } + } + + var ( + downloadReader io.ReadCloser + size int64 + err error + retries int + ) + + for { + downloadReader, size, err = descriptor.Download(d.Transfer.Context(), progressOutput) + if err == nil { + break + } + + // If an error was returned because the context + // was cancelled, we shouldn't retry. + select { + case <-d.Transfer.Context().Done(): + d.err = err + return + default: + } + + retries++ + if _, isDNR := err.(DoNotRetry); isDNR || retries == maxDownloadAttempts { + logrus.Errorf("Download failed: %v", err) + d.err = err + return + } + + logrus.Errorf("Download failed, retrying: %v", err) + delay := retries * 5 + ticker := time.NewTicker(time.Second) + + selectLoop: + for { + progress.Updatef(progressOutput, descriptor.ID(), "Retrying in %d seconds", delay) + select { + case <-ticker.C: + delay-- + if delay == 0 { + ticker.Stop() + break selectLoop + } + case <-d.Transfer.Context().Done(): + ticker.Stop() + d.err = errors.New("download cancelled during retry delay") + return + } + + } + } + + close(inactive) + + if parentDownload != nil { + select { + case <-d.Transfer.Context().Done(): + d.err = errors.New("layer registration cancelled") + downloadReader.Close() + return + case <-parentDownload.Done(): + } + + l, err := parentDownload.result() + if err != nil { + d.err = err + downloadReader.Close() + return + } + parentLayer = l.ChainID() + } + + reader := progress.NewProgressReader(ioutils.NewCancelReadCloser(d.Transfer.Context(), downloadReader), progressOutput, size, descriptor.ID(), "Extracting") + defer reader.Close() + + inflatedLayerData, err := archive.DecompressStream(reader) + if err != nil { + d.err = fmt.Errorf("could not get decompression stream: %v", err) + return + } + + d.layer, err = d.layerStore.Register(inflatedLayerData, parentLayer) + if err != nil { + select { + case <-d.Transfer.Context().Done(): + d.err = errors.New("layer registration cancelled") + default: + d.err = fmt.Errorf("failed to register layer: %v", err) + } + return + } + + progress.Update(progressOutput, descriptor.ID(), "Pull complete") + withRegistered, hasRegistered := descriptor.(DownloadDescriptorWithRegistered) + if hasRegistered { + withRegistered.Registered(d.layer.DiffID()) + } + + // Doesn't actually need to be its own goroutine, but + // done like this so we can defer close(c). + go func() { + <-d.Transfer.Released() + if d.layer != nil { + layer.ReleaseAndLog(d.layerStore, d.layer) + } + }() + }() + + return d + } +} + +// makeDownloadFuncFromDownload returns a function that performs the layer +// registration when the layer data is coming from an existing download. It +// waits for sourceDownload and parentDownload to complete, and then +// reregisters the data from sourceDownload's top layer on top of +// parentDownload. This function does not log progress output because it would +// interfere with the progress reporting for sourceDownload, which has the same +// Key. +func (ldm *LayerDownloadManager) makeDownloadFuncFromDownload(descriptor DownloadDescriptor, sourceDownload *downloadTransfer, parentDownload *downloadTransfer) DoFunc { + return func(progressChan chan<- progress.Progress, start <-chan struct{}, inactive chan<- struct{}) Transfer { + d := &downloadTransfer{ + Transfer: NewTransfer(), + layerStore: ldm.layerStore, + } + + go func() { + defer func() { + close(progressChan) + }() + + <-start + + close(inactive) + + select { + case <-d.Transfer.Context().Done(): + d.err = errors.New("layer registration cancelled") + return + case <-parentDownload.Done(): + } + + l, err := parentDownload.result() + if err != nil { + d.err = err + return + } + parentLayer := l.ChainID() + + // sourceDownload should have already finished if + // parentDownload finished, but wait for it explicitly + // to be sure. + select { + case <-d.Transfer.Context().Done(): + d.err = errors.New("layer registration cancelled") + return + case <-sourceDownload.Done(): + } + + l, err = sourceDownload.result() + if err != nil { + d.err = err + return + } + + layerReader, err := l.TarStream() + if err != nil { + d.err = err + return + } + defer layerReader.Close() + + d.layer, err = d.layerStore.Register(layerReader, parentLayer) + if err != nil { + d.err = fmt.Errorf("failed to register layer: %v", err) + return + } + + withRegistered, hasRegistered := descriptor.(DownloadDescriptorWithRegistered) + if hasRegistered { + withRegistered.Registered(d.layer.DiffID()) + } + + // Doesn't actually need to be its own goroutine, but + // done like this so we can defer close(c). + go func() { + <-d.Transfer.Released() + if d.layer != nil { + layer.ReleaseAndLog(d.layerStore, d.layer) + } + }() + }() + + return d + } +} diff --git a/distribution/xfer/download_test.go b/distribution/xfer/download_test.go new file mode 100644 index 0000000000..ff665df344 --- /dev/null +++ b/distribution/xfer/download_test.go @@ -0,0 +1,332 @@ +package xfer + +import ( + "bytes" + "errors" + "io" + "io/ioutil" + "sync/atomic" + "testing" + "time" + + "github.com/docker/distribution/digest" + "github.com/docker/docker/image" + "github.com/docker/docker/layer" + "github.com/docker/docker/pkg/archive" + "github.com/docker/docker/pkg/progress" + "golang.org/x/net/context" +) + +const maxDownloadConcurrency = 3 + +type mockLayer struct { + layerData bytes.Buffer + diffID layer.DiffID + chainID layer.ChainID + parent layer.Layer +} + +func (ml *mockLayer) TarStream() (io.ReadCloser, error) { + return ioutil.NopCloser(bytes.NewBuffer(ml.layerData.Bytes())), nil +} + +func (ml *mockLayer) ChainID() layer.ChainID { + return ml.chainID +} + +func (ml *mockLayer) DiffID() layer.DiffID { + return ml.diffID +} + +func (ml *mockLayer) Parent() layer.Layer { + return ml.parent +} + +func (ml *mockLayer) Size() (size int64, err error) { + return 0, nil +} + +func (ml *mockLayer) DiffSize() (size int64, err error) { + return 0, nil +} + +func (ml *mockLayer) Metadata() (map[string]string, error) { + return make(map[string]string), nil +} + +type mockLayerStore struct { + layers map[layer.ChainID]*mockLayer +} + +func createChainIDFromParent(parent layer.ChainID, dgsts ...layer.DiffID) layer.ChainID { + if len(dgsts) == 0 { + return parent + } + if parent == "" { + return createChainIDFromParent(layer.ChainID(dgsts[0]), dgsts[1:]...) + } + // H = "H(n-1) SHA256(n)" + dgst, err := digest.FromBytes([]byte(string(parent) + " " + string(dgsts[0]))) + if err != nil { + // Digest calculation is not expected to throw an error, + // any error at this point is a program error + panic(err) + } + return createChainIDFromParent(layer.ChainID(dgst), dgsts[1:]...) +} + +func (ls *mockLayerStore) Register(reader io.Reader, parentID layer.ChainID) (layer.Layer, error) { + var ( + parent layer.Layer + err error + ) + + if parentID != "" { + parent, err = ls.Get(parentID) + if err != nil { + return nil, err + } + } + + l := &mockLayer{parent: parent} + _, err = l.layerData.ReadFrom(reader) + if err != nil { + return nil, err + } + diffID, err := digest.FromBytes(l.layerData.Bytes()) + if err != nil { + return nil, err + } + l.diffID = layer.DiffID(diffID) + l.chainID = createChainIDFromParent(parentID, l.diffID) + + ls.layers[l.chainID] = l + return l, nil +} + +func (ls *mockLayerStore) Get(chainID layer.ChainID) (layer.Layer, error) { + l, ok := ls.layers[chainID] + if !ok { + return nil, layer.ErrLayerDoesNotExist + } + return l, nil +} + +func (ls *mockLayerStore) Release(l layer.Layer) ([]layer.Metadata, error) { + return []layer.Metadata{}, nil +} + +func (ls *mockLayerStore) Mount(id string, parent layer.ChainID, label string, init layer.MountInit) (layer.RWLayer, error) { + return nil, errors.New("not implemented") +} + +func (ls *mockLayerStore) Unmount(id string) error { + return errors.New("not implemented") +} + +func (ls *mockLayerStore) DeleteMount(id string) ([]layer.Metadata, error) { + return nil, errors.New("not implemented") +} + +func (ls *mockLayerStore) Changes(id string) ([]archive.Change, error) { + return nil, errors.New("not implemented") +} + +type mockDownloadDescriptor struct { + currentDownloads *int32 + id string + diffID layer.DiffID + registeredDiffID layer.DiffID + expectedDiffID layer.DiffID + simulateRetries int +} + +// Key returns the key used to deduplicate downloads. +func (d *mockDownloadDescriptor) Key() string { + return d.id +} + +// ID returns the ID for display purposes. +func (d *mockDownloadDescriptor) ID() string { + return d.id +} + +// DiffID should return the DiffID for this layer, or an error +// if it is unknown (for example, if it has not been downloaded +// before). +func (d *mockDownloadDescriptor) DiffID() (layer.DiffID, error) { + if d.diffID != "" { + return d.diffID, nil + } + return "", errors.New("no diffID available") +} + +func (d *mockDownloadDescriptor) Registered(diffID layer.DiffID) { + d.registeredDiffID = diffID +} + +func (d *mockDownloadDescriptor) mockTarStream() io.ReadCloser { + // The mock implementation returns the ID repeated 5 times as a tar + // stream instead of actual tar data. The data is ignored except for + // computing IDs. + return ioutil.NopCloser(bytes.NewBuffer([]byte(d.id + d.id + d.id + d.id + d.id))) +} + +// Download is called to perform the download. +func (d *mockDownloadDescriptor) Download(ctx context.Context, progressOutput progress.Output) (io.ReadCloser, int64, error) { + if d.currentDownloads != nil { + defer atomic.AddInt32(d.currentDownloads, -1) + + if atomic.AddInt32(d.currentDownloads, 1) > maxDownloadConcurrency { + return nil, 0, errors.New("concurrency limit exceeded") + } + } + + // Sleep a bit to simulate a time-consuming download. + for i := int64(0); i <= 10; i++ { + select { + case <-ctx.Done(): + return nil, 0, ctx.Err() + case <-time.After(10 * time.Millisecond): + progressOutput.WriteProgress(progress.Progress{ID: d.ID(), Action: "Downloading", Current: i, Total: 10}) + } + } + + if d.simulateRetries != 0 { + d.simulateRetries-- + return nil, 0, errors.New("simulating retry") + } + + return d.mockTarStream(), 0, nil +} + +func downloadDescriptors(currentDownloads *int32) []DownloadDescriptor { + return []DownloadDescriptor{ + &mockDownloadDescriptor{ + currentDownloads: currentDownloads, + id: "id1", + expectedDiffID: layer.DiffID("sha256:68e2c75dc5c78ea9240689c60d7599766c213ae210434c53af18470ae8c53ec1"), + }, + &mockDownloadDescriptor{ + currentDownloads: currentDownloads, + id: "id2", + expectedDiffID: layer.DiffID("sha256:64a636223116aa837973a5d9c2bdd17d9b204e4f95ac423e20e65dfbb3655473"), + }, + &mockDownloadDescriptor{ + currentDownloads: currentDownloads, + id: "id3", + expectedDiffID: layer.DiffID("sha256:58745a8bbd669c25213e9de578c4da5c8ee1c836b3581432c2b50e38a6753300"), + }, + &mockDownloadDescriptor{ + currentDownloads: currentDownloads, + id: "id2", + expectedDiffID: layer.DiffID("sha256:64a636223116aa837973a5d9c2bdd17d9b204e4f95ac423e20e65dfbb3655473"), + }, + &mockDownloadDescriptor{ + currentDownloads: currentDownloads, + id: "id4", + expectedDiffID: layer.DiffID("sha256:0dfb5b9577716cc173e95af7c10289322c29a6453a1718addc00c0c5b1330936"), + simulateRetries: 1, + }, + &mockDownloadDescriptor{ + currentDownloads: currentDownloads, + id: "id5", + expectedDiffID: layer.DiffID("sha256:0a5f25fa1acbc647f6112a6276735d0fa01e4ee2aa7ec33015e337350e1ea23d"), + }, + } +} + +func TestSuccessfulDownload(t *testing.T) { + layerStore := &mockLayerStore{make(map[layer.ChainID]*mockLayer)} + ldm := NewLayerDownloadManager(layerStore, maxDownloadConcurrency) + + progressChan := make(chan progress.Progress) + progressDone := make(chan struct{}) + receivedProgress := make(map[string]int64) + + go func() { + for p := range progressChan { + if p.Action == "Downloading" { + receivedProgress[p.ID] = p.Current + } else if p.Action == "Already exists" { + receivedProgress[p.ID] = -1 + } + } + close(progressDone) + }() + + var currentDownloads int32 + descriptors := downloadDescriptors(¤tDownloads) + + firstDescriptor := descriptors[0].(*mockDownloadDescriptor) + + // Pre-register the first layer to simulate an already-existing layer + l, err := layerStore.Register(firstDescriptor.mockTarStream(), "") + if err != nil { + t.Fatal(err) + } + firstDescriptor.diffID = l.DiffID() + + rootFS, releaseFunc, err := ldm.Download(context.Background(), *image.NewRootFS(), descriptors, progress.ChanOutput(progressChan)) + if err != nil { + t.Fatalf("download error: %v", err) + } + + releaseFunc() + + close(progressChan) + <-progressDone + + if len(rootFS.DiffIDs) != len(descriptors) { + t.Fatal("got wrong number of diffIDs in rootfs") + } + + for i, d := range descriptors { + descriptor := d.(*mockDownloadDescriptor) + + if descriptor.diffID != "" { + if receivedProgress[d.ID()] != -1 { + t.Fatalf("did not get 'already exists' message for %v", d.ID()) + } + } else if receivedProgress[d.ID()] != 10 { + t.Fatalf("missing or wrong progress output for %v (got: %d)", d.ID(), receivedProgress[d.ID()]) + } + + if rootFS.DiffIDs[i] != descriptor.expectedDiffID { + t.Fatalf("rootFS item %d has the wrong diffID (expected: %v got: %v)", i, descriptor.expectedDiffID, rootFS.DiffIDs[i]) + } + + if descriptor.diffID == "" && descriptor.registeredDiffID != rootFS.DiffIDs[i] { + t.Fatal("diffID mismatch between rootFS and Registered callback") + } + } +} + +func TestCancelledDownload(t *testing.T) { + ldm := NewLayerDownloadManager(&mockLayerStore{make(map[layer.ChainID]*mockLayer)}, maxDownloadConcurrency) + + progressChan := make(chan progress.Progress) + progressDone := make(chan struct{}) + + go func() { + for range progressChan { + } + close(progressDone) + }() + + ctx, cancel := context.WithCancel(context.Background()) + + go func() { + <-time.After(time.Millisecond) + cancel() + }() + + descriptors := downloadDescriptors(nil) + _, _, err := ldm.Download(ctx, *image.NewRootFS(), descriptors, progress.ChanOutput(progressChan)) + if err != context.Canceled { + t.Fatal("expected download to be cancelled") + } + + close(progressChan) + <-progressDone +} diff --git a/distribution/xfer/transfer.go b/distribution/xfer/transfer.go new file mode 100644 index 0000000000..c0ae3d1c4b --- /dev/null +++ b/distribution/xfer/transfer.go @@ -0,0 +1,343 @@ +package xfer + +import ( + "sync" + + "github.com/docker/docker/pkg/progress" + "golang.org/x/net/context" +) + +// DoNotRetry is an error wrapper indicating that the error cannot be resolved +// with a retry. +type DoNotRetry struct { + Err error +} + +// Error returns the stringified representation of the encapsulated error. +func (e DoNotRetry) Error() string { + return e.Err.Error() +} + +// Watcher is returned by Watch and can be passed to Release to stop watching. +type Watcher struct { + // signalChan is used to signal to the watcher goroutine that + // new progress information is available, or that the transfer + // has finished. + signalChan chan struct{} + // releaseChan signals to the watcher goroutine that the watcher + // should be detached. + releaseChan chan struct{} + // running remains open as long as the watcher is watching the + // transfer. It gets closed if the transfer finishes or the + // watcher is detached. + running chan struct{} +} + +// Transfer represents an in-progress transfer. +type Transfer interface { + Watch(progressOutput progress.Output) *Watcher + Release(*Watcher) + Context() context.Context + Cancel() + Done() <-chan struct{} + Released() <-chan struct{} + Broadcast(masterProgressChan <-chan progress.Progress) +} + +type transfer struct { + mu sync.Mutex + + ctx context.Context + cancel context.CancelFunc + + // watchers keeps track of the goroutines monitoring progress output, + // indexed by the channels that release them. + watchers map[chan struct{}]*Watcher + + // lastProgress is the most recently received progress event. + lastProgress progress.Progress + // hasLastProgress is true when lastProgress has been set. + hasLastProgress bool + + // running remains open as long as the transfer is in progress. + running chan struct{} + // hasWatchers stays open until all watchers release the trasnfer. + hasWatchers chan struct{} + + // broadcastDone is true if the master progress channel has closed. + broadcastDone bool + // broadcastSyncChan allows watchers to "ping" the broadcasting + // goroutine to wait for it for deplete its input channel. This ensures + // a detaching watcher won't miss an event that was sent before it + // started detaching. + broadcastSyncChan chan struct{} +} + +// NewTransfer creates a new transfer. +func NewTransfer() Transfer { + t := &transfer{ + watchers: make(map[chan struct{}]*Watcher), + running: make(chan struct{}), + hasWatchers: make(chan struct{}), + broadcastSyncChan: make(chan struct{}), + } + + // This uses context.Background instead of a caller-supplied context + // so that a transfer won't be cancelled automatically if the client + // which requested it is ^C'd (there could be other viewers). + t.ctx, t.cancel = context.WithCancel(context.Background()) + + return t +} + +// Broadcast copies the progress and error output to all viewers. +func (t *transfer) Broadcast(masterProgressChan <-chan progress.Progress) { + for { + var ( + p progress.Progress + ok bool + ) + select { + case p, ok = <-masterProgressChan: + default: + // We've depleted the channel, so now we can handle + // reads on broadcastSyncChan to let detaching watchers + // know we're caught up. + select { + case <-t.broadcastSyncChan: + continue + case p, ok = <-masterProgressChan: + } + } + + t.mu.Lock() + if ok { + t.lastProgress = p + t.hasLastProgress = true + for _, w := range t.watchers { + select { + case w.signalChan <- struct{}{}: + default: + } + } + + } else { + t.broadcastDone = true + } + t.mu.Unlock() + if !ok { + close(t.running) + return + } + } +} + +// Watch adds a watcher to the transfer. The supplied channel gets progress +// updates and is closed when the transfer finishes. +func (t *transfer) Watch(progressOutput progress.Output) *Watcher { + t.mu.Lock() + defer t.mu.Unlock() + + w := &Watcher{ + releaseChan: make(chan struct{}), + signalChan: make(chan struct{}), + running: make(chan struct{}), + } + + if t.broadcastDone { + close(w.running) + return w + } + + t.watchers[w.releaseChan] = w + + go func() { + defer func() { + close(w.running) + }() + done := false + for { + t.mu.Lock() + hasLastProgress := t.hasLastProgress + lastProgress := t.lastProgress + t.mu.Unlock() + + // This might write the last progress item a + // second time (since channel closure also gets + // us here), but that's fine. + if hasLastProgress { + progressOutput.WriteProgress(lastProgress) + } + + if done { + return + } + + select { + case <-w.signalChan: + case <-w.releaseChan: + done = true + // Since the watcher is going to detach, make + // sure the broadcaster is caught up so we + // don't miss anything. + select { + case t.broadcastSyncChan <- struct{}{}: + case <-t.running: + } + case <-t.running: + done = true + } + } + }() + + return w +} + +// Release is the inverse of Watch; indicating that the watcher no longer wants +// to be notified about the progress of the transfer. All calls to Watch must +// be paired with later calls to Release so that the lifecycle of the transfer +// is properly managed. +func (t *transfer) Release(watcher *Watcher) { + t.mu.Lock() + delete(t.watchers, watcher.releaseChan) + + if len(t.watchers) == 0 { + close(t.hasWatchers) + t.cancel() + } + t.mu.Unlock() + + close(watcher.releaseChan) + // Block until the watcher goroutine completes + <-watcher.running +} + +// Done returns a channel which is closed if the transfer completes or is +// cancelled. Note that having 0 watchers causes a transfer to be cancelled. +func (t *transfer) Done() <-chan struct{} { + // Note that this doesn't return t.ctx.Done() because that channel will + // be closed the moment Cancel is called, and we need to return a + // channel that blocks until a cancellation is actually acknowledged by + // the transfer function. + return t.running +} + +// Released returns a channel which is closed once all watchers release the +// transfer. +func (t *transfer) Released() <-chan struct{} { + return t.hasWatchers +} + +// Context returns the context associated with the transfer. +func (t *transfer) Context() context.Context { + return t.ctx +} + +// Cancel cancels the context associated with the transfer. +func (t *transfer) Cancel() { + t.cancel() +} + +// DoFunc is a function called by the transfer manager to actually perform +// a transfer. It should be non-blocking. It should wait until the start channel +// is closed before transfering any data. If the function closes inactive, that +// signals to the transfer manager that the job is no longer actively moving +// data - for example, it may be waiting for a dependent tranfer to finish. +// This prevents it from taking up a slot. +type DoFunc func(progressChan chan<- progress.Progress, start <-chan struct{}, inactive chan<- struct{}) Transfer + +// TransferManager is used by LayerDownloadManager and LayerUploadManager to +// schedule and deduplicate transfers. It is up to the TransferManager +// implementation to make the scheduling and concurrency decisions. +type TransferManager interface { + // Transfer checks if a transfer with the given key is in progress. If + // so, it returns progress and error output from that transfer. + // Otherwise, it will call xferFunc to initiate the transfer. + Transfer(key string, xferFunc DoFunc, progressOutput progress.Output) (Transfer, *Watcher) +} + +type transferManager struct { + mu sync.Mutex + + concurrencyLimit int + activeTransfers int + transfers map[string]Transfer + waitingTransfers []chan struct{} +} + +// NewTransferManager returns a new TransferManager. +func NewTransferManager(concurrencyLimit int) TransferManager { + return &transferManager{ + concurrencyLimit: concurrencyLimit, + transfers: make(map[string]Transfer), + } +} + +// Transfer checks if a transfer matching the given key is in progress. If not, +// it starts one by calling xferFunc. The caller supplies a channel which +// receives progress output from the transfer. +func (tm *transferManager) Transfer(key string, xferFunc DoFunc, progressOutput progress.Output) (Transfer, *Watcher) { + tm.mu.Lock() + defer tm.mu.Unlock() + + if xfer, present := tm.transfers[key]; present { + // Transfer is already in progress. + watcher := xfer.Watch(progressOutput) + return xfer, watcher + } + + start := make(chan struct{}) + inactive := make(chan struct{}) + + if tm.activeTransfers < tm.concurrencyLimit { + close(start) + tm.activeTransfers++ + } else { + tm.waitingTransfers = append(tm.waitingTransfers, start) + } + + masterProgressChan := make(chan progress.Progress) + xfer := xferFunc(masterProgressChan, start, inactive) + watcher := xfer.Watch(progressOutput) + go xfer.Broadcast(masterProgressChan) + tm.transfers[key] = xfer + + // When the transfer is finished, remove from the map. + go func() { + for { + select { + case <-inactive: + tm.mu.Lock() + tm.inactivate(start) + tm.mu.Unlock() + inactive = nil + case <-xfer.Done(): + tm.mu.Lock() + if inactive != nil { + tm.inactivate(start) + } + delete(tm.transfers, key) + tm.mu.Unlock() + return + } + } + }() + + return xfer, watcher +} + +func (tm *transferManager) inactivate(start chan struct{}) { + // If the transfer was started, remove it from the activeTransfers + // count. + select { + case <-start: + // Start next transfer if any are waiting + if len(tm.waitingTransfers) != 0 { + close(tm.waitingTransfers[0]) + tm.waitingTransfers = tm.waitingTransfers[1:] + } else { + tm.activeTransfers-- + } + default: + } +} diff --git a/distribution/xfer/transfer_test.go b/distribution/xfer/transfer_test.go new file mode 100644 index 0000000000..7eeb304033 --- /dev/null +++ b/distribution/xfer/transfer_test.go @@ -0,0 +1,385 @@ +package xfer + +import ( + "sync/atomic" + "testing" + "time" + + "github.com/docker/docker/pkg/progress" +) + +func TestTransfer(t *testing.T) { + makeXferFunc := func(id string) DoFunc { + return func(progressChan chan<- progress.Progress, start <-chan struct{}, inactive chan<- struct{}) Transfer { + select { + case <-start: + default: + t.Fatalf("transfer function not started even though concurrency limit not reached") + } + + xfer := NewTransfer() + go func() { + for i := 0; i <= 10; i++ { + progressChan <- progress.Progress{ID: id, Action: "testing", Current: int64(i), Total: 10} + time.Sleep(10 * time.Millisecond) + } + close(progressChan) + }() + return xfer + } + } + + tm := NewTransferManager(5) + progressChan := make(chan progress.Progress) + progressDone := make(chan struct{}) + receivedProgress := make(map[string]int64) + + go func() { + for p := range progressChan { + val, present := receivedProgress[p.ID] + if !present { + if p.Current != 0 { + t.Fatalf("got unexpected progress value: %d (expected 0)", p.Current) + } + } else if p.Current == 10 { + // Special case: last progress output may be + // repeated because the transfer finishing + // causes the latest progress output to be + // written to the channel (in case the watcher + // missed it). + if p.Current != 9 && p.Current != 10 { + t.Fatalf("got unexpected progress value: %d (expected %d)", p.Current, val+1) + } + } else if p.Current != val+1 { + t.Fatalf("got unexpected progress value: %d (expected %d)", p.Current, val+1) + } + receivedProgress[p.ID] = p.Current + } + close(progressDone) + }() + + // Start a few transfers + ids := []string{"id1", "id2", "id3"} + xfers := make([]Transfer, len(ids)) + watchers := make([]*Watcher, len(ids)) + for i, id := range ids { + xfers[i], watchers[i] = tm.Transfer(id, makeXferFunc(id), progress.ChanOutput(progressChan)) + } + + for i, xfer := range xfers { + <-xfer.Done() + xfer.Release(watchers[i]) + } + close(progressChan) + <-progressDone + + for _, id := range ids { + if receivedProgress[id] != 10 { + t.Fatalf("final progress value %d instead of 10", receivedProgress[id]) + } + } +} + +func TestConcurrencyLimit(t *testing.T) { + concurrencyLimit := 3 + var runningJobs int32 + + makeXferFunc := func(id string) DoFunc { + return func(progressChan chan<- progress.Progress, start <-chan struct{}, inactive chan<- struct{}) Transfer { + xfer := NewTransfer() + go func() { + <-start + totalJobs := atomic.AddInt32(&runningJobs, 1) + if int(totalJobs) > concurrencyLimit { + t.Fatalf("too many jobs running") + } + for i := 0; i <= 10; i++ { + progressChan <- progress.Progress{ID: id, Action: "testing", Current: int64(i), Total: 10} + time.Sleep(10 * time.Millisecond) + } + atomic.AddInt32(&runningJobs, -1) + close(progressChan) + }() + return xfer + } + } + + tm := NewTransferManager(concurrencyLimit) + progressChan := make(chan progress.Progress) + progressDone := make(chan struct{}) + receivedProgress := make(map[string]int64) + + go func() { + for p := range progressChan { + receivedProgress[p.ID] = p.Current + } + close(progressDone) + }() + + // Start more transfers than the concurrency limit + ids := []string{"id1", "id2", "id3", "id4", "id5", "id6", "id7", "id8"} + xfers := make([]Transfer, len(ids)) + watchers := make([]*Watcher, len(ids)) + for i, id := range ids { + xfers[i], watchers[i] = tm.Transfer(id, makeXferFunc(id), progress.ChanOutput(progressChan)) + } + + for i, xfer := range xfers { + <-xfer.Done() + xfer.Release(watchers[i]) + } + close(progressChan) + <-progressDone + + for _, id := range ids { + if receivedProgress[id] != 10 { + t.Fatalf("final progress value %d instead of 10", receivedProgress[id]) + } + } +} + +func TestInactiveJobs(t *testing.T) { + concurrencyLimit := 3 + var runningJobs int32 + testDone := make(chan struct{}) + + makeXferFunc := func(id string) DoFunc { + return func(progressChan chan<- progress.Progress, start <-chan struct{}, inactive chan<- struct{}) Transfer { + xfer := NewTransfer() + go func() { + <-start + totalJobs := atomic.AddInt32(&runningJobs, 1) + if int(totalJobs) > concurrencyLimit { + t.Fatalf("too many jobs running") + } + for i := 0; i <= 10; i++ { + progressChan <- progress.Progress{ID: id, Action: "testing", Current: int64(i), Total: 10} + time.Sleep(10 * time.Millisecond) + } + atomic.AddInt32(&runningJobs, -1) + close(inactive) + <-testDone + close(progressChan) + }() + return xfer + } + } + + tm := NewTransferManager(concurrencyLimit) + progressChan := make(chan progress.Progress) + progressDone := make(chan struct{}) + receivedProgress := make(map[string]int64) + + go func() { + for p := range progressChan { + receivedProgress[p.ID] = p.Current + } + close(progressDone) + }() + + // Start more transfers than the concurrency limit + ids := []string{"id1", "id2", "id3", "id4", "id5", "id6", "id7", "id8"} + xfers := make([]Transfer, len(ids)) + watchers := make([]*Watcher, len(ids)) + for i, id := range ids { + xfers[i], watchers[i] = tm.Transfer(id, makeXferFunc(id), progress.ChanOutput(progressChan)) + } + + close(testDone) + for i, xfer := range xfers { + <-xfer.Done() + xfer.Release(watchers[i]) + } + close(progressChan) + <-progressDone + + for _, id := range ids { + if receivedProgress[id] != 10 { + t.Fatalf("final progress value %d instead of 10", receivedProgress[id]) + } + } +} + +func TestWatchRelease(t *testing.T) { + ready := make(chan struct{}) + + makeXferFunc := func(id string) DoFunc { + return func(progressChan chan<- progress.Progress, start <-chan struct{}, inactive chan<- struct{}) Transfer { + xfer := NewTransfer() + go func() { + defer func() { + close(progressChan) + }() + <-ready + for i := int64(0); ; i++ { + select { + case <-time.After(10 * time.Millisecond): + case <-xfer.Context().Done(): + return + } + progressChan <- progress.Progress{ID: id, Action: "testing", Current: i, Total: 10} + } + }() + return xfer + } + } + + tm := NewTransferManager(5) + + type watcherInfo struct { + watcher *Watcher + progressChan chan progress.Progress + progressDone chan struct{} + receivedFirstProgress chan struct{} + } + + progressConsumer := func(w watcherInfo) { + first := true + for range w.progressChan { + if first { + close(w.receivedFirstProgress) + } + first = false + } + close(w.progressDone) + } + + // Start a transfer + watchers := make([]watcherInfo, 5) + var xfer Transfer + watchers[0].progressChan = make(chan progress.Progress) + watchers[0].progressDone = make(chan struct{}) + watchers[0].receivedFirstProgress = make(chan struct{}) + xfer, watchers[0].watcher = tm.Transfer("id1", makeXferFunc("id1"), progress.ChanOutput(watchers[0].progressChan)) + go progressConsumer(watchers[0]) + + // Give it multiple watchers + for i := 1; i != len(watchers); i++ { + watchers[i].progressChan = make(chan progress.Progress) + watchers[i].progressDone = make(chan struct{}) + watchers[i].receivedFirstProgress = make(chan struct{}) + watchers[i].watcher = xfer.Watch(progress.ChanOutput(watchers[i].progressChan)) + go progressConsumer(watchers[i]) + } + + // Now that the watchers are set up, allow the transfer goroutine to + // proceed. + close(ready) + + // Confirm that each watcher gets progress output. + for _, w := range watchers { + <-w.receivedFirstProgress + } + + // Release one watcher every 5ms + for _, w := range watchers { + xfer.Release(w.watcher) + <-time.After(5 * time.Millisecond) + } + + // Now that all watchers have been released, Released() should + // return a closed channel. + <-xfer.Released() + + // Done() should return a closed channel because the xfer func returned + // due to cancellation. + <-xfer.Done() + + for _, w := range watchers { + close(w.progressChan) + <-w.progressDone + } +} + +func TestDuplicateTransfer(t *testing.T) { + ready := make(chan struct{}) + + var xferFuncCalls int32 + + makeXferFunc := func(id string) DoFunc { + return func(progressChan chan<- progress.Progress, start <-chan struct{}, inactive chan<- struct{}) Transfer { + atomic.AddInt32(&xferFuncCalls, 1) + xfer := NewTransfer() + go func() { + defer func() { + close(progressChan) + }() + <-ready + for i := int64(0); ; i++ { + select { + case <-time.After(10 * time.Millisecond): + case <-xfer.Context().Done(): + return + } + progressChan <- progress.Progress{ID: id, Action: "testing", Current: i, Total: 10} + } + }() + return xfer + } + } + + tm := NewTransferManager(5) + + type transferInfo struct { + xfer Transfer + watcher *Watcher + progressChan chan progress.Progress + progressDone chan struct{} + receivedFirstProgress chan struct{} + } + + progressConsumer := func(t transferInfo) { + first := true + for range t.progressChan { + if first { + close(t.receivedFirstProgress) + } + first = false + } + close(t.progressDone) + } + + // Try to start multiple transfers with the same ID + transfers := make([]transferInfo, 5) + for i := range transfers { + t := &transfers[i] + t.progressChan = make(chan progress.Progress) + t.progressDone = make(chan struct{}) + t.receivedFirstProgress = make(chan struct{}) + t.xfer, t.watcher = tm.Transfer("id1", makeXferFunc("id1"), progress.ChanOutput(t.progressChan)) + go progressConsumer(*t) + } + + // Allow the transfer goroutine to proceed. + close(ready) + + // Confirm that each watcher gets progress output. + for _, t := range transfers { + <-t.receivedFirstProgress + } + + // Confirm that the transfer function was called exactly once. + if xferFuncCalls != 1 { + t.Fatal("transfer function wasn't called exactly once") + } + + // Release one watcher every 5ms + for _, t := range transfers { + t.xfer.Release(t.watcher) + <-time.After(5 * time.Millisecond) + } + + for _, t := range transfers { + // Now that all watchers have been released, Released() should + // return a closed channel. + <-t.xfer.Released() + // Done() should return a closed channel because the xfer func returned + // due to cancellation. + <-t.xfer.Done() + } + + for _, t := range transfers { + close(t.progressChan) + <-t.progressDone + } +} diff --git a/distribution/xfer/upload.go b/distribution/xfer/upload.go new file mode 100644 index 0000000000..9a7d2c17a7 --- /dev/null +++ b/distribution/xfer/upload.go @@ -0,0 +1,159 @@ +package xfer + +import ( + "errors" + "time" + + "github.com/Sirupsen/logrus" + "github.com/docker/distribution/digest" + "github.com/docker/docker/layer" + "github.com/docker/docker/pkg/progress" + "golang.org/x/net/context" +) + +const maxUploadAttempts = 5 + +// LayerUploadManager provides task management and progress reporting for +// uploads. +type LayerUploadManager struct { + tm TransferManager +} + +// NewLayerUploadManager returns a new LayerUploadManager. +func NewLayerUploadManager(concurrencyLimit int) *LayerUploadManager { + return &LayerUploadManager{ + tm: NewTransferManager(concurrencyLimit), + } +} + +type uploadTransfer struct { + Transfer + + diffID layer.DiffID + digest digest.Digest + err error +} + +// An UploadDescriptor references a layer that may need to be uploaded. +type UploadDescriptor interface { + // Key returns the key used to deduplicate uploads. + Key() string + // ID returns the ID for display purposes. + ID() string + // DiffID should return the DiffID for this layer. + DiffID() layer.DiffID + // Upload is called to perform the Upload. + Upload(ctx context.Context, progressOutput progress.Output) (digest.Digest, error) +} + +// Upload is a blocking function which ensures the listed layers are present on +// the remote registry. It uses the string returned by the Key method to +// deduplicate uploads. +func (lum *LayerUploadManager) Upload(ctx context.Context, layers []UploadDescriptor, progressOutput progress.Output) (map[layer.DiffID]digest.Digest, error) { + var ( + uploads []*uploadTransfer + digests = make(map[layer.DiffID]digest.Digest) + dedupDescriptors = make(map[string]struct{}) + ) + + for _, descriptor := range layers { + progress.Update(progressOutput, descriptor.ID(), "Preparing") + + key := descriptor.Key() + if _, present := dedupDescriptors[key]; present { + continue + } + dedupDescriptors[key] = struct{}{} + + xferFunc := lum.makeUploadFunc(descriptor) + upload, watcher := lum.tm.Transfer(descriptor.Key(), xferFunc, progressOutput) + defer upload.Release(watcher) + uploads = append(uploads, upload.(*uploadTransfer)) + } + + for _, upload := range uploads { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-upload.Transfer.Done(): + if upload.err != nil { + return nil, upload.err + } + digests[upload.diffID] = upload.digest + } + } + + return digests, nil +} + +func (lum *LayerUploadManager) makeUploadFunc(descriptor UploadDescriptor) DoFunc { + return func(progressChan chan<- progress.Progress, start <-chan struct{}, inactive chan<- struct{}) Transfer { + u := &uploadTransfer{ + Transfer: NewTransfer(), + diffID: descriptor.DiffID(), + } + + go func() { + defer func() { + close(progressChan) + }() + + progressOutput := progress.ChanOutput(progressChan) + + select { + case <-start: + default: + progress.Update(progressOutput, descriptor.ID(), "Waiting") + <-start + } + + retries := 0 + for { + digest, err := descriptor.Upload(u.Transfer.Context(), progressOutput) + if err == nil { + u.digest = digest + break + } + + // If an error was returned because the context + // was cancelled, we shouldn't retry. + select { + case <-u.Transfer.Context().Done(): + u.err = err + return + default: + } + + retries++ + if _, isDNR := err.(DoNotRetry); isDNR || retries == maxUploadAttempts { + logrus.Errorf("Upload failed: %v", err) + u.err = err + return + } + + logrus.Errorf("Upload failed, retrying: %v", err) + delay := retries * 5 + ticker := time.NewTicker(time.Second) + + selectLoop: + for { + progress.Updatef(progressOutput, descriptor.ID(), "Retrying in %d seconds", delay) + select { + case <-ticker.C: + delay-- + if delay == 0 { + ticker.Stop() + break selectLoop + } + case <-u.Transfer.Context().Done(): + ticker.Stop() + u.err = errors.New("upload cancelled during retry delay") + return + } + } + } + }() + + return u + } +} diff --git a/distribution/xfer/upload_test.go b/distribution/xfer/upload_test.go new file mode 100644 index 0000000000..df5b2ba9c0 --- /dev/null +++ b/distribution/xfer/upload_test.go @@ -0,0 +1,153 @@ +package xfer + +import ( + "errors" + "sync/atomic" + "testing" + "time" + + "github.com/docker/distribution/digest" + "github.com/docker/docker/layer" + "github.com/docker/docker/pkg/progress" + "golang.org/x/net/context" +) + +const maxUploadConcurrency = 3 + +type mockUploadDescriptor struct { + currentUploads *int32 + diffID layer.DiffID + simulateRetries int +} + +// Key returns the key used to deduplicate downloads. +func (u *mockUploadDescriptor) Key() string { + return u.diffID.String() +} + +// ID returns the ID for display purposes. +func (u *mockUploadDescriptor) ID() string { + return u.diffID.String() +} + +// DiffID should return the DiffID for this layer. +func (u *mockUploadDescriptor) DiffID() layer.DiffID { + return u.diffID +} + +// Upload is called to perform the upload. +func (u *mockUploadDescriptor) Upload(ctx context.Context, progressOutput progress.Output) (digest.Digest, error) { + if u.currentUploads != nil { + defer atomic.AddInt32(u.currentUploads, -1) + + if atomic.AddInt32(u.currentUploads, 1) > maxUploadConcurrency { + return "", errors.New("concurrency limit exceeded") + } + } + + // Sleep a bit to simulate a time-consuming upload. + for i := int64(0); i <= 10; i++ { + select { + case <-ctx.Done(): + return "", ctx.Err() + case <-time.After(10 * time.Millisecond): + progressOutput.WriteProgress(progress.Progress{ID: u.ID(), Current: i, Total: 10}) + } + } + + if u.simulateRetries != 0 { + u.simulateRetries-- + return "", errors.New("simulating retry") + } + + // For the mock implementation, use SHA256(DiffID) as the returned + // digest. + return digest.FromBytes([]byte(u.diffID.String())) +} + +func uploadDescriptors(currentUploads *int32) []UploadDescriptor { + return []UploadDescriptor{ + &mockUploadDescriptor{currentUploads, layer.DiffID("sha256:cbbf2f9a99b47fc460d422812b6a5adff7dfee951d8fa2e4a98caa0382cfbdbf"), 0}, + &mockUploadDescriptor{currentUploads, layer.DiffID("sha256:1515325234325236634634608943609283523908626098235490238423902343"), 0}, + &mockUploadDescriptor{currentUploads, layer.DiffID("sha256:6929356290463485374960346430698374523437683470934634534953453453"), 0}, + &mockUploadDescriptor{currentUploads, layer.DiffID("sha256:cbbf2f9a99b47fc460d422812b6a5adff7dfee951d8fa2e4a98caa0382cfbdbf"), 0}, + &mockUploadDescriptor{currentUploads, layer.DiffID("sha256:8159352387436803946235346346368745389534789534897538734598734987"), 1}, + &mockUploadDescriptor{currentUploads, layer.DiffID("sha256:4637863963478346897346987346987346789346789364879364897364987346"), 0}, + } +} + +var expectedDigests = map[layer.DiffID]digest.Digest{ + layer.DiffID("sha256:cbbf2f9a99b47fc460d422812b6a5adff7dfee951d8fa2e4a98caa0382cfbdbf"): digest.Digest("sha256:c5095d6cf7ee42b7b064371dcc1dc3fb4af197f04d01a60009d484bd432724fc"), + layer.DiffID("sha256:1515325234325236634634608943609283523908626098235490238423902343"): digest.Digest("sha256:968cbfe2ff5269ea1729b3804767a1f57ffbc442d3bc86f47edbf7e688a4f36e"), + layer.DiffID("sha256:6929356290463485374960346430698374523437683470934634534953453453"): digest.Digest("sha256:8a5e56ab4b477a400470a7d5d4c1ca0c91235fd723ab19cc862636a06f3a735d"), + layer.DiffID("sha256:8159352387436803946235346346368745389534789534897538734598734987"): digest.Digest("sha256:5e733e5cd3688512fc240bd5c178e72671c9915947d17bb8451750d827944cb2"), + layer.DiffID("sha256:4637863963478346897346987346987346789346789364879364897364987346"): digest.Digest("sha256:ec4bb98d15e554a9f66c3ef9296cf46772c0ded3b1592bd8324d96e2f60f460c"), +} + +func TestSuccessfulUpload(t *testing.T) { + lum := NewLayerUploadManager(maxUploadConcurrency) + + progressChan := make(chan progress.Progress) + progressDone := make(chan struct{}) + receivedProgress := make(map[string]int64) + + go func() { + for p := range progressChan { + receivedProgress[p.ID] = p.Current + } + close(progressDone) + }() + + var currentUploads int32 + descriptors := uploadDescriptors(¤tUploads) + + digests, err := lum.Upload(context.Background(), descriptors, progress.ChanOutput(progressChan)) + if err != nil { + t.Fatalf("upload error: %v", err) + } + + close(progressChan) + <-progressDone + + if len(digests) != len(expectedDigests) { + t.Fatal("wrong number of keys in digests map") + } + + for key, val := range expectedDigests { + if digests[key] != val { + t.Fatalf("mismatch in digest array for key %v (expected %v, got %v)", key, val, digests[key]) + } + if receivedProgress[key.String()] != 10 { + t.Fatalf("missing or wrong progress output for %v", key) + } + } +} + +func TestCancelledUpload(t *testing.T) { + lum := NewLayerUploadManager(maxUploadConcurrency) + + progressChan := make(chan progress.Progress) + progressDone := make(chan struct{}) + + go func() { + for range progressChan { + } + close(progressDone) + }() + + ctx, cancel := context.WithCancel(context.Background()) + + go func() { + <-time.After(time.Millisecond) + cancel() + }() + + descriptors := uploadDescriptors(nil) + _, err := lum.Upload(ctx, descriptors, progress.ChanOutput(progressChan)) + if err != context.Canceled { + t.Fatal("expected upload to be cancelled") + } + + close(progressChan) + <-progressDone +} diff --git a/integration-cli/docker_cli_pull_test.go b/integration-cli/docker_cli_pull_test.go index 411acec539..8ea31a7c52 100644 --- a/integration-cli/docker_cli_pull_test.go +++ b/integration-cli/docker_cli_pull_test.go @@ -140,7 +140,7 @@ func (s *DockerHubPullSuite) TestPullAllTagsFromCentralRegistry(c *check.C) { } // TestPullClientDisconnect kills the client during a pull operation and verifies that the operation -// still succesfully completes on the daemon side. +// gets cancelled. // // Ref: docker/docker#15589 func (s *DockerHubPullSuite) TestPullClientDisconnect(c *check.C) { @@ -161,14 +161,8 @@ func (s *DockerHubPullSuite) TestPullClientDisconnect(c *check.C) { err = pullCmd.Process.Kill() c.Assert(err, checker.IsNil) - maxAttempts := 20 - for i := 0; ; i++ { - if _, err := s.CmdWithError("inspect", repoName); err == nil { - break - } - if i >= maxAttempts { - c.Fatal("timeout reached: image was not pulled after client disconnected") - } - time.Sleep(500 * time.Millisecond) + time.Sleep(2 * time.Second) + if _, err := s.CmdWithError("inspect", repoName); err == nil { + c.Fatal("image was pulled after client disconnected") } } diff --git a/pkg/broadcaster/buffered.go b/pkg/broadcaster/buffered.go deleted file mode 100644 index 57f5f97862..0000000000 --- a/pkg/broadcaster/buffered.go +++ /dev/null @@ -1,167 +0,0 @@ -package broadcaster - -import ( - "errors" - "io" - "sync" -) - -// Buffered keeps track of one or more observers watching the progress -// of an operation. For example, if multiple clients are trying to pull an -// image, they share a Buffered struct for the download operation. -type Buffered struct { - sync.Mutex - // c is a channel that observers block on, waiting for the operation - // to finish. - c chan struct{} - // cond is a condition variable used to wake up observers when there's - // new data available. - cond *sync.Cond - // history is a buffer of the progress output so far, so a new observer - // can catch up. The history is stored as a slice of separate byte - // slices, so that if the writer is a WriteFlusher, the flushes will - // happen in the right places. - history [][]byte - // wg is a WaitGroup used to wait for all writes to finish on Close - wg sync.WaitGroup - // result is the argument passed to the first call of Close, and - // returned to callers of Wait - result error -} - -// NewBuffered returns an initialized Buffered structure. -func NewBuffered() *Buffered { - b := &Buffered{ - c: make(chan struct{}), - } - b.cond = sync.NewCond(b) - return b -} - -// closed returns true if and only if the broadcaster has been closed -func (broadcaster *Buffered) closed() bool { - select { - case <-broadcaster.c: - return true - default: - return false - } -} - -// receiveWrites runs as a goroutine so that writes don't block the Write -// function. It writes the new data in broadcaster.history each time there's -// activity on the broadcaster.cond condition variable. -func (broadcaster *Buffered) receiveWrites(observer io.Writer) { - n := 0 - - broadcaster.Lock() - - // The condition variable wait is at the end of this loop, so that the - // first iteration will write the history so far. - for { - newData := broadcaster.history[n:] - // Make a copy of newData so we can release the lock - sendData := make([][]byte, len(newData), len(newData)) - copy(sendData, newData) - broadcaster.Unlock() - - for len(sendData) > 0 { - _, err := observer.Write(sendData[0]) - if err != nil { - broadcaster.wg.Done() - return - } - n++ - sendData = sendData[1:] - } - - broadcaster.Lock() - - // If we are behind, we need to catch up instead of waiting - // or handling a closure. - if len(broadcaster.history) != n { - continue - } - - // detect closure of the broadcast writer - if broadcaster.closed() { - broadcaster.Unlock() - broadcaster.wg.Done() - return - } - - broadcaster.cond.Wait() - - // Mutex is still locked as the loop continues - } -} - -// Write adds data to the history buffer, and also writes it to all current -// observers. -func (broadcaster *Buffered) Write(p []byte) (n int, err error) { - broadcaster.Lock() - defer broadcaster.Unlock() - - // Is the broadcaster closed? If so, the write should fail. - if broadcaster.closed() { - return 0, errors.New("attempted write to a closed broadcaster.Buffered") - } - - // Add message in p to the history slice - newEntry := make([]byte, len(p), len(p)) - copy(newEntry, p) - broadcaster.history = append(broadcaster.history, newEntry) - - broadcaster.cond.Broadcast() - - return len(p), nil -} - -// Add adds an observer to the broadcaster. The new observer receives the -// data from the history buffer, and also all subsequent data. -func (broadcaster *Buffered) Add(w io.Writer) error { - // The lock is acquired here so that Add can't race with Close - broadcaster.Lock() - defer broadcaster.Unlock() - - if broadcaster.closed() { - return errors.New("attempted to add observer to a closed broadcaster.Buffered") - } - - broadcaster.wg.Add(1) - go broadcaster.receiveWrites(w) - - return nil -} - -// CloseWithError signals to all observers that the operation has finished. Its -// argument is a result that should be returned to waiters blocking on Wait. -func (broadcaster *Buffered) CloseWithError(result error) { - broadcaster.Lock() - if broadcaster.closed() { - broadcaster.Unlock() - return - } - broadcaster.result = result - close(broadcaster.c) - broadcaster.cond.Broadcast() - broadcaster.Unlock() - - // Don't return until all writers have caught up. - broadcaster.wg.Wait() -} - -// Close signals to all observers that the operation has finished. It causes -// all calls to Wait to return nil. -func (broadcaster *Buffered) Close() { - broadcaster.CloseWithError(nil) -} - -// Wait blocks until the operation is marked as completed by the Close method, -// and all writer goroutines have completed. It returns the argument that was -// passed to Close. -func (broadcaster *Buffered) Wait() error { - <-broadcaster.c - broadcaster.wg.Wait() - return broadcaster.result -} diff --git a/pkg/ioutils/readers.go b/pkg/ioutils/readers.go index b4544de53c..e73b02bbf1 100644 --- a/pkg/ioutils/readers.go +++ b/pkg/ioutils/readers.go @@ -4,6 +4,8 @@ import ( "crypto/sha256" "encoding/hex" "io" + + "golang.org/x/net/context" ) type readCloserWrapper struct { @@ -81,3 +83,72 @@ func (r *OnEOFReader) runFunc() { r.Fn = nil } } + +// cancelReadCloser wraps an io.ReadCloser with a context for cancelling read +// operations. +type cancelReadCloser struct { + cancel func() + pR *io.PipeReader // Stream to read from + pW *io.PipeWriter +} + +// NewCancelReadCloser creates a wrapper that closes the ReadCloser when the +// context is cancelled. The returned io.ReadCloser must be closed when it is +// no longer needed. +func NewCancelReadCloser(ctx context.Context, in io.ReadCloser) io.ReadCloser { + pR, pW := io.Pipe() + + // Create a context used to signal when the pipe is closed + doneCtx, cancel := context.WithCancel(context.Background()) + + p := &cancelReadCloser{ + cancel: cancel, + pR: pR, + pW: pW, + } + + go func() { + _, err := io.Copy(pW, in) + select { + case <-ctx.Done(): + // If the context was closed, p.closeWithError + // was already called. Calling it again would + // change the error that Read returns. + default: + p.closeWithError(err) + } + in.Close() + }() + go func() { + for { + select { + case <-ctx.Done(): + p.closeWithError(ctx.Err()) + case <-doneCtx.Done(): + return + } + } + }() + + return p +} + +// Read wraps the Read method of the pipe that provides data from the wrapped +// ReadCloser. +func (p *cancelReadCloser) Read(buf []byte) (n int, err error) { + return p.pR.Read(buf) +} + +// closeWithError closes the wrapper and its underlying reader. It will +// cause future calls to Read to return err. +func (p *cancelReadCloser) closeWithError(err error) { + p.pW.CloseWithError(err) + p.cancel() +} + +// Close closes the wrapper its underlying reader. It will cause +// future calls to Read to return io.EOF. +func (p *cancelReadCloser) Close() error { + p.closeWithError(io.EOF) + return nil +} diff --git a/pkg/ioutils/readers_test.go b/pkg/ioutils/readers_test.go index 5be68cb16c..9abc1054df 100644 --- a/pkg/ioutils/readers_test.go +++ b/pkg/ioutils/readers_test.go @@ -2,8 +2,12 @@ package ioutils import ( "fmt" + "io/ioutil" "strings" "testing" + "time" + + "golang.org/x/net/context" ) // Implement io.Reader @@ -65,3 +69,26 @@ func TestHashData(t *testing.T) { t.Fatalf("Expecting %s, got %s", expected, actual) } } + +type perpetualReader struct{} + +func (p *perpetualReader) Read(buf []byte) (n int, err error) { + for i := 0; i != len(buf); i++ { + buf[i] = 'a' + } + return len(buf), nil +} + +func TestCancelReadCloser(t *testing.T) { + ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond) + cancelReadCloser := NewCancelReadCloser(ctx, ioutil.NopCloser(&perpetualReader{})) + for { + var buf [128]byte + _, err := cancelReadCloser.Read(buf[:]) + if err == context.DeadlineExceeded { + break + } else if err != nil { + t.Fatalf("got unexpected error: %v", err) + } + } +} diff --git a/pkg/progress/progress.go b/pkg/progress/progress.go new file mode 100644 index 0000000000..1f3b34a832 --- /dev/null +++ b/pkg/progress/progress.go @@ -0,0 +1,63 @@ +package progress + +import ( + "fmt" +) + +// Progress represents the progress of a transfer. +type Progress struct { + ID string + + // Progress contains a Message or... + Message string + + // ...progress of an action + Action string + Current int64 + Total int64 + + LastUpdate bool +} + +// Output is an interface for writing progress information. It's +// like a writer for progress, but we don't call it Writer because +// that would be confusing next to ProgressReader (also, because it +// doesn't implement the io.Writer interface). +type Output interface { + WriteProgress(Progress) error +} + +type chanOutput chan<- Progress + +func (out chanOutput) WriteProgress(p Progress) error { + out <- p + return nil +} + +// ChanOutput returns a Output that writes progress updates to the +// supplied channel. +func ChanOutput(progressChan chan<- Progress) Output { + return chanOutput(progressChan) +} + +// Update is a convenience function to write a progress update to the channel. +func Update(out Output, id, action string) { + out.WriteProgress(Progress{ID: id, Action: action}) +} + +// Updatef is a convenience function to write a printf-formatted progress update +// to the channel. +func Updatef(out Output, id, format string, a ...interface{}) { + Update(out, id, fmt.Sprintf(format, a...)) +} + +// Message is a convenience function to write a progress message to the channel. +func Message(out Output, id, message string) { + out.WriteProgress(Progress{ID: id, Message: message}) +} + +// Messagef is a convenience function to write a printf-formatted progress +// message to the channel. +func Messagef(out Output, id, format string, a ...interface{}) { + Message(out, id, fmt.Sprintf(format, a...)) +} diff --git a/pkg/progress/progressreader.go b/pkg/progress/progressreader.go new file mode 100644 index 0000000000..c39e2b69fb --- /dev/null +++ b/pkg/progress/progressreader.go @@ -0,0 +1,59 @@ +package progress + +import ( + "io" +) + +// Reader is a Reader with progress bar. +type Reader struct { + in io.ReadCloser // Stream to read from + out Output // Where to send progress bar to + size int64 + current int64 + lastUpdate int64 + id string + action string +} + +// NewProgressReader creates a new ProgressReader. +func NewProgressReader(in io.ReadCloser, out Output, size int64, id, action string) *Reader { + return &Reader{ + in: in, + out: out, + size: size, + id: id, + action: action, + } +} + +func (p *Reader) Read(buf []byte) (n int, err error) { + read, err := p.in.Read(buf) + p.current += int64(read) + updateEvery := int64(1024 * 512) //512kB + if p.size > 0 { + // Update progress for every 1% read if 1% < 512kB + if increment := int64(0.01 * float64(p.size)); increment < updateEvery { + updateEvery = increment + } + } + if p.current-p.lastUpdate > updateEvery || err != nil { + p.updateProgress(err != nil && read == 0) + p.lastUpdate = p.current + } + + return read, err +} + +// Close closes the progress reader and its underlying reader. +func (p *Reader) Close() error { + if p.current < p.size { + // print a full progress bar when closing prematurely + p.current = p.size + p.updateProgress(false) + } + return p.in.Close() +} + +func (p *Reader) updateProgress(last bool) { + p.out.WriteProgress(Progress{ID: p.id, Action: p.action, Current: p.current, Total: p.size, LastUpdate: last}) +} diff --git a/pkg/progress/progressreader_test.go b/pkg/progress/progressreader_test.go new file mode 100644 index 0000000000..b14d401561 --- /dev/null +++ b/pkg/progress/progressreader_test.go @@ -0,0 +1,75 @@ +package progress + +import ( + "bytes" + "io" + "io/ioutil" + "testing" +) + +func TestOutputOnPrematureClose(t *testing.T) { + content := []byte("TESTING") + reader := ioutil.NopCloser(bytes.NewReader(content)) + progressChan := make(chan Progress, 10) + + pr := NewProgressReader(reader, ChanOutput(progressChan), int64(len(content)), "Test", "Read") + + part := make([]byte, 4, 4) + _, err := io.ReadFull(pr, part) + if err != nil { + pr.Close() + t.Fatal(err) + } + +drainLoop: + for { + select { + case <-progressChan: + default: + break drainLoop + } + } + + pr.Close() + + select { + case <-progressChan: + default: + t.Fatalf("Expected some output when closing prematurely") + } +} + +func TestCompleteSilently(t *testing.T) { + content := []byte("TESTING") + reader := ioutil.NopCloser(bytes.NewReader(content)) + progressChan := make(chan Progress, 10) + + pr := NewProgressReader(reader, ChanOutput(progressChan), int64(len(content)), "Test", "Read") + + out, err := ioutil.ReadAll(pr) + if err != nil { + pr.Close() + t.Fatal(err) + } + if string(out) != "TESTING" { + pr.Close() + t.Fatalf("Unexpected output %q from reader", string(out)) + } + +drainLoop: + for { + select { + case <-progressChan: + default: + break drainLoop + } + } + + pr.Close() + + select { + case <-progressChan: + t.Fatalf("Should have closed silently when read is complete") + default: + } +} diff --git a/pkg/progressreader/progressreader.go b/pkg/progressreader/progressreader.go deleted file mode 100644 index f48442b591..0000000000 --- a/pkg/progressreader/progressreader.go +++ /dev/null @@ -1,68 +0,0 @@ -// Package progressreader provides a Reader with a progress bar that can be -// printed out using the streamformatter package. -package progressreader - -import ( - "io" - - "github.com/docker/docker/pkg/jsonmessage" - "github.com/docker/docker/pkg/streamformatter" -) - -// Config contains the configuration for a Reader with progress bar. -type Config struct { - In io.ReadCloser // Stream to read from - Out io.Writer // Where to send progress bar to - Formatter *streamformatter.StreamFormatter - Size int64 - Current int64 - LastUpdate int64 - NewLines bool - ID string - Action string -} - -// New creates a new Config. -func New(newReader Config) *Config { - return &newReader -} - -func (config *Config) Read(p []byte) (n int, err error) { - read, err := config.In.Read(p) - config.Current += int64(read) - updateEvery := int64(1024 * 512) //512kB - if config.Size > 0 { - // Update progress for every 1% read if 1% < 512kB - if increment := int64(0.01 * float64(config.Size)); increment < updateEvery { - updateEvery = increment - } - } - if config.Current-config.LastUpdate > updateEvery || err != nil { - updateProgress(config) - config.LastUpdate = config.Current - } - - if err != nil && read == 0 { - updateProgress(config) - if config.NewLines { - config.Out.Write(config.Formatter.FormatStatus("", "")) - } - } - return read, err -} - -// Close closes the reader (Config). -func (config *Config) Close() error { - if config.Current < config.Size { - //print a full progress bar when closing prematurely - config.Current = config.Size - updateProgress(config) - } - return config.In.Close() -} - -func updateProgress(config *Config) { - progress := jsonmessage.JSONProgress{Current: config.Current, Total: config.Size} - fmtMessage := config.Formatter.FormatProgress(config.ID, config.Action, &progress) - config.Out.Write(fmtMessage) -} diff --git a/pkg/progressreader/progressreader_test.go b/pkg/progressreader/progressreader_test.go deleted file mode 100644 index 21d9b0f057..0000000000 --- a/pkg/progressreader/progressreader_test.go +++ /dev/null @@ -1,94 +0,0 @@ -package progressreader - -import ( - "bufio" - "bytes" - "io" - "io/ioutil" - "testing" - - "github.com/docker/docker/pkg/streamformatter" -) - -func TestOutputOnPrematureClose(t *testing.T) { - var outBuf bytes.Buffer - content := []byte("TESTING") - reader := ioutil.NopCloser(bytes.NewReader(content)) - writer := bufio.NewWriter(&outBuf) - - prCfg := Config{ - In: reader, - Out: writer, - Formatter: streamformatter.NewStreamFormatter(), - Size: int64(len(content)), - NewLines: true, - ID: "Test", - Action: "Read", - } - pr := New(prCfg) - - part := make([]byte, 4, 4) - _, err := io.ReadFull(pr, part) - if err != nil { - pr.Close() - t.Fatal(err) - } - - if err := writer.Flush(); err != nil { - pr.Close() - t.Fatal(err) - } - - tlen := outBuf.Len() - pr.Close() - if err := writer.Flush(); err != nil { - t.Fatal(err) - } - - if outBuf.Len() == tlen { - t.Fatalf("Expected some output when closing prematurely") - } -} - -func TestCompleteSilently(t *testing.T) { - var outBuf bytes.Buffer - content := []byte("TESTING") - reader := ioutil.NopCloser(bytes.NewReader(content)) - writer := bufio.NewWriter(&outBuf) - - prCfg := Config{ - In: reader, - Out: writer, - Formatter: streamformatter.NewStreamFormatter(), - Size: int64(len(content)), - NewLines: true, - ID: "Test", - Action: "Read", - } - pr := New(prCfg) - - out, err := ioutil.ReadAll(pr) - if err != nil { - pr.Close() - t.Fatal(err) - } - if string(out) != "TESTING" { - pr.Close() - t.Fatalf("Unexpected output %q from reader", string(out)) - } - - if err := writer.Flush(); err != nil { - pr.Close() - t.Fatal(err) - } - - tlen := outBuf.Len() - pr.Close() - if err := writer.Flush(); err != nil { - t.Fatal(err) - } - - if outBuf.Len() > tlen { - t.Fatalf("Should have closed silently when read is complete") - } -} diff --git a/pkg/streamformatter/streamformatter.go b/pkg/streamformatter/streamformatter.go index d3ac39ebdb..b67a53d648 100644 --- a/pkg/streamformatter/streamformatter.go +++ b/pkg/streamformatter/streamformatter.go @@ -7,6 +7,7 @@ import ( "io" "github.com/docker/docker/pkg/jsonmessage" + "github.com/docker/docker/pkg/progress" ) // StreamFormatter formats a stream, optionally using JSON. @@ -92,6 +93,44 @@ func (sf *StreamFormatter) FormatProgress(id, action string, progress *jsonmessa return []byte(action + " " + progress.String() + endl) } +// NewProgressOutput returns a progress.Output object that can be passed to +// progress.NewProgressReader. +func (sf *StreamFormatter) NewProgressOutput(out io.Writer, newLines bool) progress.Output { + return &progressOutput{ + sf: sf, + out: out, + newLines: newLines, + } +} + +type progressOutput struct { + sf *StreamFormatter + out io.Writer + newLines bool +} + +// WriteProgress formats progress information from a ProgressReader. +func (out *progressOutput) WriteProgress(prog progress.Progress) error { + var formatted []byte + if prog.Message != "" { + formatted = out.sf.FormatStatus(prog.ID, prog.Message) + } else { + jsonProgress := jsonmessage.JSONProgress{Current: prog.Current, Total: prog.Total} + formatted = out.sf.FormatProgress(prog.ID, prog.Action, &jsonProgress) + } + _, err := out.out.Write(formatted) + if err != nil { + return err + } + + if out.newLines && prog.LastUpdate { + _, err = out.out.Write(out.sf.FormatStatus("", "")) + return err + } + + return nil +} + // StdoutFormatter is a streamFormatter that writes to the standard output. type StdoutFormatter struct { io.Writer diff --git a/registry/session.go b/registry/session.go index 645e5d44b3..5017aeacac 100644 --- a/registry/session.go +++ b/registry/session.go @@ -17,7 +17,6 @@ import ( "net/url" "strconv" "strings" - "time" "github.com/Sirupsen/logrus" "github.com/docker/distribution/reference" @@ -270,7 +269,6 @@ func (r *Session) GetRemoteImageJSON(imgID, registry string) ([]byte, int64, err // GetRemoteImageLayer retrieves an image layer from the registry func (r *Session) GetRemoteImageLayer(imgID, registry string, imgSize int64) (io.ReadCloser, error) { var ( - retries = 5 statusCode = 0 res *http.Response err error @@ -281,14 +279,9 @@ func (r *Session) GetRemoteImageLayer(imgID, registry string, imgSize int64) (io if err != nil { return nil, fmt.Errorf("Error while getting from the server: %v", err) } - // TODO(tiborvass): why are we doing retries at this level? - // These retries should be generic to both v1 and v2 - for i := 1; i <= retries; i++ { - statusCode = 0 - res, err = r.client.Do(req) - if err == nil { - break - } + statusCode = 0 + res, err = r.client.Do(req) + if err != nil { logrus.Debugf("Error contacting registry %s: %v", registry, err) if res != nil { if res.Body != nil { @@ -296,11 +289,8 @@ func (r *Session) GetRemoteImageLayer(imgID, registry string, imgSize int64) (io } statusCode = res.StatusCode } - if i == retries { - return nil, fmt.Errorf("Server error: Status %d while fetching image layer (%s)", - statusCode, imgID) - } - time.Sleep(time.Duration(i) * 5 * time.Second) + return nil, fmt.Errorf("Server error: Status %d while fetching image layer (%s)", + statusCode, imgID) } if res.StatusCode != 200 {