mirror of
				https://github.com/moby/moby.git
				synced 2022-11-09 12:21:53 -05:00 
			
		
		
		
	Clean up ProgressStatus
- Rename to Broadcaster - Document exported types - Change Wait function to just wait. Writing a message to the writer and adding the writer to the observers list are now handled by separate function calls. - Avoid importing logrus (the condition where it was used should never happen, anyway). - Make writes non-blocking Signed-off-by: Aaron Lehmann <aaron.lehmann@docker.com>
This commit is contained in:
		
							parent
							
								
									80513d85cf
								
							
						
					
					
						commit
						26c9b58504
					
				
					 7 changed files with 204 additions and 126 deletions
				
			
		| 
						 | 
				
			
			@ -108,7 +108,7 @@ func (s *TagStore) recursiveLoad(address, tmpImageDir string) error {
 | 
			
		|||
		// ensure no two downloads of the same layer happen at the same time
 | 
			
		||||
		if ps, found := s.poolAdd("pull", "layer:"+img.ID); found {
 | 
			
		||||
			logrus.Debugf("Image (id: %s) load is already running, waiting", img.ID)
 | 
			
		||||
			ps.Wait(nil, nil)
 | 
			
		||||
			ps.Wait()
 | 
			
		||||
			return nil
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -13,8 +13,8 @@ func init() {
 | 
			
		|||
 | 
			
		||||
func TestPools(t *testing.T) {
 | 
			
		||||
	s := &TagStore{
 | 
			
		||||
		pullingPool: make(map[string]*progressreader.ProgressStatus),
 | 
			
		||||
		pushingPool: make(map[string]*progressreader.ProgressStatus),
 | 
			
		||||
		pullingPool: make(map[string]*progressreader.Broadcaster),
 | 
			
		||||
		pushingPool: make(map[string]*progressreader.Broadcaster),
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if _, found := s.poolAdd("pull", "test1"); found {
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -138,29 +138,30 @@ func (p *v1Puller) pullRepository(askedTag string) error {
 | 
			
		|||
			}
 | 
			
		||||
 | 
			
		||||
			// ensure no two downloads of the same image happen at the same time
 | 
			
		||||
			ps, found := p.poolAdd("pull", "img:"+img.ID)
 | 
			
		||||
			broadcaster, found := p.poolAdd("pull", "img:"+img.ID)
 | 
			
		||||
			if found {
 | 
			
		||||
				msg := p.sf.FormatProgress(stringid.TruncateID(img.ID), "Layer already being pulled by another client. Waiting.", nil)
 | 
			
		||||
				ps.Wait(out, msg)
 | 
			
		||||
				out.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), "Layer already being pulled by another client. Waiting.", nil))
 | 
			
		||||
				broadcaster.Add(out)
 | 
			
		||||
				broadcaster.Wait()
 | 
			
		||||
				out.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), "Download complete", nil))
 | 
			
		||||
				errors <- nil
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
			ps.AddObserver(out)
 | 
			
		||||
			broadcaster.Add(out)
 | 
			
		||||
			defer p.poolRemove("pull", "img:"+img.ID)
 | 
			
		||||
 | 
			
		||||
			// we need to retain it until tagging
 | 
			
		||||
			p.graph.Retain(sessionID, img.ID)
 | 
			
		||||
			imgIDs = append(imgIDs, img.ID)
 | 
			
		||||
 | 
			
		||||
			ps.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), fmt.Sprintf("Pulling image (%s) from %s", img.Tag, p.repoInfo.CanonicalName), nil))
 | 
			
		||||
			broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), fmt.Sprintf("Pulling image (%s) from %s", img.Tag, p.repoInfo.CanonicalName), nil))
 | 
			
		||||
			success := false
 | 
			
		||||
			var lastErr, err error
 | 
			
		||||
			var isDownloaded bool
 | 
			
		||||
			for _, ep := range p.repoInfo.Index.Mirrors {
 | 
			
		||||
				ep += "v1/"
 | 
			
		||||
				ps.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), fmt.Sprintf("Pulling image (%s) from %s, mirror: %s", img.Tag, p.repoInfo.CanonicalName, ep), nil))
 | 
			
		||||
				if isDownloaded, err = p.pullImage(ps, img.ID, ep, repoData.Tokens); err != nil {
 | 
			
		||||
				broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), fmt.Sprintf("Pulling image (%s) from %s, mirror: %s", img.Tag, p.repoInfo.CanonicalName, ep), nil))
 | 
			
		||||
				if isDownloaded, err = p.pullImage(broadcaster, img.ID, ep, repoData.Tokens); err != nil {
 | 
			
		||||
					// Don't report errors when pulling from mirrors.
 | 
			
		||||
					logrus.Debugf("Error pulling image (%s) from %s, mirror: %s, %s", img.Tag, p.repoInfo.CanonicalName, ep, err)
 | 
			
		||||
					continue
 | 
			
		||||
| 
						 | 
				
			
			@ -171,12 +172,12 @@ func (p *v1Puller) pullRepository(askedTag string) error {
 | 
			
		|||
			}
 | 
			
		||||
			if !success {
 | 
			
		||||
				for _, ep := range repoData.Endpoints {
 | 
			
		||||
					ps.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), fmt.Sprintf("Pulling image (%s) from %s, endpoint: %s", img.Tag, p.repoInfo.CanonicalName, ep), nil))
 | 
			
		||||
					if isDownloaded, err = p.pullImage(ps, img.ID, ep, repoData.Tokens); err != nil {
 | 
			
		||||
					broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), fmt.Sprintf("Pulling image (%s) from %s, endpoint: %s", img.Tag, p.repoInfo.CanonicalName, ep), nil))
 | 
			
		||||
					if isDownloaded, err = p.pullImage(broadcaster, img.ID, ep, repoData.Tokens); err != nil {
 | 
			
		||||
						// It's not ideal that only the last error is returned, it would be better to concatenate the errors.
 | 
			
		||||
						// As the error is also given to the output stream the user will see the error.
 | 
			
		||||
						lastErr = err
 | 
			
		||||
						ps.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), fmt.Sprintf("Error pulling image (%s) from %s, endpoint: %s, %s", img.Tag, p.repoInfo.CanonicalName, ep, err), nil))
 | 
			
		||||
						broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), fmt.Sprintf("Error pulling image (%s) from %s, endpoint: %s, %s", img.Tag, p.repoInfo.CanonicalName, ep, err), nil))
 | 
			
		||||
						continue
 | 
			
		||||
					}
 | 
			
		||||
					layersDownloaded = layersDownloaded || isDownloaded
 | 
			
		||||
