diff --git a/pkg/beam/beam.go b/pkg/beam/beam.go new file mode 100644 index 0000000000..98fc4b064e --- /dev/null +++ b/pkg/beam/beam.go @@ -0,0 +1,102 @@ +package beam + +import ( + "io" + "os" +) + +type Sender interface { + Send([]byte, *os.File) error +} + +type Receiver interface { + Receive() ([]byte, *os.File, error) +} + +type ReceiveCloser interface { + Receiver + Close() error +} + +type SendCloser interface { + Sender + Close() error +} + +func SendPipe(dst Sender, data []byte) (*os.File, error) { + r, w, err := os.Pipe() + if err != nil { + return nil, err + } + if err := dst.Send(data, r); err != nil { + r.Close() + w.Close() + return nil, err + } + return w, nil +} + +func SendPair(dst Sender, data []byte) (in ReceiveCloser, out SendCloser, err error) { + local, remote, err := SocketPair() + if err != nil { + return nil, nil, err + } + defer func() { + if err != nil { + local.Close() + remote.Close() + } + }() + endpoint, err := FileConn(local) + if err != nil { + return nil, nil, err + } + local.Close() + if err := dst.Send(data, remote); err != nil { + return nil, nil, err + } + return ReceiveCloser(endpoint), SendCloser(endpoint), nil +} + +func ReceivePair(src Receiver) ([]byte, Receiver, Sender, error) { + for { + data, f, err := src.Receive() + if err != nil { + return nil, nil, nil, err + } + if f == nil { + // Skip empty attachments + continue + } + conn, err := FileConn(f) + if err != nil { + // Skip beam attachments which are not connections + // (for example might be a regular file, directory etc) + continue + } + return data, Receiver(conn), Sender(conn), nil + } + panic("impossibru!") + return nil, nil, nil, nil +} + +func Copy(dst Sender, src Receiver) (int, error) { + var n int + for { + payload, attachment, err := src.Receive() + if err == io.EOF { + return n, nil + } else if err != nil { + return n, err + } + if err := dst.Send(payload, attachment); err != nil { + if attachment != nil { + attachment.Close() + } + return n, err + } + n++ + } + panic("impossibru!") + return n, nil +} diff --git a/pkg/beam/examples/beamsh/beamsh.go b/pkg/beam/examples/beamsh/beamsh.go index 076dcef498..ef4354f936 100644 --- a/pkg/beam/examples/beamsh/beamsh.go +++ b/pkg/beam/examples/beamsh/beamsh.go @@ -87,6 +87,7 @@ func main() { func executeRootScript(script []*dockerscript.Command) error { if len(rootPlugins) > 0 { + // If there are root plugins, wrap the script inside them var ( rootCmd *dockerscript.Command lastCmd *dockerscript.Command @@ -108,7 +109,7 @@ func executeRootScript(script []*dockerscript.Command) error { return executeScript(nil, script) } -func executeScript(client *net.UnixConn, script []*dockerscript.Command) error { +func executeScript(out beam.Sender, script []*dockerscript.Command) error { Debugf("executeScript(%s)\n", scriptString(script)) defer Debugf("executeScript(%s) DONE\n", scriptString(script)) var background sync.WaitGroup @@ -116,12 +117,12 @@ func executeScript(client *net.UnixConn, script []*dockerscript.Command) error { for _, cmd := range script { if cmd.Background { background.Add(1) - go func(client *net.UnixConn, cmd *dockerscript.Command) { - executeCommand(client, cmd) + go func(out beam.Sender, cmd *dockerscript.Command) { + executeCommand(out, cmd) background.Done() - }(client, cmd) + }(out, cmd) } else { - if err := executeCommand(client, cmd); err != nil { + if err := executeCommand(out, cmd); err != nil { return err } } @@ -136,8 +137,8 @@ func executeScript(client *net.UnixConn, script []*dockerscript.Command) error { // 4) [in the background] Run the handler // 5) Recursively executeScript() all children commands and wait for them to complete // 6) Wait for handler to return and (shortly afterwards) output copy to complete -// 7) -func executeCommand(client *net.UnixConn, cmd *dockerscript.Command) error { +// 7) Profit +func executeCommand(out beam.Sender, cmd *dockerscript.Command) error { if flX { fmt.Printf("+ %v\n", strings.Replace(strings.TrimRight(cmd.String(), "\n"), "\n", "\n+ ", -1)) } @@ -177,9 +178,9 @@ func executeCommand(client *net.UnixConn, cmd *dockerscript.Command) error { tasks.Done() }() go func() { - if client != nil { + if out != nil { Debugf("[%s] copy start...\n", strings.Join(cmd.Args, " ")) - n, err := beamCopy(client, outPub) + n, err := beam.Copy(out, outPub) if err != nil { Fatal(err) } @@ -199,18 +200,18 @@ func executeCommand(client *net.UnixConn, cmd *dockerscript.Command) error { } -type Handler func([]string, *net.UnixConn, *net.UnixConn) +type Handler func([]string, beam.Receiver, beam.Sender) func GetHandler(name string) Handler { if name == "log" { - return func(args []string, in *net.UnixConn, out *net.UnixConn) { + return func(args []string, in beam.Receiver, out beam.Sender) { var tasks sync.WaitGroup - stdout, err := sendWPipe(out, data.Empty().Set("cmd", "log", "stdout").Set("fromcmd", args...).Bytes()) + stdout, err := beam.SendPipe(out, data.Empty().Set("cmd", "log", "stdout").Set("fromcmd", args...).Bytes()) if err != nil { return } defer stdout.Close() - stderr, err := sendWPipe(out, data.Empty().Set("cmd", "log", "stderr").Set("fromcmd", args...).Bytes()) + stderr, err := beam.SendPipe(out, data.Empty().Set("cmd", "log", "stderr").Set("fromcmd", args...).Bytes()) if err != nil { return } @@ -221,14 +222,14 @@ func GetHandler(name string) Handler { } var n int = 1 for { - payload, attachment, err := beam.Receive(in) + payload, attachment, err := in.Receive() if err != nil { return } if attachment == nil { continue } - w, err := sendWPipe(out, payload) + w, err := beam.SendPipe(out, payload) if err != nil { fmt.Fprintf(stderr, "%v\n", err) attachment.Close() @@ -269,20 +270,20 @@ func GetHandler(name string) Handler { tasks.Wait() } } else if name == "render" { - return func(args []string, in *net.UnixConn, out *net.UnixConn) { - stdout, err := sendWPipe(out, data.Empty().Set("cmd", "log", "stdout").Set("fromcmd", args...).Bytes()) + return func(args []string, in beam.Receiver, out beam.Sender) { + stdout, err := beam.SendPipe(out, data.Empty().Set("cmd", "log", "stdout").Set("fromcmd", args...).Bytes()) if err != nil { return } defer stdout.Close() - stderr, err := sendWPipe(out, data.Empty().Set("cmd", "log", "stderr").Set("fromcmd", args...).Bytes()) + stderr, err := beam.SendPipe(out, data.Empty().Set("cmd", "log", "stderr").Set("fromcmd", args...).Bytes()) if err != nil { return } defer stderr.Close() if len(args) != 2 { fmt.Fprintf(stderr, "Usage: %s FORMAT\n", args[0]) - beam.Send(out, data.Empty().Set("status", "1").Bytes(), nil) + out.Send(data.Empty().Set("status", "1").Bytes(), nil) return } txt := args[1] @@ -291,7 +292,7 @@ func GetHandler(name string) Handler { } t := template.Must(template.New("render").Parse(txt)) for { - payload, attachment, err := beam.Receive(in) + payload, attachment, err := in.Receive() if err != nil { return } @@ -301,18 +302,18 @@ func GetHandler(name string) Handler { } if err := t.Execute(stdout, msg); err != nil { fmt.Fprintf(stderr, "rendering error: %v\n", err) - beam.Send(out, data.Empty().Set("status", "1").Bytes(), nil) + out.Send(data.Empty().Set("status", "1").Bytes(), nil) return } - if err := beam.Send(out, payload, attachment); err != nil { + if err := out.Send(payload, attachment); err != nil { return } } } } else if name == "devnull" { - return func(args []string, in *net.UnixConn, out *net.UnixConn) { + return func(args []string, in beam.Receiver, out beam.Sender) { for { - _, attachment, err := beam.Receive(in) + _, attachment, err := in.Receive() if err != nil { return } @@ -322,11 +323,11 @@ func GetHandler(name string) Handler { } } } else if name == "stdio" { - return func(args []string, in *net.UnixConn, out *net.UnixConn) { + return func(args []string, in beam.Receiver, out beam.Sender) { var tasks sync.WaitGroup defer tasks.Wait() for { - payload, attachment, err := beam.Receive(in) + payload, attachment, err := in.Receive() if err != nil { return } @@ -353,8 +354,8 @@ func GetHandler(name string) Handler { } } } else if name == "echo" { - return func(args []string, in *net.UnixConn, out *net.UnixConn) { - stdout, err := sendWPipe(out, data.Empty().Set("cmd", "log", "stdout").Bytes()) + return func(args []string, in beam.Receiver, out beam.Sender) { + stdout, err := beam.SendPipe(out, data.Empty().Set("cmd", "log", "stdout").Bytes()) if err != nil { return } @@ -362,13 +363,13 @@ func GetHandler(name string) Handler { stdout.Close() } } else if name == "pass" { - return func(args []string, in *net.UnixConn, out *net.UnixConn) { + return func(args []string, in beam.Receiver, out beam.Sender) { for { - payload, attachment, err := beam.Receive(in) + payload, attachment, err := in.Receive() if err != nil { return } - if err := beam.Send(out, payload, attachment); err != nil { + if err := out.Send(payload, attachment); err != nil { if attachment != nil { attachment.Close() } @@ -377,20 +378,20 @@ func GetHandler(name string) Handler { } } } else if name == "in" { - return func(args []string, in *net.UnixConn, out *net.UnixConn) { + return func(args []string, in beam.Receiver, out beam.Sender) { os.Chdir(args[1]) GetHandler("pass")([]string{"pass"}, in, out) } } else if name == "exec" { - return func(args []string, in *net.UnixConn, out *net.UnixConn) { + return func(args []string, in beam.Receiver, out beam.Sender) { cmd := exec.Command(args[1], args[2:]...) - stdout, err := sendWPipe(out, data.Empty().Set("cmd", "log", "stdout").Set("fromcmd", args...).Bytes()) + stdout, err := beam.SendPipe(out, data.Empty().Set("cmd", "log", "stdout").Set("fromcmd", args...).Bytes()) if err != nil { return } defer stdout.Close() cmd.Stdout = stdout - stderr, err := sendWPipe(out, data.Empty().Set("cmd", "log", "stderr").Set("fromcmd", args...).Bytes()) + stderr, err := beam.SendPipe(out, data.Empty().Set("cmd", "log", "stderr").Set("fromcmd", args...).Bytes()) if err != nil { return } @@ -404,12 +405,12 @@ func GetHandler(name string) Handler { } else { status = "ok" } - beam.Send(out, data.Empty().Set("status", status).Set("cmd", args...).Bytes(), nil) + out.Send(data.Empty().Set("status", status).Set("cmd", args...).Bytes(), nil) } } else if name == "trace" { - return func(args []string, in *net.UnixConn, out *net.UnixConn) { + return func(args []string, in beam.Receiver, out beam.Sender) { for { - p, a, err := beam.Receive(in) + p, a, err := in.Receive() if err != nil { return } @@ -423,17 +424,17 @@ func GetHandler(name string) Handler { msg = fmt.Sprintf("%s [%d]", msg, a.Fd()) } fmt.Printf("===> %s\n", msg) - beam.Send(out, p, a) + out.Send(p, a) } } } else if name == "emit" { - return func(args []string, in *net.UnixConn, out *net.UnixConn) { - beam.Send(out, data.Parse(args[1:]).Bytes(), nil) + return func(args []string, in beam.Receiver, out beam.Sender) { + out.Send(data.Parse(args[1:]).Bytes(), nil) } } else if name == "print" { - return func(args []string, in *net.UnixConn, out *net.UnixConn) { + return func(args []string, in beam.Receiver, out beam.Sender) { for { - _, a, err := beam.Receive(in) + _, a, err := in.Receive() if err != nil { return } @@ -443,10 +444,10 @@ func GetHandler(name string) Handler { } } } else if name == "multiprint" { - return func(args []string, in *net.UnixConn, out *net.UnixConn) { + return func(args []string, in beam.Receiver, out beam.Sender) { var tasks sync.WaitGroup for { - payload, a, err := beam.Receive(in) + payload, a, err := in.Receive() if err != nil { return } @@ -465,25 +466,25 @@ func GetHandler(name string) Handler { tasks.Wait() } } else if name == "listen" { - return func(args []string, in *net.UnixConn, out *net.UnixConn) { + return func(args []string, in beam.Receiver, out beam.Sender) { if len(args) != 2 { - beam.Send(out, data.Empty().Set("status", "1").Set("message", "wrong number of arguments").Bytes(), nil) + out.Send(data.Empty().Set("status", "1").Set("message", "wrong number of arguments").Bytes(), nil) return } u, err := url.Parse(args[1]) if err != nil { - beam.Send(out, data.Empty().Set("status", "1").Set("message", err.Error()).Bytes(), nil) + out.Send(data.Empty().Set("status", "1").Set("message", err.Error()).Bytes(), nil) return } l, err := net.Listen(u.Scheme, u.Host) if err != nil { - beam.Send(out, data.Empty().Set("status", "1").Set("message", err.Error()).Bytes(), nil) + out.Send(data.Empty().Set("status", "1").Set("message", err.Error()).Bytes(), nil) return } for { conn, err := l.Accept() if err != nil { - beam.Send(out, data.Empty().Set("status", "1").Set("message", err.Error()).Bytes(), nil) + out.Send(data.Empty().Set("status", "1").Set("message", err.Error()).Bytes(), nil) return } f, err := connToFile(conn) @@ -491,13 +492,13 @@ func GetHandler(name string) Handler { conn.Close() continue } - beam.Send(out, data.Empty().Set("type", "socket").Set("remoteaddr", conn.RemoteAddr().String()).Bytes(), f) + out.Send(data.Empty().Set("type", "socket").Set("remoteaddr", conn.RemoteAddr().String()).Bytes(), f) } } } else if name == "beamsend" { - return func(args []string, in *net.UnixConn, out *net.UnixConn) { + return func(args []string, in beam.Receiver, out beam.Sender) { if len(args) < 2 { - if err := beam.Send(out, data.Empty().Set("status", "1").Set("message", "wrong number of arguments").Bytes(), nil); err != nil { + if err := out.Send(data.Empty().Set("status", "1").Set("message", "wrong number of arguments").Bytes(), nil); err != nil { Fatal(err) } return @@ -506,16 +507,16 @@ func GetHandler(name string) Handler { connector = dialer connections, err := connector(args[1]) if err != nil { - beam.Send(out, data.Empty().Set("status", "1").Set("message", err.Error()).Bytes(), nil) + out.Send(data.Empty().Set("status", "1").Set("message", err.Error()).Bytes(), nil) return } // Copy in to conn SendToConn(connections, in) } } else if name == "beamreceive" { - return func(args []string, in *net.UnixConn, out *net.UnixConn) { + return func(args []string, in beam.Receiver, out beam.Sender) { if len(args) != 2 { - if err := beam.Send(out, data.Empty().Set("status", "1").Set("message", "wrong number of arguments").Bytes(), nil); err != nil { + if err := out.Send(data.Empty().Set("status", "1").Set("message", "wrong number of arguments").Bytes(), nil); err != nil { Fatal(err) } return @@ -524,26 +525,26 @@ func GetHandler(name string) Handler { connector = listener connections, err := connector(args[1]) if err != nil { - beam.Send(out, data.Empty().Set("status", "1").Set("message", err.Error()).Bytes(), nil) + out.Send(data.Empty().Set("status", "1").Set("message", err.Error()).Bytes(), nil) return } // Copy in to conn ReceiveFromConn(connections, out) } } else if name == "connect" { - return func(args []string, in *net.UnixConn, out *net.UnixConn) { + return func(args []string, in beam.Receiver, out beam.Sender) { if len(args) != 2 { - beam.Send(out, data.Empty().Set("status", "1").Set("message", "wrong number of arguments").Bytes(), nil) + out.Send(data.Empty().Set("status", "1").Set("message", "wrong number of arguments").Bytes(), nil) return } u, err := url.Parse(args[1]) if err != nil { - beam.Send(out, data.Empty().Set("status", "1").Set("message", err.Error()).Bytes(), nil) + out.Send(data.Empty().Set("status", "1").Set("message", err.Error()).Bytes(), nil) return } var tasks sync.WaitGroup for { - _, attachment, err := beam.Receive(in) + _, attachment, err := in.Receive() if err != nil { break } @@ -553,10 +554,10 @@ func GetHandler(name string) Handler { Logf("connecting to %s/%s\n", u.Scheme, u.Host) conn, err := net.Dial(u.Scheme, u.Host) if err != nil { - beam.Send(out, data.Empty().Set("cmd", "msg", "connect error: " + err.Error()).Bytes(), nil) + out.Send(data.Empty().Set("cmd", "msg", "connect error: " + err.Error()).Bytes(), nil) return } - beam.Send(out, data.Empty().Set("cmd", "msg", "connection established").Bytes(), nil) + out.Send(data.Empty().Set("cmd", "msg", "connection established").Bytes(), nil) tasks.Add(1) go func(attachment *os.File, conn net.Conn) { defer tasks.Done() @@ -580,13 +581,13 @@ func GetHandler(name string) Handler { tasks.Wait() } } else if name == "openfile" { - return func(args []string, in *net.UnixConn, out *net.UnixConn) { + return func(args []string, in beam.Receiver, out beam.Sender) { for _, name := range args { f, err := os.Open(name) if err != nil { continue } - if err := beam.Send(out, data.Empty().Set("path", name).Set("type", "file").Bytes(), f); err != nil { + if err := out.Send(data.Empty().Set("path", name).Set("type", "file").Bytes(), f); err != nil { f.Close() } } @@ -652,39 +653,6 @@ func scriptString(script []*dockerscript.Command) string { return fmt.Sprintf("'%s'", strings.Join(lines, "; ")) } -func beamCopy(dst *net.UnixConn, src *net.UnixConn) (int, error) { - var n int - for { - payload, attachment, err := beam.Receive(src) - if err == io.EOF { - return n, nil - } else if err != nil { - return n, err - } - if err := beam.Send(dst, payload, attachment); err != nil { - if attachment != nil { - attachment.Close() - } - return n, err - } - n++ - } - panic("impossibru!") - return n, nil -} - -func sendWPipe(conn *net.UnixConn, payload []byte) (*os.File, error) { - r, w, err := os.Pipe() - if err != nil { - return nil, err - } - if err := beam.Send(conn, payload, r); err != nil { - r.Close() - w.Close() - return nil, err - } - return w, nil -} func dialer(addr string) (chan net.Conn, error) { u, err := url.Parse(addr) @@ -731,11 +699,11 @@ func listener(addr string) (chan net.Conn, error) { -func SendToConn(connections chan net.Conn, src *net.UnixConn) error { +func SendToConn(connections chan net.Conn, src beam.Receiver) error { var tasks sync.WaitGroup defer tasks.Wait() for { - payload, attachment, err := beam.Receive(src) + payload, attachment, err := src.Receive() if err == io.EOF { return nil } else if err != nil { @@ -787,7 +755,7 @@ func msgDesc(payload []byte, attachment *os.File) string { } -func ReceiveFromConn(connections chan net.Conn, dst *net.UnixConn) error { +func ReceiveFromConn(connections chan net.Conn, dst beam.Sender) error { for conn := range connections { err := func () error { Logf("parsing message from network...\n") @@ -825,7 +793,7 @@ func ReceiveFromConn(connections chan net.Conn, dst *net.UnixConn) error { } bicopy(conn, f) }(buf[skip:n], conn, pub) - if err := beam.Send(dst, []byte(header), priv); err != nil { + if err := dst.Send([]byte(header), priv); err != nil { return err } return nil diff --git a/pkg/beam/service.go b/pkg/beam/service.go index 9a21e353df..576f1d786d 100644 --- a/pkg/beam/service.go +++ b/pkg/beam/service.go @@ -16,18 +16,18 @@ import ( // Note that if the underlying file descriptor received in attachment is nil or does // not point to a connection, that message will be skipped. // -func Listen(conn *net.UnixConn, name string) (net.Listener, error) { - endpoint, err := SendPipe(conn, []byte(name)) +func Listen(conn Sender, name string) (net.Listener, error) { + in, _, err := SendPair(conn, []byte(name)) if err != nil { return nil, err } return &listener{ name: name, - endpoint: endpoint, + endpoint: in, }, nil } -func Connect(ctx *net.UnixConn, name string) (net.Conn, error) { +func Connect(ctx *UnixConn, name string) (net.Conn, error) { l, err := Listen(ctx, name) if err != nil { return nil, err @@ -41,12 +41,12 @@ func Connect(ctx *net.UnixConn, name string) (net.Conn, error) { type listener struct { name string - endpoint *net.UnixConn + endpoint ReceiveCloser } func (l *listener) Accept() (net.Conn, error) { for { - _, f, err := Receive(l.endpoint) + _, f, err := l.endpoint.Receive() if err != nil { return nil, err } diff --git a/pkg/beam/unix.go b/pkg/beam/unix.go index 25767bbb4f..94d7b5b4fc 100644 --- a/pkg/beam/unix.go +++ b/pkg/beam/unix.go @@ -19,9 +19,27 @@ func debugCheckpoint(msg string, args ...interface{}) { tty.Close() } +type UnixConn struct { + *net.UnixConn +} + +func FileConn(f *os.File) (*UnixConn, error) { + conn, err := net.FileConn(f) + if err != nil { + return nil, err + } + uconn, ok := conn.(*net.UnixConn) + if !ok { + conn.Close() + return nil, fmt.Errorf("%d: not a unix connection", f.Fd()) + } + return &UnixConn{uconn}, nil + +} + // Send sends a new message on conn with data and f as payload and // attachment, respectively. -func Send(conn *net.UnixConn, data []byte, f *os.File) error { +func (conn *UnixConn) Send(data []byte, f *os.File) error { { var fd int = -1 if f != nil { @@ -33,7 +51,7 @@ func Send(conn *net.UnixConn, data []byte, f *os.File) error { if f != nil { fds = append(fds, int(f.Fd())) } - return sendUnix(conn, data, fds...) + return sendUnix(conn.UnixConn, data, fds...) } // Receive waits for a new message on conn, and receives its payload @@ -42,7 +60,7 @@ func Send(conn *net.UnixConn, data []byte, f *os.File) error { // If more than 1 file descriptor is sent in the message, they are all // closed except for the first, which is the attachment. // It is legal for a message to have no attachment or an empty payload. -func Receive(conn *net.UnixConn) (rdata []byte, rf *os.File, rerr error) { +func (conn *UnixConn) Receive() (rdata []byte, rf *os.File, rerr error) { defer func() { var fd int = -1 if rf != nil { @@ -51,7 +69,7 @@ func Receive(conn *net.UnixConn) (rdata []byte, rf *os.File, rerr error) { debugCheckpoint("===DEBUG=== Receive() -> '%s'[%d]. Hit enter to continue.\n", rdata, fd) }() for { - data, fds, err := receiveUnix(conn) + data, fds, err := receiveUnix(conn.UnixConn) if err != nil { return nil, nil, err } @@ -70,63 +88,6 @@ func Receive(conn *net.UnixConn) (rdata []byte, rf *os.File, rerr error) { return nil, nil, nil } -// SendPipe creates a new unix socket pair, sends one end as the attachment -// to a beam message with the payload `data`, and returns the other end. -// -// This is a common pattern to open a new service endpoint. -// For example, a service wishing to advertise its presence to clients might -// open an endpoint with: -// -// endpoint, _ := SendPipe(conn, []byte("sql")) -// defer endpoint.Close() -// for { -// conn, _ := endpoint.Receive() -// go func() { -// Handle(conn) -// conn.Close() -// }() -// } -// -// Note that beam does not distinguish between clients and servers in the logical -// sense: any program wishing to establishing a communication with another program -// may use SendPipe() to create an endpoint. -// For example, here is how an application might use it to connect to a database client. -// -// endpoint, _ := SendPipe(conn, []byte("userdb")) -// defer endpoint.Close() -// conn, _ := endpoint.Receive() -// defer conn.Close() -// db := NewDBClient(conn) -// -// In this example note that we only need the first connection out of the endpoint, -// but we could open new ones to retry after a broken connection. -// Note that, because the underlying service transport is abstracted away, this -// allows for arbitrarily complex service discovery and retry logic to take place, -// without complicating application code. -// -func SendPipe(conn *net.UnixConn, data []byte) (endpoint *net.UnixConn, err error) { - debugCheckpoint("===DEBUG=== SendPipe('%s'). Hit enter to confirm: ", data) - local, remote, err := SocketPair() - if err != nil { - return nil, err - } - defer func() { - if err != nil { - local.Close() - remote.Close() - } - }() - endpoint, err = FdConn(int(local.Fd())) - if err != nil { - return nil, err - } - local.Close() - if err := Send(conn, data, remote); err != nil { - return nil, err - } - return endpoint, nil -} - func receiveUnix(conn *net.UnixConn) ([]byte, []int, error) { buf := make([]byte, 4096) oob := make([]byte, 4096) @@ -204,7 +165,7 @@ func SocketPair() (a *os.File, b *os.File, err error) { return os.NewFile(uintptr(pair[0]), ""), os.NewFile(uintptr(pair[1]), ""), nil } -func USocketPair() (*net.UnixConn, *net.UnixConn, error) { +func USocketPair() (*UnixConn, *UnixConn, error) { debugCheckpoint("===DEBUG=== USocketPair(). Hit enter to confirm: ") defer debugCheckpoint ("===DEBUG=== USocketPair() returned. Hit enter to confirm ") a, b, err := SocketPair() @@ -213,11 +174,11 @@ func USocketPair() (*net.UnixConn, *net.UnixConn, error) { } defer a.Close() defer b.Close() - uA, err := FdConn(int(a.Fd())) + uA, err := FileConn(a) if err != nil { return nil, nil, err } - uB, err := FdConn(int(b.Fd())) + uB, err := FileConn(b) if err != nil { uA.Close() return nil, nil, err diff --git a/pkg/beam/unix_test.go b/pkg/beam/unix_test.go index bdd03b1a07..09815aa0d6 100644 --- a/pkg/beam/unix_test.go +++ b/pkg/beam/unix_test.go @@ -45,25 +45,25 @@ func TestSendUnixSocket(t *testing.T) { // defer glueA.Close() // defer glueB.Close() go func() { - err := Send(b2, []byte("a"), glueB) + err := b2.Send([]byte("a"), glueB) if err != nil { t.Fatal(err) } }() go func() { - err := Send(a2, []byte("b"), glueA) + err := a2.Send([]byte("b"), glueA) if err != nil { t.Fatal(err) } }() - connAhdr, connA, err := Receive(a1) + connAhdr, connA, err := a1.Receive() if err != nil { t.Fatal(err) } if string(connAhdr) != "b" { t.Fatalf("unexpected: %s", connAhdr) } - connBhdr, connB, err := Receive(b1) + connBhdr, connB, err := b1.Receive() if err != nil { t.Fatal(err) }