diff --git a/api/client/build.go b/api/client/build.go index 8af15b2465..0516f5385f 100644 --- a/api/client/build.go +++ b/api/client/build.go @@ -23,7 +23,7 @@ import ( "github.com/docker/docker/pkg/httputils" "github.com/docker/docker/pkg/jsonmessage" flag "github.com/docker/docker/pkg/mflag" - "github.com/docker/docker/pkg/progressreader" + "github.com/docker/docker/pkg/progress" "github.com/docker/docker/pkg/streamformatter" "github.com/docker/docker/pkg/ulimit" "github.com/docker/docker/pkg/units" @@ -169,16 +169,9 @@ func (cli *DockerCli) CmdBuild(args ...string) error { context = replaceDockerfileTarWrapper(context, newDockerfile, relDockerfile) // Setup an upload progress bar - // FIXME: ProgressReader shouldn't be this annoying to use - sf := streamformatter.NewStreamFormatter() - var body io.Reader = progressreader.New(progressreader.Config{ - In: context, - Out: cli.out, - Formatter: sf, - NewLines: true, - ID: "", - Action: "Sending build context to Docker daemon", - }) + progressOutput := streamformatter.NewStreamFormatter().NewProgressOutput(cli.out, true) + + var body io.Reader = progress.NewProgressReader(context, progressOutput, 0, "", "Sending build context to Docker daemon") var memory int64 if *flMemoryString != "" { @@ -447,17 +440,10 @@ func getContextFromURL(out io.Writer, remoteURL, dockerfileName string) (absCont return "", "", fmt.Errorf("unable to download remote context %s: %v", remoteURL, err) } defer response.Body.Close() + progressOutput := streamformatter.NewStreamFormatter().NewProgressOutput(out, true) // Pass the response body through a progress reader. - progReader := &progressreader.Config{ - In: response.Body, - Out: out, - Formatter: streamformatter.NewStreamFormatter(), - Size: response.ContentLength, - NewLines: true, - ID: "", - Action: fmt.Sprintf("Downloading build context from remote url: %s", remoteURL), - } + progReader := progress.NewProgressReader(response.Body, progressOutput, response.ContentLength, "", fmt.Sprintf("Downloading build context from remote url: %s", remoteURL)) return getContextFromReader(progReader, dockerfileName) } diff --git a/api/server/router/local/image.go b/api/server/router/local/image.go index 29d64dff15..2ddc28610d 100644 --- a/api/server/router/local/image.go +++ b/api/server/router/local/image.go @@ -23,7 +23,7 @@ import ( "github.com/docker/docker/pkg/archive" "github.com/docker/docker/pkg/chrootarchive" "github.com/docker/docker/pkg/ioutils" - "github.com/docker/docker/pkg/progressreader" + "github.com/docker/docker/pkg/progress" "github.com/docker/docker/pkg/streamformatter" "github.com/docker/docker/pkg/ulimit" "github.com/docker/docker/runconfig" @@ -325,7 +325,7 @@ func (s *router) postBuild(ctx context.Context, w http.ResponseWriter, r *http.R sf := streamformatter.NewJSONStreamFormatter() errf := func(err error) error { // Do not write the error in the http output if it's still empty. - // This prevents from writing a 200(OK) when there is an interal error. + // This prevents from writing a 200(OK) when there is an internal error. if !output.Flushed() { return err } @@ -401,23 +401,17 @@ func (s *router) postBuild(ctx context.Context, w http.ResponseWriter, r *http.R remoteURL := r.FormValue("remote") // Currently, only used if context is from a remote url. - // The field `In` is set by DetectContextFromRemoteURL. // Look at code in DetectContextFromRemoteURL for more information. - pReader := &progressreader.Config{ - // TODO: make progressreader streamformatter-agnostic - Out: output, - Formatter: sf, - Size: r.ContentLength, - NewLines: true, - ID: "Downloading context", - Action: remoteURL, + createProgressReader := func(in io.ReadCloser) io.ReadCloser { + progressOutput := sf.NewProgressOutput(output, true) + return progress.NewProgressReader(in, progressOutput, r.ContentLength, "Downloading context", remoteURL) } var ( context builder.ModifiableContext dockerfileName string ) - context, dockerfileName, err = daemonbuilder.DetectContextFromRemoteURL(r.Body, remoteURL, pReader) + context, dockerfileName, err = daemonbuilder.DetectContextFromRemoteURL(r.Body, remoteURL, createProgressReader) if err != nil { return errf(err) } diff --git a/builder/dockerfile/internals.go b/builder/dockerfile/internals.go index e4d8540c3f..bf407f278e 100644 --- a/builder/dockerfile/internals.go +++ b/builder/dockerfile/internals.go @@ -29,7 +29,7 @@ import ( "github.com/docker/docker/pkg/httputils" "github.com/docker/docker/pkg/ioutils" "github.com/docker/docker/pkg/jsonmessage" - "github.com/docker/docker/pkg/progressreader" + "github.com/docker/docker/pkg/progress" "github.com/docker/docker/pkg/streamformatter" "github.com/docker/docker/pkg/stringid" "github.com/docker/docker/pkg/stringutils" @@ -264,17 +264,11 @@ func (b *Builder) download(srcURL string) (fi builder.FileInfo, err error) { return } + stdoutFormatter := b.Stdout.(*streamformatter.StdoutFormatter) + progressOutput := stdoutFormatter.StreamFormatter.NewProgressOutput(stdoutFormatter.Writer, true) + progressReader := progress.NewProgressReader(resp.Body, progressOutput, resp.ContentLength, "", "Downloading") // Download and dump result to tmp file - if _, err = io.Copy(tmpFile, progressreader.New(progressreader.Config{ - In: resp.Body, - // TODO: make progressreader streamformatter agnostic - Out: b.Stdout.(*streamformatter.StdoutFormatter).Writer, - Formatter: b.Stdout.(*streamformatter.StdoutFormatter).StreamFormatter, - Size: resp.ContentLength, - NewLines: true, - ID: "", - Action: "Downloading", - })); err != nil { + if _, err = io.Copy(tmpFile, progressReader); err != nil { tmpFile.Close() return } diff --git a/daemon/daemon.go b/daemon/daemon.go index 441e20c166..9d8a8e6eef 100644 --- a/daemon/daemon.go +++ b/daemon/daemon.go @@ -34,6 +34,7 @@ import ( "github.com/docker/docker/daemon/network" "github.com/docker/docker/distribution" dmetadata "github.com/docker/docker/distribution/metadata" + "github.com/docker/docker/distribution/xfer" derr "github.com/docker/docker/errors" "github.com/docker/docker/image" "github.com/docker/docker/image/tarexport" @@ -49,7 +50,9 @@ import ( "github.com/docker/docker/pkg/namesgenerator" "github.com/docker/docker/pkg/nat" "github.com/docker/docker/pkg/parsers/filters" + "github.com/docker/docker/pkg/progress" "github.com/docker/docker/pkg/signal" + "github.com/docker/docker/pkg/streamformatter" "github.com/docker/docker/pkg/stringid" "github.com/docker/docker/pkg/stringutils" "github.com/docker/docker/pkg/sysinfo" @@ -66,6 +69,16 @@ import ( lntypes "github.com/docker/libnetwork/types" "github.com/docker/libtrust" "github.com/opencontainers/runc/libcontainer" + "golang.org/x/net/context" +) + +const ( + // maxDownloadConcurrency is the maximum number of downloads that + // may take place at a time for each pull. + maxDownloadConcurrency = 3 + // maxUploadConcurrency is the maximum number of uploads that + // may take place at a time for each push. + maxUploadConcurrency = 5 ) var ( @@ -126,7 +139,8 @@ type Daemon struct { containers *contStore execCommands *exec.Store tagStore tag.Store - distributionPool *distribution.Pool + downloadManager *xfer.LayerDownloadManager + uploadManager *xfer.LayerUploadManager distributionMetadataStore dmetadata.Store trustKey libtrust.PrivateKey idIndex *truncindex.TruncIndex @@ -738,7 +752,8 @@ func NewDaemon(config *Config, registryService *registry.Service) (daemon *Daemo return nil, err } - distributionPool := distribution.NewPool() + d.downloadManager = xfer.NewLayerDownloadManager(d.layerStore, maxDownloadConcurrency) + d.uploadManager = xfer.NewLayerUploadManager(maxUploadConcurrency) ifs, err := image.NewFSStoreBackend(filepath.Join(imageRoot, "imagedb")) if err != nil { @@ -834,7 +849,6 @@ func NewDaemon(config *Config, registryService *registry.Service) (daemon *Daemo d.containers = &contStore{s: make(map[string]*container.Container)} d.execCommands = exec.NewStore() d.tagStore = tagStore - d.distributionPool = distributionPool d.distributionMetadataStore = distributionMetadataStore d.trustKey = trustKey d.idIndex = truncindex.NewTruncIndex([]string{}) @@ -1038,23 +1052,53 @@ func (daemon *Daemon) TagImage(newTag reference.Named, imageName string) error { return nil } +func writeDistributionProgress(cancelFunc func(), outStream io.Writer, progressChan <-chan progress.Progress) { + progressOutput := streamformatter.NewJSONStreamFormatter().NewProgressOutput(outStream, false) + operationCancelled := false + + for prog := range progressChan { + if err := progressOutput.WriteProgress(prog); err != nil && !operationCancelled { + logrus.Errorf("error writing progress to client: %v", err) + cancelFunc() + operationCancelled = true + // Don't return, because we need to continue draining + // progressChan until it's closed to avoid a deadlock. + } + } +} + // PullImage initiates a pull operation. image is the repository name to pull, and // tag may be either empty, or indicate a specific tag to pull. func (daemon *Daemon) PullImage(ref reference.Named, metaHeaders map[string][]string, authConfig *cliconfig.AuthConfig, outStream io.Writer) error { + // Include a buffer so that slow client connections don't affect + // transfer performance. + progressChan := make(chan progress.Progress, 100) + + writesDone := make(chan struct{}) + + ctx, cancelFunc := context.WithCancel(context.Background()) + + go func() { + writeDistributionProgress(cancelFunc, outStream, progressChan) + close(writesDone) + }() + imagePullConfig := &distribution.ImagePullConfig{ MetaHeaders: metaHeaders, AuthConfig: authConfig, - OutStream: outStream, + ProgressOutput: progress.ChanOutput(progressChan), RegistryService: daemon.RegistryService, EventsService: daemon.EventsService, MetadataStore: daemon.distributionMetadataStore, - LayerStore: daemon.layerStore, ImageStore: daemon.imageStore, TagStore: daemon.tagStore, - Pool: daemon.distributionPool, + DownloadManager: daemon.downloadManager, } - return distribution.Pull(ref, imagePullConfig) + err := distribution.Pull(ctx, ref, imagePullConfig) + close(progressChan) + <-writesDone + return err } // ExportImage exports a list of images to the given output stream. The @@ -1069,10 +1113,23 @@ func (daemon *Daemon) ExportImage(names []string, outStream io.Writer) error { // PushImage initiates a push operation on the repository named localName. func (daemon *Daemon) PushImage(ref reference.Named, metaHeaders map[string][]string, authConfig *cliconfig.AuthConfig, outStream io.Writer) error { + // Include a buffer so that slow client connections don't affect + // transfer performance. + progressChan := make(chan progress.Progress, 100) + + writesDone := make(chan struct{}) + + ctx, cancelFunc := context.WithCancel(context.Background()) + + go func() { + writeDistributionProgress(cancelFunc, outStream, progressChan) + close(writesDone) + }() + imagePushConfig := &distribution.ImagePushConfig{ MetaHeaders: metaHeaders, AuthConfig: authConfig, - OutStream: outStream, + ProgressOutput: progress.ChanOutput(progressChan), RegistryService: daemon.RegistryService, EventsService: daemon.EventsService, MetadataStore: daemon.distributionMetadataStore, @@ -1080,9 +1137,13 @@ func (daemon *Daemon) PushImage(ref reference.Named, metaHeaders map[string][]st ImageStore: daemon.imageStore, TagStore: daemon.tagStore, TrustKey: daemon.trustKey, + UploadManager: daemon.uploadManager, } - return distribution.Push(ref, imagePushConfig) + err := distribution.Push(ctx, ref, imagePushConfig) + close(progressChan) + <-writesDone + return err } // LookupImage looks up an image by name and returns it as an ImageInspect diff --git a/daemon/daemonbuilder/builder.go b/daemon/daemonbuilder/builder.go index 632c484eba..4cd28c1bff 100644 --- a/daemon/daemonbuilder/builder.go +++ b/daemon/daemonbuilder/builder.go @@ -21,7 +21,6 @@ import ( "github.com/docker/docker/pkg/httputils" "github.com/docker/docker/pkg/idtools" "github.com/docker/docker/pkg/ioutils" - "github.com/docker/docker/pkg/progressreader" "github.com/docker/docker/pkg/urlutil" "github.com/docker/docker/registry" "github.com/docker/docker/runconfig" @@ -239,7 +238,7 @@ func (d Docker) Start(c *container.Container) error { // DetectContextFromRemoteURL returns a context and in certain cases the name of the dockerfile to be used // irrespective of user input. // progressReader is only used if remoteURL is actually a URL (not empty, and not a Git endpoint). -func DetectContextFromRemoteURL(r io.ReadCloser, remoteURL string, progressReader *progressreader.Config) (context builder.ModifiableContext, dockerfileName string, err error) { +func DetectContextFromRemoteURL(r io.ReadCloser, remoteURL string, createProgressReader func(in io.ReadCloser) io.ReadCloser) (context builder.ModifiableContext, dockerfileName string, err error) { switch { case remoteURL == "": context, err = builder.MakeTarSumContext(r) @@ -262,8 +261,7 @@ func DetectContextFromRemoteURL(r io.ReadCloser, remoteURL string, progressReade }, // fallback handler (tar context) "": func(rc io.ReadCloser) (io.ReadCloser, error) { - progressReader.In = rc - return progressReader, nil + return createProgressReader(rc), nil }, }) default: diff --git a/daemon/import.go b/daemon/import.go index 749295a738..75010e346d 100644 --- a/daemon/import.go +++ b/daemon/import.go @@ -13,7 +13,7 @@ import ( "github.com/docker/docker/image" "github.com/docker/docker/layer" "github.com/docker/docker/pkg/httputils" - "github.com/docker/docker/pkg/progressreader" + "github.com/docker/docker/pkg/progress" "github.com/docker/docker/pkg/streamformatter" "github.com/docker/docker/runconfig" ) @@ -47,16 +47,8 @@ func (daemon *Daemon) ImportImage(src string, newRef reference.Named, msg string if err != nil { return err } - progressReader := progressreader.New(progressreader.Config{ - In: resp.Body, - Out: outStream, - Formatter: sf, - Size: resp.ContentLength, - NewLines: true, - ID: "", - Action: "Importing", - }) - archive = progressReader + progressOutput := sf.NewProgressOutput(outStream, true) + archive = progress.NewProgressReader(resp.Body, progressOutput, resp.ContentLength, "", "Importing") } defer archive.Close() diff --git a/distribution/metadata/v1_id_service.go b/distribution/metadata/v1_id_service.go index 4098f8db83..f6e4589248 100644 --- a/distribution/metadata/v1_id_service.go +++ b/distribution/metadata/v1_id_service.go @@ -23,20 +23,20 @@ func (idserv *V1IDService) namespace() string { } // Get finds a layer by its V1 ID. -func (idserv *V1IDService) Get(v1ID, registry string) (layer.ChainID, error) { +func (idserv *V1IDService) Get(v1ID, registry string) (layer.DiffID, error) { if err := v1.ValidateID(v1ID); err != nil { - return layer.ChainID(""), err + return layer.DiffID(""), err } idBytes, err := idserv.store.Get(idserv.namespace(), registry+","+v1ID) if err != nil { - return layer.ChainID(""), err + return layer.DiffID(""), err } - return layer.ChainID(idBytes), nil + return layer.DiffID(idBytes), nil } // Set associates an image with a V1 ID. -func (idserv *V1IDService) Set(v1ID, registry string, id layer.ChainID) error { +func (idserv *V1IDService) Set(v1ID, registry string, id layer.DiffID) error { if err := v1.ValidateID(v1ID); err != nil { return err } diff --git a/distribution/metadata/v1_id_service_test.go b/distribution/metadata/v1_id_service_test.go index bf0f23a6dc..556886581e 100644 --- a/distribution/metadata/v1_id_service_test.go +++ b/distribution/metadata/v1_id_service_test.go @@ -24,22 +24,22 @@ func TestV1IDService(t *testing.T) { testVectors := []struct { registry string v1ID string - layerID layer.ChainID + layerID layer.DiffID }{ { registry: "registry1", v1ID: "f0cd5ca10b07f35512fc2f1cbf9a6cefbdb5cba70ac6b0c9e5988f4497f71937", - layerID: layer.ChainID("sha256:a3ed95caeb02ffe68cdd9fd84406680ae93d633cb16422d00e8a7c22955b46d4"), + layerID: layer.DiffID("sha256:a3ed95caeb02ffe68cdd9fd84406680ae93d633cb16422d00e8a7c22955b46d4"), }, { registry: "registry2", v1ID: "9e3447ca24cb96d86ebd5960cb34d1299b07e0a0e03801d90b9969a2c187dd6e", - layerID: layer.ChainID("sha256:86e0e091d0da6bde2456dbb48306f3956bbeb2eae1b5b9a43045843f69fe4aaa"), + layerID: layer.DiffID("sha256:86e0e091d0da6bde2456dbb48306f3956bbeb2eae1b5b9a43045843f69fe4aaa"), }, { registry: "registry1", v1ID: "9e3447ca24cb96d86ebd5960cb34d1299b07e0a0e03801d90b9969a2c187dd6e", - layerID: layer.ChainID("sha256:03f4658f8b782e12230c1783426bd3bacce651ce582a4ffb6fbbfa2079428ecb"), + layerID: layer.DiffID("sha256:03f4658f8b782e12230c1783426bd3bacce651ce582a4ffb6fbbfa2079428ecb"), }, } diff --git a/distribution/pool.go b/distribution/pool.go deleted file mode 100644 index 8c648f6e8b..0000000000 --- a/distribution/pool.go +++ /dev/null @@ -1,51 +0,0 @@ -package distribution - -import ( - "sync" - - "github.com/docker/docker/pkg/broadcaster" -) - -// A Pool manages concurrent pulls. It deduplicates in-progress downloads. -type Pool struct { - sync.Mutex - pullingPool map[string]*broadcaster.Buffered -} - -// NewPool creates a new Pool. -func NewPool() *Pool { - return &Pool{ - pullingPool: make(map[string]*broadcaster.Buffered), - } -} - -// add checks if a pull is already running, and returns (broadcaster, true) -// if a running operation is found. Otherwise, it creates a new one and returns -// (broadcaster, false). -func (pool *Pool) add(key string) (*broadcaster.Buffered, bool) { - pool.Lock() - defer pool.Unlock() - - if p, exists := pool.pullingPool[key]; exists { - return p, true - } - - broadcaster := broadcaster.NewBuffered() - pool.pullingPool[key] = broadcaster - - return broadcaster, false -} - -func (pool *Pool) removeWithError(key string, broadcasterResult error) error { - pool.Lock() - defer pool.Unlock() - if broadcaster, exists := pool.pullingPool[key]; exists { - broadcaster.CloseWithError(broadcasterResult) - delete(pool.pullingPool, key) - } - return nil -} - -func (pool *Pool) remove(key string) error { - return pool.removeWithError(key, nil) -} diff --git a/distribution/pool_test.go b/distribution/pool_test.go deleted file mode 100644 index 80511e8342..0000000000 --- a/distribution/pool_test.go +++ /dev/null @@ -1,28 +0,0 @@ -package distribution - -import ( - "testing" -) - -func TestPools(t *testing.T) { - p := NewPool() - - if _, found := p.add("test1"); found { - t.Fatal("Expected pull test1 not to be in progress") - } - if _, found := p.add("test2"); found { - t.Fatal("Expected pull test2 not to be in progress") - } - if _, found := p.add("test1"); !found { - t.Fatalf("Expected pull test1 to be in progress`") - } - if err := p.remove("test2"); err != nil { - t.Fatal(err) - } - if err := p.remove("test2"); err != nil { - t.Fatal(err) - } - if err := p.remove("test1"); err != nil { - t.Fatal(err) - } -} diff --git a/distribution/pull.go b/distribution/pull.go index 4232ce3ca1..dec47e2112 100644 --- a/distribution/pull.go +++ b/distribution/pull.go @@ -2,7 +2,7 @@ package distribution import ( "fmt" - "io" + "os" "strings" "github.com/Sirupsen/logrus" @@ -10,11 +10,12 @@ import ( "github.com/docker/docker/cliconfig" "github.com/docker/docker/daemon/events" "github.com/docker/docker/distribution/metadata" + "github.com/docker/docker/distribution/xfer" "github.com/docker/docker/image" - "github.com/docker/docker/layer" - "github.com/docker/docker/pkg/streamformatter" + "github.com/docker/docker/pkg/progress" "github.com/docker/docker/registry" "github.com/docker/docker/tag" + "golang.org/x/net/context" ) // ImagePullConfig stores pull configuration. @@ -25,9 +26,9 @@ type ImagePullConfig struct { // AuthConfig holds authentication credentials for authenticating with // the registry. AuthConfig *cliconfig.AuthConfig - // OutStream is the output writer for showing the status of the pull + // ProgressOutput is the interface for showing the status of the pull // operation. - OutStream io.Writer + ProgressOutput progress.Output // RegistryService is the registry service to use for TLS configuration // and endpoint lookup. RegistryService *registry.Service @@ -36,14 +37,12 @@ type ImagePullConfig struct { // MetadataStore is the storage backend for distribution-specific // metadata. MetadataStore metadata.Store - // LayerStore manages layers. - LayerStore layer.Store // ImageStore manages images. ImageStore image.Store // TagStore manages tags. TagStore tag.Store - // Pool manages concurrent pulls. - Pool *Pool + // DownloadManager manages concurrent pulls. + DownloadManager *xfer.LayerDownloadManager } // Puller is an interface that abstracts pulling for different API versions. @@ -51,7 +50,7 @@ type Puller interface { // Pull tries to pull the image referenced by `tag` // Pull returns an error if any, as well as a boolean that determines whether to retry Pull on the next configured endpoint. // - Pull(ref reference.Named) (fallback bool, err error) + Pull(ctx context.Context, ref reference.Named) (fallback bool, err error) } // newPuller returns a Puller interface that will pull from either a v1 or v2 @@ -59,14 +58,13 @@ type Puller interface { // whether a v1 or v2 puller will be created. The other parameters are passed // through to the underlying puller implementation for use during the actual // pull operation. -func newPuller(endpoint registry.APIEndpoint, repoInfo *registry.RepositoryInfo, imagePullConfig *ImagePullConfig, sf *streamformatter.StreamFormatter) (Puller, error) { +func newPuller(endpoint registry.APIEndpoint, repoInfo *registry.RepositoryInfo, imagePullConfig *ImagePullConfig) (Puller, error) { switch endpoint.Version { case registry.APIVersion2: return &v2Puller{ blobSumService: metadata.NewBlobSumService(imagePullConfig.MetadataStore), endpoint: endpoint, config: imagePullConfig, - sf: sf, repoInfo: repoInfo, }, nil case registry.APIVersion1: @@ -74,7 +72,6 @@ func newPuller(endpoint registry.APIEndpoint, repoInfo *registry.RepositoryInfo, v1IDService: metadata.NewV1IDService(imagePullConfig.MetadataStore), endpoint: endpoint, config: imagePullConfig, - sf: sf, repoInfo: repoInfo, }, nil } @@ -83,9 +80,7 @@ func newPuller(endpoint registry.APIEndpoint, repoInfo *registry.RepositoryInfo, // Pull initiates a pull operation. image is the repository name to pull, and // tag may be either empty, or indicate a specific tag to pull. -func Pull(ref reference.Named, imagePullConfig *ImagePullConfig) error { - var sf = streamformatter.NewJSONStreamFormatter() - +func Pull(ctx context.Context, ref reference.Named, imagePullConfig *ImagePullConfig) error { // Resolve the Repository name from fqn to RepositoryInfo repoInfo, err := imagePullConfig.RegistryService.ResolveRepository(ref) if err != nil { @@ -120,12 +115,19 @@ func Pull(ref reference.Named, imagePullConfig *ImagePullConfig) error { for _, endpoint := range endpoints { logrus.Debugf("Trying to pull %s from %s %s", repoInfo.LocalName, endpoint.URL, endpoint.Version) - puller, err := newPuller(endpoint, repoInfo, imagePullConfig, sf) + puller, err := newPuller(endpoint, repoInfo, imagePullConfig) if err != nil { errors = append(errors, err.Error()) continue } - if fallback, err := puller.Pull(ref); err != nil { + if fallback, err := puller.Pull(ctx, ref); err != nil { + // Was this pull cancelled? If so, don't try to fall + // back. + select { + case <-ctx.Done(): + fallback = false + default: + } if fallback { if _, ok := err.(registry.ErrNoSupport); !ok { // Because we found an error that's not ErrNoSupport, discard all subsequent ErrNoSupport errors. @@ -165,11 +167,11 @@ func Pull(ref reference.Named, imagePullConfig *ImagePullConfig) error { // status message indicates that a newer image was downloaded. Otherwise, it // indicates that the image is up to date. requestedTag is the tag the message // will refer to. -func writeStatus(requestedTag string, out io.Writer, sf *streamformatter.StreamFormatter, layersDownloaded bool) { +func writeStatus(requestedTag string, out progress.Output, layersDownloaded bool) { if layersDownloaded { - out.Write(sf.FormatStatus("", "Status: Downloaded newer image for %s", requestedTag)) + progress.Message(out, "", "Status: Downloaded newer image for "+requestedTag) } else { - out.Write(sf.FormatStatus("", "Status: Image is up to date for %s", requestedTag)) + progress.Message(out, "", "Status: Image is up to date for "+requestedTag) } } @@ -183,3 +185,16 @@ func validateRepoName(name string) error { } return nil } + +// tmpFileClose creates a closer function for a temporary file that closes the file +// and also deletes it. +func tmpFileCloser(tmpFile *os.File) func() error { + return func() error { + tmpFile.Close() + if err := os.RemoveAll(tmpFile.Name()); err != nil { + logrus.Errorf("Failed to remove temp file: %s", tmpFile.Name()) + } + + return nil + } +} diff --git a/distribution/pull_v1.go b/distribution/pull_v1.go index 5b6a9101b0..9bc229a83a 100644 --- a/distribution/pull_v1.go +++ b/distribution/pull_v1.go @@ -1,43 +1,42 @@ package distribution import ( - "encoding/json" "errors" "fmt" "io" + "io/ioutil" "net" "net/url" "strings" - "sync" "time" "github.com/Sirupsen/logrus" "github.com/docker/distribution/reference" "github.com/docker/distribution/registry/client/transport" "github.com/docker/docker/distribution/metadata" + "github.com/docker/docker/distribution/xfer" "github.com/docker/docker/image" "github.com/docker/docker/image/v1" "github.com/docker/docker/layer" - "github.com/docker/docker/pkg/archive" - "github.com/docker/docker/pkg/progressreader" - "github.com/docker/docker/pkg/streamformatter" + "github.com/docker/docker/pkg/ioutils" + "github.com/docker/docker/pkg/progress" "github.com/docker/docker/pkg/stringid" "github.com/docker/docker/registry" + "golang.org/x/net/context" ) type v1Puller struct { v1IDService *metadata.V1IDService endpoint registry.APIEndpoint config *ImagePullConfig - sf *streamformatter.StreamFormatter repoInfo *registry.RepositoryInfo session *registry.Session } -func (p *v1Puller) Pull(ref reference.Named) (fallback bool, err error) { +func (p *v1Puller) Pull(ctx context.Context, ref reference.Named) (fallback bool, err error) { if _, isDigested := ref.(reference.Digested); isDigested { // Allowing fallback, because HTTPS v1 is before HTTP v2 - return true, registry.ErrNoSupport{errors.New("Cannot pull by digest with v1 registry")} + return true, registry.ErrNoSupport{Err: errors.New("Cannot pull by digest with v1 registry")} } tlsConfig, err := p.config.RegistryService.TLSConfig(p.repoInfo.Index.Name) @@ -62,19 +61,17 @@ func (p *v1Puller) Pull(ref reference.Named) (fallback bool, err error) { logrus.Debugf("Fallback from error: %s", err) return true, err } - if err := p.pullRepository(ref); err != nil { + if err := p.pullRepository(ctx, ref); err != nil { // TODO(dmcgowan): Check if should fallback return false, err } - out := p.config.OutStream - out.Write(p.sf.FormatStatus("", "%s: this image was pulled from a legacy registry. Important: This registry version will not be supported in future versions of docker.", p.repoInfo.CanonicalName.Name())) + progress.Message(p.config.ProgressOutput, "", p.repoInfo.CanonicalName.Name()+": this image was pulled from a legacy registry. Important: This registry version will not be supported in future versions of docker.") return false, nil } -func (p *v1Puller) pullRepository(ref reference.Named) error { - out := p.config.OutStream - out.Write(p.sf.FormatStatus("", "Pulling repository %s", p.repoInfo.CanonicalName.Name())) +func (p *v1Puller) pullRepository(ctx context.Context, ref reference.Named) error { + progress.Message(p.config.ProgressOutput, "", "Pulling repository "+p.repoInfo.CanonicalName.Name()) repoData, err := p.session.GetRepositoryData(p.repoInfo.RemoteName) if err != nil { @@ -112,46 +109,18 @@ func (p *v1Puller) pullRepository(ref reference.Named) error { } } - errors := make(chan error) - layerDownloaded := make(chan struct{}) - layersDownloaded := false - var wg sync.WaitGroup for _, imgData := range repoData.ImgList { if isTagged && imgData.Tag != tagged.Tag() { continue } - wg.Add(1) - go func(img *registry.ImgData) { - p.downloadImage(out, repoData, img, layerDownloaded, errors) - wg.Done() - }(imgData) - } - - go func() { - wg.Wait() - close(errors) - }() - - var lastError error -selectLoop: - for { - select { - case err, ok := <-errors: - if !ok { - break selectLoop - } - lastError = err - case <-layerDownloaded: - layersDownloaded = true + err := p.downloadImage(ctx, repoData, imgData, &layersDownloaded) + if err != nil { + return err } } - if lastError != nil { - return lastError - } - localNameRef := p.repoInfo.LocalName if isTagged { localNameRef, err = reference.WithTag(localNameRef, tagged.Tag()) @@ -159,194 +128,143 @@ selectLoop: localNameRef = p.repoInfo.LocalName } } - writeStatus(localNameRef.String(), out, p.sf, layersDownloaded) + writeStatus(localNameRef.String(), p.config.ProgressOutput, layersDownloaded) return nil } -func (p *v1Puller) downloadImage(out io.Writer, repoData *registry.RepositoryData, img *registry.ImgData, layerDownloaded chan struct{}, errors chan error) { +func (p *v1Puller) downloadImage(ctx context.Context, repoData *registry.RepositoryData, img *registry.ImgData, layersDownloaded *bool) error { if img.Tag == "" { logrus.Debugf("Image (id: %s) present in this repository but untagged, skipping", img.ID) - return + return nil } localNameRef, err := reference.WithTag(p.repoInfo.LocalName, img.Tag) if err != nil { retErr := fmt.Errorf("Image (id: %s) has invalid tag: %s", img.ID, img.Tag) logrus.Debug(retErr.Error()) - errors <- retErr + return retErr } if err := v1.ValidateID(img.ID); err != nil { - errors <- err - return + return err } - out.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), fmt.Sprintf("Pulling image (%s) from %s", img.Tag, p.repoInfo.CanonicalName.Name()), nil)) + progress.Updatef(p.config.ProgressOutput, stringid.TruncateID(img.ID), "Pulling image (%s) from %s", img.Tag, p.repoInfo.CanonicalName.Name()) success := false var lastErr error - var isDownloaded bool for _, ep := range p.repoInfo.Index.Mirrors { ep += "v1/" - out.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), fmt.Sprintf("Pulling image (%s) from %s, mirror: %s", img.Tag, p.repoInfo.CanonicalName.Name(), ep), nil)) - if isDownloaded, err = p.pullImage(out, img.ID, ep, localNameRef); err != nil { + progress.Updatef(p.config.ProgressOutput, stringid.TruncateID(img.ID), fmt.Sprintf("Pulling image (%s) from %s, mirror: %s", img.Tag, p.repoInfo.CanonicalName.Name(), ep)) + if err = p.pullImage(ctx, img.ID, ep, localNameRef, layersDownloaded); err != nil { // Don't report errors when pulling from mirrors. logrus.Debugf("Error pulling image (%s) from %s, mirror: %s, %s", img.Tag, p.repoInfo.CanonicalName.Name(), ep, err) continue } - if isDownloaded { - layerDownloaded <- struct{}{} - } success = true break } if !success { for _, ep := range repoData.Endpoints { - out.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), fmt.Sprintf("Pulling image (%s) from %s, endpoint: %s", img.Tag, p.repoInfo.CanonicalName.Name(), ep), nil)) - if isDownloaded, err = p.pullImage(out, img.ID, ep, localNameRef); err != nil { + progress.Updatef(p.config.ProgressOutput, stringid.TruncateID(img.ID), "Pulling image (%s) from %s, endpoint: %s", img.Tag, p.repoInfo.CanonicalName.Name(), ep) + if err = p.pullImage(ctx, img.ID, ep, localNameRef, layersDownloaded); err != nil { // It's not ideal that only the last error is returned, it would be better to concatenate the errors. // As the error is also given to the output stream the user will see the error. lastErr = err - out.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), fmt.Sprintf("Error pulling image (%s) from %s, endpoint: %s, %s", img.Tag, p.repoInfo.CanonicalName.Name(), ep, err), nil)) + progress.Updatef(p.config.ProgressOutput, stringid.TruncateID(img.ID), "Error pulling image (%s) from %s, endpoint: %s, %s", img.Tag, p.repoInfo.CanonicalName.Name(), ep, err) continue } - if isDownloaded { - layerDownloaded <- struct{}{} - } success = true break } } if !success { err := fmt.Errorf("Error pulling image (%s) from %s, %v", img.Tag, p.repoInfo.CanonicalName.Name(), lastErr) - out.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), err.Error(), nil)) - errors <- err - return + progress.Update(p.config.ProgressOutput, stringid.TruncateID(img.ID), err.Error()) + return err } - out.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), "Download complete", nil)) + progress.Update(p.config.ProgressOutput, stringid.TruncateID(img.ID), "Download complete") + return nil } -func (p *v1Puller) pullImage(out io.Writer, v1ID, endpoint string, localNameRef reference.Named) (layersDownloaded bool, err error) { +func (p *v1Puller) pullImage(ctx context.Context, v1ID, endpoint string, localNameRef reference.Named, layersDownloaded *bool) (err error) { var history []string history, err = p.session.GetRemoteHistory(v1ID, endpoint) if err != nil { - return false, err + return err } if len(history) < 1 { - return false, fmt.Errorf("empty history for image %s", v1ID) + return fmt.Errorf("empty history for image %s", v1ID) } - out.Write(p.sf.FormatProgress(stringid.TruncateID(v1ID), "Pulling dependent layers", nil)) - // FIXME: Try to stream the images? - // FIXME: Launch the getRemoteImage() in goroutines + progress.Update(p.config.ProgressOutput, stringid.TruncateID(v1ID), "Pulling dependent layers") var ( - referencedLayers []layer.Layer - parentID layer.ChainID - newHistory []image.History - img *image.V1Image - imgJSON []byte - imgSize int64 + descriptors []xfer.DownloadDescriptor + newHistory []image.History + imgJSON []byte + imgSize int64 ) - defer func() { - for _, l := range referencedLayers { - layer.ReleaseAndLog(p.config.LayerStore, l) - } - }() - - layersDownloaded = false - - // Iterate over layers from top-most to bottom-most, checking if any - // already exist on disk. - var i int - for i = 0; i != len(history); i++ { - v1LayerID := history[i] - // Do we have a mapping for this particular v1 ID on this - // registry? - if layerID, err := p.v1IDService.Get(v1LayerID, p.repoInfo.Index.Name); err == nil { - // Does the layer actually exist - if l, err := p.config.LayerStore.Get(layerID); err == nil { - for j := i; j >= 0; j-- { - logrus.Debugf("Layer already exists: %s", history[j]) - out.Write(p.sf.FormatProgress(stringid.TruncateID(history[j]), "Already exists", nil)) - } - referencedLayers = append(referencedLayers, l) - parentID = layerID - break - } - } - } - - needsDownload := i - // Iterate over layers, in order from bottom-most to top-most. Download - // config for all layers, and download actual layer data if needed. - for i = len(history) - 1; i >= 0; i-- { + // config for all layers and create descriptors. + for i := len(history) - 1; i >= 0; i-- { v1LayerID := history[i] - imgJSON, imgSize, err = p.downloadLayerConfig(out, v1LayerID, endpoint) + imgJSON, imgSize, err = p.downloadLayerConfig(v1LayerID, endpoint) if err != nil { - return layersDownloaded, err - } - - img = &image.V1Image{} - if err := json.Unmarshal(imgJSON, img); err != nil { - return layersDownloaded, err - } - - if i < needsDownload { - l, err := p.downloadLayer(out, v1LayerID, endpoint, parentID, imgSize, &layersDownloaded) - - // Note: This needs to be done even in the error case to avoid - // stale references to the layer. - if l != nil { - referencedLayers = append(referencedLayers, l) - } - if err != nil { - return layersDownloaded, err - } - - parentID = l.ChainID() + return err } // Create a new-style config from the legacy configs h, err := v1.HistoryFromConfig(imgJSON, false) if err != nil { - return layersDownloaded, err + return err } newHistory = append(newHistory, h) + + layerDescriptor := &v1LayerDescriptor{ + v1LayerID: v1LayerID, + indexName: p.repoInfo.Index.Name, + endpoint: endpoint, + v1IDService: p.v1IDService, + layersDownloaded: layersDownloaded, + layerSize: imgSize, + session: p.session, + } + + descriptors = append(descriptors, layerDescriptor) } rootFS := image.NewRootFS() - l := referencedLayers[len(referencedLayers)-1] - for l != nil { - rootFS.DiffIDs = append([]layer.DiffID{l.DiffID()}, rootFS.DiffIDs...) - l = l.Parent() - } - - config, err := v1.MakeConfigFromV1Config(imgJSON, rootFS, newHistory) + resultRootFS, release, err := p.config.DownloadManager.Download(ctx, *rootFS, descriptors, p.config.ProgressOutput) if err != nil { - return layersDownloaded, err + return err + } + defer release() + + config, err := v1.MakeConfigFromV1Config(imgJSON, &resultRootFS, newHistory) + if err != nil { + return err } imageID, err := p.config.ImageStore.Create(config) if err != nil { - return layersDownloaded, err + return err } if err := p.config.TagStore.AddTag(localNameRef, imageID, true); err != nil { - return layersDownloaded, err + return err } - return layersDownloaded, nil + return nil } -func (p *v1Puller) downloadLayerConfig(out io.Writer, v1LayerID, endpoint string) (imgJSON []byte, imgSize int64, err error) { - out.Write(p.sf.FormatProgress(stringid.TruncateID(v1LayerID), "Pulling metadata", nil)) +func (p *v1Puller) downloadLayerConfig(v1LayerID, endpoint string) (imgJSON []byte, imgSize int64, err error) { + progress.Update(p.config.ProgressOutput, stringid.TruncateID(v1LayerID), "Pulling metadata") retries := 5 for j := 1; j <= retries; j++ { imgJSON, imgSize, err := p.session.GetRemoteImageJSON(v1LayerID, endpoint) if err != nil && j == retries { - out.Write(p.sf.FormatProgress(stringid.TruncateID(v1LayerID), "Error pulling layer metadata", nil)) + progress.Update(p.config.ProgressOutput, stringid.TruncateID(v1LayerID), "Error pulling layer metadata") return nil, 0, err } else if err != nil { time.Sleep(time.Duration(j) * 500 * time.Millisecond) @@ -360,95 +278,66 @@ func (p *v1Puller) downloadLayerConfig(out io.Writer, v1LayerID, endpoint string return nil, 0, nil } -func (p *v1Puller) downloadLayer(out io.Writer, v1LayerID, endpoint string, parentID layer.ChainID, layerSize int64, layersDownloaded *bool) (l layer.Layer, err error) { - // ensure no two downloads of the same layer happen at the same time - poolKey := "layer:" + v1LayerID - broadcaster, found := p.config.Pool.add(poolKey) - broadcaster.Add(out) - if found { - logrus.Debugf("Image (id: %s) pull is already running, skipping", v1LayerID) - if err = broadcaster.Wait(); err != nil { - return nil, err - } - layerID, err := p.v1IDService.Get(v1LayerID, p.repoInfo.Index.Name) - if err != nil { - return nil, err - } - // Does the layer actually exist - l, err := p.config.LayerStore.Get(layerID) - if err != nil { - return nil, err - } - return l, nil - } +type v1LayerDescriptor struct { + v1LayerID string + indexName string + endpoint string + v1IDService *metadata.V1IDService + layersDownloaded *bool + layerSize int64 + session *registry.Session +} - // This must use a closure so it captures the value of err when - // the function returns, not when the 'defer' is evaluated. - defer func() { - p.config.Pool.removeWithError(poolKey, err) - }() +func (ld *v1LayerDescriptor) Key() string { + return "v1:" + ld.v1LayerID +} - retries := 5 - for j := 1; j <= retries; j++ { - // Get the layer - status := "Pulling fs layer" - if j > 1 { - status = fmt.Sprintf("Pulling fs layer [retries: %d]", j) - } - broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(v1LayerID), status, nil)) - layerReader, err := p.session.GetRemoteImageLayer(v1LayerID, endpoint, layerSize) +func (ld *v1LayerDescriptor) ID() string { + return stringid.TruncateID(ld.v1LayerID) +} + +func (ld *v1LayerDescriptor) DiffID() (layer.DiffID, error) { + return ld.v1IDService.Get(ld.v1LayerID, ld.indexName) +} + +func (ld *v1LayerDescriptor) Download(ctx context.Context, progressOutput progress.Output) (io.ReadCloser, int64, error) { + progress.Update(progressOutput, ld.ID(), "Pulling fs layer") + layerReader, err := ld.session.GetRemoteImageLayer(ld.v1LayerID, ld.endpoint, ld.layerSize) + if err != nil { + progress.Update(progressOutput, ld.ID(), "Error pulling dependent layers") if uerr, ok := err.(*url.Error); ok { err = uerr.Err } - if terr, ok := err.(net.Error); ok && terr.Timeout() && j < retries { - time.Sleep(time.Duration(j) * 500 * time.Millisecond) - continue - } else if err != nil { - broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(v1LayerID), "Error pulling dependent layers", nil)) - return nil, err + if terr, ok := err.(net.Error); ok && terr.Timeout() { + return nil, 0, err } - *layersDownloaded = true - defer layerReader.Close() + return nil, 0, xfer.DoNotRetry{Err: err} + } + *ld.layersDownloaded = true - reader := progressreader.New(progressreader.Config{ - In: layerReader, - Out: broadcaster, - Formatter: p.sf, - Size: layerSize, - NewLines: false, - ID: stringid.TruncateID(v1LayerID), - Action: "Downloading", - }) - - inflatedLayerData, err := archive.DecompressStream(reader) - if err != nil { - return nil, fmt.Errorf("could not get decompression stream: %v", err) - } - - l, err := p.config.LayerStore.Register(inflatedLayerData, parentID) - if err != nil { - return nil, fmt.Errorf("failed to register layer: %v", err) - } - logrus.Debugf("layer %s registered successfully", l.DiffID()) - - if terr, ok := err.(net.Error); ok && terr.Timeout() && j < retries { - time.Sleep(time.Duration(j) * 500 * time.Millisecond) - continue - } else if err != nil { - broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(v1LayerID), "Error downloading dependent layers", nil)) - return nil, err - } - - // Cache mapping from this v1 ID to content-addressable layer ID - if err := p.v1IDService.Set(v1LayerID, p.repoInfo.Index.Name, l.ChainID()); err != nil { - return nil, err - } - - broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(v1LayerID), "Download complete", nil)) - broadcaster.Close() - return l, nil + tmpFile, err := ioutil.TempFile("", "GetImageBlob") + if err != nil { + layerReader.Close() + return nil, 0, err } - // not reached - return nil, nil + reader := progress.NewProgressReader(ioutils.NewCancelReadCloser(ctx, layerReader), progressOutput, ld.layerSize, ld.ID(), "Downloading") + defer reader.Close() + + _, err = io.Copy(tmpFile, reader) + if err != nil { + return nil, 0, err + } + + progress.Update(progressOutput, ld.ID(), "Download complete") + + logrus.Debugf("Downloaded %s to tempfile %s", ld.ID(), tmpFile.Name()) + + tmpFile.Seek(0, 0) + return ioutils.NewReadCloserWrapper(tmpFile, tmpFileCloser(tmpFile)), ld.layerSize, nil +} + +func (ld *v1LayerDescriptor) Registered(diffID layer.DiffID) { + // Cache mapping from this layer's DiffID to the blobsum + ld.v1IDService.Set(ld.v1LayerID, ld.indexName, diffID) } diff --git a/distribution/pull_v2.go b/distribution/pull_v2.go index 0c65d8d744..50244608f1 100644 --- a/distribution/pull_v2.go +++ b/distribution/pull_v2.go @@ -15,13 +15,12 @@ import ( "github.com/docker/distribution/manifest/schema1" "github.com/docker/distribution/reference" "github.com/docker/docker/distribution/metadata" + "github.com/docker/docker/distribution/xfer" "github.com/docker/docker/image" "github.com/docker/docker/image/v1" "github.com/docker/docker/layer" - "github.com/docker/docker/pkg/archive" - "github.com/docker/docker/pkg/broadcaster" - "github.com/docker/docker/pkg/progressreader" - "github.com/docker/docker/pkg/streamformatter" + "github.com/docker/docker/pkg/ioutils" + "github.com/docker/docker/pkg/progress" "github.com/docker/docker/pkg/stringid" "github.com/docker/docker/registry" "golang.org/x/net/context" @@ -31,23 +30,19 @@ type v2Puller struct { blobSumService *metadata.BlobSumService endpoint registry.APIEndpoint config *ImagePullConfig - sf *streamformatter.StreamFormatter repoInfo *registry.RepositoryInfo repo distribution.Repository - sessionID string } -func (p *v2Puller) Pull(ref reference.Named) (fallback bool, err error) { +func (p *v2Puller) Pull(ctx context.Context, ref reference.Named) (fallback bool, err error) { // TODO(tiborvass): was ReceiveTimeout p.repo, err = NewV2Repository(p.repoInfo, p.endpoint, p.config.MetaHeaders, p.config.AuthConfig, "pull") if err != nil { - logrus.Debugf("Error getting v2 registry: %v", err) + logrus.Warnf("Error getting v2 registry: %v", err) return true, err } - p.sessionID = stringid.GenerateRandomID() - - if err := p.pullV2Repository(ref); err != nil { + if err := p.pullV2Repository(ctx, ref); err != nil { if registry.ContinueOnError(err) { logrus.Debugf("Error trying v2 registry: %v", err) return true, err @@ -57,7 +52,7 @@ func (p *v2Puller) Pull(ref reference.Named) (fallback bool, err error) { return false, nil } -func (p *v2Puller) pullV2Repository(ref reference.Named) (err error) { +func (p *v2Puller) pullV2Repository(ctx context.Context, ref reference.Named) (err error) { var refs []reference.Named taggedName := p.repoInfo.LocalName if tagged, isTagged := ref.(reference.Tagged); isTagged { @@ -73,7 +68,7 @@ func (p *v2Puller) pullV2Repository(ref reference.Named) (err error) { } refs = []reference.Named{taggedName} } else { - manSvc, err := p.repo.Manifests(context.Background()) + manSvc, err := p.repo.Manifests(ctx) if err != nil { return err } @@ -98,98 +93,109 @@ func (p *v2Puller) pullV2Repository(ref reference.Named) (err error) { for _, pullRef := range refs { // pulledNew is true if either new layers were downloaded OR if existing images were newly tagged // TODO(tiborvass): should we change the name of `layersDownload`? What about message in WriteStatus? - pulledNew, err := p.pullV2Tag(p.config.OutStream, pullRef) + pulledNew, err := p.pullV2Tag(ctx, pullRef) if err != nil { return err } layersDownloaded = layersDownloaded || pulledNew } - writeStatus(taggedName.String(), p.config.OutStream, p.sf, layersDownloaded) + writeStatus(taggedName.String(), p.config.ProgressOutput, layersDownloaded) return nil } -// downloadInfo is used to pass information from download to extractor -type downloadInfo struct { - tmpFile *os.File - digest digest.Digest - layer distribution.ReadSeekCloser - size int64 - err chan error - poolKey string - broadcaster *broadcaster.Buffered +type v2LayerDescriptor struct { + digest digest.Digest + repo distribution.Repository + blobSumService *metadata.BlobSumService } -type errVerification struct{} +func (ld *v2LayerDescriptor) Key() string { + return "v2:" + ld.digest.String() +} -func (errVerification) Error() string { return "verification failed" } +func (ld *v2LayerDescriptor) ID() string { + return stringid.TruncateID(ld.digest.String()) +} -func (p *v2Puller) download(di *downloadInfo) { - logrus.Debugf("pulling blob %q", di.digest) +func (ld *v2LayerDescriptor) DiffID() (layer.DiffID, error) { + return ld.blobSumService.GetDiffID(ld.digest) +} - blobs := p.repo.Blobs(context.Background()) +func (ld *v2LayerDescriptor) Download(ctx context.Context, progressOutput progress.Output) (io.ReadCloser, int64, error) { + logrus.Debugf("pulling blob %q", ld.digest) - layerDownload, err := blobs.Open(context.Background(), di.digest) + blobs := ld.repo.Blobs(ctx) + + layerDownload, err := blobs.Open(ctx, ld.digest) if err != nil { - logrus.Debugf("Error fetching layer: %v", err) - di.err <- err - return + logrus.Debugf("Error statting layer: %v", err) + if err == distribution.ErrBlobUnknown { + return nil, 0, xfer.DoNotRetry{Err: err} + } + return nil, 0, retryOnError(err) } - defer layerDownload.Close() - di.size, err = layerDownload.Seek(0, os.SEEK_END) + size, err := layerDownload.Seek(0, os.SEEK_END) if err != nil { // Seek failed, perhaps because there was no Content-Length // header. This shouldn't fail the download, because we can // still continue without a progress bar. - di.size = 0 + size = 0 } else { // Restore the seek offset at the beginning of the stream. _, err = layerDownload.Seek(0, os.SEEK_SET) if err != nil { - di.err <- err - return + return nil, 0, err } } - verifier, err := digest.NewDigestVerifier(di.digest) + reader := progress.NewProgressReader(ioutils.NewCancelReadCloser(ctx, layerDownload), progressOutput, size, ld.ID(), "Downloading") + defer reader.Close() + + verifier, err := digest.NewDigestVerifier(ld.digest) if err != nil { - di.err <- err - return + return nil, 0, xfer.DoNotRetry{Err: err} } - digestStr := di.digest.String() + tmpFile, err := ioutil.TempFile("", "GetImageBlob") + if err != nil { + return nil, 0, xfer.DoNotRetry{Err: err} + } - reader := progressreader.New(progressreader.Config{ - In: ioutil.NopCloser(io.TeeReader(layerDownload, verifier)), - Out: di.broadcaster, - Formatter: p.sf, - Size: di.size, - NewLines: false, - ID: stringid.TruncateID(digestStr), - Action: "Downloading", - }) - io.Copy(di.tmpFile, reader) + _, err = io.Copy(tmpFile, io.TeeReader(reader, verifier)) + if err != nil { + return nil, 0, retryOnError(err) + } - di.broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(digestStr), "Verifying Checksum", nil)) + progress.Update(progressOutput, ld.ID(), "Verifying Checksum") if !verifier.Verified() { - err = fmt.Errorf("filesystem layer verification failed for digest %s", di.digest) + err = fmt.Errorf("filesystem layer verification failed for digest %s", ld.digest) logrus.Error(err) - di.err <- err - return + tmpFile.Close() + if err := os.RemoveAll(tmpFile.Name()); err != nil { + logrus.Errorf("Failed to remove temp file: %s", tmpFile.Name()) + } + + return nil, 0, xfer.DoNotRetry{Err: err} } - di.broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(digestStr), "Download complete", nil)) + progress.Update(progressOutput, ld.ID(), "Download complete") - logrus.Debugf("Downloaded %s to tempfile %s", digestStr, di.tmpFile.Name()) - di.layer = layerDownload + logrus.Debugf("Downloaded %s to tempfile %s", ld.ID(), tmpFile.Name()) - di.err <- nil + tmpFile.Seek(0, 0) + return ioutils.NewReadCloserWrapper(tmpFile, tmpFileCloser(tmpFile)), size, nil } -func (p *v2Puller) pullV2Tag(out io.Writer, ref reference.Named) (tagUpdated bool, err error) { +func (ld *v2LayerDescriptor) Registered(diffID layer.DiffID) { + // Cache mapping from this layer's DiffID to the blobsum + ld.blobSumService.Add(diffID, ld.digest) +} + +func (p *v2Puller) pullV2Tag(ctx context.Context, ref reference.Named) (tagUpdated bool, err error) { tagOrDigest := "" if tagged, isTagged := ref.(reference.Tagged); isTagged { tagOrDigest = tagged.Tag() @@ -201,7 +207,7 @@ func (p *v2Puller) pullV2Tag(out io.Writer, ref reference.Named) (tagUpdated boo logrus.Debugf("Pulling ref from V2 registry: %q", tagOrDigest) - manSvc, err := p.repo.Manifests(context.Background()) + manSvc, err := p.repo.Manifests(ctx) if err != nil { return false, err } @@ -231,33 +237,17 @@ func (p *v2Puller) pullV2Tag(out io.Writer, ref reference.Named) (tagUpdated boo return false, err } - out.Write(p.sf.FormatStatus(tagOrDigest, "Pulling from %s", p.repo.Name())) + progress.Message(p.config.ProgressOutput, tagOrDigest, "Pulling from "+p.repo.Name()) - var downloads []*downloadInfo - - defer func() { - for _, d := range downloads { - p.config.Pool.removeWithError(d.poolKey, err) - if d.tmpFile != nil { - d.tmpFile.Close() - if err := os.RemoveAll(d.tmpFile.Name()); err != nil { - logrus.Errorf("Failed to remove temp file: %s", d.tmpFile.Name()) - } - } - } - }() + var descriptors []xfer.DownloadDescriptor // Image history converted to the new format var history []image.History - poolKey := "v2layer:" - notFoundLocally := false - // Note that the order of this loop is in the direction of bottom-most // to top-most, so that the downloads slice gets ordered correctly. for i := len(verifiedManifest.FSLayers) - 1; i >= 0; i-- { blobSum := verifiedManifest.FSLayers[i].BlobSum - poolKey += blobSum.String() var throwAway struct { ThrowAway bool `json:"throwaway,omitempty"` @@ -276,119 +266,22 @@ func (p *v2Puller) pullV2Tag(out io.Writer, ref reference.Named) (tagUpdated boo continue } - // Do we have a layer on disk corresponding to the set of - // blobsums up to this point? - if !notFoundLocally { - notFoundLocally = true - diffID, err := p.blobSumService.GetDiffID(blobSum) - if err == nil { - rootFS.Append(diffID) - if l, err := p.config.LayerStore.Get(rootFS.ChainID()); err == nil { - notFoundLocally = false - logrus.Debugf("Layer already exists: %s", blobSum.String()) - out.Write(p.sf.FormatProgress(stringid.TruncateID(blobSum.String()), "Already exists", nil)) - defer layer.ReleaseAndLog(p.config.LayerStore, l) - continue - } else { - rootFS.DiffIDs = rootFS.DiffIDs[:len(rootFS.DiffIDs)-1] - } - } + layerDescriptor := &v2LayerDescriptor{ + digest: blobSum, + repo: p.repo, + blobSumService: p.blobSumService, } - out.Write(p.sf.FormatProgress(stringid.TruncateID(blobSum.String()), "Pulling fs layer", nil)) - - tmpFile, err := ioutil.TempFile("", "GetImageBlob") - if err != nil { - return false, err - } - - d := &downloadInfo{ - poolKey: poolKey, - digest: blobSum, - tmpFile: tmpFile, - // TODO: seems like this chan buffer solved hanging problem in go1.5, - // this can indicate some deeper problem that somehow we never take - // error from channel in loop below - err: make(chan error, 1), - } - - downloads = append(downloads, d) - - broadcaster, found := p.config.Pool.add(d.poolKey) - broadcaster.Add(out) - d.broadcaster = broadcaster - if found { - d.err <- nil - } else { - go p.download(d) - } + descriptors = append(descriptors, layerDescriptor) } - for _, d := range downloads { - if err := <-d.err; err != nil { - return false, err - } - - if d.layer == nil { - // Wait for a different pull to download and extract - // this layer. - err = d.broadcaster.Wait() - if err != nil { - return false, err - } - - diffID, err := p.blobSumService.GetDiffID(d.digest) - if err != nil { - return false, err - } - rootFS.Append(diffID) - - l, err := p.config.LayerStore.Get(rootFS.ChainID()) - if err != nil { - return false, err - } - - defer layer.ReleaseAndLog(p.config.LayerStore, l) - - continue - } - - d.tmpFile.Seek(0, 0) - reader := progressreader.New(progressreader.Config{ - In: d.tmpFile, - Out: d.broadcaster, - Formatter: p.sf, - Size: d.size, - NewLines: false, - ID: stringid.TruncateID(d.digest.String()), - Action: "Extracting", - }) - - inflatedLayerData, err := archive.DecompressStream(reader) - if err != nil { - return false, fmt.Errorf("could not get decompression stream: %v", err) - } - - l, err := p.config.LayerStore.Register(inflatedLayerData, rootFS.ChainID()) - if err != nil { - return false, fmt.Errorf("failed to register layer: %v", err) - } - logrus.Debugf("layer %s registered successfully", l.DiffID()) - rootFS.Append(l.DiffID()) - - // Cache mapping from this layer's DiffID to the blobsum - if err := p.blobSumService.Add(l.DiffID(), d.digest); err != nil { - return false, err - } - - defer layer.ReleaseAndLog(p.config.LayerStore, l) - - d.broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(d.digest.String()), "Pull complete", nil)) - d.broadcaster.Close() - tagUpdated = true + resultRootFS, release, err := p.config.DownloadManager.Download(ctx, *rootFS, descriptors, p.config.ProgressOutput) + if err != nil { + return false, err } + defer release() - config, err := v1.MakeConfigFromV1Config([]byte(verifiedManifest.History[0].V1Compatibility), rootFS, history) + config, err := v1.MakeConfigFromV1Config([]byte(verifiedManifest.History[0].V1Compatibility), &resultRootFS, history) if err != nil { return false, err } @@ -403,30 +296,24 @@ func (p *v2Puller) pullV2Tag(out io.Writer, ref reference.Named) (tagUpdated boo return false, err } - // Check for new tag if no layers downloaded - var oldTagImageID image.ID - if !tagUpdated { - oldTagImageID, err = p.config.TagStore.Get(ref) - if err != nil || oldTagImageID != imageID { - tagUpdated = true - } + if manifestDigest != "" { + progress.Message(p.config.ProgressOutput, "", "Digest: "+manifestDigest.String()) } - if tagUpdated { - if canonical, ok := ref.(reference.Canonical); ok { - if err = p.config.TagStore.AddDigest(canonical, imageID, true); err != nil { - return false, err - } - } else if err = p.config.TagStore.AddTag(ref, imageID, true); err != nil { + oldTagImageID, err := p.config.TagStore.Get(ref) + if err == nil && oldTagImageID == imageID { + return false, nil + } + + if canonical, ok := ref.(reference.Canonical); ok { + if err = p.config.TagStore.AddDigest(canonical, imageID, true); err != nil { return false, err } + } else if err = p.config.TagStore.AddTag(ref, imageID, true); err != nil { + return false, err } - if manifestDigest != "" { - out.Write(p.sf.FormatStatus("", "Digest: %s", manifestDigest)) - } - - return tagUpdated, nil + return true, nil } func verifyManifest(signedManifest *schema1.SignedManifest, ref reference.Reference) (m *schema1.Manifest, err error) { diff --git a/distribution/push.go b/distribution/push.go index ee41c2e1e3..7e6cd6c4d6 100644 --- a/distribution/push.go +++ b/distribution/push.go @@ -12,12 +12,14 @@ import ( "github.com/docker/docker/cliconfig" "github.com/docker/docker/daemon/events" "github.com/docker/docker/distribution/metadata" + "github.com/docker/docker/distribution/xfer" "github.com/docker/docker/image" "github.com/docker/docker/layer" - "github.com/docker/docker/pkg/streamformatter" + "github.com/docker/docker/pkg/progress" "github.com/docker/docker/registry" "github.com/docker/docker/tag" "github.com/docker/libtrust" + "golang.org/x/net/context" ) // ImagePushConfig stores push configuration. @@ -28,9 +30,9 @@ type ImagePushConfig struct { // AuthConfig holds authentication credentials for authenticating with // the registry. AuthConfig *cliconfig.AuthConfig - // OutStream is the output writer for showing the status of the push + // ProgressOutput is the interface for showing the status of the push // operation. - OutStream io.Writer + ProgressOutput progress.Output // RegistryService is the registry service to use for TLS configuration // and endpoint lookup. RegistryService *registry.Service @@ -48,6 +50,8 @@ type ImagePushConfig struct { // TrustKey is the private key for legacy signatures. This is typically // an ephemeral key, since these signatures are no longer verified. TrustKey libtrust.PrivateKey + // UploadManager dispatches uploads. + UploadManager *xfer.LayerUploadManager } // Pusher is an interface that abstracts pushing for different API versions. @@ -56,7 +60,7 @@ type Pusher interface { // Push returns an error if any, as well as a boolean that determines whether to retry Push on the next configured endpoint. // // TODO(tiborvass): have Push() take a reference to repository + tag, so that the pusher itself is repository-agnostic. - Push() (fallback bool, err error) + Push(ctx context.Context) (fallback bool, err error) } const compressionBufSize = 32768 @@ -66,7 +70,7 @@ const compressionBufSize = 32768 // whether a v1 or v2 pusher will be created. The other parameters are passed // through to the underlying pusher implementation for use during the actual // push operation. -func NewPusher(ref reference.Named, endpoint registry.APIEndpoint, repoInfo *registry.RepositoryInfo, imagePushConfig *ImagePushConfig, sf *streamformatter.StreamFormatter) (Pusher, error) { +func NewPusher(ref reference.Named, endpoint registry.APIEndpoint, repoInfo *registry.RepositoryInfo, imagePushConfig *ImagePushConfig) (Pusher, error) { switch endpoint.Version { case registry.APIVersion2: return &v2Pusher{ @@ -75,8 +79,7 @@ func NewPusher(ref reference.Named, endpoint registry.APIEndpoint, repoInfo *reg endpoint: endpoint, repoInfo: repoInfo, config: imagePushConfig, - sf: sf, - layersPushed: make(map[digest.Digest]bool), + layersPushed: pushMap{layersPushed: make(map[digest.Digest]bool)}, }, nil case registry.APIVersion1: return &v1Pusher{ @@ -85,7 +88,6 @@ func NewPusher(ref reference.Named, endpoint registry.APIEndpoint, repoInfo *reg endpoint: endpoint, repoInfo: repoInfo, config: imagePushConfig, - sf: sf, }, nil } return nil, fmt.Errorf("unknown version %d for registry %s", endpoint.Version, endpoint.URL) @@ -94,11 +96,9 @@ func NewPusher(ref reference.Named, endpoint registry.APIEndpoint, repoInfo *reg // Push initiates a push operation on the repository named localName. // ref is the specific variant of the image to be pushed. // If no tag is provided, all tags will be pushed. -func Push(ref reference.Named, imagePushConfig *ImagePushConfig) error { +func Push(ctx context.Context, ref reference.Named, imagePushConfig *ImagePushConfig) error { // FIXME: Allow to interrupt current push when new push of same image is done. - var sf = streamformatter.NewJSONStreamFormatter() - // Resolve the Repository name from fqn to RepositoryInfo repoInfo, err := imagePushConfig.RegistryService.ResolveRepository(ref) if err != nil { @@ -110,7 +110,7 @@ func Push(ref reference.Named, imagePushConfig *ImagePushConfig) error { return err } - imagePushConfig.OutStream.Write(sf.FormatStatus("", "The push refers to a repository [%s]", repoInfo.CanonicalName)) + progress.Messagef(imagePushConfig.ProgressOutput, "", "The push refers to a repository [%s]", repoInfo.CanonicalName.String()) associations := imagePushConfig.TagStore.ReferencesByName(repoInfo.LocalName) if len(associations) == 0 { @@ -121,12 +121,20 @@ func Push(ref reference.Named, imagePushConfig *ImagePushConfig) error { for _, endpoint := range endpoints { logrus.Debugf("Trying to push %s to %s %s", repoInfo.CanonicalName, endpoint.URL, endpoint.Version) - pusher, err := NewPusher(ref, endpoint, repoInfo, imagePushConfig, sf) + pusher, err := NewPusher(ref, endpoint, repoInfo, imagePushConfig) if err != nil { lastErr = err continue } - if fallback, err := pusher.Push(); err != nil { + if fallback, err := pusher.Push(ctx); err != nil { + // Was this push cancelled? If so, don't try to fall + // back. + select { + case <-ctx.Done(): + fallback = false + default: + } + if fallback { lastErr = err continue diff --git a/distribution/push_v1.go b/distribution/push_v1.go index a4a0de0811..155bbc8647 100644 --- a/distribution/push_v1.go +++ b/distribution/push_v1.go @@ -2,8 +2,6 @@ package distribution import ( "fmt" - "io" - "io/ioutil" "sync" "github.com/Sirupsen/logrus" @@ -15,25 +13,23 @@ import ( "github.com/docker/docker/image/v1" "github.com/docker/docker/layer" "github.com/docker/docker/pkg/ioutils" - "github.com/docker/docker/pkg/progressreader" - "github.com/docker/docker/pkg/streamformatter" + "github.com/docker/docker/pkg/progress" "github.com/docker/docker/pkg/stringid" "github.com/docker/docker/registry" + "golang.org/x/net/context" ) type v1Pusher struct { + ctx context.Context v1IDService *metadata.V1IDService endpoint registry.APIEndpoint ref reference.Named repoInfo *registry.RepositoryInfo config *ImagePushConfig - sf *streamformatter.StreamFormatter session *registry.Session - - out io.Writer } -func (p *v1Pusher) Push() (fallback bool, err error) { +func (p *v1Pusher) Push(ctx context.Context) (fallback bool, err error) { tlsConfig, err := p.config.RegistryService.TLSConfig(p.repoInfo.Index.Name) if err != nil { return false, err @@ -55,7 +51,7 @@ func (p *v1Pusher) Push() (fallback bool, err error) { // TODO(dmcgowan): Check if should fallback return true, err } - if err := p.pushRepository(); err != nil { + if err := p.pushRepository(ctx); err != nil { // TODO(dmcgowan): Check if should fallback return false, err } @@ -306,12 +302,12 @@ func (p *v1Pusher) lookupImageOnEndpoint(wg *sync.WaitGroup, endpoint string, im logrus.Errorf("Error in LookupRemoteImage: %s", err) imagesToPush <- v1ID } else { - p.out.Write(p.sf.FormatStatus("", "Image %s already pushed, skipping", stringid.TruncateID(v1ID))) + progress.Messagef(p.config.ProgressOutput, "", "Image %s already pushed, skipping", stringid.TruncateID(v1ID)) } } } -func (p *v1Pusher) pushImageToEndpoint(endpoint string, imageList []v1Image, tags map[image.ID][]string, repo *registry.RepositoryData) error { +func (p *v1Pusher) pushImageToEndpoint(ctx context.Context, endpoint string, imageList []v1Image, tags map[image.ID][]string, repo *registry.RepositoryData) error { workerCount := len(imageList) // start a maximum of 5 workers to check if images exist on the specified endpoint. if workerCount > 5 { @@ -349,14 +345,14 @@ func (p *v1Pusher) pushImageToEndpoint(endpoint string, imageList []v1Image, tag for _, img := range imageList { v1ID := img.V1ID() if _, push := shouldPush[v1ID]; push { - if _, err := p.pushImage(img, endpoint); err != nil { + if _, err := p.pushImage(ctx, img, endpoint); err != nil { // FIXME: Continue on error? return err } } if topImage, isTopImage := img.(*v1TopImage); isTopImage { for _, tag := range tags[topImage.imageID] { - p.out.Write(p.sf.FormatStatus("", "Pushing tag for rev [%s] on {%s}", stringid.TruncateID(v1ID), endpoint+"repositories/"+p.repoInfo.RemoteName.Name()+"/tags/"+tag)) + progress.Messagef(p.config.ProgressOutput, "", "Pushing tag for rev [%s] on {%s}", stringid.TruncateID(v1ID), endpoint+"repositories/"+p.repoInfo.RemoteName.Name()+"/tags/"+tag) if err := p.session.PushRegistryTag(p.repoInfo.RemoteName, v1ID, tag, endpoint); err != nil { return err } @@ -367,8 +363,7 @@ func (p *v1Pusher) pushImageToEndpoint(endpoint string, imageList []v1Image, tag } // pushRepository pushes layers that do not already exist on the registry. -func (p *v1Pusher) pushRepository() error { - p.out = ioutils.NewWriteFlusher(p.config.OutStream) +func (p *v1Pusher) pushRepository(ctx context.Context) error { imgList, tags, referencedLayers, err := p.getImageList() defer func() { for _, l := range referencedLayers { @@ -378,7 +373,7 @@ func (p *v1Pusher) pushRepository() error { if err != nil { return err } - p.out.Write(p.sf.FormatStatus("", "Sending image list")) + progress.Message(p.config.ProgressOutput, "", "Sending image list") imageIndex := createImageIndex(imgList, tags) for _, data := range imageIndex { @@ -391,10 +386,10 @@ func (p *v1Pusher) pushRepository() error { if err != nil { return err } - p.out.Write(p.sf.FormatStatus("", "Pushing repository %s", p.repoInfo.CanonicalName)) + progress.Message(p.config.ProgressOutput, "", "Pushing repository "+p.repoInfo.CanonicalName.String()) // push the repository to each of the endpoints only if it does not exist. for _, endpoint := range repoData.Endpoints { - if err := p.pushImageToEndpoint(endpoint, imgList, tags, repoData); err != nil { + if err := p.pushImageToEndpoint(ctx, endpoint, imgList, tags, repoData); err != nil { return err } } @@ -402,11 +397,11 @@ func (p *v1Pusher) pushRepository() error { return err } -func (p *v1Pusher) pushImage(v1Image v1Image, ep string) (checksum string, err error) { +func (p *v1Pusher) pushImage(ctx context.Context, v1Image v1Image, ep string) (checksum string, err error) { v1ID := v1Image.V1ID() jsonRaw := v1Image.Config() - p.out.Write(p.sf.FormatProgress(stringid.TruncateID(v1ID), "Pushing", nil)) + progress.Update(p.config.ProgressOutput, stringid.TruncateID(v1ID), "Pushing") // General rule is to use ID for graph accesses and compatibilityID for // calls to session.registry() @@ -417,7 +412,7 @@ func (p *v1Pusher) pushImage(v1Image v1Image, ep string) (checksum string, err e // Send the json if err := p.session.PushImageJSONRegistry(imgData, jsonRaw, ep); err != nil { if err == registry.ErrAlreadyExists { - p.out.Write(p.sf.FormatProgress(stringid.TruncateID(v1ID), "Image already pushed, skipping", nil)) + progress.Update(p.config.ProgressOutput, stringid.TruncateID(v1ID), "Image already pushed, skipping") return "", nil } return "", err @@ -437,15 +432,8 @@ func (p *v1Pusher) pushImage(v1Image v1Image, ep string) (checksum string, err e // Send the layer logrus.Debugf("rendered layer for %s of [%d] size", v1ID, size) - reader := progressreader.New(progressreader.Config{ - In: ioutil.NopCloser(arch), - Out: p.out, - Formatter: p.sf, - Size: size, - NewLines: false, - ID: stringid.TruncateID(v1ID), - Action: "Pushing", - }) + reader := progress.NewProgressReader(ioutils.NewCancelReadCloser(ctx, arch), p.config.ProgressOutput, size, stringid.TruncateID(v1ID), "Pushing") + defer reader.Close() checksum, checksumPayload, err := p.session.PushImageLayerRegistry(v1ID, reader, ep, jsonRaw) if err != nil { @@ -458,10 +446,10 @@ func (p *v1Pusher) pushImage(v1Image v1Image, ep string) (checksum string, err e return "", err } - if err := p.v1IDService.Set(v1ID, p.repoInfo.Index.Name, l.ChainID()); err != nil { + if err := p.v1IDService.Set(v1ID, p.repoInfo.Index.Name, l.DiffID()); err != nil { logrus.Warnf("Could not set v1 ID mapping: %v", err) } - p.out.Write(p.sf.FormatProgress(stringid.TruncateID(v1ID), "Image successfully pushed", nil)) + progress.Update(p.config.ProgressOutput, stringid.TruncateID(v1ID), "Image successfully pushed") return imgData.Checksum, nil } diff --git a/distribution/push_v2.go b/distribution/push_v2.go index f2c23ea0ea..4f725eb60a 100644 --- a/distribution/push_v2.go +++ b/distribution/push_v2.go @@ -5,7 +5,7 @@ import ( "errors" "fmt" "io" - "io/ioutil" + "sync" "time" "github.com/Sirupsen/logrus" @@ -15,11 +15,12 @@ import ( "github.com/docker/distribution/manifest/schema1" "github.com/docker/distribution/reference" "github.com/docker/docker/distribution/metadata" + "github.com/docker/docker/distribution/xfer" "github.com/docker/docker/image" "github.com/docker/docker/image/v1" "github.com/docker/docker/layer" - "github.com/docker/docker/pkg/progressreader" - "github.com/docker/docker/pkg/streamformatter" + "github.com/docker/docker/pkg/ioutils" + "github.com/docker/docker/pkg/progress" "github.com/docker/docker/pkg/stringid" "github.com/docker/docker/registry" "github.com/docker/docker/tag" @@ -32,16 +33,20 @@ type v2Pusher struct { endpoint registry.APIEndpoint repoInfo *registry.RepositoryInfo config *ImagePushConfig - sf *streamformatter.StreamFormatter repo distribution.Repository // layersPushed is the set of layers known to exist on the remote side. // This avoids redundant queries when pushing multiple tags that // involve the same layers. + layersPushed pushMap +} + +type pushMap struct { + sync.Mutex layersPushed map[digest.Digest]bool } -func (p *v2Pusher) Push() (fallback bool, err error) { +func (p *v2Pusher) Push(ctx context.Context) (fallback bool, err error) { p.repo, err = NewV2Repository(p.repoInfo, p.endpoint, p.config.MetaHeaders, p.config.AuthConfig, "push", "pull") if err != nil { logrus.Debugf("Error getting v2 registry: %v", err) @@ -75,7 +80,7 @@ func (p *v2Pusher) Push() (fallback bool, err error) { } for _, association := range associations { - if err := p.pushV2Tag(association); err != nil { + if err := p.pushV2Tag(ctx, association); err != nil { return false, err } } @@ -83,7 +88,7 @@ func (p *v2Pusher) Push() (fallback bool, err error) { return false, nil } -func (p *v2Pusher) pushV2Tag(association tag.Association) error { +func (p *v2Pusher) pushV2Tag(ctx context.Context, association tag.Association) error { ref := association.Ref logrus.Debugf("Pushing repository: %s", ref.String()) @@ -92,8 +97,6 @@ func (p *v2Pusher) pushV2Tag(association tag.Association) error { return fmt.Errorf("could not find image from tag %s: %v", ref.String(), err) } - out := p.config.OutStream - var l layer.Layer topLayerID := img.RootFS.ChainID() @@ -107,33 +110,41 @@ func (p *v2Pusher) pushV2Tag(association tag.Association) error { defer layer.ReleaseAndLog(p.config.LayerStore, l) } - fsLayers := make(map[layer.DiffID]schema1.FSLayer) + var descriptors []xfer.UploadDescriptor // Push empty layer if necessary for _, h := range img.History { if h.EmptyLayer { - dgst, err := p.pushLayerIfNecessary(out, layer.EmptyLayer) - if err != nil { - return err + descriptors = []xfer.UploadDescriptor{ + &v2PushDescriptor{ + layer: layer.EmptyLayer, + blobSumService: p.blobSumService, + repo: p.repo, + layersPushed: &p.layersPushed, + }, } - p.layersPushed[dgst] = true - fsLayers[layer.EmptyLayer.DiffID()] = schema1.FSLayer{BlobSum: dgst} break } } + // Loop bounds condition is to avoid pushing the base layer on Windows. for i := 0; i < len(img.RootFS.DiffIDs); i++ { - dgst, err := p.pushLayerIfNecessary(out, l) - if err != nil { - return err + descriptor := &v2PushDescriptor{ + layer: l, + blobSumService: p.blobSumService, + repo: p.repo, + layersPushed: &p.layersPushed, } - - p.layersPushed[dgst] = true - fsLayers[l.DiffID()] = schema1.FSLayer{BlobSum: dgst} + descriptors = append(descriptors, descriptor) l = l.Parent() } + fsLayers, err := p.config.UploadManager.Upload(ctx, descriptors, p.config.ProgressOutput) + if err != nil { + return err + } + var tag string if tagged, isTagged := ref.(reference.Tagged); isTagged { tag = tagged.Tag() @@ -157,59 +168,124 @@ func (p *v2Pusher) pushV2Tag(association tag.Association) error { if tagged, isTagged := ref.(reference.Tagged); isTagged { // NOTE: do not change this format without first changing the trust client // code. This information is used to determine what was pushed and should be signed. - out.Write(p.sf.FormatStatus("", "%s: digest: %s size: %d", tagged.Tag(), manifestDigest, manifestSize)) + progress.Messagef(p.config.ProgressOutput, "", "%s: digest: %s size: %d", tagged.Tag(), manifestDigest, manifestSize) } } - manSvc, err := p.repo.Manifests(context.Background()) + manSvc, err := p.repo.Manifests(ctx) if err != nil { return err } return manSvc.Put(signed) } -func (p *v2Pusher) pushLayerIfNecessary(out io.Writer, l layer.Layer) (digest.Digest, error) { - logrus.Debugf("Pushing layer: %s", l.DiffID()) +type v2PushDescriptor struct { + layer layer.Layer + blobSumService *metadata.BlobSumService + repo distribution.Repository + layersPushed *pushMap +} + +func (pd *v2PushDescriptor) Key() string { + return "v2push:" + pd.repo.Name() + " " + pd.layer.DiffID().String() +} + +func (pd *v2PushDescriptor) ID() string { + return stringid.TruncateID(pd.layer.DiffID().String()) +} + +func (pd *v2PushDescriptor) DiffID() layer.DiffID { + return pd.layer.DiffID() +} + +func (pd *v2PushDescriptor) Upload(ctx context.Context, progressOutput progress.Output) (digest.Digest, error) { + diffID := pd.DiffID() + + logrus.Debugf("Pushing layer: %s", diffID) // Do we have any blobsums associated with this layer's DiffID? - possibleBlobsums, err := p.blobSumService.GetBlobSums(l.DiffID()) + possibleBlobsums, err := pd.blobSumService.GetBlobSums(diffID) if err == nil { - dgst, exists, err := p.blobSumAlreadyExists(possibleBlobsums) + dgst, exists, err := blobSumAlreadyExists(ctx, possibleBlobsums, pd.repo, pd.layersPushed) if err != nil { - out.Write(p.sf.FormatProgress(stringid.TruncateID(string(l.DiffID())), "Image push failed", nil)) - return "", err + progress.Update(progressOutput, pd.ID(), "Image push failed") + return "", retryOnError(err) } if exists { - out.Write(p.sf.FormatProgress(stringid.TruncateID(string(l.DiffID())), "Layer already exists", nil)) + progress.Update(progressOutput, pd.ID(), "Layer already exists") return dgst, nil } } // if digest was empty or not saved, or if blob does not exist on the remote repository, // then push the blob. - pushDigest, err := p.pushV2Layer(p.repo.Blobs(context.Background()), l) + bs := pd.repo.Blobs(ctx) + + // Send the layer + layerUpload, err := bs.Create(ctx) if err != nil { - return "", err + return "", retryOnError(err) } + defer layerUpload.Close() + + arch, err := pd.layer.TarStream() + if err != nil { + return "", xfer.DoNotRetry{Err: err} + } + + // don't care if this fails; best effort + size, _ := pd.layer.DiffSize() + + reader := progress.NewProgressReader(ioutils.NewCancelReadCloser(ctx, arch), progressOutput, size, pd.ID(), "Pushing") + defer reader.Close() + compressedReader := compress(reader) + + digester := digest.Canonical.New() + tee := io.TeeReader(compressedReader, digester.Hash()) + + nn, err := layerUpload.ReadFrom(tee) + compressedReader.Close() + if err != nil { + return "", retryOnError(err) + } + + pushDigest := digester.Digest() + if _, err := layerUpload.Commit(ctx, distribution.Descriptor{Digest: pushDigest}); err != nil { + return "", retryOnError(err) + } + + logrus.Debugf("uploaded layer %s (%s), %d bytes", diffID, pushDigest, nn) + progress.Update(progressOutput, pd.ID(), "Pushed") + // Cache mapping from this layer's DiffID to the blobsum - if err := p.blobSumService.Add(l.DiffID(), pushDigest); err != nil { - return "", err + if err := pd.blobSumService.Add(diffID, pushDigest); err != nil { + return "", xfer.DoNotRetry{Err: err} } + pd.layersPushed.Lock() + pd.layersPushed.layersPushed[pushDigest] = true + pd.layersPushed.Unlock() + return pushDigest, nil } // blobSumAlreadyExists checks if the registry already know about any of the // blobsums passed in the "blobsums" slice. If it finds one that the registry // knows about, it returns the known digest and "true". -func (p *v2Pusher) blobSumAlreadyExists(blobsums []digest.Digest) (digest.Digest, bool, error) { +func blobSumAlreadyExists(ctx context.Context, blobsums []digest.Digest, repo distribution.Repository, layersPushed *pushMap) (digest.Digest, bool, error) { + layersPushed.Lock() for _, dgst := range blobsums { - if p.layersPushed[dgst] { + if layersPushed.layersPushed[dgst] { // it is already known that the push is not needed and // therefore doing a stat is unnecessary + layersPushed.Unlock() return dgst, true, nil } - _, err := p.repo.Blobs(context.Background()).Stat(context.Background(), dgst) + } + layersPushed.Unlock() + + for _, dgst := range blobsums { + _, err := repo.Blobs(ctx).Stat(ctx, dgst) switch err { case nil: return dgst, true, nil @@ -226,7 +302,7 @@ func (p *v2Pusher) blobSumAlreadyExists(blobsums []digest.Digest) (digest.Digest // FSLayer digests. // FIXME: This should be moved to the distribution repo, since it will also // be useful for converting new manifests to the old format. -func CreateV2Manifest(name, tag string, img *image.Image, fsLayers map[layer.DiffID]schema1.FSLayer) (*schema1.Manifest, error) { +func CreateV2Manifest(name, tag string, img *image.Image, fsLayers map[layer.DiffID]digest.Digest) (*schema1.Manifest, error) { if len(img.History) == 0 { return nil, errors.New("empty history when trying to create V2 manifest") } @@ -271,7 +347,7 @@ func CreateV2Manifest(name, tag string, img *image.Image, fsLayers map[layer.Dif if !present { return nil, fmt.Errorf("missing layer in CreateV2Manifest: %s", diffID.String()) } - dgst, err := digest.FromBytes([]byte(fsLayer.BlobSum.Hex() + " " + parent)) + dgst, err := digest.FromBytes([]byte(fsLayer.Hex() + " " + parent)) if err != nil { return nil, err } @@ -294,7 +370,7 @@ func CreateV2Manifest(name, tag string, img *image.Image, fsLayers map[layer.Dif reversedIndex := len(img.History) - i - 1 history[reversedIndex].V1Compatibility = string(jsonBytes) - fsLayerList[reversedIndex] = fsLayer + fsLayerList[reversedIndex] = schema1.FSLayer{BlobSum: fsLayer} parent = v1ID } @@ -315,11 +391,11 @@ func CreateV2Manifest(name, tag string, img *image.Image, fsLayers map[layer.Dif return nil, fmt.Errorf("missing layer in CreateV2Manifest: %s", diffID.String()) } - dgst, err := digest.FromBytes([]byte(fsLayer.BlobSum.Hex() + " " + parent + " " + string(img.RawJSON()))) + dgst, err := digest.FromBytes([]byte(fsLayer.Hex() + " " + parent + " " + string(img.RawJSON()))) if err != nil { return nil, err } - fsLayerList[0] = fsLayer + fsLayerList[0] = schema1.FSLayer{BlobSum: fsLayer} // Top-level v1compatibility string should be a modified version of the // image config. @@ -346,66 +422,3 @@ func CreateV2Manifest(name, tag string, img *image.Image, fsLayers map[layer.Dif History: history, }, nil } - -func rawJSON(value interface{}) *json.RawMessage { - jsonval, err := json.Marshal(value) - if err != nil { - return nil - } - return (*json.RawMessage)(&jsonval) -} - -func (p *v2Pusher) pushV2Layer(bs distribution.BlobService, l layer.Layer) (digest.Digest, error) { - out := p.config.OutStream - displayID := stringid.TruncateID(string(l.DiffID())) - - out.Write(p.sf.FormatProgress(displayID, "Preparing", nil)) - - arch, err := l.TarStream() - if err != nil { - return "", err - } - defer arch.Close() - - // Send the layer - layerUpload, err := bs.Create(context.Background()) - if err != nil { - return "", err - } - defer layerUpload.Close() - - // don't care if this fails; best effort - size, _ := l.DiffSize() - - reader := progressreader.New(progressreader.Config{ - In: ioutil.NopCloser(arch), // we'll take care of close here. - Out: out, - Formatter: p.sf, - Size: size, - NewLines: false, - ID: displayID, - Action: "Pushing", - }) - - compressedReader := compress(reader) - - digester := digest.Canonical.New() - tee := io.TeeReader(compressedReader, digester.Hash()) - - out.Write(p.sf.FormatProgress(displayID, "Pushing", nil)) - nn, err := layerUpload.ReadFrom(tee) - compressedReader.Close() - if err != nil { - return "", err - } - - dgst := digester.Digest() - if _, err := layerUpload.Commit(context.Background(), distribution.Descriptor{Digest: dgst}); err != nil { - return "", err - } - - logrus.Debugf("uploaded layer %s (%s), %d bytes", l.DiffID(), dgst, nn) - out.Write(p.sf.FormatProgress(displayID, "Pushed", nil)) - - return dgst, nil -} diff --git a/distribution/push_v2_test.go b/distribution/push_v2_test.go index ab9e6612c6..96a3939607 100644 --- a/distribution/push_v2_test.go +++ b/distribution/push_v2_test.go @@ -116,10 +116,10 @@ func TestCreateV2Manifest(t *testing.T) { t.Fatalf("json decoding failed: %v", err) } - fsLayers := map[layer.DiffID]schema1.FSLayer{ - layer.DiffID("sha256:c6f988f4874bb0add23a778f753c65efe992244e148a1d2ec2a8b664fb66bbd1"): {BlobSum: digest.Digest("sha256:a3ed95caeb02ffe68cdd9fd84406680ae93d633cb16422d00e8a7c22955b46d4")}, - layer.DiffID("sha256:5f70bf18a086007016e948b04aed3b82103a36bea41755b6cddfaf10ace3c6ef"): {BlobSum: digest.Digest("sha256:86e0e091d0da6bde2456dbb48306f3956bbeb2eae1b5b9a43045843f69fe4aaa")}, - layer.DiffID("sha256:13f53e08df5a220ab6d13c58b2bf83a59cbdc2e04d0a3f041ddf4b0ba4112d49"): {BlobSum: digest.Digest("sha256:b4ed95caeb02ffe68cdd9fd84406680ae93d633cb16422d00e8a7c22955b46d4")}, + fsLayers := map[layer.DiffID]digest.Digest{ + layer.DiffID("sha256:c6f988f4874bb0add23a778f753c65efe992244e148a1d2ec2a8b664fb66bbd1"): digest.Digest("sha256:a3ed95caeb02ffe68cdd9fd84406680ae93d633cb16422d00e8a7c22955b46d4"), + layer.DiffID("sha256:5f70bf18a086007016e948b04aed3b82103a36bea41755b6cddfaf10ace3c6ef"): digest.Digest("sha256:86e0e091d0da6bde2456dbb48306f3956bbeb2eae1b5b9a43045843f69fe4aaa"), + layer.DiffID("sha256:13f53e08df5a220ab6d13c58b2bf83a59cbdc2e04d0a3f041ddf4b0ba4112d49"): digest.Digest("sha256:b4ed95caeb02ffe68cdd9fd84406680ae93d633cb16422d00e8a7c22955b46d4"), } manifest, err := CreateV2Manifest("testrepo", "testtag", img, fsLayers) diff --git a/distribution/registry.go b/distribution/registry.go index ed8b8c20e5..bb5b58a3af 100644 --- a/distribution/registry.go +++ b/distribution/registry.go @@ -13,10 +13,12 @@ import ( "github.com/docker/distribution" "github.com/docker/distribution/digest" "github.com/docker/distribution/manifest/schema1" + "github.com/docker/distribution/registry/api/errcode" "github.com/docker/distribution/registry/client" "github.com/docker/distribution/registry/client/auth" "github.com/docker/distribution/registry/client/transport" "github.com/docker/docker/cliconfig" + "github.com/docker/docker/distribution/xfer" "github.com/docker/docker/registry" "golang.org/x/net/context" ) @@ -59,7 +61,7 @@ func NewV2Repository(repoInfo *registry.RepositoryInfo, endpoint registry.APIEnd authTransport := transport.NewTransport(base, modifiers...) pingClient := &http.Client{ Transport: authTransport, - Timeout: 5 * time.Second, + Timeout: 15 * time.Second, } endpointStr := strings.TrimRight(endpoint.URL, "/") + "/v2/" req, err := http.NewRequest("GET", endpointStr, nil) @@ -132,3 +134,23 @@ func (th *existingTokenHandler) AuthorizeRequest(req *http.Request, params map[s req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", th.token)) return nil } + +// retryOnError wraps the error in xfer.DoNotRetry if we should not retry the +// operation after this error. +func retryOnError(err error) error { + switch v := err.(type) { + case errcode.Errors: + return retryOnError(v[0]) + case errcode.Error: + switch v.Code { + case errcode.ErrorCodeUnauthorized, errcode.ErrorCodeUnsupported, errcode.ErrorCodeDenied: + return xfer.DoNotRetry{Err: err} + } + + } + // let's be nice and fallback if the error is a completely + // unexpected one. + // If new errors have to be handled in some way, please + // add them to the switch above. + return err +} diff --git a/distribution/registry_unit_test.go b/distribution/registry_unit_test.go index bf13934622..77d810e25b 100644 --- a/distribution/registry_unit_test.go +++ b/distribution/registry_unit_test.go @@ -11,9 +11,9 @@ import ( "github.com/docker/distribution/reference" "github.com/docker/distribution/registry/client/auth" "github.com/docker/docker/cliconfig" - "github.com/docker/docker/pkg/streamformatter" "github.com/docker/docker/registry" "github.com/docker/docker/utils" + "golang.org/x/net/context" ) func TestTokenPassThru(t *testing.T) { @@ -72,8 +72,7 @@ func TestTokenPassThru(t *testing.T) { MetaHeaders: http.Header{}, AuthConfig: authConfig, } - sf := streamformatter.NewJSONStreamFormatter() - puller, err := newPuller(endpoint, repoInfo, imagePullConfig, sf) + puller, err := newPuller(endpoint, repoInfo, imagePullConfig) if err != nil { t.Fatal(err) } @@ -86,7 +85,7 @@ func TestTokenPassThru(t *testing.T) { logrus.Debug("About to pull") // We expect it to fail, since we haven't mock'd the full registry exchange in our handler above tag, _ := reference.WithTag(n, "tag_goes_here") - _ = p.pullV2Repository(tag) + _ = p.pullV2Repository(context.Background(), tag) if !gotToken { t.Fatal("Failed to receive registry token") diff --git a/distribution/xfer/download.go b/distribution/xfer/download.go new file mode 100644 index 0000000000..69c8bad031 --- /dev/null +++ b/distribution/xfer/download.go @@ -0,0 +1,420 @@ +package xfer + +import ( + "errors" + "fmt" + "io" + "time" + + "github.com/Sirupsen/logrus" + "github.com/docker/docker/image" + "github.com/docker/docker/layer" + "github.com/docker/docker/pkg/archive" + "github.com/docker/docker/pkg/ioutils" + "github.com/docker/docker/pkg/progress" + "golang.org/x/net/context" +) + +const maxDownloadAttempts = 5 + +// LayerDownloadManager figures out which layers need to be downloaded, then +// registers and downloads those, taking into account dependencies between +// layers. +type LayerDownloadManager struct { + layerStore layer.Store + tm TransferManager +} + +// NewLayerDownloadManager returns a new LayerDownloadManager. +func NewLayerDownloadManager(layerStore layer.Store, concurrencyLimit int) *LayerDownloadManager { + return &LayerDownloadManager{ + layerStore: layerStore, + tm: NewTransferManager(concurrencyLimit), + } +} + +type downloadTransfer struct { + Transfer + + layerStore layer.Store + layer layer.Layer + err error +} + +// result returns the layer resulting from the download, if the download +// and registration were successful. +func (d *downloadTransfer) result() (layer.Layer, error) { + return d.layer, d.err +} + +// A DownloadDescriptor references a layer that may need to be downloaded. +type DownloadDescriptor interface { + // Key returns the key used to deduplicate downloads. + Key() string + // ID returns the ID for display purposes. + ID() string + // DiffID should return the DiffID for this layer, or an error + // if it is unknown (for example, if it has not been downloaded + // before). + DiffID() (layer.DiffID, error) + // Download is called to perform the download. + Download(ctx context.Context, progressOutput progress.Output) (io.ReadCloser, int64, error) +} + +// DownloadDescriptorWithRegistered is a DownloadDescriptor that has an +// additional Registered method which gets called after a downloaded layer is +// registered. This allows the user of the download manager to know the DiffID +// of each registered layer. This method is called if a cast to +// DownloadDescriptorWithRegistered is successful. +type DownloadDescriptorWithRegistered interface { + DownloadDescriptor + Registered(diffID layer.DiffID) +} + +// Download is a blocking function which ensures the requested layers are +// present in the layer store. It uses the string returned by the Key method to +// deduplicate downloads. If a given layer is not already known to present in +// the layer store, and the key is not used by an in-progress download, the +// Download method is called to get the layer tar data. Layers are then +// registered in the appropriate order. The caller must call the returned +// release function once it is is done with the returned RootFS object. +func (ldm *LayerDownloadManager) Download(ctx context.Context, initialRootFS image.RootFS, layers []DownloadDescriptor, progressOutput progress.Output) (image.RootFS, func(), error) { + var ( + topLayer layer.Layer + topDownload *downloadTransfer + watcher *Watcher + missingLayer bool + transferKey = "" + downloadsByKey = make(map[string]*downloadTransfer) + ) + + rootFS := initialRootFS + for _, descriptor := range layers { + key := descriptor.Key() + transferKey += key + + if !missingLayer { + missingLayer = true + diffID, err := descriptor.DiffID() + if err == nil { + getRootFS := rootFS + getRootFS.Append(diffID) + l, err := ldm.layerStore.Get(getRootFS.ChainID()) + if err == nil { + // Layer already exists. + logrus.Debugf("Layer already exists: %s", descriptor.ID()) + progress.Update(progressOutput, descriptor.ID(), "Already exists") + if topLayer != nil { + layer.ReleaseAndLog(ldm.layerStore, topLayer) + } + topLayer = l + missingLayer = false + rootFS.Append(diffID) + continue + } + } + } + + // Does this layer have the same data as a previous layer in + // the stack? If so, avoid downloading it more than once. + var topDownloadUncasted Transfer + if existingDownload, ok := downloadsByKey[key]; ok { + xferFunc := ldm.makeDownloadFuncFromDownload(descriptor, existingDownload, topDownload) + defer topDownload.Transfer.Release(watcher) + topDownloadUncasted, watcher = ldm.tm.Transfer(transferKey, xferFunc, progressOutput) + topDownload = topDownloadUncasted.(*downloadTransfer) + continue + } + + // Layer is not known to exist - download and register it. + progress.Update(progressOutput, descriptor.ID(), "Pulling fs layer") + + var xferFunc DoFunc + if topDownload != nil { + xferFunc = ldm.makeDownloadFunc(descriptor, "", topDownload) + defer topDownload.Transfer.Release(watcher) + } else { + xferFunc = ldm.makeDownloadFunc(descriptor, rootFS.ChainID(), nil) + } + topDownloadUncasted, watcher = ldm.tm.Transfer(transferKey, xferFunc, progressOutput) + topDownload = topDownloadUncasted.(*downloadTransfer) + downloadsByKey[key] = topDownload + } + + if topDownload == nil { + return rootFS, func() { layer.ReleaseAndLog(ldm.layerStore, topLayer) }, nil + } + + // Won't be using the list built up so far - will generate it + // from downloaded layers instead. + rootFS.DiffIDs = []layer.DiffID{} + + defer func() { + if topLayer != nil { + layer.ReleaseAndLog(ldm.layerStore, topLayer) + } + }() + + select { + case <-ctx.Done(): + topDownload.Transfer.Release(watcher) + return rootFS, func() {}, ctx.Err() + case <-topDownload.Done(): + break + } + + l, err := topDownload.result() + if err != nil { + topDownload.Transfer.Release(watcher) + return rootFS, func() {}, err + } + + // Must do this exactly len(layers) times, so we don't include the + // base layer on Windows. + for range layers { + if l == nil { + topDownload.Transfer.Release(watcher) + return rootFS, func() {}, errors.New("internal error: too few parent layers") + } + rootFS.DiffIDs = append([]layer.DiffID{l.DiffID()}, rootFS.DiffIDs...) + l = l.Parent() + } + return rootFS, func() { topDownload.Transfer.Release(watcher) }, err +} + +// makeDownloadFunc returns a function that performs the layer download and +// registration. If parentDownload is non-nil, it waits for that download to +// complete before the registration step, and registers the downloaded data +// on top of parentDownload's resulting layer. Otherwise, it registers the +// layer on top of the ChainID given by parentLayer. +func (ldm *LayerDownloadManager) makeDownloadFunc(descriptor DownloadDescriptor, parentLayer layer.ChainID, parentDownload *downloadTransfer) DoFunc { + return func(progressChan chan<- progress.Progress, start <-chan struct{}, inactive chan<- struct{}) Transfer { + d := &downloadTransfer{ + Transfer: NewTransfer(), + layerStore: ldm.layerStore, + } + + go func() { + defer func() { + close(progressChan) + }() + + progressOutput := progress.ChanOutput(progressChan) + + select { + case <-start: + default: + progress.Update(progressOutput, descriptor.ID(), "Waiting") + <-start + } + + if parentDownload != nil { + // Did the parent download already fail or get + // cancelled? + select { + case <-parentDownload.Done(): + _, err := parentDownload.result() + if err != nil { + d.err = err + return + } + default: + } + } + + var ( + downloadReader io.ReadCloser + size int64 + err error + retries int + ) + + for { + downloadReader, size, err = descriptor.Download(d.Transfer.Context(), progressOutput) + if err == nil { + break + } + + // If an error was returned because the context + // was cancelled, we shouldn't retry. + select { + case <-d.Transfer.Context().Done(): + d.err = err + return + default: + } + + retries++ + if _, isDNR := err.(DoNotRetry); isDNR || retries == maxDownloadAttempts { + logrus.Errorf("Download failed: %v", err) + d.err = err + return + } + + logrus.Errorf("Download failed, retrying: %v", err) + delay := retries * 5 + ticker := time.NewTicker(time.Second) + + selectLoop: + for { + progress.Updatef(progressOutput, descriptor.ID(), "Retrying in %d seconds", delay) + select { + case <-ticker.C: + delay-- + if delay == 0 { + ticker.Stop() + break selectLoop + } + case <-d.Transfer.Context().Done(): + ticker.Stop() + d.err = errors.New("download cancelled during retry delay") + return + } + + } + } + + close(inactive) + + if parentDownload != nil { + select { + case <-d.Transfer.Context().Done(): + d.err = errors.New("layer registration cancelled") + downloadReader.Close() + return + case <-parentDownload.Done(): + } + + l, err := parentDownload.result() + if err != nil { + d.err = err + downloadReader.Close() + return + } + parentLayer = l.ChainID() + } + + reader := progress.NewProgressReader(ioutils.NewCancelReadCloser(d.Transfer.Context(), downloadReader), progressOutput, size, descriptor.ID(), "Extracting") + defer reader.Close() + + inflatedLayerData, err := archive.DecompressStream(reader) + if err != nil { + d.err = fmt.Errorf("could not get decompression stream: %v", err) + return + } + + d.layer, err = d.layerStore.Register(inflatedLayerData, parentLayer) + if err != nil { + select { + case <-d.Transfer.Context().Done(): + d.err = errors.New("layer registration cancelled") + default: + d.err = fmt.Errorf("failed to register layer: %v", err) + } + return + } + + progress.Update(progressOutput, descriptor.ID(), "Pull complete") + withRegistered, hasRegistered := descriptor.(DownloadDescriptorWithRegistered) + if hasRegistered { + withRegistered.Registered(d.layer.DiffID()) + } + + // Doesn't actually need to be its own goroutine, but + // done like this so we can defer close(c). + go func() { + <-d.Transfer.Released() + if d.layer != nil { + layer.ReleaseAndLog(d.layerStore, d.layer) + } + }() + }() + + return d + } +} + +// makeDownloadFuncFromDownload returns a function that performs the layer +// registration when the layer data is coming from an existing download. It +// waits for sourceDownload and parentDownload to complete, and then +// reregisters the data from sourceDownload's top layer on top of +// parentDownload. This function does not log progress output because it would +// interfere with the progress reporting for sourceDownload, which has the same +// Key. +func (ldm *LayerDownloadManager) makeDownloadFuncFromDownload(descriptor DownloadDescriptor, sourceDownload *downloadTransfer, parentDownload *downloadTransfer) DoFunc { + return func(progressChan chan<- progress.Progress, start <-chan struct{}, inactive chan<- struct{}) Transfer { + d := &downloadTransfer{ + Transfer: NewTransfer(), + layerStore: ldm.layerStore, + } + + go func() { + defer func() { + close(progressChan) + }() + + <-start + + close(inactive) + + select { + case <-d.Transfer.Context().Done(): + d.err = errors.New("layer registration cancelled") + return + case <-parentDownload.Done(): + } + + l, err := parentDownload.result() + if err != nil { + d.err = err + return + } + parentLayer := l.ChainID() + + // sourceDownload should have already finished if + // parentDownload finished, but wait for it explicitly + // to be sure. + select { + case <-d.Transfer.Context().Done(): + d.err = errors.New("layer registration cancelled") + return + case <-sourceDownload.Done(): + } + + l, err = sourceDownload.result() + if err != nil { + d.err = err + return + } + + layerReader, err := l.TarStream() + if err != nil { + d.err = err + return + } + defer layerReader.Close() + + d.layer, err = d.layerStore.Register(layerReader, parentLayer) + if err != nil { + d.err = fmt.Errorf("failed to register layer: %v", err) + return + } + + withRegistered, hasRegistered := descriptor.(DownloadDescriptorWithRegistered) + if hasRegistered { + withRegistered.Registered(d.layer.DiffID()) + } + + // Doesn't actually need to be its own goroutine, but + // done like this so we can defer close(c). + go func() { + <-d.Transfer.Released() + if d.layer != nil { + layer.ReleaseAndLog(d.layerStore, d.layer) + } + }() + }() + + return d + } +} diff --git a/distribution/xfer/download_test.go b/distribution/xfer/download_test.go new file mode 100644 index 0000000000..ff665df344 --- /dev/null +++ b/distribution/xfer/download_test.go @@ -0,0 +1,332 @@ +package xfer + +import ( + "bytes" + "errors" + "io" + "io/ioutil" + "sync/atomic" + "testing" + "time" + + "github.com/docker/distribution/digest" + "github.com/docker/docker/image" + "github.com/docker/docker/layer" + "github.com/docker/docker/pkg/archive" + "github.com/docker/docker/pkg/progress" + "golang.org/x/net/context" +) + +const maxDownloadConcurrency = 3 + +type mockLayer struct { + layerData bytes.Buffer + diffID layer.DiffID + chainID layer.ChainID + parent layer.Layer +} + +func (ml *mockLayer) TarStream() (io.ReadCloser, error) { + return ioutil.NopCloser(bytes.NewBuffer(ml.layerData.Bytes())), nil +} + +func (ml *mockLayer) ChainID() layer.ChainID { + return ml.chainID +} + +func (ml *mockLayer) DiffID() layer.DiffID { + return ml.diffID +} + +func (ml *mockLayer) Parent() layer.Layer { + return ml.parent +} + +func (ml *mockLayer) Size() (size int64, err error) { + return 0, nil +} + +func (ml *mockLayer) DiffSize() (size int64, err error) { + return 0, nil +} + +func (ml *mockLayer) Metadata() (map[string]string, error) { + return make(map[string]string), nil +} + +type mockLayerStore struct { + layers map[layer.ChainID]*mockLayer +} + +func createChainIDFromParent(parent layer.ChainID, dgsts ...layer.DiffID) layer.ChainID { + if len(dgsts) == 0 { + return parent + } + if parent == "" { + return createChainIDFromParent(layer.ChainID(dgsts[0]), dgsts[1:]...) + } + // H = "H(n-1) SHA256(n)" + dgst, err := digest.FromBytes([]byte(string(parent) + " " + string(dgsts[0]))) + if err != nil { + // Digest calculation is not expected to throw an error, + // any error at this point is a program error + panic(err) + } + return createChainIDFromParent(layer.ChainID(dgst), dgsts[1:]...) +} + +func (ls *mockLayerStore) Register(reader io.Reader, parentID layer.ChainID) (layer.Layer, error) { + var ( + parent layer.Layer + err error + ) + + if parentID != "" { + parent, err = ls.Get(parentID) + if err != nil { + return nil, err + } + } + + l := &mockLayer{parent: parent} + _, err = l.layerData.ReadFrom(reader) + if err != nil { + return nil, err + } + diffID, err := digest.FromBytes(l.layerData.Bytes()) + if err != nil { + return nil, err + } + l.diffID = layer.DiffID(diffID) + l.chainID = createChainIDFromParent(parentID, l.diffID) + + ls.layers[l.chainID] = l + return l, nil +} + +func (ls *mockLayerStore) Get(chainID layer.ChainID) (layer.Layer, error) { + l, ok := ls.layers[chainID] + if !ok { + return nil, layer.ErrLayerDoesNotExist + } + return l, nil +} + +func (ls *mockLayerStore) Release(l layer.Layer) ([]layer.Metadata, error) { + return []layer.Metadata{}, nil +} + +func (ls *mockLayerStore) Mount(id string, parent layer.ChainID, label string, init layer.MountInit) (layer.RWLayer, error) { + return nil, errors.New("not implemented") +} + +func (ls *mockLayerStore) Unmount(id string) error { + return errors.New("not implemented") +} + +func (ls *mockLayerStore) DeleteMount(id string) ([]layer.Metadata, error) { + return nil, errors.New("not implemented") +} + +func (ls *mockLayerStore) Changes(id string) ([]archive.Change, error) { + return nil, errors.New("not implemented") +} + +type mockDownloadDescriptor struct { + currentDownloads *int32 + id string + diffID layer.DiffID + registeredDiffID layer.DiffID + expectedDiffID layer.DiffID + simulateRetries int +} + +// Key returns the key used to deduplicate downloads. +func (d *mockDownloadDescriptor) Key() string { + return d.id +} + +// ID returns the ID for display purposes. +func (d *mockDownloadDescriptor) ID() string { + return d.id +} + +// DiffID should return the DiffID for this layer, or an error +// if it is unknown (for example, if it has not been downloaded +// before). +func (d *mockDownloadDescriptor) DiffID() (layer.DiffID, error) { + if d.diffID != "" { + return d.diffID, nil + } + return "", errors.New("no diffID available") +} + +func (d *mockDownloadDescriptor) Registered(diffID layer.DiffID) { + d.registeredDiffID = diffID +} + +func (d *mockDownloadDescriptor) mockTarStream() io.ReadCloser { + // The mock implementation returns the ID repeated 5 times as a tar + // stream instead of actual tar data. The data is ignored except for + // computing IDs. + return ioutil.NopCloser(bytes.NewBuffer([]byte(d.id + d.id + d.id + d.id + d.id))) +} + +// Download is called to perform the download. +func (d *mockDownloadDescriptor) Download(ctx context.Context, progressOutput progress.Output) (io.ReadCloser, int64, error) { + if d.currentDownloads != nil { + defer atomic.AddInt32(d.currentDownloads, -1) + + if atomic.AddInt32(d.currentDownloads, 1) > maxDownloadConcurrency { + return nil, 0, errors.New("concurrency limit exceeded") + } + } + + // Sleep a bit to simulate a time-consuming download. + for i := int64(0); i <= 10; i++ { + select { + case <-ctx.Done(): + return nil, 0, ctx.Err() + case <-time.After(10 * time.Millisecond): + progressOutput.WriteProgress(progress.Progress{ID: d.ID(), Action: "Downloading", Current: i, Total: 10}) + } + } + + if d.simulateRetries != 0 { + d.simulateRetries-- + return nil, 0, errors.New("simulating retry") + } + + return d.mockTarStream(), 0, nil +} + +func downloadDescriptors(currentDownloads *int32) []DownloadDescriptor { + return []DownloadDescriptor{ + &mockDownloadDescriptor{ + currentDownloads: currentDownloads, + id: "id1", + expectedDiffID: layer.DiffID("sha256:68e2c75dc5c78ea9240689c60d7599766c213ae210434c53af18470ae8c53ec1"), + }, + &mockDownloadDescriptor{ + currentDownloads: currentDownloads, + id: "id2", + expectedDiffID: layer.DiffID("sha256:64a636223116aa837973a5d9c2bdd17d9b204e4f95ac423e20e65dfbb3655473"), + }, + &mockDownloadDescriptor{ + currentDownloads: currentDownloads, + id: "id3", + expectedDiffID: layer.DiffID("sha256:58745a8bbd669c25213e9de578c4da5c8ee1c836b3581432c2b50e38a6753300"), + }, + &mockDownloadDescriptor{ + currentDownloads: currentDownloads, + id: "id2", + expectedDiffID: layer.DiffID("sha256:64a636223116aa837973a5d9c2bdd17d9b204e4f95ac423e20e65dfbb3655473"), + }, + &mockDownloadDescriptor{ + currentDownloads: currentDownloads, + id: "id4", + expectedDiffID: layer.DiffID("sha256:0dfb5b9577716cc173e95af7c10289322c29a6453a1718addc00c0c5b1330936"), + simulateRetries: 1, + }, + &mockDownloadDescriptor{ + currentDownloads: currentDownloads, + id: "id5", + expectedDiffID: layer.DiffID("sha256:0a5f25fa1acbc647f6112a6276735d0fa01e4ee2aa7ec33015e337350e1ea23d"), + }, + } +} + +func TestSuccessfulDownload(t *testing.T) { + layerStore := &mockLayerStore{make(map[layer.ChainID]*mockLayer)} + ldm := NewLayerDownloadManager(layerStore, maxDownloadConcurrency) + + progressChan := make(chan progress.Progress) + progressDone := make(chan struct{}) + receivedProgress := make(map[string]int64) + + go func() { + for p := range progressChan { + if p.Action == "Downloading" { + receivedProgress[p.ID] = p.Current + } else if p.Action == "Already exists" { + receivedProgress[p.ID] = -1 + } + } + close(progressDone) + }() + + var currentDownloads int32 + descriptors := downloadDescriptors(¤tDownloads) + + firstDescriptor := descriptors[0].(*mockDownloadDescriptor) + + // Pre-register the first layer to simulate an already-existing layer + l, err := layerStore.Register(firstDescriptor.mockTarStream(), "") + if err != nil { + t.Fatal(err) + } + firstDescriptor.diffID = l.DiffID() + + rootFS, releaseFunc, err := ldm.Download(context.Background(), *image.NewRootFS(), descriptors, progress.ChanOutput(progressChan)) + if err != nil { + t.Fatalf("download error: %v", err) + } + + releaseFunc() + + close(progressChan) + <-progressDone + + if len(rootFS.DiffIDs) != len(descriptors) { + t.Fatal("got wrong number of diffIDs in rootfs") + } + + for i, d := range descriptors { + descriptor := d.(*mockDownloadDescriptor) + + if descriptor.diffID != "" { + if receivedProgress[d.ID()] != -1 { + t.Fatalf("did not get 'already exists' message for %v", d.ID()) + } + } else if receivedProgress[d.ID()] != 10 { + t.Fatalf("missing or wrong progress output for %v (got: %d)", d.ID(), receivedProgress[d.ID()]) + } + + if rootFS.DiffIDs[i] != descriptor.expectedDiffID { + t.Fatalf("rootFS item %d has the wrong diffID (expected: %v got: %v)", i, descriptor.expectedDiffID, rootFS.DiffIDs[i]) + } + + if descriptor.diffID == "" && descriptor.registeredDiffID != rootFS.DiffIDs[i] { + t.Fatal("diffID mismatch between rootFS and Registered callback") + } + } +} + +func TestCancelledDownload(t *testing.T) { + ldm := NewLayerDownloadManager(&mockLayerStore{make(map[layer.ChainID]*mockLayer)}, maxDownloadConcurrency) + + progressChan := make(chan progress.Progress) + progressDone := make(chan struct{}) + + go func() { + for range progressChan { + } + close(progressDone) + }() + + ctx, cancel := context.WithCancel(context.Background()) + + go func() { + <-time.After(time.Millisecond) + cancel() + }() + + descriptors := downloadDescriptors(nil) + _, _, err := ldm.Download(ctx, *image.NewRootFS(), descriptors, progress.ChanOutput(progressChan)) + if err != context.Canceled { + t.Fatal("expected download to be cancelled") + } + + close(progressChan) + <-progressDone +} diff --git a/distribution/xfer/transfer.go b/distribution/xfer/transfer.go new file mode 100644 index 0000000000..c0ae3d1c4b --- /dev/null +++ b/distribution/xfer/transfer.go @@ -0,0 +1,343 @@ +package xfer + +import ( + "sync" + + "github.com/docker/docker/pkg/progress" + "golang.org/x/net/context" +) + +// DoNotRetry is an error wrapper indicating that the error cannot be resolved +// with a retry. +type DoNotRetry struct { + Err error +} + +// Error returns the stringified representation of the encapsulated error. +func (e DoNotRetry) Error() string { + return e.Err.Error() +} + +// Watcher is returned by Watch and can be passed to Release to stop watching. +type Watcher struct { + // signalChan is used to signal to the watcher goroutine that + // new progress information is available, or that the transfer + // has finished. + signalChan chan struct{} + // releaseChan signals to the watcher goroutine that the watcher + // should be detached. + releaseChan chan struct{} + // running remains open as long as the watcher is watching the + // transfer. It gets closed if the transfer finishes or the + // watcher is detached. + running chan struct{} +} + +// Transfer represents an in-progress transfer. +type Transfer interface { + Watch(progressOutput progress.Output) *Watcher + Release(*Watcher) + Context() context.Context + Cancel() + Done() <-chan struct{} + Released() <-chan struct{} + Broadcast(masterProgressChan <-chan progress.Progress) +} + +type transfer struct { + mu sync.Mutex + + ctx context.Context + cancel context.CancelFunc + + // watchers keeps track of the goroutines monitoring progress output, + // indexed by the channels that release them. + watchers map[chan struct{}]*Watcher + + // lastProgress is the most recently received progress event. + lastProgress progress.Progress + // hasLastProgress is true when lastProgress has been set. + hasLastProgress bool + + // running remains open as long as the transfer is in progress. + running chan struct{} + // hasWatchers stays open until all watchers release the trasnfer. + hasWatchers chan struct{} + + // broadcastDone is true if the master progress channel has closed. + broadcastDone bool + // broadcastSyncChan allows watchers to "ping" the broadcasting + // goroutine to wait for it for deplete its input channel. This ensures + // a detaching watcher won't miss an event that was sent before it + // started detaching. + broadcastSyncChan chan struct{} +} + +// NewTransfer creates a new transfer. +func NewTransfer() Transfer { + t := &transfer{ + watchers: make(map[chan struct{}]*Watcher), + running: make(chan struct{}), + hasWatchers: make(chan struct{}), + broadcastSyncChan: make(chan struct{}), + } + + // This uses context.Background instead of a caller-supplied context + // so that a transfer won't be cancelled automatically if the client + // which requested it is ^C'd (there could be other viewers). + t.ctx, t.cancel = context.WithCancel(context.Background()) + + return t +} + +// Broadcast copies the progress and error output to all viewers. +func (t *transfer) Broadcast(masterProgressChan <-chan progress.Progress) { + for { + var ( + p progress.Progress + ok bool + ) + select { + case p, ok = <-masterProgressChan: + default: + // We've depleted the channel, so now we can handle + // reads on broadcastSyncChan to let detaching watchers + // know we're caught up. + select { + case <-t.broadcastSyncChan: + continue + case p, ok = <-masterProgressChan: + } + } + + t.mu.Lock() + if ok { + t.lastProgress = p + t.hasLastProgress = true + for _, w := range t.watchers { + select { + case w.signalChan <- struct{}{}: + default: + } + } + + } else { + t.broadcastDone = true + } + t.mu.Unlock() + if !ok { + close(t.running) + return + } + } +} + +// Watch adds a watcher to the transfer. The supplied channel gets progress +// updates and is closed when the transfer finishes. +func (t *transfer) Watch(progressOutput progress.Output) *Watcher { + t.mu.Lock() + defer t.mu.Unlock() + + w := &Watcher{ + releaseChan: make(chan struct{}), + signalChan: make(chan struct{}), + running: make(chan struct{}), + } + + if t.broadcastDone { + close(w.running) + return w + } + + t.watchers[w.releaseChan] = w + + go func() { + defer func() { + close(w.running) + }() + done := false + for { + t.mu.Lock() + hasLastProgress := t.hasLastProgress + lastProgress := t.lastProgress + t.mu.Unlock() + + // This might write the last progress item a + // second time (since channel closure also gets + // us here), but that's fine. + if hasLastProgress { + progressOutput.WriteProgress(lastProgress) + } + + if done { + return + } + + select { + case <-w.signalChan: + case <-w.releaseChan: + done = true + // Since the watcher is going to detach, make + // sure the broadcaster is caught up so we + // don't miss anything. + select { + case t.broadcastSyncChan <- struct{}{}: + case <-t.running: + } + case <-t.running: + done = true + } + } + }() + + return w +} + +// Release is the inverse of Watch; indicating that the watcher no longer wants +// to be notified about the progress of the transfer. All calls to Watch must +// be paired with later calls to Release so that the lifecycle of the transfer +// is properly managed. +func (t *transfer) Release(watcher *Watcher) { + t.mu.Lock() + delete(t.watchers, watcher.releaseChan) + + if len(t.watchers) == 0 { + close(t.hasWatchers) + t.cancel() + } + t.mu.Unlock() + + close(watcher.releaseChan) + // Block until the watcher goroutine completes + <-watcher.running +} + +// Done returns a channel which is closed if the transfer completes or is +// cancelled. Note that having 0 watchers causes a transfer to be cancelled. +func (t *transfer) Done() <-chan struct{} { + // Note that this doesn't return t.ctx.Done() because that channel will + // be closed the moment Cancel is called, and we need to return a + // channel that blocks until a cancellation is actually acknowledged by + // the transfer function. + return t.running +} + +// Released returns a channel which is closed once all watchers release the +// transfer. +func (t *transfer) Released() <-chan struct{} { + return t.hasWatchers +} + +// Context returns the context associated with the transfer. +func (t *transfer) Context() context.Context { + return t.ctx +} + +// Cancel cancels the context associated with the transfer. +func (t *transfer) Cancel() { + t.cancel() +} + +// DoFunc is a function called by the transfer manager to actually perform +// a transfer. It should be non-blocking. It should wait until the start channel +// is closed before transfering any data. If the function closes inactive, that +// signals to the transfer manager that the job is no longer actively moving +// data - for example, it may be waiting for a dependent tranfer to finish. +// This prevents it from taking up a slot. +type DoFunc func(progressChan chan<- progress.Progress, start <-chan struct{}, inactive chan<- struct{}) Transfer + +// TransferManager is used by LayerDownloadManager and LayerUploadManager to +// schedule and deduplicate transfers. It is up to the TransferManager +// implementation to make the scheduling and concurrency decisions. +type TransferManager interface { + // Transfer checks if a transfer with the given key is in progress. If + // so, it returns progress and error output from that transfer. + // Otherwise, it will call xferFunc to initiate the transfer. + Transfer(key string, xferFunc DoFunc, progressOutput progress.Output) (Transfer, *Watcher) +} + +type transferManager struct { + mu sync.Mutex + + concurrencyLimit int + activeTransfers int + transfers map[string]Transfer + waitingTransfers []chan struct{} +} + +// NewTransferManager returns a new TransferManager. +func NewTransferManager(concurrencyLimit int) TransferManager { + return &transferManager{ + concurrencyLimit: concurrencyLimit, + transfers: make(map[string]Transfer), + } +} + +// Transfer checks if a transfer matching the given key is in progress. If not, +// it starts one by calling xferFunc. The caller supplies a channel which +// receives progress output from the transfer. +func (tm *transferManager) Transfer(key string, xferFunc DoFunc, progressOutput progress.Output) (Transfer, *Watcher) { + tm.mu.Lock() + defer tm.mu.Unlock() + + if xfer, present := tm.transfers[key]; present { + // Transfer is already in progress. + watcher := xfer.Watch(progressOutput) + return xfer, watcher + } + + start := make(chan struct{}) + inactive := make(chan struct{}) + + if tm.activeTransfers < tm.concurrencyLimit { + close(start) + tm.activeTransfers++ + } else { + tm.waitingTransfers = append(tm.waitingTransfers, start) + } + + masterProgressChan := make(chan progress.Progress) + xfer := xferFunc(masterProgressChan, start, inactive) + watcher := xfer.Watch(progressOutput) + go xfer.Broadcast(masterProgressChan) + tm.transfers[key] = xfer + + // When the transfer is finished, remove from the map. + go func() { + for { + select { + case <-inactive: + tm.mu.Lock() + tm.inactivate(start) + tm.mu.Unlock() + inactive = nil + case <-xfer.Done(): + tm.mu.Lock() + if inactive != nil { + tm.inactivate(start) + } + delete(tm.transfers, key) + tm.mu.Unlock() + return + } + } + }() + + return xfer, watcher +} + +func (tm *transferManager) inactivate(start chan struct{}) { + // If the transfer was started, remove it from the activeTransfers + // count. + select { + case <-start: + // Start next transfer if any are waiting + if len(tm.waitingTransfers) != 0 { + close(tm.waitingTransfers[0]) + tm.waitingTransfers = tm.waitingTransfers[1:] + } else { + tm.activeTransfers-- + } + default: + } +} diff --git a/distribution/xfer/transfer_test.go b/distribution/xfer/transfer_test.go new file mode 100644 index 0000000000..7eeb304033 --- /dev/null +++ b/distribution/xfer/transfer_test.go @@ -0,0 +1,385 @@ +package xfer + +import ( + "sync/atomic" + "testing" + "time" + + "github.com/docker/docker/pkg/progress" +) + +func TestTransfer(t *testing.T) { + makeXferFunc := func(id string) DoFunc { + return func(progressChan chan<- progress.Progress, start <-chan struct{}, inactive chan<- struct{}) Transfer { + select { + case <-start: + default: + t.Fatalf("transfer function not started even though concurrency limit not reached") + } + + xfer := NewTransfer() + go func() { + for i := 0; i <= 10; i++ { + progressChan <- progress.Progress{ID: id, Action: "testing", Current: int64(i), Total: 10} + time.Sleep(10 * time.Millisecond) + } + close(progressChan) + }() + return xfer + } + } + + tm := NewTransferManager(5) + progressChan := make(chan progress.Progress) + progressDone := make(chan struct{}) + receivedProgress := make(map[string]int64) + + go func() { + for p := range progressChan { + val, present := receivedProgress[p.ID] + if !present { + if p.Current != 0 { + t.Fatalf("got unexpected progress value: %d (expected 0)", p.Current) + } + } else if p.Current == 10 { + // Special case: last progress output may be + // repeated because the transfer finishing + // causes the latest progress output to be + // written to the channel (in case the watcher + // missed it). + if p.Current != 9 && p.Current != 10 { + t.Fatalf("got unexpected progress value: %d (expected %d)", p.Current, val+1) + } + } else if p.Current != val+1 { + t.Fatalf("got unexpected progress value: %d (expected %d)", p.Current, val+1) + } + receivedProgress[p.ID] = p.Current + } + close(progressDone) + }() + + // Start a few transfers + ids := []string{"id1", "id2", "id3"} + xfers := make([]Transfer, len(ids)) + watchers := make([]*Watcher, len(ids)) + for i, id := range ids { + xfers[i], watchers[i] = tm.Transfer(id, makeXferFunc(id), progress.ChanOutput(progressChan)) + } + + for i, xfer := range xfers { + <-xfer.Done() + xfer.Release(watchers[i]) + } + close(progressChan) + <-progressDone + + for _, id := range ids { + if receivedProgress[id] != 10 { + t.Fatalf("final progress value %d instead of 10", receivedProgress[id]) + } + } +} + +func TestConcurrencyLimit(t *testing.T) { + concurrencyLimit := 3 + var runningJobs int32 + + makeXferFunc := func(id string) DoFunc { + return func(progressChan chan<- progress.Progress, start <-chan struct{}, inactive chan<- struct{}) Transfer { + xfer := NewTransfer() + go func() { + <-start + totalJobs := atomic.AddInt32(&runningJobs, 1) + if int(totalJobs) > concurrencyLimit { + t.Fatalf("too many jobs running") + } + for i := 0; i <= 10; i++ { + progressChan <- progress.Progress{ID: id, Action: "testing", Current: int64(i), Total: 10} + time.Sleep(10 * time.Millisecond) + } + atomic.AddInt32(&runningJobs, -1) + close(progressChan) + }() + return xfer + } + } + + tm := NewTransferManager(concurrencyLimit) + progressChan := make(chan progress.Progress) + progressDone := make(chan struct{}) + receivedProgress := make(map[string]int64) + + go func() { + for p := range progressChan { + receivedProgress[p.ID] = p.Current + } + close(progressDone) + }() + + // Start more transfers than the concurrency limit + ids := []string{"id1", "id2", "id3", "id4", "id5", "id6", "id7", "id8"} + xfers := make([]Transfer, len(ids)) + watchers := make([]*Watcher, len(ids)) + for i, id := range ids { + xfers[i], watchers[i] = tm.Transfer(id, makeXferFunc(id), progress.ChanOutput(progressChan)) + } + + for i, xfer := range xfers { + <-xfer.Done() + xfer.Release(watchers[i]) + } + close(progressChan) + <-progressDone + + for _, id := range ids { + if receivedProgress[id] != 10 { + t.Fatalf("final progress value %d instead of 10", receivedProgress[id]) + } + } +} + +func TestInactiveJobs(t *testing.T) { + concurrencyLimit := 3 + var runningJobs int32 + testDone := make(chan struct{}) + + makeXferFunc := func(id string) DoFunc { + return func(progressChan chan<- progress.Progress, start <-chan struct{}, inactive chan<- struct{}) Transfer { + xfer := NewTransfer() + go func() { + <-start + totalJobs := atomic.AddInt32(&runningJobs, 1) + if int(totalJobs) > concurrencyLimit { + t.Fatalf("too many jobs running") + } + for i := 0; i <= 10; i++ { + progressChan <- progress.Progress{ID: id, Action: "testing", Current: int64(i), Total: 10} + time.Sleep(10 * time.Millisecond) + } + atomic.AddInt32(&runningJobs, -1) + close(inactive) + <-testDone + close(progressChan) + }() + return xfer + } + } + + tm := NewTransferManager(concurrencyLimit) + progressChan := make(chan progress.Progress) + progressDone := make(chan struct{}) + receivedProgress := make(map[string]int64) + + go func() { + for p := range progressChan { + receivedProgress[p.ID] = p.Current + } + close(progressDone) + }() + + // Start more transfers than the concurrency limit + ids := []string{"id1", "id2", "id3", "id4", "id5", "id6", "id7", "id8"} + xfers := make([]Transfer, len(ids)) + watchers := make([]*Watcher, len(ids)) + for i, id := range ids { + xfers[i], watchers[i] = tm.Transfer(id, makeXferFunc(id), progress.ChanOutput(progressChan)) + } + + close(testDone) + for i, xfer := range xfers { + <-xfer.Done() + xfer.Release(watchers[i]) + } + close(progressChan) + <-progressDone + + for _, id := range ids { + if receivedProgress[id] != 10 { + t.Fatalf("final progress value %d instead of 10", receivedProgress[id]) + } + } +} + +func TestWatchRelease(t *testing.T) { + ready := make(chan struct{}) + + makeXferFunc := func(id string) DoFunc { + return func(progressChan chan<- progress.Progress, start <-chan struct{}, inactive chan<- struct{}) Transfer { + xfer := NewTransfer() + go func() { + defer func() { + close(progressChan) + }() + <-ready + for i := int64(0); ; i++ { + select { + case <-time.After(10 * time.Millisecond): + case <-xfer.Context().Done(): + return + } + progressChan <- progress.Progress{ID: id, Action: "testing", Current: i, Total: 10} + } + }() + return xfer + } + } + + tm := NewTransferManager(5) + + type watcherInfo struct { + watcher *Watcher + progressChan chan progress.Progress + progressDone chan struct{} + receivedFirstProgress chan struct{} + } + + progressConsumer := func(w watcherInfo) { + first := true + for range w.progressChan { + if first { + close(w.receivedFirstProgress) + } + first = false + } + close(w.progressDone) + } + + // Start a transfer + watchers := make([]watcherInfo, 5) + var xfer Transfer + watchers[0].progressChan = make(chan progress.Progress) + watchers[0].progressDone = make(chan struct{}) + watchers[0].receivedFirstProgress = make(chan struct{}) + xfer, watchers[0].watcher = tm.Transfer("id1", makeXferFunc("id1"), progress.ChanOutput(watchers[0].progressChan)) + go progressConsumer(watchers[0]) + + // Give it multiple watchers + for i := 1; i != len(watchers); i++ { + watchers[i].progressChan = make(chan progress.Progress) + watchers[i].progressDone = make(chan struct{}) + watchers[i].receivedFirstProgress = make(chan struct{}) + watchers[i].watcher = xfer.Watch(progress.ChanOutput(watchers[i].progressChan)) + go progressConsumer(watchers[i]) + } + + // Now that the watchers are set up, allow the transfer goroutine to + // proceed. + close(ready) + + // Confirm that each watcher gets progress output. + for _, w := range watchers { + <-w.receivedFirstProgress + } + + // Release one watcher every 5ms + for _, w := range watchers { + xfer.Release(w.watcher) + <-time.After(5 * time.Millisecond) + } + + // Now that all watchers have been released, Released() should + // return a closed channel. + <-xfer.Released() + + // Done() should return a closed channel because the xfer func returned + // due to cancellation. + <-xfer.Done() + + for _, w := range watchers { + close(w.progressChan) + <-w.progressDone + } +} + +func TestDuplicateTransfer(t *testing.T) { + ready := make(chan struct{}) + + var xferFuncCalls int32 + + makeXferFunc := func(id string) DoFunc { + return func(progressChan chan<- progress.Progress, start <-chan struct{}, inactive chan<- struct{}) Transfer { + atomic.AddInt32(&xferFuncCalls, 1) + xfer := NewTransfer() + go func() { + defer func() { + close(progressChan) + }() + <-ready + for i := int64(0); ; i++ { + select { + case <-time.After(10 * time.Millisecond): + case <-xfer.Context().Done(): + return + } + progressChan <- progress.Progress{ID: id, Action: "testing", Current: i, Total: 10} + } + }() + return xfer + } + } + + tm := NewTransferManager(5) + + type transferInfo struct { + xfer Transfer + watcher *Watcher + progressChan chan progress.Progress + progressDone chan struct{} + receivedFirstProgress chan struct{} + } + + progressConsumer := func(t transferInfo) { + first := true + for range t.progressChan { + if first { + close(t.receivedFirstProgress) + } + first = false + } + close(t.progressDone) + } + + // Try to start multiple transfers with the same ID + transfers := make([]transferInfo, 5) + for i := range transfers { + t := &transfers[i] + t.progressChan = make(chan progress.Progress) + t.progressDone = make(chan struct{}) + t.receivedFirstProgress = make(chan struct{}) + t.xfer, t.watcher = tm.Transfer("id1", makeXferFunc("id1"), progress.ChanOutput(t.progressChan)) + go progressConsumer(*t) + } + + // Allow the transfer goroutine to proceed. + close(ready) + + // Confirm that each watcher gets progress output. + for _, t := range transfers { + <-t.receivedFirstProgress + } + + // Confirm that the transfer function was called exactly once. + if xferFuncCalls != 1 { + t.Fatal("transfer function wasn't called exactly once") + } + + // Release one watcher every 5ms + for _, t := range transfers { + t.xfer.Release(t.watcher) + <-time.After(5 * time.Millisecond) + } + + for _, t := range transfers { + // Now that all watchers have been released, Released() should + // return a closed channel. + <-t.xfer.Released() + // Done() should return a closed channel because the xfer func returned + // due to cancellation. + <-t.xfer.Done() + } + + for _, t := range transfers { + close(t.progressChan) + <-t.progressDone + } +} diff --git a/distribution/xfer/upload.go b/distribution/xfer/upload.go new file mode 100644 index 0000000000..9a7d2c17a7 --- /dev/null +++ b/distribution/xfer/upload.go @@ -0,0 +1,159 @@ +package xfer + +import ( + "errors" + "time" + + "github.com/Sirupsen/logrus" + "github.com/docker/distribution/digest" + "github.com/docker/docker/layer" + "github.com/docker/docker/pkg/progress" + "golang.org/x/net/context" +) + +const maxUploadAttempts = 5 + +// LayerUploadManager provides task management and progress reporting for +// uploads. +type LayerUploadManager struct { + tm TransferManager +} + +// NewLayerUploadManager returns a new LayerUploadManager. +func NewLayerUploadManager(concurrencyLimit int) *LayerUploadManager { + return &LayerUploadManager{ + tm: NewTransferManager(concurrencyLimit), + } +} + +type uploadTransfer struct { + Transfer + + diffID layer.DiffID + digest digest.Digest + err error +} + +// An UploadDescriptor references a layer that may need to be uploaded. +type UploadDescriptor interface { + // Key returns the key used to deduplicate uploads. + Key() string + // ID returns the ID for display purposes. + ID() string + // DiffID should return the DiffID for this layer. + DiffID() layer.DiffID + // Upload is called to perform the Upload. + Upload(ctx context.Context, progressOutput progress.Output) (digest.Digest, error) +} + +// Upload is a blocking function which ensures the listed layers are present on +// the remote registry. It uses the string returned by the Key method to +// deduplicate uploads. +func (lum *LayerUploadManager) Upload(ctx context.Context, layers []UploadDescriptor, progressOutput progress.Output) (map[layer.DiffID]digest.Digest, error) { + var ( + uploads []*uploadTransfer + digests = make(map[layer.DiffID]digest.Digest) + dedupDescriptors = make(map[string]struct{}) + ) + + for _, descriptor := range layers { + progress.Update(progressOutput, descriptor.ID(), "Preparing") + + key := descriptor.Key() + if _, present := dedupDescriptors[key]; present { + continue + } + dedupDescriptors[key] = struct{}{} + + xferFunc := lum.makeUploadFunc(descriptor) + upload, watcher := lum.tm.Transfer(descriptor.Key(), xferFunc, progressOutput) + defer upload.Release(watcher) + uploads = append(uploads, upload.(*uploadTransfer)) + } + + for _, upload := range uploads { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-upload.Transfer.Done(): + if upload.err != nil { + return nil, upload.err + } + digests[upload.diffID] = upload.digest + } + } + + return digests, nil +} + +func (lum *LayerUploadManager) makeUploadFunc(descriptor UploadDescriptor) DoFunc { + return func(progressChan chan<- progress.Progress, start <-chan struct{}, inactive chan<- struct{}) Transfer { + u := &uploadTransfer{ + Transfer: NewTransfer(), + diffID: descriptor.DiffID(), + } + + go func() { + defer func() { + close(progressChan) + }() + + progressOutput := progress.ChanOutput(progressChan) + + select { + case <-start: + default: + progress.Update(progressOutput, descriptor.ID(), "Waiting") + <-start + } + + retries := 0 + for { + digest, err := descriptor.Upload(u.Transfer.Context(), progressOutput) + if err == nil { + u.digest = digest + break + } + + // If an error was returned because the context + // was cancelled, we shouldn't retry. + select { + case <-u.Transfer.Context().Done(): + u.err = err + return + default: + } + + retries++ + if _, isDNR := err.(DoNotRetry); isDNR || retries == maxUploadAttempts { + logrus.Errorf("Upload failed: %v", err) + u.err = err + return + } + + logrus.Errorf("Upload failed, retrying: %v", err) + delay := retries * 5 + ticker := time.NewTicker(time.Second) + + selectLoop: + for { + progress.Updatef(progressOutput, descriptor.ID(), "Retrying in %d seconds", delay) + select { + case <-ticker.C: + delay-- + if delay == 0 { + ticker.Stop() + break selectLoop + } + case <-u.Transfer.Context().Done(): + ticker.Stop() + u.err = errors.New("upload cancelled during retry delay") + return + } + } + } + }() + + return u + } +} diff --git a/distribution/xfer/upload_test.go b/distribution/xfer/upload_test.go new file mode 100644 index 0000000000..df5b2ba9c0 --- /dev/null +++ b/distribution/xfer/upload_test.go @@ -0,0 +1,153 @@ +package xfer + +import ( + "errors" + "sync/atomic" + "testing" + "time" + + "github.com/docker/distribution/digest" + "github.com/docker/docker/layer" + "github.com/docker/docker/pkg/progress" + "golang.org/x/net/context" +) + +const maxUploadConcurrency = 3 + +type mockUploadDescriptor struct { + currentUploads *int32 + diffID layer.DiffID + simulateRetries int +} + +// Key returns the key used to deduplicate downloads. +func (u *mockUploadDescriptor) Key() string { + return u.diffID.String() +} + +// ID returns the ID for display purposes. +func (u *mockUploadDescriptor) ID() string { + return u.diffID.String() +} + +// DiffID should return the DiffID for this layer. +func (u *mockUploadDescriptor) DiffID() layer.DiffID { + return u.diffID +} + +// Upload is called to perform the upload. +func (u *mockUploadDescriptor) Upload(ctx context.Context, progressOutput progress.Output) (digest.Digest, error) { + if u.currentUploads != nil { + defer atomic.AddInt32(u.currentUploads, -1) + + if atomic.AddInt32(u.currentUploads, 1) > maxUploadConcurrency { + return "", errors.New("concurrency limit exceeded") + } + } + + // Sleep a bit to simulate a time-consuming upload. + for i := int64(0); i <= 10; i++ { + select { + case <-ctx.Done(): + return "", ctx.Err() + case <-time.After(10 * time.Millisecond): + progressOutput.WriteProgress(progress.Progress{ID: u.ID(), Current: i, Total: 10}) + } + } + + if u.simulateRetries != 0 { + u.simulateRetries-- + return "", errors.New("simulating retry") + } + + // For the mock implementation, use SHA256(DiffID) as the returned + // digest. + return digest.FromBytes([]byte(u.diffID.String())) +} + +func uploadDescriptors(currentUploads *int32) []UploadDescriptor { + return []UploadDescriptor{ + &mockUploadDescriptor{currentUploads, layer.DiffID("sha256:cbbf2f9a99b47fc460d422812b6a5adff7dfee951d8fa2e4a98caa0382cfbdbf"), 0}, + &mockUploadDescriptor{currentUploads, layer.DiffID("sha256:1515325234325236634634608943609283523908626098235490238423902343"), 0}, + &mockUploadDescriptor{currentUploads, layer.DiffID("sha256:6929356290463485374960346430698374523437683470934634534953453453"), 0}, + &mockUploadDescriptor{currentUploads, layer.DiffID("sha256:cbbf2f9a99b47fc460d422812b6a5adff7dfee951d8fa2e4a98caa0382cfbdbf"), 0}, + &mockUploadDescriptor{currentUploads, layer.DiffID("sha256:8159352387436803946235346346368745389534789534897538734598734987"), 1}, + &mockUploadDescriptor{currentUploads, layer.DiffID("sha256:4637863963478346897346987346987346789346789364879364897364987346"), 0}, + } +} + +var expectedDigests = map[layer.DiffID]digest.Digest{ + layer.DiffID("sha256:cbbf2f9a99b47fc460d422812b6a5adff7dfee951d8fa2e4a98caa0382cfbdbf"): digest.Digest("sha256:c5095d6cf7ee42b7b064371dcc1dc3fb4af197f04d01a60009d484bd432724fc"), + layer.DiffID("sha256:1515325234325236634634608943609283523908626098235490238423902343"): digest.Digest("sha256:968cbfe2ff5269ea1729b3804767a1f57ffbc442d3bc86f47edbf7e688a4f36e"), + layer.DiffID("sha256:6929356290463485374960346430698374523437683470934634534953453453"): digest.Digest("sha256:8a5e56ab4b477a400470a7d5d4c1ca0c91235fd723ab19cc862636a06f3a735d"), + layer.DiffID("sha256:8159352387436803946235346346368745389534789534897538734598734987"): digest.Digest("sha256:5e733e5cd3688512fc240bd5c178e72671c9915947d17bb8451750d827944cb2"), + layer.DiffID("sha256:4637863963478346897346987346987346789346789364879364897364987346"): digest.Digest("sha256:ec4bb98d15e554a9f66c3ef9296cf46772c0ded3b1592bd8324d96e2f60f460c"), +} + +func TestSuccessfulUpload(t *testing.T) { + lum := NewLayerUploadManager(maxUploadConcurrency) + + progressChan := make(chan progress.Progress) + progressDone := make(chan struct{}) + receivedProgress := make(map[string]int64) + + go func() { + for p := range progressChan { + receivedProgress[p.ID] = p.Current + } + close(progressDone) + }() + + var currentUploads int32 + descriptors := uploadDescriptors(¤tUploads) + + digests, err := lum.Upload(context.Background(), descriptors, progress.ChanOutput(progressChan)) + if err != nil { + t.Fatalf("upload error: %v", err) + } + + close(progressChan) + <-progressDone + + if len(digests) != len(expectedDigests) { + t.Fatal("wrong number of keys in digests map") + } + + for key, val := range expectedDigests { + if digests[key] != val { + t.Fatalf("mismatch in digest array for key %v (expected %v, got %v)", key, val, digests[key]) + } + if receivedProgress[key.String()] != 10 { + t.Fatalf("missing or wrong progress output for %v", key) + } + } +} + +func TestCancelledUpload(t *testing.T) { + lum := NewLayerUploadManager(maxUploadConcurrency) + + progressChan := make(chan progress.Progress) + progressDone := make(chan struct{}) + + go func() { + for range progressChan { + } + close(progressDone) + }() + + ctx, cancel := context.WithCancel(context.Background()) + + go func() { + <-time.After(time.Millisecond) + cancel() + }() + + descriptors := uploadDescriptors(nil) + _, err := lum.Upload(ctx, descriptors, progress.ChanOutput(progressChan)) + if err != context.Canceled { + t.Fatal("expected upload to be cancelled") + } + + close(progressChan) + <-progressDone +} diff --git a/integration-cli/docker_cli_pull_test.go b/integration-cli/docker_cli_pull_test.go index 411acec539..8ea31a7c52 100644 --- a/integration-cli/docker_cli_pull_test.go +++ b/integration-cli/docker_cli_pull_test.go @@ -140,7 +140,7 @@ func (s *DockerHubPullSuite) TestPullAllTagsFromCentralRegistry(c *check.C) { } // TestPullClientDisconnect kills the client during a pull operation and verifies that the operation -// still succesfully completes on the daemon side. +// gets cancelled. // // Ref: docker/docker#15589 func (s *DockerHubPullSuite) TestPullClientDisconnect(c *check.C) { @@ -161,14 +161,8 @@ func (s *DockerHubPullSuite) TestPullClientDisconnect(c *check.C) { err = pullCmd.Process.Kill() c.Assert(err, checker.IsNil) - maxAttempts := 20 - for i := 0; ; i++ { - if _, err := s.CmdWithError("inspect", repoName); err == nil { - break - } - if i >= maxAttempts { - c.Fatal("timeout reached: image was not pulled after client disconnected") - } - time.Sleep(500 * time.Millisecond) + time.Sleep(2 * time.Second) + if _, err := s.CmdWithError("inspect", repoName); err == nil { + c.Fatal("image was pulled after client disconnected") } } diff --git a/pkg/broadcaster/buffered.go b/pkg/broadcaster/buffered.go deleted file mode 100644 index 57f5f97862..0000000000 --- a/pkg/broadcaster/buffered.go +++ /dev/null @@ -1,167 +0,0 @@ -package broadcaster - -import ( - "errors" - "io" - "sync" -) - -// Buffered keeps track of one or more observers watching the progress -// of an operation. For example, if multiple clients are trying to pull an -// image, they share a Buffered struct for the download operation. -type Buffered struct { - sync.Mutex - // c is a channel that observers block on, waiting for the operation - // to finish. - c chan struct{} - // cond is a condition variable used to wake up observers when there's - // new data available. - cond *sync.Cond - // history is a buffer of the progress output so far, so a new observer - // can catch up. The history is stored as a slice of separate byte - // slices, so that if the writer is a WriteFlusher, the flushes will - // happen in the right places. - history [][]byte - // wg is a WaitGroup used to wait for all writes to finish on Close - wg sync.WaitGroup - // result is the argument passed to the first call of Close, and - // returned to callers of Wait - result error -} - -// NewBuffered returns an initialized Buffered structure. -func NewBuffered() *Buffered { - b := &Buffered{ - c: make(chan struct{}), - } - b.cond = sync.NewCond(b) - return b -} - -// closed returns true if and only if the broadcaster has been closed -func (broadcaster *Buffered) closed() bool { - select { - case <-broadcaster.c: - return true - default: - return false - } -} - -// receiveWrites runs as a goroutine so that writes don't block the Write -// function. It writes the new data in broadcaster.history each time there's -// activity on the broadcaster.cond condition variable. -func (broadcaster *Buffered) receiveWrites(observer io.Writer) { - n := 0 - - broadcaster.Lock() - - // The condition variable wait is at the end of this loop, so that the - // first iteration will write the history so far. - for { - newData := broadcaster.history[n:] - // Make a copy of newData so we can release the lock - sendData := make([][]byte, len(newData), len(newData)) - copy(sendData, newData) - broadcaster.Unlock() - - for len(sendData) > 0 { - _, err := observer.Write(sendData[0]) - if err != nil { - broadcaster.wg.Done() - return - } - n++ - sendData = sendData[1:] - } - - broadcaster.Lock() - - // If we are behind, we need to catch up instead of waiting - // or handling a closure. - if len(broadcaster.history) != n { - continue - } - - // detect closure of the broadcast writer - if broadcaster.closed() { - broadcaster.Unlock() - broadcaster.wg.Done() - return - } - - broadcaster.cond.Wait() - - // Mutex is still locked as the loop continues - } -} - -// Write adds data to the history buffer, and also writes it to all current -// observers. -func (broadcaster *Buffered) Write(p []byte) (n int, err error) { - broadcaster.Lock() - defer broadcaster.Unlock() - - // Is the broadcaster closed? If so, the write should fail. - if broadcaster.closed() { - return 0, errors.New("attempted write to a closed broadcaster.Buffered") - } - - // Add message in p to the history slice - newEntry := make([]byte, len(p), len(p)) - copy(newEntry, p) - broadcaster.history = append(broadcaster.history, newEntry) - - broadcaster.cond.Broadcast() - - return len(p), nil -} - -// Add adds an observer to the broadcaster. The new observer receives the -// data from the history buffer, and also all subsequent data. -func (broadcaster *Buffered) Add(w io.Writer) error { - // The lock is acquired here so that Add can't race with Close - broadcaster.Lock() - defer broadcaster.Unlock() - - if broadcaster.closed() { - return errors.New("attempted to add observer to a closed broadcaster.Buffered") - } - - broadcaster.wg.Add(1) - go broadcaster.receiveWrites(w) - - return nil -} - -// CloseWithError signals to all observers that the operation has finished. Its -// argument is a result that should be returned to waiters blocking on Wait. -func (broadcaster *Buffered) CloseWithError(result error) { - broadcaster.Lock() - if broadcaster.closed() { - broadcaster.Unlock() - return - } - broadcaster.result = result - close(broadcaster.c) - broadcaster.cond.Broadcast() - broadcaster.Unlock() - - // Don't return until all writers have caught up. - broadcaster.wg.Wait() -} - -// Close signals to all observers that the operation has finished. It causes -// all calls to Wait to return nil. -func (broadcaster *Buffered) Close() { - broadcaster.CloseWithError(nil) -} - -// Wait blocks until the operation is marked as completed by the Close method, -// and all writer goroutines have completed. It returns the argument that was -// passed to Close. -func (broadcaster *Buffered) Wait() error { - <-broadcaster.c - broadcaster.wg.Wait() - return broadcaster.result -} diff --git a/pkg/ioutils/readers.go b/pkg/ioutils/readers.go index b4544de53c..e73b02bbf1 100644 --- a/pkg/ioutils/readers.go +++ b/pkg/ioutils/readers.go @@ -4,6 +4,8 @@ import ( "crypto/sha256" "encoding/hex" "io" + + "golang.org/x/net/context" ) type readCloserWrapper struct { @@ -81,3 +83,72 @@ func (r *OnEOFReader) runFunc() { r.Fn = nil } } + +// cancelReadCloser wraps an io.ReadCloser with a context for cancelling read +// operations. +type cancelReadCloser struct { + cancel func() + pR *io.PipeReader // Stream to read from + pW *io.PipeWriter +} + +// NewCancelReadCloser creates a wrapper that closes the ReadCloser when the +// context is cancelled. The returned io.ReadCloser must be closed when it is +// no longer needed. +func NewCancelReadCloser(ctx context.Context, in io.ReadCloser) io.ReadCloser { + pR, pW := io.Pipe() + + // Create a context used to signal when the pipe is closed + doneCtx, cancel := context.WithCancel(context.Background()) + + p := &cancelReadCloser{ + cancel: cancel, + pR: pR, + pW: pW, + } + + go func() { + _, err := io.Copy(pW, in) + select { + case <-ctx.Done(): + // If the context was closed, p.closeWithError + // was already called. Calling it again would + // change the error that Read returns. + default: + p.closeWithError(err) + } + in.Close() + }() + go func() { + for { + select { + case <-ctx.Done(): + p.closeWithError(ctx.Err()) + case <-doneCtx.Done(): + return + } + } + }() + + return p +} + +// Read wraps the Read method of the pipe that provides data from the wrapped +// ReadCloser. +func (p *cancelReadCloser) Read(buf []byte) (n int, err error) { + return p.pR.Read(buf) +} + +// closeWithError closes the wrapper and its underlying reader. It will +// cause future calls to Read to return err. +func (p *cancelReadCloser) closeWithError(err error) { + p.pW.CloseWithError(err) + p.cancel() +} + +// Close closes the wrapper its underlying reader. It will cause +// future calls to Read to return io.EOF. +func (p *cancelReadCloser) Close() error { + p.closeWithError(io.EOF) + return nil +} diff --git a/pkg/ioutils/readers_test.go b/pkg/ioutils/readers_test.go index 5be68cb16c..9abc1054df 100644 --- a/pkg/ioutils/readers_test.go +++ b/pkg/ioutils/readers_test.go @@ -2,8 +2,12 @@ package ioutils import ( "fmt" + "io/ioutil" "strings" "testing" + "time" + + "golang.org/x/net/context" ) // Implement io.Reader @@ -65,3 +69,26 @@ func TestHashData(t *testing.T) { t.Fatalf("Expecting %s, got %s", expected, actual) } } + +type perpetualReader struct{} + +func (p *perpetualReader) Read(buf []byte) (n int, err error) { + for i := 0; i != len(buf); i++ { + buf[i] = 'a' + } + return len(buf), nil +} + +func TestCancelReadCloser(t *testing.T) { + ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond) + cancelReadCloser := NewCancelReadCloser(ctx, ioutil.NopCloser(&perpetualReader{})) + for { + var buf [128]byte + _, err := cancelReadCloser.Read(buf[:]) + if err == context.DeadlineExceeded { + break + } else if err != nil { + t.Fatalf("got unexpected error: %v", err) + } + } +} diff --git a/pkg/progress/progress.go b/pkg/progress/progress.go new file mode 100644 index 0000000000..1f3b34a832 --- /dev/null +++ b/pkg/progress/progress.go @@ -0,0 +1,63 @@ +package progress + +import ( + "fmt" +) + +// Progress represents the progress of a transfer. +type Progress struct { + ID string + + // Progress contains a Message or... + Message string + + // ...progress of an action + Action string + Current int64 + Total int64 + + LastUpdate bool +} + +// Output is an interface for writing progress information. It's +// like a writer for progress, but we don't call it Writer because +// that would be confusing next to ProgressReader (also, because it +// doesn't implement the io.Writer interface). +type Output interface { + WriteProgress(Progress) error +} + +type chanOutput chan<- Progress + +func (out chanOutput) WriteProgress(p Progress) error { + out <- p + return nil +} + +// ChanOutput returns a Output that writes progress updates to the +// supplied channel. +func ChanOutput(progressChan chan<- Progress) Output { + return chanOutput(progressChan) +} + +// Update is a convenience function to write a progress update to the channel. +func Update(out Output, id, action string) { + out.WriteProgress(Progress{ID: id, Action: action}) +} + +// Updatef is a convenience function to write a printf-formatted progress update +// to the channel. +func Updatef(out Output, id, format string, a ...interface{}) { + Update(out, id, fmt.Sprintf(format, a...)) +} + +// Message is a convenience function to write a progress message to the channel. +func Message(out Output, id, message string) { + out.WriteProgress(Progress{ID: id, Message: message}) +} + +// Messagef is a convenience function to write a printf-formatted progress +// message to the channel. +func Messagef(out Output, id, format string, a ...interface{}) { + Message(out, id, fmt.Sprintf(format, a...)) +} diff --git a/pkg/progress/progressreader.go b/pkg/progress/progressreader.go new file mode 100644 index 0000000000..c39e2b69fb --- /dev/null +++ b/pkg/progress/progressreader.go @@ -0,0 +1,59 @@ +package progress + +import ( + "io" +) + +// Reader is a Reader with progress bar. +type Reader struct { + in io.ReadCloser // Stream to read from + out Output // Where to send progress bar to + size int64 + current int64 + lastUpdate int64 + id string + action string +} + +// NewProgressReader creates a new ProgressReader. +func NewProgressReader(in io.ReadCloser, out Output, size int64, id, action string) *Reader { + return &Reader{ + in: in, + out: out, + size: size, + id: id, + action: action, + } +} + +func (p *Reader) Read(buf []byte) (n int, err error) { + read, err := p.in.Read(buf) + p.current += int64(read) + updateEvery := int64(1024 * 512) //512kB + if p.size > 0 { + // Update progress for every 1% read if 1% < 512kB + if increment := int64(0.01 * float64(p.size)); increment < updateEvery { + updateEvery = increment + } + } + if p.current-p.lastUpdate > updateEvery || err != nil { + p.updateProgress(err != nil && read == 0) + p.lastUpdate = p.current + } + + return read, err +} + +// Close closes the progress reader and its underlying reader. +func (p *Reader) Close() error { + if p.current < p.size { + // print a full progress bar when closing prematurely + p.current = p.size + p.updateProgress(false) + } + return p.in.Close() +} + +func (p *Reader) updateProgress(last bool) { + p.out.WriteProgress(Progress{ID: p.id, Action: p.action, Current: p.current, Total: p.size, LastUpdate: last}) +} diff --git a/pkg/progress/progressreader_test.go b/pkg/progress/progressreader_test.go new file mode 100644 index 0000000000..b14d401561 --- /dev/null +++ b/pkg/progress/progressreader_test.go @@ -0,0 +1,75 @@ +package progress + +import ( + "bytes" + "io" + "io/ioutil" + "testing" +) + +func TestOutputOnPrematureClose(t *testing.T) { + content := []byte("TESTING") + reader := ioutil.NopCloser(bytes.NewReader(content)) + progressChan := make(chan Progress, 10) + + pr := NewProgressReader(reader, ChanOutput(progressChan), int64(len(content)), "Test", "Read") + + part := make([]byte, 4, 4) + _, err := io.ReadFull(pr, part) + if err != nil { + pr.Close() + t.Fatal(err) + } + +drainLoop: + for { + select { + case <-progressChan: + default: + break drainLoop + } + } + + pr.Close() + + select { + case <-progressChan: + default: + t.Fatalf("Expected some output when closing prematurely") + } +} + +func TestCompleteSilently(t *testing.T) { + content := []byte("TESTING") + reader := ioutil.NopCloser(bytes.NewReader(content)) + progressChan := make(chan Progress, 10) + + pr := NewProgressReader(reader, ChanOutput(progressChan), int64(len(content)), "Test", "Read") + + out, err := ioutil.ReadAll(pr) + if err != nil { + pr.Close() + t.Fatal(err) + } + if string(out) != "TESTING" { + pr.Close() + t.Fatalf("Unexpected output %q from reader", string(out)) + } + +drainLoop: + for { + select { + case <-progressChan: + default: + break drainLoop + } + } + + pr.Close() + + select { + case <-progressChan: + t.Fatalf("Should have closed silently when read is complete") + default: + } +} diff --git a/pkg/progressreader/progressreader.go b/pkg/progressreader/progressreader.go deleted file mode 100644 index f48442b591..0000000000 --- a/pkg/progressreader/progressreader.go +++ /dev/null @@ -1,68 +0,0 @@ -// Package progressreader provides a Reader with a progress bar that can be -// printed out using the streamformatter package. -package progressreader - -import ( - "io" - - "github.com/docker/docker/pkg/jsonmessage" - "github.com/docker/docker/pkg/streamformatter" -) - -// Config contains the configuration for a Reader with progress bar. -type Config struct { - In io.ReadCloser // Stream to read from - Out io.Writer // Where to send progress bar to - Formatter *streamformatter.StreamFormatter - Size int64 - Current int64 - LastUpdate int64 - NewLines bool - ID string - Action string -} - -// New creates a new Config. -func New(newReader Config) *Config { - return &newReader -} - -func (config *Config) Read(p []byte) (n int, err error) { - read, err := config.In.Read(p) - config.Current += int64(read) - updateEvery := int64(1024 * 512) //512kB - if config.Size > 0 { - // Update progress for every 1% read if 1% < 512kB - if increment := int64(0.01 * float64(config.Size)); increment < updateEvery { - updateEvery = increment - } - } - if config.Current-config.LastUpdate > updateEvery || err != nil { - updateProgress(config) - config.LastUpdate = config.Current - } - - if err != nil && read == 0 { - updateProgress(config) - if config.NewLines { - config.Out.Write(config.Formatter.FormatStatus("", "")) - } - } - return read, err -} - -// Close closes the reader (Config). -func (config *Config) Close() error { - if config.Current < config.Size { - //print a full progress bar when closing prematurely - config.Current = config.Size - updateProgress(config) - } - return config.In.Close() -} - -func updateProgress(config *Config) { - progress := jsonmessage.JSONProgress{Current: config.Current, Total: config.Size} - fmtMessage := config.Formatter.FormatProgress(config.ID, config.Action, &progress) - config.Out.Write(fmtMessage) -} diff --git a/pkg/progressreader/progressreader_test.go b/pkg/progressreader/progressreader_test.go deleted file mode 100644 index 21d9b0f057..0000000000 --- a/pkg/progressreader/progressreader_test.go +++ /dev/null @@ -1,94 +0,0 @@ -package progressreader - -import ( - "bufio" - "bytes" - "io" - "io/ioutil" - "testing" - - "github.com/docker/docker/pkg/streamformatter" -) - -func TestOutputOnPrematureClose(t *testing.T) { - var outBuf bytes.Buffer - content := []byte("TESTING") - reader := ioutil.NopCloser(bytes.NewReader(content)) - writer := bufio.NewWriter(&outBuf) - - prCfg := Config{ - In: reader, - Out: writer, - Formatter: streamformatter.NewStreamFormatter(), - Size: int64(len(content)), - NewLines: true, - ID: "Test", - Action: "Read", - } - pr := New(prCfg) - - part := make([]byte, 4, 4) - _, err := io.ReadFull(pr, part) - if err != nil { - pr.Close() - t.Fatal(err) - } - - if err := writer.Flush(); err != nil { - pr.Close() - t.Fatal(err) - } - - tlen := outBuf.Len() - pr.Close() - if err := writer.Flush(); err != nil { - t.Fatal(err) - } - - if outBuf.Len() == tlen { - t.Fatalf("Expected some output when closing prematurely") - } -} - -func TestCompleteSilently(t *testing.T) { - var outBuf bytes.Buffer - content := []byte("TESTING") - reader := ioutil.NopCloser(bytes.NewReader(content)) - writer := bufio.NewWriter(&outBuf) - - prCfg := Config{ - In: reader, - Out: writer, - Formatter: streamformatter.NewStreamFormatter(), - Size: int64(len(content)), - NewLines: true, - ID: "Test", - Action: "Read", - } - pr := New(prCfg) - - out, err := ioutil.ReadAll(pr) - if err != nil { - pr.Close() - t.Fatal(err) - } - if string(out) != "TESTING" { - pr.Close() - t.Fatalf("Unexpected output %q from reader", string(out)) - } - - if err := writer.Flush(); err != nil { - pr.Close() - t.Fatal(err) - } - - tlen := outBuf.Len() - pr.Close() - if err := writer.Flush(); err != nil { - t.Fatal(err) - } - - if outBuf.Len() > tlen { - t.Fatalf("Should have closed silently when read is complete") - } -} diff --git a/pkg/streamformatter/streamformatter.go b/pkg/streamformatter/streamformatter.go index d3ac39ebdb..b67a53d648 100644 --- a/pkg/streamformatter/streamformatter.go +++ b/pkg/streamformatter/streamformatter.go @@ -7,6 +7,7 @@ import ( "io" "github.com/docker/docker/pkg/jsonmessage" + "github.com/docker/docker/pkg/progress" ) // StreamFormatter formats a stream, optionally using JSON. @@ -92,6 +93,44 @@ func (sf *StreamFormatter) FormatProgress(id, action string, progress *jsonmessa return []byte(action + " " + progress.String() + endl) } +// NewProgressOutput returns a progress.Output object that can be passed to +// progress.NewProgressReader. +func (sf *StreamFormatter) NewProgressOutput(out io.Writer, newLines bool) progress.Output { + return &progressOutput{ + sf: sf, + out: out, + newLines: newLines, + } +} + +type progressOutput struct { + sf *StreamFormatter + out io.Writer + newLines bool +} + +// WriteProgress formats progress information from a ProgressReader. +func (out *progressOutput) WriteProgress(prog progress.Progress) error { + var formatted []byte + if prog.Message != "" { + formatted = out.sf.FormatStatus(prog.ID, prog.Message) + } else { + jsonProgress := jsonmessage.JSONProgress{Current: prog.Current, Total: prog.Total} + formatted = out.sf.FormatProgress(prog.ID, prog.Action, &jsonProgress) + } + _, err := out.out.Write(formatted) + if err != nil { + return err + } + + if out.newLines && prog.LastUpdate { + _, err = out.out.Write(out.sf.FormatStatus("", "")) + return err + } + + return nil +} + // StdoutFormatter is a streamFormatter that writes to the standard output. type StdoutFormatter struct { io.Writer diff --git a/registry/session.go b/registry/session.go index 645e5d44b3..5017aeacac 100644 --- a/registry/session.go +++ b/registry/session.go @@ -17,7 +17,6 @@ import ( "net/url" "strconv" "strings" - "time" "github.com/Sirupsen/logrus" "github.com/docker/distribution/reference" @@ -270,7 +269,6 @@ func (r *Session) GetRemoteImageJSON(imgID, registry string) ([]byte, int64, err // GetRemoteImageLayer retrieves an image layer from the registry func (r *Session) GetRemoteImageLayer(imgID, registry string, imgSize int64) (io.ReadCloser, error) { var ( - retries = 5 statusCode = 0 res *http.Response err error @@ -281,14 +279,9 @@ func (r *Session) GetRemoteImageLayer(imgID, registry string, imgSize int64) (io if err != nil { return nil, fmt.Errorf("Error while getting from the server: %v", err) } - // TODO(tiborvass): why are we doing retries at this level? - // These retries should be generic to both v1 and v2 - for i := 1; i <= retries; i++ { - statusCode = 0 - res, err = r.client.Do(req) - if err == nil { - break - } + statusCode = 0 + res, err = r.client.Do(req) + if err != nil { logrus.Debugf("Error contacting registry %s: %v", registry, err) if res != nil { if res.Body != nil { @@ -296,11 +289,8 @@ func (r *Session) GetRemoteImageLayer(imgID, registry string, imgSize int64) (io } statusCode = res.StatusCode } - if i == retries { - return nil, fmt.Errorf("Server error: Status %d while fetching image layer (%s)", - statusCode, imgID) - } - time.Sleep(time.Duration(i) * 5 * time.Second) + return nil, fmt.Errorf("Server error: Status %d while fetching image layer (%s)", + statusCode, imgID) } if res.StatusCode != 200 {