1
0
Fork 0
mirror of https://github.com/moby/moby.git synced 2022-11-09 12:21:53 -05:00

Improved push and pull with upload manager and download manager

This commit adds a transfer manager which deduplicates and schedules
transfers, and also an upload manager and download manager that build on
top of the transfer manager to provide high-level interfaces for uploads
and downloads. The push and pull code is modified to use these building
blocks.

Some benefits of the changes:

- Simplification of push/pull code
- Pushes can upload layers concurrently
- Failed downloads and uploads are retried after backoff delays
- Cancellation is supported, but individual transfers will only be
  cancelled if all pushes or pulls using them are cancelled.
- The distribution code is decoupled from Docker Engine packages and API
  conventions (i.e. streamformatter), which will make it easier to split
  out.

This commit also includes unit tests for the new distribution/xfer
package. The tests cover 87.8% of the statements in the package.

Signed-off-by: Aaron Lehmann <aaron.lehmann@docker.com>
This commit is contained in:
Aaron Lehmann 2015-11-13 16:59:01 -08:00
parent 7470e39c73
commit 572ce80230
36 changed files with 2675 additions and 1127 deletions

View file

@ -0,0 +1,420 @@
package xfer
import (
"errors"
"fmt"
"io"
"time"
"github.com/Sirupsen/logrus"
"github.com/docker/docker/image"
"github.com/docker/docker/layer"
"github.com/docker/docker/pkg/archive"
"github.com/docker/docker/pkg/ioutils"
"github.com/docker/docker/pkg/progress"
"golang.org/x/net/context"
)
const maxDownloadAttempts = 5
// LayerDownloadManager figures out which layers need to be downloaded, then
// registers and downloads those, taking into account dependencies between
// layers.
type LayerDownloadManager struct {
layerStore layer.Store
tm TransferManager
}
// NewLayerDownloadManager returns a new LayerDownloadManager.
func NewLayerDownloadManager(layerStore layer.Store, concurrencyLimit int) *LayerDownloadManager {
return &LayerDownloadManager{
layerStore: layerStore,
tm: NewTransferManager(concurrencyLimit),
}
}
type downloadTransfer struct {
Transfer
layerStore layer.Store
layer layer.Layer
err error
}
// result returns the layer resulting from the download, if the download
// and registration were successful.
func (d *downloadTransfer) result() (layer.Layer, error) {
return d.layer, d.err
}
// A DownloadDescriptor references a layer that may need to be downloaded.
type DownloadDescriptor interface {
// Key returns the key used to deduplicate downloads.
Key() string
// ID returns the ID for display purposes.
ID() string
// DiffID should return the DiffID for this layer, or an error
// if it is unknown (for example, if it has not been downloaded
// before).
DiffID() (layer.DiffID, error)
// Download is called to perform the download.
Download(ctx context.Context, progressOutput progress.Output) (io.ReadCloser, int64, error)
}
// DownloadDescriptorWithRegistered is a DownloadDescriptor that has an
// additional Registered method which gets called after a downloaded layer is
// registered. This allows the user of the download manager to know the DiffID
// of each registered layer. This method is called if a cast to
// DownloadDescriptorWithRegistered is successful.
type DownloadDescriptorWithRegistered interface {
DownloadDescriptor
Registered(diffID layer.DiffID)
}
// Download is a blocking function which ensures the requested layers are
// present in the layer store. It uses the string returned by the Key method to
// deduplicate downloads. If a given layer is not already known to present in
// the layer store, and the key is not used by an in-progress download, the
// Download method is called to get the layer tar data. Layers are then
// registered in the appropriate order. The caller must call the returned
// release function once it is is done with the returned RootFS object.
func (ldm *LayerDownloadManager) Download(ctx context.Context, initialRootFS image.RootFS, layers []DownloadDescriptor, progressOutput progress.Output) (image.RootFS, func(), error) {
var (
topLayer layer.Layer
topDownload *downloadTransfer
watcher *Watcher
missingLayer bool
transferKey = ""
downloadsByKey = make(map[string]*downloadTransfer)
)
rootFS := initialRootFS
for _, descriptor := range layers {
key := descriptor.Key()
transferKey += key
if !missingLayer {
missingLayer = true
diffID, err := descriptor.DiffID()
if err == nil {
getRootFS := rootFS
getRootFS.Append(diffID)
l, err := ldm.layerStore.Get(getRootFS.ChainID())
if err == nil {
// Layer already exists.
logrus.Debugf("Layer already exists: %s", descriptor.ID())
progress.Update(progressOutput, descriptor.ID(), "Already exists")
if topLayer != nil {
layer.ReleaseAndLog(ldm.layerStore, topLayer)
}
topLayer = l
missingLayer = false
rootFS.Append(diffID)
continue
}
}
}
// Does this layer have the same data as a previous layer in
// the stack? If so, avoid downloading it more than once.
var topDownloadUncasted Transfer
if existingDownload, ok := downloadsByKey[key]; ok {
xferFunc := ldm.makeDownloadFuncFromDownload(descriptor, existingDownload, topDownload)
defer topDownload.Transfer.Release(watcher)
topDownloadUncasted, watcher = ldm.tm.Transfer(transferKey, xferFunc, progressOutput)
topDownload = topDownloadUncasted.(*downloadTransfer)
continue
}
// Layer is not known to exist - download and register it.
progress.Update(progressOutput, descriptor.ID(), "Pulling fs layer")
var xferFunc DoFunc
if topDownload != nil {
xferFunc = ldm.makeDownloadFunc(descriptor, "", topDownload)
defer topDownload.Transfer.Release(watcher)
} else {
xferFunc = ldm.makeDownloadFunc(descriptor, rootFS.ChainID(), nil)
}
topDownloadUncasted, watcher = ldm.tm.Transfer(transferKey, xferFunc, progressOutput)
topDownload = topDownloadUncasted.(*downloadTransfer)
downloadsByKey[key] = topDownload
}
if topDownload == nil {
return rootFS, func() { layer.ReleaseAndLog(ldm.layerStore, topLayer) }, nil
}
// Won't be using the list built up so far - will generate it
// from downloaded layers instead.
rootFS.DiffIDs = []layer.DiffID{}
defer func() {
if topLayer != nil {
layer.ReleaseAndLog(ldm.layerStore, topLayer)
}
}()
select {
case <-ctx.Done():
topDownload.Transfer.Release(watcher)
return rootFS, func() {}, ctx.Err()
case <-topDownload.Done():
break
}
l, err := topDownload.result()
if err != nil {
topDownload.Transfer.Release(watcher)
return rootFS, func() {}, err
}
// Must do this exactly len(layers) times, so we don't include the
// base layer on Windows.
for range layers {
if l == nil {
topDownload.Transfer.Release(watcher)
return rootFS, func() {}, errors.New("internal error: too few parent layers")
}
rootFS.DiffIDs = append([]layer.DiffID{l.DiffID()}, rootFS.DiffIDs...)
l = l.Parent()
}
return rootFS, func() { topDownload.Transfer.Release(watcher) }, err
}
// makeDownloadFunc returns a function that performs the layer download and
// registration. If parentDownload is non-nil, it waits for that download to
// complete before the registration step, and registers the downloaded data
// on top of parentDownload's resulting layer. Otherwise, it registers the
// layer on top of the ChainID given by parentLayer.
func (ldm *LayerDownloadManager) makeDownloadFunc(descriptor DownloadDescriptor, parentLayer layer.ChainID, parentDownload *downloadTransfer) DoFunc {
return func(progressChan chan<- progress.Progress, start <-chan struct{}, inactive chan<- struct{}) Transfer {
d := &downloadTransfer{
Transfer: NewTransfer(),
layerStore: ldm.layerStore,
}
go func() {
defer func() {
close(progressChan)
}()
progressOutput := progress.ChanOutput(progressChan)
select {
case <-start:
default:
progress.Update(progressOutput, descriptor.ID(), "Waiting")
<-start
}
if parentDownload != nil {
// Did the parent download already fail or get
// cancelled?
select {
case <-parentDownload.Done():
_, err := parentDownload.result()
if err != nil {
d.err = err
return
}
default:
}
}
var (
downloadReader io.ReadCloser
size int64
err error
retries int
)
for {
downloadReader, size, err = descriptor.Download(d.Transfer.Context(), progressOutput)
if err == nil {
break
}
// If an error was returned because the context
// was cancelled, we shouldn't retry.
select {
case <-d.Transfer.Context().Done():
d.err = err
return
default:
}
retries++
if _, isDNR := err.(DoNotRetry); isDNR || retries == maxDownloadAttempts {
logrus.Errorf("Download failed: %v", err)
d.err = err
return
}
logrus.Errorf("Download failed, retrying: %v", err)
delay := retries * 5
ticker := time.NewTicker(time.Second)
selectLoop:
for {
progress.Updatef(progressOutput, descriptor.ID(), "Retrying in %d seconds", delay)
select {
case <-ticker.C:
delay--
if delay == 0 {
ticker.Stop()
break selectLoop
}
case <-d.Transfer.Context().Done():
ticker.Stop()
d.err = errors.New("download cancelled during retry delay")
return
}
}
}
close(inactive)
if parentDownload != nil {
select {
case <-d.Transfer.Context().Done():
d.err = errors.New("layer registration cancelled")
downloadReader.Close()
return
case <-parentDownload.Done():
}
l, err := parentDownload.result()
if err != nil {
d.err = err
downloadReader.Close()
return
}
parentLayer = l.ChainID()
}
reader := progress.NewProgressReader(ioutils.NewCancelReadCloser(d.Transfer.Context(), downloadReader), progressOutput, size, descriptor.ID(), "Extracting")
defer reader.Close()
inflatedLayerData, err := archive.DecompressStream(reader)
if err != nil {
d.err = fmt.Errorf("could not get decompression stream: %v", err)
return
}
d.layer, err = d.layerStore.Register(inflatedLayerData, parentLayer)
if err != nil {
select {
case <-d.Transfer.Context().Done():
d.err = errors.New("layer registration cancelled")
default:
d.err = fmt.Errorf("failed to register layer: %v", err)
}
return
}
progress.Update(progressOutput, descriptor.ID(), "Pull complete")
withRegistered, hasRegistered := descriptor.(DownloadDescriptorWithRegistered)
if hasRegistered {
withRegistered.Registered(d.layer.DiffID())
}
// Doesn't actually need to be its own goroutine, but
// done like this so we can defer close(c).
go func() {
<-d.Transfer.Released()
if d.layer != nil {
layer.ReleaseAndLog(d.layerStore, d.layer)
}
}()
}()
return d
}
}
// makeDownloadFuncFromDownload returns a function that performs the layer
// registration when the layer data is coming from an existing download. It
// waits for sourceDownload and parentDownload to complete, and then
// reregisters the data from sourceDownload's top layer on top of
// parentDownload. This function does not log progress output because it would
// interfere with the progress reporting for sourceDownload, which has the same
// Key.
func (ldm *LayerDownloadManager) makeDownloadFuncFromDownload(descriptor DownloadDescriptor, sourceDownload *downloadTransfer, parentDownload *downloadTransfer) DoFunc {
return func(progressChan chan<- progress.Progress, start <-chan struct{}, inactive chan<- struct{}) Transfer {
d := &downloadTransfer{
Transfer: NewTransfer(),
layerStore: ldm.layerStore,
}
go func() {
defer func() {
close(progressChan)
}()
<-start
close(inactive)
select {
case <-d.Transfer.Context().Done():
d.err = errors.New("layer registration cancelled")
return
case <-parentDownload.Done():
}
l, err := parentDownload.result()
if err != nil {
d.err = err
return
}
parentLayer := l.ChainID()
// sourceDownload should have already finished if
// parentDownload finished, but wait for it explicitly
// to be sure.
select {
case <-d.Transfer.Context().Done():
d.err = errors.New("layer registration cancelled")
return
case <-sourceDownload.Done():
}
l, err = sourceDownload.result()
if err != nil {
d.err = err
return
}
layerReader, err := l.TarStream()
if err != nil {
d.err = err
return
}
defer layerReader.Close()
d.layer, err = d.layerStore.Register(layerReader, parentLayer)
if err != nil {
d.err = fmt.Errorf("failed to register layer: %v", err)
return
}
withRegistered, hasRegistered := descriptor.(DownloadDescriptorWithRegistered)
if hasRegistered {
withRegistered.Registered(d.layer.DiffID())
}
// Doesn't actually need to be its own goroutine, but
// done like this so we can defer close(c).
go func() {
<-d.Transfer.Released()
if d.layer != nil {
layer.ReleaseAndLog(d.layerStore, d.layer)
}
}()
}()
return d
}
}

