From 73823e5e56446b23ce01bb8e44a9670ab2552b0a Mon Sep 17 00:00:00 2001 From: Tibor Vass Date: Fri, 15 May 2015 18:35:04 -0700 Subject: [PATCH] Add transport package to support CancelRequest Signed-off-by: Tibor Vass --- graph/pull.go | 12 +-- graph/push.go | 12 +-- pkg/transport/LICENSE | 27 +++++++ pkg/transport/transport.go | 148 +++++++++++++++++++++++++++++++++++++ registry/auth.go | 26 +++---- registry/endpoint.go | 25 +++---- registry/endpoint_test.go | 3 +- registry/registry.go | 106 +++++++++++--------------- registry/registry_test.go | 15 ++-- registry/service.go | 12 ++- registry/session.go | 59 ++++++++++++--- registry/session_v2.go | 4 +- registry/token.go | 4 +- 13 files changed, 325 insertions(+), 128 deletions(-) create mode 100644 pkg/transport/LICENSE create mode 100644 pkg/transport/transport.go diff --git a/graph/pull.go b/graph/pull.go index 013729a8c0..0c33e313a8 100644 --- a/graph/pull.go +++ b/graph/pull.go @@ -17,6 +17,7 @@ import ( "github.com/docker/docker/pkg/progressreader" "github.com/docker/docker/pkg/streamformatter" "github.com/docker/docker/pkg/stringid" + "github.com/docker/docker/pkg/transport" "github.com/docker/docker/registry" "github.com/docker/docker/utils" ) @@ -55,16 +56,17 @@ func (s *TagStore) Pull(image string, tag string, imagePullConfig *ImagePullConf defer s.poolRemove("pull", utils.ImageReference(repoInfo.LocalName, tag)) logrus.Debugf("pulling image from host %q with remote name %q", repoInfo.Index.Name, repoInfo.RemoteName) - endpoint, err := repoInfo.GetEndpoint() + + endpoint, err := repoInfo.GetEndpoint(imagePullConfig.MetaHeaders) if err != nil { return err } - + // TODO(tiborvass): reuse client from endpoint? // Adds Docker-specific headers as well as user-specified headers (metaHeaders) - tr := ®istry.DockerHeaders{ + tr := transport.NewTransport( registry.NewTransport(registry.ReceiveTimeout, endpoint.IsSecure), - imagePullConfig.MetaHeaders, - } + registry.DockerHeaders(imagePullConfig.MetaHeaders)..., + ) client := registry.HTTPClient(tr) r, err := registry.NewSession(client, imagePullConfig.AuthConfig, endpoint) if err != nil { diff --git a/graph/push.go b/graph/push.go index 6e9a367bcb..817ef707fc 100644 --- a/graph/push.go +++ b/graph/push.go @@ -18,6 +18,7 @@ import ( "github.com/docker/docker/pkg/progressreader" "github.com/docker/docker/pkg/streamformatter" "github.com/docker/docker/pkg/stringid" + "github.com/docker/docker/pkg/transport" "github.com/docker/docker/registry" "github.com/docker/docker/runconfig" "github.com/docker/docker/utils" @@ -509,16 +510,17 @@ func (s *TagStore) Push(localName string, imagePushConfig *ImagePushConfig) erro } defer s.poolRemove("push", repoInfo.LocalName) - endpoint, err := repoInfo.GetEndpoint() + endpoint, err := repoInfo.GetEndpoint(imagePushConfig.MetaHeaders) if err != nil { return err } - + // TODO(tiborvass): reuse client from endpoint? // Adds Docker-specific headers as well as user-specified headers (metaHeaders) - tr := ®istry.DockerHeaders{ + tr := transport.NewTransport( registry.NewTransport(registry.NoTimeout, endpoint.IsSecure), - imagePushConfig.MetaHeaders, - } + registry.DockerHeaders(imagePushConfig.MetaHeaders)..., + ) + client := registry.HTTPClient(tr) r, err := registry.NewSession(client, imagePushConfig.AuthConfig, endpoint) if err != nil { return err diff --git a/pkg/transport/LICENSE b/pkg/transport/LICENSE new file mode 100644 index 0000000000..d02f24fd52 --- /dev/null +++ b/pkg/transport/LICENSE @@ -0,0 +1,27 @@ +Copyright (c) 2009 The oauth2 Authors. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/pkg/transport/transport.go b/pkg/transport/transport.go new file mode 100644 index 0000000000..510d8b4bc2 --- /dev/null +++ b/pkg/transport/transport.go @@ -0,0 +1,148 @@ +package transport + +import ( + "io" + "net/http" + "sync" +) + +type RequestModifier interface { + ModifyRequest(*http.Request) error +} + +type headerModifier http.Header + +// NewHeaderRequestModifier returns a RequestModifier that merges the HTTP headers +// passed as an argument, with the HTTP headers of a request. +// +// If the same key is present in both, the modifying header values for that key, +// are appended to the values for that same key in the request header. +func NewHeaderRequestModifier(header http.Header) RequestModifier { + return headerModifier(header) +} + +func (h headerModifier) ModifyRequest(req *http.Request) error { + for k, s := range http.Header(h) { + req.Header[k] = append(req.Header[k], s...) + } + + return nil +} + +// NewTransport returns an http.RoundTripper that modifies requests according to +// the RequestModifiers passed in the arguments, before sending the requests to +// the base http.RoundTripper (which, if nil, defaults to http.DefaultTransport). +func NewTransport(base http.RoundTripper, modifiers ...RequestModifier) http.RoundTripper { + return &transport{ + Modifiers: modifiers, + Base: base, + } +} + +// transport is an http.RoundTripper that makes HTTP requests after +// copying and modifying the request +type transport struct { + Modifiers []RequestModifier + Base http.RoundTripper + + mu sync.Mutex // guards modReq + modReq map[*http.Request]*http.Request // original -> modified +} + +func (t *transport) RoundTrip(req *http.Request) (*http.Response, error) { + req2 := CloneRequest(req) + for _, modifier := range t.Modifiers { + if err := modifier.ModifyRequest(req2); err != nil { + return nil, err + } + } + + t.setModReq(req, req2) + res, err := t.base().RoundTrip(req2) + if err != nil { + t.setModReq(req, nil) + return nil, err + } + res.Body = &OnEOFReader{ + Rc: res.Body, + Fn: func() { t.setModReq(req, nil) }, + } + return res, nil +} + +// CancelRequest cancels an in-flight request by closing its connection. +func (t *transport) CancelRequest(req *http.Request) { + type canceler interface { + CancelRequest(*http.Request) + } + if cr, ok := t.base().(canceler); ok { + t.mu.Lock() + modReq := t.modReq[req] + delete(t.modReq, req) + t.mu.Unlock() + cr.CancelRequest(modReq) + } +} + +func (t *transport) base() http.RoundTripper { + if t.Base != nil { + return t.Base + } + return http.DefaultTransport +} + +func (t *transport) setModReq(orig, mod *http.Request) { + t.mu.Lock() + defer t.mu.Unlock() + if t.modReq == nil { + t.modReq = make(map[*http.Request]*http.Request) + } + if mod == nil { + delete(t.modReq, orig) + } else { + t.modReq[orig] = mod + } +} + +// CloneRequest returns a clone of the provided *http.Request. +// The clone is a shallow copy of the struct and its Header map. +func CloneRequest(r *http.Request) *http.Request { + // shallow copy of the struct + r2 := new(http.Request) + *r2 = *r + // deep copy of the Header + r2.Header = make(http.Header, len(r.Header)) + for k, s := range r.Header { + r2.Header[k] = append([]string(nil), s...) + } + + return r2 +} + +// OnEOFReader ensures a callback function is called +// on Close() and when the underlying Reader returns an io.EOF error +type OnEOFReader struct { + Rc io.ReadCloser + Fn func() +} + +func (r *OnEOFReader) Read(p []byte) (n int, err error) { + n, err = r.Rc.Read(p) + if err == io.EOF { + r.runFunc() + } + return +} + +func (r *OnEOFReader) Close() error { + err := r.Rc.Close() + r.runFunc() + return err +} + +func (r *OnEOFReader) runFunc() { + if fn := r.Fn; fn != nil { + fn() + r.Fn = nil + } +} diff --git a/registry/auth.go b/registry/auth.go index 0b6c3b0f95..33f8fa0689 100644 --- a/registry/auth.go +++ b/registry/auth.go @@ -44,8 +44,6 @@ func (auth *RequestAuthorization) getToken() (string, error) { return auth.tokenCache, nil } - client := auth.registryEndpoint.HTTPClient() - for _, challenge := range auth.registryEndpoint.AuthChallenges { switch strings.ToLower(challenge.Scheme) { case "basic": @@ -57,7 +55,7 @@ func (auth *RequestAuthorization) getToken() (string, error) { params[k] = v } params["scope"] = fmt.Sprintf("%s:%s:%s", auth.resource, auth.scope, strings.Join(auth.actions, ",")) - token, err := getToken(auth.authConfig.Username, auth.authConfig.Password, params, auth.registryEndpoint, client) + token, err := getToken(auth.authConfig.Username, auth.authConfig.Password, params, auth.registryEndpoint) if err != nil { return "", err } @@ -104,7 +102,6 @@ func loginV1(authConfig *cliconfig.AuthConfig, registryEndpoint *Endpoint) (stri status string reqBody []byte err error - client = registryEndpoint.HTTPClient() reqStatusCode = 0 serverAddress = authConfig.ServerAddress ) @@ -128,7 +125,7 @@ func loginV1(authConfig *cliconfig.AuthConfig, registryEndpoint *Endpoint) (stri // using `bytes.NewReader(jsonBody)` here causes the server to respond with a 411 status. b := strings.NewReader(string(jsonBody)) - req1, err := client.Post(serverAddress+"users/", "application/json; charset=utf-8", b) + req1, err := registryEndpoint.client.Post(serverAddress+"users/", "application/json; charset=utf-8", b) if err != nil { return "", fmt.Errorf("Server Error: %s", err) } @@ -151,7 +148,7 @@ func loginV1(authConfig *cliconfig.AuthConfig, registryEndpoint *Endpoint) (stri if string(reqBody) == "\"Username or email already exists\"" { req, err := http.NewRequest("GET", serverAddress+"users/", nil) req.SetBasicAuth(authConfig.Username, authConfig.Password) - resp, err := client.Do(req) + resp, err := registryEndpoint.client.Do(req) if err != nil { return "", err } @@ -180,7 +177,7 @@ func loginV1(authConfig *cliconfig.AuthConfig, registryEndpoint *Endpoint) (stri // protected, so people can use `docker login` as an auth check. req, err := http.NewRequest("GET", serverAddress+"users/", nil) req.SetBasicAuth(authConfig.Username, authConfig.Password) - resp, err := client.Do(req) + resp, err := registryEndpoint.client.Do(req) if err != nil { return "", err } @@ -217,7 +214,6 @@ func loginV2(authConfig *cliconfig.AuthConfig, registryEndpoint *Endpoint) (stri var ( err error allErrors []error - client = registryEndpoint.HTTPClient() ) for _, challenge := range registryEndpoint.AuthChallenges { @@ -225,9 +221,9 @@ func loginV2(authConfig *cliconfig.AuthConfig, registryEndpoint *Endpoint) (stri switch strings.ToLower(challenge.Scheme) { case "basic": - err = tryV2BasicAuthLogin(authConfig, challenge.Parameters, registryEndpoint, client) + err = tryV2BasicAuthLogin(authConfig, challenge.Parameters, registryEndpoint) case "bearer": - err = tryV2TokenAuthLogin(authConfig, challenge.Parameters, registryEndpoint, client) + err = tryV2TokenAuthLogin(authConfig, challenge.Parameters, registryEndpoint) default: // Unsupported challenge types are explicitly skipped. err = fmt.Errorf("unsupported auth scheme: %q", challenge.Scheme) @@ -245,7 +241,7 @@ func loginV2(authConfig *cliconfig.AuthConfig, registryEndpoint *Endpoint) (stri return "", fmt.Errorf("no successful auth challenge for %s - errors: %s", registryEndpoint, allErrors) } -func tryV2BasicAuthLogin(authConfig *cliconfig.AuthConfig, params map[string]string, registryEndpoint *Endpoint, client *http.Client) error { +func tryV2BasicAuthLogin(authConfig *cliconfig.AuthConfig, params map[string]string, registryEndpoint *Endpoint) error { req, err := http.NewRequest("GET", registryEndpoint.Path(""), nil) if err != nil { return err @@ -253,7 +249,7 @@ func tryV2BasicAuthLogin(authConfig *cliconfig.AuthConfig, params map[string]str req.SetBasicAuth(authConfig.Username, authConfig.Password) - resp, err := client.Do(req) + resp, err := registryEndpoint.client.Do(req) if err != nil { return err } @@ -266,8 +262,8 @@ func tryV2BasicAuthLogin(authConfig *cliconfig.AuthConfig, params map[string]str return nil } -func tryV2TokenAuthLogin(authConfig *cliconfig.AuthConfig, params map[string]string, registryEndpoint *Endpoint, client *http.Client) error { - token, err := getToken(authConfig.Username, authConfig.Password, params, registryEndpoint, client) +func tryV2TokenAuthLogin(authConfig *cliconfig.AuthConfig, params map[string]string, registryEndpoint *Endpoint) error { + token, err := getToken(authConfig.Username, authConfig.Password, params, registryEndpoint) if err != nil { return err } @@ -279,7 +275,7 @@ func tryV2TokenAuthLogin(authConfig *cliconfig.AuthConfig, params map[string]str req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) - resp, err := client.Do(req) + resp, err := registryEndpoint.client.Do(req) if err != nil { return err } diff --git a/registry/endpoint.go b/registry/endpoint.go index 25f66ad25f..ce92668f41 100644 --- a/registry/endpoint.go +++ b/registry/endpoint.go @@ -11,6 +11,7 @@ import ( "github.com/Sirupsen/logrus" "github.com/docker/distribution/registry/api/v2" + "github.com/docker/docker/pkg/transport" ) // for mocking in unit tests @@ -41,9 +42,9 @@ func scanForAPIVersion(address string) (string, APIVersion) { } // NewEndpoint parses the given address to return a registry endpoint. -func NewEndpoint(index *IndexInfo) (*Endpoint, error) { +func NewEndpoint(index *IndexInfo, metaHeaders http.Header) (*Endpoint, error) { // *TODO: Allow per-registry configuration of endpoints. - endpoint, err := newEndpoint(index.GetAuthConfigKey(), index.Secure) + endpoint, err := newEndpoint(index.GetAuthConfigKey(), index.Secure, metaHeaders) if err != nil { return nil, err } @@ -81,7 +82,7 @@ func validateEndpoint(endpoint *Endpoint) error { return nil } -func newEndpoint(address string, secure bool) (*Endpoint, error) { +func newEndpoint(address string, secure bool, metaHeaders http.Header) (*Endpoint, error) { var ( endpoint = new(Endpoint) trimmedAddress string @@ -98,11 +99,13 @@ func newEndpoint(address string, secure bool) (*Endpoint, error) { return nil, err } endpoint.IsSecure = secure + tr := NewTransport(ConnectTimeout, endpoint.IsSecure) + endpoint.client = HTTPClient(transport.NewTransport(tr, DockerHeaders(metaHeaders)...)) return endpoint, nil } -func (repoInfo *RepositoryInfo) GetEndpoint() (*Endpoint, error) { - return NewEndpoint(repoInfo.Index) +func (repoInfo *RepositoryInfo) GetEndpoint(metaHeaders http.Header) (*Endpoint, error) { + return NewEndpoint(repoInfo.Index, metaHeaders) } // Endpoint stores basic information about a registry endpoint. @@ -174,7 +177,7 @@ func (e *Endpoint) pingV1() (RegistryInfo, error) { return RegistryInfo{Standalone: false}, err } - resp, err := e.HTTPClient().Do(req) + resp, err := e.client.Do(req) if err != nil { return RegistryInfo{Standalone: false}, err } @@ -222,7 +225,7 @@ func (e *Endpoint) pingV2() (RegistryInfo, error) { return RegistryInfo{}, err } - resp, err := e.HTTPClient().Do(req) + resp, err := e.client.Do(req) if err != nil { return RegistryInfo{}, err } @@ -261,11 +264,3 @@ HeaderLoop: return RegistryInfo{}, fmt.Errorf("v2 registry endpoint returned status %d: %q", resp.StatusCode, http.StatusText(resp.StatusCode)) } - -func (e *Endpoint) HTTPClient() *http.Client { - if e.client == nil { - tr := NewTransport(ConnectTimeout, e.IsSecure) - e.client = HTTPClient(tr) - } - return e.client -} diff --git a/registry/endpoint_test.go b/registry/endpoint_test.go index 9567ba2352..6f67867bbb 100644 --- a/registry/endpoint_test.go +++ b/registry/endpoint_test.go @@ -19,7 +19,7 @@ func TestEndpointParse(t *testing.T) { {"0.0.0.0:5000", "https://0.0.0.0:5000/v0/"}, } for _, td := range testData { - e, err := newEndpoint(td.str, false) + e, err := newEndpoint(td.str, false, nil) if err != nil { t.Errorf("%q: %s", td.str, err) } @@ -60,6 +60,7 @@ func TestValidateEndpointAmbiguousAPIVersion(t *testing.T) { testEndpoint := Endpoint{ URL: testServerURL, Version: APIVersionUnknown, + client: HTTPClient(NewTransport(ConnectTimeout, false)), } if err = validateEndpoint(&testEndpoint); err != nil { diff --git a/registry/registry.go b/registry/registry.go index 4f5403002a..b0706e348e 100644 --- a/registry/registry.go +++ b/registry/registry.go @@ -19,6 +19,7 @@ import ( "github.com/docker/docker/autogen/dockerversion" "github.com/docker/docker/pkg/parsers/kernel" "github.com/docker/docker/pkg/timeoutconn" + "github.com/docker/docker/pkg/transport" "github.com/docker/docker/pkg/useragent" ) @@ -36,17 +37,32 @@ const ( ConnectTimeout ) -type httpsTransport struct { - *http.Transport +// dockerUserAgent is the User-Agent the Docker client uses to identify itself. +// It is populated on init(), comprising version information of different components. +var dockerUserAgent string + +func init() { + httpVersion := make([]useragent.VersionInfo, 0, 6) + httpVersion = append(httpVersion, useragent.VersionInfo{"docker", dockerversion.VERSION}) + httpVersion = append(httpVersion, useragent.VersionInfo{"go", runtime.Version()}) + httpVersion = append(httpVersion, useragent.VersionInfo{"git-commit", dockerversion.GITCOMMIT}) + if kernelVersion, err := kernel.GetKernelVersion(); err == nil { + httpVersion = append(httpVersion, useragent.VersionInfo{"kernel", kernelVersion.String()}) + } + httpVersion = append(httpVersion, useragent.VersionInfo{"os", runtime.GOOS}) + httpVersion = append(httpVersion, useragent.VersionInfo{"arch", runtime.GOARCH}) + + dockerUserAgent = useragent.AppendVersions("", httpVersion...) } +type httpsRequestModifier struct{ tlsConfig *tls.Config } + // DRAGONS(tiborvass): If someone wonders why do we set tlsconfig in a roundtrip, // it's because it's so as to match the current behavior in master: we generate the // certpool on every-goddam-request. It's not great, but it allows people to just put // the certs in /etc/docker/certs.d/.../ and let docker "pick it up" immediately. Would // prefer an fsnotify implementation, but that was out of scope of my refactoring. -// TODO: improve things -func (tr *httpsTransport) RoundTrip(req *http.Request) (*http.Response, error) { +func (m *httpsRequestModifier) ModifyRequest(req *http.Request) error { var ( roots *x509.CertPool certs []tls.Certificate @@ -66,7 +82,7 @@ func (tr *httpsTransport) RoundTrip(req *http.Request) (*http.Response, error) { logrus.Debugf("hostDir: %s", hostDir) fs, err := ioutil.ReadDir(hostDir) if err != nil && !os.IsNotExist(err) { - return nil, err + return nil } for _, f := range fs { @@ -77,7 +93,7 @@ func (tr *httpsTransport) RoundTrip(req *http.Request) (*http.Response, error) { logrus.Debugf("crt: %s", hostDir+"/"+f.Name()) data, err := ioutil.ReadFile(path.Join(hostDir, f.Name())) if err != nil { - return nil, err + return err } roots.AppendCertsFromPEM(data) } @@ -86,11 +102,11 @@ func (tr *httpsTransport) RoundTrip(req *http.Request) (*http.Response, error) { keyName := certName[:len(certName)-5] + ".key" logrus.Debugf("cert: %s", hostDir+"/"+f.Name()) if !hasFile(fs, keyName) { - return nil, fmt.Errorf("Missing key %s for certificate %s", keyName, certName) + return fmt.Errorf("Missing key %s for certificate %s", keyName, certName) } cert, err := tls.LoadX509KeyPair(path.Join(hostDir, certName), path.Join(hostDir, keyName)) if err != nil { - return nil, err + return err } certs = append(certs, cert) } @@ -99,38 +115,32 @@ func (tr *httpsTransport) RoundTrip(req *http.Request) (*http.Response, error) { certName := keyName[:len(keyName)-4] + ".cert" logrus.Debugf("key: %s", hostDir+"/"+f.Name()) if !hasFile(fs, certName) { - return nil, fmt.Errorf("Missing certificate %s for key %s", certName, keyName) + return fmt.Errorf("Missing certificate %s for key %s", certName, keyName) } } } - if tr.Transport.TLSClientConfig == nil { - tr.Transport.TLSClientConfig = &tls.Config{ - // Avoid fallback to SSL protocols < TLS1.0 - MinVersion: tls.VersionTLS10, - } - } - tr.Transport.TLSClientConfig.RootCAs = roots - tr.Transport.TLSClientConfig.Certificates = certs + m.tlsConfig.RootCAs = roots + m.tlsConfig.Certificates = certs } - return tr.Transport.RoundTrip(req) + return nil } func NewTransport(timeout TimeoutType, secure bool) http.RoundTripper { - tlsConfig := tls.Config{ + tlsConfig := &tls.Config{ // Avoid fallback to SSL protocols < TLS1.0 MinVersion: tls.VersionTLS10, InsecureSkipVerify: !secure, } - transport := &http.Transport{ + tr := &http.Transport{ DisableKeepAlives: true, Proxy: http.ProxyFromEnvironment, - TLSClientConfig: &tlsConfig, + TLSClientConfig: tlsConfig, } switch timeout { case ConnectTimeout: - transport.Dial = func(proto string, addr string) (net.Conn, error) { + tr.Dial = func(proto string, addr string) (net.Conn, error) { // Set the connect timeout to 30 seconds to allow for slower connection // times... d := net.Dialer{Timeout: 30 * time.Second, DualStack: true} @@ -144,7 +154,7 @@ func NewTransport(timeout TimeoutType, secure bool) http.RoundTripper { return conn, nil } case ReceiveTimeout: - transport.Dial = func(proto string, addr string) (net.Conn, error) { + tr.Dial = func(proto string, addr string) (net.Conn, error) { d := net.Dialer{DualStack: true} conn, err := d.Dial(proto, addr) @@ -159,51 +169,23 @@ func NewTransport(timeout TimeoutType, secure bool) http.RoundTripper { if secure { // note: httpsTransport also handles http transport // but for HTTPS, it sets up the certs - return &httpsTransport{transport} + return transport.NewTransport(tr, &httpsRequestModifier{tlsConfig}) } - return transport + return tr } -type DockerHeaders struct { - http.RoundTripper - Headers http.Header -} - -// cloneRequest returns a clone of the provided *http.Request. -// The clone is a shallow copy of the struct and its Header map -func cloneRequest(r *http.Request) *http.Request { - // shallow copy of the struct - r2 := new(http.Request) - *r2 = *r - // deep copy of the Header - r2.Header = make(http.Header, len(r.Header)) - for k, s := range r.Header { - r2.Header[k] = append([]string(nil), s...) +// DockerHeaders returns request modifiers that ensure requests have +// the User-Agent header set to dockerUserAgent and that metaHeaders +// are added. +func DockerHeaders(metaHeaders http.Header) []transport.RequestModifier { + modifiers := []transport.RequestModifier{ + transport.NewHeaderRequestModifier(http.Header{"User-Agent": []string{dockerUserAgent}}), } - return r2 -} - -func (tr *DockerHeaders) RoundTrip(req *http.Request) (*http.Response, error) { - req = cloneRequest(req) - httpVersion := make([]useragent.VersionInfo, 0, 4) - httpVersion = append(httpVersion, useragent.VersionInfo{"docker", dockerversion.VERSION}) - httpVersion = append(httpVersion, useragent.VersionInfo{"go", runtime.Version()}) - httpVersion = append(httpVersion, useragent.VersionInfo{"git-commit", dockerversion.GITCOMMIT}) - if kernelVersion, err := kernel.GetKernelVersion(); err == nil { - httpVersion = append(httpVersion, useragent.VersionInfo{"kernel", kernelVersion.String()}) + if metaHeaders != nil { + modifiers = append(modifiers, transport.NewHeaderRequestModifier(metaHeaders)) } - httpVersion = append(httpVersion, useragent.VersionInfo{"os", runtime.GOOS}) - httpVersion = append(httpVersion, useragent.VersionInfo{"arch", runtime.GOARCH}) - - userAgent := useragent.AppendVersions(req.UserAgent(), httpVersion...) - - req.Header.Set("User-Agent", userAgent) - - for k, v := range tr.Headers { - req.Header[k] = v - } - return tr.RoundTripper.RoundTrip(req) + return modifiers } type debugTransport struct{ http.RoundTripper } diff --git a/registry/registry_test.go b/registry/registry_test.go index d4a5ded082..33e86ff43a 100644 --- a/registry/registry_test.go +++ b/registry/registry_test.go @@ -8,6 +8,7 @@ import ( "testing" "github.com/docker/docker/cliconfig" + "github.com/docker/docker/pkg/transport" ) var ( @@ -21,12 +22,12 @@ const ( func spawnTestRegistrySession(t *testing.T) *Session { authConfig := &cliconfig.AuthConfig{} - endpoint, err := NewEndpoint(makeIndex("/v1/")) + endpoint, err := NewEndpoint(makeIndex("/v1/"), nil) if err != nil { t.Fatal(err) } var tr http.RoundTripper = debugTransport{NewTransport(ReceiveTimeout, endpoint.IsSecure)} - tr = &DockerHeaders{&authTransport{RoundTripper: tr, AuthConfig: authConfig}, nil} + tr = transport.NewTransport(AuthTransport(tr, authConfig, false), DockerHeaders(nil)...) client := HTTPClient(tr) r, err := NewSession(client, authConfig, endpoint) if err != nil { @@ -48,7 +49,7 @@ func spawnTestRegistrySession(t *testing.T) *Session { func TestPingRegistryEndpoint(t *testing.T) { testPing := func(index *IndexInfo, expectedStandalone bool, assertMessage string) { - ep, err := NewEndpoint(index) + ep, err := NewEndpoint(index, nil) if err != nil { t.Fatal(err) } @@ -68,7 +69,7 @@ func TestPingRegistryEndpoint(t *testing.T) { func TestEndpoint(t *testing.T) { // Simple wrapper to fail test if err != nil expandEndpoint := func(index *IndexInfo) *Endpoint { - endpoint, err := NewEndpoint(index) + endpoint, err := NewEndpoint(index, nil) if err != nil { t.Fatal(err) } @@ -77,7 +78,7 @@ func TestEndpoint(t *testing.T) { assertInsecureIndex := func(index *IndexInfo) { index.Secure = true - _, err := NewEndpoint(index) + _, err := NewEndpoint(index, nil) assertNotEqual(t, err, nil, index.Name+": Expected error for insecure index") assertEqual(t, strings.Contains(err.Error(), "insecure-registry"), true, index.Name+": Expected insecure-registry error for insecure index") index.Secure = false @@ -85,7 +86,7 @@ func TestEndpoint(t *testing.T) { assertSecureIndex := func(index *IndexInfo) { index.Secure = true - _, err := NewEndpoint(index) + _, err := NewEndpoint(index, nil) assertNotEqual(t, err, nil, index.Name+": Expected cert error for secure index") assertEqual(t, strings.Contains(err.Error(), "certificate signed by unknown authority"), true, index.Name+": Expected cert error for secure index") index.Secure = false @@ -151,7 +152,7 @@ func TestEndpoint(t *testing.T) { } for _, address := range badEndpoints { index.Name = address - _, err := NewEndpoint(index) + _, err := NewEndpoint(index, nil) checkNotEqual(t, err, nil, "Expected error while expanding bad endpoint") } } diff --git a/registry/service.go b/registry/service.go index 067df107c2..6811749272 100644 --- a/registry/service.go +++ b/registry/service.go @@ -1,6 +1,10 @@ package registry -import "github.com/docker/docker/cliconfig" +import ( + "net/http" + + "github.com/docker/docker/cliconfig" +) type Service struct { Config *ServiceConfig @@ -27,7 +31,7 @@ func (s *Service) Auth(authConfig *cliconfig.AuthConfig) (string, error) { if err != nil { return "", err } - endpoint, err := NewEndpoint(index) + endpoint, err := NewEndpoint(index, nil) if err != nil { return "", err } @@ -44,11 +48,11 @@ func (s *Service) Search(term string, authConfig *cliconfig.AuthConfig, headers } // *TODO: Search multiple indexes. - endpoint, err := repoInfo.GetEndpoint() + endpoint, err := repoInfo.GetEndpoint(http.Header(headers)) if err != nil { return nil, err } - r, err := NewSession(endpoint.HTTPClient(), authConfig, endpoint) + r, err := NewSession(endpoint.client, authConfig, endpoint) if err != nil { return nil, err } diff --git a/registry/session.go b/registry/session.go index 686e322dab..8e54bc8211 100644 --- a/registry/session.go +++ b/registry/session.go @@ -4,6 +4,7 @@ import ( "bytes" "crypto/sha256" "errors" + "sync" // this is required for some certificates _ "crypto/sha512" "encoding/hex" @@ -22,6 +23,7 @@ import ( "github.com/docker/docker/cliconfig" "github.com/docker/docker/pkg/httputils" "github.com/docker/docker/pkg/tarsum" + "github.com/docker/docker/pkg/transport" ) type Session struct { @@ -31,7 +33,18 @@ type Session struct { authConfig *cliconfig.AuthConfig } -// authTransport handles the auth layer when communicating with a v1 registry (private or official) +type authTransport struct { + http.RoundTripper + *cliconfig.AuthConfig + + alwaysSetBasicAuth bool + token []string + + mu sync.Mutex // guards modReq + modReq map[*http.Request]*http.Request // original -> modified +} + +// AuthTransport handles the auth layer when communicating with a v1 registry (private or official) // // For private v1 registries, set alwaysSetBasicAuth to true. // @@ -44,16 +57,23 @@ type Session struct { // If the server sends a token without the client having requested it, it is ignored. // // This RoundTripper also has a CancelRequest method important for correct timeout handling. -type authTransport struct { - http.RoundTripper - *cliconfig.AuthConfig - - alwaysSetBasicAuth bool - token []string +func AuthTransport(base http.RoundTripper, authConfig *cliconfig.AuthConfig, alwaysSetBasicAuth bool) http.RoundTripper { + if base == nil { + base = http.DefaultTransport + } + return &authTransport{ + RoundTripper: base, + AuthConfig: authConfig, + alwaysSetBasicAuth: alwaysSetBasicAuth, + modReq: make(map[*http.Request]*http.Request), + } } -func (tr *authTransport) RoundTrip(req *http.Request) (*http.Response, error) { - req = cloneRequest(req) +func (tr *authTransport) RoundTrip(orig *http.Request) (*http.Response, error) { + req := transport.CloneRequest(orig) + tr.mu.Lock() + tr.modReq[orig] = req + tr.mu.Unlock() if tr.alwaysSetBasicAuth { req.SetBasicAuth(tr.Username, tr.Password) @@ -73,14 +93,33 @@ func (tr *authTransport) RoundTrip(req *http.Request) (*http.Response, error) { } resp, err := tr.RoundTripper.RoundTrip(req) if err != nil { + delete(tr.modReq, orig) return nil, err } if askedForToken && len(resp.Header["X-Docker-Token"]) > 0 { tr.token = resp.Header["X-Docker-Token"] } + resp.Body = &transport.OnEOFReader{ + Rc: resp.Body, + Fn: func() { delete(tr.modReq, orig) }, + } return resp, nil } +// CancelRequest cancels an in-flight request by closing its connection. +func (tr *authTransport) CancelRequest(req *http.Request) { + type canceler interface { + CancelRequest(*http.Request) + } + if cr, ok := tr.RoundTripper.(canceler); ok { + tr.mu.Lock() + modReq := tr.modReq[req] + delete(tr.modReq, req) + tr.mu.Unlock() + cr.CancelRequest(modReq) + } +} + // TODO(tiborvass): remove authConfig param once registry client v2 is vendored func NewSession(client *http.Client, authConfig *cliconfig.AuthConfig, endpoint *Endpoint) (r *Session, err error) { r = &Session{ @@ -105,7 +144,7 @@ func NewSession(client *http.Client, authConfig *cliconfig.AuthConfig, endpoint } } - client.Transport = &authTransport{RoundTripper: client.Transport, AuthConfig: authConfig, alwaysSetBasicAuth: alwaysSetBasicAuth} + client.Transport = AuthTransport(client.Transport, authConfig, alwaysSetBasicAuth) jar, err := cookiejar.New(nil) if err != nil { diff --git a/registry/session_v2.go b/registry/session_v2.go index c639f9226a..b660172898 100644 --- a/registry/session_v2.go +++ b/registry/session_v2.go @@ -27,7 +27,7 @@ func getV2Builder(e *Endpoint) *v2.URLBuilder { func (r *Session) V2RegistryEndpoint(index *IndexInfo) (ep *Endpoint, err error) { // TODO check if should use Mirror if index.Official { - ep, err = newEndpoint(REGISTRYSERVER, true) + ep, err = newEndpoint(REGISTRYSERVER, true, nil) if err != nil { return } @@ -38,7 +38,7 @@ func (r *Session) V2RegistryEndpoint(index *IndexInfo) (ep *Endpoint, err error) } else if r.indexEndpoint.String() == index.GetAuthConfigKey() { ep = r.indexEndpoint } else { - ep, err = NewEndpoint(index) + ep, err = NewEndpoint(index, nil) if err != nil { return } diff --git a/registry/token.go b/registry/token.go index af7d5f3fcb..e27cb6f528 100644 --- a/registry/token.go +++ b/registry/token.go @@ -13,7 +13,7 @@ type tokenResponse struct { Token string `json:"token"` } -func getToken(username, password string, params map[string]string, registryEndpoint *Endpoint, client *http.Client) (token string, err error) { +func getToken(username, password string, params map[string]string, registryEndpoint *Endpoint) (token string, err error) { realm, ok := params["realm"] if !ok { return "", errors.New("no realm specified for token auth challenge") @@ -56,7 +56,7 @@ func getToken(username, password string, params map[string]string, registryEndpo req.URL.RawQuery = reqParams.Encode() - resp, err := client.Do(req) + resp, err := registryEndpoint.client.Do(req) if err != nil { return "", err }