From dacae746b70f50dd1f3ea9d40834386b96b6c200 Mon Sep 17 00:00:00 2001 From: Brian Goff Date: Fri, 7 Nov 2014 15:21:19 -0500 Subject: [PATCH] Cleanup api server creation Current implementation is hard to reason about because of trying to mix unix/tcp server implementations, even though they are quite different. This cleans that up. Also makes it possible to create and manage a new API server easily, e.g. for adding an introspection socket to a container. Built in such a way as to allow a non-HTTP server to work as well, such as libchan. Signed-off-by: Brian Goff --- api/server/server.go | 213 +++++++++++++++++++++++++++---------------- 1 file changed, 134 insertions(+), 79 deletions(-) diff --git a/api/server/server.go b/api/server/server.go index d77a6c22a2..13affc334a 100644 --- a/api/server/server.go +++ b/api/server/server.go @@ -3,8 +3,7 @@ package server import ( "bufio" "bytes" - "crypto/tls" - "crypto/x509" + "encoding/base64" "encoding/json" "expvar" @@ -19,6 +18,9 @@ import ( "strings" "syscall" + "crypto/tls" + "crypto/x509" + "code.google.com/p/go.net/websocket" "github.com/docker/libcontainer/user" "github.com/gorilla/mux" @@ -39,6 +41,18 @@ var ( activationLock chan struct{} ) +type HttpServer struct { + srv *http.Server + l net.Listener +} + +func (s *HttpServer) Serve() error { + return s.srv.Serve(s.l) +} +func (s *HttpServer) Close() error { + return s.l.Close() +} + type HttpApiFunc func(eng *engine.Engine, version version.Version, w http.ResponseWriter, r *http.Request, vars map[string]string) error func hijackServer(w http.ResponseWriter) (io.ReadCloser, io.Writer, error) { @@ -1334,9 +1348,14 @@ func ServeRequest(eng *engine.Engine, apiversion version.Version, w http.Respons return nil } -// ServeFD creates an http.Server and sets it up to serve given a socket activated +// serveFd creates an http.Server and sets it up to serve given a socket activated // argument. -func ServeFd(addr string, handle http.Handler) error { +func serveFd(addr string, job *engine.Job) error { + r, err := createRouter(job.Eng, job.GetenvBool("Logging"), job.GetenvBool("EnableCors"), job.Getenv("Version")) + if err != nil { + return err + } + ls, e := systemd.ListenFD(addr) if e != nil { return e @@ -1354,7 +1373,7 @@ func ServeFd(addr string, handle http.Handler) error { for i := range ls { listener := ls[i] go func() { - httpSrv := http.Server{Handler: handle} + httpSrv := http.Server{Handler: r} chErrors <- httpSrv.Serve(listener) }() } @@ -1382,6 +1401,41 @@ func lookupGidByName(nameOrGid string) (int, error) { return -1, fmt.Errorf("Group %s not found", nameOrGid) } +func setupTls(cert, key, ca string, l net.Listener) (net.Listener, error) { + tlsCert, err := tls.LoadX509KeyPair(cert, key) + if err != nil { + return nil, fmt.Errorf("Couldn't load X509 key pair (%s, %s): %s. Key encrypted?", + cert, key, err) + } + tlsConfig := &tls.Config{ + NextProtos: []string{"http/1.1"}, + Certificates: []tls.Certificate{tlsCert}, + // Avoid fallback on insecure SSL protocols + MinVersion: tls.VersionTLS10, + } + + if ca != "" { + certPool := x509.NewCertPool() + file, err := ioutil.ReadFile(ca) + if err != nil { + return nil, fmt.Errorf("Couldn't read CA certificate: %s", err) + } + certPool.AppendCertsFromPEM(file) + tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert + tlsConfig.ClientCAs = certPool + } + + return tls.NewListener(l, tlsConfig), nil +} + +func newListener(proto, addr string, bufferRequests bool) (net.Listener, error) { + if bufferRequests { + return listenbuffer.NewListenBuffer(proto, addr, activationLock) + } + + return net.Listen(proto, addr) +} + func changeGroup(addr string, nameOrGid string) error { gid, err := lookupGidByName(nameOrGid) if err != nil { @@ -1392,99 +1446,95 @@ func changeGroup(addr string, nameOrGid string) error { return os.Chown(addr, 0, gid) } -// ListenAndServe sets up the required http.Server and gets it listening for -// each addr passed in and does protocol specific checking. -func ListenAndServe(proto, addr string, job *engine.Job) error { - var l net.Listener +func setSocketGroup(addr, group string) error { + if group == "" { + return nil + } + + if err := changeGroup(addr, group); err != nil { + if group != "docker" { + return err + } + log.Debugf("Warning: could not chgrp %s to docker: %v", addr, err) + } + + return nil +} + +func setupUnixHttp(addr string, job *engine.Job) (*HttpServer, error) { r, err := createRouter(job.Eng, job.GetenvBool("Logging"), job.GetenvBool("EnableCors"), job.Getenv("Version")) if err != nil { - return err + return nil, err } - if proto == "fd" { - return ServeFd(addr, r) + if err := syscall.Unlink(addr); err != nil && !os.IsNotExist(err) { + return nil, err } + mask := syscall.Umask(0777) + defer syscall.Umask(mask) - if proto == "unix" { - if err := syscall.Unlink(addr); err != nil && !os.IsNotExist(err) { - return err - } - } - - var oldmask int - if proto == "unix" { - oldmask = syscall.Umask(0777) - } - - if job.GetenvBool("BufferRequests") { - l, err = listenbuffer.NewListenBuffer(proto, addr, activationLock) - } else { - l, err = net.Listen(proto, addr) - } - - if proto == "unix" { - syscall.Umask(oldmask) - } + l, err := newListener("unix", addr, job.GetenvBool("BufferRequests")) if err != nil { - return err + return nil, err } - if proto != "unix" && (job.GetenvBool("Tls") || job.GetenvBool("TlsVerify")) { - tlsCert := job.Getenv("TlsCert") - tlsKey := job.Getenv("TlsKey") - cert, err := tls.LoadX509KeyPair(tlsCert, tlsKey) - if err != nil { - return fmt.Errorf("Couldn't load X509 key pair (%s, %s): %s. Key encrypted?", - tlsCert, tlsKey, err) - } - tlsConfig := &tls.Config{ - NextProtos: []string{"http/1.1"}, - Certificates: []tls.Certificate{cert}, - // Avoid fallback on insecure SSL protocols - MinVersion: tls.VersionTLS10, - } + if err := setSocketGroup(addr, job.Getenv("SocketGroup")); err != nil { + return nil, err + } + + if err := os.Chmod(addr, 0660); err != nil { + return nil, err + } + + return &HttpServer{&http.Server{Addr: addr, Handler: r}, l}, nil +} + +func setupTcpHttp(addr string, job *engine.Job) (*HttpServer, error) { + if !strings.HasPrefix(addr, "127.0.0.1") && !job.GetenvBool("TlsVerify") { + log.Infof("/!\\ DON'T BIND ON ANOTHER IP ADDRESS THAN 127.0.0.1 IF YOU DON'T KNOW WHAT YOU'RE DOING /!\\") + } + + r, err := createRouter(job.Eng, job.GetenvBool("Logging"), job.GetenvBool("EnableCors"), job.Getenv("Version")) + if err != nil { + return nil, err + } + + l, err := newListener("tcp", addr, job.GetenvBool("BufferRequests")) + if err != nil { + return nil, err + } + + if job.GetenvBool("Tls") || job.GetenvBool("TlsVerify") { + var tlsCa string if job.GetenvBool("TlsVerify") { - certPool := x509.NewCertPool() - file, err := ioutil.ReadFile(job.Getenv("TlsCa")) - if err != nil { - return fmt.Errorf("Couldn't read CA certificate: %s", err) - } - certPool.AppendCertsFromPEM(file) - - tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert - tlsConfig.ClientCAs = certPool + tlsCa = job.Getenv("TlsCa") + } + l, err = setupTls(job.Getenv("TlsCert"), job.Getenv("TlsKey"), tlsCa, l) + if err != nil { + return nil, err } - l = tls.NewListener(l, tlsConfig) } + return &HttpServer{&http.Server{Addr: addr, Handler: r}, l}, nil +} +// NewServer sets up the required Server and does protocol specific checking. +func NewServer(proto, addr string, job *engine.Job) (Server, error) { // Basic error and sanity checking switch proto { + case "fd": + return nil, serveFd(addr, job) case "tcp": - if !strings.HasPrefix(addr, "127.0.0.1") && !job.GetenvBool("TlsVerify") { - log.Infof("/!\\ DON'T BIND ON ANOTHER IP ADDRESS THAN 127.0.0.1 IF YOU DON'T KNOW WHAT YOU'RE DOING /!\\") - } + return setupTcpHttp(addr, job) case "unix": - socketGroup := job.Getenv("SocketGroup") - if socketGroup != "" { - if err := changeGroup(addr, socketGroup); err != nil { - if socketGroup == "docker" { - // if the user hasn't explicitly specified the group ownership, don't fail on errors. - log.Debugf("Warning: could not chgrp %s to docker: %s", addr, err.Error()) - } else { - return err - } - } - - } - if err := os.Chmod(addr, 0660); err != nil { - return err - } + return setupUnixHttp(addr, job) default: - return fmt.Errorf("Invalid protocol format.") + return nil, fmt.Errorf("Invalid protocol format.") } +} - httpSrv := http.Server{Addr: addr, Handler: r} - return httpSrv.Serve(l) +type Server interface { + Serve() error + Close() error } // ServeApi loops through all of the protocols sent in to docker and spawns @@ -1506,7 +1556,12 @@ func ServeApi(job *engine.Job) engine.Status { } go func() { log.Infof("Listening for HTTP on %s (%s)", protoAddrParts[0], protoAddrParts[1]) - chErrors <- ListenAndServe(protoAddrParts[0], protoAddrParts[1], job) + srv, err := NewServer(protoAddrParts[0], protoAddrParts[1], job) + if err != nil { + chErrors <- err + return + } + chErrors <- srv.Serve() }() }