diff --git a/pkg/beam/unix.go b/pkg/beam/unix.go index b480c47eb9..b2d0d94150 100644 --- a/pkg/beam/unix.go +++ b/pkg/beam/unix.go @@ -21,6 +21,43 @@ func debugCheckpoint(msg string, args ...interface{}) { type UnixConn struct { *net.UnixConn + fds []*os.File +} + +// Framing: +// In order to handle framing in Send/Recieve, as these give frame +// boundaries we use a very simple 4 bytes header. It is a big endiand +// uint32 where the high bit is set if the message includes a file +// descriptor. The rest of the uint32 is the length of the next frame. +// We need the bit in order to be able to assign recieved fds to +// the right message, as multiple messages may be coalesced into +// a single recieve operation. +func makeHeader(data []byte, fds []int) ([]byte, error) { + header := make([]byte, 4) + + length := uint32(len(data)) + + if length > 0x7fffffff { + return nil, fmt.Errorf("Data to large") + } + + if len(fds) != 0 { + length = length | 0x80000000 + } + header[0] = byte((length >> 24) & 0xff) + header[1] = byte((length >> 16) & 0xff) + header[2] = byte((length >> 8) & 0xff) + header[3] = byte((length >> 0) & 0xff) + + return header, nil +} + +func parseHeader(header []byte) (uint32, bool) { + length := uint32(header[0])<<24 | uint32(header[1])<<16 | uint32(header[2])<<8 | uint32(header[3]) + hasFd := length&0x80000000 != 0 + length = length & ^uint32(0x80000000) + + return length, hasFd } func FileConn(f *os.File) (*UnixConn, error) { @@ -33,7 +70,7 @@ func FileConn(f *os.File) (*UnixConn, error) { conn.Close() return nil, fmt.Errorf("%d: not a unix connection", f.Fd()) } - return &UnixConn{uconn}, nil + return &UnixConn{UnixConn: uconn}, nil } @@ -52,7 +89,7 @@ func (conn *UnixConn) Send(data []byte, f *os.File) error { if f != nil { fds = append(fds, int(f.Fd())) } - if err := sendUnix(conn.UnixConn, data, fds...); err != nil { + if err := conn.sendUnix(data, fds...); err != nil { return err } @@ -76,42 +113,104 @@ func (conn *UnixConn) Receive() (rdata []byte, rf *os.File, rerr error) { } debugCheckpoint("===DEBUG=== Receive() -> '%s'[%d]. Hit enter to continue.\n", rdata, fd) }() - for { - data, fds, err := receiveUnix(conn.UnixConn) + + // Read header + header := make([]byte, 4) + nRead := uint32(0) + + for nRead < 4 { + n, err := conn.receiveUnix(header[nRead:]) if err != nil { return nil, nil, err } - var f *os.File - if len(fds) > 1 { - for _, fd := range fds[1:] { - syscall.Close(fd) - } - } - if len(fds) >= 1 { - f = os.NewFile(uintptr(fds[0]), "") - } - return data, f, nil + nRead = nRead + uint32(n) } - panic("impossibru") - return nil, nil, nil + + length, hasFd := parseHeader(header) + + if hasFd { + if len(conn.fds) == 0 { + return nil, nil, fmt.Errorf("No expected file descriptor in message") + } + + rf = conn.fds[0] + conn.fds = conn.fds[1:] + } + + rdata = make([]byte, length) + + nRead = 0 + for nRead < length { + n, err := conn.receiveUnix(rdata[nRead:]) + if err != nil { + return nil, nil, err + } + nRead = nRead + uint32(n) + } + + return } -func receiveUnix(conn *net.UnixConn) ([]byte, []int, error) { - buf := make([]byte, 4096) - oob := make([]byte, 4096) +func (conn *UnixConn) receiveUnix(buf []byte) (int, error) { + oob := make([]byte, syscall.CmsgSpace(4)) bufn, oobn, _, _, err := conn.ReadMsgUnix(buf, oob) if err != nil { - return nil, nil, err + return 0, err } - return buf[:bufn], extractFds(oob[:oobn]), nil + fd := extractFd(oob[:oobn]) + if fd != -1 { + f := os.NewFile(uintptr(fd), "") + conn.fds = append(conn.fds, f) + } + + return bufn, nil } -func sendUnix(conn *net.UnixConn, data []byte, fds ...int) error { - _, _, err := conn.WriteMsgUnix(data, syscall.UnixRights(fds...), nil) - return err +func (conn *UnixConn) sendUnix(data []byte, fds ...int) error { + header, err := makeHeader(data, fds) + if err != nil { + return err + } + + // There is a bug in conn.WriteMsgUnix where it doesn't correctly return + // the number of bytes writte (http://code.google.com/p/go/issues/detail?id=7645) + // So, we can't rely on the return value from it. However, we must use it to + // send the fds. In order to handle this we only write one byte using WriteMsgUnix + // (when we have to), as that can only ever block or fully suceed. We then write + // the rest with conn.Write() + // The reader side should not rely on this though, as hopefully this gets fixed + // in go later. + written := 0 + if len(fds) != 0 { + oob := syscall.UnixRights(fds...) + wrote, _, err := conn.WriteMsgUnix(header[0:1], oob, nil) + if err != nil { + return err + } + written = written + wrote + } + + for written < len(header) { + wrote, err := conn.Write(header[written:]) + if err != nil { + return err + } + written = written + wrote + } + + written = 0 + for written < len(data) { + wrote, err := conn.Write(data[written:]) + if err != nil { + return err + } + written = written + wrote + } + + return nil } -func extractFds(oob []byte) (fds []int) { +func extractFd(oob []byte) int { // Grab forklock to make sure no forks accidentally inherit the new // fds before they are made CLOEXEC // There is a slight race condition between ReadMsgUnix returns and @@ -122,20 +221,27 @@ func extractFds(oob []byte) (fds []int) { defer syscall.ForkLock.Unlock() scms, err := syscall.ParseSocketControlMessage(oob) if err != nil { - return + return -1 } + + foundFd := -1 for _, scm := range scms { - gotFds, err := syscall.ParseUnixRights(&scm) + fds, err := syscall.ParseUnixRights(&scm) if err != nil { continue } - fds = append(fds, gotFds...) for _, fd := range fds { - syscall.CloseOnExec(fd) + if foundFd == -1 { + syscall.CloseOnExec(fd) + foundFd = fd + } else { + syscall.Close(fd) + } } } - return + + return foundFd } func socketpair() ([2]int, error) {