diff --git a/graph/load.go b/graph/load.go index f1dfce7a1c..a58c5a3cf9 100644 --- a/graph/load.go +++ b/graph/load.go @@ -108,7 +108,7 @@ func (s *TagStore) recursiveLoad(address, tmpImageDir string) error { // ensure no two downloads of the same layer happen at the same time if ps, found := s.poolAdd("pull", "layer:"+img.ID); found { logrus.Debugf("Image (id: %s) load is already running, waiting", img.ID) - ps.Wait(nil, nil) + ps.Wait() return nil } diff --git a/graph/pools_test.go b/graph/pools_test.go index f88a1cf15b..a7b27271b7 100644 --- a/graph/pools_test.go +++ b/graph/pools_test.go @@ -13,8 +13,8 @@ func init() { func TestPools(t *testing.T) { s := &TagStore{ - pullingPool: make(map[string]*progressreader.ProgressStatus), - pushingPool: make(map[string]*progressreader.ProgressStatus), + pullingPool: make(map[string]*progressreader.Broadcaster), + pushingPool: make(map[string]*progressreader.Broadcaster), } if _, found := s.poolAdd("pull", "test1"); found { diff --git a/graph/pull_v1.go b/graph/pull_v1.go index d5f9492790..d36e9b712a 100644 --- a/graph/pull_v1.go +++ b/graph/pull_v1.go @@ -138,29 +138,30 @@ func (p *v1Puller) pullRepository(askedTag string) error { } // ensure no two downloads of the same image happen at the same time - ps, found := p.poolAdd("pull", "img:"+img.ID) + broadcaster, found := p.poolAdd("pull", "img:"+img.ID) if found { - msg := p.sf.FormatProgress(stringid.TruncateID(img.ID), "Layer already being pulled by another client. Waiting.", nil) - ps.Wait(out, msg) + out.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), "Layer already being pulled by another client. Waiting.", nil)) + broadcaster.Add(out) + broadcaster.Wait() out.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), "Download complete", nil)) errors <- nil return } - ps.AddObserver(out) + broadcaster.Add(out) defer p.poolRemove("pull", "img:"+img.ID) // we need to retain it until tagging p.graph.Retain(sessionID, img.ID) imgIDs = append(imgIDs, img.ID) - ps.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), fmt.Sprintf("Pulling image (%s) from %s", img.Tag, p.repoInfo.CanonicalName), nil)) + broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), fmt.Sprintf("Pulling image (%s) from %s", img.Tag, p.repoInfo.CanonicalName), nil)) success := false var lastErr, err error var isDownloaded bool for _, ep := range p.repoInfo.Index.Mirrors { ep += "v1/" - ps.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), fmt.Sprintf("Pulling image (%s) from %s, mirror: %s", img.Tag, p.repoInfo.CanonicalName, ep), nil)) - if isDownloaded, err = p.pullImage(ps, img.ID, ep, repoData.Tokens); err != nil { + broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), fmt.Sprintf("Pulling image (%s) from %s, mirror: %s", img.Tag, p.repoInfo.CanonicalName, ep), nil)) + if isDownloaded, err = p.pullImage(broadcaster, img.ID, ep, repoData.Tokens); err != nil { // Don't report errors when pulling from mirrors. logrus.Debugf("Error pulling image (%s) from %s, mirror: %s, %s", img.Tag, p.repoInfo.CanonicalName, ep, err) continue @@ -171,12 +172,12 @@ func (p *v1Puller) pullRepository(askedTag string) error { } if !success { for _, ep := range repoData.Endpoints { - ps.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), fmt.Sprintf("Pulling image (%s) from %s, endpoint: %s", img.Tag, p.repoInfo.CanonicalName, ep), nil)) - if isDownloaded, err = p.pullImage(ps, img.ID, ep, repoData.Tokens); err != nil { + broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), fmt.Sprintf("Pulling image (%s) from %s, endpoint: %s", img.Tag, p.repoInfo.CanonicalName, ep), nil)) + if isDownloaded, err = p.pullImage(broadcaster, img.ID, ep, repoData.Tokens); err != nil { // It's not ideal that only the last error is returned, it would be better to concatenate the errors. // As the error is also given to the output stream the user will see the error. lastErr = err - ps.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), fmt.Sprintf("Error pulling image (%s) from %s, endpoint: %s, %s", img.Tag, p.repoInfo.CanonicalName, ep, err), nil)) + broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), fmt.Sprintf("Error pulling image (%s) from %s, endpoint: %s, %s", img.Tag, p.repoInfo.CanonicalName, ep, err), nil)) continue } layersDownloaded = layersDownloaded || isDownloaded @@ -186,11 +187,11 @@ func (p *v1Puller) pullRepository(askedTag string) error { } if !success { err := fmt.Errorf("Error pulling image (%s) from %s, %v", img.Tag, p.repoInfo.CanonicalName, lastErr) - ps.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), err.Error(), nil)) + broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), err.Error(), nil)) errors <- err return } - ps.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), "Download complete", nil)) + broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), "Download complete", nil)) errors <- nil } @@ -244,18 +245,19 @@ func (p *v1Puller) pullImage(out io.Writer, imgID, endpoint string, token []stri id := history[i] // ensure no two downloads of the same layer happen at the same time - ps, found := p.poolAdd("pull", "layer:"+id) + broadcaster, found := p.poolAdd("pull", "layer:"+id) if found { logrus.Debugf("Image (id: %s) pull is already running, skipping", id) - msg := p.sf.FormatProgress(stringid.TruncateID(imgID), "Layer already being pulled by another client. Waiting.", nil) - ps.Wait(out, msg) + out.Write(p.sf.FormatProgress(stringid.TruncateID(imgID), "Layer already being pulled by another client. Waiting.", nil)) + broadcaster.Add(out) + broadcaster.Wait() } else { - ps.AddObserver(out) + broadcaster.Add(out) } defer p.poolRemove("pull", "layer:"+id) if !p.graph.Exists(id) { - ps.Write(p.sf.FormatProgress(stringid.TruncateID(id), "Pulling metadata", nil)) + broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(id), "Pulling metadata", nil)) var ( imgJSON []byte imgSize int64 @@ -266,7 +268,7 @@ func (p *v1Puller) pullImage(out io.Writer, imgID, endpoint string, token []stri for j := 1; j <= retries; j++ { imgJSON, imgSize, err = p.session.GetRemoteImageJSON(id, endpoint) if err != nil && j == retries { - ps.Write(p.sf.FormatProgress(stringid.TruncateID(id), "Error pulling dependent layers", nil)) + broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(id), "Error pulling dependent layers", nil)) return layersDownloaded, err } else if err != nil { time.Sleep(time.Duration(j) * 500 * time.Millisecond) @@ -275,7 +277,7 @@ func (p *v1Puller) pullImage(out io.Writer, imgID, endpoint string, token []stri img, err = image.NewImgJSON(imgJSON) layersDownloaded = true if err != nil && j == retries { - ps.Write(p.sf.FormatProgress(stringid.TruncateID(id), "Error pulling dependent layers", nil)) + broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(id), "Error pulling dependent layers", nil)) return layersDownloaded, fmt.Errorf("Failed to parse json: %s", err) } else if err != nil { time.Sleep(time.Duration(j) * 500 * time.Millisecond) @@ -291,7 +293,7 @@ func (p *v1Puller) pullImage(out io.Writer, imgID, endpoint string, token []stri if j > 1 { status = fmt.Sprintf("Pulling fs layer [retries: %d]", j) } - ps.Write(p.sf.FormatProgress(stringid.TruncateID(id), status, nil)) + broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(id), status, nil)) layer, err := p.session.GetRemoteImageLayer(img.ID, endpoint, imgSize) if uerr, ok := err.(*url.Error); ok { err = uerr.Err @@ -300,7 +302,7 @@ func (p *v1Puller) pullImage(out io.Writer, imgID, endpoint string, token []stri time.Sleep(time.Duration(j) * 500 * time.Millisecond) continue } else if err != nil { - ps.Write(p.sf.FormatProgress(stringid.TruncateID(id), "Error pulling dependent layers", nil)) + broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(id), "Error pulling dependent layers", nil)) return layersDownloaded, err } layersDownloaded = true @@ -309,7 +311,7 @@ func (p *v1Puller) pullImage(out io.Writer, imgID, endpoint string, token []stri err = p.graph.Register(img, progressreader.New(progressreader.Config{ In: layer, - Out: ps, + Out: broadcaster, Formatter: p.sf, Size: imgSize, NewLines: false, @@ -320,14 +322,14 @@ func (p *v1Puller) pullImage(out io.Writer, imgID, endpoint string, token []stri time.Sleep(time.Duration(j) * 500 * time.Millisecond) continue } else if err != nil { - ps.Write(p.sf.FormatProgress(stringid.TruncateID(id), "Error downloading dependent layers", nil)) + broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(id), "Error downloading dependent layers", nil)) return layersDownloaded, err } else { break } } } - ps.Write(p.sf.FormatProgress(stringid.TruncateID(id), "Download complete", nil)) + broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(id), "Download complete", nil)) } return layersDownloaded, nil } diff --git a/graph/pull_v2.go b/graph/pull_v2.go index ed5605befb..afbda480ce 100644 --- a/graph/pull_v2.go +++ b/graph/pull_v2.go @@ -73,28 +73,29 @@ func (p *v2Puller) pullV2Repository(tag string) (err error) { } - ps, found := p.poolAdd("pull", taggedName) + broadcaster, found := p.poolAdd("pull", taggedName) if found { // Another pull of the same repository is already taking place; just wait for it to finish - msg := p.sf.FormatStatus("", "Repository %s already being pulled by another client. Waiting.", p.repoInfo.CanonicalName) - ps.Wait(p.config.OutStream, msg) + p.config.OutStream.Write(p.sf.FormatStatus("", "Repository %s already being pulled by another client. Waiting.", p.repoInfo.CanonicalName)) + broadcaster.Add(p.config.OutStream) + broadcaster.Wait() return nil } defer p.poolRemove("pull", taggedName) - ps.AddObserver(p.config.OutStream) + broadcaster.Add(p.config.OutStream) var layersDownloaded bool for _, tag := range tags { // pulledNew is true if either new layers were downloaded OR if existing images were newly tagged // TODO(tiborvass): should we change the name of `layersDownload`? What about message in WriteStatus? - pulledNew, err := p.pullV2Tag(ps, tag, taggedName) + pulledNew, err := p.pullV2Tag(broadcaster, tag, taggedName) if err != nil { return err } layersDownloaded = layersDownloaded || pulledNew } - writeStatus(taggedName, ps, p.sf, layersDownloaded) + writeStatus(taggedName, broadcaster, p.sf, layersDownloaded) return nil } @@ -119,16 +120,17 @@ func (p *v2Puller) download(di *downloadInfo) { out := di.out - ps, found := p.poolAdd("pull", "img:"+di.img.ID) + broadcaster, found := p.poolAdd("pull", "img:"+di.img.ID) if found { - msg := p.sf.FormatProgress(stringid.TruncateID(di.img.ID), "Layer already being pulled by another client. Waiting.", nil) - ps.Wait(out, msg) + out.Write(p.sf.FormatProgress(stringid.TruncateID(di.img.ID), "Layer already being pulled by another client. Waiting.", nil)) + broadcaster.Add(out) + broadcaster.Wait() out.Write(p.sf.FormatProgress(stringid.TruncateID(di.img.ID), "Download complete", nil)) di.err <- nil return } - ps.AddObserver(out) + broadcaster.Add(out) defer p.poolRemove("pull", "img:"+di.img.ID) tmpFile, err := ioutil.TempFile("", "GetImageBlob") if err != nil { @@ -163,7 +165,7 @@ func (p *v2Puller) download(di *downloadInfo) { reader := progressreader.New(progressreader.Config{ In: ioutil.NopCloser(io.TeeReader(layerDownload, verifier)), - Out: ps, + Out: broadcaster, Formatter: p.sf, Size: di.size, NewLines: false, @@ -172,7 +174,7 @@ func (p *v2Puller) download(di *downloadInfo) { }) io.Copy(tmpFile, reader) - ps.Write(p.sf.FormatProgress(stringid.TruncateID(di.img.ID), "Verifying Checksum", nil)) + broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(di.img.ID), "Verifying Checksum", nil)) if !verifier.Verified() { err = fmt.Errorf("filesystem layer verification failed for digest %s", di.digest) @@ -181,7 +183,7 @@ func (p *v2Puller) download(di *downloadInfo) { return } - ps.Write(p.sf.FormatProgress(stringid.TruncateID(di.img.ID), "Download complete", nil)) + broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(di.img.ID), "Download complete", nil)) logrus.Debugf("Downloaded %s to tempfile %s", di.img.ID, tmpFile.Name()) di.layer = layerDownload diff --git a/graph/tags.go b/graph/tags.go index 51b19babbb..d3da607824 100644 --- a/graph/tags.go +++ b/graph/tags.go @@ -37,8 +37,8 @@ type TagStore struct { sync.Mutex // FIXME: move push/pull-related fields // to a helper type - pullingPool map[string]*progressreader.ProgressStatus - pushingPool map[string]*progressreader.ProgressStatus + pullingPool map[string]*progressreader.Broadcaster + pushingPool map[string]*progressreader.Broadcaster registryService *registry.Service eventsService *events.Events trustService *trust.TrustStore @@ -94,8 +94,8 @@ func NewTagStore(path string, cfg *TagStoreConfig) (*TagStore, error) { graph: cfg.Graph, trustKey: cfg.Key, Repositories: make(map[string]Repository), - pullingPool: make(map[string]*progressreader.ProgressStatus), - pushingPool: make(map[string]*progressreader.ProgressStatus), + pullingPool: make(map[string]*progressreader.Broadcaster), + pushingPool: make(map[string]*progressreader.Broadcaster), registryService: cfg.Registry, eventsService: cfg.Events, trustService: cfg.Trust, @@ -428,10 +428,10 @@ func validateDigest(dgst string) error { return nil } -// poolAdd checks if a push or pull is already running, and returns (ps, true) -// if a running operation is found. Otherwise, it creates a new one and returns -// (ps, false). -func (store *TagStore) poolAdd(kind, key string) (*progressreader.ProgressStatus, bool) { +// poolAdd checks if a push or pull is already running, and returns +// (broadcaster, true) if a running operation is found. Otherwise, it creates a +// new one and returns (broadcaster, false). +func (store *TagStore) poolAdd(kind, key string) (*progressreader.Broadcaster, bool) { store.Lock() defer store.Unlock() @@ -442,18 +442,18 @@ func (store *TagStore) poolAdd(kind, key string) (*progressreader.ProgressStatus return p, true } - ps := progressreader.NewProgressStatus() + broadcaster := progressreader.NewBroadcaster() switch kind { case "pull": - store.pullingPool[key] = ps + store.pullingPool[key] = broadcaster case "push": - store.pushingPool[key] = ps + store.pushingPool[key] = broadcaster default: panic("Unknown pool type") } - return ps, false + return broadcaster, false } func (store *TagStore) poolRemove(kind, key string) error { @@ -462,12 +462,12 @@ func (store *TagStore) poolRemove(kind, key string) error { switch kind { case "pull": if ps, exists := store.pullingPool[key]; exists { - ps.Done() + ps.Close() delete(store.pullingPool, key) } case "push": if ps, exists := store.pushingPool[key]; exists { - ps.Done() + ps.Close() delete(store.pushingPool, key) } default: diff --git a/pkg/progressreader/broadcaster.go b/pkg/progressreader/broadcaster.go new file mode 100644 index 0000000000..4b08ce405d --- /dev/null +++ b/pkg/progressreader/broadcaster.go @@ -0,0 +1,146 @@ +package progressreader + +import ( + "bytes" + "errors" + "io" + "sync" +) + +// Broadcaster keeps track of one or more observers watching the progress +// of an operation. For example, if multiple clients are trying to pull an +// image, they share a Broadcaster for the download operation. +type Broadcaster struct { + sync.Mutex + // c is a channel that observers block on, waiting for the operation + // to finish. + c chan struct{} + // cond is a condition variable used to wake up observers when there's + // new data available. + cond *sync.Cond + // history is a buffer of the progress output so far, so a new observer + // can catch up. + history bytes.Buffer + // wg is a WaitGroup used to wait for all writes to finish on Close + wg sync.WaitGroup + // isClosed is set to true when Close is called to avoid closing c + // multiple times. + isClosed bool +} + +// NewBroadcaster returns a Broadcaster structure +func NewBroadcaster() *Broadcaster { + b := &Broadcaster{ + c: make(chan struct{}), + } + b.cond = sync.NewCond(b) + return b +} + +// closed returns true if and only if the broadcaster has been closed +func (broadcaster *Broadcaster) closed() bool { + select { + case <-broadcaster.c: + return true + default: + return false + } +} + +// receiveWrites runs as a goroutine so that writes don't block the Write +// function. It writes the new data in broadcaster.history each time there's +// activity on the broadcaster.cond condition variable. +func (broadcaster *Broadcaster) receiveWrites(observer io.Writer) { + n := 0 + + broadcaster.Lock() + + // The condition variable wait is at the end of this loop, so that the + // first iteration will write the history so far. + for { + newData := broadcaster.history.Bytes()[n:] + // Make a copy of newData so we can release the lock + sendData := make([]byte, len(newData), len(newData)) + copy(sendData, newData) + broadcaster.Unlock() + + if len(sendData) > 0 { + written, err := observer.Write(sendData) + if err != nil { + broadcaster.wg.Done() + return + } + n += written + } + + broadcaster.Lock() + + // detect closure of the broadcast writer + if broadcaster.closed() { + broadcaster.Unlock() + broadcaster.wg.Done() + return + } + + if broadcaster.history.Len() == n { + broadcaster.cond.Wait() + } + + // Mutex is still locked as the loop continues + } +} + +// Write adds data to the history buffer, and also writes it to all current +// observers. +func (broadcaster *Broadcaster) Write(p []byte) (n int, err error) { + broadcaster.Lock() + defer broadcaster.Unlock() + + // Is the broadcaster closed? If so, the write should fail. + if broadcaster.closed() { + return 0, errors.New("attempted write to closed progressreader Broadcaster") + } + + broadcaster.history.Write(p) + broadcaster.cond.Broadcast() + + return len(p), nil +} + +// Add adds an observer to the Broadcaster. The new observer receives the +// data from the history buffer, and also all subsequent data. +func (broadcaster *Broadcaster) Add(w io.Writer) error { + // The lock is acquired here so that Add can't race with Close + broadcaster.Lock() + defer broadcaster.Unlock() + + if broadcaster.closed() { + return errors.New("attempted to add observer to closed progressreader Broadcaster") + } + + broadcaster.wg.Add(1) + go broadcaster.receiveWrites(w) + + return nil +} + +// Close signals to all observers that the operation has finished. +func (broadcaster *Broadcaster) Close() { + broadcaster.Lock() + if broadcaster.isClosed { + broadcaster.Unlock() + return + } + broadcaster.isClosed = true + close(broadcaster.c) + broadcaster.cond.Broadcast() + broadcaster.Unlock() + + // Don't return from Close until all writers have caught up. + broadcaster.wg.Wait() +} + +// Wait blocks until the operation is marked as completed by the Done method. +func (broadcaster *Broadcaster) Wait() { + <-broadcaster.c +} diff --git a/pkg/progressreader/progressstatus.go b/pkg/progressreader/progressstatus.go deleted file mode 100644 index f536b84053..0000000000 --- a/pkg/progressreader/progressstatus.go +++ /dev/null @@ -1,72 +0,0 @@ -package progressreader - -import ( - "bytes" - "io" - "sync" - - "github.com/docker/docker/vendor/src/github.com/Sirupsen/logrus" -) - -type ProgressStatus struct { - sync.Mutex - c chan struct{} - observers []io.Writer - history bytes.Buffer -} - -func NewProgressStatus() *ProgressStatus { - return &ProgressStatus{ - c: make(chan struct{}), - observers: []io.Writer{}, - } -} - -func (ps *ProgressStatus) Write(p []byte) (n int, err error) { - ps.Lock() - defer ps.Unlock() - ps.history.Write(p) - for _, w := range ps.observers { - // copy paste from MultiWriter, replaced return with continue - n, err = w.Write(p) - if err != nil { - continue - } - if n != len(p) { - err = io.ErrShortWrite - continue - } - } - return len(p), nil -} - -func (ps *ProgressStatus) AddObserver(w io.Writer) { - ps.Lock() - defer ps.Unlock() - w.Write(ps.history.Bytes()) - ps.observers = append(ps.observers, w) -} - -func (ps *ProgressStatus) Done() { - ps.Lock() - close(ps.c) - ps.history.Reset() - ps.Unlock() -} - -func (ps *ProgressStatus) Wait(w io.Writer, msg []byte) error { - ps.Lock() - channel := ps.c - ps.Unlock() - - if channel == nil { - // defensive - logrus.Debugf("Channel is nil ") - } - if w != nil { - w.Write(msg) - ps.AddObserver(w) - } - <-channel - return nil -}