Add transport package to support CancelRequest

Signed-off-by: Tibor Vass <tibor@docker.com>
This commit is contained in:
Tibor Vass 2015-05-15 18:35:04 -07:00
parent cf8c0d0f56
commit 73823e5e56
13 changed files with 325 additions and 128 deletions

View File

@ -17,6 +17,7 @@ import (
"github.com/docker/docker/pkg/progressreader"
"github.com/docker/docker/pkg/streamformatter"
"github.com/docker/docker/pkg/stringid"
"github.com/docker/docker/pkg/transport"
"github.com/docker/docker/registry"
"github.com/docker/docker/utils"
)
@ -55,16 +56,17 @@ func (s *TagStore) Pull(image string, tag string, imagePullConfig *ImagePullConf
defer s.poolRemove("pull", utils.ImageReference(repoInfo.LocalName, tag))
logrus.Debugf("pulling image from host %q with remote name %q", repoInfo.Index.Name, repoInfo.RemoteName)
endpoint, err := repoInfo.GetEndpoint()
endpoint, err := repoInfo.GetEndpoint(imagePullConfig.MetaHeaders)
if err != nil {
return err
}
// TODO(tiborvass): reuse client from endpoint?
// Adds Docker-specific headers as well as user-specified headers (metaHeaders)
tr := &registry.DockerHeaders{
tr := transport.NewTransport(
registry.NewTransport(registry.ReceiveTimeout, endpoint.IsSecure),
imagePullConfig.MetaHeaders,
}
registry.DockerHeaders(imagePullConfig.MetaHeaders)...,
)
client := registry.HTTPClient(tr)
r, err := registry.NewSession(client, imagePullConfig.AuthConfig, endpoint)
if err != nil {

View File

@ -18,6 +18,7 @@ import (
"github.com/docker/docker/pkg/progressreader"
"github.com/docker/docker/pkg/streamformatter"
"github.com/docker/docker/pkg/stringid"
"github.com/docker/docker/pkg/transport"
"github.com/docker/docker/registry"
"github.com/docker/docker/runconfig"
"github.com/docker/docker/utils"
@ -509,16 +510,17 @@ func (s *TagStore) Push(localName string, imagePushConfig *ImagePushConfig) erro
}
defer s.poolRemove("push", repoInfo.LocalName)
endpoint, err := repoInfo.GetEndpoint()
endpoint, err := repoInfo.GetEndpoint(imagePushConfig.MetaHeaders)
if err != nil {
return err
}
// TODO(tiborvass): reuse client from endpoint?
// Adds Docker-specific headers as well as user-specified headers (metaHeaders)
tr := &registry.DockerHeaders{
tr := transport.NewTransport(
registry.NewTransport(registry.NoTimeout, endpoint.IsSecure),
imagePushConfig.MetaHeaders,
}
registry.DockerHeaders(imagePushConfig.MetaHeaders)...,
)
client := registry.HTTPClient(tr)
r, err := registry.NewSession(client, imagePushConfig.AuthConfig, endpoint)
if err != nil {
return err

27
pkg/transport/LICENSE Normal file
View File

@ -0,0 +1,27 @@
Copyright (c) 2009 The oauth2 Authors. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Google Inc. nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

148
pkg/transport/transport.go Normal file
View File

@ -0,0 +1,148 @@
package transport
import (
"io"
"net/http"
"sync"
)
type RequestModifier interface {
ModifyRequest(*http.Request) error
}
type headerModifier http.Header
// NewHeaderRequestModifier returns a RequestModifier that merges the HTTP headers
// passed as an argument, with the HTTP headers of a request.
//
// If the same key is present in both, the modifying header values for that key,
// are appended to the values for that same key in the request header.
func NewHeaderRequestModifier(header http.Header) RequestModifier {
return headerModifier(header)
}
func (h headerModifier) ModifyRequest(req *http.Request) error {
for k, s := range http.Header(h) {
req.Header[k] = append(req.Header[k], s...)
}
return nil
}
// NewTransport returns an http.RoundTripper that modifies requests according to
// the RequestModifiers passed in the arguments, before sending the requests to
// the base http.RoundTripper (which, if nil, defaults to http.DefaultTransport).
func NewTransport(base http.RoundTripper, modifiers ...RequestModifier) http.RoundTripper {
return &transport{
Modifiers: modifiers,
Base: base,
}
}
// transport is an http.RoundTripper that makes HTTP requests after
// copying and modifying the request
type transport struct {
Modifiers []RequestModifier
Base http.RoundTripper
mu sync.Mutex // guards modReq
modReq map[*http.Request]*http.Request // original -> modified
}
func (t *transport) RoundTrip(req *http.Request) (*http.Response, error) {
req2 := CloneRequest(req)
for _, modifier := range t.Modifiers {
if err := modifier.ModifyRequest(req2); err != nil {
return nil, err
}
}
t.setModReq(req, req2)
res, err := t.base().RoundTrip(req2)
if err != nil {
t.setModReq(req, nil)
return nil, err
}
res.Body = &OnEOFReader{
Rc: res.Body,
Fn: func() { t.setModReq(req, nil) },
}
return res, nil
}
// CancelRequest cancels an in-flight request by closing its connection.
func (t *transport) CancelRequest(req *http.Request) {
type canceler interface {
CancelRequest(*http.Request)
}
if cr, ok := t.base().(canceler); ok {
t.mu.Lock()
modReq := t.modReq[req]
delete(t.modReq, req)
t.mu.Unlock()
cr.CancelRequest(modReq)
}
}
func (t *transport) base() http.RoundTripper {
if t.Base != nil {
return t.Base
}
return http.DefaultTransport
}
func (t *transport) setModReq(orig, mod *http.Request) {
t.mu.Lock()
defer t.mu.Unlock()
if t.modReq == nil {
t.modReq = make(map[*http.Request]*http.Request)
}
if mod == nil {
delete(t.modReq, orig)
} else {
t.modReq[orig] = mod
}
}
// CloneRequest returns a clone of the provided *http.Request.
// The clone is a shallow copy of the struct and its Header map.
func CloneRequest(r *http.Request) *http.Request {
// shallow copy of the struct
r2 := new(http.Request)
*r2 = *r
// deep copy of the Header
r2.Header = make(http.Header, len(r.Header))
for k, s := range r.Header {
r2.Header[k] = append([]string(nil), s...)
}
return r2
}
// OnEOFReader ensures a callback function is called
// on Close() and when the underlying Reader returns an io.EOF error
type OnEOFReader struct {
Rc io.ReadCloser
Fn func()
}
func (r *OnEOFReader) Read(p []byte) (n int, err error) {
n, err = r.Rc.Read(p)
if err == io.EOF {
r.runFunc()
}
return
}
func (r *OnEOFReader) Close() error {
err := r.Rc.Close()
r.runFunc()
return err
}
func (r *OnEOFReader) runFunc() {
if fn := r.Fn; fn != nil {
fn()
r.Fn = nil
}
}

View File

@ -44,8 +44,6 @@ func (auth *RequestAuthorization) getToken() (string, error) {
return auth.tokenCache, nil
}
client := auth.registryEndpoint.HTTPClient()
for _, challenge := range auth.registryEndpoint.AuthChallenges {
switch strings.ToLower(challenge.Scheme) {
case "basic":
@ -57,7 +55,7 @@ func (auth *RequestAuthorization) getToken() (string, error) {
params[k] = v
}
params["scope"] = fmt.Sprintf("%s:%s:%s", auth.resource, auth.scope, strings.Join(auth.actions, ","))
token, err := getToken(auth.authConfig.Username, auth.authConfig.Password, params, auth.registryEndpoint, client)
token, err := getToken(auth.authConfig.Username, auth.authConfig.Password, params, auth.registryEndpoint)
if err != nil {
return "", err
}
@ -104,7 +102,6 @@ func loginV1(authConfig *cliconfig.AuthConfig, registryEndpoint *Endpoint) (stri
status string
reqBody []byte
err error
client = registryEndpoint.HTTPClient()
reqStatusCode = 0
serverAddress = authConfig.ServerAddress
)
@ -128,7 +125,7 @@ func loginV1(authConfig *cliconfig.AuthConfig, registryEndpoint *Endpoint) (stri
// using `bytes.NewReader(jsonBody)` here causes the server to respond with a 411 status.
b := strings.NewReader(string(jsonBody))
req1, err := client.Post(serverAddress+"users/", "application/json; charset=utf-8", b)
req1, err := registryEndpoint.client.Post(serverAddress+"users/", "application/json; charset=utf-8", b)
if err != nil {
return "", fmt.Errorf("Server Error: %s", err)
}
@ -151,7 +148,7 @@ func loginV1(authConfig *cliconfig.AuthConfig, registryEndpoint *Endpoint) (stri
if string(reqBody) == "\"Username or email already exists\"" {
req, err := http.NewRequest("GET", serverAddress+"users/", nil)
req.SetBasicAuth(authConfig.Username, authConfig.Password)
resp, err := client.Do(req)
resp, err := registryEndpoint.client.Do(req)
if err != nil {
return "", err
}
@ -180,7 +177,7 @@ func loginV1(authConfig *cliconfig.AuthConfig, registryEndpoint *Endpoint) (stri
// protected, so people can use `docker login` as an auth check.
req, err := http.NewRequest("GET", serverAddress+"users/", nil)
req.SetBasicAuth(authConfig.Username, authConfig.Password)
resp, err := client.Do(req)
resp, err := registryEndpoint.client.Do(req)
if err != nil {
return "", err
}
@ -217,7 +214,6 @@ func loginV2(authConfig *cliconfig.AuthConfig, registryEndpoint *Endpoint) (stri
var (
err error
allErrors []error
client = registryEndpoint.HTTPClient()
)
for _, challenge := range registryEndpoint.AuthChallenges {
@ -225,9 +221,9 @@ func loginV2(authConfig *cliconfig.AuthConfig, registryEndpoint *Endpoint) (stri
switch strings.ToLower(challenge.Scheme) {
case "basic":
err = tryV2BasicAuthLogin(authConfig, challenge.Parameters, registryEndpoint, client)
err = tryV2BasicAuthLogin(authConfig, challenge.Parameters, registryEndpoint)
case "bearer":
err = tryV2TokenAuthLogin(authConfig, challenge.Parameters, registryEndpoint, client)
err = tryV2TokenAuthLogin(authConfig, challenge.Parameters, registryEndpoint)
default:
// Unsupported challenge types are explicitly skipped.
err = fmt.Errorf("unsupported auth scheme: %q", challenge.Scheme)
@ -245,7 +241,7 @@ func loginV2(authConfig *cliconfig.AuthConfig, registryEndpoint *Endpoint) (stri
return "", fmt.Errorf("no successful auth challenge for %s - errors: %s", registryEndpoint, allErrors)
}
func tryV2BasicAuthLogin(authConfig *cliconfig.AuthConfig, params map[string]string, registryEndpoint *Endpoint, client *http.Client) error {
func tryV2BasicAuthLogin(authConfig *cliconfig.AuthConfig, params map[string]string, registryEndpoint *Endpoint) error {
req, err := http.NewRequest("GET", registryEndpoint.Path(""), nil)
if err != nil {
return err
@ -253,7 +249,7 @@ func tryV2BasicAuthLogin(authConfig *cliconfig.AuthConfig, params map[string]str
req.SetBasicAuth(authConfig.Username, authConfig.Password)
resp, err := client.Do(req)
resp, err := registryEndpoint.client.Do(req)
if err != nil {
return err
}
@ -266,8 +262,8 @@ func tryV2BasicAuthLogin(authConfig *cliconfig.AuthConfig, params map[string]str
return nil
}
func tryV2TokenAuthLogin(authConfig *cliconfig.AuthConfig, params map[string]string, registryEndpoint *Endpoint, client *http.Client) error {
token, err := getToken(authConfig.Username, authConfig.Password, params, registryEndpoint, client)
func tryV2TokenAuthLogin(authConfig *cliconfig.AuthConfig, params map[string]string, registryEndpoint *Endpoint) error {
token, err := getToken(authConfig.Username, authConfig.Password, params, registryEndpoint)
if err != nil {
return err
}
@ -279,7 +275,7 @@ func tryV2TokenAuthLogin(authConfig *cliconfig.AuthConfig, params map[string]str
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
resp, err := client.Do(req)
resp, err := registryEndpoint.client.Do(req)
if err != nil {
return err
}

View File

@ -11,6 +11,7 @@ import (
"github.com/Sirupsen/logrus"
"github.com/docker/distribution/registry/api/v2"
"github.com/docker/docker/pkg/transport"
)
// for mocking in unit tests
@ -41,9 +42,9 @@ func scanForAPIVersion(address string) (string, APIVersion) {
}
// NewEndpoint parses the given address to return a registry endpoint.
func NewEndpoint(index *IndexInfo) (*Endpoint, error) {
func NewEndpoint(index *IndexInfo, metaHeaders http.Header) (*Endpoint, error) {
// *TODO: Allow per-registry configuration of endpoints.
endpoint, err := newEndpoint(index.GetAuthConfigKey(), index.Secure)
endpoint, err := newEndpoint(index.GetAuthConfigKey(), index.Secure, metaHeaders)
if err != nil {
return nil, err
}
@ -81,7 +82,7 @@ func validateEndpoint(endpoint *Endpoint) error {
return nil
}
func newEndpoint(address string, secure bool) (*Endpoint, error) {
func newEndpoint(address string, secure bool, metaHeaders http.Header) (*Endpoint, error) {
var (
endpoint = new(Endpoint)
trimmedAddress string
@ -98,11 +99,13 @@ func newEndpoint(address string, secure bool) (*Endpoint, error) {
return nil, err
}
endpoint.IsSecure = secure
tr := NewTransport(ConnectTimeout, endpoint.IsSecure)
endpoint.client = HTTPClient(transport.NewTransport(tr, DockerHeaders(metaHeaders)...))
return endpoint, nil
}
func (repoInfo *RepositoryInfo) GetEndpoint() (*Endpoint, error) {
return NewEndpoint(repoInfo.Index)
func (repoInfo *RepositoryInfo) GetEndpoint(metaHeaders http.Header) (*Endpoint, error) {
return NewEndpoint(repoInfo.Index, metaHeaders)
}
// Endpoint stores basic information about a registry endpoint.
@ -174,7 +177,7 @@ func (e *Endpoint) pingV1() (RegistryInfo, error) {
return RegistryInfo{Standalone: false}, err
}
resp, err := e.HTTPClient().Do(req)
resp, err := e.client.Do(req)
if err != nil {
return RegistryInfo{Standalone: false}, err
}
@ -222,7 +225,7 @@ func (e *Endpoint) pingV2() (RegistryInfo, error) {
return RegistryInfo{}, err
}
resp, err := e.HTTPClient().Do(req)
resp, err := e.client.Do(req)
if err != nil {
return RegistryInfo{}, err
}
@ -261,11 +264,3 @@ HeaderLoop:
return RegistryInfo{}, fmt.Errorf("v2 registry endpoint returned status %d: %q", resp.StatusCode, http.StatusText(resp.StatusCode))
}
func (e *Endpoint) HTTPClient() *http.Client {
if e.client == nil {
tr := NewTransport(ConnectTimeout, e.IsSecure)
e.client = HTTPClient(tr)
}
return e.client
}

View File

@ -19,7 +19,7 @@ func TestEndpointParse(t *testing.T) {
{"0.0.0.0:5000", "https://0.0.0.0:5000/v0/"},
}
for _, td := range testData {
e, err := newEndpoint(td.str, false)
e, err := newEndpoint(td.str, false, nil)
if err != nil {
t.Errorf("%q: %s", td.str, err)
}
@ -60,6 +60,7 @@ func TestValidateEndpointAmbiguousAPIVersion(t *testing.T) {
testEndpoint := Endpoint{
URL: testServerURL,
Version: APIVersionUnknown,
client: HTTPClient(NewTransport(ConnectTimeout, false)),
}
if err = validateEndpoint(&testEndpoint); err != nil {

View File

@ -19,6 +19,7 @@ import (
"github.com/docker/docker/autogen/dockerversion"
"github.com/docker/docker/pkg/parsers/kernel"
"github.com/docker/docker/pkg/timeoutconn"
"github.com/docker/docker/pkg/transport"
"github.com/docker/docker/pkg/useragent"
)
@ -36,17 +37,32 @@ const (
ConnectTimeout
)
type httpsTransport struct {
*http.Transport
// dockerUserAgent is the User-Agent the Docker client uses to identify itself.
// It is populated on init(), comprising version information of different components.
var dockerUserAgent string
func init() {
httpVersion := make([]useragent.VersionInfo, 0, 6)
httpVersion = append(httpVersion, useragent.VersionInfo{"docker", dockerversion.VERSION})
httpVersion = append(httpVersion, useragent.VersionInfo{"go", runtime.Version()})
httpVersion = append(httpVersion, useragent.VersionInfo{"git-commit", dockerversion.GITCOMMIT})
if kernelVersion, err := kernel.GetKernelVersion(); err == nil {
httpVersion = append(httpVersion, useragent.VersionInfo{"kernel", kernelVersion.String()})
}
httpVersion = append(httpVersion, useragent.VersionInfo{"os", runtime.GOOS})
httpVersion = append(httpVersion, useragent.VersionInfo{"arch", runtime.GOARCH})
dockerUserAgent = useragent.AppendVersions("", httpVersion...)
}
type httpsRequestModifier struct{ tlsConfig *tls.Config }
// DRAGONS(tiborvass): If someone wonders why do we set tlsconfig in a roundtrip,
// it's because it's so as to match the current behavior in master: we generate the
// certpool on every-goddam-request. It's not great, but it allows people to just put
// the certs in /etc/docker/certs.d/.../ and let docker "pick it up" immediately. Would
// prefer an fsnotify implementation, but that was out of scope of my refactoring.
// TODO: improve things
func (tr *httpsTransport) RoundTrip(req *http.Request) (*http.Response, error) {
func (m *httpsRequestModifier) ModifyRequest(req *http.Request) error {
var (
roots *x509.CertPool
certs []tls.Certificate
@ -66,7 +82,7 @@ func (tr *httpsTransport) RoundTrip(req *http.Request) (*http.Response, error) {
logrus.Debugf("hostDir: %s", hostDir)
fs, err := ioutil.ReadDir(hostDir)
if err != nil && !os.IsNotExist(err) {
return nil, err
return nil
}
for _, f := range fs {
@ -77,7 +93,7 @@ func (tr *httpsTransport) RoundTrip(req *http.Request) (*http.Response, error) {
logrus.Debugf("crt: %s", hostDir+"/"+f.Name())
data, err := ioutil.ReadFile(path.Join(hostDir, f.Name()))
if err != nil {
return nil, err
return err
}
roots.AppendCertsFromPEM(data)
}
@ -86,11 +102,11 @@ func (tr *httpsTransport) RoundTrip(req *http.Request) (*http.Response, error) {
keyName := certName[:len(certName)-5] + ".key"
logrus.Debugf("cert: %s", hostDir+"/"+f.Name())
if !hasFile(fs, keyName) {
return nil, fmt.Errorf("Missing key %s for certificate %s", keyName, certName)
return fmt.Errorf("Missing key %s for certificate %s", keyName, certName)
}
cert, err := tls.LoadX509KeyPair(path.Join(hostDir, certName), path.Join(hostDir, keyName))
if err != nil {
return nil, err
return err
}
certs = append(certs, cert)
}
@ -99,38 +115,32 @@ func (tr *httpsTransport) RoundTrip(req *http.Request) (*http.Response, error) {
certName := keyName[:len(keyName)-4] + ".cert"
logrus.Debugf("key: %s", hostDir+"/"+f.Name())
if !hasFile(fs, certName) {
return nil, fmt.Errorf("Missing certificate %s for key %s", certName, keyName)
return fmt.Errorf("Missing certificate %s for key %s", certName, keyName)
}
}
}
if tr.Transport.TLSClientConfig == nil {
tr.Transport.TLSClientConfig = &tls.Config{
// Avoid fallback to SSL protocols < TLS1.0
MinVersion: tls.VersionTLS10,
}
}
tr.Transport.TLSClientConfig.RootCAs = roots
tr.Transport.TLSClientConfig.Certificates = certs
m.tlsConfig.RootCAs = roots
m.tlsConfig.Certificates = certs
}
return tr.Transport.RoundTrip(req)
return nil
}
func NewTransport(timeout TimeoutType, secure bool) http.RoundTripper {
tlsConfig := tls.Config{
tlsConfig := &tls.Config{
// Avoid fallback to SSL protocols < TLS1.0
MinVersion: tls.VersionTLS10,
InsecureSkipVerify: !secure,
}
transport := &http.Transport{
tr := &http.Transport{
DisableKeepAlives: true,
Proxy: http.ProxyFromEnvironment,
TLSClientConfig: &tlsConfig,
TLSClientConfig: tlsConfig,
}
switch timeout {
case ConnectTimeout:
transport.Dial = func(proto string, addr string) (net.Conn, error) {
tr.Dial = func(proto string, addr string) (net.Conn, error) {
// Set the connect timeout to 30 seconds to allow for slower connection
// times...
d := net.Dialer{Timeout: 30 * time.Second, DualStack: true}
@ -144,7 +154,7 @@ func NewTransport(timeout TimeoutType, secure bool) http.RoundTripper {
return conn, nil
}
case ReceiveTimeout:
transport.Dial = func(proto string, addr string) (net.Conn, error) {
tr.Dial = func(proto string, addr string) (net.Conn, error) {
d := net.Dialer{DualStack: true}
conn, err := d.Dial(proto, addr)
@ -159,51 +169,23 @@ func NewTransport(timeout TimeoutType, secure bool) http.RoundTripper {
if secure {
// note: httpsTransport also handles http transport
// but for HTTPS, it sets up the certs
return &httpsTransport{transport}
return transport.NewTransport(tr, &httpsRequestModifier{tlsConfig})
}
return transport
return tr
}
type DockerHeaders struct {
http.RoundTripper
Headers http.Header
}
// cloneRequest returns a clone of the provided *http.Request.
// The clone is a shallow copy of the struct and its Header map
func cloneRequest(r *http.Request) *http.Request {
// shallow copy of the struct
r2 := new(http.Request)
*r2 = *r
// deep copy of the Header
r2.Header = make(http.Header, len(r.Header))
for k, s := range r.Header {
r2.Header[k] = append([]string(nil), s...)
// DockerHeaders returns request modifiers that ensure requests have
// the User-Agent header set to dockerUserAgent and that metaHeaders
// are added.
func DockerHeaders(metaHeaders http.Header) []transport.RequestModifier {
modifiers := []transport.RequestModifier{
transport.NewHeaderRequestModifier(http.Header{"User-Agent": []string{dockerUserAgent}}),
}
return r2
}
func (tr *DockerHeaders) RoundTrip(req *http.Request) (*http.Response, error) {
req = cloneRequest(req)
httpVersion := make([]useragent.VersionInfo, 0, 4)
httpVersion = append(httpVersion, useragent.VersionInfo{"docker", dockerversion.VERSION})
httpVersion = append(httpVersion, useragent.VersionInfo{"go", runtime.Version()})
httpVersion = append(httpVersion, useragent.VersionInfo{"git-commit", dockerversion.GITCOMMIT})
if kernelVersion, err := kernel.GetKernelVersion(); err == nil {
httpVersion = append(httpVersion, useragent.VersionInfo{"kernel", kernelVersion.String()})
if metaHeaders != nil {
modifiers = append(modifiers, transport.NewHeaderRequestModifier(metaHeaders))
}
httpVersion = append(httpVersion, useragent.VersionInfo{"os", runtime.GOOS})
httpVersion = append(httpVersion, useragent.VersionInfo{"arch", runtime.GOARCH})
userAgent := useragent.AppendVersions(req.UserAgent(), httpVersion...)
req.Header.Set("User-Agent", userAgent)
for k, v := range tr.Headers {
req.Header[k] = v
}
return tr.RoundTripper.RoundTrip(req)
return modifiers
}
type debugTransport struct{ http.RoundTripper }

View File

@ -8,6 +8,7 @@ import (
"testing"
"github.com/docker/docker/cliconfig"
"github.com/docker/docker/pkg/transport"
)
var (
@ -21,12 +22,12 @@ const (
func spawnTestRegistrySession(t *testing.T) *Session {
authConfig := &cliconfig.AuthConfig{}
endpoint, err := NewEndpoint(makeIndex("/v1/"))
endpoint, err := NewEndpoint(makeIndex("/v1/"), nil)
if err != nil {
t.Fatal(err)
}
var tr http.RoundTripper = debugTransport{NewTransport(ReceiveTimeout, endpoint.IsSecure)}
tr = &DockerHeaders{&authTransport{RoundTripper: tr, AuthConfig: authConfig}, nil}
tr = transport.NewTransport(AuthTransport(tr, authConfig, false), DockerHeaders(nil)...)
client := HTTPClient(tr)
r, err := NewSession(client, authConfig, endpoint)
if err != nil {
@ -48,7 +49,7 @@ func spawnTestRegistrySession(t *testing.T) *Session {
func TestPingRegistryEndpoint(t *testing.T) {
testPing := func(index *IndexInfo, expectedStandalone bool, assertMessage string) {
ep, err := NewEndpoint(index)
ep, err := NewEndpoint(index, nil)
if err != nil {
t.Fatal(err)
}
@ -68,7 +69,7 @@ func TestPingRegistryEndpoint(t *testing.T) {
func TestEndpoint(t *testing.T) {
// Simple wrapper to fail test if err != nil
expandEndpoint := func(index *IndexInfo) *Endpoint {
endpoint, err := NewEndpoint(index)
endpoint, err := NewEndpoint(index, nil)
if err != nil {
t.Fatal(err)
}
@ -77,7 +78,7 @@ func TestEndpoint(t *testing.T) {
assertInsecureIndex := func(index *IndexInfo) {
index.Secure = true
_, err := NewEndpoint(index)
_, err := NewEndpoint(index, nil)
assertNotEqual(t, err, nil, index.Name+": Expected error for insecure index")
assertEqual(t, strings.Contains(err.Error(), "insecure-registry"), true, index.Name+": Expected insecure-registry error for insecure index")
index.Secure = false
@ -85,7 +86,7 @@ func TestEndpoint(t *testing.T) {
assertSecureIndex := func(index *IndexInfo) {
index.Secure = true
_, err := NewEndpoint(index)
_, err := NewEndpoint(index, nil)
assertNotEqual(t, err, nil, index.Name+": Expected cert error for secure index")
assertEqual(t, strings.Contains(err.Error(), "certificate signed by unknown authority"), true, index.Name+": Expected cert error for secure index")
index.Secure = false
@ -151,7 +152,7 @@ func TestEndpoint(t *testing.T) {
}
for _, address := range badEndpoints {
index.Name = address
_, err := NewEndpoint(index)
_, err := NewEndpoint(index, nil)
checkNotEqual(t, err, nil, "Expected error while expanding bad endpoint")
}
}

View File

@ -1,6 +1,10 @@
package registry
import "github.com/docker/docker/cliconfig"
import (
"net/http"
"github.com/docker/docker/cliconfig"
)
type Service struct {
Config *ServiceConfig
@ -27,7 +31,7 @@ func (s *Service) Auth(authConfig *cliconfig.AuthConfig) (string, error) {
if err != nil {
return "", err
}
endpoint, err := NewEndpoint(index)
endpoint, err := NewEndpoint(index, nil)
if err != nil {
return "", err
}
@ -44,11 +48,11 @@ func (s *Service) Search(term string, authConfig *cliconfig.AuthConfig, headers
}
// *TODO: Search multiple indexes.
endpoint, err := repoInfo.GetEndpoint()
endpoint, err := repoInfo.GetEndpoint(http.Header(headers))
if err != nil {
return nil, err
}
r, err := NewSession(endpoint.HTTPClient(), authConfig, endpoint)
r, err := NewSession(endpoint.client, authConfig, endpoint)
if err != nil {
return nil, err
}

View File

@ -4,6 +4,7 @@ import (
"bytes"
"crypto/sha256"
"errors"
"sync"
// this is required for some certificates
_ "crypto/sha512"
"encoding/hex"
@ -22,6 +23,7 @@ import (
"github.com/docker/docker/cliconfig"
"github.com/docker/docker/pkg/httputils"
"github.com/docker/docker/pkg/tarsum"
"github.com/docker/docker/pkg/transport"
)
type Session struct {
@ -31,7 +33,18 @@ type Session struct {
authConfig *cliconfig.AuthConfig
}
// authTransport handles the auth layer when communicating with a v1 registry (private or official)
type authTransport struct {
http.RoundTripper
*cliconfig.AuthConfig
alwaysSetBasicAuth bool
token []string
mu sync.Mutex // guards modReq
modReq map[*http.Request]*http.Request // original -> modified
}
// AuthTransport handles the auth layer when communicating with a v1 registry (private or official)
//
// For private v1 registries, set alwaysSetBasicAuth to true.
//
@ -44,16 +57,23 @@ type Session struct {
// If the server sends a token without the client having requested it, it is ignored.
//
// This RoundTripper also has a CancelRequest method important for correct timeout handling.
type authTransport struct {
http.RoundTripper
*cliconfig.AuthConfig
alwaysSetBasicAuth bool
token []string
func AuthTransport(base http.RoundTripper, authConfig *cliconfig.AuthConfig, alwaysSetBasicAuth bool) http.RoundTripper {
if base == nil {
base = http.DefaultTransport
}
return &authTransport{
RoundTripper: base,
AuthConfig: authConfig,
alwaysSetBasicAuth: alwaysSetBasicAuth,
modReq: make(map[*http.Request]*http.Request),
}
}
func (tr *authTransport) RoundTrip(req *http.Request) (*http.Response, error) {
req = cloneRequest(req)
func (tr *authTransport) RoundTrip(orig *http.Request) (*http.Response, error) {
req := transport.CloneRequest(orig)
tr.mu.Lock()
tr.modReq[orig] = req
tr.mu.Unlock()
if tr.alwaysSetBasicAuth {
req.SetBasicAuth(tr.Username, tr.Password)
@ -73,14 +93,33 @@ func (tr *authTransport) RoundTrip(req *http.Request) (*http.Response, error) {
}
resp, err := tr.RoundTripper.RoundTrip(req)
if err != nil {
delete(tr.modReq, orig)
return nil, err
}
if askedForToken && len(resp.Header["X-Docker-Token"]) > 0 {
tr.token = resp.Header["X-Docker-Token"]
}
resp.Body = &transport.OnEOFReader{
Rc: resp.Body,
Fn: func() { delete(tr.modReq, orig) },
}
return resp, nil
}
// CancelRequest cancels an in-flight request by closing its connection.
func (tr *authTransport) CancelRequest(req *http.Request) {
type canceler interface {
CancelRequest(*http.Request)
}
if cr, ok := tr.RoundTripper.(canceler); ok {
tr.mu.Lock()
modReq := tr.modReq[req]
delete(tr.modReq, req)
tr.mu.Unlock()
cr.CancelRequest(modReq)
}
}
// TODO(tiborvass): remove authConfig param once registry client v2 is vendored
func NewSession(client *http.Client, authConfig *cliconfig.AuthConfig, endpoint *Endpoint) (r *Session, err error) {
r = &Session{
@ -105,7 +144,7 @@ func NewSession(client *http.Client, authConfig *cliconfig.AuthConfig, endpoint
}
}
client.Transport = &authTransport{RoundTripper: client.Transport, AuthConfig: authConfig, alwaysSetBasicAuth: alwaysSetBasicAuth}
client.Transport = AuthTransport(client.Transport, authConfig, alwaysSetBasicAuth)
jar, err := cookiejar.New(nil)
if err != nil {

View File

@ -27,7 +27,7 @@ func getV2Builder(e *Endpoint) *v2.URLBuilder {
func (r *Session) V2RegistryEndpoint(index *IndexInfo) (ep *Endpoint, err error) {
// TODO check if should use Mirror
if index.Official {
ep, err = newEndpoint(REGISTRYSERVER, true)
ep, err = newEndpoint(REGISTRYSERVER, true, nil)
if err != nil {
return
}
@ -38,7 +38,7 @@ func (r *Session) V2RegistryEndpoint(index *IndexInfo) (ep *Endpoint, err error)
} else if r.indexEndpoint.String() == index.GetAuthConfigKey() {
ep = r.indexEndpoint
} else {
ep, err = NewEndpoint(index)
ep, err = NewEndpoint(index, nil)
if err != nil {
return
}

View File

@ -13,7 +13,7 @@ type tokenResponse struct {
Token string `json:"token"`
}
func getToken(username, password string, params map[string]string, registryEndpoint *Endpoint, client *http.Client) (token string, err error) {
func getToken(username, password string, params map[string]string, registryEndpoint *Endpoint) (token string, err error) {
realm, ok := params["realm"]
if !ok {
return "", errors.New("no realm specified for token auth challenge")
@ -56,7 +56,7 @@ func getToken(username, password string, params map[string]string, registryEndpo
req.URL.RawQuery = reqParams.Encode()
resp, err := client.Do(req)
resp, err := registryEndpoint.client.Do(req)
if err != nil {
return "", err
}