View file

@ -0,0 +1,332 @@
package xfer
import (
"bytes"
"errors"
"io"
"io/ioutil"
"sync/atomic"
"testing"
"time"
"github.com/docker/distribution/digest"
"github.com/docker/docker/image"
"github.com/docker/docker/layer"
"github.com/docker/docker/pkg/archive"
"github.com/docker/docker/pkg/progress"
"golang.org/x/net/context"
)
const maxDownloadConcurrency = 3
type mockLayer struct {
layerData bytes.Buffer
diffID layer.DiffID
chainID layer.ChainID
parent layer.Layer
}
func (ml *mockLayer) TarStream() (io.ReadCloser, error) {
return ioutil.NopCloser(bytes.NewBuffer(ml.layerData.Bytes())), nil
}
func (ml *mockLayer) ChainID() layer.ChainID {
return ml.chainID
}
func (ml *mockLayer) DiffID() layer.DiffID {
return ml.diffID
}
func (ml *mockLayer) Parent() layer.Layer {
return ml.parent
}
func (ml *mockLayer) Size() (size int64, err error) {
return 0, nil
}
func (ml *mockLayer) DiffSize() (size int64, err error) {
return 0, nil
}
func (ml *mockLayer) Metadata() (map[string]string, error) {
return make(map[string]string), nil
}
type mockLayerStore struct {
layers map[layer.ChainID]*mockLayer
}
func createChainIDFromParent(parent layer.ChainID, dgsts ...layer.DiffID) layer.ChainID {
if len(dgsts) == 0 {
return parent
}
if parent == "" {
return createChainIDFromParent(layer.ChainID(dgsts[0]), dgsts[1:]...)
}
// H = "H(n-1) SHA256(n)"
dgst, err := digest.FromBytes([]byte(string(parent) + " " + string(dgsts[0])))
if err != nil {
// Digest calculation is not expected to throw an error,
// any error at this point is a program error
panic(err)
}
return createChainIDFromParent(layer.ChainID(dgst), dgsts[1:]...)
}
func (ls *mockLayerStore) Register(reader io.Reader, parentID layer.ChainID) (layer.Layer, error) {
var (
parent layer.Layer
err error
)
if parentID != "" {
parent, err = ls.Get(parentID)
if err != nil {
return nil, err
}
}
l := &mockLayer{parent: parent}
_, err = l.layerData.ReadFrom(reader)
if err != nil {
return nil, err
}
diffID, err := digest.FromBytes(l.layerData.Bytes())
if err != nil {
return nil, err
}
l.diffID = layer.DiffID(diffID)
l.chainID = createChainIDFromParent(parentID, l.diffID)
ls.layers[l.chainID] = l
return l, nil
}
func (ls *mockLayerStore) Get(chainID layer.ChainID) (layer.Layer, error) {
l, ok := ls.layers[chainID]
if !ok {
return nil, layer.ErrLayerDoesNotExist
}
return l, nil
}
func (ls *mockLayerStore) Release(l layer.Layer) ([]layer.Metadata, error) {
return []layer.Metadata{}, nil
}
func (ls *mockLayerStore) Mount(id string, parent layer.ChainID, label string, init layer.MountInit) (layer.RWLayer, error) {
return nil, errors.New("not implemented")
}
func (ls *mockLayerStore) Unmount(id string) error {
return errors.New("not implemented")
}
func (ls *mockLayerStore) DeleteMount(id string) ([]layer.Metadata, error) {
return nil, errors.New("not implemented")
}
func (ls *mockLayerStore) Changes(id string) ([]archive.Change, error) {
return nil, errors.New("not implemented")
}
type mockDownloadDescriptor struct {
currentDownloads *int32
id string
diffID layer.DiffID
registeredDiffID layer.DiffID
expectedDiffID layer.DiffID
simulateRetries int
}
// Key returns the key used to deduplicate downloads.
func (d *mockDownloadDescriptor) Key() string {
return d.id
}
// ID returns the ID for display purposes.
func (d *mockDownloadDescriptor) ID() string {
return d.id
}
// DiffID should return the DiffID for this layer, or an error
// if it is unknown (for example, if it has not been downloaded
// before).
func (d *mockDownloadDescriptor) DiffID() (layer.DiffID, error) {
if d.diffID != "" {
return d.diffID, nil
}
return "", errors.New("no diffID available")
}
func (d *mockDownloadDescriptor) Registered(diffID layer.DiffID) {
d.registeredDiffID = diffID
}
func (d *mockDownloadDescriptor) mockTarStream() io.ReadCloser {
// The mock implementation returns the ID repeated 5 times as a tar
// stream instead of actual tar data. The data is ignored except for
// computing IDs.
return ioutil.NopCloser(bytes.NewBuffer([]byte(d.id + d.id + d.id + d.id + d.id)))
}
// Download is called to perform the download.
func (d *mockDownloadDescriptor) Download(ctx context.Context, progressOutput progress.Output) (io.ReadCloser, int64, error) {
if d.currentDownloads != nil {
defer atomic.AddInt32(d.currentDownloads, -1)
if atomic.AddInt32(d.currentDownloads, 1) > maxDownloadConcurrency {
return nil, 0, errors.New("concurrency limit exceeded")
}
}
// Sleep a bit to simulate a time-consuming download.
for i := int64(0); i <= 10; i++ {
select {
case <-ctx.Done():
return nil, 0, ctx.Err()
case <-time.After(10 * time.Millisecond):
progressOutput.WriteProgress(progress.Progress{ID: d.ID(), Action: "Downloading", Current: i, Total: 10})
}
}
if d.simulateRetries != 0 {
d.simulateRetries--
return nil, 0, errors.New("simulating retry")
}
return d.mockTarStream(), 0, nil
}
func downloadDescriptors(currentDownloads *int32) []DownloadDescriptor {
return []DownloadDescriptor{
&mockDownloadDescriptor{
currentDownloads: currentDownloads,
id: "id1",
expectedDiffID: layer.DiffID("sha256:68e2c75dc5c78ea9240689c60d7599766c213ae210434c53af18470ae8c53ec1"),
},
&mockDownloadDescriptor{
currentDownloads: currentDownloads,
id: "id2",
expectedDiffID: layer.DiffID("sha256:64a636223116aa837973a5d9c2bdd17d9b204e4f95ac423e20e65dfbb3655473"),
},
&mockDownloadDescriptor{
currentDownloads: currentDownloads,
id: "id3",
expectedDiffID: layer.DiffID("sha256:58745a8bbd669c25213e9de578c4da5c8ee1c836b3581432c2b50e38a6753300"),
},
&mockDownloadDescriptor{
currentDownloads: currentDownloads,
id: "id2",
expectedDiffID: layer.DiffID("sha256:64a636223116aa837973a5d9c2bdd17d9b204e4f95ac423e20e65dfbb3655473"),
},
&mockDownloadDescriptor{
currentDownloads: currentDownloads,
id: "id4",
expectedDiffID: layer.DiffID("sha256:0dfb5b9577716cc173e95af7c10289322c29a6453a1718addc00c0c5b1330936"),
simulateRetries: 1,
},
&mockDownloadDescriptor{
currentDownloads: currentDownloads,
id: "id5",
expectedDiffID: layer.DiffID("sha256:0a5f25fa1acbc647f6112a6276735d0fa01e4ee2aa7ec33015e337350e1ea23d"),
},
}
}
func TestSuccessfulDownload(t *testing.T) {
layerStore := &mockLayerStore{make(map[layer.ChainID]*mockLayer)}
ldm := NewLayerDownloadManager(layerStore, maxDownloadConcurrency)
progressChan := make(chan progress.Progress)
progressDone := make(chan struct{})
receivedProgress := make(map[string]int64)
go func() {
for p := range progressChan {
if p.Action == "Downloading" {
receivedProgress[p.ID] = p.Current
} else if p.Action == "Already exists" {
receivedProgress[p.ID] = -1
}
}
close(progressDone)
}()
var currentDownloads int32
descriptors := downloadDescriptors(&currentDownloads)
firstDescriptor := descriptors[0].(*mockDownloadDescriptor)
// Pre-register the first layer to simulate an already-existing layer
l, err := layerStore.Register(firstDescriptor.mockTarStream(), "")
if err != nil {
t.Fatal(err)
}
firstDescriptor.diffID = l.DiffID()
rootFS, releaseFunc, err := ldm.Download(context.Background(), *image.NewRootFS(), descriptors, progress.ChanOutput(progressChan))
if err != nil {
t.Fatalf("download error: %v", err)
}
releaseFunc()
close(progressChan)
<-progressDone
if len(rootFS.DiffIDs) != len(descriptors) {
t.Fatal("got wrong number of diffIDs in rootfs")
}
for i, d := range descriptors {
descriptor := d.(*mockDownloadDescriptor)
if descriptor.diffID != "" {
if receivedProgress[d.ID()] != -1 {
t.Fatalf("did not get 'already exists' message for %v", d.ID())
}
} else if receivedProgress[d.ID()] != 10 {
t.Fatalf("missing or wrong progress output for %v (got: %d)", d.ID(), receivedProgress[d.ID()])
}
if rootFS.DiffIDs[i] != descriptor.expectedDiffID {
t.Fatalf("rootFS item %d has the wrong diffID (expected: %v got: %v)", i, descriptor.expectedDiffID, rootFS.DiffIDs[i])
}
if descriptor.diffID == "" && descriptor.registeredDiffID != rootFS.DiffIDs[i] {
t.Fatal("diffID mismatch between rootFS and Registered callback")
}
}
}
func TestCancelledDownload(t *testing.T) {
ldm := NewLayerDownloadManager(&mockLayerStore{make(map[layer.ChainID]*mockLayer)}, maxDownloadConcurrency)
progressChan := make(chan progress.Progress)
progressDone := make(chan struct{})
go func() {
for range progressChan {
}
close(progressDone)
}()
ctx, cancel := context.WithCancel(context.Background())
go func() {
<-time.After(time.Millisecond)
cancel()
}()
descriptors := downloadDescriptors(nil)
_, _, err := ldm.Download(ctx, *image.NewRootFS(), descriptors, progress.ChanOutput(progressChan))
if err != context.Canceled {
t.Fatal("expected download to be cancelled")
}
close(progressChan)
<-progressDone
}

