1
0
Fork 0
mirror of https://github.com/moby/moby.git synced 2022-11-09 12:21:53 -05:00

Improved push and pull with upload manager and download manager

This commit adds a transfer manager which deduplicates and schedules
transfers, and also an upload manager and download manager that build on
top of the transfer manager to provide high-level interfaces for uploads
and downloads. The push and pull code is modified to use these building
blocks.

Some benefits of the changes:

- Simplification of push/pull code
- Pushes can upload layers concurrently
- Failed downloads and uploads are retried after backoff delays
- Cancellation is supported, but individual transfers will only be
  cancelled if all pushes or pulls using them are cancelled.
- The distribution code is decoupled from Docker Engine packages and API
  conventions (i.e. streamformatter), which will make it easier to split
  out.

This commit also includes unit tests for the new distribution/xfer
package. The tests cover 87.8% of the statements in the package.

Signed-off-by: Aaron Lehmann <aaron.lehmann@docker.com>
This commit is contained in:
Aaron Lehmann 2015-11-13 16:59:01 -08:00
parent 7470e39c73
commit 572ce80230
36 changed files with 2675 additions and 1127 deletions

View file

@ -23,7 +23,7 @@ import (
"github.com/docker/docker/pkg/httputils"
"github.com/docker/docker/pkg/jsonmessage"
flag "github.com/docker/docker/pkg/mflag"
"github.com/docker/docker/pkg/progressreader"
"github.com/docker/docker/pkg/progress"
"github.com/docker/docker/pkg/streamformatter"
"github.com/docker/docker/pkg/ulimit"
"github.com/docker/docker/pkg/units"
@ -169,16 +169,9 @@ func (cli *DockerCli) CmdBuild(args ...string) error {
context = replaceDockerfileTarWrapper(context, newDockerfile, relDockerfile)
// Setup an upload progress bar
// FIXME: ProgressReader shouldn't be this annoying to use
sf := streamformatter.NewStreamFormatter()
var body io.Reader = progressreader.New(progressreader.Config{
In: context,
Out: cli.out,
Formatter: sf,
NewLines: true,
ID: "",
Action: "Sending build context to Docker daemon",
})
progressOutput := streamformatter.NewStreamFormatter().NewProgressOutput(cli.out, true)
var body io.Reader = progress.NewProgressReader(context, progressOutput, 0, "", "Sending build context to Docker daemon")
var memory int64
if *flMemoryString != "" {
@ -447,17 +440,10 @@ func getContextFromURL(out io.Writer, remoteURL, dockerfileName string) (absCont
return "", "", fmt.Errorf("unable to download remote context %s: %v", remoteURL, err)
}
defer response.Body.Close()
progressOutput := streamformatter.NewStreamFormatter().NewProgressOutput(out, true)
// Pass the response body through a progress reader.
progReader := &progressreader.Config{
In: response.Body,
Out: out,
Formatter: streamformatter.NewStreamFormatter(),
Size: response.ContentLength,
NewLines: true,
ID: "",
Action: fmt.Sprintf("Downloading build context from remote url: %s", remoteURL),
}
progReader := progress.NewProgressReader(response.Body, progressOutput, response.ContentLength, "", fmt.Sprintf("Downloading build context from remote url: %s", remoteURL))
return getContextFromReader(progReader, dockerfileName)
}

View file

@ -23,7 +23,7 @@ import (
"github.com/docker/docker/pkg/archive"
"github.com/docker/docker/pkg/chrootarchive"
"github.com/docker/docker/pkg/ioutils"
"github.com/docker/docker/pkg/progressreader"
"github.com/docker/docker/pkg/progress"
"github.com/docker/docker/pkg/streamformatter"
"github.com/docker/docker/pkg/ulimit"
"github.com/docker/docker/runconfig"
@ -325,7 +325,7 @@ func (s *router) postBuild(ctx context.Context, w http.ResponseWriter, r *http.R
sf := streamformatter.NewJSONStreamFormatter()
errf := func(err error) error {
// Do not write the error in the http output if it's still empty.
// This prevents from writing a 200(OK) when there is an interal error.
// This prevents from writing a 200(OK) when there is an internal error.
if !output.Flushed() {
return err
}
@ -401,23 +401,17 @@ func (s *router) postBuild(ctx context.Context, w http.ResponseWriter, r *http.R
remoteURL := r.FormValue("remote")
// Currently, only used if context is from a remote url.
// The field `In` is set by DetectContextFromRemoteURL.
// Look at code in DetectContextFromRemoteURL for more information.
pReader := &progressreader.Config{
// TODO: make progressreader streamformatter-agnostic
Out: output,
Formatter: sf,
Size: r.ContentLength,
NewLines: true,
ID: "Downloading context",
Action: remoteURL,
createProgressReader := func(in io.ReadCloser) io.ReadCloser {
progressOutput := sf.NewProgressOutput(output, true)
return progress.NewProgressReader(in, progressOutput, r.ContentLength, "Downloading context", remoteURL)
}
var (
context builder.ModifiableContext
dockerfileName string
)
context, dockerfileName, err = daemonbuilder.DetectContextFromRemoteURL(r.Body, remoteURL, pReader)
context, dockerfileName, err = daemonbuilder.DetectContextFromRemoteURL(r.Body, remoteURL, createProgressReader)
if err != nil {
return errf(err)
}

View file

@ -29,7 +29,7 @@ import (
"github.com/docker/docker/pkg/httputils"
"github.com/docker/docker/pkg/ioutils"
"github.com/docker/docker/pkg/jsonmessage"
"github.com/docker/docker/pkg/progressreader"
"github.com/docker/docker/pkg/progress"
"github.com/docker/docker/pkg/streamformatter"
"github.com/docker/docker/pkg/stringid"
"github.com/docker/docker/pkg/stringutils"
@ -264,17 +264,11 @@ func (b *Builder) download(srcURL string) (fi builder.FileInfo, err error) {
return
}
stdoutFormatter := b.Stdout.(*streamformatter.StdoutFormatter)
progressOutput := stdoutFormatter.StreamFormatter.NewProgressOutput(stdoutFormatter.Writer, true)
progressReader := progress.NewProgressReader(resp.Body, progressOutput, resp.ContentLength, "", "Downloading")
// Download and dump result to tmp file
if _, err = io.Copy(tmpFile, progressreader.New(progressreader.Config{
In: resp.Body,
// TODO: make progressreader streamformatter agnostic
Out: b.Stdout.(*streamformatter.StdoutFormatter).Writer,
Formatter: b.Stdout.(*streamformatter.StdoutFormatter).StreamFormatter,
Size: resp.ContentLength,
NewLines: true,
ID: "",
Action: "Downloading",
})); err != nil {
if _, err = io.Copy(tmpFile, progressReader); err != nil {
tmpFile.Close()
return
}

View file

@ -34,6 +34,7 @@ import (
"github.com/docker/docker/daemon/network"
"github.com/docker/docker/distribution"
dmetadata "github.com/docker/docker/distribution/metadata"
"github.com/docker/docker/distribution/xfer"
derr "github.com/docker/docker/errors"
"github.com/docker/docker/image"
"github.com/docker/docker/image/tarexport"
@ -49,7 +50,9 @@ import (
"github.com/docker/docker/pkg/namesgenerator"
"github.com/docker/docker/pkg/nat"
"github.com/docker/docker/pkg/parsers/filters"
"github.com/docker/docker/pkg/progress"
"github.com/docker/docker/pkg/signal"
"github.com/docker/docker/pkg/streamformatter"
"github.com/docker/docker/pkg/stringid"
"github.com/docker/docker/pkg/stringutils"
"github.com/docker/docker/pkg/sysinfo"
@ -66,6 +69,16 @@ import (
lntypes "github.com/docker/libnetwork/types"
"github.com/docker/libtrust"
"github.com/opencontainers/runc/libcontainer"
"golang.org/x/net/context"
)
const (
// maxDownloadConcurrency is the maximum number of downloads that
// may take place at a time for each pull.
maxDownloadConcurrency = 3
// maxUploadConcurrency is the maximum number of uploads that
// may take place at a time for each push.
maxUploadConcurrency = 5
)
var (
@ -126,7 +139,8 @@ type Daemon struct {
containers *contStore
execCommands *exec.Store
tagStore tag.Store
distributionPool *distribution.Pool
downloadManager *xfer.LayerDownloadManager
uploadManager *xfer.LayerUploadManager
distributionMetadataStore dmetadata.Store
trustKey libtrust.PrivateKey
idIndex *truncindex.TruncIndex
@ -738,7 +752,8 @@ func NewDaemon(config *Config, registryService *registry.Service) (daemon *Daemo
return nil, err
}
distributionPool := distribution.NewPool()
d.downloadManager = xfer.NewLayerDownloadManager(d.layerStore, maxDownloadConcurrency)
d.uploadManager = xfer.NewLayerUploadManager(maxUploadConcurrency)
ifs, err := image.NewFSStoreBackend(filepath.Join(imageRoot, "imagedb"))
if err != nil {
@ -834,7 +849,6 @@ func NewDaemon(config *Config, registryService *registry.Service) (daemon *Daemo
d.containers = &contStore{s: make(map[string]*container.Container)}
d.execCommands = exec.NewStore()
d.tagStore = tagStore
d.distributionPool = distributionPool
d.distributionMetadataStore = distributionMetadataStore
d.trustKey = trustKey
d.idIndex = truncindex.NewTruncIndex([]string{})
@ -1038,23 +1052,53 @@ func (daemon *Daemon) TagImage(newTag reference.Named, imageName string) error {
return nil
}
func writeDistributionProgress(cancelFunc func(), outStream io.Writer, progressChan <-chan progress.Progress) {
progressOutput := streamformatter.NewJSONStreamFormatter().NewProgressOutput(outStream, false)
operationCancelled := false
for prog := range progressChan {
if err := progressOutput.WriteProgress(prog); err != nil && !operationCancelled {
logrus.Errorf("error writing progress to client: %v", err)
cancelFunc()
operationCancelled = true
// Don't return, because we need to continue draining
// progressChan until it's closed to avoid a deadlock.
}
}
}
// PullImage initiates a pull operation. image is the repository name to pull, and
// tag may be either empty, or indicate a specific tag to pull.
func (daemon *Daemon) PullImage(ref reference.Named, metaHeaders map[string][]string, authConfig *cliconfig.AuthConfig, outStream io.Writer) error {
// Include a buffer so that slow client connections don't affect
// transfer performance.
progressChan := make(chan progress.Progress, 100)
writesDone := make(chan struct{})
ctx, cancelFunc := context.WithCancel(context.Background())
go func() {
writeDistributionProgress(cancelFunc, outStream, progressChan)
close(writesDone)
}()
imagePullConfig := &distribution.ImagePullConfig{
MetaHeaders: metaHeaders,
AuthConfig: authConfig,
OutStream: outStream,
ProgressOutput: progress.ChanOutput(progressChan),
RegistryService: daemon.RegistryService,
EventsService: daemon.EventsService,
MetadataStore: daemon.distributionMetadataStore,
LayerStore: daemon.layerStore,
ImageStore: daemon.imageStore,
TagStore: daemon.tagStore,
Pool: daemon.distributionPool,
DownloadManager: daemon.downloadManager,
}
return distribution.Pull(ref, imagePullConfig)
err := distribution.Pull(ctx, ref, imagePullConfig)
close(progressChan)
<-writesDone
return err
}
// ExportImage exports a list of images to the given output stream. The
@ -1069,10 +1113,23 @@ func (daemon *Daemon) ExportImage(names []string, outStream io.Writer) error {
// PushImage initiates a push operation on the repository named localName.
func (daemon *Daemon) PushImage(ref reference.Named, metaHeaders map[string][]string, authConfig *cliconfig.AuthConfig, outStream io.Writer) error {
// Include a buffer so that slow client connections don't affect
// transfer performance.
progressChan := make(chan progress.Progress, 100)
writesDone := make(chan struct{})
ctx, cancelFunc := context.WithCancel(context.Background())
go func() {
writeDistributionProgress(cancelFunc, outStream, progressChan)
close(writesDone)
}()
imagePushConfig := &distribution.ImagePushConfig{
MetaHeaders: metaHeaders,
AuthConfig: authConfig,
OutStream: outStream,
ProgressOutput: progress.ChanOutput(progressChan),
RegistryService: daemon.RegistryService,
EventsService: daemon.EventsService,
MetadataStore: daemon.distributionMetadataStore,
@ -1080,9 +1137,13 @@ func (daemon *Daemon) PushImage(ref reference.Named, metaHeaders map[string][]st
ImageStore: daemon.imageStore,
TagStore: daemon.tagStore,
TrustKey: daemon.trustKey,
UploadManager: daemon.uploadManager,
}
return distribution.Push(ref, imagePushConfig)
err := distribution.Push(ctx, ref, imagePushConfig)
close(progressChan)
<-writesDone
return err
}
// LookupImage looks up an image by name and returns it as an ImageInspect

View file

@ -21,7 +21,6 @@ import (
"github.com/docker/docker/pkg/httputils"
"github.com/docker/docker/pkg/idtools"
"github.com/docker/docker/pkg/ioutils"
"github.com/docker/docker/pkg/progressreader"
"github.com/docker/docker/pkg/urlutil"
"github.com/docker/docker/registry"
"github.com/docker/docker/runconfig"
@ -239,7 +238,7 @@ func (d Docker) Start(c *container.Container) error {
// DetectContextFromRemoteURL returns a context and in certain cases the name of the dockerfile to be used
// irrespective of user input.
// progressReader is only used if remoteURL is actually a URL (not empty, and not a Git endpoint).
func DetectContextFromRemoteURL(r io.ReadCloser, remoteURL string, progressReader *progressreader.Config) (context builder.ModifiableContext, dockerfileName string, err error) {
func DetectContextFromRemoteURL(r io.ReadCloser, remoteURL string, createProgressReader func(in io.ReadCloser) io.ReadCloser) (context builder.ModifiableContext, dockerfileName string, err error) {
switch {
case remoteURL == "":
context, err = builder.MakeTarSumContext(r)
@ -262,8 +261,7 @@ func DetectContextFromRemoteURL(r io.ReadCloser, remoteURL string, progressReade
},
// fallback handler (tar context)
"": func(rc io.ReadCloser) (io.ReadCloser, error) {
progressReader.In = rc
return progressReader, nil
return createProgressReader(rc), nil
},
})
default:

View file

@ -13,7 +13,7 @@ import (
"github.com/docker/docker/image"
"github.com/docker/docker/layer"
"github.com/docker/docker/pkg/httputils"
"github.com/docker/docker/pkg/progressreader"
"github.com/docker/docker/pkg/progress"
"github.com/docker/docker/pkg/streamformatter"
"github.com/docker/docker/runconfig"
)
@ -47,16 +47,8 @@ func (daemon *Daemon) ImportImage(src string, newRef reference.Named, msg string
if err != nil {
return err
}
progressReader := progressreader.New(progressreader.Config{
In: resp.Body,
Out: outStream,
Formatter: sf,
Size: resp.ContentLength,
NewLines: true,
ID: "",
Action: "Importing",
})
archive = progressReader
progressOutput := sf.NewProgressOutput(outStream, true)
archive = progress.NewProgressReader(resp.Body, progressOutput, resp.ContentLength, "", "Importing")
}
defer archive.Close()

View file

@ -23,20 +23,20 @@ func (idserv *V1IDService) namespace() string {
}
// Get finds a layer by its V1 ID.
func (idserv *V1IDService) Get(v1ID, registry string) (layer.ChainID, error) {
func (idserv *V1IDService) Get(v1ID, registry string) (layer.DiffID, error) {
if err := v1.ValidateID(v1ID); err != nil {
return layer.ChainID(""), err
return layer.DiffID(""), err
}
idBytes, err := idserv.store.Get(idserv.namespace(), registry+","+v1ID)
if err != nil {
return layer.ChainID(""), err
return layer.DiffID(""), err
}
return layer.ChainID(idBytes), nil
return layer.DiffID(idBytes), nil
}
// Set associates an image with a V1 ID.
func (idserv *V1IDService) Set(v1ID, registry string, id layer.ChainID) error {
func (idserv *V1IDService) Set(v1ID, registry string, id layer.DiffID) error {
if err := v1.ValidateID(v1ID); err != nil {
return err
}

View file

@ -24,22 +24,22 @@ func TestV1IDService(t *testing.T) {
testVectors := []struct {
registry string
v1ID string
layerID layer.ChainID
layerID layer.DiffID
}{
{
registry: "registry1",
v1ID: "f0cd5ca10b07f35512fc2f1cbf9a6cefbdb5cba70ac6b0c9e5988f4497f71937",
layerID: layer.ChainID("sha256:a3ed95caeb02ffe68cdd9fd84406680ae93d633cb16422d00e8a7c22955b46d4"),
layerID: layer.DiffID("sha256:a3ed95caeb02ffe68cdd9fd84406680ae93d633cb16422d00e8a7c22955b46d4"),
},
{
registry: "registry2",
v1ID: "9e3447ca24cb96d86ebd5960cb34d1299b07e0a0e03801d90b9969a2c187dd6e",
layerID: layer.ChainID("sha256:86e0e091d0da6bde2456dbb48306f3956bbeb2eae1b5b9a43045843f69fe4aaa"),
layerID: layer.DiffID("sha256:86e0e091d0da6bde2456dbb48306f3956bbeb2eae1b5b9a43045843f69fe4aaa"),
},
{
registry: "registry1",
v1ID: "9e3447ca24cb96d86ebd5960cb34d1299b07e0a0e03801d90b9969a2c187dd6e",
layerID: layer.ChainID("sha256:03f4658f8b782e12230c1783426bd3bacce651ce582a4ffb6fbbfa2079428ecb"),
layerID: layer.DiffID("sha256:03f4658f8b782e12230c1783426bd3bacce651ce582a4ffb6fbbfa2079428ecb"),
},
}

View file

@ -1,51 +0,0 @@
package distribution
import (
"sync"
"github.com/docker/docker/pkg/broadcaster"
)
// A Pool manages concurrent pulls. It deduplicates in-progress downloads.
type Pool struct {
sync.Mutex
pullingPool map[string]*broadcaster.Buffered
}
// NewPool creates a new Pool.
func NewPool() *Pool {
return &Pool{
pullingPool: make(map[string]*broadcaster.Buffered),
}
}
// add checks if a pull is already running, and returns (broadcaster, true)
// if a running operation is found. Otherwise, it creates a new one and returns
// (broadcaster, false).
func (pool *Pool) add(key string) (*broadcaster.Buffered, bool) {
pool.Lock()
defer pool.Unlock()
if p, exists := pool.pullingPool[key]; exists {
return p, true
}
broadcaster := broadcaster.NewBuffered()
pool.pullingPool[key] = broadcaster
return broadcaster, false
}
func (pool *Pool) removeWithError(key string, broadcasterResult error) error {
pool.Lock()
defer pool.Unlock()
if broadcaster, exists := pool.pullingPool[key]; exists {
broadcaster.CloseWithError(broadcasterResult)
delete(pool.pullingPool, key)
}
return nil
}
func (pool *Pool) remove(key string) error {
return pool.removeWithError(key, nil)
}

View file

@ -1,28 +0,0 @@
package distribution
import (
"testing"
)
func TestPools(t *testing.T) {
p := NewPool()
if _, found := p.add("test1"); found {
t.Fatal("Expected pull test1 not to be in progress")
}
if _, found := p.add("test2"); found {
t.Fatal("Expected pull test2 not to be in progress")
}
if _, found := p.add("test1"); !found {
t.Fatalf("Expected pull test1 to be in progress`")
}
if err := p.remove("test2"); err != nil {
t.Fatal(err)
}
if err := p.remove("test2"); err != nil {
t.Fatal(err)
}
if err := p.remove("test1"); err != nil {
t.Fatal(err)
}
}

View file

@ -2,7 +2,7 @@ package distribution
import (
"fmt"
"io"
"os"
"strings"
"github.com/Sirupsen/logrus"
@ -10,11 +10,12 @@ import (
"github.com/docker/docker/cliconfig"
"github.com/docker/docker/daemon/events"
"github.com/docker/docker/distribution/metadata"
"github.com/docker/docker/distribution/xfer"
"github.com/docker/docker/image"
"github.com/docker/docker/layer"
"github.com/docker/docker/pkg/streamformatter"
"github.com/docker/docker/pkg/progress"
"github.com/docker/docker/registry"
"github.com/docker/docker/tag"
"golang.org/x/net/context"
)
// ImagePullConfig stores pull configuration.
@ -25,9 +26,9 @@ type ImagePullConfig struct {
// AuthConfig holds authentication credentials for authenticating with
// the registry.
AuthConfig *cliconfig.AuthConfig
// OutStream is the output writer for showing the status of the pull
// ProgressOutput is the interface for showing the status of the pull
// operation.
OutStream io.Writer
ProgressOutput progress.Output
// RegistryService is the registry service to use for TLS configuration
// and endpoint lookup.
RegistryService *registry.Service
@ -36,14 +37,12 @@ type ImagePullConfig struct {
// MetadataStore is the storage backend for distribution-specific
// metadata.
MetadataStore metadata.Store
// LayerStore manages layers.
LayerStore layer.Store
// ImageStore manages images.
ImageStore image.Store
// TagStore manages tags.
TagStore tag.Store
// Pool manages concurrent pulls.
Pool *Pool
// DownloadManager manages concurrent pulls.
DownloadManager *xfer.LayerDownloadManager
}
// Puller is an interface that abstracts pulling for different API versions.
@ -51,7 +50,7 @@ type Puller interface {
// Pull tries to pull the image referenced by `tag`
// Pull returns an error if any, as well as a boolean that determines whether to retry Pull on the next configured endpoint.
//
Pull(ref reference.Named) (fallback bool, err error)
Pull(ctx context.Context, ref reference.Named) (fallback bool, err error)
}
// newPuller returns a Puller interface that will pull from either a v1 or v2
@ -59,14 +58,13 @@ type Puller interface {
// whether a v1 or v2 puller will be created. The other parameters are passed
// through to the underlying puller implementation for use during the actual
// pull operation.
func newPuller(endpoint registry.APIEndpoint, repoInfo *registry.RepositoryInfo, imagePullConfig *ImagePullConfig, sf *streamformatter.StreamFormatter) (Puller, error) {
func newPuller(endpoint registry.APIEndpoint, repoInfo *registry.RepositoryInfo, imagePullConfig *ImagePullConfig) (Puller, error) {
switch endpoint.Version {
case registry.APIVersion2:
return &v2Puller{
blobSumService: metadata.NewBlobSumService(imagePullConfig.MetadataStore),
endpoint: endpoint,
config: imagePullConfig,
sf: sf,
repoInfo: repoInfo,
}, nil
case registry.APIVersion1:
@ -74,7 +72,6 @@ func newPuller(endpoint registry.APIEndpoint, repoInfo *registry.RepositoryInfo,
v1IDService: metadata.NewV1IDService(imagePullConfig.MetadataStore),
endpoint: endpoint,
config: imagePullConfig,
sf: sf,
repoInfo: repoInfo,
}, nil
}
@ -83,9 +80,7 @@ func newPuller(endpoint registry.APIEndpoint, repoInfo *registry.RepositoryInfo,
// Pull initiates a pull operation. image is the repository name to pull, and
// tag may be either empty, or indicate a specific tag to pull.
func Pull(ref reference.Named, imagePullConfig *ImagePullConfig) error {
var sf = streamformatter.NewJSONStreamFormatter()
func Pull(ctx context.Context, ref reference.Named, imagePullConfig *ImagePullConfig) error {
// Resolve the Repository name from fqn to RepositoryInfo
repoInfo, err := imagePullConfig.RegistryService.ResolveRepository(ref)
if err != nil {
@ -120,12 +115,19 @@ func Pull(ref reference.Named, imagePullConfig *ImagePullConfig) error {
for _, endpoint := range endpoints {
logrus.Debugf("Trying to pull %s from %s %s", repoInfo.LocalName, endpoint.URL, endpoint.Version)
puller, err := newPuller(endpoint, repoInfo, imagePullConfig, sf)
puller, err := newPuller(endpoint, repoInfo, imagePullConfig)
if err != nil {
errors = append(errors, err.Error())
continue
}
if fallback, err := puller.Pull(ref); err != nil {
if fallback, err := puller.Pull(ctx, ref); err != nil {
// Was this pull cancelled? If so, don't try to fall
// back.
select {
case <-ctx.Done():
fallback = false
default:
}
if fallback {
if _, ok := err.(registry.ErrNoSupport); !ok {
// Because we found an error that's not ErrNoSupport, discard all subsequent ErrNoSupport errors.
@ -165,11 +167,11 @@ func Pull(ref reference.Named, imagePullConfig *ImagePullConfig) error {
// status message indicates that a newer image was downloaded. Otherwise, it
// indicates that the image is up to date. requestedTag is the tag the message
// will refer to.
func writeStatus(requestedTag string, out io.Writer, sf *streamformatter.StreamFormatter, layersDownloaded bool) {
func writeStatus(requestedTag string, out progress.Output, layersDownloaded bool) {
if layersDownloaded {
out.Write(sf.FormatStatus("", "Status: Downloaded newer image for %s", requestedTag))
progress.Message(out, "", "Status: Downloaded newer image for "+requestedTag)
} else {
out.Write(sf.FormatStatus("", "Status: Image is up to date for %s", requestedTag))
progress.Message(out, "", "Status: Image is up to date for "+requestedTag)
}
}
@ -183,3 +185,16 @@ func validateRepoName(name string) error {
}
return nil
}
// tmpFileClose creates a closer function for a temporary file that closes the file
// and also deletes it.
func tmpFileCloser(tmpFile *os.File) func() error {
return func() error {
tmpFile.Close()
if err := os.RemoveAll(tmpFile.Name()); err != nil {
logrus.Errorf("Failed to remove temp file: %s", tmpFile.Name())
}
return nil
}
}

View file

@ -1,43 +1,42 @@
package distribution
import (
"encoding/json"
"errors"
"fmt"
"io"
"io/ioutil"
"net"
"net/url"
"strings"
"sync"
"time"
"github.com/Sirupsen/logrus"
"github.com/docker/distribution/reference"
"github.com/docker/distribution/registry/client/transport"
"github.com/docker/docker/distribution/metadata"
"github.com/docker/docker/distribution/xfer"
"github.com/docker/docker/image"
"github.com/docker/docker/image/v1"
"github.com/docker/docker/layer"
"github.com/docker/docker/pkg/archive"
"github.com/docker/docker/pkg/progressreader"
"github.com/docker/docker/pkg/streamformatter"
"github.com/docker/docker/pkg/ioutils"
"github.com/docker/docker/pkg/progress"
"github.com/docker/docker/pkg/stringid"
"github.com/docker/docker/registry"
"golang.org/x/net/context"
)
type v1Puller struct {
v1IDService *metadata.V1IDService
endpoint registry.APIEndpoint
config *ImagePullConfig
sf *streamformatter.StreamFormatter
repoInfo *registry.RepositoryInfo
session *registry.Session
}
func (p *v1Puller) Pull(ref reference.Named) (fallback bool, err error) {
func (p *v1Puller) Pull(ctx context.Context, ref reference.Named) (fallback bool, err error) {
if _, isDigested := ref.(reference.Digested); isDigested {
// Allowing fallback, because HTTPS v1 is before HTTP v2
return true, registry.ErrNoSupport{errors.New("Cannot pull by digest with v1 registry")}
return true, registry.ErrNoSupport{Err: errors.New("Cannot pull by digest with v1 registry")}
}
tlsConfig, err := p.config.RegistryService.TLSConfig(p.repoInfo.Index.Name)
@ -62,19 +61,17 @@ func (p *v1Puller) Pull(ref reference.Named) (fallback bool, err error) {
logrus.Debugf("Fallback from error: %s", err)
return true, err
}
if err := p.pullRepository(ref); err != nil {
if err := p.pullRepository(ctx, ref); err != nil {
// TODO(dmcgowan): Check if should fallback
return false, err
}
out := p.config.OutStream
out.Write(p.sf.FormatStatus("", "%s: this image was pulled from a legacy registry. Important: This registry version will not be supported in future versions of docker.", p.repoInfo.CanonicalName.Name()))
progress.Message(p.config.ProgressOutput, "", p.repoInfo.CanonicalName.Name()+": this image was pulled from a legacy registry. Important: This registry version will not be supported in future versions of docker.")
return false, nil
}
func (p *v1Puller) pullRepository(ref reference.Named) error {
out := p.config.OutStream
out.Write(p.sf.FormatStatus("", "Pulling repository %s", p.repoInfo.CanonicalName.Name()))
func (p *v1Puller) pullRepository(ctx context.Context, ref reference.Named) error {
progress.Message(p.config.ProgressOutput, "", "Pulling repository "+p.repoInfo.CanonicalName.Name())
repoData, err := p.session.GetRepositoryData(p.repoInfo.RemoteName)
if err != nil {
@ -112,46 +109,18 @@ func (p *v1Puller) pullRepository(ref reference.Named) error {
}
}
errors := make(chan error)
layerDownloaded := make(chan struct{})
layersDownloaded := false
var wg sync.WaitGroup
for _, imgData := range repoData.ImgList {
if isTagged && imgData.Tag != tagged.Tag() {
continue
}
wg.Add(1)
go func(img *registry.ImgData) {
p.downloadImage(out, repoData, img, layerDownloaded, errors)
wg.Done()
}(imgData)
}
go func() {
wg.Wait()
close(errors)
}()
var lastError error
selectLoop:
for {
select {
case err, ok := <-errors:
if !ok {
break selectLoop
}
lastError = err
case <-layerDownloaded:
layersDownloaded = true
err := p.downloadImage(ctx, repoData, imgData, &layersDownloaded)
if err != nil {
return err
}
}
if lastError != nil {
return lastError
}
localNameRef := p.repoInfo.LocalName
if isTagged {
localNameRef, err = reference.WithTag(localNameRef, tagged.Tag())
@ -159,194 +128,143 @@ selectLoop:
localNameRef = p.repoInfo.LocalName
}
}
writeStatus(localNameRef.String(), out, p.sf, layersDownloaded)
writeStatus(localNameRef.String(), p.config.ProgressOutput, layersDownloaded)
return nil
}
func (p *v1Puller) downloadImage(out io.Writer, repoData *registry.RepositoryData, img *registry.ImgData, layerDownloaded chan struct{}, errors chan error) {
func (p *v1Puller) downloadImage(ctx context.Context, repoData *registry.RepositoryData, img *registry.ImgData, layersDownloaded *bool) error {
if img.Tag == "" {
logrus.Debugf("Image (id: %s) present in this repository but untagged, skipping", img.ID)
return
return nil
}
localNameRef, err := reference.WithTag(p.repoInfo.LocalName, img.Tag)
if err != nil {
retErr := fmt.Errorf("Image (id: %s) has invalid tag: %s", img.ID, img.Tag)
logrus.Debug(retErr.Error())
errors <- retErr
return retErr
}
if err := v1.ValidateID(img.ID); err != nil {
errors <- err
return
return err
}
out.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), fmt.Sprintf("Pulling image (%s) from %s", img.Tag, p.repoInfo.CanonicalName.Name()), nil))
progress.Updatef(p.config.ProgressOutput, stringid.TruncateID(img.ID), "Pulling image (%s) from %s", img.Tag, p.repoInfo.CanonicalName.Name())
success := false
var lastErr error
var isDownloaded bool
for _, ep := range p.repoInfo.Index.Mirrors {
ep += "v1/"
out.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), fmt.Sprintf("Pulling image (%s) from %s, mirror: %s", img.Tag, p.repoInfo.CanonicalName.Name(), ep), nil))
if isDownloaded, err = p.pullImage(out, img.ID, ep, localNameRef); err != nil {
progress.Updatef(p.config.ProgressOutput, stringid.TruncateID(img.ID), fmt.Sprintf("Pulling image (%s) from %s, mirror: %s", img.Tag, p.repoInfo.CanonicalName.Name(), ep))
if err = p.pullImage(ctx, img.ID, ep, localNameRef, layersDownloaded); err != nil {
// Don't report errors when pulling from mirrors.
logrus.Debugf("Error pulling image (%s) from %s, mirror: %s, %s", img.Tag, p.repoInfo.CanonicalName.Name(), ep, err)
continue
}
if isDownloaded {
layerDownloaded <- struct{}{}
}
success = true
break
}
if !success {
for _, ep := range repoData.Endpoints {
out.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), fmt.Sprintf("Pulling image (%s) from %s, endpoint: %s", img.Tag, p.repoInfo.CanonicalName.Name(), ep), nil))
if isDownloaded, err = p.pullImage(out, img.ID, ep, localNameRef); err != nil {
progress.Updatef(p.config.ProgressOutput, stringid.TruncateID(img.ID), "Pulling image (%s) from %s, endpoint: %s", img.Tag, p.repoInfo.CanonicalName.Name(), ep)
if err = p.pullImage(ctx, img.ID, ep, localNameRef, layersDownloaded); err != nil {
// It's not ideal that only the last error is returned, it would be better to concatenate the errors.
// As the error is also given to the output stream the user will see the error.
lastErr = err
out.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), fmt.Sprintf("Error pulling image (%s) from %s, endpoint: %s, %s", img.Tag, p.repoInfo.CanonicalName.Name(), ep, err), nil))
progress.Updatef(p.config.ProgressOutput, stringid.TruncateID(img.ID), "Error pulling image (%s) from %s, endpoint: %s, %s", img.Tag, p.repoInfo.CanonicalName.Name(), ep, err)
continue
}
if isDownloaded {
layerDownloaded <- struct{}{}
}
success = true
break
}
}
if !success {
err := fmt.Errorf("Error pulling image (%s) from %s, %v", img.Tag, p.repoInfo.CanonicalName.Name(), lastErr)
out.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), err.Error(), nil))
errors <- err
return
progress.Update(p.config.ProgressOutput, stringid.TruncateID(img.ID), err.Error())
return err
}
out.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), "Download complete", nil))
progress.Update(p.config.ProgressOutput, stringid.TruncateID(img.ID), "Download complete")
return nil
}
func (p *v1Puller) pullImage(out io.Writer, v1ID, endpoint string, localNameRef reference.Named) (layersDownloaded bool, err error) {
func (p *v1Puller) pullImage(ctx context.Context, v1ID, endpoint string, localNameRef reference.Named, layersDownloaded *bool) (err error) {
var history []string
history, err = p.session.GetRemoteHistory(v1ID, endpoint)
if err != nil {
return false, err
return err
}
if len(history) < 1 {
return false, fmt.Errorf("empty history for image %s", v1ID)
return fmt.Errorf("empty history for image %s", v1ID)
}
out.Write(p.sf.FormatProgress(stringid.TruncateID(v1ID), "Pulling dependent layers", nil))
// FIXME: Try to stream the images?
// FIXME: Launch the getRemoteImage() in goroutines
progress.Update(p.config.ProgressOutput, stringid.TruncateID(v1ID), "Pulling dependent layers")
var (
referencedLayers []layer.Layer
parentID layer.ChainID
newHistory []image.History
img *image.V1Image
imgJSON []byte
imgSize int64
descriptors []xfer.DownloadDescriptor
newHistory []image.History
imgJSON []byte
imgSize int64
)
defer func() {
for _, l := range referencedLayers {
layer.ReleaseAndLog(p.config.LayerStore, l)
}
}()
layersDownloaded = false
// Iterate over layers from top-most to bottom-most, checking if any
// already exist on disk.
var i int
for i = 0; i != len(history); i++ {
v1LayerID := history[i]
// Do we have a mapping for this particular v1 ID on this
// registry?
if layerID, err := p.v1IDService.Get(v1LayerID, p.repoInfo.Index.Name); err == nil {
// Does the layer actually exist
if l, err := p.config.LayerStore.Get(layerID); err == nil {
for j := i; j >= 0; j-- {
logrus.Debugf("Layer already exists: %s", history[j])
out.Write(p.sf.FormatProgress(stringid.TruncateID(history[j]), "Already exists", nil))
}
referencedLayers = append(referencedLayers, l)
parentID = layerID
break
}
}
}
needsDownload := i
// Iterate over layers, in order from bottom-most to top-most. Download
// config for all layers, and download actual layer data if needed.
for i = len(history) - 1; i >= 0; i-- {
// config for all layers and create descriptors.
for i := len(history) - 1; i >= 0; i-- {
v1LayerID := history[i]
imgJSON, imgSize, err = p.downloadLayerConfig(out, v1LayerID, endpoint)
imgJSON, imgSize, err = p.downloadLayerConfig(v1LayerID, endpoint)
if err != nil {
return layersDownloaded, err
}
img = &image.V1Image{}
if err := json.Unmarshal(imgJSON, img); err != nil {
return layersDownloaded, err
}
if i < needsDownload {
l, err := p.downloadLayer(out, v1LayerID, endpoint, parentID, imgSize, &layersDownloaded)
// Note: This needs to be done even in the error case to avoid
// stale references to the layer.
if l != nil {
referencedLayers = append(referencedLayers, l)
}
if err != nil {
return layersDownloaded, err
}
parentID = l.ChainID()
return err
}
// Create a new-style config from the legacy configs
h, err := v1.HistoryFromConfig(imgJSON, false)
if err != nil {
return layersDownloaded, err
return err
}
newHistory = append(newHistory, h)
layerDescriptor := &v1LayerDescriptor{
v1LayerID: v1LayerID,
indexName: p.repoInfo.Index.Name,
endpoint: endpoint,
v1IDService: p.v1IDService,
layersDownloaded: layersDownloaded,
layerSize: imgSize,
session: p.session,
}
descriptors = append(descriptors, layerDescriptor)
}
rootFS := image.NewRootFS()
l := referencedLayers[len(referencedLayers)-1]
for l != nil {
rootFS.DiffIDs = append([]layer.DiffID{l.DiffID()}, rootFS.DiffIDs...)
l = l.Parent()
}
config, err := v1.MakeConfigFromV1Config(imgJSON, rootFS, newHistory)
resultRootFS, release, err := p.config.DownloadManager.Download(ctx, *rootFS, descriptors, p.config.ProgressOutput)
if err != nil {
return layersDownloaded, err
return err
}
defer release()
config, err := v1.MakeConfigFromV1Config(imgJSON, &resultRootFS, newHistory)
if err != nil {
return err
}
imageID, err := p.config.ImageStore.Create(config)
if err != nil {
return layersDownloaded, err
return err
}
if err := p.config.TagStore.AddTag(localNameRef, imageID, true); err != nil {
return layersDownloaded, err
return err
}
return layersDownloaded, nil
return nil
}
func (p *v1Puller) downloadLayerConfig(out io.Writer, v1LayerID, endpoint string) (imgJSON []byte, imgSize int64, err error) {
out.Write(p.sf.FormatProgress(stringid.TruncateID(v1LayerID), "Pulling metadata", nil))
func (p *v1Puller) downloadLayerConfig(v1LayerID, endpoint string) (imgJSON []byte, imgSize int64, err error) {
progress.Update(p.config.ProgressOutput, stringid.TruncateID(v1LayerID), "Pulling metadata")
retries := 5
for j := 1; j <= retries; j++ {
imgJSON, imgSize, err := p.session.GetRemoteImageJSON(v1LayerID, endpoint)
if err != nil && j == retries {
out.Write(p.sf.FormatProgress(stringid.TruncateID(v1LayerID), "Error pulling layer metadata", nil))
progress.Update(p.config.ProgressOutput, stringid.TruncateID(v1LayerID), "Error pulling layer metadata")
return nil, 0, err
} else if err != nil {
time.Sleep(time.Duration(j) * 500 * time.Millisecond)
@ -360,95 +278,66 @@ func (p *v1Puller) downloadLayerConfig(out io.Writer, v1LayerID, endpoint string
return nil, 0, nil
}
func (p *v1Puller) downloadLayer(out io.Writer, v1LayerID, endpoint string, parentID layer.ChainID, layerSize int64, layersDownloaded *bool) (l layer.Layer, err error) {
// ensure no two downloads of the same layer happen at the same time
poolKey := "layer:" + v1LayerID
broadcaster, found := p.config.Pool.add(poolKey)
broadcaster.Add(out)
if found {
logrus.Debugf("Image (id: %s) pull is already running, skipping", v1LayerID)
if err = broadcaster.Wait(); err != nil {
return nil, err
}
layerID, err := p.v1IDService.Get(v1LayerID, p.repoInfo.Index.Name)
if err != nil {
return nil, err
}
// Does the layer actually exist
l, err := p.config.LayerStore.Get(layerID)
if err != nil {
return nil, err
}
return l, nil
}
type v1LayerDescriptor struct {
v1LayerID string
indexName string
endpoint string
v1IDService *metadata.V1IDService
layersDownloaded *bool
layerSize int64
session *registry.Session
}
// This must use a closure so it captures the value of err when
// the function returns, not when the 'defer' is evaluated.
defer func() {
p.config.Pool.removeWithError(poolKey, err)
}()
func (ld *v1LayerDescriptor) Key() string {
return "v1:" + ld.v1LayerID
}
retries := 5
for j := 1; j <= retries; j++ {
// Get the layer
status := "Pulling fs layer"
if j > 1 {
status = fmt.Sprintf("Pulling fs layer [retries: %d]", j)
}
broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(v1LayerID), status, nil))
layerReader, err := p.session.GetRemoteImageLayer(v1LayerID, endpoint, layerSize)
func (ld *v1LayerDescriptor) ID() string {
return stringid.TruncateID(ld.v1LayerID)
}
func (ld *v1LayerDescriptor) DiffID() (layer.DiffID, error) {
return ld.v1IDService.Get(ld.v1LayerID, ld.indexName)
}
func (ld *v1LayerDescriptor) Download(ctx context.Context, progressOutput progress.Output) (io.ReadCloser, int64, error) {
progress.Update(progressOutput, ld.ID(), "Pulling fs layer")
layerReader, err := ld.session.GetRemoteImageLayer(ld.v1LayerID, ld.endpoint, ld.layerSize)
if err != nil {
progress.Update(progressOutput, ld.ID(), "Error pulling dependent layers")
if uerr, ok := err.(*url.Error); ok {
err = uerr.Err
}
if terr, ok := err.(net.Error); ok && terr.Timeout() && j < retries {
time.Sleep(time.Duration(j) * 500 * time.Millisecond)
continue
} else if err != nil {
broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(v1LayerID), "Error pulling dependent layers", nil))
return nil, err
if terr, ok := err.(net.Error); ok && terr.Timeout() {
return nil, 0, err
}
*layersDownloaded = true
defer layerReader.Close()
return nil, 0, xfer.DoNotRetry{Err: err}
}
*ld.layersDownloaded = true
reader := progressreader.New(progressreader.Config{
In: layerReader,
Out: broadcaster,
Formatter: p.sf,
Size: layerSize,
NewLines: false,
ID: stringid.TruncateID(v1LayerID),
Action: "Downloading",
})
inflatedLayerData, err := archive.DecompressStream(reader)
if err != nil {
return nil, fmt.Errorf("could not get decompression stream: %v", err)
}
l, err := p.config.LayerStore.Register(inflatedLayerData, parentID)
if err != nil {
return nil, fmt.Errorf("failed to register layer: %v", err)
}
logrus.Debugf("layer %s registered successfully", l.DiffID())
if terr, ok := err.(net.Error); ok && terr.Timeout() && j < retries {
time.Sleep(time.Duration(j) * 500 * time.Millisecond)
continue
} else if err != nil {
broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(v1LayerID), "Error downloading dependent layers", nil))
return nil, err
}
// Cache mapping from this v1 ID to content-addressable layer ID
if err := p.v1IDService.Set(v1LayerID, p.repoInfo.Index.Name, l.ChainID()); err != nil {
return nil, err
}
broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(v1LayerID), "Download complete", nil))
broadcaster.Close()
return l, nil
tmpFile, err := ioutil.TempFile("", "GetImageBlob")
if err != nil {
layerReader.Close()
return nil, 0, err
}
// not reached
return nil, nil
reader := progress.NewProgressReader(ioutils.NewCancelReadCloser(ctx, layerReader), progressOutput, ld.layerSize, ld.ID(), "Downloading")
defer reader.Close()
_, err = io.Copy(tmpFile, reader)
if err != nil {
return nil, 0, err
}
progress.Update(progressOutput, ld.ID(), "Download complete")
logrus.Debugf("Downloaded %s to tempfile %s", ld.ID(), tmpFile.Name())
tmpFile.Seek(0, 0)
return ioutils.NewReadCloserWrapper(tmpFile, tmpFileCloser(tmpFile)), ld.layerSize, nil
}
func (ld *v1LayerDescriptor) Registered(diffID layer.DiffID) {
// Cache mapping from this layer's DiffID to the blobsum
ld.v1IDService.Set(ld.v1LayerID, ld.indexName, diffID)
}

View file

@ -15,13 +15,12 @@ import (
"github.com/docker/distribution/manifest/schema1"
"github.com/docker/distribution/reference"
"github.com/docker/docker/distribution/metadata"
"github.com/docker/docker/distribution/xfer"
"github.com/docker/docker/image"
"github.com/docker/docker/image/v1"
"github.com/docker/docker/layer"
"github.com/docker/docker/pkg/archive"
"github.com/docker/docker/pkg/broadcaster"
"github.com/docker/docker/pkg/progressreader"
"github.com/docker/docker/pkg/streamformatter"
"github.com/docker/docker/pkg/ioutils"
"github.com/docker/docker/pkg/progress"
"github.com/docker/docker/pkg/stringid"
"github.com/docker/docker/registry"
"golang.org/x/net/context"
@ -31,23 +30,19 @@ type v2Puller struct {
blobSumService *metadata.BlobSumService
endpoint registry.APIEndpoint
config *ImagePullConfig
sf *streamformatter.StreamFormatter
repoInfo *registry.RepositoryInfo
repo distribution.Repository
sessionID string
}
func (p *v2Puller) Pull(ref reference.Named) (fallback bool, err error) {
func (p *v2Puller) Pull(ctx context.Context, ref reference.Named) (fallback bool, err error) {
// TODO(tiborvass): was ReceiveTimeout
p.repo, err = NewV2Repository(p.repoInfo, p.endpoint, p.config.MetaHeaders, p.config.AuthConfig, "pull")
if err != nil {
logrus.Debugf("Error getting v2 registry: %v", err)
logrus.Warnf("Error getting v2 registry: %v", err)
return true, err
}
p.sessionID = stringid.GenerateRandomID()
if err := p.pullV2Repository(ref); err != nil {
if err := p.pullV2Repository(ctx, ref); err != nil {
if registry.ContinueOnError(err) {
logrus.Debugf("Error trying v2 registry: %v", err)
return true, err
@ -57,7 +52,7 @@ func (p *v2Puller) Pull(ref reference.Named) (fallback bool, err error) {
return false, nil
}
func (p *v2Puller) pullV2Repository(ref reference.Named) (err error) {
func (p *v2Puller) pullV2Repository(ctx context.Context, ref reference.Named) (err error) {
var refs []reference.Named
taggedName := p.repoInfo.LocalName
if tagged, isTagged := ref.(reference.Tagged); isTagged {
@ -73,7 +68,7 @@ func (p *v2Puller) pullV2Repository(ref reference.Named) (err error) {
}
refs = []reference.Named{taggedName}
} else {
manSvc, err := p.repo.Manifests(context.Background())
manSvc, err := p.repo.Manifests(ctx)
if err != nil {
return err
}
@ -98,98 +93,109 @@ func (p *v2Puller) pullV2Repository(ref reference.Named) (err error) {
for _, pullRef := range refs {
// pulledNew is true if either new layers were downloaded OR if existing images were newly tagged
// TODO(tiborvass): should we change the name of `layersDownload`? What about message in WriteStatus?
pulledNew, err := p.pullV2Tag(p.config.OutStream, pullRef)
pulledNew, err := p.pullV2Tag(ctx, pullRef)
if err != nil {
return err
}
layersDownloaded = layersDownloaded || pulledNew
}
writeStatus(taggedName.String(), p.config.OutStream, p.sf, layersDownloaded)
writeStatus(taggedName.String(), p.config.ProgressOutput, layersDownloaded)
return nil
}
// downloadInfo is used to pass information from download to extractor
type downloadInfo struct {
tmpFile *os.File
digest digest.Digest
layer distribution.ReadSeekCloser
size int64
err chan error
poolKey string
broadcaster *broadcaster.Buffered
type v2LayerDescriptor struct {
digest digest.Digest
repo distribution.Repository
blobSumService *metadata.BlobSumService
}
type errVerification struct{}
func (ld *v2LayerDescriptor) Key() string {
return "v2:" + ld.digest.String()
}
func (errVerification) Error() string { return "verification failed" }
func (ld *v2LayerDescriptor) ID() string {
return stringid.TruncateID(ld.digest.String())
}
func (p *v2Puller) download(di *downloadInfo) {
logrus.Debugf("pulling blob %q", di.digest)
func (ld *v2LayerDescriptor) DiffID() (layer.DiffID, error) {
return ld.blobSumService.GetDiffID(ld.digest)
}
blobs := p.repo.Blobs(context.Background())
func (ld *v2LayerDescriptor) Download(ctx context.Context, progressOutput progress.Output) (io.ReadCloser, int64, error) {
logrus.Debugf("pulling blob %q", ld.digest)
layerDownload, err := blobs.Open(context.Background(), di.digest)
blobs := ld.repo.Blobs(ctx)
layerDownload, err := blobs.Open(ctx, ld.digest)
if err != nil {
logrus.Debugf("Error fetching layer: %v", err)
di.err <- err
return
logrus.Debugf("Error statting layer: %v", err)
if err == distribution.ErrBlobUnknown {
return nil, 0, xfer.DoNotRetry{Err: err}
}
return nil, 0, retryOnError(err)
}
defer layerDownload.Close()
di.size, err = layerDownload.Seek(0, os.SEEK_END)
size, err := layerDownload.Seek(0, os.SEEK_END)
if err != nil {
// Seek failed, perhaps because there was no Content-Length
// header. This shouldn't fail the download, because we can
// still continue without a progress bar.
di.size = 0
size = 0
} else {
// Restore the seek offset at the beginning of the stream.
_, err = layerDownload.Seek(0, os.SEEK_SET)
if err != nil {
di.err <- err
return
return nil, 0, err
}
}
verifier, err := digest.NewDigestVerifier(di.digest)
reader := progress.NewProgressReader(ioutils.NewCancelReadCloser(ctx, layerDownload), progressOutput, size, ld.ID(), "Downloading")
defer reader.Close()
verifier, err := digest.NewDigestVerifier(ld.digest)
if err != nil {
di.err <- err
return
return nil, 0, xfer.DoNotRetry{Err: err}
}
digestStr := di.digest.String()
tmpFile, err := ioutil.TempFile("", "GetImageBlob")
if err != nil {
return nil, 0, xfer.DoNotRetry{Err: err}
}
reader := progressreader.New(progressreader.Config{
In: ioutil.NopCloser(io.TeeReader(layerDownload, verifier)),
Out: di.broadcaster,
Formatter: p.sf,
Size: di.size,
NewLines: false,
ID: stringid.TruncateID(digestStr),
Action: "Downloading",
})
io.Copy(di.tmpFile, reader)
_, err = io.Copy(tmpFile, io.TeeReader(reader, verifier))
if err != nil {
return nil, 0, retryOnError(err)
}
di.broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(digestStr), "Verifying Checksum", nil))
progress.Update(progressOutput, ld.ID(), "Verifying Checksum")
if !verifier.Verified() {
err = fmt.Errorf("filesystem layer verification failed for digest %s", di.digest)
err = fmt.Errorf("filesystem layer verification failed for digest %s", ld.digest)
logrus.Error(err)
di.err <- err
return
tmpFile.Close()
if err := os.RemoveAll(tmpFile.Name()); err != nil {
logrus.Errorf("Failed to remove temp file: %s", tmpFile.Name())
}
return nil, 0, xfer.DoNotRetry{Err: err}
}
di.broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(digestStr), "Download complete", nil))
progress.Update(progressOutput, ld.ID(), "Download complete")
logrus.Debugf("Downloaded %s to tempfile %s", digestStr, di.tmpFile.Name())
di.layer = layerDownload
logrus.Debugf("Downloaded %s to tempfile %s", ld.ID(), tmpFile.Name())
di.err <- nil
tmpFile.Seek(0, 0)
return ioutils.NewReadCloserWrapper(tmpFile, tmpFileCloser(tmpFile)), size, nil
}
func (p *v2Puller) pullV2Tag(out io.Writer, ref reference.Named) (tagUpdated bool, err error) {
func (ld *v2LayerDescriptor) Registered(diffID layer.DiffID) {
// Cache mapping from this layer's DiffID to the blobsum
ld.blobSumService.Add(diffID, ld.digest)
}
func (p *v2Puller) pullV2Tag(ctx context.Context, ref reference.Named) (tagUpdated bool, err error) {
tagOrDigest := ""
if tagged, isTagged := ref.(reference.Tagged); isTagged {
tagOrDigest = tagged.Tag()
@ -201,7 +207,7 @@ func (p *v2Puller) pullV2Tag(out io.Writer, ref reference.Named) (tagUpdated boo
logrus.Debugf("Pulling ref from V2 registry: %q", tagOrDigest)
manSvc, err := p.repo.Manifests(context.Background())
manSvc, err := p.repo.Manifests(ctx)
if err != nil {
return false, err
}
@ -231,33 +237,17 @@ func (p *v2Puller) pullV2Tag(out io.Writer, ref reference.Named) (tagUpdated boo
return false, err
}
out.Write(p.sf.FormatStatus(tagOrDigest, "Pulling from %s", p.repo.Name()))
progress.Message(p.config.ProgressOutput, tagOrDigest, "Pulling from "+p.repo.Name())
var downloads []*downloadInfo
defer func() {
for _, d := range downloads {
p.config.Pool.removeWithError(d.poolKey, err)
if d.tmpFile != nil {
d.tmpFile.Close()
if err := os.RemoveAll(d.tmpFile.Name()); err != nil {
logrus.Errorf("Failed to remove temp file: %s", d.tmpFile.Name())
}
}
}
}()
var descriptors []xfer.DownloadDescriptor
// Image history converted to the new format
var history []image.History
poolKey := "v2layer:"
notFoundLocally := false
// Note that the order of this loop is in the direction of bottom-most
// to top-most, so that the downloads slice gets ordered correctly.
for i := len(verifiedManifest.FSLayers) - 1; i >= 0; i-- {
blobSum := verifiedManifest.FSLayers[i].BlobSum
poolKey += blobSum.String()
var throwAway struct {
ThrowAway bool `json:"throwaway,omitempty"`
@ -276,119 +266,22 @@ func (p *v2Puller) pullV2Tag(out io.Writer, ref reference.Named) (tagUpdated boo
continue
}
// Do we have a layer on disk corresponding to the set of
// blobsums up to this point?
if !notFoundLocally {
notFoundLocally = true
diffID, err := p.blobSumService.GetDiffID(blobSum)
if err == nil {
rootFS.Append(diffID)
if l, err := p.config.LayerStore.Get(rootFS.ChainID()); err == nil {
notFoundLocally = false
logrus.Debugf("Layer already exists: %s", blobSum.String())
out.Write(p.sf.FormatProgress(stringid.TruncateID(blobSum.String()), "Already exists", nil))
defer layer.ReleaseAndLog(p.config.LayerStore, l)
continue
} else {
rootFS.DiffIDs = rootFS.DiffIDs[:len(rootFS.DiffIDs)-1]
}
}
layerDescriptor := &v2LayerDescriptor{
digest: blobSum,
repo: p.repo,
blobSumService: p.blobSumService,
}
out.Write(p.sf.FormatProgress(stringid.TruncateID(blobSum.String()), "Pulling fs layer", nil))
tmpFile, err := ioutil.TempFile("", "GetImageBlob")
if err != nil {
return false, err
}
d := &downloadInfo{
poolKey: poolKey,
digest: blobSum,
tmpFile: tmpFile,
// TODO: seems like this chan buffer solved hanging problem in go1.5,
// this can indicate some deeper problem that somehow we never take
// error from channel in loop below
err: make(chan error, 1),
}
downloads = append(downloads, d)
broadcaster, found := p.config.Pool.add(d.poolKey)
broadcaster.Add(out)
d.broadcaster = broadcaster
if found {
d.err <- nil
} else {
go p.download(d)
}
descriptors = append(descriptors, layerDescriptor)
}
for _, d := range downloads {
if err := <-d.err; err != nil {
return false, err
}
if d.layer == nil {
// Wait for a different pull to download and extract
// this layer.
err = d.broadcaster.Wait()
if err != nil {
return false, err
}
diffID, err := p.blobSumService.GetDiffID(d.digest)
if err != nil {
return false, err
}
rootFS.Append(diffID)
l, err := p.config.LayerStore.Get(rootFS.ChainID())
if err != nil {
return false, err
}
defer layer.ReleaseAndLog(p.config.LayerStore, l)
continue
}
d.tmpFile.Seek(0, 0)
reader := progressreader.New(progressreader.Config{
In: d.tmpFile,
Out: d.broadcaster,
Formatter: p.sf,
Size: d.size,
NewLines: false,
ID: stringid.TruncateID(d.digest.String()),
Action: "Extracting",
})
inflatedLayerData, err := archive.DecompressStream(reader)
if err != nil {
return false, fmt.Errorf("could not get decompression stream: %v", err)
}
l, err := p.config.LayerStore.Register(inflatedLayerData, rootFS.ChainID())
if err != nil {
return false, fmt.Errorf("failed to register layer: %v", err)
}
logrus.Debugf("layer %s registered successfully", l.DiffID())
rootFS.Append(l.DiffID())
// Cache mapping from this layer's DiffID to the blobsum
if err := p.blobSumService.Add(l.DiffID(), d.digest); err != nil {
return false, err
}
defer layer.ReleaseAndLog(p.config.LayerStore, l)
d.broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(d.digest.String()), "Pull complete", nil))
d.broadcaster.Close()
tagUpdated = true
resultRootFS, release, err := p.config.DownloadManager.Download(ctx, *rootFS, descriptors, p.config.ProgressOutput)
if err != nil {
return false, err
}
defer release()
config, err := v1.MakeConfigFromV1Config([]byte(verifiedManifest.History[0].V1Compatibility), rootFS, history)
config, err := v1.MakeConfigFromV1Config([]byte(verifiedManifest.History[0].V1Compatibility), &resultRootFS, history)
if err != nil {
return false, err
}
@ -403,30 +296,24 @@ func (p *v2Puller) pullV2Tag(out io.Writer, ref reference.Named) (tagUpdated boo
return false, err
}
// Check for new tag if no layers downloaded
var oldTagImageID image.ID
if !tagUpdated {
oldTagImageID, err = p.config.TagStore.Get(ref)
if err != nil || oldTagImageID != imageID {
tagUpdated = true
}
if manifestDigest != "" {
progress.Message(p.config.ProgressOutput, "", "Digest: "+manifestDigest.String())
}
if tagUpdated {
if canonical, ok := ref.(reference.Canonical); ok {
if err = p.config.TagStore.AddDigest(canonical, imageID, true); err != nil {
return false, err
}
} else if err = p.config.TagStore.AddTag(ref, imageID, true); err != nil {
oldTagImageID, err := p.config.TagStore.Get(ref)
if err == nil && oldTagImageID == imageID {
return false, nil
}
if canonical, ok := ref.(reference.Canonical); ok {
if err = p.config.TagStore.AddDigest(canonical, imageID, true); err != nil {
return false, err
}
} else if err = p.config.TagStore.AddTag(ref, imageID, true); err != nil {
return false, err
}
if manifestDigest != "" {
out.Write(p.sf.FormatStatus("", "Digest: %s", manifestDigest))
}
return tagUpdated, nil
return true, nil
}
func verifyManifest(signedManifest *schema1.SignedManifest, ref reference.Reference) (m *schema1.Manifest, err error) {

View file

@ -12,12 +12,14 @@ import (
"github.com/docker/docker/cliconfig"
"github.com/docker/docker/daemon/events"
"github.com/docker/docker/distribution/metadata"
"github.com/docker/docker/distribution/xfer"
"github.com/docker/docker/image"
"github.com/docker/docker/layer"
"github.com/docker/docker/pkg/streamformatter"
"github.com/docker/docker/pkg/progress"
"github.com/docker/docker/registry"
"github.com/docker/docker/tag"
"github.com/docker/libtrust"
"golang.org/x/net/context"
)
// ImagePushConfig stores push configuration.
@ -28,9 +30,9 @@ type ImagePushConfig struct {
// AuthConfig holds authentication credentials for authenticating with
// the registry.
AuthConfig *cliconfig.AuthConfig
// OutStream is the output writer for showing the status of the push
// ProgressOutput is the interface for showing the status of the push
// operation.
OutStream io.Writer
ProgressOutput progress.Output
// RegistryService is the registry service to use for TLS configuration
// and endpoint lookup.
RegistryService *registry.Service
@ -48,6 +50,8 @@ type ImagePushConfig struct {
// TrustKey is the private key for legacy signatures. This is typically
// an ephemeral key, since these signatures are no longer verified.
TrustKey libtrust.PrivateKey
// UploadManager dispatches uploads.
UploadManager *xfer.LayerUploadManager
}
// Pusher is an interface that abstracts pushing for different API versions.
@ -56,7 +60,7 @@ type Pusher interface {
// Push returns an error if any, as well as a boolean that determines whether to retry Push on the next configured endpoint.
//
// TODO(tiborvass): have Push() take a reference to repository + tag, so that the pusher itself is repository-agnostic.
Push() (fallback bool, err error)
Push(ctx context.Context) (fallback bool, err error)
}
const compressionBufSize = 32768
@ -66,7 +70,7 @@ const compressionBufSize = 32768
// whether a v1 or v2 pusher will be created. The other parameters are passed
// through to the underlying pusher implementation for use during the actual
// push operation.
func NewPusher(ref reference.Named, endpoint registry.APIEndpoint, repoInfo *registry.RepositoryInfo, imagePushConfig *ImagePushConfig, sf *streamformatter.StreamFormatter) (Pusher, error) {
func NewPusher(ref reference.Named, endpoint registry.APIEndpoint, repoInfo *registry.RepositoryInfo, imagePushConfig *ImagePushConfig) (Pusher, error) {
switch endpoint.Version {
case registry.APIVersion2:
return &v2Pusher{
@ -75,8 +79,7 @@ func NewPusher(ref reference.Named, endpoint registry.APIEndpoint, repoInfo *reg
endpoint: endpoint,
repoInfo: repoInfo,
config: imagePushConfig,
sf: sf,
layersPushed: make(map[digest.Digest]bool),
layersPushed: pushMap{layersPushed: make(map[digest.Digest]bool)},
}, nil
case registry.APIVersion1:
return &v1Pusher{
@ -85,7 +88,6 @@ func NewPusher(ref reference.Named, endpoint registry.APIEndpoint, repoInfo *reg
endpoint: endpoint,
repoInfo: repoInfo,
config: imagePushConfig,
sf: sf,
}, nil
}
return nil, fmt.Errorf("unknown version %d for registry %s", endpoint.Version, endpoint.URL)
@ -94,11 +96,9 @@ func NewPusher(ref reference.Named, endpoint registry.APIEndpoint, repoInfo *reg
// Push initiates a push operation on the repository named localName.
// ref is the specific variant of the image to be pushed.
// If no tag is provided, all tags will be pushed.
func Push(ref reference.Named, imagePushConfig *ImagePushConfig) error {
func Push(ctx context.Context, ref reference.Named, imagePushConfig *ImagePushConfig) error {
// FIXME: Allow to interrupt current push when new push of same image is done.
var sf = streamformatter.NewJSONStreamFormatter()
// Resolve the Repository name from fqn to RepositoryInfo
repoInfo, err := imagePushConfig.RegistryService.ResolveRepository(ref)
if err != nil {
@ -110,7 +110,7 @@ func Push(ref reference.Named, imagePushConfig *ImagePushConfig) error {
return err
}
imagePushConfig.OutStream.Write(sf.FormatStatus("", "The push refers to a repository [%s]", repoInfo.CanonicalName))
progress.Messagef(imagePushConfig.ProgressOutput, "", "The push refers to a repository [%s]", repoInfo.CanonicalName.String())
associations := imagePushConfig.TagStore.ReferencesByName(repoInfo.LocalName)
if len(associations) == 0 {
@ -121,12 +121,20 @@ func Push(ref reference.Named, imagePushConfig *ImagePushConfig) error {
for _, endpoint := range endpoints {
logrus.Debugf("Trying to push %s to %s %s", repoInfo.CanonicalName, endpoint.URL, endpoint.Version)
pusher, err := NewPusher(ref, endpoint, repoInfo, imagePushConfig, sf)
pusher, err := NewPusher(ref, endpoint, repoInfo, imagePushConfig)
if err != nil {
lastErr = err
continue
}
if fallback, err := pusher.Push(); err != nil {
if fallback, err := pusher.Push(ctx); err != nil {
// Was this push cancelled? If so, don't try to fall
// back.
select {
case <-ctx.Done():
fallback = false
default:
}
if fallback {
lastErr = err
continue

View file

@ -2,8 +2,6 @@ package distribution
import (
"fmt"
"io"
"io/ioutil"
"sync"
"github.com/Sirupsen/logrus"
@ -15,25 +13,23 @@ import (
"github.com/docker/docker/image/v1"
"github.com/docker/docker/layer"
"github.com/docker/docker/pkg/ioutils"
"github.com/docker/docker/pkg/progressreader"
"github.com/docker/docker/pkg/streamformatter"
"github.com/docker/docker/pkg/progress"
"github.com/docker/docker/pkg/stringid"
"github.com/docker/docker/registry"
"golang.org/x/net/context"
)
type v1Pusher struct {
ctx context.Context
v1IDService *metadata.V1IDService
endpoint registry.APIEndpoint
ref reference.Named
repoInfo *registry.RepositoryInfo
config *ImagePushConfig
sf *streamformatter.StreamFormatter
session *registry.Session
out io.Writer
}
func (p *v1Pusher) Push() (fallback bool, err error) {
func (p *v1Pusher) Push(ctx context.Context) (fallback bool, err error) {
tlsConfig, err := p.config.RegistryService.TLSConfig(p.repoInfo.Index.Name)
if err != nil {
return false, err
@ -55,7 +51,7 @@ func (p *v1Pusher) Push() (fallback bool, err error) {
// TODO(dmcgowan): Check if should fallback
return true, err
}
if err := p.pushRepository(); err != nil {
if err := p.pushRepository(ctx); err != nil {
// TODO(dmcgowan): Check if should fallback
return false, err
}
@ -306,12 +302,12 @@ func (p *v1Pusher) lookupImageOnEndpoint(wg *sync.WaitGroup, endpoint string, im
logrus.Errorf("Error in LookupRemoteImage: %s", err)
imagesToPush <- v1ID
} else {
p.out.Write(p.sf.FormatStatus("", "Image %s already pushed, skipping", stringid.TruncateID(v1ID)))
progress.Messagef(p.config.ProgressOutput, "", "Image %s already pushed, skipping", stringid.TruncateID(v1ID))
}
}
}
func (p *v1Pusher) pushImageToEndpoint(endpoint string, imageList []v1Image, tags map[image.ID][]string, repo *registry.RepositoryData) error {
func (p *v1Pusher) pushImageToEndpoint(ctx context.Context, endpoint string, imageList []v1Image, tags map[image.ID][]string, repo *registry.RepositoryData) error {
workerCount := len(imageList)
// start a maximum of 5 workers to check if images exist on the specified endpoint.
if workerCount > 5 {
@ -349,14 +345,14 @@ func (p *v1Pusher) pushImageToEndpoint(endpoint string, imageList []v1Image, tag
for _, img := range imageList {
v1ID := img.V1ID()
if _, push := shouldPush[v1ID]; push {
if _, err := p.pushImage(img, endpoint); err != nil {
if _, err := p.pushImage(ctx, img, endpoint); err != nil {
// FIXME: Continue on error?
return err
}
}
if topImage, isTopImage := img.(*v1TopImage); isTopImage {
for _, tag := range tags[topImage.imageID] {
p.out.Write(p.sf.FormatStatus("", "Pushing tag for rev [%s] on {%s}", stringid.TruncateID(v1ID), endpoint+"repositories/"+p.repoInfo.RemoteName.Name()+"/tags/"+tag))
progress.Messagef(p.config.ProgressOutput, "", "Pushing tag for rev [%s] on {%s}", stringid.TruncateID(v1ID), endpoint+"repositories/"+p.repoInfo.RemoteName.Name()+"/tags/"+tag)
if err := p.session.PushRegistryTag(p.repoInfo.RemoteName, v1ID, tag, endpoint); err != nil {
return err
}
@ -367,8 +363,7 @@ func (p *v1Pusher) pushImageToEndpoint(endpoint string, imageList []v1Image, tag
}
// pushRepository pushes layers that do not already exist on the registry.
func (p *v1Pusher) pushRepository() error {
p.out = ioutils.NewWriteFlusher(p.config.OutStream)
func (p *v1Pusher) pushRepository(ctx context.Context) error {
imgList, tags, referencedLayers, err := p.getImageList()
defer func() {
for _, l := range referencedLayers {
@ -378,7 +373,7 @@ func (p *v1Pusher) pushRepository() error {
if err != nil {
return err
}
p.out.Write(p.sf.FormatStatus("", "Sending image list"))
progress.Message(p.config.ProgressOutput, "", "Sending image list")
imageIndex := createImageIndex(imgList, tags)
for _, data := range imageIndex {
@ -391,10 +386,10 @@ func (p *v1Pusher) pushRepository() error {
if err != nil {
return err
}
p.out.Write(p.sf.FormatStatus("", "Pushing repository %s", p.repoInfo.CanonicalName))
progress.Message(p.config.ProgressOutput, "", "Pushing repository "+p.repoInfo.CanonicalName.String())
// push the repository to each of the endpoints only if it does not exist.
for _, endpoint := range repoData.Endpoints {
if err := p.pushImageToEndpoint(endpoint, imgList, tags, repoData); err != nil {
if err := p.pushImageToEndpoint(ctx, endpoint, imgList, tags, repoData); err != nil {
return err
}
}
@ -402,11 +397,11 @@ func (p *v1Pusher) pushRepository() error {
return err
}
func (p *v1Pusher) pushImage(v1Image v1Image, ep string) (checksum string, err error) {
func (p *v1Pusher) pushImage(ctx context.Context, v1Image v1Image, ep string) (checksum string, err error) {
v1ID := v1Image.V1ID()
jsonRaw := v1Image.Config()
p.out.Write(p.sf.FormatProgress(stringid.TruncateID(v1ID), "Pushing", nil))
progress.Update(p.config.ProgressOutput, stringid.TruncateID(v1ID), "Pushing")
// General rule is to use ID for graph accesses and compatibilityID for
// calls to session.registry()
@ -417,7 +412,7 @@ func (p *v1Pusher) pushImage(v1Image v1Image, ep string) (checksum string, err e
// Send the json
if err := p.session.PushImageJSONRegistry(imgData, jsonRaw, ep); err != nil {
if err == registry.ErrAlreadyExists {
p.out.Write(p.sf.FormatProgress(stringid.TruncateID(v1ID), "Image already pushed, skipping", nil))
progress.Update(p.config.ProgressOutput, stringid.TruncateID(v1ID), "Image already pushed, skipping")
return "", nil
}
return "", err
@ -437,15 +432,8 @@ func (p *v1Pusher) pushImage(v1Image v1Image, ep string) (checksum string, err e
// Send the layer
logrus.Debugf("rendered layer for %s of [%d] size", v1ID, size)
reader := progressreader.New(progressreader.Config{
In: ioutil.NopCloser(arch),
Out: p.out,
Formatter: p.sf,
Size: size,
NewLines: false,
ID: stringid.TruncateID(v1ID),
Action: "Pushing",
})
reader := progress.NewProgressReader(ioutils.NewCancelReadCloser(ctx, arch), p.config.ProgressOutput, size, stringid.TruncateID(v1ID), "Pushing")
defer reader.Close()
checksum, checksumPayload, err := p.session.PushImageLayerRegistry(v1ID, reader, ep, jsonRaw)
if err != nil {
@ -458,10 +446,10 @@ func (p *v1Pusher) pushImage(v1Image v1Image, ep string) (checksum string, err e
return "", err
}
if err := p.v1IDService.Set(v1ID, p.repoInfo.Index.Name, l.ChainID()); err != nil {
if err := p.v1IDService.Set(v1ID, p.repoInfo.Index.Name, l.DiffID()); err != nil {
logrus.Warnf("Could not set v1 ID mapping: %v", err)
}
p.out.Write(p.sf.FormatProgress(stringid.TruncateID(v1ID), "Image successfully pushed", nil))
progress.Update(p.config.ProgressOutput, stringid.TruncateID(v1ID), "Image successfully pushed")
return imgData.Checksum, nil
}

View file

@ -5,7 +5,7 @@ import (
"errors"
"fmt"
"io"
"io/ioutil"
"sync"
"time"
"github.com/Sirupsen/logrus"
@ -15,11 +15,12 @@ import (
"github.com/docker/distribution/manifest/schema1"
"github.com/docker/distribution/reference"
"github.com/docker/docker/distribution/metadata"
"github.com/docker/docker/distribution/xfer"
"github.com/docker/docker/image"
"github.com/docker/docker/image/v1"
"github.com/docker/docker/layer"
"github.com/docker/docker/pkg/progressreader"
"github.com/docker/docker/pkg/streamformatter"
"github.com/docker/docker/pkg/ioutils"
"github.com/docker/docker/pkg/progress"
"github.com/docker/docker/pkg/stringid"
"github.com/docker/docker/registry"
"github.com/docker/docker/tag"
@ -32,16 +33,20 @@ type v2Pusher struct {
endpoint registry.APIEndpoint
repoInfo *registry.RepositoryInfo
config *ImagePushConfig
sf *streamformatter.StreamFormatter
repo distribution.Repository
// layersPushed is the set of layers known to exist on the remote side.
// This avoids redundant queries when pushing multiple tags that
// involve the same layers.
layersPushed pushMap
}
type pushMap struct {
sync.Mutex
layersPushed map[digest.Digest]bool
}
func (p *v2Pusher) Push() (fallback bool, err error) {
func (p *v2Pusher) Push(ctx context.Context) (fallback bool, err error) {
p.repo, err = NewV2Repository(p.repoInfo, p.endpoint, p.config.MetaHeaders, p.config.AuthConfig, "push", "pull")
if err != nil {
logrus.Debugf("Error getting v2 registry: %v", err)
@ -75,7 +80,7 @@ func (p *v2Pusher) Push() (fallback bool, err error) {
}
for _, association := range associations {
if err := p.pushV2Tag(association); err != nil {
if err := p.pushV2Tag(ctx, association); err != nil {
return false, err
}
}
@ -83,7 +88,7 @@ func (p *v2Pusher) Push() (fallback bool, err error) {
return false, nil
}
func (p *v2Pusher) pushV2Tag(association tag.Association) error {
func (p *v2Pusher) pushV2Tag(ctx context.Context, association tag.Association) error {
ref := association.Ref
logrus.Debugf("Pushing repository: %s", ref.String())
@ -92,8 +97,6 @@ func (p *v2Pusher) pushV2Tag(association tag.Association) error {
return fmt.Errorf("could not find image from tag %s: %v", ref.String(), err)
}
out := p.config.OutStream
var l layer.Layer
topLayerID := img.RootFS.ChainID()
@ -107,33 +110,41 @@ func (p *v2Pusher) pushV2Tag(association tag.Association) error {
defer layer.ReleaseAndLog(p.config.LayerStore, l)
}
fsLayers := make(map[layer.DiffID]schema1.FSLayer)
var descriptors []xfer.UploadDescriptor
// Push empty layer if necessary
for _, h := range img.History {
if h.EmptyLayer {
dgst, err := p.pushLayerIfNecessary(out, layer.EmptyLayer)
if err != nil {
return err
descriptors = []xfer.UploadDescriptor{
&v2PushDescriptor{
layer: layer.EmptyLayer,
blobSumService: p.blobSumService,
repo: p.repo,
layersPushed: &p.layersPushed,
},
}
p.layersPushed[dgst] = true
fsLayers[layer.EmptyLayer.DiffID()] = schema1.FSLayer{BlobSum: dgst}
break
}
}
// Loop bounds condition is to avoid pushing the base layer on Windows.
for i := 0; i < len(img.RootFS.DiffIDs); i++ {
dgst, err := p.pushLayerIfNecessary(out, l)
if err != nil {
return err
descriptor := &v2PushDescriptor{
layer: l,
blobSumService: p.blobSumService,
repo: p.repo,
layersPushed: &p.layersPushed,
}
p.layersPushed[dgst] = true
fsLayers[l.DiffID()] = schema1.FSLayer{BlobSum: dgst}
descriptors = append(descriptors, descriptor)
l = l.Parent()
}
fsLayers, err := p.config.UploadManager.Upload(ctx, descriptors, p.config.ProgressOutput)
if err != nil {
return err
}
var tag string
if tagged, isTagged := ref.(reference.Tagged); isTagged {
tag = tagged.Tag()
@ -157,59 +168,124 @@ func (p *v2Pusher) pushV2Tag(association tag.Association) error {
if tagged, isTagged := ref.(reference.Tagged); isTagged {
// NOTE: do not change this format without first changing the trust client
// code. This information is used to determine what was pushed and should be signed.
out.Write(p.sf.FormatStatus("", "%s: digest: %s size: %d", tagged.Tag(), manifestDigest, manifestSize))
progress.Messagef(p.config.ProgressOutput, "", "%s: digest: %s size: %d", tagged.Tag(), manifestDigest, manifestSize)
}
}
manSvc, err := p.repo.Manifests(context.Background())
manSvc, err := p.repo.Manifests(ctx)
if err != nil {
return err
}
return manSvc.Put(signed)
}
func (p *v2Pusher) pushLayerIfNecessary(out io.Writer, l layer.Layer) (digest.Digest, error) {
logrus.Debugf("Pushing layer: %s", l.DiffID())
type v2PushDescriptor struct {
layer layer.Layer
blobSumService *metadata.BlobSumService
repo distribution.Repository
layersPushed *pushMap
}
func (pd *v2PushDescriptor) Key() string {
return "v2push:" + pd.repo.Name() + " " + pd.layer.DiffID().String()
}
func (pd *v2PushDescriptor) ID() string {
return stringid.TruncateID(pd.layer.DiffID().String())
}
func (pd *v2PushDescriptor) DiffID() layer.DiffID {
return pd.layer.DiffID()
}
func (pd *v2PushDescriptor) Upload(ctx context.Context, progressOutput progress.Output) (digest.Digest, error) {
diffID := pd.DiffID()
logrus.Debugf("Pushing layer: %s", diffID)
// Do we have any blobsums associated with this layer's DiffID?
possibleBlobsums, err := p.blobSumService.GetBlobSums(l.DiffID())
possibleBlobsums, err := pd.blobSumService.GetBlobSums(diffID)
if err == nil {
dgst, exists, err := p.blobSumAlreadyExists(possibleBlobsums)
dgst, exists, err := blobSumAlreadyExists(ctx, possibleBlobsums, pd.repo, pd.layersPushed)
if err != nil {
out.Write(p.sf.FormatProgress(stringid.TruncateID(string(l.DiffID())), "Image push failed", nil))
return "", err
progress.Update(progressOutput, pd.ID(), "Image push failed")
return "", retryOnError(err)
}
if exists {
out.Write(p.sf.FormatProgress(stringid.TruncateID(string(l.DiffID())), "Layer already exists", nil))
progress.Update(progressOutput, pd.ID(), "Layer already exists")
return dgst, nil
}
}
// if digest was empty or not saved, or if blob does not exist on the remote repository,
// then push the blob.
pushDigest, err := p.pushV2Layer(p.repo.Blobs(context.Background()), l)
bs := pd.repo.Blobs(ctx)
// Send the layer
layerUpload, err := bs.Create(ctx)
if err != nil {
return "", err
return "", retryOnError(err)
}
defer layerUpload.Close()
arch, err := pd.layer.TarStream()
if err != nil {
return "", xfer.DoNotRetry{Err: err}
}
// don't care if this fails; best effort
size, _ := pd.layer.DiffSize()
reader := progress.NewProgressReader(ioutils.NewCancelReadCloser(ctx, arch), progressOutput, size, pd.ID(), "Pushing")
defer reader.Close()
compressedReader := compress(reader)
digester := digest.Canonical.New()
tee := io.TeeReader(compressedReader, digester.Hash())
nn, err := layerUpload.ReadFrom(tee)
compressedReader.Close()
if err != nil {
return "", retryOnError(err)
}
pushDigest := digester.Digest()
if _, err := layerUpload.Commit(ctx, distribution.Descriptor{Digest: pushDigest}); err != nil {
return "", retryOnError(err)
}
logrus.Debugf("uploaded layer %s (%s), %d bytes", diffID, pushDigest, nn)
progress.Update(progressOutput, pd.ID(), "Pushed")
// Cache mapping from this layer's DiffID to the blobsum
if err := p.blobSumService.Add(l.DiffID(), pushDigest); err != nil {
return "", err
if err := pd.blobSumService.Add(diffID, pushDigest); err != nil {
return "", xfer.DoNotRetry{Err: err}
}
pd.layersPushed.Lock()
pd.layersPushed.layersPushed[pushDigest] = true
pd.layersPushed.Unlock()
return pushDigest, nil
}
// blobSumAlreadyExists checks if the registry already know about any of the
// blobsums passed in the "blobsums" slice. If it finds one that the registry
// knows about, it returns the known digest and "true".
func (p *v2Pusher) blobSumAlreadyExists(blobsums []digest.Digest) (digest.Digest, bool, error) {
func blobSumAlreadyExists(ctx context.Context, blobsums []digest.Digest, repo distribution.Repository, layersPushed *pushMap) (digest.Digest, bool, error) {
layersPushed.Lock()
for _, dgst := range blobsums {
if p.layersPushed[dgst] {
if layersPushed.layersPushed[dgst] {
// it is already known that the push is not needed and
// therefore doing a stat is unnecessary
layersPushed.Unlock()
return dgst, true, nil
}
_, err := p.repo.Blobs(context.Background()).Stat(context.Background(), dgst)
}
layersPushed.Unlock()
for _, dgst := range blobsums {
_, err := repo.Blobs(ctx).Stat(ctx, dgst)
switch err {
case nil:
return dgst, true, nil
@ -226,7 +302,7 @@ func (p *v2Pusher) blobSumAlreadyExists(blobsums []digest.Digest) (digest.Digest
// FSLayer digests.
// FIXME: This should be moved to the distribution repo, since it will also
// be useful for converting new manifests to the old format.
func CreateV2Manifest(name, tag string, img *image.Image, fsLayers map[layer.DiffID]schema1.FSLayer) (*schema1.Manifest, error) {
func CreateV2Manifest(name, tag string, img *image.Image, fsLayers map[layer.DiffID]digest.Digest) (*schema1.Manifest, error) {
if len(img.History) == 0 {
return nil, errors.New("empty history when trying to create V2 manifest")
}
@ -271,7 +347,7 @@ func CreateV2Manifest(name, tag string, img *image.Image, fsLayers map[layer.Dif
if !present {
return nil, fmt.Errorf("missing layer in CreateV2Manifest: %s", diffID.String())
}
dgst, err := digest.FromBytes([]byte(fsLayer.BlobSum.Hex() + " " + parent))
dgst, err := digest.FromBytes([]byte(fsLayer.Hex() + " " + parent))
if err != nil {
return nil, err
}
@ -294,7 +370,7 @@ func CreateV2Manifest(name, tag string, img *image.Image, fsLayers map[layer.Dif
reversedIndex := len(img.History) - i - 1
history[reversedIndex].V1Compatibility = string(jsonBytes)
fsLayerList[reversedIndex] = fsLayer
fsLayerList[reversedIndex] = schema1.FSLayer{BlobSum: fsLayer}
parent = v1ID
}
@ -315,11 +391,11 @@ func CreateV2Manifest(name, tag string, img *image.Image, fsLayers map[layer.Dif
return nil, fmt.Errorf("missing layer in CreateV2Manifest: %s", diffID.String())
}
dgst, err := digest.FromBytes([]byte(fsLayer.BlobSum.Hex() + " " + parent + " " + string(img.RawJSON())))
dgst, err := digest.FromBytes([]byte(fsLayer.Hex() + " " + parent + " " + string(img.RawJSON())))
if err != nil {
return nil, err
}
fsLayerList[0] = fsLayer
fsLayerList[0] = schema1.FSLayer{BlobSum: fsLayer}
// Top-level v1compatibility string should be a modified version of the
// image config.
@ -346,66 +422,3 @@ func CreateV2Manifest(name, tag string, img *image.Image, fsLayers map[layer.Dif
History: history,
}, nil
}
func rawJSON(value interface{}) *json.RawMessage {
jsonval, err := json.Marshal(value)
if err != nil {
return nil
}
return (*json.RawMessage)(&jsonval)
}
func (p *v2Pusher) pushV2Layer(bs distribution.BlobService, l layer.Layer) (digest.Digest, error) {
out := p.config.OutStream
displayID := stringid.TruncateID(string(l.DiffID()))
out.Write(p.sf.FormatProgress(displayID, "Preparing", nil))
arch, err := l.TarStream()
if err != nil {
return "", err
}
defer arch.Close()
// Send the layer
layerUpload, err := bs.Create(context.Background())
if err != nil {
return "", err
}
defer layerUpload.Close()
// don't care if this fails; best effort
size, _ := l.DiffSize()
reader := progressreader.New(progressreader.Config{
In: ioutil.NopCloser(arch), // we'll take care of close here.
Out: out,
Formatter: p.sf,
Size: size,
NewLines: false,
ID: displayID,
Action: "Pushing",
})
compressedReader := compress(reader)
digester := digest.Canonical.New()
tee := io.TeeReader(compressedReader, digester.Hash())
out.Write(p.sf.FormatProgress(displayID, "Pushing", nil))
nn, err := layerUpload.ReadFrom(tee)
compressedReader.Close()
if err != nil {
return "", err
}
dgst := digester.Digest()
if _, err := layerUpload.Commit(context.Background(), distribution.Descriptor{Digest: dgst}); err != nil {
return "", err
}
logrus.Debugf("uploaded layer %s (%s), %d bytes", l.DiffID(), dgst, nn)
out.Write(p.sf.FormatProgress(displayID, "Pushed", nil))
return dgst, nil
}

View file

@ -116,10 +116,10 @@ func TestCreateV2Manifest(t *testing.T) {
t.Fatalf("json decoding failed: %v", err)
}
fsLayers := map[layer.DiffID]schema1.FSLayer{
layer.DiffID("sha256:c6f988f4874bb0add23a778f753c65efe992244e148a1d2ec2a8b664fb66bbd1"): {BlobSum: digest.Digest("sha256:a3ed95caeb02ffe68cdd9fd84406680ae93d633cb16422d00e8a7c22955b46d4")},
layer.DiffID("sha256:5f70bf18a086007016e948b04aed3b82103a36bea41755b6cddfaf10ace3c6ef"): {BlobSum: digest.Digest("sha256:86e0e091d0da6bde2456dbb48306f3956bbeb2eae1b5b9a43045843f69fe4aaa")},
layer.DiffID("sha256:13f53e08df5a220ab6d13c58b2bf83a59cbdc2e04d0a3f041ddf4b0ba4112d49"): {BlobSum: digest.Digest("sha256:b4ed95caeb02ffe68cdd9fd84406680ae93d633cb16422d00e8a7c22955b46d4")},
fsLayers := map[layer.DiffID]digest.Digest{
layer.DiffID("sha256:c6f988f4874bb0add23a778f753c65efe992244e148a1d2ec2a8b664fb66bbd1"): digest.Digest("sha256:a3ed95caeb02ffe68cdd9fd84406680ae93d633cb16422d00e8a7c22955b46d4"),
layer.DiffID("sha256:5f70bf18a086007016e948b04aed3b82103a36bea41755b6cddfaf10ace3c6ef"): digest.Digest("sha256:86e0e091d0da6bde2456dbb48306f3956bbeb2eae1b5b9a43045843f69fe4aaa"),
layer.DiffID("sha256:13f53e08df5a220ab6d13c58b2bf83a59cbdc2e04d0a3f041ddf4b0ba4112d49"): digest.Digest("sha256:b4ed95caeb02ffe68cdd9fd84406680ae93d633cb16422d00e8a7c22955b46d4"),
}
manifest, err := CreateV2Manifest("testrepo", "testtag", img, fsLayers)

View file

@ -13,10 +13,12 @@ import (
"github.com/docker/distribution"
"github.com/docker/distribution/digest"
"github.com/docker/distribution/manifest/schema1"
"github.com/docker/distribution/registry/api/errcode"
"github.com/docker/distribution/registry/client"
"github.com/docker/distribution/registry/client/auth"
"github.com/docker/distribution/registry/client/transport"
"github.com/docker/docker/cliconfig"
"github.com/docker/docker/distribution/xfer"
"github.com/docker/docker/registry"
"golang.org/x/net/context"
)
@ -59,7 +61,7 @@ func NewV2Repository(repoInfo *registry.RepositoryInfo, endpoint registry.APIEnd
authTransport := transport.NewTransport(base, modifiers...)
pingClient := &http.Client{
Transport: authTransport,
Timeout: 5 * time.Second,
Timeout: 15 * time.Second,
}
endpointStr := strings.TrimRight(endpoint.URL, "/") + "/v2/"
req, err := http.NewRequest("GET", endpointStr, nil)
@ -132,3 +134,23 @@ func (th *existingTokenHandler) AuthorizeRequest(req *http.Request, params map[s
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", th.token))
return nil
}
// retryOnError wraps the error in xfer.DoNotRetry if we should not retry the
// operation after this error.
func retryOnError(err error) error {
switch v := err.(type) {
case errcode.Errors:
return retryOnError(v[0])
case errcode.Error:
switch v.Code {
case errcode.ErrorCodeUnauthorized, errcode.ErrorCodeUnsupported, errcode.ErrorCodeDenied:
return xfer.DoNotRetry{Err: err}
}
}
// let's be nice and fallback if the error is a completely
// unexpected one.
// If new errors have to be handled in some way, please
// add them to the switch above.
return err
}

View file

@ -11,9 +11,9 @@ import (
"github.com/docker/distribution/reference"
"github.com/docker/distribution/registry/client/auth"
"github.com/docker/docker/cliconfig"
"github.com/docker/docker/pkg/streamformatter"
"github.com/docker/docker/registry"
"github.com/docker/docker/utils"
"golang.org/x/net/context"
)
func TestTokenPassThru(t *testing.T) {
@ -72,8 +72,7 @@ func TestTokenPassThru(t *testing.T) {
MetaHeaders: http.Header{},
AuthConfig: authConfig,
}
sf := streamformatter.NewJSONStreamFormatter()
puller, err := newPuller(endpoint, repoInfo, imagePullConfig, sf)
puller, err := newPuller(endpoint, repoInfo, imagePullConfig)
if err != nil {
t.Fatal(err)
}
@ -86,7 +85,7 @@ func TestTokenPassThru(t *testing.T) {
logrus.Debug("About to pull")
// We expect it to fail, since we haven't mock'd the full registry exchange in our handler above
tag, _ := reference.WithTag(n, "tag_goes_here")
_ = p.pullV2Repository(tag)
_ = p.pullV2Repository(context.Background(), tag)
if !gotToken {
t.Fatal("Failed to receive registry token")

View file

@ -0,0 +1,420 @@
package xfer
import (
"errors"
"fmt"
"io"
"time"
"github.com/Sirupsen/logrus"
"github.com/docker/docker/image"
"github.com/docker/docker/layer"
"github.com/docker/docker/pkg/archive"
"github.com/docker/docker/pkg/ioutils"
"github.com/docker/docker/pkg/progress"
"golang.org/x/net/context"
)
const maxDownloadAttempts = 5
// LayerDownloadManager figures out which layers need to be downloaded, then
// registers and downloads those, taking into account dependencies between
// layers.
type LayerDownloadManager struct {
layerStore layer.Store
tm TransferManager
}
// NewLayerDownloadManager returns a new LayerDownloadManager.
func NewLayerDownloadManager(layerStore layer.Store, concurrencyLimit int) *LayerDownloadManager {
return &LayerDownloadManager{
layerStore: layerStore,
tm: NewTransferManager(concurrencyLimit),
}
}
type downloadTransfer struct {
Transfer
layerStore layer.Store
layer layer.Layer
err error
}
// result returns the layer resulting from the download, if the download
// and registration were successful.
func (d *downloadTransfer) result() (layer.Layer, error) {
return d.layer, d.err
}
// A DownloadDescriptor references a layer that may need to be downloaded.
type DownloadDescriptor interface {
// Key returns the key used to deduplicate downloads.
Key() string
// ID returns the ID for display purposes.
ID() string
// DiffID should return the DiffID for this layer, or an error
// if it is unknown (for example, if it has not been downloaded
// before).
DiffID() (layer.DiffID, error)
// Download is called to perform the download.
Download(ctx context.Context, progressOutput progress.Output) (io.ReadCloser, int64, error)
}
// DownloadDescriptorWithRegistered is a DownloadDescriptor that has an
// additional Registered method which gets called after a downloaded layer is
// registered. This allows the user of the download manager to know the DiffID
// of each registered layer. This method is called if a cast to
// DownloadDescriptorWithRegistered is successful.
type DownloadDescriptorWithRegistered interface {
DownloadDescriptor
Registered(diffID layer.DiffID)
}
// Download is a blocking function which ensures the requested layers are
// present in the layer store. It uses the string returned by the Key method to
// deduplicate downloads. If a given layer is not already known to present in
// the layer store, and the key is not used by an in-progress download, the
// Download method is called to get the layer tar data. Layers are then
// registered in the appropriate order. The caller must call the returned
// release function once it is is done with the returned RootFS object.
func (ldm *LayerDownloadManager) Download(ctx context.Context, initialRootFS image.RootFS, layers []DownloadDescriptor, progressOutput progress.Output) (image.RootFS, func(), error) {
var (
topLayer layer.Layer
topDownload *downloadTransfer
watcher *Watcher
missingLayer bool
transferKey = ""
downloadsByKey = make(map[string]*downloadTransfer)
)
rootFS := initialRootFS
for _, descriptor := range layers {
key := descriptor.Key()
transferKey += key
if !missingLayer {
missingLayer = true
diffID, err := descriptor.DiffID()
if err == nil {
getRootFS := rootFS
getRootFS.Append(diffID)
l, err := ldm.layerStore.Get(getRootFS.ChainID())
if err == nil {
// Layer already exists.
logrus.Debugf("Layer already exists: %s", descriptor.ID())
progress.Update(progressOutput, descriptor.ID(), "Already exists")
if topLayer != nil {
layer.ReleaseAndLog(ldm.layerStore, topLayer)
}
topLayer = l
missingLayer = false
rootFS.Append(diffID)
continue
}
}
}
// Does this layer have the same data as a previous layer in
// the stack? If so, avoid downloading it more than once.
var topDownloadUncasted Transfer
if existingDownload, ok := downloadsByKey[key]; ok {
xferFunc := ldm.makeDownloadFuncFromDownload(descriptor, existingDownload, topDownload)
defer topDownload.Transfer.Release(watcher)
topDownloadUncasted, watcher = ldm.tm.Transfer(transferKey, xferFunc, progressOutput)
topDownload = topDownloadUncasted.(*downloadTransfer)
continue
}
// Layer is not known to exist - download and register it.
progress.Update(progressOutput, descriptor.ID(), "Pulling fs layer")
var xferFunc DoFunc
if topDownload != nil {
xferFunc = ldm.makeDownloadFunc(descriptor, "", topDownload)
defer topDownload.Transfer.Release(watcher)
} else {
xferFunc = ldm.makeDownloadFunc(descriptor, rootFS.ChainID(), nil)
}
topDownloadUncasted, watcher = ldm.tm.Transfer(transferKey, xferFunc, progressOutput)
topDownload = topDownloadUncasted.(*downloadTransfer)
downloadsByKey[key] = topDownload
}
if topDownload == nil {
return rootFS, func() { layer.ReleaseAndLog(ldm.layerStore, topLayer) }, nil
}
// Won't be using the list built up so far - will generate it
// from downloaded layers instead.
rootFS.DiffIDs = []layer.DiffID{}
defer func() {
if topLayer != nil {
layer.ReleaseAndLog(ldm.layerStore, topLayer)
}
}()
select {
case <-ctx.Done():
topDownload.Transfer.Release(watcher)
return rootFS, func() {}, ctx.Err()
case <-topDownload.Done():
break
}
l, err := topDownload.result()
if err != nil {
topDownload.Transfer.Release(watcher)
return rootFS, func() {}, err
}
// Must do this exactly len(layers) times, so we don't include the
// base layer on Windows.
for range layers {
if l == nil {
topDownload.Transfer.Release(watcher)
return rootFS, func() {}, errors.New("internal error: too few parent layers")
}
rootFS.DiffIDs = append([]layer.DiffID{l.DiffID()}, rootFS.DiffIDs...)
l = l.Parent()
}
return rootFS, func() { topDownload.Transfer.Release(watcher) }, err
}
// makeDownloadFunc returns a function that performs the layer download and
// registration. If parentDownload is non-nil, it waits for that download to
// complete before the registration step, and registers the downloaded data
// on top of parentDownload's resulting layer. Otherwise, it registers the
// layer on top of the ChainID given by parentLayer.
func (ldm *LayerDownloadManager) makeDownloadFunc(descriptor DownloadDescriptor, parentLayer layer.ChainID, parentDownload *downloadTransfer) DoFunc {
return func(progressChan chan<- progress.Progress, start <-chan struct{}, inactive chan<- struct{}) Transfer {
d := &downloadTransfer{
Transfer: NewTransfer(),
layerStore: ldm.layerStore,
}
go func() {
defer func() {
close(progressChan)
}()
progressOutput := progress.ChanOutput(progressChan)
select {
case <-start:
default:
progress.Update(progressOutput, descriptor.ID(), "Waiting")
<-start
}
if parentDownload != nil {
// Did the parent download already fail or get
// cancelled?
select {
case <-parentDownload.Done():
_, err := parentDownload.result()
if err != nil {
d.err = err
return
}
default:
}
}
var (
downloadReader io.ReadCloser
size int64
err error
retries int
)
for {
downloadReader, size, err = descriptor.Download(d.Transfer.Context(), progressOutput)
if err == nil {
break
}
// If an error was returned because the context
// was cancelled, we shouldn't retry.
select {
case <-d.Transfer.Context().Done():
d.err = err
return
default:
}
retries++
if _, isDNR := err.(DoNotRetry); isDNR || retries == maxDownloadAttempts {
logrus.Errorf("Download failed: %v", err)
d.err = err
return
}
logrus.Errorf("Download failed, retrying: %v", err)
delay := retries * 5
ticker := time.NewTicker(time.Second)
selectLoop:
for {
progress.Updatef(progressOutput, descriptor.ID(), "Retrying in %d seconds", delay)
select {
case <-ticker.C:
delay--
if delay == 0 {
ticker.Stop()
break selectLoop
}
case <-d.Transfer.Context().Done():
ticker.Stop()
d.err = errors.New("download cancelled during retry delay")
return
}
}
}
close(inactive)
if parentDownload != nil {
select {
case <-d.Transfer.Context().Done():
d.err = errors.New("layer registration cancelled")
downloadReader.Close()
return
case <-parentDownload.Done():
}
l, err := parentDownload.result()
if err != nil {
d.err = err
downloadReader.Close()
return
}
parentLayer = l.ChainID()
}
reader := progress.NewProgressReader(ioutils.NewCancelReadCloser(d.Transfer.Context(), downloadReader), progressOutput, size, descriptor.ID(), "Extracting")
defer reader.Close()
inflatedLayerData, err := archive.DecompressStream(reader)
if err != nil {
d.err = fmt.Errorf("could not get decompression stream: %v", err)
return
}
d.layer, err = d.layerStore.Register(inflatedLayerData, parentLayer)
if err != nil {
select {
case <-d.Transfer.Context().Done():
d.err = errors.New("layer registration cancelled")
default:
d.err = fmt.Errorf("failed to register layer: %v", err)
}
return
}
progress.Update(progressOutput, descriptor.ID(), "Pull complete")
withRegistered, hasRegistered := descriptor.(DownloadDescriptorWithRegistered)
if hasRegistered {
withRegistered.Registered(d.layer.DiffID())
}
// Doesn't actually need to be its own goroutine, but
// done like this so we can defer close(c).
go func() {
<-d.Transfer.Released()
if d.layer != nil {
layer.ReleaseAndLog(d.layerStore, d.layer)
}
}()
}()
return d
}
}
// makeDownloadFuncFromDownload returns a function that performs the layer
// registration when the layer data is coming from an existing download. It
// waits for sourceDownload and parentDownload to complete, and then
// reregisters the data from sourceDownload's top layer on top of
// parentDownload. This function does not log progress output because it would
// interfere with the progress reporting for sourceDownload, which has the same
// Key.
func (ldm *LayerDownloadManager) makeDownloadFuncFromDownload(descriptor DownloadDescriptor, sourceDownload *downloadTransfer, parentDownload *downloadTransfer) DoFunc {
return func(progressChan chan<- progress.Progress, start <-chan struct{}, inactive chan<- struct{}) Transfer {
d := &downloadTransfer{
Transfer: NewTransfer(),
layerStore: ldm.layerStore,
}
go func() {
defer func() {
close(progressChan)
}()
<-start
close(inactive)
select {
case <-d.Transfer.Context().Done():
d.err = errors.New("layer registration cancelled")
return
case <-parentDownload.Done():
}
l, err := parentDownload.result()
if err != nil {
d.err = err
return
}
parentLayer := l.ChainID()
// sourceDownload should have already finished if
// parentDownload finished, but wait for it explicitly
// to be sure.
select {
case <-d.Transfer.Context().Done():
d.err = errors.New("layer registration cancelled")
return
case <-sourceDownload.Done():
}
l, err = sourceDownload.result()
if err != nil {
d.err = err
return
}
layerReader, err := l.TarStream()
if err != nil {
d.err = err
return
}
defer layerReader.Close()
d.layer, err = d.layerStore.Register(layerReader, parentLayer)
if err != nil {
d.err = fmt.Errorf("failed to register layer: %v", err)
return
}
withRegistered, hasRegistered := descriptor.(DownloadDescriptorWithRegistered)
if hasRegistered {
withRegistered.Registered(d.layer.DiffID())
}
// Doesn't actually need to be its own goroutine, but
// done like this so we can defer close(c).
go func() {
<-d.Transfer.Released()
if d.layer != nil {
layer.ReleaseAndLog(d.layerStore, d.layer)
}
}()
}()
return d
}
}

View file

@ -0,0 +1,332 @@
package xfer
import (
"bytes"
"errors"
"io"
"io/ioutil"
"sync/atomic"
"testing"
"time"
"github.com/docker/distribution/digest"
"github.com/docker/docker/image"
"github.com/docker/docker/layer"
"github.com/docker/docker/pkg/archive"
"github.com/docker/docker/pkg/progress"
"golang.org/x/net/context"
)
const maxDownloadConcurrency = 3
type mockLayer struct {
layerData bytes.Buffer
diffID layer.DiffID
chainID layer.ChainID
parent layer.Layer
}
func (ml *mockLayer) TarStream() (io.ReadCloser, error) {
return ioutil.NopCloser(bytes.NewBuffer(ml.layerData.Bytes())), nil
}
func (ml *mockLayer) ChainID() layer.ChainID {
return ml.chainID
}
func (ml *mockLayer) DiffID() layer.DiffID {
return ml.diffID
}
func (ml *mockLayer) Parent() layer.Layer {
return ml.parent
}
func (ml *mockLayer) Size() (size int64, err error) {
return 0, nil
}
func (ml *mockLayer) DiffSize() (size int64, err error) {
return 0, nil
}
func (ml *mockLayer) Metadata() (map[string]string, error) {
return make(map[string]string), nil
}
type mockLayerStore struct {
layers map[layer.ChainID]*mockLayer
}
func createChainIDFromParent(parent layer.ChainID, dgsts ...layer.DiffID) layer.ChainID {
if len(dgsts) == 0 {
return parent
}
if parent == "" {
return createChainIDFromParent(layer.ChainID(dgsts[0]), dgsts[1:]...)
}
// H = "H(n-1) SHA256(n)"
dgst, err := digest.FromBytes([]byte(string(parent) + " " + string(dgsts[0])))
if err != nil {
// Digest calculation is not expected to throw an error,
// any error at this point is a program error
panic(err)
}
return createChainIDFromParent(layer.ChainID(dgst), dgsts[1:]...)
}
func (ls *mockLayerStore) Register(reader io.Reader, parentID layer.ChainID) (layer.Layer, error) {
var (
parent layer.Layer
err error
)
if parentID != "" {
parent, err = ls.Get(parentID)
if err != nil {
return nil, err
}
}
l := &mockLayer{parent: parent}
_, err = l.layerData.ReadFrom(reader)
if err != nil {
return nil, err
}
diffID, err := digest.FromBytes(l.layerData.Bytes())
if err != nil {
return nil, err
}
l.diffID = layer.DiffID(diffID)
l.chainID = createChainIDFromParent(parentID, l.diffID)
ls.layers[l.chainID] = l
return l, nil
}
func (ls *mockLayerStore) Get(chainID layer.ChainID) (layer.Layer, error) {
l, ok := ls.layers[chainID]
if !ok {
return nil, layer.ErrLayerDoesNotExist
}
return l, nil
}
func (ls *mockLayerStore) Release(l layer.Layer) ([]layer.Metadata, error) {
return []layer.Metadata{}, nil
}
func (ls *mockLayerStore) Mount(id string, parent layer.ChainID, label string, init layer.MountInit) (layer.RWLayer, error) {
return nil, errors.New("not implemented")
}
func (ls *mockLayerStore) Unmount(id string) error {
return errors.New("not implemented")
}
func (ls *mockLayerStore) DeleteMount(id string) ([]layer.Metadata, error) {
return nil, errors.New("not implemented")
}
func (ls *mockLayerStore) Changes(id string) ([]archive.Change, error) {
return nil, errors.New("not implemented")
}
type mockDownloadDescriptor struct {
currentDownloads *int32
id string
diffID layer.DiffID
registeredDiffID layer.DiffID
expectedDiffID layer.DiffID
simulateRetries int
}
// Key returns the key used to deduplicate downloads.
func (d *mockDownloadDescriptor) Key() string {
return d.id
}
// ID returns the ID for display purposes.
func (d *mockDownloadDescriptor) ID() string {
return d.id
}
// DiffID should return the DiffID for this layer, or an error
// if it is unknown (for example, if it has not been downloaded
// before).
func (d *mockDownloadDescriptor) DiffID() (layer.DiffID, error) {
if d.diffID != "" {
return d.diffID, nil
}
return "", errors.New("no diffID available")
}
func (d *mockDownloadDescriptor) Registered(diffID layer.DiffID) {
d.registeredDiffID = diffID
}
func (d *mockDownloadDescriptor) mockTarStream() io.ReadCloser {
// The mock implementation returns the ID repeated 5 times as a tar
// stream instead of actual tar data. The data is ignored except for
// computing IDs.
return ioutil.NopCloser(bytes.NewBuffer([]byte(d.id + d.id + d.id + d.id + d.id)))
}
// Download is called to perform the download.
func (d *mockDownloadDescriptor) Download(ctx context.Context, progressOutput progress.Output) (io.ReadCloser, int64, error) {
if d.currentDownloads != nil {
defer atomic.AddInt32(d.currentDownloads, -1)
if atomic.AddInt32(d.currentDownloads, 1) > maxDownloadConcurrency {
return nil, 0, errors.New("concurrency limit exceeded")
}
}
// Sleep a bit to simulate a time-consuming download.
for i := int64(0); i <= 10; i++ {
select {
case <-ctx.Done():
return nil, 0, ctx.Err()
case <-time.After(10 * time.Millisecond):
progressOutput.WriteProgress(progress.Progress{ID: d.ID(), Action: "Downloading", Current: i, Total: 10})
}
}
if d.simulateRetries != 0 {
d.simulateRetries--
return nil, 0, errors.New("simulating retry")
}
return d.mockTarStream(), 0, nil
}
func downloadDescriptors(currentDownloads *int32) []DownloadDescriptor {
return []DownloadDescriptor{
&mockDownloadDescriptor{
currentDownloads: currentDownloads,
id: "id1",
expectedDiffID: layer.DiffID("sha256:68e2c75dc5c78ea9240689c60d7599766c213ae210434c53af18470ae8c53ec1"),
},
&mockDownloadDescriptor{
currentDownloads: currentDownloads,
id: "id2",
expectedDiffID: layer.DiffID("sha256:64a636223116aa837973a5d9c2bdd17d9b204e4f95ac423e20e65dfbb3655473"),
},
&mockDownloadDescriptor{
currentDownloads: currentDownloads,
id: "id3",
expectedDiffID: layer.DiffID("sha256:58745a8bbd669c25213e9de578c4da5c8ee1c836b3581432c2b50e38a6753300"),
},
&mockDownloadDescriptor{
currentDownloads: currentDownloads,
id: "id2",
expectedDiffID: layer.DiffID("sha256:64a636223116aa837973a5d9c2bdd17d9b204e4f95ac423e20e65dfbb3655473"),
},
&mockDownloadDescriptor{
currentDownloads: currentDownloads,
id: "id4",
expectedDiffID: layer.DiffID("sha256:0dfb5b9577716cc173e95af7c10289322c29a6453a1718addc00c0c5b1330936"),
simulateRetries: 1,
},
&mockDownloadDescriptor{
currentDownloads: currentDownloads,
id: "id5",
expectedDiffID: layer.DiffID("sha256:0a5f25fa1acbc647f6112a6276735d0fa01e4ee2aa7ec33015e337350e1ea23d"),
},
}
}
func TestSuccessfulDownload(t *testing.T) {
layerStore := &mockLayerStore{make(map[layer.ChainID]*mockLayer)}
ldm := NewLayerDownloadManager(layerStore, maxDownloadConcurrency)
progressChan := make(chan progress.Progress)
progressDone := make(chan struct{})
receivedProgress := make(map[string]int64)
go func() {
for p := range progressChan {
if p.Action == "Downloading" {
receivedProgress[p.ID] = p.Current
} else if p.Action == "Already exists" {
receivedProgress[p.ID] = -1
}
}
close(progressDone)
}()
var currentDownloads int32
descriptors := downloadDescriptors(&currentDownloads)
firstDescriptor := descriptors[0].(*mockDownloadDescriptor)
// Pre-register the first layer to simulate an already-existing layer
l, err := layerStore.Register(firstDescriptor.mockTarStream(), "")
if err != nil {
t.Fatal(err)
}
firstDescriptor.diffID = l.DiffID()
rootFS, releaseFunc, err := ldm.Download(context.Background(), *image.NewRootFS(), descriptors, progress.ChanOutput(progressChan))
if err != nil {
t.Fatalf("download error: %v", err)
}
releaseFunc()
close(progressChan)
<-progressDone
if len(rootFS.DiffIDs) != len(descriptors) {
t.Fatal("got wrong number of diffIDs in rootfs")
}
for i, d := range descriptors {
descriptor := d.(*mockDownloadDescriptor)
if descriptor.diffID != "" {
if receivedProgress[d.ID()] != -1 {
t.Fatalf("did not get 'already exists' message for %v", d.ID())
}
} else if receivedProgress[d.ID()] != 10 {
t.Fatalf("missing or wrong progress output for %v (got: %d)", d.ID(), receivedProgress[d.ID()])
}
if rootFS.DiffIDs[i] != descriptor.expectedDiffID {
t.Fatalf("rootFS item %d has the wrong diffID (expected: %v got: %v)", i, descriptor.expectedDiffID, rootFS.DiffIDs[i])
}
if descriptor.diffID == "" && descriptor.registeredDiffID != rootFS.DiffIDs[i] {
t.Fatal("diffID mismatch between rootFS and Registered callback")
}
}
}
func TestCancelledDownload(t *testing.T) {
ldm := NewLayerDownloadManager(&mockLayerStore{make(map[layer.ChainID]*mockLayer)}, maxDownloadConcurrency)
progressChan := make(chan progress.Progress)
progressDone := make(chan struct{})
go func() {
for range progressChan {
}
close(progressDone)
}()
ctx, cancel := context.WithCancel(context.Background())
go func() {
<-time.After(time.Millisecond)
cancel()
}()
descriptors := downloadDescriptors(nil)
_, _, err := ldm.Download(ctx, *image.NewRootFS(), descriptors, progress.ChanOutput(progressChan))
if err != context.Canceled {
t.Fatal("expected download to be cancelled")
}
close(progressChan)
<-progressDone
}

View file

@ -0,0 +1,343 @@
package xfer
import (
"sync"
"github.com/docker/docker/pkg/progress"
"golang.org/x/net/context"
)
// DoNotRetry is an error wrapper indicating that the error cannot be resolved
// with a retry.
type DoNotRetry struct {
Err error
}
// Error returns the stringified representation of the encapsulated error.
func (e DoNotRetry) Error() string {
return e.Err.Error()
}
// Watcher is returned by Watch and can be passed to Release to stop watching.
type Watcher struct {
// signalChan is used to signal to the watcher goroutine that
// new progress information is available, or that the transfer
// has finished.
signalChan chan struct{}
// releaseChan signals to the watcher goroutine that the watcher
// should be detached.
releaseChan chan struct{}
// running remains open as long as the watcher is watching the
// transfer. It gets closed if the transfer finishes or the
// watcher is detached.
running chan struct{}
}
// Transfer represents an in-progress transfer.
type Transfer interface {
Watch(progressOutput progress.Output) *Watcher
Release(*Watcher)
Context() context.Context
Cancel()
Done() <-chan struct{}
Released() <-chan struct{}
Broadcast(masterProgressChan <-chan progress.Progress)
}
type transfer struct {
mu sync.Mutex
ctx context.Context
cancel context.CancelFunc
// watchers keeps track of the goroutines monitoring progress output,
// indexed by the channels that release them.
watchers map[chan struct{}]*Watcher
// lastProgress is the most recently received progress event.
lastProgress progress.Progress
// hasLastProgress is true when lastProgress has been set.
hasLastProgress bool
// running remains open as long as the transfer is in progress.
running chan struct{}
// hasWatchers stays open until all watchers release the trasnfer.
hasWatchers chan struct{}
// broadcastDone is true if the master progress channel has closed.
broadcastDone bool
// broadcastSyncChan allows watchers to "ping" the broadcasting
// goroutine to wait for it for deplete its input channel. This ensures
// a detaching watcher won't miss an event that was sent before it
// started detaching.
broadcastSyncChan chan struct{}
}
// NewTransfer creates a new transfer.
func NewTransfer() Transfer {
t := &transfer{
watchers: make(map[chan struct{}]*Watcher),
running: make(chan struct{}),
hasWatchers: make(chan struct{}),
broadcastSyncChan: make(chan struct{}),
}
// This uses context.Background instead of a caller-supplied context
// so that a transfer won't be cancelled automatically if the client
// which requested it is ^C'd (there could be other viewers).
t.ctx, t.cancel = context.WithCancel(context.Background())
return t
}
// Broadcast copies the progress and error output to all viewers.
func (t *transfer) Broadcast(masterProgressChan <-chan progress.Progress) {
for {
var (
p progress.Progress
ok bool
)
select {
case p, ok = <-masterProgressChan:
default:
// We've depleted the channel, so now we can handle
// reads on broadcastSyncChan to let detaching watchers
// know we're caught up.
select {
case <-t.broadcastSyncChan:
continue
case p, ok = <-masterProgressChan:
}
}
t.mu.Lock()
if ok {
t.lastProgress = p
t.hasLastProgress = true
for _, w := range t.watchers {
select {
case w.signalChan <- struct{}{}:
default:
}
}
} else {
t.broadcastDone = true
}
t.mu.Unlock()
if !ok {
close(t.running)
return
}
}
}
// Watch adds a watcher to the transfer. The supplied channel gets progress
// updates and is closed when the transfer finishes.
func (t *transfer) Watch(progressOutput progress.Output) *Watcher {
t.mu.Lock()
defer t.mu.Unlock()
w := &Watcher{
releaseChan: make(chan struct{}),
signalChan: make(chan struct{}),
running: make(chan struct{}),
}
if t.broadcastDone {
close(w.running)
return w
}
t.watchers[w.releaseChan] = w
go func() {
defer func() {
close(w.running)
}()
done := false
for {
t.mu.Lock()
hasLastProgress := t.hasLastProgress
lastProgress := t.lastProgress
t.mu.Unlock()
// This might write the last progress item a
// second time (since channel closure also gets
// us here), but that's fine.
if hasLastProgress {
progressOutput.WriteProgress(lastProgress)
}
if done {
return
}
select {
case <-w.signalChan:
case <-w.releaseChan:
done = true
// Since the watcher is going to detach, make
// sure the broadcaster is caught up so we
// don't miss anything.
select {
case t.broadcastSyncChan <- struct{}{}:
case <-t.running:
}
case <-t.running:
done = true
}
}
}()
return w
}
// Release is the inverse of Watch; indicating that the watcher no longer wants
// to be notified about the progress of the transfer. All calls to Watch must
// be paired with later calls to Release so that the lifecycle of the transfer
// is properly managed.
func (t *transfer) Release(watcher *Watcher) {
t.mu.Lock()
delete(t.watchers, watcher.releaseChan)
if len(t.watchers) == 0 {
close(t.hasWatchers)
t.cancel()
}
t.mu.Unlock()
close(watcher.releaseChan)
// Block until the watcher goroutine completes
<-watcher.running
}
// Done returns a channel which is closed if the transfer completes or is
// cancelled. Note that having 0 watchers causes a transfer to be cancelled.
func (t *transfer) Done() <-chan struct{} {
// Note that this doesn't return t.ctx.Done() because that channel will
// be closed the moment Cancel is called, and we need to return a
// channel that blocks until a cancellation is actually acknowledged by
// the transfer function.
return t.running
}
// Released returns a channel which is closed once all watchers release the
// transfer.
func (t *transfer) Released() <-chan struct{} {
return t.hasWatchers
}
// Context returns the context associated with the transfer.
func (t *transfer) Context() context.Context {
return t.ctx
}
// Cancel cancels the context associated with the transfer.
func (t *transfer) Cancel() {
t.cancel()
}
// DoFunc is a function called by the transfer manager to actually perform
// a transfer. It should be non-blocking. It should wait until the start channel
// is closed before transfering any data. If the function closes inactive, that
// signals to the transfer manager that the job is no longer actively moving
// data - for example, it may be waiting for a dependent tranfer to finish.
// This prevents it from taking up a slot.
type DoFunc func(progressChan chan<- progress.Progress, start <-chan struct{}, inactive chan<- struct{}) Transfer
// TransferManager is used by LayerDownloadManager and LayerUploadManager to
// schedule and deduplicate transfers. It is up to the TransferManager
// implementation to make the scheduling and concurrency decisions.
type TransferManager interface {
// Transfer checks if a transfer with the given key is in progress. If
// so, it returns progress and error output from that transfer.
// Otherwise, it will call xferFunc to initiate the transfer.
Transfer(key string, xferFunc DoFunc, progressOutput progress.Output) (Transfer, *Watcher)
}
type transferManager struct {
mu sync.Mutex
concurrencyLimit int
activeTransfers int
transfers map[string]Transfer
waitingTransfers []chan struct{}
}
// NewTransferManager returns a new TransferManager.
func NewTransferManager(concurrencyLimit int) TransferManager {
return &transferManager{
concurrencyLimit: concurrencyLimit,
transfers: make(map[string]Transfer),
}
}
// Transfer checks if a transfer matching the given key is in progress. If not,
// it starts one by calling xferFunc. The caller supplies a channel which
// receives progress output from the transfer.
func (tm *transferManager) Transfer(key string, xferFunc DoFunc, progressOutput progress.Output) (Transfer, *Watcher) {
tm.mu.Lock()
defer tm.mu.Unlock()
if xfer, present := tm.transfers[key]; present {
// Transfer is already in progress.
watcher := xfer.Watch(progressOutput)
return xfer, watcher
}
start := make(chan struct{})
inactive := make(chan struct{})
if tm.activeTransfers < tm.concurrencyLimit {
close(start)
tm.activeTransfers++
} else {
tm.waitingTransfers = append(tm.waitingTransfers, start)
}
masterProgressChan := make(chan progress.Progress)
xfer := xferFunc(masterProgressChan, start, inactive)
watcher := xfer.Watch(progressOutput)
go xfer.Broadcast(masterProgressChan)
tm.transfers[key] = xfer
// When the transfer is finished, remove from the map.
go func() {
for {
select {
case <-inactive:
tm.mu.Lock()
tm.inactivate(start)
tm.mu.Unlock()
inactive = nil
case <-xfer.Done():
tm.mu.Lock()
if inactive != nil {
tm.inactivate(start)
}
delete(tm.transfers, key)
tm.mu.Unlock()
return
}
}
}()
return xfer, watcher
}
func (tm *transferManager) inactivate(start chan struct{}) {
// If the transfer was started, remove it from the activeTransfers
// count.
select {
case <-start:
// Start next transfer if any are waiting
if len(tm.waitingTransfers) != 0 {
close(tm.waitingTransfers[0])
tm.waitingTransfers = tm.waitingTransfers[1:]
} else {
tm.activeTransfers--
}
default:
}
}

View file

@ -0,0 +1,385 @@
package xfer
import (
"sync/atomic"
"testing"
"time"
"github.com/docker/docker/pkg/progress"
)
func TestTransfer(t *testing.T) {
makeXferFunc := func(id string) DoFunc {
return func(progressChan chan<- progress.Progress, start <-chan struct{}, inactive chan<- struct{}) Transfer {
select {
case <-start:
default:
t.Fatalf("transfer function not started even though concurrency limit not reached")
}
xfer := NewTransfer()
go func() {
for i := 0; i <= 10; i++ {
progressChan <- progress.Progress{ID: id, Action: "testing", Current: int64(i), Total: 10}
time.Sleep(10 * time.Millisecond)
}
close(progressChan)
}()
return xfer
}
}
tm := NewTransferManager(5)
progressChan := make(chan progress.Progress)
progressDone := make(chan struct{})
receivedProgress := make(map[string]int64)
go func() {
for p := range progressChan {
val, present := receivedProgress[p.ID]
if !present {
if p.Current != 0 {
t.Fatalf("got unexpected progress value: %d (expected 0)", p.Current)
}
} else if p.Current == 10 {
// Special case: last progress output may be
// repeated because the transfer finishing
// causes the latest progress output to be
// written to the channel (in case the watcher
// missed it).
if p.Current != 9 && p.Current != 10 {
t.Fatalf("got unexpected progress value: %d (expected %d)", p.Current, val+1)
}
} else if p.Current != val+1 {
t.Fatalf("got unexpected progress value: %d (expected %d)", p.Current, val+1)
}
receivedProgress[p.ID] = p.Current
}
close(progressDone)
}()
// Start a few transfers
ids := []string{"id1", "id2", "id3"}
xfers := make([]Transfer, len(ids))
watchers := make([]*Watcher, len(ids))
for i, id := range ids {
xfers[i], watchers[i] = tm.Transfer(id, makeXferFunc(id), progress.ChanOutput(progressChan))
}
for i, xfer := range xfers {
<-xfer.Done()
xfer.Release(watchers[i])
}
close(progressChan)
<-progressDone
for _, id := range ids {
if receivedProgress[id] != 10 {
t.Fatalf("final progress value %d instead of 10", receivedProgress[id])
}
}
}
func TestConcurrencyLimit(t *testing.T) {
concurrencyLimit := 3
var runningJobs int32
makeXferFunc := func(id string) DoFunc {
return func(progressChan chan<- progress.Progress, start <-chan struct{}, inactive chan<- struct{}) Transfer {
xfer := NewTransfer()
go func() {
<-start
totalJobs := atomic.AddInt32(&runningJobs, 1)
if int(totalJobs) > concurrencyLimit {
t.Fatalf("too many jobs running")
}
for i := 0; i <= 10; i++ {
progressChan <- progress.Progress{ID: id, Action: "testing", Current: int64(i), Total: 10}
time.Sleep(10 * time.Millisecond)
}
atomic.AddInt32(&runningJobs, -1)
close(progressChan)
}()
return xfer
}
}
tm := NewTransferManager(concurrencyLimit)
progressChan := make(chan progress.Progress)
progressDone := make(chan struct{})
receivedProgress := make(map[string]int64)
go func() {
for p := range progressChan {
receivedProgress[p.ID] = p.Current
}
close(progressDone)
}()
// Start more transfers than the concurrency limit
ids := []string{"id1", "id2", "id3", "id4", "id5", "id6", "id7", "id8"}
xfers := make([]Transfer, len(ids))
watchers := make([]*Watcher, len(ids))
for i, id := range ids {
xfers[i], watchers[i] = tm.Transfer(id, makeXferFunc(id), progress.ChanOutput(progressChan))
}
for i, xfer := range xfers {
<-xfer.Done()
xfer.Release(watchers[i])
}
close(progressChan)
<-progressDone
for _, id := range ids {
if receivedProgress[id] != 10 {
t.Fatalf("final progress value %d instead of 10", receivedProgress[id])
}
}
}
func TestInactiveJobs(t *testing.T) {
concurrencyLimit := 3
var runningJobs int32
testDone := make(chan struct{})
makeXferFunc := func(id string) DoFunc {
return func(progressChan chan<- progress.Progress, start <-chan struct{}, inactive chan<- struct{}) Transfer {
xfer := NewTransfer()
go func() {
<-start
totalJobs := atomic.AddInt32(&runningJobs, 1)
if int(totalJobs) > concurrencyLimit {
t.Fatalf("too many jobs running")
}
for i := 0; i <= 10; i++ {
progressChan <- progress.Progress{ID: id, Action: "testing", Current: int64(i), Total: 10}
time.Sleep(10 * time.Millisecond)
}
atomic.AddInt32(&runningJobs, -1)
close(inactive)
<-testDone
close(progressChan)
}()
return xfer
}
}
tm := NewTransferManager(concurrencyLimit)
progressChan := make(chan progress.Progress)
progressDone := make(chan struct{})
receivedProgress := make(map[string]int64)
go func() {
for p := range progressChan {
receivedProgress[p.ID] = p.Current
}
close(progressDone)
}()
// Start more transfers than the concurrency limit
ids := []string{"id1", "id2", "id3", "id4", "id5", "id6", "id7", "id8"}
xfers := make([]Transfer, len(ids))
watchers := make([]*Watcher, len(ids))
for i, id := range ids {
xfers[i], watchers[i] = tm.Transfer(id, makeXferFunc(id), progress.ChanOutput(progressChan))
}
close(testDone)
for i, xfer := range xfers {
<-xfer.Done()
xfer.Release(watchers[i])
}
close(progressChan)
<-progressDone
for _, id := range ids {
if receivedProgress[id] != 10 {
t.Fatalf("final progress value %d instead of 10", receivedProgress[id])
}
}
}
func TestWatchRelease(t *testing.T) {
ready := make(chan struct{})
makeXferFunc := func(id string) DoFunc {
return func(progressChan chan<- progress.Progress, start <-chan struct{}, inactive chan<- struct{}) Transfer {
xfer := NewTransfer()
go func() {
defer func() {
close(progressChan)
}()
<-ready
for i := int64(0); ; i++ {
select {
case <-time.After(10 * time.Millisecond):
case <-xfer.Context().Done():
return
}
progressChan <- progress.Progress{ID: id, Action: "testing", Current: i, Total: 10}
}
}()
return xfer
}
}
tm := NewTransferManager(5)
type watcherInfo struct {
watcher *Watcher
progressChan chan progress.Progress
progressDone chan struct{}
receivedFirstProgress chan struct{}
}
progressConsumer := func(w watcherInfo) {
first := true
for range w.progressChan {
if first {
close(w.receivedFirstProgress)
}
first = false
}
close(w.progressDone)
}
// Start a transfer
watchers := make([]watcherInfo, 5)
var xfer Transfer
watchers[0].progressChan = make(chan progress.Progress)
watchers[0].progressDone = make(chan struct{})
watchers[0].receivedFirstProgress = make(chan struct{})
xfer, watchers[0].watcher = tm.Transfer("id1", makeXferFunc("id1"), progress.ChanOutput(watchers[0].progressChan))
go progressConsumer(watchers[0])
// Give it multiple watchers
for i := 1; i != len(watchers); i++ {
watchers[i].progressChan = make(chan progress.Progress)
watchers[i].progressDone = make(chan struct{})
watchers[i].receivedFirstProgress = make(chan struct{})
watchers[i].watcher = xfer.Watch(progress.ChanOutput(watchers[i].progressChan))
go progressConsumer(watchers[i])
}
// Now that the watchers are set up, allow the transfer goroutine to
// proceed.
close(ready)
// Confirm that each watcher gets progress output.
for _, w := range watchers {
<-w.receivedFirstProgress
}
// Release one watcher every 5ms
for _, w := range watchers {
xfer.Release(w.watcher)
<-time.After(5 * time.Millisecond)
}
// Now that all watchers have been released, Released() should
// return a closed channel.
<-xfer.Released()
// Done() should return a closed channel because the xfer func returned
// due to cancellation.
<-xfer.Done()
for _, w := range watchers {
close(w.progressChan)
<-w.progressDone
}
}
func TestDuplicateTransfer(t *testing.T) {
ready := make(chan struct{})
var xferFuncCalls int32
makeXferFunc := func(id string) DoFunc {
return func(progressChan chan<- progress.Progress, start <-chan struct{}, inactive chan<- struct{}) Transfer {
atomic.AddInt32(&xferFuncCalls, 1)
xfer := NewTransfer()
go func() {
defer func() {
close(progressChan)
}()
<-ready
for i := int64(0); ; i++ {
select {
case <-time.After(10 * time.Millisecond):
case <-xfer.Context().Done():
return
}
progressChan <- progress.Progress{ID: id, Action: "testing", Current: i, Total: 10}
}
}()
return xfer
}
}
tm := NewTransferManager(5)
type transferInfo struct {
xfer Transfer
watcher *Watcher
progressChan chan progress.Progress
progressDone chan struct{}
receivedFirstProgress chan struct{}
}
progressConsumer := func(t transferInfo) {
first := true
for range t.progressChan {
if first {
close(t.receivedFirstProgress)
}
first = false
}
close(t.progressDone)
}
// Try to start multiple transfers with the same ID
transfers := make([]transferInfo, 5)
for i := range transfers {
t := &transfers[i]
t.progressChan = make(chan progress.Progress)
t.progressDone = make(chan struct{})
t.receivedFirstProgress = make(chan struct{})
t.xfer, t.watcher = tm.Transfer("id1", makeXferFunc("id1"), progress.ChanOutput(t.progressChan))
go progressConsumer(*t)
}
// Allow the transfer goroutine to proceed.
close(ready)
// Confirm that each watcher gets progress output.
for _, t := range transfers {
<-t.receivedFirstProgress
}
// Confirm that the transfer function was called exactly once.
if xferFuncCalls != 1 {
t.Fatal("transfer function wasn't called exactly once")
}
// Release one watcher every 5ms
for _, t := range transfers {
t.xfer.Release(t.watcher)
<-time.After(5 * time.Millisecond)
}
for _, t := range transfers {
// Now that all watchers have been released, Released() should
// return a closed channel.
<-t.xfer.Released()
// Done() should return a closed channel because the xfer func returned
// due to cancellation.
<-t.xfer.Done()
}
for _, t := range transfers {
close(t.progressChan)
<-t.progressDone
}
}

159
distribution/xfer/upload.go Normal file
View file

@ -0,0 +1,159 @@
package xfer
import (
"errors"
"time"
"github.com/Sirupsen/logrus"
"github.com/docker/distribution/digest"
"github.com/docker/docker/layer"
"github.com/docker/docker/pkg/progress"
"golang.org/x/net/context"
)
const maxUploadAttempts = 5
// LayerUploadManager provides task management and progress reporting for
// uploads.
type LayerUploadManager struct {
tm TransferManager
}
// NewLayerUploadManager returns a new LayerUploadManager.
func NewLayerUploadManager(concurrencyLimit int) *LayerUploadManager {
return &LayerUploadManager{
tm: NewTransferManager(concurrencyLimit),
}
}
type uploadTransfer struct {
Transfer
diffID layer.DiffID
digest digest.Digest
err error
}
// An UploadDescriptor references a layer that may need to be uploaded.
type UploadDescriptor interface {
// Key returns the key used to deduplicate uploads.
Key() string
// ID returns the ID for display purposes.
ID() string
// DiffID should return the DiffID for this layer.
DiffID() layer.DiffID
// Upload is called to perform the Upload.
Upload(ctx context.Context, progressOutput progress.Output) (digest.Digest, error)
}
// Upload is a blocking function which ensures the listed layers are present on
// the remote registry. It uses the string returned by the Key method to
// deduplicate uploads.
func (lum *LayerUploadManager) Upload(ctx context.Context, layers []UploadDescriptor, progressOutput progress.Output) (map[layer.DiffID]digest.Digest, error) {
var (
uploads []*uploadTransfer
digests = make(map[layer.DiffID]digest.Digest)
dedupDescriptors = make(map[string]struct{})
)
for _, descriptor := range layers {
progress.Update(progressOutput, descriptor.ID(), "Preparing")
key := descriptor.Key()
if _, present := dedupDescriptors[key]; present {
continue
}
dedupDescriptors[key] = struct{}{}
xferFunc := lum.makeUploadFunc(descriptor)
upload, watcher := lum.tm.Transfer(descriptor.Key(), xferFunc, progressOutput)
defer upload.Release(watcher)
uploads = append(uploads, upload.(*uploadTransfer))
}
for _, upload := range uploads {
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-upload.Transfer.Done():
if upload.err != nil {
return nil, upload.err
}
digests[upload.diffID] = upload.digest
}
}
return digests, nil
}
func (lum *LayerUploadManager) makeUploadFunc(descriptor UploadDescriptor) DoFunc {
return func(progressChan chan<- progress.Progress, start <-chan struct{}, inactive chan<- struct{}) Transfer {
u := &uploadTransfer{
Transfer: NewTransfer(),
diffID: descriptor.DiffID(),
}
go func() {
defer func() {
close(progressChan)
}()
progressOutput := progress.ChanOutput(progressChan)
select {
case <-start:
default:
progress.Update(progressOutput, descriptor.ID(), "Waiting")
<-start
}
retries := 0
for {
digest, err := descriptor.Upload(u.Transfer.Context(), progressOutput)
if err == nil {
u.digest = digest
break
}
// If an error was returned because the context
// was cancelled, we shouldn't retry.
select {
case <-u.Transfer.Context().Done():
u.err = err
return
default:
}
retries++
if _, isDNR := err.(DoNotRetry); isDNR || retries == maxUploadAttempts {
logrus.Errorf("Upload failed: %v", err)
u.err = err
return
}
logrus.Errorf("Upload failed, retrying: %v", err)
delay := retries * 5
ticker := time.NewTicker(time.Second)
selectLoop:
for {
progress.Updatef(progressOutput, descriptor.ID(), "Retrying in %d seconds", delay)
select {
case <-ticker.C:
delay--
if delay == 0 {
ticker.Stop()
break selectLoop
}
case <-u.Transfer.Context().Done():
ticker.Stop()
u.err = errors.New("upload cancelled during retry delay")
return
}
}
}
}()
return u
}
}

View file

@ -0,0 +1,153 @@
package xfer
import (
"errors"
"sync/atomic"
"testing"
"time"
"github.com/docker/distribution/digest"
"github.com/docker/docker/layer"
"github.com/docker/docker/pkg/progress"
"golang.org/x/net/context"
)
const maxUploadConcurrency = 3
type mockUploadDescriptor struct {
currentUploads *int32
diffID layer.DiffID
simulateRetries int
}
// Key returns the key used to deduplicate downloads.
func (u *mockUploadDescriptor) Key() string {
return u.diffID.String()
}
// ID returns the ID for display purposes.
func (u *mockUploadDescriptor) ID() string {
return u.diffID.String()
}
// DiffID should return the DiffID for this layer.
func (u *mockUploadDescriptor) DiffID() layer.DiffID {
return u.diffID
}
// Upload is called to perform the upload.
func (u *mockUploadDescriptor) Upload(ctx context.Context, progressOutput progress.Output) (digest.Digest, error) {
if u.currentUploads != nil {
defer atomic.AddInt32(u.currentUploads, -1)
if atomic.AddInt32(u.currentUploads, 1) > maxUploadConcurrency {
return "", errors.New("concurrency limit exceeded")
}
}
// Sleep a bit to simulate a time-consuming upload.
for i := int64(0); i <= 10; i++ {
select {
case <-ctx.Done():
return "", ctx.Err()
case <-time.After(10 * time.Millisecond):
progressOutput.WriteProgress(progress.Progress{ID: u.ID(), Current: i, Total: 10})
}
}
if u.simulateRetries != 0 {
u.simulateRetries--
return "", errors.New("simulating retry")
}
// For the mock implementation, use SHA256(DiffID) as the returned
// digest.
return digest.FromBytes([]byte(u.diffID.String()))
}
func uploadDescriptors(currentUploads *int32) []UploadDescriptor {
return []UploadDescriptor{
&mockUploadDescriptor{currentUploads, layer.DiffID("sha256:cbbf2f9a99b47fc460d422812b6a5adff7dfee951d8fa2e4a98caa0382cfbdbf"), 0},
&mockUploadDescriptor{currentUploads, layer.DiffID("sha256:1515325234325236634634608943609283523908626098235490238423902343"), 0},
&mockUploadDescriptor{currentUploads, layer.DiffID("sha256:6929356290463485374960346430698374523437683470934634534953453453"), 0},
&mockUploadDescriptor{currentUploads, layer.DiffID("sha256:cbbf2f9a99b47fc460d422812b6a5adff7dfee951d8fa2e4a98caa0382cfbdbf"), 0},
&mockUploadDescriptor{currentUploads, layer.DiffID("sha256:8159352387436803946235346346368745389534789534897538734598734987"), 1},
&mockUploadDescriptor{currentUploads, layer.DiffID("sha256:4637863963478346897346987346987346789346789364879364897364987346"), 0},
}
}
var expectedDigests = map[layer.DiffID]digest.Digest{
layer.DiffID("sha256:cbbf2f9a99b47fc460d422812b6a5adff7dfee951d8fa2e4a98caa0382cfbdbf"): digest.Digest("sha256:c5095d6cf7ee42b7b064371dcc1dc3fb4af197f04d01a60009d484bd432724fc"),
layer.DiffID("sha256:1515325234325236634634608943609283523908626098235490238423902343"): digest.Digest("sha256:968cbfe2ff5269ea1729b3804767a1f57ffbc442d3bc86f47edbf7e688a4f36e"),
layer.DiffID("sha256:6929356290463485374960346430698374523437683470934634534953453453"): digest.Digest("sha256:8a5e56ab4b477a400470a7d5d4c1ca0c91235fd723ab19cc862636a06f3a735d"),
layer.DiffID("sha256:8159352387436803946235346346368745389534789534897538734598734987"): digest.Digest("sha256:5e733e5cd3688512fc240bd5c178e72671c9915947d17bb8451750d827944cb2"),
layer.DiffID("sha256:4637863963478346897346987346987346789346789364879364897364987346"): digest.Digest("sha256:ec4bb98d15e554a9f66c3ef9296cf46772c0ded3b1592bd8324d96e2f60f460c"),
}
func TestSuccessfulUpload(t *testing.T) {
lum := NewLayerUploadManager(maxUploadConcurrency)
progressChan := make(chan progress.Progress)
progressDone := make(chan struct{})
receivedProgress := make(map[string]int64)
go func() {
for p := range progressChan {
receivedProgress[p.ID] = p.Current
}
close(progressDone)
}()
var currentUploads int32
descriptors := uploadDescriptors(&currentUploads)
digests, err := lum.Upload(context.Background(), descriptors, progress.ChanOutput(progressChan))
if err != nil {
t.Fatalf("upload error: %v", err)
}
close(progressChan)
<-progressDone
if len(digests) != len(expectedDigests) {
t.Fatal("wrong number of keys in digests map")
}
for key, val := range expectedDigests {
if digests[key] != val {
t.Fatalf("mismatch in digest array for key %v (expected %v, got %v)", key, val, digests[key])
}
if receivedProgress[key.String()] != 10 {
t.Fatalf("missing or wrong progress output for %v", key)
}
}
}
func TestCancelledUpload(t *testing.T) {
lum := NewLayerUploadManager(maxUploadConcurrency)
progressChan := make(chan progress.Progress)
progressDone := make(chan struct{})
go func() {
for range progressChan {
}
close(progressDone)
}()
ctx, cancel := context.WithCancel(context.Background())
go func() {
<-time.After(time.Millisecond)
cancel()
}()
descriptors := uploadDescriptors(nil)
_, err := lum.Upload(ctx, descriptors, progress.ChanOutput(progressChan))
if err != context.Canceled {
t.Fatal("expected upload to be cancelled")
}
close(progressChan)
<-progressDone
}

View file

@ -140,7 +140,7 @@ func (s *DockerHubPullSuite) TestPullAllTagsFromCentralRegistry(c *check.C) {
}
// TestPullClientDisconnect kills the client during a pull operation and verifies that the operation
// still succesfully completes on the daemon side.
// gets cancelled.
//
// Ref: docker/docker#15589
func (s *DockerHubPullSuite) TestPullClientDisconnect(c *check.C) {
@ -161,14 +161,8 @@ func (s *DockerHubPullSuite) TestPullClientDisconnect(c *check.C) {
err = pullCmd.Process.Kill()
c.Assert(err, checker.IsNil)
maxAttempts := 20
for i := 0; ; i++ {
if _, err := s.CmdWithError("inspect", repoName); err == nil {
break
}
if i >= maxAttempts {
c.Fatal("timeout reached: image was not pulled after client disconnected")
}
time.Sleep(500 * time.Millisecond)
time.Sleep(2 * time.Second)
if _, err := s.CmdWithError("inspect", repoName); err == nil {
c.Fatal("image was pulled after client disconnected")
}
}

View file

@ -1,167 +0,0 @@
package broadcaster
import (
"errors"
"io"
"sync"
)
// Buffered keeps track of one or more observers watching the progress
// of an operation. For example, if multiple clients are trying to pull an
// image, they share a Buffered struct for the download operation.
type Buffered struct {
sync.Mutex
// c is a channel that observers block on, waiting for the operation
// to finish.
c chan struct{}
// cond is a condition variable used to wake up observers when there's
// new data available.
cond *sync.Cond
// history is a buffer of the progress output so far, so a new observer
// can catch up. The history is stored as a slice of separate byte
// slices, so that if the writer is a WriteFlusher, the flushes will
// happen in the right places.
history [][]byte
// wg is a WaitGroup used to wait for all writes to finish on Close
wg sync.WaitGroup
// result is the argument passed to the first call of Close, and
// returned to callers of Wait
result error
}
// NewBuffered returns an initialized Buffered structure.
func NewBuffered() *Buffered {
b := &Buffered{
c: make(chan struct{}),
}
b.cond = sync.NewCond(b)
return b
}
// closed returns true if and only if the broadcaster has been closed
func (broadcaster *Buffered) closed() bool {
select {
case <-broadcaster.c:
return true
default:
return false
}
}
// receiveWrites runs as a goroutine so that writes don't block the Write
// function. It writes the new data in broadcaster.history each time there's
// activity on the broadcaster.cond condition variable.
func (broadcaster *Buffered) receiveWrites(observer io.Writer) {
n := 0
broadcaster.Lock()
// The condition variable wait is at the end of this loop, so that the
// first iteration will write the history so far.
for {
newData := broadcaster.history[n:]
// Make a copy of newData so we can release the lock
sendData := make([][]byte, len(newData), len(newData))
copy(sendData, newData)
broadcaster.Unlock()
for len(sendData) > 0 {
_, err := observer.Write(sendData[0])
if err != nil {
broadcaster.wg.Done()
return
}
n++
sendData = sendData[1:]
}
broadcaster.Lock()
// If we are behind, we need to catch up instead of waiting
// or handling a closure.
if len(broadcaster.history) != n {
continue
}
// detect closure of the broadcast writer
if broadcaster.closed() {
broadcaster.Unlock()
broadcaster.wg.Done()
return
}
broadcaster.cond.Wait()
// Mutex is still locked as the loop continues
}
}
// Write adds data to the history buffer, and also writes it to all current
// observers.
func (broadcaster *Buffered) Write(p []byte) (n int, err error) {
broadcaster.Lock()
defer broadcaster.Unlock()
// Is the broadcaster closed? If so, the write should fail.
if broadcaster.closed() {
return 0, errors.New("attempted write to a closed broadcaster.Buffered")
}
// Add message in p to the history slice
newEntry := make([]byte, len(p), len(p))
copy(newEntry, p)
broadcaster.history = append(broadcaster.history, newEntry)
broadcaster.cond.Broadcast()
return len(p), nil
}
// Add adds an observer to the broadcaster. The new observer receives the
// data from the history buffer, and also all subsequent data.
func (broadcaster *Buffered) Add(w io.Writer) error {
// The lock is acquired here so that Add can't race with Close
broadcaster.Lock()
defer broadcaster.Unlock()
if broadcaster.closed() {
return errors.New("attempted to add observer to a closed broadcaster.Buffered")
}
broadcaster.wg.Add(1)
go broadcaster.receiveWrites(w)
return nil
}
// CloseWithError signals to all observers that the operation has finished. Its
// argument is a result that should be returned to waiters blocking on Wait.
func (broadcaster *Buffered) CloseWithError(result error) {
broadcaster.Lock()
if broadcaster.closed() {
broadcaster.Unlock()
return
}
broadcaster.result = result
close(broadcaster.c)
broadcaster.cond.Broadcast()
broadcaster.Unlock()
// Don't return until all writers have caught up.
broadcaster.wg.Wait()
}
// Close signals to all observers that the operation has finished. It causes
// all calls to Wait to return nil.
func (broadcaster *Buffered) Close() {
broadcaster.CloseWithError(nil)
}
// Wait blocks until the operation is marked as completed by the Close method,
// and all writer goroutines have completed. It returns the argument that was
// passed to Close.
func (broadcaster *Buffered) Wait() error {
<-broadcaster.c
broadcaster.wg.Wait()
return broadcaster.result
}

View file

@ -4,6 +4,8 @@ import (
"crypto/sha256"
"encoding/hex"
"io"
"golang.org/x/net/context"
)
type readCloserWrapper struct {
@ -81,3 +83,72 @@ func (r *OnEOFReader) runFunc() {
r.Fn = nil
}
}
// cancelReadCloser wraps an io.ReadCloser with a context for cancelling read
// operations.
type cancelReadCloser struct {
cancel func()
pR *io.PipeReader // Stream to read from
pW *io.PipeWriter
}
// NewCancelReadCloser creates a wrapper that closes the ReadCloser when the
// context is cancelled. The returned io.ReadCloser must be closed when it is
// no longer needed.
func NewCancelReadCloser(ctx context.Context, in io.ReadCloser) io.ReadCloser {
pR, pW := io.Pipe()
// Create a context used to signal when the pipe is closed
doneCtx, cancel := context.WithCancel(context.Background())
p := &cancelReadCloser{
cancel: cancel,
pR: pR,
pW: pW,
}
go func() {
_, err := io.Copy(pW, in)
select {
case <-ctx.Done():
// If the context was closed, p.closeWithError
// was already called. Calling it again would
// change the error that Read returns.
default:
p.closeWithError(err)
}
in.Close()
}()
go func() {
for {
select {
case <-ctx.Done():
p.closeWithError(ctx.Err())
case <-doneCtx.Done():
return
}
}
}()
return p
}
// Read wraps the Read method of the pipe that provides data from the wrapped
// ReadCloser.
func (p *cancelReadCloser) Read(buf []byte) (n int, err error) {
return p.pR.Read(buf)
}
// closeWithError closes the wrapper and its underlying reader. It will
// cause future calls to Read to return err.
func (p *cancelReadCloser) closeWithError(err error) {
p.pW.CloseWithError(err)
p.cancel()
}
// Close closes the wrapper its underlying reader. It will cause
// future calls to Read to return io.EOF.
func (p *cancelReadCloser) Close() error {
p.closeWithError(io.EOF)
return nil
}

View file

@ -2,8 +2,12 @@ package ioutils
import (
"fmt"
"io/ioutil"
"strings"
"testing"
"time"
"golang.org/x/net/context"
)
// Implement io.Reader
@ -65,3 +69,26 @@ func TestHashData(t *testing.T) {
t.Fatalf("Expecting %s, got %s", expected, actual)
}
}
type perpetualReader struct{}
func (p *perpetualReader) Read(buf []byte) (n int, err error) {
for i := 0; i != len(buf); i++ {
buf[i] = 'a'
}
return len(buf), nil
}
func TestCancelReadCloser(t *testing.T) {
ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond)
cancelReadCloser := NewCancelReadCloser(ctx, ioutil.NopCloser(&perpetualReader{}))
for {
var buf [128]byte
_, err := cancelReadCloser.Read(buf[:])
if err == context.DeadlineExceeded {
break
} else if err != nil {
t.Fatalf("got unexpected error: %v", err)
}
}
}

63
pkg/progress/progress.go Normal file
View file

@ -0,0 +1,63 @@
package progress
import (
"fmt"
)
// Progress represents the progress of a transfer.
type Progress struct {
ID string
// Progress contains a Message or...
Message string
// ...progress of an action
Action string
Current int64
Total int64
LastUpdate bool
}
// Output is an interface for writing progress information. It's
// like a writer for progress, but we don't call it Writer because
// that would be confusing next to ProgressReader (also, because it
// doesn't implement the io.Writer interface).
type Output interface {
WriteProgress(Progress) error
}
type chanOutput chan<- Progress
func (out chanOutput) WriteProgress(p Progress) error {
out <- p
return nil
}
// ChanOutput returns a Output that writes progress updates to the
// supplied channel.
func ChanOutput(progressChan chan<- Progress) Output {
return chanOutput(progressChan)
}
// Update is a convenience function to write a progress update to the channel.
func Update(out Output, id, action string) {
out.WriteProgress(Progress{ID: id, Action: action})
}
// Updatef is a convenience function to write a printf-formatted progress update
// to the channel.
func Updatef(out Output, id, format string, a ...interface{}) {
Update(out, id, fmt.Sprintf(format, a...))
}
// Message is a convenience function to write a progress message to the channel.
func Message(out Output, id, message string) {
out.WriteProgress(Progress{ID: id, Message: message})
}
// Messagef is a convenience function to write a printf-formatted progress
// message to the channel.
func Messagef(out Output, id, format string, a ...interface{}) {
Message(out, id, fmt.Sprintf(format, a...))
}

View file

@ -0,0 +1,59 @@
package progress
import (
"io"
)
// Reader is a Reader with progress bar.
type Reader struct {
in io.ReadCloser // Stream to read from
out Output // Where to send progress bar to
size int64
current int64
lastUpdate int64
id string
action string
}
// NewProgressReader creates a new ProgressReader.
func NewProgressReader(in io.ReadCloser, out Output, size int64, id, action string) *Reader {
return &Reader{
in: in,
out: out,
size: size,
id: id,
action: action,
}
}
func (p *Reader) Read(buf []byte) (n int, err error) {
read, err := p.in.Read(buf)
p.current += int64(read)
updateEvery := int64(1024 * 512) //512kB
if p.size > 0 {
// Update progress for every 1% read if 1% < 512kB
if increment := int64(0.01 * float64(p.size)); increment < updateEvery {
updateEvery = increment
}
}
if p.current-p.lastUpdate > updateEvery || err != nil {
p.updateProgress(err != nil && read == 0)
p.lastUpdate = p.current
}
return read, err
}
// Close closes the progress reader and its underlying reader.
func (p *Reader) Close() error {
if p.current < p.size {
// print a full progress bar when closing prematurely
p.current = p.size
p.updateProgress(false)
}
return p.in.Close()
}
func (p *Reader) updateProgress(last bool) {
p.out.WriteProgress(Progress{ID: p.id, Action: p.action, Current: p.current, Total: p.size, LastUpdate: last})
}

View file

@ -0,0 +1,75 @@
package progress
import (
"bytes"
"io"
"io/ioutil"
"testing"
)
func TestOutputOnPrematureClose(t *testing.T) {
content := []byte("TESTING")
reader := ioutil.NopCloser(bytes.NewReader(content))
progressChan := make(chan Progress, 10)
pr := NewProgressReader(reader, ChanOutput(progressChan), int64(len(content)), "Test", "Read")
part := make([]byte, 4, 4)
_, err := io.ReadFull(pr, part)
if err != nil {
pr.Close()
t.Fatal(err)
}
drainLoop:
for {
select {
case <-progressChan:
default:
break drainLoop
}
}
pr.Close()
select {
case <-progressChan:
default:
t.Fatalf("Expected some output when closing prematurely")
}
}
func TestCompleteSilently(t *testing.T) {
content := []byte("TESTING")
reader := ioutil.NopCloser(bytes.NewReader(content))
progressChan := make(chan Progress, 10)
pr := NewProgressReader(reader, ChanOutput(progressChan), int64(len(content)), "Test", "Read")
out, err := ioutil.ReadAll(pr)
if err != nil {
pr.Close()
t.Fatal(err)
}
if string(out) != "TESTING" {
pr.Close()
t.Fatalf("Unexpected output %q from reader", string(out))
}
drainLoop:
for {
select {
case <-progressChan:
default:
break drainLoop
}
}
pr.Close()
select {
case <-progressChan:
t.Fatalf("Should have closed silently when read is complete")
default:
}
}

View file

@ -1,68 +0,0 @@
// Package progressreader provides a Reader with a progress bar that can be
// printed out using the streamformatter package.
package progressreader
import (
"io"
"github.com/docker/docker/pkg/jsonmessage"
"github.com/docker/docker/pkg/streamformatter"
)
// Config contains the configuration for a Reader with progress bar.
type Config struct {
In io.ReadCloser // Stream to read from
Out io.Writer // Where to send progress bar to
Formatter *streamformatter.StreamFormatter
Size int64
Current int64
LastUpdate int64
NewLines bool
ID string
Action string
}
// New creates a new Config.
func New(newReader Config) *Config {
return &newReader
}
func (config *Config) Read(p []byte) (n int, err error) {
read, err := config.In.Read(p)
config.Current += int64(read)
updateEvery := int64(1024 * 512) //512kB
if config.Size > 0 {
// Update progress for every 1% read if 1% < 512kB
if increment := int64(0.01 * float64(config.Size)); increment < updateEvery {
updateEvery = increment
}
}
if config.Current-config.LastUpdate > updateEvery || err != nil {
updateProgress(config)
config.LastUpdate = config.Current
}
if err != nil && read == 0 {
updateProgress(config)
if config.NewLines {
config.Out.Write(config.Formatter.FormatStatus("", ""))
}
}
return read, err
}
// Close closes the reader (Config).
func (config *Config) Close() error {
if config.Current < config.Size {
//print a full progress bar when closing prematurely
config.Current = config.Size
updateProgress(config)
}
return config.In.Close()
}
func updateProgress(config *Config) {
progress := jsonmessage.JSONProgress{Current: config.Current, Total: config.Size}
fmtMessage := config.Formatter.FormatProgress(config.ID, config.Action, &progress)
config.Out.Write(fmtMessage)
}

View file

@ -1,94 +0,0 @@
package progressreader
import (
"bufio"
"bytes"
"io"
"io/ioutil"
"testing"
"github.com/docker/docker/pkg/streamformatter"
)
func TestOutputOnPrematureClose(t *testing.T) {
var outBuf bytes.Buffer
content := []byte("TESTING")
reader := ioutil.NopCloser(bytes.NewReader(content))
writer := bufio.NewWriter(&outBuf)
prCfg := Config{
In: reader,
Out: writer,
Formatter: streamformatter.NewStreamFormatter(),
Size: int64(len(content)),
NewLines: true,
ID: "Test",
Action: "Read",
}
pr := New(prCfg)
part := make([]byte, 4, 4)
_, err := io.ReadFull(pr, part)
if err != nil {
pr.Close()
t.Fatal(err)
}
if err := writer.Flush(); err != nil {
pr.Close()
t.Fatal(err)
}
tlen := outBuf.Len()
pr.Close()
if err := writer.Flush(); err != nil {
t.Fatal(err)
}
if outBuf.Len() == tlen {
t.Fatalf("Expected some output when closing prematurely")
}
}
func TestCompleteSilently(t *testing.T) {
var outBuf bytes.Buffer
content := []byte("TESTING")
reader := ioutil.NopCloser(bytes.NewReader(content))
writer := bufio.NewWriter(&outBuf)
prCfg := Config{
In: reader,
Out: writer,
Formatter: streamformatter.NewStreamFormatter(),
Size: int64(len(content)),
NewLines: true,
ID: "Test",
Action: "Read",
}
pr := New(prCfg)
out, err := ioutil.ReadAll(pr)
if err != nil {
pr.Close()
t.Fatal(err)
}
if string(out) != "TESTING" {
pr.Close()
t.Fatalf("Unexpected output %q from reader", string(out))
}
if err := writer.Flush(); err != nil {
pr.Close()
t.Fatal(err)
}
tlen := outBuf.Len()
pr.Close()
if err := writer.Flush(); err != nil {
t.Fatal(err)
}
if outBuf.Len() > tlen {
t.Fatalf("Should have closed silently when read is complete")
}
}

View file

@ -7,6 +7,7 @@ import (
"io"
"github.com/docker/docker/pkg/jsonmessage"
"github.com/docker/docker/pkg/progress"
)
// StreamFormatter formats a stream, optionally using JSON.
@ -92,6 +93,44 @@ func (sf *StreamFormatter) FormatProgress(id, action string, progress *jsonmessa
return []byte(action + " " + progress.String() + endl)
}
// NewProgressOutput returns a progress.Output object that can be passed to
// progress.NewProgressReader.
func (sf *StreamFormatter) NewProgressOutput(out io.Writer, newLines bool) progress.Output {
return &progressOutput{
sf: sf,
out: out,
newLines: newLines,
}
}
type progressOutput struct {
sf *StreamFormatter
out io.Writer
newLines bool
}
// WriteProgress formats progress information from a ProgressReader.
func (out *progressOutput) WriteProgress(prog progress.Progress) error {
var formatted []byte
if prog.Message != "" {
formatted = out.sf.FormatStatus(prog.ID, prog.Message)
} else {
jsonProgress := jsonmessage.JSONProgress{Current: prog.Current, Total: prog.Total}
formatted = out.sf.FormatProgress(prog.ID, prog.Action, &jsonProgress)
}
_, err := out.out.Write(formatted)
if err != nil {
return err
}
if out.newLines && prog.LastUpdate {
_, err = out.out.Write(out.sf.FormatStatus("", ""))
return err
}
return nil
}
// StdoutFormatter is a streamFormatter that writes to the standard output.
type StdoutFormatter struct {
io.Writer

View file

@ -17,7 +17,6 @@ import (
"net/url"
"strconv"
"strings"
"time"
"github.com/Sirupsen/logrus"
"github.com/docker/distribution/reference"
@ -270,7 +269,6 @@ func (r *Session) GetRemoteImageJSON(imgID, registry string) ([]byte, int64, err
// GetRemoteImageLayer retrieves an image layer from the registry
func (r *Session) GetRemoteImageLayer(imgID, registry string, imgSize int64) (io.ReadCloser, error) {
var (
retries = 5
statusCode = 0
res *http.Response
err error
@ -281,14 +279,9 @@ func (r *Session) GetRemoteImageLayer(imgID, registry string, imgSize int64) (io
if err != nil {
return nil, fmt.Errorf("Error while getting from the server: %v", err)
}
// TODO(tiborvass): why are we doing retries at this level?
// These retries should be generic to both v1 and v2
for i := 1; i <= retries; i++ {
statusCode = 0
res, err = r.client.Do(req)
if err == nil {
break
}
statusCode = 0
res, err = r.client.Do(req)
if err != nil {
logrus.Debugf("Error contacting registry %s: %v", registry, err)
if res != nil {
if res.Body != nil {
@ -296,11 +289,8 @@ func (r *Session) GetRemoteImageLayer(imgID, registry string, imgSize int64) (io
}
statusCode = res.StatusCode
}
if i == retries {
return nil, fmt.Errorf("Server error: Status %d while fetching image layer (%s)",
statusCode, imgID)
}
time.Sleep(time.Duration(i) * 5 * time.Second)
return nil, fmt.Errorf("Server error: Status %d while fetching image layer (%s)",
statusCode, imgID)
}
if res.StatusCode != 200 {