mirror of
https://github.com/moby/moby.git
synced 2022-11-09 12:21:53 -05:00
Use stdlib TLS dialer
Since go1.8, the stdlib TLS net.Conn implementation implements the `CloseWrite()` interface. Signed-off-by: Brian Goff <cpuguy83@gmail.com> Signed-off-by: Sebastiaan van Stijn <github@gone.nl>
This commit is contained in:
parent
5cb95f693d
commit
2ac277a56f
4 changed files with 104 additions and 148 deletions
105
client/hijack.go
105
client/hijack.go
|
@ -9,7 +9,6 @@ import (
|
|||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/docker/docker/api/types"
|
||||
|
@ -17,21 +16,6 @@ import (
|
|||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// tlsClientCon holds tls information and a dialed connection.
|
||||
type tlsClientCon struct {
|
||||
*tls.Conn
|
||||
rawConn net.Conn
|
||||
}
|
||||
|
||||
func (c *tlsClientCon) CloseWrite() error {
|
||||
// Go standard tls.Conn doesn't provide the CloseWrite() method so we do it
|
||||
// on its underlying connection.
|
||||
if conn, ok := c.rawConn.(types.CloseWriter); ok {
|
||||
return conn.CloseWrite()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// postHijacked sends a POST request and hijacks the connection.
|
||||
func (cli *Client) postHijacked(ctx context.Context, path string, query url.Values, body interface{}, headers map[string][]string) (types.HijackedResponse, error) {
|
||||
bodyEncoded, err := encodeData(body)
|
||||
|
@ -54,96 +38,9 @@ func (cli *Client) postHijacked(ctx context.Context, path string, query url.Valu
|
|||
return types.HijackedResponse{Conn: conn, Reader: bufio.NewReader(conn)}, err
|
||||
}
|
||||
|
||||
func tlsDial(network, addr string, config *tls.Config) (net.Conn, error) {
|
||||
return tlsDialWithDialer(new(net.Dialer), network, addr, config)
|
||||
}
|
||||
|
||||
// We need to copy Go's implementation of tls.Dial (pkg/cryptor/tls/tls.go) in
|
||||
// order to return our custom tlsClientCon struct which holds both the tls.Conn
|
||||
// object _and_ its underlying raw connection. The rationale for this is that
|
||||
// we need to be able to close the write end of the connection when attaching,
|
||||
// which tls.Conn does not provide.
|
||||
func tlsDialWithDialer(dialer *net.Dialer, network, addr string, config *tls.Config) (net.Conn, error) {
|
||||
// We want the Timeout and Deadline values from dialer to cover the
|
||||
// whole process: TCP connection and TLS handshake. This means that we
|
||||
// also need to start our own timers now.
|
||||
timeout := dialer.Timeout
|
||||
|
||||
if !dialer.Deadline.IsZero() {
|
||||
deadlineTimeout := time.Until(dialer.Deadline)
|
||||
if timeout == 0 || deadlineTimeout < timeout {
|
||||
timeout = deadlineTimeout
|
||||
}
|
||||
}
|
||||
|
||||
var errChannel chan error
|
||||
|
||||
if timeout != 0 {
|
||||
errChannel = make(chan error, 2)
|
||||
time.AfterFunc(timeout, func() {
|
||||
errChannel <- errors.New("")
|
||||
})
|
||||
}
|
||||
|
||||
proxyDialer, err := sockets.DialerFromEnvironment(dialer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rawConn, err := proxyDialer.Dial(network, addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// When we set up a TCP connection for hijack, there could be long periods
|
||||
// of inactivity (a long running command with no output) that in certain
|
||||
// network setups may cause ECONNTIMEOUT, leaving the client in an unknown
|
||||
// state. Setting TCP KeepAlive on the socket connection will prohibit
|
||||
// ECONNTIMEOUT unless the socket connection truly is broken
|
||||
if tcpConn, ok := rawConn.(*net.TCPConn); ok {
|
||||
tcpConn.SetKeepAlive(true)
|
||||
tcpConn.SetKeepAlivePeriod(30 * time.Second)
|
||||
}
|
||||
|
||||
colonPos := strings.LastIndex(addr, ":")
|
||||
if colonPos == -1 {
|
||||
colonPos = len(addr)
|
||||
}
|
||||
hostname := addr[:colonPos]
|
||||
|
||||
// If no ServerName is set, infer the ServerName
|
||||
// from the hostname we're connecting to.
|
||||
if config.ServerName == "" {
|
||||
// Make a copy to avoid polluting argument or default.
|
||||
config = tlsConfigClone(config)
|
||||
config.ServerName = hostname
|
||||
}
|
||||
|
||||
conn := tls.Client(rawConn, config)
|
||||
|
||||
if timeout == 0 {
|
||||
err = conn.Handshake()
|
||||
} else {
|
||||
go func() {
|
||||
errChannel <- conn.Handshake()
|
||||
}()
|
||||
|
||||
err = <-errChannel
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
rawConn.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// This is Docker difference with standard's crypto/tls package: returned a
|
||||
// wrapper which holds both the TLS and raw connections.
|
||||
return &tlsClientCon{conn, rawConn}, nil
|
||||
}
|
||||
|
||||
func dial(proto, addr string, tlsConfig *tls.Config) (net.Conn, error) {
|
||||
if tlsConfig != nil && proto != "unix" && proto != "npipe" {
|
||||
// Notice this isn't Go standard's tls.Dial function
|
||||
return tlsDial(proto, addr, tlsConfig)
|
||||
return tls.Dial(proto, addr, tlsConfig)
|
||||
}
|
||||
if proto == "npipe" {
|
||||
return sockets.DialPipe(addr, 32*time.Second)
|
||||
|
|
103
client/hijack_test.go
Normal file
103
client/hijack_test.go
Normal file
|
@ -0,0 +1,103 @@
|
|||
package client
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/docker/docker/api/server/httputils"
|
||||
"github.com/docker/docker/api/types"
|
||||
"github.com/gotestyourself/gotestyourself/assert"
|
||||
"github.com/pkg/errors"
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
func TestTLSCloseWriter(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var chErr chan error
|
||||
ts := &httptest.Server{Config: &http.Server{Handler: http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
chErr = make(chan error, 1)
|
||||
defer close(chErr)
|
||||
if err := httputils.ParseForm(req); err != nil {
|
||||
chErr <- errors.Wrap(err, "error parsing form")
|
||||
http.Error(w, err.Error(), 500)
|
||||
return
|
||||
}
|
||||
r, rw, err := httputils.HijackConnection(w)
|
||||
if err != nil {
|
||||
chErr <- errors.Wrap(err, "error hijacking connection")
|
||||
http.Error(w, err.Error(), 500)
|
||||
return
|
||||
}
|
||||
defer r.Close()
|
||||
|
||||
fmt.Fprint(rw, "HTTP/1.1 101 UPGRADED\r\nContent-Type: application/vnd.docker.raw-stream\r\nConnection: Upgrade\r\nUpgrade: tcp\r\n\n")
|
||||
|
||||
buf := make([]byte, 5)
|
||||
_, err = r.Read(buf)
|
||||
if err != nil {
|
||||
chErr <- errors.Wrap(err, "error reading from client")
|
||||
return
|
||||
}
|
||||
_, err = rw.Write(buf)
|
||||
if err != nil {
|
||||
chErr <- errors.Wrap(err, "error writing to client")
|
||||
return
|
||||
}
|
||||
})}}
|
||||
|
||||
var (
|
||||
l net.Listener
|
||||
err error
|
||||
)
|
||||
for i := 1024; i < 10000; i++ {
|
||||
l, err = net.Listen("tcp4", fmt.Sprintf("127.0.0.1:%d", i))
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.Assert(t, err)
|
||||
|
||||
ts.Listener = l
|
||||
defer l.Close()
|
||||
|
||||
defer func() {
|
||||
if chErr != nil {
|
||||
assert.Assert(t, <-chErr)
|
||||
}
|
||||
}()
|
||||
|
||||
ts.StartTLS()
|
||||
defer ts.Close()
|
||||
|
||||
serverURL, err := url.Parse(ts.URL)
|
||||
assert.Assert(t, err)
|
||||
|
||||
client, err := NewClient("tcp://"+serverURL.Host, "", ts.Client(), nil)
|
||||
assert.Assert(t, err)
|
||||
|
||||
resp, err := client.postHijacked(context.Background(), "/asdf", url.Values{}, nil, map[string][]string{"Content-Type": {"text/plain"}})
|
||||
assert.Assert(t, err)
|
||||
defer resp.Close()
|
||||
|
||||
if _, ok := resp.Conn.(types.CloseWriter); !ok {
|
||||
t.Fatal("tls conn did not implement the CloseWrite interface")
|
||||
}
|
||||
|
||||
_, err = resp.Conn.Write([]byte("hello"))
|
||||
assert.Assert(t, err)
|
||||
|
||||
b, err := ioutil.ReadAll(resp.Reader)
|
||||
assert.Assert(t, err)
|
||||
assert.Assert(t, string(b) == "hello")
|
||||
assert.Assert(t, resp.CloseWrite())
|
||||
|
||||
// This should error since writes are closed
|
||||
_, err = resp.Conn.Write([]byte("no"))
|
||||
assert.Assert(t, err != nil)
|
||||
}
|
|
@ -1,11 +0,0 @@
|
|||
// +build go1.8
|
||||
|
||||
package client // import "github.com/docker/docker/client"
|
||||
|
||||
import "crypto/tls"
|
||||
|
||||
// tlsConfigClone returns a clone of tls.Config. This function is provided for
|
||||
// compatibility for go1.7 that doesn't include this method in stdlib.
|
||||
func tlsConfigClone(c *tls.Config) *tls.Config {
|
||||
return c.Clone()
|
||||
}
|
|
@ -1,33 +0,0 @@
|
|||
// +build go1.7,!go1.8
|
||||
|
||||
package client // import "github.com/docker/docker/client"
|
||||
|
||||
import "crypto/tls"
|
||||
|
||||
// tlsConfigClone returns a clone of tls.Config. This function is provided for
|
||||
// compatibility for go1.7 that doesn't include this method in stdlib.
|
||||
func tlsConfigClone(c *tls.Config) *tls.Config {
|
||||
return &tls.Config{
|
||||
Rand: c.Rand,
|
||||
Time: c.Time,
|
||||
Certificates: c.Certificates,
|
||||
NameToCertificate: c.NameToCertificate,
|
||||
GetCertificate: c.GetCertificate,
|
||||
RootCAs: c.RootCAs,
|
||||
NextProtos: c.NextProtos,
|
||||
ServerName: c.ServerName,
|
||||
ClientAuth: c.ClientAuth,
|
||||
ClientCAs: c.ClientCAs,
|
||||
InsecureSkipVerify: c.InsecureSkipVerify,
|
||||
CipherSuites: c.CipherSuites,
|
||||
PreferServerCipherSuites: c.PreferServerCipherSuites,
|
||||
SessionTicketsDisabled: c.SessionTicketsDisabled,
|
||||
SessionTicketKey: c.SessionTicketKey,
|
||||
ClientSessionCache: c.ClientSessionCache,
|
||||
MinVersion: c.MinVersion,
|
||||
MaxVersion: c.MaxVersion,
|
||||
CurvePreferences: c.CurvePreferences,
|
||||
DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled,
|
||||
Renegotiation: c.Renegotiation,
|
||||
}
|
||||
}
|
Loading…
Add table
Reference in a new issue