From 24f9187a0467ca66c30e26c3d9e3ee58daeb720f Mon Sep 17 00:00:00 2001 From: Alexander Larsson Date: Mon, 31 Mar 2014 11:06:39 +0200 Subject: [PATCH] beam: Add simple framing system for UnixConn This is needed for Send/Recieve to correctly handle borders between the messages. The framing uses a single 32bit uint32 length for each frame, of which the high bit is used to indicate whether the message contains a file descriptor or not. This is enough to separate out each message sent and to decide to which message each file descriptors belongs, even though multiple Sends may be coalesced into a single read, and/or one Send can be split into multiple writes. Docker-DCO-1.1-Signed-off-by: Alexander Larsson (github: alexlarsson) Docker-DCO-1.1-Signed-off-by: Solomon Hykes (github: shykes) --- pkg/beam/unix.go | 166 ++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 136 insertions(+), 30 deletions(-) diff --git a/pkg/beam/unix.go b/pkg/beam/unix.go index b480c47eb9..b2d0d94150 100644 --- a/pkg/beam/unix.go +++ b/pkg/beam/unix.go @@ -21,6 +21,43 @@ func debugCheckpoint(msg string, args ...interface{}) { type UnixConn struct { *net.UnixConn + fds []*os.File +} + +// Framing: +// In order to handle framing in Send/Recieve, as these give frame +// boundaries we use a very simple 4 bytes header. It is a big endiand +// uint32 where the high bit is set if the message includes a file +// descriptor. The rest of the uint32 is the length of the next frame. +// We need the bit in order to be able to assign recieved fds to +// the right message, as multiple messages may be coalesced into +// a single recieve operation. +func makeHeader(data []byte, fds []int) ([]byte, error) { + header := make([]byte, 4) + + length := uint32(len(data)) + + if length > 0x7fffffff { + return nil, fmt.Errorf("Data to large") + } + + if len(fds) != 0 { + length = length | 0x80000000 + } + header[0] = byte((length >> 24) & 0xff) + header[1] = byte((length >> 16) & 0xff) + header[2] = byte((length >> 8) & 0xff) + header[3] = byte((length >> 0) & 0xff) + + return header, nil +} + +func parseHeader(header []byte) (uint32, bool) { + length := uint32(header[0])<<24 | uint32(header[1])<<16 | uint32(header[2])<<8 | uint32(header[3]) + hasFd := length&0x80000000 != 0 + length = length & ^uint32(0x80000000) + + return length, hasFd } func FileConn(f *os.File) (*UnixConn, error) { @@ -33,7 +70,7 @@ func FileConn(f *os.File) (*UnixConn, error) { conn.Close() return nil, fmt.Errorf("%d: not a unix connection", f.Fd()) } - return &UnixConn{uconn}, nil + return &UnixConn{UnixConn: uconn}, nil } @@ -52,7 +89,7 @@ func (conn *UnixConn) Send(data []byte, f *os.File) error { if f != nil { fds = append(fds, int(f.Fd())) } - if err := sendUnix(conn.UnixConn, data, fds...); err != nil { + if err := conn.sendUnix(data, fds...); err != nil { return err } @@ -76,42 +113,104 @@ func (conn *UnixConn) Receive() (rdata []byte, rf *os.File, rerr error) { } debugCheckpoint("===DEBUG=== Receive() -> '%s'[%d]. Hit enter to continue.\n", rdata, fd) }() - for { - data, fds, err := receiveUnix(conn.UnixConn) + + // Read header + header := make([]byte, 4) + nRead := uint32(0) + + for nRead < 4 { + n, err := conn.receiveUnix(header[nRead:]) if err != nil { return nil, nil, err } - var f *os.File - if len(fds) > 1 { - for _, fd := range fds[1:] { - syscall.Close(fd) - } - } - if len(fds) >= 1 { - f = os.NewFile(uintptr(fds[0]), "") - } - return data, f, nil + nRead = nRead + uint32(n) } - panic("impossibru") - return nil, nil, nil + + length, hasFd := parseHeader(header) + + if hasFd { + if len(conn.fds) == 0 { + return nil, nil, fmt.Errorf("No expected file descriptor in message") + } + + rf = conn.fds[0] + conn.fds = conn.fds[1:] + } + + rdata = make([]byte, length) + + nRead = 0 + for nRead < length { + n, err := conn.receiveUnix(rdata[nRead:]) + if err != nil { + return nil, nil, err + } + nRead = nRead + uint32(n) + } + + return } -func receiveUnix(conn *net.UnixConn) ([]byte, []int, error) { - buf := make([]byte, 4096) - oob := make([]byte, 4096) +func (conn *UnixConn) receiveUnix(buf []byte) (int, error) { + oob := make([]byte, syscall.CmsgSpace(4)) bufn, oobn, _, _, err := conn.ReadMsgUnix(buf, oob) if err != nil { - return nil, nil, err + return 0, err } - return buf[:bufn], extractFds(oob[:oobn]), nil + fd := extractFd(oob[:oobn]) + if fd != -1 { + f := os.NewFile(uintptr(fd), "") + conn.fds = append(conn.fds, f) + } + + return bufn, nil } -func sendUnix(conn *net.UnixConn, data []byte, fds ...int) error { - _, _, err := conn.WriteMsgUnix(data, syscall.UnixRights(fds...), nil) - return err +func (conn *UnixConn) sendUnix(data []byte, fds ...int) error { + header, err := makeHeader(data, fds) + if err != nil { + return err + } + + // There is a bug in conn.WriteMsgUnix where it doesn't correctly return + // the number of bytes writte (http://code.google.com/p/go/issues/detail?id=7645) + // So, we can't rely on the return value from it. However, we must use it to + // send the fds. In order to handle this we only write one byte using WriteMsgUnix + // (when we have to), as that can only ever block or fully suceed. We then write + // the rest with conn.Write() + // The reader side should not rely on this though, as hopefully this gets fixed + // in go later. + written := 0 + if len(fds) != 0 { + oob := syscall.UnixRights(fds...) + wrote, _, err := conn.WriteMsgUnix(header[0:1], oob, nil) + if err != nil { + return err + } + written = written + wrote + } + + for written < len(header) { + wrote, err := conn.Write(header[written:]) + if err != nil { + return err + } + written = written + wrote + } + + written = 0 + for written < len(data) { + wrote, err := conn.Write(data[written:]) + if err != nil { + return err + } + written = written + wrote + } + + return nil } -func extractFds(oob []byte) (fds []int) { +func extractFd(oob []byte) int { // Grab forklock to make sure no forks accidentally inherit the new // fds before they are made CLOEXEC // There is a slight race condition between ReadMsgUnix returns and @@ -122,20 +221,27 @@ func extractFds(oob []byte) (fds []int) { defer syscall.ForkLock.Unlock() scms, err := syscall.ParseSocketControlMessage(oob) if err != nil { - return + return -1 } + + foundFd := -1 for _, scm := range scms { - gotFds, err := syscall.ParseUnixRights(&scm) + fds, err := syscall.ParseUnixRights(&scm) if err != nil { continue } - fds = append(fds, gotFds...) for _, fd := range fds { - syscall.CloseOnExec(fd) + if foundFd == -1 { + syscall.CloseOnExec(fd) + foundFd = fd + } else { + syscall.Close(fd) + } } } - return + + return foundFd } func socketpair() ([2]int, error) {