diff --git a/distribution/pull.go b/distribution/pull.go index 603d5de5be..49c0d3b660 100644 --- a/distribution/pull.go +++ b/distribution/pull.go @@ -2,7 +2,6 @@ package distribution import ( "fmt" - "os" "github.com/Sirupsen/logrus" "github.com/docker/docker/api" @@ -187,16 +186,3 @@ func validateRepoName(name string) error { } return nil } - -// tmpFileClose creates a closer function for a temporary file that closes the file -// and also deletes it. -func tmpFileCloser(tmpFile *os.File) func() error { - return func() error { - tmpFile.Close() - if err := os.RemoveAll(tmpFile.Name()); err != nil { - logrus.Errorf("Failed to remove temp file: %s", tmpFile.Name()) - } - - return nil - } -} diff --git a/distribution/pull_v1.go b/distribution/pull_v1.go index 312f7e30e2..a7080df697 100644 --- a/distribution/pull_v1.go +++ b/distribution/pull_v1.go @@ -7,6 +7,7 @@ import ( "io/ioutil" "net" "net/url" + "os" "strings" "time" @@ -279,6 +280,7 @@ type v1LayerDescriptor struct { layersDownloaded *bool layerSize int64 session *registry.Session + tmpFile *os.File } func (ld *v1LayerDescriptor) Key() string { @@ -308,7 +310,7 @@ func (ld *v1LayerDescriptor) Download(ctx context.Context, progressOutput progre } *ld.layersDownloaded = true - tmpFile, err := ioutil.TempFile("", "GetImageBlob") + ld.tmpFile, err = ioutil.TempFile("", "GetImageBlob") if err != nil { layerReader.Close() return nil, 0, err @@ -317,17 +319,28 @@ func (ld *v1LayerDescriptor) Download(ctx context.Context, progressOutput progre reader := progress.NewProgressReader(ioutils.NewCancelReadCloser(ctx, layerReader), progressOutput, ld.layerSize, ld.ID(), "Downloading") defer reader.Close() - _, err = io.Copy(tmpFile, reader) + _, err = io.Copy(ld.tmpFile, reader) if err != nil { + ld.Close() return nil, 0, err } progress.Update(progressOutput, ld.ID(), "Download complete") - logrus.Debugf("Downloaded %s to tempfile %s", ld.ID(), tmpFile.Name()) + logrus.Debugf("Downloaded %s to tempfile %s", ld.ID(), ld.tmpFile.Name()) - tmpFile.Seek(0, 0) - return ioutils.NewReadCloserWrapper(tmpFile, tmpFileCloser(tmpFile)), ld.layerSize, nil + ld.tmpFile.Seek(0, 0) + return ld.tmpFile, ld.layerSize, nil +} + +func (ld *v1LayerDescriptor) Close() { + if ld.tmpFile != nil { + ld.tmpFile.Close() + if err := os.RemoveAll(ld.tmpFile.Name()); err != nil { + logrus.Errorf("Failed to remove temp file: %s", ld.tmpFile.Name()) + } + ld.tmpFile = nil + } } func (ld *v1LayerDescriptor) Registered(diffID layer.DiffID) { diff --git a/distribution/pull_v2.go b/distribution/pull_v2.go index 04d05e02f4..cb07b5172a 100644 --- a/distribution/pull_v2.go +++ b/distribution/pull_v2.go @@ -114,6 +114,7 @@ type v2LayerDescriptor struct { repoInfo *registry.RepositoryInfo repo distribution.Repository V2MetadataService *metadata.V2MetadataService + tmpFile *os.File } func (ld *v2LayerDescriptor) Key() string { @@ -131,6 +132,18 @@ func (ld *v2LayerDescriptor) DiffID() (layer.DiffID, error) { func (ld *v2LayerDescriptor) Download(ctx context.Context, progressOutput progress.Output) (io.ReadCloser, int64, error) { logrus.Debugf("pulling blob %q", ld.digest) + var err error + + if ld.tmpFile == nil { + ld.tmpFile, err = createDownloadFile() + } else { + _, err = ld.tmpFile.Seek(0, os.SEEK_SET) + } + if err != nil { + return nil, 0, xfer.DoNotRetry{Err: err} + } + + tmpFile := ld.tmpFile blobs := ld.repo.Blobs(ctx) layerDownload, err := blobs.Open(ctx, ld.digest) @@ -164,17 +177,13 @@ func (ld *v2LayerDescriptor) Download(ctx context.Context, progressOutput progre return nil, 0, xfer.DoNotRetry{Err: err} } - tmpFile, err := ioutil.TempFile("", "GetImageBlob") - if err != nil { - return nil, 0, xfer.DoNotRetry{Err: err} - } - _, err = io.Copy(tmpFile, io.TeeReader(reader, verifier)) if err != nil { tmpFile.Close() if err := os.Remove(tmpFile.Name()); err != nil { logrus.Errorf("Failed to remove temp file: %s", tmpFile.Name()) } + ld.tmpFile = nil return nil, 0, retryOnError(err) } @@ -188,6 +197,7 @@ func (ld *v2LayerDescriptor) Download(ctx context.Context, progressOutput progre if err := os.Remove(tmpFile.Name()); err != nil { logrus.Errorf("Failed to remove temp file: %s", tmpFile.Name()) } + ld.tmpFile = nil return nil, 0, xfer.DoNotRetry{Err: err} } @@ -202,9 +212,19 @@ func (ld *v2LayerDescriptor) Download(ctx context.Context, progressOutput progre if err := os.Remove(tmpFile.Name()); err != nil { logrus.Errorf("Failed to remove temp file: %s", tmpFile.Name()) } + ld.tmpFile = nil return nil, 0, xfer.DoNotRetry{Err: err} } - return ioutils.NewReadCloserWrapper(tmpFile, tmpFileCloser(tmpFile)), size, nil + return tmpFile, size, nil +} + +func (ld *v2LayerDescriptor) Close() { + if ld.tmpFile != nil { + ld.tmpFile.Close() + if err := os.RemoveAll(ld.tmpFile.Name()); err != nil { + logrus.Errorf("Failed to remove temp file: %s", ld.tmpFile.Name()) + } + } } func (ld *v2LayerDescriptor) Registered(diffID layer.DiffID) { @@ -711,3 +731,7 @@ func fixManifestLayers(m *schema1.Manifest) error { return nil } + +func createDownloadFile() (*os.File, error) { + return ioutil.TempFile("", "GetImageBlob") +} diff --git a/distribution/xfer/download.go b/distribution/xfer/download.go index 69c8bad031..2536f1dd23 100644 --- a/distribution/xfer/download.go +++ b/distribution/xfer/download.go @@ -59,6 +59,10 @@ type DownloadDescriptor interface { DiffID() (layer.DiffID, error) // Download is called to perform the download. Download(ctx context.Context, progressOutput progress.Output) (io.ReadCloser, int64, error) + // Close is called when the download manager is finished with this + // descriptor and will not call Download again or read from the reader + // that Download returned. + Close() } // DownloadDescriptorWithRegistered is a DownloadDescriptor that has an @@ -229,6 +233,8 @@ func (ldm *LayerDownloadManager) makeDownloadFunc(descriptor DownloadDescriptor, retries int ) + defer descriptor.Close() + for { downloadReader, size, err = descriptor.Download(d.Transfer.Context(), progressOutput) if err == nil { diff --git a/distribution/xfer/download_test.go b/distribution/xfer/download_test.go index 6dc6708531..32d5502546 100644 --- a/distribution/xfer/download_test.go +++ b/distribution/xfer/download_test.go @@ -199,6 +199,9 @@ func (d *mockDownloadDescriptor) Download(ctx context.Context, progressOutput pr return d.mockTarStream(), 0, nil } +func (d *mockDownloadDescriptor) Close() { +} + func downloadDescriptors(currentDownloads *int32) []DownloadDescriptor { return []DownloadDescriptor{ &mockDownloadDescriptor{