| 
						 | 
				
			
			@ -186,11 +187,11 @@ func (p *v1Puller) pullRepository(askedTag string) error {
 | 
			
		|||
			}
 | 
			
		||||
			if !success {
 | 
			
		||||
				err := fmt.Errorf("Error pulling image (%s) from %s, %v", img.Tag, p.repoInfo.CanonicalName, lastErr)
 | 
			
		||||
				ps.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), err.Error(), nil))
 | 
			
		||||
				broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), err.Error(), nil))
 | 
			
		||||
				errors <- err
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
			ps.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), "Download complete", nil))
 | 
			
		||||
			broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), "Download complete", nil))
 | 
			
		||||
 | 
			
		||||
			errors <- nil
 | 
			
		||||
		}
 | 
			
		||||
| 
						 | 
				
			
			@ -244,18 +245,19 @@ func (p *v1Puller) pullImage(out io.Writer, imgID, endpoint string, token []stri
 | 
			
		|||
		id := history[i]
 | 
			
		||||
 | 
			
		||||
		// ensure no two downloads of the same layer happen at the same time
 | 
			
		||||
		ps, found := p.poolAdd("pull", "layer:"+id)
 | 
			
		||||
		broadcaster, found := p.poolAdd("pull", "layer:"+id)
 | 
			
		||||
		if found {
 | 
			
		||||
			logrus.Debugf("Image (id: %s) pull is already running, skipping", id)
 | 
			
		||||
			msg := p.sf.FormatProgress(stringid.TruncateID(imgID), "Layer already being pulled by another client. Waiting.", nil)
 | 
			
		||||
			ps.Wait(out, msg)
 | 
			
		||||
			out.Write(p.sf.FormatProgress(stringid.TruncateID(imgID), "Layer already being pulled by another client. Waiting.", nil))
 | 
			
		||||
			broadcaster.Add(out)
 | 
			
		||||
			broadcaster.Wait()
 | 
			
		||||
		} else {
 | 
			
		||||
			ps.AddObserver(out)
 | 
			
		||||
			broadcaster.Add(out)
 | 
			
		||||
		}
 | 
			
		||||
		defer p.poolRemove("pull", "layer:"+id)
 | 
			
		||||
 | 
			
		||||
		if !p.graph.Exists(id) {
 | 
			
		||||
			ps.Write(p.sf.FormatProgress(stringid.TruncateID(id), "Pulling metadata", nil))
 | 
			
		||||
			broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(id), "Pulling metadata", nil))
 | 
			
		||||
			var (
 | 
			
		||||
				imgJSON []byte
 | 
			
		||||
				imgSize int64
 | 
			
		||||
| 
						 | 
				
			
			@ -266,7 +268,7 @@ func (p *v1Puller) pullImage(out io.Writer, imgID, endpoint string, token []stri
 | 
			
		|||
			for j := 1; j <= retries; j++ {
 | 
			
		||||
				imgJSON, imgSize, err = p.session.GetRemoteImageJSON(id, endpoint)
 | 
			
		||||
				if err != nil && j == retries {
 | 
			
		||||
					ps.Write(p.sf.FormatProgress(stringid.TruncateID(id), "Error pulling dependent layers", nil))
 | 
			
		||||
					broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(id), "Error pulling dependent layers", nil))
 | 
			
		||||
					return layersDownloaded, err
 | 
			
		||||
				} else if err != nil {
 | 
			
		||||
					time.Sleep(time.Duration(j) * 500 * time.Millisecond)
 | 
			
		||||
| 
						 | 
				
			
			@ -275,7 +277,7 @@ func (p *v1Puller) pullImage(out io.Writer, imgID, endpoint string, token []stri
 | 
			
		|||
				img, err = image.NewImgJSON(imgJSON)
 | 
			
		||||
				layersDownloaded = true
 | 
			
		||||
				if err != nil && j == retries {
 | 
			
		||||
					ps.Write(p.sf.FormatProgress(stringid.TruncateID(id), "Error pulling dependent layers", nil))
 | 
			
		||||
					broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(id), "Error pulling dependent layers", nil))
 | 
			
		||||
					return layersDownloaded, fmt.Errorf("Failed to parse json: %s", err)
 | 
			
		||||
				} else if err != nil {
 | 
			
		||||
					time.Sleep(time.Duration(j) * 500 * time.Millisecond)
 | 
			
		||||
| 
						 | 
				
			
			@ -291,7 +293,7 @@ func (p *v1Puller) pullImage(out io.Writer, imgID, endpoint string, token []stri
 | 
			
		|||
				if j > 1 {
 | 
			
		||||
					status = fmt.Sprintf("Pulling fs layer [retries: %d]", j)
 | 
			
		||||
				}
 | 
			
		||||
				ps.Write(p.sf.FormatProgress(stringid.TruncateID(id), status, nil))
 | 
			
		||||
				broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(id), status, nil))
 | 
			
		||||
				layer, err := p.session.GetRemoteImageLayer(img.ID, endpoint, imgSize)
 | 
			
		||||
				if uerr, ok := err.(*url.Error); ok {
 | 
			
		||||
					err = uerr.Err
 | 
			
		||||
| 
						 | 
				
			
			@ -300,7 +302,7 @@ func (p *v1Puller) pullImage(out io.Writer, imgID, endpoint string, token []stri
 | 
			
		|||
					time.Sleep(time.Duration(j) * 500 * time.Millisecond)
 | 
			
		||||
					continue
 | 
			
		||||
				} else if err != nil {
 | 
			
		||||
					ps.Write(p.sf.FormatProgress(stringid.TruncateID(id), "Error pulling dependent layers", nil))
 | 
			
		||||
					broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(id), "Error pulling dependent layers", nil))
 | 
			
		||||
					return layersDownloaded, err
 | 
			
		||||
				}
 | 
			
		||||
				layersDownloaded = true
 | 
			
		||||
