diff --git a/registry/registry.go b/registry/registry.go index 3d0a3ed2da..7bcf066019 100644 --- a/registry/registry.go +++ b/registry/registry.go @@ -256,12 +256,43 @@ func (r *Registry) GetRemoteImageJSON(imgID, registry string, token []string) ([ return jsonString, imageSize, nil } -func (r *Registry) GetRemoteImageLayer(imgID, registry string, token []string) (io.ReadCloser, error) { - req, err := r.reqFactory.NewRequest("GET", registry+"images/"+imgID+"/layer", nil) +func (r *Registry) GetRemoteImageLayer(imgID, registry string, token []string, imgSize int64) (io.ReadCloser, error) { + var ( + retries = 5 + headRes *http.Response + hasResume bool = false + imageURL = fmt.Sprintf("%simages/%s/layer", registry, imgID) + ) + headReq, err := r.reqFactory.NewRequest("HEAD", imageURL, nil) + if err != nil { + return nil, fmt.Errorf("Error while getting from the server: %s\n", err) + } + setTokenAuth(headReq, token) + for i := 1; i <= retries; i++ { + headRes, err = r.client.Do(headReq) + if err != nil && i == retries { + return nil, fmt.Errorf("Eror while making head request: %s\n", err) + } else if err != nil { + time.Sleep(time.Duration(i) * 5 * time.Second) + continue + } + break + } + + if headRes.Header.Get("Accept-Ranges") == "bytes" && imgSize > 0 { + hasResume = true + } + + req, err := r.reqFactory.NewRequest("GET", imageURL, nil) if err != nil { return nil, fmt.Errorf("Error while getting from the server: %s\n", err) } setTokenAuth(req, token) + if hasResume { + utils.Debugf("server supports resume") + return utils.ResumableRequestReader(r.client, req, 5, imgSize), nil + } + utils.Debugf("server doesn't support resume") res, err := r.client.Do(req) if err != nil { return nil, err @@ -725,6 +756,13 @@ type Registry struct { indexEndpoint string } +func AddRequiredHeadersToRedirectedRequests(req *http.Request, via []*http.Request) error { + if via != nil && via[0] != nil { + req.Header = via[0].Header + } + return nil +} + func NewRegistry(authConfig *AuthConfig, factory *utils.HTTPRequestFactory, indexEndpoint string) (r *Registry, err error) { httpDial := func(proto string, addr string) (net.Conn, error) { conn, err := net.Dial(proto, addr) @@ -744,7 +782,8 @@ func NewRegistry(authConfig *AuthConfig, factory *utils.HTTPRequestFactory, inde r = &Registry{ authConfig: authConfig, client: &http.Client{ - Transport: httpTransport, + Transport: httpTransport, + CheckRedirect: AddRequiredHeadersToRedirectedRequests, }, indexEndpoint: indexEndpoint, } diff --git a/registry/registry_test.go b/registry/registry_test.go index 0a5be5e543..e207359e61 100644 --- a/registry/registry_test.go +++ b/registry/registry_test.go @@ -70,7 +70,7 @@ func TestGetRemoteImageJSON(t *testing.T) { func TestGetRemoteImageLayer(t *testing.T) { r := spawnTestRegistry(t) - data, err := r.GetRemoteImageLayer(IMAGE_ID, makeURL("/v1/"), TOKEN) + data, err := r.GetRemoteImageLayer(IMAGE_ID, makeURL("/v1/"), TOKEN, 0) if err != nil { t.Fatal(err) } @@ -78,7 +78,7 @@ func TestGetRemoteImageLayer(t *testing.T) { t.Fatal("Expected non-nil data result") } - _, err = r.GetRemoteImageLayer("abcdef", makeURL("/v1/"), TOKEN) + _, err = r.GetRemoteImageLayer("abcdef", makeURL("/v1/"), TOKEN, 0) if err == nil { t.Fatal("Expected image not found error") } diff --git a/server/server.go b/server/server.go index e7134d8f0c..3239893e0e 100644 --- a/server/server.go +++ b/server/server.go @@ -1137,7 +1137,7 @@ func (srv *Server) pullImage(r *registry.Registry, out io.Writer, imgID, endpoin status = fmt.Sprintf("Pulling fs layer [retries: %d]", j) } out.Write(sf.FormatProgress(utils.TruncateID(id), status, nil)) - layer, err := r.GetRemoteImageLayer(img.ID, endpoint, token) + layer, err := r.GetRemoteImageLayer(img.ID, endpoint, token, int64(imgSize)) if uerr, ok := err.(*url.Error); ok { err = uerr.Err } diff --git a/utils/resumablerequestreader.go b/utils/resumablerequestreader.go new file mode 100644 index 0000000000..e01f4e6d71 --- /dev/null +++ b/utils/resumablerequestreader.go @@ -0,0 +1,87 @@ +package utils + +import ( + "fmt" + "io" + "net/http" + "time" +) + +type resumableRequestReader struct { + client *http.Client + request *http.Request + lastRange int64 + totalSize int64 + currentResponse *http.Response + failures uint32 + maxFailures uint32 +} + +// ResumableRequestReader makes it possible to resume reading a request's body transparently +// maxfail is the number of times we retry to make requests again (not resumes) +// totalsize is the total length of the body; auto detect if not provided +func ResumableRequestReader(c *http.Client, r *http.Request, maxfail uint32, totalsize int64) io.ReadCloser { + return &resumableRequestReader{client: c, request: r, maxFailures: maxfail, totalSize: totalsize} +} + +func (r *resumableRequestReader) Read(p []byte) (n int, err error) { + if r.client == nil || r.request == nil { + return 0, fmt.Errorf("client and request can't be nil\n") + } + isFreshRequest := false + if r.lastRange != 0 && r.currentResponse == nil { + readRange := fmt.Sprintf("bytes=%d-%d", r.lastRange, r.totalSize) + r.request.Header.Set("Range", readRange) + time.Sleep(5 * time.Second) + } + if r.currentResponse == nil { + r.currentResponse, err = r.client.Do(r.request) + isFreshRequest = true + } + if err != nil && r.failures+1 != r.maxFailures { + r.cleanUpResponse() + r.failures += 1 + time.Sleep(5 * time.Duration(r.failures) * time.Second) + return 0, nil + } else if err != nil { + r.cleanUpResponse() + return 0, err + } + if r.currentResponse.StatusCode == 416 && r.lastRange == r.totalSize && r.currentResponse.ContentLength == 0 { + r.cleanUpResponse() + return 0, io.EOF + } else if r.currentResponse.StatusCode != 206 && r.lastRange != 0 && isFreshRequest { + r.cleanUpResponse() + return 0, fmt.Errorf("the server doesn't support byte ranges") + } + if r.totalSize == 0 { + r.totalSize = r.currentResponse.ContentLength + } else if r.totalSize <= 0 { + r.cleanUpResponse() + return 0, fmt.Errorf("failed to auto detect content length") + } + n, err = r.currentResponse.Body.Read(p) + r.lastRange += int64(n) + if err != nil { + r.cleanUpResponse() + } + if err != nil && err != io.EOF { + Debugf("encountered error during pull and clearing it before resume: %s", err) + err = nil + } + return n, err +} + +func (r *resumableRequestReader) Close() error { + r.cleanUpResponse() + r.client = nil + r.request = nil + return nil +} + +func (r *resumableRequestReader) cleanUpResponse() { + if r.currentResponse != nil { + r.currentResponse.Body.Close() + r.currentResponse = nil + } +}