View file

@ -0,0 +1,343 @@
package xfer
import (
"sync"
"github.com/docker/docker/pkg/progress"
"golang.org/x/net/context"
)
// DoNotRetry is an error wrapper indicating that the error cannot be resolved
// with a retry.
type DoNotRetry struct {
Err error
}
// Error returns the stringified representation of the encapsulated error.
func (e DoNotRetry) Error() string {
return e.Err.Error()
}
// Watcher is returned by Watch and can be passed to Release to stop watching.
type Watcher struct {
// signalChan is used to signal to the watcher goroutine that
// new progress information is available, or that the transfer
// has finished.
signalChan chan struct{}
// releaseChan signals to the watcher goroutine that the watcher
// should be detached.
releaseChan chan struct{}
// running remains open as long as the watcher is watching the
// transfer. It gets closed if the transfer finishes or the
// watcher is detached.
running chan struct{}
}
// Transfer represents an in-progress transfer.
type Transfer interface {
Watch(progressOutput progress.Output) *Watcher
Release(*Watcher)
Context() context.Context
Cancel()
Done() <-chan struct{}
Released() <-chan struct{}
Broadcast(masterProgressChan <-chan progress.Progress)
}
type transfer struct {
mu sync.Mutex
ctx context.Context
cancel context.CancelFunc
// watchers keeps track of the goroutines monitoring progress output,
// indexed by the channels that release them.
watchers map[chan struct{}]*Watcher
// lastProgress is the most recently received progress event.
lastProgress progress.Progress
// hasLastProgress is true when lastProgress has been set.
hasLastProgress bool
// running remains open as long as the transfer is in progress.
running chan struct{}
// hasWatchers stays open until all watchers release the trasnfer.
hasWatchers chan struct{}
// broadcastDone is true if the master progress channel has closed.
broadcastDone bool
// broadcastSyncChan allows watchers to "ping" the broadcasting
// goroutine to wait for it for deplete its input channel. This ensures
// a detaching watcher won't miss an event that was sent before it
// started detaching.
broadcastSyncChan chan struct{}
}
// NewTransfer creates a new transfer.
func NewTransfer() Transfer {
t := &transfer{
watchers: make(map[chan struct{}]*Watcher),
running: make(chan struct{}),
hasWatchers: make(chan struct{}),
broadcastSyncChan: make(chan struct{}),
}
// This uses context.Background instead of a caller-supplied context
// so that a transfer won't be cancelled automatically if the client
// which requested it is ^C'd (there could be other viewers).
t.ctx, t.cancel = context.WithCancel(context.Background())
return t
}
// Broadcast copies the progress and error output to all viewers.
func (t *transfer) Broadcast(masterProgressChan <-chan progress.Progress) {
for {
var (
p progress.Progress
ok bool
)
select {
case p, ok = <-masterProgressChan:
default:
// We've depleted the channel, so now we can handle
// reads on broadcastSyncChan to let detaching watchers
// know we're caught up.
select {
case <-t.broadcastSyncChan:
continue
case p, ok = <-masterProgressChan:
}
}
t.mu.Lock()
if ok {
t.lastProgress = p
t.hasLastProgress = true
for _, w := range t.watchers {
select {
case w.signalChan <- struct{}{}:
default:
}
}
} else {
t.broadcastDone = true
}
t.mu.Unlock()
if !ok {
close(t.running)
return
}
}
}
// Watch adds a watcher to the transfer. The supplied channel gets progress
// updates and is closed when the transfer finishes.
func (t *transfer) Watch(progressOutput progress.Output) *Watcher {
t.mu.Lock()
defer t.mu.Unlock()
w := &Watcher{
releaseChan: make(chan struct{}),
signalChan: make(chan struct{}),
running: make(chan struct{}),
}
if t.broadcastDone {
close(w.running)
return w
}
t.watchers[w.releaseChan] = w
go func() {
defer func() {
close(w.running)
}()
done := false
for {
t.mu.Lock()
hasLastProgress := t.hasLastProgress
lastProgress := t.lastProgress
t.mu.Unlock()
// This might write the last progress item a
// second time (since channel closure also gets
// us here), but that's fine.
if hasLastProgress {
progressOutput.WriteProgress(lastProgress)
}
if done {
return
}
select {
case <-w.signalChan:
case <-w.releaseChan:
done = true
// Since the watcher is going to detach, make
// sure the broadcaster is caught up so we
// don't miss anything.
select {
case t.broadcastSyncChan <- struct{}{}:
case <-t.running:
}
case <-t.running:
done = true
}
}
}()
return w
}
// Release is the inverse of Watch; indicating that the watcher no longer wants
// to be notified about the progress of the transfer. All calls to Watch must
// be paired with later calls to Release so that the lifecycle of the transfer
// is properly managed.
func (t *transfer) Release(watcher *Watcher) {
t.mu.Lock()
delete(t.watchers, watcher.releaseChan)
if len(t.watchers) == 0 {
close(t.hasWatchers)
t.cancel()
}
t.mu.Unlock()
close(watcher.releaseChan)
// Block until the watcher goroutine completes
<-watcher.running
}
// Done returns a channel which is closed if the transfer completes or is
// cancelled. Note that having 0 watchers causes a transfer to be cancelled.
func (t *transfer) Done() <-chan struct{} {
// Note that this doesn't return t.ctx.Done() because that channel will
// be closed the moment Cancel is called, and we need to return a
// channel that blocks until a cancellation is actually acknowledged by
// the transfer function.
return t.running
}
// Released returns a channel which is closed once all watchers release the
// transfer.
func (t *transfer) Released() <-chan struct{} {
return t.hasWatchers
}
// Context returns the context associated with the transfer.
func (t *transfer) Context() context.Context {
return t.ctx
}
// Cancel cancels the context associated with the transfer.
func (t *transfer) Cancel() {
t.cancel()
}
// DoFunc is a function called by the transfer manager to actually perform
// a transfer. It should be non-blocking. It should wait until the start channel
// is closed before transfering any data. If the function closes inactive, that
// signals to the transfer manager that the job is no longer actively moving
// data - for example, it may be waiting for a dependent tranfer to finish.
// This prevents it from taking up a slot.
type DoFunc func(progressChan chan<- progress.Progress, start <-chan struct{}, inactive chan<- struct{}) Transfer
// TransferManager is used by LayerDownloadManager and LayerUploadManager to
// schedule and deduplicate transfers. It is up to the TransferManager
// implementation to make the scheduling and concurrency decisions.
type TransferManager interface {
// Transfer checks if a transfer with the given key is in progress. If
// so, it returns progress and error output from that transfer.
// Otherwise, it will call xferFunc to initiate the transfer.
Transfer(key string, xferFunc DoFunc, progressOutput progress.Output) (Transfer, *Watcher)
}
type transferManager struct {
mu sync.Mutex
concurrencyLimit int
activeTransfers int
transfers map[string]Transfer
waitingTransfers []chan struct{}
}
// NewTransferManager returns a new TransferManager.
func NewTransferManager(concurrencyLimit int) TransferManager {
return &transferManager{
concurrencyLimit: concurrencyLimit,
transfers: make(map[string]Transfer),
}
}
// Transfer checks if a transfer matching the given key is in progress. If not,
// it starts one by calling xferFunc. The caller supplies a channel which
// receives progress output from the transfer.
func (tm *transferManager) Transfer(key string, xferFunc DoFunc, progressOutput progress.Output) (Transfer, *Watcher) {
tm.mu.Lock()
defer tm.mu.Unlock()
if xfer, present := tm.transfers[key]; present {
// Transfer is already in progress.
watcher := xfer.Watch(progressOutput)
return xfer, watcher
}
start := make(chan struct{})
inactive := make(chan struct{})
if tm.activeTransfers < tm.concurrencyLimit {
close(start)
tm.activeTransfers++
} else {
tm.waitingTransfers = append(tm.waitingTransfers, start)
}
masterProgressChan := make(chan progress.Progress)
xfer := xferFunc(masterProgressChan, start, inactive)
watcher := xfer.Watch(progressOutput)
go xfer.Broadcast(masterProgressChan)
tm.transfers[key] = xfer
// When the transfer is finished, remove from the map.
go func() {
for {
select {
case <-inactive:
tm.mu.Lock()
tm.inactivate(start)
tm.mu.Unlock()
inactive = nil
case <-xfer.Done():
tm.mu.Lock()
if inactive != nil {
tm.inactivate(start)
}
delete(tm.transfers, key)
tm.mu.Unlock()
return
}
}
}()
return xfer, watcher
}
func (tm *transferManager) inactivate(start chan struct{}) {
// If the transfer was started, remove it from the activeTransfers
// count.
select {
case <-start:
// Start next transfer if any are waiting
if len(tm.waitingTransfers) != 0 {
close(tm.waitingTransfers[0])
tm.waitingTransfers = tm.waitingTransfers[1:]
} else {
tm.activeTransfers--
}
default:
}
}

View file

@ -0,0 +1,385 @@
package xfer
import (
"sync/atomic"
"testing"
"time"
"github.com/docker/docker/pkg/progress"
)
func TestTransfer(t *testing.T) {
makeXferFunc := func(id string) DoFunc {
return func(progressChan chan<- progress.Progress, start <-chan struct{}, inactive chan<- struct{}) Transfer {
select {
case <-start:
default:
t.Fatalf("transfer function not started even though concurrency limit not reached")
}
xfer := NewTransfer()
go func() {
for i := 0; i <= 10; i++ {
progressChan <- progress.Progress{ID: id, Action: "testing", Current: int64(i), Total: 10}
time.Sleep(10 * time.Millisecond)
}
close(progressChan)
}()
return xfer
}
}
tm := NewTransferManager(5)
progressChan := make(chan progress.Progress)
progressDone := make(chan struct{})
receivedProgress := make(map[string]int64)
go func() {
for p := range progressChan {
val, present := receivedProgress[p.ID]
if !present {
if p.Current != 0 {
t.Fatalf("got unexpected progress value: %d (expected 0)", p.Current)
}
} else if p.Current == 10 {
// Special case: last progress output may be
// repeated because the transfer finishing
// causes the latest progress output to be
// written to the channel (in case the watcher
// missed it).
if p.Current != 9 && p.Current != 10 {
t.Fatalf("got unexpected progress value: %d (expected %d)", p.Current, val+1)
}
} else if p.Current != val+1 {
t.Fatalf("got unexpected progress value: %d (expected %d)", p.Current, val+1)
}
receivedProgress[p.ID] = p.Current
}
close(progressDone)
}()
// Start a few transfers
ids := []string{"id1", "id2", "id3"}
xfers := make([]Transfer, len(ids))
watchers := make([]*Watcher, len(ids))
for i, id := range ids {
xfers[i], watchers[i] = tm.Transfer(id, makeXferFunc(id), progress.ChanOutput(progressChan))
}
for i, xfer := range xfers {
<-xfer.Done()
xfer.Release(watchers[i])
}
close(progressChan)
<-progressDone
for _, id := range ids {
if receivedProgress[id] != 10 {
t.Fatalf("final progress value %d instead of 10", receivedProgress[id])
}
}
}
func TestConcurrencyLimit(t *testing.T) {
concurrencyLimit := 3
var runningJobs int32
makeXferFunc := func(id string) DoFunc {
return func(progressChan chan<- progress.Progress, start <-chan struct{}, inactive chan<- struct{}) Transfer {
xfer := NewTransfer()
go func() {
<-start
totalJobs := atomic.AddInt32(&runningJobs, 1)
if int(totalJobs) > concurrencyLimit {
t.Fatalf("too many jobs running")
}
for i := 0; i <= 10; i++ {
progressChan <- progress.Progress{ID: id, Action: "testing", Current: int64(i), Total: 10}
time.Sleep(10 * time.Millisecond)
}
atomic.AddInt32(&runningJobs, -1)
close(progressChan)
}()
return xfer
}
}
tm := NewTransferManager(concurrencyLimit)
progressChan := make(chan progress.Progress)
progressDone := make(chan struct{})
receivedProgress := make(map[string]int64)
go func() {
for p := range progressChan {
receivedProgress[p.ID] = p.Current
}
close(progressDone)
}()
// Start more transfers than the concurrency limit
ids := []string{"id1", "id2", "id3", "id4", "id5", "id6", "id7", "id8"}
xfers := make([]Transfer, len(ids))
watchers := make([]*Watcher, len(ids))
for i, id := range ids {
xfers[i], watchers[i] = tm.Transfer(id, makeXferFunc(id), progress.ChanOutput(progressChan))
}
for i, xfer := range xfers {
<-xfer.Done()
xfer.Release(watchers[i])
}
close(progressChan)
<-progressDone
for _, id := range ids {
if receivedProgress[id] != 10 {
t.Fatalf("final progress value %d instead of 10", receivedProgress[id])
}
}
}
func TestInactiveJobs(t *testing.T) {
concurrencyLimit := 3
var runningJobs int32
testDone := make(chan struct{})
makeXferFunc := func(id string) DoFunc {
return func(progressChan chan<- progress.Progress, start <-chan struct{}, inactive chan<- struct{}) Transfer {
xfer := NewTransfer()
go func() {
<-start
totalJobs := atomic.AddInt32(&runningJobs, 1)
if int(totalJobs) > concurrencyLimit {
t.Fatalf("too many jobs running")
}
for i := 0; i <= 10; i++ {
progressChan <- progress.Progress{ID: id, Action: "testing", Current: int64(i), Total: 10}
time.Sleep(10 * time.Millisecond)
}
atomic.AddInt32(&runningJobs, -1)
close(inactive)
<-testDone
close(progressChan)
}()
return xfer
}
}
tm := NewTransferManager(concurrencyLimit)
progressChan := make(chan progress.Progress)
progressDone := make(chan struct{})
receivedProgress := make(map[string]int64)
go func() {
for p := range progressChan {
receivedProgress[p.ID] = p.Current
}
close(progressDone)
}()
// Start more transfers than the concurrency limit
ids := []string{"id1", "id2", "id3", "id4", "id5", "id6", "id7", "id8"}
xfers := make([]Transfer, len(ids))
watchers := make([]*Watcher, len(ids))
for i, id := range ids {
xfers[i], watchers[i] = tm.Transfer(id, makeXferFunc(id), progress.ChanOutput(progressChan))
}
close(testDone)
for i, xfer := range xfers {
<-xfer.Done()
xfer.Release(watchers[i])
}
close(progressChan)
<-progressDone
for _, id := range ids {
if receivedProgress[id] != 10 {
t.Fatalf("final progress value %d instead of 10", receivedProgress[id])
}
}
}
func TestWatchRelease(t *testing.T) {
ready := make(chan struct{})
makeXferFunc := func(id string) DoFunc {
return func(progressChan chan<- progress.Progress, start <-chan struct{}, inactive chan<- struct{}) Transfer {
xfer := NewTransfer()
go func() {
defer func() {
close(progressChan)
}()
<-ready
for i := int64(0); ; i++ {
select {
case <-time.After(10 * time.Millisecond):
case <-xfer.Context().Done():
return
}
progressChan <- progress.Progress{ID: id, Action: "testing", Current: i, Total: 10}
}
}()
return xfer
}
}
tm := NewTransferManager(5)
type watcherInfo struct {
watcher *Watcher
progressChan chan progress.Progress
progressDone chan struct{}
receivedFirstProgress chan struct{}
}
progressConsumer := func(w watcherInfo) {
first := true
for range w.progressChan {
if first {
close(w.receivedFirstProgress)
}
first = false
}
close(w.progressDone)
}
// Start a transfer
watchers := make([]watcherInfo, 5)
var xfer Transfer
watchers[0].progressChan = make(chan progress.Progress)
watchers[0].progressDone = make(chan struct{})
watchers[0].receivedFirstProgress = make(chan struct{})
xfer, watchers[0].watcher = tm.Transfer("id1", makeXferFunc("id1"), progress.ChanOutput(watchers[0].progressChan))
go progressConsumer(watchers[0])
// Give it multiple watchers
for i := 1; i != len(watchers); i++ {
watchers[i].progressChan = make(chan progress.Progress)
watchers[i].progressDone = make(chan struct{})
watchers[i].receivedFirstProgress = make(chan struct{})
watchers[i].watcher = xfer.Watch(progress.ChanOutput(watchers[i].progressChan))
go progressConsumer(watchers[i])
}
// Now that the watchers are set up, allow the transfer goroutine to
// proceed.
close(ready)
// Confirm that each watcher gets progress output.
for _, w := range watchers {
<-w.receivedFirstProgress
}
// Release one watcher every 5ms
for _, w := range watchers {
xfer.Release(w.watcher)
<-time.After(5 * time.Millisecond)
}
// Now that all watchers have been released, Released() should
// return a closed channel.
<-xfer.Released()
// Done() should return a closed channel because the xfer func returned
// due to cancellation.
<-xfer.Done()
for _, w := range watchers {
close(w.progressChan)
<-w.progressDone
}
}
func TestDuplicateTransfer(t *testing.T) {
ready := make(chan struct{})
var xferFuncCalls int32
makeXferFunc := func(id string) DoFunc {
return func(progressChan chan<- progress.Progress, start <-chan struct{}, inactive chan<- struct{}) Transfer {
atomic.AddInt32(&xferFuncCalls, 1)
xfer := NewTransfer()
go func() {
defer func() {
close(progressChan)
}()
<-ready
for i := int64(0); ; i++ {
select {
case <-time.After(10 * time.Millisecond):
case <-xfer.Context().Done():
return
}
progressChan <- progress.Progress{ID: id, Action: "testing", Current: i, Total: 10}
}
}()
return xfer
}
}
tm := NewTransferManager(5)
type transferInfo struct {
xfer Transfer
watcher *Watcher
progressChan chan progress.Progress
progressDone chan struct{}
receivedFirstProgress chan struct{}
}
progressConsumer := func(t transferInfo) {
first := true
for range t.progressChan {
if first {
close(t.receivedFirstProgress)
}
first = false
}
close(t.progressDone)
}
// Try to start multiple transfers with the same ID
transfers := make([]transferInfo, 5)
for i := range transfers {
t := &transfers[i]
t.progressChan = make(chan progress.Progress)
t.progressDone = make(chan struct{})
t.receivedFirstProgress = make(chan struct{})
t.xfer, t.watcher = tm.Transfer("id1", makeXferFunc("id1"), progress.ChanOutput(t.progressChan))
go progressConsumer(*t)
}
// Allow the transfer goroutine to proceed.
close(ready)
// Confirm that each watcher gets progress output.
for _, t := range transfers {
<-t.receivedFirstProgress
}
// Confirm that the transfer function was called exactly once.
if xferFuncCalls != 1 {
t.Fatal("transfer function wasn't called exactly once")
}
// Release one watcher every 5ms
for _, t := range transfers {
t.xfer.Release(t.watcher)
<-time.After(5 * time.Millisecond)
}
for _, t := range transfers {
// Now that all watchers have been released, Released() should
// return a closed channel.
<-t.xfer.Released()
// Done() should return a closed channel because the xfer func returned
// due to cancellation.
<-t.xfer.Done()
}
for _, t := range transfers {
close(t.progressChan)
<-t.progressDone
}
}

159
distribution/xfer/upload.go Normal file
View file

@ -0,0 +1,159 @@
package xfer
import (
"errors"
"time"
"github.com/Sirupsen/logrus"
"github.com/docker/distribution/digest"
"github.com/docker/docker/layer"
"github.com/docker/docker/pkg/progress"
"golang.org/x/net/context"
)
const maxUploadAttempts = 5
// LayerUploadManager provides task management and progress reporting for
// uploads.
type LayerUploadManager struct {
tm TransferManager
}
// NewLayerUploadManager returns a new LayerUploadManager.
func NewLayerUploadManager(concurrencyLimit int) *LayerUploadManager {
return &LayerUploadManager{
tm: NewTransferManager(concurrencyLimit),
}
}
type uploadTransfer struct {
Transfer
diffID layer.DiffID
digest digest.Digest
err error
}
// An UploadDescriptor references a layer that may need to be uploaded.
type UploadDescriptor interface {
// Key returns the key used to deduplicate uploads.
Key() string
// ID returns the ID for display purposes.
ID() string
// DiffID should return the DiffID for this layer.
DiffID() layer.DiffID
// Upload is called to perform the Upload.
Upload(ctx context.Context, progressOutput progress.Output) (digest.Digest, error)
}
// Upload is a blocking function which ensures the listed layers are present on
// the remote registry. It uses the string returned by the Key method to
// deduplicate uploads.
func (lum *LayerUploadManager) Upload(ctx context.Context, layers []UploadDescriptor, progressOutput progress.Output) (map[layer.DiffID]digest.Digest, error) {
var (
uploads []*uploadTransfer
digests = make(map[layer.DiffID]digest.Digest)
dedupDescriptors = make(map[string]struct{})
)
for _, descriptor := range layers {
progress.Update(progressOutput, descriptor.ID(), "Preparing")
key := descriptor.Key()
if _, present := dedupDescriptors[key]; present {
continue
}
dedupDescriptors[key] = struct{}{}
xferFunc := lum.makeUploadFunc(descriptor)
upload, watcher := lum.tm.Transfer(descriptor.Key(), xferFunc, progressOutput)
defer upload.Release(watcher)
uploads = append(uploads, upload.(*uploadTransfer))
}
for _, upload := range uploads {
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-upload.Transfer.Done():
if upload.err != nil {
return nil, upload.err
}
digests[upload.diffID] = upload.digest
}
}
return digests, nil
}
func (lum *LayerUploadManager) makeUploadFunc(descriptor UploadDescriptor) DoFunc {
return func(progressChan chan<- progress.Progress, start <-chan struct{}, inactive chan<- struct{}) Transfer {
u := &uploadTransfer{
Transfer: NewTransfer(),
diffID: descriptor.DiffID(),
}
go func() {
defer func() {
close(progressChan)
}()
progressOutput := progress.ChanOutput(progressChan)
select {
case <-start:
default:
progress.Update(progressOutput, descriptor.ID(), "Waiting")
<-start
}
retries := 0
for {
digest, err := descriptor.Upload(u.Transfer.Context(), progressOutput)
if err == nil {
u.digest = digest
break
}
// If an error was returned because the context
// was cancelled, we shouldn't retry.
select {
case <-u.Transfer.Context().Done():
u.err = err
return
default:
}
retries++
if _, isDNR := err.(DoNotRetry); isDNR || retries == maxUploadAttempts {
logrus.Errorf("Upload failed: %v", err)
u.err = err
return
}
logrus.Errorf("Upload failed, retrying: %v", err)
delay := retries * 5
ticker := time.NewTicker(time.Second)
selectLoop:
for {
progress.Updatef(progressOutput, descriptor.ID(), "Retrying in %d seconds", delay)
select {
case <-ticker.C:
delay--
if delay == 0 {
ticker.Stop()
break selectLoop
}
case <-u.Transfer.Context().Done():
ticker.Stop()
u.err = errors.New("upload cancelled during retry delay")
return
}
}
}
}()
return u
}
}