| 
						 | 
				
			
			@ -309,7 +311,7 @@ func (p *v1Puller) pullImage(out io.Writer, imgID, endpoint string, token []stri
 | 
			
		|||
				err = p.graph.Register(img,
 | 
			
		||||
					progressreader.New(progressreader.Config{
 | 
			
		||||
						In:        layer,
 | 
			
		||||
						Out:       ps,
 | 
			
		||||
						Out:       broadcaster,
 | 
			
		||||
						Formatter: p.sf,
 | 
			
		||||
						Size:      imgSize,
 | 
			
		||||
						NewLines:  false,
 | 
			
		||||
| 
						 | 
				
			
			@ -320,14 +322,14 @@ func (p *v1Puller) pullImage(out io.Writer, imgID, endpoint string, token []stri
 | 
			
		|||
					time.Sleep(time.Duration(j) * 500 * time.Millisecond)
 | 
			
		||||
					continue
 | 
			
		||||
				} else if err != nil {
 | 
			
		||||
					ps.Write(p.sf.FormatProgress(stringid.TruncateID(id), "Error downloading dependent layers", nil))
 | 
			
		||||
					broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(id), "Error downloading dependent layers", nil))
 | 
			
		||||
					return layersDownloaded, err
 | 
			
		||||
				} else {
 | 
			
		||||
					break
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		ps.Write(p.sf.FormatProgress(stringid.TruncateID(id), "Download complete", nil))
 | 
			
		||||
		broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(id), "Download complete", nil))
 | 
			
		||||
	}
 | 
			
		||||
	return layersDownloaded, nil
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -73,28 +73,29 @@ func (p *v2Puller) pullV2Repository(tag string) (err error) {
 | 
			
		|||
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	ps, found := p.poolAdd("pull", taggedName)
 | 
			
		||||
	broadcaster, found := p.poolAdd("pull", taggedName)
 | 
			
		||||
	if found {
 | 
			
		||||
		// Another pull of the same repository is already taking place; just wait for it to finish
 | 
			
		||||
		msg := p.sf.FormatStatus("", "Repository %s already being pulled by another client. Waiting.", p.repoInfo.CanonicalName)
 | 
			
		||||
		ps.Wait(p.config.OutStream, msg)
 | 
			
		||||
		p.config.OutStream.Write(p.sf.FormatStatus("", "Repository %s already being pulled by another client. Waiting.", p.repoInfo.CanonicalName))
 | 
			
		||||
		broadcaster.Add(p.config.OutStream)
 | 
			
		||||
		broadcaster.Wait()
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	defer p.poolRemove("pull", taggedName)
 | 
			
		||||
	ps.AddObserver(p.config.OutStream)
 | 
			
		||||
	broadcaster.Add(p.config.OutStream)
 | 
			
		||||
 | 
			
		||||
	var layersDownloaded bool
 | 
			
		||||
	for _, tag := range tags {
 | 
			
		||||
		// pulledNew is true if either new layers were downloaded OR if existing images were newly tagged
 | 
			
		||||
		// TODO(tiborvass): should we change the name of `layersDownload`? What about message in WriteStatus?
 | 
			
		||||
		pulledNew, err := p.pullV2Tag(ps, tag, taggedName)
 | 
			
		||||
		pulledNew, err := p.pullV2Tag(broadcaster, tag, taggedName)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
		layersDownloaded = layersDownloaded || pulledNew
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	writeStatus(taggedName, ps, p.sf, layersDownloaded)
 | 
			
		||||
	writeStatus(taggedName, broadcaster, p.sf, layersDownloaded)
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -119,16 +120,17 @@ func (p *v2Puller) download(di *downloadInfo) {
 | 
			
		|||
 | 
			
		||||
	out := di.out
 | 
			
		||||
 | 
			
		||||
	ps, found := p.poolAdd("pull", "img:"+di.img.ID)
 | 
			
		||||
	broadcaster, found := p.poolAdd("pull", "img:"+di.img.ID)
 | 
			
		||||
	if found {
 | 
			
		||||
		msg := p.sf.FormatProgress(stringid.TruncateID(di.img.ID), "Layer already being pulled by another client. Waiting.", nil)
 | 
			
		||||
		ps.Wait(out, msg)
 | 
			
		||||
		out.Write(p.sf.FormatProgress(stringid.TruncateID(di.img.ID), "Layer already being pulled by another client. Waiting.", nil))
 | 
			
		||||
		broadcaster.Add(out)
 | 
			
		||||
		broadcaster.Wait()
 | 
			
		||||
		out.Write(p.sf.FormatProgress(stringid.TruncateID(di.img.ID), "Download complete", nil))
 | 
			
		||||
		di.err <- nil
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	ps.AddObserver(out)
 | 
			
		||||
	broadcaster.Add(out)
 | 
			
		||||
	defer p.poolRemove("pull", "img:"+di.img.ID)
 | 
			
		||||
	tmpFile, err := ioutil.TempFile("", "GetImageBlob")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
| 
						 | 
				
			
			@ -163,7 +165,7 @@ func (p *v2Puller) download(di *downloadInfo) {
 | 
			
		|||
 | 
			
		||||
	reader := progressreader.New(progressreader.Config{
 | 
			
		||||
		In:        ioutil.NopCloser(io.TeeReader(layerDownload, verifier)),
 | 
			
		||||
		Out:       ps,
 | 
			
		||||
		Out:       broadcaster,
 | 
			
		||||
		Formatter: p.sf,
 | 
			
		||||
		Size:      di.size,
 | 
			
		||||
		NewLines:  false,
 | 
			
		||||
| 
						 | 
				
			
			@ -172,7 +174,7 @@ func (p *v2Puller) download(di *downloadInfo) {
 | 
			
		|||
	})
 | 
			
		||||
	io.Copy(tmpFile, reader)
 | 
			
		||||
 | 
			
		||||
	ps.Write(p.sf.FormatProgress(stringid.TruncateID(di.img.ID), "Verifying Checksum", nil))
 | 
			
		||||
	broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(di.img.ID), "Verifying Checksum", nil))
 | 
			
		||||
 | 
			
		||||
	if !verifier.Verified() {
 | 
			
		||||
		err = fmt.Errorf("filesystem layer verification failed for digest %s", di.digest)
 | 
			
		||||
| 
						 | 
				
			
			@ -181,7 +183,7 @@ func (p *v2Puller) download(di *downloadInfo) {
 | 
			
		|||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	ps.Write(p.sf.FormatProgress(stringid.TruncateID(di.img.ID), "Download complete", nil))
 | 
			
		||||
	broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(di.img.ID), "Download complete", nil))
 | 
			
		||||
 | 
			
		||||
	logrus.Debugf("Downloaded %s to tempfile %s", di.img.ID, tmpFile.Name())
 | 
			
		||||
	di.layer = layerDownload
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -37,8 +37,8 @@ type TagStore struct {
 | 
			
		|||
	sync.Mutex
 | 
			
		||||
	// FIXME: move push/pull-related fields
 | 
			
		||||
	// to a helper type
 | 
			
		||||
	pullingPool     map[string]*progressreader.ProgressStatus
 | 
			
		||||
	pushingPool     map[string]*progressreader.ProgressStatus
 | 
			
		||||
	pullingPool     map[string]*progressreader.Broadcaster
 | 
			
		||||
	pushingPool     map[string]*progressreader.Broadcaster
 | 
			
		||||
	registryService *registry.Service
 | 
			
		||||
	eventsService   *events.Events
 | 
			
		||||
	trustService    *trust.TrustStore
 | 
			
		||||
| 
						 | 
				
			
			@ -94,8 +94,8 @@ func NewTagStore(path string, cfg *TagStoreConfig) (*TagStore, error) {
 | 
			
		|||
		graph:           cfg.Graph,
 | 
			
		||||
		trustKey:        cfg.Key,
 | 
			
		||||
		Repositories:    make(map[string]Repository),
 | 
			
		||||
		pullingPool:     make(map[string]*progressreader.ProgressStatus),
 | 
			
		||||
		pushingPool:     make(map[string]*progressreader.ProgressStatus),
 | 
			
		||||
		pullingPool:     make(map[string]*progressreader.Broadcaster),
 | 
			
		||||
		pushingPool:     make(map[string]*progressreader.Broadcaster),
 | 
			
		||||
		registryService: cfg.Registry,
 | 
			
		||||
		eventsService:   cfg.Events,
 | 
			
		||||
		trustService:    cfg.Trust,
 | 
			
		||||
| 
						 | 
				
			
			@ -428,10 +428,10 @@ func validateDigest(dgst string) error {
 | 
			
		|||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// poolAdd checks if a push or pull is already running, and returns (ps, true)
 | 
			
		||||
// if a running operation is found. Otherwise, it creates a new one and returns
 | 
			
		||||
// (ps, false).
 | 
			
		||||
func (store *TagStore) poolAdd(kind, key string) (*progressreader.ProgressStatus, bool) {
 | 
			
		||||
// poolAdd checks if a push or pull is already running, and returns
 | 
			
		||||
// (broadcaster, true) if a running operation is found. Otherwise, it creates a
 | 
			
		||||
// new one and returns (broadcaster, false).
 | 
			
		||||
func (store *TagStore) poolAdd(kind, key string) (*progressreader.Broadcaster, bool) {
 | 
			
		||||
	store.Lock()
 | 
			
		||||
	defer store.Unlock()
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -442,18 +442,18 @@ func (store *TagStore) poolAdd(kind, key string) (*progressreader.ProgressStatus
 | 
			
		|||
		return p, true
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	ps := progressreader.NewProgressStatus()
 | 
			
		||||
	broadcaster := progressreader.NewBroadcaster()
 | 
			
		||||
 | 
			
		||||
	switch kind {
 | 
			
		||||
	case "pull":
 | 
			
		||||
		store.pullingPool[key] = ps
 | 
			
		||||
		store.pullingPool[key] = broadcaster
 | 
			
		||||
	case "push":
 | 
			
		||||
		store.pushingPool[key] = ps
 | 
			
		||||
		store.pushingPool[key] = broadcaster
 | 
			
		||||
	default:
 | 
			
		||||
		panic("Unknown pool type")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return ps, false
 | 
			
		||||
	return broadcaster, false
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (store *TagStore) poolRemove(kind, key string) error {
 | 
			
		||||
| 
						 | 
				
			
			@ -462,12 +462,12 @@ func (store *TagStore) poolRemove(kind, key string) error {
 | 
			
		|||
	switch kind {
 | 
			
		||||
	case "pull":
 | 
			
		||||
		if ps, exists := store.pullingPool[key]; exists {
 | 
			
		||||
			ps.Done()
 | 
			
		||||
			ps.Close()
 | 
			
		||||
			delete(store.pullingPool, key)
 | 
			
		||||
		}
 | 
			
		||||
	case "push":
 | 
			
		||||
		if ps, exists := store.pushingPool[key]; exists {
 | 
			
		||||
			ps.Done()
 | 
			
		||||
			ps.Close()
 | 
			
		||||
			delete(store.pushingPool, key)
 | 
			
		||||
		}
 | 
			
		||||
	default:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										146
									
								
								pkg/progressreader/broadcaster.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										146
									
								
								pkg/progressreader/broadcaster.go
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,146 @@
 | 
			
		|||
package progressreader
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"io"
 | 
			
		||||
	"sync"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Broadcaster keeps track of one or more observers watching the progress
 | 
			
		||||
// of an operation. For example, if multiple clients are trying to pull an
 | 
			
		||||
// image, they share a Broadcaster for the download operation.
 | 
			
		||||
type Broadcaster struct {
 | 
			
		||||
	sync.Mutex
 | 
			
		||||
	// c is a channel that observers block on, waiting for the operation
 | 
			
		||||
	// to finish.
 | 
			
		||||
	c chan struct{}
 | 
			
		||||
	// cond is a condition variable used to wake up observers when there's
 | 
			
		||||
	// new data available.
 | 
			
		||||
	cond *sync.Cond
 | 
			
		||||
	// history is a buffer of the progress output so far, so a new observer
 | 
			
		||||
	// can catch up.
 | 
			
		||||
	history bytes.Buffer
 | 
			
		||||
	// wg is a WaitGroup used to wait for all writes to finish on Close
 | 
			
		||||
	wg sync.WaitGroup
 | 
			
		||||
	// isClosed is set to true when Close is called to avoid closing c
 | 
			
		||||
	// multiple times.
 | 
			
		||||
	isClosed bool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NewBroadcaster returns a Broadcaster structure
 | 
			
		||||
func NewBroadcaster() *Broadcaster {
 | 
			
		||||
	b := &Broadcaster{
 | 
			
		||||
		c: make(chan struct{}),
 | 
			
		||||
	}
 | 
			
		||||
	b.cond = sync.NewCond(b)
 | 
			
		||||
	return b
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// closed returns true if and only if the broadcaster has been closed
 | 
			
		||||
func (broadcaster *Broadcaster) closed() bool {
 | 
			
		||||
	select {
 | 
			
		||||
	case <-broadcaster.c:
 | 
			
		||||
		return true
 | 
			
		||||
	default:
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// receiveWrites runs as a goroutine so that writes don't block the Write
 | 
			
		||||
// function. It writes the new data in broadcaster.history each time there's
 | 
			
		||||
// activity on the broadcaster.cond condition variable.
 | 
			
		||||
func (broadcaster *Broadcaster) receiveWrites(observer io.Writer) {
 | 
			
		||||
	n := 0
 | 
			
		||||
 | 
			
		||||
	broadcaster.Lock()
 | 
			
		||||
 | 
			
		||||
	// The condition variable wait is at the end of this loop, so that the
 | 
			
		||||
	// first iteration will write the history so far.
 | 
			
		||||
	for {
 | 
			
		||||
		newData := broadcaster.history.Bytes()[n:]
 | 
			
		||||
		// Make a copy of newData so we can release the lock
 | 
			
		||||
		sendData := make([]byte, len(newData), len(newData))
 | 
			
		||||
		copy(sendData, newData)
 | 
			
		||||
		broadcaster.Unlock()
 | 
			
		||||
 | 
			
		||||
		if len(sendData) > 0 {
 | 
			
		||||
			written, err := observer.Write(sendData)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				broadcaster.wg.Done()
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
			n += written
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		broadcaster.Lock()
 | 
			
		||||
 | 
			
		||||
		// detect closure of the broadcast writer
 | 
			
		||||
		if broadcaster.closed() {
 | 
			
		||||
			broadcaster.Unlock()
 | 
			
		||||
			broadcaster.wg.Done()
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if broadcaster.history.Len() == n {
 | 
			
		||||
			broadcaster.cond.Wait()
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// Mutex is still locked as the loop continues
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Write adds data to the history buffer, and also writes it to all current
 | 
			
		||||
// observers.
 | 
			
		||||
func (broadcaster *Broadcaster) Write(p []byte) (n int, err error) {
 | 
			
		||||
	broadcaster.Lock()
 | 
			
		||||
	defer broadcaster.Unlock()
 | 
			
		||||
 | 
			
		||||
	// Is the broadcaster closed? If so, the write should fail.
 | 
			
		||||
	if broadcaster.closed() {
 | 
			
		||||
		return 0, errors.New("attempted write to closed progressreader Broadcaster")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	broadcaster.history.Write(p)
 | 
			
		||||
	broadcaster.cond.Broadcast()
 | 
			
		||||
 | 
			
		||||
	return len(p), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Add adds an observer to the Broadcaster. The new observer receives the
 | 
			
		||||
// data from the history buffer, and also all subsequent data.
 | 
			
		||||
func (broadcaster *Broadcaster) Add(w io.Writer) error {
 | 
			
		||||
	// The lock is acquired here so that Add can't race with Close
 | 
			
		||||
	broadcaster.Lock()
 | 
			
		||||
	defer broadcaster.Unlock()
 | 
			
		||||
 | 
			
		||||
	if broadcaster.closed() {
 | 
			
		||||
		return errors.New("attempted to add observer to closed progressreader Broadcaster")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	broadcaster.wg.Add(1)
 | 
			
		||||
	go broadcaster.receiveWrites(w)
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Close signals to all observers that the operation has finished.
 | 
			
		||||
func (broadcaster *Broadcaster) Close() {
 | 
			
		||||
	broadcaster.Lock()
 | 
			
		||||
	if broadcaster.isClosed {
 | 
			
		||||
		broadcaster.Unlock()
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	broadcaster.isClosed = true
 | 
			
		||||
	close(broadcaster.c)
 | 
			
		||||
	broadcaster.cond.Broadcast()
 | 
			
		||||
	broadcaster.Unlock()
 | 
			
		||||
 | 
			
		||||
	// Don't return from Close until all writers have caught up.
 | 
			
		||||
	broadcaster.wg.Wait()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Wait blocks until the operation is marked as completed by the Done method.
 | 
			
		||||
func (broadcaster *Broadcaster) Wait() {
 | 
			
		||||
	<-broadcaster.c
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -1,72 +0,0 @@
 | 
			
		|||
package progressreader
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"io"
 | 
			
		||||
	"sync"
 | 
			
		||||
 | 
			
		||||
	"github.com/docker/docker/vendor/src/github.com/Sirupsen/logrus"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type ProgressStatus struct {
 | 
			
		||||
	sync.Mutex
 | 
			
		||||
	c         chan struct{}
 | 
			
		||||
	observers []io.Writer
 | 
			
		||||
	history   bytes.Buffer
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewProgressStatus() *ProgressStatus {
 | 
			
		||||
	return &ProgressStatus{
 | 
			
		||||
		c:         make(chan struct{}),
 | 
			
		||||
		observers: []io.Writer{},
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ps *ProgressStatus) Write(p []byte) (n int, err error) {
 | 
			
		||||
	ps.Lock()
 | 
			
		||||
	defer ps.Unlock()
 | 
			
		||||
	ps.history.Write(p)
 | 
			
		||||
	for _, w := range ps.observers {
 | 
			
		||||
		// copy paste from MultiWriter, replaced return with continue
 | 
			
		||||
		n, err = w.Write(p)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		if n != len(p) {
 | 
			
		||||
			err = io.ErrShortWrite
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return len(p), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ps *ProgressStatus) AddObserver(w io.Writer) {
 | 
			
		||||
	ps.Lock()
 | 
			
		||||
	defer ps.Unlock()
 | 
			
		||||
	w.Write(ps.history.Bytes())
 | 
			
		||||
	ps.observers = append(ps.observers, w)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ps *ProgressStatus) Done() {
 | 
			
		||||
	ps.Lock()
 | 
			
		||||
	close(ps.c)
 | 
			
		||||
	ps.history.Reset()
 | 
			
		||||
	ps.Unlock()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ps *ProgressStatus) Wait(w io.Writer, msg []byte) error {
 | 
			
		||||
	ps.Lock()
 | 
			
		||||
	channel := ps.c
 | 
			
		||||
	ps.Unlock()
 | 
			
		||||
 | 
			
		||||
	if channel == nil {
 | 
			
		||||
		// defensive
 | 
			
		||||
		logrus.Debugf("Channel is nil ")
 | 
			
		||||
	}
 | 
			
		||||
	if w != nil {
 | 
			
		||||
		w.Write(msg)
 | 
			
		||||
		ps.AddObserver(w)
 | 
			
		||||
	}
 | 
			
		||||
	<-channel
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue