1
0
Fork 0
mirror of https://github.com/moby/moby.git synced 2022-11-09 12:21:53 -05:00

Use stdlib TLS dialer

Since go1.8, the stdlib TLS net.Conn implementation implements the
`CloseWrite()` interface.

Signed-off-by: Brian Goff <cpuguy83@gmail.com>
Signed-off-by: Sebastiaan van Stijn <github@gone.nl>
This commit is contained in:
Brian Goff 2018-03-23 14:39:30 -04:00 committed by Sebastiaan van Stijn
parent 5cb95f693d
commit 2ac277a56f
No known key found for this signature in database
GPG key ID: 76698F39D527CE8C
4 changed files with 104 additions and 148 deletions

View file

@ -9,7 +9,6 @@ import (
"net/http"
"net/http/httputil"
"net/url"
"strings"
"time"
"github.com/docker/docker/api/types"
@ -17,21 +16,6 @@ import (
"github.com/pkg/errors"
)
// tlsClientCon holds tls information and a dialed connection.
type tlsClientCon struct {
*tls.Conn
rawConn net.Conn
}
func (c *tlsClientCon) CloseWrite() error {
// Go standard tls.Conn doesn't provide the CloseWrite() method so we do it
// on its underlying connection.
if conn, ok := c.rawConn.(types.CloseWriter); ok {
return conn.CloseWrite()
}
return nil
}
// postHijacked sends a POST request and hijacks the connection.
func (cli *Client) postHijacked(ctx context.Context, path string, query url.Values, body interface{}, headers map[string][]string) (types.HijackedResponse, error) {
bodyEncoded, err := encodeData(body)
@ -54,96 +38,9 @@ func (cli *Client) postHijacked(ctx context.Context, path string, query url.Valu
return types.HijackedResponse{Conn: conn, Reader: bufio.NewReader(conn)}, err
}
func tlsDial(network, addr string, config *tls.Config) (net.Conn, error) {
return tlsDialWithDialer(new(net.Dialer), network, addr, config)
}
// We need to copy Go's implementation of tls.Dial (pkg/cryptor/tls/tls.go) in
// order to return our custom tlsClientCon struct which holds both the tls.Conn
// object _and_ its underlying raw connection. The rationale for this is that
// we need to be able to close the write end of the connection when attaching,
// which tls.Conn does not provide.
func tlsDialWithDialer(dialer *net.Dialer, network, addr string, config *tls.Config) (net.Conn, error) {
// We want the Timeout and Deadline values from dialer to cover the
// whole process: TCP connection and TLS handshake. This means that we
// also need to start our own timers now.
timeout := dialer.Timeout
if !dialer.Deadline.IsZero() {
deadlineTimeout := time.Until(dialer.Deadline)
if timeout == 0 || deadlineTimeout < timeout {
timeout = deadlineTimeout
}
}
var errChannel chan error
if timeout != 0 {
errChannel = make(chan error, 2)
time.AfterFunc(timeout, func() {
errChannel <- errors.New("")
})
}
proxyDialer, err := sockets.DialerFromEnvironment(dialer)
if err != nil {
return nil, err
}
rawConn, err := proxyDialer.Dial(network, addr)
if err != nil {
return nil, err
}
// When we set up a TCP connection for hijack, there could be long periods
// of inactivity (a long running command with no output) that in certain
// network setups may cause ECONNTIMEOUT, leaving the client in an unknown
// state. Setting TCP KeepAlive on the socket connection will prohibit
// ECONNTIMEOUT unless the socket connection truly is broken
if tcpConn, ok := rawConn.(*net.TCPConn); ok {
tcpConn.SetKeepAlive(true)
tcpConn.SetKeepAlivePeriod(30 * time.Second)
}
colonPos := strings.LastIndex(addr, ":")
if colonPos == -1 {
colonPos = len(addr)
}
hostname := addr[:colonPos]
// If no ServerName is set, infer the ServerName
// from the hostname we're connecting to.
if config.ServerName == "" {
// Make a copy to avoid polluting argument or default.
config = tlsConfigClone(config)
config.ServerName = hostname
}
conn := tls.Client(rawConn, config)
if timeout == 0 {
err = conn.Handshake()
} else {
go func() {
errChannel <- conn.Handshake()
}()
err = <-errChannel
}
if err != nil {
rawConn.Close()
return nil, err
}
// This is Docker difference with standard's crypto/tls package: returned a
// wrapper which holds both the TLS and raw connections.
return &tlsClientCon{conn, rawConn}, nil
}
func dial(proto, addr string, tlsConfig *tls.Config) (net.Conn, error) {
if tlsConfig != nil && proto != "unix" && proto != "npipe" {
// Notice this isn't Go standard's tls.Dial function
return tlsDial(proto, addr, tlsConfig)
return tls.Dial(proto, addr, tlsConfig)
}
if proto == "npipe" {
return sockets.DialPipe(addr, 32*time.Second)

103
client/hijack_test.go Normal file
View file

@ -0,0 +1,103 @@
package client
import (
"fmt"
"io/ioutil"
"net"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"github.com/docker/docker/api/server/httputils"
"github.com/docker/docker/api/types"
"github.com/gotestyourself/gotestyourself/assert"
"github.com/pkg/errors"
"golang.org/x/net/context"
)
func TestTLSCloseWriter(t *testing.T) {
t.Parallel()
var chErr chan error
ts := &httptest.Server{Config: &http.Server{Handler: http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
chErr = make(chan error, 1)
defer close(chErr)
if err := httputils.ParseForm(req); err != nil {
chErr <- errors.Wrap(err, "error parsing form")
http.Error(w, err.Error(), 500)
return
}
r, rw, err := httputils.HijackConnection(w)
if err != nil {
chErr <- errors.Wrap(err, "error hijacking connection")
http.Error(w, err.Error(), 500)
return
}
defer r.Close()
fmt.Fprint(rw, "HTTP/1.1 101 UPGRADED\r\nContent-Type: application/vnd.docker.raw-stream\r\nConnection: Upgrade\r\nUpgrade: tcp\r\n\n")
buf := make([]byte, 5)
_, err = r.Read(buf)
if err != nil {
chErr <- errors.Wrap(err, "error reading from client")
return
}
_, err = rw.Write(buf)
if err != nil {
chErr <- errors.Wrap(err, "error writing to client")
return
}
})}}
var (
l net.Listener
err error
)
for i := 1024; i < 10000; i++ {
l, err = net.Listen("tcp4", fmt.Sprintf("127.0.0.1:%d", i))
if err == nil {
break
}
}
assert.Assert(t, err)
ts.Listener = l
defer l.Close()
defer func() {
if chErr != nil {
assert.Assert(t, <-chErr)
}
}()
ts.StartTLS()
defer ts.Close()
serverURL, err := url.Parse(ts.URL)
assert.Assert(t, err)
client, err := NewClient("tcp://"+serverURL.Host, "", ts.Client(), nil)
assert.Assert(t, err)
resp, err := client.postHijacked(context.Background(), "/asdf", url.Values{}, nil, map[string][]string{"Content-Type": {"text/plain"}})
assert.Assert(t, err)
defer resp.Close()
if _, ok := resp.Conn.(types.CloseWriter); !ok {
t.Fatal("tls conn did not implement the CloseWrite interface")
}
_, err = resp.Conn.Write([]byte("hello"))
assert.Assert(t, err)
b, err := ioutil.ReadAll(resp.Reader)
assert.Assert(t, err)
assert.Assert(t, string(b) == "hello")
assert.Assert(t, resp.CloseWrite())
// This should error since writes are closed
_, err = resp.Conn.Write([]byte("no"))
assert.Assert(t, err != nil)
}

View file

@ -1,11 +0,0 @@
// +build go1.8
package client // import "github.com/docker/docker/client"
import "crypto/tls"
// tlsConfigClone returns a clone of tls.Config. This function is provided for
// compatibility for go1.7 that doesn't include this method in stdlib.
func tlsConfigClone(c *tls.Config) *tls.Config {
return c.Clone()
}

View file

@ -1,33 +0,0 @@
// +build go1.7,!go1.8
package client // import "github.com/docker/docker/client"
import "crypto/tls"
// tlsConfigClone returns a clone of tls.Config. This function is provided for
// compatibility for go1.7 that doesn't include this method in stdlib.
func tlsConfigClone(c *tls.Config) *tls.Config {
return &tls.Config{
Rand: c.Rand,
Time: c.Time,
Certificates: c.Certificates,
NameToCertificate: c.NameToCertificate,
GetCertificate: c.GetCertificate,
RootCAs: c.RootCAs,
NextProtos: c.NextProtos,
ServerName: c.ServerName,
ClientAuth: c.ClientAuth,
ClientCAs: c.ClientCAs,
InsecureSkipVerify: c.InsecureSkipVerify,
CipherSuites: c.CipherSuites,
PreferServerCipherSuites: c.PreferServerCipherSuites,
SessionTicketsDisabled: c.SessionTicketsDisabled,
SessionTicketKey: c.SessionTicketKey,
ClientSessionCache: c.ClientSessionCache,
MinVersion: c.MinVersion,
MaxVersion: c.MaxVersion,
CurvePreferences: c.CurvePreferences,
DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled,
Renegotiation: c.Renegotiation,
}
}