View file

@ -0,0 +1,153 @@
package xfer
import (
"errors"
"sync/atomic"
"testing"
"time"
"github.com/docker/distribution/digest"
"github.com/docker/docker/layer"
"github.com/docker/docker/pkg/progress"
"golang.org/x/net/context"
)
const maxUploadConcurrency = 3
type mockUploadDescriptor struct {
currentUploads *int32
diffID layer.DiffID
simulateRetries int
}
// Key returns the key used to deduplicate downloads.
func (u *mockUploadDescriptor) Key() string {
return u.diffID.String()
}
// ID returns the ID for display purposes.
func (u *mockUploadDescriptor) ID() string {
return u.diffID.String()
}
// DiffID should return the DiffID for this layer.
func (u *mockUploadDescriptor) DiffID() layer.DiffID {
return u.diffID
}
// Upload is called to perform the upload.
func (u *mockUploadDescriptor) Upload(ctx context.Context, progressOutput progress.Output) (digest.Digest, error) {
if u.currentUploads != nil {
defer atomic.AddInt32(u.currentUploads, -1)
if atomic.AddInt32(u.currentUploads, 1) > maxUploadConcurrency {
return "", errors.New("concurrency limit exceeded")
}
}
// Sleep a bit to simulate a time-consuming upload.
for i := int64(0); i <= 10; i++ {
select {
case <-ctx.Done():
return "", ctx.Err()
case <-time.After(10 * time.Millisecond):
progressOutput.WriteProgress(progress.Progress{ID: u.ID(), Current: i, Total: 10})
}
}
if u.simulateRetries != 0 {
u.simulateRetries--
return "", errors.New("simulating retry")
}
// For the mock implementation, use SHA256(DiffID) as the returned
// digest.
return digest.FromBytes([]byte(u.diffID.String()))
}
func uploadDescriptors(currentUploads *int32) []UploadDescriptor {
return []UploadDescriptor{
&mockUploadDescriptor{currentUploads, layer.DiffID("sha256:cbbf2f9a99b47fc460d422812b6a5adff7dfee951d8fa2e4a98caa0382cfbdbf"), 0},
&mockUploadDescriptor{currentUploads, layer.DiffID("sha256:1515325234325236634634608943609283523908626098235490238423902343"), 0},
&mockUploadDescriptor{currentUploads, layer.DiffID("sha256:6929356290463485374960346430698374523437683470934634534953453453"), 0},
&mockUploadDescriptor{currentUploads, layer.DiffID("sha256:cbbf2f9a99b47fc460d422812b6a5adff7dfee951d8fa2e4a98caa0382cfbdbf"), 0},
&mockUploadDescriptor{currentUploads, layer.DiffID("sha256:8159352387436803946235346346368745389534789534897538734598734987"), 1},
&mockUploadDescriptor{currentUploads, layer.DiffID("sha256:4637863963478346897346987346987346789346789364879364897364987346"), 0},
}
}
var expectedDigests = map[layer.DiffID]digest.Digest{
layer.DiffID("sha256:cbbf2f9a99b47fc460d422812b6a5adff7dfee951d8fa2e4a98caa0382cfbdbf"): digest.Digest("sha256:c5095d6cf7ee42b7b064371dcc1dc3fb4af197f04d01a60009d484bd432724fc"),
layer.DiffID("sha256:1515325234325236634634608943609283523908626098235490238423902343"): digest.Digest("sha256:968cbfe2ff5269ea1729b3804767a1f57ffbc442d3bc86f47edbf7e688a4f36e"),
layer.DiffID("sha256:6929356290463485374960346430698374523437683470934634534953453453"): digest.Digest("sha256:8a5e56ab4b477a400470a7d5d4c1ca0c91235fd723ab19cc862636a06f3a735d"),
layer.DiffID("sha256:8159352387436803946235346346368745389534789534897538734598734987"): digest.Digest("sha256:5e733e5cd3688512fc240bd5c178e72671c9915947d17bb8451750d827944cb2"),
layer.DiffID("sha256:4637863963478346897346987346987346789346789364879364897364987346"): digest.Digest("sha256:ec4bb98d15e554a9f66c3ef9296cf46772c0ded3b1592bd8324d96e2f60f460c"),
}
func TestSuccessfulUpload(t *testing.T) {
lum := NewLayerUploadManager(maxUploadConcurrency)
progressChan := make(chan progress.Progress)
progressDone := make(chan struct{})
receivedProgress := make(map[string]int64)
go func() {
for p := range progressChan {
receivedProgress[p.ID] = p.Current
}
close(progressDone)
}()
var currentUploads int32
descriptors := uploadDescriptors(&currentUploads)
digests, err := lum.Upload(context.Background(), descriptors, progress.ChanOutput(progressChan))
if err != nil {
t.Fatalf("upload error: %v", err)
}
close(progressChan)
<-progressDone
if len(digests) != len(expectedDigests) {
t.Fatal("wrong number of keys in digests map")
}
for key, val := range expectedDigests {
if digests[key] != val {
t.Fatalf("mismatch in digest array for key %v (expected %v, got %v)", key, val, digests[key])
}
if receivedProgress[key.String()] != 10 {
t.Fatalf("missing or wrong progress output for %v", key)
}
}
}
func TestCancelledUpload(t *testing.T) {
lum := NewLayerUploadManager(maxUploadConcurrency)
progressChan := make(chan progress.Progress)
progressDone := make(chan struct{})
go func() {
for range progressChan {
}
close(progressDone)
}()
ctx, cancel := context.WithCancel(context.Background())
go func() {
<-time.After(time.Millisecond)
cancel()
}()
descriptors := uploadDescriptors(nil)
_, err := lum.Upload(ctx, descriptors, progress.ChanOutput(progressChan))
if err != context.Canceled {
t.Fatal("expected upload to be cancelled")
}
close(progressChan)
<-progressDone
}