beam: Add simple framing system for UnixConn

This is needed for Send/Recieve to correctly handle borders between
the messages.

The framing uses a single 32bit uint32 length for each frame, of which
the high bit is used to indicate whether the message contains a file
descriptor or not. This is enough to separate out each message sent
and to decide to which message each file descriptors belongs, even
though multiple Sends may be coalesced into a single read, and/or one
Send can be split into multiple writes.

Docker-DCO-1.1-Signed-off-by: Alexander Larsson <alexl@redhat.com> (github: alexlarsson)
Docker-DCO-1.1-Signed-off-by: Solomon Hykes <solomon@docker.com> (github: shykes)
This commit is contained in:
Alexander Larsson 2014-03-31 11:06:39 +02:00 committed by Solomon Hykes
parent c42db412b6
commit 24f9187a04
1 changed files with 136 additions and 30 deletions

View File

@ -21,6 +21,43 @@ func debugCheckpoint(msg string, args ...interface{}) {
type UnixConn struct {
*net.UnixConn
fds []*os.File
}
// Framing:
// In order to handle framing in Send/Recieve, as these give frame
// boundaries we use a very simple 4 bytes header. It is a big endiand
// uint32 where the high bit is set if the message includes a file
// descriptor. The rest of the uint32 is the length of the next frame.
// We need the bit in order to be able to assign recieved fds to
// the right message, as multiple messages may be coalesced into
// a single recieve operation.
func makeHeader(data []byte, fds []int) ([]byte, error) {
header := make([]byte, 4)
length := uint32(len(data))
if length > 0x7fffffff {
return nil, fmt.Errorf("Data to large")
}
if len(fds) != 0 {
length = length | 0x80000000
}
header[0] = byte((length >> 24) & 0xff)
header[1] = byte((length >> 16) & 0xff)
header[2] = byte((length >> 8) & 0xff)
header[3] = byte((length >> 0) & 0xff)
return header, nil
}
func parseHeader(header []byte) (uint32, bool) {
length := uint32(header[0])<<24 | uint32(header[1])<<16 | uint32(header[2])<<8 | uint32(header[3])
hasFd := length&0x80000000 != 0
length = length & ^uint32(0x80000000)
return length, hasFd
}
func FileConn(f *os.File) (*UnixConn, error) {
@ -33,7 +70,7 @@ func FileConn(f *os.File) (*UnixConn, error) {
conn.Close()
return nil, fmt.Errorf("%d: not a unix connection", f.Fd())
}
return &UnixConn{uconn}, nil
return &UnixConn{UnixConn: uconn}, nil
}
@ -52,7 +89,7 @@ func (conn *UnixConn) Send(data []byte, f *os.File) error {
if f != nil {
fds = append(fds, int(f.Fd()))
}
if err := sendUnix(conn.UnixConn, data, fds...); err != nil {
if err := conn.sendUnix(data, fds...); err != nil {
return err
}
@ -76,42 +113,104 @@ func (conn *UnixConn) Receive() (rdata []byte, rf *os.File, rerr error) {
}
debugCheckpoint("===DEBUG=== Receive() -> '%s'[%d]. Hit enter to continue.\n", rdata, fd)
}()
for {
data, fds, err := receiveUnix(conn.UnixConn)
// Read header
header := make([]byte, 4)
nRead := uint32(0)
for nRead < 4 {
n, err := conn.receiveUnix(header[nRead:])
if err != nil {
return nil, nil, err
}
var f *os.File
if len(fds) > 1 {
for _, fd := range fds[1:] {
syscall.Close(fd)
}
}
if len(fds) >= 1 {
f = os.NewFile(uintptr(fds[0]), "")
}
return data, f, nil
nRead = nRead + uint32(n)
}
panic("impossibru")
return nil, nil, nil
length, hasFd := parseHeader(header)
if hasFd {
if len(conn.fds) == 0 {
return nil, nil, fmt.Errorf("No expected file descriptor in message")
}
rf = conn.fds[0]
conn.fds = conn.fds[1:]
}
rdata = make([]byte, length)
nRead = 0
for nRead < length {
n, err := conn.receiveUnix(rdata[nRead:])
if err != nil {
return nil, nil, err
}
nRead = nRead + uint32(n)
}
return
}
func receiveUnix(conn *net.UnixConn) ([]byte, []int, error) {
buf := make([]byte, 4096)
oob := make([]byte, 4096)
func (conn *UnixConn) receiveUnix(buf []byte) (int, error) {
oob := make([]byte, syscall.CmsgSpace(4))
bufn, oobn, _, _, err := conn.ReadMsgUnix(buf, oob)
if err != nil {
return nil, nil, err
return 0, err
}
return buf[:bufn], extractFds(oob[:oobn]), nil
fd := extractFd(oob[:oobn])
if fd != -1 {
f := os.NewFile(uintptr(fd), "")
conn.fds = append(conn.fds, f)
}
return bufn, nil
}
func sendUnix(conn *net.UnixConn, data []byte, fds ...int) error {
_, _, err := conn.WriteMsgUnix(data, syscall.UnixRights(fds...), nil)
return err
func (conn *UnixConn) sendUnix(data []byte, fds ...int) error {
header, err := makeHeader(data, fds)
if err != nil {
return err
}
// There is a bug in conn.WriteMsgUnix where it doesn't correctly return
// the number of bytes writte (http://code.google.com/p/go/issues/detail?id=7645)
// So, we can't rely on the return value from it. However, we must use it to
// send the fds. In order to handle this we only write one byte using WriteMsgUnix
// (when we have to), as that can only ever block or fully suceed. We then write
// the rest with conn.Write()
// The reader side should not rely on this though, as hopefully this gets fixed
// in go later.
written := 0
if len(fds) != 0 {
oob := syscall.UnixRights(fds...)
wrote, _, err := conn.WriteMsgUnix(header[0:1], oob, nil)
if err != nil {
return err
}
written = written + wrote
}
for written < len(header) {
wrote, err := conn.Write(header[written:])
if err != nil {
return err
}
written = written + wrote
}
written = 0
for written < len(data) {
wrote, err := conn.Write(data[written:])
if err != nil {
return err
}
written = written + wrote
}
return nil
}
func extractFds(oob []byte) (fds []int) {
func extractFd(oob []byte) int {
// Grab forklock to make sure no forks accidentally inherit the new
// fds before they are made CLOEXEC
// There is a slight race condition between ReadMsgUnix returns and
@ -122,20 +221,27 @@ func extractFds(oob []byte) (fds []int) {
defer syscall.ForkLock.Unlock()
scms, err := syscall.ParseSocketControlMessage(oob)
if err != nil {
return
return -1
}
foundFd := -1
for _, scm := range scms {
gotFds, err := syscall.ParseUnixRights(&scm)
fds, err := syscall.ParseUnixRights(&scm)
if err != nil {
continue
}
fds = append(fds, gotFds...)
for _, fd := range fds {
syscall.CloseOnExec(fd)
if foundFd == -1 {
syscall.CloseOnExec(fd)
foundFd = fd
} else {
syscall.Close(fd)
}
}
}
return
return foundFd
}
func socketpair() ([2]int, error) {